@@ -66,6 +66,7 @@ use rustc_hir::lang_items::LangItem;
66
66
use rustc_hir:: Node ;
67
67
use rustc_middle:: dep_graph:: DepContext ;
68
68
use rustc_middle:: ty:: print:: with_no_trimmed_paths;
69
+ use rustc_middle:: ty:: relate:: { self , RelateResult , TypeRelation } ;
69
70
use rustc_middle:: ty:: {
70
71
self , error:: TypeError , Binder , List , Region , Subst , Ty , TyCtxt , TypeFoldable ,
71
72
TypeSuperVisitable , TypeVisitable ,
@@ -2661,67 +2662,92 @@ impl<'a, 'tcx> InferCtxt<'a, 'tcx> {
2661
2662
/// Float types, respectively). When comparing two ADTs, these rules apply recursively.
2662
2663
pub fn same_type_modulo_infer ( & self , a : Ty < ' tcx > , b : Ty < ' tcx > ) -> bool {
2663
2664
let ( a, b) = self . resolve_vars_if_possible ( ( a, b) ) ;
2664
- match ( a. kind ( ) , b. kind ( ) ) {
2665
- ( & ty:: Adt ( def_a, substs_a) , & ty:: Adt ( def_b, substs_b) ) => {
2666
- if def_a != def_b {
2667
- return false ;
2668
- }
2665
+ SameTypeModuloInfer ( self ) . relate ( a, b) . is_ok ( )
2666
+ }
2667
+ }
2669
2668
2670
- substs_a
2671
- . types ( )
2672
- . zip ( substs_b. types ( ) )
2673
- . all ( |( a, b) | self . same_type_modulo_infer ( a, b) )
2674
- }
2675
- ( & ty:: FnDef ( did_a, substs_a) , & ty:: FnDef ( did_b, substs_b) ) => {
2676
- if did_a != did_b {
2677
- return false ;
2678
- }
2669
+ struct SameTypeModuloInfer < ' a , ' tcx > ( & ' a InferCtxt < ' a , ' tcx > ) ;
2679
2670
2680
- substs_a
2681
- . types ( )
2682
- . zip ( substs_b. types ( ) )
2683
- . all ( |( a, b) | self . same_type_modulo_infer ( a, b) )
2684
- }
2685
- ( & ty:: Int ( _) | & ty:: Uint ( _) , & ty:: Infer ( ty:: InferTy :: IntVar ( _) ) )
2671
+ impl < ' tcx > TypeRelation < ' tcx > for SameTypeModuloInfer < ' _ , ' tcx > {
2672
+ fn tcx ( & self ) -> TyCtxt < ' tcx > {
2673
+ self . 0 . tcx
2674
+ }
2675
+
2676
+ fn param_env ( & self ) -> ty:: ParamEnv < ' tcx > {
2677
+ // Unused, only for consts which we treat as always equal
2678
+ ty:: ParamEnv :: empty ( )
2679
+ }
2680
+
2681
+ fn tag ( & self ) -> & ' static str {
2682
+ "SameTypeModuloInfer"
2683
+ }
2684
+
2685
+ fn a_is_expected ( & self ) -> bool {
2686
+ true
2687
+ }
2688
+
2689
+ fn relate_with_variance < T : relate:: Relate < ' tcx > > (
2690
+ & mut self ,
2691
+ _variance : ty:: Variance ,
2692
+ _info : ty:: VarianceDiagInfo < ' tcx > ,
2693
+ a : T ,
2694
+ b : T ,
2695
+ ) -> relate:: RelateResult < ' tcx , T > {
2696
+ self . relate ( a, b)
2697
+ }
2698
+
2699
+ fn tys ( & mut self , a : Ty < ' tcx > , b : Ty < ' tcx > ) -> RelateResult < ' tcx , Ty < ' tcx > > {
2700
+ match ( a. kind ( ) , b. kind ( ) ) {
2701
+ ( ty:: Int ( _) | ty:: Uint ( _) , ty:: Infer ( ty:: InferTy :: IntVar ( _) ) )
2686
2702
| (
2687
- & ty:: Infer ( ty:: InferTy :: IntVar ( _) ) ,
2688
- & ty:: Int ( _) | & ty:: Uint ( _) | & ty:: Infer ( ty:: InferTy :: IntVar ( _) ) ,
2703
+ ty:: Infer ( ty:: InferTy :: IntVar ( _) ) ,
2704
+ ty:: Int ( _) | ty:: Uint ( _) | ty:: Infer ( ty:: InferTy :: IntVar ( _) ) ,
2689
2705
)
2690
- | ( & ty:: Float ( _) , & ty:: Infer ( ty:: InferTy :: FloatVar ( _) ) )
2706
+ | ( ty:: Float ( _) , ty:: Infer ( ty:: InferTy :: FloatVar ( _) ) )
2691
2707
| (
2692
- & ty:: Infer ( ty:: InferTy :: FloatVar ( _) ) ,
2693
- & ty:: Float ( _) | & ty:: Infer ( ty:: InferTy :: FloatVar ( _) ) ,
2708
+ ty:: Infer ( ty:: InferTy :: FloatVar ( _) ) ,
2709
+ ty:: Float ( _) | ty:: Infer ( ty:: InferTy :: FloatVar ( _) ) ,
2694
2710
)
2695
- | ( & ty:: Infer ( ty:: InferTy :: TyVar ( _) ) , _)
2696
- | ( _, & ty:: Infer ( ty:: InferTy :: TyVar ( _) ) ) => true ,
2697
- ( & ty:: Ref ( _, ty_a, mut_a) , & ty:: Ref ( _, ty_b, mut_b) ) => {
2698
- mut_a == mut_b && self . same_type_modulo_infer ( ty_a, ty_b)
2699
- }
2700
- ( & ty:: RawPtr ( a) , & ty:: RawPtr ( b) ) => {
2701
- a. mutbl == b. mutbl && self . same_type_modulo_infer ( a. ty , b. ty )
2702
- }
2703
- ( & ty:: Slice ( a) , & ty:: Slice ( b) ) => self . same_type_modulo_infer ( a, b) ,
2704
- ( & ty:: Array ( a_ty, a_ct) , & ty:: Array ( b_ty, b_ct) ) => {
2705
- self . same_type_modulo_infer ( a_ty, b_ty) && a_ct == b_ct
2706
- }
2707
- ( & ty:: Tuple ( a) , & ty:: Tuple ( b) ) => {
2708
- if a. len ( ) != b. len ( ) {
2709
- return false ;
2710
- }
2711
- std:: iter:: zip ( a. iter ( ) , b. iter ( ) ) . all ( |( a, b) | self . same_type_modulo_infer ( a, b) )
2712
- }
2713
- ( & ty:: FnPtr ( a) , & ty:: FnPtr ( b) ) => {
2714
- let a = a. skip_binder ( ) . inputs_and_output ;
2715
- let b = b. skip_binder ( ) . inputs_and_output ;
2716
- if a. len ( ) != b. len ( ) {
2717
- return false ;
2718
- }
2719
- std:: iter:: zip ( a. iter ( ) , b. iter ( ) ) . all ( |( a, b) | self . same_type_modulo_infer ( a, b) )
2720
- }
2721
- // FIXME(compiler-errors): This needs to be generalized more
2722
- _ => a == b,
2711
+ | ( ty:: Infer ( ty:: InferTy :: TyVar ( _) ) , _)
2712
+ | ( _, ty:: Infer ( ty:: InferTy :: TyVar ( _) ) ) => Ok ( a) ,
2713
+ ( ty:: Infer ( _) , _) | ( _, ty:: Infer ( _) ) => Err ( TypeError :: Mismatch ) ,
2714
+ _ => relate:: super_relate_tys ( self , a, b) ,
2723
2715
}
2724
2716
}
2717
+
2718
+ fn regions (
2719
+ & mut self ,
2720
+ a : ty:: Region < ' tcx > ,
2721
+ b : ty:: Region < ' tcx > ,
2722
+ ) -> RelateResult < ' tcx , ty:: Region < ' tcx > > {
2723
+ if ( a. is_var ( ) && b. is_free_or_static ( ) ) || ( b. is_var ( ) && a. is_free_or_static ( ) ) || a == b
2724
+ {
2725
+ Ok ( a)
2726
+ } else {
2727
+ Err ( TypeError :: Mismatch )
2728
+ }
2729
+ }
2730
+
2731
+ fn binders < T > (
2732
+ & mut self ,
2733
+ a : ty:: Binder < ' tcx , T > ,
2734
+ b : ty:: Binder < ' tcx , T > ,
2735
+ ) -> relate:: RelateResult < ' tcx , ty:: Binder < ' tcx , T > >
2736
+ where
2737
+ T : relate:: Relate < ' tcx > ,
2738
+ {
2739
+ Ok ( ty:: Binder :: dummy ( self . relate ( a. skip_binder ( ) , b. skip_binder ( ) ) ?) )
2740
+ }
2741
+
2742
+ fn consts (
2743
+ & mut self ,
2744
+ a : ty:: Const < ' tcx > ,
2745
+ _b : ty:: Const < ' tcx > ,
2746
+ ) -> relate:: RelateResult < ' tcx , ty:: Const < ' tcx > > {
2747
+ // FIXME(compiler-errors): This could at least do some first-order
2748
+ // relation
2749
+ Ok ( a)
2750
+ }
2725
2751
}
2726
2752
2727
2753
impl < ' a , ' tcx > InferCtxt < ' a , ' tcx > {
0 commit comments