diff --git a/convert.go b/convert.go index 50c6fff..0bd59c7 100644 --- a/convert.go +++ b/convert.go @@ -13,103 +13,41 @@ var errIncorrectNetAddr = fmt.Errorf("incorrect network addr conversion") // FromNetAddr converts a net.Addr type to a Multiaddr. func FromNetAddr(a net.Addr) (ma.Multiaddr, error) { + return defaultCodecs.FromNetAddr(a) +} + +func (cm *CodecMap) FromNetAddr(a net.Addr) (ma.Multiaddr, error) { if a == nil { return nil, fmt.Errorf("nil multiaddr") } - - switch a.Network() { - case "tcp", "tcp4", "tcp6": - ac, ok := a.(*net.TCPAddr) - if !ok { - return nil, errIncorrectNetAddr - } - - // Get IP Addr - ipm, err := FromIP(ac.IP) - if err != nil { - return nil, errIncorrectNetAddr - } - - // Get TCP Addr - tcpm, err := ma.NewMultiaddr(fmt.Sprintf("/tcp/%d", ac.Port)) - if err != nil { - return nil, errIncorrectNetAddr - } - - // Encapsulate - return ipm.Encapsulate(tcpm), nil - - case "udp", "upd4", "udp6": - ac, ok := a.(*net.UDPAddr) - if !ok { - return nil, errIncorrectNetAddr - } - - // Get IP Addr - ipm, err := FromIP(ac.IP) - if err != nil { - return nil, errIncorrectNetAddr - } - - // Get UDP Addr - udpm, err := ma.NewMultiaddr(fmt.Sprintf("/udp/%d", ac.Port)) - if err != nil { - return nil, errIncorrectNetAddr - } - - // Encapsulate - return ipm.Encapsulate(udpm), nil - - case "utp", "utp4", "utp6": - acc, ok := a.(*utp.Addr) - if !ok { - return nil, errIncorrectNetAddr - } - - // Get UDP Addr - ac, ok := acc.Child().(*net.UDPAddr) - if !ok { - return nil, errIncorrectNetAddr - } - - // Get IP Addr - ipm, err := FromIP(ac.IP) - if err != nil { - return nil, errIncorrectNetAddr - } - - // Get UDP Addr - utpm, err := ma.NewMultiaddr(fmt.Sprintf("/udp/%d/utp", ac.Port)) - if err != nil { - return nil, errIncorrectNetAddr - } - - // Encapsulate - return ipm.Encapsulate(utpm), nil - - case "ip", "ip4", "ip6": - ac, ok := a.(*net.IPAddr) - if !ok { - return nil, errIncorrectNetAddr - } - return FromIP(ac.IP) - - case "ip+net": - ac, ok := a.(*net.IPNet) - if !ok { - return nil, errIncorrectNetAddr - } - return FromIP(ac.IP) - - default: - return nil, fmt.Errorf("unknown network %v", a.Network()) + p, err := cm.getAddrParser(a.Network()) + if err != nil { + return nil, err } + + return p(a) } // ToNetAddr converts a Multiaddr to a net.Addr // Must be ThinWaist. acceptable protocol stacks are: // /ip{4,6}/{tcp, udp} func ToNetAddr(maddr ma.Multiaddr) (net.Addr, error) { + return defaultCodecs.ToNetAddr(maddr) +} + +func (cm *CodecMap) ToNetAddr(maddr ma.Multiaddr) (net.Addr, error) { + protos := maddr.Protocols() + final := protos[len(protos)-1] + + p, err := cm.getMaddrParser(final.Name) + if err != nil { + return nil, err + } + + return p(maddr) +} + +func parseBasicNetMaddr(maddr ma.Multiaddr) (net.Addr, error) { network, host, err := DialArgs(maddr) if err != nil { return nil, err @@ -143,6 +81,8 @@ func FromIP(ip net.IP) (ma.Multiaddr, error) { // DialArgs is a convenience function returning arguments for use in net.Dial func DialArgs(m ma.Multiaddr) (string, string, error) { + // TODO: find a 'good' way to eliminate the function. + // My preference is with a multiaddr.Format(...) function if !IsThinWaist(m) { return "", "", fmt.Errorf("%s is not a 'thin waist' address", m) } @@ -170,3 +110,91 @@ func DialArgs(m ma.Multiaddr) (string, string, error) { } return network, host, nil } + +func parseTcpNetAddr(a net.Addr) (ma.Multiaddr, error) { + ac, ok := a.(*net.TCPAddr) + if !ok { + return nil, errIncorrectNetAddr + } + + // Get IP Addr + ipm, err := FromIP(ac.IP) + if err != nil { + return nil, errIncorrectNetAddr + } + + // Get TCP Addr + tcpm, err := ma.NewMultiaddr(fmt.Sprintf("/tcp/%d", ac.Port)) + if err != nil { + return nil, errIncorrectNetAddr + } + + // Encapsulate + return ipm.Encapsulate(tcpm), nil +} + +func parseUdpNetAddr(a net.Addr) (ma.Multiaddr, error) { + ac, ok := a.(*net.UDPAddr) + if !ok { + return nil, errIncorrectNetAddr + } + + // Get IP Addr + ipm, err := FromIP(ac.IP) + if err != nil { + return nil, errIncorrectNetAddr + } + + // Get UDP Addr + udpm, err := ma.NewMultiaddr(fmt.Sprintf("/udp/%d", ac.Port)) + if err != nil { + return nil, errIncorrectNetAddr + } + + // Encapsulate + return ipm.Encapsulate(udpm), nil +} + +func parseUtpNetAddr(a net.Addr) (ma.Multiaddr, error) { + acc, ok := a.(*utp.Addr) + if !ok { + return nil, errIncorrectNetAddr + } + + // Get UDP Addr + ac, ok := acc.Child().(*net.UDPAddr) + if !ok { + return nil, errIncorrectNetAddr + } + + // Get IP Addr + ipm, err := FromIP(ac.IP) + if err != nil { + return nil, errIncorrectNetAddr + } + + // Get UDP Addr + utpm, err := ma.NewMultiaddr(fmt.Sprintf("/udp/%d/utp", ac.Port)) + if err != nil { + return nil, errIncorrectNetAddr + } + + // Encapsulate + return ipm.Encapsulate(utpm), nil +} + +func parseIpNetAddr(a net.Addr) (ma.Multiaddr, error) { + ac, ok := a.(*net.IPAddr) + if !ok { + return nil, errIncorrectNetAddr + } + return FromIP(ac.IP) +} + +func parseIpPlusNetAddr(a net.Addr) (ma.Multiaddr, error) { + ac, ok := a.(*net.IPNet) + if !ok { + return nil, errIncorrectNetAddr + } + return FromIP(ac.IP) +} diff --git a/registry.go b/registry.go new file mode 100644 index 0000000..816e554 --- /dev/null +++ b/registry.go @@ -0,0 +1,128 @@ +package manet + +import ( + "fmt" + "net" + "sync" + + ma "github.com/jbenet/go-multiaddr" +) + +type FromNetAddrFunc func(a net.Addr) (ma.Multiaddr, error) +type ToNetAddrFunc func(ma ma.Multiaddr) (net.Addr, error) + +var defaultCodecs *CodecMap + +func init() { + defaultCodecs = NewCodecMap() + defaultCodecs.RegisterNetCodec(tcpAddrSpec) + defaultCodecs.RegisterNetCodec(udpAddrSpec) + defaultCodecs.RegisterNetCodec(utpAddrSpec) + defaultCodecs.RegisterNetCodec(ip4AddrSpec) + defaultCodecs.RegisterNetCodec(ip6AddrSpec) +} + +type CodecMap struct { + codecs map[string]*NetCodec + addrParsers map[string]FromNetAddrFunc + maddrParsers map[string]ToNetAddrFunc + lk sync.Mutex +} + +func NewCodecMap() *CodecMap { + return &CodecMap{ + codecs: make(map[string]*NetCodec), + addrParsers: make(map[string]FromNetAddrFunc), + maddrParsers: make(map[string]ToNetAddrFunc), + } +} + +type NetCodec struct { + // NetAddrNetworks is an array of strings that may be returned + // by net.Addr.Network() calls on addresses belonging to this type + NetAddrNetworks []string + + // ProtocolName is the string value for Multiaddr address keys + ProtocolName string + + // ParseNetAddr parses a net.Addr belonging to this type into a multiaddr + ParseNetAddr FromNetAddrFunc + + // ConvertMultiaddr converts a multiaddr of this type back into a net.Addr + ConvertMultiaddr ToNetAddrFunc + + // Protocol returns the multiaddr protocol struct for this type + Protocol ma.Protocol +} + +func RegisterNetCodec(a *NetCodec) { + defaultCodecs.RegisterNetCodec(a) +} + +func (cm *CodecMap) RegisterNetCodec(a *NetCodec) { + cm.lk.Lock() + defer cm.lk.Unlock() + cm.codecs[a.ProtocolName] = a + for _, n := range a.NetAddrNetworks { + cm.addrParsers[n] = a.ParseNetAddr + } + + cm.maddrParsers[a.ProtocolName] = a.ConvertMultiaddr +} + +var tcpAddrSpec = &NetCodec{ + ProtocolName: "tcp", + NetAddrNetworks: []string{"tcp", "tcp4", "tcp6"}, + ParseNetAddr: parseTcpNetAddr, + ConvertMultiaddr: parseBasicNetMaddr, +} + +var udpAddrSpec = &NetCodec{ + ProtocolName: "udp", + NetAddrNetworks: []string{"udp", "udp4", "udp6"}, + ParseNetAddr: parseUdpNetAddr, + ConvertMultiaddr: parseBasicNetMaddr, +} + +var utpAddrSpec = &NetCodec{ + ProtocolName: "utp", + NetAddrNetworks: []string{"utp", "utp4", "utp6"}, + ParseNetAddr: parseUtpNetAddr, + ConvertMultiaddr: parseBasicNetMaddr, +} + +var ip4AddrSpec = &NetCodec{ + ProtocolName: "ip4", + NetAddrNetworks: []string{"ip4"}, + ParseNetAddr: parseIpNetAddr, + ConvertMultiaddr: parseBasicNetMaddr, +} + +var ip6AddrSpec = &NetCodec{ + ProtocolName: "ip6", + NetAddrNetworks: []string{"ip6"}, + ParseNetAddr: parseIpNetAddr, + ConvertMultiaddr: parseBasicNetMaddr, +} + +func (cm *CodecMap) getAddrParser(net string) (FromNetAddrFunc, error) { + cm.lk.Lock() + defer cm.lk.Unlock() + + parser, ok := cm.addrParsers[net] + if !ok { + return nil, fmt.Errorf("unknown network %v", net) + } + return parser, nil +} + +func (cm *CodecMap) getMaddrParser(name string) (ToNetAddrFunc, error) { + cm.lk.Lock() + defer cm.lk.Unlock() + p, ok := cm.maddrParsers[name] + if !ok { + return nil, fmt.Errorf("network not supported: %s", name) + } + + return p, nil +} diff --git a/registry_test.go b/registry_test.go new file mode 100644 index 0000000..6e825e8 --- /dev/null +++ b/registry_test.go @@ -0,0 +1,50 @@ +package manet + +import ( + "net" + "testing" + + ma "github.com/jbenet/go-multiaddr" +) + +func TestRegisterSpec(t *testing.T) { + cm := NewCodecMap() + myproto := &NetCodec{ + ProtocolName: "test", + NetAddrNetworks: []string{"test", "iptest", "blahtest"}, + ConvertMultiaddr: func(a ma.Multiaddr) (net.Addr, error) { return nil, nil }, + ParseNetAddr: func(a net.Addr) (ma.Multiaddr, error) { return nil, nil }, + } + + cm.RegisterNetCodec(myproto) + + _, ok := cm.addrParsers["test"] + if !ok { + t.Fatal("myproto not properly registered") + } + + _, ok = cm.addrParsers["iptest"] + if !ok { + t.Fatal("myproto not properly registered") + } + + _, ok = cm.addrParsers["blahtest"] + if !ok { + t.Fatal("myproto not properly registered") + } + + _, ok = cm.maddrParsers["test"] + if !ok { + t.Fatal("myproto not properly registered") + } + + _, ok = cm.maddrParsers["iptest"] + if ok { + t.Fatal("myproto not properly registered") + } + + _, ok = cm.maddrParsers["blahtest"] + if ok { + t.Fatal("myproto not properly registered") + } +}