Skip to content
This repository was archived by the owner on Jan 6, 2025. It is now read-only.

Commit 261beba

Browse files
extract constants, minor cleanup
1 parent a904fe6 commit 261beba

File tree

3 files changed

+130
-77
lines changed

3 files changed

+130
-77
lines changed

src/transport/message_handler.rs

Lines changed: 47 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,23 +3,36 @@ use crate::transport::msgs::{LSPSMessage, Prefix, RawLSPSMessage, LSPS_MESSAGE_T
33
use bitcoin::secp256k1::PublicKey;
44
use lightning::ln::peer_handler::CustomMessageHandler;
55
use lightning::ln::wire::CustomMessageReader;
6+
use lightning::log_info;
7+
use lightning::util::logger::Logger;
68
use std::collections::HashMap;
79
use std::convert::{TryFrom, TryInto};
810
use std::io;
11+
use std::ops::Deref;
912
use std::sync::{Arc, Mutex};
1013

14+
/// A trait used to implement a specific LSPS protocol
15+
/// The messages the protocol uses need to be able to be mapped
16+
/// from and into LSPSMessages.
1117
pub trait ProtocolMessageHandler {
1218
type ProtocolMessage: TryFrom<LSPSMessage> + Into<LSPSMessage>;
19+
const PROTOCOL_NUMBER: Option<u16>;
1320

1421
fn handle_message(
1522
&self, message: Self::ProtocolMessage, counterparty_node_id: &PublicKey,
1623
) -> Result<(), lightning::ln::msgs::LightningError>;
1724
fn get_and_clear_pending_protocol_messages(&self) -> Vec<(PublicKey, Self::ProtocolMessage)>;
1825
fn get_and_clear_pending_protocol_events(&self) -> Vec<Event>;
19-
fn get_protocol_number(&self) -> Option<u16>;
26+
fn get_protocol_number(&self) -> Option<u16> {
27+
Self::PROTOCOL_NUMBER
28+
}
2029
}
2130

22-
pub trait MessageHandler {
31+
/// A trait used to implement the mapping from a LSPS transport layer mesage
32+
/// to a specific protocol message. This enables the ProtocolMessageHandler's
33+
/// to not need to know about LSPSMessage and only have to deal with the specific
34+
/// messages related to the protocol that is being implemented.
35+
pub trait TransportMessageHandler {
2336
fn handle_lsps_message(
2437
&self, message: LSPSMessage, counterparty_node_id: &PublicKey,
2538
) -> Result<(), lightning::ln::msgs::LightningError>;
@@ -28,7 +41,7 @@ pub trait MessageHandler {
2841
fn get_protocol_number(&self) -> Option<u16>;
2942
}
3043

31-
impl<T> MessageHandler for T
44+
impl<T> TransportMessageHandler for T
3245
where
3346
T: ProtocolMessageHandler,
3447
LSPSMessage: TryInto<<T as ProtocolMessageHandler>::ProtocolMessage>,
@@ -59,27 +72,37 @@ where
5972
}
6073
}
6174

62-
pub struct LSPManager {
75+
pub struct LSPManager<L: Deref>
76+
where
77+
L::Target: Logger,
78+
{
79+
logger: L,
6380
pending_messages: Mutex<Vec<(PublicKey, RawLSPSMessage)>>,
6481
request_id_to_method_map: Mutex<HashMap<String, String>>,
65-
message_handlers: Arc<Mutex<HashMap<Prefix, Arc<dyn MessageHandler>>>>,
82+
message_handlers: Arc<Mutex<HashMap<Prefix, Arc<dyn TransportMessageHandler>>>>,
6683
}
6784

68-
impl LSPManager {
69-
pub fn new() -> Self {
85+
impl<L: Deref> LSPManager<L>
86+
where
87+
L::Target: Logger,
88+
{
89+
pub fn new(logger: L) -> Self {
7090
Self {
91+
logger,
7192
pending_messages: Mutex::new(Vec::new()),
7293
request_id_to_method_map: Mutex::new(HashMap::new()),
7394
message_handlers: Arc::new(Mutex::new(HashMap::new())),
7495
}
7596
}
7697

77-
pub fn get_message_handlers(&self) -> Arc<Mutex<HashMap<Prefix, Arc<dyn MessageHandler>>>> {
98+
pub fn get_message_handlers(
99+
&self,
100+
) -> Arc<Mutex<HashMap<Prefix, Arc<dyn TransportMessageHandler>>>> {
78101
self.message_handlers.clone()
79102
}
80103

81104
pub fn register_message_handler(
82-
&self, prefix: Prefix, message_handler: Arc<dyn MessageHandler>,
105+
&self, prefix: Prefix, message_handler: Arc<dyn TransportMessageHandler>,
83106
) {
84107
self.message_handlers.lock().unwrap().insert(prefix, message_handler);
85108
}
@@ -103,6 +126,13 @@ impl LSPManager {
103126
// TODO: not sure what we are supposed to do when we receive a message we don't have a handler for
104127
if let Some(message_handler) = message_handlers.get(&prefix) {
105128
message_handler.handle_lsps_message(msg, sender_node_id)?;
129+
} else {
130+
log_info!(
131+
self.logger,
132+
"Received a message from {:?} we do not have a handler for: {:?}",
133+
sender_node_id,
134+
msg
135+
);
106136
}
107137
}
108138
Ok(())
@@ -114,7 +144,10 @@ impl LSPManager {
114144
}
115145
}
116146

117-
impl CustomMessageReader for LSPManager {
147+
impl<L: Deref> CustomMessageReader for LSPManager<L>
148+
where
149+
L::Target: Logger,
150+
{
118151
type CustomMessage = RawLSPSMessage;
119152

120153
fn read<R: io::Read>(
@@ -131,7 +164,10 @@ impl CustomMessageReader for LSPManager {
131164
}
132165
}
133166

134-
impl CustomMessageHandler for LSPManager {
167+
impl<L: Deref> CustomMessageHandler for LSPManager<L>
168+
where
169+
L::Target: Logger,
170+
{
135171
fn handle_custom_message(
136172
&self, msg: Self::CustomMessage, sender_node_id: &PublicKey,
137173
) -> Result<(), lightning::ln::msgs::LightningError> {

src/transport/msgs.rs

Lines changed: 78 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,18 @@ use std::collections::HashMap;
99
use std::convert::TryFrom;
1010
use std::fmt;
1111

12+
const LSPS_MESSAGE_SERIALIZED_STRUCT_NAME: &str = "LSPSMessage";
13+
const JSONRPC_FIELD_KEY: &str = "jsonrpc";
14+
const JSONRPC_FIELD_VALUE: &str = "2.0";
15+
const JSONRPC_METHOD_FIELD_KEY: &str = "method";
16+
const JSONRPC_ID_FIELD_KEY: &str = "id";
17+
const JSONRPC_PARAMS_FIELD_KEY: &str = "params";
18+
const JSONRPC_RESULT_FIELD_KEY: &str = "result";
19+
const JSONRPC_ERROR_FIELD_KEY: &str = "error";
20+
const JSONRPC_INVALID_MESSAGE_ERROR_CODE: i32 = -32700;
21+
const JSONRPC_INVALID_MESSAGE_ERROR_MESSAGE: &str = "parse error";
22+
const LSPS0_LISTPROTOCOLS_METHOD_NAME: &str = "lsps0.listprotocols";
23+
1224
pub const LSPS_MESSAGE_TYPE: u16 = 37913;
1325

1426
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
@@ -118,39 +130,43 @@ impl Serialize for LSPSMessage {
118130
where
119131
S: serde::Serializer,
120132
{
121-
let mut jsonrpc_object = serializer.serialize_struct("LSPSMessage", 3)?;
133+
let mut jsonrpc_object =
134+
serializer.serialize_struct(LSPS_MESSAGE_SERIALIZED_STRUCT_NAME, 3)?;
122135

123-
jsonrpc_object.serialize_field("jsonrpc", "2.0")?;
136+
jsonrpc_object.serialize_field(JSONRPC_FIELD_KEY, JSONRPC_FIELD_VALUE)?;
124137

125138
match self {
126139
LSPSMessage::LSPS0(LSPS0Message::Request(request_id, request)) => {
127-
jsonrpc_object.serialize_field("method", request.method())?;
128-
jsonrpc_object.serialize_field("id", &request_id.0)?;
140+
jsonrpc_object.serialize_field(JSONRPC_METHOD_FIELD_KEY, request.method())?;
141+
jsonrpc_object.serialize_field(JSONRPC_ID_FIELD_KEY, &request_id.0)?;
129142

130143
match request {
131144
LSPS0Request::ListProtocols(params) => {
132-
jsonrpc_object.serialize_field("params", params)?
145+
jsonrpc_object.serialize_field(JSONRPC_PARAMS_FIELD_KEY, params)?
133146
}
134147
};
135148
}
136149
LSPSMessage::LSPS0(LSPS0Message::Response(request_id, response)) => {
137-
jsonrpc_object.serialize_field("id", &request_id.0)?;
150+
jsonrpc_object.serialize_field(JSONRPC_ID_FIELD_KEY, &request_id.0)?;
138151

139152
match response {
140153
LSPS0Response::ListProtocols(result) => {
141-
jsonrpc_object.serialize_field("result", result)?;
154+
jsonrpc_object.serialize_field(JSONRPC_RESULT_FIELD_KEY, result)?;
142155
}
143156
LSPS0Response::ListProtocolsError(error) => {
144-
jsonrpc_object.serialize_field("error", error)?;
157+
jsonrpc_object.serialize_field(JSONRPC_ERROR_FIELD_KEY, error)?;
145158
}
146159
}
147160
}
148161
LSPSMessage::Invalid => {
149-
let error =
150-
ResponseError { code: -32700, message: "parse error".to_string(), data: None };
162+
let error = ResponseError {
163+
code: JSONRPC_INVALID_MESSAGE_ERROR_CODE,
164+
message: JSONRPC_INVALID_MESSAGE_ERROR_MESSAGE.to_string(),
165+
data: None,
166+
};
151167

152-
jsonrpc_object.serialize_field("id", &serde_json::Value::Null)?;
153-
jsonrpc_object.serialize_field("error", &error)?;
168+
jsonrpc_object.serialize_field(JSONRPC_ID_FIELD_KEY, &serde_json::Value::Null)?;
169+
jsonrpc_object.serialize_field(JSONRPC_ERROR_FIELD_KEY, &error)?;
154170
}
155171
}
156172

@@ -202,59 +218,63 @@ impl<'de, 'a> Visitor<'de> for LSPSMessageVisitor<'a> {
202218
}
203219
}
204220

205-
if let Some(method) = method {
206-
if let Some(id) = id {
207-
match method {
208-
"lsps0.listprotocols" => {
209-
let list_protocols_request =
210-
serde_json::from_value(params.unwrap_or(json!({})))
211-
.map_err(de::Error::custom)?;
221+
match (id, method) {
222+
(Some(id), Some(method)) => match method {
223+
LSPS0_LISTPROTOCOLS_METHOD_NAME => {
224+
let list_protocols_request =
225+
serde_json::from_value(params.unwrap_or(json!({})))
226+
.map_err(de::Error::custom)?;
212227

213-
self.request_id_to_method.insert(id.clone(), method.to_string());
228+
self.request_id_to_method.insert(id.clone(), method.to_string());
214229

215-
Ok(LSPSMessage::LSPS0(LSPS0Message::Request(
216-
RequestId(id),
217-
LSPS0Request::ListProtocols(list_protocols_request),
218-
)))
219-
}
220-
_ => Err(de::Error::custom(format!(
221-
"Received request with unknown method: {}",
222-
method
223-
))),
230+
Ok(LSPSMessage::LSPS0(LSPS0Message::Request(
231+
RequestId(id),
232+
LSPS0Request::ListProtocols(list_protocols_request),
233+
)))
224234
}
225-
} else {
226-
Err(de::Error::custom(format!("Received unknown notification: {}", method)))
227-
}
228-
} else if let Some(id) = id {
229-
if let Some(method) = self.request_id_to_method.get(&id) {
230-
match method.as_str() {
231-
"lsps0.listprotocols" => {
232-
if let Some(error) = error {
233-
Ok(LSPSMessage::LSPS0(LSPS0Message::Response(
234-
RequestId(id),
235-
LSPS0Response::ListProtocolsError(error),
236-
)))
237-
} else if let Some(result) = result {
238-
let list_protocols_response =
239-
serde_json::from_value(result).map_err(de::Error::custom)?;
240-
Ok(LSPSMessage::LSPS0(LSPS0Message::Response(
241-
RequestId(id),
242-
LSPS0Response::ListProtocols(list_protocols_response),
243-
)))
244-
} else {
245-
Err(de::Error::custom("Received invalid JSON-RPC object: one of method, result, or error required"))
235+
_ => Err(de::Error::custom(format!(
236+
"Received request with unknown method: {}",
237+
method
238+
))),
239+
},
240+
(Some(id), None) => {
241+
if let Some(method) = self.request_id_to_method.get(&id) {
242+
match method.as_str() {
243+
LSPS0_LISTPROTOCOLS_METHOD_NAME => {
244+
if let Some(error) = error {
245+
Ok(LSPSMessage::LSPS0(LSPS0Message::Response(
246+
RequestId(id),
247+
LSPS0Response::ListProtocolsError(error),
248+
)))
249+
} else if let Some(result) = result {
250+
let list_protocols_response =
251+
serde_json::from_value(result).map_err(de::Error::custom)?;
252+
Ok(LSPSMessage::LSPS0(LSPS0Message::Response(
253+
RequestId(id),
254+
LSPS0Response::ListProtocols(list_protocols_response),
255+
)))
256+
} else {
257+
Err(de::Error::custom("Received invalid JSON-RPC object: one of method, result, or error required"))
258+
}
246259
}
260+
_ => Err(de::Error::custom(format!(
261+
"Received response for an unknown request method: {}",
262+
method
263+
))),
247264
}
248-
_ => Err(de::Error::custom(format!(
249-
"Received response for an unknown request method: {}",
250-
method
251-
))),
265+
} else {
266+
Err(de::Error::custom(format!(
267+
"Received response for unknown request id: {}",
268+
id
269+
)))
252270
}
253-
} else {
254-
Err(de::Error::custom(format!("Received response for unknown request id: {}", id)))
255271
}
256-
} else {
257-
Err(de::Error::custom("Received invalid JSON-RPC object: one of method or id required"))
272+
(None, Some(method)) => {
273+
Err(de::Error::custom(format!("Received unknown notification: {}", method)))
274+
}
275+
(None, None) => Err(de::Error::custom(
276+
"Received invalid JSON-RPC object: one of method or id required",
277+
)),
258278
}
259279
}
260280
}

src/transport/protocol.rs

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ use crate::events::Event;
1010
use crate::transport::message_handler::ProtocolMessageHandler;
1111
use crate::utils;
1212

13-
use super::message_handler::MessageHandler;
13+
use super::message_handler::TransportMessageHandler;
1414
use super::msgs::{
1515
LSPS0Message, LSPS0Request, LSPS0Response, ListProtocolsRequest, ListProtocolsResponse, Prefix,
1616
RequestId, ResponseError,
@@ -22,7 +22,7 @@ where
2222
ES::Target: EntropySource,
2323
{
2424
logger: L,
25-
message_handlers: Arc<Mutex<HashMap<Prefix, Arc<dyn MessageHandler>>>>,
25+
message_handlers: Arc<Mutex<HashMap<Prefix, Arc<dyn TransportMessageHandler>>>>,
2626
pending_messages: Mutex<Vec<(PublicKey, LSPS0Message)>>,
2727
entropy_source: ES,
2828
pending_events: Mutex<Vec<Event>>,
@@ -34,7 +34,7 @@ where
3434
ES::Target: EntropySource,
3535
{
3636
pub fn new(
37-
logger: L, message_handlers: Arc<Mutex<HashMap<Prefix, Arc<dyn MessageHandler>>>>,
37+
logger: L, message_handlers: Arc<Mutex<HashMap<Prefix, Arc<dyn TransportMessageHandler>>>>,
3838
entropy_source: ES,
3939
) -> Self {
4040
Self {
@@ -128,6 +128,7 @@ where
128128
ES::Target: EntropySource,
129129
{
130130
type ProtocolMessage = LSPS0Message;
131+
const PROTOCOL_NUMBER: Option<u16> = None;
131132

132133
fn handle_message(
133134
&self, message: Self::ProtocolMessage, counterparty_node_id: &PublicKey,
@@ -153,10 +154,6 @@ where
153154
let mut pending_events = self.pending_events.lock().unwrap();
154155
std::mem::take(&mut *pending_events)
155156
}
156-
157-
fn get_protocol_number(&self) -> Option<u16> {
158-
None
159-
}
160157
}
161158

162159
#[cfg(test)]
@@ -180,7 +177,7 @@ mod tests {
180177
fn test_handle_list_protocols_request() {
181178
let logger = Arc::new(TestLogger {});
182179
let entropy = Arc::new(TestEntropy {});
183-
let message_handlers: Arc<Mutex<HashMap<Prefix, Arc<dyn MessageHandler>>>> =
180+
let message_handlers: Arc<Mutex<HashMap<Prefix, Arc<dyn TransportMessageHandler>>>> =
184181
Arc::new(Mutex::new(HashMap::new()));
185182
let lsps0_handler =
186183
Arc::new(LSPS0MessageHandler::new(logger, message_handlers.clone(), entropy));

0 commit comments

Comments
 (0)