Skip to content

Commit 545d84f

Browse files
committed
Refactor BlockContext type resolution methods
Change `resolve_type` and `resolve_type_impl` to return `TypeResolution`s. Add a new method `resolve_type_inner` that returns a `TypeInner` (i.e. what `resolve_type` used to do).
1 parent bbf34ea commit 545d84f

File tree

1 file changed

+41
-29
lines changed

1 file changed

+41
-29
lines changed

naga/src/valid/function.rs

+41-29
Original file line numberDiff line numberDiff line change
@@ -280,23 +280,32 @@ impl<'a> BlockContext<'a> {
280280
&self,
281281
handle: Handle<crate::Expression>,
282282
valid_expressions: &HandleSet<crate::Expression>,
283-
) -> Result<&crate::TypeInner, WithSpan<ExpressionError>> {
283+
) -> Result<&TypeResolution, WithSpan<ExpressionError>> {
284284
if !valid_expressions.contains(handle) {
285285
Err(ExpressionError::NotInScope.with_span_handle(handle, self.expressions))
286286
} else {
287-
Ok(self.info[handle].ty.inner_with(self.types))
287+
Ok(&self.info[handle].ty)
288288
}
289289
}
290290

291291
fn resolve_type(
292292
&self,
293293
handle: Handle<crate::Expression>,
294294
valid_expressions: &HandleSet<crate::Expression>,
295-
) -> Result<&crate::TypeInner, WithSpan<FunctionError>> {
295+
) -> Result<&TypeResolution, WithSpan<FunctionError>> {
296296
self.resolve_type_impl(handle, valid_expressions)
297297
.map_err_inner(|source| FunctionError::Expression { handle, source }.with_span())
298298
}
299299

300+
fn resolve_type_inner(
301+
&self,
302+
handle: Handle<crate::Expression>,
303+
valid_expressions: &HandleSet<crate::Expression>,
304+
) -> Result<&crate::TypeInner, WithSpan<FunctionError>> {
305+
self.resolve_type(handle, valid_expressions)
306+
.map(|tr| tr.inner_with(self.types))
307+
}
308+
300309
fn resolve_pointer_type(&self, handle: Handle<crate::Expression>) -> &crate::TypeInner {
301310
self.info[handle].ty.inner_with(self.types)
302311
}
@@ -330,7 +339,7 @@ impl super::Validator {
330339
.with_span_handle(expr, context.expressions)
331340
})?;
332341
let arg_inner = &context.types[arg.ty].inner;
333-
if !ty.non_struct_equivalent(arg_inner, context.types) {
342+
if !ty.inner_with(context.types).non_struct_equivalent(arg_inner, context.types) {
334343
return Err(CallError::ArgumentType {
335344
index,
336345
required: arg.ty,
@@ -393,7 +402,7 @@ impl super::Validator {
393402
context: &BlockContext,
394403
) -> Result<(), WithSpan<FunctionError>> {
395404
// The `pointer` operand must be a pointer to an atomic value.
396-
let pointer_inner = context.resolve_type(pointer, &self.valid_expression_set)?;
405+
let pointer_inner = context.resolve_type_inner(pointer, &self.valid_expression_set)?;
397406
let crate::TypeInner::Pointer {
398407
base: pointer_base,
399408
space: pointer_space,
@@ -415,7 +424,7 @@ impl super::Validator {
415424
};
416425

417426
// The `value` operand must be a scalar of the same type as the atomic.
418-
let value_inner = context.resolve_type(value, &self.valid_expression_set)?;
427+
let value_inner = context.resolve_type_inner(value, &self.valid_expression_set)?;
419428
let crate::TypeInner::Scalar(value_scalar) = *value_inner else {
420429
log::error!("Atomic operand type {:?}", *value_inner);
421430
return Err(AtomicError::InvalidOperand(value)
@@ -543,7 +552,7 @@ impl super::Validator {
543552
// The comparison value must be a scalar of the same type as the
544553
// atomic we're operating on.
545554
let compare_inner =
546-
context.resolve_type(compare, &self.valid_expression_set)?;
555+
context.resolve_type_inner(compare, &self.valid_expression_set)?;
547556
if !compare_inner.non_struct_equivalent(value_inner, context.types) {
548557
log::error!(
549558
"Atomic exchange comparison has a different type from the value"
@@ -620,7 +629,7 @@ impl super::Validator {
620629
result: Handle<crate::Expression>,
621630
context: &BlockContext,
622631
) -> Result<(), WithSpan<FunctionError>> {
623-
let argument_inner = context.resolve_type(argument, &self.valid_expression_set)?;
632+
let argument_inner = context.resolve_type_inner(argument, &self.valid_expression_set)?;
624633

625634
let (is_scalar, scalar) = match *argument_inner {
626635
crate::TypeInner::Scalar(scalar) => (true, scalar),
@@ -695,7 +704,7 @@ impl super::Validator {
695704
| crate::GatherMode::ShuffleDown(index)
696705
| crate::GatherMode::ShuffleUp(index)
697706
| crate::GatherMode::ShuffleXor(index) => {
698-
let index_ty = context.resolve_type(index, &self.valid_expression_set)?;
707+
let index_ty = context.resolve_type_inner(index, &self.valid_expression_set)?;
699708
match *index_ty {
700709
crate::TypeInner::Scalar(crate::Scalar::U32) => {}
701710
_ => {
@@ -710,7 +719,7 @@ impl super::Validator {
710719
}
711720
}
712721
}
713-
let argument_inner = context.resolve_type(argument, &self.valid_expression_set)?;
722+
let argument_inner = context.resolve_type_inner(argument, &self.valid_expression_set)?;
714723
if !matches!(*argument_inner,
715724
crate::TypeInner::Scalar ( scalar, .. ) | crate::TypeInner::Vector { scalar, .. }
716725
if matches!(scalar.kind, crate::ScalarKind::Uint | crate::ScalarKind::Sint | crate::ScalarKind::Float)
@@ -802,7 +811,7 @@ impl super::Validator {
802811
ref accept,
803812
ref reject,
804813
} => {
805-
match *context.resolve_type(condition, &self.valid_expression_set)? {
814+
match *context.resolve_type_inner(condition, &self.valid_expression_set)? {
806815
Ti::Scalar(crate::Scalar {
807816
kind: crate::ScalarKind::Bool,
808817
width: _,
@@ -820,7 +829,7 @@ impl super::Validator {
820829
ref cases,
821830
} => {
822831
let uint = match context
823-
.resolve_type(selector, &self.valid_expression_set)?
832+
.resolve_type_inner(selector, &self.valid_expression_set)?
824833
.scalar_kind()
825834
{
826835
Some(crate::ScalarKind::Uint) => true,
@@ -917,7 +926,7 @@ impl super::Validator {
917926
.stages;
918927

919928
if let Some(condition) = break_if {
920-
match *context.resolve_type(condition, &self.valid_expression_set)? {
929+
match *context.resolve_type_inner(condition, &self.valid_expression_set)? {
921930
Ti::Scalar(crate::Scalar {
922931
kind: crate::ScalarKind::Bool,
923932
width: _,
@@ -961,7 +970,7 @@ impl super::Validator {
961970
let okay = match (value_ty, expected_ty) {
962971
(None, None) => true,
963972
(Some(value_inner), Some(expected_inner)) => {
964-
value_inner.non_struct_equivalent(expected_inner, context.types)
973+
value_inner.inner_with(context.types).non_struct_equivalent(expected_inner, context.types)
965974
}
966975
(_, _) => false,
967976
};
@@ -1027,7 +1036,7 @@ impl super::Validator {
10271036
}
10281037
}
10291038

1030-
let value_ty = context.resolve_type(value, &self.valid_expression_set)?;
1039+
let value_ty = context.resolve_type_inner(value, &self.valid_expression_set)?;
10311040
match *value_ty {
10321041
Ti::Image { .. } | Ti::Sampler { .. } => {
10331042
return Err(FunctionError::InvalidStoreTexture {
@@ -1145,7 +1154,7 @@ impl super::Validator {
11451154

11461155
// The `coordinate` operand must be a vector of the appropriate size.
11471156
if context
1148-
.resolve_type(coordinate, &self.valid_expression_set)?
1157+
.resolve_type_inner(coordinate, &self.valid_expression_set)?
11491158
.image_storage_coordinates()
11501159
.is_none_or(|coord_dim| coord_dim != dim)
11511160
{
@@ -1167,7 +1176,7 @@ impl super::Validator {
11671176
// If present, `array_index` must be a scalar integer type.
11681177
if let Some(expr) = array_index {
11691178
if !matches!(
1170-
*context.resolve_type(expr, &self.valid_expression_set)?,
1179+
*context.resolve_type_inner(expr, &self.valid_expression_set)?,
11711180
Ti::Scalar(crate::Scalar {
11721181
kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
11731182
width: _,
@@ -1188,7 +1197,7 @@ impl super::Validator {
11881197
// The value we're writing had better match the scalar type
11891198
// for `image`'s format.
11901199
let actual_value_ty =
1191-
context.resolve_type(value, &self.valid_expression_set)?;
1200+
context.resolve_type_inner(value, &self.valid_expression_set)?;
11921201
if actual_value_ty != &value_ty {
11931202
return Err(FunctionError::InvalidStoreValue {
11941203
actual: value,
@@ -1273,7 +1282,7 @@ impl super::Validator {
12731282
dim,
12741283
} => {
12751284
match context
1276-
.resolve_type(coordinate, &self.valid_expression_set)?
1285+
.resolve_type_inner(coordinate, &self.valid_expression_set)?
12771286
.image_storage_coordinates()
12781287
{
12791288
Some(coord_dim) if coord_dim == dim => {}
@@ -1293,7 +1302,9 @@ impl super::Validator {
12931302
.with_span_handle(coordinate, context.expressions));
12941303
}
12951304
if let Some(expr) = array_index {
1296-
match *context.resolve_type(expr, &self.valid_expression_set)? {
1305+
match *context
1306+
.resolve_type_inner(expr, &self.valid_expression_set)?
1307+
{
12971308
Ti::Scalar(crate::Scalar {
12981309
kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
12991310
width: _,
@@ -1404,15 +1415,15 @@ impl super::Validator {
14041415
}
14051416
};
14061417

1407-
if *context.resolve_type(value, &self.valid_expression_set)? != value_ty {
1418+
if *context.resolve_type_inner(value, &self.valid_expression_set)? != value_ty {
14081419
return Err(FunctionError::InvalidImageAtomicValue(value)
14091420
.with_span_handle(value, context.expressions));
14101421
}
14111422
}
14121423
S::WorkGroupUniformLoad { pointer, result } => {
14131424
stages &= super::ShaderStages::COMPUTE;
14141425
let pointer_inner =
1415-
context.resolve_type(pointer, &self.valid_expression_set)?;
1426+
context.resolve_type_inner(pointer, &self.valid_expression_set)?;
14161427
match *pointer_inner {
14171428
Ti::Pointer {
14181429
space: AddressSpace::WorkGroup,
@@ -1468,9 +1479,10 @@ impl super::Validator {
14681479
acceleration_structure,
14691480
descriptor,
14701481
} => {
1471-
match *context
1472-
.resolve_type(acceleration_structure, &self.valid_expression_set)?
1473-
{
1482+
match *context.resolve_type_inner(
1483+
acceleration_structure,
1484+
&self.valid_expression_set,
1485+
)? {
14741486
Ti::AccelerationStructure { vertex_return } => {
14751487
if (!vertex_return) && rq_vertex_return {
14761488
return Err(FunctionError::MissingAccelerationStructureVertexReturn(acceleration_structure, query).with_span_static(span, "invalid acceleration structure"));
@@ -1483,8 +1495,8 @@ impl super::Validator {
14831495
.with_span_static(span, "invalid acceleration structure"))
14841496
}
14851497
}
1486-
let desc_ty_given =
1487-
context.resolve_type(descriptor, &self.valid_expression_set)?;
1498+
let desc_ty_given = context
1499+
.resolve_type_inner(descriptor, &self.valid_expression_set)?;
14881500
let desc_ty_expected = context
14891501
.special_types
14901502
.ray_desc
@@ -1498,7 +1510,7 @@ impl super::Validator {
14981510
self.emit_expression(result, context)?;
14991511
}
15001512
crate::RayQueryFunction::GenerateIntersection { hit_t } => {
1501-
match *context.resolve_type(hit_t, &self.valid_expression_set)? {
1513+
match *context.resolve_type_inner(hit_t, &self.valid_expression_set)? {
15021514
Ti::Scalar(crate::Scalar {
15031515
kind: crate::ScalarKind::Float,
15041516
width: _,
@@ -1534,7 +1546,7 @@ impl super::Validator {
15341546
}
15351547
if let Some(predicate) = predicate {
15361548
let predicate_inner =
1537-
context.resolve_type(predicate, &self.valid_expression_set)?;
1549+
context.resolve_type_inner(predicate, &self.valid_expression_set)?;
15381550
if !matches!(
15391551
*predicate_inner,
15401552
crate::TypeInner::Scalar(crate::Scalar::BOOL,)

0 commit comments

Comments
 (0)