Skip to content

Commit 1a0a878

Browse files
Make division by zero defined behaviour (#1384)
This change makes integer division behave the same as on JVM: 1. Division by 0 throws an ArithmeticException. 2. Division that overflows wraps around. Previously these two cases were undefined. While integer segfaults are relatively easy to debug, the problem gets worse once LLVM starts optimizing assuming undefined behaviour (by removing said segfaults and replacing ops with 0).
1 parent ae56535 commit 1a0a878

File tree

9 files changed

+243
-45
lines changed

9 files changed

+243
-45
lines changed

docs/user/lang.rst

-1
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@ A number of error conditions which are well-defined on JVM are undefined
3434
behavior:
3535

3636
1. Dereferencing null.
37-
2. Division by zero.
3837
3. Stack overflows.
3938

4039
Those typically crash application with a segfault on the supported architectures.

nativelib/src/main/scala/scala/scalanative/runtime/package.scala

+6-1
Original file line numberDiff line numberDiff line change
@@ -89,5 +89,10 @@ package object runtime {
8989
/** Run the runtime's event loop. The method is called from the
9090
* generated C-style after the application's main method terminates.
9191
*/
92-
def loop(): Unit = ExecutionContext.loop()
92+
def loop(): Unit =
93+
ExecutionContext.loop()
94+
95+
/** Called by the compiler in case of division by zero. */
96+
def throwDivisionByZero(): Nothing =
97+
throw new java.lang.ArithmeticException("/ by zero")
9398
}

nir/src/main/scala/scala/scalanative/nir/Ops.scala

+19-9
Original file line numberDiff line numberDiff line change
@@ -41,23 +41,33 @@ sealed abstract class Op {
4141
}
4242

4343
final def show: String = nir.Show(this)
44+
45+
final def isPure: Boolean = this match {
46+
case _: Op.Elem | _: Op.Extract | _: Op.Insert | _: Op.Comp | _: Op.Conv |
47+
_: Op.Select =>
48+
true
49+
case Op.Bin(Bin.Sdiv | Bin.Udiv | Bin.Srem | Bin.Urem, _, _, _) =>
50+
false
51+
case _: Op.Bin =>
52+
true
53+
case _ =>
54+
false
55+
}
4456
}
4557
object Op {
46-
sealed abstract class Pure extends Op
47-
4858
// low-level
4959
final case class Call(ty: Type, ptr: Val, args: Seq[Val]) extends Op
5060
final case class Load(ty: Type, ptr: Val, isVolatile: Boolean) extends Op
5161
final case class Store(ty: Type, ptr: Val, value: Val, isVolatile: Boolean)
5262
extends Op
53-
final case class Elem(ty: Type, ptr: Val, indexes: Seq[Val]) extends Pure
54-
final case class Extract(aggr: Val, indexes: Seq[Int]) extends Pure
55-
final case class Insert(aggr: Val, value: Val, indexes: Seq[Int]) extends Pure
63+
final case class Elem(ty: Type, ptr: Val, indexes: Seq[Val]) extends Op
64+
final case class Extract(aggr: Val, indexes: Seq[Int]) extends Op
65+
final case class Insert(aggr: Val, value: Val, indexes: Seq[Int]) extends Op
5666
final case class Stackalloc(ty: Type, n: Val) extends Op
57-
final case class Bin(bin: nir.Bin, ty: Type, l: Val, r: Val) extends Pure
58-
final case class Comp(comp: nir.Comp, ty: Type, l: Val, r: Val) extends Pure
59-
final case class Conv(conv: nir.Conv, ty: Type, value: Val) extends Pure
60-
final case class Select(cond: Val, thenv: Val, elsev: Val) extends Pure
67+
final case class Bin(bin: nir.Bin, ty: Type, l: Val, r: Val) extends Op
68+
final case class Comp(comp: nir.Comp, ty: Type, l: Val, r: Val) extends Op
69+
final case class Conv(conv: nir.Conv, ty: Type, value: Val) extends Op
70+
final case class Select(cond: Val, thenv: Val, elsev: Val) extends Op
6171

6272
def Load(ty: Type, ptr: Val): Load =
6373
Load(ty, ptr, isVolatile = false)

tools/src/main/scala/scala/scalanative/codegen/Lower.scala

+92-26
Original file line numberDiff line numberDiff line change
@@ -427,36 +427,91 @@ object Lower {
427427
def genBinOp(buf: Buffer, n: Local, op: Op.Bin, unwind: Next): Unit = {
428428
import buf._
429429

430-
op match {
431-
// Detects taking remainder for division by -1 and replaces
432-
// it by division by 1 which can't overflow.
433-
//
434-
// We implement '%' (remainder) with LLVM's 'srem' and it
435-
// can overflow for cases:
436-
//
437-
// - Int.MinValue % -1
438-
// - Long.MinValue % -1
439-
//
440-
// E.g. On x86_64 'srem' might get translated to 'idiv'
441-
// which computes both quotient and remainder at once
442-
// and quotient can overflow.
443-
case sremBin @ Op.Bin(Bin.Srem, intType: Type.I, _, divisor)
444-
if intType.width == 32 || intType.width == 64 =>
445-
val safeDivisor = Val.Local(fresh(), intType)
446-
val thenL, elseL, contL = fresh()
430+
// LLVM's division by zero is undefined behaviour. We guard
431+
// the case when the divisor is zero and fail gracefully
432+
// by throwing an arithmetic exception.
433+
def checkDivisionByZero(op: Op.Bin): Unit = {
434+
val Op.Bin(bin, ty: Type.I, dividend, divisor) = op
435+
436+
val thenL, elseL = fresh()
437+
438+
val isZero =
439+
comp(Comp.Ieq, ty, divisor, Val.Zero(ty), unwind)
440+
branch(isZero, Next(thenL), Next(elseL))
441+
442+
label(thenL)
443+
call(throwDivisionByZeroTy,
444+
throwDivisionByZeroVal,
445+
Seq(Val.Null),
446+
unwind)
447+
unreachable
448+
449+
label(elseL)
450+
if (bin == Bin.Srem || bin == Bin.Sdiv) {
451+
checkDivisionOverflow(op)
452+
} else {
453+
let(n, op, unwind)
454+
}
455+
}
447456

448-
val isPossibleOverflow =
449-
let(Op.Comp(Comp.Ieq, intType, divisor, Val.Int(-1)), unwind)
450-
branch(isPossibleOverflow, Next(thenL), Next(elseL))
457+
// Detects taking remainder for division by -1 and replaces
458+
// it by division by 1 which can't overflow.
459+
//
460+
// We implement '%' (remainder) with LLVM's 'srem' and it
461+
// can overflow for cases:
462+
//
463+
// - Int.MinValue % -1
464+
// - Long.MinValue % -1
465+
//
466+
// E.g. On x86_64 'srem' might get translated to 'idiv'
467+
// which computes both quotient and remainder at once
468+
// and quotient can overflow.
469+
def checkDivisionOverflow(op: Op.Bin): Unit = {
470+
val Op.Bin(bin, ty: Type.I, dividend, divisor) = op
471+
472+
val mayOverflowL, noOverflowL, didOverflowL, resultL = fresh()
473+
474+
val minus1 = ty match {
475+
case Type.Int => Val.Int(-1)
476+
case Type.Long => Val.Long(-1L)
477+
case _ => util.unreachable
478+
}
479+
val minValue = ty match {
480+
case Type.Int => Val.Int(java.lang.Integer.MIN_VALUE)
481+
case Type.Long => Val.Long(java.lang.Long.MIN_VALUE)
482+
case _ => util.unreachable
483+
}
451484

452-
label(thenL)
453-
jump(contL, Seq(Val.Int(1)))
485+
val divisorIsMinus1 =
486+
let(Op.Comp(Comp.Ieq, ty, divisor, minus1), unwind)
487+
branch(divisorIsMinus1, Next(mayOverflowL), Next(noOverflowL))
454488

455-
label(elseL)
456-
jump(contL, Seq(divisor))
489+
label(mayOverflowL)
490+
val dividendIsMinValue =
491+
let(Op.Comp(Comp.Ieq, ty, dividend, minValue), unwind)
492+
branch(dividendIsMinValue, Next(didOverflowL), Next(noOverflowL))
493+
494+
label(didOverflowL)
495+
val overflowResult = bin match {
496+
case Bin.Srem => Val.Zero(ty)
497+
case Bin.Sdiv => minValue
498+
case _ => util.unreachable
499+
}
500+
jump(resultL, Seq(overflowResult))
501+
502+
label(noOverflowL)
503+
val noOverflowResult = let(op, unwind)
504+
jump(resultL, Seq(noOverflowResult))
457505

458-
label(contL, Seq(safeDivisor))
459-
let(n, sremBin.copy(r = safeDivisor), unwind)
506+
label(resultL, Seq(Val.Local(n, ty)))
507+
}
508+
509+
op match {
510+
case op @ Op.Bin(bin @ (Bin.Srem | Bin.Urem | Bin.Sdiv | Bin.Udiv),
511+
ty: Type.I,
512+
l,
513+
r) =>
514+
checkDivisionByZero(op)
460515

461516
case op =>
462517
let(n, op, unwind)
@@ -792,6 +847,16 @@ object Lower {
792847
Type.Function(Seq(Type.Ref(Global.Top("scala.scalanative.runtime.Array"))),
793848
Type.Int)
794849

850+
val throwDivisionByZeroTy =
851+
Type.Function(
852+
Seq(Type.Ref(Global.Top("scala.scalanative.runtime.package$"))),
853+
Type.Nothing)
854+
val throwDivisionByZero =
855+
Global.Member(Global.Top("scala.scalanative.runtime.package$"),
856+
Sig.Method("throwDivisionByZero", Seq(Type.Nothing)))
857+
val throwDivisionByZeroVal =
858+
Val.Global(throwDivisionByZero, Type.Ptr)
859+
795860
val injects: Seq[Defn] = {
796861
val buf = mutable.UnrolledBuffer.empty[Defn]
797862
buf += Defn.Declare(Attrs.None, allocSmallName, allocSig)
@@ -824,6 +889,7 @@ object Lower {
824889
buf ++= arrayApply.values
825890
buf ++= arrayUpdateGeneric.values
826891
buf ++= arrayUpdate.values
892+
buf += throwDivisionByZero
827893
buf
828894
}
829895
}

tools/src/main/scala/scala/scalanative/optimizer/pass/GlobalValueNumbering.scala

+4-2
Original file line numberDiff line numberDiff line change
@@ -101,8 +101,10 @@ object GlobalValueNumbering extends PassCompanion {
101101
import Op._
102102
op match {
103103
// Always idempotent:
104-
case (_: Pure | _: Method | _: Dynmethod | _: As | _: Is | _: Copy |
105-
_: Sizeof | _: Module | _: Box | _: Unbox | _: Arraylength) =>
104+
case (_: Method | _: Dynmethod | _: As | _: Is | _: Copy | _: Sizeof |
105+
_: Module | _: Box | _: Unbox | _: Arraylength) =>
106+
true
107+
case op if op.isPure =>
106108
true
107109

108110
// Never idempotent:

tools/src/main/scala/scala/scalanative/sema/UseDef.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ object UseDef {
7272
pureWhitelist.contains(name)
7373
case Inst.Let(_, Op.Module(name), _) =>
7474
pureWhitelist.contains(name)
75-
case Inst.Let(_, _: Op.Pure, _) =>
75+
case Inst.Let(_, op, _) if op.isPure =>
7676
true
7777
case _ =>
7878
false
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
package scala
2+
3+
object DivisionByZeroSuite extends tests.Suite {
4+
@noinline def byte1 = 1.toByte
5+
@noinline def char1 = 1.toChar
6+
@noinline def short1 = 1.toShort
7+
@noinline def int1 = 1
8+
@noinline def long1 = 1L
9+
@noinline def byte0 = 0.toByte
10+
@noinline def char0 = 0.toChar
11+
@noinline def short0 = 0.toShort
12+
@noinline def int0 = 0
13+
@noinline def long0 = 0L
14+
15+
test("byte / zero") {
16+
assertThrows[ArithmeticException](byte1 / byte0)
17+
assertThrows[ArithmeticException](byte1 / short0)
18+
assertThrows[ArithmeticException](byte1 / char0)
19+
assertThrows[ArithmeticException](byte1 / int0)
20+
assertThrows[ArithmeticException](byte1 / long0)
21+
}
22+
23+
test("byte % zero") {
24+
assertThrows[ArithmeticException](byte1 / byte0)
25+
assertThrows[ArithmeticException](byte1 / short0)
26+
assertThrows[ArithmeticException](byte1 / char0)
27+
assertThrows[ArithmeticException](byte1 / int0)
28+
assertThrows[ArithmeticException](byte1 / long0)
29+
}
30+
31+
test("short / zero") {
32+
assertThrows[ArithmeticException](short1 / byte0)
33+
assertThrows[ArithmeticException](short1 / short0)
34+
assertThrows[ArithmeticException](short1 / char0)
35+
assertThrows[ArithmeticException](short1 / int0)
36+
assertThrows[ArithmeticException](short1 / long0)
37+
}
38+
39+
test("short % zero") {
40+
assertThrows[ArithmeticException](short1 / byte0)
41+
assertThrows[ArithmeticException](short1 / short0)
42+
assertThrows[ArithmeticException](short1 / char0)
43+
assertThrows[ArithmeticException](short1 / int0)
44+
assertThrows[ArithmeticException](short1 / long0)
45+
}
46+
47+
test("char / zero") {
48+
assertThrows[ArithmeticException](char1 / byte0)
49+
assertThrows[ArithmeticException](char1 / short0)
50+
assertThrows[ArithmeticException](char1 / char0)
51+
assertThrows[ArithmeticException](char1 / int0)
52+
assertThrows[ArithmeticException](char1 / long0)
53+
}
54+
55+
test("char % zero") {
56+
assertThrows[ArithmeticException](char1 / byte0)
57+
assertThrows[ArithmeticException](char1 / short0)
58+
assertThrows[ArithmeticException](char1 / char0)
59+
assertThrows[ArithmeticException](char1 / int0)
60+
assertThrows[ArithmeticException](char1 / long0)
61+
}
62+
63+
test("int / zero") {
64+
assertThrows[ArithmeticException](int1 / byte0)
65+
assertThrows[ArithmeticException](int1 / short0)
66+
assertThrows[ArithmeticException](int1 / char0)
67+
assertThrows[ArithmeticException](int1 / int0)
68+
assertThrows[ArithmeticException](int1 / long0)
69+
}
70+
71+
test("int % zero") {
72+
assertThrows[ArithmeticException](int1 / byte0)
73+
assertThrows[ArithmeticException](int1 / short0)
74+
assertThrows[ArithmeticException](int1 / char0)
75+
assertThrows[ArithmeticException](int1 / int0)
76+
assertThrows[ArithmeticException](int1 / long0)
77+
}
78+
79+
test("long / zero") {
80+
assertThrows[ArithmeticException](long1 / byte0)
81+
assertThrows[ArithmeticException](long1 / short0)
82+
assertThrows[ArithmeticException](long1 / char0)
83+
assertThrows[ArithmeticException](long1 / int0)
84+
assertThrows[ArithmeticException](long1 / long0)
85+
}
86+
87+
test("long % zero") {
88+
assertThrows[ArithmeticException](long1 / byte0)
89+
assertThrows[ArithmeticException](long1 / short0)
90+
assertThrows[ArithmeticException](long1 / char0)
91+
assertThrows[ArithmeticException](long1 / int0)
92+
assertThrows[ArithmeticException](long1 / long0)
93+
}
94+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
package scala
2+
3+
object DivisionOverflowSuite extends tests.Suite {
4+
@noinline def intMinus1 = -1
5+
@noinline def longMinus1 = -1L
6+
7+
test("Integer.MIN_VALUE / -1") {
8+
assert(
9+
(java.lang.Integer.MIN_VALUE / intMinus1) == java.lang.Integer.MIN_VALUE)
10+
}
11+
12+
test("Integer.MIN_VALUE % -1") {
13+
assert((java.lang.Integer.MIN_VALUE % intMinus1) == 0)
14+
}
15+
16+
test("Long.MIN_VALUE / -1") {
17+
assert((java.lang.Long.MIN_VALUE / longMinus1) == java.lang.Long.MIN_VALUE)
18+
}
19+
20+
test("Long.MIN_VALUE % -1") {
21+
assert((java.lang.Long.MIN_VALUE % longMinus1) == 0)
22+
}
23+
}

unit-tests/src/test/scala/scala/scalanative/IssuesSuite.scala

+4-5
Original file line numberDiff line numberDiff line change
@@ -49,15 +49,14 @@ object IssuesSuite extends tests.Suite {
4949
}
5050

5151
test("#314") {
52-
// Division by zero is undefined behavior in production mode.
53-
// Optimizer can assume it never happens and remove unused result.
52+
// Division by zero is defined behavior.
5453
assert {
5554
try {
5655
5 / 0
57-
true
56+
false
5857
} catch {
59-
case _: Throwable =>
60-
false
58+
case _: ArithmeticException =>
59+
true
6160
}
6261
}
6362
}

0 commit comments

Comments
 (0)