Skip to content

Commit 41b2aed

Browse files
committed
refactor: issue request synchronously for MultiReader::commit and MultiWriter::commit
See #6 and #7.
1 parent b473b5e commit 41b2aed

File tree

4 files changed

+116
-39
lines changed

4 files changed

+116
-39
lines changed

Cargo.toml

+1
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ log = "0.4.14"
2727
static_assertions = "1.1.0"
2828
hashbrown = "0.12.0"
2929
hashlink = "0.8.0"
30+
either = "1.9.0"
3031

3132
[dev-dependencies]
3233
rand = "0.8.4"

src/client/mod.rs

+80-38
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ use std::future::Future;
55
use std::time::Duration;
66

77
use const_format::formatcp;
8+
use either::{Either, Left, Right};
89
use thiserror::Error;
910
use tokio::sync::{mpsc, watch};
1011

@@ -279,15 +280,25 @@ impl Client {
279280
receiver
280281
}
281282

282-
async fn wait<T, F>(result: Result<F>) -> Result<T>
283+
async fn wait<T, E, F>(result: std::result::Result<F, E>) -> std::result::Result<T, E>
283284
where
284-
F: Future<Output = Result<T>>, {
285+
F: Future<Output = std::result::Result<T, E>>, {
285286
match result {
286287
Err(err) => Err(err),
287288
Ok(future) => future.await,
288289
}
289290
}
290291

292+
async fn resolve<T, E, F>(result: std::result::Result<Either<F, T>, E>) -> std::result::Result<T, E>
293+
where
294+
F: Future<Output = std::result::Result<T, E>>, {
295+
match result {
296+
Err(err) => Err(err),
297+
Ok(Right(r)) => Ok(r),
298+
Ok(Left(future)) => future.await,
299+
}
300+
}
301+
291302
async fn map_wait<T, U, Fu, Fn>(result: Result<Fu>, f: Fn) -> Result<U>
292303
where
293304
Fu: Future<Output = Result<T>>,
@@ -936,8 +947,11 @@ trait MultiBuffer {
936947
fn op_code() -> OpCode;
937948

938949
fn build_request(&mut self) -> MarshalledRequest {
939-
let header = MultiHeader { op: OpCode::Error, done: true, err: -1 };
940950
let buffer = self.buffer();
951+
if buffer.is_empty() {
952+
return Default::default();
953+
}
954+
let header = MultiHeader { op: OpCode::Error, done: true, err: -1 };
941955
buffer.append_record(&header);
942956
buffer.finish();
943957
MarshalledRequest(std::mem::take(buffer))
@@ -1014,23 +1028,32 @@ impl<'a> MultiReader<'a> {
10141028
///
10151029
/// # Notable behaviors
10161030
/// Individual errors(eg. [Error::NoNode]) are reported individually through [MultiReadResult::Error].
1017-
pub async fn commit(&mut self) -> Result<Vec<MultiReadResult>> {
1018-
if self.buf.is_empty() {
1019-
return Ok(Default::default());
1020-
}
1031+
pub fn commit(&mut self) -> impl Future<Output = Result<Vec<MultiReadResult>>> + Send + '_ {
10211032
let request = self.build_request();
1033+
Client::resolve(self.commit_internally(request))
1034+
}
1035+
1036+
fn commit_internally(
1037+
&self,
1038+
request: MarshalledRequest,
1039+
) -> Result<Either<impl Future<Output = Result<Vec<MultiReadResult>>> + Send + '_, Vec<MultiReadResult>>> {
1040+
if request.is_empty() {
1041+
return Ok(Right(Vec::default()));
1042+
}
10221043
let receiver = self.client.send_marshalled_request(request);
1023-
let (body, _) = receiver.await?;
1024-
let response = record::unmarshal::<Vec<MultiReadResponse>>(&mut body.as_slice())?;
1025-
let mut results = Vec::with_capacity(response.len());
1026-
for result in response {
1027-
match result {
1028-
MultiReadResponse::Data { data, stat } => results.push(MultiReadResult::Data { data, stat }),
1029-
MultiReadResponse::Children { children } => results.push(MultiReadResult::Children { children }),
1030-
MultiReadResponse::Error(err) => results.push(MultiReadResult::Error { err }),
1044+
Ok(Left(async move {
1045+
let (body, _) = receiver.await?;
1046+
let response = record::unmarshal::<Vec<MultiReadResponse>>(&mut body.as_slice())?;
1047+
let mut results = Vec::with_capacity(response.len());
1048+
for result in response {
1049+
match result {
1050+
MultiReadResponse::Data { data, stat } => results.push(MultiReadResult::Data { data, stat }),
1051+
MultiReadResponse::Children { children } => results.push(MultiReadResult::Children { children }),
1052+
MultiReadResponse::Error(err) => results.push(MultiReadResult::Error { err }),
1053+
}
10311054
}
1032-
}
1033-
Ok(results)
1055+
Ok(results)
1056+
}))
10341057
}
10351058

10361059
/// Clears collected operations.
@@ -1184,30 +1207,49 @@ impl<'a> MultiWriter<'a> {
11841207
///
11851208
/// # Notable errors
11861209
/// * [Error::BadVersion] if check version failed.
1187-
pub async fn commit(&mut self) -> std::result::Result<Vec<MultiWriteResult>, MultiWriteError> {
1188-
if self.buf.is_empty() {
1189-
return Ok(Default::default());
1190-
}
1210+
pub fn commit(
1211+
&mut self,
1212+
) -> impl Future<Output = std::result::Result<Vec<MultiWriteResult>, MultiWriteError>> + Send + '_ {
11911213
let request = self.build_request();
1214+
Client::resolve(self.commit_internally(request))
1215+
}
1216+
1217+
fn commit_internally(
1218+
&self,
1219+
request: MarshalledRequest,
1220+
) -> std::result::Result<
1221+
Either<
1222+
impl Future<Output = std::result::Result<Vec<MultiWriteResult>, MultiWriteError>> + Send + '_,
1223+
Vec<MultiWriteResult>,
1224+
>,
1225+
MultiWriteError,
1226+
> {
1227+
if request.is_empty() {
1228+
return Ok(Right(Vec::default()));
1229+
}
11921230
let receiver = self.client.send_marshalled_request(request);
1193-
let (body, _) = receiver.await?;
1194-
let response = record::unmarshal::<Vec<MultiWriteResponse>>(&mut body.as_slice())?;
1195-
let failed = response.first().map(|r| matches!(r, MultiWriteResponse::Error(_))).unwrap_or(false);
1196-
let mut results = if failed { Vec::new() } else { Vec::with_capacity(response.len()) };
1197-
for (index, result) in response.into_iter().enumerate() {
1198-
match result {
1199-
MultiWriteResponse::Check => results.push(MultiWriteResult::Check),
1200-
MultiWriteResponse::Delete => results.push(MultiWriteResult::Delete),
1201-
MultiWriteResponse::Create { path, stat } => {
1202-
util::strip_root_path(path, self.client.chroot.root())?;
1203-
results.push(MultiWriteResult::Create { path: path.to_string(), stat });
1204-
},
1205-
MultiWriteResponse::SetData { stat } => results.push(MultiWriteResult::SetData { stat }),
1206-
MultiWriteResponse::Error(Error::UnexpectedErrorCode(0)) => {},
1207-
MultiWriteResponse::Error(err) => return Err(MultiWriteError::OperationFailed { index, source: err }),
1231+
Ok(Left(async move {
1232+
let (body, _) = receiver.await?;
1233+
let response = record::unmarshal::<Vec<MultiWriteResponse>>(&mut body.as_slice())?;
1234+
let failed = response.first().map(|r| matches!(r, MultiWriteResponse::Error(_))).unwrap_or(false);
1235+
let mut results = if failed { Vec::new() } else { Vec::with_capacity(response.len()) };
1236+
for (index, result) in response.into_iter().enumerate() {
1237+
match result {
1238+
MultiWriteResponse::Check => results.push(MultiWriteResult::Check),
1239+
MultiWriteResponse::Delete => results.push(MultiWriteResult::Delete),
1240+
MultiWriteResponse::Create { path, stat } => {
1241+
util::strip_root_path(path, self.client.chroot.root())?;
1242+
results.push(MultiWriteResult::Create { path: path.to_string(), stat });
1243+
},
1244+
MultiWriteResponse::SetData { stat } => results.push(MultiWriteResult::SetData { stat }),
1245+
MultiWriteResponse::Error(Error::UnexpectedErrorCode(0)) => {},
1246+
MultiWriteResponse::Error(err) => {
1247+
return Err(MultiWriteError::OperationFailed { index, source: err })
1248+
},
1249+
}
12081250
}
1209-
}
1210-
Ok(results)
1251+
Ok(results)
1252+
}))
12111253
}
12121254

12131255
/// Clears collected operations.

src/session/request.rs

+5-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ use crate::error::Error;
1212
use crate::proto::{self, AddWatchMode, ConnectRequest, OpCode, RequestHeader};
1313
use crate::record::{self, Record, StaticRecord};
1414

15-
#[derive(Clone, Debug)]
15+
#[derive(Clone, Debug, Default)]
1616
pub struct MarshalledRequest(pub Vec<u8>);
1717

1818
impl MarshalledRequest {
@@ -82,6 +82,10 @@ impl MarshalledRequest {
8282
};
8383
(op_code, watcher_info)
8484
}
85+
86+
pub fn is_empty(&self) -> bool {
87+
self.0.is_empty()
88+
}
8589
}
8690

8791
pub enum Operation {

tests/zookeeper.rs

+30
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,36 @@ async fn test_multi() {
215215
assert_that!(results).is_empty();
216216
}
217217

218+
#[tokio::test]
219+
async fn test_multi_async_order() {
220+
let docker = DockerCli::default();
221+
let zookeeper = docker.run(zookeeper_image());
222+
let zk_port = zookeeper.get_host_port(2181);
223+
224+
let cluster = format!("127.0.0.1:{}", zk_port);
225+
let client = zk::Client::connect(&cluster).await.unwrap();
226+
227+
let create_options = zk::CreateOptions::new(zk::CreateMode::Persistent, zk::Acl::anyone_all());
228+
client.create("/a", "a0".as_bytes(), &create_options).await.unwrap();
229+
230+
let mut writer = client.new_multi_writer();
231+
writer.add_set_data("/a", "a1".as_bytes(), None).unwrap();
232+
let write = writer.commit();
233+
234+
let mut reader = client.new_multi_reader();
235+
reader.add_get_data("/a").unwrap();
236+
let mut results = reader.commit().await.unwrap();
237+
let zk::MultiReadResult::Data { data, stat } = results.remove(0) else { panic!("expect get data result") };
238+
239+
let mut write_results = write.await.unwrap();
240+
let zk::MultiWriteResult::SetData { stat: set_stat } = write_results.remove(0) else {
241+
panic!("expect set data result")
242+
};
243+
244+
assert_that!(data).is_equal_to("a1".as_bytes().to_owned());
245+
assert_that!(stat).is_equal_to(set_stat);
246+
}
247+
218248
#[tokio::test]
219249
async fn test_no_node() {
220250
let docker = DockerCli::default();

0 commit comments

Comments
 (0)