diff --git a/codec.go b/codec.go index 0a63f12..bdbeba6 100644 --- a/codec.go +++ b/codec.go @@ -55,7 +55,7 @@ func stringToBytes(s string) ([]byte, error) { } err = p.Transcoder.ValidateBytes(a) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to validate multiaddr %q: invalid value %q for protocol %s: %w", s, sp[0], p.Name, err) } if p.Size < 0 { // varint size. _, _ = b.Write(varint.ToUvarint(uint64(len(a)))) @@ -79,12 +79,16 @@ func readComponent(b []byte) (int, Component, error) { if p.Code == 0 { return 0, Component{}, fmt.Errorf("no protocol with code %d", code) } + pPtr := protocolPtrByCode[code] + if pPtr == nil { + return 0, Component{}, fmt.Errorf("no protocol with code %d", code) + } if p.Size == 0 { c, err := validateComponent(Component{ - bytes: string(b[:offset]), - offset: offset, - protocol: p, + bytes: string(b[:offset]), + valueStartIdx: offset, + protocol: pPtr, }) return offset, c, err @@ -109,9 +113,9 @@ func readComponent(b []byte) (int, Component, error) { } c, err := validateComponent(Component{ - bytes: string(b[:offset+size]), - protocol: p, - offset: offset, + bytes: string(b[:offset+size]), + protocol: pPtr, + valueStartIdx: offset, }) return offset + size, c, err diff --git a/component.go b/component.go index aa8f4ef..39f5d5a 100644 --- a/component.go +++ b/component.go @@ -1,6 +1,7 @@ package multiaddr import ( + "bytes" "encoding/binary" "encoding/json" "fmt" @@ -11,9 +12,11 @@ import ( // Component is a single multiaddr Component. type Component struct { - bytes string // Uses the string type to ensure immutability. - protocol Protocol - offset int + // bytes is the raw bytes of the component. It includes the protocol code as + // varint, possibly the size of the value, and the value. + bytes string // string for immutability. + protocol *Protocol + valueStartIdx int // Index of the first byte of the Component's value in the bytes array } func (c Component) AsMultiaddr() Multiaddr { @@ -107,10 +110,16 @@ func (c Component) Compare(o Component) int { } func (c Component) Protocols() []Protocol { - return []Protocol{c.protocol} + if c.protocol == nil { + return nil + } + return []Protocol{*c.protocol} } func (c Component) ValueForProtocol(code int) (string, error) { + if c.protocol == nil { + return "", fmt.Errorf("component has nil protocol") + } if c.protocol.Code != code { return "", ErrProtocolNotFound } @@ -118,11 +127,14 @@ func (c Component) ValueForProtocol(code int) (string, error) { } func (c Component) Protocol() Protocol { - return c.protocol + if c.protocol == nil { + return Protocol{} + } + return *c.protocol } func (c Component) RawValue() []byte { - return []byte(c.bytes[c.offset:]) + return []byte(c.bytes[c.valueStartIdx:]) } func (c Component) Value() string { @@ -135,10 +147,13 @@ func (c Component) Value() string { } func (c Component) valueAndErr() (string, error) { + if c.protocol == nil { + return "", fmt.Errorf("component has nil protocol") + } if c.protocol.Transcoder == nil { return "", nil } - value, err := c.protocol.Transcoder.BytesToString([]byte(c.bytes[c.offset:])) + value, err := c.protocol.Transcoder.BytesToString([]byte(c.bytes[c.valueStartIdx:])) if err != nil { return "", err } @@ -154,6 +169,9 @@ func (c Component) String() string { // writeTo is an efficient, private function for string-formatting a multiaddr. // Trust me, we tend to allocate a lot when doing this. func (c Component) writeTo(b *strings.Builder) { + if c.protocol == nil { + return + } b.WriteByte('/') b.WriteString(c.protocol.Name) value := c.Value() @@ -185,6 +203,11 @@ func NewComponent(protocol, value string) (Component, error) { } func newComponent(protocol Protocol, bvalue []byte) (Component, error) { + protocolPtr := protocolPtrByCode[protocol.Code] + if protocolPtr == nil { + protocolPtr = &protocol + } + size := len(bvalue) size += len(protocol.VCode) if protocol.Size < 0 { @@ -205,9 +228,9 @@ func newComponent(protocol Protocol, bvalue []byte) (Component, error) { return validateComponent( Component{ - bytes: string(maddr), - protocol: protocol, - offset: offset, + bytes: string(maddr), + protocol: protocolPtr, + valueStartIdx: offset, }) } @@ -215,13 +238,53 @@ func newComponent(protocol Protocol, bvalue []byte) (Component, error) { // It ensures that we will be able to call all methods on Component without // error. func validateComponent(c Component) (Component, error) { + if c.protocol == nil { + return Component{}, fmt.Errorf("component is missing its protocol") + } + if c.valueStartIdx > len(c.bytes) { + return Component{}, fmt.Errorf("component valueStartIdx is greater than the length of the component's bytes") + } + + if len(c.protocol.VCode) == 0 { + return Component{}, fmt.Errorf("Component is missing its protocol's VCode field") + } + if len(c.bytes) < len(c.protocol.VCode) { + return Component{}, fmt.Errorf("component size mismatch: %d != %d", len(c.bytes), len(c.protocol.VCode)) + } + if !bytes.Equal([]byte(c.bytes[:len(c.protocol.VCode)]), c.protocol.VCode) { + return Component{}, fmt.Errorf("component's VCode field is invalid: %v != %v", []byte(c.bytes[:len(c.protocol.VCode)]), c.protocol.VCode) + } + if c.protocol.Size < 0 { + size, n, err := ReadVarintCode([]byte(c.bytes[len(c.protocol.VCode):])) + if err != nil { + return Component{}, err + } + if size != len(c.bytes[c.valueStartIdx:]) { + return Component{}, fmt.Errorf("component value size mismatch: %d != %d", size, len(c.bytes[c.valueStartIdx:])) + } + + if len(c.protocol.VCode)+n+size != len(c.bytes) { + return Component{}, fmt.Errorf("component size mismatch: %d != %d", len(c.protocol.VCode)+n+size, len(c.bytes)) + } + } else { + // Fixed size value + size := c.protocol.Size / 8 + if size != len(c.bytes[c.valueStartIdx:]) { + return Component{}, fmt.Errorf("component value size mismatch: %d != %d", size, len(c.bytes[c.valueStartIdx:])) + } + + if len(c.protocol.VCode)+size != len(c.bytes) { + return Component{}, fmt.Errorf("component size mismatch: %d != %d", len(c.protocol.VCode)+size, len(c.bytes)) + } + } + _, err := c.valueAndErr() if err != nil { return Component{}, err } if c.protocol.Transcoder != nil { - err = c.protocol.Transcoder.ValidateBytes([]byte(c.bytes[c.offset:])) + err = c.protocol.Transcoder.ValidateBytes([]byte(c.bytes[c.valueStartIdx:])) if err != nil { return Component{}, err } diff --git a/matest/matest.go b/matest/matest.go index aa76655..b361468 100644 --- a/matest/matest.go +++ b/matest/matest.go @@ -72,6 +72,9 @@ type MultiaddrMatcher struct { multiaddr.Multiaddr } +// Implements the Matcher interface for gomock.Matcher +// Let's us use this struct in gomock tests. Example: +// Expect(mock.Method(gomock.Any(), multiaddrMatcher).Return(nil) func (m MultiaddrMatcher) Matches(x interface{}) bool { if m2, ok := x.(multiaddr.Multiaddr); ok { return m.Equal(m2) diff --git a/multiaddr_test.go b/multiaddr_test.go index b281c2d..95b57f4 100644 --- a/multiaddr_test.go +++ b/multiaddr_test.go @@ -1132,3 +1132,43 @@ func FuzzSplitRoundtrip(f *testing.F) { } }) } + +func BenchmarkComponentValidation(b *testing.B) { + comp, err := NewComponent("ip4", "127.0.0.1") + if err != nil { + b.Fatal(err) + } + b.ReportAllocs() + for i := 0; i < b.N; i++ { + _, err := validateComponent(comp) + if err != nil { + b.Fatal(err) + } + } +} + +func FuzzComponents(f *testing.F) { + for _, v := range good { + m := StringCast(v) + for _, c := range m { + f.Add(c.Bytes()) + } + } + f.Fuzz(func(t *testing.T, compBytes []byte) { + n, c, err := readComponent(compBytes) + if err != nil { + t.Skip() + } + if c.protocol == nil { + t.Fatal("component has nil protocol") + } + if c.protocol.Code == 0 { + t.Fatal("component has nil protocol code") + } + if !bytes.Equal(c.Bytes(), compBytes[:n]) { + t.Logf("component bytes: %v", c.Bytes()) + t.Logf("original bytes: %v", compBytes[:n]) + t.Fatal("component bytes are not equal to the original bytes") + } + }) +} diff --git a/protocol.go b/protocol.go index 61a2924..d0ce032 100644 --- a/protocol.go +++ b/protocol.go @@ -47,6 +47,9 @@ type Protocol struct { var protocolsByName = map[string]Protocol{} var protocolsByCode = map[int]Protocol{} +// Keep a map of pointers so that we can reuse the same pointer for the same protocol. +var protocolPtrByCode = map[int]*Protocol{} + // Protocols is the list of multiaddr protocols supported by this module. var Protocols = []Protocol{} @@ -65,10 +68,14 @@ func AddProtocol(p Protocol) error { if p.Path && p.Size >= 0 { return fmt.Errorf("path protocols must have variable-length sizes") } + if len(p.VCode) == 0 { + return fmt.Errorf("protocol code %d is missing its VCode field", p.Code) + } Protocols = append(Protocols, p) protocolsByName[p.Name] = p protocolsByCode[p.Code] = p + protocolPtrByCode[p.Code] = &p return nil }