refactor: DNSTransport I/Os DNS messages (#760)

This diff refactors the DNSTransport model to receive in input a DNSQuery and return in output a DNSResponse.

The design of DNSQuery and DNSResponse takes into account the use case of a transport using getaddrinfo, meaning that we don't need to serialize and deserialize messages when using getaddrinfo.

The current codebase does not use a getaddrinfo transport, but I wrote one such a transport in the Websteps Winter 2021 prototype (https://github.com/bassosimone/websteps-illustrated/).

The design conversation that lead to producing this diff is https://github.com/ooni/probe/issues/2099
This commit is contained in:
Simone Basso 2022-05-25 17:03:58 +02:00 committed by GitHub
parent 7a0a156aec
commit 01a513a496
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
35 changed files with 1731 additions and 1076 deletions

1
go.mod
View File

@ -17,7 +17,6 @@ require (
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510
github.com/google/uuid v1.3.0
github.com/gorilla/websocket v1.5.0
github.com/hexops/gotextdiff v1.0.3
github.com/iancoleman/strcase v0.2.0
github.com/lucas-clemente/quic-go v0.27.0
github.com/mattn/go-colorable v0.1.12

2
go.sum
View File

@ -378,8 +378,6 @@ github.com/hashicorp/logutils v1.0.0/go.mod h1:QIAnNjmIWmVIIkWDTG1z5v++HQmx9WQRO
github.com/hashicorp/mdns v1.0.0/go.mod h1:tL+uN++7HEJ6SQLQ2/p+z2pH24WQKWjBPkE0mNTz8vQ=
github.com/hashicorp/memberlist v0.1.3/go.mod h1:ajVTdAv/9Im8oMAAj5G31PhhMCZJV2pPBoIllUwCN7I=
github.com/hashicorp/serf v0.8.2/go.mod h1:6hOLApaqBFA1NXqRQAsxw9QxuDEvNxSQRwA/JwenrHc=
github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUqJM=
github.com/hexops/gotextdiff v1.0.3/go.mod h1:pSWU5MAI3yDq+fZBTazCSJysOMbxWL1BSow5/V2vxeg=
github.com/hinshun/vt10x v0.0.0-20220119200601-820417d04eec h1:qv2VnGeEQHchGaZ/u7lxST/RaJw+cv273q79D81Xbog=
github.com/hinshun/vt10x v0.0.0-20220119200601-820417d04eec/go.mod h1:Q48J4R4DvxnHolD5P8pOtXigYlRuPLGl6moFx3ulM68=
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=

View File

@ -317,7 +317,7 @@ func NewDNSClientWithOverrides(config Config, URL, hostOverride, SNIOverride,
if err != nil {
return nil, err
}
var txp model.DNSTransport = netxlite.NewDNSOverTLS(
var txp model.DNSTransport = netxlite.NewDNSOverTLSTransport(
tlsDialer.DialTLSContext, endpoint)
if config.ResolveSaver != nil {
txp = resolver.SaverDNSTransport{

View File

@ -76,40 +76,6 @@ func (c *FakeConn) SetWriteDeadline(t time.Time) (err error) {
return c.SetWriteDeadlineError
}
type FakeTransport struct {
Data []byte
Err error
}
func (ft FakeTransport) RoundTrip(ctx context.Context, query []byte) ([]byte, error) {
return ft.Data, ft.Err
}
func (ft FakeTransport) RequiresPadding() bool {
return false
}
func (ft FakeTransport) Address() string {
return ""
}
func (ft FakeTransport) Network() string {
return "fake"
}
func (fk FakeTransport) CloseIdleConnections() {
// nothing to do
}
type FakeEncoder struct {
Data []byte
Err error
}
func (fe FakeEncoder) Encode(domain string, qtype uint16, padding bool) ([]byte, error) {
return fe.Data, fe.Err
}
func NewFakeResolverThatFails() model.Resolver {
return NewFakeResolverWithExplicitError(netxlite.ErrOODNSNoSuchHost)
}

View File

@ -99,14 +99,14 @@ func TestNewResolverTCPDomain(t *testing.T) {
func TestNewResolverDoTAddress(t *testing.T) {
reso := netxlite.NewSerialResolver(
netxlite.NewDNSOverTLS(new(tls.Dialer).DialContext, "8.8.8.8:853"))
netxlite.NewDNSOverTLSTransport(new(tls.Dialer).DialContext, "8.8.8.8:853"))
testresolverquick(t, reso)
testresolverquickidna(t, reso)
}
func TestNewResolverDoTDomain(t *testing.T) {
reso := netxlite.NewSerialResolver(
netxlite.NewDNSOverTLS(new(tls.Dialer).DialContext, "dns.google.com:853"))
netxlite.NewDNSOverTLSTransport(new(tls.Dialer).DialContext, "dns.google.com:853"))
testresolverquick(t, reso)
testresolverquickidna(t, reso)
}

View File

@ -46,28 +46,41 @@ type SaverDNSTransport struct {
}
// RoundTrip implements RoundTripper.RoundTrip
func (txp SaverDNSTransport) RoundTrip(ctx context.Context, query []byte) ([]byte, error) {
func (txp SaverDNSTransport) RoundTrip(
ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
start := time.Now()
txp.Saver.Write(trace.Event{
Address: txp.Address(),
DNSQuery: query,
DNSQuery: txp.maybeQueryBytes(query),
Name: "dns_round_trip_start",
Proto: txp.Network(),
Time: start,
})
reply, err := txp.DNSTransport.RoundTrip(ctx, query)
response, err := txp.DNSTransport.RoundTrip(ctx, query)
stop := time.Now()
txp.Saver.Write(trace.Event{
Address: txp.Address(),
DNSQuery: query,
DNSReply: reply,
DNSQuery: txp.maybeQueryBytes(query),
DNSReply: txp.maybeResponseBytes(response),
Duration: stop.Sub(start),
Err: err,
Name: "dns_round_trip_done",
Proto: txp.Network(),
Time: stop,
})
return reply, err
return response, err
}
func (txp SaverDNSTransport) maybeQueryBytes(query model.DNSQuery) []byte {
data, _ := query.Bytes()
return data
}
func (txp SaverDNSTransport) maybeResponseBytes(response model.DNSResponse) []byte {
if response == nil {
return nil
}
return response.Bytes()
}
var _ model.Resolver = SaverResolver{}

View File

@ -10,6 +10,8 @@ import (
"github.com/ooni/probe-cli/v3/internal/engine/netx/resolver"
"github.com/ooni/probe-cli/v3/internal/engine/netx/trace"
"github.com/ooni/probe-cli/v3/internal/model"
"github.com/ooni/probe-cli/v3/internal/model/mocks"
)
func TestSaverResolverFailure(t *testing.T) {
@ -110,12 +112,25 @@ func TestSaverDNSTransportFailure(t *testing.T) {
expected := errors.New("no such host")
saver := &trace.Saver{}
txp := resolver.SaverDNSTransport{
DNSTransport: resolver.FakeTransport{
Err: expected,
DNSTransport: &mocks.DNSTransport{
MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
return nil, expected
},
MockNetwork: func() string {
return "fake"
},
MockAddress: func() string {
return ""
},
},
Saver: saver,
}
query := []byte("abc")
rawQuery := []byte{0xde, 0xad, 0xbe, 0xef}
query := &mocks.DNSQuery{
MockBytes: func() ([]byte, error) {
return rawQuery, nil
},
}
reply, err := txp.RoundTrip(context.Background(), query)
if !errors.Is(err, expected) {
t.Fatal("not the error we expected")
@ -127,7 +142,7 @@ func TestSaverDNSTransportFailure(t *testing.T) {
if len(ev) != 2 {
t.Fatal("expected number of events")
}
if !bytes.Equal(ev[0].DNSQuery, query) {
if !bytes.Equal(ev[0].DNSQuery, rawQuery) {
t.Fatal("unexpected DNSQuery")
}
if ev[0].Name != "dns_round_trip_start" {
@ -136,7 +151,7 @@ func TestSaverDNSTransportFailure(t *testing.T) {
if !ev[0].Time.Before(time.Now()) {
t.Fatal("the saved time is wrong")
}
if !bytes.Equal(ev[1].DNSQuery, query) {
if !bytes.Equal(ev[1].DNSQuery, rawQuery) {
t.Fatal("unexpected DNSQuery")
}
if ev[1].DNSReply != nil {
@ -157,27 +172,45 @@ func TestSaverDNSTransportFailure(t *testing.T) {
}
func TestSaverDNSTransportSuccess(t *testing.T) {
expected := []byte("def")
expected := []byte{0xef, 0xbe, 0xad, 0xde}
saver := &trace.Saver{}
response := &mocks.DNSResponse{
MockBytes: func() []byte {
return expected
},
}
txp := resolver.SaverDNSTransport{
DNSTransport: resolver.FakeTransport{
Data: expected,
DNSTransport: &mocks.DNSTransport{
MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
return response, nil
},
MockNetwork: func() string {
return "fake"
},
MockAddress: func() string {
return ""
},
},
Saver: saver,
}
query := []byte("abc")
rawQuery := []byte{0xde, 0xad, 0xbe, 0xef}
query := &mocks.DNSQuery{
MockBytes: func() ([]byte, error) {
return rawQuery, nil
},
}
reply, err := txp.RoundTrip(context.Background(), query)
if err != nil {
t.Fatal("we expected nil error here")
}
if !bytes.Equal(reply, expected) {
if !bytes.Equal(reply.Bytes(), expected) {
t.Fatal("expected another reply here")
}
ev := saver.Read()
if len(ev) != 2 {
t.Fatal("expected number of events")
}
if !bytes.Equal(ev[0].DNSQuery, query) {
if !bytes.Equal(ev[0].DNSQuery, rawQuery) {
t.Fatal("unexpected DNSQuery")
}
if ev[0].Name != "dns_round_trip_start" {
@ -186,7 +219,7 @@ func TestSaverDNSTransportSuccess(t *testing.T) {
if !ev[0].Time.Before(time.Now()) {
t.Fatal("the saved time is wrong")
}
if !bytes.Equal(ev[1].DNSQuery, query) {
if !bytes.Equal(ev[1].DNSQuery, rawQuery) {
t.Fatal("unexpected DNSQuery")
}
if !bytes.Equal(ev[1].DNSReply, expected) {

View File

@ -36,18 +36,31 @@ type DNSRoundTripEvent struct {
Reply []byte
}
func (txp *dnsxRoundTripperDB) RoundTrip(ctx context.Context, query []byte) ([]byte, error) {
func (txp *dnsxRoundTripperDB) RoundTrip(
ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
started := time.Since(txp.begin).Seconds()
reply, err := txp.DNSTransport.RoundTrip(ctx, query)
response, err := txp.DNSTransport.RoundTrip(ctx, query)
finished := time.Since(txp.begin).Seconds()
txp.db.InsertIntoDNSRoundTrip(&DNSRoundTripEvent{
Network: txp.DNSTransport.Network(),
Address: txp.DNSTransport.Address(),
Query: query,
Query: txp.maybeQueryBytes(query),
Started: started,
Finished: finished,
Failure: NewFailure(err),
Reply: reply,
Reply: txp.maybeResponseBytes(response),
})
return reply, err
return response, err
}
func (txp *dnsxRoundTripperDB) maybeQueryBytes(query model.DNSQuery) []byte {
data, _ := query.Bytes()
return data
}
func (txp *dnsxRoundTripperDB) maybeResponseBytes(response model.DNSResponse) []byte {
if response == nil {
return nil
}
return response.Bytes()
}

View File

@ -1,36 +1,20 @@
package mocks
import (
"net"
//
// Mocks for model.DNSDecoder
//
"github.com/miekg/dns"
import (
"github.com/ooni/probe-cli/v3/internal/model"
)
// DNSDecoder allows mocking dnsx.DNSDecoder.
// DNSDecoder allows mocking model.DNSDecoder.
type DNSDecoder struct {
MockDecodeLookupHost func(qtype uint16, reply []byte, queryID uint16) ([]string, error)
MockDecodeHTTPS func(reply []byte, queryID uint16) (*model.HTTPSSvc, error)
MockDecodeNS func(reply []byte, queryID uint16) ([]*net.NS, error)
MockDecodeReply func(reply []byte) (*dns.Msg, error)
MockDecodeResponse func(data []byte, query model.DNSQuery) (model.DNSResponse, error)
}
// DecodeLookupHost calls MockDecodeLookupHost.
func (e *DNSDecoder) DecodeLookupHost(qtype uint16, reply []byte, queryID uint16) ([]string, error) {
return e.MockDecodeLookupHost(qtype, reply, queryID)
}
var _ model.DNSDecoder = &DNSDecoder{}
// DecodeHTTPS calls MockDecodeHTTPS.
func (e *DNSDecoder) DecodeHTTPS(reply []byte, queryID uint16) (*model.HTTPSSvc, error) {
return e.MockDecodeHTTPS(reply, queryID)
}
// DecodeNS calls MockDecodeNS.
func (e *DNSDecoder) DecodeNS(reply []byte, queryID uint16) ([]*net.NS, error) {
return e.MockDecodeNS(reply, queryID)
}
// DecodeReply calls MockDecodeReply.
func (e *DNSDecoder) DecodeReply(reply []byte) (*dns.Msg, error) {
return e.MockDecodeReply(reply)
func (e *DNSDecoder) DecodeResponse(data []byte, query model.DNSQuery) (model.DNSResponse, error) {
return e.MockDecodeResponse(data, query)
}

View File

@ -2,70 +2,20 @@ package mocks
import (
"errors"
"net"
"testing"
"github.com/miekg/dns"
"github.com/ooni/probe-cli/v3/internal/model"
)
func TestDNSDecoder(t *testing.T) {
t.Run("DecodeLookupHost", func(t *testing.T) {
t.Run("DecodeResponse", func(t *testing.T) {
expected := errors.New("mocked error")
e := &DNSDecoder{
MockDecodeLookupHost: func(qtype uint16, reply []byte, queryID uint16) ([]string, error) {
MockDecodeResponse: func(reply []byte, query model.DNSQuery) (model.DNSResponse, error) {
return nil, expected
},
}
out, err := e.DecodeLookupHost(dns.TypeA, make([]byte, 17), dns.Id())
if !errors.Is(err, expected) {
t.Fatal("unexpected err", err)
}
if out != nil {
t.Fatal("unexpected out")
}
})
t.Run("DecodeHTTPS", func(t *testing.T) {
expected := errors.New("mocked error")
e := &DNSDecoder{
MockDecodeHTTPS: func(reply []byte, queryID uint16) (*model.HTTPSSvc, error) {
return nil, expected
},
}
out, err := e.DecodeHTTPS(make([]byte, 17), dns.Id())
if !errors.Is(err, expected) {
t.Fatal("unexpected err", err)
}
if out != nil {
t.Fatal("unexpected out")
}
})
t.Run("DecodeNS", func(t *testing.T) {
expected := errors.New("mocked error")
e := &DNSDecoder{
MockDecodeNS: func(reply []byte, queryID uint16) ([]*net.NS, error) {
return nil, expected
},
}
out, err := e.DecodeNS(make([]byte, 17), dns.Id())
if !errors.Is(err, expected) {
t.Fatal("unexpected err", err)
}
if out != nil {
t.Fatal("unexpected out")
}
})
t.Run("DecodeReply", func(t *testing.T) {
expected := errors.New("mocked error")
e := &DNSDecoder{
MockDecodeReply: func(reply []byte) (*dns.Msg, error) {
return nil, expected
},
}
out, err := e.DecodeReply(make([]byte, 17))
out, err := e.DecodeResponse(make([]byte, 17), &DNSQuery{})
if !errors.Is(err, expected) {
t.Fatal("unexpected err", err)
}

View File

@ -1,11 +1,19 @@
package mocks
// DNSEncoder allows mocking dnsx.DNSEncoder.
//
// Mocks for model.DNSEncoder.
//
import "github.com/ooni/probe-cli/v3/internal/model"
// DNSEncoder allows mocking model.DNSEncoder.
type DNSEncoder struct {
MockEncode func(domain string, qtype uint16, padding bool) ([]byte, uint16, error)
MockEncode func(domain string, qtype uint16, padding bool) model.DNSQuery
}
var _ model.DNSEncoder = &DNSEncoder{}
// Encode calls MockEncode.
func (e *DNSEncoder) Encode(domain string, qtype uint16, padding bool) ([]byte, uint16, error) {
func (e *DNSEncoder) Encode(domain string, qtype uint16, padding bool) model.DNSQuery {
return e.MockEncode(domain, qtype, padding)
}

View File

@ -5,24 +5,46 @@ import (
"testing"
"github.com/miekg/dns"
"github.com/ooni/probe-cli/v3/internal/model"
)
func TestDNSEncoder(t *testing.T) {
t.Run("Encode", func(t *testing.T) {
expected := errors.New("mocked error")
queryID := dns.Id()
e := &DNSEncoder{
MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, uint16, error) {
return nil, 0, expected
MockEncode: func(domain string, qtype uint16, padding bool) model.DNSQuery {
return &DNSQuery{
MockDomain: func() string {
return dns.Fqdn(domain) // do what an implementation MUST do
},
MockType: func() uint16 {
return qtype
},
MockBytes: func() ([]byte, error) {
return nil, expected
},
MockID: func() uint16 {
return queryID
},
}
},
}
out, queryID, err := e.Encode("dns.google", dns.TypeA, true)
query := e.Encode("dns.google", dns.TypeA, true)
if query.Domain() != "dns.google." {
t.Fatal("invalid domain")
}
if query.Type() != dns.TypeA {
t.Fatal("invalid type")
}
out, err := query.Bytes()
if !errors.Is(err, expected) {
t.Fatal("unexpected err", err)
}
if out != nil {
t.Fatal("unexpected out")
}
if queryID != 0 {
if query.ID() != queryID {
t.Fatal("unexpected queryID")
}
})

View File

@ -0,0 +1,33 @@
package mocks
//
// Mocks for model.DNSQuery.
//
import "github.com/ooni/probe-cli/v3/internal/model"
// DNSQuery allocks mocking model.DNSQuery.
type DNSQuery struct {
MockDomain func() string
MockType func() uint16
MockBytes func() ([]byte, error)
MockID func() uint16
}
func (q *DNSQuery) Domain() string {
return q.MockDomain()
}
func (q *DNSQuery) Type() uint16 {
return q.MockType()
}
func (q *DNSQuery) Bytes() ([]byte, error) {
return q.MockBytes()
}
func (q *DNSQuery) ID() uint16 {
return q.MockID()
}
var _ model.DNSQuery = &DNSQuery{}

View File

@ -0,0 +1,62 @@
package mocks
import (
"bytes"
"testing"
"github.com/miekg/dns"
)
func TestDNSQuery(t *testing.T) {
t.Run("Domain", func(t *testing.T) {
expected := "dns.google."
q := &DNSQuery{
MockDomain: func() string {
return expected
},
}
if q.Domain() != expected {
t.Fatal("invalid domain")
}
})
t.Run("Type", func(t *testing.T) {
expected := dns.TypeAAAA
q := &DNSQuery{
MockType: func() uint16 {
return expected
},
}
if q.Type() != expected {
t.Fatal("invalid type")
}
})
t.Run("Bytes", func(t *testing.T) {
expected := []byte{0xde, 0xea, 0xad, 0xbe, 0xef}
q := &DNSQuery{
MockBytes: func() ([]byte, error) {
return expected, nil
},
}
out, err := q.Bytes()
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(expected, out) {
t.Fatal("invalid bytes")
}
})
t.Run("ID", func(t *testing.T) {
expected := dns.Id()
q := &DNSQuery{
MockID: func() uint16 {
return expected
},
}
if q.ID() != expected {
t.Fatal("invalid id")
}
})
}

View File

@ -0,0 +1,47 @@
package mocks
//
// Mocks for model.DNSResponse
//
import (
"net"
"github.com/ooni/probe-cli/v3/internal/model"
)
// DNSResponse allows mocking model.DNSResponse.
type DNSResponse struct {
MockQuery func() model.DNSQuery
MockBytes func() []byte
MockRcode func() int
MockDecodeHTTPS func() (*model.HTTPSSvc, error)
MockDecodeLookupHost func() ([]string, error)
MockDecodeNS func() ([]*net.NS, error)
}
var _ model.DNSResponse = &DNSResponse{}
func (r *DNSResponse) Query() model.DNSQuery {
return r.MockQuery()
}
func (r *DNSResponse) Bytes() []byte {
return r.MockBytes()
}
func (r *DNSResponse) Rcode() int {
return r.MockRcode()
}
func (r *DNSResponse) DecodeHTTPS() (*model.HTTPSSvc, error) {
return r.MockDecodeHTTPS()
}
func (r *DNSResponse) DecodeLookupHost() ([]string, error) {
return r.MockDecodeLookupHost()
}
func (r *DNSResponse) DecodeNS() ([]*net.NS, error) {
return r.MockDecodeNS()
}

View File

@ -0,0 +1,105 @@
package mocks
import (
"bytes"
"errors"
"net"
"testing"
"github.com/miekg/dns"
"github.com/ooni/probe-cli/v3/internal/model"
)
func TestDNSResponse(t *testing.T) {
t.Run("Query", func(t *testing.T) {
qid := dns.Id()
query := &DNSQuery{
MockID: func() uint16 {
return qid
},
}
resp := &DNSResponse{
MockQuery: func() model.DNSQuery {
return query
},
}
out := resp.Query()
if out.ID() != query.ID() {
t.Fatal("invalid query")
}
})
t.Run("Bytes", func(t *testing.T) {
expected := []byte{0xde, 0xea, 0xad, 0xbe, 0xef}
resp := &DNSResponse{
MockBytes: func() []byte {
return expected
},
}
out := resp.Bytes()
if !bytes.Equal(expected, out) {
t.Fatal("invalid bytes")
}
})
t.Run("Rcode", func(t *testing.T) {
expected := dns.RcodeBadAlg
resp := &DNSResponse{
MockRcode: func() int {
return expected
},
}
out := resp.Rcode()
if out != expected {
t.Fatal("invalid rcode")
}
})
t.Run("DecodeLookupHost", func(t *testing.T) {
expected := errors.New("mocked error")
r := &DNSResponse{
MockDecodeLookupHost: func() ([]string, error) {
return nil, expected
},
}
out, err := r.DecodeLookupHost()
if !errors.Is(err, expected) {
t.Fatal("unexpected err", err)
}
if out != nil {
t.Fatal("unexpected out")
}
})
t.Run("DecodeHTTPS", func(t *testing.T) {
expected := errors.New("mocked error")
r := &DNSResponse{
MockDecodeHTTPS: func() (*model.HTTPSSvc, error) {
return nil, expected
},
}
out, err := r.DecodeHTTPS()
if !errors.Is(err, expected) {
t.Fatal("unexpected err", err)
}
if out != nil {
t.Fatal("unexpected out")
}
})
t.Run("DecodeNS", func(t *testing.T) {
expected := errors.New("mocked error")
r := &DNSResponse{
MockDecodeNS: func() ([]*net.NS, error) {
return nil, expected
},
}
out, err := r.DecodeNS()
if !errors.Is(err, expected) {
t.Fatal("unexpected err", err)
}
if out != nil {
t.Fatal("unexpected out")
}
})
}

View File

@ -1,10 +1,14 @@
package mocks
import "context"
import (
"context"
"github.com/ooni/probe-cli/v3/internal/model"
)
// DNSTransport allows mocking dnsx.DNSTransport.
type DNSTransport struct {
MockRoundTrip func(ctx context.Context, query []byte) ([]byte, error)
MockRoundTrip func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error)
MockRequiresPadding func() bool
@ -15,8 +19,10 @@ type DNSTransport struct {
MockCloseIdleConnections func()
}
var _ model.DNSTransport = &DNSTransport{}
// RoundTrip calls MockRoundTrip.
func (txp *DNSTransport) RoundTrip(ctx context.Context, query []byte) ([]byte, error) {
func (txp *DNSTransport) RoundTrip(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
return txp.MockRoundTrip(ctx, query)
}

View File

@ -6,17 +6,18 @@ import (
"testing"
"github.com/ooni/probe-cli/v3/internal/atomicx"
"github.com/ooni/probe-cli/v3/internal/model"
)
func TestDNSTransport(t *testing.T) {
t.Run("RoundTrip", func(t *testing.T) {
expected := errors.New("mocked error")
txp := &DNSTransport{
MockRoundTrip: func(ctx context.Context, query []byte) ([]byte, error) {
MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
return nil, expected
},
}
resp, err := txp.RoundTrip(context.Background(), make([]byte, 16))
resp, err := txp.RoundTrip(context.Background(), &DNSQuery{})
if !errors.Is(err, expected) {
t.Fatal("not the error we expected", err)
}

View File

@ -1,5 +1,9 @@
package model
//
// Network extensions
//
import (
"context"
"crypto/tls"
@ -9,74 +13,81 @@ import (
"time"
"github.com/lucas-clemente/quic-go"
"github.com/miekg/dns"
)
//
// Network extensions
//
// DNSResponse is a parsed DNS response ready for further processing.
type DNSResponse interface {
// Query is the query associated with this response.
Query() DNSQuery
// The DNSDecoder decodes DNS replies.
// Bytes returns the bytes from which we parsed the query.
Bytes() []byte
// Rcode returns the response's Rcode.
Rcode() int
// DecodeHTTPS returns information gathered from all the HTTPS
// records found inside of this response.
DecodeHTTPS() (*HTTPSSvc, error)
// DecodeLookupHost returns the addresses in the response matching
// the original query type (one of A and AAAA).
DecodeLookupHost() ([]string, error)
// DecodeNS returns all the NS entries in this response.
DecodeNS() ([]*net.NS, error)
}
// The DNSDecoder decodes DNS responses.
type DNSDecoder interface {
// DecodeLookupHost decodes an A or AAAA reply.
//
// Arguments:
//
// - qtype is the query type (e.g., dns.TypeAAAA)
//
// - data contains the reply bytes read from a DNSTransport
//
// - queryID is the original query ID
//
// Returns:
//
// - on success, a list of IP addrs inside the reply and a nil error
//
// - on failure, a nil list and an error.
//
// Note that this function will return an error if there is no
// IP address inside of the reply.
DecodeLookupHost(qtype uint16, data []byte, queryID uint16) ([]string, error)
// DecodeHTTPS is like DecodeLookupHost but decodes an HTTPS reply.
//
// The argument is the reply as read by the DNSTransport.
//
// On success, this function returns an HTTPSSvc structure and
// a nil error. On failure, the HTTPSSvc pointer is nil and
// the error points to the error that occurred.
//
// This function will return an error if the HTTPS reply does not
// contain at least a valid ALPN entry. It will not return
// an error, though, when there are no IPv4/IPv6 hints in the reply.
DecodeHTTPS(data []byte, queryID uint16) (*HTTPSSvc, error)
// DecodeNS is like DecodeHTTPS but for NS queries.
DecodeNS(data []byte, queryID uint16) ([]*net.NS, error)
// DecodeReply decodes a DNS reply message.
// DecodeResponse decodes a DNS response message.
//
// Arguments:
//
// - data is the raw reply
//
// This function fails if we cannot parse data as a DNS
// message or the message is not a reply.
// message or the message is not a response.
//
// If you use this function, remember that:
// Regarding the returned response, remember that the Rcode
// MAY still be nonzero (this method does not treat a nonzero
// Rcode as an error when parsing the response).
DecodeResponse(data []byte, query DNSQuery) (DNSResponse, error)
}
// DNSQuery is an encoded DNS query ready to be sent using a DNSTransport.
type DNSQuery interface {
// Domain is the domain we're querying for.
Domain() string
// Type is the query type.
Type() uint16
// Bytes serializes the query to bytes. This function may fail if we're not
// able to correctly encode the domain into a query message.
//
// 1. the Rcode MAY be nonzero;
//
// 2. the replyID MAY NOT match the original query ID.
//
// That is, this is a very basic parsing method.
DecodeReply(data []byte) (*dns.Msg, error)
// The value returned by this function WILL be memoized after the first call,
// so you SHOULD create a new DNSQuery if you need to retry a query.
Bytes() ([]byte, error)
// ID returns the query ID.
ID() uint16
}
// The DNSEncoder encodes DNS queries to bytes
type DNSEncoder interface {
// Encode transforms its arguments into a serialized DNS query.
//
// Every time you call Encode, you get a new DNSQuery value
// using a query ID selected at random.
//
// Serialization to bytes is lazy to acommodate DNS transports that
// do not need to serialize and send bytes, e.g., getaddrinfo.
//
// You serialize to bytes using DNSQuery.Bytes. This operation MAY fail
// if the domain name cannot be packed into a DNS message (e.g., it is
// too long to fit into the message).
//
// Arguments:
//
// - domain is the domain for the query (e.g., x.org);
@ -85,16 +96,15 @@ type DNSEncoder interface {
//
// - padding is whether to add padding to the query.
//
// On success, this function returns a valid byte array, the queryID, and
// a nil error. On failure, we have a non-nil error, a nil arrary and a zero
// query ID.
Encode(domain string, qtype uint16, padding bool) ([]byte, uint16, error)
// This function will transform the domain into an FQDN is it's not
// already expressed in the FQDN format.
Encode(domain string, qtype uint16, padding bool) DNSQuery
}
// DNSTransport represents an abstract DNS transport.
type DNSTransport interface {
// RoundTrip sends a DNS query and receives the reply.
RoundTrip(ctx context.Context, query []byte) (reply []byte, err error)
RoundTrip(ctx context.Context, query DNSQuery) (DNSResponse, error)
// RequiresPadding returns whether this transport needs padding.
RequiresPadding() bool

View File

@ -15,14 +15,16 @@ import (
// DNSDecoderMiekg uses github.com/miekg/dns to implement the Decoder.
type DNSDecoderMiekg struct{}
// ErrDNSReplyWithWrongQueryID indicates we have got a DNS reply with the wrong queryID.
var ErrDNSReplyWithWrongQueryID = errors.New(FailureDNSReplyWithWrongQueryID)
var (
// ErrDNSReplyWithWrongQueryID indicates we have got a DNS reply with the wrong queryID.
ErrDNSReplyWithWrongQueryID = errors.New(FailureDNSReplyWithWrongQueryID)
// ErrDNSIsQuery indicates that we were passed a DNS query.
var ErrDNSIsQuery = errors.New("ooresolver: expected response but received query")
// ErrDNSIsQuery indicates that we were passed a DNS query.
ErrDNSIsQuery = errors.New("ooresolver: expected response but received query")
)
// DecodeReply implements model.DNSDecoder.DecodeReply
func (d *DNSDecoderMiekg) DecodeReply(data []byte) (*dns.Msg, error) {
// DecodeResponse implements model.DNSDecoder.DecodeResponse.
func (d *DNSDecoderMiekg) DecodeResponse(data []byte, query model.DNSQuery) (model.DNSResponse, error) {
reply := &dns.Msg{}
if err := reply.Unpack(data); err != nil {
return nil, err
@ -30,46 +32,64 @@ func (d *DNSDecoderMiekg) DecodeReply(data []byte) (*dns.Msg, error) {
if !reply.Response {
return nil, ErrDNSIsQuery
}
return reply, nil
}
// decodeSuccessfulReply decodes the bytes in data as a successful reply for the
// given queryID. This function returns an error if:
//
// 1. we cannot decode data
//
// 2. the decoded message is not a reply
//
// 3. the query ID does not match
//
// 4. the Rcode is not zero.
func (d *DNSDecoderMiekg) decodeSuccessfulReply(data []byte, queryID uint16) (*dns.Msg, error) {
reply, err := d.DecodeReply(data)
if err != nil {
return nil, err
}
if reply.Id != queryID {
if reply.Id != query.ID() {
return nil, ErrDNSReplyWithWrongQueryID
}
resp := &dnsResponse{
bytes: data,
msg: reply,
query: query,
}
return resp, nil
}
// dnsResponse implements model.DNSResponse.
type dnsResponse struct {
// bytes contains the response bytes.
bytes []byte
// msg contains the message.
msg *dns.Msg
// query is the original query.
query model.DNSQuery
}
// Query implements model.DNSResponse.Query.
func (r *dnsResponse) Query() model.DNSQuery {
return r.query
}
// Bytes implements model.DNSResponse.Bytes.
func (r *dnsResponse) Bytes() []byte {
return r.bytes
}
// Rcode implements model.DNSResponse.Rcode.
func (r *dnsResponse) Rcode() int {
return r.msg.Rcode
}
func (r *dnsResponse) rcodeToError() error {
// TODO(bassosimone): map more errors to net.DNSError names
// TODO(bassosimone): add support for lame referral.
switch reply.Rcode {
switch r.msg.Rcode {
case dns.RcodeSuccess:
return reply, nil
return nil
case dns.RcodeNameError:
return nil, ErrOODNSNoSuchHost
return ErrOODNSNoSuchHost
case dns.RcodeRefused:
return nil, ErrOODNSRefused
return ErrOODNSRefused
case dns.RcodeServerFailure:
return nil, ErrOODNSServfail
return ErrOODNSServfail
default:
return nil, ErrOODNSMisbehaving
return ErrOODNSMisbehaving
}
}
func (d *DNSDecoderMiekg) DecodeHTTPS(data []byte, queryID uint16) (*model.HTTPSSvc, error) {
reply, err := d.decodeSuccessfulReply(data, queryID)
if err != nil {
// DecodeHTTPS implements model.DNSResponse.DecodeHTTPS.
func (r *dnsResponse) DecodeHTTPS() (*model.HTTPSSvc, error) {
if err := r.rcodeToError(); err != nil {
return nil, err
}
out := &model.HTTPSSvc{
@ -77,7 +97,7 @@ func (d *DNSDecoderMiekg) DecodeHTTPS(data []byte, queryID uint16) (*model.HTTPS
IPv4: []string{}, // ensure it's not nil
IPv6: []string{}, // ensure it's not nil
}
for _, answer := range reply.Answer {
for _, answer := range r.msg.Answer {
switch avalue := answer.(type) {
case *dns.HTTPS:
for _, v := range avalue.Value {
@ -102,14 +122,14 @@ func (d *DNSDecoderMiekg) DecodeHTTPS(data []byte, queryID uint16) (*model.HTTPS
return out, nil
}
func (d *DNSDecoderMiekg) DecodeLookupHost(qtype uint16, data []byte, queryID uint16) ([]string, error) {
reply, err := d.decodeSuccessfulReply(data, queryID)
if err != nil {
// DecodeLookupHost implements model.DNSResponse.DecodeLookupHost.
func (r *dnsResponse) DecodeLookupHost() ([]string, error) {
if err := r.rcodeToError(); err != nil {
return nil, err
}
var addrs []string
for _, answer := range reply.Answer {
switch qtype {
for _, answer := range r.msg.Answer {
switch r.Query().Type() {
case dns.TypeA:
if rra, ok := answer.(*dns.A); ok {
ip := rra.A
@ -128,13 +148,13 @@ func (d *DNSDecoderMiekg) DecodeLookupHost(qtype uint16, data []byte, queryID ui
return addrs, nil
}
func (d *DNSDecoderMiekg) DecodeNS(data []byte, queryID uint16) ([]*net.NS, error) {
reply, err := d.decodeSuccessfulReply(data, queryID)
if err != nil {
// DecodeNS implements model.DNSResponse.DecodeNS.
func (r *dnsResponse) DecodeNS() ([]*net.NS, error) {
if err := r.rcodeToError(); err != nil {
return nil, err
}
out := []*net.NS{}
for _, answer := range reply.Answer {
for _, answer := range r.msg.Answer {
switch avalue := answer.(type) {
case *dns.NS:
out = append(out, &net.NS{Host: avalue.Ns})
@ -147,3 +167,4 @@ func (d *DNSDecoderMiekg) DecodeNS(data []byte, queryID uint16) ([]*net.NS, erro
}
var _ model.DNSDecoder = &DNSDecoderMiekg{}
var _ model.DNSResponse = &dnsResponse{}

View File

@ -1,26 +1,27 @@
package netxlite
import (
"bytes"
"errors"
"net"
"strings"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/miekg/dns"
"github.com/ooni/probe-cli/v3/internal/model/mocks"
"github.com/ooni/probe-cli/v3/internal/runtimex"
)
func TestDNSDecoder(t *testing.T) {
t.Run("LookupHost", func(t *testing.T) {
func TestDNSDecoderMiekg(t *testing.T) {
t.Run("DecodeResponse", func(t *testing.T) {
t.Run("UnpackError", func(t *testing.T) {
d := &DNSDecoderMiekg{}
data, err := d.DecodeLookupHost(dns.TypeA, nil, 0)
resp, err := d.DecodeResponse(nil, &mocks.DNSQuery{})
if err == nil || err.Error() != "dns: overflow unpacking uint16" {
t.Fatal("unexpected error", err)
}
if data != nil {
t.Fatal("expected nil data here")
if resp != nil {
t.Fatal("expected nil resp here")
}
})
@ -28,12 +29,12 @@ func TestDNSDecoder(t *testing.T) {
d := &DNSDecoderMiekg{}
queryID := dns.Id()
rawQuery := dnsGenQuery(dns.TypeA, queryID)
addrs, err := d.DecodeLookupHost(dns.TypeA, rawQuery, queryID)
resp, err := d.DecodeResponse(rawQuery, &mocks.DNSQuery{})
if !errors.Is(err, ErrDNSIsQuery) {
t.Fatal("unexpected err", err)
}
if len(addrs) > 0 {
t.Fatal("expected no addrs")
if resp != nil {
t.Fatal("expected nil resp here")
}
})
@ -44,297 +45,447 @@ func TestDNSDecoder(t *testing.T) {
unrelatedID = 14
)
reply := dnsGenLookupHostReplySuccess(dnsGenQuery(dns.TypeA, queryID))
data, err := d.DecodeLookupHost(dns.TypeA, reply, unrelatedID)
resp, err := d.DecodeResponse(reply, &mocks.DNSQuery{
MockID: func() uint16 {
return unrelatedID
},
})
if !errors.Is(err, ErrDNSReplyWithWrongQueryID) {
t.Fatal("unexpected error", err)
}
if data != nil {
t.Fatal("expected nil data here")
if resp != nil {
t.Fatal("expected nil resp here")
}
})
t.Run("NXDOMAIN", func(t *testing.T) {
t.Run("dnsResponse.Query", func(t *testing.T) {
d := &DNSDecoderMiekg{}
queryID := dns.Id()
data, err := d.DecodeLookupHost(dns.TypeA, dnsGenReplyWithError(
dnsGenQuery(dns.TypeA, queryID), dns.RcodeNameError), queryID)
if err == nil || !strings.HasSuffix(err.Error(), "no such host") {
t.Fatal("not the error we expected", err)
rawQuery := dnsGenQuery(dns.TypeA, queryID)
rawResponse := dnsGenLookupHostReplySuccess(rawQuery)
query := &mocks.DNSQuery{
MockID: func() uint16 {
return queryID
},
}
if data != nil {
t.Fatal("expected nil data here")
}
})
t.Run("Refused", func(t *testing.T) {
d := &DNSDecoderMiekg{}
queryID := dns.Id()
data, err := d.DecodeLookupHost(dns.TypeA, dnsGenReplyWithError(
dnsGenQuery(dns.TypeA, queryID), dns.RcodeRefused), queryID)
if !errors.Is(err, ErrOODNSRefused) {
t.Fatal("not the error we expected", err)
}
if data != nil {
t.Fatal("expected nil data here")
}
})
t.Run("Servfail", func(t *testing.T) {
d := &DNSDecoderMiekg{}
queryID := dns.Id()
data, err := d.DecodeLookupHost(dns.TypeA, dnsGenReplyWithError(
dnsGenQuery(dns.TypeA, queryID), dns.RcodeServerFailure), queryID)
if !errors.Is(err, ErrOODNSServfail) {
t.Fatal("not the error we expected", err)
}
if data != nil {
t.Fatal("expected nil data here")
}
})
t.Run("no address", func(t *testing.T) {
d := &DNSDecoderMiekg{}
queryID := dns.Id()
data, err := d.DecodeLookupHost(dns.TypeA, dnsGenLookupHostReplySuccess(
dnsGenQuery(dns.TypeA, queryID)), queryID)
if !errors.Is(err, ErrOODNSNoAnswer) {
t.Fatal("not the error we expected", err)
}
if data != nil {
t.Fatal("expected nil data here")
}
})
t.Run("decode A", func(t *testing.T) {
d := &DNSDecoderMiekg{}
queryID := dns.Id()
data, err := d.DecodeLookupHost(dns.TypeA, dnsGenLookupHostReplySuccess(
dnsGenQuery(dns.TypeA, queryID), "1.1.1.1", "8.8.8.8"), queryID)
resp, err := d.DecodeResponse(rawResponse, query)
if err != nil {
t.Fatal(err)
}
if len(data) != 2 {
t.Fatal("expected two entries here")
}
if data[0] != "1.1.1.1" {
t.Fatal("invalid first IPv4 entry")
}
if data[1] != "8.8.8.8" {
t.Fatal("invalid second IPv4 entry")
if resp.Query().ID() != query.ID() {
t.Fatal("invalid query")
}
})
t.Run("decode AAAA", func(t *testing.T) {
t.Run("dnsResponse.Bytes", func(t *testing.T) {
d := &DNSDecoderMiekg{}
queryID := dns.Id()
data, err := d.DecodeLookupHost(dns.TypeAAAA, dnsGenLookupHostReplySuccess(
dnsGenQuery(dns.TypeAAAA, queryID), "::1", "fe80::1"), queryID)
rawQuery := dnsGenQuery(dns.TypeA, queryID)
rawResponse := dnsGenLookupHostReplySuccess(rawQuery)
query := &mocks.DNSQuery{
MockID: func() uint16 {
return queryID
},
}
resp, err := d.DecodeResponse(rawResponse, query)
if err != nil {
t.Fatal(err)
}
if len(data) != 2 {
t.Fatal("expected two entries here")
}
if data[0] != "::1" {
t.Fatal("invalid first IPv6 entry")
}
if data[1] != "fe80::1" {
t.Fatal("invalid second IPv6 entry")
if !bytes.Equal(rawResponse, resp.Bytes()) {
t.Fatal("invalid bytes")
}
})
t.Run("unexpected A reply", func(t *testing.T) {
t.Run("dnsResponse.Rcode", func(t *testing.T) {
d := &DNSDecoderMiekg{}
queryID := dns.Id()
data, err := d.DecodeLookupHost(dns.TypeA, dnsGenLookupHostReplySuccess(
dnsGenQuery(dns.TypeAAAA, queryID), "::1", "fe80::1"), queryID)
if !errors.Is(err, ErrOODNSNoAnswer) {
t.Fatal("not the error we expected", err)
rawQuery := dnsGenQuery(dns.TypeA, queryID)
rawResponse := dnsGenReplyWithError(rawQuery, dns.RcodeRefused)
query := &mocks.DNSQuery{
MockID: func() uint16 {
return queryID
},
}
if data != nil {
t.Fatal("expected nil data here")
}
})
t.Run("unexpected AAAA reply", func(t *testing.T) {
d := &DNSDecoderMiekg{}
queryID := dns.Id()
data, err := d.DecodeLookupHost(dns.TypeAAAA, dnsGenLookupHostReplySuccess(
dnsGenQuery(dns.TypeA, queryID), "1.1.1.1", "8.8.4.4"), queryID)
if !errors.Is(err, ErrOODNSNoAnswer) {
t.Fatal("not the error we expected", err)
}
if data != nil {
t.Fatal("expected nil data here")
}
})
})
t.Run("decodeSuccessfulReply", func(t *testing.T) {
d := &DNSDecoderMiekg{}
msg := &dns.Msg{}
msg.Rcode = dns.RcodeFormatError // an rcode we don't handle
msg.Response = true
data, err := msg.Pack()
if err != nil {
t.Fatal(err)
}
reply, err := d.decodeSuccessfulReply(data, 0)
if !errors.Is(err, ErrOODNSMisbehaving) { // catch all error
t.Fatal("not the error we expected", err)
}
if reply != nil {
t.Fatal("expected nil reply")
}
})
t.Run("DecodeHTTPS", func(t *testing.T) {
t.Run("with nil data", func(t *testing.T) {
d := &DNSDecoderMiekg{}
reply, err := d.DecodeHTTPS(nil, 0)
if err == nil || err.Error() != "dns: overflow unpacking uint16" {
t.Fatal("not the error we expected", err)
}
if reply != nil {
t.Fatal("expected nil reply")
}
})
t.Run("with bytes containing a query", func(t *testing.T) {
d := &DNSDecoderMiekg{}
queryID := dns.Id()
rawQuery := dnsGenQuery(dns.TypeHTTPS, queryID)
https, err := d.DecodeHTTPS(rawQuery, queryID)
if !errors.Is(err, ErrDNSIsQuery) {
t.Fatal("unexpected err", err)
}
if https != nil {
t.Fatal("expected nil https")
}
})
t.Run("wrong query ID", func(t *testing.T) {
d := &DNSDecoderMiekg{}
const (
queryID = 17
unrelatedID = 14
)
reply := dnsGenHTTPSReplySuccess(dnsGenQuery(dns.TypeHTTPS, queryID), nil, nil, nil)
data, err := d.DecodeHTTPS(reply, unrelatedID)
if !errors.Is(err, ErrDNSReplyWithWrongQueryID) {
t.Fatal("unexpected error", err)
}
if data != nil {
t.Fatal("expected nil data here")
}
})
t.Run("with empty answer", func(t *testing.T) {
queryID := dns.Id()
data := dnsGenHTTPSReplySuccess(
dnsGenQuery(dns.TypeHTTPS, queryID), nil, nil, nil)
d := &DNSDecoderMiekg{}
reply, err := d.DecodeHTTPS(data, queryID)
if !errors.Is(err, ErrOODNSNoAnswer) {
t.Fatal("unexpected err", err)
}
if reply != nil {
t.Fatal("expected nil reply")
}
})
t.Run("with full answer", func(t *testing.T) {
queryID := dns.Id()
alpn := []string{"h3"}
v4 := []string{"1.1.1.1"}
v6 := []string{"::1"}
data := dnsGenHTTPSReplySuccess(
dnsGenQuery(dns.TypeHTTPS, queryID), alpn, v4, v6)
d := &DNSDecoderMiekg{}
reply, err := d.DecodeHTTPS(data, queryID)
resp, err := d.DecodeResponse(rawResponse, query)
if err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(alpn, reply.ALPN); diff != "" {
t.Fatal(diff)
}
if diff := cmp.Diff(v4, reply.IPv4); diff != "" {
t.Fatal(diff)
}
if diff := cmp.Diff(v6, reply.IPv6); diff != "" {
t.Fatal(diff)
}
})
})
t.Run("DecodeNS", func(t *testing.T) {
t.Run("with nil data", func(t *testing.T) {
d := &DNSDecoderMiekg{}
reply, err := d.DecodeNS(nil, 0)
if err == nil || err.Error() != "dns: overflow unpacking uint16" {
t.Fatal("not the error we expected", err)
}
if reply != nil {
t.Fatal("expected nil reply")
if resp.Rcode() != dns.RcodeRefused {
t.Fatal("invalid rcode")
}
})
t.Run("with bytes containing a query", func(t *testing.T) {
d := &DNSDecoderMiekg{}
queryID := dns.Id()
rawQuery := dnsGenQuery(dns.TypeNS, queryID)
ns, err := d.DecodeNS(rawQuery, queryID)
if !errors.Is(err, ErrDNSIsQuery) {
t.Fatal("unexpected err", err)
}
if len(ns) > 0 {
t.Fatal("expected no result")
t.Run("dnsResponse.rcodeToError", func(t *testing.T) {
// Here we want to ensure we map all the errors we recognize
// correctly and we also map unrecognized errors correctly
var inputsOutputs = []struct {
name string
rcode int
err error
}{{
name: "when rcode is zero",
rcode: 0,
err: nil,
}, {
name: "NXDOMAIN",
rcode: dns.RcodeNameError,
err: ErrOODNSNoSuchHost,
}, {
name: "refused",
rcode: dns.RcodeRefused,
err: ErrOODNSRefused,
}, {
name: "servfail",
rcode: dns.RcodeServerFailure,
err: ErrOODNSServfail,
}, {
name: "anything else",
rcode: dns.RcodeFormatError,
err: ErrOODNSMisbehaving,
}}
for _, io := range inputsOutputs {
t.Run(io.name, func(t *testing.T) {
d := &DNSDecoderMiekg{}
queryID := dns.Id()
rawQuery := dnsGenQuery(dns.TypeHTTPS, queryID)
rawResponse := dnsGenReplyWithError(rawQuery, io.rcode)
query := &mocks.DNSQuery{
MockID: func() uint16 {
return queryID
},
}
resp, err := d.DecodeResponse(rawResponse, query)
if err != nil {
t.Fatal(err)
}
// The following cast should always work in this configuration
err = resp.(*dnsResponse).rcodeToError()
if !errors.Is(err, io.err) {
t.Fatal("unexpected err", err)
}
})
}
})
t.Run("wrong query ID", func(t *testing.T) {
d := &DNSDecoderMiekg{}
const (
queryID = 17
unrelatedID = 14
)
reply := dnsGenNSReplySuccess(dnsGenQuery(dns.TypeNS, queryID))
data, err := d.DecodeNS(reply, unrelatedID)
if !errors.Is(err, ErrDNSReplyWithWrongQueryID) {
t.Fatal("unexpected error", err)
}
if data != nil {
t.Fatal("expected nil data here")
}
t.Run("dnsResponse.DecodeHTTPS", func(t *testing.T) {
t.Run("with failure", func(t *testing.T) {
// Ensure that we're not trying to decode if rcode != 0
d := &DNSDecoderMiekg{}
queryID := dns.Id()
rawQuery := dnsGenQuery(dns.TypeHTTPS, queryID)
rawResponse := dnsGenReplyWithError(rawQuery, dns.RcodeRefused)
query := &mocks.DNSQuery{
MockID: func() uint16 {
return queryID
},
}
resp, err := d.DecodeResponse(rawResponse, query)
if err != nil {
t.Fatal(err)
}
https, err := resp.DecodeHTTPS()
if !errors.Is(err, ErrOODNSRefused) {
t.Fatal("unexpected err", err)
}
if https != nil {
t.Fatal("expected nil https result")
}
})
t.Run("with empty answer", func(t *testing.T) {
d := &DNSDecoderMiekg{}
queryID := dns.Id()
rawQuery := dnsGenQuery(dns.TypeHTTPS, queryID)
rawResponse := dnsGenHTTPSReplySuccess(rawQuery, nil, nil, nil)
query := &mocks.DNSQuery{
MockID: func() uint16 {
return queryID
},
}
resp, err := d.DecodeResponse(rawResponse, query)
if err != nil {
t.Fatal(err)
}
https, err := resp.DecodeHTTPS()
if !errors.Is(err, ErrOODNSNoAnswer) {
t.Fatal("unexpected err", err)
}
if https != nil {
t.Fatal("expected nil https results")
}
})
t.Run("with full answer", func(t *testing.T) {
alpn := []string{"h3"}
v4 := []string{"1.1.1.1"}
v6 := []string{"::1"}
d := &DNSDecoderMiekg{}
queryID := dns.Id()
rawQuery := dnsGenQuery(dns.TypeHTTPS, queryID)
rawResponse := dnsGenHTTPSReplySuccess(rawQuery, alpn, v4, v6)
query := &mocks.DNSQuery{
MockID: func() uint16 {
return queryID
},
}
resp, err := d.DecodeResponse(rawResponse, query)
if err != nil {
t.Fatal(err)
}
reply, err := resp.DecodeHTTPS()
if err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(alpn, reply.ALPN); diff != "" {
t.Fatal(diff)
}
if diff := cmp.Diff(v4, reply.IPv4); diff != "" {
t.Fatal(diff)
}
if diff := cmp.Diff(v6, reply.IPv6); diff != "" {
t.Fatal(diff)
}
})
})
t.Run("with empty answer", func(t *testing.T) {
queryID := dns.Id()
data := dnsGenNSReplySuccess(dnsGenQuery(dns.TypeNS, queryID))
d := &DNSDecoderMiekg{}
reply, err := d.DecodeNS(data, queryID)
if !errors.Is(err, ErrOODNSNoAnswer) {
t.Fatal("unexpected err", err)
}
if reply != nil {
t.Fatal("expected nil reply")
}
t.Run("dnsResponse.DecodeNS", func(t *testing.T) {
t.Run("with failure", func(t *testing.T) {
// Ensure that we're not trying to decode if rcode != 0
d := &DNSDecoderMiekg{}
queryID := dns.Id()
rawQuery := dnsGenQuery(dns.TypeNS, queryID)
rawResponse := dnsGenReplyWithError(rawQuery, dns.RcodeRefused)
query := &mocks.DNSQuery{
MockID: func() uint16 {
return queryID
},
}
resp, err := d.DecodeResponse(rawResponse, query)
if err != nil {
t.Fatal(err)
}
ns, err := resp.DecodeNS()
if !errors.Is(err, ErrOODNSRefused) {
t.Fatal("unexpected err", err)
}
if len(ns) > 0 {
t.Fatal("expected empty ns result")
}
})
t.Run("with empty answer", func(t *testing.T) {
d := &DNSDecoderMiekg{}
queryID := dns.Id()
rawQuery := dnsGenQuery(dns.TypeNS, queryID)
rawResponse := dnsGenNSReplySuccess(rawQuery)
query := &mocks.DNSQuery{
MockID: func() uint16 {
return queryID
},
}
resp, err := d.DecodeResponse(rawResponse, query)
if err != nil {
t.Fatal(err)
}
ns, err := resp.DecodeNS()
if !errors.Is(err, ErrOODNSNoAnswer) {
t.Fatal("unexpected err", err)
}
if len(ns) > 0 {
t.Fatal("expected empty ns results")
}
})
t.Run("with full answer", func(t *testing.T) {
d := &DNSDecoderMiekg{}
queryID := dns.Id()
rawQuery := dnsGenQuery(dns.TypeNS, queryID)
rawResponse := dnsGenNSReplySuccess(rawQuery, "ns1.zdns.google.")
query := &mocks.DNSQuery{
MockID: func() uint16 {
return queryID
},
}
resp, err := d.DecodeResponse(rawResponse, query)
if err != nil {
t.Fatal(err)
}
ns, err := resp.DecodeNS()
if err != nil {
t.Fatal(err)
}
if len(ns) != 1 {
t.Fatal("unexpected ns length")
}
if ns[0].Host != "ns1.zdns.google." {
t.Fatal("unexpected host")
}
})
})
t.Run("with full answer", func(t *testing.T) {
queryID := dns.Id()
data := dnsGenNSReplySuccess(dnsGenQuery(dns.TypeNS, queryID), "ns1.zdns.google.")
d := &DNSDecoderMiekg{}
reply, err := d.DecodeNS(data, queryID)
if err != nil {
t.Fatal(err)
}
if len(reply) != 1 {
t.Fatal("unexpected reply length")
}
if reply[0].Host != "ns1.zdns.google." {
t.Fatal("unexpected reply host")
}
t.Run("dnsResponse.LookupHost", func(t *testing.T) {
t.Run("with failure", func(t *testing.T) {
// Ensure that we're not trying to decode if rcode != 0
d := &DNSDecoderMiekg{}
queryID := dns.Id()
rawQuery := dnsGenQuery(dns.TypeA, queryID)
rawResponse := dnsGenReplyWithError(rawQuery, dns.RcodeRefused)
query := &mocks.DNSQuery{
MockID: func() uint16 {
return queryID
},
}
resp, err := d.DecodeResponse(rawResponse, query)
if err != nil {
t.Fatal(err)
}
addrs, err := resp.DecodeLookupHost()
if !errors.Is(err, ErrOODNSRefused) {
t.Fatal("unexpected err", err)
}
if len(addrs) > 0 {
t.Fatal("expected empty addrs result")
}
})
t.Run("with empty answer", func(t *testing.T) {
d := &DNSDecoderMiekg{}
queryID := dns.Id()
rawQuery := dnsGenQuery(dns.TypeA, queryID)
rawResponse := dnsGenLookupHostReplySuccess(rawQuery)
query := &mocks.DNSQuery{
MockID: func() uint16 {
return queryID
},
}
resp, err := d.DecodeResponse(rawResponse, query)
if err != nil {
t.Fatal(err)
}
addrs, err := resp.DecodeLookupHost()
if !errors.Is(err, ErrOODNSNoAnswer) {
t.Fatal("unexpected err", err)
}
if len(addrs) > 0 {
t.Fatal("expected empty ns results")
}
})
t.Run("decode A", func(t *testing.T) {
d := &DNSDecoderMiekg{}
queryID := dns.Id()
rawQuery := dnsGenQuery(dns.TypeA, queryID)
rawResponse := dnsGenLookupHostReplySuccess(rawQuery, "1.1.1.1", "8.8.8.8")
query := &mocks.DNSQuery{
MockID: func() uint16 {
return queryID
},
MockType: func() uint16 {
return dns.TypeA
},
}
resp, err := d.DecodeResponse(rawResponse, query)
if err != nil {
t.Fatal(err)
}
addrs, err := resp.DecodeLookupHost()
if err != nil {
t.Fatal(err)
}
if len(addrs) != 2 {
t.Fatal("expected two entries here")
}
if addrs[0] != "1.1.1.1" {
t.Fatal("invalid first IPv4 entry")
}
if addrs[1] != "8.8.8.8" {
t.Fatal("invalid second IPv4 entry")
}
})
t.Run("decode AAAA", func(t *testing.T) {
d := &DNSDecoderMiekg{}
queryID := dns.Id()
rawQuery := dnsGenQuery(dns.TypeAAAA, queryID)
rawResponse := dnsGenLookupHostReplySuccess(rawQuery, "::1", "fe80::1")
query := &mocks.DNSQuery{
MockID: func() uint16 {
return queryID
},
MockType: func() uint16 {
return dns.TypeAAAA
},
}
resp, err := d.DecodeResponse(rawResponse, query)
if err != nil {
t.Fatal(err)
}
addrs, err := resp.DecodeLookupHost()
if err != nil {
t.Fatal(err)
}
if len(addrs) != 2 {
t.Fatal("expected two entries here")
}
if addrs[0] != "::1" {
t.Fatal("invalid first IPv6 entry")
}
if addrs[1] != "fe80::1" {
t.Fatal("invalid second IPv6 entry")
}
})
t.Run("unexpected A reply to AAAA query", func(t *testing.T) {
d := &DNSDecoderMiekg{}
queryID := dns.Id()
rawQuery := dnsGenQuery(dns.TypeAAAA, queryID)
rawResponse := dnsGenLookupHostReplySuccess(rawQuery, "1.1.1.1", "8.8.8.8")
query := &mocks.DNSQuery{
MockID: func() uint16 {
return queryID
},
MockType: func() uint16 {
return dns.TypeAAAA
},
}
resp, err := d.DecodeResponse(rawResponse, query)
if err != nil {
t.Fatal(err)
}
addrs, err := resp.DecodeLookupHost()
if !errors.Is(err, ErrOODNSNoAnswer) {
t.Fatal("not the error we expected", err)
}
if len(addrs) > 0 {
t.Fatal("expected no addrs here")
}
})
t.Run("unexpected AAAA reply to A query", func(t *testing.T) {
d := &DNSDecoderMiekg{}
queryID := dns.Id()
rawQuery := dnsGenQuery(dns.TypeA, queryID)
rawResponse := dnsGenLookupHostReplySuccess(rawQuery, "::1", "fe80::1")
query := &mocks.DNSQuery{
MockID: func() uint16 {
return queryID
},
MockType: func() uint16 {
return dns.TypeA
},
}
resp, err := d.DecodeResponse(rawResponse, query)
if err != nil {
t.Fatal(err)
}
addrs, err := resp.DecodeLookupHost()
if !errors.Is(err, ErrOODNSNoAnswer) {
t.Fatal("not the error we expected", err)
}
if len(addrs) > 0 {
t.Fatal("expected no addrs here")
}
})
})
})
}
@ -371,8 +522,8 @@ func dnsGenReplyWithError(rawQuery []byte, code int) []byte {
return data
}
// dnsGenLookupHostReplySuccess generates a successful DNS reply for the given
// qtype (e.g., dns.TypeA) containing the given ips... in the answer.
// dnsGenLookupHostReplySuccess generates a successful DNS reply containing the given ips...
// in the answers where each answer's type depends on the IP's type (A/AAAA).
func dnsGenLookupHostReplySuccess(rawQuery []byte, ips ...string) []byte {
query := new(dns.Msg)
err := query.Unpack(rawQuery)
@ -388,28 +539,22 @@ func dnsGenLookupHostReplySuccess(rawQuery []byte, ips ...string) []byte {
reply.MsgHdr.RecursionAvailable = true
reply.SetReply(query)
for _, ip := range ips {
switch question.Qtype {
case dns.TypeA:
if isIPv6(ip) {
continue
}
switch isIPv6(ip) {
case false:
reply.Answer = append(reply.Answer, &dns.A{
Hdr: dns.RR_Header{
Name: dns.Fqdn("x.org"),
Rrtype: question.Qtype,
Name: question.Name,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: 0,
},
A: net.ParseIP(ip),
})
case dns.TypeAAAA:
if !isIPv6(ip) {
continue
}
case true:
reply.Answer = append(reply.Answer, &dns.AAAA{
Hdr: dns.RR_Header{
Name: dns.Fqdn("x.org"),
Rrtype: question.Qtype,
Name: question.Name,
Rrtype: dns.TypeAAAA,
Class: dns.ClassINET,
Ttl: 0,
},

View File

@ -5,7 +5,10 @@ package netxlite
//
import (
"sync"
"github.com/miekg/dns"
"github.com/ooni/probe-cli/v3/internal/atomicx"
"github.com/ooni/probe-cli/v3/internal/model"
)
@ -23,18 +26,82 @@ const (
dnsDNSSECEnabled = true
)
func (e *DNSEncoderMiekg) Encode(domain string, qtype uint16, padding bool) ([]byte, uint16, error) {
// Encoder implements model.DNSEncoder.Encode.
func (e *DNSEncoderMiekg) Encode(domain string, qtype uint16, padding bool) model.DNSQuery {
return &dnsQuery{
bytesCalls: &atomicx.Int64{},
domain: domain,
kind: qtype,
id: dns.Id(),
memoizedBytes: []byte{},
mu: sync.Mutex{},
padding: padding,
}
}
// dnsQuery implements model.DNSQuery.
type dnsQuery struct {
// bytesCalls counts the calls to the bytes() method
bytesCalls *atomicx.Int64
// domain is the domain.
domain string
// kind is the query type.
kind uint16
// id is the query ID.
id uint16
// memoizedBytes contains the query encoded as bytes. We only fill
// this field the first time the Bytes method is called.
memoizedBytes []byte
// mu provides mutual exclusion.
mu sync.Mutex
// padding indicates whether we need padding.
padding bool
}
// Domain implements model.DNSQuery.Domain.
func (q *dnsQuery) Domain() string {
return q.domain
}
// Type implements model.DNSQuery.Type.
func (q *dnsQuery) Type() uint16 {
return q.kind
}
// Bytes implements model.DNSQuery.Bytes.
func (q *dnsQuery) Bytes() ([]byte, error) {
defer q.mu.Unlock()
q.mu.Lock()
if len(q.memoizedBytes) <= 0 {
q.bytesCalls.Add(1) // for testing
data, err := q.bytes()
if err != nil {
return nil, err
}
q.memoizedBytes = data
}
return q.memoizedBytes, nil
}
// bytes is the unmemoized implementation of Bytes
func (q *dnsQuery) bytes() ([]byte, error) {
question := dns.Question{
Name: dns.Fqdn(domain),
Qtype: qtype,
Name: dns.Fqdn(q.domain),
Qtype: q.kind,
Qclass: dns.ClassINET,
}
query := new(dns.Msg)
query.Id = dns.Id()
query.Id = q.id
query.RecursionDesired = true
query.Question = make([]dns.Question, 1)
query.Question[0] = question
if padding {
if q.padding {
query.SetEdns0(dnsEDNS0MaxResponseSize, dnsDNSSECEnabled)
// Clients SHOULD pad queries to the closest multiple of
// 128 octets RFC8467#section-4.1. We inflate the query
@ -47,8 +114,13 @@ func (e *DNSEncoderMiekg) Encode(domain string, qtype uint16, padding bool) ([]b
opt.Padding = make([]byte, remainder)
query.IsEdns0().Option = append(query.IsEdns0().Option, opt)
}
data, err := query.Pack()
return data, query.Id, err
return query.Pack()
}
// ID implements model.DNSQuery.ID
func (q *dnsQuery) ID() uint16 {
return q.id
}
var _ model.DNSEncoder = &DNSEncoderMiekg{}
var _ model.DNSQuery = &dnsQuery{}

View File

@ -1,29 +1,103 @@
package netxlite
import (
"bytes"
"encoding/binary"
"strings"
"testing"
"github.com/miekg/dns"
"github.com/ooni/probe-cli/v3/internal/randx"
"github.com/ooni/probe-cli/v3/internal/runtimex"
)
func TestDNSEncoder(t *testing.T) {
func TestDNSEncoderMiekg(t *testing.T) {
t.Run("we can fail to encode a domain name to bytes", func(t *testing.T) {
e := &DNSEncoderMiekg{}
domain := randx.LettersUppercase(512)
query := e.Encode(domain, dns.TypeA, false)
data, err := query.Bytes()
if err == nil || !strings.HasSuffix(err.Error(), "bad rdata") {
t.Fatal("unexpected err", err)
}
if data != nil {
t.Fatal("expected nil data here")
}
})
t.Run("calls to bytes are memoized", func(t *testing.T) {
t.Run("on success", func(t *testing.T) {
e := &DNSEncoderMiekg{}
query := e.Encode("x.org", dns.TypeA, false)
checkResult := func(data []byte, err error) {
if err != nil {
t.Fatal("unexpected err", err)
}
dnsValidateEncodedQueryBytes(t, data, byte(dns.TypeA), query.ID())
}
const repeat = 3
for idx := 0; idx < repeat; idx++ {
checkResult(query.Bytes())
}
// The following cast will always work in this configuration
if query.(*dnsQuery).bytesCalls.Load() != 1 {
t.Fatal("invalid number of calls")
}
})
t.Run("on failure", func(t *testing.T) {
e := &DNSEncoderMiekg{}
domain := randx.LettersUppercase(512)
query := e.Encode(domain, dns.TypeA, false)
checkResult := func(data []byte, err error) {
if err == nil || !strings.HasSuffix(err.Error(), "bad rdata") {
t.Fatal("unexpected err", err)
}
if data != nil {
t.Fatal("expected nil data here")
}
}
const repeat = 3
for idx := 0; idx < repeat; idx++ {
checkResult(query.Bytes())
}
// The following cast will always work in this configuration
if query.(*dnsQuery).bytesCalls.Load() != repeat {
t.Fatal("invalid number of calls")
}
})
})
t.Run("encode A", func(t *testing.T) {
e := &DNSEncoderMiekg{}
data, _, err := e.Encode("x.org", dns.TypeA, false)
query := e.Encode("x.org", dns.TypeA, false)
if query.Domain() != "x.org" {
t.Fatal("invalid domain")
}
if query.Type() != dns.TypeA {
t.Fatal("invalid type")
}
data, err := query.Bytes()
if err != nil {
t.Fatal(err)
}
dnsValidateEncodedQueryBytes(t, data, byte(dns.TypeA))
dnsValidateEncodedQueryBytes(t, data, byte(dns.TypeA), query.ID())
})
t.Run("encode AAAA", func(t *testing.T) {
e := &DNSEncoderMiekg{}
data, _, err := e.Encode("x.org", dns.TypeAAAA, false)
query := e.Encode("x.org", dns.TypeAAAA, false)
if query.Domain() != "x.org" {
t.Fatal("invalid domain")
}
if query.Type() != dns.TypeAAAA {
t.Fatal("invalid type")
}
data, err := query.Bytes()
if err != nil {
t.Fatal(err)
}
dnsValidateEncodedQueryBytes(t, data, byte(dns.TypeA))
dnsValidateEncodedQueryBytes(t, data, byte(dns.TypeA), query.ID())
})
t.Run("encode padding", func(t *testing.T) {
@ -31,7 +105,7 @@ func TestDNSEncoder(t *testing.T) {
// array of values we obtain the right query size.
getquerylen := func(domainlen int, padding bool) int {
e := &DNSEncoderMiekg{}
data, _, err := e.Encode(
query := e.Encode(
// This is not a valid name because it ends up being way
// longer than 255 octets. However, the library is allowing
// us to generate such name and we are not going to send
@ -40,6 +114,7 @@ func TestDNSEncoder(t *testing.T) {
dns.Fqdn(strings.Repeat("x.", domainlen)),
dns.TypeA, padding,
)
data, err := query.Bytes()
if err != nil {
t.Fatal(err)
}
@ -63,8 +138,13 @@ func TestDNSEncoder(t *testing.T) {
// dnsValidateEncodedQueryBytes validates the query serialized in data
// for the given query type qtype (e.g., dns.TypeAAAA).
func dnsValidateEncodedQueryBytes(t *testing.T, data []byte, qtype byte) {
// skipping over the query ID
func dnsValidateEncodedQueryBytes(t *testing.T, data []byte, qtype byte, qid uint16) {
var wirequery uint16
err := binary.Read(bytes.NewReader(data), binary.BigEndian, &wirequery)
runtimex.PanicOnError(err, "binary.Read failed unexpectedly")
if wirequery != qid {
t.Fatal("invalid query ID")
}
if data[2] != 1 {
t.Fatal("FLAGS should only have RD set")
}

View File

@ -8,6 +8,7 @@ import (
"bytes"
"context"
"errors"
"io"
"net/http"
"time"
@ -19,6 +20,9 @@ type DNSOverHTTPSTransport struct {
// Client is the MANDATORY http client to use.
Client model.HTTPClient
// Decoder is the MANDATORY DNSDecoder.
Decoder model.DNSDecoder
// URL is the MANDATORY URL of the DNS-over-HTTPS server.
URL string
@ -31,9 +35,9 @@ type DNSOverHTTPSTransport struct {
//
// Arguments:
//
// - client in http.Client-like type (e.g., http.DefaultClient);
// - client is a model.HTTPClient type;
//
// - URL is the DoH resolver URL (e.g., https://1.1.1.1/dns-query).
// - URL is the DoH resolver URL (e.g., https://dns.google/dns-query).
func NewDNSOverHTTPSTransport(client model.HTTPClient, URL string) *DNSOverHTTPSTransport {
return NewDNSOverHTTPSTransportWithHostOverride(client, URL, "")
}
@ -42,22 +46,31 @@ func NewDNSOverHTTPSTransport(client model.HTTPClient, URL string) *DNSOverHTTPS
// with the given Host header override.
func NewDNSOverHTTPSTransportWithHostOverride(
client model.HTTPClient, URL, hostOverride string) *DNSOverHTTPSTransport {
return &DNSOverHTTPSTransport{Client: client, URL: URL, HostOverride: hostOverride}
return &DNSOverHTTPSTransport{
Client: client,
Decoder: &DNSDecoderMiekg{},
URL: URL,
HostOverride: hostOverride,
}
}
// RoundTrip sends a query and receives a reply.
func (t *DNSOverHTTPSTransport) RoundTrip(ctx context.Context, query []byte) ([]byte, error) {
func (t *DNSOverHTTPSTransport) RoundTrip(
ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
rawQuery, err := query.Bytes()
if err != nil {
return nil, err
}
ctx, cancel := context.WithTimeout(ctx, 45*time.Second)
defer cancel()
req, err := http.NewRequest("POST", t.URL, bytes.NewReader(query))
req, err := http.NewRequest("POST", t.URL, bytes.NewReader(rawQuery))
if err != nil {
return nil, err
}
req.Host = t.HostOverride
req.Header.Set("user-agent", model.HTTPHeaderUserAgent)
req.Header.Set("content-type", "application/dns-message")
var resp *http.Response
resp, err = t.Client.Do(req.WithContext(ctx))
resp, err := t.Client.Do(req.WithContext(ctx))
if err != nil {
return nil, err
}
@ -70,7 +83,13 @@ func (t *DNSOverHTTPSTransport) RoundTrip(ctx context.Context, query []byte) ([]
if resp.Header.Get("content-type") != "application/dns-message" {
return nil, errors.New("doh: invalid content-type")
}
return ReadAllContext(ctx, resp.Body)
const maxresponsesize = 1 << 20
limitReader := io.LimitReader(resp.Body, maxresponsesize)
rawResponse, err := ReadAllContext(ctx, limitReader)
if err != nil {
return nil, err
}
return t.Decoder.DecodeResponse(rawResponse, query)
}
// RequiresPadding returns true for DoH according to RFC8467.

View File

@ -15,14 +15,36 @@ import (
func TestDNSOverHTTPSTransport(t *testing.T) {
t.Run("RoundTrip", func(t *testing.T) {
t.Run("query serialization failure", func(t *testing.T) {
txp := NewDNSOverHTTPSTransport(http.DefaultClient, "https://1.1.1.1/dns-query")
expected := errors.New("mocked error")
query := &mocks.DNSQuery{
MockBytes: func() ([]byte, error) {
return nil, expected
},
}
resp, err := txp.RoundTrip(context.Background(), query)
if !errors.Is(err, expected) {
t.Fatal("unexpected err", err)
}
if resp != nil {
t.Fatal("expected no response here")
}
})
t.Run("NewRequestFailure", func(t *testing.T) {
const invalidURL = "\t"
txp := NewDNSOverHTTPSTransport(http.DefaultClient, invalidURL)
data, err := txp.RoundTrip(context.Background(), nil)
if err == nil || !strings.HasSuffix(err.Error(), "invalid control character in URL") {
t.Fatal("expected an error here")
query := &mocks.DNSQuery{
MockBytes: func() ([]byte, error) {
return make([]byte, 17), nil
},
}
if data != nil {
resp, err := txp.RoundTrip(context.Background(), query)
if err == nil || !strings.HasSuffix(err.Error(), "invalid control character in URL") {
t.Fatal("unexpected err", err)
}
if resp != nil {
t.Fatal("expected no response here")
}
})
@ -37,11 +59,16 @@ func TestDNSOverHTTPSTransport(t *testing.T) {
},
URL: "https://cloudflare-dns.com/dns-query",
}
data, err := txp.RoundTrip(context.Background(), nil)
if !errors.Is(err, expected) {
t.Fatal("expected an error here")
query := &mocks.DNSQuery{
MockBytes: func() ([]byte, error) {
return make([]byte, 17), nil
},
}
if data != nil {
resp, err := txp.RoundTrip(context.Background(), query)
if !errors.Is(err, expected) {
t.Fatal("unexpected err", err)
}
if resp != nil {
t.Fatal("expected no response here")
}
})
@ -58,11 +85,16 @@ func TestDNSOverHTTPSTransport(t *testing.T) {
},
URL: "https://cloudflare-dns.com/dns-query",
}
data, err := txp.RoundTrip(context.Background(), nil)
if err == nil || err.Error() != "doh: server returned error" {
t.Fatal("expected an error here")
query := &mocks.DNSQuery{
MockBytes: func() ([]byte, error) {
return make([]byte, 17), nil
},
}
if data != nil {
resp, err := txp.RoundTrip(context.Background(), query)
if err == nil || err.Error() != "doh: server returned error" {
t.Fatal("unexpected err", err)
}
if resp != nil {
t.Fatal("expected no response here")
}
})
@ -79,11 +111,86 @@ func TestDNSOverHTTPSTransport(t *testing.T) {
},
URL: "https://cloudflare-dns.com/dns-query",
}
data, err := txp.RoundTrip(context.Background(), nil)
if err == nil || err.Error() != "doh: invalid content-type" {
t.Fatal("expected an error here")
query := &mocks.DNSQuery{
MockBytes: func() ([]byte, error) {
return make([]byte, 17), nil
},
}
if data != nil {
resp, err := txp.RoundTrip(context.Background(), query)
if err == nil || err.Error() != "doh: invalid content-type" {
t.Fatal("unexpected err", err)
}
if resp != nil {
t.Fatal("expected no response here")
}
})
t.Run("ReadAllContext fails", func(t *testing.T) {
expected := errors.New("mocked error")
txp := &DNSOverHTTPSTransport{
Client: &mocks.HTTPClient{
MockDo: func(*http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: 200,
Body: io.NopCloser(&mocks.Reader{
MockRead: func(b []byte) (int, error) {
return 0, expected
},
}),
Header: http.Header{
"Content-Type": []string{"application/dns-message"},
},
}, nil
},
},
URL: "https://cloudflare-dns.com/dns-query",
}
query := &mocks.DNSQuery{
MockBytes: func() ([]byte, error) {
return make([]byte, 17), nil
},
}
resp, err := txp.RoundTrip(context.Background(), query)
if !errors.Is(err, expected) {
t.Fatal("unexpected err", err)
}
if resp != nil {
t.Fatal("expected no response here")
}
})
t.Run("decode response failure", func(t *testing.T) {
expected := errors.New("mocked error")
body := []byte("AAA")
txp := &DNSOverHTTPSTransport{
Client: &mocks.HTTPClient{
MockDo: func(*http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: 200,
Body: io.NopCloser(bytes.NewReader(body)),
Header: http.Header{
"Content-Type": []string{"application/dns-message"},
},
}, nil
},
},
URL: "https://cloudflare-dns.com/dns-query",
Decoder: &mocks.DNSDecoder{
MockDecodeResponse: func(data []byte, query model.DNSQuery) (model.DNSResponse, error) {
return nil, expected
},
},
}
query := &mocks.DNSQuery{
MockBytes: func() ([]byte, error) {
return make([]byte, 17), nil
},
}
resp, err := txp.RoundTrip(context.Background(), query)
if !errors.Is(err, expected) {
t.Fatal("unexpected err", err)
}
if resp != nil {
t.Fatal("expected no response here")
}
})
@ -103,13 +210,23 @@ func TestDNSOverHTTPSTransport(t *testing.T) {
},
},
URL: "https://cloudflare-dns.com/dns-query",
Decoder: &mocks.DNSDecoder{
MockDecodeResponse: func(data []byte, query model.DNSQuery) (model.DNSResponse, error) {
return &mocks.DNSResponse{}, nil
},
},
}
data, err := txp.RoundTrip(context.Background(), nil)
query := &mocks.DNSQuery{
MockBytes: func() ([]byte, error) {
return make([]byte, 17), nil
},
}
resp, err := txp.RoundTrip(context.Background(), query)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(data, body) {
t.Fatal("not the response we expected")
if resp == nil {
t.Fatal("expected non-nil resp here")
}
})
@ -125,7 +242,12 @@ func TestDNSOverHTTPSTransport(t *testing.T) {
},
URL: "https://cloudflare-dns.com/dns-query",
}
data, err := txp.RoundTrip(context.Background(), nil)
query := &mocks.DNSQuery{
MockBytes: func() ([]byte, error) {
return make([]byte, 17), nil
},
}
data, err := txp.RoundTrip(context.Background(), query)
if !errors.Is(err, expected) {
t.Fatal("expected an error here")
}
@ -151,18 +273,22 @@ func TestDNSOverHTTPSTransport(t *testing.T) {
URL: "https://cloudflare-dns.com/dns-query",
HostOverride: hostOverride,
}
data, err := txp.RoundTrip(context.Background(), nil)
query := &mocks.DNSQuery{
MockBytes: func() ([]byte, error) {
return make([]byte, 17), nil
},
}
resp, err := txp.RoundTrip(context.Background(), query)
if !errors.Is(err, expected) {
t.Fatal("expected an error here")
}
if data != nil {
if resp != nil {
t.Fatal("expected no response here")
}
if !correct {
t.Fatal("did not see correct host override")
}
})
})
t.Run("other functions behave correctly", func(t *testing.T) {

View File

@ -20,9 +20,12 @@ type DialContextFunc func(context.Context, string, string) (net.Conn, error)
// DNSOverTCPTransport is a DNS-over-{TCP,TLS} DNSTransport.
//
// Bug: this implementation always creates a new connection for each query.
// Note: this implementation always creates a new connection for each query. This
// strategy is less efficient but MAY be more robust for cleartext TCP connections
// when querying for a blocked domain name causes endpoint blocking.
type DNSOverTCPTransport struct {
dial DialContextFunc
decoder model.DNSDecoder
address string
network string
requiresPadding bool
@ -36,47 +39,58 @@ type DNSOverTCPTransport struct {
//
// - address is the endpoint address (e.g., 8.8.8.8:53).
func NewDNSOverTCPTransport(dial DialContextFunc, address string) *DNSOverTCPTransport {
return &DNSOverTCPTransport{
dial: dial,
address: address,
network: "tcp",
requiresPadding: false,
}
return newDNSOverTCPOrTLSTransport(dial, "tcp", address, false)
}
// NewDNSOverTLS creates a new DNSOverTLS transport.
// NewDNSOverTLSTransport creates a new DNSOverTLS transport.
//
// Arguments:
//
// - dial is a function with the net.Dialer.DialContext's signature;
//
// - address is the endpoint address (e.g., 8.8.8.8:853).
func NewDNSOverTLS(dial DialContextFunc, address string) *DNSOverTCPTransport {
func NewDNSOverTLSTransport(dial DialContextFunc, address string) *DNSOverTCPTransport {
return newDNSOverTCPOrTLSTransport(dial, "dot", address, true)
}
// newDNSOverTCPOrTLSTransport is the common factory for creating a transport
func newDNSOverTCPOrTLSTransport(
dial DialContextFunc, network, address string, padding bool) *DNSOverTCPTransport {
return &DNSOverTCPTransport{
dial: dial,
decoder: &DNSDecoderMiekg{},
address: address,
network: "dot",
requiresPadding: true,
network: network,
requiresPadding: padding,
}
}
// errQueryTooLarge indicates the query is too large for the transport.
var errQueryTooLarge = errors.New("oodns: query too large for this transport")
// RoundTrip sends a query and receives a reply.
func (t *DNSOverTCPTransport) RoundTrip(ctx context.Context, query []byte) ([]byte, error) {
if len(query) > math.MaxUint16 {
return nil, errors.New("query too long")
func (t *DNSOverTCPTransport) RoundTrip(
ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
// TODO(bassosimone): this method should more strictly honour the context, which
// currently is only used to bound the dial operation
rawQuery, err := query.Bytes()
if err != nil {
return nil, err
}
if len(rawQuery) > math.MaxUint16 {
return nil, errQueryTooLarge
}
conn, err := t.dial(ctx, "tcp", t.address)
if err != nil {
return nil, err
}
defer conn.Close()
if err = conn.SetDeadline(time.Now().Add(10 * time.Second)); err != nil {
return nil, err
}
const iotimeout = 10 * time.Second
conn.SetDeadline(time.Now().Add(iotimeout))
// Write request
buf := []byte{byte(len(query) >> 8)}
buf = append(buf, byte(len(query)))
buf = append(buf, query...)
buf := []byte{byte(len(rawQuery) >> 8)}
buf = append(buf, byte(len(rawQuery)))
buf = append(buf, rawQuery...)
if _, err = conn.Write(buf); err != nil {
return nil, err
}
@ -86,11 +100,11 @@ func (t *DNSOverTCPTransport) RoundTrip(ctx context.Context, query []byte) ([]by
return nil, err
}
length := int(header[0])<<8 | int(header[1])
reply := make([]byte, length)
if _, err = io.ReadFull(conn, reply); err != nil {
rawResponse := make([]byte, length)
if _, err = io.ReadFull(conn, rawResponse); err != nil {
return nil, err
}
return reply, nil
return t.decoder.DecodeResponse(rawResponse, query)
}
// RequiresPadding returns true for DoT and false for TCP

View File

@ -6,73 +6,83 @@ import (
"crypto/tls"
"errors"
"io"
"math"
"net"
"testing"
"time"
"github.com/ooni/probe-cli/v3/internal/model"
"github.com/ooni/probe-cli/v3/internal/model/mocks"
)
func TestDNSOverTCPTransport(t *testing.T) {
t.Run("RoundTrip", func(t *testing.T) {
t.Run("cannot encode query", func(t *testing.T) {
expected := errors.New("mocked error")
const address = "9.9.9.9:53"
txp := NewDNSOverTCPTransport(new(net.Dialer).DialContext, address)
query := &mocks.DNSQuery{
MockBytes: func() ([]byte, error) {
return nil, expected
},
}
resp, err := txp.RoundTrip(context.Background(), query)
if !errors.Is(err, expected) {
t.Fatal("unexpected err", err)
}
if resp != nil {
t.Fatal("expected nil response here")
}
})
t.Run("query too large", func(t *testing.T) {
const address = "9.9.9.9:53"
txp := NewDNSOverTCPTransport(new(net.Dialer).DialContext, address)
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<18))
if err == nil {
t.Fatal("expected an error here")
query := &mocks.DNSQuery{
MockBytes: func() ([]byte, error) {
return make([]byte, math.MaxUint16+1), nil
},
}
if reply != nil {
t.Fatal("expected nil reply here")
resp, err := txp.RoundTrip(context.Background(), query)
if !errors.Is(err, errQueryTooLarge) {
t.Fatal("unexpected err", err)
}
if resp != nil {
t.Fatal("expected nil response here")
}
})
t.Run("dial failure", func(t *testing.T) {
const address = "9.9.9.9:53"
mocked := errors.New("mocked error")
query := &mocks.DNSQuery{
MockBytes: func() ([]byte, error) {
return make([]byte, 128), nil
},
}
fakedialer := &mocks.Dialer{
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
return nil, mocked
},
}
txp := NewDNSOverTCPTransport(fakedialer.DialContext, address)
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11))
resp, err := txp.RoundTrip(context.Background(), query)
if !errors.Is(err, mocked) {
t.Fatal("not the error we expected")
}
if reply != nil {
t.Fatal("expected nil reply here")
}
})
t.Run("SetDeadline failure", func(t *testing.T) {
const address = "9.9.9.9:53"
mocked := errors.New("mocked error")
fakedialer := &mocks.Dialer{
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
return &mocks.Conn{
MockSetDeadline: func(t time.Time) error {
return mocked
},
MockClose: func() error {
return nil
},
}, nil
},
}
txp := NewDNSOverTCPTransport(fakedialer.DialContext, address)
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11))
if !errors.Is(err, mocked) {
t.Fatal("not the error we expected")
}
if reply != nil {
t.Fatal("expected nil reply here")
if resp != nil {
t.Fatal("expected nil resp here")
}
})
t.Run("write failure", func(t *testing.T) {
const address = "9.9.9.9:53"
mocked := errors.New("mocked error")
query := &mocks.DNSQuery{
MockBytes: func() ([]byte, error) {
return make([]byte, 128), nil
},
}
fakedialer := &mocks.Dialer{
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
return &mocks.Conn{
@ -89,18 +99,23 @@ func TestDNSOverTCPTransport(t *testing.T) {
},
}
txp := NewDNSOverTCPTransport(fakedialer.DialContext, address)
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11))
resp, err := txp.RoundTrip(context.Background(), query)
if !errors.Is(err, mocked) {
t.Fatal("not the error we expected")
}
if reply != nil {
t.Fatal("expected nil reply here")
if resp != nil {
t.Fatal("expected nil resp here")
}
})
t.Run("first read fails", func(t *testing.T) {
const address = "9.9.9.9:53"
mocked := errors.New("mocked error")
query := &mocks.DNSQuery{
MockBytes: func() ([]byte, error) {
return make([]byte, 128), nil
},
}
fakedialer := &mocks.Dialer{
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
return &mocks.Conn{
@ -120,18 +135,23 @@ func TestDNSOverTCPTransport(t *testing.T) {
},
}
txp := NewDNSOverTCPTransport(fakedialer.DialContext, address)
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11))
resp, err := txp.RoundTrip(context.Background(), query)
if !errors.Is(err, mocked) {
t.Fatal("not the error we expected")
}
if reply != nil {
t.Fatal("expected nil reply here")
if resp != nil {
t.Fatal("expected nil resp here")
}
})
t.Run("second read fails", func(t *testing.T) {
const address = "9.9.9.9:53"
mocked := errors.New("mocked error")
query := &mocks.DNSQuery{
MockBytes: func() ([]byte, error) {
return make([]byte, 128), nil
},
}
input := io.MultiReader(
bytes.NewReader([]byte{byte(0), byte(2)}),
&mocks.Reader{
@ -157,17 +177,23 @@ func TestDNSOverTCPTransport(t *testing.T) {
},
}
txp := NewDNSOverTCPTransport(fakedialer.DialContext, address)
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11))
resp, err := txp.RoundTrip(context.Background(), query)
if !errors.Is(err, mocked) {
t.Fatal("not the error we expected")
}
if reply != nil {
t.Fatal("expected nil reply here")
if resp != nil {
t.Fatal("expected nil resp here")
}
})
t.Run("successful case", func(t *testing.T) {
t.Run("decode failure", func(t *testing.T) {
const address = "9.9.9.9:53"
mocked := errors.New("mocked error")
query := &mocks.DNSQuery{
MockBytes: func() ([]byte, error) {
return make([]byte, 128), nil
},
}
input := bytes.NewReader([]byte{byte(0), byte(1), byte(1)})
fakedialer := &mocks.Dialer{
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
@ -186,11 +212,56 @@ func TestDNSOverTCPTransport(t *testing.T) {
},
}
txp := NewDNSOverTCPTransport(fakedialer.DialContext, address)
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11))
txp.decoder = &mocks.DNSDecoder{
MockDecodeResponse: func(data []byte, query model.DNSQuery) (model.DNSResponse, error) {
return nil, mocked
},
}
resp, err := txp.RoundTrip(context.Background(), query)
if !errors.Is(err, mocked) {
t.Fatal("unexpected err", err)
}
if resp != nil {
t.Fatal("expected nil resp here")
}
})
t.Run("successful case", func(t *testing.T) {
const address = "9.9.9.9:53"
query := &mocks.DNSQuery{
MockBytes: func() ([]byte, error) {
return make([]byte, 128), nil
},
}
input := bytes.NewReader([]byte{byte(0), byte(1), byte(1)})
fakedialer := &mocks.Dialer{
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
return &mocks.Conn{
MockSetDeadline: func(t time.Time) error {
return nil
},
MockWrite: func(b []byte) (int, error) {
return len(b), nil
},
MockRead: input.Read,
MockClose: func() error {
return nil
},
}, nil
},
}
txp := NewDNSOverTCPTransport(fakedialer.DialContext, address)
expectedResp := &mocks.DNSResponse{}
txp.decoder = &mocks.DNSDecoder{
MockDecodeResponse: func(data []byte, query model.DNSQuery) (model.DNSResponse, error) {
return expectedResp, nil
},
}
resp, err := txp.RoundTrip(context.Background(), query)
if err != nil {
t.Fatal(err)
}
if len(reply) != 1 || reply[0] != 1 {
if resp != expectedResp {
t.Fatal("not the response we expected")
}
})
@ -213,7 +284,7 @@ func TestDNSOverTCPTransport(t *testing.T) {
t.Run("other functions okay with TLS", func(t *testing.T) {
const address = "9.9.9.9:853"
txp := NewDNSOverTLS((&tls.Dialer{}).DialContext, address)
txp := NewDNSOverTLSTransport((&tls.Dialer{}).DialContext, address)
if txp.RequiresPadding() != true {
t.Fatal("invalid RequiresPadding")
}

View File

@ -14,6 +14,7 @@ import (
// DNSOverUDPTransport is a DNS-over-UDP DNSTransport.
type DNSOverUDPTransport struct {
dialer model.Dialer
decoder model.DNSDecoder
address string
}
@ -25,11 +26,20 @@ type DNSOverUDPTransport struct {
//
// - address is the endpoint address (e.g., 8.8.8.8:53).
func NewDNSOverUDPTransport(dialer model.Dialer, address string) *DNSOverUDPTransport {
return &DNSOverUDPTransport{dialer: dialer, address: address}
return &DNSOverUDPTransport{
dialer: dialer,
decoder: &DNSDecoderMiekg{},
address: address,
}
}
// RoundTrip sends a query and receives a reply.
func (t *DNSOverUDPTransport) RoundTrip(ctx context.Context, query []byte) ([]byte, error) {
func (t *DNSOverUDPTransport) RoundTrip(
ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
rawQuery, err := query.Bytes()
if err != nil {
return nil, err
}
conn, err := t.dialer.DialContext(ctx, "udp", t.address)
if err != nil {
return nil, err
@ -37,19 +47,19 @@ func (t *DNSOverUDPTransport) RoundTrip(ctx context.Context, query []byte) ([]by
defer conn.Close()
// Use five seconds timeout like Bionic does. See
// https://labs.ripe.net/Members/baptiste_jonglez_1/persistent-dns-connections-for-reliability-and-performance
if err = conn.SetDeadline(time.Now().Add(5 * time.Second)); err != nil {
const iotimeout = 5 * time.Second
conn.SetDeadline(time.Now().Add(iotimeout))
if _, err = conn.Write(rawQuery); err != nil {
return nil, err
}
if _, err = conn.Write(query); err != nil {
return nil, err
}
reply := make([]byte, 1<<17)
var n int
n, err = conn.Read(reply)
const maxmessagesize = 1 << 17
rawResponse := make([]byte, maxmessagesize)
count, err := conn.Read(rawResponse)
if err != nil {
return nil, err
}
return reply[:n], nil
rawResponse = rawResponse[:count]
return t.decoder.DecodeResponse(rawResponse, query)
}
// RequiresPadding returns false for UDP according to RFC8467.

View File

@ -9,11 +9,30 @@ import (
"time"
"github.com/apex/log"
"github.com/ooni/probe-cli/v3/internal/model"
"github.com/ooni/probe-cli/v3/internal/model/mocks"
)
func TestDNSOverUDPTransport(t *testing.T) {
t.Run("RoundTrip", func(t *testing.T) {
t.Run("cannot encode query", func(t *testing.T) {
expected := errors.New("mocked error")
const address = "9.9.9.9:53"
txp := NewDNSOverUDPTransport(nil, address)
query := &mocks.DNSQuery{
MockBytes: func() ([]byte, error) {
return nil, expected
},
}
resp, err := txp.RoundTrip(context.Background(), query)
if !errors.Is(err, expected) {
t.Fatal("unexpected err", err)
}
if resp != nil {
t.Fatal("expected nil response here")
}
})
t.Run("dial failure", func(t *testing.T) {
mocked := errors.New("mocked error")
const address = "9.9.9.9:53"
@ -22,36 +41,16 @@ func TestDNSOverUDPTransport(t *testing.T) {
return nil, mocked
},
}, address)
data, err := txp.RoundTrip(context.Background(), nil)
query := &mocks.DNSQuery{
MockBytes: func() ([]byte, error) {
return make([]byte, 128), nil
},
}
resp, err := txp.RoundTrip(context.Background(), query)
if !errors.Is(err, mocked) {
t.Fatal("not the error we expected")
}
if data != nil {
t.Fatal("expected no response here")
}
})
t.Run("SetDeadline failure", func(t *testing.T) {
mocked := errors.New("mocked error")
txp := NewDNSOverUDPTransport(
&mocks.Dialer{
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
return &mocks.Conn{
MockSetDeadline: func(t time.Time) error {
return mocked
},
MockClose: func() error {
return nil
},
}, nil
},
}, "9.9.9.9:53",
)
data, err := txp.RoundTrip(context.Background(), nil)
if !errors.Is(err, mocked) {
t.Fatal("not the error we expected")
}
if data != nil {
if resp != nil {
t.Fatal("expected no response here")
}
})
@ -75,11 +74,16 @@ func TestDNSOverUDPTransport(t *testing.T) {
},
}, "9.9.9.9:53",
)
data, err := txp.RoundTrip(context.Background(), nil)
query := &mocks.DNSQuery{
MockBytes: func() ([]byte, error) {
return make([]byte, 128), nil
},
}
resp, err := txp.RoundTrip(context.Background(), query)
if !errors.Is(err, mocked) {
t.Fatal("not the error we expected")
}
if data != nil {
if resp != nil {
t.Fatal("expected no response here")
}
})
@ -106,15 +110,61 @@ func TestDNSOverUDPTransport(t *testing.T) {
},
}, "9.9.9.9:53",
)
data, err := txp.RoundTrip(context.Background(), nil)
query := &mocks.DNSQuery{
MockBytes: func() ([]byte, error) {
return make([]byte, 128), nil
},
}
resp, err := txp.RoundTrip(context.Background(), query)
if !errors.Is(err, mocked) {
t.Fatal("not the error we expected")
}
if data != nil {
if resp != nil {
t.Fatal("expected no response here")
}
})
t.Run("decode failure", func(t *testing.T) {
const expected = 17
input := bytes.NewReader(make([]byte, expected))
txp := NewDNSOverUDPTransport(
&mocks.Dialer{
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
return &mocks.Conn{
MockSetDeadline: func(t time.Time) error {
return nil
},
MockWrite: func(b []byte) (int, error) {
return len(b), nil
},
MockRead: input.Read,
MockClose: func() error {
return nil
},
}, nil
},
}, "9.9.9.9:53",
)
expectedErr := errors.New("mocked error")
txp.decoder = &mocks.DNSDecoder{
MockDecodeResponse: func(data []byte, query model.DNSQuery) (model.DNSResponse, error) {
return nil, expectedErr
},
}
query := &mocks.DNSQuery{
MockBytes: func() ([]byte, error) {
return make([]byte, 128), nil
},
}
resp, err := txp.RoundTrip(context.Background(), query)
if !errors.Is(err, expectedErr) {
t.Fatal("unexpected err", err)
}
if resp != nil {
t.Fatal("expected nil resp")
}
})
t.Run("read success", func(t *testing.T) {
const expected = 17
input := bytes.NewReader(make([]byte, expected))
@ -136,12 +186,23 @@ func TestDNSOverUDPTransport(t *testing.T) {
},
}, "9.9.9.9:53",
)
data, err := txp.RoundTrip(context.Background(), nil)
expectedResp := &mocks.DNSResponse{}
txp.decoder = &mocks.DNSDecoder{
MockDecodeResponse: func(data []byte, query model.DNSQuery) (model.DNSResponse, error) {
return expectedResp, nil
},
}
query := &mocks.DNSQuery{
MockBytes: func() ([]byte, error) {
return make([]byte, 128), nil
},
}
resp, err := txp.RoundTrip(context.Background(), query)
if err != nil {
t.Fatal(err)
}
if len(data) != expected {
t.Fatal("expected non nil data")
if resp != expectedResp {
t.Fatal("unexpected resp")
}
})
})

View File

@ -1,15 +1,12 @@
package filtering
import (
"context"
"errors"
"io"
"net"
"net/http"
"strings"
"github.com/miekg/dns"
"github.com/ooni/probe-cli/v3/internal/netxlite"
"github.com/ooni/probe-cli/v3/internal/runtimex"
)
@ -51,19 +48,13 @@ type DNSProxy struct {
// receive a query for the given domain.
OnQuery func(domain string) DNSAction
// Upstream is the OPTIONAL upstream transport.
Upstream DNSTransport
// UpstreamEndpoint is the OPTIONAL upstream transport endpoint.
UpstreamEndpoint string
// mockableReply allows to mock DNSProxy.reply in tests.
mockableReply func(query *dns.Msg) (*dns.Msg, error)
}
// DNSTransport is the type we expect from an upstream DNS transport.
type DNSTransport interface {
RoundTrip(ctx context.Context, query []byte) ([]byte, error)
CloseIdleConnections()
}
// DNSListener is the interface returned by DNSProxy.Start
type DNSListener interface {
io.Closer
@ -204,23 +195,24 @@ func (p *DNSProxy) compose(query *dns.Msg, ips ...net.IP) *dns.Msg {
return reply
}
var (
// errDNSExpectedSingleQuestion means we expected to see a single question
errDNSExpectedSingleQuestion = errors.New("filtering: expected single DNS question")
// errDNSExpectedQueryNotResponse means we expected to see a query.
errDNSExpectedQueryNotResponse = errors.New("filtering: expected query not response")
)
func (p *DNSProxy) proxy(query *dns.Msg) (*dns.Msg, error) {
queryBytes, err := query.Pack()
if err != nil {
return nil, err
if query.Response {
return nil, errDNSExpectedQueryNotResponse
}
txp := p.dnstransport()
defer txp.CloseIdleConnections()
ctx := context.Background()
replyBytes, err := txp.RoundTrip(ctx, queryBytes)
if err != nil {
return nil, err
if len(query.Question) != 1 {
return nil, errDNSExpectedSingleQuestion
}
reply := &dns.Msg{}
if err := reply.Unpack(replyBytes); err != nil {
return nil, err
}
return reply, nil
clnt := &dns.Client{}
resp, _, err := clnt.Exchange(query, p.upstreamEndpoint())
return resp, err
}
func (p *DNSProxy) cache(name string, query *dns.Msg) *dns.Msg {
@ -237,10 +229,9 @@ func (p *DNSProxy) cache(name string, query *dns.Msg) *dns.Msg {
return p.compose(query, ipAddrs...)
}
func (p *DNSProxy) dnstransport() DNSTransport {
if p.Upstream != nil {
return p.Upstream
func (p *DNSProxy) upstreamEndpoint() string {
if p.UpstreamEndpoint != "" {
return p.UpstreamEndpoint
}
const URL = "https://1.1.1.1/dns-query"
return netxlite.NewDNSOverHTTPSTransport(http.DefaultClient, URL)
return "8.8.8.8:53"
}

View File

@ -283,9 +283,7 @@ func TestDNSProxy(t *testing.T) {
if len(p) < len(data) {
panic("buffer too small")
}
for i := 0; i < len(data); i++ {
p[i] = data[i]
}
copy(p, data)
return len(data), &net.UDPAddr{}, nil
},
}
@ -314,9 +312,7 @@ func TestDNSProxy(t *testing.T) {
if len(p) < len(data) {
panic("buffer too small")
}
for i := 0; i < len(data); i++ {
p[i] = data[i]
}
copy(p, data)
return len(data), &net.UDPAddr{}, nil
},
}
@ -328,12 +324,24 @@ func TestDNSProxy(t *testing.T) {
})
t.Run("proxy", func(t *testing.T) {
t.Run("Pack fails", func(t *testing.T) {
t.Run("with response", func(t *testing.T) {
p := &DNSProxy{}
query := &dns.Msg{}
query.Rcode = -1 // causes Pack to fail
query.Response = true
reply, err := p.proxy(query)
if err == nil || !strings.HasSuffix(err.Error(), "bad rcode") {
if !errors.Is(err, errDNSExpectedQueryNotResponse) {
t.Fatal("unexpected err", err)
}
if reply != nil {
t.Fatal("expected nil reply")
}
})
t.Run("with no questions", func(t *testing.T) {
p := &DNSProxy{}
query := &dns.Msg{}
reply, err := p.proxy(query)
if !errors.Is(err, errDNSExpectedSingleQuestion) {
t.Fatal("unexpected err", err)
}
if reply != nil {
@ -342,35 +350,13 @@ func TestDNSProxy(t *testing.T) {
})
t.Run("round trip fails", func(t *testing.T) {
expected := errors.New("mocked error")
p := &DNSProxy{
Upstream: &mocks.DNSTransport{
MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) {
return nil, expected
},
MockCloseIdleConnections: func() {},
},
UpstreamEndpoint: "antani",
}
reply, err := p.proxy(&dns.Msg{})
if !errors.Is(err, expected) {
t.Fatal("unexpected err", err)
}
if reply != nil {
t.Fatal("expected nil reply here")
}
})
t.Run("Unpack fails", func(t *testing.T) {
p := &DNSProxy{
Upstream: &mocks.DNSTransport{
MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) {
return make([]byte, 1), nil
},
MockCloseIdleConnections: func() {},
},
}
reply, err := p.proxy(&dns.Msg{})
if err == nil || !strings.HasSuffix(err.Error(), "overflow unpacking uint16") {
query := &dns.Msg{}
query.Question = append(query.Question, dns.Question{})
reply, err := p.proxy(query)
if err == nil || !strings.HasSuffix(err.Error(), "missing port in address") {
t.Fatal("unexpected err", err)
}
if reply != nil {

View File

@ -9,7 +9,6 @@ import (
"net"
"github.com/miekg/dns"
"github.com/ooni/probe-cli/v3/internal/atomicx"
"github.com/ooni/probe-cli/v3/internal/model"
)
@ -19,15 +18,6 @@ import (
// You should probably use NewUnwrappedParallelResolver to
// create a new instance of this type.
type ParallelResolver struct {
// Encoder is the MANDATORY encoder to use.
Encoder model.DNSEncoder
// Decoder is the MANDATORY decoder to use.
Decoder model.DNSDecoder
// NumTimeouts is MANDATORY and counts the number of timeouts.
NumTimeouts *atomicx.Int64
// Txp is the MANDATORY underlying DNS transport.
Txp model.DNSTransport
}
@ -36,10 +26,7 @@ type ParallelResolver struct {
// not wrapped and you should wrap if before using it.
func NewUnwrappedParallelResolver(t model.DNSTransport) *ParallelResolver {
return &ParallelResolver{
Encoder: &DNSEncoderMiekg{},
Decoder: &DNSDecoderMiekg{},
NumTimeouts: &atomicx.Int64{},
Txp: t,
Txp: t,
}
}
@ -80,22 +67,22 @@ func (r *ParallelResolver) LookupHost(ctx context.Context, hostname string) ([]s
var addrs []string
addrs = append(addrs, ares.addrs...)
addrs = append(addrs, aaaares.addrs...)
if len(addrs) < 1 {
return nil, ErrOODNSNoAnswer
}
return addrs, nil
}
// LookupHTTPS implements Resolver.LookupHTTPS.
func (r *ParallelResolver) LookupHTTPS(
ctx context.Context, hostname string) (*model.HTTPSSvc, error) {
querydata, queryID, err := r.Encoder.Encode(
hostname, dns.TypeHTTPS, r.Txp.RequiresPadding())
encoder := &DNSEncoderMiekg{}
query := encoder.Encode(hostname, dns.TypeHTTPS, r.Txp.RequiresPadding())
response, err := r.Txp.RoundTrip(ctx, query)
if err != nil {
return nil, err
}
replydata, err := r.Txp.RoundTrip(ctx, querydata)
if err != nil {
return nil, err
}
return r.Decoder.DecodeHTTPS(replydata, queryID)
return response.DecodeHTTPS()
}
// parallelResolverResult is the internal representation of a
@ -108,7 +95,9 @@ type parallelResolverResult struct {
// lookupHost issues a lookup host query for the specified qtype (e.g., dns.A).
func (r *ParallelResolver) lookupHost(ctx context.Context, hostname string,
qtype uint16, out chan<- *parallelResolverResult) {
querydata, queryID, err := r.Encoder.Encode(hostname, qtype, r.Txp.RequiresPadding())
encoder := &DNSEncoderMiekg{}
query := encoder.Encode(hostname, qtype, r.Txp.RequiresPadding())
response, err := r.Txp.RoundTrip(ctx, query)
if err != nil {
out <- &parallelResolverResult{
addrs: []string{},
@ -116,15 +105,7 @@ func (r *ParallelResolver) lookupHost(ctx context.Context, hostname string,
}
return
}
replydata, err := r.Txp.RoundTrip(ctx, querydata)
if err != nil {
out <- &parallelResolverResult{
addrs: []string{},
err: err,
}
return
}
addrs, err := r.Decoder.DecodeLookupHost(qtype, replydata, queryID)
addrs, err := response.DecodeLookupHost()
out <- &parallelResolverResult{
addrs: addrs,
err: err,
@ -134,14 +115,11 @@ func (r *ParallelResolver) lookupHost(ctx context.Context, hostname string,
// LookupNS implements Resolver.LookupNS.
func (r *ParallelResolver) LookupNS(
ctx context.Context, hostname string) ([]*net.NS, error) {
querydata, queryID, err := r.Encoder.Encode(
hostname, dns.TypeNS, r.Txp.RequiresPadding())
encoder := &DNSEncoderMiekg{}
query := encoder.Encode(hostname, dns.TypeNS, r.Txp.RequiresPadding())
response, err := r.Txp.RoundTrip(ctx, query)
if err != nil {
return nil, err
}
replydata, err := r.Txp.RoundTrip(ctx, querydata)
if err != nil {
return nil, err
}
return r.Decoder.DecodeNS(replydata, queryID)
return response.DecodeNS()
}

View File

@ -8,14 +8,13 @@ import (
"testing"
"github.com/miekg/dns"
"github.com/ooni/probe-cli/v3/internal/atomicx"
"github.com/ooni/probe-cli/v3/internal/model"
"github.com/ooni/probe-cli/v3/internal/model/mocks"
)
func TestParallelResolver(t *testing.T) {
t.Run("transport okay", func(t *testing.T) {
txp := NewDNSOverTLS((&tls.Dialer{}).DialContext, "8.8.8.8:853")
txp := NewDNSOverTLSTransport((&tls.Dialer{}).DialContext, "8.8.8.8:853")
r := NewUnwrappedParallelResolver(txp)
rtx := r.Transport()
if rtx.Network() != "dot" || rtx.Address() != "8.8.8.8:853" {
@ -30,30 +29,10 @@ func TestParallelResolver(t *testing.T) {
})
t.Run("LookupHost", func(t *testing.T) {
t.Run("Encode error", func(t *testing.T) {
mocked := errors.New("mocked error")
txp := NewDNSOverTLS((&tls.Dialer{}).DialContext, "8.8.8.8:853")
r := ParallelResolver{
Encoder: &mocks.DNSEncoder{
MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, uint16, error) {
return nil, 0, mocked
},
},
Txp: txp,
}
addrs, err := r.LookupHost(context.Background(), "www.gogle.com")
if !errors.Is(err, mocked) {
t.Fatal("not the error we expected")
}
if addrs != nil {
t.Fatal("expected nil address here")
}
})
t.Run("RoundTrip error", func(t *testing.T) {
mocked := errors.New("mocked error")
txp := &mocks.DNSTransport{
MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) {
MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
return nil, mocked
},
MockRequiresPadding: func() bool {
@ -72,8 +51,13 @@ func TestParallelResolver(t *testing.T) {
t.Run("empty reply", func(t *testing.T) {
txp := &mocks.DNSTransport{
MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) {
return dnsGenLookupHostReplySuccess(query), nil
MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
response := &mocks.DNSResponse{
MockDecodeLookupHost: func() ([]string, error) {
return nil, nil
},
}
return response, nil
},
MockRequiresPadding: func() bool {
return true
@ -91,8 +75,16 @@ func TestParallelResolver(t *testing.T) {
t.Run("with A reply", func(t *testing.T) {
txp := &mocks.DNSTransport{
MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) {
return dnsGenLookupHostReplySuccess(query, "8.8.8.8"), nil
MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
response := &mocks.DNSResponse{
MockDecodeLookupHost: func() ([]string, error) {
if query.Type() != dns.TypeA {
return nil, nil
}
return []string{"8.8.8.8"}, nil
},
}
return response, nil
},
MockRequiresPadding: func() bool {
return true
@ -110,8 +102,16 @@ func TestParallelResolver(t *testing.T) {
t.Run("with AAAA reply", func(t *testing.T) {
txp := &mocks.DNSTransport{
MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) {
return dnsGenLookupHostReplySuccess(query, "::1"), nil
MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
response := &mocks.DNSResponse{
MockDecodeLookupHost: func() ([]string, error) {
if query.Type() != dns.TypeAAAA {
return nil, nil
}
return []string{"::1"}, nil
},
}
return response, nil
},
MockRequiresPadding: func() bool {
return true
@ -131,22 +131,15 @@ func TestParallelResolver(t *testing.T) {
afailure := errors.New("a failure")
aaaafailure := errors.New("aaaa failure")
txp := &mocks.DNSTransport{
MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) {
msg := &dns.Msg{}
if err := msg.Unpack(query); err != nil {
return nil, err
}
if len(msg.Question) != 1 {
return nil, errors.New("expected just one question")
}
q := msg.Question[0]
if q.Qtype == dns.TypeA {
MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
switch query.Type() {
case dns.TypeA:
return nil, afailure
}
if q.Qtype == dns.TypeAAAA {
case dns.TypeAAAA:
return nil, aaaafailure
default:
return nil, errors.New("unexpected query")
}
return nil, errors.New("expected A or AAAA query")
},
MockRequiresPadding: func() bool {
return true
@ -179,44 +172,11 @@ func TestParallelResolver(t *testing.T) {
})
t.Run("LookupHTTPS", func(t *testing.T) {
t.Run("for encoding error", func(t *testing.T) {
expected := errors.New("mocked error")
r := &ParallelResolver{
Encoder: &mocks.DNSEncoder{
MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, uint16, error) {
return nil, 0, expected
},
},
Decoder: nil,
NumTimeouts: &atomicx.Int64{},
Txp: &mocks.DNSTransport{
MockRequiresPadding: func() bool {
return false
},
},
}
ctx := context.Background()
https, err := r.LookupHTTPS(ctx, "example.com")
if !errors.Is(err, expected) {
t.Fatal("unexpected err", err)
}
if https != nil {
t.Fatal("unexpected result")
}
})
t.Run("for round-trip error", func(t *testing.T) {
expected := errors.New("mocked error")
r := &ParallelResolver{
Encoder: &mocks.DNSEncoder{
MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, uint16, error) {
return make([]byte, 64), 0, nil
},
},
Decoder: nil,
NumTimeouts: &atomicx.Int64{},
Txp: &mocks.DNSTransport{
MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) {
MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
return nil, expected
},
MockRequiresPadding: func() bool {
@ -234,23 +194,17 @@ func TestParallelResolver(t *testing.T) {
}
})
t.Run("for decode error", func(t *testing.T) {
t.Run("for DecodeHTTPS error", func(t *testing.T) {
expected := errors.New("mocked error")
r := &ParallelResolver{
Encoder: &mocks.DNSEncoder{
MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, uint16, error) {
return make([]byte, 64), 0, nil
},
},
Decoder: &mocks.DNSDecoder{
MockDecodeHTTPS: func(reply []byte, queryID uint16) (*model.HTTPSSvc, error) {
return nil, expected
},
},
NumTimeouts: &atomicx.Int64{},
Txp: &mocks.DNSTransport{
MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) {
return make([]byte, 128), nil
MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
response := &mocks.DNSResponse{
MockDecodeHTTPS: func() (*model.HTTPSSvc, error) {
return nil, expected
},
}
return response, nil
},
MockRequiresPadding: func() bool {
return false
@ -269,44 +223,11 @@ func TestParallelResolver(t *testing.T) {
})
t.Run("LookupNS", func(t *testing.T) {
t.Run("for encoding error", func(t *testing.T) {
expected := errors.New("mocked error")
r := &ParallelResolver{
Encoder: &mocks.DNSEncoder{
MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, uint16, error) {
return nil, 0, expected
},
},
Decoder: nil,
NumTimeouts: &atomicx.Int64{},
Txp: &mocks.DNSTransport{
MockRequiresPadding: func() bool {
return false
},
},
}
ctx := context.Background()
ns, err := r.LookupNS(ctx, "example.com")
if !errors.Is(err, expected) {
t.Fatal("unexpected err", err)
}
if ns != nil {
t.Fatal("unexpected result")
}
})
t.Run("for round-trip error", func(t *testing.T) {
expected := errors.New("mocked error")
r := &ParallelResolver{
Encoder: &mocks.DNSEncoder{
MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, uint16, error) {
return make([]byte, 64), 0, nil
},
},
Decoder: nil,
NumTimeouts: &atomicx.Int64{},
Txp: &mocks.DNSTransport{
MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) {
MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
return nil, expected
},
MockRequiresPadding: func() bool {
@ -327,20 +248,14 @@ func TestParallelResolver(t *testing.T) {
t.Run("for decode error", func(t *testing.T) {
expected := errors.New("mocked error")
r := &ParallelResolver{
Encoder: &mocks.DNSEncoder{
MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, uint16, error) {
return make([]byte, 64), 0, nil
},
},
Decoder: &mocks.DNSDecoder{
MockDecodeNS: func(reply []byte, queryID uint16) ([]*net.NS, error) {
return nil, expected
},
},
NumTimeouts: &atomicx.Int64{},
Txp: &mocks.DNSTransport{
MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) {
return make([]byte, 128), nil
MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
response := &mocks.DNSResponse{
MockDecodeNS: func() ([]*net.NS, error) {
return nil, expected
},
}
return response, nil
},
MockRequiresPadding: func() bool {
return false

View File

@ -26,12 +26,6 @@ import (
// QUIRK: unlike the ParallelResolver, this resolver's LookupHost retries
// each query three times for soft errors.
type SerialResolver struct {
// Encoder is the MANDATORY encoder to use.
Encoder model.DNSEncoder
// Decoder is the MANDATORY decoder to use.
Decoder model.DNSDecoder
// NumTimeouts is MANDATORY and counts the number of timeouts.
NumTimeouts *atomicx.Int64
@ -42,8 +36,6 @@ type SerialResolver struct {
// NewSerialResolver creates a new SerialResolver instance.
func NewSerialResolver(t model.DNSTransport) *SerialResolver {
return &SerialResolver{
Encoder: &DNSEncoderMiekg{},
Decoder: &DNSDecoderMiekg{},
NumTimeouts: &atomicx.Int64{},
Txp: t,
}
@ -82,22 +74,22 @@ func (r *SerialResolver) LookupHost(ctx context.Context, hostname string) ([]str
}
addrs = append(addrs, addrsA...)
addrs = append(addrs, addrsAAAA...)
if len(addrs) < 1 {
return nil, ErrOODNSNoAnswer
}
return addrs, nil
}
// LookupHTTPS implements Resolver.LookupHTTPS.
func (r *SerialResolver) LookupHTTPS(
ctx context.Context, hostname string) (*model.HTTPSSvc, error) {
querydata, queryID, err := r.Encoder.Encode(
hostname, dns.TypeHTTPS, r.Txp.RequiresPadding())
encoder := &DNSEncoderMiekg{}
query := encoder.Encode(hostname, dns.TypeHTTPS, r.Txp.RequiresPadding())
response, err := r.Txp.RoundTrip(ctx, query)
if err != nil {
return nil, err
}
replydata, err := r.Txp.RoundTrip(ctx, querydata)
if err != nil {
return nil, err
}
return r.Decoder.DecodeHTTPS(replydata, queryID)
return response.DecodeHTTPS()
}
func (r *SerialResolver) lookupHostWithRetry(
@ -132,28 +124,23 @@ func (r *SerialResolver) lookupHostWithRetry(
// qtype (dns.A or dns.AAAA) without retrying on failure.
func (r *SerialResolver) lookupHostWithoutRetry(
ctx context.Context, hostname string, qtype uint16) ([]string, error) {
querydata, queryID, err := r.Encoder.Encode(hostname, qtype, r.Txp.RequiresPadding())
encoder := &DNSEncoderMiekg{}
query := encoder.Encode(hostname, qtype, r.Txp.RequiresPadding())
response, err := r.Txp.RoundTrip(ctx, query)
if err != nil {
return nil, err
}
replydata, err := r.Txp.RoundTrip(ctx, querydata)
if err != nil {
return nil, err
}
return r.Decoder.DecodeLookupHost(qtype, replydata, queryID)
return response.DecodeLookupHost()
}
// LookupNS implements Resolver.LookupNS.
func (r *SerialResolver) LookupNS(
ctx context.Context, hostname string) ([]*net.NS, error) {
querydata, queryID, err := r.Encoder.Encode(
hostname, dns.TypeNS, r.Txp.RequiresPadding())
encoder := &DNSEncoderMiekg{}
query := encoder.Encode(hostname, dns.TypeNS, r.Txp.RequiresPadding())
response, err := r.Txp.RoundTrip(ctx, query)
if err != nil {
return nil, err
}
replydata, err := r.Txp.RoundTrip(ctx, querydata)
if err != nil {
return nil, err
}
return r.Decoder.DecodeNS(replydata, queryID)
return response.DecodeNS()
}

View File

@ -7,6 +7,7 @@ import (
"net"
"testing"
"github.com/miekg/dns"
"github.com/ooni/probe-cli/v3/internal/atomicx"
"github.com/ooni/probe-cli/v3/internal/model"
"github.com/ooni/probe-cli/v3/internal/model/mocks"
@ -30,7 +31,7 @@ func (err *errorWithTimeout) Unwrap() error {
func TestSerialResolver(t *testing.T) {
t.Run("transport okay", func(t *testing.T) {
txp := NewDNSOverTLS((&tls.Dialer{}).DialContext, "8.8.8.8:853")
txp := NewDNSOverTLSTransport((&tls.Dialer{}).DialContext, "8.8.8.8:853")
r := NewSerialResolver(txp)
rtx := r.Transport()
if rtx.Network() != "dot" || rtx.Address() != "8.8.8.8:853" {
@ -45,30 +46,10 @@ func TestSerialResolver(t *testing.T) {
})
t.Run("LookupHost", func(t *testing.T) {
t.Run("Encode error", func(t *testing.T) {
mocked := errors.New("mocked error")
txp := NewDNSOverTLS((&tls.Dialer{}).DialContext, "8.8.8.8:853")
r := SerialResolver{
Encoder: &mocks.DNSEncoder{
MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, uint16, error) {
return nil, 0, mocked
},
},
Txp: txp,
}
addrs, err := r.LookupHost(context.Background(), "www.gogle.com")
if !errors.Is(err, mocked) {
t.Fatal("not the error we expected")
}
if addrs != nil {
t.Fatal("expected nil address here")
}
})
t.Run("RoundTrip error", func(t *testing.T) {
mocked := errors.New("mocked error")
txp := &mocks.DNSTransport{
MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) {
MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
return nil, mocked
},
MockRequiresPadding: func() bool {
@ -87,8 +68,13 @@ func TestSerialResolver(t *testing.T) {
t.Run("empty reply", func(t *testing.T) {
txp := &mocks.DNSTransport{
MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) {
return dnsGenLookupHostReplySuccess(query), nil
MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
response := &mocks.DNSResponse{
MockDecodeLookupHost: func() ([]string, error) {
return nil, nil
},
}
return response, nil
},
MockRequiresPadding: func() bool {
return true
@ -106,8 +92,16 @@ func TestSerialResolver(t *testing.T) {
t.Run("with A reply", func(t *testing.T) {
txp := &mocks.DNSTransport{
MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) {
return dnsGenLookupHostReplySuccess(query, "8.8.8.8"), nil
MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
response := &mocks.DNSResponse{
MockDecodeLookupHost: func() ([]string, error) {
if query.Type() != dns.TypeA {
return nil, nil
}
return []string{"8.8.8.8"}, nil
},
}
return response, nil
},
MockRequiresPadding: func() bool {
return true
@ -125,8 +119,16 @@ func TestSerialResolver(t *testing.T) {
t.Run("with AAAA reply", func(t *testing.T) {
txp := &mocks.DNSTransport{
MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) {
return dnsGenLookupHostReplySuccess(query, "::1"), nil
MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
response := &mocks.DNSResponse{
MockDecodeLookupHost: func() ([]string, error) {
if query.Type() != dns.TypeAAAA {
return nil, nil
}
return []string{"::1"}, nil
},
}
return response, nil
},
MockRequiresPadding: func() bool {
return true
@ -144,11 +146,12 @@ func TestSerialResolver(t *testing.T) {
t.Run("with timeout", func(t *testing.T) {
txp := &mocks.DNSTransport{
MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) {
return nil, &net.OpError{
MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
err := &net.OpError{
Err: &errorWithTimeout{ETIMEDOUT},
Op: "dial",
}
return nil, err
},
MockRequiresPadding: func() bool {
return true
@ -184,44 +187,12 @@ func TestSerialResolver(t *testing.T) {
})
t.Run("LookupHTTPS", func(t *testing.T) {
t.Run("for encoding error", func(t *testing.T) {
expected := errors.New("mocked error")
r := &SerialResolver{
Encoder: &mocks.DNSEncoder{
MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, uint16, error) {
return nil, 0, expected
},
},
Decoder: nil,
NumTimeouts: &atomicx.Int64{},
Txp: &mocks.DNSTransport{
MockRequiresPadding: func() bool {
return false
},
},
}
ctx := context.Background()
https, err := r.LookupHTTPS(ctx, "example.com")
if !errors.Is(err, expected) {
t.Fatal("unexpected err", err)
}
if https != nil {
t.Fatal("unexpected result")
}
})
t.Run("for round-trip error", func(t *testing.T) {
expected := errors.New("mocked error")
r := &SerialResolver{
Encoder: &mocks.DNSEncoder{
MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, uint16, error) {
return make([]byte, 64), 0, nil
},
},
Decoder: nil,
NumTimeouts: &atomicx.Int64{},
Txp: &mocks.DNSTransport{
MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) {
MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
return nil, expected
},
MockRequiresPadding: func() bool {
@ -242,20 +213,15 @@ func TestSerialResolver(t *testing.T) {
t.Run("for decode error", func(t *testing.T) {
expected := errors.New("mocked error")
r := &SerialResolver{
Encoder: &mocks.DNSEncoder{
MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, uint16, error) {
return make([]byte, 64), 0, nil
},
},
Decoder: &mocks.DNSDecoder{
MockDecodeHTTPS: func(reply []byte, queryID uint16) (*model.HTTPSSvc, error) {
return nil, expected
},
},
NumTimeouts: &atomicx.Int64{},
Txp: &mocks.DNSTransport{
MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) {
return make([]byte, 128), nil
MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
response := &mocks.DNSResponse{
MockDecodeHTTPS: func() (*model.HTTPSSvc, error) {
return nil, expected
},
}
return response, nil
},
MockRequiresPadding: func() bool {
return false
@ -274,44 +240,12 @@ func TestSerialResolver(t *testing.T) {
})
t.Run("LookupNS", func(t *testing.T) {
t.Run("for encoding error", func(t *testing.T) {
expected := errors.New("mocked error")
r := &SerialResolver{
Encoder: &mocks.DNSEncoder{
MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, uint16, error) {
return nil, 0, expected
},
},
Decoder: nil,
NumTimeouts: &atomicx.Int64{},
Txp: &mocks.DNSTransport{
MockRequiresPadding: func() bool {
return false
},
},
}
ctx := context.Background()
ns, err := r.LookupNS(ctx, "example.com")
if !errors.Is(err, expected) {
t.Fatal("unexpected err", err)
}
if ns != nil {
t.Fatal("unexpected result")
}
})
t.Run("for round-trip error", func(t *testing.T) {
expected := errors.New("mocked error")
r := &SerialResolver{
Encoder: &mocks.DNSEncoder{
MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, uint16, error) {
return make([]byte, 64), 0, nil
},
},
Decoder: nil,
NumTimeouts: &atomicx.Int64{},
Txp: &mocks.DNSTransport{
MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) {
MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
return nil, expected
},
MockRequiresPadding: func() bool {
@ -332,20 +266,15 @@ func TestSerialResolver(t *testing.T) {
t.Run("for decode error", func(t *testing.T) {
expected := errors.New("mocked error")
r := &SerialResolver{
Encoder: &mocks.DNSEncoder{
MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, uint16, error) {
return make([]byte, 64), 0, nil
},
},
Decoder: &mocks.DNSDecoder{
MockDecodeNS: func(reply []byte, queryID uint16) ([]*net.NS, error) {
return nil, expected
},
},
NumTimeouts: &atomicx.Int64{},
Txp: &mocks.DNSTransport{
MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) {
return make([]byte, 128), nil
MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
response := &mocks.DNSResponse{
MockDecodeNS: func() ([]*net.NS, error) {
return nil, expected
},
}
return response, nil
},
MockRequiresPadding: func() bool {
return false