Skip to content

Commit a6c6c91

Browse files
Fix undefined behaviour in integer shifts (#1385)
Prior to this change, Integer shifts used to exhibit undefined behavior if the number of bits is outside of 0..BITS-1 range (i.e., 0..31 for shifts on Int and 0..63 for shifts on Long). Here we address this change by masking the bits part to always be in range.
1 parent 1a0a878 commit a6c6c91

File tree

4 files changed

+167
-22
lines changed

4 files changed

+167
-22
lines changed

nir/src/main/scala/scala/scalanative/nir/Ops.scala

+29-2
Original file line numberDiff line numberDiff line change
@@ -42,17 +42,44 @@ sealed abstract class Op {
4242

4343
final def show: String = nir.Show(this)
4444

45+
/** Op is pure if it doesn't have any side-effects, including:
46+
*
47+
* * doesn't throw exceptions
48+
* * doesn't perform any unsafe reads or writes from the memory
49+
* * doesn't call foreign code
50+
*
51+
* Recomputing pure op will always yield to the same result.
52+
*/
4553
final def isPure: Boolean = this match {
4654
case _: Op.Elem | _: Op.Extract | _: Op.Insert | _: Op.Comp | _: Op.Conv |
47-
_: Op.Select =>
55+
_: Op.Select | _: Op.Is | _: Op.Copy | _: Op.Sizeof =>
4856
true
49-
case Op.Bin(Bin.Sdiv | Bin.Udiv | Bin.Srem | Bin.Urem, _, _, _) =>
57+
// Division and modulo on integers are not pure as
58+
// they may throw if the divisor is zero.
59+
case Op.Bin(Bin.Sdiv | Bin.Udiv | Bin.Srem | Bin.Urem, _: Type.I, _, _) =>
5060
false
5161
case _: Op.Bin =>
5262
true
5363
case _ =>
5464
false
5565
}
66+
67+
/** Op is idempotent if re-evaluation of the operation with the same
68+
* arguments is going to produce the same results, without any extra
69+
* side effects as long as previous evaluation did not throw.
70+
*/
71+
final def isIdempotent: Boolean = this match {
72+
case op if op.isPure =>
73+
true
74+
// Division and modulo are non-pure but idempotent.
75+
case op: Op.Bin =>
76+
true
77+
case _: Op.Method | _: Op.Dynmethod | _: Op.Module | _: Op.Box |
78+
_: Op.Unbox | _: Op.Arraylength =>
79+
true
80+
case _ =>
81+
false
82+
}
5683
}
5784
object Op {
5885
// low-level

tools/src/main/scala/scala/scalanative/codegen/Lower.scala

+19
Original file line numberDiff line numberDiff line change
@@ -506,13 +506,32 @@ object Lower {
506506
label(resultL, Seq(Val.Local(n, ty)))
507507
}
508508

509+
// Shifts are undefined if the bits shifted by are >= bits in the type.
510+
// We mask the right hand side with bits in type - 1 to make it defined.
511+
def maskShift(op: Op.Bin) = {
512+
val Op.Bin(_, ty: Type.I, _, r) = op
513+
val mask = ty match {
514+
case Type.Int => Val.Int(31)
515+
case Type.Long => Val.Int(63)
516+
case _ => util.unreachable
517+
}
518+
val masked = bin(Bin.And, ty, r, mask, unwind)
519+
let(n, op.copy(r = masked), unwind)
520+
}
521+
509522
op match {
510523
case op @ Op.Bin(bin @ (Bin.Srem | Bin.Urem | Bin.Sdiv | Bin.Udiv),
511524
ty: Type.I,
512525
l,
513526
r) =>
514527
checkDivisionByZero(op)
515528

529+
case op @ Op.Bin(bin @ (Bin.Shl | Bin.Lshr | Bin.Ashr),
530+
ty: Type.I,
531+
l,
532+
r) =>
533+
maskShift(op)
534+
516535
case op =>
517536
let(n, op, unwind)
518537
}

tools/src/main/scala/scala/scalanative/optimizer/pass/GlobalValueNumbering.scala

+2-20
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ class GlobalValueNumbering extends Pass {
5252
val newBlockInsts = block.insts.map {
5353

5454
case inst: Inst.Let => {
55-
val idempotent = isIdempotent(inst.op)
55+
val idempotent = inst.op.isIdempotent
5656

5757
val instHash =
5858
if (idempotent)
@@ -97,24 +97,6 @@ class GlobalValueNumbering extends Pass {
9797
}
9898

9999
object GlobalValueNumbering extends PassCompanion {
100-
def isIdempotent(op: Op): Boolean = {
101-
import Op._
102-
op match {
103-
// Always idempotent:
104-
case (_: Method | _: Dynmethod | _: As | _: Is | _: Copy | _: Sizeof |
105-
_: Module | _: Box | _: Unbox | _: Arraylength) =>
106-
true
107-
case op if op.isPure =>
108-
true
109-
110-
// Never idempotent:
111-
case (_: Load | _: Store | _: Stackalloc | _: Classalloc | _: Call |
112-
_: Closure | _: Fieldload | _: Fieldstore | _: Var | _: Varload |
113-
_: Varstore | _: Arrayalloc | _: Arrayload | _: Arraystore) =>
114-
false
115-
}
116-
}
117-
118100
class DeepEquals(localDefs: Local => Inst) {
119101

120102
def eqInst(instA: Inst.Let, instB: Inst.Let): Boolean = {
@@ -123,7 +105,7 @@ object GlobalValueNumbering extends PassCompanion {
123105

124106
def eqOp(opA: Op, opB: Op): Boolean = {
125107
import Op._
126-
if (!(isIdempotent(opA) && isIdempotent(opB)))
108+
if (!(opA.isIdempotent && opB.isIdempotent))
127109
false
128110
else {
129111
(opA, opB) match {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
package scala
2+
3+
object ShiftOverflowSuite extends tests.Suite {
4+
@noinline def noinlineByte42: Byte = 42.toByte
5+
@noinline def noinlineShort42: Short = 42.toShort
6+
@noinline def noinlineChar42: Char = 42.toChar
7+
@noinline def noinlineInt42: Int = 42
8+
@noinline def noinlineLong42: Long = 42L
9+
@noinline def noinlineInt33: Int = 33
10+
@noinline def noinlineLong65: Long = 65L
11+
@noinline def noinlineInt21: Int = 21
12+
@noinline def noinlineLong21: Long = 21L
13+
@noinline def noinlineInt84: Int = 84
14+
@noinline def noinlineLong84: Long = 84L
15+
16+
@inline def inlineByte42: Byte = 42.toByte
17+
@inline def inlineShort42: Short = 42.toShort
18+
@inline def inlineChar42: Char = 42.toChar
19+
@inline def inlineInt42: Int = 42
20+
@inline def inlineLong42: Long = 42L
21+
@inline def inlineInt33: Int = 33
22+
@inline def inlineLong65: Long = 65L
23+
@inline def inlineInt21: Int = 21
24+
@inline def inlineLong21: Long = 21L
25+
@inline def inlineInt84: Int = 84
26+
@inline def inlineLong84: Long = 84L
27+
28+
test("x << 33 (noinline)") {
29+
assert((noinlineByte42 << noinlineInt33) == noinlineInt84)
30+
assert((noinlineShort42 << noinlineInt33) == noinlineInt84)
31+
assert((noinlineChar42 << noinlineInt33) == noinlineInt84)
32+
assert((noinlineInt42 << noinlineInt33) == noinlineInt84)
33+
}
34+
35+
test("x << 33 (inline)") {
36+
assert((inlineByte42 << inlineInt33) == inlineInt84)
37+
assert((inlineShort42 << inlineInt33) == inlineInt84)
38+
assert((inlineChar42 << inlineInt33) == inlineInt84)
39+
assert((inlineInt42 << inlineInt33) == inlineInt84)
40+
}
41+
42+
test("x << 65L (noinline)") {
43+
assert((noinlineByte42 << noinlineLong65) == noinlineLong84)
44+
assert((noinlineShort42 << noinlineLong65) == noinlineLong84)
45+
assert((noinlineChar42 << noinlineLong65) == noinlineLong84)
46+
assert((noinlineInt42 << noinlineLong65) == noinlineLong84)
47+
assert((noinlineLong42 << noinlineLong65) == noinlineLong84)
48+
}
49+
50+
test("x << 65L (inline)") {
51+
assert((inlineByte42 << inlineLong65) == inlineLong84)
52+
assert((inlineShort42 << inlineLong65) == inlineLong84)
53+
assert((inlineChar42 << inlineLong65) == inlineLong84)
54+
assert((inlineInt42 << inlineLong65) == inlineLong84)
55+
assert((inlineLong42 << inlineLong65) == inlineLong84)
56+
}
57+
58+
test("x >> 33 (noinline)") {
59+
assert((noinlineByte42 >> noinlineInt33) == noinlineInt21)
60+
assert((noinlineShort42 >> noinlineInt33) == noinlineInt21)
61+
assert((noinlineChar42 >> noinlineInt33) == noinlineInt21)
62+
assert((noinlineInt42 >> noinlineInt33) == noinlineInt21)
63+
}
64+
65+
test("x >> 33 (inline)") {
66+
assert((inlineByte42 >> inlineInt33) == inlineInt21)
67+
assert((inlineShort42 >> inlineInt33) == inlineInt21)
68+
assert((inlineChar42 >> inlineInt33) == inlineInt21)
69+
assert((inlineInt42 >> inlineInt33) == inlineInt21)
70+
}
71+
72+
test("x >> 65L (noinline)") {
73+
assert((noinlineByte42 >> noinlineLong65) == noinlineLong21)
74+
assert((noinlineShort42 >> noinlineLong65) == noinlineLong21)
75+
assert((noinlineChar42 >> noinlineLong65) == noinlineLong21)
76+
assert((noinlineInt42 >> noinlineLong65) == noinlineLong21)
77+
assert((noinlineLong42 >> noinlineLong65) == noinlineLong21)
78+
}
79+
80+
test("x >> 65L (inline)") {
81+
assert((inlineByte42 >> inlineLong65) == inlineLong21)
82+
assert((inlineShort42 >> inlineLong65) == inlineLong21)
83+
assert((inlineChar42 >> inlineLong65) == inlineLong21)
84+
assert((inlineInt42 >> inlineLong65) == inlineLong21)
85+
assert((inlineLong42 >> inlineLong65) == inlineLong21)
86+
}
87+
88+
test("x >>> 33 (noinline)") {
89+
assert((noinlineByte42 >>> noinlineInt33) == noinlineInt21)
90+
assert((noinlineShort42 >>> noinlineInt33) == noinlineInt21)
91+
assert((noinlineChar42 >>> noinlineInt33) == noinlineInt21)
92+
assert((noinlineInt42 >>> noinlineInt33) == noinlineInt21)
93+
}
94+
95+
test("x >>> 33 (inline)") {
96+
assert((inlineByte42 >>> inlineInt33) == inlineInt21)
97+
assert((inlineShort42 >>> inlineInt33) == inlineInt21)
98+
assert((inlineChar42 >>> inlineInt33) == inlineInt21)
99+
assert((inlineInt42 >>> inlineInt33) == inlineInt21)
100+
}
101+
102+
test("x >>> 65L (noinline)") {
103+
assert((noinlineByte42 >>> noinlineLong65) == noinlineLong21)
104+
assert((noinlineShort42 >>> noinlineLong65) == noinlineLong21)
105+
assert((noinlineChar42 >>> noinlineLong65) == noinlineLong21)
106+
assert((noinlineInt42 >>> noinlineLong65) == noinlineLong21)
107+
assert((noinlineLong42 >>> noinlineLong65) == noinlineLong21)
108+
}
109+
110+
test("x >>> 65L (inline)") {
111+
assert((inlineByte42 >>> inlineLong65) == inlineLong21)
112+
assert((inlineShort42 >>> inlineLong65) == inlineLong21)
113+
assert((inlineChar42 >>> inlineLong65) == inlineLong21)
114+
assert((inlineInt42 >>> inlineLong65) == inlineLong21)
115+
assert((inlineLong42 >>> inlineLong65) == inlineLong21)
116+
}
117+
}

0 commit comments

Comments
 (0)