Skip to content

Commit 64bd8c1

Browse files
Make same_type_modulo_infer a proper TypeRelation
1 parent 86c6ebe commit 64bd8c1

File tree

4 files changed

+150
-53
lines changed

4 files changed

+150
-53
lines changed

compiler/rustc_infer/src/infer/error_reporting/mod.rs

+79-53
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ use rustc_hir::lang_items::LangItem;
6666
use rustc_hir::Node;
6767
use rustc_middle::dep_graph::DepContext;
6868
use rustc_middle::ty::print::with_no_trimmed_paths;
69+
use rustc_middle::ty::relate::{self, RelateResult, TypeRelation};
6970
use rustc_middle::ty::{
7071
self, error::TypeError, Binder, List, Region, Subst, Ty, TyCtxt, TypeFoldable,
7172
TypeSuperVisitable, TypeVisitable,
@@ -2660,67 +2661,92 @@ impl<'a, 'tcx> InferCtxt<'a, 'tcx> {
26602661
/// Float types, respectively). When comparing two ADTs, these rules apply recursively.
26612662
pub fn same_type_modulo_infer(&self, a: Ty<'tcx>, b: Ty<'tcx>) -> bool {
26622663
let (a, b) = self.resolve_vars_if_possible((a, b));
2663-
match (a.kind(), b.kind()) {
2664-
(&ty::Adt(def_a, substs_a), &ty::Adt(def_b, substs_b)) => {
2665-
if def_a != def_b {
2666-
return false;
2667-
}
2664+
SameTypeModuloInfer(self).relate(a, b).is_ok()
2665+
}
2666+
}
26682667

2669-
substs_a
2670-
.types()
2671-
.zip(substs_b.types())
2672-
.all(|(a, b)| self.same_type_modulo_infer(a, b))
2673-
}
2674-
(&ty::FnDef(did_a, substs_a), &ty::FnDef(did_b, substs_b)) => {
2675-
if did_a != did_b {
2676-
return false;
2677-
}
2668+
struct SameTypeModuloInfer<'a, 'tcx>(&'a InferCtxt<'a, 'tcx>);
26782669

2679-
substs_a
2680-
.types()
2681-
.zip(substs_b.types())
2682-
.all(|(a, b)| self.same_type_modulo_infer(a, b))
2683-
}
2684-
(&ty::Int(_) | &ty::Uint(_), &ty::Infer(ty::InferTy::IntVar(_)))
2670+
impl<'tcx> TypeRelation<'tcx> for SameTypeModuloInfer<'_, 'tcx> {
2671+
fn tcx(&self) -> TyCtxt<'tcx> {
2672+
self.0.tcx
2673+
}
2674+
2675+
fn param_env(&self) -> ty::ParamEnv<'tcx> {
2676+
// Unused, only for consts which we treat as always equal
2677+
ty::ParamEnv::empty()
2678+
}
2679+
2680+
fn tag(&self) -> &'static str {
2681+
"SameTypeModuloInfer"
2682+
}
2683+
2684+
fn a_is_expected(&self) -> bool {
2685+
true
2686+
}
2687+
2688+
fn relate_with_variance<T: relate::Relate<'tcx>>(
2689+
&mut self,
2690+
_variance: ty::Variance,
2691+
_info: ty::VarianceDiagInfo<'tcx>,
2692+
a: T,
2693+
b: T,
2694+
) -> relate::RelateResult<'tcx, T> {
2695+
self.relate(a, b)
2696+
}
2697+
2698+
fn tys(&mut self, a: Ty<'tcx>, b: Ty<'tcx>) -> RelateResult<'tcx, Ty<'tcx>> {
2699+
match (a.kind(), b.kind()) {
2700+
(ty::Int(_) | ty::Uint(_), ty::Infer(ty::InferTy::IntVar(_)))
26852701
| (
2686-
&ty::Infer(ty::InferTy::IntVar(_)),
2687-
&ty::Int(_) | &ty::Uint(_) | &ty::Infer(ty::InferTy::IntVar(_)),
2702+
ty::Infer(ty::InferTy::IntVar(_)),
2703+
ty::Int(_) | ty::Uint(_) | ty::Infer(ty::InferTy::IntVar(_)),
26882704
)
2689-
| (&ty::Float(_), &ty::Infer(ty::InferTy::FloatVar(_)))
2705+
| (ty::Float(_), ty::Infer(ty::InferTy::FloatVar(_)))
26902706
| (
2691-
&ty::Infer(ty::InferTy::FloatVar(_)),
2692-
&ty::Float(_) | &ty::Infer(ty::InferTy::FloatVar(_)),
2707+
ty::Infer(ty::InferTy::FloatVar(_)),
2708+
ty::Float(_) | ty::Infer(ty::InferTy::FloatVar(_)),
26932709
)
2694-
| (&ty::Infer(ty::InferTy::TyVar(_)), _)
2695-
| (_, &ty::Infer(ty::InferTy::TyVar(_))) => true,
2696-
(&ty::Ref(_, ty_a, mut_a), &ty::Ref(_, ty_b, mut_b)) => {
2697-
mut_a == mut_b && self.same_type_modulo_infer(ty_a, ty_b)
2698-
}
2699-
(&ty::RawPtr(a), &ty::RawPtr(b)) => {
2700-
a.mutbl == b.mutbl && self.same_type_modulo_infer(a.ty, b.ty)
2701-
}
2702-
(&ty::Slice(a), &ty::Slice(b)) => self.same_type_modulo_infer(a, b),
2703-
(&ty::Array(a_ty, a_ct), &ty::Array(b_ty, b_ct)) => {
2704-
self.same_type_modulo_infer(a_ty, b_ty) && a_ct == b_ct
2705-
}
2706-
(&ty::Tuple(a), &ty::Tuple(b)) => {
2707-
if a.len() != b.len() {
2708-
return false;
2709-
}
2710-
std::iter::zip(a.iter(), b.iter()).all(|(a, b)| self.same_type_modulo_infer(a, b))
2711-
}
2712-
(&ty::FnPtr(a), &ty::FnPtr(b)) => {
2713-
let a = a.skip_binder().inputs_and_output;
2714-
let b = b.skip_binder().inputs_and_output;
2715-
if a.len() != b.len() {
2716-
return false;
2717-
}
2718-
std::iter::zip(a.iter(), b.iter()).all(|(a, b)| self.same_type_modulo_infer(a, b))
2719-
}
2720-
// FIXME(compiler-errors): This needs to be generalized more
2721-
_ => a == b,
2710+
| (ty::Infer(ty::InferTy::TyVar(_)), _)
2711+
| (_, ty::Infer(ty::InferTy::TyVar(_))) => Ok(a),
2712+
(ty::Infer(_), _) | (_, ty::Infer(_)) => Err(TypeError::Mismatch),
2713+
_ => relate::super_relate_tys(self, a, b),
27222714
}
27232715
}
2716+
2717+
fn regions(
2718+
&mut self,
2719+
a: ty::Region<'tcx>,
2720+
b: ty::Region<'tcx>,
2721+
) -> RelateResult<'tcx, ty::Region<'tcx>> {
2722+
if (a.is_var() && b.is_free_or_static()) || (b.is_var() && a.is_free_or_static()) || a == b
2723+
{
2724+
Ok(a)
2725+
} else {
2726+
Err(TypeError::Mismatch)
2727+
}
2728+
}
2729+
2730+
fn binders<T>(
2731+
&mut self,
2732+
a: ty::Binder<'tcx, T>,
2733+
b: ty::Binder<'tcx, T>,
2734+
) -> relate::RelateResult<'tcx, ty::Binder<'tcx, T>>
2735+
where
2736+
T: relate::Relate<'tcx>,
2737+
{
2738+
Ok(ty::Binder::dummy(self.relate(a.skip_binder(), b.skip_binder())?))
2739+
}
2740+
2741+
fn consts(
2742+
&mut self,
2743+
a: ty::Const<'tcx>,
2744+
_b: ty::Const<'tcx>,
2745+
) -> relate::RelateResult<'tcx, ty::Const<'tcx>> {
2746+
// FIXME(compiler-errors): This could at least do some first-order
2747+
// relation
2748+
Ok(a)
2749+
}
27242750
}
27252751

27262752
impl<'a, 'tcx> InferCtxt<'a, 'tcx> {

compiler/rustc_middle/src/ty/sty.rs

+4
Original file line numberDiff line numberDiff line change
@@ -1617,6 +1617,10 @@ impl<'tcx> Region<'tcx> {
16171617
_ => self.is_free(),
16181618
}
16191619
}
1620+
1621+
pub fn is_var(self) -> bool {
1622+
matches!(self.kind(), ty::ReVar(_))
1623+
}
16201624
}
16211625

16221626
/// Type utilities
+45
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
// This code (probably) _should_ compile, but it currently does not because we
2+
// are not smart enough about implied bounds.
3+
4+
use std::io;
5+
6+
fn real_dispatch<T, F>(f: F) -> Result<(), io::Error>
7+
//~^ NOTE required by a bound in this
8+
where
9+
F: FnOnce(&mut UIView<T>) -> Result<(), io::Error> + Send + 'static,
10+
//~^ NOTE required by this bound in `real_dispatch`
11+
//~| NOTE required by a bound in `real_dispatch`
12+
{
13+
todo!()
14+
}
15+
16+
#[derive(Debug)]
17+
struct UIView<'a, T: 'a> {
18+
_phantom: std::marker::PhantomData<&'a mut T>,
19+
}
20+
21+
trait Handle<'a, T: 'a, V, R> {
22+
fn dispatch<F>(&self, f: F) -> Result<(), io::Error>
23+
where
24+
F: FnOnce(&mut V) -> R + Send + 'static;
25+
}
26+
27+
#[derive(Debug, Clone)]
28+
struct TUIHandle<T> {
29+
_phantom: std::marker::PhantomData<T>,
30+
}
31+
32+
impl<'a, T: 'a> Handle<'a, T, UIView<'a, T>, Result<(), io::Error>> for TUIHandle<T> {
33+
fn dispatch<F>(&self, f: F) -> Result<(), io::Error>
34+
where
35+
F: FnOnce(&mut UIView<'a, T>) -> Result<(), io::Error> + Send + 'static,
36+
{
37+
real_dispatch(f)
38+
//~^ ERROR expected a `FnOnce<(&mut UIView<'_, T>,)>` closure, found `F`
39+
//~| NOTE expected an `FnOnce<(&mut UIView<'_, T>,)>` closure, found `F`
40+
//~| NOTE expected a closure with arguments
41+
//~| NOTE required by a bound introduced by this call
42+
}
43+
}
44+
45+
fn main() {}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
error[E0277]: expected a `FnOnce<(&mut UIView<'_, T>,)>` closure, found `F`
2+
--> $DIR/issue-100690.rs:37:23
3+
|
4+
LL | real_dispatch(f)
5+
| ------------- ^ expected an `FnOnce<(&mut UIView<'_, T>,)>` closure, found `F`
6+
| |
7+
| required by a bound introduced by this call
8+
|
9+
= note: expected a closure with arguments `(&mut UIView<'a, T>,)`
10+
found a closure with arguments `(&mut UIView<'_, T>,)`
11+
note: required by a bound in `real_dispatch`
12+
--> $DIR/issue-100690.rs:9:8
13+
|
14+
LL | fn real_dispatch<T, F>(f: F) -> Result<(), io::Error>
15+
| ------------- required by a bound in this
16+
...
17+
LL | F: FnOnce(&mut UIView<T>) -> Result<(), io::Error> + Send + 'static,
18+
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ required by this bound in `real_dispatch`
19+
20+
error: aborting due to previous error
21+
22+
For more information about this error, try `rustc --explain E0277`.

0 commit comments

Comments
 (0)