Skip to content

refactor: Backwards compatible Encapsulate/Decapsulate/Join #272

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Feb 24, 2025
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 21 additions & 12 deletions codec.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,31 +67,36 @@ func stringToBytes(s string) ([]byte, error) {
return b.Bytes(), nil
}

func readComponent(b []byte) (int, Component, error) {
func readComponent(b []byte) (int, *Component, error) {
var offset int
code, n, err := ReadVarintCode(b)
if err != nil {
return 0, Component{}, err
return 0, nil, err
}
offset += n

p := ProtocolWithCode(code)
if p.Code == 0 {
return 0, Component{}, fmt.Errorf("no protocol with code %d", code)
return 0, nil, fmt.Errorf("no protocol with code %d", code)
}
pPtr := protocolPtrByCode[code]
if pPtr == nil {
return 0, Component{}, fmt.Errorf("no protocol with code %d", code)
return 0, nil, fmt.Errorf("no protocol with code %d", code)
}

if p.Size == 0 {
c, err := validateComponent(Component{
c := &Component{
bytes: string(b[:offset]),
valueStartIdx: offset,
protocol: pPtr,
})
}

err := validateComponent(c)
if err != nil {
return 0, nil, err
}

return offset, c, err
return offset, c, nil
}

var size int
Expand All @@ -100,7 +105,7 @@ func readComponent(b []byte) (int, Component, error) {
var n int
size, n, err = ReadVarintCode(b[offset:])
if err != nil {
return 0, Component{}, err
return 0, nil, err
}
offset += n
} else {
Expand All @@ -109,14 +114,18 @@ func readComponent(b []byte) (int, Component, error) {
}

if len(b[offset:]) < size || size <= 0 {
return 0, Component{}, fmt.Errorf("invalid value for size %d", len(b[offset:]))
return 0, nil, fmt.Errorf("invalid value for size %d", len(b[offset:]))
}

c, err := validateComponent(Component{
c := &Component{
bytes: string(b[:offset+size]),
protocol: pPtr,
valueStartIdx: offset,
})
}
err = validateComponent(c)
if err != nil {
return 0, nil, err
}

return offset + size, c, err
}
Expand All @@ -142,7 +151,7 @@ func readMultiaddr(b []byte) (int, Multiaddr, error) {
return bytesRead, nil, fmt.Errorf("unexpected component after path component")
}
sawPathComponent = c.protocol.Path
res = append(res, c)
res = append(res, *c)
}
return bytesRead, res, nil
}
68 changes: 38 additions & 30 deletions component.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,11 @@ func (c *Component) AsMultiaddr() Multiaddr {
return []Component{*c}
}

func (c *Component) Encapsulate(o Multiaddr) Multiaddr {
func (c *Component) Encapsulate(o asMultiaddr) Multiaddr {
return c.AsMultiaddr().Encapsulate(o)
}

func (c *Component) Decapsulate(o Multiaddr) Multiaddr {
func (c *Component) Decapsulate(o asMultiaddr) Multiaddr {
return c.AsMultiaddr().Decapsulate(o)
}

Expand Down Expand Up @@ -63,7 +63,7 @@ func (c *Component) UnmarshalBinary(data []byte) error {
if err != nil {
return err
}
*c = comp
*c = *comp
return nil
}

Expand All @@ -87,7 +87,7 @@ func (c *Component) UnmarshalText(data []byte) error {
if err != nil {
return err
}
*c = comp
*c = *comp
return nil
}

Expand Down Expand Up @@ -236,24 +236,24 @@ func (c *Component) writeTo(b *strings.Builder) {
}

// NewComponent constructs a new multiaddr component
func NewComponent(protocol, value string) (Component, error) {
func NewComponent(protocol, value string) (*Component, error) {
p := ProtocolWithName(protocol)
if p.Code == 0 {
return Component{}, fmt.Errorf("unsupported protocol: %s", protocol)
return nil, fmt.Errorf("unsupported protocol: %s", protocol)
}
if p.Transcoder != nil {
bts, err := p.Transcoder.StringToBytes(value)
if err != nil {
return Component{}, err
return nil, err
}
return newComponent(p, bts)
} else if value != "" {
return Component{}, fmt.Errorf("protocol %s doesn't take a value", p.Name)
return nil, fmt.Errorf("protocol %s doesn't take a value", p.Name)
}
return newComponent(p, nil)
}

func newComponent(protocol Protocol, bvalue []byte) (Component, error) {
func newComponent(protocol Protocol, bvalue []byte) (*Component, error) {
protocolPtr := protocolPtrByCode[protocol.Code]
if protocolPtr == nil {
protocolPtr = &protocol
Expand All @@ -274,71 +274,79 @@ func newComponent(protocol Protocol, bvalue []byte) (Component, error) {

// Shouldn't happen
if len(maddr) != offset+len(bvalue) {
return Component{}, fmt.Errorf("component size mismatch: %d != %d", len(maddr), offset+len(bvalue))
return nil, fmt.Errorf("component size mismatch: %d != %d", len(maddr), offset+len(bvalue))
}

return validateComponent(
Component{
bytes: string(maddr),
protocol: protocolPtr,
valueStartIdx: offset,
})
c := &Component{
bytes: string(maddr),
protocol: protocolPtr,
valueStartIdx: offset,
}

err := validateComponent(c)
if err != nil {
return nil, err
}
return c, nil
}

// validateComponent MUST be called after creating a non-zero Component.
// It ensures that we will be able to call all methods on Component without
// error.
func validateComponent(c Component) (Component, error) {
func validateComponent(c *Component) error {
if c == nil {
return errNilPtr
}
if c.protocol == nil {
return Component{}, fmt.Errorf("component is missing its protocol")
return fmt.Errorf("component is missing its protocol")
}
if c.valueStartIdx > len(c.bytes) {
return Component{}, fmt.Errorf("component valueStartIdx is greater than the length of the component's bytes")
return fmt.Errorf("component valueStartIdx is greater than the length of the component's bytes")
}

if len(c.protocol.VCode) == 0 {
return Component{}, fmt.Errorf("Component is missing its protocol's VCode field")
return fmt.Errorf("Component is missing its protocol's VCode field")
}
if len(c.bytes) < len(c.protocol.VCode) {
return Component{}, fmt.Errorf("component size mismatch: %d != %d", len(c.bytes), len(c.protocol.VCode))
return fmt.Errorf("component size mismatch: %d != %d", len(c.bytes), len(c.protocol.VCode))
}
if !bytes.Equal([]byte(c.bytes[:len(c.protocol.VCode)]), c.protocol.VCode) {
return Component{}, fmt.Errorf("component's VCode field is invalid: %v != %v", []byte(c.bytes[:len(c.protocol.VCode)]), c.protocol.VCode)
return fmt.Errorf("component's VCode field is invalid: %v != %v", []byte(c.bytes[:len(c.protocol.VCode)]), c.protocol.VCode)
}
if c.protocol.Size < 0 {
size, n, err := ReadVarintCode([]byte(c.bytes[len(c.protocol.VCode):]))
if err != nil {
return Component{}, err
return err
}
if size != len(c.bytes[c.valueStartIdx:]) {
return Component{}, fmt.Errorf("component value size mismatch: %d != %d", size, len(c.bytes[c.valueStartIdx:]))
return fmt.Errorf("component value size mismatch: %d != %d", size, len(c.bytes[c.valueStartIdx:]))
}

if len(c.protocol.VCode)+n+size != len(c.bytes) {
return Component{}, fmt.Errorf("component size mismatch: %d != %d", len(c.protocol.VCode)+n+size, len(c.bytes))
return fmt.Errorf("component size mismatch: %d != %d", len(c.protocol.VCode)+n+size, len(c.bytes))
}
} else {
// Fixed size value
size := c.protocol.Size / 8
if size != len(c.bytes[c.valueStartIdx:]) {
return Component{}, fmt.Errorf("component value size mismatch: %d != %d", size, len(c.bytes[c.valueStartIdx:]))
return fmt.Errorf("component value size mismatch: %d != %d", size, len(c.bytes[c.valueStartIdx:]))
}

if len(c.protocol.VCode)+size != len(c.bytes) {
return Component{}, fmt.Errorf("component size mismatch: %d != %d", len(c.protocol.VCode)+size, len(c.bytes))
return fmt.Errorf("component size mismatch: %d != %d", len(c.protocol.VCode)+size, len(c.bytes))
}
}

_, err := c.valueAndErr()
if err != nil {
return Component{}, err
return err

}
if c.protocol.Transcoder != nil {
err = c.protocol.Transcoder.ValidateBytes([]byte(c.bytes[c.valueStartIdx:]))
if err != nil {
return Component{}, err
return err
}
}
return c, nil
return nil
}
26 changes: 14 additions & 12 deletions multiaddr.go
Original file line number Diff line number Diff line change
Expand Up @@ -181,23 +181,25 @@ func (m Multiaddr) Protocols() []Protocol {
return out
}

// Encapsulate wraps a given Multiaddr, returning the resulting joined Multiaddr
func (m Multiaddr) Encapsulate(o Multiaddr) Multiaddr {
return Join(m, o)
type asMultiaddr interface {
AsMultiaddr() Multiaddr
}

func (m Multiaddr) EncapsulateC(c *Component) Multiaddr {
if c.Empty() {
return m
}
out := make([]Component, 0, len(m)+1)
out = append(out, m...)
out = append(out, *c)
return out
func (m Multiaddr) AsMultiaddr() Multiaddr {
return m
}

// Encapsulate wraps a given Multiaddr, returning the resulting joined Multiaddr
func (m Multiaddr) Encapsulate(other asMultiaddr) Multiaddr {
return Join(m, other)
}

// Decapsulate unwraps Multiaddr up until the given Multiaddr is found.
func (m Multiaddr) Decapsulate(rightParts Multiaddr) Multiaddr {
func (m Multiaddr) Decapsulate(rightPartsAny asMultiaddr) Multiaddr {
if rightPartsAny == nil {
return m
}
rightParts := rightPartsAny.AsMultiaddr()
leftParts := m

lastIndex := -1
Expand Down
47 changes: 40 additions & 7 deletions multiaddr_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,16 @@ func TestReturnsNilOnEmpty(t *testing.T) {
require.Nil(t, a)
}

func TestJoinWithComponents(t *testing.T) {
var m Multiaddr
c, err := NewComponent("ip4", "127.0.0.1")
require.NoError(t, err)

expected := "/ip4/127.0.0.1"
require.Equal(t, expected, Join(m, c).String())

}

func TestConstructFails(t *testing.T) {
cases := []string{
"/ip4",
Expand Down Expand Up @@ -520,6 +530,19 @@ func TestEncapsulate(t *testing.T) {
if d != nil {
t.Error("decapsulate /ip4 failed: ", d)
}

t.Run("Encapsulating with components", func(t *testing.T) {
left, last := SplitLast(m)
joined := left.Encapsulate(last)
require.True(t, joined.Equal(m))

first, rest := SplitFirst(m)
joined = first.Encapsulate(rest)
require.True(t, joined.Equal(m))
// Component type
joined = (*first).Encapsulate(rest)
require.True(t, joined.Equal(m))
})
}

func TestDecapsulateComment(t *testing.T) {
Expand Down Expand Up @@ -580,6 +603,16 @@ func TestDecapsulate(t *testing.T) {
require.Equal(t, expected, actual)
})
}

for _, tc := range testcases {
t.Run("Decapsulating with components"+tc.left, func(t *testing.T) {
left, last := SplitLast(StringCast(tc.left))
butLast := left.Decapsulate(last)
require.Equal(t, butLast.String(), left.String())
// Round trip
require.Equal(t, tc.left, butLast.Encapsulate(last).String())
})
}
}

func assertValueForProto(t *testing.T, a Multiaddr, p int, exp string) {
Expand Down Expand Up @@ -949,7 +982,7 @@ func TestUseNilComponent(t *testing.T) {
foo.AsMultiaddr()
foo.Encapsulate(nil)
foo.Decapsulate(nil)
foo.Empty()
require.True(t, foo.Empty())
foo.Bytes()
foo.MarshalBinary()
foo.MarshalJSON()
Expand All @@ -967,7 +1000,7 @@ func TestUseNilComponent(t *testing.T) {
_ = foo.String()

var m Multiaddr = nil
m.EncapsulateC(foo)
m.Encapsulate(foo)
}

func TestFilterAddrs(t *testing.T) {
Expand Down Expand Up @@ -1124,12 +1157,12 @@ func FuzzSplitRoundtrip(f *testing.F) {

// Test SplitFirst
first, rest := SplitFirst(addr)
joined := Join(first.AsMultiaddr(), rest)
joined := Join(first, rest)
require.True(t, addr.Equal(joined), "SplitFirst and Join should round-trip")

// Test SplitLast
rest, last := SplitLast(addr)
joined = Join(rest, last.AsMultiaddr())
joined = Join(rest, last)
require.True(t, addr.Equal(joined), "SplitLast and Join should round-trip")

p := addr.Protocols()
Expand All @@ -1155,12 +1188,12 @@ func FuzzSplitRoundtrip(f *testing.F) {
return c.Protocol().Code == proto.Code
}
beforeC, after := SplitFirst(addr)
joined = Join(beforeC.AsMultiaddr(), after)
joined = Join(beforeC, after)
require.True(t, addr.Equal(joined))
tryPubMethods(after)

before, afterC := SplitLast(addr)
joined = Join(before, afterC.AsMultiaddr())
joined = Join(before, afterC)
require.True(t, addr.Equal(joined))
tryPubMethods(before)

Expand All @@ -1180,7 +1213,7 @@ func BenchmarkComponentValidation(b *testing.B) {
}
b.ReportAllocs()
for i := 0; i < b.N; i++ {
_, err := validateComponent(comp)
err := validateComponent(comp)
if err != nil {
b.Fatal(err)
}
Expand Down
Loading