Skip to content

Commit d3d5a3b

Browse files
committed
cmd/compile: devirtualize interface calls with type assertions
1 parent ff27d27 commit d3d5a3b

File tree

4 files changed

+205
-9
lines changed

4 files changed

+205
-9
lines changed

src/cmd/compile/internal/devirtualize/devirtualize.go

+2-8
Original file line numberDiff line numberDiff line change
@@ -40,14 +40,8 @@ func StaticCall(call *ir.CallExpr) {
4040
}
4141

4242
sel := call.Fun.(*ir.SelectorExpr)
43-
r := ir.StaticValue(sel.X)
44-
if r.Op() != ir.OCONVIFACE {
45-
return
46-
}
47-
recv := r.(*ir.ConvExpr)
48-
49-
typ := recv.X.Type()
50-
if typ.IsInterface() {
43+
typ := ir.StaticType(sel.X)
44+
if typ == nil {
5145
return
5246
}
5347

src/cmd/compile/internal/ir/expr.go

+48-1
Original file line numberDiff line numberDiff line change
@@ -840,6 +840,37 @@ func IsAddressable(n Node) bool {
840840
return false
841841
}
842842

843+
var Implements = func(t, iface *types.Type) bool {
844+
panic("unreachable")
845+
}
846+
847+
// StaticType is like StaticValue but for types.
848+
func StaticType(n Node) *types.Type {
849+
out, typs := staticValue(n, true)
850+
851+
if out.Op() != OCONVIFACE {
852+
return nil
853+
}
854+
855+
recv := out.(*ConvExpr)
856+
857+
typ := recv.X.Type()
858+
if typ.IsInterface() {
859+
return nil
860+
}
861+
862+
// Make sure that every type assertion that involves interfaes is satisfied.
863+
for _, t := range typs {
864+
if t.IsInterface() {
865+
if !Implements(typ, t) {
866+
return nil
867+
}
868+
}
869+
}
870+
871+
return typ
872+
}
873+
843874
// StaticValue analyzes n to find the earliest expression that always
844875
// evaluates to the same value as n, which might be from an enclosing
845876
// function.
@@ -855,6 +886,16 @@ func IsAddressable(n Node) bool {
855886
// calling StaticValue on the "int(y)" expression returns the outer
856887
// "g()" expression.
857888
func StaticValue(n Node) Node {
889+
v, t := staticValue(n, false)
890+
if len(t) != 0 {
891+
base.Fatalf("len(t) != 0; len(t) = %v", len(t))
892+
}
893+
return v
894+
895+
}
896+
897+
func staticValue(n Node, forDevirt bool) (Node, []*types.Type) {
898+
typeAssertTypes := []*types.Type{}
858899
for {
859900
switch n1 := n.(type) {
860901
case *ConvExpr:
@@ -870,11 +911,17 @@ func StaticValue(n Node) Node {
870911
case *ParenExpr:
871912
n = n1.X
872913
continue
914+
case *TypeAssertExpr:
915+
if forDevirt && n1.Op() == ODOTTYPE {
916+
typeAssertTypes = append(typeAssertTypes, n1.Type())
917+
n = n1.X
918+
continue
919+
}
873920
}
874921

875922
n1 := staticValue1(n)
876923
if n1 == nil {
877-
return n
924+
return n, typeAssertTypes
878925
}
879926
n = n1
880927
}

src/cmd/compile/internal/typecheck/subr.go

+4
Original file line numberDiff line numberDiff line change
@@ -612,6 +612,10 @@ func Implements(t, iface *types.Type) bool {
612612
return implements(t, iface, &missing, &have, &ptr)
613613
}
614614

615+
func init() {
616+
ir.Implements = Implements
617+
}
618+
615619
// ImplementsExplain reports whether t implements the interface iface. t can be
616620
// an interface, a type parameter, or a concrete type. If t does not implement
617621
// iface, a non-empty string is returned explaining why.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
// errorcheck -0 -m
2+
3+
// Copyright 2025 The Go Authors. All rights reserved.
4+
// Use of this source code is governed by a BSD-style
5+
// license that can be found in the LICENSE file.
6+
7+
package escape
8+
9+
import (
10+
"crypto/sha256"
11+
"encoding"
12+
"hash"
13+
"io"
14+
)
15+
16+
type M interface{ M() }
17+
18+
type A interface{ A() }
19+
20+
type C interface{ C() }
21+
22+
type Impl struct{}
23+
24+
func (*Impl) M() {} // ERROR "can inline"
25+
26+
func (*Impl) A() {} // ERROR "can inline"
27+
28+
type CImpl struct{}
29+
30+
func (CImpl) C() {} // ERROR "can inline"
31+
32+
func t() {
33+
var a M = &Impl{} // ERROR "&Impl{} does not escape"
34+
35+
a.(M).M() // ERROR "devirtualizing a.\(M\).M" "inlining call"
36+
a.(A).A() // ERROR "devirtualizing a.\(A\).A" "inlining call"
37+
a.(*Impl).M() // ERROR "inlining call"
38+
a.(*Impl).A() // ERROR "inlining call"
39+
40+
v := a.(M)
41+
v.M() // ERROR "devirtualizing v.M" "inlining call"
42+
v.(A).A() // ERROR "devirtualizing v.\(A\).A" "inlining call"
43+
v.(*Impl).A() // ERROR "inlining call"
44+
v.(*Impl).M() // ERROR "inlining call"
45+
46+
v2 := a.(A)
47+
v2.A() // ERROR "devirtualizing v2.A" "inlining call"
48+
v2.(M).M() // ERROR "devirtualizing v2.\(M\).M" "inlining call"
49+
v2.(*Impl).A() // ERROR "inlining call"
50+
v2.(*Impl).M() // ERROR "inlining call"
51+
52+
a.(M).(A).A() // ERROR "devirtualizing a.\(M\).\(A\).A" "inlining call"
53+
a.(A).(M).M() // ERROR "devirtualizing a.\(A\).\(M\).M" "inlining call"
54+
55+
a.(M).(A).(*Impl).A() // ERROR "inlining call"
56+
a.(A).(M).(*Impl).M() // ERROR "inlining call"
57+
58+
{
59+
var a C = &CImpl{} // ERROR "does not escape"
60+
a.(any).(C).C() // ERROR "devirtualizing" "inlining"
61+
a.(any).(*CImpl).C() // ERROR "inlining"
62+
}
63+
}
64+
65+
// TODO: these type assertions could also be devirtualized.
66+
func t2() {
67+
{
68+
var a M = &Impl{} // ERROR "&Impl{} escapes to heap"
69+
if v, ok := a.(M); ok {
70+
v.M()
71+
}
72+
}
73+
{
74+
var a M = &Impl{} // ERROR "&Impl{} escapes to heap"
75+
if v, ok := a.(A); ok {
76+
v.A()
77+
}
78+
}
79+
{
80+
var a M = &Impl{} // ERROR "&Impl{} escapes to heap"
81+
v, ok := a.(M)
82+
if ok {
83+
v.M()
84+
}
85+
}
86+
{
87+
var a M = &Impl{} // ERROR "&Impl{} escapes to heap"
88+
v, ok := a.(A)
89+
if ok {
90+
v.A()
91+
}
92+
}
93+
{
94+
var a M = &Impl{} // ERROR "does not escape"
95+
v, ok := a.(*Impl)
96+
if ok {
97+
v.A() // ERROR "inlining"
98+
v.M() // ERROR "inlining"
99+
}
100+
}
101+
{
102+
var a M = &Impl{} // ERROR "&Impl{} escapes to heap"
103+
v, _ := a.(M)
104+
v.M()
105+
}
106+
{
107+
var a M = &Impl{} // ERROR "&Impl{} escapes to heap"
108+
v, _ := a.(A)
109+
v.A()
110+
}
111+
{
112+
var a M = &Impl{} // ERROR "does not escape"
113+
v, _ := a.(*Impl)
114+
v.A() // ERROR "inlining"
115+
v.M() // ERROR "inlining"
116+
}
117+
}
118+
119+
//go:noinline
120+
func testInvalidAsserts() {
121+
{
122+
var a M = &Impl{} // ERROR "escapes"
123+
a.(C).C() // this will panic
124+
a.(any).(C).C() // this will panic
125+
}
126+
{
127+
var a C = &CImpl{} // ERROR "escapes"
128+
a.(M).M() // this will panic
129+
a.(any).(M).M() // this will panic
130+
}
131+
{
132+
var a C = &CImpl{} // ERROR "does not escape"
133+
134+
// this will panic
135+
a.(M).(*Impl).M() // ERROR "inlining"
136+
137+
// this will panic
138+
a.(any).(M).(*Impl).M() // ERROR "inlining"
139+
}
140+
}
141+
142+
func testSha256() {
143+
h := sha256.New() // ERROR "inlining call" "does not escape"
144+
h.Write(nil) // ERROR "devirtualizing"
145+
h.(io.Writer).Write(nil) // ERROR "devirtualizing"
146+
h.(hash.Hash).Write(nil) // ERROR "devirtualizing"
147+
h.(encoding.BinaryUnmarshaler).UnmarshalBinary(nil) // ERROR "devirtualizing"
148+
149+
h2 := sha256.New() // ERROR "escapes" "inlining call"
150+
h2.(M).M() // this will panic
151+
}

0 commit comments

Comments
 (0)