Skip to content

Commit 93d28fb

Browse files
Merge #3099
3099: Init implementation of structural search replace r=matklad a=mikhail-m1 next steps: * ignore space and other minor difference * add support to ra_cli * call rust parser to check pattern * documentation original issue #2267 Co-authored-by: Mikhail Modin <[email protected]>
2 parents 429fa44 + f8f454a commit 93d28fb

File tree

10 files changed

+399
-1
lines changed

10 files changed

+399
-1
lines changed

crates/ra_ide/src/lib.rs

+12
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ mod display;
3737
mod inlay_hints;
3838
mod expand;
3939
mod expand_macro;
40+
mod ssr;
4041

4142
#[cfg(test)]
4243
mod marks;
@@ -73,6 +74,7 @@ pub use crate::{
7374
},
7475
runnables::{Runnable, RunnableKind, TestId},
7576
source_change::{FileSystemEdit, SourceChange, SourceFileEdit},
77+
ssr::SsrError,
7678
syntax_highlighting::HighlightedRange,
7779
};
7880

@@ -464,6 +466,16 @@ impl Analysis {
464466
self.with_db(|db| references::rename(db, position, new_name))
465467
}
466468

469+
pub fn structural_search_replace(
470+
&self,
471+
query: &str,
472+
) -> Cancelable<Result<SourceChange, SsrError>> {
473+
self.with_db(|db| {
474+
let edits = ssr::parse_search_replace(query, db)?;
475+
Ok(SourceChange::source_file_edits("ssr", edits))
476+
})
477+
}
478+
467479
/// Performs an operation on that may be Canceled.
468480
fn with_db<F: FnOnce(&RootDatabase) -> T + std::panic::UnwindSafe, T>(
469481
&self,

crates/ra_ide/src/ssr.rs

+324
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,324 @@
1+
//! structural search replace
2+
3+
use crate::source_change::SourceFileEdit;
4+
use ra_ide_db::RootDatabase;
5+
use ra_syntax::ast::make::expr_from_text;
6+
use ra_syntax::AstNode;
7+
use ra_syntax::SyntaxElement;
8+
use ra_syntax::SyntaxNode;
9+
use ra_text_edit::{TextEdit, TextEditBuilder};
10+
use rustc_hash::FxHashMap;
11+
use std::collections::HashMap;
12+
use std::str::FromStr;
13+
14+
pub use ra_db::{SourceDatabase, SourceDatabaseExt};
15+
use ra_ide_db::symbol_index::SymbolsDatabase;
16+
17+
#[derive(Debug, PartialEq)]
18+
pub struct SsrError(String);
19+
20+
impl std::fmt::Display for SsrError {
21+
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
22+
write!(f, "Parse error: {}", self.0)
23+
}
24+
}
25+
26+
impl std::error::Error for SsrError {}
27+
28+
pub fn parse_search_replace(
29+
query: &str,
30+
db: &RootDatabase,
31+
) -> Result<Vec<SourceFileEdit>, SsrError> {
32+
let mut edits = vec![];
33+
let query: SsrQuery = query.parse()?;
34+
for &root in db.local_roots().iter() {
35+
let sr = db.source_root(root);
36+
for file_id in sr.walk() {
37+
dbg!(db.file_relative_path(file_id));
38+
let matches = find(&query.pattern, db.parse(file_id).tree().syntax());
39+
if !matches.matches.is_empty() {
40+
edits.push(SourceFileEdit { file_id, edit: replace(&matches, &query.template) });
41+
}
42+
}
43+
}
44+
Ok(edits)
45+
}
46+
47+
#[derive(Debug)]
48+
struct SsrQuery {
49+
pattern: SsrPattern,
50+
template: SsrTemplate,
51+
}
52+
53+
#[derive(Debug)]
54+
struct SsrPattern {
55+
pattern: SyntaxNode,
56+
vars: Vec<Var>,
57+
}
58+
59+
/// represents an `$var` in an SSR query
60+
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
61+
struct Var(String);
62+
63+
#[derive(Debug)]
64+
struct SsrTemplate {
65+
template: SyntaxNode,
66+
placeholders: FxHashMap<SyntaxNode, Var>,
67+
}
68+
69+
type Binding = HashMap<Var, SyntaxNode>;
70+
71+
#[derive(Debug)]
72+
struct Match {
73+
place: SyntaxNode,
74+
binding: Binding,
75+
}
76+
77+
#[derive(Debug)]
78+
struct SsrMatches {
79+
matches: Vec<Match>,
80+
}
81+
82+
impl FromStr for SsrQuery {
83+
type Err = SsrError;
84+
85+
fn from_str(query: &str) -> Result<SsrQuery, SsrError> {
86+
let mut it = query.split("==>>");
87+
let pattern = it.next().expect("at least empty string").trim();
88+
let mut template =
89+
it.next().ok_or(SsrError("Cannot find delemiter `==>>`".into()))?.trim().to_string();
90+
if it.next().is_some() {
91+
return Err(SsrError("More than one delimiter found".into()));
92+
}
93+
let mut vars = vec![];
94+
let mut it = pattern.split('$');
95+
let mut pattern = it.next().expect("something").to_string();
96+
97+
for part in it.map(split_by_var) {
98+
let (var, var_type, remainder) = part?;
99+
is_expr(var_type)?;
100+
let new_var = create_name(var, &mut vars)?;
101+
pattern.push_str(new_var);
102+
pattern.push_str(remainder);
103+
template = replace_in_template(template, var, new_var);
104+
}
105+
106+
let template = expr_from_text(&template).syntax().clone();
107+
let mut placeholders = FxHashMap::default();
108+
109+
traverse(&template, &mut |n| {
110+
if let Some(v) = vars.iter().find(|v| v.0.as_str() == n.text()) {
111+
placeholders.insert(n.clone(), v.clone());
112+
false
113+
} else {
114+
true
115+
}
116+
});
117+
118+
let pattern = SsrPattern { pattern: expr_from_text(&pattern).syntax().clone(), vars };
119+
let template = SsrTemplate { template, placeholders };
120+
Ok(SsrQuery { pattern, template })
121+
}
122+
}
123+
124+
fn traverse(node: &SyntaxNode, go: &mut impl FnMut(&SyntaxNode) -> bool) {
125+
if !go(node) {
126+
return;
127+
}
128+
for ref child in node.children() {
129+
traverse(child, go);
130+
}
131+
}
132+
133+
fn split_by_var(s: &str) -> Result<(&str, &str, &str), SsrError> {
134+
let end_of_name = s.find(":").ok_or(SsrError("Use $<name>:expr".into()))?;
135+
let name = &s[0..end_of_name];
136+
is_name(name)?;
137+
let type_begin = end_of_name + 1;
138+
let type_length = s[type_begin..].find(|c| !char::is_ascii_alphanumeric(&c)).unwrap_or(s.len());
139+
let type_name = &s[type_begin..type_begin + type_length];
140+
Ok((name, type_name, &s[type_begin + type_length..]))
141+
}
142+
143+
fn is_name(s: &str) -> Result<(), SsrError> {
144+
if s.chars().all(|c| c.is_ascii_alphanumeric() || c == '_') {
145+
Ok(())
146+
} else {
147+
Err(SsrError("Name can contain only alphanumerics and _".into()))
148+
}
149+
}
150+
151+
fn is_expr(s: &str) -> Result<(), SsrError> {
152+
if s == "expr" {
153+
Ok(())
154+
} else {
155+
Err(SsrError("Only $<name>:expr is supported".into()))
156+
}
157+
}
158+
159+
fn replace_in_template(template: String, var: &str, new_var: &str) -> String {
160+
let name = format!("${}", var);
161+
template.replace(&name, new_var)
162+
}
163+
164+
fn create_name<'a>(name: &str, vars: &'a mut Vec<Var>) -> Result<&'a str, SsrError> {
165+
let sanitized_name = format!("__search_pattern_{}", name);
166+
if vars.iter().any(|a| a.0 == sanitized_name) {
167+
return Err(SsrError(format!("Name `{}` repeats more than once", name)));
168+
}
169+
vars.push(Var(sanitized_name));
170+
Ok(&vars.last().unwrap().0)
171+
}
172+
173+
fn find(pattern: &SsrPattern, code: &SyntaxNode) -> SsrMatches {
174+
fn check(
175+
pattern: &SyntaxElement,
176+
code: &SyntaxElement,
177+
placeholders: &[Var],
178+
match_: &mut Match,
179+
) -> bool {
180+
match (pattern, code) {
181+
(SyntaxElement::Token(ref pattern), SyntaxElement::Token(ref code)) => {
182+
pattern.text() == code.text()
183+
}
184+
(SyntaxElement::Node(ref pattern), SyntaxElement::Node(ref code)) => {
185+
if placeholders.iter().find(|&n| n.0.as_str() == pattern.text()).is_some() {
186+
match_.binding.insert(Var(pattern.text().to_string()), code.clone());
187+
true
188+
} else {
189+
pattern.green().children().count() == code.green().children().count()
190+
&& pattern
191+
.children_with_tokens()
192+
.zip(code.children_with_tokens())
193+
.all(|(a, b)| check(&a, &b, placeholders, match_))
194+
}
195+
}
196+
_ => false,
197+
}
198+
}
199+
let kind = pattern.pattern.kind();
200+
let matches = code
201+
.descendants_with_tokens()
202+
.filter(|n| n.kind() == kind)
203+
.filter_map(|code| {
204+
let mut match_ =
205+
Match { place: code.as_node().unwrap().clone(), binding: HashMap::new() };
206+
if check(
207+
&SyntaxElement::from(pattern.pattern.clone()),
208+
&code,
209+
&pattern.vars,
210+
&mut match_,
211+
) {
212+
Some(match_)
213+
} else {
214+
None
215+
}
216+
})
217+
.collect();
218+
SsrMatches { matches }
219+
}
220+
221+
fn replace(matches: &SsrMatches, template: &SsrTemplate) -> TextEdit {
222+
let mut builder = TextEditBuilder::default();
223+
for match_ in &matches.matches {
224+
builder.replace(match_.place.text_range(), render_replace(&match_.binding, template));
225+
}
226+
builder.finish()
227+
}
228+
229+
fn render_replace(binding: &Binding, template: &SsrTemplate) -> String {
230+
let mut builder = TextEditBuilder::default();
231+
for element in template.template.descendants() {
232+
if let Some(var) = template.placeholders.get(&element) {
233+
builder.replace(element.text_range(), binding[var].to_string())
234+
}
235+
}
236+
builder.finish().apply(&template.template.text().to_string())
237+
}
238+
239+
#[cfg(test)]
240+
mod tests {
241+
use super::*;
242+
use ra_syntax::SourceFile;
243+
244+
fn parse_error_text(query: &str) -> String {
245+
format!("{}", query.parse::<SsrQuery>().unwrap_err())
246+
}
247+
248+
#[test]
249+
fn parser_happy_case() {
250+
let result: SsrQuery = "foo($a:expr, $b:expr) ==>> bar($b, $a)".parse().unwrap();
251+
assert_eq!(&result.pattern.pattern.text(), "foo(__search_pattern_a, __search_pattern_b)");
252+
assert_eq!(result.pattern.vars.len(), 2);
253+
assert_eq!(result.pattern.vars[0].0, "__search_pattern_a");
254+
assert_eq!(result.pattern.vars[1].0, "__search_pattern_b");
255+
assert_eq!(&result.template.template.text(), "bar(__search_pattern_b, __search_pattern_a)");
256+
dbg!(result.template.placeholders);
257+
}
258+
259+
#[test]
260+
fn parser_empty_query() {
261+
assert_eq!(parse_error_text(""), "Parse error: Cannot find delemiter `==>>`");
262+
}
263+
264+
#[test]
265+
fn parser_no_delimiter() {
266+
assert_eq!(parse_error_text("foo()"), "Parse error: Cannot find delemiter `==>>`");
267+
}
268+
269+
#[test]
270+
fn parser_two_delimiters() {
271+
assert_eq!(
272+
parse_error_text("foo() ==>> a ==>> b "),
273+
"Parse error: More than one delimiter found"
274+
);
275+
}
276+
277+
#[test]
278+
fn parser_no_pattern_type() {
279+
assert_eq!(parse_error_text("foo($a) ==>>"), "Parse error: Use $<name>:expr");
280+
}
281+
282+
#[test]
283+
fn parser_invalid_name() {
284+
assert_eq!(
285+
parse_error_text("foo($a+:expr) ==>>"),
286+
"Parse error: Name can contain only alphanumerics and _"
287+
);
288+
}
289+
290+
#[test]
291+
fn parser_invalid_type() {
292+
assert_eq!(
293+
parse_error_text("foo($a:ident) ==>>"),
294+
"Parse error: Only $<name>:expr is supported"
295+
);
296+
}
297+
298+
#[test]
299+
fn parser_repeated_name() {
300+
assert_eq!(
301+
parse_error_text("foo($a:expr, $a:expr) ==>>"),
302+
"Parse error: Name `a` repeats more than once"
303+
);
304+
}
305+
306+
#[test]
307+
fn parse_match_replace() {
308+
let query: SsrQuery = "foo($x:expr) ==>> bar($x)".parse().unwrap();
309+
let input = "fn main() { foo(1+2); }";
310+
311+
let code = SourceFile::parse(input).tree();
312+
let matches = find(&query.pattern, code.syntax());
313+
assert_eq!(matches.matches.len(), 1);
314+
assert_eq!(matches.matches[0].place.text(), "foo(1+2)");
315+
assert_eq!(matches.matches[0].binding.len(), 1);
316+
assert_eq!(
317+
matches.matches[0].binding[&Var("__search_pattern_x".to_string())].text(),
318+
"1+2"
319+
);
320+
321+
let edit = replace(&matches, &query.template);
322+
assert_eq!(edit.apply(input), "fn main() { bar(1+2); }");
323+
}
324+
}

crates/ra_lsp_server/src/main_loop.rs

+1
Original file line numberDiff line numberDiff line change
@@ -527,6 +527,7 @@ fn on_request(
527527
.on::<req::CallHierarchyPrepare>(handlers::handle_call_hierarchy_prepare)?
528528
.on::<req::CallHierarchyIncomingCalls>(handlers::handle_call_hierarchy_incoming)?
529529
.on::<req::CallHierarchyOutgoingCalls>(handlers::handle_call_hierarchy_outgoing)?
530+
.on::<req::Ssr>(handlers::handle_ssr)?
530531
.finish();
531532
Ok(())
532533
}

crates/ra_lsp_server/src/main_loop/handlers.rs

+5
Original file line numberDiff line numberDiff line change
@@ -882,6 +882,11 @@ pub fn handle_document_highlight(
882882
))
883883
}
884884

885+
pub fn handle_ssr(world: WorldSnapshot, params: req::SsrParams) -> Result<req::SourceChange> {
886+
let _p = profile("handle_ssr");
887+
world.analysis().structural_search_replace(&params.arg)??.try_conv_with(&world)
888+
}
889+
885890
pub fn publish_diagnostics(world: &WorldSnapshot, file_id: FileId) -> Result<DiagnosticTask> {
886891
let _p = profile("publish_diagnostics");
887892
let line_index = world.analysis().file_line_index(file_id)?;

crates/ra_lsp_server/src/req.rs

+13
Original file line numberDiff line numberDiff line change
@@ -206,3 +206,16 @@ pub struct InlayHint {
206206
pub kind: InlayKind,
207207
pub label: String,
208208
}
209+
210+
pub enum Ssr {}
211+
212+
impl Request for Ssr {
213+
type Params = SsrParams;
214+
type Result = SourceChange;
215+
const METHOD: &'static str = "rust-analyzer/ssr";
216+
}
217+
218+
#[derive(Debug, Deserialize, Serialize)]
219+
pub struct SsrParams {
220+
pub arg: String,
221+
}

0 commit comments

Comments
 (0)