feat(netxlite): implements NS queries (#734)
This diff has been extracted from https://github.com/bassosimone/websteps-illustrated/commit/eb0bf38957e79fbad198fcdc9f9c7b36f61a8e2c. See https://github.com/ooni/probe/issues/2096. While there, skip the broken tests caused by issue https://github.com/ooni/probe/issues/2098.
This commit is contained in:
@@ -6,8 +6,6 @@ package netxlite
|
||||
// This file helps us to decide if an IPAddr is a bogon.
|
||||
//
|
||||
|
||||
// TODO(bassosimone): code in engine/netx should use this file.
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ package netxlite
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/ooni/probe-cli/v3/internal/model"
|
||||
@@ -111,4 +112,22 @@ func (d *DNSDecoderMiekg) DecodeLookupHost(qtype uint16, data []byte, queryID ui
|
||||
return addrs, nil
|
||||
}
|
||||
|
||||
func (d *DNSDecoderMiekg) DecodeNS(data []byte, queryID uint16) ([]*net.NS, error) {
|
||||
reply, err := d.parseReply(data, queryID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out := []*net.NS{}
|
||||
for _, answer := range reply.Answer {
|
||||
switch avalue := answer.(type) {
|
||||
case *dns.NS:
|
||||
out = append(out, &net.NS{Host: avalue.Ns})
|
||||
}
|
||||
}
|
||||
if len(out) < 1 {
|
||||
return nil, ErrOODNSNoAnswer
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
var _ model.DNSDecoder = &DNSDecoderMiekg{}
|
||||
|
||||
@@ -192,8 +192,8 @@ func TestDNSDecoder(t *testing.T) {
|
||||
queryID = 17
|
||||
unrelatedID = 14
|
||||
)
|
||||
reply := dnsGenHTTPSReplySuccess(dnsGenQuery(dns.TypeA, queryID), nil, nil, nil)
|
||||
data, err := d.DecodeLookupHost(dns.TypeA, reply, unrelatedID)
|
||||
reply := dnsGenHTTPSReplySuccess(dnsGenQuery(dns.TypeHTTPS, queryID), nil, nil, nil)
|
||||
data, err := d.DecodeHTTPS(reply, unrelatedID)
|
||||
if !errors.Is(err, ErrDNSReplyWithWrongQueryID) {
|
||||
t.Fatal("unexpected error", err)
|
||||
}
|
||||
@@ -239,6 +239,64 @@ func TestDNSDecoder(t *testing.T) {
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("DecodeNS", func(t *testing.T) {
|
||||
t.Run("with nil data", func(t *testing.T) {
|
||||
d := &DNSDecoderMiekg{}
|
||||
reply, err := d.DecodeNS(nil, 0)
|
||||
if err == nil || err.Error() != "dns: overflow unpacking uint16" {
|
||||
t.Fatal("not the error we expected", err)
|
||||
}
|
||||
if reply != nil {
|
||||
t.Fatal("expected nil reply")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("wrong query ID", func(t *testing.T) {
|
||||
d := &DNSDecoderMiekg{}
|
||||
const (
|
||||
queryID = 17
|
||||
unrelatedID = 14
|
||||
)
|
||||
reply := dnsGenNSReplySuccess(dnsGenQuery(dns.TypeNS, queryID))
|
||||
data, err := d.DecodeNS(reply, unrelatedID)
|
||||
if !errors.Is(err, ErrDNSReplyWithWrongQueryID) {
|
||||
t.Fatal("unexpected error", err)
|
||||
}
|
||||
if data != nil {
|
||||
t.Fatal("expected nil data here")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("with empty answer", func(t *testing.T) {
|
||||
queryID := dns.Id()
|
||||
data := dnsGenNSReplySuccess(dnsGenQuery(dns.TypeNS, queryID))
|
||||
d := &DNSDecoderMiekg{}
|
||||
reply, err := d.DecodeNS(data, queryID)
|
||||
if !errors.Is(err, ErrOODNSNoAnswer) {
|
||||
t.Fatal("unexpected err", err)
|
||||
}
|
||||
if reply != nil {
|
||||
t.Fatal("expected nil reply")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("with full answer", func(t *testing.T) {
|
||||
queryID := dns.Id()
|
||||
data := dnsGenNSReplySuccess(dnsGenQuery(dns.TypeNS, queryID), "ns1.zdns.google.")
|
||||
d := &DNSDecoderMiekg{}
|
||||
reply, err := d.DecodeNS(data, queryID)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(reply) != 1 {
|
||||
t.Fatal("unexpected reply length")
|
||||
}
|
||||
if reply[0].Host != "ns1.zdns.google." {
|
||||
t.Fatal("unexpected reply host")
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// dnsGenQuery generates a query suitable to be used with testing.
|
||||
@@ -281,6 +339,10 @@ func dnsGenLookupHostReplySuccess(rawQuery []byte, ips ...string) []byte {
|
||||
runtimex.PanicOnError(err, "query.Unpack failed")
|
||||
runtimex.PanicIfFalse(len(query.Question) == 1, "more than one question")
|
||||
question := query.Question[0]
|
||||
runtimex.PanicIfFalse(
|
||||
question.Qtype == dns.TypeA || question.Qtype == dns.TypeAAAA,
|
||||
"invalid query type (expected A or AAAA)",
|
||||
)
|
||||
reply := new(dns.Msg)
|
||||
reply.Compress = true
|
||||
reply.MsgHdr.RecursionAvailable = true
|
||||
@@ -326,6 +388,9 @@ func dnsGenHTTPSReplySuccess(rawQuery []byte, alpns, ipv4s, ipv6s []string) []by
|
||||
query := new(dns.Msg)
|
||||
err := query.Unpack(rawQuery)
|
||||
runtimex.PanicOnError(err, "query.Unpack failed")
|
||||
runtimex.PanicIfFalse(len(query.Question) == 1, "expected just a single question")
|
||||
question := query.Question[0]
|
||||
runtimex.PanicIfFalse(question.Qtype == dns.TypeHTTPS, "expected HTTPS query")
|
||||
reply := new(dns.Msg)
|
||||
reply.Compress = true
|
||||
reply.MsgHdr.RecursionAvailable = true
|
||||
@@ -364,3 +429,31 @@ func dnsGenHTTPSReplySuccess(rawQuery []byte, alpns, ipv4s, ipv6s []string) []by
|
||||
runtimex.PanicOnError(err, "reply.Pack failed")
|
||||
return data
|
||||
}
|
||||
|
||||
// dnsGenNSReplySuccess generates a successful NS reply using the given names.
|
||||
func dnsGenNSReplySuccess(rawQuery []byte, names ...string) []byte {
|
||||
query := new(dns.Msg)
|
||||
err := query.Unpack(rawQuery)
|
||||
runtimex.PanicOnError(err, "query.Unpack failed")
|
||||
runtimex.PanicIfFalse(len(query.Question) == 1, "more than one question")
|
||||
question := query.Question[0]
|
||||
runtimex.PanicIfFalse(question.Qtype == dns.TypeNS, "expected NS query")
|
||||
reply := new(dns.Msg)
|
||||
reply.Compress = true
|
||||
reply.MsgHdr.RecursionAvailable = true
|
||||
reply.SetReply(query)
|
||||
for _, name := range names {
|
||||
reply.Answer = append(reply.Answer, &dns.NS{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: dns.Fqdn("x.org"),
|
||||
Rrtype: question.Qtype,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: 0,
|
||||
},
|
||||
Ns: name,
|
||||
})
|
||||
}
|
||||
data, err := reply.Pack()
|
||||
runtimex.PanicOnError(err, "reply.Pack failed")
|
||||
return data
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ package netxlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/ooni/probe-cli/v3/internal/atomicx"
|
||||
@@ -129,3 +130,18 @@ func (r *ParallelResolver) lookupHost(ctx context.Context, hostname string,
|
||||
err: err,
|
||||
}
|
||||
}
|
||||
|
||||
// LookupNS implements Resolver.LookupNS.
|
||||
func (r *ParallelResolver) LookupNS(
|
||||
ctx context.Context, hostname string) ([]*net.NS, error) {
|
||||
querydata, queryID, err := r.Encoder.Encode(
|
||||
hostname, dns.TypeNS, r.Txp.RequiresPadding())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
replydata, err := r.Txp.RoundTrip(ctx, querydata)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return r.Decoder.DecodeNS(replydata, queryID)
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
@@ -266,4 +267,94 @@ func TestParallelResolver(t *testing.T) {
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("LookupNS", func(t *testing.T) {
|
||||
t.Run("for encoding error", func(t *testing.T) {
|
||||
expected := errors.New("mocked error")
|
||||
r := &ParallelResolver{
|
||||
Encoder: &mocks.DNSEncoder{
|
||||
MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, uint16, error) {
|
||||
return nil, 0, expected
|
||||
},
|
||||
},
|
||||
Decoder: nil,
|
||||
NumTimeouts: &atomicx.Int64{},
|
||||
Txp: &mocks.DNSTransport{
|
||||
MockRequiresPadding: func() bool {
|
||||
return false
|
||||
},
|
||||
},
|
||||
}
|
||||
ctx := context.Background()
|
||||
ns, err := r.LookupNS(ctx, "example.com")
|
||||
if !errors.Is(err, expected) {
|
||||
t.Fatal("unexpected err", err)
|
||||
}
|
||||
if ns != nil {
|
||||
t.Fatal("unexpected result")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("for round-trip error", func(t *testing.T) {
|
||||
expected := errors.New("mocked error")
|
||||
r := &ParallelResolver{
|
||||
Encoder: &mocks.DNSEncoder{
|
||||
MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, uint16, error) {
|
||||
return make([]byte, 64), 0, nil
|
||||
},
|
||||
},
|
||||
Decoder: nil,
|
||||
NumTimeouts: &atomicx.Int64{},
|
||||
Txp: &mocks.DNSTransport{
|
||||
MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) {
|
||||
return nil, expected
|
||||
},
|
||||
MockRequiresPadding: func() bool {
|
||||
return false
|
||||
},
|
||||
},
|
||||
}
|
||||
ctx := context.Background()
|
||||
ns, err := r.LookupNS(ctx, "example.com")
|
||||
if !errors.Is(err, expected) {
|
||||
t.Fatal("unexpected err", err)
|
||||
}
|
||||
if ns != nil {
|
||||
t.Fatal("unexpected result")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("for decode error", func(t *testing.T) {
|
||||
expected := errors.New("mocked error")
|
||||
r := &ParallelResolver{
|
||||
Encoder: &mocks.DNSEncoder{
|
||||
MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, uint16, error) {
|
||||
return make([]byte, 64), 0, nil
|
||||
},
|
||||
},
|
||||
Decoder: &mocks.DNSDecoder{
|
||||
MockDecodeNS: func(reply []byte, queryID uint16) ([]*net.NS, error) {
|
||||
return nil, expected
|
||||
},
|
||||
},
|
||||
NumTimeouts: &atomicx.Int64{},
|
||||
Txp: &mocks.DNSTransport{
|
||||
MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) {
|
||||
return make([]byte, 128), nil
|
||||
},
|
||||
MockRequiresPadding: func() bool {
|
||||
return false
|
||||
},
|
||||
},
|
||||
}
|
||||
ctx := context.Background()
|
||||
https, err := r.LookupNS(ctx, "example.com")
|
||||
if !errors.Is(err, expected) {
|
||||
t.Fatal("unexpected err", err)
|
||||
}
|
||||
if https != nil {
|
||||
t.Fatal("unexpected result")
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
@@ -136,6 +136,16 @@ func (r *resolverSystem) LookupHTTPS(
|
||||
return nil, ErrNoDNSTransport
|
||||
}
|
||||
|
||||
func (r *resolverSystem) LookupNS(
|
||||
ctx context.Context, domain string) ([]*net.NS, error) {
|
||||
// TODO(bassosimone): figure out in which context it makes sense
|
||||
// to issue this query. How is this implemented under the hood by
|
||||
// the stdlib? Is it using /etc/resolve.conf on Unix? Until we
|
||||
// known all these details, let's pretend this functionality does
|
||||
// not exist in the stdlib and focus on custom resolvers.
|
||||
return nil, ErrNoDNSTransport
|
||||
}
|
||||
|
||||
// resolverLogger is a resolver that emits events
|
||||
type resolverLogger struct {
|
||||
Resolver model.Resolver
|
||||
@@ -188,6 +198,21 @@ func (r *resolverLogger) CloseIdleConnections() {
|
||||
r.Resolver.CloseIdleConnections()
|
||||
}
|
||||
|
||||
func (r *resolverLogger) LookupNS(
|
||||
ctx context.Context, domain string) ([]*net.NS, error) {
|
||||
prefix := fmt.Sprintf("resolve[NS] %s with %s (%s)", domain, r.Network(), r.Address())
|
||||
r.Logger.Debugf("%s...", prefix)
|
||||
start := time.Now()
|
||||
ns, err := r.Resolver.LookupNS(ctx, domain)
|
||||
elapsed := time.Since(start)
|
||||
if err != nil {
|
||||
r.Logger.Debugf("%s... %s in %s", prefix, err, elapsed)
|
||||
return nil, err
|
||||
}
|
||||
r.Logger.Debugf("%s... %+v in %s", prefix, ns, elapsed)
|
||||
return ns, nil
|
||||
}
|
||||
|
||||
// resolverIDNA supports resolving Internationalized Domain Names.
|
||||
//
|
||||
// See RFC3492 for more information.
|
||||
@@ -226,6 +251,15 @@ func (r *resolverIDNA) CloseIdleConnections() {
|
||||
r.Resolver.CloseIdleConnections()
|
||||
}
|
||||
|
||||
func (r *resolverIDNA) LookupNS(
|
||||
ctx context.Context, domain string) ([]*net.NS, error) {
|
||||
host, err := idna.ToASCII(domain)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return r.Resolver.LookupNS(ctx, host)
|
||||
}
|
||||
|
||||
// resolverShortCircuitIPAddr recognizes when the input hostname is an
|
||||
// IP address and returns it immediately to the caller.
|
||||
type resolverShortCircuitIPAddr struct {
|
||||
@@ -266,6 +300,18 @@ func (r *resolverShortCircuitIPAddr) CloseIdleConnections() {
|
||||
r.Resolver.CloseIdleConnections()
|
||||
}
|
||||
|
||||
// ErrDNSIPAddress indicates that you passed an IP address to a DNS
|
||||
// function that only works with domain names.
|
||||
var ErrDNSIPAddress = errors.New("ooresolver: expected domain, found IP address")
|
||||
|
||||
func (r *resolverShortCircuitIPAddr) LookupNS(
|
||||
ctx context.Context, hostname string) ([]*net.NS, error) {
|
||||
if net.ParseIP(hostname) != nil {
|
||||
return nil, ErrDNSIPAddress
|
||||
}
|
||||
return r.Resolver.LookupNS(ctx, hostname)
|
||||
}
|
||||
|
||||
// IsIPv6 returns true if the given candidate is a valid IP address
|
||||
// representation and such representation is IPv6.
|
||||
func IsIPv6(candidate string) (bool, error) {
|
||||
@@ -313,6 +359,11 @@ func (r *nullResolver) LookupHTTPS(
|
||||
return nil, ErrNoResolver
|
||||
}
|
||||
|
||||
func (r *nullResolver) LookupNS(
|
||||
ctx context.Context, domain string) ([]*net.NS, error) {
|
||||
return nil, ErrNoResolver
|
||||
}
|
||||
|
||||
// resolverErrWrapper is a Resolver that knows about wrapping errors.
|
||||
type resolverErrWrapper struct {
|
||||
Resolver model.Resolver
|
||||
@@ -348,3 +399,12 @@ func (r *resolverErrWrapper) Address() string {
|
||||
func (r *resolverErrWrapper) CloseIdleConnections() {
|
||||
r.Resolver.CloseIdleConnections()
|
||||
}
|
||||
|
||||
func (r *resolverErrWrapper) LookupNS(
|
||||
ctx context.Context, domain string) ([]*net.NS, error) {
|
||||
out, err := r.Resolver.LookupNS(ctx, domain)
|
||||
if err != nil {
|
||||
return nil, newErrWrapper(classifyResolverError, ResolveOperation, err)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
@@ -166,6 +167,17 @@ func TestResolverSystem(t *testing.T) {
|
||||
t.Fatal("expected nil result")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("LookupNS", func(t *testing.T) {
|
||||
r := &resolverSystem{}
|
||||
ns, err := r.LookupNS(context.Background(), "x.org")
|
||||
if !errors.Is(err, ErrNoDNSTransport) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if ns != nil {
|
||||
t.Fatal("expected nil result")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestResolverLogger(t *testing.T) {
|
||||
@@ -312,6 +324,94 @@ func TestResolverLogger(t *testing.T) {
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("CloseIdleConnections", func(t *testing.T) {
|
||||
var called bool
|
||||
child := &mocks.Resolver{
|
||||
MockCloseIdleConnections: func() {
|
||||
called = true
|
||||
},
|
||||
}
|
||||
reso := &resolverLogger{
|
||||
Resolver: child,
|
||||
Logger: model.DiscardLogger,
|
||||
}
|
||||
reso.CloseIdleConnections()
|
||||
if !called {
|
||||
t.Fatal("not called")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("LookupNS", func(t *testing.T) {
|
||||
t.Run("with success", func(t *testing.T) {
|
||||
var count int
|
||||
lo := &mocks.Logger{
|
||||
MockDebugf: func(format string, v ...interface{}) {
|
||||
count++
|
||||
},
|
||||
}
|
||||
expected := []*net.NS{{
|
||||
Host: "ns1.zdns.google.",
|
||||
}}
|
||||
r := &resolverLogger{
|
||||
Logger: lo,
|
||||
Resolver: &mocks.Resolver{
|
||||
MockLookupNS: func(ctx context.Context, domain string) ([]*net.NS, error) {
|
||||
return expected, nil
|
||||
},
|
||||
MockNetwork: func() string {
|
||||
return "system"
|
||||
},
|
||||
MockAddress: func() string {
|
||||
return ""
|
||||
},
|
||||
},
|
||||
}
|
||||
ns, err := r.LookupNS(context.Background(), "dns.google")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if diff := cmp.Diff(expected, ns); diff != "" {
|
||||
t.Fatal(diff)
|
||||
}
|
||||
if count != 2 {
|
||||
t.Fatal("unexpected count")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("with failure", func(t *testing.T) {
|
||||
var count int
|
||||
lo := &mocks.Logger{
|
||||
MockDebugf: func(format string, v ...interface{}) {
|
||||
count++
|
||||
},
|
||||
}
|
||||
expected := errors.New("mocked error")
|
||||
r := &resolverLogger{
|
||||
Logger: lo,
|
||||
Resolver: &mocks.Resolver{
|
||||
MockLookupNS: func(ctx context.Context, domain string) ([]*net.NS, error) {
|
||||
return nil, expected
|
||||
},
|
||||
MockNetwork: func() string {
|
||||
return "system"
|
||||
},
|
||||
MockAddress: func() string {
|
||||
return ""
|
||||
},
|
||||
},
|
||||
}
|
||||
ns, err := r.LookupNS(context.Background(), "dns.google")
|
||||
if !errors.Is(err, expected) {
|
||||
t.Fatal("not the error we expected", err)
|
||||
}
|
||||
if ns != nil {
|
||||
t.Fatal("expected nil addr here")
|
||||
}
|
||||
if count != 2 {
|
||||
t.Fatal("unexpected count")
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestResolverIDNA(t *testing.T) {
|
||||
@@ -424,6 +524,63 @@ func TestResolverIDNA(t *testing.T) {
|
||||
t.Fatal("invalid address")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("CloseIdleConnections", func(t *testing.T) {
|
||||
var called bool
|
||||
child := &mocks.Resolver{
|
||||
MockCloseIdleConnections: func() {
|
||||
called = true
|
||||
},
|
||||
}
|
||||
reso := &resolverIDNA{child}
|
||||
reso.CloseIdleConnections()
|
||||
if !called {
|
||||
t.Fatal("not called")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("LookupNS", func(t *testing.T) {
|
||||
t.Run("with valid IDNA in input", func(t *testing.T) {
|
||||
expected := []*net.NS{{
|
||||
Host: "ns1.zdns.google.",
|
||||
}}
|
||||
r := &resolverIDNA{
|
||||
Resolver: &mocks.Resolver{
|
||||
MockLookupNS: func(ctx context.Context, domain string) ([]*net.NS, error) {
|
||||
if domain != "xn--d1acpjx3f.xn--p1ai" {
|
||||
return nil, errors.New("passed invalid domain")
|
||||
}
|
||||
return expected, nil
|
||||
},
|
||||
},
|
||||
}
|
||||
ctx := context.Background()
|
||||
ns, err := r.LookupNS(ctx, "яндекс.рф")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if diff := cmp.Diff(expected, ns); diff != "" {
|
||||
t.Fatal(diff)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("with invalid punycode", func(t *testing.T) {
|
||||
r := &resolverIDNA{Resolver: &mocks.Resolver{
|
||||
MockLookupNS: func(ctx context.Context, domain string) ([]*net.NS, error) {
|
||||
return nil, errors.New("should not happen")
|
||||
},
|
||||
}}
|
||||
// See https://www.farsightsecurity.com/blog/txt-record/punycode-20180711/
|
||||
ctx := context.Background()
|
||||
ns, err := r.LookupNS(ctx, "xn--0000h")
|
||||
if err == nil || !strings.HasPrefix(err.Error(), "idna: invalid label") {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if ns != nil {
|
||||
t.Fatal("expected no response here")
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestResolverShortCircuitIPAddr(t *testing.T) {
|
||||
@@ -520,6 +677,100 @@ func TestResolverShortCircuitIPAddr(t *testing.T) {
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("LookupNS", func(t *testing.T) {
|
||||
t.Run("with IPv4 addr", func(t *testing.T) {
|
||||
r := &resolverShortCircuitIPAddr{
|
||||
Resolver: &mocks.Resolver{
|
||||
MockLookupNS: func(ctx context.Context, domain string) ([]*net.NS, error) {
|
||||
return nil, errors.New("mocked error")
|
||||
},
|
||||
},
|
||||
}
|
||||
ctx := context.Background()
|
||||
ns, err := r.LookupNS(ctx, "8.8.8.8")
|
||||
if !errors.Is(err, ErrDNSIPAddress) {
|
||||
t.Fatal("unexpected error", err)
|
||||
}
|
||||
if len(ns) > 0 {
|
||||
t.Fatal("invalid result")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("with IPv6 addr", func(t *testing.T) {
|
||||
r := &resolverShortCircuitIPAddr{
|
||||
Resolver: &mocks.Resolver{
|
||||
MockLookupNS: func(ctx context.Context, domain string) ([]*net.NS, error) {
|
||||
return nil, errors.New("mocked error")
|
||||
},
|
||||
},
|
||||
}
|
||||
ctx := context.Background()
|
||||
ns, err := r.LookupNS(ctx, "::1")
|
||||
if !errors.Is(err, ErrDNSIPAddress) {
|
||||
t.Fatal("unexpected error", err)
|
||||
}
|
||||
if len(ns) > 0 {
|
||||
t.Fatal("invalid result")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("with domain", func(t *testing.T) {
|
||||
r := &resolverShortCircuitIPAddr{
|
||||
Resolver: &mocks.Resolver{
|
||||
MockLookupNS: func(ctx context.Context, domain string) ([]*net.NS, error) {
|
||||
return nil, errors.New("mocked error")
|
||||
},
|
||||
},
|
||||
}
|
||||
ctx := context.Background()
|
||||
ns, err := r.LookupNS(ctx, "dns.google")
|
||||
if err == nil || err.Error() != "mocked error" {
|
||||
t.Fatal("not the error we expected", err)
|
||||
}
|
||||
if len(ns) > 0 {
|
||||
t.Fatal("invalid result")
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("Network", func(t *testing.T) {
|
||||
child := &mocks.Resolver{
|
||||
MockNetwork: func() string {
|
||||
return "x"
|
||||
},
|
||||
}
|
||||
reso := &resolverShortCircuitIPAddr{child}
|
||||
if reso.Network() != "x" {
|
||||
t.Fatal("invalid result")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Address", func(t *testing.T) {
|
||||
child := &mocks.Resolver{
|
||||
MockAddress: func() string {
|
||||
return "x"
|
||||
},
|
||||
}
|
||||
reso := &resolverShortCircuitIPAddr{child}
|
||||
if reso.Address() != "x" {
|
||||
t.Fatal("invalid result")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("CloseIdleConnections", func(t *testing.T) {
|
||||
var called bool
|
||||
child := &mocks.Resolver{
|
||||
MockCloseIdleConnections: func() {
|
||||
called = true
|
||||
},
|
||||
}
|
||||
reso := &resolverShortCircuitIPAddr{child}
|
||||
reso.CloseIdleConnections()
|
||||
if !called {
|
||||
t.Fatal("not called")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestIsIPv6(t *testing.T) {
|
||||
@@ -592,6 +843,18 @@ func TestNullResolver(t *testing.T) {
|
||||
}
|
||||
r.CloseIdleConnections() // for coverage
|
||||
})
|
||||
|
||||
t.Run("LookupNS", func(t *testing.T) {
|
||||
r := &nullResolver{}
|
||||
ctx := context.Background()
|
||||
ns, err := r.LookupNS(ctx, "dns.google")
|
||||
if !errors.Is(err, ErrNoResolver) {
|
||||
t.Fatal("unexpected error", err)
|
||||
}
|
||||
if len(ns) > 0 {
|
||||
t.Fatal("unexpected result")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestResolverErrWrapper(t *testing.T) {
|
||||
@@ -719,4 +982,46 @@ func TestResolverErrWrapper(t *testing.T) {
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("LookupNS", func(t *testing.T) {
|
||||
t.Run("on success", func(t *testing.T) {
|
||||
expected := []*net.NS{{
|
||||
Host: "antani.local.",
|
||||
}}
|
||||
reso := &resolverErrWrapper{
|
||||
Resolver: &mocks.Resolver{
|
||||
MockLookupNS: func(ctx context.Context, domain string) ([]*net.NS, error) {
|
||||
return expected, nil
|
||||
},
|
||||
},
|
||||
}
|
||||
ctx := context.Background()
|
||||
ns, err := reso.LookupNS(ctx, "antani.local")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if diff := cmp.Diff(expected, ns); diff != "" {
|
||||
t.Fatal(diff)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("on failure", func(t *testing.T) {
|
||||
expected := io.EOF
|
||||
reso := &resolverErrWrapper{
|
||||
Resolver: &mocks.Resolver{
|
||||
MockLookupNS: func(ctx context.Context, domain string) ([]*net.NS, error) {
|
||||
return nil, expected
|
||||
},
|
||||
},
|
||||
}
|
||||
ctx := context.Background()
|
||||
ns, err := reso.LookupNS(ctx, "")
|
||||
if err == nil || err.Error() != FailureEOFError {
|
||||
t.Fatal("unexpected err", err)
|
||||
}
|
||||
if len(ns) > 0 {
|
||||
t.Fatal("unexpected addrs")
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
@@ -23,8 +23,8 @@ import (
|
||||
// Deprecated: please use ParallelResolver in new code. We cannot
|
||||
// remove this code as long as we use tracing for measuring.
|
||||
//
|
||||
// QUIRK: unlike the ParallelResolver, this resolver retries each
|
||||
// query three times for soft errors.
|
||||
// QUIRK: unlike the ParallelResolver, this resolver's LookupHost retries
|
||||
// each query three times for soft errors.
|
||||
type SerialResolver struct {
|
||||
// Encoder is the MANDATORY encoder to use.
|
||||
Encoder model.DNSEncoder
|
||||
@@ -142,3 +142,18 @@ func (r *SerialResolver) lookupHostWithoutRetry(
|
||||
}
|
||||
return r.Decoder.DecodeLookupHost(qtype, replydata, queryID)
|
||||
}
|
||||
|
||||
// LookupNS implements Resolver.LookupNS.
|
||||
func (r *SerialResolver) LookupNS(
|
||||
ctx context.Context, hostname string) ([]*net.NS, error) {
|
||||
querydata, queryID, err := r.Encoder.Encode(
|
||||
hostname, dns.TypeNS, r.Txp.RequiresPadding())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
replydata, err := r.Txp.RoundTrip(ctx, querydata)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return r.Decoder.DecodeNS(replydata, queryID)
|
||||
}
|
||||
|
||||
@@ -272,4 +272,94 @@ func TestSerialResolver(t *testing.T) {
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("LookupNS", func(t *testing.T) {
|
||||
t.Run("for encoding error", func(t *testing.T) {
|
||||
expected := errors.New("mocked error")
|
||||
r := &SerialResolver{
|
||||
Encoder: &mocks.DNSEncoder{
|
||||
MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, uint16, error) {
|
||||
return nil, 0, expected
|
||||
},
|
||||
},
|
||||
Decoder: nil,
|
||||
NumTimeouts: &atomicx.Int64{},
|
||||
Txp: &mocks.DNSTransport{
|
||||
MockRequiresPadding: func() bool {
|
||||
return false
|
||||
},
|
||||
},
|
||||
}
|
||||
ctx := context.Background()
|
||||
ns, err := r.LookupNS(ctx, "example.com")
|
||||
if !errors.Is(err, expected) {
|
||||
t.Fatal("unexpected err", err)
|
||||
}
|
||||
if ns != nil {
|
||||
t.Fatal("unexpected result")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("for round-trip error", func(t *testing.T) {
|
||||
expected := errors.New("mocked error")
|
||||
r := &SerialResolver{
|
||||
Encoder: &mocks.DNSEncoder{
|
||||
MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, uint16, error) {
|
||||
return make([]byte, 64), 0, nil
|
||||
},
|
||||
},
|
||||
Decoder: nil,
|
||||
NumTimeouts: &atomicx.Int64{},
|
||||
Txp: &mocks.DNSTransport{
|
||||
MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) {
|
||||
return nil, expected
|
||||
},
|
||||
MockRequiresPadding: func() bool {
|
||||
return false
|
||||
},
|
||||
},
|
||||
}
|
||||
ctx := context.Background()
|
||||
ns, err := r.LookupNS(ctx, "example.com")
|
||||
if !errors.Is(err, expected) {
|
||||
t.Fatal("unexpected err", err)
|
||||
}
|
||||
if ns != nil {
|
||||
t.Fatal("unexpected result")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("for decode error", func(t *testing.T) {
|
||||
expected := errors.New("mocked error")
|
||||
r := &SerialResolver{
|
||||
Encoder: &mocks.DNSEncoder{
|
||||
MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, uint16, error) {
|
||||
return make([]byte, 64), 0, nil
|
||||
},
|
||||
},
|
||||
Decoder: &mocks.DNSDecoder{
|
||||
MockDecodeNS: func(reply []byte, queryID uint16) ([]*net.NS, error) {
|
||||
return nil, expected
|
||||
},
|
||||
},
|
||||
NumTimeouts: &atomicx.Int64{},
|
||||
Txp: &mocks.DNSTransport{
|
||||
MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) {
|
||||
return make([]byte, 128), nil
|
||||
},
|
||||
MockRequiresPadding: func() bool {
|
||||
return false
|
||||
},
|
||||
},
|
||||
}
|
||||
ctx := context.Background()
|
||||
https, err := r.LookupNS(ctx, "example.com")
|
||||
if !errors.Is(err, expected) {
|
||||
t.Fatal("unexpected err", err)
|
||||
}
|
||||
if https != nil {
|
||||
t.Fatal("unexpected result")
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user