Skip to content

Commit faa810a

Browse files
committed
[naga] Two structs with the same members are not equivalent
Fixes gfx-rs#5796
1 parent c098c63 commit faa810a

File tree

11 files changed

+250
-64
lines changed

11 files changed

+250
-64
lines changed

naga/src/arena/unique_arena.rs

+2-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@ use crate::{FastIndexSet, Span};
1616
/// The element type must implement `Eq` and `Hash`. Insertions of equivalent
1717
/// elements, according to `Eq`, all return the same `Handle`.
1818
///
19-
/// Once inserted, elements may not be mutated.
19+
/// Once inserted, elements generally may not be mutated, although a `replace`
20+
/// method exists to support rare cases.
2021
///
2122
/// `UniqueArena` is similar to [`Arena`]: If `Arena` is vector-like,
2223
/// `UniqueArena` is `HashSet`-like.

naga/src/front/wgsl/lower/conversion.rs

+4-1
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,10 @@ impl<'source> super::ExpressionContext<'source, '_, '_> {
4646
}
4747

4848
// If `expr` already has the requested type, we're done.
49-
if expr_inner.non_struct_equivalent(goal_inner, types) {
49+
if self
50+
.module
51+
.compare_types(expr_resolution.clone(), goal_ty.clone())
52+
{
5053
return Ok(expr);
5154
}
5255

naga/src/front/wgsl/lower/mod.rs

+4-9
Original file line numberDiff line numberDiff line change
@@ -1276,13 +1276,11 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
12761276
})?;
12771277

12781278
let init_ty = ectx.register_type(init)?;
1279-
let explicit_inner = &ectx.module.types[explicit_ty].inner;
1280-
let init_inner = &ectx.module.types[init_ty].inner;
1281-
if !explicit_inner.non_struct_equivalent(init_inner, &ectx.module.types) {
1279+
if !ectx.module.compare_types(explicit_ty, init_ty) {
12821280
return Err(Box::new(Error::InitializationTypeMismatch {
12831281
name: name.span,
1284-
expected: ectx.type_inner_to_string(explicit_inner),
1285-
got: ectx.type_inner_to_string(init_inner),
1282+
expected: ectx.type_to_string(explicit_ty),
1283+
got: ectx.type_to_string(init_ty),
12861284
}));
12871285
}
12881286
ty = explicit_ty;
@@ -1508,10 +1506,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
15081506
if let Some(ty) = explicit_ty {
15091507
let mut ctx = ctx.as_expression(block, &mut emitter);
15101508
let init_ty = ctx.register_type(value)?;
1511-
if !ctx.module.types[ty]
1512-
.inner
1513-
.non_struct_equivalent(&ctx.module.types[init_ty].inner, &ctx.module.types)
1514-
{
1509+
if !ctx.module.compare_types(ty, init_ty) {
15151510
return Err(Box::new(Error::InitializationTypeMismatch {
15161511
name: l.name.span,
15171512
expected: ctx.type_to_string(ty),

naga/src/proc/mod.rs

+17-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ pub use namer::{EntryPointIndex, NameKey, Namer};
2121
pub use terminator::ensure_block_returns;
2222
use thiserror::Error;
2323
pub use type_methods::min_max_float_representable_by;
24-
pub use typifier::{ResolveContext, ResolveError, TypeResolution};
24+
pub use typifier::{compare_types, ResolveContext, ResolveError, TypeResolution};
2525

2626
impl From<super::StorageFormat> for super::Scalar {
2727
fn from(format: super::StorageFormat) -> Self {
@@ -399,6 +399,14 @@ impl crate::Module {
399399
global_expressions: &self.global_expressions,
400400
}
401401
}
402+
403+
pub fn compare_types<L: Into<TypeResolution>, R: Into<TypeResolution>>(
404+
&self,
405+
lhs: L,
406+
rhs: R,
407+
) -> bool {
408+
compare_types(lhs, rhs, &self.types)
409+
}
402410
}
403411

404412
#[derive(Debug)]
@@ -487,6 +495,14 @@ impl GlobalCtx<'_> {
487495
_ => get(*self, handle, arena),
488496
}
489497
}
498+
499+
pub fn compare_types<L: Into<TypeResolution>, R: Into<TypeResolution>>(
500+
&self,
501+
lhs: L,
502+
rhs: R,
503+
) -> bool {
504+
compare_types(lhs, rhs, self.types)
505+
}
490506
}
491507

492508
#[derive(Error, Debug, Clone, Copy, PartialEq)]

naga/src/proc/type_methods.rs

+10-2
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
//! [`Scalar`]: crate::Scalar
55
//! [`ScalarKind`]: crate::ScalarKind
66
7+
use crate::ir;
8+
79
use super::TypeResolution;
810

911
impl crate::ScalarKind {
@@ -224,7 +226,7 @@ impl crate::TypeInner {
224226
}
225227
}
226228

227-
/// Compare `self` and `rhs` as types.
229+
/// Compare value type `self` and `rhs` as types.
228230
///
229231
/// This is mostly the same as `<TypeInner as Eq>::eq`, but it treats
230232
/// `ValuePointer` and `Pointer` types as equivalent. This method
@@ -237,11 +239,17 @@ impl crate::TypeInner {
237239
/// values with `==`.
238240
pub fn non_struct_equivalent(
239241
&self,
240-
rhs: &crate::TypeInner,
242+
rhs: &ir::TypeInner,
241243
types: &crate::UniqueArena<crate::Type>,
242244
) -> bool {
243245
let left = self.canonical_form(types);
244246
let right = rhs.canonical_form(types);
247+
248+
let left_struct = matches!(*self, ir::TypeInner::Struct { .. });
249+
let right_struct = matches!(*rhs, ir::TypeInner::Struct { .. });
250+
251+
assert!(!left_struct || !right_struct);
252+
245253
left.as_ref().unwrap_or(self) == right.as_ref().unwrap_or(rhs)
246254
}
247255

naga/src/proc/typifier.rs

+46-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,10 @@ use alloc::{format, string::String};
22

33
use thiserror::Error;
44

5-
use crate::arena::{Arena, Handle, UniqueArena};
5+
use crate::{
6+
arena::{Arena, Handle, UniqueArena},
7+
ir,
8+
};
69

710
/// The result of computing an expression's type.
811
///
@@ -121,6 +124,12 @@ impl TypeResolution {
121124
}
122125
}
123126

127+
impl From<Handle<crate::Type>> for TypeResolution {
128+
fn from(value: Handle<crate::Type>) -> Self {
129+
Self::Handle(value)
130+
}
131+
}
132+
124133
// Clone is only implemented for numeric variants of `TypeInner`.
125134
impl Clone for TypeResolution {
126135
fn clone(&self) -> Self {
@@ -916,6 +925,42 @@ impl<'a> ResolveContext<'a> {
916925
}
917926
}
918927

928+
/// Compare two types.
929+
///
930+
/// This is the most general way of comparing two types, as it can distinguish
931+
/// two structs with different names but the same members. For other ways, see
932+
/// `TypeInner::non_struct_equivalent` or even `TypeInner::eq`.
933+
///
934+
/// This is usually called via the like-named methods on `Module` and `GlobalCtx`.
935+
pub fn compare_types<L: Into<TypeResolution>, R: Into<TypeResolution>>(
936+
lhs: L,
937+
rhs: R,
938+
types: &UniqueArena<crate::Type>,
939+
) -> bool {
940+
let lhs_tr = lhs.into();
941+
let rhs_tr = rhs.into();
942+
943+
match lhs_tr {
944+
TypeResolution::Handle(lhs_handle)
945+
if matches!(
946+
types[lhs_handle],
947+
ir::Type {
948+
inner: ir::TypeInner::Struct { .. },
949+
..
950+
}
951+
) =>
952+
{
953+
// Structs can only be in the arena, not in a TypeResolution::Value
954+
rhs_tr
955+
.handle()
956+
.is_some_and(|rhs_handle| lhs_handle == rhs_handle)
957+
}
958+
_ => lhs_tr
959+
.inner_with(types)
960+
.non_struct_equivalent(rhs_tr.inner_with(types), types),
961+
}
962+
}
963+
919964
#[test]
920965
fn test_error_size() {
921966
assert_eq!(size_of::<ResolveError>(), 32);

naga/src/valid/compose.rs

+2-10
Original file line numberDiff line numberDiff line change
@@ -84,11 +84,7 @@ pub fn validate_compose(
8484
});
8585
}
8686
for (index, comp_res) in component_resolutions.enumerate() {
87-
let base_inner = &gctx.types[base].inner;
88-
let comp_res_inner = comp_res.inner_with(gctx.types);
89-
// We don't support arrays of pointers, but it seems best not to
90-
// embed that assumption here, so use `TypeInner::equivalent`.
91-
if !base_inner.non_struct_equivalent(comp_res_inner, gctx.types) {
87+
if !gctx.compare_types(base, comp_res.clone()) {
9288
log::error!("Array component[{}] type {:?}", index, comp_res);
9389
return Err(ComposeError::ComponentType {
9490
index: index as u32,
@@ -105,11 +101,7 @@ pub fn validate_compose(
105101
}
106102
for (index, (member, comp_res)) in members.iter().zip(component_resolutions).enumerate()
107103
{
108-
let member_inner = &gctx.types[member.ty].inner;
109-
let comp_res_inner = comp_res.inner_with(gctx.types);
110-
// We don't support pointers in structs, but it seems best not to embed
111-
// that assumption here, so use `TypeInner::equivalent`.
112-
if !comp_res_inner.non_struct_equivalent(member_inner, gctx.types) {
104+
if !gctx.compare_types(member.ty, comp_res.clone()) {
113105
log::error!("Struct component[{}] type {:?}", index, comp_res);
114106
return Err(ComposeError::ComponentType {
115107
index: index as u32,

naga/src/valid/function.rs

+29-16
Original file line numberDiff line numberDiff line change
@@ -120,8 +120,11 @@ pub enum FunctionError {
120120
ContinueOutsideOfLoop,
121121
#[error("The `return` is called within a `continuing` block")]
122122
InvalidReturnSpot,
123-
#[error("The `return` value {0:?} does not match the function return value")]
124-
InvalidReturnType(Option<Handle<crate::Expression>>),
123+
#[error("The `return` expression {expression:?} does not match the declared return type {expected_ty:?}")]
124+
InvalidReturnType {
125+
expression: Option<Handle<crate::Expression>>,
126+
expected_ty: Option<Handle<crate::Type>>,
127+
},
125128
#[error("The `if` condition {0:?} is not a boolean scalar")]
126129
InvalidIfType(Handle<crate::Expression>),
127130
#[error("The `switch` value {0:?} is not an integer scalar")]
@@ -313,6 +316,14 @@ impl<'a> BlockContext<'a> {
313316
fn inner_type<'t>(&'t self, ty: &'t TypeResolution) -> &'t crate::TypeInner {
314317
ty.inner_with(self.types)
315318
}
319+
320+
fn compare_types<L: Into<TypeResolution>, R: Into<TypeResolution>>(
321+
&self,
322+
lhs: L,
323+
rhs: R,
324+
) -> bool {
325+
crate::proc::compare_types(lhs, rhs, self.types)
326+
}
316327
}
317328

318329
impl super::Validator {
@@ -338,8 +349,7 @@ impl super::Validator {
338349
CallError::Argument { index, source }
339350
.with_span_handle(expr, context.expressions)
340351
})?;
341-
let arg_inner = &context.types[arg.ty].inner;
342-
if !ty.non_struct_equivalent(arg_inner, context.types) {
352+
if !context.compare_types(arg.ty, ty.clone()) {
343353
return Err(CallError::ArgumentType {
344354
index,
345355
required: arg.ty,
@@ -964,13 +974,12 @@ impl super::Validator {
964974
let value_ty = value
965975
.map(|expr| context.resolve_type(expr, &self.valid_expression_set))
966976
.transpose()?;
967-
let expected_ty = context.return_type.map(|ty| &context.types[ty].inner);
968977
// We can't return pointers, but it seems best not to embed that
969978
// assumption here, so use `TypeInner::equivalent` for comparison.
970-
let okay = match (value_ty, expected_ty) {
979+
let okay = match (value_ty, context.return_type) {
971980
(None, None) => true,
972-
(Some(value_inner), Some(expected_inner)) => {
973-
value_inner.non_struct_equivalent(expected_inner, context.types)
981+
(Some(value_inner), Some(expected_ty)) => {
982+
context.compare_types(value_inner.clone(), expected_ty)
974983
}
975984
(_, _) => false,
976985
};
@@ -979,14 +988,20 @@ impl super::Validator {
979988
log::error!(
980989
"Returning {:?} where {:?} is expected",
981990
value_ty,
982-
expected_ty
991+
context.return_type,
983992
);
984993
if let Some(handle) = value {
985-
return Err(FunctionError::InvalidReturnType(value)
986-
.with_span_handle(handle, context.expressions));
994+
return Err(FunctionError::InvalidReturnType {
995+
expression: value,
996+
expected_ty: context.return_type,
997+
}
998+
.with_span_handle(handle, context.expressions));
987999
} else {
988-
return Err(FunctionError::InvalidReturnType(value)
989-
.with_span_static(span, "invalid return"));
1000+
return Err(FunctionError::InvalidReturnType {
1001+
expression: value,
1002+
expected_ty: context.return_type,
1003+
}
1004+
.with_span_static(span, "invalid return"));
9901005
}
9911006
}
9921007
finished = true;
@@ -1640,9 +1655,7 @@ impl super::Validator {
16401655
}
16411656

16421657
if let Some(init) = var.init {
1643-
let decl_ty = &gctx.types[var.ty].inner;
1644-
let init_ty = fun_info[init].ty.inner_with(gctx.types);
1645-
if !decl_ty.non_struct_equivalent(init_ty, gctx.types) {
1658+
if !gctx.compare_types(var.ty, fun_info[init].ty.clone()) {
16461659
return Err(LocalVariableError::InitializerType);
16471660
}
16481661

naga/src/valid/interface.rs

+1-3
Original file line numberDiff line numberDiff line change
@@ -636,9 +636,7 @@ impl super::Validator {
636636
return Err(GlobalVariableError::InitializerExprType);
637637
}
638638

639-
let decl_ty = &gctx.types[var.ty].inner;
640-
let init_ty = mod_info[init].inner_with(gctx.types);
641-
if !decl_ty.non_struct_equivalent(init_ty, gctx.types) {
639+
if !gctx.compare_types(var.ty, mod_info[init].clone()) {
642640
return Err(GlobalVariableError::InitializerType);
643641
}
644642
}

naga/src/valid/mod.rs

+4-8
Original file line numberDiff line numberDiff line change
@@ -532,9 +532,7 @@ impl Validator {
532532
return Err(ConstantError::InitializerExprType);
533533
}
534534

535-
let decl_ty = &gctx.types[con.ty].inner;
536-
let init_ty = mod_info[con.init].inner_with(gctx.types);
537-
if !decl_ty.non_struct_equivalent(init_ty, gctx.types) {
535+
if !gctx.compare_types(con.ty, mod_info[con.init].clone()) {
538536
return Err(ConstantError::InvalidType);
539537
}
540538

@@ -560,9 +558,8 @@ impl Validator {
560558
return Err(OverrideError::NonConstructibleType);
561559
}
562560

563-
let decl_ty = &gctx.types[o.ty].inner;
564-
match decl_ty {
565-
&crate::TypeInner::Scalar(
561+
match gctx.types[o.ty].inner {
562+
crate::TypeInner::Scalar(
566563
crate::Scalar::BOOL
567564
| crate::Scalar::I32
568565
| crate::Scalar::U32
@@ -574,8 +571,7 @@ impl Validator {
574571
}
575572

576573
if let Some(init) = o.init {
577-
let init_ty = mod_info[init].inner_with(gctx.types);
578-
if !decl_ty.non_struct_equivalent(init_ty, gctx.types) {
574+
if !gctx.compare_types(o.ty, mod_info[init].clone()) {
579575
return Err(OverrideError::InvalidType);
580576
}
581577
} else if self.overrides_resolved {

0 commit comments

Comments
 (0)