Skip to content
This repository was archived by the owner on Dec 22, 2021. It is now read-only.

Commit da4f63b

Browse files
authored
Implement floating-point rounding in interpreter (#344)
Implement f32x4 and f64x2 ceil, floor, trunc, nearest. They have the same behavior as the f32 and f64 instructions. Also implemented to encoding and decoding. These new instructions were added to simd_f32x4.wast test case, and the test generation script is updated with these new instructions.
1 parent 9576fd0 commit da4f63b

15 files changed

+1072
-3
lines changed

interpreter/binary/decode.ml

+8
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,14 @@ let simd_prefix s =
377377
| 0xcel -> i64x2_add
378378
| 0xd1l -> i64x2_sub
379379
| 0xd5l -> i64x2_mul
380+
| 0xd8l -> f32x4_ceil
381+
| 0xd9l -> f32x4_floor
382+
| 0xdal -> f32x4_trunc
383+
| 0xdbl -> f32x4_nearest
384+
| 0xdcl -> f64x2_ceil
385+
| 0xddl -> f64x2_floor
386+
| 0xdel -> f64x2_trunc
387+
| 0xdfl -> f64x2_nearest
380388
| 0xe0l -> f32x4_abs
381389
| 0xe1l -> f32x4_neg
382390
| 0xe3l -> f32x4_sqrt

interpreter/binary/encode.ml

+8
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,14 @@ let encode m =
339339
| Unary (V128 V128Op.(I32x4 WidenLowU)) -> simd_op 0xa9l
340340
| Unary (V128 V128Op.(I32x4 WidenHighU)) -> simd_op 0xaal
341341
| Unary (V128 V128Op.(I64x2 Neg)) -> simd_op 0xc1l
342+
| Unary (V128 V128Op.(F32x4 Ceil)) -> simd_op 0xd8l
343+
| Unary (V128 V128Op.(F32x4 Floor)) -> simd_op 0xd9l
344+
| Unary (V128 V128Op.(F32x4 Trunc)) -> simd_op 0xdal
345+
| Unary (V128 V128Op.(F32x4 Nearest)) -> simd_op 0xdbl
346+
| Unary (V128 V128Op.(F64x2 Ceil)) -> simd_op 0xdcl
347+
| Unary (V128 V128Op.(F64x2 Floor)) -> simd_op 0xddl
348+
| Unary (V128 V128Op.(F64x2 Trunc)) -> simd_op 0xdel
349+
| Unary (V128 V128Op.(F64x2 Nearest)) -> simd_op 0xdfl
342350
| Unary (V128 V128Op.(F32x4 Abs)) -> simd_op 0xe0l
343351
| Unary (V128 V128Op.(F32x4 Neg)) -> simd_op 0xe1l
344352
| Unary (V128 V128Op.(F32x4 Sqrt)) -> simd_op 0xe3l

interpreter/exec/eval_numeric.ml

+8
Original file line numberDiff line numberDiff line change
@@ -148,11 +148,19 @@ struct
148148
| F32x4 Abs -> to_value (SXX.F32x4.abs (of_value 1 v))
149149
| F32x4 Neg -> to_value (SXX.F32x4.neg (of_value 1 v))
150150
| F32x4 Sqrt -> to_value (SXX.F32x4.sqrt (of_value 1 v))
151+
| F32x4 Ceil -> to_value (SXX.F32x4.ceil (of_value 1 v))
152+
| F32x4 Floor -> to_value (SXX.F32x4.floor (of_value 1 v))
153+
| F32x4 Trunc -> to_value (SXX.F32x4.trunc (of_value 1 v))
154+
| F32x4 Nearest -> to_value (SXX.F32x4.nearest (of_value 1 v))
151155
| F32x4 ConvertI32x4S -> to_value (SXX.F32x4_convert.convert_i32x4_s (of_value 1 v))
152156
| F32x4 ConvertI32x4U -> to_value (SXX.F32x4_convert.convert_i32x4_u (of_value 1 v))
153157
| F64x2 Abs -> to_value (SXX.F64x2.abs (of_value 1 v))
154158
| F64x2 Neg -> to_value (SXX.F64x2.neg (of_value 1 v))
155159
| F64x2 Sqrt -> to_value (SXX.F64x2.sqrt (of_value 1 v))
160+
| F64x2 Ceil -> to_value (SXX.F64x2.ceil (of_value 1 v))
161+
| F64x2 Floor -> to_value (SXX.F64x2.floor (of_value 1 v))
162+
| F64x2 Trunc -> to_value (SXX.F64x2.trunc (of_value 1 v))
163+
| F64x2 Nearest -> to_value (SXX.F64x2.nearest (of_value 1 v))
156164
| V128 Not -> to_value (SXX.V128.lognot (of_value 1 v))
157165
| _ -> failwith "TODO v128 unimplemented unop"
158166

interpreter/exec/simd.ml

+8
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,10 @@ sig
109109
val abs : t -> t
110110
val neg : t -> t
111111
val sqrt : t -> t
112+
val ceil : t -> t
113+
val floor : t -> t
114+
val trunc : t -> t
115+
val nearest : t -> t
112116
val add : t -> t -> t
113117
val sub : t -> t -> t
114118
val mul : t -> t -> t
@@ -254,6 +258,10 @@ struct
254258
let abs = unop Float.abs
255259
let neg = unop Float.neg
256260
let sqrt = unop Float.sqrt
261+
let ceil = unop Float.ceil
262+
let floor = unop Float.floor
263+
let trunc = unop Float.trunc
264+
let nearest = unop Float.nearest
257265
let add = binop Float.add
258266
let sub = binop Float.sub
259267
let mul = binop Float.mul

interpreter/syntax/ast.ml

+3-1
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,9 @@ struct
5454
| Eq | Ne | LtS | LtU | LeS | LeU | GtS | GtU | GeS | GeU
5555
| Swizzle | Shuffle of int list | NarrowS | NarrowU
5656
| AddSatS | AddSatU | SubSatS | SubSatU
57-
type funop = Abs | Neg | Sqrt | ConvertI32x4S | ConvertI32x4U
57+
type funop = Abs | Neg | Sqrt
58+
| Ceil | Floor | Trunc | Nearest
59+
| ConvertI32x4S | ConvertI32x4U
5860
type fbinop = Add | Sub | Mul | Div | Min | Max
5961
| Eq | Ne | Lt | Le | Gt | Ge
6062
type vunop = Not

interpreter/syntax/operators.ml

+8
Original file line numberDiff line numberDiff line change
@@ -385,6 +385,10 @@ let f32x4_ge = Binary (V128 V128Op.(F32x4 Ge))
385385
let f32x4_abs = Unary (V128 (V128Op.F32x4 V128Op.Abs))
386386
let f32x4_neg = Unary (V128 (V128Op.F32x4 V128Op.Neg))
387387
let f32x4_sqrt = Unary (V128 (V128Op.F32x4 V128Op.Sqrt))
388+
let f32x4_ceil = Unary (V128 (V128Op.(F32x4 Ceil)))
389+
let f32x4_floor = Unary (V128 (V128Op.(F32x4 Floor)))
390+
let f32x4_trunc = Unary (V128 (V128Op.(F32x4 Trunc)))
391+
let f32x4_nearest = Unary (V128 (V128Op.(F32x4 Nearest)))
388392
let f32x4_add = Binary (V128 (V128Op.F32x4 V128Op.Add))
389393
let f32x4_sub = Binary (V128 (V128Op.F32x4 V128Op.Sub))
390394
let f32x4_mul = Binary (V128 (V128Op.F32x4 V128Op.Mul))
@@ -405,6 +409,10 @@ let f64x2_gt = Binary (V128 V128Op.(F64x2 Gt))
405409
let f64x2_ge = Binary (V128 V128Op.(F64x2 Ge))
406410
let f64x2_neg = Unary (V128 (V128Op.F64x2 V128Op.Neg))
407411
let f64x2_sqrt = Unary (V128 (V128Op.F64x2 V128Op.Sqrt))
412+
let f64x2_ceil = Unary (V128 (V128Op.(F64x2 Ceil)))
413+
let f64x2_floor = Unary (V128 (V128Op.(F64x2 Floor)))
414+
let f64x2_trunc = Unary (V128 (V128Op.(F64x2 Trunc)))
415+
let f64x2_nearest = Unary (V128 (V128Op.(F64x2 Nearest)))
408416
let f64x2_add = Binary (V128 (V128Op.F64x2 V128Op.Add))
409417
let f64x2_sub = Binary (V128 (V128Op.F64x2 V128Op.Sub))
410418
let f64x2_mul = Binary (V128 (V128Op.F64x2 V128Op.Mul))

interpreter/text/arrange.ml

+8
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,14 @@ struct
219219
| I32x4 TruncSatF32x4S -> "i32x4.trunc_sat_f32x4_s"
220220
| I32x4 TruncSatF32x4U -> "i32x4.trunc_sat_f32x4_u"
221221
| I64x2 Neg -> "i64x2.neg"
222+
| F32x4 Ceil -> "f32x4.ceil"
223+
| F32x4 Floor -> "f32x4.floor"
224+
| F32x4 Trunc -> "f32x4.trunc"
225+
| F32x4 Nearest -> "f32x4.nearest"
226+
| F64x2 Ceil -> "f64x2.ceil"
227+
| F64x2 Floor -> "f64x2.floor"
228+
| F64x2 Trunc -> "f64x2.trunc"
229+
| F64x2 Nearest -> "f64x2.nearest"
222230
| F32x4 Abs -> "f32x4.abs"
223231
| F32x4 Neg -> "f32x4.neg"
224232
| F32x4 Sqrt -> "f32x4.sqrt"

interpreter/text/lexer.mll

+4
Original file line numberDiff line numberDiff line change
@@ -490,6 +490,10 @@ rule token = parse
490490
| (simd_shape as s)".neg"
491491
{ UNARY (simdop s i8x16_neg i16x8_neg i32x4_neg i64x2_neg f32x4_neg f64x2_neg) }
492492
| (simd_float_shape as s)".sqrt" { UNARY (simd_float_op s f32x4_sqrt f64x2_sqrt) }
493+
| (simd_float_shape as s)".ceil" { UNARY (simd_float_op s f32x4_ceil f64x2_ceil) }
494+
| (simd_float_shape as s)".floor" { UNARY (simd_float_op s f32x4_floor f64x2_floor) }
495+
| (simd_float_shape as s)".trunc" { UNARY (simd_float_op s f32x4_trunc f64x2_trunc) }
496+
| (simd_float_shape as s)".nearest" { UNARY (simd_float_op s f32x4_nearest f64x2_nearest) }
493497
| (simd_shape as s)".add"
494498
{ BINARY (simdop s i8x16_add i16x8_add i32x4_add i64x2_add f32x4_add f64x2_add) }
495499
| (simd_shape as s)".sub"

test/core/simd/meta/README.md

+2
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ Currently it only support following simd test files generation.
2222
- 'simd_i16x8_sat_arith.wast'
2323
- 'simd_f32x4.wast'
2424
- 'simd_f64x2.wast'
25+
- 'simd_f32x4_rounding'
26+
- 'simd_f64x2_rounding'
2527

2628

2729
Usage:

test/core/simd/meta/gen_tests.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
'simd_f32x4',
2727
'simd_f64x2',
2828
'simd_int_arith2',
29+
'simd_f32x4_rounding',
30+
'simd_f64x2_rounding',
2931
)
3032

3133

@@ -61,4 +63,4 @@ def main():
6163

6264
if __name__ == '__main__':
6365
main()
64-
print('Done.')
66+
print('Done.')
+85
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
#!/usr/bin/env python3
2+
3+
"""
4+
Generate f32x4 [ceil, floor, trunc, nearest] cases.
5+
"""
6+
7+
from simd_f32x4_arith import Simdf32x4ArithmeticCase
8+
from simd_float_op import FloatingPointRoundingOp
9+
from simd import SIMD
10+
from test_assert import AssertReturn
11+
12+
13+
class Simdf32x4RoundingCase(Simdf32x4ArithmeticCase):
14+
UNARY_OPS = ('ceil', 'floor', 'trunc', 'nearest')
15+
BINARY_OPS = ()
16+
floatOp = FloatingPointRoundingOp()
17+
18+
def get_combine_cases(self):
19+
return ''
20+
21+
def get_normal_case(self):
22+
"""Normal test cases from WebAssembly core tests.
23+
"""
24+
cases = []
25+
unary_test_data = []
26+
27+
for op in self.UNARY_OPS:
28+
op_name = self.full_op_name(op)
29+
for operand in self.FLOAT_NUMBERS:
30+
result = self.floatOp.unary_op(op, operand)
31+
if 'nan' in result:
32+
unary_test_data.append([op_name, operand, 'nan:canonical'])
33+
else:
34+
unary_test_data.append([op_name, operand, result])
35+
36+
for operand in self.LITERAL_NUMBERS:
37+
result = self.floatOp.unary_op(op, operand, hex_form=False)
38+
unary_test_data.append([op_name, operand, result])
39+
40+
for operand in self.NAN_NUMBERS:
41+
if 'nan:' in operand:
42+
unary_test_data.append([op_name, operand, 'nan:arithmetic'])
43+
else:
44+
unary_test_data.append([op_name, operand, 'nan:canonical'])
45+
46+
for case in unary_test_data:
47+
cases.append(str(AssertReturn(case[0],
48+
[SIMD.v128_const(elem, self.LANE_TYPE) for elem in case[1:-1]],
49+
SIMD.v128_const(case[-1], self.LANE_TYPE))))
50+
51+
self.get_unknown_operator_case(cases)
52+
53+
return '\n'.join(cases)
54+
55+
def get_unknown_operator_case(self, cases):
56+
"""Unknown operator cases.
57+
"""
58+
59+
tpl_assert = "(assert_malformed (module quote \"(memory 1) (func (result v128) " \
60+
"({lane_type}.{op} {value}))\") \"unknown operator\")"
61+
62+
unknown_op_cases = ['\n\n;; Unknown operators\n']
63+
cases.extend(unknown_op_cases)
64+
65+
for lane_type in ['i8x16', 'i16x8', 'i32x4', 'i64x2']:
66+
for op in self.UNARY_OPS:
67+
cases.append(tpl_assert.format(lane_type=lane_type, op=op, value=self.v128_const('i32x4', '0')))
68+
69+
def gen_test_cases(self):
70+
wast_filename = '../simd_{lane_type}_rounding.wast'.format(lane_type=self.LANE_TYPE)
71+
with open(wast_filename, 'w') as fp:
72+
txt_test_case = self.get_all_cases()
73+
txt_test_case = txt_test_case.replace(
74+
self.LANE_TYPE + ' arithmetic',
75+
self.LANE_TYPE + ' [ceil, floor, trunc, nearest]')
76+
fp.write(txt_test_case)
77+
78+
79+
def gen_test_cases():
80+
simd_f32x4_case = Simdf32x4RoundingCase()
81+
simd_f32x4_case.gen_test_cases()
82+
83+
84+
if __name__ == '__main__':
85+
gen_test_cases()
+29
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
#!/usr/bin/env python3
2+
3+
"""
4+
Generate f64x2 [ceil, floor, trunc, nearest] cases.
5+
"""
6+
7+
from simd_f32x4_rounding import Simdf32x4RoundingCase
8+
from simd_f64x2 import Simdf64x2Case
9+
from simd_f64x2_arith import Simdf64x2ArithmeticCase
10+
from simd_float_op import FloatingPointRoundingOp
11+
from simd import SIMD
12+
from test_assert import AssertReturn
13+
14+
15+
class Simdf64x2RoundingCase(Simdf32x4RoundingCase):
16+
17+
LANE_TYPE = 'f64x2'
18+
FLOAT_NUMBERS = Simdf64x2ArithmeticCase.FLOAT_NUMBERS
19+
LITERAL_NUMBERS = Simdf64x2ArithmeticCase.LITERAL_NUMBERS
20+
NAN_NUMBERS = Simdf64x2ArithmeticCase.NAN_NUMBERS
21+
22+
23+
def gen_test_cases():
24+
simd_f64x2_case = Simdf64x2RoundingCase()
25+
simd_f64x2_case.gen_test_cases()
26+
27+
28+
if __name__ == '__main__':
29+
gen_test_cases()

test/core/simd/meta/simd_float_op.py

+50-1
Original file line numberDiff line numberDiff line change
@@ -252,4 +252,53 @@ def binary_op(self, op: str, p1: str, p2: str) -> str:
252252
elif op == 'ge':
253253
return '-1' if f1 >= f2 else '0'
254254
else:
255-
raise Exception('Unknown binary operation')
255+
raise Exception('Unknown binary operation')
256+
257+
258+
class FloatingPointRoundingOp(FloatingPointOp):
259+
def unary_op(self, op: str, p1: str, hex_form=True) -> str:
260+
"""Unnary operation on p1 with the operation specified by op
261+
262+
:param op: ceil, floor, trunc, nearest
263+
:param p1: float number in hex
264+
:return:
265+
"""
266+
if '0x' in p1:
267+
f1 = float.fromhex(p1)
268+
else:
269+
f1 = float(p1)
270+
271+
if 'nan' in p1:
272+
return 'nan'
273+
274+
if 'inf' in p1:
275+
return p1
276+
277+
# The rounding ops don't treat -0.0 correctly, e.g.:
278+
# math.ceil(-0.4) returns +0.0, so copy the sign.
279+
elif op == 'ceil':
280+
r = math.copysign(math.ceil(f1), f1)
281+
if hex_form:
282+
return r.hex()
283+
else:
284+
return str(r)
285+
elif op == 'floor':
286+
r = math.copysign(math.floor(f1), f1)
287+
if hex_form:
288+
return r.hex()
289+
else:
290+
return str(r)
291+
elif op == 'trunc':
292+
r = math.copysign(math.trunc(f1), f1)
293+
if hex_form:
294+
return r.hex()
295+
else:
296+
return str(r)
297+
elif op == 'nearest':
298+
r = math.copysign(round(f1), f1)
299+
if hex_form:
300+
return r.hex()
301+
else:
302+
return str(r)
303+
else:
304+
raise Exception('Unknown binary operation')

0 commit comments

Comments
 (0)