diff --git a/resolve.go b/resolve.go index 697022e..586e549 100644 --- a/resolve.go +++ b/resolve.go @@ -16,8 +16,10 @@ var ( dnsProtocol = ma.ProtocolWithCode(ma.P_DNS) ) -var ResolvableProtocols = []ma.Protocol{dnsaddrProtocol, dns4Protocol, dns6Protocol, dnsProtocol} -var DefaultResolver = &Resolver{def: net.DefaultResolver} +var ( + ResolvableProtocols = []ma.Protocol{dnsaddrProtocol, dns4Protocol, dns6Protocol, dnsProtocol} + DefaultResolver = &Resolver{def: net.DefaultResolver} +) const dnsaddrTXTPrefix = "dnsaddr=" @@ -104,179 +106,162 @@ func (r *Resolver) getResolver(domain string) BasicResolver { return r.def } -// Resolve resolves a DNS multiaddr. +// Resolve resolves a DNS multiaddr. It will only resolve the first DNS component in the multiaddr. +// If you need to resolve multiple DNS components, you may call this function again with each returned address. func (r *Resolver) Resolve(ctx context.Context, maddr ma.Multiaddr) ([]ma.Multiaddr, error) { - var results []ma.Multiaddr - for i := 0; maddr != nil; i++ { - var keep ma.Multiaddr - - // Find the next dns component. - keep, maddr = ma.SplitFunc(maddr, func(c ma.Component) bool { - switch c.Protocol().Code { - case dnsProtocol.Code, dns4Protocol.Code, dns6Protocol.Code, dnsaddrProtocol.Code: - return true - default: - return false - } - }) - - // Keep everything before the dns component. - if keep != nil { - if len(results) == 0 { - results = []ma.Multiaddr{keep} - } else { - for i, r := range results { - results[i] = r.Encapsulate(keep) - } - } + // Find the next dns component. + preDNS, maddr := ma.SplitFunc(maddr, func(c ma.Component) bool { + switch c.Protocol().Code { + case dnsProtocol.Code, dns4Protocol.Code, dns6Protocol.Code, dnsaddrProtocol.Code: + return true + default: + return false } + }) - // If the rest is empty, we've hit the end (there _was_ no dns component). - if maddr == nil { - break - } + // If the rest is empty, we've hit the end (there _was_ no dns component). + if maddr == nil { + return []ma.Multiaddr{preDNS}, nil + } - // split off the dns component. - var resolve *ma.Component - resolve, maddr = ma.SplitFirst(maddr) - - proto := resolve.Protocol() - value := resolve.Value() - rslv := r.getResolver(value) - - // resolve the dns component - var resolved []ma.Multiaddr - switch proto.Code { - case dns4Protocol.Code, dns6Protocol.Code, dnsProtocol.Code: - // The dns, dns4, and dns6 resolver simply resolves each - // dns* component into an ipv4/ipv6 address. - - v4only := proto.Code == dns4Protocol.Code - v6only := proto.Code == dns6Protocol.Code - - // XXX: Unfortunately, go does a pretty terrible job of - // differentiating between IPv6 and IPv4. A v4-in-v6 - // AAAA record will _look_ like an A record to us and - // there's nothing we can do about that. - records, err := rslv.LookupIPAddr(ctx, value) - if err != nil { - return nil, err - } + // split off the dns component. + resolve, postDNS := ma.SplitFirst(maddr) + + proto := resolve.Protocol() + value := resolve.Value() + rslv := r.getResolver(value) + + // resolve the dns component + var resolved []ma.Multiaddr + switch proto.Code { + case dns4Protocol.Code, dns6Protocol.Code, dnsProtocol.Code: + // The dns, dns4, and dns6 resolver simply resolves each + // dns* component into an ipv4/ipv6 address. + + v4only := proto.Code == dns4Protocol.Code + v6only := proto.Code == dns6Protocol.Code - // Convert each DNS record into a multiaddr. If the - // protocol is dns4, throw away any IPv6 addresses. If - // the protocol is dns6, throw away any IPv4 addresses. - - for _, r := range records { - var ( - rmaddr ma.Multiaddr - err error - ) - ip4 := r.IP.To4() - if ip4 == nil { - if v4only { - continue - } - rmaddr, err = ma.NewMultiaddr("/ip6/" + r.IP.String()) - } else { - if v6only { - continue - } - rmaddr, err = ma.NewMultiaddr("/ip4/" + ip4.String()) + // XXX: Unfortunately, go does a pretty terrible job of + // differentiating between IPv6 and IPv4. A v4-in-v6 + // AAAA record will _look_ like an A record to us and + // there's nothing we can do about that. + records, err := rslv.LookupIPAddr(ctx, value) + if err != nil { + return nil, err + } + + // Convert each DNS record into a multiaddr. If the + // protocol is dns4, throw away any IPv6 addresses. If + // the protocol is dns6, throw away any IPv4 addresses. + + for _, r := range records { + var ( + rmaddr ma.Multiaddr + err error + ) + ip4 := r.IP.To4() + if ip4 == nil { + if v4only { + continue } - if err != nil { - return nil, err + rmaddr, err = ma.NewMultiaddr("/ip6/" + r.IP.String()) + } else { + if v6only { + continue } - resolved = append(resolved, rmaddr) + rmaddr, err = ma.NewMultiaddr("/ip4/" + ip4.String()) } - case dnsaddrProtocol.Code: - // The dnsaddr resolver is a bit more complicated. We: - // - // 1. Lookup the dnsaddr txt record on _dnsaddr.DOMAIN.TLD - // 2. Take everything _after_ the `/dnsaddr/DOMAIN.TLD` - // part of the multiaddr. - // 3. Find the dnsaddr records (if any) with suffixes - // matching the result of step 2. - - // First, lookup the TXT record - records, err := rslv.LookupTXT(ctx, "_dnsaddr."+value) if err != nil { return nil, err } + resolved = append(resolved, rmaddr) + } + case dnsaddrProtocol.Code: + // The dnsaddr resolver is a bit more complicated. We: + // + // 1. Lookup the dnsaddr txt record on _dnsaddr.DOMAIN.TLD + // 2. Take everything _after_ the `/dnsaddr/DOMAIN.TLD` + // part of the multiaddr. + // 3. Find the dnsaddr records (if any) with suffixes + // matching the result of step 2. + + // First, lookup the TXT record + records, err := rslv.LookupTXT(ctx, "_dnsaddr."+value) + if err != nil { + return nil, err + } - // Then, calculate the length of the suffix we're - // looking for. - length := 0 - if maddr != nil { - length = addrLen(maddr) + // Then, calculate the length of the suffix we're + // looking for. + length := 0 + if postDNS != nil { + length = addrLen(postDNS) + } + + for _, r := range records { + // Ignore non dnsaddr TXT records. + if !strings.HasPrefix(r, dnsaddrTXTPrefix) { + continue } - for _, r := range records { - // Ignore non dnsaddr TXT records. - if !strings.HasPrefix(r, dnsaddrTXTPrefix) { - continue - } + // Extract and decode the multiaddr. + rmaddr, err := ma.NewMultiaddr(r[len(dnsaddrTXTPrefix):]) + if err != nil { + // discard multiaddrs we don't understand. + // XXX: Is this right? It's the best we + // can do for now, really. + continue + } - // Extract and decode the multiaddr. - rmaddr, err := ma.NewMultiaddr(r[len(dnsaddrTXTPrefix):]) - if err != nil { - // discard multiaddrs we don't understand. - // XXX: Is this right? It's the best we - // can do for now, really. + // If we have a suffix to match on. + if postDNS != nil { + // Make sure the new address is at least + // as long as the suffix we're looking + // for. + rmlen := addrLen(rmaddr) + if rmlen < length { + // not long enough. continue } - // If we have a suffix to match on. - if maddr != nil { - // Make sure the new address is at least - // as long as the suffix we're looking - // for. - rmlen := addrLen(rmaddr) - if rmlen < length { - // not long enough. - continue - } - - // Matches everything after the /dnsaddr/... with the end of the - // dnsaddr record: - // - // v----------rmlen-----------------v - // /ip4/1.2.3.4/tcp/1234/p2p/QmFoobar - // /p2p/QmFoobar - // ^--(rmlen - length)--^---length--^ - if !maddr.Equal(offset(rmaddr, rmlen-length)) { - continue - } + // Matches everything after the /dnsaddr/... with the end of the + // dnsaddr record: + // + // v----------rmlen-----------------v + // /ip4/1.2.3.4/tcp/1234/p2p/QmFoobar + // /p2p/QmFoobar + // ^--(rmlen - length)--^---length--^ + if !postDNS.Equal(offset(rmaddr, rmlen-length)) { + continue } - - resolved = append(resolved, rmaddr) } - // consumes the rest of the multiaddr as part of the "match" process. - maddr = nil - default: - panic("unreachable") + // remove the suffix from the multiaddr, we'll add it back at the end. + if postDNS != nil { + rmaddr = rmaddr.Decapsulate(postDNS) + } + resolved = append(resolved, rmaddr) } + default: + panic("unreachable") + } + + if len(resolved) == 0 { + return nil, nil + } - if len(resolved) == 0 { - return nil, nil - } else if len(results) == 0 { - results = resolved - } else { - // We take the cross product here as we don't have any - // better way to represent "ORs" in multiaddrs. For - // example, `/dns/foo.com/p2p-circuit/dns/bar.com` could - // resolve to: - // - // * /ip4/1.1.1.1/p2p-circuit/ip4/2.1.1.1 - // * /ip4/1.1.1.1/p2p-circuit/ip4/2.1.1.2 - // * /ip4/1.1.1.2/p2p-circuit/ip4/2.1.1.1 - // * /ip4/1.1.1.2/p2p-circuit/ip4/2.1.1.2 - results = cross(results, resolved) + if preDNS != nil { + for i, m := range resolved { + resolved[i] = preDNS.Encapsulate(m) + } + } + if postDNS != nil { + for i, m := range resolved { + resolved[i] = m.Encapsulate(postDNS) } } - return results, nil + return resolved, nil } func (r *Resolver) LookupIPAddr(ctx context.Context, domain string) ([]net.IPAddr, error) { diff --git a/resolve_test.go b/resolve_test.go index b2ccb47..f09b939 100644 --- a/resolve_test.go +++ b/resolve_test.go @@ -8,25 +8,33 @@ import ( ma "github.com/multiformats/go-multiaddr" ) -var ip4a = net.IPAddr{IP: net.ParseIP("192.0.2.1")} -var ip4b = net.IPAddr{IP: net.ParseIP("192.0.2.2")} -var ip6a = net.IPAddr{IP: net.ParseIP("2001:db8::a3")} -var ip6b = net.IPAddr{IP: net.ParseIP("2001:db8::a4")} - -var ip4ma = ma.StringCast("/ip4/" + ip4a.IP.String()) -var ip4mb = ma.StringCast("/ip4/" + ip4b.IP.String()) -var ip6ma = ma.StringCast("/ip6/" + ip6a.IP.String()) -var ip6mb = ma.StringCast("/ip6/" + ip6b.IP.String()) - -var txtmc = ma.Join(ip4ma, ma.StringCast("/tcp/123/http")) -var txtmd = ma.Join(ip4ma, ma.StringCast("/tcp/123")) -var txtme = ma.Join(ip4ma, ma.StringCast("/tcp/789/http")) - -var txta = "dnsaddr=" + ip4ma.String() -var txtb = "dnsaddr=" + ip6ma.String() -var txtc = "dnsaddr=" + txtmc.String() -var txtd = "dnsaddr=" + txtmd.String() -var txte = "dnsaddr=" + txtme.String() +var ( + ip4a = net.IPAddr{IP: net.ParseIP("192.0.2.1")} + ip4b = net.IPAddr{IP: net.ParseIP("192.0.2.2")} + ip6a = net.IPAddr{IP: net.ParseIP("2001:db8::a3")} + ip6b = net.IPAddr{IP: net.ParseIP("2001:db8::a4")} +) + +var ( + ip4ma = ma.StringCast("/ip4/" + ip4a.IP.String()) + ip4mb = ma.StringCast("/ip4/" + ip4b.IP.String()) + ip6ma = ma.StringCast("/ip6/" + ip6a.IP.String()) + ip6mb = ma.StringCast("/ip6/" + ip6b.IP.String()) +) + +var ( + txtmc = ma.Join(ip4ma, ma.StringCast("/tcp/123/http")) + txtmd = ma.Join(ip4ma, ma.StringCast("/tcp/123")) + txtme = ma.Join(ip4ma, ma.StringCast("/tcp/789/http")) +) + +var ( + txta = "dnsaddr=" + ip4ma.String() + txtb = "dnsaddr=" + ip6ma.String() + txtc = "dnsaddr=" + txtmc.String() + txtd = "dnsaddr=" + txtmd.String() + txte = "dnsaddr=" + txtme.String() +) func makeResolver() *Resolver { mock := &MockResolver{ @@ -96,7 +104,7 @@ func TestSimpleIPResolve(t *testing.T) { } } -func TestResolveMultiple(t *testing.T) { +func TestResolveOnlyOnce(t *testing.T) { ctx := context.Background() resolver := makeResolver() @@ -104,28 +112,45 @@ func TestResolveMultiple(t *testing.T) { if err != nil { t.Error(err) } + for i, x := range []ma.Multiaddr{ip4ma, ip4mb} { - for j, y := range []ma.Multiaddr{ip6ma, ip6mb} { - expected := ma.Join(x, ma.StringCast("/quic"), y) - actual := addrs[i*2+j] - if !expected.Equal(actual) { - t.Fatalf("expected %s, got %s", expected, actual) + expected := ma.Join(x, ma.StringCast("/quic/dns6/example.com")) + actual := addrs[i] + if !expected.Equal(actual) { + t.Fatalf("expected %s, got %s", expected, actual) + } + } +} + +func resolveAllDNS(ctx context.Context, resolver *Resolver, in ma.Multiaddr) ([]ma.Multiaddr, error) { + var inAddrs []ma.Multiaddr + outAddrs := []ma.Multiaddr{in} + + for len(inAddrs) != len(outAddrs) { + inAddrs = outAddrs + outAddrs = nil + for _, inAddr := range inAddrs { + addrs, err := resolver.Resolve(ctx, inAddr) + if err != nil { + return nil, err } + outAddrs = append(outAddrs, addrs...) } } + return outAddrs, nil } -func TestResolveMultipleAdjacent(t *testing.T) { +func TestResolveMultiple(t *testing.T) { ctx := context.Background() resolver := makeResolver() - addrs, err := resolver.Resolve(ctx, ma.StringCast("/dns4/example.com/dns6/example.com")) + addrs, err := resolveAllDNS(ctx, resolver, ma.StringCast("/dns4/example.com/quic/dns6/example.com")) if err != nil { t.Error(err) } for i, x := range []ma.Multiaddr{ip4ma, ip4mb} { for j, y := range []ma.Multiaddr{ip6ma, ip6mb} { - expected := ma.Join(x, y) + expected := ma.Join(x, ma.StringCast("/quic"), y) actual := addrs[i*2+j] if !expected.Equal(actual) { t.Fatalf("expected %s, got %s", expected, actual) @@ -138,7 +163,7 @@ func TestResolveMultipleSandwitch(t *testing.T) { ctx := context.Background() resolver := makeResolver() - addrs, err := resolver.Resolve(ctx, ma.StringCast("/quic/dns4/example.com/dns6/example.com/http")) + addrs, err := resolveAllDNS(ctx, resolver, ma.StringCast("/quic/dns4/example.com/dns6/example.com/http")) if err != nil { t.Error(err) } diff --git a/util.go b/util.go index b124639..1b72557 100644 --- a/util.go +++ b/util.go @@ -42,16 +42,3 @@ func offset(maddr ma.Multiaddr, offset int) ma.Multiaddr { }) return after } - -// takes the cross product of two sets of multiaddrs -// -// assumes `a` is non-empty. -func cross(a, b []ma.Multiaddr) []ma.Multiaddr { - res := make([]ma.Multiaddr, 0, len(a)*len(b)) - for _, x := range a { - for _, y := range b { - res = append(res, x.Encapsulate(y)) - } - } - return res -}