From 825939c6756c0a3be9d28312788b46779de12e23 Mon Sep 17 00:00:00 2001 From: ad hoc Date: Thu, 29 Jun 2023 14:40:40 +0200 Subject: [PATCH 01/64] introduce libsqlx crate --- Cargo.lock | 169 ++- Cargo.toml | 1 + libsqlx/Cargo.toml | 36 + libsqlx/assets/test/simple_wallog | Bin 0 -> 28904 bytes libsqlx/src/analysis.rs | 288 +++++ libsqlx/src/connection.rs | 33 + libsqlx/src/database/frame.rs | 112 ++ libsqlx/src/database/libsql/connection.rs | 309 +++++ .../src/database/libsql/injector/headers.rs | 47 + libsqlx/src/database/libsql/injector/hook.rs | 175 +++ libsqlx/src/database/libsql/injector/mod.rs | 233 ++++ libsqlx/src/database/libsql/mod.rs | 365 ++++++ .../libsql/replication_log/frame_stream.rs | 114 ++ .../database/libsql/replication_log/logger.rs | 1022 +++++++++++++++++ .../database/libsql/replication_log/merger.rs | 137 +++ .../database/libsql/replication_log/mod.rs | 12 + .../libsql/replication_log/snapshot.rs | 334 ++++++ libsqlx/src/database/mod.rs | 135 +++ libsqlx/src/database/proxy/connection.rs | 222 ++++ libsqlx/src/database/proxy/database.rs | 50 + libsqlx/src/database/proxy/mod.rs | 11 + libsqlx/src/database/test_utils.rs | 66 ++ libsqlx/src/error.rs | 44 + libsqlx/src/lib.rs | 21 + libsqlx/src/program.rs | 60 + libsqlx/src/query.rs | 267 +++++ libsqlx/src/result_builder.rs | 711 ++++++++++++ libsqlx/src/seal.rs | 8 + libsqlx/src/semaphore.rs | 98 ++ sqld-libsql-bindings/src/wal_hook.rs | 4 +- sqld/src/replication/replica/hook.rs | 4 +- 31 files changed, 5044 insertions(+), 44 deletions(-) create mode 100644 libsqlx/Cargo.toml create mode 100644 libsqlx/assets/test/simple_wallog create mode 100644 libsqlx/src/analysis.rs create mode 100644 libsqlx/src/connection.rs create mode 100644 libsqlx/src/database/frame.rs create mode 100644 libsqlx/src/database/libsql/connection.rs create mode 100644 libsqlx/src/database/libsql/injector/headers.rs create mode 100644 libsqlx/src/database/libsql/injector/hook.rs create mode 100644 libsqlx/src/database/libsql/injector/mod.rs create mode 100644 libsqlx/src/database/libsql/mod.rs create mode 100644 libsqlx/src/database/libsql/replication_log/frame_stream.rs create mode 100644 libsqlx/src/database/libsql/replication_log/logger.rs create mode 100644 libsqlx/src/database/libsql/replication_log/merger.rs create mode 100644 libsqlx/src/database/libsql/replication_log/mod.rs create mode 100644 libsqlx/src/database/libsql/replication_log/snapshot.rs create mode 100644 libsqlx/src/database/mod.rs create mode 100644 libsqlx/src/database/proxy/connection.rs create mode 100644 libsqlx/src/database/proxy/database.rs create mode 100644 libsqlx/src/database/proxy/mod.rs create mode 100644 libsqlx/src/database/test_utils.rs create mode 100644 libsqlx/src/error.rs create mode 100644 libsqlx/src/lib.rs create mode 100644 libsqlx/src/program.rs create mode 100644 libsqlx/src/query.rs create mode 100644 libsqlx/src/result_builder.rs create mode 100644 libsqlx/src/seal.rs create mode 100644 libsqlx/src/semaphore.rs diff --git a/Cargo.lock b/Cargo.lock index 38ff0672..6d0349e7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1176,7 +1176,7 @@ dependencies = [ "cranelift-entity", "fxhash", "hashbrown 0.12.3", - "indexmap", + "indexmap 1.9.3", "log", "smallvec", ] @@ -1228,7 +1228,7 @@ dependencies = [ "cranelift-codegen", "cranelift-entity", "cranelift-frontend", - "itertools", + "itertools 0.10.5", "log", "smallvec", "wasmparser", @@ -1470,6 +1470,12 @@ dependencies = [ "termcolor", ] +[[package]] +name = "equivalent" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "88bffebc5d80432c9b140ee17875ff173a8ab62faad5b257da912bd2f6c1c0a1" + [[package]] name = "errno" version = "0.2.8" @@ -1514,6 +1520,12 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4443176a9f2c162692bd3d352d745ef9413eec5782a80d8fd6f8a1ac692a07f7" +[[package]] +name = "fallible-iterator" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2acce4a10f12dc2fb14a218589d4f1f62ef011b2d0cc4b3cb1bba8e94da14649" + [[package]] name = "fallible-streaming-iterator" version = "0.1.9" @@ -1767,8 +1779,8 @@ version = "0.26.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "22030e2c5a68ec659fde1e949a745124b48e6fa8b045b7ed5bd1fe4ccc5c4e5d" dependencies = [ - "fallible-iterator", - "indexmap", + "fallible-iterator 0.2.0", + "indexmap 1.9.3", "stable_deref_trait", ] @@ -1790,7 +1802,7 @@ dependencies = [ "futures-sink", "futures-util", "http", - "indexmap", + "indexmap 1.9.3", "slab", "tokio", "tokio-util", @@ -1815,6 +1827,12 @@ dependencies = [ "ahash 0.8.3", ] +[[package]] +name = "hashbrown" +version = "0.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2c6201b9ff9fd90a5a3bac2e56a830d0caa509576f0e503818ee82c181b3437a" + [[package]] name = "hashlink" version = "0.8.2" @@ -2075,6 +2093,16 @@ dependencies = [ "serde", ] +[[package]] +name = "indexmap" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d5477fe2230a79769d8dc68e0eabf5437907c0457a5614a9e8dddb67f65eb65d" +dependencies = [ + "equivalent", + "hashbrown 0.14.0", +] + [[package]] name = "insta" version = "1.29.0" @@ -2177,6 +2205,15 @@ dependencies = [ "either", ] +[[package]] +name = "itertools" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b1c173a5686ce8bfa551b3563d0c2170bf24ca44da99c7ca4bfdab5418c3fe57" +dependencies = [ + "either", +] + [[package]] name = "itoa" version = "1.0.6" @@ -2322,6 +2359,38 @@ dependencies = [ "vcpkg", ] +[[package]] +name = "libsqlx" +version = "0.1.0" +dependencies = [ + "anyhow", + "arbitrary", + "async-trait", + "bytemuck", + "bytes 1.4.0", + "bytesize", + "crc", + "crossbeam", + "fallible-iterator 0.3.0", + "futures", + "itertools 0.11.0", + "nix", + "once_cell", + "parking_lot", + "rand", + "regex", + "rusqlite", + "serde", + "serde_json", + "sqld-libsql-bindings", + "sqlite3-parser 0.9.0", + "tempfile", + "thiserror", + "tokio", + "tracing", + "uuid", +] + [[package]] name = "linked-hash-map" version = "0.5.6" @@ -2611,7 +2680,7 @@ checksum = "21158b2c33aa6d4561f1c0a6ea283ca92bc54802a93b263e910746d679a7eb53" dependencies = [ "crc32fast", "hashbrown 0.12.3", - "indexmap", + "indexmap 1.9.3", "memchr", ] @@ -2646,9 +2715,9 @@ dependencies = [ [[package]] name = "once_cell" -version = "1.17.1" +version = "1.18.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b7e5500299e16ebb147ae15a00a942af264cf3688f47923b8fc2cd5858f23ad3" +checksum = "dd8b5dd2ae5ed71462c540258bedcb51965123ad7e7ccf4b9a8cafaa4a63576d" [[package]] name = "openssl" @@ -2763,7 +2832,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4dd7d28ee937e54fe3080c91faa1c3a46c06de6252988a7f4592ba2310ef22a4" dependencies = [ "fixedbitset", - "indexmap", + "indexmap 1.9.3", ] [[package]] @@ -2932,7 +3001,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5ca9c6be70d989d21a136eb86c2d83e4b328447fac4a88dace2143c179c86267" dependencies = [ "autocfg", - "indexmap", + "indexmap 1.9.3", ] [[package]] @@ -2982,7 +3051,7 @@ checksum = "119533552c9a7ffacc21e099c24a0ac8bb19c2a2a3f363de84cd9b844feab270" dependencies = [ "bytes 1.4.0", "heck", - "itertools", + "itertools 0.10.5", "lazy_static", "log", "multimap", @@ -3003,7 +3072,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e5d2d8d10f3c6ded6da8b05b5fb3b8a5082514344d56c9f871412d29b4e075b4" dependencies = [ "anyhow", - "itertools", + "itertools 0.10.5", "proc-macro2", "quote", "syn 1.0.109", @@ -3155,9 +3224,9 @@ dependencies = [ [[package]] name = "regex" -version = "1.8.2" +version = "1.8.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d1a59b5d8e97dee33696bf13c5ba8ab85341c002922fba050069326b9c498974" +checksum = "d0ab3ca65655bb1e41f2a8c8cd662eb4fb035e67c3f78da1d61dffe89d07300f" dependencies = [ "aho-corasick", "memchr", @@ -3248,7 +3317,7 @@ version = "0.29.0" source = "git+https://github.com/psarna/rusqlite?rev=d9a97c0f25#d9a97c0f25d48272c91d3f8d93d46cb405c39037" dependencies = [ "bitflags 2.3.1", - "fallible-iterator", + "fallible-iterator 0.2.0", "fallible-streaming-iterator", "hashlink", "libsqlite3-sys", @@ -3452,18 +3521,18 @@ checksum = "bebd363326d05ec3e2f532ab7660680f3b02130d780c299bca73469d521bc0ed" [[package]] name = "serde" -version = "1.0.163" +version = "1.0.164" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2113ab51b87a539ae008b5c6c02dc020ffa39afd2d83cffcb3f4eb2722cebec2" +checksum = "9e8c8cf938e98f769bc164923b06dce91cea1751522f46f8466461af04c9027d" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.163" +version = "1.0.164" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8c805777e3930c8883389c602315a24224bcc738b63905ef87cd1420353ea93e" +checksum = "d9735b638ccc51c28bf6914d90a2e9725b377144fc612c49a611fddd1b631d68" dependencies = [ "proc-macro2", "quote", @@ -3472,11 +3541,11 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.96" +version = "1.0.99" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "057d394a50403bcac12672b2b18fb387ab6d289d957dab67dd201875391e52f1" +checksum = "46266871c240a00b8f503b877622fe33430b3c7d963bdc0f2adc511e54a1eae3" dependencies = [ - "indexmap", + "indexmap 2.0.0", "itoa", "ryu", "serde", @@ -3635,7 +3704,7 @@ version = "0.1.0" dependencies = [ "anyhow", "bytes 1.4.0", - "fallible-iterator", + "fallible-iterator 0.2.0", "fn-error-context", "libsql-client", "scram", @@ -3669,13 +3738,13 @@ dependencies = [ "crossbeam", "enclose", "env_logger", - "fallible-iterator", + "fallible-iterator 0.2.0", "futures", "hmac", "hyper", "hyper-tungstenite", "insta", - "itertools", + "itertools 0.10.5", "jsonwebtoken", "libsql-client", "memmap", @@ -3696,7 +3765,7 @@ dependencies = [ "sha2", "sha256", "sqld-libsql-bindings", - "sqlite3-parser", + "sqlite3-parser 0.8.0", "tempfile", "thiserror", "tokio", @@ -3731,8 +3800,27 @@ checksum = "c3995a6daa13c113217b6ad22154865fb06f9cb939bef398fd04f4a7aaaf5bd7" dependencies = [ "bitflags 2.3.1", "cc", - "fallible-iterator", - "indexmap", + "fallible-iterator 0.2.0", + "indexmap 1.9.3", + "log", + "memchr", + "phf", + "phf_codegen", + "phf_shared", + "smallvec", + "uncased", +] + +[[package]] +name = "sqlite3-parser" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "db68d3f0682b50197a408d65a3246b7d6173399d1325cf0208fb3fdb66e3229f" +dependencies = [ + "bitflags 2.3.1", + "cc", + "fallible-iterator 0.3.0", + "indexmap 1.9.3", "log", "memchr", "phf", @@ -3829,15 +3917,16 @@ checksum = "fd1ba337640d60c3e96bc6f0638a939b9c9a7f2c316a1598c279828b3d1dc8c5" [[package]] name = "tempfile" -version = "3.5.0" +version = "3.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b9fbec84f381d5795b08656e4912bec604d162bff9291d6189a78f4c8ab87998" +checksum = "31c0432476357e58790aaa47a8efb0c5138f137343f3b5f23bd36a27e3b0a6d6" dependencies = [ + "autocfg", "cfg-if", "fastrand", "redox_syscall 0.3.5", "rustix 0.37.19", - "windows-sys 0.45.0", + "windows-sys 0.48.0", ] [[package]] @@ -3946,9 +4035,9 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokio" -version = "1.28.1" +version = "1.28.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0aa32867d44e6f2ce3385e89dceb990188b8bb0fb25b0cf576647a6f98ac5105" +checksum = "94d7b1cfd2aa4011f2de74c2c4c63665e27a71006b0a192dcd2710272e73dfa2" dependencies = [ "autocfg", "bytes 1.4.0", @@ -4145,7 +4234,7 @@ checksum = "b8fa9be0de6cf49e536ce1851f987bd21a43b771b09473c3549a6c853db37c1c" dependencies = [ "futures-core", "futures-util", - "indexmap", + "indexmap 1.9.3", "pin-project 1.1.0", "pin-project-lite", "rand", @@ -4380,9 +4469,9 @@ checksum = "711b9620af191e0cdc7468a8d14e709c3dcdb115b36f838e601583af800a370a" [[package]] name = "uuid" -version = "1.3.3" +version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "345444e32442451b267fc254ae85a209c64be56d2890e601a0c37ff0c3c5ecd2" +checksum = "d023da39d1fde5a8a3fe1f3e01ca9632ada0a63e9797de55a879d6e2236277be" dependencies = [ "atomic", "getrandom", @@ -4579,7 +4668,7 @@ version = "0.93.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c5a4460aa3e271fa180b6a5d003e728f3963fb30e3ba0fa7c9634caa06049328" dependencies = [ - "indexmap", + "indexmap 1.9.3", ] [[package]] @@ -4592,7 +4681,7 @@ dependencies = [ "async-trait", "bincode", "cfg-if", - "indexmap", + "indexmap 1.9.3", "libc", "log", "object", @@ -4672,7 +4761,7 @@ dependencies = [ "anyhow", "cranelift-entity", "gimli", - "indexmap", + "indexmap 1.9.3", "log", "object", "serde", @@ -4752,7 +4841,7 @@ dependencies = [ "anyhow", "cc", "cfg-if", - "indexmap", + "indexmap 1.9.3", "libc", "log", "mach", diff --git a/Cargo.toml b/Cargo.toml index 846999ef..238333f0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,6 +7,7 @@ members = [ "sqld", "sqld-libsql-bindings", "testing/end-to-end", + "libsqlx", ] [workspace.dependencies] diff --git a/libsqlx/Cargo.toml b/libsqlx/Cargo.toml new file mode 100644 index 00000000..b0f00521 --- /dev/null +++ b/libsqlx/Cargo.toml @@ -0,0 +1,36 @@ +[package] +name = "libsqlx" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +async-trait = "0.1.68" +bytesize = "1.2.0" +serde = "1.0.164" +serde_json = "1.0.99" +rusqlite = { workspace = true } +anyhow = "1.0.71" +futures = "0.3.28" +tokio = { version = "1.28.2", features = ["sync", "time"] } +sqlite3-parser = "0.9.0" +fallible-iterator = "0.3.0" +bytes = "1.4.0" +tracing = "0.1.37" +bytemuck = { version = "1.13.1", features = ["derive"] } +parking_lot = "0.12.1" +uuid = { version = "1.4.0", features = ["v4"] } +sqld-libsql-bindings = { version = "0", path = "../sqld-libsql-bindings" } +crossbeam = "0.8.2" +thiserror = "1.0.40" +nix = "0.26.2" +crc = "3.0.1" +once_cell = "1.18.0" +regex = "1.8.4" +tempfile = "3.6.0" + +[dev-dependencies] +arbitrary = { version = "1.3.0", features = ["derive"] } +itertools = "0.11.0" +rand = "0.8.5" diff --git a/libsqlx/assets/test/simple_wallog b/libsqlx/assets/test/simple_wallog new file mode 100644 index 0000000000000000000000000000000000000000..42e5b3a914a0ad8f9cb3600fb11e4c2af98990a7 GIT binary patch literal 28904 zcmeI)O(+Cm9LMpQXLd-^xCkfXpzOh0ialEFZ4TZdN_lCO@>0yfLCRSi+;EZ{U7WZo zB@&U79F(_<(jK@-o@YIy5yw5)?fY+LpJ$$(d79^QdwyHn-RZUW#XTda<~wJg;q_?u zwDq=f`SD^W7|+)oq1xiF3rCNsX6$Nc;c25TW$Ll+PH)^#x?|Hblf%B-C?65RXmMTn zwVEP5jlNYq648Ct1XF~DeszUWJGED&Zg#E=c~A%-fB*srAb#I|=btU3G3AZQS>Lpw?;#SW^gH3$l@x%Pbb>()j=~sL9AV0{| zYcUT32q1s}0tg_000IagfB*sr{4asvhyTN}>y`b2%ad>}U^?%HJSYSZKmY**5I_I{ z1Q0*~0R#}pK>=Zk0!y!}=6toOd?~gb&IN??CUXJ#K_P$u0tg_000IagfB*srAb>zF z30Ojg0xA~>{up4eqkb{Ib#oui1uQ3(&IM#cA%Fk^2q1s}0tg_000IagfIv#? y(OQ4+%RyCtBAg4@&eL};AiER-2q1s}0tg_000IagfB*srWT}8<3)B817x)B5eNwRi literal 0 HcmV?d00001 diff --git a/libsqlx/src/analysis.rs b/libsqlx/src/analysis.rs new file mode 100644 index 00000000..0c7f7d43 --- /dev/null +++ b/libsqlx/src/analysis.rs @@ -0,0 +1,288 @@ +use anyhow::Result; +use fallible_iterator::FallibleIterator; +use sqlite3_parser::ast::{Cmd, PragmaBody, QualifiedName, Stmt}; +use sqlite3_parser::lexer::sql::{Parser, ParserError}; + +/// A group of statements to be executed together. +#[derive(Debug, Clone)] +pub struct Statement { + pub stmt: String, + pub kind: StmtKind, + /// Is the statement an INSERT, UPDATE or DELETE? + pub is_iud: bool, + pub is_insert: bool, +} + +impl Default for Statement { + fn default() -> Self { + Self::empty() + } +} + +/// Classify statement in categories of interest. +#[derive(Debug, PartialEq, Clone, Copy)] +pub enum StmtKind { + /// The begining of a transaction + TxnBegin, + /// The end of a transaction + TxnEnd, + Read, + Write, + Other, +} + +fn is_temp(name: &QualifiedName) -> bool { + name.db_name.as_ref().map(|n| n.0.as_str()) == Some("TEMP") +} + +fn is_reserved_tbl(name: &QualifiedName) -> bool { + let n = name.name.0.to_lowercase(); + n == "_litestream_seq" || n == "_litestream_lock" || n == "libsql_wasm_func_table" +} + +fn write_if_not_reserved(name: &QualifiedName) -> Option { + (!is_reserved_tbl(name)).then_some(StmtKind::Write) +} + +impl StmtKind { + fn kind(cmd: &Cmd) -> Option { + match cmd { + Cmd::Explain(Stmt::Pragma(name, body)) => Self::pragma_kind(name, body.as_ref()), + Cmd::Explain(_) => Some(Self::Other), + Cmd::ExplainQueryPlan(_) => Some(Self::Other), + Cmd::Stmt(Stmt::Begin { .. }) => Some(Self::TxnBegin), + Cmd::Stmt(Stmt::Commit { .. } | Stmt::Rollback { .. }) => Some(Self::TxnEnd), + Cmd::Stmt( + Stmt::CreateVirtualTable { tbl_name, .. } + | Stmt::CreateTable { + tbl_name, + temporary: false, + .. + }, + ) if !is_temp(tbl_name) => Some(Self::Write), + Cmd::Stmt( + Stmt::Insert { + with: _, + or_conflict: _, + tbl_name, + .. + } + | Stmt::Update { + with: _, + or_conflict: _, + tbl_name, + .. + }, + ) => write_if_not_reserved(tbl_name), + + Cmd::Stmt(Stmt::Delete { + with: _, tbl_name, .. + }) => write_if_not_reserved(tbl_name), + Cmd::Stmt(Stmt::DropTable { + if_exists: _, + tbl_name, + }) => write_if_not_reserved(tbl_name), + Cmd::Stmt(Stmt::AlterTable(tbl_name, _)) => write_if_not_reserved(tbl_name), + Cmd::Stmt( + Stmt::DropIndex { .. } + | Stmt::CreateTrigger { + temporary: false, .. + } + | Stmt::CreateIndex { .. }, + ) => Some(Self::Write), + Cmd::Stmt(Stmt::Select { .. }) => Some(Self::Read), + Cmd::Stmt(Stmt::Pragma(name, body)) => Self::pragma_kind(name, body.as_ref()), + _ => None, + } + } + + fn pragma_kind(name: &QualifiedName, body: Option<&PragmaBody>) -> Option { + let name = name.name.0.as_str(); + match name { + // always ok to be served by primary or replicas - pure readonly pragmas + "table_list" | "index_list" | "table_info" | "table_xinfo" | "index_xinfo" + | "pragma_list" | "compile_options" | "database_list" | "function_list" + | "module_list" => Some(Self::Read), + // special case for `encoding` - it's effectively readonly for connections + // that already created a database, which is always the case for sqld + "encoding" => Some(Self::Read), + // always ok to be served by primary + "foreign_keys" | "foreign_key_list" | "foreign_key_check" | "collation_list" + | "data_version" | "freelist_count" | "integrity_check" | "legacy_file_format" + | "page_count" | "quick_check" | "stats" => Some(Self::Write), + // ok to be served by primary without args + "analysis_limit" + | "application_id" + | "auto_vacuum" + | "automatic_index" + | "busy_timeout" + | "cache_size" + | "cache_spill" + | "cell_size_check" + | "checkpoint_fullfsync" + | "defer_foreign_keys" + | "fullfsync" + | "hard_heap_limit" + | "journal_mode" + | "journal_size_limit" + | "legacy_alter_table" + | "locking_mode" + | "max_page_count" + | "mmap_size" + | "page_size" + | "query_only" + | "read_uncommitted" + | "recursive_triggers" + | "reverse_unordered_selects" + | "schema_version" + | "secure_delete" + | "soft_heap_limit" + | "synchronous" + | "temp_store" + | "threads" + | "trusted_schema" + | "user_version" + | "wal_autocheckpoint" => { + match body { + Some(_) => None, + None => Some(Self::Write), + } + } + // changes the state of the connection, and can't be allowed rn: + "case_sensitive_like" | "ignore_check_constraints" | "incremental_vacuum" + // TODO: check if optimize can be safely performed + | "optimize" + | "parser_trace" + | "shrink_memory" + | "wal_checkpoint" => None, + _ => { + tracing::debug!("Unknown pragma: {name}"); + None + }, + } + } +} + +/// The state of a transaction for a series of statement +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +pub enum State { + /// The txn in an opened state + Txn, + /// The txn in a closed state + Init, + /// This is an invalid state for the state machine + Invalid, +} + +impl State { + pub fn step(&mut self, kind: StmtKind) { + *self = match (*self, kind) { + (State::Txn, StmtKind::TxnBegin) | (State::Init, StmtKind::TxnEnd) => State::Invalid, + (State::Txn, StmtKind::TxnEnd) => State::Init, + (state, StmtKind::Other | StmtKind::Write | StmtKind::Read) => state, + (State::Invalid, _) => State::Invalid, + (State::Init, StmtKind::TxnBegin) => State::Txn, + }; + } + + pub fn reset(&mut self) { + *self = State::Init + } +} + +impl Statement { + pub fn empty() -> Self { + Self { + stmt: String::new(), + // empty statement is arbitrarely made of the read kind so it is not send to a writer + kind: StmtKind::Read, + is_iud: false, + is_insert: false, + } + } + + pub fn parse(s: &str) -> impl Iterator> + '_ { + fn parse_inner( + original: &str, + stmt_count: u64, + has_more_stmts: bool, + c: Cmd, + ) -> Result { + let kind = + StmtKind::kind(&c).ok_or_else(|| anyhow::anyhow!("unsupported statement"))?; + + if stmt_count == 1 && !has_more_stmts { + // XXX: Temporary workaround for integration with Atlas + if let Cmd::Stmt(Stmt::CreateTable { .. }) = &c { + return Ok(Statement { + stmt: original.to_string(), + kind, + is_iud: false, + is_insert: false, + }); + } + } + + let is_iud = matches!( + c, + Cmd::Stmt(Stmt::Insert { .. } | Stmt::Update { .. } | Stmt::Delete { .. }) + ); + let is_insert = matches!(c, Cmd::Stmt(Stmt::Insert { .. })); + + Ok(Statement { + stmt: c.to_string(), + kind, + is_iud, + is_insert, + }) + } + // The parser needs to be boxed because it's large, and you don't want it on the stack. + // There's upstream work to make it smaller, but in the meantime the parser should remain + // on the heap: + // - https://github.com/gwenn/lemon-rs/issues/8 + // - https://github.com/gwenn/lemon-rs/pull/19 + let mut parser = Box::new(Parser::new(s.as_bytes()).peekable()); + let mut stmt_count = 0; + std::iter::from_fn(move || { + stmt_count += 1; + match parser.next() { + Ok(Some(cmd)) => Some(parse_inner( + s, + stmt_count, + parser.peek().map_or(true, |o| o.is_some()), + cmd, + )), + Ok(None) => None, + Err(sqlite3_parser::lexer::sql::Error::ParserError( + ParserError::SyntaxError { + token_type: _, + found: Some(found), + }, + Some((line, col)), + )) => Some(Err(anyhow::anyhow!( + "syntax error around L{line}:{col}: `{found}`" + ))), + Err(e) => Some(Err(e.into())), + } + }) + } + + pub fn is_read_only(&self) -> bool { + matches!( + self.kind, + StmtKind::Read | StmtKind::TxnEnd | StmtKind::TxnBegin + ) + } +} + +/// Given a an initial state and an array of queries, attempts to predict what the final state will +/// be +pub fn predict_final_state<'a>( + mut state: State, + stmts: impl Iterator, +) -> State { + for stmt in stmts { + state.step(stmt.kind); + } + state +} diff --git a/libsqlx/src/connection.rs b/libsqlx/src/connection.rs new file mode 100644 index 00000000..d21ca9a6 --- /dev/null +++ b/libsqlx/src/connection.rs @@ -0,0 +1,33 @@ +use crate::program::Program; +use crate::result_builder::ResultBuilder; + +#[derive(Debug, Clone)] +pub struct DescribeResponse { + pub params: Vec, + pub cols: Vec, + pub is_explain: bool, + pub is_readonly: bool, +} + +#[derive(Debug, Clone)] +pub struct DescribeParam { + pub name: Option, +} + +#[derive(Debug, Clone)] +pub struct DescribeCol { + pub name: String, + pub decltype: Option, +} + +pub trait Connection { + /// Executes a query program + fn execute_program( + &mut self, + pgm: Program, + result_builder: B, + ) -> crate::Result; + + /// Parse the SQL statement and return information about it. + fn describe(&self, sql: String) -> crate::Result; +} diff --git a/libsqlx/src/database/frame.rs b/libsqlx/src/database/frame.rs new file mode 100644 index 00000000..337853fb --- /dev/null +++ b/libsqlx/src/database/frame.rs @@ -0,0 +1,112 @@ +use std::borrow::Cow; +use std::fmt; +use std::mem::{size_of, transmute}; +use std::ops::Deref; + +use bytemuck::{bytes_of, pod_read_unaligned, try_from_bytes, Pod, Zeroable}; +use bytes::{Bytes, BytesMut}; + +use super::libsql::replication_log::WAL_PAGE_SIZE; +use super::FrameNo; + +/// The file header for the WAL log. All fields are represented in little-endian ordering. +/// See `encode` and `decode` for actual layout. +// repr C for stable sizing +#[repr(C)] +#[derive(Debug, Clone, Copy, Zeroable, Pod)] +pub struct FrameHeader { + /// Incremental frame number + pub frame_no: FrameNo, + /// Rolling checksum of all the previous frames, including this one. + pub checksum: u64, + /// page number, if frame_type is FrameType::Page + pub page_no: u32, + /// Size of the database (in page) after commiting the transaction. This is passed from sqlite, + /// and serves as commit transaction boundary + pub size_after: u32, +} + +#[derive(Clone)] +/// The owned version of a replication frame. +/// Cloning this is cheap. +pub struct Frame { + data: Bytes, +} + +impl fmt::Debug for Frame { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Frame") + .field("header", &self.header()) + .field("data", &"[..]") + .finish() + } +} + +impl Frame { + /// size of a single frame + pub const SIZE: usize = size_of::() + WAL_PAGE_SIZE as usize; + + pub fn from_parts(header: &FrameHeader, data: &[u8]) -> Self { + assert_eq!(data.len(), WAL_PAGE_SIZE as usize); + let mut buf = BytesMut::with_capacity(Self::SIZE); + buf.extend_from_slice(bytes_of(header)); + buf.extend_from_slice(data); + + Self { data: buf.freeze() } + } + + pub fn try_from_bytes(data: Bytes) -> anyhow::Result { + anyhow::ensure!(data.len() == Self::SIZE, "invalid frame size"); + Ok(Self { data }) + } + + pub fn bytes(&self) -> Bytes { + self.data.clone() + } + + pub fn page_bytes(&self) -> Bytes { + let mut data = self.data.clone(); + let _ = data.split_to(size_of::()); + debug_assert_eq!(data.len(), WAL_PAGE_SIZE as usize); + data + } +} + +/// The borrowed version of Frame +#[repr(transparent)] +pub struct FrameBorrowed { + data: [u8], +} + +impl FrameBorrowed { + pub fn header(&self) -> Cow { + let data = &self.data[..size_of::()]; + try_from_bytes(data) + .map(Cow::Borrowed) + .unwrap_or_else(|_| Cow::Owned(pod_read_unaligned(data))) + } + + /// Returns the bytes for this frame. Includes the header bytes. + pub fn as_slice(&self) -> &[u8] { + &self.data + } + + pub fn from_bytes(data: &[u8]) -> &Self { + assert_eq!(data.len(), Frame::SIZE); + // SAFETY: &FrameBorrowed is equivalent to &[u8] + unsafe { transmute(data) } + } + + /// returns this frame's page data. + pub fn page(&self) -> &[u8] { + &self.data[size_of::()..] + } +} + +impl Deref for Frame { + type Target = FrameBorrowed; + + fn deref(&self) -> &Self::Target { + FrameBorrowed::from_bytes(&self.data) + } +} diff --git a/libsqlx/src/database/libsql/connection.rs b/libsqlx/src/database/libsql/connection.rs new file mode 100644 index 00000000..88632501 --- /dev/null +++ b/libsqlx/src/database/libsql/connection.rs @@ -0,0 +1,309 @@ +use std::path::{Path, PathBuf}; +use std::sync::Arc; +use std::time::Instant; + +use rusqlite::{OpenFlags, Statement, StatementStatus}; +use sqld_libsql_bindings::wal_hook::{WalHook, WalMethodsHook}; + +use crate::connection::{Connection, DescribeCol, DescribeParam, DescribeResponse}; +use crate::database::TXN_TIMEOUT; +use crate::error::Error; +use crate::program::{Cond, Program, Step}; +use crate::query::Query; +use crate::result_builder::{QueryBuilderConfig, ResultBuilder}; +use crate::seal::Seal; +use crate::Result; + +use super::RowStatsHandler; + +pub struct RowStats { + pub rows_read: u64, + pub rows_written: u64, +} + +impl From<&Statement<'_>> for RowStats { + fn from(stmt: &Statement) -> Self { + Self { + rows_read: stmt.get_status(StatementStatus::RowsRead) as u64, + rows_written: stmt.get_status(StatementStatus::RowsWritten) as u64, + } + } +} + +pub fn open_db<'a, W>( + path: &Path, + wal_methods: &'static WalMethodsHook, + hook_ctx: &'a mut W::Context, + flags: Option, +) -> std::result::Result, rusqlite::Error> +where + W: WalHook, +{ + let flags = flags.unwrap_or( + OpenFlags::SQLITE_OPEN_READ_WRITE + | OpenFlags::SQLITE_OPEN_CREATE + | OpenFlags::SQLITE_OPEN_URI + | OpenFlags::SQLITE_OPEN_NO_MUTEX, + ); + + sqld_libsql_bindings::Connection::open(path, flags, wal_methods, hook_ctx) +} + +pub struct LibsqlConnection { + timeout_deadline: Option, + conn: sqld_libsql_bindings::Connection<'static>, // holds a ref to _context, must be dropped first. + row_stats_handler: Option>, + builder_config: QueryBuilderConfig, + _context: Seal>, +} + +impl LibsqlConnection { + pub(crate) fn new( + path: &Path, + extensions: Option>, + wal_methods: &'static WalMethodsHook, + hook_ctx: W::Context, + row_stats_callback: Option>, + builder_config: QueryBuilderConfig, + ) -> Result> { + let mut ctx = Box::new(hook_ctx); + let this = LibsqlConnection { + conn: open_db( + path, + wal_methods, + unsafe { &mut *(ctx.as_mut() as *mut _) }, + None, + )?, + timeout_deadline: None, + builder_config, + row_stats_handler: row_stats_callback, + _context: Seal::new(ctx), + }; + + if let Some(extensions) = extensions { + for ext in extensions.iter() { + unsafe { + let _guard = rusqlite::LoadExtensionGuard::new(&this.conn).unwrap(); + if let Err(e) = this.conn.load_extension(ext, None) { + tracing::error!("failed to load extension: {}", ext.display()); + Err(e)?; + } + tracing::debug!("Loaded extension {}", ext.display()); + } + } + } + + Ok(this) + } + + #[cfg(test)] + pub fn inner_connection(&self) -> &sqld_libsql_bindings::Connection<'static> { + &self.conn + } + + fn run(&mut self, pgm: Program, mut builder: B) -> Result { + let mut results = Vec::with_capacity(pgm.steps.len()); + + builder.init(&self.builder_config)?; + let is_autocommit_before = self.conn.is_autocommit(); + + for step in pgm.steps() { + let res = self.execute_step(step, &results, &mut builder)?; + results.push(res); + } + + // A transaction is still open, set up a timeout + if is_autocommit_before && !self.conn.is_autocommit() { + self.timeout_deadline = Some(Instant::now() + TXN_TIMEOUT) + } + + builder.finish(!self.conn.is_autocommit(), None)?; + + Ok(builder) + } + + fn execute_step( + &mut self, + step: &Step, + results: &[bool], + builder: &mut impl ResultBuilder, + ) -> Result { + builder.begin_step()?; + let mut enabled = match step.cond.as_ref() { + Some(cond) => match eval_cond(cond, results) { + Ok(enabled) => enabled, + Err(e) => { + builder.step_error(e).unwrap(); + false + } + }, + None => true, + }; + + let (affected_row_count, last_insert_rowid) = if enabled { + match self.execute_query(&step.query, builder) { + // builder error interupt the execution of query. we should exit immediately. + Err(e @ Error::BuilderError(_)) => return Err(e), + Err(e) => { + builder.step_error(e)?; + enabled = false; + (0, None) + } + Ok(x) => x, + } + } else { + (0, None) + }; + + builder.finish_step(affected_row_count, last_insert_rowid)?; + + Ok(enabled) + } + + fn execute_query( + &self, + query: &Query, + builder: &mut impl ResultBuilder, + ) -> Result<(u64, Option)> { + tracing::trace!("executing query: {}", query.stmt.stmt); + + let mut stmt = self.conn.prepare(&query.stmt.stmt)?; + + let cols = stmt.columns(); + let cols_count = cols.len(); + builder.cols_description(&mut cols.iter().map(Into::into))?; + drop(cols); + + query + .params + .bind(&mut stmt) + .map_err(Error::LibSqlInvalidQueryParams)?; + + let mut qresult = stmt.raw_query(); + builder.begin_rows()?; + while let Some(row) = qresult.next()? { + builder.begin_row()?; + for i in 0..cols_count { + let val = row.get_ref(i)?; + builder.add_row_value(val)?; + } + builder.finish_row()?; + } + + builder.finish_rows()?; + + // sqlite3_changes() is only modified for INSERT, UPDATE or DELETE; it is not reset for SELECT, + // but we want to return 0 in that case. + let affected_row_count = match query.stmt.is_iud { + true => self.conn.changes(), + false => 0, + }; + + // sqlite3_last_insert_rowid() only makes sense for INSERTs into a rowid table. we can't detect + // a rowid table, but at least we can detect an INSERT + let last_insert_rowid = match query.stmt.is_insert { + true => Some(self.conn.last_insert_rowid()), + false => None, + }; + + drop(qresult); + + if let Some(ref handler) = self.row_stats_handler { + handler.handle_row_stats(RowStats::from(&stmt)) + } + + Ok((affected_row_count, last_insert_rowid)) + } +} + +fn eval_cond(cond: &Cond, results: &[bool]) -> Result { + let get_step_res = |step: usize| -> Result { + let res = results.get(step).ok_or(Error::InvalidBatchStep(step))?; + + Ok(*res) + }; + + Ok(match cond { + Cond::Ok { step } => get_step_res(*step)?, + Cond::Err { step } => !get_step_res(*step)?, + Cond::Not { cond } => !eval_cond(cond, results)?, + Cond::And { conds } => conds + .iter() + .try_fold(true, |x, cond| eval_cond(cond, results).map(|y| x & y))?, + Cond::Or { conds } => conds + .iter() + .try_fold(false, |x, cond| eval_cond(cond, results).map(|y| x | y))?, + }) +} + +impl Connection for LibsqlConnection { + fn execute_program(&mut self, pgm: Program, builder: B) -> crate::Result { + self.run(pgm, builder) + } + + fn describe(&self, sql: String) -> crate::Result { + let stmt = self.conn.prepare(&sql)?; + + let params = (1..=stmt.parameter_count()) + .map(|param_i| { + let name = stmt.parameter_name(param_i).map(|n| n.into()); + DescribeParam { name } + }) + .collect(); + + let cols = stmt + .columns() + .into_iter() + .map(|col| { + let name = col.name().into(); + let decltype = col.decl_type().map(|t| t.into()); + DescribeCol { name, decltype } + }) + .collect(); + + let is_explain = stmt.is_explain() != 0; + let is_readonly = stmt.readonly(); + + Ok(DescribeResponse { + params, + cols, + is_explain, + is_readonly, + }) + } +} + +#[cfg(test)] +mod test { + // use itertools::Itertools; + // + // use crate::result_builder::{test::test_driver, IgnoreResult}; + // + // use super::*; + + // fn setup_test_conn(ctx: &mut ()) -> Conn { + // let mut conn = Conn { + // timeout_deadline: None, + // conn: sqld_libsql_bindings::Connection::test(ctx), + // timed_out: false, + // builder_config: QueryBuilderConfig::default(), + // row_stats_callback: None, + // }; + // + // let stmts = std::iter::once("create table test (x)") + // .chain(std::iter::repeat("insert into test values ('hello world')").take(100)) + // .collect_vec(); + // conn.run(Program::seq(&stmts), IgnoreResult).unwrap(); + // + // conn + // } + // + // #[test] + // fn test_libsql_conn_builder_driver() { + // test_driver(1000, |b| { + // let ctx = &mut (); + // let mut conn = setup_test_conn(ctx); + // conn.run(Program::seq(&["select * from test"]), b) + // }) + // } +} diff --git a/libsqlx/src/database/libsql/injector/headers.rs b/libsqlx/src/database/libsql/injector/headers.rs new file mode 100644 index 00000000..0973d65b --- /dev/null +++ b/libsqlx/src/database/libsql/injector/headers.rs @@ -0,0 +1,47 @@ +use std::marker::PhantomData; + +use rusqlite::ffi::PgHdr; + +pub struct Headers<'a> { + ptr: *mut PgHdr, + _pth: PhantomData<&'a ()>, +} + +impl<'a> Headers<'a> { + // safety: ptr is guaranteed to be valid for 'a + pub(crate) unsafe fn new(ptr: *mut PgHdr) -> Self { + Self { + ptr, + _pth: PhantomData, + } + } + + pub(crate) fn as_ptr(&mut self) -> *mut PgHdr { + self.ptr + } + + pub(crate) fn all_applied(&self) -> bool { + let mut current = self.ptr; + while !current.is_null() { + unsafe { + // WAL appended + if (*current).flags & 0x040 == 0 { + return false; + } + current = (*current).pDirty; + } + } + + true + } +} + +impl Drop for Headers<'_> { + fn drop(&mut self) { + let mut current = self.ptr; + while !current.is_null() { + let h: Box = unsafe { Box::from_raw(current as _) }; + current = h.pDirty; + } + } +} diff --git a/libsqlx/src/database/libsql/injector/hook.rs b/libsqlx/src/database/libsql/injector/hook.rs new file mode 100644 index 00000000..0479fb2d --- /dev/null +++ b/libsqlx/src/database/libsql/injector/hook.rs @@ -0,0 +1,175 @@ +use std::ffi::{c_int, CStr}; + +use rusqlite::ffi::{libsql_wal as Wal, PgHdr}; +use sqld_libsql_bindings::ffi::types::XWalFrameFn; +use sqld_libsql_bindings::init_static_wal_method; +use sqld_libsql_bindings::wal_hook::WalHook; + +use crate::database::frame::FrameBorrowed; +use crate::database::libsql::replication_log::WAL_PAGE_SIZE; + +use super::headers::Headers; +use super::{FrameBuffer, InjectorCommitHandler}; + +// Those are custom error codes returned by the replicator hook. +pub const LIBSQL_INJECT_FATAL: c_int = 200; +/// Injection succeeded, left on a open txn state +pub const LIBSQL_INJECT_OK_TXN: c_int = 201; +/// Injection succeeded +pub const LIBSQL_INJECT_OK: c_int = 202; + +pub struct InjectorHookCtx { + /// shared frame buffer + buffer: FrameBuffer, + /// currently in a txn + is_txn: bool, + commit_handler: Box, +} + +impl InjectorHookCtx { + pub fn new( + buffer: FrameBuffer, + injector_commit_handler: impl InjectorCommitHandler + 'static, + ) -> Self { + Self { + buffer, + is_txn: false, + commit_handler: Box::new(injector_commit_handler), + } + } + + fn inject_pages( + &mut self, + sync_flags: i32, + orig: XWalFrameFn, + wal: *mut Wal, + ) -> anyhow::Result<()> { + self.is_txn = true; + let buffer = self.buffer.borrow(); + let (mut headers, last_frame_no, size_after) = + make_page_header(buffer.iter().map(|f| &**f)); + if size_after != 0 { + self.commit_handler.pre_commit(last_frame_no)?; + } + + let ret = unsafe { + orig( + wal, + WAL_PAGE_SIZE, + headers.as_ptr(), + size_after, + (size_after != 0) as _, + sync_flags, + ) + }; + + if ret == 0 { + debug_assert!(headers.all_applied()); + drop(headers); + if size_after != 0 { + self.commit_handler.post_commit(last_frame_no)?; + self.is_txn = false; + } + tracing::trace!("applied frame batch"); + + Ok(()) + } else { + anyhow::bail!("failed to apply pages"); + } + } +} + +/// Turn a list of `WalFrame` into a list of PgHdr. +/// The caller has the responsibility to free the returned headers. +/// return (headers, last_frame_no, size_after) +fn make_page_header<'a>( + frames: impl Iterator, +) -> (Headers<'a>, u64, u32) { + let mut first_pg: *mut PgHdr = std::ptr::null_mut(); + let mut current_pg; + let mut last_frame_no = 0; + let mut size_after = 0; + + let mut headers_count = 0; + let mut prev_pg: *mut PgHdr = std::ptr::null_mut(); + for frame in frames { + if frame.header().frame_no > last_frame_no { + last_frame_no = frame.header().frame_no; + size_after = frame.header().size_after; + } + + let page = PgHdr { + pPage: std::ptr::null_mut(), + pData: frame.page().as_ptr() as _, + pExtra: std::ptr::null_mut(), + pCache: std::ptr::null_mut(), + pDirty: std::ptr::null_mut(), + pPager: std::ptr::null_mut(), + pgno: frame.header().page_no, + pageHash: 0, + flags: 0x02, // PGHDR_DIRTY - it works without the flag, but why risk it + nRef: 0, + pDirtyNext: std::ptr::null_mut(), + pDirtyPrev: std::ptr::null_mut(), + }; + headers_count += 1; + current_pg = Box::into_raw(Box::new(page)); + if first_pg.is_null() { + first_pg = current_pg; + } + if !prev_pg.is_null() { + unsafe { + (*prev_pg).pDirty = current_pg; + } + } + prev_pg = current_pg; + } + + tracing::trace!("built {headers_count} page headers"); + + let headers = unsafe { Headers::new(first_pg) }; + (headers, last_frame_no, size_after) +} + +init_static_wal_method!(INJECTOR_METHODS, InjectorHook); + +/// The injector hook hijacks a call to xframes, and replace the content of the call with it's own +/// frames. +/// The Caller must first call `set_frames`, passing the frames to be injected, then trigger a call +/// to xFrames from the libsql connection (see dummy write in `injector`), and can then collect the +/// result on the injection with `take_result` +pub enum InjectorHook {} + +unsafe impl WalHook for InjectorHook { + type Context = InjectorHookCtx; + + fn on_frames( + wal: &mut Wal, + _page_size: c_int, + _page_headers: *mut PgHdr, + _size_after: u32, + _is_commit: c_int, + sync_flags: c_int, + orig: XWalFrameFn, + ) -> c_int { + let wal_ptr = wal as *mut _; + let ctx = Self::wal_extract_ctx(wal); + let ret = ctx.inject_pages(sync_flags, orig, wal_ptr); + if let Err(e) = ret { + tracing::error!("fatal replication error: {e}"); + return LIBSQL_INJECT_FATAL; + } + + ctx.buffer.borrow_mut().clear(); + + if !ctx.is_txn { + LIBSQL_INJECT_OK + } else { + LIBSQL_INJECT_OK_TXN + } + } + + fn name() -> &'static CStr { + CStr::from_bytes_with_nul(b"frame_injector_hook\0").unwrap() + } +} diff --git a/libsqlx/src/database/libsql/injector/mod.rs b/libsqlx/src/database/libsql/injector/mod.rs new file mode 100644 index 00000000..df01cd34 --- /dev/null +++ b/libsqlx/src/database/libsql/injector/mod.rs @@ -0,0 +1,233 @@ +use std::cell::RefCell; +use std::collections::VecDeque; +use std::path::Path; +use std::rc::Rc; + +use rusqlite::OpenFlags; + +use crate::database::frame::Frame; +use crate::database::libsql::injector::hook::{ + INJECTOR_METHODS, LIBSQL_INJECT_FATAL, LIBSQL_INJECT_OK, LIBSQL_INJECT_OK_TXN, +}; +use crate::database::FrameNo; +use crate::seal::Seal; + +use hook::InjectorHookCtx; + +mod headers; +mod hook; + +pub type FrameBuffer = Rc>>; + +pub struct Injector { + /// The injector is in a transaction state + is_txn: bool, + /// Buffer for holding current transaction frames + buffer: FrameBuffer, + /// Maximum capacity of the frame buffer + capacity: usize, + /// Injector connection + // connection must be dropped before the hook context + connection: sqld_libsql_bindings::Connection<'static>, + /// Pointer to the hook + _hook_ctx: Seal>, +} + +/// Methods from this trait are called before and after performing a frame injection. +/// This trait trait is used to record the last committed frame_no to the log. +/// The implementer can persist the pre and post commit frame no, and compare them in the event of +/// a crash; if the pre and post commit frame_no don't match, then the log may be corrupted. +pub trait InjectorCommitHandler: 'static { + fn pre_commit(&mut self, frame_no: FrameNo) -> anyhow::Result<()>; + fn post_commit(&mut self, frame_no: FrameNo) -> anyhow::Result<()>; +} + +#[cfg(test)] +impl InjectorCommitHandler for () { + fn pre_commit(&mut self, _frame_no: FrameNo) -> anyhow::Result<()> { + Ok(()) + } + + fn post_commit(&mut self, _frame_no: FrameNo) -> anyhow::Result<()> { + Ok(()) + } +} + +impl Injector { + pub fn new( + path: &Path, + injector_commit_hanlder: impl InjectorCommitHandler + 'static, + buffer_capacity: usize, + ) -> crate::Result { + let buffer = FrameBuffer::default(); + let ctx = InjectorHookCtx::new(buffer.clone(), injector_commit_hanlder); + let mut ctx = Box::new(ctx); + let connection = sqld_libsql_bindings::Connection::open( + path, + OpenFlags::SQLITE_OPEN_READ_WRITE + | OpenFlags::SQLITE_OPEN_CREATE + | OpenFlags::SQLITE_OPEN_URI + | OpenFlags::SQLITE_OPEN_NO_MUTEX, + &INJECTOR_METHODS, + // safety: hook is dropped after connection + unsafe { &mut *(ctx.as_mut() as *mut _) }, + )?; + + Ok(Self { + is_txn: false, + buffer, + capacity: buffer_capacity, + connection, + _hook_ctx: Seal::new(ctx), + }) + } + + /// Inject on frame into the log. If this was a commit frame, returns Ok(Some(FrameNo)). + pub(crate) fn inject_frame(&mut self, frame: Frame) -> crate::Result> { + let frame_close_txn = frame.header().size_after != 0; + self.buffer.borrow_mut().push_back(frame); + if frame_close_txn || self.buffer.borrow().len() >= self.capacity { + if !self.is_txn { + self.begin_txn(); + } + return self.flush(); + } + + Ok(None) + } + + /// Flush the buffer to libsql WAL. + /// Trigger a dummy write, and flush the cache to trigger a call to xFrame. The buffer's frame + /// are then injected into the wal. + fn flush(&mut self) -> crate::Result> { + let last_frame_no = match self.buffer.borrow().back() { + Some(f) => f.header().frame_no, + None => { + tracing::trace!("nothing to inject"); + return Ok(None); + } + }; + self.connection + .execute("INSERT INTO __DUMMY__ VALUES (42)", ())?; + // force call to xframe + match self.connection.cache_flush() { + Ok(_) => panic!("replication hook was not called"), + Err(e) => { + if let Some(e) = e.sqlite_error() { + if e.extended_code == LIBSQL_INJECT_OK { + // refresh schema + self.connection + .pragma_update(None, "writable_schema", "reset")?; + self.commit(); + self.is_txn = false; + assert!(self.buffer.borrow().is_empty()); + return Ok(Some(last_frame_no)); + } else if e.extended_code == LIBSQL_INJECT_OK_TXN { + self.is_txn = true; + assert!(self.buffer.borrow().is_empty()); + return Ok(None); + } else if e.extended_code == LIBSQL_INJECT_FATAL { + todo!("handle fatal error"); + } + } + + todo!("handle fatal error"); + } + } + } + + fn commit(&mut self) { + // TODO: error? + let _ = self.connection.execute("COMMIT", ()); + } + + fn begin_txn(&mut self) { + self.connection.execute("BEGIN IMMEDIATE", ()).unwrap(); + self.connection + .execute("CREATE TABLE __DUMMY__ (__dummy__)", ()) + .unwrap(); + } +} + +#[cfg(test)] +mod test { + use std::fs::File; + + use crate::database::libsql::injector::Injector; + use crate::database::libsql::replication_log::logger::LogFile; + + #[test] + fn test_simple_inject_frames() { + let file = File::open("assets/test/simple_wallog").unwrap(); + let log = LogFile::new(file).unwrap(); + let temp = tempfile::tempdir().unwrap(); + + let mut injector = Injector::new(temp.path(), (), 10).unwrap(); + for frame in log.frames_iter().unwrap() { + let frame = frame.unwrap(); + injector.inject_frame(frame).unwrap(); + } + + let conn = rusqlite::Connection::open(temp.path().join("data")).unwrap(); + + conn.query_row("SELECT COUNT(*) FROM test", (), |row| { + assert_eq!(row.get::<_, usize>(0).unwrap(), 5); + Ok(()) + }) + .unwrap(); + } + + #[test] + fn test_inject_frames_split_txn() { + let file = File::open("assets/test/simple_wallog").unwrap(); + let log = LogFile::new(file).unwrap(); + let temp = tempfile::tempdir().unwrap(); + + // inject one frame at a time + let mut injector = Injector::new(temp.path(), (), 1).unwrap(); + for frame in log.frames_iter().unwrap() { + let frame = frame.unwrap(); + injector.inject_frame(frame).unwrap(); + } + + let conn = rusqlite::Connection::open(temp.path().join("data")).unwrap(); + + conn.query_row("SELECT COUNT(*) FROM test", (), |row| { + assert_eq!(row.get::<_, usize>(0).unwrap(), 5); + Ok(()) + }) + .unwrap(); + } + + #[test] + fn test_inject_partial_txn_isolated() { + let file = File::open("assets/test/simple_wallog").unwrap(); + let log = LogFile::new(file).unwrap(); + let temp = tempfile::tempdir().unwrap(); + + // inject one frame at a time + let mut injector = Injector::new(temp.path(), (), 10).unwrap(); + let mut iter = log.frames_iter().unwrap(); + + assert!(injector + .inject_frame(iter.next().unwrap().unwrap()) + .unwrap() + .is_none()); + let conn = rusqlite::Connection::open(temp.path().join("data")).unwrap(); + assert!(conn + .query_row("SELECT COUNT(*) FROM test", (), |_| Ok(())) + .is_err()); + + while injector + .inject_frame(iter.next().unwrap().unwrap()) + .unwrap() + .is_none() + {} + + // reset schema + conn.pragma_update(None, "writable_schema", "reset") + .unwrap(); + conn.query_row("SELECT COUNT(*) FROM test", (), |_| Ok(())) + .unwrap(); + } +} diff --git a/libsqlx/src/database/libsql/mod.rs b/libsqlx/src/database/libsql/mod.rs new file mode 100644 index 00000000..e5f4ad0a --- /dev/null +++ b/libsqlx/src/database/libsql/mod.rs @@ -0,0 +1,365 @@ +use std::path::PathBuf; +use std::sync::Arc; + +use sqld_libsql_bindings::wal_hook::{TransparentMethods, WalHook, TRANSPARENT_METHODS}; +use sqld_libsql_bindings::WalMethodsHook; + +use crate::database::frame::Frame; +use crate::database::{Database, InjectError, InjectableDatabase}; +use crate::error::Error; +use crate::result_builder::QueryBuilderConfig; + +use connection::{LibsqlConnection, RowStats}; +use injector::Injector; +use replication_log::logger::{ + ReplicationLogger, ReplicationLoggerHook, ReplicationLoggerHookCtx, REPLICATION_METHODS, +}; + +use self::injector::InjectorCommitHandler; +use self::replication_log::logger::LogCompactor; + +pub use replication_log::merger::SnapshotMerger; + +mod connection; +mod injector; +pub(crate) mod replication_log; + +pub struct PrimaryType { + logger: Arc, +} + +impl LibsqlDbType for PrimaryType { + type ConnectionHook = ReplicationLoggerHook; + + fn hook() -> &'static WalMethodsHook { + &REPLICATION_METHODS + } + + fn hook_context(&self) -> ::Context { + ReplicationLoggerHookCtx { + buffer: Vec::new(), + logger: self.logger.clone(), + } + } +} + +pub struct ReplicaType { + // frame injector for the database + injector: Injector, +} + +impl LibsqlDbType for ReplicaType { + type ConnectionHook = TransparentMethods; + + fn hook() -> &'static WalMethodsHook { + &TRANSPARENT_METHODS + } + + fn hook_context(&self) -> ::Context {} +} + +pub trait LibsqlDbType { + type ConnectionHook: WalHook; + + /// Return a static reference to the instanciated WAL hook + fn hook() -> &'static WalMethodsHook; + /// returns new context for the wal hook + fn hook_context(&self) -> ::Context; +} + +/// A generic wrapper around a libsql database. +/// `LibsqlDatabase` can be specialized into either a `ReplicaType` or a `PrimaryType`. +/// In `PrimaryType` mode, the LibsqlDatabase maintains a replication log that can be replicated to +/// a `LibsqlDatabase` in replica mode, thanks to the methods provided by `InjectableDatabase` +/// implemented for `LibsqlDatabase`. +pub struct LibsqlDatabase { + /// The connection factory for this database + db_path: PathBuf, + extensions: Option>, + response_size_limit: u64, + row_stats_callback: Option>, + /// type-specific data for the database + ty: T, +} + +/// Handler trait for gathering row stats when executing queries. +pub trait RowStatsHandler { + fn handle_row_stats(&self, stats: RowStats); +} + +impl RowStatsHandler for F +where + F: Fn(RowStats), +{ + fn handle_row_stats(&self, stats: RowStats) { + (self)(stats) + } +} + +impl LibsqlDatabase { + /// Creates a new replica type database + pub fn new_replica( + db_path: PathBuf, + injector_buffer_capacity: usize, + injector_commit_handler: impl InjectorCommitHandler, + ) -> crate::Result { + let ty = ReplicaType { + injector: Injector::new(&db_path, injector_commit_handler, injector_buffer_capacity)?, + }; + + Ok(Self::new(db_path, ty)) + } +} + +impl LibsqlDatabase { + pub fn new_primary( + db_path: PathBuf, + compactor: impl LogCompactor, + // whether the log is dirty and might need repair + dirty: bool, + ) -> crate::Result { + let ty = PrimaryType { + logger: Arc::new(ReplicationLogger::open(&db_path, dirty, compactor)?), + }; + Ok(Self::new(db_path, ty)) + } +} + +impl LibsqlDatabase { + /// Create a new instance with the passed `LibsqlDbType`. + fn new(db_path: PathBuf, ty: T) -> Self { + Self { + db_path, + extensions: None, + response_size_limit: u64::MAX, + row_stats_callback: None, + ty, + } + } + + /// Load extensions for connection to this database. + pub fn with_extensions(mut self, ext: impl IntoIterator) -> Self { + self.extensions = Some(ext.into_iter().collect()); + self + } + + /// Register a callback + pub fn with_row_stats_handler(mut self, handler: Arc) -> Self { + self.row_stats_callback = Some(handler); + self + } +} + +impl Database for LibsqlDatabase { + type Connection = LibsqlConnection<::Context>; + + fn connect(&self) -> Result { + LibsqlConnection::<::Context>::new( + &self.db_path, + self.extensions.clone(), + T::hook(), + self.ty.hook_context(), + self.row_stats_callback.clone(), + QueryBuilderConfig { + max_size: Some(self.response_size_limit), + }, + ) + } +} + +impl InjectableDatabase for LibsqlDatabase { + fn inject_frame(&mut self, frame: Frame) -> Result<(), InjectError> { + self.ty.injector.inject_frame(frame).unwrap(); + Ok(()) + } +} + +#[cfg(test)] +mod test { + use std::cell::Cell; + use std::fs::File; + use std::rc::Rc; + + use rusqlite::types::Value; + + use crate::connection::Connection; + use crate::database::libsql::replication_log::logger::LogFile; + use crate::program::Program; + use crate::result_builder::{QueryResultBuilderError, ResultBuilder}; + + use super::*; + + struct ReadRowBuilder(Vec); + + impl ResultBuilder for ReadRowBuilder { + fn add_row_value( + &mut self, + v: rusqlite::types::ValueRef, + ) -> Result<(), QueryResultBuilderError> { + self.0.push(v.into()); + Ok(()) + } + } + + #[test] + fn inject_libsql_db() { + let temp = tempfile::tempdir().unwrap(); + let replica = ReplicaType { + injector: Injector::new(temp.path(), (), 10).unwrap(), + }; + let mut db = LibsqlDatabase::new(temp.path().to_path_buf(), replica); + + let mut conn = db.connect().unwrap(); + let res = conn + .execute_program( + Program::seq(&["select count(*) from test"]), + ReadRowBuilder(Vec::new()), + ) + .unwrap(); + assert!(res.0.is_empty()); + + let file = File::open("assets/test/simple_wallog").unwrap(); + let log = LogFile::new(file).unwrap(); + log.frames_iter() + .unwrap() + .for_each(|f| db.inject_frame(f.unwrap()).unwrap()); + + let res = conn + .execute_program( + Program::seq(&["select count(*) from test"]), + ReadRowBuilder(Vec::new()), + ) + .unwrap(); + assert_eq!(res.0[0], Value::Integer(5)); + } + + #[test] + fn roundtrip_primary_replica() { + let temp_primary = tempfile::tempdir().unwrap(); + let temp_replica = tempfile::tempdir().unwrap(); + + let primary = LibsqlDatabase::new( + temp_primary.path().to_path_buf(), + PrimaryType { + logger: Arc::new(ReplicationLogger::open(temp_primary.path(), false, ()).unwrap()), + }, + ); + + let mut replica = LibsqlDatabase::new( + temp_replica.path().to_path_buf(), + ReplicaType { + injector: Injector::new(temp_replica.path(), (), 10).unwrap(), + }, + ); + + let mut primary_conn = primary.connect().unwrap(); + primary_conn + .execute_program( + Program::seq(&["create table test (x)", "insert into test values (42)"]), + (), + ) + .unwrap(); + + let logfile = primary.ty.logger.log_file.read(); + + for frame in logfile.frames_iter().unwrap() { + let frame = frame.unwrap(); + replica.inject_frame(frame).unwrap(); + } + + let mut replica_conn = replica.connect().unwrap(); + let result = replica_conn + .execute_program( + Program::seq(&["select * from test limit 1"]), + ReadRowBuilder(Vec::new()), + ) + .unwrap(); + + assert_eq!(result.0.len(), 1); + assert_eq!(result.0[0], Value::Integer(42)); + } + + #[test] + fn primary_compact_log() { + struct Compactor(Rc>); + + impl LogCompactor for Compactor { + fn should_compact(&self, log: &LogFile) -> bool { + log.header().frame_count > 2 + } + + fn compact( + &self, + _file: LogFile, + _path: PathBuf, + _size_after: u32, + ) -> anyhow::Result<()> { + self.0.set(true); + Ok(()) + } + } + + let temp = tempfile::tempdir().unwrap(); + let compactor_called = Rc::new(Cell::new(false)); + let db = LibsqlDatabase::new_primary( + temp.path().to_path_buf(), + Compactor(compactor_called.clone()), + false, + ) + .unwrap(); + + let mut conn = db.connect().unwrap(); + conn.execute_program( + Program::seq(&["create table test (x)", "insert into test values (12)"]), + (), + ) + .unwrap(); + assert!(compactor_called.get()); + } + + #[test] + fn no_compaction_uncommited_frames() { + struct Compactor(Rc>); + + impl LogCompactor for Compactor { + fn should_compact(&self, log: &LogFile) -> bool { + assert_eq!(log.uncommitted_frame_count, 0); + self.0.set(true); + false + } + + fn compact( + &self, + _file: LogFile, + _path: PathBuf, + _size_after: u32, + ) -> anyhow::Result<()> { + unreachable!() + } + } + + let temp = tempfile::tempdir().unwrap(); + let compactor_called = Rc::new(Cell::new(false)); + let db = LibsqlDatabase::new_primary( + temp.path().to_path_buf(), + Compactor(compactor_called.clone()), + false, + ) + .unwrap(); + + let mut conn = db.connect().unwrap(); + conn.execute_program( + Program::seq(&[ + "begin", + "create table test (x)", + "insert into test values (12)", + ]), + (), + ) + .unwrap(); + conn.inner_connection().cache_flush().unwrap(); + assert!(!compactor_called.get()); + conn.execute_program(Program::seq(&["commit"]), ()).unwrap(); + assert!(compactor_called.get()); + } +} diff --git a/libsqlx/src/database/libsql/replication_log/frame_stream.rs b/libsqlx/src/database/libsql/replication_log/frame_stream.rs new file mode 100644 index 00000000..79436f99 --- /dev/null +++ b/libsqlx/src/database/libsql/replication_log/frame_stream.rs @@ -0,0 +1,114 @@ +use std::sync::Arc; +use std::task::{ready, Poll}; +use std::{pin::Pin, task::Context}; + +use futures::future::BoxFuture; +use futures::Stream; + +use crate::database::frame::Frame; + +use super::FrameNo; +use super::logger::{ReplicationLogger, LogReadError}; + + +/// Streams frames from the replication log starting at `current_frame_no`. +/// Only stops if the current frame is not in the log anymore. +pub struct FrameStream { + next_frame_no: FrameNo, + max_available_frame_no: FrameNo, + logger: Arc, + state: FrameStreamState, +} + +impl FrameStream { + pub fn new(logger: Arc, next_frame_no: FrameNo) -> Self { + let max_available_frame_no = *logger.new_frame_notifier.subscribe().borrow(); + Self { + next_frame_no, + max_available_frame_no, + logger, + state: FrameStreamState::Init, + } + } + + fn transition_state_next_frame(&mut self) { + if matches!(self.state, FrameStreamState::Closed) { + return; + } + + let next_frameno = self.next_frame_no; + let logger = self.logger.clone(); + let fut = async move { + let res = tokio::task::spawn_blocking(move || logger.get_frame(next_frameno)).await; + match res { + Ok(Ok(frame)) => Ok(frame), + Ok(Err(e)) => Err(e), + Err(e) => Err(LogReadError::Error(e.into())), + } + }; + + self.state = FrameStreamState::WaitingFrame(Box::pin(fut)); + } +} + +enum FrameStreamState { + Init, + /// waiting for new frames to replicate + WaitingFrameNo(BoxFuture<'static, anyhow::Result>), + WaitingFrame(BoxFuture<'static, Result>), + Closed, +} + +impl Stream for FrameStream { + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.state { + FrameStreamState::Init => { + self.transition_state_next_frame(); + self.poll_next(cx) + } + FrameStreamState::WaitingFrameNo(ref mut fut) => { + self.max_available_frame_no = match ready!(fut.as_mut().poll(cx)) { + Ok(frame_no) => frame_no, + Err(e) => { + self.state = FrameStreamState::Closed; + return Poll::Ready(Some(Err(LogReadError::Error(e)))); + } + }; + self.transition_state_next_frame(); + self.poll_next(cx) + } + FrameStreamState::WaitingFrame(ref mut fut) => match ready!(fut.as_mut().poll(cx)) { + Ok(frame) => { + self.next_frame_no += 1; + self.transition_state_next_frame(); + Poll::Ready(Some(Ok(frame))) + } + + Err(LogReadError::Ahead) => { + let mut notifier = self.logger.new_frame_notifier.subscribe(); + let max_available_frame_no = *notifier.borrow(); + // check in case value has already changed, otherwise we'll be notified later + if max_available_frame_no > self.max_available_frame_no { + self.max_available_frame_no = max_available_frame_no; + self.transition_state_next_frame(); + self.poll_next(cx) + } else { + let fut = async move { + notifier.changed().await?; + Ok(*notifier.borrow()) + }; + self.state = FrameStreamState::WaitingFrameNo(Box::pin(fut)); + self.poll_next(cx) + } + } + Err(e) => { + self.state = FrameStreamState::Closed; + Poll::Ready(Some(Err(e))) + } + }, + FrameStreamState::Closed => Poll::Ready(None), + } + } +} diff --git a/libsqlx/src/database/libsql/replication_log/logger.rs b/libsqlx/src/database/libsql/replication_log/logger.rs new file mode 100644 index 00000000..25546e36 --- /dev/null +++ b/libsqlx/src/database/libsql/replication_log/logger.rs @@ -0,0 +1,1022 @@ +use std::ffi::{c_int, c_void, CStr}; +use std::fs::{remove_dir_all, File, OpenOptions}; +use std::io::Write; +use std::mem::size_of; +use std::os::unix::prelude::FileExt; +use std::path::{Path, PathBuf}; +use std::sync::Arc; + +use anyhow::{bail, ensure}; +use bytemuck::{bytes_of, pod_read_unaligned, Pod, Zeroable}; +use bytes::{Bytes, BytesMut}; +use parking_lot::RwLock; +use rusqlite::ffi::{ + libsql_wal as Wal, sqlite3, PgHdr, SQLITE_CHECKPOINT_TRUNCATE, SQLITE_IOERR, SQLITE_OK, +}; +use sqld_libsql_bindings::ffi::types::{ + XWalCheckpointFn, XWalFrameFn, XWalSavePointUndoFn, XWalUndoFn, +}; +use sqld_libsql_bindings::ffi::PageHdrIter; +use sqld_libsql_bindings::init_static_wal_method; +use sqld_libsql_bindings::wal_hook::WalHook; +use tokio::sync::watch; +use uuid::Uuid; + +use crate::database::frame::{Frame, FrameHeader}; +#[cfg(feature = "bottomless")] +use crate::libsql::ffi::SQLITE_IOERR_WRITE; + +use super::snapshot::{find_snapshot_file, SnapshotFile}; +use super::{FrameNo, CRC_64_GO_ISO, WAL_MAGIC, WAL_PAGE_SIZE}; + +init_static_wal_method!(REPLICATION_METHODS, ReplicationLoggerHook); + +#[derive(PartialEq, Eq)] +struct Version([u16; 4]); + +impl Version { + fn current() -> Self { + let major = env!("CARGO_PKG_VERSION_MAJOR").parse().unwrap(); + let minor = env!("CARGO_PKG_VERSION_MINOR").parse().unwrap(); + let patch = env!("CARGO_PKG_VERSION_PATCH").parse().unwrap(); + Self([0, major, minor, patch]) + } +} + +pub enum ReplicationLoggerHook {} + +#[derive(Clone)] +pub struct ReplicationLoggerHookCtx { + pub(crate) buffer: Vec, + pub(crate) logger: Arc, + #[cfg(feature = "bottomless")] + bottomless_replicator: Option>>, +} + +/// This implementation of WalHook intercepts calls to `on_frame`, and writes them to a +/// shadow wal. Writing to the shadow wal is done in three steps: +/// i. append the new pages at the offset pointed by header.start_frame_no + header.frame_count +/// ii. call the underlying implementation of on_frames +/// iii. if the call of the underlying method was successfull, update the log header to the new +/// frame count. +/// +/// If either writing to the database of to the shadow wal fails, it must be noop. +unsafe impl WalHook for ReplicationLoggerHook { + type Context = ReplicationLoggerHookCtx; + + fn name() -> &'static CStr { + CStr::from_bytes_with_nul(b"replication_logger_hook\0").unwrap() + } + + fn on_frames( + wal: &mut Wal, + page_size: c_int, + page_headers: *mut PgHdr, + ntruncate: u32, + is_commit: c_int, + sync_flags: c_int, + orig: XWalFrameFn, + ) -> c_int { + assert_eq!(page_size, 4096); + let wal_ptr = wal as *mut _; + #[cfg(feature = "bottomless")] + let last_valid_frame = wal.hdr.mxFrame; + #[cfg(feature = "bottomless")] + let _frame_checksum = wal.hdr.aFrameCksum; + let ctx = Self::wal_extract_ctx(wal); + + for (page_no, data) in PageHdrIter::new(page_headers, page_size as _) { + ctx.write_frame(page_no, data) + } + if let Err(e) = ctx.flush(ntruncate) { + tracing::error!("error writing to replication log: {e}"); + // returning IO_ERR ensure that xUndo will be called by sqlite. + return SQLITE_IOERR; + } + + let rc = unsafe { + orig( + wal_ptr, + page_size, + page_headers, + ntruncate, + is_commit, + sync_flags, + ) + }; + + // FIXME: instead of block_on, we should consider replicating asynchronously in the background, + // e.g. by sending the data to another fiber by an unbounded channel (which allows sync insertions). + #[allow(clippy::await_holding_lock)] // uncontended -> only gets called under a libSQL write lock + #[cfg(feature = "bottomless")] + if rc == 0 { + let runtime = tokio::runtime::Handle::current(); + if let Some(replicator) = ctx.bottomless_replicator.as_mut() { + match runtime.block_on(async move { + let mut replicator = replicator.lock().unwrap(); + replicator.register_last_valid_frame(last_valid_frame); + // In theory it's enough to set the page size only once, but in practice + // it's a very cheap operation anyway, and the page is not always known + // upfront and can change dynamically. + // FIXME: changing the page size in the middle of operation is *not* + // supported by bottomless storage. + replicator.set_page_size(page_size as usize)?; + let frame_count = PageHdrIter::new(page_headers, page_size as usize).count(); + replicator.submit_frames(frame_count as u32); + Ok::<(), anyhow::Error>(()) + }) { + Ok(()) => {} + Err(e) => { + tracing::error!("error writing to bottomless: {e}"); + return SQLITE_IOERR_WRITE; + } + } + } + } + + if is_commit != 0 && rc == 0 { + if let Err(e) = ctx.commit() { + // If we reach this point, it means that we have commited a transaction to sqlite wal, + // but failed to commit it to the shadow WAL, which leaves us in an inconsistent state. + tracing::error!( + "fatal error: log failed to commit: inconsistent replication log: {e}" + ); + std::process::abort(); + } + + if let Err(e) = ctx.logger.log_file.write().maybe_compact( + &*ctx.logger.compactor, + ntruncate, + &ctx.logger.db_path, + ) { + tracing::error!("fatal error: {e}, exiting"); + std::process::abort() + } + } + + rc + } + + fn on_undo( + wal: &mut Wal, + func: Option i32>, + undo_ctx: *mut c_void, + orig: XWalUndoFn, + ) -> i32 { + let ctx = Self::wal_extract_ctx(wal); + ctx.rollback(); + + #[cfg(feature = "bottomless")] + tracing::error!( + "fixme: implement bottomless undo for {:?}", + ctx.bottomless_replicator + ); + + unsafe { orig(wal, func, undo_ctx) } + } + + fn on_savepoint_undo(wal: &mut Wal, wal_data: *mut u32, orig: XWalSavePointUndoFn) -> i32 { + let rc = unsafe { orig(wal, wal_data) }; + if rc != SQLITE_OK { + return rc; + }; + + #[cfg(feature = "bottomless")] + { + let ctx = Self::wal_extract_ctx(wal); + if let Some(replicator) = ctx.bottomless_replicator.as_mut() { + let last_valid_frame = unsafe { *wal_data }; + let mut replicator = replicator.lock().unwrap(); + let prev_valid_frame = replicator.peek_last_valid_frame(); + tracing::trace!( + "Savepoint: rolling back from frame {prev_valid_frame} to {last_valid_frame}", + ); + replicator.rollback_to_frame(last_valid_frame); + } + } + + rc + } + + #[allow(clippy::too_many_arguments)] + fn on_checkpoint( + wal: &mut Wal, + db: *mut sqlite3, + emode: i32, + busy_handler: Option i32>, + busy_arg: *mut c_void, + sync_flags: i32, + n_buf: i32, + z_buf: *mut u8, + frames_in_wal: *mut i32, + backfilled_frames: *mut i32, + orig: XWalCheckpointFn, + ) -> i32 { + #[cfg(feature = "bottomless")] + { + tracing::trace!("bottomless checkpoint"); + + /* In order to avoid partial checkpoints, passive checkpoint + ** mode is not allowed. Only TRUNCATE checkpoints are accepted, + ** because these are guaranteed to block writes, copy all WAL pages + ** back into the main database file and reset the frame number. + ** In order to avoid autocheckpoint on close (that's too often), + ** checkpoint attempts weaker than TRUNCATE are ignored. + */ + if emode < SQLITE_CHECKPOINT_TRUNCATE { + tracing::trace!("Ignoring a checkpoint request weaker than TRUNCATE"); + return SQLITE_OK; + } + } + let rc = unsafe { + orig( + wal, + db, + emode, + busy_handler, + busy_arg, + sync_flags, + n_buf, + z_buf, + frames_in_wal, + backfilled_frames, + ) + }; + + if rc != SQLITE_OK { + return rc; + } + + #[allow(clippy::await_holding_lock)] // uncontended -> only gets called under a libSQL write lock + #[cfg(feature = "bottomless")] + { + let ctx = Self::wal_extract_ctx(wal); + let runtime = tokio::runtime::Handle::current(); + if let Some(replicator) = ctx.bottomless_replicator.as_mut() { + let mut replicator = replicator.lock().unwrap(); + if replicator.commits_in_current_generation() == 0 { + tracing::debug!("No commits happened in this generation, not snapshotting"); + return SQLITE_OK; + } + let last_known_frame = replicator.last_known_frame(); + replicator.request_flush(); + if let Err(e) = runtime.block_on(replicator.wait_until_committed(last_known_frame)) + { + tracing::error!( + "Failed to wait for S3 replicator to confirm {} frames backup: {}", + last_known_frame, + e + ); + return SQLITE_IOERR_WRITE; + } + replicator.new_generation(); + if let Err(e) = + runtime.block_on(async move { replicator.snapshot_main_db_file().await }) + { + tracing::error!("Failed to snapshot the main db file during checkpoint: {e}"); + return SQLITE_IOERR_WRITE; + } + } + } + SQLITE_OK + } +} + +#[derive(Clone)] +pub struct WalPage { + pub page_no: u32, + /// 0 for non-commit frames + pub size_after: u32, + pub data: Bytes, +} + +impl ReplicationLoggerHookCtx { + pub fn new( + logger: Arc, + #[cfg(feature = "bottomless")] bottomless_replicator: Option< + Arc>, + >, + ) -> Self { + #[cfg(feature = "bottomless")] + tracing::trace!("bottomless replication enabled: {bottomless_replicator:?}"); + Self { + buffer: Default::default(), + logger, + #[cfg(feature = "bottomless")] + bottomless_replicator, + } + } + + fn write_frame(&mut self, page_no: u32, data: &[u8]) { + let entry = WalPage { + page_no, + size_after: 0, + data: Bytes::copy_from_slice(data), + }; + self.buffer.push(entry); + } + + /// write buffered pages to the logger, without commiting. + fn flush(&mut self, size_after: u32) -> anyhow::Result<()> { + if !self.buffer.is_empty() { + self.buffer.last_mut().unwrap().size_after = size_after; + self.logger.write_pages(&self.buffer)?; + self.buffer.clear(); + } + + Ok(()) + } + + fn commit(&self) -> anyhow::Result<()> { + let new_frame_no = self.logger.commit()?; + let _ = self.logger.new_frame_notifier.send(new_frame_no); + Ok(()) + } + + fn rollback(&mut self) { + self.logger.log_file.write().rollback(); + self.buffer.clear(); + } +} + +/// Represent a LogFile, and operations that can be performed on it. +/// A log file must only ever be opened by a single instance of LogFile, since it caches the file +/// header. +#[derive(Debug)] +pub struct LogFile { + file: File, + pub header: LogFileHeader, + /// number of frames in the log that have not been commited yet. On commit the header's frame + /// count is incremented by that ammount. New pages are written after the last + /// header.frame_count + uncommit_frame_count. + /// On rollback, this is reset to 0, so that everything that was written after the previous + /// header.frame_count is ignored and can be overwritten + pub(crate) uncommitted_frame_count: u64, + uncommitted_checksum: u64, + + /// checksum of the last commited frame + commited_checksum: u64, +} + +#[derive(thiserror::Error, Debug)] +pub enum LogReadError { + #[error("could not fetch log entry, snapshot required")] + SnapshotRequired, + #[error("requested entry is ahead of log")] + Ahead, + #[error(transparent)] + Error(#[from] anyhow::Error), +} + +impl LogFile { + /// size of a single frame + pub const FRAME_SIZE: usize = size_of::() + WAL_PAGE_SIZE as usize; + + pub fn new(file: File) -> crate::Result { + // FIXME: we should probably take a lock on this file, to prevent anybody else to write to + // it. + let file_end = file.metadata()?.len(); + + if file_end == 0 { + let db_id = Uuid::new_v4(); + let header = LogFileHeader { + version: 2, + start_frame_no: 0, + magic: WAL_MAGIC, + page_size: WAL_PAGE_SIZE, + start_checksum: 0, + db_id: db_id.as_u128(), + frame_count: 0, + sqld_version: Version::current().0, + }; + + let mut this = Self { + file, + header, + uncommitted_frame_count: 0, + uncommitted_checksum: 0, + commited_checksum: 0, + }; + + this.write_header()?; + + Ok(this) + } else { + let header = Self::read_header(&file)?; + let mut this = Self { + file, + header, + uncommitted_frame_count: 0, + uncommitted_checksum: 0, + commited_checksum: 0, + }; + + if let Some(last_commited) = this.last_commited_frame_no() { + // file is not empty, the starting checksum is the checksum from the last entry + let last_frame = this.frame(last_commited).unwrap(); + this.commited_checksum = last_frame.header().checksum; + this.uncommitted_checksum = last_frame.header().checksum; + } else { + // file contains no entry, start with the initial checksum from the file header. + this.commited_checksum = this.header.start_checksum; + this.uncommitted_checksum = this.header.start_checksum; + } + + Ok(this) + } + } + + pub fn read_header(file: &File) -> crate::Result { + let mut buf = [0; size_of::()]; + file.read_exact_at(&mut buf, 0)?; + let header: LogFileHeader = pod_read_unaligned(&buf); + if header.magic != WAL_MAGIC { + return Err(crate::error::Error::InvalidLogHeader); + } + + Ok(header) + } + + pub fn header(&self) -> &LogFileHeader { + &self.header + } + + pub fn commit(&mut self) -> crate::Result<()> { + self.header.frame_count += self.uncommitted_frame_count; + self.uncommitted_frame_count = 0; + self.commited_checksum = self.uncommitted_checksum; + self.write_header()?; + + Ok(()) + } + + fn rollback(&mut self) { + self.uncommitted_frame_count = 0; + self.uncommitted_checksum = self.commited_checksum; + } + + pub fn write_header(&mut self) -> crate::Result<()> { + self.file.write_all_at(bytes_of(&self.header), 0)?; + self.file.flush()?; + + Ok(()) + } + + /// Returns an iterator over the WAL frame headers + pub fn frames_iter(&self) -> anyhow::Result> + '_> { + let mut current_frame_offset = 0; + Ok(std::iter::from_fn(move || { + if current_frame_offset >= self.header.frame_count { + return None; + } + let read_byte_offset = Self::absolute_byte_offset(current_frame_offset); + current_frame_offset += 1; + Some(self.read_frame_byte_offset(read_byte_offset)) + })) + } + + /// Returns an iterator over the WAL frame headers + pub fn rev_frames_iter( + &self, + ) -> anyhow::Result> + '_> { + let mut current_frame_offset = self.header.frame_count; + + Ok(std::iter::from_fn(move || { + if current_frame_offset == 0 { + return None; + } + current_frame_offset -= 1; + let read_byte_offset = Self::absolute_byte_offset(current_frame_offset); + let frame = self.read_frame_byte_offset(read_byte_offset); + Some(frame) + })) + } + + fn compute_checksum(&self, page: &WalPage) -> u64 { + let mut digest = CRC_64_GO_ISO.digest_with_initial(self.uncommitted_checksum); + digest.update(&page.data); + digest.finalize() + } + + pub fn push_page(&mut self, page: &WalPage) -> crate::Result<()> { + let checksum = self.compute_checksum(page); + let frame = Frame::from_parts( + &FrameHeader { + frame_no: self.next_frame_no(), + checksum, + page_no: page.page_no, + size_after: page.size_after, + }, + &page.data, + ); + + let byte_offset = self.next_byte_offset(); + tracing::trace!( + "writing frame {} at offset {byte_offset}", + frame.header().frame_no + ); + self.file.write_all_at(frame.as_slice(), byte_offset)?; + + self.uncommitted_frame_count += 1; + self.uncommitted_checksum = checksum; + + Ok(()) + } + + /// offset in bytes at which to write the next frame + fn next_byte_offset(&self) -> u64 { + Self::absolute_byte_offset(self.header().frame_count + self.uncommitted_frame_count) + } + + fn next_frame_no(&self) -> FrameNo { + self.header().start_frame_no + self.header().frame_count + self.uncommitted_frame_count + } + + /// Returns the bytes position of the `nth` entry in the log + fn absolute_byte_offset(nth: u64) -> u64 { + std::mem::size_of::() as u64 + nth * Self::FRAME_SIZE as u64 + } + + fn byte_offset(&self, id: FrameNo) -> anyhow::Result> { + if id < self.header.start_frame_no + || id > self.header.start_frame_no + self.header.frame_count + { + return Ok(None); + } + Ok(Self::absolute_byte_offset(id - self.header.start_frame_no).into()) + } + + /// Returns bytes represening a WalFrame for frame `frame_no` + /// + /// If the requested frame is before the first frame in the log, or after the last frame, + /// Ok(None) is returned. + pub fn frame(&self, frame_no: FrameNo) -> std::result::Result { + if frame_no < self.header.start_frame_no { + return Err(LogReadError::SnapshotRequired); + } + + if frame_no >= self.header.start_frame_no + self.header.frame_count { + return Err(LogReadError::Ahead); + } + + let frame = self.read_frame_byte_offset(self.byte_offset(frame_no)?.unwrap())?; + + Ok(frame) + } + + fn maybe_compact( + &mut self, + compactor: &dyn LogCompactor, + size_after: u32, + path: &Path, + ) -> anyhow::Result<()> { + if compactor.should_compact(self) { + return self.do_compaction(compactor, size_after, path); + } + + Ok(()) + } + + fn do_compaction( + &mut self, + compactor: &dyn LogCompactor, + size_after: u32, + path: &Path, + ) -> anyhow::Result<()> { + tracing::info!("performing log compaction"); + let temp_log_path = path.join("temp_log"); + let file = OpenOptions::new() + .read(true) + .write(true) + .create(true) + .open(&temp_log_path)?; + let mut new_log_file = LogFile::new(file)?; + let new_header = LogFileHeader { + start_frame_no: self.header.start_frame_no + self.header.frame_count, + frame_count: 0, + start_checksum: self.commited_checksum, + ..self.header + }; + new_log_file.header = new_header; + new_log_file.write_header().unwrap(); + // swap old and new snapshot + atomic_rename(&temp_log_path, path.join("wallog")).unwrap(); + let old_log_file = std::mem::replace(self, new_log_file); + compactor.compact(old_log_file, temp_log_path, size_after)?; + + Ok(()) + } + + fn read_frame_byte_offset(&self, offset: u64) -> anyhow::Result { + let mut buffer = BytesMut::zeroed(LogFile::FRAME_SIZE); + self.file.read_exact_at(&mut buffer, offset)?; + let buffer = buffer.freeze(); + + Frame::try_from_bytes(buffer) + } + + fn last_commited_frame_no(&self) -> Option { + if self.header.frame_count == 0 { + None + } else { + Some(self.header.start_frame_no + self.header.frame_count - 1) + } + } + + fn reset(self) -> crate::Result { + // truncate file + self.file.set_len(0)?; + Self::new(self.file) + } +} + +#[cfg(target_os = "macos")] +fn atomic_rename(p1: impl AsRef, p2: impl AsRef) -> anyhow::Result<()> { + use std::ffi::CString; + use std::os::unix::prelude::OsStrExt; + + use nix::libc::renamex_np; + use nix::libc::RENAME_SWAP; + + let p1 = CString::new(p1.as_ref().as_os_str().as_bytes())?; + let p2 = CString::new(p2.as_ref().as_os_str().as_bytes())?; + unsafe { + let ret = renamex_np(p1.as_ptr(), p2.as_ptr(), RENAME_SWAP); + + if ret != 0 { + bail!( + "failed to perform snapshot file swap: {ret}, errno: {}", + std::io::Error::last_os_error() + ); + } + } + + Ok(()) +} + +#[cfg(target_os = "linux")] +fn atomic_rename(p1: impl AsRef, p2: impl AsRef) -> anyhow::Result<()> { + use anyhow::Context; + use nix::fcntl::{renameat2, RenameFlags}; + + renameat2( + None, + p1.as_ref(), + None, + p2.as_ref(), + RenameFlags::RENAME_EXCHANGE, + ) + .context("failed to perform snapshot file swap")?; + + Ok(()) +} + +#[derive(Debug, Clone, Copy, Zeroable, Pod)] +#[repr(C)] +pub struct LogFileHeader { + /// magic number: b"SQLDWAL\0" as u64 + pub magic: u64, + /// Initial checksum value for the rolling CRC checksum + /// computed with the 64 bits CRC_64_GO_ISO + pub start_checksum: u64, + /// Uuid of the database associated with this log. + pub db_id: u128, + /// Frame_no of the first frame in the log + pub start_frame_no: FrameNo, + /// entry count in file + pub frame_count: u64, + /// Wal file version number, currently: 2 + pub version: u32, + /// page size: 4096 + pub page_size: i32, + /// sqld version when creating this log + pub sqld_version: [u16; 4], +} + +impl LogFileHeader { + pub fn last_frame_no(&self) -> FrameNo { + self.start_frame_no + self.frame_count + } + + fn sqld_version(&self) -> Version { + Version(self.sqld_version) + } +} + +pub struct Generation { + pub id: Uuid, + pub start_index: u64, +} + +impl Generation { + fn new(start_index: u64) -> Self { + Self { + id: Uuid::new_v4(), + start_index, + } + } +} + +pub trait LogCompactor: 'static { + /// returns whether the passed log file should be compacted. If this method returns true, + /// compact should be called next. + fn should_compact(&self, log: &LogFile) -> bool; + /// Compact the given snapshot + fn compact(&self, log: LogFile, path: PathBuf, size_after: u32) -> anyhow::Result<()>; +} + +#[cfg(test)] +impl LogCompactor for () { + fn compact(&self, _file: LogFile, _path: PathBuf, _size_after: u32) -> anyhow::Result<()> { + Ok(()) + } + + fn should_compact(&self, _file: &LogFile) -> bool { + false + } +} + +pub struct ReplicationLogger { + pub generation: Generation, + pub log_file: RwLock, + compactor: Box, + db_path: PathBuf, + /// a notifier channel other tasks can subscribe to, and get notified when new frames become + /// available. + pub new_frame_notifier: watch::Sender, +} + +impl ReplicationLogger { + pub fn open(db_path: &Path, dirty: bool, compactor: impl LogCompactor) -> crate::Result { + let log_path = db_path.join("wallog"); + let data_path = db_path.join("data"); + + let fresh = !log_path.exists(); + + let file = OpenOptions::new() + .create(true) + .write(true) + .read(true) + .open(log_path)?; + + let log_file = LogFile::new(file)?; + let header = log_file.header(); + + let should_recover = if dirty { + tracing::info!("Replication log is dirty, recovering from database file."); + true + } else if header.version < 2 || header.sqld_version() != Version::current() { + tracing::info!("replication log version not compatible with current sqld version, recovering from database file."); + true + } else if fresh && data_path.exists() { + tracing::info!("replication log not found, recovering from database file."); + true + } else { + false + }; + + if should_recover { + Self::recover(log_file, data_path, compactor) + } else { + Self::from_log_file(db_path.to_path_buf(), log_file, compactor) + } + } + + fn from_log_file( + db_path: PathBuf, + log_file: LogFile, + compactor: impl LogCompactor, + ) -> crate::Result { + let header = log_file.header(); + let generation_start_frame_no = header.start_frame_no + header.frame_count; + + let (new_frame_notifier, _) = watch::channel(generation_start_frame_no); + + Ok(Self { + generation: Generation::new(generation_start_frame_no), + compactor: Box::new(compactor), + log_file: RwLock::new(log_file), + db_path, + new_frame_notifier, + }) + } + + fn recover( + log_file: LogFile, + mut data_path: PathBuf, + compactor: impl LogCompactor, + ) -> crate::Result { + // It is necessary to checkpoint before we restore the replication log, since the WAL may + // contain pages that are not in the database file. + checkpoint_db(&data_path)?; + let mut log_file = log_file.reset()?; + let snapshot_path = data_path.parent().unwrap().join("snapshots"); + // best effort, there may be no snapshots + let _ = remove_dir_all(snapshot_path); + + let data_file = File::open(&data_path)?; + let size = data_path.metadata()?.len(); + assert!( + size % WAL_PAGE_SIZE as u64 == 0, + "database file size is not a multiple of page size" + ); + let num_page = size / WAL_PAGE_SIZE as u64; + let mut buf = [0; WAL_PAGE_SIZE as usize]; + let mut page_no = 1; // page numbering starts at 1 + for i in 0..num_page { + data_file.read_exact_at(&mut buf, i * WAL_PAGE_SIZE as u64)?; + log_file.push_page(&WalPage { + page_no, + size_after: if i == num_page - 1 { num_page as _ } else { 0 }, + data: Bytes::copy_from_slice(&buf), + })?; + log_file.commit()?; + + page_no += 1; + } + + assert!(data_path.pop()); + + Self::from_log_file(data_path, log_file, compactor) + } + + pub fn database_id(&self) -> anyhow::Result { + Ok(Uuid::from_u128((self.log_file.read()).header().db_id)) + } + + /// Write pages to the log, without updating the file header. + /// Returns the new frame count and checksum to commit + fn write_pages(&self, pages: &[WalPage]) -> anyhow::Result<()> { + let mut log_file = self.log_file.write(); + for page in pages.iter() { + log_file.push_page(page)?; + } + + Ok(()) + } + + #[allow(dead_code)] + fn compute_checksum(wal_header: &LogFileHeader, log_file: &LogFile) -> anyhow::Result { + tracing::debug!("computing WAL log running checksum..."); + let mut iter = log_file.frames_iter()?; + iter.try_fold(wal_header.start_checksum, |sum, frame| { + let frame = frame?; + let mut digest = CRC_64_GO_ISO.digest_with_initial(sum); + digest.update(frame.page()); + let cs = digest.finalize(); + ensure!( + cs == frame.header().checksum, + "invalid WAL file: invalid checksum" + ); + Ok(cs) + }) + } + + /// commit the current transaction and returns the new top frame number + fn commit(&self) -> anyhow::Result { + let mut log_file = self.log_file.write(); + log_file.commit()?; + Ok(log_file.header().last_frame_no()) + } + + pub fn get_snapshot_file(&self, from: FrameNo) -> anyhow::Result> { + find_snapshot_file(&self.db_path, from) + } + + pub fn get_frame(&self, frame_no: FrameNo) -> Result { + self.log_file.read().frame(frame_no) + } +} + +fn checkpoint_db(data_path: &Path) -> crate::Result<()> { + unsafe { + let conn = rusqlite::Connection::open(data_path)?; + conn.pragma_query(None, "page_size", |row| { + let page_size = row.get::<_, i32>(0).unwrap(); + assert_eq!( + page_size, WAL_PAGE_SIZE, + "invalid database file, expected page size to be {}, but found {} instead", + WAL_PAGE_SIZE, page_size + ); + Ok(()) + })?; + let mut num_checkpointed: c_int = 0; + let rc = rusqlite::ffi::sqlite3_wal_checkpoint_v2( + conn.handle(), + std::ptr::null(), + SQLITE_CHECKPOINT_TRUNCATE, + &mut num_checkpointed as *mut _, + std::ptr::null_mut(), + ); + + // TODO: ensure correct page size + assert!( + rc == 0 && num_checkpointed >= 0, + "failed to checkpoint database while recovering replication log" + ); + + conn.execute("VACUUM", ())?; + } + + Ok(()) +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn write_and_read_from_frame_log() { + let dir = tempfile::tempdir().unwrap(); + let logger = ReplicationLogger::open(dir.path(), false, ()).unwrap(); + + let frames = (0..10) + .map(|i| WalPage { + page_no: i, + size_after: 0, + data: Bytes::from(vec![i as _; 4096]), + }) + .collect::>(); + logger.write_pages(&frames).unwrap(); + logger.commit().unwrap(); + + let log_file = logger.log_file.write(); + for i in 0..10 { + let frame = log_file.frame(i).unwrap(); + assert_eq!(frame.header().page_no, i as u32); + assert!(frame.page().iter().all(|x| i as u8 == *x)); + } + + assert_eq!( + log_file.header.start_frame_no + log_file.header.frame_count, + 10 + ); + } + + #[test] + fn index_out_of_bounds() { + let dir = tempfile::tempdir().unwrap(); + let logger = ReplicationLogger::open(dir.path(), false, ()).unwrap(); + let log_file = logger.log_file.write(); + assert!(matches!(log_file.frame(1), Err(LogReadError::Ahead))); + } + + #[test] + #[should_panic] + fn incorrect_frame_size() { + let dir = tempfile::tempdir().unwrap(); + let logger = ReplicationLogger::open(dir.path(), false, ()).unwrap(); + let entry = WalPage { + page_no: 0, + size_after: 0, + data: vec![0; 3].into(), + }; + + logger.write_pages(&[entry]).unwrap(); + logger.commit().unwrap(); + } + + #[test] + fn log_file_test_rollback() { + let f = tempfile::tempfile().unwrap(); + let mut log_file = LogFile::new(f).unwrap(); + (0..5) + .map(|i| WalPage { + page_no: i, + size_after: 5, + data: Bytes::from_static(&[1; 4096]), + }) + .for_each(|p| { + log_file.push_page(&p).unwrap(); + }); + + assert_eq!(log_file.frames_iter().unwrap().count(), 0); + + log_file.commit().unwrap(); + + (0..5) + .map(|i| WalPage { + page_no: i, + size_after: 5, + data: Bytes::from_static(&[1; 4096]), + }) + .for_each(|p| { + log_file.push_page(&p).unwrap(); + }); + + log_file.rollback(); + assert_eq!(log_file.frames_iter().unwrap().count(), 5); + + log_file + .push_page(&WalPage { + page_no: 42, + size_after: 5, + data: Bytes::from_static(&[1; 4096]), + }) + .unwrap(); + + assert_eq!(log_file.frames_iter().unwrap().count(), 5); + log_file.commit().unwrap(); + assert_eq!(log_file.frames_iter().unwrap().count(), 6); + } +} diff --git a/libsqlx/src/database/libsql/replication_log/merger.rs b/libsqlx/src/database/libsql/replication_log/merger.rs new file mode 100644 index 00000000..d098a9e5 --- /dev/null +++ b/libsqlx/src/database/libsql/replication_log/merger.rs @@ -0,0 +1,137 @@ +use std::path::Path; +use std::sync::mpsc; +use std::thread::JoinHandle; + +use crate::database::frame::Frame; + +use super::snapshot::{ + parse_snapshot_name, snapshot_dir_path, snapshot_list, SnapshotBuilder, SnapshotFile, + MAX_SNAPSHOT_NUMBER, SNAPHOT_SPACE_AMPLIFICATION_FACTOR, +}; + +pub struct SnapshotMerger { + /// Sending part of a channel of (snapshot_name, snapshot_frame_count, db_page_count) to the merger thread + sender: mpsc::Sender<(String, u64, u32)>, + handle: Option>>, +} + +impl SnapshotMerger { + pub fn new(db_path: &Path, db_id: u128) -> anyhow::Result { + let (sender, receiver) = mpsc::channel(); + + let db_path = db_path.to_path_buf(); + let handle = + std::thread::spawn(move || Self::run_snapshot_merger_loop(receiver, &db_path, db_id)); + + Ok(Self { + sender, + handle: Some(handle), + }) + } + + fn should_compact(snapshots: &[(String, u64)], db_page_count: u32) -> bool { + let snapshots_size: u64 = snapshots.iter().map(|(_, s)| *s).sum(); + snapshots_size >= SNAPHOT_SPACE_AMPLIFICATION_FACTOR * db_page_count as u64 + || snapshots.len() > MAX_SNAPSHOT_NUMBER + } + + fn run_snapshot_merger_loop( + receiver: mpsc::Receiver<(String, u64, u32)>, + db_path: &Path, + db_id: u128, + ) -> anyhow::Result<()> { + let mut snapshots = Self::init_snapshot_info_list(db_path)?; + while let Ok((name, size, db_page_count)) = receiver.recv() { + snapshots.push((name, size)); + if Self::should_compact(&snapshots, db_page_count) { + let compacted_snapshot_info = Self::merge_snapshots(&snapshots, db_path, db_id)?; + snapshots.clear(); + snapshots.push(compacted_snapshot_info); + } + } + + Ok(()) + } + + /// Reads the snapshot dir and returns the list of snapshots along with their size, sorted in + /// chronological order. + /// + /// TODO: if the process was killed in the midst of merging snapshot, then the compacted snapshot + /// can exist alongside the snapshots it's supposed to have compacted. This is the place to + /// perform the cleanup. + fn init_snapshot_info_list(db_path: &Path) -> anyhow::Result> { + let snapshot_dir_path = snapshot_dir_path(db_path); + if !snapshot_dir_path.exists() { + return Ok(Vec::new()); + } + + let mut temp = Vec::new(); + for snapshot_name in snapshot_list(db_path)? { + let snapshot_path = snapshot_dir_path.join(&snapshot_name); + let snapshot = SnapshotFile::open(&snapshot_path)?; + temp.push(( + snapshot_name, + snapshot.header.frame_count, + snapshot.header.start_frame_no, + )) + } + + temp.sort_by_key(|(_, _, id)| *id); + + Ok(temp + .into_iter() + .map(|(name, count, _)| (name, count)) + .collect()) + } + + fn merge_snapshots( + snapshots: &[(String, u64)], + db_path: &Path, + db_id: u128, + ) -> anyhow::Result<(String, u64)> { + let mut builder = SnapshotBuilder::new(db_path, db_id)?; + let snapshot_dir_path = snapshot_dir_path(db_path); + for (name, _) in snapshots.iter().rev() { + let snapshot = SnapshotFile::open(&snapshot_dir_path.join(name))?; + let iter = snapshot.frames_iter().map(|b| Frame::try_from_bytes(b?)); + builder.append_frames(iter)?; + } + + let (_, start_frame_no, _) = parse_snapshot_name(&snapshots[0].0).unwrap(); + let (_, _, end_frame_no) = parse_snapshot_name(&snapshots.last().unwrap().0).unwrap(); + + builder.header.start_frame_no = start_frame_no; + builder.header.end_frame_no = end_frame_no; + + let compacted_snapshot_infos = builder.finish()?; + + for (name, _) in snapshots.iter() { + std::fs::remove_file(&snapshot_dir_path.join(name))?; + } + + Ok(compacted_snapshot_infos) + } + + pub fn register_snapshot( + &mut self, + snapshot_name: String, + snapshot_frame_count: u64, + db_page_count: u32, + ) -> anyhow::Result<()> { + if self + .sender + .send((snapshot_name, snapshot_frame_count, db_page_count)) + .is_err() + { + if let Some(handle) = self.handle.take() { + handle + .join() + .map_err(|_| anyhow::anyhow!("snapshot merger thread panicked"))??; + } + + anyhow::bail!("failed to register snapshot with log merger: thread exited"); + } + + Ok(()) + } +} diff --git a/libsqlx/src/database/libsql/replication_log/mod.rs b/libsqlx/src/database/libsql/replication_log/mod.rs new file mode 100644 index 00000000..42b2a03f --- /dev/null +++ b/libsqlx/src/database/libsql/replication_log/mod.rs @@ -0,0 +1,12 @@ +use crc::Crc; + +pub mod logger; +pub mod merger; +pub mod snapshot; + +pub const WAL_PAGE_SIZE: i32 = 4096; +pub const WAL_MAGIC: u64 = u64::from_le_bytes(*b"SQLDWAL\0"); +const CRC_64_GO_ISO: Crc = Crc::::new(&crc::CRC_64_GO_ISO); + +/// The frame uniquely identifying, monotonically increasing number +pub type FrameNo = u64; diff --git a/libsqlx/src/database/libsql/replication_log/snapshot.rs b/libsqlx/src/database/libsql/replication_log/snapshot.rs new file mode 100644 index 00000000..c5f58ea3 --- /dev/null +++ b/libsqlx/src/database/libsql/replication_log/snapshot.rs @@ -0,0 +1,334 @@ +use std::collections::HashSet; +use std::fs::File; +use std::io::BufWriter; +use std::io::Write; +use std::mem::size_of; +use std::os::unix::prelude::FileExt; +use std::path::{Path, PathBuf}; +use std::str::FromStr; + +use bytemuck::{bytes_of, pod_read_unaligned, Pod, Zeroable}; +use bytes::{Bytes, BytesMut}; +use once_cell::sync::Lazy; +use regex::Regex; +use tempfile::NamedTempFile; +use uuid::Uuid; + +use crate::database::frame::Frame; + +use super::logger::LogFile; +use super::FrameNo; + +/// This is the ratio of the space required to store snapshot vs size of the actual database. +/// When this ratio is exceeded, compaction is triggered. +pub const SNAPHOT_SPACE_AMPLIFICATION_FACTOR: u64 = 2; +/// The maximum amount of snapshot allowed before a compaction is required +pub const MAX_SNAPSHOT_NUMBER: usize = 32; + +#[derive(Debug, Copy, Clone, Zeroable, Pod, PartialEq, Eq)] +#[repr(C)] +pub struct SnapshotFileHeader { + /// id of the database + pub db_id: u128, + /// first frame in the snapshot + pub start_frame_no: u64, + /// end frame in the snapshot + pub end_frame_no: u64, + /// number of frames in the snapshot + pub frame_count: u64, + /// safe of the database after applying the snapshot + pub size_after: u32, + pub _pad: u32, +} + +pub struct SnapshotFile { + pub file: File, + pub header: SnapshotFileHeader, +} + +/// returns (db_id, start_frame_no, end_frame_no) for the given snapshot name +pub fn parse_snapshot_name(name: &str) -> Option<(Uuid, u64, u64)> { + static SNAPSHOT_FILE_MATCHER: Lazy = Lazy::new(|| { + Regex::new( + r#"(?x) + # match database id + (\w{8}-\w{4}-\w{4}-\w{4}-\w{12})- + # match start frame_no + (\d*)- + # match end frame_no + (\d*).snap"#, + ) + .unwrap() + }); + let Some(captures) = SNAPSHOT_FILE_MATCHER.captures(name) else { return None}; + let db_id = captures.get(1).unwrap(); + let start_index: u64 = captures.get(2).unwrap().as_str().parse().unwrap(); + let end_index: u64 = captures.get(3).unwrap().as_str().parse().unwrap(); + + Some(( + Uuid::from_str(db_id.as_str()).unwrap(), + start_index, + end_index, + )) +} + +pub fn snapshot_list(db_path: &Path) -> anyhow::Result> { + let mut entries = std::fs::read_dir(snapshot_dir_path(db_path))?; + Ok(std::iter::from_fn(move || { + for entry in entries.by_ref() { + let Ok(entry) = entry else { continue; }; + let path = entry.path(); + let Some(name) = path.file_name() else {continue;}; + let Some(name_str) = name.to_str() else { continue;}; + + return Some(name_str.to_string()); + } + None + })) +} + +/// Return snapshot file containing "logically" frame_no +pub fn find_snapshot_file( + db_path: &Path, + frame_no: FrameNo, +) -> anyhow::Result> { + let snapshot_dir_path = snapshot_dir_path(db_path); + for name in snapshot_list(db_path)? { + let Some((_, start_frame_no, end_frame_no)) = parse_snapshot_name(&name) else { continue; }; + // we're looking for the frame right after the last applied frame on the replica + if (start_frame_no..=end_frame_no).contains(&frame_no) { + let snapshot_path = snapshot_dir_path.join(&name); + tracing::debug!("found snapshot for frame {frame_no} at {snapshot_path:?}"); + let snapshot_file = SnapshotFile::open(&snapshot_path)?; + return Ok(Some(snapshot_file)); + } + } + + Ok(None) +} + +impl SnapshotFile { + pub fn open(path: &Path) -> anyhow::Result { + let file = File::open(path)?; + let mut header_buf = [0; size_of::()]; + file.read_exact_at(&mut header_buf, 0)?; + let header: SnapshotFileHeader = pod_read_unaligned(&header_buf); + + Ok(Self { file, header }) + } + + /// Iterator on the frames contained in the snapshot file, in reverse frame_no order. + pub fn frames_iter(&self) -> impl Iterator> + '_ { + let mut current_offset = 0; + std::iter::from_fn(move || { + if current_offset >= self.header.frame_count { + return None; + } + let read_offset = size_of::() as u64 + + current_offset * LogFile::FRAME_SIZE as u64; + current_offset += 1; + let mut buf = BytesMut::zeroed(LogFile::FRAME_SIZE); + match self.file.read_exact_at(&mut buf, read_offset as _) { + Ok(_) => Some(Ok(buf.freeze())), + Err(e) => Some(Err(e.into())), + } + }) + } + + /// Like `frames_iter`, but stops as soon as a frame with frame_no <= `frame_no` is reached + pub fn frames_iter_from( + &self, + frame_no: u64, + ) -> impl Iterator> + '_ { + let mut iter = self.frames_iter(); + std::iter::from_fn(move || match iter.next() { + Some(Ok(bytes)) => match Frame::try_from_bytes(bytes.clone()) { + Ok(frame) => { + if frame.header().frame_no < frame_no { + None + } else { + Some(Ok(bytes)) + } + } + Err(e) => Some(Err(e)), + }, + other => other, + }) + } +} + +/// An utility to build a snapshots from log frames +pub struct SnapshotBuilder { + seen_pages: HashSet, + pub header: SnapshotFileHeader, + snapshot_file: BufWriter, + db_path: PathBuf, + last_seen_frame_no: u64, +} + +pub fn snapshot_dir_path(db_path: &Path) -> PathBuf { + db_path.join("snapshots") +} + +impl SnapshotBuilder { + pub fn new(db_path: &Path, db_id: u128) -> anyhow::Result { + let snapshot_dir_path = snapshot_dir_path(db_path); + std::fs::create_dir_all(&snapshot_dir_path)?; + let mut target = BufWriter::new(NamedTempFile::new_in(&snapshot_dir_path)?); + // reserve header space + target.write_all(&[0; size_of::()])?; + + Ok(Self { + seen_pages: HashSet::new(), + header: SnapshotFileHeader { + db_id, + start_frame_no: u64::MAX, + end_frame_no: u64::MIN, + frame_count: 0, + size_after: 0, + _pad: 0, + }, + snapshot_file: target, + db_path: db_path.to_path_buf(), + last_seen_frame_no: u64::MAX, + }) + } + + /// append frames to the snapshot. Frames must be in decreasing frame_no order. + pub fn append_frames( + &mut self, + frames: impl Iterator>, + ) -> anyhow::Result<()> { + // We iterate on the frames starting from the end of the log and working our way backward. We + // make sure that only the most recent version of each file is present in the resulting + // snapshot. + // + // The snapshot file contains the most recent version of each page, in descending frame + // number order. That last part is important for when we read it later on. + for frame in frames { + let frame = frame?; + assert!(frame.header().frame_no < self.last_seen_frame_no); + self.last_seen_frame_no = frame.header().frame_no; + if frame.header().frame_no < self.header.start_frame_no { + self.header.start_frame_no = frame.header().frame_no; + } + + if frame.header().frame_no > self.header.end_frame_no { + self.header.end_frame_no = frame.header().frame_no; + self.header.size_after = frame.header().size_after; + } + + if !self.seen_pages.contains(&frame.header().page_no) { + self.seen_pages.insert(frame.header().page_no); + self.snapshot_file.write_all(frame.as_slice())?; + self.header.frame_count += 1; + } + } + + Ok(()) + } + + /// Persist the snapshot, and returns the name and size is frame on the snapshot. + pub fn finish(mut self) -> anyhow::Result<(String, u64)> { + self.snapshot_file.flush()?; + let file = self.snapshot_file.into_inner()?; + file.as_file().write_all_at(bytes_of(&self.header), 0)?; + let snapshot_name = format!( + "{}-{}-{}.snap", + Uuid::from_u128(self.header.db_id), + self.header.start_frame_no, + self.header.end_frame_no, + ); + + file.persist(snapshot_dir_path(&self.db_path).join(&snapshot_name))?; + + Ok((snapshot_name, self.header.frame_count)) + } +} + +// #[cfg(test)] +// mod test { +// use std::fs::read; +// use std::{thread, time::Duration}; +// +// use bytemuck::pod_read_unaligned; +// use bytes::Bytes; +// use tempfile::tempdir; +// +// use crate::database::frame::FrameHeader; +// use crate::database::libsql::replication_log::logger::WalPage; +// +// use super::*; +// +// #[test] +// fn compact_file_create_snapshot() { +// let temp = tempfile::NamedTempFile::new().unwrap(); +// let mut log_file = LogFile::new(temp.as_file().try_clone().unwrap(), 0).unwrap(); +// let db_id = Uuid::new_v4(); +// log_file.header.db_id = db_id.as_u128(); +// log_file.write_header().unwrap(); +// +// // add 50 pages, each one in two versions +// for _ in 0..2 { +// for i in 0..25 { +// let data = std::iter::repeat(0).take(4096).collect::(); +// let page = WalPage { +// page_no: i, +// size_after: i + 1, +// data, +// }; +// log_file.push_page(&page).unwrap(); +// } +// } +// +// log_file.commit().unwrap(); +// +// let dump_dir = tempdir().unwrap(); +// let compactor = LogCompactor::new(dump_dir.path(), db_id.as_u128()).unwrap(); +// compactor +// .compact(log_file, temp.path().to_path_buf(), 25) +// .unwrap(); +// +// thread::sleep(Duration::from_secs(1)); +// +// let snapshot_path = +// snapshot_dir_path(dump_dir.path()).join(format!("{}-{}-{}.snap", db_id, 0, 49)); +// let snapshot = read(&snapshot_path).unwrap(); +// let header: SnapshotFileHeader = +// pod_read_unaligned(&snapshot[..std::mem::size_of::()]); +// +// assert_eq!(header.start_frame_no, 0); +// assert_eq!(header.end_frame_no, 49); +// assert_eq!(header.frame_count, 25); +// assert_eq!(header.db_id, db_id.as_u128()); +// assert_eq!(header.size_after, 25); +// +// let mut seen_frames = HashSet::new(); +// let mut seen_page_no = HashSet::new(); +// let data = &snapshot[std::mem::size_of::()..]; +// data.chunks(LogFile::FRAME_SIZE).for_each(|f| { +// let frame = Frame::try_from_bytes(Bytes::copy_from_slice(f)).unwrap(); +// assert!(!seen_frames.contains(&frame.header().frame_no)); +// assert!(!seen_page_no.contains(&frame.header().page_no)); +// seen_page_no.insert(frame.header().page_no); +// seen_frames.insert(frame.header().frame_no); +// assert!(frame.header().frame_no >= 25); +// }); +// +// assert_eq!(seen_frames.len(), 25); +// assert_eq!(seen_page_no.len(), 25); +// +// let snapshot_file = SnapshotFile::open(&snapshot_path).unwrap(); +// +// let frames = snapshot_file.frames_iter_from(0); +// let mut expected_frame_no = 49; +// for frame in frames { +// let frame = frame.unwrap(); +// let header: FrameHeader = pod_read_unaligned(&frame[..size_of::()]); +// assert_eq!(header.frame_no, expected_frame_no); +// expected_frame_no -= 1; +// } +// +// assert_eq!(expected_frame_no, 24); +// } +// } diff --git a/libsqlx/src/database/mod.rs b/libsqlx/src/database/mod.rs new file mode 100644 index 00000000..edfb0b8e --- /dev/null +++ b/libsqlx/src/database/mod.rs @@ -0,0 +1,135 @@ +use std::time::Duration; + +use crate::connection::{Connection, DescribeResponse}; +use crate::error::Error; +use crate::program::Program; +use crate::result_builder::ResultBuilder; +use crate::semaphore::{Permit, Semaphore}; + +use self::frame::Frame; + +mod frame; +pub mod libsql; +pub mod proxy; +#[cfg(test)] +mod test_utils; + +pub type FrameNo = u64; + +pub const TXN_TIMEOUT: Duration = Duration::from_secs(5); + +#[derive(Debug)] +pub enum InjectError {} + +pub trait Database { + type Connection: Connection; + /// Create a new connection to the database + fn connect(&self) -> Result; + + /// returns a database with a limit on the number of conccurent connections + fn throttled(self, limit: usize, timeout: Option) -> ThrottledDatabase + where + Self: Sized, + { + ThrottledDatabase::new(limit, self, timeout) + } +} + +// Trait implemented by databases that support frame injection +pub trait InjectableDatabase { + fn inject_frame(&mut self, frame: Frame) -> Result<(), InjectError>; +} + +/// A Database that limits the number of conccurent connections to the underlying database. +pub struct ThrottledDatabase { + semaphore: Semaphore, + db: T, + timeout: Option, +} + +impl ThrottledDatabase { + fn new(conccurency: usize, db: T, timeout: Option) -> Self { + Self { + semaphore: Semaphore::new(conccurency), + db, + timeout, + } + } +} + +impl Database for ThrottledDatabase { + type Connection = TrackedDb; + + fn connect(&self) -> Result { + let permit = match self.timeout { + Some(t) => self + .semaphore + .acquire_timeout(t) + .ok_or(Error::DbCreateTimeout)?, + None => self.semaphore.acquire(), + }; + + let inner = self.db.connect()?; + Ok(TrackedDb { permit, inner }) + } +} + +pub struct TrackedDb { + inner: DB, + #[allow(dead_code)] // just hold on to it + permit: Permit, +} + +impl Connection for TrackedDb { + #[inline] + fn execute_program(&mut self, pgm: Program, builder: B) -> crate::Result { + self.inner.execute_program(pgm, builder) + } + + #[inline] + fn describe(&self, sql: String) -> crate::Result { + self.inner.describe(sql) + } +} + +#[cfg(test)] +mod test { + use super::*; + + struct DummyConn; + + impl Connection for DummyConn { + fn execute_program(&mut self, _pgm: Program, _builder: B) -> crate::Result + where + B: ResultBuilder, + { + unreachable!() + } + + fn describe(&self, _sql: String) -> crate::Result { + unreachable!() + } + } + + struct DummyDatabase; + + impl Database for DummyDatabase { + type Connection = DummyConn; + + fn connect(&self) -> Result { + Ok(DummyConn) + } + } + + #[test] + fn throttle_db_creation() { + let db = DummyDatabase.throttled(1, Some(Duration::from_millis(100))); + let conn = db.connect().unwrap(); + + assert!(db.connect().is_err()); + + drop(conn); + + assert!(db.connect().is_ok()); + } +} diff --git a/libsqlx/src/database/proxy/connection.rs b/libsqlx/src/database/proxy/connection.rs new file mode 100644 index 00000000..6fec7d23 --- /dev/null +++ b/libsqlx/src/database/proxy/connection.rs @@ -0,0 +1,222 @@ +use crate::connection::{Connection, DescribeResponse}; +use crate::database::FrameNo; +use crate::program::Program; +use crate::result_builder::{QueryBuilderConfig, ResultBuilder}; +use crate::Result; + +use super::WaitFrameNoCb; + +#[derive(Debug, Default)] +pub(crate) struct ConnState { + is_txn: bool, + last_frame_no: Option, +} + +/// A connection that proxies write operations to the `WriteDb` and the read operations to the +/// `ReadDb` +pub struct WriteProxyConnection { + pub(crate) read_db: ReadDb, + pub(crate) write_db: WriteDb, + pub(crate) wait_frame_no_cb: WaitFrameNoCb, + pub(crate) state: parking_lot::Mutex, +} + +impl Connection for WriteProxyConnection +where + ReadDb: Connection, + WriteDb: Connection, +{ + fn execute_program(&mut self, pgm: Program, builder: B) -> Result { + let mut state = self.state.lock(); + let builder = ExtractFrameNoBuilder::new(builder); + if !state.is_txn && pgm.is_read_only() { + if let Some(frame_no) = state.last_frame_no { + (self.wait_frame_no_cb)(frame_no); + } + // We know that this program won't perform any writes. We attempt to run it on the + // replica. If it leaves an open transaction, then this program is an interactive + // transaction, so we rollback the replica, and execute again on the primary. + let builder = self.read_db.execute_program(pgm.clone(), builder)?; + + // still in transaction state after running a read-only txn + if builder.is_txn { + // TODO: rollback + // self.read_db.rollback().await?; + let builder = self.write_db.execute_program(pgm, builder)?; + state.is_txn = builder.is_txn; + state.last_frame_no = builder.frame_no; + Ok(builder.inner) + } else { + Ok(builder.inner) + } + } else { + let builder = self.write_db.execute_program(pgm, builder)?; + state.is_txn = builder.is_txn; + state.last_frame_no = builder.frame_no; + Ok(builder.inner) + } + } + + fn describe(&self, sql: String) -> Result { + if let Some(frame_no) = self.state.lock().last_frame_no { + (self.wait_frame_no_cb)(frame_no); + } + self.read_db.describe(sql) + } +} + +struct ExtractFrameNoBuilder { + inner: B, + frame_no: Option, + is_txn: bool, +} + +impl ExtractFrameNoBuilder { + fn new(inner: B) -> Self { + Self { + inner, + frame_no: None, + is_txn: false, + } + } +} + +impl ResultBuilder for ExtractFrameNoBuilder { + fn init( + &mut self, + config: &QueryBuilderConfig, + ) -> std::result::Result<(), crate::result_builder::QueryResultBuilderError> { + self.inner.init(config) + } + + fn begin_step( + &mut self, + ) -> std::result::Result<(), crate::result_builder::QueryResultBuilderError> { + self.inner.begin_step() + } + + fn finish_step( + &mut self, + affected_row_count: u64, + last_insert_rowid: Option, + ) -> std::result::Result<(), crate::result_builder::QueryResultBuilderError> { + self.inner + .finish_step(affected_row_count, last_insert_rowid) + } + + fn step_error( + &mut self, + error: crate::error::Error, + ) -> std::result::Result<(), crate::result_builder::QueryResultBuilderError> { + self.inner.step_error(error) + } + + fn cols_description( + &mut self, + cols: &mut dyn Iterator, + ) -> std::result::Result<(), crate::result_builder::QueryResultBuilderError> { + self.inner.cols_description(cols) + } + + fn begin_rows( + &mut self, + ) -> std::result::Result<(), crate::result_builder::QueryResultBuilderError> { + self.inner.begin_rows() + } + + fn begin_row( + &mut self, + ) -> std::result::Result<(), crate::result_builder::QueryResultBuilderError> { + self.inner.begin_row() + } + + fn add_row_value( + &mut self, + v: rusqlite::types::ValueRef, + ) -> std::result::Result<(), crate::result_builder::QueryResultBuilderError> { + self.inner.add_row_value(v) + } + + fn finish_row( + &mut self, + ) -> std::result::Result<(), crate::result_builder::QueryResultBuilderError> { + self.inner.finish_row() + } + + fn finish_rows( + &mut self, + ) -> std::result::Result<(), crate::result_builder::QueryResultBuilderError> { + self.inner.finish_rows() + } + + fn finish( + &mut self, + is_txn: bool, + frame_no: Option, + ) -> std::result::Result<(), crate::result_builder::QueryResultBuilderError> { + self.frame_no = frame_no; + self.is_txn = is_txn; + self.inner.finish(is_txn, frame_no) + } +} + +#[cfg(test)] +mod test { + use std::cell::Cell; + use std::rc::Rc; + use std::sync::Arc; + + use crate::connection::Connection; + use crate::database::test_utils::MockDatabase; + use crate::database::{proxy::database::WriteProxyDatabase, Database}; + use crate::program::Program; + + #[test] + fn simple_write_proxied() { + let write_called = Rc::new(Cell::new(false)); + let write_db = MockDatabase::new().with_execute({ + let write_called = write_called.clone(); + move |_, b| { + b.finish(false, Some(42)).unwrap(); + write_called.set(true); + Ok(()) + } + }); + + let read_called = Rc::new(Cell::new(false)); + let read_db = MockDatabase::new().with_execute({ + let read_called = read_called.clone(); + move |_, _| { + read_called.set(true); + Ok(()) + } + }); + + let wait_called = Rc::new(Cell::new(false)); + let db = WriteProxyDatabase::new( + read_db, + write_db, + Arc::new({ + let wait_called = wait_called.clone(); + move |fno| { + assert_eq!(fno, 42); + wait_called.set(true); + } + }), + ); + + let mut conn = db.connect().unwrap(); + conn.execute_program(Program::seq(&["insert into test values (12)"]), ()) + .unwrap(); + + assert!(!wait_called.get()); + assert!(!read_called.get()); + assert!(write_called.get()); + + conn.execute_program(Program::seq(&["select * from test"]), ()) + .unwrap(); + + assert!(read_called.get()); + assert!(wait_called.get()); + } +} diff --git a/libsqlx/src/database/proxy/database.rs b/libsqlx/src/database/proxy/database.rs new file mode 100644 index 00000000..fedd0ef7 --- /dev/null +++ b/libsqlx/src/database/proxy/database.rs @@ -0,0 +1,50 @@ +use crate::database::frame::Frame; +use crate::database::{Database, InjectableDatabase}; +use crate::error::Error; + +use super::connection::WriteProxyConnection; +use super::WaitFrameNoCb; + +pub struct WriteProxyDatabase { + read_db: RDB, + write_db: WDB, + wait_frame_no_cb: WaitFrameNoCb, +} + +impl WriteProxyDatabase { + pub fn new(read_db: RDB, write_db: WDB, wait_frame_no_cb: WaitFrameNoCb) -> Self { + Self { + read_db, + write_db, + wait_frame_no_cb, + } + } +} + +impl Database for WriteProxyDatabase +where + RDB: Database, + WDB: Database, +{ + type Connection = WriteProxyConnection; + + /// Create a new connection to the database + fn connect(&self) -> Result { + Ok(WriteProxyConnection { + read_db: self.read_db.connect()?, + write_db: self.write_db.connect()?, + wait_frame_no_cb: self.wait_frame_no_cb.clone(), + state: Default::default(), + }) + } +} + +impl InjectableDatabase for WriteProxyDatabase +where + RDB: InjectableDatabase, +{ + fn inject_frame(&mut self, frame: Frame) -> Result<(), crate::database::InjectError> { + // TODO: handle frame index + self.read_db.inject_frame(frame) + } +} diff --git a/libsqlx/src/database/proxy/mod.rs b/libsqlx/src/database/proxy/mod.rs new file mode 100644 index 00000000..1b5b3226 --- /dev/null +++ b/libsqlx/src/database/proxy/mod.rs @@ -0,0 +1,11 @@ +use std::sync::Arc; + +use super::FrameNo; + +mod connection; +mod database; + +pub use database::WriteProxyDatabase; + +// Waits until passed frameno has been replicated back to the database +type WaitFrameNoCb = Arc; diff --git a/libsqlx/src/database/test_utils.rs b/libsqlx/src/database/test_utils.rs new file mode 100644 index 00000000..86c072ea --- /dev/null +++ b/libsqlx/src/database/test_utils.rs @@ -0,0 +1,66 @@ +use std::sync::Arc; + +use crate::{ + connection::{Connection, DescribeResponse}, + program::Program, + result_builder::ResultBuilder, +}; + +use super::Database; + +pub struct MockDatabase { + #[allow(clippy::type_complexity)] + describe_fn: Arc crate::Result>, + #[allow(clippy::type_complexity)] + execute_fn: Arc crate::Result<()>>, +} + +pub struct MockConnection { + #[allow(clippy::type_complexity)] + describe_fn: Arc crate::Result>, + #[allow(clippy::type_complexity)] + execute_fn: Arc crate::Result<()>>, +} + +impl MockDatabase { + pub fn new() -> Self { + MockDatabase { + describe_fn: Arc::new(|_| panic!("describe fn not set")), + execute_fn: Arc::new(|_, _| panic!("execute fn not set")), + } + } + + pub fn with_execute( + mut self, + f: impl Fn(Program, &mut dyn ResultBuilder) -> crate::Result<()> + 'static, + ) -> Self { + self.execute_fn = Arc::new(f); + self + } +} + +impl Database for MockDatabase { + type Connection = MockConnection; + + fn connect(&self) -> Result { + Ok(MockConnection { + describe_fn: self.describe_fn.clone(), + execute_fn: self.execute_fn.clone(), + }) + } +} + +impl Connection for MockConnection { + fn execute_program( + &mut self, + pgm: crate::program::Program, + mut reponse_builder: B, + ) -> crate::Result { + (self.execute_fn)(pgm, &mut reponse_builder)?; + Ok(reponse_builder) + } + + fn describe(&self, sql: String) -> crate::Result { + (self.describe_fn)(sql) + } +} diff --git a/libsqlx/src/error.rs b/libsqlx/src/error.rs new file mode 100644 index 00000000..6e35e217 --- /dev/null +++ b/libsqlx/src/error.rs @@ -0,0 +1,44 @@ +use crate::result_builder::QueryResultBuilderError; + +#[allow(clippy::enum_variant_names)] +#[derive(Debug, thiserror::Error)] +pub enum Error { + #[error("LibSQL failed to bind provided query parameters: `{0}`")] + LibSqlInvalidQueryParams(anyhow::Error), + #[error("Transaction timed-out")] + LibSqlTxTimeout, + #[error("Server can't handle additional transactions")] + LibSqlTxBusy, + #[error(transparent)] + IOError(#[from] std::io::Error), + #[error(transparent)] + RusqliteError(#[from] rusqlite::Error), + #[error("Database value error: `{0}`")] + DbValueError(String), + // Dedicated for most generic internal errors. Please use it sparingly. + // Consider creating a dedicate enum value for your error. + #[error("Internal Error: `{0}`")] + Internal(String), + #[error("Invalid batch step: {0}")] + InvalidBatchStep(usize), + #[error("Not authorized to execute query: {0}")] + NotAuthorized(String), + #[error("The replicator exited, instance cannot make any progress.")] + ReplicatorExited, + #[error("Timed out while openning database connection")] + DbCreateTimeout, + #[error(transparent)] + BuilderError(#[from] QueryResultBuilderError), + #[error("Operation was blocked{}", .0.as_ref().map(|msg| format!(": {}", msg)).unwrap_or_default())] + Blocked(Option), + #[error("invalid replication log header")] + InvalidLogHeader, +} + +impl From for Error { + fn from(inner: tokio::sync::oneshot::error::RecvError) -> Self { + Self::Internal(format!( + "Failed to receive response via oneshot channel: {inner}" + )) + } +} diff --git a/libsqlx/src/lib.rs b/libsqlx/src/lib.rs new file mode 100644 index 00000000..986044e2 --- /dev/null +++ b/libsqlx/src/lib.rs @@ -0,0 +1,21 @@ +pub mod analysis; +pub mod error; +pub mod query; + +mod connection; +mod database; +mod program; +mod result_builder; +mod seal; +mod semaphore; + +pub type Result = std::result::Result; + +pub use connection::Connection; +pub use database::libsql; +pub use database::proxy; +pub use database::Database; +pub use program::Program; +pub use result_builder::{ + Column, QueryBuilderConfig, QueryResultBuilderError, ResultBuilder, ResultBuilderExt, +}; diff --git a/libsqlx/src/program.rs b/libsqlx/src/program.rs new file mode 100644 index 00000000..3eb2f551 --- /dev/null +++ b/libsqlx/src/program.rs @@ -0,0 +1,60 @@ +use std::sync::Arc; + +use crate::query::Query; + +#[derive(Debug, Clone)] +pub struct Program { + pub steps: Arc>, +} + +impl Program { + pub fn new(steps: Vec) -> Self { + Self { + steps: Arc::new(steps), + } + } + + pub fn is_read_only(&self) -> bool { + self.steps.iter().all(|s| s.query.stmt.is_read_only()) + } + + pub fn steps(&self) -> &[Step] { + self.steps.as_slice() + } + + #[cfg(test)] + pub fn seq(stmts: &[&str]) -> Self { + use crate::{analysis::Statement, query::Params}; + + let mut steps = Vec::with_capacity(stmts.len()); + for stmt in stmts { + let step = Step { + cond: None, + query: Query { + stmt: Statement::parse(stmt).next().unwrap().unwrap(), + params: Params::empty(), + want_rows: true, + }, + }; + + steps.push(step); + } + + Self::new(steps) + } +} + +#[derive(Debug, Clone)] +pub struct Step { + pub cond: Option, + pub query: Query, +} + +#[derive(Debug, Clone)] +pub enum Cond { + Ok { step: usize }, + Err { step: usize }, + Not { cond: Box }, + Or { conds: Vec }, + And { conds: Vec }, +} diff --git a/libsqlx/src/query.rs b/libsqlx/src/query.rs new file mode 100644 index 00000000..2d37e514 --- /dev/null +++ b/libsqlx/src/query.rs @@ -0,0 +1,267 @@ +use std::collections::HashMap; + +use anyhow::{anyhow, ensure, Context}; +use rusqlite::types::{ToSqlOutput, ValueRef}; +use rusqlite::ToSql; +use serde::{Deserialize, Serialize}; + +use crate::analysis::Statement; + +/// Mirrors rusqlite::Value, but implement extra traits +#[derive(Debug, Clone, Serialize, Deserialize)] +#[cfg_attr(test, derive(arbitrary::Arbitrary))] +pub enum Value { + Null, + Integer(i64), + Real(f64), + Text(String), + Blob(Vec), +} + +impl<'a> From<&'a Value> for ValueRef<'a> { + fn from(value: &'a Value) -> Self { + match value { + Value::Null => ValueRef::Null, + Value::Integer(i) => ValueRef::Integer(*i), + Value::Real(x) => ValueRef::Real(*x), + Value::Text(s) => ValueRef::Text(s.as_bytes()), + Value::Blob(b) => ValueRef::Blob(b.as_slice()), + } + } +} + +impl TryFrom> for Value { + type Error = anyhow::Error; + + fn try_from(value: rusqlite::types::ValueRef<'_>) -> anyhow::Result { + let val = match value { + rusqlite::types::ValueRef::Null => Value::Null, + rusqlite::types::ValueRef::Integer(i) => Value::Integer(i), + rusqlite::types::ValueRef::Real(x) => Value::Real(x), + rusqlite::types::ValueRef::Text(s) => Value::Text(String::from_utf8(Vec::from(s))?), + rusqlite::types::ValueRef::Blob(b) => Value::Blob(Vec::from(b)), + }; + + Ok(val) + } +} + +#[derive(Debug, Clone)] +pub struct Query { + pub stmt: Statement, + pub params: Params, + pub want_rows: bool, +} + +impl ToSql for Value { + fn to_sql(&self) -> rusqlite::Result> { + let val = match self { + Value::Null => ToSqlOutput::Owned(rusqlite::types::Value::Null), + Value::Integer(i) => ToSqlOutput::Owned(rusqlite::types::Value::Integer(*i)), + Value::Real(x) => ToSqlOutput::Owned(rusqlite::types::Value::Real(*x)), + Value::Text(s) => ToSqlOutput::Borrowed(rusqlite::types::ValueRef::Text(s.as_bytes())), + Value::Blob(b) => ToSqlOutput::Borrowed(rusqlite::types::ValueRef::Blob(b)), + }; + + Ok(val) + } +} + +#[derive(Debug, Serialize, Clone)] +pub enum Params { + Named(HashMap), + Positional(Vec), +} + +impl Params { + pub fn empty() -> Self { + Self::Positional(Vec::new()) + } + + pub fn new_named(values: HashMap) -> Self { + Self::Named(values) + } + + pub fn new_positional(values: Vec) -> Self { + Self::Positional(values) + } + + pub fn get_pos(&self, pos: usize) -> Option<&Value> { + assert!(pos > 0); + match self { + Params::Named(_) => None, + Params::Positional(params) => params.get(pos - 1), + } + } + + pub fn get_named(&self, name: &str) -> Option<&Value> { + match self { + Params::Named(params) => params.get(name), + Params::Positional(_) => None, + } + } + + pub fn len(&self) -> usize { + match self { + Params::Named(params) => params.len(), + Params::Positional(params) => params.len(), + } + } + + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + pub fn bind(&self, stmt: &mut rusqlite::Statement) -> anyhow::Result<()> { + let param_count = stmt.parameter_count(); + ensure!( + param_count >= self.len(), + "too many parameters, expected {param_count} found {}", + self.len() + ); + + if param_count > 0 { + for index in 1..=param_count { + let mut param_name = None; + // get by name + let maybe_value = match stmt.parameter_name(index) { + Some(name) => { + param_name = Some(name); + let mut chars = name.chars(); + match chars.next() { + Some('?') => { + let pos = chars.as_str().parse::().context( + "invalid parameter {name}: expected a numerical position after `?`", + )?; + self.get_pos(pos) + } + _ => self + .get_named(name) + .or_else(|| self.get_named(chars.as_str())), + } + } + None => self.get_pos(index), + }; + + if let Some(value) = maybe_value { + stmt.raw_bind_parameter(index, value)?; + } else if let Some(name) = param_name { + return Err(anyhow!("value for parameter {} not found", name)); + } else { + return Err(anyhow!("value for parameter {} not found", index)); + } + } + } + + Ok(()) + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_bind_params_positional_simple() { + let con = rusqlite::Connection::open_in_memory().unwrap(); + let mut stmt = con.prepare("SELECT ?").unwrap(); + let params = Params::new_positional(vec![Value::Integer(10)]); + params.bind(&mut stmt).unwrap(); + + assert_eq!(stmt.expanded_sql().unwrap(), "SELECT 10"); + } + + #[test] + fn test_bind_params_positional_numbered() { + let con = rusqlite::Connection::open_in_memory().unwrap(); + let mut stmt = con.prepare("SELECT ? || ?2 || ?1").unwrap(); + let params = Params::new_positional(vec![Value::Integer(10), Value::Integer(20)]); + params.bind(&mut stmt).unwrap(); + + assert_eq!(stmt.expanded_sql().unwrap(), "SELECT 10 || 20 || 10"); + } + + #[test] + fn test_bind_params_positional_named() { + let con = rusqlite::Connection::open_in_memory().unwrap(); + let mut stmt = con.prepare("SELECT :first || $second").unwrap(); + let mut params = HashMap::new(); + params.insert(":first".to_owned(), Value::Integer(10)); + params.insert("$second".to_owned(), Value::Integer(20)); + let params = Params::new_named(params); + params.bind(&mut stmt).unwrap(); + + assert_eq!(stmt.expanded_sql().unwrap(), "SELECT 10 || 20"); + } + + #[test] + fn test_bind_params_positional_named_no_prefix() { + let con = rusqlite::Connection::open_in_memory().unwrap(); + let mut stmt = con.prepare("SELECT :first || $second").unwrap(); + let mut params = HashMap::new(); + params.insert("first".to_owned(), Value::Integer(10)); + params.insert("second".to_owned(), Value::Integer(20)); + let params = Params::new_named(params); + params.bind(&mut stmt).unwrap(); + + assert_eq!(stmt.expanded_sql().unwrap(), "SELECT 10 || 20"); + } + + #[test] + fn test_bind_params_positional_named_conflict() { + let con = rusqlite::Connection::open_in_memory().unwrap(); + let mut stmt = con.prepare("SELECT :first || $first").unwrap(); + let mut params = HashMap::new(); + params.insert("first".to_owned(), Value::Integer(10)); + params.insert("$first".to_owned(), Value::Integer(20)); + let params = Params::new_named(params); + params.bind(&mut stmt).unwrap(); + + assert_eq!(stmt.expanded_sql().unwrap(), "SELECT 10 || 20"); + } + + #[test] + fn test_bind_params_positional_named_repeated() { + let con = rusqlite::Connection::open_in_memory().unwrap(); + let mut stmt = con + .prepare("SELECT :first || $second || $first || $second") + .unwrap(); + let mut params = HashMap::new(); + params.insert("first".to_owned(), Value::Integer(10)); + params.insert("$second".to_owned(), Value::Integer(20)); + let params = Params::new_named(params); + params.bind(&mut stmt).unwrap(); + + assert_eq!(stmt.expanded_sql().unwrap(), "SELECT 10 || 20 || 10 || 20"); + } + + #[test] + fn test_bind_params_too_many_params() { + let con = rusqlite::Connection::open_in_memory().unwrap(); + let mut stmt = con.prepare("SELECT :first || $second").unwrap(); + let mut params = HashMap::new(); + params.insert(":first".to_owned(), Value::Integer(10)); + params.insert("$second".to_owned(), Value::Integer(20)); + params.insert("$oops".to_owned(), Value::Integer(20)); + let params = Params::new_named(params); + assert!(params.bind(&mut stmt).is_err()); + } + + #[test] + fn test_bind_params_too_few_params() { + let con = rusqlite::Connection::open_in_memory().unwrap(); + let mut stmt = con.prepare("SELECT :first || $second").unwrap(); + let mut params = HashMap::new(); + params.insert(":first".to_owned(), Value::Integer(10)); + let params = Params::new_named(params); + assert!(params.bind(&mut stmt).is_err()); + } + + #[test] + fn test_bind_params_invalid_positional() { + let con = rusqlite::Connection::open_in_memory().unwrap(); + let mut stmt = con.prepare("SELECT ?invalid").unwrap(); + let params = Params::empty(); + assert!(params.bind(&mut stmt).is_err()); + } +} diff --git a/libsqlx/src/result_builder.rs b/libsqlx/src/result_builder.rs new file mode 100644 index 00000000..2784274c --- /dev/null +++ b/libsqlx/src/result_builder.rs @@ -0,0 +1,711 @@ +use std::fmt; +use std::io::{self, ErrorKind}; + +use bytesize::ByteSize; +use rusqlite::types::ValueRef; + +use crate::database::FrameNo; + +#[derive(Debug)] +pub enum QueryResultBuilderError { + ResponseTooLarge(u64), + Internal(anyhow::Error), +} + +impl fmt::Display for QueryResultBuilderError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + QueryResultBuilderError::ResponseTooLarge(s) => { + write!(f, "query response exceeds the maximum size of {}. Try reducing the number of queried rows.", ByteSize(*s)) + } + QueryResultBuilderError::Internal(e) => e.fmt(f), + } + } +} + +impl std::error::Error for QueryResultBuilderError {} + +impl From for QueryResultBuilderError { + fn from(value: anyhow::Error) -> Self { + Self::Internal(value) + } +} + +impl QueryResultBuilderError { + pub fn from_any>(e: E) -> Self { + Self::Internal(e.into()) + } +} + +impl From for QueryResultBuilderError { + fn from(value: io::Error) -> Self { + if value.kind() == ErrorKind::OutOfMemory + && value.get_ref().is_some() + && value.get_ref().unwrap().is::() + { + return *value + .into_inner() + .unwrap() + .downcast::() + .unwrap(); + } + Self::Internal(value.into()) + } +} + +/// Identical to rusqlite::Column, with visible fields. +#[cfg_attr(test, derive(arbitrary::Arbitrary))] +pub struct Column<'a> { + pub name: &'a str, + pub decl_ty: Option<&'a str>, +} + +impl<'a> From<(&'a str, Option<&'a str>)> for Column<'a> { + fn from((name, decl_ty): (&'a str, Option<&'a str>)) -> Self { + Self { name, decl_ty } + } +} + +impl<'a> From<&'a rusqlite::Column<'a>> for Column<'a> { + fn from(value: &'a rusqlite::Column<'a>) -> Self { + Self { + name: value.name(), + decl_ty: value.decl_type(), + } + } +} + +#[derive(Debug, Clone, Copy, Default)] +pub struct QueryBuilderConfig { + pub max_size: Option, +} + +pub trait ResultBuilder: Send + 'static { + /// (Re)initialize the builder. This method can be called multiple times. + fn init(&mut self, _config: &QueryBuilderConfig) -> Result<(), QueryResultBuilderError> { + Ok(()) + } + /// start serializing new step + fn begin_step(&mut self) -> Result<(), QueryResultBuilderError> { + Ok(()) + } + /// finish serializing current step + fn finish_step( + &mut self, + _affected_row_count: u64, + _last_insert_rowid: Option, + ) -> Result<(), QueryResultBuilderError> { + Ok(()) + } + /// emit an error to serialize. + fn step_error(&mut self, _error: crate::error::Error) -> Result<(), QueryResultBuilderError> { + Ok(()) + } + /// add cols description for current step. + /// This is called called at most once per step, and is always the first method being called + fn cols_description( + &mut self, + _cols: &mut dyn Iterator, + ) -> Result<(), QueryResultBuilderError> { + Ok(()) + } + /// start adding rows + fn begin_rows(&mut self) -> Result<(), QueryResultBuilderError> { + Ok(()) + } + /// begin a new row for the current step + fn begin_row(&mut self) -> Result<(), QueryResultBuilderError> { + Ok(()) + } + /// add value to current row + fn add_row_value(&mut self, _v: ValueRef) -> Result<(), QueryResultBuilderError> { + Ok(()) + } + /// finish current row + fn finish_row(&mut self) -> Result<(), QueryResultBuilderError> { + Ok(()) + } + /// end adding rows + fn finish_rows(&mut self) -> Result<(), QueryResultBuilderError> { + Ok(()) + } + /// finish the builder, and pass the transaction state. + fn finish( + &mut self, + _is_txn: bool, + _frame_no: Option, + ) -> Result<(), QueryResultBuilderError> { + Ok(()) + } +} + +pub trait ResultBuilderExt: ResultBuilder { + /// Returns a `QueryResultBuilder` that wraps Self and takes at most `n` steps + fn take(self, limit: usize) -> Take + where + Self: Sized, + { + Take { + limit, + count: 0, + inner: self, + } + } +} + +impl ResultBuilderExt for T {} + +#[derive(Debug)] +pub enum StepResult { + Ok, + Err(crate::error::Error), + Skipped, +} + +/// A `QueryResultBuilder` that ignores rows, but records the outcome of each step in a `StepResult` +#[derive(Debug, Default)] +pub struct StepResultsBuilder { + current: Option, + step_results: Vec, + is_skipped: bool, +} + +impl ResultBuilder for StepResultsBuilder { + fn init(&mut self, _config: &QueryBuilderConfig) -> Result<(), QueryResultBuilderError> { + *self = Default::default(); + Ok(()) + } + + fn begin_step(&mut self) -> Result<(), QueryResultBuilderError> { + self.is_skipped = true; + Ok(()) + } + + fn finish_step( + &mut self, + _affected_row_count: u64, + _last_insert_rowid: Option, + ) -> Result<(), QueryResultBuilderError> { + let res = match self.current.take() { + Some(e) => StepResult::Err(e), + None if self.is_skipped => StepResult::Skipped, + None => StepResult::Ok, + }; + + self.step_results.push(res); + + Ok(()) + } + + fn step_error(&mut self, error: crate::error::Error) -> Result<(), QueryResultBuilderError> { + assert!(self.current.is_none()); + self.current = Some(error); + + Ok(()) + } + + fn cols_description( + &mut self, + _cols: &mut dyn Iterator, + ) -> Result<(), QueryResultBuilderError> { + self.is_skipped = false; + Ok(()) + } + + fn begin_rows(&mut self) -> Result<(), QueryResultBuilderError> { + Ok(()) + } + + fn begin_row(&mut self) -> Result<(), QueryResultBuilderError> { + Ok(()) + } + + fn add_row_value(&mut self, _v: ValueRef) -> Result<(), QueryResultBuilderError> { + Ok(()) + } + + fn finish_row(&mut self) -> Result<(), QueryResultBuilderError> { + Ok(()) + } + + fn finish_rows(&mut self) -> Result<(), QueryResultBuilderError> { + Ok(()) + } + + fn finish( + &mut self, + _is_txn: bool, + _frame_no: Option, + ) -> Result<(), QueryResultBuilderError> { + Ok(()) + } +} + +impl ResultBuilder for () {} + +// A builder that wraps another builder, but takes at most `n` steps +pub struct Take { + limit: usize, + count: usize, + inner: B, +} + +impl Take { + pub fn into_inner(self) -> B { + self.inner + } +} + +impl ResultBuilder for Take { + fn init(&mut self, config: &QueryBuilderConfig) -> Result<(), QueryResultBuilderError> { + self.count = 0; + self.inner.init(config) + } + + fn begin_step(&mut self) -> Result<(), QueryResultBuilderError> { + if self.count < self.limit { + self.inner.begin_step() + } else { + Ok(()) + } + } + + fn finish_step( + &mut self, + affected_row_count: u64, + last_insert_rowid: Option, + ) -> Result<(), QueryResultBuilderError> { + if self.count < self.limit { + self.inner + .finish_step(affected_row_count, last_insert_rowid)?; + self.count += 1; + } + + Ok(()) + } + + fn step_error(&mut self, error: crate::error::Error) -> Result<(), QueryResultBuilderError> { + if self.count < self.limit { + self.inner.step_error(error) + } else { + Ok(()) + } + } + + fn cols_description( + &mut self, + cols: &mut dyn Iterator, + ) -> Result<(), QueryResultBuilderError> { + if self.count < self.limit { + self.inner.cols_description(cols) + } else { + Ok(()) + } + } + + fn begin_rows(&mut self) -> Result<(), QueryResultBuilderError> { + if self.count < self.limit { + self.inner.begin_rows() + } else { + Ok(()) + } + } + + fn begin_row(&mut self) -> Result<(), QueryResultBuilderError> { + if self.count < self.limit { + self.inner.begin_row() + } else { + Ok(()) + } + } + + fn add_row_value(&mut self, v: ValueRef) -> Result<(), QueryResultBuilderError> { + if self.count < self.limit { + self.inner.add_row_value(v) + } else { + Ok(()) + } + } + + fn finish_row(&mut self) -> Result<(), QueryResultBuilderError> { + if self.count < self.limit { + self.inner.finish_row() + } else { + Ok(()) + } + } + + fn finish_rows(&mut self) -> Result<(), QueryResultBuilderError> { + if self.count < self.limit { + self.inner.finish_rows() + } else { + Ok(()) + } + } + + fn finish( + &mut self, + is_txn: bool, + frame_no: Option, + ) -> Result<(), QueryResultBuilderError> { + self.inner.finish(is_txn, frame_no) + } +} + +#[cfg(test)] +pub mod test { + #![allow(dead_code)] + use std::fmt; + + use arbitrary::{Arbitrary, Unstructured}; + use itertools::Itertools; + use rand::{ + distributions::{Standard, WeightedIndex}, + prelude::Distribution, + thread_rng, Fill, Rng, + }; + use FsmState::*; + + use super::*; + + /// a dummy QueryResultBuilder that encodes the QueryResultBuilder FSM. It can be passed to a + /// driver to ensure that it is not mis-used + + #[derive(Debug, PartialEq, Eq, Clone, Copy)] + #[repr(usize)] + // do not reorder! + enum FsmState { + Init = 0, + Finish, + BeginStep, + FinishStep, + StepError, + ColsDescription, + FinishRows, + BeginRows, + FinishRow, + BeginRow, + AddRowValue, + BuilderError, + } + + #[rustfmt::skip] + static TRANSITION_TABLE: [[bool; 12]; 12] = [ + //FROM: + //Init Finish BeginStep FinishStep StepError ColsDes FinishRows BegRows FinishRow BegRow AddRowVal BuidlerErr TO: + [true , true , true , true , true , true , true , true , true , true , true , false], // Init, + [true , false, false, true , false, false, false, false, false, false, false, false], // Finish, + [true , false, false, true , false, false, false, false, false, false, false, false], // BeginStep + [false, false, true , false, true , false, true , false, false, false, false, false], // FinishStep + [false, false, true , false, false, true , true , true , true , true , true , false], // StepError + [false, false, true , false, false, false, false, false, false, false, false, false], // ColsDescr + [false, false, false, false, false, false, false, true , true , false, false, false], // FinishRows + [false, false, false, false, false, true , false, false, false, false, false, false], // BeginRows + [false, false, false, false, false, false, false, false, false, true , true , false], // FinishRow + [false, false, false, false, false, false, false, true , true , false, false, false], // BeginRow, + [false, false, false, false, false, false, false, false, false, true , true , false], // AddRowValue + [true , true , true , true , true , true , true , true , true , true , true , false], // BuilderError + ]; + + impl FsmState { + /// returns a random valid transition from the current state + fn rand_transition(self, allow_init: bool) -> Self { + let valid_next_states = TRANSITION_TABLE[..TRANSITION_TABLE.len() - 1] // ignore + // builder error + .iter() + .enumerate() + .skip(if allow_init { 0 } else { 1 }) + .filter_map(|(i, ss)| ss[self as usize].then_some(i)) + .collect_vec(); + // distribution is somewhat tweaked to be biased towards more real-world test cases + let weigths = valid_next_states + .iter() + .enumerate() + .map(|(p, i)| i.pow(p as _)) + .collect_vec(); + let dist = WeightedIndex::new(weigths).unwrap(); + unsafe { std::mem::transmute(valid_next_states[dist.sample(&mut thread_rng())]) } + } + + /// moves towards the finish step as fast as possible + fn toward_finish(self) -> Self { + match self { + Init => Finish, + BeginStep => FinishStep, + FinishStep => Finish, + StepError => FinishStep, + BeginRows | BeginRow | AddRowValue | FinishRow | FinishRows | ColsDescription => { + StepError + } + Finish => Finish, + BuilderError => Finish, + } + } + } + + pub fn random_builder_driver(mut max_steps: usize, mut b: B) -> B { + let mut rand_data = [0; 10_000]; + rand_data.try_fill(&mut rand::thread_rng()).unwrap(); + let mut u = Unstructured::new(&rand_data); + let mut trace = Vec::new(); + + #[derive(Arbitrary)] + pub enum ValueRef<'a> { + Null, + Integer(i64), + Real(f64), + Text(&'a str), + Blob(&'a [u8]), + } + + impl<'a> From> for rusqlite::types::ValueRef<'a> { + fn from(value: ValueRef<'a>) -> Self { + match value { + ValueRef::Null => rusqlite::types::ValueRef::Null, + ValueRef::Integer(i) => rusqlite::types::ValueRef::Integer(i), + ValueRef::Real(x) => rusqlite::types::ValueRef::Real(x), + ValueRef::Text(s) => rusqlite::types::ValueRef::Text(s.as_bytes()), + ValueRef::Blob(b) => rusqlite::types::ValueRef::Blob(b), + } + } + } + + let mut state = Init; + trace.push(state); + loop { + match state { + Init => b.init(&QueryBuilderConfig::default()).unwrap(), + BeginStep => b.begin_step().unwrap(), + FinishStep => b + .finish_step( + Arbitrary::arbitrary(&mut u).unwrap(), + Arbitrary::arbitrary(&mut u).unwrap(), + ) + .unwrap(), + StepError => b.step_error(crate::error::Error::LibSqlTxBusy).unwrap(), + ColsDescription => b + .cols_description(&mut >::arbitrary(&mut u).unwrap().into_iter()) + .unwrap(), + BeginRows => b.begin_rows().unwrap(), + BeginRow => b.begin_row().unwrap(), + AddRowValue => b + .add_row_value(ValueRef::arbitrary(&mut u).unwrap().into()) + .unwrap(), + FinishRow => b.finish_row().unwrap(), + FinishRows => b.finish_rows().unwrap(), + Finish => { + b.finish(false, None).unwrap(); + break; + } + BuilderError => return b, + } + + if max_steps > 0 { + state = state.rand_transition(false); + } else { + state = state.toward_finish() + } + + trace.push(state); + + max_steps = max_steps.saturating_sub(1); + } + + // this can be usefull to help debug the generated test case + dbg!(trace); + + b + } + + pub struct FsmQueryBuilder { + state: FsmState, + inject_errors: bool, + } + + impl fmt::Display for FsmState { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let s = match self { + Init => "init", + BeginStep => "begin_step", + FinishStep => "finish_step", + StepError => "step_error", + ColsDescription => "cols_description", + BeginRows => "begin_rows", + BeginRow => "begin_row", + AddRowValue => "add_row_value", + FinishRow => "finish_row", + FinishRows => "finish_rows", + Finish => "finish", + BuilderError => "a builder error", + }; + + f.write_str(s) + } + } + + impl FsmQueryBuilder { + fn new(inject_errors: bool) -> Self { + Self { + state: Init, + inject_errors, + } + } + + fn transition(&mut self, to: FsmState) -> Result<(), QueryResultBuilderError> { + let from = self.state as usize; + if TRANSITION_TABLE[to as usize][from] { + self.state = to; + } else { + panic!("{} can't be called after {}", to, self.state); + } + + Ok(()) + } + + fn maybe_inject_error(&mut self) -> Result<(), QueryResultBuilderError> { + if self.inject_errors { + let val: f32 = thread_rng().sample(Standard); + // < 0.1% change to generate error + if val < 0.001 { + self.state = BuilderError; + Err(anyhow::anyhow!("dummy"))?; + } + } + + Ok(()) + } + } + + impl ResultBuilder for FsmQueryBuilder { + fn init(&mut self, _config: &QueryBuilderConfig) -> Result<(), QueryResultBuilderError> { + self.maybe_inject_error()?; + self.transition(Init) + } + + fn begin_step(&mut self) -> Result<(), QueryResultBuilderError> { + self.maybe_inject_error()?; + self.transition(BeginStep) + } + + fn finish_step( + &mut self, + _affected_row_count: u64, + _last_insert_rowid: Option, + ) -> Result<(), QueryResultBuilderError> { + self.maybe_inject_error()?; + self.transition(FinishStep) + } + + fn step_error( + &mut self, + _error: crate::error::Error, + ) -> Result<(), QueryResultBuilderError> { + self.maybe_inject_error()?; + self.transition(StepError) + } + + fn cols_description( + &mut self, + _cols: &mut dyn Iterator, + ) -> Result<(), QueryResultBuilderError> { + self.maybe_inject_error()?; + self.transition(ColsDescription) + } + + fn begin_rows(&mut self) -> Result<(), QueryResultBuilderError> { + self.maybe_inject_error()?; + self.transition(BeginRows) + } + + fn begin_row(&mut self) -> Result<(), QueryResultBuilderError> { + self.maybe_inject_error()?; + self.transition(BeginRow) + } + + fn add_row_value(&mut self, _v: ValueRef) -> Result<(), QueryResultBuilderError> { + self.maybe_inject_error()?; + self.transition(AddRowValue) + } + + fn finish_row(&mut self) -> Result<(), QueryResultBuilderError> { + self.maybe_inject_error()?; + self.transition(FinishRow) + } + + fn finish_rows(&mut self) -> Result<(), QueryResultBuilderError> { + self.maybe_inject_error()?; + self.transition(FinishRows) + } + + fn finish( + &mut self, + _is_txn: bool, + _frame_no: Option, + ) -> Result<(), QueryResultBuilderError> { + self.maybe_inject_error()?; + self.transition(Finish) + } + } + + pub fn test_driver( + iter: usize, + f: impl Fn(FsmQueryBuilder) -> Result, + ) { + for _ in 0..iter { + // inject random errors + let builder = FsmQueryBuilder::new(true); + match f(builder) { + Ok(b) => { + assert_eq!(b.state, Finish); + } + Err(e) => { + assert!(matches!(e, crate::error::Error::BuilderError(_))); + } + } + } + } + + #[test] + fn test_fsm_ok() { + let mut builder = FsmQueryBuilder::new(false); + builder.init(&QueryBuilderConfig::default()).unwrap(); + + builder.begin_step().unwrap(); + builder + .cols_description(&mut [("hello", None).into()].into_iter()) + .unwrap(); + builder.begin_rows().unwrap(); + builder.begin_row().unwrap(); + builder.add_row_value(ValueRef::Null).unwrap(); + builder.finish_row().unwrap(); + builder + .step_error(crate::error::Error::LibSqlTxBusy) + .unwrap(); + builder.finish_step(0, None).unwrap(); + + builder.begin_step().unwrap(); + builder + .cols_description(&mut [("hello", None).into()].into_iter()) + .unwrap(); + builder.begin_rows().unwrap(); + builder.begin_row().unwrap(); + builder.add_row_value(ValueRef::Null).unwrap(); + builder.finish_row().unwrap(); + builder.finish_rows().unwrap(); + builder.finish_step(0, None).unwrap(); + + builder.finish(false, None).unwrap(); + } + + #[test] + #[should_panic] + fn test_fsm_invalid() { + let mut builder = FsmQueryBuilder::new(false); + builder.init(&QueryBuilderConfig::default()).unwrap(); + builder.begin_step().unwrap(); + builder.begin_rows().unwrap(); + } + + #[allow(dead_code)] + fn is_trait_objectifiable(_: Box) {} +} diff --git a/libsqlx/src/seal.rs b/libsqlx/src/seal.rs new file mode 100644 index 00000000..393accc4 --- /dev/null +++ b/libsqlx/src/seal.rs @@ -0,0 +1,8 @@ +/// Hold some type, but prevent any access to it +pub struct Seal(T); + +impl Seal { + pub fn new(t: T) -> Self { + Seal(t) + } +} diff --git a/libsqlx/src/semaphore.rs b/libsqlx/src/semaphore.rs new file mode 100644 index 00000000..a47a4eb1 --- /dev/null +++ b/libsqlx/src/semaphore.rs @@ -0,0 +1,98 @@ +use std::sync::Arc; +use std::time::Duration; +use std::time::Instant; + +use parking_lot::Condvar; +use parking_lot::Mutex; + +struct SemaphoreInner { + max_permits: usize, + permits: Mutex, + condvar: Condvar, +} + +#[derive(Clone)] +pub struct Semaphore { + inner: Arc, +} + +pub struct Permit(Semaphore); + +impl Drop for Permit { + fn drop(&mut self) { + *self.0.inner.permits.lock() -= 1; + self.0.inner.condvar.notify_one(); + } +} + +impl Semaphore { + pub fn new(max_permits: usize) -> Self { + Self { + inner: Arc::new(SemaphoreInner { + max_permits, + permits: Mutex::new(0), + condvar: Condvar::new(), + }), + } + } + + pub fn acquire(&self) -> Permit { + let mut permits = self.inner.permits.lock(); + self.inner + .condvar + .wait_while(&mut permits, |permits| *permits >= self.inner.max_permits); + *permits += 1; + assert!(*permits <= self.inner.max_permits); + Permit(self.clone()) + } + + pub fn acquire_timeout(&self, timeout: Duration) -> Option { + let deadline = Instant::now() + timeout; + let mut permits = self.inner.permits.lock(); + if self + .inner + .condvar + .wait_while_until( + &mut permits, + |permits| *permits >= self.inner.max_permits, + deadline, + ) + .timed_out() + { + return None; + } + + *permits += 1; + assert!(*permits <= self.inner.max_permits); + Some(Permit(self.clone())) + } + + #[cfg(test)] + fn try_acquire(&self) -> Option { + let mut permits = self.inner.permits.lock(); + if *permits >= self.inner.max_permits { + None + } else { + *permits += 1; + Some(Permit(self.clone())) + } + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn semaphore() { + let sem = Semaphore::new(2); + let permit1 = sem.acquire(); + let _permit2 = sem.acquire(); + + assert!(sem.try_acquire().is_none()); + drop(permit1); + let perm = sem.try_acquire(); + assert!(perm.is_some()); + assert!(sem.acquire_timeout(Duration::from_millis(100)).is_none()); + } +} diff --git a/sqld-libsql-bindings/src/wal_hook.rs b/sqld-libsql-bindings/src/wal_hook.rs index 7f09ad31..30b21995 100644 --- a/sqld-libsql-bindings/src/wal_hook.rs +++ b/sqld-libsql-bindings/src/wal_hook.rs @@ -15,7 +15,7 @@ use crate::get_orig_wal_methods; macro_rules! init_static_wal_method { ($name:ident, $ty:path) => { pub static $name: $crate::Lazy<&'static $crate::WalMethodsHook<$ty>> = - once_cell::sync::Lazy::new(|| { + $crate::Lazy::new(|| { // we need a 'static address before we can register the methods. static METHODS: $crate::Lazy<$crate::WalMethodsHook<$ty>> = $crate::Lazy::new(|| $crate::WalMethodsHook::<$ty>::new()); @@ -45,7 +45,7 @@ macro_rules! init_static_wal_method { /// /// # Safety /// The implementer is responsible for calling the orig method with valid arguments. -pub unsafe trait WalHook { +pub unsafe trait WalHook: Send + Sync + 'static { type Context; fn name() -> &'static CStr; diff --git a/sqld/src/replication/replica/hook.rs b/sqld/src/replication/replica/hook.rs index 0cad303c..57241896 100644 --- a/sqld/src/replication/replica/hook.rs +++ b/sqld/src/replication/replica/hook.rs @@ -78,9 +78,9 @@ pub struct InjectorHookCtx { /// currently in a txn pub is_txn: bool, /// invoked before injecting frames - pre_commit: Box anyhow::Result<()>>, + pre_commit: Box anyhow::Result<()> + Send + 'static>, /// invoked after injecting frames - post_commit: Box anyhow::Result<()>>, + post_commit: Box anyhow::Result<()> + Send + 'static>, } impl InjectorHookCtx { From e25ea8f285515872def6df28b0c7d34e091e8de6 Mon Sep 17 00:00:00 2001 From: ad hoc Date: Thu, 6 Jul 2023 15:19:36 +0200 Subject: [PATCH 02/64] setup logging --- Cargo.lock | 154 +++++++++++++++++++++++++++++++------ Cargo.toml | 1 + libsqlx-server/Cargo.toml | 14 ++++ libsqlx-server/src/main.rs | 29 +++++++ 4 files changed, 176 insertions(+), 22 deletions(-) create mode 100644 libsqlx-server/Cargo.toml create mode 100644 libsqlx-server/src/main.rs diff --git a/Cargo.lock b/Cargo.lock index 6d0349e7..fc45ee48 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8,7 +8,16 @@ version = "0.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b9ecd88a8c8378ca913a680cd98f0f13ac67383d35993f86c90a70e3f137816b" dependencies = [ - "gimli", + "gimli 0.26.2", +] + +[[package]] +name = "addr2line" +version = "0.20.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f4fa78e18c64fce05e902adecd7a5eed15a5e0a3439f7b0e169f0252214865e3" +dependencies = [ + "gimli 0.27.3", ] [[package]] @@ -643,6 +652,21 @@ dependencies = [ "tower-service", ] +[[package]] +name = "backtrace" +version = "0.3.68" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4319208da049c43661739c5fade2ba182f09d1dc2299b32298d3a31692b17e12" +dependencies = [ + "addr2line 0.20.0", + "cc", + "cfg-if", + "libc", + "miniz_oxide", + "object 0.31.1", + "rustc-demangle", +] + [[package]] name = "base64" version = "0.13.1" @@ -970,9 +994,9 @@ dependencies = [ [[package]] name = "clap" -version = "4.3.0" +version = "4.3.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "93aae7a4192245f70fe75dd9157fc7b4a5bf53e88d30bd4396f7d8f9284d5acc" +checksum = "1640e5cc7fb47dbb8338fd471b105e7ed6c3cb2aeb00c2e067127ffd3764a05d" dependencies = [ "clap_builder", "clap_derive", @@ -981,22 +1005,21 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.3.0" +version = "4.3.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4f423e341edefb78c9caba2d9c7f7687d0e72e89df3ce3394554754393ac3990" +checksum = "98c59138d527eeaf9b53f35a77fcc1fad9d883116070c63d5de1c7dc7b00c72b" dependencies = [ "anstream", "anstyle", - "bitflags 1.3.2", "clap_lex", "strsim", ] [[package]] name = "clap_derive" -version = "4.3.0" +version = "4.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "191d9573962933b4027f932c600cd252ce27a8ad5979418fe78e43c07996f27b" +checksum = "b8cd2b2a819ad6eec39e8f1d6b53001af1e5469f8c177579cdaeb313115b825f" dependencies = [ "heck", "proc-macro2", @@ -1010,6 +1033,33 @@ version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2da6da31387c7e4ef160ffab6d5e7f00c42626fe39aea70a7b0f1773f7dd6c1b" +[[package]] +name = "color-eyre" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a667583cca8c4f8436db8de46ea8233c42a7d9ae424a82d338f2e4675229204" +dependencies = [ + "backtrace", + "color-spantrace", + "eyre", + "indenter", + "once_cell", + "owo-colors", + "tracing-error", +] + +[[package]] +name = "color-spantrace" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ba75b3d9449ecdccb27ecbc479fdc0b87fa2dd43d2f8298f9bf0e59aacc8dce" +dependencies = [ + "once_cell", + "owo-colors", + "tracing-core", + "tracing-error", +] + [[package]] name = "colorchoice" version = "1.0.0" @@ -1145,7 +1195,7 @@ dependencies = [ "cranelift-egraph", "cranelift-entity", "cranelift-isle", - "gimli", + "gimli 0.26.2", "log", "regalloc2", "smallvec", @@ -1514,6 +1564,16 @@ version = "2.5.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0206175f82b8d6bf6652ff7d71a1e27fd2e4efde587fd368662814d6ec1d9ce0" +[[package]] +name = "eyre" +version = "0.6.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4c2b6b5a29c02cdc822728b7d7b8ae1bab3e3b05d44522770ddd49722eeac7eb" +dependencies = [ + "indenter", + "once_cell", +] + [[package]] name = "fallible-iterator" version = "0.2.0" @@ -1784,6 +1844,12 @@ dependencies = [ "stable_deref_trait", ] +[[package]] +name = "gimli" +version = "0.27.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6c80984affa11d98d1b88b66ac8853f143217b399d3c74116778ff8fdb4ed2e" + [[package]] name = "glob" version = "0.3.1" @@ -2082,6 +2148,12 @@ dependencies = [ "unicode-normalization", ] +[[package]] +name = "indenter" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ce23b50ad8242c51a442f3ff322d56b02f08852c77e4c0b4d3fd684abc89c683" + [[package]] name = "indexmap" version = "1.9.3" @@ -2292,9 +2364,9 @@ checksum = "884e2677b40cc8c339eaefcb701c32ef1fd2493d71118dc0ca4b6a736c93bd67" [[package]] name = "libc" -version = "0.2.144" +version = "0.2.147" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2b00cc1c228a6782d0f076e7b232802e0c5689d41bb5df366f2a6b6621cfdfe1" +checksum = "b4668fb0ea861c1df094127ac5f1da3409a82116a4ba74fca2e58ef927159bb3" [[package]] name = "libloading" @@ -2391,6 +2463,18 @@ dependencies = [ "uuid", ] +[[package]] +name = "libsqlx-server" +version = "0.1.0" +dependencies = [ + "axum", + "clap", + "color-eyre", + "tokio", + "tracing", + "tracing-subscriber", +] + [[package]] name = "linked-hash-map" version = "0.5.6" @@ -2684,6 +2768,15 @@ dependencies = [ "memchr", ] +[[package]] +name = "object" +version = "0.31.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8bda667d9f2b5051b8833f59f3bf748b28ef54f850f4fcb389a252aa383866d1" +dependencies = [ + "memchr", +] + [[package]] name = "octopod" version = "0.1.0" @@ -2775,6 +2868,12 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" +[[package]] +name = "owo-colors" +version = "3.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c1b04fb49957986fdce4d6ee7a65027d55d4b6d2265e5848bbb507b58ccfdb6f" + [[package]] name = "parking_lot" version = "0.12.1" @@ -4035,11 +4134,12 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokio" -version = "1.28.2" +version = "1.29.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "94d7b1cfd2aa4011f2de74c2c4c63665e27a71006b0a192dcd2710272e73dfa2" +checksum = "532826ff75199d5833b9d2c5fe410f29235e25704ee5f0ef599fb51c21f4a4da" dependencies = [ "autocfg", + "backtrace", "bytes 1.4.0", "libc", "mio", @@ -4314,6 +4414,16 @@ dependencies = [ "valuable", ] +[[package]] +name = "tracing-error" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d686ec1c0f384b1277f097b2f279a2ecc11afe8c133c1aabf036a27cb4cd206e" +dependencies = [ + "tracing", + "tracing-subscriber", +] + [[package]] name = "tracing-futures" version = "0.2.5" @@ -4684,7 +4794,7 @@ dependencies = [ "indexmap 1.9.3", "libc", "log", - "object", + "object 0.29.0", "once_cell", "paste", "psm", @@ -4743,9 +4853,9 @@ dependencies = [ "cranelift-frontend", "cranelift-native", "cranelift-wasm", - "gimli", + "gimli 0.26.2", "log", - "object", + "object 0.29.0", "target-lexicon", "thiserror", "wasmparser", @@ -4760,10 +4870,10 @@ checksum = "754b97f7441ac780a7fa738db5b9c23c1b70ef4abccd8ad205ada5669d196ba2" dependencies = [ "anyhow", "cranelift-entity", - "gimli", + "gimli 0.26.2", "indexmap 1.9.3", "log", - "object", + "object 0.29.0", "serde", "target-lexicon", "thiserror", @@ -4790,15 +4900,15 @@ version = "3.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "32800cb6e29faabab7056593f70a4c00c65c75c365aaf05406933f2169d0c22f" dependencies = [ - "addr2line", + "addr2line 0.17.0", "anyhow", "bincode", "cfg-if", "cpp_demangle", - "gimli", + "gimli 0.26.2", "ittapi", "log", - "object", + "object 0.29.0", "rustc-demangle", "serde", "target-lexicon", @@ -4816,7 +4926,7 @@ version = "3.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fe057012a0ba6cee3685af1e923d6e0a6cb9baf15fb3ffa4be3d7f712c7dec42" dependencies = [ - "object", + "object 0.29.0", "once_cell", "rustix 0.35.13", ] diff --git a/Cargo.toml b/Cargo.toml index 238333f0..26e14f1e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,6 +8,7 @@ members = [ "sqld-libsql-bindings", "testing/end-to-end", "libsqlx", + "libsqlx-server", ] [workspace.dependencies] diff --git a/libsqlx-server/Cargo.toml b/libsqlx-server/Cargo.toml new file mode 100644 index 00000000..66ba92b7 --- /dev/null +++ b/libsqlx-server/Cargo.toml @@ -0,0 +1,14 @@ +[package] +name = "libsqlx-server" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +axum = "0.6.18" +clap = { version = "4.3.11", features = ["derive"] } +color-eyre = "0.6.2" +tokio = { version = "1.29.1", features = ["full"] } +tracing = "0.1.37" +tracing-subscriber = "0.3.17" diff --git a/libsqlx-server/src/main.rs b/libsqlx-server/src/main.rs new file mode 100644 index 00000000..a4eeb1e1 --- /dev/null +++ b/libsqlx-server/src/main.rs @@ -0,0 +1,29 @@ +use color_eyre::eyre::Result; +use tracing::metadata::LevelFilter; +use tracing_subscriber::prelude::*; + +#[tokio::main] +async fn main() -> Result<()> { + init(); + + Ok(()) +} + +fn init() { + let registry = tracing_subscriber::registry(); + + registry + .with( + tracing_subscriber::fmt::layer() + .with_ansi(false) + .with_filter( + tracing_subscriber::EnvFilter::builder() + .with_default_directive(LevelFilter::INFO.into()) + .from_env_lossy(), + ), + ) + .init(); + + color_eyre::install().unwrap(); +} + From e301b80d2c05d1d8100f53e3d0807b77228d7591 Mon Sep 17 00:00:00 2001 From: ad hoc Date: Thu, 6 Jul 2023 15:34:58 +0200 Subject: [PATCH 03/64] admin server boilerplate --- Cargo.lock | 5 +++-- libsqlx-server/Cargo.toml | 1 + libsqlx-server/src/http/admin.rs | 21 +++++++++++++++++++++ libsqlx-server/src/http/mod.rs | 1 + libsqlx-server/src/main.rs | 3 ++- 5 files changed, 28 insertions(+), 3 deletions(-) create mode 100644 libsqlx-server/src/http/admin.rs create mode 100644 libsqlx-server/src/http/mod.rs diff --git a/Cargo.lock b/Cargo.lock index fc45ee48..06c5b482 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2014,9 +2014,9 @@ checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4" [[package]] name = "hyper" -version = "0.14.26" +version = "0.14.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ab302d72a6f11a3b910431ff93aae7e773078c769f0a3ef15fb9ec692ed147d4" +checksum = "ffb1cfd654a8219eaef89881fdb3bb3b1cdc5fa75ded05d6933b2b382e395468" dependencies = [ "bytes 1.4.0", "futures-channel", @@ -2470,6 +2470,7 @@ dependencies = [ "axum", "clap", "color-eyre", + "hyper", "tokio", "tracing", "tracing-subscriber", diff --git a/libsqlx-server/Cargo.toml b/libsqlx-server/Cargo.toml index 66ba92b7..7fe6b737 100644 --- a/libsqlx-server/Cargo.toml +++ b/libsqlx-server/Cargo.toml @@ -9,6 +9,7 @@ edition = "2021" axum = "0.6.18" clap = { version = "4.3.11", features = ["derive"] } color-eyre = "0.6.2" +hyper = { version = "0.14.27", features = ["h2"] } tokio = { version = "1.29.1", features = ["full"] } tracing = "0.1.37" tracing-subscriber = "0.3.17" diff --git a/libsqlx-server/src/http/admin.rs b/libsqlx-server/src/http/admin.rs new file mode 100644 index 00000000..cda7ee55 --- /dev/null +++ b/libsqlx-server/src/http/admin.rs @@ -0,0 +1,21 @@ +use std::sync::Arc; + +use axum::Router; +use color_eyre::eyre::Result; +use hyper::server::accept::Accept; +use tokio::io::{AsyncRead, AsyncWrite}; + +pub struct AdminServerConfig { } + +struct AdminServerState { } + +pub async fn run_admin_server(_config: AdminServerConfig, listener: I) -> Result<()> +where I: Accept, + I::Conn: AsyncRead + AsyncWrite + Send + Unpin + 'static, +{ + let state = AdminServerState { }; + let app = Router::new().with_state(Arc::new(state)); + axum::Server::builder(listener).serve(app.into_make_service()).await?; + + Ok(()) +} diff --git a/libsqlx-server/src/http/mod.rs b/libsqlx-server/src/http/mod.rs new file mode 100644 index 00000000..92918b09 --- /dev/null +++ b/libsqlx-server/src/http/mod.rs @@ -0,0 +1 @@ +pub mod admin; diff --git a/libsqlx-server/src/main.rs b/libsqlx-server/src/main.rs index a4eeb1e1..15e33e24 100644 --- a/libsqlx-server/src/main.rs +++ b/libsqlx-server/src/main.rs @@ -2,6 +2,8 @@ use color_eyre::eyre::Result; use tracing::metadata::LevelFilter; use tracing_subscriber::prelude::*; +mod http; + #[tokio::main] async fn main() -> Result<()> { init(); @@ -26,4 +28,3 @@ fn init() { color_eyre::install().unwrap(); } - From 92c18367b791379f465c732d241f02924ecfe75a Mon Sep 17 00:00:00 2001 From: ad hoc Date: Fri, 7 Jul 2023 10:46:15 +0200 Subject: [PATCH 04/64] introduce allocation type to manages a single database --- libsqlx-server/src/allocation/config.rs | 9 ++ libsqlx-server/src/allocation/mod.rs | 124 ++++++++++++++++++++++++ 2 files changed, 133 insertions(+) create mode 100644 libsqlx-server/src/allocation/config.rs create mode 100644 libsqlx-server/src/allocation/mod.rs diff --git a/libsqlx-server/src/allocation/config.rs b/libsqlx-server/src/allocation/config.rs new file mode 100644 index 00000000..19a6396b --- /dev/null +++ b/libsqlx-server/src/allocation/config.rs @@ -0,0 +1,9 @@ +use serde::{Serialize, Deserialize}; + +#[derive(Debug, Serialize, Deserialize)] +pub enum AllocConfig { + Primary { }, + Replica { + primary_node_id: String, + } +} diff --git a/libsqlx-server/src/allocation/mod.rs b/libsqlx-server/src/allocation/mod.rs new file mode 100644 index 00000000..11ae61ae --- /dev/null +++ b/libsqlx-server/src/allocation/mod.rs @@ -0,0 +1,124 @@ +use std::collections::HashMap; + +use libsqlx::Database; +use tokio::{sync::{mpsc, oneshot}, task::{JoinSet, block_in_place}}; + +pub mod config; + +type ExecFn = Box; + +#[derive(Clone)] +struct ConnectionId { + id: u32, + close_sender: mpsc::Sender<()>, +} + +enum AllocationMessage { + /// Execute callback against connection + Exec { + connection_id: ConnectionId, + exec: ExecFn, + }, + /// Create a new connection, execute the callback and return the connection id. + NewConnExec { + exec: ExecFn, + ret: oneshot::Sender, + } +} + +pub struct Allocation { + inbox: mpsc::Receiver, + database: Box, + /// senders to the spawned connections + connections: HashMap>, + /// spawned connection futures, returning their connection id on completion. + connections_futs: JoinSet, + next_conn_id: u32, + max_concurrent_connections: u32, +} + +impl Allocation { + async fn run(mut self) { + loop { + tokio::select! { + Some(msg) = self.inbox.recv() => { + match msg { + AllocationMessage::Exec { connection_id, exec } => { + if let Some(sender) = self.connections.get(&connection_id.id) { + if let Err(_) = sender.send(exec).await { + tracing::debug!("connection {} closed.", connection_id.id); + self.connections.remove_entry(&connection_id.id); + } + } + }, + AllocationMessage::NewConnExec { exec, ret } => { + let id = self.new_conn_exec(exec).await; + let _ = ret.send(id); + }, + } + }, + maybe_id = self.connections_futs.join_next() => { + if let Some(Ok(id)) = maybe_id { + self.connections.remove_entry(&id); + } + }, + else => break, + } + } + } + + async fn new_conn_exec(&mut self, exec: ExecFn) -> ConnectionId { + let id = self.next_conn_id(); + let conn = block_in_place(|| self.database.connect()).unwrap(); + let (close_sender, exit) = mpsc::channel(1); + let (exec_sender, exec_receiver) = mpsc::channel(1); + let conn = Connection { + id, + conn, + exit, + exec: exec_receiver, + }; + + + self.connections_futs.spawn(conn.run()); + // This should never block! + assert!(exec_sender.try_send(exec).is_ok()); + assert!(self.connections.insert(id, exec_sender).is_none()); + + ConnectionId { + id, + close_sender, + } + } + + fn next_conn_id(&mut self) -> u32 { + loop { + self.next_conn_id = self.next_conn_id.wrapping_add(1); + if !self.connections.contains_key(&self.next_conn_id) { + return self.next_conn_id + } + } + } +} + +struct Connection { + id: u32, + conn: Box, + exit: mpsc::Receiver<()>, + exec: mpsc::Receiver, +} + +impl Connection { + async fn run(mut self) -> u32 { + loop { + tokio::select! { + _ = self.exit.recv() => break, + Some(exec) = self.exec.recv() => { + tokio::task::block_in_place(|| exec(&mut *self.conn)); + } + } + } + + self.id + } +} From 110c5fa0404e61186a71d454470d1a302c36cf42 Mon Sep 17 00:00:00 2001 From: ad hoc Date: Fri, 7 Jul 2023 10:47:46 +0200 Subject: [PATCH 05/64] sketch meta store --- libsqlx-server/src/meta.rs | 10 ++++++++++ 1 file changed, 10 insertions(+) create mode 100644 libsqlx-server/src/meta.rs diff --git a/libsqlx-server/src/meta.rs b/libsqlx-server/src/meta.rs new file mode 100644 index 00000000..7f48a456 --- /dev/null +++ b/libsqlx-server/src/meta.rs @@ -0,0 +1,10 @@ +use uuid::Uuid; + +use crate::allocation::config::AllocConfig; + +pub struct MetaStore {} + +impl MetaStore { + pub async fn allocate(&self, alloc_id: &str, meta: &AllocConfig) {} + pub async fn deallocate(&self, alloc_id: Uuid) {} +} From 09f7f0955b93c0010ef77279a57215906085d1d3 Mon Sep 17 00:00:00 2001 From: ad hoc Date: Fri, 7 Jul 2023 10:48:08 +0200 Subject: [PATCH 06/64] changes to libsqlx --- Cargo.lock | 58 +++++----- libsqlx-server/Cargo.toml | 4 + libsqlx-server/src/allocation/mod.rs | 15 ++- libsqlx-server/src/http/admin.rs | 48 ++++++-- libsqlx-server/src/main.rs | 3 + libsqlx/src/connection.rs | 20 +++- libsqlx/src/database/libsql/connection.rs | 12 +- libsqlx/src/database/libsql/injector/hook.rs | 7 +- libsqlx/src/database/libsql/injector/mod.rs | 20 +++- libsqlx/src/database/libsql/mod.rs | 73 +++++++----- libsqlx/src/database/mod.rs | 112 +------------------ libsqlx/src/database/proxy/connection.rs | 30 ++--- libsqlx/src/database/proxy/database.rs | 7 +- libsqlx/src/database/test_utils.rs | 10 +- libsqlx/src/lib.rs | 1 - libsqlx/src/result_builder.rs | 2 +- libsqlx/src/semaphore.rs | 98 ---------------- 17 files changed, 203 insertions(+), 317 deletions(-) delete mode 100644 libsqlx/src/semaphore.rs diff --git a/Cargo.lock b/Cargo.lock index 06c5b482..63b0f30a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -205,7 +205,7 @@ checksum = "16e62a023e7c117e27523144c5d2459f4397fcc3cab0085af8e2224f643a0193" dependencies = [ "proc-macro2", "quote", - "syn 2.0.16", + "syn 2.0.23", ] [[package]] @@ -216,7 +216,7 @@ checksum = "b9ccdd8f2a161be9bd5c023df56f1b2a0bd1d83872ae53b71a84a12c9bf6e842" dependencies = [ "proc-macro2", "quote", - "syn 2.0.16", + "syn 2.0.23", ] [[package]] @@ -716,7 +716,7 @@ dependencies = [ "regex", "rustc-hash", "shlex", - "syn 2.0.16", + "syn 2.0.23", ] [[package]] @@ -836,7 +836,7 @@ checksum = "fdde5c9cd29ebd706ce1b35600920a33550e402fc998a2e53ad3b42c3c47a192" dependencies = [ "proc-macro2", "quote", - "syn 2.0.16", + "syn 2.0.23", ] [[package]] @@ -1024,7 +1024,7 @@ dependencies = [ "heck", "proc-macro2", "quote", - "syn 2.0.16", + "syn 2.0.23", ] [[package]] @@ -1647,7 +1647,7 @@ checksum = "2cd66269887534af4b0c3e3337404591daa8dc8b9b2b3db71f9523beb4bafb41" dependencies = [ "proc-macro2", "quote", - "syn 2.0.16", + "syn 2.0.23", ] [[package]] @@ -1747,7 +1747,7 @@ checksum = "89ca545a94061b6365f2c7355b4b32bd20df3ff95f02da9329b34ccc3bd6ee72" dependencies = [ "proc-macro2", "quote", - "syn 2.0.16", + "syn 2.0.23", ] [[package]] @@ -1830,7 +1830,7 @@ checksum = "e77ac7b51b8e6313251737fcef4b1c01a2ea102bde68415b62c0ee9268fec357" dependencies = [ "proc-macro2", "quote", - "syn 2.0.16", + "syn 2.0.23", ] [[package]] @@ -2470,10 +2470,14 @@ dependencies = [ "axum", "clap", "color-eyre", + "futures", "hyper", + "libsqlx", + "serde", "tokio", "tracing", "tracing-subscriber", + "uuid", ] [[package]] @@ -2836,7 +2840,7 @@ checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" dependencies = [ "proc-macro2", "quote", - "syn 2.0.16", + "syn 2.0.23", ] [[package]] @@ -3011,7 +3015,7 @@ checksum = "39407670928234ebc5e6e580247dd567ad73a3578460c5990f9503df207e8f07" dependencies = [ "proc-macro2", "quote", - "syn 2.0.16", + "syn 2.0.23", ] [[package]] @@ -3091,7 +3095,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3b69d39aab54d069e7f2fe8cb970493e7834601ca2d8c65fd7bbd183578080d1" dependencies = [ "proc-macro2", - "syn 2.0.16", + "syn 2.0.23", ] [[package]] @@ -3106,9 +3110,9 @@ dependencies = [ [[package]] name = "proc-macro2" -version = "1.0.58" +version = "1.0.63" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fa1fb82fc0c281dd9671101b66b771ebbe1eaf967b96ac8740dcba4b70005ca8" +checksum = "7b368fba921b0dce7e60f5e04ec15e565b3303972b42bcfde1d0713b881959eb" dependencies = [ "unicode-ident", ] @@ -3204,9 +3208,9 @@ checksum = "a1d01941d82fa2ab50be1e79e6714289dd7cde78eba4c074bc5a4374f650dfe0" [[package]] name = "quote" -version = "1.0.27" +version = "1.0.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f4f29d145265ec1c483c7c654450edde0bfe043d3938d6972630663356d9500" +checksum = "573015e8ab27661678357f27dc26460738fd2b6c86e46f386fde94cb5d913105" dependencies = [ "proc-macro2", ] @@ -3621,22 +3625,22 @@ checksum = "bebd363326d05ec3e2f532ab7660680f3b02130d780c299bca73469d521bc0ed" [[package]] name = "serde" -version = "1.0.164" +version = "1.0.166" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9e8c8cf938e98f769bc164923b06dce91cea1751522f46f8466461af04c9027d" +checksum = "d01b7404f9d441d3ad40e6a636a7782c377d2abdbe4fa2440e2edcc2f4f10db8" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.164" +version = "1.0.166" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d9735b638ccc51c28bf6914d90a2e9725b377144fc612c49a611fddd1b631d68" +checksum = "5dd83d6dde2b6b2d466e14d9d1acce8816dedee94f735eac6395808b3483c6d6" dependencies = [ "proc-macro2", "quote", - "syn 2.0.16", + "syn 2.0.23", ] [[package]] @@ -3967,9 +3971,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.16" +version = "2.0.23" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a6f671d4b5ffdb8eadec19c0ae67fe2639df8684bd7bc4b83d986b8db549cf01" +checksum = "59fb7d6d8281a51045d62b8eb3a7d1ce347b76f312af50cd3dc0af39c87c1737" dependencies = [ "proc-macro2", "quote", @@ -4067,7 +4071,7 @@ checksum = "f9456a42c5b0d803c8cd86e73dd7cc9edd429499f37a3550d286d5e86720569f" dependencies = [ "proc-macro2", "quote", - "syn 2.0.16", + "syn 2.0.23", ] [[package]] @@ -4172,7 +4176,7 @@ checksum = "630bdcf245f78637c13ec01ffae6187cca34625e8c63150d424b59e55af2675e" dependencies = [ "proc-macro2", "quote", - "syn 2.0.16", + "syn 2.0.23", ] [[package]] @@ -4402,7 +4406,7 @@ checksum = "0f57e3ca2a01450b1a921183a9c9cbfda207fd822cef4ccb00a65402cbba7a74" dependencies = [ "proc-macro2", "quote", - "syn 2.0.16", + "syn 2.0.23", ] [[package]] @@ -4719,7 +4723,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn 2.0.16", + "syn 2.0.23", "wasm-bindgen-shared", ] @@ -4753,7 +4757,7 @@ checksum = "e128beba882dd1eb6200e1dc92ae6c5dbaa4311aa7bb211ca035779e5efc39f8" dependencies = [ "proc-macro2", "quote", - "syn 2.0.16", + "syn 2.0.23", "wasm-bindgen-backend", "wasm-bindgen-shared", ] diff --git a/libsqlx-server/Cargo.toml b/libsqlx-server/Cargo.toml index 7fe6b737..90f2ca0b 100644 --- a/libsqlx-server/Cargo.toml +++ b/libsqlx-server/Cargo.toml @@ -9,7 +9,11 @@ edition = "2021" axum = "0.6.18" clap = { version = "4.3.11", features = ["derive"] } color-eyre = "0.6.2" +futures = "0.3.28" hyper = { version = "0.14.27", features = ["h2"] } +libsqlx = { version = "0.1.0", path = "../libsqlx" } +serde = { version = "1.0.166", features = ["derive"] } tokio = { version = "1.29.1", features = ["full"] } tracing = "0.1.37" tracing-subscriber = "0.3.17" +uuid = { version = "1.4.0", features = ["v4"] } diff --git a/libsqlx-server/src/allocation/mod.rs b/libsqlx-server/src/allocation/mod.rs index 11ae61ae..2fa9a4cd 100644 --- a/libsqlx-server/src/allocation/mod.rs +++ b/libsqlx-server/src/allocation/mod.rs @@ -1,6 +1,5 @@ use std::collections::HashMap; -use libsqlx::Database; use tokio::{sync::{mpsc, oneshot}, task::{JoinSet, block_in_place}}; pub mod config; @@ -26,9 +25,17 @@ enum AllocationMessage { } } +enum Database {} + +impl Database { + fn connect(&self) -> Box { + todo!(); + } +} + pub struct Allocation { inbox: mpsc::Receiver, - database: Box, + database: Database, /// senders to the spawned connections connections: HashMap>, /// spawned connection futures, returning their connection id on completion. @@ -69,7 +76,7 @@ impl Allocation { async fn new_conn_exec(&mut self, exec: ExecFn) -> ConnectionId { let id = self.next_conn_id(); - let conn = block_in_place(|| self.database.connect()).unwrap(); + let conn = block_in_place(|| self.database.connect()); let (close_sender, exit) = mpsc::channel(1); let (exec_sender, exec_receiver) = mpsc::channel(1); let conn = Connection { @@ -103,7 +110,7 @@ impl Allocation { struct Connection { id: u32, - conn: Box, + conn: Box, exit: mpsc::Receiver<()>, exec: mpsc::Receiver, } diff --git a/libsqlx-server/src/http/admin.rs b/libsqlx-server/src/http/admin.rs index cda7ee55..2d9c8054 100644 --- a/libsqlx-server/src/http/admin.rs +++ b/libsqlx-server/src/http/admin.rs @@ -1,21 +1,53 @@ use std::sync::Arc; -use axum::Router; +use axum::{extract::State, routing::post, Json, Router}; use color_eyre::eyre::Result; use hyper::server::accept::Accept; +use serde::{Deserialize, Serialize}; use tokio::io::{AsyncRead, AsyncWrite}; -pub struct AdminServerConfig { } +use crate::{meta::MetaStore, allocation::config::AllocConfig}; -struct AdminServerState { } +pub struct AdminServerConfig {} + +struct AdminServerState { + meta_store: Arc, +} pub async fn run_admin_server(_config: AdminServerConfig, listener: I) -> Result<()> -where I: Accept, - I::Conn: AsyncRead + AsyncWrite + Send + Unpin + 'static, +where + I: Accept, + I::Conn: AsyncRead + AsyncWrite + Send + Unpin + 'static, { - let state = AdminServerState { }; - let app = Router::new().with_state(Arc::new(state)); - axum::Server::builder(listener).serve(app.into_make_service()).await?; + let state = AdminServerState { + meta_store: todo!(), + }; + let app = Router::new() + .route("/manage/allocation/create", post(allocate)) + .with_state(Arc::new(state)); + axum::Server::builder(listener) + .serve(app.into_make_service()) + .await?; Ok(()) } + +#[derive(Serialize, Debug)] +struct ErrorResponse {} + +#[derive(Serialize, Debug)] +struct AllocateResp { } + +#[derive(Deserialize, Debug)] +struct AllocateReq { + alloc_id: String, + config: AllocConfig, +} + +async fn allocate( + State(state): State>, + Json(req): Json, +) -> Result, Json> { + state.meta_store.allocate(&req.alloc_id, &req.config).await; + todo!(); +} diff --git a/libsqlx-server/src/main.rs b/libsqlx-server/src/main.rs index 15e33e24..1ee047bf 100644 --- a/libsqlx-server/src/main.rs +++ b/libsqlx-server/src/main.rs @@ -2,7 +2,10 @@ use color_eyre::eyre::Result; use tracing::metadata::LevelFilter; use tracing_subscriber::prelude::*; +mod allocation; +mod databases; mod http; +mod meta; #[tokio::main] async fn main() -> Result<()> { diff --git a/libsqlx/src/connection.rs b/libsqlx/src/connection.rs index d21ca9a6..cc4776c4 100644 --- a/libsqlx/src/connection.rs +++ b/libsqlx/src/connection.rs @@ -22,12 +22,26 @@ pub struct DescribeCol { pub trait Connection { /// Executes a query program - fn execute_program( + fn execute_program( &mut self, pgm: Program, - result_builder: B, - ) -> crate::Result; + result_builder: &mut dyn ResultBuilder, + ) -> crate::Result<()>; /// Parse the SQL statement and return information about it. fn describe(&self, sql: String) -> crate::Result; } + +impl Connection for Box { + fn execute_program( + &mut self, + pgm: Program, + result_builder: &mut dyn ResultBuilder, + ) -> crate::Result<()> { + self.as_mut().execute_program(pgm, result_builder) + } + + fn describe(&self, sql: String) -> crate::Result { + self.as_ref().describe(sql) + } +} diff --git a/libsqlx/src/database/libsql/connection.rs b/libsqlx/src/database/libsql/connection.rs index 88632501..8a1c8c55 100644 --- a/libsqlx/src/database/libsql/connection.rs +++ b/libsqlx/src/database/libsql/connection.rs @@ -101,14 +101,14 @@ impl LibsqlConnection { &self.conn } - fn run(&mut self, pgm: Program, mut builder: B) -> Result { + fn run(&mut self, pgm: Program, builder: &mut dyn ResultBuilder) -> Result<()> { let mut results = Vec::with_capacity(pgm.steps.len()); builder.init(&self.builder_config)?; let is_autocommit_before = self.conn.is_autocommit(); for step in pgm.steps() { - let res = self.execute_step(step, &results, &mut builder)?; + let res = self.execute_step(step, &results, builder)?; results.push(res); } @@ -119,14 +119,14 @@ impl LibsqlConnection { builder.finish(!self.conn.is_autocommit(), None)?; - Ok(builder) + Ok(()) } fn execute_step( &mut self, step: &Step, results: &[bool], - builder: &mut impl ResultBuilder, + builder: &mut dyn ResultBuilder, ) -> Result { builder.begin_step()?; let mut enabled = match step.cond.as_ref() { @@ -163,7 +163,7 @@ impl LibsqlConnection { fn execute_query( &self, query: &Query, - builder: &mut impl ResultBuilder, + builder: &mut dyn ResultBuilder, ) -> Result<(u64, Option)> { tracing::trace!("executing query: {}", query.stmt.stmt); @@ -237,7 +237,7 @@ fn eval_cond(cond: &Cond, results: &[bool]) -> Result { } impl Connection for LibsqlConnection { - fn execute_program(&mut self, pgm: Program, builder: B) -> crate::Result { + fn execute_program(&mut self, pgm: Program, builder: &mut dyn ResultBuilder) -> crate::Result<()> { self.run(pgm, builder) } diff --git a/libsqlx/src/database/libsql/injector/hook.rs b/libsqlx/src/database/libsql/injector/hook.rs index 0479fb2d..f87172db 100644 --- a/libsqlx/src/database/libsql/injector/hook.rs +++ b/libsqlx/src/database/libsql/injector/hook.rs @@ -27,14 +27,11 @@ pub struct InjectorHookCtx { } impl InjectorHookCtx { - pub fn new( - buffer: FrameBuffer, - injector_commit_handler: impl InjectorCommitHandler + 'static, - ) -> Self { + pub fn new(buffer: FrameBuffer, commit_handler: Box) -> Self { Self { buffer, is_txn: false, - commit_handler: Box::new(injector_commit_handler), + commit_handler, } } diff --git a/libsqlx/src/database/libsql/injector/mod.rs b/libsqlx/src/database/libsql/injector/mod.rs index df01cd34..1682e3b4 100644 --- a/libsqlx/src/database/libsql/injector/mod.rs +++ b/libsqlx/src/database/libsql/injector/mod.rs @@ -42,6 +42,16 @@ pub trait InjectorCommitHandler: 'static { fn post_commit(&mut self, frame_no: FrameNo) -> anyhow::Result<()>; } +impl InjectorCommitHandler for Box { + fn pre_commit(&mut self, frame_no: FrameNo) -> anyhow::Result<()> { + self.as_mut().pre_commit(frame_no) + } + + fn post_commit(&mut self, frame_no: FrameNo) -> anyhow::Result<()> { + self.as_mut().post_commit(frame_no) + } +} + #[cfg(test)] impl InjectorCommitHandler for () { fn pre_commit(&mut self, _frame_no: FrameNo) -> anyhow::Result<()> { @@ -56,11 +66,11 @@ impl InjectorCommitHandler for () { impl Injector { pub fn new( path: &Path, - injector_commit_hanlder: impl InjectorCommitHandler + 'static, + injector_commit_handler: Box, buffer_capacity: usize, ) -> crate::Result { let buffer = FrameBuffer::default(); - let ctx = InjectorHookCtx::new(buffer.clone(), injector_commit_hanlder); + let ctx = InjectorHookCtx::new(buffer.clone(), injector_commit_handler); let mut ctx = Box::new(ctx); let connection = sqld_libsql_bindings::Connection::open( path, @@ -162,7 +172,7 @@ mod test { let log = LogFile::new(file).unwrap(); let temp = tempfile::tempdir().unwrap(); - let mut injector = Injector::new(temp.path(), (), 10).unwrap(); + let mut injector = Injector::new(temp.path(), Box::new(()), 10).unwrap(); for frame in log.frames_iter().unwrap() { let frame = frame.unwrap(); injector.inject_frame(frame).unwrap(); @@ -184,7 +194,7 @@ mod test { let temp = tempfile::tempdir().unwrap(); // inject one frame at a time - let mut injector = Injector::new(temp.path(), (), 1).unwrap(); + let mut injector = Injector::new(temp.path(), Box::new(()), 1).unwrap(); for frame in log.frames_iter().unwrap() { let frame = frame.unwrap(); injector.inject_frame(frame).unwrap(); @@ -206,7 +216,7 @@ mod test { let temp = tempfile::tempdir().unwrap(); // inject one frame at a time - let mut injector = Injector::new(temp.path(), (), 10).unwrap(); + let mut injector = Injector::new(temp.path(), Box::new(()), 10).unwrap(); let mut iter = log.frames_iter().unwrap(); assert!(injector diff --git a/libsqlx/src/database/libsql/mod.rs b/libsqlx/src/database/libsql/mod.rs index e5f4ad0a..27397663 100644 --- a/libsqlx/src/database/libsql/mod.rs +++ b/libsqlx/src/database/libsql/mod.rs @@ -44,8 +44,8 @@ impl LibsqlDbType for PrimaryType { } pub struct ReplicaType { - // frame injector for the database - injector: Injector, + commit_handler: Option>, + injector_buffer_capacity: usize, } impl LibsqlDbType for ReplicaType { @@ -83,13 +83,13 @@ pub struct LibsqlDatabase { } /// Handler trait for gathering row stats when executing queries. -pub trait RowStatsHandler { +pub trait RowStatsHandler: Send + Sync { fn handle_row_stats(&self, stats: RowStats); } impl RowStatsHandler for F where - F: Fn(RowStats), + F: Fn(RowStats) + Send + Sync, { fn handle_row_stats(&self, stats: RowStats) { (self)(stats) @@ -104,7 +104,8 @@ impl LibsqlDatabase { injector_commit_handler: impl InjectorCommitHandler, ) -> crate::Result { let ty = ReplicaType { - injector: Injector::new(&db_path, injector_commit_handler, injector_buffer_capacity)?, + commit_handler: Some(Box::new(injector_commit_handler)), + injector_buffer_capacity, }; Ok(Self::new(db_path, ty)) @@ -154,7 +155,7 @@ impl Database for LibsqlDatabase { type Connection = LibsqlConnection<::Context>; fn connect(&self) -> Result { - LibsqlConnection::<::Context>::new( + Ok(LibsqlConnection::<::Context>::new( &self.db_path, self.extensions.clone(), T::hook(), @@ -163,13 +164,24 @@ impl Database for LibsqlDatabase { QueryBuilderConfig { max_size: Some(self.response_size_limit), }, - ) + )?) } } impl InjectableDatabase for LibsqlDatabase { - fn inject_frame(&mut self, frame: Frame) -> Result<(), InjectError> { - self.ty.injector.inject_frame(frame).unwrap(); + fn injector(&mut self) -> crate::Result> { + let Some(commit_handler) = self.ty.commit_handler.take() else { panic!("there can be only one injector") }; + Ok(Box::new(Injector::new( + &self.db_path, + commit_handler, + self.ty.injector_buffer_capacity, + )?)) + } +} + +impl super::Injector for Injector { + fn inject(&mut self, frame: Frame) -> Result<(), InjectError> { + self.inject_frame(frame).unwrap(); Ok(()) } } @@ -205,32 +217,36 @@ mod test { fn inject_libsql_db() { let temp = tempfile::tempdir().unwrap(); let replica = ReplicaType { - injector: Injector::new(temp.path(), (), 10).unwrap(), + commit_handler: Some(Box::new(())), + injector_buffer_capacity: 10, }; let mut db = LibsqlDatabase::new(temp.path().to_path_buf(), replica); let mut conn = db.connect().unwrap(); - let res = conn + let mut builder = ReadRowBuilder(Vec::new()); + conn .execute_program( Program::seq(&["select count(*) from test"]), - ReadRowBuilder(Vec::new()), + &mut builder ) .unwrap(); - assert!(res.0.is_empty()); + assert!(builder.0.is_empty()); let file = File::open("assets/test/simple_wallog").unwrap(); let log = LogFile::new(file).unwrap(); + let mut injector = db.injector().unwrap(); log.frames_iter() .unwrap() - .for_each(|f| db.inject_frame(f.unwrap()).unwrap()); + .for_each(|f| injector.inject(f.unwrap()).unwrap()); - let res = conn + let mut builder = ReadRowBuilder(Vec::new()); + conn .execute_program( Program::seq(&["select count(*) from test"]), - ReadRowBuilder(Vec::new()), + &mut builder ) .unwrap(); - assert_eq!(res.0[0], Value::Integer(5)); + assert_eq!(builder.0[0], Value::Integer(5)); } #[test] @@ -248,7 +264,8 @@ mod test { let mut replica = LibsqlDatabase::new( temp_replica.path().to_path_buf(), ReplicaType { - injector: Injector::new(temp_replica.path(), (), 10).unwrap(), + commit_handler: Some(Box::new(())), + injector_buffer_capacity: 10, }, ); @@ -256,27 +273,29 @@ mod test { primary_conn .execute_program( Program::seq(&["create table test (x)", "insert into test values (42)"]), - (), + &mut (), ) .unwrap(); let logfile = primary.ty.logger.log_file.read(); + let mut injector = replica.injector().unwrap(); for frame in logfile.frames_iter().unwrap() { let frame = frame.unwrap(); - replica.inject_frame(frame).unwrap(); + injector.inject(frame).unwrap(); } let mut replica_conn = replica.connect().unwrap(); - let result = replica_conn + let mut builder = ReadRowBuilder(Vec::new()); + replica_conn .execute_program( Program::seq(&["select * from test limit 1"]), - ReadRowBuilder(Vec::new()), + &mut builder ) .unwrap(); - assert_eq!(result.0.len(), 1); - assert_eq!(result.0[0], Value::Integer(42)); + assert_eq!(builder.0.len(), 1); + assert_eq!(builder.0[0], Value::Integer(42)); } #[test] @@ -311,7 +330,7 @@ mod test { let mut conn = db.connect().unwrap(); conn.execute_program( Program::seq(&["create table test (x)", "insert into test values (12)"]), - (), + &mut (), ) .unwrap(); assert!(compactor_called.get()); @@ -354,12 +373,12 @@ mod test { "create table test (x)", "insert into test values (12)", ]), - (), + &mut (), ) .unwrap(); conn.inner_connection().cache_flush().unwrap(); assert!(!compactor_called.get()); - conn.execute_program(Program::seq(&["commit"]), ()).unwrap(); + conn.execute_program(Program::seq(&["commit"]), &mut ()).unwrap(); assert!(compactor_called.get()); } } diff --git a/libsqlx/src/database/mod.rs b/libsqlx/src/database/mod.rs index edfb0b8e..fa1ce874 100644 --- a/libsqlx/src/database/mod.rs +++ b/libsqlx/src/database/mod.rs @@ -1,11 +1,7 @@ use std::time::Duration; -use crate::connection::{Connection, DescribeResponse}; +use crate::connection::Connection; use crate::error::Error; -use crate::program::Program; -use crate::result_builder::ResultBuilder; -use crate::semaphore::{Permit, Semaphore}; - use self::frame::Frame; mod frame; @@ -25,111 +21,13 @@ pub trait Database { type Connection: Connection; /// Create a new connection to the database fn connect(&self) -> Result; - - /// returns a database with a limit on the number of conccurent connections - fn throttled(self, limit: usize, timeout: Option) -> ThrottledDatabase - where - Self: Sized, - { - ThrottledDatabase::new(limit, self, timeout) - } } -// Trait implemented by databases that support frame injection pub trait InjectableDatabase { - fn inject_frame(&mut self, frame: Frame) -> Result<(), InjectError>; -} - -/// A Database that limits the number of conccurent connections to the underlying database. -pub struct ThrottledDatabase { - semaphore: Semaphore, - db: T, - timeout: Option, -} - -impl ThrottledDatabase { - fn new(conccurency: usize, db: T, timeout: Option) -> Self { - Self { - semaphore: Semaphore::new(conccurency), - db, - timeout, - } - } -} - -impl Database for ThrottledDatabase { - type Connection = TrackedDb; - - fn connect(&self) -> Result { - let permit = match self.timeout { - Some(t) => self - .semaphore - .acquire_timeout(t) - .ok_or(Error::DbCreateTimeout)?, - None => self.semaphore.acquire(), - }; - - let inner = self.db.connect()?; - Ok(TrackedDb { permit, inner }) - } -} - -pub struct TrackedDb { - inner: DB, - #[allow(dead_code)] // just hold on to it - permit: Permit, -} - -impl Connection for TrackedDb { - #[inline] - fn execute_program(&mut self, pgm: Program, builder: B) -> crate::Result { - self.inner.execute_program(pgm, builder) - } - - #[inline] - fn describe(&self, sql: String) -> crate::Result { - self.inner.describe(sql) - } + fn injector(&mut self) -> crate::Result>; } -#[cfg(test)] -mod test { - use super::*; - - struct DummyConn; - - impl Connection for DummyConn { - fn execute_program(&mut self, _pgm: Program, _builder: B) -> crate::Result - where - B: ResultBuilder, - { - unreachable!() - } - - fn describe(&self, _sql: String) -> crate::Result { - unreachable!() - } - } - - struct DummyDatabase; - - impl Database for DummyDatabase { - type Connection = DummyConn; - - fn connect(&self) -> Result { - Ok(DummyConn) - } - } - - #[test] - fn throttle_db_creation() { - let db = DummyDatabase.throttled(1, Some(Duration::from_millis(100))); - let conn = db.connect().unwrap(); - - assert!(db.connect().is_err()); - - drop(conn); - - assert!(db.connect().is_ok()); - } +// Trait implemented by databases that support frame injection +pub trait Injector { + fn inject(&mut self, frame: Frame) -> Result<(), InjectError>; } diff --git a/libsqlx/src/database/proxy/connection.rs b/libsqlx/src/database/proxy/connection.rs index 6fec7d23..24c10a47 100644 --- a/libsqlx/src/database/proxy/connection.rs +++ b/libsqlx/src/database/proxy/connection.rs @@ -26,9 +26,9 @@ where ReadDb: Connection, WriteDb: Connection, { - fn execute_program(&mut self, pgm: Program, builder: B) -> Result { + fn execute_program(&mut self, pgm: Program, builder: &mut dyn ResultBuilder) -> Result<()> { let mut state = self.state.lock(); - let builder = ExtractFrameNoBuilder::new(builder); + let mut builder = ExtractFrameNoBuilder::new(builder); if !state.is_txn && pgm.is_read_only() { if let Some(frame_no) = state.last_frame_no { (self.wait_frame_no_cb)(frame_no); @@ -36,24 +36,24 @@ where // We know that this program won't perform any writes. We attempt to run it on the // replica. If it leaves an open transaction, then this program is an interactive // transaction, so we rollback the replica, and execute again on the primary. - let builder = self.read_db.execute_program(pgm.clone(), builder)?; + self.read_db.execute_program(pgm.clone(), &mut builder)?; // still in transaction state after running a read-only txn if builder.is_txn { // TODO: rollback // self.read_db.rollback().await?; - let builder = self.write_db.execute_program(pgm, builder)?; + self.write_db.execute_program(pgm, &mut builder)?; state.is_txn = builder.is_txn; state.last_frame_no = builder.frame_no; - Ok(builder.inner) + Ok(()) } else { - Ok(builder.inner) + Ok(()) } } else { - let builder = self.write_db.execute_program(pgm, builder)?; + self.write_db.execute_program(pgm, &mut builder)?; state.is_txn = builder.is_txn; state.last_frame_no = builder.frame_no; - Ok(builder.inner) + Ok(()) } } @@ -65,14 +65,14 @@ where } } -struct ExtractFrameNoBuilder { - inner: B, +struct ExtractFrameNoBuilder<'a> { + inner: &'a mut dyn ResultBuilder, frame_no: Option, is_txn: bool, } -impl ExtractFrameNoBuilder { - fn new(inner: B) -> Self { +impl<'a> ExtractFrameNoBuilder<'a> { + fn new(inner: &'a mut dyn ResultBuilder) -> Self { Self { inner, frame_no: None, @@ -81,7 +81,7 @@ impl ExtractFrameNoBuilder { } } -impl ResultBuilder for ExtractFrameNoBuilder { +impl<'a> ResultBuilder for ExtractFrameNoBuilder<'a> { fn init( &mut self, config: &QueryBuilderConfig, @@ -206,14 +206,14 @@ mod test { ); let mut conn = db.connect().unwrap(); - conn.execute_program(Program::seq(&["insert into test values (12)"]), ()) + conn.execute_program(Program::seq(&["insert into test values (12)"]), &mut ()) .unwrap(); assert!(!wait_called.get()); assert!(!read_called.get()); assert!(write_called.get()); - conn.execute_program(Program::seq(&["select * from test"]), ()) + conn.execute_program(Program::seq(&["select * from test"]), &mut ()) .unwrap(); assert!(read_called.get()); diff --git a/libsqlx/src/database/proxy/database.rs b/libsqlx/src/database/proxy/database.rs index fedd0ef7..129cc5e2 100644 --- a/libsqlx/src/database/proxy/database.rs +++ b/libsqlx/src/database/proxy/database.rs @@ -1,4 +1,3 @@ -use crate::database::frame::Frame; use crate::database::{Database, InjectableDatabase}; use crate::error::Error; @@ -27,7 +26,6 @@ where WDB: Database, { type Connection = WriteProxyConnection; - /// Create a new connection to the database fn connect(&self) -> Result { Ok(WriteProxyConnection { @@ -43,8 +41,7 @@ impl InjectableDatabase for WriteProxyDatabase where RDB: InjectableDatabase, { - fn inject_frame(&mut self, frame: Frame) -> Result<(), crate::database::InjectError> { - // TODO: handle frame index - self.read_db.inject_frame(frame) + fn injector(&mut self) -> crate::Result> { + self.read_db.injector() } } diff --git a/libsqlx/src/database/test_utils.rs b/libsqlx/src/database/test_utils.rs index 86c072ea..a46aa2ac 100644 --- a/libsqlx/src/database/test_utils.rs +++ b/libsqlx/src/database/test_utils.rs @@ -51,13 +51,13 @@ impl Database for MockDatabase { } impl Connection for MockConnection { - fn execute_program( + fn execute_program( &mut self, pgm: crate::program::Program, - mut reponse_builder: B, - ) -> crate::Result { - (self.execute_fn)(pgm, &mut reponse_builder)?; - Ok(reponse_builder) + reponse_builder: &mut dyn ResultBuilder, + ) -> crate::Result<()> { + (self.execute_fn)(pgm, reponse_builder)?; + Ok(()) } fn describe(&self, sql: String) -> crate::Result { diff --git a/libsqlx/src/lib.rs b/libsqlx/src/lib.rs index 986044e2..a6e3c3a2 100644 --- a/libsqlx/src/lib.rs +++ b/libsqlx/src/lib.rs @@ -7,7 +7,6 @@ mod database; mod program; mod result_builder; mod seal; -mod semaphore; pub type Result = std::result::Result; diff --git a/libsqlx/src/result_builder.rs b/libsqlx/src/result_builder.rs index 2784274c..ae299b1e 100644 --- a/libsqlx/src/result_builder.rs +++ b/libsqlx/src/result_builder.rs @@ -80,7 +80,7 @@ pub struct QueryBuilderConfig { pub max_size: Option, } -pub trait ResultBuilder: Send + 'static { +pub trait ResultBuilder { /// (Re)initialize the builder. This method can be called multiple times. fn init(&mut self, _config: &QueryBuilderConfig) -> Result<(), QueryResultBuilderError> { Ok(()) diff --git a/libsqlx/src/semaphore.rs b/libsqlx/src/semaphore.rs deleted file mode 100644 index a47a4eb1..00000000 --- a/libsqlx/src/semaphore.rs +++ /dev/null @@ -1,98 +0,0 @@ -use std::sync::Arc; -use std::time::Duration; -use std::time::Instant; - -use parking_lot::Condvar; -use parking_lot::Mutex; - -struct SemaphoreInner { - max_permits: usize, - permits: Mutex, - condvar: Condvar, -} - -#[derive(Clone)] -pub struct Semaphore { - inner: Arc, -} - -pub struct Permit(Semaphore); - -impl Drop for Permit { - fn drop(&mut self) { - *self.0.inner.permits.lock() -= 1; - self.0.inner.condvar.notify_one(); - } -} - -impl Semaphore { - pub fn new(max_permits: usize) -> Self { - Self { - inner: Arc::new(SemaphoreInner { - max_permits, - permits: Mutex::new(0), - condvar: Condvar::new(), - }), - } - } - - pub fn acquire(&self) -> Permit { - let mut permits = self.inner.permits.lock(); - self.inner - .condvar - .wait_while(&mut permits, |permits| *permits >= self.inner.max_permits); - *permits += 1; - assert!(*permits <= self.inner.max_permits); - Permit(self.clone()) - } - - pub fn acquire_timeout(&self, timeout: Duration) -> Option { - let deadline = Instant::now() + timeout; - let mut permits = self.inner.permits.lock(); - if self - .inner - .condvar - .wait_while_until( - &mut permits, - |permits| *permits >= self.inner.max_permits, - deadline, - ) - .timed_out() - { - return None; - } - - *permits += 1; - assert!(*permits <= self.inner.max_permits); - Some(Permit(self.clone())) - } - - #[cfg(test)] - fn try_acquire(&self) -> Option { - let mut permits = self.inner.permits.lock(); - if *permits >= self.inner.max_permits { - None - } else { - *permits += 1; - Some(Permit(self.clone())) - } - } -} - -#[cfg(test)] -mod test { - use super::*; - - #[test] - fn semaphore() { - let sem = Semaphore::new(2); - let permit1 = sem.acquire(); - let _permit2 = sem.acquire(); - - assert!(sem.try_acquire().is_none()); - drop(permit1); - let perm = sem.try_acquire(); - assert!(perm.is_some()); - assert!(sem.acquire_timeout(Duration::from_millis(100)).is_none()); - } -} From 19c8f4fecfafbad6b7af7cb13d323ca87bcdaf62 Mon Sep 17 00:00:00 2001 From: ad hoc Date: Fri, 7 Jul 2023 16:48:59 +0200 Subject: [PATCH 07/64] bare bone allocation manager/store --- Cargo.lock | 317 +++++++++++++++++- libsqlx-server/Cargo.toml | 7 +- libsqlx-server/src/allocation/config.rs | 24 +- libsqlx-server/src/allocation/mod.rs | 71 ++-- libsqlx-server/src/databases/mod.rs | 5 + libsqlx-server/src/databases/store.rs | 12 + libsqlx-server/src/http/admin.rs | 69 +++- libsqlx-server/src/main.rs | 15 + libsqlx-server/src/manager.rs | 50 +++ libsqlx-server/src/meta.rs | 53 ++- libsqlx/src/connection.rs | 61 +++- libsqlx/src/database/libsql/connection.rs | 6 +- libsqlx/src/database/libsql/mod.rs | 91 ++--- .../database/libsql/replication_log/logger.rs | 22 +- libsqlx/src/database/mod.rs | 2 +- 15 files changed, 712 insertions(+), 93 deletions(-) create mode 100644 libsqlx-server/src/databases/mod.rs create mode 100644 libsqlx-server/src/databases/store.rs create mode 100644 libsqlx-server/src/manager.rs diff --git a/Cargo.lock b/Cargo.lock index 63b0f30a..0864a55e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -177,6 +177,26 @@ dependencies = [ "tokio", ] +[[package]] +name = "async-io" +version = "1.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fc5b45d93ef0529756f812ca52e44c221b35341892d3dcc34132ac02f3dd2af" +dependencies = [ + "async-lock", + "autocfg", + "cfg-if", + "concurrent-queue", + "futures-lite", + "log", + "parking", + "polling", + "rustix 0.37.19", + "slab", + "socket2", + "waker-fn", +] + [[package]] name = "async-lock" version = "2.7.0" @@ -819,6 +839,12 @@ version = "3.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a3e2c3daef883ecc1b5d58c15adae93470a91d425f3532ba1695849656af3fc1" +[[package]] +name = "bytecount" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2c676a478f63e9fa2dd5368a42f28bba0d6c560b775f38583c8bbaa7fcd67c9c" + [[package]] name = "bytemuck" version = "1.13.1" @@ -876,6 +902,15 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "38fcc2979eff34a4b84e1cf9a1e3da42a7d44b3b690a40cdcb23e3d556cfb2e5" +[[package]] +name = "camino" +version = "1.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c530edf18f37068ac2d977409ed5cd50d53d73bc653c7647b48eb78976ac9ae2" +dependencies = [ + "serde", +] + [[package]] name = "cap-fs-ext" version = "0.26.1" @@ -941,6 +976,28 @@ dependencies = [ "winx", ] +[[package]] +name = "cargo-platform" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cbdb825da8a5df079a43676dbe042702f1707b1109f713a01420fbb4cc71fa27" +dependencies = [ + "serde", +] + +[[package]] +name = "cargo_metadata" +version = "0.14.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4acbb09d9ee8e23699b9634375c72795d095bf268439da88562cf9b501f181fa" +dependencies = [ + "camino", + "cargo-platform", + "semver", + "serde", + "serde_json", +] + [[package]] name = "cc" version = "1.0.79" @@ -1066,6 +1123,15 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "acbf1af155f9b9ef647e42cdc158db4b64a1b61f743629225fde6f3e0be2a7c7" +[[package]] +name = "concurrent-queue" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "62ec6771ecfa0762d24683ee5a32ad78487a3d3afdc0fb8cae19d2c5deb50b7c" +dependencies = [ + "crossbeam-utils", +] + [[package]] name = "console" version = "0.15.7" @@ -1558,6 +1624,15 @@ dependencies = [ "libc", ] +[[package]] +name = "error-chain" +version = "0.12.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2d2f06b9cac1506ece98fe3231e3cc9c4410ec3d5b1f24ae1c8946f0742cdefc" +dependencies = [ + "version_check", +] + [[package]] name = "event-listener" version = "2.5.3" @@ -1691,6 +1766,16 @@ dependencies = [ "windows-sys 0.36.1", ] +[[package]] +name = "fs2" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9564fc758e15025b46aa6643b1b77d047d1a56a1aea6e01002ac0c7026876213" +dependencies = [ + "libc", + "winapi", +] + [[package]] name = "futures" version = "0.3.28" @@ -1739,6 +1824,21 @@ version = "0.3.28" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4fff74096e71ed47f8e023204cfd0aa1289cd54ae5430a9523be060cdb849964" +[[package]] +name = "futures-lite" +version = "1.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "49a9d51ce47660b1e808d3c990b4709f2f415d928835a17dfd16991515c46bce" +dependencies = [ + "fastrand", + "futures-core", + "futures-io", + "memchr", + "parking", + "pin-project-lite", + "waker-fn", +] + [[package]] name = "futures-macro" version = "0.3.28" @@ -2448,7 +2548,7 @@ dependencies = [ "itertools 0.11.0", "nix", "once_cell", - "parking_lot", + "parking_lot 0.12.1", "rand", "regex", "rusqlite", @@ -2468,12 +2568,15 @@ name = "libsqlx-server" version = "0.1.0" dependencies = [ "axum", + "bincode", "clap", "color-eyre", "futures", "hyper", "libsqlx", + "moka", "serde", + "sled", "tokio", "tracing", "tracing-subscriber", @@ -2526,6 +2629,15 @@ dependencies = [ "libc", ] +[[package]] +name = "mach2" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6d0d1830bcd151a6fc4aea1369af235b36c1528fe976b8ff678683c9995eade8" +dependencies = [ + "libc", +] + [[package]] name = "maplit" version = "1.0.2" @@ -2656,6 +2768,31 @@ dependencies = [ "windows-sys 0.45.0", ] +[[package]] +name = "moka" +version = "0.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "206bf83f415b0579fd885fe0804eb828e727636657dc1bf73d80d2f1218e14a1" +dependencies = [ + "async-io", + "async-lock", + "crossbeam-channel", + "crossbeam-epoch", + "crossbeam-utils", + "futures-util", + "once_cell", + "parking_lot 0.12.1", + "quanta", + "rustc_version", + "scheduled-thread-pool", + "skeptic", + "smallvec", + "tagptr", + "thiserror", + "triomphe", + "uuid", +] + [[package]] name = "multimap" version = "0.8.3" @@ -2879,6 +3016,23 @@ version = "3.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c1b04fb49957986fdce4d6ee7a65027d55d4b6d2265e5848bbb507b58ccfdb6f" +[[package]] +name = "parking" +version = "2.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "14f2252c834a40ed9bb5422029649578e63aa341ac401f74e719dd1afda8394e" + +[[package]] +name = "parking_lot" +version = "0.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7d17b78036a60663b797adeaee46f5c9dfebb86948d1255007a1d6be0271ff99" +dependencies = [ + "instant", + "lock_api", + "parking_lot_core 0.8.6", +] + [[package]] name = "parking_lot" version = "0.12.1" @@ -2886,7 +3040,21 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3742b2c103b9f06bc9fff0a37ff4912935851bee6d36f3c02bcc755bcfec228f" dependencies = [ "lock_api", - "parking_lot_core", + "parking_lot_core 0.9.7", +] + +[[package]] +name = "parking_lot_core" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60a2cfe6f0ad2bfc16aefa463b497d5c7a5ecd44a23efa72aa342d90177356dc" +dependencies = [ + "cfg-if", + "instant", + "libc", + "redox_syscall 0.2.16", + "smallvec", + "winapi", ] [[package]] @@ -3072,6 +3240,22 @@ dependencies = [ "serde_json", ] +[[package]] +name = "polling" +version = "2.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4b2d323e8ca7996b3e23126511a523f7e62924d93ecd5ae73b333815b0eb3dce" +dependencies = [ + "autocfg", + "bitflags 1.3.2", + "cfg-if", + "concurrent-queue", + "libc", + "log", + "pin-project-lite", + "windows-sys 0.48.0", +] + [[package]] name = "ppv-lite86" version = "0.2.17" @@ -3200,6 +3384,33 @@ dependencies = [ "cc", ] +[[package]] +name = "pulldown-cmark" +version = "0.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77a1a2f1f0a7ecff9c31abbe177637be0e97a0aef46cf8738ece09327985d998" +dependencies = [ + "bitflags 1.3.2", + "memchr", + "unicase", +] + +[[package]] +name = "quanta" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a17e662a7a8291a865152364c20c7abc5e60486ab2001e8ec10b24862de0b9ab" +dependencies = [ + "crossbeam-utils", + "libc", + "mach2", + "once_cell", + "raw-cpuid", + "wasi 0.11.0+wasi-snapshot-preview1", + "web-sys", + "winapi", +] + [[package]] name = "quick-error" version = "1.2.3" @@ -3254,6 +3465,15 @@ dependencies = [ "rand_core", ] +[[package]] +name = "raw-cpuid" +version = "10.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c297679cb867470fa8c9f67dbba74a78d78e3e98d7cf2b08d6d71540f797332" +dependencies = [ + "bitflags 1.3.2", +] + [[package]] name = "rayon" version = "1.7.0" @@ -3558,6 +3778,15 @@ version = "1.0.13" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f91339c0467de62360649f8d3e185ca8de4224ff281f66000de5eb2a77a79041" +[[package]] +name = "same-file" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502" +dependencies = [ + "winapi-util", +] + [[package]] name = "schannel" version = "0.1.21" @@ -3567,6 +3796,15 @@ dependencies = [ "windows-sys 0.42.0", ] +[[package]] +name = "scheduled-thread-pool" +version = "0.2.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3cbc66816425a074528352f5789333ecff06ca41b36b0b0efdfbb29edc391a19" +dependencies = [ + "parking_lot 0.12.1", +] + [[package]] name = "scopeguard" version = "1.1.0" @@ -3622,6 +3860,9 @@ name = "semver" version = "1.0.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bebd363326d05ec3e2f532ab7660680f3b02130d780c299bca73469d521bc0ed" +dependencies = [ + "serde", +] [[package]] name = "serde" @@ -3765,6 +4006,21 @@ version = "0.3.10" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7bd3e3206899af3f8b12af284fafc038cc1dc2b41d1b89dd17297221c5d225de" +[[package]] +name = "skeptic" +version = "0.13.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "16d23b015676c90a0f01c197bfdc786c20342c73a0afdda9025adb0bc42940a8" +dependencies = [ + "bytecount", + "cargo_metadata", + "error-chain", + "glob", + "pulldown-cmark", + "tempfile", + "walkdir", +] + [[package]] name = "slab" version = "0.4.8" @@ -3774,6 +4030,22 @@ dependencies = [ "autocfg", ] +[[package]] +name = "sled" +version = "0.34.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f96b4737c2ce5987354855aed3797279def4ebf734436c6aa4552cf8e169935" +dependencies = [ + "crc32fast", + "crossbeam-epoch", + "crossbeam-utils", + "fs2", + "fxhash", + "libc", + "log", + "parking_lot 0.11.2", +] + [[package]] name = "slice-group-by" version = "0.3.1" @@ -3855,7 +4127,7 @@ dependencies = [ "mimalloc", "nix", "once_cell", - "parking_lot", + "parking_lot 0.12.1", "priority-queue", "proptest", "prost", @@ -4002,6 +4274,12 @@ dependencies = [ "winx", ] +[[package]] +name = "tagptr" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b2093cf4c8eb1e67749a6762251bc9cd836b6fc171623bd0a9d324d37af2417" + [[package]] name = "tar" version = "0.4.38" @@ -4149,7 +4427,7 @@ dependencies = [ "libc", "mio", "num_cpus", - "parking_lot", + "parking_lot 0.12.1", "pin-project-lite", "signal-hook-registry", "socket2", @@ -4468,6 +4746,12 @@ dependencies = [ "tracing-log", ] +[[package]] +name = "triomphe" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0eee8098afad3fb0c54a9007aab6804558410503ad676d4633f9c2559a00ac0f" + [[package]] name = "try-lock" version = "0.2.4" @@ -4514,6 +4798,15 @@ dependencies = [ "version_check", ] +[[package]] +name = "unicase" +version = "2.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "50f37be617794602aabbeee0be4f259dc1778fabe05e2d67ee8f79326d5cb4f6" +dependencies = [ + "version_check", +] + [[package]] name = "unicode-bidi" version = "0.3.13" @@ -4637,6 +4930,22 @@ dependencies = [ "libc", ] +[[package]] +name = "waker-fn" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d5b2c62b4012a3e1eca5a7e077d13b3bf498c4073e33ccd58626607748ceeca" + +[[package]] +name = "walkdir" +version = "2.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "36df944cda56c7d8d8b7496af378e6b16de9284591917d307c9b4d313c44e698" +dependencies = [ + "same-file", + "winapi-util", +] + [[package]] name = "want" version = "0.3.0" diff --git a/libsqlx-server/Cargo.toml b/libsqlx-server/Cargo.toml index 90f2ca0b..26eb60ff 100644 --- a/libsqlx-server/Cargo.toml +++ b/libsqlx-server/Cargo.toml @@ -7,13 +7,16 @@ edition = "2021" [dependencies] axum = "0.6.18" +bincode = "1.3.3" clap = { version = "4.3.11", features = ["derive"] } color-eyre = "0.6.2" futures = "0.3.28" -hyper = { version = "0.14.27", features = ["h2"] } +hyper = { version = "0.14.27", features = ["h2", "server"] } libsqlx = { version = "0.1.0", path = "../libsqlx" } +moka = { version = "0.11.2", features = ["future"] } serde = { version = "1.0.166", features = ["derive"] } +sled = "0.34.7" tokio = { version = "1.29.1", features = ["full"] } tracing = "0.1.37" -tracing-subscriber = "0.3.17" +tracing-subscriber = { version = "0.3.17", features = ["env-filter"] } uuid = { version = "1.4.0", features = ["v4"] } diff --git a/libsqlx-server/src/allocation/config.rs b/libsqlx-server/src/allocation/config.rs index 19a6396b..f5839e9c 100644 --- a/libsqlx-server/src/allocation/config.rs +++ b/libsqlx-server/src/allocation/config.rs @@ -1,9 +1,21 @@ -use serde::{Serialize, Deserialize}; +use serde::{Deserialize, Serialize}; +/// Structural supertype of AllocConfig, used for checking the meta version. Subsequent version of +/// AllocConfig need to conform to this prototype. #[derive(Debug, Serialize, Deserialize)] -pub enum AllocConfig { - Primary { }, - Replica { - primary_node_id: String, - } +struct ConfigVersion { + config_version: u32, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct AllocConfig { + pub max_conccurent_connection: u32, + pub id: String, + pub db_config: DbConfig, +} + +#[derive(Debug, Serialize, Deserialize)] +pub enum DbConfig { + Primary {}, + Replica { primary_node_id: String }, } diff --git a/libsqlx-server/src/allocation/mod.rs b/libsqlx-server/src/allocation/mod.rs index 2fa9a4cd..21e4c97c 100644 --- a/libsqlx-server/src/allocation/mod.rs +++ b/libsqlx-server/src/allocation/mod.rs @@ -1,18 +1,24 @@ use std::collections::HashMap; +use std::path::PathBuf; -use tokio::{sync::{mpsc, oneshot}, task::{JoinSet, block_in_place}}; +use libsqlx::libsql::{LibsqlDatabase, LogCompactor, LogFile, PrimaryType}; +use libsqlx::Database as _; +use tokio::sync::{mpsc, oneshot}; +use tokio::task::{block_in_place, JoinSet}; + +use self::config::{AllocConfig, DbConfig}; pub mod config; type ExecFn = Box; #[derive(Clone)] -struct ConnectionId { +pub struct ConnectionId { id: u32, close_sender: mpsc::Sender<()>, } -enum AllocationMessage { +pub enum AllocationMessage { /// Execute callback against connection Exec { connection_id: ConnectionId, @@ -22,30 +28,61 @@ enum AllocationMessage { NewConnExec { exec: ExecFn, ret: oneshot::Sender, - } + }, +} + +pub enum Database { + Primary(libsqlx::libsql::LibsqlDatabase), } -enum Database {} +struct Compactor; + +impl LogCompactor for Compactor { + fn should_compact(&self, _log: &LogFile) -> bool { + false + } + + fn compact( + &self, + _log: LogFile, + _path: std::path::PathBuf, + _size_after: u32, + ) -> Result<(), Box> { + todo!() + } +} impl Database { + pub fn from_config(config: &AllocConfig, path: PathBuf) -> Self { + match config.db_config { + DbConfig::Primary {} => { + let db = LibsqlDatabase::new_primary(path, Compactor, false).unwrap(); + Self::Primary(db) + } + DbConfig::Replica { .. } => todo!(), + } + } + fn connect(&self) -> Box { - todo!(); + match self { + Database::Primary(db) => Box::new(db.connect().unwrap()), + } } } pub struct Allocation { - inbox: mpsc::Receiver, - database: Database, + pub inbox: mpsc::Receiver, + pub database: Database, /// senders to the spawned connections - connections: HashMap>, + pub connections: HashMap>, /// spawned connection futures, returning their connection id on completion. - connections_futs: JoinSet, - next_conn_id: u32, - max_concurrent_connections: u32, + pub connections_futs: JoinSet, + pub next_conn_id: u32, + pub max_concurrent_connections: u32, } impl Allocation { - async fn run(mut self) { + pub async fn run(mut self) { loop { tokio::select! { Some(msg) = self.inbox.recv() => { @@ -86,23 +123,19 @@ impl Allocation { exec: exec_receiver, }; - self.connections_futs.spawn(conn.run()); // This should never block! assert!(exec_sender.try_send(exec).is_ok()); assert!(self.connections.insert(id, exec_sender).is_none()); - ConnectionId { - id, - close_sender, - } + ConnectionId { id, close_sender } } fn next_conn_id(&mut self) -> u32 { loop { self.next_conn_id = self.next_conn_id.wrapping_add(1); if !self.connections.contains_key(&self.next_conn_id) { - return self.next_conn_id + return self.next_conn_id; } } } diff --git a/libsqlx-server/src/databases/mod.rs b/libsqlx-server/src/databases/mod.rs new file mode 100644 index 00000000..0494174b --- /dev/null +++ b/libsqlx-server/src/databases/mod.rs @@ -0,0 +1,5 @@ +use uuid::Uuid; + +mod store; + +pub type DatabaseId = Uuid; diff --git a/libsqlx-server/src/databases/store.rs b/libsqlx-server/src/databases/store.rs new file mode 100644 index 00000000..206beb34 --- /dev/null +++ b/libsqlx-server/src/databases/store.rs @@ -0,0 +1,12 @@ +use std::collections::HashMap; + +use super::DatabaseId; + +pub enum Database { + Replica, + Primary, +} + +pub struct DatabaseManager { + databases: HashMap, +} diff --git a/libsqlx-server/src/http/admin.rs b/libsqlx-server/src/http/admin.rs index 2d9c8054..51ba1b7f 100644 --- a/libsqlx-server/src/http/admin.rs +++ b/libsqlx-server/src/http/admin.rs @@ -1,4 +1,4 @@ -use std::sync::Arc; +use std::{path::PathBuf, sync::Arc}; use axum::{extract::State, routing::post, Json, Router}; use color_eyre::eyre::Result; @@ -6,24 +6,30 @@ use hyper::server::accept::Accept; use serde::{Deserialize, Serialize}; use tokio::io::{AsyncRead, AsyncWrite}; -use crate::{meta::MetaStore, allocation::config::AllocConfig}; +use crate::{ + allocation::config::{AllocConfig, DbConfig}, + meta::Store, +}; -pub struct AdminServerConfig {} +pub struct AdminServerConfig { + pub db_path: PathBuf, +} struct AdminServerState { - meta_store: Arc, + meta_store: Arc, } -pub async fn run_admin_server(_config: AdminServerConfig, listener: I) -> Result<()> +pub async fn run_admin_server(config: AdminServerConfig, listener: I) -> Result<()> where I: Accept, I::Conn: AsyncRead + AsyncWrite + Send + Unpin + 'static, { let state = AdminServerState { - meta_store: todo!(), + meta_store: Arc::new(Store::new(&config.db_path)), }; + let app = Router::new() - .route("/manage/allocation/create", post(allocate)) + .route("/manage/allocation", post(allocate).get(list_allocs)) .with_state(Arc::new(state)); axum::Server::builder(listener) .serve(app.into_make_service()) @@ -36,18 +42,59 @@ where struct ErrorResponse {} #[derive(Serialize, Debug)] -struct AllocateResp { } +struct AllocateResp {} #[derive(Deserialize, Debug)] struct AllocateReq { alloc_id: String, - config: AllocConfig, + max_conccurent_connection: Option, + config: DbConfigReq, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum DbConfigReq { + Primary { }, + Replica { primary_node_id: String }, } async fn allocate( State(state): State>, Json(req): Json, ) -> Result, Json> { - state.meta_store.allocate(&req.alloc_id, &req.config).await; - todo!(); + let config = AllocConfig { + max_conccurent_connection: req.max_conccurent_connection.unwrap_or(16), + id: req.alloc_id.clone(), + db_config: match req.config { + DbConfigReq::Primary { } => DbConfig::Primary { }, + DbConfigReq::Replica { primary_node_id } => DbConfig::Replica { primary_node_id }, + }, + }; + state.meta_store.allocate(&req.alloc_id, &config).await; + + Ok(Json(AllocateResp {})) +} + +#[derive(Serialize, Debug)] +struct ListAllocResp { + allocs: Vec, +} + +#[derive(Serialize, Debug)] +struct AllocView { + id: String, +} + +async fn list_allocs( + State(state): State>, +) -> Result, Json> { + let allocs = state + .meta_store + .list_allocs() + .await + .into_iter() + .map(|cfg| AllocView { id: cfg.id }) + .collect(); + + Ok(Json(ListAllocResp { allocs })) } diff --git a/libsqlx-server/src/main.rs b/libsqlx-server/src/main.rs index 1ee047bf..fb213397 100644 --- a/libsqlx-server/src/main.rs +++ b/libsqlx-server/src/main.rs @@ -1,16 +1,31 @@ +use std::path::PathBuf; + use color_eyre::eyre::Result; +use http::admin::{run_admin_server, AdminServerConfig}; +use hyper::server::conn::AddrIncoming; use tracing::metadata::LevelFilter; use tracing_subscriber::prelude::*; mod allocation; mod databases; mod http; +mod manager; mod meta; #[tokio::main] async fn main() -> Result<()> { init(); + let admin_api_listener = tokio::net::TcpListener::bind("0.0.0.0:3456").await?; + run_admin_server( + AdminServerConfig { + db_path: PathBuf::from("database"), + }, + AddrIncoming::from_listener(admin_api_listener)?, + ) + .await + .unwrap(); + Ok(()) } diff --git a/libsqlx-server/src/manager.rs b/libsqlx-server/src/manager.rs new file mode 100644 index 00000000..8d7737a7 --- /dev/null +++ b/libsqlx-server/src/manager.rs @@ -0,0 +1,50 @@ +use std::collections::HashMap; +use std::path::PathBuf; +use std::sync::Arc; + +use moka::future::Cache; +use tokio::sync::mpsc; +use tokio::task::JoinSet; + +use crate::allocation::{Allocation, AllocationMessage, Database}; +use crate::meta::Store; + +pub struct Manager { + cache: Cache>, + meta_store: Arc, + db_path: PathBuf, +} + +const MAX_ALLOC_MESSAGE_QUEUE_LEN: usize = 32; + +impl Manager { + pub async fn alloc(&self, alloc_id: &str) -> mpsc::Sender { + if let Some(sender) = self.cache.get(alloc_id) { + return sender.clone(); + } + + if let Some(config) = self.meta_store.meta(alloc_id).await { + let path = self.db_path.join("dbs").join(alloc_id); + tokio::fs::create_dir_all(&path).await.unwrap(); + let (alloc_sender, inbox) = mpsc::channel(MAX_ALLOC_MESSAGE_QUEUE_LEN); + let alloc = Allocation { + inbox, + database: Database::from_config(&config, path), + connections: HashMap::new(), + connections_futs: JoinSet::new(), + next_conn_id: 0, + max_concurrent_connections: config.max_conccurent_connection, + }; + + tokio::spawn(alloc.run()); + + self.cache + .insert(alloc_id.to_string(), alloc_sender.clone()) + .await; + + return alloc_sender; + } + + todo!("alloc doesn't exist") + } +} diff --git a/libsqlx-server/src/meta.rs b/libsqlx-server/src/meta.rs index 7f48a456..475a0250 100644 --- a/libsqlx-server/src/meta.rs +++ b/libsqlx-server/src/meta.rs @@ -1,10 +1,57 @@ +use std::path::Path; + +use sled::Tree; use uuid::Uuid; use crate::allocation::config::AllocConfig; -pub struct MetaStore {} +type ExecFn = Box)>; + +pub struct Store { + meta_store: Tree, +} + +impl Store { + pub fn new(path: &Path) -> Self { + std::fs::create_dir_all(&path).unwrap(); + let path = path.join("store"); + let db = sled::open(path).unwrap(); + let meta_store = db.open_tree("meta_store").unwrap(); + + Self { meta_store } + } + + pub async fn allocate(&self, alloc_id: &str, meta: &AllocConfig) { + //TODO: Handle conflict + tokio::task::block_in_place(|| { + let meta_bytes = bincode::serialize(meta).unwrap(); + self.meta_store + .compare_and_swap(alloc_id, None as Option<&[u8]>, Some(meta_bytes)) + .unwrap() + .unwrap(); + }); + } -impl MetaStore { - pub async fn allocate(&self, alloc_id: &str, meta: &AllocConfig) {} pub async fn deallocate(&self, alloc_id: Uuid) {} + + pub async fn meta(&self, alloc_id: &str) -> Option { + tokio::task::block_in_place(|| { + let config = self.meta_store.get(alloc_id).unwrap()?; + let config = bincode::deserialize(config.as_ref()).unwrap(); + Some(config) + }) + } + + pub async fn list_allocs(&self) -> Vec { + tokio::task::block_in_place(|| { + let mut out = Vec::new(); + for kv in self.meta_store.iter() { + let (k, v) = kv.unwrap(); + let alloc = bincode::deserialize(&v).unwrap(); + out.push(alloc); + } + + out + }) + } } diff --git a/libsqlx/src/connection.rs b/libsqlx/src/connection.rs index cc4776c4..a5eb7e60 100644 --- a/libsqlx/src/connection.rs +++ b/libsqlx/src/connection.rs @@ -1,5 +1,9 @@ -use crate::program::Program; +use rusqlite::types::Value; + +use crate::program::{Program, Step}; +use crate::query::Query; use crate::result_builder::ResultBuilder; +use crate::QueryBuilderConfig; #[derive(Debug, Clone)] pub struct DescribeResponse { @@ -30,6 +34,61 @@ pub trait Connection { /// Parse the SQL statement and return information about it. fn describe(&self, sql: String) -> crate::Result; + + /// execute a single query + fn execute(&mut self, query: Query) -> crate::Result>> { + #[derive(Default)] + struct RowsBuilder { + error: Option, + rows: Vec>, + current_row: Vec, + } + + impl ResultBuilder for RowsBuilder { + fn init( + &mut self, + _config: &QueryBuilderConfig, + ) -> std::result::Result<(), crate::QueryResultBuilderError> { + self.error = None; + self.rows.clear(); + self.current_row.clear(); + + Ok(()) + } + + fn add_row_value( + &mut self, + v: rusqlite::types::ValueRef, + ) -> Result<(), crate::QueryResultBuilderError> { + self.current_row.push(v.into()); + Ok(()) + } + + fn finish_row(&mut self) -> Result<(), crate::QueryResultBuilderError> { + let row = std::mem::take(&mut self.current_row); + self.rows.push(row); + + Ok(()) + } + + fn step_error( + &mut self, + error: crate::error::Error, + ) -> Result<(), crate::QueryResultBuilderError> { + self.error.replace(error); + Ok(()) + } + } + + let pgm = Program::new(vec![Step { cond: None, query }]); + let mut builder = RowsBuilder::default(); + self.execute_program(pgm, &mut builder)?; + if let Some(err) = builder.error.take() { + Err(err) + } else { + Ok(builder.rows) + } + } } impl Connection for Box { diff --git a/libsqlx/src/database/libsql/connection.rs b/libsqlx/src/database/libsql/connection.rs index 8a1c8c55..0a2cb6b0 100644 --- a/libsqlx/src/database/libsql/connection.rs +++ b/libsqlx/src/database/libsql/connection.rs @@ -237,7 +237,11 @@ fn eval_cond(cond: &Cond, results: &[bool]) -> Result { } impl Connection for LibsqlConnection { - fn execute_program(&mut self, pgm: Program, builder: &mut dyn ResultBuilder) -> crate::Result<()> { + fn execute_program( + &mut self, + pgm: Program, + builder: &mut dyn ResultBuilder, + ) -> crate::Result<()> { self.run(pgm, builder) } diff --git a/libsqlx/src/database/libsql/mod.rs b/libsqlx/src/database/libsql/mod.rs index 27397663..41de3569 100644 --- a/libsqlx/src/database/libsql/mod.rs +++ b/libsqlx/src/database/libsql/mod.rs @@ -9,15 +9,16 @@ use crate::database::{Database, InjectError, InjectableDatabase}; use crate::error::Error; use crate::result_builder::QueryBuilderConfig; -use connection::{LibsqlConnection, RowStats}; +use connection::RowStats; use injector::Injector; use replication_log::logger::{ ReplicationLogger, ReplicationLoggerHook, ReplicationLoggerHookCtx, REPLICATION_METHODS, }; use self::injector::InjectorCommitHandler; -use self::replication_log::logger::LogCompactor; +pub use connection::LibsqlConnection; +pub use replication_log::logger::{LogCompactor, LogFile}; pub use replication_log::merger::SnapshotMerger; mod connection; @@ -67,6 +68,18 @@ pub trait LibsqlDbType { fn hook_context(&self) -> ::Context; } +pub struct PlainType; + +impl LibsqlDbType for PlainType { + type ConnectionHook = TransparentMethods; + + fn hook() -> &'static WalMethodsHook { + &TRANSPARENT_METHODS + } + + fn hook_context(&self) -> ::Context {} +} + /// A generic wrapper around a libsql database. /// `LibsqlDatabase` can be specialized into either a `ReplicaType` or a `PrimaryType`. /// In `PrimaryType` mode, the LibsqlDatabase maintains a replication log that can be replicated to @@ -112,6 +125,12 @@ impl LibsqlDatabase { } } +impl LibsqlDatabase { + pub fn new_plain(db_path: PathBuf) -> crate::Result { + Ok(Self::new(db_path, PlainType)) + } +} + impl LibsqlDatabase { pub fn new_primary( db_path: PathBuf, @@ -155,16 +174,18 @@ impl Database for LibsqlDatabase { type Connection = LibsqlConnection<::Context>; fn connect(&self) -> Result { - Ok(LibsqlConnection::<::Context>::new( - &self.db_path, - self.extensions.clone(), - T::hook(), - self.ty.hook_context(), - self.row_stats_callback.clone(), - QueryBuilderConfig { - max_size: Some(self.response_size_limit), - }, - )?) + Ok( + LibsqlConnection::<::Context>::new( + &self.db_path, + self.extensions.clone(), + T::hook(), + self.ty.hook_context(), + self.row_stats_callback.clone(), + QueryBuilderConfig { + max_size: Some(self.response_size_limit), + }, + )?, + ) } } @@ -188,9 +209,9 @@ impl super::Injector for Injector { #[cfg(test)] mod test { - use std::cell::Cell; use std::fs::File; - use std::rc::Rc; + use std::sync::atomic::AtomicBool; + use std::sync::atomic::Ordering::Relaxed; use rusqlite::types::Value; @@ -224,11 +245,7 @@ mod test { let mut conn = db.connect().unwrap(); let mut builder = ReadRowBuilder(Vec::new()); - conn - .execute_program( - Program::seq(&["select count(*) from test"]), - &mut builder - ) + conn.execute_program(Program::seq(&["select count(*) from test"]), &mut builder) .unwrap(); assert!(builder.0.is_empty()); @@ -240,11 +257,7 @@ mod test { .for_each(|f| injector.inject(f.unwrap()).unwrap()); let mut builder = ReadRowBuilder(Vec::new()); - conn - .execute_program( - Program::seq(&["select count(*) from test"]), - &mut builder - ) + conn.execute_program(Program::seq(&["select count(*) from test"]), &mut builder) .unwrap(); assert_eq!(builder.0[0], Value::Integer(5)); } @@ -288,10 +301,7 @@ mod test { let mut replica_conn = replica.connect().unwrap(); let mut builder = ReadRowBuilder(Vec::new()); replica_conn - .execute_program( - Program::seq(&["select * from test limit 1"]), - &mut builder - ) + .execute_program(Program::seq(&["select * from test limit 1"]), &mut builder) .unwrap(); assert_eq!(builder.0.len(), 1); @@ -300,7 +310,7 @@ mod test { #[test] fn primary_compact_log() { - struct Compactor(Rc>); + struct Compactor(Arc); impl LogCompactor for Compactor { fn should_compact(&self, log: &LogFile) -> bool { @@ -312,14 +322,14 @@ mod test { _file: LogFile, _path: PathBuf, _size_after: u32, - ) -> anyhow::Result<()> { - self.0.set(true); + ) -> Result<(), Box> { + self.0.store(true, Relaxed); Ok(()) } } let temp = tempfile::tempdir().unwrap(); - let compactor_called = Rc::new(Cell::new(false)); + let compactor_called = Arc::new(AtomicBool::new(false)); let db = LibsqlDatabase::new_primary( temp.path().to_path_buf(), Compactor(compactor_called.clone()), @@ -333,17 +343,17 @@ mod test { &mut (), ) .unwrap(); - assert!(compactor_called.get()); + assert!(compactor_called.load(Relaxed)); } #[test] fn no_compaction_uncommited_frames() { - struct Compactor(Rc>); + struct Compactor(Arc); impl LogCompactor for Compactor { fn should_compact(&self, log: &LogFile) -> bool { assert_eq!(log.uncommitted_frame_count, 0); - self.0.set(true); + self.0.store(true, Relaxed); false } @@ -352,13 +362,13 @@ mod test { _file: LogFile, _path: PathBuf, _size_after: u32, - ) -> anyhow::Result<()> { + ) -> Result<(), Box> { unreachable!() } } let temp = tempfile::tempdir().unwrap(); - let compactor_called = Rc::new(Cell::new(false)); + let compactor_called = Arc::new(AtomicBool::new(false)); let db = LibsqlDatabase::new_primary( temp.path().to_path_buf(), Compactor(compactor_called.clone()), @@ -377,8 +387,9 @@ mod test { ) .unwrap(); conn.inner_connection().cache_flush().unwrap(); - assert!(!compactor_called.get()); - conn.execute_program(Program::seq(&["commit"]), &mut ()).unwrap(); - assert!(compactor_called.get()); + assert!(!compactor_called.load(Relaxed)); + conn.execute_program(Program::seq(&["commit"]), &mut ()) + .unwrap(); + assert!(compactor_called.load(Relaxed)); } } diff --git a/libsqlx/src/database/libsql/replication_log/logger.rs b/libsqlx/src/database/libsql/replication_log/logger.rs index 25546e36..aebff0db 100644 --- a/libsqlx/src/database/libsql/replication_log/logger.rs +++ b/libsqlx/src/database/libsql/replication_log/logger.rs @@ -602,7 +602,9 @@ impl LogFile { // swap old and new snapshot atomic_rename(&temp_log_path, path.join("wallog")).unwrap(); let old_log_file = std::mem::replace(self, new_log_file); - compactor.compact(old_log_file, temp_log_path, size_after)?; + compactor + .compact(old_log_file, temp_log_path, size_after) + .unwrap(); Ok(()) } @@ -717,17 +719,27 @@ impl Generation { } } -pub trait LogCompactor: 'static { +pub trait LogCompactor: Sync + Send + 'static { /// returns whether the passed log file should be compacted. If this method returns true, /// compact should be called next. fn should_compact(&self, log: &LogFile) -> bool; /// Compact the given snapshot - fn compact(&self, log: LogFile, path: PathBuf, size_after: u32) -> anyhow::Result<()>; + fn compact( + &self, + log: LogFile, + path: PathBuf, + size_after: u32, + ) -> Result<(), Box>; } #[cfg(test)] impl LogCompactor for () { - fn compact(&self, _file: LogFile, _path: PathBuf, _size_after: u32) -> anyhow::Result<()> { + fn compact( + &self, + _file: LogFile, + _path: PathBuf, + _size_after: u32, + ) -> Result<(), Box> { Ok(()) } @@ -739,7 +751,7 @@ impl LogCompactor for () { pub struct ReplicationLogger { pub generation: Generation, pub log_file: RwLock, - compactor: Box, + compactor: Box, db_path: PathBuf, /// a notifier channel other tasks can subscribe to, and get notified when new frames become /// available. diff --git a/libsqlx/src/database/mod.rs b/libsqlx/src/database/mod.rs index fa1ce874..62581402 100644 --- a/libsqlx/src/database/mod.rs +++ b/libsqlx/src/database/mod.rs @@ -1,8 +1,8 @@ use std::time::Duration; +use self::frame::Frame; use crate::connection::Connection; use crate::error::Error; -use self::frame::Frame; mod frame; pub mod libsql; From 52c566c4d86dca7de1c2e7f7d1ec350b7d7e366a Mon Sep 17 00:00:00 2001 From: ad hoc Date: Mon, 10 Jul 2023 09:56:41 +0200 Subject: [PATCH 08/64] user API & database extractor --- Cargo.lock | 34 ++++++++--- libsqlx-server/Cargo.toml | 2 + libsqlx-server/src/http/admin.rs | 12 ++-- libsqlx-server/src/http/mod.rs | 1 + libsqlx-server/src/http/user.rs | 101 +++++++++++++++++++++++++++++++ libsqlx-server/src/main.rs | 33 +++++++--- libsqlx-server/src/manager.rs | 17 ++++-- libsqlx-server/src/meta.rs | 4 +- 8 files changed, 174 insertions(+), 30 deletions(-) create mode 100644 libsqlx-server/src/http/user.rs diff --git a/Cargo.lock b/Cargo.lock index 0864a55e..b4364605 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2575,8 +2575,10 @@ dependencies = [ "hyper", "libsqlx", "moka", + "regex", "serde", "sled", + "thiserror", "tokio", "tracing", "tracing-subscriber", @@ -2650,7 +2652,7 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8263075bb86c5a1b1427b5ae862e8889656f126e9f77c484496e8b47cf5c5558" dependencies = [ - "regex-automata", + "regex-automata 0.1.10", ] [[package]] @@ -3548,13 +3550,14 @@ dependencies = [ [[package]] name = "regex" -version = "1.8.4" +version = "1.9.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d0ab3ca65655bb1e41f2a8c8cd662eb4fb035e67c3f78da1d61dffe89d07300f" +checksum = "b2eae68fc220f7cf2532e4494aded17545fce192d59cd996e0fe7887f4ceb575" dependencies = [ "aho-corasick", "memchr", - "regex-syntax 0.7.2", + "regex-automata 0.3.2", + "regex-syntax 0.7.3", ] [[package]] @@ -3566,6 +3569,17 @@ dependencies = [ "regex-syntax 0.6.29", ] +[[package]] +name = "regex-automata" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "83d3daa6976cffb758ec878f108ba0e062a45b2d6ca3a2cca965338855476caf" +dependencies = [ + "aho-corasick", + "memchr", + "regex-syntax 0.7.3", +] + [[package]] name = "regex-syntax" version = "0.6.29" @@ -3574,9 +3588,9 @@ checksum = "f162c6dd7b008981e4d40210aca20b4bd0f9b60ca9271061b07f78537722f2e1" [[package]] name = "regex-syntax" -version = "0.7.2" +version = "0.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "436b050e76ed2903236f032a59761c1eb99e1b0aead2c257922771dab1fc8c78" +checksum = "2ab07dc67230e4a4718e70fd5c20055a4334b121f1f9db8fe63ef39ce9b8c846" [[package]] name = "reqwest" @@ -4334,18 +4348,18 @@ dependencies = [ [[package]] name = "thiserror" -version = "1.0.40" +version = "1.0.43" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "978c9a314bd8dc99be594bc3c175faaa9794be04a5a5e153caba6915336cebac" +checksum = "a35fc5b8971143ca348fa6df4f024d4d55264f3468c71ad1c2f365b0a4d58c42" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.40" +version = "1.0.43" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f9456a42c5b0d803c8cd86e73dd7cc9edd429499f37a3550d286d5e86720569f" +checksum = "463fe12d7993d3b327787537ce8dd4dfa058de32fc2b195ef3cde03dc4771e8f" dependencies = [ "proc-macro2", "quote", diff --git a/libsqlx-server/Cargo.toml b/libsqlx-server/Cargo.toml index 26eb60ff..4b4668ba 100644 --- a/libsqlx-server/Cargo.toml +++ b/libsqlx-server/Cargo.toml @@ -14,8 +14,10 @@ futures = "0.3.28" hyper = { version = "0.14.27", features = ["h2", "server"] } libsqlx = { version = "0.1.0", path = "../libsqlx" } moka = { version = "0.11.2", features = ["future"] } +regex = "1.9.1" serde = { version = "1.0.166", features = ["derive"] } sled = "0.34.7" +thiserror = "1.0.43" tokio = { version = "1.29.1", features = ["full"] } tracing = "0.1.37" tracing-subscriber = { version = "0.3.17", features = ["env-filter"] } diff --git a/libsqlx-server/src/http/admin.rs b/libsqlx-server/src/http/admin.rs index 51ba1b7f..80e787d6 100644 --- a/libsqlx-server/src/http/admin.rs +++ b/libsqlx-server/src/http/admin.rs @@ -11,21 +11,21 @@ use crate::{ meta::Store, }; -pub struct AdminServerConfig { - pub db_path: PathBuf, +pub struct AdminApiConfig { + pub meta_store: Arc, } struct AdminServerState { meta_store: Arc, } -pub async fn run_admin_server(config: AdminServerConfig, listener: I) -> Result<()> +pub async fn run_admin_api(config: AdminApiConfig, listener: I) -> Result<()> where I: Accept, I::Conn: AsyncRead + AsyncWrite + Send + Unpin + 'static, { let state = AdminServerState { - meta_store: Arc::new(Store::new(&config.db_path)), + meta_store: config.meta_store, }; let app = Router::new() @@ -54,7 +54,7 @@ struct AllocateReq { #[derive(Debug, Serialize, Deserialize)] #[serde(tag = "type", rename_all = "snake_case")] pub enum DbConfigReq { - Primary { }, + Primary {}, Replica { primary_node_id: String }, } @@ -66,7 +66,7 @@ async fn allocate( max_conccurent_connection: req.max_conccurent_connection.unwrap_or(16), id: req.alloc_id.clone(), db_config: match req.config { - DbConfigReq::Primary { } => DbConfig::Primary { }, + DbConfigReq::Primary {} => DbConfig::Primary {}, DbConfigReq::Replica { primary_node_id } => DbConfig::Replica { primary_node_id }, }, }; diff --git a/libsqlx-server/src/http/mod.rs b/libsqlx-server/src/http/mod.rs index 92918b09..1e6bf65b 100644 --- a/libsqlx-server/src/http/mod.rs +++ b/libsqlx-server/src/http/mod.rs @@ -1 +1,2 @@ pub mod admin; +pub mod user; diff --git a/libsqlx-server/src/http/user.rs b/libsqlx-server/src/http/user.rs new file mode 100644 index 00000000..040f5a66 --- /dev/null +++ b/libsqlx-server/src/http/user.rs @@ -0,0 +1,101 @@ +use std::sync::Arc; + +use axum::{async_trait, extract::FromRequestParts, response::IntoResponse, routing::get, Router, Json}; +use color_eyre::Result; +use hyper::{http::request::Parts, server::accept::Accept, StatusCode}; +use serde::Serialize; +use tokio::{ + io::{AsyncRead, AsyncWrite}, + sync::mpsc, +}; + +use crate::{allocation::AllocationMessage, manager::Manager}; + +pub struct UserApiConfig { + pub manager: Arc, +} + +struct UserApiState { + manager: Arc, +} + +pub async fn run_user_api(config: UserApiConfig, listener: I) -> Result<()> +where + I: Accept, + I::Conn: AsyncRead + AsyncWrite + Send + Unpin + 'static, +{ + let state = UserApiState { manager: config.manager }; + + let app = Router::new() + .route("/", get(test_database)) + .with_state(Arc::new(state)); + + axum::Server::builder(listener) + .serve(app.into_make_service()) + .await?; + + Ok(()) +} + +struct Database { + sender: mpsc::Sender, +} + +#[derive(Debug, thiserror::Error)] +enum UserApiError { + #[error("missing host header")] + MissingHost, + #[error("invalid host header format")] + InvalidHost, + #[error("Database `{0}` doesn't exist")] + UnknownDatabase(String), +} + +impl UserApiError { + fn http_status(&self) -> StatusCode { + match self { + UserApiError::MissingHost + | UserApiError::InvalidHost + | UserApiError::UnknownDatabase(_) => StatusCode::BAD_REQUEST, + } + } +} + +#[derive(Debug, Serialize)] +struct ApiError { + error: String, +} + +impl IntoResponse for UserApiError { + fn into_response(self) -> axum::response::Response { + let mut resp = Json(ApiError { + error: self.to_string() + }).into_response(); + *resp.status_mut() = self.http_status(); + + resp + } +} + +#[async_trait] +impl FromRequestParts> for Database { + type Rejection = UserApiError; + + async fn from_request_parts( + parts: &mut Parts, + state: &Arc, + ) -> Result { + let Some(host) = parts.headers.get("host") else { return Err(UserApiError::MissingHost) }; + let Ok(host_str) = std::str::from_utf8(host.as_bytes()) else {return Err(UserApiError::MissingHost)}; + let db_id = parse_host(host_str)?; + let Some(sender) = state.manager.alloc(db_id).await else { return Err(UserApiError::UnknownDatabase(db_id.to_owned())) }; + + Ok(Database { sender }) + } +} + +fn parse_host(host: &str) -> Result<&str, UserApiError> { + let mut split = host.split("."); + let Some(db_id) = split.next() else { return Err(UserApiError::InvalidHost) }; + Ok(db_id) +} diff --git a/libsqlx-server/src/main.rs b/libsqlx-server/src/main.rs index fb213397..2e9411cf 100644 --- a/libsqlx-server/src/main.rs +++ b/libsqlx-server/src/main.rs @@ -1,13 +1,20 @@ -use std::path::PathBuf; +use std::{path::PathBuf, sync::Arc}; use color_eyre::eyre::Result; -use http::admin::{run_admin_server, AdminServerConfig}; +use http::{ + admin::{run_admin_api, AdminApiConfig}, + user::{run_user_api, UserApiConfig}, +}; use hyper::server::conn::AddrIncoming; +use manager::Manager; +use meta::Store; +use tokio::task::JoinSet; use tracing::metadata::LevelFilter; use tracing_subscriber::prelude::*; mod allocation; mod databases; +mod hrana; mod http; mod manager; mod meta; @@ -15,16 +22,24 @@ mod meta; #[tokio::main] async fn main() -> Result<()> { init(); + let mut join_set = JoinSet::new(); + let db_path = PathBuf::from("database"); + let store = Arc::new(Store::new(&db_path)); let admin_api_listener = tokio::net::TcpListener::bind("0.0.0.0:3456").await?; - run_admin_server( - AdminServerConfig { - db_path: PathBuf::from("database"), - }, + join_set.spawn(run_admin_api( + AdminApiConfig { meta_store: store.clone() }, AddrIncoming::from_listener(admin_api_listener)?, - ) - .await - .unwrap(); + )); + + let manager = Arc::new(Manager::new(db_path.clone(), store, 100)); + let user_api_listener = tokio::net::TcpListener::bind("0.0.0.0:3457").await?; + join_set.spawn(run_user_api( + UserApiConfig { manager }, + AddrIncoming::from_listener(user_api_listener)?, + )); + + join_set.join_next().await; Ok(()) } diff --git a/libsqlx-server/src/manager.rs b/libsqlx-server/src/manager.rs index 8d7737a7..81ac3b72 100644 --- a/libsqlx-server/src/manager.rs +++ b/libsqlx-server/src/manager.rs @@ -18,9 +18,18 @@ pub struct Manager { const MAX_ALLOC_MESSAGE_QUEUE_LEN: usize = 32; impl Manager { - pub async fn alloc(&self, alloc_id: &str) -> mpsc::Sender { + pub fn new(db_path: PathBuf, meta_store: Arc, max_conccurent_allocs: u64) -> Self { + Self { + cache: Cache::new(max_conccurent_allocs), + meta_store, + db_path, + } + } + + /// Returns a handle to an allocation, lazily initializing if it isn't already loaded. + pub async fn alloc(&self, alloc_id: &str) -> Option> { if let Some(sender) = self.cache.get(alloc_id) { - return sender.clone(); + return Some(sender.clone()); } if let Some(config) = self.meta_store.meta(alloc_id).await { @@ -42,9 +51,9 @@ impl Manager { .insert(alloc_id.to_string(), alloc_sender.clone()) .await; - return alloc_sender; + return Some(alloc_sender); } - todo!("alloc doesn't exist") + None } } diff --git a/libsqlx-server/src/meta.rs b/libsqlx-server/src/meta.rs index 475a0250..4eade1b0 100644 --- a/libsqlx-server/src/meta.rs +++ b/libsqlx-server/src/meta.rs @@ -32,7 +32,9 @@ impl Store { }); } - pub async fn deallocate(&self, alloc_id: Uuid) {} + pub async fn deallocate(&self, alloc_id: Uuid) { + todo!() + } pub async fn meta(&self, alloc_id: &str) -> Option { tokio::task::block_in_place(|| { From c2fed59890d12c5914a7e7151237f63276872fd0 Mon Sep 17 00:00:00 2001 From: ad hoc Date: Mon, 10 Jul 2023 15:38:08 +0200 Subject: [PATCH 09/64] port hrana for libsqlx server we can now allocate primaries, and query them --- Cargo.lock | 36 +- libsqlx-server/Cargo.toml | 10 +- libsqlx-server/src/allocation/mod.rs | 101 ++++-- libsqlx-server/src/database.rs | 19 + libsqlx-server/src/databases/mod.rs | 5 - libsqlx-server/src/databases/store.rs | 12 - libsqlx-server/src/hrana/batch.rs | 131 +++++++ libsqlx-server/src/hrana/http/mod.rs | 118 ++++++ libsqlx-server/src/hrana/http/proto.rs | 115 ++++++ libsqlx-server/src/hrana/http/request.rs | 115 ++++++ libsqlx-server/src/hrana/http/stream.rs | 404 +++++++++++++++++++++ libsqlx-server/src/hrana/mod.rs | 68 ++++ libsqlx-server/src/hrana/proto.rs | 160 ++++++++ libsqlx-server/src/hrana/result_builder.rs | 320 ++++++++++++++++ libsqlx-server/src/hrana/stmt.rs | 289 +++++++++++++++ libsqlx-server/src/hrana/ws/conn.rs | 301 +++++++++++++++ libsqlx-server/src/hrana/ws/handshake.rs | 140 +++++++ libsqlx-server/src/hrana/ws/mod.rs | 104 ++++++ libsqlx-server/src/hrana/ws/proto.rs | 127 +++++++ libsqlx-server/src/hrana/ws/session.rs | 329 +++++++++++++++++ libsqlx-server/src/http/admin.rs | 2 +- libsqlx-server/src/http/user.rs | 101 ------ libsqlx-server/src/http/user/error.rs | 41 +++ libsqlx-server/src/http/user/extractors.rs | 32 ++ libsqlx-server/src/http/user/mod.rs | 48 +++ libsqlx-server/src/main.rs | 6 +- libsqlx-server/src/manager.rs | 4 +- libsqlx-server/src/meta.rs | 4 +- libsqlx/src/analysis.rs | 11 +- libsqlx/src/connection.rs | 11 +- libsqlx/src/database/libsql/connection.rs | 2 +- libsqlx/src/database/libsql/mod.rs | 1 + libsqlx/src/error.rs | 12 +- libsqlx/src/lib.rs | 11 +- libsqlx/src/program.rs | 19 +- libsqlx/src/result_builder.rs | 8 +- 36 files changed, 3015 insertions(+), 202 deletions(-) create mode 100644 libsqlx-server/src/database.rs delete mode 100644 libsqlx-server/src/databases/mod.rs delete mode 100644 libsqlx-server/src/databases/store.rs create mode 100644 libsqlx-server/src/hrana/batch.rs create mode 100644 libsqlx-server/src/hrana/http/mod.rs create mode 100644 libsqlx-server/src/hrana/http/proto.rs create mode 100644 libsqlx-server/src/hrana/http/request.rs create mode 100644 libsqlx-server/src/hrana/http/stream.rs create mode 100644 libsqlx-server/src/hrana/mod.rs create mode 100644 libsqlx-server/src/hrana/proto.rs create mode 100644 libsqlx-server/src/hrana/result_builder.rs create mode 100644 libsqlx-server/src/hrana/stmt.rs create mode 100644 libsqlx-server/src/hrana/ws/conn.rs create mode 100644 libsqlx-server/src/hrana/ws/handshake.rs create mode 100644 libsqlx-server/src/hrana/ws/mod.rs create mode 100644 libsqlx-server/src/hrana/ws/proto.rs create mode 100644 libsqlx-server/src/hrana/ws/session.rs delete mode 100644 libsqlx-server/src/http/user.rs create mode 100644 libsqlx-server/src/http/user/error.rs create mode 100644 libsqlx-server/src/http/user/extractors.rs create mode 100644 libsqlx-server/src/http/user/mod.rs diff --git a/Cargo.lock b/Cargo.lock index b4364605..9752e54f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -695,9 +695,9 @@ checksum = "9e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8" [[package]] name = "base64" -version = "0.21.1" +version = "0.21.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3f1e31e207a6b8fb791a38ea3105e6cb541f55e4d029902d3039a4ad07cc4105" +checksum = "604178f6c5c21f02dc555784810edfb88d34ac2c73b2eae109655649ee73ce3d" [[package]] name = "base64-simd" @@ -2436,7 +2436,7 @@ version = "8.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6971da4d9c3aa03c3d8f3ff0f4155b534aad021292003895a469716b2a230378" dependencies = [ - "base64 0.21.1", + "base64 0.21.2", "pem", "ring", "serde", @@ -2502,7 +2502,7 @@ checksum = "9c7b1c078b4d3d45ba0db91accc23dcb8d2761d67f819efd94293065597b7ac8" dependencies = [ "anyhow", "async-trait", - "base64 0.21.1", + "base64 0.21.2", "num-traits", "reqwest", "serde_json", @@ -2568,15 +2568,23 @@ name = "libsqlx-server" version = "0.1.0" dependencies = [ "axum", + "base64 0.21.2", "bincode", + "bytes 1.4.0", "clap", "color-eyre", "futures", + "hmac", "hyper", "libsqlx", "moka", + "parking_lot 0.12.1", + "priority-queue", + "rand", "regex", "serde", + "serde_json", + "sha2", "sled", "thiserror", "tokio", @@ -3286,9 +3294,9 @@ dependencies = [ [[package]] name = "priority-queue" -version = "1.3.1" +version = "1.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5ca9c6be70d989d21a136eb86c2d83e4b328447fac4a88dace2143c179c86267" +checksum = "fff39edfcaec0d64e8d0da38564fad195d2d51b680940295fcc307366e101e61" dependencies = [ "autocfg", "indexmap 1.9.3", @@ -3598,7 +3606,7 @@ version = "0.11.18" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cde824a14b7c14f85caff81225f411faacc04a2013f41670f41443742b1c1c55" dependencies = [ - "base64 0.21.1", + "base64 0.21.2", "bytes 1.4.0", "encoding_rs", "futures-core", @@ -3755,7 +3763,7 @@ version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d194b56d58803a43635bdc398cd17e383d6f71f9182b9a192c127ca42494a59b" dependencies = [ - "base64 0.21.1", + "base64 0.21.2", ] [[package]] @@ -3900,9 +3908,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.99" +version = "1.0.100" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "46266871c240a00b8f503b877622fe33430b3c7d963bdc0f2adc511e54a1eae3" +checksum = "0f1e14e89be7aa4c4b78bdbdc9eb5bf8517829a600ae8eaa39a6e1d960b5185c" dependencies = [ "indexmap 2.0.0", "itoa", @@ -3944,9 +3952,9 @@ dependencies = [ [[package]] name = "sha2" -version = "0.10.6" +version = "0.10.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "82e6b795fe2e3b1e845bafcb27aa35405c4d47cdfc92af5fc8d3002f76cebdc0" +checksum = "479fb9d862239e610720565ca91403019f2f00410f1864c5aa7479b950a76ed8" dependencies = [ "cfg-if", "cpufeatures", @@ -4116,7 +4124,7 @@ dependencies = [ "aws-config", "aws-sdk-s3", "axum", - "base64 0.21.1", + "base64 0.21.2", "bincode", "bottomless", "bytemuck", @@ -4590,7 +4598,7 @@ checksum = "3082666a3a6433f7f511c7192923fa1fe07c69332d3c6a2e6bb040b569199d5a" dependencies = [ "async-trait", "axum", - "base64 0.21.1", + "base64 0.21.2", "bytes 1.4.0", "futures-core", "futures-util", diff --git a/libsqlx-server/Cargo.toml b/libsqlx-server/Cargo.toml index 4b4668ba..5e6d5c15 100644 --- a/libsqlx-server/Cargo.toml +++ b/libsqlx-server/Cargo.toml @@ -7,15 +7,23 @@ edition = "2021" [dependencies] axum = "0.6.18" +base64 = "0.21.2" bincode = "1.3.3" +bytes = "1.4.0" clap = { version = "4.3.11", features = ["derive"] } color-eyre = "0.6.2" futures = "0.3.28" +hmac = "0.12.1" hyper = { version = "0.14.27", features = ["h2", "server"] } libsqlx = { version = "0.1.0", path = "../libsqlx" } moka = { version = "0.11.2", features = ["future"] } +parking_lot = "0.12.1" +priority-queue = "1.3.2" +rand = "0.8.5" regex = "1.9.1" -serde = { version = "1.0.166", features = ["derive"] } +serde = { version = "1.0.166", features = ["derive", "rc"] } +serde_json = "1.0.100" +sha2 = "0.10.7" sled = "0.34.7" thiserror = "1.0.43" tokio = { version = "1.29.1", features = ["full"] } diff --git a/libsqlx-server/src/allocation/mod.rs b/libsqlx-server/src/allocation/mod.rs index 21e4c97c..a086f479 100644 --- a/libsqlx-server/src/allocation/mod.rs +++ b/libsqlx-server/src/allocation/mod.rs @@ -1,11 +1,15 @@ -use std::collections::HashMap; use std::path::PathBuf; +use std::sync::Arc; use libsqlx::libsql::{LibsqlDatabase, LogCompactor, LogFile, PrimaryType}; use libsqlx::Database as _; use tokio::sync::{mpsc, oneshot}; use tokio::task::{block_in_place, JoinSet}; +use crate::hrana; +use crate::hrana::http::handle_pipeline; +use crate::hrana::http::proto::{PipelineRequestBody, PipelineResponseBody}; + use self::config::{AllocConfig, DbConfig}; pub mod config; @@ -19,16 +23,11 @@ pub struct ConnectionId { } pub enum AllocationMessage { - /// Execute callback against connection - Exec { - connection_id: ConnectionId, - exec: ExecFn, - }, - /// Create a new connection, execute the callback and return the connection id. - NewConnExec { - exec: ExecFn, - ret: oneshot::Sender, - }, + NewConnection(oneshot::Sender), + HranaPipelineReq { + req: PipelineRequestBody, + ret: oneshot::Sender>, + } } pub enum Database { @@ -73,12 +72,34 @@ impl Database { pub struct Allocation { pub inbox: mpsc::Receiver, pub database: Database, - /// senders to the spawned connections - pub connections: HashMap>, /// spawned connection futures, returning their connection id on completion. pub connections_futs: JoinSet, pub next_conn_id: u32, pub max_concurrent_connections: u32, + + pub hrana_server: Arc, +} + +pub struct ConnectionHandle { + exec: mpsc::Sender, + exit: oneshot::Sender<()>, +} + +impl ConnectionHandle { + pub async fn exec(&self, f: F) -> crate::Result + where F: for<'a> FnOnce(&'a mut (dyn libsqlx::Connection + 'a)) -> R + Send + 'static, + R: Send + 'static, + { + let (sender, ret) = oneshot::channel(); + let cb = move |conn: &mut dyn libsqlx::Connection| { + let res = f(conn); + let _ = sender.send(res); + }; + + self.exec.send(Box::new(cb)).await.unwrap(); + + Ok(ret.await?) + } } impl Allocation { @@ -87,23 +108,22 @@ impl Allocation { tokio::select! { Some(msg) = self.inbox.recv() => { match msg { - AllocationMessage::Exec { connection_id, exec } => { - if let Some(sender) = self.connections.get(&connection_id.id) { - if let Err(_) = sender.send(exec).await { - tracing::debug!("connection {} closed.", connection_id.id); - self.connections.remove_entry(&connection_id.id); - } - } - }, - AllocationMessage::NewConnExec { exec, ret } => { - let id = self.new_conn_exec(exec).await; - let _ = ret.send(id); + AllocationMessage::NewConnection(ret) => { + let _ =ret.send(self.new_conn().await); }, + AllocationMessage::HranaPipelineReq { req, ret} => { + let res = handle_pipeline(&self.hrana_server.clone(), req, || async { + let conn= self.new_conn().await; + dbg!(); + Ok(conn) + }).await; + let _ = ret.send(res); + } } }, maybe_id = self.connections_futs.join_next() => { - if let Some(Ok(id)) = maybe_id { - self.connections.remove_entry(&id); + if let Some(Ok(_id)) = maybe_id { + // self.connections.remove_entry(&id); } }, else => break, @@ -111,10 +131,13 @@ impl Allocation { } } - async fn new_conn_exec(&mut self, exec: ExecFn) -> ConnectionId { + async fn new_conn(&mut self) -> ConnectionHandle { + dbg!(); let id = self.next_conn_id(); + dbg!(); let conn = block_in_place(|| self.database.connect()); - let (close_sender, exit) = mpsc::channel(1); + dbg!(); + let (close_sender, exit) = oneshot::channel(); let (exec_sender, exec_receiver) = mpsc::channel(1); let conn = Connection { id, @@ -123,20 +146,24 @@ impl Allocation { exec: exec_receiver, }; + dbg!(); self.connections_futs.spawn(conn.run()); - // This should never block! - assert!(exec_sender.try_send(exec).is_ok()); - assert!(self.connections.insert(id, exec_sender).is_none()); + dbg!(); + + ConnectionHandle { + exec: exec_sender, + exit: close_sender, + } - ConnectionId { id, close_sender } } fn next_conn_id(&mut self) -> u32 { loop { self.next_conn_id = self.next_conn_id.wrapping_add(1); - if !self.connections.contains_key(&self.next_conn_id) { - return self.next_conn_id; - } + return self.next_conn_id; + // if !self.connections.contains_key(&self.next_conn_id) { + // return self.next_conn_id; + // } } } } @@ -144,7 +171,7 @@ impl Allocation { struct Connection { id: u32, conn: Box, - exit: mpsc::Receiver<()>, + exit: oneshot::Receiver<()>, exec: mpsc::Receiver, } @@ -152,7 +179,7 @@ impl Connection { async fn run(mut self) -> u32 { loop { tokio::select! { - _ = self.exit.recv() => break, + _ = &mut self.exit => break, Some(exec) = self.exec.recv() => { tokio::task::block_in_place(|| exec(&mut *self.conn)); } diff --git a/libsqlx-server/src/database.rs b/libsqlx-server/src/database.rs new file mode 100644 index 00000000..d0c979cc --- /dev/null +++ b/libsqlx-server/src/database.rs @@ -0,0 +1,19 @@ +use tokio::sync::{mpsc, oneshot}; + +use crate::hrana::http::proto::{PipelineRequestBody, PipelineResponseBody}; +use crate::allocation::{AllocationMessage, ConnectionHandle}; + +pub struct Database { + pub sender: mpsc::Sender, +} + +impl Database { + pub async fn hrana_pipeline(&self, req: PipelineRequestBody) -> crate::Result { + dbg!(); + let (sender, ret) = oneshot::channel(); + dbg!(); + self.sender.send(AllocationMessage::HranaPipelineReq { req, ret: sender }).await.unwrap(); + dbg!(); + ret.await.unwrap() + } +} diff --git a/libsqlx-server/src/databases/mod.rs b/libsqlx-server/src/databases/mod.rs deleted file mode 100644 index 0494174b..00000000 --- a/libsqlx-server/src/databases/mod.rs +++ /dev/null @@ -1,5 +0,0 @@ -use uuid::Uuid; - -mod store; - -pub type DatabaseId = Uuid; diff --git a/libsqlx-server/src/databases/store.rs b/libsqlx-server/src/databases/store.rs deleted file mode 100644 index 206beb34..00000000 --- a/libsqlx-server/src/databases/store.rs +++ /dev/null @@ -1,12 +0,0 @@ -use std::collections::HashMap; - -use super::DatabaseId; - -pub enum Database { - Replica, - Primary, -} - -pub struct DatabaseManager { - databases: HashMap, -} diff --git a/libsqlx-server/src/hrana/batch.rs b/libsqlx-server/src/hrana/batch.rs new file mode 100644 index 00000000..7d2a1f0c --- /dev/null +++ b/libsqlx-server/src/hrana/batch.rs @@ -0,0 +1,131 @@ +use std::collections::HashMap; + +use crate::allocation::ConnectionHandle; +use crate::hrana::stmt::StmtError; + +use super::result_builder::HranaBatchProtoBuilder; +use super::stmt::{proto_stmt_to_query, stmt_error_from_sqld_error}; +use super::{proto, ProtocolError, Version}; + +use color_eyre::eyre::anyhow; +use libsqlx::analysis::Statement; +use libsqlx::program::{Cond, Program, Step}; +use libsqlx::query::{Query, Params}; +use libsqlx::result_builder::{StepResult, StepResultsBuilder}; + +fn proto_cond_to_cond(cond: &proto::BatchCond, max_step_i: usize) -> color_eyre::Result { + let try_convert_step = |step: i32| -> Result { + let step = usize::try_from(step).map_err(|_| ProtocolError::BatchCondBadStep)?; + if step >= max_step_i { + return Err(ProtocolError::BatchCondBadStep); + } + Ok(step) + }; + let cond = match cond { + proto::BatchCond::Ok { step } => Cond::Ok { + step: try_convert_step(*step)?, + }, + proto::BatchCond::Error { step } => Cond::Err { + step: try_convert_step(*step)?, + }, + proto::BatchCond::Not { cond } => Cond::Not { + cond: proto_cond_to_cond(cond, max_step_i)?.into(), + }, + proto::BatchCond::And { conds } => Cond::And { + conds: conds + .iter() + .map(|cond| proto_cond_to_cond(cond, max_step_i)) + .collect::>()?, + }, + proto::BatchCond::Or { conds } => Cond::Or { + conds: conds + .iter() + .map(|cond| proto_cond_to_cond(cond, max_step_i)) + .collect::>()?, + }, + }; + + Ok(cond) +} + +pub fn proto_batch_to_program( + batch: &proto::Batch, + sqls: &HashMap, + version: Version, +) -> color_eyre::Result { + let mut steps = Vec::with_capacity(batch.steps.len()); + for (step_i, step) in batch.steps.iter().enumerate() { + let query = proto_stmt_to_query(&step.stmt, sqls, version)?; + let cond = step + .condition + .as_ref() + .map(|cond| proto_cond_to_cond(cond, step_i)) + .transpose()?; + let step = Step { query, cond }; + + steps.push(step); + } + + Ok(Program::new(steps)) +} + +pub async fn execute_batch( + db: &ConnectionHandle, + pgm: Program, +) -> color_eyre::Result { + let builder = db.exec(move |conn| -> color_eyre::Result<_> { + let mut builder = HranaBatchProtoBuilder::default(); + conn.execute_program(pgm, &mut builder)?; + Ok(builder) + }).await??; + + Ok(builder.into_ret()) +} + +pub fn proto_sequence_to_program(sql: &str) -> color_eyre::Result { + let stmts = Statement::parse(sql) + .collect::>>() + .map_err(|err| anyhow!(StmtError::SqlParse { source: err.into() }))?; + + let steps = stmts + .into_iter() + .enumerate() + .map(|(step_i, stmt)| { + let cond = match step_i { + 0 => None, + _ => Some(Cond::Ok { step: step_i - 1 }), + }; + let query = Query { + stmt, + params: Params::empty(), + want_rows: false, + }; + Step { cond, query } + }) + .collect(); + + Ok(Program { + steps, + }) +} + +pub async fn execute_sequence(conn: &ConnectionHandle, pgm: Program) -> color_eyre::Result<()> { + let builder = conn.exec(move |conn| -> color_eyre::Result<_> { + let mut builder = StepResultsBuilder::default(); + conn.execute_program(pgm, &mut builder)?; + + Ok(builder) + }).await??; + + builder + .into_ret() + .into_iter() + .try_for_each(|result| match result { + StepResult::Ok => Ok(()), + StepResult::Err(e) => match stmt_error_from_sqld_error(e) { + Ok(stmt_err) => Err(anyhow!(stmt_err)), + Err(sqld_err) => Err(anyhow!(sqld_err)), + }, + StepResult::Skipped => Err(anyhow!("Statement in sequence was not executed")), + }) +} diff --git a/libsqlx-server/src/hrana/http/mod.rs b/libsqlx-server/src/hrana/http/mod.rs new file mode 100644 index 00000000..5e22bedc --- /dev/null +++ b/libsqlx-server/src/hrana/http/mod.rs @@ -0,0 +1,118 @@ +use color_eyre::eyre::Context; +use futures::Future; +use parking_lot::Mutex; +use serde::{de::DeserializeOwned, Serialize}; + +use crate::allocation::ConnectionHandle; + +use self::proto::{PipelineRequestBody, PipelineResponseBody}; + +use super::ProtocolError; + +pub mod proto; +mod request; +mod stream; + +pub struct Server { + self_url: Option, + baton_key: [u8; 32], + stream_state: Mutex, +} + +#[derive(Debug)] +pub enum Route { + GetIndex, + PostPipeline, +} + +impl Server { + pub fn new(self_url: Option) -> Self { + Self { + self_url, + baton_key: rand::random(), + stream_state: Mutex::new(stream::ServerStreamState::new()), + } + } + + pub async fn run_expire(&self) { + stream::run_expire(self).await + } +} + +fn handle_index() -> color_eyre::Result> { + Ok(text_response( + hyper::StatusCode::OK, + "Hello, this is HTTP API v2 (Hrana over HTTP)".into(), + )) +} + +pub async fn handle_pipeline( + server: &Server, + req: PipelineRequestBody, + mk_conn: F +) -> color_eyre::Result +where F: FnOnce() -> Fut, + Fut: Future>, +{ + let mut stream_guard = stream::acquire(server, req.baton.as_deref(), mk_conn).await?; + + let mut results = Vec::with_capacity(req.requests.len()); + for request in req.requests.into_iter() { + let result = request::handle(&mut stream_guard, request) + .await + .context("Could not execute a request in pipeline")?; + results.push(result); + } + + let resp_body = proto::PipelineResponseBody { + baton: stream_guard.release(), + base_url: server.self_url.clone(), + results, + }; + + Ok(resp_body) +} + +async fn read_request_json(req: hyper::Request) -> color_eyre::Result { + let req_body = hyper::body::to_bytes(req.into_body()) + .await + .context("Could not read request body")?; + let req_body = serde_json::from_slice(&req_body) + .map_err(|err| ProtocolError::Deserialize { source: err }) + .context("Could not deserialize JSON request body")?; + Ok(req_body) +} + +fn protocol_error_response(err: ProtocolError) -> hyper::Response { + text_response(hyper::StatusCode::BAD_REQUEST, err.to_string()) +} + +fn stream_error_response(err: stream::StreamError) -> hyper::Response { + json_response( + hyper::StatusCode::INTERNAL_SERVER_ERROR, + &proto::Error { + message: err.to_string(), + code: err.code().into(), + }, + ) +} + +fn json_response( + status: hyper::StatusCode, + resp_body: &T, +) -> hyper::Response { + let resp_body = serde_json::to_vec(resp_body).unwrap(); + hyper::Response::builder() + .status(status) + .header(hyper::http::header::CONTENT_TYPE, "application/json") + .body(hyper::Body::from(resp_body)) + .unwrap() +} + +fn text_response(status: hyper::StatusCode, resp_body: String) -> hyper::Response { + hyper::Response::builder() + .status(status) + .header(hyper::http::header::CONTENT_TYPE, "text/plain") + .body(hyper::Body::from(resp_body)) + .unwrap() +} diff --git a/libsqlx-server/src/hrana/http/proto.rs b/libsqlx-server/src/hrana/http/proto.rs new file mode 100644 index 00000000..ba1285f1 --- /dev/null +++ b/libsqlx-server/src/hrana/http/proto.rs @@ -0,0 +1,115 @@ +//! Structures for Hrana-over-HTTP. + +pub use super::super::proto::*; +use serde::{Deserialize, Serialize}; + +#[derive(Deserialize, Debug)] +pub struct PipelineRequestBody { + pub baton: Option, + pub requests: Vec, +} + +#[derive(Serialize, Debug)] +pub struct PipelineResponseBody { + pub baton: Option, + pub base_url: Option, + pub results: Vec, +} + +#[derive(Serialize, Debug)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum StreamResult { + Ok { response: StreamResponse }, + Error { error: Error }, +} + +#[derive(Deserialize, Debug)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum StreamRequest { + Close(CloseStreamReq), + Execute(ExecuteStreamReq), + Batch(BatchStreamReq), + Sequence(SequenceStreamReq), + Describe(DescribeStreamReq), + StoreSql(StoreSqlStreamReq), + CloseSql(CloseSqlStreamReq), +} + +#[derive(Serialize, Debug)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum StreamResponse { + Close(CloseStreamResp), + Execute(ExecuteStreamResp), + Batch(BatchStreamResp), + Sequence(SequenceStreamResp), + Describe(DescribeStreamResp), + StoreSql(StoreSqlStreamResp), + CloseSql(CloseSqlStreamResp), +} + +#[derive(Deserialize, Debug)] +pub struct CloseStreamReq {} + +#[derive(Serialize, Debug)] +pub struct CloseStreamResp {} + +#[derive(Deserialize, Debug)] +pub struct ExecuteStreamReq { + pub stmt: Stmt, +} + +#[derive(Serialize, Debug)] +pub struct ExecuteStreamResp { + pub result: StmtResult, +} + +#[derive(Deserialize, Debug)] +pub struct BatchStreamReq { + pub batch: Batch, +} + +#[derive(Serialize, Debug)] +pub struct BatchStreamResp { + pub result: BatchResult, +} + +#[derive(Deserialize, Debug)] +pub struct SequenceStreamReq { + #[serde(default)] + pub sql: Option, + #[serde(default)] + pub sql_id: Option, +} + +#[derive(Serialize, Debug)] +pub struct SequenceStreamResp {} + +#[derive(Deserialize, Debug)] +pub struct DescribeStreamReq { + #[serde(default)] + pub sql: Option, + #[serde(default)] + pub sql_id: Option, +} + +#[derive(Serialize, Debug)] +pub struct DescribeStreamResp { + pub result: DescribeResult, +} + +#[derive(Deserialize, Debug)] +pub struct StoreSqlStreamReq { + pub sql_id: i32, + pub sql: String, +} + +#[derive(Serialize, Debug)] +pub struct StoreSqlStreamResp {} + +#[derive(Deserialize, Debug)] +pub struct CloseSqlStreamReq { + pub sql_id: i32, +} + +#[derive(Serialize, Debug)] +pub struct CloseSqlStreamResp {} diff --git a/libsqlx-server/src/hrana/http/request.rs b/libsqlx-server/src/hrana/http/request.rs new file mode 100644 index 00000000..ac6d8912 --- /dev/null +++ b/libsqlx-server/src/hrana/http/request.rs @@ -0,0 +1,115 @@ +use color_eyre::eyre::{anyhow, bail}; + +use super::super::{batch, stmt, ProtocolError, Version}; +use super::{proto, stream}; + +/// An error from executing a [`proto::StreamRequest`] +#[derive(thiserror::Error, Debug)] +pub enum StreamResponseError { + #[error("The server already stores {count} SQL texts, it cannot store more")] + SqlTooMany { count: usize }, + #[error(transparent)] + Stmt(stmt::StmtError), +} + +pub async fn handle( + stream_guard: &mut stream::Guard<'_>, + request: proto::StreamRequest, +) -> color_eyre::Result { + let result = match try_handle(stream_guard, request).await { + Ok(response) => proto::StreamResult::Ok { response }, + Err(err) => { + let resp_err = err.downcast::()?; + let error = proto::Error { + message: resp_err.to_string(), + code: resp_err.code().into(), + }; + proto::StreamResult::Error { error } + } + }; + Ok(result) +} + +async fn try_handle( + stream_guard: &mut stream::Guard<'_>, + request: proto::StreamRequest, +) -> color_eyre::Result { + Ok(match request { + proto::StreamRequest::Close(_req) => { + stream_guard.close_db(); + proto::StreamResponse::Close(proto::CloseStreamResp {}) + } + proto::StreamRequest::Execute(req) => { + let db = stream_guard.get_db()?; + let sqls = stream_guard.sqls(); + let query = stmt::proto_stmt_to_query(&req.stmt, sqls, Version::Hrana2) + .map_err(catch_stmt_error)?; + let result = stmt::execute_stmt(db, query) + .await + .map_err(catch_stmt_error)?; + proto::StreamResponse::Execute(proto::ExecuteStreamResp { result }) + } + proto::StreamRequest::Batch(req) => { + let db = stream_guard.get_db()?; + let sqls = stream_guard.sqls(); + let pgm = batch::proto_batch_to_program(&req.batch, sqls, Version::Hrana2)?; + let result = batch::execute_batch(db, pgm).await?; + proto::StreamResponse::Batch(proto::BatchStreamResp { result }) + } + proto::StreamRequest::Sequence(req) => { + let db = stream_guard.get_db()?; + let sqls = stream_guard.sqls(); + let sql = + stmt::proto_sql_to_sql(req.sql.as_deref(), req.sql_id, sqls, Version::Hrana2)?; + let pgm = batch::proto_sequence_to_program(sql).map_err(catch_stmt_error)?; + batch::execute_sequence(db, pgm) + .await + .map_err(catch_stmt_error)?; + proto::StreamResponse::Sequence(proto::SequenceStreamResp {}) + } + proto::StreamRequest::Describe(req) => { + let db = stream_guard.get_db()?; + let sqls = stream_guard.sqls(); + let sql = + stmt::proto_sql_to_sql(req.sql.as_deref(), req.sql_id, sqls, Version::Hrana2)?; + let result = stmt::describe_stmt(db, sql.into()) + .await + .map_err(catch_stmt_error)?; + proto::StreamResponse::Describe(proto::DescribeStreamResp { result }) + } + proto::StreamRequest::StoreSql(req) => { + let sqls = stream_guard.sqls_mut(); + let sql_id = req.sql_id; + if sqls.contains_key(&sql_id) { + bail!(ProtocolError::SqlExists { sql_id }) + } else if sqls.len() >= MAX_SQL_COUNT { + bail!(StreamResponseError::SqlTooMany { count: sqls.len() }) + } + sqls.insert(sql_id, req.sql); + proto::StreamResponse::StoreSql(proto::StoreSqlStreamResp {}) + } + proto::StreamRequest::CloseSql(req) => { + let sqls = stream_guard.sqls_mut(); + sqls.remove(&req.sql_id); + proto::StreamResponse::CloseSql(proto::CloseSqlStreamResp {}) + } + }) +} + +const MAX_SQL_COUNT: usize = 50; + +fn catch_stmt_error(err: color_eyre::eyre::Error) -> color_eyre::eyre::Error { + match err.downcast::() { + Ok(stmt_err) => anyhow!(StreamResponseError::Stmt(stmt_err)), + Err(err) => err, + } +} + +impl StreamResponseError { + pub fn code(&self) -> &'static str { + match self { + Self::SqlTooMany { .. } => "SQL_STORE_TOO_MANY", + Self::Stmt(err) => err.code(), + } + } +} diff --git a/libsqlx-server/src/hrana/http/stream.rs b/libsqlx-server/src/hrana/http/stream.rs new file mode 100644 index 00000000..1261e7c2 --- /dev/null +++ b/libsqlx-server/src/hrana/http/stream.rs @@ -0,0 +1,404 @@ +use std::cmp::Reverse; +use std::collections::{HashMap, VecDeque}; +use std::pin::Pin; +use std::{future, mem, task}; + +use base64::prelude::{Engine as _, BASE64_STANDARD_NO_PAD}; +use color_eyre::eyre::{anyhow, WrapErr}; +use futures::Future; +use hmac::Mac as _; +use priority_queue::PriorityQueue; +use tokio::time::{Duration, Instant}; + +use super::super::ProtocolError; +use super::Server; +use crate::allocation::ConnectionHandle; + +/// Mutable state related to streams, owned by [`Server`] and protected with a mutex. +pub struct ServerStreamState { + /// Map from stream ids to stream handles. The stream ids are random integers. + handles: HashMap, + /// Queue of streams ordered by the instant when they should expire. All these stream ids + /// should refer to handles in the [`Handle::Available`] variant. + expire_queue: PriorityQueue>, + /// Queue of expired streams that are still stored as [`Handle::Expired`], together with the + /// instant when we should remove them completely. + cleanup_queue: VecDeque<(u64, Instant)>, + /// The timer that we use to wait for the next item in `expire_queue`. + expire_sleep: Pin>, + /// A waker to wake up the task that expires streams from the `expire_queue`. + expire_waker: Option, + /// See [`roundup_instant()`]. + expire_round_base: Instant, +} + +/// Handle to a stream, owned by the [`ServerStreamState`]. +enum Handle { + /// A stream that is open and ready to be used by requests. [`Stream::db`] should always be + /// `Some`. + Available(Box), + /// A stream that has been acquired by a request that hasn't finished processing. This will be + /// replaced with `Available` when the request completes and releases the stream. + Acquired, + /// A stream that has been expired. This stream behaves as closed, but we keep this around for + /// some time to provide a nicer error messages (i.e., if the stream is expired, we return a + /// "stream expired" error rather than "invalid baton" error). + Expired, +} + +/// State of a Hrana-over-HTTP stream. +/// +/// The stream is either owned by [`Handle::Available`] (when it's not in use) or by [`Guard`] +/// (when it's being used by a request). +struct Stream { + /// The database connection that corresponds to this stream. This is `None` after the `"close"` + /// request was executed. + conn: Option, + /// The cache of SQL texts stored on the server with `"store_sql"` requests. + sqls: HashMap, + /// Stream id of this stream. The id is generated randomly (it should be unguessable). + stream_id: u64, + /// Sequence number that is expected in the next baton. To make sure that clients issue stream + /// requests sequentially, the baton returned from each HTTP request includes this sequence + /// number, and the following HTTP request must show a baton with the same sequence number. + baton_seq: u64, +} + +/// Guard object that is used to access a stream from the outside. The guard makes sure that the +/// stream's entry in [`ServerStreamState::handles`] is either removed or replaced with +/// [`Handle::Available`] after the guard goes out of scope. +pub struct Guard<'srv> { + server: &'srv Server, + /// The guarded stream. This is only set to `None` in the destructor. + stream: Option>, + /// If set to `true`, the destructor will release the stream for further use (saving it as + /// [`Handle::Available`] in [`ServerStreamState::handles`]. If false, the stream is removed on + /// drop. + release: bool, +} + +/// An unrecoverable error that should close the stream. The difference from [`ProtocolError`] is +/// that a correct client may trigger this error, it does not mean that the protocol has been +/// violated. +#[derive(thiserror::Error, Debug)] +pub enum StreamError { + #[error("The stream has expired due to inactivity")] + StreamExpired, +} + +impl ServerStreamState { + pub fn new() -> Self { + Self { + handles: HashMap::new(), + expire_queue: PriorityQueue::new(), + cleanup_queue: VecDeque::new(), + expire_sleep: Box::pin(tokio::time::sleep(Duration::ZERO)), + expire_waker: None, + expire_round_base: Instant::now(), + } + } +} + +/// Acquire a guard to a new or existing stream. If baton is `Some`, we try to look up the stream, +/// otherwise we create a new stream. +pub async fn acquire<'srv, F, Fut>( + server: &'srv Server, + baton: Option<&str>, + mk_conn: F, +) -> color_eyre::Result> +where F: FnOnce() -> Fut, + Fut: Future>, +{ + let stream = match baton { + Some(baton) => { + let (stream_id, baton_seq) = decode_baton(server, baton)?; + + let mut state = server.stream_state.lock(); + let handle = state.handles.get_mut(&stream_id); + match handle { + None => { + return Err(ProtocolError::BatonInvalid(format!("Stream handle for {stream_id} was not found")).into()) + } + Some(Handle::Acquired) => { + return Err(ProtocolError::BatonReused) + .context(format!("Stream handle for {stream_id} is acquired")); + } + Some(Handle::Expired) => { + return Err(StreamError::StreamExpired) + .context(format!("Stream handle for {stream_id} is expired")); + } + Some(Handle::Available(stream)) => { + if stream.baton_seq != baton_seq { + return Err(ProtocolError::BatonReused).context(format!( + "Expected baton seq {}, received {baton_seq}", + stream.baton_seq + )); + } + } + }; + + let Handle::Available(mut stream) = mem::replace(handle.unwrap(), Handle::Acquired) else { + unreachable!() + }; + + tracing::debug!("Stream {stream_id} was acquired with baton seq {baton_seq}"); + // incrementing the sequence number forces the next HTTP request to use a different + // baton + stream.baton_seq = stream.baton_seq.wrapping_add(1); + unmark_expire(&mut state, stream.stream_id); + stream + } + None => { + let conn = mk_conn().await.context("Could not create a database connection")?; + + let mut state = server.stream_state.lock(); + let stream = Box::new(Stream { + conn: Some(conn), + sqls: HashMap::new(), + stream_id: gen_stream_id(&mut state), + // initializing the sequence number randomly makes it much harder to exploit + // collisions in batons + baton_seq: rand::random(), + }); + state.handles.insert(stream.stream_id, Handle::Acquired); + tracing::debug!( + "Stream {} was created with baton seq {}", + stream.stream_id, + stream.baton_seq + ); + stream + } + }; + Ok(Guard { + server, + stream: Some(stream), + release: false, + }) +} + +impl<'srv> Guard<'srv> { + pub fn get_db(&self) -> Result<&ConnectionHandle, ProtocolError> { + let stream = self.stream.as_ref().unwrap(); + stream.conn.as_ref().ok_or(ProtocolError::BatonStreamClosed) + } + + /// Closes the database connection. The next call to [`Guard::release()`] will then remove the + /// stream. + pub fn close_db(&mut self) { + let stream = self.stream.as_mut().unwrap(); + stream.conn = None; + } + + pub fn sqls(&self) -> &HashMap { + &self.stream.as_ref().unwrap().sqls + } + + pub fn sqls_mut(&mut self) -> &mut HashMap { + &mut self.stream.as_mut().unwrap().sqls + } + + /// Releases the guard and returns the baton that can be used to access this stream in the next + /// HTTP request. Returns `None` if the stream has been closed (and thus cannot be accessed + /// again). + pub fn release(mut self) -> Option { + let stream = self.stream.as_ref().unwrap(); + if stream.conn.is_some() { + self.release = true; // tell destructor to make the stream available again + Some(encode_baton( + self.server, + stream.stream_id, + stream.baton_seq, + )) + } else { + None + } + } +} + +impl<'srv> Drop for Guard<'srv> { + fn drop(&mut self) { + let stream = self.stream.take().unwrap(); + let stream_id = stream.stream_id; + + let mut state = self.server.stream_state.lock(); + let Some(handle) = state.handles.remove(&stream_id) else { + panic!("Dropped a Guard for stream {stream_id}, \ + but Server does not contain a handle to it"); + }; + if !matches!(handle, Handle::Acquired) { + panic!( + "Dropped a Guard for stream {stream_id}, \ + but Server contained handle that is not acquired" + ); + } + + if self.release { + state.handles.insert(stream_id, Handle::Available(stream)); + mark_expire(&mut state, stream_id); + tracing::debug!("Stream {stream_id} was released for further use"); + } else { + tracing::debug!("Stream {stream_id} was closed"); + } + } +} + +fn gen_stream_id(state: &mut ServerStreamState) -> u64 { + for _ in 0..10 { + let stream_id = rand::random(); + if !state.handles.contains_key(&stream_id) { + return stream_id; + } + } + panic!("Failed to generate a free stream id with rejection sampling") +} + +/// Encodes the baton. +/// +/// The baton is base64-encoded byte string that is composed from: +/// +/// - payload (16 bytes): +/// - `stream_id` (8 bytes, big endian) +/// - `baton_seq` (8 bytes, big endian) +/// - MAC (32 bytes): an authentication code generated with HMAC-SHA256 +/// +/// The MAC is used to cryptographically verify that the baton was generated by this server. It is +/// unlikely that we ever issue the same baton twice, because there are 2^128 possible combinations +/// for payload (note that both `stream_id` and the initial `baton_seq` are generated randomly). +fn encode_baton(server: &Server, stream_id: u64, baton_seq: u64) -> String { + let mut payload = [0; 16]; + payload[0..8].copy_from_slice(&stream_id.to_be_bytes()); + payload[8..16].copy_from_slice(&baton_seq.to_be_bytes()); + + let mut hmac = hmac::Hmac::::new_from_slice(&server.baton_key).unwrap(); + hmac.update(&payload); + let mac = hmac.finalize().into_bytes(); + + let mut baton_data = [0; 48]; + baton_data[0..16].copy_from_slice(&payload); + baton_data[16..48].copy_from_slice(&mac); + BASE64_STANDARD_NO_PAD.encode(baton_data) +} + +/// Decodes a baton encoded with `encode_baton()` and returns `(stream_id, baton_seq)`. Always +/// returns a [`ProtocolError::BatonInvalid`] if the baton is invalid, but it attaches an anyhow +/// context that describes the precise cause. +fn decode_baton(server: &Server, baton_str: &str) -> color_eyre::Result<(u64, u64)> { + let baton_data = BASE64_STANDARD_NO_PAD.decode(baton_str).map_err(|err| { + ProtocolError::BatonInvalid(format!("Could not base64-decode baton: {err}")) + })?; + + if baton_data.len() != 48 { + return Err(ProtocolError::BatonInvalid(format!( + "Baton has invalid size of {} bytes", + baton_data.len() + )).into()); + } + + let payload = &baton_data[0..16]; + let received_mac = &baton_data[16..48]; + + let mut hmac = hmac::Hmac::::new_from_slice(&server.baton_key).unwrap(); + hmac.update(payload); + hmac.verify_slice(received_mac) + .map_err(|_| anyhow!(ProtocolError::BatonInvalid("Invalid MAC on baton".to_string())))?; + + let stream_id = u64::from_be_bytes(payload[0..8].try_into().unwrap()); + let baton_seq = u64::from_be_bytes(payload[8..16].try_into().unwrap()); + Ok((stream_id, baton_seq)) +} + +/// How long do we keep a stream in [`Handle::Available`] state before expiration. Note that every +/// HTTP request resets the timer to beginning, so the client can keep a stream alive for a long +/// time, as long as it pings regularly. +const EXPIRATION: Duration = Duration::from_secs(10); + +/// How long do we keep an expired stream in [`Handle::Expired`] state before removing it for good. +const CLEANUP: Duration = Duration::from_secs(300); + +fn mark_expire(state: &mut ServerStreamState, stream_id: u64) { + let expire_at = roundup_instant(state, Instant::now() + EXPIRATION); + if state.expire_sleep.deadline() > expire_at { + if let Some(waker) = state.expire_waker.take() { + waker.wake(); + } + } + state.expire_queue.push(stream_id, Reverse(expire_at)); +} + +fn unmark_expire(state: &mut ServerStreamState, stream_id: u64) { + state.expire_queue.remove(&stream_id); +} + +/// Handles stream expiration (and cleanup). The returned future is never resolved. +pub async fn run_expire(server: &Server) { + future::poll_fn(|cx| { + let mut state = server.stream_state.lock(); + pump_expire(&mut state, cx); + task::Poll::Pending + }) + .await +} + +fn pump_expire(state: &mut ServerStreamState, cx: &mut task::Context) { + let now = Instant::now(); + + // expire all streams in the `expire_queue` that have passed their expiration time + let wakeup_at = loop { + let stream_id = match state.expire_queue.peek() { + Some((&stream_id, &Reverse(expire_at))) => { + if expire_at <= now { + stream_id + } else { + break expire_at; + } + } + None => break now + Duration::from_secs(60), + }; + state.expire_queue.pop(); + + match state.handles.get_mut(&stream_id) { + Some(handle @ Handle::Available(_)) => { + *handle = Handle::Expired; + } + _ => continue, + } + tracing::debug!("Stream {stream_id} was expired"); + + let cleanup_at = roundup_instant(state, now + CLEANUP); + state.cleanup_queue.push_back((stream_id, cleanup_at)); + }; + + // completely remove streams that are due in `cleanup_queue` + loop { + let stream_id = match state.cleanup_queue.front() { + Some(&(stream_id, cleanup_at)) if cleanup_at <= now => stream_id, + _ => break, + }; + state.cleanup_queue.pop_front(); + + let handle = state.handles.remove(&stream_id); + assert!(matches!(handle, Some(Handle::Expired))); + tracing::debug!("Stream {stream_id} was cleaned up after expiration"); + } + + // make sure that this function is called again no later than at time `wakeup_at` + state.expire_sleep.as_mut().reset(wakeup_at); + state.expire_waker = Some(cx.waker().clone()); + let _: task::Poll<()> = state.expire_sleep.as_mut().poll(cx); +} + +/// Rounds the `instant` to the next second. This is used to ensure that streams that expire close +/// together are expired at exactly the same instant, thus reducing the number of times that +/// [`pump_expire()`] is called during periods of high load. +fn roundup_instant(state: &ServerStreamState, instant: Instant) -> Instant { + let duration_s = (instant - state.expire_round_base).as_secs(); + state.expire_round_base + Duration::from_secs(duration_s + 1) +} + +impl StreamError { + pub fn code(&self) -> &'static str { + match self { + Self::StreamExpired => "STREAM_EXPIRED", + } + } +} diff --git a/libsqlx-server/src/hrana/mod.rs b/libsqlx-server/src/hrana/mod.rs new file mode 100644 index 00000000..fc85fcfe --- /dev/null +++ b/libsqlx-server/src/hrana/mod.rs @@ -0,0 +1,68 @@ +use std::fmt; + +pub mod batch; +pub mod http; +pub mod proto; +mod result_builder; +pub mod stmt; +// pub mod ws; + +#[derive(Debug, Copy, Clone, PartialOrd, Ord, PartialEq, Eq)] +pub enum Version { + Hrana1, + Hrana2, +} + +impl fmt::Display for Version { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + Version::Hrana1 => write!(f, "hrana1"), + Version::Hrana2 => write!(f, "hrana2"), + } + } +} + +/// An unrecoverable protocol error that should close the WebSocket or HTTP stream. A correct +/// client should never trigger any of these errors. +#[derive(thiserror::Error, Debug)] +pub enum ProtocolError { + #[error("Cannot deserialize client message: {source}")] + Deserialize { source: serde_json::Error }, + #[error("Received a binary WebSocket message, which is not supported")] + BinaryWebSocketMessage, + #[error("Received a request before hello message")] + RequestBeforeHello, + + #[error("Stream {stream_id} not found")] + StreamNotFound { stream_id: i32 }, + #[error("Stream {stream_id} already exists")] + StreamExists { stream_id: i32 }, + + #[error("Either `sql` or `sql_id` are required, but not both")] + SqlIdAndSqlGiven, + #[error("Either `sql` or `sql_id` are required")] + SqlIdOrSqlNotGiven, + #[error("SQL text {sql_id} not found")] + SqlNotFound { sql_id: i32 }, + #[error("SQL text {sql_id} already exists")] + SqlExists { sql_id: i32 }, + + #[error("Invalid reference to step in a batch condition")] + BatchCondBadStep, + + #[error("Received an invalid baton: {0}")] + BatonInvalid(String), + #[error("Received a baton that has already been used")] + BatonReused, + #[error("Stream for this baton was closed")] + BatonStreamClosed, + + #[error("{what} is only supported in protocol version {min_version} and higher")] + NotSupported { + what: &'static str, + min_version: Version, + }, + + #[error("{0}")] + ResponseTooLarge(String), +} diff --git a/libsqlx-server/src/hrana/proto.rs b/libsqlx-server/src/hrana/proto.rs new file mode 100644 index 00000000..8d544a07 --- /dev/null +++ b/libsqlx-server/src/hrana/proto.rs @@ -0,0 +1,160 @@ +//! Structures in Hrana that are common for WebSockets and HTTP. + +use bytes::Bytes; +use serde::{Deserialize, Serialize}; +use std::sync::Arc; + +#[derive(Serialize, Debug)] +pub struct Error { + pub message: String, + pub code: String, +} + +#[derive(Deserialize, Debug)] +pub struct Stmt { + #[serde(default)] + pub sql: Option, + #[serde(default)] + pub sql_id: Option, + #[serde(default)] + pub args: Vec, + #[serde(default)] + pub named_args: Vec, + #[serde(default)] + pub want_rows: Option, +} + +#[derive(Deserialize, Debug)] +pub struct NamedArg { + pub name: String, + pub value: Value, +} + +#[derive(Serialize, Debug)] +pub struct StmtResult { + pub cols: Vec, + pub rows: Vec>, + pub affected_row_count: u64, + #[serde(with = "option_i64_as_str")] + pub last_insert_rowid: Option, +} + +#[derive(Serialize, Debug)] +pub struct Col { + pub name: Option, + pub decltype: Option, +} + +#[derive(Deserialize, Debug)] +pub struct Batch { + pub steps: Vec, +} + +#[derive(Deserialize, Debug)] +pub struct BatchStep { + pub stmt: Stmt, + #[serde(default)] + pub condition: Option, +} + +#[derive(Serialize, Debug)] +pub struct BatchResult { + pub step_results: Vec>, + pub step_errors: Vec>, +} + +#[derive(Deserialize, Debug)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum BatchCond { + Ok { step: i32 }, + Error { step: i32 }, + Not { cond: Box }, + And { conds: Vec }, + Or { conds: Vec }, +} + +#[derive(Serialize, Debug)] +pub struct DescribeResult { + pub params: Vec, + pub cols: Vec, + pub is_explain: bool, + pub is_readonly: bool, +} + +#[derive(Serialize, Debug)] +pub struct DescribeParam { + pub name: Option, +} + +#[derive(Serialize, Debug)] +pub struct DescribeCol { + pub name: String, + pub decltype: Option, +} + +#[derive(Serialize, Deserialize, Clone, Debug)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum Value { + Null, + Integer { + #[serde(with = "i64_as_str")] + value: i64, + }, + Float { + value: f64, + }, + Text { + value: Arc, + }, + Blob { + #[serde(with = "bytes_as_base64", rename = "base64")] + value: Bytes, + }, +} + +mod i64_as_str { + use serde::{de, ser}; + use serde::{de::Error as _, Serialize as _}; + + pub fn serialize(value: &i64, ser: S) -> Result { + value.to_string().serialize(ser) + } + + pub fn deserialize<'de, D: de::Deserializer<'de>>(de: D) -> Result { + let str_value = <&'de str as de::Deserialize>::deserialize(de)?; + str_value.parse().map_err(|_| { + D::Error::invalid_value( + de::Unexpected::Str(str_value), + &"decimal integer as a string", + ) + }) + } +} + +mod option_i64_as_str { + use serde::{ser, Serialize as _}; + + pub fn serialize(value: &Option, ser: S) -> Result { + value.map(|v| v.to_string()).serialize(ser) + } +} + +mod bytes_as_base64 { + use base64::{engine::general_purpose::STANDARD_NO_PAD, Engine as _}; + use bytes::Bytes; + use serde::{de, ser}; + use serde::{de::Error as _, Serialize as _}; + + pub fn serialize(value: &Bytes, ser: S) -> Result { + STANDARD_NO_PAD.encode(value).serialize(ser) + } + + pub fn deserialize<'de, D: de::Deserializer<'de>>(de: D) -> Result { + let text = <&'de str as de::Deserialize>::deserialize(de)?; + let text = text.trim_end_matches('='); + let bytes = STANDARD_NO_PAD.decode(text).map_err(|_| { + D::Error::invalid_value(de::Unexpected::Str(text), &"binary data encoded as base64") + })?; + Ok(Bytes::from(bytes)) + } +} diff --git a/libsqlx-server/src/hrana/result_builder.rs b/libsqlx-server/src/hrana/result_builder.rs new file mode 100644 index 00000000..94b23775 --- /dev/null +++ b/libsqlx-server/src/hrana/result_builder.rs @@ -0,0 +1,320 @@ +use std::fmt::{self, Write as _}; +use std::io; + +use bytes::Bytes; +use libsqlx::{result_builder::*, FrameNo}; + +use crate::hrana::stmt::{proto_error_from_stmt_error, stmt_error_from_sqld_error}; + +use super::proto; + +#[derive(Debug, Default)] +pub struct SingleStatementBuilder { + has_step: bool, + cols: Vec, + rows: Vec>, + err: Option, + affected_row_count: u64, + last_insert_rowid: Option, + current_size: u64, + max_response_size: u64, +} + +impl SingleStatementBuilder { + pub fn into_ret(self) -> Result { + match self.err { + Some(err) => Err(err), + None => Ok(proto::StmtResult { + cols: self.cols, + rows: self.rows, + affected_row_count: self.affected_row_count, + last_insert_rowid: self.last_insert_rowid, + }), + } + } +} + +struct SizeFormatter(u64); + +impl io::Write for SizeFormatter { + fn write(&mut self, buf: &[u8]) -> io::Result { + self.0 += buf.len() as u64; + Ok(buf.len()) + } + + fn flush(&mut self) -> io::Result<()> { + Ok(()) + } +} + +impl fmt::Write for SizeFormatter { + fn write_str(&mut self, s: &str) -> fmt::Result { + self.0 += s.len() as u64; + Ok(()) + } +} + +fn value_json_size(v: &ValueRef) -> u64 { + let mut f = SizeFormatter(0); + match v { + ValueRef::Null => write!(&mut f, r#"{{"type":"null"}}"#).unwrap(), + ValueRef::Integer(i) => write!(&mut f, r#"{{"type":"integer", "value": "{i}"}}"#).unwrap(), + ValueRef::Real(x) => write!(&mut f, r#"{{"type":"integer","value": {x}"}}"#).unwrap(), + ValueRef::Text(s) => { + // error will be caught later. + if let Ok(s) = std::str::from_utf8(s) { + write!(&mut f, r#"{{"type":"text","value":"{s}"}}"#).unwrap() + } + } + ValueRef::Blob(b) => return b.len() as u64, + } + + f.0 +} + +impl ResultBuilder for SingleStatementBuilder { + fn init(&mut self, config: &QueryBuilderConfig) -> Result<(), QueryResultBuilderError> { + *self = Self { + max_response_size: config.max_size.unwrap_or(u64::MAX), + ..Default::default() + }; + + Ok(()) + } + + fn begin_step(&mut self) -> Result<(), QueryResultBuilderError> { + // SingleStatementBuilder only builds a single statement + assert!(!self.has_step); + self.has_step = true; + Ok(()) + } + + fn finish_step( + &mut self, + affected_row_count: u64, + last_insert_rowid: Option, + ) -> Result<(), QueryResultBuilderError> { + self.last_insert_rowid = last_insert_rowid; + self.affected_row_count = affected_row_count; + + Ok(()) + } + + fn step_error(&mut self, error: libsqlx::error::Error) -> Result<(), QueryResultBuilderError> { + assert!(self.err.is_none()); + let mut f = SizeFormatter(0); + write!(&mut f, "{error}").unwrap(); + self.current_size = f.0; + + self.err = Some(error); + + Ok(()) + } + + fn cols_description<'a>( + &mut self, + cols: &mut dyn Iterator, + ) -> Result<(), QueryResultBuilderError> { + assert!(self.err.is_none()); + assert!(self.cols.is_empty()); + + let mut cols_size = 0; + + self.cols.extend(cols.into_iter().map(Into::into).map(|c| { + cols_size += estimate_cols_json_size(&c); + proto::Col { + name: Some(c.name.to_owned()), + decltype: c.decl_ty.map(ToString::to_string), + } + })); + + self.current_size += cols_size; + if self.current_size > self.max_response_size { + return Err(QueryResultBuilderError::ResponseTooLarge( + self.max_response_size, + )); + } + + Ok(()) + } + + fn begin_rows(&mut self) -> Result<(), QueryResultBuilderError> { + assert!(self.err.is_none()); + assert!(self.rows.is_empty()); + Ok(()) + } + + fn begin_row(&mut self) -> Result<(), QueryResultBuilderError> { + assert!(self.err.is_none()); + self.rows.push(Vec::with_capacity(self.cols.len())); + Ok(()) + } + + fn add_row_value(&mut self, v: ValueRef) -> Result<(), QueryResultBuilderError> { + assert!(self.err.is_none()); + let estimate_size = value_json_size(&v); + if self.current_size + estimate_size > self.max_response_size { + return Err(QueryResultBuilderError::ResponseTooLarge( + self.max_response_size, + )); + } + + self.current_size += estimate_size; + + let val = match v { + ValueRef::Null => proto::Value::Null, + ValueRef::Integer(value) => proto::Value::Integer { value }, + ValueRef::Real(value) => proto::Value::Float { value }, + ValueRef::Text(s) => proto::Value::Text { + value: String::from_utf8(s.to_vec()) + .map_err(QueryResultBuilderError::from_any)? + .into(), + }, + ValueRef::Blob(d) => proto::Value::Blob { + value: Bytes::copy_from_slice(d), + }, + }; + + self.rows + .last_mut() + .expect("row must be initialized") + .push(val); + + Ok(()) + } + + fn finish_row(&mut self) -> Result<(), QueryResultBuilderError> { + assert!(self.err.is_none()); + Ok(()) + } + + fn finish_rows(&mut self) -> Result<(), QueryResultBuilderError> { + assert!(self.err.is_none()); + Ok(()) + } + + fn finish( + &mut self, + _is_txn: bool, + _frame_no: Option, + ) -> Result<(), QueryResultBuilderError> { + Ok(()) + } +} + +fn estimate_cols_json_size(c: &Column) -> u64 { + let mut f = SizeFormatter(0); + write!( + &mut f, + r#"{{"name":"{}","decltype":"{}"}}"#, + c.name, + c.decl_ty.unwrap_or("null") + ) + .unwrap(); + f.0 +} + +#[derive(Debug, Default)] +pub struct HranaBatchProtoBuilder { + step_results: Vec>, + step_errors: Vec>, + stmt_builder: SingleStatementBuilder, + current_size: u64, + max_response_size: u64, + step_empty: bool, +} + +impl HranaBatchProtoBuilder { + pub fn into_ret(self) -> proto::BatchResult { + proto::BatchResult { + step_results: self.step_results, + step_errors: self.step_errors, + } + } +} + +impl ResultBuilder for HranaBatchProtoBuilder { + fn init(&mut self, config: &QueryBuilderConfig) -> Result<(), QueryResultBuilderError> { + *self = Self { + max_response_size: config.max_size.unwrap_or(u64::MAX), + ..Default::default() + }; + self.stmt_builder.init(config)?; + Ok(()) + } + + fn begin_step(&mut self) -> Result<(), QueryResultBuilderError> { + self.step_empty = true; + self.stmt_builder.begin_step() + } + + fn finish_step( + &mut self, + affected_row_count: u64, + last_insert_rowid: Option, + ) -> Result<(), QueryResultBuilderError> { + self.stmt_builder + .finish_step(affected_row_count, last_insert_rowid)?; + self.current_size += self.stmt_builder.current_size; + + let new_builder = SingleStatementBuilder { + current_size: 0, + max_response_size: self.max_response_size - self.current_size, + ..Default::default() + }; + match std::mem::replace(&mut self.stmt_builder, new_builder).into_ret() { + Ok(res) => { + self.step_results.push((!self.step_empty).then_some(res)); + self.step_errors.push(None); + } + Err(e) => { + self.step_results.push(None); + self.step_errors.push(Some(proto_error_from_stmt_error( + &stmt_error_from_sqld_error(e).map_err(QueryResultBuilderError::from_any)?, + ))); + } + } + + Ok(()) + } + + fn step_error(&mut self, error: libsqlx::error::Error) -> Result<(), QueryResultBuilderError> { + self.stmt_builder.step_error(error) + } + + fn cols_description( + &mut self, + cols: &mut dyn Iterator, + ) -> Result<(), QueryResultBuilderError> { + self.step_empty = false; + self.stmt_builder.cols_description(cols) + } + + fn begin_rows(&mut self) -> Result<(), QueryResultBuilderError> { + self.stmt_builder.begin_rows() + } + + fn begin_row(&mut self) -> Result<(), QueryResultBuilderError> { + self.stmt_builder.begin_row() + } + + fn add_row_value(&mut self, v: ValueRef) -> Result<(), QueryResultBuilderError> { + self.stmt_builder.add_row_value(v) + } + + fn finish_row(&mut self) -> Result<(), QueryResultBuilderError> { + self.stmt_builder.finish_row() + } + + fn finish_rows(&mut self) -> Result<(), QueryResultBuilderError> { + Ok(()) + } + + fn finish( + &mut self, + _is_txn: bool, + _frame_no: Option, + ) -> Result<(), QueryResultBuilderError> { + Ok(()) + } +} diff --git a/libsqlx-server/src/hrana/stmt.rs b/libsqlx-server/src/hrana/stmt.rs new file mode 100644 index 00000000..e74c3d42 --- /dev/null +++ b/libsqlx-server/src/hrana/stmt.rs @@ -0,0 +1,289 @@ +use std::collections::HashMap; + +use color_eyre::eyre::{bail, anyhow}; +use libsqlx::analysis::Statement; +use libsqlx::query::{Query, Params, Value}; + +use super::result_builder::SingleStatementBuilder; +use super::{proto, ProtocolError, Version}; +use crate::allocation::ConnectionHandle; +use crate::hrana; + +/// An error during execution of an SQL statement. +#[derive(thiserror::Error, Debug)] +pub enum StmtError { + #[error("SQL string could not be parsed: {source}")] + SqlParse { source: color_eyre::eyre::Error }, + #[error("SQL string does not contain any statement")] + SqlNoStmt, + #[error("SQL string contains more than one statement")] + SqlManyStmts, + #[error("Arguments do not match SQL parameters: {msg}")] + ArgsInvalid { msg: String }, + #[error("Specifying both positional and named arguments is not supported")] + ArgsBothPositionalAndNamed, + + #[error("Transaction timed out")] + TransactionTimeout, + #[error("Server cannot handle additional transactions")] + TransactionBusy, + #[error("SQLite error: {message}")] + SqliteError { + source: libsqlx::rusqlite::ffi::Error, + message: String, + }, + #[error("SQL input error: {message} (at offset {offset})")] + SqlInputError { + source: color_eyre::eyre::Error, + message: String, + offset: i32, + }, + + #[error("Operation was blocked{}", .reason.as_ref().map(|msg| format!(": {}", msg)).unwrap_or_default())] + Blocked { reason: Option }, +} + +pub async fn execute_stmt( + conn: &ConnectionHandle, + query: Query, +) -> color_eyre::Result { + let builder = conn.exec(move |conn| -> color_eyre::Result<_> { + let mut builder = SingleStatementBuilder::default(); + let pgm = libsqlx::program::Program::from_queries(std::iter::once(query)); + conn.execute_program(pgm, &mut builder)?; + + Ok(builder) + + }).await??; + + builder + .into_ret() + .map_err(|sqld_error| match stmt_error_from_sqld_error(sqld_error) { + Ok(stmt_error) => anyhow!(stmt_error), + Err(sqld_error) => anyhow!(sqld_error), + }) +} + +pub async fn describe_stmt( + _db: &ConnectionHandle, + _sql: String, +) -> color_eyre::Result { + todo!(); + // match db.describe(sql).await? { + // Ok(describe_response) => todo!(), + // // Ok(proto_describe_result_from_describe_response( + // // describe_response, + // // )), + // Err(sqld_error) => match stmt_error_from_sqld_error(sqld_error) { + // Ok(stmt_error) => bail!(stmt_error), + // Err(sqld_error) => bail!(sqld_error), + // }, + // } +} + +pub fn proto_stmt_to_query( + proto_stmt: &proto::Stmt, + sqls: &HashMap, + version: Version, +) -> color_eyre::Result { + let sql = proto_sql_to_sql(proto_stmt.sql.as_deref(), proto_stmt.sql_id, sqls, version)?; + + let mut stmt_iter = Statement::parse(sql); + let stmt = match stmt_iter.next() { + Some(Ok(stmt)) => stmt, + Some(Err(err)) => bail!(StmtError::SqlParse { source: err.into() }), + None => bail!(StmtError::SqlNoStmt), + }; + + if stmt_iter.next().is_some() { + bail!(StmtError::SqlManyStmts) + } + + let params = if proto_stmt.named_args.is_empty() { + let values = proto_stmt.args.iter().map(proto_value_to_value).collect(); + Params::Positional(values) + } else if proto_stmt.args.is_empty() { + let values = proto_stmt + .named_args + .iter() + .map(|arg| (arg.name.clone(), proto_value_to_value(&arg.value))) + .collect(); + Params::Named(values) + } else { + bail!(StmtError::ArgsBothPositionalAndNamed) + }; + + let want_rows = proto_stmt.want_rows.unwrap_or(true); + Ok(Query { + stmt, + params, + want_rows, + }) +} + +pub fn proto_sql_to_sql<'s>( + proto_sql: Option<&'s str>, + proto_sql_id: Option, + sqls: &'s HashMap, + verion: Version, +) -> Result<&'s str, ProtocolError> { + if proto_sql_id.is_some() && verion < Version::Hrana2 { + return Err(ProtocolError::NotSupported { + what: "`sql_id`", + min_version: Version::Hrana2, + }); + } + + match (proto_sql, proto_sql_id) { + (Some(sql), None) => Ok(sql), + (None, Some(sql_id)) => match sqls.get(&sql_id) { + Some(sql) => Ok(sql), + None => Err(ProtocolError::SqlNotFound { sql_id }), + }, + (Some(_), Some(_)) => Err(ProtocolError::SqlIdAndSqlGiven), + (None, None) => Err(ProtocolError::SqlIdOrSqlNotGiven), + } +} + +fn proto_value_to_value(proto_value: &proto::Value) -> Value { + match proto_value { + proto::Value::Null => Value::Null, + proto::Value::Integer { value } => Value::Integer(*value), + proto::Value::Float { value } => Value::Real(*value), + proto::Value::Text { value } => Value::Text(value.as_ref().into()), + proto::Value::Blob { value } => Value::Blob(value.as_ref().into()), + } +} + +fn proto_value_from_value(value: Value) -> proto::Value { + match value { + Value::Null => proto::Value::Null, + Value::Integer(value) => proto::Value::Integer { value }, + Value::Real(value) => proto::Value::Float { value }, + Value::Text(value) => proto::Value::Text { + value: value.into(), + }, + Value::Blob(value) => proto::Value::Blob { + value: value.into(), + }, + } +} + +// fn proto_describe_result_from_describe_response( +// response: DescribeResponse, +// ) -> proto::DescribeResult { +// proto::DescribeResult { +// params: response +// .params +// .into_iter() +// .map(|p| proto::DescribeParam { name: p.name }) +// .collect(), +// cols: response +// .cols +// .into_iter() +// .map(|c| proto::DescribeCol { +// name: c.name, +// decltype: c.decltype, +// }) +// .collect(), +// is_explain: response.is_explain, +// is_readonly: response.is_readonly, +// } +// } + +pub fn stmt_error_from_sqld_error(sqld_error: libsqlx::error::Error) -> Result { + Ok(match sqld_error { + libsqlx::error::Error::LibSqlInvalidQueryParams(msg) => StmtError::ArgsInvalid { msg }, + libsqlx::error::Error::LibSqlTxTimeout => StmtError::TransactionTimeout, + libsqlx::error::Error::LibSqlTxBusy => StmtError::TransactionBusy, + libsqlx::error::Error::Blocked(reason) => StmtError::Blocked { reason }, + libsqlx::error::Error::RusqliteError(rusqlite_error) => match rusqlite_error { + libsqlx::error::RusqliteError::SqliteFailure(sqlite_error, Some(message)) => StmtError::SqliteError { + source: sqlite_error, + message, + }, + libsqlx::error::RusqliteError::SqliteFailure(sqlite_error, None) => StmtError::SqliteError { + message: sqlite_error.to_string(), + source: sqlite_error, + }, + libsqlx::error::RusqliteError::SqlInputError { + error: sqlite_error, + msg: message, + offset, + .. + } => StmtError::SqlInputError { + source: sqlite_error.into(), + message, + offset, + }, + rusqlite_error => return Err(libsqlx::error::Error::RusqliteError(rusqlite_error)), + }, + sqld_error => return Err(sqld_error), + }) +} + +pub fn proto_error_from_stmt_error(error: &StmtError) -> hrana::proto::Error { + hrana::proto::Error { + message: error.to_string(), + code: error.code().into(), + } +} + +impl StmtError { + pub fn code(&self) -> &'static str { + match self { + Self::SqlParse { .. } => "SQL_PARSE_ERROR", + Self::SqlNoStmt => "SQL_NO_STATEMENT", + Self::SqlManyStmts => "SQL_MANY_STATEMENTS", + Self::ArgsInvalid { .. } => "ARGS_INVALID", + Self::ArgsBothPositionalAndNamed => "ARGS_BOTH_POSITIONAL_AND_NAMED", + Self::TransactionTimeout => "TRANSACTION_TIMEOUT", + Self::TransactionBusy => "TRANSACTION_BUSY", + Self::SqliteError { source, .. } => sqlite_error_code(source.code), + Self::SqlInputError { .. } => "SQL_INPUT_ERROR", + Self::Blocked { .. } => "BLOCKED", + } + } +} + +fn sqlite_error_code(code: libsqlx::error::ErrorCode) -> &'static str { + match code { + libsqlx::error::ErrorCode::InternalMalfunction => "SQLITE_INTERNAL", + libsqlx::error::ErrorCode::PermissionDenied => "SQLITE_PERM", + libsqlx::error::ErrorCode::OperationAborted => "SQLITE_ABORT", + libsqlx::error::ErrorCode::DatabaseBusy => "SQLITE_BUSY", + libsqlx::error::ErrorCode::DatabaseLocked => "SQLITE_LOCKED", + libsqlx::error::ErrorCode::OutOfMemory => "SQLITE_NOMEM", + libsqlx::error::ErrorCode::ReadOnly => "SQLITE_READONLY", + libsqlx::error::ErrorCode::OperationInterrupted => "SQLITE_INTERRUPT", + libsqlx::error::ErrorCode::SystemIoFailure => "SQLITE_IOERR", + libsqlx::error::ErrorCode::DatabaseCorrupt => "SQLITE_CORRUPT", + libsqlx::error::ErrorCode::NotFound => "SQLITE_NOTFOUND", + libsqlx::error::ErrorCode::DiskFull => "SQLITE_FULL", + libsqlx::error::ErrorCode::CannotOpen => "SQLITE_CANTOPEN", + libsqlx::error::ErrorCode::FileLockingProtocolFailed => "SQLITE_PROTOCOL", + libsqlx::error::ErrorCode::SchemaChanged => "SQLITE_SCHEMA", + libsqlx::error::ErrorCode::TooBig => "SQLITE_TOOBIG", + libsqlx::error::ErrorCode::ConstraintViolation => "SQLITE_CONSTRAINT", + libsqlx::error::ErrorCode::TypeMismatch => "SQLITE_MISMATCH", + libsqlx::error::ErrorCode::ApiMisuse => "SQLITE_MISUSE", + libsqlx::error::ErrorCode::NoLargeFileSupport => "SQLITE_NOLFS", + libsqlx::error::ErrorCode::AuthorizationForStatementDenied => "SQLITE_AUTH", + libsqlx::error::ErrorCode::ParameterOutOfRange => "SQLITE_RANGE", + libsqlx::error::ErrorCode::NotADatabase => "SQLITE_NOTADB", + libsqlx::error::ErrorCode::Unknown => "SQLITE_UNKNOWN", + _ => "SQLITE_UNKNOWN", + } +} + +impl From<&proto::Value> for Value { + fn from(proto_value: &proto::Value) -> Value { + proto_value_to_value(proto_value) + } +} + +impl From for proto::Value { + fn from(value: Value) -> proto::Value { + proto_value_from_value(value) + } +} diff --git a/libsqlx-server/src/hrana/ws/conn.rs b/libsqlx-server/src/hrana/ws/conn.rs new file mode 100644 index 00000000..44daf98f --- /dev/null +++ b/libsqlx-server/src/hrana/ws/conn.rs @@ -0,0 +1,301 @@ +use std::borrow::Cow; +use std::future::Future; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; + +use anyhow::{bail, Context as _, Result}; +use futures::stream::FuturesUnordered; +use futures::{ready, FutureExt as _, StreamExt as _}; +use tokio::sync::oneshot; +use tokio_tungstenite::tungstenite; +use tungstenite::protocol::frame::coding::CloseCode; + +use crate::database::Database; + +use super::super::{ProtocolError, Version}; +use super::handshake::WebSocket; +use super::{handshake, proto, session, Server, Upgrade}; + +/// State of a Hrana connection. +struct Conn { + conn_id: u64, + server: Arc>, + ws: WebSocket, + ws_closed: bool, + /// The version of the protocol that has been negotiated in the WebSocket handshake. + version: Version, + /// After a successful authentication, this contains the session-level state of the connection. + session: Option>, + /// Join set for all tasks that were spawned to handle the connection. + join_set: tokio::task::JoinSet<()>, + /// Future responses to requests that we have received but are evaluating asynchronously. + responses: FuturesUnordered, +} + +/// A `Future` that stores a handle to a future response to request which is being evaluated +/// asynchronously. +struct ResponseFuture { + /// The request id, which must be included in the response. + request_id: i32, + /// The future that will be resolved with the response. + response_rx: futures::future::Fuse>>, +} + +pub(super) async fn handle_tcp( + server: Arc>, + socket: tokio::net::TcpStream, + conn_id: u64, +) -> Result<()> { + let (ws, version) = handshake::handshake_tcp(socket) + .await + .context("Could not perform the WebSocket handshake on TCP connection")?; + handle_ws(server, ws, version, conn_id).await +} + +pub(super) async fn handle_upgrade( + server: Arc>, + upgrade: Upgrade, + conn_id: u64, +) -> Result<()> { + let (ws, version) = handshake::handshake_upgrade(upgrade) + .await + .context("Could not perform the WebSocket handshake on HTTP connection")?; + handle_ws(server, ws, version, conn_id).await +} + +async fn handle_ws( + server: Arc>, + ws: WebSocket, + version: Version, + conn_id: u64, +) -> Result<()> { + let mut conn = Conn { + conn_id, + server, + ws, + ws_closed: false, + version, + session: None, + join_set: tokio::task::JoinSet::new(), + responses: FuturesUnordered::new(), + }; + + loop { + if let Some(kicker) = conn.server.idle_kicker.as_ref() { + kicker.kick(); + } + + tokio::select! { + Some(client_msg_res) = conn.ws.recv() => { + let client_msg = client_msg_res + .context("Could not receive a WebSocket message")?; + match handle_msg(&mut conn, client_msg).await { + Ok(true) => continue, + Ok(false) => break, + Err(err) => { + match err.downcast::() { + Ok(proto_err) => { + tracing::warn!( + "Connection #{} terminated due to protocol error: {}", + conn.conn_id, + proto_err, + ); + let close_code = protocol_error_to_close_code(&proto_err); + close(&mut conn, close_code, proto_err.to_string()).await; + return Ok(()) + } + Err(err) => { + close(&mut conn, CloseCode::Error, "Internal server error".into()).await; + return Err(err); + } + } + } + } + }, + Some(task_res) = conn.join_set.join_next() => { + task_res.expect("Connection subtask failed") + }, + Some(response_res) = conn.responses.next() => { + let response_msg = response_res?; + send_msg(&mut conn, &response_msg).await?; + }, + else => break, + } + } + + close( + &mut conn, + CloseCode::Normal, + "Thank you for using sqld".into(), + ) + .await; + Ok(()) +} + +async fn handle_msg( + conn: &mut Conn, + client_msg: tungstenite::Message, +) -> Result { + match client_msg { + tungstenite::Message::Text(client_msg) => { + // client messages are received as text WebSocket messages that encode the `ClientMsg` + // in JSON + let client_msg: proto::ClientMsg = match serde_json::from_str(&client_msg) { + Ok(client_msg) => client_msg, + Err(err) => bail!(ProtocolError::Deserialize { source: err }), + }; + + match client_msg { + proto::ClientMsg::Hello { jwt } => handle_hello_msg(conn, jwt).await, + proto::ClientMsg::Request { + request_id, + request, + } => handle_request_msg(conn, request_id, request).await, + } + } + tungstenite::Message::Binary(_) => bail!(ProtocolError::BinaryWebSocketMessage), + tungstenite::Message::Ping(ping_data) => { + let pong_msg = tungstenite::Message::Pong(ping_data); + conn.ws + .send(pong_msg) + .await + .context("Could not send pong to the WebSocket")?; + Ok(true) + } + tungstenite::Message::Pong(_) => Ok(true), + tungstenite::Message::Close(_) => Ok(false), + tungstenite::Message::Frame(_) => panic!("Received a tungstenite::Message::Frame"), + } +} + +async fn handle_hello_msg(conn: &mut Conn, jwt: Option) -> Result { + let hello_res = match conn.session.as_mut() { + None => session::handle_initial_hello(&conn.server, conn.version, jwt) + .map(|session| conn.session = Some(session)), + Some(session) => session::handle_repeated_hello(&conn.server, session, jwt), + }; + + match hello_res { + Ok(_) => { + send_msg(conn, &proto::ServerMsg::HelloOk {}).await?; + Ok(true) + } + Err(err) => match downcast_error(err) { + Ok(error) => { + send_msg(conn, &proto::ServerMsg::HelloError { error }).await?; + Ok(false) + } + Err(err) => Err(err), + }, + } +} + +async fn handle_request_msg( + conn: &mut Conn, + request_id: i32, + request: proto::Request, +) -> Result { + let Some(session) = conn.session.as_mut() else { + bail!(ProtocolError::RequestBeforeHello) + }; + + let response_rx = session::handle_request(&conn.server, session, &mut conn.join_set, request) + .await + .unwrap_or_else(|err| { + // we got an error immediately, but let's treat it as a special case of the general + // flow + let (tx, rx) = oneshot::channel(); + tx.send(Err(err)).unwrap(); + rx + }); + + conn.responses.push(ResponseFuture { + request_id, + response_rx: response_rx.fuse(), + }); + Ok(true) +} + +impl Future for ResponseFuture { + type Output = Result; + fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { + match ready!(Pin::new(&mut self.response_rx).poll(cx)) { + Ok(Ok(response)) => Poll::Ready(Ok(proto::ServerMsg::ResponseOk { + request_id: self.request_id, + response, + })), + Ok(Err(err)) => match downcast_error(err) { + Ok(error) => Poll::Ready(Ok(proto::ServerMsg::ResponseError { + request_id: self.request_id, + error, + })), + Err(err) => Poll::Ready(Err(err)), + }, + Err(_recv_err) => { + // do not propagate this error, because the error that caused the receiver to drop + // is very likely propagating from another task at this moment, and we don't want + // to hide it. + // this is also the reason why we need to use `Fuse` in self.response_rx + tracing::warn!("Response sender was dropped"); + Poll::Pending + } + } + } +} + +fn downcast_error(err: anyhow::Error) -> Result { + match err.downcast_ref::() { + Some(error) => Ok(proto::Error { + message: error.to_string(), + code: error.code().into(), + }), + None => Err(err), + } +} + +async fn send_msg(conn: &mut Conn, msg: &proto::ServerMsg) -> Result<()> { + let msg = serde_json::to_string(&msg).context("Could not serialize response message")?; + let msg = tungstenite::Message::Text(msg); + conn.ws + .send(msg) + .await + .context("Could not send response to the WebSocket") +} + +async fn close(conn: &mut Conn, code: CloseCode, reason: String) { + if conn.ws_closed { + return; + } + + let close_frame = tungstenite::protocol::frame::CloseFrame { + code, + reason: Cow::Owned(reason), + }; + if let Err(err) = conn + .ws + .send(tungstenite::Message::Close(Some(close_frame))) + .await + { + if !matches!( + err, + tungstenite::Error::AlreadyClosed | tungstenite::Error::ConnectionClosed + ) { + tracing::warn!( + "Could not send close frame to WebSocket of connection #{}: {:?}", + conn.conn_id, + err + ); + } + } + + conn.ws_closed = true; +} + +fn protocol_error_to_close_code(err: &ProtocolError) -> CloseCode { + match err { + ProtocolError::Deserialize { .. } => CloseCode::Invalid, + ProtocolError::BinaryWebSocketMessage => CloseCode::Unsupported, + _ => CloseCode::Policy, + } +} diff --git a/libsqlx-server/src/hrana/ws/handshake.rs b/libsqlx-server/src/hrana/ws/handshake.rs new file mode 100644 index 00000000..ef187a6a --- /dev/null +++ b/libsqlx-server/src/hrana/ws/handshake.rs @@ -0,0 +1,140 @@ +use anyhow::{anyhow, bail, Context as _, Result}; +use futures::{SinkExt as _, StreamExt as _}; +use tokio_tungstenite::tungstenite; +use tungstenite::http; + +use super::super::Version; +use super::Upgrade; + +#[derive(Debug)] +pub enum WebSocket { + Tcp(tokio_tungstenite::WebSocketStream), + Upgraded(tokio_tungstenite::WebSocketStream), +} + +pub async fn handshake_tcp(socket: tokio::net::TcpStream) -> Result<(WebSocket, Version)> { + let mut version = None; + let callback = |req: &http::Request<()>, resp: http::Response<()>| { + let (mut resp_parts, _) = resp.into_parts(); + resp_parts + .headers + .insert("server", http::HeaderValue::from_static("sqld-hrana-tcp")); + + match negotiate_version(req.headers(), &mut resp_parts.headers) { + Ok(version_) => { + version = Some(version_); + Ok(http::Response::from_parts(resp_parts, ())) + } + Err(resp_body) => Err(http::Response::from_parts(resp_parts, Some(resp_body))), + } + }; + + let ws_config = Some(get_ws_config()); + let stream = + tokio_tungstenite::accept_hdr_async_with_config(socket, callback, ws_config).await?; + Ok((WebSocket::Tcp(stream), version.unwrap())) +} + +pub async fn handshake_upgrade(upgrade: Upgrade) -> Result<(WebSocket, Version)> { + let mut req = upgrade.request; + + let ws_config = Some(get_ws_config()); + let (mut resp, stream_fut_version_res) = match hyper_tungstenite::upgrade(&mut req, ws_config) { + Ok((mut resp, stream_fut)) => match negotiate_version(req.headers(), resp.headers_mut()) { + Ok(version) => (resp, Ok((stream_fut, version))), + Err(msg) => { + *resp.status_mut() = http::StatusCode::BAD_REQUEST; + *resp.body_mut() = hyper::Body::from(msg.clone()); + ( + resp, + Err(anyhow!("Could not negotiate subprotocol: {}", msg)), + ) + } + }, + Err(err) => { + let resp = http::Response::builder() + .status(http::StatusCode::BAD_REQUEST) + .body(hyper::Body::from(format!("{err}"))) + .unwrap(); + ( + resp, + Err(anyhow!(err).context("Protocol error in HTTP upgrade")), + ) + } + }; + + resp.headers_mut().insert( + "server", + http::HeaderValue::from_static("sqld-hrana-upgrade"), + ); + if upgrade.response_tx.send(resp).is_err() { + bail!("Could not send the HTTP upgrade response") + } + + let (stream_fut, version) = stream_fut_version_res?; + let stream = stream_fut + .await + .context("Could not upgrade HTTP request to a WebSocket")?; + Ok((WebSocket::Upgraded(stream), version)) +} + +fn negotiate_version( + req_headers: &http::HeaderMap, + resp_headers: &mut http::HeaderMap, +) -> Result { + if let Some(protocol_hdr) = req_headers.get("sec-websocket-protocol") { + let supported_by_client = protocol_hdr + .to_str() + .unwrap_or("") + .split(',') + .map(|p| p.trim()); + + let mut hrana1_supported = false; + let mut hrana2_supported = false; + for protocol_str in supported_by_client { + hrana1_supported |= protocol_str.eq_ignore_ascii_case("hrana1"); + hrana2_supported |= protocol_str.eq_ignore_ascii_case("hrana2"); + } + + let version = if hrana2_supported { + Version::Hrana2 + } else if hrana1_supported { + Version::Hrana1 + } else { + return Err("Only 'hrana1' and 'hrana2' subprotocols are supported".into()); + }; + + resp_headers.append( + "sec-websocket-protocol", + http::HeaderValue::from_str(&version.to_string()).unwrap(), + ); + Ok(version) + } else { + // Sec-WebSocket-Protocol header not present, assume that the client wants hrana1 + // According to RFC 6455, we must not set the Sec-WebSocket-Protocol response header + Ok(Version::Hrana1) + } +} + +fn get_ws_config() -> tungstenite::protocol::WebSocketConfig { + tungstenite::protocol::WebSocketConfig { + max_send_queue: Some(1 << 20), + ..Default::default() + } +} + +impl WebSocket { + pub async fn recv(&mut self) -> Option> { + match self { + Self::Tcp(stream) => stream.next().await, + Self::Upgraded(stream) => stream.next().await, + } + } + + pub async fn send(&mut self, msg: tungstenite::Message) -> tungstenite::Result<()> { + match self { + Self::Tcp(stream) => stream.send(msg).await, + Self::Upgraded(stream) => stream.send(msg).await, + } + } +} diff --git a/libsqlx-server/src/hrana/ws/mod.rs b/libsqlx-server/src/hrana/ws/mod.rs new file mode 100644 index 00000000..32a34957 --- /dev/null +++ b/libsqlx-server/src/hrana/ws/mod.rs @@ -0,0 +1,104 @@ +use crate::auth::Auth; +use crate::database::Database; +use crate::utils::services::idle_shutdown::IdleKicker; +use anyhow::{Context as _, Result}; +use enclose::enclose; +use std::net::SocketAddr; +use std::sync::atomic::{AtomicU64, Ordering}; +use std::sync::Arc; +use tokio::sync::{mpsc, oneshot}; + +pub mod proto; + +mod conn; +mod handshake; +mod session; + +struct Server { + db_factory: Arc>, + auth: Arc, + idle_kicker: Option, + next_conn_id: AtomicU64, +} + +#[derive(Debug)] +pub struct Accept { + pub socket: tokio::net::TcpStream, + pub peer_addr: SocketAddr, +} + +#[derive(Debug)] +pub struct Upgrade { + pub request: hyper::Request, + pub response_tx: oneshot::Sender>, +} + +pub async fn serve( + db_factory: Arc>, + auth: Arc, + idle_kicker: Option, + mut accept_rx: mpsc::Receiver, + mut upgrade_rx: mpsc::Receiver, +) -> Result<()> { + let server = Arc::new(Server { + db_factory, + auth, + idle_kicker, + next_conn_id: AtomicU64::new(0), + }); + + let mut join_set = tokio::task::JoinSet::new(); + loop { + if let Some(kicker) = server.idle_kicker.as_ref() { + kicker.kick(); + } + + tokio::select! { + Some(accept) = accept_rx.recv() => { + let conn_id = server.next_conn_id.fetch_add(1, Ordering::AcqRel); + tracing::info!("Received TCP connection #{} from {}", conn_id, accept.peer_addr); + + join_set.spawn(enclose!{(server, conn_id) async move { + match conn::handle_tcp(server, accept.socket, conn_id).await { + Ok(_) => tracing::info!("TCP connection #{} was terminated", conn_id), + Err(err) => tracing::error!("TCP connection #{} failed: {:?}", conn_id, err), + } + }}); + }, + Some(upgrade) = upgrade_rx.recv() => { + let conn_id = server.next_conn_id.fetch_add(1, Ordering::AcqRel); + tracing::info!("Received HTTP upgrade connection #{}", conn_id); + + join_set.spawn(enclose!{(server, conn_id) async move { + match conn::handle_upgrade(server, upgrade, conn_id).await { + Ok(_) => tracing::info!("HTTP upgrade connection #{} was terminated", conn_id), + Err(err) => tracing::error!("HTTP upgrade connection #{} failed: {:?}", conn_id, err), + } + }}); + }, + Some(task_res) = join_set.join_next() => { + task_res.expect("Hrana connection task failed") + }, + else => { + tracing::error!("hrana server loop exited"); + return Ok(()) + } + } + } +} + +pub async fn listen(bind_addr: SocketAddr, accept_tx: mpsc::Sender) -> Result<()> { + let listener = tokio::net::TcpListener::bind(bind_addr) + .await + .context("Could not bind TCP listener")?; + let local_addr = listener.local_addr()?; + tracing::info!("Listening for Hrana connections on {}", local_addr); + + loop { + let (socket, peer_addr) = listener + .accept() + .await + .context("Could not accept a TCP connection")?; + let _: Result<_, _> = accept_tx.send(Accept { socket, peer_addr }).await; + } +} diff --git a/libsqlx-server/src/hrana/ws/proto.rs b/libsqlx-server/src/hrana/ws/proto.rs new file mode 100644 index 00000000..6bb88367 --- /dev/null +++ b/libsqlx-server/src/hrana/ws/proto.rs @@ -0,0 +1,127 @@ +//! Structures for Hrana-over-WebSockets. + +pub use super::super::proto::*; +use serde::{Deserialize, Serialize}; + +#[derive(Deserialize, Debug)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum ClientMsg { + Hello { jwt: Option }, + Request { request_id: i32, request: Request }, +} + +#[derive(Serialize, Debug)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum ServerMsg { + HelloOk {}, + HelloError { error: Error }, + ResponseOk { request_id: i32, response: Response }, + ResponseError { request_id: i32, error: Error }, +} + +#[derive(Deserialize, Debug)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum Request { + OpenStream(OpenStreamReq), + CloseStream(CloseStreamReq), + Execute(ExecuteReq), + Batch(BatchReq), + Sequence(SequenceReq), + Describe(DescribeReq), + StoreSql(StoreSqlReq), + CloseSql(CloseSqlReq), +} + +#[derive(Serialize, Debug)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum Response { + OpenStream(OpenStreamResp), + CloseStream(CloseStreamResp), + Execute(ExecuteResp), + Batch(BatchResp), + Sequence(SequenceResp), + Describe(DescribeResp), + StoreSql(StoreSqlResp), + CloseSql(CloseSqlResp), +} + +#[derive(Deserialize, Debug)] +pub struct OpenStreamReq { + pub stream_id: i32, +} + +#[derive(Serialize, Debug)] +pub struct OpenStreamResp {} + +#[derive(Deserialize, Debug)] +pub struct CloseStreamReq { + pub stream_id: i32, +} + +#[derive(Serialize, Debug)] +pub struct CloseStreamResp {} + +#[derive(Deserialize, Debug)] +pub struct ExecuteReq { + pub stream_id: i32, + pub stmt: Stmt, +} + +#[derive(Serialize, Debug)] +pub struct ExecuteResp { + pub result: StmtResult, +} + +#[derive(Deserialize, Debug)] +pub struct BatchReq { + pub stream_id: i32, + pub batch: Batch, +} + +#[derive(Serialize, Debug)] +pub struct BatchResp { + pub result: BatchResult, +} + +#[derive(Deserialize, Debug)] +pub struct SequenceReq { + pub stream_id: i32, + #[serde(default)] + pub sql: Option, + #[serde(default)] + pub sql_id: Option, +} + +#[derive(Serialize, Debug)] +pub struct SequenceResp {} + +#[derive(Deserialize, Debug)] +pub struct DescribeReq { + pub stream_id: i32, + #[serde(default)] + pub sql: Option, + #[serde(default)] + pub sql_id: Option, +} + +#[derive(Serialize, Debug)] +pub struct DescribeResp { + pub result: DescribeResult, +} + +#[derive(Deserialize, Debug)] +pub struct StoreSqlReq { + pub sql_id: i32, + pub sql: String, +} + +#[derive(Serialize, Debug)] +pub struct StoreSqlResp {} + +#[derive(Deserialize, Debug)] +pub struct CloseSqlReq { + pub sql_id: i32, +} + +#[derive(Serialize, Debug)] +pub struct CloseSqlResp {} diff --git a/libsqlx-server/src/hrana/ws/session.rs b/libsqlx-server/src/hrana/ws/session.rs new file mode 100644 index 00000000..f59bcecc --- /dev/null +++ b/libsqlx-server/src/hrana/ws/session.rs @@ -0,0 +1,329 @@ +use std::collections::HashMap; + +use anyhow::{anyhow, bail, Context as _, Result}; +use futures::future::BoxFuture; +use tokio::sync::{mpsc, oneshot}; + +use super::super::{batch, stmt, ProtocolError, Version}; +use super::{proto, Server}; +use crate::auth::{AuthError, Authenticated}; +use crate::database::Database; + +/// Session-level state of an authenticated Hrana connection. +pub struct Session { + authenticated: Authenticated, + version: Version, + streams: HashMap>, + sqls: HashMap, +} + +struct StreamHandle { + job_tx: mpsc::Sender>, +} + +/// An arbitrary job that is executed on a [`Stream`]. +/// +/// All jobs are executed sequentially on a single task (as evidenced by the `&mut Stream` passed +/// to `f`). +struct StreamJob { + /// The async function which performs the job. + #[allow(clippy::type_complexity)] + f: Box FnOnce(&'s mut Stream) -> BoxFuture<'s, Result> + Send>, + /// The result of `f` will be sent here. + resp_tx: oneshot::Sender>, +} + +/// State of a Hrana stream, which corresponds to a standalone database connection. +struct Stream { + /// The database handle is `None` when the stream is created, and normally set to `Some` by the + /// first job executed on the stream by the [`proto::OpenStreamReq`] request. However, if that + /// request returns an error, the following requests may encounter a `None` here. + db: Option, +} + +/// An error which can be converted to a Hrana [Error][proto::Error]. +#[derive(thiserror::Error, Debug)] +pub enum ResponseError { + #[error("Authentication failed: {source}")] + Auth { source: AuthError }, + #[error("Stream {stream_id} has failed to open")] + StreamNotOpen { stream_id: i32 }, + #[error("The server already stores {count} SQL texts, it cannot store more")] + SqlTooMany { count: usize }, + #[error(transparent)] + Stmt(stmt::StmtError), +} + +pub(super) fn handle_initial_hello( + server: &Server, + version: Version, + jwt: Option, +) -> Result> { + let authenticated = server + .auth + .authenticate_jwt(jwt.as_deref()) + .map_err(|err| anyhow!(ResponseError::Auth { source: err }))?; + + Ok(Session { + authenticated, + version, + streams: HashMap::new(), + sqls: HashMap::new(), + }) +} + +pub(super) fn handle_repeated_hello( + server: &Server, + session: &mut Session, + jwt: Option, +) -> Result<()> { + if session.version < Version::Hrana2 { + bail!(ProtocolError::NotSupported { + what: "Repeated hello message", + min_version: Version::Hrana2, + }) + } + + session.authenticated = server + .auth + .authenticate_jwt(jwt.as_deref()) + .map_err(|err| anyhow!(ResponseError::Auth { source: err }))?; + Ok(()) +} + +pub(super) async fn handle_request( + server: &Server, + session: &mut Session, + join_set: &mut tokio::task::JoinSet<()>, + req: proto::Request, +) -> Result>> { + // TODO: this function has rotten: it is too long and contains too much duplicated code. It + // should be refactored at the next opportunity, together with code in stmt.rs and batch.rs + + let (resp_tx, resp_rx) = oneshot::channel(); + + macro_rules! stream_respond { + ($stream_hnd:expr, async move |$stream:ident| { $($body:tt)* }) => { + stream_respond($stream_hnd, resp_tx, move |$stream| { + Box::pin(async move { $($body)* }) + }) + .await + }; + } + + macro_rules! respond { + ($value:expr) => { + resp_tx.send(Ok($value)).unwrap() + }; + } + + macro_rules! ensure_version { + ($min_version:expr, $what:expr) => { + if session.version < $min_version { + bail!(ProtocolError::NotSupported { + what: $what, + min_version: $min_version, + }) + } + }; + } + + macro_rules! get_stream_mut { + ($stream_id:expr) => { + match session.streams.get_mut(&$stream_id) { + Some(stream_hdn) => stream_hdn, + None => bail!(ProtocolError::StreamNotFound { + stream_id: $stream_id + }), + } + }; + } + + macro_rules! get_stream_db { + ($stream:expr, $stream_id:expr) => { + match $stream.db.as_ref() { + Some(db) => db, + None => bail!(ResponseError::StreamNotOpen { + stream_id: $stream_id + }), + } + }; + } + + match req { + proto::Request::OpenStream(req) => { + let stream_id = req.stream_id; + if session.streams.contains_key(&stream_id) { + bail!(ProtocolError::StreamExists { stream_id }) + } + + let mut stream_hnd = stream_spawn(join_set, Stream { db: None }); + let db_factory = server.db_factory.clone(); + + stream_respond!(&mut stream_hnd, async move |stream| { + let db = db_factory + .create() + .await + .context("Could not create a database connection")?; + stream.db = Some(db); + Ok(proto::Response::OpenStream(proto::OpenStreamResp {})) + }); + + session.streams.insert(stream_id, stream_hnd); + } + proto::Request::CloseStream(req) => { + let stream_id = req.stream_id; + let Some(mut stream_hnd) = session.streams.remove(&stream_id) else { + bail!(ProtocolError::StreamNotFound { stream_id }) + }; + + stream_respond!(&mut stream_hnd, async move |_stream| { + Ok(proto::Response::CloseStream(proto::CloseStreamResp {})) + }); + } + proto::Request::Execute(req) => { + let stream_id = req.stream_id; + let stream_hnd = get_stream_mut!(stream_id); + + let query = stmt::proto_stmt_to_query(&req.stmt, &session.sqls, session.version) + .map_err(catch_stmt_error)?; + let auth = session.authenticated; + + stream_respond!(stream_hnd, async move |stream| { + let db = get_stream_db!(stream, stream_id); + let result = stmt::execute_stmt(db, auth, query) + .await + .map_err(catch_stmt_error)?; + Ok(proto::Response::Execute(proto::ExecuteResp { result })) + }); + } + proto::Request::Batch(req) => { + let stream_id = req.stream_id; + let stream_hnd = get_stream_mut!(stream_id); + + let pgm = batch::proto_batch_to_program(&req.batch, &session.sqls, session.version) + .map_err(catch_stmt_error)?; + let auth = session.authenticated; + + stream_respond!(stream_hnd, async move |stream| { + let db = get_stream_db!(stream, stream_id); + let result = batch::execute_batch(db, auth, pgm).await?; + Ok(proto::Response::Batch(proto::BatchResp { result })) + }); + } + proto::Request::Sequence(req) => { + ensure_version!(Version::Hrana2, "The `sequence` request"); + let stream_id = req.stream_id; + let stream_hnd = get_stream_mut!(stream_id); + + let sql = stmt::proto_sql_to_sql( + req.sql.as_deref(), + req.sql_id, + &session.sqls, + session.version, + )?; + let pgm = batch::proto_sequence_to_program(sql).map_err(catch_stmt_error)?; + let auth = session.authenticated; + + stream_respond!(stream_hnd, async move |stream| { + let db = get_stream_db!(stream, stream_id); + batch::execute_sequence(db, auth, pgm) + .await + .map_err(catch_stmt_error)?; + Ok(proto::Response::Sequence(proto::SequenceResp {})) + }); + } + proto::Request::Describe(req) => { + ensure_version!(Version::Hrana2, "The `describe` request"); + let stream_id = req.stream_id; + let stream_hnd = get_stream_mut!(stream_id); + + let sql = stmt::proto_sql_to_sql( + req.sql.as_deref(), + req.sql_id, + &session.sqls, + session.version, + )? + .into(); + let auth = session.authenticated; + + stream_respond!(stream_hnd, async move |stream| { + let db = get_stream_db!(stream, stream_id); + let result = stmt::describe_stmt(db, auth, sql) + .await + .map_err(catch_stmt_error)?; + Ok(proto::Response::Describe(proto::DescribeResp { result })) + }); + } + proto::Request::StoreSql(req) => { + ensure_version!(Version::Hrana2, "The `store_sql` request"); + let sql_id = req.sql_id; + if session.sqls.contains_key(&sql_id) { + bail!(ProtocolError::SqlExists { sql_id }) + } else if session.sqls.len() >= MAX_SQL_COUNT { + bail!(ResponseError::SqlTooMany { + count: session.sqls.len() + }) + } + + session.sqls.insert(sql_id, req.sql); + respond!(proto::Response::StoreSql(proto::StoreSqlResp {})); + } + proto::Request::CloseSql(req) => { + ensure_version!(Version::Hrana2, "The `close_sql` request"); + session.sqls.remove(&req.sql_id); + respond!(proto::Response::CloseSql(proto::CloseSqlResp {})); + } + } + Ok(resp_rx) +} + +const MAX_SQL_COUNT: usize = 150; + +fn stream_spawn( + join_set: &mut tokio::task::JoinSet<()>, + stream: Stream, +) -> StreamHandle { + let (job_tx, mut job_rx) = mpsc::channel::>(8); + join_set.spawn(async move { + let mut stream = stream; + while let Some(job) = job_rx.recv().await { + let res = (job.f)(&mut stream).await; + let _: Result<_, _> = job.resp_tx.send(res); + } + }); + StreamHandle { job_tx } +} + +async fn stream_respond( + stream_hnd: &mut StreamHandle, + resp_tx: oneshot::Sender>, + f: F, +) where + for<'s> F: FnOnce(&'s mut Stream) -> BoxFuture<'s, Result>, + F: Send + 'static, +{ + let job = StreamJob { + f: Box::new(f), + resp_tx, + }; + let _: Result<_, _> = stream_hnd.job_tx.send(job).await; +} + +fn catch_stmt_error(err: anyhow::Error) -> anyhow::Error { + match err.downcast::() { + Ok(stmt_err) => anyhow!(ResponseError::Stmt(stmt_err)), + Err(err) => err, + } +} + +impl ResponseError { + pub fn code(&self) -> &'static str { + match self { + Self::Auth { source } => source.code(), + Self::SqlTooMany { .. } => "SQL_STORE_TOO_MANY", + Self::StreamNotOpen { .. } => "STREAM_NOT_OPEN", + Self::Stmt(err) => err.code(), + } + } +} diff --git a/libsqlx-server/src/http/admin.rs b/libsqlx-server/src/http/admin.rs index 80e787d6..6b23ef58 100644 --- a/libsqlx-server/src/http/admin.rs +++ b/libsqlx-server/src/http/admin.rs @@ -1,4 +1,4 @@ -use std::{path::PathBuf, sync::Arc}; +use std::sync::Arc; use axum::{extract::State, routing::post, Json, Router}; use color_eyre::eyre::Result; diff --git a/libsqlx-server/src/http/user.rs b/libsqlx-server/src/http/user.rs deleted file mode 100644 index 040f5a66..00000000 --- a/libsqlx-server/src/http/user.rs +++ /dev/null @@ -1,101 +0,0 @@ -use std::sync::Arc; - -use axum::{async_trait, extract::FromRequestParts, response::IntoResponse, routing::get, Router, Json}; -use color_eyre::Result; -use hyper::{http::request::Parts, server::accept::Accept, StatusCode}; -use serde::Serialize; -use tokio::{ - io::{AsyncRead, AsyncWrite}, - sync::mpsc, -}; - -use crate::{allocation::AllocationMessage, manager::Manager}; - -pub struct UserApiConfig { - pub manager: Arc, -} - -struct UserApiState { - manager: Arc, -} - -pub async fn run_user_api(config: UserApiConfig, listener: I) -> Result<()> -where - I: Accept, - I::Conn: AsyncRead + AsyncWrite + Send + Unpin + 'static, -{ - let state = UserApiState { manager: config.manager }; - - let app = Router::new() - .route("/", get(test_database)) - .with_state(Arc::new(state)); - - axum::Server::builder(listener) - .serve(app.into_make_service()) - .await?; - - Ok(()) -} - -struct Database { - sender: mpsc::Sender, -} - -#[derive(Debug, thiserror::Error)] -enum UserApiError { - #[error("missing host header")] - MissingHost, - #[error("invalid host header format")] - InvalidHost, - #[error("Database `{0}` doesn't exist")] - UnknownDatabase(String), -} - -impl UserApiError { - fn http_status(&self) -> StatusCode { - match self { - UserApiError::MissingHost - | UserApiError::InvalidHost - | UserApiError::UnknownDatabase(_) => StatusCode::BAD_REQUEST, - } - } -} - -#[derive(Debug, Serialize)] -struct ApiError { - error: String, -} - -impl IntoResponse for UserApiError { - fn into_response(self) -> axum::response::Response { - let mut resp = Json(ApiError { - error: self.to_string() - }).into_response(); - *resp.status_mut() = self.http_status(); - - resp - } -} - -#[async_trait] -impl FromRequestParts> for Database { - type Rejection = UserApiError; - - async fn from_request_parts( - parts: &mut Parts, - state: &Arc, - ) -> Result { - let Some(host) = parts.headers.get("host") else { return Err(UserApiError::MissingHost) }; - let Ok(host_str) = std::str::from_utf8(host.as_bytes()) else {return Err(UserApiError::MissingHost)}; - let db_id = parse_host(host_str)?; - let Some(sender) = state.manager.alloc(db_id).await else { return Err(UserApiError::UnknownDatabase(db_id.to_owned())) }; - - Ok(Database { sender }) - } -} - -fn parse_host(host: &str) -> Result<&str, UserApiError> { - let mut split = host.split("."); - let Some(db_id) = split.next() else { return Err(UserApiError::InvalidHost) }; - Ok(db_id) -} diff --git a/libsqlx-server/src/http/user/error.rs b/libsqlx-server/src/http/user/error.rs new file mode 100644 index 00000000..9aab9a71 --- /dev/null +++ b/libsqlx-server/src/http/user/error.rs @@ -0,0 +1,41 @@ +use axum::response::IntoResponse; +use axum::Json; +use hyper::StatusCode; +use serde::Serialize; + +#[derive(Debug, thiserror::Error)] +pub enum UserApiError { + #[error("missing host header")] + MissingHost, + #[error("invalid host header format")] + InvalidHost, + #[error("Database `{0}` doesn't exist")] + UnknownDatabase(String), +} + +impl UserApiError { + fn http_status(&self) -> StatusCode { + match self { + UserApiError::MissingHost + | UserApiError::InvalidHost + | UserApiError::UnknownDatabase(_) => StatusCode::BAD_REQUEST, + } + } +} + +#[derive(Debug, Serialize)] +pub struct ApiError { + error: String, +} + +impl IntoResponse for UserApiError { + fn into_response(self) -> axum::response::Response { + let mut resp = Json(ApiError { + error: self.to_string(), + }) + .into_response(); + *resp.status_mut() = self.http_status(); + + resp + } +} diff --git a/libsqlx-server/src/http/user/extractors.rs b/libsqlx-server/src/http/user/extractors.rs new file mode 100644 index 00000000..2b3f5a14 --- /dev/null +++ b/libsqlx-server/src/http/user/extractors.rs @@ -0,0 +1,32 @@ +use std::sync::Arc; + +use axum::async_trait; +use axum::extract::FromRequestParts; +use hyper::http::request::Parts; + +use crate::database::Database; + +use super::{error::UserApiError, UserApiState}; + +#[async_trait] +impl FromRequestParts> for Database { + type Rejection = UserApiError; + + async fn from_request_parts( + parts: &mut Parts, + state: &Arc, + ) -> Result { + let Some(host) = parts.headers.get("host") else { return Err(UserApiError::MissingHost) }; + let Ok(host_str) = std::str::from_utf8(host.as_bytes()) else {return Err(UserApiError::MissingHost)}; + let db_id = parse_host(host_str)?; + let Some(sender) = state.manager.alloc(db_id).await else { return Err(UserApiError::UnknownDatabase(db_id.to_owned())) }; + + Ok(Database { sender }) + } +} + +fn parse_host(host: &str) -> Result<&str, UserApiError> { + let mut split = host.split("."); + let Some(db_id) = split.next() else { return Err(UserApiError::InvalidHost) }; + Ok(db_id) +} diff --git a/libsqlx-server/src/http/user/mod.rs b/libsqlx-server/src/http/user/mod.rs new file mode 100644 index 00000000..4c314a39 --- /dev/null +++ b/libsqlx-server/src/http/user/mod.rs @@ -0,0 +1,48 @@ +use std::sync::Arc; + +use axum::routing::post; +use axum::{Json, Router}; +use color_eyre::Result; +use hyper::server::accept::Accept; +use tokio::io::{AsyncRead, AsyncWrite}; + +use crate::database::Database; +use crate::hrana::http::proto::{PipelineRequestBody, PipelineResponseBody}; +use crate::manager::Manager; + +mod error; +mod extractors; + +pub struct UserApiConfig { + pub manager: Arc, +} + +struct UserApiState { + manager: Arc, +} + +pub async fn run_user_api(config: UserApiConfig, listener: I) -> Result<()> +where + I: Accept, + I::Conn: AsyncRead + AsyncWrite + Send + Unpin + 'static, +{ + let state = UserApiState { + manager: config.manager, + }; + + let app = Router::new() + .route("/v2/pipeline", post(handle_hrana_pipeline)) + .with_state(Arc::new(state)); + + axum::Server::builder(listener) + .serve(app.into_make_service()) + .await?; + + Ok(()) +} + +async fn handle_hrana_pipeline(db: Database, Json(req): Json) -> Json { + let resp = db.hrana_pipeline(req).await; + dbg!(); + Json(resp.unwrap()) +} diff --git a/libsqlx-server/src/main.rs b/libsqlx-server/src/main.rs index 2e9411cf..a8829093 100644 --- a/libsqlx-server/src/main.rs +++ b/libsqlx-server/src/main.rs @@ -13,7 +13,7 @@ use tracing::metadata::LevelFilter; use tracing_subscriber::prelude::*; mod allocation; -mod databases; +mod database; mod hrana; mod http; mod manager; @@ -28,7 +28,9 @@ async fn main() -> Result<()> { let store = Arc::new(Store::new(&db_path)); let admin_api_listener = tokio::net::TcpListener::bind("0.0.0.0:3456").await?; join_set.spawn(run_admin_api( - AdminApiConfig { meta_store: store.clone() }, + AdminApiConfig { + meta_store: store.clone(), + }, AddrIncoming::from_listener(admin_api_listener)?, )); diff --git a/libsqlx-server/src/manager.rs b/libsqlx-server/src/manager.rs index 81ac3b72..48315e0a 100644 --- a/libsqlx-server/src/manager.rs +++ b/libsqlx-server/src/manager.rs @@ -1,4 +1,3 @@ -use std::collections::HashMap; use std::path::PathBuf; use std::sync::Arc; @@ -7,6 +6,7 @@ use tokio::sync::mpsc; use tokio::task::JoinSet; use crate::allocation::{Allocation, AllocationMessage, Database}; +use crate::hrana; use crate::meta::Store; pub struct Manager { @@ -39,10 +39,10 @@ impl Manager { let alloc = Allocation { inbox, database: Database::from_config(&config, path), - connections: HashMap::new(), connections_futs: JoinSet::new(), next_conn_id: 0, max_concurrent_connections: config.max_conccurent_connection, + hrana_server: Arc::new(hrana::http::Server::new(None)), // TODO: handle self URL? }; tokio::spawn(alloc.run()); diff --git a/libsqlx-server/src/meta.rs b/libsqlx-server/src/meta.rs index 4eade1b0..06e37a76 100644 --- a/libsqlx-server/src/meta.rs +++ b/libsqlx-server/src/meta.rs @@ -32,7 +32,7 @@ impl Store { }); } - pub async fn deallocate(&self, alloc_id: Uuid) { + pub async fn deallocate(&self, _alloc_id: Uuid) { todo!() } @@ -48,7 +48,7 @@ impl Store { tokio::task::block_in_place(|| { let mut out = Vec::new(); for kv in self.meta_store.iter() { - let (k, v) = kv.unwrap(); + let (_k, v) = kv.unwrap(); let alloc = bincode::deserialize(&v).unwrap(); out.push(alloc); } diff --git a/libsqlx/src/analysis.rs b/libsqlx/src/analysis.rs index 0c7f7d43..fccbf3dc 100644 --- a/libsqlx/src/analysis.rs +++ b/libsqlx/src/analysis.rs @@ -1,4 +1,3 @@ -use anyhow::Result; use fallible_iterator::FallibleIterator; use sqlite3_parser::ast::{Cmd, PragmaBody, QualifiedName, Stmt}; use sqlite3_parser::lexer::sql::{Parser, ParserError}; @@ -201,15 +200,15 @@ impl Statement { } } - pub fn parse(s: &str) -> impl Iterator> + '_ { + pub fn parse(s: &str) -> impl Iterator> + '_ { fn parse_inner( original: &str, stmt_count: u64, has_more_stmts: bool, c: Cmd, - ) -> Result { + ) -> crate::Result { let kind = - StmtKind::kind(&c).ok_or_else(|| anyhow::anyhow!("unsupported statement"))?; + StmtKind::kind(&c).ok_or_else(|| crate::error::Error::UnsupportedStatement)?; if stmt_count == 1 && !has_more_stmts { // XXX: Temporary workaround for integration with Atlas @@ -259,9 +258,7 @@ impl Statement { found: Some(found), }, Some((line, col)), - )) => Some(Err(anyhow::anyhow!( - "syntax error around L{line}:{col}: `{found}`" - ))), + )) => Some(Err(crate::error::Error::SyntaxError { line, col, found})), Err(e) => Some(Err(e.into())), } }) diff --git a/libsqlx/src/connection.rs b/libsqlx/src/connection.rs index a5eb7e60..38d31964 100644 --- a/libsqlx/src/connection.rs +++ b/libsqlx/src/connection.rs @@ -2,8 +2,7 @@ use rusqlite::types::Value; use crate::program::{Program, Step}; use crate::query::Query; -use crate::result_builder::ResultBuilder; -use crate::QueryBuilderConfig; +use crate::result_builder::{ResultBuilder, QueryBuilderConfig, QueryResultBuilderError}; #[derive(Debug, Clone)] pub struct DescribeResponse { @@ -48,7 +47,7 @@ pub trait Connection { fn init( &mut self, _config: &QueryBuilderConfig, - ) -> std::result::Result<(), crate::QueryResultBuilderError> { + ) -> std::result::Result<(), QueryResultBuilderError> { self.error = None; self.rows.clear(); self.current_row.clear(); @@ -59,12 +58,12 @@ pub trait Connection { fn add_row_value( &mut self, v: rusqlite::types::ValueRef, - ) -> Result<(), crate::QueryResultBuilderError> { + ) -> Result<(), QueryResultBuilderError> { self.current_row.push(v.into()); Ok(()) } - fn finish_row(&mut self) -> Result<(), crate::QueryResultBuilderError> { + fn finish_row(&mut self) -> Result<(), QueryResultBuilderError> { let row = std::mem::take(&mut self.current_row); self.rows.push(row); @@ -74,7 +73,7 @@ pub trait Connection { fn step_error( &mut self, error: crate::error::Error, - ) -> Result<(), crate::QueryResultBuilderError> { + ) -> Result<(), QueryResultBuilderError> { self.error.replace(error); Ok(()) } diff --git a/libsqlx/src/database/libsql/connection.rs b/libsqlx/src/database/libsql/connection.rs index 0a2cb6b0..554a22da 100644 --- a/libsqlx/src/database/libsql/connection.rs +++ b/libsqlx/src/database/libsql/connection.rs @@ -177,7 +177,7 @@ impl LibsqlConnection { query .params .bind(&mut stmt) - .map_err(Error::LibSqlInvalidQueryParams)?; + .map_err(|e|Error::LibSqlInvalidQueryParams(e.to_string()))?; let mut qresult = stmt.raw_query(); builder.begin_rows()?; diff --git a/libsqlx/src/database/libsql/mod.rs b/libsqlx/src/database/libsql/mod.rs index 41de3569..2844a204 100644 --- a/libsqlx/src/database/libsql/mod.rs +++ b/libsqlx/src/database/libsql/mod.rs @@ -174,6 +174,7 @@ impl Database for LibsqlDatabase { type Connection = LibsqlConnection<::Context>; fn connect(&self) -> Result { + dbg!(); Ok( LibsqlConnection::<::Context>::new( &self.db_path, diff --git a/libsqlx/src/error.rs b/libsqlx/src/error.rs index 6e35e217..47fde1ae 100644 --- a/libsqlx/src/error.rs +++ b/libsqlx/src/error.rs @@ -1,10 +1,12 @@ use crate::result_builder::QueryResultBuilderError; +pub use rusqlite::Error as RusqliteError; +pub use rusqlite::ffi::ErrorCode; #[allow(clippy::enum_variant_names)] #[derive(Debug, thiserror::Error)] pub enum Error { #[error("LibSQL failed to bind provided query parameters: `{0}`")] - LibSqlInvalidQueryParams(anyhow::Error), + LibSqlInvalidQueryParams(String), #[error("Transaction timed-out")] LibSqlTxTimeout, #[error("Server can't handle additional transactions")] @@ -33,6 +35,14 @@ pub enum Error { Blocked(Option), #[error("invalid replication log header")] InvalidLogHeader, + #[error("unsupported statement")] + UnsupportedStatement, + #[error("Syntax error at {line}:{col}: {found}")] + SyntaxError { + line: u64, col: usize, found: String + }, + #[error(transparent)] + LexerError(#[from] sqlite3_parser::lexer::sql::Error) } impl From for Error { diff --git a/libsqlx/src/lib.rs b/libsqlx/src/lib.rs index a6e3c3a2..f9ef106d 100644 --- a/libsqlx/src/lib.rs +++ b/libsqlx/src/lib.rs @@ -4,8 +4,8 @@ pub mod query; mod connection; mod database; -mod program; -mod result_builder; +pub mod program; +pub mod result_builder; mod seal; pub type Result = std::result::Result; @@ -14,7 +14,6 @@ pub use connection::Connection; pub use database::libsql; pub use database::proxy; pub use database::Database; -pub use program::Program; -pub use result_builder::{ - Column, QueryBuilderConfig, QueryResultBuilderError, ResultBuilder, ResultBuilderExt, -}; +pub use database::libsql::replication_log::FrameNo; + +pub use rusqlite; diff --git a/libsqlx/src/program.rs b/libsqlx/src/program.rs index 3eb2f551..131dd125 100644 --- a/libsqlx/src/program.rs +++ b/libsqlx/src/program.rs @@ -4,13 +4,13 @@ use crate::query::Query; #[derive(Debug, Clone)] pub struct Program { - pub steps: Arc>, + pub steps: Arc<[Step]>, } impl Program { pub fn new(steps: Vec) -> Self { Self { - steps: Arc::new(steps), + steps: steps.into(), } } @@ -19,7 +19,20 @@ impl Program { } pub fn steps(&self) -> &[Step] { - self.steps.as_slice() + &self.steps + } + + /// transforms a collection of queries into a batch program. The execution of each query + /// depends on the success of the previous one. + pub fn from_queries(qs: impl IntoIterator) -> Self { + let steps = qs.into_iter().enumerate().map(|(idx, query)| Step { + cond: (idx > 0).then(|| Cond::Ok { step: idx - 1 }), + query, + }) + .collect(); + + Self { steps } + } #[cfg(test)] diff --git a/libsqlx/src/result_builder.rs b/libsqlx/src/result_builder.rs index ae299b1e..be5e27a7 100644 --- a/libsqlx/src/result_builder.rs +++ b/libsqlx/src/result_builder.rs @@ -2,7 +2,7 @@ use std::fmt; use std::io::{self, ErrorKind}; use bytesize::ByteSize; -use rusqlite::types::ValueRef; +pub use rusqlite::types::ValueRef; use crate::database::FrameNo; @@ -170,6 +170,12 @@ pub struct StepResultsBuilder { is_skipped: bool, } +impl StepResultsBuilder { + pub fn into_ret(self) -> Vec { + self.step_results + } +} + impl ResultBuilder for StepResultsBuilder { fn init(&mut self, _config: &QueryBuilderConfig) -> Result<(), QueryResultBuilderError> { *self = Default::default(); From 6cefe75ec8bc4f8f0761cda758d9bc3a3659ae73 Mon Sep 17 00:00:00 2001 From: ad hoc Date: Tue, 27 Jun 2023 19:03:14 +0200 Subject: [PATCH 10/64] LINC protocol spec --- docs/LINC.md | 282 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 282 insertions(+) create mode 100644 docs/LINC.md diff --git a/docs/LINC.md b/docs/LINC.md new file mode 100644 index 00000000..9a915c65 --- /dev/null +++ b/docs/LINC.md @@ -0,0 +1,282 @@ +# Libsql Inter-Node Communication protocol: LINC protocol + +## Overview + +This document describes the version 1 of Libsql Inter-Node Communication (LINC) +protocol. + +The first version of the protocol aims to merge the existing two +protocol (proxy and replication) into a single one, and adds support for multi-tenancy. + +LINC v1 is designed to handle 3 tasks: +- inter-node communication +- database replication +- proxying of request from replicas to primaries + +LINC makes use of streams to multiplex messages between databases on different nodes. + +LINC v1 is implemented on top of TCP. + +LINC uses bincode for message serialization and deserialization. + +## Connection protocol + +Each node is identified by a `node_id`, and an address. +At startup, a sqld node is configured with list of peers (`(node_id, node_addr)`). A connection between two peers is initiated by the peer with the greatest node_id. + +```mermaid +graph TD +node4 --> node3 +node4 --> node2 +node4 --> node1 +node3 --> node2 +node3 --> node1 +node2 --> node1 +node1 +``` + +A new node node can be added to the cluster with no reconfiguration as long as its `node_id` is greater than all other `node_id` in the cluster and it has the address of all the other nodes. In this case, the new node will initiate a connection with all other nodes. + +On disconnection, the initiator of the connection attempts to reconnect. + +## Messages + +```rust +enum Message { + /// Messages destined to a node + Node(NodeMessage), + /// message destined to a stream + Stream { + stream_id: StreamId, + payload: StreamMessage, + }, +} + +enum NodeMessage { + /// Initial message exchanged between nodes when connecting + Handshake { + protocol_version: String, + node_id: String, + }, + /// Request to open a bi-directional stream between the client and the server + OpenStream { + /// Id to give to the newly opened stream + stream_id: StreamId, + /// Id of the database to open the stream to. + database_id: Uuid, + }, + /// Close a previously opened stream + CloseStream { + id: StreamId, + }, + /// Error type returned while handling a node message + Error(NodeError), +} + +enum NodeError { + UnknownStream(StreamId), + HandshakeVersionMismatch { expected: u32 }, + StreamAlreadyExist(StreamId), + UnknownDatabase(DatabaseId, StreamId), +} + +enum StreamMessage { + /// Replication message between a replica and a primary + Replication(ReplicationMessage), + /// Proxy message between a replica and a primary + Proxy(ProxyMessage), + Error(StreamError), +} + +enum ReplicationMessage { + HandshakeResponse { + /// id of the replication log + log_id: Uuid, + /// current frame_no of the primary + current_frame_no: u64, + }, + /// Replication request + Replicate { + /// next frame no to send + next_frame_no: u64, + }, + /// a batch of frames that are part of the same transaction + Transaction { + /// if not None, then the last frame is a commit frame, and this is the new size of the database. + size_after: Option, + /// frame_no of the last frame in frames + end_frame_no: u64 + /// a batch of frames part of the transaction. + frames: Vec + }, + /// Error occurred handling a replication message + Error(StreamError) +} + +struct Frame { + /// Page id of that frame + page_id: u32, + /// Data + data: Bytes, +} + +enum ProxyMessage { + /// Proxy a query to a primary + ProxyRequest { + /// id of the connection to perform the query against + /// If the connection doesn't already exist it is created + /// Id of the request. + /// Responses to this request must have the same id. + connection_id: u32, + req_id: u32, + query: Query, + }, + /// Response to a proxied query + ProxyResponse { + /// id of the request this message is a response to. + req_id: u32, + /// Collection of steps to drive the query builder transducer. + row_step: [RowStep] + }, + /// Stop processing request `id`. + CancelRequest { + req_id: u32, + }, + /// Close Connection with passed id. + CloseConnection { + connection_id: u32, + }, +} + +/// Steps applied to the query builder transducer to build a response to a proxied query. +/// Those types closely mirror those of the `QueryBuilderTrait`. +enum BuilderStep { + BeginStep, + FinishStep(u64, Option), + StepError(StepError), + ColsDesc([Column]), + BeginRows, + BeginRow, + AddRowValue(Value), + FinishRow, + FinishRos, + Finish(ConnectionState) +} + +// State of the connection after a query was executed +enum ConnectionState { + /// The connection is still in a open transaction state + OpenTxn, + /// The connection is idle. + Idle, +} + +struct Column { + /// name of the column + name: string, + /// Declared type of the column, if any. + decl_ty: Option, +} + +/// for now, the stringified version of a sqld::error::Error. +struct StepError(String); + +enum StreamError { + NotAPrimary, + AlreadyReplicating, +} +``` + +## Node Handshake + +When a node connects to another node, it first need to perform a handshake. The +handshake is initialized by the initializer of the connection. It sends the +following message: + +```typescipt +type NodeHandshake = { + version: string, // protocol version + node_id: string, +} +``` + +If a peer receives a connection from a peer with a id smaller than his, it must reject the handshake with a `IllegalConnection` error + +## Streams + +Messages destined to a particular database are sent as part of a stream. A +stream is created by sending a `NodeMessage::OpenStream`, specifying the id of +the stream to open, along with the id of the database for which to open this +stream. If the requested database is not on the destination node, the +destination node respond with a `NodeError::UnknownDatabase` error, and the stream in not +opened. + +If a node receives a message for a stream that was not opened before, it responds a `NodeError::UnknownStream` + +A stream is closed by sending a `CloseStream` with the id of the stream. If the +stream does not exist an `NodeError::UnknownStream` error is returned. + +Streams can be opened by either peer. Each stream is identified with by `i32` +stream id. The peer that initiated the original connection allocates positive +stream ids, while the acceptor peer allocates negative ids. 0 is not a legal +value for a stream_id. The receiver of a request for a stream with id 0 must +close the connection immediately. + +The peer opening a stream is responsible for sending the close message. The +other peer can close the stream at any point, but must not send close message +for that stream. On subsequent message to that stream, it will respond with an +`UnknownStream` message, forcing the initiator to deal with recreating a +stream if necessary. + +## Sub-protocols + +### Replication + +The replica is responsible for initiating the replication protocol. This is +done by opening a stream to a primary. If the destination of the stream is not a +primary database, it responds with a `StreamError::NotAPrimary` error and immediately close +the stream. If the destination database is a primary, it responds to the stream +open request with a `ReplicationMessage::HandshakeResponse` message. This message informs the +replica of the current log version, and of the primary current replication +index (frame_no). + +The replica compares the log version it received from the primary with the one it has, if any. If the +versions don't match, the replica deletes its state and start replicating again from the start. + +After a successful handshake, the replica sends a `ReplicationMessage::Replicate` message with the +next frame_no it's expecting. For example if the replica has not replicated any +frame yet, it sends `ReplicationMessage::Replicate { next_frame_no: 0 }` to +signify to the primary that it's expecting to be sent frame 0. The primary +sends the smallest frame with a `frame_no` satisfying `frame_no >= +next_frame_no`. Because logs can be compacted, the next frame_no the primary +sends to the replica isn't necessarily the one the replica is expecting. It's correct to send +the smallest frame >= next_frame_no because frame_nos only move forward in the event of a compaction: a +frame can only be missing if it was written too more recently, hence _moving +forward_ in the log. The primary ensure consistency by moving commit points +accordingly. It is an error for the primary to send a frame_no strictly less +than the requested frame_no, frame_nos can be received in any order. + +In the event of a disconnection, it is the replica's duty to re-initiate the replication protocol. + +Sending a replicate request twice on the same stream is an error. If a primary +receives more than a single `Replicate` request, it closes the stream and sends +a `StreamError::AlreadyReplicating` request. The replica can re-open a stream and start +replicating again if necessary. + +### Proxy + +Replicas can proxy queries to their primary. Replica can start sending proxy request after they have sent a replication request. + +To proxy a query, a replica sends a `ProxyRequest`. Proxied query on a same connection are serialized. The replica sets the connection id +and the request id for the proxied query. If no connection exists for the +passed id on the primary, one is created. The query is executed on the primary, +and the result rows are returned in `ProxyResponse`. The result rows can be split +into multiple `ProxyResponse`, enabling row streaming. A replica can send a `CancelRequest` to interrupt a request. Any +`ProxyResponse` for that `request_id` can be dropped by the replica, and the +primary should stop sending any more `ProxyResponse` message upon receiving the +cancel request. The primary must rollback a cancelled request. + +The primary can reduce the amount of concurrent open transaction by closing the +underlying SQLite connection for proxied connections that are not in a open +transaction state (`is_autocommit` is true). Subsequent requests on that +connection id will re-open a connection, if necessary. From da0f74356c7c75d5c6bce7c1f48dbadf2c18a3e9 Mon Sep 17 00:00:00 2001 From: ad hoc Date: Tue, 27 Jun 2023 19:04:17 +0200 Subject: [PATCH 11/64] implement LINC protocol --- Cargo.lock | 1172 +++++++++-------- libsqlx-server/Cargo.toml | 7 + libsqlx-server/src/database.rs | 5 +- libsqlx-server/src/linc/bus.rs | 186 +++ libsqlx-server/src/linc/connection.rs | 723 ++++++++++ libsqlx-server/src/linc/connection_manager.rs | 0 libsqlx-server/src/linc/connection_pool.rs | 202 +++ libsqlx-server/src/linc/mod.rs | 38 + libsqlx-server/src/linc/net.rs | 81 ++ libsqlx-server/src/linc/proto.rs | 214 +++ libsqlx-server/src/linc/server.rs | 347 +++++ libsqlx-server/src/main.rs | 1 + 12 files changed, 2421 insertions(+), 555 deletions(-) create mode 100644 libsqlx-server/src/linc/bus.rs create mode 100644 libsqlx-server/src/linc/connection.rs create mode 100644 libsqlx-server/src/linc/connection_manager.rs create mode 100644 libsqlx-server/src/linc/connection_pool.rs create mode 100644 libsqlx-server/src/linc/mod.rs create mode 100644 libsqlx-server/src/linc/net.rs create mode 100644 libsqlx-server/src/linc/proto.rs create mode 100644 libsqlx-server/src/linc/server.rs diff --git a/Cargo.lock b/Cargo.lock index 9752e54f..1e3d5a3d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4,11 +4,11 @@ version = 3 [[package]] name = "addr2line" -version = "0.17.0" +version = "0.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b9ecd88a8c8378ca913a680cd98f0f13ac67383d35993f86c90a70e3f137816b" +checksum = "a76fd60b23679b7d19bd066031410fb7e458ccc5e958eb5c325888ce4baedc97" dependencies = [ - "gimli 0.26.2", + "gimli", ] [[package]] @@ -17,7 +17,7 @@ version = "0.20.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f4fa78e18c64fce05e902adecd7a5eed15a5e0a3439f7b0e169f0252214865e3" dependencies = [ - "gimli 0.27.3", + "gimli", ] [[package]] @@ -26,17 +26,6 @@ version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" -[[package]] -name = "ahash" -version = "0.7.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fcb51a0695d8f838b1ee009b3fbf66bda078cd64590202a864a8f3e8c4315c47" -dependencies = [ - "getrandom", - "once_cell", - "version_check", -] - [[package]] name = "ahash" version = "0.8.3" @@ -50,9 +39,9 @@ dependencies = [ [[package]] name = "aho-corasick" -version = "1.0.1" +version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "67fc08ce920c31afb70f013dcce1bfc3a3195de6a228474e45e1f145b36f8d04" +checksum = "43f6cb1bf222025340178f382c426f13757b2960e89779dfcb319c32542a5a41" dependencies = [ "memchr", ] @@ -72,11 +61,23 @@ dependencies = [ "alloc-no-stdlib", ] +[[package]] +name = "allocator-api2" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "56fc6cf8dc8c4158eed8649f9b8b0ea1518eb62b544fe9490d66fa0b349eafe9" + [[package]] name = "ambient-authority" -version = "0.0.1" +version = "0.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e9d4ee0d472d1cd2e28c97dfa124b3d8d992e10eb0a035f33f5d12e3a177ba3b" + +[[package]] +name = "android-tzdata" +version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ec8ad6edb4840b78c5c3d88de606b22252d552b55f3a4699fbb10fc070ec3049" +checksum = "e999941b234f3131b00bc13c22d06e8c5ff726d1b6318ac7eb276997bbb4fef0" [[package]] name = "android_system_properties" @@ -98,21 +99,21 @@ dependencies = [ "anstyle-query", "anstyle-wincon", "colorchoice", - "is-terminal 0.4.7", + "is-terminal", "utf8parse", ] [[package]] name = "anstyle" -version = "1.0.0" +version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "41ed9a86bf92ae6580e0a31281f65a1b1d867c0cc68d5346e2ae128dddfa6a7d" +checksum = "3a30da5c5f2d5e72842e00bcb57657162cdabef0931f40e2deb9b4140440cecd" [[package]] name = "anstyle-parse" -version = "0.2.0" +version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e765fd216e48e067936442276d1d57399e37bce53c264d6fefbe298080cb57ee" +checksum = "938874ff5980b03a87c5524b3ae5b59cf99b1d6bc836848df7bc5ada9643c333" dependencies = [ "utf8parse", ] @@ -158,10 +159,19 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bddcadddf5e9015d310179a59bb28c4d4b9920ad0f11e8e14dbadf654890c9a6" [[package]] -name = "arrayvec" -version = "0.7.2" +name = "async-bincode" +version = "0.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8da52d66c7071e2e3fa2a1e5c6d088fec47b593032b254f5e980de8ea54454d6" +checksum = "0688a53af69da2208017b6d68ea675f073fbbc2488e71cc9b40af48ad9404fc2" +dependencies = [ + "bincode", + "byteorder", + "bytes 1.4.0", + "futures-core", + "futures-sink", + "serde", + "tokio", +] [[package]] name = "async-compression" @@ -191,7 +201,7 @@ dependencies = [ "log", "parking", "polling", - "rustix 0.37.19", + "rustix 0.37.23", "slab", "socket2", "waker-fn", @@ -225,18 +235,18 @@ checksum = "16e62a023e7c117e27523144c5d2459f4397fcc3cab0085af8e2224f643a0193" dependencies = [ "proc-macro2", "quote", - "syn 2.0.23", + "syn 2.0.25", ] [[package]] name = "async-trait" -version = "0.1.68" +version = "0.1.71" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b9ccdd8f2a161be9bd5c023df56f1b2a0bd1d83872ae53b71a84a12c9bf6e842" +checksum = "a564d521dd56509c4c47480d00b80ee55f7e385ae48db5744c67ad50c92d2ebf" dependencies = [ "proc-macro2", "quote", - "syn 2.0.23", + "syn 2.0.25", ] [[package]] @@ -245,17 +255,6 @@ version = "0.5.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c59bdb34bc650a32731b31bd8f0829cc15d24a708ee31559e0bb34f2bc320cba" -[[package]] -name = "atty" -version = "0.2.14" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d9b39be18770d11421cdb1b9947a45dd3f37e93092cbf377614828a319d5fee8" -dependencies = [ - "hermit-abi 0.1.19", - "libc", - "winapi", -] - [[package]] name = "autocfg" version = "1.1.0" @@ -285,7 +284,7 @@ dependencies = [ "http", "hyper", "ring", - "time 0.3.21", + "time 0.3.23", "tokio", "tower", "tracing", @@ -455,7 +454,7 @@ dependencies = [ "percent-encoding", "regex", "sha2", - "time 0.3.21", + "time 0.3.23", "tracing", ] @@ -595,7 +594,7 @@ dependencies = [ "itoa", "num-integer", "ryu", - "time 0.3.21", + "time 0.3.23", ] [[package]] @@ -730,13 +729,13 @@ dependencies = [ "lazy_static", "lazycell", "peeking_take_while", - "prettyplease 0.2.6", + "prettyplease 0.2.10", "proc-macro2", "quote", "regex", "rustc-hash", "shlex", - "syn 2.0.23", + "syn 2.0.25", ] [[package]] @@ -762,9 +761,9 @@ checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" [[package]] name = "bitflags" -version = "2.3.1" +version = "2.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6776fc96284a0bb647b615056fc496d1fe1644a7ab01829818a6d91cae888b84" +checksum = "630be753d4e58660abd17930c71b647fe46c27ea6b63cc59e1e3851406972e42" [[package]] name = "block-buffer" @@ -862,7 +861,7 @@ checksum = "fdde5c9cd29ebd706ce1b35600920a33550e402fc998a2e53ad3b42c3c47a192" dependencies = [ "proc-macro2", "quote", - "syn 2.0.23", + "syn 2.0.25", ] [[package]] @@ -913,39 +912,38 @@ dependencies = [ [[package]] name = "cap-fs-ext" -version = "0.26.1" +version = "1.0.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0b0e103ce36d217d568903ad27b14ec2238ecb5d65bad2e756a8f3c0d651506e" +checksum = "58bc48200a1a0fa6fba138b1802ad7def18ec1cdd92f7b2a04e21f1bd887f7b9" dependencies = [ "cap-primitives", "cap-std", - "io-lifetimes 0.7.5", - "windows-sys 0.36.1", + "io-lifetimes 1.0.11", + "windows-sys 0.48.0", ] [[package]] name = "cap-primitives" -version = "0.26.1" +version = "1.0.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "af3f336aa91cce16033ed3c94ac91d98956c49b420e6d6cd0dd7d0e386a57085" +checksum = "a4b6df5b295dca8d56f35560be8c391d59f0420f72e546997154e24e765e6451" dependencies = [ "ambient-authority", "fs-set-times", "io-extras", - "io-lifetimes 0.7.5", + "io-lifetimes 1.0.11", "ipnet", "maybe-owned", - "rustix 0.35.13", - "winapi-util", - "windows-sys 0.36.1", - "winx", + "rustix 0.37.23", + "windows-sys 0.48.0", + "winx 0.35.1", ] [[package]] name = "cap-rand" -version = "0.26.1" +version = "1.0.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d14b9606aa9550d34651bc481443203bc014237bdb992d201d2afa62d2ec6dea" +checksum = "4d25555efacb0b5244cf1d35833d55d21abc916fff0eaad254b8e2453ea9b8ab" dependencies = [ "ambient-authority", "rand", @@ -953,27 +951,26 @@ dependencies = [ [[package]] name = "cap-std" -version = "0.26.1" +version = "1.0.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c9d6e70b626eceac9d6fc790fe2d72cc3f2f7bc3c35f467690c54a526b0f56db" +checksum = "3373a62accd150b4fcba056d4c5f3b552127f0ec86d3c8c102d60b978174a012" dependencies = [ "cap-primitives", "io-extras", - "io-lifetimes 0.7.5", - "ipnet", - "rustix 0.35.13", + "io-lifetimes 1.0.11", + "rustix 0.37.23", ] [[package]] name = "cap-time-ext" -version = "0.26.1" +version = "1.0.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c3a0524f7c4cff2ea547ae2b652bf7a348fd3e48f76556dc928d8b45ab2f1d50" +checksum = "e95002993b7baee6b66c8950470e59e5226a23b3af39fc59c47fe416dd39821a" dependencies = [ "cap-primitives", "once_cell", - "rustix 0.35.13", - "winx", + "rustix 0.37.23", + "winx 0.35.1", ] [[package]] @@ -1024,13 +1021,13 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" [[package]] name = "chrono" -version = "0.4.24" +version = "0.4.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4e3c5919066adf22df73762e50cffcde3a758f2a848b113b586d1f86728b673b" +checksum = "ec837a71355b28f6556dbd569b37b3f363091c0bd4b2e735674521b4c5fd9bc5" dependencies = [ + "android-tzdata", "iana-time-zone", "js-sys", - "num-integer", "num-traits", "serde", "time 0.1.45", @@ -1081,7 +1078,7 @@ dependencies = [ "heck", "proc-macro2", "quote", - "syn 2.0.23", + "syn 2.0.25", ] [[package]] @@ -1158,9 +1155,9 @@ dependencies = [ [[package]] name = "console-subscriber" -version = "0.1.9" +version = "0.1.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "57ab2224a0311582eb03adba4caaf18644f7b1f10a760803a803b9b605187fc7" +checksum = "d4cf42660ac07fcebed809cfe561dd8730bcd35b075215e6479c516bcd0d11cb" dependencies = [ "console-api", "crossbeam-channel", @@ -1195,7 +1192,7 @@ dependencies = [ "log", "mime", "paste", - "pin-project 1.1.0", + "pin-project 1.1.2", "serde", "serde_json", "tar", @@ -1231,37 +1228,37 @@ dependencies = [ [[package]] name = "cpufeatures" -version = "0.2.7" +version = "0.2.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3e4c1eaa2012c47becbbad2ab175484c2a84d1185b566fb2cc5b8707343dfe58" +checksum = "a17b76ff3a4162b0b27f354a0c87015ddad39d35f9c0c36607a3bdd175dde1f1" dependencies = [ "libc", ] [[package]] name = "cranelift-bforest" -version = "0.90.1" +version = "0.96.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b62c772976416112fa4484cbd688cb6fb35fd430005c1c586224fc014018abad" +checksum = "182b82f78049f54d3aee5a19870d356ef754226665a695ce2fcdd5d55379718e" dependencies = [ "cranelift-entity", ] [[package]] name = "cranelift-codegen" -version = "0.90.1" +version = "0.96.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9b40ed2dd13c2ac7e24f88a3090c68ad3414eb1d066a95f8f1f7b3b819cb4e46" +checksum = "e7c027bf04ecae5b048d3554deb888061bc26f426afff47bf06d6ac933dce0a6" dependencies = [ - "arrayvec", "bumpalo", "cranelift-bforest", "cranelift-codegen-meta", "cranelift-codegen-shared", - "cranelift-egraph", + "cranelift-control", "cranelift-entity", "cranelift-isle", - "gimli 0.26.2", + "gimli", + "hashbrown 0.13.2", "log", "regalloc2", "smallvec", @@ -1270,47 +1267,42 @@ dependencies = [ [[package]] name = "cranelift-codegen-meta" -version = "0.90.1" +version = "0.96.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bb927a8f1c27c34ee3759b6b0ffa528d2330405d5cc4511f0cab33fe2279f4b5" +checksum = "649f70038235e4c81dba5680d7e5ae83e1081f567232425ab98b55b03afd9904" dependencies = [ "cranelift-codegen-shared", ] [[package]] name = "cranelift-codegen-shared" -version = "0.90.1" +version = "0.96.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "43dfa417b884a9ab488d95fd6b93b25e959321fe7bfd7a0a960ba5d7fb7ab927" +checksum = "7a1d1c5ee2611c6a0bdc8d42d5d3dc5ce8bf53a8040561e26e88b9b21f966417" [[package]] -name = "cranelift-egraph" -version = "0.90.1" +name = "cranelift-control" +version = "0.96.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e0a66b39785efd8513d2cca967ede56d6cc57c8d7986a595c7c47d0c78de8dce" +checksum = "da66a68b1f48da863d1d53209b8ddb1a6236411d2d72a280ffa8c2f734f7219e" dependencies = [ - "cranelift-entity", - "fxhash", - "hashbrown 0.12.3", - "indexmap 1.9.3", - "log", - "smallvec", + "arbitrary", ] [[package]] name = "cranelift-entity" -version = "0.90.1" +version = "0.96.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0637ffde963cb5d759bc4d454cfa364b6509e6c74cdaa21298add0ed9276f346" +checksum = "9bd897422dbb66621fa558f4d9209875530c53e3c8f4b13b2849fbb667c431a6" dependencies = [ "serde", ] [[package]] name = "cranelift-frontend" -version = "0.90.1" +version = "0.96.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fb72b8342685e850cb037350418f62cc4fc55d6c2eb9c7ca01b82f9f1a6f3d56" +checksum = "05db883114c98cfcd6959f72278d2fec42e01ea6a6982cfe4f20e88eebe86653" dependencies = [ "cranelift-codegen", "log", @@ -1320,15 +1312,15 @@ dependencies = [ [[package]] name = "cranelift-isle" -version = "0.90.1" +version = "0.96.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "850579cb9e4b448f7c301f1e6e6cbad99abe3f1f1d878a4994cb66e33c6db8cd" +checksum = "84559de86e2564152c87e299c8b2559f9107e9c6d274b24ebeb04fb0a5f4abf8" [[package]] name = "cranelift-native" -version = "0.90.1" +version = "0.96.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2d0a279e5bcba3e0466c734d8d8eb6bfc1ad29e95c37f3e4955b492b5616335e" +checksum = "3f40b57f187f0fe1ffaf281df4adba2b4bc623a0f6651954da9f3c184be72761" dependencies = [ "cranelift-codegen", "libc", @@ -1337,9 +1329,9 @@ dependencies = [ [[package]] name = "cranelift-wasm" -version = "0.90.1" +version = "0.96.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e6b8c5e7ffb754093fb89ec4bd4f9dbb9f1c955427299e334917d284745835c2" +checksum = "f3eab6084cc789b9dd0b1316241efeb2968199fee709f4bb4fe0fb0923bb468b" dependencies = [ "cranelift-codegen", "cranelift-entity", @@ -1421,14 +1413,14 @@ dependencies = [ [[package]] name = "crossbeam-epoch" -version = "0.9.14" +version = "0.9.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "46bd5f3f85273295a9d14aedfb86f6aadbff6d8f5295c4a9edb08e819dcf5695" +checksum = "ae211234986c545741a7dc064309f67ee1e5ad243d0e48335adc0484d960bcc7" dependencies = [ "autocfg", "cfg-if", "crossbeam-utils", - "memoffset 0.8.0", + "memoffset 0.9.0", "scopeguard", ] @@ -1444,9 +1436,9 @@ dependencies = [ [[package]] name = "crossbeam-utils" -version = "0.8.15" +version = "0.8.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3c063cd8cc95f5c377ed0d4b49a4b21f632396ff690e8470c29b3359b346984b" +checksum = "5a22b2d63d4d1dc0b7f1b6b2747dd0088008a9be28b6ddf0b1e7d335e3037294" dependencies = [ "cfg-if", ] @@ -1467,15 +1459,24 @@ version = "2.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c2e66c9d817f1720209181c316d28635c050fa304f9c79e47a520882661b7308" +[[package]] +name = "debugid" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bef552e6f588e446098f6ba40d89ac146c8c7b64aade83c051ee00bb5d2bc18d" +dependencies = [ + "uuid", +] + [[package]] name = "derive_arbitrary" -version = "1.3.0" +version = "1.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f3cdeb9ec472d588e539a818b2dee436825730da08ad0017c4b1a17676bdc8b7" +checksum = "53e0efad4403bfc52dc201159c4b842a246a14b98c64b55dfd0f2d89729dfeb8" dependencies = [ "proc-macro2", "quote", - "syn 1.0.109", + "syn 2.0.25", ] [[package]] @@ -1580,7 +1581,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "85cdab6a89accf66733ad5a1693a4dcced6aeff64602b634530dd73c1f3ee9f0" dependencies = [ "humantime", - "is-terminal 0.4.7", + "is-terminal", "log", "regex", "termcolor", @@ -1592,17 +1593,6 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "88bffebc5d80432c9b140ee17875ff173a8ab62faad5b257da912bd2f6c1c0a1" -[[package]] -name = "errno" -version = "0.2.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f639046355ee4f37944e44f60642c6f3a7efa3cf6b78c78a0d989a8ce6c396a1" -dependencies = [ - "errno-dragonfly", - "libc", - "winapi", -] - [[package]] name = "errno" version = "0.3.1" @@ -1676,6 +1666,17 @@ dependencies = [ "instant", ] +[[package]] +name = "fd-lock" +version = "4.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b0377f1edc77dbd1118507bc7a66e4ab64d2b90c66f90726dc801e73a8c68f9" +dependencies = [ + "cfg-if", + "rustix 0.38.3", + "windows-sys 0.48.0", +] + [[package]] name = "file-per-thread-logger" version = "0.1.6" @@ -1722,7 +1723,7 @@ checksum = "2cd66269887534af4b0c3e3337404591daa8dc8b9b2b3db71f9523beb4bafb41" dependencies = [ "proc-macro2", "quote", - "syn 2.0.23", + "syn 2.0.25", ] [[package]] @@ -1748,22 +1749,22 @@ checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b" [[package]] name = "form_urlencoded" -version = "1.1.0" +version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a9c384f161156f5260c24a097c56119f9be8c798586aecc13afbcbe7b7e26bf8" +checksum = "a62bc1cf6f830c2ec14a513a9fb124d0a213a629668a4186f329db21fe045652" dependencies = [ "percent-encoding", ] [[package]] name = "fs-set-times" -version = "0.17.1" +version = "0.19.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a267b6a9304912e018610d53fe07115d8b530b160e85db4d2d3a59f3ddde1aec" +checksum = "6d167b646a876ba8fda6b50ac645cfd96242553cbaf0ca4fccaa39afcbf0801f" dependencies = [ - "io-lifetimes 0.7.5", - "rustix 0.35.13", - "windows-sys 0.36.1", + "io-lifetimes 1.0.11", + "rustix 0.38.3", + "windows-sys 0.48.0", ] [[package]] @@ -1847,7 +1848,7 @@ checksum = "89ca545a94061b6365f2c7355b4b32bd20df3ff95f02da9329b34ccc3bd6ee72" dependencies = [ "proc-macro2", "quote", - "syn 2.0.23", + "syn 2.0.25", ] [[package]] @@ -1901,6 +1902,19 @@ dependencies = [ "byteorder", ] +[[package]] +name = "fxprof-processed-profile" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "27d12c0aed7f1e24276a241aadc4cb8ea9f83000f34bc062b7cc2d51e3b0fabd" +dependencies = [ + "bitflags 2.3.3", + "debugid", + "fxhash", + "serde", + "serde_json", +] + [[package]] name = "generic-array" version = "0.14.7" @@ -1913,43 +1927,26 @@ dependencies = [ [[package]] name = "getrandom" -version = "0.2.9" +version = "0.2.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c85e1d9ab2eadba7e5040d4e09cbd6d072b76a557ad64e797c2cb9d4da21d7e4" +checksum = "be4136b2a15dd319360be1c07d9933517ccf0be8f16bf62a3bee4f0d618df427" dependencies = [ "cfg-if", "libc", "wasi 0.11.0+wasi-snapshot-preview1", ] -[[package]] -name = "ghost" -version = "0.1.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e77ac7b51b8e6313251737fcef4b1c01a2ea102bde68415b62c0ee9268fec357" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.23", -] - [[package]] name = "gimli" -version = "0.26.2" +version = "0.27.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "22030e2c5a68ec659fde1e949a745124b48e6fa8b045b7ed5bd1fe4ccc5c4e5d" +checksum = "b6c80984affa11d98d1b88b66ac8853f143217b399d3c74116778ff8fdb4ed2e" dependencies = [ "fallible-iterator 0.2.0", "indexmap 1.9.3", "stable_deref_trait", ] -[[package]] -name = "gimli" -version = "0.27.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b6c80984affa11d98d1b88b66ac8853f143217b399d3c74116778ff8fdb4ed2e" - [[package]] name = "glob" version = "0.3.1" @@ -1958,9 +1955,9 @@ checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b" [[package]] name = "h2" -version = "0.3.19" +version = "0.3.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d357c7ae988e7d2182f7d7871d0b963962420b0678b0997ce7de72001aeab782" +checksum = "97ec8491ebaf99c8eaa73058b045fe58073cd6be7f596ac993ced0b0a0c01049" dependencies = [ "bytes 1.4.0", "fnv", @@ -1980,9 +1977,6 @@ name = "hashbrown" version = "0.12.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" -dependencies = [ - "ahash 0.7.6", -] [[package]] name = "hashbrown" @@ -1990,7 +1984,7 @@ version = "0.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "43a3c133739dddd0d2990f9a4bdf8eb4b21ef50e4851ca85ab661199821d510e" dependencies = [ - "ahash 0.8.3", + "ahash", ] [[package]] @@ -1998,14 +1992,18 @@ name = "hashbrown" version = "0.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2c6201b9ff9fd90a5a3bac2e56a830d0caa509576f0e503818ee82c181b3437a" +dependencies = [ + "ahash", + "allocator-api2", +] [[package]] name = "hashlink" -version = "0.8.2" +version = "0.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0761a1b9491c4f2e3d66aa0f62d0fba0af9a0e2852e4d48ea506632a4b56e6aa" +checksum = "312f66718a2d7789ffef4f4b7b213138ed9f1eb3aa1d0d82fc99f88fb3ffd26f" dependencies = [ - "hashbrown 0.13.2", + "hashbrown 0.14.0", ] [[package]] @@ -2029,27 +2027,9 @@ checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" [[package]] name = "hermit-abi" -version = "0.1.19" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "62b467343b94ba476dcb2500d242dadbb39557df889310ac77c5d99100aaac33" -dependencies = [ - "libc", -] - -[[package]] -name = "hermit-abi" -version = "0.2.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ee512640fe35acbfb4bb779db6f0d80704c2cacfa2e39b601ef3e3f47d1ae4c7" -dependencies = [ - "libc", -] - -[[package]] -name = "hermit-abi" -version = "0.3.1" +version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fed44880c466736ef9a5c5b5facefb5ed0785676d0c02d612db14e54f0d84286" +checksum = "443144c8cdadd93ebf52ddb4056d257f5b52c04d3c804e657d19eb73fc33668b" [[package]] name = "hex" @@ -2153,15 +2133,16 @@ dependencies = [ [[package]] name = "hyper-rustls" -version = "0.24.0" +version = "0.24.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0646026eb1b3eea4cd9ba47912ea5ce9cc07713d105b1a14698f4e6433d348b7" +checksum = "8d78e1e73ec14cf7375674f74d7dde185c8206fd9dea6fb6295e8a98098aaa97" dependencies = [ + "futures-util", "http", "hyper", - "rustls 0.21.1", + "rustls 0.21.3", "tokio", - "tokio-rustls 0.24.0", + "tokio-rustls 0.24.1", ] [[package]] @@ -2196,7 +2177,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "226df6fd0aece319a325419d770aa9d947defa60463f142cd82b329121f906a3" dependencies = [ "hyper", - "pin-project 1.1.0", + "pin-project 1.1.2", "tokio", "tokio-tungstenite", "tungstenite", @@ -2211,15 +2192,15 @@ dependencies = [ "futures-util", "hex", "hyper", - "pin-project 1.1.0", + "pin-project 1.1.2", "tokio", ] [[package]] name = "iana-time-zone" -version = "0.1.56" +version = "0.1.57" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0722cd7114b7de04316e7ea5456a0bbb20e4adb46fd27a3697adb812cff0f37c" +checksum = "2fad5b825842d2b38bd206f3e81d6957625fd7f0a361e345c30e01a0ae2dd613" dependencies = [ "android_system_properties", "core-foundation-sys", @@ -2238,11 +2219,17 @@ dependencies = [ "cc", ] +[[package]] +name = "id-arena" +version = "2.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "25a2bc672d1148e28034f176e01fffebb08b35768468cc954630da77a1449005" + [[package]] name = "idna" -version = "0.3.0" +version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e14ddfc70884202db2244c223200c204c2bda1bc6e0998d11b5e024d657209e6" +checksum = "7d20d6b07bfbc108882d88ed8e37d39636dcc260e15e30c45e6ba089610b917c" dependencies = [ "unicode-bidi", "unicode-normalization", @@ -2277,9 +2264,9 @@ dependencies = [ [[package]] name = "insta" -version = "1.29.0" +version = "1.30.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9a28d25139df397cbca21408bb742cf6837e04cdbebf1b07b760caf971d6a972" +checksum = "28491f7753051e5704d4d0ae7860d45fae3238d7d235bc4289dcd45c48d3cec3" dependencies = [ "console", "lazy_static", @@ -2300,31 +2287,18 @@ dependencies = [ [[package]] name = "inventory" -version = "0.3.6" +version = "0.3.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e0539b5de9241582ce6bd6b0ba7399313560151e58c9aaf8b74b711b1bdce644" -dependencies = [ - "ghost", -] +checksum = "c38a87a1e0e2752433cd4b26019a469112a25fb43b30f5ee9b3b898925c5a0f9" [[package]] name = "io-extras" -version = "0.15.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4a5d8c2ab5becd8720e30fd25f8fa5500d8dc3fceadd8378f05859bd7b46fc49" -dependencies = [ - "io-lifetimes 0.7.5", - "windows-sys 0.36.1", -] - -[[package]] -name = "io-lifetimes" -version = "0.7.5" +version = "0.17.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "59ce5ef949d49ee85593fc4d3f3f95ad61657076395cbbce23e2121fc5542074" +checksum = "fde93d48f0d9277f977a333eca8313695ddd5301dc96f7e02aeddcb0dd99096f" dependencies = [ - "libc", - "windows-sys 0.42.0", + "io-lifetimes 1.0.11", + "windows-sys 0.48.0", ] [[package]] @@ -2333,38 +2307,31 @@ version = "1.0.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "eae7b9aee968036d54dce06cebaefd919e4472e753296daccd6d344e3e2df0c2" dependencies = [ - "hermit-abi 0.3.1", + "hermit-abi", "libc", "windows-sys 0.48.0", ] [[package]] -name = "ipnet" -version = "2.7.2" +name = "io-lifetimes" +version = "2.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "12b6ee2129af8d4fb011108c73d99a1b83a85977f23b82460c0ae2e25bb4b57f" +checksum = "bffb4def18c48926ccac55c1223e02865ce1a821751a95920448662696e7472c" [[package]] -name = "is-terminal" -version = "0.3.0" +name = "ipnet" +version = "2.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0d508111813f9af3afd2f92758f77e4ed2cc9371b642112c6a48d22eb73105c5" -dependencies = [ - "hermit-abi 0.2.6", - "io-lifetimes 0.7.5", - "rustix 0.35.13", - "windows-sys 0.36.1", -] +checksum = "28b29a3cd74f0f4598934efe3aeba42bae0eb4680554128851ebbecb02af14e6" [[package]] name = "is-terminal" -version = "0.4.7" +version = "0.4.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "adcf93614601c8129ddf72e2d5633df827ba6551541c6d8c59520a371475be1f" +checksum = "cb0889898416213fab133e1d33a0e5858a48177452750691bde3666d0fdbaf8b" dependencies = [ - "hermit-abi 0.3.1", - "io-lifetimes 1.0.11", - "rustix 0.37.19", + "hermit-abi", + "rustix 0.38.3", "windows-sys 0.48.0", ] @@ -2388,9 +2355,9 @@ dependencies = [ [[package]] name = "itoa" -version = "1.0.6" +version = "1.0.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "453ad9f582a441959e5f0d088b02ce04cfe8d51a8eaf077f12ac6d3e94164ca6" +checksum = "62b02a5381cc465bd3041d84623d0fa3b66738b52b8e2fc3bab8ad63ab032f4a" [[package]] name = "ittapi" @@ -2423,9 +2390,9 @@ dependencies = [ [[package]] name = "js-sys" -version = "0.3.63" +version = "0.3.64" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2f37a4a5928311ac501dee68b3c7613a1037d0edb30c8e5427bd832d55d1b790" +checksum = "c5f195fe497f702db0f318b07fdd68edb16955aed830df8363d837542f8f935a" dependencies = [ "wasm-bindgen", ] @@ -2511,9 +2478,9 @@ dependencies = [ [[package]] name = "libsql-wasmtime-bindings" -version = "0.2.1" +version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bcb56f5849df5085e99b7e1bea2e87ff3f93c4143d0922ab43682f904d9cbf59" +checksum = "5c4794ff21e37f83839dad45f8c7977b071315f18705cf73badc9850b9fb6b6f" dependencies = [ "wasmtime", "wasmtime-wasi", @@ -2567,6 +2534,7 @@ dependencies = [ name = "libsqlx-server" version = "0.1.0" dependencies = [ + "async-bincode", "axum", "base64 0.21.2", "bincode", @@ -2576,6 +2544,7 @@ dependencies = [ "futures", "hmac", "hyper", + "itertools 0.11.0", "libsqlx", "moka", "parking_lot 0.12.1", @@ -2588,8 +2557,11 @@ dependencies = [ "sled", "thiserror", "tokio", + "tokio-stream", + "tokio-util", "tracing", "tracing-subscriber", + "turmoil", "uuid", ] @@ -2601,21 +2573,21 @@ checksum = "0717cef1bc8b636c6e1c1bbdefc09e6322da8a9321966e8928ef80d20f7f770f" [[package]] name = "linux-raw-sys" -version = "0.0.46" +version = "0.3.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d4d2456c373231a208ad294c33dc5bff30051eafd954cd4caae83a712b12854d" +checksum = "ef53942eb7bf7ff43a617b3e2c1c4a5ecf5944a7c1bc12d7ee39bbb15e5c1519" [[package]] name = "linux-raw-sys" -version = "0.3.8" +version = "0.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ef53942eb7bf7ff43a617b3e2c1c4a5ecf5944a7c1bc12d7ee39bbb15e5c1519" +checksum = "09fc20d2ca12cb9f044c93e3bd6d32d523e6e2ec3db4f7b2939cd99026ecd3f0" [[package]] name = "lock_api" -version = "0.4.9" +version = "0.4.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "435011366fe56583b16cf956f9df0095b405b82d76425bc8981c0e22e60ec4df" +checksum = "c1cc9717a20b1bb222f333e6a92fd32f7d8a18ddc5a3191a11af45dcbf4dcd16" dependencies = [ "autocfg", "scopeguard", @@ -2623,12 +2595,9 @@ dependencies = [ [[package]] name = "log" -version = "0.4.17" +version = "0.4.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "abb12e687cfb44aa40f41fc3978ef76448f9b6038cad6aef4259d3c095a2382e" -dependencies = [ - "cfg-if", -] +checksum = "b06a4cde4c0f271a446782e3eff8de789548ce57dbc8eca9292c27f4a42004b4" [[package]] name = "mach" @@ -2696,7 +2665,7 @@ version = "0.6.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ffc89ccdc6e10d6907450f753537ebc5c5d3460d2e4e62ea74bd571db62c0f9e" dependencies = [ - "rustix 0.37.19", + "rustix 0.37.23", ] [[package]] @@ -2711,27 +2680,27 @@ dependencies = [ [[package]] name = "memoffset" -version = "0.6.5" +version = "0.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5aa361d4faea93603064a027415f07bd8e1d5c88c9fbf68bf56a285428fd79ce" +checksum = "5de893c32cde5f383baa4c04c5d6dbdd735cfd4a794b0debdb2bb1b421da5ff4" dependencies = [ "autocfg", ] [[package]] name = "memoffset" -version = "0.7.1" +version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5de893c32cde5f383baa4c04c5d6dbdd735cfd4a794b0debdb2bb1b421da5ff4" +checksum = "d61c719bcfbcf5d62b3a09efa6088de8c54bc0bfcd3ea7ae39fcc186108b8de1" dependencies = [ "autocfg", ] [[package]] name = "memoffset" -version = "0.8.0" +version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d61c719bcfbcf5d62b3a09efa6088de8c54bc0bfcd3ea7ae39fcc186108b8de1" +checksum = "5a634b1c61a95585bd15607c6ab0c4e5b226e695ff2800ba0cdccddf208c406c" dependencies = [ "autocfg", ] @@ -2768,14 +2737,13 @@ dependencies = [ [[package]] name = "mio" -version = "0.8.6" +version = "0.8.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5b9d9a46eff5b4ff64b45a9e316a6d1e0bc719ef429cbec4dc630684212bfdf9" +checksum = "927a765cd3fc26206e66b296465fa9d3e5ab003e651c1b3c060e7956d96b19d2" dependencies = [ "libc", - "log", "wasi 0.11.0+wasi-snapshot-preview1", - "windows-sys 0.45.0", + "windows-sys 0.48.0", ] [[package]] @@ -2894,11 +2862,20 @@ dependencies = [ [[package]] name = "num_cpus" -version = "1.15.0" +version = "1.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0fac9e2da13b5eb447a6ce3d392f23a29d8694bff781bf03a16cd9ac8697593b" +checksum = "4161fcb6d602d4d2081af7c3a45852d875a03dd337a6bfdd6e06407b61342a43" +dependencies = [ + "hermit-abi", + "libc", +] + +[[package]] +name = "num_threads" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2819ce041d2ee131036f4fc9d6ae7ae125a3a40e97ba64d04fe799ad9dabbb44" dependencies = [ - "hermit-abi 0.2.6", "libc", ] @@ -2910,12 +2887,12 @@ checksum = "b8f8bdf33df195859076e54ab11ee78a1b208382d3a26ec40d142ffc1ecc49ef" [[package]] name = "object" -version = "0.29.0" +version = "0.30.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "21158b2c33aa6d4561f1c0a6ea283ca92bc54802a93b263e910746d679a7eb53" +checksum = "03b4680b86d9cfafba8fc491dc9b6df26b68cf40e9e6cd73909194759a63c385" dependencies = [ "crc32fast", - "hashbrown 0.12.3", + "hashbrown 0.13.2", "indexmap 1.9.3", "memchr", ] @@ -2966,9 +2943,9 @@ checksum = "dd8b5dd2ae5ed71462c540258bedcb51965123ad7e7ccf4b9a8cafaa4a63576d" [[package]] name = "openssl" -version = "0.10.52" +version = "0.10.55" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "01b8574602df80f7b85fdfc5392fa884a4e3b3f4f35402c070ab34c3d3f78d56" +checksum = "345df152bc43501c5eb9e4654ff05f794effb78d4efe3d53abc158baddc0703d" dependencies = [ "bitflags 1.3.2", "cfg-if", @@ -2987,7 +2964,7 @@ checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" dependencies = [ "proc-macro2", "quote", - "syn 2.0.23", + "syn 2.0.25", ] [[package]] @@ -2998,9 +2975,9 @@ checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf" [[package]] name = "openssl-sys" -version = "0.9.87" +version = "0.9.90" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8e17f59264b2809d77ae94f0e1ebabc434773f370d6ca667bd223ea10e06cc7e" +checksum = "374533b0e45f3a7ced10fcaeccca020e66656bc03dac384f852e4e5a7a8104a6" dependencies = [ "cc", "libc", @@ -3050,7 +3027,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3742b2c103b9f06bc9fff0a37ff4912935851bee6d36f3c02bcc755bcfec228f" dependencies = [ "lock_api", - "parking_lot_core 0.9.7", + "parking_lot_core 0.9.8", ] [[package]] @@ -3069,22 +3046,22 @@ dependencies = [ [[package]] name = "parking_lot_core" -version = "0.9.7" +version = "0.9.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9069cbb9f99e3a5083476ccb29ceb1de18b9118cafa53e90c9551235de2b9521" +checksum = "93f00c865fe7cabf650081affecd3871070f26767e7b2070a3ffae14c654b447" dependencies = [ "cfg-if", "libc", - "redox_syscall 0.2.16", + "redox_syscall 0.3.5", "smallvec", - "windows-sys 0.45.0", + "windows-targets 0.48.1", ] [[package]] name = "paste" -version = "1.0.12" +version = "1.0.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9f746c4065a8fa3fe23974dd82f15431cc8d40779821001404d10d2e79ca7d79" +checksum = "b4b27ab7be369122c218afc2079489cdcb4b517c0a3fc386ff11e1fedfcc2b35" [[package]] name = "peeking_take_while" @@ -3103,9 +3080,9 @@ dependencies = [ [[package]] name = "percent-encoding" -version = "2.2.0" +version = "2.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "478c572c3d73181ff3c2539045f6eb99e5491218eae919370993b890cdbdd98e" +checksum = "9b2a4787296e9989611394c33f193f676704af1686e70b8f8033ab5ba9a35a94" [[package]] name = "petgraph" @@ -3119,18 +3096,18 @@ dependencies = [ [[package]] name = "phf" -version = "0.11.1" +version = "0.11.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "928c6535de93548188ef63bb7c4036bd415cd8f36ad25af44b9789b2ee72a48c" +checksum = "ade2d8b8f33c7333b51bcf0428d37e217e9f32192ae4772156f65063b8ce03dc" dependencies = [ "phf_shared", ] [[package]] name = "phf_codegen" -version = "0.11.1" +version = "0.11.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a56ac890c5e3ca598bbdeaa99964edb5b0258a583a9eb6ef4e89fc85d9224770" +checksum = "e8d39688d359e6b34654d328e262234662d16cc0f60ec8dcbe5e718709342a5a" dependencies = [ "phf_generator", "phf_shared", @@ -3138,9 +3115,9 @@ dependencies = [ [[package]] name = "phf_generator" -version = "0.11.1" +version = "0.11.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b1181c94580fa345f50f19d738aaa39c0ed30a600d95cb2d3e23f94266f14fbf" +checksum = "48e4cc64c2ad9ebe670cb8fd69dd50ae301650392e81c05f9bfcb2d5bdbc24b0" dependencies = [ "phf_shared", "rand", @@ -3148,9 +3125,9 @@ dependencies = [ [[package]] name = "phf_shared" -version = "0.11.1" +version = "0.11.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e1fb5f6f826b772a8d4c0394209441e7d37cbbb967ae9c7e0e8134365c9ee676" +checksum = "90fcb95eef784c2ac79119d1dd819e162b5da872ce6f3c3abe1e8ca1c082f72b" dependencies = [ "siphasher", "uncased", @@ -3167,11 +3144,11 @@ dependencies = [ [[package]] name = "pin-project" -version = "1.1.0" +version = "1.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c95a7476719eab1e366eaf73d0260af3021184f18177925b07f54b30089ceead" +checksum = "030ad2bc4db10a8944cb0d837f158bdfec4d4a4873ab701a95046770d11f8842" dependencies = [ - "pin-project-internal 1.1.0", + "pin-project-internal 1.1.2", ] [[package]] @@ -3187,20 +3164,20 @@ dependencies = [ [[package]] name = "pin-project-internal" -version = "1.1.0" +version = "1.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "39407670928234ebc5e6e580247dd567ad73a3578460c5990f9503df207e8f07" +checksum = "ec2e072ecce94ec471b13398d5402c188e76ac03cf74dd1a975161b23a3f6d9c" dependencies = [ "proc-macro2", "quote", - "syn 2.0.23", + "syn 2.0.25", ] [[package]] name = "pin-project-lite" -version = "0.2.9" +version = "0.2.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e0a7ae3ac2f1173085d398531c705756c94a4c56843785df85a60c1a0afac116" +checksum = "4c40d25201921e5ff0c862a505c6557ea88568a4e3ace775ab55e93f2f4f9d57" [[package]] name = "pin-utils" @@ -3284,12 +3261,12 @@ dependencies = [ [[package]] name = "prettyplease" -version = "0.2.6" +version = "0.2.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3b69d39aab54d069e7f2fe8cb970493e7834601ca2d8c65fd7bbd183578080d1" +checksum = "92139198957b410250d43fad93e630d956499a625c527eda65175c8680f83387" dependencies = [ "proc-macro2", - "syn 2.0.23", + "syn 2.0.25", ] [[package]] @@ -3304,9 +3281,9 @@ dependencies = [ [[package]] name = "proc-macro2" -version = "1.0.63" +version = "1.0.64" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7b368fba921b0dce7e60f5e04ec15e565b3303972b42bcfde1d0713b881959eb" +checksum = "78803b62cbf1f46fde80d7c0e803111524b9877184cfe7c3033659490ac7a7da" dependencies = [ "unicode-ident", ] @@ -3394,6 +3371,17 @@ dependencies = [ "cc", ] +[[package]] +name = "pulldown-cmark" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ffade02495f22453cd593159ea2f59827aae7f53fa8323f756799b670881dcf8" +dependencies = [ + "bitflags 1.3.2", + "memchr", + "unicase", +] + [[package]] name = "pulldown-cmark" version = "0.9.3" @@ -3466,6 +3454,16 @@ dependencies = [ "getrandom", ] +[[package]] +name = "rand_distr" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32cb0b9bc82b0a0876c2dd994a7e7a2683d3e7390ca40e6886785ef0c7e3ee31" +dependencies = [ + "num-traits", + "rand", +] + [[package]] name = "rand_xorshift" version = "0.3.0" @@ -3546,12 +3544,13 @@ dependencies = [ [[package]] name = "regalloc2" -version = "0.4.2" +version = "0.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "91b2eab54204ea0117fe9a060537e0b07a4e72f7c7d182361ecc346cab2240e5" +checksum = "d4a52e724646c6c0800fc456ec43b4165d2f91fba88ceaca06d9e0b400023478" dependencies = [ - "fxhash", + "hashbrown 0.13.2", "log", + "rustc-hash", "slice-group-by", "smallvec", ] @@ -3615,7 +3614,7 @@ dependencies = [ "http", "http-body", "hyper", - "hyper-rustls 0.24.0", + "hyper-rustls 0.24.1", "hyper-tls", "ipnet", "js-sys", @@ -3625,14 +3624,14 @@ dependencies = [ "once_cell", "percent-encoding", "pin-project-lite", - "rustls 0.21.1", + "rustls 0.21.3", "rustls-pemfile", "serde", "serde_json", "serde_urlencoded", "tokio", "tokio-native-tls", - "tokio-rustls 0.24.0", + "tokio-rustls 0.24.1", "tower-service", "url", "wasm-bindgen", @@ -3662,7 +3661,7 @@ name = "rusqlite" version = "0.29.0" source = "git+https://github.com/psarna/rusqlite?rev=d9a97c0f25#d9a97c0f25d48272c91d3f8d93d46cb405c39037" dependencies = [ - "bitflags 2.3.1", + "bitflags 2.3.3", "fallible-iterator 0.2.0", "fallible-streaming-iterator", "hashlink", @@ -3693,31 +3692,30 @@ dependencies = [ [[package]] name = "rustix" -version = "0.35.13" +version = "0.37.23" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "727a1a6d65f786ec22df8a81ca3121107f235970dc1705ed681d3e6e8b9cd5f9" +checksum = "4d69718bf81c6127a49dc64e44a742e8bb9213c0ff8869a22c308f84c1d4ab06" dependencies = [ "bitflags 1.3.2", - "errno 0.2.8", - "io-lifetimes 0.7.5", + "errno", + "io-lifetimes 1.0.11", "itoa", "libc", - "linux-raw-sys 0.0.46", + "linux-raw-sys 0.3.8", "once_cell", - "windows-sys 0.42.0", + "windows-sys 0.48.0", ] [[package]] name = "rustix" -version = "0.37.19" +version = "0.38.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "acf8729d8542766f1b2cf77eb034d52f40d375bb8b615d0b147089946e16613d" +checksum = "ac5ffa1efe7548069688cd7028f32591853cd7b5b756d41bcffd2353e4fc75b4" dependencies = [ - "bitflags 1.3.2", - "errno 0.3.1", - "io-lifetimes 1.0.11", + "bitflags 2.3.3", + "errno", "libc", - "linux-raw-sys 0.3.8", + "linux-raw-sys 0.4.3", "windows-sys 0.48.0", ] @@ -3735,9 +3733,9 @@ dependencies = [ [[package]] name = "rustls" -version = "0.21.1" +version = "0.21.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c911ba11bc8433e811ce56fde130ccf32f5127cab0e0194e9c68c5a5b671791e" +checksum = "b19faa85ecb5197342b54f987b142fb3e30d0c90da40f80ef4fa9a726e6676ed" dependencies = [ "log", "ring", @@ -3747,9 +3745,9 @@ dependencies = [ [[package]] name = "rustls-native-certs" -version = "0.6.2" +version = "0.6.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0167bac7a9f490495f3c33013e7722b53cb087ecbe082fb0c6387c96f634ea50" +checksum = "a9aace74cb666635c918e9c12bc0d348266037aa8eb599b5cba565709a8dff00" dependencies = [ "openssl-probe", "rustls-pemfile", @@ -3759,18 +3757,18 @@ dependencies = [ [[package]] name = "rustls-pemfile" -version = "1.0.2" +version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d194b56d58803a43635bdc398cd17e383d6f71f9182b9a192c127ca42494a59b" +checksum = "2d3987094b1d07b653b7dfdc3f70ce9a1da9c51ac18c1b06b662e4f9a0e9f4b2" dependencies = [ "base64 0.21.2", ] [[package]] name = "rustls-webpki" -version = "0.100.1" +version = "0.101.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d6207cd5ed3d8dca7816f8f3725513a34609c0c765bf652b8c3cb4cfd87db46b" +checksum = "15f36a6828982f422756984e47912a7a51dcbc2a197aa791158f8ca61cd8204e" dependencies = [ "ring", "untrusted", @@ -3778,9 +3776,9 @@ dependencies = [ [[package]] name = "rustversion" -version = "1.0.12" +version = "1.0.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4f3208ce4d8448b3f3e7d168a73f5e0c43a61e32930de3bceeccedb388b6bf06" +checksum = "dc31bd9b61a32c31f9650d18add92aa83a49ba979c143eefd27fe7177b05bd5f" [[package]] name = "rusty-fork" @@ -3796,9 +3794,9 @@ dependencies = [ [[package]] name = "ryu" -version = "1.0.13" +version = "1.0.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f91339c0467de62360649f8d3e185ca8de4224ff281f66000de5eb2a77a79041" +checksum = "fe232bdf6be8c8de797b22184ee71118d63780ea42ac85b61d1baa6d3b782ae9" [[package]] name = "same-file" @@ -3811,11 +3809,11 @@ dependencies = [ [[package]] name = "schannel" -version = "0.1.21" +version = "0.1.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "713cfb06c7059f3588fb8044c0fad1d09e3c01d225e25b9220dbfdcf16dbb1b3" +checksum = "0c3733bf4cf7ea0880754e19cb5a462007c4a8c1914bff372ccc95b464f1df88" dependencies = [ - "windows-sys 0.42.0", + "windows-sys 0.48.0", ] [[package]] @@ -3827,6 +3825,12 @@ dependencies = [ "parking_lot 0.12.1", ] +[[package]] +name = "scoped-tls" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e1cf6437eb19a8f4a6cc0f7dca544973b0b78843adbfeb3683d1a94a0024a294" + [[package]] name = "scopeguard" version = "1.1.0" @@ -3888,22 +3892,22 @@ dependencies = [ [[package]] name = "serde" -version = "1.0.166" +version = "1.0.171" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d01b7404f9d441d3ad40e6a636a7782c377d2abdbe4fa2440e2edcc2f4f10db8" +checksum = "30e27d1e4fd7659406c492fd6cfaf2066ba8773de45ca75e855590f856dc34a9" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.166" +version = "1.0.171" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5dd83d6dde2b6b2d466e14d9d1acce8816dedee94f735eac6395808b3483c6d6" +checksum = "389894603bd18c46fa56231694f8d827779c0951a667087194cf9de94ed24682" dependencies = [ "proc-macro2", "quote", - "syn 2.0.23", + "syn 2.0.25", ] [[package]] @@ -3920,10 +3924,11 @@ dependencies = [ [[package]] name = "serde_path_to_error" -version = "0.1.11" +version = "0.1.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f7f05c1d5476066defcdfacce1f52fc3cae3af1d3089727100c02ae92e5abbe0" +checksum = "8acc4422959dd87a76cb117c191dcbffc20467f06c9100b76721dab370f24d3a" dependencies = [ + "itoa", "serde", ] @@ -3963,9 +3968,9 @@ dependencies = [ [[package]] name = "sha256" -version = "1.1.3" +version = "1.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f9f8b5de2bac3a4ae28e9b611072a8e326d9b26c8189c0972d4c321fa684f1f" +checksum = "08a975c1bc0941703000eaf232c4d8ce188d8d5408d6344b6b2c8c6262772828" dependencies = [ "hex", "sha2", @@ -4019,7 +4024,7 @@ dependencies = [ "num-bigint", "num-traits", "thiserror", - "time 0.3.21", + "time 0.3.23", ] [[package]] @@ -4038,7 +4043,7 @@ dependencies = [ "cargo_metadata", "error-chain", "glob", - "pulldown-cmark", + "pulldown-cmark 0.9.3", "tempfile", "walkdir", ] @@ -4076,9 +4081,9 @@ checksum = "826167069c09b99d56f31e9ae5c99049e932a98c9dc2dac47645b08dbbf76ba7" [[package]] name = "smallvec" -version = "1.10.0" +version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a507befe795404456341dfab10cef66ead4c041f62b8b11bbb92bffe5d0953e0" +checksum = "62bb4feee49fdd9f707ef802e22365a35de4b7b299de4763d44bfea899442ff9" [[package]] name = "socket2" @@ -4119,6 +4124,7 @@ version = "0.15.0" dependencies = [ "anyhow", "arbitrary", + "async-bincode", "async-lock", "async-trait", "aws-config", @@ -4169,12 +4175,14 @@ dependencies = [ "tokio", "tokio-stream", "tokio-tungstenite", + "tokio-util", "tonic 0.8.3", "tonic-build", "tower", "tower-http", "tracing", "tracing-subscriber", + "turmoil", "url", "uuid", "vergen", @@ -4196,7 +4204,7 @@ version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c3995a6daa13c113217b6ad22154865fb06f9cb939bef398fd04f4a7aaaf5bd7" dependencies = [ - "bitflags 2.3.1", + "bitflags 2.3.3", "cc", "fallible-iterator 0.2.0", "indexmap 1.9.3", @@ -4215,7 +4223,7 @@ version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "db68d3f0682b50197a408d65a3246b7d6173399d1325cf0208fb3fdb66e3229f" dependencies = [ - "bitflags 2.3.1", + "bitflags 2.3.3", "cc", "fallible-iterator 0.3.0", "indexmap 1.9.3", @@ -4265,9 +4273,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.23" +version = "2.0.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "59fb7d6d8281a51045d62b8eb3a7d1ce347b76f312af50cd3dc0af39c87c1737" +checksum = "15e3fc8c0c74267e2df136e5e5fb656a464158aa57624053375eb9c8c6e25ae2" dependencies = [ "proc-macro2", "quote", @@ -4282,18 +4290,18 @@ checksum = "2047c6ded9c721764247e62cd3b03c09ffc529b2ba5b10ec482ae507a4a70160" [[package]] name = "system-interface" -version = "0.23.0" +version = "0.25.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "92adbaf536f5aff6986e1e62ba36cee72b1718c5153eee08b9e728ddde3f6029" +checksum = "10081a99cbecbc363d381b9503563785f0b02735fccbb0d4c1a2cb3d39f7e7fe" dependencies = [ - "atty", - "bitflags 1.3.2", + "bitflags 2.3.3", "cap-fs-ext", "cap-std", - "io-lifetimes 0.7.5", - "rustix 0.35.13", - "windows-sys 0.36.1", - "winx", + "fd-lock", + "io-lifetimes 2.0.2", + "rustix 0.38.3", + "windows-sys 0.48.0", + "winx 0.36.1", ] [[package]] @@ -4315,9 +4323,9 @@ dependencies = [ [[package]] name = "target-lexicon" -version = "0.12.7" +version = "0.12.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fd1ba337640d60c3e96bc6f0638a939b9c9a7f2c316a1598c279828b3d1dc8c5" +checksum = "1b1c7f239eb94671427157bd93b3694320f3668d4e1eff08c7285366fd777fac" [[package]] name = "tempfile" @@ -4329,7 +4337,7 @@ dependencies = [ "cfg-if", "fastrand", "redox_syscall 0.3.5", - "rustix 0.37.19", + "rustix 0.37.23", "windows-sys 0.48.0", ] @@ -4371,7 +4379,7 @@ checksum = "463fe12d7993d3b327787537ce8dd4dfa058de32fc2b195ef3cde03dc4771e8f" dependencies = [ "proc-macro2", "quote", - "syn 2.0.23", + "syn 2.0.25", ] [[package]] @@ -4397,11 +4405,13 @@ dependencies = [ [[package]] name = "time" -version = "0.3.21" +version = "0.3.23" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f3403384eaacbca9923fa06940178ac13e4edb725486d70e8e15881d0c836cc" +checksum = "59e399c068f43a5d116fedaf73b203fa4f9c519f17e2b34f63221d3792f81446" dependencies = [ "itoa", + "libc", + "num_threads", "serde", "time-core", "time-macros", @@ -4415,9 +4425,9 @@ checksum = "7300fbefb4dadc1af235a9cef3737cea692a9d97e1b9cbcd4ebdae6f8868e6fb" [[package]] name = "time-macros" -version = "0.2.9" +version = "0.2.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "372950940a5f07bf38dbe211d7283c9e6d7327df53794992d293e534c733d09b" +checksum = "96ba15a897f3c86766b757e5ac7221554c6750054d74d5b28844fce5fb36a6c4" dependencies = [ "time-core", ] @@ -4476,7 +4486,7 @@ checksum = "630bdcf245f78637c13ec01ffae6187cca34625e8c63150d424b59e55af2675e" dependencies = [ "proc-macro2", "quote", - "syn 2.0.23", + "syn 2.0.25", ] [[package]] @@ -4502,11 +4512,11 @@ dependencies = [ [[package]] name = "tokio-rustls" -version = "0.24.0" +version = "0.24.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e0d409377ff5b1e3ca6437aa86c1eb7d40c134bfec254e44c830defa92669db5" +checksum = "c28327cf380ac148141087fbfb9de9d7bd4e84ab5d2c28fbc911d753de8a7081" dependencies = [ - "rustls 0.21.1", + "rustls 0.21.3", "tokio", ] @@ -4521,6 +4531,19 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-test" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "53474327ae5e166530d17f2d956afcb4f8a004de581b3cae10f12006bc8163e3" +dependencies = [ + "async-stream", + "bytes 1.4.0", + "futures-core", + "tokio", + "tokio-stream", +] + [[package]] name = "tokio-tungstenite" version = "0.19.0" @@ -4575,7 +4598,7 @@ dependencies = [ "hyper", "hyper-timeout", "percent-encoding", - "pin-project 1.1.0", + "pin-project 1.1.2", "prost", "prost-derive", "rustls-pemfile", @@ -4608,7 +4631,7 @@ dependencies = [ "hyper", "hyper-timeout", "percent-encoding", - "pin-project 1.1.0", + "pin-project 1.1.2", "prost", "tokio", "tokio-stream", @@ -4640,7 +4663,7 @@ dependencies = [ "futures-core", "futures-util", "indexmap 1.9.3", - "pin-project 1.1.0", + "pin-project 1.1.2", "pin-project-lite", "rand", "slab", @@ -4700,13 +4723,13 @@ dependencies = [ [[package]] name = "tracing-attributes" -version = "0.1.24" +version = "0.1.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0f57e3ca2a01450b1a921183a9c9cbfda207fd822cef4ccb00a65402cbba7a74" +checksum = "5f4f31f56159e98206da9efd823404b79b6ef3143b4a7ab76e67b1751b25a4ab" dependencies = [ "proc-macro2", "quote", - "syn 2.0.23", + "syn 2.0.25", ] [[package]] @@ -4735,7 +4758,7 @@ version = "0.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "97d095ae15e245a057c8e8451bab9b3ee1e1f68e9ba2b4fbc18d0ac5237835f2" dependencies = [ - "pin-project 1.1.0", + "pin-project 1.1.2", "tracing", ] @@ -4799,6 +4822,26 @@ dependencies = [ "utf-8", ] +[[package]] +name = "turmoil" +version = "0.5.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e72ab712288bd737d0abc60712e031fb48488bbce7810ac2135da067e916469" +dependencies = [ + "bytes 1.4.0", + "futures", + "indexmap 1.9.3", + "rand", + "rand_distr", + "scoped-tls", + "tokio", + "tokio-stream", + "tokio-test", + "tokio-util", + "tracing", + "tracing-subscriber", +] + [[package]] name = "typenum" version = "1.16.0" @@ -4837,9 +4880,9 @@ checksum = "92888ba5573ff080736b3648696b70cafad7d250551175acbaa4e0385b3e1460" [[package]] name = "unicode-ident" -version = "1.0.9" +version = "1.0.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b15811caf2415fb889178633e7724bad2509101cde276048e013b9def5e51fa0" +checksum = "22049a19f4a68748a168c0fc439f9516686aa045927ff767eca0a85101fb6e73" [[package]] name = "unicode-normalization" @@ -4856,6 +4899,12 @@ version = "0.1.10" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c0edd1e5b14653f783770bce4a4dabb4a5108a5370a5f5d8cfe8710c361f6c8b" +[[package]] +name = "unicode-xid" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f962df74c8c05a667b5ee8bcf162993134c104e96440b663c8daa176dc772d8c" + [[package]] name = "untrusted" version = "0.7.1" @@ -4870,9 +4919,9 @@ checksum = "2fbfe96089af082b3c856f83bdd0b6866241377d9dbea803fb39481151e5742d" [[package]] name = "url" -version = "2.3.1" +version = "2.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0d68c799ae75762b8c3fe375feb6600ef5602c883c5d21eb51c09f22b83c4643" +checksum = "50bff7831e19200a85b17131d085c25d7811bc4e186efdaf54bbd132994a88cb" dependencies = [ "form_urlencoded", "idna", @@ -4922,13 +4971,13 @@ checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" [[package]] name = "vergen" -version = "8.2.1" +version = "8.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8b3c89c2c7e50f33e4d35527e5bf9c11d6d132226dbbd1753f0fbe9f19ef88c6" +checksum = "bbc5ad0d9d26b2c49a5ab7da76c3e79d3ee37e7821799f8223fcb8f2f391a2e7" dependencies = [ "anyhow", "rustversion", - "time 0.3.21", + "time 0.3.23", ] [[package]] @@ -4970,11 +5019,10 @@ dependencies = [ [[package]] name = "want" -version = "0.3.0" +version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1ce8a968cb1cd110d136ff8b819a556d6fb6d919363c61534f6860c7eb172ba0" +checksum = "bfa7760aed19e106de2c7c0b581b509f2f25d3dacaf737cb82ac61bc6d760b0e" dependencies = [ - "log", "try-lock", ] @@ -4992,9 +5040,9 @@ checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" [[package]] name = "wasi-cap-std-sync" -version = "3.0.1" +version = "9.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ecbeebb8985a5423f36f976b2f4a0b3c6ce38d7d9a7247e1ce07aa2880e4f29b" +checksum = "5d29c5da3b5cfc9212a7fa824224875cb67fb89d2a8392db655e4c59b8ab2ae7" dependencies = [ "anyhow", "async-trait", @@ -5004,40 +5052,41 @@ dependencies = [ "cap-time-ext", "fs-set-times", "io-extras", - "io-lifetimes 0.7.5", - "is-terminal 0.3.0", + "io-lifetimes 1.0.11", + "is-terminal", "once_cell", - "rustix 0.35.13", + "rustix 0.37.23", "system-interface", "tracing", "wasi-common", - "windows-sys 0.36.1", + "windows-sys 0.48.0", ] [[package]] name = "wasi-common" -version = "3.0.1" +version = "9.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "81e2171f3783fe6600ee24ff6c58ca1b329c55e458cc1622ecc1fd0427648607" +checksum = "f8bd905dcec1448664bf63d42d291cbae0feeea3ad41631817b8819e096d76bd" dependencies = [ "anyhow", "bitflags 1.3.2", "cap-rand", "cap-std", "io-extras", - "rustix 0.35.13", + "log", + "rustix 0.37.23", "thiserror", "tracing", "wasmtime", "wiggle", - "windows-sys 0.36.1", + "windows-sys 0.48.0", ] [[package]] name = "wasm-bindgen" -version = "0.2.86" +version = "0.2.87" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5bba0e8cb82ba49ff4e229459ff22a191bbe9a1cb3a341610c9c33efc27ddf73" +checksum = "7706a72ab36d8cb1f80ffbf0e071533974a60d0a308d01a5d0375bf60499a342" dependencies = [ "cfg-if", "wasm-bindgen-macro", @@ -5045,24 +5094,24 @@ dependencies = [ [[package]] name = "wasm-bindgen-backend" -version = "0.2.86" +version = "0.2.87" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "19b04bc93f9d6bdee709f6bd2118f57dd6679cf1176a1af464fca3ab0d66d8fb" +checksum = "5ef2b6d3c510e9625e5fe6f509ab07d66a760f0885d858736483c32ed7809abd" dependencies = [ "bumpalo", "log", "once_cell", "proc-macro2", "quote", - "syn 2.0.23", + "syn 2.0.25", "wasm-bindgen-shared", ] [[package]] name = "wasm-bindgen-futures" -version = "0.4.36" +version = "0.4.37" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2d1985d03709c53167ce907ff394f5316aa22cb4e12761295c5dc57dacb6297e" +checksum = "c02dbc21516f9f1f04f187958890d7e6026df8d16540b7ad9492bc34a67cea03" dependencies = [ "cfg-if", "js-sys", @@ -5072,9 +5121,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro" -version = "0.2.86" +version = "0.2.87" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "14d6b024f1a526bb0234f52840389927257beb670610081360e5a03c5df9c258" +checksum = "dee495e55982a3bd48105a7b947fd2a9b4a8ae3010041b9e0faab3f9cd028f1d" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -5082,134 +5131,178 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.86" +version = "0.2.87" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e128beba882dd1eb6200e1dc92ae6c5dbaa4311aa7bb211ca035779e5efc39f8" +checksum = "54681b18a46765f095758388f2d0cf16eb8d4169b639ab575a8f5693af210c7b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.23", + "syn 2.0.25", "wasm-bindgen-backend", "wasm-bindgen-shared", ] [[package]] name = "wasm-bindgen-shared" -version = "0.2.86" +version = "0.2.87" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ed9d5b4305409d1fc9482fee2d7f9bcbf24b3972bf59817ef757e23982242a93" +checksum = "ca6ad05a4870b2bf5fe995117d3728437bd27d7cd5f06f13c17443ef369775a1" [[package]] name = "wasm-encoder" -version = "0.28.0" +version = "0.29.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "83c94f464d50e31da425794a02da1a82d4b96a657dcb152a6664e8aa915be517" +checksum = "18c41dbd92eaebf3612a39be316540b8377c871cb9bde6b064af962984912881" dependencies = [ "leb128", ] [[package]] name = "wasmparser" -version = "0.93.0" +version = "0.103.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c5a4460aa3e271fa180b6a5d003e728f3963fb30e3ba0fa7c9634caa06049328" +checksum = "2c437373cac5ea84f1113d648d51f71751ffbe3d90c00ae67618cf20d0b5ee7b" dependencies = [ "indexmap 1.9.3", + "url", ] [[package]] name = "wasmtime" -version = "3.0.1" +version = "9.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d18265705b1c49218776577d9f301d79ab06888c7f4a32e2ed24e68a55738ce7" +checksum = "634357e8668774b24c80b210552f3f194e2342a065d6d83845ba22c5817d0770" dependencies = [ "anyhow", "async-trait", "bincode", + "bumpalo", "cfg-if", + "fxprof-processed-profile", "indexmap 1.9.3", "libc", "log", - "object 0.29.0", + "object 0.30.4", "once_cell", "paste", "psm", "rayon", "serde", + "serde_json", "target-lexicon", "wasmparser", "wasmtime-cache", + "wasmtime-component-macro", "wasmtime-cranelift", "wasmtime-environ", "wasmtime-fiber", "wasmtime-jit", "wasmtime-runtime", "wat", - "windows-sys 0.36.1", + "windows-sys 0.48.0", ] [[package]] name = "wasmtime-asm-macros" -version = "3.0.1" +version = "9.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a201583f6c79b96e74dcce748fa44fb2958f474ef13c93f880ea4d3bed31ae4f" +checksum = "d33c73c24ce79b0483a3b091a9acf88871f4490b88998e8974b22236264d304c" dependencies = [ "cfg-if", ] [[package]] name = "wasmtime-cache" -version = "3.0.1" +version = "9.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3f37efc6945b08fcb634cffafc438dd299bac55a27c836954656c634d3e63c31" +checksum = "6107809b2d9f5b2fd3ddbaddb3bb92ff8048b62f4030debf1408119ffd38c6cb" dependencies = [ "anyhow", - "base64 0.13.1", + "base64 0.21.2", "bincode", "directories-next", "file-per-thread-logger", "log", - "rustix 0.35.13", + "rustix 0.37.23", "serde", "sha2", "toml", - "windows-sys 0.36.1", + "windows-sys 0.48.0", "zstd", ] +[[package]] +name = "wasmtime-component-macro" +version = "9.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5ba489850d9c91c6c5b9e1696ee89e7a69d9796236a005f7e9131b6746e13b6" +dependencies = [ + "anyhow", + "proc-macro2", + "quote", + "syn 1.0.109", + "wasmtime-component-util", + "wasmtime-wit-bindgen", + "wit-parser", +] + +[[package]] +name = "wasmtime-component-util" +version = "9.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7fa88f9e77d80f828c9d684741a9da649366c6d1cceb814755dd9cab7112d1d1" + [[package]] name = "wasmtime-cranelift" -version = "3.0.1" +version = "9.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fe208297e045ea0ee6702be88772ea40f918d55fbd4163981a4699aff034b634" +checksum = "5800616a28ed6bd5e8b99ea45646c956d798ae030494ac0689bc3e45d3b689c1" dependencies = [ "anyhow", "cranelift-codegen", + "cranelift-control", "cranelift-entity", "cranelift-frontend", "cranelift-native", "cranelift-wasm", - "gimli 0.26.2", + "gimli", "log", - "object 0.29.0", + "object 0.30.4", "target-lexicon", "thiserror", "wasmparser", + "wasmtime-cranelift-shared", + "wasmtime-environ", +] + +[[package]] +name = "wasmtime-cranelift-shared" +version = "9.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "27e4030b959ac5c5d6ee500078977e813f8768fa2b92fc12be01856cd0c76c55" +dependencies = [ + "anyhow", + "cranelift-codegen", + "cranelift-control", + "cranelift-native", + "gimli", + "object 0.30.4", + "target-lexicon", "wasmtime-environ", ] [[package]] name = "wasmtime-environ" -version = "3.0.1" +version = "9.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "754b97f7441ac780a7fa738db5b9c23c1b70ef4abccd8ad205ada5669d196ba2" +checksum = "9ec815d01a8d38aceb7ed4678f9ba551ae6b8a568a63810ac3ad9293b0fd01c8" dependencies = [ "anyhow", "cranelift-entity", - "gimli 0.26.2", + "gimli", "indexmap 1.9.3", "log", - "object 0.29.0", + "object 0.30.4", "serde", "target-lexicon", "thiserror", @@ -5219,70 +5312,69 @@ dependencies = [ [[package]] name = "wasmtime-fiber" -version = "3.0.1" +version = "9.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e5f54abc960b4a055ba16b942cbbd1da641e0ad44cc97a7608f3d43c069b120e" +checksum = "23c5127908fdf720614891ec741c13dd70c844e102caa393e2faca1ee68e9bfb" dependencies = [ "cc", "cfg-if", - "rustix 0.35.13", + "rustix 0.37.23", "wasmtime-asm-macros", - "windows-sys 0.36.1", + "windows-sys 0.48.0", ] [[package]] name = "wasmtime-jit" -version = "3.0.1" +version = "9.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "32800cb6e29faabab7056593f70a4c00c65c75c365aaf05406933f2169d0c22f" +checksum = "2712eafe829778b426cad0e1769fef944898923dd29f0039e34e0d53ba72b234" dependencies = [ - "addr2line 0.17.0", + "addr2line 0.19.0", "anyhow", "bincode", "cfg-if", "cpp_demangle", - "gimli 0.26.2", + "gimli", "ittapi", "log", - "object 0.29.0", + "object 0.30.4", "rustc-demangle", "serde", "target-lexicon", - "thiserror", "wasmtime-environ", "wasmtime-jit-debug", "wasmtime-jit-icache-coherence", "wasmtime-runtime", - "windows-sys 0.36.1", + "windows-sys 0.48.0", ] [[package]] name = "wasmtime-jit-debug" -version = "3.0.1" +version = "9.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fe057012a0ba6cee3685af1e923d6e0a6cb9baf15fb3ffa4be3d7f712c7dec42" +checksum = "65fb78eacf4a6e47260d8ef8cc81ea8ddb91397b2e848b3fb01567adebfe89b5" dependencies = [ - "object 0.29.0", + "object 0.30.4", "once_cell", - "rustix 0.35.13", + "rustix 0.37.23", ] [[package]] name = "wasmtime-jit-icache-coherence" -version = "2.0.1" +version = "9.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e6bbabb309c06cc238ee91b1455b748c45f0bdcab0dda2c2db85b0a1e69fcb66" +checksum = "d1364900b05f7d6008516121e8e62767ddb3e176bdf4c84dfa85da1734aeab79" dependencies = [ "cfg-if", "libc", - "windows-sys 0.36.1", + "windows-sys 0.48.0", ] [[package]] name = "wasmtime-runtime" -version = "3.0.1" +version = "9.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "09a23b6e138e89594c0189162e524a29e217aec8f9a4e1959a34f74c64e8d17d" +checksum = "4a16ffe4de9ac9669175c0ea5c6c51ffc596dfb49320aaa6f6c57eff58cef069" dependencies = [ "anyhow", "cc", @@ -5292,23 +5384,22 @@ dependencies = [ "log", "mach", "memfd", - "memoffset 0.6.5", + "memoffset 0.8.0", "paste", "rand", - "rustix 0.35.13", - "thiserror", + "rustix 0.37.23", "wasmtime-asm-macros", "wasmtime-environ", "wasmtime-fiber", "wasmtime-jit-debug", - "windows-sys 0.36.1", + "windows-sys 0.48.0", ] [[package]] name = "wasmtime-types" -version = "3.0.1" +version = "9.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "68ec7615fde8c79737f1345d81f0b18da83b3db929a87b4604f27c932246d1e2" +checksum = "19961c9a3b04d5e766875a5c467f6f5d693f508b3e81f8dc4a1444aa94f041c9" dependencies = [ "cranelift-entity", "serde", @@ -5318,17 +5409,29 @@ dependencies = [ [[package]] name = "wasmtime-wasi" -version = "3.0.1" +version = "9.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ca539adf155dca1407aa3656e5661bf2364b1f3ebabc7f0a8bd62629d876acfa" +checksum = "21080ff62878f1d7c53d9571053dbe96552c0f982f9f29eac65ea89974fabfd7" dependencies = [ "anyhow", + "libc", "wasi-cap-std-sync", "wasi-common", "wasmtime", "wiggle", ] +[[package]] +name = "wasmtime-wit-bindgen" +version = "9.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "421f0d16cc5c612b35ae53a0be3d3124c72296f18e5be3468263c745d56d37ab" +dependencies = [ + "anyhow", + "heck", + "wit-parser", +] + [[package]] name = "wast" version = "35.0.2" @@ -5340,9 +5443,9 @@ dependencies = [ [[package]] name = "wast" -version = "59.0.0" +version = "60.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38462178c91e3f990df95f12bf48abe36018e03550a58a65c53975f4e704fc35" +checksum = "bd06cc744b536e30387e72a48fdd492105b9c938bb4f415c39c616a7a0a697ad" dependencies = [ "leb128", "memchr", @@ -5352,18 +5455,18 @@ dependencies = [ [[package]] name = "wat" -version = "1.0.65" +version = "1.0.66" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c936a025be0417a94d6e9bf92bfdf9e06dbf63debf187b650d9c73a5add701f1" +checksum = "5abe520f0ab205366e9ac7d3e6b2fc71de44e32a2b58f2ec871b6b575bdcea3b" dependencies = [ - "wast 59.0.0", + "wast 60.0.0", ] [[package]] name = "web-sys" -version = "0.3.63" +version = "0.3.64" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3bdd9ef4e984da1187bf8110c5cf5b845fbc87a23602cdf912386a76fcd3a7c2" +checksum = "9b85cbef8c220a6abc02aefd892dfc0fc23afb1c6a426316ec33253a3877249b" dependencies = [ "js-sys", "wasm-bindgen", @@ -5401,9 +5504,9 @@ dependencies = [ [[package]] name = "wiggle" -version = "3.0.1" +version = "9.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2da09ca5b8bb9278a2123e8c36342166b9aaa55a0dbab18b231f46d6f6ab85bc" +checksum = "5b34e40b7b17a920d03449ca78b0319984379eed01a9a11c1def9c3d3832d85a" dependencies = [ "anyhow", "async-trait", @@ -5416,9 +5519,9 @@ dependencies = [ [[package]] name = "wiggle-generate" -version = "3.0.1" +version = "9.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ba5796f53b429df7d44cfdaae8f6d9cd981d82aec3516561352ca9c5e73ee185" +checksum = "9eefda132eaa84fe5f15d23a55a912f8417385aee65d0141d78a3b65e46201ed" dependencies = [ "anyhow", "heck", @@ -5431,9 +5534,9 @@ dependencies = [ [[package]] name = "wiggle-macro" -version = "3.0.1" +version = "9.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8b830eb7203d48942fb8bc8bb105f76e7d09c33a082d638e990e02143bb2facd" +checksum = "6ca1a344a0ba781e2a94b27be5bb78f23e43d52336bd663b810d49d7189ad334" dependencies = [ "proc-macro2", "quote", @@ -5478,35 +5581,7 @@ version = "0.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e686886bc078bc1b0b600cac0147aadb815089b6e4da64016cbd754b6342700f" dependencies = [ - "windows-targets 0.48.0", -] - -[[package]] -name = "windows-sys" -version = "0.36.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ea04155a16a59f9eab786fe12a4a450e75cdb175f9e0d80da1e17db09f55b8d2" -dependencies = [ - "windows_aarch64_msvc 0.36.1", - "windows_i686_gnu 0.36.1", - "windows_i686_msvc 0.36.1", - "windows_x86_64_gnu 0.36.1", - "windows_x86_64_msvc 0.36.1", -] - -[[package]] -name = "windows-sys" -version = "0.42.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5a3e1820f08b8513f676f7ab6c1f99ff312fb97b553d30ff4dd86f9f15728aa7" -dependencies = [ - "windows_aarch64_gnullvm 0.42.2", - "windows_aarch64_msvc 0.42.2", - "windows_i686_gnu 0.42.2", - "windows_i686_msvc 0.42.2", - "windows_x86_64_gnu 0.42.2", - "windows_x86_64_gnullvm 0.42.2", - "windows_x86_64_msvc 0.42.2", + "windows-targets 0.48.1", ] [[package]] @@ -5524,7 +5599,7 @@ version = "0.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "677d2418bec65e3338edb076e806bc1ec15693c5d0104683f2efe857f61056a9" dependencies = [ - "windows-targets 0.48.0", + "windows-targets 0.48.1", ] [[package]] @@ -5544,9 +5619,9 @@ dependencies = [ [[package]] name = "windows-targets" -version = "0.48.0" +version = "0.48.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7b1eb6f0cd7c80c79759c929114ef071b87354ce476d9d94271031c0497adfd5" +checksum = "05d4b17490f70499f20b9e791dcf6a299785ce8af4d709018206dc5b4953e95f" dependencies = [ "windows_aarch64_gnullvm 0.48.0", "windows_aarch64_msvc 0.48.0", @@ -5569,12 +5644,6 @@ version = "0.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "91ae572e1b79dba883e0d315474df7305d12f569b400fcf90581b06062f7e1bc" -[[package]] -name = "windows_aarch64_msvc" -version = "0.36.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9bb8c3fd39ade2d67e9874ac4f3db21f0d710bee00fe7cab16949ec184eeaa47" - [[package]] name = "windows_aarch64_msvc" version = "0.42.2" @@ -5587,12 +5656,6 @@ version = "0.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b2ef27e0d7bdfcfc7b868b317c1d32c641a6fe4629c171b8928c7b08d98d7cf3" -[[package]] -name = "windows_i686_gnu" -version = "0.36.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "180e6ccf01daf4c426b846dfc66db1fc518f074baa793aa7d9b9aaeffad6a3b6" - [[package]] name = "windows_i686_gnu" version = "0.42.2" @@ -5605,12 +5668,6 @@ version = "0.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "622a1962a7db830d6fd0a69683c80a18fda201879f0f447f065a3b7467daa241" -[[package]] -name = "windows_i686_msvc" -version = "0.36.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e2e7917148b2812d1eeafaeb22a97e4813dfa60a3f8f78ebe204bcc88f12f024" - [[package]] name = "windows_i686_msvc" version = "0.42.2" @@ -5623,12 +5680,6 @@ version = "0.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4542c6e364ce21bf45d69fdd2a8e455fa38d316158cfd43b3ac1c5b1b19f8e00" -[[package]] -name = "windows_x86_64_gnu" -version = "0.36.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4dcd171b8776c41b97521e5da127a2d86ad280114807d0b2ab1e462bc764d9e1" - [[package]] name = "windows_x86_64_gnu" version = "0.42.2" @@ -5653,12 +5704,6 @@ version = "0.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7896dbc1f41e08872e9d5e8f8baa8fdd2677f29468c4e156210174edc7f7b953" -[[package]] -name = "windows_x86_64_msvc" -version = "0.36.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c811ca4a8c853ef420abd8592ba53ddbbac90410fab6903b3e79972a631f7680" - [[package]] name = "windows_x86_64_msvc" version = "0.42.2" @@ -5682,13 +5727,38 @@ dependencies = [ [[package]] name = "winx" -version = "0.33.0" +version = "0.35.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b7b01e010390eb263a4518c8cebf86cb67469d1511c00b749a47b64c39e8054d" +checksum = "1c52a121f0fbf9320d5f2a9a5d82f6cb7557eda5e8b47fc3e7f359ec866ae960" dependencies = [ "bitflags 1.3.2", - "io-lifetimes 0.7.5", - "windows-sys 0.36.1", + "io-lifetimes 1.0.11", + "windows-sys 0.48.0", +] + +[[package]] +name = "winx" +version = "0.36.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4857cedf8371f690bb6782a3e2b065c54d1b6661be068aaf3eac8b45e813fdf8" +dependencies = [ + "bitflags 2.3.3", + "windows-sys 0.48.0", +] + +[[package]] +name = "wit-parser" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5ca2581061573ef6d1754983d7a9b3ed5871ef859d52708ea9a0f5af32919172" +dependencies = [ + "anyhow", + "id-arena", + "indexmap 1.9.3", + "log", + "pulldown-cmark 0.8.0", + "unicode-xid", + "url", ] [[package]] diff --git a/libsqlx-server/Cargo.toml b/libsqlx-server/Cargo.toml index 5e6d5c15..b4f9d366 100644 --- a/libsqlx-server/Cargo.toml +++ b/libsqlx-server/Cargo.toml @@ -6,6 +6,7 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] +async-bincode = { version = "0.7.1", features = ["tokio"] } axum = "0.6.18" base64 = "0.21.2" bincode = "1.3.3" @@ -15,6 +16,7 @@ color-eyre = "0.6.2" futures = "0.3.28" hmac = "0.12.1" hyper = { version = "0.14.27", features = ["h2", "server"] } +itertools = "0.11.0" libsqlx = { version = "0.1.0", path = "../libsqlx" } moka = { version = "0.11.2", features = ["future"] } parking_lot = "0.12.1" @@ -27,6 +29,11 @@ sha2 = "0.10.7" sled = "0.34.7" thiserror = "1.0.43" tokio = { version = "1.29.1", features = ["full"] } +tokio-stream = "0.1.14" +tokio-util = "0.7.8" tracing = "0.1.37" tracing-subscriber = { version = "0.3.17", features = ["env-filter"] } uuid = { version = "1.4.0", features = ["v4"] } + +[dev-dependencies] +turmoil = "0.5.5" diff --git a/libsqlx-server/src/database.rs b/libsqlx-server/src/database.rs index d0c979cc..e4971b74 100644 --- a/libsqlx-server/src/database.rs +++ b/libsqlx-server/src/database.rs @@ -1,7 +1,7 @@ use tokio::sync::{mpsc, oneshot}; use crate::hrana::http::proto::{PipelineRequestBody, PipelineResponseBody}; -use crate::allocation::{AllocationMessage, ConnectionHandle}; +use crate::allocation::AllocationMessage; pub struct Database { pub sender: mpsc::Sender, @@ -9,11 +9,8 @@ pub struct Database { impl Database { pub async fn hrana_pipeline(&self, req: PipelineRequestBody) -> crate::Result { - dbg!(); let (sender, ret) = oneshot::channel(); - dbg!(); self.sender.send(AllocationMessage::HranaPipelineReq { req, ret: sender }).await.unwrap(); - dbg!(); ret.await.unwrap() } } diff --git a/libsqlx-server/src/linc/bus.rs b/libsqlx-server/src/linc/bus.rs new file mode 100644 index 00000000..6fcc3238 --- /dev/null +++ b/libsqlx-server/src/linc/bus.rs @@ -0,0 +1,186 @@ +use std::collections::{hash_map::Entry, HashMap}; +use std::sync::Arc; + +use color_eyre::eyre::{bail, anyhow}; +use parking_lot::Mutex; +use tokio::sync::{mpsc, Notify}; +use uuid::Uuid; + +use super::connection::{ConnectionHandle, Stream}; + +type NodeId = Uuid; +type DatabaseId = Uuid; + +#[must_use] +pub struct Subscription { + receiver: mpsc::Receiver, + bus: Bus, + database_id: DatabaseId, +} + +impl Drop for Subscription { + fn drop(&mut self) { + self.bus + .inner + .lock() + .subscriptions + .remove(&self.database_id); + } +} + +impl futures::Stream for Subscription { + type Item = Stream; + + fn poll_next( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.receiver.poll_recv(cx) + } +} + +#[derive(Clone)] +pub struct Bus { + inner: Arc>, + pub node_id: NodeId, +} + +enum ConnectionSlot { + Handle(ConnectionHandle), + // Interest in the connection when it becomes available + Interest(Arc), +} + +struct BusInner { + connections: HashMap, + subscriptions: HashMap>, +} + +impl Bus { + pub fn new(node_id: NodeId) -> Self { + Self { + node_id, + inner: Arc::new(Mutex::new(BusInner { + connections: HashMap::new(), + subscriptions: HashMap::new(), + })), + } + } + + /// open a new stream to the database at `database_id` on the node `node_id` + pub async fn new_stream( + &self, + node_id: NodeId, + database_id: DatabaseId, + ) -> color_eyre::Result { + let get_conn = || { + let mut lock = self.inner.lock(); + match lock.connections.entry(node_id) { + Entry::Occupied(mut e) => match e.get_mut() { + ConnectionSlot::Handle(h) => Ok(h.clone()), + ConnectionSlot::Interest(notify) => Err(notify.clone()), + }, + Entry::Vacant(e) => { + let notify = Arc::new(Notify::new()); + e.insert(ConnectionSlot::Interest(notify.clone())); + Err(notify) + } + } + }; + + let conn = match get_conn() { + Ok(conn) => conn, + Err(notify) => { + notify.notified().await; + get_conn().map_err(|_| anyhow!("failed to create stream"))? + } + }; + + conn.new_stream(database_id).await + } + + /// Notify a subscription that new stream was openned + pub async fn notify_subscription( + &mut self, + database_id: DatabaseId, + stream: Stream, + ) -> color_eyre::Result<()> { + let maybe_sender = self.inner.lock().subscriptions.get(&database_id).cloned(); + + match maybe_sender { + Some(sender) => { + if sender.send(stream).await.is_err() { + bail!("subscription for {database_id} closed"); + } + + Ok(()) + } + None => { + bail!("no subscription for {database_id}") + } + } + } + + #[cfg(test)] + pub fn is_empty(&self) -> bool { + self.inner.lock().connections.is_empty() + } + + #[must_use] + pub fn register_connection(&self, node_id: NodeId, conn: ConnectionHandle) -> Registration { + let mut lock = self.inner.lock(); + match lock.connections.entry(node_id) { + Entry::Occupied(mut e) => { + if let ConnectionSlot::Interest(ref notify) = e.get() { + notify.notify_waiters(); + } + + *e.get_mut() = ConnectionSlot::Handle(conn); + } + Entry::Vacant(e) => { + e.insert(ConnectionSlot::Handle(conn)); + } + } + + Registration { + bus: self.clone(), + node_id, + } + } + + pub fn subscribe(&self, database_id: DatabaseId) -> color_eyre::Result { + let (sender, receiver) = mpsc::channel(1); + { + let mut inner = self.inner.lock(); + + if inner.subscriptions.contains_key(&database_id) { + bail!("a subscription already exist for that database"); + } + + inner.subscriptions.insert(database_id, sender); + } + + Ok(Subscription { + receiver, + bus: self.clone(), + database_id, + }) + } +} + +pub struct Registration { + bus: Bus, + node_id: NodeId, +} + +impl Drop for Registration { + fn drop(&mut self) { + assert!(self + .bus + .inner + .lock() + .connections + .remove(&self.node_id) + .is_some()); + } +} diff --git a/libsqlx-server/src/linc/connection.rs b/libsqlx-server/src/linc/connection.rs new file mode 100644 index 00000000..b5a8ce25 --- /dev/null +++ b/libsqlx-server/src/linc/connection.rs @@ -0,0 +1,723 @@ +use std::collections::HashMap; + +use async_bincode::tokio::AsyncBincodeStream; +use async_bincode::AsyncDestination; +use color_eyre::eyre::{bail, anyhow}; +use futures::{SinkExt, StreamExt}; +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio::sync::mpsc::error::TrySendError; +use tokio::sync::{mpsc, oneshot}; +use tokio::time::{Duration, Instant}; +use tokio_util::sync::PollSender; + +use crate::linc::proto::{NodeError, NodeMessage}; +use crate::linc::CURRENT_PROTO_VERSION; + +use super::bus::{Bus, Registration}; +use super::proto::{Message, StreamId, StreamMessage}; +use super::{DatabaseId, NodeId}; +use super::{StreamIdAllocator, MAX_STREAM_MSG}; + +#[derive(Debug, Clone)] +pub struct ConnectionHandle { + connection_sender: mpsc::Sender, +} + +impl ConnectionHandle { + pub async fn new_stream(&self, database_id: DatabaseId) -> color_eyre::eyre::Result { + let (send, ret) = oneshot::channel(); + self.connection_sender + .send(ConnectionMessage::StreamCreate { + database_id, + ret: send, + }) + .await + .unwrap(); + + Ok(ret.await?) + } +} + +/// A Bidirectional stream between databases on two nodes. +#[derive(Debug)] +pub struct Stream { + stream_id: StreamId, + /// sender to the connection + sender: tokio_util::sync::PollSender, + /// incoming message for this stream + recv: tokio_stream::wrappers::ReceiverStream, +} + +impl futures::Sink for Stream { + type Error = tokio_util::sync::PollSendError; + + fn poll_ready( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.sender.poll_ready_unpin(cx) + } + + fn start_send( + mut self: std::pin::Pin<&mut Self>, + payload: StreamMessage, + ) -> Result<(), Self::Error> { + let stream_id = self.stream_id; + self.sender + .start_send_unpin(ConnectionMessage::Message(Message::Stream { + stream_id, + payload, + })) + } + + fn poll_flush( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.sender.poll_flush_unpin(cx) + } + + fn poll_close( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.sender.poll_close_unpin(cx) + } +} + +impl futures::Stream for Stream { + type Item = StreamMessage; + + fn poll_next( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.recv.poll_next_unpin(cx) + } +} + +impl Drop for Stream { + fn drop(&mut self) { + self.recv.close(); + assert!(self.recv.as_mut().try_recv().is_err()); + let mut sender = self.sender.clone(); + let id = self.stream_id; + if let Some(sender_ref) = sender.get_ref() { + // Try send here is mostly for turmoil, since it stops polling the future as soon as + // the test future returns which causes spawn to panic. In the tests, the channel will + // always have capacity. + if let Err(TrySendError::Full(m)) = + sender_ref.try_send(ConnectionMessage::CloseStream(id)) + { + tokio::task::spawn(async move { + let _ = sender.send(m).await; + }); + } + } + } +} + +struct StreamState { + sender: mpsc::Sender, +} + +/// A connection to another node. Manage the connection state, and (de)register streams with the +/// `Bus` +pub struct Connection { + /// Id of the current node + pub peer: Option, + /// State of the connection + pub state: ConnectionState, + /// Sink/Stream for network messages + conn: AsyncBincodeStream, + /// Collection of streams for that connection + streams: HashMap, + /// internal connection messages + connection_messages: mpsc::Receiver, + connection_messages_sender: mpsc::Sender, + /// Are we the initiator of this connection? + is_initiator: bool, + bus: Bus, + stream_id_allocator: StreamIdAllocator, + /// handle to the registration of this connection to the bus. + /// Dropping this deregister this connection from the bus + registration: Option, +} + +#[derive(Debug)] +pub enum ConnectionMessage { + StreamCreate { + database_id: DatabaseId, + ret: oneshot::Sender, + }, + CloseStream(StreamId), + Message(Message), +} + +#[derive(Debug)] +pub enum ConnectionState { + Init, + Connecting, + Connected, + // Closing the connection with an error + CloseError(color_eyre::eyre::Error), + // Graceful connection shutdown + Close, +} + +pub fn handshake_deadline() -> Instant { + const HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(5); + Instant::now() + HANDSHAKE_TIMEOUT +} + +impl Connection +where + S: AsyncRead + AsyncWrite + Unpin, +{ + const MAX_CONNECTION_MESSAGES: usize = 128; + + pub fn new_initiator(stream: S, bus: Bus) -> Self { + let (connection_messages_sender, connection_messages) = + mpsc::channel(Self::MAX_CONNECTION_MESSAGES); + Self { + peer: None, + state: ConnectionState::Init, + conn: AsyncBincodeStream::from(stream).for_async(), + streams: HashMap::new(), + is_initiator: true, + bus, + stream_id_allocator: StreamIdAllocator::new(true), + connection_messages, + connection_messages_sender, + registration: None, + } + } + + pub fn new_acceptor(stream: S, bus: Bus) -> Self { + let (connection_messages_sender, connection_messages) = + mpsc::channel(Self::MAX_CONNECTION_MESSAGES); + Connection { + peer: None, + state: ConnectionState::Connecting, + streams: HashMap::new(), + connection_messages, + connection_messages_sender, + is_initiator: false, + bus, + conn: AsyncBincodeStream::from(stream).for_async(), + stream_id_allocator: StreamIdAllocator::new(false), + registration: None, + } + } + + pub fn handle(&self) -> ConnectionHandle { + ConnectionHandle { + connection_sender: self.connection_messages_sender.clone(), + } + } + + pub async fn run(mut self) { + while self.tick().await {} + } + + pub async fn tick(&mut self) -> bool { + match self.state { + ConnectionState::Connected => self.tick_connected().await, + ConnectionState::Init => match self.initiate_connection().await { + Ok(_) => { + self.state = ConnectionState::Connecting; + } + Err(e) => { + self.state = ConnectionState::CloseError(e); + } + }, + ConnectionState::Connecting => { + if let Err(e) = self + .wait_handshake_response_with_deadline(handshake_deadline()) + .await + { + self.state = ConnectionState::CloseError(e); + } + } + ConnectionState::CloseError(ref e) => { + tracing::error!("closing connection with {:?}: {e}", self.peer); + return false; + } + ConnectionState::Close => return false, + } + true + } + + async fn tick_connected(&mut self) { + tokio::select! { + m = self.conn.next() => { + match m { + Some(Ok(m)) => { + self.handle_message(m).await; + } + Some(Err(e)) => { + self.state = ConnectionState::CloseError(e.into()); + } + None => { + self.state = ConnectionState::Close; + } + } + } + Some(command) = self.connection_messages.recv() => { + self.handle_command(command).await; + }, + else => { + self.state = ConnectionState::Close; + } + } + } + + async fn handle_message(&mut self, message: Message) { + match message { + Message::Node(NodeMessage::OpenStream { + stream_id, + database_id, + }) => { + if self.streams.contains_key(&stream_id) { + self.send_message(Message::Node(NodeMessage::Error( + NodeError::StreamAlreadyExist(stream_id), + ))) + .await; + return; + } + let stream = self.create_stream(stream_id); + if let Err(e) = self.bus.notify_subscription(database_id, stream).await { + tracing::error!("{e}"); + self.send_message(Message::Node(NodeMessage::Error( + NodeError::UnknownDatabase(database_id, stream_id), + ))) + .await; + } + } + Message::Node(NodeMessage::Handshake { .. }) => { + self.close_error(anyhow!("unexpected handshake: closing connection")); + } + Message::Node(NodeMessage::CloseStream { stream_id: id }) => { + self.close_stream(id); + } + Message::Node(NodeMessage::Error(e @ NodeError::HandshakeVersionMismatch { .. })) => { + self.close_error(anyhow!("unexpected peer error: {e}")); + } + Message::Node(NodeMessage::Error(NodeError::UnknownStream(id))) => { + tracing::error!("unkown stream: {id}"); + self.close_stream(id); + } + Message::Node(NodeMessage::Error(e @ NodeError::StreamAlreadyExist(_))) => { + self.state = ConnectionState::CloseError(e.into()); + } + Message::Node(NodeMessage::Error(ref e @ NodeError::UnknownDatabase(_, stream_id))) => { + tracing::error!("{e}"); + self.close_stream(stream_id); + } + Message::Stream { stream_id, payload } => { + match self.streams.get_mut(&stream_id) { + Some(s) => { + // TODO: there is not stream-independant control-flow for now. + // When/if control-flow is implemented, it will be handled here. + if s.sender.send(payload).await.is_err() { + self.close_stream(stream_id); + } + } + None => { + self.send_message(Message::Node(NodeMessage::Error( + NodeError::UnknownStream(stream_id), + ))) + .await; + } + } + } + } + } + + fn close_error(&mut self, error: color_eyre::eyre::Error) { + self.state = ConnectionState::CloseError(error); + } + + fn close_stream(&mut self, id: StreamId) { + self.streams.remove(&id); + } + + async fn handle_command(&mut self, command: ConnectionMessage) { + match command { + ConnectionMessage::Message(m) => { + self.send_message(m).await; + } + ConnectionMessage::CloseStream(stream_id) => { + self.close_stream(stream_id); + self.send_message(Message::Node(NodeMessage::CloseStream { stream_id })) + .await; + } + ConnectionMessage::StreamCreate { database_id, ret } => { + let Some(stream_id) = self.stream_id_allocator.allocate() else { + // TODO: We close the connection here, which will cause a reconnections, and + // reset the stream_id allocator. If that happens in practice, it should be very quick to + // re-establish a connection. If this is an issue, we can either start using + // i64 stream_ids, or use a smarter id allocator. + self.state = ConnectionState::CloseError(anyhow!("Ran out of stream ids")); + return + }; + assert_eq!(stream_id.is_positive(), self.is_initiator); + assert!(!self.streams.contains_key(&stream_id)); + let stream = self.create_stream(stream_id); + self.send_message(Message::Node(NodeMessage::OpenStream { + stream_id, + database_id, + })) + .await; + let _ = ret.send(stream); + } + } + } + + async fn send_message(&mut self, message: Message) { + if let Err(e) = self.conn.send(message).await { + self.close_error(e.into()); + } + } + + fn create_stream(&mut self, stream_id: StreamId) -> Stream { + let (sender, recv) = mpsc::channel(MAX_STREAM_MSG); + let stream = Stream { + stream_id, + sender: PollSender::new(self.connection_messages_sender.clone()), + recv: recv.into(), + }; + self.streams.insert(stream_id, StreamState { sender }); + stream + } + + /// wait for a handshake response from peer + pub async fn wait_handshake_response_with_deadline( + &mut self, + deadline: Instant, + ) -> color_eyre::Result<()> { + assert!(matches!(self.state, ConnectionState::Connecting)); + + match tokio::time::timeout_at(deadline, self.conn.next()).await { + Ok(Some(Ok(Message::Node(NodeMessage::Handshake { + protocol_version, + node_id, + })))) => { + if protocol_version != CURRENT_PROTO_VERSION { + let _ = self + .conn + .send(Message::Node(NodeMessage::Error( + NodeError::HandshakeVersionMismatch { + expected: CURRENT_PROTO_VERSION, + }, + ))) + .await; + + bail!("handshake error: invalid peer protocol version"); + } else { + // when not initiating a connection, respond to handshake message with a + // handshake message + if !self.is_initiator { + self.conn + .send(Message::Node(NodeMessage::Handshake { + protocol_version: CURRENT_PROTO_VERSION, + node_id: self.bus.node_id, + })) + .await?; + } + + self.peer = Some(node_id); + self.state = ConnectionState::Connected; + self.registration = Some(self.bus.register_connection(node_id, self.handle())); + + Ok(()) + } + } + Ok(Some(Ok(Message::Node(NodeMessage::Error(e))))) => { + bail!("handshake error: {e}"); + } + Ok(Some(Ok(_))) => { + bail!("unexpected message from peer during handshake."); + } + Ok(Some(Err(e))) => { + bail!("failed to perform handshake with peer: {e}"); + } + Ok(None) => { + bail!("failed to perform handshake with peer: connection closed"); + } + Err(_e) => { + bail!("failed to perform handshake with peer: timed out"); + } + } + } + + async fn initiate_connection(&mut self) -> color_eyre::Result<()> { + self.conn + .send(Message::Node(NodeMessage::Handshake { + protocol_version: CURRENT_PROTO_VERSION, + node_id: self.bus.node_id, + })) + .await?; + + Ok(()) + } +} + +#[cfg(test)] +mod test { + use std::sync::Arc; + + use tokio::sync::Notify; + use turmoil::net::{TcpListener, TcpStream}; + use uuid::Uuid; + + use super::*; + + #[test] + fn invalid_handshake() { + let mut sim = turmoil::Builder::new().build(); + + let host_node_id = NodeId::new_v4(); + sim.host("host", move || async move { + let bus = Bus::new(host_node_id); + let listener = turmoil::net::TcpListener::bind("0.0.0.0:1234") + .await + .unwrap(); + let (s, _) = listener.accept().await.unwrap(); + let mut connection = Connection::new_acceptor(s, bus); + connection.tick().await; + + Ok(()) + }); + + sim.client("client", async move { + let s = TcpStream::connect("host:1234").await.unwrap(); + let mut s = AsyncBincodeStream::<_, Message, Message, _>::from(s).for_async(); + + s.send(Message::Node(NodeMessage::Handshake { + protocol_version: 1234, + node_id: Uuid::new_v4(), + })) + .await + .unwrap(); + let m = s.next().await.unwrap().unwrap(); + + assert!(matches!( + m, + Message::Node(NodeMessage::Error( + NodeError::HandshakeVersionMismatch { .. } + )) + )); + + Ok(()) + }); + + sim.run().unwrap(); + } + + #[test] + fn stream_closed() { + let mut sim = turmoil::Builder::new().build(); + + let database_id = DatabaseId::new_v4(); + let host_node_id = NodeId::new_v4(); + let notify = Arc::new(Notify::new()); + sim.host("host", { + let notify = notify.clone(); + move || { + let notify = notify.clone(); + async move { + let bus = Bus::new(host_node_id); + let mut sub = bus.subscribe(database_id).unwrap(); + let listener = turmoil::net::TcpListener::bind("0.0.0.0:1234") + .await + .unwrap(); + let (s, _) = listener.accept().await.unwrap(); + let connection = Connection::new_acceptor(s, bus); + tokio::task::spawn_local(connection.run()); + let mut streams = Vec::new(); + loop { + tokio::select! { + Some(mut stream) = sub.next() => { + let m = stream.next().await.unwrap(); + stream.send(m).await.unwrap(); + streams.push(stream); + } + _ = notify.notified() => { + break; + } + } + } + + Ok(()) + } + } + }); + + sim.client("client", async move { + let stream_id = StreamId::new(1); + let node_id = NodeId::new_v4(); + let s = TcpStream::connect("host:1234").await.unwrap(); + let mut s = AsyncBincodeStream::<_, Message, Message, _>::from(s).for_async(); + + s.send(Message::Node(NodeMessage::Handshake { + protocol_version: CURRENT_PROTO_VERSION, + node_id, + })) + .await + .unwrap(); + let m = s.next().await.unwrap().unwrap(); + assert!(matches!(m, Message::Node(NodeMessage::Handshake { .. }))); + + // send message to unexisting stream: + s.send(Message::Stream { + stream_id, + payload: StreamMessage::Dummy, + }) + .await + .unwrap(); + let m = s.next().await.unwrap().unwrap(); + assert_eq!( + m, + Message::Node(NodeMessage::Error(NodeError::UnknownStream(stream_id))) + ); + + // open stream then send message + s.send(Message::Node(NodeMessage::OpenStream { + stream_id, + database_id, + })) + .await + .unwrap(); + s.send(Message::Stream { + stream_id, + payload: StreamMessage::Dummy, + }) + .await + .unwrap(); + let m = s.next().await.unwrap().unwrap(); + assert_eq!( + m, + Message::Stream { + stream_id, + payload: StreamMessage::Dummy + } + ); + + s.send(Message::Node(NodeMessage::CloseStream { + stream_id: StreamId::new(1), + })) + .await + .unwrap(); + s.send(Message::Stream { + stream_id, + payload: StreamMessage::Dummy, + }) + .await + .unwrap(); + let m = s.next().await.unwrap().unwrap(); + assert_eq!( + m, + Message::Node(NodeMessage::Error(NodeError::UnknownStream(stream_id))) + ); + + notify.notify_waiters(); + + Ok(()) + }); + + sim.run().unwrap(); + } + + #[test] + fn connection_closed_by_peer_close_connection() { + let mut sim = turmoil::Builder::new().build(); + + let notify = Arc::new(Notify::new()); + sim.host("host", { + let notify = notify.clone(); + move || { + let notify = notify.clone(); + async move { + let listener = TcpListener::bind("0.0.0.0:1234").await.unwrap(); + let (stream, _) = listener.accept().await.unwrap(); + notify.notified().await; + + // drop connection + drop(stream); + + Ok(()) + } + } + }); + + sim.client("client", async move { + let stream = TcpStream::connect("host:1234").await.unwrap(); + let bus = Bus::new(NodeId::new_v4()); + let mut conn = Connection::new_acceptor(stream, bus); + + notify.notify_waiters(); + + conn.tick().await; + + assert!(matches!(conn.state, ConnectionState::CloseError(_))); + + Ok(()) + }); + + sim.run().unwrap(); + } + + #[test] + fn zero_stream_id() { + let mut sim = turmoil::Builder::new().build(); + + let notify = Arc::new(Notify::new()); + sim.host("host", { + let notify = notify.clone(); + move || { + let notify = notify.clone(); + async move { + let listener = TcpListener::bind("0.0.0.0:1234").await.unwrap(); + let (stream, _) = listener.accept().await.unwrap(); + let (connection_messages_sender, connection_messages) = mpsc::channel(1); + let conn = Connection { + peer: Some(NodeId::new_v4()), + state: ConnectionState::Connected, + conn: AsyncBincodeStream::from(stream).for_async(), + streams: HashMap::new(), + connection_messages, + connection_messages_sender, + is_initiator: false, + bus: Bus::new(NodeId::new_v4()), + stream_id_allocator: StreamIdAllocator::new(false), + registration: None, + }; + + conn.run().await; + + Ok(()) + } + } + }); + + sim.client("client", async move { + let stream = TcpStream::connect("host:1234").await.unwrap(); + let mut stream = AsyncBincodeStream::<_, Message, Message, _>::from(stream).for_async(); + + stream + .send(Message::Stream { + stream_id: StreamId::new_unchecked(0), + payload: StreamMessage::Dummy, + }) + .await + .unwrap(); + + assert!(stream.next().await.is_none()); + + Ok(()) + }); + + sim.run().unwrap(); + } +} diff --git a/libsqlx-server/src/linc/connection_manager.rs b/libsqlx-server/src/linc/connection_manager.rs new file mode 100644 index 00000000..e69de29b diff --git a/libsqlx-server/src/linc/connection_pool.rs b/libsqlx-server/src/linc/connection_pool.rs new file mode 100644 index 00000000..f5f29c61 --- /dev/null +++ b/libsqlx-server/src/linc/connection_pool.rs @@ -0,0 +1,202 @@ +use std::collections::HashMap; + +use itertools::Itertools; +use tokio::task::JoinSet; +use tokio::time::Duration; + +use super::connection::Connection; +use super::net::Connector; +use super::{bus::Bus, NodeId}; + +/// Manages a pool of connections to other peers, handling re-connection. +struct ConnectionPool { + managed_peers: HashMap, + connections: JoinSet, + bus: Bus, +} + +impl ConnectionPool { + pub fn new(bus: Bus, managed_peers: impl IntoIterator) -> Self { + Self { + managed_peers: managed_peers.into_iter().collect(), + connections: JoinSet::new(), + bus, + } + } + + pub async fn run(mut self) { + self.init::().await; + + while self.tick::().await {} + } + + pub async fn tick(&mut self) -> bool { + if let Some(maybe_to_restart) = self.connections.join_next().await { + if let Ok(to_restart) = maybe_to_restart { + self.connect::(to_restart); + } + true + } else { + false + } + } + + async fn init(&mut self) { + let peers = self.managed_peers.keys().copied().collect_vec(); + peers.into_iter().for_each(|p| self.connect::(p)); + } + + fn connect(&mut self, peer_id: NodeId) { + let bus = self.bus.clone(); + let peer_addr = self.managed_peers[&peer_id].clone(); + let fut = async move { + let stream = match C::connect(peer_addr.clone()).await { + Ok(stream) => stream, + Err(e) => { + tracing::error!("error connection to peer {peer_id}@{peer_addr}: {e}"); + tokio::time::sleep(Duration::from_secs(1)).await; + return peer_id; + } + }; + let connection = Connection::new_initiator(stream, bus.clone()); + connection.run().await; + + peer_id + }; + + self.connections.spawn(fut); + } +} + +#[cfg(test)] +mod test { + use std::sync::Arc; + + use futures::SinkExt; + use tokio::sync::Notify; + use tokio_stream::StreamExt; + + use crate::linc::{server::Server, DatabaseId}; + + use super::*; + + #[test] + fn manage_connections() { + let mut sim = turmoil::Builder::new().build(); + let database_id = DatabaseId::new_v4(); + let notify = Arc::new(Notify::new()); + + let expected_msg = crate::linc::proto::StreamMessage::Proxy( + crate::linc::proto::ProxyMessage::ProxyRequest { + connection_id: 42, + req_id: 42, + program: "foobar".into(), + }, + ); + + let spawn_host = |node_id| { + let notify = notify.clone(); + let expected_msg = expected_msg.clone(); + move || { + let notify = notify.clone(); + let expected_msg = expected_msg.clone(); + async move { + let bus = Bus::new(node_id); + let mut sub = bus.subscribe(database_id).unwrap(); + let mut server = Server::new(bus.clone()); + let mut listener = turmoil::net::TcpListener::bind("0.0.0.0:1234") + .await + .unwrap(); + + let mut has_closed = false; + let mut streams = Vec::new(); + loop { + tokio::select! { + _ = notify.notified() => { + if !has_closed { + streams.clear(); + server.close_connections().await; + has_closed = true; + } else { + break; + } + }, + _ = server.tick(&mut listener) => (), + Some(mut stream) = sub.next() => { + stream + .send(expected_msg.clone()) + .await + .unwrap(); + streams.push(stream); + } + } + } + + Ok(()) + } + } + }; + + let host1_id = NodeId::new_v4(); + sim.host("host1", spawn_host(host1_id)); + + let host2_id = NodeId::new_v4(); + sim.host("host2", spawn_host(host2_id)); + + let host3_id = NodeId::new_v4(); + sim.host("host3", spawn_host(host3_id)); + + sim.client("client", async move { + let bus = Bus::new(NodeId::new_v4()); + let pool = ConnectionPool::new( + bus.clone(), + vec![ + (host1_id, "host1:1234".into()), + (host2_id, "host2:1234".into()), + (host3_id, "host3:1234".into()), + ], + ); + + tokio::task::spawn_local(pool.run::()); + + // all three hosts are reachable: + let mut stream1 = bus.new_stream(host1_id, database_id).await.unwrap(); + let m = stream1.next().await.unwrap(); + assert_eq!(m, expected_msg); + + let mut stream2 = bus.new_stream(host2_id, database_id).await.unwrap(); + let m = stream2.next().await.unwrap(); + assert_eq!(m, expected_msg); + + let mut stream3 = bus.new_stream(host3_id, database_id).await.unwrap(); + let m = stream3.next().await.unwrap(); + assert_eq!(m, expected_msg); + + // sever connections + notify.notify_waiters(); + + assert!(stream1.next().await.is_none()); + assert!(stream2.next().await.is_none()); + assert!(stream3.next().await.is_none()); + + let mut stream1 = bus.new_stream(host1_id, database_id).await.unwrap(); + let m = stream1.next().await.unwrap(); + assert_eq!(m, expected_msg); + + let mut stream2 = bus.new_stream(host2_id, database_id).await.unwrap(); + let m = stream2.next().await.unwrap(); + assert_eq!(m, expected_msg); + + let mut stream3 = bus.new_stream(host3_id, database_id).await.unwrap(); + let m = stream3.next().await.unwrap(); + assert_eq!(m, expected_msg); + + // terminate test + notify.notify_waiters(); + + Ok(()) + }); + + sim.run().unwrap(); + } +} diff --git a/libsqlx-server/src/linc/mod.rs b/libsqlx-server/src/linc/mod.rs new file mode 100644 index 00000000..30b06285 --- /dev/null +++ b/libsqlx-server/src/linc/mod.rs @@ -0,0 +1,38 @@ +use uuid::Uuid; + +use self::proto::StreamId; + +pub mod bus; +pub mod connection; +pub mod connection_pool; +pub mod net; +pub mod proto; +pub mod server; + +type NodeId = Uuid; +type DatabaseId = Uuid; + +const CURRENT_PROTO_VERSION: u32 = 1; +const MAX_STREAM_MSG: usize = 64; + +#[derive(Debug)] +pub struct StreamIdAllocator { + direction: i32, + next_id: i32, +} + +impl StreamIdAllocator { + fn new(positive: bool) -> Self { + let direction = if positive { 1 } else { -1 }; + Self { + direction, + next_id: direction, + } + } + + pub fn allocate(&mut self) -> Option { + let id = self.next_id; + self.next_id = id.checked_add(self.direction)?; + Some(StreamId::new(id)) + } +} diff --git a/libsqlx-server/src/linc/net.rs b/libsqlx-server/src/linc/net.rs new file mode 100644 index 00000000..430b6d08 --- /dev/null +++ b/libsqlx-server/src/linc/net.rs @@ -0,0 +1,81 @@ +use std::io; +use std::net::SocketAddr; +use std::pin::Pin; + +use futures::Future; +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio::net::TcpListener; +use tokio::net::TcpStream; + +pub trait Connector +where + Self: Sized + AsyncRead + AsyncWrite + Unpin + 'static + Send, +{ + type Future: Future> + Send; + + fn connect(addr: String) -> Self::Future; +} + +impl Connector for TcpStream { + type Future = Pin> + Send>>; + + fn connect(addr: String) -> Self::Future { + Box::pin(TcpStream::connect(addr)) + } +} + +pub trait Listener { + type Stream: AsyncRead + AsyncWrite + Unpin + Send + 'static; + type Future<'a>: Future> + 'a + where + Self: 'a; + + fn accept(&self) -> Self::Future<'_>; +} + +pub struct AcceptFut<'a>(&'a TcpListener); + +impl<'a> Future for AcceptFut<'a> { + type Output = io::Result<(TcpStream, SocketAddr)>; + + fn poll( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll { + self.0.poll_accept(cx) + } +} + +impl Listener for TcpListener { + type Stream = TcpStream; + type Future<'a> = AcceptFut<'a>; + + fn accept(&self) -> Self::Future<'_> { + AcceptFut(self) + } +} + +#[cfg(test)] +mod test { + use super::*; + + use turmoil::net::{TcpListener, TcpStream}; + + impl Listener for TcpListener { + type Stream = TcpStream; + type Future<'a> = + Pin> + 'a>>; + + fn accept(&self) -> Self::Future<'_> { + Box::pin(self.accept()) + } + } + + impl Connector for TcpStream { + type Future = Pin> + Send + 'static>>; + + fn connect(addr: String) -> Self::Future { + Box::pin(Self::connect(addr)) + } + } +} diff --git a/libsqlx-server/src/linc/proto.rs b/libsqlx-server/src/linc/proto.rs new file mode 100644 index 00000000..7de1002a --- /dev/null +++ b/libsqlx-server/src/linc/proto.rs @@ -0,0 +1,214 @@ +use std::fmt; + +use bytes::Bytes; +use serde::{de::Error, Deserialize, Deserializer, Serialize}; +use uuid::Uuid; + +use super::DatabaseId; + +pub type Program = String; + +#[derive(Serialize, Deserialize, Debug, PartialEq, Eq, Hash, Clone, Copy)] +pub struct StreamId(#[serde(deserialize_with = "non_zero")] i32); + +impl fmt::Display for StreamId { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.0.fmt(f) + } +} + +fn non_zero<'de, D>(d: D) -> Result +where + D: Deserializer<'de>, +{ + let value = i32::deserialize(d)?; + + if value == 0 { + return Err(D::Error::custom("invalid stream_id")); + } + + Ok(value) +} + +impl StreamId { + /// creates a new stream_id. + /// panics if val is zero. + pub fn new(val: i32) -> Self { + assert!(val != 0); + Self(val) + } + + pub fn is_positive(&self) -> bool { + self.0.is_positive() + } + + #[cfg(test)] + pub fn new_unchecked(i: i32) -> Self { + Self(i) + } +} + +#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)] +pub enum Message { + /// Messages destined to a node + Node(NodeMessage), + /// message destined to a database + Stream { + stream_id: StreamId, + payload: StreamMessage, + }, +} + +#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)] +pub enum NodeMessage { + /// Initial message exchanged between nodes when connecting + Handshake { + protocol_version: u32, + node_id: Uuid, + }, + /// Request to open a bi-directional stream between the client and the server + OpenStream { + /// Id to give to the newly opened stream + /// Initiator of the connection create streams with positive ids, + /// and acceptor of the connection create streams with negative ids. + stream_id: StreamId, + /// Id of the database to open the stream to. + database_id: Uuid, + }, + /// Close a previously opened stream + CloseStream { stream_id: StreamId }, + /// Error type returned while handling a node message + Error(NodeError), +} + +#[derive(Debug, Serialize, Deserialize, thiserror::Error, PartialEq, Eq)] +pub enum NodeError { + /// The requested stream does not exist + #[error("unknown stream: {0}")] + UnknownStream(StreamId), + /// Incompatible protocol versions + #[error("invalid protocol version, expected: {expected}")] + HandshakeVersionMismatch { expected: u32 }, + #[error("stream {0} already exists")] + StreamAlreadyExist(StreamId), + #[error("cannot open stream {1}: unknown database {0}")] + UnknownDatabase(DatabaseId, StreamId), +} + +#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)] +pub enum StreamMessage { + /// Replication message between a replica and a primary + Replication(ReplicationMessage), + /// Proxy message between a replica and a primary + Proxy(ProxyMessage), + #[cfg(test)] + Dummy, +} + +#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)] +pub enum ReplicationMessage { + HandshakeResponse { + /// id of the replication log + log_id: Uuid, + /// current frame_no of the primary + current_frame_no: u64, + }, + Replicate { + /// next frame no to send + next_frame_no: u64, + }, + /// a batch of frames that are part of the same transaction + Transaction { + /// if not None, then the last frame is a commit frame, and this is the new size of the database. + size_after: Option, + /// frame_no of the last frame in frames + end_frame_no: u64, + /// a batch of frames part of the transaction. + frames: Vec, + }, + /// Error occurred handling a replication message + Error(ReplicationError), +} + +#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)] +pub struct Frame { + /// Page id of that frame + page_id: u32, + /// Data + data: Bytes, +} + +#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)] +pub enum ProxyMessage { + /// Proxy a query to a primary + ProxyRequest { + /// id of the connection to perform the query against + /// If the connection doesn't already exist it is created + /// Id of the request. + /// Responses to this request must have the same id. + connection_id: u32, + req_id: u32, + program: Program, + }, + /// Response to a proxied query + ProxyResponse { + /// id of the request this message is a response to. + req_id: u32, + /// Collection of steps to drive the query builder transducer. + row_step: Vec, + }, + /// Stop processing request `id`. + CancelRequest { req_id: u32 }, + /// Close Connection with passed id. + CloseConnection { connection_id: u32 }, + /// Error returned when handling a proxied query message. + Error(ProxyError), +} + +/// Steps applied to the query builder transducer to build a response to a proxied query. +/// Those types closely mirror those of the `QueryBuilderTrait`. +#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)] +pub enum BuilderStep { + BeginStep, + FinishStep(u64, Option), + StepError(StepError), + ColsDesc(Vec), + BeginRows, + BeginRow, + AddRowValue(Value), + FinishRow, + FinishRos, + Finish(ConnectionState), +} + +// State of the connection after a query was executed +#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)] +pub enum ConnectionState { + /// The connection is still in a open transaction state + OpenTxn, + /// The connection is idle. + Idle, +} + +#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)] +pub enum Value {} + +#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)] +pub struct Column { + /// name of the column + name: String, + /// Declared type of the column, if any. + decl_ty: Option, +} + +/// for now, the stringified version of a sqld::error::Error. +#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)] +pub struct StepError(String); + +/// TBD +#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)] +pub enum ProxyError {} + +/// TBD +#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)] +pub enum ReplicationError {} diff --git a/libsqlx-server/src/linc/server.rs b/libsqlx-server/src/linc/server.rs new file mode 100644 index 00000000..08c205ef --- /dev/null +++ b/libsqlx-server/src/linc/server.rs @@ -0,0 +1,347 @@ +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio::task::JoinSet; + +use crate::linc::connection::Connection; + +use super::bus::Bus; + +pub struct Server { + /// reference to the bus + bus: Bus, + /// Connection tasks owned by the server + connections: JoinSet>, +} + +impl Server { + pub fn new(bus: Bus) -> Self { + Self { + bus, + connections: JoinSet::new(), + } + } + + /// Close all connections + #[cfg(test)] + pub async fn close_connections(&mut self) { + self.connections.abort_all(); + while self.connections.join_next().await.is_some() {} + assert!(self.bus.is_empty()); + } + + pub async fn run(mut self, mut listener: L) + where + L: super::net::Listener, + { + while self.tick(&mut listener).await {} + } + + pub async fn tick(&mut self, listener: &mut L) -> bool + where + L: super::net::Listener, + { + match listener.accept().await { + Ok((stream, _addr)) => { + self.make_connection(stream).await; + true + } + Err(e) => { + tracing::error!("error creating connection: {e}"); + false + } + } + } + + async fn make_connection(&mut self, stream: S) + where + S: AsyncRead + AsyncWrite + Unpin + Send + 'static, + { + let bus = self.bus.clone(); + let fut = async move { + let connection = Connection::new_acceptor(stream, bus.clone()); + connection.run().await; + Ok(()) + }; + + self.connections.spawn(fut); + } +} + +#[cfg(test)] +mod test { + use std::sync::Arc; + + use crate::linc::{ + proto::{ProxyMessage, StreamMessage}, + DatabaseId, NodeId, + }; + + use super::*; + + use futures::{SinkExt, StreamExt}; + use tokio::sync::Notify; + use turmoil::net::TcpStream; + + #[test] + fn server_respond_to_handshake() { + let mut sim = turmoil::Builder::new().build(); + + let host_node_id = NodeId::new_v4(); + let notify = Arc::new(tokio::sync::Notify::new()); + sim.host("host", move || { + let notify = notify.clone(); + async move { + let bus = Bus::new(host_node_id); + let mut server = Server::new(bus); + let mut listener = turmoil::net::TcpListener::bind("0.0.0.0:1234") + .await + .unwrap(); + server.tick(&mut listener).await; + notify.notified().await; + + Ok(()) + } + }); + + sim.client("client", async move { + let node_id = NodeId::new_v4(); + let mut c = Connection::new_initiator( + TcpStream::connect("host:1234").await.unwrap(), + Bus::new(node_id), + ); + + c.tick().await; + c.tick().await; + + assert_eq!(c.peer, Some(host_node_id)); + + Ok(()) + }); + + sim.run().unwrap(); + } + + #[test] + fn client_create_stream_client_close() { + let mut sim = turmoil::Builder::new().build(); + + let host_node_id = NodeId::new_v4(); + let stream_db_id = DatabaseId::new_v4(); + let notify = Arc::new(Notify::new()); + let expected_msg = StreamMessage::Proxy(ProxyMessage::ProxyRequest { + connection_id: 12, + req_id: 1, + program: "hello".to_string(), + }); + + sim.host("host", { + let notify = notify.clone(); + let expected_msg = expected_msg.clone(); + move || { + let notify = notify.clone(); + let expected_msg = expected_msg.clone(); + async move { + let bus = Bus::new(host_node_id); + let server = Server::new(bus.clone()); + let listener = turmoil::net::TcpListener::bind("0.0.0.0:1234") + .await + .unwrap(); + let mut subs = bus.subscribe(stream_db_id).unwrap(); + tokio::task::spawn_local(server.run(listener)); + + let mut stream = subs.next().await.unwrap(); + + let msg = stream.next().await.unwrap(); + + assert_eq!(msg, expected_msg); + + notify.notify_waiters(); + + assert!(stream.next().await.is_none()); + + notify.notify_waiters(); + + Ok(()) + } + } + }); + + sim.client("client", async move { + let node_id = NodeId::new_v4(); + let bus = Bus::new(node_id); + let mut c = Connection::new_initiator( + TcpStream::connect("host:1234").await.unwrap(), + bus.clone(), + ); + c.tick().await; + c.tick().await; + let _h = tokio::spawn(c.run()); + let mut stream = bus.new_stream(host_node_id, stream_db_id).await.unwrap(); + stream.send(expected_msg).await.unwrap(); + + notify.notified().await; + + drop(stream); + + notify.notified().await; + + Ok(()) + }); + + sim.run().unwrap(); + } + + #[test] + fn client_create_stream_server_close() { + let mut sim = turmoil::Builder::new().build(); + + let host_node_id = NodeId::new_v4(); + let database_id = DatabaseId::new_v4(); + let notify = Arc::new(Notify::new()); + + sim.host("host", { + let notify = notify.clone(); + move || { + let notify = notify.clone(); + async move { + let bus = Bus::new(host_node_id); + let server = Server::new(bus.clone()); + let listener = turmoil::net::TcpListener::bind("0.0.0.0:1234") + .await + .unwrap(); + let mut subs = bus.subscribe(database_id).unwrap(); + tokio::task::spawn_local(server.run(listener)); + + let stream = subs.next().await.unwrap(); + drop(stream); + + notify.notify_waiters(); + notify.notified().await; + + Ok(()) + } + } + }); + + sim.client("client", async move { + let node_id = NodeId::new_v4(); + let bus = Bus::new(node_id); + let mut c = Connection::new_initiator( + TcpStream::connect("host:1234").await.unwrap(), + bus.clone(), + ); + c.tick().await; + c.tick().await; + let _h = tokio::spawn(c.run()); + let mut stream = bus.new_stream(host_node_id, database_id).await.unwrap(); + + notify.notified().await; + assert!(stream.next().await.is_none()); + notify.notify_waiters(); + + Ok(()) + }); + + sim.run().unwrap(); + } + + #[test] + fn server_create_stream_server_close() { + let mut sim = turmoil::Builder::new().build(); + + let host_node_id = NodeId::new_v4(); + let notify = Arc::new(Notify::new()); + let client_id = NodeId::new_v4(); + let database_id = DatabaseId::new_v4(); + let expected_msg = StreamMessage::Proxy(ProxyMessage::ProxyRequest { + connection_id: 12, + req_id: 1, + program: "hello".to_string(), + }); + + sim.host("host", { + let notify = notify.clone(); + let expected_msg = expected_msg.clone(); + move || { + let notify = notify.clone(); + let expected_msg = expected_msg.clone(); + async move { + let bus = Bus::new(host_node_id); + let server = Server::new(bus.clone()); + let listener = turmoil::net::TcpListener::bind("0.0.0.0:1234") + .await + .unwrap(); + tokio::task::spawn_local(server.run(listener)); + + let mut stream = bus.new_stream(client_id, database_id).await.unwrap(); + stream.send(expected_msg).await.unwrap(); + notify.notified().await; + drop(stream); + + Ok(()) + } + } + }); + + sim.client("client", async move { + let bus = Bus::new(client_id); + let mut subs = bus.subscribe(database_id).unwrap(); + let c = Connection::new_initiator( + TcpStream::connect("host:1234").await.unwrap(), + bus.clone(), + ); + let _h = tokio::spawn(c.run()); + + let mut stream = subs.next().await.unwrap(); + let msg = stream.next().await.unwrap(); + assert_eq!(msg, expected_msg); + notify.notify_waiters(); + assert!(stream.next().await.is_none()); + + Ok(()) + }); + + sim.run().unwrap(); + } + + #[test] + fn server_create_stream_client_close() { + let mut sim = turmoil::Builder::new().build(); + + let host_node_id = NodeId::new_v4(); + let client_id = NodeId::new_v4(); + let database_id = DatabaseId::new_v4(); + + sim.host("host", { + move || async move { + let bus = Bus::new(host_node_id); + let server = Server::new(bus.clone()); + let listener = turmoil::net::TcpListener::bind("0.0.0.0:1234") + .await + .unwrap(); + tokio::task::spawn_local(server.run(listener)); + + let mut stream = bus.new_stream(client_id, database_id).await.unwrap(); + assert!(stream.next().await.is_none()); + + Ok(()) + } + }); + + sim.client("client", async move { + let bus = Bus::new(client_id); + let mut subs = bus.subscribe(database_id).unwrap(); + let c = Connection::new_initiator( + TcpStream::connect("host:1234").await.unwrap(), + bus.clone(), + ); + let _h = tokio::spawn(c.run()); + + let stream = subs.next().await.unwrap(); + drop(stream); + + Ok(()) + }); + + sim.run().unwrap(); + } +} diff --git a/libsqlx-server/src/main.rs b/libsqlx-server/src/main.rs index a8829093..16e6a38e 100644 --- a/libsqlx-server/src/main.rs +++ b/libsqlx-server/src/main.rs @@ -18,6 +18,7 @@ mod hrana; mod http; mod manager; mod meta; +mod linc; #[tokio::main] async fn main() -> Result<()> { From 4dad85ef7681ef7a7c6c94d8f9506bdc52fd02f8 Mon Sep 17 00:00:00 2001 From: ad hoc Date: Tue, 11 Jul 2023 11:58:30 +0200 Subject: [PATCH 12/64] base configuration --- Cargo.lock | 3 --- libsqlx-server/src/config.rs | 26 ++++++++++++++++++++++++++ libsqlx-server/src/main.rs | 1 + 3 files changed, 27 insertions(+), 3 deletions(-) create mode 100644 libsqlx-server/src/config.rs diff --git a/Cargo.lock b/Cargo.lock index 1e3d5a3d..5da186ff 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4124,7 +4124,6 @@ version = "0.15.0" dependencies = [ "anyhow", "arbitrary", - "async-bincode", "async-lock", "async-trait", "aws-config", @@ -4175,14 +4174,12 @@ dependencies = [ "tokio", "tokio-stream", "tokio-tungstenite", - "tokio-util", "tonic 0.8.3", "tonic-build", "tower", "tower-http", "tracing", "tracing-subscriber", - "turmoil", "url", "uuid", "vergen", diff --git a/libsqlx-server/src/config.rs b/libsqlx-server/src/config.rs new file mode 100644 index 00000000..f163be0e --- /dev/null +++ b/libsqlx-server/src/config.rs @@ -0,0 +1,26 @@ +use std::net::SocketAddr; + +use serde::Deserialize; + +#[derive(Deserialize, Debug, Clone)] +pub struct Config { + cluster_config: ClusterConfig, + user_api_config: UserApiConfig, + admin_api_config: AdminApiConfig, +} + +#[derive(Deserialize, Debug, Clone)] +pub struct ClusterConfig { + addr: SocketAddr, + peers: Vec<(u64, String)>, +} + +#[derive(Deserialize, Debug, Clone)] +pub struct UserApiConfig { + addr: SocketAddr, +} + +#[derive(Deserialize, Debug, Clone)] +pub struct AdminApiConfig { + addr: SocketAddr, +} diff --git a/libsqlx-server/src/main.rs b/libsqlx-server/src/main.rs index 16e6a38e..957606a4 100644 --- a/libsqlx-server/src/main.rs +++ b/libsqlx-server/src/main.rs @@ -19,6 +19,7 @@ mod http; mod manager; mod meta; mod linc; +mod config; #[tokio::main] async fn main() -> Result<()> { From 5433c7508273634842e39eb62b5dec90630bf180 Mon Sep 17 00:00:00 2001 From: ad hoc Date: Tue, 11 Jul 2023 12:19:14 +0200 Subject: [PATCH 13/64] load config --- Cargo.lock | 55 +++++++++++++++- libsqlx-server/Cargo.toml | 1 + libsqlx-server/src/allocation/mod.rs | 8 +-- libsqlx-server/src/config.rs | 23 +++++-- libsqlx-server/src/database.rs | 12 +++- libsqlx-server/src/hrana/batch.rs | 30 +++++---- libsqlx-server/src/hrana/http/mod.rs | 11 ++-- libsqlx-server/src/hrana/http/stream.rs | 24 +++++-- libsqlx-server/src/hrana/result_builder.rs | 4 +- libsqlx-server/src/hrana/stmt.rs | 43 +++++++----- libsqlx-server/src/http/admin.rs | 4 +- libsqlx-server/src/http/user/mod.rs | 9 ++- libsqlx-server/src/linc/bus.rs | 2 +- libsqlx-server/src/linc/connection.rs | 2 +- libsqlx-server/src/main.rs | 76 +++++++++++++++------- libsqlx/src/analysis.rs | 2 +- libsqlx/src/connection.rs | 2 +- libsqlx/src/database/libsql/connection.rs | 2 +- libsqlx/src/error.rs | 8 ++- libsqlx/src/lib.rs | 2 +- libsqlx/src/program.rs | 14 ++-- 21 files changed, 232 insertions(+), 102 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 5da186ff..2eb87fc8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2559,6 +2559,7 @@ dependencies = [ "tokio", "tokio-stream", "tokio-util", + "toml 0.7.6", "tracing", "tracing-subscriber", "turmoil", @@ -3932,6 +3933,15 @@ dependencies = [ "serde", ] +[[package]] +name = "serde_spanned" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "96426c9936fd7a0124915f9185ea1d20aa9445cc9821142f0a73bc9207a2e186" +dependencies = [ + "serde", +] + [[package]] name = "serde_urlencoded" version = "0.7.1" @@ -4576,6 +4586,40 @@ dependencies = [ "serde", ] +[[package]] +name = "toml" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c17e963a819c331dcacd7ab957d80bc2b9a9c1e71c804826d2f283dd65306542" +dependencies = [ + "serde", + "serde_spanned", + "toml_datetime", + "toml_edit", +] + +[[package]] +name = "toml_datetime" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7cda73e2f1397b1262d6dfdcef8aafae14d1de7748d66822d3bfeeb6d03e5e4b" +dependencies = [ + "serde", +] + +[[package]] +name = "toml_edit" +version = "0.19.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c500344a19072298cd05a7224b3c0c629348b78692bf48466c5238656e315a78" +dependencies = [ + "indexmap 2.0.0", + "serde", + "serde_spanned", + "toml_datetime", + "winnow", +] + [[package]] name = "tonic" version = "0.8.3" @@ -5223,7 +5267,7 @@ dependencies = [ "rustix 0.37.23", "serde", "sha2", - "toml", + "toml 0.5.11", "windows-sys 0.48.0", "zstd", ] @@ -5713,6 +5757,15 @@ version = "0.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1a515f5799fe4961cb532f983ce2b23082366b898e52ffbce459c86f67c8378a" +[[package]] +name = "winnow" +version = "0.4.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "81a2094c43cc94775293eaa0e499fbc30048a6d824ac82c0351a8c0bf9112529" +dependencies = [ + "memchr", +] + [[package]] name = "winreg" version = "0.10.1" diff --git a/libsqlx-server/Cargo.toml b/libsqlx-server/Cargo.toml index b4f9d366..6393f6cb 100644 --- a/libsqlx-server/Cargo.toml +++ b/libsqlx-server/Cargo.toml @@ -31,6 +31,7 @@ thiserror = "1.0.43" tokio = { version = "1.29.1", features = ["full"] } tokio-stream = "0.1.14" tokio-util = "0.7.8" +toml = "0.7.6" tracing = "0.1.37" tracing-subscriber = { version = "0.3.17", features = ["env-filter"] } uuid = { version = "1.4.0", features = ["v4"] } diff --git a/libsqlx-server/src/allocation/mod.rs b/libsqlx-server/src/allocation/mod.rs index a086f479..b0393165 100644 --- a/libsqlx-server/src/allocation/mod.rs +++ b/libsqlx-server/src/allocation/mod.rs @@ -27,7 +27,7 @@ pub enum AllocationMessage { HranaPipelineReq { req: PipelineRequestBody, ret: oneshot::Sender>, - } + }, } pub enum Database { @@ -87,8 +87,9 @@ pub struct ConnectionHandle { impl ConnectionHandle { pub async fn exec(&self, f: F) -> crate::Result - where F: for<'a> FnOnce(&'a mut (dyn libsqlx::Connection + 'a)) -> R + Send + 'static, - R: Send + 'static, + where + F: for<'a> FnOnce(&'a mut (dyn libsqlx::Connection + 'a)) -> R + Send + 'static, + R: Send + 'static, { let (sender, ret) = oneshot::channel(); let cb = move |conn: &mut dyn libsqlx::Connection| { @@ -154,7 +155,6 @@ impl Allocation { exec: exec_sender, exit: close_sender, } - } fn next_conn_id(&mut self) -> u32 { diff --git a/libsqlx-server/src/config.rs b/libsqlx-server/src/config.rs index f163be0e..bd7778b8 100644 --- a/libsqlx-server/src/config.rs +++ b/libsqlx-server/src/config.rs @@ -1,26 +1,35 @@ use std::net::SocketAddr; +use std::path::PathBuf; use serde::Deserialize; #[derive(Deserialize, Debug, Clone)] pub struct Config { - cluster_config: ClusterConfig, - user_api_config: UserApiConfig, - admin_api_config: AdminApiConfig, + pub db_path: PathBuf, + pub cluster_config: ClusterConfig, + pub user_api_config: UserApiConfig, + pub admin_api_config: AdminApiConfig, +} + +impl Config { + pub fn validate(&self) -> color_eyre::Result<()> { + // TODO: implement validation + Ok(()) + } } #[derive(Deserialize, Debug, Clone)] pub struct ClusterConfig { - addr: SocketAddr, - peers: Vec<(u64, String)>, + pub addr: SocketAddr, + pub peers: Vec<(u64, String)>, } #[derive(Deserialize, Debug, Clone)] pub struct UserApiConfig { - addr: SocketAddr, + pub addr: SocketAddr, } #[derive(Deserialize, Debug, Clone)] pub struct AdminApiConfig { - addr: SocketAddr, + pub addr: SocketAddr, } diff --git a/libsqlx-server/src/database.rs b/libsqlx-server/src/database.rs index e4971b74..4945cd70 100644 --- a/libsqlx-server/src/database.rs +++ b/libsqlx-server/src/database.rs @@ -1,16 +1,22 @@ use tokio::sync::{mpsc, oneshot}; -use crate::hrana::http::proto::{PipelineRequestBody, PipelineResponseBody}; use crate::allocation::AllocationMessage; +use crate::hrana::http::proto::{PipelineRequestBody, PipelineResponseBody}; pub struct Database { pub sender: mpsc::Sender, } impl Database { - pub async fn hrana_pipeline(&self, req: PipelineRequestBody) -> crate::Result { + pub async fn hrana_pipeline( + &self, + req: PipelineRequestBody, + ) -> crate::Result { let (sender, ret) = oneshot::channel(); - self.sender.send(AllocationMessage::HranaPipelineReq { req, ret: sender }).await.unwrap(); + self.sender + .send(AllocationMessage::HranaPipelineReq { req, ret: sender }) + .await + .unwrap(); ret.await.unwrap() } } diff --git a/libsqlx-server/src/hrana/batch.rs b/libsqlx-server/src/hrana/batch.rs index 7d2a1f0c..1368991e 100644 --- a/libsqlx-server/src/hrana/batch.rs +++ b/libsqlx-server/src/hrana/batch.rs @@ -10,7 +10,7 @@ use super::{proto, ProtocolError, Version}; use color_eyre::eyre::anyhow; use libsqlx::analysis::Statement; use libsqlx::program::{Cond, Program, Step}; -use libsqlx::query::{Query, Params}; +use libsqlx::query::{Params, Query}; use libsqlx::result_builder::{StepResult, StepResultsBuilder}; fn proto_cond_to_cond(cond: &proto::BatchCond, max_step_i: usize) -> color_eyre::Result { @@ -73,11 +73,13 @@ pub async fn execute_batch( db: &ConnectionHandle, pgm: Program, ) -> color_eyre::Result { - let builder = db.exec(move |conn| -> color_eyre::Result<_> { - let mut builder = HranaBatchProtoBuilder::default(); - conn.execute_program(pgm, &mut builder)?; - Ok(builder) - }).await??; + let builder = db + .exec(move |conn| -> color_eyre::Result<_> { + let mut builder = HranaBatchProtoBuilder::default(); + conn.execute_program(pgm, &mut builder)?; + Ok(builder) + }) + .await??; Ok(builder.into_ret()) } @@ -104,18 +106,18 @@ pub fn proto_sequence_to_program(sql: &str) -> color_eyre::Result { }) .collect(); - Ok(Program { - steps, - }) + Ok(Program { steps }) } pub async fn execute_sequence(conn: &ConnectionHandle, pgm: Program) -> color_eyre::Result<()> { - let builder = conn.exec(move |conn| -> color_eyre::Result<_> { - let mut builder = StepResultsBuilder::default(); - conn.execute_program(pgm, &mut builder)?; + let builder = conn + .exec(move |conn| -> color_eyre::Result<_> { + let mut builder = StepResultsBuilder::default(); + conn.execute_program(pgm, &mut builder)?; - Ok(builder) - }).await??; + Ok(builder) + }) + .await??; builder .into_ret() diff --git a/libsqlx-server/src/hrana/http/mod.rs b/libsqlx-server/src/hrana/http/mod.rs index 5e22bedc..651ab3f0 100644 --- a/libsqlx-server/src/hrana/http/mod.rs +++ b/libsqlx-server/src/hrana/http/mod.rs @@ -49,10 +49,11 @@ fn handle_index() -> color_eyre::Result> { pub async fn handle_pipeline( server: &Server, req: PipelineRequestBody, - mk_conn: F + mk_conn: F, ) -> color_eyre::Result -where F: FnOnce() -> Fut, - Fut: Future>, +where + F: FnOnce() -> Fut, + Fut: Future>, { let mut stream_guard = stream::acquire(server, req.baton.as_deref(), mk_conn).await?; @@ -73,7 +74,9 @@ where F: FnOnce() -> Fut, Ok(resp_body) } -async fn read_request_json(req: hyper::Request) -> color_eyre::Result { +async fn read_request_json( + req: hyper::Request, +) -> color_eyre::Result { let req_body = hyper::body::to_bytes(req.into_body()) .await .context("Could not read request body")?; diff --git a/libsqlx-server/src/hrana/http/stream.rs b/libsqlx-server/src/hrana/http/stream.rs index 1261e7c2..5f40537e 100644 --- a/libsqlx-server/src/hrana/http/stream.rs +++ b/libsqlx-server/src/hrana/http/stream.rs @@ -106,8 +106,9 @@ pub async fn acquire<'srv, F, Fut>( baton: Option<&str>, mk_conn: F, ) -> color_eyre::Result> -where F: FnOnce() -> Fut, - Fut: Future>, +where + F: FnOnce() -> Fut, + Fut: Future>, { let stream = match baton { Some(baton) => { @@ -117,7 +118,10 @@ where F: FnOnce() -> Fut, let handle = state.handles.get_mut(&stream_id); match handle { None => { - return Err(ProtocolError::BatonInvalid(format!("Stream handle for {stream_id} was not found")).into()) + return Err(ProtocolError::BatonInvalid(format!( + "Stream handle for {stream_id} was not found" + )) + .into()) } Some(Handle::Acquired) => { return Err(ProtocolError::BatonReused) @@ -149,7 +153,9 @@ where F: FnOnce() -> Fut, stream } None => { - let conn = mk_conn().await.context("Could not create a database connection")?; + let conn = mk_conn() + .await + .context("Could not create a database connection")?; let mut state = server.stream_state.lock(); let stream = Box::new(Stream { @@ -291,7 +297,8 @@ fn decode_baton(server: &Server, baton_str: &str) -> color_eyre::Result<(u64, u6 return Err(ProtocolError::BatonInvalid(format!( "Baton has invalid size of {} bytes", baton_data.len() - )).into()); + )) + .into()); } let payload = &baton_data[0..16]; @@ -299,8 +306,11 @@ fn decode_baton(server: &Server, baton_str: &str) -> color_eyre::Result<(u64, u6 let mut hmac = hmac::Hmac::::new_from_slice(&server.baton_key).unwrap(); hmac.update(payload); - hmac.verify_slice(received_mac) - .map_err(|_| anyhow!(ProtocolError::BatonInvalid("Invalid MAC on baton".to_string())))?; + hmac.verify_slice(received_mac).map_err(|_| { + anyhow!(ProtocolError::BatonInvalid( + "Invalid MAC on baton".to_string() + )) + })?; let stream_id = u64::from_be_bytes(payload[0..8].try_into().unwrap()); let baton_seq = u64::from_be_bytes(payload[8..16].try_into().unwrap()); diff --git a/libsqlx-server/src/hrana/result_builder.rs b/libsqlx-server/src/hrana/result_builder.rs index 94b23775..b6b8c635 100644 --- a/libsqlx-server/src/hrana/result_builder.rs +++ b/libsqlx-server/src/hrana/result_builder.rs @@ -194,10 +194,10 @@ impl ResultBuilder for SingleStatementBuilder { } fn finish( - &mut self, + &mut self, _is_txn: bool, _frame_no: Option, - ) -> Result<(), QueryResultBuilderError> { + ) -> Result<(), QueryResultBuilderError> { Ok(()) } } diff --git a/libsqlx-server/src/hrana/stmt.rs b/libsqlx-server/src/hrana/stmt.rs index e74c3d42..5453ab5c 100644 --- a/libsqlx-server/src/hrana/stmt.rs +++ b/libsqlx-server/src/hrana/stmt.rs @@ -1,8 +1,8 @@ use std::collections::HashMap; -use color_eyre::eyre::{bail, anyhow}; +use color_eyre::eyre::{anyhow, bail}; use libsqlx::analysis::Statement; -use libsqlx::query::{Query, Params, Value}; +use libsqlx::query::{Params, Query, Value}; use super::result_builder::SingleStatementBuilder; use super::{proto, ProtocolError, Version}; @@ -47,14 +47,15 @@ pub async fn execute_stmt( conn: &ConnectionHandle, query: Query, ) -> color_eyre::Result { - let builder = conn.exec(move |conn| -> color_eyre::Result<_> { - let mut builder = SingleStatementBuilder::default(); - let pgm = libsqlx::program::Program::from_queries(std::iter::once(query)); - conn.execute_program(pgm, &mut builder)?; + let builder = conn + .exec(move |conn| -> color_eyre::Result<_> { + let mut builder = SingleStatementBuilder::default(); + let pgm = libsqlx::program::Program::from_queries(std::iter::once(query)); + conn.execute_program(pgm, &mut builder)?; - Ok(builder) - - }).await??; + Ok(builder) + }) + .await??; builder .into_ret() @@ -191,21 +192,27 @@ fn proto_value_from_value(value: Value) -> proto::Value { // } // } -pub fn stmt_error_from_sqld_error(sqld_error: libsqlx::error::Error) -> Result { +pub fn stmt_error_from_sqld_error( + sqld_error: libsqlx::error::Error, +) -> Result { Ok(match sqld_error { libsqlx::error::Error::LibSqlInvalidQueryParams(msg) => StmtError::ArgsInvalid { msg }, libsqlx::error::Error::LibSqlTxTimeout => StmtError::TransactionTimeout, libsqlx::error::Error::LibSqlTxBusy => StmtError::TransactionBusy, libsqlx::error::Error::Blocked(reason) => StmtError::Blocked { reason }, libsqlx::error::Error::RusqliteError(rusqlite_error) => match rusqlite_error { - libsqlx::error::RusqliteError::SqliteFailure(sqlite_error, Some(message)) => StmtError::SqliteError { - source: sqlite_error, - message, - }, - libsqlx::error::RusqliteError::SqliteFailure(sqlite_error, None) => StmtError::SqliteError { - message: sqlite_error.to_string(), - source: sqlite_error, - }, + libsqlx::error::RusqliteError::SqliteFailure(sqlite_error, Some(message)) => { + StmtError::SqliteError { + source: sqlite_error, + message, + } + } + libsqlx::error::RusqliteError::SqliteFailure(sqlite_error, None) => { + StmtError::SqliteError { + message: sqlite_error.to_string(), + source: sqlite_error, + } + } libsqlx::error::RusqliteError::SqlInputError { error: sqlite_error, msg: message, diff --git a/libsqlx-server/src/http/admin.rs b/libsqlx-server/src/http/admin.rs index 6b23ef58..346987c4 100644 --- a/libsqlx-server/src/http/admin.rs +++ b/libsqlx-server/src/http/admin.rs @@ -11,7 +11,7 @@ use crate::{ meta::Store, }; -pub struct AdminApiConfig { +pub struct Config { pub meta_store: Arc, } @@ -19,7 +19,7 @@ struct AdminServerState { meta_store: Arc, } -pub async fn run_admin_api(config: AdminApiConfig, listener: I) -> Result<()> +pub async fn run_admin_api(config: Config, listener: I) -> Result<()> where I: Accept, I::Conn: AsyncRead + AsyncWrite + Send + Unpin + 'static, diff --git a/libsqlx-server/src/http/user/mod.rs b/libsqlx-server/src/http/user/mod.rs index 4c314a39..bc3265e9 100644 --- a/libsqlx-server/src/http/user/mod.rs +++ b/libsqlx-server/src/http/user/mod.rs @@ -13,7 +13,7 @@ use crate::manager::Manager; mod error; mod extractors; -pub struct UserApiConfig { +pub struct Config { pub manager: Arc, } @@ -21,7 +21,7 @@ struct UserApiState { manager: Arc, } -pub async fn run_user_api(config: UserApiConfig, listener: I) -> Result<()> +pub async fn run_user_api(config: Config, listener: I) -> Result<()> where I: Accept, I::Conn: AsyncRead + AsyncWrite + Send + Unpin + 'static, @@ -41,7 +41,10 @@ where Ok(()) } -async fn handle_hrana_pipeline(db: Database, Json(req): Json) -> Json { +async fn handle_hrana_pipeline( + db: Database, + Json(req): Json, +) -> Json { let resp = db.hrana_pipeline(req).await; dbg!(); Json(resp.unwrap()) diff --git a/libsqlx-server/src/linc/bus.rs b/libsqlx-server/src/linc/bus.rs index 6fcc3238..f9533347 100644 --- a/libsqlx-server/src/linc/bus.rs +++ b/libsqlx-server/src/linc/bus.rs @@ -1,7 +1,7 @@ use std::collections::{hash_map::Entry, HashMap}; use std::sync::Arc; -use color_eyre::eyre::{bail, anyhow}; +use color_eyre::eyre::{anyhow, bail}; use parking_lot::Mutex; use tokio::sync::{mpsc, Notify}; use uuid::Uuid; diff --git a/libsqlx-server/src/linc/connection.rs b/libsqlx-server/src/linc/connection.rs index b5a8ce25..1d598cef 100644 --- a/libsqlx-server/src/linc/connection.rs +++ b/libsqlx-server/src/linc/connection.rs @@ -2,7 +2,7 @@ use std::collections::HashMap; use async_bincode::tokio::AsyncBincodeStream; use async_bincode::AsyncDestination; -use color_eyre::eyre::{bail, anyhow}; +use color_eyre::eyre::{anyhow, bail}; use futures::{SinkExt, StreamExt}; use tokio::io::{AsyncRead, AsyncWrite}; use tokio::sync::mpsc::error::TrySendError; diff --git a/libsqlx-server/src/main.rs b/libsqlx-server/src/main.rs index 957606a4..cab52d5d 100644 --- a/libsqlx-server/src/main.rs +++ b/libsqlx-server/src/main.rs @@ -1,10 +1,12 @@ -use std::{path::PathBuf, sync::Arc}; +use std::fs::read_to_string; +use std::path::PathBuf; +use std::sync::Arc; +use clap::Parser; use color_eyre::eyre::Result; -use http::{ - admin::{run_admin_api, AdminApiConfig}, - user::{run_user_api, UserApiConfig}, -}; +use config::{AdminApiConfig, UserApiConfig}; +use http::admin::run_admin_api; +use http::user::run_user_api; use hyper::server::conn::AddrIncoming; use manager::Manager; use meta::Store; @@ -13,35 +15,65 @@ use tracing::metadata::LevelFilter; use tracing_subscriber::prelude::*; mod allocation; +mod config; mod database; mod hrana; mod http; +mod linc; mod manager; mod meta; -mod linc; -mod config; + +#[derive(Debug, Parser)] +struct Args { + /// Path to the node configuration file + #[clap(long, short)] + config: PathBuf, +} + +async fn spawn_admin_api( + set: &mut JoinSet>, + config: &AdminApiConfig, + meta_store: Arc, +) -> Result<()> { + let admin_api_listener = tokio::net::TcpListener::bind(config.addr).await?; + let fut = run_admin_api( + http::admin::Config { meta_store }, + AddrIncoming::from_listener(admin_api_listener)?, + ); + set.spawn(fut); + + Ok(()) +} + +async fn spawn_user_api( + set: &mut JoinSet>, + config: &UserApiConfig, + manager: Arc, +) -> Result<()> { + let user_api_listener = tokio::net::TcpListener::bind(config.addr).await?; + set.spawn(run_user_api( + http::user::Config { manager }, + AddrIncoming::from_listener(user_api_listener)?, + )); + + Ok(()) +} #[tokio::main] async fn main() -> Result<()> { init(); + let args = Args::parse(); + let config_str = read_to_string(args.config)?; + let config: config::Config = toml::from_str(&config_str)?; + config.validate()?; + let mut join_set = JoinSet::new(); - let db_path = PathBuf::from("database"); - let store = Arc::new(Store::new(&db_path)); - let admin_api_listener = tokio::net::TcpListener::bind("0.0.0.0:3456").await?; - join_set.spawn(run_admin_api( - AdminApiConfig { - meta_store: store.clone(), - }, - AddrIncoming::from_listener(admin_api_listener)?, - )); + let store = Arc::new(Store::new(&config.db_path)); + let manager = Arc::new(Manager::new(config.db_path.clone(), store.clone(), 100)); - let manager = Arc::new(Manager::new(db_path.clone(), store, 100)); - let user_api_listener = tokio::net::TcpListener::bind("0.0.0.0:3457").await?; - join_set.spawn(run_user_api( - UserApiConfig { manager }, - AddrIncoming::from_listener(user_api_listener)?, - )); + spawn_admin_api(&mut join_set, &config.admin_api_config, store.clone()).await?; + spawn_user_api(&mut join_set, &config.user_api_config, manager).await?; join_set.join_next().await; diff --git a/libsqlx/src/analysis.rs b/libsqlx/src/analysis.rs index fccbf3dc..97ef5f5b 100644 --- a/libsqlx/src/analysis.rs +++ b/libsqlx/src/analysis.rs @@ -258,7 +258,7 @@ impl Statement { found: Some(found), }, Some((line, col)), - )) => Some(Err(crate::error::Error::SyntaxError { line, col, found})), + )) => Some(Err(crate::error::Error::SyntaxError { line, col, found })), Err(e) => Some(Err(e.into())), } }) diff --git a/libsqlx/src/connection.rs b/libsqlx/src/connection.rs index 38d31964..e2fd05f8 100644 --- a/libsqlx/src/connection.rs +++ b/libsqlx/src/connection.rs @@ -2,7 +2,7 @@ use rusqlite::types::Value; use crate::program::{Program, Step}; use crate::query::Query; -use crate::result_builder::{ResultBuilder, QueryBuilderConfig, QueryResultBuilderError}; +use crate::result_builder::{QueryBuilderConfig, QueryResultBuilderError, ResultBuilder}; #[derive(Debug, Clone)] pub struct DescribeResponse { diff --git a/libsqlx/src/database/libsql/connection.rs b/libsqlx/src/database/libsql/connection.rs index 554a22da..cd5bcff3 100644 --- a/libsqlx/src/database/libsql/connection.rs +++ b/libsqlx/src/database/libsql/connection.rs @@ -177,7 +177,7 @@ impl LibsqlConnection { query .params .bind(&mut stmt) - .map_err(|e|Error::LibSqlInvalidQueryParams(e.to_string()))?; + .map_err(|e| Error::LibSqlInvalidQueryParams(e.to_string()))?; let mut qresult = stmt.raw_query(); builder.begin_rows()?; diff --git a/libsqlx/src/error.rs b/libsqlx/src/error.rs index 47fde1ae..07a71831 100644 --- a/libsqlx/src/error.rs +++ b/libsqlx/src/error.rs @@ -1,6 +1,6 @@ use crate::result_builder::QueryResultBuilderError; -pub use rusqlite::Error as RusqliteError; pub use rusqlite::ffi::ErrorCode; +pub use rusqlite::Error as RusqliteError; #[allow(clippy::enum_variant_names)] #[derive(Debug, thiserror::Error)] @@ -39,10 +39,12 @@ pub enum Error { UnsupportedStatement, #[error("Syntax error at {line}:{col}: {found}")] SyntaxError { - line: u64, col: usize, found: String + line: u64, + col: usize, + found: String, }, #[error(transparent)] - LexerError(#[from] sqlite3_parser::lexer::sql::Error) + LexerError(#[from] sqlite3_parser::lexer::sql::Error), } impl From for Error { diff --git a/libsqlx/src/lib.rs b/libsqlx/src/lib.rs index f9ef106d..a89c2771 100644 --- a/libsqlx/src/lib.rs +++ b/libsqlx/src/lib.rs @@ -12,8 +12,8 @@ pub type Result = std::result::Result; pub use connection::Connection; pub use database::libsql; +pub use database::libsql::replication_log::FrameNo; pub use database::proxy; pub use database::Database; -pub use database::libsql::replication_log::FrameNo; pub use rusqlite; diff --git a/libsqlx/src/program.rs b/libsqlx/src/program.rs index 131dd125..0b5c7980 100644 --- a/libsqlx/src/program.rs +++ b/libsqlx/src/program.rs @@ -25,14 +25,16 @@ impl Program { /// transforms a collection of queries into a batch program. The execution of each query /// depends on the success of the previous one. pub fn from_queries(qs: impl IntoIterator) -> Self { - let steps = qs.into_iter().enumerate().map(|(idx, query)| Step { - cond: (idx > 0).then(|| Cond::Ok { step: idx - 1 }), - query, - }) - .collect(); + let steps = qs + .into_iter() + .enumerate() + .map(|(idx, query)| Step { + cond: (idx > 0).then(|| Cond::Ok { step: idx - 1 }), + query, + }) + .collect(); Self { steps } - } #[cfg(test)] From 23ca93feab6fe7fb84e7bc7c86f6d73ba16360af Mon Sep 17 00:00:00 2001 From: ad hoc Date: Tue, 11 Jul 2023 12:22:57 +0200 Subject: [PATCH 14/64] add config defaults --- libsqlx-server/src/config.rs | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/libsqlx-server/src/config.rs b/libsqlx-server/src/config.rs index bd7778b8..0511d4f0 100644 --- a/libsqlx-server/src/config.rs +++ b/libsqlx-server/src/config.rs @@ -5,6 +5,7 @@ use serde::Deserialize; #[derive(Deserialize, Debug, Clone)] pub struct Config { + #[serde(default = "default_db_path")] pub db_path: PathBuf, pub cluster_config: ClusterConfig, pub user_api_config: UserApiConfig, @@ -26,10 +27,24 @@ pub struct ClusterConfig { #[derive(Deserialize, Debug, Clone)] pub struct UserApiConfig { + #[serde(default = "default_user_addr")] pub addr: SocketAddr, } #[derive(Deserialize, Debug, Clone)] pub struct AdminApiConfig { + #[serde(default = "default_admin_addr")] pub addr: SocketAddr, } + +fn default_db_path() -> PathBuf { + PathBuf::from("data.sqld") +} + +fn default_admin_addr() -> SocketAddr { + "0.0.0.0:8081".parse().unwrap() +} + +fn default_user_addr() -> SocketAddr { + "0.0.0.0:8080".parse().unwrap() +} From 4cab1b0c0dc293c4fb99f9974a60aca554043c98 Mon Sep 17 00:00:00 2001 From: ad hoc Date: Tue, 11 Jul 2023 12:35:55 +0200 Subject: [PATCH 15/64] peer config --- libsqlx-server/src/config.rs | 56 +++++++++++++++++++++++++++++++++--- 1 file changed, 52 insertions(+), 4 deletions(-) diff --git a/libsqlx-server/src/config.rs b/libsqlx-server/src/config.rs index 0511d4f0..4772b53f 100644 --- a/libsqlx-server/src/config.rs +++ b/libsqlx-server/src/config.rs @@ -2,14 +2,19 @@ use std::net::SocketAddr; use std::path::PathBuf; use serde::Deserialize; +use serde::de::Visitor; #[derive(Deserialize, Debug, Clone)] pub struct Config { + /// Database path #[serde(default = "default_db_path")] pub db_path: PathBuf, - pub cluster_config: ClusterConfig, - pub user_api_config: UserApiConfig, - pub admin_api_config: AdminApiConfig, + /// Cluster configuration + pub cluster: ClusterConfig, + /// User API configuration + pub user_api: UserApiConfig, + /// Admin API configuration + pub admin_api: AdminApiConfig, } impl Config { @@ -21,8 +26,11 @@ impl Config { #[derive(Deserialize, Debug, Clone)] pub struct ClusterConfig { + /// Address to bind this node to + #[serde(default = "default_linc_addr")] pub addr: SocketAddr, - pub peers: Vec<(u64, String)>, + /// List of peers in the format `:` + pub peers: Vec, } #[derive(Deserialize, Debug, Clone)] @@ -48,3 +56,43 @@ fn default_admin_addr() -> SocketAddr { fn default_user_addr() -> SocketAddr { "0.0.0.0:8080".parse().unwrap() } + +fn default_linc_addr() -> SocketAddr { + "0.0.0.0:5001".parse().unwrap() +} + +#[derive(Debug, Clone)] +struct Peer { + id: u64, + addr: String, +} + +impl<'de> Deserialize<'de> for Peer { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de> { + struct V; + + impl Visitor<'_> for V { + type Value = Peer; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("a string in the format :") + } + + fn visit_str(self, v: &str) -> Result + where + E: serde::de::Error, { + + let mut iter = v.split(":"); + let Some(id) = iter.next() else { return Err(E::custom("node id is missing")) }; + let Ok(id) = id.parse::() else { return Err(E::custom("failed to parse node id")) }; + let Some(addr) = iter.next() else { return Err(E::custom("node address is missing")) }; + Ok(Peer { id, addr: addr.to_string() }) + } + } + + deserializer.deserialize_str(V) + } +} + From c40482ddc9a64505fae378a6a3e26352c1ec1445 Mon Sep 17 00:00:00 2001 From: ad hoc Date: Tue, 11 Jul 2023 12:35:55 +0200 Subject: [PATCH 16/64] peer config --- libsqlx-server/src/main.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/libsqlx-server/src/main.rs b/libsqlx-server/src/main.rs index cab52d5d..a8360402 100644 --- a/libsqlx-server/src/main.rs +++ b/libsqlx-server/src/main.rs @@ -72,8 +72,8 @@ async fn main() -> Result<()> { let store = Arc::new(Store::new(&config.db_path)); let manager = Arc::new(Manager::new(config.db_path.clone(), store.clone(), 100)); - spawn_admin_api(&mut join_set, &config.admin_api_config, store.clone()).await?; - spawn_user_api(&mut join_set, &config.user_api_config, manager).await?; + spawn_admin_api(&mut join_set, &config.admin_api, store.clone()).await?; + spawn_user_api(&mut join_set, &config.user_api, manager).await?; join_set.join_next().await; From 198ed752f98f9d65df500c67a4283c7ff5bfd2e2 Mon Sep 17 00:00:00 2001 From: ad hoc Date: Tue, 11 Jul 2023 19:20:23 +0200 Subject: [PATCH 17/64] changes to proto --- libsqlx-server/src/allocation/mod.rs | 5 - libsqlx-server/src/config.rs | 51 +-- libsqlx-server/src/http/user/mod.rs | 1 - libsqlx-server/src/linc/bus.rs | 189 ++-------- libsqlx-server/src/linc/connection.rs | 398 ++++++--------------- libsqlx-server/src/linc/connection_pool.rs | 9 +- libsqlx-server/src/linc/handler.rs | 6 + libsqlx-server/src/linc/mod.rs | 30 +- libsqlx-server/src/linc/proto.rs | 110 +----- libsqlx-server/src/linc/server.rs | 16 +- 10 files changed, 203 insertions(+), 612 deletions(-) create mode 100644 libsqlx-server/src/linc/handler.rs diff --git a/libsqlx-server/src/allocation/mod.rs b/libsqlx-server/src/allocation/mod.rs index b0393165..38c29ff3 100644 --- a/libsqlx-server/src/allocation/mod.rs +++ b/libsqlx-server/src/allocation/mod.rs @@ -133,11 +133,8 @@ impl Allocation { } async fn new_conn(&mut self) -> ConnectionHandle { - dbg!(); let id = self.next_conn_id(); - dbg!(); let conn = block_in_place(|| self.database.connect()); - dbg!(); let (close_sender, exit) = oneshot::channel(); let (exec_sender, exec_receiver) = mpsc::channel(1); let conn = Connection { @@ -147,9 +144,7 @@ impl Allocation { exec: exec_receiver, }; - dbg!(); self.connections_futs.spawn(conn.run()); - dbg!(); ConnectionHandle { exec: exec_sender, diff --git a/libsqlx-server/src/config.rs b/libsqlx-server/src/config.rs index 4772b53f..cb2d68b5 100644 --- a/libsqlx-server/src/config.rs +++ b/libsqlx-server/src/config.rs @@ -1,8 +1,8 @@ use std::net::SocketAddr; use std::path::PathBuf; -use serde::Deserialize; use serde::de::Visitor; +use serde::Deserialize; #[derive(Deserialize, Debug, Clone)] pub struct Config { @@ -62,7 +62,7 @@ fn default_linc_addr() -> SocketAddr { } #[derive(Debug, Clone)] -struct Peer { +pub struct Peer { id: u64, addr: String, } @@ -70,29 +70,32 @@ struct Peer { impl<'de> Deserialize<'de> for Peer { fn deserialize(deserializer: D) -> Result where - D: serde::Deserializer<'de> { - struct V; - - impl Visitor<'_> for V { - type Value = Peer; - - fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { - formatter.write_str("a string in the format :") - } - - fn visit_str(self, v: &str) -> Result - where - E: serde::de::Error, { - - let mut iter = v.split(":"); - let Some(id) = iter.next() else { return Err(E::custom("node id is missing")) }; - let Ok(id) = id.parse::() else { return Err(E::custom("failed to parse node id")) }; - let Some(addr) = iter.next() else { return Err(E::custom("node address is missing")) }; - Ok(Peer { id, addr: addr.to_string() }) - } + D: serde::Deserializer<'de>, + { + struct V; + + impl Visitor<'_> for V { + type Value = Peer; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("a string in the format :") } - deserializer.deserialize_str(V) + fn visit_str(self, v: &str) -> Result + where + E: serde::de::Error, + { + let mut iter = v.split(":"); + let Some(id) = iter.next() else { return Err(E::custom("node id is missing")) }; + let Ok(id) = id.parse::() else { return Err(E::custom("failed to parse node id")) }; + let Some(addr) = iter.next() else { return Err(E::custom("node address is missing")) }; + Ok(Peer { + id, + addr: addr.to_string(), + }) + } } -} + deserializer.deserialize_str(V) + } +} diff --git a/libsqlx-server/src/http/user/mod.rs b/libsqlx-server/src/http/user/mod.rs index bc3265e9..f357499c 100644 --- a/libsqlx-server/src/http/user/mod.rs +++ b/libsqlx-server/src/http/user/mod.rs @@ -46,6 +46,5 @@ async fn handle_hrana_pipeline( Json(req): Json, ) -> Json { let resp = db.hrana_pipeline(req).await; - dbg!(); Json(resp.unwrap()) } diff --git a/libsqlx-server/src/linc/bus.rs b/libsqlx-server/src/linc/bus.rs index f9533347..7c52ec42 100644 --- a/libsqlx-server/src/linc/bus.rs +++ b/libsqlx-server/src/linc/bus.rs @@ -1,186 +1,59 @@ -use std::collections::{hash_map::Entry, HashMap}; use std::sync::Arc; -use color_eyre::eyre::{anyhow, bail}; -use parking_lot::Mutex; -use tokio::sync::{mpsc, Notify}; use uuid::Uuid; -use super::connection::{ConnectionHandle, Stream}; +use super::{connection::SendQueue, handler::Handler, Outbound, Inbound}; type NodeId = Uuid; type DatabaseId = Uuid; -#[must_use] -pub struct Subscription { - receiver: mpsc::Receiver, - bus: Bus, - database_id: DatabaseId, +pub struct Bus { + inner: Arc>, } -impl Drop for Subscription { - fn drop(&mut self) { - self.bus - .inner - .lock() - .subscriptions - .remove(&self.database_id); +impl Clone for Bus { + fn clone(&self) -> Self { + Self { inner: self.inner.clone() } } } -impl futures::Stream for Subscription { - type Item = Stream; - - fn poll_next( - mut self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - self.receiver.poll_recv(cx) - } -} - -#[derive(Clone)] -pub struct Bus { - inner: Arc>, - pub node_id: NodeId, -} - -enum ConnectionSlot { - Handle(ConnectionHandle), - // Interest in the connection when it becomes available - Interest(Arc), -} - -struct BusInner { - connections: HashMap, - subscriptions: HashMap>, +struct BusInner { + node_id: NodeId, + handler: H, + send_queue: SendQueue, } -impl Bus { - pub fn new(node_id: NodeId) -> Self { +impl Bus { + pub fn new(node_id: NodeId, handler: H) -> Self { + let send_queue = SendQueue::new(); Self { - node_id, - inner: Arc::new(Mutex::new(BusInner { - connections: HashMap::new(), - subscriptions: HashMap::new(), - })), + inner: Arc::new(BusInner { + node_id, + handler, + send_queue, + }), } } - /// open a new stream to the database at `database_id` on the node `node_id` - pub async fn new_stream( - &self, - node_id: NodeId, - database_id: DatabaseId, - ) -> color_eyre::Result { - let get_conn = || { - let mut lock = self.inner.lock(); - match lock.connections.entry(node_id) { - Entry::Occupied(mut e) => match e.get_mut() { - ConnectionSlot::Handle(h) => Ok(h.clone()), - ConnectionSlot::Interest(notify) => Err(notify.clone()), - }, - Entry::Vacant(e) => { - let notify = Arc::new(Notify::new()); - e.insert(ConnectionSlot::Interest(notify.clone())); - Err(notify) - } - } - }; - - let conn = match get_conn() { - Ok(conn) => conn, - Err(notify) => { - notify.notified().await; - get_conn().map_err(|_| anyhow!("failed to create stream"))? - } - }; - - conn.new_stream(database_id).await + pub fn node_id(&self) -> NodeId { + self.inner.node_id } - /// Notify a subscription that new stream was openned - pub async fn notify_subscription( - &mut self, - database_id: DatabaseId, - stream: Stream, - ) -> color_eyre::Result<()> { - let maybe_sender = self.inner.lock().subscriptions.get(&database_id).cloned(); - - match maybe_sender { - Some(sender) => { - if sender.send(stream).await.is_err() { - bail!("subscription for {database_id} closed"); - } - - Ok(()) - } - None => { - bail!("no subscription for {database_id}") - } - } + pub async fn incomming(&self, incomming: Inbound) { + self.inner.handler.handle(self, incomming); } - #[cfg(test)] - pub fn is_empty(&self) -> bool { - self.inner.lock().connections.is_empty() + pub async fn dispatch(&self, msg: Outbound) { + assert!( + msg.to != self.node_id(), + "trying to send a message to ourself!" + ); + // This message is outbound. + self.inner.send_queue.enqueue(msg).await; } - #[must_use] - pub fn register_connection(&self, node_id: NodeId, conn: ConnectionHandle) -> Registration { - let mut lock = self.inner.lock(); - match lock.connections.entry(node_id) { - Entry::Occupied(mut e) => { - if let ConnectionSlot::Interest(ref notify) = e.get() { - notify.notify_waiters(); - } - - *e.get_mut() = ConnectionSlot::Handle(conn); - } - Entry::Vacant(e) => { - e.insert(ConnectionSlot::Handle(conn)); - } - } - - Registration { - bus: self.clone(), - node_id, - } - } - - pub fn subscribe(&self, database_id: DatabaseId) -> color_eyre::Result { - let (sender, receiver) = mpsc::channel(1); - { - let mut inner = self.inner.lock(); - - if inner.subscriptions.contains_key(&database_id) { - bail!("a subscription already exist for that database"); - } - - inner.subscriptions.insert(database_id, sender); - } - - Ok(Subscription { - receiver, - bus: self.clone(), - database_id, - }) - } -} - -pub struct Registration { - bus: Bus, - node_id: NodeId, -} + pub fn send_queue(&self) -> &SendQueue { + &self.inner.send_queue -impl Drop for Registration { - fn drop(&mut self) { - assert!(self - .bus - .inner - .lock() - .connections - .remove(&self.node_id) - .is_some()); } } diff --git a/libsqlx-server/src/linc/connection.rs b/libsqlx-server/src/linc/connection.rs index 1d598cef..55977623 100644 --- a/libsqlx-server/src/linc/connection.rs +++ b/libsqlx-server/src/linc/connection.rs @@ -2,156 +2,35 @@ use std::collections::HashMap; use async_bincode::tokio::AsyncBincodeStream; use async_bincode::AsyncDestination; -use color_eyre::eyre::{anyhow, bail}; +use color_eyre::eyre::bail; use futures::{SinkExt, StreamExt}; +use parking_lot::RwLock; use tokio::io::{AsyncRead, AsyncWrite}; -use tokio::sync::mpsc::error::TrySendError; -use tokio::sync::{mpsc, oneshot}; +use tokio::sync::mpsc; use tokio::time::{Duration, Instant}; -use tokio_util::sync::PollSender; -use crate::linc::proto::{NodeError, NodeMessage}; +use crate::linc::proto::ProtoError; use crate::linc::CURRENT_PROTO_VERSION; -use super::bus::{Bus, Registration}; -use super::proto::{Message, StreamId, StreamMessage}; -use super::{DatabaseId, NodeId}; -use super::{StreamIdAllocator, MAX_STREAM_MSG}; - -#[derive(Debug, Clone)] -pub struct ConnectionHandle { - connection_sender: mpsc::Sender, -} - -impl ConnectionHandle { - pub async fn new_stream(&self, database_id: DatabaseId) -> color_eyre::eyre::Result { - let (send, ret) = oneshot::channel(); - self.connection_sender - .send(ConnectionMessage::StreamCreate { - database_id, - ret: send, - }) - .await - .unwrap(); - - Ok(ret.await?) - } -} - -/// A Bidirectional stream between databases on two nodes. -#[derive(Debug)] -pub struct Stream { - stream_id: StreamId, - /// sender to the connection - sender: tokio_util::sync::PollSender, - /// incoming message for this stream - recv: tokio_stream::wrappers::ReceiverStream, -} - -impl futures::Sink for Stream { - type Error = tokio_util::sync::PollSendError; - - fn poll_ready( - mut self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - self.sender.poll_ready_unpin(cx) - } - - fn start_send( - mut self: std::pin::Pin<&mut Self>, - payload: StreamMessage, - ) -> Result<(), Self::Error> { - let stream_id = self.stream_id; - self.sender - .start_send_unpin(ConnectionMessage::Message(Message::Stream { - stream_id, - payload, - })) - } - - fn poll_flush( - mut self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - self.sender.poll_flush_unpin(cx) - } - - fn poll_close( - mut self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - self.sender.poll_close_unpin(cx) - } -} - -impl futures::Stream for Stream { - type Item = StreamMessage; - - fn poll_next( - mut self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - self.recv.poll_next_unpin(cx) - } -} - -impl Drop for Stream { - fn drop(&mut self) { - self.recv.close(); - assert!(self.recv.as_mut().try_recv().is_err()); - let mut sender = self.sender.clone(); - let id = self.stream_id; - if let Some(sender_ref) = sender.get_ref() { - // Try send here is mostly for turmoil, since it stops polling the future as soon as - // the test future returns which causes spawn to panic. In the tests, the channel will - // always have capacity. - if let Err(TrySendError::Full(m)) = - sender_ref.try_send(ConnectionMessage::CloseStream(id)) - { - tokio::task::spawn(async move { - let _ = sender.send(m).await; - }); - } - } - } -} - -struct StreamState { - sender: mpsc::Sender, -} +use super::bus::{Bus}; +use super::handler::Handler; +use super::proto::{Enveloppe, Message}; +use super::{NodeId, Outbound, Inbound}; /// A connection to another node. Manage the connection state, and (de)register streams with the /// `Bus` -pub struct Connection { +pub struct Connection { /// Id of the current node pub peer: Option, /// State of the connection pub state: ConnectionState, /// Sink/Stream for network messages - conn: AsyncBincodeStream, - /// Collection of streams for that connection - streams: HashMap, - /// internal connection messages - connection_messages: mpsc::Receiver, - connection_messages_sender: mpsc::Sender, + conn: AsyncBincodeStream, /// Are we the initiator of this connection? is_initiator: bool, - bus: Bus, - stream_id_allocator: StreamIdAllocator, - /// handle to the registration of this connection to the bus. - /// Dropping this deregister this connection from the bus - registration: Option, -} - -#[derive(Debug)] -pub enum ConnectionMessage { - StreamCreate { - database_id: DatabaseId, - ret: oneshot::Sender, - }, - CloseStream(StreamId), - Message(Message), + /// send queue for this connection + send_queue: Option>, + bus: Bus, } #[derive(Debug)] @@ -170,49 +49,61 @@ pub fn handshake_deadline() -> Instant { Instant::now() + HANDSHAKE_TIMEOUT } -impl Connection +// TODO: limit send queue depth +pub struct SendQueue { + senders: RwLock>>, +} + +impl SendQueue { + pub fn new() -> Self { + Self { + senders: Default::default(), + } + } + + pub async fn enqueue(&self, msg: Outbound) { + let sender = match self.senders.read().get(&msg.to) { + Some(sender) => sender.clone(), + None => todo!("no queue"), + }; + + sender.send(msg.enveloppe); + } + + pub fn register(&self, node_id: NodeId) -> mpsc::UnboundedReceiver { + let (sender, receiver) = mpsc::unbounded_channel(); + self.senders.write().insert(node_id, sender); + + receiver + } +} + +impl Connection where S: AsyncRead + AsyncWrite + Unpin, + H: Handler, { const MAX_CONNECTION_MESSAGES: usize = 128; - pub fn new_initiator(stream: S, bus: Bus) -> Self { - let (connection_messages_sender, connection_messages) = - mpsc::channel(Self::MAX_CONNECTION_MESSAGES); + pub fn new_initiator(stream: S, bus: Bus) -> Self { Self { peer: None, state: ConnectionState::Init, conn: AsyncBincodeStream::from(stream).for_async(), - streams: HashMap::new(), is_initiator: true, - bus, - stream_id_allocator: StreamIdAllocator::new(true), - connection_messages, - connection_messages_sender, - registration: None, + send_queue: None, + bus, } } - pub fn new_acceptor(stream: S, bus: Bus) -> Self { - let (connection_messages_sender, connection_messages) = - mpsc::channel(Self::MAX_CONNECTION_MESSAGES); + pub fn new_acceptor(stream: S, bus: Bus) -> Self { Connection { peer: None, state: ConnectionState::Connecting, - streams: HashMap::new(), - connection_messages, - connection_messages_sender, is_initiator: false, bus, + send_queue: None, conn: AsyncBincodeStream::from(stream).for_async(), - stream_id_allocator: StreamIdAllocator::new(false), - registration: None, - } - } - - pub fn handle(&self) -> ConnectionHandle { - ConnectionHandle { - connection_sender: self.connection_messages_sender.clone(), } } @@ -262,135 +153,34 @@ where self.state = ConnectionState::Close; } } - } - Some(command) = self.connection_messages.recv() => { - self.handle_command(command).await; }, + // TODO: pop send queue + Some(m) = self.send_queue.as_mut().unwrap().recv() => { + self.conn.feed(m).await.unwrap(); + // send as many as possible + while let Ok(m) = self.send_queue.as_mut().unwrap().try_recv() { + self.conn.feed(m).await.unwrap(); + } + self.conn.flush().await.unwrap(); + } else => { self.state = ConnectionState::Close; } } } - async fn handle_message(&mut self, message: Message) { - match message { - Message::Node(NodeMessage::OpenStream { - stream_id, - database_id, - }) => { - if self.streams.contains_key(&stream_id) { - self.send_message(Message::Node(NodeMessage::Error( - NodeError::StreamAlreadyExist(stream_id), - ))) - .await; - return; - } - let stream = self.create_stream(stream_id); - if let Err(e) = self.bus.notify_subscription(database_id, stream).await { - tracing::error!("{e}"); - self.send_message(Message::Node(NodeMessage::Error( - NodeError::UnknownDatabase(database_id, stream_id), - ))) - .await; - } - } - Message::Node(NodeMessage::Handshake { .. }) => { - self.close_error(anyhow!("unexpected handshake: closing connection")); - } - Message::Node(NodeMessage::CloseStream { stream_id: id }) => { - self.close_stream(id); - } - Message::Node(NodeMessage::Error(e @ NodeError::HandshakeVersionMismatch { .. })) => { - self.close_error(anyhow!("unexpected peer error: {e}")); - } - Message::Node(NodeMessage::Error(NodeError::UnknownStream(id))) => { - tracing::error!("unkown stream: {id}"); - self.close_stream(id); - } - Message::Node(NodeMessage::Error(e @ NodeError::StreamAlreadyExist(_))) => { - self.state = ConnectionState::CloseError(e.into()); - } - Message::Node(NodeMessage::Error(ref e @ NodeError::UnknownDatabase(_, stream_id))) => { - tracing::error!("{e}"); - self.close_stream(stream_id); - } - Message::Stream { stream_id, payload } => { - match self.streams.get_mut(&stream_id) { - Some(s) => { - // TODO: there is not stream-independant control-flow for now. - // When/if control-flow is implemented, it will be handled here. - if s.sender.send(payload).await.is_err() { - self.close_stream(stream_id); - } - } - None => { - self.send_message(Message::Node(NodeMessage::Error( - NodeError::UnknownStream(stream_id), - ))) - .await; - } - } - } - } + async fn handle_message(&mut self, enveloppe: Enveloppe) { + let incomming = Inbound { + from: self.peer.expect("peer id should be known at this point"), + enveloppe, + }; + self.bus.incomming(incomming).await; } fn close_error(&mut self, error: color_eyre::eyre::Error) { self.state = ConnectionState::CloseError(error); } - fn close_stream(&mut self, id: StreamId) { - self.streams.remove(&id); - } - - async fn handle_command(&mut self, command: ConnectionMessage) { - match command { - ConnectionMessage::Message(m) => { - self.send_message(m).await; - } - ConnectionMessage::CloseStream(stream_id) => { - self.close_stream(stream_id); - self.send_message(Message::Node(NodeMessage::CloseStream { stream_id })) - .await; - } - ConnectionMessage::StreamCreate { database_id, ret } => { - let Some(stream_id) = self.stream_id_allocator.allocate() else { - // TODO: We close the connection here, which will cause a reconnections, and - // reset the stream_id allocator. If that happens in practice, it should be very quick to - // re-establish a connection. If this is an issue, we can either start using - // i64 stream_ids, or use a smarter id allocator. - self.state = ConnectionState::CloseError(anyhow!("Ran out of stream ids")); - return - }; - assert_eq!(stream_id.is_positive(), self.is_initiator); - assert!(!self.streams.contains_key(&stream_id)); - let stream = self.create_stream(stream_id); - self.send_message(Message::Node(NodeMessage::OpenStream { - stream_id, - database_id, - })) - .await; - let _ = ret.send(stream); - } - } - } - - async fn send_message(&mut self, message: Message) { - if let Err(e) = self.conn.send(message).await { - self.close_error(e.into()); - } - } - - fn create_stream(&mut self, stream_id: StreamId) -> Stream { - let (sender, recv) = mpsc::channel(MAX_STREAM_MSG); - let stream = Stream { - stream_id, - sender: PollSender::new(self.connection_messages_sender.clone()), - recv: recv.into(), - }; - self.streams.insert(stream_id, StreamState { sender }); - stream - } - /// wait for a handshake response from peer pub async fn wait_handshake_response_with_deadline( &mut self, @@ -399,41 +189,49 @@ where assert!(matches!(self.state, ConnectionState::Connecting)); match tokio::time::timeout_at(deadline, self.conn.next()).await { - Ok(Some(Ok(Message::Node(NodeMessage::Handshake { - protocol_version, - node_id, - })))) => { + Ok(Some(Ok(Enveloppe { + message: + Message::Handshake { + protocol_version, + node_id, + }, + .. + }))) => { if protocol_version != CURRENT_PROTO_VERSION { - let _ = self - .conn - .send(Message::Node(NodeMessage::Error( - NodeError::HandshakeVersionMismatch { - expected: CURRENT_PROTO_VERSION, - }, - ))) - .await; + let msg = Enveloppe { + from: None, + to: None, + message: Message::Error(ProtoError::HandshakeVersionMismatch { + expected: CURRENT_PROTO_VERSION, + }), + }; + + let _ = self.conn.send(msg).await; bail!("handshake error: invalid peer protocol version"); } else { // when not initiating a connection, respond to handshake message with a // handshake message if !self.is_initiator { - self.conn - .send(Message::Node(NodeMessage::Handshake { + let msg = Enveloppe { + from: None, + to: None, + message: Message::Handshake { protocol_version: CURRENT_PROTO_VERSION, - node_id: self.bus.node_id, - })) - .await?; + node_id: self.bus.node_id(), + }, + }; + self.conn.send(msg).await?; } self.peer = Some(node_id); self.state = ConnectionState::Connected; - self.registration = Some(self.bus.register_connection(node_id, self.handle())); + self.send_queue = Some(self.bus.send_queue().register(node_id)); Ok(()) } } - Ok(Some(Ok(Message::Node(NodeMessage::Error(e))))) => { + Ok(Some(Ok(Enveloppe { message: Message::Error(e), ..}))) => { bail!("handshake error: {e}"); } Ok(Some(Ok(_))) => { @@ -452,12 +250,16 @@ where } async fn initiate_connection(&mut self) -> color_eyre::Result<()> { - self.conn - .send(Message::Node(NodeMessage::Handshake { + let msg = Enveloppe { + from: None, + to: None, + message: Message::Handshake { protocol_version: CURRENT_PROTO_VERSION, - node_id: self.bus.node_id, - })) - .await?; + node_id: self.bus.node_id(), + }, + }; + + self.conn.send(msg).await?; Ok(()) } diff --git a/libsqlx-server/src/linc/connection_pool.rs b/libsqlx-server/src/linc/connection_pool.rs index f5f29c61..812745d3 100644 --- a/libsqlx-server/src/linc/connection_pool.rs +++ b/libsqlx-server/src/linc/connection_pool.rs @@ -5,18 +5,19 @@ use tokio::task::JoinSet; use tokio::time::Duration; use super::connection::Connection; +use super::handler::Handler; use super::net::Connector; use super::{bus::Bus, NodeId}; /// Manages a pool of connections to other peers, handling re-connection. -struct ConnectionPool { +struct ConnectionPool { managed_peers: HashMap, connections: JoinSet, - bus: Bus, + bus: Bus, } -impl ConnectionPool { - pub fn new(bus: Bus, managed_peers: impl IntoIterator) -> Self { +impl ConnectionPool { + pub fn new(bus: Bus, managed_peers: impl IntoIterator) -> Self { Self { managed_peers: managed_peers.into_iter().collect(), connections: JoinSet::new(), diff --git a/libsqlx-server/src/linc/handler.rs b/libsqlx-server/src/linc/handler.rs new file mode 100644 index 00000000..c8db9a89 --- /dev/null +++ b/libsqlx-server/src/linc/handler.rs @@ -0,0 +1,6 @@ +use super::{bus::{Bus}, Inbound}; + +pub trait Handler: Sized + Send + Sync + 'static { + fn handle(&self, bus: &Bus, msg: Inbound); +} + diff --git a/libsqlx-server/src/linc/mod.rs b/libsqlx-server/src/linc/mod.rs index 30b06285..8a3747bd 100644 --- a/libsqlx-server/src/linc/mod.rs +++ b/libsqlx-server/src/linc/mod.rs @@ -1,6 +1,6 @@ use uuid::Uuid; -use self::proto::StreamId; +use self::proto::Enveloppe; pub mod bus; pub mod connection; @@ -8,6 +8,7 @@ pub mod connection_pool; pub mod net; pub mod proto; pub mod server; +pub mod handler; type NodeId = Uuid; type DatabaseId = Uuid; @@ -16,23 +17,16 @@ const CURRENT_PROTO_VERSION: u32 = 1; const MAX_STREAM_MSG: usize = 64; #[derive(Debug)] -pub struct StreamIdAllocator { - direction: i32, - next_id: i32, +pub struct Inbound { + /// Id of the node sending the message + pub from: NodeId, + /// payload + pub enveloppe: Enveloppe, } -impl StreamIdAllocator { - fn new(positive: bool) -> Self { - let direction = if positive { 1 } else { -1 }; - Self { - direction, - next_id: direction, - } - } - - pub fn allocate(&mut self) -> Option { - let id = self.next_id; - self.next_id = id.checked_add(self.direction)?; - Some(StreamId::new(id)) - } +#[derive(Debug)] +pub struct Outbound { + pub to: NodeId, + pub enveloppe: Enveloppe, } + diff --git a/libsqlx-server/src/linc/proto.rs b/libsqlx-server/src/linc/proto.rs index 7de1002a..617f2d87 100644 --- a/libsqlx-server/src/linc/proto.rs +++ b/libsqlx-server/src/linc/proto.rs @@ -1,112 +1,42 @@ -use std::fmt; - use bytes::Bytes; -use serde::{de::Error, Deserialize, Deserializer, Serialize}; +use serde::{Deserialize, Serialize}; use uuid::Uuid; -use super::DatabaseId; +use super::{DatabaseId}; pub type Program = String; -#[derive(Serialize, Deserialize, Debug, PartialEq, Eq, Hash, Clone, Copy)] -pub struct StreamId(#[serde(deserialize_with = "non_zero")] i32); - -impl fmt::Display for StreamId { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - self.0.fmt(f) - } -} - -fn non_zero<'de, D>(d: D) -> Result -where - D: Deserializer<'de>, -{ - let value = i32::deserialize(d)?; - - if value == 0 { - return Err(D::Error::custom("invalid stream_id")); - } - - Ok(value) -} - -impl StreamId { - /// creates a new stream_id. - /// panics if val is zero. - pub fn new(val: i32) -> Self { - assert!(val != 0); - Self(val) - } - - pub fn is_positive(&self) -> bool { - self.0.is_positive() - } - - #[cfg(test)] - pub fn new_unchecked(i: i32) -> Self { - Self(i) - } -} - #[derive(Debug, Serialize, Deserialize, PartialEq, Eq)] -pub enum Message { - /// Messages destined to a node - Node(NodeMessage), - /// message destined to a database - Stream { - stream_id: StreamId, - payload: StreamMessage, - }, +pub struct Enveloppe { + pub from: Option, + pub to: Option, + pub message: Message, } #[derive(Debug, Serialize, Deserialize, PartialEq, Eq)] -pub enum NodeMessage { +pub enum Message { /// Initial message exchanged between nodes when connecting Handshake { protocol_version: u32, node_id: Uuid, }, - /// Request to open a bi-directional stream between the client and the server - OpenStream { - /// Id to give to the newly opened stream - /// Initiator of the connection create streams with positive ids, - /// and acceptor of the connection create streams with negative ids. - stream_id: StreamId, - /// Id of the database to open the stream to. - database_id: Uuid, - }, - /// Close a previously opened stream - CloseStream { stream_id: StreamId }, - /// Error type returned while handling a node message - Error(NodeError), + Replication(ReplicationMessage), + Proxy(ProxyMessage), + Error(ProtoError), } #[derive(Debug, Serialize, Deserialize, thiserror::Error, PartialEq, Eq)] -pub enum NodeError { - /// The requested stream does not exist - #[error("unknown stream: {0}")] - UnknownStream(StreamId), +pub enum ProtoError { /// Incompatible protocol versions #[error("invalid protocol version, expected: {expected}")] HandshakeVersionMismatch { expected: u32 }, - #[error("stream {0} already exists")] - StreamAlreadyExist(StreamId), - #[error("cannot open stream {1}: unknown database {0}")] - UnknownDatabase(DatabaseId, StreamId), -} - -#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)] -pub enum StreamMessage { - /// Replication message between a replica and a primary - Replication(ReplicationMessage), - /// Proxy message between a replica and a primary - Proxy(ProxyMessage), - #[cfg(test)] - Dummy, + #[error("unknown database {0}")] + UnknownDatabase(DatabaseId), } #[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)] pub enum ReplicationMessage { + Handshake {}, HandshakeResponse { /// id of the replication log log_id: Uuid, @@ -126,8 +56,6 @@ pub enum ReplicationMessage { /// a batch of frames part of the transaction. frames: Vec, }, - /// Error occurred handling a replication message - Error(ReplicationError), } #[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)] @@ -161,8 +89,6 @@ pub enum ProxyMessage { CancelRequest { req_id: u32 }, /// Close Connection with passed id. CloseConnection { connection_id: u32 }, - /// Error returned when handling a proxied query message. - Error(ProxyError), } /// Steps applied to the query builder transducer to build a response to a proxied query. @@ -204,11 +130,3 @@ pub struct Column { /// for now, the stringified version of a sqld::error::Error. #[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)] pub struct StepError(String); - -/// TBD -#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)] -pub enum ProxyError {} - -/// TBD -#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)] -pub enum ReplicationError {} diff --git a/libsqlx-server/src/linc/server.rs b/libsqlx-server/src/linc/server.rs index 08c205ef..0594bd9e 100644 --- a/libsqlx-server/src/linc/server.rs +++ b/libsqlx-server/src/linc/server.rs @@ -3,17 +3,18 @@ use tokio::task::JoinSet; use crate::linc::connection::Connection; -use super::bus::Bus; +use super::bus::{Bus}; +use super::handler::Handler; -pub struct Server { +pub struct Server { /// reference to the bus - bus: Bus, + bus: Bus, /// Connection tasks owned by the server connections: JoinSet>, } -impl Server { - pub fn new(bus: Bus) -> Self { +impl Server { + pub fn new(bus: Bus) -> Self { Self { bus, connections: JoinSet::new(), @@ -25,7 +26,6 @@ impl Server { pub async fn close_connections(&mut self) { self.connections.abort_all(); while self.connections.join_next().await.is_some() {} - assert!(self.bus.is_empty()); } pub async fn run(mut self, mut listener: L) @@ -57,7 +57,7 @@ impl Server { { let bus = self.bus.clone(); let fut = async move { - let connection = Connection::new_acceptor(stream, bus.clone()); + let connection = Connection::new_acceptor(stream, bus); connection.run().await; Ok(()) }; @@ -71,7 +71,7 @@ mod test { use std::sync::Arc; use crate::linc::{ - proto::{ProxyMessage, StreamMessage}, + proto::{ProxyMessage}, DatabaseId, NodeId, }; From 58b3aaf74dadab21657eb4a6eede994576716465 Mon Sep 17 00:00:00 2001 From: ad hoc Date: Wed, 12 Jul 2023 11:56:53 +0200 Subject: [PATCH 18/64] deliver message to allocations --- Cargo.lock | 23 ++++++- libsqlx-server/Cargo.toml | 2 + libsqlx-server/src/allocation/config.rs | 6 +- libsqlx-server/src/allocation/mod.rs | 35 ++++++++++- libsqlx-server/src/config.rs | 1 + libsqlx-server/src/http/admin.rs | 7 ++- libsqlx-server/src/http/user/extractors.rs | 7 ++- libsqlx-server/src/http/user/mod.rs | 4 ++ libsqlx-server/src/linc/bus.rs | 52 +++++++-------- libsqlx-server/src/linc/connection.rs | 18 +++--- libsqlx-server/src/linc/connection_pool.rs | 12 ++-- libsqlx-server/src/linc/handler.rs | 9 ++- libsqlx-server/src/linc/mod.rs | 23 ++++--- libsqlx-server/src/linc/proto.rs | 63 ++++++++++++++++--- libsqlx-server/src/linc/server.rs | 21 +++---- libsqlx-server/src/main.rs | 7 ++- libsqlx-server/src/manager.rs | 40 +++++++++--- libsqlx-server/src/meta.rs | 49 ++++++++++++--- libsqlx/Cargo.toml | 2 - libsqlx/src/database/libsql/injector/mod.rs | 2 +- libsqlx/src/database/libsql/mod.rs | 8 ++- .../database/libsql/replication_log/logger.rs | 29 ++++++--- libsqlx/src/error.rs | 8 --- sqld/src/query_result_builder.rs | 1 - 24 files changed, 303 insertions(+), 126 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 2eb87fc8..2a3cc0f6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2411,6 +2411,15 @@ dependencies = [ "simple_asn1", ] +[[package]] +name = "keccak" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f6d5ed8676d904364de097082f4e7d240b571b67989ced0240f08b7f966f940" +dependencies = [ + "cpufeatures", +] + [[package]] name = "lazy_static" version = "1.4.0" @@ -2511,7 +2520,6 @@ dependencies = [ "crc", "crossbeam", "fallible-iterator 0.3.0", - "futures", "itertools 0.11.0", "nix", "once_cell", @@ -2525,7 +2533,6 @@ dependencies = [ "sqlite3-parser 0.9.0", "tempfile", "thiserror", - "tokio", "tracing", "uuid", ] @@ -2535,6 +2542,7 @@ name = "libsqlx-server" version = "0.1.0" dependencies = [ "async-bincode", + "async-trait", "axum", "base64 0.21.2", "bincode", @@ -2554,6 +2562,7 @@ dependencies = [ "serde", "serde_json", "sha2", + "sha3", "sled", "thiserror", "tokio", @@ -3986,6 +3995,16 @@ dependencies = [ "sha2", ] +[[package]] +name = "sha3" +version = "0.10.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75872d278a8f37ef87fa0ddbda7802605cb18344497949862c0d4dcb291eba60" +dependencies = [ + "digest", + "keccak", +] + [[package]] name = "sharded-slab" version = "0.1.4" diff --git a/libsqlx-server/Cargo.toml b/libsqlx-server/Cargo.toml index 6393f6cb..efff738e 100644 --- a/libsqlx-server/Cargo.toml +++ b/libsqlx-server/Cargo.toml @@ -7,6 +7,7 @@ edition = "2021" [dependencies] async-bincode = { version = "0.7.1", features = ["tokio"] } +async-trait = "0.1.71" axum = "0.6.18" base64 = "0.21.2" bincode = "1.3.3" @@ -26,6 +27,7 @@ regex = "1.9.1" serde = { version = "1.0.166", features = ["derive", "rc"] } serde_json = "1.0.100" sha2 = "0.10.7" +sha3 = "0.10.8" sled = "0.34.7" thiserror = "1.0.43" tokio = { version = "1.29.1", features = ["full"] } diff --git a/libsqlx-server/src/allocation/config.rs b/libsqlx-server/src/allocation/config.rs index f5839e9c..9d1bab34 100644 --- a/libsqlx-server/src/allocation/config.rs +++ b/libsqlx-server/src/allocation/config.rs @@ -1,5 +1,7 @@ use serde::{Deserialize, Serialize}; +use crate::linc::NodeId; + /// Structural supertype of AllocConfig, used for checking the meta version. Subsequent version of /// AllocConfig need to conform to this prototype. #[derive(Debug, Serialize, Deserialize)] @@ -10,12 +12,12 @@ struct ConfigVersion { #[derive(Debug, Serialize, Deserialize)] pub struct AllocConfig { pub max_conccurent_connection: u32, - pub id: String, + pub db_name: String, pub db_config: DbConfig, } #[derive(Debug, Serialize, Deserialize)] pub enum DbConfig { Primary {}, - Replica { primary_node_id: String }, + Replica { primary_node_id: NodeId }, } diff --git a/libsqlx-server/src/allocation/mod.rs b/libsqlx-server/src/allocation/mod.rs index 38c29ff3..c9f88ed0 100644 --- a/libsqlx-server/src/allocation/mod.rs +++ b/libsqlx-server/src/allocation/mod.rs @@ -1,7 +1,7 @@ use std::path::PathBuf; use std::sync::Arc; -use libsqlx::libsql::{LibsqlDatabase, LogCompactor, LogFile, PrimaryType}; +use libsqlx::libsql::{LibsqlDatabase, LogCompactor, LogFile, PrimaryType, ReplicaType}; use libsqlx::Database as _; use tokio::sync::{mpsc, oneshot}; use tokio::task::{block_in_place, JoinSet}; @@ -9,6 +9,9 @@ use tokio::task::{block_in_place, JoinSet}; use crate::hrana; use crate::hrana::http::handle_pipeline; use crate::hrana::http::proto::{PipelineRequestBody, PipelineResponseBody}; +use crate::linc::bus::Dispatch; +use crate::linc::{Inbound, NodeId}; +use crate::meta::DatabaseId; use self::config::{AllocConfig, DbConfig}; @@ -28,10 +31,15 @@ pub enum AllocationMessage { req: PipelineRequestBody, ret: oneshot::Sender>, }, + Inbound(Inbound), } pub enum Database { Primary(libsqlx::libsql::LibsqlDatabase), + Replica { + db: libsqlx::libsql::LibsqlDatabase, + primary_node_id: NodeId, + }, } struct Compactor; @@ -65,6 +73,7 @@ impl Database { fn connect(&self) -> Box { match self { Database::Primary(db) => Box::new(db.connect().unwrap()), + Database::Replica { db, .. } => Box::new(db.connect().unwrap()), } } } @@ -78,6 +87,9 @@ pub struct Allocation { pub max_concurrent_connections: u32, pub hrana_server: Arc, + /// handle to the message bus, to send messages + pub dispatcher: Arc, + pub db_name: String, } pub struct ConnectionHandle { @@ -115,11 +127,13 @@ impl Allocation { AllocationMessage::HranaPipelineReq { req, ret} => { let res = handle_pipeline(&self.hrana_server.clone(), req, || async { let conn= self.new_conn().await; - dbg!(); Ok(conn) }).await; let _ = ret.send(res); } + AllocationMessage::Inbound(msg) => { + self.handle_inbound(msg).await; + } } }, maybe_id = self.connections_futs.join_next() => { @@ -132,6 +146,23 @@ impl Allocation { } } + async fn handle_inbound(&mut self, msg: Inbound) { + debug_assert_eq!(msg.enveloppe.to, Some(DatabaseId::from_name(&self.db_name))); + + match msg.enveloppe.message { + crate::linc::proto::Message::Handshake { .. } => todo!(), + crate::linc::proto::Message::ReplicationHandshake { .. } => todo!(), + crate::linc::proto::Message::ReplicationHandshakeResponse { .. } => todo!(), + crate::linc::proto::Message::Replicate { .. } => todo!(), + crate::linc::proto::Message::Transaction { .. } => todo!(), + crate::linc::proto::Message::ProxyRequest { .. } => todo!(), + crate::linc::proto::Message::ProxyResponse { .. } => todo!(), + crate::linc::proto::Message::CancelRequest { .. } => todo!(), + crate::linc::proto::Message::CloseConnection { .. } => todo!(), + crate::linc::proto::Message::Error(_) => todo!(), + } + } + async fn new_conn(&mut self) -> ConnectionHandle { let id = self.next_conn_id(); let conn = block_in_place(|| self.database.connect()); diff --git a/libsqlx-server/src/config.rs b/libsqlx-server/src/config.rs index cb2d68b5..f0f9ca0c 100644 --- a/libsqlx-server/src/config.rs +++ b/libsqlx-server/src/config.rs @@ -26,6 +26,7 @@ impl Config { #[derive(Deserialize, Debug, Clone)] pub struct ClusterConfig { + pub id: u64, /// Address to bind this node to #[serde(default = "default_linc_addr")] pub addr: SocketAddr, diff --git a/libsqlx-server/src/http/admin.rs b/libsqlx-server/src/http/admin.rs index 346987c4..8a08187e 100644 --- a/libsqlx-server/src/http/admin.rs +++ b/libsqlx-server/src/http/admin.rs @@ -8,6 +8,7 @@ use tokio::io::{AsyncRead, AsyncWrite}; use crate::{ allocation::config::{AllocConfig, DbConfig}, + linc::NodeId, meta::Store, }; @@ -55,7 +56,7 @@ struct AllocateReq { #[serde(tag = "type", rename_all = "snake_case")] pub enum DbConfigReq { Primary {}, - Replica { primary_node_id: String }, + Replica { primary_node_id: NodeId }, } async fn allocate( @@ -64,7 +65,7 @@ async fn allocate( ) -> Result, Json> { let config = AllocConfig { max_conccurent_connection: req.max_conccurent_connection.unwrap_or(16), - id: req.alloc_id.clone(), + db_name: req.alloc_id.clone(), db_config: match req.config { DbConfigReq::Primary {} => DbConfig::Primary {}, DbConfigReq::Replica { primary_node_id } => DbConfig::Replica { primary_node_id }, @@ -93,7 +94,7 @@ async fn list_allocs( .list_allocs() .await .into_iter() - .map(|cfg| AllocView { id: cfg.id }) + .map(|cfg| AllocView { id: cfg.db_name }) .collect(); Ok(Json(ListAllocResp { allocs })) diff --git a/libsqlx-server/src/http/user/extractors.rs b/libsqlx-server/src/http/user/extractors.rs index 2b3f5a14..962eb060 100644 --- a/libsqlx-server/src/http/user/extractors.rs +++ b/libsqlx-server/src/http/user/extractors.rs @@ -4,7 +4,7 @@ use axum::async_trait; use axum::extract::FromRequestParts; use hyper::http::request::Parts; -use crate::database::Database; +use crate::{database::Database, meta::DatabaseId}; use super::{error::UserApiError, UserApiState}; @@ -18,8 +18,9 @@ impl FromRequestParts> for Database { ) -> Result { let Some(host) = parts.headers.get("host") else { return Err(UserApiError::MissingHost) }; let Ok(host_str) = std::str::from_utf8(host.as_bytes()) else {return Err(UserApiError::MissingHost)}; - let db_id = parse_host(host_str)?; - let Some(sender) = state.manager.alloc(db_id).await else { return Err(UserApiError::UnknownDatabase(db_id.to_owned())) }; + let db_name = parse_host(host_str)?; + let db_id = DatabaseId::from_name(db_name); + let Some(sender) = state.manager.alloc(db_id, state.bus.clone()).await else { return Err(UserApiError::UnknownDatabase(db_name.to_owned())) }; Ok(Database { sender }) } diff --git a/libsqlx-server/src/http/user/mod.rs b/libsqlx-server/src/http/user/mod.rs index f357499c..c947fb8b 100644 --- a/libsqlx-server/src/http/user/mod.rs +++ b/libsqlx-server/src/http/user/mod.rs @@ -8,6 +8,7 @@ use tokio::io::{AsyncRead, AsyncWrite}; use crate::database::Database; use crate::hrana::http::proto::{PipelineRequestBody, PipelineResponseBody}; +use crate::linc::bus::Bus; use crate::manager::Manager; mod error; @@ -15,10 +16,12 @@ mod extractors; pub struct Config { pub manager: Arc, + pub bus: Arc>>, } struct UserApiState { manager: Arc, + bus: Arc>>, } pub async fn run_user_api(config: Config, listener: I) -> Result<()> @@ -28,6 +31,7 @@ where { let state = UserApiState { manager: config.manager, + bus: config.bus, }; let app = Router::new() diff --git a/libsqlx-server/src/linc/bus.rs b/libsqlx-server/src/linc/bus.rs index 7c52ec42..bf9a2cc4 100644 --- a/libsqlx-server/src/linc/bus.rs +++ b/libsqlx-server/src/linc/bus.rs @@ -1,23 +1,8 @@ use std::sync::Arc; -use uuid::Uuid; - -use super::{connection::SendQueue, handler::Handler, Outbound, Inbound}; - -type NodeId = Uuid; -type DatabaseId = Uuid; +use super::{connection::SendQueue, handler::Handler, Inbound, NodeId, Outbound}; pub struct Bus { - inner: Arc>, -} - -impl Clone for Bus { - fn clone(&self) -> Self { - Self { inner: self.inner.clone() } - } -} - -struct BusInner { node_id: NodeId, handler: H, send_queue: SendQueue, @@ -27,33 +12,38 @@ impl Bus { pub fn new(node_id: NodeId, handler: H) -> Self { let send_queue = SendQueue::new(); Self { - inner: Arc::new(BusInner { - node_id, - handler, - send_queue, - }), + node_id, + handler, + send_queue, } } pub fn node_id(&self) -> NodeId { - self.inner.node_id + self.node_id } - pub async fn incomming(&self, incomming: Inbound) { - self.inner.handler.handle(self, incomming); + pub async fn incomming(self: &Arc, incomming: Inbound) { + self.handler.handle(self.clone(), incomming); } - pub async fn dispatch(&self, msg: Outbound) { + pub fn send_queue(&self) -> &SendQueue { + &self.send_queue + } +} + +#[async_trait::async_trait] +pub trait Dispatch: Send + Sync + 'static { + async fn dispatch(&self, msg: Outbound); +} + +#[async_trait::async_trait] +impl Dispatch for Bus { + async fn dispatch(&self, msg: Outbound) { assert!( msg.to != self.node_id(), "trying to send a message to ourself!" ); // This message is outbound. - self.inner.send_queue.enqueue(msg).await; - } - - pub fn send_queue(&self) -> &SendQueue { - &self.inner.send_queue - + self.send_queue.enqueue(msg).await; } } diff --git a/libsqlx-server/src/linc/connection.rs b/libsqlx-server/src/linc/connection.rs index 55977623..a96b8179 100644 --- a/libsqlx-server/src/linc/connection.rs +++ b/libsqlx-server/src/linc/connection.rs @@ -1,4 +1,5 @@ use std::collections::HashMap; +use std::sync::Arc; use async_bincode::tokio::AsyncBincodeStream; use async_bincode::AsyncDestination; @@ -12,10 +13,10 @@ use tokio::time::{Duration, Instant}; use crate::linc::proto::ProtoError; use crate::linc::CURRENT_PROTO_VERSION; -use super::bus::{Bus}; +use super::bus::Bus; use super::handler::Handler; use super::proto::{Enveloppe, Message}; -use super::{NodeId, Outbound, Inbound}; +use super::{Inbound, NodeId, Outbound}; /// A connection to another node. Manage the connection state, and (de)register streams with the /// `Bus` @@ -30,7 +31,7 @@ pub struct Connection { is_initiator: bool, /// send queue for this connection send_queue: Option>, - bus: Bus, + bus: Arc>, } #[derive(Debug)] @@ -85,18 +86,18 @@ where { const MAX_CONNECTION_MESSAGES: usize = 128; - pub fn new_initiator(stream: S, bus: Bus) -> Self { + pub fn new_initiator(stream: S, bus: Arc>) -> Self { Self { peer: None, state: ConnectionState::Init, conn: AsyncBincodeStream::from(stream).for_async(), is_initiator: true, send_queue: None, - bus, + bus, } } - pub fn new_acceptor(stream: S, bus: Bus) -> Self { + pub fn new_acceptor(stream: S, bus: Arc>) -> Self { Connection { peer: None, state: ConnectionState::Connecting, @@ -231,7 +232,10 @@ where Ok(()) } } - Ok(Some(Ok(Enveloppe { message: Message::Error(e), ..}))) => { + Ok(Some(Ok(Enveloppe { + message: Message::Error(e), + .. + }))) => { bail!("handshake error: {e}"); } Ok(Some(Ok(_))) => { diff --git a/libsqlx-server/src/linc/connection_pool.rs b/libsqlx-server/src/linc/connection_pool.rs index 812745d3..26c5d923 100644 --- a/libsqlx-server/src/linc/connection_pool.rs +++ b/libsqlx-server/src/linc/connection_pool.rs @@ -1,4 +1,5 @@ use std::collections::HashMap; +use std::sync::Arc; use itertools::Itertools; use tokio::task::JoinSet; @@ -13,11 +14,14 @@ use super::{bus::Bus, NodeId}; struct ConnectionPool { managed_peers: HashMap, connections: JoinSet, - bus: Bus, + bus: Arc>, } impl ConnectionPool { - pub fn new(bus: Bus, managed_peers: impl IntoIterator) -> Self { + pub fn new( + bus: Arc>, + managed_peers: impl IntoIterator, + ) -> Self { Self { managed_peers: managed_peers.into_iter().collect(), connections: JoinSet::new(), @@ -77,14 +81,14 @@ mod test { use tokio::sync::Notify; use tokio_stream::StreamExt; - use crate::linc::{server::Server, DatabaseId}; + use crate::linc::{server::Server, AllocId}; use super::*; #[test] fn manage_connections() { let mut sim = turmoil::Builder::new().build(); - let database_id = DatabaseId::new_v4(); + let database_id = AllocId::new_v4(); let notify = Arc::new(Notify::new()); let expected_msg = crate::linc::proto::StreamMessage::Proxy( diff --git a/libsqlx-server/src/linc/handler.rs b/libsqlx-server/src/linc/handler.rs index c8db9a89..6a6ae6f8 100644 --- a/libsqlx-server/src/linc/handler.rs +++ b/libsqlx-server/src/linc/handler.rs @@ -1,6 +1,9 @@ -use super::{bus::{Bus}, Inbound}; +use std::sync::Arc; +use super::bus::Bus; +use super::Inbound; + +#[async_trait::async_trait] pub trait Handler: Sized + Send + Sync + 'static { - fn handle(&self, bus: &Bus, msg: Inbound); + async fn handle(&self, bus: Arc>, msg: Inbound); } - diff --git a/libsqlx-server/src/linc/mod.rs b/libsqlx-server/src/linc/mod.rs index 8a3747bd..fa787e87 100644 --- a/libsqlx-server/src/linc/mod.rs +++ b/libsqlx-server/src/linc/mod.rs @@ -1,17 +1,14 @@ -use uuid::Uuid; - -use self::proto::Enveloppe; +use self::proto::{Enveloppe, Message}; pub mod bus; pub mod connection; pub mod connection_pool; +pub mod handler; pub mod net; pub mod proto; pub mod server; -pub mod handler; -type NodeId = Uuid; -type DatabaseId = Uuid; +pub type NodeId = u64; const CURRENT_PROTO_VERSION: u32 = 1; const MAX_STREAM_MSG: usize = 64; @@ -24,9 +21,21 @@ pub struct Inbound { pub enveloppe: Enveloppe, } +impl Inbound { + pub fn respond(&self, message: Message) -> Outbound { + Outbound { + to: self.from, + enveloppe: Enveloppe { + from: self.enveloppe.to, + to: self.enveloppe.from, + message, + }, + } + } +} + #[derive(Debug)] pub struct Outbound { pub to: NodeId, pub enveloppe: Enveloppe, } - diff --git a/libsqlx-server/src/linc/proto.rs b/libsqlx-server/src/linc/proto.rs index 617f2d87..c099cbd1 100644 --- a/libsqlx-server/src/linc/proto.rs +++ b/libsqlx-server/src/linc/proto.rs @@ -2,7 +2,9 @@ use bytes::Bytes; use serde::{Deserialize, Serialize}; use uuid::Uuid; -use super::{DatabaseId}; +use crate::meta::DatabaseId; + +use super::NodeId; pub type Program = String; @@ -18,10 +20,55 @@ pub enum Message { /// Initial message exchanged between nodes when connecting Handshake { protocol_version: u32, - node_id: Uuid, + node_id: NodeId, + }, + ReplicationHandshake { + database_name: String, + }, + ReplicationHandshakeResponse { + /// id of the replication log + log_id: Uuid, + /// current frame_no of the primary + current_frame_no: u64, + }, + Replicate { + /// next frame no to send + next_frame_no: u64, + }, + /// a batch of frames that are part of the same transaction + Transaction { + /// if not None, then the last frame is a commit frame, and this is the new size of the database. + size_after: Option, + /// frame_no of the last frame in frames + end_frame_no: u64, + /// a batch of frames part of the transaction. + frames: Vec, + }, + /// Proxy a query to a primary + ProxyRequest { + /// id of the connection to perform the query against + /// If the connection doesn't already exist it is created + /// Id of the request. + /// Responses to this request must have the same id. + connection_id: u32, + req_id: u32, + program: Program, + }, + /// Response to a proxied query + ProxyResponse { + /// id of the request this message is a response to. + req_id: u32, + /// Collection of steps to drive the query builder transducer. + row_step: Vec, + }, + /// Stop processing request `id`. + CancelRequest { + req_id: u32, + }, + /// Close Connection with passed id. + CloseConnection { + connection_id: u32, }, - Replication(ReplicationMessage), - Proxy(ProxyMessage), Error(ProtoError), } @@ -31,13 +78,15 @@ pub enum ProtoError { #[error("invalid protocol version, expected: {expected}")] HandshakeVersionMismatch { expected: u32 }, #[error("unknown database {0}")] - UnknownDatabase(DatabaseId), + UnknownDatabase(String), } #[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)] pub enum ReplicationMessage { - Handshake {}, - HandshakeResponse { + ReplicationHandshake { + database_name: String, + }, + ReplicationHandshakeResponse { /// id of the replication log log_id: Uuid, /// current frame_no of the primary diff --git a/libsqlx-server/src/linc/server.rs b/libsqlx-server/src/linc/server.rs index 0594bd9e..b462d0a1 100644 --- a/libsqlx-server/src/linc/server.rs +++ b/libsqlx-server/src/linc/server.rs @@ -1,20 +1,22 @@ +use std::sync::Arc; + use tokio::io::{AsyncRead, AsyncWrite}; use tokio::task::JoinSet; use crate::linc::connection::Connection; -use super::bus::{Bus}; +use super::bus::Bus; use super::handler::Handler; pub struct Server { /// reference to the bus - bus: Bus, + bus: Arc>, /// Connection tasks owned by the server connections: JoinSet>, } impl Server { - pub fn new(bus: Bus) -> Self { + pub fn new(bus: Arc>) -> Self { Self { bus, connections: JoinSet::new(), @@ -70,10 +72,7 @@ impl Server { mod test { use std::sync::Arc; - use crate::linc::{ - proto::{ProxyMessage}, - DatabaseId, NodeId, - }; + use crate::linc::{proto::ProxyMessage, AllocId, NodeId}; use super::*; @@ -125,7 +124,7 @@ mod test { let mut sim = turmoil::Builder::new().build(); let host_node_id = NodeId::new_v4(); - let stream_db_id = DatabaseId::new_v4(); + let stream_db_id = AllocId::new_v4(); let notify = Arc::new(Notify::new()); let expected_msg = StreamMessage::Proxy(ProxyMessage::ProxyRequest { connection_id: 12, @@ -195,7 +194,7 @@ mod test { let mut sim = turmoil::Builder::new().build(); let host_node_id = NodeId::new_v4(); - let database_id = DatabaseId::new_v4(); + let database_id = AllocId::new_v4(); let notify = Arc::new(Notify::new()); sim.host("host", { @@ -251,7 +250,7 @@ mod test { let host_node_id = NodeId::new_v4(); let notify = Arc::new(Notify::new()); let client_id = NodeId::new_v4(); - let database_id = DatabaseId::new_v4(); + let database_id = AllocId::new_v4(); let expected_msg = StreamMessage::Proxy(ProxyMessage::ProxyRequest { connection_id: 12, req_id: 1, @@ -309,7 +308,7 @@ mod test { let host_node_id = NodeId::new_v4(); let client_id = NodeId::new_v4(); - let database_id = DatabaseId::new_v4(); + let database_id = AllocId::new_v4(); sim.host("host", { move || async move { diff --git a/libsqlx-server/src/main.rs b/libsqlx-server/src/main.rs index a8360402..d5b0c35f 100644 --- a/libsqlx-server/src/main.rs +++ b/libsqlx-server/src/main.rs @@ -8,6 +8,7 @@ use config::{AdminApiConfig, UserApiConfig}; use http::admin::run_admin_api; use http::user::run_user_api; use hyper::server::conn::AddrIncoming; +use linc::bus::Bus; use manager::Manager; use meta::Store; use tokio::task::JoinSet; @@ -49,10 +50,11 @@ async fn spawn_user_api( set: &mut JoinSet>, config: &UserApiConfig, manager: Arc, + bus: Arc>>, ) -> Result<()> { let user_api_listener = tokio::net::TcpListener::bind(config.addr).await?; set.spawn(run_user_api( - http::user::Config { manager }, + http::user::Config { manager, bus }, AddrIncoming::from_listener(user_api_listener)?, )); @@ -71,9 +73,10 @@ async fn main() -> Result<()> { let store = Arc::new(Store::new(&config.db_path)); let manager = Arc::new(Manager::new(config.db_path.clone(), store.clone(), 100)); + let bus = Arc::new(Bus::new(config.cluster.id, manager.clone())); spawn_admin_api(&mut join_set, &config.admin_api, store.clone()).await?; - spawn_user_api(&mut join_set, &config.user_api, manager).await?; + spawn_user_api(&mut join_set, &config.user_api, manager, bus).await?; join_set.join_next().await; diff --git a/libsqlx-server/src/manager.rs b/libsqlx-server/src/manager.rs index 48315e0a..62f86479 100644 --- a/libsqlx-server/src/manager.rs +++ b/libsqlx-server/src/manager.rs @@ -7,10 +7,13 @@ use tokio::task::JoinSet; use crate::allocation::{Allocation, AllocationMessage, Database}; use crate::hrana; -use crate::meta::Store; +use crate::linc::bus::Bus; +use crate::linc::handler::Handler; +use crate::linc::Inbound; +use crate::meta::{DatabaseId, Store}; pub struct Manager { - cache: Cache>, + cache: Cache>, meta_store: Arc, db_path: PathBuf, } @@ -27,13 +30,17 @@ impl Manager { } /// Returns a handle to an allocation, lazily initializing if it isn't already loaded. - pub async fn alloc(&self, alloc_id: &str) -> Option> { - if let Some(sender) = self.cache.get(alloc_id) { + pub async fn alloc( + self: &Arc, + database_id: DatabaseId, + bus: Arc>>, + ) -> Option> { + if let Some(sender) = self.cache.get(&database_id) { return Some(sender.clone()); } - if let Some(config) = self.meta_store.meta(alloc_id).await { - let path = self.db_path.join("dbs").join(alloc_id); + if let Some(config) = self.meta_store.meta(&database_id).await { + let path = self.db_path.join("dbs").join(database_id.to_string()); tokio::fs::create_dir_all(&path).await.unwrap(); let (alloc_sender, inbox) = mpsc::channel(MAX_ALLOC_MESSAGE_QUEUE_LEN); let alloc = Allocation { @@ -42,14 +49,14 @@ impl Manager { connections_futs: JoinSet::new(), next_conn_id: 0, max_concurrent_connections: config.max_conccurent_connection, - hrana_server: Arc::new(hrana::http::Server::new(None)), // TODO: handle self URL? + hrana_server: Arc::new(hrana::http::Server::new(None)), + dispatcher: bus, // TODO: handle self URL? + db_name: config.db_name, }; tokio::spawn(alloc.run()); - self.cache - .insert(alloc_id.to_string(), alloc_sender.clone()) - .await; + self.cache.insert(database_id, alloc_sender.clone()).await; return Some(alloc_sender); } @@ -57,3 +64,16 @@ impl Manager { None } } + +#[async_trait::async_trait] +impl Handler for Arc { + async fn handle(&self, bus: Arc>, msg: Inbound) { + if let Some(sender) = self + .clone() + .alloc(msg.enveloppe.to.unwrap(), bus.clone()) + .await + { + let _ = sender.send(AllocationMessage::Inbound(msg)).await; + } + } +} diff --git a/libsqlx-server/src/meta.rs b/libsqlx-server/src/meta.rs index 06e37a76..b71b33eb 100644 --- a/libsqlx-server/src/meta.rs +++ b/libsqlx-server/src/meta.rs @@ -1,7 +1,11 @@ +use core::fmt; use std::path::Path; +use serde::{Deserialize, Serialize}; +use sha3::digest::{ExtendableOutput, Update, XofReader}; +use sha3::Shake128; use sled::Tree; -use uuid::Uuid; +use tokio::task::block_in_place; use crate::allocation::config::AllocConfig; @@ -11,6 +15,32 @@ pub struct Store { meta_store: Tree, } +#[derive(Serialize, Deserialize, PartialEq, Eq, Debug, Hash, Clone, Copy)] +pub struct DatabaseId([u8; 16]); + +impl DatabaseId { + pub fn from_name(name: &str) -> Self { + let mut hasher = Shake128::default(); + hasher.update(name.as_bytes()); + let mut reader = hasher.finalize_xof(); + let mut out = [0; 16]; + reader.read(&mut out); + Self(out) + } +} + +impl fmt::Display for DatabaseId { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{:x}", u128::from_be_bytes(self.0)) + } +} + +impl AsRef<[u8]> for DatabaseId { + fn as_ref(&self) -> &[u8] { + &self.0 + } +} + impl Store { pub fn new(path: &Path) -> Self { std::fs::create_dir_all(&path).unwrap(); @@ -21,31 +51,32 @@ impl Store { Self { meta_store } } - pub async fn allocate(&self, alloc_id: &str, meta: &AllocConfig) { + pub async fn allocate(&self, database_name: &str, meta: &AllocConfig) { //TODO: Handle conflict - tokio::task::block_in_place(|| { + block_in_place(|| { let meta_bytes = bincode::serialize(meta).unwrap(); + let id = DatabaseId::from_name(database_name); self.meta_store - .compare_and_swap(alloc_id, None as Option<&[u8]>, Some(meta_bytes)) + .compare_and_swap(id, None as Option<&[u8]>, Some(meta_bytes)) .unwrap() .unwrap(); }); } - pub async fn deallocate(&self, _alloc_id: Uuid) { + pub async fn deallocate(&self, _database_name: &str) { todo!() } - pub async fn meta(&self, alloc_id: &str) -> Option { - tokio::task::block_in_place(|| { - let config = self.meta_store.get(alloc_id).unwrap()?; + pub async fn meta(&self, database_id: &DatabaseId) -> Option { + block_in_place(|| { + let config = self.meta_store.get(database_id).unwrap()?; let config = bincode::deserialize(config.as_ref()).unwrap(); Some(config) }) } pub async fn list_allocs(&self) -> Vec { - tokio::task::block_in_place(|| { + block_in_place(|| { let mut out = Vec::new(); for kv in self.meta_store.iter() { let (_k, v) = kv.unwrap(); diff --git a/libsqlx/Cargo.toml b/libsqlx/Cargo.toml index b0f00521..abdb39ad 100644 --- a/libsqlx/Cargo.toml +++ b/libsqlx/Cargo.toml @@ -12,8 +12,6 @@ serde = "1.0.164" serde_json = "1.0.99" rusqlite = { workspace = true } anyhow = "1.0.71" -futures = "0.3.28" -tokio = { version = "1.28.2", features = ["sync", "time"] } sqlite3-parser = "0.9.0" fallible-iterator = "0.3.0" bytes = "1.4.0" diff --git a/libsqlx/src/database/libsql/injector/mod.rs b/libsqlx/src/database/libsql/injector/mod.rs index 1682e3b4..19fd51ce 100644 --- a/libsqlx/src/database/libsql/injector/mod.rs +++ b/libsqlx/src/database/libsql/injector/mod.rs @@ -37,7 +37,7 @@ pub struct Injector { /// This trait trait is used to record the last committed frame_no to the log. /// The implementer can persist the pre and post commit frame no, and compare them in the event of /// a crash; if the pre and post commit frame_no don't match, then the log may be corrupted. -pub trait InjectorCommitHandler: 'static { +pub trait InjectorCommitHandler: Send + 'static { fn pre_commit(&mut self, frame_no: FrameNo) -> anyhow::Result<()>; fn post_commit(&mut self, frame_no: FrameNo) -> anyhow::Result<()>; } diff --git a/libsqlx/src/database/libsql/mod.rs b/libsqlx/src/database/libsql/mod.rs index 2844a204..9cb32c20 100644 --- a/libsqlx/src/database/libsql/mod.rs +++ b/libsqlx/src/database/libsql/mod.rs @@ -139,7 +139,12 @@ impl LibsqlDatabase { dirty: bool, ) -> crate::Result { let ty = PrimaryType { - logger: Arc::new(ReplicationLogger::open(&db_path, dirty, compactor)?), + logger: Arc::new(ReplicationLogger::open( + &db_path, + dirty, + compactor, + Box::new(|_| ()), + )?), }; Ok(Self::new(db_path, ty)) } @@ -174,7 +179,6 @@ impl Database for LibsqlDatabase { type Connection = LibsqlConnection<::Context>; fn connect(&self) -> Result { - dbg!(); Ok( LibsqlConnection::<::Context>::new( &self.db_path, diff --git a/libsqlx/src/database/libsql/replication_log/logger.rs b/libsqlx/src/database/libsql/replication_log/logger.rs index aebff0db..fe371258 100644 --- a/libsqlx/src/database/libsql/replication_log/logger.rs +++ b/libsqlx/src/database/libsql/replication_log/logger.rs @@ -19,7 +19,6 @@ use sqld_libsql_bindings::ffi::types::{ use sqld_libsql_bindings::ffi::PageHdrIter; use sqld_libsql_bindings::init_static_wal_method; use sqld_libsql_bindings::wal_hook::WalHook; -use tokio::sync::watch; use uuid::Uuid; use crate::database::frame::{Frame, FrameHeader}; @@ -329,7 +328,7 @@ impl ReplicationLoggerHookCtx { fn commit(&self) -> anyhow::Result<()> { let new_frame_no = self.logger.commit()?; - let _ = self.logger.new_frame_notifier.send(new_frame_no); + let _ = (self.logger.new_frame_notifier)(new_frame_no); Ok(()) } @@ -748,6 +747,8 @@ impl LogCompactor for () { } } +pub type FrameNotifierCb = Box; + pub struct ReplicationLogger { pub generation: Generation, pub log_file: RwLock, @@ -755,11 +756,16 @@ pub struct ReplicationLogger { db_path: PathBuf, /// a notifier channel other tasks can subscribe to, and get notified when new frames become /// available. - pub new_frame_notifier: watch::Sender, + pub new_frame_notifier: FrameNotifierCb, } impl ReplicationLogger { - pub fn open(db_path: &Path, dirty: bool, compactor: impl LogCompactor) -> crate::Result { + pub fn open( + db_path: &Path, + dirty: bool, + compactor: impl LogCompactor, + new_frame_notifier: FrameNotifierCb, + ) -> crate::Result { let log_path = db_path.join("wallog"); let data_path = db_path.join("data"); @@ -788,9 +794,14 @@ impl ReplicationLogger { }; if should_recover { - Self::recover(log_file, data_path, compactor) + Self::recover(log_file, data_path, compactor, new_frame_notifier) } else { - Self::from_log_file(db_path.to_path_buf(), log_file, compactor) + Self::from_log_file( + db_path.to_path_buf(), + log_file, + compactor, + new_frame_notifier, + ) } } @@ -798,12 +809,11 @@ impl ReplicationLogger { db_path: PathBuf, log_file: LogFile, compactor: impl LogCompactor, + new_frame_notifier: FrameNotifierCb, ) -> crate::Result { let header = log_file.header(); let generation_start_frame_no = header.start_frame_no + header.frame_count; - let (new_frame_notifier, _) = watch::channel(generation_start_frame_no); - Ok(Self { generation: Generation::new(generation_start_frame_no), compactor: Box::new(compactor), @@ -817,6 +827,7 @@ impl ReplicationLogger { log_file: LogFile, mut data_path: PathBuf, compactor: impl LogCompactor, + new_frame_notifier: FrameNotifierCb, ) -> crate::Result { // It is necessary to checkpoint before we restore the replication log, since the WAL may // contain pages that are not in the database file. @@ -849,7 +860,7 @@ impl ReplicationLogger { assert!(data_path.pop()); - Self::from_log_file(data_path, log_file, compactor) + Self::from_log_file(data_path, log_file, compactor, new_frame_notifier) } pub fn database_id(&self) -> anyhow::Result { diff --git a/libsqlx/src/error.rs b/libsqlx/src/error.rs index 07a71831..fd7828c1 100644 --- a/libsqlx/src/error.rs +++ b/libsqlx/src/error.rs @@ -46,11 +46,3 @@ pub enum Error { #[error(transparent)] LexerError(#[from] sqlite3_parser::lexer::sql::Error), } - -impl From for Error { - fn from(inner: tokio::sync::oneshot::error::RecvError) -> Self { - Self::Internal(format!( - "Failed to receive response via oneshot channel: {inner}" - )) - } -} diff --git a/sqld/src/query_result_builder.rs b/sqld/src/query_result_builder.rs index a9aeadd7..29a2dc91 100644 --- a/sqld/src/query_result_builder.rs +++ b/sqld/src/query_result_builder.rs @@ -642,7 +642,6 @@ pub mod test { } // this can be usefull to help debug the generated test case - dbg!(trace); b } From 20dc440c563e0ee5a2c3d643ade19815d8a9ca72 Mon Sep 17 00:00:00 2001 From: ad hoc Date: Wed, 12 Jul 2023 19:50:20 +0200 Subject: [PATCH 19/64] replica: send replication request to primary --- libsqlx-server/src/allocation/mod.rs | 174 +++++++++++++++++-- libsqlx-server/src/linc/bus.rs | 18 +- libsqlx-server/src/linc/connection.rs | 10 +- libsqlx-server/src/linc/handler.rs | 1 + libsqlx-server/src/linc/mod.rs | 3 +- libsqlx-server/src/linc/proto.rs | 26 +-- libsqlx-server/src/manager.rs | 8 +- libsqlx/src/database/libsql/injector/hook.rs | 4 +- libsqlx/src/database/libsql/injector/mod.rs | 32 ++-- libsqlx/src/database/libsql/mod.rs | 30 +--- libsqlx/src/database/mod.rs | 9 +- libsqlx/src/database/proxy/database.rs | 2 +- libsqlx/src/database/proxy/mod.rs | 2 +- libsqlx/src/lib.rs | 5 +- 14 files changed, 236 insertions(+), 88 deletions(-) diff --git a/libsqlx-server/src/allocation/mod.rs b/libsqlx-server/src/allocation/mod.rs index c9f88ed0..2b6c5faf 100644 --- a/libsqlx-server/src/allocation/mod.rs +++ b/libsqlx-server/src/allocation/mod.rs @@ -1,16 +1,20 @@ use std::path::PathBuf; use std::sync::Arc; +use std::time::{Duration, Instant}; use libsqlx::libsql::{LibsqlDatabase, LogCompactor, LogFile, PrimaryType, ReplicaType}; -use libsqlx::Database as _; +use libsqlx::proxy::WriteProxyDatabase; +use libsqlx::{Database as _, DescribeResponse, Frame, InjectableDatabase, Injector, FrameNo}; use tokio::sync::{mpsc, oneshot}; use tokio::task::{block_in_place, JoinSet}; +use tokio::time::timeout; use crate::hrana; use crate::hrana::http::handle_pipeline; use crate::hrana::http::proto::{PipelineRequestBody, PipelineResponseBody}; use crate::linc::bus::Dispatch; -use crate::linc::{Inbound, NodeId}; +use crate::linc::proto::{Enveloppe, Message, Frames}; +use crate::linc::{Inbound, NodeId, Outbound}; use crate::meta::DatabaseId; use self::config::{AllocConfig, DbConfig}; @@ -24,7 +28,6 @@ pub struct ConnectionId { id: u32, close_sender: mpsc::Sender<()>, } - pub enum AllocationMessage { NewConnection(oneshot::Sender), HranaPipelineReq { @@ -34,11 +37,40 @@ pub enum AllocationMessage { Inbound(Inbound), } +pub struct DummyDb; +pub struct DummyConn; + +impl libsqlx::Connection for DummyConn { + fn execute_program( + &mut self, + _pgm: libsqlx::program::Program, + _result_builder: &mut dyn libsqlx::result_builder::ResultBuilder, + ) -> libsqlx::Result<()> { + todo!() + } + + fn describe(&self, _sql: String) -> libsqlx::Result { + todo!() + } +} + +impl libsqlx::Database for DummyDb { + type Connection = DummyConn; + + fn connect(&self) -> Result { + todo!() + } +} + +type ProxyDatabase = WriteProxyDatabase, DummyDb>; + pub enum Database { - Primary(libsqlx::libsql::LibsqlDatabase), + Primary(LibsqlDatabase), Replica { - db: libsqlx::libsql::LibsqlDatabase, + db: ProxyDatabase, + injector_handle: mpsc::Sender, primary_node_id: NodeId, + last_received_frame_ts: Option, }, } @@ -59,14 +91,107 @@ impl LogCompactor for Compactor { } } +const MAX_INJECTOR_BUFFER_CAP: usize = 32; + +struct Replicator { + dispatcher: Arc, + req_id: u32, + last_committed: FrameNo, + next_seq: u32, + database_id: DatabaseId, + primary_node_id: NodeId, + injector: Box, + receiver: mpsc::Receiver, +} + +impl Replicator { + async fn run(mut self) { + loop { + match timeout(Duration::from_secs(5), self.receiver.recv()).await { + Ok(Some(Frames { + req_id, + seq, + frames, + })) => { + // ignore frames from a previous call to Replicate + if req_id != self.req_id { continue } + if seq != self.next_seq { + // this is not the batch of frame we were expecting, drop what we have, and + // ask again from last checkpoint + self.query_replicate().await; + continue; + }; + self.next_seq += 1; + for bytes in frames { + let frame = Frame::try_from_bytes(bytes).unwrap(); + block_in_place(|| { + if let Some(last_committed) = self.injector.inject(frame).unwrap() { + self.last_committed = last_committed; + } + }); + } + } + Err(_) => self.query_replicate().await, + Ok(None) => break, + } + } + } + + async fn query_replicate(&mut self) { + self.req_id += 1; + self.next_seq = 0; + // clear buffered, uncommitted frames + self.injector.clear(); + self.dispatcher + .dispatch(Outbound { + to: self.primary_node_id, + enveloppe: Enveloppe { + database_id: Some(self.database_id), + message: Message::Replicate { + next_frame_no: self.last_committed + 1, + req_id: self.req_id - 1, + }, + }, + }) + .await; + } +} + impl Database { - pub fn from_config(config: &AllocConfig, path: PathBuf) -> Self { + pub fn from_config(config: &AllocConfig, path: PathBuf, dispatcher: Arc) -> Self { match config.db_config { DbConfig::Primary {} => { let db = LibsqlDatabase::new_primary(path, Compactor, false).unwrap(); Self::Primary(db) } - DbConfig::Replica { .. } => todo!(), + DbConfig::Replica { primary_node_id } => { + let rdb = LibsqlDatabase::new_replica(path, MAX_INJECTOR_BUFFER_CAP, ()).unwrap(); + let wdb = DummyDb; + let mut db = WriteProxyDatabase::new(rdb, wdb, Arc::new(|_| ())); + let injector = db.injector().unwrap(); + let (sender, receiver) = mpsc::channel(16); + let database_id = DatabaseId::from_name(&config.db_name); + + let replicator = Replicator { + dispatcher, + req_id: 0, + last_committed: 0, // TODO: load the last commited from meta file + next_seq: 0, + database_id, + primary_node_id, + injector, + receiver, + }; + + tokio::spawn(replicator.run()); + + Self::Replica { + db, + injector_handle: sender, + primary_node_id, + last_received_frame_ts: None, + } + } } } @@ -147,19 +272,32 @@ impl Allocation { } async fn handle_inbound(&mut self, msg: Inbound) { - debug_assert_eq!(msg.enveloppe.to, Some(DatabaseId::from_name(&self.db_name))); + debug_assert_eq!( + msg.enveloppe.database_id, + Some(DatabaseId::from_name(&self.db_name)) + ); match msg.enveloppe.message { - crate::linc::proto::Message::Handshake { .. } => todo!(), - crate::linc::proto::Message::ReplicationHandshake { .. } => todo!(), - crate::linc::proto::Message::ReplicationHandshakeResponse { .. } => todo!(), - crate::linc::proto::Message::Replicate { .. } => todo!(), - crate::linc::proto::Message::Transaction { .. } => todo!(), - crate::linc::proto::Message::ProxyRequest { .. } => todo!(), - crate::linc::proto::Message::ProxyResponse { .. } => todo!(), - crate::linc::proto::Message::CancelRequest { .. } => todo!(), - crate::linc::proto::Message::CloseConnection { .. } => todo!(), - crate::linc::proto::Message::Error(_) => todo!(), + Message::Handshake { .. } => todo!(), + Message::ReplicationHandshake { .. } => todo!(), + Message::ReplicationHandshakeResponse { .. } => todo!(), + Message::Replicate { .. } => todo!(), + Message::Frames(frames) => match &mut self.database { + Database::Replica { + injector_handle, + last_received_frame_ts, + .. + } => { + *last_received_frame_ts = Some(Instant::now()); + injector_handle.send(frames).await; + } + Database::Primary(_) => todo!("handle primary receiving txn"), + }, + Message::ProxyRequest { .. } => todo!(), + Message::ProxyResponse { .. } => todo!(), + Message::CancelRequest { .. } => todo!(), + Message::CloseConnection { .. } => todo!(), + Message::Error(_) => todo!(), } } diff --git a/libsqlx-server/src/linc/bus.rs b/libsqlx-server/src/linc/bus.rs index bf9a2cc4..8beae22d 100644 --- a/libsqlx-server/src/linc/bus.rs +++ b/libsqlx-server/src/linc/bus.rs @@ -1,10 +1,16 @@ +use std::collections::HashSet; use std::sync::Arc; -use super::{connection::SendQueue, handler::Handler, Inbound, NodeId, Outbound}; +use parking_lot::RwLock; + +use super::connection::SendQueue; +use super::handler::Handler; +use super::{Inbound, NodeId, Outbound}; pub struct Bus { node_id: NodeId, handler: H, + peers: RwLock>, send_queue: SendQueue, } @@ -15,6 +21,7 @@ impl Bus { node_id, handler, send_queue, + peers: Default::default(), } } @@ -29,6 +36,15 @@ impl Bus { pub fn send_queue(&self) -> &SendQueue { &self.send_queue } + + pub fn connect(&self, node_id: NodeId) { + // TODO: handle peer already exists + self.peers.write().insert(node_id); + } + + pub fn disconnect(&self, node_id: NodeId) { + self.peers.write().remove(&node_id); + } } #[async_trait::async_trait] diff --git a/libsqlx-server/src/linc/connection.rs b/libsqlx-server/src/linc/connection.rs index a96b8179..170d1f2a 100644 --- a/libsqlx-server/src/linc/connection.rs +++ b/libsqlx-server/src/linc/connection.rs @@ -200,8 +200,7 @@ where }))) => { if protocol_version != CURRENT_PROTO_VERSION { let msg = Enveloppe { - from: None, - to: None, + database_id: None, message: Message::Error(ProtoError::HandshakeVersionMismatch { expected: CURRENT_PROTO_VERSION, }), @@ -215,8 +214,7 @@ where // handshake message if !self.is_initiator { let msg = Enveloppe { - from: None, - to: None, + database_id: None, message: Message::Handshake { protocol_version: CURRENT_PROTO_VERSION, node_id: self.bus.node_id(), @@ -228,6 +226,7 @@ where self.peer = Some(node_id); self.state = ConnectionState::Connected; self.send_queue = Some(self.bus.send_queue().register(node_id)); + self.bus.connect(node_id); Ok(()) } @@ -255,8 +254,7 @@ where async fn initiate_connection(&mut self) -> color_eyre::Result<()> { let msg = Enveloppe { - from: None, - to: None, + database_id: None, message: Message::Handshake { protocol_version: CURRENT_PROTO_VERSION, node_id: self.bus.node_id(), diff --git a/libsqlx-server/src/linc/handler.rs b/libsqlx-server/src/linc/handler.rs index 6a6ae6f8..6403906e 100644 --- a/libsqlx-server/src/linc/handler.rs +++ b/libsqlx-server/src/linc/handler.rs @@ -5,5 +5,6 @@ use super::Inbound; #[async_trait::async_trait] pub trait Handler: Sized + Send + Sync + 'static { + /// Handle inbound message async fn handle(&self, bus: Arc>, msg: Inbound); } diff --git a/libsqlx-server/src/linc/mod.rs b/libsqlx-server/src/linc/mod.rs index fa787e87..638f56e2 100644 --- a/libsqlx-server/src/linc/mod.rs +++ b/libsqlx-server/src/linc/mod.rs @@ -26,8 +26,7 @@ impl Inbound { Outbound { to: self.from, enveloppe: Enveloppe { - from: self.enveloppe.to, - to: self.enveloppe.from, + database_id: None, message, }, } diff --git a/libsqlx-server/src/linc/proto.rs b/libsqlx-server/src/linc/proto.rs index c099cbd1..93ac445e 100644 --- a/libsqlx-server/src/linc/proto.rs +++ b/libsqlx-server/src/linc/proto.rs @@ -10,11 +10,21 @@ pub type Program = String; #[derive(Debug, Serialize, Deserialize, PartialEq, Eq)] pub struct Enveloppe { - pub from: Option, - pub to: Option, + pub database_id: Option, pub message: Message, } +#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)] +/// a batch of frames to inject +pub struct Frames{ + /// must match the Replicate request id + pub req_id: u32, + /// sequence id, monotonically incremented, reset when req_id changes. + /// Used to detect gaps in received frames. + pub seq: u32, + pub frames: Vec, +} + #[derive(Debug, Serialize, Deserialize, PartialEq, Eq)] pub enum Message { /// Initial message exchanged between nodes when connecting @@ -32,18 +42,12 @@ pub enum Message { current_frame_no: u64, }, Replicate { + /// incremental request id, used when responding with a Frames message + req_id: u32, /// next frame no to send next_frame_no: u64, }, - /// a batch of frames that are part of the same transaction - Transaction { - /// if not None, then the last frame is a commit frame, and this is the new size of the database. - size_after: Option, - /// frame_no of the last frame in frames - end_frame_no: u64, - /// a batch of frames part of the transaction. - frames: Vec, - }, + Frames(Frames), /// Proxy a query to a primary ProxyRequest { /// id of the connection to perform the query against diff --git a/libsqlx-server/src/manager.rs b/libsqlx-server/src/manager.rs index 62f86479..89604569 100644 --- a/libsqlx-server/src/manager.rs +++ b/libsqlx-server/src/manager.rs @@ -20,6 +20,10 @@ pub struct Manager { const MAX_ALLOC_MESSAGE_QUEUE_LEN: usize = 32; +trait IsSync: Sync {} + +impl IsSync for Allocation {} + impl Manager { pub fn new(db_path: PathBuf, meta_store: Arc, max_conccurent_allocs: u64) -> Self { Self { @@ -45,7 +49,7 @@ impl Manager { let (alloc_sender, inbox) = mpsc::channel(MAX_ALLOC_MESSAGE_QUEUE_LEN); let alloc = Allocation { inbox, - database: Database::from_config(&config, path), + database: Database::from_config(&config, path, bus.clone()), connections_futs: JoinSet::new(), next_conn_id: 0, max_concurrent_connections: config.max_conccurent_connection, @@ -70,7 +74,7 @@ impl Handler for Arc { async fn handle(&self, bus: Arc>, msg: Inbound) { if let Some(sender) = self .clone() - .alloc(msg.enveloppe.to.unwrap(), bus.clone()) + .alloc(msg.enveloppe.database_id.unwrap(), bus.clone()) .await { let _ = sender.send(AllocationMessage::Inbound(msg)).await; diff --git a/libsqlx/src/database/libsql/injector/hook.rs b/libsqlx/src/database/libsql/injector/hook.rs index f87172db..2cb5348d 100644 --- a/libsqlx/src/database/libsql/injector/hook.rs +++ b/libsqlx/src/database/libsql/injector/hook.rs @@ -42,7 +42,7 @@ impl InjectorHookCtx { wal: *mut Wal, ) -> anyhow::Result<()> { self.is_txn = true; - let buffer = self.buffer.borrow(); + let buffer = self.buffer.lock(); let (mut headers, last_frame_no, size_after) = make_page_header(buffer.iter().map(|f| &**f)); if size_after != 0 { @@ -157,7 +157,7 @@ unsafe impl WalHook for InjectorHook { return LIBSQL_INJECT_FATAL; } - ctx.buffer.borrow_mut().clear(); + ctx.buffer.lock().clear(); if !ctx.is_txn { LIBSQL_INJECT_OK diff --git a/libsqlx/src/database/libsql/injector/mod.rs b/libsqlx/src/database/libsql/injector/mod.rs index 19fd51ce..0c2c2207 100644 --- a/libsqlx/src/database/libsql/injector/mod.rs +++ b/libsqlx/src/database/libsql/injector/mod.rs @@ -1,15 +1,15 @@ -use std::cell::RefCell; use std::collections::VecDeque; use std::path::Path; -use std::rc::Rc; +use std::sync::Arc; +use parking_lot::Mutex; use rusqlite::OpenFlags; use crate::database::frame::Frame; use crate::database::libsql::injector::hook::{ INJECTOR_METHODS, LIBSQL_INJECT_FATAL, LIBSQL_INJECT_OK, LIBSQL_INJECT_OK_TXN, }; -use crate::database::FrameNo; +use crate::database::{FrameNo, InjectError}; use crate::seal::Seal; use hook::InjectorHookCtx; @@ -17,7 +17,7 @@ use hook::InjectorHookCtx; mod headers; mod hook; -pub type FrameBuffer = Rc>>; +pub type FrameBuffer = Arc>>; pub struct Injector { /// The injector is in a transaction state @@ -33,11 +33,22 @@ pub struct Injector { _hook_ctx: Seal>, } +impl crate::database::Injector for Injector { + fn inject(&mut self, frame: Frame) -> Result, InjectError> { + let res = self.inject_frame(frame).unwrap(); + Ok(res) + } + + fn clear(&mut self) { + self.buffer.lock().clear(); + } +} + /// Methods from this trait are called before and after performing a frame injection. /// This trait trait is used to record the last committed frame_no to the log. /// The implementer can persist the pre and post commit frame no, and compare them in the event of /// a crash; if the pre and post commit frame_no don't match, then the log may be corrupted. -pub trait InjectorCommitHandler: Send + 'static { +pub trait InjectorCommitHandler: Send + Sync + 'static { fn pre_commit(&mut self, frame_no: FrameNo) -> anyhow::Result<()>; fn post_commit(&mut self, frame_no: FrameNo) -> anyhow::Result<()>; } @@ -52,7 +63,6 @@ impl InjectorCommitHandler for Box { } } -#[cfg(test)] impl InjectorCommitHandler for () { fn pre_commit(&mut self, _frame_no: FrameNo) -> anyhow::Result<()> { Ok(()) @@ -95,8 +105,8 @@ impl Injector { /// Inject on frame into the log. If this was a commit frame, returns Ok(Some(FrameNo)). pub(crate) fn inject_frame(&mut self, frame: Frame) -> crate::Result> { let frame_close_txn = frame.header().size_after != 0; - self.buffer.borrow_mut().push_back(frame); - if frame_close_txn || self.buffer.borrow().len() >= self.capacity { + self.buffer.lock().push_back(frame); + if frame_close_txn || self.buffer.lock().len() >= self.capacity { if !self.is_txn { self.begin_txn(); } @@ -110,7 +120,7 @@ impl Injector { /// Trigger a dummy write, and flush the cache to trigger a call to xFrame. The buffer's frame /// are then injected into the wal. fn flush(&mut self) -> crate::Result> { - let last_frame_no = match self.buffer.borrow().back() { + let last_frame_no = match self.buffer.lock().back() { Some(f) => f.header().frame_no, None => { tracing::trace!("nothing to inject"); @@ -130,11 +140,11 @@ impl Injector { .pragma_update(None, "writable_schema", "reset")?; self.commit(); self.is_txn = false; - assert!(self.buffer.borrow().is_empty()); + assert!(self.buffer.lock().is_empty()); return Ok(Some(last_frame_no)); } else if e.extended_code == LIBSQL_INJECT_OK_TXN { self.is_txn = true; - assert!(self.buffer.borrow().is_empty()); + assert!(self.buffer.lock().is_empty()); return Ok(None); } else if e.extended_code == LIBSQL_INJECT_FATAL { todo!("handle fatal error"); diff --git a/libsqlx/src/database/libsql/mod.rs b/libsqlx/src/database/libsql/mod.rs index 9cb32c20..dbd1d285 100644 --- a/libsqlx/src/database/libsql/mod.rs +++ b/libsqlx/src/database/libsql/mod.rs @@ -4,8 +4,7 @@ use std::sync::Arc; use sqld_libsql_bindings::wal_hook::{TransparentMethods, WalHook, TRANSPARENT_METHODS}; use sqld_libsql_bindings::WalMethodsHook; -use crate::database::frame::Frame; -use crate::database::{Database, InjectError, InjectableDatabase}; +use crate::database::{Database, InjectableDatabase}; use crate::error::Error; use crate::result_builder::QueryBuilderConfig; @@ -68,18 +67,6 @@ pub trait LibsqlDbType { fn hook_context(&self) -> ::Context; } -pub struct PlainType; - -impl LibsqlDbType for PlainType { - type ConnectionHook = TransparentMethods; - - fn hook() -> &'static WalMethodsHook { - &TRANSPARENT_METHODS - } - - fn hook_context(&self) -> ::Context {} -} - /// A generic wrapper around a libsql database. /// `LibsqlDatabase` can be specialized into either a `ReplicaType` or a `PrimaryType`. /// In `PrimaryType` mode, the LibsqlDatabase maintains a replication log that can be replicated to @@ -125,12 +112,6 @@ impl LibsqlDatabase { } } -impl LibsqlDatabase { - pub fn new_plain(db_path: PathBuf) -> crate::Result { - Ok(Self::new(db_path, PlainType)) - } -} - impl LibsqlDatabase { pub fn new_primary( db_path: PathBuf, @@ -195,7 +176,7 @@ impl Database for LibsqlDatabase { } impl InjectableDatabase for LibsqlDatabase { - fn injector(&mut self) -> crate::Result> { + fn injector(&mut self) -> crate::Result> { let Some(commit_handler) = self.ty.commit_handler.take() else { panic!("there can be only one injector") }; Ok(Box::new(Injector::new( &self.db_path, @@ -205,13 +186,6 @@ impl InjectableDatabase for LibsqlDatabase { } } -impl super::Injector for Injector { - fn inject(&mut self, frame: Frame) -> Result<(), InjectError> { - self.inject_frame(frame).unwrap(); - Ok(()) - } -} - #[cfg(test)] mod test { use std::fs::File; diff --git a/libsqlx/src/database/mod.rs b/libsqlx/src/database/mod.rs index 62581402..368ac5ac 100644 --- a/libsqlx/src/database/mod.rs +++ b/libsqlx/src/database/mod.rs @@ -1,6 +1,5 @@ use std::time::Duration; -use self::frame::Frame; use crate::connection::Connection; use crate::error::Error; @@ -10,6 +9,8 @@ pub mod proxy; #[cfg(test)] mod test_utils; +pub use frame::Frame; + pub type FrameNo = u64; pub const TXN_TIMEOUT: Duration = Duration::from_secs(5); @@ -24,10 +25,12 @@ pub trait Database { } pub trait InjectableDatabase { - fn injector(&mut self) -> crate::Result>; + fn injector(&mut self) -> crate::Result>; } // Trait implemented by databases that support frame injection pub trait Injector { - fn inject(&mut self, frame: Frame) -> Result<(), InjectError>; + fn inject(&mut self, frame: Frame) -> Result, InjectError>; + /// clear internal buffer + fn clear(&mut self); } diff --git a/libsqlx/src/database/proxy/database.rs b/libsqlx/src/database/proxy/database.rs index 129cc5e2..e9add71f 100644 --- a/libsqlx/src/database/proxy/database.rs +++ b/libsqlx/src/database/proxy/database.rs @@ -41,7 +41,7 @@ impl InjectableDatabase for WriteProxyDatabase where RDB: InjectableDatabase, { - fn injector(&mut self) -> crate::Result> { + fn injector(&mut self) -> crate::Result> { self.read_db.injector() } } diff --git a/libsqlx/src/database/proxy/mod.rs b/libsqlx/src/database/proxy/mod.rs index 1b5b3226..62c6925d 100644 --- a/libsqlx/src/database/proxy/mod.rs +++ b/libsqlx/src/database/proxy/mod.rs @@ -8,4 +8,4 @@ mod database; pub use database::WriteProxyDatabase; // Waits until passed frameno has been replicated back to the database -type WaitFrameNoCb = Arc; +type WaitFrameNoCb = Arc; diff --git a/libsqlx/src/lib.rs b/libsqlx/src/lib.rs index a89c2771..e004317e 100644 --- a/libsqlx/src/lib.rs +++ b/libsqlx/src/lib.rs @@ -10,10 +10,11 @@ mod seal; pub type Result = std::result::Result; -pub use connection::Connection; +pub use connection::{Connection, DescribeResponse}; pub use database::libsql; pub use database::libsql::replication_log::FrameNo; pub use database::proxy; -pub use database::Database; +pub use database::Frame; +pub use database::{Database, InjectableDatabase, Injector}; pub use rusqlite; From 1989ae8b4b0d1165c6f4391b042fceffcee01d87 Mon Sep 17 00:00:00 2001 From: ad hoc Date: Thu, 13 Jul 2023 10:21:22 +0200 Subject: [PATCH 20/64] replica sends replicate request to primary --- libsqlx-server/Cargo.toml | 4 ++-- libsqlx-server/src/allocation/mod.rs | 12 +++++++--- libsqlx-server/src/config.rs | 8 +++---- libsqlx-server/src/linc/bus.rs | 2 +- libsqlx-server/src/linc/connection.rs | 10 +++++++-- libsqlx-server/src/linc/connection_pool.rs | 13 ++++++++--- libsqlx-server/src/linc/net.rs | 5 +++++ libsqlx-server/src/linc/server.rs | 5 ++++- libsqlx-server/src/main.rs | 26 +++++++++++++++++++--- 9 files changed, 66 insertions(+), 19 deletions(-) diff --git a/libsqlx-server/Cargo.toml b/libsqlx-server/Cargo.toml index efff738e..86beceda 100644 --- a/libsqlx-server/Cargo.toml +++ b/libsqlx-server/Cargo.toml @@ -11,7 +11,7 @@ async-trait = "0.1.71" axum = "0.6.18" base64 = "0.21.2" bincode = "1.3.3" -bytes = "1.4.0" +bytes = { version = "1.4.0", features = ["serde"] } clap = { version = "4.3.11", features = ["derive"] } color-eyre = "0.6.2" futures = "0.3.28" @@ -36,7 +36,7 @@ tokio-util = "0.7.8" toml = "0.7.6" tracing = "0.1.37" tracing-subscriber = { version = "0.3.17", features = ["env-filter"] } -uuid = { version = "1.4.0", features = ["v4"] } +uuid = { version = "1.4.0", features = ["v4", "serde"] } [dev-dependencies] turmoil = "0.5.5" diff --git a/libsqlx-server/src/allocation/mod.rs b/libsqlx-server/src/allocation/mod.rs index 2b6c5faf..743dbdd4 100644 --- a/libsqlx-server/src/allocation/mod.rs +++ b/libsqlx-server/src/allocation/mod.rs @@ -58,7 +58,7 @@ impl libsqlx::Database for DummyDb { type Connection = DummyConn; fn connect(&self) -> Result { - todo!() + Ok(DummyConn) } } @@ -106,6 +106,9 @@ struct Replicator { impl Replicator { async fn run(mut self) { + dbg!(); + self.query_replicate().await; + dbg!(); loop { match timeout(Duration::from_secs(5), self.receiver.recv()).await { Ok(Some(Frames { @@ -281,7 +284,10 @@ impl Allocation { Message::Handshake { .. } => todo!(), Message::ReplicationHandshake { .. } => todo!(), Message::ReplicationHandshakeResponse { .. } => todo!(), - Message::Replicate { .. } => todo!(), + Message::Replicate { .. } => match &mut self.database { + Database::Primary(_) => todo!(), + Database::Replica { .. } => (), + }, Message::Frames(frames) => match &mut self.database { Database::Replica { injector_handle, @@ -289,7 +295,7 @@ impl Allocation { .. } => { *last_received_frame_ts = Some(Instant::now()); - injector_handle.send(frames).await; + injector_handle.send(frames).await.unwrap(); } Database::Primary(_) => todo!("handle primary receiving txn"), }, diff --git a/libsqlx-server/src/config.rs b/libsqlx-server/src/config.rs index f0f9ca0c..84b961eb 100644 --- a/libsqlx-server/src/config.rs +++ b/libsqlx-server/src/config.rs @@ -30,7 +30,7 @@ pub struct ClusterConfig { /// Address to bind this node to #[serde(default = "default_linc_addr")] pub addr: SocketAddr, - /// List of peers in the format `:` + /// List of peers in the format `@` pub peers: Vec, } @@ -64,8 +64,8 @@ fn default_linc_addr() -> SocketAddr { #[derive(Debug, Clone)] pub struct Peer { - id: u64, - addr: String, + pub id: u64, + pub addr: String, } impl<'de> Deserialize<'de> for Peer { @@ -86,7 +86,7 @@ impl<'de> Deserialize<'de> for Peer { where E: serde::de::Error, { - let mut iter = v.split(":"); + let mut iter = v.split("@"); let Some(id) = iter.next() else { return Err(E::custom("node id is missing")) }; let Ok(id) = id.parse::() else { return Err(E::custom("failed to parse node id")) }; let Some(addr) = iter.next() else { return Err(E::custom("node address is missing")) }; diff --git a/libsqlx-server/src/linc/bus.rs b/libsqlx-server/src/linc/bus.rs index 8beae22d..4707c989 100644 --- a/libsqlx-server/src/linc/bus.rs +++ b/libsqlx-server/src/linc/bus.rs @@ -30,7 +30,7 @@ impl Bus { } pub async fn incomming(self: &Arc, incomming: Inbound) { - self.handler.handle(self.clone(), incomming); + self.handler.handle(self.clone(), incomming).await; } pub fn send_queue(&self) -> &SendQueue { diff --git a/libsqlx-server/src/linc/connection.rs b/libsqlx-server/src/linc/connection.rs index 170d1f2a..e12838cd 100644 --- a/libsqlx-server/src/linc/connection.rs +++ b/libsqlx-server/src/linc/connection.rs @@ -68,7 +68,8 @@ impl SendQueue { None => todo!("no queue"), }; - sender.send(msg.enveloppe); + dbg!(); + sender.send(msg.enveloppe).unwrap(); } pub fn register(&self, node_id: NodeId) -> mpsc::UnboundedReceiver { @@ -145,6 +146,7 @@ where m = self.conn.next() => { match m { Some(Ok(m)) => { + dbg!(); self.handle_message(m).await; } Some(Err(e)) => { @@ -157,11 +159,13 @@ where }, // TODO: pop send queue Some(m) = self.send_queue.as_mut().unwrap().recv() => { + dbg!(); self.conn.feed(m).await.unwrap(); // send as many as possible while let Ok(m) = self.send_queue.as_mut().unwrap().try_recv() { self.conn.feed(m).await.unwrap(); } + dbg!(); self.conn.flush().await.unwrap(); } else => { @@ -216,13 +220,15 @@ where let msg = Enveloppe { database_id: None, message: Message::Handshake { - protocol_version: CURRENT_PROTO_VERSION, + protocol_version: CURRENT_PROTO_VERSION, node_id: self.bus.node_id(), }, }; self.conn.send(msg).await?; } + tracing::info!("Connected to peer {node_id}"); + self.peer = Some(node_id); self.state = ConnectionState::Connected; self.send_queue = Some(self.bus.send_queue().register(node_id)); diff --git a/libsqlx-server/src/linc/connection_pool.rs b/libsqlx-server/src/linc/connection_pool.rs index 26c5d923..89a43a15 100644 --- a/libsqlx-server/src/linc/connection_pool.rs +++ b/libsqlx-server/src/linc/connection_pool.rs @@ -11,7 +11,7 @@ use super::net::Connector; use super::{bus::Bus, NodeId}; /// Manages a pool of connections to other peers, handling re-connection. -struct ConnectionPool { +pub struct ConnectionPool { managed_peers: HashMap, connections: JoinSet, bus: Arc>, @@ -23,16 +23,22 @@ impl ConnectionPool { managed_peers: impl IntoIterator, ) -> Self { Self { - managed_peers: managed_peers.into_iter().collect(), + managed_peers: managed_peers.into_iter().filter(|(id, _)| *id < bus.node_id()).collect(), connections: JoinSet::new(), bus, } } - pub async fn run(mut self) { + pub fn managed_count(&self) -> usize { + self.managed_peers.len() + } + + pub async fn run(mut self) -> color_eyre::Result<()> { self.init::().await; while self.tick::().await {} + + Ok(()) } pub async fn tick(&mut self) -> bool { @@ -66,6 +72,7 @@ impl ConnectionPool { let connection = Connection::new_initiator(stream, bus.clone()); connection.run().await; + dbg!(); peer_id }; diff --git a/libsqlx-server/src/linc/net.rs b/libsqlx-server/src/linc/net.rs index 430b6d08..2123c041 100644 --- a/libsqlx-server/src/linc/net.rs +++ b/libsqlx-server/src/linc/net.rs @@ -31,6 +31,7 @@ pub trait Listener { Self: 'a; fn accept(&self) -> Self::Future<'_>; + fn local_addr(&self) -> color_eyre::Result; } pub struct AcceptFut<'a>(&'a TcpListener); @@ -53,6 +54,10 @@ impl Listener for TcpListener { fn accept(&self) -> Self::Future<'_> { AcceptFut(self) } + + fn local_addr(&self) -> color_eyre::Result { + Ok(self.local_addr()?) + } } #[cfg(test)] diff --git a/libsqlx-server/src/linc/server.rs b/libsqlx-server/src/linc/server.rs index b462d0a1..f3eacec2 100644 --- a/libsqlx-server/src/linc/server.rs +++ b/libsqlx-server/src/linc/server.rs @@ -30,11 +30,14 @@ impl Server { while self.connections.join_next().await.is_some() {} } - pub async fn run(mut self, mut listener: L) + pub async fn run(mut self, mut listener: L) -> color_eyre::Result<()> where L: super::net::Listener, { + tracing::info!("Cluster server listening on {}", listener.local_addr()?); while self.tick(&mut listener).await {} + + Ok(()) } pub async fn tick(&mut self, listener: &mut L) -> bool diff --git a/libsqlx-server/src/main.rs b/libsqlx-server/src/main.rs index d5b0c35f..0d89f6f4 100644 --- a/libsqlx-server/src/main.rs +++ b/libsqlx-server/src/main.rs @@ -4,13 +4,14 @@ use std::sync::Arc; use clap::Parser; use color_eyre::eyre::Result; -use config::{AdminApiConfig, UserApiConfig}; +use config::{AdminApiConfig, UserApiConfig, ClusterConfig}; use http::admin::run_admin_api; use http::user::run_user_api; use hyper::server::conn::AddrIncoming; use linc::bus::Bus; use manager::Manager; use meta::Store; +use tokio::net::{TcpListener, TcpStream}; use tokio::task::JoinSet; use tracing::metadata::LevelFilter; use tracing_subscriber::prelude::*; @@ -36,7 +37,7 @@ async fn spawn_admin_api( config: &AdminApiConfig, meta_store: Arc, ) -> Result<()> { - let admin_api_listener = tokio::net::TcpListener::bind(config.addr).await?; + let admin_api_listener = TcpListener::bind(config.addr).await?; let fut = run_admin_api( http::admin::Config { meta_store }, AddrIncoming::from_listener(admin_api_listener)?, @@ -52,7 +53,7 @@ async fn spawn_user_api( manager: Arc, bus: Arc>>, ) -> Result<()> { - let user_api_listener = tokio::net::TcpListener::bind(config.addr).await?; + let user_api_listener = TcpListener::bind(config.addr).await?; set.spawn(run_user_api( http::user::Config { manager, bus }, AddrIncoming::from_listener(user_api_listener)?, @@ -61,6 +62,24 @@ async fn spawn_user_api( Ok(()) } +async fn spawn_cluster_networking( + set: &mut JoinSet>, + config: &ClusterConfig, + bus: Arc>>, +) -> Result<()> { + let server = linc::server::Server::new(bus.clone()); + + let listener = TcpListener::bind(config.addr).await?; + set.spawn(server.run(listener)); + + let pool = linc::connection_pool::ConnectionPool::new(bus, config.peers.iter().map(|p| (p.id, dbg!(p.addr.clone())))); + if pool.managed_count() > 0 { + set.spawn(pool.run::()); + } + + Ok(()) +} + #[tokio::main] async fn main() -> Result<()> { init(); @@ -75,6 +94,7 @@ async fn main() -> Result<()> { let manager = Arc::new(Manager::new(config.db_path.clone(), store.clone(), 100)); let bus = Arc::new(Bus::new(config.cluster.id, manager.clone())); + spawn_cluster_networking(&mut join_set, &config.cluster, bus.clone()).await?; spawn_admin_api(&mut join_set, &config.admin_api, store.clone()).await?; spawn_user_api(&mut join_set, &config.user_api, manager, bus).await?; From cc9c2c1f576f0c30102c35328dd4ac79387adc2a Mon Sep 17 00:00:00 2001 From: ad hoc Date: Thu, 13 Jul 2023 15:15:26 +0200 Subject: [PATCH 21/64] primary replicate to replica --- libsqlx-server/src/allocation/mod.rs | 167 +++++++++++++++--- libsqlx-server/src/linc/connection.rs | 6 +- libsqlx-server/src/linc/connection_pool.rs | 6 +- libsqlx-server/src/linc/proto.rs | 8 +- libsqlx-server/src/main.rs | 9 +- libsqlx/src/database/libsql/mod.rs | 8 +- .../database/libsql/replication_log/logger.rs | 13 +- libsqlx/src/lib.rs | 1 + 8 files changed, 176 insertions(+), 42 deletions(-) diff --git a/libsqlx-server/src/allocation/mod.rs b/libsqlx-server/src/allocation/mod.rs index 743dbdd4..fdd08a88 100644 --- a/libsqlx-server/src/allocation/mod.rs +++ b/libsqlx-server/src/allocation/mod.rs @@ -1,10 +1,16 @@ +use std::collections::HashMap; +use std::collections::hash_map::Entry; use std::path::PathBuf; use std::sync::Arc; use std::time::{Duration, Instant}; +use bytes::Bytes; use libsqlx::libsql::{LibsqlDatabase, LogCompactor, LogFile, PrimaryType, ReplicaType}; use libsqlx::proxy::WriteProxyDatabase; -use libsqlx::{Database as _, DescribeResponse, Frame, InjectableDatabase, Injector, FrameNo}; +use libsqlx::{ + Database as _, DescribeResponse, Frame, FrameNo, InjectableDatabase, Injector, LogReadError, + ReplicationLogger, +}; use tokio::sync::{mpsc, oneshot}; use tokio::task::{block_in_place, JoinSet}; use tokio::time::timeout; @@ -13,7 +19,7 @@ use crate::hrana; use crate::hrana::http::handle_pipeline; use crate::hrana::http::proto::{PipelineRequestBody, PipelineResponseBody}; use crate::linc::bus::Dispatch; -use crate::linc::proto::{Enveloppe, Message, Frames}; +use crate::linc::proto::{Enveloppe, Frames, Message}; use crate::linc::{Inbound, NodeId, Outbound}; use crate::meta::DatabaseId; @@ -65,7 +71,11 @@ impl libsqlx::Database for DummyDb { type ProxyDatabase = WriteProxyDatabase, DummyDb>; pub enum Database { - Primary(LibsqlDatabase), + Primary { + db: LibsqlDatabase, + replica_streams: HashMap)>, + frame_notifier: tokio::sync::watch::Receiver, + }, Replica { db: ProxyDatabase, injector_handle: mpsc::Sender, @@ -96,7 +106,7 @@ const MAX_INJECTOR_BUFFER_CAP: usize = 32; struct Replicator { dispatcher: Arc, req_id: u32, - last_committed: FrameNo, + next_frame_no: FrameNo, next_seq: u32, database_id: DatabaseId, primary_node_id: NodeId, @@ -106,30 +116,36 @@ struct Replicator { impl Replicator { async fn run(mut self) { - dbg!(); self.query_replicate().await; - dbg!(); loop { match timeout(Duration::from_secs(5), self.receiver.recv()).await { Ok(Some(Frames { - req_id, - seq, + req_no: req_id, + seq_no: seq, frames, })) => { // ignore frames from a previous call to Replicate - if req_id != self.req_id { continue } - if seq != self.next_seq { + if req_id != self.req_id { + tracing::debug!(req_id, self.req_id, "wrong req_id"); + continue; + } + if seq != self.next_seq { // this is not the batch of frame we were expecting, drop what we have, and // ask again from last checkpoint + tracing::debug!(seq, self.next_seq, "wrong seq"); self.query_replicate().await; continue; }; self.next_seq += 1; + + tracing::debug!("injecting {} frames", frames.len()); + for bytes in frames { let frame = Frame::try_from_bytes(bytes).unwrap(); block_in_place(|| { if let Some(last_committed) = self.injector.inject(frame).unwrap() { - self.last_committed = last_committed; + tracing::debug!(last_committed); + self.next_frame_no = last_committed + 1; } }); } @@ -151,12 +167,71 @@ impl Replicator { enveloppe: Enveloppe { database_id: Some(self.database_id), message: Message::Replicate { - next_frame_no: self.last_committed + 1, - req_id: self.req_id - 1, + next_frame_no: self.next_frame_no, + req_no: self.req_id, }, }, }) - .await; + .await; + } +} + +struct FrameStreamer { + logger: Arc, + database_id: DatabaseId, + node_id: NodeId, + next_frame_no: FrameNo, + req_no: u32, + seq_no: u32, + dipatcher: Arc, + notifier: tokio::sync::watch::Receiver, + buffer: Vec, +} + +// the maximum number of frame a Frame messahe is allowed to contain +const FRAMES_MESSAGE_MAX_COUNT: usize = 5; + +impl FrameStreamer { + async fn run(mut self) { + loop { + match block_in_place(|| self.logger.get_frame(self.next_frame_no)) { + Ok(frame) => { + if self.buffer.len() > FRAMES_MESSAGE_MAX_COUNT { + self.send_frames().await; + } + self.buffer.push(frame.bytes()); + self.next_frame_no += 1; + } + Err(LogReadError::Ahead) => { + tracing::debug!("frame {} not yet avaiblable", self.next_frame_no); + if !self.buffer.is_empty() { + self.send_frames().await; + } + if self.notifier.wait_for(|fno| dbg!(*fno) >= self.next_frame_no).await.is_err() { + break; + } + } + Err(LogReadError::Error(_)) => todo!("handle log read error"), + Err(LogReadError::SnapshotRequired) => todo!("handle reading from snapshot"), + } + } + } + + async fn send_frames(&mut self) { + let frames = std::mem::take(&mut self.buffer); + let outbound = Outbound { + to: self.node_id, + enveloppe: Enveloppe { + database_id: Some(self.database_id), + message: Message::Frames(Frames { + req_no: self.req_no, + seq_no: self.seq_no, + frames, + }), + }, + }; + self.seq_no += 1; + self.dipatcher.dispatch(outbound).await; } } @@ -164,9 +239,21 @@ impl Database { pub fn from_config(config: &AllocConfig, path: PathBuf, dispatcher: Arc) -> Self { match config.db_config { DbConfig::Primary {} => { - let db = LibsqlDatabase::new_primary(path, Compactor, false).unwrap(); - Self::Primary(db) - } + let (sender, receiver) = tokio::sync::watch::channel(0); + let db = LibsqlDatabase::new_primary( + path, + Compactor, + false, + Box::new(move |fno| { let _ = sender.send(fno); } ), + ) + .unwrap(); + + Self::Primary { + db, + replica_streams: HashMap::new(), + frame_notifier: receiver, + } + }, DbConfig::Replica { primary_node_id } => { let rdb = LibsqlDatabase::new_replica(path, MAX_INJECTOR_BUFFER_CAP, ()).unwrap(); let wdb = DummyDb; @@ -178,7 +265,7 @@ impl Database { let replicator = Replicator { dispatcher, req_id: 0, - last_committed: 0, // TODO: load the last commited from meta file + next_frame_no: 0, // TODO: load the last commited from meta file next_seq: 0, database_id, primary_node_id, @@ -200,7 +287,7 @@ impl Database { fn connect(&self) -> Box { match self { - Database::Primary(db) => Box::new(db.connect().unwrap()), + Database::Primary { db, .. } => Box::new(db.connect().unwrap()), Database::Replica { db, .. } => Box::new(db.connect().unwrap()), } } @@ -281,12 +368,44 @@ impl Allocation { ); match msg.enveloppe.message { - Message::Handshake { .. } => todo!(), + Message::Handshake { .. } => unreachable!("handshake should have been caught earlier"), Message::ReplicationHandshake { .. } => todo!(), Message::ReplicationHandshakeResponse { .. } => todo!(), - Message::Replicate { .. } => match &mut self.database { - Database::Primary(_) => todo!(), - Database::Replica { .. } => (), + Message::Replicate { req_no, next_frame_no } => match &mut self.database { + Database::Primary { db, replica_streams, frame_notifier } => { + dbg!(next_frame_no); + let streamer = FrameStreamer { + logger: db.logger(), + database_id: DatabaseId::from_name(&self.db_name), + node_id: msg.from, + next_frame_no, + req_no, + seq_no: 0, + dipatcher: self.dispatcher.clone(), + notifier: frame_notifier.clone(), + buffer: Vec::new(), + }; + + match replica_streams.entry(msg.from) { + Entry::Occupied(mut e) => { + let (old_req_no, old_handle) = e.get_mut(); + // ignore req_no older that the current req_no + if *old_req_no < req_no { + let handle = tokio::spawn(streamer.run()); + let old_handle = std::mem::replace(old_handle, handle); + *old_req_no = req_no; + old_handle.abort(); + } + }, + Entry::Vacant(e) => { + let handle = tokio::spawn(streamer.run()); + // For some reason, not yielding causes the task not to be spawned + tokio::task::yield_now().await; + e.insert((req_no, handle)); + }, + } + }, + Database::Replica { .. } => todo!("not a primary!"), }, Message::Frames(frames) => match &mut self.database { Database::Replica { @@ -297,7 +416,7 @@ impl Allocation { *last_received_frame_ts = Some(Instant::now()); injector_handle.send(frames).await.unwrap(); } - Database::Primary(_) => todo!("handle primary receiving txn"), + Database::Primary { .. } => todo!("handle primary receiving txn"), }, Message::ProxyRequest { .. } => todo!(), Message::ProxyResponse { .. } => todo!(), diff --git a/libsqlx-server/src/linc/connection.rs b/libsqlx-server/src/linc/connection.rs index e12838cd..bf5bd97e 100644 --- a/libsqlx-server/src/linc/connection.rs +++ b/libsqlx-server/src/linc/connection.rs @@ -68,7 +68,6 @@ impl SendQueue { None => todo!("no queue"), }; - dbg!(); sender.send(msg.enveloppe).unwrap(); } @@ -146,7 +145,6 @@ where m = self.conn.next() => { match m { Some(Ok(m)) => { - dbg!(); self.handle_message(m).await; } Some(Err(e)) => { @@ -159,13 +157,11 @@ where }, // TODO: pop send queue Some(m) = self.send_queue.as_mut().unwrap().recv() => { - dbg!(); self.conn.feed(m).await.unwrap(); // send as many as possible while let Ok(m) = self.send_queue.as_mut().unwrap().try_recv() { self.conn.feed(m).await.unwrap(); } - dbg!(); self.conn.flush().await.unwrap(); } else => { @@ -220,7 +216,7 @@ where let msg = Enveloppe { database_id: None, message: Message::Handshake { - protocol_version: CURRENT_PROTO_VERSION, + protocol_version: CURRENT_PROTO_VERSION, node_id: self.bus.node_id(), }, }; diff --git a/libsqlx-server/src/linc/connection_pool.rs b/libsqlx-server/src/linc/connection_pool.rs index 89a43a15..b6113a80 100644 --- a/libsqlx-server/src/linc/connection_pool.rs +++ b/libsqlx-server/src/linc/connection_pool.rs @@ -23,7 +23,10 @@ impl ConnectionPool { managed_peers: impl IntoIterator, ) -> Self { Self { - managed_peers: managed_peers.into_iter().filter(|(id, _)| *id < bus.node_id()).collect(), + managed_peers: managed_peers + .into_iter() + .filter(|(id, _)| *id < bus.node_id()) + .collect(), connections: JoinSet::new(), bus, } @@ -72,7 +75,6 @@ impl ConnectionPool { let connection = Connection::new_initiator(stream, bus.clone()); connection.run().await; - dbg!(); peer_id }; diff --git a/libsqlx-server/src/linc/proto.rs b/libsqlx-server/src/linc/proto.rs index 93ac445e..bec6ff7a 100644 --- a/libsqlx-server/src/linc/proto.rs +++ b/libsqlx-server/src/linc/proto.rs @@ -16,12 +16,12 @@ pub struct Enveloppe { #[derive(Debug, Serialize, Deserialize, PartialEq, Eq)] /// a batch of frames to inject -pub struct Frames{ +pub struct Frames { /// must match the Replicate request id - pub req_id: u32, + pub req_no: u32, /// sequence id, monotonically incremented, reset when req_id changes. /// Used to detect gaps in received frames. - pub seq: u32, + pub seq_no: u32, pub frames: Vec, } @@ -43,7 +43,7 @@ pub enum Message { }, Replicate { /// incremental request id, used when responding with a Frames message - req_id: u32, + req_no: u32, /// next frame no to send next_frame_no: u64, }, diff --git a/libsqlx-server/src/main.rs b/libsqlx-server/src/main.rs index 0d89f6f4..454ae954 100644 --- a/libsqlx-server/src/main.rs +++ b/libsqlx-server/src/main.rs @@ -4,7 +4,7 @@ use std::sync::Arc; use clap::Parser; use color_eyre::eyre::Result; -use config::{AdminApiConfig, UserApiConfig, ClusterConfig}; +use config::{AdminApiConfig, ClusterConfig, UserApiConfig}; use http::admin::run_admin_api; use http::user::run_user_api; use hyper::server::conn::AddrIncoming; @@ -72,7 +72,10 @@ async fn spawn_cluster_networking( let listener = TcpListener::bind(config.addr).await?; set.spawn(server.run(listener)); - let pool = linc::connection_pool::ConnectionPool::new(bus, config.peers.iter().map(|p| (p.id, dbg!(p.addr.clone())))); + let pool = linc::connection_pool::ConnectionPool::new( + bus, + config.peers.iter().map(|p| (p.id, p.addr.clone())), + ); if pool.managed_count() > 0 { set.spawn(pool.run::()); } @@ -80,7 +83,7 @@ async fn spawn_cluster_networking( Ok(()) } -#[tokio::main] +#[tokio::main(flavor = "multi_thread", worker_threads = 10)] async fn main() -> Result<()> { init(); let args = Args::parse(); diff --git a/libsqlx/src/database/libsql/mod.rs b/libsqlx/src/database/libsql/mod.rs index dbd1d285..44952df6 100644 --- a/libsqlx/src/database/libsql/mod.rs +++ b/libsqlx/src/database/libsql/mod.rs @@ -15,6 +15,7 @@ use replication_log::logger::{ }; use self::injector::InjectorCommitHandler; +use self::replication_log::logger::FrameNotifierCb; pub use connection::LibsqlConnection; pub use replication_log::logger::{LogCompactor, LogFile}; @@ -118,17 +119,22 @@ impl LibsqlDatabase { compactor: impl LogCompactor, // whether the log is dirty and might need repair dirty: bool, + new_frame_notifier: FrameNotifierCb, ) -> crate::Result { let ty = PrimaryType { logger: Arc::new(ReplicationLogger::open( &db_path, dirty, compactor, - Box::new(|_| ()), + new_frame_notifier, )?), }; Ok(Self::new(db_path, ty)) } + + pub fn logger(&self) -> Arc { + self.ty.logger.clone() + } } impl LibsqlDatabase { diff --git a/libsqlx/src/database/libsql/replication_log/logger.rs b/libsqlx/src/database/libsql/replication_log/logger.rs index fe371258..e17c286c 100644 --- a/libsqlx/src/database/libsql/replication_log/logger.rs +++ b/libsqlx/src/database/libsql/replication_log/logger.rs @@ -441,6 +441,7 @@ impl LogFile { } pub fn commit(&mut self) -> crate::Result<()> { + dbg!(&self); self.header.frame_count += self.uncommitted_frame_count; self.uncommitted_frame_count = 0; self.commited_checksum = self.uncommitted_checksum; @@ -550,6 +551,7 @@ impl LogFile { /// If the requested frame is before the first frame in the log, or after the last frame, /// Ok(None) is returned. pub fn frame(&self, frame_no: FrameNo) -> std::result::Result { + dbg!(frame_no); if frame_no < self.header.start_frame_no { return Err(LogReadError::SnapshotRequired); } @@ -695,8 +697,12 @@ pub struct LogFileHeader { } impl LogFileHeader { - pub fn last_frame_no(&self) -> FrameNo { - self.start_frame_no + self.frame_count + pub fn last_frame_no(&self) -> Option { + if self.start_frame_no == 0 && self.frame_count == 0 { + None + } else { + Some(self.start_frame_no + self.frame_count - 1) + } } fn sqld_version(&self) -> Version { @@ -871,6 +877,7 @@ impl ReplicationLogger { /// Returns the new frame count and checksum to commit fn write_pages(&self, pages: &[WalPage]) -> anyhow::Result<()> { let mut log_file = self.log_file.write(); + dbg!(); for page in pages.iter() { log_file.push_page(page)?; } @@ -899,7 +906,7 @@ impl ReplicationLogger { fn commit(&self) -> anyhow::Result { let mut log_file = self.log_file.write(); log_file.commit()?; - Ok(log_file.header().last_frame_no()) + Ok(log_file.header().last_frame_no().expect("there should be at least one frame after commit")) } pub fn get_snapshot_file(&self, from: FrameNo) -> anyhow::Result> { diff --git a/libsqlx/src/lib.rs b/libsqlx/src/lib.rs index e004317e..13223d22 100644 --- a/libsqlx/src/lib.rs +++ b/libsqlx/src/lib.rs @@ -12,6 +12,7 @@ pub type Result = std::result::Result; pub use connection::{Connection, DescribeResponse}; pub use database::libsql; +pub use database::libsql::replication_log::logger::{LogReadError, ReplicationLogger}; pub use database::libsql::replication_log::FrameNo; pub use database::proxy; pub use database::Frame; From d1e98452e9e81d6bfe9502276fe0c8a94174b819 Mon Sep 17 00:00:00 2001 From: ad hoc Date: Tue, 18 Jul 2023 10:56:13 +0200 Subject: [PATCH 22/64] fully async result builder --- Cargo.lock | 3 + libsqlx-server/Cargo.toml | 3 +- libsqlx-server/src/allocation/mod.rs | 63 +++-- libsqlx-server/src/hrana/batch.rs | 24 +- libsqlx-server/src/hrana/result_builder.rs | 212 +++++++++------- libsqlx-server/src/hrana/stmt.rs | 12 +- libsqlx/Cargo.toml | 5 + libsqlx/src/connection.rs | 90 ++----- libsqlx/src/database/libsql/connection.rs | 41 ++-- libsqlx/src/database/libsql/mod.rs | 24 +- .../database/libsql/replication_log/logger.rs | 8 +- libsqlx/src/database/proxy/connection.rs | 227 +++++++++++------- libsqlx/src/database/proxy/mod.rs | 1 + libsqlx/src/lib.rs | 2 + libsqlx/src/result_builder.rs | 87 +++---- 15 files changed, 442 insertions(+), 360 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 2a3cc0f6..d08d5034 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2519,6 +2519,7 @@ dependencies = [ "bytesize", "crc", "crossbeam", + "either", "fallible-iterator 0.3.0", "itertools 0.11.0", "nix", @@ -2533,6 +2534,7 @@ dependencies = [ "sqlite3-parser 0.9.0", "tempfile", "thiserror", + "tokio", "tracing", "uuid", ] @@ -2549,6 +2551,7 @@ dependencies = [ "bytes 1.4.0", "clap", "color-eyre", + "either", "futures", "hmac", "hyper", diff --git a/libsqlx-server/Cargo.toml b/libsqlx-server/Cargo.toml index 86beceda..a5a11437 100644 --- a/libsqlx-server/Cargo.toml +++ b/libsqlx-server/Cargo.toml @@ -14,11 +14,12 @@ bincode = "1.3.3" bytes = { version = "1.4.0", features = ["serde"] } clap = { version = "4.3.11", features = ["derive"] } color-eyre = "0.6.2" +either = "1.8.1" futures = "0.3.28" hmac = "0.12.1" hyper = { version = "0.14.27", features = ["h2", "server"] } itertools = "0.11.0" -libsqlx = { version = "0.1.0", path = "../libsqlx" } +libsqlx = { version = "0.1.0", path = "../libsqlx", features = ["tokio"] } moka = { version = "0.11.2", features = ["future"] } parking_lot = "0.12.1" priority-queue = "1.3.2" diff --git a/libsqlx-server/src/allocation/mod.rs b/libsqlx-server/src/allocation/mod.rs index fdd08a88..7e7e2b9b 100644 --- a/libsqlx-server/src/allocation/mod.rs +++ b/libsqlx-server/src/allocation/mod.rs @@ -1,12 +1,14 @@ -use std::collections::HashMap; use std::collections::hash_map::Entry; +use std::collections::HashMap; use std::path::PathBuf; use std::sync::Arc; use std::time::{Duration, Instant}; use bytes::Bytes; +use either::Either; use libsqlx::libsql::{LibsqlDatabase, LogCompactor, LogFile, PrimaryType, ReplicaType}; -use libsqlx::proxy::WriteProxyDatabase; +use libsqlx::proxy::{WriteProxyConnection, WriteProxyDatabase}; +use libsqlx::result_builder::ResultBuilder; use libsqlx::{ Database as _, DescribeResponse, Frame, FrameNo, InjectableDatabase, Injector, LogReadError, ReplicationLogger, @@ -27,7 +29,11 @@ use self::config::{AllocConfig, DbConfig}; pub mod config; -type ExecFn = Box; +type LibsqlConnection = Either< + libsqlx::libsql::LibsqlConnection, + WriteProxyConnection, DummyConn>, +>; +type ExecFn = Box; #[derive(Clone)] pub struct ConnectionId { @@ -47,10 +53,10 @@ pub struct DummyDb; pub struct DummyConn; impl libsqlx::Connection for DummyConn { - fn execute_program( + fn execute_program( &mut self, - _pgm: libsqlx::program::Program, - _result_builder: &mut dyn libsqlx::result_builder::ResultBuilder, + _pgm: &libsqlx::program::Program, + _result_builder: B, ) -> libsqlx::Result<()> { todo!() } @@ -207,7 +213,12 @@ impl FrameStreamer { if !self.buffer.is_empty() { self.send_frames().await; } - if self.notifier.wait_for(|fno| dbg!(*fno) >= self.next_frame_no).await.is_err() { + if self + .notifier + .wait_for(|fno| *fno >= self.next_frame_no) + .await + .is_err() + { break; } } @@ -244,7 +255,9 @@ impl Database { path, Compactor, false, - Box::new(move |fno| { let _ = sender.send(fno); } ), + Box::new(move |fno| { + let _ = sender.send(fno); + }), ) .unwrap(); @@ -253,7 +266,7 @@ impl Database { replica_streams: HashMap::new(), frame_notifier: receiver, } - }, + } DbConfig::Replica { primary_node_id } => { let rdb = LibsqlDatabase::new_replica(path, MAX_INJECTOR_BUFFER_CAP, ()).unwrap(); let wdb = DummyDb; @@ -285,10 +298,10 @@ impl Database { } } - fn connect(&self) -> Box { + fn connect(&self) -> LibsqlConnection { match self { - Database::Primary { db, .. } => Box::new(db.connect().unwrap()), - Database::Replica { db, .. } => Box::new(db.connect().unwrap()), + Database::Primary { db, .. } => Either::Left(db.connect().unwrap()), + Database::Replica { db, .. } => Either::Right(db.connect().unwrap()), } } } @@ -315,11 +328,11 @@ pub struct ConnectionHandle { impl ConnectionHandle { pub async fn exec(&self, f: F) -> crate::Result where - F: for<'a> FnOnce(&'a mut (dyn libsqlx::Connection + 'a)) -> R + Send + 'static, + F: for<'a> FnOnce(&'a mut LibsqlConnection) -> R + Send + 'static, R: Send + 'static, { let (sender, ret) = oneshot::channel(); - let cb = move |conn: &mut dyn libsqlx::Connection| { + let cb = move |conn: &mut LibsqlConnection| { let res = f(conn); let _ = sender.send(res); }; @@ -371,9 +384,15 @@ impl Allocation { Message::Handshake { .. } => unreachable!("handshake should have been caught earlier"), Message::ReplicationHandshake { .. } => todo!(), Message::ReplicationHandshakeResponse { .. } => todo!(), - Message::Replicate { req_no, next_frame_no } => match &mut self.database { - Database::Primary { db, replica_streams, frame_notifier } => { - dbg!(next_frame_no); + Message::Replicate { + req_no, + next_frame_no, + } => match &mut self.database { + Database::Primary { + db, + replica_streams, + frame_notifier, + } => { let streamer = FrameStreamer { logger: db.logger(), database_id: DatabaseId::from_name(&self.db_name), @@ -396,15 +415,15 @@ impl Allocation { *old_req_no = req_no; old_handle.abort(); } - }, + } Entry::Vacant(e) => { let handle = tokio::spawn(streamer.run()); // For some reason, not yielding causes the task not to be spawned tokio::task::yield_now().await; e.insert((req_no, handle)); - }, + } } - }, + } Database::Replica { .. } => todo!("not a primary!"), }, Message::Frames(frames) => match &mut self.database { @@ -459,7 +478,7 @@ impl Allocation { struct Connection { id: u32, - conn: Box, + conn: LibsqlConnection, exit: oneshot::Receiver<()>, exec: mpsc::Receiver, } @@ -470,7 +489,7 @@ impl Connection { tokio::select! { _ = &mut self.exit => break, Some(exec) = self.exec.recv() => { - tokio::task::block_in_place(|| exec(&mut *self.conn)); + tokio::task::block_in_place(|| exec(&mut self.conn)); } } } diff --git a/libsqlx-server/src/hrana/batch.rs b/libsqlx-server/src/hrana/batch.rs index 1368991e..c4131c45 100644 --- a/libsqlx-server/src/hrana/batch.rs +++ b/libsqlx-server/src/hrana/batch.rs @@ -8,10 +8,12 @@ use super::stmt::{proto_stmt_to_query, stmt_error_from_sqld_error}; use super::{proto, ProtocolError, Version}; use color_eyre::eyre::anyhow; +use libsqlx::Connection; use libsqlx::analysis::Statement; use libsqlx::program::{Cond, Program, Step}; use libsqlx::query::{Params, Query}; use libsqlx::result_builder::{StepResult, StepResultsBuilder}; +use tokio::sync::oneshot; fn proto_cond_to_cond(cond: &proto::BatchCond, max_step_i: usize) -> color_eyre::Result { let try_convert_step = |step: i32| -> Result { @@ -73,15 +75,15 @@ pub async fn execute_batch( db: &ConnectionHandle, pgm: Program, ) -> color_eyre::Result { - let builder = db + let fut = db .exec(move |conn| -> color_eyre::Result<_> { - let mut builder = HranaBatchProtoBuilder::default(); - conn.execute_program(pgm, &mut builder)?; - Ok(builder) + let (builder, ret) = HranaBatchProtoBuilder::new(); + conn.execute_program(&pgm, builder)?; + Ok(ret) }) .await??; - Ok(builder.into_ret()) + Ok(fut.await?) } pub fn proto_sequence_to_program(sql: &str) -> color_eyre::Result { @@ -110,17 +112,17 @@ pub fn proto_sequence_to_program(sql: &str) -> color_eyre::Result { } pub async fn execute_sequence(conn: &ConnectionHandle, pgm: Program) -> color_eyre::Result<()> { - let builder = conn + let fut = conn .exec(move |conn| -> color_eyre::Result<_> { - let mut builder = StepResultsBuilder::default(); - conn.execute_program(pgm, &mut builder)?; + let (snd, rcv) = oneshot::channel(); + let builder = StepResultsBuilder::new(snd); + conn.execute_program(&pgm, builder)?; - Ok(builder) + Ok(rcv) }) .await??; - builder - .into_ret() + fut.await? .into_iter() .try_for_each(|result| match result { StepResult::Ok => Ok(()), diff --git a/libsqlx-server/src/hrana/result_builder.rs b/libsqlx-server/src/hrana/result_builder.rs index b6b8c635..1047f091 100644 --- a/libsqlx-server/src/hrana/result_builder.rs +++ b/libsqlx-server/src/hrana/result_builder.rs @@ -3,76 +3,90 @@ use std::io; use bytes::Bytes; use libsqlx::{result_builder::*, FrameNo}; +use tokio::sync::oneshot; use crate::hrana::stmt::{proto_error_from_stmt_error, stmt_error_from_sqld_error}; use super::proto; -#[derive(Debug, Default)] pub struct SingleStatementBuilder { - has_step: bool, - cols: Vec, - rows: Vec>, - err: Option, - affected_row_count: u64, - last_insert_rowid: Option, - current_size: u64, - max_response_size: u64, + builder: StatementBuilder, + ret: oneshot::Sender>, } impl SingleStatementBuilder { - pub fn into_ret(self) -> Result { - match self.err { - Some(err) => Err(err), - None => Ok(proto::StmtResult { - cols: self.cols, - rows: self.rows, - affected_row_count: self.affected_row_count, - last_insert_rowid: self.last_insert_rowid, - }), - } + pub fn new() -> (Self, oneshot::Receiver>) { + let (ret, rcv) = oneshot::channel(); + (Self { + builder: StatementBuilder::default(), + ret, + }, rcv) } } -struct SizeFormatter(u64); +impl ResultBuilder for SingleStatementBuilder { + fn init(&mut self, config: &QueryBuilderConfig) -> Result<(), QueryResultBuilderError> { + self.builder.init(config) + } -impl io::Write for SizeFormatter { - fn write(&mut self, buf: &[u8]) -> io::Result { - self.0 += buf.len() as u64; - Ok(buf.len()) + fn begin_step(&mut self) -> Result<(), QueryResultBuilderError> { + self.builder.begin_step() } - fn flush(&mut self) -> io::Result<()> { - Ok(()) + fn finish_step( + &mut self, + affected_row_count: u64, + last_insert_rowid: Option, + ) -> Result<(), QueryResultBuilderError> { + self.builder.finish_step(affected_row_count, last_insert_rowid) } -} -impl fmt::Write for SizeFormatter { - fn write_str(&mut self, s: &str) -> fmt::Result { - self.0 += s.len() as u64; - Ok(()) + fn step_error(&mut self, error: libsqlx::error::Error) -> Result<(), QueryResultBuilderError> { + self.builder.step_error(error) } -} -fn value_json_size(v: &ValueRef) -> u64 { - let mut f = SizeFormatter(0); - match v { - ValueRef::Null => write!(&mut f, r#"{{"type":"null"}}"#).unwrap(), - ValueRef::Integer(i) => write!(&mut f, r#"{{"type":"integer", "value": "{i}"}}"#).unwrap(), - ValueRef::Real(x) => write!(&mut f, r#"{{"type":"integer","value": {x}"}}"#).unwrap(), - ValueRef::Text(s) => { - // error will be caught later. - if let Ok(s) = std::str::from_utf8(s) { - write!(&mut f, r#"{{"type":"text","value":"{s}"}}"#).unwrap() - } - } - ValueRef::Blob(b) => return b.len() as u64, + fn cols_description( + &mut self, + cols: &mut dyn Iterator, + ) -> Result<(), QueryResultBuilderError> { + self.builder.cols_description(cols) } - f.0 + fn begin_row(&mut self) -> Result<(), QueryResultBuilderError> { + self.builder.begin_row() + } + + fn add_row_value(&mut self, v: ValueRef) -> Result<(), QueryResultBuilderError> { + self.builder.add_row_value(v) + } + + fn finnalize( + self, + _is_txn: bool, + _frame_no: Option, + ) -> Result + where Self: Sized + { + let res = self.builder.into_ret(); + let _ = self.ret.send(res); + Ok(true) + } } -impl ResultBuilder for SingleStatementBuilder { + +#[derive(Debug, Default)] +struct StatementBuilder { + has_step: bool, + cols: Vec, + rows: Vec>, + err: Option, + affected_row_count: u64, + last_insert_rowid: Option, + current_size: u64, + max_response_size: u64, +} + +impl StatementBuilder { fn init(&mut self, config: &QueryBuilderConfig) -> Result<(), QueryResultBuilderError> { *self = Self { max_response_size: config.max_size.unwrap_or(u64::MAX), @@ -138,12 +152,6 @@ impl ResultBuilder for SingleStatementBuilder { Ok(()) } - fn begin_rows(&mut self) -> Result<(), QueryResultBuilderError> { - assert!(self.err.is_none()); - assert!(self.rows.is_empty()); - Ok(()) - } - fn begin_row(&mut self) -> Result<(), QueryResultBuilderError> { assert!(self.err.is_none()); self.rows.push(Vec::with_capacity(self.cols.len())); @@ -183,25 +191,57 @@ impl ResultBuilder for SingleStatementBuilder { Ok(()) } - fn finish_row(&mut self) -> Result<(), QueryResultBuilderError> { - assert!(self.err.is_none()); - Ok(()) + pub fn into_ret(self) -> Result { + match self.err { + Some(err) => Err(err), + None => Ok(proto::StmtResult { + cols: self.cols, + rows: self.rows, + affected_row_count: self.affected_row_count, + last_insert_rowid: self.last_insert_rowid, + }), + } } +} - fn finish_rows(&mut self) -> Result<(), QueryResultBuilderError> { - assert!(self.err.is_none()); +struct SizeFormatter(u64); + +impl io::Write for SizeFormatter { + fn write(&mut self, buf: &[u8]) -> io::Result { + self.0 += buf.len() as u64; + Ok(buf.len()) + } + + fn flush(&mut self) -> io::Result<()> { Ok(()) } +} - fn finish( - &mut self, - _is_txn: bool, - _frame_no: Option, - ) -> Result<(), QueryResultBuilderError> { +impl fmt::Write for SizeFormatter { + fn write_str(&mut self, s: &str) -> fmt::Result { + self.0 += s.len() as u64; Ok(()) } } +fn value_json_size(v: &ValueRef) -> u64 { + let mut f = SizeFormatter(0); + match v { + ValueRef::Null => write!(&mut f, r#"{{"type":"null"}}"#).unwrap(), + ValueRef::Integer(i) => write!(&mut f, r#"{{"type":"integer", "value": "{i}"}}"#).unwrap(), + ValueRef::Real(x) => write!(&mut f, r#"{{"type":"integer","value": {x}"}}"#).unwrap(), + ValueRef::Text(s) => { + // error will be caught later. + if let Ok(s) = std::str::from_utf8(s) { + write!(&mut f, r#"{{"type":"text","value":"{s}"}}"#).unwrap() + } + } + ValueRef::Blob(b) => return b.len() as u64, + } + + f.0 +} + fn estimate_cols_json_size(c: &Column) -> u64 { let mut f = SizeFormatter(0); write!( @@ -214,17 +254,32 @@ fn estimate_cols_json_size(c: &Column) -> u64 { f.0 } -#[derive(Debug, Default)] +#[derive(Debug)] pub struct HranaBatchProtoBuilder { step_results: Vec>, step_errors: Vec>, - stmt_builder: SingleStatementBuilder, + stmt_builder: StatementBuilder, current_size: u64, max_response_size: u64, step_empty: bool, + ret: oneshot::Sender } impl HranaBatchProtoBuilder { + pub fn new() -> (Self, oneshot::Receiver) { + let (ret, rcv) = oneshot::channel(); + (Self { + step_results: Vec::new(), + step_errors: Vec::new(), + stmt_builder: StatementBuilder::default(), + current_size: 0, + max_response_size: u64::MAX, + step_empty: false, + ret, + }, + rcv) + + } pub fn into_ret(self) -> proto::BatchResult { proto::BatchResult { step_results: self.step_results, @@ -235,10 +290,7 @@ impl HranaBatchProtoBuilder { impl ResultBuilder for HranaBatchProtoBuilder { fn init(&mut self, config: &QueryBuilderConfig) -> Result<(), QueryResultBuilderError> { - *self = Self { - max_response_size: config.max_size.unwrap_or(u64::MAX), - ..Default::default() - }; + self.max_response_size = config.max_size.unwrap_or(u64::MAX); self.stmt_builder.init(config)?; Ok(()) } @@ -257,7 +309,7 @@ impl ResultBuilder for HranaBatchProtoBuilder { .finish_step(affected_row_count, last_insert_rowid)?; self.current_size += self.stmt_builder.current_size; - let new_builder = SingleStatementBuilder { + let new_builder = StatementBuilder { current_size: 0, max_response_size: self.max_response_size - self.current_size, ..Default::default() @@ -290,10 +342,6 @@ impl ResultBuilder for HranaBatchProtoBuilder { self.stmt_builder.cols_description(cols) } - fn begin_rows(&mut self) -> Result<(), QueryResultBuilderError> { - self.stmt_builder.begin_rows() - } - fn begin_row(&mut self) -> Result<(), QueryResultBuilderError> { self.stmt_builder.begin_row() } @@ -301,20 +349,4 @@ impl ResultBuilder for HranaBatchProtoBuilder { fn add_row_value(&mut self, v: ValueRef) -> Result<(), QueryResultBuilderError> { self.stmt_builder.add_row_value(v) } - - fn finish_row(&mut self) -> Result<(), QueryResultBuilderError> { - self.stmt_builder.finish_row() - } - - fn finish_rows(&mut self) -> Result<(), QueryResultBuilderError> { - Ok(()) - } - - fn finish( - &mut self, - _is_txn: bool, - _frame_no: Option, - ) -> Result<(), QueryResultBuilderError> { - Ok(()) - } } diff --git a/libsqlx-server/src/hrana/stmt.rs b/libsqlx-server/src/hrana/stmt.rs index 5453ab5c..e6c002a1 100644 --- a/libsqlx-server/src/hrana/stmt.rs +++ b/libsqlx-server/src/hrana/stmt.rs @@ -3,6 +3,7 @@ use std::collections::HashMap; use color_eyre::eyre::{anyhow, bail}; use libsqlx::analysis::Statement; use libsqlx::query::{Params, Query, Value}; +use libsqlx::Connection; use super::result_builder::SingleStatementBuilder; use super::{proto, ProtocolError, Version}; @@ -47,18 +48,17 @@ pub async fn execute_stmt( conn: &ConnectionHandle, query: Query, ) -> color_eyre::Result { - let builder = conn + let fut = conn .exec(move |conn| -> color_eyre::Result<_> { - let mut builder = SingleStatementBuilder::default(); + let (builder, ret) = SingleStatementBuilder::new(); let pgm = libsqlx::program::Program::from_queries(std::iter::once(query)); - conn.execute_program(pgm, &mut builder)?; + conn.execute_program(&pgm, builder)?; - Ok(builder) + Ok(ret) }) .await??; - builder - .into_ret() + fut.await? .map_err(|sqld_error| match stmt_error_from_sqld_error(sqld_error) { Ok(stmt_error) => anyhow!(stmt_error), Err(sqld_error) => anyhow!(sqld_error), diff --git a/libsqlx/Cargo.toml b/libsqlx/Cargo.toml index abdb39ad..85fd7a9d 100644 --- a/libsqlx/Cargo.toml +++ b/libsqlx/Cargo.toml @@ -27,8 +27,13 @@ crc = "3.0.1" once_cell = "1.18.0" regex = "1.8.4" tempfile = "3.6.0" +either = "1.8.1" +tokio = { version = "1", optional = true, features = ["sync"] } [dev-dependencies] arbitrary = { version = "1.3.0", features = ["derive"] } itertools = "0.11.0" rand = "0.8.5" + +[features] +tokio = ["dep:tokio"] diff --git a/libsqlx/src/connection.rs b/libsqlx/src/connection.rs index e2fd05f8..fa027997 100644 --- a/libsqlx/src/connection.rs +++ b/libsqlx/src/connection.rs @@ -1,8 +1,7 @@ -use rusqlite::types::Value; +use either::Either; -use crate::program::{Program, Step}; -use crate::query::Query; -use crate::result_builder::{QueryBuilderConfig, QueryResultBuilderError, ResultBuilder}; +use crate::program::Program; +use crate::result_builder::ResultBuilder; #[derive(Debug, Clone)] pub struct DescribeResponse { @@ -25,81 +24,36 @@ pub struct DescribeCol { pub trait Connection { /// Executes a query program - fn execute_program( + fn execute_program( &mut self, - pgm: Program, - result_builder: &mut dyn ResultBuilder, + pgm: &Program, + result_builder: B, ) -> crate::Result<()>; /// Parse the SQL statement and return information about it. fn describe(&self, sql: String) -> crate::Result; - - /// execute a single query - fn execute(&mut self, query: Query) -> crate::Result>> { - #[derive(Default)] - struct RowsBuilder { - error: Option, - rows: Vec>, - current_row: Vec, - } - - impl ResultBuilder for RowsBuilder { - fn init( - &mut self, - _config: &QueryBuilderConfig, - ) -> std::result::Result<(), QueryResultBuilderError> { - self.error = None; - self.rows.clear(); - self.current_row.clear(); - - Ok(()) - } - - fn add_row_value( - &mut self, - v: rusqlite::types::ValueRef, - ) -> Result<(), QueryResultBuilderError> { - self.current_row.push(v.into()); - Ok(()) - } - - fn finish_row(&mut self) -> Result<(), QueryResultBuilderError> { - let row = std::mem::take(&mut self.current_row); - self.rows.push(row); - - Ok(()) - } - - fn step_error( - &mut self, - error: crate::error::Error, - ) -> Result<(), QueryResultBuilderError> { - self.error.replace(error); - Ok(()) - } - } - - let pgm = Program::new(vec![Step { cond: None, query }]); - let mut builder = RowsBuilder::default(); - self.execute_program(pgm, &mut builder)?; - if let Some(err) = builder.error.take() { - Err(err) - } else { - Ok(builder.rows) - } - } } -impl Connection for Box { - fn execute_program( +impl Connection for Either +where + T: Connection, + X: Connection, +{ + fn execute_program( &mut self, - pgm: Program, - result_builder: &mut dyn ResultBuilder, + pgm: &Program, + result_builder: B, ) -> crate::Result<()> { - self.as_mut().execute_program(pgm, result_builder) + match self { + Either::Left(c) => c.execute_program(pgm, result_builder), + Either::Right(c) => c.execute_program(pgm, result_builder), + } } fn describe(&self, sql: String) -> crate::Result { - self.as_ref().describe(sql) + match self { + Either::Left(c) => c.describe(sql), + Either::Right(c) => c.describe(sql), + } } } diff --git a/libsqlx/src/database/libsql/connection.rs b/libsqlx/src/database/libsql/connection.rs index cd5bcff3..27ee59e1 100644 --- a/libsqlx/src/database/libsql/connection.rs +++ b/libsqlx/src/database/libsql/connection.rs @@ -14,7 +14,7 @@ use crate::result_builder::{QueryBuilderConfig, ResultBuilder}; use crate::seal::Seal; use crate::Result; -use super::RowStatsHandler; +use super::{LibsqlDbType, RowStatsHandler}; pub struct RowStats { pub rows_read: u64, @@ -49,23 +49,23 @@ where sqld_libsql_bindings::Connection::open(path, flags, wal_methods, hook_ctx) } -pub struct LibsqlConnection { +pub struct LibsqlConnection { timeout_deadline: Option, conn: sqld_libsql_bindings::Connection<'static>, // holds a ref to _context, must be dropped first. row_stats_handler: Option>, builder_config: QueryBuilderConfig, - _context: Seal>, + _context: Seal::Context>>, } -impl LibsqlConnection { - pub(crate) fn new( +impl LibsqlConnection { + pub(crate) fn new( path: &Path, extensions: Option>, - wal_methods: &'static WalMethodsHook, - hook_ctx: W::Context, + wal_methods: &'static WalMethodsHook, + hook_ctx: ::Context, row_stats_callback: Option>, builder_config: QueryBuilderConfig, - ) -> Result> { + ) -> Result> { let mut ctx = Box::new(hook_ctx); let this = LibsqlConnection { conn: open_db( @@ -101,14 +101,14 @@ impl LibsqlConnection { &self.conn } - fn run(&mut self, pgm: Program, builder: &mut dyn ResultBuilder) -> Result<()> { + fn run(&mut self, pgm: &Program, mut builder: B) -> Result<()> { let mut results = Vec::with_capacity(pgm.steps.len()); builder.init(&self.builder_config)?; let is_autocommit_before = self.conn.is_autocommit(); for step in pgm.steps() { - let res = self.execute_step(step, &results, builder)?; + let res = self.execute_step(step, &results, &mut builder)?; results.push(res); } @@ -117,16 +117,19 @@ impl LibsqlConnection { self.timeout_deadline = Some(Instant::now() + TXN_TIMEOUT) } - builder.finish(!self.conn.is_autocommit(), None)?; + let is_txn = !self.conn.is_autocommit(); + if !builder.finnalize(is_txn, None)? && is_txn { + let _ = self.conn.execute("ROLLBACK", ()); + } Ok(()) } - fn execute_step( + fn execute_step( &mut self, step: &Step, results: &[bool], - builder: &mut dyn ResultBuilder, + builder: &mut B, ) -> Result { builder.begin_step()?; let mut enabled = match step.cond.as_ref() { @@ -160,10 +163,10 @@ impl LibsqlConnection { Ok(enabled) } - fn execute_query( + fn execute_query( &self, query: &Query, - builder: &mut dyn ResultBuilder, + builder: &mut B, ) -> Result<(u64, Option)> { tracing::trace!("executing query: {}", query.stmt.stmt); @@ -236,11 +239,11 @@ fn eval_cond(cond: &Cond, results: &[bool]) -> Result { }) } -impl Connection for LibsqlConnection { - fn execute_program( +impl Connection for LibsqlConnection { + fn execute_program( &mut self, - pgm: Program, - builder: &mut dyn ResultBuilder, + pgm: &Program, + builder: B, ) -> crate::Result<()> { self.run(pgm, builder) } diff --git a/libsqlx/src/database/libsql/mod.rs b/libsqlx/src/database/libsql/mod.rs index 44952df6..c0aaed79 100644 --- a/libsqlx/src/database/libsql/mod.rs +++ b/libsqlx/src/database/libsql/mod.rs @@ -163,21 +163,19 @@ impl LibsqlDatabase { } impl Database for LibsqlDatabase { - type Connection = LibsqlConnection<::Context>; + type Connection = LibsqlConnection; fn connect(&self) -> Result { - Ok( - LibsqlConnection::<::Context>::new( - &self.db_path, - self.extensions.clone(), - T::hook(), - self.ty.hook_context(), - self.row_stats_callback.clone(), - QueryBuilderConfig { - max_size: Some(self.response_size_limit), - }, - )?, - ) + Ok(LibsqlConnection::::new( + &self.db_path, + self.extensions.clone(), + T::hook(), + self.ty.hook_context(), + self.row_stats_callback.clone(), + QueryBuilderConfig { + max_size: Some(self.response_size_limit), + }, + )?) } } diff --git a/libsqlx/src/database/libsql/replication_log/logger.rs b/libsqlx/src/database/libsql/replication_log/logger.rs index e17c286c..187f3b25 100644 --- a/libsqlx/src/database/libsql/replication_log/logger.rs +++ b/libsqlx/src/database/libsql/replication_log/logger.rs @@ -441,7 +441,6 @@ impl LogFile { } pub fn commit(&mut self) -> crate::Result<()> { - dbg!(&self); self.header.frame_count += self.uncommitted_frame_count; self.uncommitted_frame_count = 0; self.commited_checksum = self.uncommitted_checksum; @@ -551,7 +550,6 @@ impl LogFile { /// If the requested frame is before the first frame in the log, or after the last frame, /// Ok(None) is returned. pub fn frame(&self, frame_no: FrameNo) -> std::result::Result { - dbg!(frame_no); if frame_no < self.header.start_frame_no { return Err(LogReadError::SnapshotRequired); } @@ -877,7 +875,6 @@ impl ReplicationLogger { /// Returns the new frame count and checksum to commit fn write_pages(&self, pages: &[WalPage]) -> anyhow::Result<()> { let mut log_file = self.log_file.write(); - dbg!(); for page in pages.iter() { log_file.push_page(page)?; } @@ -906,7 +903,10 @@ impl ReplicationLogger { fn commit(&self) -> anyhow::Result { let mut log_file = self.log_file.write(); log_file.commit()?; - Ok(log_file.header().last_frame_no().expect("there should be at least one frame after commit")) + Ok(log_file + .header() + .last_frame_no() + .expect("there should be at least one frame after commit")) } pub fn get_snapshot_file(&self, from: FrameNo) -> anyhow::Result> { diff --git a/libsqlx/src/database/proxy/connection.rs b/libsqlx/src/database/proxy/connection.rs index 24c10a47..68d10c00 100644 --- a/libsqlx/src/database/proxy/connection.rs +++ b/libsqlx/src/database/proxy/connection.rs @@ -1,7 +1,7 @@ use crate::connection::{Connection, DescribeResponse}; use crate::database::FrameNo; use crate::program::Program; -use crate::result_builder::{QueryBuilderConfig, ResultBuilder}; +use crate::result_builder::{Column, QueryBuilderConfig, QueryResultBuilderError, ResultBuilder}; use crate::Result; use super::WaitFrameNoCb; @@ -18,7 +18,92 @@ pub struct WriteProxyConnection { pub(crate) read_db: ReadDb, pub(crate) write_db: WriteDb, pub(crate) wait_frame_no_cb: WaitFrameNoCb, - pub(crate) state: parking_lot::Mutex, + pub(crate) state: ConnState, +} + +struct MaybeRemoteExecBuilder<'a, 'b, B, W> { + builder: B, + conn: &'a mut W, + pgm: &'b Program, + state: &'a mut ConnState, +} + +impl<'a, 'b, B, W> ResultBuilder for MaybeRemoteExecBuilder<'a, 'b, B, W> +where + W: Connection, + B: ResultBuilder, +{ + fn init(&mut self, config: &QueryBuilderConfig) -> Result<(), QueryResultBuilderError> { + self.builder.init(config) + } + + fn begin_step(&mut self) -> Result<(), QueryResultBuilderError> { + self.builder.begin_step() + } + + fn finish_step( + &mut self, + affected_row_count: u64, + last_insert_rowid: Option, + ) -> Result<(), QueryResultBuilderError> { + self.builder + .finish_step(affected_row_count, last_insert_rowid) + } + + fn step_error(&mut self, error: crate::error::Error) -> Result<(), QueryResultBuilderError> { + self.builder.step_error(error) + } + + fn cols_description( + &mut self, + cols: &mut dyn Iterator, + ) -> Result<(), QueryResultBuilderError> { + self.builder.cols_description(cols) + } + + fn begin_rows(&mut self) -> Result<(), QueryResultBuilderError> { + self.builder.begin_rows() + } + + fn begin_row(&mut self) -> Result<(), QueryResultBuilderError> { + self.builder.begin_row() + } + + fn add_row_value( + &mut self, + v: rusqlite::types::ValueRef, + ) -> Result<(), QueryResultBuilderError> { + self.builder.add_row_value(v) + } + + fn finish_row(&mut self) -> Result<(), QueryResultBuilderError> { + self.builder.finish_row() + } + + fn finish_rows(&mut self) -> Result<(), QueryResultBuilderError> { + self.builder.finish_rows() + } + + fn finnalize( + self, + is_txn: bool, + frame_no: Option, + ) -> Result { + if is_txn { + // a read only connection is not allowed to leave an open transaction. We mispredicted the + // final state of the connection, so we rollback, and execute again on the write proxy. + let builder = ExtractFrameNoBuilder { + builder: self.builder, + state: self.state, + }; + + self.conn.execute_program(self.pgm, builder).unwrap(); + + Ok(false) + } else { + self.builder.finnalize(is_txn, frame_no) + } + } } impl Connection for WriteProxyConnection @@ -26,137 +111,111 @@ where ReadDb: Connection, WriteDb: Connection, { - fn execute_program(&mut self, pgm: Program, builder: &mut dyn ResultBuilder) -> Result<()> { - let mut state = self.state.lock(); - let mut builder = ExtractFrameNoBuilder::new(builder); - if !state.is_txn && pgm.is_read_only() { - if let Some(frame_no) = state.last_frame_no { + fn execute_program( + &mut self, + pgm: &Program, + builder: B, + ) -> crate::Result<()> { + if !self.state.is_txn && pgm.is_read_only() { + if let Some(frame_no) = self.state.last_frame_no { (self.wait_frame_no_cb)(frame_no); } + + let builder = MaybeRemoteExecBuilder { + builder, + conn: &mut self.write_db, + state: &mut self.state, + pgm, + }; // We know that this program won't perform any writes. We attempt to run it on the // replica. If it leaves an open transaction, then this program is an interactive // transaction, so we rollback the replica, and execute again on the primary. - self.read_db.execute_program(pgm.clone(), &mut builder)?; - - // still in transaction state after running a read-only txn - if builder.is_txn { - // TODO: rollback - // self.read_db.rollback().await?; - self.write_db.execute_program(pgm, &mut builder)?; - state.is_txn = builder.is_txn; - state.last_frame_no = builder.frame_no; - Ok(()) - } else { - Ok(()) - } + self.read_db.execute_program(pgm, builder)?; + // rollback(&mut self.conn.read_db); + Ok(()) } else { - self.write_db.execute_program(pgm, &mut builder)?; - state.is_txn = builder.is_txn; - state.last_frame_no = builder.frame_no; + let builder = ExtractFrameNoBuilder { + builder, + state: &mut self.state, + }; + self.write_db.execute_program(pgm, builder)?; Ok(()) } } - fn describe(&self, sql: String) -> Result { - if let Some(frame_no) = self.state.lock().last_frame_no { + fn describe(&self, sql: String) -> crate::Result { + if let Some(frame_no) = self.state.last_frame_no { (self.wait_frame_no_cb)(frame_no); } self.read_db.describe(sql) } } -struct ExtractFrameNoBuilder<'a> { - inner: &'a mut dyn ResultBuilder, - frame_no: Option, - is_txn: bool, +struct ExtractFrameNoBuilder<'a, B> { + builder: B, + state: &'a mut ConnState, } -impl<'a> ExtractFrameNoBuilder<'a> { - fn new(inner: &'a mut dyn ResultBuilder) -> Self { - Self { - inner, - frame_no: None, - is_txn: false, - } +impl ResultBuilder for ExtractFrameNoBuilder<'_, B> { + fn init(&mut self, config: &QueryBuilderConfig) -> Result<(), QueryResultBuilderError> { + self.builder.init(config) } -} -impl<'a> ResultBuilder for ExtractFrameNoBuilder<'a> { - fn init( - &mut self, - config: &QueryBuilderConfig, - ) -> std::result::Result<(), crate::result_builder::QueryResultBuilderError> { - self.inner.init(config) - } - - fn begin_step( - &mut self, - ) -> std::result::Result<(), crate::result_builder::QueryResultBuilderError> { - self.inner.begin_step() + fn begin_step(&mut self) -> Result<(), QueryResultBuilderError> { + self.builder.begin_step() } fn finish_step( &mut self, affected_row_count: u64, last_insert_rowid: Option, - ) -> std::result::Result<(), crate::result_builder::QueryResultBuilderError> { - self.inner + ) -> Result<(), QueryResultBuilderError> { + self.builder .finish_step(affected_row_count, last_insert_rowid) } - fn step_error( - &mut self, - error: crate::error::Error, - ) -> std::result::Result<(), crate::result_builder::QueryResultBuilderError> { - self.inner.step_error(error) + fn step_error(&mut self, error: crate::error::Error) -> Result<(), QueryResultBuilderError> { + self.builder.step_error(error) } fn cols_description( &mut self, - cols: &mut dyn Iterator, - ) -> std::result::Result<(), crate::result_builder::QueryResultBuilderError> { - self.inner.cols_description(cols) + cols: &mut dyn Iterator, + ) -> Result<(), QueryResultBuilderError> { + self.builder.cols_description(cols) } - fn begin_rows( - &mut self, - ) -> std::result::Result<(), crate::result_builder::QueryResultBuilderError> { - self.inner.begin_rows() + fn begin_rows(&mut self) -> Result<(), QueryResultBuilderError> { + self.builder.begin_rows() } - fn begin_row( - &mut self, - ) -> std::result::Result<(), crate::result_builder::QueryResultBuilderError> { - self.inner.begin_row() + fn begin_row(&mut self) -> Result<(), QueryResultBuilderError> { + self.builder.begin_row() } fn add_row_value( &mut self, v: rusqlite::types::ValueRef, - ) -> std::result::Result<(), crate::result_builder::QueryResultBuilderError> { - self.inner.add_row_value(v) + ) -> Result<(), QueryResultBuilderError> { + self.builder.add_row_value(v) } - fn finish_row( - &mut self, - ) -> std::result::Result<(), crate::result_builder::QueryResultBuilderError> { - self.inner.finish_row() + fn finish_row(&mut self) -> Result<(), QueryResultBuilderError> { + self.builder.finish_row() } - fn finish_rows( - &mut self, - ) -> std::result::Result<(), crate::result_builder::QueryResultBuilderError> { - self.inner.finish_rows() + fn finish_rows(&mut self) -> Result<(), QueryResultBuilderError> { + self.builder.finish_rows() } - fn finish( - &mut self, + fn finnalize( + self, is_txn: bool, frame_no: Option, - ) -> std::result::Result<(), crate::result_builder::QueryResultBuilderError> { - self.frame_no = frame_no; - self.is_txn = is_txn; - self.inner.finish(is_txn, frame_no) + ) -> Result { + self.state.last_frame_no = frame_no; + self.state.is_txn = is_txn; + self.builder.finnalize(is_txn, frame_no) } } @@ -177,7 +236,7 @@ mod test { let write_db = MockDatabase::new().with_execute({ let write_called = write_called.clone(); move |_, b| { - b.finish(false, Some(42)).unwrap(); + b.finnalize(false, Some(42)).unwrap(); write_called.set(true); Ok(()) } diff --git a/libsqlx/src/database/proxy/mod.rs b/libsqlx/src/database/proxy/mod.rs index 62c6925d..0fdf7ceb 100644 --- a/libsqlx/src/database/proxy/mod.rs +++ b/libsqlx/src/database/proxy/mod.rs @@ -5,6 +5,7 @@ use super::FrameNo; mod connection; mod database; +pub use connection::WriteProxyConnection; pub use database::WriteProxyDatabase; // Waits until passed frameno has been replicated back to the database diff --git a/libsqlx/src/lib.rs b/libsqlx/src/lib.rs index 13223d22..899a7912 100644 --- a/libsqlx/src/lib.rs +++ b/libsqlx/src/lib.rs @@ -18,4 +18,6 @@ pub use database::proxy; pub use database::Frame; pub use database::{Database, InjectableDatabase, Injector}; +pub use sqld_libsql_bindings::wal_hook::WalHook; + pub use rusqlite; diff --git a/libsqlx/src/result_builder.rs b/libsqlx/src/result_builder.rs index be5e27a7..98f598c1 100644 --- a/libsqlx/src/result_builder.rs +++ b/libsqlx/src/result_builder.rs @@ -130,12 +130,16 @@ pub trait ResultBuilder { Ok(()) } /// finish the builder, and pass the transaction state. - fn finish( - &mut self, + /// If false is returned, and is_txn is true, then the transaction is rolledback. + fn finnalize( + self, _is_txn: bool, _frame_no: Option, - ) -> Result<(), QueryResultBuilderError> { - Ok(()) + ) -> Result + where + Self: Sized, + { + Ok(true) } } @@ -163,22 +167,40 @@ pub enum StepResult { } /// A `QueryResultBuilder` that ignores rows, but records the outcome of each step in a `StepResult` -#[derive(Debug, Default)] -pub struct StepResultsBuilder { +pub struct StepResultsBuilder { current: Option, step_results: Vec, is_skipped: bool, + ret: R +} + +pub trait RetChannel { + fn send(self, t: T); +} + +#[cfg(feature = "tokio")] +impl RetChannel for tokio::sync::oneshot::Sender { + fn send(self, t: T) { + let _ = self.send(t); + } } -impl StepResultsBuilder { - pub fn into_ret(self) -> Vec { - self.step_results +impl StepResultsBuilder { + pub fn new(ret: R) -> Self { + Self { + current: None, + step_results: Vec::new(), + is_skipped: false, + ret, + } } } -impl ResultBuilder for StepResultsBuilder { +impl>> ResultBuilder for StepResultsBuilder { fn init(&mut self, _config: &QueryBuilderConfig) -> Result<(), QueryResultBuilderError> { - *self = Default::default(); + self.current = None; + self.step_results.clear(); + self.is_skipped = false; Ok(()) } @@ -218,32 +240,13 @@ impl ResultBuilder for StepResultsBuilder { Ok(()) } - fn begin_rows(&mut self) -> Result<(), QueryResultBuilderError> { - Ok(()) - } - - fn begin_row(&mut self) -> Result<(), QueryResultBuilderError> { - Ok(()) - } - - fn add_row_value(&mut self, _v: ValueRef) -> Result<(), QueryResultBuilderError> { - Ok(()) - } - - fn finish_row(&mut self) -> Result<(), QueryResultBuilderError> { - Ok(()) - } - - fn finish_rows(&mut self) -> Result<(), QueryResultBuilderError> { - Ok(()) - } - - fn finish( - &mut self, + fn finnalize( + self, _is_txn: bool, _frame_no: Option, - ) -> Result<(), QueryResultBuilderError> { - Ok(()) + ) -> Result { + self.ret.send(self.step_results); + Ok(true) } } @@ -349,12 +352,12 @@ impl ResultBuilder for Take { } } - fn finish( - &mut self, + fn finnalize( + self, is_txn: bool, frame_no: Option, - ) -> Result<(), QueryResultBuilderError> { - self.inner.finish(is_txn, frame_no) + ) -> Result { + self.inner.finnalize(is_txn, frame_no) } } @@ -500,7 +503,7 @@ pub mod test { FinishRow => b.finish_row().unwrap(), FinishRows => b.finish_rows().unwrap(), Finish => { - b.finish(false, None).unwrap(); + b.finnalize(false, None).unwrap(); break; } BuilderError => return b, @@ -643,7 +646,7 @@ pub mod test { self.transition(FinishRows) } - fn finish( + fn finnalize( &mut self, _is_txn: bool, _frame_no: Option, @@ -700,7 +703,7 @@ pub mod test { builder.finish_rows().unwrap(); builder.finish_step(0, None).unwrap(); - builder.finish(false, None).unwrap(); + builder.finnalize(false, None).unwrap(); } #[test] From 95a28583715e0c0b6819d1fc44bf155991758c76 Mon Sep 17 00:00:00 2001 From: ad hoc Date: Thu, 13 Jul 2023 17:17:49 +0200 Subject: [PATCH 23/64] proxy writes --- libsqlx-server/src/allocation/mod.rs | 615 ++++++++++++++++++--- libsqlx-server/src/hrana/batch.rs | 23 +- libsqlx-server/src/hrana/http/mod.rs | 48 +- libsqlx-server/src/hrana/http/request.rs | 4 +- libsqlx-server/src/hrana/http/stream.rs | 19 +- libsqlx-server/src/hrana/result_builder.rs | 67 +-- libsqlx-server/src/hrana/stmt.rs | 3 +- libsqlx-server/src/linc/connection.rs | 2 +- libsqlx-server/src/linc/proto.rs | 154 +++--- libsqlx-server/src/manager.rs | 8 +- libsqlx/src/analysis.rs | 5 +- libsqlx/src/connection.rs | 8 +- libsqlx/src/database/libsql/connection.rs | 18 +- libsqlx/src/database/proxy/connection.rs | 128 +++-- libsqlx/src/database/proxy/database.rs | 5 +- libsqlx/src/program.rs | 8 +- libsqlx/src/query.rs | 4 +- libsqlx/src/result_builder.rs | 26 +- 18 files changed, 808 insertions(+), 337 deletions(-) diff --git a/libsqlx-server/src/allocation/mod.rs b/libsqlx-server/src/allocation/mod.rs index 7e7e2b9b..7d1e8fe6 100644 --- a/libsqlx-server/src/allocation/mod.rs +++ b/libsqlx-server/src/allocation/mod.rs @@ -1,5 +1,7 @@ use std::collections::hash_map::Entry; use std::collections::HashMap; +use std::mem::size_of; +use std::ops::Deref; use std::path::PathBuf; use std::sync::Arc; use std::time::{Duration, Instant}; @@ -7,12 +9,14 @@ use std::time::{Duration, Instant}; use bytes::Bytes; use either::Either; use libsqlx::libsql::{LibsqlDatabase, LogCompactor, LogFile, PrimaryType, ReplicaType}; +use libsqlx::program::Program; use libsqlx::proxy::{WriteProxyConnection, WriteProxyDatabase}; -use libsqlx::result_builder::ResultBuilder; +use libsqlx::result_builder::{Column, QueryBuilderConfig, ResultBuilder}; use libsqlx::{ Database as _, DescribeResponse, Frame, FrameNo, InjectableDatabase, Injector, LogReadError, ReplicationLogger, }; +use parking_lot::Mutex; use tokio::sync::{mpsc, oneshot}; use tokio::task::{block_in_place, JoinSet}; use tokio::time::timeout; @@ -20,28 +24,26 @@ use tokio::time::timeout; use crate::hrana; use crate::hrana::http::handle_pipeline; use crate::hrana::http::proto::{PipelineRequestBody, PipelineResponseBody}; -use crate::linc::bus::Dispatch; -use crate::linc::proto::{Enveloppe, Frames, Message}; +use crate::linc::bus::{Bus, Dispatch}; +use crate::linc::proto::{ + BuilderStep, Enveloppe, Frames, Message, ProxyResponse, StepError, Value, +}; use crate::linc::{Inbound, NodeId, Outbound}; +use crate::manager::Manager; use crate::meta::DatabaseId; use self::config::{AllocConfig, DbConfig}; pub mod config; -type LibsqlConnection = Either< - libsqlx::libsql::LibsqlConnection, - WriteProxyConnection, DummyConn>, ->; -type ExecFn = Box; +/// the maximum number of frame a Frame messahe is allowed to contain +const FRAMES_MESSAGE_MAX_COUNT: usize = 5; + +type ProxyConnection = + WriteProxyConnection, RemoteConn>; +type ExecFn = Box; -#[derive(Clone)] -pub struct ConnectionId { - id: u32, - close_sender: mpsc::Sender<()>, -} pub enum AllocationMessage { - NewConnection(oneshot::Sender), HranaPipelineReq { req: PipelineRequestBody, ret: oneshot::Sender>, @@ -49,43 +51,240 @@ pub enum AllocationMessage { Inbound(Inbound), } -pub struct DummyDb; -pub struct DummyConn; +pub struct RemoteDb; -impl libsqlx::Connection for DummyConn { - fn execute_program( +#[derive(Clone)] +pub struct RemoteConn { + inner: Arc, +} + +struct Request { + id: Option, + builder: Box, + pgm: Option, + next_seq_no: u32, +} + +pub struct RemoteConnInner { + current_req: Mutex>, +} + +impl Deref for RemoteConn { + type Target = RemoteConnInner; + + fn deref(&self) -> &Self::Target { + self.inner.as_ref() + } +} + +impl libsqlx::Connection for RemoteConn { + fn execute_program( &mut self, - _pgm: &libsqlx::program::Program, - _result_builder: B, + program: &libsqlx::program::Program, + builder: Box, ) -> libsqlx::Result<()> { - todo!() + // When we need to proxy a query, we place it in the current request slot. When we are + // back in a async context, we'll send it to the primary, and asynchrously drive the + // builder. + let mut lock = self.inner.current_req.lock(); + *lock = match *lock { + Some(_) => unreachable!("conccurent request on the same connection!"), + None => Some(Request { + id: None, + builder, + pgm: Some(program.clone()), + next_seq_no: 0, + }), + }; + + Ok(()) } fn describe(&self, _sql: String) -> libsqlx::Result { - todo!() + unreachable!("Describe request should not be proxied") } } -impl libsqlx::Database for DummyDb { - type Connection = DummyConn; +impl libsqlx::Database for RemoteDb { + type Connection = RemoteConn; fn connect(&self) -> Result { - Ok(DummyConn) + Ok(RemoteConn { + inner: Arc::new(RemoteConnInner { + current_req: Default::default(), + }), + }) } } -type ProxyDatabase = WriteProxyDatabase, DummyDb>; +pub type ProxyDatabase = WriteProxyDatabase, RemoteDb>; + +pub struct PrimaryDatabase { + pub db: LibsqlDatabase, + pub replica_streams: HashMap)>, + pub frame_notifier: tokio::sync::watch::Receiver, +} + +struct ProxyResponseBuilder { + dispatcher: Arc, + buffer: Vec, + to: NodeId, + database_id: DatabaseId, + req_id: u32, + connection_id: u32, + next_seq_no: u32, +} + +const MAX_STEP_BATCH_SIZE: usize = 100_000_000; // ~100kb + +impl ProxyResponseBuilder { + fn maybe_send(&mut self) { + // FIXME: this is stupid: compute current buffer size on the go instead + let size = self + .buffer + .iter() + .map(|s| match s { + BuilderStep::FinishStep(_, _) => 2 * 8, + BuilderStep::StepError(StepError(s)) => s.len(), + BuilderStep::ColsDesc(ref d) => d + .iter() + .map(|c| c.name.len() + c.decl_ty.as_ref().map(|t| t.len()).unwrap_or_default()) + .sum(), + BuilderStep::Finnalize { .. } => 9, + BuilderStep::AddRowValue(v) => match v { + crate::linc::proto::Value::Text(s) | crate::linc::proto::Value::Blob(s) => { + s.len() + } + _ => size_of::(), + }, + _ => 8, + }) + .sum::(); + + if size > MAX_STEP_BATCH_SIZE { + self.send() + } + } + + fn send(&mut self) { + let msg = Outbound { + to: self.to, + enveloppe: Enveloppe { + database_id: Some(self.database_id), + message: Message::ProxyResponse(crate::linc::proto::ProxyResponse { + connection_id: self.connection_id, + req_id: self.req_id, + row_steps: std::mem::take(&mut self.buffer), + seq_no: self.next_seq_no, + }), + }, + }; + + self.next_seq_no += 1; + tokio::runtime::Handle::current().block_on(self.dispatcher.dispatch(msg)); + } +} + +impl ResultBuilder for ProxyResponseBuilder { + fn init( + &mut self, + _config: &libsqlx::result_builder::QueryBuilderConfig, + ) -> Result<(), libsqlx::result_builder::QueryResultBuilderError> { + self.buffer.push(BuilderStep::Init); + self.maybe_send(); + Ok(()) + } + + fn begin_step(&mut self) -> Result<(), libsqlx::result_builder::QueryResultBuilderError> { + self.buffer.push(BuilderStep::BeginStep); + self.maybe_send(); + Ok(()) + } + + fn finish_step( + &mut self, + affected_row_count: u64, + last_insert_rowid: Option, + ) -> Result<(), libsqlx::result_builder::QueryResultBuilderError> { + self.buffer.push(BuilderStep::FinishStep( + affected_row_count, + last_insert_rowid, + )); + self.maybe_send(); + Ok(()) + } + + fn step_error( + &mut self, + error: libsqlx::error::Error, + ) -> Result<(), libsqlx::result_builder::QueryResultBuilderError> { + self.buffer + .push(BuilderStep::StepError(StepError(error.to_string()))); + self.maybe_send(); + Ok(()) + } + + fn cols_description( + &mut self, + cols: &mut dyn Iterator, + ) -> Result<(), libsqlx::result_builder::QueryResultBuilderError> { + self.buffer + .push(BuilderStep::ColsDesc(cols.map(Into::into).collect())); + self.maybe_send(); + Ok(()) + } + + fn begin_rows(&mut self) -> Result<(), libsqlx::result_builder::QueryResultBuilderError> { + self.buffer.push(BuilderStep::BeginRows); + self.maybe_send(); + Ok(()) + } + + fn begin_row(&mut self) -> Result<(), libsqlx::result_builder::QueryResultBuilderError> { + self.buffer.push(BuilderStep::BeginRow); + self.maybe_send(); + Ok(()) + } + + fn add_row_value( + &mut self, + v: libsqlx::result_builder::ValueRef, + ) -> Result<(), libsqlx::result_builder::QueryResultBuilderError> { + self.buffer.push(BuilderStep::AddRowValue(v.into())); + self.maybe_send(); + Ok(()) + } + + fn finish_row(&mut self) -> Result<(), libsqlx::result_builder::QueryResultBuilderError> { + self.buffer.push(BuilderStep::FinishRow); + self.maybe_send(); + Ok(()) + } + + fn finish_rows(&mut self) -> Result<(), libsqlx::result_builder::QueryResultBuilderError> { + self.buffer.push(BuilderStep::FinishRows); + self.maybe_send(); + Ok(()) + } + + fn finnalize( + &mut self, + is_txn: bool, + frame_no: Option, + ) -> Result { + self.buffer + .push(BuilderStep::Finnalize { is_txn, frame_no }); + self.send(); + Ok(true) + } +} pub enum Database { - Primary { - db: LibsqlDatabase, - replica_streams: HashMap)>, - frame_notifier: tokio::sync::watch::Receiver, - }, + Primary(PrimaryDatabase), Replica { db: ProxyDatabase, injector_handle: mpsc::Sender, - primary_node_id: NodeId, + primary_id: NodeId, last_received_frame_ts: Option, }, } @@ -194,9 +393,6 @@ struct FrameStreamer { buffer: Vec, } -// the maximum number of frame a Frame messahe is allowed to contain -const FRAMES_MESSAGE_MAX_COUNT: usize = 5; - impl FrameStreamer { async fn run(mut self) { loop { @@ -261,15 +457,15 @@ impl Database { ) .unwrap(); - Self::Primary { + Self::Primary(PrimaryDatabase { db, replica_streams: HashMap::new(), frame_notifier: receiver, - } + }) } DbConfig::Replica { primary_node_id } => { let rdb = LibsqlDatabase::new_replica(path, MAX_INJECTOR_BUFFER_CAP, ()).unwrap(); - let wdb = DummyDb; + let wdb = RemoteDb; let mut db = WriteProxyDatabase::new(rdb, wdb, Arc::new(|_| ())); let injector = db.injector().unwrap(); let (sender, receiver) = mpsc::channel(16); @@ -291,17 +487,191 @@ impl Database { Self::Replica { db, injector_handle: sender, - primary_node_id, + primary_id: primary_node_id, last_received_frame_ts: None, } } } } - fn connect(&self) -> LibsqlConnection { + fn connect(&self, connection_id: u32, alloc: &Allocation) -> impl ConnectionHandler { + match self { + Database::Primary(PrimaryDatabase { db, .. }) => Either::Right(PrimaryConnection { + conn: db.connect().unwrap(), + }), + Database::Replica { db, primary_id, .. } => Either::Left(ReplicaConnection { + conn: db.connect().unwrap(), + connection_id, + next_req_id: 0, + primary_id: *primary_id, + database_id: DatabaseId::from_name(&alloc.db_name), + dispatcher: alloc.bus.clone(), + }), + } + } + + pub fn is_primary(&self) -> bool { + matches!(self, Self::Primary(..)) + } +} + +struct PrimaryConnection { + conn: libsqlx::libsql::LibsqlConnection, +} + +#[async_trait::async_trait] +impl ConnectionHandler for PrimaryConnection { + fn exec_ready(&self) -> bool { + true + } + + async fn handle_exec(&mut self, exec: ExecFn) { + block_in_place(|| exec(&mut self.conn)); + } + + async fn handle_inbound(&mut self, _msg: Inbound) { + tracing::debug!("primary connection received message, ignoring.") + } +} + +struct ReplicaConnection { + conn: ProxyConnection, + connection_id: u32, + next_req_id: u32, + primary_id: NodeId, + database_id: DatabaseId, + dispatcher: Arc, +} + +impl ReplicaConnection { + fn handle_proxy_response(&mut self, resp: ProxyResponse) { + let mut lock = self.conn.writer().inner.current_req.lock(); + let finnalized = match *lock { + Some(ref mut req) if req.id == Some(resp.req_id) && resp.seq_no == req.next_seq_no => { + self.next_req_id += 1; + // TODO: pass actual config + let config = QueryBuilderConfig { max_size: None }; + let mut finnalized = false; + for step in resp.row_steps.iter() { + if finnalized { break }; + match step { + BuilderStep::Init => req.builder.init(&config).unwrap(), + BuilderStep::BeginStep => req.builder.begin_step().unwrap(), + BuilderStep::FinishStep(affected_row_count, last_insert_rowid) => req + .builder + .finish_step(*affected_row_count, *last_insert_rowid) + .unwrap(), + BuilderStep::StepError(e) => req.builder.step_error(todo!()).unwrap(), + BuilderStep::ColsDesc(cols) => req + .builder + .cols_description(&mut cols.iter().map(|c| Column { + name: &c.name, + decl_ty: c.decl_ty.as_deref(), + })) + .unwrap(), + BuilderStep::BeginRows => req.builder.begin_rows().unwrap(), + BuilderStep::BeginRow => req.builder.begin_row().unwrap(), + BuilderStep::AddRowValue(v) => req.builder.add_row_value(v.into()).unwrap(), + BuilderStep::FinishRow => req.builder.finish_row().unwrap(), + BuilderStep::FinishRows => req.builder.finish_rows().unwrap(), + BuilderStep::Finnalize { is_txn, frame_no } => { + let _ = req.builder.finnalize(*is_txn, *frame_no).unwrap(); + finnalized = true; + } + } + } + finnalized + } + Some(_) => todo!("error processing response"), + None => { + tracing::error!("received builder message, but there is no pending request"); + false + } + }; + + if finnalized { + *lock = None; + } + } +} + +#[async_trait::async_trait] +impl ConnectionHandler for ReplicaConnection { + fn exec_ready(&self) -> bool { + // we are currently handling a request on this connection + self.conn.writer().current_req.lock().is_none() + } + + async fn handle_exec(&mut self, exec: ExecFn) { + block_in_place(|| exec(&mut self.conn)); + let msg = { + let mut lock = self.conn.writer().inner.current_req.lock(); + match *lock { + Some(ref mut req) if req.id.is_none() => { + let program = req + .pgm + .take() + .expect("unsent request should have a program"); + let req_id = self.next_req_id; + self.next_req_id += 1; + req.id = Some(req_id); + + let msg = Outbound { + to: self.primary_id, + enveloppe: Enveloppe { + database_id: Some(self.database_id), + message: Message::ProxyRequest { + connection_id: self.connection_id, + req_id, + program, + }, + }, + }; + + Some(msg) + } + _ => None, + } + }; + + if let Some(msg) = msg { + self.dispatcher.dispatch(msg).await; + } + } + + async fn handle_inbound(&mut self, msg: Inbound) { + match msg.enveloppe.message { + Message::ProxyResponse(resp) => { + self.handle_proxy_response(resp); + } + _ => (), // ignore anything else + } + } +} + +#[async_trait::async_trait] +impl ConnectionHandler for Either +where + L: ConnectionHandler, + R: ConnectionHandler, +{ + fn exec_ready(&self) -> bool { + match self { + Either::Left(l) => l.exec_ready(), + Either::Right(r) => r.exec_ready(), + } + } + + async fn handle_exec(&mut self, exec: ExecFn) { + match self { + Either::Left(l) => l.handle_exec(exec).await, + Either::Right(r) => r.handle_exec(exec).await, + } + } + async fn handle_inbound(&mut self, msg: Inbound) { match self { - Database::Primary { db, .. } => Either::Left(db.connect().unwrap()), - Database::Replica { db, .. } => Either::Right(db.connect().unwrap()), + Either::Left(l) => l.handle_inbound(msg).await, + Either::Right(r) => r.handle_inbound(msg).await, } } } @@ -310,29 +680,31 @@ pub struct Allocation { pub inbox: mpsc::Receiver, pub database: Database, /// spawned connection futures, returning their connection id on completion. - pub connections_futs: JoinSet, + pub connections_futs: JoinSet<(NodeId, u32)>, pub next_conn_id: u32, pub max_concurrent_connections: u32, + pub connections: HashMap>, pub hrana_server: Arc, - /// handle to the message bus, to send messages - pub dispatcher: Arc, + /// handle to the message bus + pub bus: Arc>>, pub db_name: String, } +#[derive(Clone)] pub struct ConnectionHandle { exec: mpsc::Sender, - exit: oneshot::Sender<()>, + inbound: mpsc::Sender, } impl ConnectionHandle { pub async fn exec(&self, f: F) -> crate::Result where - F: for<'a> FnOnce(&'a mut LibsqlConnection) -> R + Send + 'static, + F: for<'a> FnOnce(&'a mut dyn libsqlx::Connection) -> R + Send + 'static, R: Send + 'static, { let (sender, ret) = oneshot::channel(); - let cb = move |conn: &mut LibsqlConnection| { + let cb = move |conn: &mut dyn libsqlx::Connection| { let res = f(conn); let _ = sender.send(res); }; @@ -349,15 +721,12 @@ impl Allocation { tokio::select! { Some(msg) = self.inbox.recv() => { match msg { - AllocationMessage::NewConnection(ret) => { - let _ =ret.send(self.new_conn().await); - }, - AllocationMessage::HranaPipelineReq { req, ret} => { - let res = handle_pipeline(&self.hrana_server.clone(), req, || async { - let conn= self.new_conn().await; + AllocationMessage::HranaPipelineReq { req, ret } => { + let server = self.hrana_server.clone(); + handle_pipeline(server, req, ret, || async { + let conn = self.new_conn(None).await; Ok(conn) - }).await; - let _ = ret.send(res); + }).await.unwrap(); } AllocationMessage::Inbound(msg) => { self.handle_inbound(msg).await; @@ -388,11 +757,12 @@ impl Allocation { req_no, next_frame_no, } => match &mut self.database { - Database::Primary { + Database::Primary(PrimaryDatabase { db, replica_streams, frame_notifier, - } => { + .. + }) => { let streamer = FrameStreamer { logger: db.logger(), database_id: DatabaseId::from_name(&self.db_name), @@ -400,7 +770,7 @@ impl Allocation { next_frame_no, req_no, seq_no: 0, - dipatcher: self.dispatcher.clone(), + dipatcher: self.bus.clone() as _, notifier: frame_notifier.clone(), buffer: Vec::new(), }; @@ -435,62 +805,139 @@ impl Allocation { *last_received_frame_ts = Some(Instant::now()); injector_handle.send(frames).await.unwrap(); } - Database::Primary { .. } => todo!("handle primary receiving txn"), + Database::Primary(PrimaryDatabase { .. }) => todo!("handle primary receiving txn"), }, - Message::ProxyRequest { .. } => todo!(), - Message::ProxyResponse { .. } => todo!(), + Message::ProxyRequest { + connection_id, + req_id, + program, + } => { + self.handle_proxy(msg.from, connection_id, req_id, program) + .await + } + Message::ProxyResponse(ref r) => { + if let Some(conn) = self + .connections + .get(&self.bus.node_id()) + .and_then(|m| m.get(&r.connection_id).cloned()) + { + conn.inbound.send(msg).await.unwrap(); + } + } Message::CancelRequest { .. } => todo!(), Message::CloseConnection { .. } => todo!(), Message::Error(_) => todo!(), } } - async fn new_conn(&mut self) -> ConnectionHandle { - let id = self.next_conn_id(); - let conn = block_in_place(|| self.database.connect()); - let (close_sender, exit) = oneshot::channel(); + async fn handle_proxy( + &mut self, + node_id: NodeId, + connection_id: u32, + req_id: u32, + program: Program, + ) { + let dispatcher = self.bus.clone(); + let database_id = DatabaseId::from_name(&self.db_name); + let exec = |conn: ConnectionHandle| async move { + let _ = conn + .exec(move |conn| { + let builder = ProxyResponseBuilder { + dispatcher, + req_id, + buffer: Vec::new(), + to: node_id, + database_id, + connection_id, + next_seq_no: 0, + }; + conn.execute_program(&program, Box::new(builder)).unwrap(); + }) + .await; + }; + + if self.database.is_primary() { + match self + .connections + .get(&node_id) + .and_then(|m| m.get(&connection_id).cloned()) + { + Some(handle) => { + tokio::spawn(exec(handle)); + } + None => { + let handle = self.new_conn(Some((node_id, connection_id))).await; + tokio::spawn(exec(handle)); + } + } + } + } + + async fn new_conn(&mut self, remote: Option<(NodeId, u32)>) -> ConnectionHandle { + let conn_id = self.next_conn_id(); + let conn = block_in_place(|| self.database.connect(conn_id, self)); let (exec_sender, exec_receiver) = mpsc::channel(1); + let (inbound_sender, inbound_receiver) = mpsc::channel(1); + let id = remote.unwrap_or((self.bus.node_id(), conn_id)); let conn = Connection { id, conn, - exit, exec: exec_receiver, + inbound: inbound_receiver, }; self.connections_futs.spawn(conn.run()); - - ConnectionHandle { + let handle = ConnectionHandle { exec: exec_sender, - exit: close_sender, - } + inbound: inbound_sender, + }; + self.connections + .entry(id.0) + .or_insert_with(HashMap::new) + .insert(id.1, handle.clone()); + handle } fn next_conn_id(&mut self) -> u32 { loop { self.next_conn_id = self.next_conn_id.wrapping_add(1); - return self.next_conn_id; - // if !self.connections.contains_key(&self.next_conn_id) { - // return self.next_conn_id; - // } + if self + .connections + .get(&self.bus.node_id()) + .and_then(|m| m.get(&self.next_conn_id)) + .is_none() + { + return self.next_conn_id; + } } } } -struct Connection { - id: u32, - conn: LibsqlConnection, - exit: oneshot::Receiver<()>, +struct Connection { + id: (NodeId, u32), + conn: C, exec: mpsc::Receiver, + inbound: mpsc::Receiver, +} + +#[async_trait::async_trait] +trait ConnectionHandler: Send { + fn exec_ready(&self) -> bool; + async fn handle_exec(&mut self, exec: ExecFn); + async fn handle_inbound(&mut self, msg: Inbound); } -impl Connection { - async fn run(mut self) -> u32 { +impl Connection { + async fn run(mut self) -> (NodeId, u32) { loop { tokio::select! { - _ = &mut self.exit => break, - Some(exec) = self.exec.recv() => { - tokio::task::block_in_place(|| exec(&mut self.conn)); + Some(inbound) = self.inbound.recv() => { + self.conn.handle_inbound(inbound).await; } + Some(exec) = self.exec.recv(), if self.conn.exec_ready() => { + self.conn.handle_exec(exec).await; + }, + else => break, } } diff --git a/libsqlx-server/src/hrana/batch.rs b/libsqlx-server/src/hrana/batch.rs index c4131c45..a9ed0553 100644 --- a/libsqlx-server/src/hrana/batch.rs +++ b/libsqlx-server/src/hrana/batch.rs @@ -8,7 +8,6 @@ use super::stmt::{proto_stmt_to_query, stmt_error_from_sqld_error}; use super::{proto, ProtocolError, Version}; use color_eyre::eyre::anyhow; -use libsqlx::Connection; use libsqlx::analysis::Statement; use libsqlx::program::{Cond, Program, Step}; use libsqlx::query::{Params, Query}; @@ -78,7 +77,7 @@ pub async fn execute_batch( let fut = db .exec(move |conn| -> color_eyre::Result<_> { let (builder, ret) = HranaBatchProtoBuilder::new(); - conn.execute_program(&pgm, builder)?; + conn.execute_program(&pgm, Box::new(builder))?; Ok(ret) }) .await??; @@ -116,20 +115,18 @@ pub async fn execute_sequence(conn: &ConnectionHandle, pgm: Program) -> color_ey .exec(move |conn| -> color_eyre::Result<_> { let (snd, rcv) = oneshot::channel(); let builder = StepResultsBuilder::new(snd); - conn.execute_program(&pgm, builder)?; + conn.execute_program(&pgm, Box::new(builder))?; Ok(rcv) }) .await??; - fut.await? - .into_iter() - .try_for_each(|result| match result { - StepResult::Ok => Ok(()), - StepResult::Err(e) => match stmt_error_from_sqld_error(e) { - Ok(stmt_err) => Err(anyhow!(stmt_err)), - Err(sqld_err) => Err(anyhow!(sqld_err)), - }, - StepResult::Skipped => Err(anyhow!("Statement in sequence was not executed")), - }) + fut.await?.into_iter().try_for_each(|result| match result { + StepResult::Ok => Ok(()), + StepResult::Err(e) => match stmt_error_from_sqld_error(e) { + Ok(stmt_err) => Err(anyhow!(stmt_err)), + Err(sqld_err) => Err(anyhow!(sqld_err)), + }, + StepResult::Skipped => Err(anyhow!("Statement in sequence was not executed")), + }) } diff --git a/libsqlx-server/src/hrana/http/mod.rs b/libsqlx-server/src/hrana/http/mod.rs index 651ab3f0..521d33ff 100644 --- a/libsqlx-server/src/hrana/http/mod.rs +++ b/libsqlx-server/src/hrana/http/mod.rs @@ -1,7 +1,10 @@ +use std::sync::Arc; + use color_eyre::eyre::Context; use futures::Future; use parking_lot::Mutex; use serde::{de::DeserializeOwned, Serialize}; +use tokio::sync::oneshot; use crate::allocation::ConnectionHandle; @@ -47,31 +50,38 @@ fn handle_index() -> color_eyre::Result> { } pub async fn handle_pipeline( - server: &Server, + server: Arc, req: PipelineRequestBody, + ret: oneshot::Sender>, mk_conn: F, -) -> color_eyre::Result +) -> color_eyre::Result<()> where F: FnOnce() -> Fut, Fut: Future>, { - let mut stream_guard = stream::acquire(server, req.baton.as_deref(), mk_conn).await?; - - let mut results = Vec::with_capacity(req.requests.len()); - for request in req.requests.into_iter() { - let result = request::handle(&mut stream_guard, request) - .await - .context("Could not execute a request in pipeline")?; - results.push(result); - } - - let resp_body = proto::PipelineResponseBody { - baton: stream_guard.release(), - base_url: server.self_url.clone(), - results, - }; - - Ok(resp_body) + let mut stream_guard = stream::acquire(server.clone(), req.baton.as_deref(), mk_conn).await?; + + tokio::spawn(async move { + let f = async move { + let mut results = Vec::with_capacity(req.requests.len()); + for request in req.requests.into_iter() { + let result = request::handle(&mut stream_guard, request) + .await + .context("Could not execute a request in pipeline")?; + results.push(result); + } + + Ok(proto::PipelineResponseBody { + baton: stream_guard.release(), + base_url: server.self_url.clone(), + results, + }) + }; + + let _ = ret.send(f.await); + }); + + Ok(()) } async fn read_request_json( diff --git a/libsqlx-server/src/hrana/http/request.rs b/libsqlx-server/src/hrana/http/request.rs index ac6d8912..eb1623cd 100644 --- a/libsqlx-server/src/hrana/http/request.rs +++ b/libsqlx-server/src/hrana/http/request.rs @@ -13,7 +13,7 @@ pub enum StreamResponseError { } pub async fn handle( - stream_guard: &mut stream::Guard<'_>, + stream_guard: &mut stream::Guard, request: proto::StreamRequest, ) -> color_eyre::Result { let result = match try_handle(stream_guard, request).await { @@ -31,7 +31,7 @@ pub async fn handle( } async fn try_handle( - stream_guard: &mut stream::Guard<'_>, + stream_guard: &mut stream::Guard, request: proto::StreamRequest, ) -> color_eyre::Result { Ok(match request { diff --git a/libsqlx-server/src/hrana/http/stream.rs b/libsqlx-server/src/hrana/http/stream.rs index 5f40537e..25c1e719 100644 --- a/libsqlx-server/src/hrana/http/stream.rs +++ b/libsqlx-server/src/hrana/http/stream.rs @@ -1,6 +1,7 @@ use std::cmp::Reverse; use std::collections::{HashMap, VecDeque}; use std::pin::Pin; +use std::sync::Arc; use std::{future, mem, task}; use base64::prelude::{Engine as _, BASE64_STANDARD_NO_PAD}; @@ -67,8 +68,8 @@ struct Stream { /// Guard object that is used to access a stream from the outside. The guard makes sure that the /// stream's entry in [`ServerStreamState::handles`] is either removed or replaced with /// [`Handle::Available`] after the guard goes out of scope. -pub struct Guard<'srv> { - server: &'srv Server, +pub struct Guard { + server: Arc, /// The guarded stream. This is only set to `None` in the destructor. stream: Option>, /// If set to `true`, the destructor will release the stream for further use (saving it as @@ -101,18 +102,18 @@ impl ServerStreamState { /// Acquire a guard to a new or existing stream. If baton is `Some`, we try to look up the stream, /// otherwise we create a new stream. -pub async fn acquire<'srv, F, Fut>( - server: &'srv Server, +pub async fn acquire( + server: Arc, baton: Option<&str>, mk_conn: F, -) -> color_eyre::Result> +) -> color_eyre::Result where F: FnOnce() -> Fut, Fut: Future>, { let stream = match baton { Some(baton) => { - let (stream_id, baton_seq) = decode_baton(server, baton)?; + let (stream_id, baton_seq) = decode_baton(&server, baton)?; let mut state = server.stream_state.lock(); let handle = state.handles.get_mut(&stream_id); @@ -182,7 +183,7 @@ where }) } -impl<'srv> Guard<'srv> { +impl Guard { pub fn get_db(&self) -> Result<&ConnectionHandle, ProtocolError> { let stream = self.stream.as_ref().unwrap(); stream.conn.as_ref().ok_or(ProtocolError::BatonStreamClosed) @@ -211,7 +212,7 @@ impl<'srv> Guard<'srv> { if stream.conn.is_some() { self.release = true; // tell destructor to make the stream available again Some(encode_baton( - self.server, + &self.server, stream.stream_id, stream.baton_seq, )) @@ -221,7 +222,7 @@ impl<'srv> Guard<'srv> { } } -impl<'srv> Drop for Guard<'srv> { +impl Drop for Guard { fn drop(&mut self) { let stream = self.stream.take().unwrap(); let stream_id = stream.stream_id; diff --git a/libsqlx-server/src/hrana/result_builder.rs b/libsqlx-server/src/hrana/result_builder.rs index 1047f091..e91bca28 100644 --- a/libsqlx-server/src/hrana/result_builder.rs +++ b/libsqlx-server/src/hrana/result_builder.rs @@ -11,16 +11,22 @@ use super::proto; pub struct SingleStatementBuilder { builder: StatementBuilder, - ret: oneshot::Sender>, + ret: Option>>, } impl SingleStatementBuilder { - pub fn new() -> (Self, oneshot::Receiver>) { + pub fn new() -> ( + Self, + oneshot::Receiver>, + ) { let (ret, rcv) = oneshot::channel(); - (Self { - builder: StatementBuilder::default(), - ret, - }, rcv) + ( + Self { + builder: StatementBuilder::default(), + ret: Some(ret), + }, + rcv, + ) } } @@ -38,7 +44,8 @@ impl ResultBuilder for SingleStatementBuilder { affected_row_count: u64, last_insert_rowid: Option, ) -> Result<(), QueryResultBuilderError> { - self.builder.finish_step(affected_row_count, last_insert_rowid) + self.builder + .finish_step(affected_row_count, last_insert_rowid) } fn step_error(&mut self, error: libsqlx::error::Error) -> Result<(), QueryResultBuilderError> { @@ -61,19 +68,16 @@ impl ResultBuilder for SingleStatementBuilder { } fn finnalize( - self, + &mut self, _is_txn: bool, _frame_no: Option, - ) -> Result - where Self: Sized - { - let res = self.builder.into_ret(); - let _ = self.ret.send(res); + ) -> Result { + let res = self.builder.take_ret(); + let _ = self.ret.take().unwrap().send(res); Ok(true) } } - #[derive(Debug, Default)] struct StatementBuilder { has_step: bool, @@ -191,12 +195,12 @@ impl StatementBuilder { Ok(()) } - pub fn into_ret(self) -> Result { - match self.err { + pub fn take_ret(&mut self) -> Result { + match self.err.take() { Some(err) => Err(err), None => Ok(proto::StmtResult { - cols: self.cols, - rows: self.rows, + cols: std::mem::take(&mut self.cols), + rows: std::mem::take(&mut self.rows), affected_row_count: self.affected_row_count, last_insert_rowid: self.last_insert_rowid, }), @@ -262,23 +266,24 @@ pub struct HranaBatchProtoBuilder { current_size: u64, max_response_size: u64, step_empty: bool, - ret: oneshot::Sender + ret: oneshot::Sender, } impl HranaBatchProtoBuilder { pub fn new() -> (Self, oneshot::Receiver) { let (ret, rcv) = oneshot::channel(); - (Self { - step_results: Vec::new(), - step_errors: Vec::new(), - stmt_builder: StatementBuilder::default(), - current_size: 0, - max_response_size: u64::MAX, - step_empty: false, - ret, - }, - rcv) - + ( + Self { + step_results: Vec::new(), + step_errors: Vec::new(), + stmt_builder: StatementBuilder::default(), + current_size: 0, + max_response_size: u64::MAX, + step_empty: false, + ret, + }, + rcv, + ) } pub fn into_ret(self) -> proto::BatchResult { proto::BatchResult { @@ -314,7 +319,7 @@ impl ResultBuilder for HranaBatchProtoBuilder { max_response_size: self.max_response_size - self.current_size, ..Default::default() }; - match std::mem::replace(&mut self.stmt_builder, new_builder).into_ret() { + match std::mem::replace(&mut self.stmt_builder, new_builder).take_ret() { Ok(res) => { self.step_results.push((!self.step_empty).then_some(res)); self.step_errors.push(None); diff --git a/libsqlx-server/src/hrana/stmt.rs b/libsqlx-server/src/hrana/stmt.rs index e6c002a1..1a8c03f6 100644 --- a/libsqlx-server/src/hrana/stmt.rs +++ b/libsqlx-server/src/hrana/stmt.rs @@ -3,7 +3,6 @@ use std::collections::HashMap; use color_eyre::eyre::{anyhow, bail}; use libsqlx::analysis::Statement; use libsqlx::query::{Params, Query, Value}; -use libsqlx::Connection; use super::result_builder::SingleStatementBuilder; use super::{proto, ProtocolError, Version}; @@ -52,7 +51,7 @@ pub async fn execute_stmt( .exec(move |conn| -> color_eyre::Result<_> { let (builder, ret) = SingleStatementBuilder::new(); let pgm = libsqlx::program::Program::from_queries(std::iter::once(query)); - conn.execute_program(&pgm, builder)?; + conn.execute_program(&pgm, Box::new(builder))?; Ok(ret) }) diff --git a/libsqlx-server/src/linc/connection.rs b/libsqlx-server/src/linc/connection.rs index bf5bd97e..09e2ec44 100644 --- a/libsqlx-server/src/linc/connection.rs +++ b/libsqlx-server/src/linc/connection.rs @@ -163,7 +163,7 @@ where self.conn.feed(m).await.unwrap(); } self.conn.flush().await.unwrap(); - } + }, else => { self.state = ConnectionState::Close; } diff --git a/libsqlx-server/src/linc/proto.rs b/libsqlx-server/src/linc/proto.rs index bec6ff7a..a9aa529d 100644 --- a/libsqlx-server/src/linc/proto.rs +++ b/libsqlx-server/src/linc/proto.rs @@ -1,4 +1,5 @@ use bytes::Bytes; +use libsqlx::{program::Program, FrameNo}; use serde::{Deserialize, Serialize}; use uuid::Uuid; @@ -6,9 +7,7 @@ use crate::meta::DatabaseId; use super::NodeId; -pub type Program = String; - -#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)] +#[derive(Debug, Serialize, Deserialize)] pub struct Enveloppe { pub database_id: Option, pub message: Message, @@ -25,7 +24,18 @@ pub struct Frames { pub frames: Vec, } -#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)] +#[derive(Debug, Serialize, Deserialize)] +/// Response to a proxied query +pub struct ProxyResponse { + pub connection_id: u32, + /// id of the request this message is a response to. + pub req_id: u32, + pub seq_no: u32, + /// Collection of steps to drive the query builder transducer. + pub row_steps: Vec, +} + +#[derive(Debug, Serialize, Deserialize)] pub enum Message { /// Initial message exchanged between nodes when connecting Handshake { @@ -58,13 +68,7 @@ pub enum Message { req_id: u32, program: Program, }, - /// Response to a proxied query - ProxyResponse { - /// id of the request this message is a response to. - req_id: u32, - /// Collection of steps to drive the query builder transducer. - row_step: Vec, - }, + ProxyResponse(ProxyResponse), /// Stop processing request `id`. CancelRequest { req_id: u32, @@ -85,101 +89,79 @@ pub enum ProtoError { UnknownDatabase(String), } -#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)] -pub enum ReplicationMessage { - ReplicationHandshake { - database_name: String, - }, - ReplicationHandshakeResponse { - /// id of the replication log - log_id: Uuid, - /// current frame_no of the primary - current_frame_no: u64, - }, - Replicate { - /// next frame no to send - next_frame_no: u64, - }, - /// a batch of frames that are part of the same transaction - Transaction { - /// if not None, then the last frame is a commit frame, and this is the new size of the database. - size_after: Option, - /// frame_no of the last frame in frames - end_frame_no: u64, - /// a batch of frames part of the transaction. - frames: Vec, - }, -} - -#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)] -pub struct Frame { - /// Page id of that frame - page_id: u32, - /// Data - data: Bytes, -} - -#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)] -pub enum ProxyMessage { - /// Proxy a query to a primary - ProxyRequest { - /// id of the connection to perform the query against - /// If the connection doesn't already exist it is created - /// Id of the request. - /// Responses to this request must have the same id. - connection_id: u32, - req_id: u32, - program: Program, - }, - /// Response to a proxied query - ProxyResponse { - /// id of the request this message is a response to. - req_id: u32, - /// Collection of steps to drive the query builder transducer. - row_step: Vec, - }, - /// Stop processing request `id`. - CancelRequest { req_id: u32 }, - /// Close Connection with passed id. - CloseConnection { connection_id: u32 }, -} - /// Steps applied to the query builder transducer to build a response to a proxied query. /// Those types closely mirror those of the `QueryBuilderTrait`. -#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)] +#[derive(Debug, Serialize, Deserialize, Clone)] pub enum BuilderStep { + Init, BeginStep, - FinishStep(u64, Option), + FinishStep(u64, Option), StepError(StepError), ColsDesc(Vec), BeginRows, BeginRow, AddRowValue(Value), FinishRow, - FinishRos, - Finish(ConnectionState), + FinishRows, + Finnalize { + is_txn: bool, + frame_no: Option, + }, } -// State of the connection after a query was executed -#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)] -pub enum ConnectionState { - /// The connection is still in a open transaction state - OpenTxn, - /// The connection is idle. - Idle, +#[derive(Debug, Serialize, Deserialize, Clone)] +pub enum Value { + Null, + Integer(i64), + Real(f64), + // TODO: how to stream blobs/string??? + Text(Vec), + Blob(Vec), } -#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)] -pub enum Value {} +impl<'a> Into> for &'a Value { + fn into(self) -> libsqlx::result_builder::ValueRef<'a> { + use libsqlx::result_builder::ValueRef; + match self { + Value::Null => ValueRef::Null, + Value::Integer(i) => ValueRef::Integer(*i), + Value::Real(x) => ValueRef::Real(*x), + Value::Text(ref t) => ValueRef::Text(t), + Value::Blob(ref b) => ValueRef::Blob(b), + } + } +} + +impl From> for Value { + fn from(value: libsqlx::result_builder::ValueRef) -> Self { + use libsqlx::result_builder::ValueRef; + match value { + ValueRef::Null => Self::Null, + ValueRef::Integer(i) => Self::Integer(i), + ValueRef::Real(x) => Self::Real(x), + ValueRef::Text(s) => Self::Text(s.into()), + ValueRef::Blob(b) => Self::Blob(b.into()), + } + } +} #[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)] pub struct Column { /// name of the column - name: String, + pub name: String, /// Declared type of the column, if any. - decl_ty: Option, + pub decl_ty: Option, +} + +impl From> for Column { + fn from(value: libsqlx::result_builder::Column) -> Self { + Self { + name: value.name.to_string(), + decl_ty: value.decl_ty.map(ToOwned::to_owned), + } + } } /// for now, the stringified version of a sqld::error::Error. #[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)] -pub struct StepError(String); +pub struct StepError(pub String); diff --git a/libsqlx-server/src/manager.rs b/libsqlx-server/src/manager.rs index 89604569..01870144 100644 --- a/libsqlx-server/src/manager.rs +++ b/libsqlx-server/src/manager.rs @@ -1,3 +1,4 @@ +use std::collections::HashMap; use std::path::PathBuf; use std::sync::Arc; @@ -20,10 +21,6 @@ pub struct Manager { const MAX_ALLOC_MESSAGE_QUEUE_LEN: usize = 32; -trait IsSync: Sync {} - -impl IsSync for Allocation {} - impl Manager { pub fn new(db_path: PathBuf, meta_store: Arc, max_conccurent_allocs: u64) -> Self { Self { @@ -54,8 +51,9 @@ impl Manager { next_conn_id: 0, max_concurrent_connections: config.max_conccurent_connection, hrana_server: Arc::new(hrana::http::Server::new(None)), - dispatcher: bus, // TODO: handle self URL? + bus, // TODO: handle self URL? db_name: config.db_name, + connections: HashMap::new(), }; tokio::spawn(alloc.run()); diff --git a/libsqlx/src/analysis.rs b/libsqlx/src/analysis.rs index 97ef5f5b..0706ebff 100644 --- a/libsqlx/src/analysis.rs +++ b/libsqlx/src/analysis.rs @@ -1,9 +1,10 @@ use fallible_iterator::FallibleIterator; +use serde::{Deserialize, Serialize}; use sqlite3_parser::ast::{Cmd, PragmaBody, QualifiedName, Stmt}; use sqlite3_parser::lexer::sql::{Parser, ParserError}; /// A group of statements to be executed together. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub struct Statement { pub stmt: String, pub kind: StmtKind, @@ -19,7 +20,7 @@ impl Default for Statement { } /// Classify statement in categories of interest. -#[derive(Debug, PartialEq, Clone, Copy)] +#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)] pub enum StmtKind { /// The begining of a transaction TxnBegin, diff --git a/libsqlx/src/connection.rs b/libsqlx/src/connection.rs index fa027997..e767073a 100644 --- a/libsqlx/src/connection.rs +++ b/libsqlx/src/connection.rs @@ -24,10 +24,10 @@ pub struct DescribeCol { pub trait Connection { /// Executes a query program - fn execute_program( + fn execute_program( &mut self, pgm: &Program, - result_builder: B, + result_builder: Box, ) -> crate::Result<()>; /// Parse the SQL statement and return information about it. @@ -39,10 +39,10 @@ where T: Connection, X: Connection, { - fn execute_program( + fn execute_program( &mut self, pgm: &Program, - result_builder: B, + result_builder: Box, ) -> crate::Result<()> { match self { Either::Left(c) => c.execute_program(pgm, result_builder), diff --git a/libsqlx/src/database/libsql/connection.rs b/libsqlx/src/database/libsql/connection.rs index 27ee59e1..1f0ea7ab 100644 --- a/libsqlx/src/database/libsql/connection.rs +++ b/libsqlx/src/database/libsql/connection.rs @@ -101,14 +101,14 @@ impl LibsqlConnection { &self.conn } - fn run(&mut self, pgm: &Program, mut builder: B) -> Result<()> { + fn run(&mut self, pgm: &Program, builder: &mut dyn ResultBuilder) -> Result<()> { let mut results = Vec::with_capacity(pgm.steps.len()); builder.init(&self.builder_config)?; let is_autocommit_before = self.conn.is_autocommit(); for step in pgm.steps() { - let res = self.execute_step(step, &results, &mut builder)?; + let res = self.execute_step(step, &results, builder)?; results.push(res); } @@ -125,11 +125,11 @@ impl LibsqlConnection { Ok(()) } - fn execute_step( + fn execute_step( &mut self, step: &Step, results: &[bool], - builder: &mut B, + builder: &mut dyn ResultBuilder, ) -> Result { builder.begin_step()?; let mut enabled = match step.cond.as_ref() { @@ -163,10 +163,10 @@ impl LibsqlConnection { Ok(enabled) } - fn execute_query( + fn execute_query( &self, query: &Query, - builder: &mut B, + builder: &mut dyn ResultBuilder, ) -> Result<(u64, Option)> { tracing::trace!("executing query: {}", query.stmt.stmt); @@ -240,12 +240,12 @@ fn eval_cond(cond: &Cond, results: &[bool]) -> Result { } impl Connection for LibsqlConnection { - fn execute_program( + fn execute_program( &mut self, pgm: &Program, - builder: B, + mut builder: Box, ) -> crate::Result<()> { - self.run(pgm, builder) + self.run(pgm, &mut *builder) } fn describe(&self, sql: String) -> crate::Result { diff --git a/libsqlx/src/database/proxy/connection.rs b/libsqlx/src/database/proxy/connection.rs index 68d10c00..a06b6620 100644 --- a/libsqlx/src/database/proxy/connection.rs +++ b/libsqlx/src/database/proxy/connection.rs @@ -1,3 +1,7 @@ +use std::sync::Arc; + +use parking_lot::Mutex; + use crate::connection::{Connection, DescribeResponse}; use crate::database::FrameNo; use crate::program::Program; @@ -14,31 +18,48 @@ pub(crate) struct ConnState { /// A connection that proxies write operations to the `WriteDb` and the read operations to the /// `ReadDb` -pub struct WriteProxyConnection { - pub(crate) read_db: ReadDb, - pub(crate) write_db: WriteDb, +pub struct WriteProxyConnection { + pub(crate) read_conn: R, + pub(crate) write_conn: W, pub(crate) wait_frame_no_cb: WaitFrameNoCb, - pub(crate) state: ConnState, + pub(crate) state: Arc>, +} + +impl WriteProxyConnection { + pub fn writer_mut(&mut self) -> &mut W { + &mut self.write_conn + } + + pub fn writer(&self) -> &W { + &self.write_conn + } + + pub fn reader_mut(&mut self) -> &mut R { + &mut self.read_conn + } + + pub fn reader(&self) -> &R { + &self.read_conn + } } -struct MaybeRemoteExecBuilder<'a, 'b, B, W> { - builder: B, - conn: &'a mut W, - pgm: &'b Program, - state: &'a mut ConnState, +struct MaybeRemoteExecBuilder { + builder: Option>, + conn: W, + pgm: Program, + state: Arc>, } -impl<'a, 'b, B, W> ResultBuilder for MaybeRemoteExecBuilder<'a, 'b, B, W> +impl ResultBuilder for MaybeRemoteExecBuilder where - W: Connection, - B: ResultBuilder, + W: Connection + Send + 'static, { fn init(&mut self, config: &QueryBuilderConfig) -> Result<(), QueryResultBuilderError> { - self.builder.init(config) + self.builder.as_mut().unwrap().init(config) } fn begin_step(&mut self) -> Result<(), QueryResultBuilderError> { - self.builder.begin_step() + self.builder.as_mut().unwrap().begin_step() } fn finish_step( @@ -47,45 +68,47 @@ where last_insert_rowid: Option, ) -> Result<(), QueryResultBuilderError> { self.builder + .as_mut() + .unwrap() .finish_step(affected_row_count, last_insert_rowid) } fn step_error(&mut self, error: crate::error::Error) -> Result<(), QueryResultBuilderError> { - self.builder.step_error(error) + self.builder.as_mut().unwrap().step_error(error) } fn cols_description( &mut self, cols: &mut dyn Iterator, ) -> Result<(), QueryResultBuilderError> { - self.builder.cols_description(cols) + self.builder.as_mut().unwrap().cols_description(cols) } fn begin_rows(&mut self) -> Result<(), QueryResultBuilderError> { - self.builder.begin_rows() + self.builder.as_mut().unwrap().begin_rows() } fn begin_row(&mut self) -> Result<(), QueryResultBuilderError> { - self.builder.begin_row() + self.builder.as_mut().unwrap().begin_row() } fn add_row_value( &mut self, v: rusqlite::types::ValueRef, ) -> Result<(), QueryResultBuilderError> { - self.builder.add_row_value(v) + self.builder.as_mut().unwrap().add_row_value(v) } fn finish_row(&mut self) -> Result<(), QueryResultBuilderError> { - self.builder.finish_row() + self.builder.as_mut().unwrap().finish_row() } fn finish_rows(&mut self) -> Result<(), QueryResultBuilderError> { - self.builder.finish_rows() + self.builder.as_mut().unwrap().finish_rows() } fn finnalize( - self, + &mut self, is_txn: bool, frame_no: Option, ) -> Result { @@ -93,70 +116,75 @@ where // a read only connection is not allowed to leave an open transaction. We mispredicted the // final state of the connection, so we rollback, and execute again on the write proxy. let builder = ExtractFrameNoBuilder { - builder: self.builder, - state: self.state, + builder: self + .builder + .take() + .expect("finnalize called more than once"), + state: self.state.clone(), }; - self.conn.execute_program(self.pgm, builder).unwrap(); + self.conn + .execute_program(&self.pgm, Box::new(builder)) + .unwrap(); Ok(false) } else { - self.builder.finnalize(is_txn, frame_no) + self.builder.as_mut().unwrap().finnalize(is_txn, frame_no) } } } -impl Connection for WriteProxyConnection +impl Connection for WriteProxyConnection where - ReadDb: Connection, - WriteDb: Connection, + R: Connection, + W: Connection + Clone + Send + 'static, { - fn execute_program( + fn execute_program( &mut self, pgm: &Program, - builder: B, + builder: Box, ) -> crate::Result<()> { - if !self.state.is_txn && pgm.is_read_only() { - if let Some(frame_no) = self.state.last_frame_no { + if !self.state.lock().is_txn && pgm.is_read_only() { + if let Some(frame_no) = self.state.lock().last_frame_no { (self.wait_frame_no_cb)(frame_no); } let builder = MaybeRemoteExecBuilder { - builder, - conn: &mut self.write_db, - state: &mut self.state, - pgm, + builder: Some(builder), + conn: self.write_conn.clone(), + state: self.state.clone(), + pgm: pgm.clone(), }; // We know that this program won't perform any writes. We attempt to run it on the // replica. If it leaves an open transaction, then this program is an interactive // transaction, so we rollback the replica, and execute again on the primary. - self.read_db.execute_program(pgm, builder)?; + self.read_conn.execute_program(pgm, Box::new(builder))?; // rollback(&mut self.conn.read_db); Ok(()) } else { let builder = ExtractFrameNoBuilder { builder, - state: &mut self.state, + state: self.state.clone(), }; - self.write_db.execute_program(pgm, builder)?; + self.write_conn.execute_program(pgm, Box::new(builder))?; Ok(()) } } fn describe(&self, sql: String) -> crate::Result { - if let Some(frame_no) = self.state.last_frame_no { + if let Some(frame_no) = self.state.lock().last_frame_no { (self.wait_frame_no_cb)(frame_no); } - self.read_db.describe(sql) + self.read_conn.describe(sql) } } -struct ExtractFrameNoBuilder<'a, B> { - builder: B, - state: &'a mut ConnState, +struct ExtractFrameNoBuilder { + builder: Box, + state: Arc>, } -impl ResultBuilder for ExtractFrameNoBuilder<'_, B> { +impl ResultBuilder for ExtractFrameNoBuilder { fn init(&mut self, config: &QueryBuilderConfig) -> Result<(), QueryResultBuilderError> { self.builder.init(config) } @@ -209,12 +237,13 @@ impl ResultBuilder for ExtractFrameNoBuilder<'_, B> { } fn finnalize( - self, + &mut self, is_txn: bool, frame_no: Option, ) -> Result { - self.state.last_frame_no = frame_no; - self.state.is_txn = is_txn; + let mut state = self.state.lock(); + state.last_frame_no = frame_no; + state.is_txn = is_txn; self.builder.finnalize(is_txn, frame_no) } } @@ -225,7 +254,6 @@ mod test { use std::rc::Rc; use std::sync::Arc; - use crate::connection::Connection; use crate::database::test_utils::MockDatabase; use crate::database::{proxy::database::WriteProxyDatabase, Database}; use crate::program::Program; diff --git a/libsqlx/src/database/proxy/database.rs b/libsqlx/src/database/proxy/database.rs index e9add71f..adc82ed8 100644 --- a/libsqlx/src/database/proxy/database.rs +++ b/libsqlx/src/database/proxy/database.rs @@ -24,13 +24,14 @@ impl Database for WriteProxyDatabase where RDB: Database, WDB: Database, + WDB::Connection: Clone + Send + 'static, { type Connection = WriteProxyConnection; /// Create a new connection to the database fn connect(&self) -> Result { Ok(WriteProxyConnection { - read_db: self.read_db.connect()?, - write_db: self.write_db.connect()?, + read_conn: self.read_db.connect()?, + write_conn: self.write_db.connect()?, wait_frame_no_cb: self.wait_frame_no_cb.clone(), state: Default::default(), }) diff --git a/libsqlx/src/program.rs b/libsqlx/src/program.rs index 0b5c7980..b2b627af 100644 --- a/libsqlx/src/program.rs +++ b/libsqlx/src/program.rs @@ -1,8 +1,10 @@ use std::sync::Arc; +use serde::{Deserialize, Serialize}; + use crate::query::Query; -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub struct Program { pub steps: Arc<[Step]>, } @@ -59,13 +61,13 @@ impl Program { } } -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub struct Step { pub cond: Option, pub query: Query, } -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub enum Cond { Ok { step: usize }, Err { step: usize }, diff --git a/libsqlx/src/query.rs b/libsqlx/src/query.rs index 2d37e514..d3b1e5eb 100644 --- a/libsqlx/src/query.rs +++ b/libsqlx/src/query.rs @@ -46,7 +46,7 @@ impl TryFrom> for Value { } } -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub struct Query { pub stmt: Statement, pub params: Params, @@ -67,7 +67,7 @@ impl ToSql for Value { } } -#[derive(Debug, Serialize, Clone)] +#[derive(Debug, Serialize, Clone, Deserialize)] pub enum Params { Named(HashMap), Positional(Vec), diff --git a/libsqlx/src/result_builder.rs b/libsqlx/src/result_builder.rs index 98f598c1..fed13fd3 100644 --- a/libsqlx/src/result_builder.rs +++ b/libsqlx/src/result_builder.rs @@ -80,7 +80,7 @@ pub struct QueryBuilderConfig { pub max_size: Option, } -pub trait ResultBuilder { +pub trait ResultBuilder: Send + 'static { /// (Re)initialize the builder. This method can be called multiple times. fn init(&mut self, _config: &QueryBuilderConfig) -> Result<(), QueryResultBuilderError> { Ok(()) @@ -132,13 +132,10 @@ pub trait ResultBuilder { /// finish the builder, and pass the transaction state. /// If false is returned, and is_txn is true, then the transaction is rolledback. fn finnalize( - self, + &mut self, _is_txn: bool, _frame_no: Option, - ) -> Result - where - Self: Sized, - { + ) -> Result { Ok(true) } } @@ -171,15 +168,15 @@ pub struct StepResultsBuilder { current: Option, step_results: Vec, is_skipped: bool, - ret: R + ret: Option, } -pub trait RetChannel { +pub trait RetChannel: Send + 'static { fn send(self, t: T); } #[cfg(feature = "tokio")] -impl RetChannel for tokio::sync::oneshot::Sender { +impl RetChannel for tokio::sync::oneshot::Sender { fn send(self, t: T) { let _ = self.send(t); } @@ -191,7 +188,7 @@ impl StepResultsBuilder { current: None, step_results: Vec::new(), is_skipped: false, - ret, + ret: Some(ret), } } } @@ -241,11 +238,14 @@ impl>> ResultBuilder for StepResultsBuilder { } fn finnalize( - self, + &mut self, _is_txn: bool, _frame_no: Option, ) -> Result { - self.ret.send(self.step_results); + self.ret + .take() + .expect("finnalize called more than once") + .send(std::mem::take(&mut self.step_results)); Ok(true) } } @@ -353,7 +353,7 @@ impl ResultBuilder for Take { } fn finnalize( - self, + &mut self, is_txn: bool, frame_no: Option, ) -> Result { From b68a3c9edac50cb316760d44f5e1407a02498045 Mon Sep 17 00:00:00 2001 From: ad hoc Date: Thu, 20 Jul 2023 11:13:28 +0200 Subject: [PATCH 24/64] fix tests --- libsqlx-server/src/allocation/mod.rs | 17 +- libsqlx-server/src/linc/bus.rs | 5 + libsqlx-server/src/linc/connection.rs | 228 +++-------------- libsqlx-server/src/linc/connection_pool.rs | 133 ---------- libsqlx-server/src/linc/handler.rs | 15 +- libsqlx-server/src/linc/net.rs | 4 + libsqlx-server/src/linc/server.rs | 237 +----------------- libsqlx-server/src/manager.rs | 10 +- libsqlx-server/src/meta.rs | 5 + libsqlx/Cargo.toml | 2 +- libsqlx/src/database/libsql/mod.rs | 50 ++-- .../database/libsql/replication_log/logger.rs | 6 +- libsqlx/src/database/proxy/connection.rs | 33 +-- libsqlx/src/database/test_utils.rs | 15 +- libsqlx/src/result_builder.rs | 5 +- 15 files changed, 142 insertions(+), 623 deletions(-) diff --git a/libsqlx-server/src/allocation/mod.rs b/libsqlx-server/src/allocation/mod.rs index 7d1e8fe6..fbe0d98e 100644 --- a/libsqlx-server/src/allocation/mod.rs +++ b/libsqlx-server/src/allocation/mod.rs @@ -24,12 +24,11 @@ use tokio::time::timeout; use crate::hrana; use crate::hrana::http::handle_pipeline; use crate::hrana::http::proto::{PipelineRequestBody, PipelineResponseBody}; -use crate::linc::bus::{Bus, Dispatch}; +use crate::linc::bus::{Dispatch}; use crate::linc::proto::{ BuilderStep, Enveloppe, Frames, Message, ProxyResponse, StepError, Value, }; use crate::linc::{Inbound, NodeId, Outbound}; -use crate::manager::Manager; use crate::meta::DatabaseId; use self::config::{AllocConfig, DbConfig}; @@ -505,7 +504,7 @@ impl Database { next_req_id: 0, primary_id: *primary_id, database_id: DatabaseId::from_name(&alloc.db_name), - dispatcher: alloc.bus.clone(), + dispatcher: alloc.dispatcher.clone(), }), } } @@ -687,7 +686,7 @@ pub struct Allocation { pub hrana_server: Arc, /// handle to the message bus - pub bus: Arc>>, + pub dispatcher: Arc, pub db_name: String, } @@ -770,7 +769,7 @@ impl Allocation { next_frame_no, req_no, seq_no: 0, - dipatcher: self.bus.clone() as _, + dipatcher: self.dispatcher.clone() as _, notifier: frame_notifier.clone(), buffer: Vec::new(), }; @@ -818,7 +817,7 @@ impl Allocation { Message::ProxyResponse(ref r) => { if let Some(conn) = self .connections - .get(&self.bus.node_id()) + .get(&self.dispatcher.node_id()) .and_then(|m| m.get(&r.connection_id).cloned()) { conn.inbound.send(msg).await.unwrap(); @@ -837,7 +836,7 @@ impl Allocation { req_id: u32, program: Program, ) { - let dispatcher = self.bus.clone(); + let dispatcher = self.dispatcher.clone(); let database_id = DatabaseId::from_name(&self.db_name); let exec = |conn: ConnectionHandle| async move { let _ = conn @@ -878,7 +877,7 @@ impl Allocation { let conn = block_in_place(|| self.database.connect(conn_id, self)); let (exec_sender, exec_receiver) = mpsc::channel(1); let (inbound_sender, inbound_receiver) = mpsc::channel(1); - let id = remote.unwrap_or((self.bus.node_id(), conn_id)); + let id = remote.unwrap_or((self.dispatcher.node_id(), conn_id)); let conn = Connection { id, conn, @@ -903,7 +902,7 @@ impl Allocation { self.next_conn_id = self.next_conn_id.wrapping_add(1); if self .connections - .get(&self.bus.node_id()) + .get(&self.dispatcher.node_id()) .and_then(|m| m.get(&self.next_conn_id)) .is_none() { diff --git a/libsqlx-server/src/linc/bus.rs b/libsqlx-server/src/linc/bus.rs index 4707c989..a31c3368 100644 --- a/libsqlx-server/src/linc/bus.rs +++ b/libsqlx-server/src/linc/bus.rs @@ -50,6 +50,7 @@ impl Bus { #[async_trait::async_trait] pub trait Dispatch: Send + Sync + 'static { async fn dispatch(&self, msg: Outbound); + fn node_id(&self) -> NodeId; } #[async_trait::async_trait] @@ -62,4 +63,8 @@ impl Dispatch for Bus { // This message is outbound. self.send_queue.enqueue(msg).await; } + + fn node_id(&self) -> NodeId { + self.node_id + } } diff --git a/libsqlx-server/src/linc/connection.rs b/libsqlx-server/src/linc/connection.rs index 09e2ec44..5f5d9f24 100644 --- a/libsqlx-server/src/linc/connection.rs +++ b/libsqlx-server/src/linc/connection.rs @@ -273,9 +273,9 @@ where mod test { use std::sync::Arc; + use futures::{future, pin_mut}; use tokio::sync::Notify; use turmoil::net::{TcpListener, TcpStream}; - use uuid::Uuid; use super::*; @@ -283,151 +283,50 @@ mod test { fn invalid_handshake() { let mut sim = turmoil::Builder::new().build(); - let host_node_id = NodeId::new_v4(); - sim.host("host", move || async move { - let bus = Bus::new(host_node_id); - let listener = turmoil::net::TcpListener::bind("0.0.0.0:1234") - .await - .unwrap(); - let (s, _) = listener.accept().await.unwrap(); - let mut connection = Connection::new_acceptor(s, bus); - connection.tick().await; - - Ok(()) + let host_node_id = 0; + let done = Arc::new(Notify::new()); + let done_clone = done.clone(); + sim.host("host", move || { + let done_clone = done_clone.clone(); + async move { + let bus = Arc::new(Bus::new(host_node_id, |_, _| async {})); + let listener = turmoil::net::TcpListener::bind("0.0.0.0:1234") + .await + .unwrap(); + let (s, _) = listener.accept().await.unwrap(); + let connection = Connection::new_acceptor(s, bus); + let done = done_clone.notified(); + let run = connection.run(); + pin_mut!(done); + pin_mut!(run); + future::select(run, done).await; + + Ok(()) + } }); sim.client("client", async move { let s = TcpStream::connect("host:1234").await.unwrap(); - let mut s = AsyncBincodeStream::<_, Message, Message, _>::from(s).for_async(); - - s.send(Message::Node(NodeMessage::Handshake { - protocol_version: 1234, - node_id: Uuid::new_v4(), - })) - .await - .unwrap(); + let mut s = AsyncBincodeStream::<_, Enveloppe, Enveloppe, _>::from(s).for_async(); + + let msg = Enveloppe { + database_id: None, + message: Message::Handshake { + protocol_version: 1234, + node_id: 1, + }, + }; + s.send(msg).await.unwrap(); let m = s.next().await.unwrap().unwrap(); assert!(matches!( - m, - Message::Node(NodeMessage::Error( - NodeError::HandshakeVersionMismatch { .. } - )) + m.message, + Message::Error( + ProtoError::HandshakeVersionMismatch { .. } + ) )); - Ok(()) - }); - - sim.run().unwrap(); - } - - #[test] - fn stream_closed() { - let mut sim = turmoil::Builder::new().build(); - - let database_id = DatabaseId::new_v4(); - let host_node_id = NodeId::new_v4(); - let notify = Arc::new(Notify::new()); - sim.host("host", { - let notify = notify.clone(); - move || { - let notify = notify.clone(); - async move { - let bus = Bus::new(host_node_id); - let mut sub = bus.subscribe(database_id).unwrap(); - let listener = turmoil::net::TcpListener::bind("0.0.0.0:1234") - .await - .unwrap(); - let (s, _) = listener.accept().await.unwrap(); - let connection = Connection::new_acceptor(s, bus); - tokio::task::spawn_local(connection.run()); - let mut streams = Vec::new(); - loop { - tokio::select! { - Some(mut stream) = sub.next() => { - let m = stream.next().await.unwrap(); - stream.send(m).await.unwrap(); - streams.push(stream); - } - _ = notify.notified() => { - break; - } - } - } - - Ok(()) - } - } - }); - - sim.client("client", async move { - let stream_id = StreamId::new(1); - let node_id = NodeId::new_v4(); - let s = TcpStream::connect("host:1234").await.unwrap(); - let mut s = AsyncBincodeStream::<_, Message, Message, _>::from(s).for_async(); - - s.send(Message::Node(NodeMessage::Handshake { - protocol_version: CURRENT_PROTO_VERSION, - node_id, - })) - .await - .unwrap(); - let m = s.next().await.unwrap().unwrap(); - assert!(matches!(m, Message::Node(NodeMessage::Handshake { .. }))); - - // send message to unexisting stream: - s.send(Message::Stream { - stream_id, - payload: StreamMessage::Dummy, - }) - .await - .unwrap(); - let m = s.next().await.unwrap().unwrap(); - assert_eq!( - m, - Message::Node(NodeMessage::Error(NodeError::UnknownStream(stream_id))) - ); - - // open stream then send message - s.send(Message::Node(NodeMessage::OpenStream { - stream_id, - database_id, - })) - .await - .unwrap(); - s.send(Message::Stream { - stream_id, - payload: StreamMessage::Dummy, - }) - .await - .unwrap(); - let m = s.next().await.unwrap().unwrap(); - assert_eq!( - m, - Message::Stream { - stream_id, - payload: StreamMessage::Dummy - } - ); - - s.send(Message::Node(NodeMessage::CloseStream { - stream_id: StreamId::new(1), - })) - .await - .unwrap(); - s.send(Message::Stream { - stream_id, - payload: StreamMessage::Dummy, - }) - .await - .unwrap(); - let m = s.next().await.unwrap().unwrap(); - assert_eq!( - m, - Message::Node(NodeMessage::Error(NodeError::UnknownStream(stream_id))) - ); - - notify.notify_waiters(); + done.notify_waiters(); Ok(()) }); @@ -459,7 +358,7 @@ mod test { sim.client("client", async move { let stream = TcpStream::connect("host:1234").await.unwrap(); - let bus = Bus::new(NodeId::new_v4()); + let bus = Arc::new(Bus::new(1, |_, _| async {})); let mut conn = Connection::new_acceptor(stream, bus); notify.notify_waiters(); @@ -473,57 +372,4 @@ mod test { sim.run().unwrap(); } - - #[test] - fn zero_stream_id() { - let mut sim = turmoil::Builder::new().build(); - - let notify = Arc::new(Notify::new()); - sim.host("host", { - let notify = notify.clone(); - move || { - let notify = notify.clone(); - async move { - let listener = TcpListener::bind("0.0.0.0:1234").await.unwrap(); - let (stream, _) = listener.accept().await.unwrap(); - let (connection_messages_sender, connection_messages) = mpsc::channel(1); - let conn = Connection { - peer: Some(NodeId::new_v4()), - state: ConnectionState::Connected, - conn: AsyncBincodeStream::from(stream).for_async(), - streams: HashMap::new(), - connection_messages, - connection_messages_sender, - is_initiator: false, - bus: Bus::new(NodeId::new_v4()), - stream_id_allocator: StreamIdAllocator::new(false), - registration: None, - }; - - conn.run().await; - - Ok(()) - } - } - }); - - sim.client("client", async move { - let stream = TcpStream::connect("host:1234").await.unwrap(); - let mut stream = AsyncBincodeStream::<_, Message, Message, _>::from(stream).for_async(); - - stream - .send(Message::Stream { - stream_id: StreamId::new_unchecked(0), - payload: StreamMessage::Dummy, - }) - .await - .unwrap(); - - assert!(stream.next().await.is_none()); - - Ok(()) - }); - - sim.run().unwrap(); - } } diff --git a/libsqlx-server/src/linc/connection_pool.rs b/libsqlx-server/src/linc/connection_pool.rs index b6113a80..3415dee4 100644 --- a/libsqlx-server/src/linc/connection_pool.rs +++ b/libsqlx-server/src/linc/connection_pool.rs @@ -81,136 +81,3 @@ impl ConnectionPool { self.connections.spawn(fut); } } - -#[cfg(test)] -mod test { - use std::sync::Arc; - - use futures::SinkExt; - use tokio::sync::Notify; - use tokio_stream::StreamExt; - - use crate::linc::{server::Server, AllocId}; - - use super::*; - - #[test] - fn manage_connections() { - let mut sim = turmoil::Builder::new().build(); - let database_id = AllocId::new_v4(); - let notify = Arc::new(Notify::new()); - - let expected_msg = crate::linc::proto::StreamMessage::Proxy( - crate::linc::proto::ProxyMessage::ProxyRequest { - connection_id: 42, - req_id: 42, - program: "foobar".into(), - }, - ); - - let spawn_host = |node_id| { - let notify = notify.clone(); - let expected_msg = expected_msg.clone(); - move || { - let notify = notify.clone(); - let expected_msg = expected_msg.clone(); - async move { - let bus = Bus::new(node_id); - let mut sub = bus.subscribe(database_id).unwrap(); - let mut server = Server::new(bus.clone()); - let mut listener = turmoil::net::TcpListener::bind("0.0.0.0:1234") - .await - .unwrap(); - - let mut has_closed = false; - let mut streams = Vec::new(); - loop { - tokio::select! { - _ = notify.notified() => { - if !has_closed { - streams.clear(); - server.close_connections().await; - has_closed = true; - } else { - break; - } - }, - _ = server.tick(&mut listener) => (), - Some(mut stream) = sub.next() => { - stream - .send(expected_msg.clone()) - .await - .unwrap(); - streams.push(stream); - } - } - } - - Ok(()) - } - } - }; - - let host1_id = NodeId::new_v4(); - sim.host("host1", spawn_host(host1_id)); - - let host2_id = NodeId::new_v4(); - sim.host("host2", spawn_host(host2_id)); - - let host3_id = NodeId::new_v4(); - sim.host("host3", spawn_host(host3_id)); - - sim.client("client", async move { - let bus = Bus::new(NodeId::new_v4()); - let pool = ConnectionPool::new( - bus.clone(), - vec![ - (host1_id, "host1:1234".into()), - (host2_id, "host2:1234".into()), - (host3_id, "host3:1234".into()), - ], - ); - - tokio::task::spawn_local(pool.run::()); - - // all three hosts are reachable: - let mut stream1 = bus.new_stream(host1_id, database_id).await.unwrap(); - let m = stream1.next().await.unwrap(); - assert_eq!(m, expected_msg); - - let mut stream2 = bus.new_stream(host2_id, database_id).await.unwrap(); - let m = stream2.next().await.unwrap(); - assert_eq!(m, expected_msg); - - let mut stream3 = bus.new_stream(host3_id, database_id).await.unwrap(); - let m = stream3.next().await.unwrap(); - assert_eq!(m, expected_msg); - - // sever connections - notify.notify_waiters(); - - assert!(stream1.next().await.is_none()); - assert!(stream2.next().await.is_none()); - assert!(stream3.next().await.is_none()); - - let mut stream1 = bus.new_stream(host1_id, database_id).await.unwrap(); - let m = stream1.next().await.unwrap(); - assert_eq!(m, expected_msg); - - let mut stream2 = bus.new_stream(host2_id, database_id).await.unwrap(); - let m = stream2.next().await.unwrap(); - assert_eq!(m, expected_msg); - - let mut stream3 = bus.new_stream(host3_id, database_id).await.unwrap(); - let m = stream3.next().await.unwrap(); - assert_eq!(m, expected_msg); - - // terminate test - notify.notify_waiters(); - - Ok(()) - }); - - sim.run().unwrap(); - } -} diff --git a/libsqlx-server/src/linc/handler.rs b/libsqlx-server/src/linc/handler.rs index 6403906e..828c8bb6 100644 --- a/libsqlx-server/src/linc/handler.rs +++ b/libsqlx-server/src/linc/handler.rs @@ -1,10 +1,21 @@ use std::sync::Arc; -use super::bus::Bus; +use super::bus::{Dispatch}; use super::Inbound; #[async_trait::async_trait] pub trait Handler: Sized + Send + Sync + 'static { /// Handle inbound message - async fn handle(&self, bus: Arc>, msg: Inbound); + async fn handle(&self, bus: Arc, msg: Inbound); +} + +#[cfg(test)] +#[async_trait::async_trait] +impl Handler for F +where F: Fn(Arc, Inbound) -> Fut + Send + Sync + 'static, + Fut: std::future::Future + Send, +{ + async fn handle(&self, bus: Arc, msg: Inbound) { + (self)(bus, msg).await + } } diff --git a/libsqlx-server/src/linc/net.rs b/libsqlx-server/src/linc/net.rs index 2123c041..a7fa87af 100644 --- a/libsqlx-server/src/linc/net.rs +++ b/libsqlx-server/src/linc/net.rs @@ -74,6 +74,10 @@ mod test { fn accept(&self) -> Self::Future<'_> { Box::pin(self.accept()) } + + fn local_addr(&self) -> color_eyre::Result { + Ok(self.local_addr()?) + } } impl Connector for TcpStream { diff --git a/libsqlx-server/src/linc/server.rs b/libsqlx-server/src/linc/server.rs index f3eacec2..6371a059 100644 --- a/libsqlx-server/src/linc/server.rs +++ b/libsqlx-server/src/linc/server.rs @@ -75,24 +75,20 @@ impl Server { mod test { use std::sync::Arc; - use crate::linc::{proto::ProxyMessage, AllocId, NodeId}; - use super::*; - use futures::{SinkExt, StreamExt}; - use tokio::sync::Notify; use turmoil::net::TcpStream; #[test] fn server_respond_to_handshake() { let mut sim = turmoil::Builder::new().build(); - let host_node_id = NodeId::new_v4(); + let host_node_id = 0; let notify = Arc::new(tokio::sync::Notify::new()); sim.host("host", move || { let notify = notify.clone(); async move { - let bus = Bus::new(host_node_id); + let bus = Arc::new(Bus::new(host_node_id, |_, _| async {})); let mut server = Server::new(bus); let mut listener = turmoil::net::TcpListener::bind("0.0.0.0:1234") .await @@ -105,10 +101,10 @@ mod test { }); sim.client("client", async move { - let node_id = NodeId::new_v4(); + let node_id = 1; let mut c = Connection::new_initiator( TcpStream::connect("host:1234").await.unwrap(), - Bus::new(node_id), + Arc::new(Bus::new(node_id, |_, _| async {})), ); c.tick().await; @@ -121,229 +117,4 @@ mod test { sim.run().unwrap(); } - - #[test] - fn client_create_stream_client_close() { - let mut sim = turmoil::Builder::new().build(); - - let host_node_id = NodeId::new_v4(); - let stream_db_id = AllocId::new_v4(); - let notify = Arc::new(Notify::new()); - let expected_msg = StreamMessage::Proxy(ProxyMessage::ProxyRequest { - connection_id: 12, - req_id: 1, - program: "hello".to_string(), - }); - - sim.host("host", { - let notify = notify.clone(); - let expected_msg = expected_msg.clone(); - move || { - let notify = notify.clone(); - let expected_msg = expected_msg.clone(); - async move { - let bus = Bus::new(host_node_id); - let server = Server::new(bus.clone()); - let listener = turmoil::net::TcpListener::bind("0.0.0.0:1234") - .await - .unwrap(); - let mut subs = bus.subscribe(stream_db_id).unwrap(); - tokio::task::spawn_local(server.run(listener)); - - let mut stream = subs.next().await.unwrap(); - - let msg = stream.next().await.unwrap(); - - assert_eq!(msg, expected_msg); - - notify.notify_waiters(); - - assert!(stream.next().await.is_none()); - - notify.notify_waiters(); - - Ok(()) - } - } - }); - - sim.client("client", async move { - let node_id = NodeId::new_v4(); - let bus = Bus::new(node_id); - let mut c = Connection::new_initiator( - TcpStream::connect("host:1234").await.unwrap(), - bus.clone(), - ); - c.tick().await; - c.tick().await; - let _h = tokio::spawn(c.run()); - let mut stream = bus.new_stream(host_node_id, stream_db_id).await.unwrap(); - stream.send(expected_msg).await.unwrap(); - - notify.notified().await; - - drop(stream); - - notify.notified().await; - - Ok(()) - }); - - sim.run().unwrap(); - } - - #[test] - fn client_create_stream_server_close() { - let mut sim = turmoil::Builder::new().build(); - - let host_node_id = NodeId::new_v4(); - let database_id = AllocId::new_v4(); - let notify = Arc::new(Notify::new()); - - sim.host("host", { - let notify = notify.clone(); - move || { - let notify = notify.clone(); - async move { - let bus = Bus::new(host_node_id); - let server = Server::new(bus.clone()); - let listener = turmoil::net::TcpListener::bind("0.0.0.0:1234") - .await - .unwrap(); - let mut subs = bus.subscribe(database_id).unwrap(); - tokio::task::spawn_local(server.run(listener)); - - let stream = subs.next().await.unwrap(); - drop(stream); - - notify.notify_waiters(); - notify.notified().await; - - Ok(()) - } - } - }); - - sim.client("client", async move { - let node_id = NodeId::new_v4(); - let bus = Bus::new(node_id); - let mut c = Connection::new_initiator( - TcpStream::connect("host:1234").await.unwrap(), - bus.clone(), - ); - c.tick().await; - c.tick().await; - let _h = tokio::spawn(c.run()); - let mut stream = bus.new_stream(host_node_id, database_id).await.unwrap(); - - notify.notified().await; - assert!(stream.next().await.is_none()); - notify.notify_waiters(); - - Ok(()) - }); - - sim.run().unwrap(); - } - - #[test] - fn server_create_stream_server_close() { - let mut sim = turmoil::Builder::new().build(); - - let host_node_id = NodeId::new_v4(); - let notify = Arc::new(Notify::new()); - let client_id = NodeId::new_v4(); - let database_id = AllocId::new_v4(); - let expected_msg = StreamMessage::Proxy(ProxyMessage::ProxyRequest { - connection_id: 12, - req_id: 1, - program: "hello".to_string(), - }); - - sim.host("host", { - let notify = notify.clone(); - let expected_msg = expected_msg.clone(); - move || { - let notify = notify.clone(); - let expected_msg = expected_msg.clone(); - async move { - let bus = Bus::new(host_node_id); - let server = Server::new(bus.clone()); - let listener = turmoil::net::TcpListener::bind("0.0.0.0:1234") - .await - .unwrap(); - tokio::task::spawn_local(server.run(listener)); - - let mut stream = bus.new_stream(client_id, database_id).await.unwrap(); - stream.send(expected_msg).await.unwrap(); - notify.notified().await; - drop(stream); - - Ok(()) - } - } - }); - - sim.client("client", async move { - let bus = Bus::new(client_id); - let mut subs = bus.subscribe(database_id).unwrap(); - let c = Connection::new_initiator( - TcpStream::connect("host:1234").await.unwrap(), - bus.clone(), - ); - let _h = tokio::spawn(c.run()); - - let mut stream = subs.next().await.unwrap(); - let msg = stream.next().await.unwrap(); - assert_eq!(msg, expected_msg); - notify.notify_waiters(); - assert!(stream.next().await.is_none()); - - Ok(()) - }); - - sim.run().unwrap(); - } - - #[test] - fn server_create_stream_client_close() { - let mut sim = turmoil::Builder::new().build(); - - let host_node_id = NodeId::new_v4(); - let client_id = NodeId::new_v4(); - let database_id = AllocId::new_v4(); - - sim.host("host", { - move || async move { - let bus = Bus::new(host_node_id); - let server = Server::new(bus.clone()); - let listener = turmoil::net::TcpListener::bind("0.0.0.0:1234") - .await - .unwrap(); - tokio::task::spawn_local(server.run(listener)); - - let mut stream = bus.new_stream(client_id, database_id).await.unwrap(); - assert!(stream.next().await.is_none()); - - Ok(()) - } - }); - - sim.client("client", async move { - let bus = Bus::new(client_id); - let mut subs = bus.subscribe(database_id).unwrap(); - let c = Connection::new_initiator( - TcpStream::connect("host:1234").await.unwrap(), - bus.clone(), - ); - let _h = tokio::spawn(c.run()); - - let stream = subs.next().await.unwrap(); - drop(stream); - - Ok(()) - }); - - sim.run().unwrap(); - } } diff --git a/libsqlx-server/src/manager.rs b/libsqlx-server/src/manager.rs index 01870144..414e17bf 100644 --- a/libsqlx-server/src/manager.rs +++ b/libsqlx-server/src/manager.rs @@ -8,7 +8,7 @@ use tokio::task::JoinSet; use crate::allocation::{Allocation, AllocationMessage, Database}; use crate::hrana; -use crate::linc::bus::Bus; +use crate::linc::bus::Dispatch; use crate::linc::handler::Handler; use crate::linc::Inbound; use crate::meta::{DatabaseId, Store}; @@ -34,7 +34,7 @@ impl Manager { pub async fn alloc( self: &Arc, database_id: DatabaseId, - bus: Arc>>, + dispatcher: Arc, ) -> Option> { if let Some(sender) = self.cache.get(&database_id) { return Some(sender.clone()); @@ -46,12 +46,12 @@ impl Manager { let (alloc_sender, inbox) = mpsc::channel(MAX_ALLOC_MESSAGE_QUEUE_LEN); let alloc = Allocation { inbox, - database: Database::from_config(&config, path, bus.clone()), + database: Database::from_config(&config, path, dispatcher.clone()), connections_futs: JoinSet::new(), next_conn_id: 0, max_concurrent_connections: config.max_conccurent_connection, hrana_server: Arc::new(hrana::http::Server::new(None)), - bus, // TODO: handle self URL? + dispatcher, // TODO: handle self URL? db_name: config.db_name, connections: HashMap::new(), }; @@ -69,7 +69,7 @@ impl Manager { #[async_trait::async_trait] impl Handler for Arc { - async fn handle(&self, bus: Arc>, msg: Inbound) { + async fn handle(&self, bus: Arc, msg: Inbound) { if let Some(sender) = self .clone() .alloc(msg.enveloppe.database_id.unwrap(), bus.clone()) diff --git a/libsqlx-server/src/meta.rs b/libsqlx-server/src/meta.rs index b71b33eb..0167497b 100644 --- a/libsqlx-server/src/meta.rs +++ b/libsqlx-server/src/meta.rs @@ -27,6 +27,11 @@ impl DatabaseId { reader.read(&mut out); Self(out) } + + #[cfg(test)] + pub fn random() -> Self { + Self(uuid::Uuid::new_v4().into_bytes()) + } } impl fmt::Display for DatabaseId { diff --git a/libsqlx/Cargo.toml b/libsqlx/Cargo.toml index 85fd7a9d..519339ec 100644 --- a/libsqlx/Cargo.toml +++ b/libsqlx/Cargo.toml @@ -8,7 +8,7 @@ edition = "2021" [dependencies] async-trait = "0.1.68" bytesize = "1.2.0" -serde = "1.0.164" +serde = { version = "1", features = ["rc"] } serde_json = "1.0.99" rusqlite = { workspace = true } anyhow = "1.0.71" diff --git a/libsqlx/src/database/libsql/mod.rs b/libsqlx/src/database/libsql/mod.rs index c0aaed79..5c060e4d 100644 --- a/libsqlx/src/database/libsql/mod.rs +++ b/libsqlx/src/database/libsql/mod.rs @@ -196,6 +196,7 @@ mod test { use std::sync::atomic::AtomicBool; use std::sync::atomic::Ordering::Relaxed; + use parking_lot::Mutex; use rusqlite::types::Value; use crate::connection::Connection; @@ -205,14 +206,14 @@ mod test { use super::*; - struct ReadRowBuilder(Vec); + struct ReadRowBuilder(Arc>>); impl ResultBuilder for ReadRowBuilder { fn add_row_value( &mut self, v: rusqlite::types::ValueRef, ) -> Result<(), QueryResultBuilderError> { - self.0.push(v.into()); + self.0.lock().push(v.into()); Ok(()) } } @@ -227,22 +228,26 @@ mod test { let mut db = LibsqlDatabase::new(temp.path().to_path_buf(), replica); let mut conn = db.connect().unwrap(); - let mut builder = ReadRowBuilder(Vec::new()); - conn.execute_program(Program::seq(&["select count(*) from test"]), &mut builder) + let row: Arc>> = Default::default(); + let builder = Box::new(ReadRowBuilder(row.clone())); + conn.execute_program(&Program::seq(&["select count(*) from test"]), builder) .unwrap(); - assert!(builder.0.is_empty()); + assert!(row.lock().is_empty()); let file = File::open("assets/test/simple_wallog").unwrap(); let log = LogFile::new(file).unwrap(); let mut injector = db.injector().unwrap(); log.frames_iter() .unwrap() - .for_each(|f| injector.inject(f.unwrap()).unwrap()); + .for_each(|f| { + injector.inject(f.unwrap()).unwrap(); + }); - let mut builder = ReadRowBuilder(Vec::new()); - conn.execute_program(Program::seq(&["select count(*) from test"]), &mut builder) + let row: Arc>> = Default::default(); + let builder = Box::new(ReadRowBuilder(row.clone())); + conn.execute_program(&Program::seq(&["select count(*) from test"]), builder) .unwrap(); - assert_eq!(builder.0[0], Value::Integer(5)); + assert_eq!(row.lock()[0], Value::Integer(5)); } #[test] @@ -253,7 +258,7 @@ mod test { let primary = LibsqlDatabase::new( temp_primary.path().to_path_buf(), PrimaryType { - logger: Arc::new(ReplicationLogger::open(temp_primary.path(), false, ()).unwrap()), + logger: Arc::new(ReplicationLogger::open(temp_primary.path(), false, (), Box::new(|_| ())).unwrap()), }, ); @@ -268,8 +273,8 @@ mod test { let mut primary_conn = primary.connect().unwrap(); primary_conn .execute_program( - Program::seq(&["create table test (x)", "insert into test values (42)"]), - &mut (), + &Program::seq(&["create table test (x)", "insert into test values (42)"]), + Box::new(()), ) .unwrap(); @@ -282,13 +287,14 @@ mod test { } let mut replica_conn = replica.connect().unwrap(); - let mut builder = ReadRowBuilder(Vec::new()); + let row: Arc>> = Default::default(); + let builder = Box::new(ReadRowBuilder(row.clone())); replica_conn - .execute_program(Program::seq(&["select * from test limit 1"]), &mut builder) + .execute_program(&Program::seq(&["select * from test limit 1"]), builder) .unwrap(); - assert_eq!(builder.0.len(), 1); - assert_eq!(builder.0[0], Value::Integer(42)); + assert_eq!(row.lock().len(), 1); + assert_eq!(row.lock()[0], Value::Integer(42)); } #[test] @@ -317,13 +323,14 @@ mod test { temp.path().to_path_buf(), Compactor(compactor_called.clone()), false, + Box::new(|_| ()), ) .unwrap(); let mut conn = db.connect().unwrap(); conn.execute_program( - Program::seq(&["create table test (x)", "insert into test values (12)"]), - &mut (), + &Program::seq(&["create table test (x)", "insert into test values (12)"]), + Box::new(()), ) .unwrap(); assert!(compactor_called.load(Relaxed)); @@ -356,22 +363,23 @@ mod test { temp.path().to_path_buf(), Compactor(compactor_called.clone()), false, + Box::new(|_| ()) ) .unwrap(); let mut conn = db.connect().unwrap(); conn.execute_program( - Program::seq(&[ + &Program::seq(&[ "begin", "create table test (x)", "insert into test values (12)", ]), - &mut (), + Box::new(()) ) .unwrap(); conn.inner_connection().cache_flush().unwrap(); assert!(!compactor_called.load(Relaxed)); - conn.execute_program(Program::seq(&["commit"]), &mut ()) + conn.execute_program(&Program::seq(&["commit"]), Box::new(())) .unwrap(); assert!(compactor_called.load(Relaxed)); } diff --git a/libsqlx/src/database/libsql/replication_log/logger.rs b/libsqlx/src/database/libsql/replication_log/logger.rs index 187f3b25..7bcfb0bf 100644 --- a/libsqlx/src/database/libsql/replication_log/logger.rs +++ b/libsqlx/src/database/libsql/replication_log/logger.rs @@ -958,7 +958,7 @@ mod test { #[test] fn write_and_read_from_frame_log() { let dir = tempfile::tempdir().unwrap(); - let logger = ReplicationLogger::open(dir.path(), false, ()).unwrap(); + let logger = ReplicationLogger::open(dir.path(), false, (), Box::new(|_| ())).unwrap(); let frames = (0..10) .map(|i| WalPage { @@ -986,7 +986,7 @@ mod test { #[test] fn index_out_of_bounds() { let dir = tempfile::tempdir().unwrap(); - let logger = ReplicationLogger::open(dir.path(), false, ()).unwrap(); + let logger = ReplicationLogger::open(dir.path(), false, (), Box::new(|_| ())).unwrap(); let log_file = logger.log_file.write(); assert!(matches!(log_file.frame(1), Err(LogReadError::Ahead))); } @@ -995,7 +995,7 @@ mod test { #[should_panic] fn incorrect_frame_size() { let dir = tempfile::tempdir().unwrap(); - let logger = ReplicationLogger::open(dir.path(), false, ()).unwrap(); + let logger = ReplicationLogger::open(dir.path(), false, (), Box::new(|_| ())).unwrap(); let entry = WalPage { page_no: 0, size_after: 0, diff --git a/libsqlx/src/database/proxy/connection.rs b/libsqlx/src/database/proxy/connection.rs index a06b6620..2d576387 100644 --- a/libsqlx/src/database/proxy/connection.rs +++ b/libsqlx/src/database/proxy/connection.rs @@ -250,36 +250,37 @@ impl ResultBuilder for ExtractFrameNoBuilder { #[cfg(test)] mod test { - use std::cell::Cell; - use std::rc::Rc; use std::sync::Arc; + use parking_lot::Mutex; + + use crate::Connection; use crate::database::test_utils::MockDatabase; use crate::database::{proxy::database::WriteProxyDatabase, Database}; use crate::program::Program; #[test] fn simple_write_proxied() { - let write_called = Rc::new(Cell::new(false)); + let write_called = Arc::new(Mutex::new(false)); let write_db = MockDatabase::new().with_execute({ let write_called = write_called.clone(); - move |_, b| { + move |_, mut b| { b.finnalize(false, Some(42)).unwrap(); - write_called.set(true); + *write_called.lock() =true; Ok(()) } }); - let read_called = Rc::new(Cell::new(false)); + let read_called = Arc::new(Mutex::new(false)); let read_db = MockDatabase::new().with_execute({ let read_called = read_called.clone(); move |_, _| { - read_called.set(true); + *read_called.lock() = true; Ok(()) } }); - let wait_called = Rc::new(Cell::new(false)); + let wait_called = Arc::new(Mutex::new(false)); let db = WriteProxyDatabase::new( read_db, write_db, @@ -287,23 +288,23 @@ mod test { let wait_called = wait_called.clone(); move |fno| { assert_eq!(fno, 42); - wait_called.set(true); + *wait_called.lock() = true; } }), ); let mut conn = db.connect().unwrap(); - conn.execute_program(Program::seq(&["insert into test values (12)"]), &mut ()) + conn.execute_program(&Program::seq(&["insert into test values (12)"]), Box::new(())) .unwrap(); - assert!(!wait_called.get()); - assert!(!read_called.get()); - assert!(write_called.get()); + assert!(!*wait_called.lock()); + assert!(!*read_called.lock()); + assert!(*write_called.lock()); - conn.execute_program(Program::seq(&["select * from test"]), &mut ()) + conn.execute_program(&Program::seq(&["select * from test"]), Box::new(())) .unwrap(); - assert!(read_called.get()); - assert!(wait_called.get()); + assert!(*read_called.lock()); + assert!(*wait_called.lock()); } } diff --git a/libsqlx/src/database/test_utils.rs b/libsqlx/src/database/test_utils.rs index a46aa2ac..93bf3b1d 100644 --- a/libsqlx/src/database/test_utils.rs +++ b/libsqlx/src/database/test_utils.rs @@ -10,16 +10,17 @@ use super::Database; pub struct MockDatabase { #[allow(clippy::type_complexity)] - describe_fn: Arc crate::Result>, + describe_fn: Arc crate::Result +Send +Sync>, #[allow(clippy::type_complexity)] - execute_fn: Arc crate::Result<()>>, + execute_fn: Arc) -> crate::Result<()> +Send +Sync>, } +#[derive(Clone)] pub struct MockConnection { #[allow(clippy::type_complexity)] - describe_fn: Arc crate::Result>, + describe_fn: Arc crate::Result + Send +Sync>, #[allow(clippy::type_complexity)] - execute_fn: Arc crate::Result<()>>, + execute_fn: Arc) -> crate::Result<()> + Send +Sync>, } impl MockDatabase { @@ -32,7 +33,7 @@ impl MockDatabase { pub fn with_execute( mut self, - f: impl Fn(Program, &mut dyn ResultBuilder) -> crate::Result<()> + 'static, + f: impl Fn(&Program, Box) -> crate::Result<()> + Send + Sync +'static, ) -> Self { self.execute_fn = Arc::new(f); self @@ -53,8 +54,8 @@ impl Database for MockDatabase { impl Connection for MockConnection { fn execute_program( &mut self, - pgm: crate::program::Program, - reponse_builder: &mut dyn ResultBuilder, + pgm: &crate::program::Program, + reponse_builder: Box, ) -> crate::Result<()> { (self.execute_fn)(pgm, reponse_builder)?; Ok(()) diff --git a/libsqlx/src/result_builder.rs b/libsqlx/src/result_builder.rs index fed13fd3..458b50cc 100644 --- a/libsqlx/src/result_builder.rs +++ b/libsqlx/src/result_builder.rs @@ -650,9 +650,10 @@ pub mod test { &mut self, _is_txn: bool, _frame_no: Option, - ) -> Result<(), QueryResultBuilderError> { + ) -> Result { self.maybe_inject_error()?; - self.transition(Finish) + self.transition(Finish)?; + Ok(true) } } From 06d5d648f98cac9849a0fb8a0794d1d41c83e0cf Mon Sep 17 00:00:00 2001 From: ad hoc Date: Thu, 20 Jul 2023 11:36:32 +0200 Subject: [PATCH 25/64] set connection state to unknown before sending proxy request --- libsqlx/src/database/proxy/connection.rs | 34 +++++++++++++++++++++--- 1 file changed, 31 insertions(+), 3 deletions(-) diff --git a/libsqlx/src/database/proxy/connection.rs b/libsqlx/src/database/proxy/connection.rs index 2d576387..bee33d17 100644 --- a/libsqlx/src/database/proxy/connection.rs +++ b/libsqlx/src/database/proxy/connection.rs @@ -10,9 +10,27 @@ use crate::Result; use super::WaitFrameNoCb; +#[derive(Debug, Default)] +enum State { + Txn, + #[default] + Idle, + Unknown +} + +impl State { + /// Returns `true` if the state is [`Idle`]. + /// + /// [`Idle`]: State::Idle + #[must_use] + fn is_idle(&self) -> bool { + matches!(self, Self::Idle) + } +} + #[derive(Debug, Default)] pub(crate) struct ConnState { - is_txn: bool, + state: State, last_frame_no: Option, } @@ -123,6 +141,9 @@ where state: self.state.clone(), }; + // set the connection state to unknown before executing on the remote + self.state.lock().state = State::Unknown; + self.conn .execute_program(&self.pgm, Box::new(builder)) .unwrap(); @@ -144,7 +165,7 @@ where pgm: &Program, builder: Box, ) -> crate::Result<()> { - if !self.state.lock().is_txn && pgm.is_read_only() { + if self.state.lock().state.is_idle() && pgm.is_read_only() { if let Some(frame_no) = self.state.lock().last_frame_no { (self.wait_frame_no_cb)(frame_no); } @@ -162,6 +183,9 @@ where // rollback(&mut self.conn.read_db); Ok(()) } else { + // we set the state to unknown because until we have received from the actual + // connection state from the primary. + self.state.lock().state = State::Unknown; let builder = ExtractFrameNoBuilder { builder, state: self.state.clone(), @@ -243,7 +267,11 @@ impl ResultBuilder for ExtractFrameNoBuilder { ) -> Result { let mut state = self.state.lock(); state.last_frame_no = frame_no; - state.is_txn = is_txn; + if is_txn { + state.state = State::Txn; + } else { + state.state = State::Idle; + } self.builder.finnalize(is_txn, frame_no) } } From 1010a68faead94ee36eb11927707ce09bdbf94fd Mon Sep 17 00:00:00 2001 From: ad hoc Date: Thu, 20 Jul 2023 15:57:22 +0200 Subject: [PATCH 26/64] proxy request timeout --- Cargo.lock | 2 + libsqlx-server/Cargo.toml | 2 + libsqlx-server/src/allocation/config.rs | 7 +- libsqlx-server/src/allocation/mod.rs | 155 +++++++++++++++++---- libsqlx-server/src/hrana/result_builder.rs | 8 ++ libsqlx-server/src/http/admin.rs | 63 +++++++-- libsqlx-server/src/linc/bus.rs | 5 +- libsqlx-server/src/linc/connection.rs | 7 +- libsqlx-server/src/linc/handler.rs | 9 +- libsqlx-server/src/linc/proto.rs | 1 + libsqlx/src/database/libsql/mod.rs | 17 +-- libsqlx/src/database/proxy/connection.rs | 21 ++- libsqlx/src/database/test_utils.rs | 10 +- libsqlx/src/program.rs | 1 - libsqlx/src/result_builder.rs | 3 + 15 files changed, 247 insertions(+), 64 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index d08d5034..23ee18ef 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2554,6 +2554,7 @@ dependencies = [ "either", "futures", "hmac", + "humantime", "hyper", "itertools 0.11.0", "libsqlx", @@ -2567,6 +2568,7 @@ dependencies = [ "sha2", "sha3", "sled", + "tempfile", "thiserror", "tokio", "tokio-stream", diff --git a/libsqlx-server/Cargo.toml b/libsqlx-server/Cargo.toml index a5a11437..42a508c6 100644 --- a/libsqlx-server/Cargo.toml +++ b/libsqlx-server/Cargo.toml @@ -17,6 +17,7 @@ color-eyre = "0.6.2" either = "1.8.1" futures = "0.3.28" hmac = "0.12.1" +humantime = "2.1.0" hyper = { version = "0.14.27", features = ["h2", "server"] } itertools = "0.11.0" libsqlx = { version = "0.1.0", path = "../libsqlx", features = ["tokio"] } @@ -30,6 +31,7 @@ serde_json = "1.0.100" sha2 = "0.10.7" sha3 = "0.10.8" sled = "0.34.7" +tempfile = "3.6.0" thiserror = "1.0.43" tokio = { version = "1.29.1", features = ["full"] } tokio-stream = "0.1.14" diff --git a/libsqlx-server/src/allocation/config.rs b/libsqlx-server/src/allocation/config.rs index 9d1bab34..ac21efa7 100644 --- a/libsqlx-server/src/allocation/config.rs +++ b/libsqlx-server/src/allocation/config.rs @@ -1,3 +1,5 @@ +use std::time::Duration; + use serde::{Deserialize, Serialize}; use crate::linc::NodeId; @@ -19,5 +21,8 @@ pub struct AllocConfig { #[derive(Debug, Serialize, Deserialize)] pub enum DbConfig { Primary {}, - Replica { primary_node_id: NodeId }, + Replica { + primary_node_id: NodeId, + proxy_request_timeout_duration: Duration, + }, } diff --git a/libsqlx-server/src/allocation/mod.rs b/libsqlx-server/src/allocation/mod.rs index fbe0d98e..a41f2cd7 100644 --- a/libsqlx-server/src/allocation/mod.rs +++ b/libsqlx-server/src/allocation/mod.rs @@ -1,13 +1,17 @@ use std::collections::hash_map::Entry; use std::collections::HashMap; +use std::future::poll_fn; use std::mem::size_of; use std::ops::Deref; use std::path::PathBuf; +use std::pin::Pin; use std::sync::Arc; +use std::task::{Context, Poll}; use std::time::{Duration, Instant}; use bytes::Bytes; use either::Either; +use futures::Future; use libsqlx::libsql::{LibsqlDatabase, LogCompactor, LogFile, PrimaryType, ReplicaType}; use libsqlx::program::Program; use libsqlx::proxy::{WriteProxyConnection, WriteProxyDatabase}; @@ -19,12 +23,12 @@ use libsqlx::{ use parking_lot::Mutex; use tokio::sync::{mpsc, oneshot}; use tokio::task::{block_in_place, JoinSet}; -use tokio::time::timeout; +use tokio::time::{timeout, Sleep}; use crate::hrana; use crate::hrana::http::handle_pipeline; use crate::hrana::http::proto::{PipelineRequestBody, PipelineResponseBody}; -use crate::linc::bus::{Dispatch}; +use crate::linc::bus::Dispatch; use crate::linc::proto::{ BuilderStep, Enveloppe, Frames, Message, ProxyResponse, StepError, Value, }; @@ -50,7 +54,9 @@ pub enum AllocationMessage { Inbound(Inbound), } -pub struct RemoteDb; +pub struct RemoteDb { + proxy_request_timeout_duration: Duration, +} #[derive(Clone)] pub struct RemoteConn { @@ -62,10 +68,12 @@ struct Request { builder: Box, pgm: Option, next_seq_no: u32, + timeout: Pin>, } pub struct RemoteConnInner { current_req: Mutex>, + request_timeout_duration: Duration, } impl Deref for RemoteConn { @@ -93,6 +101,7 @@ impl libsqlx::Connection for RemoteConn { builder, pgm: Some(program.clone()), next_seq_no: 0, + timeout: Box::pin(tokio::time::sleep(self.inner.request_timeout_duration)), }), }; @@ -111,6 +120,7 @@ impl libsqlx::Database for RemoteDb { Ok(RemoteConn { inner: Arc::new(RemoteConnInner { current_req: Default::default(), + request_timeout_duration: self.proxy_request_timeout_duration, }), }) } @@ -462,9 +472,14 @@ impl Database { frame_notifier: receiver, }) } - DbConfig::Replica { primary_node_id } => { + DbConfig::Replica { + primary_node_id, + proxy_request_timeout_duration, + } => { let rdb = LibsqlDatabase::new_replica(path, MAX_INJECTOR_BUFFER_CAP, ()).unwrap(); - let wdb = RemoteDb; + let wdb = RemoteDb { + proxy_request_timeout_duration, + }; let mut db = WriteProxyDatabase::new(rdb, wdb, Arc::new(|_| ())); let injector = db.injector().unwrap(); let (sender, receiver) = mpsc::channel(16); @@ -502,7 +517,7 @@ impl Database { conn: db.connect().unwrap(), connection_id, next_req_id: 0, - primary_id: *primary_id, + primary_node_id: *primary_id, database_id: DatabaseId::from_name(&alloc.db_name), dispatcher: alloc.dispatcher.clone(), }), @@ -520,8 +535,8 @@ struct PrimaryConnection { #[async_trait::async_trait] impl ConnectionHandler for PrimaryConnection { - fn exec_ready(&self) -> bool { - true + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<()> { + Poll::Ready(()) } async fn handle_exec(&mut self, exec: ExecFn) { @@ -537,7 +552,7 @@ struct ReplicaConnection { conn: ProxyConnection, connection_id: u32, next_req_id: u32, - primary_id: NodeId, + primary_node_id: NodeId, database_id: DatabaseId, dispatcher: Arc, } @@ -551,16 +566,21 @@ impl ReplicaConnection { // TODO: pass actual config let config = QueryBuilderConfig { max_size: None }; let mut finnalized = false; - for step in resp.row_steps.iter() { - if finnalized { break }; + for step in resp.row_steps.into_iter() { + if finnalized { + break; + }; match step { BuilderStep::Init => req.builder.init(&config).unwrap(), BuilderStep::BeginStep => req.builder.begin_step().unwrap(), BuilderStep::FinishStep(affected_row_count, last_insert_rowid) => req .builder - .finish_step(*affected_row_count, *last_insert_rowid) + .finish_step(affected_row_count, last_insert_rowid) + .unwrap(), + BuilderStep::StepError(e) => req + .builder + .step_error(todo!("handle proxy step error")) .unwrap(), - BuilderStep::StepError(e) => req.builder.step_error(todo!()).unwrap(), BuilderStep::ColsDesc(cols) => req .builder .cols_description(&mut cols.iter().map(|c| Column { @@ -570,11 +590,15 @@ impl ReplicaConnection { .unwrap(), BuilderStep::BeginRows => req.builder.begin_rows().unwrap(), BuilderStep::BeginRow => req.builder.begin_row().unwrap(), - BuilderStep::AddRowValue(v) => req.builder.add_row_value(v.into()).unwrap(), + BuilderStep::AddRowValue(v) => req.builder.add_row_value((&v).into()).unwrap(), BuilderStep::FinishRow => req.builder.finish_row().unwrap(), BuilderStep::FinishRows => req.builder.finish_rows().unwrap(), BuilderStep::Finnalize { is_txn, frame_no } => { - let _ = req.builder.finnalize(*is_txn, *frame_no).unwrap(); + let _ = req.builder.finnalize(is_txn, frame_no).unwrap(); + finnalized = true; + }, + BuilderStep::FinnalizeError(e) => { + req.builder.finnalize_error(e); finnalized = true; } } @@ -596,9 +620,28 @@ impl ReplicaConnection { #[async_trait::async_trait] impl ConnectionHandler for ReplicaConnection { - fn exec_ready(&self) -> bool { + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<()> { // we are currently handling a request on this connection - self.conn.writer().current_req.lock().is_none() + // self.conn.writer().current_req.timeout.poll() + let mut req = self.conn.writer().current_req.lock(); + let should_abort_query = match &mut *req { + Some(ref mut req) => { + match req.timeout.as_mut().poll(cx) { + Poll::Ready(_) => { + req.builder.finnalize_error("request timed out".to_string()); + true + } + Poll::Pending => return Poll::Pending, + } + } + None => return Poll::Ready(()), + }; + + if should_abort_query { + *req = None + } + + Poll::Ready(()) } async fn handle_exec(&mut self, exec: ExecFn) { @@ -616,7 +659,7 @@ impl ConnectionHandler for ReplicaConnection { req.id = Some(req_id); let msg = Outbound { - to: self.primary_id, + to: self.primary_node_id, enveloppe: Enveloppe { database_id: Some(self.database_id), message: Message::ProxyRequest { @@ -654,10 +697,10 @@ where L: ConnectionHandler, R: ConnectionHandler, { - fn exec_ready(&self) -> bool { + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<()> { match self { - Either::Left(l) => l.exec_ready(), - Either::Right(r) => r.exec_ready(), + Either::Left(l) => l.poll_ready(cx), + Either::Right(r) => r.poll_ready(cx), } } @@ -852,7 +895,7 @@ impl Allocation { }; conn.execute_program(&program, Box::new(builder)).unwrap(); }) - .await; + .await; }; if self.database.is_primary() { @@ -921,7 +964,7 @@ struct Connection { #[async_trait::async_trait] trait ConnectionHandler: Send { - fn exec_ready(&self) -> bool; + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<()>; async fn handle_exec(&mut self, exec: ExecFn); async fn handle_inbound(&mut self, msg: Inbound); } @@ -929,11 +972,13 @@ trait ConnectionHandler: Send { impl Connection { async fn run(mut self) -> (NodeId, u32) { loop { + let fut = + futures::future::join(self.exec.recv(), poll_fn(|cx| self.conn.poll_ready(cx))); tokio::select! { Some(inbound) = self.inbound.recv() => { self.conn.handle_inbound(inbound).await; } - Some(exec) = self.exec.recv(), if self.conn.exec_ready() => { + (Some(exec), _) = fut => { self.conn.handle_exec(exec).await; }, else => break, @@ -943,3 +988,65 @@ impl Connection { self.id } } + +#[cfg(test)] +mod test { + use tokio::sync::Notify; + + use crate::linc::bus::Bus; + + use super::*; + + #[tokio::test(flavor = "multi_thread", worker_threads = 1)] + async fn proxy_request_timeout() { + let bus = Arc::new(Bus::new(0, |_, _| async {})); + let _queue = bus.connect(1); // pretend connection to node 1 + let tmp = tempfile::TempDir::new().unwrap(); + let read_db = LibsqlDatabase::new_replica(tmp.path().to_path_buf(), 1, ()).unwrap(); + let write_db = RemoteDb { + proxy_request_timeout_duration: Duration::from_millis(100), + }; + let db = WriteProxyDatabase::new(read_db, write_db, Arc::new(|_| ())); + let conn = db.connect().unwrap(); + let conn = ReplicaConnection { + conn, + connection_id: 0, + next_req_id: 0, + primary_node_id: 1, + database_id: DatabaseId::random(), + dispatcher: bus, + }; + + let (exec_sender, exec) = mpsc::channel(1); + let (_inbound_sender, inbound) = mpsc::channel(1); + let connection = Connection { + id: (0, 0), + conn, + exec, + inbound, + }; + + let handle = tokio::spawn(connection.run()); + + let notify = Arc::new(Notify::new()); + struct Builder(Arc); + impl ResultBuilder for Builder { + fn finnalize_error(&mut self, _e: String) { + self.0.notify_waiters() + } + } + + let builder = Box::new(Builder(notify.clone())); + exec_sender + .send(Box::new(move |conn| { + conn.execute_program(&Program::seq(&["create table test (c)"]), builder) + .unwrap(); + })) + .await + .unwrap(); + + notify.notified().await; + + handle.abort(); + } +} diff --git a/libsqlx-server/src/hrana/result_builder.rs b/libsqlx-server/src/hrana/result_builder.rs index e91bca28..c0c597bf 100644 --- a/libsqlx-server/src/hrana/result_builder.rs +++ b/libsqlx-server/src/hrana/result_builder.rs @@ -76,6 +76,10 @@ impl ResultBuilder for SingleStatementBuilder { let _ = self.ret.take().unwrap().send(res); Ok(true) } + + fn finnalize_error(&mut self, _e: String) { + todo!() + } } #[derive(Debug, Default)] @@ -354,4 +358,8 @@ impl ResultBuilder for HranaBatchProtoBuilder { fn add_row_value(&mut self, v: ValueRef) -> Result<(), QueryResultBuilderError> { self.stmt_builder.add_row_value(v) } + + fn finnalize_error(&mut self, _e: String) { + todo!() + } } diff --git a/libsqlx-server/src/http/admin.rs b/libsqlx-server/src/http/admin.rs index 8a08187e..9323bcdd 100644 --- a/libsqlx-server/src/http/admin.rs +++ b/libsqlx-server/src/http/admin.rs @@ -1,16 +1,18 @@ use std::sync::Arc; +use std::str::FromStr; +use std::time::Duration; -use axum::{extract::State, routing::post, Json, Router}; +use axum::{Json, Router}; +use axum::routing::post; +use axum::extract::State; use color_eyre::eyre::Result; use hyper::server::accept::Accept; -use serde::{Deserialize, Serialize}; +use serde::{Deserialize, Deserializer, Serialize}; use tokio::io::{AsyncRead, AsyncWrite}; -use crate::{ - allocation::config::{AllocConfig, DbConfig}, - linc::NodeId, - meta::Store, -}; +use crate::meta::Store; +use crate::allocation::config::{AllocConfig, DbConfig}; +use crate::linc::NodeId; pub struct Config { pub meta_store: Arc, @@ -52,11 +54,46 @@ struct AllocateReq { config: DbConfigReq, } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Deserialize)] #[serde(tag = "type", rename_all = "snake_case")] pub enum DbConfigReq { Primary {}, - Replica { primary_node_id: NodeId }, + Replica { + primary_node_id: NodeId, + #[serde(deserialize_with = "deserialize_duration", default = "default_proxy_timeout")] + proxy_request_timeout_duration: Duration, + }, +} + +const fn default_proxy_timeout() -> Duration { + Duration::from_secs(5) +} + +fn deserialize_duration<'de, D>(deserializer: D) -> Result +where + D: Deserializer<'de>, +{ + struct Visitor; + impl serde::de::Visitor<'_> for Visitor { + type Value = Duration; + + fn visit_str(self, v: &str) -> std::result::Result + where + E: serde::de::Error, + { + match humantime::Duration::from_str(v) { + Ok(d) => Ok(*d), + Err(e) => Err(E::custom(e.to_string())), + } + } + + fn expecting(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + f.write_str("a duration, in a string format") + } + + } + + deserializer.deserialize_str(Visitor) } async fn allocate( @@ -68,7 +105,13 @@ async fn allocate( db_name: req.alloc_id.clone(), db_config: match req.config { DbConfigReq::Primary {} => DbConfig::Primary {}, - DbConfigReq::Replica { primary_node_id } => DbConfig::Replica { primary_node_id }, + DbConfigReq::Replica { + primary_node_id, + proxy_request_timeout_duration, + } => DbConfig::Replica { + primary_node_id, + proxy_request_timeout_duration, + }, }, }; state.meta_store.allocate(&req.alloc_id, &config).await; diff --git a/libsqlx-server/src/linc/bus.rs b/libsqlx-server/src/linc/bus.rs index a31c3368..7c7b70dd 100644 --- a/libsqlx-server/src/linc/bus.rs +++ b/libsqlx-server/src/linc/bus.rs @@ -2,9 +2,11 @@ use std::collections::HashSet; use std::sync::Arc; use parking_lot::RwLock; +use tokio::sync::mpsc; use super::connection::SendQueue; use super::handler::Handler; +use super::proto::Enveloppe; use super::{Inbound, NodeId, Outbound}; pub struct Bus { @@ -37,9 +39,10 @@ impl Bus { &self.send_queue } - pub fn connect(&self, node_id: NodeId) { + pub fn connect(&self, node_id: NodeId) -> mpsc::UnboundedReceiver { // TODO: handle peer already exists self.peers.write().insert(node_id); + self.send_queue.register(node_id) } pub fn disconnect(&self, node_id: NodeId) { diff --git a/libsqlx-server/src/linc/connection.rs b/libsqlx-server/src/linc/connection.rs index 5f5d9f24..b979c437 100644 --- a/libsqlx-server/src/linc/connection.rs +++ b/libsqlx-server/src/linc/connection.rs @@ -227,8 +227,7 @@ where self.peer = Some(node_id); self.state = ConnectionState::Connected; - self.send_queue = Some(self.bus.send_queue().register(node_id)); - self.bus.connect(node_id); + self.send_queue = Some(self.bus.connect(node_id)); Ok(()) } @@ -321,9 +320,7 @@ mod test { assert!(matches!( m.message, - Message::Error( - ProtoError::HandshakeVersionMismatch { .. } - ) + Message::Error(ProtoError::HandshakeVersionMismatch { .. }) )); done.notify_waiters(); diff --git a/libsqlx-server/src/linc/handler.rs b/libsqlx-server/src/linc/handler.rs index 828c8bb6..2d17ff96 100644 --- a/libsqlx-server/src/linc/handler.rs +++ b/libsqlx-server/src/linc/handler.rs @@ -1,6 +1,6 @@ use std::sync::Arc; -use super::bus::{Dispatch}; +use super::bus::Dispatch; use super::Inbound; #[async_trait::async_trait] @@ -11,9 +11,10 @@ pub trait Handler: Sized + Send + Sync + 'static { #[cfg(test)] #[async_trait::async_trait] -impl Handler for F -where F: Fn(Arc, Inbound) -> Fut + Send + Sync + 'static, - Fut: std::future::Future + Send, +impl Handler for F +where + F: Fn(Arc, Inbound) -> Fut + Send + Sync + 'static, + Fut: std::future::Future + Send, { async fn handle(&self, bus: Arc, msg: Inbound) { (self)(bus, msg).await diff --git a/libsqlx-server/src/linc/proto.rs b/libsqlx-server/src/linc/proto.rs index a9aa529d..7e3a583d 100644 --- a/libsqlx-server/src/linc/proto.rs +++ b/libsqlx-server/src/linc/proto.rs @@ -107,6 +107,7 @@ pub enum BuilderStep { is_txn: bool, frame_no: Option, }, + FinnalizeError(String), } #[derive(Debug, Serialize, Deserialize, Clone)] diff --git a/libsqlx/src/database/libsql/mod.rs b/libsqlx/src/database/libsql/mod.rs index 5c060e4d..382a4177 100644 --- a/libsqlx/src/database/libsql/mod.rs +++ b/libsqlx/src/database/libsql/mod.rs @@ -237,11 +237,9 @@ mod test { let file = File::open("assets/test/simple_wallog").unwrap(); let log = LogFile::new(file).unwrap(); let mut injector = db.injector().unwrap(); - log.frames_iter() - .unwrap() - .for_each(|f| { - injector.inject(f.unwrap()).unwrap(); - }); + log.frames_iter().unwrap().for_each(|f| { + injector.inject(f.unwrap()).unwrap(); + }); let row: Arc>> = Default::default(); let builder = Box::new(ReadRowBuilder(row.clone())); @@ -258,7 +256,10 @@ mod test { let primary = LibsqlDatabase::new( temp_primary.path().to_path_buf(), PrimaryType { - logger: Arc::new(ReplicationLogger::open(temp_primary.path(), false, (), Box::new(|_| ())).unwrap()), + logger: Arc::new( + ReplicationLogger::open(temp_primary.path(), false, (), Box::new(|_| ())) + .unwrap(), + ), }, ); @@ -363,7 +364,7 @@ mod test { temp.path().to_path_buf(), Compactor(compactor_called.clone()), false, - Box::new(|_| ()) + Box::new(|_| ()), ) .unwrap(); @@ -374,7 +375,7 @@ mod test { "create table test (x)", "insert into test values (12)", ]), - Box::new(()) + Box::new(()), ) .unwrap(); conn.inner_connection().cache_flush().unwrap(); diff --git a/libsqlx/src/database/proxy/connection.rs b/libsqlx/src/database/proxy/connection.rs index bee33d17..a7638e56 100644 --- a/libsqlx/src/database/proxy/connection.rs +++ b/libsqlx/src/database/proxy/connection.rs @@ -15,7 +15,7 @@ enum State { Txn, #[default] Idle, - Unknown + Unknown, } impl State { @@ -153,6 +153,10 @@ where self.builder.as_mut().unwrap().finnalize(is_txn, frame_no) } } + + fn finnalize_error(&mut self, e: String) { + self.builder.take().unwrap().finnalize_error(e) + } } impl Connection for WriteProxyConnection @@ -274,6 +278,10 @@ impl ResultBuilder for ExtractFrameNoBuilder { } self.builder.finnalize(is_txn, frame_no) } + + fn finnalize_error(&mut self, e: String) { + self.builder.finnalize_error(e) + } } #[cfg(test)] @@ -282,10 +290,10 @@ mod test { use parking_lot::Mutex; - use crate::Connection; use crate::database::test_utils::MockDatabase; use crate::database::{proxy::database::WriteProxyDatabase, Database}; use crate::program::Program; + use crate::Connection; #[test] fn simple_write_proxied() { @@ -294,7 +302,7 @@ mod test { let write_called = write_called.clone(); move |_, mut b| { b.finnalize(false, Some(42)).unwrap(); - *write_called.lock() =true; + *write_called.lock() = true; Ok(()) } }); @@ -322,8 +330,11 @@ mod test { ); let mut conn = db.connect().unwrap(); - conn.execute_program(&Program::seq(&["insert into test values (12)"]), Box::new(())) - .unwrap(); + conn.execute_program( + &Program::seq(&["insert into test values (12)"]), + Box::new(()), + ) + .unwrap(); assert!(!*wait_called.lock()); assert!(!*read_called.lock()); diff --git a/libsqlx/src/database/test_utils.rs b/libsqlx/src/database/test_utils.rs index 93bf3b1d..3034ca93 100644 --- a/libsqlx/src/database/test_utils.rs +++ b/libsqlx/src/database/test_utils.rs @@ -10,17 +10,17 @@ use super::Database; pub struct MockDatabase { #[allow(clippy::type_complexity)] - describe_fn: Arc crate::Result +Send +Sync>, + describe_fn: Arc crate::Result + Send + Sync>, #[allow(clippy::type_complexity)] - execute_fn: Arc) -> crate::Result<()> +Send +Sync>, + execute_fn: Arc) -> crate::Result<()> + Send + Sync>, } #[derive(Clone)] pub struct MockConnection { #[allow(clippy::type_complexity)] - describe_fn: Arc crate::Result + Send +Sync>, + describe_fn: Arc crate::Result + Send + Sync>, #[allow(clippy::type_complexity)] - execute_fn: Arc) -> crate::Result<()> + Send +Sync>, + execute_fn: Arc) -> crate::Result<()> + Send + Sync>, } impl MockDatabase { @@ -33,7 +33,7 @@ impl MockDatabase { pub fn with_execute( mut self, - f: impl Fn(&Program, Box) -> crate::Result<()> + Send + Sync +'static, + f: impl Fn(&Program, Box) -> crate::Result<()> + Send + Sync + 'static, ) -> Self { self.execute_fn = Arc::new(f); self diff --git a/libsqlx/src/program.rs b/libsqlx/src/program.rs index b2b627af..fc30a4bf 100644 --- a/libsqlx/src/program.rs +++ b/libsqlx/src/program.rs @@ -39,7 +39,6 @@ impl Program { Self { steps } } - #[cfg(test)] pub fn seq(stmts: &[&str]) -> Self { use crate::{analysis::Statement, query::Params}; diff --git a/libsqlx/src/result_builder.rs b/libsqlx/src/result_builder.rs index 458b50cc..d69ac35b 100644 --- a/libsqlx/src/result_builder.rs +++ b/libsqlx/src/result_builder.rs @@ -138,6 +138,9 @@ pub trait ResultBuilder: Send + 'static { ) -> Result { Ok(true) } + + /// There was a fatal error and the request was aborted + fn finnalize_error(&mut self, _e: String) {} } pub trait ResultBuilderExt: ResultBuilder { From a65f77182b3a4600b5ba39a553298ce6ee4c9998 Mon Sep 17 00:00:00 2001 From: ad hoc Date: Thu, 20 Jul 2023 16:23:15 +0200 Subject: [PATCH 27/64] eager database schedule on allocate --- libsqlx-server/src/http/admin.rs | 13 ++++++++----- libsqlx-server/src/http/user/extractors.rs | 2 +- libsqlx-server/src/linc/bus.rs | 4 ++++ libsqlx-server/src/main.rs | 2 +- libsqlx-server/src/manager.rs | 14 ++++++++++++-- libsqlx-server/src/meta.rs | 5 +++-- 6 files changed, 29 insertions(+), 11 deletions(-) diff --git a/libsqlx-server/src/http/admin.rs b/libsqlx-server/src/http/admin.rs index 9323bcdd..8a28f4bc 100644 --- a/libsqlx-server/src/http/admin.rs +++ b/libsqlx-server/src/http/admin.rs @@ -10,16 +10,17 @@ use hyper::server::accept::Accept; use serde::{Deserialize, Deserializer, Serialize}; use tokio::io::{AsyncRead, AsyncWrite}; -use crate::meta::Store; +use crate::linc::bus::Bus; +use crate::manager::Manager; use crate::allocation::config::{AllocConfig, DbConfig}; use crate::linc::NodeId; pub struct Config { - pub meta_store: Arc, + pub bus: Arc>>, } struct AdminServerState { - meta_store: Arc, + bus: Arc>>, } pub async fn run_admin_api(config: Config, listener: I) -> Result<()> @@ -28,7 +29,7 @@ where I::Conn: AsyncRead + AsyncWrite + Send + Unpin + 'static, { let state = AdminServerState { - meta_store: config.meta_store, + bus: config.bus, }; let app = Router::new() @@ -114,7 +115,9 @@ async fn allocate( }, }, }; - state.meta_store.allocate(&req.alloc_id, &config).await; + + let dispatcher = state.bus.clone(); + state.bus.handler().allocate(&req.alloc_id, &config, dispatcher).await; Ok(Json(AllocateResp {})) } diff --git a/libsqlx-server/src/http/user/extractors.rs b/libsqlx-server/src/http/user/extractors.rs index 962eb060..582b0fd6 100644 --- a/libsqlx-server/src/http/user/extractors.rs +++ b/libsqlx-server/src/http/user/extractors.rs @@ -20,7 +20,7 @@ impl FromRequestParts> for Database { let Ok(host_str) = std::str::from_utf8(host.as_bytes()) else {return Err(UserApiError::MissingHost)}; let db_name = parse_host(host_str)?; let db_id = DatabaseId::from_name(db_name); - let Some(sender) = state.manager.alloc(db_id, state.bus.clone()).await else { return Err(UserApiError::UnknownDatabase(db_name.to_owned())) }; + let Some(sender) = state.manager.schedule(db_id, state.bus.clone()).await else { return Err(UserApiError::UnknownDatabase(db_name.to_owned())) }; Ok(Database { sender }) } diff --git a/libsqlx-server/src/linc/bus.rs b/libsqlx-server/src/linc/bus.rs index 7c7b70dd..5072c8ae 100644 --- a/libsqlx-server/src/linc/bus.rs +++ b/libsqlx-server/src/linc/bus.rs @@ -31,6 +31,10 @@ impl Bus { self.node_id } + pub fn handler(&self) -> &H { + &self.handler + } + pub async fn incomming(self: &Arc, incomming: Inbound) { self.handler.handle(self.clone(), incomming).await; } diff --git a/libsqlx-server/src/main.rs b/libsqlx-server/src/main.rs index 454ae954..742be0b5 100644 --- a/libsqlx-server/src/main.rs +++ b/libsqlx-server/src/main.rs @@ -39,7 +39,7 @@ async fn spawn_admin_api( ) -> Result<()> { let admin_api_listener = TcpListener::bind(config.addr).await?; let fut = run_admin_api( - http::admin::Config { meta_store }, + http::admin::Config { manager: meta_store }, AddrIncoming::from_listener(admin_api_listener)?, ); set.spawn(fut); diff --git a/libsqlx-server/src/manager.rs b/libsqlx-server/src/manager.rs index 414e17bf..64f63437 100644 --- a/libsqlx-server/src/manager.rs +++ b/libsqlx-server/src/manager.rs @@ -6,6 +6,7 @@ use moka::future::Cache; use tokio::sync::mpsc; use tokio::task::JoinSet; +use crate::allocation::config::AllocConfig; use crate::allocation::{Allocation, AllocationMessage, Database}; use crate::hrana; use crate::linc::bus::Dispatch; @@ -31,7 +32,7 @@ impl Manager { } /// Returns a handle to an allocation, lazily initializing if it isn't already loaded. - pub async fn alloc( + pub async fn schedule( self: &Arc, database_id: DatabaseId, dispatcher: Arc, @@ -65,6 +66,15 @@ impl Manager { None } + + pub async fn allocate(self: &Arc, database_name: &str, meta: &AllocConfig, dispatcher: Arc) { + let id = self.store().allocate(database_name, meta).await; + self.schedule(id, dispatcher).await; + } + + pub fn store(&self) -> &Store { + &self.meta_store + } } #[async_trait::async_trait] @@ -72,7 +82,7 @@ impl Handler for Arc { async fn handle(&self, bus: Arc, msg: Inbound) { if let Some(sender) = self .clone() - .alloc(msg.enveloppe.database_id.unwrap(), bus.clone()) + .schedule(msg.enveloppe.database_id.unwrap(), bus.clone()) .await { let _ = sender.send(AllocationMessage::Inbound(msg)).await; diff --git a/libsqlx-server/src/meta.rs b/libsqlx-server/src/meta.rs index 0167497b..0770e7ed 100644 --- a/libsqlx-server/src/meta.rs +++ b/libsqlx-server/src/meta.rs @@ -56,16 +56,17 @@ impl Store { Self { meta_store } } - pub async fn allocate(&self, database_name: &str, meta: &AllocConfig) { + pub async fn allocate(&self, database_name: &str, meta: &AllocConfig) -> DatabaseId { //TODO: Handle conflict + let id = DatabaseId::from_name(database_name); block_in_place(|| { let meta_bytes = bincode::serialize(meta).unwrap(); - let id = DatabaseId::from_name(database_name); self.meta_store .compare_and_swap(id, None as Option<&[u8]>, Some(meta_bytes)) .unwrap() .unwrap(); }); + id } pub async fn deallocate(&self, _database_name: &str) { From 5d16027c9e24f1a95a1edf4f70d235bec0741eaf Mon Sep 17 00:00:00 2001 From: ad hoc Date: Thu, 20 Jul 2023 16:50:49 +0200 Subject: [PATCH 28/64] deallocate database --- libsqlx-server/src/allocation/mod.rs | 22 ++++++------- libsqlx-server/src/http/admin.rs | 47 ++++++++++++++++++---------- libsqlx-server/src/main.rs | 6 ++-- libsqlx-server/src/manager.rs | 18 +++++++++-- libsqlx-server/src/meta.rs | 8 ++--- 5 files changed, 63 insertions(+), 38 deletions(-) diff --git a/libsqlx-server/src/allocation/mod.rs b/libsqlx-server/src/allocation/mod.rs index a41f2cd7..7bea3ddd 100644 --- a/libsqlx-server/src/allocation/mod.rs +++ b/libsqlx-server/src/allocation/mod.rs @@ -590,14 +590,16 @@ impl ReplicaConnection { .unwrap(), BuilderStep::BeginRows => req.builder.begin_rows().unwrap(), BuilderStep::BeginRow => req.builder.begin_row().unwrap(), - BuilderStep::AddRowValue(v) => req.builder.add_row_value((&v).into()).unwrap(), + BuilderStep::AddRowValue(v) => { + req.builder.add_row_value((&v).into()).unwrap() + } BuilderStep::FinishRow => req.builder.finish_row().unwrap(), BuilderStep::FinishRows => req.builder.finish_rows().unwrap(), BuilderStep::Finnalize { is_txn, frame_no } => { let _ = req.builder.finnalize(is_txn, frame_no).unwrap(); finnalized = true; - }, - BuilderStep::FinnalizeError(e) => { + } + BuilderStep::FinnalizeError(e) => { req.builder.finnalize_error(e); finnalized = true; } @@ -625,15 +627,13 @@ impl ConnectionHandler for ReplicaConnection { // self.conn.writer().current_req.timeout.poll() let mut req = self.conn.writer().current_req.lock(); let should_abort_query = match &mut *req { - Some(ref mut req) => { - match req.timeout.as_mut().poll(cx) { - Poll::Ready(_) => { - req.builder.finnalize_error("request timed out".to_string()); - true - } - Poll::Pending => return Poll::Pending, + Some(ref mut req) => match req.timeout.as_mut().poll(cx) { + Poll::Ready(_) => { + req.builder.finnalize_error("request timed out".to_string()); + true } - } + Poll::Pending => return Poll::Pending, + }, None => return Poll::Ready(()), }; diff --git a/libsqlx-server/src/http/admin.rs b/libsqlx-server/src/http/admin.rs index 8a28f4bc..0e263ddf 100644 --- a/libsqlx-server/src/http/admin.rs +++ b/libsqlx-server/src/http/admin.rs @@ -1,19 +1,20 @@ -use std::sync::Arc; use std::str::FromStr; +use std::sync::Arc; use std::time::Duration; +use axum::extract::{Path, State}; +use axum::routing::{delete, post}; use axum::{Json, Router}; -use axum::routing::post; -use axum::extract::State; use color_eyre::eyre::Result; use hyper::server::accept::Accept; use serde::{Deserialize, Deserializer, Serialize}; use tokio::io::{AsyncRead, AsyncWrite}; -use crate::linc::bus::Bus; -use crate::manager::Manager; use crate::allocation::config::{AllocConfig, DbConfig}; +use crate::linc::bus::Bus; use crate::linc::NodeId; +use crate::manager::Manager; +use crate::meta::DatabaseId; pub struct Config { pub bus: Arc>>, @@ -28,12 +29,11 @@ where I: Accept, I::Conn: AsyncRead + AsyncWrite + Send + Unpin + 'static, { - let state = AdminServerState { - bus: config.bus, - }; + let state = AdminServerState { bus: config.bus }; let app = Router::new() .route("/manage/allocation", post(allocate).get(list_allocs)) + .route("/manage/allocation/:db_name", delete(deallocate)) .with_state(Arc::new(state)); axum::Server::builder(listener) .serve(app.into_make_service()) @@ -50,7 +50,7 @@ struct AllocateResp {} #[derive(Deserialize, Debug)] struct AllocateReq { - alloc_id: String, + database_name: String, max_conccurent_connection: Option, config: DbConfigReq, } @@ -61,7 +61,10 @@ pub enum DbConfigReq { Primary {}, Replica { primary_node_id: NodeId, - #[serde(deserialize_with = "deserialize_duration", default = "default_proxy_timeout")] + #[serde( + deserialize_with = "deserialize_duration", + default = "default_proxy_timeout" + )] proxy_request_timeout_duration: Duration, }, } @@ -79,8 +82,8 @@ where type Value = Duration; fn visit_str(self, v: &str) -> std::result::Result - where - E: serde::de::Error, + where + E: serde::de::Error, { match humantime::Duration::from_str(v) { Ok(d) => Ok(*d), @@ -91,7 +94,6 @@ where fn expecting(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { f.write_str("a duration, in a string format") } - } deserializer.deserialize_str(Visitor) @@ -103,7 +105,7 @@ async fn allocate( ) -> Result, Json> { let config = AllocConfig { max_conccurent_connection: req.max_conccurent_connection.unwrap_or(16), - db_name: req.alloc_id.clone(), + db_name: req.database_name.clone(), db_config: match req.config { DbConfigReq::Primary {} => DbConfig::Primary {}, DbConfigReq::Replica { @@ -117,7 +119,18 @@ async fn allocate( }; let dispatcher = state.bus.clone(); - state.bus.handler().allocate(&req.alloc_id, &config, dispatcher).await; + let id = DatabaseId::from_name(&req.database_name); + state.bus.handler().allocate(id, &config, dispatcher).await; + + Ok(Json(AllocateResp {})) +} + +async fn deallocate( + State(state): State>, + Path(database_name): Path, +) -> Result, Json> { + let id = DatabaseId::from_name(&database_name); + state.bus.handler().deallocate(id).await; Ok(Json(AllocateResp {})) } @@ -136,7 +149,9 @@ async fn list_allocs( State(state): State>, ) -> Result, Json> { let allocs = state - .meta_store + .bus + .handler() + .store() .list_allocs() .await .into_iter() diff --git a/libsqlx-server/src/main.rs b/libsqlx-server/src/main.rs index 742be0b5..296c54c9 100644 --- a/libsqlx-server/src/main.rs +++ b/libsqlx-server/src/main.rs @@ -35,11 +35,11 @@ struct Args { async fn spawn_admin_api( set: &mut JoinSet>, config: &AdminApiConfig, - meta_store: Arc, + bus: Arc>>, ) -> Result<()> { let admin_api_listener = TcpListener::bind(config.addr).await?; let fut = run_admin_api( - http::admin::Config { manager: meta_store }, + http::admin::Config { bus }, AddrIncoming::from_listener(admin_api_listener)?, ); set.spawn(fut); @@ -98,7 +98,7 @@ async fn main() -> Result<()> { let bus = Arc::new(Bus::new(config.cluster.id, manager.clone())); spawn_cluster_networking(&mut join_set, &config.cluster, bus.clone()).await?; - spawn_admin_api(&mut join_set, &config.admin_api, store.clone()).await?; + spawn_admin_api(&mut join_set, &config.admin_api, bus.clone()).await?; spawn_user_api(&mut join_set, &config.user_api, manager, bus).await?; join_set.join_next().await; diff --git a/libsqlx-server/src/manager.rs b/libsqlx-server/src/manager.rs index 64f63437..69d1376f 100644 --- a/libsqlx-server/src/manager.rs +++ b/libsqlx-server/src/manager.rs @@ -67,9 +67,21 @@ impl Manager { None } - pub async fn allocate(self: &Arc, database_name: &str, meta: &AllocConfig, dispatcher: Arc) { - let id = self.store().allocate(database_name, meta).await; - self.schedule(id, dispatcher).await; + pub async fn allocate( + self: &Arc, + database_id: DatabaseId, + meta: &AllocConfig, + dispatcher: Arc, + ) { + self.store().allocate(database_id, meta).await; + self.schedule(database_id, dispatcher).await; + } + + pub async fn deallocate(&self, database_id: DatabaseId) { + self.meta_store.deallocate(database_id).await; + self.cache.remove(&database_id).await; + let db_path = self.db_path.join("dbs").join(database_id.to_string()); + tokio::fs::remove_dir_all(db_path).await.unwrap(); } pub fn store(&self) -> &Store { diff --git a/libsqlx-server/src/meta.rs b/libsqlx-server/src/meta.rs index 0770e7ed..0d61d04f 100644 --- a/libsqlx-server/src/meta.rs +++ b/libsqlx-server/src/meta.rs @@ -56,9 +56,8 @@ impl Store { Self { meta_store } } - pub async fn allocate(&self, database_name: &str, meta: &AllocConfig) -> DatabaseId { + pub async fn allocate(&self, id: DatabaseId, meta: &AllocConfig) { //TODO: Handle conflict - let id = DatabaseId::from_name(database_name); block_in_place(|| { let meta_bytes = bincode::serialize(meta).unwrap(); self.meta_store @@ -66,11 +65,10 @@ impl Store { .unwrap() .unwrap(); }); - id } - pub async fn deallocate(&self, _database_name: &str) { - todo!() + pub async fn deallocate(&self, id: DatabaseId) { + block_in_place(|| self.meta_store.remove(id).unwrap()); } pub async fn meta(&self, database_id: &DatabaseId) -> Option { From b587be53332bedd53994f6e273b1654e6c0a9cc6 Mon Sep 17 00:00:00 2001 From: ad hoc Date: Thu, 20 Jul 2023 17:55:55 +0200 Subject: [PATCH 29/64] reorganize allocation file --- libsqlx-server/src/allocation/mod.rs | 660 ++--------------------- libsqlx-server/src/allocation/primary.rs | 275 ++++++++++ libsqlx-server/src/allocation/replica.rs | 342 ++++++++++++ 3 files changed, 676 insertions(+), 601 deletions(-) create mode 100644 libsqlx-server/src/allocation/primary.rs create mode 100644 libsqlx-server/src/allocation/replica.rs diff --git a/libsqlx-server/src/allocation/mod.rs b/libsqlx-server/src/allocation/mod.rs index 7bea3ddd..8c7ba873 100644 --- a/libsqlx-server/src/allocation/mod.rs +++ b/libsqlx-server/src/allocation/mod.rs @@ -1,49 +1,40 @@ use std::collections::hash_map::Entry; use std::collections::HashMap; use std::future::poll_fn; -use std::mem::size_of; -use std::ops::Deref; use std::path::PathBuf; -use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; -use std::time::{Duration, Instant}; +use std::time::Instant; -use bytes::Bytes; use either::Either; -use futures::Future; -use libsqlx::libsql::{LibsqlDatabase, LogCompactor, LogFile, PrimaryType, ReplicaType}; +use libsqlx::libsql::{LibsqlDatabase, LogCompactor, LogFile}; use libsqlx::program::Program; -use libsqlx::proxy::{WriteProxyConnection, WriteProxyDatabase}; -use libsqlx::result_builder::{Column, QueryBuilderConfig, ResultBuilder}; -use libsqlx::{ - Database as _, DescribeResponse, Frame, FrameNo, InjectableDatabase, Injector, LogReadError, - ReplicationLogger, -}; -use parking_lot::Mutex; +use libsqlx::proxy::WriteProxyDatabase; +use libsqlx::{Database as _, InjectableDatabase}; use tokio::sync::{mpsc, oneshot}; use tokio::task::{block_in_place, JoinSet}; -use tokio::time::{timeout, Sleep}; +use crate::allocation::primary::FrameStreamer; use crate::hrana; use crate::hrana::http::handle_pipeline; use crate::hrana::http::proto::{PipelineRequestBody, PipelineResponseBody}; use crate::linc::bus::Dispatch; -use crate::linc::proto::{ - BuilderStep, Enveloppe, Frames, Message, ProxyResponse, StepError, Value, -}; -use crate::linc::{Inbound, NodeId, Outbound}; +use crate::linc::proto::{Frames, Message}; +use crate::linc::{Inbound, NodeId}; use crate::meta::DatabaseId; use self::config::{AllocConfig, DbConfig}; +use self::primary::{PrimaryConnection, PrimaryDatabase, ProxyResponseBuilder}; +use self::replica::{ProxyDatabase, RemoteDb, ReplicaConnection, Replicator}; pub mod config; +mod primary; +mod replica; /// the maximum number of frame a Frame messahe is allowed to contain const FRAMES_MESSAGE_MAX_COUNT: usize = 5; +const MAX_INJECTOR_BUFFER_CAP: usize = 32; -type ProxyConnection = - WriteProxyConnection, RemoteConn>; type ExecFn = Box; pub enum AllocationMessage { @@ -54,240 +45,6 @@ pub enum AllocationMessage { Inbound(Inbound), } -pub struct RemoteDb { - proxy_request_timeout_duration: Duration, -} - -#[derive(Clone)] -pub struct RemoteConn { - inner: Arc, -} - -struct Request { - id: Option, - builder: Box, - pgm: Option, - next_seq_no: u32, - timeout: Pin>, -} - -pub struct RemoteConnInner { - current_req: Mutex>, - request_timeout_duration: Duration, -} - -impl Deref for RemoteConn { - type Target = RemoteConnInner; - - fn deref(&self) -> &Self::Target { - self.inner.as_ref() - } -} - -impl libsqlx::Connection for RemoteConn { - fn execute_program( - &mut self, - program: &libsqlx::program::Program, - builder: Box, - ) -> libsqlx::Result<()> { - // When we need to proxy a query, we place it in the current request slot. When we are - // back in a async context, we'll send it to the primary, and asynchrously drive the - // builder. - let mut lock = self.inner.current_req.lock(); - *lock = match *lock { - Some(_) => unreachable!("conccurent request on the same connection!"), - None => Some(Request { - id: None, - builder, - pgm: Some(program.clone()), - next_seq_no: 0, - timeout: Box::pin(tokio::time::sleep(self.inner.request_timeout_duration)), - }), - }; - - Ok(()) - } - - fn describe(&self, _sql: String) -> libsqlx::Result { - unreachable!("Describe request should not be proxied") - } -} - -impl libsqlx::Database for RemoteDb { - type Connection = RemoteConn; - - fn connect(&self) -> Result { - Ok(RemoteConn { - inner: Arc::new(RemoteConnInner { - current_req: Default::default(), - request_timeout_duration: self.proxy_request_timeout_duration, - }), - }) - } -} - -pub type ProxyDatabase = WriteProxyDatabase, RemoteDb>; - -pub struct PrimaryDatabase { - pub db: LibsqlDatabase, - pub replica_streams: HashMap)>, - pub frame_notifier: tokio::sync::watch::Receiver, -} - -struct ProxyResponseBuilder { - dispatcher: Arc, - buffer: Vec, - to: NodeId, - database_id: DatabaseId, - req_id: u32, - connection_id: u32, - next_seq_no: u32, -} - -const MAX_STEP_BATCH_SIZE: usize = 100_000_000; // ~100kb - -impl ProxyResponseBuilder { - fn maybe_send(&mut self) { - // FIXME: this is stupid: compute current buffer size on the go instead - let size = self - .buffer - .iter() - .map(|s| match s { - BuilderStep::FinishStep(_, _) => 2 * 8, - BuilderStep::StepError(StepError(s)) => s.len(), - BuilderStep::ColsDesc(ref d) => d - .iter() - .map(|c| c.name.len() + c.decl_ty.as_ref().map(|t| t.len()).unwrap_or_default()) - .sum(), - BuilderStep::Finnalize { .. } => 9, - BuilderStep::AddRowValue(v) => match v { - crate::linc::proto::Value::Text(s) | crate::linc::proto::Value::Blob(s) => { - s.len() - } - _ => size_of::(), - }, - _ => 8, - }) - .sum::(); - - if size > MAX_STEP_BATCH_SIZE { - self.send() - } - } - - fn send(&mut self) { - let msg = Outbound { - to: self.to, - enveloppe: Enveloppe { - database_id: Some(self.database_id), - message: Message::ProxyResponse(crate::linc::proto::ProxyResponse { - connection_id: self.connection_id, - req_id: self.req_id, - row_steps: std::mem::take(&mut self.buffer), - seq_no: self.next_seq_no, - }), - }, - }; - - self.next_seq_no += 1; - tokio::runtime::Handle::current().block_on(self.dispatcher.dispatch(msg)); - } -} - -impl ResultBuilder for ProxyResponseBuilder { - fn init( - &mut self, - _config: &libsqlx::result_builder::QueryBuilderConfig, - ) -> Result<(), libsqlx::result_builder::QueryResultBuilderError> { - self.buffer.push(BuilderStep::Init); - self.maybe_send(); - Ok(()) - } - - fn begin_step(&mut self) -> Result<(), libsqlx::result_builder::QueryResultBuilderError> { - self.buffer.push(BuilderStep::BeginStep); - self.maybe_send(); - Ok(()) - } - - fn finish_step( - &mut self, - affected_row_count: u64, - last_insert_rowid: Option, - ) -> Result<(), libsqlx::result_builder::QueryResultBuilderError> { - self.buffer.push(BuilderStep::FinishStep( - affected_row_count, - last_insert_rowid, - )); - self.maybe_send(); - Ok(()) - } - - fn step_error( - &mut self, - error: libsqlx::error::Error, - ) -> Result<(), libsqlx::result_builder::QueryResultBuilderError> { - self.buffer - .push(BuilderStep::StepError(StepError(error.to_string()))); - self.maybe_send(); - Ok(()) - } - - fn cols_description( - &mut self, - cols: &mut dyn Iterator, - ) -> Result<(), libsqlx::result_builder::QueryResultBuilderError> { - self.buffer - .push(BuilderStep::ColsDesc(cols.map(Into::into).collect())); - self.maybe_send(); - Ok(()) - } - - fn begin_rows(&mut self) -> Result<(), libsqlx::result_builder::QueryResultBuilderError> { - self.buffer.push(BuilderStep::BeginRows); - self.maybe_send(); - Ok(()) - } - - fn begin_row(&mut self) -> Result<(), libsqlx::result_builder::QueryResultBuilderError> { - self.buffer.push(BuilderStep::BeginRow); - self.maybe_send(); - Ok(()) - } - - fn add_row_value( - &mut self, - v: libsqlx::result_builder::ValueRef, - ) -> Result<(), libsqlx::result_builder::QueryResultBuilderError> { - self.buffer.push(BuilderStep::AddRowValue(v.into())); - self.maybe_send(); - Ok(()) - } - - fn finish_row(&mut self) -> Result<(), libsqlx::result_builder::QueryResultBuilderError> { - self.buffer.push(BuilderStep::FinishRow); - self.maybe_send(); - Ok(()) - } - - fn finish_rows(&mut self) -> Result<(), libsqlx::result_builder::QueryResultBuilderError> { - self.buffer.push(BuilderStep::FinishRows); - self.maybe_send(); - Ok(()) - } - - fn finnalize( - &mut self, - is_txn: bool, - frame_no: Option, - ) -> Result { - self.buffer - .push(BuilderStep::Finnalize { is_txn, frame_no }); - self.send(); - Ok(true) - } -} - pub enum Database { Primary(PrimaryDatabase), Replica { @@ -315,142 +72,6 @@ impl LogCompactor for Compactor { } } -const MAX_INJECTOR_BUFFER_CAP: usize = 32; - -struct Replicator { - dispatcher: Arc, - req_id: u32, - next_frame_no: FrameNo, - next_seq: u32, - database_id: DatabaseId, - primary_node_id: NodeId, - injector: Box, - receiver: mpsc::Receiver, -} - -impl Replicator { - async fn run(mut self) { - self.query_replicate().await; - loop { - match timeout(Duration::from_secs(5), self.receiver.recv()).await { - Ok(Some(Frames { - req_no: req_id, - seq_no: seq, - frames, - })) => { - // ignore frames from a previous call to Replicate - if req_id != self.req_id { - tracing::debug!(req_id, self.req_id, "wrong req_id"); - continue; - } - if seq != self.next_seq { - // this is not the batch of frame we were expecting, drop what we have, and - // ask again from last checkpoint - tracing::debug!(seq, self.next_seq, "wrong seq"); - self.query_replicate().await; - continue; - }; - self.next_seq += 1; - - tracing::debug!("injecting {} frames", frames.len()); - - for bytes in frames { - let frame = Frame::try_from_bytes(bytes).unwrap(); - block_in_place(|| { - if let Some(last_committed) = self.injector.inject(frame).unwrap() { - tracing::debug!(last_committed); - self.next_frame_no = last_committed + 1; - } - }); - } - } - Err(_) => self.query_replicate().await, - Ok(None) => break, - } - } - } - - async fn query_replicate(&mut self) { - self.req_id += 1; - self.next_seq = 0; - // clear buffered, uncommitted frames - self.injector.clear(); - self.dispatcher - .dispatch(Outbound { - to: self.primary_node_id, - enveloppe: Enveloppe { - database_id: Some(self.database_id), - message: Message::Replicate { - next_frame_no: self.next_frame_no, - req_no: self.req_id, - }, - }, - }) - .await; - } -} - -struct FrameStreamer { - logger: Arc, - database_id: DatabaseId, - node_id: NodeId, - next_frame_no: FrameNo, - req_no: u32, - seq_no: u32, - dipatcher: Arc, - notifier: tokio::sync::watch::Receiver, - buffer: Vec, -} - -impl FrameStreamer { - async fn run(mut self) { - loop { - match block_in_place(|| self.logger.get_frame(self.next_frame_no)) { - Ok(frame) => { - if self.buffer.len() > FRAMES_MESSAGE_MAX_COUNT { - self.send_frames().await; - } - self.buffer.push(frame.bytes()); - self.next_frame_no += 1; - } - Err(LogReadError::Ahead) => { - tracing::debug!("frame {} not yet avaiblable", self.next_frame_no); - if !self.buffer.is_empty() { - self.send_frames().await; - } - if self - .notifier - .wait_for(|fno| *fno >= self.next_frame_no) - .await - .is_err() - { - break; - } - } - Err(LogReadError::Error(_)) => todo!("handle log read error"), - Err(LogReadError::SnapshotRequired) => todo!("handle reading from snapshot"), - } - } - } - - async fn send_frames(&mut self) { - let frames = std::mem::take(&mut self.buffer); - let outbound = Outbound { - to: self.node_id, - enveloppe: Enveloppe { - database_id: Some(self.database_id), - message: Message::Frames(Frames { - req_no: self.req_no, - seq_no: self.seq_no, - frames, - }), - }, - }; - self.seq_no += 1; - self.dipatcher.dispatch(outbound).await; - } -} - impl Database { pub fn from_config(config: &AllocConfig, path: PathBuf, dispatcher: Arc) -> Self { match config.db_config { @@ -485,16 +106,14 @@ impl Database { let (sender, receiver) = mpsc::channel(16); let database_id = DatabaseId::from_name(&config.db_name); - let replicator = Replicator { + let replicator = Replicator::new( dispatcher, - req_id: 0, - next_frame_no: 0, // TODO: load the last commited from meta file - next_seq: 0, + 0, database_id, primary_node_id, injector, receiver, - }; + ); tokio::spawn(replicator.run()); @@ -529,195 +148,6 @@ impl Database { } } -struct PrimaryConnection { - conn: libsqlx::libsql::LibsqlConnection, -} - -#[async_trait::async_trait] -impl ConnectionHandler for PrimaryConnection { - fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<()> { - Poll::Ready(()) - } - - async fn handle_exec(&mut self, exec: ExecFn) { - block_in_place(|| exec(&mut self.conn)); - } - - async fn handle_inbound(&mut self, _msg: Inbound) { - tracing::debug!("primary connection received message, ignoring.") - } -} - -struct ReplicaConnection { - conn: ProxyConnection, - connection_id: u32, - next_req_id: u32, - primary_node_id: NodeId, - database_id: DatabaseId, - dispatcher: Arc, -} - -impl ReplicaConnection { - fn handle_proxy_response(&mut self, resp: ProxyResponse) { - let mut lock = self.conn.writer().inner.current_req.lock(); - let finnalized = match *lock { - Some(ref mut req) if req.id == Some(resp.req_id) && resp.seq_no == req.next_seq_no => { - self.next_req_id += 1; - // TODO: pass actual config - let config = QueryBuilderConfig { max_size: None }; - let mut finnalized = false; - for step in resp.row_steps.into_iter() { - if finnalized { - break; - }; - match step { - BuilderStep::Init => req.builder.init(&config).unwrap(), - BuilderStep::BeginStep => req.builder.begin_step().unwrap(), - BuilderStep::FinishStep(affected_row_count, last_insert_rowid) => req - .builder - .finish_step(affected_row_count, last_insert_rowid) - .unwrap(), - BuilderStep::StepError(e) => req - .builder - .step_error(todo!("handle proxy step error")) - .unwrap(), - BuilderStep::ColsDesc(cols) => req - .builder - .cols_description(&mut cols.iter().map(|c| Column { - name: &c.name, - decl_ty: c.decl_ty.as_deref(), - })) - .unwrap(), - BuilderStep::BeginRows => req.builder.begin_rows().unwrap(), - BuilderStep::BeginRow => req.builder.begin_row().unwrap(), - BuilderStep::AddRowValue(v) => { - req.builder.add_row_value((&v).into()).unwrap() - } - BuilderStep::FinishRow => req.builder.finish_row().unwrap(), - BuilderStep::FinishRows => req.builder.finish_rows().unwrap(), - BuilderStep::Finnalize { is_txn, frame_no } => { - let _ = req.builder.finnalize(is_txn, frame_no).unwrap(); - finnalized = true; - } - BuilderStep::FinnalizeError(e) => { - req.builder.finnalize_error(e); - finnalized = true; - } - } - } - finnalized - } - Some(_) => todo!("error processing response"), - None => { - tracing::error!("received builder message, but there is no pending request"); - false - } - }; - - if finnalized { - *lock = None; - } - } -} - -#[async_trait::async_trait] -impl ConnectionHandler for ReplicaConnection { - fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<()> { - // we are currently handling a request on this connection - // self.conn.writer().current_req.timeout.poll() - let mut req = self.conn.writer().current_req.lock(); - let should_abort_query = match &mut *req { - Some(ref mut req) => match req.timeout.as_mut().poll(cx) { - Poll::Ready(_) => { - req.builder.finnalize_error("request timed out".to_string()); - true - } - Poll::Pending => return Poll::Pending, - }, - None => return Poll::Ready(()), - }; - - if should_abort_query { - *req = None - } - - Poll::Ready(()) - } - - async fn handle_exec(&mut self, exec: ExecFn) { - block_in_place(|| exec(&mut self.conn)); - let msg = { - let mut lock = self.conn.writer().inner.current_req.lock(); - match *lock { - Some(ref mut req) if req.id.is_none() => { - let program = req - .pgm - .take() - .expect("unsent request should have a program"); - let req_id = self.next_req_id; - self.next_req_id += 1; - req.id = Some(req_id); - - let msg = Outbound { - to: self.primary_node_id, - enveloppe: Enveloppe { - database_id: Some(self.database_id), - message: Message::ProxyRequest { - connection_id: self.connection_id, - req_id, - program, - }, - }, - }; - - Some(msg) - } - _ => None, - } - }; - - if let Some(msg) = msg { - self.dispatcher.dispatch(msg).await; - } - } - - async fn handle_inbound(&mut self, msg: Inbound) { - match msg.enveloppe.message { - Message::ProxyResponse(resp) => { - self.handle_proxy_response(resp); - } - _ => (), // ignore anything else - } - } -} - -#[async_trait::async_trait] -impl ConnectionHandler for Either -where - L: ConnectionHandler, - R: ConnectionHandler, -{ - fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<()> { - match self { - Either::Left(l) => l.poll_ready(cx), - Either::Right(r) => r.poll_ready(cx), - } - } - - async fn handle_exec(&mut self, exec: ExecFn) { - match self { - Either::Left(l) => l.handle_exec(exec).await, - Either::Right(r) => r.handle_exec(exec).await, - } - } - async fn handle_inbound(&mut self, msg: Inbound) { - match self { - Either::Left(l) => l.handle_inbound(msg).await, - Either::Right(r) => r.handle_inbound(msg).await, - } - } -} - pub struct Allocation { pub inbox: mpsc::Receiver, pub database: Database, @@ -874,7 +304,7 @@ impl Allocation { async fn handle_proxy( &mut self, - node_id: NodeId, + to: NodeId, connection_id: u32, req_id: u32, program: Program, @@ -884,15 +314,13 @@ impl Allocation { let exec = |conn: ConnectionHandle| async move { let _ = conn .exec(move |conn| { - let builder = ProxyResponseBuilder { + let builder = ProxyResponseBuilder::new( dispatcher, - req_id, - buffer: Vec::new(), - to: node_id, database_id, + to, + req_id, connection_id, - next_seq_no: 0, - }; + ); conn.execute_program(&program, Box::new(builder)).unwrap(); }) .await; @@ -901,14 +329,14 @@ impl Allocation { if self.database.is_primary() { match self .connections - .get(&node_id) + .get(&to) .and_then(|m| m.get(&connection_id).cloned()) { Some(handle) => { tokio::spawn(exec(handle)); } None => { - let handle = self.new_conn(Some((node_id, connection_id))).await; + let handle = self.new_conn(Some((to, connection_id))).await; tokio::spawn(exec(handle)); } } @@ -955,13 +383,6 @@ impl Allocation { } } -struct Connection { - id: (NodeId, u32), - conn: C, - exec: mpsc::Receiver, - inbound: mpsc::Receiver, -} - #[async_trait::async_trait] trait ConnectionHandler: Send { fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<()>; @@ -969,6 +390,40 @@ trait ConnectionHandler: Send { async fn handle_inbound(&mut self, msg: Inbound); } +#[async_trait::async_trait] +impl ConnectionHandler for Either +where + L: ConnectionHandler, + R: ConnectionHandler, +{ + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<()> { + match self { + Either::Left(l) => l.poll_ready(cx), + Either::Right(r) => r.poll_ready(cx), + } + } + + async fn handle_exec(&mut self, exec: ExecFn) { + match self { + Either::Left(l) => l.handle_exec(exec).await, + Either::Right(r) => r.handle_exec(exec).await, + } + } + async fn handle_inbound(&mut self, msg: Inbound) { + match self { + Either::Left(l) => l.handle_inbound(msg).await, + Either::Right(r) => r.handle_inbound(msg).await, + } + } +} + +struct Connection { + id: (NodeId, u32), + conn: C, + exec: mpsc::Receiver, + inbound: mpsc::Receiver, +} + impl Connection { async fn run(mut self) -> (NodeId, u32) { loop { @@ -991,8 +446,11 @@ impl Connection { #[cfg(test)] mod test { + use std::time::Duration; + use tokio::sync::Notify; + use crate::allocation::replica::ReplicaConnection; use crate::linc::bus::Bus; use super::*; diff --git a/libsqlx-server/src/allocation/primary.rs b/libsqlx-server/src/allocation/primary.rs new file mode 100644 index 00000000..15ac4dbd --- /dev/null +++ b/libsqlx-server/src/allocation/primary.rs @@ -0,0 +1,275 @@ +use std::collections::HashMap; +use std::mem::size_of; +use std::sync::Arc; +use std::task::{Context, Poll}; + +use bytes::Bytes; +use libsqlx::libsql::{LibsqlDatabase, PrimaryType}; +use libsqlx::result_builder::ResultBuilder; +use libsqlx::{FrameNo, LogReadError, ReplicationLogger}; +use tokio::task::block_in_place; + +use crate::linc::bus::Dispatch; +use crate::linc::proto::{BuilderStep, Enveloppe, Frames, Message, StepError, Value}; +use crate::linc::{Inbound, NodeId, Outbound}; +use crate::meta::DatabaseId; + +use super::{ConnectionHandler, ExecFn, FRAMES_MESSAGE_MAX_COUNT}; + +const MAX_STEP_BATCH_SIZE: usize = 100_000_000; // ~100kb + // +pub struct PrimaryDatabase { + pub db: LibsqlDatabase, + pub replica_streams: HashMap)>, + pub frame_notifier: tokio::sync::watch::Receiver, +} + +pub struct ProxyResponseBuilder { + dispatcher: Arc, + buffer: Vec, + database_id: DatabaseId, + to: NodeId, + req_id: u32, + connection_id: u32, + next_seq_no: u32, +} + +impl ProxyResponseBuilder { + pub fn new( + dispatcher: Arc, + database_id: DatabaseId, + to: NodeId, + req_id: u32, + connection_id: u32, + ) -> Self { + Self { + dispatcher, + buffer: Vec::new(), + database_id, + to, + req_id, + connection_id, + next_seq_no: 0, + } + } + + fn maybe_send(&mut self) { + // FIXME: this is stupid: compute current buffer size on the go instead + let size = self + .buffer + .iter() + .map(|s| match s { + BuilderStep::FinishStep(_, _) => 2 * 8, + BuilderStep::StepError(StepError(s)) => s.len(), + BuilderStep::ColsDesc(ref d) => d + .iter() + .map(|c| c.name.len() + c.decl_ty.as_ref().map(|t| t.len()).unwrap_or_default()) + .sum(), + BuilderStep::Finnalize { .. } => 9, + BuilderStep::AddRowValue(v) => match v { + crate::linc::proto::Value::Text(s) | crate::linc::proto::Value::Blob(s) => { + s.len() + } + _ => size_of::(), + }, + _ => 8, + }) + .sum::(); + + if size > MAX_STEP_BATCH_SIZE { + self.send() + } + } + + fn send(&mut self) { + let msg = Outbound { + to: self.to, + enveloppe: Enveloppe { + database_id: Some(self.database_id), + message: Message::ProxyResponse(crate::linc::proto::ProxyResponse { + connection_id: self.connection_id, + req_id: self.req_id, + row_steps: std::mem::take(&mut self.buffer), + seq_no: self.next_seq_no, + }), + }, + }; + + self.next_seq_no += 1; + tokio::runtime::Handle::current().block_on(self.dispatcher.dispatch(msg)); + } +} + +impl ResultBuilder for ProxyResponseBuilder { + fn init( + &mut self, + _config: &libsqlx::result_builder::QueryBuilderConfig, + ) -> Result<(), libsqlx::result_builder::QueryResultBuilderError> { + self.buffer.push(BuilderStep::Init); + self.maybe_send(); + Ok(()) + } + + fn begin_step(&mut self) -> Result<(), libsqlx::result_builder::QueryResultBuilderError> { + self.buffer.push(BuilderStep::BeginStep); + self.maybe_send(); + Ok(()) + } + + fn finish_step( + &mut self, + affected_row_count: u64, + last_insert_rowid: Option, + ) -> Result<(), libsqlx::result_builder::QueryResultBuilderError> { + self.buffer.push(BuilderStep::FinishStep( + affected_row_count, + last_insert_rowid, + )); + self.maybe_send(); + Ok(()) + } + + fn step_error( + &mut self, + error: libsqlx::error::Error, + ) -> Result<(), libsqlx::result_builder::QueryResultBuilderError> { + self.buffer + .push(BuilderStep::StepError(StepError(error.to_string()))); + self.maybe_send(); + Ok(()) + } + + fn cols_description( + &mut self, + cols: &mut dyn Iterator, + ) -> Result<(), libsqlx::result_builder::QueryResultBuilderError> { + self.buffer + .push(BuilderStep::ColsDesc(cols.map(Into::into).collect())); + self.maybe_send(); + Ok(()) + } + + fn begin_rows(&mut self) -> Result<(), libsqlx::result_builder::QueryResultBuilderError> { + self.buffer.push(BuilderStep::BeginRows); + self.maybe_send(); + Ok(()) + } + + fn begin_row(&mut self) -> Result<(), libsqlx::result_builder::QueryResultBuilderError> { + self.buffer.push(BuilderStep::BeginRow); + self.maybe_send(); + Ok(()) + } + + fn add_row_value( + &mut self, + v: libsqlx::result_builder::ValueRef, + ) -> Result<(), libsqlx::result_builder::QueryResultBuilderError> { + self.buffer.push(BuilderStep::AddRowValue(v.into())); + self.maybe_send(); + Ok(()) + } + + fn finish_row(&mut self) -> Result<(), libsqlx::result_builder::QueryResultBuilderError> { + self.buffer.push(BuilderStep::FinishRow); + self.maybe_send(); + Ok(()) + } + + fn finish_rows(&mut self) -> Result<(), libsqlx::result_builder::QueryResultBuilderError> { + self.buffer.push(BuilderStep::FinishRows); + self.maybe_send(); + Ok(()) + } + + fn finnalize( + &mut self, + is_txn: bool, + frame_no: Option, + ) -> Result { + self.buffer + .push(BuilderStep::Finnalize { is_txn, frame_no }); + self.send(); + Ok(true) + } +} + +pub struct FrameStreamer { + pub logger: Arc, + pub database_id: DatabaseId, + pub node_id: NodeId, + pub next_frame_no: FrameNo, + pub req_no: u32, + pub seq_no: u32, + pub dipatcher: Arc, + pub notifier: tokio::sync::watch::Receiver, + pub buffer: Vec, +} + +impl FrameStreamer { + pub async fn run(mut self) { + loop { + match block_in_place(|| self.logger.get_frame(self.next_frame_no)) { + Ok(frame) => { + if self.buffer.len() > FRAMES_MESSAGE_MAX_COUNT { + self.send_frames().await; + } + self.buffer.push(frame.bytes()); + self.next_frame_no += 1; + } + Err(LogReadError::Ahead) => { + tracing::debug!("frame {} not yet avaiblable", self.next_frame_no); + if !self.buffer.is_empty() { + self.send_frames().await; + } + if self + .notifier + .wait_for(|fno| *fno >= self.next_frame_no) + .await + .is_err() + { + break; + } + } + Err(LogReadError::Error(_)) => todo!("handle log read error"), + Err(LogReadError::SnapshotRequired) => todo!("handle reading from snapshot"), + } + } + } + + async fn send_frames(&mut self) { + let frames = std::mem::take(&mut self.buffer); + let outbound = Outbound { + to: self.node_id, + enveloppe: Enveloppe { + database_id: Some(self.database_id), + message: Message::Frames(Frames { + req_no: self.req_no, + seq_no: self.seq_no, + frames, + }), + }, + }; + self.seq_no += 1; + self.dipatcher.dispatch(outbound).await; + } +} + +pub struct PrimaryConnection { + pub conn: libsqlx::libsql::LibsqlConnection, +} + +#[async_trait::async_trait] +impl ConnectionHandler for PrimaryConnection { + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<()> { + Poll::Ready(()) + } + + async fn handle_exec(&mut self, exec: ExecFn) { + block_in_place(|| exec(&mut self.conn)); + } + + async fn handle_inbound(&mut self, _msg: Inbound) { + tracing::debug!("primary connection received message, ignoring.") + } +} diff --git a/libsqlx-server/src/allocation/replica.rs b/libsqlx-server/src/allocation/replica.rs new file mode 100644 index 00000000..297d27fb --- /dev/null +++ b/libsqlx-server/src/allocation/replica.rs @@ -0,0 +1,342 @@ +use std::ops::Deref; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; +use std::time::Duration; + +use futures::Future; +use libsqlx::libsql::{LibsqlConnection, LibsqlDatabase, ReplicaType}; +use libsqlx::program::Program; +use libsqlx::proxy::{WriteProxyConnection, WriteProxyDatabase}; +use libsqlx::result_builder::{Column, QueryBuilderConfig, ResultBuilder}; +use libsqlx::{DescribeResponse, Frame, FrameNo, Injector}; +use parking_lot::Mutex; +use tokio::{ + sync::mpsc, + task::block_in_place, + time::{timeout, Sleep}, +}; + +use crate::linc::proto::{BuilderStep, ProxyResponse}; +use crate::linc::Inbound; +use crate::{ + linc::{ + bus::Dispatch, + proto::{Enveloppe, Frames, Message}, + NodeId, Outbound, + }, + meta::DatabaseId, +}; + +use super::{ConnectionHandler, ExecFn}; + +type ProxyConnection = WriteProxyConnection, RemoteConn>; +pub type ProxyDatabase = WriteProxyDatabase, RemoteDb>; + +pub struct RemoteDb { + pub proxy_request_timeout_duration: Duration, +} + +#[derive(Clone)] +pub struct RemoteConn { + inner: Arc, +} + +struct Request { + id: Option, + builder: Box, + pgm: Option, + next_seq_no: u32, + timeout: Pin>, +} + +pub struct RemoteConnInner { + current_req: Mutex>, + request_timeout_duration: Duration, +} + +impl Deref for RemoteConn { + type Target = RemoteConnInner; + + fn deref(&self) -> &Self::Target { + self.inner.as_ref() + } +} + +impl libsqlx::Connection for RemoteConn { + fn execute_program( + &mut self, + program: &libsqlx::program::Program, + builder: Box, + ) -> libsqlx::Result<()> { + // When we need to proxy a query, we place it in the current request slot. When we are + // back in a async context, we'll send it to the primary, and asynchrously drive the + // builder. + let mut lock = self.inner.current_req.lock(); + *lock = match *lock { + Some(_) => unreachable!("conccurent request on the same connection!"), + None => Some(Request { + id: None, + builder, + pgm: Some(program.clone()), + next_seq_no: 0, + timeout: Box::pin(tokio::time::sleep(self.inner.request_timeout_duration)), + }), + }; + + Ok(()) + } + + fn describe(&self, _sql: String) -> libsqlx::Result { + unreachable!("Describe request should not be proxied") + } +} + +impl libsqlx::Database for RemoteDb { + type Connection = RemoteConn; + + fn connect(&self) -> Result { + Ok(RemoteConn { + inner: Arc::new(RemoteConnInner { + current_req: Default::default(), + request_timeout_duration: self.proxy_request_timeout_duration, + }), + }) + } +} + +pub struct Replicator { + dispatcher: Arc, + req_id: u32, + next_frame_no: FrameNo, + next_seq: u32, + database_id: DatabaseId, + primary_node_id: NodeId, + injector: Box, + receiver: mpsc::Receiver, +} + +impl Replicator { + pub fn new( + dispatcher: Arc, + next_frame_no: FrameNo, + database_id: DatabaseId, + primary_node_id: NodeId, + injector: Box, + receiver: mpsc::Receiver, + ) -> Self { + Self { + dispatcher, + req_id: 0, + next_frame_no, + next_seq: 0, + database_id, + primary_node_id, + injector, + receiver, + } + } + + pub async fn run(mut self) { + self.query_replicate().await; + loop { + match timeout(Duration::from_secs(5), self.receiver.recv()).await { + Ok(Some(Frames { + req_no: req_id, + seq_no: seq, + frames, + })) => { + // ignore frames from a previous call to Replicate + if req_id != self.req_id { + tracing::debug!(req_id, self.req_id, "wrong req_id"); + continue; + } + if seq != self.next_seq { + // this is not the batch of frame we were expecting, drop what we have, and + // ask again from last checkpoint + tracing::debug!(seq, self.next_seq, "wrong seq"); + self.query_replicate().await; + continue; + }; + self.next_seq += 1; + + tracing::debug!("injecting {} frames", frames.len()); + + for bytes in frames { + let frame = Frame::try_from_bytes(bytes).unwrap(); + block_in_place(|| { + if let Some(last_committed) = self.injector.inject(frame).unwrap() { + tracing::debug!(last_committed); + self.next_frame_no = last_committed + 1; + } + }); + } + } + Err(_) => self.query_replicate().await, + Ok(None) => break, + } + } + } + + async fn query_replicate(&mut self) { + self.req_id += 1; + self.next_seq = 0; + // clear buffered, uncommitted frames + self.injector.clear(); + self.dispatcher + .dispatch(Outbound { + to: self.primary_node_id, + enveloppe: Enveloppe { + database_id: Some(self.database_id), + message: Message::Replicate { + next_frame_no: self.next_frame_no, + req_no: self.req_id, + }, + }, + }) + .await; + } +} + +pub struct ReplicaConnection { + pub conn: ProxyConnection, + pub connection_id: u32, + pub next_req_id: u32, + pub primary_node_id: NodeId, + pub database_id: DatabaseId, + pub dispatcher: Arc, +} + +impl ReplicaConnection { + fn handle_proxy_response(&mut self, resp: ProxyResponse) { + let mut lock = self.conn.writer().inner.current_req.lock(); + let finnalized = match *lock { + Some(ref mut req) if req.id == Some(resp.req_id) && resp.seq_no == req.next_seq_no => { + self.next_req_id += 1; + // TODO: pass actual config + let config = QueryBuilderConfig { max_size: None }; + let mut finnalized = false; + for step in resp.row_steps.into_iter() { + if finnalized { + break; + }; + match step { + BuilderStep::Init => req.builder.init(&config).unwrap(), + BuilderStep::BeginStep => req.builder.begin_step().unwrap(), + BuilderStep::FinishStep(affected_row_count, last_insert_rowid) => req + .builder + .finish_step(affected_row_count, last_insert_rowid) + .unwrap(), + BuilderStep::StepError(e) => req + .builder + .step_error(todo!("handle proxy step error")) + .unwrap(), + BuilderStep::ColsDesc(cols) => req + .builder + .cols_description(&mut &mut cols.iter().map(|c| Column { + name: &c.name, + decl_ty: c.decl_ty.as_deref(), + })) + .unwrap(), + BuilderStep::BeginRows => req.builder.begin_rows().unwrap(), + BuilderStep::BeginRow => req.builder.begin_row().unwrap(), + BuilderStep::AddRowValue(v) => { + req.builder.add_row_value((&v).into()).unwrap() + } + BuilderStep::FinishRow => req.builder.finish_row().unwrap(), + BuilderStep::FinishRows => req.builder.finish_rows().unwrap(), + BuilderStep::Finnalize { is_txn, frame_no } => { + let _ = req.builder.finnalize(is_txn, frame_no).unwrap(); + finnalized = true; + } + BuilderStep::FinnalizeError(e) => { + req.builder.finnalize_error(e); + finnalized = true; + } + } + } + finnalized + } + Some(_) => todo!("error processing response"), + None => { + tracing::error!("received builder message, but there is no pending request"); + false + } + }; + + if finnalized { + *lock = None; + } + } +} + +#[async_trait::async_trait] +impl ConnectionHandler for ReplicaConnection { + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<()> { + // we are currently handling a request on this connection + // self.conn.writer().current_req.timeout.poll() + let mut req = self.conn.writer().current_req.lock(); + let should_abort_query = match &mut *req { + Some(ref mut req) => match req.timeout.as_mut().poll(cx) { + Poll::Ready(_) => { + req.builder.finnalize_error("request timed out".to_string()); + true + } + Poll::Pending => return Poll::Pending, + }, + None => return Poll::Ready(()), + }; + + if should_abort_query { + *req = None + } + + Poll::Ready(()) + } + + async fn handle_exec(&mut self, exec: ExecFn) { + block_in_place(|| exec(&mut self.conn)); + let msg = { + let mut lock = self.conn.writer().inner.current_req.lock(); + match *lock { + Some(ref mut req) if req.id.is_none() => { + let program = req + .pgm + .take() + .expect("unsent request should have a program"); + let req_id = self.next_req_id; + self.next_req_id += 1; + req.id = Some(req_id); + + let msg = Outbound { + to: self.primary_node_id, + enveloppe: Enveloppe { + database_id: Some(self.database_id), + message: Message::ProxyRequest { + connection_id: self.connection_id, + req_id, + program, + }, + }, + }; + + Some(msg) + } + _ => None, + } + }; + + if let Some(msg) = msg { + self.dispatcher.dispatch(msg).await; + } + } + + async fn handle_inbound(&mut self, msg: Inbound) { + match msg.enveloppe.message { + Message::ProxyResponse(resp) => { + self.handle_proxy_response(resp); + } + _ => (), // ignore anything else + } + } +} From 8997c3f83442239f62af9036676d2a7a718745e4 Mon Sep 17 00:00:00 2001 From: ad hoc Date: Thu, 20 Jul 2023 18:40:03 +0200 Subject: [PATCH 30/64] add compact method to LibsqlDatabase --- libsqlx-server/src/allocation/replica.rs | 34 ++++++------- libsqlx/src/database/libsql/mod.rs | 4 ++ .../database/libsql/replication_log/logger.rs | 49 ++++++++++++------- 3 files changed, 48 insertions(+), 39 deletions(-) diff --git a/libsqlx-server/src/allocation/replica.rs b/libsqlx-server/src/allocation/replica.rs index 297d27fb..441a3d40 100644 --- a/libsqlx-server/src/allocation/replica.rs +++ b/libsqlx-server/src/allocation/replica.rs @@ -1,7 +1,7 @@ use std::ops::Deref; use std::pin::Pin; use std::sync::Arc; -use std::task::{Context, Poll}; +use std::task::{Context, Poll, ready}; use std::time::Duration; use futures::Future; @@ -11,22 +11,16 @@ use libsqlx::proxy::{WriteProxyConnection, WriteProxyDatabase}; use libsqlx::result_builder::{Column, QueryBuilderConfig, ResultBuilder}; use libsqlx::{DescribeResponse, Frame, FrameNo, Injector}; use parking_lot::Mutex; -use tokio::{ - sync::mpsc, - task::block_in_place, - time::{timeout, Sleep}, -}; +use tokio::sync::mpsc; +use tokio::task::block_in_place; +use tokio::time::{timeout, Sleep}; +use crate::linc::bus::Dispatch; use crate::linc::proto::{BuilderStep, ProxyResponse}; +use crate::linc::proto::{Enveloppe, Frames, Message}; use crate::linc::Inbound; -use crate::{ - linc::{ - bus::Dispatch, - proto::{Enveloppe, Frames, Message}, - NodeId, Outbound, - }, - meta::DatabaseId, -}; +use crate::linc::{NodeId, Outbound}; +use crate::meta::DatabaseId; use super::{ConnectionHandler, ExecFn}; @@ -227,7 +221,7 @@ impl ReplicaConnection { .builder .finish_step(affected_row_count, last_insert_rowid) .unwrap(), - BuilderStep::StepError(e) => req + BuilderStep::StepError(_e) => req .builder .step_error(todo!("handle proxy step error")) .unwrap(), @@ -274,15 +268,15 @@ impl ReplicaConnection { impl ConnectionHandler for ReplicaConnection { fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<()> { // we are currently handling a request on this connection - // self.conn.writer().current_req.timeout.poll() let mut req = self.conn.writer().current_req.lock(); let should_abort_query = match &mut *req { - Some(ref mut req) => match req.timeout.as_mut().poll(cx) { - Poll::Ready(_) => { + Some(ref mut req) => { + ready!(req.timeout.as_mut().poll(cx)); + // the request has timedout, we finalize the builder with a error, and clean the + // current request. req.builder.finnalize_error("request timed out".to_string()); true - } - Poll::Pending => return Poll::Pending, + }, None => return Poll::Ready(()), }; diff --git a/libsqlx/src/database/libsql/mod.rs b/libsqlx/src/database/libsql/mod.rs index 382a4177..a6fdceb1 100644 --- a/libsqlx/src/database/libsql/mod.rs +++ b/libsqlx/src/database/libsql/mod.rs @@ -132,6 +132,10 @@ impl LibsqlDatabase { Ok(Self::new(db_path, ty)) } + pub fn compact_log(&self) { + self.ty.logger.compact(); + } + pub fn logger(&self) -> Arc { self.ty.logger.clone() } diff --git a/libsqlx/src/database/libsql/replication_log/logger.rs b/libsqlx/src/database/libsql/replication_log/logger.rs index 7bcfb0bf..4f899a0d 100644 --- a/libsqlx/src/database/libsql/replication_log/logger.rs +++ b/libsqlx/src/database/libsql/replication_log/logger.rs @@ -143,11 +143,12 @@ unsafe impl WalHook for ReplicationLoggerHook { std::process::abort(); } - if let Err(e) = ctx.logger.log_file.write().maybe_compact( - &*ctx.logger.compactor, - ntruncate, - &ctx.logger.db_path, - ) { + if let Err(e) = ctx + .logger + .log_file + .write() + .maybe_compact(&*ctx.logger.compactor, &ctx.logger.db_path) + { tracing::error!("fatal error: {e}, exiting"); std::process::abort() } @@ -425,6 +426,10 @@ impl LogFile { } } + pub fn can_compact(&mut self) -> bool { + self.header.frame_count > 0 && self.uncommitted_frame_count == 0 + } + pub fn read_header(file: &File) -> crate::Result { let mut buf = [0; size_of::()]; file.read_exact_at(&mut buf, 0)?; @@ -563,27 +568,24 @@ impl LogFile { Ok(frame) } - fn maybe_compact( - &mut self, - compactor: &dyn LogCompactor, - size_after: u32, - path: &Path, - ) -> anyhow::Result<()> { - if compactor.should_compact(self) { - return self.do_compaction(compactor, size_after, path); + fn maybe_compact(&mut self, compactor: &dyn LogCompactor, path: &Path) -> anyhow::Result<()> { + if self.can_compact() && compactor.should_compact(self) { + return self.do_compaction(compactor, path); } Ok(()) } - fn do_compaction( - &mut self, - compactor: &dyn LogCompactor, - size_after: u32, - path: &Path, - ) -> anyhow::Result<()> { + fn do_compaction(&mut self, compactor: &dyn LogCompactor, path: &Path) -> anyhow::Result<()> { tracing::info!("performing log compaction"); let temp_log_path = path.join("temp_log"); + let last_frame = self + .rev_frames_iter()? + .next() + .expect("there should be at least one frame to perform compaction")?; + let size_after = last_frame.header().size_after; + assert!(size_after != 0); + let file = OpenOptions::new() .read(true) .write(true) @@ -916,6 +918,15 @@ impl ReplicationLogger { pub fn get_frame(&self, frame_no: FrameNo) -> Result { self.log_file.read().frame(frame_no) } + + pub fn compact(&self) { + let mut log_file = self.log_file.write(); + if log_file.can_compact() { + log_file + .do_compaction(&*self.compactor, &self.db_path) + .unwrap(); + } + } } fn checkpoint_db(data_path: &Path) -> crate::Result<()> { From 2b3316ebaa7e106473832991bea1908041b8abdf Mon Sep 17 00:00:00 2001 From: ad hoc Date: Fri, 21 Jul 2023 10:21:53 +0200 Subject: [PATCH 31/64] implement should_compact --- libsqlx-server/src/allocation/mod.rs | 28 +++++++++++++++---- .../database/libsql/replication_log/logger.rs | 23 +++++++++------ 2 files changed, 36 insertions(+), 15 deletions(-) diff --git a/libsqlx-server/src/allocation/mod.rs b/libsqlx-server/src/allocation/mod.rs index 8c7ba873..836bbb50 100644 --- a/libsqlx-server/src/allocation/mod.rs +++ b/libsqlx-server/src/allocation/mod.rs @@ -4,7 +4,7 @@ use std::future::poll_fn; use std::path::PathBuf; use std::sync::Arc; use std::task::{Context, Poll}; -use std::time::Instant; +use std::time::{Instant, Duration}; use either::Either; use libsqlx::libsql::{LibsqlDatabase, LogCompactor, LogFile}; @@ -55,19 +55,31 @@ pub enum Database { }, } -struct Compactor; +struct Compactor { + max_log_size: usize, + last_compacted_at: Instant, + compact_interval: Option, +} impl LogCompactor for Compactor { - fn should_compact(&self, _log: &LogFile) -> bool { - false + fn should_compact(&self, log: &LogFile) -> bool { + let mut should_compact = false; + if let Some(compact_interval)= self.compact_interval { + should_compact |= self.last_compacted_at.elapsed() >= compact_interval + } + + should_compact |= log.size() >= self.max_log_size; + + should_compact } fn compact( - &self, + &mut self, _log: LogFile, _path: std::path::PathBuf, _size_after: u32, ) -> Result<(), Box> { + self.last_compacted_at = Instant::now(); todo!() } } @@ -79,7 +91,11 @@ impl Database { let (sender, receiver) = tokio::sync::watch::channel(0); let db = LibsqlDatabase::new_primary( path, - Compactor, + Compactor { + max_log_size: usize::MAX, + last_compacted_at: Instant::now(), + compact_interval: None, + }, false, Box::new(move |fno| { let _ = sender.send(fno); diff --git a/libsqlx/src/database/libsql/replication_log/logger.rs b/libsqlx/src/database/libsql/replication_log/logger.rs index 4f899a0d..cbb9acca 100644 --- a/libsqlx/src/database/libsql/replication_log/logger.rs +++ b/libsqlx/src/database/libsql/replication_log/logger.rs @@ -9,7 +9,7 @@ use std::sync::Arc; use anyhow::{bail, ensure}; use bytemuck::{bytes_of, pod_read_unaligned, Pod, Zeroable}; use bytes::{Bytes, BytesMut}; -use parking_lot::RwLock; +use parking_lot::{RwLock, Mutex}; use rusqlite::ffi::{ libsql_wal as Wal, sqlite3, PgHdr, SQLITE_CHECKPOINT_TRUNCATE, SQLITE_IOERR, SQLITE_OK, }; @@ -147,7 +147,7 @@ unsafe impl WalHook for ReplicationLoggerHook { .logger .log_file .write() - .maybe_compact(&*ctx.logger.compactor, &ctx.logger.db_path) + .maybe_compact(&mut *ctx.logger.compactor.lock(), &ctx.logger.db_path) { tracing::error!("fatal error: {e}, exiting"); std::process::abort() @@ -568,7 +568,7 @@ impl LogFile { Ok(frame) } - fn maybe_compact(&mut self, compactor: &dyn LogCompactor, path: &Path) -> anyhow::Result<()> { + fn maybe_compact(&mut self, compactor: &mut dyn LogCompactor, path: &Path) -> anyhow::Result<()> { if self.can_compact() && compactor.should_compact(self) { return self.do_compaction(compactor, path); } @@ -576,7 +576,7 @@ impl LogFile { Ok(()) } - fn do_compaction(&mut self, compactor: &dyn LogCompactor, path: &Path) -> anyhow::Result<()> { + fn do_compaction(&mut self, compactor: &mut dyn LogCompactor, path: &Path) -> anyhow::Result<()> { tracing::info!("performing log compaction"); let temp_log_path = path.join("temp_log"); let last_frame = self @@ -631,6 +631,11 @@ impl LogFile { self.file.set_len(0)?; Self::new(self.file) } + + /// return the size in bytes of the log + pub fn size(&self) -> usize { + size_of::() + Frame::SIZE * self.header().frame_count as usize + } } #[cfg(target_os = "macos")] @@ -730,7 +735,7 @@ pub trait LogCompactor: Sync + Send + 'static { fn should_compact(&self, log: &LogFile) -> bool; /// Compact the given snapshot fn compact( - &self, + &mut self, log: LogFile, path: PathBuf, size_after: u32, @@ -740,7 +745,7 @@ pub trait LogCompactor: Sync + Send + 'static { #[cfg(test)] impl LogCompactor for () { fn compact( - &self, + &mut self, _file: LogFile, _path: PathBuf, _size_after: u32, @@ -758,7 +763,7 @@ pub type FrameNotifierCb = Box; pub struct ReplicationLogger { pub generation: Generation, pub log_file: RwLock, - compactor: Box, + compactor: Box>, db_path: PathBuf, /// a notifier channel other tasks can subscribe to, and get notified when new frames become /// available. @@ -822,7 +827,7 @@ impl ReplicationLogger { Ok(Self { generation: Generation::new(generation_start_frame_no), - compactor: Box::new(compactor), + compactor: Box::new(Mutex::new(compactor)), log_file: RwLock::new(log_file), db_path, new_frame_notifier, @@ -923,7 +928,7 @@ impl ReplicationLogger { let mut log_file = self.log_file.write(); if log_file.can_compact() { log_file - .do_compaction(&*self.compactor, &self.db_path) + .do_compaction(&mut *self.compactor.lock(), &self.db_path) .unwrap(); } } From b9cb88ec02e1974211ffeb2a0bf53dbbab9e92ae Mon Sep 17 00:00:00 2001 From: ad hoc Date: Fri, 21 Jul 2023 10:39:29 +0200 Subject: [PATCH 32/64] periodic compaction --- libsqlx-server/src/allocation/mod.rs | 58 +++++++++++++++++------- libsqlx-server/src/allocation/primary.rs | 2 +- 2 files changed, 43 insertions(+), 17 deletions(-) diff --git a/libsqlx-server/src/allocation/mod.rs b/libsqlx-server/src/allocation/mod.rs index 836bbb50..d71f128c 100644 --- a/libsqlx-server/src/allocation/mod.rs +++ b/libsqlx-server/src/allocation/mod.rs @@ -2,8 +2,9 @@ use std::collections::hash_map::Entry; use std::collections::HashMap; use std::future::poll_fn; use std::path::PathBuf; +use std::pin::Pin; use std::sync::Arc; -use std::task::{Context, Poll}; +use std::task::{Context, Poll, ready}; use std::time::{Instant, Duration}; use either::Either; @@ -13,6 +14,7 @@ use libsqlx::proxy::WriteProxyDatabase; use libsqlx::{Database as _, InjectableDatabase}; use tokio::sync::{mpsc, oneshot}; use tokio::task::{block_in_place, JoinSet}; +use tokio::time::Interval; use crate::allocation::primary::FrameStreamer; use crate::hrana; @@ -46,7 +48,10 @@ pub enum AllocationMessage { } pub enum Database { - Primary(PrimaryDatabase), + Primary { + db: PrimaryDatabase, + compact_interval: Option>>, + }, Replica { db: ProxyDatabase, injector_handle: mpsc::Sender, @@ -55,6 +60,20 @@ pub enum Database { }, } +impl Database { + fn poll(&mut self, cx: &mut Context<'_>) -> Poll<()> { + if let Self::Primary { compact_interval: Some(ref mut interval), db } = self { + ready!(interval.poll_tick(cx)); + let db = db.db.clone(); + tokio::task::spawn_blocking(move || { + db.compact_log(); + }); + } + + Poll::Pending + } +} + struct Compactor { max_log_size: usize, last_compacted_at: Instant, @@ -103,11 +122,14 @@ impl Database { ) .unwrap(); - Self::Primary(PrimaryDatabase { - db, - replica_streams: HashMap::new(), - frame_notifier: receiver, - }) + Self::Primary{ + db: PrimaryDatabase { + db: Arc::new(db), + replica_streams: HashMap::new(), + frame_notifier: receiver, + } , + compact_interval: None, + } } DbConfig::Replica { primary_node_id, @@ -145,7 +167,7 @@ impl Database { fn connect(&self, connection_id: u32, alloc: &Allocation) -> impl ConnectionHandler { match self { - Database::Primary(PrimaryDatabase { db, .. }) => Either::Right(PrimaryConnection { + Database::Primary { db: PrimaryDatabase { db, .. }, .. } => Either::Right(PrimaryConnection { conn: db.connect().unwrap(), }), Database::Replica { db, primary_id, .. } => Either::Left(ReplicaConnection { @@ -160,7 +182,7 @@ impl Database { } pub fn is_primary(&self) -> bool { - matches!(self, Self::Primary(..)) + matches!(self, Self::Primary { .. }) } } @@ -206,7 +228,9 @@ impl ConnectionHandle { impl Allocation { pub async fn run(mut self) { loop { + let fut = poll_fn(|cx| self.database.poll(cx)); tokio::select! { + _ = fut => (), Some(msg) = self.inbox.recv() => { match msg { AllocationMessage::HranaPipelineReq { req, ret } => { @@ -245,12 +269,14 @@ impl Allocation { req_no, next_frame_no, } => match &mut self.database { - Database::Primary(PrimaryDatabase { - db, - replica_streams, - frame_notifier, - .. - }) => { + Database::Primary{ + db: PrimaryDatabase { + db, + replica_streams, + frame_notifier, + .. + }, .. + } => { let streamer = FrameStreamer { logger: db.logger(), database_id: DatabaseId::from_name(&self.db_name), @@ -293,7 +319,7 @@ impl Allocation { *last_received_frame_ts = Some(Instant::now()); injector_handle.send(frames).await.unwrap(); } - Database::Primary(PrimaryDatabase { .. }) => todo!("handle primary receiving txn"), + Database::Primary { db: PrimaryDatabase { .. }, .. } => todo!("handle primary receiving txn"), }, Message::ProxyRequest { connection_id, diff --git a/libsqlx-server/src/allocation/primary.rs b/libsqlx-server/src/allocation/primary.rs index 15ac4dbd..f66ed4dd 100644 --- a/libsqlx-server/src/allocation/primary.rs +++ b/libsqlx-server/src/allocation/primary.rs @@ -19,7 +19,7 @@ use super::{ConnectionHandler, ExecFn, FRAMES_MESSAGE_MAX_COUNT}; const MAX_STEP_BATCH_SIZE: usize = 100_000_000; // ~100kb // pub struct PrimaryDatabase { - pub db: LibsqlDatabase, + pub db: Arc>, pub replica_streams: HashMap)>, pub frame_notifier: tokio::sync::watch::Receiver, } From 3d8c5e70b1a755bf5ef17accab5b9d92b6277c5a Mon Sep 17 00:00:00 2001 From: ad hoc Date: Fri, 21 Jul 2023 11:45:37 +0200 Subject: [PATCH 33/64] add log compaction config to primary config --- Cargo.lock | 4 + libsqlx-server/Cargo.toml | 1 + libsqlx-server/src/allocation/config.rs | 7 +- libsqlx-server/src/allocation/mod.rs | 54 ++++++---- libsqlx-server/src/allocation/replica.rs | 13 ++- libsqlx-server/src/http/admin.rs | 98 ++++++++++++------- .../database/libsql/replication_log/logger.rs | 14 ++- 7 files changed, 128 insertions(+), 63 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 23ee18ef..d369d037 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -900,6 +900,9 @@ name = "bytesize" version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "38fcc2979eff34a4b84e1cf9a1e3da42a7d44b3b690a40cdcb23e3d556cfb2e5" +dependencies = [ + "serde", +] [[package]] name = "camino" @@ -2549,6 +2552,7 @@ dependencies = [ "base64 0.21.2", "bincode", "bytes 1.4.0", + "bytesize", "clap", "color-eyre", "either", diff --git a/libsqlx-server/Cargo.toml b/libsqlx-server/Cargo.toml index 42a508c6..90b27680 100644 --- a/libsqlx-server/Cargo.toml +++ b/libsqlx-server/Cargo.toml @@ -12,6 +12,7 @@ axum = "0.6.18" base64 = "0.21.2" bincode = "1.3.3" bytes = { version = "1.4.0", features = ["serde"] } +bytesize = { version = "1.2.0", features = ["serde"] } clap = { version = "4.3.11", features = ["derive"] } color-eyre = "0.6.2" either = "1.8.1" diff --git a/libsqlx-server/src/allocation/config.rs b/libsqlx-server/src/allocation/config.rs index ac21efa7..f0c13870 100644 --- a/libsqlx-server/src/allocation/config.rs +++ b/libsqlx-server/src/allocation/config.rs @@ -20,7 +20,12 @@ pub struct AllocConfig { #[derive(Debug, Serialize, Deserialize)] pub enum DbConfig { - Primary {}, + Primary { + /// maximum size the replication log is allowed to grow, before it is compacted. + max_log_size: usize, + /// Interval at which to force compaction + replication_log_compact_interval: Option, + }, Replica { primary_node_id: NodeId, proxy_request_timeout_duration: Duration, diff --git a/libsqlx-server/src/allocation/mod.rs b/libsqlx-server/src/allocation/mod.rs index d71f128c..33bde27d 100644 --- a/libsqlx-server/src/allocation/mod.rs +++ b/libsqlx-server/src/allocation/mod.rs @@ -4,8 +4,8 @@ use std::future::poll_fn; use std::path::PathBuf; use std::pin::Pin; use std::sync::Arc; -use std::task::{Context, Poll, ready}; -use std::time::{Instant, Duration}; +use std::task::{ready, Context, Poll}; +use std::time::{Duration, Instant}; use either::Either; use libsqlx::libsql::{LibsqlDatabase, LogCompactor, LogFile}; @@ -36,6 +36,7 @@ mod replica; /// the maximum number of frame a Frame messahe is allowed to contain const FRAMES_MESSAGE_MAX_COUNT: usize = 5; const MAX_INJECTOR_BUFFER_CAP: usize = 32; +const DEFAULT_MAX_LOG_SIZE: usize = 100 * 1024 * 1024; // 100Mb type ExecFn = Box; @@ -62,7 +63,11 @@ pub enum Database { impl Database { fn poll(&mut self, cx: &mut Context<'_>) -> Poll<()> { - if let Self::Primary { compact_interval: Some(ref mut interval), db } = self { + if let Self::Primary { + compact_interval: Some(ref mut interval), + db, + } = self + { ready!(interval.poll_tick(cx)); let db = db.db.clone(); tokio::task::spawn_blocking(move || { @@ -83,10 +88,10 @@ struct Compactor { impl LogCompactor for Compactor { fn should_compact(&self, log: &LogFile) -> bool { let mut should_compact = false; - if let Some(compact_interval)= self.compact_interval { + if let Some(compact_interval) = self.compact_interval { should_compact |= self.last_compacted_at.elapsed() >= compact_interval } - + should_compact |= log.size() >= self.max_log_size; should_compact @@ -106,14 +111,17 @@ impl LogCompactor for Compactor { impl Database { pub fn from_config(config: &AllocConfig, path: PathBuf, dispatcher: Arc) -> Self { match config.db_config { - DbConfig::Primary {} => { + DbConfig::Primary { + max_log_size, + replication_log_compact_interval, + } => { let (sender, receiver) = tokio::sync::watch::channel(0); let db = LibsqlDatabase::new_primary( path, Compactor { - max_log_size: usize::MAX, last_compacted_at: Instant::now(), - compact_interval: None, + max_log_size, + compact_interval: replication_log_compact_interval, }, false, Box::new(move |fno| { @@ -122,12 +130,12 @@ impl Database { ) .unwrap(); - Self::Primary{ + Self::Primary { db: PrimaryDatabase { db: Arc::new(db), replica_streams: HashMap::new(), frame_notifier: receiver, - } , + }, compact_interval: None, } } @@ -167,7 +175,10 @@ impl Database { fn connect(&self, connection_id: u32, alloc: &Allocation) -> impl ConnectionHandler { match self { - Database::Primary { db: PrimaryDatabase { db, .. }, .. } => Either::Right(PrimaryConnection { + Database::Primary { + db: PrimaryDatabase { db, .. }, + .. + } => Either::Right(PrimaryConnection { conn: db.connect().unwrap(), }), Database::Replica { db, primary_id, .. } => Either::Left(ReplicaConnection { @@ -269,13 +280,15 @@ impl Allocation { req_no, next_frame_no, } => match &mut self.database { - Database::Primary{ - db: PrimaryDatabase { - db, - replica_streams, - frame_notifier, - .. - }, .. + Database::Primary { + db: + PrimaryDatabase { + db, + replica_streams, + frame_notifier, + .. + }, + .. } => { let streamer = FrameStreamer { logger: db.logger(), @@ -319,7 +332,10 @@ impl Allocation { *last_received_frame_ts = Some(Instant::now()); injector_handle.send(frames).await.unwrap(); } - Database::Primary { db: PrimaryDatabase { .. }, .. } => todo!("handle primary receiving txn"), + Database::Primary { + db: PrimaryDatabase { .. }, + .. + } => todo!("handle primary receiving txn"), }, Message::ProxyRequest { connection_id, diff --git a/libsqlx-server/src/allocation/replica.rs b/libsqlx-server/src/allocation/replica.rs index 441a3d40..ee8008d6 100644 --- a/libsqlx-server/src/allocation/replica.rs +++ b/libsqlx-server/src/allocation/replica.rs @@ -1,7 +1,7 @@ use std::ops::Deref; use std::pin::Pin; use std::sync::Arc; -use std::task::{Context, Poll, ready}; +use std::task::{ready, Context, Poll}; use std::time::Duration; use futures::Future; @@ -272,12 +272,11 @@ impl ConnectionHandler for ReplicaConnection { let should_abort_query = match &mut *req { Some(ref mut req) => { ready!(req.timeout.as_mut().poll(cx)); - // the request has timedout, we finalize the builder with a error, and clean the - // current request. - req.builder.finnalize_error("request timed out".to_string()); - true - - }, + // the request has timedout, we finalize the builder with a error, and clean the + // current request. + req.builder.finnalize_error("request timed out".to_string()); + true + } None => return Poll::Ready(()), }; diff --git a/libsqlx-server/src/http/admin.rs b/libsqlx-server/src/http/admin.rs index 0e263ddf..4dd27944 100644 --- a/libsqlx-server/src/http/admin.rs +++ b/libsqlx-server/src/http/admin.rs @@ -1,3 +1,4 @@ +use std::ops::Deref; use std::str::FromStr; use std::sync::Arc; use std::time::Duration; @@ -57,46 +58,69 @@ struct AllocateReq { #[derive(Debug, Deserialize)] #[serde(tag = "type", rename_all = "snake_case")] -pub enum DbConfigReq { - Primary {}, - Replica { - primary_node_id: NodeId, - #[serde( - deserialize_with = "deserialize_duration", - default = "default_proxy_timeout" - )] - proxy_request_timeout_duration: Duration, - }, +pub struct Primary { + /// The maximum size the replication is allowed to grow. Expects a string like 200mb. + #[serde(default = "default_max_log_size")] + pub max_replication_log_size: bytesize::ByteSize, + pub replication_log_compact_interval: Option, } -const fn default_proxy_timeout() -> Duration { - Duration::from_secs(5) +#[derive(Debug)] +pub struct HumanDuration(Duration); + +impl Deref for HumanDuration { + type Target = Duration; + + fn deref(&self) -> &Self::Target { + &self.0 + } } -fn deserialize_duration<'de, D>(deserializer: D) -> Result -where - D: Deserializer<'de>, -{ - struct Visitor; - impl serde::de::Visitor<'_> for Visitor { - type Value = Duration; - - fn visit_str(self, v: &str) -> std::result::Result - where - E: serde::de::Error, - { - match humantime::Duration::from_str(v) { - Ok(d) => Ok(*d), - Err(e) => Err(E::custom(e.to_string())), +impl<'de> Deserialize<'de> for HumanDuration { + fn deserialize(deserializer: D) -> std::result::Result + where + D: Deserializer<'de>, + { + struct DurationVisitor; + impl serde::de::Visitor<'_> for DurationVisitor { + type Value = HumanDuration; + + fn visit_str(self, v: &str) -> std::result::Result + where + E: serde::de::Error, + { + match humantime::Duration::from_str(v) { + Ok(d) => Ok(HumanDuration(*d)), + Err(e) => Err(E::custom(e.to_string())), + } } - } - fn expecting(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - f.write_str("a duration, in a string format") + fn expecting(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + f.write_str("a duration, in a string format") + } } + + deserializer.deserialize_str(DurationVisitor) } +} - deserializer.deserialize_str(Visitor) +#[derive(Debug, Deserialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum DbConfigReq { + Primary(Primary), + Replica { + primary_node_id: NodeId, + #[serde(default = "default_proxy_timeout")] + proxy_request_timeout_duration: HumanDuration, + }, +} + +const fn default_max_log_size() -> bytesize::ByteSize { + bytesize::ByteSize::mb(100) +} + +const fn default_proxy_timeout() -> HumanDuration { + HumanDuration(Duration::from_secs(5)) } async fn allocate( @@ -107,13 +131,21 @@ async fn allocate( max_conccurent_connection: req.max_conccurent_connection.unwrap_or(16), db_name: req.database_name.clone(), db_config: match req.config { - DbConfigReq::Primary {} => DbConfig::Primary {}, + DbConfigReq::Primary(Primary { + max_replication_log_size, + replication_log_compact_interval, + }) => DbConfig::Primary { + max_log_size: max_replication_log_size.as_u64() as usize, + replication_log_compact_interval: replication_log_compact_interval + .as_deref() + .copied(), + }, DbConfigReq::Replica { primary_node_id, proxy_request_timeout_duration, } => DbConfig::Replica { primary_node_id, - proxy_request_timeout_duration, + proxy_request_timeout_duration: *proxy_request_timeout_duration, }, }, }; diff --git a/libsqlx/src/database/libsql/replication_log/logger.rs b/libsqlx/src/database/libsql/replication_log/logger.rs index cbb9acca..48d3916e 100644 --- a/libsqlx/src/database/libsql/replication_log/logger.rs +++ b/libsqlx/src/database/libsql/replication_log/logger.rs @@ -9,7 +9,7 @@ use std::sync::Arc; use anyhow::{bail, ensure}; use bytemuck::{bytes_of, pod_read_unaligned, Pod, Zeroable}; use bytes::{Bytes, BytesMut}; -use parking_lot::{RwLock, Mutex}; +use parking_lot::{Mutex, RwLock}; use rusqlite::ffi::{ libsql_wal as Wal, sqlite3, PgHdr, SQLITE_CHECKPOINT_TRUNCATE, SQLITE_IOERR, SQLITE_OK, }; @@ -568,7 +568,11 @@ impl LogFile { Ok(frame) } - fn maybe_compact(&mut self, compactor: &mut dyn LogCompactor, path: &Path) -> anyhow::Result<()> { + fn maybe_compact( + &mut self, + compactor: &mut dyn LogCompactor, + path: &Path, + ) -> anyhow::Result<()> { if self.can_compact() && compactor.should_compact(self) { return self.do_compaction(compactor, path); } @@ -576,7 +580,11 @@ impl LogFile { Ok(()) } - fn do_compaction(&mut self, compactor: &mut dyn LogCompactor, path: &Path) -> anyhow::Result<()> { + fn do_compaction( + &mut self, + compactor: &mut dyn LogCompactor, + path: &Path, + ) -> anyhow::Result<()> { tracing::info!("performing log compaction"); let temp_log_path = path.join("temp_log"); let last_frame = self From 773bd730601704e7d2c20c25b58f49ba8c97d11e Mon Sep 17 00:00:00 2001 From: ad hoc Date: Fri, 21 Jul 2023 13:11:52 +0200 Subject: [PATCH 34/64] move Compactor to own file --- libsqlx-server/src/allocation/mod.rs | 49 ++++--------------- .../src/allocation/primary/compactor.rs | 42 ++++++++++++++++ .../allocation/{primary.rs => primary/mod.rs} | 2 + 3 files changed, 53 insertions(+), 40 deletions(-) create mode 100644 libsqlx-server/src/allocation/primary/compactor.rs rename libsqlx-server/src/allocation/{primary.rs => primary/mod.rs} (99%) diff --git a/libsqlx-server/src/allocation/mod.rs b/libsqlx-server/src/allocation/mod.rs index 33bde27d..fbc4049f 100644 --- a/libsqlx-server/src/allocation/mod.rs +++ b/libsqlx-server/src/allocation/mod.rs @@ -5,10 +5,10 @@ use std::path::PathBuf; use std::pin::Pin; use std::sync::Arc; use std::task::{ready, Context, Poll}; -use std::time::{Duration, Instant}; +use std::time::Instant; use either::Either; -use libsqlx::libsql::{LibsqlDatabase, LogCompactor, LogFile}; +use libsqlx::libsql::LibsqlDatabase; use libsqlx::program::Program; use libsqlx::proxy::WriteProxyDatabase; use libsqlx::{Database as _, InjectableDatabase}; @@ -26,6 +26,7 @@ use crate::linc::{Inbound, NodeId}; use crate::meta::DatabaseId; use self::config::{AllocConfig, DbConfig}; +use self::primary::compactor::Compactor; use self::primary::{PrimaryConnection, PrimaryDatabase, ProxyResponseBuilder}; use self::replica::{ProxyDatabase, RemoteDb, ReplicaConnection, Replicator}; @@ -33,10 +34,10 @@ pub mod config; mod primary; mod replica; -/// the maximum number of frame a Frame messahe is allowed to contain +/// Maximum number of frame a Frame message is allowed to contain const FRAMES_MESSAGE_MAX_COUNT: usize = 5; -const MAX_INJECTOR_BUFFER_CAP: usize = 32; -const DEFAULT_MAX_LOG_SIZE: usize = 100 * 1024 * 1024; // 100Mb +/// Maximum number of frames in the injector buffer +const MAX_INJECTOR_BUFFER_CAPACITY: usize = 32; type ExecFn = Box; @@ -79,35 +80,6 @@ impl Database { } } -struct Compactor { - max_log_size: usize, - last_compacted_at: Instant, - compact_interval: Option, -} - -impl LogCompactor for Compactor { - fn should_compact(&self, log: &LogFile) -> bool { - let mut should_compact = false; - if let Some(compact_interval) = self.compact_interval { - should_compact |= self.last_compacted_at.elapsed() >= compact_interval - } - - should_compact |= log.size() >= self.max_log_size; - - should_compact - } - - fn compact( - &mut self, - _log: LogFile, - _path: std::path::PathBuf, - _size_after: u32, - ) -> Result<(), Box> { - self.last_compacted_at = Instant::now(); - todo!() - } -} - impl Database { pub fn from_config(config: &AllocConfig, path: PathBuf, dispatcher: Arc) -> Self { match config.db_config { @@ -118,11 +90,7 @@ impl Database { let (sender, receiver) = tokio::sync::watch::channel(0); let db = LibsqlDatabase::new_primary( path, - Compactor { - last_compacted_at: Instant::now(), - max_log_size, - compact_interval: replication_log_compact_interval, - }, + Compactor::new(max_log_size, replication_log_compact_interval), false, Box::new(move |fno| { let _ = sender.send(fno); @@ -143,7 +111,8 @@ impl Database { primary_node_id, proxy_request_timeout_duration, } => { - let rdb = LibsqlDatabase::new_replica(path, MAX_INJECTOR_BUFFER_CAP, ()).unwrap(); + let rdb = + LibsqlDatabase::new_replica(path, MAX_INJECTOR_BUFFER_CAPACITY, ()).unwrap(); let wdb = RemoteDb { proxy_request_timeout_duration, }; diff --git a/libsqlx-server/src/allocation/primary/compactor.rs b/libsqlx-server/src/allocation/primary/compactor.rs new file mode 100644 index 00000000..62b9c0ed --- /dev/null +++ b/libsqlx-server/src/allocation/primary/compactor.rs @@ -0,0 +1,42 @@ +use std::time::{Duration, Instant}; + +use libsqlx::libsql::{LogCompactor, LogFile}; + +pub struct Compactor { + max_log_size: usize, + last_compacted_at: Instant, + compact_interval: Option, +} + +impl Compactor { + pub fn new(max_log_size: usize, compact_interval: Option) -> Self { + Self { + max_log_size, + last_compacted_at: Instant::now(), + compact_interval, + } + } +} + +impl LogCompactor for Compactor { + fn should_compact(&self, log: &LogFile) -> bool { + let mut should_compact = false; + if let Some(compact_interval) = self.compact_interval { + should_compact |= self.last_compacted_at.elapsed() >= compact_interval + } + + should_compact |= log.size() >= self.max_log_size; + + should_compact + } + + fn compact( + &mut self, + _log: LogFile, + _path: std::path::PathBuf, + _size_after: u32, + ) -> Result<(), Box> { + self.last_compacted_at = Instant::now(); + todo!() + } +} diff --git a/libsqlx-server/src/allocation/primary.rs b/libsqlx-server/src/allocation/primary/mod.rs similarity index 99% rename from libsqlx-server/src/allocation/primary.rs rename to libsqlx-server/src/allocation/primary/mod.rs index f66ed4dd..480a0e6a 100644 --- a/libsqlx-server/src/allocation/primary.rs +++ b/libsqlx-server/src/allocation/primary/mod.rs @@ -16,6 +16,8 @@ use crate::meta::DatabaseId; use super::{ConnectionHandler, ExecFn, FRAMES_MESSAGE_MAX_COUNT}; +pub mod compactor; + const MAX_STEP_BATCH_SIZE: usize = 100_000_000; // ~100kb // pub struct PrimaryDatabase { From 74a39e8909f31591a10222043ef3ca640cddc57b Mon Sep 17 00:00:00 2001 From: ad hoc Date: Mon, 24 Jul 2023 11:38:35 +0200 Subject: [PATCH 35/64] migrate to lmdb for meta storage --- Cargo.lock | 96 ++++++++++++++++++++++++++++++++ libsqlx-server/Cargo.toml | 3 + libsqlx-server/src/http/admin.rs | 1 - libsqlx-server/src/main.rs | 9 ++- libsqlx-server/src/manager.rs | 6 +- libsqlx-server/src/meta.rs | 81 ++++++++++++++++----------- 6 files changed, 159 insertions(+), 37 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index d369d037..7b525e06 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1534,6 +1534,15 @@ dependencies = [ "winapi", ] +[[package]] +name = "doxygen-rs" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bff670ea0c9bbb8414e7efa6e23ebde2b8f520a7eef78273a3918cf1903e7505" +dependencies = [ + "phf", +] + [[package]] name = "either" version = "1.8.1" @@ -2028,6 +2037,45 @@ version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" +[[package]] +name = "heed" +version = "0.20.0-alpha.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2ae92e0d788e4b5608cddd89bec1128ad9f424e365025bd2454915aab450d7a2" +dependencies = [ + "bytemuck", + "byteorder", + "heed-traits", + "heed-types", + "libc", + "lmdb-master-sys", + "once_cell", + "page_size", + "serde", + "synchronoise", + "url", +] + +[[package]] +name = "heed-traits" +version = "0.20.0-alpha.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "44055e6d049fb62b58671059045fe4a8a083d78ef04347818cc8a87a62d6fa1f" + +[[package]] +name = "heed-types" +version = "0.20.0-alpha.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d419e64e429f0bbe6d8ef3507bf137105c8cebee0e5fba59cf556de93c8ab57" +dependencies = [ + "bincode", + "bytemuck", + "byteorder", + "heed-traits", + "serde", + "serde_json", +] + [[package]] name = "hermit-abi" version = "0.3.2" @@ -2551,12 +2599,15 @@ dependencies = [ "axum", "base64 0.21.2", "bincode", + "bytemuck", "bytes 1.4.0", "bytesize", "clap", "color-eyre", "either", "futures", + "heed", + "heed-types", "hmac", "humantime", "hyper", @@ -2602,6 +2653,18 @@ version = "0.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "09fc20d2ca12cb9f044c93e3bd6d32d523e6e2ec3db4f7b2939cd99026ecd3f0" +[[package]] +name = "lmdb-master-sys" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "629c123f5321b48fa4f8f4d3b868165b748d9ba79c7103fb58e3a94f736bcedd" +dependencies = [ + "cc", + "doxygen-rs", + "libc", + "pkg-config", +] + [[package]] name = "lock_api" version = "0.4.10" @@ -3022,6 +3085,16 @@ version = "3.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c1b04fb49957986fdce4d6ee7a65027d55d4b6d2265e5848bbb507b58ccfdb6f" +[[package]] +name = "page_size" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1b7663cbd190cfd818d08efa8497f6cd383076688c49a391ef7c0d03cd12b561" +dependencies = [ + "libc", + "winapi", +] + [[package]] name = "parking" version = "2.1.0" @@ -3119,6 +3192,7 @@ version = "0.11.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ade2d8b8f33c7333b51bcf0428d37e217e9f32192ae4772156f65063b8ce03dc" dependencies = [ + "phf_macros", "phf_shared", ] @@ -3142,6 +3216,19 @@ dependencies = [ "rand", ] +[[package]] +name = "phf_macros" +version = "0.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3444646e286606587e49f3bcf1679b8cef1dc2c5ecc29ddacaffc305180d464b" +dependencies = [ + "phf_generator", + "phf_shared", + "proc-macro2", + "quote", + "syn 2.0.25", +] + [[package]] name = "phf_shared" version = "0.11.2" @@ -4323,6 +4410,15 @@ version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2047c6ded9c721764247e62cd3b03c09ffc529b2ba5b10ec482ae507a4a70160" +[[package]] +name = "synchronoise" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3dbc01390fc626ce8d1cffe3376ded2b72a11bb70e1c75f404a210e4daa4def2" +dependencies = [ + "crossbeam-queue", +] + [[package]] name = "system-interface" version = "0.25.9" diff --git a/libsqlx-server/Cargo.toml b/libsqlx-server/Cargo.toml index 90b27680..f4eabace 100644 --- a/libsqlx-server/Cargo.toml +++ b/libsqlx-server/Cargo.toml @@ -11,12 +11,15 @@ async-trait = "0.1.71" axum = "0.6.18" base64 = "0.21.2" bincode = "1.3.3" +bytemuck = { version = "1.13.1", features = ["derive"] } bytes = { version = "1.4.0", features = ["serde"] } bytesize = { version = "1.2.0", features = ["serde"] } clap = { version = "4.3.11", features = ["derive"] } color-eyre = "0.6.2" either = "1.8.1" futures = "0.3.28" +heed = { version = "0.20.0-alpha.3", features = ["serde-bincode"] } +heed-types = "0.20.0-alpha.3" hmac = "0.12.1" humantime = "2.1.0" hyper = { version = "0.14.27", features = ["h2", "server"] } diff --git a/libsqlx-server/src/http/admin.rs b/libsqlx-server/src/http/admin.rs index 4dd27944..ac5c9ede 100644 --- a/libsqlx-server/src/http/admin.rs +++ b/libsqlx-server/src/http/admin.rs @@ -185,7 +185,6 @@ async fn list_allocs( .handler() .store() .list_allocs() - .await .into_iter() .map(|cfg| AllocView { id: cfg.db_name }) .collect(); diff --git a/libsqlx-server/src/main.rs b/libsqlx-server/src/main.rs index 296c54c9..5a9e26da 100644 --- a/libsqlx-server/src/main.rs +++ b/libsqlx-server/src/main.rs @@ -93,7 +93,14 @@ async fn main() -> Result<()> { let mut join_set = JoinSet::new(); - let store = Arc::new(Store::new(&config.db_path)); + + let meta_path = config.db_path.join("meta"); + tokio::fs::create_dir_all(&meta_path).await?; + let env = heed::EnvOpenOptions::new() + .max_dbs(1000) + .map_size(100 * 1024 * 1024) + .open(meta_path)?; + let store = Arc::new(Store::new(env.clone())); let manager = Arc::new(Manager::new(config.db_path.clone(), store.clone(), 100)); let bus = Arc::new(Bus::new(config.cluster.id, manager.clone())); diff --git a/libsqlx-server/src/manager.rs b/libsqlx-server/src/manager.rs index 69d1376f..fb44414d 100644 --- a/libsqlx-server/src/manager.rs +++ b/libsqlx-server/src/manager.rs @@ -41,7 +41,7 @@ impl Manager { return Some(sender.clone()); } - if let Some(config) = self.meta_store.meta(&database_id).await { + if let Some(config) = self.meta_store.meta(&database_id) { let path = self.db_path.join("dbs").join(database_id.to_string()); tokio::fs::create_dir_all(&path).await.unwrap(); let (alloc_sender, inbox) = mpsc::channel(MAX_ALLOC_MESSAGE_QUEUE_LEN); @@ -73,12 +73,12 @@ impl Manager { meta: &AllocConfig, dispatcher: Arc, ) { - self.store().allocate(database_id, meta).await; + self.store().allocate(&database_id, meta); self.schedule(database_id, dispatcher).await; } pub async fn deallocate(&self, database_id: DatabaseId) { - self.meta_store.deallocate(database_id).await; + self.meta_store.deallocate(&database_id); self.cache.remove(&database_id).await; let db_path = self.db_path.join("dbs").join(database_id.to_string()); tokio::fs::remove_dir_all(db_path).await.unwrap(); diff --git a/libsqlx-server/src/meta.rs b/libsqlx-server/src/meta.rs index 0d61d04f..84a8856c 100644 --- a/libsqlx-server/src/meta.rs +++ b/libsqlx-server/src/meta.rs @@ -1,10 +1,10 @@ -use core::fmt; -use std::path::Path; +use std::fmt; +use heed::bytemuck::{Pod, Zeroable}; +use heed::types::{SerdeBincode, OwnedType}; use serde::{Deserialize, Serialize}; use sha3::digest::{ExtendableOutput, Update, XofReader}; use sha3::Shake128; -use sled::Tree; use tokio::task::block_in_place; use crate::allocation::config::AllocConfig; @@ -12,10 +12,12 @@ use crate::allocation::config::AllocConfig; type ExecFn = Box)>; pub struct Store { - meta_store: Tree, + env: heed::Env, + alloc_config_db: heed::Database, SerdeBincode>, } -#[derive(Serialize, Deserialize, PartialEq, Eq, Debug, Hash, Clone, Copy)] +#[derive(Serialize, Deserialize, PartialEq, Eq, Debug, Hash, Clone, Copy, Pod, Zeroable)] +#[repr(transparent)] pub struct DatabaseId([u8; 16]); impl DatabaseId { @@ -47,48 +49,63 @@ impl AsRef<[u8]> for DatabaseId { } impl Store { - pub fn new(path: &Path) -> Self { - std::fs::create_dir_all(&path).unwrap(); - let path = path.join("store"); - let db = sled::open(path).unwrap(); - let meta_store = db.open_tree("meta_store").unwrap(); - - Self { meta_store } + const ALLOC_CONFIG_DB_NAME: &'static str = "alloc_conf_db"; + + pub fn new(env: heed::Env) -> Self { + let mut txn = env.write_txn().unwrap(); + let alloc_config_db = env + .create_database(&mut txn, Some(Self::ALLOC_CONFIG_DB_NAME)) + .unwrap(); + txn.commit().unwrap(); + + Self { + env, + alloc_config_db, + } } - pub async fn allocate(&self, id: DatabaseId, meta: &AllocConfig) { + pub fn allocate(&self, id: &DatabaseId, meta: &AllocConfig) { //TODO: Handle conflict block_in_place(|| { - let meta_bytes = bincode::serialize(meta).unwrap(); - self.meta_store - .compare_and_swap(id, None as Option<&[u8]>, Some(meta_bytes)) + let mut txn = self.env.write_txn().unwrap(); + if self + .alloc_config_db + .lazily_decode_data() + .get(&txn, id) .unwrap() - .unwrap(); + .is_some() + { + panic!("alloc already exists"); + }; + self.alloc_config_db.put(&mut txn, id, meta).unwrap(); + txn.commit().unwrap(); }); } - pub async fn deallocate(&self, id: DatabaseId) { - block_in_place(|| self.meta_store.remove(id).unwrap()); + pub fn deallocate(&self, id: &DatabaseId) { + block_in_place(|| { + let mut txn = self.env.write_txn().unwrap(); + self.alloc_config_db.delete(&mut txn, id).unwrap(); + txn.commit().unwrap(); + }); } - pub async fn meta(&self, database_id: &DatabaseId) -> Option { + pub fn meta(&self, id: &DatabaseId) -> Option { block_in_place(|| { - let config = self.meta_store.get(database_id).unwrap()?; - let config = bincode::deserialize(config.as_ref()).unwrap(); - Some(config) + let txn = self.env.read_txn().unwrap(); + self.alloc_config_db.get(&txn, id).unwrap() }) } - pub async fn list_allocs(&self) -> Vec { + pub fn list_allocs(&self) -> Vec { block_in_place(|| { - let mut out = Vec::new(); - for kv in self.meta_store.iter() { - let (_k, v) = kv.unwrap(); - let alloc = bincode::deserialize(&v).unwrap(); - out.push(alloc); - } - - out + let txn = self.env.read_txn().unwrap(); + self.alloc_config_db + .iter(&txn) + .unwrap() + .map(Result::unwrap) + .map(|x| x.1) + .collect() }) } } From 2f27a77ca62bb56ffb3a2345815dcfc81e64db9f Mon Sep 17 00:00:00 2001 From: ad hoc Date: Mon, 24 Jul 2023 11:39:28 +0200 Subject: [PATCH 36/64] remove sled --- Cargo.lock | 66 +++++---------------------------------- libsqlx-server/Cargo.toml | 1 - 2 files changed, 7 insertions(+), 60 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 7b525e06..94c80f4a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1779,16 +1779,6 @@ dependencies = [ "windows-sys 0.48.0", ] -[[package]] -name = "fs2" -version = "0.4.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9564fc758e15025b46aa6643b1b77d047d1a56a1aea6e01002ac0c7026876213" -dependencies = [ - "libc", - "winapi", -] - [[package]] name = "futures" version = "0.3.28" @@ -2575,7 +2565,7 @@ dependencies = [ "itertools 0.11.0", "nix", "once_cell", - "parking_lot 0.12.1", + "parking_lot", "rand", "regex", "rusqlite", @@ -2614,7 +2604,7 @@ dependencies = [ "itertools 0.11.0", "libsqlx", "moka", - "parking_lot 0.12.1", + "parking_lot", "priority-queue", "rand", "regex", @@ -2622,7 +2612,6 @@ dependencies = [ "serde_json", "sha2", "sha3", - "sled", "tempfile", "thiserror", "tokio", @@ -2841,7 +2830,7 @@ dependencies = [ "crossbeam-utils", "futures-util", "once_cell", - "parking_lot 0.12.1", + "parking_lot", "quanta", "rustc_version", "scheduled-thread-pool", @@ -3101,17 +3090,6 @@ version = "2.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "14f2252c834a40ed9bb5422029649578e63aa341ac401f74e719dd1afda8394e" -[[package]] -name = "parking_lot" -version = "0.11.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7d17b78036a60663b797adeaee46f5c9dfebb86948d1255007a1d6be0271ff99" -dependencies = [ - "instant", - "lock_api", - "parking_lot_core 0.8.6", -] - [[package]] name = "parking_lot" version = "0.12.1" @@ -3119,21 +3097,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3742b2c103b9f06bc9fff0a37ff4912935851bee6d36f3c02bcc755bcfec228f" dependencies = [ "lock_api", - "parking_lot_core 0.9.8", -] - -[[package]] -name = "parking_lot_core" -version = "0.8.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "60a2cfe6f0ad2bfc16aefa463b497d5c7a5ecd44a23efa72aa342d90177356dc" -dependencies = [ - "cfg-if", - "instant", - "libc", - "redox_syscall 0.2.16", - "smallvec", - "winapi", + "parking_lot_core", ] [[package]] @@ -3928,7 +3892,7 @@ version = "0.2.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3cbc66816425a074528352f5789333ecff06ca41b36b0b0efdfbb29edc391a19" dependencies = [ - "parking_lot 0.12.1", + "parking_lot", ] [[package]] @@ -4182,22 +4146,6 @@ dependencies = [ "autocfg", ] -[[package]] -name = "sled" -version = "0.34.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7f96b4737c2ce5987354855aed3797279def4ebf734436c6aa4552cf8e169935" -dependencies = [ - "crc32fast", - "crossbeam-epoch", - "crossbeam-utils", - "fs2", - "fxhash", - "libc", - "log", - "parking_lot 0.11.2", -] - [[package]] name = "slice-group-by" version = "0.3.1" @@ -4279,7 +4227,7 @@ dependencies = [ "mimalloc", "nix", "once_cell", - "parking_lot 0.12.1", + "parking_lot", "priority-queue", "proptest", "prost", @@ -4590,7 +4538,7 @@ dependencies = [ "libc", "mio", "num_cpus", - "parking_lot 0.12.1", + "parking_lot", "pin-project-lite", "signal-hook-registry", "socket2", diff --git a/libsqlx-server/Cargo.toml b/libsqlx-server/Cargo.toml index f4eabace..89b34c74 100644 --- a/libsqlx-server/Cargo.toml +++ b/libsqlx-server/Cargo.toml @@ -34,7 +34,6 @@ serde = { version = "1.0.166", features = ["derive", "rc"] } serde_json = "1.0.100" sha2 = "0.10.7" sha3 = "0.10.8" -sled = "0.34.7" tempfile = "3.6.0" thiserror = "1.0.43" tokio = { version = "1.29.1", features = ["full"] } From 0f11e308fea8626d7c9bb85c3953be85d34d721f Mon Sep 17 00:00:00 2001 From: ad hoc Date: Wed, 26 Jul 2023 10:04:42 +0200 Subject: [PATCH 37/64] compaction queue --- Cargo.lock | 12 +- libsqlx-server/Cargo.toml | 6 +- libsqlx-server/src/allocation/mod.rs | 39 ++- .../src/allocation/primary/compactor.rs | 38 ++- libsqlx-server/src/compactor.rs | 224 ++++++++++++++++++ libsqlx-server/src/http/admin.rs | 2 +- libsqlx-server/src/main.rs | 38 ++- libsqlx-server/src/manager.rs | 17 +- libsqlx-server/src/meta.rs | 8 +- libsqlx-server/src/snapshot_store.rs | 75 ++++++ libsqlx/src/database/frame.rs | 7 +- .../database/libsql/replication_log/logger.rs | 110 ++++----- .../libsql/replication_log/snapshot.rs | 6 +- libsqlx/src/error.rs | 2 + 14 files changed, 494 insertions(+), 90 deletions(-) create mode 100644 libsqlx-server/src/compactor.rs create mode 100644 libsqlx-server/src/snapshot_store.rs diff --git a/Cargo.lock b/Cargo.lock index 94c80f4a..a46f82e5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2030,8 +2030,7 @@ checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" [[package]] name = "heed" version = "0.20.0-alpha.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2ae92e0d788e4b5608cddd89bec1128ad9f424e365025bd2454915aab450d7a2" +source = "git+https://github.com/MarinPostma/heed.git?rev=2ae9a14#2ae9a14ce2270118e23f069ba6999212353d94aa" dependencies = [ "bytemuck", "byteorder", @@ -2049,14 +2048,12 @@ dependencies = [ [[package]] name = "heed-traits" version = "0.20.0-alpha.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "44055e6d049fb62b58671059045fe4a8a083d78ef04347818cc8a87a62d6fa1f" +source = "git+https://github.com/MarinPostma/heed.git?rev=2ae9a14#2ae9a14ce2270118e23f069ba6999212353d94aa" [[package]] name = "heed-types" version = "0.20.0-alpha.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9d419e64e429f0bbe6d8ef3507bf137105c8cebee0e5fba59cf556de93c8ab57" +source = "git+https://github.com/MarinPostma/heed.git?rev=2ae9a14#2ae9a14ce2270118e23f069ba6999212353d94aa" dependencies = [ "bincode", "bytemuck", @@ -2645,8 +2642,7 @@ checksum = "09fc20d2ca12cb9f044c93e3bd6d32d523e6e2ec3db4f7b2939cd99026ecd3f0" [[package]] name = "lmdb-master-sys" version = "0.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "629c123f5321b48fa4f8f4d3b868165b748d9ba79c7103fb58e3a94f736bcedd" +source = "git+https://github.com/MarinPostma/heed.git?rev=2ae9a14#2ae9a14ce2270118e23f069ba6999212353d94aa" dependencies = [ "cc", "doxygen-rs", diff --git a/libsqlx-server/Cargo.toml b/libsqlx-server/Cargo.toml index 89b34c74..b925732a 100644 --- a/libsqlx-server/Cargo.toml +++ b/libsqlx-server/Cargo.toml @@ -18,8 +18,10 @@ clap = { version = "4.3.11", features = ["derive"] } color-eyre = "0.6.2" either = "1.8.1" futures = "0.3.28" -heed = { version = "0.20.0-alpha.3", features = ["serde-bincode"] } -heed-types = "0.20.0-alpha.3" +# heed = { version = "0.20.0-alpha.3", features = ["serde-bincode", "sync-read-txn"] } +heed = { git = "https://github.com/MarinPostma/heed.git", rev = "2ae9a14", features = ["serde-bincode", "sync-read-txn"] } +heed-types = { git = "https://github.com/MarinPostma/heed.git", rev = "2ae9a14" } +# heed-types = "0.20.0-alpha.3" hmac = "0.12.1" humantime = "2.1.0" hyper = { version = "0.14.27", features = ["h2", "server"] } diff --git a/libsqlx-server/src/allocation/mod.rs b/libsqlx-server/src/allocation/mod.rs index fbc4049f..6b3e15d7 100644 --- a/libsqlx-server/src/allocation/mod.rs +++ b/libsqlx-server/src/allocation/mod.rs @@ -17,6 +17,7 @@ use tokio::task::{block_in_place, JoinSet}; use tokio::time::Interval; use crate::allocation::primary::FrameStreamer; +use crate::compactor::CompactionQueue; use crate::hrana; use crate::hrana::http::handle_pipeline; use crate::hrana::http::proto::{PipelineRequestBody, PipelineResponseBody}; @@ -70,10 +71,12 @@ impl Database { } = self { ready!(interval.poll_tick(cx)); + tracing::debug!("attempting periodic log compaction"); let db = db.db.clone(); tokio::task::spawn_blocking(move || { db.compact_log(); }); + return Poll::Ready(()) } Poll::Pending @@ -81,7 +84,14 @@ impl Database { } impl Database { - pub fn from_config(config: &AllocConfig, path: PathBuf, dispatcher: Arc) -> Self { + pub fn from_config( + config: &AllocConfig, + path: PathBuf, + dispatcher: Arc, + compaction_queue: Arc, + ) -> Self { + let database_id = DatabaseId::from_name(&config.db_name); + match config.db_config { DbConfig::Primary { max_log_size, @@ -90,7 +100,12 @@ impl Database { let (sender, receiver) = tokio::sync::watch::channel(0); let db = LibsqlDatabase::new_primary( path, - Compactor::new(max_log_size, replication_log_compact_interval), + Compactor::new( + max_log_size, + replication_log_compact_interval, + compaction_queue, + database_id, + ), false, Box::new(move |fno| { let _ = sender.send(fno); @@ -98,13 +113,19 @@ impl Database { ) .unwrap(); + let compact_interval = replication_log_compact_interval.map(|d| { + let mut i = tokio::time::interval(d / 2); + i.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay); + Box::pin(i) + }); + Self::Primary { db: PrimaryDatabase { db: Arc::new(db), replica_streams: HashMap::new(), frame_notifier: receiver, }, - compact_interval: None, + compact_interval, } } DbConfig::Replica { @@ -119,7 +140,6 @@ impl Database { let mut db = WriteProxyDatabase::new(rdb, wdb, Arc::new(|_| ())); let injector = db.injector().unwrap(); let (sender, receiver) = mpsc::channel(16); - let database_id = DatabaseId::from_name(&config.db_name); let replicator = Replicator::new( dispatcher, @@ -208,9 +228,10 @@ impl ConnectionHandle { impl Allocation { pub async fn run(mut self) { loop { + dbg!(); let fut = poll_fn(|cx| self.database.poll(cx)); tokio::select! { - _ = fut => (), + _ = fut => dbg!(), Some(msg) = self.inbox.recv() => { match msg { AllocationMessage::HranaPipelineReq { req, ret } => { @@ -225,12 +246,16 @@ impl Allocation { } } }, - maybe_id = self.connections_futs.join_next() => { + maybe_id = self.connections_futs.join_next(), if !self.connections_futs.is_empty() => { + dbg!(); if let Some(Ok(_id)) = maybe_id { // self.connections.remove_entry(&id); } }, - else => break, + else => { + dbg!(); + break + }, } } } diff --git a/libsqlx-server/src/allocation/primary/compactor.rs b/libsqlx-server/src/allocation/primary/compactor.rs index 62b9c0ed..5bc4c9a3 100644 --- a/libsqlx-server/src/allocation/primary/compactor.rs +++ b/libsqlx-server/src/allocation/primary/compactor.rs @@ -1,19 +1,38 @@ -use std::time::{Duration, Instant}; +use std::{ + path::PathBuf, + sync::Arc, + time::{Duration, Instant}, +}; use libsqlx::libsql::{LogCompactor, LogFile}; +use uuid::Uuid; + +use crate::{ + compactor::{CompactionJob, CompactionQueue}, + meta::DatabaseId, +}; pub struct Compactor { max_log_size: usize, last_compacted_at: Instant, compact_interval: Option, + queue: Arc, + database_id: DatabaseId, } impl Compactor { - pub fn new(max_log_size: usize, compact_interval: Option) -> Self { + pub fn new( + max_log_size: usize, + compact_interval: Option, + queue: Arc, + database_id: DatabaseId, + ) -> Self { Self { max_log_size, last_compacted_at: Instant::now(), compact_interval, + queue, + database_id, } } } @@ -32,11 +51,18 @@ impl LogCompactor for Compactor { fn compact( &mut self, - _log: LogFile, - _path: std::path::PathBuf, - _size_after: u32, + log_id: Uuid, ) -> Result<(), Box> { self.last_compacted_at = Instant::now(); - todo!() + self.queue.push(&CompactionJob { + database_id: self.database_id, + log_id, + }); + + Ok(()) + } + + fn snapshot_dir(&self) -> PathBuf { + self.queue.snapshot_queue_dir() } } diff --git a/libsqlx-server/src/compactor.rs b/libsqlx-server/src/compactor.rs new file mode 100644 index 00000000..22f31c6a --- /dev/null +++ b/libsqlx-server/src/compactor.rs @@ -0,0 +1,224 @@ +use std::io::{BufWriter, Write}; +use std::mem::size_of; +use std::os::unix::prelude::FileExt; +use std::path::{Path, PathBuf}; +use std::sync::{ + atomic::{AtomicU64, Ordering}, + Arc, +}; + +use bytemuck::{bytes_of, Pod, Zeroable}; +use heed::byteorder::BigEndian; +use heed_types::{SerdeBincode, U64}; +use libsqlx::libsql::LogFile; +use libsqlx::{Frame, FrameNo}; +use serde::{Deserialize, Serialize}; +use tempfile::NamedTempFile; +use tokio::sync::watch; +use tokio::task::block_in_place; +use uuid::Uuid; + +use crate::meta::DatabaseId; +use crate::snapshot_store::SnapshotStore; + +#[derive(Debug, Serialize, Deserialize)] +pub struct CompactionJob { + /// Id of the database whose log needs to be compacted + pub database_id: DatabaseId, + /// path to the log to compact + pub log_id: Uuid, +} + +pub struct CompactionQueue { + env: heed::Env, + queue: heed::Database, SerdeBincode>, + next_id: AtomicU64, + notify: watch::Sender>, + db_path: PathBuf, + snapshot_store: Arc, +} + +impl CompactionQueue { + const COMPACTION_QUEUE_DB_NAME: &str = "compaction_queue_db"; + pub fn new( + env: heed::Env, + db_path: PathBuf, + snapshot_store: Arc, + ) -> color_eyre::Result { + let mut txn = env.write_txn()?; + let queue = env.create_database(&mut txn, Some(Self::COMPACTION_QUEUE_DB_NAME))?; + let next_id = match queue.last(&mut txn)? { + Some((id, _)) => id + 1, + None => 0, + }; + txn.commit()?; + + let (notify, _) = watch::channel((next_id > 0).then(|| next_id - 1)); + Ok(Self { + env, + queue, + next_id: next_id.into(), + notify, + db_path, + snapshot_store, + }) + } + + pub fn push(&self, job: &CompactionJob) { + tracing::debug!("new compaction job available: {job:?}"); + let mut txn = self.env.write_txn().unwrap(); + let id = self.next_id.fetch_add(1, Ordering::Relaxed); + self.queue.put(&mut txn, &id, job).unwrap(); + txn.commit().unwrap(); + self.notify.send_replace(Some(id)); + } + + pub async fn peek(&self) -> (u64, CompactionJob) { + let id = self.next_id.load(Ordering::Relaxed); + let txn = block_in_place(|| self.env.read_txn().unwrap()); + match block_in_place(|| self.queue.first(&txn).unwrap()) { + Some(job) => job, + None => { + drop(txn); + self.notify + .subscribe() + .wait_for(|x| x.map(|x| x >= id).unwrap_or_default()) + .await + .unwrap(); + block_in_place(|| { + let txn = self.env.read_txn().unwrap(); + self.queue.first(&txn).unwrap().unwrap() + }) + } + } + } + + fn complete(&self, txn: &mut heed::RwTxn, job_id: u64) { + block_in_place(|| { + self.queue.delete(txn, &job_id).unwrap(); + }); + } + + async fn compact(&self) -> color_eyre::Result<()> { + let (job_id, job) = self.peek().await; + tracing::debug!("starting new compaction job: {job:?}"); + let to_compact_path = self.snapshot_queue_dir().join(job.log_id.to_string()); + let (snapshot_id, start_fno, end_fno) = tokio::task::spawn_blocking({ + let to_compact_path = to_compact_path.clone(); + let db_path = self.db_path.clone(); + move || { + let mut builder = SnapshotBuilder::new(&db_path, job.database_id)?; + let log = LogFile::new(to_compact_path)?; + for frame in log.rev_deduped() { + let frame = frame?; + builder.push_frame(frame)?; + } + builder.finish() + } + }) + .await??; + + let mut txn = self.env.write_txn()?; + self.complete(&mut txn, job_id); + self.snapshot_store + .register(&mut txn, job.database_id, start_fno, end_fno, snapshot_id); + txn.commit()?; + + std::fs::remove_file(to_compact_path)?; + + Ok(()) + } + + pub fn snapshot_queue_dir(&self) -> PathBuf { + self.db_path.join("snapshot_queue") + } +} + +pub async fn run_compactor_loop(compactor: Arc) -> color_eyre::Result<()> { + loop { + compactor.compact().await?; + } +} + +#[derive(Debug, Copy, Clone, Zeroable, Pod, PartialEq, Eq)] +#[repr(C)] +/// header of a snapshot file +pub struct SnapshotFileHeader { + /// id of the database + pub db_id: DatabaseId, + /// first frame in the snapshot + pub start_frame_no: u64, + /// end frame in the snapshot + pub end_frame_no: u64, + /// number of frames in the snapshot + pub frame_count: u64, + /// safe of the database after applying the snapshot + pub size_after: u32, + pub _pad: u32, +} + +/// An utility to build a snapshots from log frames +pub struct SnapshotBuilder { + pub header: SnapshotFileHeader, + snapshot_file: BufWriter, + db_path: PathBuf, + last_seen_frame_no: u64, +} + +impl SnapshotBuilder { + pub fn new(db_path: &Path, db_id: DatabaseId) -> color_eyre::Result { + let temp_dir = db_path.join("tmp"); + let mut target = BufWriter::new(NamedTempFile::new_in(&temp_dir)?); + // reserve header space + target.write_all(&[0; size_of::()])?; + + Ok(Self { + header: SnapshotFileHeader { + db_id, + start_frame_no: u64::MAX, + end_frame_no: u64::MIN, + frame_count: 0, + size_after: 0, + _pad: 0, + }, + snapshot_file: target, + db_path: db_path.to_path_buf(), + last_seen_frame_no: u64::MAX, + }) + } + + pub fn push_frame(&mut self, frame: Frame) -> color_eyre::Result<()> { + assert!(frame.header().frame_no < self.last_seen_frame_no); + self.last_seen_frame_no = frame.header().frame_no; + if frame.header().frame_no < self.header.start_frame_no { + self.header.start_frame_no = frame.header().frame_no; + } + + if frame.header().frame_no > self.header.end_frame_no { + self.header.end_frame_no = frame.header().frame_no; + self.header.size_after = frame.header().size_after; + } + + self.snapshot_file.write_all(frame.as_slice())?; + self.header.frame_count += 1; + + Ok(()) + } + + /// Persist the snapshot, and returns the name and size is frame on the snapshot. + pub fn finish(mut self) -> color_eyre::Result<(Uuid, FrameNo, FrameNo)> { + self.snapshot_file.flush()?; + let file = self.snapshot_file.into_inner()?; + file.as_file().write_all_at(bytes_of(&self.header), 0)?; + let snapshot_id = Uuid::new_v4(); + + let path = self.db_path.join("snapshots").join(snapshot_id.to_string()); + file.persist(path)?; + + Ok(( + snapshot_id, + self.header.start_frame_no, + self.header.end_frame_no, + )) + } +} diff --git a/libsqlx-server/src/http/admin.rs b/libsqlx-server/src/http/admin.rs index ac5c9ede..9b51b7ed 100644 --- a/libsqlx-server/src/http/admin.rs +++ b/libsqlx-server/src/http/admin.rs @@ -57,7 +57,7 @@ struct AllocateReq { } #[derive(Debug, Deserialize)] -#[serde(tag = "type", rename_all = "snake_case")] +#[serde(tag = "type", rename_all = "snake_case", deny_unknown_fields)] pub struct Primary { /// The maximum size the replication is allowed to grow. Expects a string like 200mb. #[serde(default = "default_max_log_size")] diff --git a/libsqlx-server/src/main.rs b/libsqlx-server/src/main.rs index 5a9e26da..b1856e9e 100644 --- a/libsqlx-server/src/main.rs +++ b/libsqlx-server/src/main.rs @@ -1,9 +1,10 @@ use std::fs::read_to_string; -use std::path::PathBuf; +use std::path::{PathBuf, Path}; use std::sync::Arc; use clap::Parser; use color_eyre::eyre::Result; +use compactor::{CompactionQueue, run_compactor_loop}; use config::{AdminApiConfig, ClusterConfig, UserApiConfig}; use http::admin::run_admin_api; use http::user::run_user_api; @@ -11,12 +12,15 @@ use hyper::server::conn::AddrIncoming; use linc::bus::Bus; use manager::Manager; use meta::Store; +use snapshot_store::SnapshotStore; +use tokio::fs::create_dir_all; use tokio::net::{TcpListener, TcpStream}; use tokio::task::JoinSet; use tracing::metadata::LevelFilter; use tracing_subscriber::prelude::*; mod allocation; +mod compactor; mod config; mod database; mod hrana; @@ -24,6 +28,7 @@ mod http; mod linc; mod manager; mod meta; +mod snapshot_store; #[derive(Debug, Parser)] struct Args { @@ -83,6 +88,17 @@ async fn spawn_cluster_networking( Ok(()) } +async fn init_dirs(db_path: &Path) -> color_eyre::Result<()> { + create_dir_all(&db_path).await?; + create_dir_all(db_path.join("tmp")).await?; + create_dir_all(db_path.join("snapshot_queue")).await?; + create_dir_all(db_path.join("snapshots")).await?; + create_dir_all(db_path.join("dbs")).await?; + create_dir_all(db_path.join("meta")).await?; + + Ok(()) +} + #[tokio::main(flavor = "multi_thread", worker_threads = 10)] async fn main() -> Result<()> { init(); @@ -93,17 +109,29 @@ async fn main() -> Result<()> { let mut join_set = JoinSet::new(); + init_dirs(&config.db_path).await?; - let meta_path = config.db_path.join("meta"); - tokio::fs::create_dir_all(&meta_path).await?; let env = heed::EnvOpenOptions::new() .max_dbs(1000) .map_size(100 * 1024 * 1024) - .open(meta_path)?; + .open(config.db_path.join("meta"))?; + + let snapshot_store = Arc::new(SnapshotStore::new(config.db_path.clone(), &env)?); + let compaction_queue = Arc::new(CompactionQueue::new( + env.clone(), + config.db_path.clone(), + snapshot_store, + )?); let store = Arc::new(Store::new(env.clone())); - let manager = Arc::new(Manager::new(config.db_path.clone(), store.clone(), 100)); + let manager = Arc::new(Manager::new( + config.db_path.clone(), + store.clone(), + 100, + compaction_queue.clone(), + )); let bus = Arc::new(Bus::new(config.cluster.id, manager.clone())); + join_set.spawn(run_compactor_loop(compaction_queue)); spawn_cluster_networking(&mut join_set, &config.cluster, bus.clone()).await?; spawn_admin_api(&mut join_set, &config.admin_api, bus.clone()).await?; spawn_user_api(&mut join_set, &config.user_api, manager, bus).await?; diff --git a/libsqlx-server/src/manager.rs b/libsqlx-server/src/manager.rs index fb44414d..a3bb68d8 100644 --- a/libsqlx-server/src/manager.rs +++ b/libsqlx-server/src/manager.rs @@ -8,6 +8,7 @@ use tokio::task::JoinSet; use crate::allocation::config::AllocConfig; use crate::allocation::{Allocation, AllocationMessage, Database}; +use crate::compactor::CompactionQueue; use crate::hrana; use crate::linc::bus::Dispatch; use crate::linc::handler::Handler; @@ -18,16 +19,23 @@ pub struct Manager { cache: Cache>, meta_store: Arc, db_path: PathBuf, + compaction_queue: Arc, } const MAX_ALLOC_MESSAGE_QUEUE_LEN: usize = 32; impl Manager { - pub fn new(db_path: PathBuf, meta_store: Arc, max_conccurent_allocs: u64) -> Self { + pub fn new( + db_path: PathBuf, + meta_store: Arc, + max_conccurent_allocs: u64, + compaction_queue: Arc, + ) -> Self { Self { cache: Cache::new(max_conccurent_allocs), meta_store, db_path, + compaction_queue, } } @@ -47,7 +55,12 @@ impl Manager { let (alloc_sender, inbox) = mpsc::channel(MAX_ALLOC_MESSAGE_QUEUE_LEN); let alloc = Allocation { inbox, - database: Database::from_config(&config, path, dispatcher.clone()), + database: Database::from_config( + &config, + path, + dispatcher.clone(), + self.compaction_queue.clone(), + ), connections_futs: JoinSet::new(), next_conn_id: 0, max_concurrent_connections: config.max_conccurent_connection, diff --git a/libsqlx-server/src/meta.rs b/libsqlx-server/src/meta.rs index 84a8856c..38a1c30b 100644 --- a/libsqlx-server/src/meta.rs +++ b/libsqlx-server/src/meta.rs @@ -1,7 +1,8 @@ use std::fmt; +use std::mem::size_of; use heed::bytemuck::{Pod, Zeroable}; -use heed::types::{SerdeBincode, OwnedType}; +use heed_types::{OwnedType, SerdeBincode}; use serde::{Deserialize, Serialize}; use sha3::digest::{ExtendableOutput, Update, XofReader}; use sha3::Shake128; @@ -30,6 +31,11 @@ impl DatabaseId { Self(out) } + pub fn from_bytes(bytes: &[u8]) -> Self { + assert_eq!(bytes.len(), size_of::()); + Self(bytes.try_into().unwrap()) + } + #[cfg(test)] pub fn random() -> Self { Self(uuid::Uuid::new_v4().into_bytes()) diff --git a/libsqlx-server/src/snapshot_store.rs b/libsqlx-server/src/snapshot_store.rs new file mode 100644 index 00000000..a8803711 --- /dev/null +++ b/libsqlx-server/src/snapshot_store.rs @@ -0,0 +1,75 @@ +use std::mem::size_of; +use std::path::PathBuf; + +use bytemuck::{Pod, Zeroable}; +use heed_types::{CowType, SerdeBincode}; +use libsqlx::FrameNo; +use serde::Serialize; +use uuid::Uuid; + +use crate::meta::DatabaseId; + +#[derive(Clone, Copy, Zeroable, Pod, Debug)] +#[repr(transparent)] +struct BEU64([u8; size_of::()]); + +impl From for BEU64 { + fn from(value: u64) -> Self { + Self(value.to_be_bytes()) + } +} + +impl From for u64 { + fn from(value: BEU64) -> Self { + u64::from_be_bytes(value.0) + } +} + +#[derive(Clone, Copy, Zeroable, Pod, Debug)] +#[repr(C)] +struct SnapshotKey { + database_id: DatabaseId, + start_frame_no: BEU64, + end_frame_no: FrameNo, +} + +#[derive(Debug, Serialize)] +struct SnapshotMeta { + snapshot_id: Uuid, +} + +pub struct SnapshotStore { + database: heed::Database, SerdeBincode>, + db_path: PathBuf, +} + +impl SnapshotStore { + const SNAPSHOT_STORE_NAME: &str = "snapshot-store-db"; + + pub fn new(db_path: PathBuf, env: &heed::Env) -> color_eyre::Result { + let mut txn = env.write_txn().unwrap(); + let database = env.create_database(&mut txn, Some(Self::SNAPSHOT_STORE_NAME))?; + txn.commit()?; + + Ok(Self { database, db_path }) + } + + pub fn register( + &self, + txn: &mut heed::RwTxn, + database_id: DatabaseId, + start_frame_no: FrameNo, + end_frame_no: FrameNo, + snapshot_id: Uuid, + ) { + let key = SnapshotKey { + database_id, + start_frame_no: start_frame_no.into(), + end_frame_no: end_frame_no.into(), + }; + + let data = SnapshotMeta { snapshot_id }; + + self.database.put(txn, &key, &data).unwrap(); + } +} diff --git a/libsqlx/src/database/frame.rs b/libsqlx/src/database/frame.rs index 337853fb..ba2d638d 100644 --- a/libsqlx/src/database/frame.rs +++ b/libsqlx/src/database/frame.rs @@ -55,8 +55,11 @@ impl Frame { Self { data: buf.freeze() } } - pub fn try_from_bytes(data: Bytes) -> anyhow::Result { - anyhow::ensure!(data.len() == Self::SIZE, "invalid frame size"); + pub fn try_from_bytes(data: Bytes) -> crate::Result { + if data.len() != Self::SIZE { + return Err(crate::error::Error::InvalidFrame); + } + Ok(Self { data }) } diff --git a/libsqlx/src/database/libsql/replication_log/logger.rs b/libsqlx/src/database/libsql/replication_log/logger.rs index 48d3916e..96520ef8 100644 --- a/libsqlx/src/database/libsql/replication_log/logger.rs +++ b/libsqlx/src/database/libsql/replication_log/logger.rs @@ -1,3 +1,4 @@ +use std::collections::HashSet; use std::ffi::{c_int, c_void, CStr}; use std::fs::{remove_dir_all, File, OpenOptions}; use std::io::Write; @@ -147,7 +148,7 @@ unsafe impl WalHook for ReplicationLoggerHook { .logger .log_file .write() - .maybe_compact(&mut *ctx.logger.compactor.lock(), &ctx.logger.db_path) + .maybe_compact(&mut *ctx.logger.compactor.lock()) { tracing::error!("fatal error: {e}, exiting"); std::process::abort() @@ -345,6 +346,8 @@ impl ReplicationLoggerHookCtx { #[derive(Debug)] pub struct LogFile { file: File, + /// Path of the LogFile + path: PathBuf, pub header: LogFileHeader, /// number of frames in the log that have not been commited yet. On commit the header's frame /// count is incremented by that ammount. New pages are written after the last @@ -365,16 +368,22 @@ pub enum LogReadError { #[error("requested entry is ahead of log")] Ahead, #[error(transparent)] - Error(#[from] anyhow::Error), + Error(#[from] crate::error::Error), } impl LogFile { /// size of a single frame pub const FRAME_SIZE: usize = size_of::() + WAL_PAGE_SIZE as usize; - pub fn new(file: File) -> crate::Result { + pub fn new(path: PathBuf) -> crate::Result { // FIXME: we should probably take a lock on this file, to prevent anybody else to write to // it. + let file = OpenOptions::new() + .read(true) + .write(true) + .create(true) + .open(&path)?; + let file_end = file.metadata()?.len(); if file_end == 0 { @@ -391,6 +400,7 @@ impl LogFile { }; let mut this = Self { + path, file, header, uncommitted_frame_count: 0, @@ -409,6 +419,7 @@ impl LogFile { uncommitted_frame_count: 0, uncommitted_checksum: 0, commited_checksum: 0, + path, }; if let Some(last_commited) = this.last_commited_frame_no() { @@ -467,7 +478,7 @@ impl LogFile { } /// Returns an iterator over the WAL frame headers - pub fn frames_iter(&self) -> anyhow::Result> + '_> { + pub fn frames_iter(&self) -> anyhow::Result> + '_> { let mut current_frame_offset = 0; Ok(std::iter::from_fn(move || { if current_frame_offset >= self.header.frame_count { @@ -480,12 +491,10 @@ impl LogFile { } /// Returns an iterator over the WAL frame headers - pub fn rev_frames_iter( - &self, - ) -> anyhow::Result> + '_> { + pub fn rev_frames_iter(&self) -> impl Iterator> + '_ { let mut current_frame_offset = self.header.frame_count; - Ok(std::iter::from_fn(move || { + std::iter::from_fn(move || { if current_frame_offset == 0 { return None; } @@ -493,7 +502,24 @@ impl LogFile { let read_byte_offset = Self::absolute_byte_offset(current_frame_offset); let frame = self.read_frame_byte_offset(read_byte_offset); Some(frame) - })) + }) + } + + /// Return a reversed iterator over the deduplicated frames in the log file. + pub fn rev_deduped(&self) -> impl Iterator> + '_ { + let mut iter = self.rev_frames_iter(); + let mut seen = HashSet::new(); + std::iter::from_fn(move || loop { + match iter.next()? { + Ok(frame) => { + if !seen.contains(&frame.header().page_no) { + seen.insert(frame.header().page_no); + return Some(Ok(frame)); + } + } + Err(e) => return Some(Err(e)), + } + }) } fn compute_checksum(&self, page: &WalPage) -> u64 { @@ -541,7 +567,7 @@ impl LogFile { std::mem::size_of::() as u64 + nth * Self::FRAME_SIZE as u64 } - fn byte_offset(&self, id: FrameNo) -> anyhow::Result> { + fn byte_offset(&self, id: FrameNo) -> crate::Result> { if id < self.header.start_frame_no || id > self.header.start_frame_no + self.header.frame_count { @@ -550,7 +576,7 @@ impl LogFile { Ok(Self::absolute_byte_offset(id - self.header.start_frame_no).into()) } - /// Returns bytes represening a WalFrame for frame `frame_no` + /// Returns bytes representing a WalFrame for frame `frame_no` /// /// If the requested frame is before the first frame in the log, or after the last frame, /// Ok(None) is returned. @@ -568,38 +594,26 @@ impl LogFile { Ok(frame) } - fn maybe_compact( - &mut self, - compactor: &mut dyn LogCompactor, - path: &Path, - ) -> anyhow::Result<()> { + fn maybe_compact(&mut self, compactor: &mut dyn LogCompactor) -> anyhow::Result<()> { if self.can_compact() && compactor.should_compact(self) { - return self.do_compaction(compactor, path); + return self.do_compaction(compactor); } Ok(()) } - fn do_compaction( - &mut self, - compactor: &mut dyn LogCompactor, - path: &Path, - ) -> anyhow::Result<()> { + fn do_compaction(&mut self, compactor: &mut dyn LogCompactor) -> anyhow::Result<()> { tracing::info!("performing log compaction"); - let temp_log_path = path.join("temp_log"); + let log_id = Uuid::new_v4(); + let temp_log_path = compactor.snapshot_dir().join(log_id.to_string()); let last_frame = self - .rev_frames_iter()? + .rev_frames_iter() .next() .expect("there should be at least one frame to perform compaction")?; let size_after = last_frame.header().size_after; assert!(size_after != 0); - let file = OpenOptions::new() - .read(true) - .write(true) - .create(true) - .open(&temp_log_path)?; - let mut new_log_file = LogFile::new(file)?; + let mut new_log_file = LogFile::new(temp_log_path.clone())?; let new_header = LogFileHeader { start_frame_no: self.header.start_frame_no + self.header.frame_count, frame_count: 0, @@ -609,16 +623,15 @@ impl LogFile { new_log_file.header = new_header; new_log_file.write_header().unwrap(); // swap old and new snapshot - atomic_rename(&temp_log_path, path.join("wallog")).unwrap(); - let old_log_file = std::mem::replace(self, new_log_file); - compactor - .compact(old_log_file, temp_log_path, size_after) - .unwrap(); + atomic_rename(dbg!(&temp_log_path), dbg!(&self.path)).unwrap(); + std::mem::swap(&mut new_log_file.path, &mut self.path); + let _ = std::mem::replace(self, new_log_file); + compactor.compact(log_id).unwrap(); Ok(()) } - fn read_frame_byte_offset(&self, offset: u64) -> anyhow::Result { + fn read_frame_byte_offset(&self, offset: u64) -> crate::Result { let mut buffer = BytesMut::zeroed(LogFile::FRAME_SIZE); self.file.read_exact_at(&mut buffer, offset)?; let buffer = buffer.freeze(); @@ -637,7 +650,7 @@ impl LogFile { fn reset(self) -> crate::Result { // truncate file self.file.set_len(0)?; - Self::new(self.file) + Self::new(self.path) } /// return the size in bytes of the log @@ -744,19 +757,18 @@ pub trait LogCompactor: Sync + Send + 'static { /// Compact the given snapshot fn compact( &mut self, - log: LogFile, - path: PathBuf, - size_after: u32, + log_id: Uuid, ) -> Result<(), Box>; + + fn snapshot_dir(&self) -> PathBuf; } #[cfg(test)] impl LogCompactor for () { fn compact( &mut self, - _file: LogFile, - _path: PathBuf, - _size_after: u32, + log_name: String, + path: PathBuf, ) -> Result<(), Box> { Ok(()) } @@ -790,13 +802,7 @@ impl ReplicationLogger { let fresh = !log_path.exists(); - let file = OpenOptions::new() - .create(true) - .write(true) - .read(true) - .open(log_path)?; - - let log_file = LogFile::new(file)?; + let log_file = LogFile::new(log_path)?; let header = log_file.header(); let should_recover = if dirty { @@ -935,9 +941,7 @@ impl ReplicationLogger { pub fn compact(&self) { let mut log_file = self.log_file.write(); if log_file.can_compact() { - log_file - .do_compaction(&mut *self.compactor.lock(), &self.db_path) - .unwrap(); + log_file.do_compaction(&mut *self.compactor.lock()).unwrap(); } } } diff --git a/libsqlx/src/database/libsql/replication_log/snapshot.rs b/libsqlx/src/database/libsql/replication_log/snapshot.rs index c5f58ea3..b6bac186 100644 --- a/libsqlx/src/database/libsql/replication_log/snapshot.rs +++ b/libsqlx/src/database/libsql/replication_log/snapshot.rs @@ -118,7 +118,7 @@ impl SnapshotFile { } /// Iterator on the frames contained in the snapshot file, in reverse frame_no order. - pub fn frames_iter(&self) -> impl Iterator> + '_ { + pub fn frames_iter(&self) -> impl Iterator> + '_ { let mut current_offset = 0; std::iter::from_fn(move || { if current_offset >= self.header.frame_count { @@ -139,7 +139,7 @@ impl SnapshotFile { pub fn frames_iter_from( &self, frame_no: u64, - ) -> impl Iterator> + '_ { + ) -> impl Iterator> + '_ { let mut iter = self.frames_iter(); std::iter::from_fn(move || match iter.next() { Some(Ok(bytes)) => match Frame::try_from_bytes(bytes.clone()) { @@ -197,7 +197,7 @@ impl SnapshotBuilder { /// append frames to the snapshot. Frames must be in decreasing frame_no order. pub fn append_frames( &mut self, - frames: impl Iterator>, + frames: impl Iterator>, ) -> anyhow::Result<()> { // We iterate on the frames starting from the end of the log and working our way backward. We // make sure that only the most recent version of each file is present in the resulting diff --git a/libsqlx/src/error.rs b/libsqlx/src/error.rs index fd7828c1..091152b9 100644 --- a/libsqlx/src/error.rs +++ b/libsqlx/src/error.rs @@ -45,4 +45,6 @@ pub enum Error { }, #[error(transparent)] LexerError(#[from] sqlite3_parser::lexer::sql::Error), + #[error("invalid frame")] + InvalidFrame, } From 4617dc58b656a2add910d58e0c65929501fcd4fe Mon Sep 17 00:00:00 2001 From: ad hoc Date: Wed, 26 Jul 2023 14:53:12 +0200 Subject: [PATCH 38/64] add locate method to SnapshotStore --- libsqlx-server/src/snapshot_store.rs | 59 +++++++++++++++++++++++----- 1 file changed, 50 insertions(+), 9 deletions(-) diff --git a/libsqlx-server/src/snapshot_store.rs b/libsqlx-server/src/snapshot_store.rs index a8803711..f4a008f2 100644 --- a/libsqlx-server/src/snapshot_store.rs +++ b/libsqlx-server/src/snapshot_store.rs @@ -2,9 +2,11 @@ use std::mem::size_of; use std::path::PathBuf; use bytemuck::{Pod, Zeroable}; -use heed_types::{CowType, SerdeBincode}; +use heed::BytesDecode; +use heed_types::{ByteSlice, CowType, SerdeBincode}; use libsqlx::FrameNo; -use serde::Serialize; +use serde::{Deserialize, Serialize}; +use tokio::task::block_in_place; use uuid::Uuid; use crate::meta::DatabaseId; @@ -30,15 +32,16 @@ impl From for u64 { struct SnapshotKey { database_id: DatabaseId, start_frame_no: BEU64, - end_frame_no: FrameNo, + end_frame_no: BEU64, } -#[derive(Debug, Serialize)] -struct SnapshotMeta { - snapshot_id: Uuid, +#[derive(Debug, Clone, Copy, Serialize, Deserialize)] +pub struct SnapshotMeta { + pub snapshot_id: Uuid, } pub struct SnapshotStore { + env: heed::Env, database: heed::Database, SerdeBincode>, db_path: PathBuf, } @@ -46,12 +49,16 @@ pub struct SnapshotStore { impl SnapshotStore { const SNAPSHOT_STORE_NAME: &str = "snapshot-store-db"; - pub fn new(db_path: PathBuf, env: &heed::Env) -> color_eyre::Result { + pub fn new(db_path: PathBuf, env: heed::Env) -> color_eyre::Result { let mut txn = env.write_txn().unwrap(); let database = env.create_database(&mut txn, Some(Self::SNAPSHOT_STORE_NAME))?; txn.commit()?; - Ok(Self { database, db_path }) + Ok(Self { + database, + db_path, + env, + }) } pub fn register( @@ -70,6 +77,40 @@ impl SnapshotStore { let data = SnapshotMeta { snapshot_id }; - self.database.put(txn, &key, &data).unwrap(); + block_in_place(|| self.database.put(txn, &key, &data).unwrap()); + } + + /// Locate a snapshot for `database_id` that contains `frame_no` + pub fn locate(&self, database_id: DatabaseId, frame_no: FrameNo) -> Option { + let txn = self.env.read_txn().unwrap(); + // Snapshot keys being lexicographically ordered, looking for the first key less than of + // equal to (db_id, frame_no, FrameNo::MAX) will always return the entry we're looking for + // if it exists. + let key = SnapshotKey { + database_id, + start_frame_no: frame_no.into(), + end_frame_no: u64::MAX.into(), + }; + + match self + .database + .get_lower_than_or_equal_to(&txn, &key) + .transpose()? + { + Ok((key, v)) => { + if key.database_id != database_id { + return None; + } else if frame_no >= key.start_frame_no.into() + && frame_no <= key.end_frame_no.into() + { + return Some(v); + } else { + None + } + } + Err(_) => todo!(), + } + } +} } } From 898974ea5140e472cbcbd8928fb85b393413abee Mon Sep 17 00:00:00 2001 From: ad hoc Date: Wed, 26 Jul 2023 14:53:33 +0200 Subject: [PATCH 39/64] add test for locate method --- libsqlx-server/src/snapshot_store.rs | 81 ++++++++++++++++++++++++++++ 1 file changed, 81 insertions(+) diff --git a/libsqlx-server/src/snapshot_store.rs b/libsqlx-server/src/snapshot_store.rs index f4a008f2..68610c10 100644 --- a/libsqlx-server/src/snapshot_store.rs +++ b/libsqlx-server/src/snapshot_store.rs @@ -112,5 +112,86 @@ impl SnapshotStore { } } } + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn insert_and_locate() { + let temp = tempfile::tempdir().unwrap(); + let env = heed::EnvOpenOptions::new() + .max_dbs(10) + .map_size(1000 * 4096) + .open(temp.path()) + .unwrap(); + let store = SnapshotStore::new(temp.path().to_path_buf(), env).unwrap(); + let mut txn = store.env.write_txn().unwrap(); + let db_id = DatabaseId::random(); + let snapshot_id = Uuid::new_v4(); + store.register(&mut txn, db_id, 0, 51, snapshot_id); + txn.commit().unwrap(); + + assert!(store.locate(db_id, 0).is_some()); + assert!(store.locate(db_id, 17).is_some()); + assert!(store.locate(db_id, 51).is_some()); + assert!(store.locate(db_id, 52).is_none()); + } + + #[test] + fn multiple_snapshots() { + let temp = tempfile::tempdir().unwrap(); + let env = heed::EnvOpenOptions::new() + .max_dbs(10) + .map_size(1000 * 4096) + .open(temp.path()) + .unwrap(); + let store = SnapshotStore::new(temp.path().to_path_buf(), env).unwrap(); + let mut txn = store.env.write_txn().unwrap(); + let db_id = DatabaseId::random(); + let snapshot_1_id = Uuid::new_v4(); + store.register(&mut txn, db_id, 0, 51, snapshot_1_id); + let snapshot_2_id = Uuid::new_v4(); + store.register(&mut txn, db_id, 52, 112, snapshot_2_id); + txn.commit().unwrap(); + + assert_eq!(store.locate(db_id, 0).unwrap().snapshot_id, snapshot_1_id); + assert_eq!(store.locate(db_id, 17).unwrap().snapshot_id, snapshot_1_id); + assert_eq!(store.locate(db_id, 51).unwrap().snapshot_id, snapshot_1_id); + assert_eq!(store.locate(db_id, 52).unwrap().snapshot_id, snapshot_2_id); + assert_eq!(store.locate(db_id, 100).unwrap().snapshot_id, snapshot_2_id); + assert_eq!(store.locate(db_id, 112).unwrap().snapshot_id, snapshot_2_id); + assert!(store.locate(db_id, 12345).is_none()); + } + + #[test] + fn multiple_databases() { + let temp = tempfile::tempdir().unwrap(); + let env = heed::EnvOpenOptions::new() + .max_dbs(10) + .map_size(1000 * 4096) + .open(temp.path()) + .unwrap(); + let store = SnapshotStore::new(temp.path().to_path_buf(), env).unwrap(); + let mut txn = store.env.write_txn().unwrap(); + let db_id1 = DatabaseId::random(); + let db_id2 = DatabaseId::random(); + let snapshot_id1 = Uuid::new_v4(); + let snapshot_id2 = Uuid::new_v4(); + store.register(&mut txn, db_id1, 0, 51, snapshot_id1); + store.register(&mut txn, db_id2, 0, 51, snapshot_id2); + txn.commit().unwrap(); + + assert_eq!(store.locate(db_id1, 0).unwrap().snapshot_id, snapshot_id1); + assert_eq!(store.locate(db_id2, 0).unwrap().snapshot_id, snapshot_id2); + + assert_eq!(store.locate(db_id1, 12).unwrap().snapshot_id, snapshot_id1); + assert_eq!(store.locate(db_id2, 18).unwrap().snapshot_id, snapshot_id2); + + assert_eq!(store.locate(db_id1, 51).unwrap().snapshot_id, snapshot_id1); + assert_eq!(store.locate(db_id2, 51).unwrap().snapshot_id, snapshot_id2); + + assert!(store.locate(db_id1, 52).is_none()); + assert!(store.locate(db_id2, 52).is_none()); } } From 992256838b4d1d2e2c424913a06dc3634652c5c6 Mon Sep 17 00:00:00 2001 From: ad hoc Date: Wed, 26 Jul 2023 14:54:09 +0200 Subject: [PATCH 40/64] add compactor test --- libsqlx-server/src/compactor.rs | 155 +++++++++++++++++++++++++++++--- 1 file changed, 141 insertions(+), 14 deletions(-) diff --git a/libsqlx-server/src/compactor.rs b/libsqlx-server/src/compactor.rs index 22f31c6a..ac88a455 100644 --- a/libsqlx-server/src/compactor.rs +++ b/libsqlx-server/src/compactor.rs @@ -1,3 +1,4 @@ +use std::fs::File; use std::io::{BufWriter, Write}; use std::mem::size_of; use std::os::unix::prelude::FileExt; @@ -7,7 +8,8 @@ use std::sync::{ Arc, }; -use bytemuck::{bytes_of, Pod, Zeroable}; +use bytemuck::{bytes_of, pod_read_unaligned, Pod, Zeroable}; +use bytes::{Bytes, BytesMut}; use heed::byteorder::BigEndian; use heed_types::{SerdeBincode, U64}; use libsqlx::libsql::LogFile; @@ -85,7 +87,7 @@ impl CompactionQueue { .wait_for(|x| x.map(|x| x >= id).unwrap_or_default()) .await .unwrap(); - block_in_place(|| { + block_in_place(|| { let txn = self.env.read_txn().unwrap(); self.queue.first(&txn).unwrap().unwrap() }) @@ -103,11 +105,11 @@ impl CompactionQueue { let (job_id, job) = self.peek().await; tracing::debug!("starting new compaction job: {job:?}"); let to_compact_path = self.snapshot_queue_dir().join(job.log_id.to_string()); - let (snapshot_id, start_fno, end_fno) = tokio::task::spawn_blocking({ + let (start_fno, end_fno) = tokio::task::spawn_blocking({ let to_compact_path = to_compact_path.clone(); let db_path = self.db_path.clone(); move || { - let mut builder = SnapshotBuilder::new(&db_path, job.database_id)?; + let mut builder = SnapshotBuilder::new(&db_path, job.database_id, job.log_id)?; let log = LogFile::new(to_compact_path)?; for frame in log.rev_deduped() { let frame = frame?; @@ -121,7 +123,7 @@ impl CompactionQueue { let mut txn = self.env.write_txn()?; self.complete(&mut txn, job_id); self.snapshot_store - .register(&mut txn, job.database_id, start_fno, end_fno, snapshot_id); + .register(&mut txn, job.database_id, start_fno, end_fno, job.log_id); txn.commit()?; std::fs::remove_file(to_compact_path)?; @@ -160,13 +162,14 @@ pub struct SnapshotFileHeader { /// An utility to build a snapshots from log frames pub struct SnapshotBuilder { pub header: SnapshotFileHeader, + snapshot_id: Uuid, snapshot_file: BufWriter, db_path: PathBuf, last_seen_frame_no: u64, } impl SnapshotBuilder { - pub fn new(db_path: &Path, db_id: DatabaseId) -> color_eyre::Result { + pub fn new(db_path: &Path, db_id: DatabaseId, snapshot_id: Uuid) -> color_eyre::Result { let temp_dir = db_path.join("tmp"); let mut target = BufWriter::new(NamedTempFile::new_in(&temp_dir)?); // reserve header space @@ -184,6 +187,7 @@ impl SnapshotBuilder { snapshot_file: target, db_path: db_path.to_path_buf(), last_seen_frame_no: u64::MAX, + snapshot_id, }) } @@ -206,19 +210,142 @@ impl SnapshotBuilder { } /// Persist the snapshot, and returns the name and size is frame on the snapshot. - pub fn finish(mut self) -> color_eyre::Result<(Uuid, FrameNo, FrameNo)> { + pub fn finish(mut self) -> color_eyre::Result<(FrameNo, FrameNo)> { self.snapshot_file.flush()?; let file = self.snapshot_file.into_inner()?; file.as_file().write_all_at(bytes_of(&self.header), 0)?; - let snapshot_id = Uuid::new_v4(); - let path = self.db_path.join("snapshots").join(snapshot_id.to_string()); + let path = self + .db_path + .join("snapshots") + .join(self.snapshot_id.to_string()); file.persist(path)?; - Ok(( - snapshot_id, - self.header.start_frame_no, - self.header.end_frame_no, - )) + Ok((self.header.start_frame_no, self.header.end_frame_no)) + } +} + +pub struct SnapshotFile { + pub file: File, + pub header: SnapshotFileHeader, +} + +impl SnapshotFile { + pub fn open(path: &Path) -> color_eyre::Result { + let file = File::open(path)?; + let mut header_buf = [0; size_of::()]; + file.read_exact_at(&mut header_buf, 0)?; + let header: SnapshotFileHeader = pod_read_unaligned(&header_buf); + + Ok(Self { file, header }) + } + + /// Iterator on the frames contained in the snapshot file, in reverse frame_no order. + pub fn frames_iter(&self) -> impl Iterator> + '_ { + let mut current_offset = 0; + std::iter::from_fn(move || { + if current_offset >= self.header.frame_count { + return None; + } + let read_offset = size_of::() as u64 + + current_offset * LogFile::FRAME_SIZE as u64; + current_offset += 1; + let mut buf = BytesMut::zeroed(LogFile::FRAME_SIZE); + match self.file.read_exact_at(&mut buf, read_offset as _) { + Ok(_) => Some(Ok(buf.freeze())), + Err(e) => Some(Err(e.into())), + } + }) + } + + /// Like `frames_iter`, but stops as soon as a frame with frame_no <= `frame_no` is reached + pub fn frames_iter_from( + &self, + frame_no: u64, + ) -> impl Iterator> + '_ { + let mut iter = self.frames_iter(); + std::iter::from_fn(move || match iter.next() { + Some(Ok(bytes)) => match Frame::try_from_bytes(bytes.clone()) { + Ok(frame) => { + if frame.header().frame_no < frame_no { + None + } else { + Some(Ok(bytes)) + } + } + Err(e) => Some(Err(e)), + }, + other => other, + }) + } +} + +#[cfg(test)] +mod test { + use std::collections::HashSet; + + use crate::init_dirs; + + use super::*; + + #[tokio::test(flavor = "multi_thread")] + async fn create_snapshot() { + let temp = tempfile::tempdir().unwrap(); + init_dirs(temp.path()).await.unwrap(); + let env = heed::EnvOpenOptions::new() + .max_dbs(100) + .map_size(1000 * 4096) + .open(temp.path().join("meta")) + .unwrap(); + let snapshot_store = SnapshotStore::new(temp.path().to_path_buf(), env.clone()).unwrap(); + let store = Arc::new(snapshot_store); + let queue = CompactionQueue::new(env, temp.path().to_path_buf(), store.clone()).unwrap(); + let log_id = Uuid::new_v4(); + let database_id = DatabaseId::random(); + + let log_path = temp.path().join("snapshot_queue").join(log_id.to_string()); + tokio::fs::copy("assets/test/simple-log", &log_path) + .await + .unwrap(); + + let log_file = LogFile::new(log_path).unwrap(); + let expected_start_frameno = log_file.header().start_frame_no; + let expected_end_frameno = + log_file.header().start_frame_no + log_file.header().frame_count - 1; + let mut expected_page_content = log_file + .frames_iter() + .unwrap() + .map(|f| f.unwrap().header().page_no) + .collect::>(); + + queue.push(&CompactionJob { + database_id, + log_id, + }); + + queue.compact().await.unwrap(); + + let snapshot_path = temp.path().join("snapshots").join(log_id.to_string()); + assert!(snapshot_path.exists()); + + let snapshot_file = SnapshotFile::open(&snapshot_path).unwrap(); + assert_eq!(snapshot_file.header.start_frame_no, expected_start_frameno); + assert_eq!(snapshot_file.header.end_frame_no, expected_end_frameno); + assert!(snapshot_file.frames_iter().all(|f| expected_page_content + .remove(&Frame::try_from_bytes(f.unwrap()).unwrap().header().page_no))); + assert!(expected_page_content.is_empty()); + + assert_eq!(snapshot_file + .frames_iter() + .map(Result::unwrap) + .map(Frame::try_from_bytes) + .map(Result::unwrap) + .map(|f| f.header().frame_no) + .reduce(|prev, new| { + assert!(new < prev); + new + }).unwrap(), 0); + + assert_eq!(store.locate(database_id, 0).unwrap().snapshot_id, log_id); } } From 884b2c84d37a8a8cb581485b9c998ff459e91584 Mon Sep 17 00:00:00 2001 From: ad hoc Date: Wed, 26 Jul 2023 14:54:24 +0200 Subject: [PATCH 41/64] remove snapshot for libsqlx --- .../libsql/replication_log/snapshot.rs | 334 ------------------ 1 file changed, 334 deletions(-) delete mode 100644 libsqlx/src/database/libsql/replication_log/snapshot.rs diff --git a/libsqlx/src/database/libsql/replication_log/snapshot.rs b/libsqlx/src/database/libsql/replication_log/snapshot.rs deleted file mode 100644 index b6bac186..00000000 --- a/libsqlx/src/database/libsql/replication_log/snapshot.rs +++ /dev/null @@ -1,334 +0,0 @@ -use std::collections::HashSet; -use std::fs::File; -use std::io::BufWriter; -use std::io::Write; -use std::mem::size_of; -use std::os::unix::prelude::FileExt; -use std::path::{Path, PathBuf}; -use std::str::FromStr; - -use bytemuck::{bytes_of, pod_read_unaligned, Pod, Zeroable}; -use bytes::{Bytes, BytesMut}; -use once_cell::sync::Lazy; -use regex::Regex; -use tempfile::NamedTempFile; -use uuid::Uuid; - -use crate::database::frame::Frame; - -use super::logger::LogFile; -use super::FrameNo; - -/// This is the ratio of the space required to store snapshot vs size of the actual database. -/// When this ratio is exceeded, compaction is triggered. -pub const SNAPHOT_SPACE_AMPLIFICATION_FACTOR: u64 = 2; -/// The maximum amount of snapshot allowed before a compaction is required -pub const MAX_SNAPSHOT_NUMBER: usize = 32; - -#[derive(Debug, Copy, Clone, Zeroable, Pod, PartialEq, Eq)] -#[repr(C)] -pub struct SnapshotFileHeader { - /// id of the database - pub db_id: u128, - /// first frame in the snapshot - pub start_frame_no: u64, - /// end frame in the snapshot - pub end_frame_no: u64, - /// number of frames in the snapshot - pub frame_count: u64, - /// safe of the database after applying the snapshot - pub size_after: u32, - pub _pad: u32, -} - -pub struct SnapshotFile { - pub file: File, - pub header: SnapshotFileHeader, -} - -/// returns (db_id, start_frame_no, end_frame_no) for the given snapshot name -pub fn parse_snapshot_name(name: &str) -> Option<(Uuid, u64, u64)> { - static SNAPSHOT_FILE_MATCHER: Lazy = Lazy::new(|| { - Regex::new( - r#"(?x) - # match database id - (\w{8}-\w{4}-\w{4}-\w{4}-\w{12})- - # match start frame_no - (\d*)- - # match end frame_no - (\d*).snap"#, - ) - .unwrap() - }); - let Some(captures) = SNAPSHOT_FILE_MATCHER.captures(name) else { return None}; - let db_id = captures.get(1).unwrap(); - let start_index: u64 = captures.get(2).unwrap().as_str().parse().unwrap(); - let end_index: u64 = captures.get(3).unwrap().as_str().parse().unwrap(); - - Some(( - Uuid::from_str(db_id.as_str()).unwrap(), - start_index, - end_index, - )) -} - -pub fn snapshot_list(db_path: &Path) -> anyhow::Result> { - let mut entries = std::fs::read_dir(snapshot_dir_path(db_path))?; - Ok(std::iter::from_fn(move || { - for entry in entries.by_ref() { - let Ok(entry) = entry else { continue; }; - let path = entry.path(); - let Some(name) = path.file_name() else {continue;}; - let Some(name_str) = name.to_str() else { continue;}; - - return Some(name_str.to_string()); - } - None - })) -} - -/// Return snapshot file containing "logically" frame_no -pub fn find_snapshot_file( - db_path: &Path, - frame_no: FrameNo, -) -> anyhow::Result> { - let snapshot_dir_path = snapshot_dir_path(db_path); - for name in snapshot_list(db_path)? { - let Some((_, start_frame_no, end_frame_no)) = parse_snapshot_name(&name) else { continue; }; - // we're looking for the frame right after the last applied frame on the replica - if (start_frame_no..=end_frame_no).contains(&frame_no) { - let snapshot_path = snapshot_dir_path.join(&name); - tracing::debug!("found snapshot for frame {frame_no} at {snapshot_path:?}"); - let snapshot_file = SnapshotFile::open(&snapshot_path)?; - return Ok(Some(snapshot_file)); - } - } - - Ok(None) -} - -impl SnapshotFile { - pub fn open(path: &Path) -> anyhow::Result { - let file = File::open(path)?; - let mut header_buf = [0; size_of::()]; - file.read_exact_at(&mut header_buf, 0)?; - let header: SnapshotFileHeader = pod_read_unaligned(&header_buf); - - Ok(Self { file, header }) - } - - /// Iterator on the frames contained in the snapshot file, in reverse frame_no order. - pub fn frames_iter(&self) -> impl Iterator> + '_ { - let mut current_offset = 0; - std::iter::from_fn(move || { - if current_offset >= self.header.frame_count { - return None; - } - let read_offset = size_of::() as u64 - + current_offset * LogFile::FRAME_SIZE as u64; - current_offset += 1; - let mut buf = BytesMut::zeroed(LogFile::FRAME_SIZE); - match self.file.read_exact_at(&mut buf, read_offset as _) { - Ok(_) => Some(Ok(buf.freeze())), - Err(e) => Some(Err(e.into())), - } - }) - } - - /// Like `frames_iter`, but stops as soon as a frame with frame_no <= `frame_no` is reached - pub fn frames_iter_from( - &self, - frame_no: u64, - ) -> impl Iterator> + '_ { - let mut iter = self.frames_iter(); - std::iter::from_fn(move || match iter.next() { - Some(Ok(bytes)) => match Frame::try_from_bytes(bytes.clone()) { - Ok(frame) => { - if frame.header().frame_no < frame_no { - None - } else { - Some(Ok(bytes)) - } - } - Err(e) => Some(Err(e)), - }, - other => other, - }) - } -} - -/// An utility to build a snapshots from log frames -pub struct SnapshotBuilder { - seen_pages: HashSet, - pub header: SnapshotFileHeader, - snapshot_file: BufWriter, - db_path: PathBuf, - last_seen_frame_no: u64, -} - -pub fn snapshot_dir_path(db_path: &Path) -> PathBuf { - db_path.join("snapshots") -} - -impl SnapshotBuilder { - pub fn new(db_path: &Path, db_id: u128) -> anyhow::Result { - let snapshot_dir_path = snapshot_dir_path(db_path); - std::fs::create_dir_all(&snapshot_dir_path)?; - let mut target = BufWriter::new(NamedTempFile::new_in(&snapshot_dir_path)?); - // reserve header space - target.write_all(&[0; size_of::()])?; - - Ok(Self { - seen_pages: HashSet::new(), - header: SnapshotFileHeader { - db_id, - start_frame_no: u64::MAX, - end_frame_no: u64::MIN, - frame_count: 0, - size_after: 0, - _pad: 0, - }, - snapshot_file: target, - db_path: db_path.to_path_buf(), - last_seen_frame_no: u64::MAX, - }) - } - - /// append frames to the snapshot. Frames must be in decreasing frame_no order. - pub fn append_frames( - &mut self, - frames: impl Iterator>, - ) -> anyhow::Result<()> { - // We iterate on the frames starting from the end of the log and working our way backward. We - // make sure that only the most recent version of each file is present in the resulting - // snapshot. - // - // The snapshot file contains the most recent version of each page, in descending frame - // number order. That last part is important for when we read it later on. - for frame in frames { - let frame = frame?; - assert!(frame.header().frame_no < self.last_seen_frame_no); - self.last_seen_frame_no = frame.header().frame_no; - if frame.header().frame_no < self.header.start_frame_no { - self.header.start_frame_no = frame.header().frame_no; - } - - if frame.header().frame_no > self.header.end_frame_no { - self.header.end_frame_no = frame.header().frame_no; - self.header.size_after = frame.header().size_after; - } - - if !self.seen_pages.contains(&frame.header().page_no) { - self.seen_pages.insert(frame.header().page_no); - self.snapshot_file.write_all(frame.as_slice())?; - self.header.frame_count += 1; - } - } - - Ok(()) - } - - /// Persist the snapshot, and returns the name and size is frame on the snapshot. - pub fn finish(mut self) -> anyhow::Result<(String, u64)> { - self.snapshot_file.flush()?; - let file = self.snapshot_file.into_inner()?; - file.as_file().write_all_at(bytes_of(&self.header), 0)?; - let snapshot_name = format!( - "{}-{}-{}.snap", - Uuid::from_u128(self.header.db_id), - self.header.start_frame_no, - self.header.end_frame_no, - ); - - file.persist(snapshot_dir_path(&self.db_path).join(&snapshot_name))?; - - Ok((snapshot_name, self.header.frame_count)) - } -} - -// #[cfg(test)] -// mod test { -// use std::fs::read; -// use std::{thread, time::Duration}; -// -// use bytemuck::pod_read_unaligned; -// use bytes::Bytes; -// use tempfile::tempdir; -// -// use crate::database::frame::FrameHeader; -// use crate::database::libsql::replication_log::logger::WalPage; -// -// use super::*; -// -// #[test] -// fn compact_file_create_snapshot() { -// let temp = tempfile::NamedTempFile::new().unwrap(); -// let mut log_file = LogFile::new(temp.as_file().try_clone().unwrap(), 0).unwrap(); -// let db_id = Uuid::new_v4(); -// log_file.header.db_id = db_id.as_u128(); -// log_file.write_header().unwrap(); -// -// // add 50 pages, each one in two versions -// for _ in 0..2 { -// for i in 0..25 { -// let data = std::iter::repeat(0).take(4096).collect::(); -// let page = WalPage { -// page_no: i, -// size_after: i + 1, -// data, -// }; -// log_file.push_page(&page).unwrap(); -// } -// } -// -// log_file.commit().unwrap(); -// -// let dump_dir = tempdir().unwrap(); -// let compactor = LogCompactor::new(dump_dir.path(), db_id.as_u128()).unwrap(); -// compactor -// .compact(log_file, temp.path().to_path_buf(), 25) -// .unwrap(); -// -// thread::sleep(Duration::from_secs(1)); -// -// let snapshot_path = -// snapshot_dir_path(dump_dir.path()).join(format!("{}-{}-{}.snap", db_id, 0, 49)); -// let snapshot = read(&snapshot_path).unwrap(); -// let header: SnapshotFileHeader = -// pod_read_unaligned(&snapshot[..std::mem::size_of::()]); -// -// assert_eq!(header.start_frame_no, 0); -// assert_eq!(header.end_frame_no, 49); -// assert_eq!(header.frame_count, 25); -// assert_eq!(header.db_id, db_id.as_u128()); -// assert_eq!(header.size_after, 25); -// -// let mut seen_frames = HashSet::new(); -// let mut seen_page_no = HashSet::new(); -// let data = &snapshot[std::mem::size_of::()..]; -// data.chunks(LogFile::FRAME_SIZE).for_each(|f| { -// let frame = Frame::try_from_bytes(Bytes::copy_from_slice(f)).unwrap(); -// assert!(!seen_frames.contains(&frame.header().frame_no)); -// assert!(!seen_page_no.contains(&frame.header().page_no)); -// seen_page_no.insert(frame.header().page_no); -// seen_frames.insert(frame.header().frame_no); -// assert!(frame.header().frame_no >= 25); -// }); -// -// assert_eq!(seen_frames.len(), 25); -// assert_eq!(seen_page_no.len(), 25); -// -// let snapshot_file = SnapshotFile::open(&snapshot_path).unwrap(); -// -// let frames = snapshot_file.frames_iter_from(0); -// let mut expected_frame_no = 49; -// for frame in frames { -// let frame = frame.unwrap(); -// let header: FrameHeader = pod_read_unaligned(&frame[..size_of::()]); -// assert_eq!(header.frame_no, expected_frame_no); -// expected_frame_no -= 1; -// } -// -// assert_eq!(expected_frame_no, 24); -// } -// } From c987f5cca4783bfd721e9010dc953333ec88853a Mon Sep 17 00:00:00 2001 From: ad hoc Date: Wed, 26 Jul 2023 14:54:53 +0200 Subject: [PATCH 42/64] add missing Poll::Ready for Database::Poll --- libsqlx-server/src/allocation/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libsqlx-server/src/allocation/mod.rs b/libsqlx-server/src/allocation/mod.rs index 6b3e15d7..69a3a3a1 100644 --- a/libsqlx-server/src/allocation/mod.rs +++ b/libsqlx-server/src/allocation/mod.rs @@ -76,7 +76,7 @@ impl Database { tokio::task::spawn_blocking(move || { db.compact_log(); }); - return Poll::Ready(()) + return Poll::Ready(()); } Poll::Pending From 6337d9cc3c30bc3d6c6417ac4e043757106a8bbc Mon Sep 17 00:00:00 2001 From: ad hoc Date: Wed, 26 Jul 2023 14:57:16 +0200 Subject: [PATCH 43/64] fix periodic compaction --- libsqlx-server/src/allocation/mod.rs | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/libsqlx-server/src/allocation/mod.rs b/libsqlx-server/src/allocation/mod.rs index 69a3a3a1..f6d87b24 100644 --- a/libsqlx-server/src/allocation/mod.rs +++ b/libsqlx-server/src/allocation/mod.rs @@ -113,7 +113,7 @@ impl Database { ) .unwrap(); - let compact_interval = replication_log_compact_interval.map(|d| { + let compact_interval = replication_log_compact_interval.map(|d| { let mut i = tokio::time::interval(d / 2); i.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay); Box::pin(i) @@ -228,10 +228,9 @@ impl ConnectionHandle { impl Allocation { pub async fn run(mut self) { loop { - dbg!(); let fut = poll_fn(|cx| self.database.poll(cx)); tokio::select! { - _ = fut => dbg!(), + _ = fut => (), Some(msg) = self.inbox.recv() => { match msg { AllocationMessage::HranaPipelineReq { req, ret } => { @@ -247,15 +246,11 @@ impl Allocation { } }, maybe_id = self.connections_futs.join_next(), if !self.connections_futs.is_empty() => { - dbg!(); - if let Some(Ok(_id)) = maybe_id { - // self.connections.remove_entry(&id); + if let Some(Ok((node_id, conn_id))) = maybe_id { + self.connections.get_mut(&node_id).map(|m| m.remove(&conn_id)); } }, - else => { - dbg!(); - break - }, + else => break, } } } @@ -500,6 +495,7 @@ impl Connection { mod test { use std::time::Duration; + use libsqlx::result_builder::ResultBuilder; use tokio::sync::Notify; use crate::allocation::replica::ReplicaConnection; From c05dc09e2b3a43ab86f7d81825afa5dfd4c0c181 Mon Sep 17 00:00:00 2001 From: ad hoc Date: Wed, 26 Jul 2023 14:57:32 +0200 Subject: [PATCH 44/64] spawn compactor loop --- libsqlx-server/src/main.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/libsqlx-server/src/main.rs b/libsqlx-server/src/main.rs index b1856e9e..1392e325 100644 --- a/libsqlx-server/src/main.rs +++ b/libsqlx-server/src/main.rs @@ -1,10 +1,10 @@ use std::fs::read_to_string; -use std::path::{PathBuf, Path}; +use std::path::{Path, PathBuf}; use std::sync::Arc; use clap::Parser; use color_eyre::eyre::Result; -use compactor::{CompactionQueue, run_compactor_loop}; +use compactor::{run_compactor_loop, CompactionQueue}; use config::{AdminApiConfig, ClusterConfig, UserApiConfig}; use http::admin::run_admin_api; use http::user::run_user_api; @@ -116,7 +116,7 @@ async fn main() -> Result<()> { .map_size(100 * 1024 * 1024) .open(config.db_path.join("meta"))?; - let snapshot_store = Arc::new(SnapshotStore::new(config.db_path.clone(), &env)?); + let snapshot_store = Arc::new(SnapshotStore::new(config.db_path.clone(), env.clone())?); let compaction_queue = Arc::new(CompactionQueue::new( env.clone(), config.db_path.clone(), From f86954c32f9f2ad48e6daa07b5a5363bd184bc68 Mon Sep 17 00:00:00 2001 From: ad hoc Date: Wed, 26 Jul 2023 14:57:47 +0200 Subject: [PATCH 45/64] changes to compactor trait --- Cargo.lock | 1 + libsqlx-server/Cargo.toml | 1 + libsqlx/src/database/libsql/injector/mod.rs | 11 +++----- libsqlx/src/database/libsql/mod.rs | 26 +++++++++-------- .../database/libsql/replication_log/logger.rs | 28 ++++++------------- .../database/libsql/replication_log/mod.rs | 3 +- 6 files changed, 30 insertions(+), 40 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index a46f82e5..53a34d6f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2619,6 +2619,7 @@ dependencies = [ "tracing-subscriber", "turmoil", "uuid", + "walkdir", ] [[package]] diff --git a/libsqlx-server/Cargo.toml b/libsqlx-server/Cargo.toml index b925732a..1ce8c6c8 100644 --- a/libsqlx-server/Cargo.toml +++ b/libsqlx-server/Cargo.toml @@ -48,3 +48,4 @@ uuid = { version = "1.4.0", features = ["v4", "serde"] } [dev-dependencies] turmoil = "0.5.5" +walkdir = "2.3.3" diff --git a/libsqlx/src/database/libsql/injector/mod.rs b/libsqlx/src/database/libsql/injector/mod.rs index 0c2c2207..7580bd5f 100644 --- a/libsqlx/src/database/libsql/injector/mod.rs +++ b/libsqlx/src/database/libsql/injector/mod.rs @@ -171,15 +171,14 @@ impl Injector { #[cfg(test)] mod test { - use std::fs::File; + use std::path::PathBuf; use crate::database::libsql::injector::Injector; use crate::database::libsql::replication_log::logger::LogFile; #[test] fn test_simple_inject_frames() { - let file = File::open("assets/test/simple_wallog").unwrap(); - let log = LogFile::new(file).unwrap(); + let log = LogFile::new(PathBuf::from("assets/test/simple_wallog")).unwrap(); let temp = tempfile::tempdir().unwrap(); let mut injector = Injector::new(temp.path(), Box::new(()), 10).unwrap(); @@ -199,8 +198,7 @@ mod test { #[test] fn test_inject_frames_split_txn() { - let file = File::open("assets/test/simple_wallog").unwrap(); - let log = LogFile::new(file).unwrap(); + let log = LogFile::new(PathBuf::from("assets/test/simple_wallog")).unwrap(); let temp = tempfile::tempdir().unwrap(); // inject one frame at a time @@ -221,8 +219,7 @@ mod test { #[test] fn test_inject_partial_txn_isolated() { - let file = File::open("assets/test/simple_wallog").unwrap(); - let log = LogFile::new(file).unwrap(); + let log = LogFile::new(PathBuf::from("assets/test/simple_wallog")).unwrap(); let temp = tempfile::tempdir().unwrap(); // inject one frame at a time diff --git a/libsqlx/src/database/libsql/mod.rs b/libsqlx/src/database/libsql/mod.rs index a6fdceb1..1cc884a8 100644 --- a/libsqlx/src/database/libsql/mod.rs +++ b/libsqlx/src/database/libsql/mod.rs @@ -19,7 +19,6 @@ use self::replication_log::logger::FrameNotifierCb; pub use connection::LibsqlConnection; pub use replication_log::logger::{LogCompactor, LogFile}; -pub use replication_log::merger::SnapshotMerger; mod connection; mod injector; @@ -196,12 +195,12 @@ impl InjectableDatabase for LibsqlDatabase { #[cfg(test)] mod test { - use std::fs::File; use std::sync::atomic::AtomicBool; use std::sync::atomic::Ordering::Relaxed; use parking_lot::Mutex; use rusqlite::types::Value; + use uuid::Uuid; use crate::connection::Connection; use crate::database::libsql::replication_log::logger::LogFile; @@ -238,8 +237,7 @@ mod test { .unwrap(); assert!(row.lock().is_empty()); - let file = File::open("assets/test/simple_wallog").unwrap(); - let log = LogFile::new(file).unwrap(); + let log = LogFile::new(PathBuf::from("assets/test/simple_wallog")).unwrap(); let mut injector = db.injector().unwrap(); log.frames_iter().unwrap().for_each(|f| { injector.inject(f.unwrap()).unwrap(); @@ -312,14 +310,16 @@ mod test { } fn compact( - &self, - _file: LogFile, - _path: PathBuf, - _size_after: u32, + &mut self, + _id: Uuid, ) -> Result<(), Box> { self.0.store(true, Relaxed); Ok(()) } + + fn snapshot_dir(&self) -> PathBuf { + todo!(); + } } let temp = tempfile::tempdir().unwrap(); @@ -353,13 +353,15 @@ mod test { } fn compact( - &self, - _file: LogFile, - _path: PathBuf, - _size_after: u32, + &mut self, + _id: Uuid, ) -> Result<(), Box> { unreachable!() } + + fn snapshot_dir(&self) -> PathBuf { + todo!() + } } let temp = tempfile::tempdir().unwrap(); diff --git a/libsqlx/src/database/libsql/replication_log/logger.rs b/libsqlx/src/database/libsql/replication_log/logger.rs index 96520ef8..914153b5 100644 --- a/libsqlx/src/database/libsql/replication_log/logger.rs +++ b/libsqlx/src/database/libsql/replication_log/logger.rs @@ -26,7 +26,6 @@ use crate::database::frame::{Frame, FrameHeader}; #[cfg(feature = "bottomless")] use crate::libsql::ffi::SQLITE_IOERR_WRITE; -use super::snapshot::{find_snapshot_file, SnapshotFile}; use super::{FrameNo, CRC_64_GO_ISO, WAL_MAGIC, WAL_PAGE_SIZE}; init_static_wal_method!(REPLICATION_METHODS, ReplicationLoggerHook); @@ -767,8 +766,7 @@ pub trait LogCompactor: Sync + Send + 'static { impl LogCompactor for () { fn compact( &mut self, - log_name: String, - path: PathBuf, + _log_id: Uuid, ) -> Result<(), Box> { Ok(()) } @@ -776,6 +774,10 @@ impl LogCompactor for () { fn should_compact(&self, _file: &LogFile) -> bool { false } + + fn snapshot_dir(&self) -> PathBuf { + todo!() + } } pub type FrameNotifierCb = Box; @@ -784,7 +786,6 @@ pub struct ReplicationLogger { pub generation: Generation, pub log_file: RwLock, compactor: Box>, - db_path: PathBuf, /// a notifier channel other tasks can subscribe to, and get notified when new frames become /// available. pub new_frame_notifier: FrameNotifierCb, @@ -821,17 +822,11 @@ impl ReplicationLogger { if should_recover { Self::recover(log_file, data_path, compactor, new_frame_notifier) } else { - Self::from_log_file( - db_path.to_path_buf(), - log_file, - compactor, - new_frame_notifier, - ) + Self::from_log_file(log_file, compactor, new_frame_notifier) } } fn from_log_file( - db_path: PathBuf, log_file: LogFile, compactor: impl LogCompactor, new_frame_notifier: FrameNotifierCb, @@ -843,7 +838,6 @@ impl ReplicationLogger { generation: Generation::new(generation_start_frame_no), compactor: Box::new(Mutex::new(compactor)), log_file: RwLock::new(log_file), - db_path, new_frame_notifier, }) } @@ -885,7 +879,7 @@ impl ReplicationLogger { assert!(data_path.pop()); - Self::from_log_file(data_path, log_file, compactor, new_frame_notifier) + Self::from_log_file(log_file, compactor, new_frame_notifier) } pub fn database_id(&self) -> anyhow::Result { @@ -930,10 +924,6 @@ impl ReplicationLogger { .expect("there should be at least one frame after commit")) } - pub fn get_snapshot_file(&self, from: FrameNo) -> anyhow::Result> { - find_snapshot_file(&self.db_path, from) - } - pub fn get_frame(&self, frame_no: FrameNo) -> Result { self.log_file.read().frame(frame_no) } @@ -1036,8 +1026,8 @@ mod test { #[test] fn log_file_test_rollback() { - let f = tempfile::tempfile().unwrap(); - let mut log_file = LogFile::new(f).unwrap(); + let f = tempfile::NamedTempFile::new().unwrap(); + let mut log_file = LogFile::new(f.path().to_path_buf()).unwrap(); (0..5) .map(|i| WalPage { page_no: i, diff --git a/libsqlx/src/database/libsql/replication_log/mod.rs b/libsqlx/src/database/libsql/replication_log/mod.rs index 42b2a03f..32120285 100644 --- a/libsqlx/src/database/libsql/replication_log/mod.rs +++ b/libsqlx/src/database/libsql/replication_log/mod.rs @@ -1,8 +1,7 @@ use crc::Crc; pub mod logger; -pub mod merger; -pub mod snapshot; +// pub mod merger; pub const WAL_PAGE_SIZE: i32 = 4096; pub const WAL_MAGIC: u64 = u64::from_le_bytes(*b"SQLDWAL\0"); From c815ea82e545903caf946fe8cf20ce750f9d9ba5 Mon Sep 17 00:00:00 2001 From: ad hoc Date: Wed, 26 Jul 2023 15:58:44 +0200 Subject: [PATCH 46/64] intoduce SnapshotFrame --- libsqlx-server/src/compactor.rs | 72 +++++++++++++++++++++------- libsqlx-server/src/snapshot_store.rs | 3 +- 2 files changed, 55 insertions(+), 20 deletions(-) diff --git a/libsqlx-server/src/compactor.rs b/libsqlx-server/src/compactor.rs index ac88a455..517fca77 100644 --- a/libsqlx-server/src/compactor.rs +++ b/libsqlx-server/src/compactor.rs @@ -1,3 +1,4 @@ +use std::borrow::Cow; use std::fs::File; use std::io::{BufWriter, Write}; use std::mem::size_of; @@ -8,7 +9,7 @@ use std::sync::{ Arc, }; -use bytemuck::{bytes_of, pod_read_unaligned, Pod, Zeroable}; +use bytemuck::{bytes_of, pod_read_unaligned, Pod, Zeroable, try_from_bytes}; use bytes::{Bytes, BytesMut}; use heed::byteorder::BigEndian; use heed_types::{SerdeBincode, U64}; @@ -168,6 +169,39 @@ pub struct SnapshotBuilder { last_seen_frame_no: u64, } +#[derive(Debug, Clone, Copy, Pod, Zeroable)] +#[repr(C)] +pub struct SnapshotFrameHeader { + frame_no: FrameNo, + page_no: u32, + _pad: u32, +} + +#[derive(Clone)] +pub struct SnapshotFrame { + data: Bytes +} + +impl SnapshotFrame { + const SIZE: usize = size_of::() + 4096; + + pub fn try_from_bytes(data: Bytes) -> crate::Result { + if data.len() != Self::SIZE { + color_eyre::eyre::bail!("invalid snapshot frame") + } + + Ok(Self { data }) + + } + + pub fn header(&self) -> Cow { + let data = &self.data[..size_of::()]; + try_from_bytes(data) + .map(Cow::Borrowed) + .unwrap_or_else(|_| Cow::Owned(pod_read_unaligned(data))) + } +} + impl SnapshotBuilder { pub fn new(db_path: &Path, db_id: DatabaseId, snapshot_id: Uuid) -> color_eyre::Result { let temp_dir = db_path.join("tmp"); @@ -202,8 +236,15 @@ impl SnapshotBuilder { self.header.end_frame_no = frame.header().frame_no; self.header.size_after = frame.header().size_after; } + let header = SnapshotFrameHeader { + frame_no: frame.header().frame_no, + page_no: frame.header().page_no, + _pad: 0, + }; + + self.snapshot_file.write_all(bytes_of(&header))?; + self.snapshot_file.write_all(frame.page())?; - self.snapshot_file.write_all(frame.as_slice())?; self.header.frame_count += 1; Ok(()) @@ -241,18 +282,18 @@ impl SnapshotFile { } /// Iterator on the frames contained in the snapshot file, in reverse frame_no order. - pub fn frames_iter(&self) -> impl Iterator> + '_ { + pub fn frames_iter(&self) -> impl Iterator> + '_ { let mut current_offset = 0; std::iter::from_fn(move || { if current_offset >= self.header.frame_count { return None; } let read_offset = size_of::() as u64 - + current_offset * LogFile::FRAME_SIZE as u64; + + current_offset * SnapshotFrame::SIZE as u64; current_offset += 1; - let mut buf = BytesMut::zeroed(LogFile::FRAME_SIZE); + let mut buf = BytesMut::zeroed(SnapshotFrame::SIZE); match self.file.read_exact_at(&mut buf, read_offset as _) { - Ok(_) => Some(Ok(buf.freeze())), + Ok(_) => Some(Ok(SnapshotFrame { data: buf.freeze() })), Err(e) => Some(Err(e.into())), } }) @@ -262,18 +303,15 @@ impl SnapshotFile { pub fn frames_iter_from( &self, frame_no: u64, - ) -> impl Iterator> + '_ { + ) -> impl Iterator> + '_ { let mut iter = self.frames_iter(); std::iter::from_fn(move || match iter.next() { - Some(Ok(bytes)) => match Frame::try_from_bytes(bytes.clone()) { - Ok(frame) => { - if frame.header().frame_no < frame_no { - None - } else { - Some(Ok(bytes)) - } + Some(Ok(frame)) => { + if frame.header().frame_no < frame_no { + None + } else { + Some(Ok(frame)) } - Err(e) => Some(Err(e)), }, other => other, }) @@ -332,14 +370,12 @@ mod test { assert_eq!(snapshot_file.header.start_frame_no, expected_start_frameno); assert_eq!(snapshot_file.header.end_frame_no, expected_end_frameno); assert!(snapshot_file.frames_iter().all(|f| expected_page_content - .remove(&Frame::try_from_bytes(f.unwrap()).unwrap().header().page_no))); + .remove(&f.unwrap().header().page_no))); assert!(expected_page_content.is_empty()); assert_eq!(snapshot_file .frames_iter() .map(Result::unwrap) - .map(Frame::try_from_bytes) - .map(Result::unwrap) .map(|f| f.header().frame_no) .reduce(|prev, new| { assert!(new < prev); diff --git a/libsqlx-server/src/snapshot_store.rs b/libsqlx-server/src/snapshot_store.rs index 68610c10..72977966 100644 --- a/libsqlx-server/src/snapshot_store.rs +++ b/libsqlx-server/src/snapshot_store.rs @@ -2,8 +2,7 @@ use std::mem::size_of; use std::path::PathBuf; use bytemuck::{Pod, Zeroable}; -use heed::BytesDecode; -use heed_types::{ByteSlice, CowType, SerdeBincode}; +use heed_types::{CowType, SerdeBincode}; use libsqlx::FrameNo; use serde::{Deserialize, Serialize}; use tokio::task::block_in_place; From e10d2f341f34de7b9d14c76626e23e61f1b36d94 Mon Sep 17 00:00:00 2001 From: ad hoc Date: Wed, 26 Jul 2023 15:58:59 +0200 Subject: [PATCH 47/64] add missing test assets --- libsqlx-server/assets/test/simple-log | Bin 0 -> 20664 bytes 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 libsqlx-server/assets/test/simple-log diff --git a/libsqlx-server/assets/test/simple-log b/libsqlx-server/assets/test/simple-log new file mode 100644 index 0000000000000000000000000000000000000000..0cf5b0539c802b0e595616afe0e8cdf18e936254 GIT binary patch literal 20664 zcmeI)ze@sP9LMo{p6kbuaV-*I!y`dtLk$fR1kv-TPaL#5guK6;_dI^_9`0V-^Eoh^@s~D|zGLfVZvNxpw0;~} zy&SJxtUTI&Z+jxmIMn(g+C4Lr8u{95efr(@vKfCXD5_FJEb8GVVA&Vo| Date: Thu, 27 Jul 2023 10:40:47 +0200 Subject: [PATCH 48/64] replicate from snapshot --- libsqlx-server/src/allocation/mod.rs | 5 +- libsqlx-server/src/allocation/primary/mod.rs | 54 +++++++++++- libsqlx-server/src/compactor.rs | 75 ++++++++++------ libsqlx-server/src/snapshot_store.rs | 20 ++++- libsqlx/src/database/frame.rs | 2 - .../database/libsql/replication_log/logger.rs | 85 +++++-------------- .../database/libsql/replication_log/mod.rs | 3 - libsqlx/src/database/mod.rs | 2 +- libsqlx/src/lib.rs | 2 +- 9 files changed, 146 insertions(+), 102 deletions(-) diff --git a/libsqlx-server/src/allocation/mod.rs b/libsqlx-server/src/allocation/mod.rs index f6d87b24..0978b549 100644 --- a/libsqlx-server/src/allocation/mod.rs +++ b/libsqlx-server/src/allocation/mod.rs @@ -103,7 +103,7 @@ impl Database { Compactor::new( max_log_size, replication_log_compact_interval, - compaction_queue, + compaction_queue.clone(), database_id, ), false, @@ -124,6 +124,7 @@ impl Database { db: Arc::new(db), replica_streams: HashMap::new(), frame_notifier: receiver, + snapshot_store: compaction_queue.snapshot_store.clone(), }, compact_interval, } @@ -275,6 +276,7 @@ impl Allocation { db, replica_streams, frame_notifier, + snapshot_store, .. }, .. @@ -289,6 +291,7 @@ impl Allocation { dipatcher: self.dispatcher.clone() as _, notifier: frame_notifier.clone(), buffer: Vec::new(), + snapshot_store: snapshot_store.clone(), }; match replica_streams.entry(msg.from) { diff --git a/libsqlx-server/src/allocation/primary/mod.rs b/libsqlx-server/src/allocation/primary/mod.rs index 480a0e6a..ccd67c55 100644 --- a/libsqlx-server/src/allocation/primary/mod.rs +++ b/libsqlx-server/src/allocation/primary/mod.rs @@ -2,17 +2,19 @@ use std::collections::HashMap; use std::mem::size_of; use std::sync::Arc; use std::task::{Context, Poll}; +use std::time::Duration; use bytes::Bytes; use libsqlx::libsql::{LibsqlDatabase, PrimaryType}; use libsqlx::result_builder::ResultBuilder; -use libsqlx::{FrameNo, LogReadError, ReplicationLogger}; +use libsqlx::{Frame, FrameHeader, FrameNo, LogReadError, ReplicationLogger}; use tokio::task::block_in_place; use crate::linc::bus::Dispatch; use crate::linc::proto::{BuilderStep, Enveloppe, Frames, Message, StepError, Value}; use crate::linc::{Inbound, NodeId, Outbound}; use crate::meta::DatabaseId; +use crate::snapshot_store::SnapshotStore; use super::{ConnectionHandler, ExecFn, FRAMES_MESSAGE_MAX_COUNT}; @@ -24,6 +26,7 @@ pub struct PrimaryDatabase { pub db: Arc>, pub replica_streams: HashMap)>, pub frame_notifier: tokio::sync::watch::Receiver, + pub snapshot_store: Arc, } pub struct ProxyResponseBuilder { @@ -206,6 +209,7 @@ pub struct FrameStreamer { pub dipatcher: Arc, pub notifier: tokio::sync::watch::Receiver, pub buffer: Vec, + pub snapshot_store: Arc, } impl FrameStreamer { @@ -234,7 +238,53 @@ impl FrameStreamer { } } Err(LogReadError::Error(_)) => todo!("handle log read error"), - Err(LogReadError::SnapshotRequired) => todo!("handle reading from snapshot"), + Err(LogReadError::SnapshotRequired) => self.send_snapshot().await, + } + } + } + + async fn send_snapshot(&mut self) { + tracing::debug!("sending frames from snapshot"); + loop { + match self + .snapshot_store + .locate_file(self.database_id, self.next_frame_no) + { + Some(file) => { + let mut iter = file.frames_iter_from(self.next_frame_no).peekable(); + + while let Some(frame) = block_in_place(|| iter.next()) { + let frame = frame.unwrap(); + // TODO: factorize in maybe_send + if self.buffer.len() > FRAMES_MESSAGE_MAX_COUNT { + self.send_frames().await; + } + let size_after = iter + .peek() + .is_none() + .then_some(file.header.size_after) + .unwrap_or(0); + let frame = Frame::from_parts( + &FrameHeader { + frame_no: frame.header().frame_no, + page_no: frame.header().page_no, + size_after, + }, + frame.page(), + ); + self.next_frame_no = frame.header().frame_no + 1; + self.buffer.push(frame.bytes()); + + tokio::task::yield_now().await; + } + + break; + } + None => { + // snapshot is not ready yet, wait a bit + // FIXME: notify when snapshot becomes ready instead of using loop + tokio::time::sleep(Duration::from_millis(100)).await; + } } } } diff --git a/libsqlx-server/src/compactor.rs b/libsqlx-server/src/compactor.rs index 517fca77..2039343a 100644 --- a/libsqlx-server/src/compactor.rs +++ b/libsqlx-server/src/compactor.rs @@ -9,7 +9,7 @@ use std::sync::{ Arc, }; -use bytemuck::{bytes_of, pod_read_unaligned, Pod, Zeroable, try_from_bytes}; +use bytemuck::{bytes_of, pod_read_unaligned, try_from_bytes, Pod, Zeroable}; use bytes::{Bytes, BytesMut}; use heed::byteorder::BigEndian; use heed_types::{SerdeBincode, U64}; @@ -38,7 +38,7 @@ pub struct CompactionQueue { next_id: AtomicU64, notify: watch::Sender>, db_path: PathBuf, - snapshot_store: Arc, + pub snapshot_store: Arc, } impl CompactionQueue { @@ -110,9 +110,17 @@ impl CompactionQueue { let to_compact_path = to_compact_path.clone(); let db_path = self.db_path.clone(); move || { - let mut builder = SnapshotBuilder::new(&db_path, job.database_id, job.log_id)?; let log = LogFile::new(to_compact_path)?; - for frame in log.rev_deduped() { + let (start_fno, end_fno, iter) = + log.rev_deduped().expect("compaction job with no frames!"); + let mut builder = SnapshotBuilder::new( + &db_path, + job.database_id, + job.log_id, + start_fno, + end_fno, + )?; + for frame in iter { let frame = frame?; builder.push_frame(frame)?; } @@ -172,14 +180,14 @@ pub struct SnapshotBuilder { #[derive(Debug, Clone, Copy, Pod, Zeroable)] #[repr(C)] pub struct SnapshotFrameHeader { - frame_no: FrameNo, - page_no: u32, + pub frame_no: FrameNo, + pub page_no: u32, _pad: u32, } #[derive(Clone)] pub struct SnapshotFrame { - data: Bytes + data: Bytes, } impl SnapshotFrame { @@ -191,7 +199,6 @@ impl SnapshotFrame { } Ok(Self { data }) - } pub fn header(&self) -> Cow { @@ -200,10 +207,20 @@ impl SnapshotFrame { .map(Cow::Borrowed) .unwrap_or_else(|_| Cow::Owned(pod_read_unaligned(data))) } + + pub(crate) fn page(&self) -> &[u8] { + &self.data[size_of::()..] + } } impl SnapshotBuilder { - pub fn new(db_path: &Path, db_id: DatabaseId, snapshot_id: Uuid) -> color_eyre::Result { + pub fn new( + db_path: &Path, + db_id: DatabaseId, + snapshot_id: Uuid, + start_fno: FrameNo, + end_fno: FrameNo, + ) -> color_eyre::Result { let temp_dir = db_path.join("tmp"); let mut target = BufWriter::new(NamedTempFile::new_in(&temp_dir)?); // reserve header space @@ -212,8 +229,8 @@ impl SnapshotBuilder { Ok(Self { header: SnapshotFileHeader { db_id, - start_frame_no: u64::MAX, - end_frame_no: u64::MIN, + start_frame_no: start_fno, + end_frame_no: end_fno, frame_count: 0, size_after: 0, _pad: 0, @@ -228,14 +245,11 @@ impl SnapshotBuilder { pub fn push_frame(&mut self, frame: Frame) -> color_eyre::Result<()> { assert!(frame.header().frame_no < self.last_seen_frame_no); self.last_seen_frame_no = frame.header().frame_no; - if frame.header().frame_no < self.header.start_frame_no { - self.header.start_frame_no = frame.header().frame_no; - } - if frame.header().frame_no > self.header.end_frame_no { - self.header.end_frame_no = frame.header().frame_no; + if frame.header().frame_no == self.header.end_frame_no { self.header.size_after = frame.header().size_after; } + let header = SnapshotFrameHeader { frame_no: frame.header().frame_no, page_no: frame.header().page_no, @@ -306,13 +320,13 @@ impl SnapshotFile { ) -> impl Iterator> + '_ { let mut iter = self.frames_iter(); std::iter::from_fn(move || match iter.next() { - Some(Ok(frame)) => { + Some(Ok(frame)) => { if frame.header().frame_no < frame_no { None } else { Some(Ok(frame)) } - }, + } other => other, }) } @@ -369,18 +383,23 @@ mod test { let snapshot_file = SnapshotFile::open(&snapshot_path).unwrap(); assert_eq!(snapshot_file.header.start_frame_no, expected_start_frameno); assert_eq!(snapshot_file.header.end_frame_no, expected_end_frameno); - assert!(snapshot_file.frames_iter().all(|f| expected_page_content - .remove(&f.unwrap().header().page_no))); + assert!(snapshot_file + .frames_iter() + .all(|f| expected_page_content.remove(&f.unwrap().header().page_no))); assert!(expected_page_content.is_empty()); - assert_eq!(snapshot_file - .frames_iter() - .map(Result::unwrap) - .map(|f| f.header().frame_no) - .reduce(|prev, new| { - assert!(new < prev); - new - }).unwrap(), 0); + assert_eq!( + snapshot_file + .frames_iter() + .map(Result::unwrap) + .map(|f| f.header().frame_no) + .reduce(|prev, new| { + assert!(new < prev); + new + }) + .unwrap(), + 0 + ); assert_eq!(store.locate(database_id, 0).unwrap().snapshot_id, log_id); } diff --git a/libsqlx-server/src/snapshot_store.rs b/libsqlx-server/src/snapshot_store.rs index 72977966..965c65a9 100644 --- a/libsqlx-server/src/snapshot_store.rs +++ b/libsqlx-server/src/snapshot_store.rs @@ -8,7 +8,7 @@ use serde::{Deserialize, Serialize}; use tokio::task::block_in_place; use uuid::Uuid; -use crate::meta::DatabaseId; +use crate::{compactor::SnapshotFile, meta::DatabaseId}; #[derive(Clone, Copy, Zeroable, Pod, Debug)] #[repr(transparent)] @@ -91,6 +91,10 @@ impl SnapshotStore { end_frame_no: u64::MAX.into(), }; + for entry in self.database.lazily_decode_data().iter(&txn).unwrap() { + let (k, _) = entry.unwrap(); + } + match self .database .get_lower_than_or_equal_to(&txn, &key) @@ -102,6 +106,11 @@ impl SnapshotStore { } else if frame_no >= key.start_frame_no.into() && frame_no <= key.end_frame_no.into() { + tracing::debug!( + "found snapshot for {frame_no}; {}-{}", + u64::from(key.start_frame_no), + u64::from(key.end_frame_no) + ); return Some(v); } else { None @@ -110,6 +119,15 @@ impl SnapshotStore { Err(_) => todo!(), } } + + pub fn locate_file(&self, database_id: DatabaseId, frame_no: FrameNo) -> Option { + let meta = self.locate(database_id, frame_no)?; + let path = self + .db_path + .join("snapshots") + .join(meta.snapshot_id.to_string()); + Some(SnapshotFile::open(&path).unwrap()) + } } #[cfg(test)] diff --git a/libsqlx/src/database/frame.rs b/libsqlx/src/database/frame.rs index ba2d638d..d3dd4a44 100644 --- a/libsqlx/src/database/frame.rs +++ b/libsqlx/src/database/frame.rs @@ -17,8 +17,6 @@ use super::FrameNo; pub struct FrameHeader { /// Incremental frame number pub frame_no: FrameNo, - /// Rolling checksum of all the previous frames, including this one. - pub checksum: u64, /// page number, if frame_type is FrameType::Page pub page_no: u32, /// Size of the database (in page) after commiting the transaction. This is passed from sqlite, diff --git a/libsqlx/src/database/libsql/replication_log/logger.rs b/libsqlx/src/database/libsql/replication_log/logger.rs index 914153b5..d12af50d 100644 --- a/libsqlx/src/database/libsql/replication_log/logger.rs +++ b/libsqlx/src/database/libsql/replication_log/logger.rs @@ -7,7 +7,7 @@ use std::os::unix::prelude::FileExt; use std::path::{Path, PathBuf}; use std::sync::Arc; -use anyhow::{bail, ensure}; +use anyhow::bail; use bytemuck::{bytes_of, pod_read_unaligned, Pod, Zeroable}; use bytes::{Bytes, BytesMut}; use parking_lot::{Mutex, RwLock}; @@ -26,7 +26,7 @@ use crate::database::frame::{Frame, FrameHeader}; #[cfg(feature = "bottomless")] use crate::libsql::ffi::SQLITE_IOERR_WRITE; -use super::{FrameNo, CRC_64_GO_ISO, WAL_MAGIC, WAL_PAGE_SIZE}; +use super::{FrameNo, WAL_MAGIC, WAL_PAGE_SIZE}; init_static_wal_method!(REPLICATION_METHODS, ReplicationLoggerHook); @@ -354,10 +354,6 @@ pub struct LogFile { /// On rollback, this is reset to 0, so that everything that was written after the previous /// header.frame_count is ignored and can be overwritten pub(crate) uncommitted_frame_count: u64, - uncommitted_checksum: u64, - - /// checksum of the last commited frame - commited_checksum: u64, } #[derive(thiserror::Error, Debug)] @@ -392,10 +388,10 @@ impl LogFile { start_frame_no: 0, magic: WAL_MAGIC, page_size: WAL_PAGE_SIZE, - start_checksum: 0, db_id: db_id.as_u128(), frame_count: 0, sqld_version: Version::current().0, + _pad: 0, }; let mut this = Self { @@ -403,8 +399,6 @@ impl LogFile { file, header, uncommitted_frame_count: 0, - uncommitted_checksum: 0, - commited_checksum: 0, }; this.write_header()?; @@ -412,27 +406,12 @@ impl LogFile { Ok(this) } else { let header = Self::read_header(&file)?; - let mut this = Self { + Ok(Self { file, header, uncommitted_frame_count: 0, - uncommitted_checksum: 0, - commited_checksum: 0, path, - }; - - if let Some(last_commited) = this.last_commited_frame_no() { - // file is not empty, the starting checksum is the checksum from the last entry - let last_frame = this.frame(last_commited).unwrap(); - this.commited_checksum = last_frame.header().checksum; - this.uncommitted_checksum = last_frame.header().checksum; - } else { - // file contains no entry, start with the initial checksum from the file header. - this.commited_checksum = this.header.start_checksum; - this.uncommitted_checksum = this.header.start_checksum; - } - - Ok(this) + }) } } @@ -458,7 +437,6 @@ impl LogFile { pub fn commit(&mut self) -> crate::Result<()> { self.header.frame_count += self.uncommitted_frame_count; self.uncommitted_frame_count = 0; - self.commited_checksum = self.uncommitted_checksum; self.write_header()?; Ok(()) @@ -466,7 +444,6 @@ impl LogFile { fn rollback(&mut self) { self.uncommitted_frame_count = 0; - self.uncommitted_checksum = self.commited_checksum; } pub fn write_header(&mut self) -> crate::Result<()> { @@ -504,11 +481,20 @@ impl LogFile { }) } - /// Return a reversed iterator over the deduplicated frames in the log file. - pub fn rev_deduped(&self) -> impl Iterator> + '_ { + /// If the log contains any frames, returns (start_frameno, end_frameno, iter), where iter, is + /// a deduplicated reversed iterator over the frames in the log + pub fn rev_deduped( + &self, + ) -> Option<( + FrameNo, + FrameNo, + impl Iterator> + '_, + )> { let mut iter = self.rev_frames_iter(); let mut seen = HashSet::new(); - std::iter::from_fn(move || loop { + let start_fno = self.header().start_frame_no; + let end_fno = self.header().last_frame_no()?; + let iter = std::iter::from_fn(move || loop { match iter.next()? { Ok(frame) => { if !seen.contains(&frame.header().page_no) { @@ -518,21 +504,15 @@ impl LogFile { } Err(e) => return Some(Err(e)), } - }) - } + }); - fn compute_checksum(&self, page: &WalPage) -> u64 { - let mut digest = CRC_64_GO_ISO.digest_with_initial(self.uncommitted_checksum); - digest.update(&page.data); - digest.finalize() + Some((start_fno, end_fno, iter)) } pub fn push_page(&mut self, page: &WalPage) -> crate::Result<()> { - let checksum = self.compute_checksum(page); let frame = Frame::from_parts( &FrameHeader { frame_no: self.next_frame_no(), - checksum, page_no: page.page_no, size_after: page.size_after, }, @@ -547,7 +527,6 @@ impl LogFile { self.file.write_all_at(frame.as_slice(), byte_offset)?; self.uncommitted_frame_count += 1; - self.uncommitted_checksum = checksum; Ok(()) } @@ -616,13 +595,12 @@ impl LogFile { let new_header = LogFileHeader { start_frame_no: self.header.start_frame_no + self.header.frame_count, frame_count: 0, - start_checksum: self.commited_checksum, ..self.header }; new_log_file.header = new_header; new_log_file.write_header().unwrap(); - // swap old and new snapshot - atomic_rename(dbg!(&temp_log_path), dbg!(&self.path)).unwrap(); + // swap old and new log + atomic_rename(&temp_log_path, &self.path).unwrap(); std::mem::swap(&mut new_log_file.path, &mut self.path); let _ = std::mem::replace(self, new_log_file); compactor.compact(log_id).unwrap(); @@ -704,9 +682,7 @@ fn atomic_rename(p1: impl AsRef, p2: impl AsRef) -> anyhow::Result<( pub struct LogFileHeader { /// magic number: b"SQLDWAL\0" as u64 pub magic: u64, - /// Initial checksum value for the rolling CRC checksum - /// computed with the 64 bits CRC_64_GO_ISO - pub start_checksum: u64, + _pad: u64, /// Uuid of the database associated with this log. pub db_id: u128, /// Frame_no of the first frame in the log @@ -897,23 +873,6 @@ impl ReplicationLogger { Ok(()) } - #[allow(dead_code)] - fn compute_checksum(wal_header: &LogFileHeader, log_file: &LogFile) -> anyhow::Result { - tracing::debug!("computing WAL log running checksum..."); - let mut iter = log_file.frames_iter()?; - iter.try_fold(wal_header.start_checksum, |sum, frame| { - let frame = frame?; - let mut digest = CRC_64_GO_ISO.digest_with_initial(sum); - digest.update(frame.page()); - let cs = digest.finalize(); - ensure!( - cs == frame.header().checksum, - "invalid WAL file: invalid checksum" - ); - Ok(cs) - }) - } - /// commit the current transaction and returns the new top frame number fn commit(&self) -> anyhow::Result { let mut log_file = self.log_file.write(); diff --git a/libsqlx/src/database/libsql/replication_log/mod.rs b/libsqlx/src/database/libsql/replication_log/mod.rs index 32120285..a7f006ae 100644 --- a/libsqlx/src/database/libsql/replication_log/mod.rs +++ b/libsqlx/src/database/libsql/replication_log/mod.rs @@ -1,11 +1,8 @@ -use crc::Crc; - pub mod logger; // pub mod merger; pub const WAL_PAGE_SIZE: i32 = 4096; pub const WAL_MAGIC: u64 = u64::from_le_bytes(*b"SQLDWAL\0"); -const CRC_64_GO_ISO: Crc = Crc::::new(&crc::CRC_64_GO_ISO); /// The frame uniquely identifying, monotonically increasing number pub type FrameNo = u64; diff --git a/libsqlx/src/database/mod.rs b/libsqlx/src/database/mod.rs index 368ac5ac..61c39c64 100644 --- a/libsqlx/src/database/mod.rs +++ b/libsqlx/src/database/mod.rs @@ -9,7 +9,7 @@ pub mod proxy; #[cfg(test)] mod test_utils; -pub use frame::Frame; +pub use frame::{Frame, FrameHeader}; pub type FrameNo = u64; diff --git a/libsqlx/src/lib.rs b/libsqlx/src/lib.rs index 899a7912..24441571 100644 --- a/libsqlx/src/lib.rs +++ b/libsqlx/src/lib.rs @@ -15,8 +15,8 @@ pub use database::libsql; pub use database::libsql::replication_log::logger::{LogReadError, ReplicationLogger}; pub use database::libsql::replication_log::FrameNo; pub use database::proxy; -pub use database::Frame; pub use database::{Database, InjectableDatabase, Injector}; +pub use database::{Frame, FrameHeader}; pub use sqld_libsql_bindings::wal_hook::WalHook; From 4594545d012458acd801a21aab7e1611ccb5e6f9 Mon Sep 17 00:00:00 2001 From: ad hoc Date: Thu, 27 Jul 2023 13:00:34 +0200 Subject: [PATCH 49/64] add txn status change callback to libsql connection --- libsqlx-server/src/snapshot_store.rs | 4 ---- libsqlx/src/database/libsql/connection.rs | 25 ++++++++++++++--------- libsqlx/src/database/libsql/mod.rs | 1 + libsqlx/src/database/mod.rs | 4 ---- 4 files changed, 16 insertions(+), 18 deletions(-) diff --git a/libsqlx-server/src/snapshot_store.rs b/libsqlx-server/src/snapshot_store.rs index 965c65a9..32f7f0e9 100644 --- a/libsqlx-server/src/snapshot_store.rs +++ b/libsqlx-server/src/snapshot_store.rs @@ -91,10 +91,6 @@ impl SnapshotStore { end_frame_no: u64::MAX.into(), }; - for entry in self.database.lazily_decode_data().iter(&txn).unwrap() { - let (k, _) = entry.unwrap(); - } - match self .database .get_lower_than_or_equal_to(&txn, &key) diff --git a/libsqlx/src/database/libsql/connection.rs b/libsqlx/src/database/libsql/connection.rs index 1f0ea7ab..384b907c 100644 --- a/libsqlx/src/database/libsql/connection.rs +++ b/libsqlx/src/database/libsql/connection.rs @@ -1,12 +1,10 @@ use std::path::{Path, PathBuf}; use std::sync::Arc; -use std::time::Instant; use rusqlite::{OpenFlags, Statement, StatementStatus}; use sqld_libsql_bindings::wal_hook::{WalHook, WalMethodsHook}; use crate::connection::{Connection, DescribeCol, DescribeParam, DescribeResponse}; -use crate::database::TXN_TIMEOUT; use crate::error::Error; use crate::program::{Cond, Program, Step}; use crate::query::Query; @@ -50,10 +48,12 @@ where } pub struct LibsqlConnection { - timeout_deadline: Option, conn: sqld_libsql_bindings::Connection<'static>, // holds a ref to _context, must be dropped first. row_stats_handler: Option>, builder_config: QueryBuilderConfig, + /// `true` is the connection is in an open connection state + is_txn: bool, + on_txn_status_change_cb: Option>, _context: Seal::Context>>, } @@ -65,6 +65,7 @@ impl LibsqlConnection { hook_ctx: ::Context, row_stats_callback: Option>, builder_config: QueryBuilderConfig, + on_txn_status_change_cb: Option>, ) -> Result> { let mut ctx = Box::new(hook_ctx); let this = LibsqlConnection { @@ -74,9 +75,10 @@ impl LibsqlConnection { unsafe { &mut *(ctx.as_mut() as *mut _) }, None, )?, - timeout_deadline: None, + on_txn_status_change_cb, builder_config, row_stats_handler: row_stats_callback, + is_txn: false, _context: Seal::new(ctx), }; @@ -105,18 +107,12 @@ impl LibsqlConnection { let mut results = Vec::with_capacity(pgm.steps.len()); builder.init(&self.builder_config)?; - let is_autocommit_before = self.conn.is_autocommit(); for step in pgm.steps() { let res = self.execute_step(step, &results, builder)?; results.push(res); } - // A transaction is still open, set up a timeout - if is_autocommit_before && !self.conn.is_autocommit() { - self.timeout_deadline = Some(Instant::now() + TXN_TIMEOUT) - } - let is_txn = !self.conn.is_autocommit(); if !builder.finnalize(is_txn, None)? && is_txn { let _ = self.conn.execute("ROLLBACK", ()); @@ -160,6 +156,15 @@ impl LibsqlConnection { builder.finish_step(affected_row_count, last_insert_rowid)?; + let is_txn = !self.conn.is_autocommit(); + if self.is_txn != is_txn { + // txn status changed + if let Some(ref cb) = self.on_txn_status_change_cb { + cb(is_txn) + } + } + self.is_txn = is_txn; + Ok(enabled) } diff --git a/libsqlx/src/database/libsql/mod.rs b/libsqlx/src/database/libsql/mod.rs index 1cc884a8..a42bfdc7 100644 --- a/libsqlx/src/database/libsql/mod.rs +++ b/libsqlx/src/database/libsql/mod.rs @@ -178,6 +178,7 @@ impl Database for LibsqlDatabase { QueryBuilderConfig { max_size: Some(self.response_size_limit), }, + None, )?) } } diff --git a/libsqlx/src/database/mod.rs b/libsqlx/src/database/mod.rs index 61c39c64..43fa0dac 100644 --- a/libsqlx/src/database/mod.rs +++ b/libsqlx/src/database/mod.rs @@ -1,5 +1,3 @@ -use std::time::Duration; - use crate::connection::Connection; use crate::error::Error; @@ -13,8 +11,6 @@ pub use frame::{Frame, FrameHeader}; pub type FrameNo = u64; -pub const TXN_TIMEOUT: Duration = Duration::from_secs(5); - #[derive(Debug)] pub enum InjectError {} From d0fb11192e839e7bd341fb1989ef4fd822c537d0 Mon Sep 17 00:00:00 2001 From: ad hoc Date: Thu, 27 Jul 2023 13:17:00 +0200 Subject: [PATCH 50/64] add method to register txn status cahnge callback --- libsqlx/src/database/libsql/connection.rs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/libsqlx/src/database/libsql/connection.rs b/libsqlx/src/database/libsql/connection.rs index 384b907c..0ad8b780 100644 --- a/libsqlx/src/database/libsql/connection.rs +++ b/libsqlx/src/database/libsql/connection.rs @@ -222,6 +222,10 @@ impl LibsqlConnection { Ok((affected_row_count, last_insert_rowid)) } + + pub fn set_on_txn_status_change_cb(&mut self, cb: impl Fn(bool) + Send + Sync + 'static) { + self.on_txn_status_change_cb = Some(Box::new(cb)); + } } fn eval_cond(cond: &Cond, results: &[bool]) -> Result { From c8f095fd8efd26a4ba69996b62c4c55f39f0c80c Mon Sep 17 00:00:00 2001 From: ad hoc Date: Thu, 27 Jul 2023 13:31:54 +0200 Subject: [PATCH 51/64] set on txn status change callback for allocation connection --- libsqlx-server/src/allocation/mod.rs | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/libsqlx-server/src/allocation/mod.rs b/libsqlx-server/src/allocation/mod.rs index 0978b549..75275335 100644 --- a/libsqlx-server/src/allocation/mod.rs +++ b/libsqlx-server/src/allocation/mod.rs @@ -163,22 +163,28 @@ impl Database { } } - fn connect(&self, connection_id: u32, alloc: &Allocation) -> impl ConnectionHandler { + fn connect(&self, connection_id: u32, alloc: &Allocation, on_txn_status_change_cb: impl Fn(bool) + Send + Sync + 'static) -> impl ConnectionHandler { match self { Database::Primary { db: PrimaryDatabase { db, .. }, .. - } => Either::Right(PrimaryConnection { - conn: db.connect().unwrap(), - }), - Database::Replica { db, primary_id, .. } => Either::Left(ReplicaConnection { - conn: db.connect().unwrap(), + } => { + let mut conn = db.connect().unwrap(); + conn.set_on_txn_status_change_cb(on_txn_status_change_cb); + Either::Right(PrimaryConnection { + conn, + }) }, + Database::Replica { db, primary_id, .. } => { + let mut conn = db.connect().unwrap(); + conn.reader_mut().set_on_txn_status_change_cb(on_txn_status_change_cb); + Either::Left(ReplicaConnection { + conn, connection_id, next_req_id: 0, primary_node_id: *primary_id, database_id: DatabaseId::from_name(&alloc.db_name), dispatcher: alloc.dispatcher.clone(), - }), + }) }, } } @@ -395,7 +401,7 @@ impl Allocation { async fn new_conn(&mut self, remote: Option<(NodeId, u32)>) -> ConnectionHandle { let conn_id = self.next_conn_id(); - let conn = block_in_place(|| self.database.connect(conn_id, self)); + let conn = block_in_place(|| self.database.connect(conn_id, self, |_|())); let (exec_sender, exec_receiver) = mpsc::channel(1); let (inbound_sender, inbound_receiver) = mpsc::channel(1); let id = remote.unwrap_or((self.dispatcher.node_id(), conn_id)); From 734f7ccd7076c014b0f97573e5612e6ce5b0c2e4 Mon Sep 17 00:00:00 2001 From: ad hoc Date: Thu, 27 Jul 2023 15:33:10 +0200 Subject: [PATCH 52/64] connection monitor txn timeout --- libsqlx-server/src/allocation/mod.rs | 147 +++++++++++++------ libsqlx-server/src/allocation/primary/mod.rs | 16 +- libsqlx-server/src/allocation/replica.rs | 72 ++++----- libsqlx-server/src/hrana/batch.rs | 25 +--- libsqlx-server/src/hrana/stmt.rs | 16 +- 5 files changed, 168 insertions(+), 108 deletions(-) diff --git a/libsqlx-server/src/allocation/mod.rs b/libsqlx-server/src/allocation/mod.rs index 75275335..1aff2665 100644 --- a/libsqlx-server/src/allocation/mod.rs +++ b/libsqlx-server/src/allocation/mod.rs @@ -4,17 +4,20 @@ use std::future::poll_fn; use std::path::PathBuf; use std::pin::Pin; use std::sync::Arc; -use std::task::{ready, Context, Poll}; +use std::task::{ready, Context, Poll, Waker}; use std::time::Instant; use either::Either; +use futures::{Future, FutureExt}; use libsqlx::libsql::LibsqlDatabase; use libsqlx::program::Program; use libsqlx::proxy::WriteProxyDatabase; +use libsqlx::result_builder::ResultBuilder; use libsqlx::{Database as _, InjectableDatabase}; +use parking_lot::Mutex; use tokio::sync::{mpsc, oneshot}; use tokio::task::{block_in_place, JoinSet}; -use tokio::time::Interval; +use tokio::time::{Interval, Sleep, sleep_until}; use crate::allocation::primary::FrameStreamer; use crate::compactor::CompactionQueue; @@ -40,7 +43,13 @@ const FRAMES_MESSAGE_MAX_COUNT: usize = 5; /// Maximum number of frames in the injector buffer const MAX_INJECTOR_BUFFER_CAPACITY: usize = 32; -type ExecFn = Box; +pub enum ConnectionMessage { + Execute { + pgm: Program, + builder: Box, + }, + Describe, +} pub enum AllocationMessage { HranaPipelineReq { @@ -210,25 +219,15 @@ pub struct Allocation { #[derive(Clone)] pub struct ConnectionHandle { - exec: mpsc::Sender, + messages: mpsc::Sender, inbound: mpsc::Sender, } impl ConnectionHandle { - pub async fn exec(&self, f: F) -> crate::Result - where - F: for<'a> FnOnce(&'a mut dyn libsqlx::Connection) -> R + Send + 'static, - R: Send + 'static, + pub async fn execute(&self, pgm: Program, builder: Box) -> crate::Result<()> { - let (sender, ret) = oneshot::channel(); - let cb = move |conn: &mut dyn libsqlx::Connection| { - let res = f(conn); - let _ = sender.send(res); - }; - - self.exec.send(Box::new(cb)).await.unwrap(); - - Ok(ret.await?) + self.messages.send(ConnectionMessage::Execute { pgm, builder }).await.unwrap(); + Ok(()) } } @@ -368,18 +367,14 @@ impl Allocation { let dispatcher = self.dispatcher.clone(); let database_id = DatabaseId::from_name(&self.db_name); let exec = |conn: ConnectionHandle| async move { - let _ = conn - .exec(move |conn| { - let builder = ProxyResponseBuilder::new( - dispatcher, - database_id, - to, - req_id, - connection_id, - ); - conn.execute_program(&program, Box::new(builder)).unwrap(); - }) - .await; + let builder = ProxyResponseBuilder::new( + dispatcher, + database_id, + to, + req_id, + connection_id, + ); + conn.execute(program, Box::new(builder)).await.unwrap(); }; if self.database.is_primary() { @@ -402,19 +397,22 @@ impl Allocation { async fn new_conn(&mut self, remote: Option<(NodeId, u32)>) -> ConnectionHandle { let conn_id = self.next_conn_id(); let conn = block_in_place(|| self.database.connect(conn_id, self, |_|())); - let (exec_sender, exec_receiver) = mpsc::channel(1); + let (messages_sender, messages_receiver) = mpsc::channel(1); let (inbound_sender, inbound_receiver) = mpsc::channel(1); let id = remote.unwrap_or((self.dispatcher.node_id(), conn_id)); + let (timeout_monitor, _) = timeout_monitor(); let conn = Connection { id, conn, - exec: exec_receiver, + messages: messages_receiver, inbound: inbound_receiver, + last_txn_timedout: false, + timeout_monitor, }; self.connections_futs.spawn(conn.run()); let handle = ConnectionHandle { - exec: exec_sender, + messages: messages_sender, inbound: inbound_sender, }; self.connections @@ -442,7 +440,7 @@ impl Allocation { #[async_trait::async_trait] trait ConnectionHandler: Send { fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<()>; - async fn handle_exec(&mut self, exec: ExecFn); + async fn handle_conn_message(&mut self, exec: ConnectionMessage); async fn handle_inbound(&mut self, msg: Inbound); } @@ -459,10 +457,10 @@ where } } - async fn handle_exec(&mut self, exec: ExecFn) { + async fn handle_conn_message(&mut self, msg: ConnectionMessage) { match self { - Either::Left(l) => l.handle_exec(exec).await, - Either::Right(r) => r.handle_exec(exec).await, + Either::Left(l) => l.handle_conn_message(msg).await, + Either::Right(r) => r.handle_conn_message(msg).await, } } async fn handle_inbound(&mut self, msg: Inbound) { @@ -473,24 +471,89 @@ where } } + +fn timeout_monitor() -> (TimeoutMonitor, TimeoutNotifier) { + let inner = Arc::new(Mutex::new(TimeoutInner { + sleep: Box::pin(sleep_until(Instant::now().into())), + enabled: false, + waker: None, + })); + + (TimeoutMonitor { inner: inner.clone()}, TimeoutNotifier { inner }) +} + +struct TimeoutMonitor { + inner: Arc> +} + +struct TimeoutNotifier { + inner: Arc> +} + +impl TimeoutNotifier { + pub fn disable(&self) { + self.inner.lock().enabled = false; + } + + pub fn timeout_at(&self, at: Instant) { + let mut inner = self.inner.lock(); + inner.enabled = true; + inner.sleep.as_mut().reset(at.into()); + if let Some(waker) = inner.waker.take() { + waker.wake() + } + } +} + +struct TimeoutInner { + sleep: Pin>, + enabled: bool, + waker: Option, +} + +impl Future for TimeoutMonitor { + type Output = (); + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let mut inner = self.inner.lock(); + if inner.enabled { + inner.sleep.poll_unpin(cx) + } else { + inner.waker.replace(cx.waker().clone()); + Poll::Pending + } + } +} + struct Connection { id: (NodeId, u32), conn: C, - exec: mpsc::Receiver, + messages: mpsc::Receiver, inbound: mpsc::Receiver, + last_txn_timedout: bool, + timeout_monitor: TimeoutMonitor, } impl Connection { async fn run(mut self) -> (NodeId, u32) { loop { - let fut = - futures::future::join(self.exec.recv(), poll_fn(|cx| self.conn.poll_ready(cx))); + let message_ready = + futures::future::join(self.messages.recv(), poll_fn(|cx| self.conn.poll_ready(cx))); + tokio::select! { + _ = &mut self.timeout_monitor => { + self.last_txn_timedout = true; + } Some(inbound) = self.inbound.recv() => { self.conn.handle_inbound(inbound).await; } - (Some(exec), _) = fut => { - self.conn.handle_exec(exec).await; + (Some(msg), _) = message_ready => { + if self.last_txn_timedout{ + self.last_txn_timedout = false; + todo!("handle txn timeout"); + } else { + self.conn.handle_conn_message(msg).await; + } }, else => break, } @@ -537,7 +600,7 @@ mod test { let connection = Connection { id: (0, 0), conn, - exec, + messages, inbound, }; diff --git a/libsqlx-server/src/allocation/primary/mod.rs b/libsqlx-server/src/allocation/primary/mod.rs index ccd67c55..b06f3967 100644 --- a/libsqlx-server/src/allocation/primary/mod.rs +++ b/libsqlx-server/src/allocation/primary/mod.rs @@ -7,7 +7,7 @@ use std::time::Duration; use bytes::Bytes; use libsqlx::libsql::{LibsqlDatabase, PrimaryType}; use libsqlx::result_builder::ResultBuilder; -use libsqlx::{Frame, FrameHeader, FrameNo, LogReadError, ReplicationLogger}; +use libsqlx::{Frame, FrameHeader, FrameNo, LogReadError, ReplicationLogger, Connection}; use tokio::task::block_in_place; use crate::linc::bus::Dispatch; @@ -16,7 +16,7 @@ use crate::linc::{Inbound, NodeId, Outbound}; use crate::meta::DatabaseId; use crate::snapshot_store::SnapshotStore; -use super::{ConnectionHandler, ExecFn, FRAMES_MESSAGE_MAX_COUNT}; +use super::{ConnectionHandler, FRAMES_MESSAGE_MAX_COUNT, ConnectionMessage}; pub mod compactor; @@ -317,8 +317,16 @@ impl ConnectionHandler for PrimaryConnection { Poll::Ready(()) } - async fn handle_exec(&mut self, exec: ExecFn) { - block_in_place(|| exec(&mut self.conn)); + async fn handle_conn_message(&mut self, msg: ConnectionMessage) { + match msg { + ConnectionMessage::Execute { pgm, builder } => { + self.conn.execute_program(&pgm, builder).unwrap() + } + ConnectionMessage::Describe => { + todo!() + } + + } } async fn handle_inbound(&mut self, _msg: Inbound) { diff --git a/libsqlx-server/src/allocation/replica.rs b/libsqlx-server/src/allocation/replica.rs index ee8008d6..433bf2eb 100644 --- a/libsqlx-server/src/allocation/replica.rs +++ b/libsqlx-server/src/allocation/replica.rs @@ -9,7 +9,7 @@ use libsqlx::libsql::{LibsqlConnection, LibsqlDatabase, ReplicaType}; use libsqlx::program::Program; use libsqlx::proxy::{WriteProxyConnection, WriteProxyDatabase}; use libsqlx::result_builder::{Column, QueryBuilderConfig, ResultBuilder}; -use libsqlx::{DescribeResponse, Frame, FrameNo, Injector}; +use libsqlx::{DescribeResponse, Frame, FrameNo, Injector, Connection}; use parking_lot::Mutex; use tokio::sync::mpsc; use tokio::task::block_in_place; @@ -22,7 +22,7 @@ use crate::linc::Inbound; use crate::linc::{NodeId, Outbound}; use crate::meta::DatabaseId; -use super::{ConnectionHandler, ExecFn}; +use super::{ConnectionHandler, ConnectionMessage}; type ProxyConnection = WriteProxyConnection, RemoteConn>; pub type ProxyDatabase = WriteProxyDatabase, RemoteDb>; @@ -287,40 +287,46 @@ impl ConnectionHandler for ReplicaConnection { Poll::Ready(()) } - async fn handle_exec(&mut self, exec: ExecFn) { - block_in_place(|| exec(&mut self.conn)); - let msg = { - let mut lock = self.conn.writer().inner.current_req.lock(); - match *lock { - Some(ref mut req) if req.id.is_none() => { - let program = req - .pgm - .take() - .expect("unsent request should have a program"); - let req_id = self.next_req_id; - self.next_req_id += 1; - req.id = Some(req_id); - - let msg = Outbound { - to: self.primary_node_id, - enveloppe: Enveloppe { - database_id: Some(self.database_id), - message: Message::ProxyRequest { - connection_id: self.connection_id, - req_id, - program, - }, - }, - }; + async fn handle_conn_message(&mut self, msg: ConnectionMessage) { + match msg { + ConnectionMessage::Execute { pgm, builder } => { + self.conn.execute_program(&pgm, builder).unwrap(); + let msg = { + let mut lock = self.conn.writer().inner.current_req.lock(); + match *lock { + Some(ref mut req) if req.id.is_none() => { + let program = req + .pgm + .take() + .expect("unsent request should have a program"); + let req_id = self.next_req_id; + self.next_req_id += 1; + req.id = Some(req_id); + + let msg = Outbound { + to: self.primary_node_id, + enveloppe: Enveloppe { + database_id: Some(self.database_id), + message: Message::ProxyRequest { + connection_id: self.connection_id, + req_id, + program, + }, + }, + }; + + Some(msg) + } + _ => None, + } + }; - Some(msg) + if let Some(msg) = msg { + self.dispatcher.dispatch(msg).await; } - _ => None, - } - }; - if let Some(msg) = msg { - self.dispatcher.dispatch(msg).await; + } + ConnectionMessage::Describe => (), } } diff --git a/libsqlx-server/src/hrana/batch.rs b/libsqlx-server/src/hrana/batch.rs index a9ed0553..fab788d0 100644 --- a/libsqlx-server/src/hrana/batch.rs +++ b/libsqlx-server/src/hrana/batch.rs @@ -74,15 +74,10 @@ pub async fn execute_batch( db: &ConnectionHandle, pgm: Program, ) -> color_eyre::Result { - let fut = db - .exec(move |conn| -> color_eyre::Result<_> { - let (builder, ret) = HranaBatchProtoBuilder::new(); - conn.execute_program(&pgm, Box::new(builder))?; - Ok(ret) - }) - .await??; + let (builder, ret) = HranaBatchProtoBuilder::new(); + db.execute(pgm, Box::new(builder)).await?; - Ok(fut.await?) + Ok(ret.await?) } pub fn proto_sequence_to_program(sql: &str) -> color_eyre::Result { @@ -111,17 +106,11 @@ pub fn proto_sequence_to_program(sql: &str) -> color_eyre::Result { } pub async fn execute_sequence(conn: &ConnectionHandle, pgm: Program) -> color_eyre::Result<()> { - let fut = conn - .exec(move |conn| -> color_eyre::Result<_> { - let (snd, rcv) = oneshot::channel(); - let builder = StepResultsBuilder::new(snd); - conn.execute_program(&pgm, Box::new(builder))?; - - Ok(rcv) - }) - .await??; + let (snd, rcv) = oneshot::channel(); + let builder = StepResultsBuilder::new(snd); + conn.execute(pgm, Box::new(builder)).await?; - fut.await?.into_iter().try_for_each(|result| match result { + rcv.await?.into_iter().try_for_each(|result| match result { StepResult::Ok => Ok(()), StepResult::Err(e) => match stmt_error_from_sqld_error(e) { Ok(stmt_err) => Err(anyhow!(stmt_err)), diff --git a/libsqlx-server/src/hrana/stmt.rs b/libsqlx-server/src/hrana/stmt.rs index 1a8c03f6..7753fa81 100644 --- a/libsqlx-server/src/hrana/stmt.rs +++ b/libsqlx-server/src/hrana/stmt.rs @@ -47,17 +47,11 @@ pub async fn execute_stmt( conn: &ConnectionHandle, query: Query, ) -> color_eyre::Result { - let fut = conn - .exec(move |conn| -> color_eyre::Result<_> { - let (builder, ret) = SingleStatementBuilder::new(); - let pgm = libsqlx::program::Program::from_queries(std::iter::once(query)); - conn.execute_program(&pgm, Box::new(builder))?; - - Ok(ret) - }) - .await??; - - fut.await? + + let (builder, ret) = SingleStatementBuilder::new(); + let pgm = libsqlx::program::Program::from_queries(std::iter::once(query)); + conn.execute(pgm, Box::new(builder)).await?; + ret.await? .map_err(|sqld_error| match stmt_error_from_sqld_error(sqld_error) { Ok(stmt_error) => anyhow!(stmt_error), Err(sqld_error) => anyhow!(sqld_error), From dbe2520fd651a29f587857f22b10fffd4d500773 Mon Sep 17 00:00:00 2001 From: ad hoc Date: Thu, 27 Jul 2023 15:41:07 +0200 Subject: [PATCH 53/64] notify on transaction status change --- libsqlx-server/src/allocation/mod.rs | 25 ++++++++++++++++++++----- 1 file changed, 20 insertions(+), 5 deletions(-) diff --git a/libsqlx-server/src/allocation/mod.rs b/libsqlx-server/src/allocation/mod.rs index 1aff2665..f1fb4733 100644 --- a/libsqlx-server/src/allocation/mod.rs +++ b/libsqlx-server/src/allocation/mod.rs @@ -5,7 +5,7 @@ use std::path::PathBuf; use std::pin::Pin; use std::sync::Arc; use std::task::{ready, Context, Poll, Waker}; -use std::time::Instant; +use std::time::{Instant, Duration}; use either::Either; use futures::{Future, FutureExt}; @@ -395,12 +395,22 @@ impl Allocation { } async fn new_conn(&mut self, remote: Option<(NodeId, u32)>) -> ConnectionHandle { + // TODO: make that configurable + const TXN_TIMEOUT_DURATION: Duration = Duration::from_secs(5); + let conn_id = self.next_conn_id(); - let conn = block_in_place(|| self.database.connect(conn_id, self, |_|())); + let (timeout_monitor, notifier) = timeout_monitor(); + let conn = block_in_place(|| self.database.connect(conn_id, self, move |is_txn| { + if is_txn { + notifier.timeout_at(Instant::now() + TXN_TIMEOUT_DURATION); + } else { + notifier.disable(); + } + })); + let (messages_sender, messages_receiver) = mpsc::channel(1); let (inbound_sender, inbound_receiver) = mpsc::channel(1); let id = remote.unwrap_or((self.dispatcher.node_id(), conn_id)); - let (timeout_monitor, _) = timeout_monitor(); let conn = Connection { id, conn, @@ -548,9 +558,14 @@ impl Connection { self.conn.handle_inbound(inbound).await; } (Some(msg), _) = message_ready => { - if self.last_txn_timedout{ + if self.last_txn_timedout { self.last_txn_timedout = false; - todo!("handle txn timeout"); + match msg { + ConnectionMessage::Execute { mut builder, .. } => { + let _ = builder.finnalize_error("transaction has timed out".into()); + }, + ConnectionMessage::Describe => todo!(), + } } else { self.conn.handle_conn_message(msg).await; } From 96e52d4b1af1d7cac728a82b127df777820c57e1 Mon Sep 17 00:00:00 2001 From: ad hoc Date: Thu, 27 Jul 2023 15:56:41 +0200 Subject: [PATCH 54/64] move timeout_notifier to own module --- libsqlx-server/src/allocation/mod.rs | 151 +++++++----------- libsqlx-server/src/allocation/primary/mod.rs | 5 +- libsqlx-server/src/allocation/replica.rs | 3 +- .../src/allocation/timeout_notifier.rs | 88 ++++++++++ libsqlx-server/src/hrana/stmt.rs | 1 - 5 files changed, 145 insertions(+), 103 deletions(-) create mode 100644 libsqlx-server/src/allocation/timeout_notifier.rs diff --git a/libsqlx-server/src/allocation/mod.rs b/libsqlx-server/src/allocation/mod.rs index f1fb4733..2b348915 100644 --- a/libsqlx-server/src/allocation/mod.rs +++ b/libsqlx-server/src/allocation/mod.rs @@ -4,22 +4,21 @@ use std::future::poll_fn; use std::path::PathBuf; use std::pin::Pin; use std::sync::Arc; -use std::task::{ready, Context, Poll, Waker}; -use std::time::{Instant, Duration}; +use std::task::{ready, Context, Poll}; +use std::time::{Duration, Instant}; use either::Either; -use futures::{Future, FutureExt}; use libsqlx::libsql::LibsqlDatabase; use libsqlx::program::Program; use libsqlx::proxy::WriteProxyDatabase; use libsqlx::result_builder::ResultBuilder; use libsqlx::{Database as _, InjectableDatabase}; -use parking_lot::Mutex; use tokio::sync::{mpsc, oneshot}; use tokio::task::{block_in_place, JoinSet}; -use tokio::time::{Interval, Sleep, sleep_until}; +use tokio::time::Interval; use crate::allocation::primary::FrameStreamer; +use crate::allocation::timeout_notifier::timeout_monitor; use crate::compactor::CompactionQueue; use crate::hrana; use crate::hrana::http::handle_pipeline; @@ -33,10 +32,12 @@ use self::config::{AllocConfig, DbConfig}; use self::primary::compactor::Compactor; use self::primary::{PrimaryConnection, PrimaryDatabase, ProxyResponseBuilder}; use self::replica::{ProxyDatabase, RemoteDb, ReplicaConnection, Replicator}; +use self::timeout_notifier::TimeoutMonitor; pub mod config; mod primary; mod replica; +mod timeout_notifier; /// Maximum number of frame a Frame message is allowed to contain const FRAMES_MESSAGE_MAX_COUNT: usize = 5; @@ -172,28 +173,34 @@ impl Database { } } - fn connect(&self, connection_id: u32, alloc: &Allocation, on_txn_status_change_cb: impl Fn(bool) + Send + Sync + 'static) -> impl ConnectionHandler { + fn connect( + &self, + connection_id: u32, + alloc: &Allocation, + on_txn_status_change_cb: impl Fn(bool) + Send + Sync + 'static, + ) -> impl ConnectionHandler { match self { Database::Primary { db: PrimaryDatabase { db, .. }, .. - } => { + } => { let mut conn = db.connect().unwrap(); conn.set_on_txn_status_change_cb(on_txn_status_change_cb); - Either::Right(PrimaryConnection { - conn, - }) }, + Either::Right(PrimaryConnection { conn }) + } Database::Replica { db, primary_id, .. } => { let mut conn = db.connect().unwrap(); - conn.reader_mut().set_on_txn_status_change_cb(on_txn_status_change_cb); + conn.reader_mut() + .set_on_txn_status_change_cb(on_txn_status_change_cb); Either::Left(ReplicaConnection { - conn, - connection_id, - next_req_id: 0, - primary_node_id: *primary_id, - database_id: DatabaseId::from_name(&alloc.db_name), - dispatcher: alloc.dispatcher.clone(), - }) }, + conn, + connection_id, + next_req_id: 0, + primary_node_id: *primary_id, + database_id: DatabaseId::from_name(&alloc.db_name), + dispatcher: alloc.dispatcher.clone(), + }) + } } } @@ -224,9 +231,15 @@ pub struct ConnectionHandle { } impl ConnectionHandle { - pub async fn execute(&self, pgm: Program, builder: Box) -> crate::Result<()> - { - self.messages.send(ConnectionMessage::Execute { pgm, builder }).await.unwrap(); + pub async fn execute( + &self, + pgm: Program, + builder: Box, + ) -> crate::Result<()> { + self.messages + .send(ConnectionMessage::Execute { pgm, builder }) + .await + .unwrap(); Ok(()) } } @@ -367,13 +380,8 @@ impl Allocation { let dispatcher = self.dispatcher.clone(); let database_id = DatabaseId::from_name(&self.db_name); let exec = |conn: ConnectionHandle| async move { - let builder = ProxyResponseBuilder::new( - dispatcher, - database_id, - to, - req_id, - connection_id, - ); + let builder = + ProxyResponseBuilder::new(dispatcher, database_id, to, req_id, connection_id); conn.execute(program, Box::new(builder)).await.unwrap(); }; @@ -400,13 +408,15 @@ impl Allocation { let conn_id = self.next_conn_id(); let (timeout_monitor, notifier) = timeout_monitor(); - let conn = block_in_place(|| self.database.connect(conn_id, self, move |is_txn| { - if is_txn { - notifier.timeout_at(Instant::now() + TXN_TIMEOUT_DURATION); - } else { - notifier.disable(); - } - })); + let conn = block_in_place(|| { + self.database.connect(conn_id, self, move |is_txn| { + if is_txn { + notifier.timeout_at(Instant::now() + TXN_TIMEOUT_DURATION); + } else { + notifier.disable(); + } + }) + }); let (messages_sender, messages_receiver) = mpsc::channel(1); let (inbound_sender, inbound_receiver) = mpsc::channel(1); @@ -481,60 +491,6 @@ where } } - -fn timeout_monitor() -> (TimeoutMonitor, TimeoutNotifier) { - let inner = Arc::new(Mutex::new(TimeoutInner { - sleep: Box::pin(sleep_until(Instant::now().into())), - enabled: false, - waker: None, - })); - - (TimeoutMonitor { inner: inner.clone()}, TimeoutNotifier { inner }) -} - -struct TimeoutMonitor { - inner: Arc> -} - -struct TimeoutNotifier { - inner: Arc> -} - -impl TimeoutNotifier { - pub fn disable(&self) { - self.inner.lock().enabled = false; - } - - pub fn timeout_at(&self, at: Instant) { - let mut inner = self.inner.lock(); - inner.enabled = true; - inner.sleep.as_mut().reset(at.into()); - if let Some(waker) = inner.waker.take() { - waker.wake() - } - } -} - -struct TimeoutInner { - sleep: Pin>, - enabled: bool, - waker: Option, -} - -impl Future for TimeoutMonitor { - type Output = (); - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let mut inner = self.inner.lock(); - if inner.enabled { - inner.sleep.poll_unpin(cx) - } else { - inner.waker.replace(cx.waker().clone()); - Poll::Pending - } - } -} - struct Connection { id: (NodeId, u32), conn: C, @@ -610,13 +566,16 @@ mod test { dispatcher: bus, }; - let (exec_sender, exec) = mpsc::channel(1); + let (messages_sender, messages) = mpsc::channel(1); let (_inbound_sender, inbound) = mpsc::channel(1); + let (timeout_monitor, _) = timeout_monitor(); let connection = Connection { id: (0, 0), conn, messages, inbound, + timeout_monitor, + last_txn_timedout: false, }; let handle = tokio::spawn(connection.run()); @@ -630,13 +589,11 @@ mod test { } let builder = Box::new(Builder(notify.clone())); - exec_sender - .send(Box::new(move |conn| { - conn.execute_program(&Program::seq(&["create table test (c)"]), builder) - .unwrap(); - })) - .await - .unwrap(); + let msg = ConnectionMessage::Execute { + pgm: Program::seq(&["create table test (c)"]), + builder, + }; + messages_sender.send(msg).await.unwrap(); notify.notified().await; diff --git a/libsqlx-server/src/allocation/primary/mod.rs b/libsqlx-server/src/allocation/primary/mod.rs index b06f3967..505333cb 100644 --- a/libsqlx-server/src/allocation/primary/mod.rs +++ b/libsqlx-server/src/allocation/primary/mod.rs @@ -7,7 +7,7 @@ use std::time::Duration; use bytes::Bytes; use libsqlx::libsql::{LibsqlDatabase, PrimaryType}; use libsqlx::result_builder::ResultBuilder; -use libsqlx::{Frame, FrameHeader, FrameNo, LogReadError, ReplicationLogger, Connection}; +use libsqlx::{Connection, Frame, FrameHeader, FrameNo, LogReadError, ReplicationLogger}; use tokio::task::block_in_place; use crate::linc::bus::Dispatch; @@ -16,7 +16,7 @@ use crate::linc::{Inbound, NodeId, Outbound}; use crate::meta::DatabaseId; use crate::snapshot_store::SnapshotStore; -use super::{ConnectionHandler, FRAMES_MESSAGE_MAX_COUNT, ConnectionMessage}; +use super::{ConnectionHandler, ConnectionMessage, FRAMES_MESSAGE_MAX_COUNT}; pub mod compactor; @@ -325,7 +325,6 @@ impl ConnectionHandler for PrimaryConnection { ConnectionMessage::Describe => { todo!() } - } } diff --git a/libsqlx-server/src/allocation/replica.rs b/libsqlx-server/src/allocation/replica.rs index 433bf2eb..5ff7d897 100644 --- a/libsqlx-server/src/allocation/replica.rs +++ b/libsqlx-server/src/allocation/replica.rs @@ -9,7 +9,7 @@ use libsqlx::libsql::{LibsqlConnection, LibsqlDatabase, ReplicaType}; use libsqlx::program::Program; use libsqlx::proxy::{WriteProxyConnection, WriteProxyDatabase}; use libsqlx::result_builder::{Column, QueryBuilderConfig, ResultBuilder}; -use libsqlx::{DescribeResponse, Frame, FrameNo, Injector, Connection}; +use libsqlx::{Connection, DescribeResponse, Frame, FrameNo, Injector}; use parking_lot::Mutex; use tokio::sync::mpsc; use tokio::task::block_in_place; @@ -324,7 +324,6 @@ impl ConnectionHandler for ReplicaConnection { if let Some(msg) = msg { self.dispatcher.dispatch(msg).await; } - } ConnectionMessage::Describe => (), } diff --git a/libsqlx-server/src/allocation/timeout_notifier.rs b/libsqlx-server/src/allocation/timeout_notifier.rs new file mode 100644 index 00000000..a9cbc49d --- /dev/null +++ b/libsqlx-server/src/allocation/timeout_notifier.rs @@ -0,0 +1,88 @@ +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll, Waker}; +use std::time::Instant; + +use futures::{Future, FutureExt}; +use parking_lot::Mutex; +use tokio::time::{sleep_until, Sleep}; + +pub fn timeout_monitor() -> (TimeoutMonitor, TimeoutNotifier) { + let inner = Arc::new(Mutex::new(TimeoutInner { + sleep: Box::pin(sleep_until(Instant::now().into())), + enabled: false, + waker: None, + })); + + ( + TimeoutMonitor { + inner: inner.clone(), + }, + TimeoutNotifier { inner }, + ) +} + +pub struct TimeoutMonitor { + inner: Arc>, +} + +pub struct TimeoutNotifier { + inner: Arc>, +} + +impl TimeoutNotifier { + pub fn disable(&self) { + self.inner.lock().enabled = false; + } + + pub fn timeout_at(&self, at: Instant) { + let mut inner = self.inner.lock(); + inner.enabled = true; + inner.sleep.as_mut().reset(at.into()); + if let Some(waker) = inner.waker.take() { + waker.wake() + } + } +} + +struct TimeoutInner { + sleep: Pin>, + enabled: bool, + waker: Option, +} + +impl Future for TimeoutMonitor { + type Output = (); + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let mut inner = self.inner.lock(); + if inner.enabled { + inner.sleep.poll_unpin(cx) + } else { + inner.waker.replace(cx.waker().clone()); + Poll::Pending + } + } +} + +#[cfg(test)] +mod test { + use std::time::Duration; + + use super::*; + + #[tokio::test] + async fn set_timeout() { + let (monitor, notifier) = timeout_monitor(); + notifier.timeout_at(Instant::now() + Duration::from_millis(100)); + monitor.await; + } + + #[tokio::test] + async fn disable_timeout() { + let (monitor, notifier) = timeout_monitor(); + notifier.timeout_at(Instant::now() + Duration::from_millis(1)); + notifier.disable(); + assert!(tokio::time::timeout(Duration::from_millis(10), monitor).await.is_err()); + } +} diff --git a/libsqlx-server/src/hrana/stmt.rs b/libsqlx-server/src/hrana/stmt.rs index 7753fa81..1b843367 100644 --- a/libsqlx-server/src/hrana/stmt.rs +++ b/libsqlx-server/src/hrana/stmt.rs @@ -47,7 +47,6 @@ pub async fn execute_stmt( conn: &ConnectionHandle, query: Query, ) -> color_eyre::Result { - let (builder, ret) = SingleStatementBuilder::new(); let pgm = libsqlx::program::Program::from_queries(std::iter::once(query)); conn.execute(pgm, Box::new(builder)).await?; From 14b098f5fafa428c507b7cd9043cfb656bb56153 Mon Sep 17 00:00:00 2001 From: ad hoc Date: Thu, 27 Jul 2023 16:53:59 +0200 Subject: [PATCH 55/64] test txn timeout --- libsqlx-server/src/allocation/config.rs | 2 + libsqlx-server/src/allocation/mod.rs | 77 +++++++++++++++++-- .../src/allocation/timeout_notifier.rs | 4 +- libsqlx-server/src/hrana/batch.rs | 2 +- libsqlx-server/src/http/admin.rs | 12 +++ libsqlx/src/result_builder.rs | 15 +++- 6 files changed, 103 insertions(+), 9 deletions(-) diff --git a/libsqlx-server/src/allocation/config.rs b/libsqlx-server/src/allocation/config.rs index f0c13870..13de097d 100644 --- a/libsqlx-server/src/allocation/config.rs +++ b/libsqlx-server/src/allocation/config.rs @@ -25,9 +25,11 @@ pub enum DbConfig { max_log_size: usize, /// Interval at which to force compaction replication_log_compact_interval: Option, + transaction_timeout_duration: Duration, }, Replica { primary_node_id: NodeId, proxy_request_timeout_duration: Duration, + transaction_timeout_duration: Duration, }, } diff --git a/libsqlx-server/src/allocation/mod.rs b/libsqlx-server/src/allocation/mod.rs index 2b348915..7300feea 100644 --- a/libsqlx-server/src/allocation/mod.rs +++ b/libsqlx-server/src/allocation/mod.rs @@ -64,12 +64,14 @@ pub enum Database { Primary { db: PrimaryDatabase, compact_interval: Option>>, + transaction_timeout_duration: Duration, }, Replica { db: ProxyDatabase, injector_handle: mpsc::Sender, primary_id: NodeId, last_received_frame_ts: Option, + transaction_timeout_duration: Duration, }, } @@ -78,6 +80,7 @@ impl Database { if let Self::Primary { compact_interval: Some(ref mut interval), db, + .. } = self { ready!(interval.poll_tick(cx)); @@ -91,6 +94,13 @@ impl Database { Poll::Pending } + + fn txn_timeout_duration(&self) -> Duration { + match self { + Database::Primary { transaction_timeout_duration, .. } => *transaction_timeout_duration, + Database::Replica { transaction_timeout_duration, .. } => *transaction_timeout_duration, + } + } } impl Database { @@ -106,6 +116,7 @@ impl Database { DbConfig::Primary { max_log_size, replication_log_compact_interval, + transaction_timeout_duration, } => { let (sender, receiver) = tokio::sync::watch::channel(0); let db = LibsqlDatabase::new_primary( @@ -137,11 +148,13 @@ impl Database { snapshot_store: compaction_queue.snapshot_store.clone(), }, compact_interval, + transaction_timeout_duration, } } DbConfig::Replica { primary_node_id, proxy_request_timeout_duration, + transaction_timeout_duration, } => { let rdb = LibsqlDatabase::new_replica(path, MAX_INJECTOR_BUFFER_CAPACITY, ()).unwrap(); @@ -168,6 +181,7 @@ impl Database { injector_handle: sender, primary_id: primary_node_id, last_received_frame_ts: None, + transaction_timeout_duration, } } } @@ -403,15 +417,13 @@ impl Allocation { } async fn new_conn(&mut self, remote: Option<(NodeId, u32)>) -> ConnectionHandle { - // TODO: make that configurable - const TXN_TIMEOUT_DURATION: Duration = Duration::from_secs(5); - let conn_id = self.next_conn_id(); let (timeout_monitor, notifier) = timeout_monitor(); + let timeout = self.database.txn_timeout_duration(); let conn = block_in_place(|| { self.database.connect(conn_id, self, move |is_txn| { if is_txn { - notifier.timeout_at(Instant::now() + TXN_TIMEOUT_DURATION); + notifier.timeout_at(Instant::now() + timeout); } else { notifier.disable(); } @@ -538,11 +550,15 @@ impl Connection { mod test { use std::time::Duration; - use libsqlx::result_builder::ResultBuilder; + use heed::EnvOpenOptions; + use libsqlx::result_builder::{ResultBuilder, StepResultsBuilder}; + use tempfile::tempdir; use tokio::sync::Notify; use crate::allocation::replica::ReplicaConnection; + use crate::init_dirs; use crate::linc::bus::Bus; + use crate::snapshot_store::SnapshotStore; use super::*; @@ -599,4 +615,55 @@ mod test { handle.abort(); } + + #[tokio::test(flavor = "multi_thread", worker_threads = 1)] + async fn txn_timeout() { + let bus = Arc::new(Bus::new(0, |_, _| async {})); + let tmp = tempdir().unwrap(); + init_dirs(tmp.path()).await.unwrap(); + let config = AllocConfig { + max_conccurent_connection: 10, + db_name: "test/db".to_owned(), + db_config: DbConfig::Primary { + max_log_size: 100000, + replication_log_compact_interval: None, + transaction_timeout_duration: Duration::from_millis(100), + }, + }; + let (sender, inbox) = mpsc::channel(10); + let env = EnvOpenOptions::new().max_dbs(10).map_size(4096 * 100).open(tmp.path()).unwrap(); + let store = Arc::new(SnapshotStore::new(tmp.path().to_path_buf(), env.clone()).unwrap()); + let queue = Arc::new(CompactionQueue::new(env, tmp.path().to_path_buf(), store).unwrap()); + let mut alloc = Allocation { + inbox, + database: Database::from_config( + &config, + tmp.path().to_path_buf(), + bus.clone(), + queue, + ), + connections_futs: JoinSet::new(), + next_conn_id: 0, + max_concurrent_connections: config.max_conccurent_connection, + hrana_server: Arc::new(hrana::http::Server::new(None)), + dispatcher: bus, // TODO: handle self URL? + db_name: config.db_name, + connections: HashMap::new(), + }; + + let conn = alloc.new_conn(None).await; + tokio::spawn(alloc.run()); + + let (snd, rcv) = oneshot::channel(); + let builder = StepResultsBuilder::new(snd); + conn.execute(Program::seq(&["begin"]), Box::new(builder)).await.unwrap(); + rcv.await.unwrap().unwrap(); + + tokio::time::sleep(Duration::from_secs(1)).await; + + let (snd, rcv) = oneshot::channel(); + let builder = StepResultsBuilder::new(snd); + conn.execute(Program::seq(&["create table test (x)"]), Box::new(builder)).await.unwrap(); + assert!(rcv.await.unwrap().is_err()); + } } diff --git a/libsqlx-server/src/allocation/timeout_notifier.rs b/libsqlx-server/src/allocation/timeout_notifier.rs index a9cbc49d..b64c71a8 100644 --- a/libsqlx-server/src/allocation/timeout_notifier.rs +++ b/libsqlx-server/src/allocation/timeout_notifier.rs @@ -83,6 +83,8 @@ mod test { let (monitor, notifier) = timeout_monitor(); notifier.timeout_at(Instant::now() + Duration::from_millis(1)); notifier.disable(); - assert!(tokio::time::timeout(Duration::from_millis(10), monitor).await.is_err()); + assert!(tokio::time::timeout(Duration::from_millis(10), monitor) + .await + .is_err()); } } diff --git a/libsqlx-server/src/hrana/batch.rs b/libsqlx-server/src/hrana/batch.rs index fab788d0..6d41c8b4 100644 --- a/libsqlx-server/src/hrana/batch.rs +++ b/libsqlx-server/src/hrana/batch.rs @@ -110,7 +110,7 @@ pub async fn execute_sequence(conn: &ConnectionHandle, pgm: Program) -> color_ey let builder = StepResultsBuilder::new(snd); conn.execute(pgm, Box::new(builder)).await?; - rcv.await?.into_iter().try_for_each(|result| match result { + rcv.await?.map_err(|e| anyhow!("{e}"))?.into_iter().try_for_each(|result| match result { StepResult::Ok => Ok(()), StepResult::Err(e) => match stmt_error_from_sqld_error(e) { Ok(stmt_err) => Err(anyhow!(stmt_err)), diff --git a/libsqlx-server/src/http/admin.rs b/libsqlx-server/src/http/admin.rs index 9b51b7ed..ff718674 100644 --- a/libsqlx-server/src/http/admin.rs +++ b/libsqlx-server/src/http/admin.rs @@ -63,6 +63,8 @@ pub struct Primary { #[serde(default = "default_max_log_size")] pub max_replication_log_size: bytesize::ByteSize, pub replication_log_compact_interval: Option, + #[serde(default = "default_txn_timeout")] + transaction_timeout_duration: HumanDuration, } #[derive(Debug)] @@ -112,6 +114,8 @@ pub enum DbConfigReq { primary_node_id: NodeId, #[serde(default = "default_proxy_timeout")] proxy_request_timeout_duration: HumanDuration, + #[serde(default = "default_txn_timeout")] + transaction_timeout_duration: HumanDuration, }, } @@ -123,6 +127,10 @@ const fn default_proxy_timeout() -> HumanDuration { HumanDuration(Duration::from_secs(5)) } +const fn default_txn_timeout() -> HumanDuration { + HumanDuration(Duration::from_secs(5)) +} + async fn allocate( State(state): State>, Json(req): Json, @@ -134,18 +142,22 @@ async fn allocate( DbConfigReq::Primary(Primary { max_replication_log_size, replication_log_compact_interval, + transaction_timeout_duration, }) => DbConfig::Primary { max_log_size: max_replication_log_size.as_u64() as usize, replication_log_compact_interval: replication_log_compact_interval .as_deref() .copied(), + transaction_timeout_duration: *transaction_timeout_duration, }, DbConfigReq::Replica { primary_node_id, proxy_request_timeout_duration, + transaction_timeout_duration, } => DbConfig::Replica { primary_node_id, proxy_request_timeout_duration: *proxy_request_timeout_duration, + transaction_timeout_duration: *transaction_timeout_duration, }, }, }; diff --git a/libsqlx/src/result_builder.rs b/libsqlx/src/result_builder.rs index d69ac35b..c5f159e7 100644 --- a/libsqlx/src/result_builder.rs +++ b/libsqlx/src/result_builder.rs @@ -196,7 +196,7 @@ impl StepResultsBuilder { } } -impl>> ResultBuilder for StepResultsBuilder { +impl, String>>> ResultBuilder for StepResultsBuilder { fn init(&mut self, _config: &QueryBuilderConfig) -> Result<(), QueryResultBuilderError> { self.current = None; self.step_results.clear(); @@ -248,9 +248,16 @@ impl>> ResultBuilder for StepResultsBuilder { self.ret .take() .expect("finnalize called more than once") - .send(std::mem::take(&mut self.step_results)); + .send(Ok(std::mem::take(&mut self.step_results))); Ok(true) } + + fn finnalize_error(&mut self, e: String) { + self.ret + .take() + .expect("finnalize called more than once") + .send(Err(e)); + } } impl ResultBuilder for () {} @@ -362,6 +369,10 @@ impl ResultBuilder for Take { ) -> Result { self.inner.finnalize(is_txn, frame_no) } + + fn finnalize_error(&mut self, e: String) { + self.inner.finnalize_error(e) + } } #[cfg(test)] From 9cd8c1aa6df07ed2cfd155ecd5dc9bfcda96ef29 Mon Sep 17 00:00:00 2001 From: ad hoc Date: Fri, 28 Jul 2023 09:17:53 +0200 Subject: [PATCH 56/64] replace commit handler with commit callback --- libsqlx-server/src/allocation/mod.rs | 43 +++++++++++++------- libsqlx-server/src/allocation/replica.rs | 7 ++-- libsqlx-server/src/hrana/batch.rs | 19 +++++---- libsqlx-server/src/meta.rs | 2 - libsqlx/src/database/libsql/injector/hook.rs | 13 +++--- libsqlx/src/database/libsql/injector/mod.rs | 36 ++++------------ libsqlx/src/database/libsql/mod.rs | 15 ++++--- 7 files changed, 62 insertions(+), 73 deletions(-) diff --git a/libsqlx-server/src/allocation/mod.rs b/libsqlx-server/src/allocation/mod.rs index 7300feea..1b7fb5c4 100644 --- a/libsqlx-server/src/allocation/mod.rs +++ b/libsqlx-server/src/allocation/mod.rs @@ -97,8 +97,14 @@ impl Database { fn txn_timeout_duration(&self) -> Duration { match self { - Database::Primary { transaction_timeout_duration, .. } => *transaction_timeout_duration, - Database::Replica { transaction_timeout_duration, .. } => *transaction_timeout_duration, + Database::Primary { + transaction_timeout_duration, + .. + } => *transaction_timeout_duration, + Database::Replica { + transaction_timeout_duration, + .. + } => *transaction_timeout_duration, } } } @@ -156,8 +162,13 @@ impl Database { proxy_request_timeout_duration, transaction_timeout_duration, } => { - let rdb = - LibsqlDatabase::new_replica(path, MAX_INJECTOR_BUFFER_CAPACITY, ()).unwrap(); + // TODO: set commit handler + let rdb = LibsqlDatabase::new_replica( + path, + MAX_INJECTOR_BUFFER_CAPACITY, + Arc::new(|_| ()), + ) + .unwrap(); let wdb = RemoteDb { proxy_request_timeout_duration, }; @@ -567,7 +578,8 @@ mod test { let bus = Arc::new(Bus::new(0, |_, _| async {})); let _queue = bus.connect(1); // pretend connection to node 1 let tmp = tempfile::TempDir::new().unwrap(); - let read_db = LibsqlDatabase::new_replica(tmp.path().to_path_buf(), 1, ()).unwrap(); + let read_db = + LibsqlDatabase::new_replica(tmp.path().to_path_buf(), 1, Arc::new(|_| ())).unwrap(); let write_db = RemoteDb { proxy_request_timeout_duration: Duration::from_millis(100), }; @@ -631,17 +643,16 @@ mod test { }, }; let (sender, inbox) = mpsc::channel(10); - let env = EnvOpenOptions::new().max_dbs(10).map_size(4096 * 100).open(tmp.path()).unwrap(); + let env = EnvOpenOptions::new() + .max_dbs(10) + .map_size(4096 * 100) + .open(tmp.path()) + .unwrap(); let store = Arc::new(SnapshotStore::new(tmp.path().to_path_buf(), env.clone()).unwrap()); let queue = Arc::new(CompactionQueue::new(env, tmp.path().to_path_buf(), store).unwrap()); let mut alloc = Allocation { inbox, - database: Database::from_config( - &config, - tmp.path().to_path_buf(), - bus.clone(), - queue, - ), + database: Database::from_config(&config, tmp.path().to_path_buf(), bus.clone(), queue), connections_futs: JoinSet::new(), next_conn_id: 0, max_concurrent_connections: config.max_conccurent_connection, @@ -656,14 +667,18 @@ mod test { let (snd, rcv) = oneshot::channel(); let builder = StepResultsBuilder::new(snd); - conn.execute(Program::seq(&["begin"]), Box::new(builder)).await.unwrap(); + conn.execute(Program::seq(&["begin"]), Box::new(builder)) + .await + .unwrap(); rcv.await.unwrap().unwrap(); tokio::time::sleep(Duration::from_secs(1)).await; let (snd, rcv) = oneshot::channel(); let builder = StepResultsBuilder::new(snd); - conn.execute(Program::seq(&["create table test (x)"]), Box::new(builder)).await.unwrap(); + conn.execute(Program::seq(&["create table test (x)"]), Box::new(builder)) + .await + .unwrap(); assert!(rcv.await.unwrap().is_err()); } } diff --git a/libsqlx-server/src/allocation/replica.rs b/libsqlx-server/src/allocation/replica.rs index 5ff7d897..a31c8116 100644 --- a/libsqlx-server/src/allocation/replica.rs +++ b/libsqlx-server/src/allocation/replica.rs @@ -16,10 +16,8 @@ use tokio::task::block_in_place; use tokio::time::{timeout, Sleep}; use crate::linc::bus::Dispatch; -use crate::linc::proto::{BuilderStep, ProxyResponse}; -use crate::linc::proto::{Enveloppe, Frames, Message}; -use crate::linc::Inbound; -use crate::linc::{NodeId, Outbound}; +use crate::linc::proto::{BuilderStep, Enveloppe, Frames, Message, ProxyResponse}; +use crate::linc::{Inbound, NodeId, Outbound}; use crate::meta::DatabaseId; use super::{ConnectionHandler, ConnectionMessage}; @@ -166,6 +164,7 @@ impl Replicator { }); } } + // no news from primary for the past 5 secs, send a request again Err(_) => self.query_replicate().await, Ok(None) => break, } diff --git a/libsqlx-server/src/hrana/batch.rs b/libsqlx-server/src/hrana/batch.rs index 6d41c8b4..14cfb1c3 100644 --- a/libsqlx-server/src/hrana/batch.rs +++ b/libsqlx-server/src/hrana/batch.rs @@ -110,12 +110,15 @@ pub async fn execute_sequence(conn: &ConnectionHandle, pgm: Program) -> color_ey let builder = StepResultsBuilder::new(snd); conn.execute(pgm, Box::new(builder)).await?; - rcv.await?.map_err(|e| anyhow!("{e}"))?.into_iter().try_for_each(|result| match result { - StepResult::Ok => Ok(()), - StepResult::Err(e) => match stmt_error_from_sqld_error(e) { - Ok(stmt_err) => Err(anyhow!(stmt_err)), - Err(sqld_err) => Err(anyhow!(sqld_err)), - }, - StepResult::Skipped => Err(anyhow!("Statement in sequence was not executed")), - }) + rcv.await? + .map_err(|e| anyhow!("{e}"))? + .into_iter() + .try_for_each(|result| match result { + StepResult::Ok => Ok(()), + StepResult::Err(e) => match stmt_error_from_sqld_error(e) { + Ok(stmt_err) => Err(anyhow!(stmt_err)), + Err(sqld_err) => Err(anyhow!(sqld_err)), + }, + StepResult::Skipped => Err(anyhow!("Statement in sequence was not executed")), + }) } diff --git a/libsqlx-server/src/meta.rs b/libsqlx-server/src/meta.rs index 38a1c30b..2436839b 100644 --- a/libsqlx-server/src/meta.rs +++ b/libsqlx-server/src/meta.rs @@ -10,8 +10,6 @@ use tokio::task::block_in_place; use crate::allocation::config::AllocConfig; -type ExecFn = Box)>; - pub struct Store { env: heed::Env, alloc_config_db: heed::Database, SerdeBincode>, diff --git a/libsqlx/src/database/libsql/injector/hook.rs b/libsqlx/src/database/libsql/injector/hook.rs index 2cb5348d..57c34fbc 100644 --- a/libsqlx/src/database/libsql/injector/hook.rs +++ b/libsqlx/src/database/libsql/injector/hook.rs @@ -9,7 +9,7 @@ use crate::database::frame::FrameBorrowed; use crate::database::libsql::replication_log::WAL_PAGE_SIZE; use super::headers::Headers; -use super::{FrameBuffer, InjectorCommitHandler}; +use super::{FrameBuffer, OnCommitCb}; // Those are custom error codes returned by the replicator hook. pub const LIBSQL_INJECT_FATAL: c_int = 200; @@ -23,15 +23,15 @@ pub struct InjectorHookCtx { buffer: FrameBuffer, /// currently in a txn is_txn: bool, - commit_handler: Box, + on_commit_cb: OnCommitCb, } impl InjectorHookCtx { - pub fn new(buffer: FrameBuffer, commit_handler: Box) -> Self { + pub fn new(buffer: FrameBuffer, commit_handler: OnCommitCb) -> Self { Self { buffer, is_txn: false, - commit_handler, + on_commit_cb: commit_handler, } } @@ -45,9 +45,6 @@ impl InjectorHookCtx { let buffer = self.buffer.lock(); let (mut headers, last_frame_no, size_after) = make_page_header(buffer.iter().map(|f| &**f)); - if size_after != 0 { - self.commit_handler.pre_commit(last_frame_no)?; - } let ret = unsafe { orig( @@ -64,7 +61,7 @@ impl InjectorHookCtx { debug_assert!(headers.all_applied()); drop(headers); if size_after != 0 { - self.commit_handler.post_commit(last_frame_no)?; + (self.on_commit_cb)(last_frame_no); self.is_txn = false; } tracing::trace!("applied frame batch"); diff --git a/libsqlx/src/database/libsql/injector/mod.rs b/libsqlx/src/database/libsql/injector/mod.rs index 7580bd5f..cbc9dc80 100644 --- a/libsqlx/src/database/libsql/injector/mod.rs +++ b/libsqlx/src/database/libsql/injector/mod.rs @@ -18,6 +18,7 @@ mod headers; mod hook; pub type FrameBuffer = Arc>>; +pub type OnCommitCb = Arc; pub struct Injector { /// The injector is in a transaction state @@ -48,39 +49,15 @@ impl crate::database::Injector for Injector { /// This trait trait is used to record the last committed frame_no to the log. /// The implementer can persist the pre and post commit frame no, and compare them in the event of /// a crash; if the pre and post commit frame_no don't match, then the log may be corrupted. -pub trait InjectorCommitHandler: Send + Sync + 'static { - fn pre_commit(&mut self, frame_no: FrameNo) -> anyhow::Result<()>; - fn post_commit(&mut self, frame_no: FrameNo) -> anyhow::Result<()>; -} - -impl InjectorCommitHandler for Box { - fn pre_commit(&mut self, frame_no: FrameNo) -> anyhow::Result<()> { - self.as_mut().pre_commit(frame_no) - } - - fn post_commit(&mut self, frame_no: FrameNo) -> anyhow::Result<()> { - self.as_mut().post_commit(frame_no) - } -} - -impl InjectorCommitHandler for () { - fn pre_commit(&mut self, _frame_no: FrameNo) -> anyhow::Result<()> { - Ok(()) - } - - fn post_commit(&mut self, _frame_no: FrameNo) -> anyhow::Result<()> { - Ok(()) - } -} impl Injector { pub fn new( path: &Path, - injector_commit_handler: Box, + on_commit_cb: OnCommitCb, buffer_capacity: usize, ) -> crate::Result { let buffer = FrameBuffer::default(); - let ctx = InjectorHookCtx::new(buffer.clone(), injector_commit_handler); + let ctx = InjectorHookCtx::new(buffer.clone(), on_commit_cb); let mut ctx = Box::new(ctx); let connection = sqld_libsql_bindings::Connection::open( path, @@ -172,6 +149,7 @@ impl Injector { #[cfg(test)] mod test { use std::path::PathBuf; + use std::sync::Arc; use crate::database::libsql::injector::Injector; use crate::database::libsql::replication_log::logger::LogFile; @@ -181,7 +159,7 @@ mod test { let log = LogFile::new(PathBuf::from("assets/test/simple_wallog")).unwrap(); let temp = tempfile::tempdir().unwrap(); - let mut injector = Injector::new(temp.path(), Box::new(()), 10).unwrap(); + let mut injector = Injector::new(temp.path(), Arc::new(|_| ()), 10).unwrap(); for frame in log.frames_iter().unwrap() { let frame = frame.unwrap(); injector.inject_frame(frame).unwrap(); @@ -202,7 +180,7 @@ mod test { let temp = tempfile::tempdir().unwrap(); // inject one frame at a time - let mut injector = Injector::new(temp.path(), Box::new(()), 1).unwrap(); + let mut injector = Injector::new(temp.path(), Arc::new(|_| ()), 1).unwrap(); for frame in log.frames_iter().unwrap() { let frame = frame.unwrap(); injector.inject_frame(frame).unwrap(); @@ -223,7 +201,7 @@ mod test { let temp = tempfile::tempdir().unwrap(); // inject one frame at a time - let mut injector = Injector::new(temp.path(), Box::new(()), 10).unwrap(); + let mut injector = Injector::new(temp.path(), Arc::new(|_| ()), 10).unwrap(); let mut iter = log.frames_iter().unwrap(); assert!(injector diff --git a/libsqlx/src/database/libsql/mod.rs b/libsqlx/src/database/libsql/mod.rs index a42bfdc7..9582cf2c 100644 --- a/libsqlx/src/database/libsql/mod.rs +++ b/libsqlx/src/database/libsql/mod.rs @@ -14,7 +14,7 @@ use replication_log::logger::{ ReplicationLogger, ReplicationLoggerHook, ReplicationLoggerHookCtx, REPLICATION_METHODS, }; -use self::injector::InjectorCommitHandler; +use self::injector::OnCommitCb; use self::replication_log::logger::FrameNotifierCb; pub use connection::LibsqlConnection; @@ -44,7 +44,7 @@ impl LibsqlDbType for PrimaryType { } pub struct ReplicaType { - commit_handler: Option>, + on_commit_cb: OnCommitCb, injector_buffer_capacity: usize, } @@ -101,10 +101,10 @@ impl LibsqlDatabase { pub fn new_replica( db_path: PathBuf, injector_buffer_capacity: usize, - injector_commit_handler: impl InjectorCommitHandler, + on_commit_cb: OnCommitCb, ) -> crate::Result { let ty = ReplicaType { - commit_handler: Some(Box::new(injector_commit_handler)), + on_commit_cb, injector_buffer_capacity, }; @@ -185,10 +185,9 @@ impl Database for LibsqlDatabase { impl InjectableDatabase for LibsqlDatabase { fn injector(&mut self) -> crate::Result> { - let Some(commit_handler) = self.ty.commit_handler.take() else { panic!("there can be only one injector") }; Ok(Box::new(Injector::new( &self.db_path, - commit_handler, + self.ty.on_commit_cb.clone(), self.ty.injector_buffer_capacity, )?)) } @@ -226,7 +225,7 @@ mod test { fn inject_libsql_db() { let temp = tempfile::tempdir().unwrap(); let replica = ReplicaType { - commit_handler: Some(Box::new(())), + on_commit_cb: Arc::new(|_| ()), injector_buffer_capacity: 10, }; let mut db = LibsqlDatabase::new(temp.path().to_path_buf(), replica); @@ -269,7 +268,7 @@ mod test { let mut replica = LibsqlDatabase::new( temp_replica.path().to_path_buf(), ReplicaType { - commit_handler: Some(Box::new(())), + on_commit_cb: Arc::new(|_| ()), injector_buffer_capacity: 10, }, ); From 7031828db31ffdcad2d83c6d61ca25f8b9c1619a Mon Sep 17 00:00:00 2001 From: ad hoc Date: Fri, 28 Jul 2023 10:46:46 +0200 Subject: [PATCH 57/64] record replica commit index --- libsqlx-server/src/allocation/mod.rs | 32 ++++++++++++++++---- libsqlx-server/src/main.rs | 4 +++ libsqlx-server/src/manager.rs | 5 ++++ libsqlx-server/src/replica_commit_store.rs | 34 ++++++++++++++++++++++ 4 files changed, 69 insertions(+), 6 deletions(-) create mode 100644 libsqlx-server/src/replica_commit_store.rs diff --git a/libsqlx-server/src/allocation/mod.rs b/libsqlx-server/src/allocation/mod.rs index 1b7fb5c4..e1086d41 100644 --- a/libsqlx-server/src/allocation/mod.rs +++ b/libsqlx-server/src/allocation/mod.rs @@ -27,6 +27,7 @@ use crate::linc::bus::Dispatch; use crate::linc::proto::{Frames, Message}; use crate::linc::{Inbound, NodeId}; use crate::meta::DatabaseId; +use crate::replica_commit_store::ReplicaCommitStore; use self::config::{AllocConfig, DbConfig}; use self::primary::compactor::Compactor; @@ -115,6 +116,7 @@ impl Database { path: PathBuf, dispatcher: Arc, compaction_queue: Arc, + replica_commit_store: Arc, ) -> Self { let database_id = DatabaseId::from_name(&config.db_name); @@ -162,13 +164,22 @@ impl Database { proxy_request_timeout_duration, transaction_timeout_duration, } => { - // TODO: set commit handler + let next_frame_no = + block_in_place(|| replica_commit_store.get_commit_index(database_id)) + .map(|fno| fno + 1) + .unwrap_or(0); + + let commit_callback = Arc::new(move |fno| { + replica_commit_store.commit(database_id, fno); + }); + let rdb = LibsqlDatabase::new_replica( path, MAX_INJECTOR_BUFFER_CAPACITY, - Arc::new(|_| ()), + commit_callback, ) .unwrap(); + let wdb = RemoteDb { proxy_request_timeout_duration, }; @@ -178,7 +189,7 @@ impl Database { let replicator = Replicator::new( dispatcher, - 0, + next_frame_no, database_id, primary_node_id, injector, @@ -567,9 +578,10 @@ mod test { use tokio::sync::Notify; use crate::allocation::replica::ReplicaConnection; - use crate::init_dirs; use crate::linc::bus::Bus; + use crate::replica_commit_store::ReplicaCommitStore; use crate::snapshot_store::SnapshotStore; + use crate::{init_dirs, replica_commit_store}; use super::*; @@ -649,10 +661,18 @@ mod test { .open(tmp.path()) .unwrap(); let store = Arc::new(SnapshotStore::new(tmp.path().to_path_buf(), env.clone()).unwrap()); - let queue = Arc::new(CompactionQueue::new(env, tmp.path().to_path_buf(), store).unwrap()); + let queue = + Arc::new(CompactionQueue::new(env.clone(), tmp.path().to_path_buf(), store).unwrap()); + let replica_commit_store = Arc::new(ReplicaCommitStore::new(env.clone())); let mut alloc = Allocation { inbox, - database: Database::from_config(&config, tmp.path().to_path_buf(), bus.clone(), queue), + database: Database::from_config( + &config, + tmp.path().to_path_buf(), + bus.clone(), + queue, + replica_commit_store, + ), connections_futs: JoinSet::new(), next_conn_id: 0, max_concurrent_connections: config.max_conccurent_connection, diff --git a/libsqlx-server/src/main.rs b/libsqlx-server/src/main.rs index 1392e325..ff5c415a 100644 --- a/libsqlx-server/src/main.rs +++ b/libsqlx-server/src/main.rs @@ -12,6 +12,7 @@ use hyper::server::conn::AddrIncoming; use linc::bus::Bus; use manager::Manager; use meta::Store; +use replica_commit_store::ReplicaCommitStore; use snapshot_store::SnapshotStore; use tokio::fs::create_dir_all; use tokio::net::{TcpListener, TcpStream}; @@ -28,6 +29,7 @@ mod http; mod linc; mod manager; mod meta; +mod replica_commit_store; mod snapshot_store; #[derive(Debug, Parser)] @@ -123,11 +125,13 @@ async fn main() -> Result<()> { snapshot_store, )?); let store = Arc::new(Store::new(env.clone())); + let replica_commit_store = Arc::new(ReplicaCommitStore::new(env.clone())); let manager = Arc::new(Manager::new( config.db_path.clone(), store.clone(), 100, compaction_queue.clone(), + replica_commit_store, )); let bus = Arc::new(Bus::new(config.cluster.id, manager.clone())); diff --git a/libsqlx-server/src/manager.rs b/libsqlx-server/src/manager.rs index a3bb68d8..1b0ca7d1 100644 --- a/libsqlx-server/src/manager.rs +++ b/libsqlx-server/src/manager.rs @@ -14,12 +14,14 @@ use crate::linc::bus::Dispatch; use crate::linc::handler::Handler; use crate::linc::Inbound; use crate::meta::{DatabaseId, Store}; +use crate::replica_commit_store::ReplicaCommitStore; pub struct Manager { cache: Cache>, meta_store: Arc, db_path: PathBuf, compaction_queue: Arc, + replica_commit_store: Arc, } const MAX_ALLOC_MESSAGE_QUEUE_LEN: usize = 32; @@ -30,12 +32,14 @@ impl Manager { meta_store: Arc, max_conccurent_allocs: u64, compaction_queue: Arc, + replica_commit_store: Arc, ) -> Self { Self { cache: Cache::new(max_conccurent_allocs), meta_store, db_path, compaction_queue, + replica_commit_store, } } @@ -60,6 +64,7 @@ impl Manager { path, dispatcher.clone(), self.compaction_queue.clone(), + self.replica_commit_store.clone(), ), connections_futs: JoinSet::new(), next_conn_id: 0, diff --git a/libsqlx-server/src/replica_commit_store.rs b/libsqlx-server/src/replica_commit_store.rs new file mode 100644 index 00000000..18c0aeed --- /dev/null +++ b/libsqlx-server/src/replica_commit_store.rs @@ -0,0 +1,34 @@ +use heed_types::OwnedType; +use libsqlx::FrameNo; + +use crate::meta::DatabaseId; + +/// Stores replica last injected commit index +pub struct ReplicaCommitStore { + env: heed::Env, + database: heed::Database, OwnedType>, +} + +impl ReplicaCommitStore { + const DB_NAME: &str = "replica-commit-store"; + pub fn new(env: heed::Env) -> Self { + let mut txn = env.write_txn().unwrap(); + let database = env.create_database(&mut txn, Some(Self::DB_NAME)).unwrap(); + txn.commit().unwrap(); + + Self { env, database } + } + + pub fn commit(&self, database_id: DatabaseId, frame_no: FrameNo) { + let mut txn = self.env.write_txn().unwrap(); + self.database + .put(&mut txn, &database_id, &frame_no) + .unwrap(); + txn.commit().unwrap(); + } + + pub fn get_commit_index(&self, database_id: DatabaseId) -> Option { + let txn = self.env.read_txn().unwrap(); + self.database.get(&txn, &database_id).unwrap() + } +} From e6aaeb4bba1c8a9f4ee1cdce723a81dbe1fe96dd Mon Sep 17 00:00:00 2001 From: ad hoc Date: Fri, 28 Jul 2023 11:17:48 +0200 Subject: [PATCH 58/64] fix replication bootstrap bug --- libsqlx-server/src/allocation/mod.rs | 4 ++-- libsqlx-server/src/allocation/primary/mod.rs | 6 +++--- .../src/database/libsql/replication_log/logger.rs | 15 ++++++++------- 3 files changed, 13 insertions(+), 12 deletions(-) diff --git a/libsqlx-server/src/allocation/mod.rs b/libsqlx-server/src/allocation/mod.rs index e1086d41..f5e19ba6 100644 --- a/libsqlx-server/src/allocation/mod.rs +++ b/libsqlx-server/src/allocation/mod.rs @@ -126,7 +126,7 @@ impl Database { replication_log_compact_interval, transaction_timeout_duration, } => { - let (sender, receiver) = tokio::sync::watch::channel(0); + let (sender, receiver) = tokio::sync::watch::channel(None); let db = LibsqlDatabase::new_primary( path, Compactor::new( @@ -137,7 +137,7 @@ impl Database { ), false, Box::new(move |fno| { - let _ = sender.send(fno); + let _ = sender.send(Some(fno)); }), ) .unwrap(); diff --git a/libsqlx-server/src/allocation/primary/mod.rs b/libsqlx-server/src/allocation/primary/mod.rs index 505333cb..606171c9 100644 --- a/libsqlx-server/src/allocation/primary/mod.rs +++ b/libsqlx-server/src/allocation/primary/mod.rs @@ -25,7 +25,7 @@ const MAX_STEP_BATCH_SIZE: usize = 100_000_000; // ~100kb pub struct PrimaryDatabase { pub db: Arc>, pub replica_streams: HashMap)>, - pub frame_notifier: tokio::sync::watch::Receiver, + pub frame_notifier: tokio::sync::watch::Receiver>, pub snapshot_store: Arc, } @@ -207,7 +207,7 @@ pub struct FrameStreamer { pub req_no: u32, pub seq_no: u32, pub dipatcher: Arc, - pub notifier: tokio::sync::watch::Receiver, + pub notifier: tokio::sync::watch::Receiver>, pub buffer: Vec, pub snapshot_store: Arc, } @@ -230,7 +230,7 @@ impl FrameStreamer { } if self .notifier - .wait_for(|fno| *fno >= self.next_frame_no) + .wait_for(|fno| fno.map(|f| f >= self.next_frame_no).unwrap_or(false)) .await .is_err() { diff --git a/libsqlx/src/database/libsql/replication_log/logger.rs b/libsqlx/src/database/libsql/replication_log/logger.rs index d12af50d..38112c37 100644 --- a/libsqlx/src/database/libsql/replication_log/logger.rs +++ b/libsqlx/src/database/libsql/replication_log/logger.rs @@ -759,7 +759,6 @@ impl LogCompactor for () { pub type FrameNotifierCb = Box; pub struct ReplicationLogger { - pub generation: Generation, pub log_file: RwLock, compactor: Box>, /// a notifier channel other tasks can subscribe to, and get notified when new frames become @@ -807,15 +806,17 @@ impl ReplicationLogger { compactor: impl LogCompactor, new_frame_notifier: FrameNotifierCb, ) -> crate::Result { - let header = log_file.header(); - let generation_start_frame_no = header.start_frame_no + header.frame_count; - - Ok(Self { - generation: Generation::new(generation_start_frame_no), + let this = Self { compactor: Box::new(Mutex::new(compactor)), log_file: RwLock::new(log_file), new_frame_notifier, - }) + }; + + if let Some(last_frame) = this.log_file.read().last_commited_frame_no() { + (this.new_frame_notifier)(last_frame); + } + + Ok(this) } fn recover( From 59556790ba62011a00995b733864799dc803872d Mon Sep 17 00:00:00 2001 From: ad hoc Date: Fri, 28 Jul 2023 11:40:49 +0200 Subject: [PATCH 59/64] add missing block_in_place --- libsqlx-server/src/allocation/primary/mod.rs | 4 +++- libsqlx-server/src/allocation/replica.rs | 1 + 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/libsqlx-server/src/allocation/primary/mod.rs b/libsqlx-server/src/allocation/primary/mod.rs index 606171c9..ee9e0d96 100644 --- a/libsqlx-server/src/allocation/primary/mod.rs +++ b/libsqlx-server/src/allocation/primary/mod.rs @@ -320,7 +320,9 @@ impl ConnectionHandler for PrimaryConnection { async fn handle_conn_message(&mut self, msg: ConnectionMessage) { match msg { ConnectionMessage::Execute { pgm, builder } => { - self.conn.execute_program(&pgm, builder).unwrap() + block_in_place(|| { + self.conn.execute_program(&pgm, builder).unwrap() + }) } ConnectionMessage::Describe => { todo!() diff --git a/libsqlx-server/src/allocation/replica.rs b/libsqlx-server/src/allocation/replica.rs index a31c8116..b1422e5b 100644 --- a/libsqlx-server/src/allocation/replica.rs +++ b/libsqlx-server/src/allocation/replica.rs @@ -172,6 +172,7 @@ impl Replicator { } async fn query_replicate(&mut self) { + tracing::debug!("seinding replication request"); self.req_id += 1; self.next_seq = 0; // clear buffered, uncommitted frames From 75f77a017637bbf2aa5d852925c51e08b600a911 Mon Sep 17 00:00:00 2001 From: ad hoc Date: Fri, 28 Jul 2023 14:00:36 +0200 Subject: [PATCH 60/64] handle allocation errors --- libsqlx-server/src/allocation/error.rs | 9 ++ libsqlx-server/src/allocation/mod.rs | 135 ++++++++++-------- libsqlx-server/src/compactor.rs | 1 + libsqlx-server/src/hrana/batch.rs | 4 +- libsqlx-server/src/hrana/stmt.rs | 2 +- libsqlx-server/src/manager.rs | 2 +- libsqlx/src/database/libsql/mod.rs | 9 +- .../database/libsql/replication_log/logger.rs | 14 -- libsqlx/src/database/mod.rs | 2 +- libsqlx/src/database/proxy/database.rs | 2 +- 10 files changed, 93 insertions(+), 87 deletions(-) create mode 100644 libsqlx-server/src/allocation/error.rs diff --git a/libsqlx-server/src/allocation/error.rs b/libsqlx-server/src/allocation/error.rs new file mode 100644 index 00000000..09c700aa --- /dev/null +++ b/libsqlx-server/src/allocation/error.rs @@ -0,0 +1,9 @@ +#[derive(Debug, thiserror::Error)] +pub enum Error { + #[error(transparent)] + Libsqlx(#[from] libsqlx::error::Error), + #[error("replica injector loop exited")] + InjectorExited, + #[error("connection closed")] + ConnectionClosed, +} diff --git a/libsqlx-server/src/allocation/mod.rs b/libsqlx-server/src/allocation/mod.rs index f5e19ba6..0ec1e271 100644 --- a/libsqlx-server/src/allocation/mod.rs +++ b/libsqlx-server/src/allocation/mod.rs @@ -39,6 +39,9 @@ pub mod config; mod primary; mod replica; mod timeout_notifier; +mod error; + +pub type Result = std::result::Result; /// Maximum number of frame a Frame message is allowed to contain const FRAMES_MESSAGE_MAX_COUNT: usize = 5; @@ -117,7 +120,7 @@ impl Database { dispatcher: Arc, compaction_queue: Arc, replica_commit_store: Arc, - ) -> Self { + ) -> Result { let database_id = DatabaseId::from_name(&config.db_name); match config.db_config { @@ -139,8 +142,7 @@ impl Database { Box::new(move |fno| { let _ = sender.send(Some(fno)); }), - ) - .unwrap(); + )?; let compact_interval = replication_log_compact_interval.map(|d| { let mut i = tokio::time::interval(d / 2); @@ -148,7 +150,7 @@ impl Database { Box::pin(i) }); - Self::Primary { + Ok(Self::Primary { db: PrimaryDatabase { db: Arc::new(db), replica_streams: HashMap::new(), @@ -157,7 +159,7 @@ impl Database { }, compact_interval, transaction_timeout_duration, - } + }) } DbConfig::Replica { primary_node_id, @@ -177,14 +179,13 @@ impl Database { path, MAX_INJECTOR_BUFFER_CAPACITY, commit_callback, - ) - .unwrap(); + )?; let wdb = RemoteDb { proxy_request_timeout_duration, }; - let mut db = WriteProxyDatabase::new(rdb, wdb, Arc::new(|_| ())); - let injector = db.injector().unwrap(); + let db = WriteProxyDatabase::new(rdb, wdb, Arc::new(|_| ())); + let injector = db.injector()?; let (sender, receiver) = mpsc::channel(16); let replicator = Replicator::new( @@ -198,13 +199,13 @@ impl Database { tokio::spawn(replicator.run()); - Self::Replica { + Ok(Self::Replica { db, injector_handle: sender, primary_id: primary_node_id, last_received_frame_ts: None, transaction_timeout_duration, - } + }) } } } @@ -214,28 +215,28 @@ impl Database { connection_id: u32, alloc: &Allocation, on_txn_status_change_cb: impl Fn(bool) + Send + Sync + 'static, - ) -> impl ConnectionHandler { + ) -> Result { match self { Database::Primary { db: PrimaryDatabase { db, .. }, .. } => { - let mut conn = db.connect().unwrap(); + let mut conn = db.connect()?; conn.set_on_txn_status_change_cb(on_txn_status_change_cb); - Either::Right(PrimaryConnection { conn }) + Ok(Either::Right(PrimaryConnection { conn })) } Database::Replica { db, primary_id, .. } => { - let mut conn = db.connect().unwrap(); + let mut conn = db.connect()?; conn.reader_mut() .set_on_txn_status_change_cb(on_txn_status_change_cb); - Either::Left(ReplicaConnection { + Ok(Either::Left(ReplicaConnection { conn, connection_id, next_req_id: 0, primary_node_id: *primary_id, database_id: DatabaseId::from_name(&alloc.db_name), dispatcher: alloc.dispatcher.clone(), - }) + })) } } } @@ -271,12 +272,12 @@ impl ConnectionHandle { &self, pgm: Program, builder: Box, - ) -> crate::Result<()> { - self.messages - .send(ConnectionMessage::Execute { pgm, builder }) - .await - .unwrap(); - Ok(()) + ) { + let msg = ConnectionMessage::Execute { pgm, builder }; + if let Err(e) = self.messages.send(msg).await { + let ConnectionMessage::Execute { mut builder, .. } = e.0 else { unreachable!() }; + builder.finnalize_error("connection closed".to_string()); + } } } @@ -290,13 +291,18 @@ impl Allocation { match msg { AllocationMessage::HranaPipelineReq { req, ret } => { let server = self.hrana_server.clone(); - handle_pipeline(server, req, ret, || async { - let conn = self.new_conn(None).await; + if let Err(e) = handle_pipeline(server, req, ret, || async { + let conn = self.new_conn(None).await?; Ok(conn) - }).await.unwrap(); + }).await { + tracing::error!("error handling request: {e}") + }; } AllocationMessage::Inbound(msg) => { - self.handle_inbound(msg).await; + if let Err(e) = self.handle_inbound(msg).await { + tracing::error!("allocation loop finished with error: {e}"); + return + } } } }, @@ -310,7 +316,7 @@ impl Allocation { } } - async fn handle_inbound(&mut self, msg: Inbound) { + async fn handle_inbound(&mut self, msg: Inbound) -> Result<()> { debug_assert_eq!( msg.enveloppe.database_id, Some(DatabaseId::from_name(&self.db_name)) @@ -361,7 +367,7 @@ impl Allocation { } Entry::Vacant(e) => { let handle = tokio::spawn(streamer.run()); - // For some reason, not yielding causes the task not to be spawned + // For some reason, yielding here is necessary for the task to start running tokio::task::yield_now().await; e.insert((req_no, handle)); } @@ -369,19 +375,15 @@ impl Allocation { } Database::Replica { .. } => todo!("not a primary!"), }, - Message::Frames(frames) => match &mut self.database { - Database::Replica { - injector_handle, - last_received_frame_ts, - .. - } => { - *last_received_frame_ts = Some(Instant::now()); - injector_handle.send(frames).await.unwrap(); + Message::Frames(frames) => if let Database::Replica { + injector_handle, + last_received_frame_ts, + .. + } = &mut self.database { + *last_received_frame_ts = Some(Instant::now()); + if injector_handle.send(frames).await.is_err() { + return Err(error::Error::InjectorExited) } - Database::Primary { - db: PrimaryDatabase { .. }, - .. - } => todo!("handle primary receiving txn"), }, Message::ProxyRequest { connection_id, @@ -397,13 +399,17 @@ impl Allocation { .get(&self.dispatcher.node_id()) .and_then(|m| m.get(&r.connection_id).cloned()) { - conn.inbound.send(msg).await.unwrap(); + if conn.inbound.send(msg).await.is_err() { + tracing::error!("cannot process message: connection is closed"); + } } } Message::CancelRequest { .. } => todo!(), Message::CloseConnection { .. } => todo!(), Message::Error(_) => todo!(), } + + Ok(()) } async fn handle_proxy( @@ -415,11 +421,8 @@ impl Allocation { ) { let dispatcher = self.dispatcher.clone(); let database_id = DatabaseId::from_name(&self.db_name); - let exec = |conn: ConnectionHandle| async move { - let builder = - ProxyResponseBuilder::new(dispatcher, database_id, to, req_id, connection_id); - conn.execute(program, Box::new(builder)).await.unwrap(); - }; + let mut builder = + ProxyResponseBuilder::new(dispatcher, database_id, to, req_id, connection_id); if self.database.is_primary() { match self @@ -428,17 +431,22 @@ impl Allocation { .and_then(|m| m.get(&connection_id).cloned()) { Some(handle) => { - tokio::spawn(exec(handle)); + tokio::spawn(async move { handle.execute(program, Box::new(builder)).await }); } None => { - let handle = self.new_conn(Some((to, connection_id))).await; - tokio::spawn(exec(handle)); + match self.new_conn(Some((to, connection_id))).await { + Ok(handle) => { + tokio::spawn(async move { handle.execute(program, Box::new(builder)).await }); + }, + Err(e) => builder.finnalize_error(format!("error creating connection: {e}")), + + } } } } } - async fn new_conn(&mut self, remote: Option<(NodeId, u32)>) -> ConnectionHandle { + async fn new_conn(&mut self, remote: Option<(NodeId, u32)>) -> Result { let conn_id = self.next_conn_id(); let (timeout_monitor, notifier) = timeout_monitor(); let timeout = self.database.txn_timeout_duration(); @@ -450,7 +458,7 @@ impl Allocation { notifier.disable(); } }) - }); + })?; let (messages_sender, messages_receiver) = mpsc::channel(1); let (inbound_sender, inbound_receiver) = mpsc::channel(1); @@ -465,15 +473,18 @@ impl Allocation { }; self.connections_futs.spawn(conn.run()); + let handle = ConnectionHandle { messages: messages_sender, inbound: inbound_sender, }; + self.connections .entry(id.0) .or_insert_with(HashMap::new) .insert(id.1, handle.clone()); - handle + + Ok(handle) } fn next_conn_id(&mut self) -> u32 { @@ -564,6 +575,8 @@ impl Connection { } } + tracing::debug!("connection exited: {:?}", self.id); + self.id } } @@ -581,7 +594,7 @@ mod test { use crate::linc::bus::Bus; use crate::replica_commit_store::ReplicaCommitStore; use crate::snapshot_store::SnapshotStore; - use crate::{init_dirs, replica_commit_store}; + use crate::init_dirs; use super::*; @@ -654,7 +667,7 @@ mod test { transaction_timeout_duration: Duration::from_millis(100), }, }; - let (sender, inbox) = mpsc::channel(10); + let (_sender, inbox) = mpsc::channel(10); let env = EnvOpenOptions::new() .max_dbs(10) .map_size(4096 * 100) @@ -672,24 +685,23 @@ mod test { bus.clone(), queue, replica_commit_store, - ), + ).unwrap(), connections_futs: JoinSet::new(), next_conn_id: 0, max_concurrent_connections: config.max_conccurent_connection, hrana_server: Arc::new(hrana::http::Server::new(None)), - dispatcher: bus, // TODO: handle self URL? + dispatcher: bus, db_name: config.db_name, connections: HashMap::new(), }; - let conn = alloc.new_conn(None).await; + let conn = alloc.new_conn(None).await.unwrap(); tokio::spawn(alloc.run()); let (snd, rcv) = oneshot::channel(); let builder = StepResultsBuilder::new(snd); conn.execute(Program::seq(&["begin"]), Box::new(builder)) - .await - .unwrap(); + .await; rcv.await.unwrap().unwrap(); tokio::time::sleep(Duration::from_secs(1)).await; @@ -697,8 +709,7 @@ mod test { let (snd, rcv) = oneshot::channel(); let builder = StepResultsBuilder::new(snd); conn.execute(Program::seq(&["create table test (x)"]), Box::new(builder)) - .await - .unwrap(); + .await; assert!(rcv.await.unwrap().is_err()); } } diff --git a/libsqlx-server/src/compactor.rs b/libsqlx-server/src/compactor.rs index 2039343a..523e0d39 100644 --- a/libsqlx-server/src/compactor.rs +++ b/libsqlx-server/src/compactor.rs @@ -193,6 +193,7 @@ pub struct SnapshotFrame { impl SnapshotFrame { const SIZE: usize = size_of::() + 4096; + #[cfg(test)] pub fn try_from_bytes(data: Bytes) -> crate::Result { if data.len() != Self::SIZE { color_eyre::eyre::bail!("invalid snapshot frame") diff --git a/libsqlx-server/src/hrana/batch.rs b/libsqlx-server/src/hrana/batch.rs index 14cfb1c3..50f31b0f 100644 --- a/libsqlx-server/src/hrana/batch.rs +++ b/libsqlx-server/src/hrana/batch.rs @@ -75,7 +75,7 @@ pub async fn execute_batch( pgm: Program, ) -> color_eyre::Result { let (builder, ret) = HranaBatchProtoBuilder::new(); - db.execute(pgm, Box::new(builder)).await?; + db.execute(pgm, Box::new(builder)).await; Ok(ret.await?) } @@ -108,7 +108,7 @@ pub fn proto_sequence_to_program(sql: &str) -> color_eyre::Result { pub async fn execute_sequence(conn: &ConnectionHandle, pgm: Program) -> color_eyre::Result<()> { let (snd, rcv) = oneshot::channel(); let builder = StepResultsBuilder::new(snd); - conn.execute(pgm, Box::new(builder)).await?; + conn.execute(pgm, Box::new(builder)).await; rcv.await? .map_err(|e| anyhow!("{e}"))? diff --git a/libsqlx-server/src/hrana/stmt.rs b/libsqlx-server/src/hrana/stmt.rs index 1b843367..becb075f 100644 --- a/libsqlx-server/src/hrana/stmt.rs +++ b/libsqlx-server/src/hrana/stmt.rs @@ -49,7 +49,7 @@ pub async fn execute_stmt( ) -> color_eyre::Result { let (builder, ret) = SingleStatementBuilder::new(); let pgm = libsqlx::program::Program::from_queries(std::iter::once(query)); - conn.execute(pgm, Box::new(builder)).await?; + conn.execute(pgm, Box::new(builder)).await; ret.await? .map_err(|sqld_error| match stmt_error_from_sqld_error(sqld_error) { Ok(stmt_error) => anyhow!(stmt_error), diff --git a/libsqlx-server/src/manager.rs b/libsqlx-server/src/manager.rs index 1b0ca7d1..c4fe1bae 100644 --- a/libsqlx-server/src/manager.rs +++ b/libsqlx-server/src/manager.rs @@ -65,7 +65,7 @@ impl Manager { dispatcher.clone(), self.compaction_queue.clone(), self.replica_commit_store.clone(), - ), + ).unwrap(), connections_futs: JoinSet::new(), next_conn_id: 0, max_concurrent_connections: config.max_conccurent_connection, diff --git a/libsqlx/src/database/libsql/mod.rs b/libsqlx/src/database/libsql/mod.rs index 9582cf2c..a56c10ae 100644 --- a/libsqlx/src/database/libsql/mod.rs +++ b/libsqlx/src/database/libsql/mod.rs @@ -184,7 +184,7 @@ impl Database for LibsqlDatabase { } impl InjectableDatabase for LibsqlDatabase { - fn injector(&mut self) -> crate::Result> { + fn injector(&self) -> crate::Result> { Ok(Box::new(Injector::new( &self.db_path, self.ty.on_commit_cb.clone(), @@ -228,7 +228,7 @@ mod test { on_commit_cb: Arc::new(|_| ()), injector_buffer_capacity: 10, }; - let mut db = LibsqlDatabase::new(temp.path().to_path_buf(), replica); + let db = LibsqlDatabase::new(temp.path().to_path_buf(), replica); let mut conn = db.connect().unwrap(); let row: Arc>> = Default::default(); @@ -265,7 +265,7 @@ mod test { }, ); - let mut replica = LibsqlDatabase::new( + let replica = LibsqlDatabase::new( temp_replica.path().to_path_buf(), ReplicaType { on_commit_cb: Arc::new(|_| ()), @@ -278,8 +278,7 @@ mod test { .execute_program( &Program::seq(&["create table test (x)", "insert into test values (42)"]), Box::new(()), - ) - .unwrap(); + ).unwrap(); let logfile = primary.ty.logger.log_file.read(); diff --git a/libsqlx/src/database/libsql/replication_log/logger.rs b/libsqlx/src/database/libsql/replication_log/logger.rs index 38112c37..998fada7 100644 --- a/libsqlx/src/database/libsql/replication_log/logger.rs +++ b/libsqlx/src/database/libsql/replication_log/logger.rs @@ -711,20 +711,6 @@ impl LogFileHeader { } } -pub struct Generation { - pub id: Uuid, - pub start_index: u64, -} - -impl Generation { - fn new(start_index: u64) -> Self { - Self { - id: Uuid::new_v4(), - start_index, - } - } -} - pub trait LogCompactor: Sync + Send + 'static { /// returns whether the passed log file should be compacted. If this method returns true, /// compact should be called next. diff --git a/libsqlx/src/database/mod.rs b/libsqlx/src/database/mod.rs index 43fa0dac..868eaf7f 100644 --- a/libsqlx/src/database/mod.rs +++ b/libsqlx/src/database/mod.rs @@ -21,7 +21,7 @@ pub trait Database { } pub trait InjectableDatabase { - fn injector(&mut self) -> crate::Result>; + fn injector(&self) -> crate::Result>; } // Trait implemented by databases that support frame injection diff --git a/libsqlx/src/database/proxy/database.rs b/libsqlx/src/database/proxy/database.rs index adc82ed8..d0e7c5a0 100644 --- a/libsqlx/src/database/proxy/database.rs +++ b/libsqlx/src/database/proxy/database.rs @@ -42,7 +42,7 @@ impl InjectableDatabase for WriteProxyDatabase where RDB: InjectableDatabase, { - fn injector(&mut self) -> crate::Result> { + fn injector(&self) -> crate::Result> { self.read_db.injector() } } From 0d6005c4aaf39a733f53acb49f10dc52061ccf4f Mon Sep 17 00:00:00 2001 From: ad hoc Date: Fri, 28 Jul 2023 16:01:00 +0200 Subject: [PATCH 61/64] wip --- libsqlx-server/src/allocation/mod.rs | 62 ++++--- libsqlx-server/src/allocation/primary/mod.rs | 4 +- libsqlx-server/src/database.rs | 14 +- libsqlx-server/src/{allocation => }/error.rs | 6 + libsqlx-server/src/hrana/batch.rs | 36 ++-- libsqlx-server/src/hrana/error.rs | 168 +++++++++++++++++++ libsqlx-server/src/hrana/http/mod.rs | 13 +- libsqlx-server/src/hrana/http/request.rs | 59 +++---- libsqlx-server/src/hrana/http/stream.rs | 39 ++--- libsqlx-server/src/hrana/mod.rs | 46 +---- libsqlx-server/src/hrana/stmt.rs | 114 ++----------- libsqlx-server/src/main.rs | 18 +- libsqlx-server/src/manager.rs | 3 +- libsqlx/src/database/libsql/mod.rs | 3 +- 14 files changed, 309 insertions(+), 276 deletions(-) rename libsqlx-server/src/{allocation => }/error.rs (58%) create mode 100644 libsqlx-server/src/hrana/error.rs diff --git a/libsqlx-server/src/allocation/mod.rs b/libsqlx-server/src/allocation/mod.rs index 0ec1e271..8a757339 100644 --- a/libsqlx-server/src/allocation/mod.rs +++ b/libsqlx-server/src/allocation/mod.rs @@ -20,7 +20,9 @@ use tokio::time::Interval; use crate::allocation::primary::FrameStreamer; use crate::allocation::timeout_notifier::timeout_monitor; use crate::compactor::CompactionQueue; +use crate::error::Error; use crate::hrana; +use crate::hrana::error::HranaError; use crate::hrana::http::handle_pipeline; use crate::hrana::http::proto::{PipelineRequestBody, PipelineResponseBody}; use crate::linc::bus::Dispatch; @@ -39,9 +41,6 @@ pub mod config; mod primary; mod replica; mod timeout_notifier; -mod error; - -pub type Result = std::result::Result; /// Maximum number of frame a Frame message is allowed to contain const FRAMES_MESSAGE_MAX_COUNT: usize = 5; @@ -59,7 +58,7 @@ pub enum ConnectionMessage { pub enum AllocationMessage { HranaPipelineReq { req: PipelineRequestBody, - ret: oneshot::Sender>, + ret: oneshot::Sender>, }, Inbound(Inbound), } @@ -120,7 +119,7 @@ impl Database { dispatcher: Arc, compaction_queue: Arc, replica_commit_store: Arc, - ) -> Result { + ) -> crate::Result { let database_id = DatabaseId::from_name(&config.db_name); match config.db_config { @@ -215,7 +214,7 @@ impl Database { connection_id: u32, alloc: &Allocation, on_txn_status_change_cb: impl Fn(bool) + Send + Sync + 'static, - ) -> Result { + ) -> crate::Result { match self { Database::Primary { db: PrimaryDatabase { db, .. }, @@ -268,11 +267,7 @@ pub struct ConnectionHandle { } impl ConnectionHandle { - pub async fn execute( - &self, - pgm: Program, - builder: Box, - ) { + pub async fn execute(&self, pgm: Program, builder: Box) { let msg = ConnectionMessage::Execute { pgm, builder }; if let Err(e) = self.messages.send(msg).await { let ConnectionMessage::Execute { mut builder, .. } = e.0 else { unreachable!() }; @@ -316,7 +311,7 @@ impl Allocation { } } - async fn handle_inbound(&mut self, msg: Inbound) -> Result<()> { + async fn handle_inbound(&mut self, msg: Inbound) -> crate::Result<()> { debug_assert_eq!( msg.enveloppe.database_id, Some(DatabaseId::from_name(&self.db_name)) @@ -375,16 +370,19 @@ impl Allocation { } Database::Replica { .. } => todo!("not a primary!"), }, - Message::Frames(frames) => if let Database::Replica { - injector_handle, - last_received_frame_ts, - .. - } = &mut self.database { - *last_received_frame_ts = Some(Instant::now()); - if injector_handle.send(frames).await.is_err() { - return Err(error::Error::InjectorExited) + Message::Frames(frames) => { + if let Database::Replica { + injector_handle, + last_received_frame_ts, + .. + } = &mut self.database + { + *last_received_frame_ts = Some(Instant::now()); + if injector_handle.send(frames).await.is_err() { + return Err(Error::InjectorExited); + } } - }, + } Message::ProxyRequest { connection_id, req_id, @@ -433,20 +431,19 @@ impl Allocation { Some(handle) => { tokio::spawn(async move { handle.execute(program, Box::new(builder)).await }); } - None => { - match self.new_conn(Some((to, connection_id))).await { - Ok(handle) => { - tokio::spawn(async move { handle.execute(program, Box::new(builder)).await }); - }, - Err(e) => builder.finnalize_error(format!("error creating connection: {e}")), - + None => match self.new_conn(Some((to, connection_id))).await { + Ok(handle) => { + tokio::spawn( + async move { handle.execute(program, Box::new(builder)).await }, + ); } - } + Err(e) => builder.finnalize_error(format!("error creating connection: {e}")), + }, } } } - async fn new_conn(&mut self, remote: Option<(NodeId, u32)>) -> Result { + async fn new_conn(&mut self, remote: Option<(NodeId, u32)>) -> crate::Result { let conn_id = self.next_conn_id(); let (timeout_monitor, notifier) = timeout_monitor(); let timeout = self.database.txn_timeout_duration(); @@ -591,10 +588,10 @@ mod test { use tokio::sync::Notify; use crate::allocation::replica::ReplicaConnection; + use crate::init_dirs; use crate::linc::bus::Bus; use crate::replica_commit_store::ReplicaCommitStore; use crate::snapshot_store::SnapshotStore; - use crate::init_dirs; use super::*; @@ -685,7 +682,8 @@ mod test { bus.clone(), queue, replica_commit_store, - ).unwrap(), + ) + .unwrap(), connections_futs: JoinSet::new(), next_conn_id: 0, max_concurrent_connections: config.max_conccurent_connection, diff --git a/libsqlx-server/src/allocation/primary/mod.rs b/libsqlx-server/src/allocation/primary/mod.rs index ee9e0d96..dbc0387f 100644 --- a/libsqlx-server/src/allocation/primary/mod.rs +++ b/libsqlx-server/src/allocation/primary/mod.rs @@ -320,9 +320,7 @@ impl ConnectionHandler for PrimaryConnection { async fn handle_conn_message(&mut self, msg: ConnectionMessage) { match msg { ConnectionMessage::Execute { pgm, builder } => { - block_in_place(|| { - self.conn.execute_program(&pgm, builder).unwrap() - }) + block_in_place(|| self.conn.execute_program(&pgm, builder).unwrap()) } ConnectionMessage::Describe => { todo!() diff --git a/libsqlx-server/src/database.rs b/libsqlx-server/src/database.rs index 4945cd70..94e9b724 100644 --- a/libsqlx-server/src/database.rs +++ b/libsqlx-server/src/database.rs @@ -1,6 +1,7 @@ use tokio::sync::{mpsc, oneshot}; use crate::allocation::AllocationMessage; +use crate::error::Error; use crate::hrana::http::proto::{PipelineRequestBody, PipelineResponseBody}; pub struct Database { @@ -13,10 +14,17 @@ impl Database { req: PipelineRequestBody, ) -> crate::Result { let (sender, ret) = oneshot::channel(); - self.sender + if self + .sender .send(AllocationMessage::HranaPipelineReq { req, ret: sender }) .await - .unwrap(); - ret.await.unwrap() + .is_err() + { + return Err(Error::AllocationClosed); + } + + ret.await.map_err(|_| { + Error::Internal(String::from("response builder dropped by connection")) + }) } } diff --git a/libsqlx-server/src/allocation/error.rs b/libsqlx-server/src/error.rs similarity index 58% rename from libsqlx-server/src/allocation/error.rs rename to libsqlx-server/src/error.rs index 09c700aa..df485ac5 100644 --- a/libsqlx-server/src/allocation/error.rs +++ b/libsqlx-server/src/error.rs @@ -6,4 +6,10 @@ pub enum Error { InjectorExited, #[error("connection closed")] ConnectionClosed, + #[error(transparent)] + Io(#[from] std::io::Error), + #[error("allocation closed")] + AllocationClosed, + #[error("internal error: {0}")] + Internal(String), } diff --git a/libsqlx-server/src/hrana/batch.rs b/libsqlx-server/src/hrana/batch.rs index 50f31b0f..b6ed24fe 100644 --- a/libsqlx-server/src/hrana/batch.rs +++ b/libsqlx-server/src/hrana/batch.rs @@ -1,20 +1,19 @@ use std::collections::HashMap; use crate::allocation::ConnectionHandle; -use crate::hrana::stmt::StmtError; +use super::error::{ProtocolError, StmtError, HranaError}; use super::result_builder::HranaBatchProtoBuilder; use super::stmt::{proto_stmt_to_query, stmt_error_from_sqld_error}; -use super::{proto, ProtocolError, Version}; +use super::{proto, Version}; -use color_eyre::eyre::anyhow; use libsqlx::analysis::Statement; use libsqlx::program::{Cond, Program, Step}; use libsqlx::query::{Params, Query}; use libsqlx::result_builder::{StepResult, StepResultsBuilder}; use tokio::sync::oneshot; -fn proto_cond_to_cond(cond: &proto::BatchCond, max_step_i: usize) -> color_eyre::Result { +fn proto_cond_to_cond(cond: &proto::BatchCond, max_step_i: usize) -> crate::Result { let try_convert_step = |step: i32| -> Result { let step = usize::try_from(step).map_err(|_| ProtocolError::BatchCondBadStep)?; if step >= max_step_i { @@ -36,13 +35,13 @@ fn proto_cond_to_cond(cond: &proto::BatchCond, max_step_i: usize) -> color_eyre: conds: conds .iter() .map(|cond| proto_cond_to_cond(cond, max_step_i)) - .collect::>()?, + .collect::>()?, }, proto::BatchCond::Or { conds } => Cond::Or { conds: conds .iter() .map(|cond| proto_cond_to_cond(cond, max_step_i)) - .collect::>()?, + .collect::>()?, }, }; @@ -53,7 +52,7 @@ pub fn proto_batch_to_program( batch: &proto::Batch, sqls: &HashMap, version: Version, -) -> color_eyre::Result { +) -> crate::Result { let mut steps = Vec::with_capacity(batch.steps.len()); for (step_i, step) in batch.steps.iter().enumerate() { let query = proto_stmt_to_query(&step.stmt, sqls, version)?; @@ -73,17 +72,17 @@ pub fn proto_batch_to_program( pub async fn execute_batch( db: &ConnectionHandle, pgm: Program, -) -> color_eyre::Result { +) -> crate::Result { let (builder, ret) = HranaBatchProtoBuilder::new(); db.execute(pgm, Box::new(builder)).await; - Ok(ret.await?) + Ok(ret.await.unwrap()) } -pub fn proto_sequence_to_program(sql: &str) -> color_eyre::Result { +pub fn proto_sequence_to_program(sql: &str) -> crate::Result { let stmts = Statement::parse(sql) .collect::>>() - .map_err(|err| anyhow!(StmtError::SqlParse { source: err.into() }))?; + .map_err(|err| StmtError::SqlParse { source: err.into() })?; let steps = stmts .into_iter() @@ -105,20 +104,21 @@ pub fn proto_sequence_to_program(sql: &str) -> color_eyre::Result { Ok(Program { steps }) } -pub async fn execute_sequence(conn: &ConnectionHandle, pgm: Program) -> color_eyre::Result<()> { +pub async fn execute_sequence(conn: &ConnectionHandle, pgm: Program) -> crate::Result<(), HranaError> { let (snd, rcv) = oneshot::channel(); let builder = StepResultsBuilder::new(snd); conn.execute(pgm, Box::new(builder)).await; - rcv.await? - .map_err(|e| anyhow!("{e}"))? + rcv.await + .map_err(|e| HranaError::Internal(e.into()))? + .map_err(|e| HranaError::Stmt(StmtError::QueryError(e)))? .into_iter() .try_for_each(|result| match result { StepResult::Ok => Ok(()), - StepResult::Err(e) => match stmt_error_from_sqld_error(e) { - Ok(stmt_err) => Err(anyhow!(stmt_err)), - Err(sqld_err) => Err(anyhow!(sqld_err)), + StepResult::Err(e) => { + let stmt_err = stmt_error_from_sqld_error(e)?; + Err(stmt_err)? }, - StepResult::Skipped => Err(anyhow!("Statement in sequence was not executed")), + StepResult::Skipped => Err(HranaError::StatementSkipped), }) } diff --git a/libsqlx-server/src/hrana/error.rs b/libsqlx-server/src/hrana/error.rs new file mode 100644 index 00000000..4092690d --- /dev/null +++ b/libsqlx-server/src/hrana/error.rs @@ -0,0 +1,168 @@ +use super::Version; + +#[derive(thiserror::Error, Debug)] +pub enum HranaError { + #[error(transparent)] + Protocol(#[from] ProtocolError), + #[error(transparent)] + Stmt(#[from] StmtError), + #[error(transparent)] + Internal(color_eyre::eyre::Error), + #[error("Statement in sequence was not executed")] + StatementSkipped, + #[error(transparent)] + Libsqlx(#[from] libsqlx::error::Error), + #[error(transparent)] + StreamResponse(#[from] StreamResponseError), + #[error(transparent)] + Stream(#[from] StreamError) +} + +/// An error from executing a [`proto::StreamRequest`] +#[derive(thiserror::Error, Debug)] +pub enum StreamResponseError { + #[error("The server already stores {count} SQL texts, it cannot store more")] + SqlTooMany { count: usize }, + #[error(transparent)] + Stmt(StmtError), +} + + +/// An unrecoverable protocol error that should close the WebSocket or HTTP stream. A correct +/// client should never trigger any of these errors. +#[derive(thiserror::Error, Debug)] +pub enum ProtocolError { + #[error("Cannot deserialize client message: {source}")] + Deserialize { source: serde_json::Error }, + #[error("Received a binary WebSocket message, which is not supported")] + BinaryWebSocketMessage, + #[error("Received a request before hello message")] + RequestBeforeHello, + + #[error("Stream {stream_id} not found")] + StreamNotFound { stream_id: i32 }, + #[error("Stream {stream_id} already exists")] + StreamExists { stream_id: i32 }, + + #[error("Either `sql` or `sql_id` are required, but not both")] + SqlIdAndSqlGiven, + #[error("Either `sql` or `sql_id` are required")] + SqlIdOrSqlNotGiven, + #[error("SQL text {sql_id} not found")] + SqlNotFound { sql_id: i32 }, + #[error("SQL text {sql_id} already exists")] + SqlExists { sql_id: i32 }, + + #[error("Invalid reference to step in a batch condition")] + BatchCondBadStep, + + #[error("Received an invalid baton: {0}")] + BatonInvalid(String), + #[error("Received a baton that has already been used")] + BatonReused, + #[error("Stream for this baton was closed")] + BatonStreamClosed, + + #[error("{what} is only supported in protocol version {min_version} and higher")] + NotSupported { + what: &'static str, + min_version: Version, + }, + + #[error("{0}")] + ResponseTooLarge(String), +} + +/// An error during execution of an SQL statement. +#[derive(thiserror::Error, Debug)] +pub enum StmtError { + #[error("SQL string could not be parsed: {source}")] + SqlParse { source: color_eyre::eyre::Error }, + #[error("SQL string does not contain any statement")] + SqlNoStmt, + #[error("SQL string contains more than one statement")] + SqlManyStmts, + #[error("Arguments do not match SQL parameters: {msg}")] + ArgsInvalid { msg: String }, + #[error("Specifying both positional and named arguments is not supported")] + ArgsBothPositionalAndNamed, + + #[error("Transaction timed out")] + TransactionTimeout, + #[error("Server cannot handle additional transactions")] + TransactionBusy, + #[error("SQLite error: {message}")] + SqliteError { + source: libsqlx::rusqlite::ffi::Error, + message: String, + }, + #[error("SQL input error: {message} (at offset {offset})")] + SqlInputError { + source: color_eyre::eyre::Error, + message: String, + offset: i32, + }, + + #[error("Operation was blocked{}", .reason.as_ref().map(|msg| format!(": {}", msg)).unwrap_or_default())] + Blocked { reason: Option }, + #[error("query error: {0}")] + QueryError(String) +} + +impl StmtError { + pub fn code(&self) -> &'static str { + match self { + Self::SqlParse { .. } => "SQL_PARSE_ERROR", + Self::SqlNoStmt => "SQL_NO_STATEMENT", + Self::SqlManyStmts => "SQL_MANY_STATEMENTS", + Self::ArgsInvalid { .. } => "ARGS_INVALID", + Self::ArgsBothPositionalAndNamed => "ARGS_BOTH_POSITIONAL_AND_NAMED", + Self::TransactionTimeout => "TRANSACTION_TIMEOUT", + Self::TransactionBusy => "TRANSACTION_BUSY", + Self::SqliteError { source, .. } => sqlite_error_code(source.code), + Self::SqlInputError { .. } => "SQL_INPUT_ERROR", + Self::Blocked { .. } => "BLOCKED", + Self::QueryError(_) => todo!(), + } + } +} + +fn sqlite_error_code(code: libsqlx::error::ErrorCode) -> &'static str { + match code { + libsqlx::error::ErrorCode::InternalMalfunction => "SQLITE_INTERNAL", + libsqlx::error::ErrorCode::PermissionDenied => "SQLITE_PERM", + libsqlx::error::ErrorCode::OperationAborted => "SQLITE_ABORT", + libsqlx::error::ErrorCode::DatabaseBusy => "SQLITE_BUSY", + libsqlx::error::ErrorCode::DatabaseLocked => "SQLITE_LOCKED", + libsqlx::error::ErrorCode::OutOfMemory => "SQLITE_NOMEM", + libsqlx::error::ErrorCode::ReadOnly => "SQLITE_READONLY", + libsqlx::error::ErrorCode::OperationInterrupted => "SQLITE_INTERRUPT", + libsqlx::error::ErrorCode::SystemIoFailure => "SQLITE_IOERR", + libsqlx::error::ErrorCode::DatabaseCorrupt => "SQLITE_CORRUPT", + libsqlx::error::ErrorCode::NotFound => "SQLITE_NOTFOUND", + libsqlx::error::ErrorCode::DiskFull => "SQLITE_FULL", + libsqlx::error::ErrorCode::CannotOpen => "SQLITE_CANTOPEN", + libsqlx::error::ErrorCode::FileLockingProtocolFailed => "SQLITE_PROTOCOL", + libsqlx::error::ErrorCode::SchemaChanged => "SQLITE_SCHEMA", + libsqlx::error::ErrorCode::TooBig => "SQLITE_TOOBIG", + libsqlx::error::ErrorCode::ConstraintViolation => "SQLITE_CONSTRAINT", + libsqlx::error::ErrorCode::TypeMismatch => "SQLITE_MISMATCH", + libsqlx::error::ErrorCode::ApiMisuse => "SQLITE_MISUSE", + libsqlx::error::ErrorCode::NoLargeFileSupport => "SQLITE_NOLFS", + libsqlx::error::ErrorCode::AuthorizationForStatementDenied => "SQLITE_AUTH", + libsqlx::error::ErrorCode::ParameterOutOfRange => "SQLITE_RANGE", + libsqlx::error::ErrorCode::NotADatabase => "SQLITE_NOTADB", + libsqlx::error::ErrorCode::Unknown => "SQLITE_UNKNOWN", + _ => "SQLITE_UNKNOWN", + } +} + +/// An unrecoverable error that should close the stream. The difference from [`ProtocolError`] is +/// that a correct client may trigger this error, it does not mean that the protocol has been +/// violated. +#[derive(thiserror::Error, Debug)] +pub enum StreamError { + #[error("The stream has expired due to inactivity")] + StreamExpired, +} + diff --git a/libsqlx-server/src/hrana/http/mod.rs b/libsqlx-server/src/hrana/http/mod.rs index 521d33ff..ea316aea 100644 --- a/libsqlx-server/src/hrana/http/mod.rs +++ b/libsqlx-server/src/hrana/http/mod.rs @@ -10,7 +10,7 @@ use crate::allocation::ConnectionHandle; use self::proto::{PipelineRequestBody, PipelineResponseBody}; -use super::ProtocolError; +use super::error::{HranaError, ProtocolError, StreamError}; pub mod proto; mod request; @@ -42,7 +42,7 @@ impl Server { } } -fn handle_index() -> color_eyre::Result> { +fn handle_index() -> crate::Result, HranaError> { Ok(text_response( hyper::StatusCode::OK, "Hello, this is HTTP API v2 (Hrana over HTTP)".into(), @@ -52,9 +52,9 @@ fn handle_index() -> color_eyre::Result> { pub async fn handle_pipeline( server: Arc, req: PipelineRequestBody, - ret: oneshot::Sender>, + ret: oneshot::Sender>, mk_conn: F, -) -> color_eyre::Result<()> +) -> crate::Result<(), HranaError> where F: FnOnce() -> Fut, Fut: Future>, @@ -66,8 +66,7 @@ where let mut results = Vec::with_capacity(req.requests.len()); for request in req.requests.into_iter() { let result = request::handle(&mut stream_guard, request) - .await - .context("Could not execute a request in pipeline")?; + .await?; results.push(result); } @@ -100,7 +99,7 @@ fn protocol_error_response(err: ProtocolError) -> hyper::Response { text_response(hyper::StatusCode::BAD_REQUEST, err.to_string()) } -fn stream_error_response(err: stream::StreamError) -> hyper::Response { +fn stream_error_response(err: StreamError) -> hyper::Response { json_response( hyper::StatusCode::INTERNAL_SERVER_ERROR, &proto::Error { diff --git a/libsqlx-server/src/hrana/http/request.rs b/libsqlx-server/src/hrana/http/request.rs index eb1623cd..c9369315 100644 --- a/libsqlx-server/src/hrana/http/request.rs +++ b/libsqlx-server/src/hrana/http/request.rs @@ -1,39 +1,34 @@ -use color_eyre::eyre::{anyhow, bail}; +use crate::hrana::error::{HranaError, ProtocolError, StreamResponseError}; -use super::super::{batch, stmt, ProtocolError, Version}; +use super::super::{batch, stmt, Version}; use super::{proto, stream}; -/// An error from executing a [`proto::StreamRequest`] -#[derive(thiserror::Error, Debug)] -pub enum StreamResponseError { - #[error("The server already stores {count} SQL texts, it cannot store more")] - SqlTooMany { count: usize }, - #[error(transparent)] - Stmt(stmt::StmtError), -} - pub async fn handle( stream_guard: &mut stream::Guard, request: proto::StreamRequest, -) -> color_eyre::Result { +) -> crate::Result { let result = match try_handle(stream_guard, request).await { Ok(response) => proto::StreamResult::Ok { response }, Err(err) => { - let resp_err = err.downcast::()?; - let error = proto::Error { - message: resp_err.to_string(), - code: resp_err.code().into(), - }; - proto::StreamResult::Error { error } + if let HranaError::StreamResponse(resp_err) = err { + let error = proto::Error { + message: resp_err.to_string(), + code: resp_err.code().into(), + }; + proto::StreamResult::Error { error } + } else { + return Err(err); + } } }; + Ok(result) } async fn try_handle( stream_guard: &mut stream::Guard, request: proto::StreamRequest, -) -> color_eyre::Result { +) -> crate::Result { Ok(match request { proto::StreamRequest::Close(_req) => { stream_guard.close_db(); @@ -42,11 +37,8 @@ async fn try_handle( proto::StreamRequest::Execute(req) => { let db = stream_guard.get_db()?; let sqls = stream_guard.sqls(); - let query = stmt::proto_stmt_to_query(&req.stmt, sqls, Version::Hrana2) - .map_err(catch_stmt_error)?; - let result = stmt::execute_stmt(db, query) - .await - .map_err(catch_stmt_error)?; + let query = stmt::proto_stmt_to_query(&req.stmt, sqls, Version::Hrana2)?; + let result = stmt::execute_stmt(db, query).await?; proto::StreamResponse::Execute(proto::ExecuteStreamResp { result }) } proto::StreamRequest::Batch(req) => { @@ -61,10 +53,9 @@ async fn try_handle( let sqls = stream_guard.sqls(); let sql = stmt::proto_sql_to_sql(req.sql.as_deref(), req.sql_id, sqls, Version::Hrana2)?; - let pgm = batch::proto_sequence_to_program(sql).map_err(catch_stmt_error)?; + let pgm = batch::proto_sequence_to_program(sql)?; batch::execute_sequence(db, pgm) - .await - .map_err(catch_stmt_error)?; + .await?; proto::StreamResponse::Sequence(proto::SequenceStreamResp {}) } proto::StreamRequest::Describe(req) => { @@ -73,17 +64,16 @@ async fn try_handle( let sql = stmt::proto_sql_to_sql(req.sql.as_deref(), req.sql_id, sqls, Version::Hrana2)?; let result = stmt::describe_stmt(db, sql.into()) - .await - .map_err(catch_stmt_error)?; + .await?; proto::StreamResponse::Describe(proto::DescribeStreamResp { result }) } proto::StreamRequest::StoreSql(req) => { let sqls = stream_guard.sqls_mut(); let sql_id = req.sql_id; if sqls.contains_key(&sql_id) { - bail!(ProtocolError::SqlExists { sql_id }) + Err(ProtocolError::SqlExists { sql_id })? } else if sqls.len() >= MAX_SQL_COUNT { - bail!(StreamResponseError::SqlTooMany { count: sqls.len() }) + Err(StreamResponseError::SqlTooMany { count: sqls.len() })? } sqls.insert(sql_id, req.sql); proto::StreamResponse::StoreSql(proto::StoreSqlStreamResp {}) @@ -98,13 +88,6 @@ async fn try_handle( const MAX_SQL_COUNT: usize = 50; -fn catch_stmt_error(err: color_eyre::eyre::Error) -> color_eyre::eyre::Error { - match err.downcast::() { - Ok(stmt_err) => anyhow!(StreamResponseError::Stmt(stmt_err)), - Err(err) => err, - } -} - impl StreamResponseError { pub fn code(&self) -> &'static str { match self { diff --git a/libsqlx-server/src/hrana/http/stream.rs b/libsqlx-server/src/hrana/http/stream.rs index 25c1e719..3d87a96e 100644 --- a/libsqlx-server/src/hrana/http/stream.rs +++ b/libsqlx-server/src/hrana/http/stream.rs @@ -5,15 +5,14 @@ use std::sync::Arc; use std::{future, mem, task}; use base64::prelude::{Engine as _, BASE64_STANDARD_NO_PAD}; -use color_eyre::eyre::{anyhow, WrapErr}; use futures::Future; use hmac::Mac as _; use priority_queue::PriorityQueue; use tokio::time::{Duration, Instant}; -use super::super::ProtocolError; use super::Server; use crate::allocation::ConnectionHandle; +use crate::hrana::error::{ProtocolError, HranaError, StreamError}; /// Mutable state related to streams, owned by [`Server`] and protected with a mutex. pub struct ServerStreamState { @@ -78,15 +77,6 @@ pub struct Guard { release: bool, } -/// An unrecoverable error that should close the stream. The difference from [`ProtocolError`] is -/// that a correct client may trigger this error, it does not mean that the protocol has been -/// violated. -#[derive(thiserror::Error, Debug)] -pub enum StreamError { - #[error("The stream has expired due to inactivity")] - StreamExpired, -} - impl ServerStreamState { pub fn new() -> Self { Self { @@ -106,7 +96,7 @@ pub async fn acquire( server: Arc, baton: Option<&str>, mk_conn: F, -) -> color_eyre::Result +) -> crate::Result where F: FnOnce() -> Fut, Fut: Future>, @@ -125,19 +115,20 @@ where .into()) } Some(Handle::Acquired) => { - return Err(ProtocolError::BatonReused) - .context(format!("Stream handle for {stream_id} is acquired")); + Err(ProtocolError::BatonReused)? + // .context(format!("Stream handle for {stream_id} is acquired")); } Some(Handle::Expired) => { - return Err(StreamError::StreamExpired) - .context(format!("Stream handle for {stream_id} is expired")); + Err(StreamError::StreamExpired)? + // .context(format!("Stream handle for {stream_id} is expired")); } Some(Handle::Available(stream)) => { if stream.baton_seq != baton_seq { - return Err(ProtocolError::BatonReused).context(format!( - "Expected baton seq {}, received {baton_seq}", - stream.baton_seq - )); + Err(ProtocolError::BatonReused)? + // .context(format!( + // "Expected baton seq {}, received {baton_seq}", + // stream.baton_seq + // )); } } }; @@ -156,7 +147,7 @@ where None => { let conn = mk_conn() .await - .context("Could not create a database connection")?; + .map_err(|e| HranaError::Internal(e.into()))?; let mut state = server.stream_state.lock(); let stream = Box::new(Stream { @@ -289,7 +280,7 @@ fn encode_baton(server: &Server, stream_id: u64, baton_seq: u64) -> String { /// Decodes a baton encoded with `encode_baton()` and returns `(stream_id, baton_seq)`. Always /// returns a [`ProtocolError::BatonInvalid`] if the baton is invalid, but it attaches an anyhow /// context that describes the precise cause. -fn decode_baton(server: &Server, baton_str: &str) -> color_eyre::Result<(u64, u64)> { +fn decode_baton(server: &Server, baton_str: &str) -> crate::Result<(u64, u64), HranaError> { let baton_data = BASE64_STANDARD_NO_PAD.decode(baton_str).map_err(|err| { ProtocolError::BatonInvalid(format!("Could not base64-decode baton: {err}")) })?; @@ -308,9 +299,9 @@ fn decode_baton(server: &Server, baton_str: &str) -> color_eyre::Result<(u64, u6 let mut hmac = hmac::Hmac::::new_from_slice(&server.baton_key).unwrap(); hmac.update(payload); hmac.verify_slice(received_mac).map_err(|_| { - anyhow!(ProtocolError::BatonInvalid( + ProtocolError::BatonInvalid( "Invalid MAC on baton".to_string() - )) + ) })?; let stream_id = u64::from_be_bytes(payload[0..8].try_into().unwrap()); diff --git a/libsqlx-server/src/hrana/mod.rs b/libsqlx-server/src/hrana/mod.rs index fc85fcfe..54c65327 100644 --- a/libsqlx-server/src/hrana/mod.rs +++ b/libsqlx-server/src/hrana/mod.rs @@ -5,6 +5,7 @@ pub mod http; pub mod proto; mod result_builder; pub mod stmt; +pub mod error; // pub mod ws; #[derive(Debug, Copy, Clone, PartialOrd, Ord, PartialEq, Eq)] @@ -21,48 +22,3 @@ impl fmt::Display for Version { } } } - -/// An unrecoverable protocol error that should close the WebSocket or HTTP stream. A correct -/// client should never trigger any of these errors. -#[derive(thiserror::Error, Debug)] -pub enum ProtocolError { - #[error("Cannot deserialize client message: {source}")] - Deserialize { source: serde_json::Error }, - #[error("Received a binary WebSocket message, which is not supported")] - BinaryWebSocketMessage, - #[error("Received a request before hello message")] - RequestBeforeHello, - - #[error("Stream {stream_id} not found")] - StreamNotFound { stream_id: i32 }, - #[error("Stream {stream_id} already exists")] - StreamExists { stream_id: i32 }, - - #[error("Either `sql` or `sql_id` are required, but not both")] - SqlIdAndSqlGiven, - #[error("Either `sql` or `sql_id` are required")] - SqlIdOrSqlNotGiven, - #[error("SQL text {sql_id} not found")] - SqlNotFound { sql_id: i32 }, - #[error("SQL text {sql_id} already exists")] - SqlExists { sql_id: i32 }, - - #[error("Invalid reference to step in a batch condition")] - BatchCondBadStep, - - #[error("Received an invalid baton: {0}")] - BatonInvalid(String), - #[error("Received a baton that has already been used")] - BatonReused, - #[error("Stream for this baton was closed")] - BatonStreamClosed, - - #[error("{what} is only supported in protocol version {min_version} and higher")] - NotSupported { - what: &'static str, - min_version: Version, - }, - - #[error("{0}")] - ResponseTooLarge(String), -} diff --git a/libsqlx-server/src/hrana/stmt.rs b/libsqlx-server/src/hrana/stmt.rs index becb075f..2053b532 100644 --- a/libsqlx-server/src/hrana/stmt.rs +++ b/libsqlx-server/src/hrana/stmt.rs @@ -1,66 +1,35 @@ use std::collections::HashMap; -use color_eyre::eyre::{anyhow, bail}; use libsqlx::analysis::Statement; use libsqlx::query::{Params, Query, Value}; +use super::error::{HranaError, StmtError, ProtocolError}; use super::result_builder::SingleStatementBuilder; -use super::{proto, ProtocolError, Version}; +use super::{proto, Version}; use crate::allocation::ConnectionHandle; use crate::hrana; -/// An error during execution of an SQL statement. -#[derive(thiserror::Error, Debug)] -pub enum StmtError { - #[error("SQL string could not be parsed: {source}")] - SqlParse { source: color_eyre::eyre::Error }, - #[error("SQL string does not contain any statement")] - SqlNoStmt, - #[error("SQL string contains more than one statement")] - SqlManyStmts, - #[error("Arguments do not match SQL parameters: {msg}")] - ArgsInvalid { msg: String }, - #[error("Specifying both positional and named arguments is not supported")] - ArgsBothPositionalAndNamed, - - #[error("Transaction timed out")] - TransactionTimeout, - #[error("Server cannot handle additional transactions")] - TransactionBusy, - #[error("SQLite error: {message}")] - SqliteError { - source: libsqlx::rusqlite::ffi::Error, - message: String, - }, - #[error("SQL input error: {message} (at offset {offset})")] - SqlInputError { - source: color_eyre::eyre::Error, - message: String, - offset: i32, - }, - - #[error("Operation was blocked{}", .reason.as_ref().map(|msg| format!(": {}", msg)).unwrap_or_default())] - Blocked { reason: Option }, -} - pub async fn execute_stmt( conn: &ConnectionHandle, query: Query, -) -> color_eyre::Result { +) -> crate::Result { let (builder, ret) = SingleStatementBuilder::new(); let pgm = libsqlx::program::Program::from_queries(std::iter::once(query)); conn.execute(pgm, Box::new(builder)).await; - ret.await? - .map_err(|sqld_error| match stmt_error_from_sqld_error(sqld_error) { - Ok(stmt_error) => anyhow!(stmt_error), - Err(sqld_error) => anyhow!(sqld_error), - }) + ret.await + .unwrap() + .map_err(|sqld_error| { + match stmt_error_from_sqld_error(sqld_error) { + Ok(e) => e.into(), + Err(e) => e.into(), + } + }) } pub async fn describe_stmt( _db: &ConnectionHandle, _sql: String, -) -> color_eyre::Result { +) -> crate::Result { todo!(); // match db.describe(sql).await? { // Ok(describe_response) => todo!(), @@ -78,18 +47,18 @@ pub fn proto_stmt_to_query( proto_stmt: &proto::Stmt, sqls: &HashMap, version: Version, -) -> color_eyre::Result { +) -> crate::Result { let sql = proto_sql_to_sql(proto_stmt.sql.as_deref(), proto_stmt.sql_id, sqls, version)?; let mut stmt_iter = Statement::parse(sql); let stmt = match stmt_iter.next() { Some(Ok(stmt)) => stmt, - Some(Err(err)) => bail!(StmtError::SqlParse { source: err.into() }), - None => bail!(StmtError::SqlNoStmt), + Some(Err(err)) => Err(StmtError::SqlParse { source: err.into() })?, + None => Err(StmtError::SqlNoStmt)?, }; if stmt_iter.next().is_some() { - bail!(StmtError::SqlManyStmts) + Err(StmtError::SqlManyStmts)? } let params = if proto_stmt.named_args.is_empty() { @@ -103,7 +72,7 @@ pub fn proto_stmt_to_query( .collect(); Params::Named(values) } else { - bail!(StmtError::ArgsBothPositionalAndNamed) + Err(StmtError::ArgsBothPositionalAndNamed)? }; let want_rows = proto_stmt.want_rows.unwrap_or(true); @@ -119,7 +88,7 @@ pub fn proto_sql_to_sql<'s>( proto_sql_id: Option, sqls: &'s HashMap, verion: Version, -) -> Result<&'s str, ProtocolError> { +) -> crate::Result<&'s str, ProtocolError> { if proto_sql_id.is_some() && verion < Version::Hrana2 { return Err(ProtocolError::NotSupported { what: "`sql_id`", @@ -228,53 +197,6 @@ pub fn proto_error_from_stmt_error(error: &StmtError) -> hrana::proto::Error { } } -impl StmtError { - pub fn code(&self) -> &'static str { - match self { - Self::SqlParse { .. } => "SQL_PARSE_ERROR", - Self::SqlNoStmt => "SQL_NO_STATEMENT", - Self::SqlManyStmts => "SQL_MANY_STATEMENTS", - Self::ArgsInvalid { .. } => "ARGS_INVALID", - Self::ArgsBothPositionalAndNamed => "ARGS_BOTH_POSITIONAL_AND_NAMED", - Self::TransactionTimeout => "TRANSACTION_TIMEOUT", - Self::TransactionBusy => "TRANSACTION_BUSY", - Self::SqliteError { source, .. } => sqlite_error_code(source.code), - Self::SqlInputError { .. } => "SQL_INPUT_ERROR", - Self::Blocked { .. } => "BLOCKED", - } - } -} - -fn sqlite_error_code(code: libsqlx::error::ErrorCode) -> &'static str { - match code { - libsqlx::error::ErrorCode::InternalMalfunction => "SQLITE_INTERNAL", - libsqlx::error::ErrorCode::PermissionDenied => "SQLITE_PERM", - libsqlx::error::ErrorCode::OperationAborted => "SQLITE_ABORT", - libsqlx::error::ErrorCode::DatabaseBusy => "SQLITE_BUSY", - libsqlx::error::ErrorCode::DatabaseLocked => "SQLITE_LOCKED", - libsqlx::error::ErrorCode::OutOfMemory => "SQLITE_NOMEM", - libsqlx::error::ErrorCode::ReadOnly => "SQLITE_READONLY", - libsqlx::error::ErrorCode::OperationInterrupted => "SQLITE_INTERRUPT", - libsqlx::error::ErrorCode::SystemIoFailure => "SQLITE_IOERR", - libsqlx::error::ErrorCode::DatabaseCorrupt => "SQLITE_CORRUPT", - libsqlx::error::ErrorCode::NotFound => "SQLITE_NOTFOUND", - libsqlx::error::ErrorCode::DiskFull => "SQLITE_FULL", - libsqlx::error::ErrorCode::CannotOpen => "SQLITE_CANTOPEN", - libsqlx::error::ErrorCode::FileLockingProtocolFailed => "SQLITE_PROTOCOL", - libsqlx::error::ErrorCode::SchemaChanged => "SQLITE_SCHEMA", - libsqlx::error::ErrorCode::TooBig => "SQLITE_TOOBIG", - libsqlx::error::ErrorCode::ConstraintViolation => "SQLITE_CONSTRAINT", - libsqlx::error::ErrorCode::TypeMismatch => "SQLITE_MISMATCH", - libsqlx::error::ErrorCode::ApiMisuse => "SQLITE_MISUSE", - libsqlx::error::ErrorCode::NoLargeFileSupport => "SQLITE_NOLFS", - libsqlx::error::ErrorCode::AuthorizationForStatementDenied => "SQLITE_AUTH", - libsqlx::error::ErrorCode::ParameterOutOfRange => "SQLITE_RANGE", - libsqlx::error::ErrorCode::NotADatabase => "SQLITE_NOTADB", - libsqlx::error::ErrorCode::Unknown => "SQLITE_UNKNOWN", - _ => "SQLITE_UNKNOWN", - } -} - impl From<&proto::Value> for Value { fn from(proto_value: &proto::Value) -> Value { proto_value_to_value(proto_value) diff --git a/libsqlx-server/src/main.rs b/libsqlx-server/src/main.rs index ff5c415a..465673f3 100644 --- a/libsqlx-server/src/main.rs +++ b/libsqlx-server/src/main.rs @@ -3,7 +3,6 @@ use std::path::{Path, PathBuf}; use std::sync::Arc; use clap::Parser; -use color_eyre::eyre::Result; use compactor::{run_compactor_loop, CompactionQueue}; use config::{AdminApiConfig, ClusterConfig, UserApiConfig}; use http::admin::run_admin_api; @@ -24,6 +23,7 @@ mod allocation; mod compactor; mod config; mod database; +mod error; mod hrana; mod http; mod linc; @@ -32,6 +32,8 @@ mod meta; mod replica_commit_store; mod snapshot_store; +pub type Result = std::result::Result; + #[derive(Debug, Parser)] struct Args { /// Path to the node configuration file @@ -40,10 +42,10 @@ struct Args { } async fn spawn_admin_api( - set: &mut JoinSet>, + set: &mut JoinSet>, config: &AdminApiConfig, bus: Arc>>, -) -> Result<()> { +) -> color_eyre::Result<()> { let admin_api_listener = TcpListener::bind(config.addr).await?; let fut = run_admin_api( http::admin::Config { bus }, @@ -55,11 +57,11 @@ async fn spawn_admin_api( } async fn spawn_user_api( - set: &mut JoinSet>, + set: &mut JoinSet>, config: &UserApiConfig, manager: Arc, bus: Arc>>, -) -> Result<()> { +) -> color_eyre::Result<()> { let user_api_listener = TcpListener::bind(config.addr).await?; set.spawn(run_user_api( http::user::Config { manager, bus }, @@ -70,10 +72,10 @@ async fn spawn_user_api( } async fn spawn_cluster_networking( - set: &mut JoinSet>, + set: &mut JoinSet>, config: &ClusterConfig, bus: Arc>>, -) -> Result<()> { +) -> color_eyre::Result<()> { let server = linc::server::Server::new(bus.clone()); let listener = TcpListener::bind(config.addr).await?; @@ -102,7 +104,7 @@ async fn init_dirs(db_path: &Path) -> color_eyre::Result<()> { } #[tokio::main(flavor = "multi_thread", worker_threads = 10)] -async fn main() -> Result<()> { +async fn main() -> color_eyre::Result<()> { init(); let args = Args::parse(); let config_str = read_to_string(args.config)?; diff --git a/libsqlx-server/src/manager.rs b/libsqlx-server/src/manager.rs index c4fe1bae..33d6c8ba 100644 --- a/libsqlx-server/src/manager.rs +++ b/libsqlx-server/src/manager.rs @@ -65,7 +65,8 @@ impl Manager { dispatcher.clone(), self.compaction_queue.clone(), self.replica_commit_store.clone(), - ).unwrap(), + ) + .unwrap(), connections_futs: JoinSet::new(), next_conn_id: 0, max_concurrent_connections: config.max_conccurent_connection, diff --git a/libsqlx/src/database/libsql/mod.rs b/libsqlx/src/database/libsql/mod.rs index a56c10ae..3a1191f1 100644 --- a/libsqlx/src/database/libsql/mod.rs +++ b/libsqlx/src/database/libsql/mod.rs @@ -278,7 +278,8 @@ mod test { .execute_program( &Program::seq(&["create table test (x)", "insert into test values (42)"]), Box::new(()), - ).unwrap(); + ) + .unwrap(); let logfile = primary.ty.logger.log_file.read(); From 88433f97e1a4fa12052b9ed463f2a310ea110d32 Mon Sep 17 00:00:00 2001 From: ad hoc Date: Sun, 30 Jul 2023 19:45:23 +0200 Subject: [PATCH 62/64] shared hrana server --- libsqlx-server/src/allocation/mod.rs | 27 +-- libsqlx-server/src/database.rs | 25 +-- libsqlx-server/src/hrana/batch.rs | 72 +++--- libsqlx-server/src/hrana/error.rs | 166 +------------- libsqlx-server/src/hrana/http/mod.rs | 116 ++-------- libsqlx-server/src/hrana/http/request.rs | 51 +++-- libsqlx-server/src/hrana/http/stream.rs | 92 ++++---- libsqlx-server/src/hrana/mod.rs | 47 +++- libsqlx-server/src/hrana/result_builder.rs | 13 +- libsqlx-server/src/hrana/stmt.rs | 242 ++++++++++++++------- libsqlx-server/src/hrana/ws/mod.rs | 6 +- libsqlx-server/src/http/user/mod.rs | 10 +- libsqlx-server/src/main.rs | 10 +- libsqlx-server/src/manager.rs | 4 +- 14 files changed, 394 insertions(+), 487 deletions(-) diff --git a/libsqlx-server/src/allocation/mod.rs b/libsqlx-server/src/allocation/mod.rs index 8a757339..6fa4df7f 100644 --- a/libsqlx-server/src/allocation/mod.rs +++ b/libsqlx-server/src/allocation/mod.rs @@ -21,12 +21,9 @@ use crate::allocation::primary::FrameStreamer; use crate::allocation::timeout_notifier::timeout_monitor; use crate::compactor::CompactionQueue; use crate::error::Error; -use crate::hrana; -use crate::hrana::error::HranaError; -use crate::hrana::http::handle_pipeline; -use crate::hrana::http::proto::{PipelineRequestBody, PipelineResponseBody}; +use crate::hrana::proto::DescribeResult; use crate::linc::bus::Dispatch; -use crate::linc::proto::{Frames, Message}; +use crate::linc::proto::{Message, Frames}; use crate::linc::{Inbound, NodeId}; use crate::meta::DatabaseId; use crate::replica_commit_store::ReplicaCommitStore; @@ -56,9 +53,8 @@ pub enum ConnectionMessage { } pub enum AllocationMessage { - HranaPipelineReq { - req: PipelineRequestBody, - ret: oneshot::Sender>, + Connect { + ret: oneshot::Sender>, }, Inbound(Inbound), } @@ -254,7 +250,6 @@ pub struct Allocation { pub max_concurrent_connections: u32, pub connections: HashMap>, - pub hrana_server: Arc, /// handle to the message bus pub dispatcher: Arc, pub db_name: String, @@ -274,6 +269,10 @@ impl ConnectionHandle { builder.finnalize_error("connection closed".to_string()); } } + + pub async fn describe(&self, sql: String) -> crate::Result { + todo!() + } } impl Allocation { @@ -284,14 +283,8 @@ impl Allocation { _ = fut => (), Some(msg) = self.inbox.recv() => { match msg { - AllocationMessage::HranaPipelineReq { req, ret } => { - let server = self.hrana_server.clone(); - if let Err(e) = handle_pipeline(server, req, ret, || async { - let conn = self.new_conn(None).await?; - Ok(conn) - }).await { - tracing::error!("error handling request: {e}") - }; + AllocationMessage::Connect { ret } => { + let _ = ret.send(self.new_conn(None).await); } AllocationMessage::Inbound(msg) => { if let Err(e) = self.handle_inbound(msg).await { diff --git a/libsqlx-server/src/database.rs b/libsqlx-server/src/database.rs index 94e9b724..0b75dd9c 100644 --- a/libsqlx-server/src/database.rs +++ b/libsqlx-server/src/database.rs @@ -1,30 +1,15 @@ use tokio::sync::{mpsc, oneshot}; -use crate::allocation::AllocationMessage; -use crate::error::Error; -use crate::hrana::http::proto::{PipelineRequestBody, PipelineResponseBody}; +use crate::allocation::{AllocationMessage, ConnectionHandle}; pub struct Database { pub sender: mpsc::Sender, } impl Database { - pub async fn hrana_pipeline( - &self, - req: PipelineRequestBody, - ) -> crate::Result { - let (sender, ret) = oneshot::channel(); - if self - .sender - .send(AllocationMessage::HranaPipelineReq { req, ret: sender }) - .await - .is_err() - { - return Err(Error::AllocationClosed); - } - - ret.await.map_err(|_| { - Error::Internal(String::from("response builder dropped by connection")) - }) + pub async fn connect(&self) -> crate::Result { + let (ret, conn) = oneshot::channel(); + self.sender.send(AllocationMessage::Connect { ret }).await.unwrap(); + conn.await.unwrap() } } diff --git a/libsqlx-server/src/hrana/batch.rs b/libsqlx-server/src/hrana/batch.rs index b6ed24fe..7a5f7c87 100644 --- a/libsqlx-server/src/hrana/batch.rs +++ b/libsqlx-server/src/hrana/batch.rs @@ -1,19 +1,22 @@ use std::collections::HashMap; +use std::sync::Arc; -use crate::allocation::ConnectionHandle; - -use super::error::{ProtocolError, StmtError, HranaError}; -use super::result_builder::HranaBatchProtoBuilder; -use super::stmt::{proto_stmt_to_query, stmt_error_from_sqld_error}; -use super::{proto, Version}; +// use crate::auth::Authenticated; use libsqlx::analysis::Statement; use libsqlx::program::{Cond, Program, Step}; -use libsqlx::query::{Params, Query}; +use libsqlx::query::{Query, Params}; use libsqlx::result_builder::{StepResult, StepResultsBuilder}; use tokio::sync::oneshot; -fn proto_cond_to_cond(cond: &proto::BatchCond, max_step_i: usize) -> crate::Result { +use crate::allocation::ConnectionHandle; + +use super::error::HranaError; +use super::result_builder::HranaBatchProtoBuilder; +use super::stmt::{proto_stmt_to_query, StmtError}; +use super::{proto, ProtocolError, Version}; + +fn proto_cond_to_cond(cond: &proto::BatchCond, max_step_i: usize) -> Result { let try_convert_step = |step: i32| -> Result { let step = usize::try_from(step).map_err(|_| ProtocolError::BatchCondBadStep)?; if step >= max_step_i { @@ -21,6 +24,7 @@ fn proto_cond_to_cond(cond: &proto::BatchCond, max_step_i: usize) -> crate::Resu } Ok(step) }; + let cond = match cond { proto::BatchCond::Ok { step } => Cond::Ok { step: try_convert_step(*step)?, @@ -35,13 +39,13 @@ fn proto_cond_to_cond(cond: &proto::BatchCond, max_step_i: usize) -> crate::Resu conds: conds .iter() .map(|cond| proto_cond_to_cond(cond, max_step_i)) - .collect::>()?, + .collect::>()?, }, proto::BatchCond::Or { conds } => Cond::Or { conds: conds .iter() .map(|cond| proto_cond_to_cond(cond, max_step_i)) - .collect::>()?, + .collect::>()?, }, }; @@ -52,7 +56,7 @@ pub fn proto_batch_to_program( batch: &proto::Batch, sqls: &HashMap, version: Version, -) -> crate::Result { +) -> Result { let mut steps = Vec::with_capacity(batch.steps.len()); for (step_i, step) in batch.steps.iter().enumerate() { let query = proto_stmt_to_query(&step.stmt, sqls, version)?; @@ -70,18 +74,22 @@ pub fn proto_batch_to_program( } pub async fn execute_batch( - db: &ConnectionHandle, + conn: &ConnectionHandle, + // auth: Authenticated, pgm: Program, -) -> crate::Result { +) -> Result { let (builder, ret) = HranaBatchProtoBuilder::new(); - db.execute(pgm, Box::new(builder)).await; + conn.execute( + pgm, + // auth, + Box::new(builder)).await; Ok(ret.await.unwrap()) } -pub fn proto_sequence_to_program(sql: &str) -> crate::Result { +pub fn proto_sequence_to_program(sql: &str) -> Result { let stmts = Statement::parse(sql) - .collect::>>() + .collect::, libsqlx::error::Error>>() .map_err(|err| StmtError::SqlParse { source: err.into() })?; let steps = stmts @@ -99,26 +107,30 @@ pub fn proto_sequence_to_program(sql: &str) -> crate::Result }; Step { cond, query } }) - .collect(); + .collect::>(); - Ok(Program { steps }) + Ok(Program { + steps, + }) } -pub async fn execute_sequence(conn: &ConnectionHandle, pgm: Program) -> crate::Result<(), HranaError> { - let (snd, rcv) = oneshot::channel(); - let builder = StepResultsBuilder::new(snd); - conn.execute(pgm, Box::new(builder)).await; +pub async fn execute_sequence( + conn: &ConnectionHandle, + // auth: Authenticated, + pgm: Program) -> Result<(), HranaError> { + let (send, ret) = oneshot::channel(); + let builder = StepResultsBuilder::new(send); + conn.execute(pgm, + // auth, + Box::new(builder)).await; - rcv.await - .map_err(|e| HranaError::Internal(e.into()))? - .map_err(|e| HranaError::Stmt(StmtError::QueryError(e)))? + ret.await + .unwrap() + .unwrap() .into_iter() .try_for_each(|result| match result { StepResult::Ok => Ok(()), - StepResult::Err(e) => { - let stmt_err = stmt_error_from_sqld_error(e)?; - Err(stmt_err)? - }, - StepResult::Skipped => Err(HranaError::StatementSkipped), + StepResult::Err(e) => Err(crate::error::Error::from(e))?, + StepResult::Skipped => todo!(), // Err(anyhow!("Statement in sequence was not executed")), }) } diff --git a/libsqlx-server/src/hrana/error.rs b/libsqlx-server/src/hrana/error.rs index 4092690d..0874a91e 100644 --- a/libsqlx-server/src/hrana/error.rs +++ b/libsqlx-server/src/hrana/error.rs @@ -1,168 +1,18 @@ -use super::Version; +use super::ProtocolError; +use super::stmt::StmtError; +use super::http::StreamError; +use super::http::request::StreamResponseError; -#[derive(thiserror::Error, Debug)] +#[derive(Debug, thiserror::Error)] pub enum HranaError { - #[error(transparent)] - Protocol(#[from] ProtocolError), #[error(transparent)] Stmt(#[from] StmtError), #[error(transparent)] - Internal(color_eyre::eyre::Error), - #[error("Statement in sequence was not executed")] - StatementSkipped, + Proto(#[from] ProtocolError), #[error(transparent)] - Libsqlx(#[from] libsqlx::error::Error), + Stream(#[from] StreamError), #[error(transparent)] StreamResponse(#[from] StreamResponseError), #[error(transparent)] - Stream(#[from] StreamError) -} - -/// An error from executing a [`proto::StreamRequest`] -#[derive(thiserror::Error, Debug)] -pub enum StreamResponseError { - #[error("The server already stores {count} SQL texts, it cannot store more")] - SqlTooMany { count: usize }, - #[error(transparent)] - Stmt(StmtError), -} - - -/// An unrecoverable protocol error that should close the WebSocket or HTTP stream. A correct -/// client should never trigger any of these errors. -#[derive(thiserror::Error, Debug)] -pub enum ProtocolError { - #[error("Cannot deserialize client message: {source}")] - Deserialize { source: serde_json::Error }, - #[error("Received a binary WebSocket message, which is not supported")] - BinaryWebSocketMessage, - #[error("Received a request before hello message")] - RequestBeforeHello, - - #[error("Stream {stream_id} not found")] - StreamNotFound { stream_id: i32 }, - #[error("Stream {stream_id} already exists")] - StreamExists { stream_id: i32 }, - - #[error("Either `sql` or `sql_id` are required, but not both")] - SqlIdAndSqlGiven, - #[error("Either `sql` or `sql_id` are required")] - SqlIdOrSqlNotGiven, - #[error("SQL text {sql_id} not found")] - SqlNotFound { sql_id: i32 }, - #[error("SQL text {sql_id} already exists")] - SqlExists { sql_id: i32 }, - - #[error("Invalid reference to step in a batch condition")] - BatchCondBadStep, - - #[error("Received an invalid baton: {0}")] - BatonInvalid(String), - #[error("Received a baton that has already been used")] - BatonReused, - #[error("Stream for this baton was closed")] - BatonStreamClosed, - - #[error("{what} is only supported in protocol version {min_version} and higher")] - NotSupported { - what: &'static str, - min_version: Version, - }, - - #[error("{0}")] - ResponseTooLarge(String), -} - -/// An error during execution of an SQL statement. -#[derive(thiserror::Error, Debug)] -pub enum StmtError { - #[error("SQL string could not be parsed: {source}")] - SqlParse { source: color_eyre::eyre::Error }, - #[error("SQL string does not contain any statement")] - SqlNoStmt, - #[error("SQL string contains more than one statement")] - SqlManyStmts, - #[error("Arguments do not match SQL parameters: {msg}")] - ArgsInvalid { msg: String }, - #[error("Specifying both positional and named arguments is not supported")] - ArgsBothPositionalAndNamed, - - #[error("Transaction timed out")] - TransactionTimeout, - #[error("Server cannot handle additional transactions")] - TransactionBusy, - #[error("SQLite error: {message}")] - SqliteError { - source: libsqlx::rusqlite::ffi::Error, - message: String, - }, - #[error("SQL input error: {message} (at offset {offset})")] - SqlInputError { - source: color_eyre::eyre::Error, - message: String, - offset: i32, - }, - - #[error("Operation was blocked{}", .reason.as_ref().map(|msg| format!(": {}", msg)).unwrap_or_default())] - Blocked { reason: Option }, - #[error("query error: {0}")] - QueryError(String) -} - -impl StmtError { - pub fn code(&self) -> &'static str { - match self { - Self::SqlParse { .. } => "SQL_PARSE_ERROR", - Self::SqlNoStmt => "SQL_NO_STATEMENT", - Self::SqlManyStmts => "SQL_MANY_STATEMENTS", - Self::ArgsInvalid { .. } => "ARGS_INVALID", - Self::ArgsBothPositionalAndNamed => "ARGS_BOTH_POSITIONAL_AND_NAMED", - Self::TransactionTimeout => "TRANSACTION_TIMEOUT", - Self::TransactionBusy => "TRANSACTION_BUSY", - Self::SqliteError { source, .. } => sqlite_error_code(source.code), - Self::SqlInputError { .. } => "SQL_INPUT_ERROR", - Self::Blocked { .. } => "BLOCKED", - Self::QueryError(_) => todo!(), - } - } -} - -fn sqlite_error_code(code: libsqlx::error::ErrorCode) -> &'static str { - match code { - libsqlx::error::ErrorCode::InternalMalfunction => "SQLITE_INTERNAL", - libsqlx::error::ErrorCode::PermissionDenied => "SQLITE_PERM", - libsqlx::error::ErrorCode::OperationAborted => "SQLITE_ABORT", - libsqlx::error::ErrorCode::DatabaseBusy => "SQLITE_BUSY", - libsqlx::error::ErrorCode::DatabaseLocked => "SQLITE_LOCKED", - libsqlx::error::ErrorCode::OutOfMemory => "SQLITE_NOMEM", - libsqlx::error::ErrorCode::ReadOnly => "SQLITE_READONLY", - libsqlx::error::ErrorCode::OperationInterrupted => "SQLITE_INTERRUPT", - libsqlx::error::ErrorCode::SystemIoFailure => "SQLITE_IOERR", - libsqlx::error::ErrorCode::DatabaseCorrupt => "SQLITE_CORRUPT", - libsqlx::error::ErrorCode::NotFound => "SQLITE_NOTFOUND", - libsqlx::error::ErrorCode::DiskFull => "SQLITE_FULL", - libsqlx::error::ErrorCode::CannotOpen => "SQLITE_CANTOPEN", - libsqlx::error::ErrorCode::FileLockingProtocolFailed => "SQLITE_PROTOCOL", - libsqlx::error::ErrorCode::SchemaChanged => "SQLITE_SCHEMA", - libsqlx::error::ErrorCode::TooBig => "SQLITE_TOOBIG", - libsqlx::error::ErrorCode::ConstraintViolation => "SQLITE_CONSTRAINT", - libsqlx::error::ErrorCode::TypeMismatch => "SQLITE_MISMATCH", - libsqlx::error::ErrorCode::ApiMisuse => "SQLITE_MISUSE", - libsqlx::error::ErrorCode::NoLargeFileSupport => "SQLITE_NOLFS", - libsqlx::error::ErrorCode::AuthorizationForStatementDenied => "SQLITE_AUTH", - libsqlx::error::ErrorCode::ParameterOutOfRange => "SQLITE_RANGE", - libsqlx::error::ErrorCode::NotADatabase => "SQLITE_NOTADB", - libsqlx::error::ErrorCode::Unknown => "SQLITE_UNKNOWN", - _ => "SQLITE_UNKNOWN", - } + Libsqlx(crate::error::Error), } - -/// An unrecoverable error that should close the stream. The difference from [`ProtocolError`] is -/// that a correct client may trigger this error, it does not mean that the protocol has been -/// violated. -#[derive(thiserror::Error, Debug)] -pub enum StreamError { - #[error("The stream has expired due to inactivity")] - StreamExpired, -} - diff --git a/libsqlx-server/src/hrana/http/mod.rs b/libsqlx-server/src/hrana/http/mod.rs index ea316aea..65686fbe 100644 --- a/libsqlx-server/src/hrana/http/mod.rs +++ b/libsqlx-server/src/hrana/http/mod.rs @@ -1,19 +1,12 @@ -use std::sync::Arc; - -use color_eyre::eyre::Context; -use futures::Future; use parking_lot::Mutex; -use serde::{de::DeserializeOwned, Serialize}; -use tokio::sync::oneshot; - -use crate::allocation::ConnectionHandle; - -use self::proto::{PipelineRequestBody, PipelineResponseBody}; -use super::error::{HranaError, ProtocolError, StreamError}; +use super::error::HranaError; +// use crate::auth::Authenticated; +use crate::database::Database; +pub use stream::StreamError; pub mod proto; -mod request; +pub mod request; mod stream; pub struct Server { @@ -42,89 +35,24 @@ impl Server { } } -fn handle_index() -> crate::Result, HranaError> { - Ok(text_response( - hyper::StatusCode::OK, - "Hello, this is HTTP API v2 (Hrana over HTTP)".into(), - )) -} - -pub async fn handle_pipeline( - server: Arc, - req: PipelineRequestBody, - ret: oneshot::Sender>, - mk_conn: F, -) -> crate::Result<(), HranaError> -where - F: FnOnce() -> Fut, - Fut: Future>, +pub async fn handle_pipeline( + server: &Server, + // auth: Authenticated, + req: proto::PipelineRequestBody, + db: Database, +) -> crate::Result { - let mut stream_guard = stream::acquire(server.clone(), req.baton.as_deref(), mk_conn).await?; - - tokio::spawn(async move { - let f = async move { - let mut results = Vec::with_capacity(req.requests.len()); - for request in req.requests.into_iter() { - let result = request::handle(&mut stream_guard, request) - .await?; - results.push(result); - } - - Ok(proto::PipelineResponseBody { - baton: stream_guard.release(), - base_url: server.self_url.clone(), - results, - }) - }; - - let _ = ret.send(f.await); - }); - - Ok(()) -} - -async fn read_request_json( - req: hyper::Request, -) -> color_eyre::Result { - let req_body = hyper::body::to_bytes(req.into_body()) - .await - .context("Could not read request body")?; - let req_body = serde_json::from_slice(&req_body) - .map_err(|err| ProtocolError::Deserialize { source: err }) - .context("Could not deserialize JSON request body")?; - Ok(req_body) -} - -fn protocol_error_response(err: ProtocolError) -> hyper::Response { - text_response(hyper::StatusCode::BAD_REQUEST, err.to_string()) -} + let mut stream_guard = stream::acquire(server, req.baton.as_deref(), db).await?; -fn stream_error_response(err: StreamError) -> hyper::Response { - json_response( - hyper::StatusCode::INTERNAL_SERVER_ERROR, - &proto::Error { - message: err.to_string(), - code: err.code().into(), - }, - ) -} - -fn json_response( - status: hyper::StatusCode, - resp_body: &T, -) -> hyper::Response { - let resp_body = serde_json::to_vec(resp_body).unwrap(); - hyper::Response::builder() - .status(status) - .header(hyper::http::header::CONTENT_TYPE, "application/json") - .body(hyper::Body::from(resp_body)) - .unwrap() -} + let mut results = Vec::with_capacity(req.requests.len()); + for request in req.requests.into_iter() { + let result = request::handle(&mut stream_guard, /*auth,*/ request).await?; + results.push(result); + } -fn text_response(status: hyper::StatusCode, resp_body: String) -> hyper::Response { - hyper::Response::builder() - .status(status) - .header(hyper::http::header::CONTENT_TYPE, "text/plain") - .body(hyper::Body::from(resp_body)) - .unwrap() + Ok(proto::PipelineResponseBody { + baton: stream_guard.release(), + base_url: server.self_url.clone(), + results, + }) } diff --git a/libsqlx-server/src/hrana/http/request.rs b/libsqlx-server/src/hrana/http/request.rs index c9369315..c99c493b 100644 --- a/libsqlx-server/src/hrana/http/request.rs +++ b/libsqlx-server/src/hrana/http/request.rs @@ -1,69 +1,82 @@ -use crate::hrana::error::{HranaError, ProtocolError, StreamResponseError}; +use crate::hrana::ProtocolError; +use crate::hrana::error::HranaError; use super::super::{batch, stmt, Version}; use super::{proto, stream}; +// use crate::auth::Authenticated; + +/// An error from executing a [`proto::StreamRequest`] +#[derive(thiserror::Error, Debug)] +pub enum StreamResponseError { + #[error("The server already stores {count} SQL texts, it cannot store more")] + SqlTooMany { count: usize }, + #[error(transparent)] + Stmt(stmt::StmtError), +} pub async fn handle( - stream_guard: &mut stream::Guard, + stream_guard: &mut stream::Guard<'_>, + // auth: Authenticated, request: proto::StreamRequest, -) -> crate::Result { - let result = match try_handle(stream_guard, request).await { +) -> Result { + let result = match try_handle(stream_guard/*, auth*/, request).await { Ok(response) => proto::StreamResult::Ok { response }, Err(err) => { - if let HranaError::StreamResponse(resp_err) = err { + if let HranaError::StreamResponse(err) = err { let error = proto::Error { - message: resp_err.to_string(), - code: resp_err.code().into(), + message: err.to_string(), + code: err.code().into(), }; proto::StreamResult::Error { error } } else { - return Err(err); + Err(err)? } } }; - Ok(result) } async fn try_handle( - stream_guard: &mut stream::Guard, + stream_guard: &mut stream::Guard<'_>, + // auth: Authenticated, request: proto::StreamRequest, ) -> crate::Result { Ok(match request { proto::StreamRequest::Close(_req) => { - stream_guard.close_db(); + stream_guard.close_conn(); proto::StreamResponse::Close(proto::CloseStreamResp {}) } proto::StreamRequest::Execute(req) => { - let db = stream_guard.get_db()?; + let db = stream_guard.get_conn()?; let sqls = stream_guard.sqls(); let query = stmt::proto_stmt_to_query(&req.stmt, sqls, Version::Hrana2)?; - let result = stmt::execute_stmt(db, query).await?; + let result = stmt::execute_stmt(db, /*auth,*/ query) + .await?; proto::StreamResponse::Execute(proto::ExecuteStreamResp { result }) } proto::StreamRequest::Batch(req) => { - let db = stream_guard.get_db()?; + let db = stream_guard.get_conn()?; let sqls = stream_guard.sqls(); let pgm = batch::proto_batch_to_program(&req.batch, sqls, Version::Hrana2)?; - let result = batch::execute_batch(db, pgm).await?; + let result = batch::execute_batch(db, /*auth,*/ pgm).await?; proto::StreamResponse::Batch(proto::BatchStreamResp { result }) } proto::StreamRequest::Sequence(req) => { - let db = stream_guard.get_db()?; + let db = stream_guard.get_conn()?; let sqls = stream_guard.sqls(); let sql = stmt::proto_sql_to_sql(req.sql.as_deref(), req.sql_id, sqls, Version::Hrana2)?; let pgm = batch::proto_sequence_to_program(sql)?; - batch::execute_sequence(db, pgm) + batch::execute_sequence(db, /*auth,*/ pgm) .await?; proto::StreamResponse::Sequence(proto::SequenceStreamResp {}) } proto::StreamRequest::Describe(req) => { - let db = stream_guard.get_db()?; + let db = stream_guard.get_conn()?; let sqls = stream_guard.sqls(); let sql = stmt::proto_sql_to_sql(req.sql.as_deref(), req.sql_id, sqls, Version::Hrana2)?; - let result = stmt::describe_stmt(db, sql.into()) + let result = stmt::describe_stmt(db, /* auth,*/ sql.into()) .await?; proto::StreamResponse::Describe(proto::DescribeStreamResp { result }) } diff --git a/libsqlx-server/src/hrana/http/stream.rs b/libsqlx-server/src/hrana/http/stream.rs index 3d87a96e..f204f799 100644 --- a/libsqlx-server/src/hrana/http/stream.rs +++ b/libsqlx-server/src/hrana/http/stream.rs @@ -1,18 +1,18 @@ -use std::cmp::Reverse; -use std::collections::{HashMap, VecDeque}; -use std::pin::Pin; -use std::sync::Arc; -use std::{future, mem, task}; - use base64::prelude::{Engine as _, BASE64_STANDARD_NO_PAD}; use futures::Future; use hmac::Mac as _; use priority_queue::PriorityQueue; +use std::cmp::Reverse; +use std::collections::{HashMap, VecDeque}; +use std::pin::Pin; +use std::{future, mem, task}; use tokio::time::{Duration, Instant}; +use super::super::ProtocolError; use super::Server; use crate::allocation::ConnectionHandle; -use crate::hrana::error::{ProtocolError, HranaError, StreamError}; +use crate::database::Database; +use crate::hrana::error::HranaError; /// Mutable state related to streams, owned by [`Server`] and protected with a mutex. pub struct ServerStreamState { @@ -67,8 +67,8 @@ struct Stream { /// Guard object that is used to access a stream from the outside. The guard makes sure that the /// stream's entry in [`ServerStreamState::handles`] is either removed or replaced with /// [`Handle::Available`] after the guard goes out of scope. -pub struct Guard { - server: Arc, +pub struct Guard<'srv> { + server: &'srv Server, /// The guarded stream. This is only set to `None` in the destructor. stream: Option>, /// If set to `true`, the destructor will release the stream for further use (saving it as @@ -77,6 +77,15 @@ pub struct Guard { release: bool, } +/// An unrecoverable error that should close the stream. The difference from [`ProtocolError`] is +/// that a correct client may trigger this error, it does not mean that the protocol has been +/// violated. +#[derive(thiserror::Error, Debug)] +pub enum StreamError { + #[error("The stream has expired due to inactivity")] + StreamExpired, +} + impl ServerStreamState { pub fn new() -> Self { Self { @@ -92,43 +101,33 @@ impl ServerStreamState { /// Acquire a guard to a new or existing stream. If baton is `Some`, we try to look up the stream, /// otherwise we create a new stream. -pub async fn acquire( - server: Arc, +pub async fn acquire<'srv>( + server: &'srv Server, baton: Option<&str>, - mk_conn: F, -) -> crate::Result -where - F: FnOnce() -> Fut, - Fut: Future>, -{ + db: Database, +) -> Result, HranaError> { let stream = match baton { Some(baton) => { - let (stream_id, baton_seq) = decode_baton(&server, baton)?; + let (stream_id, baton_seq) = decode_baton(server, baton)?; let mut state = server.stream_state.lock(); let handle = state.handles.get_mut(&stream_id); match handle { None => { - return Err(ProtocolError::BatonInvalid(format!( - "Stream handle for {stream_id} was not found" - )) - .into()) + Err(ProtocolError::BatonInvalid { reason: format!("Stream handle for {stream_id} was not found")})?; } Some(Handle::Acquired) => { - Err(ProtocolError::BatonReused)? - // .context(format!("Stream handle for {stream_id} is acquired")); + Err(ProtocolError::BatonReused { reason: format!("Stream handle for {stream_id} is acquired")})?; } Some(Handle::Expired) => { Err(StreamError::StreamExpired)? - // .context(format!("Stream handle for {stream_id} is expired")); } Some(Handle::Available(stream)) => { if stream.baton_seq != baton_seq { - Err(ProtocolError::BatonReused)? - // .context(format!( - // "Expected baton seq {}, received {baton_seq}", - // stream.baton_seq - // )); + Err(ProtocolError::BatonReused { reason: format!( + "Expected baton seq {}, received {baton_seq}", + stream.baton_seq + )})?; } } }; @@ -145,10 +144,7 @@ where stream } None => { - let conn = mk_conn() - .await - .map_err(|e| HranaError::Internal(e.into()))?; - + let conn = db.connect().await.unwrap(); let mut state = server.stream_state.lock(); let stream = Box::new(Stream { conn: Some(conn), @@ -174,15 +170,15 @@ where }) } -impl Guard { - pub fn get_db(&self) -> Result<&ConnectionHandle, ProtocolError> { +impl<'srv> Guard<'srv> { + pub fn get_conn(&self) -> Result<&ConnectionHandle, ProtocolError> { let stream = self.stream.as_ref().unwrap(); stream.conn.as_ref().ok_or(ProtocolError::BatonStreamClosed) } /// Closes the database connection. The next call to [`Guard::release()`] will then remove the /// stream. - pub fn close_db(&mut self) { + pub fn close_conn(&mut self) { let stream = self.stream.as_mut().unwrap(); stream.conn = None; } @@ -203,7 +199,7 @@ impl Guard { if stream.conn.is_some() { self.release = true; // tell destructor to make the stream available again Some(encode_baton( - &self.server, + self.server, stream.stream_id, stream.baton_seq, )) @@ -213,7 +209,7 @@ impl Guard { } } -impl Drop for Guard { +impl<'srv> Drop for Guard<'srv> { fn drop(&mut self) { let stream = self.stream.take().unwrap(); let stream_id = stream.stream_id; @@ -282,15 +278,14 @@ fn encode_baton(server: &Server, stream_id: u64, baton_seq: u64) -> String { /// context that describes the precise cause. fn decode_baton(server: &Server, baton_str: &str) -> crate::Result<(u64, u64), HranaError> { let baton_data = BASE64_STANDARD_NO_PAD.decode(baton_str).map_err(|err| { - ProtocolError::BatonInvalid(format!("Could not base64-decode baton: {err}")) + ProtocolError::BatonInvalid { reason: format!("Could not base64-decode baton: {err}") } })?; if baton_data.len() != 48 { - return Err(ProtocolError::BatonInvalid(format!( - "Baton has invalid size of {} bytes", - baton_data.len() - )) - .into()); + Err(ProtocolError::BatonInvalid { reason: format!( + "Baton has invalid size of {} bytes", + baton_data.len() + )})?; } let payload = &baton_data[0..16]; @@ -298,11 +293,8 @@ fn decode_baton(server: &Server, baton_str: &str) -> crate::Result<(u64, u64), H let mut hmac = hmac::Hmac::::new_from_slice(&server.baton_key).unwrap(); hmac.update(payload); - hmac.verify_slice(received_mac).map_err(|_| { - ProtocolError::BatonInvalid( - "Invalid MAC on baton".to_string() - ) - })?; + hmac.verify_slice(received_mac) + .map_err(|_| ProtocolError::BatonInvalid { reason: "Invalid MAC on baton".into() })?; let stream_id = u64::from_be_bytes(payload[0..8].try_into().unwrap()); let baton_seq = u64::from_be_bytes(payload[8..16].try_into().unwrap()); diff --git a/libsqlx-server/src/hrana/mod.rs b/libsqlx-server/src/hrana/mod.rs index 54c65327..8f4c7f68 100644 --- a/libsqlx-server/src/hrana/mod.rs +++ b/libsqlx-server/src/hrana/mod.rs @@ -5,8 +5,8 @@ pub mod http; pub mod proto; mod result_builder; pub mod stmt; -pub mod error; // pub mod ws; +pub mod error; #[derive(Debug, Copy, Clone, PartialOrd, Ord, PartialEq, Eq)] pub enum Version { @@ -22,3 +22,48 @@ impl fmt::Display for Version { } } } + +/// An unrecoverable protocol error that should close the WebSocket or HTTP stream. A correct +/// client should never trigger any of these errors. +#[derive(thiserror::Error, Debug)] +pub enum ProtocolError { + #[error("Cannot deserialize client message: {source}")] + Deserialize { source: serde_json::Error }, + #[error("Received a binary WebSocket message, which is not supported")] + BinaryWebSocketMessage, + #[error("Received a request before hello message")] + RequestBeforeHello, + + #[error("Stream {stream_id} not found")] + StreamNotFound { stream_id: i32 }, + #[error("Stream {stream_id} already exists")] + StreamExists { stream_id: i32 }, + + #[error("Either `sql` or `sql_id` are required, but not both")] + SqlIdAndSqlGiven, + #[error("Either `sql` or `sql_id` are required")] + SqlIdOrSqlNotGiven, + #[error("SQL text {sql_id} not found")] + SqlNotFound { sql_id: i32 }, + #[error("SQL text {sql_id} already exists")] + SqlExists { sql_id: i32 }, + + #[error("Invalid reference to step in a batch condition")] + BatchCondBadStep, + + #[error("Received an invalid baton: {reason}")] + BatonInvalid { reason: String }, + #[error("Received a baton that has already been used: {reason}")] + BatonReused { reason: String }, + #[error("Stream for this baton was closed")] + BatonStreamClosed, + + #[error("{what} is only supported in protocol version {min_version} and higher")] + NotSupported { + what: &'static str, + min_version: Version, + }, + + #[error("{0}")] + ResponseTooLarge(String), +} diff --git a/libsqlx-server/src/hrana/result_builder.rs b/libsqlx-server/src/hrana/result_builder.rs index c0c597bf..75d84683 100644 --- a/libsqlx-server/src/hrana/result_builder.rs +++ b/libsqlx-server/src/hrana/result_builder.rs @@ -5,19 +5,20 @@ use bytes::Bytes; use libsqlx::{result_builder::*, FrameNo}; use tokio::sync::oneshot; -use crate::hrana::stmt::{proto_error_from_stmt_error, stmt_error_from_sqld_error}; +use crate::hrana::stmt::proto_error_from_stmt_error; +use super::error::HranaError; use super::proto; pub struct SingleStatementBuilder { builder: StatementBuilder, - ret: Option>>, + ret: Option>>, } impl SingleStatementBuilder { pub fn new() -> ( Self, - oneshot::Receiver>, + oneshot::Receiver>, ) { let (ret, rcv) = oneshot::channel(); ( @@ -199,9 +200,9 @@ impl StatementBuilder { Ok(()) } - pub fn take_ret(&mut self) -> Result { + pub fn take_ret(&mut self) -> crate::Result { match self.err.take() { - Some(err) => Err(err), + Some(err) => Err(crate::error::Error::from(err))?, None => Ok(proto::StmtResult { cols: std::mem::take(&mut self.cols), rows: std::mem::take(&mut self.rows), @@ -331,7 +332,7 @@ impl ResultBuilder for HranaBatchProtoBuilder { Err(e) => { self.step_results.push(None); self.step_errors.push(Some(proto_error_from_stmt_error( - &stmt_error_from_sqld_error(e).map_err(QueryResultBuilderError::from_any)?, + Err(HranaError::from(e)).map_err(QueryResultBuilderError::from_any)?, ))); } } diff --git a/libsqlx-server/src/hrana/stmt.rs b/libsqlx-server/src/hrana/stmt.rs index 2053b532..f0021028 100644 --- a/libsqlx-server/src/hrana/stmt.rs +++ b/libsqlx-server/src/hrana/stmt.rs @@ -1,44 +1,75 @@ use std::collections::HashMap; +use libsqlx::DescribeResponse; use libsqlx::analysis::Statement; -use libsqlx::query::{Params, Query, Value}; +use libsqlx::program::Program; +use libsqlx::query::{Query, Params, Value}; -use super::error::{HranaError, StmtError, ProtocolError}; +use super::error::HranaError; use super::result_builder::SingleStatementBuilder; -use super::{proto, Version}; +use super::{proto, ProtocolError, Version}; use crate::allocation::ConnectionHandle; +// use crate::auth::Authenticated; use crate::hrana; +/// An error during execution of an SQL statement. +#[derive(thiserror::Error, Debug)] +pub enum StmtError { + #[error("SQL string could not be parsed: {source}")] + SqlParse { source: color_eyre::eyre::Error }, + #[error("SQL string does not contain any statement")] + SqlNoStmt, + #[error("SQL string contains more than one statement")] + SqlManyStmts, + #[error("Arguments do not match SQL parameters: {source}")] + ArgsInvalid { source: color_eyre::eyre::Error }, + #[error("Specifying both positional and named arguments is not supported")] + ArgsBothPositionalAndNamed, + + #[error("Transaction timed out")] + TransactionTimeout, + #[error("Server cannot handle additional transactions")] + TransactionBusy, + #[error("SQLite error: {message}")] + SqliteError { + source: libsqlx::rusqlite::ffi::Error, + message: String, + }, + #[error("SQL input error: {message} (at offset {offset})")] + SqlInputError { + source: libsqlx::rusqlite::ffi::Error, + message: String, + offset: i32, + }, + + #[error("Operation was blocked{}", .reason.as_ref().map(|msg| format!(": {}", msg)).unwrap_or_default())] + Blocked { reason: Option }, +} + pub async fn execute_stmt( conn: &ConnectionHandle, + // auth: Authenticated, query: Query, ) -> crate::Result { let (builder, ret) = SingleStatementBuilder::new(); - let pgm = libsqlx::program::Program::from_queries(std::iter::once(query)); - conn.execute(pgm, Box::new(builder)).await; - ret.await - .unwrap() - .map_err(|sqld_error| { - match stmt_error_from_sqld_error(sqld_error) { - Ok(e) => e.into(), - Err(e) => e.into(), - } - }) + conn.execute(Program::from_queries(Some(query))/*, auth*/, Box::new(builder)).await; + let stmt_res = ret.await; + Ok(stmt_res.unwrap()?) } pub async fn describe_stmt( - _db: &ConnectionHandle, - _sql: String, + db: &ConnectionHandle, + // auth: Authenticated, + sql: String, ) -> crate::Result { - todo!(); - // match db.describe(sql).await? { - // Ok(describe_response) => todo!(), - // // Ok(proto_describe_result_from_describe_response( - // // describe_response, - // // )), + todo!() + // match db.describe(sql/*, auth*/).await? { + // Ok(describe_response) => Ok(proto_describe_result_from_describe_response( + // describe_response, + // )), // Err(sqld_error) => match stmt_error_from_sqld_error(sqld_error) { - // Ok(stmt_error) => bail!(stmt_error), - // Err(sqld_error) => bail!(sqld_error), + // Ok(stmt_error) => Err(stmt_error)?, + // Err(sqld_error) => Err(sqld_error)?, // }, // } } @@ -46,9 +77,9 @@ pub async fn describe_stmt( pub fn proto_stmt_to_query( proto_stmt: &proto::Stmt, sqls: &HashMap, - version: Version, + verion: Version, ) -> crate::Result { - let sql = proto_sql_to_sql(proto_stmt.sql.as_deref(), proto_stmt.sql_id, sqls, version)?; + let sql = proto_sql_to_sql(proto_stmt.sql.as_deref(), proto_stmt.sql_id, sqls, verion)?; let mut stmt_iter = Statement::parse(sql); let stmt = match stmt_iter.next() { @@ -88,7 +119,7 @@ pub fn proto_sql_to_sql<'s>( proto_sql_id: Option, sqls: &'s HashMap, verion: Version, -) -> crate::Result<&'s str, ProtocolError> { +) -> Result<&'s str, ProtocolError> { if proto_sql_id.is_some() && verion < Version::Hrana2 { return Err(ProtocolError::NotSupported { what: "`sql_id`", @@ -131,63 +162,63 @@ fn proto_value_from_value(value: Value) -> proto::Value { } } -// fn proto_describe_result_from_describe_response( -// response: DescribeResponse, -// ) -> proto::DescribeResult { -// proto::DescribeResult { -// params: response -// .params -// .into_iter() -// .map(|p| proto::DescribeParam { name: p.name }) -// .collect(), -// cols: response -// .cols -// .into_iter() -// .map(|c| proto::DescribeCol { -// name: c.name, -// decltype: c.decltype, -// }) -// .collect(), -// is_explain: response.is_explain, -// is_readonly: response.is_readonly, -// } -// } - -pub fn stmt_error_from_sqld_error( - sqld_error: libsqlx::error::Error, -) -> Result { - Ok(match sqld_error { - libsqlx::error::Error::LibSqlInvalidQueryParams(msg) => StmtError::ArgsInvalid { msg }, - libsqlx::error::Error::LibSqlTxTimeout => StmtError::TransactionTimeout, - libsqlx::error::Error::LibSqlTxBusy => StmtError::TransactionBusy, - libsqlx::error::Error::Blocked(reason) => StmtError::Blocked { reason }, - libsqlx::error::Error::RusqliteError(rusqlite_error) => match rusqlite_error { - libsqlx::error::RusqliteError::SqliteFailure(sqlite_error, Some(message)) => { - StmtError::SqliteError { - source: sqlite_error, - message, - } - } - libsqlx::error::RusqliteError::SqliteFailure(sqlite_error, None) => { - StmtError::SqliteError { - message: sqlite_error.to_string(), - source: sqlite_error, - } +fn proto_describe_result_from_describe_response( + response: DescribeResponse, +) -> proto::DescribeResult { + proto::DescribeResult { + params: response + .params + .into_iter() + .map(|p| proto::DescribeParam { name: p.name }) + .collect(), + cols: response + .cols + .into_iter() + .map(|c| proto::DescribeCol { + name: c.name, + decltype: c.decltype, + }) + .collect(), + is_explain: response.is_explain, + is_readonly: response.is_readonly, + } +} + +impl From for HranaError { + fn from(error: crate::error::Error) -> Self { + if let crate::error::Error::Libsqlx(e) = error { + match e { + libsqlx::error::Error::LibSqlInvalidQueryParams(source) => StmtError::ArgsInvalid { source: color_eyre::eyre::anyhow!("{source}") }.into(), + libsqlx::error::Error::LibSqlTxTimeout => StmtError::TransactionTimeout.into(), + libsqlx::error::Error::LibSqlTxBusy => StmtError::TransactionBusy.into(), + libsqlx::error::Error::Blocked(reason) => StmtError::Blocked { reason }.into(), + libsqlx::error::Error::RusqliteError(rusqlite_error) => match rusqlite_error { + libsqlx::error::RusqliteError::SqliteFailure(sqlite_error, Some(message)) => StmtError::SqliteError { + source: sqlite_error, + message, + }.into(), + libsqlx::error::RusqliteError::SqliteFailure(sqlite_error, None) => StmtError::SqliteError { + message: sqlite_error.to_string(), + source: sqlite_error, + }.into(), + libsqlx::error::RusqliteError::SqlInputError { + error: sqlite_error, + msg: message, + offset, + .. + } => StmtError::SqlInputError { + source: sqlite_error, + message, + offset, + }.into(), + rusqlite_error => return crate::error::Error::from(libsqlx::error::Error::RusqliteError(rusqlite_error)).into(), + }, + sqld_error => return crate::error::Error::from(sqld_error).into(), } - libsqlx::error::RusqliteError::SqlInputError { - error: sqlite_error, - msg: message, - offset, - .. - } => StmtError::SqlInputError { - source: sqlite_error.into(), - message, - offset, - }, - rusqlite_error => return Err(libsqlx::error::Error::RusqliteError(rusqlite_error)), - }, - sqld_error => return Err(sqld_error), - }) + } else { + Self::Libsqlx(error) + } + } } pub fn proto_error_from_stmt_error(error: &StmtError) -> hrana::proto::Error { @@ -197,6 +228,53 @@ pub fn proto_error_from_stmt_error(error: &StmtError) -> hrana::proto::Error { } } +impl StmtError { + pub fn code(&self) -> &'static str { + match self { + Self::SqlParse { .. } => "SQL_PARSE_ERROR", + Self::SqlNoStmt => "SQL_NO_STATEMENT", + Self::SqlManyStmts => "SQL_MANY_STATEMENTS", + Self::ArgsInvalid { .. } => "ARGS_INVALID", + Self::ArgsBothPositionalAndNamed => "ARGS_BOTH_POSITIONAL_AND_NAMED", + Self::TransactionTimeout => "TRANSACTION_TIMEOUT", + Self::TransactionBusy => "TRANSACTION_BUSY", + Self::SqliteError { source, .. } => sqlite_error_code(source.code), + Self::SqlInputError { .. } => "SQL_INPUT_ERROR", + Self::Blocked { .. } => "BLOCKED", + } + } +} + +fn sqlite_error_code(code: libsqlx::error::ErrorCode) -> &'static str { + match code { + libsqlx::error::ErrorCode::InternalMalfunction => "SQLITE_INTERNAL", + libsqlx::error::ErrorCode::PermissionDenied => "SQLITE_PERM", + libsqlx::error::ErrorCode::OperationAborted => "SQLITE_ABORT", + libsqlx::error::ErrorCode::DatabaseBusy => "SQLITE_BUSY", + libsqlx::error::ErrorCode::DatabaseLocked => "SQLITE_LOCKED", + libsqlx::error::ErrorCode::OutOfMemory => "SQLITE_NOMEM", + libsqlx::error::ErrorCode::ReadOnly => "SQLITE_READONLY", + libsqlx::error::ErrorCode::OperationInterrupted => "SQLITE_INTERRUPT", + libsqlx::error::ErrorCode::SystemIoFailure => "SQLITE_IOERR", + libsqlx::error::ErrorCode::DatabaseCorrupt => "SQLITE_CORRUPT", + libsqlx::error::ErrorCode::NotFound => "SQLITE_NOTFOUND", + libsqlx::error::ErrorCode::DiskFull => "SQLITE_FULL", + libsqlx::error::ErrorCode::CannotOpen => "SQLITE_CANTOPEN", + libsqlx::error::ErrorCode::FileLockingProtocolFailed => "SQLITE_PROTOCOL", + libsqlx::error::ErrorCode::SchemaChanged => "SQLITE_SCHEMA", + libsqlx::error::ErrorCode::TooBig => "SQLITE_TOOBIG", + libsqlx::error::ErrorCode::ConstraintViolation => "SQLITE_CONSTRAINT", + libsqlx::error::ErrorCode::TypeMismatch => "SQLITE_MISMATCH", + libsqlx::error::ErrorCode::ApiMisuse => "SQLITE_MISUSE", + libsqlx::error::ErrorCode::NoLargeFileSupport => "SQLITE_NOLFS", + libsqlx::error::ErrorCode::AuthorizationForStatementDenied => "SQLITE_AUTH", + libsqlx::error::ErrorCode::ParameterOutOfRange => "SQLITE_RANGE", + libsqlx::error::ErrorCode::NotADatabase => "SQLITE_NOTADB", + libsqlx::error::ErrorCode::Unknown => "SQLITE_UNKNOWN", + _ => "SQLITE_UNKNOWN", + } +} + impl From<&proto::Value> for Value { fn from(proto_value: &proto::Value) -> Value { proto_value_to_value(proto_value) diff --git a/libsqlx-server/src/hrana/ws/mod.rs b/libsqlx-server/src/hrana/ws/mod.rs index 32a34957..a4b96ed5 100644 --- a/libsqlx-server/src/hrana/ws/mod.rs +++ b/libsqlx-server/src/hrana/ws/mod.rs @@ -1,7 +1,5 @@ -use crate::auth::Auth; +// use crate::auth::Auth; use crate::database::Database; -use crate::utils::services::idle_shutdown::IdleKicker; -use anyhow::{Context as _, Result}; use enclose::enclose; use std::net::SocketAddr; use std::sync::atomic::{AtomicU64, Ordering}; @@ -17,7 +15,7 @@ mod session; struct Server { db_factory: Arc>, auth: Arc, - idle_kicker: Option, + // idle_kicker: Option, next_conn_id: AtomicU64, } diff --git a/libsqlx-server/src/http/user/mod.rs b/libsqlx-server/src/http/user/mod.rs index c947fb8b..abf62fd1 100644 --- a/libsqlx-server/src/http/user/mod.rs +++ b/libsqlx-server/src/http/user/mod.rs @@ -1,5 +1,6 @@ use std::sync::Arc; +use axum::extract::State; use axum::routing::post; use axum::{Json, Router}; use color_eyre::Result; @@ -7,6 +8,7 @@ use hyper::server::accept::Accept; use tokio::io::{AsyncRead, AsyncWrite}; use crate::database::Database; +use crate::hrana; use crate::hrana::http::proto::{PipelineRequestBody, PipelineResponseBody}; use crate::linc::bus::Bus; use crate::manager::Manager; @@ -17,11 +19,13 @@ mod extractors; pub struct Config { pub manager: Arc, pub bus: Arc>>, + pub hrana_server: Arc, } struct UserApiState { manager: Arc, bus: Arc>>, + hrana_server: Arc, } pub async fn run_user_api(config: Config, listener: I) -> Result<()> @@ -32,6 +36,7 @@ where let state = UserApiState { manager: config.manager, bus: config.bus, + hrana_server: config.hrana_server, }; let app = Router::new() @@ -46,9 +51,10 @@ where } async fn handle_hrana_pipeline( + State(state): State>, db: Database, Json(req): Json, ) -> Json { - let resp = db.hrana_pipeline(req).await; - Json(resp.unwrap()) + let ret = hrana::http::handle_pipeline(&state.hrana_server, req, db).await.unwrap(); + Json(ret) } diff --git a/libsqlx-server/src/main.rs b/libsqlx-server/src/main.rs index 465673f3..6fdf52a5 100644 --- a/libsqlx-server/src/main.rs +++ b/libsqlx-server/src/main.rs @@ -63,8 +63,16 @@ async fn spawn_user_api( bus: Arc>>, ) -> color_eyre::Result<()> { let user_api_listener = TcpListener::bind(config.addr).await?; + let hrana_server = Arc::new(hrana::http::Server::new(None)); + set.spawn({ + let hrana_server = hrana_server.clone(); + async move { + hrana_server.run_expire().await; + Ok(()) + } + }); set.spawn(run_user_api( - http::user::Config { manager, bus }, + http::user::Config { manager, bus, hrana_server }, AddrIncoming::from_listener(user_api_listener)?, )); diff --git a/libsqlx-server/src/manager.rs b/libsqlx-server/src/manager.rs index 33d6c8ba..429493eb 100644 --- a/libsqlx-server/src/manager.rs +++ b/libsqlx-server/src/manager.rs @@ -9,7 +9,6 @@ use tokio::task::JoinSet; use crate::allocation::config::AllocConfig; use crate::allocation::{Allocation, AllocationMessage, Database}; use crate::compactor::CompactionQueue; -use crate::hrana; use crate::linc::bus::Dispatch; use crate::linc::handler::Handler; use crate::linc::Inbound; @@ -70,8 +69,7 @@ impl Manager { connections_futs: JoinSet::new(), next_conn_id: 0, max_concurrent_connections: config.max_conccurent_connection, - hrana_server: Arc::new(hrana::http::Server::new(None)), - dispatcher, // TODO: handle self URL? + dispatcher, db_name: config.db_name, connections: HashMap::new(), }; From f1e5fad342cedd7244e1d56a428519e487841bd5 Mon Sep 17 00:00:00 2001 From: ad hoc Date: Mon, 31 Jul 2023 14:15:52 +0200 Subject: [PATCH 63/64] improve error handling --- docs/LINC.md | 282 ------------------ libsqlx-server/src/allocation/mod.rs | 9 +- .../src/allocation/primary/compactor.rs | 2 +- libsqlx-server/src/allocation/primary/mod.rs | 96 +++--- libsqlx-server/src/allocation/replica.rs | 49 ++- libsqlx-server/src/compactor.rs | 82 ++--- libsqlx-server/src/database.rs | 9 +- libsqlx-server/src/error.rs | 8 +- libsqlx-server/src/hrana/batch.rs | 24 +- libsqlx-server/src/hrana/error.rs | 18 +- libsqlx-server/src/hrana/http/mod.rs | 9 +- libsqlx-server/src/hrana/http/request.rs | 13 +- libsqlx-server/src/hrana/http/stream.rs | 44 +-- libsqlx-server/src/hrana/result_builder.rs | 23 +- libsqlx-server/src/hrana/stmt.rs | 52 +++- libsqlx-server/src/hrana/ws/mod.rs | 3 +- libsqlx-server/src/http/admin.rs | 1 + libsqlx-server/src/http/user/error.rs | 3 + libsqlx-server/src/http/user/extractors.rs | 2 +- libsqlx-server/src/http/user/mod.rs | 32 +- libsqlx-server/src/linc/bus.rs | 21 +- libsqlx-server/src/linc/connection.rs | 39 ++- libsqlx-server/src/linc/connection_manager.rs | 0 libsqlx-server/src/linc/handler.rs | 7 +- libsqlx-server/src/linc/mod.rs | 15 +- libsqlx-server/src/main.rs | 19 +- libsqlx-server/src/manager.rs | 48 +-- libsqlx-server/src/meta.rs | 75 +++-- libsqlx-server/src/replica_commit_store.rs | 28 +- libsqlx-server/src/snapshot_store.rs | 48 +-- libsqlx/src/connection.rs | 12 +- libsqlx/src/database/libsql/connection.rs | 10 +- libsqlx/src/database/libsql/injector/mod.rs | 14 +- libsqlx/src/database/proxy/connection.rs | 16 +- 34 files changed, 492 insertions(+), 621 deletions(-) delete mode 100644 docs/LINC.md delete mode 100644 libsqlx-server/src/linc/connection_manager.rs diff --git a/docs/LINC.md b/docs/LINC.md deleted file mode 100644 index 9a915c65..00000000 --- a/docs/LINC.md +++ /dev/null @@ -1,282 +0,0 @@ -# Libsql Inter-Node Communication protocol: LINC protocol - -## Overview - -This document describes the version 1 of Libsql Inter-Node Communication (LINC) -protocol. - -The first version of the protocol aims to merge the existing two -protocol (proxy and replication) into a single one, and adds support for multi-tenancy. - -LINC v1 is designed to handle 3 tasks: -- inter-node communication -- database replication -- proxying of request from replicas to primaries - -LINC makes use of streams to multiplex messages between databases on different nodes. - -LINC v1 is implemented on top of TCP. - -LINC uses bincode for message serialization and deserialization. - -## Connection protocol - -Each node is identified by a `node_id`, and an address. -At startup, a sqld node is configured with list of peers (`(node_id, node_addr)`). A connection between two peers is initiated by the peer with the greatest node_id. - -```mermaid -graph TD -node4 --> node3 -node4 --> node2 -node4 --> node1 -node3 --> node2 -node3 --> node1 -node2 --> node1 -node1 -``` - -A new node node can be added to the cluster with no reconfiguration as long as its `node_id` is greater than all other `node_id` in the cluster and it has the address of all the other nodes. In this case, the new node will initiate a connection with all other nodes. - -On disconnection, the initiator of the connection attempts to reconnect. - -## Messages - -```rust -enum Message { - /// Messages destined to a node - Node(NodeMessage), - /// message destined to a stream - Stream { - stream_id: StreamId, - payload: StreamMessage, - }, -} - -enum NodeMessage { - /// Initial message exchanged between nodes when connecting - Handshake { - protocol_version: String, - node_id: String, - }, - /// Request to open a bi-directional stream between the client and the server - OpenStream { - /// Id to give to the newly opened stream - stream_id: StreamId, - /// Id of the database to open the stream to. - database_id: Uuid, - }, - /// Close a previously opened stream - CloseStream { - id: StreamId, - }, - /// Error type returned while handling a node message - Error(NodeError), -} - -enum NodeError { - UnknownStream(StreamId), - HandshakeVersionMismatch { expected: u32 }, - StreamAlreadyExist(StreamId), - UnknownDatabase(DatabaseId, StreamId), -} - -enum StreamMessage { - /// Replication message between a replica and a primary - Replication(ReplicationMessage), - /// Proxy message between a replica and a primary - Proxy(ProxyMessage), - Error(StreamError), -} - -enum ReplicationMessage { - HandshakeResponse { - /// id of the replication log - log_id: Uuid, - /// current frame_no of the primary - current_frame_no: u64, - }, - /// Replication request - Replicate { - /// next frame no to send - next_frame_no: u64, - }, - /// a batch of frames that are part of the same transaction - Transaction { - /// if not None, then the last frame is a commit frame, and this is the new size of the database. - size_after: Option, - /// frame_no of the last frame in frames - end_frame_no: u64 - /// a batch of frames part of the transaction. - frames: Vec - }, - /// Error occurred handling a replication message - Error(StreamError) -} - -struct Frame { - /// Page id of that frame - page_id: u32, - /// Data - data: Bytes, -} - -enum ProxyMessage { - /// Proxy a query to a primary - ProxyRequest { - /// id of the connection to perform the query against - /// If the connection doesn't already exist it is created - /// Id of the request. - /// Responses to this request must have the same id. - connection_id: u32, - req_id: u32, - query: Query, - }, - /// Response to a proxied query - ProxyResponse { - /// id of the request this message is a response to. - req_id: u32, - /// Collection of steps to drive the query builder transducer. - row_step: [RowStep] - }, - /// Stop processing request `id`. - CancelRequest { - req_id: u32, - }, - /// Close Connection with passed id. - CloseConnection { - connection_id: u32, - }, -} - -/// Steps applied to the query builder transducer to build a response to a proxied query. -/// Those types closely mirror those of the `QueryBuilderTrait`. -enum BuilderStep { - BeginStep, - FinishStep(u64, Option), - StepError(StepError), - ColsDesc([Column]), - BeginRows, - BeginRow, - AddRowValue(Value), - FinishRow, - FinishRos, - Finish(ConnectionState) -} - -// State of the connection after a query was executed -enum ConnectionState { - /// The connection is still in a open transaction state - OpenTxn, - /// The connection is idle. - Idle, -} - -struct Column { - /// name of the column - name: string, - /// Declared type of the column, if any. - decl_ty: Option, -} - -/// for now, the stringified version of a sqld::error::Error. -struct StepError(String); - -enum StreamError { - NotAPrimary, - AlreadyReplicating, -} -``` - -## Node Handshake - -When a node connects to another node, it first need to perform a handshake. The -handshake is initialized by the initializer of the connection. It sends the -following message: - -```typescipt -type NodeHandshake = { - version: string, // protocol version - node_id: string, -} -``` - -If a peer receives a connection from a peer with a id smaller than his, it must reject the handshake with a `IllegalConnection` error - -## Streams - -Messages destined to a particular database are sent as part of a stream. A -stream is created by sending a `NodeMessage::OpenStream`, specifying the id of -the stream to open, along with the id of the database for which to open this -stream. If the requested database is not on the destination node, the -destination node respond with a `NodeError::UnknownDatabase` error, and the stream in not -opened. - -If a node receives a message for a stream that was not opened before, it responds a `NodeError::UnknownStream` - -A stream is closed by sending a `CloseStream` with the id of the stream. If the -stream does not exist an `NodeError::UnknownStream` error is returned. - -Streams can be opened by either peer. Each stream is identified with by `i32` -stream id. The peer that initiated the original connection allocates positive -stream ids, while the acceptor peer allocates negative ids. 0 is not a legal -value for a stream_id. The receiver of a request for a stream with id 0 must -close the connection immediately. - -The peer opening a stream is responsible for sending the close message. The -other peer can close the stream at any point, but must not send close message -for that stream. On subsequent message to that stream, it will respond with an -`UnknownStream` message, forcing the initiator to deal with recreating a -stream if necessary. - -## Sub-protocols - -### Replication - -The replica is responsible for initiating the replication protocol. This is -done by opening a stream to a primary. If the destination of the stream is not a -primary database, it responds with a `StreamError::NotAPrimary` error and immediately close -the stream. If the destination database is a primary, it responds to the stream -open request with a `ReplicationMessage::HandshakeResponse` message. This message informs the -replica of the current log version, and of the primary current replication -index (frame_no). - -The replica compares the log version it received from the primary with the one it has, if any. If the -versions don't match, the replica deletes its state and start replicating again from the start. - -After a successful handshake, the replica sends a `ReplicationMessage::Replicate` message with the -next frame_no it's expecting. For example if the replica has not replicated any -frame yet, it sends `ReplicationMessage::Replicate { next_frame_no: 0 }` to -signify to the primary that it's expecting to be sent frame 0. The primary -sends the smallest frame with a `frame_no` satisfying `frame_no >= -next_frame_no`. Because logs can be compacted, the next frame_no the primary -sends to the replica isn't necessarily the one the replica is expecting. It's correct to send -the smallest frame >= next_frame_no because frame_nos only move forward in the event of a compaction: a -frame can only be missing if it was written too more recently, hence _moving -forward_ in the log. The primary ensure consistency by moving commit points -accordingly. It is an error for the primary to send a frame_no strictly less -than the requested frame_no, frame_nos can be received in any order. - -In the event of a disconnection, it is the replica's duty to re-initiate the replication protocol. - -Sending a replicate request twice on the same stream is an error. If a primary -receives more than a single `Replicate` request, it closes the stream and sends -a `StreamError::AlreadyReplicating` request. The replica can re-open a stream and start -replicating again if necessary. - -### Proxy - -Replicas can proxy queries to their primary. Replica can start sending proxy request after they have sent a replication request. - -To proxy a query, a replica sends a `ProxyRequest`. Proxied query on a same connection are serialized. The replica sets the connection id -and the request id for the proxied query. If no connection exists for the -passed id on the primary, one is created. The query is executed on the primary, -and the result rows are returned in `ProxyResponse`. The result rows can be split -into multiple `ProxyResponse`, enabling row streaming. A replica can send a `CancelRequest` to interrupt a request. Any -`ProxyResponse` for that `request_id` can be dropped by the replica, and the -primary should stop sending any more `ProxyResponse` message upon receiving the -cancel request. The primary must rollback a cancelled request. - -The primary can reduce the amount of concurrent open transaction by closing the -underlying SQLite connection for proxied connections that are not in a open -transaction state (`is_autocommit` is true). Subsequent requests on that -connection id will re-open a connection, if necessary. diff --git a/libsqlx-server/src/allocation/mod.rs b/libsqlx-server/src/allocation/mod.rs index 6fa4df7f..20be7cc4 100644 --- a/libsqlx-server/src/allocation/mod.rs +++ b/libsqlx-server/src/allocation/mod.rs @@ -23,7 +23,7 @@ use crate::compactor::CompactionQueue; use crate::error::Error; use crate::hrana::proto::DescribeResult; use crate::linc::bus::Dispatch; -use crate::linc::proto::{Message, Frames}; +use crate::linc::proto::{Frames, Message}; use crate::linc::{Inbound, NodeId}; use crate::meta::DatabaseId; use crate::replica_commit_store::ReplicaCommitStore; @@ -162,13 +162,12 @@ impl Database { transaction_timeout_duration, } => { let next_frame_no = - block_in_place(|| replica_commit_store.get_commit_index(database_id)) + block_in_place(|| replica_commit_store.get_commit_index(database_id))? .map(|fno| fno + 1) .unwrap_or(0); - let commit_callback = Arc::new(move |fno| { - replica_commit_store.commit(database_id, fno); - }); + let commit_callback = + Arc::new(move |fno| replica_commit_store.commit(database_id, fno).is_ok()); let rdb = LibsqlDatabase::new_replica( path, diff --git a/libsqlx-server/src/allocation/primary/compactor.rs b/libsqlx-server/src/allocation/primary/compactor.rs index 5bc4c9a3..3ee2aca4 100644 --- a/libsqlx-server/src/allocation/primary/compactor.rs +++ b/libsqlx-server/src/allocation/primary/compactor.rs @@ -57,7 +57,7 @@ impl LogCompactor for Compactor { self.queue.push(&CompactionJob { database_id: self.database_id, log_id, - }); + })?; Ok(()) } diff --git a/libsqlx-server/src/allocation/primary/mod.rs b/libsqlx-server/src/allocation/primary/mod.rs index dbc0387f..63e60e61 100644 --- a/libsqlx-server/src/allocation/primary/mod.rs +++ b/libsqlx-server/src/allocation/primary/mod.rs @@ -6,7 +6,7 @@ use std::time::Duration; use bytes::Bytes; use libsqlx::libsql::{LibsqlDatabase, PrimaryType}; -use libsqlx::result_builder::ResultBuilder; +use libsqlx::result_builder::{QueryResultBuilderError, ResultBuilder}; use libsqlx::{Connection, Frame, FrameHeader, FrameNo, LogReadError, ReplicationLogger}; use tokio::task::block_in_place; @@ -58,7 +58,7 @@ impl ProxyResponseBuilder { } } - fn maybe_send(&mut self) { + fn maybe_send(&mut self) -> crate::Result<()> { // FIXME: this is stupid: compute current buffer size on the go instead let size = self .buffer @@ -82,11 +82,13 @@ impl ProxyResponseBuilder { .sum::(); if size > MAX_STEP_BATCH_SIZE { - self.send() + self.send()?; } + + Ok(()) } - fn send(&mut self) { + fn send(&mut self) -> crate::Result<()> { let msg = Outbound { to: self.to, enveloppe: Enveloppe { @@ -101,7 +103,9 @@ impl ProxyResponseBuilder { }; self.next_seq_no += 1; - tokio::runtime::Handle::current().block_on(self.dispatcher.dispatch(msg)); + tokio::runtime::Handle::current().block_on(self.dispatcher.dispatch(msg))?; + + Ok(()) } } @@ -109,15 +113,17 @@ impl ResultBuilder for ProxyResponseBuilder { fn init( &mut self, _config: &libsqlx::result_builder::QueryBuilderConfig, - ) -> Result<(), libsqlx::result_builder::QueryResultBuilderError> { + ) -> Result<(), QueryResultBuilderError> { self.buffer.push(BuilderStep::Init); - self.maybe_send(); + self.maybe_send() + .map_err(QueryResultBuilderError::from_any)?; Ok(()) } - fn begin_step(&mut self) -> Result<(), libsqlx::result_builder::QueryResultBuilderError> { + fn begin_step(&mut self) -> Result<(), QueryResultBuilderError> { self.buffer.push(BuilderStep::BeginStep); - self.maybe_send(); + self.maybe_send() + .map_err(QueryResultBuilderError::from_any)?; Ok(()) } @@ -125,65 +131,70 @@ impl ResultBuilder for ProxyResponseBuilder { &mut self, affected_row_count: u64, last_insert_rowid: Option, - ) -> Result<(), libsqlx::result_builder::QueryResultBuilderError> { + ) -> Result<(), QueryResultBuilderError> { self.buffer.push(BuilderStep::FinishStep( affected_row_count, last_insert_rowid, )); - self.maybe_send(); + self.maybe_send() + .map_err(QueryResultBuilderError::from_any)?; Ok(()) } - fn step_error( - &mut self, - error: libsqlx::error::Error, - ) -> Result<(), libsqlx::result_builder::QueryResultBuilderError> { + fn step_error(&mut self, error: libsqlx::error::Error) -> Result<(), QueryResultBuilderError> { self.buffer .push(BuilderStep::StepError(StepError(error.to_string()))); - self.maybe_send(); + self.maybe_send() + .map_err(QueryResultBuilderError::from_any)?; Ok(()) } fn cols_description( &mut self, cols: &mut dyn Iterator, - ) -> Result<(), libsqlx::result_builder::QueryResultBuilderError> { + ) -> Result<(), QueryResultBuilderError> { self.buffer .push(BuilderStep::ColsDesc(cols.map(Into::into).collect())); - self.maybe_send(); + self.maybe_send() + .map_err(QueryResultBuilderError::from_any)?; Ok(()) } - fn begin_rows(&mut self) -> Result<(), libsqlx::result_builder::QueryResultBuilderError> { + fn begin_rows(&mut self) -> Result<(), QueryResultBuilderError> { self.buffer.push(BuilderStep::BeginRows); - self.maybe_send(); + self.maybe_send() + .map_err(QueryResultBuilderError::from_any)?; Ok(()) } - fn begin_row(&mut self) -> Result<(), libsqlx::result_builder::QueryResultBuilderError> { + fn begin_row(&mut self) -> Result<(), QueryResultBuilderError> { self.buffer.push(BuilderStep::BeginRow); - self.maybe_send(); + self.maybe_send() + .map_err(QueryResultBuilderError::from_any)?; Ok(()) } fn add_row_value( &mut self, v: libsqlx::result_builder::ValueRef, - ) -> Result<(), libsqlx::result_builder::QueryResultBuilderError> { + ) -> Result<(), QueryResultBuilderError> { self.buffer.push(BuilderStep::AddRowValue(v.into())); - self.maybe_send(); + self.maybe_send() + .map_err(QueryResultBuilderError::from_any)?; Ok(()) } - fn finish_row(&mut self) -> Result<(), libsqlx::result_builder::QueryResultBuilderError> { + fn finish_row(&mut self) -> Result<(), QueryResultBuilderError> { self.buffer.push(BuilderStep::FinishRow); - self.maybe_send(); + self.maybe_send() + .map_err(QueryResultBuilderError::from_any)?; Ok(()) } - fn finish_rows(&mut self) -> Result<(), libsqlx::result_builder::QueryResultBuilderError> { + fn finish_rows(&mut self) -> Result<(), QueryResultBuilderError> { self.buffer.push(BuilderStep::FinishRows); - self.maybe_send(); + self.maybe_send() + .map_err(QueryResultBuilderError::from_any)?; Ok(()) } @@ -191,10 +202,10 @@ impl ResultBuilder for ProxyResponseBuilder { &mut self, is_txn: bool, frame_no: Option, - ) -> Result { + ) -> Result { self.buffer .push(BuilderStep::Finnalize { is_txn, frame_no }); - self.send(); + self.send().map_err(QueryResultBuilderError::from_any)?; Ok(true) } } @@ -238,26 +249,31 @@ impl FrameStreamer { } } Err(LogReadError::Error(_)) => todo!("handle log read error"), - Err(LogReadError::SnapshotRequired) => self.send_snapshot().await, + Err(LogReadError::SnapshotRequired) => { + if let Err(e) = self.send_snapshot().await { + tracing::error!("error sending snapshot: {e}"); + break; + } + } } } } - async fn send_snapshot(&mut self) { + async fn send_snapshot(&mut self) -> crate::Result<()> { tracing::debug!("sending frames from snapshot"); loop { match self .snapshot_store - .locate_file(self.database_id, self.next_frame_no) + .locate_file(self.database_id, self.next_frame_no)? { Some(file) => { let mut iter = file.frames_iter_from(self.next_frame_no).peekable(); while let Some(frame) = block_in_place(|| iter.next()) { - let frame = frame.unwrap(); + let frame = frame?; // TODO: factorize in maybe_send if self.buffer.len() > FRAMES_MESSAGE_MAX_COUNT { - self.send_frames().await; + self.send_frames().await?; } let size_after = iter .peek() @@ -287,9 +303,11 @@ impl FrameStreamer { } } } + + Ok(()) } - async fn send_frames(&mut self) { + async fn send_frames(&mut self) -> crate::Result<()> { let frames = std::mem::take(&mut self.buffer); let outbound = Outbound { to: self.node_id, @@ -303,7 +321,9 @@ impl FrameStreamer { }, }; self.seq_no += 1; - self.dipatcher.dispatch(outbound).await; + self.dipatcher.dispatch(outbound).await?; + + Ok(()) } } @@ -320,7 +340,7 @@ impl ConnectionHandler for PrimaryConnection { async fn handle_conn_message(&mut self, msg: ConnectionMessage) { match msg { ConnectionMessage::Execute { pgm, builder } => { - block_in_place(|| self.conn.execute_program(&pgm, builder).unwrap()) + block_in_place(|| self.conn.execute_program(&pgm, builder)) } ConnectionMessage::Describe => { todo!() diff --git a/libsqlx-server/src/allocation/replica.rs b/libsqlx-server/src/allocation/replica.rs index b1422e5b..3d6e81e0 100644 --- a/libsqlx-server/src/allocation/replica.rs +++ b/libsqlx-server/src/allocation/replica.rs @@ -60,7 +60,7 @@ impl libsqlx::Connection for RemoteConn { &mut self, program: &libsqlx::program::Program, builder: Box, - ) -> libsqlx::Result<()> { + ) { // When we need to proxy a query, we place it in the current request slot. When we are // back in a async context, we'll send it to the primary, and asynchrously drive the // builder. @@ -75,8 +75,6 @@ impl libsqlx::Connection for RemoteConn { timeout: Box::pin(tokio::time::sleep(self.inner.request_timeout_duration)), }), }; - - Ok(()) } fn describe(&self, _sql: String) -> libsqlx::Result { @@ -130,7 +128,15 @@ impl Replicator { } pub async fn run(mut self) { - self.query_replicate().await; + macro_rules! ok_or_log { + ($e:expr) => { + if let Err(e) = $e { + tracing::warn!("failed to start replication process: {e}"); + } + }; + } + + ok_or_log!(self.query_replicate().await); loop { match timeout(Duration::from_secs(5), self.receiver.recv()).await { Ok(Some(Frames { @@ -147,7 +153,7 @@ impl Replicator { // this is not the batch of frame we were expecting, drop what we have, and // ask again from last checkpoint tracing::debug!(seq, self.next_seq, "wrong seq"); - self.query_replicate().await; + ok_or_log!(self.query_replicate().await); continue; }; self.next_seq += 1; @@ -155,23 +161,32 @@ impl Replicator { tracing::debug!("injecting {} frames", frames.len()); for bytes in frames { - let frame = Frame::try_from_bytes(bytes).unwrap(); - block_in_place(|| { - if let Some(last_committed) = self.injector.inject(frame).unwrap() { - tracing::debug!(last_committed); - self.next_frame_no = last_committed + 1; - } - }); + let inject = || -> crate::Result<()> { + let frame = Frame::try_from_bytes(bytes)?; + block_in_place(|| { + if let Some(last_committed) = self.injector.inject(frame).unwrap() { + tracing::debug!(last_committed); + self.next_frame_no = last_committed + 1; + } + Ok(()) + }) + }; + + if let Err(e) = inject() { + tracing::error!("error injecting frames: {e}"); + ok_or_log!(self.query_replicate().await); + break; + } } } // no news from primary for the past 5 secs, send a request again - Err(_) => self.query_replicate().await, + Err(_) => ok_or_log!(self.query_replicate().await), Ok(None) => break, } } } - async fn query_replicate(&mut self) { + async fn query_replicate(&mut self) -> crate::Result<()> { tracing::debug!("seinding replication request"); self.req_id += 1; self.next_seq = 0; @@ -188,7 +203,9 @@ impl Replicator { }, }, }) - .await; + .await?; + + Ok(()) } } @@ -290,7 +307,7 @@ impl ConnectionHandler for ReplicaConnection { async fn handle_conn_message(&mut self, msg: ConnectionMessage) { match msg { ConnectionMessage::Execute { pgm, builder } => { - self.conn.execute_program(&pgm, builder).unwrap(); + self.conn.execute_program(&pgm, builder); let msg = { let mut lock = self.conn.writer().inner.current_req.lock(); match *lock { diff --git a/libsqlx-server/src/compactor.rs b/libsqlx-server/src/compactor.rs index 523e0d39..060a1dda 100644 --- a/libsqlx-server/src/compactor.rs +++ b/libsqlx-server/src/compactor.rs @@ -47,7 +47,7 @@ impl CompactionQueue { env: heed::Env, db_path: PathBuf, snapshot_store: Arc, - ) -> color_eyre::Result { + ) -> crate::Result { let mut txn = env.write_txn()?; let queue = env.create_database(&mut txn, Some(Self::COMPACTION_QUEUE_DB_NAME))?; let next_id = match queue.last(&mut txn)? { @@ -67,43 +67,47 @@ impl CompactionQueue { }) } - pub fn push(&self, job: &CompactionJob) { + pub fn push(&self, job: &CompactionJob) -> crate::Result<()> { tracing::debug!("new compaction job available: {job:?}"); - let mut txn = self.env.write_txn().unwrap(); + let mut txn = self.env.write_txn()?; let id = self.next_id.fetch_add(1, Ordering::Relaxed); - self.queue.put(&mut txn, &id, job).unwrap(); - txn.commit().unwrap(); + self.queue.put(&mut txn, &id, job)?; + txn.commit()?; self.notify.send_replace(Some(id)); + + Ok(()) } - pub async fn peek(&self) -> (u64, CompactionJob) { + pub async fn peek(&self) -> crate::Result<(u64, CompactionJob)> { let id = self.next_id.load(Ordering::Relaxed); - let txn = block_in_place(|| self.env.read_txn().unwrap()); - match block_in_place(|| self.queue.first(&txn).unwrap()) { - Some(job) => job, - None => { - drop(txn); - self.notify - .subscribe() - .wait_for(|x| x.map(|x| x >= id).unwrap_or_default()) - .await - .unwrap(); - block_in_place(|| { - let txn = self.env.read_txn().unwrap(); - self.queue.first(&txn).unwrap().unwrap() - }) + let peek = || { + block_in_place(|| -> crate::Result<_> { + let txn = self.env.read_txn()?; + Ok(self.queue.first(&txn)?) + }) + }; + + loop { + match peek()? { + Some(job) => return Ok(job), + None => { + self.notify + .subscribe() + .wait_for(|x| x.map(|x| x >= id).unwrap_or_default()) + .await + .expect("we're holding the other side of the channel!"); + } } } } - fn complete(&self, txn: &mut heed::RwTxn, job_id: u64) { - block_in_place(|| { - self.queue.delete(txn, &job_id).unwrap(); - }); + fn complete(&self, txn: &mut heed::RwTxn, job_id: u64) -> crate::Result<()> { + block_in_place(|| self.queue.delete(txn, &job_id))?; + Ok(()) } - async fn compact(&self) -> color_eyre::Result<()> { - let (job_id, job) = self.peek().await; + async fn compact(&self) -> crate::Result<()> { + let (job_id, job) = self.peek().await?; tracing::debug!("starting new compaction job: {job:?}"); let to_compact_path = self.snapshot_queue_dir().join(job.log_id.to_string()); let (start_fno, end_fno) = tokio::task::spawn_blocking({ @@ -127,12 +131,15 @@ impl CompactionQueue { builder.finish() } }) - .await??; + .await + .map_err(|_| { + crate::error::Error::Internal(color_eyre::eyre::anyhow!("compaction thread panicked")) + })??; let mut txn = self.env.write_txn()?; - self.complete(&mut txn, job_id); + self.complete(&mut txn, job_id)?; self.snapshot_store - .register(&mut txn, job.database_id, start_fno, end_fno, job.log_id); + .register(&mut txn, job.database_id, start_fno, end_fno, job.log_id)?; txn.commit()?; std::fs::remove_file(to_compact_path)?; @@ -221,7 +228,7 @@ impl SnapshotBuilder { snapshot_id: Uuid, start_fno: FrameNo, end_fno: FrameNo, - ) -> color_eyre::Result { + ) -> crate::Result { let temp_dir = db_path.join("tmp"); let mut target = BufWriter::new(NamedTempFile::new_in(&temp_dir)?); // reserve header space @@ -243,7 +250,7 @@ impl SnapshotBuilder { }) } - pub fn push_frame(&mut self, frame: Frame) -> color_eyre::Result<()> { + pub fn push_frame(&mut self, frame: Frame) -> crate::Result<()> { assert!(frame.header().frame_no < self.last_seen_frame_no); self.last_seen_frame_no = frame.header().frame_no; @@ -266,16 +273,21 @@ impl SnapshotBuilder { } /// Persist the snapshot, and returns the name and size is frame on the snapshot. - pub fn finish(mut self) -> color_eyre::Result<(FrameNo, FrameNo)> { + pub fn finish(mut self) -> crate::Result<(FrameNo, FrameNo)> { self.snapshot_file.flush()?; - let file = self.snapshot_file.into_inner()?; + let file = self + .snapshot_file + .into_inner() + .map_err(|e| crate::error::Error::Internal(e.into()))?; + file.as_file().write_all_at(bytes_of(&self.header), 0)?; let path = self .db_path .join("snapshots") .join(self.snapshot_id.to_string()); - file.persist(path)?; + file.persist(path) + .map_err(|e| crate::error::Error::Internal(e.into()))?; Ok((self.header.start_frame_no, self.header.end_frame_no)) } @@ -287,7 +299,7 @@ pub struct SnapshotFile { } impl SnapshotFile { - pub fn open(path: &Path) -> color_eyre::Result { + pub fn open(path: &Path) -> crate::Result { let file = File::open(path)?; let mut header_buf = [0; size_of::()]; file.read_exact_at(&mut header_buf, 0)?; diff --git a/libsqlx-server/src/database.rs b/libsqlx-server/src/database.rs index 0b75dd9c..2147a85d 100644 --- a/libsqlx-server/src/database.rs +++ b/libsqlx-server/src/database.rs @@ -9,7 +9,12 @@ pub struct Database { impl Database { pub async fn connect(&self) -> crate::Result { let (ret, conn) = oneshot::channel(); - self.sender.send(AllocationMessage::Connect { ret }).await.unwrap(); - conn.await.unwrap() + self.sender + .send(AllocationMessage::Connect { ret }) + .await + .map_err(|_| crate::error::Error::AllocationClosed)?; + + conn.await + .map_err(|_| crate::error::Error::ConnectionClosed)? } } diff --git a/libsqlx-server/src/error.rs b/libsqlx-server/src/error.rs index df485ac5..f62bf20c 100644 --- a/libsqlx-server/src/error.rs +++ b/libsqlx-server/src/error.rs @@ -1,3 +1,5 @@ +use crate::meta::AllocationError; + #[derive(Debug, thiserror::Error)] pub enum Error { #[error(transparent)] @@ -11,5 +13,9 @@ pub enum Error { #[error("allocation closed")] AllocationClosed, #[error("internal error: {0}")] - Internal(String), + Internal(color_eyre::eyre::Error), + #[error(transparent)] + Heed(#[from] heed::Error), + #[error(transparent)] + Allocation(#[from] AllocationError), } diff --git a/libsqlx-server/src/hrana/batch.rs b/libsqlx-server/src/hrana/batch.rs index 7a5f7c87..d01bc8ac 100644 --- a/libsqlx-server/src/hrana/batch.rs +++ b/libsqlx-server/src/hrana/batch.rs @@ -5,7 +5,7 @@ use std::sync::Arc; use libsqlx::analysis::Statement; use libsqlx::program::{Cond, Program, Step}; -use libsqlx::query::{Query, Params}; +use libsqlx::query::{Params, Query}; use libsqlx::result_builder::{StepResult, StepResultsBuilder}; use tokio::sync::oneshot; @@ -82,9 +82,13 @@ pub async fn execute_batch( conn.execute( pgm, // auth, - Box::new(builder)).await; + Box::new(builder), + ) + .await; - Ok(ret.await.unwrap()) + Ok(ret + .await + .map_err(|_| crate::error::Error::ConnectionClosed)?) } pub fn proto_sequence_to_program(sql: &str) -> Result { @@ -109,20 +113,22 @@ pub fn proto_sequence_to_program(sql: &str) -> Result { }) .collect::>(); - Ok(Program { - steps, - }) + Ok(Program { steps }) } pub async fn execute_sequence( conn: &ConnectionHandle, // auth: Authenticated, - pgm: Program) -> Result<(), HranaError> { + pgm: Program, +) -> Result<(), HranaError> { let (send, ret) = oneshot::channel(); let builder = StepResultsBuilder::new(send); - conn.execute(pgm, + conn.execute( + pgm, // auth, - Box::new(builder)).await; + Box::new(builder), + ) + .await; ret.await .unwrap() diff --git a/libsqlx-server/src/hrana/error.rs b/libsqlx-server/src/hrana/error.rs index 0874a91e..2324887a 100644 --- a/libsqlx-server/src/hrana/error.rs +++ b/libsqlx-server/src/hrana/error.rs @@ -1,7 +1,7 @@ -use super::ProtocolError; -use super::stmt::StmtError; -use super::http::StreamError; use super::http::request::StreamResponseError; +use super::http::StreamError; +use super::stmt::StmtError; +use super::ProtocolError; #[derive(Debug, thiserror::Error)] pub enum HranaError { @@ -16,3 +16,15 @@ pub enum HranaError { #[error(transparent)] Libsqlx(crate::error::Error), } + +impl HranaError { + pub fn code(&self) -> Option<&str>{ + match self { + HranaError::Stmt(e) => Some(e.code()), + HranaError::StreamResponse(e) => Some(e.code()), + HranaError::Stream(_) + | HranaError::Libsqlx(_) + | HranaError::Proto(_) => None, + } + } +} diff --git a/libsqlx-server/src/hrana/http/mod.rs b/libsqlx-server/src/hrana/http/mod.rs index 65686fbe..790a3c4f 100644 --- a/libsqlx-server/src/hrana/http/mod.rs +++ b/libsqlx-server/src/hrana/http/mod.rs @@ -15,12 +15,6 @@ pub struct Server { stream_state: Mutex, } -#[derive(Debug)] -pub enum Route { - GetIndex, - PostPipeline, -} - impl Server { pub fn new(self_url: Option) -> Self { Self { @@ -40,8 +34,7 @@ pub async fn handle_pipeline( // auth: Authenticated, req: proto::PipelineRequestBody, db: Database, -) -> crate::Result -{ +) -> crate::Result { let mut stream_guard = stream::acquire(server, req.baton.as_deref(), db).await?; let mut results = Vec::with_capacity(req.requests.len()); diff --git a/libsqlx-server/src/hrana/http/request.rs b/libsqlx-server/src/hrana/http/request.rs index c99c493b..eaf84eb5 100644 --- a/libsqlx-server/src/hrana/http/request.rs +++ b/libsqlx-server/src/hrana/http/request.rs @@ -1,5 +1,5 @@ -use crate::hrana::ProtocolError; use crate::hrana::error::HranaError; +use crate::hrana::ProtocolError; use super::super::{batch, stmt, Version}; use super::{proto, stream}; @@ -19,7 +19,7 @@ pub async fn handle( // auth: Authenticated, request: proto::StreamRequest, ) -> Result { - let result = match try_handle(stream_guard/*, auth*/, request).await { + let result = match try_handle(stream_guard /*, auth*/, request).await { Ok(response) => proto::StreamResult::Ok { response }, Err(err) => { if let HranaError::StreamResponse(err) = err { @@ -50,8 +50,7 @@ async fn try_handle( let db = stream_guard.get_conn()?; let sqls = stream_guard.sqls(); let query = stmt::proto_stmt_to_query(&req.stmt, sqls, Version::Hrana2)?; - let result = stmt::execute_stmt(db, /*auth,*/ query) - .await?; + let result = stmt::execute_stmt(db, /*auth,*/ query).await?; proto::StreamResponse::Execute(proto::ExecuteStreamResp { result }) } proto::StreamRequest::Batch(req) => { @@ -67,8 +66,7 @@ async fn try_handle( let sql = stmt::proto_sql_to_sql(req.sql.as_deref(), req.sql_id, sqls, Version::Hrana2)?; let pgm = batch::proto_sequence_to_program(sql)?; - batch::execute_sequence(db, /*auth,*/ pgm) - .await?; + batch::execute_sequence(db, /*auth,*/ pgm).await?; proto::StreamResponse::Sequence(proto::SequenceStreamResp {}) } proto::StreamRequest::Describe(req) => { @@ -76,8 +74,7 @@ async fn try_handle( let sqls = stream_guard.sqls(); let sql = stmt::proto_sql_to_sql(req.sql.as_deref(), req.sql_id, sqls, Version::Hrana2)?; - let result = stmt::describe_stmt(db, /* auth,*/ sql.into()) - .await?; + let result = stmt::describe_stmt(db, /* auth,*/ sql.into()).await?; proto::StreamResponse::Describe(proto::DescribeStreamResp { result }) } proto::StreamRequest::StoreSql(req) => { diff --git a/libsqlx-server/src/hrana/http/stream.rs b/libsqlx-server/src/hrana/http/stream.rs index f204f799..1320df90 100644 --- a/libsqlx-server/src/hrana/http/stream.rs +++ b/libsqlx-server/src/hrana/http/stream.rs @@ -114,20 +114,24 @@ pub async fn acquire<'srv>( let handle = state.handles.get_mut(&stream_id); match handle { None => { - Err(ProtocolError::BatonInvalid { reason: format!("Stream handle for {stream_id} was not found")})?; + Err(ProtocolError::BatonInvalid { + reason: format!("Stream handle for {stream_id} was not found"), + })?; } Some(Handle::Acquired) => { - Err(ProtocolError::BatonReused { reason: format!("Stream handle for {stream_id} is acquired")})?; - } - Some(Handle::Expired) => { - Err(StreamError::StreamExpired)? + Err(ProtocolError::BatonReused { + reason: format!("Stream handle for {stream_id} is acquired"), + })?; } + Some(Handle::Expired) => Err(StreamError::StreamExpired)?, Some(Handle::Available(stream)) => { if stream.baton_seq != baton_seq { - Err(ProtocolError::BatonReused { reason: format!( - "Expected baton seq {}, received {baton_seq}", - stream.baton_seq - )})?; + Err(ProtocolError::BatonReused { + reason: format!( + "Expected baton seq {}, received {baton_seq}", + stream.baton_seq + ), + })?; } } }; @@ -144,7 +148,7 @@ pub async fn acquire<'srv>( stream } None => { - let conn = db.connect().await.unwrap(); + let conn = db.connect().await?; let mut state = server.stream_state.lock(); let stream = Box::new(Stream { conn: Some(conn), @@ -277,15 +281,17 @@ fn encode_baton(server: &Server, stream_id: u64, baton_seq: u64) -> String { /// returns a [`ProtocolError::BatonInvalid`] if the baton is invalid, but it attaches an anyhow /// context that describes the precise cause. fn decode_baton(server: &Server, baton_str: &str) -> crate::Result<(u64, u64), HranaError> { - let baton_data = BASE64_STANDARD_NO_PAD.decode(baton_str).map_err(|err| { - ProtocolError::BatonInvalid { reason: format!("Could not base64-decode baton: {err}") } - })?; + let baton_data = + BASE64_STANDARD_NO_PAD + .decode(baton_str) + .map_err(|err| ProtocolError::BatonInvalid { + reason: format!("Could not base64-decode baton: {err}"), + })?; if baton_data.len() != 48 { - Err(ProtocolError::BatonInvalid { reason: format!( - "Baton has invalid size of {} bytes", - baton_data.len() - )})?; + Err(ProtocolError::BatonInvalid { + reason: format!("Baton has invalid size of {} bytes", baton_data.len()), + })?; } let payload = &baton_data[0..16]; @@ -294,7 +300,9 @@ fn decode_baton(server: &Server, baton_str: &str) -> crate::Result<(u64, u64), H let mut hmac = hmac::Hmac::::new_from_slice(&server.baton_key).unwrap(); hmac.update(payload); hmac.verify_slice(received_mac) - .map_err(|_| ProtocolError::BatonInvalid { reason: "Invalid MAC on baton".into() })?; + .map_err(|_| ProtocolError::BatonInvalid { + reason: "Invalid MAC on baton".into(), + })?; let stream_id = u64::from_be_bytes(payload[0..8].try_into().unwrap()); let baton_seq = u64::from_be_bytes(payload[8..16].try_into().unwrap()); diff --git a/libsqlx-server/src/hrana/result_builder.rs b/libsqlx-server/src/hrana/result_builder.rs index 75d84683..3985795b 100644 --- a/libsqlx-server/src/hrana/result_builder.rs +++ b/libsqlx-server/src/hrana/result_builder.rs @@ -271,7 +271,7 @@ pub struct HranaBatchProtoBuilder { current_size: u64, max_response_size: u64, step_empty: bool, - ret: oneshot::Sender, + ret: Option>, } impl HranaBatchProtoBuilder { @@ -285,15 +285,16 @@ impl HranaBatchProtoBuilder { current_size: 0, max_response_size: u64::MAX, step_empty: false, - ret, + ret: Some(ret), }, rcv, ) } - pub fn into_ret(self) -> proto::BatchResult { + + pub fn into_ret(&mut self) -> proto::BatchResult { proto::BatchResult { - step_results: self.step_results, - step_errors: self.step_errors, + step_results: std::mem::take(&mut self.step_results), + step_errors: std::mem::take(&mut self.step_errors), } } } @@ -360,6 +361,18 @@ impl ResultBuilder for HranaBatchProtoBuilder { self.stmt_builder.add_row_value(v) } + fn finnalize( + &mut self, + _is_txn: bool, + _frame_no: Option, + ) -> Result { + if let Some(ret) = self.ret.take() { + let _ = ret.send(self.into_ret()); + } + + Ok(false) + } + fn finnalize_error(&mut self, _e: String) { todo!() } diff --git a/libsqlx-server/src/hrana/stmt.rs b/libsqlx-server/src/hrana/stmt.rs index f0021028..d1a7799e 100644 --- a/libsqlx-server/src/hrana/stmt.rs +++ b/libsqlx-server/src/hrana/stmt.rs @@ -1,9 +1,10 @@ use std::collections::HashMap; -use libsqlx::DescribeResponse; +use futures::FutureExt; use libsqlx::analysis::Statement; use libsqlx::program::Program; -use libsqlx::query::{Query, Params, Value}; +use libsqlx::query::{Params, Query, Value}; +use libsqlx::DescribeResponse; use super::error::HranaError; use super::result_builder::SingleStatementBuilder; @@ -52,9 +53,13 @@ pub async fn execute_stmt( query: Query, ) -> crate::Result { let (builder, ret) = SingleStatementBuilder::new(); - conn.execute(Program::from_queries(Some(query))/*, auth*/, Box::new(builder)).await; - let stmt_res = ret.await; - Ok(stmt_res.unwrap()?) + conn.execute( + Program::from_queries(Some(query)), /*, auth*/ + Box::new(builder), + ) + .await; + ret.await + .map_err(|_| crate::error::Error::ConnectionClosed)? } pub async fn describe_stmt( @@ -188,19 +193,28 @@ impl From for HranaError { fn from(error: crate::error::Error) -> Self { if let crate::error::Error::Libsqlx(e) = error { match e { - libsqlx::error::Error::LibSqlInvalidQueryParams(source) => StmtError::ArgsInvalid { source: color_eyre::eyre::anyhow!("{source}") }.into(), + libsqlx::error::Error::LibSqlInvalidQueryParams(source) => StmtError::ArgsInvalid { + source: color_eyre::eyre::anyhow!("{source}"), + } + .into(), libsqlx::error::Error::LibSqlTxTimeout => StmtError::TransactionTimeout.into(), libsqlx::error::Error::LibSqlTxBusy => StmtError::TransactionBusy.into(), libsqlx::error::Error::Blocked(reason) => StmtError::Blocked { reason }.into(), libsqlx::error::Error::RusqliteError(rusqlite_error) => match rusqlite_error { - libsqlx::error::RusqliteError::SqliteFailure(sqlite_error, Some(message)) => StmtError::SqliteError { - source: sqlite_error, - message, - }.into(), - libsqlx::error::RusqliteError::SqliteFailure(sqlite_error, None) => StmtError::SqliteError { - message: sqlite_error.to_string(), - source: sqlite_error, - }.into(), + libsqlx::error::RusqliteError::SqliteFailure(sqlite_error, Some(message)) => { + StmtError::SqliteError { + source: sqlite_error, + message, + } + .into() + } + libsqlx::error::RusqliteError::SqliteFailure(sqlite_error, None) => { + StmtError::SqliteError { + message: sqlite_error.to_string(), + source: sqlite_error, + } + .into() + } libsqlx::error::RusqliteError::SqlInputError { error: sqlite_error, msg: message, @@ -210,8 +224,14 @@ impl From for HranaError { source: sqlite_error, message, offset, - }.into(), - rusqlite_error => return crate::error::Error::from(libsqlx::error::Error::RusqliteError(rusqlite_error)).into(), + } + .into(), + rusqlite_error => { + return crate::error::Error::from(libsqlx::error::Error::RusqliteError( + rusqlite_error, + )) + .into() + } }, sqld_error => return crate::error::Error::from(sqld_error).into(), } diff --git a/libsqlx-server/src/hrana/ws/mod.rs b/libsqlx-server/src/hrana/ws/mod.rs index a4b96ed5..bcdb5209 100644 --- a/libsqlx-server/src/hrana/ws/mod.rs +++ b/libsqlx-server/src/hrana/ws/mod.rs @@ -12,8 +12,7 @@ mod conn; mod handshake; mod session; -struct Server { - db_factory: Arc>, +struct Server { auth: Arc, // idle_kicker: Option, next_conn_id: AtomicU64, diff --git a/libsqlx-server/src/http/admin.rs b/libsqlx-server/src/http/admin.rs index ff718674..2bac534c 100644 --- a/libsqlx-server/src/http/admin.rs +++ b/libsqlx-server/src/http/admin.rs @@ -197,6 +197,7 @@ async fn list_allocs( .handler() .store() .list_allocs() + .unwrap() .into_iter() .map(|cfg| AllocView { id: cfg.db_name }) .collect(); diff --git a/libsqlx-server/src/http/user/error.rs b/libsqlx-server/src/http/user/error.rs index 9aab9a71..81a9ea2b 100644 --- a/libsqlx-server/src/http/user/error.rs +++ b/libsqlx-server/src/http/user/error.rs @@ -11,6 +11,8 @@ pub enum UserApiError { InvalidHost, #[error("Database `{0}` doesn't exist")] UnknownDatabase(String), + #[error(transparent)] + LibsqlxServer(#[from] crate::error::Error), } impl UserApiError { @@ -19,6 +21,7 @@ impl UserApiError { UserApiError::MissingHost | UserApiError::InvalidHost | UserApiError::UnknownDatabase(_) => StatusCode::BAD_REQUEST, + UserApiError::LibsqlxServer(_) => StatusCode::INTERNAL_SERVER_ERROR, } } } diff --git a/libsqlx-server/src/http/user/extractors.rs b/libsqlx-server/src/http/user/extractors.rs index 582b0fd6..fc84900e 100644 --- a/libsqlx-server/src/http/user/extractors.rs +++ b/libsqlx-server/src/http/user/extractors.rs @@ -20,7 +20,7 @@ impl FromRequestParts> for Database { let Ok(host_str) = std::str::from_utf8(host.as_bytes()) else {return Err(UserApiError::MissingHost)}; let db_name = parse_host(host_str)?; let db_id = DatabaseId::from_name(db_name); - let Some(sender) = state.manager.schedule(db_id, state.bus.clone()).await else { return Err(UserApiError::UnknownDatabase(db_name.to_owned())) }; + let Some(sender) = state.manager.schedule(db_id, state.bus.clone()).await? else { return Err(UserApiError::UnknownDatabase(db_name.to_owned())) }; Ok(Database { sender }) } diff --git a/libsqlx-server/src/http/user/mod.rs b/libsqlx-server/src/http/user/mod.rs index abf62fd1..3653377b 100644 --- a/libsqlx-server/src/http/user/mod.rs +++ b/libsqlx-server/src/http/user/mod.rs @@ -1,14 +1,18 @@ use std::sync::Arc; use axum::extract::State; +use axum::response::IntoResponse; use axum::routing::post; use axum::{Json, Router}; use color_eyre::Result; +use hyper::StatusCode; use hyper::server::accept::Accept; +use serde::Serialize; use tokio::io::{AsyncRead, AsyncWrite}; use crate::database::Database; use crate::hrana; +use crate::hrana::error::HranaError; use crate::hrana::http::proto::{PipelineRequestBody, PipelineResponseBody}; use crate::linc::bus::Bus; use crate::manager::Manager; @@ -16,6 +20,28 @@ use crate::manager::Manager; mod error; mod extractors; +#[derive(Debug, Serialize)] +struct ErrorResponseBody { + pub message: String, + pub code: String, +} + +impl IntoResponse for HranaError { + fn into_response(self) -> axum::response::Response { + let (message, code) = match self.code() { + Some(code) => (self.to_string(), code.to_owned()), + None => ("internal error, please check the logs".to_owned(), "INTERNAL_ERROR".to_owned()), + }; + let resp = ErrorResponseBody { + message, + code, + }; + let mut resp = Json(resp).into_response(); + *resp.status_mut() = StatusCode::BAD_REQUEST; + resp + } +} + pub struct Config { pub manager: Arc, pub bus: Arc>>, @@ -54,7 +80,7 @@ async fn handle_hrana_pipeline( State(state): State>, db: Database, Json(req): Json, -) -> Json { - let ret = hrana::http::handle_pipeline(&state.hrana_server, req, db).await.unwrap(); - Json(ret) +) -> crate::Result, HranaError> { + let ret = hrana::http::handle_pipeline(&state.hrana_server, req, db).await?; + Ok(Json(ret)) } diff --git a/libsqlx-server/src/linc/bus.rs b/libsqlx-server/src/linc/bus.rs index 5072c8ae..c74ba267 100644 --- a/libsqlx-server/src/linc/bus.rs +++ b/libsqlx-server/src/linc/bus.rs @@ -36,11 +36,9 @@ impl Bus { } pub async fn incomming(self: &Arc, incomming: Inbound) { - self.handler.handle(self.clone(), incomming).await; - } - - pub fn send_queue(&self) -> &SendQueue { - &self.send_queue + if let Err(e) = self.handler.handle(self.clone(), incomming).await { + tracing::error!("error handling message: {e}") + } } pub fn connect(&self, node_id: NodeId) -> mpsc::UnboundedReceiver { @@ -48,27 +46,26 @@ impl Bus { self.peers.write().insert(node_id); self.send_queue.register(node_id) } - - pub fn disconnect(&self, node_id: NodeId) { - self.peers.write().remove(&node_id); - } } #[async_trait::async_trait] pub trait Dispatch: Send + Sync + 'static { - async fn dispatch(&self, msg: Outbound); + async fn dispatch(&self, msg: Outbound) -> crate::Result<()>; + /// id of the current node fn node_id(&self) -> NodeId; } #[async_trait::async_trait] impl Dispatch for Bus { - async fn dispatch(&self, msg: Outbound) { + async fn dispatch(&self, msg: Outbound) -> crate::Result<()> { assert!( msg.to != self.node_id(), "trying to send a message to ourself!" ); // This message is outbound. - self.send_queue.enqueue(msg).await; + self.send_queue.enqueue(msg).await?; + + Ok(()) } fn node_id(&self) -> NodeId { diff --git a/libsqlx-server/src/linc/connection.rs b/libsqlx-server/src/linc/connection.rs index b979c437..9fa41bb1 100644 --- a/libsqlx-server/src/linc/connection.rs +++ b/libsqlx-server/src/linc/connection.rs @@ -62,13 +62,24 @@ impl SendQueue { } } - pub async fn enqueue(&self, msg: Outbound) { + pub async fn enqueue(&self, msg: Outbound) -> crate::Result<()> { let sender = match self.senders.read().get(&msg.to) { Some(sender) => sender.clone(), - None => todo!("no queue"), + None => { + return Err(crate::error::Error::Internal(color_eyre::eyre::anyhow!( + "failed to deliver message: unknown node id `{}`", + msg.to + ))) + } }; - sender.send(msg.enveloppe).unwrap(); + sender.send(msg.enveloppe).map_err(|_| { + crate::error::Error::Internal(color_eyre::eyre::anyhow!( + "failed to deliver message: connection closed" + )) + })?; + + Ok(()) } pub fn register(&self, node_id: NodeId) -> mpsc::UnboundedReceiver { @@ -155,14 +166,22 @@ where } } }, - // TODO: pop send queue - Some(m) = self.send_queue.as_mut().unwrap().recv() => { - self.conn.feed(m).await.unwrap(); - // send as many as possible - while let Ok(m) = self.send_queue.as_mut().unwrap().try_recv() { - self.conn.feed(m).await.unwrap(); + Some(m) = self.send_queue.as_mut().expect("no send_queue in connected sate").recv() => { + let feed = || async { + self.conn.feed(m).await?; + // send as many as possible + while let Ok(m) = self.send_queue.as_mut().expect("no send_queue in connected sate").try_recv() { + self.conn.feed(m).await?; + } + self.conn.flush().await?; + + Ok(()) + }; + + if let Err(e) = feed().await { + tracing::error!("error flusing send queue for {}; closing connection", self.peer.unwrap()); + self.state = ConnectionState::CloseError(e) } - self.conn.flush().await.unwrap(); }, else => { self.state = ConnectionState::Close; diff --git a/libsqlx-server/src/linc/connection_manager.rs b/libsqlx-server/src/linc/connection_manager.rs deleted file mode 100644 index e69de29b..00000000 diff --git a/libsqlx-server/src/linc/handler.rs b/libsqlx-server/src/linc/handler.rs index 2d17ff96..410e4c24 100644 --- a/libsqlx-server/src/linc/handler.rs +++ b/libsqlx-server/src/linc/handler.rs @@ -6,7 +6,7 @@ use super::Inbound; #[async_trait::async_trait] pub trait Handler: Sized + Send + Sync + 'static { /// Handle inbound message - async fn handle(&self, bus: Arc, msg: Inbound); + async fn handle(&self, bus: Arc, msg: Inbound) -> crate::Result<()>; } #[cfg(test)] @@ -16,7 +16,8 @@ where F: Fn(Arc, Inbound) -> Fut + Send + Sync + 'static, Fut: std::future::Future + Send, { - async fn handle(&self, bus: Arc, msg: Inbound) { - (self)(bus, msg).await + async fn handle(&self, bus: Arc, msg: Inbound) -> crate::Result<()> { + (self)(bus, msg).await; + Ok(()) } } diff --git a/libsqlx-server/src/linc/mod.rs b/libsqlx-server/src/linc/mod.rs index 638f56e2..2ee07790 100644 --- a/libsqlx-server/src/linc/mod.rs +++ b/libsqlx-server/src/linc/mod.rs @@ -1,4 +1,4 @@ -use self::proto::{Enveloppe, Message}; +use self::proto::Enveloppe; pub mod bus; pub mod connection; @@ -11,7 +11,6 @@ pub mod server; pub type NodeId = u64; const CURRENT_PROTO_VERSION: u32 = 1; -const MAX_STREAM_MSG: usize = 64; #[derive(Debug)] pub struct Inbound { @@ -21,18 +20,6 @@ pub struct Inbound { pub enveloppe: Enveloppe, } -impl Inbound { - pub fn respond(&self, message: Message) -> Outbound { - Outbound { - to: self.from, - enveloppe: Enveloppe { - database_id: None, - message, - }, - } - } -} - #[derive(Debug)] pub struct Outbound { pub to: NodeId, diff --git a/libsqlx-server/src/main.rs b/libsqlx-server/src/main.rs index 6fdf52a5..6162df26 100644 --- a/libsqlx-server/src/main.rs +++ b/libsqlx-server/src/main.rs @@ -72,7 +72,11 @@ async fn spawn_user_api( } }); set.spawn(run_user_api( - http::user::Config { manager, bus, hrana_server }, + http::user::Config { + manager, + bus, + hrana_server, + }, AddrIncoming::from_listener(user_api_listener)?, )); @@ -113,7 +117,8 @@ async fn init_dirs(db_path: &Path) -> color_eyre::Result<()> { #[tokio::main(flavor = "multi_thread", worker_threads = 10)] async fn main() -> color_eyre::Result<()> { - init(); + init()?; + let args = Args::parse(); let config_str = read_to_string(args.config)?; let config: config::Config = toml::from_str(&config_str)?; @@ -134,8 +139,8 @@ async fn main() -> color_eyre::Result<()> { config.db_path.clone(), snapshot_store, )?); - let store = Arc::new(Store::new(env.clone())); - let replica_commit_store = Arc::new(ReplicaCommitStore::new(env.clone())); + let store = Arc::new(Store::new(env.clone())?); + let replica_commit_store = Arc::new(ReplicaCommitStore::new(env.clone())?); let manager = Arc::new(Manager::new( config.db_path.clone(), store.clone(), @@ -155,7 +160,7 @@ async fn main() -> color_eyre::Result<()> { Ok(()) } -fn init() { +fn init() -> color_eyre::Result<()> { let registry = tracing_subscriber::registry(); registry @@ -170,5 +175,7 @@ fn init() { ) .init(); - color_eyre::install().unwrap(); + color_eyre::install()?; + + Ok(()) } diff --git a/libsqlx-server/src/manager.rs b/libsqlx-server/src/manager.rs index 429493eb..686e6fca 100644 --- a/libsqlx-server/src/manager.rs +++ b/libsqlx-server/src/manager.rs @@ -47,14 +47,14 @@ impl Manager { self: &Arc, database_id: DatabaseId, dispatcher: Arc, - ) -> Option> { + ) -> crate::Result>> { if let Some(sender) = self.cache.get(&database_id) { - return Some(sender.clone()); + return Ok(Some(sender.clone())); } - if let Some(config) = self.meta_store.meta(&database_id) { + if let Some(config) = self.meta_store.meta(&database_id)? { let path = self.db_path.join("dbs").join(database_id.to_string()); - tokio::fs::create_dir_all(&path).await.unwrap(); + tokio::fs::create_dir_all(&path).await?; let (alloc_sender, inbox) = mpsc::channel(MAX_ALLOC_MESSAGE_QUEUE_LEN); let alloc = Allocation { inbox, @@ -64,8 +64,7 @@ impl Manager { dispatcher.clone(), self.compaction_queue.clone(), self.replica_commit_store.clone(), - ) - .unwrap(), + )?, connections_futs: JoinSet::new(), next_conn_id: 0, max_concurrent_connections: config.max_conccurent_connection, @@ -78,10 +77,10 @@ impl Manager { self.cache.insert(database_id, alloc_sender.clone()).await; - return Some(alloc_sender); + return Ok(Some(alloc_sender)); } - None + Ok(None) } pub async fn allocate( @@ -89,16 +88,21 @@ impl Manager { database_id: DatabaseId, meta: &AllocConfig, dispatcher: Arc, - ) { - self.store().allocate(&database_id, meta); - self.schedule(database_id, dispatcher).await; + ) -> crate::Result<()> { + self.store().allocate(&database_id, meta)?; + self.schedule(database_id, dispatcher).await?; + Ok(()) } - pub async fn deallocate(&self, database_id: DatabaseId) { - self.meta_store.deallocate(&database_id); + pub async fn deallocate(&self, database_id: DatabaseId) -> crate::Result<()> { + self.meta_store.deallocate(&database_id)?; self.cache.remove(&database_id).await; let db_path = self.db_path.join("dbs").join(database_id.to_string()); - tokio::fs::remove_dir_all(db_path).await.unwrap(); + if db_path.exists() { + tokio::fs::remove_dir_all(db_path).await?; + } + + Ok(()) } pub fn store(&self) -> &Store { @@ -108,13 +112,15 @@ impl Manager { #[async_trait::async_trait] impl Handler for Arc { - async fn handle(&self, bus: Arc, msg: Inbound) { - if let Some(sender) = self - .clone() - .schedule(msg.enveloppe.database_id.unwrap(), bus.clone()) - .await - { - let _ = sender.send(AllocationMessage::Inbound(msg)).await; + async fn handle(&self, bus: Arc, msg: Inbound) -> crate::Result<()> { + if let Some(database_id) = msg.enveloppe.database_id { + if let Some(sender) = self.clone().schedule(database_id, bus.clone()).await? { + sender + .send(AllocationMessage::Inbound(msg)) + .await + .map_err(|_| crate::error::Error::AllocationClosed)?; + } } + Ok(()) } } diff --git a/libsqlx-server/src/meta.rs b/libsqlx-server/src/meta.rs index 2436839b..baefaeaf 100644 --- a/libsqlx-server/src/meta.rs +++ b/libsqlx-server/src/meta.rs @@ -3,6 +3,7 @@ use std::mem::size_of; use heed::bytemuck::{Pod, Zeroable}; use heed_types::{OwnedType, SerdeBincode}; +use itertools::Itertools; use serde::{Deserialize, Serialize}; use sha3::digest::{ExtendableOutput, Update, XofReader}; use sha3::Shake128; @@ -52,64 +53,72 @@ impl AsRef<[u8]> for DatabaseId { } } +#[derive(Debug, thiserror::Error)] +pub enum AllocationError { + #[error("an allocation already exists for {0}")] + AlreadyExist(String), +} + impl Store { const ALLOC_CONFIG_DB_NAME: &'static str = "alloc_conf_db"; - pub fn new(env: heed::Env) -> Self { - let mut txn = env.write_txn().unwrap(); - let alloc_config_db = env - .create_database(&mut txn, Some(Self::ALLOC_CONFIG_DB_NAME)) - .unwrap(); - txn.commit().unwrap(); + pub fn new(env: heed::Env) -> crate::Result { + let mut txn = env.write_txn()?; + let alloc_config_db = env.create_database(&mut txn, Some(Self::ALLOC_CONFIG_DB_NAME))?; + txn.commit()?; - Self { + Ok(Self { env, alloc_config_db, - } + }) } - pub fn allocate(&self, id: &DatabaseId, meta: &AllocConfig) { - //TODO: Handle conflict + pub fn allocate(&self, id: &DatabaseId, meta: &AllocConfig) -> crate::Result<()> { block_in_place(|| { - let mut txn = self.env.write_txn().unwrap(); + let mut txn = self.env.write_txn()?; if self .alloc_config_db .lazily_decode_data() - .get(&txn, id) - .unwrap() + .get(&txn, id)? .is_some() { - panic!("alloc already exists"); + Err(AllocationError::AlreadyExist(meta.db_name.clone()))?; }; - self.alloc_config_db.put(&mut txn, id, meta).unwrap(); - txn.commit().unwrap(); - }); + + self.alloc_config_db.put(&mut txn, id, meta)?; + + txn.commit()?; + + Ok(()) + }) } - pub fn deallocate(&self, id: &DatabaseId) { + pub fn deallocate(&self, id: &DatabaseId) -> crate::Result<()> { block_in_place(|| { - let mut txn = self.env.write_txn().unwrap(); - self.alloc_config_db.delete(&mut txn, id).unwrap(); - txn.commit().unwrap(); - }); + let mut txn = self.env.write_txn()?; + self.alloc_config_db.delete(&mut txn, id)?; + txn.commit()?; + + Ok(()) + }) } - pub fn meta(&self, id: &DatabaseId) -> Option { + pub fn meta(&self, id: &DatabaseId) -> crate::Result> { block_in_place(|| { - let txn = self.env.read_txn().unwrap(); - self.alloc_config_db.get(&txn, id).unwrap() + let txn = self.env.read_txn()?; + Ok(self.alloc_config_db.get(&txn, id)?) }) } - pub fn list_allocs(&self) -> Vec { + pub fn list_allocs(&self) -> crate::Result> { block_in_place(|| { - let txn = self.env.read_txn().unwrap(); - self.alloc_config_db - .iter(&txn) - .unwrap() - .map(Result::unwrap) - .map(|x| x.1) - .collect() + let txn = self.env.read_txn()?; + let res = self + .alloc_config_db + .iter(&txn)? + .map(|x| x.map(|x| x.1)) + .try_collect()?; + Ok(res) }) } } diff --git a/libsqlx-server/src/replica_commit_store.rs b/libsqlx-server/src/replica_commit_store.rs index 18c0aeed..2598c3b0 100644 --- a/libsqlx-server/src/replica_commit_store.rs +++ b/libsqlx-server/src/replica_commit_store.rs @@ -11,24 +11,24 @@ pub struct ReplicaCommitStore { impl ReplicaCommitStore { const DB_NAME: &str = "replica-commit-store"; - pub fn new(env: heed::Env) -> Self { - let mut txn = env.write_txn().unwrap(); - let database = env.create_database(&mut txn, Some(Self::DB_NAME)).unwrap(); - txn.commit().unwrap(); + pub fn new(env: heed::Env) -> crate::Result { + let mut txn = env.write_txn()?; + let database = env.create_database(&mut txn, Some(Self::DB_NAME))?; + txn.commit()?; - Self { env, database } + Ok(Self { env, database }) } - pub fn commit(&self, database_id: DatabaseId, frame_no: FrameNo) { - let mut txn = self.env.write_txn().unwrap(); - self.database - .put(&mut txn, &database_id, &frame_no) - .unwrap(); - txn.commit().unwrap(); + pub fn commit(&self, database_id: DatabaseId, frame_no: FrameNo) -> crate::Result<()> { + let mut txn = self.env.write_txn()?; + self.database.put(&mut txn, &database_id, &frame_no)?; + txn.commit()?; + + Ok(()) } - pub fn get_commit_index(&self, database_id: DatabaseId) -> Option { - let txn = self.env.read_txn().unwrap(); - self.database.get(&txn, &database_id).unwrap() + pub fn get_commit_index(&self, database_id: DatabaseId) -> crate::Result> { + let txn = self.env.read_txn()?; + Ok(self.database.get(&txn, &database_id)?) } } diff --git a/libsqlx-server/src/snapshot_store.rs b/libsqlx-server/src/snapshot_store.rs index 32f7f0e9..73cbc0cb 100644 --- a/libsqlx-server/src/snapshot_store.rs +++ b/libsqlx-server/src/snapshot_store.rs @@ -10,6 +10,8 @@ use uuid::Uuid; use crate::{compactor::SnapshotFile, meta::DatabaseId}; +/// Equivalent to a u64, but stored in big-endian ordering. +/// Used for storing values whose bytes need to be lexically ordered. #[derive(Clone, Copy, Zeroable, Pod, Debug)] #[repr(transparent)] struct BEU64([u8; size_of::()]); @@ -48,8 +50,8 @@ pub struct SnapshotStore { impl SnapshotStore { const SNAPSHOT_STORE_NAME: &str = "snapshot-store-db"; - pub fn new(db_path: PathBuf, env: heed::Env) -> color_eyre::Result { - let mut txn = env.write_txn().unwrap(); + pub fn new(db_path: PathBuf, env: heed::Env) -> crate::Result { + let mut txn = env.write_txn()?; let database = env.create_database(&mut txn, Some(Self::SNAPSHOT_STORE_NAME))?; txn.commit()?; @@ -67,7 +69,7 @@ impl SnapshotStore { start_frame_no: FrameNo, end_frame_no: FrameNo, snapshot_id: Uuid, - ) { + ) -> crate::Result<()> { let key = SnapshotKey { database_id, start_frame_no: start_frame_no.into(), @@ -76,13 +78,19 @@ impl SnapshotStore { let data = SnapshotMeta { snapshot_id }; - block_in_place(|| self.database.put(txn, &key, &data).unwrap()); + block_in_place(|| self.database.put(txn, &key, &data))?; + + Ok(()) } /// Locate a snapshot for `database_id` that contains `frame_no` - pub fn locate(&self, database_id: DatabaseId, frame_no: FrameNo) -> Option { - let txn = self.env.read_txn().unwrap(); - // Snapshot keys being lexicographically ordered, looking for the first key less than of + pub fn locate( + &self, + database_id: DatabaseId, + frame_no: FrameNo, + ) -> crate::Result> { + let txn = self.env.read_txn()?; + // Snapshot keys are lexicographically ordered, looking for the first key less than of // equal to (db_id, frame_no, FrameNo::MAX) will always return the entry we're looking for // if it exists. let key = SnapshotKey { @@ -91,14 +99,10 @@ impl SnapshotStore { end_frame_no: u64::MAX.into(), }; - match self - .database - .get_lower_than_or_equal_to(&txn, &key) - .transpose()? - { - Ok((key, v)) => { + match self.database.get_lower_than_or_equal_to(&txn, &key)? { + Some((key, v)) => { if key.database_id != database_id { - return None; + return Ok(None); } else if frame_no >= key.start_frame_no.into() && frame_no <= key.end_frame_no.into() { @@ -107,22 +111,26 @@ impl SnapshotStore { u64::from(key.start_frame_no), u64::from(key.end_frame_no) ); - return Some(v); + return Ok(Some(v)); } else { - None + Ok(None) } } - Err(_) => todo!(), + None => Ok(None), } } - pub fn locate_file(&self, database_id: DatabaseId, frame_no: FrameNo) -> Option { - let meta = self.locate(database_id, frame_no)?; + pub fn locate_file( + &self, + database_id: DatabaseId, + frame_no: FrameNo, + ) -> crate::Result> { + let Some(meta) = self.locate(database_id, frame_no)? else { return Ok(None) }; let path = self .db_path .join("snapshots") .join(meta.snapshot_id.to_string()); - Some(SnapshotFile::open(&path).unwrap()) + Ok(Some(SnapshotFile::open(&path)?)) } } diff --git a/libsqlx/src/connection.rs b/libsqlx/src/connection.rs index e767073a..9c3fcdb0 100644 --- a/libsqlx/src/connection.rs +++ b/libsqlx/src/connection.rs @@ -24,11 +24,7 @@ pub struct DescribeCol { pub trait Connection { /// Executes a query program - fn execute_program( - &mut self, - pgm: &Program, - result_builder: Box, - ) -> crate::Result<()>; + fn execute_program(&mut self, pgm: &Program, result_builder: Box); /// Parse the SQL statement and return information about it. fn describe(&self, sql: String) -> crate::Result; @@ -39,11 +35,7 @@ where T: Connection, X: Connection, { - fn execute_program( - &mut self, - pgm: &Program, - result_builder: Box, - ) -> crate::Result<()> { + fn execute_program(&mut self, pgm: &Program, result_builder: Box) { match self { Either::Left(c) => c.execute_program(pgm, result_builder), Either::Right(c) => c.execute_program(pgm, result_builder), diff --git a/libsqlx/src/database/libsql/connection.rs b/libsqlx/src/database/libsql/connection.rs index 0ad8b780..9579dc94 100644 --- a/libsqlx/src/database/libsql/connection.rs +++ b/libsqlx/src/database/libsql/connection.rs @@ -249,12 +249,10 @@ fn eval_cond(cond: &Cond, results: &[bool]) -> Result { } impl Connection for LibsqlConnection { - fn execute_program( - &mut self, - pgm: &Program, - mut builder: Box, - ) -> crate::Result<()> { - self.run(pgm, &mut *builder) + fn execute_program(&mut self, pgm: &Program, mut builder: Box) { + if let Err(e) = self.run(pgm, &mut *builder) { + builder.finnalize_error(e.to_string()); + } } fn describe(&self, sql: String) -> crate::Result { diff --git a/libsqlx/src/database/libsql/injector/mod.rs b/libsqlx/src/database/libsql/injector/mod.rs index cbc9dc80..5318b997 100644 --- a/libsqlx/src/database/libsql/injector/mod.rs +++ b/libsqlx/src/database/libsql/injector/mod.rs @@ -18,7 +18,7 @@ mod headers; mod hook; pub type FrameBuffer = Arc>>; -pub type OnCommitCb = Arc; +pub type OnCommitCb = Arc bool + Send + Sync + 'static>; pub struct Injector { /// The injector is in a transaction state @@ -85,7 +85,7 @@ impl Injector { self.buffer.lock().push_back(frame); if frame_close_txn || self.buffer.lock().len() >= self.capacity { if !self.is_txn { - self.begin_txn(); + self.begin_txn()?; } return self.flush(); } @@ -135,14 +135,14 @@ impl Injector { fn commit(&mut self) { // TODO: error? - let _ = self.connection.execute("COMMIT", ()); + let _ = dbg!(self.connection.execute("COMMIT", ())); } - fn begin_txn(&mut self) { - self.connection.execute("BEGIN IMMEDIATE", ()).unwrap(); + fn begin_txn(&mut self) -> crate::Result<()> { + self.connection.execute("BEGIN IMMEDIATE", ())?; self.connection - .execute("CREATE TABLE __DUMMY__ (__dummy__)", ()) - .unwrap(); + .execute("CREATE TABLE __DUMMY__ (__dummy__)", ())?; + Ok(()) } } diff --git a/libsqlx/src/database/proxy/connection.rs b/libsqlx/src/database/proxy/connection.rs index a7638e56..4e433d7b 100644 --- a/libsqlx/src/database/proxy/connection.rs +++ b/libsqlx/src/database/proxy/connection.rs @@ -144,9 +144,7 @@ where // set the connection state to unknown before executing on the remote self.state.lock().state = State::Unknown; - self.conn - .execute_program(&self.pgm, Box::new(builder)) - .unwrap(); + self.conn.execute_program(&self.pgm, Box::new(builder)); Ok(false) } else { @@ -164,11 +162,7 @@ where R: Connection, W: Connection + Clone + Send + 'static, { - fn execute_program( - &mut self, - pgm: &Program, - builder: Box, - ) -> crate::Result<()> { + fn execute_program(&mut self, pgm: &Program, builder: Box) { if self.state.lock().state.is_idle() && pgm.is_read_only() { if let Some(frame_no) = self.state.lock().last_frame_no { (self.wait_frame_no_cb)(frame_no); @@ -183,9 +177,8 @@ where // We know that this program won't perform any writes. We attempt to run it on the // replica. If it leaves an open transaction, then this program is an interactive // transaction, so we rollback the replica, and execute again on the primary. - self.read_conn.execute_program(pgm, Box::new(builder))?; + self.read_conn.execute_program(pgm, Box::new(builder)); // rollback(&mut self.conn.read_db); - Ok(()) } else { // we set the state to unknown because until we have received from the actual // connection state from the primary. @@ -194,8 +187,7 @@ where builder, state: self.state.clone(), }; - self.write_conn.execute_program(pgm, Box::new(builder))?; - Ok(()) + self.write_conn.execute_program(pgm, Box::new(builder)); } } From 137051c68c9f315ce59593b501579bbf24a2af10 Mon Sep 17 00:00:00 2001 From: ad hoc Date: Mon, 31 Jul 2023 16:28:31 +0200 Subject: [PATCH 64/64] allocation response payload --- Cargo.lock | 1 + libsqlx-server/Cargo.toml | 1 + libsqlx-server/src/hrana/error.rs | 6 +-- libsqlx-server/src/http/admin.rs | 75 ++++++++++++++++++++++++----- libsqlx-server/src/http/user/mod.rs | 12 ++--- libsqlx-server/src/manager.rs | 19 ++++---- libsqlx-server/src/meta.rs | 26 +++++++--- 7 files changed, 102 insertions(+), 38 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 53a34d6f..ed3c9c27 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2589,6 +2589,7 @@ dependencies = [ "bytemuck", "bytes 1.4.0", "bytesize", + "chrono", "clap", "color-eyre", "either", diff --git a/libsqlx-server/Cargo.toml b/libsqlx-server/Cargo.toml index 1ce8c6c8..243cb0ea 100644 --- a/libsqlx-server/Cargo.toml +++ b/libsqlx-server/Cargo.toml @@ -14,6 +14,7 @@ bincode = "1.3.3" bytemuck = { version = "1.13.1", features = ["derive"] } bytes = { version = "1.4.0", features = ["serde"] } bytesize = { version = "1.2.0", features = ["serde"] } +chrono = { version = "0.4.26", features = ["serde"] } clap = { version = "4.3.11", features = ["derive"] } color-eyre = "0.6.2" either = "1.8.1" diff --git a/libsqlx-server/src/hrana/error.rs b/libsqlx-server/src/hrana/error.rs index 2324887a..8f8711a1 100644 --- a/libsqlx-server/src/hrana/error.rs +++ b/libsqlx-server/src/hrana/error.rs @@ -18,13 +18,11 @@ pub enum HranaError { } impl HranaError { - pub fn code(&self) -> Option<&str>{ + pub fn code(&self) -> Option<&str> { match self { HranaError::Stmt(e) => Some(e.code()), HranaError::StreamResponse(e) => Some(e.code()), - HranaError::Stream(_) - | HranaError::Libsqlx(_) - | HranaError::Proto(_) => None, + HranaError::Stream(_) | HranaError::Libsqlx(_) | HranaError::Proto(_) => None, } } } diff --git a/libsqlx-server/src/http/admin.rs b/libsqlx-server/src/http/admin.rs index 2bac534c..e90ecfe2 100644 --- a/libsqlx-server/src/http/admin.rs +++ b/libsqlx-server/src/http/admin.rs @@ -4,10 +4,13 @@ use std::sync::Arc; use std::time::Duration; use axum::extract::{Path, State}; +use axum::response::IntoResponse; use axum::routing::{delete, post}; use axum::{Json, Router}; +use chrono::{DateTime, Utc}; use color_eyre::eyre::Result; use hyper::server::accept::Accept; +use hyper::StatusCode; use serde::{Deserialize, Deserializer, Serialize}; use tokio::io::{AsyncRead, AsyncWrite}; @@ -15,7 +18,35 @@ use crate::allocation::config::{AllocConfig, DbConfig}; use crate::linc::bus::Bus; use crate::linc::NodeId; use crate::manager::Manager; -use crate::meta::DatabaseId; +use crate::meta::{AllocationError, DatabaseId}; + +impl IntoResponse for crate::error::Error { + fn into_response(self) -> axum::response::Response { + #[derive(Serialize)] + struct ErrorBody { + message: String, + } + + let mut resp = Json(ErrorBody { + message: self.to_string(), + }) + .into_response(); + *resp.status_mut() = match self { + crate::error::Error::Libsqlx(_) + | crate::error::Error::InjectorExited + | crate::error::Error::ConnectionClosed + | crate::error::Error::Io(_) + | crate::error::Error::AllocationClosed + | crate::error::Error::Internal(_) + | crate::error::Error::Heed(_) => StatusCode::INTERNAL_SERVER_ERROR, + crate::error::Error::Allocation(AllocationError::AlreadyExist(_)) => { + StatusCode::BAD_REQUEST + } + }; + + resp + } +} pub struct Config { pub bus: Arc>>, @@ -47,7 +78,19 @@ where struct ErrorResponse {} #[derive(Serialize, Debug)] -struct AllocateResp {} +#[serde(rename_all = "lowercase")] +enum DbType { + Primary, + Replica, +} + +#[derive(Serialize, Debug)] +struct AllocationSummaryView { + created_at: DateTime, + database_name: String, + #[serde(rename = "type")] + ty: DbType, +} #[derive(Deserialize, Debug)] struct AllocateReq { @@ -134,7 +177,7 @@ const fn default_txn_timeout() -> HumanDuration { async fn allocate( State(state): State>, Json(req): Json, -) -> Result, Json> { +) -> crate::Result> { let config = AllocConfig { max_conccurent_connection: req.max_conccurent_connection.unwrap_or(16), db_name: req.database_name.clone(), @@ -164,19 +207,26 @@ async fn allocate( let dispatcher = state.bus.clone(); let id = DatabaseId::from_name(&req.database_name); - state.bus.handler().allocate(id, &config, dispatcher).await; + let meta = state.bus.handler().allocate(id, config, dispatcher).await?; - Ok(Json(AllocateResp {})) + Ok(Json(AllocationSummaryView { + created_at: meta.created_at, + database_name: meta.config.db_name, + ty: match meta.config.db_config { + DbConfig::Primary {..} => DbType::Primary, + DbConfig::Replica {..} => DbType::Replica, + } + })) } async fn deallocate( State(state): State>, Path(database_name): Path, -) -> Result, Json> { +) -> crate::Result<()> { let id = DatabaseId::from_name(&database_name); - state.bus.handler().deallocate(id).await; + state.bus.handler().deallocate(id).await?; - Ok(Json(AllocateResp {})) + Ok(()) } #[derive(Serialize, Debug)] @@ -191,15 +241,16 @@ struct AllocView { async fn list_allocs( State(state): State>, -) -> Result, Json> { +) -> crate::Result> { let allocs = state .bus .handler() .store() - .list_allocs() - .unwrap() + .list_allocs()? .into_iter() - .map(|cfg| AllocView { id: cfg.db_name }) + .map(|meta| AllocView { + id: meta.config.db_name, + }) .collect(); Ok(Json(ListAllocResp { allocs })) diff --git a/libsqlx-server/src/http/user/mod.rs b/libsqlx-server/src/http/user/mod.rs index 3653377b..7b43d36d 100644 --- a/libsqlx-server/src/http/user/mod.rs +++ b/libsqlx-server/src/http/user/mod.rs @@ -5,8 +5,8 @@ use axum::response::IntoResponse; use axum::routing::post; use axum::{Json, Router}; use color_eyre::Result; -use hyper::StatusCode; use hyper::server::accept::Accept; +use hyper::StatusCode; use serde::Serialize; use tokio::io::{AsyncRead, AsyncWrite}; @@ -30,12 +30,12 @@ impl IntoResponse for HranaError { fn into_response(self) -> axum::response::Response { let (message, code) = match self.code() { Some(code) => (self.to_string(), code.to_owned()), - None => ("internal error, please check the logs".to_owned(), "INTERNAL_ERROR".to_owned()), - }; - let resp = ErrorResponseBody { - message, - code, + None => ( + "internal error, please check the logs".to_owned(), + "INTERNAL_ERROR".to_owned(), + ), }; + let resp = ErrorResponseBody { message, code }; let mut resp = Json(resp).into_response(); *resp.status_mut() = StatusCode::BAD_REQUEST; resp diff --git a/libsqlx-server/src/manager.rs b/libsqlx-server/src/manager.rs index 686e6fca..2f0cafa1 100644 --- a/libsqlx-server/src/manager.rs +++ b/libsqlx-server/src/manager.rs @@ -12,7 +12,7 @@ use crate::compactor::CompactionQueue; use crate::linc::bus::Dispatch; use crate::linc::handler::Handler; use crate::linc::Inbound; -use crate::meta::{DatabaseId, Store}; +use crate::meta::{AllocMeta, DatabaseId, Store}; use crate::replica_commit_store::ReplicaCommitStore; pub struct Manager { @@ -52,14 +52,14 @@ impl Manager { return Ok(Some(sender.clone())); } - if let Some(config) = self.meta_store.meta(&database_id)? { + if let Some(meta) = self.meta_store.meta(&database_id)? { let path = self.db_path.join("dbs").join(database_id.to_string()); tokio::fs::create_dir_all(&path).await?; let (alloc_sender, inbox) = mpsc::channel(MAX_ALLOC_MESSAGE_QUEUE_LEN); let alloc = Allocation { inbox, database: Database::from_config( - &config, + &meta.config, path, dispatcher.clone(), self.compaction_queue.clone(), @@ -67,9 +67,9 @@ impl Manager { )?, connections_futs: JoinSet::new(), next_conn_id: 0, - max_concurrent_connections: config.max_conccurent_connection, + max_concurrent_connections: meta.config.max_conccurent_connection, dispatcher, - db_name: config.db_name, + db_name: meta.config.db_name, connections: HashMap::new(), }; @@ -86,12 +86,13 @@ impl Manager { pub async fn allocate( self: &Arc, database_id: DatabaseId, - meta: &AllocConfig, + config: AllocConfig, dispatcher: Arc, - ) -> crate::Result<()> { - self.store().allocate(&database_id, meta)?; + ) -> crate::Result { + let meta = self.store().allocate(&database_id, config)?; self.schedule(database_id, dispatcher).await?; - Ok(()) + + Ok(meta) } pub async fn deallocate(&self, database_id: DatabaseId) -> crate::Result<()> { diff --git a/libsqlx-server/src/meta.rs b/libsqlx-server/src/meta.rs index baefaeaf..a2a3ac87 100644 --- a/libsqlx-server/src/meta.rs +++ b/libsqlx-server/src/meta.rs @@ -1,6 +1,7 @@ use std::fmt; use std::mem::size_of; +use chrono::{DateTime, Utc}; use heed::bytemuck::{Pod, Zeroable}; use heed_types::{OwnedType, SerdeBincode}; use itertools::Itertools; @@ -11,9 +12,15 @@ use tokio::task::block_in_place; use crate::allocation::config::AllocConfig; +#[derive(Debug, Serialize, Deserialize)] +pub struct AllocMeta { + pub config: AllocConfig, + pub created_at: DateTime, +} + pub struct Store { env: heed::Env, - alloc_config_db: heed::Database, SerdeBincode>, + alloc_config_db: heed::Database, SerdeBincode>, } #[derive(Serialize, Deserialize, PartialEq, Eq, Debug, Hash, Clone, Copy, Pod, Zeroable)] @@ -73,7 +80,7 @@ impl Store { }) } - pub fn allocate(&self, id: &DatabaseId, meta: &AllocConfig) -> crate::Result<()> { + pub fn allocate(&self, id: &DatabaseId, config: AllocConfig) -> crate::Result { block_in_place(|| { let mut txn = self.env.write_txn()?; if self @@ -82,14 +89,19 @@ impl Store { .get(&txn, id)? .is_some() { - Err(AllocationError::AlreadyExist(meta.db_name.clone()))?; + Err(AllocationError::AlreadyExist(config.db_name.clone()))?; }; - self.alloc_config_db.put(&mut txn, id, meta)?; + let meta = AllocMeta { + config, + created_at: Utc::now(), + }; + + self.alloc_config_db.put(&mut txn, id, &meta)?; txn.commit()?; - Ok(()) + Ok(meta) }) } @@ -103,14 +115,14 @@ impl Store { }) } - pub fn meta(&self, id: &DatabaseId) -> crate::Result> { + pub fn meta(&self, id: &DatabaseId) -> crate::Result> { block_in_place(|| { let txn = self.env.read_txn()?; Ok(self.alloc_config_db.get(&txn, id)?) }) } - pub fn list_allocs(&self) -> crate::Result> { + pub fn list_allocs(&self) -> crate::Result> { block_in_place(|| { let txn = self.env.read_txn()?; let res = self