Skip to content

Commit 53385eb

Browse files
authored
128-bit load/store primitives for GC'd arrays (#2247)
1 parent bfb237d commit 53385eb

26 files changed

+2534
-1576
lines changed

middle_end/flambda2/docs/simd.md

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
2+
# SIMD in flambda2
3+
4+
## Vector types
5+
6+
<!-- CR mslater: more docs -->
7+
8+
## Intrinsics
9+
10+
<!-- CR mslater: more docs -->
11+
12+
## Load/Store
13+
14+
Unlike intrinsics, SIMD loads and stores are represented as flambda2-visible primitives.
15+
16+
- `string`, `bytes`: `caml_{string,bytes}_{getu128,setu128}{u}` map to `MOVUPD`, an unaligned 128-bit vector load/store.
17+
The primitives can operate on all 128-bit vector types.
18+
The safe primitives raise `Invalid_argument` if any part of the vector is not within the array bounds; the `u` suffix omits this check.
19+
Aligned load/store is not available because these values may be moved by the GC.
20+
21+
- `bigstring`: `caml_bigstring_{get,set}{u}128{u}` map to `MOVAPD` or `MOVUPD`.
22+
The primitives can operate on all 128-bit vector types.
23+
The prefix `u` indicates an unaligned operation (`MOVUPD`), and the suffix `u` omits bounds checking.
24+
Aligned load/store is available because bigstrings are allocated by `malloc`.
25+
26+
- `float array`, `floatarray`, `float# array`: the corresponding primitives take an index in `float`s and are required to operate on `float64x2`s.
27+
The address is computed as `array + index * 8`; the safe primitives bounds-check against `0, length - 1`.
28+
The primitives on `float array` are only available when the float array optimization is enabled.
29+
Aligned load/store is not available because these values may be moved by the GC.
30+
31+
- `nativeint# array`, `int64# array`: the corresponding primitives take an index in `nativeint`s/`int64`s and are required to operate on `int64x2`s.
32+
The address is computed as `array + index * 8`; the safe primitives bounds-check against `0, length - 1`.
33+
The primitives on `nativeint# array` are only available in 64-bit mode.
34+
Aligned load/store is not available because these values may be moved by the GC.
35+
36+
- `int32# array`: the corresponding primitives take an index in `int32`s and are required to operate on `int32x4`s.
37+
The address is computed as `array + index * 4`; the safe primitives bounds-check against `0, length - 3`.
38+
Aligned load/store is not available because these values may be moved by the GC.
39+
40+
- `%immediate64 array`: the corresponding primitives take an index in immediates, and are required to operate on `int64x2`s.
41+
The primitives can operate on all `('a : immediate64) array`s and are only available in 64-bit mode.
42+
The address is computed as `array + index * 8`; the safe primitives bounds-check against `0, length - 1`.
43+
Aligned load/store is not available because these values may be moved by the GC.
44+
Load/store directly reads/writes two 64-bit **tagged** values. The "safe" primitives do not check for proper tagging,
45+
so are not to be exposed to users as "safe."

middle_end/flambda2/from_lambda/closure_conversion.ml

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -848,9 +848,15 @@ let close_primitive acc env ~let_bound_ids_with_kinds named
848848
| Pbytes_set_16 _ | Pbytes_set_32 _ | Pbytes_set_64 _ | Pbytes_set_128 _
849849
| Pbigstring_load_16 _ | Pbigstring_load_32 _ | Pbigstring_load_64 _
850850
| Pbigstring_load_128 _ | Pbigstring_set_16 _ | Pbigstring_set_32 _
851-
| Pbigstring_set_64 _ | Pbigstring_set_128 _ | Pctconst _ | Pbswap16
852-
| Pbbswap _ | Pint_as_pointer _ | Popaque _ | Pprobe_is_enabled _
853-
| Pobj_dup | Pobj_magic _ | Punbox_float _
851+
| Pbigstring_set_64 _ | Pbigstring_set_128 _ | Pfloatarray_load_128 _
852+
| Pfloat_array_load_128 _ | Pint_array_load_128 _
853+
| Punboxed_float_array_load_128 _ | Punboxed_int32_array_load_128 _
854+
| Punboxed_int64_array_load_128 _ | Punboxed_nativeint_array_load_128 _
855+
| Pfloatarray_set_128 _ | Pfloat_array_set_128 _ | Pint_array_set_128 _
856+
| Punboxed_float_array_set_128 _ | Punboxed_int32_array_set_128 _
857+
| Punboxed_int64_array_set_128 _ | Punboxed_nativeint_array_set_128 _
858+
| Pctconst _ | Pbswap16 | Pbbswap _ | Pint_as_pointer _ | Popaque _
859+
| Pprobe_is_enabled _ | Pobj_dup | Pobj_magic _ | Punbox_float _
854860
| Pbox_float (_, _)
855861
| Punbox_int _ | Pbox_int _ | Pmake_unboxed_product _
856862
| Punboxed_product_field _ | Pget_header _ | Prunstack | Pperform

middle_end/flambda2/from_lambda/lambda_to_flambda.ml

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -592,6 +592,20 @@ let primitive_can_raise (prim : Lambda.primitive) =
592592
| Pbigstring_set_32 { unsafe = false; boxed = _ }
593593
| Pbigstring_set_64 { unsafe = false; boxed = _ }
594594
| Pbigstring_set_128 { unsafe = false; _ }
595+
| Pfloatarray_load_128 { unsafe = false; _ }
596+
| Pfloat_array_load_128 { unsafe = false; _ }
597+
| Pint_array_load_128 { unsafe = false; _ }
598+
| Punboxed_float_array_load_128 { unsafe = false; _ }
599+
| Punboxed_int32_array_load_128 { unsafe = false; _ }
600+
| Punboxed_int64_array_load_128 { unsafe = false; _ }
601+
| Punboxed_nativeint_array_load_128 { unsafe = false; _ }
602+
| Pfloatarray_set_128 { unsafe = false; _ }
603+
| Pfloat_array_set_128 { unsafe = false; _ }
604+
| Pint_array_set_128 { unsafe = false; _ }
605+
| Punboxed_float_array_set_128 { unsafe = false; _ }
606+
| Punboxed_int32_array_set_128 { unsafe = false; _ }
607+
| Punboxed_int64_array_set_128 { unsafe = false; _ }
608+
| Punboxed_nativeint_array_set_128 { unsafe = false; _ }
595609
| Pdivbint { is_safe = Safe; _ }
596610
| Pmodbint { is_safe = Safe; _ }
597611
| Pbigarrayref (false, _, _, _)
@@ -664,6 +678,20 @@ let primitive_can_raise (prim : Lambda.primitive) =
664678
| Pbigstring_set_32 { unsafe = true; boxed = _ }
665679
| Pbigstring_set_64 { unsafe = true; boxed = _ }
666680
| Pbigstring_set_128 { unsafe = true; _ }
681+
| Pfloatarray_load_128 { unsafe = true; _ }
682+
| Pfloat_array_load_128 { unsafe = true; _ }
683+
| Pint_array_load_128 { unsafe = true; _ }
684+
| Punboxed_float_array_load_128 { unsafe = true; _ }
685+
| Punboxed_int32_array_load_128 { unsafe = true; _ }
686+
| Punboxed_int64_array_load_128 { unsafe = true; _ }
687+
| Punboxed_nativeint_array_load_128 { unsafe = true; _ }
688+
| Pfloatarray_set_128 { unsafe = true; _ }
689+
| Pfloat_array_set_128 { unsafe = true; _ }
690+
| Pint_array_set_128 { unsafe = true; _ }
691+
| Punboxed_float_array_set_128 { unsafe = true; _ }
692+
| Punboxed_int32_array_set_128 { unsafe = true; _ }
693+
| Punboxed_int64_array_set_128 { unsafe = true; _ }
694+
| Punboxed_nativeint_array_set_128 { unsafe = true; _ }
667695
| Pctconst _ | Pbswap16 | Pbbswap _ | Pint_as_pointer _ | Popaque _
668696
| Pprobe_is_enabled _ | Pobj_dup | Pobj_magic _
669697
| Pbox_float (_, _)

middle_end/flambda2/from_lambda/lambda_to_flambda_primitives.ml

Lines changed: 149 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -110,12 +110,12 @@ let convert_block_shape (shape : L.block_shape) ~num_fields =
110110
num_fields shape_length;
111111
List.map K.With_subkind.from_lambda_value_kind shape
112112

113-
let check_float_array_optimisation_enabled () =
113+
let check_float_array_optimisation_enabled name =
114114
if not (Flambda_features.flat_float_array ())
115115
then
116-
Misc.fatal_error
117-
"[Pgenarray] is not expected when the float array optimisation is \
118-
disabled"
116+
Misc.fatal_errorf
117+
"[%s] is not expected when the float array optimisation is disabled" name
118+
()
119119

120120
type converted_array_kind =
121121
| Array_kind of P.Array_kind.t
@@ -124,7 +124,7 @@ type converted_array_kind =
124124
let convert_array_kind (kind : L.array_kind) : converted_array_kind =
125125
match kind with
126126
| Pgenarray ->
127-
check_float_array_optimisation_enabled ();
127+
check_float_array_optimisation_enabled "Pgenarray";
128128
Float_array_opt_dynamic
129129
| Paddrarray -> Array_kind Values
130130
| Pintarray -> Array_kind Immediates
@@ -257,7 +257,7 @@ let convert_array_kind_to_duplicate_array_kind (kind : L.array_kind) :
257257
converted_duplicate_array_kind =
258258
match kind with
259259
| Pgenarray ->
260-
check_float_array_optimisation_enabled ();
260+
check_float_array_optimisation_enabled "Pgenarray";
261261
Float_array_opt_dynamic
262262
| Paddrarray -> Duplicate_array_kind Values
263263
| Pintarray -> Duplicate_array_kind Immediates
@@ -565,6 +565,78 @@ let bytes_like_set_safe ~dbg ~size_int ~access_size kind ~boxed bytes index
565565
new_value)
566566
bytes index
567567

568+
(* Array vector load/store *)
569+
570+
let array_vector_access_validity_condition array ~size_int
571+
(array_kind : P.Array_kind.t) index =
572+
let width_in_scalars =
573+
match array_kind with
574+
| Naked_floats | Immediates | Naked_int64s | Naked_nativeints -> 2
575+
| Naked_int32s -> 4
576+
| Values ->
577+
Misc.fatal_error
578+
"Attempted to load/store a SIMD vector from/to a value array."
579+
in
580+
let length_untagged =
581+
untag_int (H.Prim (Unary (Array_length (Array_kind array_kind), array)))
582+
in
583+
let reduced_length_untagged =
584+
H.Prim
585+
(Binary
586+
( Int_arith (Naked_immediate, Sub),
587+
length_untagged,
588+
Simple
589+
(Simple.untagged_const_int
590+
(Targetint_31_63.of_int (width_in_scalars - 1))) ))
591+
in
592+
(* We need to convert the length into a naked_nativeint because the optimised
593+
version of the max_with_zero function needs to be on machine-width integers
594+
to work (or at least on an integer number of bytes to work). *)
595+
let reduced_length_nativeint =
596+
H.Prim
597+
(Unary
598+
( Num_conv { src = Naked_immediate; dst = Naked_nativeint },
599+
reduced_length_untagged ))
600+
in
601+
let check_nativeint = max_with_zero ~size_int reduced_length_nativeint in
602+
let check_untagged =
603+
H.Prim
604+
(Unary
605+
( Num_conv { src = Naked_nativeint; dst = Naked_immediate },
606+
check_nativeint ))
607+
in
608+
check_bound_tagged index check_untagged
609+
610+
let check_array_vector_access ~dbg ~size_int ~array array_kind ~index primitive
611+
: H.expr_primitive =
612+
checked_access ~primitive
613+
~conditions:
614+
[array_vector_access_validity_condition ~size_int array array_kind index]
615+
~dbg
616+
617+
let array_like_load_128 ~dbg ~size_int ~unsafe ~mode ~current_region array_kind
618+
array index =
619+
let primitive =
620+
box_vec128 mode ~current_region
621+
(H.Binary (Array_load (array_kind, Vec128, Mutable), array, index))
622+
in
623+
if unsafe
624+
then primitive
625+
else
626+
check_array_vector_access ~dbg ~size_int ~array array_kind ~index primitive
627+
628+
let array_like_set_128 ~dbg ~size_int ~unsafe array_kind array index new_value =
629+
let primitive =
630+
H.Ternary
631+
(Array_set (array_kind, Vec128), array, index, unbox_vec128 new_value)
632+
in
633+
if unsafe
634+
then primitive
635+
else
636+
check_array_vector_access ~dbg ~size_int ~array
637+
(P.Array_set_kind.array_kind array_kind)
638+
~index primitive
639+
568640
(* Bigarray accesses *)
569641
let bigarray_box_or_tag_raw_value_to_read kind alloc_mode =
570642
let error what =
@@ -688,17 +760,20 @@ let check_array_access ~dbg ~array array_kind ~index primitive :
688760
let array_load_unsafe ~array ~index (array_ref_kind : Array_ref_kind.t)
689761
~current_region : H.expr_primitive =
690762
match array_ref_kind with
691-
| Immediates -> Binary (Array_load (Immediates, Mutable), array, index)
692-
| Values -> Binary (Array_load (Values, Mutable), array, index)
763+
| Immediates -> Binary (Array_load (Immediates, Scalar, Mutable), array, index)
764+
| Values -> Binary (Array_load (Values, Scalar, Mutable), array, index)
693765
| Naked_floats_to_be_boxed mode ->
694766
box_float mode
695-
(Binary (Array_load (Naked_floats, Mutable), array, index))
767+
(Binary (Array_load (Naked_floats, Scalar, Mutable), array, index))
696768
~current_region
697-
| Naked_floats -> Binary (Array_load (Naked_floats, Mutable), array, index)
698-
| Naked_int32s -> Binary (Array_load (Naked_int32s, Mutable), array, index)
699-
| Naked_int64s -> Binary (Array_load (Naked_int64s, Mutable), array, index)
769+
| Naked_floats ->
770+
Binary (Array_load (Naked_floats, Scalar, Mutable), array, index)
771+
| Naked_int32s ->
772+
Binary (Array_load (Naked_int32s, Scalar, Mutable), array, index)
773+
| Naked_int64s ->
774+
Binary (Array_load (Naked_int64s, Scalar, Mutable), array, index)
700775
| Naked_nativeints ->
701-
Binary (Array_load (Naked_nativeints, Mutable), array, index)
776+
Binary (Array_load (Naked_nativeints, Scalar, Mutable), array, index)
702777

703778
let array_set_unsafe ~array ~index ~new_value
704779
(array_set_kind : Array_set_kind.t) : H.expr_primitive =
@@ -710,7 +785,7 @@ let array_set_unsafe ~array ~index ~new_value
710785
| Naked_floats_to_be_unboxed -> unbox_float new_value
711786
in
712787
let array_set_kind = convert_intermediate_array_set_kind array_set_kind in
713-
Ternary (Array_set array_set_kind, array, index, new_value)
788+
Ternary (Array_set (array_set_kind, Scalar), array, index, new_value)
714789

715790
let[@inline always] match_on_array_ref_kind ~array array_ref_kind f :
716791
H.expr_primitive =
@@ -1526,6 +1601,58 @@ let convert_lprim ~big_endian (prim : L.primitive) (args : Simple.t list list)
15261601
[ bytes_like_set_safe ~dbg ~size_int
15271602
~access_size:(One_twenty_eight { aligned })
15281603
Bigstring ~boxed bigstring index new_value ]
1604+
| Pfloat_array_load_128 { unsafe; mode }, [[array]; [index]] ->
1605+
check_float_array_optimisation_enabled "Pfloat_array_load_128";
1606+
[ array_like_load_128 ~dbg ~size_int ~current_region ~unsafe ~mode
1607+
Naked_floats array index ]
1608+
| Pfloatarray_load_128 { unsafe; mode }, [[array]; [index]]
1609+
| Punboxed_float_array_load_128 { unsafe; mode }, [[array]; [index]] ->
1610+
[ array_like_load_128 ~dbg ~size_int ~current_region ~unsafe ~mode
1611+
Naked_floats array index ]
1612+
| Pint_array_load_128 { unsafe; mode }, [[array]; [index]] ->
1613+
if Targetint.size <> 64
1614+
then Misc.fatal_error "[Pint_array_load_128]: immediates must be 64 bits.";
1615+
[ array_like_load_128 ~dbg ~size_int ~current_region ~unsafe ~mode
1616+
Immediates array index ]
1617+
| Punboxed_int64_array_load_128 { unsafe; mode }, [[array]; [index]] ->
1618+
[ array_like_load_128 ~dbg ~size_int ~current_region ~unsafe ~mode
1619+
Naked_int64s array index ]
1620+
| Punboxed_nativeint_array_load_128 { unsafe; mode }, [[array]; [index]] ->
1621+
if Targetint.size <> 64
1622+
then
1623+
Misc.fatal_error
1624+
"[Punboxed_nativeint_array_load_128]: nativeint must be 64 bits.";
1625+
[ array_like_load_128 ~dbg ~size_int ~current_region ~unsafe ~mode
1626+
Naked_nativeints array index ]
1627+
| Punboxed_int32_array_load_128 { unsafe; mode }, [[array]; [index]] ->
1628+
[ array_like_load_128 ~dbg ~size_int ~current_region ~unsafe ~mode
1629+
Naked_int32s array index ]
1630+
| Pfloat_array_set_128 { unsafe }, [[array]; [index]; [new_value]] ->
1631+
check_float_array_optimisation_enabled "Pfloat_array_set_128";
1632+
[ array_like_set_128 ~dbg ~size_int ~unsafe Naked_floats array index
1633+
new_value ]
1634+
| Pfloatarray_set_128 { unsafe }, [[array]; [index]; [new_value]]
1635+
| Punboxed_float_array_set_128 { unsafe }, [[array]; [index]; [new_value]] ->
1636+
[ array_like_set_128 ~dbg ~size_int ~unsafe Naked_floats array index
1637+
new_value ]
1638+
| Pint_array_set_128 { unsafe }, [[array]; [index]; [new_value]] ->
1639+
if Targetint.size <> 64
1640+
then Misc.fatal_error "[Pint_array_set_128]: immediates must be 64 bits.";
1641+
[array_like_set_128 ~dbg ~size_int ~unsafe Immediates array index new_value]
1642+
| Punboxed_int64_array_set_128 { unsafe }, [[array]; [index]; [new_value]] ->
1643+
[ array_like_set_128 ~dbg ~size_int ~unsafe Naked_int64s array index
1644+
new_value ]
1645+
| Punboxed_nativeint_array_set_128 { unsafe }, [[array]; [index]; [new_value]]
1646+
->
1647+
if Targetint.size <> 64
1648+
then
1649+
Misc.fatal_error
1650+
"[Punboxed_nativeint_array_load_128]: nativeint must be 64 bits.";
1651+
[ array_like_set_128 ~dbg ~size_int ~unsafe Naked_nativeints array index
1652+
new_value ]
1653+
| Punboxed_int32_array_set_128 { unsafe }, [[array]; [index]; [new_value]] ->
1654+
[ array_like_set_128 ~dbg ~size_int ~unsafe Naked_int32s array index
1655+
new_value ]
15291656
| Pcompare_ints, [[i1]; [i2]] ->
15301657
[ tag_int
15311658
(Binary
@@ -1612,7 +1739,10 @@ let convert_lprim ~big_endian (prim : L.primitive) (args : Simple.t list list)
16121739
| Pasrbint _ | Pfield_computed _ | Pdivbint _ | Pmodbint _
16131740
| Psetfloatfield _ | Psetufloatfield _ | Pbintcomp _ | Punboxed_int_comp _
16141741
| Pbigstring_load_16 _ | Pbigstring_load_32 _ | Pbigstring_load_64 _
1615-
| Pbigstring_load_128 _
1742+
| Pbigstring_load_128 _ | Pfloatarray_load_128 _ | Pfloat_array_load_128 _
1743+
| Pint_array_load_128 _ | Punboxed_float_array_load_128 _
1744+
| Punboxed_int32_array_load_128 _ | Punboxed_int64_array_load_128 _
1745+
| Punboxed_nativeint_array_load_128 _
16161746
| Parrayrefu
16171747
( Pgenarray_ref _ | Paddrarray_ref | Pintarray_ref | Pfloatarray_ref _
16181748
| Punboxedfloatarray_ref _ | Punboxedintarray_ref _ )
@@ -1639,7 +1769,10 @@ let convert_lprim ~big_endian (prim : L.primitive) (args : Simple.t list list)
16391769
| Punboxedfloatarray_set _ | Punboxedintarray_set _ )
16401770
| Pbytes_set_16 _ | Pbytes_set_32 _ | Pbytes_set_64 _ | Pbytes_set_128 _
16411771
| Pbigstring_set_16 _ | Pbigstring_set_32 _ | Pbigstring_set_64 _
1642-
| Pbigstring_set_128 _ | Patomic_cas ),
1772+
| Pbigstring_set_128 _ | Pfloatarray_set_128 _ | Pfloat_array_set_128 _
1773+
| Pint_array_set_128 _ | Punboxed_float_array_set_128 _
1774+
| Punboxed_int32_array_set_128 _ | Punboxed_int64_array_set_128 _
1775+
| Punboxed_nativeint_array_set_128 _ | Patomic_cas ),
16431776
( []
16441777
| [_]
16451778
| [_; _]

middle_end/flambda2/parser/fexpr.ml

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,10 @@ type string_accessor_width = Flambda_primitive.string_accessor_width =
327327
| Sixty_four
328328
| One_twenty_eight of { aligned : bool }
329329

330+
type array_accessor_width = Flambda_primitive.array_accessor_width =
331+
| Scalar
332+
| Vec128
333+
330334
type string_like_value = Flambda_primitive.string_like_value =
331335
| String
332336
| Bytes
@@ -344,7 +348,7 @@ type infix_binop =
344348
| Float_comp of unit comparison_behaviour
345349

346350
type binop =
347-
| Array_load of array_kind * mutability
351+
| Array_load of array_kind * array_accessor_width * mutability
348352
| Block_load of block_access_kind * mutability
349353
| Phys_equal of equality_comparison
350354
| Int_arith of standard_int * binary_int_arith_op
@@ -356,7 +360,7 @@ type binop =
356360

357361
type ternop =
358362
(* CR mshinwell: Array_set should use "array_set_kind" *)
359-
| Array_set of array_kind * init_or_assign
363+
| Array_set of array_kind * array_accessor_width * init_or_assign
360364
| Block_set of block_access_kind * init_or_assign
361365
| Bytes_or_bigstring_set of bytes_like_value * string_accessor_width
362366

0 commit comments

Comments
 (0)