refactor: DNSTransport I/Os DNS messages (#760)

This diff refactors the DNSTransport model to receive in input a DNSQuery and return in output a DNSResponse.

The design of DNSQuery and DNSResponse takes into account the use case of a transport using getaddrinfo, meaning that we don't need to serialize and deserialize messages when using getaddrinfo.

The current codebase does not use a getaddrinfo transport, but I wrote one such a transport in the Websteps Winter 2021 prototype (https://github.com/bassosimone/websteps-illustrated/).

The design conversation that lead to producing this diff is https://github.com/ooni/probe/issues/2099
This commit is contained in:
Simone Basso 2022-05-25 17:03:58 +02:00 committed by GitHub
parent 7a0a156aec
commit 01a513a496
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
35 changed files with 1731 additions and 1076 deletions

1
go.mod
View File

@ -17,7 +17,6 @@ require (
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510
github.com/google/uuid v1.3.0 github.com/google/uuid v1.3.0
github.com/gorilla/websocket v1.5.0 github.com/gorilla/websocket v1.5.0
github.com/hexops/gotextdiff v1.0.3
github.com/iancoleman/strcase v0.2.0 github.com/iancoleman/strcase v0.2.0
github.com/lucas-clemente/quic-go v0.27.0 github.com/lucas-clemente/quic-go v0.27.0
github.com/mattn/go-colorable v0.1.12 github.com/mattn/go-colorable v0.1.12

2
go.sum
View File

@ -378,8 +378,6 @@ github.com/hashicorp/logutils v1.0.0/go.mod h1:QIAnNjmIWmVIIkWDTG1z5v++HQmx9WQRO
github.com/hashicorp/mdns v1.0.0/go.mod h1:tL+uN++7HEJ6SQLQ2/p+z2pH24WQKWjBPkE0mNTz8vQ= github.com/hashicorp/mdns v1.0.0/go.mod h1:tL+uN++7HEJ6SQLQ2/p+z2pH24WQKWjBPkE0mNTz8vQ=
github.com/hashicorp/memberlist v0.1.3/go.mod h1:ajVTdAv/9Im8oMAAj5G31PhhMCZJV2pPBoIllUwCN7I= github.com/hashicorp/memberlist v0.1.3/go.mod h1:ajVTdAv/9Im8oMAAj5G31PhhMCZJV2pPBoIllUwCN7I=
github.com/hashicorp/serf v0.8.2/go.mod h1:6hOLApaqBFA1NXqRQAsxw9QxuDEvNxSQRwA/JwenrHc= github.com/hashicorp/serf v0.8.2/go.mod h1:6hOLApaqBFA1NXqRQAsxw9QxuDEvNxSQRwA/JwenrHc=
github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUqJM=
github.com/hexops/gotextdiff v1.0.3/go.mod h1:pSWU5MAI3yDq+fZBTazCSJysOMbxWL1BSow5/V2vxeg=
github.com/hinshun/vt10x v0.0.0-20220119200601-820417d04eec h1:qv2VnGeEQHchGaZ/u7lxST/RaJw+cv273q79D81Xbog= github.com/hinshun/vt10x v0.0.0-20220119200601-820417d04eec h1:qv2VnGeEQHchGaZ/u7lxST/RaJw+cv273q79D81Xbog=
github.com/hinshun/vt10x v0.0.0-20220119200601-820417d04eec/go.mod h1:Q48J4R4DvxnHolD5P8pOtXigYlRuPLGl6moFx3ulM68= github.com/hinshun/vt10x v0.0.0-20220119200601-820417d04eec/go.mod h1:Q48J4R4DvxnHolD5P8pOtXigYlRuPLGl6moFx3ulM68=
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=

View File

@ -317,7 +317,7 @@ func NewDNSClientWithOverrides(config Config, URL, hostOverride, SNIOverride,
if err != nil { if err != nil {
return nil, err return nil, err
} }
var txp model.DNSTransport = netxlite.NewDNSOverTLS( var txp model.DNSTransport = netxlite.NewDNSOverTLSTransport(
tlsDialer.DialTLSContext, endpoint) tlsDialer.DialTLSContext, endpoint)
if config.ResolveSaver != nil { if config.ResolveSaver != nil {
txp = resolver.SaverDNSTransport{ txp = resolver.SaverDNSTransport{

View File

@ -76,40 +76,6 @@ func (c *FakeConn) SetWriteDeadline(t time.Time) (err error) {
return c.SetWriteDeadlineError return c.SetWriteDeadlineError
} }
type FakeTransport struct {
Data []byte
Err error
}
func (ft FakeTransport) RoundTrip(ctx context.Context, query []byte) ([]byte, error) {
return ft.Data, ft.Err
}
func (ft FakeTransport) RequiresPadding() bool {
return false
}
func (ft FakeTransport) Address() string {
return ""
}
func (ft FakeTransport) Network() string {
return "fake"
}
func (fk FakeTransport) CloseIdleConnections() {
// nothing to do
}
type FakeEncoder struct {
Data []byte
Err error
}
func (fe FakeEncoder) Encode(domain string, qtype uint16, padding bool) ([]byte, error) {
return fe.Data, fe.Err
}
func NewFakeResolverThatFails() model.Resolver { func NewFakeResolverThatFails() model.Resolver {
return NewFakeResolverWithExplicitError(netxlite.ErrOODNSNoSuchHost) return NewFakeResolverWithExplicitError(netxlite.ErrOODNSNoSuchHost)
} }

View File

@ -99,14 +99,14 @@ func TestNewResolverTCPDomain(t *testing.T) {
func TestNewResolverDoTAddress(t *testing.T) { func TestNewResolverDoTAddress(t *testing.T) {
reso := netxlite.NewSerialResolver( reso := netxlite.NewSerialResolver(
netxlite.NewDNSOverTLS(new(tls.Dialer).DialContext, "8.8.8.8:853")) netxlite.NewDNSOverTLSTransport(new(tls.Dialer).DialContext, "8.8.8.8:853"))
testresolverquick(t, reso) testresolverquick(t, reso)
testresolverquickidna(t, reso) testresolverquickidna(t, reso)
} }
func TestNewResolverDoTDomain(t *testing.T) { func TestNewResolverDoTDomain(t *testing.T) {
reso := netxlite.NewSerialResolver( reso := netxlite.NewSerialResolver(
netxlite.NewDNSOverTLS(new(tls.Dialer).DialContext, "dns.google.com:853")) netxlite.NewDNSOverTLSTransport(new(tls.Dialer).DialContext, "dns.google.com:853"))
testresolverquick(t, reso) testresolverquick(t, reso)
testresolverquickidna(t, reso) testresolverquickidna(t, reso)
} }

View File

@ -46,28 +46,41 @@ type SaverDNSTransport struct {
} }
// RoundTrip implements RoundTripper.RoundTrip // RoundTrip implements RoundTripper.RoundTrip
func (txp SaverDNSTransport) RoundTrip(ctx context.Context, query []byte) ([]byte, error) { func (txp SaverDNSTransport) RoundTrip(
ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
start := time.Now() start := time.Now()
txp.Saver.Write(trace.Event{ txp.Saver.Write(trace.Event{
Address: txp.Address(), Address: txp.Address(),
DNSQuery: query, DNSQuery: txp.maybeQueryBytes(query),
Name: "dns_round_trip_start", Name: "dns_round_trip_start",
Proto: txp.Network(), Proto: txp.Network(),
Time: start, Time: start,
}) })
reply, err := txp.DNSTransport.RoundTrip(ctx, query) response, err := txp.DNSTransport.RoundTrip(ctx, query)
stop := time.Now() stop := time.Now()
txp.Saver.Write(trace.Event{ txp.Saver.Write(trace.Event{
Address: txp.Address(), Address: txp.Address(),
DNSQuery: query, DNSQuery: txp.maybeQueryBytes(query),
DNSReply: reply, DNSReply: txp.maybeResponseBytes(response),
Duration: stop.Sub(start), Duration: stop.Sub(start),
Err: err, Err: err,
Name: "dns_round_trip_done", Name: "dns_round_trip_done",
Proto: txp.Network(), Proto: txp.Network(),
Time: stop, Time: stop,
}) })
return reply, err return response, err
}
func (txp SaverDNSTransport) maybeQueryBytes(query model.DNSQuery) []byte {
data, _ := query.Bytes()
return data
}
func (txp SaverDNSTransport) maybeResponseBytes(response model.DNSResponse) []byte {
if response == nil {
return nil
}
return response.Bytes()
} }
var _ model.Resolver = SaverResolver{} var _ model.Resolver = SaverResolver{}

View File

@ -10,6 +10,8 @@ import (
"github.com/ooni/probe-cli/v3/internal/engine/netx/resolver" "github.com/ooni/probe-cli/v3/internal/engine/netx/resolver"
"github.com/ooni/probe-cli/v3/internal/engine/netx/trace" "github.com/ooni/probe-cli/v3/internal/engine/netx/trace"
"github.com/ooni/probe-cli/v3/internal/model"
"github.com/ooni/probe-cli/v3/internal/model/mocks"
) )
func TestSaverResolverFailure(t *testing.T) { func TestSaverResolverFailure(t *testing.T) {
@ -110,12 +112,25 @@ func TestSaverDNSTransportFailure(t *testing.T) {
expected := errors.New("no such host") expected := errors.New("no such host")
saver := &trace.Saver{} saver := &trace.Saver{}
txp := resolver.SaverDNSTransport{ txp := resolver.SaverDNSTransport{
DNSTransport: resolver.FakeTransport{ DNSTransport: &mocks.DNSTransport{
Err: expected, MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
return nil, expected
},
MockNetwork: func() string {
return "fake"
},
MockAddress: func() string {
return ""
},
}, },
Saver: saver, Saver: saver,
} }
query := []byte("abc") rawQuery := []byte{0xde, 0xad, 0xbe, 0xef}
query := &mocks.DNSQuery{
MockBytes: func() ([]byte, error) {
return rawQuery, nil
},
}
reply, err := txp.RoundTrip(context.Background(), query) reply, err := txp.RoundTrip(context.Background(), query)
if !errors.Is(err, expected) { if !errors.Is(err, expected) {
t.Fatal("not the error we expected") t.Fatal("not the error we expected")
@ -127,7 +142,7 @@ func TestSaverDNSTransportFailure(t *testing.T) {
if len(ev) != 2 { if len(ev) != 2 {
t.Fatal("expected number of events") t.Fatal("expected number of events")
} }
if !bytes.Equal(ev[0].DNSQuery, query) { if !bytes.Equal(ev[0].DNSQuery, rawQuery) {
t.Fatal("unexpected DNSQuery") t.Fatal("unexpected DNSQuery")
} }
if ev[0].Name != "dns_round_trip_start" { if ev[0].Name != "dns_round_trip_start" {
@ -136,7 +151,7 @@ func TestSaverDNSTransportFailure(t *testing.T) {
if !ev[0].Time.Before(time.Now()) { if !ev[0].Time.Before(time.Now()) {
t.Fatal("the saved time is wrong") t.Fatal("the saved time is wrong")
} }
if !bytes.Equal(ev[1].DNSQuery, query) { if !bytes.Equal(ev[1].DNSQuery, rawQuery) {
t.Fatal("unexpected DNSQuery") t.Fatal("unexpected DNSQuery")
} }
if ev[1].DNSReply != nil { if ev[1].DNSReply != nil {
@ -157,27 +172,45 @@ func TestSaverDNSTransportFailure(t *testing.T) {
} }
func TestSaverDNSTransportSuccess(t *testing.T) { func TestSaverDNSTransportSuccess(t *testing.T) {
expected := []byte("def") expected := []byte{0xef, 0xbe, 0xad, 0xde}
saver := &trace.Saver{} saver := &trace.Saver{}
response := &mocks.DNSResponse{
MockBytes: func() []byte {
return expected
},
}
txp := resolver.SaverDNSTransport{ txp := resolver.SaverDNSTransport{
DNSTransport: resolver.FakeTransport{ DNSTransport: &mocks.DNSTransport{
Data: expected, MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
return response, nil
},
MockNetwork: func() string {
return "fake"
},
MockAddress: func() string {
return ""
},
}, },
Saver: saver, Saver: saver,
} }
query := []byte("abc") rawQuery := []byte{0xde, 0xad, 0xbe, 0xef}
query := &mocks.DNSQuery{
MockBytes: func() ([]byte, error) {
return rawQuery, nil
},
}
reply, err := txp.RoundTrip(context.Background(), query) reply, err := txp.RoundTrip(context.Background(), query)
if err != nil { if err != nil {
t.Fatal("we expected nil error here") t.Fatal("we expected nil error here")
} }
if !bytes.Equal(reply, expected) { if !bytes.Equal(reply.Bytes(), expected) {
t.Fatal("expected another reply here") t.Fatal("expected another reply here")
} }
ev := saver.Read() ev := saver.Read()
if len(ev) != 2 { if len(ev) != 2 {
t.Fatal("expected number of events") t.Fatal("expected number of events")
} }
if !bytes.Equal(ev[0].DNSQuery, query) { if !bytes.Equal(ev[0].DNSQuery, rawQuery) {
t.Fatal("unexpected DNSQuery") t.Fatal("unexpected DNSQuery")
} }
if ev[0].Name != "dns_round_trip_start" { if ev[0].Name != "dns_round_trip_start" {
@ -186,7 +219,7 @@ func TestSaverDNSTransportSuccess(t *testing.T) {
if !ev[0].Time.Before(time.Now()) { if !ev[0].Time.Before(time.Now()) {
t.Fatal("the saved time is wrong") t.Fatal("the saved time is wrong")
} }
if !bytes.Equal(ev[1].DNSQuery, query) { if !bytes.Equal(ev[1].DNSQuery, rawQuery) {
t.Fatal("unexpected DNSQuery") t.Fatal("unexpected DNSQuery")
} }
if !bytes.Equal(ev[1].DNSReply, expected) { if !bytes.Equal(ev[1].DNSReply, expected) {

View File

@ -36,18 +36,31 @@ type DNSRoundTripEvent struct {
Reply []byte Reply []byte
} }
func (txp *dnsxRoundTripperDB) RoundTrip(ctx context.Context, query []byte) ([]byte, error) { func (txp *dnsxRoundTripperDB) RoundTrip(
ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
started := time.Since(txp.begin).Seconds() started := time.Since(txp.begin).Seconds()
reply, err := txp.DNSTransport.RoundTrip(ctx, query) response, err := txp.DNSTransport.RoundTrip(ctx, query)
finished := time.Since(txp.begin).Seconds() finished := time.Since(txp.begin).Seconds()
txp.db.InsertIntoDNSRoundTrip(&DNSRoundTripEvent{ txp.db.InsertIntoDNSRoundTrip(&DNSRoundTripEvent{
Network: txp.DNSTransport.Network(), Network: txp.DNSTransport.Network(),
Address: txp.DNSTransport.Address(), Address: txp.DNSTransport.Address(),
Query: query, Query: txp.maybeQueryBytes(query),
Started: started, Started: started,
Finished: finished, Finished: finished,
Failure: NewFailure(err), Failure: NewFailure(err),
Reply: reply, Reply: txp.maybeResponseBytes(response),
}) })
return reply, err return response, err
}
func (txp *dnsxRoundTripperDB) maybeQueryBytes(query model.DNSQuery) []byte {
data, _ := query.Bytes()
return data
}
func (txp *dnsxRoundTripperDB) maybeResponseBytes(response model.DNSResponse) []byte {
if response == nil {
return nil
}
return response.Bytes()
} }

View File

@ -1,36 +1,20 @@
package mocks package mocks
import ( //
"net" // Mocks for model.DNSDecoder
//
"github.com/miekg/dns" import (
"github.com/ooni/probe-cli/v3/internal/model" "github.com/ooni/probe-cli/v3/internal/model"
) )
// DNSDecoder allows mocking dnsx.DNSDecoder. // DNSDecoder allows mocking model.DNSDecoder.
type DNSDecoder struct { type DNSDecoder struct {
MockDecodeLookupHost func(qtype uint16, reply []byte, queryID uint16) ([]string, error) MockDecodeResponse func(data []byte, query model.DNSQuery) (model.DNSResponse, error)
MockDecodeHTTPS func(reply []byte, queryID uint16) (*model.HTTPSSvc, error)
MockDecodeNS func(reply []byte, queryID uint16) ([]*net.NS, error)
MockDecodeReply func(reply []byte) (*dns.Msg, error)
} }
// DecodeLookupHost calls MockDecodeLookupHost. var _ model.DNSDecoder = &DNSDecoder{}
func (e *DNSDecoder) DecodeLookupHost(qtype uint16, reply []byte, queryID uint16) ([]string, error) {
return e.MockDecodeLookupHost(qtype, reply, queryID)
}
// DecodeHTTPS calls MockDecodeHTTPS. func (e *DNSDecoder) DecodeResponse(data []byte, query model.DNSQuery) (model.DNSResponse, error) {
func (e *DNSDecoder) DecodeHTTPS(reply []byte, queryID uint16) (*model.HTTPSSvc, error) { return e.MockDecodeResponse(data, query)
return e.MockDecodeHTTPS(reply, queryID)
}
// DecodeNS calls MockDecodeNS.
func (e *DNSDecoder) DecodeNS(reply []byte, queryID uint16) ([]*net.NS, error) {
return e.MockDecodeNS(reply, queryID)
}
// DecodeReply calls MockDecodeReply.
func (e *DNSDecoder) DecodeReply(reply []byte) (*dns.Msg, error) {
return e.MockDecodeReply(reply)
} }

View File

@ -2,70 +2,20 @@ package mocks
import ( import (
"errors" "errors"
"net"
"testing" "testing"
"github.com/miekg/dns"
"github.com/ooni/probe-cli/v3/internal/model" "github.com/ooni/probe-cli/v3/internal/model"
) )
func TestDNSDecoder(t *testing.T) { func TestDNSDecoder(t *testing.T) {
t.Run("DecodeLookupHost", func(t *testing.T) { t.Run("DecodeResponse", func(t *testing.T) {
expected := errors.New("mocked error") expected := errors.New("mocked error")
e := &DNSDecoder{ e := &DNSDecoder{
MockDecodeLookupHost: func(qtype uint16, reply []byte, queryID uint16) ([]string, error) { MockDecodeResponse: func(reply []byte, query model.DNSQuery) (model.DNSResponse, error) {
return nil, expected return nil, expected
}, },
} }
out, err := e.DecodeLookupHost(dns.TypeA, make([]byte, 17), dns.Id()) out, err := e.DecodeResponse(make([]byte, 17), &DNSQuery{})
if !errors.Is(err, expected) {
t.Fatal("unexpected err", err)
}
if out != nil {
t.Fatal("unexpected out")
}
})
t.Run("DecodeHTTPS", func(t *testing.T) {
expected := errors.New("mocked error")
e := &DNSDecoder{
MockDecodeHTTPS: func(reply []byte, queryID uint16) (*model.HTTPSSvc, error) {
return nil, expected
},
}
out, err := e.DecodeHTTPS(make([]byte, 17), dns.Id())
if !errors.Is(err, expected) {
t.Fatal("unexpected err", err)
}
if out != nil {
t.Fatal("unexpected out")
}
})
t.Run("DecodeNS", func(t *testing.T) {
expected := errors.New("mocked error")
e := &DNSDecoder{
MockDecodeNS: func(reply []byte, queryID uint16) ([]*net.NS, error) {
return nil, expected
},
}
out, err := e.DecodeNS(make([]byte, 17), dns.Id())
if !errors.Is(err, expected) {
t.Fatal("unexpected err", err)
}
if out != nil {
t.Fatal("unexpected out")
}
})
t.Run("DecodeReply", func(t *testing.T) {
expected := errors.New("mocked error")
e := &DNSDecoder{
MockDecodeReply: func(reply []byte) (*dns.Msg, error) {
return nil, expected
},
}
out, err := e.DecodeReply(make([]byte, 17))
if !errors.Is(err, expected) { if !errors.Is(err, expected) {
t.Fatal("unexpected err", err) t.Fatal("unexpected err", err)
} }

View File

@ -1,11 +1,19 @@
package mocks package mocks
// DNSEncoder allows mocking dnsx.DNSEncoder. //
// Mocks for model.DNSEncoder.
//
import "github.com/ooni/probe-cli/v3/internal/model"
// DNSEncoder allows mocking model.DNSEncoder.
type DNSEncoder struct { type DNSEncoder struct {
MockEncode func(domain string, qtype uint16, padding bool) ([]byte, uint16, error) MockEncode func(domain string, qtype uint16, padding bool) model.DNSQuery
} }
var _ model.DNSEncoder = &DNSEncoder{}
// Encode calls MockEncode. // Encode calls MockEncode.
func (e *DNSEncoder) Encode(domain string, qtype uint16, padding bool) ([]byte, uint16, error) { func (e *DNSEncoder) Encode(domain string, qtype uint16, padding bool) model.DNSQuery {
return e.MockEncode(domain, qtype, padding) return e.MockEncode(domain, qtype, padding)
} }

View File

@ -5,24 +5,46 @@ import (
"testing" "testing"
"github.com/miekg/dns" "github.com/miekg/dns"
"github.com/ooni/probe-cli/v3/internal/model"
) )
func TestDNSEncoder(t *testing.T) { func TestDNSEncoder(t *testing.T) {
t.Run("Encode", func(t *testing.T) { t.Run("Encode", func(t *testing.T) {
expected := errors.New("mocked error") expected := errors.New("mocked error")
queryID := dns.Id()
e := &DNSEncoder{ e := &DNSEncoder{
MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, uint16, error) { MockEncode: func(domain string, qtype uint16, padding bool) model.DNSQuery {
return nil, 0, expected return &DNSQuery{
MockDomain: func() string {
return dns.Fqdn(domain) // do what an implementation MUST do
},
MockType: func() uint16 {
return qtype
},
MockBytes: func() ([]byte, error) {
return nil, expected
},
MockID: func() uint16 {
return queryID
}, },
} }
out, queryID, err := e.Encode("dns.google", dns.TypeA, true) },
}
query := e.Encode("dns.google", dns.TypeA, true)
if query.Domain() != "dns.google." {
t.Fatal("invalid domain")
}
if query.Type() != dns.TypeA {
t.Fatal("invalid type")
}
out, err := query.Bytes()
if !errors.Is(err, expected) { if !errors.Is(err, expected) {
t.Fatal("unexpected err", err) t.Fatal("unexpected err", err)
} }
if out != nil { if out != nil {
t.Fatal("unexpected out") t.Fatal("unexpected out")
} }
if queryID != 0 { if query.ID() != queryID {
t.Fatal("unexpected queryID") t.Fatal("unexpected queryID")
} }
}) })

View File

@ -0,0 +1,33 @@
package mocks
//
// Mocks for model.DNSQuery.
//
import "github.com/ooni/probe-cli/v3/internal/model"
// DNSQuery allocks mocking model.DNSQuery.
type DNSQuery struct {
MockDomain func() string
MockType func() uint16
MockBytes func() ([]byte, error)
MockID func() uint16
}
func (q *DNSQuery) Domain() string {
return q.MockDomain()
}
func (q *DNSQuery) Type() uint16 {
return q.MockType()
}
func (q *DNSQuery) Bytes() ([]byte, error) {
return q.MockBytes()
}
func (q *DNSQuery) ID() uint16 {
return q.MockID()
}
var _ model.DNSQuery = &DNSQuery{}

View File

@ -0,0 +1,62 @@
package mocks
import (
"bytes"
"testing"
"github.com/miekg/dns"
)
func TestDNSQuery(t *testing.T) {
t.Run("Domain", func(t *testing.T) {
expected := "dns.google."
q := &DNSQuery{
MockDomain: func() string {
return expected
},
}
if q.Domain() != expected {
t.Fatal("invalid domain")
}
})
t.Run("Type", func(t *testing.T) {
expected := dns.TypeAAAA
q := &DNSQuery{
MockType: func() uint16 {
return expected
},
}
if q.Type() != expected {
t.Fatal("invalid type")
}
})
t.Run("Bytes", func(t *testing.T) {
expected := []byte{0xde, 0xea, 0xad, 0xbe, 0xef}
q := &DNSQuery{
MockBytes: func() ([]byte, error) {
return expected, nil
},
}
out, err := q.Bytes()
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(expected, out) {
t.Fatal("invalid bytes")
}
})
t.Run("ID", func(t *testing.T) {
expected := dns.Id()
q := &DNSQuery{
MockID: func() uint16 {
return expected
},
}
if q.ID() != expected {
t.Fatal("invalid id")
}
})
}

View File

@ -0,0 +1,47 @@
package mocks
//
// Mocks for model.DNSResponse
//
import (
"net"
"github.com/ooni/probe-cli/v3/internal/model"
)
// DNSResponse allows mocking model.DNSResponse.
type DNSResponse struct {
MockQuery func() model.DNSQuery
MockBytes func() []byte
MockRcode func() int
MockDecodeHTTPS func() (*model.HTTPSSvc, error)
MockDecodeLookupHost func() ([]string, error)
MockDecodeNS func() ([]*net.NS, error)
}
var _ model.DNSResponse = &DNSResponse{}
func (r *DNSResponse) Query() model.DNSQuery {
return r.MockQuery()
}
func (r *DNSResponse) Bytes() []byte {
return r.MockBytes()
}
func (r *DNSResponse) Rcode() int {
return r.MockRcode()
}
func (r *DNSResponse) DecodeHTTPS() (*model.HTTPSSvc, error) {
return r.MockDecodeHTTPS()
}
func (r *DNSResponse) DecodeLookupHost() ([]string, error) {
return r.MockDecodeLookupHost()
}
func (r *DNSResponse) DecodeNS() ([]*net.NS, error) {
return r.MockDecodeNS()
}

View File

@ -0,0 +1,105 @@
package mocks
import (
"bytes"
"errors"
"net"
"testing"
"github.com/miekg/dns"
"github.com/ooni/probe-cli/v3/internal/model"
)
func TestDNSResponse(t *testing.T) {
t.Run("Query", func(t *testing.T) {
qid := dns.Id()
query := &DNSQuery{
MockID: func() uint16 {
return qid
},
}
resp := &DNSResponse{
MockQuery: func() model.DNSQuery {
return query
},
}
out := resp.Query()
if out.ID() != query.ID() {
t.Fatal("invalid query")
}
})
t.Run("Bytes", func(t *testing.T) {
expected := []byte{0xde, 0xea, 0xad, 0xbe, 0xef}
resp := &DNSResponse{
MockBytes: func() []byte {
return expected
},
}
out := resp.Bytes()
if !bytes.Equal(expected, out) {
t.Fatal("invalid bytes")
}
})
t.Run("Rcode", func(t *testing.T) {
expected := dns.RcodeBadAlg
resp := &DNSResponse{
MockRcode: func() int {
return expected
},
}
out := resp.Rcode()
if out != expected {
t.Fatal("invalid rcode")
}
})
t.Run("DecodeLookupHost", func(t *testing.T) {
expected := errors.New("mocked error")
r := &DNSResponse{
MockDecodeLookupHost: func() ([]string, error) {
return nil, expected
},
}
out, err := r.DecodeLookupHost()
if !errors.Is(err, expected) {
t.Fatal("unexpected err", err)
}
if out != nil {
t.Fatal("unexpected out")
}
})
t.Run("DecodeHTTPS", func(t *testing.T) {
expected := errors.New("mocked error")
r := &DNSResponse{
MockDecodeHTTPS: func() (*model.HTTPSSvc, error) {
return nil, expected
},
}
out, err := r.DecodeHTTPS()
if !errors.Is(err, expected) {
t.Fatal("unexpected err", err)
}
if out != nil {
t.Fatal("unexpected out")
}
})
t.Run("DecodeNS", func(t *testing.T) {
expected := errors.New("mocked error")
r := &DNSResponse{
MockDecodeNS: func() ([]*net.NS, error) {
return nil, expected
},
}
out, err := r.DecodeNS()
if !errors.Is(err, expected) {
t.Fatal("unexpected err", err)
}
if out != nil {
t.Fatal("unexpected out")
}
})
}

View File

@ -1,10 +1,14 @@
package mocks package mocks
import "context" import (
"context"
"github.com/ooni/probe-cli/v3/internal/model"
)
// DNSTransport allows mocking dnsx.DNSTransport. // DNSTransport allows mocking dnsx.DNSTransport.
type DNSTransport struct { type DNSTransport struct {
MockRoundTrip func(ctx context.Context, query []byte) ([]byte, error) MockRoundTrip func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error)
MockRequiresPadding func() bool MockRequiresPadding func() bool
@ -15,8 +19,10 @@ type DNSTransport struct {
MockCloseIdleConnections func() MockCloseIdleConnections func()
} }
var _ model.DNSTransport = &DNSTransport{}
// RoundTrip calls MockRoundTrip. // RoundTrip calls MockRoundTrip.
func (txp *DNSTransport) RoundTrip(ctx context.Context, query []byte) ([]byte, error) { func (txp *DNSTransport) RoundTrip(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
return txp.MockRoundTrip(ctx, query) return txp.MockRoundTrip(ctx, query)
} }

View File

@ -6,17 +6,18 @@ import (
"testing" "testing"
"github.com/ooni/probe-cli/v3/internal/atomicx" "github.com/ooni/probe-cli/v3/internal/atomicx"
"github.com/ooni/probe-cli/v3/internal/model"
) )
func TestDNSTransport(t *testing.T) { func TestDNSTransport(t *testing.T) {
t.Run("RoundTrip", func(t *testing.T) { t.Run("RoundTrip", func(t *testing.T) {
expected := errors.New("mocked error") expected := errors.New("mocked error")
txp := &DNSTransport{ txp := &DNSTransport{
MockRoundTrip: func(ctx context.Context, query []byte) ([]byte, error) { MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
return nil, expected return nil, expected
}, },
} }
resp, err := txp.RoundTrip(context.Background(), make([]byte, 16)) resp, err := txp.RoundTrip(context.Background(), &DNSQuery{})
if !errors.Is(err, expected) { if !errors.Is(err, expected) {
t.Fatal("not the error we expected", err) t.Fatal("not the error we expected", err)
} }

View File

@ -1,5 +1,9 @@
package model package model
//
// Network extensions
//
import ( import (
"context" "context"
"crypto/tls" "crypto/tls"
@ -9,74 +13,81 @@ import (
"time" "time"
"github.com/lucas-clemente/quic-go" "github.com/lucas-clemente/quic-go"
"github.com/miekg/dns"
) )
// // DNSResponse is a parsed DNS response ready for further processing.
// Network extensions type DNSResponse interface {
// // Query is the query associated with this response.
Query() DNSQuery
// The DNSDecoder decodes DNS replies. // Bytes returns the bytes from which we parsed the query.
Bytes() []byte
// Rcode returns the response's Rcode.
Rcode() int
// DecodeHTTPS returns information gathered from all the HTTPS
// records found inside of this response.
DecodeHTTPS() (*HTTPSSvc, error)
// DecodeLookupHost returns the addresses in the response matching
// the original query type (one of A and AAAA).
DecodeLookupHost() ([]string, error)
// DecodeNS returns all the NS entries in this response.
DecodeNS() ([]*net.NS, error)
}
// The DNSDecoder decodes DNS responses.
type DNSDecoder interface { type DNSDecoder interface {
// DecodeLookupHost decodes an A or AAAA reply. // DecodeResponse decodes a DNS response message.
//
// Arguments:
//
// - qtype is the query type (e.g., dns.TypeAAAA)
//
// - data contains the reply bytes read from a DNSTransport
//
// - queryID is the original query ID
//
// Returns:
//
// - on success, a list of IP addrs inside the reply and a nil error
//
// - on failure, a nil list and an error.
//
// Note that this function will return an error if there is no
// IP address inside of the reply.
DecodeLookupHost(qtype uint16, data []byte, queryID uint16) ([]string, error)
// DecodeHTTPS is like DecodeLookupHost but decodes an HTTPS reply.
//
// The argument is the reply as read by the DNSTransport.
//
// On success, this function returns an HTTPSSvc structure and
// a nil error. On failure, the HTTPSSvc pointer is nil and
// the error points to the error that occurred.
//
// This function will return an error if the HTTPS reply does not
// contain at least a valid ALPN entry. It will not return
// an error, though, when there are no IPv4/IPv6 hints in the reply.
DecodeHTTPS(data []byte, queryID uint16) (*HTTPSSvc, error)
// DecodeNS is like DecodeHTTPS but for NS queries.
DecodeNS(data []byte, queryID uint16) ([]*net.NS, error)
// DecodeReply decodes a DNS reply message.
// //
// Arguments: // Arguments:
// //
// - data is the raw reply // - data is the raw reply
// //
// This function fails if we cannot parse data as a DNS // This function fails if we cannot parse data as a DNS
// message or the message is not a reply. // message or the message is not a response.
// //
// If you use this function, remember that: // Regarding the returned response, remember that the Rcode
// MAY still be nonzero (this method does not treat a nonzero
// Rcode as an error when parsing the response).
DecodeResponse(data []byte, query DNSQuery) (DNSResponse, error)
}
// DNSQuery is an encoded DNS query ready to be sent using a DNSTransport.
type DNSQuery interface {
// Domain is the domain we're querying for.
Domain() string
// Type is the query type.
Type() uint16
// Bytes serializes the query to bytes. This function may fail if we're not
// able to correctly encode the domain into a query message.
// //
// 1. the Rcode MAY be nonzero; // The value returned by this function WILL be memoized after the first call,
// // so you SHOULD create a new DNSQuery if you need to retry a query.
// 2. the replyID MAY NOT match the original query ID. Bytes() ([]byte, error)
//
// That is, this is a very basic parsing method. // ID returns the query ID.
DecodeReply(data []byte) (*dns.Msg, error) ID() uint16
} }
// The DNSEncoder encodes DNS queries to bytes // The DNSEncoder encodes DNS queries to bytes
type DNSEncoder interface { type DNSEncoder interface {
// Encode transforms its arguments into a serialized DNS query. // Encode transforms its arguments into a serialized DNS query.
// //
// Every time you call Encode, you get a new DNSQuery value
// using a query ID selected at random.
//
// Serialization to bytes is lazy to acommodate DNS transports that
// do not need to serialize and send bytes, e.g., getaddrinfo.
//
// You serialize to bytes using DNSQuery.Bytes. This operation MAY fail
// if the domain name cannot be packed into a DNS message (e.g., it is
// too long to fit into the message).
//
// Arguments: // Arguments:
// //
// - domain is the domain for the query (e.g., x.org); // - domain is the domain for the query (e.g., x.org);
@ -85,16 +96,15 @@ type DNSEncoder interface {
// //
// - padding is whether to add padding to the query. // - padding is whether to add padding to the query.
// //
// On success, this function returns a valid byte array, the queryID, and // This function will transform the domain into an FQDN is it's not
// a nil error. On failure, we have a non-nil error, a nil arrary and a zero // already expressed in the FQDN format.
// query ID. Encode(domain string, qtype uint16, padding bool) DNSQuery
Encode(domain string, qtype uint16, padding bool) ([]byte, uint16, error)
} }
// DNSTransport represents an abstract DNS transport. // DNSTransport represents an abstract DNS transport.
type DNSTransport interface { type DNSTransport interface {
// RoundTrip sends a DNS query and receives the reply. // RoundTrip sends a DNS query and receives the reply.
RoundTrip(ctx context.Context, query []byte) (reply []byte, err error) RoundTrip(ctx context.Context, query DNSQuery) (DNSResponse, error)
// RequiresPadding returns whether this transport needs padding. // RequiresPadding returns whether this transport needs padding.
RequiresPadding() bool RequiresPadding() bool

View File

@ -15,14 +15,16 @@ import (
// DNSDecoderMiekg uses github.com/miekg/dns to implement the Decoder. // DNSDecoderMiekg uses github.com/miekg/dns to implement the Decoder.
type DNSDecoderMiekg struct{} type DNSDecoderMiekg struct{}
// ErrDNSReplyWithWrongQueryID indicates we have got a DNS reply with the wrong queryID. var (
var ErrDNSReplyWithWrongQueryID = errors.New(FailureDNSReplyWithWrongQueryID) // ErrDNSReplyWithWrongQueryID indicates we have got a DNS reply with the wrong queryID.
ErrDNSReplyWithWrongQueryID = errors.New(FailureDNSReplyWithWrongQueryID)
// ErrDNSIsQuery indicates that we were passed a DNS query. // ErrDNSIsQuery indicates that we were passed a DNS query.
var ErrDNSIsQuery = errors.New("ooresolver: expected response but received query") ErrDNSIsQuery = errors.New("ooresolver: expected response but received query")
)
// DecodeReply implements model.DNSDecoder.DecodeReply // DecodeResponse implements model.DNSDecoder.DecodeResponse.
func (d *DNSDecoderMiekg) DecodeReply(data []byte) (*dns.Msg, error) { func (d *DNSDecoderMiekg) DecodeResponse(data []byte, query model.DNSQuery) (model.DNSResponse, error) {
reply := &dns.Msg{} reply := &dns.Msg{}
if err := reply.Unpack(data); err != nil { if err := reply.Unpack(data); err != nil {
return nil, err return nil, err
@ -30,46 +32,64 @@ func (d *DNSDecoderMiekg) DecodeReply(data []byte) (*dns.Msg, error) {
if !reply.Response { if !reply.Response {
return nil, ErrDNSIsQuery return nil, ErrDNSIsQuery
} }
return reply, nil if reply.Id != query.ID() {
}
// decodeSuccessfulReply decodes the bytes in data as a successful reply for the
// given queryID. This function returns an error if:
//
// 1. we cannot decode data
//
// 2. the decoded message is not a reply
//
// 3. the query ID does not match
//
// 4. the Rcode is not zero.
func (d *DNSDecoderMiekg) decodeSuccessfulReply(data []byte, queryID uint16) (*dns.Msg, error) {
reply, err := d.DecodeReply(data)
if err != nil {
return nil, err
}
if reply.Id != queryID {
return nil, ErrDNSReplyWithWrongQueryID return nil, ErrDNSReplyWithWrongQueryID
} }
resp := &dnsResponse{
bytes: data,
msg: reply,
query: query,
}
return resp, nil
}
// dnsResponse implements model.DNSResponse.
type dnsResponse struct {
// bytes contains the response bytes.
bytes []byte
// msg contains the message.
msg *dns.Msg
// query is the original query.
query model.DNSQuery
}
// Query implements model.DNSResponse.Query.
func (r *dnsResponse) Query() model.DNSQuery {
return r.query
}
// Bytes implements model.DNSResponse.Bytes.
func (r *dnsResponse) Bytes() []byte {
return r.bytes
}
// Rcode implements model.DNSResponse.Rcode.
func (r *dnsResponse) Rcode() int {
return r.msg.Rcode
}
func (r *dnsResponse) rcodeToError() error {
// TODO(bassosimone): map more errors to net.DNSError names // TODO(bassosimone): map more errors to net.DNSError names
// TODO(bassosimone): add support for lame referral. // TODO(bassosimone): add support for lame referral.
switch reply.Rcode { switch r.msg.Rcode {
case dns.RcodeSuccess: case dns.RcodeSuccess:
return reply, nil return nil
case dns.RcodeNameError: case dns.RcodeNameError:
return nil, ErrOODNSNoSuchHost return ErrOODNSNoSuchHost
case dns.RcodeRefused: case dns.RcodeRefused:
return nil, ErrOODNSRefused return ErrOODNSRefused
case dns.RcodeServerFailure: case dns.RcodeServerFailure:
return nil, ErrOODNSServfail return ErrOODNSServfail
default: default:
return nil, ErrOODNSMisbehaving return ErrOODNSMisbehaving
} }
} }
func (d *DNSDecoderMiekg) DecodeHTTPS(data []byte, queryID uint16) (*model.HTTPSSvc, error) { // DecodeHTTPS implements model.DNSResponse.DecodeHTTPS.
reply, err := d.decodeSuccessfulReply(data, queryID) func (r *dnsResponse) DecodeHTTPS() (*model.HTTPSSvc, error) {
if err != nil { if err := r.rcodeToError(); err != nil {
return nil, err return nil, err
} }
out := &model.HTTPSSvc{ out := &model.HTTPSSvc{
@ -77,7 +97,7 @@ func (d *DNSDecoderMiekg) DecodeHTTPS(data []byte, queryID uint16) (*model.HTTPS
IPv4: []string{}, // ensure it's not nil IPv4: []string{}, // ensure it's not nil
IPv6: []string{}, // ensure it's not nil IPv6: []string{}, // ensure it's not nil
} }
for _, answer := range reply.Answer { for _, answer := range r.msg.Answer {
switch avalue := answer.(type) { switch avalue := answer.(type) {
case *dns.HTTPS: case *dns.HTTPS:
for _, v := range avalue.Value { for _, v := range avalue.Value {
@ -102,14 +122,14 @@ func (d *DNSDecoderMiekg) DecodeHTTPS(data []byte, queryID uint16) (*model.HTTPS
return out, nil return out, nil
} }
func (d *DNSDecoderMiekg) DecodeLookupHost(qtype uint16, data []byte, queryID uint16) ([]string, error) { // DecodeLookupHost implements model.DNSResponse.DecodeLookupHost.
reply, err := d.decodeSuccessfulReply(data, queryID) func (r *dnsResponse) DecodeLookupHost() ([]string, error) {
if err != nil { if err := r.rcodeToError(); err != nil {
return nil, err return nil, err
} }
var addrs []string var addrs []string
for _, answer := range reply.Answer { for _, answer := range r.msg.Answer {
switch qtype { switch r.Query().Type() {
case dns.TypeA: case dns.TypeA:
if rra, ok := answer.(*dns.A); ok { if rra, ok := answer.(*dns.A); ok {
ip := rra.A ip := rra.A
@ -128,13 +148,13 @@ func (d *DNSDecoderMiekg) DecodeLookupHost(qtype uint16, data []byte, queryID ui
return addrs, nil return addrs, nil
} }
func (d *DNSDecoderMiekg) DecodeNS(data []byte, queryID uint16) ([]*net.NS, error) { // DecodeNS implements model.DNSResponse.DecodeNS.
reply, err := d.decodeSuccessfulReply(data, queryID) func (r *dnsResponse) DecodeNS() ([]*net.NS, error) {
if err != nil { if err := r.rcodeToError(); err != nil {
return nil, err return nil, err
} }
out := []*net.NS{} out := []*net.NS{}
for _, answer := range reply.Answer { for _, answer := range r.msg.Answer {
switch avalue := answer.(type) { switch avalue := answer.(type) {
case *dns.NS: case *dns.NS:
out = append(out, &net.NS{Host: avalue.Ns}) out = append(out, &net.NS{Host: avalue.Ns})
@ -147,3 +167,4 @@ func (d *DNSDecoderMiekg) DecodeNS(data []byte, queryID uint16) ([]*net.NS, erro
} }
var _ model.DNSDecoder = &DNSDecoderMiekg{} var _ model.DNSDecoder = &DNSDecoderMiekg{}
var _ model.DNSResponse = &dnsResponse{}

View File

@ -1,26 +1,27 @@
package netxlite package netxlite
import ( import (
"bytes"
"errors" "errors"
"net" "net"
"strings"
"testing" "testing"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
"github.com/miekg/dns" "github.com/miekg/dns"
"github.com/ooni/probe-cli/v3/internal/model/mocks"
"github.com/ooni/probe-cli/v3/internal/runtimex" "github.com/ooni/probe-cli/v3/internal/runtimex"
) )
func TestDNSDecoder(t *testing.T) { func TestDNSDecoderMiekg(t *testing.T) {
t.Run("LookupHost", func(t *testing.T) { t.Run("DecodeResponse", func(t *testing.T) {
t.Run("UnpackError", func(t *testing.T) { t.Run("UnpackError", func(t *testing.T) {
d := &DNSDecoderMiekg{} d := &DNSDecoderMiekg{}
data, err := d.DecodeLookupHost(dns.TypeA, nil, 0) resp, err := d.DecodeResponse(nil, &mocks.DNSQuery{})
if err == nil || err.Error() != "dns: overflow unpacking uint16" { if err == nil || err.Error() != "dns: overflow unpacking uint16" {
t.Fatal("unexpected error", err) t.Fatal("unexpected error", err)
} }
if data != nil { if resp != nil {
t.Fatal("expected nil data here") t.Fatal("expected nil resp here")
} }
}) })
@ -28,12 +29,12 @@ func TestDNSDecoder(t *testing.T) {
d := &DNSDecoderMiekg{} d := &DNSDecoderMiekg{}
queryID := dns.Id() queryID := dns.Id()
rawQuery := dnsGenQuery(dns.TypeA, queryID) rawQuery := dnsGenQuery(dns.TypeA, queryID)
addrs, err := d.DecodeLookupHost(dns.TypeA, rawQuery, queryID) resp, err := d.DecodeResponse(rawQuery, &mocks.DNSQuery{})
if !errors.Is(err, ErrDNSIsQuery) { if !errors.Is(err, ErrDNSIsQuery) {
t.Fatal("unexpected err", err) t.Fatal("unexpected err", err)
} }
if len(addrs) > 0 { if resp != nil {
t.Fatal("expected no addrs") t.Fatal("expected nil resp here")
} }
}) })
@ -44,214 +45,194 @@ func TestDNSDecoder(t *testing.T) {
unrelatedID = 14 unrelatedID = 14
) )
reply := dnsGenLookupHostReplySuccess(dnsGenQuery(dns.TypeA, queryID)) reply := dnsGenLookupHostReplySuccess(dnsGenQuery(dns.TypeA, queryID))
data, err := d.DecodeLookupHost(dns.TypeA, reply, unrelatedID) resp, err := d.DecodeResponse(reply, &mocks.DNSQuery{
MockID: func() uint16 {
return unrelatedID
},
})
if !errors.Is(err, ErrDNSReplyWithWrongQueryID) { if !errors.Is(err, ErrDNSReplyWithWrongQueryID) {
t.Fatal("unexpected error", err) t.Fatal("unexpected error", err)
} }
if data != nil { if resp != nil {
t.Fatal("expected nil data here") t.Fatal("expected nil resp here")
} }
}) })
t.Run("NXDOMAIN", func(t *testing.T) { t.Run("dnsResponse.Query", func(t *testing.T) {
d := &DNSDecoderMiekg{} d := &DNSDecoderMiekg{}
queryID := dns.Id() queryID := dns.Id()
data, err := d.DecodeLookupHost(dns.TypeA, dnsGenReplyWithError( rawQuery := dnsGenQuery(dns.TypeA, queryID)
dnsGenQuery(dns.TypeA, queryID), dns.RcodeNameError), queryID) rawResponse := dnsGenLookupHostReplySuccess(rawQuery)
if err == nil || !strings.HasSuffix(err.Error(), "no such host") { query := &mocks.DNSQuery{
t.Fatal("not the error we expected", err) MockID: func() uint16 {
return queryID
},
} }
if data != nil { resp, err := d.DecodeResponse(rawResponse, query)
t.Fatal("expected nil data here")
}
})
t.Run("Refused", func(t *testing.T) {
d := &DNSDecoderMiekg{}
queryID := dns.Id()
data, err := d.DecodeLookupHost(dns.TypeA, dnsGenReplyWithError(
dnsGenQuery(dns.TypeA, queryID), dns.RcodeRefused), queryID)
if !errors.Is(err, ErrOODNSRefused) {
t.Fatal("not the error we expected", err)
}
if data != nil {
t.Fatal("expected nil data here")
}
})
t.Run("Servfail", func(t *testing.T) {
d := &DNSDecoderMiekg{}
queryID := dns.Id()
data, err := d.DecodeLookupHost(dns.TypeA, dnsGenReplyWithError(
dnsGenQuery(dns.TypeA, queryID), dns.RcodeServerFailure), queryID)
if !errors.Is(err, ErrOODNSServfail) {
t.Fatal("not the error we expected", err)
}
if data != nil {
t.Fatal("expected nil data here")
}
})
t.Run("no address", func(t *testing.T) {
d := &DNSDecoderMiekg{}
queryID := dns.Id()
data, err := d.DecodeLookupHost(dns.TypeA, dnsGenLookupHostReplySuccess(
dnsGenQuery(dns.TypeA, queryID)), queryID)
if !errors.Is(err, ErrOODNSNoAnswer) {
t.Fatal("not the error we expected", err)
}
if data != nil {
t.Fatal("expected nil data here")
}
})
t.Run("decode A", func(t *testing.T) {
d := &DNSDecoderMiekg{}
queryID := dns.Id()
data, err := d.DecodeLookupHost(dns.TypeA, dnsGenLookupHostReplySuccess(
dnsGenQuery(dns.TypeA, queryID), "1.1.1.1", "8.8.8.8"), queryID)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if len(data) != 2 { if resp.Query().ID() != query.ID() {
t.Fatal("expected two entries here") t.Fatal("invalid query")
}
if data[0] != "1.1.1.1" {
t.Fatal("invalid first IPv4 entry")
}
if data[1] != "8.8.8.8" {
t.Fatal("invalid second IPv4 entry")
} }
}) })
t.Run("decode AAAA", func(t *testing.T) { t.Run("dnsResponse.Bytes", func(t *testing.T) {
d := &DNSDecoderMiekg{} d := &DNSDecoderMiekg{}
queryID := dns.Id() queryID := dns.Id()
data, err := d.DecodeLookupHost(dns.TypeAAAA, dnsGenLookupHostReplySuccess( rawQuery := dnsGenQuery(dns.TypeA, queryID)
dnsGenQuery(dns.TypeAAAA, queryID), "::1", "fe80::1"), queryID) rawResponse := dnsGenLookupHostReplySuccess(rawQuery)
query := &mocks.DNSQuery{
MockID: func() uint16 {
return queryID
},
}
resp, err := d.DecodeResponse(rawResponse, query)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if len(data) != 2 { if !bytes.Equal(rawResponse, resp.Bytes()) {
t.Fatal("expected two entries here") t.Fatal("invalid bytes")
}
if data[0] != "::1" {
t.Fatal("invalid first IPv6 entry")
}
if data[1] != "fe80::1" {
t.Fatal("invalid second IPv6 entry")
} }
}) })
t.Run("unexpected A reply", func(t *testing.T) { t.Run("dnsResponse.Rcode", func(t *testing.T) {
d := &DNSDecoderMiekg{} d := &DNSDecoderMiekg{}
queryID := dns.Id() queryID := dns.Id()
data, err := d.DecodeLookupHost(dns.TypeA, dnsGenLookupHostReplySuccess( rawQuery := dnsGenQuery(dns.TypeA, queryID)
dnsGenQuery(dns.TypeAAAA, queryID), "::1", "fe80::1"), queryID) rawResponse := dnsGenReplyWithError(rawQuery, dns.RcodeRefused)
if !errors.Is(err, ErrOODNSNoAnswer) { query := &mocks.DNSQuery{
t.Fatal("not the error we expected", err) MockID: func() uint16 {
return queryID
},
} }
if data != nil { resp, err := d.DecodeResponse(rawResponse, query)
t.Fatal("expected nil data here")
}
})
t.Run("unexpected AAAA reply", func(t *testing.T) {
d := &DNSDecoderMiekg{}
queryID := dns.Id()
data, err := d.DecodeLookupHost(dns.TypeAAAA, dnsGenLookupHostReplySuccess(
dnsGenQuery(dns.TypeA, queryID), "1.1.1.1", "8.8.4.4"), queryID)
if !errors.Is(err, ErrOODNSNoAnswer) {
t.Fatal("not the error we expected", err)
}
if data != nil {
t.Fatal("expected nil data here")
}
})
})
t.Run("decodeSuccessfulReply", func(t *testing.T) {
d := &DNSDecoderMiekg{}
msg := &dns.Msg{}
msg.Rcode = dns.RcodeFormatError // an rcode we don't handle
msg.Response = true
data, err := msg.Pack()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
reply, err := d.decodeSuccessfulReply(data, 0) if resp.Rcode() != dns.RcodeRefused {
if !errors.Is(err, ErrOODNSMisbehaving) { // catch all error t.Fatal("invalid rcode")
t.Fatal("not the error we expected", err)
}
if reply != nil {
t.Fatal("expected nil reply")
} }
}) })
t.Run("DecodeHTTPS", func(t *testing.T) { t.Run("dnsResponse.rcodeToError", func(t *testing.T) {
t.Run("with nil data", func(t *testing.T) { // Here we want to ensure we map all the errors we recognize
d := &DNSDecoderMiekg{} // correctly and we also map unrecognized errors correctly
reply, err := d.DecodeHTTPS(nil, 0) var inputsOutputs = []struct {
if err == nil || err.Error() != "dns: overflow unpacking uint16" { name string
t.Fatal("not the error we expected", err) rcode int
} err error
if reply != nil { }{{
t.Fatal("expected nil reply") name: "when rcode is zero",
} rcode: 0,
}) err: nil,
}, {
t.Run("with bytes containing a query", func(t *testing.T) { name: "NXDOMAIN",
rcode: dns.RcodeNameError,
err: ErrOODNSNoSuchHost,
}, {
name: "refused",
rcode: dns.RcodeRefused,
err: ErrOODNSRefused,
}, {
name: "servfail",
rcode: dns.RcodeServerFailure,
err: ErrOODNSServfail,
}, {
name: "anything else",
rcode: dns.RcodeFormatError,
err: ErrOODNSMisbehaving,
}}
for _, io := range inputsOutputs {
t.Run(io.name, func(t *testing.T) {
d := &DNSDecoderMiekg{} d := &DNSDecoderMiekg{}
queryID := dns.Id() queryID := dns.Id()
rawQuery := dnsGenQuery(dns.TypeHTTPS, queryID) rawQuery := dnsGenQuery(dns.TypeHTTPS, queryID)
https, err := d.DecodeHTTPS(rawQuery, queryID) rawResponse := dnsGenReplyWithError(rawQuery, io.rcode)
if !errors.Is(err, ErrDNSIsQuery) { query := &mocks.DNSQuery{
MockID: func() uint16 {
return queryID
},
}
resp, err := d.DecodeResponse(rawResponse, query)
if err != nil {
t.Fatal(err)
}
// The following cast should always work in this configuration
err = resp.(*dnsResponse).rcodeToError()
if !errors.Is(err, io.err) {
t.Fatal("unexpected err", err) t.Fatal("unexpected err", err)
} }
if https != nil { })
t.Fatal("expected nil https")
} }
}) })
t.Run("wrong query ID", func(t *testing.T) { t.Run("dnsResponse.DecodeHTTPS", func(t *testing.T) {
t.Run("with failure", func(t *testing.T) {
// Ensure that we're not trying to decode if rcode != 0
d := &DNSDecoderMiekg{} d := &DNSDecoderMiekg{}
const ( queryID := dns.Id()
queryID = 17 rawQuery := dnsGenQuery(dns.TypeHTTPS, queryID)
unrelatedID = 14 rawResponse := dnsGenReplyWithError(rawQuery, dns.RcodeRefused)
) query := &mocks.DNSQuery{
reply := dnsGenHTTPSReplySuccess(dnsGenQuery(dns.TypeHTTPS, queryID), nil, nil, nil) MockID: func() uint16 {
data, err := d.DecodeHTTPS(reply, unrelatedID) return queryID
if !errors.Is(err, ErrDNSReplyWithWrongQueryID) { },
t.Fatal("unexpected error", err)
} }
if data != nil { resp, err := d.DecodeResponse(rawResponse, query)
t.Fatal("expected nil data here") if err != nil {
t.Fatal(err)
}
https, err := resp.DecodeHTTPS()
if !errors.Is(err, ErrOODNSRefused) {
t.Fatal("unexpected err", err)
}
if https != nil {
t.Fatal("expected nil https result")
} }
}) })
t.Run("with empty answer", func(t *testing.T) { t.Run("with empty answer", func(t *testing.T) {
queryID := dns.Id()
data := dnsGenHTTPSReplySuccess(
dnsGenQuery(dns.TypeHTTPS, queryID), nil, nil, nil)
d := &DNSDecoderMiekg{} d := &DNSDecoderMiekg{}
reply, err := d.DecodeHTTPS(data, queryID) queryID := dns.Id()
rawQuery := dnsGenQuery(dns.TypeHTTPS, queryID)
rawResponse := dnsGenHTTPSReplySuccess(rawQuery, nil, nil, nil)
query := &mocks.DNSQuery{
MockID: func() uint16 {
return queryID
},
}
resp, err := d.DecodeResponse(rawResponse, query)
if err != nil {
t.Fatal(err)
}
https, err := resp.DecodeHTTPS()
if !errors.Is(err, ErrOODNSNoAnswer) { if !errors.Is(err, ErrOODNSNoAnswer) {
t.Fatal("unexpected err", err) t.Fatal("unexpected err", err)
} }
if reply != nil { if https != nil {
t.Fatal("expected nil reply") t.Fatal("expected nil https results")
} }
}) })
t.Run("with full answer", func(t *testing.T) { t.Run("with full answer", func(t *testing.T) {
queryID := dns.Id()
alpn := []string{"h3"} alpn := []string{"h3"}
v4 := []string{"1.1.1.1"} v4 := []string{"1.1.1.1"}
v6 := []string{"::1"} v6 := []string{"::1"}
data := dnsGenHTTPSReplySuccess(
dnsGenQuery(dns.TypeHTTPS, queryID), alpn, v4, v6)
d := &DNSDecoderMiekg{} d := &DNSDecoderMiekg{}
reply, err := d.DecodeHTTPS(data, queryID) queryID := dns.Id()
rawQuery := dnsGenQuery(dns.TypeHTTPS, queryID)
rawResponse := dnsGenHTTPSReplySuccess(rawQuery, alpn, v4, v6)
query := &mocks.DNSQuery{
MockID: func() uint16 {
return queryID
},
}
resp, err := d.DecodeResponse(rawResponse, query)
if err != nil {
t.Fatal(err)
}
reply, err := resp.DecodeHTTPS()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -267,74 +248,244 @@ func TestDNSDecoder(t *testing.T) {
}) })
}) })
t.Run("DecodeNS", func(t *testing.T) { t.Run("dnsResponse.DecodeNS", func(t *testing.T) {
t.Run("with nil data", func(t *testing.T) { t.Run("with failure", func(t *testing.T) {
d := &DNSDecoderMiekg{} // Ensure that we're not trying to decode if rcode != 0
reply, err := d.DecodeNS(nil, 0)
if err == nil || err.Error() != "dns: overflow unpacking uint16" {
t.Fatal("not the error we expected", err)
}
if reply != nil {
t.Fatal("expected nil reply")
}
})
t.Run("with bytes containing a query", func(t *testing.T) {
d := &DNSDecoderMiekg{} d := &DNSDecoderMiekg{}
queryID := dns.Id() queryID := dns.Id()
rawQuery := dnsGenQuery(dns.TypeNS, queryID) rawQuery := dnsGenQuery(dns.TypeNS, queryID)
ns, err := d.DecodeNS(rawQuery, queryID) rawResponse := dnsGenReplyWithError(rawQuery, dns.RcodeRefused)
if !errors.Is(err, ErrDNSIsQuery) { query := &mocks.DNSQuery{
MockID: func() uint16 {
return queryID
},
}
resp, err := d.DecodeResponse(rawResponse, query)
if err != nil {
t.Fatal(err)
}
ns, err := resp.DecodeNS()
if !errors.Is(err, ErrOODNSRefused) {
t.Fatal("unexpected err", err) t.Fatal("unexpected err", err)
} }
if len(ns) > 0 { if len(ns) > 0 {
t.Fatal("expected no result") t.Fatal("expected empty ns result")
}
})
t.Run("wrong query ID", func(t *testing.T) {
d := &DNSDecoderMiekg{}
const (
queryID = 17
unrelatedID = 14
)
reply := dnsGenNSReplySuccess(dnsGenQuery(dns.TypeNS, queryID))
data, err := d.DecodeNS(reply, unrelatedID)
if !errors.Is(err, ErrDNSReplyWithWrongQueryID) {
t.Fatal("unexpected error", err)
}
if data != nil {
t.Fatal("expected nil data here")
} }
}) })
t.Run("with empty answer", func(t *testing.T) { t.Run("with empty answer", func(t *testing.T) {
queryID := dns.Id()
data := dnsGenNSReplySuccess(dnsGenQuery(dns.TypeNS, queryID))
d := &DNSDecoderMiekg{} d := &DNSDecoderMiekg{}
reply, err := d.DecodeNS(data, queryID) queryID := dns.Id()
rawQuery := dnsGenQuery(dns.TypeNS, queryID)
rawResponse := dnsGenNSReplySuccess(rawQuery)
query := &mocks.DNSQuery{
MockID: func() uint16 {
return queryID
},
}
resp, err := d.DecodeResponse(rawResponse, query)
if err != nil {
t.Fatal(err)
}
ns, err := resp.DecodeNS()
if !errors.Is(err, ErrOODNSNoAnswer) { if !errors.Is(err, ErrOODNSNoAnswer) {
t.Fatal("unexpected err", err) t.Fatal("unexpected err", err)
} }
if reply != nil { if len(ns) > 0 {
t.Fatal("expected nil reply") t.Fatal("expected empty ns results")
} }
}) })
t.Run("with full answer", func(t *testing.T) { t.Run("with full answer", func(t *testing.T) {
queryID := dns.Id()
data := dnsGenNSReplySuccess(dnsGenQuery(dns.TypeNS, queryID), "ns1.zdns.google.")
d := &DNSDecoderMiekg{} d := &DNSDecoderMiekg{}
reply, err := d.DecodeNS(data, queryID) queryID := dns.Id()
rawQuery := dnsGenQuery(dns.TypeNS, queryID)
rawResponse := dnsGenNSReplySuccess(rawQuery, "ns1.zdns.google.")
query := &mocks.DNSQuery{
MockID: func() uint16 {
return queryID
},
}
resp, err := d.DecodeResponse(rawResponse, query)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if len(reply) != 1 { ns, err := resp.DecodeNS()
t.Fatal("unexpected reply length") if err != nil {
t.Fatal(err)
} }
if reply[0].Host != "ns1.zdns.google." { if len(ns) != 1 {
t.Fatal("unexpected reply host") t.Fatal("unexpected ns length")
} }
if ns[0].Host != "ns1.zdns.google." {
t.Fatal("unexpected host")
}
})
})
t.Run("dnsResponse.LookupHost", func(t *testing.T) {
t.Run("with failure", func(t *testing.T) {
// Ensure that we're not trying to decode if rcode != 0
d := &DNSDecoderMiekg{}
queryID := dns.Id()
rawQuery := dnsGenQuery(dns.TypeA, queryID)
rawResponse := dnsGenReplyWithError(rawQuery, dns.RcodeRefused)
query := &mocks.DNSQuery{
MockID: func() uint16 {
return queryID
},
}
resp, err := d.DecodeResponse(rawResponse, query)
if err != nil {
t.Fatal(err)
}
addrs, err := resp.DecodeLookupHost()
if !errors.Is(err, ErrOODNSRefused) {
t.Fatal("unexpected err", err)
}
if len(addrs) > 0 {
t.Fatal("expected empty addrs result")
}
})
t.Run("with empty answer", func(t *testing.T) {
d := &DNSDecoderMiekg{}
queryID := dns.Id()
rawQuery := dnsGenQuery(dns.TypeA, queryID)
rawResponse := dnsGenLookupHostReplySuccess(rawQuery)
query := &mocks.DNSQuery{
MockID: func() uint16 {
return queryID
},
}
resp, err := d.DecodeResponse(rawResponse, query)
if err != nil {
t.Fatal(err)
}
addrs, err := resp.DecodeLookupHost()
if !errors.Is(err, ErrOODNSNoAnswer) {
t.Fatal("unexpected err", err)
}
if len(addrs) > 0 {
t.Fatal("expected empty ns results")
}
})
t.Run("decode A", func(t *testing.T) {
d := &DNSDecoderMiekg{}
queryID := dns.Id()
rawQuery := dnsGenQuery(dns.TypeA, queryID)
rawResponse := dnsGenLookupHostReplySuccess(rawQuery, "1.1.1.1", "8.8.8.8")
query := &mocks.DNSQuery{
MockID: func() uint16 {
return queryID
},
MockType: func() uint16 {
return dns.TypeA
},
}
resp, err := d.DecodeResponse(rawResponse, query)
if err != nil {
t.Fatal(err)
}
addrs, err := resp.DecodeLookupHost()
if err != nil {
t.Fatal(err)
}
if len(addrs) != 2 {
t.Fatal("expected two entries here")
}
if addrs[0] != "1.1.1.1" {
t.Fatal("invalid first IPv4 entry")
}
if addrs[1] != "8.8.8.8" {
t.Fatal("invalid second IPv4 entry")
}
})
t.Run("decode AAAA", func(t *testing.T) {
d := &DNSDecoderMiekg{}
queryID := dns.Id()
rawQuery := dnsGenQuery(dns.TypeAAAA, queryID)
rawResponse := dnsGenLookupHostReplySuccess(rawQuery, "::1", "fe80::1")
query := &mocks.DNSQuery{
MockID: func() uint16 {
return queryID
},
MockType: func() uint16 {
return dns.TypeAAAA
},
}
resp, err := d.DecodeResponse(rawResponse, query)
if err != nil {
t.Fatal(err)
}
addrs, err := resp.DecodeLookupHost()
if err != nil {
t.Fatal(err)
}
if len(addrs) != 2 {
t.Fatal("expected two entries here")
}
if addrs[0] != "::1" {
t.Fatal("invalid first IPv6 entry")
}
if addrs[1] != "fe80::1" {
t.Fatal("invalid second IPv6 entry")
}
})
t.Run("unexpected A reply to AAAA query", func(t *testing.T) {
d := &DNSDecoderMiekg{}
queryID := dns.Id()
rawQuery := dnsGenQuery(dns.TypeAAAA, queryID)
rawResponse := dnsGenLookupHostReplySuccess(rawQuery, "1.1.1.1", "8.8.8.8")
query := &mocks.DNSQuery{
MockID: func() uint16 {
return queryID
},
MockType: func() uint16 {
return dns.TypeAAAA
},
}
resp, err := d.DecodeResponse(rawResponse, query)
if err != nil {
t.Fatal(err)
}
addrs, err := resp.DecodeLookupHost()
if !errors.Is(err, ErrOODNSNoAnswer) {
t.Fatal("not the error we expected", err)
}
if len(addrs) > 0 {
t.Fatal("expected no addrs here")
}
})
t.Run("unexpected AAAA reply to A query", func(t *testing.T) {
d := &DNSDecoderMiekg{}
queryID := dns.Id()
rawQuery := dnsGenQuery(dns.TypeA, queryID)
rawResponse := dnsGenLookupHostReplySuccess(rawQuery, "::1", "fe80::1")
query := &mocks.DNSQuery{
MockID: func() uint16 {
return queryID
},
MockType: func() uint16 {
return dns.TypeA
},
}
resp, err := d.DecodeResponse(rawResponse, query)
if err != nil {
t.Fatal(err)
}
addrs, err := resp.DecodeLookupHost()
if !errors.Is(err, ErrOODNSNoAnswer) {
t.Fatal("not the error we expected", err)
}
if len(addrs) > 0 {
t.Fatal("expected no addrs here")
}
})
}) })
}) })
} }
@ -371,8 +522,8 @@ func dnsGenReplyWithError(rawQuery []byte, code int) []byte {
return data return data
} }
// dnsGenLookupHostReplySuccess generates a successful DNS reply for the given // dnsGenLookupHostReplySuccess generates a successful DNS reply containing the given ips...
// qtype (e.g., dns.TypeA) containing the given ips... in the answer. // in the answers where each answer's type depends on the IP's type (A/AAAA).
func dnsGenLookupHostReplySuccess(rawQuery []byte, ips ...string) []byte { func dnsGenLookupHostReplySuccess(rawQuery []byte, ips ...string) []byte {
query := new(dns.Msg) query := new(dns.Msg)
err := query.Unpack(rawQuery) err := query.Unpack(rawQuery)
@ -388,28 +539,22 @@ func dnsGenLookupHostReplySuccess(rawQuery []byte, ips ...string) []byte {
reply.MsgHdr.RecursionAvailable = true reply.MsgHdr.RecursionAvailable = true
reply.SetReply(query) reply.SetReply(query)
for _, ip := range ips { for _, ip := range ips {
switch question.Qtype { switch isIPv6(ip) {
case dns.TypeA: case false:
if isIPv6(ip) {
continue
}
reply.Answer = append(reply.Answer, &dns.A{ reply.Answer = append(reply.Answer, &dns.A{
Hdr: dns.RR_Header{ Hdr: dns.RR_Header{
Name: dns.Fqdn("x.org"), Name: question.Name,
Rrtype: question.Qtype, Rrtype: dns.TypeA,
Class: dns.ClassINET, Class: dns.ClassINET,
Ttl: 0, Ttl: 0,
}, },
A: net.ParseIP(ip), A: net.ParseIP(ip),
}) })
case dns.TypeAAAA: case true:
if !isIPv6(ip) {
continue
}
reply.Answer = append(reply.Answer, &dns.AAAA{ reply.Answer = append(reply.Answer, &dns.AAAA{
Hdr: dns.RR_Header{ Hdr: dns.RR_Header{
Name: dns.Fqdn("x.org"), Name: question.Name,
Rrtype: question.Qtype, Rrtype: dns.TypeAAAA,
Class: dns.ClassINET, Class: dns.ClassINET,
Ttl: 0, Ttl: 0,
}, },

View File

@ -5,7 +5,10 @@ package netxlite
// //
import ( import (
"sync"
"github.com/miekg/dns" "github.com/miekg/dns"
"github.com/ooni/probe-cli/v3/internal/atomicx"
"github.com/ooni/probe-cli/v3/internal/model" "github.com/ooni/probe-cli/v3/internal/model"
) )
@ -23,18 +26,82 @@ const (
dnsDNSSECEnabled = true dnsDNSSECEnabled = true
) )
func (e *DNSEncoderMiekg) Encode(domain string, qtype uint16, padding bool) ([]byte, uint16, error) { // Encoder implements model.DNSEncoder.Encode.
func (e *DNSEncoderMiekg) Encode(domain string, qtype uint16, padding bool) model.DNSQuery {
return &dnsQuery{
bytesCalls: &atomicx.Int64{},
domain: domain,
kind: qtype,
id: dns.Id(),
memoizedBytes: []byte{},
mu: sync.Mutex{},
padding: padding,
}
}
// dnsQuery implements model.DNSQuery.
type dnsQuery struct {
// bytesCalls counts the calls to the bytes() method
bytesCalls *atomicx.Int64
// domain is the domain.
domain string
// kind is the query type.
kind uint16
// id is the query ID.
id uint16
// memoizedBytes contains the query encoded as bytes. We only fill
// this field the first time the Bytes method is called.
memoizedBytes []byte
// mu provides mutual exclusion.
mu sync.Mutex
// padding indicates whether we need padding.
padding bool
}
// Domain implements model.DNSQuery.Domain.
func (q *dnsQuery) Domain() string {
return q.domain
}
// Type implements model.DNSQuery.Type.
func (q *dnsQuery) Type() uint16 {
return q.kind
}
// Bytes implements model.DNSQuery.Bytes.
func (q *dnsQuery) Bytes() ([]byte, error) {
defer q.mu.Unlock()
q.mu.Lock()
if len(q.memoizedBytes) <= 0 {
q.bytesCalls.Add(1) // for testing
data, err := q.bytes()
if err != nil {
return nil, err
}
q.memoizedBytes = data
}
return q.memoizedBytes, nil
}
// bytes is the unmemoized implementation of Bytes
func (q *dnsQuery) bytes() ([]byte, error) {
question := dns.Question{ question := dns.Question{
Name: dns.Fqdn(domain), Name: dns.Fqdn(q.domain),
Qtype: qtype, Qtype: q.kind,
Qclass: dns.ClassINET, Qclass: dns.ClassINET,
} }
query := new(dns.Msg) query := new(dns.Msg)
query.Id = dns.Id() query.Id = q.id
query.RecursionDesired = true query.RecursionDesired = true
query.Question = make([]dns.Question, 1) query.Question = make([]dns.Question, 1)
query.Question[0] = question query.Question[0] = question
if padding { if q.padding {
query.SetEdns0(dnsEDNS0MaxResponseSize, dnsDNSSECEnabled) query.SetEdns0(dnsEDNS0MaxResponseSize, dnsDNSSECEnabled)
// Clients SHOULD pad queries to the closest multiple of // Clients SHOULD pad queries to the closest multiple of
// 128 octets RFC8467#section-4.1. We inflate the query // 128 octets RFC8467#section-4.1. We inflate the query
@ -47,8 +114,13 @@ func (e *DNSEncoderMiekg) Encode(domain string, qtype uint16, padding bool) ([]b
opt.Padding = make([]byte, remainder) opt.Padding = make([]byte, remainder)
query.IsEdns0().Option = append(query.IsEdns0().Option, opt) query.IsEdns0().Option = append(query.IsEdns0().Option, opt)
} }
data, err := query.Pack() return query.Pack()
return data, query.Id, err }
// ID implements model.DNSQuery.ID
func (q *dnsQuery) ID() uint16 {
return q.id
} }
var _ model.DNSEncoder = &DNSEncoderMiekg{} var _ model.DNSEncoder = &DNSEncoderMiekg{}
var _ model.DNSQuery = &dnsQuery{}

View File

@ -1,29 +1,103 @@
package netxlite package netxlite
import ( import (
"bytes"
"encoding/binary"
"strings" "strings"
"testing" "testing"
"github.com/miekg/dns" "github.com/miekg/dns"
"github.com/ooni/probe-cli/v3/internal/randx"
"github.com/ooni/probe-cli/v3/internal/runtimex"
) )
func TestDNSEncoder(t *testing.T) { func TestDNSEncoderMiekg(t *testing.T) {
t.Run("we can fail to encode a domain name to bytes", func(t *testing.T) {
e := &DNSEncoderMiekg{}
domain := randx.LettersUppercase(512)
query := e.Encode(domain, dns.TypeA, false)
data, err := query.Bytes()
if err == nil || !strings.HasSuffix(err.Error(), "bad rdata") {
t.Fatal("unexpected err", err)
}
if data != nil {
t.Fatal("expected nil data here")
}
})
t.Run("calls to bytes are memoized", func(t *testing.T) {
t.Run("on success", func(t *testing.T) {
e := &DNSEncoderMiekg{}
query := e.Encode("x.org", dns.TypeA, false)
checkResult := func(data []byte, err error) {
if err != nil {
t.Fatal("unexpected err", err)
}
dnsValidateEncodedQueryBytes(t, data, byte(dns.TypeA), query.ID())
}
const repeat = 3
for idx := 0; idx < repeat; idx++ {
checkResult(query.Bytes())
}
// The following cast will always work in this configuration
if query.(*dnsQuery).bytesCalls.Load() != 1 {
t.Fatal("invalid number of calls")
}
})
t.Run("on failure", func(t *testing.T) {
e := &DNSEncoderMiekg{}
domain := randx.LettersUppercase(512)
query := e.Encode(domain, dns.TypeA, false)
checkResult := func(data []byte, err error) {
if err == nil || !strings.HasSuffix(err.Error(), "bad rdata") {
t.Fatal("unexpected err", err)
}
if data != nil {
t.Fatal("expected nil data here")
}
}
const repeat = 3
for idx := 0; idx < repeat; idx++ {
checkResult(query.Bytes())
}
// The following cast will always work in this configuration
if query.(*dnsQuery).bytesCalls.Load() != repeat {
t.Fatal("invalid number of calls")
}
})
})
t.Run("encode A", func(t *testing.T) { t.Run("encode A", func(t *testing.T) {
e := &DNSEncoderMiekg{} e := &DNSEncoderMiekg{}
data, _, err := e.Encode("x.org", dns.TypeA, false) query := e.Encode("x.org", dns.TypeA, false)
if query.Domain() != "x.org" {
t.Fatal("invalid domain")
}
if query.Type() != dns.TypeA {
t.Fatal("invalid type")
}
data, err := query.Bytes()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
dnsValidateEncodedQueryBytes(t, data, byte(dns.TypeA)) dnsValidateEncodedQueryBytes(t, data, byte(dns.TypeA), query.ID())
}) })
t.Run("encode AAAA", func(t *testing.T) { t.Run("encode AAAA", func(t *testing.T) {
e := &DNSEncoderMiekg{} e := &DNSEncoderMiekg{}
data, _, err := e.Encode("x.org", dns.TypeAAAA, false) query := e.Encode("x.org", dns.TypeAAAA, false)
if query.Domain() != "x.org" {
t.Fatal("invalid domain")
}
if query.Type() != dns.TypeAAAA {
t.Fatal("invalid type")
}
data, err := query.Bytes()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
dnsValidateEncodedQueryBytes(t, data, byte(dns.TypeA)) dnsValidateEncodedQueryBytes(t, data, byte(dns.TypeA), query.ID())
}) })
t.Run("encode padding", func(t *testing.T) { t.Run("encode padding", func(t *testing.T) {
@ -31,7 +105,7 @@ func TestDNSEncoder(t *testing.T) {
// array of values we obtain the right query size. // array of values we obtain the right query size.
getquerylen := func(domainlen int, padding bool) int { getquerylen := func(domainlen int, padding bool) int {
e := &DNSEncoderMiekg{} e := &DNSEncoderMiekg{}
data, _, err := e.Encode( query := e.Encode(
// This is not a valid name because it ends up being way // This is not a valid name because it ends up being way
// longer than 255 octets. However, the library is allowing // longer than 255 octets. However, the library is allowing
// us to generate such name and we are not going to send // us to generate such name and we are not going to send
@ -40,6 +114,7 @@ func TestDNSEncoder(t *testing.T) {
dns.Fqdn(strings.Repeat("x.", domainlen)), dns.Fqdn(strings.Repeat("x.", domainlen)),
dns.TypeA, padding, dns.TypeA, padding,
) )
data, err := query.Bytes()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -63,8 +138,13 @@ func TestDNSEncoder(t *testing.T) {
// dnsValidateEncodedQueryBytes validates the query serialized in data // dnsValidateEncodedQueryBytes validates the query serialized in data
// for the given query type qtype (e.g., dns.TypeAAAA). // for the given query type qtype (e.g., dns.TypeAAAA).
func dnsValidateEncodedQueryBytes(t *testing.T, data []byte, qtype byte) { func dnsValidateEncodedQueryBytes(t *testing.T, data []byte, qtype byte, qid uint16) {
// skipping over the query ID var wirequery uint16
err := binary.Read(bytes.NewReader(data), binary.BigEndian, &wirequery)
runtimex.PanicOnError(err, "binary.Read failed unexpectedly")
if wirequery != qid {
t.Fatal("invalid query ID")
}
if data[2] != 1 { if data[2] != 1 {
t.Fatal("FLAGS should only have RD set") t.Fatal("FLAGS should only have RD set")
} }

View File

@ -8,6 +8,7 @@ import (
"bytes" "bytes"
"context" "context"
"errors" "errors"
"io"
"net/http" "net/http"
"time" "time"
@ -19,6 +20,9 @@ type DNSOverHTTPSTransport struct {
// Client is the MANDATORY http client to use. // Client is the MANDATORY http client to use.
Client model.HTTPClient Client model.HTTPClient
// Decoder is the MANDATORY DNSDecoder.
Decoder model.DNSDecoder
// URL is the MANDATORY URL of the DNS-over-HTTPS server. // URL is the MANDATORY URL of the DNS-over-HTTPS server.
URL string URL string
@ -31,9 +35,9 @@ type DNSOverHTTPSTransport struct {
// //
// Arguments: // Arguments:
// //
// - client in http.Client-like type (e.g., http.DefaultClient); // - client is a model.HTTPClient type;
// //
// - URL is the DoH resolver URL (e.g., https://1.1.1.1/dns-query). // - URL is the DoH resolver URL (e.g., https://dns.google/dns-query).
func NewDNSOverHTTPSTransport(client model.HTTPClient, URL string) *DNSOverHTTPSTransport { func NewDNSOverHTTPSTransport(client model.HTTPClient, URL string) *DNSOverHTTPSTransport {
return NewDNSOverHTTPSTransportWithHostOverride(client, URL, "") return NewDNSOverHTTPSTransportWithHostOverride(client, URL, "")
} }
@ -42,22 +46,31 @@ func NewDNSOverHTTPSTransport(client model.HTTPClient, URL string) *DNSOverHTTPS
// with the given Host header override. // with the given Host header override.
func NewDNSOverHTTPSTransportWithHostOverride( func NewDNSOverHTTPSTransportWithHostOverride(
client model.HTTPClient, URL, hostOverride string) *DNSOverHTTPSTransport { client model.HTTPClient, URL, hostOverride string) *DNSOverHTTPSTransport {
return &DNSOverHTTPSTransport{Client: client, URL: URL, HostOverride: hostOverride} return &DNSOverHTTPSTransport{
Client: client,
Decoder: &DNSDecoderMiekg{},
URL: URL,
HostOverride: hostOverride,
}
} }
// RoundTrip sends a query and receives a reply. // RoundTrip sends a query and receives a reply.
func (t *DNSOverHTTPSTransport) RoundTrip(ctx context.Context, query []byte) ([]byte, error) { func (t *DNSOverHTTPSTransport) RoundTrip(
ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
rawQuery, err := query.Bytes()
if err != nil {
return nil, err
}
ctx, cancel := context.WithTimeout(ctx, 45*time.Second) ctx, cancel := context.WithTimeout(ctx, 45*time.Second)
defer cancel() defer cancel()
req, err := http.NewRequest("POST", t.URL, bytes.NewReader(query)) req, err := http.NewRequest("POST", t.URL, bytes.NewReader(rawQuery))
if err != nil { if err != nil {
return nil, err return nil, err
} }
req.Host = t.HostOverride req.Host = t.HostOverride
req.Header.Set("user-agent", model.HTTPHeaderUserAgent) req.Header.Set("user-agent", model.HTTPHeaderUserAgent)
req.Header.Set("content-type", "application/dns-message") req.Header.Set("content-type", "application/dns-message")
var resp *http.Response resp, err := t.Client.Do(req.WithContext(ctx))
resp, err = t.Client.Do(req.WithContext(ctx))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -70,7 +83,13 @@ func (t *DNSOverHTTPSTransport) RoundTrip(ctx context.Context, query []byte) ([]
if resp.Header.Get("content-type") != "application/dns-message" { if resp.Header.Get("content-type") != "application/dns-message" {
return nil, errors.New("doh: invalid content-type") return nil, errors.New("doh: invalid content-type")
} }
return ReadAllContext(ctx, resp.Body) const maxresponsesize = 1 << 20
limitReader := io.LimitReader(resp.Body, maxresponsesize)
rawResponse, err := ReadAllContext(ctx, limitReader)
if err != nil {
return nil, err
}
return t.Decoder.DecodeResponse(rawResponse, query)
} }
// RequiresPadding returns true for DoH according to RFC8467. // RequiresPadding returns true for DoH according to RFC8467.

View File

@ -15,14 +15,36 @@ import (
func TestDNSOverHTTPSTransport(t *testing.T) { func TestDNSOverHTTPSTransport(t *testing.T) {
t.Run("RoundTrip", func(t *testing.T) { t.Run("RoundTrip", func(t *testing.T) {
t.Run("query serialization failure", func(t *testing.T) {
txp := NewDNSOverHTTPSTransport(http.DefaultClient, "https://1.1.1.1/dns-query")
expected := errors.New("mocked error")
query := &mocks.DNSQuery{
MockBytes: func() ([]byte, error) {
return nil, expected
},
}
resp, err := txp.RoundTrip(context.Background(), query)
if !errors.Is(err, expected) {
t.Fatal("unexpected err", err)
}
if resp != nil {
t.Fatal("expected no response here")
}
})
t.Run("NewRequestFailure", func(t *testing.T) { t.Run("NewRequestFailure", func(t *testing.T) {
const invalidURL = "\t" const invalidURL = "\t"
txp := NewDNSOverHTTPSTransport(http.DefaultClient, invalidURL) txp := NewDNSOverHTTPSTransport(http.DefaultClient, invalidURL)
data, err := txp.RoundTrip(context.Background(), nil) query := &mocks.DNSQuery{
if err == nil || !strings.HasSuffix(err.Error(), "invalid control character in URL") { MockBytes: func() ([]byte, error) {
t.Fatal("expected an error here") return make([]byte, 17), nil
},
} }
if data != nil { resp, err := txp.RoundTrip(context.Background(), query)
if err == nil || !strings.HasSuffix(err.Error(), "invalid control character in URL") {
t.Fatal("unexpected err", err)
}
if resp != nil {
t.Fatal("expected no response here") t.Fatal("expected no response here")
} }
}) })
@ -37,11 +59,16 @@ func TestDNSOverHTTPSTransport(t *testing.T) {
}, },
URL: "https://cloudflare-dns.com/dns-query", URL: "https://cloudflare-dns.com/dns-query",
} }
data, err := txp.RoundTrip(context.Background(), nil) query := &mocks.DNSQuery{
if !errors.Is(err, expected) { MockBytes: func() ([]byte, error) {
t.Fatal("expected an error here") return make([]byte, 17), nil
},
} }
if data != nil { resp, err := txp.RoundTrip(context.Background(), query)
if !errors.Is(err, expected) {
t.Fatal("unexpected err", err)
}
if resp != nil {
t.Fatal("expected no response here") t.Fatal("expected no response here")
} }
}) })
@ -58,11 +85,16 @@ func TestDNSOverHTTPSTransport(t *testing.T) {
}, },
URL: "https://cloudflare-dns.com/dns-query", URL: "https://cloudflare-dns.com/dns-query",
} }
data, err := txp.RoundTrip(context.Background(), nil) query := &mocks.DNSQuery{
if err == nil || err.Error() != "doh: server returned error" { MockBytes: func() ([]byte, error) {
t.Fatal("expected an error here") return make([]byte, 17), nil
},
} }
if data != nil { resp, err := txp.RoundTrip(context.Background(), query)
if err == nil || err.Error() != "doh: server returned error" {
t.Fatal("unexpected err", err)
}
if resp != nil {
t.Fatal("expected no response here") t.Fatal("expected no response here")
} }
}) })
@ -79,11 +111,86 @@ func TestDNSOverHTTPSTransport(t *testing.T) {
}, },
URL: "https://cloudflare-dns.com/dns-query", URL: "https://cloudflare-dns.com/dns-query",
} }
data, err := txp.RoundTrip(context.Background(), nil) query := &mocks.DNSQuery{
if err == nil || err.Error() != "doh: invalid content-type" { MockBytes: func() ([]byte, error) {
t.Fatal("expected an error here") return make([]byte, 17), nil
},
} }
if data != nil { resp, err := txp.RoundTrip(context.Background(), query)
if err == nil || err.Error() != "doh: invalid content-type" {
t.Fatal("unexpected err", err)
}
if resp != nil {
t.Fatal("expected no response here")
}
})
t.Run("ReadAllContext fails", func(t *testing.T) {
expected := errors.New("mocked error")
txp := &DNSOverHTTPSTransport{
Client: &mocks.HTTPClient{
MockDo: func(*http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: 200,
Body: io.NopCloser(&mocks.Reader{
MockRead: func(b []byte) (int, error) {
return 0, expected
},
}),
Header: http.Header{
"Content-Type": []string{"application/dns-message"},
},
}, nil
},
},
URL: "https://cloudflare-dns.com/dns-query",
}
query := &mocks.DNSQuery{
MockBytes: func() ([]byte, error) {
return make([]byte, 17), nil
},
}
resp, err := txp.RoundTrip(context.Background(), query)
if !errors.Is(err, expected) {
t.Fatal("unexpected err", err)
}
if resp != nil {
t.Fatal("expected no response here")
}
})
t.Run("decode response failure", func(t *testing.T) {
expected := errors.New("mocked error")
body := []byte("AAA")
txp := &DNSOverHTTPSTransport{
Client: &mocks.HTTPClient{
MockDo: func(*http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: 200,
Body: io.NopCloser(bytes.NewReader(body)),
Header: http.Header{
"Content-Type": []string{"application/dns-message"},
},
}, nil
},
},
URL: "https://cloudflare-dns.com/dns-query",
Decoder: &mocks.DNSDecoder{
MockDecodeResponse: func(data []byte, query model.DNSQuery) (model.DNSResponse, error) {
return nil, expected
},
},
}
query := &mocks.DNSQuery{
MockBytes: func() ([]byte, error) {
return make([]byte, 17), nil
},
}
resp, err := txp.RoundTrip(context.Background(), query)
if !errors.Is(err, expected) {
t.Fatal("unexpected err", err)
}
if resp != nil {
t.Fatal("expected no response here") t.Fatal("expected no response here")
} }
}) })
@ -103,13 +210,23 @@ func TestDNSOverHTTPSTransport(t *testing.T) {
}, },
}, },
URL: "https://cloudflare-dns.com/dns-query", URL: "https://cloudflare-dns.com/dns-query",
Decoder: &mocks.DNSDecoder{
MockDecodeResponse: func(data []byte, query model.DNSQuery) (model.DNSResponse, error) {
return &mocks.DNSResponse{}, nil
},
},
} }
data, err := txp.RoundTrip(context.Background(), nil) query := &mocks.DNSQuery{
MockBytes: func() ([]byte, error) {
return make([]byte, 17), nil
},
}
resp, err := txp.RoundTrip(context.Background(), query)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if !bytes.Equal(data, body) { if resp == nil {
t.Fatal("not the response we expected") t.Fatal("expected non-nil resp here")
} }
}) })
@ -125,7 +242,12 @@ func TestDNSOverHTTPSTransport(t *testing.T) {
}, },
URL: "https://cloudflare-dns.com/dns-query", URL: "https://cloudflare-dns.com/dns-query",
} }
data, err := txp.RoundTrip(context.Background(), nil) query := &mocks.DNSQuery{
MockBytes: func() ([]byte, error) {
return make([]byte, 17), nil
},
}
data, err := txp.RoundTrip(context.Background(), query)
if !errors.Is(err, expected) { if !errors.Is(err, expected) {
t.Fatal("expected an error here") t.Fatal("expected an error here")
} }
@ -151,18 +273,22 @@ func TestDNSOverHTTPSTransport(t *testing.T) {
URL: "https://cloudflare-dns.com/dns-query", URL: "https://cloudflare-dns.com/dns-query",
HostOverride: hostOverride, HostOverride: hostOverride,
} }
data, err := txp.RoundTrip(context.Background(), nil) query := &mocks.DNSQuery{
MockBytes: func() ([]byte, error) {
return make([]byte, 17), nil
},
}
resp, err := txp.RoundTrip(context.Background(), query)
if !errors.Is(err, expected) { if !errors.Is(err, expected) {
t.Fatal("expected an error here") t.Fatal("expected an error here")
} }
if data != nil { if resp != nil {
t.Fatal("expected no response here") t.Fatal("expected no response here")
} }
if !correct { if !correct {
t.Fatal("did not see correct host override") t.Fatal("did not see correct host override")
} }
}) })
}) })
t.Run("other functions behave correctly", func(t *testing.T) { t.Run("other functions behave correctly", func(t *testing.T) {

View File

@ -20,9 +20,12 @@ type DialContextFunc func(context.Context, string, string) (net.Conn, error)
// DNSOverTCPTransport is a DNS-over-{TCP,TLS} DNSTransport. // DNSOverTCPTransport is a DNS-over-{TCP,TLS} DNSTransport.
// //
// Bug: this implementation always creates a new connection for each query. // Note: this implementation always creates a new connection for each query. This
// strategy is less efficient but MAY be more robust for cleartext TCP connections
// when querying for a blocked domain name causes endpoint blocking.
type DNSOverTCPTransport struct { type DNSOverTCPTransport struct {
dial DialContextFunc dial DialContextFunc
decoder model.DNSDecoder
address string address string
network string network string
requiresPadding bool requiresPadding bool
@ -36,47 +39,58 @@ type DNSOverTCPTransport struct {
// //
// - address is the endpoint address (e.g., 8.8.8.8:53). // - address is the endpoint address (e.g., 8.8.8.8:53).
func NewDNSOverTCPTransport(dial DialContextFunc, address string) *DNSOverTCPTransport { func NewDNSOverTCPTransport(dial DialContextFunc, address string) *DNSOverTCPTransport {
return &DNSOverTCPTransport{ return newDNSOverTCPOrTLSTransport(dial, "tcp", address, false)
dial: dial,
address: address,
network: "tcp",
requiresPadding: false,
}
} }
// NewDNSOverTLS creates a new DNSOverTLS transport. // NewDNSOverTLSTransport creates a new DNSOverTLS transport.
// //
// Arguments: // Arguments:
// //
// - dial is a function with the net.Dialer.DialContext's signature; // - dial is a function with the net.Dialer.DialContext's signature;
// //
// - address is the endpoint address (e.g., 8.8.8.8:853). // - address is the endpoint address (e.g., 8.8.8.8:853).
func NewDNSOverTLS(dial DialContextFunc, address string) *DNSOverTCPTransport { func NewDNSOverTLSTransport(dial DialContextFunc, address string) *DNSOverTCPTransport {
return newDNSOverTCPOrTLSTransport(dial, "dot", address, true)
}
// newDNSOverTCPOrTLSTransport is the common factory for creating a transport
func newDNSOverTCPOrTLSTransport(
dial DialContextFunc, network, address string, padding bool) *DNSOverTCPTransport {
return &DNSOverTCPTransport{ return &DNSOverTCPTransport{
dial: dial, dial: dial,
decoder: &DNSDecoderMiekg{},
address: address, address: address,
network: "dot", network: network,
requiresPadding: true, requiresPadding: padding,
} }
} }
// errQueryTooLarge indicates the query is too large for the transport.
var errQueryTooLarge = errors.New("oodns: query too large for this transport")
// RoundTrip sends a query and receives a reply. // RoundTrip sends a query and receives a reply.
func (t *DNSOverTCPTransport) RoundTrip(ctx context.Context, query []byte) ([]byte, error) { func (t *DNSOverTCPTransport) RoundTrip(
if len(query) > math.MaxUint16 { ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
return nil, errors.New("query too long") // TODO(bassosimone): this method should more strictly honour the context, which
// currently is only used to bound the dial operation
rawQuery, err := query.Bytes()
if err != nil {
return nil, err
}
if len(rawQuery) > math.MaxUint16 {
return nil, errQueryTooLarge
} }
conn, err := t.dial(ctx, "tcp", t.address) conn, err := t.dial(ctx, "tcp", t.address)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer conn.Close() defer conn.Close()
if err = conn.SetDeadline(time.Now().Add(10 * time.Second)); err != nil { const iotimeout = 10 * time.Second
return nil, err conn.SetDeadline(time.Now().Add(iotimeout))
}
// Write request // Write request
buf := []byte{byte(len(query) >> 8)} buf := []byte{byte(len(rawQuery) >> 8)}
buf = append(buf, byte(len(query))) buf = append(buf, byte(len(rawQuery)))
buf = append(buf, query...) buf = append(buf, rawQuery...)
if _, err = conn.Write(buf); err != nil { if _, err = conn.Write(buf); err != nil {
return nil, err return nil, err
} }
@ -86,11 +100,11 @@ func (t *DNSOverTCPTransport) RoundTrip(ctx context.Context, query []byte) ([]by
return nil, err return nil, err
} }
length := int(header[0])<<8 | int(header[1]) length := int(header[0])<<8 | int(header[1])
reply := make([]byte, length) rawResponse := make([]byte, length)
if _, err = io.ReadFull(conn, reply); err != nil { if _, err = io.ReadFull(conn, rawResponse); err != nil {
return nil, err return nil, err
} }
return reply, nil return t.decoder.DecodeResponse(rawResponse, query)
} }
// RequiresPadding returns true for DoT and false for TCP // RequiresPadding returns true for DoT and false for TCP

View File

@ -6,73 +6,83 @@ import (
"crypto/tls" "crypto/tls"
"errors" "errors"
"io" "io"
"math"
"net" "net"
"testing" "testing"
"time" "time"
"github.com/ooni/probe-cli/v3/internal/model"
"github.com/ooni/probe-cli/v3/internal/model/mocks" "github.com/ooni/probe-cli/v3/internal/model/mocks"
) )
func TestDNSOverTCPTransport(t *testing.T) { func TestDNSOverTCPTransport(t *testing.T) {
t.Run("RoundTrip", func(t *testing.T) { t.Run("RoundTrip", func(t *testing.T) {
t.Run("cannot encode query", func(t *testing.T) {
expected := errors.New("mocked error")
const address = "9.9.9.9:53"
txp := NewDNSOverTCPTransport(new(net.Dialer).DialContext, address)
query := &mocks.DNSQuery{
MockBytes: func() ([]byte, error) {
return nil, expected
},
}
resp, err := txp.RoundTrip(context.Background(), query)
if !errors.Is(err, expected) {
t.Fatal("unexpected err", err)
}
if resp != nil {
t.Fatal("expected nil response here")
}
})
t.Run("query too large", func(t *testing.T) { t.Run("query too large", func(t *testing.T) {
const address = "9.9.9.9:53" const address = "9.9.9.9:53"
txp := NewDNSOverTCPTransport(new(net.Dialer).DialContext, address) txp := NewDNSOverTCPTransport(new(net.Dialer).DialContext, address)
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<18)) query := &mocks.DNSQuery{
if err == nil { MockBytes: func() ([]byte, error) {
t.Fatal("expected an error here") return make([]byte, math.MaxUint16+1), nil
},
} }
if reply != nil { resp, err := txp.RoundTrip(context.Background(), query)
t.Fatal("expected nil reply here") if !errors.Is(err, errQueryTooLarge) {
t.Fatal("unexpected err", err)
}
if resp != nil {
t.Fatal("expected nil response here")
} }
}) })
t.Run("dial failure", func(t *testing.T) { t.Run("dial failure", func(t *testing.T) {
const address = "9.9.9.9:53" const address = "9.9.9.9:53"
mocked := errors.New("mocked error") mocked := errors.New("mocked error")
query := &mocks.DNSQuery{
MockBytes: func() ([]byte, error) {
return make([]byte, 128), nil
},
}
fakedialer := &mocks.Dialer{ fakedialer := &mocks.Dialer{
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) { MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
return nil, mocked return nil, mocked
}, },
} }
txp := NewDNSOverTCPTransport(fakedialer.DialContext, address) txp := NewDNSOverTCPTransport(fakedialer.DialContext, address)
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11)) resp, err := txp.RoundTrip(context.Background(), query)
if !errors.Is(err, mocked) { if !errors.Is(err, mocked) {
t.Fatal("not the error we expected") t.Fatal("not the error we expected")
} }
if reply != nil { if resp != nil {
t.Fatal("expected nil reply here") t.Fatal("expected nil resp here")
}
})
t.Run("SetDeadline failure", func(t *testing.T) {
const address = "9.9.9.9:53"
mocked := errors.New("mocked error")
fakedialer := &mocks.Dialer{
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
return &mocks.Conn{
MockSetDeadline: func(t time.Time) error {
return mocked
},
MockClose: func() error {
return nil
},
}, nil
},
}
txp := NewDNSOverTCPTransport(fakedialer.DialContext, address)
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11))
if !errors.Is(err, mocked) {
t.Fatal("not the error we expected")
}
if reply != nil {
t.Fatal("expected nil reply here")
} }
}) })
t.Run("write failure", func(t *testing.T) { t.Run("write failure", func(t *testing.T) {
const address = "9.9.9.9:53" const address = "9.9.9.9:53"
mocked := errors.New("mocked error") mocked := errors.New("mocked error")
query := &mocks.DNSQuery{
MockBytes: func() ([]byte, error) {
return make([]byte, 128), nil
},
}
fakedialer := &mocks.Dialer{ fakedialer := &mocks.Dialer{
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) { MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
return &mocks.Conn{ return &mocks.Conn{
@ -89,18 +99,23 @@ func TestDNSOverTCPTransport(t *testing.T) {
}, },
} }
txp := NewDNSOverTCPTransport(fakedialer.DialContext, address) txp := NewDNSOverTCPTransport(fakedialer.DialContext, address)
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11)) resp, err := txp.RoundTrip(context.Background(), query)
if !errors.Is(err, mocked) { if !errors.Is(err, mocked) {
t.Fatal("not the error we expected") t.Fatal("not the error we expected")
} }
if reply != nil { if resp != nil {
t.Fatal("expected nil reply here") t.Fatal("expected nil resp here")
} }
}) })
t.Run("first read fails", func(t *testing.T) { t.Run("first read fails", func(t *testing.T) {
const address = "9.9.9.9:53" const address = "9.9.9.9:53"
mocked := errors.New("mocked error") mocked := errors.New("mocked error")
query := &mocks.DNSQuery{
MockBytes: func() ([]byte, error) {
return make([]byte, 128), nil
},
}
fakedialer := &mocks.Dialer{ fakedialer := &mocks.Dialer{
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) { MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
return &mocks.Conn{ return &mocks.Conn{
@ -120,18 +135,23 @@ func TestDNSOverTCPTransport(t *testing.T) {
}, },
} }
txp := NewDNSOverTCPTransport(fakedialer.DialContext, address) txp := NewDNSOverTCPTransport(fakedialer.DialContext, address)
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11)) resp, err := txp.RoundTrip(context.Background(), query)
if !errors.Is(err, mocked) { if !errors.Is(err, mocked) {
t.Fatal("not the error we expected") t.Fatal("not the error we expected")
} }
if reply != nil { if resp != nil {
t.Fatal("expected nil reply here") t.Fatal("expected nil resp here")
} }
}) })
t.Run("second read fails", func(t *testing.T) { t.Run("second read fails", func(t *testing.T) {
const address = "9.9.9.9:53" const address = "9.9.9.9:53"
mocked := errors.New("mocked error") mocked := errors.New("mocked error")
query := &mocks.DNSQuery{
MockBytes: func() ([]byte, error) {
return make([]byte, 128), nil
},
}
input := io.MultiReader( input := io.MultiReader(
bytes.NewReader([]byte{byte(0), byte(2)}), bytes.NewReader([]byte{byte(0), byte(2)}),
&mocks.Reader{ &mocks.Reader{
@ -157,17 +177,23 @@ func TestDNSOverTCPTransport(t *testing.T) {
}, },
} }
txp := NewDNSOverTCPTransport(fakedialer.DialContext, address) txp := NewDNSOverTCPTransport(fakedialer.DialContext, address)
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11)) resp, err := txp.RoundTrip(context.Background(), query)
if !errors.Is(err, mocked) { if !errors.Is(err, mocked) {
t.Fatal("not the error we expected") t.Fatal("not the error we expected")
} }
if reply != nil { if resp != nil {
t.Fatal("expected nil reply here") t.Fatal("expected nil resp here")
} }
}) })
t.Run("successful case", func(t *testing.T) { t.Run("decode failure", func(t *testing.T) {
const address = "9.9.9.9:53" const address = "9.9.9.9:53"
mocked := errors.New("mocked error")
query := &mocks.DNSQuery{
MockBytes: func() ([]byte, error) {
return make([]byte, 128), nil
},
}
input := bytes.NewReader([]byte{byte(0), byte(1), byte(1)}) input := bytes.NewReader([]byte{byte(0), byte(1), byte(1)})
fakedialer := &mocks.Dialer{ fakedialer := &mocks.Dialer{
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) { MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
@ -186,11 +212,56 @@ func TestDNSOverTCPTransport(t *testing.T) {
}, },
} }
txp := NewDNSOverTCPTransport(fakedialer.DialContext, address) txp := NewDNSOverTCPTransport(fakedialer.DialContext, address)
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11)) txp.decoder = &mocks.DNSDecoder{
MockDecodeResponse: func(data []byte, query model.DNSQuery) (model.DNSResponse, error) {
return nil, mocked
},
}
resp, err := txp.RoundTrip(context.Background(), query)
if !errors.Is(err, mocked) {
t.Fatal("unexpected err", err)
}
if resp != nil {
t.Fatal("expected nil resp here")
}
})
t.Run("successful case", func(t *testing.T) {
const address = "9.9.9.9:53"
query := &mocks.DNSQuery{
MockBytes: func() ([]byte, error) {
return make([]byte, 128), nil
},
}
input := bytes.NewReader([]byte{byte(0), byte(1), byte(1)})
fakedialer := &mocks.Dialer{
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
return &mocks.Conn{
MockSetDeadline: func(t time.Time) error {
return nil
},
MockWrite: func(b []byte) (int, error) {
return len(b), nil
},
MockRead: input.Read,
MockClose: func() error {
return nil
},
}, nil
},
}
txp := NewDNSOverTCPTransport(fakedialer.DialContext, address)
expectedResp := &mocks.DNSResponse{}
txp.decoder = &mocks.DNSDecoder{
MockDecodeResponse: func(data []byte, query model.DNSQuery) (model.DNSResponse, error) {
return expectedResp, nil
},
}
resp, err := txp.RoundTrip(context.Background(), query)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if len(reply) != 1 || reply[0] != 1 { if resp != expectedResp {
t.Fatal("not the response we expected") t.Fatal("not the response we expected")
} }
}) })
@ -213,7 +284,7 @@ func TestDNSOverTCPTransport(t *testing.T) {
t.Run("other functions okay with TLS", func(t *testing.T) { t.Run("other functions okay with TLS", func(t *testing.T) {
const address = "9.9.9.9:853" const address = "9.9.9.9:853"
txp := NewDNSOverTLS((&tls.Dialer{}).DialContext, address) txp := NewDNSOverTLSTransport((&tls.Dialer{}).DialContext, address)
if txp.RequiresPadding() != true { if txp.RequiresPadding() != true {
t.Fatal("invalid RequiresPadding") t.Fatal("invalid RequiresPadding")
} }

View File

@ -14,6 +14,7 @@ import (
// DNSOverUDPTransport is a DNS-over-UDP DNSTransport. // DNSOverUDPTransport is a DNS-over-UDP DNSTransport.
type DNSOverUDPTransport struct { type DNSOverUDPTransport struct {
dialer model.Dialer dialer model.Dialer
decoder model.DNSDecoder
address string address string
} }
@ -25,11 +26,20 @@ type DNSOverUDPTransport struct {
// //
// - address is the endpoint address (e.g., 8.8.8.8:53). // - address is the endpoint address (e.g., 8.8.8.8:53).
func NewDNSOverUDPTransport(dialer model.Dialer, address string) *DNSOverUDPTransport { func NewDNSOverUDPTransport(dialer model.Dialer, address string) *DNSOverUDPTransport {
return &DNSOverUDPTransport{dialer: dialer, address: address} return &DNSOverUDPTransport{
dialer: dialer,
decoder: &DNSDecoderMiekg{},
address: address,
}
} }
// RoundTrip sends a query and receives a reply. // RoundTrip sends a query and receives a reply.
func (t *DNSOverUDPTransport) RoundTrip(ctx context.Context, query []byte) ([]byte, error) { func (t *DNSOverUDPTransport) RoundTrip(
ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
rawQuery, err := query.Bytes()
if err != nil {
return nil, err
}
conn, err := t.dialer.DialContext(ctx, "udp", t.address) conn, err := t.dialer.DialContext(ctx, "udp", t.address)
if err != nil { if err != nil {
return nil, err return nil, err
@ -37,19 +47,19 @@ func (t *DNSOverUDPTransport) RoundTrip(ctx context.Context, query []byte) ([]by
defer conn.Close() defer conn.Close()
// Use five seconds timeout like Bionic does. See // Use five seconds timeout like Bionic does. See
// https://labs.ripe.net/Members/baptiste_jonglez_1/persistent-dns-connections-for-reliability-and-performance // https://labs.ripe.net/Members/baptiste_jonglez_1/persistent-dns-connections-for-reliability-and-performance
if err = conn.SetDeadline(time.Now().Add(5 * time.Second)); err != nil { const iotimeout = 5 * time.Second
conn.SetDeadline(time.Now().Add(iotimeout))
if _, err = conn.Write(rawQuery); err != nil {
return nil, err return nil, err
} }
if _, err = conn.Write(query); err != nil { const maxmessagesize = 1 << 17
return nil, err rawResponse := make([]byte, maxmessagesize)
} count, err := conn.Read(rawResponse)
reply := make([]byte, 1<<17)
var n int
n, err = conn.Read(reply)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return reply[:n], nil rawResponse = rawResponse[:count]
return t.decoder.DecodeResponse(rawResponse, query)
} }
// RequiresPadding returns false for UDP according to RFC8467. // RequiresPadding returns false for UDP according to RFC8467.

View File

@ -9,11 +9,30 @@ import (
"time" "time"
"github.com/apex/log" "github.com/apex/log"
"github.com/ooni/probe-cli/v3/internal/model"
"github.com/ooni/probe-cli/v3/internal/model/mocks" "github.com/ooni/probe-cli/v3/internal/model/mocks"
) )
func TestDNSOverUDPTransport(t *testing.T) { func TestDNSOverUDPTransport(t *testing.T) {
t.Run("RoundTrip", func(t *testing.T) { t.Run("RoundTrip", func(t *testing.T) {
t.Run("cannot encode query", func(t *testing.T) {
expected := errors.New("mocked error")
const address = "9.9.9.9:53"
txp := NewDNSOverUDPTransport(nil, address)
query := &mocks.DNSQuery{
MockBytes: func() ([]byte, error) {
return nil, expected
},
}
resp, err := txp.RoundTrip(context.Background(), query)
if !errors.Is(err, expected) {
t.Fatal("unexpected err", err)
}
if resp != nil {
t.Fatal("expected nil response here")
}
})
t.Run("dial failure", func(t *testing.T) { t.Run("dial failure", func(t *testing.T) {
mocked := errors.New("mocked error") mocked := errors.New("mocked error")
const address = "9.9.9.9:53" const address = "9.9.9.9:53"
@ -22,36 +41,16 @@ func TestDNSOverUDPTransport(t *testing.T) {
return nil, mocked return nil, mocked
}, },
}, address) }, address)
data, err := txp.RoundTrip(context.Background(), nil) query := &mocks.DNSQuery{
MockBytes: func() ([]byte, error) {
return make([]byte, 128), nil
},
}
resp, err := txp.RoundTrip(context.Background(), query)
if !errors.Is(err, mocked) { if !errors.Is(err, mocked) {
t.Fatal("not the error we expected") t.Fatal("not the error we expected")
} }
if data != nil { if resp != nil {
t.Fatal("expected no response here")
}
})
t.Run("SetDeadline failure", func(t *testing.T) {
mocked := errors.New("mocked error")
txp := NewDNSOverUDPTransport(
&mocks.Dialer{
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
return &mocks.Conn{
MockSetDeadline: func(t time.Time) error {
return mocked
},
MockClose: func() error {
return nil
},
}, nil
},
}, "9.9.9.9:53",
)
data, err := txp.RoundTrip(context.Background(), nil)
if !errors.Is(err, mocked) {
t.Fatal("not the error we expected")
}
if data != nil {
t.Fatal("expected no response here") t.Fatal("expected no response here")
} }
}) })
@ -75,11 +74,16 @@ func TestDNSOverUDPTransport(t *testing.T) {
}, },
}, "9.9.9.9:53", }, "9.9.9.9:53",
) )
data, err := txp.RoundTrip(context.Background(), nil) query := &mocks.DNSQuery{
MockBytes: func() ([]byte, error) {
return make([]byte, 128), nil
},
}
resp, err := txp.RoundTrip(context.Background(), query)
if !errors.Is(err, mocked) { if !errors.Is(err, mocked) {
t.Fatal("not the error we expected") t.Fatal("not the error we expected")
} }
if data != nil { if resp != nil {
t.Fatal("expected no response here") t.Fatal("expected no response here")
} }
}) })
@ -106,15 +110,61 @@ func TestDNSOverUDPTransport(t *testing.T) {
}, },
}, "9.9.9.9:53", }, "9.9.9.9:53",
) )
data, err := txp.RoundTrip(context.Background(), nil) query := &mocks.DNSQuery{
MockBytes: func() ([]byte, error) {
return make([]byte, 128), nil
},
}
resp, err := txp.RoundTrip(context.Background(), query)
if !errors.Is(err, mocked) { if !errors.Is(err, mocked) {
t.Fatal("not the error we expected") t.Fatal("not the error we expected")
} }
if data != nil { if resp != nil {
t.Fatal("expected no response here") t.Fatal("expected no response here")
} }
}) })
t.Run("decode failure", func(t *testing.T) {
const expected = 17
input := bytes.NewReader(make([]byte, expected))
txp := NewDNSOverUDPTransport(
&mocks.Dialer{
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
return &mocks.Conn{
MockSetDeadline: func(t time.Time) error {
return nil
},
MockWrite: func(b []byte) (int, error) {
return len(b), nil
},
MockRead: input.Read,
MockClose: func() error {
return nil
},
}, nil
},
}, "9.9.9.9:53",
)
expectedErr := errors.New("mocked error")
txp.decoder = &mocks.DNSDecoder{
MockDecodeResponse: func(data []byte, query model.DNSQuery) (model.DNSResponse, error) {
return nil, expectedErr
},
}
query := &mocks.DNSQuery{
MockBytes: func() ([]byte, error) {
return make([]byte, 128), nil
},
}
resp, err := txp.RoundTrip(context.Background(), query)
if !errors.Is(err, expectedErr) {
t.Fatal("unexpected err", err)
}
if resp != nil {
t.Fatal("expected nil resp")
}
})
t.Run("read success", func(t *testing.T) { t.Run("read success", func(t *testing.T) {
const expected = 17 const expected = 17
input := bytes.NewReader(make([]byte, expected)) input := bytes.NewReader(make([]byte, expected))
@ -136,12 +186,23 @@ func TestDNSOverUDPTransport(t *testing.T) {
}, },
}, "9.9.9.9:53", }, "9.9.9.9:53",
) )
data, err := txp.RoundTrip(context.Background(), nil) expectedResp := &mocks.DNSResponse{}
txp.decoder = &mocks.DNSDecoder{
MockDecodeResponse: func(data []byte, query model.DNSQuery) (model.DNSResponse, error) {
return expectedResp, nil
},
}
query := &mocks.DNSQuery{
MockBytes: func() ([]byte, error) {
return make([]byte, 128), nil
},
}
resp, err := txp.RoundTrip(context.Background(), query)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if len(data) != expected { if resp != expectedResp {
t.Fatal("expected non nil data") t.Fatal("unexpected resp")
} }
}) })
}) })

View File

@ -1,15 +1,12 @@
package filtering package filtering
import ( import (
"context"
"errors" "errors"
"io" "io"
"net" "net"
"net/http"
"strings" "strings"
"github.com/miekg/dns" "github.com/miekg/dns"
"github.com/ooni/probe-cli/v3/internal/netxlite"
"github.com/ooni/probe-cli/v3/internal/runtimex" "github.com/ooni/probe-cli/v3/internal/runtimex"
) )
@ -51,19 +48,13 @@ type DNSProxy struct {
// receive a query for the given domain. // receive a query for the given domain.
OnQuery func(domain string) DNSAction OnQuery func(domain string) DNSAction
// Upstream is the OPTIONAL upstream transport. // UpstreamEndpoint is the OPTIONAL upstream transport endpoint.
Upstream DNSTransport UpstreamEndpoint string
// mockableReply allows to mock DNSProxy.reply in tests. // mockableReply allows to mock DNSProxy.reply in tests.
mockableReply func(query *dns.Msg) (*dns.Msg, error) mockableReply func(query *dns.Msg) (*dns.Msg, error)
} }
// DNSTransport is the type we expect from an upstream DNS transport.
type DNSTransport interface {
RoundTrip(ctx context.Context, query []byte) ([]byte, error)
CloseIdleConnections()
}
// DNSListener is the interface returned by DNSProxy.Start // DNSListener is the interface returned by DNSProxy.Start
type DNSListener interface { type DNSListener interface {
io.Closer io.Closer
@ -204,23 +195,24 @@ func (p *DNSProxy) compose(query *dns.Msg, ips ...net.IP) *dns.Msg {
return reply return reply
} }
var (
// errDNSExpectedSingleQuestion means we expected to see a single question
errDNSExpectedSingleQuestion = errors.New("filtering: expected single DNS question")
// errDNSExpectedQueryNotResponse means we expected to see a query.
errDNSExpectedQueryNotResponse = errors.New("filtering: expected query not response")
)
func (p *DNSProxy) proxy(query *dns.Msg) (*dns.Msg, error) { func (p *DNSProxy) proxy(query *dns.Msg) (*dns.Msg, error) {
queryBytes, err := query.Pack() if query.Response {
if err != nil { return nil, errDNSExpectedQueryNotResponse
return nil, err
} }
txp := p.dnstransport() if len(query.Question) != 1 {
defer txp.CloseIdleConnections() return nil, errDNSExpectedSingleQuestion
ctx := context.Background()
replyBytes, err := txp.RoundTrip(ctx, queryBytes)
if err != nil {
return nil, err
} }
reply := &dns.Msg{} clnt := &dns.Client{}
if err := reply.Unpack(replyBytes); err != nil { resp, _, err := clnt.Exchange(query, p.upstreamEndpoint())
return nil, err return resp, err
}
return reply, nil
} }
func (p *DNSProxy) cache(name string, query *dns.Msg) *dns.Msg { func (p *DNSProxy) cache(name string, query *dns.Msg) *dns.Msg {
@ -237,10 +229,9 @@ func (p *DNSProxy) cache(name string, query *dns.Msg) *dns.Msg {
return p.compose(query, ipAddrs...) return p.compose(query, ipAddrs...)
} }
func (p *DNSProxy) dnstransport() DNSTransport { func (p *DNSProxy) upstreamEndpoint() string {
if p.Upstream != nil { if p.UpstreamEndpoint != "" {
return p.Upstream return p.UpstreamEndpoint
} }
const URL = "https://1.1.1.1/dns-query" return "8.8.8.8:53"
return netxlite.NewDNSOverHTTPSTransport(http.DefaultClient, URL)
} }

View File

@ -283,9 +283,7 @@ func TestDNSProxy(t *testing.T) {
if len(p) < len(data) { if len(p) < len(data) {
panic("buffer too small") panic("buffer too small")
} }
for i := 0; i < len(data); i++ { copy(p, data)
p[i] = data[i]
}
return len(data), &net.UDPAddr{}, nil return len(data), &net.UDPAddr{}, nil
}, },
} }
@ -314,9 +312,7 @@ func TestDNSProxy(t *testing.T) {
if len(p) < len(data) { if len(p) < len(data) {
panic("buffer too small") panic("buffer too small")
} }
for i := 0; i < len(data); i++ { copy(p, data)
p[i] = data[i]
}
return len(data), &net.UDPAddr{}, nil return len(data), &net.UDPAddr{}, nil
}, },
} }
@ -328,12 +324,24 @@ func TestDNSProxy(t *testing.T) {
}) })
t.Run("proxy", func(t *testing.T) { t.Run("proxy", func(t *testing.T) {
t.Run("Pack fails", func(t *testing.T) { t.Run("with response", func(t *testing.T) {
p := &DNSProxy{} p := &DNSProxy{}
query := &dns.Msg{} query := &dns.Msg{}
query.Rcode = -1 // causes Pack to fail query.Response = true
reply, err := p.proxy(query) reply, err := p.proxy(query)
if err == nil || !strings.HasSuffix(err.Error(), "bad rcode") { if !errors.Is(err, errDNSExpectedQueryNotResponse) {
t.Fatal("unexpected err", err)
}
if reply != nil {
t.Fatal("expected nil reply")
}
})
t.Run("with no questions", func(t *testing.T) {
p := &DNSProxy{}
query := &dns.Msg{}
reply, err := p.proxy(query)
if !errors.Is(err, errDNSExpectedSingleQuestion) {
t.Fatal("unexpected err", err) t.Fatal("unexpected err", err)
} }
if reply != nil { if reply != nil {
@ -342,35 +350,13 @@ func TestDNSProxy(t *testing.T) {
}) })
t.Run("round trip fails", func(t *testing.T) { t.Run("round trip fails", func(t *testing.T) {
expected := errors.New("mocked error")
p := &DNSProxy{ p := &DNSProxy{
Upstream: &mocks.DNSTransport{ UpstreamEndpoint: "antani",
MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) {
return nil, expected
},
MockCloseIdleConnections: func() {},
},
} }
reply, err := p.proxy(&dns.Msg{}) query := &dns.Msg{}
if !errors.Is(err, expected) { query.Question = append(query.Question, dns.Question{})
t.Fatal("unexpected err", err) reply, err := p.proxy(query)
} if err == nil || !strings.HasSuffix(err.Error(), "missing port in address") {
if reply != nil {
t.Fatal("expected nil reply here")
}
})
t.Run("Unpack fails", func(t *testing.T) {
p := &DNSProxy{
Upstream: &mocks.DNSTransport{
MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) {
return make([]byte, 1), nil
},
MockCloseIdleConnections: func() {},
},
}
reply, err := p.proxy(&dns.Msg{})
if err == nil || !strings.HasSuffix(err.Error(), "overflow unpacking uint16") {
t.Fatal("unexpected err", err) t.Fatal("unexpected err", err)
} }
if reply != nil { if reply != nil {

View File

@ -9,7 +9,6 @@ import (
"net" "net"
"github.com/miekg/dns" "github.com/miekg/dns"
"github.com/ooni/probe-cli/v3/internal/atomicx"
"github.com/ooni/probe-cli/v3/internal/model" "github.com/ooni/probe-cli/v3/internal/model"
) )
@ -19,15 +18,6 @@ import (
// You should probably use NewUnwrappedParallelResolver to // You should probably use NewUnwrappedParallelResolver to
// create a new instance of this type. // create a new instance of this type.
type ParallelResolver struct { type ParallelResolver struct {
// Encoder is the MANDATORY encoder to use.
Encoder model.DNSEncoder
// Decoder is the MANDATORY decoder to use.
Decoder model.DNSDecoder
// NumTimeouts is MANDATORY and counts the number of timeouts.
NumTimeouts *atomicx.Int64
// Txp is the MANDATORY underlying DNS transport. // Txp is the MANDATORY underlying DNS transport.
Txp model.DNSTransport Txp model.DNSTransport
} }
@ -36,9 +26,6 @@ type ParallelResolver struct {
// not wrapped and you should wrap if before using it. // not wrapped and you should wrap if before using it.
func NewUnwrappedParallelResolver(t model.DNSTransport) *ParallelResolver { func NewUnwrappedParallelResolver(t model.DNSTransport) *ParallelResolver {
return &ParallelResolver{ return &ParallelResolver{
Encoder: &DNSEncoderMiekg{},
Decoder: &DNSDecoderMiekg{},
NumTimeouts: &atomicx.Int64{},
Txp: t, Txp: t,
} }
} }
@ -80,22 +67,22 @@ func (r *ParallelResolver) LookupHost(ctx context.Context, hostname string) ([]s
var addrs []string var addrs []string
addrs = append(addrs, ares.addrs...) addrs = append(addrs, ares.addrs...)
addrs = append(addrs, aaaares.addrs...) addrs = append(addrs, aaaares.addrs...)
if len(addrs) < 1 {
return nil, ErrOODNSNoAnswer
}
return addrs, nil return addrs, nil
} }
// LookupHTTPS implements Resolver.LookupHTTPS. // LookupHTTPS implements Resolver.LookupHTTPS.
func (r *ParallelResolver) LookupHTTPS( func (r *ParallelResolver) LookupHTTPS(
ctx context.Context, hostname string) (*model.HTTPSSvc, error) { ctx context.Context, hostname string) (*model.HTTPSSvc, error) {
querydata, queryID, err := r.Encoder.Encode( encoder := &DNSEncoderMiekg{}
hostname, dns.TypeHTTPS, r.Txp.RequiresPadding()) query := encoder.Encode(hostname, dns.TypeHTTPS, r.Txp.RequiresPadding())
response, err := r.Txp.RoundTrip(ctx, query)
if err != nil { if err != nil {
return nil, err return nil, err
} }
replydata, err := r.Txp.RoundTrip(ctx, querydata) return response.DecodeHTTPS()
if err != nil {
return nil, err
}
return r.Decoder.DecodeHTTPS(replydata, queryID)
} }
// parallelResolverResult is the internal representation of a // parallelResolverResult is the internal representation of a
@ -108,7 +95,9 @@ type parallelResolverResult struct {
// lookupHost issues a lookup host query for the specified qtype (e.g., dns.A). // lookupHost issues a lookup host query for the specified qtype (e.g., dns.A).
func (r *ParallelResolver) lookupHost(ctx context.Context, hostname string, func (r *ParallelResolver) lookupHost(ctx context.Context, hostname string,
qtype uint16, out chan<- *parallelResolverResult) { qtype uint16, out chan<- *parallelResolverResult) {
querydata, queryID, err := r.Encoder.Encode(hostname, qtype, r.Txp.RequiresPadding()) encoder := &DNSEncoderMiekg{}
query := encoder.Encode(hostname, qtype, r.Txp.RequiresPadding())
response, err := r.Txp.RoundTrip(ctx, query)
if err != nil { if err != nil {
out <- &parallelResolverResult{ out <- &parallelResolverResult{
addrs: []string{}, addrs: []string{},
@ -116,15 +105,7 @@ func (r *ParallelResolver) lookupHost(ctx context.Context, hostname string,
} }
return return
} }
replydata, err := r.Txp.RoundTrip(ctx, querydata) addrs, err := response.DecodeLookupHost()
if err != nil {
out <- &parallelResolverResult{
addrs: []string{},
err: err,
}
return
}
addrs, err := r.Decoder.DecodeLookupHost(qtype, replydata, queryID)
out <- &parallelResolverResult{ out <- &parallelResolverResult{
addrs: addrs, addrs: addrs,
err: err, err: err,
@ -134,14 +115,11 @@ func (r *ParallelResolver) lookupHost(ctx context.Context, hostname string,
// LookupNS implements Resolver.LookupNS. // LookupNS implements Resolver.LookupNS.
func (r *ParallelResolver) LookupNS( func (r *ParallelResolver) LookupNS(
ctx context.Context, hostname string) ([]*net.NS, error) { ctx context.Context, hostname string) ([]*net.NS, error) {
querydata, queryID, err := r.Encoder.Encode( encoder := &DNSEncoderMiekg{}
hostname, dns.TypeNS, r.Txp.RequiresPadding()) query := encoder.Encode(hostname, dns.TypeNS, r.Txp.RequiresPadding())
response, err := r.Txp.RoundTrip(ctx, query)
if err != nil { if err != nil {
return nil, err return nil, err
} }
replydata, err := r.Txp.RoundTrip(ctx, querydata) return response.DecodeNS()
if err != nil {
return nil, err
}
return r.Decoder.DecodeNS(replydata, queryID)
} }

View File

@ -8,14 +8,13 @@ import (
"testing" "testing"
"github.com/miekg/dns" "github.com/miekg/dns"
"github.com/ooni/probe-cli/v3/internal/atomicx"
"github.com/ooni/probe-cli/v3/internal/model" "github.com/ooni/probe-cli/v3/internal/model"
"github.com/ooni/probe-cli/v3/internal/model/mocks" "github.com/ooni/probe-cli/v3/internal/model/mocks"
) )
func TestParallelResolver(t *testing.T) { func TestParallelResolver(t *testing.T) {
t.Run("transport okay", func(t *testing.T) { t.Run("transport okay", func(t *testing.T) {
txp := NewDNSOverTLS((&tls.Dialer{}).DialContext, "8.8.8.8:853") txp := NewDNSOverTLSTransport((&tls.Dialer{}).DialContext, "8.8.8.8:853")
r := NewUnwrappedParallelResolver(txp) r := NewUnwrappedParallelResolver(txp)
rtx := r.Transport() rtx := r.Transport()
if rtx.Network() != "dot" || rtx.Address() != "8.8.8.8:853" { if rtx.Network() != "dot" || rtx.Address() != "8.8.8.8:853" {
@ -30,30 +29,10 @@ func TestParallelResolver(t *testing.T) {
}) })
t.Run("LookupHost", func(t *testing.T) { t.Run("LookupHost", func(t *testing.T) {
t.Run("Encode error", func(t *testing.T) {
mocked := errors.New("mocked error")
txp := NewDNSOverTLS((&tls.Dialer{}).DialContext, "8.8.8.8:853")
r := ParallelResolver{
Encoder: &mocks.DNSEncoder{
MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, uint16, error) {
return nil, 0, mocked
},
},
Txp: txp,
}
addrs, err := r.LookupHost(context.Background(), "www.gogle.com")
if !errors.Is(err, mocked) {
t.Fatal("not the error we expected")
}
if addrs != nil {
t.Fatal("expected nil address here")
}
})
t.Run("RoundTrip error", func(t *testing.T) { t.Run("RoundTrip error", func(t *testing.T) {
mocked := errors.New("mocked error") mocked := errors.New("mocked error")
txp := &mocks.DNSTransport{ txp := &mocks.DNSTransport{
MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) { MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
return nil, mocked return nil, mocked
}, },
MockRequiresPadding: func() bool { MockRequiresPadding: func() bool {
@ -72,8 +51,13 @@ func TestParallelResolver(t *testing.T) {
t.Run("empty reply", func(t *testing.T) { t.Run("empty reply", func(t *testing.T) {
txp := &mocks.DNSTransport{ txp := &mocks.DNSTransport{
MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) { MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
return dnsGenLookupHostReplySuccess(query), nil response := &mocks.DNSResponse{
MockDecodeLookupHost: func() ([]string, error) {
return nil, nil
},
}
return response, nil
}, },
MockRequiresPadding: func() bool { MockRequiresPadding: func() bool {
return true return true
@ -91,8 +75,16 @@ func TestParallelResolver(t *testing.T) {
t.Run("with A reply", func(t *testing.T) { t.Run("with A reply", func(t *testing.T) {
txp := &mocks.DNSTransport{ txp := &mocks.DNSTransport{
MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) { MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
return dnsGenLookupHostReplySuccess(query, "8.8.8.8"), nil response := &mocks.DNSResponse{
MockDecodeLookupHost: func() ([]string, error) {
if query.Type() != dns.TypeA {
return nil, nil
}
return []string{"8.8.8.8"}, nil
},
}
return response, nil
}, },
MockRequiresPadding: func() bool { MockRequiresPadding: func() bool {
return true return true
@ -110,8 +102,16 @@ func TestParallelResolver(t *testing.T) {
t.Run("with AAAA reply", func(t *testing.T) { t.Run("with AAAA reply", func(t *testing.T) {
txp := &mocks.DNSTransport{ txp := &mocks.DNSTransport{
MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) { MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
return dnsGenLookupHostReplySuccess(query, "::1"), nil response := &mocks.DNSResponse{
MockDecodeLookupHost: func() ([]string, error) {
if query.Type() != dns.TypeAAAA {
return nil, nil
}
return []string{"::1"}, nil
},
}
return response, nil
}, },
MockRequiresPadding: func() bool { MockRequiresPadding: func() bool {
return true return true
@ -131,22 +131,15 @@ func TestParallelResolver(t *testing.T) {
afailure := errors.New("a failure") afailure := errors.New("a failure")
aaaafailure := errors.New("aaaa failure") aaaafailure := errors.New("aaaa failure")
txp := &mocks.DNSTransport{ txp := &mocks.DNSTransport{
MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) { MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
msg := &dns.Msg{} switch query.Type() {
if err := msg.Unpack(query); err != nil { case dns.TypeA:
return nil, err
}
if len(msg.Question) != 1 {
return nil, errors.New("expected just one question")
}
q := msg.Question[0]
if q.Qtype == dns.TypeA {
return nil, afailure return nil, afailure
} case dns.TypeAAAA:
if q.Qtype == dns.TypeAAAA {
return nil, aaaafailure return nil, aaaafailure
default:
return nil, errors.New("unexpected query")
} }
return nil, errors.New("expected A or AAAA query")
}, },
MockRequiresPadding: func() bool { MockRequiresPadding: func() bool {
return true return true
@ -179,44 +172,11 @@ func TestParallelResolver(t *testing.T) {
}) })
t.Run("LookupHTTPS", func(t *testing.T) { t.Run("LookupHTTPS", func(t *testing.T) {
t.Run("for encoding error", func(t *testing.T) {
expected := errors.New("mocked error")
r := &ParallelResolver{
Encoder: &mocks.DNSEncoder{
MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, uint16, error) {
return nil, 0, expected
},
},
Decoder: nil,
NumTimeouts: &atomicx.Int64{},
Txp: &mocks.DNSTransport{
MockRequiresPadding: func() bool {
return false
},
},
}
ctx := context.Background()
https, err := r.LookupHTTPS(ctx, "example.com")
if !errors.Is(err, expected) {
t.Fatal("unexpected err", err)
}
if https != nil {
t.Fatal("unexpected result")
}
})
t.Run("for round-trip error", func(t *testing.T) { t.Run("for round-trip error", func(t *testing.T) {
expected := errors.New("mocked error") expected := errors.New("mocked error")
r := &ParallelResolver{ r := &ParallelResolver{
Encoder: &mocks.DNSEncoder{
MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, uint16, error) {
return make([]byte, 64), 0, nil
},
},
Decoder: nil,
NumTimeouts: &atomicx.Int64{},
Txp: &mocks.DNSTransport{ Txp: &mocks.DNSTransport{
MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) { MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
return nil, expected return nil, expected
}, },
MockRequiresPadding: func() bool { MockRequiresPadding: func() bool {
@ -234,23 +194,17 @@ func TestParallelResolver(t *testing.T) {
} }
}) })
t.Run("for decode error", func(t *testing.T) { t.Run("for DecodeHTTPS error", func(t *testing.T) {
expected := errors.New("mocked error") expected := errors.New("mocked error")
r := &ParallelResolver{ r := &ParallelResolver{
Encoder: &mocks.DNSEncoder{ Txp: &mocks.DNSTransport{
MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, uint16, error) { MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
return make([]byte, 64), 0, nil response := &mocks.DNSResponse{
}, MockDecodeHTTPS: func() (*model.HTTPSSvc, error) {
},
Decoder: &mocks.DNSDecoder{
MockDecodeHTTPS: func(reply []byte, queryID uint16) (*model.HTTPSSvc, error) {
return nil, expected return nil, expected
}, },
}, }
NumTimeouts: &atomicx.Int64{}, return response, nil
Txp: &mocks.DNSTransport{
MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) {
return make([]byte, 128), nil
}, },
MockRequiresPadding: func() bool { MockRequiresPadding: func() bool {
return false return false
@ -269,44 +223,11 @@ func TestParallelResolver(t *testing.T) {
}) })
t.Run("LookupNS", func(t *testing.T) { t.Run("LookupNS", func(t *testing.T) {
t.Run("for encoding error", func(t *testing.T) {
expected := errors.New("mocked error")
r := &ParallelResolver{
Encoder: &mocks.DNSEncoder{
MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, uint16, error) {
return nil, 0, expected
},
},
Decoder: nil,
NumTimeouts: &atomicx.Int64{},
Txp: &mocks.DNSTransport{
MockRequiresPadding: func() bool {
return false
},
},
}
ctx := context.Background()
ns, err := r.LookupNS(ctx, "example.com")
if !errors.Is(err, expected) {
t.Fatal("unexpected err", err)
}
if ns != nil {
t.Fatal("unexpected result")
}
})
t.Run("for round-trip error", func(t *testing.T) { t.Run("for round-trip error", func(t *testing.T) {
expected := errors.New("mocked error") expected := errors.New("mocked error")
r := &ParallelResolver{ r := &ParallelResolver{
Encoder: &mocks.DNSEncoder{
MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, uint16, error) {
return make([]byte, 64), 0, nil
},
},
Decoder: nil,
NumTimeouts: &atomicx.Int64{},
Txp: &mocks.DNSTransport{ Txp: &mocks.DNSTransport{
MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) { MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
return nil, expected return nil, expected
}, },
MockRequiresPadding: func() bool { MockRequiresPadding: func() bool {
@ -327,20 +248,14 @@ func TestParallelResolver(t *testing.T) {
t.Run("for decode error", func(t *testing.T) { t.Run("for decode error", func(t *testing.T) {
expected := errors.New("mocked error") expected := errors.New("mocked error")
r := &ParallelResolver{ r := &ParallelResolver{
Encoder: &mocks.DNSEncoder{ Txp: &mocks.DNSTransport{
MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, uint16, error) { MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
return make([]byte, 64), 0, nil response := &mocks.DNSResponse{
}, MockDecodeNS: func() ([]*net.NS, error) {
},
Decoder: &mocks.DNSDecoder{
MockDecodeNS: func(reply []byte, queryID uint16) ([]*net.NS, error) {
return nil, expected return nil, expected
}, },
}, }
NumTimeouts: &atomicx.Int64{}, return response, nil
Txp: &mocks.DNSTransport{
MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) {
return make([]byte, 128), nil
}, },
MockRequiresPadding: func() bool { MockRequiresPadding: func() bool {
return false return false

View File

@ -26,12 +26,6 @@ import (
// QUIRK: unlike the ParallelResolver, this resolver's LookupHost retries // QUIRK: unlike the ParallelResolver, this resolver's LookupHost retries
// each query three times for soft errors. // each query three times for soft errors.
type SerialResolver struct { type SerialResolver struct {
// Encoder is the MANDATORY encoder to use.
Encoder model.DNSEncoder
// Decoder is the MANDATORY decoder to use.
Decoder model.DNSDecoder
// NumTimeouts is MANDATORY and counts the number of timeouts. // NumTimeouts is MANDATORY and counts the number of timeouts.
NumTimeouts *atomicx.Int64 NumTimeouts *atomicx.Int64
@ -42,8 +36,6 @@ type SerialResolver struct {
// NewSerialResolver creates a new SerialResolver instance. // NewSerialResolver creates a new SerialResolver instance.
func NewSerialResolver(t model.DNSTransport) *SerialResolver { func NewSerialResolver(t model.DNSTransport) *SerialResolver {
return &SerialResolver{ return &SerialResolver{
Encoder: &DNSEncoderMiekg{},
Decoder: &DNSDecoderMiekg{},
NumTimeouts: &atomicx.Int64{}, NumTimeouts: &atomicx.Int64{},
Txp: t, Txp: t,
} }
@ -82,22 +74,22 @@ func (r *SerialResolver) LookupHost(ctx context.Context, hostname string) ([]str
} }
addrs = append(addrs, addrsA...) addrs = append(addrs, addrsA...)
addrs = append(addrs, addrsAAAA...) addrs = append(addrs, addrsAAAA...)
if len(addrs) < 1 {
return nil, ErrOODNSNoAnswer
}
return addrs, nil return addrs, nil
} }
// LookupHTTPS implements Resolver.LookupHTTPS. // LookupHTTPS implements Resolver.LookupHTTPS.
func (r *SerialResolver) LookupHTTPS( func (r *SerialResolver) LookupHTTPS(
ctx context.Context, hostname string) (*model.HTTPSSvc, error) { ctx context.Context, hostname string) (*model.HTTPSSvc, error) {
querydata, queryID, err := r.Encoder.Encode( encoder := &DNSEncoderMiekg{}
hostname, dns.TypeHTTPS, r.Txp.RequiresPadding()) query := encoder.Encode(hostname, dns.TypeHTTPS, r.Txp.RequiresPadding())
response, err := r.Txp.RoundTrip(ctx, query)
if err != nil { if err != nil {
return nil, err return nil, err
} }
replydata, err := r.Txp.RoundTrip(ctx, querydata) return response.DecodeHTTPS()
if err != nil {
return nil, err
}
return r.Decoder.DecodeHTTPS(replydata, queryID)
} }
func (r *SerialResolver) lookupHostWithRetry( func (r *SerialResolver) lookupHostWithRetry(
@ -132,28 +124,23 @@ func (r *SerialResolver) lookupHostWithRetry(
// qtype (dns.A or dns.AAAA) without retrying on failure. // qtype (dns.A or dns.AAAA) without retrying on failure.
func (r *SerialResolver) lookupHostWithoutRetry( func (r *SerialResolver) lookupHostWithoutRetry(
ctx context.Context, hostname string, qtype uint16) ([]string, error) { ctx context.Context, hostname string, qtype uint16) ([]string, error) {
querydata, queryID, err := r.Encoder.Encode(hostname, qtype, r.Txp.RequiresPadding()) encoder := &DNSEncoderMiekg{}
query := encoder.Encode(hostname, qtype, r.Txp.RequiresPadding())
response, err := r.Txp.RoundTrip(ctx, query)
if err != nil { if err != nil {
return nil, err return nil, err
} }
replydata, err := r.Txp.RoundTrip(ctx, querydata) return response.DecodeLookupHost()
if err != nil {
return nil, err
}
return r.Decoder.DecodeLookupHost(qtype, replydata, queryID)
} }
// LookupNS implements Resolver.LookupNS. // LookupNS implements Resolver.LookupNS.
func (r *SerialResolver) LookupNS( func (r *SerialResolver) LookupNS(
ctx context.Context, hostname string) ([]*net.NS, error) { ctx context.Context, hostname string) ([]*net.NS, error) {
querydata, queryID, err := r.Encoder.Encode( encoder := &DNSEncoderMiekg{}
hostname, dns.TypeNS, r.Txp.RequiresPadding()) query := encoder.Encode(hostname, dns.TypeNS, r.Txp.RequiresPadding())
response, err := r.Txp.RoundTrip(ctx, query)
if err != nil { if err != nil {
return nil, err return nil, err
} }
replydata, err := r.Txp.RoundTrip(ctx, querydata) return response.DecodeNS()
if err != nil {
return nil, err
}
return r.Decoder.DecodeNS(replydata, queryID)
} }

View File

@ -7,6 +7,7 @@ import (
"net" "net"
"testing" "testing"
"github.com/miekg/dns"
"github.com/ooni/probe-cli/v3/internal/atomicx" "github.com/ooni/probe-cli/v3/internal/atomicx"
"github.com/ooni/probe-cli/v3/internal/model" "github.com/ooni/probe-cli/v3/internal/model"
"github.com/ooni/probe-cli/v3/internal/model/mocks" "github.com/ooni/probe-cli/v3/internal/model/mocks"
@ -30,7 +31,7 @@ func (err *errorWithTimeout) Unwrap() error {
func TestSerialResolver(t *testing.T) { func TestSerialResolver(t *testing.T) {
t.Run("transport okay", func(t *testing.T) { t.Run("transport okay", func(t *testing.T) {
txp := NewDNSOverTLS((&tls.Dialer{}).DialContext, "8.8.8.8:853") txp := NewDNSOverTLSTransport((&tls.Dialer{}).DialContext, "8.8.8.8:853")
r := NewSerialResolver(txp) r := NewSerialResolver(txp)
rtx := r.Transport() rtx := r.Transport()
if rtx.Network() != "dot" || rtx.Address() != "8.8.8.8:853" { if rtx.Network() != "dot" || rtx.Address() != "8.8.8.8:853" {
@ -45,30 +46,10 @@ func TestSerialResolver(t *testing.T) {
}) })
t.Run("LookupHost", func(t *testing.T) { t.Run("LookupHost", func(t *testing.T) {
t.Run("Encode error", func(t *testing.T) {
mocked := errors.New("mocked error")
txp := NewDNSOverTLS((&tls.Dialer{}).DialContext, "8.8.8.8:853")
r := SerialResolver{
Encoder: &mocks.DNSEncoder{
MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, uint16, error) {
return nil, 0, mocked
},
},
Txp: txp,
}
addrs, err := r.LookupHost(context.Background(), "www.gogle.com")
if !errors.Is(err, mocked) {
t.Fatal("not the error we expected")
}
if addrs != nil {
t.Fatal("expected nil address here")
}
})
t.Run("RoundTrip error", func(t *testing.T) { t.Run("RoundTrip error", func(t *testing.T) {
mocked := errors.New("mocked error") mocked := errors.New("mocked error")
txp := &mocks.DNSTransport{ txp := &mocks.DNSTransport{
MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) { MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
return nil, mocked return nil, mocked
}, },
MockRequiresPadding: func() bool { MockRequiresPadding: func() bool {
@ -87,8 +68,13 @@ func TestSerialResolver(t *testing.T) {
t.Run("empty reply", func(t *testing.T) { t.Run("empty reply", func(t *testing.T) {
txp := &mocks.DNSTransport{ txp := &mocks.DNSTransport{
MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) { MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
return dnsGenLookupHostReplySuccess(query), nil response := &mocks.DNSResponse{
MockDecodeLookupHost: func() ([]string, error) {
return nil, nil
},
}
return response, nil
}, },
MockRequiresPadding: func() bool { MockRequiresPadding: func() bool {
return true return true
@ -106,8 +92,16 @@ func TestSerialResolver(t *testing.T) {
t.Run("with A reply", func(t *testing.T) { t.Run("with A reply", func(t *testing.T) {
txp := &mocks.DNSTransport{ txp := &mocks.DNSTransport{
MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) { MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
return dnsGenLookupHostReplySuccess(query, "8.8.8.8"), nil response := &mocks.DNSResponse{
MockDecodeLookupHost: func() ([]string, error) {
if query.Type() != dns.TypeA {
return nil, nil
}
return []string{"8.8.8.8"}, nil
},
}
return response, nil
}, },
MockRequiresPadding: func() bool { MockRequiresPadding: func() bool {
return true return true
@ -125,8 +119,16 @@ func TestSerialResolver(t *testing.T) {
t.Run("with AAAA reply", func(t *testing.T) { t.Run("with AAAA reply", func(t *testing.T) {
txp := &mocks.DNSTransport{ txp := &mocks.DNSTransport{
MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) { MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
return dnsGenLookupHostReplySuccess(query, "::1"), nil response := &mocks.DNSResponse{
MockDecodeLookupHost: func() ([]string, error) {
if query.Type() != dns.TypeAAAA {
return nil, nil
}
return []string{"::1"}, nil
},
}
return response, nil
}, },
MockRequiresPadding: func() bool { MockRequiresPadding: func() bool {
return true return true
@ -144,11 +146,12 @@ func TestSerialResolver(t *testing.T) {
t.Run("with timeout", func(t *testing.T) { t.Run("with timeout", func(t *testing.T) {
txp := &mocks.DNSTransport{ txp := &mocks.DNSTransport{
MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) { MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
return nil, &net.OpError{ err := &net.OpError{
Err: &errorWithTimeout{ETIMEDOUT}, Err: &errorWithTimeout{ETIMEDOUT},
Op: "dial", Op: "dial",
} }
return nil, err
}, },
MockRequiresPadding: func() bool { MockRequiresPadding: func() bool {
return true return true
@ -184,44 +187,12 @@ func TestSerialResolver(t *testing.T) {
}) })
t.Run("LookupHTTPS", func(t *testing.T) { t.Run("LookupHTTPS", func(t *testing.T) {
t.Run("for encoding error", func(t *testing.T) {
expected := errors.New("mocked error")
r := &SerialResolver{
Encoder: &mocks.DNSEncoder{
MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, uint16, error) {
return nil, 0, expected
},
},
Decoder: nil,
NumTimeouts: &atomicx.Int64{},
Txp: &mocks.DNSTransport{
MockRequiresPadding: func() bool {
return false
},
},
}
ctx := context.Background()
https, err := r.LookupHTTPS(ctx, "example.com")
if !errors.Is(err, expected) {
t.Fatal("unexpected err", err)
}
if https != nil {
t.Fatal("unexpected result")
}
})
t.Run("for round-trip error", func(t *testing.T) { t.Run("for round-trip error", func(t *testing.T) {
expected := errors.New("mocked error") expected := errors.New("mocked error")
r := &SerialResolver{ r := &SerialResolver{
Encoder: &mocks.DNSEncoder{
MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, uint16, error) {
return make([]byte, 64), 0, nil
},
},
Decoder: nil,
NumTimeouts: &atomicx.Int64{}, NumTimeouts: &atomicx.Int64{},
Txp: &mocks.DNSTransport{ Txp: &mocks.DNSTransport{
MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) { MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
return nil, expected return nil, expected
}, },
MockRequiresPadding: func() bool { MockRequiresPadding: func() bool {
@ -242,20 +213,15 @@ func TestSerialResolver(t *testing.T) {
t.Run("for decode error", func(t *testing.T) { t.Run("for decode error", func(t *testing.T) {
expected := errors.New("mocked error") expected := errors.New("mocked error")
r := &SerialResolver{ r := &SerialResolver{
Encoder: &mocks.DNSEncoder{
MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, uint16, error) {
return make([]byte, 64), 0, nil
},
},
Decoder: &mocks.DNSDecoder{
MockDecodeHTTPS: func(reply []byte, queryID uint16) (*model.HTTPSSvc, error) {
return nil, expected
},
},
NumTimeouts: &atomicx.Int64{}, NumTimeouts: &atomicx.Int64{},
Txp: &mocks.DNSTransport{ Txp: &mocks.DNSTransport{
MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) { MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
return make([]byte, 128), nil response := &mocks.DNSResponse{
MockDecodeHTTPS: func() (*model.HTTPSSvc, error) {
return nil, expected
},
}
return response, nil
}, },
MockRequiresPadding: func() bool { MockRequiresPadding: func() bool {
return false return false
@ -274,44 +240,12 @@ func TestSerialResolver(t *testing.T) {
}) })
t.Run("LookupNS", func(t *testing.T) { t.Run("LookupNS", func(t *testing.T) {
t.Run("for encoding error", func(t *testing.T) {
expected := errors.New("mocked error")
r := &SerialResolver{
Encoder: &mocks.DNSEncoder{
MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, uint16, error) {
return nil, 0, expected
},
},
Decoder: nil,
NumTimeouts: &atomicx.Int64{},
Txp: &mocks.DNSTransport{
MockRequiresPadding: func() bool {
return false
},
},
}
ctx := context.Background()
ns, err := r.LookupNS(ctx, "example.com")
if !errors.Is(err, expected) {
t.Fatal("unexpected err", err)
}
if ns != nil {
t.Fatal("unexpected result")
}
})
t.Run("for round-trip error", func(t *testing.T) { t.Run("for round-trip error", func(t *testing.T) {
expected := errors.New("mocked error") expected := errors.New("mocked error")
r := &SerialResolver{ r := &SerialResolver{
Encoder: &mocks.DNSEncoder{
MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, uint16, error) {
return make([]byte, 64), 0, nil
},
},
Decoder: nil,
NumTimeouts: &atomicx.Int64{}, NumTimeouts: &atomicx.Int64{},
Txp: &mocks.DNSTransport{ Txp: &mocks.DNSTransport{
MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) { MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
return nil, expected return nil, expected
}, },
MockRequiresPadding: func() bool { MockRequiresPadding: func() bool {
@ -332,20 +266,15 @@ func TestSerialResolver(t *testing.T) {
t.Run("for decode error", func(t *testing.T) { t.Run("for decode error", func(t *testing.T) {
expected := errors.New("mocked error") expected := errors.New("mocked error")
r := &SerialResolver{ r := &SerialResolver{
Encoder: &mocks.DNSEncoder{
MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, uint16, error) {
return make([]byte, 64), 0, nil
},
},
Decoder: &mocks.DNSDecoder{
MockDecodeNS: func(reply []byte, queryID uint16) ([]*net.NS, error) {
return nil, expected
},
},
NumTimeouts: &atomicx.Int64{}, NumTimeouts: &atomicx.Int64{},
Txp: &mocks.DNSTransport{ Txp: &mocks.DNSTransport{
MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) { MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
return make([]byte, 128), nil response := &mocks.DNSResponse{
MockDecodeNS: func() ([]*net.NS, error) {
return nil, expected
},
}
return response, nil
}, },
MockRequiresPadding: func() bool { MockRequiresPadding: func() bool {
return false return false