Skip to content

Commit 45d8135

Browse files
committed
[naga] Two structs with the same members are not equivalent
Fixes gfx-rs#5796
1 parent 545d84f commit 45d8135

File tree

12 files changed

+302
-79
lines changed

12 files changed

+302
-79
lines changed

naga/src/arena/unique_arena.rs

+2-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@ use crate::{FastIndexSet, Span};
1616
/// The element type must implement `Eq` and `Hash`. Insertions of equivalent
1717
/// elements, according to `Eq`, all return the same `Handle`.
1818
///
19-
/// Once inserted, elements may not be mutated.
19+
/// Once inserted, elements generally may not be mutated, although a `replace`
20+
/// method exists to support rare cases.
2021
///
2122
/// `UniqueArena` is similar to [`Arena`]: If `Arena` is vector-like,
2223
/// `UniqueArena` is `HashSet`-like.

naga/src/front/wgsl/lower/conversion.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ impl<'source> super::ExpressionContext<'source, '_, '_> {
4646
}
4747

4848
// If `expr` already has the requested type, we're done.
49-
if expr_inner.non_struct_equivalent(goal_inner, types) {
49+
if self.module.compare_types(expr_resolution, goal_ty) {
5050
return Ok(expr);
5151
}
5252

naga/src/front/wgsl/lower/mod.rs

+4-3
Original file line numberDiff line numberDiff line change
@@ -1311,9 +1311,10 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
13111311
})?;
13121312

13131313
let init_ty = ectx.register_type(init)?;
1314-
let explicit_inner = &ectx.module.types[explicit_ty].inner;
1315-
let init_inner = &ectx.module.types[init_ty].inner;
1316-
if !explicit_inner.non_struct_equivalent(init_inner, &ectx.module.types) {
1314+
if !ectx.module.compare_types(
1315+
&crate::proc::TypeResolution::Handle(explicit_ty),
1316+
&crate::proc::TypeResolution::Handle(init_ty),
1317+
) {
13171318
return Err(Box::new(Error::InitializationTypeMismatch {
13181319
name: name.span,
13191320
expected: ectx.type_to_string(explicit_ty),

naga/src/ir/mod.rs

+23-5
Original file line numberDiff line numberDiff line change
@@ -636,6 +636,15 @@ pub struct Type {
636636
}
637637

638638
/// Enum with additional information, depending on the kind of type.
639+
///
640+
/// Comparison using `==` is not reliable in the case of [`Pointer`],
641+
/// [`ValuePointer`], or [`Struct`] variants. For these variants,
642+
/// use [`TypeInner::non_struct_equivalent`] or [`compare_types`].
643+
///
644+
/// [`compare_types`]: crate::proc::compare_types
645+
/// [`ValuePointer`]: TypeInner::ValuePointer
646+
/// [`Pointer`]: TypeInner::Pointer
647+
/// [`Struct`]: TypeInner::Struct
639648
#[derive(Clone, Debug, Eq, Hash, PartialEq)]
640649
#[cfg_attr(feature = "serialize", derive(Serialize))]
641650
#[cfg_attr(feature = "deserialize", derive(Deserialize))]
@@ -656,8 +665,9 @@ pub enum TypeInner {
656665
/// Pointer to another type.
657666
///
658667
/// Pointers to scalars and vectors should be treated as equivalent to
659-
/// [`ValuePointer`] types. Use the [`TypeInner::equivalent`] method to
660-
/// compare types in a way that treats pointers correctly.
668+
/// [`ValuePointer`] types. Use either [`TypeInner::non_struct_equivalent`]
669+
/// or [`compare_types`] to compare types in a way that treats pointers
670+
/// correctly.
661671
///
662672
/// ## Pointers to non-`SIZED` types
663673
///
@@ -679,6 +689,7 @@ pub enum TypeInner {
679689
/// [`ValuePointer`]: TypeInner::ValuePointer
680690
/// [`GlobalVariable`]: Expression::GlobalVariable
681691
/// [`AccessIndex`]: Expression::AccessIndex
692+
/// [`compare_types`]: crate::proc::compare_types
682693
Pointer {
683694
base: Handle<Type>,
684695
space: AddressSpace,
@@ -690,12 +701,13 @@ pub enum TypeInner {
690701
/// `Scalar` or `Vector` type. This is for use in [`TypeResolution::Value`]
691702
/// variants; see the documentation for [`TypeResolution`] for details.
692703
///
693-
/// Use the [`TypeInner::equivalent`] method to compare types that could be
694-
/// pointers, to ensure that `Pointer` and `ValuePointer` types are
695-
/// recognized as equivalent.
704+
/// Use [`TypeInner::non_struct_equivalent`] or [`compare_types`] to compare
705+
/// types that could be pointers, to ensure that `Pointer` and
706+
/// `ValuePointer` types are recognized as equivalent.
696707
///
697708
/// [`TypeResolution`]: crate::proc::TypeResolution
698709
/// [`TypeResolution::Value`]: crate::proc::TypeResolution::Value
710+
/// [`compare_types`]: crate::proc::compare_types
699711
ValuePointer {
700712
size: Option<VectorSize>,
701713
scalar: Scalar,
@@ -744,9 +756,15 @@ pub enum TypeInner {
744756
/// struct, which may be a dynamically sized [`Array`]. The
745757
/// `Struct` type itself is `SIZED` when all its members are `SIZED`.
746758
///
759+
/// Two structure types with different names are not equivalent. Because
760+
/// this variant does not contain the name, it is not possible to use it
761+
/// to compare struct types. Use [`compare_types`] to compare two types
762+
/// that may be structs.
763+
///
747764
/// [`DATA`]: crate::valid::TypeFlags::DATA
748765
/// [`SIZED`]: crate::∅TypeFlags::SIZED
749766
/// [`Array`]: TypeInner::Array
767+
/// [`compare_types`]: crate::proc::compare_types
750768
Struct {
751769
members: Vec<StructMember>,
752770
//TODO: should this be unaligned?

naga/src/proc/mod.rs

+9-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ pub use overloads::{Conclusion, MissingSpecialType, OverloadSet, Rule};
2323
pub use terminator::ensure_block_returns;
2424
use thiserror::Error;
2525
pub use type_methods::min_max_float_representable_by;
26-
pub use typifier::{ResolveContext, ResolveError, TypeResolution};
26+
pub use typifier::{compare_types, ResolveContext, ResolveError, TypeResolution};
2727

2828
impl From<super::StorageFormat> for super::Scalar {
2929
fn from(format: super::StorageFormat) -> Self {
@@ -403,6 +403,10 @@ impl crate::Module {
403403
global_expressions: &self.global_expressions,
404404
}
405405
}
406+
407+
pub fn compare_types(&self, lhs: &TypeResolution, rhs: &TypeResolution) -> bool {
408+
compare_types(lhs, rhs, &self.types)
409+
}
406410
}
407411

408412
#[derive(Debug)]
@@ -491,6 +495,10 @@ impl GlobalCtx<'_> {
491495
_ => get(*self, handle, arena),
492496
}
493497
}
498+
499+
pub fn compare_types(&self, lhs: &TypeResolution, rhs: &TypeResolution) -> bool {
500+
compare_types(lhs, rhs, self.types)
501+
}
494502
}
495503

496504
#[derive(Error, Debug, Clone, Copy, PartialEq)]

naga/src/proc/type_methods.rs

+23-7
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
//! [`Scalar`]: crate::Scalar
55
//! [`ScalarKind`]: crate::ScalarKind
66
7+
use crate::ir;
8+
79
use super::TypeResolution;
810

911
impl crate::ScalarKind {
@@ -255,24 +257,38 @@ impl crate::TypeInner {
255257
}
256258
}
257259

258-
/// Compare `self` and `rhs` as types.
260+
/// Compare value type `self` and `rhs` as types.
259261
///
260262
/// This is mostly the same as `<TypeInner as Eq>::eq`, but it treats
261-
/// `ValuePointer` and `Pointer` types as equivalent. This method
263+
/// [`ValuePointer`] and [`Pointer`] types as equivalent. This method
262264
/// cannot be used for structs, because it cannot distinguish two
263265
/// structs with different names but the same members. For structs,
264-
/// use `Module::compare_types`.
266+
/// use [`compare_types`].
267+
///
268+
/// When you know that one side of the comparison is never a pointer or
269+
/// struct, it's fine to not bother with canonicalization, and just
270+
/// compare `TypeInner` values with `==`.
271+
///
272+
/// # Panics
273+
///
274+
/// If both `self` and `rhs` are structs.
265275
///
266-
/// When you know that one side of the comparison is never a pointer, it's
267-
/// fine to not bother with canonicalization, and just compare `TypeInner`
268-
/// values with `==`.
276+
/// [`compare_types`]: crate::proc::compare_types
277+
/// [`ValuePointer`]: ir::TypeInner::ValuePointer
278+
/// [`Pointer`]: ir::TypeInner::Pointer
269279
pub fn non_struct_equivalent(
270280
&self,
271-
rhs: &crate::TypeInner,
281+
rhs: &ir::TypeInner,
272282
types: &crate::UniqueArena<crate::Type>,
273283
) -> bool {
274284
let left = self.canonical_form(types);
275285
let right = rhs.canonical_form(types);
286+
287+
let left_struct = matches!(*self, ir::TypeInner::Struct { .. });
288+
let right_struct = matches!(*rhs, ir::TypeInner::Struct { .. });
289+
290+
assert!(!left_struct || !right_struct);
291+
276292
left.as_ref().unwrap_or(self) == right.as_ref().unwrap_or(rhs)
277293
}
278294

naga/src/proc/typifier.rs

+43-2
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,11 @@ use alloc::{format, string::String};
22

33
use thiserror::Error;
44

5-
use crate::arena::{Arena, Handle, UniqueArena};
6-
use crate::common::ForDebugWithTypes;
5+
use crate::{
6+
arena::{Arena, Handle, UniqueArena},
7+
common::ForDebugWithTypes,
8+
ir,
9+
};
710

811
/// The result of computing an expression's type.
912
///
@@ -773,6 +776,44 @@ impl<'a> ResolveContext<'a> {
773776
}
774777
}
775778

779+
/// Compare two types.
780+
///
781+
/// This is the most general way of comparing two types, as it can distinguish
782+
/// two structs with different names but the same members. For other ways, see
783+
/// [`TypeInner::non_struct_equivalent`] and [`TypeInner::eq`].
784+
///
785+
/// In Naga code, this is usually called via the like-named methods on [`Module`],
786+
/// [`GlobalCtx`], and `BlockContext`.
787+
///
788+
/// [`TypeInner::non_struct_equivalent`]: crate::ir::TypeInner::non_struct_equivalent
789+
/// [`TypeInner::eq`]: crate::ir::TypeInner
790+
/// [`Module`]: crate::ir::Module
791+
/// [`GlobalCtx`]: crate::proc::GlobalCtx
792+
pub fn compare_types(
793+
lhs: &TypeResolution,
794+
rhs: &TypeResolution,
795+
types: &UniqueArena<crate::Type>,
796+
) -> bool {
797+
match lhs {
798+
&TypeResolution::Handle(lhs_handle)
799+
if matches!(
800+
types[lhs_handle],
801+
ir::Type {
802+
inner: ir::TypeInner::Struct { .. },
803+
..
804+
}
805+
) =>
806+
{
807+
// Structs can only be in the arena, not in a TypeResolution::Value
808+
rhs.handle()
809+
.is_some_and(|rhs_handle| lhs_handle == rhs_handle)
810+
}
811+
_ => lhs
812+
.inner_with(types)
813+
.non_struct_equivalent(rhs.inner_with(types), types),
814+
}
815+
}
816+
776817
#[test]
777818
fn test_error_size() {
778819
assert_eq!(size_of::<ResolveError>(), 32);

naga/src/valid/compose.rs

+2-10
Original file line numberDiff line numberDiff line change
@@ -84,11 +84,7 @@ pub fn validate_compose(
8484
});
8585
}
8686
for (index, comp_res) in component_resolutions.enumerate() {
87-
let base_inner = &gctx.types[base].inner;
88-
let comp_res_inner = comp_res.inner_with(gctx.types);
89-
// We don't support arrays of pointers, but it seems best not to
90-
// embed that assumption here, so use `TypeInner::equivalent`.
91-
if !base_inner.non_struct_equivalent(comp_res_inner, gctx.types) {
87+
if !gctx.compare_types(&TypeResolution::Handle(base), &comp_res) {
9288
log::error!("Array component[{}] type {:?}", index, comp_res);
9389
return Err(ComposeError::ComponentType {
9490
index: index as u32,
@@ -105,11 +101,7 @@ pub fn validate_compose(
105101
}
106102
for (index, (member, comp_res)) in members.iter().zip(component_resolutions).enumerate()
107103
{
108-
let member_inner = &gctx.types[member.ty].inner;
109-
let comp_res_inner = comp_res.inner_with(gctx.types);
110-
// We don't support pointers in structs, but it seems best not to embed
111-
// that assumption here, so use `TypeInner::equivalent`.
112-
if !comp_res_inner.non_struct_equivalent(member_inner, gctx.types) {
104+
if !gctx.compare_types(&TypeResolution::Handle(member.ty), &comp_res) {
113105
log::error!("Struct component[{}] type {:?}", index, comp_res);
114106
return Err(ComposeError::ComponentType {
115107
index: index as u32,

naga/src/valid/function.rs

+35-26
Original file line numberDiff line numberDiff line change
@@ -120,8 +120,11 @@ pub enum FunctionError {
120120
ContinueOutsideOfLoop,
121121
#[error("The `return` is called within a `continuing` block")]
122122
InvalidReturnSpot,
123-
#[error("The `return` value {0:?} does not match the function return value")]
124-
InvalidReturnType(Option<Handle<crate::Expression>>),
123+
#[error("The `return` expression {expression:?} does not match the declared return type {expected_ty:?}")]
124+
InvalidReturnType {
125+
expression: Option<Handle<crate::Expression>>,
126+
expected_ty: Option<Handle<crate::Type>>,
127+
},
125128
#[error("The `if` condition {0:?} is not a boolean scalar")]
126129
InvalidIfType(Handle<crate::Expression>),
127130
#[error("The `switch` value {0:?} is not an integer scalar")]
@@ -310,8 +313,8 @@ impl<'a> BlockContext<'a> {
310313
self.info[handle].ty.inner_with(self.types)
311314
}
312315

313-
fn inner_type<'t>(&'t self, ty: &'t TypeResolution) -> &'t crate::TypeInner {
314-
ty.inner_with(self.types)
316+
fn compare_types(&self, lhs: &TypeResolution, rhs: &TypeResolution) -> bool {
317+
crate::proc::compare_types(lhs, rhs, self.types)
315318
}
316319
}
317320

@@ -338,8 +341,7 @@ impl super::Validator {
338341
CallError::Argument { index, source }
339342
.with_span_handle(expr, context.expressions)
340343
})?;
341-
let arg_inner = &context.types[arg.ty].inner;
342-
if !ty.inner_with(context.types).non_struct_equivalent(arg_inner, context.types) {
344+
if !context.compare_types(&TypeResolution::Handle(arg.ty), ty) {
343345
return Err(CallError::ArgumentType {
344346
index,
345347
required: arg.ty,
@@ -964,13 +966,12 @@ impl super::Validator {
964966
let value_ty = value
965967
.map(|expr| context.resolve_type(expr, &self.valid_expression_set))
966968
.transpose()?;
967-
let expected_ty = context.return_type.map(|ty| &context.types[ty].inner);
968969
// We can't return pointers, but it seems best not to embed that
969970
// assumption here, so use `TypeInner::equivalent` for comparison.
970-
let okay = match (value_ty, expected_ty) {
971+
let okay = match (value_ty, context.return_type) {
971972
(None, None) => true,
972-
(Some(value_inner), Some(expected_inner)) => {
973-
value_inner.inner_with(context.types).non_struct_equivalent(expected_inner, context.types)
973+
(Some(value_inner), Some(expected_ty)) => {
974+
context.compare_types(value_inner, &TypeResolution::Handle(expected_ty))
974975
}
975976
(_, _) => false,
976977
};
@@ -979,14 +980,20 @@ impl super::Validator {
979980
log::error!(
980981
"Returning {:?} where {:?} is expected",
981982
value_ty,
982-
expected_ty
983+
context.return_type,
983984
);
984985
if let Some(handle) = value {
985-
return Err(FunctionError::InvalidReturnType(value)
986-
.with_span_handle(handle, context.expressions));
986+
return Err(FunctionError::InvalidReturnType {
987+
expression: value,
988+
expected_ty: context.return_type,
989+
}
990+
.with_span_handle(handle, context.expressions));
987991
} else {
988-
return Err(FunctionError::InvalidReturnType(value)
989-
.with_span_static(span, "invalid return"));
992+
return Err(FunctionError::InvalidReturnType {
993+
expression: value,
994+
expected_ty: context.return_type,
995+
}
996+
.with_span_static(span, "invalid return"));
990997
}
991998
}
992999
finished = true;
@@ -1036,7 +1043,8 @@ impl super::Validator {
10361043
}
10371044
}
10381045

1039-
let value_ty = context.resolve_type_inner(value, &self.valid_expression_set)?;
1046+
let value_tr = context.resolve_type(value, &self.valid_expression_set)?;
1047+
let value_ty = value_tr.inner_with(context.types);
10401048
match *value_ty {
10411049
Ti::Image { .. } | Ti::Sampler { .. } => {
10421050
return Err(FunctionError::InvalidStoreTexture {
@@ -1053,16 +1061,19 @@ impl super::Validator {
10531061
}
10541062

10551063
let pointer_ty = context.resolve_pointer_type(pointer);
1056-
let good = match pointer_ty
1057-
.pointer_base_type()
1064+
let pointer_base_tr = pointer_ty.pointer_base_type();
1065+
let pointer_base_ty = pointer_base_tr
10581066
.as_ref()
1059-
.map(|ty| context.inner_type(ty))
1060-
{
1067+
.map(|ty| ty.inner_with(context.types));
1068+
let good = if let Some(&Ti::Atomic(ref scalar)) = pointer_base_ty {
10611069
// The Naga IR allows storing a scalar to an atomic.
1062-
Some(&Ti::Atomic(ref scalar)) => *value_ty == Ti::Scalar(*scalar),
1063-
Some(other) => *value_ty == *other,
1064-
None => false,
1070+
*value_ty == Ti::Scalar(*scalar)
1071+
} else if let Some(tr) = pointer_base_tr {
1072+
context.compare_types(value_tr, &tr)
1073+
} else {
1074+
false
10651075
};
1076+
10661077
if !good {
10671078
return Err(FunctionError::InvalidStoreTypes { pointer, value }
10681079
.with_span()
@@ -1640,9 +1651,7 @@ impl super::Validator {
16401651
}
16411652

16421653
if let Some(init) = var.init {
1643-
let decl_ty = &gctx.types[var.ty].inner;
1644-
let init_ty = fun_info[init].ty.inner_with(gctx.types);
1645-
if !decl_ty.non_struct_equivalent(init_ty, gctx.types) {
1654+
if !gctx.compare_types(&TypeResolution::Handle(var.ty), &fun_info[init].ty) {
16461655
return Err(LocalVariableError::InitializerType);
16471656
}
16481657

0 commit comments

Comments
 (0)