Skip to content

Commit 4a99c5f

Browse files
committed
Auto merge of #97345 - lcnr:fast_reject, r=nnethercote
add a deep fast_reject routine continues the work on #97136. r? `@nnethercote` Actually agree with you on the match structure 😆 let's see how that impacted perf 😅
2 parents cbdce42 + bff7b51 commit 4a99c5f

File tree

4 files changed

+252
-66
lines changed

4 files changed

+252
-66
lines changed

compiler/rustc_middle/src/ty/fast_reject.rs

+221
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
use crate::mir::Mutability;
2+
use crate::ty::subst::GenericArgKind;
23
use crate::ty::{self, Ty, TyCtxt, TypeFoldable};
34
use rustc_hir::def_id::DefId;
45
use std::fmt::Debug;
56
use std::hash::Hash;
7+
use std::iter;
68

79
use self::SimplifiedTypeGen::*;
810

@@ -72,6 +74,10 @@ pub enum TreatParams {
7274

7375
/// Tries to simplify a type by only returning the outermost injective¹ layer, if one exists.
7476
///
77+
/// **This function should only be used if you need to store or retrieve the type from some
78+
/// hashmap. If you want to quickly decide whether two types may unify, use the [DeepRejectCtxt]
79+
/// instead.**
80+
///
7581
/// The idea is to get something simple that we can use to quickly decide if two types could unify,
7682
/// for example during method lookup. If this function returns `Some(x)` it can only unify with
7783
/// types for which this method returns either `Some(x)` as well or `None`.
@@ -182,3 +188,218 @@ impl<D: Copy + Debug + Eq> SimplifiedTypeGen<D> {
182188
}
183189
}
184190
}
191+
192+
/// Given generic arguments from an obligation and an impl,
193+
/// could these two be unified after replacing parameters in the
194+
/// the impl with inference variables.
195+
///
196+
/// For obligations, parameters won't be replaced by inference
197+
/// variables and only unify with themselves. We treat them
198+
/// the same way we treat placeholders.
199+
///
200+
/// We also use this function during coherence. For coherence the
201+
/// impls only have to overlap for some value, so we treat parameters
202+
/// on both sides like inference variables. This behavior is toggled
203+
/// using the `treat_obligation_params` field.
204+
#[derive(Debug, Clone, Copy)]
205+
pub struct DeepRejectCtxt {
206+
pub treat_obligation_params: TreatParams,
207+
}
208+
209+
impl DeepRejectCtxt {
210+
pub fn generic_args_may_unify(
211+
self,
212+
obligation_arg: ty::GenericArg<'_>,
213+
impl_arg: ty::GenericArg<'_>,
214+
) -> bool {
215+
match (obligation_arg.unpack(), impl_arg.unpack()) {
216+
// We don't fast reject based on regions for now.
217+
(GenericArgKind::Lifetime(_), GenericArgKind::Lifetime(_)) => true,
218+
(GenericArgKind::Type(obl), GenericArgKind::Type(imp)) => {
219+
self.types_may_unify(obl, imp)
220+
}
221+
(GenericArgKind::Const(obl), GenericArgKind::Const(imp)) => {
222+
self.consts_may_unify(obl, imp)
223+
}
224+
_ => bug!("kind mismatch: {obligation_arg} {impl_arg}"),
225+
}
226+
}
227+
228+
pub fn types_may_unify(self, obligation_ty: Ty<'_>, impl_ty: Ty<'_>) -> bool {
229+
match impl_ty.kind() {
230+
// Start by checking whether the type in the impl may unify with
231+
// pretty much everything. Just return `true` in that case.
232+
ty::Param(_) | ty::Projection(_) | ty::Error(_) => return true,
233+
// These types only unify with inference variables or their own
234+
// variant.
235+
ty::Bool
236+
| ty::Char
237+
| ty::Int(_)
238+
| ty::Uint(_)
239+
| ty::Float(_)
240+
| ty::Adt(..)
241+
| ty::Str
242+
| ty::Array(..)
243+
| ty::Slice(..)
244+
| ty::RawPtr(..)
245+
| ty::Dynamic(..)
246+
| ty::Ref(..)
247+
| ty::Never
248+
| ty::Tuple(..)
249+
| ty::FnPtr(..)
250+
| ty::Foreign(..)
251+
| ty::Opaque(..) => {}
252+
ty::FnDef(..)
253+
| ty::Closure(..)
254+
| ty::Generator(..)
255+
| ty::GeneratorWitness(..)
256+
| ty::Placeholder(..)
257+
| ty::Bound(..)
258+
| ty::Infer(_) => bug!("unexpected impl_ty: {impl_ty}"),
259+
}
260+
261+
let k = impl_ty.kind();
262+
match *obligation_ty.kind() {
263+
// Purely rigid types, use structural equivalence.
264+
ty::Bool
265+
| ty::Char
266+
| ty::Int(_)
267+
| ty::Uint(_)
268+
| ty::Float(_)
269+
| ty::Str
270+
| ty::Never
271+
| ty::Foreign(_) => obligation_ty == impl_ty,
272+
ty::Ref(_, obl_ty, obl_mutbl) => match k {
273+
&ty::Ref(_, impl_ty, impl_mutbl) => {
274+
obl_mutbl == impl_mutbl && self.types_may_unify(obl_ty, impl_ty)
275+
}
276+
_ => false,
277+
},
278+
ty::Adt(obl_def, obl_substs) => match k {
279+
&ty::Adt(impl_def, impl_substs) => {
280+
obl_def == impl_def
281+
&& iter::zip(obl_substs, impl_substs)
282+
.all(|(obl, imp)| self.generic_args_may_unify(obl, imp))
283+
}
284+
_ => false,
285+
},
286+
ty::Slice(obl_ty) => {
287+
matches!(k, &ty::Slice(impl_ty) if self.types_may_unify(obl_ty, impl_ty))
288+
}
289+
ty::Array(obl_ty, obl_len) => match k {
290+
&ty::Array(impl_ty, impl_len) => {
291+
self.types_may_unify(obl_ty, impl_ty)
292+
&& self.consts_may_unify(obl_len, impl_len)
293+
}
294+
_ => false,
295+
},
296+
ty::Tuple(obl) => match k {
297+
&ty::Tuple(imp) => {
298+
obl.len() == imp.len()
299+
&& iter::zip(obl, imp).all(|(obl, imp)| self.types_may_unify(obl, imp))
300+
}
301+
_ => false,
302+
},
303+
ty::RawPtr(obl) => match k {
304+
ty::RawPtr(imp) => obl.mutbl == imp.mutbl && self.types_may_unify(obl.ty, imp.ty),
305+
_ => false,
306+
},
307+
ty::Dynamic(obl_preds, ..) => {
308+
// Ideally we would walk the existential predicates here or at least
309+
// compare their length. But considering that the relevant `Relate` impl
310+
// actually sorts and deduplicates these, that doesn't work.
311+
matches!(k, ty::Dynamic(impl_preds, ..) if
312+
obl_preds.principal_def_id() == impl_preds.principal_def_id()
313+
)
314+
}
315+
ty::FnPtr(obl_sig) => match k {
316+
ty::FnPtr(impl_sig) => {
317+
let ty::FnSig { inputs_and_output, c_variadic, unsafety, abi } =
318+
obl_sig.skip_binder();
319+
let impl_sig = impl_sig.skip_binder();
320+
321+
abi == impl_sig.abi
322+
&& c_variadic == impl_sig.c_variadic
323+
&& unsafety == impl_sig.unsafety
324+
&& inputs_and_output.len() == impl_sig.inputs_and_output.len()
325+
&& iter::zip(inputs_and_output, impl_sig.inputs_and_output)
326+
.all(|(obl, imp)| self.types_may_unify(obl, imp))
327+
}
328+
_ => false,
329+
},
330+
331+
// Opaque types in impls should be forbidden, but that doesn't
332+
// stop compilation. So this match arm should never return true
333+
// if compilation succeeds.
334+
ty::Opaque(..) => matches!(k, ty::Opaque(..)),
335+
336+
// Impls cannot contain these types as these cannot be named directly.
337+
ty::FnDef(..) | ty::Closure(..) | ty::Generator(..) => false,
338+
339+
ty::Placeholder(..) => false,
340+
341+
// Depending on the value of `treat_obligation_params`, we either
342+
// treat generic parameters like placeholders or like inference variables.
343+
ty::Param(_) => match self.treat_obligation_params {
344+
TreatParams::AsPlaceholder => false,
345+
TreatParams::AsInfer => true,
346+
},
347+
348+
ty::Infer(_) => true,
349+
350+
// As we're walking the whole type, it may encounter projections
351+
// inside of binders and what not, so we're just going to assume that
352+
// projections can unify with other stuff.
353+
//
354+
// Looking forward to lazy normalization this is the safer strategy anyways.
355+
ty::Projection(_) => true,
356+
357+
ty::Error(_) => true,
358+
359+
ty::GeneratorWitness(..) | ty::Bound(..) => {
360+
bug!("unexpected obligation type: {:?}", obligation_ty)
361+
}
362+
}
363+
}
364+
365+
pub fn consts_may_unify(self, obligation_ct: ty::Const<'_>, impl_ct: ty::Const<'_>) -> bool {
366+
match impl_ct.val() {
367+
ty::ConstKind::Param(_) | ty::ConstKind::Unevaluated(_) | ty::ConstKind::Error(_) => {
368+
return true;
369+
}
370+
ty::ConstKind::Value(_) => {}
371+
ty::ConstKind::Infer(_) | ty::ConstKind::Bound(..) | ty::ConstKind::Placeholder(_) => {
372+
bug!("unexpected impl arg: {:?}", impl_ct)
373+
}
374+
}
375+
376+
let k = impl_ct.val();
377+
match obligation_ct.val() {
378+
ty::ConstKind::Param(_) => match self.treat_obligation_params {
379+
TreatParams::AsPlaceholder => false,
380+
TreatParams::AsInfer => true,
381+
},
382+
383+
// As we don't necessarily eagerly evaluate constants,
384+
// they might unify with any value.
385+
ty::ConstKind::Unevaluated(_) | ty::ConstKind::Error(_) => true,
386+
ty::ConstKind::Value(obl) => match k {
387+
ty::ConstKind::Value(imp) => {
388+
// FIXME(valtrees): Once we have valtrees, we can just
389+
// compare them directly here.
390+
match (obl.try_to_scalar_int(), imp.try_to_scalar_int()) {
391+
(Some(obl), Some(imp)) => obl == imp,
392+
_ => true,
393+
}
394+
}
395+
_ => true,
396+
},
397+
398+
ty::ConstKind::Infer(_) => true,
399+
400+
ty::ConstKind::Bound(..) | ty::ConstKind::Placeholder(_) => {
401+
bug!("unexpected obl const: {:?}", obligation_ct)
402+
}
403+
}
404+
}
405+
}

compiler/rustc_trait_selection/src/traits/coherence.rs

+14-19
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ use rustc_hir::CRATE_HIR_ID;
2020
use rustc_infer::infer::{InferCtxt, TyCtxtInferExt};
2121
use rustc_infer::traits::{util, TraitEngine};
2222
use rustc_middle::traits::specialization_graph::OverlapMode;
23-
use rustc_middle::ty::fast_reject::{self, TreatParams};
23+
use rustc_middle::ty::fast_reject::{DeepRejectCtxt, TreatParams};
2424
use rustc_middle::ty::fold::TypeFoldable;
2525
use rustc_middle::ty::subst::Subst;
2626
use rustc_middle::ty::{self, ImplSubject, Ty, TyCtxt};
@@ -79,26 +79,21 @@ where
7979
// Before doing expensive operations like entering an inference context, do
8080
// a quick check via fast_reject to tell if the impl headers could possibly
8181
// unify.
82+
let drcx = DeepRejectCtxt { treat_obligation_params: TreatParams::AsInfer };
8283
let impl1_ref = tcx.impl_trait_ref(impl1_def_id);
8384
let impl2_ref = tcx.impl_trait_ref(impl2_def_id);
84-
85-
// Check if any of the input types definitely do not unify.
86-
if iter::zip(
87-
impl1_ref.iter().flat_map(|tref| tref.substs.types()),
88-
impl2_ref.iter().flat_map(|tref| tref.substs.types()),
89-
)
90-
.any(|(ty1, ty2)| {
91-
let t1 = fast_reject::simplify_type(tcx, ty1, TreatParams::AsInfer);
92-
let t2 = fast_reject::simplify_type(tcx, ty2, TreatParams::AsInfer);
93-
94-
if let (Some(t1), Some(t2)) = (t1, t2) {
95-
// Simplified successfully
96-
t1 != t2
97-
} else {
98-
// Types might unify
99-
false
85+
let may_overlap = match (impl1_ref, impl2_ref) {
86+
(Some(a), Some(b)) => iter::zip(a.substs, b.substs)
87+
.all(|(arg1, arg2)| drcx.generic_args_may_unify(arg1, arg2)),
88+
(None, None) => {
89+
let self_ty1 = tcx.type_of(impl1_def_id);
90+
let self_ty2 = tcx.type_of(impl2_def_id);
91+
drcx.types_may_unify(self_ty1, self_ty2)
10092
}
101-
}) {
93+
_ => bug!("unexpected impls: {impl1_def_id:?} {impl2_def_id:?}"),
94+
};
95+
96+
if !may_overlap {
10297
// Some types involved are definitely different, so the impls couldn't possibly overlap.
10398
debug!("overlapping_impls: fast_reject early-exit");
10499
return no_overlap();
@@ -519,7 +514,7 @@ pub fn orphan_check(tcx: TyCtxt<'_>, impl_def_id: DefId) -> Result<(), OrphanChe
519514
/// 3. Before this local type, no generic type parameter of the impl must
520515
/// be reachable through fundamental types.
521516
/// - e.g. `impl<T> Trait<LocalType> for Vec<T>` is fine, as `Vec` is not fundamental.
522-
/// - while `impl<T> Trait<LocalType for Box<T>` results in an error, as `T` is
517+
/// - while `impl<T> Trait<LocalType> for Box<T>` results in an error, as `T` is
523518
/// reachable through the fundamental type `Box`.
524519
/// 4. Every type in the local key parameter not known in C, going
525520
/// through the parameter's type tree, must appear only as a subtree of

compiler/rustc_trait_selection/src/traits/select/candidate_assembly.rs

+9-1
Original file line numberDiff line numberDiff line change
@@ -539,8 +539,16 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
539539
obligation.predicate.def_id(),
540540
obligation.predicate.skip_binder().trait_ref.self_ty(),
541541
|impl_def_id| {
542+
// Before we create the substitutions and everything, first
543+
// consider a "quick reject". This avoids creating more types
544+
// and so forth that we need to.
545+
let impl_trait_ref = self.tcx().bound_impl_trait_ref(impl_def_id).unwrap();
546+
if self.fast_reject_trait_refs(obligation, &impl_trait_ref.0) {
547+
return;
548+
}
549+
542550
self.infcx.probe(|_| {
543-
if let Ok(_substs) = self.match_impl(impl_def_id, obligation) {
551+
if let Ok(_substs) = self.match_impl(impl_def_id, impl_trait_ref, obligation) {
544552
candidates.vec.push(ImplCandidate(impl_def_id));
545553
}
546554
});

0 commit comments

Comments
 (0)