Skip to content

refactor: simplify parser ? #330

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Apr 12, 2025
6 changes: 3 additions & 3 deletions crates/pgt_statement_splitter/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
pub mod diagnostics;
mod parser;

use parser::{Parse, Parser, source};
use parser::{Parser, ParserResult, source};
use pgt_lexer::diagnostics::ScanError;

pub fn split(sql: &str) -> Result<Parse, Vec<ScanError>> {
pub fn split(sql: &str) -> Result<ParserResult, Vec<ScanError>> {
let tokens = pgt_lexer::lex(sql)?;

let mut parser = Parser::new(tokens);
Expand All @@ -28,7 +28,7 @@ mod tests {

struct Tester {
input: String,
parse: Parse,
parse: ParserResult,
}

impl From<&str> for Tester {
Expand Down
235 changes: 123 additions & 112 deletions crates/pgt_statement_splitter/src/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,24 +13,24 @@ use crate::diagnostics::SplitDiagnostic;
/// Main parser that exposes the `cstree` api, and collects errors and statements
/// It is modelled after a Pratt Parser. For a gentle introduction to Pratt Parsing, see https://matklad.github.io/2020/04/13/simple-but-powerful-pratt-parsing.html
pub struct Parser {
/// The ranges of the statements
ranges: Vec<(usize, usize)>,
/// The statement ranges are defined by the indices of the start/end tokens
stmt_ranges: Vec<(usize, usize)>,

/// The syntax errors accumulated during parsing
errors: Vec<SplitDiagnostic>,
/// The start of the current statement, if any

current_stmt_start: Option<usize>,
/// The tokens to parse
pub tokens: Vec<Token>,

tokens: Vec<Token>,

eof_token: Token,

next_pos: usize,
current_pos: usize,
}

/// Result of Building
#[derive(Debug)]
pub struct Parse {
/// The ranges of the errors
pub struct ParserResult {
/// The ranges of the parsed statements
pub ranges: Vec<TextRange>,
/// The syntax errors accumulated during parsing
pub errors: Vec<SplitDiagnostic>,
Expand All @@ -41,40 +41,34 @@ impl Parser {
let eof_token = Token::eof(usize::from(
tokens
.last()
.map(|t| t.span.start())
.map(|t| t.span.end())
.unwrap_or(TextSize::from(0)),
));

// next_pos should be the initialised with the first valid token already
let mut next_pos = 0;
loop {
let token = tokens.get(next_pos).unwrap_or(&eof_token);

if is_irrelevant_token(token) {
next_pos += 1;
} else {
break;
}
// Place `current_pos` on the first relevant token
let mut current_pos = 0;
while is_irrelevant_token(tokens.get(current_pos).unwrap_or(&eof_token)) {
current_pos += 1;
}

Self {
ranges: Vec::new(),
stmt_ranges: Vec::new(),
eof_token,
errors: Vec::new(),
current_stmt_start: None,
tokens,
next_pos,
current_pos,
}
}

pub fn finish(self) -> Parse {
Parse {
pub fn finish(self) -> ParserResult {
ParserResult {
ranges: self
.ranges
.stmt_ranges
.iter()
.map(|(start, end)| {
let from = self.tokens.get(*start);
let to = self.tokens.get(*end).unwrap_or(&self.eof_token);
.map(|(start_token_pos, end_token_pos)| {
let from = self.tokens.get(*start_token_pos);
let to = self.tokens.get(*end_token_pos).unwrap_or(&self.eof_token);

TextRange::new(from.unwrap().span.start(), to.span.end())
})
Expand All @@ -83,124 +77,87 @@ impl Parser {
}
}

/// Start statement
pub fn start_stmt(&mut self) {
assert!(
self.current_stmt_start.is_none(),
"cannot start statement within statement at {:?}",
self.tokens.get(self.current_stmt_start.unwrap())
);
self.current_stmt_start = Some(self.next_pos);
self.current_stmt_start = Some(self.current_pos);
}

/// Close statement
pub fn close_stmt(&mut self) {
assert!(self.next_pos > 0);

// go back the positions until we find the first relevant token
let mut end_token_pos = self.next_pos - 1;
loop {
let token = self.tokens.get(end_token_pos);
assert!(
self.current_stmt_start.is_some(),
"Must start statement before closing it."
);

if end_token_pos == 0 || token.is_none() {
break;
}
let start_token_pos = self.current_stmt_start.unwrap();

if !is_irrelevant_token(token.unwrap()) {
break;
}
assert!(
self.current_pos > start_token_pos,
"Must close the statement on a token that's later than the start token."
);

end_token_pos -= 1;
}
let (end_token_pos, _) = self.find_last_relevant().unwrap();

self.ranges.push((
self.current_stmt_start.expect("Expected active statement"),
end_token_pos,
));
self.stmt_ranges.push((start_token_pos, end_token_pos));

self.current_stmt_start = None;
}

fn advance(&mut self) -> &Token {
let mut first_relevant_token = None;
loop {
let token = self.tokens.get(self.next_pos).unwrap_or(&self.eof_token);

// we need to continue with next_pos until the next relevant token after we already
// found the first one
if !is_irrelevant_token(token) {
if let Some(t) = first_relevant_token {
return t;
}
first_relevant_token = Some(token);
}

self.next_pos += 1;
}
}

fn peek(&self) -> &Token {
match self.tokens.get(self.next_pos) {
fn current(&self) -> &Token {
match self.tokens.get(self.current_pos) {
Some(token) => token,
None => &self.eof_token,
}
}

/// Look ahead to the next relevant token
/// Returns `None` if we are already at the last relevant token
fn look_ahead(&self) -> Option<&Token> {
// we need to look ahead to the next relevant token
let mut look_ahead_pos = self.next_pos + 1;
loop {
let token = self.tokens.get(look_ahead_pos)?;

if !is_irrelevant_token(token) {
return Some(token);
}
fn advance(&mut self) -> &Token {
// can't reuse any `find_next_relevant` logic because of Mr. Borrow Checker
let (pos, token) = self
.tokens
.iter()
.enumerate()
.skip(self.current_pos + 1)
.find(|(_, t)| is_relevant(t))
.unwrap_or((self.tokens.len(), &self.eof_token));

self.current_pos = pos;
token
}

look_ahead_pos += 1;
}
fn look_ahead(&self) -> Option<&Token> {
self.tokens
.iter()
.skip(self.current_pos + 1)
.find(|t| is_relevant(t))
}

/// Returns `None` if there are no previous relevant tokens
fn look_back(&self) -> Option<&Token> {
// we need to look back to the last relevant token
let mut look_back_pos = self.next_pos - 1;
loop {
let token = self.tokens.get(look_back_pos);

if look_back_pos == 0 || token.is_none() {
return None;
}

if !is_irrelevant_token(token.unwrap()) {
return token;
}

look_back_pos -= 1;
}
self.find_last_relevant().map(|it| it.1)
}

/// checks if the current token is of `kind` and advances if true
/// returns true if the current token is of `kind`
pub fn eat(&mut self, kind: SyntaxKind) -> bool {
if self.peek().kind == kind {
/// Will advance if the `kind` matches the current token.
/// Otherwise, will add a diagnostic to the internal `errors`.
pub fn expect(&mut self, kind: SyntaxKind) {
if self.current().kind == kind {
self.advance();
true
} else {
false
self.errors.push(SplitDiagnostic::new(
format!("Expected {:#?}", kind),
self.current().span,
));
}
}

pub fn expect(&mut self, kind: SyntaxKind) {
if self.eat(kind) {
return;
}

self.errors.push(SplitDiagnostic::new(
format!("Expected {:#?}", kind),
self.peek().span,
));
fn find_last_relevant(&self) -> Option<(usize, &Token)> {
self.tokens
.iter()
.enumerate()
.take(self.current_pos)
.rfind(|(_, t)| is_relevant(t))
}
}

Expand All @@ -219,3 +176,57 @@ fn is_irrelevant_token(t: &Token) -> bool {
WHITESPACE_TOKENS.contains(&t.kind)
&& (t.kind != SyntaxKind::Newline || t.text.chars().count() == 1)
}

fn is_relevant(t: &Token) -> bool {
!is_irrelevant_token(t)
}

#[cfg(test)]
mod tests {
use pgt_lexer::SyntaxKind;

use crate::parser::Parser;

#[test]
fn advance_works_as_expected() {
let sql = r#"
create table users (
id serial primary key,
name text,
email text
);
"#;
let tokens = pgt_lexer::lex(sql).unwrap();
let total_num_tokens = tokens.len();

let mut parser = Parser::new(tokens);

let expected = vec![
(SyntaxKind::Create, 2),
(SyntaxKind::Table, 4),
(SyntaxKind::Ident, 6),
(SyntaxKind::Ascii40, 8),
(SyntaxKind::Ident, 11),
(SyntaxKind::Ident, 13),
(SyntaxKind::Primary, 15),
(SyntaxKind::Key, 17),
(SyntaxKind::Ascii44, 18),
(SyntaxKind::NameP, 21),
(SyntaxKind::TextP, 23),
(SyntaxKind::Ascii44, 24),
(SyntaxKind::Ident, 27),
(SyntaxKind::TextP, 29),
(SyntaxKind::Ascii41, 32),
(SyntaxKind::Ascii59, 33),
];

for (kind, pos) in expected {
assert_eq!(parser.current().kind, kind);
assert_eq!(parser.current_pos, pos);
parser.advance();
}

assert_eq!(parser.current().kind, SyntaxKind::Eof);
assert_eq!(parser.current_pos, total_num_tokens);
}
}
10 changes: 5 additions & 5 deletions crates/pgt_statement_splitter/src/parser/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use super::{

pub fn source(p: &mut Parser) {
loop {
match p.peek() {
match p.current() {
Token {
kind: SyntaxKind::Eof,
..
Expand All @@ -33,7 +33,7 @@ pub fn source(p: &mut Parser) {

pub(crate) fn statement(p: &mut Parser) {
p.start_stmt();
match p.peek().kind {
match p.current().kind {
SyntaxKind::With => {
cte(p);
}
Expand Down Expand Up @@ -68,7 +68,7 @@ pub(crate) fn parenthesis(p: &mut Parser) {
let mut depth = 1;

loop {
match p.peek().kind {
match p.current().kind {
SyntaxKind::Ascii40 => {
p.advance();
depth += 1;
Expand All @@ -91,7 +91,7 @@ pub(crate) fn case(p: &mut Parser) {
p.expect(SyntaxKind::Case);

loop {
match p.peek().kind {
match p.current().kind {
SyntaxKind::EndP => {
p.advance();
break;
Expand All @@ -105,7 +105,7 @@ pub(crate) fn case(p: &mut Parser) {

pub(crate) fn unknown(p: &mut Parser, exclude: &[SyntaxKind]) {
loop {
match p.peek() {
match p.current() {
Token {
kind: SyntaxKind::Ascii59,
..
Expand Down
4 changes: 3 additions & 1 deletion crates/pgt_statement_splitter/src/parser/dml.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ pub(crate) fn cte(p: &mut Parser) {
p.expect(SyntaxKind::As);
parenthesis(p);

if !p.eat(SyntaxKind::Ascii44) {
if p.current().kind == SyntaxKind::Ascii44 {
p.advance();
} else {
break;
}
}
Expand Down
Loading