Skip to content

Commit e0c030c

Browse files
authored
Emit switch bytecode when matching unions of a switchable type (#20411)
Fixes #20410
2 parents 7d559ad + 7279bf7 commit e0c030c

File tree

2 files changed

+128
-8
lines changed

2 files changed

+128
-8
lines changed

Diff for: compiler/src/dotty/tools/dotc/transform/PatternMatcher.scala

+8-8
Original file line numberDiff line numberDiff line change
@@ -818,11 +818,11 @@ object PatternMatcher {
818818
*/
819819
private def collectSwitchCases(scrutinee: Tree, plan: SeqPlan): List[(List[Tree], Plan)] = {
820820
def isSwitchableType(tpe: Type): Boolean =
821-
(tpe isRef defn.IntClass) ||
822-
(tpe isRef defn.ByteClass) ||
823-
(tpe isRef defn.ShortClass) ||
824-
(tpe isRef defn.CharClass) ||
825-
(tpe isRef defn.StringClass)
821+
(tpe <:< defn.IntType) ||
822+
(tpe <:< defn.ByteType) ||
823+
(tpe <:< defn.ShortType) ||
824+
(tpe <:< defn.CharType) ||
825+
(tpe <:< defn.StringType)
826826

827827
val seen = mutable.Set[Any]()
828828

@@ -872,7 +872,7 @@ object PatternMatcher {
872872
(Nil, plan) :: Nil
873873
}
874874

875-
if (isSwitchableType(scrutinee.tpe.widen)) recur(plan)
875+
if (isSwitchableType(scrutinee.tpe)) recur(plan)
876876
else Nil
877877
}
878878

@@ -893,8 +893,8 @@ object PatternMatcher {
893893
*/
894894

895895
val (primScrutinee, scrutineeTpe) =
896-
if (scrutinee.tpe.widen.isRef(defn.IntClass)) (scrutinee, defn.IntType)
897-
else if (scrutinee.tpe.widen.isRef(defn.StringClass)) (scrutinee, defn.StringType)
896+
if (scrutinee.tpe <:< defn.IntType) (scrutinee, defn.IntType)
897+
else if (scrutinee.tpe <:< defn.StringType) (scrutinee, defn.StringType)
898898
else (scrutinee.select(nme.toInt), defn.IntType)
899899

900900
def primLiteral(lit: Tree): Tree =

Diff for: compiler/test/dotty/tools/backend/jvm/DottyBytecodeTests.scala

+120
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,126 @@ class DottyBytecodeTests extends DottyBytecodeTest {
158158
}
159159
}
160160

161+
@Test def switchOnUnionOfInts = {
162+
val source =
163+
"""
164+
|object Foo {
165+
| def foo(x: 1 | 2 | 3 | 4 | 5) = x match {
166+
| case 1 => println(3)
167+
| case 2 | 3 => println(2)
168+
| case 4 => println(1)
169+
| case 5 => println(0)
170+
| }
171+
|}
172+
""".stripMargin
173+
174+
checkBCode(source) { dir =>
175+
val moduleIn = dir.lookupName("Foo$.class", directory = false)
176+
val moduleNode = loadClassNode(moduleIn.input)
177+
val methodNode = getMethod(moduleNode, "foo")
178+
assert(verifySwitch(methodNode))
179+
}
180+
}
181+
182+
@Test def switchOnUnionOfStrings = {
183+
val source =
184+
"""
185+
|object Foo {
186+
| def foo(s: "one" | "two" | "three" | "four" | "five") = s match {
187+
| case "one" => println(3)
188+
| case "two" | "three" => println(2)
189+
| case "four" | "five" => println(1)
190+
| case _ => println(0)
191+
| }
192+
|}
193+
""".stripMargin
194+
195+
checkBCode(source) { dir =>
196+
val moduleIn = dir.lookupName("Foo$.class", directory = false)
197+
val moduleNode = loadClassNode(moduleIn.input)
198+
val methodNode = getMethod(moduleNode, "foo")
199+
assert(verifySwitch(methodNode))
200+
}
201+
}
202+
203+
@Test def switchOnUnionOfChars = {
204+
val source =
205+
"""
206+
|object Foo {
207+
| def foo(ch: 'a' | 'b' | 'c' | 'd' | 'e'): Int = ch match {
208+
| case 'a' => 1
209+
| case 'b' => 2
210+
| case 'c' => 3
211+
| case 'd' => 4
212+
| case 'e' => 5
213+
| }
214+
|}
215+
""".stripMargin
216+
217+
checkBCode(source) { dir =>
218+
val moduleIn = dir.lookupName("Foo$.class", directory = false)
219+
val moduleNode = loadClassNode(moduleIn.input)
220+
val methodNode = getMethod(moduleNode, "foo")
221+
assert(verifySwitch(methodNode))
222+
}
223+
}
224+
225+
@Test def switchOnUnionOfIntSingletons = {
226+
val source =
227+
"""
228+
|object Foo {
229+
| final val One = 1
230+
| final val Two = 2
231+
| final val Three = 3
232+
| final val Four = 4
233+
| final val Five = 5
234+
| type Values = One.type | Two.type | Three.type | Four.type | Five.type
235+
|
236+
| def foo(s: Values) = s match {
237+
| case One => println(3)
238+
| case Two | Three => println(2)
239+
| case Four => println(1)
240+
| case Five => println(0)
241+
| }
242+
|}
243+
""".stripMargin
244+
245+
checkBCode(source) { dir =>
246+
val moduleIn = dir.lookupName("Foo$.class", directory = false)
247+
val moduleNode = loadClassNode(moduleIn.input)
248+
val methodNode = getMethod(moduleNode, "foo")
249+
assert(verifySwitch(methodNode))
250+
}
251+
}
252+
253+
@Test def switchOnUnionOfStringSingletons = {
254+
val source =
255+
"""
256+
|object Foo {
257+
| final val One = "one"
258+
| final val Two = "two"
259+
| final val Three = "three"
260+
| final val Four = "four"
261+
| final val Five = "five"
262+
| type Values = One.type | Two.type | Three.type | Four.type | Five.type
263+
|
264+
| def foo(s: Values) = s match {
265+
| case One => println(3)
266+
| case Two | Three => println(2)
267+
| case Four => println(1)
268+
| case Five => println(0)
269+
| }
270+
|}
271+
""".stripMargin
272+
273+
checkBCode(source) { dir =>
274+
val moduleIn = dir.lookupName("Foo$.class", directory = false)
275+
val moduleNode = loadClassNode(moduleIn.input)
276+
val methodNode = getMethod(moduleNode, "foo")
277+
assert(verifySwitch(methodNode))
278+
}
279+
}
280+
161281
@Test def matchWithDefaultNoThrowMatchError = {
162282
val source =
163283
"""class Test {

0 commit comments

Comments
 (0)