diff --git a/crates/pgt_statement_splitter/src/lib.rs b/crates/pgt_statement_splitter/src/lib.rs index 68f5daaf..63e68cd2 100644 --- a/crates/pgt_statement_splitter/src/lib.rs +++ b/crates/pgt_statement_splitter/src/lib.rs @@ -4,10 +4,10 @@ pub mod diagnostics; mod parser; -use parser::{Parse, Parser, source}; +use parser::{Parser, ParserResult, source}; use pgt_lexer::diagnostics::ScanError; -pub fn split(sql: &str) -> Result> { +pub fn split(sql: &str) -> Result> { let tokens = pgt_lexer::lex(sql)?; let mut parser = Parser::new(tokens); @@ -28,7 +28,7 @@ mod tests { struct Tester { input: String, - parse: Parse, + parse: ParserResult, } impl From<&str> for Tester { diff --git a/crates/pgt_statement_splitter/src/parser.rs b/crates/pgt_statement_splitter/src/parser.rs index 05de8cb0..c94fe245 100644 --- a/crates/pgt_statement_splitter/src/parser.rs +++ b/crates/pgt_statement_splitter/src/parser.rs @@ -13,24 +13,24 @@ use crate::diagnostics::SplitDiagnostic; /// Main parser that exposes the `cstree` api, and collects errors and statements /// It is modelled after a Pratt Parser. For a gentle introduction to Pratt Parsing, see https://matklad.github.io/2020/04/13/simple-but-powerful-pratt-parsing.html pub struct Parser { - /// The ranges of the statements - ranges: Vec<(usize, usize)>, + /// The statement ranges are defined by the indices of the start/end tokens + stmt_ranges: Vec<(usize, usize)>, + /// The syntax errors accumulated during parsing errors: Vec, - /// The start of the current statement, if any + current_stmt_start: Option, - /// The tokens to parse - pub tokens: Vec, + + tokens: Vec, eof_token: Token, - next_pos: usize, + current_pos: usize, } -/// Result of Building #[derive(Debug)] -pub struct Parse { - /// The ranges of the errors +pub struct ParserResult { + /// The ranges of the parsed statements pub ranges: Vec, /// The syntax errors accumulated during parsing pub errors: Vec, @@ -41,40 +41,34 @@ impl Parser { let eof_token = Token::eof(usize::from( tokens .last() - .map(|t| t.span.start()) + .map(|t| t.span.end()) .unwrap_or(TextSize::from(0)), )); - // next_pos should be the initialised with the first valid token already - let mut next_pos = 0; - loop { - let token = tokens.get(next_pos).unwrap_or(&eof_token); - - if is_irrelevant_token(token) { - next_pos += 1; - } else { - break; - } + // Place `current_pos` on the first relevant token + let mut current_pos = 0; + while is_irrelevant_token(tokens.get(current_pos).unwrap_or(&eof_token)) { + current_pos += 1; } Self { - ranges: Vec::new(), + stmt_ranges: Vec::new(), eof_token, errors: Vec::new(), current_stmt_start: None, tokens, - next_pos, + current_pos, } } - pub fn finish(self) -> Parse { - Parse { + pub fn finish(self) -> ParserResult { + ParserResult { ranges: self - .ranges + .stmt_ranges .iter() - .map(|(start, end)| { - let from = self.tokens.get(*start); - let to = self.tokens.get(*end).unwrap_or(&self.eof_token); + .map(|(start_token_pos, end_token_pos)| { + let from = self.tokens.get(*start_token_pos); + let to = self.tokens.get(*end_token_pos).unwrap_or(&self.eof_token); TextRange::new(from.unwrap().span.start(), to.span.end()) }) @@ -83,124 +77,87 @@ impl Parser { } } - /// Start statement pub fn start_stmt(&mut self) { assert!( self.current_stmt_start.is_none(), "cannot start statement within statement at {:?}", self.tokens.get(self.current_stmt_start.unwrap()) ); - self.current_stmt_start = Some(self.next_pos); + self.current_stmt_start = Some(self.current_pos); } - /// Close statement pub fn close_stmt(&mut self) { - assert!(self.next_pos > 0); - - // go back the positions until we find the first relevant token - let mut end_token_pos = self.next_pos - 1; - loop { - let token = self.tokens.get(end_token_pos); + assert!( + self.current_stmt_start.is_some(), + "Must start statement before closing it." + ); - if end_token_pos == 0 || token.is_none() { - break; - } + let start_token_pos = self.current_stmt_start.unwrap(); - if !is_irrelevant_token(token.unwrap()) { - break; - } + assert!( + self.current_pos > start_token_pos, + "Must close the statement on a token that's later than the start token." + ); - end_token_pos -= 1; - } + let (end_token_pos, _) = self.find_last_relevant().unwrap(); - self.ranges.push(( - self.current_stmt_start.expect("Expected active statement"), - end_token_pos, - )); + self.stmt_ranges.push((start_token_pos, end_token_pos)); self.current_stmt_start = None; } - fn advance(&mut self) -> &Token { - let mut first_relevant_token = None; - loop { - let token = self.tokens.get(self.next_pos).unwrap_or(&self.eof_token); - - // we need to continue with next_pos until the next relevant token after we already - // found the first one - if !is_irrelevant_token(token) { - if let Some(t) = first_relevant_token { - return t; - } - first_relevant_token = Some(token); - } - - self.next_pos += 1; - } - } - - fn peek(&self) -> &Token { - match self.tokens.get(self.next_pos) { + fn current(&self) -> &Token { + match self.tokens.get(self.current_pos) { Some(token) => token, None => &self.eof_token, } } - /// Look ahead to the next relevant token - /// Returns `None` if we are already at the last relevant token - fn look_ahead(&self) -> Option<&Token> { - // we need to look ahead to the next relevant token - let mut look_ahead_pos = self.next_pos + 1; - loop { - let token = self.tokens.get(look_ahead_pos)?; - - if !is_irrelevant_token(token) { - return Some(token); - } + fn advance(&mut self) -> &Token { + // can't reuse any `find_next_relevant` logic because of Mr. Borrow Checker + let (pos, token) = self + .tokens + .iter() + .enumerate() + .skip(self.current_pos + 1) + .find(|(_, t)| is_relevant(t)) + .unwrap_or((self.tokens.len(), &self.eof_token)); + + self.current_pos = pos; + token + } - look_ahead_pos += 1; - } + fn look_ahead(&self) -> Option<&Token> { + self.tokens + .iter() + .skip(self.current_pos + 1) + .find(|t| is_relevant(t)) } /// Returns `None` if there are no previous relevant tokens fn look_back(&self) -> Option<&Token> { - // we need to look back to the last relevant token - let mut look_back_pos = self.next_pos - 1; - loop { - let token = self.tokens.get(look_back_pos); - - if look_back_pos == 0 || token.is_none() { - return None; - } - - if !is_irrelevant_token(token.unwrap()) { - return token; - } - - look_back_pos -= 1; - } + self.find_last_relevant().map(|it| it.1) } - /// checks if the current token is of `kind` and advances if true - /// returns true if the current token is of `kind` - pub fn eat(&mut self, kind: SyntaxKind) -> bool { - if self.peek().kind == kind { + /// Will advance if the `kind` matches the current token. + /// Otherwise, will add a diagnostic to the internal `errors`. + pub fn expect(&mut self, kind: SyntaxKind) { + if self.current().kind == kind { self.advance(); - true } else { - false + self.errors.push(SplitDiagnostic::new( + format!("Expected {:#?}", kind), + self.current().span, + )); } } - pub fn expect(&mut self, kind: SyntaxKind) { - if self.eat(kind) { - return; - } - - self.errors.push(SplitDiagnostic::new( - format!("Expected {:#?}", kind), - self.peek().span, - )); + fn find_last_relevant(&self) -> Option<(usize, &Token)> { + self.tokens + .iter() + .enumerate() + .take(self.current_pos) + .rfind(|(_, t)| is_relevant(t)) } } @@ -219,3 +176,57 @@ fn is_irrelevant_token(t: &Token) -> bool { WHITESPACE_TOKENS.contains(&t.kind) && (t.kind != SyntaxKind::Newline || t.text.chars().count() == 1) } + +fn is_relevant(t: &Token) -> bool { + !is_irrelevant_token(t) +} + +#[cfg(test)] +mod tests { + use pgt_lexer::SyntaxKind; + + use crate::parser::Parser; + + #[test] + fn advance_works_as_expected() { + let sql = r#" + create table users ( + id serial primary key, + name text, + email text + ); + "#; + let tokens = pgt_lexer::lex(sql).unwrap(); + let total_num_tokens = tokens.len(); + + let mut parser = Parser::new(tokens); + + let expected = vec![ + (SyntaxKind::Create, 2), + (SyntaxKind::Table, 4), + (SyntaxKind::Ident, 6), + (SyntaxKind::Ascii40, 8), + (SyntaxKind::Ident, 11), + (SyntaxKind::Ident, 13), + (SyntaxKind::Primary, 15), + (SyntaxKind::Key, 17), + (SyntaxKind::Ascii44, 18), + (SyntaxKind::NameP, 21), + (SyntaxKind::TextP, 23), + (SyntaxKind::Ascii44, 24), + (SyntaxKind::Ident, 27), + (SyntaxKind::TextP, 29), + (SyntaxKind::Ascii41, 32), + (SyntaxKind::Ascii59, 33), + ]; + + for (kind, pos) in expected { + assert_eq!(parser.current().kind, kind); + assert_eq!(parser.current_pos, pos); + parser.advance(); + } + + assert_eq!(parser.current().kind, SyntaxKind::Eof); + assert_eq!(parser.current_pos, total_num_tokens); + } +} diff --git a/crates/pgt_statement_splitter/src/parser/common.rs b/crates/pgt_statement_splitter/src/parser/common.rs index d145018d..1a355f08 100644 --- a/crates/pgt_statement_splitter/src/parser/common.rs +++ b/crates/pgt_statement_splitter/src/parser/common.rs @@ -9,7 +9,7 @@ use super::{ pub fn source(p: &mut Parser) { loop { - match p.peek() { + match p.current() { Token { kind: SyntaxKind::Eof, .. @@ -33,7 +33,7 @@ pub fn source(p: &mut Parser) { pub(crate) fn statement(p: &mut Parser) { p.start_stmt(); - match p.peek().kind { + match p.current().kind { SyntaxKind::With => { cte(p); } @@ -68,7 +68,7 @@ pub(crate) fn parenthesis(p: &mut Parser) { let mut depth = 1; loop { - match p.peek().kind { + match p.current().kind { SyntaxKind::Ascii40 => { p.advance(); depth += 1; @@ -91,7 +91,7 @@ pub(crate) fn case(p: &mut Parser) { p.expect(SyntaxKind::Case); loop { - match p.peek().kind { + match p.current().kind { SyntaxKind::EndP => { p.advance(); break; @@ -105,7 +105,7 @@ pub(crate) fn case(p: &mut Parser) { pub(crate) fn unknown(p: &mut Parser, exclude: &[SyntaxKind]) { loop { - match p.peek() { + match p.current() { Token { kind: SyntaxKind::Ascii59, .. diff --git a/crates/pgt_statement_splitter/src/parser/dml.rs b/crates/pgt_statement_splitter/src/parser/dml.rs index a45f6c40..015c50b6 100644 --- a/crates/pgt_statement_splitter/src/parser/dml.rs +++ b/crates/pgt_statement_splitter/src/parser/dml.rs @@ -13,7 +13,9 @@ pub(crate) fn cte(p: &mut Parser) { p.expect(SyntaxKind::As); parenthesis(p); - if !p.eat(SyntaxKind::Ascii44) { + if p.current().kind == SyntaxKind::Ascii44 { + p.advance(); + } else { break; } }