fix(netxlite): reject replies with wrong queryID (#732)

This diff has been extracted from c2f7ccab0e

See https://github.com/ooni/probe/issues/2096

While there, export DecodeReply to decode a raw reply without
interpreting the Rcode or parsing the results, which seems a
nice extra feature to have to more flexibly parse DNS replies
in other parts of the codebase.
This commit is contained in:
Simone Basso 2022-05-14 19:38:46 +02:00 committed by GitHub
parent f5b801ae95
commit 9d2301cae2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 262 additions and 129 deletions

View File

@ -1,20 +1,30 @@
package mocks package mocks
import "github.com/ooni/probe-cli/v3/internal/model" import (
"github.com/miekg/dns"
"github.com/ooni/probe-cli/v3/internal/model"
)
// DNSDecoder allows mocking dnsx.DNSDecoder. // DNSDecoder allows mocking dnsx.DNSDecoder.
type DNSDecoder struct { type DNSDecoder struct {
MockDecodeLookupHost func(qtype uint16, reply []byte) ([]string, error) MockDecodeLookupHost func(qtype uint16, reply []byte, queryID uint16) ([]string, error)
MockDecodeHTTPS func(reply []byte) (*model.HTTPSSvc, error) MockDecodeHTTPS func(reply []byte, queryID uint16) (*model.HTTPSSvc, error)
MockDecodeReply func(reply []byte) (*dns.Msg, error)
} }
// DecodeLookupHost calls MockDecodeLookupHost. // DecodeLookupHost calls MockDecodeLookupHost.
func (e *DNSDecoder) DecodeLookupHost(qtype uint16, reply []byte) ([]string, error) { func (e *DNSDecoder) DecodeLookupHost(qtype uint16, reply []byte, queryID uint16) ([]string, error) {
return e.MockDecodeLookupHost(qtype, reply) return e.MockDecodeLookupHost(qtype, reply, queryID)
} }
// DecodeHTTPS calls MockDecodeHTTPS. // DecodeHTTPS calls MockDecodeHTTPS.
func (e *DNSDecoder) DecodeHTTPS(reply []byte) (*model.HTTPSSvc, error) { func (e *DNSDecoder) DecodeHTTPS(reply []byte, queryID uint16) (*model.HTTPSSvc, error) {
return e.MockDecodeHTTPS(reply) return e.MockDecodeHTTPS(reply, queryID)
}
// DecodeReply calls MockDecodeReply.
func (e *DNSDecoder) DecodeReply(reply []byte) (*dns.Msg, error) {
return e.MockDecodeReply(reply)
} }

View File

@ -12,11 +12,11 @@ func TestDNSDecoder(t *testing.T) {
t.Run("DecodeLookupHost", func(t *testing.T) { t.Run("DecodeLookupHost", func(t *testing.T) {
expected := errors.New("mocked error") expected := errors.New("mocked error")
e := &DNSDecoder{ e := &DNSDecoder{
MockDecodeLookupHost: func(qtype uint16, reply []byte) ([]string, error) { MockDecodeLookupHost: func(qtype uint16, reply []byte, queryID uint16) ([]string, error) {
return nil, expected return nil, expected
}, },
} }
out, err := e.DecodeLookupHost(dns.TypeA, make([]byte, 17)) out, err := e.DecodeLookupHost(dns.TypeA, make([]byte, 17), dns.Id())
if !errors.Is(err, expected) { if !errors.Is(err, expected) {
t.Fatal("unexpected err", err) t.Fatal("unexpected err", err)
} }
@ -28,11 +28,27 @@ func TestDNSDecoder(t *testing.T) {
t.Run("DecodeHTTPS", func(t *testing.T) { t.Run("DecodeHTTPS", func(t *testing.T) {
expected := errors.New("mocked error") expected := errors.New("mocked error")
e := &DNSDecoder{ e := &DNSDecoder{
MockDecodeHTTPS: func(reply []byte) (*model.HTTPSSvc, error) { MockDecodeHTTPS: func(reply []byte, queryID uint16) (*model.HTTPSSvc, error) {
return nil, expected return nil, expected
}, },
} }
out, err := e.DecodeHTTPS(make([]byte, 17)) out, err := e.DecodeHTTPS(make([]byte, 17), dns.Id())
if !errors.Is(err, expected) {
t.Fatal("unexpected err", err)
}
if out != nil {
t.Fatal("unexpected out")
}
})
t.Run("DecodeReply", func(t *testing.T) {
expected := errors.New("mocked error")
e := &DNSDecoder{
MockDecodeReply: func(reply []byte) (*dns.Msg, error) {
return nil, expected
},
}
out, err := e.DecodeReply(make([]byte, 17))
if !errors.Is(err, expected) { if !errors.Is(err, expected) {
t.Fatal("unexpected err", err) t.Fatal("unexpected err", err)
} }

View File

@ -2,10 +2,10 @@ package mocks
// DNSEncoder allows mocking dnsx.DNSEncoder. // DNSEncoder allows mocking dnsx.DNSEncoder.
type DNSEncoder struct { type DNSEncoder struct {
MockEncode func(domain string, qtype uint16, padding bool) ([]byte, error) MockEncode func(domain string, qtype uint16, padding bool) ([]byte, uint16, error)
} }
// Encode calls MockEncode. // Encode calls MockEncode.
func (e *DNSEncoder) Encode(domain string, qtype uint16, padding bool) ([]byte, error) { func (e *DNSEncoder) Encode(domain string, qtype uint16, padding bool) ([]byte, uint16, error) {
return e.MockEncode(domain, qtype, padding) return e.MockEncode(domain, qtype, padding)
} }

View File

@ -11,16 +11,19 @@ func TestDNSEncoder(t *testing.T) {
t.Run("Encode", func(t *testing.T) { t.Run("Encode", func(t *testing.T) {
expected := errors.New("mocked error") expected := errors.New("mocked error")
e := &DNSEncoder{ e := &DNSEncoder{
MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, error) { MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, uint16, error) {
return nil, expected return nil, 0, expected
}, },
} }
out, err := e.Encode("dns.google", dns.TypeA, true) out, queryID, err := e.Encode("dns.google", dns.TypeA, true)
if !errors.Is(err, expected) { if !errors.Is(err, expected) {
t.Fatal("unexpected err", err) t.Fatal("unexpected err", err)
} }
if out != nil { if out != nil {
t.Fatal("unexpected out") t.Fatal("unexpected out")
} }
if queryID != 0 {
t.Fatal("unexpected queryID")
}
}) })
} }

View File

@ -9,6 +9,18 @@ import (
) )
func TestHTTPTransport(t *testing.T) { func TestHTTPTransport(t *testing.T) {
t.Run("Network", func(t *testing.T) {
expected := "quic"
txp := &HTTPTransport{
MockNetwork: func() string {
return expected
},
}
if txp.Network() != expected {
t.Fatal("unexpected network value")
}
})
t.Run("RoundTrip", func(t *testing.T) { t.Run("RoundTrip", func(t *testing.T) {
expected := errors.New("mocked error") expected := errors.New("mocked error")
txp := &HTTPTransport{ txp := &HTTPTransport{

View File

@ -9,6 +9,7 @@ import (
"time" "time"
"github.com/lucas-clemente/quic-go" "github.com/lucas-clemente/quic-go"
"github.com/miekg/dns"
) )
// //
@ -25,6 +26,8 @@ type DNSDecoder interface {
// //
// - data contains the reply bytes read from a DNSTransport // - data contains the reply bytes read from a DNSTransport
// //
// - queryID is the original query ID
//
// Returns: // Returns:
// //
// - on success, a list of IP addrs inside the reply and a nil error // - on success, a list of IP addrs inside the reply and a nil error
@ -33,9 +36,9 @@ type DNSDecoder interface {
// //
// Note that this function will return an error if there is no // Note that this function will return an error if there is no
// IP address inside of the reply. // IP address inside of the reply.
DecodeLookupHost(qtype uint16, data []byte) ([]string, error) DecodeLookupHost(qtype uint16, data []byte, queryID uint16) ([]string, error)
// DecodeHTTPS decodes an HTTPS reply. // DecodeHTTPS is like DecodeLookupHost but decodes an HTTPS reply.
// //
// The argument is the reply as read by the DNSTransport. // The argument is the reply as read by the DNSTransport.
// //
@ -46,7 +49,22 @@ type DNSDecoder interface {
// This function will return an error if the HTTPS reply does not // This function will return an error if the HTTPS reply does not
// contain at least a valid ALPN entry. It will not return // contain at least a valid ALPN entry. It will not return
// an error, though, when there are no IPv4/IPv6 hints in the reply. // an error, though, when there are no IPv4/IPv6 hints in the reply.
DecodeHTTPS(data []byte) (*HTTPSSvc, error) DecodeHTTPS(data []byte, queryID uint16) (*HTTPSSvc, error)
// DecodeReply decodes a DNS reply message.
//
// Arguments:
//
// - data is the raw reply
//
// If you use this function, remember that:
//
// 1. the Rcode MAY be nonzero;
//
// 2. the replyID MAY NOT match the original query ID.
//
// That is, this is a very basic parsing method.
DecodeReply(data []byte) (*dns.Msg, error)
} }
// The DNSEncoder encodes DNS queries to bytes // The DNSEncoder encodes DNS queries to bytes
@ -61,9 +79,10 @@ type DNSEncoder interface {
// //
// - padding is whether to add padding to the query. // - padding is whether to add padding to the query.
// //
// On success, this function returns a valid byte array and // On success, this function returns a valid byte array, the queryID, and
// a nil error. On failure, we have an error and the byte array is nil. // a nil error. On failure, we have a non-nil error, a nil arrary and a zero
Encode(domain string, qtype uint16, padding bool) ([]byte, error) // query ID.
Encode(domain string, qtype uint16, padding bool) ([]byte, uint16, error)
} }
// DNSTransport represents an abstract DNS transport. // DNSTransport represents an abstract DNS transport.

View File

@ -287,6 +287,9 @@ func classifyResolverError(err error) string {
if errors.Is(err, ErrOODNSServfail) { if errors.Is(err, ErrOODNSServfail) {
return FailureDNSServfailError return FailureDNSServfailError
} }
if errors.Is(err, ErrDNSReplyWithWrongQueryID) {
return FailureDNSReplyWithWrongQueryID
}
return classifyGenericError(err) return classifyGenericError(err)
} }

View File

@ -275,6 +275,12 @@ func TestClassifyResolverError(t *testing.T) {
} }
}) })
t.Run("for dns reply with wrong queryID", func(t *testing.T) {
if classifyResolverError(ErrDNSReplyWithWrongQueryID) != FailureDNSReplyWithWrongQueryID {
t.Fatal("unexpected result")
}
})
t.Run("for another kind of error", func(t *testing.T) { t.Run("for another kind of error", func(t *testing.T) {
if classifyResolverError(io.EOF) != FailureEOFError { if classifyResolverError(io.EOF) != FailureEOFError {
t.Fatal("unexpected result") t.Fatal("unexpected result")

View File

@ -1,6 +1,8 @@
package netxlite package netxlite
import ( import (
"errors"
"github.com/miekg/dns" "github.com/miekg/dns"
"github.com/ooni/probe-cli/v3/internal/model" "github.com/ooni/probe-cli/v3/internal/model"
) )
@ -8,11 +10,26 @@ import (
// DNSDecoderMiekg uses github.com/miekg/dns to implement the Decoder. // DNSDecoderMiekg uses github.com/miekg/dns to implement the Decoder.
type DNSDecoderMiekg struct{} type DNSDecoderMiekg struct{}
func (d *DNSDecoderMiekg) parseReply(data []byte) (*dns.Msg, error) { // ErrDNSReplyWithWrongQueryID indicates we have got a DNS reply with the wrong queryID.
var ErrDNSReplyWithWrongQueryID = errors.New(FailureDNSReplyWithWrongQueryID)
// DecodeReply implements model.DNSDecoder.DecodeReply
func (d *DNSDecoderMiekg) DecodeReply(data []byte) (*dns.Msg, error) {
reply := new(dns.Msg) reply := new(dns.Msg)
if err := reply.Unpack(data); err != nil { if err := reply.Unpack(data); err != nil {
return nil, err return nil, err
} }
return reply, nil
}
func (d *DNSDecoderMiekg) parseReply(data []byte, queryID uint16) (*dns.Msg, error) {
reply, err := d.DecodeReply(data)
if err != nil {
return nil, err
}
if reply.Id != queryID {
return nil, ErrDNSReplyWithWrongQueryID
}
// TODO(bassosimone): map more errors to net.DNSError names // TODO(bassosimone): map more errors to net.DNSError names
// TODO(bassosimone): add support for lame referral. // TODO(bassosimone): add support for lame referral.
switch reply.Rcode { switch reply.Rcode {
@ -29,8 +46,8 @@ func (d *DNSDecoderMiekg) parseReply(data []byte) (*dns.Msg, error) {
} }
} }
func (d *DNSDecoderMiekg) DecodeHTTPS(data []byte) (*model.HTTPSSvc, error) { func (d *DNSDecoderMiekg) DecodeHTTPS(data []byte, queryID uint16) (*model.HTTPSSvc, error) {
reply, err := d.parseReply(data) reply, err := d.parseReply(data, queryID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -64,8 +81,8 @@ func (d *DNSDecoderMiekg) DecodeHTTPS(data []byte) (*model.HTTPSSvc, error) {
return out, nil return out, nil
} }
func (d *DNSDecoderMiekg) DecodeLookupHost(qtype uint16, data []byte) ([]string, error) { func (d *DNSDecoderMiekg) DecodeLookupHost(qtype uint16, data []byte, queryID uint16) ([]string, error) {
reply, err := d.parseReply(data) reply, err := d.parseReply(data, queryID)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -8,15 +8,32 @@ import (
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
"github.com/miekg/dns" "github.com/miekg/dns"
"github.com/ooni/probe-cli/v3/internal/runtimex"
) )
func TestDNSDecoder(t *testing.T) { func TestDNSDecoder(t *testing.T) {
t.Run("LookupHost", func(t *testing.T) { t.Run("LookupHost", func(t *testing.T) {
t.Run("UnpackError", func(t *testing.T) { t.Run("UnpackError", func(t *testing.T) {
d := &DNSDecoderMiekg{} d := &DNSDecoderMiekg{}
data, err := d.DecodeLookupHost(dns.TypeA, nil) data, err := d.DecodeLookupHost(dns.TypeA, nil, 0)
if err == nil { if err == nil || err.Error() != "dns: overflow unpacking uint16" {
t.Fatal("expected an error here") t.Fatal("unexpected error", err)
}
if data != nil {
t.Fatal("expected nil data here")
}
})
t.Run("wrong query ID", func(t *testing.T) {
d := &DNSDecoderMiekg{}
const (
queryID = 17
unrelatedID = 14
)
reply := dnsGenLookupHostReplySuccess(dnsGenQuery(dns.TypeA, queryID))
data, err := d.DecodeLookupHost(dns.TypeA, reply, unrelatedID)
if !errors.Is(err, ErrDNSReplyWithWrongQueryID) {
t.Fatal("unexpected error", err)
} }
if data != nil { if data != nil {
t.Fatal("expected nil data here") t.Fatal("expected nil data here")
@ -25,8 +42,9 @@ func TestDNSDecoder(t *testing.T) {
t.Run("NXDOMAIN", func(t *testing.T) { t.Run("NXDOMAIN", func(t *testing.T) {
d := &DNSDecoderMiekg{} d := &DNSDecoderMiekg{}
data, err := d.DecodeLookupHost( queryID := dns.Id()
dns.TypeA, dnsGenReplyWithError(t, dns.TypeA, dns.RcodeNameError)) data, err := d.DecodeLookupHost(dns.TypeA, dnsGenReplyWithError(
dnsGenQuery(dns.TypeA, queryID), dns.RcodeNameError), queryID)
if err == nil || !strings.HasSuffix(err.Error(), "no such host") { if err == nil || !strings.HasSuffix(err.Error(), "no such host") {
t.Fatal("not the error we expected", err) t.Fatal("not the error we expected", err)
} }
@ -37,8 +55,9 @@ func TestDNSDecoder(t *testing.T) {
t.Run("Refused", func(t *testing.T) { t.Run("Refused", func(t *testing.T) {
d := &DNSDecoderMiekg{} d := &DNSDecoderMiekg{}
data, err := d.DecodeLookupHost( queryID := dns.Id()
dns.TypeA, dnsGenReplyWithError(t, dns.TypeA, dns.RcodeRefused)) data, err := d.DecodeLookupHost(dns.TypeA, dnsGenReplyWithError(
dnsGenQuery(dns.TypeA, queryID), dns.RcodeRefused), queryID)
if !errors.Is(err, ErrOODNSRefused) { if !errors.Is(err, ErrOODNSRefused) {
t.Fatal("not the error we expected", err) t.Fatal("not the error we expected", err)
} }
@ -49,8 +68,9 @@ func TestDNSDecoder(t *testing.T) {
t.Run("Servfail", func(t *testing.T) { t.Run("Servfail", func(t *testing.T) {
d := &DNSDecoderMiekg{} d := &DNSDecoderMiekg{}
data, err := d.DecodeLookupHost( queryID := dns.Id()
dns.TypeA, dnsGenReplyWithError(t, dns.TypeA, dns.RcodeServerFailure)) data, err := d.DecodeLookupHost(dns.TypeA, dnsGenReplyWithError(
dnsGenQuery(dns.TypeA, queryID), dns.RcodeServerFailure), queryID)
if !errors.Is(err, ErrOODNSServfail) { if !errors.Is(err, ErrOODNSServfail) {
t.Fatal("not the error we expected", err) t.Fatal("not the error we expected", err)
} }
@ -61,7 +81,9 @@ func TestDNSDecoder(t *testing.T) {
t.Run("no address", func(t *testing.T) { t.Run("no address", func(t *testing.T) {
d := &DNSDecoderMiekg{} d := &DNSDecoderMiekg{}
data, err := d.DecodeLookupHost(dns.TypeA, dnsGenLookupHostReplySuccess(t, dns.TypeA)) queryID := dns.Id()
data, err := d.DecodeLookupHost(dns.TypeA, dnsGenLookupHostReplySuccess(
dnsGenQuery(dns.TypeA, queryID)), queryID)
if !errors.Is(err, ErrOODNSNoAnswer) { if !errors.Is(err, ErrOODNSNoAnswer) {
t.Fatal("not the error we expected", err) t.Fatal("not the error we expected", err)
} }
@ -72,8 +94,9 @@ func TestDNSDecoder(t *testing.T) {
t.Run("decode A", func(t *testing.T) { t.Run("decode A", func(t *testing.T) {
d := &DNSDecoderMiekg{} d := &DNSDecoderMiekg{}
data, err := d.DecodeLookupHost( queryID := dns.Id()
dns.TypeA, dnsGenLookupHostReplySuccess(t, dns.TypeA, "1.1.1.1", "8.8.8.8")) data, err := d.DecodeLookupHost(dns.TypeA, dnsGenLookupHostReplySuccess(
dnsGenQuery(dns.TypeA, queryID), "1.1.1.1", "8.8.8.8"), queryID)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -90,8 +113,9 @@ func TestDNSDecoder(t *testing.T) {
t.Run("decode AAAA", func(t *testing.T) { t.Run("decode AAAA", func(t *testing.T) {
d := &DNSDecoderMiekg{} d := &DNSDecoderMiekg{}
data, err := d.DecodeLookupHost( queryID := dns.Id()
dns.TypeAAAA, dnsGenLookupHostReplySuccess(t, dns.TypeAAAA, "::1", "fe80::1")) data, err := d.DecodeLookupHost(dns.TypeAAAA, dnsGenLookupHostReplySuccess(
dnsGenQuery(dns.TypeAAAA, queryID), "::1", "fe80::1"), queryID)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -108,8 +132,9 @@ func TestDNSDecoder(t *testing.T) {
t.Run("unexpected A reply", func(t *testing.T) { t.Run("unexpected A reply", func(t *testing.T) {
d := &DNSDecoderMiekg{} d := &DNSDecoderMiekg{}
data, err := d.DecodeLookupHost( queryID := dns.Id()
dns.TypeA, dnsGenLookupHostReplySuccess(t, dns.TypeAAAA, "::1", "fe80::1")) data, err := d.DecodeLookupHost(dns.TypeA, dnsGenLookupHostReplySuccess(
dnsGenQuery(dns.TypeAAAA, queryID), "::1", "fe80::1"), queryID)
if !errors.Is(err, ErrOODNSNoAnswer) { if !errors.Is(err, ErrOODNSNoAnswer) {
t.Fatal("not the error we expected", err) t.Fatal("not the error we expected", err)
} }
@ -120,8 +145,9 @@ func TestDNSDecoder(t *testing.T) {
t.Run("unexpected AAAA reply", func(t *testing.T) { t.Run("unexpected AAAA reply", func(t *testing.T) {
d := &DNSDecoderMiekg{} d := &DNSDecoderMiekg{}
data, err := d.DecodeLookupHost( queryID := dns.Id()
dns.TypeAAAA, dnsGenLookupHostReplySuccess(t, dns.TypeA, "1.1.1.1", "8.8.4.4.")) data, err := d.DecodeLookupHost(dns.TypeAAAA, dnsGenLookupHostReplySuccess(
dnsGenQuery(dns.TypeA, queryID), "1.1.1.1", "8.8.4.4"), queryID)
if !errors.Is(err, ErrOODNSNoAnswer) { if !errors.Is(err, ErrOODNSNoAnswer) {
t.Fatal("not the error we expected", err) t.Fatal("not the error we expected", err)
} }
@ -139,7 +165,7 @@ func TestDNSDecoder(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
reply, err := d.parseReply(data) reply, err := d.parseReply(data, 0)
if !errors.Is(err, ErrOODNSMisbehaving) { // catch all error if !errors.Is(err, ErrOODNSMisbehaving) { // catch all error
t.Fatal("not the error we expected", err) t.Fatal("not the error we expected", err)
} }
@ -151,7 +177,7 @@ func TestDNSDecoder(t *testing.T) {
t.Run("DecodeHTTPS", func(t *testing.T) { t.Run("DecodeHTTPS", func(t *testing.T) {
t.Run("with nil data", func(t *testing.T) { t.Run("with nil data", func(t *testing.T) {
d := &DNSDecoderMiekg{} d := &DNSDecoderMiekg{}
reply, err := d.DecodeHTTPS(nil) reply, err := d.DecodeHTTPS(nil, 0)
if err == nil || err.Error() != "dns: overflow unpacking uint16" { if err == nil || err.Error() != "dns: overflow unpacking uint16" {
t.Fatal("not the error we expected", err) t.Fatal("not the error we expected", err)
} }
@ -160,10 +186,28 @@ func TestDNSDecoder(t *testing.T) {
} }
}) })
t.Run("with empty answer", func(t *testing.T) { t.Run("wrong query ID", func(t *testing.T) {
data := dnsGenHTTPSReplySuccess(t, nil, nil, nil)
d := &DNSDecoderMiekg{} d := &DNSDecoderMiekg{}
reply, err := d.DecodeHTTPS(data) const (
queryID = 17
unrelatedID = 14
)
reply := dnsGenHTTPSReplySuccess(dnsGenQuery(dns.TypeA, queryID), nil, nil, nil)
data, err := d.DecodeLookupHost(dns.TypeA, reply, unrelatedID)
if !errors.Is(err, ErrDNSReplyWithWrongQueryID) {
t.Fatal("unexpected error", err)
}
if data != nil {
t.Fatal("expected nil data here")
}
})
t.Run("with empty answer", func(t *testing.T) {
queryID := dns.Id()
data := dnsGenHTTPSReplySuccess(
dnsGenQuery(dns.TypeHTTPS, queryID), nil, nil, nil)
d := &DNSDecoderMiekg{}
reply, err := d.DecodeHTTPS(data, queryID)
if !errors.Is(err, ErrOODNSNoAnswer) { if !errors.Is(err, ErrOODNSNoAnswer) {
t.Fatal("unexpected err", err) t.Fatal("unexpected err", err)
} }
@ -173,12 +217,14 @@ func TestDNSDecoder(t *testing.T) {
}) })
t.Run("with full answer", func(t *testing.T) { t.Run("with full answer", func(t *testing.T) {
queryID := dns.Id()
alpn := []string{"h3"} alpn := []string{"h3"}
v4 := []string{"1.1.1.1"} v4 := []string{"1.1.1.1"}
v6 := []string{"::1"} v6 := []string{"::1"}
data := dnsGenHTTPSReplySuccess(t, alpn, v4, v6) data := dnsGenHTTPSReplySuccess(
dnsGenQuery(dns.TypeHTTPS, queryID), alpn, v4, v6)
d := &DNSDecoderMiekg{} d := &DNSDecoderMiekg{}
reply, err := d.DecodeHTTPS(data) reply, err := d.DecodeHTTPS(data, queryID)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -195,64 +241,73 @@ func TestDNSDecoder(t *testing.T) {
}) })
} }
// dnsGenReplyWithError generates a DNS reply for the given // dnsGenQuery generates a query suitable to be used with testing.
// query type (e.g., dns.TypeA) using code as the Rcode. func dnsGenQuery(qtype uint16, queryID uint16) []byte {
func dnsGenReplyWithError(t *testing.T, qtype uint16, code int) []byte {
question := dns.Question{ question := dns.Question{
Name: dns.Fqdn("x.org"), Name: dns.Fqdn("x.org"),
Qtype: qtype, Qtype: qtype,
Qclass: dns.ClassINET, Qclass: dns.ClassINET,
} }
query := new(dns.Msg) query := new(dns.Msg)
query.Id = dns.Id() query.Id = queryID
query.RecursionDesired = true query.RecursionDesired = true
query.Question = make([]dns.Question, 1) query.Question = make([]dns.Question, 1)
query.Question[0] = question query.Question[0] = question
data, err := query.Pack()
runtimex.PanicOnError(err, "query.Pack failed")
return data
}
// dnsGenReplyWithError generates a DNS reply for the given
// query type (e.g., dns.TypeA) using code as the Rcode.
func dnsGenReplyWithError(rawQuery []byte, code int) []byte {
query := new(dns.Msg)
err := query.Unpack(rawQuery)
runtimex.PanicOnError(err, "query.Unpack failed")
reply := new(dns.Msg) reply := new(dns.Msg)
reply.Compress = true reply.Compress = true
reply.MsgHdr.RecursionAvailable = true reply.MsgHdr.RecursionAvailable = true
reply.SetRcode(query, code) reply.SetRcode(query, code)
data, err := reply.Pack() data, err := reply.Pack()
if err != nil { runtimex.PanicOnError(err, "reply.Pack failed")
t.Fatal(err)
}
return data return data
} }
// dnsGenLookupHostReplySuccess generates a successful DNS reply for the given // dnsGenLookupHostReplySuccess generates a successful DNS reply for the given
// qtype (e.g., dns.TypeA) containing the given ips... in the answer. // qtype (e.g., dns.TypeA) containing the given ips... in the answer.
func dnsGenLookupHostReplySuccess(t *testing.T, qtype uint16, ips ...string) []byte { func dnsGenLookupHostReplySuccess(rawQuery []byte, ips ...string) []byte {
question := dns.Question{
Name: dns.Fqdn("x.org"),
Qtype: qtype,
Qclass: dns.ClassINET,
}
query := new(dns.Msg) query := new(dns.Msg)
query.Id = dns.Id() err := query.Unpack(rawQuery)
query.RecursionDesired = true runtimex.PanicOnError(err, "query.Unpack failed")
query.Question = make([]dns.Question, 1) runtimex.PanicIfFalse(len(query.Question) == 1, "more than one question")
query.Question[0] = question question := query.Question[0]
reply := new(dns.Msg) reply := new(dns.Msg)
reply.Compress = true reply.Compress = true
reply.MsgHdr.RecursionAvailable = true reply.MsgHdr.RecursionAvailable = true
reply.SetReply(query) reply.SetReply(query)
for _, ip := range ips { for _, ip := range ips {
switch qtype { switch question.Qtype {
case dns.TypeA: case dns.TypeA:
if isIPv6(ip) {
continue
}
reply.Answer = append(reply.Answer, &dns.A{ reply.Answer = append(reply.Answer, &dns.A{
Hdr: dns.RR_Header{ Hdr: dns.RR_Header{
Name: dns.Fqdn("x.org"), Name: dns.Fqdn("x.org"),
Rrtype: qtype, Rrtype: question.Qtype,
Class: dns.ClassINET, Class: dns.ClassINET,
Ttl: 0, Ttl: 0,
}, },
A: net.ParseIP(ip), A: net.ParseIP(ip),
}) })
case dns.TypeAAAA: case dns.TypeAAAA:
if !isIPv6(ip) {
continue
}
reply.Answer = append(reply.Answer, &dns.AAAA{ reply.Answer = append(reply.Answer, &dns.AAAA{
Hdr: dns.RR_Header{ Hdr: dns.RR_Header{
Name: dns.Fqdn("x.org"), Name: dns.Fqdn("x.org"),
Rrtype: qtype, Rrtype: question.Qtype,
Class: dns.ClassINET, Class: dns.ClassINET,
Ttl: 0, Ttl: 0,
}, },
@ -261,25 +316,16 @@ func dnsGenLookupHostReplySuccess(t *testing.T, qtype uint16, ips ...string) []b
} }
} }
data, err := reply.Pack() data, err := reply.Pack()
if err != nil { runtimex.PanicOnError(err, "reply.Pack failed")
t.Fatal(err)
}
return data return data
} }
// dnsGenHTTPSReplySuccess generates a successful HTTPS response containing // dnsGenHTTPSReplySuccess generates a successful HTTPS response containing
// the given (possibly nil) alpns, ipv4s, and ipv6s. // the given (possibly nil) alpns, ipv4s, and ipv6s.
func dnsGenHTTPSReplySuccess(t *testing.T, alpns, ipv4s, ipv6s []string) []byte { func dnsGenHTTPSReplySuccess(rawQuery []byte, alpns, ipv4s, ipv6s []string) []byte {
question := dns.Question{
Name: dns.Fqdn("x.org"),
Qtype: dns.TypeHTTPS,
Qclass: dns.ClassINET,
}
query := new(dns.Msg) query := new(dns.Msg)
query.Id = dns.Id() err := query.Unpack(rawQuery)
query.RecursionDesired = true runtimex.PanicOnError(err, "query.Unpack failed")
query.Question = make([]dns.Question, 1)
query.Question[0] = question
reply := new(dns.Msg) reply := new(dns.Msg)
reply.Compress = true reply.Compress = true
reply.MsgHdr.RecursionAvailable = true reply.MsgHdr.RecursionAvailable = true
@ -315,8 +361,6 @@ func dnsGenHTTPSReplySuccess(t *testing.T, alpns, ipv4s, ipv6s []string) []byte
answer.Value = append(answer.Value, &dns.SVCBIPv6Hint{Hint: addrs}) answer.Value = append(answer.Value, &dns.SVCBIPv6Hint{Hint: addrs})
} }
data, err := reply.Pack() data, err := reply.Pack()
if err != nil { runtimex.PanicOnError(err, "reply.Pack failed")
t.Fatal(err)
}
return data return data
} }

View File

@ -19,7 +19,7 @@ const (
dnsDNSSECEnabled = true dnsDNSSECEnabled = true
) )
func (e *DNSEncoderMiekg) Encode(domain string, qtype uint16, padding bool) ([]byte, error) { func (e *DNSEncoderMiekg) Encode(domain string, qtype uint16, padding bool) ([]byte, uint16, error) {
question := dns.Question{ question := dns.Question{
Name: dns.Fqdn(domain), Name: dns.Fqdn(domain),
Qtype: qtype, Qtype: qtype,
@ -43,7 +43,8 @@ func (e *DNSEncoderMiekg) Encode(domain string, qtype uint16, padding bool) ([]b
opt.Padding = make([]byte, remainder) opt.Padding = make([]byte, remainder)
query.IsEdns0().Option = append(query.IsEdns0().Option, opt) query.IsEdns0().Option = append(query.IsEdns0().Option, opt)
} }
return query.Pack() data, err := query.Pack()
return data, query.Id, err
} }
var _ model.DNSEncoder = &DNSEncoderMiekg{} var _ model.DNSEncoder = &DNSEncoderMiekg{}

View File

@ -10,7 +10,7 @@ import (
func TestDNSEncoder(t *testing.T) { func TestDNSEncoder(t *testing.T) {
t.Run("encode A", func(t *testing.T) { t.Run("encode A", func(t *testing.T) {
e := &DNSEncoderMiekg{} e := &DNSEncoderMiekg{}
data, err := e.Encode("x.org", dns.TypeA, false) data, _, err := e.Encode("x.org", dns.TypeA, false)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -19,7 +19,7 @@ func TestDNSEncoder(t *testing.T) {
t.Run("encode AAAA", func(t *testing.T) { t.Run("encode AAAA", func(t *testing.T) {
e := &DNSEncoderMiekg{} e := &DNSEncoderMiekg{}
data, err := e.Encode("x.org", dns.TypeAAAA, false) data, _, err := e.Encode("x.org", dns.TypeAAAA, false)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -31,7 +31,7 @@ func TestDNSEncoder(t *testing.T) {
// array of values we obtain the right query size. // array of values we obtain the right query size.
getquerylen := func(domainlen int, padding bool) int { getquerylen := func(domainlen int, padding bool) int {
e := &DNSEncoderMiekg{} e := &DNSEncoderMiekg{}
data, err := e.Encode( data, _, err := e.Encode(
// This is not a valid name because it ends up being way // This is not a valid name because it ends up being way
// longer than 255 octets. However, the library is allowing // longer than 255 octets. However, the library is allowing
// us to generate such name and we are not going to send // us to generate such name and we are not going to send

View File

@ -1,5 +1,5 @@
// Code generated by go generate; DO NOT EDIT. // Code generated by go generate; DO NOT EDIT.
// Generated: 2022-05-13 19:09:07.343096 +0200 CEST m=+0.512294417 // Generated: 2022-05-14 18:04:43.744122 +0200 CEST m=+0.315992417
package netxlite package netxlite
@ -25,6 +25,7 @@ const (
FailureDNSNoAnswer = "dns_no_answer" FailureDNSNoAnswer = "dns_no_answer"
FailureDNSNonRecoverableFailure = "dns_non_recoverable_failure" FailureDNSNonRecoverableFailure = "dns_non_recoverable_failure"
FailureDNSRefusedError = "dns_refused_error" FailureDNSRefusedError = "dns_refused_error"
FailureDNSReplyWithWrongQueryID = "dns_reply_with_wrong_query_id"
FailureDNSServerMisbehaving = "dns_server_misbehaving" FailureDNSServerMisbehaving = "dns_server_misbehaving"
FailureDNSServfailError = "dns_servfail_error" FailureDNSServfailError = "dns_servfail_error"
FailureDNSTemporaryFailure = "dns_temporary_failure" FailureDNSTemporaryFailure = "dns_temporary_failure"
@ -75,6 +76,7 @@ var failuresMap = map[string]string{
"dns_non_recoverable_failure": "dns_non_recoverable_failure", "dns_non_recoverable_failure": "dns_non_recoverable_failure",
"dns_nxdomain_error": "dns_nxdomain_error", "dns_nxdomain_error": "dns_nxdomain_error",
"dns_refused_error": "dns_refused_error", "dns_refused_error": "dns_refused_error",
"dns_reply_with_wrong_query_id": "dns_reply_with_wrong_query_id",
"dns_server_misbehaving": "dns_server_misbehaving", "dns_server_misbehaving": "dns_server_misbehaving",
"dns_servfail_error": "dns_servfail_error", "dns_servfail_error": "dns_servfail_error",
"dns_temporary_failure": "dns_temporary_failure", "dns_temporary_failure": "dns_temporary_failure",

View File

@ -159,6 +159,7 @@ var Specs = []*ErrorSpec{
NewLibraryError("DNS_server_misbehaving"), NewLibraryError("DNS_server_misbehaving"),
NewLibraryError("DNS_no_answer"), NewLibraryError("DNS_no_answer"),
NewLibraryError("DNS_servfail_error"), NewLibraryError("DNS_servfail_error"),
NewLibraryError("DNS_reply_with_wrong_query_ID"),
NewLibraryError("EOF_error"), NewLibraryError("EOF_error"),
NewLibraryError("generic_timeout_error"), NewLibraryError("generic_timeout_error"),
NewLibraryError("QUIC_incompatible_version"), NewLibraryError("QUIC_incompatible_version"),

View File

@ -85,7 +85,7 @@ func (r *ParallelResolver) LookupHost(ctx context.Context, hostname string) ([]s
// LookupHTTPS implements Resolver.LookupHTTPS. // LookupHTTPS implements Resolver.LookupHTTPS.
func (r *ParallelResolver) LookupHTTPS( func (r *ParallelResolver) LookupHTTPS(
ctx context.Context, hostname string) (*model.HTTPSSvc, error) { ctx context.Context, hostname string) (*model.HTTPSSvc, error) {
querydata, err := r.Encoder.Encode( querydata, queryID, err := r.Encoder.Encode(
hostname, dns.TypeHTTPS, r.Txp.RequiresPadding()) hostname, dns.TypeHTTPS, r.Txp.RequiresPadding())
if err != nil { if err != nil {
return nil, err return nil, err
@ -94,7 +94,7 @@ func (r *ParallelResolver) LookupHTTPS(
if err != nil { if err != nil {
return nil, err return nil, err
} }
return r.Decoder.DecodeHTTPS(replydata) return r.Decoder.DecodeHTTPS(replydata, queryID)
} }
// parallelResolverResult is the internal representation of a // parallelResolverResult is the internal representation of a
@ -107,7 +107,7 @@ type parallelResolverResult struct {
// lookupHost issues a lookup host query for the specified qtype (e.g., dns.A). // lookupHost issues a lookup host query for the specified qtype (e.g., dns.A).
func (r *ParallelResolver) lookupHost(ctx context.Context, hostname string, func (r *ParallelResolver) lookupHost(ctx context.Context, hostname string,
qtype uint16, out chan<- *parallelResolverResult) { qtype uint16, out chan<- *parallelResolverResult) {
querydata, err := r.Encoder.Encode(hostname, qtype, r.Txp.RequiresPadding()) querydata, queryID, err := r.Encoder.Encode(hostname, qtype, r.Txp.RequiresPadding())
if err != nil { if err != nil {
out <- &parallelResolverResult{ out <- &parallelResolverResult{
addrs: []string{}, addrs: []string{},
@ -123,7 +123,7 @@ func (r *ParallelResolver) lookupHost(ctx context.Context, hostname string,
} }
return return
} }
addrs, err := r.Decoder.DecodeLookupHost(qtype, replydata) addrs, err := r.Decoder.DecodeLookupHost(qtype, replydata, queryID)
out <- &parallelResolverResult{ out <- &parallelResolverResult{
addrs: addrs, addrs: addrs,
err: err, err: err,

View File

@ -34,8 +34,8 @@ func TestParallelResolver(t *testing.T) {
txp := NewDNSOverTLS((&tls.Dialer{}).DialContext, "8.8.8.8:853") txp := NewDNSOverTLS((&tls.Dialer{}).DialContext, "8.8.8.8:853")
r := ParallelResolver{ r := ParallelResolver{
Encoder: &mocks.DNSEncoder{ Encoder: &mocks.DNSEncoder{
MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, error) { MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, uint16, error) {
return nil, mocked return nil, 0, mocked
}, },
}, },
Txp: txp, Txp: txp,
@ -72,7 +72,7 @@ func TestParallelResolver(t *testing.T) {
t.Run("empty reply", func(t *testing.T) { t.Run("empty reply", func(t *testing.T) {
txp := &mocks.DNSTransport{ txp := &mocks.DNSTransport{
MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) { MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) {
return dnsGenLookupHostReplySuccess(t, dns.TypeA), nil return dnsGenLookupHostReplySuccess(query), nil
}, },
MockRequiresPadding: func() bool { MockRequiresPadding: func() bool {
return true return true
@ -91,7 +91,7 @@ func TestParallelResolver(t *testing.T) {
t.Run("with A reply", func(t *testing.T) { t.Run("with A reply", func(t *testing.T) {
txp := &mocks.DNSTransport{ txp := &mocks.DNSTransport{
MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) { MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) {
return dnsGenLookupHostReplySuccess(t, dns.TypeA, "8.8.8.8"), nil return dnsGenLookupHostReplySuccess(query, "8.8.8.8"), nil
}, },
MockRequiresPadding: func() bool { MockRequiresPadding: func() bool {
return true return true
@ -103,14 +103,14 @@ func TestParallelResolver(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
if len(addrs) != 1 || addrs[0] != "8.8.8.8" { if len(addrs) != 1 || addrs[0] != "8.8.8.8" {
t.Fatal("not the result we expected") t.Fatal("not the result we expected", addrs)
} }
}) })
t.Run("with AAAA reply", func(t *testing.T) { t.Run("with AAAA reply", func(t *testing.T) {
txp := &mocks.DNSTransport{ txp := &mocks.DNSTransport{
MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) { MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) {
return dnsGenLookupHostReplySuccess(t, dns.TypeAAAA, "::1"), nil return dnsGenLookupHostReplySuccess(query, "::1"), nil
}, },
MockRequiresPadding: func() bool { MockRequiresPadding: func() bool {
return true return true
@ -122,7 +122,7 @@ func TestParallelResolver(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
if len(addrs) != 1 || addrs[0] != "::1" { if len(addrs) != 1 || addrs[0] != "::1" {
t.Fatal("not the result we expected") t.Fatal("not the result we expected", addrs)
} }
}) })
@ -182,8 +182,8 @@ func TestParallelResolver(t *testing.T) {
expected := errors.New("mocked error") expected := errors.New("mocked error")
r := &ParallelResolver{ r := &ParallelResolver{
Encoder: &mocks.DNSEncoder{ Encoder: &mocks.DNSEncoder{
MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, error) { MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, uint16, error) {
return nil, expected return nil, 0, expected
}, },
}, },
Decoder: nil, Decoder: nil,
@ -208,8 +208,8 @@ func TestParallelResolver(t *testing.T) {
expected := errors.New("mocked error") expected := errors.New("mocked error")
r := &ParallelResolver{ r := &ParallelResolver{
Encoder: &mocks.DNSEncoder{ Encoder: &mocks.DNSEncoder{
MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, error) { MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, uint16, error) {
return make([]byte, 64), nil return make([]byte, 64), 0, nil
}, },
}, },
Decoder: nil, Decoder: nil,
@ -237,12 +237,12 @@ func TestParallelResolver(t *testing.T) {
expected := errors.New("mocked error") expected := errors.New("mocked error")
r := &ParallelResolver{ r := &ParallelResolver{
Encoder: &mocks.DNSEncoder{ Encoder: &mocks.DNSEncoder{
MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, error) { MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, uint16, error) {
return make([]byte, 64), nil return make([]byte, 64), 0, nil
}, },
}, },
Decoder: &mocks.DNSDecoder{ Decoder: &mocks.DNSDecoder{
MockDecodeHTTPS: func(reply []byte) (*model.HTTPSSvc, error) { MockDecodeHTTPS: func(reply []byte, queryID uint16) (*model.HTTPSSvc, error) {
return nil, expected return nil, expected
}, },
}, },

View File

@ -84,7 +84,7 @@ func (r *SerialResolver) LookupHost(ctx context.Context, hostname string) ([]str
// LookupHTTPS implements Resolver.LookupHTTPS. // LookupHTTPS implements Resolver.LookupHTTPS.
func (r *SerialResolver) LookupHTTPS( func (r *SerialResolver) LookupHTTPS(
ctx context.Context, hostname string) (*model.HTTPSSvc, error) { ctx context.Context, hostname string) (*model.HTTPSSvc, error) {
querydata, err := r.Encoder.Encode( querydata, queryID, err := r.Encoder.Encode(
hostname, dns.TypeHTTPS, r.Txp.RequiresPadding()) hostname, dns.TypeHTTPS, r.Txp.RequiresPadding())
if err != nil { if err != nil {
return nil, err return nil, err
@ -93,7 +93,7 @@ func (r *SerialResolver) LookupHTTPS(
if err != nil { if err != nil {
return nil, err return nil, err
} }
return r.Decoder.DecodeHTTPS(replydata) return r.Decoder.DecodeHTTPS(replydata, queryID)
} }
func (r *SerialResolver) lookupHostWithRetry( func (r *SerialResolver) lookupHostWithRetry(
@ -126,7 +126,7 @@ func (r *SerialResolver) lookupHostWithRetry(
// qtype (dns.A or dns.AAAA) without retrying on failure. // qtype (dns.A or dns.AAAA) without retrying on failure.
func (r *SerialResolver) lookupHostWithoutRetry( func (r *SerialResolver) lookupHostWithoutRetry(
ctx context.Context, hostname string, qtype uint16) ([]string, error) { ctx context.Context, hostname string, qtype uint16) ([]string, error) {
querydata, err := r.Encoder.Encode(hostname, qtype, r.Txp.RequiresPadding()) querydata, queryID, err := r.Encoder.Encode(hostname, qtype, r.Txp.RequiresPadding())
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -134,5 +134,5 @@ func (r *SerialResolver) lookupHostWithoutRetry(
if err != nil { if err != nil {
return nil, err return nil, err
} }
return r.Decoder.DecodeLookupHost(qtype, replydata) return r.Decoder.DecodeLookupHost(qtype, replydata, queryID)
} }

View File

@ -7,7 +7,6 @@ import (
"net" "net"
"testing" "testing"
"github.com/miekg/dns"
"github.com/ooni/probe-cli/v3/internal/atomicx" "github.com/ooni/probe-cli/v3/internal/atomicx"
"github.com/ooni/probe-cli/v3/internal/model" "github.com/ooni/probe-cli/v3/internal/model"
"github.com/ooni/probe-cli/v3/internal/model/mocks" "github.com/ooni/probe-cli/v3/internal/model/mocks"
@ -51,8 +50,8 @@ func TestSerialResolver(t *testing.T) {
txp := NewDNSOverTLS((&tls.Dialer{}).DialContext, "8.8.8.8:853") txp := NewDNSOverTLS((&tls.Dialer{}).DialContext, "8.8.8.8:853")
r := SerialResolver{ r := SerialResolver{
Encoder: &mocks.DNSEncoder{ Encoder: &mocks.DNSEncoder{
MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, error) { MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, uint16, error) {
return nil, mocked return nil, 0, mocked
}, },
}, },
Txp: txp, Txp: txp,
@ -89,7 +88,7 @@ func TestSerialResolver(t *testing.T) {
t.Run("empty reply", func(t *testing.T) { t.Run("empty reply", func(t *testing.T) {
txp := &mocks.DNSTransport{ txp := &mocks.DNSTransport{
MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) { MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) {
return dnsGenLookupHostReplySuccess(t, dns.TypeA), nil return dnsGenLookupHostReplySuccess(query), nil
}, },
MockRequiresPadding: func() bool { MockRequiresPadding: func() bool {
return true return true
@ -108,7 +107,7 @@ func TestSerialResolver(t *testing.T) {
t.Run("with A reply", func(t *testing.T) { t.Run("with A reply", func(t *testing.T) {
txp := &mocks.DNSTransport{ txp := &mocks.DNSTransport{
MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) { MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) {
return dnsGenLookupHostReplySuccess(t, dns.TypeA, "8.8.8.8"), nil return dnsGenLookupHostReplySuccess(query, "8.8.8.8"), nil
}, },
MockRequiresPadding: func() bool { MockRequiresPadding: func() bool {
return true return true
@ -127,7 +126,7 @@ func TestSerialResolver(t *testing.T) {
t.Run("with AAAA reply", func(t *testing.T) { t.Run("with AAAA reply", func(t *testing.T) {
txp := &mocks.DNSTransport{ txp := &mocks.DNSTransport{
MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) { MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) {
return dnsGenLookupHostReplySuccess(t, dns.TypeAAAA, "::1"), nil return dnsGenLookupHostReplySuccess(query, "::1"), nil
}, },
MockRequiresPadding: func() bool { MockRequiresPadding: func() bool {
return true return true
@ -189,8 +188,8 @@ func TestSerialResolver(t *testing.T) {
expected := errors.New("mocked error") expected := errors.New("mocked error")
r := &SerialResolver{ r := &SerialResolver{
Encoder: &mocks.DNSEncoder{ Encoder: &mocks.DNSEncoder{
MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, error) { MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, uint16, error) {
return nil, expected return nil, 0, expected
}, },
}, },
Decoder: nil, Decoder: nil,
@ -215,8 +214,8 @@ func TestSerialResolver(t *testing.T) {
expected := errors.New("mocked error") expected := errors.New("mocked error")
r := &SerialResolver{ r := &SerialResolver{
Encoder: &mocks.DNSEncoder{ Encoder: &mocks.DNSEncoder{
MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, error) { MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, uint16, error) {
return make([]byte, 64), nil return make([]byte, 64), 0, nil
}, },
}, },
Decoder: nil, Decoder: nil,
@ -244,12 +243,12 @@ func TestSerialResolver(t *testing.T) {
expected := errors.New("mocked error") expected := errors.New("mocked error")
r := &SerialResolver{ r := &SerialResolver{
Encoder: &mocks.DNSEncoder{ Encoder: &mocks.DNSEncoder{
MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, error) { MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, uint16, error) {
return make([]byte, 64), nil return make([]byte, 64), 0, nil
}, },
}, },
Decoder: &mocks.DNSDecoder{ Decoder: &mocks.DNSDecoder{
MockDecodeHTTPS: func(reply []byte) (*model.HTTPSSvc, error) { MockDecodeHTTPS: func(reply []byte, queryID uint16) (*model.HTTPSSvc, error) {
return nil, expected return nil, expected
}, },
}, },