@@ -280,23 +280,32 @@ impl<'a> BlockContext<'a> {
280
280
& self ,
281
281
handle : Handle < crate :: Expression > ,
282
282
valid_expressions : & HandleSet < crate :: Expression > ,
283
- ) -> Result < & crate :: TypeInner , WithSpan < ExpressionError > > {
283
+ ) -> Result < & TypeResolution , WithSpan < ExpressionError > > {
284
284
if !valid_expressions. contains ( handle) {
285
285
Err ( ExpressionError :: NotInScope . with_span_handle ( handle, self . expressions ) )
286
286
} else {
287
- Ok ( self . info [ handle] . ty . inner_with ( self . types ) )
287
+ Ok ( & self . info [ handle] . ty )
288
288
}
289
289
}
290
290
291
291
fn resolve_type (
292
292
& self ,
293
293
handle : Handle < crate :: Expression > ,
294
294
valid_expressions : & HandleSet < crate :: Expression > ,
295
- ) -> Result < & crate :: TypeInner , WithSpan < FunctionError > > {
295
+ ) -> Result < & TypeResolution , WithSpan < FunctionError > > {
296
296
self . resolve_type_impl ( handle, valid_expressions)
297
297
. map_err_inner ( |source| FunctionError :: Expression { handle, source } . with_span ( ) )
298
298
}
299
299
300
+ fn resolve_type_inner (
301
+ & self ,
302
+ handle : Handle < crate :: Expression > ,
303
+ valid_expressions : & HandleSet < crate :: Expression > ,
304
+ ) -> Result < & crate :: TypeInner , WithSpan < FunctionError > > {
305
+ self . resolve_type ( handle, valid_expressions)
306
+ . map ( |tr| tr. inner_with ( self . types ) )
307
+ }
308
+
300
309
fn resolve_pointer_type ( & self , handle : Handle < crate :: Expression > ) -> & crate :: TypeInner {
301
310
self . info [ handle] . ty . inner_with ( self . types )
302
311
}
@@ -330,7 +339,7 @@ impl super::Validator {
330
339
. with_span_handle ( expr, context. expressions )
331
340
} ) ?;
332
341
let arg_inner = & context. types [ arg. ty ] . inner ;
333
- if !ty. non_struct_equivalent ( arg_inner, context. types ) {
342
+ if !ty. inner_with ( context . types ) . non_struct_equivalent ( arg_inner, context. types ) {
334
343
return Err ( CallError :: ArgumentType {
335
344
index,
336
345
required : arg. ty ,
@@ -393,7 +402,7 @@ impl super::Validator {
393
402
context : & BlockContext ,
394
403
) -> Result < ( ) , WithSpan < FunctionError > > {
395
404
// The `pointer` operand must be a pointer to an atomic value.
396
- let pointer_inner = context. resolve_type ( pointer, & self . valid_expression_set ) ?;
405
+ let pointer_inner = context. resolve_type_inner ( pointer, & self . valid_expression_set ) ?;
397
406
let crate :: TypeInner :: Pointer {
398
407
base : pointer_base,
399
408
space : pointer_space,
@@ -415,7 +424,7 @@ impl super::Validator {
415
424
} ;
416
425
417
426
// The `value` operand must be a scalar of the same type as the atomic.
418
- let value_inner = context. resolve_type ( value, & self . valid_expression_set ) ?;
427
+ let value_inner = context. resolve_type_inner ( value, & self . valid_expression_set ) ?;
419
428
let crate :: TypeInner :: Scalar ( value_scalar) = * value_inner else {
420
429
log:: error!( "Atomic operand type {:?}" , * value_inner) ;
421
430
return Err ( AtomicError :: InvalidOperand ( value)
@@ -543,7 +552,7 @@ impl super::Validator {
543
552
// The comparison value must be a scalar of the same type as the
544
553
// atomic we're operating on.
545
554
let compare_inner =
546
- context. resolve_type ( compare, & self . valid_expression_set ) ?;
555
+ context. resolve_type_inner ( compare, & self . valid_expression_set ) ?;
547
556
if !compare_inner. non_struct_equivalent ( value_inner, context. types ) {
548
557
log:: error!(
549
558
"Atomic exchange comparison has a different type from the value"
@@ -620,7 +629,7 @@ impl super::Validator {
620
629
result : Handle < crate :: Expression > ,
621
630
context : & BlockContext ,
622
631
) -> Result < ( ) , WithSpan < FunctionError > > {
623
- let argument_inner = context. resolve_type ( argument, & self . valid_expression_set ) ?;
632
+ let argument_inner = context. resolve_type_inner ( argument, & self . valid_expression_set ) ?;
624
633
625
634
let ( is_scalar, scalar) = match * argument_inner {
626
635
crate :: TypeInner :: Scalar ( scalar) => ( true , scalar) ,
@@ -695,7 +704,7 @@ impl super::Validator {
695
704
| crate :: GatherMode :: ShuffleDown ( index)
696
705
| crate :: GatherMode :: ShuffleUp ( index)
697
706
| crate :: GatherMode :: ShuffleXor ( index) => {
698
- let index_ty = context. resolve_type ( index, & self . valid_expression_set ) ?;
707
+ let index_ty = context. resolve_type_inner ( index, & self . valid_expression_set ) ?;
699
708
match * index_ty {
700
709
crate :: TypeInner :: Scalar ( crate :: Scalar :: U32 ) => { }
701
710
_ => {
@@ -710,7 +719,7 @@ impl super::Validator {
710
719
}
711
720
}
712
721
}
713
- let argument_inner = context. resolve_type ( argument, & self . valid_expression_set ) ?;
722
+ let argument_inner = context. resolve_type_inner ( argument, & self . valid_expression_set ) ?;
714
723
if !matches ! ( * argument_inner,
715
724
crate :: TypeInner :: Scalar ( scalar, .. ) | crate :: TypeInner :: Vector { scalar, .. }
716
725
if matches!( scalar. kind, crate :: ScalarKind :: Uint | crate :: ScalarKind :: Sint | crate :: ScalarKind :: Float )
@@ -802,7 +811,7 @@ impl super::Validator {
802
811
ref accept,
803
812
ref reject,
804
813
} => {
805
- match * context. resolve_type ( condition, & self . valid_expression_set ) ? {
814
+ match * context. resolve_type_inner ( condition, & self . valid_expression_set ) ? {
806
815
Ti :: Scalar ( crate :: Scalar {
807
816
kind : crate :: ScalarKind :: Bool ,
808
817
width : _,
@@ -820,7 +829,7 @@ impl super::Validator {
820
829
ref cases,
821
830
} => {
822
831
let uint = match context
823
- . resolve_type ( selector, & self . valid_expression_set ) ?
832
+ . resolve_type_inner ( selector, & self . valid_expression_set ) ?
824
833
. scalar_kind ( )
825
834
{
826
835
Some ( crate :: ScalarKind :: Uint ) => true ,
@@ -917,7 +926,7 @@ impl super::Validator {
917
926
. stages ;
918
927
919
928
if let Some ( condition) = break_if {
920
- match * context. resolve_type ( condition, & self . valid_expression_set ) ? {
929
+ match * context. resolve_type_inner ( condition, & self . valid_expression_set ) ? {
921
930
Ti :: Scalar ( crate :: Scalar {
922
931
kind : crate :: ScalarKind :: Bool ,
923
932
width : _,
@@ -961,7 +970,7 @@ impl super::Validator {
961
970
let okay = match ( value_ty, expected_ty) {
962
971
( None , None ) => true ,
963
972
( Some ( value_inner) , Some ( expected_inner) ) => {
964
- value_inner. non_struct_equivalent ( expected_inner, context. types )
973
+ value_inner. inner_with ( context . types ) . non_struct_equivalent ( expected_inner, context. types )
965
974
}
966
975
( _, _) => false ,
967
976
} ;
@@ -1027,7 +1036,7 @@ impl super::Validator {
1027
1036
}
1028
1037
}
1029
1038
1030
- let value_ty = context. resolve_type ( value, & self . valid_expression_set ) ?;
1039
+ let value_ty = context. resolve_type_inner ( value, & self . valid_expression_set ) ?;
1031
1040
match * value_ty {
1032
1041
Ti :: Image { .. } | Ti :: Sampler { .. } => {
1033
1042
return Err ( FunctionError :: InvalidStoreTexture {
@@ -1145,7 +1154,7 @@ impl super::Validator {
1145
1154
1146
1155
// The `coordinate` operand must be a vector of the appropriate size.
1147
1156
if context
1148
- . resolve_type ( coordinate, & self . valid_expression_set ) ?
1157
+ . resolve_type_inner ( coordinate, & self . valid_expression_set ) ?
1149
1158
. image_storage_coordinates ( )
1150
1159
. is_none_or ( |coord_dim| coord_dim != dim)
1151
1160
{
@@ -1167,7 +1176,7 @@ impl super::Validator {
1167
1176
// If present, `array_index` must be a scalar integer type.
1168
1177
if let Some ( expr) = array_index {
1169
1178
if !matches ! (
1170
- * context. resolve_type ( expr, & self . valid_expression_set) ?,
1179
+ * context. resolve_type_inner ( expr, & self . valid_expression_set) ?,
1171
1180
Ti :: Scalar ( crate :: Scalar {
1172
1181
kind: crate :: ScalarKind :: Sint | crate :: ScalarKind :: Uint ,
1173
1182
width: _,
@@ -1188,7 +1197,7 @@ impl super::Validator {
1188
1197
// The value we're writing had better match the scalar type
1189
1198
// for `image`'s format.
1190
1199
let actual_value_ty =
1191
- context. resolve_type ( value, & self . valid_expression_set ) ?;
1200
+ context. resolve_type_inner ( value, & self . valid_expression_set ) ?;
1192
1201
if actual_value_ty != & value_ty {
1193
1202
return Err ( FunctionError :: InvalidStoreValue {
1194
1203
actual : value,
@@ -1273,7 +1282,7 @@ impl super::Validator {
1273
1282
dim,
1274
1283
} => {
1275
1284
match context
1276
- . resolve_type ( coordinate, & self . valid_expression_set ) ?
1285
+ . resolve_type_inner ( coordinate, & self . valid_expression_set ) ?
1277
1286
. image_storage_coordinates ( )
1278
1287
{
1279
1288
Some ( coord_dim) if coord_dim == dim => { }
@@ -1293,7 +1302,9 @@ impl super::Validator {
1293
1302
. with_span_handle ( coordinate, context. expressions ) ) ;
1294
1303
}
1295
1304
if let Some ( expr) = array_index {
1296
- match * context. resolve_type ( expr, & self . valid_expression_set ) ? {
1305
+ match * context
1306
+ . resolve_type_inner ( expr, & self . valid_expression_set ) ?
1307
+ {
1297
1308
Ti :: Scalar ( crate :: Scalar {
1298
1309
kind : crate :: ScalarKind :: Sint | crate :: ScalarKind :: Uint ,
1299
1310
width : _,
@@ -1404,15 +1415,15 @@ impl super::Validator {
1404
1415
}
1405
1416
} ;
1406
1417
1407
- if * context. resolve_type ( value, & self . valid_expression_set ) ? != value_ty {
1418
+ if * context. resolve_type_inner ( value, & self . valid_expression_set ) ? != value_ty {
1408
1419
return Err ( FunctionError :: InvalidImageAtomicValue ( value)
1409
1420
. with_span_handle ( value, context. expressions ) ) ;
1410
1421
}
1411
1422
}
1412
1423
S :: WorkGroupUniformLoad { pointer, result } => {
1413
1424
stages &= super :: ShaderStages :: COMPUTE ;
1414
1425
let pointer_inner =
1415
- context. resolve_type ( pointer, & self . valid_expression_set ) ?;
1426
+ context. resolve_type_inner ( pointer, & self . valid_expression_set ) ?;
1416
1427
match * pointer_inner {
1417
1428
Ti :: Pointer {
1418
1429
space : AddressSpace :: WorkGroup ,
@@ -1468,9 +1479,10 @@ impl super::Validator {
1468
1479
acceleration_structure,
1469
1480
descriptor,
1470
1481
} => {
1471
- match * context
1472
- . resolve_type ( acceleration_structure, & self . valid_expression_set ) ?
1473
- {
1482
+ match * context. resolve_type_inner (
1483
+ acceleration_structure,
1484
+ & self . valid_expression_set ,
1485
+ ) ? {
1474
1486
Ti :: AccelerationStructure { vertex_return } => {
1475
1487
if ( !vertex_return) && rq_vertex_return {
1476
1488
return Err ( FunctionError :: MissingAccelerationStructureVertexReturn ( acceleration_structure, query) . with_span_static ( span, "invalid acceleration structure" ) ) ;
@@ -1483,8 +1495,8 @@ impl super::Validator {
1483
1495
. with_span_static ( span, "invalid acceleration structure" ) )
1484
1496
}
1485
1497
}
1486
- let desc_ty_given =
1487
- context . resolve_type ( descriptor, & self . valid_expression_set ) ?;
1498
+ let desc_ty_given = context
1499
+ . resolve_type_inner ( descriptor, & self . valid_expression_set ) ?;
1488
1500
let desc_ty_expected = context
1489
1501
. special_types
1490
1502
. ray_desc
@@ -1498,7 +1510,7 @@ impl super::Validator {
1498
1510
self . emit_expression ( result, context) ?;
1499
1511
}
1500
1512
crate :: RayQueryFunction :: GenerateIntersection { hit_t } => {
1501
- match * context. resolve_type ( hit_t, & self . valid_expression_set ) ? {
1513
+ match * context. resolve_type_inner ( hit_t, & self . valid_expression_set ) ? {
1502
1514
Ti :: Scalar ( crate :: Scalar {
1503
1515
kind : crate :: ScalarKind :: Float ,
1504
1516
width : _,
@@ -1534,7 +1546,7 @@ impl super::Validator {
1534
1546
}
1535
1547
if let Some ( predicate) = predicate {
1536
1548
let predicate_inner =
1537
- context. resolve_type ( predicate, & self . valid_expression_set ) ?;
1549
+ context. resolve_type_inner ( predicate, & self . valid_expression_set ) ?;
1538
1550
if !matches ! (
1539
1551
* predicate_inner,
1540
1552
crate :: TypeInner :: Scalar ( crate :: Scalar :: BOOL , )
0 commit comments