refactor: DNSTransport I/Os DNS messages (#760)
This diff refactors the DNSTransport model to receive in input a DNSQuery and return in output a DNSResponse. The design of DNSQuery and DNSResponse takes into account the use case of a transport using getaddrinfo, meaning that we don't need to serialize and deserialize messages when using getaddrinfo. The current codebase does not use a getaddrinfo transport, but I wrote one such a transport in the Websteps Winter 2021 prototype (https://github.com/bassosimone/websteps-illustrated/). The design conversation that lead to producing this diff is https://github.com/ooni/probe/issues/2099
This commit is contained in:
parent
7a0a156aec
commit
01a513a496
1
go.mod
1
go.mod
|
@ -17,7 +17,6 @@ require (
|
|||
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510
|
||||
github.com/google/uuid v1.3.0
|
||||
github.com/gorilla/websocket v1.5.0
|
||||
github.com/hexops/gotextdiff v1.0.3
|
||||
github.com/iancoleman/strcase v0.2.0
|
||||
github.com/lucas-clemente/quic-go v0.27.0
|
||||
github.com/mattn/go-colorable v0.1.12
|
||||
|
|
2
go.sum
2
go.sum
|
@ -378,8 +378,6 @@ github.com/hashicorp/logutils v1.0.0/go.mod h1:QIAnNjmIWmVIIkWDTG1z5v++HQmx9WQRO
|
|||
github.com/hashicorp/mdns v1.0.0/go.mod h1:tL+uN++7HEJ6SQLQ2/p+z2pH24WQKWjBPkE0mNTz8vQ=
|
||||
github.com/hashicorp/memberlist v0.1.3/go.mod h1:ajVTdAv/9Im8oMAAj5G31PhhMCZJV2pPBoIllUwCN7I=
|
||||
github.com/hashicorp/serf v0.8.2/go.mod h1:6hOLApaqBFA1NXqRQAsxw9QxuDEvNxSQRwA/JwenrHc=
|
||||
github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUqJM=
|
||||
github.com/hexops/gotextdiff v1.0.3/go.mod h1:pSWU5MAI3yDq+fZBTazCSJysOMbxWL1BSow5/V2vxeg=
|
||||
github.com/hinshun/vt10x v0.0.0-20220119200601-820417d04eec h1:qv2VnGeEQHchGaZ/u7lxST/RaJw+cv273q79D81Xbog=
|
||||
github.com/hinshun/vt10x v0.0.0-20220119200601-820417d04eec/go.mod h1:Q48J4R4DvxnHolD5P8pOtXigYlRuPLGl6moFx3ulM68=
|
||||
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
|
||||
|
|
|
@ -317,7 +317,7 @@ func NewDNSClientWithOverrides(config Config, URL, hostOverride, SNIOverride,
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var txp model.DNSTransport = netxlite.NewDNSOverTLS(
|
||||
var txp model.DNSTransport = netxlite.NewDNSOverTLSTransport(
|
||||
tlsDialer.DialTLSContext, endpoint)
|
||||
if config.ResolveSaver != nil {
|
||||
txp = resolver.SaverDNSTransport{
|
||||
|
|
|
@ -76,40 +76,6 @@ func (c *FakeConn) SetWriteDeadline(t time.Time) (err error) {
|
|||
return c.SetWriteDeadlineError
|
||||
}
|
||||
|
||||
type FakeTransport struct {
|
||||
Data []byte
|
||||
Err error
|
||||
}
|
||||
|
||||
func (ft FakeTransport) RoundTrip(ctx context.Context, query []byte) ([]byte, error) {
|
||||
return ft.Data, ft.Err
|
||||
}
|
||||
|
||||
func (ft FakeTransport) RequiresPadding() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (ft FakeTransport) Address() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (ft FakeTransport) Network() string {
|
||||
return "fake"
|
||||
}
|
||||
|
||||
func (fk FakeTransport) CloseIdleConnections() {
|
||||
// nothing to do
|
||||
}
|
||||
|
||||
type FakeEncoder struct {
|
||||
Data []byte
|
||||
Err error
|
||||
}
|
||||
|
||||
func (fe FakeEncoder) Encode(domain string, qtype uint16, padding bool) ([]byte, error) {
|
||||
return fe.Data, fe.Err
|
||||
}
|
||||
|
||||
func NewFakeResolverThatFails() model.Resolver {
|
||||
return NewFakeResolverWithExplicitError(netxlite.ErrOODNSNoSuchHost)
|
||||
}
|
||||
|
|
|
@ -99,14 +99,14 @@ func TestNewResolverTCPDomain(t *testing.T) {
|
|||
|
||||
func TestNewResolverDoTAddress(t *testing.T) {
|
||||
reso := netxlite.NewSerialResolver(
|
||||
netxlite.NewDNSOverTLS(new(tls.Dialer).DialContext, "8.8.8.8:853"))
|
||||
netxlite.NewDNSOverTLSTransport(new(tls.Dialer).DialContext, "8.8.8.8:853"))
|
||||
testresolverquick(t, reso)
|
||||
testresolverquickidna(t, reso)
|
||||
}
|
||||
|
||||
func TestNewResolverDoTDomain(t *testing.T) {
|
||||
reso := netxlite.NewSerialResolver(
|
||||
netxlite.NewDNSOverTLS(new(tls.Dialer).DialContext, "dns.google.com:853"))
|
||||
netxlite.NewDNSOverTLSTransport(new(tls.Dialer).DialContext, "dns.google.com:853"))
|
||||
testresolverquick(t, reso)
|
||||
testresolverquickidna(t, reso)
|
||||
}
|
||||
|
|
|
@ -46,28 +46,41 @@ type SaverDNSTransport struct {
|
|||
}
|
||||
|
||||
// RoundTrip implements RoundTripper.RoundTrip
|
||||
func (txp SaverDNSTransport) RoundTrip(ctx context.Context, query []byte) ([]byte, error) {
|
||||
func (txp SaverDNSTransport) RoundTrip(
|
||||
ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
|
||||
start := time.Now()
|
||||
txp.Saver.Write(trace.Event{
|
||||
Address: txp.Address(),
|
||||
DNSQuery: query,
|
||||
DNSQuery: txp.maybeQueryBytes(query),
|
||||
Name: "dns_round_trip_start",
|
||||
Proto: txp.Network(),
|
||||
Time: start,
|
||||
})
|
||||
reply, err := txp.DNSTransport.RoundTrip(ctx, query)
|
||||
response, err := txp.DNSTransport.RoundTrip(ctx, query)
|
||||
stop := time.Now()
|
||||
txp.Saver.Write(trace.Event{
|
||||
Address: txp.Address(),
|
||||
DNSQuery: query,
|
||||
DNSReply: reply,
|
||||
DNSQuery: txp.maybeQueryBytes(query),
|
||||
DNSReply: txp.maybeResponseBytes(response),
|
||||
Duration: stop.Sub(start),
|
||||
Err: err,
|
||||
Name: "dns_round_trip_done",
|
||||
Proto: txp.Network(),
|
||||
Time: stop,
|
||||
})
|
||||
return reply, err
|
||||
return response, err
|
||||
}
|
||||
|
||||
func (txp SaverDNSTransport) maybeQueryBytes(query model.DNSQuery) []byte {
|
||||
data, _ := query.Bytes()
|
||||
return data
|
||||
}
|
||||
|
||||
func (txp SaverDNSTransport) maybeResponseBytes(response model.DNSResponse) []byte {
|
||||
if response == nil {
|
||||
return nil
|
||||
}
|
||||
return response.Bytes()
|
||||
}
|
||||
|
||||
var _ model.Resolver = SaverResolver{}
|
||||
|
|
|
@ -10,6 +10,8 @@ import (
|
|||
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/resolver"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/trace"
|
||||
"github.com/ooni/probe-cli/v3/internal/model"
|
||||
"github.com/ooni/probe-cli/v3/internal/model/mocks"
|
||||
)
|
||||
|
||||
func TestSaverResolverFailure(t *testing.T) {
|
||||
|
@ -110,12 +112,25 @@ func TestSaverDNSTransportFailure(t *testing.T) {
|
|||
expected := errors.New("no such host")
|
||||
saver := &trace.Saver{}
|
||||
txp := resolver.SaverDNSTransport{
|
||||
DNSTransport: resolver.FakeTransport{
|
||||
Err: expected,
|
||||
DNSTransport: &mocks.DNSTransport{
|
||||
MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
|
||||
return nil, expected
|
||||
},
|
||||
MockNetwork: func() string {
|
||||
return "fake"
|
||||
},
|
||||
MockAddress: func() string {
|
||||
return ""
|
||||
},
|
||||
},
|
||||
Saver: saver,
|
||||
}
|
||||
query := []byte("abc")
|
||||
rawQuery := []byte{0xde, 0xad, 0xbe, 0xef}
|
||||
query := &mocks.DNSQuery{
|
||||
MockBytes: func() ([]byte, error) {
|
||||
return rawQuery, nil
|
||||
},
|
||||
}
|
||||
reply, err := txp.RoundTrip(context.Background(), query)
|
||||
if !errors.Is(err, expected) {
|
||||
t.Fatal("not the error we expected")
|
||||
|
@ -127,7 +142,7 @@ func TestSaverDNSTransportFailure(t *testing.T) {
|
|||
if len(ev) != 2 {
|
||||
t.Fatal("expected number of events")
|
||||
}
|
||||
if !bytes.Equal(ev[0].DNSQuery, query) {
|
||||
if !bytes.Equal(ev[0].DNSQuery, rawQuery) {
|
||||
t.Fatal("unexpected DNSQuery")
|
||||
}
|
||||
if ev[0].Name != "dns_round_trip_start" {
|
||||
|
@ -136,7 +151,7 @@ func TestSaverDNSTransportFailure(t *testing.T) {
|
|||
if !ev[0].Time.Before(time.Now()) {
|
||||
t.Fatal("the saved time is wrong")
|
||||
}
|
||||
if !bytes.Equal(ev[1].DNSQuery, query) {
|
||||
if !bytes.Equal(ev[1].DNSQuery, rawQuery) {
|
||||
t.Fatal("unexpected DNSQuery")
|
||||
}
|
||||
if ev[1].DNSReply != nil {
|
||||
|
@ -157,27 +172,45 @@ func TestSaverDNSTransportFailure(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestSaverDNSTransportSuccess(t *testing.T) {
|
||||
expected := []byte("def")
|
||||
expected := []byte{0xef, 0xbe, 0xad, 0xde}
|
||||
saver := &trace.Saver{}
|
||||
response := &mocks.DNSResponse{
|
||||
MockBytes: func() []byte {
|
||||
return expected
|
||||
},
|
||||
}
|
||||
txp := resolver.SaverDNSTransport{
|
||||
DNSTransport: resolver.FakeTransport{
|
||||
Data: expected,
|
||||
DNSTransport: &mocks.DNSTransport{
|
||||
MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
|
||||
return response, nil
|
||||
},
|
||||
MockNetwork: func() string {
|
||||
return "fake"
|
||||
},
|
||||
MockAddress: func() string {
|
||||
return ""
|
||||
},
|
||||
},
|
||||
Saver: saver,
|
||||
}
|
||||
query := []byte("abc")
|
||||
rawQuery := []byte{0xde, 0xad, 0xbe, 0xef}
|
||||
query := &mocks.DNSQuery{
|
||||
MockBytes: func() ([]byte, error) {
|
||||
return rawQuery, nil
|
||||
},
|
||||
}
|
||||
reply, err := txp.RoundTrip(context.Background(), query)
|
||||
if err != nil {
|
||||
t.Fatal("we expected nil error here")
|
||||
}
|
||||
if !bytes.Equal(reply, expected) {
|
||||
if !bytes.Equal(reply.Bytes(), expected) {
|
||||
t.Fatal("expected another reply here")
|
||||
}
|
||||
ev := saver.Read()
|
||||
if len(ev) != 2 {
|
||||
t.Fatal("expected number of events")
|
||||
}
|
||||
if !bytes.Equal(ev[0].DNSQuery, query) {
|
||||
if !bytes.Equal(ev[0].DNSQuery, rawQuery) {
|
||||
t.Fatal("unexpected DNSQuery")
|
||||
}
|
||||
if ev[0].Name != "dns_round_trip_start" {
|
||||
|
@ -186,7 +219,7 @@ func TestSaverDNSTransportSuccess(t *testing.T) {
|
|||
if !ev[0].Time.Before(time.Now()) {
|
||||
t.Fatal("the saved time is wrong")
|
||||
}
|
||||
if !bytes.Equal(ev[1].DNSQuery, query) {
|
||||
if !bytes.Equal(ev[1].DNSQuery, rawQuery) {
|
||||
t.Fatal("unexpected DNSQuery")
|
||||
}
|
||||
if !bytes.Equal(ev[1].DNSReply, expected) {
|
||||
|
|
|
@ -36,18 +36,31 @@ type DNSRoundTripEvent struct {
|
|||
Reply []byte
|
||||
}
|
||||
|
||||
func (txp *dnsxRoundTripperDB) RoundTrip(ctx context.Context, query []byte) ([]byte, error) {
|
||||
func (txp *dnsxRoundTripperDB) RoundTrip(
|
||||
ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
|
||||
started := time.Since(txp.begin).Seconds()
|
||||
reply, err := txp.DNSTransport.RoundTrip(ctx, query)
|
||||
response, err := txp.DNSTransport.RoundTrip(ctx, query)
|
||||
finished := time.Since(txp.begin).Seconds()
|
||||
txp.db.InsertIntoDNSRoundTrip(&DNSRoundTripEvent{
|
||||
Network: txp.DNSTransport.Network(),
|
||||
Address: txp.DNSTransport.Address(),
|
||||
Query: query,
|
||||
Query: txp.maybeQueryBytes(query),
|
||||
Started: started,
|
||||
Finished: finished,
|
||||
Failure: NewFailure(err),
|
||||
Reply: reply,
|
||||
Reply: txp.maybeResponseBytes(response),
|
||||
})
|
||||
return reply, err
|
||||
return response, err
|
||||
}
|
||||
|
||||
func (txp *dnsxRoundTripperDB) maybeQueryBytes(query model.DNSQuery) []byte {
|
||||
data, _ := query.Bytes()
|
||||
return data
|
||||
}
|
||||
|
||||
func (txp *dnsxRoundTripperDB) maybeResponseBytes(response model.DNSResponse) []byte {
|
||||
if response == nil {
|
||||
return nil
|
||||
}
|
||||
return response.Bytes()
|
||||
}
|
||||
|
|
|
@ -1,36 +1,20 @@
|
|||
package mocks
|
||||
|
||||
import (
|
||||
"net"
|
||||
//
|
||||
// Mocks for model.DNSDecoder
|
||||
//
|
||||
|
||||
"github.com/miekg/dns"
|
||||
import (
|
||||
"github.com/ooni/probe-cli/v3/internal/model"
|
||||
)
|
||||
|
||||
// DNSDecoder allows mocking dnsx.DNSDecoder.
|
||||
// DNSDecoder allows mocking model.DNSDecoder.
|
||||
type DNSDecoder struct {
|
||||
MockDecodeLookupHost func(qtype uint16, reply []byte, queryID uint16) ([]string, error)
|
||||
MockDecodeHTTPS func(reply []byte, queryID uint16) (*model.HTTPSSvc, error)
|
||||
MockDecodeNS func(reply []byte, queryID uint16) ([]*net.NS, error)
|
||||
MockDecodeReply func(reply []byte) (*dns.Msg, error)
|
||||
MockDecodeResponse func(data []byte, query model.DNSQuery) (model.DNSResponse, error)
|
||||
}
|
||||
|
||||
// DecodeLookupHost calls MockDecodeLookupHost.
|
||||
func (e *DNSDecoder) DecodeLookupHost(qtype uint16, reply []byte, queryID uint16) ([]string, error) {
|
||||
return e.MockDecodeLookupHost(qtype, reply, queryID)
|
||||
}
|
||||
var _ model.DNSDecoder = &DNSDecoder{}
|
||||
|
||||
// DecodeHTTPS calls MockDecodeHTTPS.
|
||||
func (e *DNSDecoder) DecodeHTTPS(reply []byte, queryID uint16) (*model.HTTPSSvc, error) {
|
||||
return e.MockDecodeHTTPS(reply, queryID)
|
||||
}
|
||||
|
||||
// DecodeNS calls MockDecodeNS.
|
||||
func (e *DNSDecoder) DecodeNS(reply []byte, queryID uint16) ([]*net.NS, error) {
|
||||
return e.MockDecodeNS(reply, queryID)
|
||||
}
|
||||
|
||||
// DecodeReply calls MockDecodeReply.
|
||||
func (e *DNSDecoder) DecodeReply(reply []byte) (*dns.Msg, error) {
|
||||
return e.MockDecodeReply(reply)
|
||||
func (e *DNSDecoder) DecodeResponse(data []byte, query model.DNSQuery) (model.DNSResponse, error) {
|
||||
return e.MockDecodeResponse(data, query)
|
||||
}
|
||||
|
|
|
@ -2,70 +2,20 @@ package mocks
|
|||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/ooni/probe-cli/v3/internal/model"
|
||||
)
|
||||
|
||||
func TestDNSDecoder(t *testing.T) {
|
||||
t.Run("DecodeLookupHost", func(t *testing.T) {
|
||||
t.Run("DecodeResponse", func(t *testing.T) {
|
||||
expected := errors.New("mocked error")
|
||||
e := &DNSDecoder{
|
||||
MockDecodeLookupHost: func(qtype uint16, reply []byte, queryID uint16) ([]string, error) {
|
||||
MockDecodeResponse: func(reply []byte, query model.DNSQuery) (model.DNSResponse, error) {
|
||||
return nil, expected
|
||||
},
|
||||
}
|
||||
out, err := e.DecodeLookupHost(dns.TypeA, make([]byte, 17), dns.Id())
|
||||
if !errors.Is(err, expected) {
|
||||
t.Fatal("unexpected err", err)
|
||||
}
|
||||
if out != nil {
|
||||
t.Fatal("unexpected out")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("DecodeHTTPS", func(t *testing.T) {
|
||||
expected := errors.New("mocked error")
|
||||
e := &DNSDecoder{
|
||||
MockDecodeHTTPS: func(reply []byte, queryID uint16) (*model.HTTPSSvc, error) {
|
||||
return nil, expected
|
||||
},
|
||||
}
|
||||
out, err := e.DecodeHTTPS(make([]byte, 17), dns.Id())
|
||||
if !errors.Is(err, expected) {
|
||||
t.Fatal("unexpected err", err)
|
||||
}
|
||||
if out != nil {
|
||||
t.Fatal("unexpected out")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("DecodeNS", func(t *testing.T) {
|
||||
expected := errors.New("mocked error")
|
||||
e := &DNSDecoder{
|
||||
MockDecodeNS: func(reply []byte, queryID uint16) ([]*net.NS, error) {
|
||||
return nil, expected
|
||||
},
|
||||
}
|
||||
out, err := e.DecodeNS(make([]byte, 17), dns.Id())
|
||||
if !errors.Is(err, expected) {
|
||||
t.Fatal("unexpected err", err)
|
||||
}
|
||||
if out != nil {
|
||||
t.Fatal("unexpected out")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("DecodeReply", func(t *testing.T) {
|
||||
expected := errors.New("mocked error")
|
||||
e := &DNSDecoder{
|
||||
MockDecodeReply: func(reply []byte) (*dns.Msg, error) {
|
||||
return nil, expected
|
||||
},
|
||||
}
|
||||
out, err := e.DecodeReply(make([]byte, 17))
|
||||
out, err := e.DecodeResponse(make([]byte, 17), &DNSQuery{})
|
||||
if !errors.Is(err, expected) {
|
||||
t.Fatal("unexpected err", err)
|
||||
}
|
||||
|
|
|
@ -1,11 +1,19 @@
|
|||
package mocks
|
||||
|
||||
// DNSEncoder allows mocking dnsx.DNSEncoder.
|
||||
//
|
||||
// Mocks for model.DNSEncoder.
|
||||
//
|
||||
|
||||
import "github.com/ooni/probe-cli/v3/internal/model"
|
||||
|
||||
// DNSEncoder allows mocking model.DNSEncoder.
|
||||
type DNSEncoder struct {
|
||||
MockEncode func(domain string, qtype uint16, padding bool) ([]byte, uint16, error)
|
||||
MockEncode func(domain string, qtype uint16, padding bool) model.DNSQuery
|
||||
}
|
||||
|
||||
var _ model.DNSEncoder = &DNSEncoder{}
|
||||
|
||||
// Encode calls MockEncode.
|
||||
func (e *DNSEncoder) Encode(domain string, qtype uint16, padding bool) ([]byte, uint16, error) {
|
||||
func (e *DNSEncoder) Encode(domain string, qtype uint16, padding bool) model.DNSQuery {
|
||||
return e.MockEncode(domain, qtype, padding)
|
||||
}
|
||||
|
|
|
@ -5,24 +5,46 @@ import (
|
|||
"testing"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/ooni/probe-cli/v3/internal/model"
|
||||
)
|
||||
|
||||
func TestDNSEncoder(t *testing.T) {
|
||||
t.Run("Encode", func(t *testing.T) {
|
||||
expected := errors.New("mocked error")
|
||||
queryID := dns.Id()
|
||||
e := &DNSEncoder{
|
||||
MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, uint16, error) {
|
||||
return nil, 0, expected
|
||||
MockEncode: func(domain string, qtype uint16, padding bool) model.DNSQuery {
|
||||
return &DNSQuery{
|
||||
MockDomain: func() string {
|
||||
return dns.Fqdn(domain) // do what an implementation MUST do
|
||||
},
|
||||
MockType: func() uint16 {
|
||||
return qtype
|
||||
},
|
||||
MockBytes: func() ([]byte, error) {
|
||||
return nil, expected
|
||||
},
|
||||
MockID: func() uint16 {
|
||||
return queryID
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
out, queryID, err := e.Encode("dns.google", dns.TypeA, true)
|
||||
query := e.Encode("dns.google", dns.TypeA, true)
|
||||
if query.Domain() != "dns.google." {
|
||||
t.Fatal("invalid domain")
|
||||
}
|
||||
if query.Type() != dns.TypeA {
|
||||
t.Fatal("invalid type")
|
||||
}
|
||||
out, err := query.Bytes()
|
||||
if !errors.Is(err, expected) {
|
||||
t.Fatal("unexpected err", err)
|
||||
}
|
||||
if out != nil {
|
||||
t.Fatal("unexpected out")
|
||||
}
|
||||
if queryID != 0 {
|
||||
if query.ID() != queryID {
|
||||
t.Fatal("unexpected queryID")
|
||||
}
|
||||
})
|
||||
|
|
33
internal/model/mocks/dnsquery.go
Normal file
33
internal/model/mocks/dnsquery.go
Normal file
|
@ -0,0 +1,33 @@
|
|||
package mocks
|
||||
|
||||
//
|
||||
// Mocks for model.DNSQuery.
|
||||
//
|
||||
|
||||
import "github.com/ooni/probe-cli/v3/internal/model"
|
||||
|
||||
// DNSQuery allocks mocking model.DNSQuery.
|
||||
type DNSQuery struct {
|
||||
MockDomain func() string
|
||||
MockType func() uint16
|
||||
MockBytes func() ([]byte, error)
|
||||
MockID func() uint16
|
||||
}
|
||||
|
||||
func (q *DNSQuery) Domain() string {
|
||||
return q.MockDomain()
|
||||
}
|
||||
|
||||
func (q *DNSQuery) Type() uint16 {
|
||||
return q.MockType()
|
||||
}
|
||||
|
||||
func (q *DNSQuery) Bytes() ([]byte, error) {
|
||||
return q.MockBytes()
|
||||
}
|
||||
|
||||
func (q *DNSQuery) ID() uint16 {
|
||||
return q.MockID()
|
||||
}
|
||||
|
||||
var _ model.DNSQuery = &DNSQuery{}
|
62
internal/model/mocks/dnsquery_test.go
Normal file
62
internal/model/mocks/dnsquery_test.go
Normal file
|
@ -0,0 +1,62 @@
|
|||
package mocks
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
func TestDNSQuery(t *testing.T) {
|
||||
t.Run("Domain", func(t *testing.T) {
|
||||
expected := "dns.google."
|
||||
q := &DNSQuery{
|
||||
MockDomain: func() string {
|
||||
return expected
|
||||
},
|
||||
}
|
||||
if q.Domain() != expected {
|
||||
t.Fatal("invalid domain")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Type", func(t *testing.T) {
|
||||
expected := dns.TypeAAAA
|
||||
q := &DNSQuery{
|
||||
MockType: func() uint16 {
|
||||
return expected
|
||||
},
|
||||
}
|
||||
if q.Type() != expected {
|
||||
t.Fatal("invalid type")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Bytes", func(t *testing.T) {
|
||||
expected := []byte{0xde, 0xea, 0xad, 0xbe, 0xef}
|
||||
q := &DNSQuery{
|
||||
MockBytes: func() ([]byte, error) {
|
||||
return expected, nil
|
||||
},
|
||||
}
|
||||
out, err := q.Bytes()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !bytes.Equal(expected, out) {
|
||||
t.Fatal("invalid bytes")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ID", func(t *testing.T) {
|
||||
expected := dns.Id()
|
||||
q := &DNSQuery{
|
||||
MockID: func() uint16 {
|
||||
return expected
|
||||
},
|
||||
}
|
||||
if q.ID() != expected {
|
||||
t.Fatal("invalid id")
|
||||
}
|
||||
})
|
||||
}
|
47
internal/model/mocks/dnsresponse.go
Normal file
47
internal/model/mocks/dnsresponse.go
Normal file
|
@ -0,0 +1,47 @@
|
|||
package mocks
|
||||
|
||||
//
|
||||
// Mocks for model.DNSResponse
|
||||
//
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/model"
|
||||
)
|
||||
|
||||
// DNSResponse allows mocking model.DNSResponse.
|
||||
type DNSResponse struct {
|
||||
MockQuery func() model.DNSQuery
|
||||
MockBytes func() []byte
|
||||
MockRcode func() int
|
||||
MockDecodeHTTPS func() (*model.HTTPSSvc, error)
|
||||
MockDecodeLookupHost func() ([]string, error)
|
||||
MockDecodeNS func() ([]*net.NS, error)
|
||||
}
|
||||
|
||||
var _ model.DNSResponse = &DNSResponse{}
|
||||
|
||||
func (r *DNSResponse) Query() model.DNSQuery {
|
||||
return r.MockQuery()
|
||||
}
|
||||
|
||||
func (r *DNSResponse) Bytes() []byte {
|
||||
return r.MockBytes()
|
||||
}
|
||||
|
||||
func (r *DNSResponse) Rcode() int {
|
||||
return r.MockRcode()
|
||||
}
|
||||
|
||||
func (r *DNSResponse) DecodeHTTPS() (*model.HTTPSSvc, error) {
|
||||
return r.MockDecodeHTTPS()
|
||||
}
|
||||
|
||||
func (r *DNSResponse) DecodeLookupHost() ([]string, error) {
|
||||
return r.MockDecodeLookupHost()
|
||||
}
|
||||
|
||||
func (r *DNSResponse) DecodeNS() ([]*net.NS, error) {
|
||||
return r.MockDecodeNS()
|
||||
}
|
105
internal/model/mocks/dnsresponse_test.go
Normal file
105
internal/model/mocks/dnsresponse_test.go
Normal file
|
@ -0,0 +1,105 @@
|
|||
package mocks
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/ooni/probe-cli/v3/internal/model"
|
||||
)
|
||||
|
||||
func TestDNSResponse(t *testing.T) {
|
||||
t.Run("Query", func(t *testing.T) {
|
||||
qid := dns.Id()
|
||||
query := &DNSQuery{
|
||||
MockID: func() uint16 {
|
||||
return qid
|
||||
},
|
||||
}
|
||||
resp := &DNSResponse{
|
||||
MockQuery: func() model.DNSQuery {
|
||||
return query
|
||||
},
|
||||
}
|
||||
out := resp.Query()
|
||||
if out.ID() != query.ID() {
|
||||
t.Fatal("invalid query")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Bytes", func(t *testing.T) {
|
||||
expected := []byte{0xde, 0xea, 0xad, 0xbe, 0xef}
|
||||
resp := &DNSResponse{
|
||||
MockBytes: func() []byte {
|
||||
return expected
|
||||
},
|
||||
}
|
||||
out := resp.Bytes()
|
||||
if !bytes.Equal(expected, out) {
|
||||
t.Fatal("invalid bytes")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Rcode", func(t *testing.T) {
|
||||
expected := dns.RcodeBadAlg
|
||||
resp := &DNSResponse{
|
||||
MockRcode: func() int {
|
||||
return expected
|
||||
},
|
||||
}
|
||||
out := resp.Rcode()
|
||||
if out != expected {
|
||||
t.Fatal("invalid rcode")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("DecodeLookupHost", func(t *testing.T) {
|
||||
expected := errors.New("mocked error")
|
||||
r := &DNSResponse{
|
||||
MockDecodeLookupHost: func() ([]string, error) {
|
||||
return nil, expected
|
||||
},
|
||||
}
|
||||
out, err := r.DecodeLookupHost()
|
||||
if !errors.Is(err, expected) {
|
||||
t.Fatal("unexpected err", err)
|
||||
}
|
||||
if out != nil {
|
||||
t.Fatal("unexpected out")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("DecodeHTTPS", func(t *testing.T) {
|
||||
expected := errors.New("mocked error")
|
||||
r := &DNSResponse{
|
||||
MockDecodeHTTPS: func() (*model.HTTPSSvc, error) {
|
||||
return nil, expected
|
||||
},
|
||||
}
|
||||
out, err := r.DecodeHTTPS()
|
||||
if !errors.Is(err, expected) {
|
||||
t.Fatal("unexpected err", err)
|
||||
}
|
||||
if out != nil {
|
||||
t.Fatal("unexpected out")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("DecodeNS", func(t *testing.T) {
|
||||
expected := errors.New("mocked error")
|
||||
r := &DNSResponse{
|
||||
MockDecodeNS: func() ([]*net.NS, error) {
|
||||
return nil, expected
|
||||
},
|
||||
}
|
||||
out, err := r.DecodeNS()
|
||||
if !errors.Is(err, expected) {
|
||||
t.Fatal("unexpected err", err)
|
||||
}
|
||||
if out != nil {
|
||||
t.Fatal("unexpected out")
|
||||
}
|
||||
})
|
||||
}
|
|
@ -1,10 +1,14 @@
|
|||
package mocks
|
||||
|
||||
import "context"
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/model"
|
||||
)
|
||||
|
||||
// DNSTransport allows mocking dnsx.DNSTransport.
|
||||
type DNSTransport struct {
|
||||
MockRoundTrip func(ctx context.Context, query []byte) ([]byte, error)
|
||||
MockRoundTrip func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error)
|
||||
|
||||
MockRequiresPadding func() bool
|
||||
|
||||
|
@ -15,8 +19,10 @@ type DNSTransport struct {
|
|||
MockCloseIdleConnections func()
|
||||
}
|
||||
|
||||
var _ model.DNSTransport = &DNSTransport{}
|
||||
|
||||
// RoundTrip calls MockRoundTrip.
|
||||
func (txp *DNSTransport) RoundTrip(ctx context.Context, query []byte) ([]byte, error) {
|
||||
func (txp *DNSTransport) RoundTrip(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
|
||||
return txp.MockRoundTrip(ctx, query)
|
||||
}
|
||||
|
||||
|
|
|
@ -6,17 +6,18 @@ import (
|
|||
"testing"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/atomicx"
|
||||
"github.com/ooni/probe-cli/v3/internal/model"
|
||||
)
|
||||
|
||||
func TestDNSTransport(t *testing.T) {
|
||||
t.Run("RoundTrip", func(t *testing.T) {
|
||||
expected := errors.New("mocked error")
|
||||
txp := &DNSTransport{
|
||||
MockRoundTrip: func(ctx context.Context, query []byte) ([]byte, error) {
|
||||
MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
|
||||
return nil, expected
|
||||
},
|
||||
}
|
||||
resp, err := txp.RoundTrip(context.Background(), make([]byte, 16))
|
||||
resp, err := txp.RoundTrip(context.Background(), &DNSQuery{})
|
||||
if !errors.Is(err, expected) {
|
||||
t.Fatal("not the error we expected", err)
|
||||
}
|
||||
|
|
|
@ -1,5 +1,9 @@
|
|||
package model
|
||||
|
||||
//
|
||||
// Network extensions
|
||||
//
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
|
@ -9,74 +13,81 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
//
|
||||
// Network extensions
|
||||
//
|
||||
// DNSResponse is a parsed DNS response ready for further processing.
|
||||
type DNSResponse interface {
|
||||
// Query is the query associated with this response.
|
||||
Query() DNSQuery
|
||||
|
||||
// The DNSDecoder decodes DNS replies.
|
||||
// Bytes returns the bytes from which we parsed the query.
|
||||
Bytes() []byte
|
||||
|
||||
// Rcode returns the response's Rcode.
|
||||
Rcode() int
|
||||
|
||||
// DecodeHTTPS returns information gathered from all the HTTPS
|
||||
// records found inside of this response.
|
||||
DecodeHTTPS() (*HTTPSSvc, error)
|
||||
|
||||
// DecodeLookupHost returns the addresses in the response matching
|
||||
// the original query type (one of A and AAAA).
|
||||
DecodeLookupHost() ([]string, error)
|
||||
|
||||
// DecodeNS returns all the NS entries in this response.
|
||||
DecodeNS() ([]*net.NS, error)
|
||||
}
|
||||
|
||||
// The DNSDecoder decodes DNS responses.
|
||||
type DNSDecoder interface {
|
||||
// DecodeLookupHost decodes an A or AAAA reply.
|
||||
//
|
||||
// Arguments:
|
||||
//
|
||||
// - qtype is the query type (e.g., dns.TypeAAAA)
|
||||
//
|
||||
// - data contains the reply bytes read from a DNSTransport
|
||||
//
|
||||
// - queryID is the original query ID
|
||||
//
|
||||
// Returns:
|
||||
//
|
||||
// - on success, a list of IP addrs inside the reply and a nil error
|
||||
//
|
||||
// - on failure, a nil list and an error.
|
||||
//
|
||||
// Note that this function will return an error if there is no
|
||||
// IP address inside of the reply.
|
||||
DecodeLookupHost(qtype uint16, data []byte, queryID uint16) ([]string, error)
|
||||
|
||||
// DecodeHTTPS is like DecodeLookupHost but decodes an HTTPS reply.
|
||||
//
|
||||
// The argument is the reply as read by the DNSTransport.
|
||||
//
|
||||
// On success, this function returns an HTTPSSvc structure and
|
||||
// a nil error. On failure, the HTTPSSvc pointer is nil and
|
||||
// the error points to the error that occurred.
|
||||
//
|
||||
// This function will return an error if the HTTPS reply does not
|
||||
// contain at least a valid ALPN entry. It will not return
|
||||
// an error, though, when there are no IPv4/IPv6 hints in the reply.
|
||||
DecodeHTTPS(data []byte, queryID uint16) (*HTTPSSvc, error)
|
||||
|
||||
// DecodeNS is like DecodeHTTPS but for NS queries.
|
||||
DecodeNS(data []byte, queryID uint16) ([]*net.NS, error)
|
||||
|
||||
// DecodeReply decodes a DNS reply message.
|
||||
// DecodeResponse decodes a DNS response message.
|
||||
//
|
||||
// Arguments:
|
||||
//
|
||||
// - data is the raw reply
|
||||
//
|
||||
// This function fails if we cannot parse data as a DNS
|
||||
// message or the message is not a reply.
|
||||
// message or the message is not a response.
|
||||
//
|
||||
// If you use this function, remember that:
|
||||
// Regarding the returned response, remember that the Rcode
|
||||
// MAY still be nonzero (this method does not treat a nonzero
|
||||
// Rcode as an error when parsing the response).
|
||||
DecodeResponse(data []byte, query DNSQuery) (DNSResponse, error)
|
||||
}
|
||||
|
||||
// DNSQuery is an encoded DNS query ready to be sent using a DNSTransport.
|
||||
type DNSQuery interface {
|
||||
// Domain is the domain we're querying for.
|
||||
Domain() string
|
||||
|
||||
// Type is the query type.
|
||||
Type() uint16
|
||||
|
||||
// Bytes serializes the query to bytes. This function may fail if we're not
|
||||
// able to correctly encode the domain into a query message.
|
||||
//
|
||||
// 1. the Rcode MAY be nonzero;
|
||||
//
|
||||
// 2. the replyID MAY NOT match the original query ID.
|
||||
//
|
||||
// That is, this is a very basic parsing method.
|
||||
DecodeReply(data []byte) (*dns.Msg, error)
|
||||
// The value returned by this function WILL be memoized after the first call,
|
||||
// so you SHOULD create a new DNSQuery if you need to retry a query.
|
||||
Bytes() ([]byte, error)
|
||||
|
||||
// ID returns the query ID.
|
||||
ID() uint16
|
||||
}
|
||||
|
||||
// The DNSEncoder encodes DNS queries to bytes
|
||||
type DNSEncoder interface {
|
||||
// Encode transforms its arguments into a serialized DNS query.
|
||||
//
|
||||
// Every time you call Encode, you get a new DNSQuery value
|
||||
// using a query ID selected at random.
|
||||
//
|
||||
// Serialization to bytes is lazy to acommodate DNS transports that
|
||||
// do not need to serialize and send bytes, e.g., getaddrinfo.
|
||||
//
|
||||
// You serialize to bytes using DNSQuery.Bytes. This operation MAY fail
|
||||
// if the domain name cannot be packed into a DNS message (e.g., it is
|
||||
// too long to fit into the message).
|
||||
//
|
||||
// Arguments:
|
||||
//
|
||||
// - domain is the domain for the query (e.g., x.org);
|
||||
|
@ -85,16 +96,15 @@ type DNSEncoder interface {
|
|||
//
|
||||
// - padding is whether to add padding to the query.
|
||||
//
|
||||
// On success, this function returns a valid byte array, the queryID, and
|
||||
// a nil error. On failure, we have a non-nil error, a nil arrary and a zero
|
||||
// query ID.
|
||||
Encode(domain string, qtype uint16, padding bool) ([]byte, uint16, error)
|
||||
// This function will transform the domain into an FQDN is it's not
|
||||
// already expressed in the FQDN format.
|
||||
Encode(domain string, qtype uint16, padding bool) DNSQuery
|
||||
}
|
||||
|
||||
// DNSTransport represents an abstract DNS transport.
|
||||
type DNSTransport interface {
|
||||
// RoundTrip sends a DNS query and receives the reply.
|
||||
RoundTrip(ctx context.Context, query []byte) (reply []byte, err error)
|
||||
RoundTrip(ctx context.Context, query DNSQuery) (DNSResponse, error)
|
||||
|
||||
// RequiresPadding returns whether this transport needs padding.
|
||||
RequiresPadding() bool
|
||||
|
|
|
@ -15,14 +15,16 @@ import (
|
|||
// DNSDecoderMiekg uses github.com/miekg/dns to implement the Decoder.
|
||||
type DNSDecoderMiekg struct{}
|
||||
|
||||
// ErrDNSReplyWithWrongQueryID indicates we have got a DNS reply with the wrong queryID.
|
||||
var ErrDNSReplyWithWrongQueryID = errors.New(FailureDNSReplyWithWrongQueryID)
|
||||
var (
|
||||
// ErrDNSReplyWithWrongQueryID indicates we have got a DNS reply with the wrong queryID.
|
||||
ErrDNSReplyWithWrongQueryID = errors.New(FailureDNSReplyWithWrongQueryID)
|
||||
|
||||
// ErrDNSIsQuery indicates that we were passed a DNS query.
|
||||
var ErrDNSIsQuery = errors.New("ooresolver: expected response but received query")
|
||||
// ErrDNSIsQuery indicates that we were passed a DNS query.
|
||||
ErrDNSIsQuery = errors.New("ooresolver: expected response but received query")
|
||||
)
|
||||
|
||||
// DecodeReply implements model.DNSDecoder.DecodeReply
|
||||
func (d *DNSDecoderMiekg) DecodeReply(data []byte) (*dns.Msg, error) {
|
||||
// DecodeResponse implements model.DNSDecoder.DecodeResponse.
|
||||
func (d *DNSDecoderMiekg) DecodeResponse(data []byte, query model.DNSQuery) (model.DNSResponse, error) {
|
||||
reply := &dns.Msg{}
|
||||
if err := reply.Unpack(data); err != nil {
|
||||
return nil, err
|
||||
|
@ -30,46 +32,64 @@ func (d *DNSDecoderMiekg) DecodeReply(data []byte) (*dns.Msg, error) {
|
|||
if !reply.Response {
|
||||
return nil, ErrDNSIsQuery
|
||||
}
|
||||
return reply, nil
|
||||
}
|
||||
|
||||
// decodeSuccessfulReply decodes the bytes in data as a successful reply for the
|
||||
// given queryID. This function returns an error if:
|
||||
//
|
||||
// 1. we cannot decode data
|
||||
//
|
||||
// 2. the decoded message is not a reply
|
||||
//
|
||||
// 3. the query ID does not match
|
||||
//
|
||||
// 4. the Rcode is not zero.
|
||||
func (d *DNSDecoderMiekg) decodeSuccessfulReply(data []byte, queryID uint16) (*dns.Msg, error) {
|
||||
reply, err := d.DecodeReply(data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if reply.Id != queryID {
|
||||
if reply.Id != query.ID() {
|
||||
return nil, ErrDNSReplyWithWrongQueryID
|
||||
}
|
||||
resp := &dnsResponse{
|
||||
bytes: data,
|
||||
msg: reply,
|
||||
query: query,
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// dnsResponse implements model.DNSResponse.
|
||||
type dnsResponse struct {
|
||||
// bytes contains the response bytes.
|
||||
bytes []byte
|
||||
|
||||
// msg contains the message.
|
||||
msg *dns.Msg
|
||||
|
||||
// query is the original query.
|
||||
query model.DNSQuery
|
||||
}
|
||||
|
||||
// Query implements model.DNSResponse.Query.
|
||||
func (r *dnsResponse) Query() model.DNSQuery {
|
||||
return r.query
|
||||
}
|
||||
|
||||
// Bytes implements model.DNSResponse.Bytes.
|
||||
func (r *dnsResponse) Bytes() []byte {
|
||||
return r.bytes
|
||||
}
|
||||
|
||||
// Rcode implements model.DNSResponse.Rcode.
|
||||
func (r *dnsResponse) Rcode() int {
|
||||
return r.msg.Rcode
|
||||
}
|
||||
|
||||
func (r *dnsResponse) rcodeToError() error {
|
||||
// TODO(bassosimone): map more errors to net.DNSError names
|
||||
// TODO(bassosimone): add support for lame referral.
|
||||
switch reply.Rcode {
|
||||
switch r.msg.Rcode {
|
||||
case dns.RcodeSuccess:
|
||||
return reply, nil
|
||||
return nil
|
||||
case dns.RcodeNameError:
|
||||
return nil, ErrOODNSNoSuchHost
|
||||
return ErrOODNSNoSuchHost
|
||||
case dns.RcodeRefused:
|
||||
return nil, ErrOODNSRefused
|
||||
return ErrOODNSRefused
|
||||
case dns.RcodeServerFailure:
|
||||
return nil, ErrOODNSServfail
|
||||
return ErrOODNSServfail
|
||||
default:
|
||||
return nil, ErrOODNSMisbehaving
|
||||
return ErrOODNSMisbehaving
|
||||
}
|
||||
}
|
||||
|
||||
func (d *DNSDecoderMiekg) DecodeHTTPS(data []byte, queryID uint16) (*model.HTTPSSvc, error) {
|
||||
reply, err := d.decodeSuccessfulReply(data, queryID)
|
||||
if err != nil {
|
||||
// DecodeHTTPS implements model.DNSResponse.DecodeHTTPS.
|
||||
func (r *dnsResponse) DecodeHTTPS() (*model.HTTPSSvc, error) {
|
||||
if err := r.rcodeToError(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out := &model.HTTPSSvc{
|
||||
|
@ -77,7 +97,7 @@ func (d *DNSDecoderMiekg) DecodeHTTPS(data []byte, queryID uint16) (*model.HTTPS
|
|||
IPv4: []string{}, // ensure it's not nil
|
||||
IPv6: []string{}, // ensure it's not nil
|
||||
}
|
||||
for _, answer := range reply.Answer {
|
||||
for _, answer := range r.msg.Answer {
|
||||
switch avalue := answer.(type) {
|
||||
case *dns.HTTPS:
|
||||
for _, v := range avalue.Value {
|
||||
|
@ -102,14 +122,14 @@ func (d *DNSDecoderMiekg) DecodeHTTPS(data []byte, queryID uint16) (*model.HTTPS
|
|||
return out, nil
|
||||
}
|
||||
|
||||
func (d *DNSDecoderMiekg) DecodeLookupHost(qtype uint16, data []byte, queryID uint16) ([]string, error) {
|
||||
reply, err := d.decodeSuccessfulReply(data, queryID)
|
||||
if err != nil {
|
||||
// DecodeLookupHost implements model.DNSResponse.DecodeLookupHost.
|
||||
func (r *dnsResponse) DecodeLookupHost() ([]string, error) {
|
||||
if err := r.rcodeToError(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var addrs []string
|
||||
for _, answer := range reply.Answer {
|
||||
switch qtype {
|
||||
for _, answer := range r.msg.Answer {
|
||||
switch r.Query().Type() {
|
||||
case dns.TypeA:
|
||||
if rra, ok := answer.(*dns.A); ok {
|
||||
ip := rra.A
|
||||
|
@ -128,13 +148,13 @@ func (d *DNSDecoderMiekg) DecodeLookupHost(qtype uint16, data []byte, queryID ui
|
|||
return addrs, nil
|
||||
}
|
||||
|
||||
func (d *DNSDecoderMiekg) DecodeNS(data []byte, queryID uint16) ([]*net.NS, error) {
|
||||
reply, err := d.decodeSuccessfulReply(data, queryID)
|
||||
if err != nil {
|
||||
// DecodeNS implements model.DNSResponse.DecodeNS.
|
||||
func (r *dnsResponse) DecodeNS() ([]*net.NS, error) {
|
||||
if err := r.rcodeToError(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out := []*net.NS{}
|
||||
for _, answer := range reply.Answer {
|
||||
for _, answer := range r.msg.Answer {
|
||||
switch avalue := answer.(type) {
|
||||
case *dns.NS:
|
||||
out = append(out, &net.NS{Host: avalue.Ns})
|
||||
|
@ -147,3 +167,4 @@ func (d *DNSDecoderMiekg) DecodeNS(data []byte, queryID uint16) ([]*net.NS, erro
|
|||
}
|
||||
|
||||
var _ model.DNSDecoder = &DNSDecoderMiekg{}
|
||||
var _ model.DNSResponse = &dnsResponse{}
|
||||
|
|
|
@ -1,26 +1,27 @@
|
|||
package netxlite
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"net"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/ooni/probe-cli/v3/internal/model/mocks"
|
||||
"github.com/ooni/probe-cli/v3/internal/runtimex"
|
||||
)
|
||||
|
||||
func TestDNSDecoder(t *testing.T) {
|
||||
t.Run("LookupHost", func(t *testing.T) {
|
||||
func TestDNSDecoderMiekg(t *testing.T) {
|
||||
t.Run("DecodeResponse", func(t *testing.T) {
|
||||
t.Run("UnpackError", func(t *testing.T) {
|
||||
d := &DNSDecoderMiekg{}
|
||||
data, err := d.DecodeLookupHost(dns.TypeA, nil, 0)
|
||||
resp, err := d.DecodeResponse(nil, &mocks.DNSQuery{})
|
||||
if err == nil || err.Error() != "dns: overflow unpacking uint16" {
|
||||
t.Fatal("unexpected error", err)
|
||||
}
|
||||
if data != nil {
|
||||
t.Fatal("expected nil data here")
|
||||
if resp != nil {
|
||||
t.Fatal("expected nil resp here")
|
||||
}
|
||||
})
|
||||
|
||||
|
@ -28,12 +29,12 @@ func TestDNSDecoder(t *testing.T) {
|
|||
d := &DNSDecoderMiekg{}
|
||||
queryID := dns.Id()
|
||||
rawQuery := dnsGenQuery(dns.TypeA, queryID)
|
||||
addrs, err := d.DecodeLookupHost(dns.TypeA, rawQuery, queryID)
|
||||
resp, err := d.DecodeResponse(rawQuery, &mocks.DNSQuery{})
|
||||
if !errors.Is(err, ErrDNSIsQuery) {
|
||||
t.Fatal("unexpected err", err)
|
||||
}
|
||||
if len(addrs) > 0 {
|
||||
t.Fatal("expected no addrs")
|
||||
if resp != nil {
|
||||
t.Fatal("expected nil resp here")
|
||||
}
|
||||
})
|
||||
|
||||
|
@ -44,297 +45,447 @@ func TestDNSDecoder(t *testing.T) {
|
|||
unrelatedID = 14
|
||||
)
|
||||
reply := dnsGenLookupHostReplySuccess(dnsGenQuery(dns.TypeA, queryID))
|
||||
data, err := d.DecodeLookupHost(dns.TypeA, reply, unrelatedID)
|
||||
resp, err := d.DecodeResponse(reply, &mocks.DNSQuery{
|
||||
MockID: func() uint16 {
|
||||
return unrelatedID
|
||||
},
|
||||
})
|
||||
if !errors.Is(err, ErrDNSReplyWithWrongQueryID) {
|
||||
t.Fatal("unexpected error", err)
|
||||
}
|
||||
if data != nil {
|
||||
t.Fatal("expected nil data here")
|
||||
if resp != nil {
|
||||
t.Fatal("expected nil resp here")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("NXDOMAIN", func(t *testing.T) {
|
||||
t.Run("dnsResponse.Query", func(t *testing.T) {
|
||||
d := &DNSDecoderMiekg{}
|
||||
queryID := dns.Id()
|
||||
data, err := d.DecodeLookupHost(dns.TypeA, dnsGenReplyWithError(
|
||||
dnsGenQuery(dns.TypeA, queryID), dns.RcodeNameError), queryID)
|
||||
if err == nil || !strings.HasSuffix(err.Error(), "no such host") {
|
||||
t.Fatal("not the error we expected", err)
|
||||
rawQuery := dnsGenQuery(dns.TypeA, queryID)
|
||||
rawResponse := dnsGenLookupHostReplySuccess(rawQuery)
|
||||
query := &mocks.DNSQuery{
|
||||
MockID: func() uint16 {
|
||||
return queryID
|
||||
},
|
||||
}
|
||||
if data != nil {
|
||||
t.Fatal("expected nil data here")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Refused", func(t *testing.T) {
|
||||
d := &DNSDecoderMiekg{}
|
||||
queryID := dns.Id()
|
||||
data, err := d.DecodeLookupHost(dns.TypeA, dnsGenReplyWithError(
|
||||
dnsGenQuery(dns.TypeA, queryID), dns.RcodeRefused), queryID)
|
||||
if !errors.Is(err, ErrOODNSRefused) {
|
||||
t.Fatal("not the error we expected", err)
|
||||
}
|
||||
if data != nil {
|
||||
t.Fatal("expected nil data here")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Servfail", func(t *testing.T) {
|
||||
d := &DNSDecoderMiekg{}
|
||||
queryID := dns.Id()
|
||||
data, err := d.DecodeLookupHost(dns.TypeA, dnsGenReplyWithError(
|
||||
dnsGenQuery(dns.TypeA, queryID), dns.RcodeServerFailure), queryID)
|
||||
if !errors.Is(err, ErrOODNSServfail) {
|
||||
t.Fatal("not the error we expected", err)
|
||||
}
|
||||
if data != nil {
|
||||
t.Fatal("expected nil data here")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("no address", func(t *testing.T) {
|
||||
d := &DNSDecoderMiekg{}
|
||||
queryID := dns.Id()
|
||||
data, err := d.DecodeLookupHost(dns.TypeA, dnsGenLookupHostReplySuccess(
|
||||
dnsGenQuery(dns.TypeA, queryID)), queryID)
|
||||
if !errors.Is(err, ErrOODNSNoAnswer) {
|
||||
t.Fatal("not the error we expected", err)
|
||||
}
|
||||
if data != nil {
|
||||
t.Fatal("expected nil data here")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("decode A", func(t *testing.T) {
|
||||
d := &DNSDecoderMiekg{}
|
||||
queryID := dns.Id()
|
||||
data, err := d.DecodeLookupHost(dns.TypeA, dnsGenLookupHostReplySuccess(
|
||||
dnsGenQuery(dns.TypeA, queryID), "1.1.1.1", "8.8.8.8"), queryID)
|
||||
resp, err := d.DecodeResponse(rawResponse, query)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(data) != 2 {
|
||||
t.Fatal("expected two entries here")
|
||||
}
|
||||
if data[0] != "1.1.1.1" {
|
||||
t.Fatal("invalid first IPv4 entry")
|
||||
}
|
||||
if data[1] != "8.8.8.8" {
|
||||
t.Fatal("invalid second IPv4 entry")
|
||||
if resp.Query().ID() != query.ID() {
|
||||
t.Fatal("invalid query")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("decode AAAA", func(t *testing.T) {
|
||||
t.Run("dnsResponse.Bytes", func(t *testing.T) {
|
||||
d := &DNSDecoderMiekg{}
|
||||
queryID := dns.Id()
|
||||
data, err := d.DecodeLookupHost(dns.TypeAAAA, dnsGenLookupHostReplySuccess(
|
||||
dnsGenQuery(dns.TypeAAAA, queryID), "::1", "fe80::1"), queryID)
|
||||
rawQuery := dnsGenQuery(dns.TypeA, queryID)
|
||||
rawResponse := dnsGenLookupHostReplySuccess(rawQuery)
|
||||
query := &mocks.DNSQuery{
|
||||
MockID: func() uint16 {
|
||||
return queryID
|
||||
},
|
||||
}
|
||||
resp, err := d.DecodeResponse(rawResponse, query)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(data) != 2 {
|
||||
t.Fatal("expected two entries here")
|
||||
}
|
||||
if data[0] != "::1" {
|
||||
t.Fatal("invalid first IPv6 entry")
|
||||
}
|
||||
if data[1] != "fe80::1" {
|
||||
t.Fatal("invalid second IPv6 entry")
|
||||
if !bytes.Equal(rawResponse, resp.Bytes()) {
|
||||
t.Fatal("invalid bytes")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("unexpected A reply", func(t *testing.T) {
|
||||
t.Run("dnsResponse.Rcode", func(t *testing.T) {
|
||||
d := &DNSDecoderMiekg{}
|
||||
queryID := dns.Id()
|
||||
data, err := d.DecodeLookupHost(dns.TypeA, dnsGenLookupHostReplySuccess(
|
||||
dnsGenQuery(dns.TypeAAAA, queryID), "::1", "fe80::1"), queryID)
|
||||
if !errors.Is(err, ErrOODNSNoAnswer) {
|
||||
t.Fatal("not the error we expected", err)
|
||||
rawQuery := dnsGenQuery(dns.TypeA, queryID)
|
||||
rawResponse := dnsGenReplyWithError(rawQuery, dns.RcodeRefused)
|
||||
query := &mocks.DNSQuery{
|
||||
MockID: func() uint16 {
|
||||
return queryID
|
||||
},
|
||||
}
|
||||
if data != nil {
|
||||
t.Fatal("expected nil data here")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("unexpected AAAA reply", func(t *testing.T) {
|
||||
d := &DNSDecoderMiekg{}
|
||||
queryID := dns.Id()
|
||||
data, err := d.DecodeLookupHost(dns.TypeAAAA, dnsGenLookupHostReplySuccess(
|
||||
dnsGenQuery(dns.TypeA, queryID), "1.1.1.1", "8.8.4.4"), queryID)
|
||||
if !errors.Is(err, ErrOODNSNoAnswer) {
|
||||
t.Fatal("not the error we expected", err)
|
||||
}
|
||||
if data != nil {
|
||||
t.Fatal("expected nil data here")
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("decodeSuccessfulReply", func(t *testing.T) {
|
||||
d := &DNSDecoderMiekg{}
|
||||
msg := &dns.Msg{}
|
||||
msg.Rcode = dns.RcodeFormatError // an rcode we don't handle
|
||||
msg.Response = true
|
||||
data, err := msg.Pack()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
reply, err := d.decodeSuccessfulReply(data, 0)
|
||||
if !errors.Is(err, ErrOODNSMisbehaving) { // catch all error
|
||||
t.Fatal("not the error we expected", err)
|
||||
}
|
||||
if reply != nil {
|
||||
t.Fatal("expected nil reply")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("DecodeHTTPS", func(t *testing.T) {
|
||||
t.Run("with nil data", func(t *testing.T) {
|
||||
d := &DNSDecoderMiekg{}
|
||||
reply, err := d.DecodeHTTPS(nil, 0)
|
||||
if err == nil || err.Error() != "dns: overflow unpacking uint16" {
|
||||
t.Fatal("not the error we expected", err)
|
||||
}
|
||||
if reply != nil {
|
||||
t.Fatal("expected nil reply")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("with bytes containing a query", func(t *testing.T) {
|
||||
d := &DNSDecoderMiekg{}
|
||||
queryID := dns.Id()
|
||||
rawQuery := dnsGenQuery(dns.TypeHTTPS, queryID)
|
||||
https, err := d.DecodeHTTPS(rawQuery, queryID)
|
||||
if !errors.Is(err, ErrDNSIsQuery) {
|
||||
t.Fatal("unexpected err", err)
|
||||
}
|
||||
if https != nil {
|
||||
t.Fatal("expected nil https")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("wrong query ID", func(t *testing.T) {
|
||||
d := &DNSDecoderMiekg{}
|
||||
const (
|
||||
queryID = 17
|
||||
unrelatedID = 14
|
||||
)
|
||||
reply := dnsGenHTTPSReplySuccess(dnsGenQuery(dns.TypeHTTPS, queryID), nil, nil, nil)
|
||||
data, err := d.DecodeHTTPS(reply, unrelatedID)
|
||||
if !errors.Is(err, ErrDNSReplyWithWrongQueryID) {
|
||||
t.Fatal("unexpected error", err)
|
||||
}
|
||||
if data != nil {
|
||||
t.Fatal("expected nil data here")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("with empty answer", func(t *testing.T) {
|
||||
queryID := dns.Id()
|
||||
data := dnsGenHTTPSReplySuccess(
|
||||
dnsGenQuery(dns.TypeHTTPS, queryID), nil, nil, nil)
|
||||
d := &DNSDecoderMiekg{}
|
||||
reply, err := d.DecodeHTTPS(data, queryID)
|
||||
if !errors.Is(err, ErrOODNSNoAnswer) {
|
||||
t.Fatal("unexpected err", err)
|
||||
}
|
||||
if reply != nil {
|
||||
t.Fatal("expected nil reply")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("with full answer", func(t *testing.T) {
|
||||
queryID := dns.Id()
|
||||
alpn := []string{"h3"}
|
||||
v4 := []string{"1.1.1.1"}
|
||||
v6 := []string{"::1"}
|
||||
data := dnsGenHTTPSReplySuccess(
|
||||
dnsGenQuery(dns.TypeHTTPS, queryID), alpn, v4, v6)
|
||||
d := &DNSDecoderMiekg{}
|
||||
reply, err := d.DecodeHTTPS(data, queryID)
|
||||
resp, err := d.DecodeResponse(rawResponse, query)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if diff := cmp.Diff(alpn, reply.ALPN); diff != "" {
|
||||
t.Fatal(diff)
|
||||
}
|
||||
if diff := cmp.Diff(v4, reply.IPv4); diff != "" {
|
||||
t.Fatal(diff)
|
||||
}
|
||||
if diff := cmp.Diff(v6, reply.IPv6); diff != "" {
|
||||
t.Fatal(diff)
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("DecodeNS", func(t *testing.T) {
|
||||
t.Run("with nil data", func(t *testing.T) {
|
||||
d := &DNSDecoderMiekg{}
|
||||
reply, err := d.DecodeNS(nil, 0)
|
||||
if err == nil || err.Error() != "dns: overflow unpacking uint16" {
|
||||
t.Fatal("not the error we expected", err)
|
||||
}
|
||||
if reply != nil {
|
||||
t.Fatal("expected nil reply")
|
||||
if resp.Rcode() != dns.RcodeRefused {
|
||||
t.Fatal("invalid rcode")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("with bytes containing a query", func(t *testing.T) {
|
||||
d := &DNSDecoderMiekg{}
|
||||
queryID := dns.Id()
|
||||
rawQuery := dnsGenQuery(dns.TypeNS, queryID)
|
||||
ns, err := d.DecodeNS(rawQuery, queryID)
|
||||
if !errors.Is(err, ErrDNSIsQuery) {
|
||||
t.Fatal("unexpected err", err)
|
||||
}
|
||||
if len(ns) > 0 {
|
||||
t.Fatal("expected no result")
|
||||
t.Run("dnsResponse.rcodeToError", func(t *testing.T) {
|
||||
// Here we want to ensure we map all the errors we recognize
|
||||
// correctly and we also map unrecognized errors correctly
|
||||
var inputsOutputs = []struct {
|
||||
name string
|
||||
rcode int
|
||||
err error
|
||||
}{{
|
||||
name: "when rcode is zero",
|
||||
rcode: 0,
|
||||
err: nil,
|
||||
}, {
|
||||
name: "NXDOMAIN",
|
||||
rcode: dns.RcodeNameError,
|
||||
err: ErrOODNSNoSuchHost,
|
||||
}, {
|
||||
name: "refused",
|
||||
rcode: dns.RcodeRefused,
|
||||
err: ErrOODNSRefused,
|
||||
}, {
|
||||
name: "servfail",
|
||||
rcode: dns.RcodeServerFailure,
|
||||
err: ErrOODNSServfail,
|
||||
}, {
|
||||
name: "anything else",
|
||||
rcode: dns.RcodeFormatError,
|
||||
err: ErrOODNSMisbehaving,
|
||||
}}
|
||||
for _, io := range inputsOutputs {
|
||||
t.Run(io.name, func(t *testing.T) {
|
||||
d := &DNSDecoderMiekg{}
|
||||
queryID := dns.Id()
|
||||
rawQuery := dnsGenQuery(dns.TypeHTTPS, queryID)
|
||||
rawResponse := dnsGenReplyWithError(rawQuery, io.rcode)
|
||||
query := &mocks.DNSQuery{
|
||||
MockID: func() uint16 {
|
||||
return queryID
|
||||
},
|
||||
}
|
||||
resp, err := d.DecodeResponse(rawResponse, query)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// The following cast should always work in this configuration
|
||||
err = resp.(*dnsResponse).rcodeToError()
|
||||
if !errors.Is(err, io.err) {
|
||||
t.Fatal("unexpected err", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("wrong query ID", func(t *testing.T) {
|
||||
d := &DNSDecoderMiekg{}
|
||||
const (
|
||||
queryID = 17
|
||||
unrelatedID = 14
|
||||
)
|
||||
reply := dnsGenNSReplySuccess(dnsGenQuery(dns.TypeNS, queryID))
|
||||
data, err := d.DecodeNS(reply, unrelatedID)
|
||||
if !errors.Is(err, ErrDNSReplyWithWrongQueryID) {
|
||||
t.Fatal("unexpected error", err)
|
||||
}
|
||||
if data != nil {
|
||||
t.Fatal("expected nil data here")
|
||||
}
|
||||
t.Run("dnsResponse.DecodeHTTPS", func(t *testing.T) {
|
||||
t.Run("with failure", func(t *testing.T) {
|
||||
// Ensure that we're not trying to decode if rcode != 0
|
||||
d := &DNSDecoderMiekg{}
|
||||
queryID := dns.Id()
|
||||
rawQuery := dnsGenQuery(dns.TypeHTTPS, queryID)
|
||||
rawResponse := dnsGenReplyWithError(rawQuery, dns.RcodeRefused)
|
||||
query := &mocks.DNSQuery{
|
||||
MockID: func() uint16 {
|
||||
return queryID
|
||||
},
|
||||
}
|
||||
resp, err := d.DecodeResponse(rawResponse, query)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
https, err := resp.DecodeHTTPS()
|
||||
if !errors.Is(err, ErrOODNSRefused) {
|
||||
t.Fatal("unexpected err", err)
|
||||
}
|
||||
if https != nil {
|
||||
t.Fatal("expected nil https result")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("with empty answer", func(t *testing.T) {
|
||||
d := &DNSDecoderMiekg{}
|
||||
queryID := dns.Id()
|
||||
rawQuery := dnsGenQuery(dns.TypeHTTPS, queryID)
|
||||
rawResponse := dnsGenHTTPSReplySuccess(rawQuery, nil, nil, nil)
|
||||
query := &mocks.DNSQuery{
|
||||
MockID: func() uint16 {
|
||||
return queryID
|
||||
},
|
||||
}
|
||||
resp, err := d.DecodeResponse(rawResponse, query)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
https, err := resp.DecodeHTTPS()
|
||||
if !errors.Is(err, ErrOODNSNoAnswer) {
|
||||
t.Fatal("unexpected err", err)
|
||||
}
|
||||
if https != nil {
|
||||
t.Fatal("expected nil https results")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("with full answer", func(t *testing.T) {
|
||||
alpn := []string{"h3"}
|
||||
v4 := []string{"1.1.1.1"}
|
||||
v6 := []string{"::1"}
|
||||
d := &DNSDecoderMiekg{}
|
||||
queryID := dns.Id()
|
||||
rawQuery := dnsGenQuery(dns.TypeHTTPS, queryID)
|
||||
rawResponse := dnsGenHTTPSReplySuccess(rawQuery, alpn, v4, v6)
|
||||
query := &mocks.DNSQuery{
|
||||
MockID: func() uint16 {
|
||||
return queryID
|
||||
},
|
||||
}
|
||||
resp, err := d.DecodeResponse(rawResponse, query)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
reply, err := resp.DecodeHTTPS()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if diff := cmp.Diff(alpn, reply.ALPN); diff != "" {
|
||||
t.Fatal(diff)
|
||||
}
|
||||
if diff := cmp.Diff(v4, reply.IPv4); diff != "" {
|
||||
t.Fatal(diff)
|
||||
}
|
||||
if diff := cmp.Diff(v6, reply.IPv6); diff != "" {
|
||||
t.Fatal(diff)
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("with empty answer", func(t *testing.T) {
|
||||
queryID := dns.Id()
|
||||
data := dnsGenNSReplySuccess(dnsGenQuery(dns.TypeNS, queryID))
|
||||
d := &DNSDecoderMiekg{}
|
||||
reply, err := d.DecodeNS(data, queryID)
|
||||
if !errors.Is(err, ErrOODNSNoAnswer) {
|
||||
t.Fatal("unexpected err", err)
|
||||
}
|
||||
if reply != nil {
|
||||
t.Fatal("expected nil reply")
|
||||
}
|
||||
t.Run("dnsResponse.DecodeNS", func(t *testing.T) {
|
||||
t.Run("with failure", func(t *testing.T) {
|
||||
// Ensure that we're not trying to decode if rcode != 0
|
||||
d := &DNSDecoderMiekg{}
|
||||
queryID := dns.Id()
|
||||
rawQuery := dnsGenQuery(dns.TypeNS, queryID)
|
||||
rawResponse := dnsGenReplyWithError(rawQuery, dns.RcodeRefused)
|
||||
query := &mocks.DNSQuery{
|
||||
MockID: func() uint16 {
|
||||
return queryID
|
||||
},
|
||||
}
|
||||
resp, err := d.DecodeResponse(rawResponse, query)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
ns, err := resp.DecodeNS()
|
||||
if !errors.Is(err, ErrOODNSRefused) {
|
||||
t.Fatal("unexpected err", err)
|
||||
}
|
||||
if len(ns) > 0 {
|
||||
t.Fatal("expected empty ns result")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("with empty answer", func(t *testing.T) {
|
||||
d := &DNSDecoderMiekg{}
|
||||
queryID := dns.Id()
|
||||
rawQuery := dnsGenQuery(dns.TypeNS, queryID)
|
||||
rawResponse := dnsGenNSReplySuccess(rawQuery)
|
||||
query := &mocks.DNSQuery{
|
||||
MockID: func() uint16 {
|
||||
return queryID
|
||||
},
|
||||
}
|
||||
resp, err := d.DecodeResponse(rawResponse, query)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
ns, err := resp.DecodeNS()
|
||||
if !errors.Is(err, ErrOODNSNoAnswer) {
|
||||
t.Fatal("unexpected err", err)
|
||||
}
|
||||
if len(ns) > 0 {
|
||||
t.Fatal("expected empty ns results")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("with full answer", func(t *testing.T) {
|
||||
d := &DNSDecoderMiekg{}
|
||||
queryID := dns.Id()
|
||||
rawQuery := dnsGenQuery(dns.TypeNS, queryID)
|
||||
rawResponse := dnsGenNSReplySuccess(rawQuery, "ns1.zdns.google.")
|
||||
query := &mocks.DNSQuery{
|
||||
MockID: func() uint16 {
|
||||
return queryID
|
||||
},
|
||||
}
|
||||
resp, err := d.DecodeResponse(rawResponse, query)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
ns, err := resp.DecodeNS()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(ns) != 1 {
|
||||
t.Fatal("unexpected ns length")
|
||||
}
|
||||
if ns[0].Host != "ns1.zdns.google." {
|
||||
t.Fatal("unexpected host")
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("with full answer", func(t *testing.T) {
|
||||
queryID := dns.Id()
|
||||
data := dnsGenNSReplySuccess(dnsGenQuery(dns.TypeNS, queryID), "ns1.zdns.google.")
|
||||
d := &DNSDecoderMiekg{}
|
||||
reply, err := d.DecodeNS(data, queryID)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(reply) != 1 {
|
||||
t.Fatal("unexpected reply length")
|
||||
}
|
||||
if reply[0].Host != "ns1.zdns.google." {
|
||||
t.Fatal("unexpected reply host")
|
||||
}
|
||||
t.Run("dnsResponse.LookupHost", func(t *testing.T) {
|
||||
t.Run("with failure", func(t *testing.T) {
|
||||
// Ensure that we're not trying to decode if rcode != 0
|
||||
d := &DNSDecoderMiekg{}
|
||||
queryID := dns.Id()
|
||||
rawQuery := dnsGenQuery(dns.TypeA, queryID)
|
||||
rawResponse := dnsGenReplyWithError(rawQuery, dns.RcodeRefused)
|
||||
query := &mocks.DNSQuery{
|
||||
MockID: func() uint16 {
|
||||
return queryID
|
||||
},
|
||||
}
|
||||
resp, err := d.DecodeResponse(rawResponse, query)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
addrs, err := resp.DecodeLookupHost()
|
||||
if !errors.Is(err, ErrOODNSRefused) {
|
||||
t.Fatal("unexpected err", err)
|
||||
}
|
||||
if len(addrs) > 0 {
|
||||
t.Fatal("expected empty addrs result")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("with empty answer", func(t *testing.T) {
|
||||
d := &DNSDecoderMiekg{}
|
||||
queryID := dns.Id()
|
||||
rawQuery := dnsGenQuery(dns.TypeA, queryID)
|
||||
rawResponse := dnsGenLookupHostReplySuccess(rawQuery)
|
||||
query := &mocks.DNSQuery{
|
||||
MockID: func() uint16 {
|
||||
return queryID
|
||||
},
|
||||
}
|
||||
resp, err := d.DecodeResponse(rawResponse, query)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
addrs, err := resp.DecodeLookupHost()
|
||||
if !errors.Is(err, ErrOODNSNoAnswer) {
|
||||
t.Fatal("unexpected err", err)
|
||||
}
|
||||
if len(addrs) > 0 {
|
||||
t.Fatal("expected empty ns results")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("decode A", func(t *testing.T) {
|
||||
d := &DNSDecoderMiekg{}
|
||||
queryID := dns.Id()
|
||||
rawQuery := dnsGenQuery(dns.TypeA, queryID)
|
||||
rawResponse := dnsGenLookupHostReplySuccess(rawQuery, "1.1.1.1", "8.8.8.8")
|
||||
query := &mocks.DNSQuery{
|
||||
MockID: func() uint16 {
|
||||
return queryID
|
||||
},
|
||||
MockType: func() uint16 {
|
||||
return dns.TypeA
|
||||
},
|
||||
}
|
||||
resp, err := d.DecodeResponse(rawResponse, query)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
addrs, err := resp.DecodeLookupHost()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(addrs) != 2 {
|
||||
t.Fatal("expected two entries here")
|
||||
}
|
||||
if addrs[0] != "1.1.1.1" {
|
||||
t.Fatal("invalid first IPv4 entry")
|
||||
}
|
||||
if addrs[1] != "8.8.8.8" {
|
||||
t.Fatal("invalid second IPv4 entry")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("decode AAAA", func(t *testing.T) {
|
||||
d := &DNSDecoderMiekg{}
|
||||
queryID := dns.Id()
|
||||
rawQuery := dnsGenQuery(dns.TypeAAAA, queryID)
|
||||
rawResponse := dnsGenLookupHostReplySuccess(rawQuery, "::1", "fe80::1")
|
||||
query := &mocks.DNSQuery{
|
||||
MockID: func() uint16 {
|
||||
return queryID
|
||||
},
|
||||
MockType: func() uint16 {
|
||||
return dns.TypeAAAA
|
||||
},
|
||||
}
|
||||
resp, err := d.DecodeResponse(rawResponse, query)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
addrs, err := resp.DecodeLookupHost()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(addrs) != 2 {
|
||||
t.Fatal("expected two entries here")
|
||||
}
|
||||
if addrs[0] != "::1" {
|
||||
t.Fatal("invalid first IPv6 entry")
|
||||
}
|
||||
if addrs[1] != "fe80::1" {
|
||||
t.Fatal("invalid second IPv6 entry")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("unexpected A reply to AAAA query", func(t *testing.T) {
|
||||
d := &DNSDecoderMiekg{}
|
||||
queryID := dns.Id()
|
||||
rawQuery := dnsGenQuery(dns.TypeAAAA, queryID)
|
||||
rawResponse := dnsGenLookupHostReplySuccess(rawQuery, "1.1.1.1", "8.8.8.8")
|
||||
query := &mocks.DNSQuery{
|
||||
MockID: func() uint16 {
|
||||
return queryID
|
||||
},
|
||||
MockType: func() uint16 {
|
||||
return dns.TypeAAAA
|
||||
},
|
||||
}
|
||||
resp, err := d.DecodeResponse(rawResponse, query)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
addrs, err := resp.DecodeLookupHost()
|
||||
if !errors.Is(err, ErrOODNSNoAnswer) {
|
||||
t.Fatal("not the error we expected", err)
|
||||
}
|
||||
if len(addrs) > 0 {
|
||||
t.Fatal("expected no addrs here")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("unexpected AAAA reply to A query", func(t *testing.T) {
|
||||
d := &DNSDecoderMiekg{}
|
||||
queryID := dns.Id()
|
||||
rawQuery := dnsGenQuery(dns.TypeA, queryID)
|
||||
rawResponse := dnsGenLookupHostReplySuccess(rawQuery, "::1", "fe80::1")
|
||||
query := &mocks.DNSQuery{
|
||||
MockID: func() uint16 {
|
||||
return queryID
|
||||
},
|
||||
MockType: func() uint16 {
|
||||
return dns.TypeA
|
||||
},
|
||||
}
|
||||
resp, err := d.DecodeResponse(rawResponse, query)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
addrs, err := resp.DecodeLookupHost()
|
||||
if !errors.Is(err, ErrOODNSNoAnswer) {
|
||||
t.Fatal("not the error we expected", err)
|
||||
}
|
||||
if len(addrs) > 0 {
|
||||
t.Fatal("expected no addrs here")
|
||||
}
|
||||
})
|
||||
})
|
||||
})
|
||||
}
|
||||
|
@ -371,8 +522,8 @@ func dnsGenReplyWithError(rawQuery []byte, code int) []byte {
|
|||
return data
|
||||
}
|
||||
|
||||
// dnsGenLookupHostReplySuccess generates a successful DNS reply for the given
|
||||
// qtype (e.g., dns.TypeA) containing the given ips... in the answer.
|
||||
// dnsGenLookupHostReplySuccess generates a successful DNS reply containing the given ips...
|
||||
// in the answers where each answer's type depends on the IP's type (A/AAAA).
|
||||
func dnsGenLookupHostReplySuccess(rawQuery []byte, ips ...string) []byte {
|
||||
query := new(dns.Msg)
|
||||
err := query.Unpack(rawQuery)
|
||||
|
@ -388,28 +539,22 @@ func dnsGenLookupHostReplySuccess(rawQuery []byte, ips ...string) []byte {
|
|||
reply.MsgHdr.RecursionAvailable = true
|
||||
reply.SetReply(query)
|
||||
for _, ip := range ips {
|
||||
switch question.Qtype {
|
||||
case dns.TypeA:
|
||||
if isIPv6(ip) {
|
||||
continue
|
||||
}
|
||||
switch isIPv6(ip) {
|
||||
case false:
|
||||
reply.Answer = append(reply.Answer, &dns.A{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: dns.Fqdn("x.org"),
|
||||
Rrtype: question.Qtype,
|
||||
Name: question.Name,
|
||||
Rrtype: dns.TypeA,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: 0,
|
||||
},
|
||||
A: net.ParseIP(ip),
|
||||
})
|
||||
case dns.TypeAAAA:
|
||||
if !isIPv6(ip) {
|
||||
continue
|
||||
}
|
||||
case true:
|
||||
reply.Answer = append(reply.Answer, &dns.AAAA{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: dns.Fqdn("x.org"),
|
||||
Rrtype: question.Qtype,
|
||||
Name: question.Name,
|
||||
Rrtype: dns.TypeAAAA,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: 0,
|
||||
},
|
||||
|
|
|
@ -5,7 +5,10 @@ package netxlite
|
|||
//
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/ooni/probe-cli/v3/internal/atomicx"
|
||||
"github.com/ooni/probe-cli/v3/internal/model"
|
||||
)
|
||||
|
||||
|
@ -23,18 +26,82 @@ const (
|
|||
dnsDNSSECEnabled = true
|
||||
)
|
||||
|
||||
func (e *DNSEncoderMiekg) Encode(domain string, qtype uint16, padding bool) ([]byte, uint16, error) {
|
||||
// Encoder implements model.DNSEncoder.Encode.
|
||||
func (e *DNSEncoderMiekg) Encode(domain string, qtype uint16, padding bool) model.DNSQuery {
|
||||
return &dnsQuery{
|
||||
bytesCalls: &atomicx.Int64{},
|
||||
domain: domain,
|
||||
kind: qtype,
|
||||
id: dns.Id(),
|
||||
memoizedBytes: []byte{},
|
||||
mu: sync.Mutex{},
|
||||
padding: padding,
|
||||
}
|
||||
}
|
||||
|
||||
// dnsQuery implements model.DNSQuery.
|
||||
type dnsQuery struct {
|
||||
// bytesCalls counts the calls to the bytes() method
|
||||
bytesCalls *atomicx.Int64
|
||||
|
||||
// domain is the domain.
|
||||
domain string
|
||||
|
||||
// kind is the query type.
|
||||
kind uint16
|
||||
|
||||
// id is the query ID.
|
||||
id uint16
|
||||
|
||||
// memoizedBytes contains the query encoded as bytes. We only fill
|
||||
// this field the first time the Bytes method is called.
|
||||
memoizedBytes []byte
|
||||
|
||||
// mu provides mutual exclusion.
|
||||
mu sync.Mutex
|
||||
|
||||
// padding indicates whether we need padding.
|
||||
padding bool
|
||||
}
|
||||
|
||||
// Domain implements model.DNSQuery.Domain.
|
||||
func (q *dnsQuery) Domain() string {
|
||||
return q.domain
|
||||
}
|
||||
|
||||
// Type implements model.DNSQuery.Type.
|
||||
func (q *dnsQuery) Type() uint16 {
|
||||
return q.kind
|
||||
}
|
||||
|
||||
// Bytes implements model.DNSQuery.Bytes.
|
||||
func (q *dnsQuery) Bytes() ([]byte, error) {
|
||||
defer q.mu.Unlock()
|
||||
q.mu.Lock()
|
||||
if len(q.memoizedBytes) <= 0 {
|
||||
q.bytesCalls.Add(1) // for testing
|
||||
data, err := q.bytes()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
q.memoizedBytes = data
|
||||
}
|
||||
return q.memoizedBytes, nil
|
||||
}
|
||||
|
||||
// bytes is the unmemoized implementation of Bytes
|
||||
func (q *dnsQuery) bytes() ([]byte, error) {
|
||||
question := dns.Question{
|
||||
Name: dns.Fqdn(domain),
|
||||
Qtype: qtype,
|
||||
Name: dns.Fqdn(q.domain),
|
||||
Qtype: q.kind,
|
||||
Qclass: dns.ClassINET,
|
||||
}
|
||||
query := new(dns.Msg)
|
||||
query.Id = dns.Id()
|
||||
query.Id = q.id
|
||||
query.RecursionDesired = true
|
||||
query.Question = make([]dns.Question, 1)
|
||||
query.Question[0] = question
|
||||
if padding {
|
||||
if q.padding {
|
||||
query.SetEdns0(dnsEDNS0MaxResponseSize, dnsDNSSECEnabled)
|
||||
// Clients SHOULD pad queries to the closest multiple of
|
||||
// 128 octets RFC8467#section-4.1. We inflate the query
|
||||
|
@ -47,8 +114,13 @@ func (e *DNSEncoderMiekg) Encode(domain string, qtype uint16, padding bool) ([]b
|
|||
opt.Padding = make([]byte, remainder)
|
||||
query.IsEdns0().Option = append(query.IsEdns0().Option, opt)
|
||||
}
|
||||
data, err := query.Pack()
|
||||
return data, query.Id, err
|
||||
return query.Pack()
|
||||
}
|
||||
|
||||
// ID implements model.DNSQuery.ID
|
||||
func (q *dnsQuery) ID() uint16 {
|
||||
return q.id
|
||||
}
|
||||
|
||||
var _ model.DNSEncoder = &DNSEncoderMiekg{}
|
||||
var _ model.DNSQuery = &dnsQuery{}
|
||||
|
|
|
@ -1,29 +1,103 @@
|
|||
package netxlite
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/ooni/probe-cli/v3/internal/randx"
|
||||
"github.com/ooni/probe-cli/v3/internal/runtimex"
|
||||
)
|
||||
|
||||
func TestDNSEncoder(t *testing.T) {
|
||||
func TestDNSEncoderMiekg(t *testing.T) {
|
||||
t.Run("we can fail to encode a domain name to bytes", func(t *testing.T) {
|
||||
e := &DNSEncoderMiekg{}
|
||||
domain := randx.LettersUppercase(512)
|
||||
query := e.Encode(domain, dns.TypeA, false)
|
||||
data, err := query.Bytes()
|
||||
if err == nil || !strings.HasSuffix(err.Error(), "bad rdata") {
|
||||
t.Fatal("unexpected err", err)
|
||||
}
|
||||
if data != nil {
|
||||
t.Fatal("expected nil data here")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("calls to bytes are memoized", func(t *testing.T) {
|
||||
t.Run("on success", func(t *testing.T) {
|
||||
e := &DNSEncoderMiekg{}
|
||||
query := e.Encode("x.org", dns.TypeA, false)
|
||||
checkResult := func(data []byte, err error) {
|
||||
if err != nil {
|
||||
t.Fatal("unexpected err", err)
|
||||
}
|
||||
dnsValidateEncodedQueryBytes(t, data, byte(dns.TypeA), query.ID())
|
||||
}
|
||||
const repeat = 3
|
||||
for idx := 0; idx < repeat; idx++ {
|
||||
checkResult(query.Bytes())
|
||||
}
|
||||
// The following cast will always work in this configuration
|
||||
if query.(*dnsQuery).bytesCalls.Load() != 1 {
|
||||
t.Fatal("invalid number of calls")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("on failure", func(t *testing.T) {
|
||||
e := &DNSEncoderMiekg{}
|
||||
domain := randx.LettersUppercase(512)
|
||||
query := e.Encode(domain, dns.TypeA, false)
|
||||
checkResult := func(data []byte, err error) {
|
||||
if err == nil || !strings.HasSuffix(err.Error(), "bad rdata") {
|
||||
t.Fatal("unexpected err", err)
|
||||
}
|
||||
if data != nil {
|
||||
t.Fatal("expected nil data here")
|
||||
}
|
||||
}
|
||||
const repeat = 3
|
||||
for idx := 0; idx < repeat; idx++ {
|
||||
checkResult(query.Bytes())
|
||||
}
|
||||
// The following cast will always work in this configuration
|
||||
if query.(*dnsQuery).bytesCalls.Load() != repeat {
|
||||
t.Fatal("invalid number of calls")
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("encode A", func(t *testing.T) {
|
||||
e := &DNSEncoderMiekg{}
|
||||
data, _, err := e.Encode("x.org", dns.TypeA, false)
|
||||
query := e.Encode("x.org", dns.TypeA, false)
|
||||
if query.Domain() != "x.org" {
|
||||
t.Fatal("invalid domain")
|
||||
}
|
||||
if query.Type() != dns.TypeA {
|
||||
t.Fatal("invalid type")
|
||||
}
|
||||
data, err := query.Bytes()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
dnsValidateEncodedQueryBytes(t, data, byte(dns.TypeA))
|
||||
dnsValidateEncodedQueryBytes(t, data, byte(dns.TypeA), query.ID())
|
||||
})
|
||||
|
||||
t.Run("encode AAAA", func(t *testing.T) {
|
||||
e := &DNSEncoderMiekg{}
|
||||
data, _, err := e.Encode("x.org", dns.TypeAAAA, false)
|
||||
query := e.Encode("x.org", dns.TypeAAAA, false)
|
||||
if query.Domain() != "x.org" {
|
||||
t.Fatal("invalid domain")
|
||||
}
|
||||
if query.Type() != dns.TypeAAAA {
|
||||
t.Fatal("invalid type")
|
||||
}
|
||||
data, err := query.Bytes()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
dnsValidateEncodedQueryBytes(t, data, byte(dns.TypeA))
|
||||
dnsValidateEncodedQueryBytes(t, data, byte(dns.TypeA), query.ID())
|
||||
})
|
||||
|
||||
t.Run("encode padding", func(t *testing.T) {
|
||||
|
@ -31,7 +105,7 @@ func TestDNSEncoder(t *testing.T) {
|
|||
// array of values we obtain the right query size.
|
||||
getquerylen := func(domainlen int, padding bool) int {
|
||||
e := &DNSEncoderMiekg{}
|
||||
data, _, err := e.Encode(
|
||||
query := e.Encode(
|
||||
// This is not a valid name because it ends up being way
|
||||
// longer than 255 octets. However, the library is allowing
|
||||
// us to generate such name and we are not going to send
|
||||
|
@ -40,6 +114,7 @@ func TestDNSEncoder(t *testing.T) {
|
|||
dns.Fqdn(strings.Repeat("x.", domainlen)),
|
||||
dns.TypeA, padding,
|
||||
)
|
||||
data, err := query.Bytes()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -63,8 +138,13 @@ func TestDNSEncoder(t *testing.T) {
|
|||
|
||||
// dnsValidateEncodedQueryBytes validates the query serialized in data
|
||||
// for the given query type qtype (e.g., dns.TypeAAAA).
|
||||
func dnsValidateEncodedQueryBytes(t *testing.T, data []byte, qtype byte) {
|
||||
// skipping over the query ID
|
||||
func dnsValidateEncodedQueryBytes(t *testing.T, data []byte, qtype byte, qid uint16) {
|
||||
var wirequery uint16
|
||||
err := binary.Read(bytes.NewReader(data), binary.BigEndian, &wirequery)
|
||||
runtimex.PanicOnError(err, "binary.Read failed unexpectedly")
|
||||
if wirequery != qid {
|
||||
t.Fatal("invalid query ID")
|
||||
}
|
||||
if data[2] != 1 {
|
||||
t.Fatal("FLAGS should only have RD set")
|
||||
}
|
||||
|
|
|
@ -8,6 +8,7 @@ import (
|
|||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
|
@ -19,6 +20,9 @@ type DNSOverHTTPSTransport struct {
|
|||
// Client is the MANDATORY http client to use.
|
||||
Client model.HTTPClient
|
||||
|
||||
// Decoder is the MANDATORY DNSDecoder.
|
||||
Decoder model.DNSDecoder
|
||||
|
||||
// URL is the MANDATORY URL of the DNS-over-HTTPS server.
|
||||
URL string
|
||||
|
||||
|
@ -31,9 +35,9 @@ type DNSOverHTTPSTransport struct {
|
|||
//
|
||||
// Arguments:
|
||||
//
|
||||
// - client in http.Client-like type (e.g., http.DefaultClient);
|
||||
// - client is a model.HTTPClient type;
|
||||
//
|
||||
// - URL is the DoH resolver URL (e.g., https://1.1.1.1/dns-query).
|
||||
// - URL is the DoH resolver URL (e.g., https://dns.google/dns-query).
|
||||
func NewDNSOverHTTPSTransport(client model.HTTPClient, URL string) *DNSOverHTTPSTransport {
|
||||
return NewDNSOverHTTPSTransportWithHostOverride(client, URL, "")
|
||||
}
|
||||
|
@ -42,22 +46,31 @@ func NewDNSOverHTTPSTransport(client model.HTTPClient, URL string) *DNSOverHTTPS
|
|||
// with the given Host header override.
|
||||
func NewDNSOverHTTPSTransportWithHostOverride(
|
||||
client model.HTTPClient, URL, hostOverride string) *DNSOverHTTPSTransport {
|
||||
return &DNSOverHTTPSTransport{Client: client, URL: URL, HostOverride: hostOverride}
|
||||
return &DNSOverHTTPSTransport{
|
||||
Client: client,
|
||||
Decoder: &DNSDecoderMiekg{},
|
||||
URL: URL,
|
||||
HostOverride: hostOverride,
|
||||
}
|
||||
}
|
||||
|
||||
// RoundTrip sends a query and receives a reply.
|
||||
func (t *DNSOverHTTPSTransport) RoundTrip(ctx context.Context, query []byte) ([]byte, error) {
|
||||
func (t *DNSOverHTTPSTransport) RoundTrip(
|
||||
ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
|
||||
rawQuery, err := query.Bytes()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(ctx, 45*time.Second)
|
||||
defer cancel()
|
||||
req, err := http.NewRequest("POST", t.URL, bytes.NewReader(query))
|
||||
req, err := http.NewRequest("POST", t.URL, bytes.NewReader(rawQuery))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Host = t.HostOverride
|
||||
req.Header.Set("user-agent", model.HTTPHeaderUserAgent)
|
||||
req.Header.Set("content-type", "application/dns-message")
|
||||
var resp *http.Response
|
||||
resp, err = t.Client.Do(req.WithContext(ctx))
|
||||
resp, err := t.Client.Do(req.WithContext(ctx))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -70,7 +83,13 @@ func (t *DNSOverHTTPSTransport) RoundTrip(ctx context.Context, query []byte) ([]
|
|||
if resp.Header.Get("content-type") != "application/dns-message" {
|
||||
return nil, errors.New("doh: invalid content-type")
|
||||
}
|
||||
return ReadAllContext(ctx, resp.Body)
|
||||
const maxresponsesize = 1 << 20
|
||||
limitReader := io.LimitReader(resp.Body, maxresponsesize)
|
||||
rawResponse, err := ReadAllContext(ctx, limitReader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return t.Decoder.DecodeResponse(rawResponse, query)
|
||||
}
|
||||
|
||||
// RequiresPadding returns true for DoH according to RFC8467.
|
||||
|
|
|
@ -15,14 +15,36 @@ import (
|
|||
|
||||
func TestDNSOverHTTPSTransport(t *testing.T) {
|
||||
t.Run("RoundTrip", func(t *testing.T) {
|
||||
t.Run("query serialization failure", func(t *testing.T) {
|
||||
txp := NewDNSOverHTTPSTransport(http.DefaultClient, "https://1.1.1.1/dns-query")
|
||||
expected := errors.New("mocked error")
|
||||
query := &mocks.DNSQuery{
|
||||
MockBytes: func() ([]byte, error) {
|
||||
return nil, expected
|
||||
},
|
||||
}
|
||||
resp, err := txp.RoundTrip(context.Background(), query)
|
||||
if !errors.Is(err, expected) {
|
||||
t.Fatal("unexpected err", err)
|
||||
}
|
||||
if resp != nil {
|
||||
t.Fatal("expected no response here")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("NewRequestFailure", func(t *testing.T) {
|
||||
const invalidURL = "\t"
|
||||
txp := NewDNSOverHTTPSTransport(http.DefaultClient, invalidURL)
|
||||
data, err := txp.RoundTrip(context.Background(), nil)
|
||||
if err == nil || !strings.HasSuffix(err.Error(), "invalid control character in URL") {
|
||||
t.Fatal("expected an error here")
|
||||
query := &mocks.DNSQuery{
|
||||
MockBytes: func() ([]byte, error) {
|
||||
return make([]byte, 17), nil
|
||||
},
|
||||
}
|
||||
if data != nil {
|
||||
resp, err := txp.RoundTrip(context.Background(), query)
|
||||
if err == nil || !strings.HasSuffix(err.Error(), "invalid control character in URL") {
|
||||
t.Fatal("unexpected err", err)
|
||||
}
|
||||
if resp != nil {
|
||||
t.Fatal("expected no response here")
|
||||
}
|
||||
})
|
||||
|
@ -37,11 +59,16 @@ func TestDNSOverHTTPSTransport(t *testing.T) {
|
|||
},
|
||||
URL: "https://cloudflare-dns.com/dns-query",
|
||||
}
|
||||
data, err := txp.RoundTrip(context.Background(), nil)
|
||||
if !errors.Is(err, expected) {
|
||||
t.Fatal("expected an error here")
|
||||
query := &mocks.DNSQuery{
|
||||
MockBytes: func() ([]byte, error) {
|
||||
return make([]byte, 17), nil
|
||||
},
|
||||
}
|
||||
if data != nil {
|
||||
resp, err := txp.RoundTrip(context.Background(), query)
|
||||
if !errors.Is(err, expected) {
|
||||
t.Fatal("unexpected err", err)
|
||||
}
|
||||
if resp != nil {
|
||||
t.Fatal("expected no response here")
|
||||
}
|
||||
})
|
||||
|
@ -58,11 +85,16 @@ func TestDNSOverHTTPSTransport(t *testing.T) {
|
|||
},
|
||||
URL: "https://cloudflare-dns.com/dns-query",
|
||||
}
|
||||
data, err := txp.RoundTrip(context.Background(), nil)
|
||||
if err == nil || err.Error() != "doh: server returned error" {
|
||||
t.Fatal("expected an error here")
|
||||
query := &mocks.DNSQuery{
|
||||
MockBytes: func() ([]byte, error) {
|
||||
return make([]byte, 17), nil
|
||||
},
|
||||
}
|
||||
if data != nil {
|
||||
resp, err := txp.RoundTrip(context.Background(), query)
|
||||
if err == nil || err.Error() != "doh: server returned error" {
|
||||
t.Fatal("unexpected err", err)
|
||||
}
|
||||
if resp != nil {
|
||||
t.Fatal("expected no response here")
|
||||
}
|
||||
})
|
||||
|
@ -79,11 +111,86 @@ func TestDNSOverHTTPSTransport(t *testing.T) {
|
|||
},
|
||||
URL: "https://cloudflare-dns.com/dns-query",
|
||||
}
|
||||
data, err := txp.RoundTrip(context.Background(), nil)
|
||||
if err == nil || err.Error() != "doh: invalid content-type" {
|
||||
t.Fatal("expected an error here")
|
||||
query := &mocks.DNSQuery{
|
||||
MockBytes: func() ([]byte, error) {
|
||||
return make([]byte, 17), nil
|
||||
},
|
||||
}
|
||||
if data != nil {
|
||||
resp, err := txp.RoundTrip(context.Background(), query)
|
||||
if err == nil || err.Error() != "doh: invalid content-type" {
|
||||
t.Fatal("unexpected err", err)
|
||||
}
|
||||
if resp != nil {
|
||||
t.Fatal("expected no response here")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ReadAllContext fails", func(t *testing.T) {
|
||||
expected := errors.New("mocked error")
|
||||
txp := &DNSOverHTTPSTransport{
|
||||
Client: &mocks.HTTPClient{
|
||||
MockDo: func(*http.Request) (*http.Response, error) {
|
||||
return &http.Response{
|
||||
StatusCode: 200,
|
||||
Body: io.NopCloser(&mocks.Reader{
|
||||
MockRead: func(b []byte) (int, error) {
|
||||
return 0, expected
|
||||
},
|
||||
}),
|
||||
Header: http.Header{
|
||||
"Content-Type": []string{"application/dns-message"},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
},
|
||||
URL: "https://cloudflare-dns.com/dns-query",
|
||||
}
|
||||
query := &mocks.DNSQuery{
|
||||
MockBytes: func() ([]byte, error) {
|
||||
return make([]byte, 17), nil
|
||||
},
|
||||
}
|
||||
resp, err := txp.RoundTrip(context.Background(), query)
|
||||
if !errors.Is(err, expected) {
|
||||
t.Fatal("unexpected err", err)
|
||||
}
|
||||
if resp != nil {
|
||||
t.Fatal("expected no response here")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("decode response failure", func(t *testing.T) {
|
||||
expected := errors.New("mocked error")
|
||||
body := []byte("AAA")
|
||||
txp := &DNSOverHTTPSTransport{
|
||||
Client: &mocks.HTTPClient{
|
||||
MockDo: func(*http.Request) (*http.Response, error) {
|
||||
return &http.Response{
|
||||
StatusCode: 200,
|
||||
Body: io.NopCloser(bytes.NewReader(body)),
|
||||
Header: http.Header{
|
||||
"Content-Type": []string{"application/dns-message"},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
},
|
||||
URL: "https://cloudflare-dns.com/dns-query",
|
||||
Decoder: &mocks.DNSDecoder{
|
||||
MockDecodeResponse: func(data []byte, query model.DNSQuery) (model.DNSResponse, error) {
|
||||
return nil, expected
|
||||
},
|
||||
},
|
||||
}
|
||||
query := &mocks.DNSQuery{
|
||||
MockBytes: func() ([]byte, error) {
|
||||
return make([]byte, 17), nil
|
||||
},
|
||||
}
|
||||
resp, err := txp.RoundTrip(context.Background(), query)
|
||||
if !errors.Is(err, expected) {
|
||||
t.Fatal("unexpected err", err)
|
||||
}
|
||||
if resp != nil {
|
||||
t.Fatal("expected no response here")
|
||||
}
|
||||
})
|
||||
|
@ -103,13 +210,23 @@ func TestDNSOverHTTPSTransport(t *testing.T) {
|
|||
},
|
||||
},
|
||||
URL: "https://cloudflare-dns.com/dns-query",
|
||||
Decoder: &mocks.DNSDecoder{
|
||||
MockDecodeResponse: func(data []byte, query model.DNSQuery) (model.DNSResponse, error) {
|
||||
return &mocks.DNSResponse{}, nil
|
||||
},
|
||||
},
|
||||
}
|
||||
data, err := txp.RoundTrip(context.Background(), nil)
|
||||
query := &mocks.DNSQuery{
|
||||
MockBytes: func() ([]byte, error) {
|
||||
return make([]byte, 17), nil
|
||||
},
|
||||
}
|
||||
resp, err := txp.RoundTrip(context.Background(), query)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !bytes.Equal(data, body) {
|
||||
t.Fatal("not the response we expected")
|
||||
if resp == nil {
|
||||
t.Fatal("expected non-nil resp here")
|
||||
}
|
||||
})
|
||||
|
||||
|
@ -125,7 +242,12 @@ func TestDNSOverHTTPSTransport(t *testing.T) {
|
|||
},
|
||||
URL: "https://cloudflare-dns.com/dns-query",
|
||||
}
|
||||
data, err := txp.RoundTrip(context.Background(), nil)
|
||||
query := &mocks.DNSQuery{
|
||||
MockBytes: func() ([]byte, error) {
|
||||
return make([]byte, 17), nil
|
||||
},
|
||||
}
|
||||
data, err := txp.RoundTrip(context.Background(), query)
|
||||
if !errors.Is(err, expected) {
|
||||
t.Fatal("expected an error here")
|
||||
}
|
||||
|
@ -151,18 +273,22 @@ func TestDNSOverHTTPSTransport(t *testing.T) {
|
|||
URL: "https://cloudflare-dns.com/dns-query",
|
||||
HostOverride: hostOverride,
|
||||
}
|
||||
data, err := txp.RoundTrip(context.Background(), nil)
|
||||
query := &mocks.DNSQuery{
|
||||
MockBytes: func() ([]byte, error) {
|
||||
return make([]byte, 17), nil
|
||||
},
|
||||
}
|
||||
resp, err := txp.RoundTrip(context.Background(), query)
|
||||
if !errors.Is(err, expected) {
|
||||
t.Fatal("expected an error here")
|
||||
}
|
||||
if data != nil {
|
||||
if resp != nil {
|
||||
t.Fatal("expected no response here")
|
||||
}
|
||||
if !correct {
|
||||
t.Fatal("did not see correct host override")
|
||||
}
|
||||
})
|
||||
|
||||
})
|
||||
|
||||
t.Run("other functions behave correctly", func(t *testing.T) {
|
||||
|
|
|
@ -20,9 +20,12 @@ type DialContextFunc func(context.Context, string, string) (net.Conn, error)
|
|||
|
||||
// DNSOverTCPTransport is a DNS-over-{TCP,TLS} DNSTransport.
|
||||
//
|
||||
// Bug: this implementation always creates a new connection for each query.
|
||||
// Note: this implementation always creates a new connection for each query. This
|
||||
// strategy is less efficient but MAY be more robust for cleartext TCP connections
|
||||
// when querying for a blocked domain name causes endpoint blocking.
|
||||
type DNSOverTCPTransport struct {
|
||||
dial DialContextFunc
|
||||
decoder model.DNSDecoder
|
||||
address string
|
||||
network string
|
||||
requiresPadding bool
|
||||
|
@ -36,47 +39,58 @@ type DNSOverTCPTransport struct {
|
|||
//
|
||||
// - address is the endpoint address (e.g., 8.8.8.8:53).
|
||||
func NewDNSOverTCPTransport(dial DialContextFunc, address string) *DNSOverTCPTransport {
|
||||
return &DNSOverTCPTransport{
|
||||
dial: dial,
|
||||
address: address,
|
||||
network: "tcp",
|
||||
requiresPadding: false,
|
||||
}
|
||||
return newDNSOverTCPOrTLSTransport(dial, "tcp", address, false)
|
||||
}
|
||||
|
||||
// NewDNSOverTLS creates a new DNSOverTLS transport.
|
||||
// NewDNSOverTLSTransport creates a new DNSOverTLS transport.
|
||||
//
|
||||
// Arguments:
|
||||
//
|
||||
// - dial is a function with the net.Dialer.DialContext's signature;
|
||||
//
|
||||
// - address is the endpoint address (e.g., 8.8.8.8:853).
|
||||
func NewDNSOverTLS(dial DialContextFunc, address string) *DNSOverTCPTransport {
|
||||
func NewDNSOverTLSTransport(dial DialContextFunc, address string) *DNSOverTCPTransport {
|
||||
return newDNSOverTCPOrTLSTransport(dial, "dot", address, true)
|
||||
}
|
||||
|
||||
// newDNSOverTCPOrTLSTransport is the common factory for creating a transport
|
||||
func newDNSOverTCPOrTLSTransport(
|
||||
dial DialContextFunc, network, address string, padding bool) *DNSOverTCPTransport {
|
||||
return &DNSOverTCPTransport{
|
||||
dial: dial,
|
||||
decoder: &DNSDecoderMiekg{},
|
||||
address: address,
|
||||
network: "dot",
|
||||
requiresPadding: true,
|
||||
network: network,
|
||||
requiresPadding: padding,
|
||||
}
|
||||
}
|
||||
|
||||
// errQueryTooLarge indicates the query is too large for the transport.
|
||||
var errQueryTooLarge = errors.New("oodns: query too large for this transport")
|
||||
|
||||
// RoundTrip sends a query and receives a reply.
|
||||
func (t *DNSOverTCPTransport) RoundTrip(ctx context.Context, query []byte) ([]byte, error) {
|
||||
if len(query) > math.MaxUint16 {
|
||||
return nil, errors.New("query too long")
|
||||
func (t *DNSOverTCPTransport) RoundTrip(
|
||||
ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
|
||||
// TODO(bassosimone): this method should more strictly honour the context, which
|
||||
// currently is only used to bound the dial operation
|
||||
rawQuery, err := query.Bytes()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(rawQuery) > math.MaxUint16 {
|
||||
return nil, errQueryTooLarge
|
||||
}
|
||||
conn, err := t.dial(ctx, "tcp", t.address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer conn.Close()
|
||||
if err = conn.SetDeadline(time.Now().Add(10 * time.Second)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
const iotimeout = 10 * time.Second
|
||||
conn.SetDeadline(time.Now().Add(iotimeout))
|
||||
// Write request
|
||||
buf := []byte{byte(len(query) >> 8)}
|
||||
buf = append(buf, byte(len(query)))
|
||||
buf = append(buf, query...)
|
||||
buf := []byte{byte(len(rawQuery) >> 8)}
|
||||
buf = append(buf, byte(len(rawQuery)))
|
||||
buf = append(buf, rawQuery...)
|
||||
if _, err = conn.Write(buf); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -86,11 +100,11 @@ func (t *DNSOverTCPTransport) RoundTrip(ctx context.Context, query []byte) ([]by
|
|||
return nil, err
|
||||
}
|
||||
length := int(header[0])<<8 | int(header[1])
|
||||
reply := make([]byte, length)
|
||||
if _, err = io.ReadFull(conn, reply); err != nil {
|
||||
rawResponse := make([]byte, length)
|
||||
if _, err = io.ReadFull(conn, rawResponse); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return reply, nil
|
||||
return t.decoder.DecodeResponse(rawResponse, query)
|
||||
}
|
||||
|
||||
// RequiresPadding returns true for DoT and false for TCP
|
||||
|
|
|
@ -6,73 +6,83 @@ import (
|
|||
"crypto/tls"
|
||||
"errors"
|
||||
"io"
|
||||
"math"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ooni/probe-cli/v3/internal/model"
|
||||
"github.com/ooni/probe-cli/v3/internal/model/mocks"
|
||||
)
|
||||
|
||||
func TestDNSOverTCPTransport(t *testing.T) {
|
||||
t.Run("RoundTrip", func(t *testing.T) {
|
||||
t.Run("cannot encode query", func(t *testing.T) {
|
||||
expected := errors.New("mocked error")
|
||||
const address = "9.9.9.9:53"
|
||||
txp := NewDNSOverTCPTransport(new(net.Dialer).DialContext, address)
|
||||
query := &mocks.DNSQuery{
|
||||
MockBytes: func() ([]byte, error) {
|
||||
return nil, expected
|
||||
},
|
||||
}
|
||||
resp, err := txp.RoundTrip(context.Background(), query)
|
||||
if !errors.Is(err, expected) {
|
||||
t.Fatal("unexpected err", err)
|
||||
}
|
||||
if resp != nil {
|
||||
t.Fatal("expected nil response here")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("query too large", func(t *testing.T) {
|
||||
const address = "9.9.9.9:53"
|
||||
txp := NewDNSOverTCPTransport(new(net.Dialer).DialContext, address)
|
||||
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<18))
|
||||
if err == nil {
|
||||
t.Fatal("expected an error here")
|
||||
query := &mocks.DNSQuery{
|
||||
MockBytes: func() ([]byte, error) {
|
||||
return make([]byte, math.MaxUint16+1), nil
|
||||
},
|
||||
}
|
||||
if reply != nil {
|
||||
t.Fatal("expected nil reply here")
|
||||
resp, err := txp.RoundTrip(context.Background(), query)
|
||||
if !errors.Is(err, errQueryTooLarge) {
|
||||
t.Fatal("unexpected err", err)
|
||||
}
|
||||
if resp != nil {
|
||||
t.Fatal("expected nil response here")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("dial failure", func(t *testing.T) {
|
||||
const address = "9.9.9.9:53"
|
||||
mocked := errors.New("mocked error")
|
||||
query := &mocks.DNSQuery{
|
||||
MockBytes: func() ([]byte, error) {
|
||||
return make([]byte, 128), nil
|
||||
},
|
||||
}
|
||||
fakedialer := &mocks.Dialer{
|
||||
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
return nil, mocked
|
||||
},
|
||||
}
|
||||
txp := NewDNSOverTCPTransport(fakedialer.DialContext, address)
|
||||
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11))
|
||||
resp, err := txp.RoundTrip(context.Background(), query)
|
||||
if !errors.Is(err, mocked) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if reply != nil {
|
||||
t.Fatal("expected nil reply here")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("SetDeadline failure", func(t *testing.T) {
|
||||
const address = "9.9.9.9:53"
|
||||
mocked := errors.New("mocked error")
|
||||
fakedialer := &mocks.Dialer{
|
||||
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
return &mocks.Conn{
|
||||
MockSetDeadline: func(t time.Time) error {
|
||||
return mocked
|
||||
},
|
||||
MockClose: func() error {
|
||||
return nil
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
txp := NewDNSOverTCPTransport(fakedialer.DialContext, address)
|
||||
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11))
|
||||
if !errors.Is(err, mocked) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if reply != nil {
|
||||
t.Fatal("expected nil reply here")
|
||||
if resp != nil {
|
||||
t.Fatal("expected nil resp here")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("write failure", func(t *testing.T) {
|
||||
const address = "9.9.9.9:53"
|
||||
mocked := errors.New("mocked error")
|
||||
query := &mocks.DNSQuery{
|
||||
MockBytes: func() ([]byte, error) {
|
||||
return make([]byte, 128), nil
|
||||
},
|
||||
}
|
||||
fakedialer := &mocks.Dialer{
|
||||
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
return &mocks.Conn{
|
||||
|
@ -89,18 +99,23 @@ func TestDNSOverTCPTransport(t *testing.T) {
|
|||
},
|
||||
}
|
||||
txp := NewDNSOverTCPTransport(fakedialer.DialContext, address)
|
||||
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11))
|
||||
resp, err := txp.RoundTrip(context.Background(), query)
|
||||
if !errors.Is(err, mocked) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if reply != nil {
|
||||
t.Fatal("expected nil reply here")
|
||||
if resp != nil {
|
||||
t.Fatal("expected nil resp here")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("first read fails", func(t *testing.T) {
|
||||
const address = "9.9.9.9:53"
|
||||
mocked := errors.New("mocked error")
|
||||
query := &mocks.DNSQuery{
|
||||
MockBytes: func() ([]byte, error) {
|
||||
return make([]byte, 128), nil
|
||||
},
|
||||
}
|
||||
fakedialer := &mocks.Dialer{
|
||||
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
return &mocks.Conn{
|
||||
|
@ -120,18 +135,23 @@ func TestDNSOverTCPTransport(t *testing.T) {
|
|||
},
|
||||
}
|
||||
txp := NewDNSOverTCPTransport(fakedialer.DialContext, address)
|
||||
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11))
|
||||
resp, err := txp.RoundTrip(context.Background(), query)
|
||||
if !errors.Is(err, mocked) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if reply != nil {
|
||||
t.Fatal("expected nil reply here")
|
||||
if resp != nil {
|
||||
t.Fatal("expected nil resp here")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("second read fails", func(t *testing.T) {
|
||||
const address = "9.9.9.9:53"
|
||||
mocked := errors.New("mocked error")
|
||||
query := &mocks.DNSQuery{
|
||||
MockBytes: func() ([]byte, error) {
|
||||
return make([]byte, 128), nil
|
||||
},
|
||||
}
|
||||
input := io.MultiReader(
|
||||
bytes.NewReader([]byte{byte(0), byte(2)}),
|
||||
&mocks.Reader{
|
||||
|
@ -157,17 +177,23 @@ func TestDNSOverTCPTransport(t *testing.T) {
|
|||
},
|
||||
}
|
||||
txp := NewDNSOverTCPTransport(fakedialer.DialContext, address)
|
||||
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11))
|
||||
resp, err := txp.RoundTrip(context.Background(), query)
|
||||
if !errors.Is(err, mocked) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if reply != nil {
|
||||
t.Fatal("expected nil reply here")
|
||||
if resp != nil {
|
||||
t.Fatal("expected nil resp here")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("successful case", func(t *testing.T) {
|
||||
t.Run("decode failure", func(t *testing.T) {
|
||||
const address = "9.9.9.9:53"
|
||||
mocked := errors.New("mocked error")
|
||||
query := &mocks.DNSQuery{
|
||||
MockBytes: func() ([]byte, error) {
|
||||
return make([]byte, 128), nil
|
||||
},
|
||||
}
|
||||
input := bytes.NewReader([]byte{byte(0), byte(1), byte(1)})
|
||||
fakedialer := &mocks.Dialer{
|
||||
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
|
@ -186,11 +212,56 @@ func TestDNSOverTCPTransport(t *testing.T) {
|
|||
},
|
||||
}
|
||||
txp := NewDNSOverTCPTransport(fakedialer.DialContext, address)
|
||||
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11))
|
||||
txp.decoder = &mocks.DNSDecoder{
|
||||
MockDecodeResponse: func(data []byte, query model.DNSQuery) (model.DNSResponse, error) {
|
||||
return nil, mocked
|
||||
},
|
||||
}
|
||||
resp, err := txp.RoundTrip(context.Background(), query)
|
||||
if !errors.Is(err, mocked) {
|
||||
t.Fatal("unexpected err", err)
|
||||
}
|
||||
if resp != nil {
|
||||
t.Fatal("expected nil resp here")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("successful case", func(t *testing.T) {
|
||||
const address = "9.9.9.9:53"
|
||||
query := &mocks.DNSQuery{
|
||||
MockBytes: func() ([]byte, error) {
|
||||
return make([]byte, 128), nil
|
||||
},
|
||||
}
|
||||
input := bytes.NewReader([]byte{byte(0), byte(1), byte(1)})
|
||||
fakedialer := &mocks.Dialer{
|
||||
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
return &mocks.Conn{
|
||||
MockSetDeadline: func(t time.Time) error {
|
||||
return nil
|
||||
},
|
||||
MockWrite: func(b []byte) (int, error) {
|
||||
return len(b), nil
|
||||
},
|
||||
MockRead: input.Read,
|
||||
MockClose: func() error {
|
||||
return nil
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
txp := NewDNSOverTCPTransport(fakedialer.DialContext, address)
|
||||
expectedResp := &mocks.DNSResponse{}
|
||||
txp.decoder = &mocks.DNSDecoder{
|
||||
MockDecodeResponse: func(data []byte, query model.DNSQuery) (model.DNSResponse, error) {
|
||||
return expectedResp, nil
|
||||
},
|
||||
}
|
||||
resp, err := txp.RoundTrip(context.Background(), query)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(reply) != 1 || reply[0] != 1 {
|
||||
if resp != expectedResp {
|
||||
t.Fatal("not the response we expected")
|
||||
}
|
||||
})
|
||||
|
@ -213,7 +284,7 @@ func TestDNSOverTCPTransport(t *testing.T) {
|
|||
|
||||
t.Run("other functions okay with TLS", func(t *testing.T) {
|
||||
const address = "9.9.9.9:853"
|
||||
txp := NewDNSOverTLS((&tls.Dialer{}).DialContext, address)
|
||||
txp := NewDNSOverTLSTransport((&tls.Dialer{}).DialContext, address)
|
||||
if txp.RequiresPadding() != true {
|
||||
t.Fatal("invalid RequiresPadding")
|
||||
}
|
||||
|
|
|
@ -14,6 +14,7 @@ import (
|
|||
// DNSOverUDPTransport is a DNS-over-UDP DNSTransport.
|
||||
type DNSOverUDPTransport struct {
|
||||
dialer model.Dialer
|
||||
decoder model.DNSDecoder
|
||||
address string
|
||||
}
|
||||
|
||||
|
@ -25,11 +26,20 @@ type DNSOverUDPTransport struct {
|
|||
//
|
||||
// - address is the endpoint address (e.g., 8.8.8.8:53).
|
||||
func NewDNSOverUDPTransport(dialer model.Dialer, address string) *DNSOverUDPTransport {
|
||||
return &DNSOverUDPTransport{dialer: dialer, address: address}
|
||||
return &DNSOverUDPTransport{
|
||||
dialer: dialer,
|
||||
decoder: &DNSDecoderMiekg{},
|
||||
address: address,
|
||||
}
|
||||
}
|
||||
|
||||
// RoundTrip sends a query and receives a reply.
|
||||
func (t *DNSOverUDPTransport) RoundTrip(ctx context.Context, query []byte) ([]byte, error) {
|
||||
func (t *DNSOverUDPTransport) RoundTrip(
|
||||
ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
|
||||
rawQuery, err := query.Bytes()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
conn, err := t.dialer.DialContext(ctx, "udp", t.address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -37,19 +47,19 @@ func (t *DNSOverUDPTransport) RoundTrip(ctx context.Context, query []byte) ([]by
|
|||
defer conn.Close()
|
||||
// Use five seconds timeout like Bionic does. See
|
||||
// https://labs.ripe.net/Members/baptiste_jonglez_1/persistent-dns-connections-for-reliability-and-performance
|
||||
if err = conn.SetDeadline(time.Now().Add(5 * time.Second)); err != nil {
|
||||
const iotimeout = 5 * time.Second
|
||||
conn.SetDeadline(time.Now().Add(iotimeout))
|
||||
if _, err = conn.Write(rawQuery); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if _, err = conn.Write(query); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
reply := make([]byte, 1<<17)
|
||||
var n int
|
||||
n, err = conn.Read(reply)
|
||||
const maxmessagesize = 1 << 17
|
||||
rawResponse := make([]byte, maxmessagesize)
|
||||
count, err := conn.Read(rawResponse)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return reply[:n], nil
|
||||
rawResponse = rawResponse[:count]
|
||||
return t.decoder.DecodeResponse(rawResponse, query)
|
||||
}
|
||||
|
||||
// RequiresPadding returns false for UDP according to RFC8467.
|
||||
|
|
|
@ -9,11 +9,30 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/apex/log"
|
||||
"github.com/ooni/probe-cli/v3/internal/model"
|
||||
"github.com/ooni/probe-cli/v3/internal/model/mocks"
|
||||
)
|
||||
|
||||
func TestDNSOverUDPTransport(t *testing.T) {
|
||||
t.Run("RoundTrip", func(t *testing.T) {
|
||||
t.Run("cannot encode query", func(t *testing.T) {
|
||||
expected := errors.New("mocked error")
|
||||
const address = "9.9.9.9:53"
|
||||
txp := NewDNSOverUDPTransport(nil, address)
|
||||
query := &mocks.DNSQuery{
|
||||
MockBytes: func() ([]byte, error) {
|
||||
return nil, expected
|
||||
},
|
||||
}
|
||||
resp, err := txp.RoundTrip(context.Background(), query)
|
||||
if !errors.Is(err, expected) {
|
||||
t.Fatal("unexpected err", err)
|
||||
}
|
||||
if resp != nil {
|
||||
t.Fatal("expected nil response here")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("dial failure", func(t *testing.T) {
|
||||
mocked := errors.New("mocked error")
|
||||
const address = "9.9.9.9:53"
|
||||
|
@ -22,36 +41,16 @@ func TestDNSOverUDPTransport(t *testing.T) {
|
|||
return nil, mocked
|
||||
},
|
||||
}, address)
|
||||
data, err := txp.RoundTrip(context.Background(), nil)
|
||||
query := &mocks.DNSQuery{
|
||||
MockBytes: func() ([]byte, error) {
|
||||
return make([]byte, 128), nil
|
||||
},
|
||||
}
|
||||
resp, err := txp.RoundTrip(context.Background(), query)
|
||||
if !errors.Is(err, mocked) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if data != nil {
|
||||
t.Fatal("expected no response here")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("SetDeadline failure", func(t *testing.T) {
|
||||
mocked := errors.New("mocked error")
|
||||
txp := NewDNSOverUDPTransport(
|
||||
&mocks.Dialer{
|
||||
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
return &mocks.Conn{
|
||||
MockSetDeadline: func(t time.Time) error {
|
||||
return mocked
|
||||
},
|
||||
MockClose: func() error {
|
||||
return nil
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}, "9.9.9.9:53",
|
||||
)
|
||||
data, err := txp.RoundTrip(context.Background(), nil)
|
||||
if !errors.Is(err, mocked) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if data != nil {
|
||||
if resp != nil {
|
||||
t.Fatal("expected no response here")
|
||||
}
|
||||
})
|
||||
|
@ -75,11 +74,16 @@ func TestDNSOverUDPTransport(t *testing.T) {
|
|||
},
|
||||
}, "9.9.9.9:53",
|
||||
)
|
||||
data, err := txp.RoundTrip(context.Background(), nil)
|
||||
query := &mocks.DNSQuery{
|
||||
MockBytes: func() ([]byte, error) {
|
||||
return make([]byte, 128), nil
|
||||
},
|
||||
}
|
||||
resp, err := txp.RoundTrip(context.Background(), query)
|
||||
if !errors.Is(err, mocked) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if data != nil {
|
||||
if resp != nil {
|
||||
t.Fatal("expected no response here")
|
||||
}
|
||||
})
|
||||
|
@ -106,15 +110,61 @@ func TestDNSOverUDPTransport(t *testing.T) {
|
|||
},
|
||||
}, "9.9.9.9:53",
|
||||
)
|
||||
data, err := txp.RoundTrip(context.Background(), nil)
|
||||
query := &mocks.DNSQuery{
|
||||
MockBytes: func() ([]byte, error) {
|
||||
return make([]byte, 128), nil
|
||||
},
|
||||
}
|
||||
resp, err := txp.RoundTrip(context.Background(), query)
|
||||
if !errors.Is(err, mocked) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if data != nil {
|
||||
if resp != nil {
|
||||
t.Fatal("expected no response here")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("decode failure", func(t *testing.T) {
|
||||
const expected = 17
|
||||
input := bytes.NewReader(make([]byte, expected))
|
||||
txp := NewDNSOverUDPTransport(
|
||||
&mocks.Dialer{
|
||||
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
return &mocks.Conn{
|
||||
MockSetDeadline: func(t time.Time) error {
|
||||
return nil
|
||||
},
|
||||
MockWrite: func(b []byte) (int, error) {
|
||||
return len(b), nil
|
||||
},
|
||||
MockRead: input.Read,
|
||||
MockClose: func() error {
|
||||
return nil
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}, "9.9.9.9:53",
|
||||
)
|
||||
expectedErr := errors.New("mocked error")
|
||||
txp.decoder = &mocks.DNSDecoder{
|
||||
MockDecodeResponse: func(data []byte, query model.DNSQuery) (model.DNSResponse, error) {
|
||||
return nil, expectedErr
|
||||
},
|
||||
}
|
||||
query := &mocks.DNSQuery{
|
||||
MockBytes: func() ([]byte, error) {
|
||||
return make([]byte, 128), nil
|
||||
},
|
||||
}
|
||||
resp, err := txp.RoundTrip(context.Background(), query)
|
||||
if !errors.Is(err, expectedErr) {
|
||||
t.Fatal("unexpected err", err)
|
||||
}
|
||||
if resp != nil {
|
||||
t.Fatal("expected nil resp")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("read success", func(t *testing.T) {
|
||||
const expected = 17
|
||||
input := bytes.NewReader(make([]byte, expected))
|
||||
|
@ -136,12 +186,23 @@ func TestDNSOverUDPTransport(t *testing.T) {
|
|||
},
|
||||
}, "9.9.9.9:53",
|
||||
)
|
||||
data, err := txp.RoundTrip(context.Background(), nil)
|
||||
expectedResp := &mocks.DNSResponse{}
|
||||
txp.decoder = &mocks.DNSDecoder{
|
||||
MockDecodeResponse: func(data []byte, query model.DNSQuery) (model.DNSResponse, error) {
|
||||
return expectedResp, nil
|
||||
},
|
||||
}
|
||||
query := &mocks.DNSQuery{
|
||||
MockBytes: func() ([]byte, error) {
|
||||
return make([]byte, 128), nil
|
||||
},
|
||||
}
|
||||
resp, err := txp.RoundTrip(context.Background(), query)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(data) != expected {
|
||||
t.Fatal("expected non nil data")
|
||||
if resp != expectedResp {
|
||||
t.Fatal("unexpected resp")
|
||||
}
|
||||
})
|
||||
})
|
||||
|
|
|
@ -1,15 +1,12 @@
|
|||
package filtering
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/ooni/probe-cli/v3/internal/netxlite"
|
||||
"github.com/ooni/probe-cli/v3/internal/runtimex"
|
||||
)
|
||||
|
||||
|
@ -51,19 +48,13 @@ type DNSProxy struct {
|
|||
// receive a query for the given domain.
|
||||
OnQuery func(domain string) DNSAction
|
||||
|
||||
// Upstream is the OPTIONAL upstream transport.
|
||||
Upstream DNSTransport
|
||||
// UpstreamEndpoint is the OPTIONAL upstream transport endpoint.
|
||||
UpstreamEndpoint string
|
||||
|
||||
// mockableReply allows to mock DNSProxy.reply in tests.
|
||||
mockableReply func(query *dns.Msg) (*dns.Msg, error)
|
||||
}
|
||||
|
||||
// DNSTransport is the type we expect from an upstream DNS transport.
|
||||
type DNSTransport interface {
|
||||
RoundTrip(ctx context.Context, query []byte) ([]byte, error)
|
||||
CloseIdleConnections()
|
||||
}
|
||||
|
||||
// DNSListener is the interface returned by DNSProxy.Start
|
||||
type DNSListener interface {
|
||||
io.Closer
|
||||
|
@ -204,23 +195,24 @@ func (p *DNSProxy) compose(query *dns.Msg, ips ...net.IP) *dns.Msg {
|
|||
return reply
|
||||
}
|
||||
|
||||
var (
|
||||
// errDNSExpectedSingleQuestion means we expected to see a single question
|
||||
errDNSExpectedSingleQuestion = errors.New("filtering: expected single DNS question")
|
||||
|
||||
// errDNSExpectedQueryNotResponse means we expected to see a query.
|
||||
errDNSExpectedQueryNotResponse = errors.New("filtering: expected query not response")
|
||||
)
|
||||
|
||||
func (p *DNSProxy) proxy(query *dns.Msg) (*dns.Msg, error) {
|
||||
queryBytes, err := query.Pack()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
if query.Response {
|
||||
return nil, errDNSExpectedQueryNotResponse
|
||||
}
|
||||
txp := p.dnstransport()
|
||||
defer txp.CloseIdleConnections()
|
||||
ctx := context.Background()
|
||||
replyBytes, err := txp.RoundTrip(ctx, queryBytes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
if len(query.Question) != 1 {
|
||||
return nil, errDNSExpectedSingleQuestion
|
||||
}
|
||||
reply := &dns.Msg{}
|
||||
if err := reply.Unpack(replyBytes); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return reply, nil
|
||||
clnt := &dns.Client{}
|
||||
resp, _, err := clnt.Exchange(query, p.upstreamEndpoint())
|
||||
return resp, err
|
||||
}
|
||||
|
||||
func (p *DNSProxy) cache(name string, query *dns.Msg) *dns.Msg {
|
||||
|
@ -237,10 +229,9 @@ func (p *DNSProxy) cache(name string, query *dns.Msg) *dns.Msg {
|
|||
return p.compose(query, ipAddrs...)
|
||||
}
|
||||
|
||||
func (p *DNSProxy) dnstransport() DNSTransport {
|
||||
if p.Upstream != nil {
|
||||
return p.Upstream
|
||||
func (p *DNSProxy) upstreamEndpoint() string {
|
||||
if p.UpstreamEndpoint != "" {
|
||||
return p.UpstreamEndpoint
|
||||
}
|
||||
const URL = "https://1.1.1.1/dns-query"
|
||||
return netxlite.NewDNSOverHTTPSTransport(http.DefaultClient, URL)
|
||||
return "8.8.8.8:53"
|
||||
}
|
||||
|
|
|
@ -283,9 +283,7 @@ func TestDNSProxy(t *testing.T) {
|
|||
if len(p) < len(data) {
|
||||
panic("buffer too small")
|
||||
}
|
||||
for i := 0; i < len(data); i++ {
|
||||
p[i] = data[i]
|
||||
}
|
||||
copy(p, data)
|
||||
return len(data), &net.UDPAddr{}, nil
|
||||
},
|
||||
}
|
||||
|
@ -314,9 +312,7 @@ func TestDNSProxy(t *testing.T) {
|
|||
if len(p) < len(data) {
|
||||
panic("buffer too small")
|
||||
}
|
||||
for i := 0; i < len(data); i++ {
|
||||
p[i] = data[i]
|
||||
}
|
||||
copy(p, data)
|
||||
return len(data), &net.UDPAddr{}, nil
|
||||
},
|
||||
}
|
||||
|
@ -328,12 +324,24 @@ func TestDNSProxy(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("proxy", func(t *testing.T) {
|
||||
t.Run("Pack fails", func(t *testing.T) {
|
||||
t.Run("with response", func(t *testing.T) {
|
||||
p := &DNSProxy{}
|
||||
query := &dns.Msg{}
|
||||
query.Rcode = -1 // causes Pack to fail
|
||||
query.Response = true
|
||||
reply, err := p.proxy(query)
|
||||
if err == nil || !strings.HasSuffix(err.Error(), "bad rcode") {
|
||||
if !errors.Is(err, errDNSExpectedQueryNotResponse) {
|
||||
t.Fatal("unexpected err", err)
|
||||
}
|
||||
if reply != nil {
|
||||
t.Fatal("expected nil reply")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("with no questions", func(t *testing.T) {
|
||||
p := &DNSProxy{}
|
||||
query := &dns.Msg{}
|
||||
reply, err := p.proxy(query)
|
||||
if !errors.Is(err, errDNSExpectedSingleQuestion) {
|
||||
t.Fatal("unexpected err", err)
|
||||
}
|
||||
if reply != nil {
|
||||
|
@ -342,35 +350,13 @@ func TestDNSProxy(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("round trip fails", func(t *testing.T) {
|
||||
expected := errors.New("mocked error")
|
||||
p := &DNSProxy{
|
||||
Upstream: &mocks.DNSTransport{
|
||||
MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) {
|
||||
return nil, expected
|
||||
},
|
||||
MockCloseIdleConnections: func() {},
|
||||
},
|
||||
UpstreamEndpoint: "antani",
|
||||
}
|
||||
reply, err := p.proxy(&dns.Msg{})
|
||||
if !errors.Is(err, expected) {
|
||||
t.Fatal("unexpected err", err)
|
||||
}
|
||||
if reply != nil {
|
||||
t.Fatal("expected nil reply here")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Unpack fails", func(t *testing.T) {
|
||||
p := &DNSProxy{
|
||||
Upstream: &mocks.DNSTransport{
|
||||
MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) {
|
||||
return make([]byte, 1), nil
|
||||
},
|
||||
MockCloseIdleConnections: func() {},
|
||||
},
|
||||
}
|
||||
reply, err := p.proxy(&dns.Msg{})
|
||||
if err == nil || !strings.HasSuffix(err.Error(), "overflow unpacking uint16") {
|
||||
query := &dns.Msg{}
|
||||
query.Question = append(query.Question, dns.Question{})
|
||||
reply, err := p.proxy(query)
|
||||
if err == nil || !strings.HasSuffix(err.Error(), "missing port in address") {
|
||||
t.Fatal("unexpected err", err)
|
||||
}
|
||||
if reply != nil {
|
||||
|
|
|
@ -9,7 +9,6 @@ import (
|
|||
"net"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/ooni/probe-cli/v3/internal/atomicx"
|
||||
"github.com/ooni/probe-cli/v3/internal/model"
|
||||
)
|
||||
|
||||
|
@ -19,15 +18,6 @@ import (
|
|||
// You should probably use NewUnwrappedParallelResolver to
|
||||
// create a new instance of this type.
|
||||
type ParallelResolver struct {
|
||||
// Encoder is the MANDATORY encoder to use.
|
||||
Encoder model.DNSEncoder
|
||||
|
||||
// Decoder is the MANDATORY decoder to use.
|
||||
Decoder model.DNSDecoder
|
||||
|
||||
// NumTimeouts is MANDATORY and counts the number of timeouts.
|
||||
NumTimeouts *atomicx.Int64
|
||||
|
||||
// Txp is the MANDATORY underlying DNS transport.
|
||||
Txp model.DNSTransport
|
||||
}
|
||||
|
@ -36,10 +26,7 @@ type ParallelResolver struct {
|
|||
// not wrapped and you should wrap if before using it.
|
||||
func NewUnwrappedParallelResolver(t model.DNSTransport) *ParallelResolver {
|
||||
return &ParallelResolver{
|
||||
Encoder: &DNSEncoderMiekg{},
|
||||
Decoder: &DNSDecoderMiekg{},
|
||||
NumTimeouts: &atomicx.Int64{},
|
||||
Txp: t,
|
||||
Txp: t,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -80,22 +67,22 @@ func (r *ParallelResolver) LookupHost(ctx context.Context, hostname string) ([]s
|
|||
var addrs []string
|
||||
addrs = append(addrs, ares.addrs...)
|
||||
addrs = append(addrs, aaaares.addrs...)
|
||||
if len(addrs) < 1 {
|
||||
return nil, ErrOODNSNoAnswer
|
||||
}
|
||||
return addrs, nil
|
||||
}
|
||||
|
||||
// LookupHTTPS implements Resolver.LookupHTTPS.
|
||||
func (r *ParallelResolver) LookupHTTPS(
|
||||
ctx context.Context, hostname string) (*model.HTTPSSvc, error) {
|
||||
querydata, queryID, err := r.Encoder.Encode(
|
||||
hostname, dns.TypeHTTPS, r.Txp.RequiresPadding())
|
||||
encoder := &DNSEncoderMiekg{}
|
||||
query := encoder.Encode(hostname, dns.TypeHTTPS, r.Txp.RequiresPadding())
|
||||
response, err := r.Txp.RoundTrip(ctx, query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
replydata, err := r.Txp.RoundTrip(ctx, querydata)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return r.Decoder.DecodeHTTPS(replydata, queryID)
|
||||
return response.DecodeHTTPS()
|
||||
}
|
||||
|
||||
// parallelResolverResult is the internal representation of a
|
||||
|
@ -108,7 +95,9 @@ type parallelResolverResult struct {
|
|||
// lookupHost issues a lookup host query for the specified qtype (e.g., dns.A).
|
||||
func (r *ParallelResolver) lookupHost(ctx context.Context, hostname string,
|
||||
qtype uint16, out chan<- *parallelResolverResult) {
|
||||
querydata, queryID, err := r.Encoder.Encode(hostname, qtype, r.Txp.RequiresPadding())
|
||||
encoder := &DNSEncoderMiekg{}
|
||||
query := encoder.Encode(hostname, qtype, r.Txp.RequiresPadding())
|
||||
response, err := r.Txp.RoundTrip(ctx, query)
|
||||
if err != nil {
|
||||
out <- ¶llelResolverResult{
|
||||
addrs: []string{},
|
||||
|
@ -116,15 +105,7 @@ func (r *ParallelResolver) lookupHost(ctx context.Context, hostname string,
|
|||
}
|
||||
return
|
||||
}
|
||||
replydata, err := r.Txp.RoundTrip(ctx, querydata)
|
||||
if err != nil {
|
||||
out <- ¶llelResolverResult{
|
||||
addrs: []string{},
|
||||
err: err,
|
||||
}
|
||||
return
|
||||
}
|
||||
addrs, err := r.Decoder.DecodeLookupHost(qtype, replydata, queryID)
|
||||
addrs, err := response.DecodeLookupHost()
|
||||
out <- ¶llelResolverResult{
|
||||
addrs: addrs,
|
||||
err: err,
|
||||
|
@ -134,14 +115,11 @@ func (r *ParallelResolver) lookupHost(ctx context.Context, hostname string,
|
|||
// LookupNS implements Resolver.LookupNS.
|
||||
func (r *ParallelResolver) LookupNS(
|
||||
ctx context.Context, hostname string) ([]*net.NS, error) {
|
||||
querydata, queryID, err := r.Encoder.Encode(
|
||||
hostname, dns.TypeNS, r.Txp.RequiresPadding())
|
||||
encoder := &DNSEncoderMiekg{}
|
||||
query := encoder.Encode(hostname, dns.TypeNS, r.Txp.RequiresPadding())
|
||||
response, err := r.Txp.RoundTrip(ctx, query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
replydata, err := r.Txp.RoundTrip(ctx, querydata)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return r.Decoder.DecodeNS(replydata, queryID)
|
||||
return response.DecodeNS()
|
||||
}
|
||||
|
|
|
@ -8,14 +8,13 @@ import (
|
|||
"testing"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/ooni/probe-cli/v3/internal/atomicx"
|
||||
"github.com/ooni/probe-cli/v3/internal/model"
|
||||
"github.com/ooni/probe-cli/v3/internal/model/mocks"
|
||||
)
|
||||
|
||||
func TestParallelResolver(t *testing.T) {
|
||||
t.Run("transport okay", func(t *testing.T) {
|
||||
txp := NewDNSOverTLS((&tls.Dialer{}).DialContext, "8.8.8.8:853")
|
||||
txp := NewDNSOverTLSTransport((&tls.Dialer{}).DialContext, "8.8.8.8:853")
|
||||
r := NewUnwrappedParallelResolver(txp)
|
||||
rtx := r.Transport()
|
||||
if rtx.Network() != "dot" || rtx.Address() != "8.8.8.8:853" {
|
||||
|
@ -30,30 +29,10 @@ func TestParallelResolver(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("LookupHost", func(t *testing.T) {
|
||||
t.Run("Encode error", func(t *testing.T) {
|
||||
mocked := errors.New("mocked error")
|
||||
txp := NewDNSOverTLS((&tls.Dialer{}).DialContext, "8.8.8.8:853")
|
||||
r := ParallelResolver{
|
||||
Encoder: &mocks.DNSEncoder{
|
||||
MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, uint16, error) {
|
||||
return nil, 0, mocked
|
||||
},
|
||||
},
|
||||
Txp: txp,
|
||||
}
|
||||
addrs, err := r.LookupHost(context.Background(), "www.gogle.com")
|
||||
if !errors.Is(err, mocked) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if addrs != nil {
|
||||
t.Fatal("expected nil address here")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("RoundTrip error", func(t *testing.T) {
|
||||
mocked := errors.New("mocked error")
|
||||
txp := &mocks.DNSTransport{
|
||||
MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) {
|
||||
MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
|
||||
return nil, mocked
|
||||
},
|
||||
MockRequiresPadding: func() bool {
|
||||
|
@ -72,8 +51,13 @@ func TestParallelResolver(t *testing.T) {
|
|||
|
||||
t.Run("empty reply", func(t *testing.T) {
|
||||
txp := &mocks.DNSTransport{
|
||||
MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) {
|
||||
return dnsGenLookupHostReplySuccess(query), nil
|
||||
MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
|
||||
response := &mocks.DNSResponse{
|
||||
MockDecodeLookupHost: func() ([]string, error) {
|
||||
return nil, nil
|
||||
},
|
||||
}
|
||||
return response, nil
|
||||
},
|
||||
MockRequiresPadding: func() bool {
|
||||
return true
|
||||
|
@ -91,8 +75,16 @@ func TestParallelResolver(t *testing.T) {
|
|||
|
||||
t.Run("with A reply", func(t *testing.T) {
|
||||
txp := &mocks.DNSTransport{
|
||||
MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) {
|
||||
return dnsGenLookupHostReplySuccess(query, "8.8.8.8"), nil
|
||||
MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
|
||||
response := &mocks.DNSResponse{
|
||||
MockDecodeLookupHost: func() ([]string, error) {
|
||||
if query.Type() != dns.TypeA {
|
||||
return nil, nil
|
||||
}
|
||||
return []string{"8.8.8.8"}, nil
|
||||
},
|
||||
}
|
||||
return response, nil
|
||||
},
|
||||
MockRequiresPadding: func() bool {
|
||||
return true
|
||||
|
@ -110,8 +102,16 @@ func TestParallelResolver(t *testing.T) {
|
|||
|
||||
t.Run("with AAAA reply", func(t *testing.T) {
|
||||
txp := &mocks.DNSTransport{
|
||||
MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) {
|
||||
return dnsGenLookupHostReplySuccess(query, "::1"), nil
|
||||
MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
|
||||
response := &mocks.DNSResponse{
|
||||
MockDecodeLookupHost: func() ([]string, error) {
|
||||
if query.Type() != dns.TypeAAAA {
|
||||
return nil, nil
|
||||
}
|
||||
return []string{"::1"}, nil
|
||||
},
|
||||
}
|
||||
return response, nil
|
||||
},
|
||||
MockRequiresPadding: func() bool {
|
||||
return true
|
||||
|
@ -131,22 +131,15 @@ func TestParallelResolver(t *testing.T) {
|
|||
afailure := errors.New("a failure")
|
||||
aaaafailure := errors.New("aaaa failure")
|
||||
txp := &mocks.DNSTransport{
|
||||
MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) {
|
||||
msg := &dns.Msg{}
|
||||
if err := msg.Unpack(query); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(msg.Question) != 1 {
|
||||
return nil, errors.New("expected just one question")
|
||||
}
|
||||
q := msg.Question[0]
|
||||
if q.Qtype == dns.TypeA {
|
||||
MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
|
||||
switch query.Type() {
|
||||
case dns.TypeA:
|
||||
return nil, afailure
|
||||
}
|
||||
if q.Qtype == dns.TypeAAAA {
|
||||
case dns.TypeAAAA:
|
||||
return nil, aaaafailure
|
||||
default:
|
||||
return nil, errors.New("unexpected query")
|
||||
}
|
||||
return nil, errors.New("expected A or AAAA query")
|
||||
},
|
||||
MockRequiresPadding: func() bool {
|
||||
return true
|
||||
|
@ -179,44 +172,11 @@ func TestParallelResolver(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("LookupHTTPS", func(t *testing.T) {
|
||||
t.Run("for encoding error", func(t *testing.T) {
|
||||
expected := errors.New("mocked error")
|
||||
r := &ParallelResolver{
|
||||
Encoder: &mocks.DNSEncoder{
|
||||
MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, uint16, error) {
|
||||
return nil, 0, expected
|
||||
},
|
||||
},
|
||||
Decoder: nil,
|
||||
NumTimeouts: &atomicx.Int64{},
|
||||
Txp: &mocks.DNSTransport{
|
||||
MockRequiresPadding: func() bool {
|
||||
return false
|
||||
},
|
||||
},
|
||||
}
|
||||
ctx := context.Background()
|
||||
https, err := r.LookupHTTPS(ctx, "example.com")
|
||||
if !errors.Is(err, expected) {
|
||||
t.Fatal("unexpected err", err)
|
||||
}
|
||||
if https != nil {
|
||||
t.Fatal("unexpected result")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("for round-trip error", func(t *testing.T) {
|
||||
expected := errors.New("mocked error")
|
||||
r := &ParallelResolver{
|
||||
Encoder: &mocks.DNSEncoder{
|
||||
MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, uint16, error) {
|
||||
return make([]byte, 64), 0, nil
|
||||
},
|
||||
},
|
||||
Decoder: nil,
|
||||
NumTimeouts: &atomicx.Int64{},
|
||||
Txp: &mocks.DNSTransport{
|
||||
MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) {
|
||||
MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
|
||||
return nil, expected
|
||||
},
|
||||
MockRequiresPadding: func() bool {
|
||||
|
@ -234,23 +194,17 @@ func TestParallelResolver(t *testing.T) {
|
|||
}
|
||||
})
|
||||
|
||||
t.Run("for decode error", func(t *testing.T) {
|
||||
t.Run("for DecodeHTTPS error", func(t *testing.T) {
|
||||
expected := errors.New("mocked error")
|
||||
r := &ParallelResolver{
|
||||
Encoder: &mocks.DNSEncoder{
|
||||
MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, uint16, error) {
|
||||
return make([]byte, 64), 0, nil
|
||||
},
|
||||
},
|
||||
Decoder: &mocks.DNSDecoder{
|
||||
MockDecodeHTTPS: func(reply []byte, queryID uint16) (*model.HTTPSSvc, error) {
|
||||
return nil, expected
|
||||
},
|
||||
},
|
||||
NumTimeouts: &atomicx.Int64{},
|
||||
Txp: &mocks.DNSTransport{
|
||||
MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) {
|
||||
return make([]byte, 128), nil
|
||||
MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
|
||||
response := &mocks.DNSResponse{
|
||||
MockDecodeHTTPS: func() (*model.HTTPSSvc, error) {
|
||||
return nil, expected
|
||||
},
|
||||
}
|
||||
return response, nil
|
||||
},
|
||||
MockRequiresPadding: func() bool {
|
||||
return false
|
||||
|
@ -269,44 +223,11 @@ func TestParallelResolver(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("LookupNS", func(t *testing.T) {
|
||||
t.Run("for encoding error", func(t *testing.T) {
|
||||
expected := errors.New("mocked error")
|
||||
r := &ParallelResolver{
|
||||
Encoder: &mocks.DNSEncoder{
|
||||
MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, uint16, error) {
|
||||
return nil, 0, expected
|
||||
},
|
||||
},
|
||||
Decoder: nil,
|
||||
NumTimeouts: &atomicx.Int64{},
|
||||
Txp: &mocks.DNSTransport{
|
||||
MockRequiresPadding: func() bool {
|
||||
return false
|
||||
},
|
||||
},
|
||||
}
|
||||
ctx := context.Background()
|
||||
ns, err := r.LookupNS(ctx, "example.com")
|
||||
if !errors.Is(err, expected) {
|
||||
t.Fatal("unexpected err", err)
|
||||
}
|
||||
if ns != nil {
|
||||
t.Fatal("unexpected result")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("for round-trip error", func(t *testing.T) {
|
||||
expected := errors.New("mocked error")
|
||||
r := &ParallelResolver{
|
||||
Encoder: &mocks.DNSEncoder{
|
||||
MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, uint16, error) {
|
||||
return make([]byte, 64), 0, nil
|
||||
},
|
||||
},
|
||||
Decoder: nil,
|
||||
NumTimeouts: &atomicx.Int64{},
|
||||
Txp: &mocks.DNSTransport{
|
||||
MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) {
|
||||
MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
|
||||
return nil, expected
|
||||
},
|
||||
MockRequiresPadding: func() bool {
|
||||
|
@ -327,20 +248,14 @@ func TestParallelResolver(t *testing.T) {
|
|||
t.Run("for decode error", func(t *testing.T) {
|
||||
expected := errors.New("mocked error")
|
||||
r := &ParallelResolver{
|
||||
Encoder: &mocks.DNSEncoder{
|
||||
MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, uint16, error) {
|
||||
return make([]byte, 64), 0, nil
|
||||
},
|
||||
},
|
||||
Decoder: &mocks.DNSDecoder{
|
||||
MockDecodeNS: func(reply []byte, queryID uint16) ([]*net.NS, error) {
|
||||
return nil, expected
|
||||
},
|
||||
},
|
||||
NumTimeouts: &atomicx.Int64{},
|
||||
Txp: &mocks.DNSTransport{
|
||||
MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) {
|
||||
return make([]byte, 128), nil
|
||||
MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
|
||||
response := &mocks.DNSResponse{
|
||||
MockDecodeNS: func() ([]*net.NS, error) {
|
||||
return nil, expected
|
||||
},
|
||||
}
|
||||
return response, nil
|
||||
},
|
||||
MockRequiresPadding: func() bool {
|
||||
return false
|
||||
|
|
|
@ -26,12 +26,6 @@ import (
|
|||
// QUIRK: unlike the ParallelResolver, this resolver's LookupHost retries
|
||||
// each query three times for soft errors.
|
||||
type SerialResolver struct {
|
||||
// Encoder is the MANDATORY encoder to use.
|
||||
Encoder model.DNSEncoder
|
||||
|
||||
// Decoder is the MANDATORY decoder to use.
|
||||
Decoder model.DNSDecoder
|
||||
|
||||
// NumTimeouts is MANDATORY and counts the number of timeouts.
|
||||
NumTimeouts *atomicx.Int64
|
||||
|
||||
|
@ -42,8 +36,6 @@ type SerialResolver struct {
|
|||
// NewSerialResolver creates a new SerialResolver instance.
|
||||
func NewSerialResolver(t model.DNSTransport) *SerialResolver {
|
||||
return &SerialResolver{
|
||||
Encoder: &DNSEncoderMiekg{},
|
||||
Decoder: &DNSDecoderMiekg{},
|
||||
NumTimeouts: &atomicx.Int64{},
|
||||
Txp: t,
|
||||
}
|
||||
|
@ -82,22 +74,22 @@ func (r *SerialResolver) LookupHost(ctx context.Context, hostname string) ([]str
|
|||
}
|
||||
addrs = append(addrs, addrsA...)
|
||||
addrs = append(addrs, addrsAAAA...)
|
||||
if len(addrs) < 1 {
|
||||
return nil, ErrOODNSNoAnswer
|
||||
}
|
||||
return addrs, nil
|
||||
}
|
||||
|
||||
// LookupHTTPS implements Resolver.LookupHTTPS.
|
||||
func (r *SerialResolver) LookupHTTPS(
|
||||
ctx context.Context, hostname string) (*model.HTTPSSvc, error) {
|
||||
querydata, queryID, err := r.Encoder.Encode(
|
||||
hostname, dns.TypeHTTPS, r.Txp.RequiresPadding())
|
||||
encoder := &DNSEncoderMiekg{}
|
||||
query := encoder.Encode(hostname, dns.TypeHTTPS, r.Txp.RequiresPadding())
|
||||
response, err := r.Txp.RoundTrip(ctx, query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
replydata, err := r.Txp.RoundTrip(ctx, querydata)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return r.Decoder.DecodeHTTPS(replydata, queryID)
|
||||
return response.DecodeHTTPS()
|
||||
}
|
||||
|
||||
func (r *SerialResolver) lookupHostWithRetry(
|
||||
|
@ -132,28 +124,23 @@ func (r *SerialResolver) lookupHostWithRetry(
|
|||
// qtype (dns.A or dns.AAAA) without retrying on failure.
|
||||
func (r *SerialResolver) lookupHostWithoutRetry(
|
||||
ctx context.Context, hostname string, qtype uint16) ([]string, error) {
|
||||
querydata, queryID, err := r.Encoder.Encode(hostname, qtype, r.Txp.RequiresPadding())
|
||||
encoder := &DNSEncoderMiekg{}
|
||||
query := encoder.Encode(hostname, qtype, r.Txp.RequiresPadding())
|
||||
response, err := r.Txp.RoundTrip(ctx, query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
replydata, err := r.Txp.RoundTrip(ctx, querydata)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return r.Decoder.DecodeLookupHost(qtype, replydata, queryID)
|
||||
return response.DecodeLookupHost()
|
||||
}
|
||||
|
||||
// LookupNS implements Resolver.LookupNS.
|
||||
func (r *SerialResolver) LookupNS(
|
||||
ctx context.Context, hostname string) ([]*net.NS, error) {
|
||||
querydata, queryID, err := r.Encoder.Encode(
|
||||
hostname, dns.TypeNS, r.Txp.RequiresPadding())
|
||||
encoder := &DNSEncoderMiekg{}
|
||||
query := encoder.Encode(hostname, dns.TypeNS, r.Txp.RequiresPadding())
|
||||
response, err := r.Txp.RoundTrip(ctx, query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
replydata, err := r.Txp.RoundTrip(ctx, querydata)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return r.Decoder.DecodeNS(replydata, queryID)
|
||||
return response.DecodeNS()
|
||||
}
|
||||
|
|
|
@ -7,6 +7,7 @@ import (
|
|||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/ooni/probe-cli/v3/internal/atomicx"
|
||||
"github.com/ooni/probe-cli/v3/internal/model"
|
||||
"github.com/ooni/probe-cli/v3/internal/model/mocks"
|
||||
|
@ -30,7 +31,7 @@ func (err *errorWithTimeout) Unwrap() error {
|
|||
|
||||
func TestSerialResolver(t *testing.T) {
|
||||
t.Run("transport okay", func(t *testing.T) {
|
||||
txp := NewDNSOverTLS((&tls.Dialer{}).DialContext, "8.8.8.8:853")
|
||||
txp := NewDNSOverTLSTransport((&tls.Dialer{}).DialContext, "8.8.8.8:853")
|
||||
r := NewSerialResolver(txp)
|
||||
rtx := r.Transport()
|
||||
if rtx.Network() != "dot" || rtx.Address() != "8.8.8.8:853" {
|
||||
|
@ -45,30 +46,10 @@ func TestSerialResolver(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("LookupHost", func(t *testing.T) {
|
||||
t.Run("Encode error", func(t *testing.T) {
|
||||
mocked := errors.New("mocked error")
|
||||
txp := NewDNSOverTLS((&tls.Dialer{}).DialContext, "8.8.8.8:853")
|
||||
r := SerialResolver{
|
||||
Encoder: &mocks.DNSEncoder{
|
||||
MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, uint16, error) {
|
||||
return nil, 0, mocked
|
||||
},
|
||||
},
|
||||
Txp: txp,
|
||||
}
|
||||
addrs, err := r.LookupHost(context.Background(), "www.gogle.com")
|
||||
if !errors.Is(err, mocked) {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
if addrs != nil {
|
||||
t.Fatal("expected nil address here")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("RoundTrip error", func(t *testing.T) {
|
||||
mocked := errors.New("mocked error")
|
||||
txp := &mocks.DNSTransport{
|
||||
MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) {
|
||||
MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
|
||||
return nil, mocked
|
||||
},
|
||||
MockRequiresPadding: func() bool {
|
||||
|
@ -87,8 +68,13 @@ func TestSerialResolver(t *testing.T) {
|
|||
|
||||
t.Run("empty reply", func(t *testing.T) {
|
||||
txp := &mocks.DNSTransport{
|
||||
MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) {
|
||||
return dnsGenLookupHostReplySuccess(query), nil
|
||||
MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
|
||||
response := &mocks.DNSResponse{
|
||||
MockDecodeLookupHost: func() ([]string, error) {
|
||||
return nil, nil
|
||||
},
|
||||
}
|
||||
return response, nil
|
||||
},
|
||||
MockRequiresPadding: func() bool {
|
||||
return true
|
||||
|
@ -106,8 +92,16 @@ func TestSerialResolver(t *testing.T) {
|
|||
|
||||
t.Run("with A reply", func(t *testing.T) {
|
||||
txp := &mocks.DNSTransport{
|
||||
MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) {
|
||||
return dnsGenLookupHostReplySuccess(query, "8.8.8.8"), nil
|
||||
MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
|
||||
response := &mocks.DNSResponse{
|
||||
MockDecodeLookupHost: func() ([]string, error) {
|
||||
if query.Type() != dns.TypeA {
|
||||
return nil, nil
|
||||
}
|
||||
return []string{"8.8.8.8"}, nil
|
||||
},
|
||||
}
|
||||
return response, nil
|
||||
},
|
||||
MockRequiresPadding: func() bool {
|
||||
return true
|
||||
|
@ -125,8 +119,16 @@ func TestSerialResolver(t *testing.T) {
|
|||
|
||||
t.Run("with AAAA reply", func(t *testing.T) {
|
||||
txp := &mocks.DNSTransport{
|
||||
MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) {
|
||||
return dnsGenLookupHostReplySuccess(query, "::1"), nil
|
||||
MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
|
||||
response := &mocks.DNSResponse{
|
||||
MockDecodeLookupHost: func() ([]string, error) {
|
||||
if query.Type() != dns.TypeAAAA {
|
||||
return nil, nil
|
||||
}
|
||||
return []string{"::1"}, nil
|
||||
},
|
||||
}
|
||||
return response, nil
|
||||
},
|
||||
MockRequiresPadding: func() bool {
|
||||
return true
|
||||
|
@ -144,11 +146,12 @@ func TestSerialResolver(t *testing.T) {
|
|||
|
||||
t.Run("with timeout", func(t *testing.T) {
|
||||
txp := &mocks.DNSTransport{
|
||||
MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) {
|
||||
return nil, &net.OpError{
|
||||
MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
|
||||
err := &net.OpError{
|
||||
Err: &errorWithTimeout{ETIMEDOUT},
|
||||
Op: "dial",
|
||||
}
|
||||
return nil, err
|
||||
},
|
||||
MockRequiresPadding: func() bool {
|
||||
return true
|
||||
|
@ -184,44 +187,12 @@ func TestSerialResolver(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("LookupHTTPS", func(t *testing.T) {
|
||||
t.Run("for encoding error", func(t *testing.T) {
|
||||
expected := errors.New("mocked error")
|
||||
r := &SerialResolver{
|
||||
Encoder: &mocks.DNSEncoder{
|
||||
MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, uint16, error) {
|
||||
return nil, 0, expected
|
||||
},
|
||||
},
|
||||
Decoder: nil,
|
||||
NumTimeouts: &atomicx.Int64{},
|
||||
Txp: &mocks.DNSTransport{
|
||||
MockRequiresPadding: func() bool {
|
||||
return false
|
||||
},
|
||||
},
|
||||
}
|
||||
ctx := context.Background()
|
||||
https, err := r.LookupHTTPS(ctx, "example.com")
|
||||
if !errors.Is(err, expected) {
|
||||
t.Fatal("unexpected err", err)
|
||||
}
|
||||
if https != nil {
|
||||
t.Fatal("unexpected result")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("for round-trip error", func(t *testing.T) {
|
||||
expected := errors.New("mocked error")
|
||||
r := &SerialResolver{
|
||||
Encoder: &mocks.DNSEncoder{
|
||||
MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, uint16, error) {
|
||||
return make([]byte, 64), 0, nil
|
||||
},
|
||||
},
|
||||
Decoder: nil,
|
||||
NumTimeouts: &atomicx.Int64{},
|
||||
Txp: &mocks.DNSTransport{
|
||||
MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) {
|
||||
MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
|
||||
return nil, expected
|
||||
},
|
||||
MockRequiresPadding: func() bool {
|
||||
|
@ -242,20 +213,15 @@ func TestSerialResolver(t *testing.T) {
|
|||
t.Run("for decode error", func(t *testing.T) {
|
||||
expected := errors.New("mocked error")
|
||||
r := &SerialResolver{
|
||||
Encoder: &mocks.DNSEncoder{
|
||||
MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, uint16, error) {
|
||||
return make([]byte, 64), 0, nil
|
||||
},
|
||||
},
|
||||
Decoder: &mocks.DNSDecoder{
|
||||
MockDecodeHTTPS: func(reply []byte, queryID uint16) (*model.HTTPSSvc, error) {
|
||||
return nil, expected
|
||||
},
|
||||
},
|
||||
NumTimeouts: &atomicx.Int64{},
|
||||
Txp: &mocks.DNSTransport{
|
||||
MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) {
|
||||
return make([]byte, 128), nil
|
||||
MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
|
||||
response := &mocks.DNSResponse{
|
||||
MockDecodeHTTPS: func() (*model.HTTPSSvc, error) {
|
||||
return nil, expected
|
||||
},
|
||||
}
|
||||
return response, nil
|
||||
},
|
||||
MockRequiresPadding: func() bool {
|
||||
return false
|
||||
|
@ -274,44 +240,12 @@ func TestSerialResolver(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("LookupNS", func(t *testing.T) {
|
||||
t.Run("for encoding error", func(t *testing.T) {
|
||||
expected := errors.New("mocked error")
|
||||
r := &SerialResolver{
|
||||
Encoder: &mocks.DNSEncoder{
|
||||
MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, uint16, error) {
|
||||
return nil, 0, expected
|
||||
},
|
||||
},
|
||||
Decoder: nil,
|
||||
NumTimeouts: &atomicx.Int64{},
|
||||
Txp: &mocks.DNSTransport{
|
||||
MockRequiresPadding: func() bool {
|
||||
return false
|
||||
},
|
||||
},
|
||||
}
|
||||
ctx := context.Background()
|
||||
ns, err := r.LookupNS(ctx, "example.com")
|
||||
if !errors.Is(err, expected) {
|
||||
t.Fatal("unexpected err", err)
|
||||
}
|
||||
if ns != nil {
|
||||
t.Fatal("unexpected result")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("for round-trip error", func(t *testing.T) {
|
||||
expected := errors.New("mocked error")
|
||||
r := &SerialResolver{
|
||||
Encoder: &mocks.DNSEncoder{
|
||||
MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, uint16, error) {
|
||||
return make([]byte, 64), 0, nil
|
||||
},
|
||||
},
|
||||
Decoder: nil,
|
||||
NumTimeouts: &atomicx.Int64{},
|
||||
Txp: &mocks.DNSTransport{
|
||||
MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) {
|
||||
MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
|
||||
return nil, expected
|
||||
},
|
||||
MockRequiresPadding: func() bool {
|
||||
|
@ -332,20 +266,15 @@ func TestSerialResolver(t *testing.T) {
|
|||
t.Run("for decode error", func(t *testing.T) {
|
||||
expected := errors.New("mocked error")
|
||||
r := &SerialResolver{
|
||||
Encoder: &mocks.DNSEncoder{
|
||||
MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, uint16, error) {
|
||||
return make([]byte, 64), 0, nil
|
||||
},
|
||||
},
|
||||
Decoder: &mocks.DNSDecoder{
|
||||
MockDecodeNS: func(reply []byte, queryID uint16) ([]*net.NS, error) {
|
||||
return nil, expected
|
||||
},
|
||||
},
|
||||
NumTimeouts: &atomicx.Int64{},
|
||||
Txp: &mocks.DNSTransport{
|
||||
MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) {
|
||||
return make([]byte, 128), nil
|
||||
MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
|
||||
response := &mocks.DNSResponse{
|
||||
MockDecodeNS: func() ([]*net.NS, error) {
|
||||
return nil, expected
|
||||
},
|
||||
}
|
||||
return response, nil
|
||||
},
|
||||
MockRequiresPadding: func() bool {
|
||||
return false
|
||||
|
|
Loading…
Reference in New Issue
Block a user