Skip to content

Commit 4d8454c

Browse files
committed
Rust: Handle associated types in trait methods
1 parent f9ff92a commit 4d8454c

File tree

6 files changed

+135
-20
lines changed

6 files changed

+135
-20
lines changed

rust/ql/lib/codeql/rust/elements/internal/TraitImpl.qll

+7
Original file line numberDiff line numberDiff line change
@@ -26,5 +26,12 @@ module Impl {
2626
*/
2727
class Trait extends Generated::Trait {
2828
override string toStringImpl() { result = "trait " + this.getName().getText() }
29+
30+
int getNumberOfGenericParams() {
31+
result = this.getGenericParamList().getNumberOfGenericParams()
32+
or
33+
not this.hasGenericParamList() and
34+
result = 0
35+
}
2936
}
3037
}

rust/ql/lib/codeql/rust/internal/Type.qll

+63-1
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@
22

33
private import rust
44
private import PathResolution
5-
private import TypeInference
65
private import TypeMention
76
private import codeql.rust.internal.CachedStages
7+
private import codeql.rust.elements.internal.generated.Raw
8+
private import codeql.rust.elements.internal.generated.Synth
89

910
cached
1011
newtype TType =
@@ -15,6 +16,7 @@ newtype TType =
1516
TArrayType() or // todo: add size?
1617
TRefType() or // todo: add mut?
1718
TTypeParamTypeParameter(TypeParam t) or
19+
TAssociatedTypeTypeParameter(TypeAlias t) { any(TraitItemNode trait).getADescendant() = t } or
1820
TRefTypeParameter() or
1921
TSelfTypeParameter(Trait t)
2022

@@ -144,6 +146,9 @@ class TraitType extends Type, TTrait {
144146

145147
override TypeParameter getTypeParameter(int i) {
146148
result = TTypeParamTypeParameter(trait.getGenericParamList().getTypeParam(i))
149+
or
150+
result =
151+
any(AssociatedTypeTypeParameter param | param.getTrait() = trait and param.getIndex() = i)
147152
}
148153

149154
pragma[nomagic]
@@ -297,6 +302,14 @@ abstract class TypeParameter extends Type {
297302
override TypeParameter getTypeParameter(int i) { none() }
298303
}
299304

305+
private class RawTypeParameter = @type_param or @trait or @type_alias;
306+
307+
private predicate id(RawTypeParameter x, RawTypeParameter y) { x = y }
308+
309+
private predicate idOfRaw(RawTypeParameter x, int y) = equivalenceRelation(id/2)(x, y)
310+
311+
int idOfTypeParameterAstNode(AstNode node) { idOfRaw(Synth::convertAstNodeToRaw(node), result) }
312+
300313
/** A type parameter from source code. */
301314
class TypeParamTypeParameter extends TypeParameter, TTypeParamTypeParameter {
302315
private TypeParam typeParam;
@@ -320,6 +333,55 @@ class TypeParamTypeParameter extends TypeParameter, TTypeParamTypeParameter {
320333
}
321334
}
322335

336+
/** Gets type alias that is the `i`th type parameter of `trait`. */
337+
predicate traitAliasIndex(Trait trait, int i, TypeAlias typeAlias) {
338+
typeAlias =
339+
rank[i + 1 - trait.getNumberOfGenericParams()](TypeAlias alias |
340+
trait.(TraitItemNode).getADescendant() = alias
341+
|
342+
alias order by idOfTypeParameterAstNode(alias)
343+
)
344+
}
345+
346+
/**
347+
* A type parameter corresponding to an associated type in a trait.
348+
*
349+
* We treat associated type declarations in traits as type parameters. E.g., a
350+
* trait such as
351+
* ```rust
352+
* trait ATrait {
353+
* type AssociatedType;
354+
* // ...
355+
* }
356+
* ```
357+
* is treated as if it where
358+
* ```rust
359+
* trait ATrait<AssociatedType> {
360+
* // ...
361+
* }
362+
* ```
363+
*/
364+
class AssociatedTypeTypeParameter extends TypeParameter, TAssociatedTypeTypeParameter {
365+
private TypeAlias typeAlias;
366+
367+
AssociatedTypeTypeParameter() { this = TAssociatedTypeTypeParameter(typeAlias) }
368+
369+
TypeAlias getTypeAlias() { result = typeAlias }
370+
371+
/** Gets the trait that contains this associated type declaration. */
372+
TraitItemNode getTrait() { result.getADescendant() = typeAlias }
373+
374+
int getIndex() { traitAliasIndex(this.getTrait(), result, typeAlias) }
375+
376+
override Function getMethod(string name) { none() }
377+
378+
override string toString() { result = typeAlias.getName().getText() }
379+
380+
override Location getLocation() { result = typeAlias.getLocation() }
381+
382+
override TypeMention getABaseTypeMention() { none() }
383+
}
384+
323385
/** An implicit reference type parameter. */
324386
class RefTypeParameter extends TypeParameter, TRefTypeParameter {
325387
override Function getMethod(string name) { none() }

rust/ql/lib/codeql/rust/internal/TypeInference.qll

+13-14
Original file line numberDiff line numberDiff line change
@@ -40,17 +40,21 @@ private module Input1 implements InputSig1<Location> {
4040

4141
private newtype TTypeParameterPosition =
4242
TTypeParamTypeParameterPosition(TypeParam tp) or
43-
TSelfTypeParameterPosition()
43+
TImplicitTypeParameterPosition()
4444

4545
class TypeParameterPosition extends TTypeParameterPosition {
4646
TypeParam asTypeParam() { this = TTypeParamTypeParameterPosition(result) }
4747

48-
predicate isSelf() { this = TSelfTypeParameterPosition() }
48+
/**
49+
* Holds if this is the implicit type parameter position used to represent
50+
* parameters that are never passed explicitly as arguments.
51+
*/
52+
predicate isImplicit() { this = TImplicitTypeParameterPosition() }
4953

5054
string toString() {
5155
result = this.asTypeParam().toString()
5256
or
53-
result = "Self" and this.isSelf()
57+
result = "Implicit" and this.isImplicit()
5458
}
5559
}
5660

@@ -69,15 +73,6 @@ private module Input1 implements InputSig1<Location> {
6973
apos.asMethodTypeArgumentPosition() = ppos.asTypeParam().getPosition()
7074
}
7175

72-
/** A raw AST node that might correspond to a type parameter. */
73-
private class RawTypeParameter = @type_param or @trait;
74-
75-
private predicate id(RawTypeParameter x, RawTypeParameter y) { x = y }
76-
77-
private predicate idOfRaw(RawTypeParameter x, int y) = equivalenceRelation(id/2)(x, y)
78-
79-
private int idOf(AstNode node) { idOfRaw(Synth::convertAstNodeToRaw(node), result) }
80-
8176
int getTypeParameterId(TypeParameter tp) {
8277
tp =
8378
rank[result](TypeParameter tp0, int kind, int id |
@@ -86,8 +81,9 @@ private module Input1 implements InputSig1<Location> {
8681
id = 0
8782
or
8883
kind = 1 and
89-
exists(AstNode node | id = idOf(node) |
84+
exists(AstNode node | id = idOfTypeParameterAstNode(node) |
9085
node = tp0.(TypeParamTypeParameter).getTypeParam() or
86+
node = tp0.(AssociatedTypeTypeParameter).getTypeAlias() or
9187
node = tp0.(SelfTypeParameter).getTrait()
9288
)
9389
|
@@ -500,7 +496,10 @@ private module CallExprBaseMatchingInput implements MatchingInputSig {
500496
exists(TraitItemNode trait | this = trait.getAnAssocItem() |
501497
typeParamMatchPosition(trait.getTypeParam(_), result, ppos)
502498
or
503-
ppos.isSelf() and result = TSelfTypeParameter(trait)
499+
ppos.isImplicit() and result = TSelfTypeParameter(trait)
500+
or
501+
ppos.isImplicit() and
502+
result.(AssociatedTypeTypeParameter).getTrait() = trait
504503
)
505504
}
506505

rust/ql/lib/codeql/rust/internal/TypeMention.qll

+38-2
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,27 @@ class PathMention extends TypeMention, Path {
7575
this = node.getASelfPath() and
7676
result = node.(ImplItemNode).getSelfPath().getSegment().getGenericArgList().getTypeArg(i)
7777
)
78+
or
79+
// If `this` is the trait of an `impl` block then any associated types
80+
// defined in the `impl` block are type arguments to the trait.
81+
//
82+
// For instance, for a trait implementation like this
83+
// ```rust
84+
// impl MyTrait for MyType {
85+
// ^^^^^^^ this
86+
// type AssociatedType = i64
87+
// ^^^ result
88+
// // ...
89+
// }
90+
// ```
91+
// the rhs. of the type alias is a type argument to the trait.
92+
exists(ImplItemNode impl, AssociatedTypeTypeParameter param, TypeAlias alias |
93+
this = impl.getTraitPath() and
94+
param.getTrait() = resolvePath(this) and
95+
alias = impl.getASuccessor(param.getTypeAlias().getName().getText()) and
96+
result = alias.getTypeRepr() and
97+
param.getIndex() = i
98+
)
7899
}
79100

80101
override Type resolveType() {
@@ -93,7 +114,11 @@ class PathMention extends TypeMention, Path {
93114
or
94115
result = TTypeParamTypeParameter(i)
95116
or
96-
result = i.(TypeAlias).getTypeRepr().(TypeReprMention).resolveType()
117+
exists(TypeAlias alias | alias = i |
118+
result.(AssociatedTypeTypeParameter).getTypeAlias() = alias
119+
or
120+
result = alias.getTypeRepr().(TypeReprMention).resolveType()
121+
)
97122
)
98123
}
99124
}
@@ -106,6 +131,13 @@ class TypeParamMention extends TypeMention, TypeParam {
106131
override Type resolveType() { result = TTypeParamTypeParameter(this) }
107132
}
108133

134+
// Used to represent implicit associated type type arguments in traits.
135+
class TypeAliasMention extends TypeMention, TypeAlias {
136+
override TypeReprMention getTypeArgument(int i) { none() }
137+
138+
override Type resolveType() { result = TAssociatedTypeTypeParameter(this) }
139+
}
140+
109141
/**
110142
* Holds if the `i`th type argument of `selfPath`, belonging to `impl`, resolves
111143
* to type parameter `tp`.
@@ -157,7 +189,11 @@ class ImplMention extends TypeMention, ImplItemNode {
157189
}
158190

159191
class TraitMention extends TypeMention, TraitItemNode {
160-
override TypeMention getTypeArgument(int i) { result = this.getTypeParam(i) }
192+
override TypeMention getTypeArgument(int i) {
193+
result = this.getTypeParam(i)
194+
or
195+
traitAliasIndex(this, i, result)
196+
}
161197

162198
override Type resolveType() { result = TTrait(this) }
163199
}

rust/ql/test/library-tests/type-inference/main.rs

+3-3
Original file line numberDiff line numberDiff line change
@@ -351,7 +351,7 @@ mod trait_associated_type {
351351
Self::AssociatedType: Default,
352352
Self: Sized,
353353
{
354-
self.m1(); // $ method=MyTrait::m1
354+
self.m1(); // $ method=MyTrait::m1 type=self.m1():AssociatedType
355355
Self::AssociatedType::default()
356356
}
357357
}
@@ -424,7 +424,7 @@ mod trait_associated_type {
424424

425425
let x2 = S;
426426
// Call to default method in `trait` block
427-
let y = x2.m2(); // $ method=m2 MISSING: type=y:AT
427+
let y = x2.m2(); // $ method=m2 type=y:AT
428428
println!("{:?}", y);
429429

430430
let x3 = S;
@@ -440,7 +440,7 @@ mod trait_associated_type {
440440
let x5 = S2;
441441
println!("{:?}", x5.m1()); // $ method=m1 MISSING: type=x5.m1():A.S2
442442
let x6 = S2;
443-
println!("{:?}", x6.m2()); // $ method=m2 MISSING: type=x6.m2():A.S2
443+
println!("{:?}", x6.m2()); // $ method=m2 type=x6.m2():A.S2
444444
}
445445
}
446446

rust/ql/test/library-tests/type-inference/type-inference.expected

+11
Original file line numberDiff line numberDiff line change
@@ -336,19 +336,25 @@ inferType
336336
| main.rs:339:13:339:22 | self.field | | main.rs:337:10:337:10 | A |
337337
| main.rs:347:15:347:18 | SelfParam | | main.rs:343:5:357:5 | Self [trait MyTrait] |
338338
| main.rs:349:15:349:18 | SelfParam | | main.rs:343:5:357:5 | Self [trait MyTrait] |
339+
| main.rs:353:9:356:9 | { ... } | | main.rs:344:9:344:28 | AssociatedType |
339340
| main.rs:354:13:354:16 | self | | main.rs:343:5:357:5 | Self [trait MyTrait] |
341+
| main.rs:354:13:354:21 | self.m1() | | main.rs:344:9:344:28 | AssociatedType |
342+
| main.rs:355:13:355:43 | ...::default(...) | | main.rs:344:9:344:28 | AssociatedType |
340343
| main.rs:363:19:363:23 | SelfParam | | file://:0:0:0:0 | & |
341344
| main.rs:363:19:363:23 | SelfParam | &T | main.rs:359:5:369:5 | Self [trait MyTraitAssoc2] |
342345
| main.rs:363:26:363:26 | a | | main.rs:363:16:363:16 | A |
343346
| main.rs:365:22:365:26 | SelfParam | | file://:0:0:0:0 | & |
344347
| main.rs:365:22:365:26 | SelfParam | &T | main.rs:359:5:369:5 | Self [trait MyTraitAssoc2] |
345348
| main.rs:365:29:365:29 | a | | main.rs:365:19:365:19 | A |
346349
| main.rs:365:35:365:35 | b | | main.rs:365:19:365:19 | A |
350+
| main.rs:365:75:368:9 | { ... } | | main.rs:360:9:360:52 | GenericAssociatedType |
347351
| main.rs:366:13:366:16 | self | | file://:0:0:0:0 | & |
348352
| main.rs:366:13:366:16 | self | &T | main.rs:359:5:369:5 | Self [trait MyTraitAssoc2] |
353+
| main.rs:366:13:366:23 | self.put(...) | | main.rs:360:9:360:52 | GenericAssociatedType |
349354
| main.rs:366:22:366:22 | a | | main.rs:365:19:365:19 | A |
350355
| main.rs:367:13:367:16 | self | | file://:0:0:0:0 | & |
351356
| main.rs:367:13:367:16 | self | &T | main.rs:359:5:369:5 | Self [trait MyTraitAssoc2] |
357+
| main.rs:367:13:367:23 | self.put(...) | | main.rs:360:9:360:52 | GenericAssociatedType |
352358
| main.rs:367:22:367:22 | b | | main.rs:365:19:365:19 | A |
353359
| main.rs:384:15:384:18 | SelfParam | | main.rs:371:5:372:13 | S |
354360
| main.rs:384:45:386:9 | { ... } | | main.rs:377:5:378:14 | AT |
@@ -380,7 +386,10 @@ inferType
380386
| main.rs:423:26:423:32 | x1.m1() | | main.rs:377:5:378:14 | AT |
381387
| main.rs:425:13:425:14 | x2 | | main.rs:371:5:372:13 | S |
382388
| main.rs:425:18:425:18 | S | | main.rs:371:5:372:13 | S |
389+
| main.rs:427:13:427:13 | y | | main.rs:377:5:378:14 | AT |
383390
| main.rs:427:17:427:18 | x2 | | main.rs:371:5:372:13 | S |
391+
| main.rs:427:17:427:23 | x2.m2() | | main.rs:377:5:378:14 | AT |
392+
| main.rs:428:26:428:26 | y | | main.rs:377:5:378:14 | AT |
384393
| main.rs:430:13:430:14 | x3 | | main.rs:371:5:372:13 | S |
385394
| main.rs:430:18:430:18 | S | | main.rs:371:5:372:13 | S |
386395
| main.rs:432:26:432:27 | x3 | | main.rs:371:5:372:13 | S |
@@ -394,6 +403,8 @@ inferType
394403
| main.rs:442:13:442:14 | x6 | | main.rs:374:5:375:14 | S2 |
395404
| main.rs:442:18:442:19 | S2 | | main.rs:374:5:375:14 | S2 |
396405
| main.rs:443:26:443:27 | x6 | | main.rs:374:5:375:14 | S2 |
406+
| main.rs:443:26:443:32 | x6.m2() | | main.rs:332:5:335:5 | Wrapper |
407+
| main.rs:443:26:443:32 | x6.m2() | A | main.rs:374:5:375:14 | S2 |
397408
| main.rs:460:15:460:18 | SelfParam | | main.rs:448:5:452:5 | MyEnum |
398409
| main.rs:460:15:460:18 | SelfParam | A | main.rs:459:10:459:10 | T |
399410
| main.rs:460:26:465:9 | { ... } | | main.rs:459:10:459:10 | T |

0 commit comments

Comments
 (0)