Skip to content
This repository was archived by the owner on Aug 2, 2021. It is now read-only.

Commit b3d0708

Browse files
authored
P2p validate accounting (#2051)
* p2p/protocols: change sequence of message accounting
1 parent 681742a commit b3d0708

File tree

6 files changed

+139
-68
lines changed

6 files changed

+139
-68
lines changed

p2p/protocols/accounting.go

+20-25
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ type Price struct {
7171
// A protocol provides the message price in absolute value
7272
// This method then returns the correct signed amount,
7373
// depending on who pays, which is identified by the `payer` argument:
74-
// `Send` will pass a `Sender` payer, `Receive` will pass the `Receiver` argument.
74+
// Sending will pass a `Sender` payer, receiving will pass the `Receiver` argument.
7575
// Thus: If Sending and sender pays, amount negative, otherwise positive
7676
// If Receiving, and receiver pays, amount negative, otherwise positive
7777
func (p *Price) For(payer Payer, size uint32) int64 {
@@ -93,6 +93,10 @@ type Balance interface {
9393
// positive amount = credit local node
9494
// negative amount = debit local node
9595
Add(amount int64, peer *Peer) error
96+
// Check is a dry-run for the Add operation:
97+
// As the accounting takes place **after** the actual send/receive operation happens,
98+
// we want to make sure that that operation would not result in any problem
99+
Check(amount int64, peer *Peer) error
96100
}
97101

98102
// Accounting implements the Hook interface
@@ -118,44 +122,35 @@ func SetupAccountingMetrics(reportInterval time.Duration, path string) *Accounti
118122
return NewAccountingMetrics(metrics.AccountingRegistry, reportInterval, path)
119123
}
120124

121-
// Send takes a peer, a size and a msg and
122-
// - calculates the cost for the local node sending a msg of size to peer querying the message for its price
123-
// - credits/debits local node using balance interface
124-
func (ah *Accounting) Send(peer *Peer, size uint32, msg interface{}) error {
125-
// get the price for a message
126-
var pricedMessage PricedMessage
127-
var ok bool
128-
// if the msg implements `Price`, it is an accounted message
129-
if pricedMessage, ok = msg.(PricedMessage); !ok {
130-
return nil
131-
}
132-
// evaluate the price for sending messages
133-
costToLocalNode := pricedMessage.Price().For(Sender, size)
125+
// Apply takes a peer, the signed cost for the local node and the msg size and credits/debits local node using balance interface
126+
func (ah *Accounting) Apply(peer *Peer, costToLocalNode int64, size uint32) error {
134127
// do the accounting
135128
err := ah.Add(costToLocalNode, peer)
136129
// record metrics: just increase counters for user-facing metrics
137130
ah.doMetrics(costToLocalNode, size, err)
138131
return err
139132
}
140133

141-
// Receive takes a peer, a size and a msg and
142-
// - calculates the cost for the local node receiving a msg of size from peer querying the message for its price
143-
// - credits/debits local node using balance interface
144-
func (ah *Accounting) Receive(peer *Peer, size uint32, msg interface{}) error {
134+
// Validate calculates the cost for the local node sending or receiving a msg to/from a peer querying the message for its price.
135+
// It returns either the signed cost for the local node as int64 or an error, signaling that the accounting operation would fail
136+
// (no change has been applied at this point)
137+
func (ah *Accounting) Validate(peer *Peer, size uint32, msg interface{}, payer Payer) (int64, error) {
145138
// get the price for a message (by querying the message type via the PricedMessage interface)
146139
var pricedMessage PricedMessage
147140
var ok bool
148141
// if the msg implements `Price`, it is an accounted message
149142
if pricedMessage, ok = msg.(PricedMessage); !ok {
150-
return nil
143+
return 0, nil
151144
}
152145
// evaluate the price for receiving messages
153-
costToLocalNode := pricedMessage.Price().For(Receiver, size)
154-
// do the accounting
155-
err := ah.Add(costToLocalNode, peer)
156-
// record metrics: just increase counters for user-facing metrics
157-
ah.doMetrics(costToLocalNode, size, err)
158-
return err
146+
costToLocalNode := pricedMessage.Price().For(payer, size)
147+
// check that the operation would perform correctly
148+
err := ah.Check(costToLocalNode, peer)
149+
if err != nil {
150+
// signal to caller that the operation would fail
151+
return 0, err
152+
}
153+
return costToLocalNode, nil
159154
}
160155

161156
// record some metrics

p2p/protocols/accounting_simulation_test.go

+10-6
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,6 @@ func TestAccountingSimulation(t *testing.T) {
7171
}
7272
defer os.RemoveAll(dir)
7373
SetupAccountingMetrics(1*time.Second, filepath.Join(dir, "metrics.db"))
74-
//define the node.Service for this test
7574
services := adapters.Services{
7675
"accounting": func(ctx *adapters.ServiceContext) (node.Service, error) {
7776
return bal.newNode(), nil
@@ -82,8 +81,9 @@ func TestAccountingSimulation(t *testing.T) {
8281
net := simulations.NewNetwork(adapter, &simulations.NetworkConfig{DefaultService: "accounting"})
8382
defer net.Shutdown()
8483

85-
// we send msgs messages per node, wait for all messages to arrive
86-
bal.wg.Add(*nodes * *msgs)
84+
// we send msgs messages per node, wait for all messages to have been accounted
85+
// (Add runs for each send and receive, so we need double the amount)
86+
bal.wg.Add(*nodes * *msgs * 2)
8787
trigger := make(chan enode.ID)
8888
go func() {
8989
// wait for all of them to arrive
@@ -168,13 +168,14 @@ func newMatrix(n int) *matrix {
168168
}
169169

170170
// called from the testBalance's Add accounting function: register balance change
171-
func (m *matrix) add(i, j int, v int64) error {
171+
func (m *matrix) add(i, j int, v int64, wg *sync.WaitGroup) error {
172172
// index for the balance of local node i with remote nodde j is
173173
// i * number of nodes + remote node
174174
mi := i*m.n + j
175175
// register that balance
176176
m.lock.Lock()
177177
m.m[mi] += v
178+
wg.Done()
178179
m.lock.Unlock()
179180
return nil
180181
}
@@ -229,13 +230,17 @@ type testNode struct {
229230
peerCount int
230231
}
231232

233+
func (t *testNode) Check(a int64, p *Peer) error {
234+
return nil
235+
}
236+
232237
// do the accounting for the peer's test protocol
233238
// testNode implements protocols.Balance
234239
func (t *testNode) Add(a int64, p *Peer) error {
235240
//get the index for the remote peer
236241
remote := t.bal.id2n[p.ID()]
237242
log.Debug("add", "local", t.i, "remote", remote, "amount", a)
238-
return t.bal.add(t.i, remote, a)
243+
return t.bal.add(t.i, remote, a, t.bal.wg)
239244
}
240245

241246
//run the p2p protocol
@@ -260,7 +265,6 @@ func (t *testNode) run(p *p2p.Peer, rw p2p.MsgReadWriter) error {
260265

261266
// p2p message receive handler function
262267
func (tp *testPeer) handle(ctx context.Context, msg interface{}) error {
263-
tp.wg.Done()
264268
log.Debug("receive", "from", tp.remote, "to", tp.local, "type", reflect.TypeOf(msg), "msg", msg)
265269
return nil
266270
}

p2p/protocols/accounting_test.go

+17-3
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,11 @@ func (m *zeroPriceMsg) Price() *Price {
9595
}
9696
}
9797

98+
//dummy accounting implementation, only stores values for later check
99+
func (d *dummyBalance) Check(amount int64, peer *Peer) error {
100+
return nil
101+
}
102+
98103
//dummy accounting implementation, only stores values for later check
99104
func (d *dummyBalance) Add(amount int64, peer *Peer) error {
100105
d.amount = amount
@@ -168,16 +173,25 @@ func TestBalance(t *testing.T) {
168173
}
169174

170175
func checkAccountingTestCases(t *testing.T, cases []testCase, acc *Accounting, peer *Peer, balance *dummyBalance, send bool) {
176+
t.Helper()
171177
for _, c := range cases {
172178
var err error
173-
var expectedResult int64
179+
var expectedResult, cost int64
174180
//reset balance before every check
175181
balance.amount = 0
176182
if send {
177-
err = acc.Send(peer, c.size, c.msg)
183+
cost, err = acc.Validate(peer, c.size, c.msg, Sender)
184+
if err != nil {
185+
t.Fatal(err)
186+
}
187+
err = acc.Apply(peer, cost, c.size)
178188
expectedResult = c.sendResult
179189
} else {
180-
err = acc.Receive(peer, c.size, c.msg)
190+
cost, err = acc.Validate(peer, c.size, c.msg, Receiver)
191+
if err != nil {
192+
t.Fatal(err)
193+
}
194+
err = acc.Apply(peer, cost, c.size)
181195
expectedResult = c.recvResult
182196
}
183197

p2p/protocols/protocol.go

+52-18
Original file line numberDiff line numberDiff line change
@@ -115,10 +115,10 @@ func errorf(code int, format string, params ...interface{}) *Error {
115115
//To access this functionality, we provide a Hook interface which will call accounting methods
116116
//NOTE: there could be more such (horizontal) hooks in the future
117117
type Hook interface {
118-
//A hook for sending messages
119-
Send(peer *Peer, size uint32, msg interface{}) error
120-
//A hook for receiving messages
121-
Receive(peer *Peer, size uint32, msg interface{}) error
118+
// A hook for applying accounting
119+
Apply(peer *Peer, costToLocalNode int64, size uint32) error
120+
// Run some validation before applying accounting
121+
Validate(peer *Peer, size uint32, msg interface{}, payer Payer) (int64, error)
122122
}
123123

124124
// Spec is a protocol specification including its name and version as well as
@@ -203,6 +203,7 @@ type Peer struct {
203203
spec *Spec
204204
encode func(context.Context, interface{}) (interface{}, int, error)
205205
decode func(p2p.Msg) (context.Context, []byte, error)
206+
lock sync.Mutex
206207
}
207208

208209
// NewPeer constructs a new peer
@@ -277,15 +278,31 @@ func (p *Peer) Send(ctx context.Context, msg interface{}) error {
277278
}
278279
size = len(r)
279280
}
280-
// if the accounting hook is set, call it
281+
282+
// if the accounting hook is set, do accounting logic
281283
if p.spec.Hook != nil {
282-
err = p.spec.Hook.Send(p, uint32(size), msg)
284+
285+
// let's lock, we want to avoid that after validating, a separate call might interfere
286+
p.lock.Lock()
287+
defer p.lock.Unlock()
288+
// validate that this operation would succeed...
289+
costToLocalNode, err := p.spec.Hook.Validate(p, uint32(size), wmsg, Sender)
283290
if err != nil {
291+
// ...because if it would fail, we return and don't send the message
284292
return err
285293
}
294+
// seems like accounting would succeed, thus send the message first...
295+
err = p2p.Send(p.rw, code, wmsg)
296+
if err != nil {
297+
return err
298+
}
299+
// ...and finally apply (write) the accounting change
300+
err = p.spec.Hook.Apply(p, costToLocalNode, uint32(size))
301+
} else {
302+
err = p2p.Send(p.rw, code, wmsg)
286303
}
287304

288-
return p2p.Send(p.rw, code, wmsg)
305+
return err
289306
}
290307

291308
// handleIncoming(code)
@@ -322,23 +339,40 @@ func (p *Peer) handleIncoming(handle func(ctx context.Context, msg interface{})
322339
return errorf(ErrDecode, "<= %v: %v", msg, err)
323340
}
324341

325-
// if the accounting hook is set, call it
342+
// if the accounting hook is set, do accounting logic
326343
if p.spec.Hook != nil {
327-
err := p.spec.Hook.Receive(p, uint32(len(msgBytes)), val)
344+
345+
p.lock.Lock()
346+
defer p.lock.Unlock()
347+
348+
size := uint32(len(msgBytes))
349+
350+
// validate that the accounting call would succeed...
351+
costToLocalNode, err := p.spec.Hook.Validate(p, size, val, Receiver)
328352
if err != nil {
353+
// ...because if it would fail, we return and don't handle the message
329354
return err
330355
}
331-
}
332356

333-
// call the registered handler callbacks
334-
// a registered callback take the decoded message as argument as an interface
335-
// which the handler is supposed to cast to the appropriate type
336-
// it is entirely safe not to check the cast in the handler since the handler is
337-
// chosen based on the proper type in the first place
338-
if err := handle(ctx, val); err != nil {
339-
return errorf(ErrHandler, "(msg code %v): %v", msg.Code, err)
357+
// seems like accounting would be fine, so handle the message
358+
if err := handle(ctx, val); err != nil {
359+
return errorf(ErrHandler, "(msg code %v): %v", msg.Code, err)
360+
}
361+
362+
// handling succeeded, finally apply accounting
363+
err = p.spec.Hook.Apply(p, costToLocalNode, size)
364+
} else {
365+
// call the registered handler callbacks
366+
// a registered callback take the decoded message as argument as an interface
367+
// which the handler is supposed to cast to the appropriate type
368+
// it is entirely safe not to check the cast in the handler since the handler is
369+
// chosen based on the proper type in the first place
370+
if err := handle(ctx, val); err != nil {
371+
return errorf(ErrHandler, "(msg code %v): %v", msg.Code, err)
372+
}
340373
}
341-
return nil
374+
375+
return err
342376
}
343377

344378
// Handshake negotiates a handshake on the peer connection

p2p/protocols/protocol_test.go

+9-10
Original file line numberDiff line numberDiff line change
@@ -209,26 +209,25 @@ type dummyMsg struct {
209209
Content string
210210
}
211211

212-
func (d *dummyHook) Send(peer *Peer, size uint32, msg interface{}) error {
212+
func (d *dummyHook) Validate(peer *Peer, size uint32, msg interface{}, payer Payer) (int64, error) {
213213
d.mu.Lock()
214214
defer d.mu.Unlock()
215-
216215
d.peer = peer
217216
d.size = size
218217
d.msg = msg
219-
d.send = true
220-
return d.err
218+
if payer == Sender {
219+
d.send = true
220+
} else {
221+
d.send = false
222+
d.waitC <- struct{}{}
223+
}
224+
return 0, d.err
221225
}
222226

223-
func (d *dummyHook) Receive(peer *Peer, size uint32, msg interface{}) error {
227+
func (d *dummyHook) Apply(peer *Peer, cost int64, size uint32) error {
224228
d.mu.Lock()
225229
defer d.mu.Unlock()
226230

227-
d.peer = peer
228-
d.size = size
229-
d.msg = msg
230-
d.send = false
231-
d.waitC <- struct{}{}
232231
return d.err
233232
}
234233

swap/swap.go

+31-6
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,7 @@ func keyToID(key string, prefix string) enode.ID {
304304
return enode.HexID(key[len(prefix):])
305305
}
306306

307-
// createOwner assings keys and addresses
307+
// createOwner assigns keys and addresses
308308
func createOwner(prvkey *ecdsa.PrivateKey) *Owner {
309309
pubkey := &prvkey.PublicKey
310310
return &Owner{
@@ -314,20 +314,45 @@ func createOwner(prvkey *ecdsa.PrivateKey) *Owner {
314314
}
315315
}
316316

317+
// modifyBalanceOk checks that the amount would not result in crossing the disconnection threshold
318+
func (s *Swap) modifyBalanceOk(amount int64, swapPeer *Peer) (err error) {
319+
// check if balance with peer is over the disconnect threshold and if the message would increase the existing debt
320+
balance := swapPeer.getBalance()
321+
if balance >= s.params.DisconnectThreshold && amount > 0 {
322+
return fmt.Errorf("balance for peer %s is over the disconnect threshold %d and cannot incur more debt, disconnecting", swapPeer.ID().String(), s.params.DisconnectThreshold)
323+
}
324+
325+
return nil
326+
}
327+
328+
// Check is called as a *dry run* before applying the actual accounting to an operation.
329+
// It only checks that performing a given accounting operation would not incur in an error.
330+
// If it returns no error, this signals to the caller that the operation is safe
331+
func (s *Swap) Check(amount int64, peer *protocols.Peer) (err error) {
332+
swapPeer := s.getPeer(peer.ID())
333+
if swapPeer == nil {
334+
return fmt.Errorf("peer %s not a swap enabled peer", peer.ID().String())
335+
}
336+
337+
swapPeer.lock.Lock()
338+
defer swapPeer.lock.Unlock()
339+
// currently this is the only real check needed:
340+
return s.modifyBalanceOk(amount, swapPeer)
341+
}
342+
317343
// Add is the (sole) accounting function
318344
// Swap implements the protocols.Balance interface
319345
func (s *Swap) Add(amount int64, peer *protocols.Peer) (err error) {
320346
swapPeer := s.getPeer(peer.ID())
321347
if swapPeer == nil {
322348
return fmt.Errorf("peer %s not a swap enabled peer", peer.ID().String())
323349
}
350+
324351
swapPeer.lock.Lock()
325352
defer swapPeer.lock.Unlock()
326-
327-
// check if balance with peer is over the disconnect threshold and if the message would increase the existing debt
328-
balance := swapPeer.getBalance()
329-
if balance >= s.params.DisconnectThreshold && amount > 0 {
330-
return fmt.Errorf("balance for peer %s is over the disconnect threshold %d and cannot incur more debt, disconnecting", peer.ID().String(), s.params.DisconnectThreshold)
353+
// we should probably check here again:
354+
if err = s.modifyBalanceOk(amount, swapPeer); err != nil {
355+
return err
331356
}
332357

333358
if err = swapPeer.updateBalance(amount); err != nil {

0 commit comments

Comments
 (0)