Skip to content

Fix inaccurate std::intrinsics::simd documentation #137828

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 2, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 34 additions & 36 deletions library/core/src/intrinsics/simd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,28 +26,28 @@ pub unsafe fn simd_extract<T, U>(_x: T, _idx: u32) -> U;

/// Adds two simd vectors elementwise.
///
/// `T` must be a vector of integer or floating point primitive types.
/// `T` must be a vector of integers or floats.
#[rustc_intrinsic]
#[rustc_nounwind]
pub unsafe fn simd_add<T>(_x: T, _y: T) -> T;

/// Subtracts `rhs` from `lhs` elementwise.
///
/// `T` must be a vector of integer or floating point primitive types.
/// `T` must be a vector of integers or floats.
#[rustc_intrinsic]
#[rustc_nounwind]
pub unsafe fn simd_sub<T>(_lhs: T, _rhs: T) -> T;

/// Multiplies two simd vectors elementwise.
///
/// `T` must be a vector of integer or floating point primitive types.
/// `T` must be a vector of integers or floats.
#[rustc_intrinsic]
#[rustc_nounwind]
pub unsafe fn simd_mul<T>(_x: T, _y: T) -> T;

/// Divides `lhs` by `rhs` elementwise.
///
/// `T` must be a vector of integer or floating point primitive types.
/// `T` must be a vector of integers or floats.
///
/// # Safety
/// For integers, `rhs` must not contain any zero elements.
Expand All @@ -58,7 +58,7 @@ pub unsafe fn simd_div<T>(_lhs: T, _rhs: T) -> T;

/// Returns remainder of two vectors elementwise.
///
/// `T` must be a vector of integer or floating point primitive types.
/// `T` must be a vector of integers or floats.
///
/// # Safety
/// For integers, `rhs` must not contain any zero elements.
Expand All @@ -71,7 +71,7 @@ pub unsafe fn simd_rem<T>(_lhs: T, _rhs: T) -> T;
///
/// Shifts `lhs` left by `rhs`, shifting in sign bits for signed types.
///
/// `T` must be a vector of integer primitive types.
/// `T` must be a vector of integers.
///
/// # Safety
///
Expand All @@ -82,7 +82,7 @@ pub unsafe fn simd_shl<T>(_lhs: T, _rhs: T) -> T;

/// Shifts vector right elementwise, with UB on overflow.
///
/// `T` must be a vector of integer primitive types.
/// `T` must be a vector of integers.
///
/// Shifts `lhs` right by `rhs`, shifting in sign bits for signed types.
///
Expand All @@ -95,29 +95,28 @@ pub unsafe fn simd_shr<T>(_lhs: T, _rhs: T) -> T;

/// "Ands" vectors elementwise.
///
/// `T` must be a vector of integer primitive types.
/// `T` must be a vector of integers.
#[rustc_intrinsic]
#[rustc_nounwind]
pub unsafe fn simd_and<T>(_x: T, _y: T) -> T;

/// "Ors" vectors elementwise.
///
/// `T` must be a vector of integer primitive types.
/// `T` must be a vector of integers.
#[rustc_intrinsic]
#[rustc_nounwind]
pub unsafe fn simd_or<T>(_x: T, _y: T) -> T;

/// "Exclusive ors" vectors elementwise.
///
/// `T` must be a vector of integer primitive types.
/// `T` must be a vector of integers.
#[rustc_intrinsic]
#[rustc_nounwind]
pub unsafe fn simd_xor<T>(_x: T, _y: T) -> T;

/// Numerically casts a vector, elementwise.
///
/// `T` and `U` must be vectors of integer or floating point primitive types, and must have the
/// same length.
/// `T` and `U` must be vectors of integers or floats, and must have the same length.
///
/// When casting floats to integers, the result is truncated. Out-of-bounds result lead to UB.
/// When casting integers to floats, the result is rounded.
Expand All @@ -138,8 +137,7 @@ pub unsafe fn simd_cast<T, U>(_x: T) -> U;

/// Numerically casts a vector, elementwise.
///
/// `T` and `U` be a vectors of integer or floating point primitive types, and must have the
/// same length.
/// `T` and `U` be a vectors of integers or floats, and must have the same length.
///
/// Like `simd_cast`, but saturates float-to-integer conversions (NaN becomes 0).
/// This matches regular `as` and is always safe.
Expand All @@ -153,7 +151,7 @@ pub unsafe fn simd_as<T, U>(_x: T) -> U;

/// Negates a vector elementwise.
///
/// `T` must be a vector of integer or floating-point primitive types.
/// `T` must be a vector of integers or floats.
///
/// Rust panics for `-<int>::Min` due to overflow, but it is not UB with this intrinsic.
#[rustc_intrinsic]
Expand Down Expand Up @@ -187,7 +185,7 @@ pub unsafe fn simd_fmax<T>(_x: T, _y: T) -> T;

/// Tests elementwise equality of two vectors.
///
/// `T` must be a vector of floating-point primitive types.
/// `T` must be a vector of integers or floats.
///
/// `U` must be a vector of integers with the same number of elements and element size as `T`.
///
Expand All @@ -198,7 +196,7 @@ pub unsafe fn simd_eq<T, U>(_x: T, _y: T) -> U;

/// Tests elementwise inequality equality of two vectors.
///
/// `T` must be a vector of floating-point primitive types.
/// `T` must be a vector of integers or floats.
///
/// `U` must be a vector of integers with the same number of elements and element size as `T`.
///
Expand All @@ -209,7 +207,7 @@ pub unsafe fn simd_ne<T, U>(_x: T, _y: T) -> U;

/// Tests if `x` is less than `y`, elementwise.
///
/// `T` must be a vector of floating-point primitive types.
/// `T` must be a vector of integers or floats.
///
/// `U` must be a vector of integers with the same number of elements and element size as `T`.
///
Expand All @@ -220,7 +218,7 @@ pub unsafe fn simd_lt<T, U>(_x: T, _y: T) -> U;

/// Tests if `x` is less than or equal to `y`, elementwise.
///
/// `T` must be a vector of floating-point primitive types.
/// `T` must be a vector of integers or floats.
///
/// `U` must be a vector of integers with the same number of elements and element size as `T`.
///
Expand All @@ -231,7 +229,7 @@ pub unsafe fn simd_le<T, U>(_x: T, _y: T) -> U;

/// Tests if `x` is greater than `y`, elementwise.
///
/// `T` must be a vector of floating-point primitive types.
/// `T` must be a vector of integers or floats.
///
/// `U` must be a vector of integers with the same number of elements and element size as `T`.
///
Expand All @@ -242,7 +240,7 @@ pub unsafe fn simd_gt<T, U>(_x: T, _y: T) -> U;

/// Tests if `x` is greater than or equal to `y`, elementwise.
///
/// `T` must be a vector of floating-point primitive types.
/// `T` must be a vector of integers or floats.
///
/// `U` must be a vector of integers with the same number of elements and element size as `T`.
///
Expand Down Expand Up @@ -273,7 +271,7 @@ pub unsafe fn simd_shuffle<T, U, V>(_x: T, _y: T, _idx: U) -> V;
///
/// `U` must be a vector of pointers to the element type of `T`, with the same length as `T`.
///
/// `V` must be a vector of integers with the same length as `T` (but any element size).
/// `V` must be a vector of signed integers with the same length as `T` (but any element size).
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe outside the scope of this PR, but this would need to be checked in codegen and emit an ICE

Copy link
Contributor Author

@folkertdev folkertdev Feb 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what exactly? you get a monomorphization error when these rules are violated. They are not the prettiest errros, but they are not ICEs either. Unless I'm missing something. e.g. https://godbolt.org/z/f5j5re5z8

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, that's a post-mono error.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, good. I misunderstood and thought you meant the intrinsic was permitting invalid behavior.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah this is just bringing the docs up to date with the behavior that we already have and enforce.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at this in the LLVM backend I don't even see such a widening cast. This code converts the masks to i1 vectors which LLVM seems to use for them, and it does that with lshr and trunc. The lshr is entirely unnecessary for correctness as we require the input to be all-1 or all-0, but

    /// The rust simd semantics are that each element should either consist of all ones or all zeroes,
    /// but this information is not available to llvm. Truncating the vector effectively uses the lowest bit,
    /// but codegen for several targets is better if we consider the highest bit by shifting.

But I can't find anything that would go wrong with unsigned integers.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I personally agree, in that an intrinsic (that normal users should never have to worry about), we could either demand that the lane width matches the other arguments, or that the mask uses sign extension no matter the signedness of the argument.

From what I can tell the codegen backends already handle this, because e.g. in llvm (and from what i can see, also cranelift and gcc) the integers are just a bunch of bits.

The counter-argument was that performing sign extension on what rust believes is a vector with unsigned types would violate type safety.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The counter-argument was that performing sign extension on what rust believes is a vector with unsigned types would violate type safety.

I don't know what you mean by that. There's no sign extension happening, as far as I can tell. Also, all the backends have to do is implement the intended mask semantics, which is described in a bitwise way. If they use sign extension as part of that, that's completely fine. There's no type safety violation here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, the backends support it, and the intrinsic could (and in my opinion should) support it too. I'm just relaying the response I got when I suggested exactly that https://rust-lang.zulipchat.com/#narrow/channel/257879-project-portable-simd/topic/add.20.60simd_max.60.20and.20.60simd_min.60/near/502647748

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

///
/// For each pointer in `ptr`, if the corresponding value in `mask` is `!0`, read the pointer.
/// Otherwise if the corresponding value in `mask` is `0`, return the corresponding value from
Expand All @@ -294,7 +292,7 @@ pub unsafe fn simd_gather<T, U, V>(_val: T, _ptr: U, _mask: V) -> T;
///
/// `U` must be a vector of pointers to the element type of `T`, with the same length as `T`.
///
/// `V` must be a vector of integers with the same length as `T` (but any element size).
/// `V` must be a vector of signed integers with the same length as `T` (but any element size).
///
/// For each pointer in `ptr`, if the corresponding value in `mask` is `!0`, write the
/// corresponding value in `val` to the pointer.
Expand All @@ -318,7 +316,7 @@ pub unsafe fn simd_scatter<T, U, V>(_val: T, _ptr: U, _mask: V);
///
/// `U` must be a pointer to the element type of `T`
///
/// `V` must be a vector of integers with the same length as `T` (but any element size).
/// `V` must be a vector of signed integers with the same length as `T` (but any element size).
///
/// For each element, if the corresponding value in `mask` is `!0`, read the corresponding
/// pointer offset from `ptr`.
Expand All @@ -341,7 +339,7 @@ pub unsafe fn simd_masked_load<V, U, T>(_mask: V, _ptr: U, _val: T) -> T;
///
/// `U` must be a pointer to the element type of `T`
///
/// `V` must be a vector of integers with the same length as `T` (but any element size).
/// `V` must be a vector of signed integers with the same length as `T` (but any element size).
///
/// For each element, if the corresponding value in `mask` is `!0`, write the corresponding
/// value in `val` to the pointer offset from `ptr`.
Expand Down Expand Up @@ -375,7 +373,7 @@ pub unsafe fn simd_saturating_sub<T>(_lhs: T, _rhs: T) -> T;

/// Adds elements within a vector from left to right.
///
/// `T` must be a vector of integer or floating-point primitive types.
/// `T` must be a vector of integers or floats.
///
/// `U` must be the element type of `T`.
///
Expand All @@ -387,7 +385,7 @@ pub unsafe fn simd_reduce_add_ordered<T, U>(_x: T, _y: U) -> U;
/// Adds elements within a vector in arbitrary order. May also be re-associated with
/// unordered additions on the inputs/outputs.
///
/// `T` must be a vector of integer or floating-point primitive types.
/// `T` must be a vector of integers or floats.
///
/// `U` must be the element type of `T`.
#[rustc_intrinsic]
Expand All @@ -396,7 +394,7 @@ pub unsafe fn simd_reduce_add_unordered<T, U>(_x: T) -> U;

/// Multiplies elements within a vector from left to right.
///
/// `T` must be a vector of integer or floating-point primitive types.
/// `T` must be a vector of integers or floats.
///
/// `U` must be the element type of `T`.
///
Expand All @@ -408,7 +406,7 @@ pub unsafe fn simd_reduce_mul_ordered<T, U>(_x: T, _y: U) -> U;
/// Multiplies elements within a vector in arbitrary order. May also be re-associated with
/// unordered additions on the inputs/outputs.
///
/// `T` must be a vector of integer or floating-point primitive types.
/// `T` must be a vector of integers or floats.
///
/// `U` must be the element type of `T`.
#[rustc_intrinsic]
Expand Down Expand Up @@ -437,7 +435,7 @@ pub unsafe fn simd_reduce_any<T>(_x: T) -> bool;

/// Returns the maximum element of a vector.
///
/// `T` must be a vector of integer or floating-point primitive types.
/// `T` must be a vector of integers or floats.
///
/// `U` must be the element type of `T`.
///
Expand All @@ -448,7 +446,7 @@ pub unsafe fn simd_reduce_max<T, U>(_x: T) -> U;

/// Returns the minimum element of a vector.
///
/// `T` must be a vector of integer or floating-point primitive types.
/// `T` must be a vector of integers or floats.
///
/// `U` must be the element type of `T`.
///
Expand All @@ -459,7 +457,7 @@ pub unsafe fn simd_reduce_min<T, U>(_x: T) -> U;

/// Logical "ands" all elements together.
///
/// `T` must be a vector of integer or floating-point primitive types.
/// `T` must be a vector of integers or floats.
///
/// `U` must be the element type of `T`.
#[rustc_intrinsic]
Expand All @@ -468,7 +466,7 @@ pub unsafe fn simd_reduce_and<T, U>(_x: T) -> U;

/// Logical "ors" all elements together.
///
/// `T` must be a vector of integer or floating-point primitive types.
/// `T` must be a vector of integers or floats.
///
/// `U` must be the element type of `T`.
#[rustc_intrinsic]
Expand All @@ -477,7 +475,7 @@ pub unsafe fn simd_reduce_or<T, U>(_x: T) -> U;

/// Logical "exclusive ors" all elements together.
///
/// `T` must be a vector of integer or floating-point primitive types.
/// `T` must be a vector of integers or floats.
///
/// `U` must be the element type of `T`.
#[rustc_intrinsic]
Expand Down Expand Up @@ -523,9 +521,9 @@ pub unsafe fn simd_bitmask<T, U>(_x: T) -> U;

/// Selects elements from a mask.
///
/// `M` must be an integer vector.
/// `T` must be a vector.
///
/// `T` must be a vector with the same number of elements as `M`.
/// `M` must be a signed integer vector with the same length as `T` (but any element size).
///
/// For each element, if the corresponding value in `mask` is `!0`, select the element from
/// `if_true`. If the corresponding value in `mask` is `0`, select the element from
Expand Down
Loading