Skip to content

Commit 1d7574a

Browse files
authored
Unrolled build for rust-lang#117953
Rollup merge of rust-lang#117953 - farnoy:masked-load-store, r=workingjubilee Add more SIMD platform-intrinsics - [x] simd_masked_load - [x] LLVM codegen - llvm.masked.load - [x] cranelift codegen - implemented but untested - [ ] simd_masked_store - [x] LLVM codegen - llvm.masked.store - [ ] cranelift codegen Also added a run-pass test to test both intrinsics, and additional build-fail & check-fail to cover validation for both intrinsics
2 parents 1dfb228 + 97ae509 commit 1d7574a

File tree

11 files changed

+594
-1
lines changed

11 files changed

+594
-1
lines changed

compiler/rustc_codegen_cranelift/src/intrinsics/simd.rs

+51-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
//! Codegen `extern "platform-intrinsic"` intrinsics.
22
3+
use cranelift_codegen::ir::immediates::Offset32;
34
use rustc_middle::ty::GenericArgsRef;
45
use rustc_span::Symbol;
56
use rustc_target::abi::Endian;
@@ -1008,8 +1009,57 @@ pub(super) fn codegen_simd_intrinsic_call<'tcx>(
10081009
}
10091010
}
10101011

1012+
sym::simd_masked_load => {
1013+
intrinsic_args!(fx, args => (mask, ptr, val); intrinsic);
1014+
1015+
let (val_lane_count, val_lane_ty) = val.layout().ty.simd_size_and_type(fx.tcx);
1016+
let (mask_lane_count, _mask_lane_ty) = mask.layout().ty.simd_size_and_type(fx.tcx);
1017+
let (ret_lane_count, ret_lane_ty) = ret.layout().ty.simd_size_and_type(fx.tcx);
1018+
assert_eq!(val_lane_count, mask_lane_count);
1019+
assert_eq!(val_lane_count, ret_lane_count);
1020+
1021+
let lane_clif_ty = fx.clif_type(val_lane_ty).unwrap();
1022+
let ret_lane_layout = fx.layout_of(ret_lane_ty);
1023+
let ptr_val = ptr.load_scalar(fx);
1024+
1025+
for lane_idx in 0..ret_lane_count {
1026+
let val_lane = val.value_lane(fx, lane_idx).load_scalar(fx);
1027+
let mask_lane = mask.value_lane(fx, lane_idx).load_scalar(fx);
1028+
1029+
let if_enabled = fx.bcx.create_block();
1030+
let if_disabled = fx.bcx.create_block();
1031+
let next = fx.bcx.create_block();
1032+
let res_lane = fx.bcx.append_block_param(next, lane_clif_ty);
1033+
1034+
fx.bcx.ins().brif(mask_lane, if_enabled, &[], if_disabled, &[]);
1035+
fx.bcx.seal_block(if_enabled);
1036+
fx.bcx.seal_block(if_disabled);
1037+
1038+
fx.bcx.switch_to_block(if_enabled);
1039+
let offset = lane_idx as i32 * lane_clif_ty.bytes() as i32;
1040+
let res = fx.bcx.ins().load(
1041+
lane_clif_ty,
1042+
MemFlags::trusted(),
1043+
ptr_val,
1044+
Offset32::new(offset),
1045+
);
1046+
fx.bcx.ins().jump(next, &[res]);
1047+
1048+
fx.bcx.switch_to_block(if_disabled);
1049+
fx.bcx.ins().jump(next, &[val_lane]);
1050+
1051+
fx.bcx.seal_block(next);
1052+
fx.bcx.switch_to_block(next);
1053+
1054+
fx.bcx.ins().nop();
1055+
1056+
ret.place_lane(fx, lane_idx)
1057+
.write_cvalue(fx, CValue::by_val(res_lane, ret_lane_layout));
1058+
}
1059+
}
1060+
10111061
sym::simd_scatter => {
1012-
intrinsic_args!(fx, args => (val, ptr, mask); intrinsic);
1062+
intrinsic_args!(fx, args => (mask, ptr, val); intrinsic);
10131063

10141064
let (val_lane_count, _val_lane_ty) = val.layout().ty.simd_size_and_type(fx.tcx);
10151065
let (ptr_lane_count, _ptr_lane_ty) = ptr.layout().ty.simd_size_and_type(fx.tcx);

compiler/rustc_codegen_llvm/src/intrinsic.rs

+192
Original file line numberDiff line numberDiff line change
@@ -1492,6 +1492,198 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
14921492
return Ok(v);
14931493
}
14941494

1495+
if name == sym::simd_masked_load {
1496+
// simd_masked_load(mask: <N x i{M}>, pointer: *_ T, values: <N x T>) -> <N x T>
1497+
// * N: number of elements in the input vectors
1498+
// * T: type of the element to load
1499+
// * M: any integer width is supported, will be truncated to i1
1500+
// Loads contiguous elements from memory behind `pointer`, but only for
1501+
// those lanes whose `mask` bit is enabled.
1502+
// The memory addresses corresponding to the “off” lanes are not accessed.
1503+
1504+
// The element type of the "mask" argument must be a signed integer type of any width
1505+
let mask_ty = in_ty;
1506+
let (mask_len, mask_elem) = (in_len, in_elem);
1507+
1508+
// The second argument must be a pointer matching the element type
1509+
let pointer_ty = arg_tys[1];
1510+
1511+
// The last argument is a passthrough vector providing values for disabled lanes
1512+
let values_ty = arg_tys[2];
1513+
let (values_len, values_elem) = require_simd!(values_ty, SimdThird);
1514+
1515+
require_simd!(ret_ty, SimdReturn);
1516+
1517+
// Of the same length:
1518+
require!(
1519+
values_len == mask_len,
1520+
InvalidMonomorphization::ThirdArgumentLength {
1521+
span,
1522+
name,
1523+
in_len: mask_len,
1524+
in_ty: mask_ty,
1525+
arg_ty: values_ty,
1526+
out_len: values_len
1527+
}
1528+
);
1529+
1530+
// The return type must match the last argument type
1531+
require!(
1532+
ret_ty == values_ty,
1533+
InvalidMonomorphization::ExpectedReturnType { span, name, in_ty: values_ty, ret_ty }
1534+
);
1535+
1536+
require!(
1537+
matches!(
1538+
pointer_ty.kind(),
1539+
ty::RawPtr(p) if p.ty == values_elem && p.ty.kind() == values_elem.kind()
1540+
),
1541+
InvalidMonomorphization::ExpectedElementType {
1542+
span,
1543+
name,
1544+
expected_element: values_elem,
1545+
second_arg: pointer_ty,
1546+
in_elem: values_elem,
1547+
in_ty: values_ty,
1548+
mutability: ExpectedPointerMutability::Not,
1549+
}
1550+
);
1551+
1552+
require!(
1553+
matches!(mask_elem.kind(), ty::Int(_)),
1554+
InvalidMonomorphization::ThirdArgElementType {
1555+
span,
1556+
name,
1557+
expected_element: values_elem,
1558+
third_arg: mask_ty,
1559+
}
1560+
);
1561+
1562+
// Alignment of T, must be a constant integer value:
1563+
let alignment_ty = bx.type_i32();
1564+
let alignment = bx.const_i32(bx.align_of(values_ty).bytes() as i32);
1565+
1566+
// Truncate the mask vector to a vector of i1s:
1567+
let (mask, mask_ty) = {
1568+
let i1 = bx.type_i1();
1569+
let i1xn = bx.type_vector(i1, mask_len);
1570+
(bx.trunc(args[0].immediate(), i1xn), i1xn)
1571+
};
1572+
1573+
let llvm_pointer = bx.type_ptr();
1574+
1575+
// Type of the vector of elements:
1576+
let llvm_elem_vec_ty = llvm_vector_ty(bx, values_elem, values_len);
1577+
let llvm_elem_vec_str = llvm_vector_str(bx, values_elem, values_len);
1578+
1579+
let llvm_intrinsic = format!("llvm.masked.load.{llvm_elem_vec_str}.p0");
1580+
let fn_ty = bx
1581+
.type_func(&[llvm_pointer, alignment_ty, mask_ty, llvm_elem_vec_ty], llvm_elem_vec_ty);
1582+
let f = bx.declare_cfn(&llvm_intrinsic, llvm::UnnamedAddr::No, fn_ty);
1583+
let v = bx.call(
1584+
fn_ty,
1585+
None,
1586+
None,
1587+
f,
1588+
&[args[1].immediate(), alignment, mask, args[2].immediate()],
1589+
None,
1590+
);
1591+
return Ok(v);
1592+
}
1593+
1594+
if name == sym::simd_masked_store {
1595+
// simd_masked_store(mask: <N x i{M}>, pointer: *mut T, values: <N x T>) -> ()
1596+
// * N: number of elements in the input vectors
1597+
// * T: type of the element to load
1598+
// * M: any integer width is supported, will be truncated to i1
1599+
// Stores contiguous elements to memory behind `pointer`, but only for
1600+
// those lanes whose `mask` bit is enabled.
1601+
// The memory addresses corresponding to the “off” lanes are not accessed.
1602+
1603+
// The element type of the "mask" argument must be a signed integer type of any width
1604+
let mask_ty = in_ty;
1605+
let (mask_len, mask_elem) = (in_len, in_elem);
1606+
1607+
// The second argument must be a pointer matching the element type
1608+
let pointer_ty = arg_tys[1];
1609+
1610+
// The last argument specifies the values to store to memory
1611+
let values_ty = arg_tys[2];
1612+
let (values_len, values_elem) = require_simd!(values_ty, SimdThird);
1613+
1614+
// Of the same length:
1615+
require!(
1616+
values_len == mask_len,
1617+
InvalidMonomorphization::ThirdArgumentLength {
1618+
span,
1619+
name,
1620+
in_len: mask_len,
1621+
in_ty: mask_ty,
1622+
arg_ty: values_ty,
1623+
out_len: values_len
1624+
}
1625+
);
1626+
1627+
// The second argument must be a mutable pointer type matching the element type
1628+
require!(
1629+
matches!(
1630+
pointer_ty.kind(),
1631+
ty::RawPtr(p) if p.ty == values_elem && p.ty.kind() == values_elem.kind() && p.mutbl.is_mut()
1632+
),
1633+
InvalidMonomorphization::ExpectedElementType {
1634+
span,
1635+
name,
1636+
expected_element: values_elem,
1637+
second_arg: pointer_ty,
1638+
in_elem: values_elem,
1639+
in_ty: values_ty,
1640+
mutability: ExpectedPointerMutability::Mut,
1641+
}
1642+
);
1643+
1644+
require!(
1645+
matches!(mask_elem.kind(), ty::Int(_)),
1646+
InvalidMonomorphization::ThirdArgElementType {
1647+
span,
1648+
name,
1649+
expected_element: values_elem,
1650+
third_arg: mask_ty,
1651+
}
1652+
);
1653+
1654+
// Alignment of T, must be a constant integer value:
1655+
let alignment_ty = bx.type_i32();
1656+
let alignment = bx.const_i32(bx.align_of(values_elem).bytes() as i32);
1657+
1658+
// Truncate the mask vector to a vector of i1s:
1659+
let (mask, mask_ty) = {
1660+
let i1 = bx.type_i1();
1661+
let i1xn = bx.type_vector(i1, in_len);
1662+
(bx.trunc(args[0].immediate(), i1xn), i1xn)
1663+
};
1664+
1665+
let ret_t = bx.type_void();
1666+
1667+
let llvm_pointer = bx.type_ptr();
1668+
1669+
// Type of the vector of elements:
1670+
let llvm_elem_vec_ty = llvm_vector_ty(bx, values_elem, values_len);
1671+
let llvm_elem_vec_str = llvm_vector_str(bx, values_elem, values_len);
1672+
1673+
let llvm_intrinsic = format!("llvm.masked.store.{llvm_elem_vec_str}.p0");
1674+
let fn_ty = bx.type_func(&[llvm_elem_vec_ty, llvm_pointer, alignment_ty, mask_ty], ret_t);
1675+
let f = bx.declare_cfn(&llvm_intrinsic, llvm::UnnamedAddr::No, fn_ty);
1676+
let v = bx.call(
1677+
fn_ty,
1678+
None,
1679+
None,
1680+
f,
1681+
&[args[2].immediate(), args[1].immediate(), alignment, mask],
1682+
None,
1683+
);
1684+
return Ok(v);
1685+
}
1686+
14951687
if name == sym::simd_scatter {
14961688
// simd_scatter(values: <N x T>, pointers: <N x *mut T>,
14971689
// mask: <N x i{M}>) -> ()

compiler/rustc_hir_analysis/src/check/intrinsic.rs

+2
Original file line numberDiff line numberDiff line change
@@ -521,6 +521,8 @@ pub fn check_platform_intrinsic_type(tcx: TyCtxt<'_>, it: &hir::ForeignItem<'_>)
521521
sym::simd_fpowi => (1, 0, vec![param(0), tcx.types.i32], param(0)),
522522
sym::simd_fma => (1, 0, vec![param(0), param(0), param(0)], param(0)),
523523
sym::simd_gather => (3, 0, vec![param(0), param(1), param(2)], param(0)),
524+
sym::simd_masked_load => (3, 0, vec![param(0), param(1), param(2)], param(2)),
525+
sym::simd_masked_store => (3, 0, vec![param(0), param(1), param(2)], Ty::new_unit(tcx)),
524526
sym::simd_scatter => (3, 0, vec![param(0), param(1), param(2)], Ty::new_unit(tcx)),
525527
sym::simd_insert => (2, 0, vec![param(0), tcx.types.u32, param(1)], param(0)),
526528
sym::simd_extract => (2, 0, vec![param(0), tcx.types.u32], param(1)),

compiler/rustc_span/src/symbol.rs

+2
Original file line numberDiff line numberDiff line change
@@ -1516,6 +1516,8 @@ symbols! {
15161516
simd_insert,
15171517
simd_le,
15181518
simd_lt,
1519+
simd_masked_load,
1520+
simd_masked_store,
15191521
simd_mul,
15201522
simd_ne,
15211523
simd_neg,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
// compile-flags: -C no-prepopulate-passes
2+
3+
#![crate_type = "lib"]
4+
5+
#![feature(repr_simd, platform_intrinsics)]
6+
#![allow(non_camel_case_types)]
7+
8+
#[repr(simd)]
9+
#[derive(Copy, Clone, PartialEq, Debug)]
10+
pub struct Vec2<T>(pub T, pub T);
11+
12+
#[repr(simd)]
13+
#[derive(Copy, Clone, PartialEq, Debug)]
14+
pub struct Vec4<T>(pub T, pub T, pub T, pub T);
15+
16+
extern "platform-intrinsic" {
17+
fn simd_masked_load<M, P, T>(mask: M, pointer: P, values: T) -> T;
18+
}
19+
20+
// CHECK-LABEL: @load_f32x2
21+
#[no_mangle]
22+
pub unsafe fn load_f32x2(mask: Vec2<i32>, pointer: *const f32,
23+
values: Vec2<f32>) -> Vec2<f32> {
24+
// CHECK: call <2 x float> @llvm.masked.load.v2f32.p0(ptr {{.*}}, i32 {{.*}}, <2 x i1> {{.*}}, <2 x float> {{.*}})
25+
simd_masked_load(mask, pointer, values)
26+
}
27+
28+
// CHECK-LABEL: @load_pf32x4
29+
#[no_mangle]
30+
pub unsafe fn load_pf32x4(mask: Vec4<i32>, pointer: *const *const f32,
31+
values: Vec4<*const f32>) -> Vec4<*const f32> {
32+
// CHECK: call <4 x ptr> @llvm.masked.load.v4p0.p0(ptr {{.*}}, i32 {{.*}}, <4 x i1> {{.*}}, <4 x ptr> {{.*}})
33+
simd_masked_load(mask, pointer, values)
34+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
// compile-flags: -C no-prepopulate-passes
2+
3+
#![crate_type = "lib"]
4+
5+
#![feature(repr_simd, platform_intrinsics)]
6+
#![allow(non_camel_case_types)]
7+
8+
#[repr(simd)]
9+
#[derive(Copy, Clone, PartialEq, Debug)]
10+
pub struct Vec2<T>(pub T, pub T);
11+
12+
#[repr(simd)]
13+
#[derive(Copy, Clone, PartialEq, Debug)]
14+
pub struct Vec4<T>(pub T, pub T, pub T, pub T);
15+
16+
extern "platform-intrinsic" {
17+
fn simd_masked_store<M, P, T>(mask: M, pointer: P, values: T) -> ();
18+
}
19+
20+
// CHECK-LABEL: @store_f32x2
21+
#[no_mangle]
22+
pub unsafe fn store_f32x2(mask: Vec2<i32>, pointer: *mut f32, values: Vec2<f32>) {
23+
// CHECK: call void @llvm.masked.store.v2f32.p0(<2 x float> {{.*}}, ptr {{.*}}, i32 {{.*}}, <2 x i1> {{.*}})
24+
simd_masked_store(mask, pointer, values)
25+
}
26+
27+
// CHECK-LABEL: @store_pf32x4
28+
#[no_mangle]
29+
pub unsafe fn store_pf32x4(mask: Vec4<i32>, pointer: *mut *const f32, values: Vec4<*const f32>) {
30+
// CHECK: call void @llvm.masked.store.v4p0.p0(<4 x ptr> {{.*}}, ptr {{.*}}, i32 {{.*}}, <4 x i1> {{.*}})
31+
simd_masked_store(mask, pointer, values)
32+
}

0 commit comments

Comments
 (0)