Skip to content

Commit 70f0c93

Browse files
refactor: simplify parser ? (#330)
* simplify * simplify 2 * simplify 3 * ok * ffs * comment * ok............ * tidying up * comment… * ok * comment * more * ok * ok * end test
1 parent 1cfa5b8 commit 70f0c93

File tree

4 files changed

+134
-121
lines changed

4 files changed

+134
-121
lines changed

crates/pgt_statement_splitter/src/lib.rs

+3-3
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44
pub mod diagnostics;
55
mod parser;
66

7-
use parser::{Parse, Parser, source};
7+
use parser::{Parser, ParserResult, source};
88
use pgt_lexer::diagnostics::ScanError;
99

10-
pub fn split(sql: &str) -> Result<Parse, Vec<ScanError>> {
10+
pub fn split(sql: &str) -> Result<ParserResult, Vec<ScanError>> {
1111
let tokens = pgt_lexer::lex(sql)?;
1212

1313
let mut parser = Parser::new(tokens);
@@ -28,7 +28,7 @@ mod tests {
2828

2929
struct Tester {
3030
input: String,
31-
parse: Parse,
31+
parse: ParserResult,
3232
}
3333

3434
impl From<&str> for Tester {

crates/pgt_statement_splitter/src/parser.rs

+123-112
Original file line numberDiff line numberDiff line change
@@ -13,24 +13,24 @@ use crate::diagnostics::SplitDiagnostic;
1313
/// Main parser that exposes the `cstree` api, and collects errors and statements
1414
/// It is modelled after a Pratt Parser. For a gentle introduction to Pratt Parsing, see https://matklad.github.io/2020/04/13/simple-but-powerful-pratt-parsing.html
1515
pub struct Parser {
16-
/// The ranges of the statements
17-
ranges: Vec<(usize, usize)>,
16+
/// The statement ranges are defined by the indices of the start/end tokens
17+
stmt_ranges: Vec<(usize, usize)>,
18+
1819
/// The syntax errors accumulated during parsing
1920
errors: Vec<SplitDiagnostic>,
20-
/// The start of the current statement, if any
21+
2122
current_stmt_start: Option<usize>,
22-
/// The tokens to parse
23-
pub tokens: Vec<Token>,
23+
24+
tokens: Vec<Token>,
2425

2526
eof_token: Token,
2627

27-
next_pos: usize,
28+
current_pos: usize,
2829
}
2930

30-
/// Result of Building
3131
#[derive(Debug)]
32-
pub struct Parse {
33-
/// The ranges of the errors
32+
pub struct ParserResult {
33+
/// The ranges of the parsed statements
3434
pub ranges: Vec<TextRange>,
3535
/// The syntax errors accumulated during parsing
3636
pub errors: Vec<SplitDiagnostic>,
@@ -41,40 +41,34 @@ impl Parser {
4141
let eof_token = Token::eof(usize::from(
4242
tokens
4343
.last()
44-
.map(|t| t.span.start())
44+
.map(|t| t.span.end())
4545
.unwrap_or(TextSize::from(0)),
4646
));
4747

48-
// next_pos should be the initialised with the first valid token already
49-
let mut next_pos = 0;
50-
loop {
51-
let token = tokens.get(next_pos).unwrap_or(&eof_token);
52-
53-
if is_irrelevant_token(token) {
54-
next_pos += 1;
55-
} else {
56-
break;
57-
}
48+
// Place `current_pos` on the first relevant token
49+
let mut current_pos = 0;
50+
while is_irrelevant_token(tokens.get(current_pos).unwrap_or(&eof_token)) {
51+
current_pos += 1;
5852
}
5953

6054
Self {
61-
ranges: Vec::new(),
55+
stmt_ranges: Vec::new(),
6256
eof_token,
6357
errors: Vec::new(),
6458
current_stmt_start: None,
6559
tokens,
66-
next_pos,
60+
current_pos,
6761
}
6862
}
6963

70-
pub fn finish(self) -> Parse {
71-
Parse {
64+
pub fn finish(self) -> ParserResult {
65+
ParserResult {
7266
ranges: self
73-
.ranges
67+
.stmt_ranges
7468
.iter()
75-
.map(|(start, end)| {
76-
let from = self.tokens.get(*start);
77-
let to = self.tokens.get(*end).unwrap_or(&self.eof_token);
69+
.map(|(start_token_pos, end_token_pos)| {
70+
let from = self.tokens.get(*start_token_pos);
71+
let to = self.tokens.get(*end_token_pos).unwrap_or(&self.eof_token);
7872

7973
TextRange::new(from.unwrap().span.start(), to.span.end())
8074
})
@@ -83,124 +77,87 @@ impl Parser {
8377
}
8478
}
8579

86-
/// Start statement
8780
pub fn start_stmt(&mut self) {
8881
assert!(
8982
self.current_stmt_start.is_none(),
9083
"cannot start statement within statement at {:?}",
9184
self.tokens.get(self.current_stmt_start.unwrap())
9285
);
93-
self.current_stmt_start = Some(self.next_pos);
86+
self.current_stmt_start = Some(self.current_pos);
9487
}
9588

96-
/// Close statement
9789
pub fn close_stmt(&mut self) {
98-
assert!(self.next_pos > 0);
99-
100-
// go back the positions until we find the first relevant token
101-
let mut end_token_pos = self.next_pos - 1;
102-
loop {
103-
let token = self.tokens.get(end_token_pos);
90+
assert!(
91+
self.current_stmt_start.is_some(),
92+
"Must start statement before closing it."
93+
);
10494

105-
if end_token_pos == 0 || token.is_none() {
106-
break;
107-
}
95+
let start_token_pos = self.current_stmt_start.unwrap();
10896

109-
if !is_irrelevant_token(token.unwrap()) {
110-
break;
111-
}
97+
assert!(
98+
self.current_pos > start_token_pos,
99+
"Must close the statement on a token that's later than the start token."
100+
);
112101

113-
end_token_pos -= 1;
114-
}
102+
let (end_token_pos, _) = self.find_last_relevant().unwrap();
115103

116-
self.ranges.push((
117-
self.current_stmt_start.expect("Expected active statement"),
118-
end_token_pos,
119-
));
104+
self.stmt_ranges.push((start_token_pos, end_token_pos));
120105

121106
self.current_stmt_start = None;
122107
}
123108

124-
fn advance(&mut self) -> &Token {
125-
let mut first_relevant_token = None;
126-
loop {
127-
let token = self.tokens.get(self.next_pos).unwrap_or(&self.eof_token);
128-
129-
// we need to continue with next_pos until the next relevant token after we already
130-
// found the first one
131-
if !is_irrelevant_token(token) {
132-
if let Some(t) = first_relevant_token {
133-
return t;
134-
}
135-
first_relevant_token = Some(token);
136-
}
137-
138-
self.next_pos += 1;
139-
}
140-
}
141-
142-
fn peek(&self) -> &Token {
143-
match self.tokens.get(self.next_pos) {
109+
fn current(&self) -> &Token {
110+
match self.tokens.get(self.current_pos) {
144111
Some(token) => token,
145112
None => &self.eof_token,
146113
}
147114
}
148115

149-
/// Look ahead to the next relevant token
150-
/// Returns `None` if we are already at the last relevant token
151-
fn look_ahead(&self) -> Option<&Token> {
152-
// we need to look ahead to the next relevant token
153-
let mut look_ahead_pos = self.next_pos + 1;
154-
loop {
155-
let token = self.tokens.get(look_ahead_pos)?;
156-
157-
if !is_irrelevant_token(token) {
158-
return Some(token);
159-
}
116+
fn advance(&mut self) -> &Token {
117+
// can't reuse any `find_next_relevant` logic because of Mr. Borrow Checker
118+
let (pos, token) = self
119+
.tokens
120+
.iter()
121+
.enumerate()
122+
.skip(self.current_pos + 1)
123+
.find(|(_, t)| is_relevant(t))
124+
.unwrap_or((self.tokens.len(), &self.eof_token));
125+
126+
self.current_pos = pos;
127+
token
128+
}
160129

161-
look_ahead_pos += 1;
162-
}
130+
fn look_ahead(&self) -> Option<&Token> {
131+
self.tokens
132+
.iter()
133+
.skip(self.current_pos + 1)
134+
.find(|t| is_relevant(t))
163135
}
164136

165137
/// Returns `None` if there are no previous relevant tokens
166138
fn look_back(&self) -> Option<&Token> {
167-
// we need to look back to the last relevant token
168-
let mut look_back_pos = self.next_pos - 1;
169-
loop {
170-
let token = self.tokens.get(look_back_pos);
171-
172-
if look_back_pos == 0 || token.is_none() {
173-
return None;
174-
}
175-
176-
if !is_irrelevant_token(token.unwrap()) {
177-
return token;
178-
}
179-
180-
look_back_pos -= 1;
181-
}
139+
self.find_last_relevant().map(|it| it.1)
182140
}
183141

184-
/// checks if the current token is of `kind` and advances if true
185-
/// returns true if the current token is of `kind`
186-
pub fn eat(&mut self, kind: SyntaxKind) -> bool {
187-
if self.peek().kind == kind {
142+
/// Will advance if the `kind` matches the current token.
143+
/// Otherwise, will add a diagnostic to the internal `errors`.
144+
pub fn expect(&mut self, kind: SyntaxKind) {
145+
if self.current().kind == kind {
188146
self.advance();
189-
true
190147
} else {
191-
false
148+
self.errors.push(SplitDiagnostic::new(
149+
format!("Expected {:#?}", kind),
150+
self.current().span,
151+
));
192152
}
193153
}
194154

195-
pub fn expect(&mut self, kind: SyntaxKind) {
196-
if self.eat(kind) {
197-
return;
198-
}
199-
200-
self.errors.push(SplitDiagnostic::new(
201-
format!("Expected {:#?}", kind),
202-
self.peek().span,
203-
));
155+
fn find_last_relevant(&self) -> Option<(usize, &Token)> {
156+
self.tokens
157+
.iter()
158+
.enumerate()
159+
.take(self.current_pos)
160+
.rfind(|(_, t)| is_relevant(t))
204161
}
205162
}
206163

@@ -219,3 +176,57 @@ fn is_irrelevant_token(t: &Token) -> bool {
219176
WHITESPACE_TOKENS.contains(&t.kind)
220177
&& (t.kind != SyntaxKind::Newline || t.text.chars().count() == 1)
221178
}
179+
180+
fn is_relevant(t: &Token) -> bool {
181+
!is_irrelevant_token(t)
182+
}
183+
184+
#[cfg(test)]
185+
mod tests {
186+
use pgt_lexer::SyntaxKind;
187+
188+
use crate::parser::Parser;
189+
190+
#[test]
191+
fn advance_works_as_expected() {
192+
let sql = r#"
193+
create table users (
194+
id serial primary key,
195+
name text,
196+
email text
197+
);
198+
"#;
199+
let tokens = pgt_lexer::lex(sql).unwrap();
200+
let total_num_tokens = tokens.len();
201+
202+
let mut parser = Parser::new(tokens);
203+
204+
let expected = vec![
205+
(SyntaxKind::Create, 2),
206+
(SyntaxKind::Table, 4),
207+
(SyntaxKind::Ident, 6),
208+
(SyntaxKind::Ascii40, 8),
209+
(SyntaxKind::Ident, 11),
210+
(SyntaxKind::Ident, 13),
211+
(SyntaxKind::Primary, 15),
212+
(SyntaxKind::Key, 17),
213+
(SyntaxKind::Ascii44, 18),
214+
(SyntaxKind::NameP, 21),
215+
(SyntaxKind::TextP, 23),
216+
(SyntaxKind::Ascii44, 24),
217+
(SyntaxKind::Ident, 27),
218+
(SyntaxKind::TextP, 29),
219+
(SyntaxKind::Ascii41, 32),
220+
(SyntaxKind::Ascii59, 33),
221+
];
222+
223+
for (kind, pos) in expected {
224+
assert_eq!(parser.current().kind, kind);
225+
assert_eq!(parser.current_pos, pos);
226+
parser.advance();
227+
}
228+
229+
assert_eq!(parser.current().kind, SyntaxKind::Eof);
230+
assert_eq!(parser.current_pos, total_num_tokens);
231+
}
232+
}

crates/pgt_statement_splitter/src/parser/common.rs

+5-5
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ use super::{
99

1010
pub fn source(p: &mut Parser) {
1111
loop {
12-
match p.peek() {
12+
match p.current() {
1313
Token {
1414
kind: SyntaxKind::Eof,
1515
..
@@ -33,7 +33,7 @@ pub fn source(p: &mut Parser) {
3333

3434
pub(crate) fn statement(p: &mut Parser) {
3535
p.start_stmt();
36-
match p.peek().kind {
36+
match p.current().kind {
3737
SyntaxKind::With => {
3838
cte(p);
3939
}
@@ -68,7 +68,7 @@ pub(crate) fn parenthesis(p: &mut Parser) {
6868
let mut depth = 1;
6969

7070
loop {
71-
match p.peek().kind {
71+
match p.current().kind {
7272
SyntaxKind::Ascii40 => {
7373
p.advance();
7474
depth += 1;
@@ -91,7 +91,7 @@ pub(crate) fn case(p: &mut Parser) {
9191
p.expect(SyntaxKind::Case);
9292

9393
loop {
94-
match p.peek().kind {
94+
match p.current().kind {
9595
SyntaxKind::EndP => {
9696
p.advance();
9797
break;
@@ -105,7 +105,7 @@ pub(crate) fn case(p: &mut Parser) {
105105

106106
pub(crate) fn unknown(p: &mut Parser, exclude: &[SyntaxKind]) {
107107
loop {
108-
match p.peek() {
108+
match p.current() {
109109
Token {
110110
kind: SyntaxKind::Ascii59,
111111
..

crates/pgt_statement_splitter/src/parser/dml.rs

+3-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@ pub(crate) fn cte(p: &mut Parser) {
1313
p.expect(SyntaxKind::As);
1414
parenthesis(p);
1515

16-
if !p.eat(SyntaxKind::Ascii44) {
16+
if p.current().kind == SyntaxKind::Ascii44 {
17+
p.advance();
18+
} else {
1719
break;
1820
}
1921
}

0 commit comments

Comments
 (0)