@@ -66,6 +66,7 @@ use rustc_hir::lang_items::LangItem;
66
66
use rustc_hir:: Node ;
67
67
use rustc_middle:: dep_graph:: DepContext ;
68
68
use rustc_middle:: ty:: print:: with_no_trimmed_paths;
69
+ use rustc_middle:: ty:: relate:: { self , RelateResult , TypeRelation } ;
69
70
use rustc_middle:: ty:: {
70
71
self , error:: TypeError , Binder , List , Region , Subst , Ty , TyCtxt , TypeFoldable ,
71
72
TypeSuperVisitable , TypeVisitable ,
@@ -2660,67 +2661,92 @@ impl<'a, 'tcx> InferCtxt<'a, 'tcx> {
2660
2661
/// Float types, respectively). When comparing two ADTs, these rules apply recursively.
2661
2662
pub fn same_type_modulo_infer ( & self , a : Ty < ' tcx > , b : Ty < ' tcx > ) -> bool {
2662
2663
let ( a, b) = self . resolve_vars_if_possible ( ( a, b) ) ;
2663
- match ( a. kind ( ) , b. kind ( ) ) {
2664
- ( & ty:: Adt ( def_a, substs_a) , & ty:: Adt ( def_b, substs_b) ) => {
2665
- if def_a != def_b {
2666
- return false ;
2667
- }
2664
+ SameTypeModuloInfer ( self ) . relate ( a, b) . is_ok ( )
2665
+ }
2666
+ }
2668
2667
2669
- substs_a
2670
- . types ( )
2671
- . zip ( substs_b. types ( ) )
2672
- . all ( |( a, b) | self . same_type_modulo_infer ( a, b) )
2673
- }
2674
- ( & ty:: FnDef ( did_a, substs_a) , & ty:: FnDef ( did_b, substs_b) ) => {
2675
- if did_a != did_b {
2676
- return false ;
2677
- }
2668
+ struct SameTypeModuloInfer < ' a , ' tcx > ( & ' a InferCtxt < ' a , ' tcx > ) ;
2678
2669
2679
- substs_a
2680
- . types ( )
2681
- . zip ( substs_b. types ( ) )
2682
- . all ( |( a, b) | self . same_type_modulo_infer ( a, b) )
2683
- }
2684
- ( & ty:: Int ( _) | & ty:: Uint ( _) , & ty:: Infer ( ty:: InferTy :: IntVar ( _) ) )
2670
+ impl < ' tcx > TypeRelation < ' tcx > for SameTypeModuloInfer < ' _ , ' tcx > {
2671
+ fn tcx ( & self ) -> TyCtxt < ' tcx > {
2672
+ self . 0 . tcx
2673
+ }
2674
+
2675
+ fn param_env ( & self ) -> ty:: ParamEnv < ' tcx > {
2676
+ // Unused, only for consts which we treat as always equal
2677
+ ty:: ParamEnv :: empty ( )
2678
+ }
2679
+
2680
+ fn tag ( & self ) -> & ' static str {
2681
+ "SameTypeModuloInfer"
2682
+ }
2683
+
2684
+ fn a_is_expected ( & self ) -> bool {
2685
+ true
2686
+ }
2687
+
2688
+ fn relate_with_variance < T : relate:: Relate < ' tcx > > (
2689
+ & mut self ,
2690
+ _variance : ty:: Variance ,
2691
+ _info : ty:: VarianceDiagInfo < ' tcx > ,
2692
+ a : T ,
2693
+ b : T ,
2694
+ ) -> relate:: RelateResult < ' tcx , T > {
2695
+ self . relate ( a, b)
2696
+ }
2697
+
2698
+ fn tys ( & mut self , a : Ty < ' tcx > , b : Ty < ' tcx > ) -> RelateResult < ' tcx , Ty < ' tcx > > {
2699
+ match ( a. kind ( ) , b. kind ( ) ) {
2700
+ ( ty:: Int ( _) | ty:: Uint ( _) , ty:: Infer ( ty:: InferTy :: IntVar ( _) ) )
2685
2701
| (
2686
- & ty:: Infer ( ty:: InferTy :: IntVar ( _) ) ,
2687
- & ty:: Int ( _) | & ty:: Uint ( _) | & ty:: Infer ( ty:: InferTy :: IntVar ( _) ) ,
2702
+ ty:: Infer ( ty:: InferTy :: IntVar ( _) ) ,
2703
+ ty:: Int ( _) | ty:: Uint ( _) | ty:: Infer ( ty:: InferTy :: IntVar ( _) ) ,
2688
2704
)
2689
- | ( & ty:: Float ( _) , & ty:: Infer ( ty:: InferTy :: FloatVar ( _) ) )
2705
+ | ( ty:: Float ( _) , ty:: Infer ( ty:: InferTy :: FloatVar ( _) ) )
2690
2706
| (
2691
- & ty:: Infer ( ty:: InferTy :: FloatVar ( _) ) ,
2692
- & ty:: Float ( _) | & ty:: Infer ( ty:: InferTy :: FloatVar ( _) ) ,
2707
+ ty:: Infer ( ty:: InferTy :: FloatVar ( _) ) ,
2708
+ ty:: Float ( _) | ty:: Infer ( ty:: InferTy :: FloatVar ( _) ) ,
2693
2709
)
2694
- | ( & ty:: Infer ( ty:: InferTy :: TyVar ( _) ) , _)
2695
- | ( _, & ty:: Infer ( ty:: InferTy :: TyVar ( _) ) ) => true ,
2696
- ( & ty:: Ref ( _, ty_a, mut_a) , & ty:: Ref ( _, ty_b, mut_b) ) => {
2697
- mut_a == mut_b && self . same_type_modulo_infer ( ty_a, ty_b)
2698
- }
2699
- ( & ty:: RawPtr ( a) , & ty:: RawPtr ( b) ) => {
2700
- a. mutbl == b. mutbl && self . same_type_modulo_infer ( a. ty , b. ty )
2701
- }
2702
- ( & ty:: Slice ( a) , & ty:: Slice ( b) ) => self . same_type_modulo_infer ( a, b) ,
2703
- ( & ty:: Array ( a_ty, a_ct) , & ty:: Array ( b_ty, b_ct) ) => {
2704
- self . same_type_modulo_infer ( a_ty, b_ty) && a_ct == b_ct
2705
- }
2706
- ( & ty:: Tuple ( a) , & ty:: Tuple ( b) ) => {
2707
- if a. len ( ) != b. len ( ) {
2708
- return false ;
2709
- }
2710
- std:: iter:: zip ( a. iter ( ) , b. iter ( ) ) . all ( |( a, b) | self . same_type_modulo_infer ( a, b) )
2711
- }
2712
- ( & ty:: FnPtr ( a) , & ty:: FnPtr ( b) ) => {
2713
- let a = a. skip_binder ( ) . inputs_and_output ;
2714
- let b = b. skip_binder ( ) . inputs_and_output ;
2715
- if a. len ( ) != b. len ( ) {
2716
- return false ;
2717
- }
2718
- std:: iter:: zip ( a. iter ( ) , b. iter ( ) ) . all ( |( a, b) | self . same_type_modulo_infer ( a, b) )
2719
- }
2720
- // FIXME(compiler-errors): This needs to be generalized more
2721
- _ => a == b,
2710
+ | ( ty:: Infer ( ty:: InferTy :: TyVar ( _) ) , _)
2711
+ | ( _, ty:: Infer ( ty:: InferTy :: TyVar ( _) ) ) => Ok ( a) ,
2712
+ ( ty:: Infer ( _) , _) | ( _, ty:: Infer ( _) ) => Err ( TypeError :: Mismatch ) ,
2713
+ _ => relate:: super_relate_tys ( self , a, b) ,
2722
2714
}
2723
2715
}
2716
+
2717
+ fn regions (
2718
+ & mut self ,
2719
+ a : ty:: Region < ' tcx > ,
2720
+ b : ty:: Region < ' tcx > ,
2721
+ ) -> RelateResult < ' tcx , ty:: Region < ' tcx > > {
2722
+ if ( a. is_var ( ) && b. is_free_or_static ( ) ) || ( b. is_var ( ) && a. is_free_or_static ( ) ) || a == b
2723
+ {
2724
+ Ok ( a)
2725
+ } else {
2726
+ Err ( TypeError :: Mismatch )
2727
+ }
2728
+ }
2729
+
2730
+ fn binders < T > (
2731
+ & mut self ,
2732
+ a : ty:: Binder < ' tcx , T > ,
2733
+ b : ty:: Binder < ' tcx , T > ,
2734
+ ) -> relate:: RelateResult < ' tcx , ty:: Binder < ' tcx , T > >
2735
+ where
2736
+ T : relate:: Relate < ' tcx > ,
2737
+ {
2738
+ Ok ( ty:: Binder :: dummy ( self . relate ( a. skip_binder ( ) , b. skip_binder ( ) ) ?) )
2739
+ }
2740
+
2741
+ fn consts (
2742
+ & mut self ,
2743
+ a : ty:: Const < ' tcx > ,
2744
+ _b : ty:: Const < ' tcx > ,
2745
+ ) -> relate:: RelateResult < ' tcx , ty:: Const < ' tcx > > {
2746
+ // FIXME(compiler-errors): This could at least do some first-order
2747
+ // relation
2748
+ Ok ( a)
2749
+ }
2724
2750
}
2725
2751
2726
2752
impl < ' a , ' tcx > InferCtxt < ' a , ' tcx > {
0 commit comments