Skip to content

Commit 78fad84

Browse files
committed
Prototype using const generic for simd_shuffle IDX array
1 parent 078eb11 commit 78fad84

File tree

10 files changed

+282
-43
lines changed

10 files changed

+282
-43
lines changed

Diff for: compiler/rustc_codegen_cranelift/src/intrinsics/simd.rs

+49-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ fn report_simd_type_validation_error(
2121
pub(super) fn codegen_simd_intrinsic_call<'tcx>(
2222
fx: &mut FunctionCx<'_, '_, 'tcx>,
2323
intrinsic: Symbol,
24-
_args: GenericArgsRef<'tcx>,
24+
generic_args: GenericArgsRef<'tcx>,
2525
args: &[mir::Operand<'tcx>],
2626
ret: CPlace<'tcx>,
2727
target: BasicBlock,
@@ -117,6 +117,54 @@ pub(super) fn codegen_simd_intrinsic_call<'tcx>(
117117
});
118118
}
119119

120+
// simd_shuffle_generic<T, U, const I: &[u32]>(x: T, y: T) -> U
121+
sym::simd_shuffle_generic => {
122+
let [x, y] = args else {
123+
bug!("wrong number of args for intrinsic {intrinsic}");
124+
};
125+
let x = codegen_operand(fx, x);
126+
let y = codegen_operand(fx, y);
127+
128+
if !x.layout().ty.is_simd() {
129+
report_simd_type_validation_error(fx, intrinsic, span, x.layout().ty);
130+
return;
131+
}
132+
133+
let idx = generic_args[2]
134+
.expect_const()
135+
.eval(fx.tcx, ty::ParamEnv::reveal_all(), Some(span))
136+
.unwrap()
137+
.unwrap_branch();
138+
139+
assert_eq!(x.layout(), y.layout());
140+
let layout = x.layout();
141+
142+
let (lane_count, lane_ty) = layout.ty.simd_size_and_type(fx.tcx);
143+
let (ret_lane_count, ret_lane_ty) = ret.layout().ty.simd_size_and_type(fx.tcx);
144+
145+
assert_eq!(lane_ty, ret_lane_ty);
146+
assert_eq!(idx.len() as u64, ret_lane_count);
147+
148+
let total_len = lane_count * 2;
149+
150+
let indexes =
151+
idx.iter().map(|idx| idx.unwrap_leaf().try_to_u16().unwrap()).collect::<Vec<u16>>();
152+
153+
for &idx in &indexes {
154+
assert!(u64::from(idx) < total_len, "idx {} out of range 0..{}", idx, total_len);
155+
}
156+
157+
for (out_idx, in_idx) in indexes.into_iter().enumerate() {
158+
let in_lane = if u64::from(in_idx) < lane_count {
159+
x.value_lane(fx, in_idx.into())
160+
} else {
161+
y.value_lane(fx, u64::from(in_idx) - lane_count)
162+
};
163+
let out_lane = ret.place_lane(fx, u64::try_from(out_idx).unwrap());
164+
out_lane.write_cvalue(fx, in_lane);
165+
}
166+
}
167+
120168
// simd_shuffle<T, I, U>(x: T, y: T, idx: I) -> U
121169
sym::simd_shuffle => {
122170
let (x, y, idx) = match args {

Diff for: compiler/rustc_codegen_llvm/src/intrinsic.rs

+55-2
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ use rustc_codegen_ssa::mir::place::PlaceRef;
1515
use rustc_codegen_ssa::traits::*;
1616
use rustc_hir as hir;
1717
use rustc_middle::ty::layout::{FnAbiOf, HasTyCtxt, LayoutOf};
18-
use rustc_middle::ty::{self, Ty};
18+
use rustc_middle::ty::{self, GenericArgsRef, Ty};
1919
use rustc_middle::{bug, span_bug};
2020
use rustc_span::{sym, symbol::kw, Span, Symbol};
2121
use rustc_target::abi::{self, Align, HasDataLayout, Primitive};
@@ -376,7 +376,9 @@ impl<'ll, 'tcx> IntrinsicCallMethods<'tcx> for Builder<'_, 'll, 'tcx> {
376376
}
377377

378378
_ if name.as_str().starts_with("simd_") => {
379-
match generic_simd_intrinsic(self, name, callee_ty, args, ret_ty, llret_ty, span) {
379+
match generic_simd_intrinsic(
380+
self, name, callee_ty, fn_args, args, ret_ty, llret_ty, span,
381+
) {
380382
Ok(llval) => llval,
381383
Err(()) => return,
382384
}
@@ -911,6 +913,7 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
911913
bx: &mut Builder<'_, 'll, 'tcx>,
912914
name: Symbol,
913915
callee_ty: Ty<'tcx>,
916+
fn_args: GenericArgsRef<'tcx>,
914917
args: &[OperandRef<'tcx, &'ll Value>],
915918
ret_ty: Ty<'tcx>,
916919
llret_ty: &'ll Type,
@@ -1030,6 +1033,56 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
10301033
));
10311034
}
10321035

1036+
if name == sym::simd_shuffle_generic {
1037+
let idx = fn_args[2]
1038+
.expect_const()
1039+
.eval(tcx, ty::ParamEnv::reveal_all(), Some(span))
1040+
.unwrap()
1041+
.unwrap_branch();
1042+
let n = idx.len() as u64;
1043+
1044+
require_simd!(ret_ty, InvalidMonomorphization::SimdReturn { span, name, ty: ret_ty });
1045+
let (out_len, out_ty) = ret_ty.simd_size_and_type(bx.tcx());
1046+
require!(
1047+
out_len == n,
1048+
InvalidMonomorphization::ReturnLength { span, name, in_len: n, ret_ty, out_len }
1049+
);
1050+
require!(
1051+
in_elem == out_ty,
1052+
InvalidMonomorphization::ReturnElement { span, name, in_elem, in_ty, ret_ty, out_ty }
1053+
);
1054+
1055+
let total_len = in_len * 2;
1056+
1057+
let indices: Option<Vec<_>> = idx
1058+
.iter()
1059+
.enumerate()
1060+
.map(|(arg_idx, val)| {
1061+
let idx = val.unwrap_leaf().try_to_i32().unwrap();
1062+
if idx >= i32::try_from(total_len).unwrap() {
1063+
bx.sess().emit_err(InvalidMonomorphization::ShuffleIndexOutOfBounds {
1064+
span,
1065+
name,
1066+
arg_idx: arg_idx as u64,
1067+
total_len: total_len.into(),
1068+
});
1069+
None
1070+
} else {
1071+
Some(bx.const_i32(idx))
1072+
}
1073+
})
1074+
.collect();
1075+
let Some(indices) = indices else {
1076+
return Ok(bx.const_null(llret_ty));
1077+
};
1078+
1079+
return Ok(bx.shuffle_vector(
1080+
args[0].immediate(),
1081+
args[1].immediate(),
1082+
bx.const_vector(&indices),
1083+
));
1084+
}
1085+
10331086
if name == sym::simd_shuffle {
10341087
// Make sure this is actually an array, since typeck only checks the length-suffixed
10351088
// version of this intrinsic.

Diff for: compiler/rustc_hir_analysis/src/check/intrinsic.rs

+23-21
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ fn equate_intrinsic_type<'tcx>(
2020
it: &hir::ForeignItem<'_>,
2121
n_tps: usize,
2222
n_lts: usize,
23+
n_cts: usize,
2324
sig: ty::PolyFnSig<'tcx>,
2425
) {
2526
let (own_counts, span) = match &it.kind {
@@ -51,7 +52,7 @@ fn equate_intrinsic_type<'tcx>(
5152

5253
if gen_count_ok(own_counts.lifetimes, n_lts, "lifetime")
5354
&& gen_count_ok(own_counts.types, n_tps, "type")
54-
&& gen_count_ok(own_counts.consts, 0, "const")
55+
&& gen_count_ok(own_counts.consts, n_cts, "const")
5556
{
5657
let fty = Ty::new_fn_ptr(tcx, sig);
5758
let it_def_id = it.owner_id.def_id;
@@ -492,7 +493,7 @@ pub fn check_intrinsic_type(tcx: TyCtxt<'_>, it: &hir::ForeignItem<'_>) {
492493
};
493494
let sig = tcx.mk_fn_sig(inputs, output, false, unsafety, Abi::RustIntrinsic);
494495
let sig = ty::Binder::bind_with_vars(sig, bound_vars);
495-
equate_intrinsic_type(tcx, it, n_tps, n_lts, sig)
496+
equate_intrinsic_type(tcx, it, n_tps, n_lts, 0, sig)
496497
}
497498

498499
/// Type-check `extern "platform-intrinsic" { ... }` functions.
@@ -504,9 +505,9 @@ pub fn check_platform_intrinsic_type(tcx: TyCtxt<'_>, it: &hir::ForeignItem<'_>)
504505

505506
let name = it.ident.name;
506507

507-
let (n_tps, inputs, output) = match name {
508+
let (n_tps, n_cts, inputs, output) = match name {
508509
sym::simd_eq | sym::simd_ne | sym::simd_lt | sym::simd_le | sym::simd_gt | sym::simd_ge => {
509-
(2, vec![param(0), param(0)], param(1))
510+
(2, 0, vec![param(0), param(0)], param(1))
510511
}
511512
sym::simd_add
512513
| sym::simd_sub
@@ -522,8 +523,8 @@ pub fn check_platform_intrinsic_type(tcx: TyCtxt<'_>, it: &hir::ForeignItem<'_>)
522523
| sym::simd_fmax
523524
| sym::simd_fpow
524525
| sym::simd_saturating_add
525-
| sym::simd_saturating_sub => (1, vec![param(0), param(0)], param(0)),
526-
sym::simd_arith_offset => (2, vec![param(0), param(1)], param(0)),
526+
| sym::simd_saturating_sub => (1, 0, vec![param(0), param(0)], param(0)),
527+
sym::simd_arith_offset => (2, 0, vec![param(0), param(1)], param(0)),
527528
sym::simd_neg
528529
| sym::simd_bswap
529530
| sym::simd_bitreverse
@@ -541,25 +542,25 @@ pub fn check_platform_intrinsic_type(tcx: TyCtxt<'_>, it: &hir::ForeignItem<'_>)
541542
| sym::simd_ceil
542543
| sym::simd_floor
543544
| sym::simd_round
544-
| sym::simd_trunc => (1, vec![param(0)], param(0)),
545-
sym::simd_fpowi => (1, vec![param(0), tcx.types.i32], param(0)),
546-
sym::simd_fma => (1, vec![param(0), param(0), param(0)], param(0)),
547-
sym::simd_gather => (3, vec![param(0), param(1), param(2)], param(0)),
548-
sym::simd_scatter => (3, vec![param(0), param(1), param(2)], Ty::new_unit(tcx)),
549-
sym::simd_insert => (2, vec![param(0), tcx.types.u32, param(1)], param(0)),
550-
sym::simd_extract => (2, vec![param(0), tcx.types.u32], param(1)),
545+
| sym::simd_trunc => (1, 0, vec![param(0)], param(0)),
546+
sym::simd_fpowi => (1, 0, vec![param(0), tcx.types.i32], param(0)),
547+
sym::simd_fma => (1, 0, vec![param(0), param(0), param(0)], param(0)),
548+
sym::simd_gather => (3, 0, vec![param(0), param(1), param(2)], param(0)),
549+
sym::simd_scatter => (3, 0, vec![param(0), param(1), param(2)], Ty::new_unit(tcx)),
550+
sym::simd_insert => (2, 0, vec![param(0), tcx.types.u32, param(1)], param(0)),
551+
sym::simd_extract => (2, 0, vec![param(0), tcx.types.u32], param(1)),
551552
sym::simd_cast
552553
| sym::simd_as
553554
| sym::simd_cast_ptr
554555
| sym::simd_expose_addr
555-
| sym::simd_from_exposed_addr => (2, vec![param(0)], param(1)),
556-
sym::simd_bitmask => (2, vec![param(0)], param(1)),
556+
| sym::simd_from_exposed_addr => (2, 0, vec![param(0)], param(1)),
557+
sym::simd_bitmask => (2, 0, vec![param(0)], param(1)),
557558
sym::simd_select | sym::simd_select_bitmask => {
558-
(2, vec![param(0), param(1), param(1)], param(1))
559+
(2, 0, vec![param(0), param(1), param(1)], param(1))
559560
}
560-
sym::simd_reduce_all | sym::simd_reduce_any => (1, vec![param(0)], tcx.types.bool),
561+
sym::simd_reduce_all | sym::simd_reduce_any => (1, 0, vec![param(0)], tcx.types.bool),
561562
sym::simd_reduce_add_ordered | sym::simd_reduce_mul_ordered => {
562-
(2, vec![param(0), param(1)], param(1))
563+
(2, 0, vec![param(0), param(1)], param(1))
563564
}
564565
sym::simd_reduce_add_unordered
565566
| sym::simd_reduce_mul_unordered
@@ -569,8 +570,9 @@ pub fn check_platform_intrinsic_type(tcx: TyCtxt<'_>, it: &hir::ForeignItem<'_>)
569570
| sym::simd_reduce_min
570571
| sym::simd_reduce_max
571572
| sym::simd_reduce_min_nanless
572-
| sym::simd_reduce_max_nanless => (2, vec![param(0)], param(1)),
573-
sym::simd_shuffle => (3, vec![param(0), param(0), param(1)], param(2)),
573+
| sym::simd_reduce_max_nanless => (2, 0, vec![param(0)], param(1)),
574+
sym::simd_shuffle => (3, 0, vec![param(0), param(0), param(1)], param(2)),
575+
sym::simd_shuffle_generic => (2, 1, vec![param(0), param(0)], param(1)),
574576
_ => {
575577
let msg = format!("unrecognized platform-specific intrinsic function: `{name}`");
576578
tcx.sess.struct_span_err(it.span, msg).emit();
@@ -580,5 +582,5 @@ pub fn check_platform_intrinsic_type(tcx: TyCtxt<'_>, it: &hir::ForeignItem<'_>)
580582

581583
let sig = tcx.mk_fn_sig(inputs, output, false, hir::Unsafety::Unsafe, Abi::PlatformIntrinsic);
582584
let sig = ty::Binder::dummy(sig);
583-
equate_intrinsic_type(tcx, it, n_tps, 0, sig)
585+
equate_intrinsic_type(tcx, it, n_tps, 0, n_cts, sig)
584586
}

Diff for: compiler/rustc_span/src/symbol.rs

+1
Original file line numberDiff line numberDiff line change
@@ -1465,6 +1465,7 @@ symbols! {
14651465
simd_shl,
14661466
simd_shr,
14671467
simd_shuffle,
1468+
simd_shuffle_generic,
14681469
simd_sub,
14691470
simd_trunc,
14701471
simd_xor,

Diff for: src/tools/miri/src/shims/intrinsics/mod.rs

+3-2
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
6060
}
6161

6262
// The rest jumps to `ret` immediately.
63-
this.emulate_intrinsic_by_name(intrinsic_name, args, dest)?;
63+
this.emulate_intrinsic_by_name(intrinsic_name, instance.args, args, dest)?;
6464

6565
trace!("{:?}", this.dump_place(dest));
6666
this.go_to_block(ret);
@@ -71,6 +71,7 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
7171
fn emulate_intrinsic_by_name(
7272
&mut self,
7373
intrinsic_name: &str,
74+
generic_args: ty::GenericArgsRef<'tcx>,
7475
args: &[OpTy<'tcx, Provenance>],
7576
dest: &PlaceTy<'tcx, Provenance>,
7677
) -> InterpResult<'tcx> {
@@ -80,7 +81,7 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
8081
return this.emulate_atomic_intrinsic(name, args, dest);
8182
}
8283
if let Some(name) = intrinsic_name.strip_prefix("simd_") {
83-
return this.emulate_simd_intrinsic(name, args, dest);
84+
return this.emulate_simd_intrinsic(name, generic_args, args, dest);
8485
}
8586

8687
match intrinsic_name {

Diff for: src/tools/miri/src/shims/intrinsics/simd.rs

+33
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
1212
fn emulate_simd_intrinsic(
1313
&mut self,
1414
intrinsic_name: &str,
15+
generic_args: ty::GenericArgsRef<'tcx>,
1516
args: &[OpTy<'tcx, Provenance>],
1617
dest: &PlaceTy<'tcx, Provenance>,
1718
) -> InterpResult<'tcx> {
@@ -490,6 +491,38 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
490491
this.write_immediate(val, &dest)?;
491492
}
492493
}
494+
"shuffle_generic" => {
495+
let [left, right] = check_arg_count(args)?;
496+
let (left, left_len) = this.operand_to_simd(left)?;
497+
let (right, right_len) = this.operand_to_simd(right)?;
498+
let (dest, dest_len) = this.place_to_simd(dest)?;
499+
500+
let index = generic_args[2].expect_const().eval(*this.tcx, this.param_env(), Some(this.tcx.span)).unwrap().unwrap_branch();
501+
let index_len = index.len();
502+
503+
assert_eq!(left_len, right_len);
504+
assert_eq!(index_len as u64, dest_len);
505+
506+
for i in 0..dest_len {
507+
let src_index: u64 = index[i as usize].unwrap_leaf()
508+
.try_to_u32().unwrap()
509+
.into();
510+
let dest = this.project_index(&dest, i)?;
511+
512+
let val = if src_index < left_len {
513+
this.read_immediate(&this.project_index(&left, src_index)?)?
514+
} else if src_index < left_len.checked_add(right_len).unwrap() {
515+
let right_idx = src_index.checked_sub(left_len).unwrap();
516+
this.read_immediate(&this.project_index(&right, right_idx)?)?
517+
} else {
518+
span_bug!(
519+
this.cur_span(),
520+
"simd_shuffle index {src_index} is out of bounds for 2 vectors of size {left_len}",
521+
);
522+
};
523+
this.write_immediate(*val, &dest)?;
524+
}
525+
}
493526
"shuffle" => {
494527
let [left, right, index] = check_arg_count(args)?;
495528
let (left, left_len) = this.operand_to_simd(left)?;

Diff for: tests/ui/simd/intrinsic/generic-elements.rs

+27-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
// build-fail
22

3-
#![feature(repr_simd, platform_intrinsics, rustc_attrs)]
3+
#![feature(repr_simd, platform_intrinsics, rustc_attrs, adt_const_params)]
4+
#![allow(incomplete_features)]
45

56
#[repr(simd)]
67
#[derive(Copy, Clone)]
@@ -35,6 +36,7 @@ extern "platform-intrinsic" {
3536
fn simd_extract<T, E>(x: T, idx: u32) -> E;
3637

3738
fn simd_shuffle<T, I, U>(x: T, y: T, idx: I) -> U;
39+
fn simd_shuffle_generic<T, U, const IDX: &'static [u32]>(x: T, y: T) -> U;
3840
}
3941

4042
fn main() {
@@ -71,5 +73,29 @@ fn main() {
7173
//~^ ERROR expected return type of length 4, found `i32x8` with length 8
7274
simd_shuffle::<_, _, i32x2>(x, x, IDX8);
7375
//~^ ERROR expected return type of length 8, found `i32x2` with length 2
76+
77+
const I2: &[u32] = &[0; 2];
78+
simd_shuffle_generic::<i32, i32, I2>(0, 0);
79+
//~^ ERROR expected SIMD input type, found non-SIMD `i32`
80+
const I4: &[u32] = &[0; 4];
81+
simd_shuffle_generic::<i32, i32, I4>(0, 0);
82+
//~^ ERROR expected SIMD input type, found non-SIMD `i32`
83+
const I8: &[u32] = &[0; 8];
84+
simd_shuffle_generic::<i32, i32, I8>(0, 0);
85+
//~^ ERROR expected SIMD input type, found non-SIMD `i32`
86+
87+
simd_shuffle_generic::<_, f32x2, I2>(x, x);
88+
//~^ ERROR element type `i32` (element of input `i32x4`), found `f32x2` with element type `f32`
89+
simd_shuffle_generic::<_, f32x4, I4>(x, x);
90+
//~^ ERROR element type `i32` (element of input `i32x4`), found `f32x4` with element type `f32`
91+
simd_shuffle_generic::<_, f32x8, I8>(x, x);
92+
//~^ ERROR element type `i32` (element of input `i32x4`), found `f32x8` with element type `f32`
93+
94+
simd_shuffle_generic::<_, i32x8, I2>(x, x);
95+
//~^ ERROR expected return type of length 2, found `i32x8` with length 8
96+
simd_shuffle_generic::<_, i32x8, I4>(x, x);
97+
//~^ ERROR expected return type of length 4, found `i32x8` with length 8
98+
simd_shuffle_generic::<_, i32x2, I8>(x, x);
99+
//~^ ERROR expected return type of length 8, found `i32x2` with length 2
74100
}
75101
}

0 commit comments

Comments
 (0)