diff --git a/Cargo.lock b/Cargo.lock index 0044279e..5aa27e7a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -36,6 +36,18 @@ version = "0.2.16" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0942ffc6dcaadf03badf6e6a2d0228460359d5e34b57ccdc720b7382dfbd5ec5" +[[package]] +name = "anes" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299" + +[[package]] +name = "anstyle" +version = "1.0.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1bec1de6f59aedf83baf9ff929c98f2ad654b97c9510f4e70cf6f661d49fd5b1" + [[package]] name = "anyhow" version = "1.0.81" @@ -309,6 +321,12 @@ version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a2bd12c1caf447e69cd4528f47f94d203fd2582878ecb9e9465484c4148a8223" +[[package]] +name = "cast" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" + [[package]] name = "cc" version = "1.0.83" @@ -333,6 +351,33 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" +[[package]] +name = "ciborium" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42e69ffd6f0917f5c029256a24d0161db17cea3997d185db0d35926308770f0e" +dependencies = [ + "ciborium-io", + "ciborium-ll", + "serde", +] + +[[package]] +name = "ciborium-io" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05afea1e0a06c9be33d539b876f1ce3692f4afea2cb41f740e7743225ed1c757" + +[[package]] +name = "ciborium-ll" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57663b653d948a338bfb3eeba9bb2fd5fcfaecb9e199e87e1eda4d9e8b240fd9" +dependencies = [ + "ciborium-io", + "half", +] + [[package]] name = "clang-sys" version = "1.7.0" @@ -344,6 +389,31 @@ dependencies = [ "libloading", ] +[[package]] +name = "clap" +version = "4.5.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ed6719fffa43d0d87e5fd8caeab59be1554fb028cd30edc88fc4369b17971019" +dependencies = [ + "clap_builder", +] + +[[package]] +name = "clap_builder" +version = "4.5.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "216aec2b177652e3846684cbfe25c9964d18ec45234f0f5da5157b207ed1aab6" +dependencies = [ + "anstyle", + "clap_lex", +] + +[[package]] +name = "clap_lex" +version = "0.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1462739cb27611015575c0c11df5df7601141071f07518d56fcc1be504cbec97" + [[package]] name = "cmake" version = "0.1.50" @@ -422,6 +492,42 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "criterion" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f2b12d017a929603d80db1831cd3a24082f8137ce19c69e6447f54f5fc8d692f" +dependencies = [ + "anes", + "cast", + "ciborium", + "clap", + "criterion-plot", + "is-terminal", + "itertools 0.10.5", + "num-traits", + "once_cell", + "oorandom", + "plotters", + "rayon", + "regex", + "serde", + "serde_derive", + "serde_json", + "tinytemplate", + "walkdir", +] + +[[package]] +name = "criterion-plot" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6b50826342786a51a89e2da3a28f1c32b06e387201bc2d19791f622c673706b1" +dependencies = [ + "cast", + "itertools 0.10.5", +] + [[package]] name = "crossbeam-channel" version = "0.5.12" @@ -431,6 +537,25 @@ dependencies = [ "crossbeam-utils", ] +[[package]] +name = "crossbeam-deque" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "613f8cc01fe9cf1a3eb3d7f488fd2fa8388403e97039e2f73692932e291a770d" +dependencies = [ + "crossbeam-epoch", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-epoch" +version = "0.9.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e" +dependencies = [ + "crossbeam-utils", +] + [[package]] name = "crossbeam-queue" version = "0.3.11" @@ -446,6 +571,12 @@ version = "0.8.19" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "248e3bacc7dc6baa3b21e405ee045c3047101a49145e7e9eca583ab4c2ca5345" +[[package]] +name = "crunchy" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7" + [[package]] name = "crypto-common" version = "0.1.6" @@ -834,6 +965,16 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "half" +version = "2.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6dd08c532ae367adf81c312a4580bc67f1d0fe8bc9c460520283f4c0ff277888" +dependencies = [ + "cfg-if", + "crunchy", +] + [[package]] name = "hashbrown" version = "0.12.3" @@ -874,6 +1015,12 @@ version = "0.3.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d0c62115964e08cb8039170eb33c1d0e2388a256930279edca206fff675f82c3" +[[package]] +name = "hermit-abi" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fbf6a919d6cf397374f7dfeeea91d974c7c0a7221d0d0f4f20d859d329e53fcc" + [[package]] name = "hex" version = "0.4.3" @@ -965,11 +1112,22 @@ version = "1.0.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "eae7b9aee968036d54dce06cebaefd919e4472e753296daccd6d344e3e2df0c2" dependencies = [ - "hermit-abi", + "hermit-abi 0.3.5", "libc", "windows-sys 0.48.0", ] +[[package]] +name = "is-terminal" +version = "0.4.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "261f68e344040fbd0edea105bef17c66edf46f984ddb1115b775ce31be948f4b" +dependencies = [ + "hermit-abi 0.4.0", + "libc", + "windows-sys 0.52.0", +] + [[package]] name = "itertools" version = "0.10.5" @@ -1246,7 +1404,7 @@ version = "1.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4161fcb6d602d4d2081af7c3a45852d875a03dd337a6bfdd6e06407b61342a43" dependencies = [ - "hermit-abi", + "hermit-abi 0.3.5", "libc", ] @@ -1256,6 +1414,12 @@ version = "1.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" +[[package]] +name = "oorandom" +version = "11.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b410bbe7e14ab526a0e86877eb47c6996a2bd7746f027ba551028c925390e4e9" + [[package]] name = "parking" version = "2.2.0" @@ -1507,6 +1671,8 @@ dependencies = [ name = "pg_statement_splitter" version = "0.0.0" dependencies = [ + "criterion", + "insta", "pg_lexer", "pg_query", "text-size", @@ -1616,6 +1782,34 @@ version = "0.3.29" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2900ede94e305130c13ddd391e0ab7cbaeb783945ae07a279c268cb05109c6cb" +[[package]] +name = "plotters" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a15b6eccb8484002195a3e44fe65a4ce8e93a625797a063735536fd59cb01cf3" +dependencies = [ + "num-traits", + "plotters-backend", + "plotters-svg", + "wasm-bindgen", + "web-sys", +] + +[[package]] +name = "plotters-backend" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "414cec62c6634ae900ea1c56128dfe87cf63e7caece0852ec76aba307cebadb7" + +[[package]] +name = "plotters-svg" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "81b30686a7d9c3e010b84284bdd26a29f2138574f52f5eb6f794fc0ad924e705" +dependencies = [ + "plotters-backend", +] + [[package]] name = "polling" version = "2.8.0" @@ -1807,6 +2001,26 @@ dependencies = [ "getrandom", ] +[[package]] +name = "rayon" +version = "1.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b418a60154510ca1a002a752ca9714984e21e4241e804d32555251faf8b78ffa" +dependencies = [ + "either", + "rayon-core", +] + +[[package]] +name = "rayon-core" +version = "1.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1465873a3dfdaa8ae7cb14b4383657caab0b3e8a0aa9ae8e04b044854c8dfce2" +dependencies = [ + "crossbeam-deque", + "crossbeam-utils", +] + [[package]] name = "redox_syscall" version = "0.4.1" @@ -1948,6 +2162,15 @@ version = "1.0.16" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f98d2aa92eebf49b69786be48e4477826b256916e84a57ff2a4f21923b48eb4c" +[[package]] +name = "same-file" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502" +dependencies = [ + "winapi-util", +] + [[package]] name = "scopeguard" version = "1.2.0" @@ -2437,6 +2660,16 @@ version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ef927ca75afb808a4d64dd374f00a2adf8d0fcff8e7b184af886c3c87ec4a3f3" +[[package]] +name = "tinytemplate" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be4d6b5f19ff7664e8c98d03e2139cb510db9b0a60b55f8e8709b689d939b6bc" +dependencies = [ + "serde", + "serde_json", +] + [[package]] name = "tinyvec" version = "1.6.0" @@ -2598,6 +2831,16 @@ version = "1.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f3c4517f54858c779bbcbf228f4fca63d121bf85fbecb2dc578cdf4a39395690" +[[package]] +name = "walkdir" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "29790946404f91d9c5d06f9874efddea1dc06c5efe94541a7d6863108e3a5e4b" +dependencies = [ + "same-file", + "winapi-util", +] + [[package]] name = "wasi" version = "0.11.0+wasi-snapshot-preview1" @@ -2720,6 +2963,15 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" +[[package]] +name = "winapi-util" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cf221c93e13a30d793f7645a0e7762c55d169dbb0a49671918a2319d289b10bb" +dependencies = [ + "windows-sys 0.52.0", +] + [[package]] name = "winapi-x86_64-pc-windows-gnu" version = "0.4.0" diff --git a/crates/pg_base_db/src/change.rs b/crates/pg_base_db/src/change.rs index 26a926ff..4c7b7632 100644 --- a/crates/pg_base_db/src/change.rs +++ b/crates/pg_base_db/src/change.rs @@ -126,7 +126,6 @@ impl Change { ); // TODO also use errors returned by extract sql statement ranges doc.statement_ranges = pg_statement_splitter::split(&self.text) - .ranges .iter() .map(|r| r.clone()) .collect(); @@ -248,7 +247,7 @@ impl Change { + 1, ); - for range in pg_statement_splitter::split(extracted_text).ranges { + for range in pg_statement_splitter::split(extracted_text) { match doc .statement_ranges .binary_search_by(|r| r.start().cmp(&range.start())) diff --git a/crates/pg_base_db/src/document.rs b/crates/pg_base_db/src/document.rs index a9838833..0e7297b2 100644 --- a/crates/pg_base_db/src/document.rs +++ b/crates/pg_base_db/src/document.rs @@ -50,7 +50,6 @@ impl Document { || Vec::new(), |f| { pg_statement_splitter::split(&f) - .ranges .iter() .map(|range| range.clone()) .collect() diff --git a/crates/pg_statement_splitter/Cargo.toml b/crates/pg_statement_splitter/Cargo.toml index 15a30680..ce70d1a1 100644 --- a/crates/pg_statement_splitter/Cargo.toml +++ b/crates/pg_statement_splitter/Cargo.toml @@ -9,4 +9,10 @@ text-size = "1.1.1" [dev-dependencies] pg_query = "0.8" +insta = "1.31.0" +criterion = { version = "0.5" } + +[[bench]] +name = "pg_statement_splitter" +harness = false diff --git a/crates/pg_statement_splitter/benches/pg_statement_splitter.rs b/crates/pg_statement_splitter/benches/pg_statement_splitter.rs new file mode 100644 index 00000000..153adfb2 --- /dev/null +++ b/crates/pg_statement_splitter/benches/pg_statement_splitter.rs @@ -0,0 +1,67 @@ +use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; +use std::fs::{self}; + +const POSTGRES_REGRESS_PATH: &str = "../../libpg_query/test/sql/postgres_regress/"; +const SKIPPED_REGRESS_TESTS: &str = include_str!("../tests/skipped.txt"); + +fn from_elem(c: &mut Criterion) { + let mut paths: Vec<_> = fs::read_dir(POSTGRES_REGRESS_PATH) + .unwrap() + .map(|r| r.unwrap()) + .collect(); + paths.sort_by_key(|dir| dir.path()); + + for f in paths.iter() { + let path = f.path(); + + let test_name = path.file_stem().unwrap().to_str().unwrap(); + + // these require fixes in the parser + if SKIPPED_REGRESS_TESTS + .lines() + .collect::>() + .contains(&test_name) + { + continue; + } + + println!("Running test: {}", test_name); + + // remove \commands because pg_query doesn't support them + let contents = fs::read_to_string(&path) + .unwrap() + .lines() + .filter_map(|l| { + if !l.starts_with("\\") + && !l.ends_with("\\gset") + && !l.starts_with("--") + && !l.contains(":'") + && l.split("\t").count() <= 1 + && l != "ALTER INDEX attmp_idx ALTER COLUMN 0 SET STATISTICS 1000;" + { + if let Some(index) = l.find("--") { + Some(l[..index].to_string()) + } else { + Some(l.to_string()) + } + } else { + None + } + }) + .collect::>() + .join("\n"); + + let contents_str = contents.as_str(); + + c.bench_with_input( + BenchmarkId::new(test_name, contents_str), + &contents_str, + |b, &s| { + b.iter(|| pg_statement_splitter::split(&s)); + }, + ); + } +} + +criterion_group!(benches, from_elem); +criterion_main!(benches); diff --git a/crates/pg_statement_splitter/src/data.rs b/crates/pg_statement_splitter/src/data.rs new file mode 100644 index 00000000..4b32841f --- /dev/null +++ b/crates/pg_statement_splitter/src/data.rs @@ -0,0 +1,2023 @@ +use pg_lexer::SyntaxKind; +use std::{collections::HashMap, sync::LazyLock}; + +#[derive(Debug)] +pub enum SyntaxDefinition { + RequiredToken(SyntaxKind), // A single required token + OptionalToken(SyntaxKind), // A single optional token + OptionalGroup(Vec), // A group of tokens that are required if the group is present + AnyToken, // Any single token + AnyTokens(Option>), // A sequence of 0 or more tokens, of which any can be present + OneOf(Vec), // One of the specified tokens + OptionalRepeatedGroup(Vec), // A group of tokens that can be repeated +} + +impl SyntaxDefinition { + pub fn is_group(&self) -> bool { + match self { + SyntaxDefinition::OptionalGroup(_) => true, + SyntaxDefinition::OptionalRepeatedGroup(_) => true, + _ => false, + } + } + + pub fn first_required_tokens(&self) -> Vec<&SyntaxKind> { + match self { + SyntaxDefinition::RequiredToken(k) => vec![k], + SyntaxDefinition::OneOf(kinds) => kinds.iter().collect(), + _ => vec![], + } + } +} + +#[derive(Debug)] +pub struct SyntaxBuilder { + parts: Vec, + is_complete: bool, +} + +impl SyntaxBuilder { + // Start a new builder, which will automatically create a Group + pub fn new() -> Self { + Self { + parts: Vec::new(), + is_complete: false, + } + } + + pub fn new_complete() -> Self { + Self { + parts: Vec::new(), + is_complete: true, + } + } + + pub fn any_token(mut self) -> Self { + self.parts.push(SyntaxDefinition::AnyToken); + self + } + + /// The name of an object is almost always an `Ident` token, but due to naming conflicts it can + /// also be a set of other tokens. This function adds those tokens to the list of possible + /// tokens. + pub fn ident_like(mut self) -> Self { + self.parts.push(SyntaxDefinition::OneOf(vec![ + SyntaxKind::Ident, + SyntaxKind::VersionP, + SyntaxKind::Cursor, + SyntaxKind::Simple, + SyntaxKind::Set, + SyntaxKind::Leakproof, + ])); + self + } + + pub fn any_tokens(mut self, tokens: Option>) -> Self { + self.parts.push(SyntaxDefinition::AnyTokens(tokens)); + self + } + + pub fn required_token(mut self, token: SyntaxKind) -> Self { + self.parts.push(SyntaxDefinition::RequiredToken(token)); + self + } + + pub fn optional_token(mut self, token: SyntaxKind) -> Self { + self.parts.push(SyntaxDefinition::OptionalToken(token)); + self + } + + pub fn optional_schema_name_group(self) -> Self { + self.optional_group(vec![SyntaxKind::Ident, SyntaxKind::Ascii46]) + } + + pub fn optional_if_exists_group(self) -> Self { + self.optional_group(vec![SyntaxKind::IfP, SyntaxKind::Exists]) + } + + pub fn optional_if_not_exists_group(self) -> Self { + self.optional_group(vec![SyntaxKind::IfP, SyntaxKind::Not, SyntaxKind::Exists]) + } + + pub fn optional_or_replace_group(self) -> Self { + self.optional_group(vec![SyntaxKind::Or, SyntaxKind::Replace]) + } + + pub fn one_of(mut self, tokens: Vec) -> Self { + self.parts.push(SyntaxDefinition::OneOf(tokens)); + self + } + + pub fn optional_group(mut self, tokens: Vec) -> Self { + self.parts.push(SyntaxDefinition::OptionalGroup(tokens)); + self + } + + pub fn optional_repeated_group(mut self, builder: SyntaxBuilder) -> Self { + let res = builder.build(); + match res.first() { + Some(SyntaxDefinition::RequiredToken(_)) => {} + Some(SyntaxDefinition::OneOf(_)) => {} + _ => panic!("First token in repeated group must be required or one of"), + } + self.parts + .push(SyntaxDefinition::OptionalRepeatedGroup(res)); + self + } + + pub fn cte(mut self) -> Self { + self.parts.extend( + SyntaxBuilder::new() + .required_token(SyntaxKind::With) + .optional_token(SyntaxKind::Recursive) + .ident_like() + .required_token(SyntaxKind::As) + .required_token(SyntaxKind::Ascii40) + .any_tokens(None) + .required_token(SyntaxKind::Ascii41) + .optional_repeated_group( + SyntaxBuilder::new() + .required_token(SyntaxKind::Ascii44) + .ident_like() + .required_token(SyntaxKind::As) + .required_token(SyntaxKind::Ascii40) + .any_tokens(None) + .required_token(SyntaxKind::Ascii41), + ) + .build(), + ); + self + } + + pub fn build(mut self) -> Vec { + if !self.is_complete { + self.parts.push(SyntaxDefinition::AnyTokens(None)); + } else { + self.parts + .push(SyntaxDefinition::OptionalToken(SyntaxKind::Ascii59)); + } + self.parts + } +} + +#[derive(Debug)] +pub struct StatementDefinition { + pub stmt: SyntaxKind, + pub tokens: Vec, + pub prohibited_following_statements: Vec, + pub prohibited_tokens: Vec, + pub ignore_if_prohibited: bool, +} + +impl StatementDefinition { + fn new(stmt: SyntaxKind, b: SyntaxBuilder) -> Self { + Self { + stmt, + tokens: b.build(), + prohibited_following_statements: Vec::new(), + prohibited_tokens: Vec::new(), + ignore_if_prohibited: false, + } + } + + fn with_prohibited_tokens(mut self, prohibited: Vec) -> Self { + self.prohibited_tokens = prohibited; + self + } + + fn with_prohibited_following_statements(mut self, prohibited: Vec) -> Self { + self.prohibited_following_statements = prohibited; + self + } + + fn with_ignore_if_prohibited(mut self) -> Self { + self.ignore_if_prohibited = true; + self + } +} + +pub static STATEMENT_BRIDGE_DEFINITIONS: LazyLock>> = + LazyLock::new(|| { + let mut m: Vec = Vec::new(); + + m.push(StatementDefinition::new( + SyntaxKind::SelectStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Intersect) + .optional_token(SyntaxKind::All), + )); + + m.push(StatementDefinition::new( + SyntaxKind::SelectStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Union) + .optional_token(SyntaxKind::All), + )); + + m.push(StatementDefinition::new( + SyntaxKind::SelectStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Except) + .optional_token(SyntaxKind::All), + )); + + let mut stmt_starts: HashMap> = HashMap::new(); + + for stmt in m { + let first_token = stmt + .tokens + .first() + .expect("Expected first token to be present"); + + if let SyntaxDefinition::RequiredToken(token) = first_token { + stmt_starts.entry(*token).or_insert(Vec::new()).push(stmt); + } else { + panic!("Expected first token to be a required token"); + } + } + + stmt_starts + }); + +pub static STATEMENT_DEFINITIONS: LazyLock>> = + LazyLock::new(|| { + let mut m: Vec = Vec::new(); + + m.push(StatementDefinition::new( + SyntaxKind::CreateTrigStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Create) + .optional_token(SyntaxKind::Or) + .optional_token(SyntaxKind::Replace) + .optional_token(SyntaxKind::Constraint) + .required_token(SyntaxKind::Trigger) + .optional_schema_name_group() + .required_token(SyntaxKind::Ident) + .any_tokens(None) + .required_token(SyntaxKind::On) + .required_token(SyntaxKind::Ident) + .any_tokens(None) + .required_token(SyntaxKind::Execute) + .one_of(vec![SyntaxKind::Function, SyntaxKind::Procedure]) + .optional_schema_name_group() + .required_token(SyntaxKind::Ident), + )); + + m.push(StatementDefinition::new( + SyntaxKind::SelectStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Select) + .any_token(), + )); + + m.push(StatementDefinition::new( + SyntaxKind::SelectStmt, + SyntaxBuilder::new() + .cte() + .required_token(SyntaxKind::Select) + .any_token(), + )); + + m.push( + StatementDefinition::new( + SyntaxKind::SelectStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Ascii40) + .required_token(SyntaxKind::Select) + .any_tokens(None) + .required_token(SyntaxKind::Ascii41) + .any_tokens(Some(vec![ + SyntaxKind::Union, + SyntaxKind::Except, + SyntaxKind::Intersect, + SyntaxKind::All, + ])) + .required_token(SyntaxKind::Ascii40) + .required_token(SyntaxKind::Select) + .any_tokens(None) + .required_token(SyntaxKind::Ascii41) + .optional_repeated_group( + SyntaxBuilder::new() + .one_of(vec![ + SyntaxKind::Union, + SyntaxKind::Except, + SyntaxKind::Intersect, + ]) + .optional_token(SyntaxKind::All) + .required_token(SyntaxKind::Ascii40) + .required_token(SyntaxKind::Select) + .any_tokens(None) + .required_token(SyntaxKind::Ascii41), + ), + ) + .with_ignore_if_prohibited(), + ); + + // // "TABLE t1;" + // // is syntactic sugar for "SELECT * FROM t1" + // m.push( + // StatementDefinition::new( + // SyntaxKind::SelectStmt, + // // we use "new_complete" here + // SyntaxBuilder::new_complete() + // .required_token(SyntaxKind::Table) + // .optional_schema_name_group() + // .ident_like(), + // ) + // // this pollutes the "prohibited following statements" logic too much + // // so we need to ignore it as a prohibited statement + // .with_ignore_if_prohibited(), + // ); + + // VALUES is also legal as a standalone query + // e.g. VALUES (1,2), (3,4+4), (7,77.7); + // todo use repeated group + m.push(StatementDefinition::new( + SyntaxKind::SelectStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Values) + .required_token(SyntaxKind::Ascii40) + .any_tokens(None) + .required_token(SyntaxKind::Ascii41), + )); + + m.push( + StatementDefinition::new( + SyntaxKind::InsertStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Insert) + .required_token(SyntaxKind::Into) + .optional_schema_name_group() + .ident_like(), + ) + .with_prohibited_following_statements(vec![ + SyntaxKind::SelectStmt, + SyntaxKind::VariableSetStmt, + ]), + ); + + m.push(StatementDefinition::new( + SyntaxKind::DeleteStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::DeleteP) + .required_token(SyntaxKind::From) + .optional_token(SyntaxKind::Only) + .optional_schema_name_group() + .required_token(SyntaxKind::Ident), + )); + + m.push(StatementDefinition::new( + SyntaxKind::UpdateStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Update) + .optional_token(SyntaxKind::Only) + .optional_schema_name_group() + .ident_like() + .any_tokens(None) + .required_token(SyntaxKind::Set) + .any_token(), + )); + + m.push(StatementDefinition::new( + SyntaxKind::MergeStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Merge) + .required_token(SyntaxKind::Into) + .optional_token(SyntaxKind::Only) + .optional_schema_name_group() + .required_token(SyntaxKind::Ident), + )); + + m.push( + StatementDefinition::new( + SyntaxKind::AlterTableStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Alter) + .optional_token(SyntaxKind::Materialized) + .optional_token(SyntaxKind::Foreign) + .one_of(vec![SyntaxKind::Table, SyntaxKind::Index, SyntaxKind::View]) + .optional_if_exists_group() + .optional_token(SyntaxKind::Only) + .optional_schema_name_group() + .ident_like() + .any_token(), + ) + .with_prohibited_tokens(vec![SyntaxKind::Rename]), + ); + + // no idea why this is an AlterTableStmt + m.push(StatementDefinition::new( + SyntaxKind::AlterTableStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Alter) + .required_token(SyntaxKind::Sequence) + .optional_if_exists_group() + .optional_schema_name_group() + .ident_like() + .required_token(SyntaxKind::Set) + .one_of(vec![SyntaxKind::Logged, SyntaxKind::Unlogged]), + )); + + m.push(StatementDefinition::new( + SyntaxKind::RenameStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Alter) + .any_tokens(None) + .optional_schema_name_group() + .required_token(SyntaxKind::Ident) + .any_tokens(None) + .required_token(SyntaxKind::Rename) + .required_token(SyntaxKind::To) + .optional_schema_name_group() + .required_token(SyntaxKind::Ident), + )); + + m.push(StatementDefinition::new( + SyntaxKind::RenameStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Alter) + .required_token(SyntaxKind::Table) + .optional_if_exists_group() + .optional_token(SyntaxKind::Only) + .optional_schema_name_group() + .required_token(SyntaxKind::Ident) + .required_token(SyntaxKind::Rename), + )); + + m.push(StatementDefinition::new( + SyntaxKind::AlterDomainStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Alter) + .required_token(SyntaxKind::DomainP) + .optional_schema_name_group() + .required_token(SyntaxKind::Ident), + )); + + m.push(StatementDefinition::new( + SyntaxKind::CallStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Call) + .optional_schema_name_group() + .ident_like() + .required_token(SyntaxKind::Ascii40) + .any_tokens(None) + .required_token(SyntaxKind::Ascii41), + )); + + m.push( + StatementDefinition::new( + SyntaxKind::AlterDefaultPrivilegesStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Alter) + .required_token(SyntaxKind::Default) + .required_token(SyntaxKind::Privileges), + ) + .with_prohibited_following_statements(vec![SyntaxKind::GrantStmt]), + ); + + m.push(StatementDefinition::new( + SyntaxKind::ClusterStmt, + SyntaxBuilder::new().required_token(SyntaxKind::Cluster), + )); + + m.push(StatementDefinition::new( + SyntaxKind::CopyStmt, + SyntaxBuilder::new().required_token(SyntaxKind::Copy), + )); + + m.push(StatementDefinition::new( + SyntaxKind::ExecuteStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Execute) + .optional_schema_name_group() + .required_token(SyntaxKind::Ident), + )); + + m.push( + StatementDefinition::new( + SyntaxKind::CreateStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Create) + .any_tokens(Some(vec![ + SyntaxKind::Global, + SyntaxKind::Local, + SyntaxKind::Temporary, + SyntaxKind::Temp, + SyntaxKind::Unlogged, + ])) + .required_token(SyntaxKind::Table) + .optional_if_not_exists_group() + .optional_schema_name_group() + .ident_like(), + ) + .with_prohibited_following_statements(vec![SyntaxKind::TransactionStmt]) + .with_prohibited_tokens(vec![SyntaxKind::As]), + ); + + m.push(StatementDefinition::new( + SyntaxKind::DefineStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Create) + .optional_token(SyntaxKind::Or) + .optional_token(SyntaxKind::Replace) + .required_token(SyntaxKind::Aggregate), + )); + + m.push(StatementDefinition::new( + SyntaxKind::CreateOpClassStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Create) + .required_token(SyntaxKind::Operator) + .required_token(SyntaxKind::Class) + .optional_schema_name_group() + .required_token(SyntaxKind::Ident) + .optional_token(SyntaxKind::Default) + .required_token(SyntaxKind::For) + .required_token(SyntaxKind::TypeP) + .optional_schema_name_group() + .one_of(vec![SyntaxKind::Ident, SyntaxKind::TextP]) + .required_token(SyntaxKind::Using), + )); + + m.push(StatementDefinition::new( + SyntaxKind::DropStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Drop) + .one_of(vec![ + SyntaxKind::Server, + SyntaxKind::Collation, + SyntaxKind::ConversionP, + SyntaxKind::Extension, + SyntaxKind::Aggregate, + SyntaxKind::DomainP, + SyntaxKind::Sequence, + SyntaxKind::Table, + SyntaxKind::TypeP, + SyntaxKind::Routine, + SyntaxKind::Procedure, + SyntaxKind::Schema, + SyntaxKind::View, + SyntaxKind::Language, + SyntaxKind::Function, + ]) + .optional_if_exists_group() + .optional_schema_name_group() + .ident_like(), + )); + + m.push(StatementDefinition::new( + SyntaxKind::DropStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Drop) + .required_token(SyntaxKind::TextP) + .required_token(SyntaxKind::Search) + .one_of(vec![ + SyntaxKind::Parser, + SyntaxKind::Dictionary, + SyntaxKind::Template, + SyntaxKind::Configuration, + ]) + .optional_if_exists_group() + .optional_schema_name_group() + .required_token(SyntaxKind::Ident), + )); + + m.push(StatementDefinition::new( + SyntaxKind::DropStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Drop) + .required_token(SyntaxKind::Materialized) + .required_token(SyntaxKind::View) + .optional_if_exists_group() + .optional_schema_name_group() + .required_token(SyntaxKind::Ident), + )); + + m.push(StatementDefinition::new( + SyntaxKind::DropStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Drop) + .required_token(SyntaxKind::Event) + .required_token(SyntaxKind::Trigger) + .optional_if_exists_group() + .optional_schema_name_group() + .required_token(SyntaxKind::Ident), + )); + + m.push(StatementDefinition::new( + SyntaxKind::DropStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Drop) + .optional_token(SyntaxKind::Procedural) + .required_token(SyntaxKind::Language) + .optional_if_exists_group() + .optional_schema_name_group() + .required_token(SyntaxKind::Ident), + )); + + m.push(StatementDefinition::new( + SyntaxKind::DropStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Drop) + .required_token(SyntaxKind::Operator) + .required_token(SyntaxKind::Class) + .optional_if_exists_group() + .optional_schema_name_group() + .required_token(SyntaxKind::Ident) + .required_token(SyntaxKind::Using) + .required_token(SyntaxKind::Ident), + )); + + m.push(StatementDefinition::new( + SyntaxKind::DropStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Drop) + .required_token(SyntaxKind::Access) + .required_token(SyntaxKind::Method) + .optional_if_exists_group() + .optional_schema_name_group() + .required_token(SyntaxKind::Ident), + )); + + m.push(StatementDefinition::new( + SyntaxKind::DropStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Drop) + .one_of(vec![SyntaxKind::Rule, SyntaxKind::Trigger]) + .optional_if_exists_group() + .required_token(SyntaxKind::Ident) + .required_token(SyntaxKind::On) + .optional_schema_name_group() + .required_token(SyntaxKind::Ident), + )); + + m.push(StatementDefinition::new( + SyntaxKind::DropStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Drop) + .required_token(SyntaxKind::TextP) + .required_token(SyntaxKind::Search) + .one_of(vec![ + SyntaxKind::Template, + SyntaxKind::Configuration, + SyntaxKind::Parser, + SyntaxKind::Dictionary, + ]) + .optional_if_exists_group() + .optional_schema_name_group() + .required_token(SyntaxKind::Ident), + )); + + m.push(StatementDefinition::new( + SyntaxKind::DropStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Drop) + .required_token(SyntaxKind::Foreign) + .required_token(SyntaxKind::Table) + .optional_if_exists_group() + .optional_schema_name_group() + .required_token(SyntaxKind::Ident), + )); + + m.push(StatementDefinition::new( + SyntaxKind::DropStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Drop) + .required_token(SyntaxKind::Cast) + .optional_if_exists_group() + .required_token(SyntaxKind::Ascii40) + .any_tokens(None) + .required_token(SyntaxKind::As) + .any_tokens(None) + .required_token(SyntaxKind::Ascii41), + )); + + m.push(StatementDefinition::new( + SyntaxKind::DropStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Drop) + .required_token(SyntaxKind::Foreign) + .required_token(SyntaxKind::DataP) + .required_token(SyntaxKind::Wrapper) + .optional_if_exists_group() + .optional_schema_name_group() + .required_token(SyntaxKind::Ident), + )); + + m.push(StatementDefinition::new( + SyntaxKind::DropStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Drop) + .required_token(SyntaxKind::Index) + .optional_token(SyntaxKind::Concurrently) + .optional_if_exists_group() + .optional_schema_name_group() + .required_token(SyntaxKind::Ident), + )); + + m.push(StatementDefinition::new( + SyntaxKind::DropStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Drop) + .required_token(SyntaxKind::Operator) + .optional_if_exists_group() + .optional_schema_name_group() + .one_of(vec![SyntaxKind::Ident, SyntaxKind::Op]) + .required_token(SyntaxKind::Ascii40) + .any_tokens(None) + .required_token(SyntaxKind::Ascii41), + )); + + m.push(StatementDefinition::new( + SyntaxKind::DropStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Drop) + .required_token(SyntaxKind::Function) + .optional_if_exists_group() + .optional_schema_name_group() + .ident_like() + .required_token(SyntaxKind::Ascii40) + .any_tokens(None) + .required_token(SyntaxKind::Ascii41), + )); + + m.push(StatementDefinition::new( + SyntaxKind::DropStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Drop) + .required_token(SyntaxKind::Operator) + .required_token(SyntaxKind::Family) + .optional_if_exists_group() + .optional_schema_name_group() + .required_token(SyntaxKind::Ident) + .required_token(SyntaxKind::Using) + .required_token(SyntaxKind::Ident), + )); + + m.push(StatementDefinition::new( + SyntaxKind::DefineStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Create) + .required_token(SyntaxKind::TextP) + .required_token(SyntaxKind::Search) + .one_of(vec![ + SyntaxKind::Dictionary, + SyntaxKind::Configuration, + SyntaxKind::Template, + SyntaxKind::Parser, + ]) + .optional_schema_name_group() + .required_token(SyntaxKind::Ident) + .required_token(SyntaxKind::Ascii40) + .any_tokens(None) + .required_token(SyntaxKind::Ascii41), + )); + + m.push(StatementDefinition::new( + SyntaxKind::DefineStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Create) + .required_token(SyntaxKind::Operator), + )); + + m.push(StatementDefinition::new( + SyntaxKind::DefineStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Create) + .optional_or_replace_group() + .required_token(SyntaxKind::Aggregate) + .optional_schema_name_group() + .required_token(SyntaxKind::Ident) + .required_token(SyntaxKind::Ascii40) + .any_tokens(None) + .required_token(SyntaxKind::Ascii41), + )); + + m.push(StatementDefinition::new( + SyntaxKind::DefineStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Create) + .required_token(SyntaxKind::TypeP) + .optional_schema_name_group() + .required_token(SyntaxKind::Ident), + )); + + m.push(StatementDefinition::new( + SyntaxKind::CompositeTypeStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Create) + .required_token(SyntaxKind::TypeP) + .optional_schema_name_group() + .required_token(SyntaxKind::Ident) + .required_token(SyntaxKind::As), + )); + + m.push(StatementDefinition::new( + SyntaxKind::CreateEnumStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Create) + .required_token(SyntaxKind::TypeP) + .optional_schema_name_group() + .required_token(SyntaxKind::Ident) + .required_token(SyntaxKind::As) + .required_token(SyntaxKind::EnumP), + )); + + m.push(StatementDefinition::new( + SyntaxKind::CreateRangeStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Create) + .required_token(SyntaxKind::TypeP) + .optional_schema_name_group() + .required_token(SyntaxKind::Ident) + .required_token(SyntaxKind::As) + .required_token(SyntaxKind::Range), + )); + + m.push(StatementDefinition::new( + SyntaxKind::TruncateStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Truncate) + .optional_token(SyntaxKind::Table) + .optional_token(SyntaxKind::Only) + .required_token(SyntaxKind::Ident), + )); + + m.push(StatementDefinition::new( + SyntaxKind::CommentStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Comment) + .required_token(SyntaxKind::On) + .any_tokens(None) + .required_token(SyntaxKind::Is) + .one_of(vec![ + SyntaxKind::Ident, + SyntaxKind::Sconst, + SyntaxKind::NullP, + ]), + )); + + m.push(StatementDefinition::new( + SyntaxKind::FetchStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Fetch) + .any_tokens(None) + .required_token(SyntaxKind::Ident), + )); + + m.push(StatementDefinition::new( + SyntaxKind::FetchStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Move) + .any_tokens(None) + .required_token(SyntaxKind::Ident), + )); + + m.push(StatementDefinition::new( + SyntaxKind::VacuumStmt, + SyntaxBuilder::new().required_token(SyntaxKind::Analyze), + )); + + m.push(StatementDefinition::new( + SyntaxKind::IndexStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Create) + .optional_token(SyntaxKind::Unique) + .required_token(SyntaxKind::Index) + .any_tokens(None) + .required_token(SyntaxKind::On) + .optional_token(SyntaxKind::Only) + .optional_schema_name_group() + .required_token(SyntaxKind::Ident), + )); + + m.push( + StatementDefinition::new( + SyntaxKind::CreateFunctionStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Create) + .optional_token(SyntaxKind::Or) + .optional_token(SyntaxKind::Replace) + .one_of(vec![SyntaxKind::Function, SyntaxKind::Procedure]) + .any_tokens(None) + .required_token(SyntaxKind::Ascii40) + .any_tokens(None) + .required_token(SyntaxKind::Ascii41), + ) + .with_prohibited_following_statements(vec![ + SyntaxKind::TransactionStmt, + SyntaxKind::VariableSetStmt, + ]), + ); + + m.push( + StatementDefinition::new( + SyntaxKind::AlterFunctionStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Alter) + .one_of(vec![SyntaxKind::Function, SyntaxKind::Procedure]) + .optional_schema_name_group() + .required_token(SyntaxKind::Ident), + ) + .with_prohibited_following_statements(vec![SyntaxKind::VariableSetStmt]), + ); + + m.push(StatementDefinition::new( + SyntaxKind::DoStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Do) + .optional_token(SyntaxKind::Language) + .optional_token(SyntaxKind::Ident) + .required_token(SyntaxKind::Sconst), + )); + + m.push( + StatementDefinition::new( + SyntaxKind::RuleStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Create) + .optional_token(SyntaxKind::Or) + .optional_token(SyntaxKind::Replace) + .required_token(SyntaxKind::Rule) + .optional_schema_name_group() + .required_token(SyntaxKind::Ident) + .required_token(SyntaxKind::As) + .required_token(SyntaxKind::On) + .one_of(vec![ + SyntaxKind::Select, + SyntaxKind::Insert, + SyntaxKind::Update, + SyntaxKind::DeleteP, + ]) + .required_token(SyntaxKind::To) + .any_tokens(None) + .required_token(SyntaxKind::Do), + ) + .with_prohibited_following_statements(vec![ + SyntaxKind::SelectStmt, + SyntaxKind::InsertStmt, + SyntaxKind::UpdateStmt, + SyntaxKind::DeleteStmt, + SyntaxKind::VariableSetStmt, + ]), + ); + + m.push(StatementDefinition::new( + SyntaxKind::NotifyStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Notify) + .required_token(SyntaxKind::Ident), + )); + + m.push(StatementDefinition::new( + SyntaxKind::ListenStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Listen) + .required_token(SyntaxKind::Ident), + )); + + m.push(StatementDefinition::new( + SyntaxKind::UnlistenStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Unlisten) + .one_of(vec![SyntaxKind::Ident, SyntaxKind::Ascii42]), + )); + + // DECLARE c CURSOR FOR SELECT ctid,cmin,* FROM combocidtest + m.push( + StatementDefinition::new( + SyntaxKind::DeclareCursorStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Declare) + .required_token(SyntaxKind::Ident) + .any_tokens(None) + .required_token(SyntaxKind::Cursor) + .any_tokens(None) + .required_token(SyntaxKind::For), + ) + .with_prohibited_following_statements(vec![SyntaxKind::SelectStmt]), + ); + + // m.push(StatementDefinition::new( + // SyntaxKind::DeclareCursorStmt, + // SyntaxBuilder::new() + // .required_token(SyntaxKind::Declare) + // .required_token(SyntaxKind::Ident) + // .any_tokens(None) + // .required_token(SyntaxKind::Cursor) + // .any_tokens(None) + // .required_token(SyntaxKind::For) + // .one_of(vec![SyntaxKind::Select, SyntaxKind::With]) + // .any_token(), + // )); + + m.push( + StatementDefinition::new( + SyntaxKind::TransactionStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Release) + .optional_token(SyntaxKind::Savepoint) + .required_token(SyntaxKind::Ident), + ) + .with_prohibited_following_statements(vec![SyntaxKind::TransactionStmt]), + ); + + m.push( + StatementDefinition::new( + SyntaxKind::TransactionStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Savepoint) + .required_token(SyntaxKind::Ident), + ) + .with_prohibited_following_statements(vec![SyntaxKind::TransactionStmt]), + ); + + m.push(StatementDefinition::new( + SyntaxKind::TransactionStmt, + SyntaxBuilder::new().required_token(SyntaxKind::BeginP), + )); + + m.push(StatementDefinition::new( + SyntaxKind::TransactionStmt, + SyntaxBuilder::new().required_token(SyntaxKind::EndP), + )); + + m.push(StatementDefinition::new( + SyntaxKind::TransactionStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Prepare) + .required_token(SyntaxKind::Transaction) + .required_token(SyntaxKind::Sconst), + )); + + m.push(StatementDefinition::new( + SyntaxKind::TransactionStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Start) + .required_token(SyntaxKind::Transaction) + .any_token(), + )); + + m.push(StatementDefinition::new( + SyntaxKind::TransactionStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::BeginP) + .required_token(SyntaxKind::Transaction), + )); + + m.push(StatementDefinition::new( + SyntaxKind::TransactionStmt, + SyntaxBuilder::new().required_token(SyntaxKind::Commit), + )); + + m.push( + StatementDefinition::new( + SyntaxKind::TransactionStmt, + SyntaxBuilder::new_complete() + .required_token(SyntaxKind::Rollback) + .any_tokens(None) + .required_token(SyntaxKind::To) + .optional_token(SyntaxKind::Savepoint) + .required_token(SyntaxKind::Ident), + ) + .with_prohibited_following_statements(vec![SyntaxKind::TransactionStmt]), + ); + + m.push( + StatementDefinition::new( + SyntaxKind::TransactionStmt, + SyntaxBuilder::new_complete().required_token(SyntaxKind::Rollback), + ) + .with_prohibited_following_statements(vec![SyntaxKind::TransactionStmt]), + ); + + m.push( + StatementDefinition::new( + SyntaxKind::ViewStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Create) + .optional_or_replace_group() + .optional_token(SyntaxKind::Temporary) + .optional_token(SyntaxKind::Temp) + .optional_token(SyntaxKind::Recursive) + .required_token(SyntaxKind::View) + .optional_schema_name_group() + .required_token(SyntaxKind::Ident) + .any_tokens(None) + .required_token(SyntaxKind::As), + ) + .with_prohibited_following_statements(vec![SyntaxKind::SelectStmt]), + ); + + m.push(StatementDefinition::new( + SyntaxKind::LoadStmt, + SyntaxBuilder::new().required_token(SyntaxKind::Load), + )); + + m.push(StatementDefinition::new( + SyntaxKind::CreateDomainStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Create) + .required_token(SyntaxKind::DomainP) + .optional_schema_name_group() + .required_token(SyntaxKind::Ident), + )); + + m.push(StatementDefinition::new( + SyntaxKind::CreatedbStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Create) + .required_token(SyntaxKind::Database) + .required_token(SyntaxKind::Ident), + )); + + m.push(StatementDefinition::new( + SyntaxKind::DropdbStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Drop) + .required_token(SyntaxKind::Database) + .optional_if_exists_group() + .required_token(SyntaxKind::Ident), + )); + + m.push( + StatementDefinition::new( + SyntaxKind::VacuumStmt, + SyntaxBuilder::new().required_token(SyntaxKind::Vacuum), + ) + .with_prohibited_following_statements(vec![SyntaxKind::VacuumStmt]), + ); + + m.push( + StatementDefinition::new( + SyntaxKind::CreateTableAsStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Create) + .required_token(SyntaxKind::Materialized) + .required_token(SyntaxKind::View) + .optional_if_not_exists_group() + .optional_schema_name_group() + .required_token(SyntaxKind::Ident) + .any_tokens(None) + .required_token(SyntaxKind::As), + ) + .with_prohibited_following_statements(vec![ + SyntaxKind::SelectStmt, + SyntaxKind::ExecuteStmt, + ]), + ); + + m.push( + StatementDefinition::new( + SyntaxKind::CreateTableAsStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Create) + .any_tokens(Some(vec![ + SyntaxKind::Global, + SyntaxKind::Local, + SyntaxKind::Temporary, + SyntaxKind::Temp, + ])) + .required_token(SyntaxKind::Table) + .optional_if_not_exists_group() + .optional_schema_name_group() + .ident_like() + .any_tokens(None) + .required_token(SyntaxKind::As) + .one_of(vec![ + SyntaxKind::With, + SyntaxKind::Select, + SyntaxKind::Values, + SyntaxKind::Table, + SyntaxKind::Execute, + ]), + ) + .with_prohibited_following_statements(vec![ + SyntaxKind::SelectStmt, + SyntaxKind::ExecuteStmt, + ]), + ); + + m.push( + StatementDefinition::new( + SyntaxKind::ViewStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Create) + .optional_token(SyntaxKind::Or) + .optional_token(SyntaxKind::Replace) + .optional_token(SyntaxKind::Temporary) + .optional_token(SyntaxKind::Temp) + .optional_token(SyntaxKind::Recursive) + .required_token(SyntaxKind::View) + .optional_if_not_exists_group() + .optional_schema_name_group() + .required_token(SyntaxKind::Ident) + .any_tokens(None) + .required_token(SyntaxKind::As), + ) + .with_prohibited_following_statements(vec![SyntaxKind::SelectStmt]), + ); + + m.push( + StatementDefinition::new( + SyntaxKind::ExplainStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Explain) + .any_tokens(None) + .one_of(vec![ + SyntaxKind::Select, + SyntaxKind::Insert, + SyntaxKind::Update, + SyntaxKind::DeleteP, + SyntaxKind::Merge, + SyntaxKind::Execute, + SyntaxKind::Create, + SyntaxKind::Declare, + SyntaxKind::Create, + ]), + ) + .with_prohibited_following_statements(vec![ + // SyntaxKind::VacuumStmt, + // SyntaxKind::SelectStmt, + // SyntaxKind::CreateTableAsStmt, + // SyntaxKind::InsertStmt, + // SyntaxKind::DeleteStmt, + // SyntaxKind::UpdateStmt, + // SyntaxKind::MergeStmt, + // SyntaxKind::ExecuteStmt, + // SyntaxKind::CreateStmt, + // SyntaxKind::DeclareCursorStmt, + // todo remove this again when we include all deps + SyntaxKind::VariableSetStmt, + ]), + ); + + m.push(StatementDefinition::new( + SyntaxKind::CreateSeqStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Create) + .any_tokens(Some(vec![ + SyntaxKind::Temporary, + SyntaxKind::Temp, + SyntaxKind::Unlogged, + ])) + .required_token(SyntaxKind::Sequence) + .optional_if_not_exists_group() + .optional_schema_name_group() + .required_token(SyntaxKind::Ident), + )); + + m.push(StatementDefinition::new( + SyntaxKind::AlterSeqStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Alter) + .required_token(SyntaxKind::Sequence) + .optional_if_exists_group() + .optional_schema_name_group() + .required_token(SyntaxKind::Ident), + )); + + m.push(StatementDefinition::new( + SyntaxKind::VariableSetStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Reset) + .one_of(vec![SyntaxKind::All, SyntaxKind::Ident, SyntaxKind::Role]), + )); + + m.push(StatementDefinition::new( + SyntaxKind::VariableSetStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Reset) + .required_token(SyntaxKind::Session) + .required_token(SyntaxKind::Authorization), + )); + + m.push(StatementDefinition::new( + SyntaxKind::VariableSetStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Set) + .required_token(SyntaxKind::Transaction), + )); + + m.push(StatementDefinition::new( + SyntaxKind::VariableSetStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Set) + .required_token(SyntaxKind::Role) + .required_token(SyntaxKind::Ident), + )); + + // ref: https://www.postgresql.org/docs/current/sql-set-session-authorization.html + m.push(StatementDefinition::new( + SyntaxKind::VariableSetStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Set) + .optional_token(SyntaxKind::Local) + .required_token(SyntaxKind::Session) + .required_token(SyntaxKind::Authorization) + .one_of(vec![SyntaxKind::Ident, SyntaxKind::Sconst]), + )); + + m.push(StatementDefinition::new( + SyntaxKind::VariableSetStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Set) + .any_tokens(Some(vec![ + SyntaxKind::Local, + SyntaxKind::Session, + SyntaxKind::Ident, + SyntaxKind::Ascii46, + ])) + .one_of(vec![SyntaxKind::To, SyntaxKind::Ascii61]) + .any_token(), + )); + + m.push(StatementDefinition::new( + SyntaxKind::VariableSetStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Set) + .optional_token(SyntaxKind::Session) + .optional_token(SyntaxKind::Local) + .required_token(SyntaxKind::Time) + .required_token(SyntaxKind::Zone) + .any_token(), + )); + + m.push(StatementDefinition::new( + SyntaxKind::VariableShowStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Show) + .one_of(vec![SyntaxKind::Ident, SyntaxKind::All]), + )); + + m.push(StatementDefinition::new( + SyntaxKind::DiscardStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Discard) + .one_of(vec![ + SyntaxKind::All, + SyntaxKind::Plans, + SyntaxKind::Sequences, + SyntaxKind::Temp, + SyntaxKind::Temporary, + ]), + )); + + m.push(StatementDefinition::new( + SyntaxKind::CreateRoleStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Create) + .one_of(vec![SyntaxKind::Role, SyntaxKind::GroupP, SyntaxKind::User]) + .required_token(SyntaxKind::Ident), + )); + + m.push(StatementDefinition::new( + SyntaxKind::AlterRoleStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Alter) + .one_of(vec![SyntaxKind::Role, SyntaxKind::User]) + .required_token(SyntaxKind::Ident), + )); + + m.push( + StatementDefinition::new( + SyntaxKind::AlterRoleSetStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Alter) + .required_token(SyntaxKind::Role) + .required_token(SyntaxKind::Ident) + .required_token(SyntaxKind::Set), + ) + .with_prohibited_following_statements(vec![SyntaxKind::VariableSetStmt]), + ); + + m.push(StatementDefinition::new( + SyntaxKind::DropRoleStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Drop) + .one_of(vec![SyntaxKind::Role, SyntaxKind::User, SyntaxKind::GroupP]) + .optional_token(SyntaxKind::IfP) + .optional_token(SyntaxKind::Exists) + .required_token(SyntaxKind::Ident), + )); + + m.push(StatementDefinition::new( + SyntaxKind::LockStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::LockP) + .optional_token(SyntaxKind::Table) + .optional_token(SyntaxKind::Only) + .optional_schema_name_group() + .required_token(SyntaxKind::Ident), + )); + + m.push(StatementDefinition::new( + SyntaxKind::ConstraintsSetStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Set) + .required_token(SyntaxKind::Constraints), + )); + + m.push(StatementDefinition::new( + SyntaxKind::ReindexStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Reindex) + .any_tokens(None) + .one_of(vec![ + SyntaxKind::Table, + SyntaxKind::Index, + SyntaxKind::Schema, + ]) + .optional_token(SyntaxKind::Concurrently) + .optional_schema_name_group() + .required_token(SyntaxKind::Ident), + )); + + m.push(StatementDefinition::new( + SyntaxKind::ReindexStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Reindex) + .any_tokens(None) + .one_of(vec![SyntaxKind::Database, SyntaxKind::SystemP]), + )); + + m.push(StatementDefinition::new( + SyntaxKind::CheckPointStmt, + SyntaxBuilder::new().required_token(SyntaxKind::Checkpoint), + )); + + // CREATE TABLE, CREATE VIEW, CREATE INDEX, CREATE SEQUENCE, CREATE TRIGGER and GRANT + m.push( + StatementDefinition::new( + SyntaxKind::CreateSchemaStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Create) + .required_token(SyntaxKind::Schema), + ) + .with_prohibited_following_statements(vec![ + SyntaxKind::CreateTableAsStmt, + SyntaxKind::CreateStmt, + SyntaxKind::SelectStmt, + SyntaxKind::IndexStmt, + SyntaxKind::CreateSeqStmt, + SyntaxKind::CreateTrigStmt, + SyntaxKind::GrantStmt, + ]), + ); + + m.push(StatementDefinition::new( + SyntaxKind::AlterDatabaseStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Alter) + .required_token(SyntaxKind::Database) + .required_token(SyntaxKind::Ident), + )); + + m.push(StatementDefinition::new( + SyntaxKind::AlterDatabaseRefreshCollStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Alter) + .required_token(SyntaxKind::Database) + .required_token(SyntaxKind::Ident) + .required_token(SyntaxKind::Refresh) + .required_token(SyntaxKind::Collation) + .required_token(SyntaxKind::VersionP), + )); + + m.push(StatementDefinition::new( + SyntaxKind::AlterDatabaseSetStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Alter) + .required_token(SyntaxKind::Database) + .required_token(SyntaxKind::Ident) + .one_of(vec![SyntaxKind::Set, SyntaxKind::Reset]), + )); + + m.push(StatementDefinition::new( + SyntaxKind::CreateConversionStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Create) + .optional_token(SyntaxKind::Default) + .required_token(SyntaxKind::ConversionP) + .optional_schema_name_group() + .required_token(SyntaxKind::Ident) + .required_token(SyntaxKind::For) + .required_token(SyntaxKind::Sconst) + .required_token(SyntaxKind::To) + .required_token(SyntaxKind::Sconst) + .required_token(SyntaxKind::From) + .required_token(SyntaxKind::Ident), + )); + + m.push(StatementDefinition::new( + SyntaxKind::CreateCastStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Create) + .required_token(SyntaxKind::Cast) + .required_token(SyntaxKind::Ascii40) + .any_tokens(None) + .required_token(SyntaxKind::As) + .any_tokens(None) + .required_token(SyntaxKind::Ascii41), + )); + + m.push(StatementDefinition::new( + SyntaxKind::CreateOpFamilyStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Create) + .required_token(SyntaxKind::Operator) + .required_token(SyntaxKind::Family) + .optional_schema_name_group() + .required_token(SyntaxKind::Ident), + )); + + m.push( + StatementDefinition::new( + SyntaxKind::AlterOpFamilyStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Alter) + .required_token(SyntaxKind::Operator) + .required_token(SyntaxKind::Family) + .optional_schema_name_group() + .required_token(SyntaxKind::Ident) + .required_token(SyntaxKind::Using) + .required_token(SyntaxKind::Ident) + .one_of(vec![SyntaxKind::Drop, SyntaxKind::AddP, SyntaxKind::Rename]), + ) + .with_prohibited_tokens(vec![SyntaxKind::Rename]), + ); + + m.push( + StatementDefinition::new( + SyntaxKind::PrepareStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Prepare) + .required_token(SyntaxKind::Ident) + .any_tokens(None) + .required_token(SyntaxKind::As) + .any_token(), + ) + .with_prohibited_following_statements(vec![ + SyntaxKind::SelectStmt, + SyntaxKind::InsertStmt, + SyntaxKind::UpdateStmt, + SyntaxKind::DeleteStmt, + ]), + ); + + m.push(StatementDefinition::new( + SyntaxKind::ClosePortalStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Close) + .one_of(vec![SyntaxKind::Ident, SyntaxKind::All]), + )); + + m.push(StatementDefinition::new( + SyntaxKind::DeallocateStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Deallocate) + .optional_token(SyntaxKind::Prepare) + .one_of(vec![SyntaxKind::Ident, SyntaxKind::All]), + )); + + m.push(StatementDefinition::new( + SyntaxKind::CreateTableSpaceStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Create) + .required_token(SyntaxKind::Tablespace) + .any_tokens(None) + .required_token(SyntaxKind::Location), + )); + + m.push(StatementDefinition::new( + SyntaxKind::DropTableSpaceStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Drop) + .required_token(SyntaxKind::Tablespace) + .optional_if_exists_group() + .optional_token(SyntaxKind::IfP) + .optional_token(SyntaxKind::Exists) + .required_token(SyntaxKind::Ident), + )); + + m.push(StatementDefinition::new( + SyntaxKind::AlterOperatorStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Alter) + .required_token(SyntaxKind::Operator), + )); + + m.push(StatementDefinition::new( + SyntaxKind::AlterTypeStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Alter) + .required_token(SyntaxKind::TypeP) + .optional_schema_name_group() + .required_token(SyntaxKind::Ident), + )); + + m.push(StatementDefinition::new( + SyntaxKind::DropOwnedStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Drop) + .required_token(SyntaxKind::Owned) + .required_token(SyntaxKind::By), + )); + + m.push(StatementDefinition::new( + SyntaxKind::ReassignOwnedStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Reassign) + .required_token(SyntaxKind::Owned) + .required_token(SyntaxKind::By) + .any_tokens(None) + .required_token(SyntaxKind::To), + )); + + m.push(StatementDefinition::new( + SyntaxKind::CreateFdwStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Create) + .required_token(SyntaxKind::Foreign) + .required_token(SyntaxKind::DataP) + .required_token(SyntaxKind::Wrapper) + .required_token(SyntaxKind::Ident), + )); + + m.push( + StatementDefinition::new( + SyntaxKind::AlterFdwStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Alter) + .required_token(SyntaxKind::Foreign) + .required_token(SyntaxKind::DataP) + .required_token(SyntaxKind::Wrapper) + .required_token(SyntaxKind::Ident), + ) + .with_prohibited_tokens(vec![SyntaxKind::Rename]), + ); + + m.push(StatementDefinition::new( + SyntaxKind::CreateForeignServerStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Create) + .required_token(SyntaxKind::Server) + .optional_if_not_exists_group() + .optional_schema_name_group() + .required_token(SyntaxKind::Ident) + .any_tokens(None) + .required_token(SyntaxKind::Foreign) + .required_token(SyntaxKind::DataP) + .required_token(SyntaxKind::Wrapper) + .required_token(SyntaxKind::Ident), + )); + + m.push(StatementDefinition::new( + SyntaxKind::AlterForeignServerStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Alter) + .required_token(SyntaxKind::Server) + .optional_schema_name_group() + .required_token(SyntaxKind::Ident), + )); + + m.push(StatementDefinition::new( + SyntaxKind::CreateUserMappingStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Create) + .required_token(SyntaxKind::User) + .required_token(SyntaxKind::Mapping) + .optional_if_not_exists_group() + .required_token(SyntaxKind::For) + .any_tokens(None) + .required_token(SyntaxKind::Server) + .required_token(SyntaxKind::Ident), + )); + + m.push(StatementDefinition::new( + SyntaxKind::AlterUserMappingStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Alter) + .required_token(SyntaxKind::User) + .required_token(SyntaxKind::Mapping) + .optional_token(SyntaxKind::For) + .any_tokens(None) + .required_token(SyntaxKind::Server) + .required_token(SyntaxKind::Ident) + .required_token(SyntaxKind::Options), + )); + + m.push(StatementDefinition::new( + SyntaxKind::DropUserMappingStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Drop) + .required_token(SyntaxKind::User) + .required_token(SyntaxKind::Mapping) + .optional_if_exists_group() + .optional_token(SyntaxKind::For) + .any_tokens(None) + .required_token(SyntaxKind::Server) + .required_token(SyntaxKind::Ident), + )); + + m.push(StatementDefinition::new( + SyntaxKind::SecLabelStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Security) + .required_token(SyntaxKind::Label) + .optional_token(SyntaxKind::For) + .optional_schema_name_group() + .required_token(SyntaxKind::Ident) + .required_token(SyntaxKind::On), + )); + + m.push(StatementDefinition::new( + SyntaxKind::CreateForeignTableStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Create) + .required_token(SyntaxKind::Foreign) + .required_token(SyntaxKind::Table) + .optional_if_not_exists_group() + .optional_schema_name_group() + .required_token(SyntaxKind::Ident) + .any_tokens(None) + .required_token(SyntaxKind::Server) + .required_token(SyntaxKind::Ident), + )); + + m.push(StatementDefinition::new( + SyntaxKind::ImportForeignSchemaStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::ImportP) + .required_token(SyntaxKind::Foreign) + .required_token(SyntaxKind::Schema) + .optional_schema_name_group() + .required_token(SyntaxKind::Ident) + .any_tokens(None) + .required_token(SyntaxKind::From) + .required_token(SyntaxKind::Server) + .required_token(SyntaxKind::Ident) + .required_token(SyntaxKind::Into) + .optional_schema_name_group() + .required_token(SyntaxKind::Ident), + )); + + m.push(StatementDefinition::new( + SyntaxKind::CreateExtensionStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Create) + .required_token(SyntaxKind::Extension) + .optional_if_not_exists_group() + .optional_schema_name_group() + .required_token(SyntaxKind::Ident), + )); + + m.push(StatementDefinition::new( + SyntaxKind::AlterExtensionStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Alter) + .required_token(SyntaxKind::Extension) + .optional_schema_name_group() + .required_token(SyntaxKind::Ident), + )); + + m.push(StatementDefinition::new( + SyntaxKind::CreateEventTrigStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Create) + .required_token(SyntaxKind::Event) + .required_token(SyntaxKind::Trigger) + .optional_schema_name_group() + .required_token(SyntaxKind::Ident) + .required_token(SyntaxKind::On) + .required_token(SyntaxKind::Ident) + .any_tokens(None) + .required_token(SyntaxKind::Execute) + .one_of(vec![SyntaxKind::Function, SyntaxKind::Procedure]) + .optional_schema_name_group() + .required_token(SyntaxKind::Ident) + .required_token(SyntaxKind::Ascii40) + .required_token(SyntaxKind::Ascii41), + )); + + m.push(StatementDefinition::new( + SyntaxKind::AlterEventTrigStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Alter) + .required_token(SyntaxKind::Event) + .required_token(SyntaxKind::Trigger) + .optional_schema_name_group() + .required_token(SyntaxKind::Ident), + )); + + m.push(StatementDefinition::new( + SyntaxKind::RefreshMatViewStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Refresh) + .required_token(SyntaxKind::Materialized) + .required_token(SyntaxKind::View) + .optional_token(SyntaxKind::Concurrently) + .optional_schema_name_group() + .required_token(SyntaxKind::Ident), + )); + + m.push(StatementDefinition::new( + SyntaxKind::AlterSystemStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Alter) + .required_token(SyntaxKind::SystemP) + .one_of(vec![SyntaxKind::Set, SyntaxKind::Reset]), + )); + + m.push(StatementDefinition::new( + SyntaxKind::CreatePolicyStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Create) + .required_token(SyntaxKind::Policy) + .optional_schema_name_group() + .required_token(SyntaxKind::Ident) + .required_token(SyntaxKind::On) + .optional_schema_name_group() + .required_token(SyntaxKind::Ident), + )); + + m.push(StatementDefinition::new( + SyntaxKind::AlterPolicyStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Alter) + .required_token(SyntaxKind::Policy) + .optional_schema_name_group() + .required_token(SyntaxKind::Ident) + .required_token(SyntaxKind::On) + .optional_schema_name_group() + .required_token(SyntaxKind::Ident), + )); + + m.push(StatementDefinition::new( + SyntaxKind::CreateTransformStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Create) + .optional_or_replace_group() + .required_token(SyntaxKind::Transform) + .required_token(SyntaxKind::For) + .optional_schema_name_group() + .required_token(SyntaxKind::Ident) + .required_token(SyntaxKind::Language) + .optional_schema_name_group() + .required_token(SyntaxKind::Ident) + .required_token(SyntaxKind::Ascii40) + .any_tokens(None) + .required_token(SyntaxKind::Ascii41), + )); + + m.push(StatementDefinition::new( + SyntaxKind::CreateAmStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Create) + .required_token(SyntaxKind::Access) + .required_token(SyntaxKind::Method) + .optional_schema_name_group() + .required_token(SyntaxKind::Ident) + .required_token(SyntaxKind::TypeP), + )); + + m.push(StatementDefinition::new( + SyntaxKind::CreatePublicationStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Create) + .required_token(SyntaxKind::Publication) + .optional_schema_name_group() + .required_token(SyntaxKind::Ident), + )); + + m.push(StatementDefinition::new( + SyntaxKind::AlterPublicationStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Alter) + .required_token(SyntaxKind::Publication) + .optional_schema_name_group() + .required_token(SyntaxKind::Ident), + )); + + m.push(StatementDefinition::new( + SyntaxKind::CreateSubscriptionStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Create) + .required_token(SyntaxKind::Subscription) + .optional_schema_name_group() + .required_token(SyntaxKind::Ident) + .required_token(SyntaxKind::Connection) + .required_token(SyntaxKind::Sconst) + .required_token(SyntaxKind::Publication) + .optional_schema_name_group() + .required_token(SyntaxKind::Ident), + )); + + m.push(StatementDefinition::new( + SyntaxKind::AlterSubscriptionStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Alter) + .required_token(SyntaxKind::Subscription) + .optional_schema_name_group() + .required_token(SyntaxKind::Ident), + )); + + m.push(StatementDefinition::new( + SyntaxKind::DropSubscriptionStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Drop) + .required_token(SyntaxKind::Subscription) + .optional_if_exists_group() + .optional_schema_name_group() + .required_token(SyntaxKind::Ident), + )); + + m.push(StatementDefinition::new( + SyntaxKind::GrantStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Grant) + .any_tokens(None) + .required_token(SyntaxKind::On) + .any_tokens(None) + .required_token(SyntaxKind::To), + )); + + m.push( + StatementDefinition::new( + SyntaxKind::GrantStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Revoke) + .any_tokens(None) + .required_token(SyntaxKind::On), + ) + .with_prohibited_following_statements(vec![SyntaxKind::SelectStmt]), + ); + + m.push( + StatementDefinition::new( + SyntaxKind::AlterOwnerStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Alter) + .any_tokens(None) + .required_token(SyntaxKind::Owner) + .required_token(SyntaxKind::To) + .required_token(SyntaxKind::Ident), + ) + // dont ask why, but it seems like tables are special + // and altering their owner is an AlterTableStmt + .with_prohibited_tokens(vec![SyntaxKind::Table]), + ); + + m.push(StatementDefinition::new( + SyntaxKind::AlterObjectSchemaStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Alter) + .any_tokens(None) + .required_token(SyntaxKind::Set) + .required_token(SyntaxKind::Schema) + .optional_schema_name_group() + .required_token(SyntaxKind::Ident), + )); + + m.push(StatementDefinition::new( + SyntaxKind::CreatePlangStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Create) + .optional_or_replace_group() + .optional_token(SyntaxKind::Trusted) + .optional_token(SyntaxKind::Procedural) + .required_token(SyntaxKind::Language) + .optional_schema_name_group() + .required_token(SyntaxKind::Ident), + )); + + m.push(StatementDefinition::new( + SyntaxKind::CreateStatsStmt, + SyntaxBuilder::new() + .required_token(SyntaxKind::Create) + .required_token(SyntaxKind::Statistics) + .any_tokens(None) + .required_token(SyntaxKind::On) + .any_tokens(None) + .required_token(SyntaxKind::From) + .optional_schema_name_group() + .required_token(SyntaxKind::Ident), + )); + + let mut stmt_starts: HashMap> = HashMap::new(); + + for stmt in m { + let first_token = stmt + .tokens + .first() + .expect("Expected first token to be present"); + + if let SyntaxDefinition::RequiredToken(token) = first_token { + stmt_starts.entry(*token).or_insert(Vec::new()).push(stmt); + } else { + panic!("Expected first token to be a required token"); + } + } + + stmt_starts + }); diff --git a/crates/pg_statement_splitter/src/is_at_stmt_start.rs b/crates/pg_statement_splitter/src/is_at_stmt_start.rs deleted file mode 100644 index ec1b83ea..00000000 --- a/crates/pg_statement_splitter/src/is_at_stmt_start.rs +++ /dev/null @@ -1,1015 +0,0 @@ -use std::collections::HashMap; -use std::sync::LazyLock; - -use super::Parser; -use pg_lexer::SyntaxKind; - -pub enum SyntaxToken { - Required(SyntaxKind), - Optional(SyntaxKind), -} - -#[derive(Debug, Clone, Hash)] -pub enum TokenStatement { - // The respective token is the last token of the statement - EoS(SyntaxKind), - Any(SyntaxKind), -} - -impl TokenStatement { - fn is_eos(&self) -> bool { - match self { - TokenStatement::EoS(_) => true, - _ => false, - } - } - - fn kind(&self) -> SyntaxKind { - match self { - TokenStatement::EoS(k) => k.to_owned(), - TokenStatement::Any(k) => k.to_owned(), - } - } -} - -impl PartialEq for TokenStatement { - fn eq(&self, other: &Self) -> bool { - let a = match self { - TokenStatement::EoS(s) => s, - TokenStatement::Any(s) => s, - }; - - let b = match other { - TokenStatement::EoS(s) => s, - TokenStatement::Any(s) => s, - }; - - return a == b; - } -} - -// vector of hashmaps, where each hashmap returns the list of possible statements for a token at -// the respective index. -// -// For example, at idx 0, the hashmap contains a superset of -// ``` -//{ -// Create: [ -// IndexStmt, -// CreateFunctionStmt, -// CreateStmt, -// ViewStmt, -// ], -// Select: [ -// SelectStmt, -// ], -// }, -// ``` -// -// the idea is to trim down the possible options for each token, until only one statement is left. -// -// The vector is lazily constructed out of another vector of tuples, where each tuple contains a -// statement, and a list of `SyntaxToken`s that are to be found at the start of the statement. -pub static STATEMENT_START_TOKEN_MAPS: LazyLock>>> = - LazyLock::new(|| { - let mut m: Vec<(SyntaxKind, &'static [SyntaxToken])> = Vec::new(); - - m.push(( - SyntaxKind::InsertStmt, - &[ - SyntaxToken::Required(SyntaxKind::Insert), - SyntaxToken::Required(SyntaxKind::Into), - ], - )); - - m.push(( - SyntaxKind::DeleteStmt, - &[ - SyntaxToken::Required(SyntaxKind::DeleteP), - SyntaxToken::Required(SyntaxKind::From), - ], - )); - - m.push(( - SyntaxKind::UpdateStmt, - &[SyntaxToken::Required(SyntaxKind::Update)], - )); - - m.push(( - SyntaxKind::MergeStmt, - &[ - SyntaxToken::Required(SyntaxKind::Merge), - SyntaxToken::Required(SyntaxKind::Into), - ], - )); - - m.push(( - SyntaxKind::SelectStmt, - &[SyntaxToken::Required(SyntaxKind::Select)], - )); - - m.push(( - SyntaxKind::AlterTableStmt, - &[ - SyntaxToken::Required(SyntaxKind::Alter), - SyntaxToken::Required(SyntaxKind::Table), - SyntaxToken::Optional(SyntaxKind::IfP), - SyntaxToken::Optional(SyntaxKind::Exists), - SyntaxToken::Optional(SyntaxKind::Only), - SyntaxToken::Required(SyntaxKind::Ident), - ], - )); - - // ALTER TABLE x RENAME ... is different to e.g. alter table alter column... - m.push(( - SyntaxKind::RenameStmt, - &[ - SyntaxToken::Required(SyntaxKind::Alter), - SyntaxToken::Required(SyntaxKind::Table), - SyntaxToken::Optional(SyntaxKind::IfP), - SyntaxToken::Optional(SyntaxKind::Exists), - SyntaxToken::Optional(SyntaxKind::Only), - SyntaxToken::Required(SyntaxKind::Ident), - SyntaxToken::Required(SyntaxKind::Rename), - ], - )); - - m.push(( - SyntaxKind::AlterDomainStmt, - &[ - SyntaxToken::Required(SyntaxKind::Alter), - SyntaxToken::Required(SyntaxKind::DomainP), - ], - )); - - m.push(( - SyntaxKind::AlterDefaultPrivilegesStmt, - &[ - SyntaxToken::Required(SyntaxKind::Alter), - SyntaxToken::Required(SyntaxKind::Default), - SyntaxToken::Required(SyntaxKind::Privileges), - ], - )); - - m.push(( - SyntaxKind::ClusterStmt, - &[SyntaxToken::Required(SyntaxKind::Cluster)], - )); - - m.push(( - SyntaxKind::CopyStmt, - &[SyntaxToken::Required(SyntaxKind::Copy)], - )); - - // CREATE [ [ GLOBAL | LOCAL ] { TEMPORARY | TEMP } | UNLOGGED ] TABLE - // this is overly simplified, but it should be good enough for now - m.push(( - SyntaxKind::CreateStmt, - &[ - SyntaxToken::Required(SyntaxKind::Create), - SyntaxToken::Optional(SyntaxKind::Global), - SyntaxToken::Optional(SyntaxKind::Local), - SyntaxToken::Optional(SyntaxKind::Temporary), - SyntaxToken::Optional(SyntaxKind::Temp), - SyntaxToken::Optional(SyntaxKind::Unlogged), - SyntaxToken::Optional(SyntaxKind::IfP), - SyntaxToken::Optional(SyntaxKind::Not), - SyntaxToken::Optional(SyntaxKind::Exists), - SyntaxToken::Required(SyntaxKind::Table), - SyntaxToken::Required(SyntaxKind::Ident), - ], - )); - - // CREATE [ OR REPLACE ] AGGREGATE - m.push(( - SyntaxKind::DefineStmt, - &[ - SyntaxToken::Required(SyntaxKind::Create), - SyntaxToken::Optional(SyntaxKind::Or), - SyntaxToken::Optional(SyntaxKind::Replace), - SyntaxToken::Required(SyntaxKind::Aggregate), - ], - )); - - // CREATE OPERATOR - m.push(( - SyntaxKind::DefineStmt, - &[ - SyntaxToken::Required(SyntaxKind::Create), - SyntaxToken::Required(SyntaxKind::Operator), - ], - )); - - // CREATE TYPE name - m.push(( - SyntaxKind::DefineStmt, - &[ - SyntaxToken::Required(SyntaxKind::Create), - SyntaxToken::Required(SyntaxKind::TypeP), - SyntaxToken::Required(SyntaxKind::Ident), - ], - )); - - // CREATE TYPE name AS - m.push(( - SyntaxKind::CompositeTypeStmt, - &[ - SyntaxToken::Required(SyntaxKind::Create), - SyntaxToken::Required(SyntaxKind::TypeP), - SyntaxToken::Required(SyntaxKind::Ident), - SyntaxToken::Required(SyntaxKind::As), - ], - )); - - // CREATE TYPE name AS ENUM - m.push(( - SyntaxKind::CreateEnumStmt, - &[ - SyntaxToken::Required(SyntaxKind::Create), - SyntaxToken::Required(SyntaxKind::TypeP), - SyntaxToken::Required(SyntaxKind::Ident), - SyntaxToken::Required(SyntaxKind::As), - SyntaxToken::Required(SyntaxKind::EnumP), - ], - )); - - // CREATE TYPE name AS RANGE - m.push(( - SyntaxKind::CreateRangeStmt, - &[ - SyntaxToken::Required(SyntaxKind::Create), - SyntaxToken::Required(SyntaxKind::TypeP), - SyntaxToken::Required(SyntaxKind::Ident), - SyntaxToken::Required(SyntaxKind::As), - SyntaxToken::Required(SyntaxKind::Range), - ], - )); - - // m.push(( - // SyntaxKind::DropStmt, - // &[ - // SyntaxToken::Required(SyntaxKind::Drop), - // ], - // )); - - m.push(( - SyntaxKind::TruncateStmt, - &[SyntaxToken::Required(SyntaxKind::Truncate)], - )); - - m.push(( - SyntaxKind::CommentStmt, - &[ - SyntaxToken::Required(SyntaxKind::Comment), - SyntaxToken::Required(SyntaxKind::On), - ], - )); - - m.push(( - SyntaxKind::FetchStmt, - &[SyntaxToken::Required(SyntaxKind::Fetch)], - )); - - // CREATE [ UNIQUE ] INDEX - m.push(( - SyntaxKind::IndexStmt, - &[ - SyntaxToken::Required(SyntaxKind::Create), - SyntaxToken::Optional(SyntaxKind::Unique), - SyntaxToken::Required(SyntaxKind::Index), - ], - )); - - // CREATE [ OR REPLACE ] FUNCTION - m.push(( - SyntaxKind::CreateFunctionStmt, - &[ - SyntaxToken::Required(SyntaxKind::Create), - SyntaxToken::Optional(SyntaxKind::Or), - SyntaxToken::Optional(SyntaxKind::Replace), - SyntaxToken::Required(SyntaxKind::Function), - ], - )); - - m.push(( - SyntaxKind::AlterFunctionStmt, - &[ - SyntaxToken::Required(SyntaxKind::Alter), - SyntaxToken::Required(SyntaxKind::Function), - ], - )); - - m.push((SyntaxKind::DoStmt, &[SyntaxToken::Required(SyntaxKind::Do)])); - - // CREATE [ OR REPLACE ] RULE - m.push(( - SyntaxKind::RuleStmt, - &[ - SyntaxToken::Required(SyntaxKind::Create), - SyntaxToken::Optional(SyntaxKind::Or), - SyntaxToken::Optional(SyntaxKind::Replace), - SyntaxToken::Required(SyntaxKind::Rule), - ], - )); - - m.push(( - SyntaxKind::NotifyStmt, - &[SyntaxToken::Required(SyntaxKind::Notify)], - )); - m.push(( - SyntaxKind::ListenStmt, - &[SyntaxToken::Required(SyntaxKind::Listen)], - )); - m.push(( - SyntaxKind::UnlistenStmt, - &[SyntaxToken::Required(SyntaxKind::Unlisten)], - )); - - // TransactionStmt can be Begin or Commit - m.push(( - SyntaxKind::TransactionStmt, - &[SyntaxToken::Required(SyntaxKind::BeginP)], - )); - m.push(( - SyntaxKind::TransactionStmt, - &[SyntaxToken::Required(SyntaxKind::Commit)], - )); - - // CREATE [ OR REPLACE ] [ TEMP | TEMPORARY ] [ RECURSIVE ] VIEW - // this is overly simplified, but it should be good enough for now - m.push(( - SyntaxKind::ViewStmt, - &[ - SyntaxToken::Required(SyntaxKind::Create), - SyntaxToken::Optional(SyntaxKind::Or), - SyntaxToken::Optional(SyntaxKind::Replace), - SyntaxToken::Optional(SyntaxKind::Temporary), - SyntaxToken::Optional(SyntaxKind::Temp), - SyntaxToken::Optional(SyntaxKind::Recursive), - SyntaxToken::Required(SyntaxKind::View), - ], - )); - - m.push(( - SyntaxKind::LoadStmt, - &[SyntaxToken::Required(SyntaxKind::Load)], - )); - - m.push(( - SyntaxKind::CreateDomainStmt, - &[ - SyntaxToken::Required(SyntaxKind::Create), - SyntaxToken::Required(SyntaxKind::DomainP), - ], - )); - - m.push(( - SyntaxKind::CreatedbStmt, - &[ - SyntaxToken::Required(SyntaxKind::Create), - SyntaxToken::Required(SyntaxKind::Database), - ], - )); - - m.push(( - SyntaxKind::DropdbStmt, - &[ - SyntaxToken::Required(SyntaxKind::Drop), - SyntaxToken::Required(SyntaxKind::Database), - ], - )); - - m.push(( - SyntaxKind::VacuumStmt, - &[SyntaxToken::Required(SyntaxKind::Vacuum)], - )); - - m.push(( - SyntaxKind::ExplainStmt, - &[SyntaxToken::Required(SyntaxKind::Explain)], - )); - - // CREATE [ [ GLOBAL | LOCAL ] { TEMPORARY | TEMP } ] TABLE AS - // this is overly simplified, but it should be good enough for now - m.push(( - SyntaxKind::CreateTableAsStmt, - &[ - SyntaxToken::Required(SyntaxKind::Create), - SyntaxToken::Optional(SyntaxKind::Global), - SyntaxToken::Optional(SyntaxKind::Local), - SyntaxToken::Optional(SyntaxKind::Temporary), - SyntaxToken::Optional(SyntaxKind::Temp), - SyntaxToken::Required(SyntaxKind::Table), - SyntaxToken::Required(SyntaxKind::Ident), - SyntaxToken::Required(SyntaxKind::As), - ], - )); - - m.push(( - SyntaxKind::CreateSeqStmt, - &[ - SyntaxToken::Required(SyntaxKind::Create), - SyntaxToken::Optional(SyntaxKind::Temporary), - SyntaxToken::Optional(SyntaxKind::Temp), - SyntaxToken::Optional(SyntaxKind::Unlogged), - SyntaxToken::Required(SyntaxKind::Sequence), - ], - )); - - m.push(( - SyntaxKind::AlterSeqStmt, - &[ - SyntaxToken::Required(SyntaxKind::Alter), - SyntaxToken::Required(SyntaxKind::Sequence), - ], - )); - - m.push(( - SyntaxKind::VariableSetStmt, - &[SyntaxToken::Required(SyntaxKind::Set)], - )); - - m.push(( - SyntaxKind::VariableShowStmt, - &[SyntaxToken::Required(SyntaxKind::Show)], - )); - - m.push(( - SyntaxKind::DiscardStmt, - &[SyntaxToken::Required(SyntaxKind::Discard)], - )); - - // CREATE [ OR REPLACE ] [ CONSTRAINT ] TRIGGER - m.push(( - SyntaxKind::CreateTrigStmt, - &[ - SyntaxToken::Required(SyntaxKind::Create), - SyntaxToken::Optional(SyntaxKind::Or), - SyntaxToken::Optional(SyntaxKind::Replace), - SyntaxToken::Optional(SyntaxKind::Constraint), - SyntaxToken::Required(SyntaxKind::Trigger), - ], - )); - - m.push(( - SyntaxKind::CreateRoleStmt, - &[ - SyntaxToken::Required(SyntaxKind::Create), - SyntaxToken::Required(SyntaxKind::Role), - ], - )); - - m.push(( - SyntaxKind::AlterRoleStmt, - &[ - SyntaxToken::Required(SyntaxKind::Alter), - SyntaxToken::Required(SyntaxKind::Role), - ], - )); - - m.push(( - SyntaxKind::DropRoleStmt, - &[ - SyntaxToken::Required(SyntaxKind::Drop), - SyntaxToken::Required(SyntaxKind::Role), - ], - )); - - m.push(( - SyntaxKind::LockStmt, - &[SyntaxToken::Required(SyntaxKind::LockP)], - )); - - m.push(( - SyntaxKind::ConstraintsSetStmt, - &[ - SyntaxToken::Required(SyntaxKind::Set), - SyntaxToken::Required(SyntaxKind::Constraints), - ], - )); - - m.push(( - SyntaxKind::ReindexStmt, - &[SyntaxToken::Required(SyntaxKind::Reindex)], - )); - - m.push(( - SyntaxKind::CheckPointStmt, - &[SyntaxToken::Required(SyntaxKind::Checkpoint)], - )); - - m.push(( - SyntaxKind::CreateSchemaStmt, - &[ - SyntaxToken::Required(SyntaxKind::Create), - SyntaxToken::Required(SyntaxKind::Schema), - ], - )); - - m.push(( - SyntaxKind::AlterDatabaseStmt, - &[ - SyntaxToken::Required(SyntaxKind::Alter), - SyntaxToken::Required(SyntaxKind::Database), - SyntaxToken::Required(SyntaxKind::Ident), - ], - )); - - m.push(( - SyntaxKind::AlterDatabaseRefreshCollStmt, - &[ - SyntaxToken::Required(SyntaxKind::Alter), - SyntaxToken::Required(SyntaxKind::Database), - SyntaxToken::Required(SyntaxKind::Ident), - SyntaxToken::Required(SyntaxKind::Refresh), - SyntaxToken::Required(SyntaxKind::Collation), - SyntaxToken::Required(SyntaxKind::VersionP), - ], - )); - - m.push(( - SyntaxKind::AlterDatabaseSetStmt, - &[ - SyntaxToken::Required(SyntaxKind::Alter), - SyntaxToken::Required(SyntaxKind::Database), - SyntaxToken::Required(SyntaxKind::Ident), - SyntaxToken::Required(SyntaxKind::Set), - ], - )); - - m.push(( - SyntaxKind::AlterDatabaseSetStmt, - &[ - SyntaxToken::Required(SyntaxKind::Alter), - SyntaxToken::Required(SyntaxKind::Database), - SyntaxToken::Required(SyntaxKind::Ident), - SyntaxToken::Required(SyntaxKind::Reset), - ], - )); - - m.push(( - SyntaxKind::CreateConversionStmt, - &[ - SyntaxToken::Required(SyntaxKind::Create), - SyntaxToken::Optional(SyntaxKind::Default), - SyntaxToken::Required(SyntaxKind::ConversionP), - ], - )); - - m.push(( - SyntaxKind::CreateCastStmt, - &[ - SyntaxToken::Required(SyntaxKind::Create), - SyntaxToken::Required(SyntaxKind::Cast), - ], - )); - - m.push(( - SyntaxKind::CreateOpClassStmt, - &[ - SyntaxToken::Required(SyntaxKind::Create), - SyntaxToken::Required(SyntaxKind::Operator), - SyntaxToken::Required(SyntaxKind::Class), - ], - )); - - m.push(( - SyntaxKind::CreateOpFamilyStmt, - &[ - SyntaxToken::Required(SyntaxKind::Create), - SyntaxToken::Required(SyntaxKind::Operator), - SyntaxToken::Required(SyntaxKind::Family), - ], - )); - - m.push(( - SyntaxKind::AlterOpFamilyStmt, - &[ - SyntaxToken::Required(SyntaxKind::Alter), - SyntaxToken::Required(SyntaxKind::Operator), - SyntaxToken::Required(SyntaxKind::Family), - ], - )); - - m.push(( - SyntaxKind::PrepareStmt, - &[SyntaxToken::Required(SyntaxKind::Prepare)], - )); - - // m.push(( - // SyntaxKind::ExecuteStmt, - // &[SyntaxToken::Required(SyntaxKind::Execute)], - // )); - - m.push(( - SyntaxKind::DeallocateStmt, - &[SyntaxToken::Required(SyntaxKind::Deallocate)], - )); - - m.push(( - SyntaxKind::CreateTableSpaceStmt, - &[ - SyntaxToken::Required(SyntaxKind::Create), - SyntaxToken::Required(SyntaxKind::Tablespace), - ], - )); - - m.push(( - SyntaxKind::DropTableSpaceStmt, - &[ - SyntaxToken::Required(SyntaxKind::Drop), - SyntaxToken::Required(SyntaxKind::Tablespace), - ], - )); - - m.push(( - SyntaxKind::AlterOperatorStmt, - &[ - SyntaxToken::Required(SyntaxKind::Alter), - SyntaxToken::Required(SyntaxKind::Operator), - ], - )); - - m.push(( - SyntaxKind::AlterTypeStmt, - &[ - SyntaxToken::Required(SyntaxKind::Alter), - SyntaxToken::Required(SyntaxKind::TypeP), - ], - )); - - m.push(( - SyntaxKind::DropOwnedStmt, - &[ - SyntaxToken::Required(SyntaxKind::Drop), - SyntaxToken::Required(SyntaxKind::Owned), - SyntaxToken::Required(SyntaxKind::By), - ], - )); - - m.push(( - SyntaxKind::ReassignOwnedStmt, - &[ - SyntaxToken::Required(SyntaxKind::Reassign), - SyntaxToken::Required(SyntaxKind::Owned), - SyntaxToken::Required(SyntaxKind::By), - ], - )); - - m.push(( - SyntaxKind::CreateFdwStmt, - &[ - SyntaxToken::Required(SyntaxKind::Create), - SyntaxToken::Required(SyntaxKind::Foreign), - SyntaxToken::Required(SyntaxKind::DataP), - SyntaxToken::Required(SyntaxKind::Wrapper), - ], - )); - - m.push(( - SyntaxKind::AlterFdwStmt, - &[ - SyntaxToken::Required(SyntaxKind::Alter), - SyntaxToken::Required(SyntaxKind::Foreign), - SyntaxToken::Required(SyntaxKind::DataP), - SyntaxToken::Required(SyntaxKind::Wrapper), - ], - )); - - m.push(( - SyntaxKind::CreateForeignServerStmt, - &[ - SyntaxToken::Required(SyntaxKind::Create), - SyntaxToken::Required(SyntaxKind::Server), - ], - )); - - m.push(( - SyntaxKind::AlterForeignServerStmt, - &[ - SyntaxToken::Required(SyntaxKind::Alter), - SyntaxToken::Required(SyntaxKind::Server), - ], - )); - - m.push(( - SyntaxKind::CreateUserMappingStmt, - &[ - SyntaxToken::Required(SyntaxKind::Create), - SyntaxToken::Required(SyntaxKind::User), - SyntaxToken::Required(SyntaxKind::Mapping), - ], - )); - - m.push(( - SyntaxKind::AlterUserMappingStmt, - &[ - SyntaxToken::Required(SyntaxKind::Alter), - SyntaxToken::Required(SyntaxKind::User), - SyntaxToken::Required(SyntaxKind::Mapping), - SyntaxToken::Required(SyntaxKind::For), - ], - )); - - m.push(( - SyntaxKind::DropUserMappingStmt, - &[ - SyntaxToken::Required(SyntaxKind::Drop), - SyntaxToken::Required(SyntaxKind::User), - SyntaxToken::Required(SyntaxKind::Mapping), - ], - )); - - m.push(( - SyntaxKind::SecLabelStmt, - &[ - SyntaxToken::Required(SyntaxKind::Security), - SyntaxToken::Required(SyntaxKind::Label), - ], - )); - - m.push(( - SyntaxKind::CreateForeignTableStmt, - &[ - SyntaxToken::Required(SyntaxKind::Create), - SyntaxToken::Required(SyntaxKind::Foreign), - SyntaxToken::Required(SyntaxKind::Table), - ], - )); - - m.push(( - SyntaxKind::ImportForeignSchemaStmt, - &[ - SyntaxToken::Required(SyntaxKind::ImportP), - SyntaxToken::Required(SyntaxKind::Foreign), - SyntaxToken::Required(SyntaxKind::Schema), - ], - )); - - m.push(( - SyntaxKind::CreateExtensionStmt, - &[ - SyntaxToken::Required(SyntaxKind::Create), - SyntaxToken::Required(SyntaxKind::Extension), - ], - )); - - m.push(( - SyntaxKind::AlterExtensionStmt, - &[ - SyntaxToken::Required(SyntaxKind::Alter), - SyntaxToken::Required(SyntaxKind::Extension), - ], - )); - - m.push(( - SyntaxKind::CreateEventTrigStmt, - &[ - SyntaxToken::Required(SyntaxKind::Create), - SyntaxToken::Required(SyntaxKind::Event), - SyntaxToken::Required(SyntaxKind::Trigger), - ], - )); - - m.push(( - SyntaxKind::AlterEventTrigStmt, - &[ - SyntaxToken::Required(SyntaxKind::Alter), - SyntaxToken::Required(SyntaxKind::Event), - SyntaxToken::Required(SyntaxKind::Trigger), - ], - )); - - m.push(( - SyntaxKind::RefreshMatViewStmt, - &[ - SyntaxToken::Required(SyntaxKind::Refresh), - SyntaxToken::Required(SyntaxKind::Materialized), - SyntaxToken::Required(SyntaxKind::View), - ], - )); - - m.push(( - SyntaxKind::AlterSystemStmt, - &[ - SyntaxToken::Required(SyntaxKind::Alter), - SyntaxToken::Required(SyntaxKind::SystemP), - ], - )); - - m.push(( - SyntaxKind::CreatePolicyStmt, - &[ - SyntaxToken::Required(SyntaxKind::Create), - SyntaxToken::Required(SyntaxKind::Policy), - ], - )); - - m.push(( - SyntaxKind::AlterPolicyStmt, - &[ - SyntaxToken::Required(SyntaxKind::Alter), - SyntaxToken::Required(SyntaxKind::Policy), - ], - )); - - m.push(( - SyntaxKind::CreateTransformStmt, - &[ - SyntaxToken::Required(SyntaxKind::Create), - SyntaxToken::Optional(SyntaxKind::Or), - SyntaxToken::Optional(SyntaxKind::Replace), - SyntaxToken::Required(SyntaxKind::Transform), - SyntaxToken::Required(SyntaxKind::For), - ], - )); - - m.push(( - SyntaxKind::CreateAmStmt, - &[ - SyntaxToken::Required(SyntaxKind::Create), - SyntaxToken::Required(SyntaxKind::Access), - SyntaxToken::Required(SyntaxKind::Method), - ], - )); - - m.push(( - SyntaxKind::CreatePublicationStmt, - &[ - SyntaxToken::Required(SyntaxKind::Create), - SyntaxToken::Required(SyntaxKind::Publication), - ], - )); - - m.push(( - SyntaxKind::AlterPublicationStmt, - &[ - SyntaxToken::Required(SyntaxKind::Alter), - SyntaxToken::Required(SyntaxKind::Publication), - ], - )); - - m.push(( - SyntaxKind::CreateSubscriptionStmt, - &[ - SyntaxToken::Required(SyntaxKind::Create), - SyntaxToken::Required(SyntaxKind::Subscription), - ], - )); - - m.push(( - SyntaxKind::AlterSubscriptionStmt, - &[ - SyntaxToken::Required(SyntaxKind::Alter), - SyntaxToken::Required(SyntaxKind::Subscription), - ], - )); - - m.push(( - SyntaxKind::DropSubscriptionStmt, - &[ - SyntaxToken::Required(SyntaxKind::Drop), - SyntaxToken::Required(SyntaxKind::Subscription), - ], - )); - - m.push(( - SyntaxKind::CreateStatsStmt, - &[ - SyntaxToken::Required(SyntaxKind::Create), - SyntaxToken::Required(SyntaxKind::Statistics), - ], - )); - - m.push(( - SyntaxKind::AlterCollationStmt, - &[ - SyntaxToken::Required(SyntaxKind::Alter), - SyntaxToken::Required(SyntaxKind::Collation), - ], - )); - - m.push(( - SyntaxKind::CallStmt, - &[SyntaxToken::Required(SyntaxKind::Call)], - )); - - m.push(( - SyntaxKind::AlterStatsStmt, - &[ - SyntaxToken::Required(SyntaxKind::Alter), - SyntaxToken::Required(SyntaxKind::Statistics), - ], - )); - - let mut vec: Vec>> = Vec::new(); - - m.iter().for_each(|(statement, tokens)| { - let mut left_pull: usize = 0; - tokens.iter().enumerate().for_each(|(idx, token)| { - if vec.len() <= idx { - vec.push(HashMap::new()); - } - - let is_last = idx == tokens.len() - 1; - - match token { - SyntaxToken::Required(t) => { - for i in (idx - left_pull)..(idx + 1) { - let list_entry = vec[i].entry(t.to_owned()); - list_entry - .and_modify(|list| { - list.push(if is_last { - TokenStatement::EoS(statement.to_owned()) - } else { - TokenStatement::Any(statement.to_owned()) - }); - }) - .or_insert(vec![if is_last { - TokenStatement::EoS(statement.to_owned()) - } else { - TokenStatement::Any(statement.to_owned()) - }]); - } - } - SyntaxToken::Optional(t) => { - if is_last { - panic!("Optional token cannot be last token"); - } - for i in (idx - left_pull)..(idx + 1) { - let list_entry = vec[i].entry(t.to_owned()); - list_entry - .and_modify(|list| { - list.push(TokenStatement::Any(statement.to_owned())); - }) - .or_insert(vec![TokenStatement::Any(statement.to_owned())]); - } - left_pull += 1; - } - } - }); - }); - - vec - }); - -// TODO: complete the hashmap above with all statements: -// RETURN statement (inside SQL function body) -// ReturnStmt, -// SetOperationStmt, -// -// TODO: parsing ambiguity, check docs for solution -// GrantStmt(super::GrantStmt), -// GrantRoleStmt(super::GrantRoleStmt), -// ClosePortalStmt, -// CreatePlangStmt, -// AlterRoleSetStmt, -// DeclareCursorStmt, -// AlterObjectDependsStmt, -// AlterObjectSchemaStmt, -// AlterOwnerStmt, -// AlterEnumStmt, -// AlterTsdictionaryStmt, -// AlterTsconfigurationStmt, -// AlterTableSpaceOptionsStmt, -// AlterTableMoveAllStmt, -// AlterExtensionContentsStmt, -// ReplicaIdentityStmt, -// - -/// Returns the statement at which the parser is currently at, if any -pub fn is_at_stmt_start(parser: &mut Parser) -> Option { - let mut options = Vec::new(); - for i in 0..STATEMENT_START_TOKEN_MAPS.len() { - // important, else infinite loop: only ignore whitespaces after first token - let token = parser.nth(i, i != 0).kind; - if let Some(result) = STATEMENT_START_TOKEN_MAPS[i].get(&token) { - if i == 0 { - options = result.clone(); - } else { - options = result - .iter() - .filter(|o| options.contains(o)) - .cloned() - .collect(); - } - } else if options.len() > 1 { - // no result is found, and there is currently more than one option - // filter the options for all statements that are complete at this point - options.retain(|o| o.is_eos()); - } - - if options.len() == 0 { - break; - } else if options.len() == 1 && options.get(0).unwrap().is_eos() { - break; - } - } - if options.len() == 0 { - None - } else if options.len() == 1 && options.get(0).unwrap().is_eos() { - Some(options.get(0).unwrap().kind()) - } else { - panic!("Ambiguous statement"); - } -} diff --git a/crates/pg_statement_splitter/src/lib.rs b/crates/pg_statement_splitter/src/lib.rs index adaea475..5913c176 100644 --- a/crates/pg_statement_splitter/src/lib.rs +++ b/crates/pg_statement_splitter/src/lib.rs @@ -9,112 +9,27 @@ /// We should expand the definition map to include an `Any*`, which must be followed by at least /// one required token and allows the parser to search for the end tokens of the statement. This /// will hopefully be enough to reduce collisions to zero. -mod is_at_stmt_start; +mod data; mod parser; +mod statement_splitter; mod syntax_error; +mod tracker; +mod tracker_new; + +use statement_splitter::{StatementPosition, StatementSplitter}; +use text_size::TextRange; + +pub fn split(sql: &str) -> Vec { + StatementSplitter::new(sql) + .run() + .iter() + .map(|x| x.range) + .collect() +} -use is_at_stmt_start::{is_at_stmt_start, TokenStatement, STATEMENT_START_TOKEN_MAPS}; - -use parser::{Parse, Parser}; - -use pg_lexer::{lex, SyntaxKind}; - -pub fn split(sql: &str) -> Parse { - let mut parser = Parser::new(lex(sql)); - - while !parser.eof() { - match is_at_stmt_start(&mut parser) { - Some(stmt) => { - parser.start_stmt(); - - // advance over all start tokens of the statement - for i in 0..STATEMENT_START_TOKEN_MAPS.len() { - parser.eat_whitespace(); - let token = parser.nth(0, false); - if let Some(result) = STATEMENT_START_TOKEN_MAPS[i].get(&token.kind) { - let is_in_results = result - .iter() - .find(|x| match x { - TokenStatement::EoS(y) | TokenStatement::Any(y) => y == &stmt, - }) - .is_some(); - if i == 0 && !is_in_results { - panic!("Expected statement start"); - } else if is_in_results { - parser.expect(token.kind); - } else { - break; - } - } - } - - // move until the end of the statement, or until the next statement start - let mut is_sub_stmt = 0; - let mut is_sub_trx = 0; - let mut ignore_next_non_whitespace = false; - while !parser.at(SyntaxKind::Ascii59) && !parser.eof() { - match parser.nth(0, false).kind { - SyntaxKind::All => { - // ALL is never a statement start, but needs to be skipped when combining queries - // (e.g. UNION ALL) - parser.advance(); - } - SyntaxKind::BeginP => { - // BEGIN, consume until END - is_sub_trx += 1; - parser.advance(); - } - SyntaxKind::EndP => { - is_sub_trx -= 1; - parser.advance(); - } - // opening brackets "(", consume until closing bracket ")" - SyntaxKind::Ascii40 => { - is_sub_stmt += 1; - parser.advance(); - } - SyntaxKind::Ascii41 => { - is_sub_stmt -= 1; - parser.advance(); - } - SyntaxKind::As - | SyntaxKind::Union - | SyntaxKind::Intersect - | SyntaxKind::Except => { - // ignore the next non-whitespace token - ignore_next_non_whitespace = true; - parser.advance(); - } - _ => { - // if another stmt FIRST is encountered, break - // ignore if parsing sub stmt - if ignore_next_non_whitespace == false - && is_sub_stmt == 0 - && is_sub_trx == 0 - && is_at_stmt_start(&mut parser).is_some() - { - break; - } else { - if ignore_next_non_whitespace == true && !parser.at_whitespace() { - ignore_next_non_whitespace = false; - } - parser.advance(); - } - } - } - } - - parser.expect(SyntaxKind::Ascii59); - - parser.close_stmt(); - } - None => { - parser.advance(); - } - } - } - - parser.finish() +/// mostly used for testing +pub fn statements(sql: &str) -> Vec { + StatementSplitter::new(sql).run() } #[cfg(test)] @@ -126,12 +41,13 @@ mod tests { let input = "select 1 from contact;\nselect 1;\nalter table test drop column id;"; let res = split(input); - assert_eq!(res.ranges.len(), 3); - assert_eq!("select 1 from contact;", input[res.ranges[0]].to_string()); - assert_eq!("select 1;", input[res.ranges[1]].to_string()); + + assert_eq!(res.len(), 3); + assert_eq!("select 1 from contact;", input[res[0]].to_string()); + assert_eq!("select 1;", input[res[1]].to_string()); assert_eq!( "alter table test drop column id;", - input[res.ranges[2]].to_string() + input[res[2]].to_string() ); } } diff --git a/crates/pg_statement_splitter/src/parser.rs b/crates/pg_statement_splitter/src/parser.rs index 1b3d0f8b..fba0297e 100644 --- a/crates/pg_statement_splitter/src/parser.rs +++ b/crates/pg_statement_splitter/src/parser.rs @@ -1,18 +1,6 @@ -use std::cmp::min; - use pg_lexer::{SyntaxKind, Token, TokenType, WHITESPACE_TOKENS}; -use text_size::{TextRange, TextSize}; - -use crate::syntax_error::SyntaxError; -/// Main parser that exposes the `cstree` api, and collects errors and statements pub struct Parser { - /// The ranges of the statements - ranges: Vec<(usize, usize)>, - /// The syntax errors accumulated during parsing - errors: Vec, - /// The start of the current statement, if any - current_stmt_start: Option, /// The tokens to parse pub tokens: Vec, /// The current position in the token stream @@ -23,77 +11,21 @@ pub struct Parser { eof_token: Token, } -/// Result of Building -#[derive(Debug)] -pub struct Parse { - /// The ranges of the errors - pub ranges: Vec, - /// The syntax errors accumulated during parsing - pub errors: Vec, -} - impl Parser { pub fn new(tokens: Vec) -> Self { Self { eof_token: Token::eof(usize::from(tokens.last().unwrap().span.end())), - ranges: Vec::new(), - errors: Vec::new(), - current_stmt_start: None, tokens, pos: 0, whitespace_token_buffer: None, } } - pub fn finish(self) -> Parse { - Parse { - ranges: self - .ranges - .iter() - .map(|(start, end)| { - let from = self.tokens.get(*start); - let to = self.tokens.get(end - 1); - // get text range from token range - let text_start = from.unwrap().span.start(); - let text_end = to.unwrap().span.end(); - - TextRange::new( - TextSize::try_from(text_start).unwrap(), - TextSize::try_from(text_end).unwrap(), - ) - }) - .collect(), - errors: self.errors, - } - } - - pub fn start_stmt(&mut self) { - assert!(self.current_stmt_start.is_none()); - self.current_stmt_start = Some(self.pos); - } - - pub fn close_stmt(&mut self) { - assert!(self.current_stmt_start.is_some()); - self.ranges - .push((self.current_stmt_start.take().unwrap(), self.pos)); - } - - /// collects an SyntaxError with an `error` message at `pos` - pub fn error_at_pos(&mut self, error: String, pos: usize) { - self.errors.push(SyntaxError::new_at_offset( - error, - self.tokens - .get(min(self.tokens.len() - 1, pos)) - .unwrap() - .span - .start(), - )); - } - /// applies token and advances pub fn advance(&mut self) { assert!(!self.eof()); - if self.nth(0, false).kind == SyntaxKind::Whitespace { + let token = self.nth(0, false); + if token.kind == SyntaxKind::Whitespace { if self.whitespace_token_buffer.is_none() { self.whitespace_token_buffer = Some(self.pos); } @@ -137,6 +69,47 @@ impl Parser { self.pos == self.tokens.len() } + /// lookbehind method. + /// + /// if `ignore_whitespace` is true, it will skip all whitespace tokens + pub fn lookbehind( + &self, + lookbehind: usize, + ignore_whitespace: bool, + start_before: Option, + ) -> Option<&Token> { + if ignore_whitespace { + let mut idx = 0; + let mut non_whitespace_token_ctr = 0; + loop { + if idx > self.pos { + return None; + } + match self.tokens.get(self.pos - start_before.unwrap_or(0) - idx) { + Some(token) => { + if !WHITESPACE_TOKENS.contains(&token.kind) { + non_whitespace_token_ctr += 1; + if non_whitespace_token_ctr == lookbehind { + return Some(token); + } + } + idx += 1; + } + None => { + if (self.pos - idx - start_before.unwrap_or(0)) > 0 { + idx += 1; + } else { + return None; + } + } + } + } + } else { + self.tokens + .get(self.pos - lookbehind - start_before.unwrap_or(0)) + } + } + /// lookahead method. /// /// if `ignore_whitespace` is true, it will skip all whitespace tokens @@ -172,25 +145,4 @@ impl Parser { pub fn at(&self, kind: SyntaxKind) -> bool { self.nth(0, false).kind == kind } - - pub fn expect(&mut self, kind: SyntaxKind) { - if self.eat(kind) { - return; - } - if self.whitespace_token_buffer.is_some() { - self.error_at_pos( - format!( - "Expected {:#?}, found {:#?}", - kind, - self.tokens[self.whitespace_token_buffer.unwrap()].kind - ), - self.whitespace_token_buffer.unwrap(), - ); - } else { - self.error_at_pos( - format!("Expected {:#?}, found {:#?}", kind, self.nth(0, false)), - self.pos + 1, - ); - } - } } diff --git a/crates/pg_statement_splitter/src/statement_splitter.rs b/crates/pg_statement_splitter/src/statement_splitter.rs new file mode 100644 index 00000000..4c1fd947 --- /dev/null +++ b/crates/pg_statement_splitter/src/statement_splitter.rs @@ -0,0 +1,1651 @@ +use pg_lexer::{SyntaxKind, WHITESPACE_TOKENS}; +use text_size::{TextRange, TextSize}; + +use crate::{ + data::{STATEMENT_BRIDGE_DEFINITIONS, STATEMENT_DEFINITIONS}, + parser::Parser, + tracker_new::StatementTracker as Tracker, +}; + +pub(crate) struct StatementSplitter<'a> { + parser: Parser, + tracked_statements: Vec>, + active_bridges: Vec>, + ranges: Vec, + sub_trx_depth: usize, + sub_stmt_depth: usize, + is_within_atomic_block: bool, + sub_case_stmt_depth: usize, +} + +#[derive(Debug, Clone)] +pub struct StatementPosition { + pub kind: SyntaxKind, + pub range: TextRange, +} + +impl<'a> StatementSplitter<'a> { + pub fn new(sql: &str) -> Self { + Self { + parser: Parser::new(pg_lexer::lex(sql)), + tracked_statements: Vec::new(), + active_bridges: Vec::new(), + ranges: Vec::new(), + + sub_trx_depth: 0, + sub_stmt_depth: 0, + is_within_atomic_block: false, + sub_case_stmt_depth: 0, + } + } + + fn end_nesting(&mut self) { + match self.parser.nth(0, false).kind { + SyntaxKind::Ascii41 => { + // ")" + self.sub_stmt_depth -= 1; + } + SyntaxKind::EndP => { + self.is_within_atomic_block = false; + if self.sub_case_stmt_depth > 0 { + self.sub_case_stmt_depth -= 1; + } + } + _ => {} + }; + } + + fn start_nesting(&mut self) { + match self.parser.nth(0, false).kind { + SyntaxKind::Case => { + self.sub_case_stmt_depth += 1; + } + SyntaxKind::Ascii40 => { + // "(" + self.sub_stmt_depth += 1; + } + SyntaxKind::Atomic => { + if self.parser.lookbehind(2, true, None).map(|t| t.kind) == Some(SyntaxKind::BeginP) + { + self.is_within_atomic_block = true; + } + } + _ => {} + }; + } + + /// advance all tracked statements and return the earliest started_at value of the removed + /// statements + fn advance_tracker(&mut self) -> Option { + let mut removed_items = Vec::new(); + + self.tracked_statements.retain_mut(|stmt| { + println!( + "started at {:?}, parser pos {:?}", + stmt.started_at, self.parser.pos + ); + // dont advace if we started at the current position + if stmt.started_at == self.parser.pos { + return true; + } + + let keep = stmt.advance_with(&self.parser.nth(0, false).kind); + if !keep { + removed_items.push(stmt.started_at); + } + keep + }); + + println!("removed items: {:?}", removed_items); + + removed_items.iter().min().map(|i| *i) + } + + fn token_range(&self, token_pos: usize) -> TextRange { + self.parser.tokens.get(token_pos).unwrap().span + } + + fn add_incomplete_statement(&mut self, started_at: Option) { + if self.tracked_statements.len() > 0 || started_at.is_none() { + return; + } + + self.ranges.push(StatementPosition { + kind: SyntaxKind::Any, + range: TextRange::new( + self.token_range(started_at.unwrap()).start(), + self.parser.lookbehind(2, true, None).unwrap().span.end(), + ), + }); + } + + fn start_new_statements(&mut self) { + if self.sub_trx_depth != 0 + || self.sub_stmt_depth != 0 + || self.is_within_atomic_block + || self.sub_case_stmt_depth != 0 + { + return; + } + + // it onyl makes sense to start tracking new statements if at least one of the + // currently tracked statements could be complete. or if none are tracked yet. + // this is important for statements such as `explain select 1;` where `select 1` + // would mark a completed statement that would move `explain` into completed, + // even though the latter is part of the former. + if self.tracked_statements.len() != 0 + && self + .tracked_statements + .iter() + .all(|s| !s.could_be_complete()) + { + println!("reutning because none could be completed"); + return; + } else { + println!( + "{:?} {:?} could be complete", + self.tracked_statements + .iter() + .map(|x| x) + .collect::>(), + self.tracked_statements + .iter() + .map(|x| x.could_be_complete()) + .collect::>() + ); + } + + let new_stmts = STATEMENT_DEFINITIONS.get(&self.parser.nth(0, false).kind); + println!("potential new stmts {:?}", new_stmts); + + if let Some(new_stmts) = new_stmts { + let to_add = &mut new_stmts + .iter() + .filter_map(|stmt| { + if self.active_bridges.iter().any(|b| b.def.stmt == stmt.stmt) { + println!("not adding because of active bridges"); + None + } else if self.tracked_statements.iter_mut().any(|s| { + !s.can_start_stmt_after( + &stmt.stmt, + self.parser.pos, + stmt.ignore_if_prohibited, + ) + }) { + println!("not adding because cant start stmt after"); + None + } else { + println!("tracking new statement: {:?}", stmt.stmt); + Some(Tracker::new_at(stmt, self.parser.pos)) + } + }) + .collect(); + self.tracked_statements.append(to_add); + } + } + + fn advance_bridges(&mut self) { + self.active_bridges + .retain_mut(|stmt| stmt.advance_with(&self.parser.nth(0, false).kind)); + } + + fn start_new_bridges(&mut self) { + if let Some(bridges) = STATEMENT_BRIDGE_DEFINITIONS.get(&self.parser.nth(0, false).kind) { + self.active_bridges.append( + &mut bridges + .iter() + .map(|stmt| Tracker::new_at(stmt, self.parser.pos)) + .collect(), + ); + } + } + + fn close_stmt_with_semicolon(&mut self) { + let at_token = self.parser.nth(0, false); + assert_eq!(at_token.kind, SyntaxKind::Ascii59); + + // i didnt believe it myself at first, but there are statements where a ";" is valid + // within a sub statement, e.g.: + // "create rule qqq as on insert to copydml_test do instead (delete from copydml_test; delete from copydml_test);" + // so we need to check for sub statement depth here + if self.sub_stmt_depth != 0 || self.is_within_atomic_block { + return; + } + + println!( + "closing statement with semicolon {:?}", + self.tracked_statements + ); + + // get earliest statement + if let Some(earliest_complete_stmt_started_at) = self + .tracked_statements + .iter() + .filter(|s| s.could_be_complete()) + .min_by_key(|stmt| stmt.started_at) + .map(|stmt| stmt.started_at) + { + println!( + "earliest complete stmt started at: {}", + earliest_complete_stmt_started_at + ); + let earliest_complete_stmt = self + .tracked_statements + .iter() + .filter(|s| { + s.started_at == earliest_complete_stmt_started_at && s.could_be_complete() + }) + .max_by_key(|stmt| stmt.max_pos()) + .unwrap(); + + self.assert_single_complete_statement_at_position(earliest_complete_stmt); + + let end_pos = at_token.span.end(); + let start_pos = TextSize::try_from( + self.parser + .tokens + .get(earliest_complete_stmt.started_at) + .unwrap() + .span + .start(), + ) + .unwrap(); + self.ranges.push(StatementPosition { + kind: earliest_complete_stmt.def.stmt, + range: TextRange::new(start_pos, end_pos), + }); + } + + self.tracked_statements.clear(); + self.active_bridges.clear(); + } + + fn find_earliest_statement_start_pos(&self) -> Option { + self.tracked_statements + .iter() + .min_by_key(|stmt| stmt.started_at) + .map(|stmt| stmt.started_at) + } + + fn find_earliest_complete_statement_start_pos(&self) -> Option { + self.tracked_statements + .iter() + .filter(|s| s.could_be_complete()) + .min_by_key(|stmt| stmt.started_at) + .map(|stmt| stmt.started_at) + } + + fn find_latest_complete_statement_start_pos(&self) -> Option { + self.tracked_statements + .iter() + .filter(|s| s.could_be_complete()) + .max_by_key(|stmt| stmt.started_at) + .map(|stmt| stmt.started_at) + } + + fn find_latest_complete_statement_before_start_pos(&self, before: usize) -> Option { + self.tracked_statements + .iter() + .filter(|s| s.could_be_complete() && s.started_at < before) + .max_by_key(|stmt| stmt.started_at) + .map(|stmt| stmt.started_at) + } + + fn find_highest_positioned_complete_statement(&self, started_at: usize) -> &Tracker<'a> { + self.tracked_statements + .iter() + .filter(|s| s.started_at == started_at && s.could_be_complete()) + .max_by_key(|stmt| stmt.max_pos()) + .unwrap() + } + + fn assert_single_complete_statement_at_position(&self, tracker: &Tracker<'a>) { + let complete_stmts = self + .tracked_statements + .iter() + .filter(|s| { + s.started_at == tracker.started_at + && s.could_be_complete() + && s.current_positions() + .iter() + .any(|i| tracker.current_positions().contains(i)) + }) + .collect::>(); + assert_eq!( + 1, + complete_stmts.len(), + "multiple complete statements at the same position: {:?}", + complete_stmts + .iter() + .map(|s| s.def.stmt) + .collect::>() + ); + } + + pub fn run(mut self) -> Vec { + println!("parser pos {:?}", self.parser.pos); + while !self.parser.eof() { + if WHITESPACE_TOKENS.contains(&self.parser.nth(0, false).kind) { + self.parser.advance(); + continue; + } + + println!( + "############ current token: {:?}", + self.parser.nth(0, false).kind + ); + + println!( + "current stmts: {:?}", + self.tracked_statements + .iter() + .map(|s| s.def.stmt) + .collect::>() + ); + + // todo start new stmts first, then advance all others + + self.start_new_statements(); + + self.advance_bridges(); + + self.start_new_bridges(); + + let removed_items_min_started_at = self.advance_tracker(); + + self.add_incomplete_statement(removed_items_min_started_at); + + self.start_nesting(); + + if self.parser.nth(0, false).kind == SyntaxKind::Ascii59 { + self.close_stmt_with_semicolon(); + } + + self.end_nesting(); + + println!("stmts after: {:?}", self.tracked_statements); + + // # This is where the actual parsing happens + + // 1. Find the latest complete statement + if let Some(latest_completed_stmt_started_at) = + self.find_latest_complete_statement_start_pos() + { + println!( + "latest_completed_stmt_started_at: {:?}", + latest_completed_stmt_started_at + ); + + // Step 2: Find the latest complete statement before the latest completed statement + if let Some(latest_complete_before_started_at) = self + .find_latest_complete_statement_before_start_pos( + latest_completed_stmt_started_at, + ) + { + let latest_complete_before = self.find_highest_positioned_complete_statement( + latest_complete_before_started_at, + ); + + println!("latest_complete_before: {:?}", latest_complete_before); + + self.assert_single_complete_statement_at_position(&latest_complete_before); + + let stmt_kind = latest_complete_before.def.stmt; + let latest_complete_before_started_at = latest_complete_before.started_at; + + // Step 3: save range for the statement + let start_pos = self.token_range(latest_complete_before_started_at).start(); + + // the end position is the end() of the last non-whitespace token before the start + // of the latest complete statement + let latest_non_whitespace_token = self.parser.lookbehind( + 2, + true, + Some(self.parser.pos - latest_completed_stmt_started_at), + ); + let end_pos = latest_non_whitespace_token.unwrap().span.end(); + + println!("!!!! adding {:?}", stmt_kind); + + self.ranges.push(StatementPosition { + kind: stmt_kind, + range: TextRange::new(start_pos, end_pos), + }); + + // Step 4: remove all statements that started before or at the position + self.tracked_statements + .retain(|s| s.started_at > latest_complete_before_started_at); + } + } + + self.parser.advance(); + } + + println!("tracked statements: {:?}", self.tracked_statements); + + // we reached eof; add any remaining statements + + // get the earliest statement that is complete + if let Some(earliest_complete_stmt_started_at) = + self.find_earliest_complete_statement_start_pos() + { + let earliest_complete_stmt = + self.find_highest_positioned_complete_statement(earliest_complete_stmt_started_at); + + println!("earliest complete stmt: {:?}", earliest_complete_stmt); + + self.assert_single_complete_statement_at_position(earliest_complete_stmt); + + let start_pos = self.token_range(earliest_complete_stmt_started_at).start(); + + let end_token = self.parser.lookbehind(1, true, None).unwrap(); + let end_pos = end_token.span.end(); + + println!("!!!! adding {:?}", earliest_complete_stmt.def.stmt); + + self.ranges.push(StatementPosition { + kind: earliest_complete_stmt.def.stmt, + range: TextRange::new(start_pos, end_pos), + }); + + self.tracked_statements + .retain(|s| s.started_at > earliest_complete_stmt_started_at); + } + + if let Some(earliest_stmt_started_at) = self.find_earliest_statement_start_pos() { + let start_pos = self.token_range(earliest_stmt_started_at).start(); + + // end position is last non-whitespace token before or at the current position + let end_pos = self.parser.lookbehind(1, true, None).unwrap().span.end(); + + println!("!!!! adding any"); + + self.ranges.push(StatementPosition { + kind: SyntaxKind::Any, + range: TextRange::new(start_pos, end_pos), + }); + } + + self.ranges + } +} + +#[cfg(test)] +mod tests { + use pg_lexer::{lex, SyntaxKind}; + + use crate::statement_splitter::StatementSplitter; + + #[test] + fn test_simple_select() { + let input = " +select id, name, test1231234123, unknown from co; + +select 14433313331333 + +alter table test drop column id; + +select lower('test'); +"; + + let result = StatementSplitter::new(input).run(); + + for r in &result { + println!("{:?} {:?}", r.kind, r.range); + println!("'{}'", input[r.range].to_string()); + } + + assert_eq!(result.len(), 4); + assert_eq!( + "select id, name, test1231234123, unknown from co;", + input[result[0].range].to_string() + ); + assert_eq!(SyntaxKind::SelectStmt, result[0].kind); + assert_eq!("select 14433313331333", input[result[1].range].to_string()); + assert_eq!(SyntaxKind::SelectStmt, result[1].kind); + assert_eq!(SyntaxKind::AlterTableStmt, result[2].kind); + assert_eq!( + "alter table test drop column id;", + input[result[2].range].to_string() + ); + assert_eq!(SyntaxKind::SelectStmt, result[3].kind); + assert_eq!("select lower('test');", input[result[3].range].to_string()); + } + + #[test] + fn test_create_or_replace() { + let input = "CREATE OR REPLACE TRIGGER check_update + BEFORE UPDATE OF balance ON accounts + FOR EACH ROW + EXECUTE FUNCTION check_account_update();\nexecute test;"; + + let result = StatementSplitter::new(input).run(); + + assert_eq!(result.len(), 2); + assert_eq!( + "CREATE OR REPLACE TRIGGER check_update\n BEFORE UPDATE OF balance ON accounts\n FOR EACH ROW\n EXECUTE FUNCTION check_account_update();", + input[result[0].range].to_string() + ); + assert_eq!(SyntaxKind::CreateTrigStmt, result[0].kind); + assert_eq!("execute test;", input[result[1].range].to_string()); + assert_eq!(SyntaxKind::ExecuteStmt, result[1].kind); + } + + #[test] + fn test_prohibited_follow_up() { + let input = + "insert into public.test (id) select 1 from other.test where id = 2;\nselect 4;"; + + let result = StatementSplitter::new(input).run(); + + assert_eq!(result.len(), 2); + assert_eq!( + "insert into public.test (id) select 1 from other.test where id = 2;", + input[result[0].range].to_string() + ); + assert_eq!(SyntaxKind::InsertStmt, result[0].kind); + assert_eq!("select 4;", input[result[1].range].to_string()); + assert_eq!(SyntaxKind::SelectStmt, result[1].kind); + } + + #[test] + fn test_schema() { + let input = "delete from public.table where id = 2;\nselect 4;"; + + let result = StatementSplitter::new(input).run(); + + assert_eq!(result.len(), 2); + assert_eq!( + "delete from public.table where id = 2;", + input[result[0].range].to_string() + ); + assert_eq!(SyntaxKind::DeleteStmt, result[0].kind); + assert_eq!("select 4;", input[result[1].range].to_string()); + assert_eq!(SyntaxKind::SelectStmt, result[1].kind); + } + + #[test] + fn test_sub_statement() { + let input = "select 1 from (select 2 from contact) c;\nselect 4;"; + + let result = StatementSplitter::new(input).run(); + + assert_eq!(result.len(), 2); + assert_eq!( + "select 1 from (select 2 from contact) c;", + input[result[0].range].to_string() + ); + assert_eq!(SyntaxKind::SelectStmt, result[0].kind); + assert_eq!("select 4;", input[result[1].range].to_string()); + assert_eq!(SyntaxKind::SelectStmt, result[1].kind); + } + + #[test] + fn test_semicolon_precedence() { + let input = "select 1 from ;\nselect 4;"; + + let result = StatementSplitter::new(input).run(); + + assert_eq!(result.len(), 2); + assert_eq!("select 1 from ;", input[result[0].range].to_string()); + assert_eq!(SyntaxKind::SelectStmt, result[0].kind); + assert_eq!("select 4;", input[result[1].range].to_string()); + assert_eq!(SyntaxKind::SelectStmt, result[1].kind); + } + + #[test] + fn test_union_with_semicolon() { + let input = "select 1 from contact union;\nselect 4;"; + + let result = StatementSplitter::new(input).run(); + + assert_eq!(result.len(), 2); + assert_eq!( + "select 1 from contact union;", + input[result[0].range].to_string() + ); + assert_eq!("select 4;", input[result[1].range].to_string()); + } + + #[test] + fn test_union() { + let input = "select 1 from contact union select 1;\nselect 4;"; + + let result = StatementSplitter::new(input).run(); + + assert_eq!(result.len(), 2); + assert_eq!( + "select 1 from contact union select 1;", + input[result[0].range].to_string() + ); + assert_eq!(SyntaxKind::SelectStmt, result[0].kind); + assert_eq!("select 4;", input[result[1].range].to_string()); + assert_eq!(SyntaxKind::SelectStmt, result[1].kind); + } + + #[test] + fn test_splitter() { + let input = "select 1 from contact;\nselect 1;\nselect 4;"; + + let result = StatementSplitter::new(input).run(); + + assert_eq!(result.len(), 3); + assert_eq!("select 1 from contact;", input[result[0].range].to_string()); + assert_eq!(SyntaxKind::SelectStmt, result[0].kind); + assert_eq!("select 1;", input[result[1].range].to_string()); + assert_eq!(SyntaxKind::SelectStmt, result[1].kind); + assert_eq!("select 4;", input[result[2].range].to_string()); + assert_eq!(SyntaxKind::SelectStmt, result[2].kind); + } + + #[test] + fn test_no_semicolons() { + let input = "select 1 from contact\nselect 1\nselect 4"; + + let result = StatementSplitter::new(input).run(); + + assert_eq!(result.len(), 3); + assert_eq!("select 1 from contact", input[result[0].range].to_string()); + assert_eq!(SyntaxKind::SelectStmt, result[0].kind); + assert_eq!("select 1", input[result[1].range].to_string()); + assert_eq!(SyntaxKind::SelectStmt, result[1].kind); + assert_eq!("select 4", input[result[2].range].to_string()); + assert_eq!(SyntaxKind::SelectStmt, result[2].kind); + } + + #[test] + fn test_explain() { + let input = "explain select 1 from contact\nselect 1\nselect 4"; + + let result = StatementSplitter::new(input).run(); + + for range in &result { + println!("Result: '{}'", input[range.range].to_string()); + } + + assert_eq!(result.len(), 3); + assert_eq!( + "explain select 1 from contact", + input[result[0].range].to_string() + ); + assert_eq!(SyntaxKind::ExplainStmt, result[0].kind); + assert_eq!("select 1", input[result[1].range].to_string()); + assert_eq!(SyntaxKind::SelectStmt, result[1].kind); + assert_eq!("select 4", input[result[2].range].to_string()); + assert_eq!(SyntaxKind::SelectStmt, result[2].kind); + } + + #[test] + fn test_explain_analyze() { + let input = "explain analyze select 1 from contact;\nselect 1;\nselect 4;"; + + let result = StatementSplitter::new(input).run(); + + assert_eq!(result.len(), 3); + assert_eq!( + "explain analyze select 1 from contact;", + input[result[0].range].to_string() + ); + assert_eq!(SyntaxKind::ExplainStmt, result[0].kind); + assert_eq!("select 1;", input[result[1].range].to_string()); + assert_eq!(SyntaxKind::SelectStmt, result[1].kind); + assert_eq!("select 4;", input[result[2].range].to_string()); + assert_eq!(SyntaxKind::SelectStmt, result[2].kind); + } + + #[test] + fn test_cast() { + let input = "SELECT CAST(42 AS float8);\nselect 1"; + + let result = StatementSplitter::new(input).run(); + + assert_eq!(result.len(), 2); + assert_eq!( + "SELECT CAST(42 AS float8);", + input[result[0].range].to_string() + ); + assert_eq!(SyntaxKind::SelectStmt, result[0].kind); + assert_eq!("select 1", input[result[1].range].to_string()); + assert_eq!(SyntaxKind::SelectStmt, result[1].kind); + } + + #[test] + fn test_create_conversion() { + let input = "CREATE CONVERSION myconv FOR 'UTF8' TO 'LATIN1' FROM myfunc;"; + + let result = StatementSplitter::new(input).run(); + + assert_eq!(result.len(), 1); + assert_eq!( + "CREATE CONVERSION myconv FOR 'UTF8' TO 'LATIN1' FROM myfunc;", + input[result[0].range].to_string() + ); + assert_eq!(SyntaxKind::CreateConversionStmt, result[0].kind); + } + + #[test] + fn test_with_comment() { + let input = "--\n-- ADVISORY LOCKS\n--\n\nBEGIN;\n\nSELECT\n\tpg_advisory_xact_lock(1), pg_advisory_xact_lock_shared(2),\n\tpg_advisory_xact_lock(1, 1), pg_advisory_xact_lock_shared(2, 2);\n\nSELECT locktype, classid, objid, objsubid, mode, granted\n\tFROM pg_locks WHERE locktype = 'advisory'\n\tORDER BY classid, objid, objsubid;\n\n\n-- pg_advisory_unlock_all() shouldn't release xact locks\nSELECT pg_advisory_unlock_all();\n\nSELECT count(*) FROM pg_locks WHERE locktype = 'advisory';\n\n\n-- can't unlock xact locks\nSELECT\n\tpg_advisory_unlock(1), pg_advisory_unlock_shared(2),\n\tpg_advisory_unlock(1, 1), pg_advisory_unlock_shared(2, 2);\n\n\n-- automatically release xact locks at commit\nCOMMIT;"; + + let result = StatementSplitter::new(input).run(); + + assert_eq!(result.len(), 7); + } + + #[test] + fn test_composite_type() { + let input = "create type avg_state as (total bigint, count bigint);\ncreate type test;"; + + let result = StatementSplitter::new(input).run(); + + assert_eq!(result.len(), 2); + assert_eq!( + "create type avg_state as (total bigint, count bigint);", + input[result[0].range].to_string() + ); + assert_eq!(SyntaxKind::CompositeTypeStmt, result[0].kind); + assert_eq!("create type test;", input[result[1].range].to_string()); + assert_eq!(SyntaxKind::DefineStmt, result[1].kind); + } + + #[test] + fn test_set() { + let input = "CREATE FUNCTION test_opclass_options_func(internal) + RETURNS void + AS :'regresslib', 'test_opclass_options_func' + LANGUAGE C; + +SET client_min_messages TO 'warning'; + +DROP ROLE IF EXISTS regress_alter_generic_user1;"; + + let result = StatementSplitter::new(input).run(); + + assert_eq!(result.len(), 3); + assert_eq!( + "CREATE FUNCTION test_opclass_options_func(internal)\n RETURNS void\n AS :'regresslib', 'test_opclass_options_func'\n LANGUAGE C;", + input[result[0].range].to_string() + ); + assert_eq!(SyntaxKind::CreateFunctionStmt, result[0].kind); + assert_eq!( + "SET client_min_messages TO 'warning';", + input[result[1].range].to_string() + ); + assert_eq!(SyntaxKind::VariableSetStmt, result[1].kind); + assert_eq!( + "DROP ROLE IF EXISTS regress_alter_generic_user1;", + input[result[2].range].to_string() + ); + assert_eq!(SyntaxKind::DropRoleStmt, result[2].kind); + } + + #[test] + fn test_incomplete_statement() { + let input = "create\nselect 1;"; + + let result = StatementSplitter::new(input).run(); + + assert_eq!(result.len(), 2); + assert_eq!("create", input[result[0].range].to_string()); + assert_eq!(SyntaxKind::Any, result[0].kind); + assert_eq!("select 1;", input[result[1].range].to_string()); + assert_eq!(SyntaxKind::SelectStmt, result[1].kind); + } + + #[test] + fn test_incomplete_statement_at_end() { + let input = "select 1;\ncreate"; + + let result = StatementSplitter::new(input).run(); + + assert_eq!(result.len(), 2); + assert_eq!("select 1;", input[result[0].range].to_string()); + assert_eq!(SyntaxKind::SelectStmt, result[0].kind); + assert_eq!("create", input[result[1].range].to_string()); + assert_eq!(SyntaxKind::Any, result[1].kind); + } + + #[test] + fn test_only_incomplete_statement_semicolon() { + let input = "create;\nselect 1;"; + + let result = StatementSplitter::new(input).run(); + + assert_eq!(result.len(), 2); + assert_eq!("create", input[result[0].range].to_string()); + assert_eq!(SyntaxKind::Any, result[0].kind); + assert_eq!("select 1;", input[result[1].range].to_string()); + assert_eq!(SyntaxKind::SelectStmt, result[1].kind); + } + + #[test] + fn test_set_with_schema() { + let input = "SET custom.my_guc = 42;"; + + let result = StatementSplitter::new(input).run(); + + assert_eq!(result.len(), 1); + assert_eq!( + "SET custom.my_guc = 42;", + input[result[0].range].to_string() + ); + assert_eq!(SyntaxKind::VariableSetStmt, result[0].kind); + } + + #[test] + fn test_only_incomplete_statement() { + let input = " create "; + + let result = StatementSplitter::new(input).run(); + + assert_eq!(result.len(), 1); + assert_eq!("create", input[result[0].range].to_string()); + assert_eq!(SyntaxKind::Any, result[0].kind); + } + + #[test] + fn test_reset() { + let input = " +DROP ROLE IF EXISTS regress_alter_generic_user3; + +RESET client_min_messages; + +CREATE USER regress_alter_generic_user3; +"; + + let result = StatementSplitter::new(input).run(); + + assert_eq!(result.len(), 3); + assert_eq!(SyntaxKind::DropRoleStmt, result[0].kind); + assert_eq!(SyntaxKind::VariableSetStmt, result[1].kind); + assert_eq!(SyntaxKind::CreateRoleStmt, result[2].kind); + } + + #[test] + fn test_grant_and_set_session_auth() { + let input = " +CREATE SCHEMA alt_nsp2; + +GRANT ALL ON SCHEMA alt_nsp1, alt_nsp2 TO public; + +SET search_path = alt_nsp1, public; + +SET SESSION AUTHORIZATION regress_alter_generic_user1; +"; + + let result = StatementSplitter::new(input).run(); + + assert_eq!(result.len(), 4); + assert_eq!(SyntaxKind::CreateSchemaStmt, result[0].kind); + assert_eq!(SyntaxKind::GrantStmt, result[1].kind); + assert_eq!(SyntaxKind::VariableSetStmt, result[2].kind); + assert_eq!(SyntaxKind::VariableSetStmt, result[3].kind); + } + + #[test] + fn test_create_fn_and_agg() { + let input = " +CREATE FUNCTION alt_func1(int) RETURNS int LANGUAGE sql + AS 'SELECT $1 + 1'; +CREATE FUNCTION alt_func2(int) RETURNS int LANGUAGE sql + AS 'SELECT $1 - 1'; +CREATE AGGREGATE alt_agg1 ( + sfunc1 = int4pl, basetype = int4, stype1 = int4, initcond = 0 +); +CREATE AGGREGATE alt_agg2 ( + sfunc1 = int4mi, basetype = int4, stype1 = int4, initcond = 0 +); +"; + let result = StatementSplitter::new(input).run(); + + assert_eq!(result.len(), 4); + assert_eq!(SyntaxKind::CreateFunctionStmt, result[0].kind); + assert_eq!(SyntaxKind::CreateFunctionStmt, result[1].kind); + assert_eq!(SyntaxKind::DefineStmt, result[2].kind); + assert_eq!(SyntaxKind::DefineStmt, result[3].kind); + } + + #[test] + fn test_create_alter_agg() { + let input = " +CREATE AGGREGATE alt_agg2 ( + sfunc1 = int4mi, basetype = int4, stype1 = int4, initcond = 0 +); +ALTER AGGREGATE alt_func1(int) RENAME TO alt_func3; +ALTER AGGREGATE alt_func1(int) OWNER TO regress_alter_generic_user3; +ALTER AGGREGATE alt_func1(int) SET SCHEMA alt_nsp2; +"; + let result = StatementSplitter::new(input).run(); + + assert_eq!(result.len(), 4); + assert_eq!(SyntaxKind::DefineStmt, result[0].kind); + assert_eq!(SyntaxKind::RenameStmt, result[1].kind); + assert_eq!(SyntaxKind::AlterOwnerStmt, result[2].kind); + assert_eq!(SyntaxKind::AlterObjectSchemaStmt, result[3].kind); + } + + #[test] + fn test_reset_session() { + let input = " +ALTER AGGREGATE alt_agg2(int) SET SCHEMA alt_nsp2; + +RESET SESSION AUTHORIZATION; +"; + let result = StatementSplitter::new(input).run(); + + assert_eq!(result.len(), 2); + assert_eq!(SyntaxKind::AlterObjectSchemaStmt, result[0].kind); + assert_eq!(SyntaxKind::VariableSetStmt, result[1].kind); + } + + #[test] + fn test_rename_fdw() { + let input = " +CREATE SERVER alt_fserv2 FOREIGN DATA WRAPPER alt_fdw2; + +ALTER FOREIGN DATA WRAPPER alt_fdw1 RENAME TO alt_fdw2; +ALTER FOREIGN DATA WRAPPER alt_fdw1 RENAME TO alt_fdw3; +"; + + let result = StatementSplitter::new(input).run(); + + assert_eq!(result.len(), 3); + assert_eq!(SyntaxKind::CreateForeignServerStmt, result[0].kind); + assert_eq!(SyntaxKind::RenameStmt, result[1].kind); + assert_eq!(SyntaxKind::RenameStmt, result[2].kind); + } + + #[test] + fn test_ops() { + let input = " +ALTER OPERATOR FAMILY alt_opf4 USING btree DROP + -- int4 vs int2 + OPERATOR 1 (int4, int2) , + OPERATOR 2 (int4, int2) , + OPERATOR 3 (int4, int2) , + OPERATOR 4 (int4, int2) , + OPERATOR 5 (int4, int2) , + FUNCTION 1 (int4, int2) ; +DROP OPERATOR FAMILY alt_opf4 USING btree; +"; + + let result = StatementSplitter::new(input).run(); + + assert_eq!(result.len(), 2); + assert_eq!(SyntaxKind::AlterOpFamilyStmt, result[0].kind); + assert_eq!(SyntaxKind::DropStmt, result[1].kind); + } + + #[test] + fn test_temp_table() { + let input = " +CREATE TEMP TABLE foo (f1 int, f2 int, f3 int, f4 int); + +CREATE INDEX fooindex ON foo (f1 desc, f2 asc, f3 nulls first, f4 nulls last); +"; + + let result = StatementSplitter::new(input).run(); + + assert_eq!(result.len(), 2); + assert_eq!(SyntaxKind::CreateStmt, result[0].kind); + assert_eq!(SyntaxKind::IndexStmt, result[1].kind); + } + + #[test] + fn test_create_table_as() { + let input = " +CREATE TEMP TABLE point_tbl AS SELECT * FROM public.point_tbl; +"; + let result = StatementSplitter::new(input).run(); + + assert_eq!(result.len(), 1); + assert_eq!(SyntaxKind::CreateTableAsStmt, result[0].kind); + } + + #[test] + fn test_analyze() { + let input = " +ANALYZE array_op_test; +INSERT INTO arrtest (a[1:5], b[1:1][1:2][1:2], c, d, f, g) + VALUES ('{1,2,3,4,5}', '{{{0,0},{1,2}}}', '{}', '{}', '{}', '{}'); +"; + let result = StatementSplitter::new(input).run(); + + assert_eq!(result.len(), 2); + assert_eq!(SyntaxKind::VacuumStmt, result[0].kind); + assert_eq!(SyntaxKind::InsertStmt, result[1].kind); + } + + #[test] + fn test_drop_operator() { + let input = " +DROP OPERATOR === (boolean, boolean); +"; + + let result = StatementSplitter::new(input).run(); + + assert_eq!(result.len(), 1); + assert_eq!(SyntaxKind::DropStmt, result[0].kind); + } + + #[test] + fn test_language() { + let input = " +CREATE LANGUAGE alt_lang1 HANDLER plpgsql_call_handler; +CREATE LANGUAGE alt_lang2 HANDLER plpgsql_call_handler; + +ALTER LANGUAGE alt_lang1 OWNER TO regress_alter_generic_user1; +ALTER LANGUAGE alt_lang2 OWNER TO regress_alter_generic_user2; +"; + + let result = StatementSplitter::new(input).run(); + + assert_eq!(result.len(), 4); + assert_eq!(SyntaxKind::CreatePlangStmt, result[0].kind); + assert_eq!(SyntaxKind::CreatePlangStmt, result[1].kind); + assert_eq!(SyntaxKind::AlterOwnerStmt, result[2].kind); + assert_eq!(SyntaxKind::AlterOwnerStmt, result[3].kind); + } + + #[test] + fn test_alter_op_family() { + let input = " +ALTER OPERATOR FAMILY alt_opf1 USING hash OWNER TO regress_alter_generic_user1; +"; + + let result = StatementSplitter::new(input).run(); + + assert_eq!(result.len(), 1); + assert_eq!(SyntaxKind::AlterOwnerStmt, result[0].kind); + } + + #[test] + fn test_drop_op_family() { + let input = " +DROP OPERATOR FAMILY alt_opf4 USING btree; +"; + + let result = StatementSplitter::new(input).run(); + + assert_eq!(result.len(), 1); + assert_eq!(SyntaxKind::DropStmt, result[0].kind); + } + + #[test] + fn test_set_role() { + let input = " +SET ROLE regress_alter_generic_user5; +"; + + let result = StatementSplitter::new(input).run(); + + assert_eq!(result.len(), 1); + assert_eq!(SyntaxKind::VariableSetStmt, result[0].kind); + } + + #[test] + fn test_revoke() { + let input = " +CREATE ROLE regress_alter_generic_user6; +CREATE SCHEMA alt_nsp6; +REVOKE ALL ON SCHEMA alt_nsp6 FROM regress_alter_generic_user6; +CREATE OPERATOR FAMILY alt_nsp6.alt_opf6 USING btree; +SET ROLE regress_alter_generic_user6; +ALTER OPERATOR FAMILY alt_nsp6.alt_opf6 USING btree ADD OPERATOR 1 < (int4, int2); +"; + + let result = StatementSplitter::new(input).run(); + + assert_eq!(result.len(), 6); + assert_eq!(SyntaxKind::CreateRoleStmt, result[0].kind); + assert_eq!(SyntaxKind::CreateSchemaStmt, result[1].kind); + assert_eq!(SyntaxKind::GrantStmt, result[2].kind); + assert_eq!(SyntaxKind::CreateOpFamilyStmt, result[3].kind); + assert_eq!(SyntaxKind::VariableSetStmt, result[4].kind); + assert_eq!(SyntaxKind::AlterOpFamilyStmt, result[5].kind); + } + + #[test] + fn test_alter_op_family_2() { + let input = " +CREATE OPERATOR FAMILY alt_opf4 USING btree; +ALTER OPERATOR FAMILY test.alt_opf4 USING btree ADD + -- int4 vs int2 + OPERATOR 1 < (int4, int2) , + OPERATOR 2 <= (int4, int2) , + OPERATOR 3 = (int4, int2) , + OPERATOR 4 >= (int4, int2) , + OPERATOR 5 > (int4, int2) , + FUNCTION 1 btint42cmp(int4, int2); +"; + + let result = StatementSplitter::new(input).run(); + + assert_eq!(result.len(), 2); + assert_eq!(SyntaxKind::CreateOpFamilyStmt, result[0].kind); + assert_eq!(SyntaxKind::AlterOpFamilyStmt, result[1].kind); + } + + #[test] + fn test_create_stat() { + let input = " +CREATE STATISTICS alt_stat1 ON a, b FROM alt_regress_1; +"; + + let result = StatementSplitter::new(input).run(); + + assert_eq!(result.len(), 1); + assert_eq!(SyntaxKind::CreateStatsStmt, result[0].kind); + } + + #[test] + fn test_create_text_search_dictionary() { + let input = " +CREATE TEXT SEARCH DICTIONARY alt_ts_dict1 (template=simple); +"; + + let result = StatementSplitter::new(input).run(); + + assert_eq!(result.len(), 1); + assert_eq!(SyntaxKind::DefineStmt, result[0].kind); + } + + #[test] + fn test_create_text_search_configuration() { + let input = " +CREATE TEXT SEARCH CONFIGURATION alt_ts_conf1 (copy=english); +"; + + let result = StatementSplitter::new(input).run(); + + assert_eq!(result.len(), 1); + assert_eq!(SyntaxKind::DefineStmt, result[0].kind); + } + + #[test] + fn test_alter_operator() { + let input = " +ALTER OPERATOR === (boolean, boolean) SET (RESTRICT = NONE); +"; + + let result = StatementSplitter::new(input).run(); + + assert_eq!(result.len(), 1); + assert_eq!(SyntaxKind::AlterOperatorStmt, result[0].kind); + } + + #[test] + fn test_drop_fdw() { + let input = " +DROP FOREIGN DATA WRAPPER alt_fdw2 CASCADE; +"; + + let result = StatementSplitter::new(input).run(); + + assert_eq!(result.len(), 1); + assert_eq!(SyntaxKind::DropStmt, result[0].kind); + } + + #[test] + fn test_insert_select() { + let input = " +insert into src select string_agg(random()::text,'') from generate_series(1,10000); +"; + + let result = StatementSplitter::new(input).run(); + + assert_eq!(result.len(), 1); + assert_eq!(SyntaxKind::InsertStmt, result[0].kind); + } + + #[test] + fn test_on_conflict() { + let input = " +insert into arr_pk_tbl values (1, '{3,4,5}') on conflict (pk)\n do update set f1[1] = excluded.f1[1], f1[3] = excluded.f1[3]\n returning pk, f1; +"; + + let result = StatementSplitter::new(input).run(); + + assert_eq!(result.len(), 1); + assert_eq!(SyntaxKind::InsertStmt, result[0].kind); + } + + #[test] + fn test_alter_index() { + let input = " +ALTER INDEX btree_tall_idx2 ALTER COLUMN id SET (n_distinct=100); +"; + + let result = StatementSplitter::new(input).run(); + + assert_eq!(result.len(), 1); + assert_eq!(SyntaxKind::AlterTableStmt, result[0].kind); + } + + #[test] + fn test_update_set() { + let input = " +UPDATE CASE_TBL\n SET i = CASE WHEN i >= 3 THEN (- i)\n ELSE (2 * i) END; +"; + + let result = StatementSplitter::new(input).run(); + + assert_eq!(result.len(), 1); + assert_eq!(SyntaxKind::UpdateStmt, result[0].kind); + } + + #[test] + fn test_savepoint() { + let input = " +SAVEPOINT s1; +"; + + let result = StatementSplitter::new(input).run(); + + assert_eq!(result.len(), 1); + assert_eq!(SyntaxKind::TransactionStmt, result[0].kind); + } + + #[test] + fn test_declare_cursor() { + let input = " +DECLARE c CURSOR FOR SELECT ctid,cmin,* FROM combocidtest; +"; + + let result = StatementSplitter::new(input).run(); + + assert_eq!(result.len(), 1); + assert_eq!(SyntaxKind::DeclareCursorStmt, result[0].kind); + } + + #[test] + fn test_create_empty_table() { + let input = " +CREATE TABLE IF NOT EXISTS testcase( +); +"; + + let result = StatementSplitter::new(input).run(); + + assert_eq!(result.len(), 1); + assert_eq!(SyntaxKind::CreateStmt, result[0].kind); + } + + #[test] + fn test_rollback_to() { + let input = " +ROLLBACK TO SAVEPOINT subxact; +"; + + let result = StatementSplitter::new(input).run(); + + assert_eq!(result.len(), 1); + assert_eq!(SyntaxKind::TransactionStmt, result[0].kind); + } + + #[test] + fn test_rule_delete_from() { + let input = " +create rule qqq as on insert to copydml_test do also delete from copydml_test; +"; + + let result = StatementSplitter::new(input).run(); + + assert_eq!(result.len(), 1); + assert_eq!(SyntaxKind::RuleStmt, result[0].kind); + } + + #[test] + fn test_create_cast() { + let input = " +CREATE CAST (text AS casttesttype) WITHOUT FUNCTION; +"; + + let result = StatementSplitter::new(input).run(); + + assert_eq!(result.len(), 1); + assert_eq!(SyntaxKind::CreateCastStmt, result[0].kind); + } + + #[test] + fn test_begin_atomic() { + let input = " +CREATE PROCEDURE ptest1s(x text)\nLANGUAGE SQL\nBEGIN ATOMIC\n INSERT INTO cp_test VALUES (1, x);\nEND;\nselect 1; +"; + + let result = StatementSplitter::new(input).run(); + + assert_eq!(result.len(), 2); + assert_eq!(SyntaxKind::CreateFunctionStmt, result[0].kind); + assert_eq!(SyntaxKind::SelectStmt, result[1].kind); + } + + #[test] + fn test_drop_procedure() { + let input = " +CREATE PROCEDURE ptest4b(INOUT b int, INOUT a int) +LANGUAGE SQL +AS $$ +CALL ptest4a(a, b) +$$; + +DROP PROCEDURE ptest4a; + +CREATE OR REPLACE PROCEDURE ptest5(a int, b text, c int default 100) +LANGUAGE SQL +AS $$ +INSERT INTO cp_test VALUES(a, b) +$$; +"; + let result = StatementSplitter::new(input).run(); + + assert_eq!(result.len(), 3); + assert_eq!(SyntaxKind::CreateFunctionStmt, result[0].kind); + assert_eq!(SyntaxKind::DropStmt, result[1].kind); + assert_eq!(SyntaxKind::CreateFunctionStmt, result[2].kind); + } + + #[test] + fn test_prepare_as() { + let input = " +DROP VIEW fdv4; + +PREPARE foo AS + SELECT id, keywords, title, body, created + FROM articles + GROUP BY id; + +EXECUTE foo; +"; + let result = StatementSplitter::new(input).run(); + + assert_eq!(result.len(), 3); + assert_eq!(SyntaxKind::DropStmt, result[0].kind); + assert_eq!(SyntaxKind::PrepareStmt, result[1].kind); + assert_eq!(SyntaxKind::ExecuteStmt, result[2].kind); + } + + #[test] + fn create_function_set() { + let input = " +create function report_guc(text) returns text as\n$$ select current_setting($1) $$ language sql\nset work_mem = '1MB'; +"; + let result = StatementSplitter::new(input).run(); + + assert_eq!(result.len(), 1); + assert_eq!(SyntaxKind::CreateFunctionStmt, result[0].kind); + } + + #[test] + fn test_drop_function() { + let input = " +DROP FUNCTION set(name); +"; + let result = StatementSplitter::new(input).run(); + + assert_eq!(result.len(), 1); + assert_eq!(SyntaxKind::DropStmt, result[0].kind); + } + + #[test] + fn test_call_version() { + let input = " +CALL version(); +CALL sum(1); +"; + let result = StatementSplitter::new(input).run(); + + assert_eq!(result.len(), 2); + assert_eq!(SyntaxKind::CallStmt, result[0].kind); + assert_eq!(SyntaxKind::CallStmt, result[1].kind); + } + + #[test] + fn test_drop_lang() { + let input = " +DROP OPERATOR @#@ (int8, int8); +DROP LANGUAGE test_language_exists; +DROP LANGUAGE IF EXISTS test_language_exists; +"; + let result = StatementSplitter::new(input).run(); + + assert_eq!(result.len(), 3); + assert_eq!(SyntaxKind::DropStmt, result[0].kind); + assert_eq!(SyntaxKind::DropStmt, result[1].kind); + assert_eq!(SyntaxKind::DropStmt, result[2].kind); + } + + #[test] + fn alter_mat_view() { + let input = " +ALTER MATERIALIZED VIEW mvtest_tvm SET SCHEMA mvtest_mvschema; +"; + let result = StatementSplitter::new(input).run(); + + assert_eq!(result.len(), 1); + assert_eq!(SyntaxKind::AlterObjectSchemaStmt, result[0].kind); + } + + #[test] + fn move_backward() { + let input = " +MOVE BACKWARD ALL IN c1; +"; + let result = StatementSplitter::new(input).run(); + + assert_eq!(result.len(), 1); + assert_eq!(SyntaxKind::FetchStmt, result[0].kind); + } + + #[test] + fn create_tbl_as_2() { + let input = " +create table simple as + select generate_series(1, 20000) AS id, 'aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa'; +"; + let result = StatementSplitter::new(input).run(); + + assert_eq!(result.len(), 1); + assert_eq!(SyntaxKind::CreateTableAsStmt, result[0].kind); + } + + #[test] + fn create_tbl_as() { + let input = " +CREATE TABLE tab_settings_flags AS SELECT name, category, + 'EXPLAIN' = ANY(flags) AS explain, + 'NO_RESET_ALL' = ANY(flags) AS no_reset_all, + 'NO_SHOW_ALL' = ANY(flags) AS no_show_all, + 'NOT_IN_SAMPLE' = ANY(flags) AS not_in_sample, + 'RUNTIME_COMPUTED' = ANY(flags) AS runtime_computed + FROM pg_show_all_settings() AS psas, + pg_settings_get_flags(psas.name) AS flags; +"; + let result = StatementSplitter::new(input).run(); + + assert_eq!(result.len(), 1); + assert_eq!(SyntaxKind::CreateTableAsStmt, result[0].kind); + } + + #[test] + fn alter_table_owner() { + let input = " +ALTER TABLE seclabel_tbl1 OWNER TO regress_seclabel_user1; +"; + let result = StatementSplitter::new(input).run(); + + assert_eq!(result.len(), 1); + assert_eq!(SyntaxKind::AlterTableStmt, result[0].kind); + } + + #[test] + fn alter_table_rename() { + let input = " +ALTER TABLE foo_seq RENAME TO foo_seq_new; +"; + let result = StatementSplitter::new(input).run(); + + assert_eq!(result.len(), 1); + assert_eq!(SyntaxKind::RenameStmt, result[0].kind); + } + + #[test] + fn alter_seq() { + let input = " +ALTER SEQUENCE sequence_test_unlogged SET LOGGED; +"; + let result = StatementSplitter::new(input).run(); + + assert_eq!(result.len(), 1); + assert_eq!(SyntaxKind::AlterTableStmt, result[0].kind); + } + + #[test] + fn create_op_class() { + let input = " +create operator class part_test_text_ops for type text using hash as + operator 1 =, + function 2 part_hashtext_length(text, int8); +"; + let result = StatementSplitter::new(input).run(); + + assert_eq!(result.len(), 1); + assert_eq!(SyntaxKind::CreateOpClassStmt, result[0].kind); + } + + #[test] + fn case_end() { + let input = " +SELECT q1, case when q1 > 0 then generate_series(1,3) else 0 end FROM int8_tbl; +"; + let result = StatementSplitter::new(input).run(); + + assert_eq!(result.len(), 1); + assert_eq!(SyntaxKind::SelectStmt, result[0].kind); + } + + #[test] + fn just_table() { + // wtf? + let input = " +TABLE t1; +"; + let result = StatementSplitter::new(input).run(); + + assert_eq!(result.len(), 1); + assert_eq!(SyntaxKind::SelectStmt, result[0].kind); + } + + #[test] + fn explain_create_table() { + let input = " +explain (costs off) create table parallel_write as select length(stringu1) from tenk1 group by length(stringu1); +"; + let result = StatementSplitter::new(input).run(); + + assert_eq!(result.len(), 1); + assert_eq!(SyntaxKind::ExplainStmt, result[0].kind); + } + + #[test] + fn create_table_as_execute() { + let input = " +create table parallel_write as execute prep_stmt; +"; + let result = StatementSplitter::new(input).run(); + + assert_eq!(result.len(), 1); + assert_eq!(SyntaxKind::CreateTableAsStmt, result[0].kind); + } + + #[test] + fn cte_select() { + let input = " +WITH t1 AS ( + SELECT * FROM t1 +), t2 AS ( + SELECT * FROM t2 +) +SELECT 's'; +"; + let result = StatementSplitter::new(input).run(); + + assert_eq!(result.len(), 1); + assert_eq!(SyntaxKind::SelectStmt, result[0].kind); + } + + #[test] + fn cte_select_without_repeated() { + let input = " +WITH t1 AS ( + SELECT * FROM t1 +) +SELECT 's'; +"; + let result = StatementSplitter::new(input).run(); + + assert_eq!(result.len(), 1); + assert_eq!(SyntaxKind::SelectStmt, result[0].kind); + } + + #[test] + fn union_intersect() { + let input = " +(select 1) union (select 2) except (select 3) intersect (select 4); +"; + let result = StatementSplitter::new(input).run(); + + assert_eq!(result.len(), 1); + assert_eq!(SyntaxKind::SelectStmt, result[0].kind); + } + + #[test] + fn alter_table_cluster_on() { + let input = " +ALTER TABLE clstr_tst CLUSTER ON clstr_tst_b_c; +"; + let result = StatementSplitter::new(input).run(); + + assert_eq!(result.len(), 1); + assert_eq!(SyntaxKind::AlterTableStmt, result[0].kind); + } + + #[allow(clippy::must_use)] + fn debug(input: &str) { + for s in input.split(';').filter_map(|s| { + if s.trim().is_empty() { + None + } else { + Some(s.trim()) + } + }) { + println!("Statement: '{:?}'", s); + + let res = pg_query::parse(s) + .map(|parsed| { + parsed + .protobuf + .nodes() + .iter() + .find(|n| n.1 == 1) + .unwrap() + .0 + .to_enum() + }) + .unwrap(); + println!("Result: {:?}", res); + } + + let result = StatementSplitter::new(input).run(); + + for r in &result { + println!("{:?} {:?}", r.kind, input[r.range].to_string()); + } + + for t in lex(input) { + println!("{:?}", t.kind); + } + + assert!(false); + } +} diff --git a/crates/pg_statement_splitter/src/tracker.rs b/crates/pg_statement_splitter/src/tracker.rs new file mode 100644 index 00000000..57e312d6 --- /dev/null +++ b/crates/pg_statement_splitter/src/tracker.rs @@ -0,0 +1,366 @@ +use pg_lexer::{SyntaxKind, WHITESPACE_TOKENS}; + +use crate::data::{StatementDefinition, SyntaxDefinition}; + +#[derive(Debug)] +pub struct TokenTracker<'a> { + pub tokens: &'a Vec, + + /// position in the definition, and for each position we track the current token for that + /// position. required for groups. + pub positions: Vec, + + /// only for RepeatedGroup + child: Option>>, +} + +impl<'a> TokenTracker<'a> { + pub fn new(tokens: &'a Vec) -> Self { + Self { + tokens, + positions: vec![Position::new(1)], + child: None, + } + } + + pub fn advance_with(&mut self, kind: &SyntaxKind) -> bool { + let mut new_positions = Vec::with_capacity(self.positions.len()); + + for mut pos in self.positions.drain(..) { + match self.tokens.get(pos.idx) { + Some(SyntaxDefinition::OptionalRepeatedGroup(definitions)) => { + // if child does not exist, create it + if self.child.is_none() { + // check if we can spawn a new position for the next token + new_positions.extend(TokenTracker::next_possible_positions_from_with( + &self.tokens, + &pos, + kind, + )); + self.child = Some(Box::new(TokenTracker::new(definitions))); + new_positions.push(pos); + } else if self.child.as_mut().unwrap().advance_with(kind) { + if self.child.as_ref().unwrap().could_be_complete() { + new_positions.extend(TokenTracker::next_possible_positions_from_with( + &self.tokens, + &pos, + kind, + )); + } + // and advance it with the current token + new_positions.push(pos); + } + } + Some(SyntaxDefinition::RequiredToken(k)) => { + pos.advance(); + if k == kind { + new_positions.push(pos); + } + } + Some(SyntaxDefinition::AnyToken) => { + pos.advance(); + new_positions.push(pos); + } + Some(SyntaxDefinition::OneOf(kinds)) => { + if kinds.iter().any(|x| x == kind) { + pos.advance(); + new_positions.push(pos); + } + } + Some(SyntaxDefinition::OptionalToken(k)) => { + if k == kind { + pos.advance(); + new_positions.push(pos); + } else { + new_positions.extend(TokenTracker::next_possible_positions_from_with( + &self.tokens, + &pos, + kind, + )); + } + } + Some(SyntaxDefinition::AnyTokens(maybe_tokens)) => { + let next_positions = + TokenTracker::next_possible_positions_from_with(&self.tokens, &pos, kind); + + if next_positions.is_empty() { + // we only keep the current position if we either dont care about the + // tokens or the token is in the list of possible tokens + if let Some(tokens) = maybe_tokens { + if tokens.iter().any(|x| x == kind) { + new_positions.push(pos); + } + } else { + new_positions.push(pos); + } + } else { + new_positions.extend(next_positions); + } + } + Some(SyntaxDefinition::OptionalGroup(tokens)) => { + if pos.group_idx == 0 { + // if we are at the beginning of the group, we also need to spawn new + // trackers for every possible next token + new_positions.extend(TokenTracker::next_possible_positions_from_with( + &self.tokens, + &pos, + kind, + )); + } + + // advance group + let token = tokens.get(pos.group_idx).unwrap(); + if token == kind { + pos.advance_group(); + + // if we reached the end of the group, we advance the position + if pos.group_idx == tokens.len() { + pos.advance(); + } + + new_positions.push(pos); + } + } + None => { + // if we reached the end of the definition, we do nothing but keep the position + new_positions.push(pos); + + // TODO the problem with removing as position when there is no token anymore is + // that we will return false AT the last token, since the last token does not + // have any following tokens. even if the statement is complete at this point + // and still valid until the next token is added. + // + // i think to fix this, we need to track the CURRENT positions and not all + // possible NEXT positions. + } + }; + } + + self.positions = new_positions; + + self.positions.len() != 0 + } + + fn next_possible_positions_from_with( + tokens: &Vec, + pos: &Position, + kind: &SyntaxKind, + ) -> Vec { + let mut positions = Vec::new(); + + for (pos, token) in tokens.iter().enumerate().skip(pos.idx.to_owned()) { + match token { + SyntaxDefinition::RequiredToken(k) => { + if k == kind { + positions.push(Position::new(pos + 1)); + } + break; + } + SyntaxDefinition::OptionalToken(k) => { + if k == kind { + positions.push(Position::new(pos + 1)); + } + } + SyntaxDefinition::AnyTokens(_) => { + // + } + SyntaxDefinition::AnyToken => { + // + } + SyntaxDefinition::OneOf(kinds) => { + if kinds.iter().any(|x| x == kind) { + positions.push(Position::new(pos + 1)); + } + break; + } + SyntaxDefinition::OptionalGroup(t) => { + let first_token = t.first().unwrap(); + if first_token == kind { + positions.push(Position::new_with_group(pos + 1)); + } + } + SyntaxDefinition::OptionalRepeatedGroup(t) => { + let first_def = t.first().unwrap(); + match first_def { + SyntaxDefinition::RequiredToken(k) => { + if k == kind { + positions.push(Position::new(pos + 1)); + } + } + SyntaxDefinition::OneOf(kinds) => { + if kinds.iter().any(|x| x == kind) { + positions.push(Position::new(pos + 1)); + } + } + _ => { + panic!("OptionalRepeatedGroup must start with RequiredToken or OneOf"); + } + } + } + } + } + + positions + } + + pub fn could_be_complete(&self) -> bool { + self.tokens + .iter() + .skip( + self.positions + .iter() + .max_by_key(|p| p.idx) + .unwrap() + .to_owned() + .idx, + ) + .all(|x| match x { + SyntaxDefinition::RequiredToken(_) => false, + SyntaxDefinition::OneOf(_) => false, + SyntaxDefinition::AnyToken => false, + SyntaxDefinition::OptionalRepeatedGroup(_) => { + if self.child.is_none() { + true + } else { + self.child.as_ref().unwrap().could_be_complete() + } + } + _ => true, + }) + } + + pub fn current_positions(&self) -> Vec { + self.positions.iter().map(|x| x.idx).collect() + } + + /// Returns the max idx of all tracked positions while ignoring non-required tokens + pub fn max_pos(&self) -> usize { + self.positions + .iter() + .map(|p| { + // substract non-required tokens from the position count + (0..p.idx).fold(0, |acc, idx| { + let token = self.tokens.get(idx); + match token { + Some(SyntaxDefinition::RequiredToken(_)) => acc + 1, + Some(SyntaxDefinition::OneOf(_)) => acc + 1, + Some(SyntaxDefinition::AnyToken) => acc + 1, + _ => acc, + } + }) + }) + .max() + .unwrap() + } +} + +#[derive(Debug, Clone)] +pub struct Position { + idx: usize, + group_idx: usize, +} + +impl Position { + fn new(idx: usize) -> Self { + Self { idx, group_idx: 0 } + } + + fn new_with_group(idx: usize) -> Self { + Self { idx, group_idx: 1 } + } + + fn advance(&mut self) { + self.idx += 1; + self.group_idx = 0; + } + + fn advance_group(&mut self) { + self.group_idx += 1; + } +} + +#[derive(Debug)] +pub struct Tracker<'a> { + pub def: &'a StatementDefinition, + + /// position in the definition, and for each position we track the current token for that + /// position. required for groups. + // pub positions: Vec, + + /// position in the global token stream + pub started_at: usize, + + used_prohibited_statements: Vec<(usize, SyntaxKind)>, + + token_tracker: TokenTracker<'a>, +} + +impl<'a> Tracker<'a> { + pub fn new_at(def: &'a StatementDefinition, at: usize) -> Self { + Self { + def, + // positions: vec![Position::new(1)], + started_at: at, + used_prohibited_statements: Vec::new(), + token_tracker: TokenTracker::new(&def.tokens), + } + } + + pub fn can_start_stmt_after( + &mut self, + kind: &SyntaxKind, + at: usize, + ignore_if_prohibited: bool, + ) -> bool { + if let Some(x) = self + .used_prohibited_statements + .iter() + .find(|x| x.1 == *kind) + { + // we already used this prohibited statement, we we can start a new statement + // but only if we are not at the same position as the prohibited statement + // this is to prevent adding the second "VariableSetStmt" if the first was added to the + // used list if both start at the same position + println!("used prohibited statement: {:?}", x); + return x.0 != at; + } + + let res = + self.could_be_complete() && self.def.prohibited_following_statements.contains(kind); + + if res { + if !ignore_if_prohibited { + println!("prohibited statement: {:?}", kind); + self.used_prohibited_statements.push((at, kind.clone())); + } + return false; + } + + true + } + + /// Returns the max idx of all tracked positions while ignoring non-required tokens + pub fn max_pos(&self) -> usize { + self.token_tracker.max_pos() + } + + pub fn current_positions(&self) -> Vec { + self.token_tracker.current_positions() + } + + pub fn advance_with(&mut self, kind: &SyntaxKind) -> bool { + if WHITESPACE_TOKENS.contains(kind) { + return true; + } + + if self.def.prohibited_tokens.contains(kind) { + return false; + } + + self.token_tracker.advance_with(kind) + } + + pub fn could_be_complete(&self) -> bool { + self.token_tracker.could_be_complete() + } +} diff --git a/crates/pg_statement_splitter/src/tracker_new.rs b/crates/pg_statement_splitter/src/tracker_new.rs new file mode 100644 index 00000000..713f3637 --- /dev/null +++ b/crates/pg_statement_splitter/src/tracker_new.rs @@ -0,0 +1,587 @@ +use pg_lexer::{SyntaxKind, WHITESPACE_TOKENS}; + +use crate::data::{StatementDefinition, SyntaxDefinition}; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct StatementPosition { + at: usize, + + group_idx: Option, +} + +impl StatementPosition { + pub fn new(at: usize) -> Self { + StatementPosition { + at, + group_idx: None, + } + } + + fn new_within_group(at: usize, group_idx: usize) -> Self { + StatementPosition { + at, + group_idx: Some(group_idx), + } + } + + fn group_idx(&self) -> usize { + self.group_idx + .expect("Expected position pointing to a group to have a group index") + } +} + +#[derive(Debug)] +pub struct StatementTracker<'a> { + pub def: &'a StatementDefinition, + + /// position in the global token stream + pub started_at: usize, + + used_prohibited_statements: Vec<(usize, SyntaxKind)>, + + positions: Vec, +} + +impl<'a> StatementTracker<'a> { + pub fn new_at(def: &'a StatementDefinition, started_at: usize) -> Self { + StatementTracker { + def, + started_at, + used_prohibited_statements: vec![], + positions: vec![StatementPosition::new(0)], + } + } + + pub fn advance_with(&mut self, kind: &SyntaxKind) -> bool { + println!("advance with ${:?}", kind); + if WHITESPACE_TOKENS.contains(kind) { + return true; + } + + if self.def.prohibited_tokens.contains(kind) { + return false; + } + + let mut new_positions = Vec::new(); + + for pos in &self.positions { + let syntax = self.def.tokens.get(pos.at).expect("invalid position"); + match syntax { + def @ SyntaxDefinition::OptionalRepeatedGroup(defs) => { + if pos.group_idx() == defs.len() - 1 { + // if we are at the end of a repeated group, check next positions + new_positions.extend(next_positions(&self.def.tokens, pos.at, kind)); + // also check if we can restart + if def.first_required_tokens().iter().any(|x| x == &kind) { + new_positions.push(StatementPosition::new_within_group(pos.at, 0)); + } + } else { + // if we are within a repeated group, we need to check if we can advance within + let next_group_positions = next_positions(&defs, pos.group_idx(), kind); + + for next_pos in next_group_positions { + new_positions + .push(StatementPosition::new_within_group(pos.at, next_pos.at)); + } + } + } + SyntaxDefinition::OptionalGroup(tokens) => { + if pos.group_idx() == tokens.len() - 1 { + // if we are at the end of a group, check next positions + new_positions.extend(next_positions(&self.def.tokens, pos.at, kind)); + } else { + // if we are within a group, we need to check if we can advance within + if tokens[pos.group_idx() + 1] == *kind { + new_positions.push(StatementPosition::new_within_group( + pos.at, + pos.group_idx() + 1, + )); + } + } + } + SyntaxDefinition::AnyTokens(allowed) => { + let next_pos = next_positions(&self.def.tokens, pos.at, kind); + + // if within allowed or no next position, keep position + if (allowed.is_some() && allowed.as_ref().unwrap().contains(kind)) + || next_pos.is_empty() + { + new_positions.push(StatementPosition::new(pos.at)); + } + + // next positions + new_positions.extend(next_pos); + } + _ => { + new_positions.extend(next_positions(&self.def.tokens, pos.at, kind)); + } + } + } + + self.positions = new_positions; + + !self.positions.is_empty() + } + + pub fn can_start_stmt_after( + &mut self, + kind: &SyntaxKind, + at: usize, + ignore_if_prohibited: bool, + ) -> bool { + if let Some(x) = self + .used_prohibited_statements + .iter() + .find(|x| x.1 == *kind) + { + println!("used prohibited: {:?} at {}", x, at); + // we already used this prohibited statement, we we can start a new statement + // but only if we are not at the same position as the prohibited statement + // this is to prevent adding the second "VariableSetStmt" if the first was added to the + // used list if both start at the same position + return x.0 != at; + } + + let res = + self.could_be_complete() && self.def.prohibited_following_statements.contains(kind); + + println!("prohibited: res {} for {:?} at {}", res, kind, at); + if res { + if !ignore_if_prohibited { + self.used_prohibited_statements.push((at, kind.clone())); + } + return false; + } + + true + } + + pub fn current_positions(&self) -> Vec { + self.positions.iter().map(|x| x.at).collect() + } + + /// Returns the max idx of all tracked positions while ignoring non-required tokens + pub fn max_pos(&self) -> usize { + self.positions + .iter() + .map(|p| { + // substract non-required tokens from the position count + (0..p.at).fold(0, |acc, idx| { + let token = self.def.tokens.get(idx); + match token { + Some(SyntaxDefinition::RequiredToken(_)) => acc + 1, + Some(SyntaxDefinition::OneOf(_)) => acc + 1, + Some(SyntaxDefinition::AnyToken) => acc + 1, + _ => acc, + } + }) + }) + .max() + .unwrap() + } + + pub fn could_be_complete(&self) -> bool { + let res = self._could_be_complete(); + // println!( + // "{:?} could be complete: {} with {:?}", + // self.def.stmt, res, self.def.tokens + // ); + res + } + + pub fn _could_be_complete(&self) -> bool { + let max_pos = self.positions.iter().map(|p| p.at).max().unwrap(); + // println!("tokens: {:?}", self.def.tokens); + // println!("max pos: {} at {:?}", max_pos, self.def.tokens.get(max_pos)); + + // if max pos is at group and not at last group idx, we can't be complete + match self.def.tokens.get(max_pos) { + Some(SyntaxDefinition::OptionalGroup(tokens)) => { + if self + .positions + .iter() + .all(|x| x.group_idx() < tokens.len() - 1) + { + return false; + } + } + Some(SyntaxDefinition::OptionalRepeatedGroup(tokens)) => { + if self + .positions + .iter() + .all(|x| x.group_idx() < tokens.len() - 1) + { + return false; + } + } + _ => {} + } + // + // println!( + // "checking tokens after: {:?}", + // self.def.tokens.iter().skip(max_pos + 1).collect::>() + // ); + + self.def.tokens.iter().skip(max_pos + 1).all(|x| match x { + SyntaxDefinition::RequiredToken(_) => false, + SyntaxDefinition::OneOf(_) => false, + SyntaxDefinition::AnyToken => false, + _ => true, + }) + } +} + +fn next_positions( + tokens: &Vec, + pos: usize, + kind: &SyntaxKind, +) -> Vec { + let mut new_positions = Vec::new(); + + for (pos, token) in tokens.iter().enumerate().skip(pos + 1) { + match token { + SyntaxDefinition::RequiredToken(k) => { + if k == kind { + new_positions.push(StatementPosition::new(pos)); + } + break; + } + SyntaxDefinition::OptionalToken(k) => { + if k == kind { + new_positions.push(StatementPosition::new(pos)); + } + } + SyntaxDefinition::AnyTokens(expected) => { + if expected.is_none() || expected.as_ref().unwrap().contains(kind) { + new_positions.push(StatementPosition::new(pos)); + } + } + SyntaxDefinition::AnyToken => { + new_positions.push(StatementPosition::new(pos)); + break; + } + SyntaxDefinition::OneOf(kinds) => { + if kinds.iter().any(|x| x == kind) { + new_positions.push(StatementPosition::new(pos)); + } + break; + } + SyntaxDefinition::OptionalGroup(t) => { + let first_token = t.first().unwrap(); + if first_token == kind { + new_positions.push(StatementPosition::new_within_group(pos, 0)); + } + } + def @ SyntaxDefinition::OptionalRepeatedGroup(_) => { + if def.first_required_tokens().iter().any(|x| x == &kind) { + new_positions.push(StatementPosition::new_within_group(pos, 0)); + } + } + } + } + + new_positions +} + +#[cfg(test)] +mod tests { + use pg_lexer::{lex, SyntaxKind, WHITESPACE_TOKENS}; + + use crate::{ + data::{SyntaxDefinition, STATEMENT_DEFINITIONS}, + tracker_new::StatementPosition, + }; + + use super::StatementTracker; + + #[test] + fn test_optional_repeated_group() { + let input = " +WITH t1 AS ( + SELECT 1 +), t2 AS ( + SELECT 2 +) +SELECT 's'; + "; + + let stmt_def = STATEMENT_DEFINITIONS + .get(&SyntaxKind::With) + .unwrap() + .first() + .unwrap(); + + // TODO only go to any tokens if there is no other position! + println!("{:#?}", stmt_def.tokens); + + let lexed = lex(input); + + let tokens = lexed + .iter() + .filter(|x| !WHITESPACE_TOKENS.contains(&x.kind)) + .collect::>(); + let mut tokens_iter = tokens.iter(); + + while tokens_iter.next().unwrap().kind != SyntaxKind::With { + // skip until WITH + } + + let mut tracker = StatementTracker::new_at(stmt_def, 1); + + assert_eq!( + tracker.positions, + vec![StatementPosition { + at: 0, + group_idx: None + }] + ); + + tracker.advance_with(&tokens_iter.next().unwrap().kind); + + assert_eq!( + tracker.positions, + vec![StatementPosition { + at: 2, + group_idx: None + }] + ); + + tracker.advance_with(&tokens_iter.next().unwrap().kind); + + assert_eq!( + tracker.positions, + vec![StatementPosition { + at: 3, + group_idx: None + }] + ); + + tracker.advance_with(&tokens_iter.next().unwrap().kind); + + assert_eq!( + tracker.positions, + vec![StatementPosition { + at: 4, + group_idx: None + }] + ); + + tracker.advance_with(&tokens_iter.next().unwrap().kind); + + assert_eq!( + tracker.positions, + vec![StatementPosition { + at: 5, + group_idx: None + }] + ); + + tracker.advance_with(&tokens_iter.next().unwrap().kind); + + assert_eq!( + tracker.positions, + vec![StatementPosition { + at: 5, + group_idx: None + }] + ); + + tracker.advance_with(&tokens_iter.next().unwrap().kind); + + assert_eq!( + tracker.positions, + vec![StatementPosition { + at: 6, + group_idx: None + }] + ); + + tracker.advance_with(&tokens_iter.next().unwrap().kind); + + assert_eq!( + tracker.positions, + vec![StatementPosition { + at: 8, + group_idx: None + }] + ); + + tracker.advance_with(&tokens_iter.next().unwrap().kind); + + assert_eq!( + tracker.positions, + vec![StatementPosition { + at: 8, + group_idx: None + }] + ); + + tracker.advance_with(&tokens_iter.next().unwrap().kind); + + assert_eq!( + tracker.positions, + vec![StatementPosition { + at: 8, + group_idx: None + }] + ); + + tracker.advance_with(&tokens_iter.next().unwrap().kind); + + assert_eq!( + tracker.positions, + vec![StatementPosition { + at: 8, + group_idx: None + }] + ); + + println!( + "{:?}", + tracker + .positions + .iter() + .map(|x| stmt_def.tokens.get(x.at)) + .collect::>() + ); + + // tracker.advance_with(&SyntaxKind::Ascii42); + // + // assert_eq!(tracker.positions.len(), 1); + // + // assert_eq!( + // tracker.positions[0], + // StatementPosition { + // at: 1, + // group_idx: None + // } + // ); + // + // tracker.advance_with(&SyntaxKind::Whitespace); + // + // assert_eq!(tracker.positions.len(), 1); + // + // assert_eq!( + // tracker.positions[0], + // StatementPosition { + // at: 1, + // group_idx: None + // } + // ); + // + // tracker.advance_with(&SyntaxKind::From); + // + // assert_eq!(tracker.positions.len(), 1); + // + // assert_eq!( + // tracker.positions[0], + // StatementPosition { + // at: 2, + // group_idx: None + // } + // ); + // + // tracker.advance_with(&SyntaxKind::Whitespace); + // + // assert_eq!(tracker.positions.len(), 1); + // + // assert_eq!( + // tracker.positions[0], + // StatementPosition { + // at: 2, + // group_idx: None + // } + // ); + // + // tracker.advance_with(&SyntaxKind::Ident); + // + // assert_eq!(tracker.positions.len(), 1); + // + // assert_eq!( + // tracker.positions[0], + // StatementPosition { + // at: 2, + // group_idx: None + // } + // ); + } + + #[test] + fn test_advance_with() { + let new_stmts = STATEMENT_DEFINITIONS.get(&SyntaxKind::Select).unwrap(); + + let mut tracker = StatementTracker::new_at(new_stmts.first().unwrap(), 0); + + tracker.advance_with(&SyntaxKind::Whitespace); + + assert_eq!(tracker.positions.len(), 1); + + assert_eq!( + tracker.positions[0], + StatementPosition { + at: 0, + group_idx: None + } + ); + + tracker.advance_with(&SyntaxKind::Ascii42); + + assert_eq!(tracker.positions.len(), 1); + + assert_eq!( + tracker.positions[0], + StatementPosition { + at: 1, + group_idx: None + } + ); + + tracker.advance_with(&SyntaxKind::Whitespace); + + assert_eq!(tracker.positions.len(), 1); + + assert_eq!( + tracker.positions[0], + StatementPosition { + at: 1, + group_idx: None + } + ); + + tracker.advance_with(&SyntaxKind::From); + + assert_eq!(tracker.positions.len(), 1); + + assert_eq!( + tracker.positions[0], + StatementPosition { + at: 2, + group_idx: None + } + ); + + tracker.advance_with(&SyntaxKind::Whitespace); + + assert_eq!(tracker.positions.len(), 1); + + assert_eq!( + tracker.positions[0], + StatementPosition { + at: 2, + group_idx: None + } + ); + + tracker.advance_with(&SyntaxKind::Ident); + + assert_eq!(tracker.positions.len(), 1); + + assert_eq!( + tracker.positions[0], + StatementPosition { + at: 2, + group_idx: None + } + ); + } +} diff --git a/crates/pg_statement_splitter/tests/data/simple_select__4.sql b/crates/pg_statement_splitter/tests/data/simple_select.sql similarity index 100% rename from crates/pg_statement_splitter/tests/data/simple_select__4.sql rename to crates/pg_statement_splitter/tests/data/simple_select.sql diff --git a/crates/pg_statement_splitter/tests/skipped.txt b/crates/pg_statement_splitter/tests/skipped.txt index 480089b9..35a130b9 100644 --- a/crates/pg_statement_splitter/tests/skipped.txt +++ b/crates/pg_statement_splitter/tests/skipped.txt @@ -1,3 +1,5 @@ +alter_table + brin brin_bloom brin_multi @@ -10,3 +12,23 @@ drop_operator replica_identity unicode xmlmap +aggregates +comments +dependency +drop_if_exists +groupingsets +index_including_gist +inherit +insert +insert_conflict +numeric_big +opr_sanity +case +random +rangetypes +regproc +rowtypes +sanity_check +select_distinct +text +union diff --git a/crates/pg_statement_splitter/tests/skipped_statements.txt b/crates/pg_statement_splitter/tests/skipped_statements.txt new file mode 100644 index 00000000..edaaef9d --- /dev/null +++ b/crates/pg_statement_splitter/tests/skipped_statements.txt @@ -0,0 +1,2 @@ +alter table atacc1 SET WITH OIDS; +ALTER INDEX attmp_idx ALTER COLUMN 0 SET STATISTICS 1000; diff --git a/crates/pg_statement_splitter/tests/snapshots/data/simple_select.snap b/crates/pg_statement_splitter/tests/snapshots/data/simple_select.snap new file mode 100644 index 00000000..d27571f4 --- /dev/null +++ b/crates/pg_statement_splitter/tests/snapshots/data/simple_select.snap @@ -0,0 +1,26 @@ +--- +source: crates/pg_statement_splitter/tests/statement_splitter_tests.rs +description: "select id, name, test1231234123, unknown from co;\n\nselect 14433313331333\n\nalter table test drop column id;\n\nselect lower('test');\n\n" +--- +[ + ( + SelectStmt, + 0..49, + "select id, name, test1231234123, unknown from co;", + ), + ( + SelectStmt, + 51..72, + "select 14433313331333", + ), + ( + AlterTableStmt, + 74..106, + "alter table test drop column id;", + ), + ( + SelectStmt, + 108..129, + "select lower('test');", + ), +] diff --git a/crates/pg_statement_splitter/tests/statement_splitter_tests.rs b/crates/pg_statement_splitter/tests/statement_splitter_tests.rs index fb639fef..781b3b9c 100644 --- a/crates/pg_statement_splitter/tests/statement_splitter_tests.rs +++ b/crates/pg_statement_splitter/tests/statement_splitter_tests.rs @@ -1,8 +1,17 @@ -use std::fs::{self}; +use insta::{assert_debug_snapshot, Settings}; +use std::{ + fs::{self}, + panic, +}; + +use pg_lexer::SyntaxKind; const DATA_DIR_PATH: &str = "tests/data/"; const POSTGRES_REGRESS_PATH: &str = "../../libpg_query/test/sql/postgres_regress/"; const SKIPPED_REGRESS_TESTS: &str = include_str!("skipped.txt"); +const SKIPPED_STATEMENTS: &str = include_str!("skipped_statements.txt"); + +const SNAPSHOTS_PATH: &str = "snapshots/data"; #[test] fn test_postgres_regress() { @@ -28,39 +37,73 @@ fn test_postgres_regress() { continue; } + println!("Running test: {}", test_name); + // remove \commands because pg_query doesn't support them let contents = fs::read_to_string(&path) .unwrap() .lines() - .filter(|l| !l.starts_with("\\") && !l.ends_with("\\gset")) + .filter_map(|l| { + if !l.starts_with("\\") + && !l.ends_with("\\gset") + && !l.starts_with("--") + && !l.contains(":'") + && (l.starts_with("\t") || l.split("\t").count() <= 1) + && !SKIPPED_STATEMENTS.contains(l) + { + if let Some(index) = l.find("--") { + Some(l[..index].to_string()) + } else { + Some(l.to_string()) + } + } else { + None + } + }) .collect::>() - .join(" "); - - let libpg_query_split = pg_query::split_with_parser(&contents).unwrap(); - - let parser_split = pg_statement_splitter::split(&contents); - - assert_eq!( - parser_split.errors.len(), - 0, - "Unexpected errors when parsing file {}:\n{:#?}", - test_name, - parser_split.errors - ); - - assert_eq!( - libpg_query_split.len(), - parser_split.ranges.len(), - "Mismatch in statement count for file {}: Expected {} statements, got {}", - test_name, - libpg_query_split.len(), - parser_split.ranges.len() - ); - - for (libpg_query_stmt, parser_range) in - libpg_query_split.iter().zip(parser_split.ranges.iter()) - { - let parser_stmt = &contents[parser_range.clone()].trim(); + .join("\n"); + + let libpg_query_split_result = pg_query::split_with_parser(&contents); + + if libpg_query_split_result.is_err() { + panic!( + "'{}'\nFailed to split statements for test '{}': {:?}", + contents, test_name, libpg_query_split_result + ); + } + + let libpg_query_split = libpg_query_split_result.unwrap(); + + let result = panic::catch_unwind(|| pg_statement_splitter::statements(&contents)); + + if result.is_err() { + panic!( + "Failed to split statements for test '{}': {:?}", + test_name, + result.unwrap_err() + ); + } + + let split = result.unwrap(); + + // assert_eq!( + // libpg_query_split.len(), + // split.len(), + // "[{}] Mismatch in statement count: Expected {} statements, got {}. Contents:\n{}", + // test_name, + // libpg_query_split.len(), + // split.len(), + // contents + // ); + + for (libpg_query_stmt, parser_result) in libpg_query_split.iter().zip(split.iter()) { + let mut parser_stmt = contents[parser_result.range.clone()].trim().to_string(); + + if parser_stmt.ends_with(';') { + let mut s = parser_stmt.chars().rev().skip(1).collect::(); + s = s.chars().rev().collect(); + parser_stmt = format!("{}{}", s.trim(), ";"); + } let libpg_query_stmt = if libpg_query_stmt.ends_with(';') { libpg_query_stmt.to_string() @@ -71,11 +114,34 @@ fn test_postgres_regress() { let libpg_query_stmt_trimmed = libpg_query_stmt.trim(); let parser_stmt_trimmed = parser_stmt.trim(); + let root = pg_query::parse(libpg_query_stmt_trimmed) + .map(|parsed| { + parsed + .protobuf + .nodes() + .iter() + .find(|n| n.1 == 1) + .unwrap() + .0 + .to_enum() + }) + .expect("Failed to parse statement"); + assert_eq!( libpg_query_stmt_trimmed, parser_stmt_trimmed, - "Mismatch in statement {}:\nlibg_query: '{}'\nsplitter: '{}'", - test_name, libpg_query_stmt_trimmed, parser_stmt_trimmed + "[{}] Mismatch in statement:\nlibg_query: '{}'\nsplitter: '{}'\n Root Node: {:?}", + test_name, libpg_query_stmt_trimmed, parser_stmt_trimmed, root ); + + let syntax_kind = SyntaxKind::from(&root); + + assert_eq!( + syntax_kind, parser_result.kind, + "[{}] Mismatch in statement type. Expected {:?}, got {:?} for statement '{}'. Root Node: {:?}", + test_name, syntax_kind, parser_result.kind, parser_stmt_trimmed, root + ); + + println!("[{}] Matched {}", test_name, parser_stmt_trimmed); } } } @@ -91,24 +157,23 @@ fn test_statement_splitter() { for f in paths.iter() { let path = f.path(); let test_name = path.file_stem().unwrap().to_str().unwrap(); - let expected_count = test_name - .split("__") - .last() - .unwrap() - .parse::() - .unwrap(); let contents = fs::read_to_string(&path).unwrap(); - let split = pg_statement_splitter::split(&contents); + let statements = pg_statement_splitter::statements(&contents); + + let result = statements + .iter() + .map(|x| (x.kind, x.range, &contents[x.range.clone()])) + .collect::>(); + + let mut settings = Settings::clone_current(); + settings.set_input_file(&path); + settings.set_prepend_module_to_snapshot(false); + settings.set_description(contents.to_string()); + settings.set_omit_expression(true); + settings.set_snapshot_path(SNAPSHOTS_PATH); - assert_eq!( - split.ranges.len(), - expected_count, - "Mismatch in statement count for file {}: Expected {} statements, got {}", - test_name, - expected_count, - split.ranges.len() - ); + settings.bind(|| assert_debug_snapshot!(test_name, result)); } }