diff --git a/chalk-rust/src/ir/mod.rs b/chalk-rust/src/ir/mod.rs index 9da79f32bd8..d351e953a0c 100644 --- a/chalk-rust/src/ir/mod.rs +++ b/chalk-rust/src/ir/mod.rs @@ -12,6 +12,9 @@ pub struct Program { /// For each struct/trait: pub type_kinds: HashMap, + /// For each struct: + pub struct_data: HashMap, + /// For each impl: pub impl_data: HashMap, @@ -20,12 +23,34 @@ pub struct Program { /// For each trait: pub associated_ty_data: HashMap, +} + +impl Program { + pub fn split_projection<'p>(&self, projection: &'p ProjectionTy) + -> (&AssociatedTyDatum, &'p [Parameter], &'p [Parameter]) { + let ProjectionTy { associated_ty_id, ref parameters } = *projection; + let associated_ty_data = &self.associated_ty_data[&associated_ty_id]; + let trait_datum = &self.trait_data[&associated_ty_data.trait_id]; + let trait_num_params = trait_datum.binders.len(); + let split_point = parameters.len() - trait_num_params; + let (other_params, trait_params) = parameters.split_at(split_point); + (associated_ty_data, trait_params, other_params) + } +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct ProgramEnvironment { + /// For each trait: + pub trait_data: HashMap, + + /// For each trait: + pub associated_ty_data: HashMap, /// Compiled forms of the above: pub program_clauses: Vec, } -impl Program { +impl ProgramEnvironment { pub fn split_projection<'p>(&self, projection: &'p ProjectionTy) -> (&AssociatedTyDatum, &'p [Parameter], &'p [Parameter]) { let ProjectionTy { associated_ty_id, ref parameters } = *projection; diff --git a/chalk-rust/src/lower/mod.rs b/chalk-rust/src/lower/mod.rs index e175ee3d45a..86e498c8683 100644 --- a/chalk-rust/src/lower/mod.rs +++ b/chalk-rust/src/lower/mod.rs @@ -225,33 +225,7 @@ impl LowerProgram for Program { } } - // Construct the set of *clauses*; these are sort of a compiled form - // of the data above that always has the form: - // - // forall P0...Pn. Something :- Conditions - let mut program_clauses = vec![]; - - for struct_datum in struct_data.values() { - program_clauses.extend(struct_datum.to_program_clauses()); - } - - for trait_datum in trait_data.values() { - program_clauses.extend(trait_datum.to_program_clauses()); - } - - for (&id, associated_ty_datum) in &associated_ty_data { - program_clauses.extend(associated_ty_datum.to_program_clauses(id)); - } - - for impl_datum in impl_data.values() { - program_clauses.push(impl_datum.to_program_clause()); - - for atv in &impl_datum.binders.value.associated_ty_values { - program_clauses.extend(atv.to_program_clauses(impl_datum)); - } - } - - Ok(ir::Program { type_ids, type_kinds, trait_data, impl_data, associated_ty_data, program_clauses }) + Ok(ir::Program { type_ids, type_kinds, struct_data, trait_data, impl_data, associated_ty_data, }) } } @@ -882,6 +856,34 @@ impl LowerQuantifiedGoal for Goal { } } +impl ir::Program { + pub fn environment(&self) -> ir::ProgramEnvironment { + // Construct the set of *clauses*; these are sort of a compiled form + // of the data above that always has the form: + // + // forall P0...Pn. Something :- Conditions + let mut program_clauses = vec![]; + + program_clauses.extend(self.struct_data.values().flat_map(|d| d.to_program_clauses())); + program_clauses.extend(self.trait_data.values().flat_map(|d| d.to_program_clauses())); + program_clauses.extend(self.associated_ty_data.iter().flat_map(|(&id, d)| { + d.to_program_clauses(id) + })); + + for datum in self.impl_data.values() { + program_clauses.push(datum.to_program_clause()); + program_clauses.extend(datum.binders.value.associated_ty_values.iter().flat_map(|atv| { + atv.to_program_clauses(datum) + })); + } + + let trait_data = self.trait_data.clone(); + let associated_ty_data = self.associated_ty_data.clone(); + + ir::ProgramEnvironment { trait_data, associated_ty_data, program_clauses } + } +} + impl ir::ImplDatum { /// Given `impl Clone for Vec`, generate: /// diff --git a/chalk-rust/src/solve/environment.rs b/chalk-rust/src/solve/environment.rs index 4998d25a03c..e01e91131fa 100644 --- a/chalk-rust/src/solve/environment.rs +++ b/chalk-rust/src/solve/environment.rs @@ -29,7 +29,7 @@ impl Environment { Arc::new(env) } - pub fn elaborated_clauses(&self, program: &Program) -> impl Iterator { + pub fn elaborated_clauses(&self, program: &ProgramEnvironment) -> impl Iterator { let mut set = HashSet::new(); set.extend(self.clauses.iter().cloned()); diff --git a/chalk-rust/src/solve/fulfill.rs b/chalk-rust/src/solve/fulfill.rs index 40b64820a1d..e604eb942a3 100644 --- a/chalk-rust/src/solve/fulfill.rs +++ b/chalk-rust/src/solve/fulfill.rs @@ -23,7 +23,7 @@ impl<'s> Fulfill<'s> { Fulfill { solver, infer, obligations: vec![], constraints: HashSet::new() } } - pub fn program(&self) -> Arc { + pub fn program(&self) -> Arc { self.solver.program.clone() } diff --git a/chalk-rust/src/solve/solver/mod.rs b/chalk-rust/src/solve/solver/mod.rs index 97dd8cd589a..f6b41450e2b 100644 --- a/chalk-rust/src/solve/solver/mod.rs +++ b/chalk-rust/src/solve/solver/mod.rs @@ -13,13 +13,13 @@ use std::sync::Arc; use super::*; pub struct Solver { - pub(super) program: Arc, + pub(super) program: Arc, overflow_depth: usize, stack: Vec>>, } impl Solver { - pub fn new(program: &Arc, overflow_depth: usize) -> Self { + pub fn new(program: &Arc, overflow_depth: usize) -> Self { Solver { program: program.clone(), stack: vec![], overflow_depth, } } diff --git a/chalk-rust/src/solve/test.rs b/chalk-rust/src/solve/test.rs index 6c68064df90..d4e5b513fb9 100644 --- a/chalk-rust/src/solve/test.rs +++ b/chalk-rust/src/solve/test.rs @@ -27,6 +27,7 @@ fn solve_goal(program_text: &str, assert!(program_text.starts_with("{")); assert!(program_text.ends_with("}")); let program = Arc::new(parse_and_lower_program(&program_text[1..program_text.len()-1]).unwrap()); + let env = Arc::new(program.environment()); ir::set_current_program(&program, || { for (goal_text, expected) in goals { println!("----------------------------------------------------------------------"); @@ -39,7 +40,7 @@ fn solve_goal(program_text: &str, // tests don't require a higher one. let overflow_depth = 3; - let mut solver = Solver::new(&program, overflow_depth); + let mut solver = Solver::new(&env, overflow_depth); let result = match Prove::new(&mut solver, goal).solve() { Ok(v) => format!("{:#?}", v), Err(e) => format!("{}", e),