From 01a513a49698c9516b8996217b98b43baa6df71f Mon Sep 17 00:00:00 2001 From: Simone Basso Date: Wed, 25 May 2022 17:03:58 +0200 Subject: [PATCH] refactor: DNSTransport I/Os DNS messages (#760) This diff refactors the DNSTransport model to receive in input a DNSQuery and return in output a DNSResponse. The design of DNSQuery and DNSResponse takes into account the use case of a transport using getaddrinfo, meaning that we don't need to serialize and deserialize messages when using getaddrinfo. The current codebase does not use a getaddrinfo transport, but I wrote one such a transport in the Websteps Winter 2021 prototype (https://github.com/bassosimone/websteps-illustrated/). The design conversation that lead to producing this diff is https://github.com/ooni/probe/issues/2099 --- go.mod | 1 - go.sum | 2 - internal/engine/netx/netx.go | 2 +- internal/engine/netx/resolver/fake_test.go | 34 - .../engine/netx/resolver/integration_test.go | 4 +- internal/engine/netx/resolver/saver.go | 25 +- internal/engine/netx/resolver/saver_test.go | 57 +- internal/measurex/dnsx.go | 23 +- internal/model/mocks/dnsdecoder.go | 34 +- internal/model/mocks/dnsdecoder_test.go | 56 +- internal/model/mocks/dnsencoder.go | 14 +- internal/model/mocks/dnsencoder_test.go | 30 +- internal/model/mocks/dnsquery.go | 33 + internal/model/mocks/dnsquery_test.go | 62 ++ internal/model/mocks/dnsresponse.go | 47 ++ internal/model/mocks/dnsresponse_test.go | 105 +++ internal/model/mocks/dnstransport.go | 12 +- internal/model/mocks/dnstransport_test.go | 5 +- internal/model/netx.go | 120 +-- internal/netxlite/dnsdecoder.go | 109 +-- internal/netxlite/dnsdecoder_test.go | 697 +++++++++++------- internal/netxlite/dnsencoder.go | 86 ++- internal/netxlite/dnsencoder_test.go | 96 ++- internal/netxlite/dnsoverhttps.go | 35 +- internal/netxlite/dnsoverhttps_test.go | 172 ++++- internal/netxlite/dnsovertcp.go | 60 +- internal/netxlite/dnsovertcp_test.go | 163 ++-- internal/netxlite/dnsoverudp.go | 30 +- internal/netxlite/dnsoverudp_test.go | 129 +++- internal/netxlite/filtering/dns.go | 51 +- internal/netxlite/filtering/dns_test.go | 58 +- internal/netxlite/parallelresolver.go | 54 +- internal/netxlite/parallelresolver_test.go | 189 ++--- internal/netxlite/serialresolver.go | 43 +- internal/netxlite/serialresolver_test.go | 169 ++--- 35 files changed, 1731 insertions(+), 1076 deletions(-) create mode 100644 internal/model/mocks/dnsquery.go create mode 100644 internal/model/mocks/dnsquery_test.go create mode 100644 internal/model/mocks/dnsresponse.go create mode 100644 internal/model/mocks/dnsresponse_test.go diff --git a/go.mod b/go.mod index b73aad3..c41de38 100644 --- a/go.mod +++ b/go.mod @@ -17,7 +17,6 @@ require ( github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 github.com/google/uuid v1.3.0 github.com/gorilla/websocket v1.5.0 - github.com/hexops/gotextdiff v1.0.3 github.com/iancoleman/strcase v0.2.0 github.com/lucas-clemente/quic-go v0.27.0 github.com/mattn/go-colorable v0.1.12 diff --git a/go.sum b/go.sum index 4298ad6..9b18f77 100644 --- a/go.sum +++ b/go.sum @@ -378,8 +378,6 @@ github.com/hashicorp/logutils v1.0.0/go.mod h1:QIAnNjmIWmVIIkWDTG1z5v++HQmx9WQRO github.com/hashicorp/mdns v1.0.0/go.mod h1:tL+uN++7HEJ6SQLQ2/p+z2pH24WQKWjBPkE0mNTz8vQ= github.com/hashicorp/memberlist v0.1.3/go.mod h1:ajVTdAv/9Im8oMAAj5G31PhhMCZJV2pPBoIllUwCN7I= github.com/hashicorp/serf v0.8.2/go.mod h1:6hOLApaqBFA1NXqRQAsxw9QxuDEvNxSQRwA/JwenrHc= -github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUqJM= -github.com/hexops/gotextdiff v1.0.3/go.mod h1:pSWU5MAI3yDq+fZBTazCSJysOMbxWL1BSow5/V2vxeg= github.com/hinshun/vt10x v0.0.0-20220119200601-820417d04eec h1:qv2VnGeEQHchGaZ/u7lxST/RaJw+cv273q79D81Xbog= github.com/hinshun/vt10x v0.0.0-20220119200601-820417d04eec/go.mod h1:Q48J4R4DvxnHolD5P8pOtXigYlRuPLGl6moFx3ulM68= github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= diff --git a/internal/engine/netx/netx.go b/internal/engine/netx/netx.go index 6e3441d..2efdb13 100644 --- a/internal/engine/netx/netx.go +++ b/internal/engine/netx/netx.go @@ -317,7 +317,7 @@ func NewDNSClientWithOverrides(config Config, URL, hostOverride, SNIOverride, if err != nil { return nil, err } - var txp model.DNSTransport = netxlite.NewDNSOverTLS( + var txp model.DNSTransport = netxlite.NewDNSOverTLSTransport( tlsDialer.DialTLSContext, endpoint) if config.ResolveSaver != nil { txp = resolver.SaverDNSTransport{ diff --git a/internal/engine/netx/resolver/fake_test.go b/internal/engine/netx/resolver/fake_test.go index ebdbf13..9c62447 100644 --- a/internal/engine/netx/resolver/fake_test.go +++ b/internal/engine/netx/resolver/fake_test.go @@ -76,40 +76,6 @@ func (c *FakeConn) SetWriteDeadline(t time.Time) (err error) { return c.SetWriteDeadlineError } -type FakeTransport struct { - Data []byte - Err error -} - -func (ft FakeTransport) RoundTrip(ctx context.Context, query []byte) ([]byte, error) { - return ft.Data, ft.Err -} - -func (ft FakeTransport) RequiresPadding() bool { - return false -} - -func (ft FakeTransport) Address() string { - return "" -} - -func (ft FakeTransport) Network() string { - return "fake" -} - -func (fk FakeTransport) CloseIdleConnections() { - // nothing to do -} - -type FakeEncoder struct { - Data []byte - Err error -} - -func (fe FakeEncoder) Encode(domain string, qtype uint16, padding bool) ([]byte, error) { - return fe.Data, fe.Err -} - func NewFakeResolverThatFails() model.Resolver { return NewFakeResolverWithExplicitError(netxlite.ErrOODNSNoSuchHost) } diff --git a/internal/engine/netx/resolver/integration_test.go b/internal/engine/netx/resolver/integration_test.go index 217d9a5..545f208 100644 --- a/internal/engine/netx/resolver/integration_test.go +++ b/internal/engine/netx/resolver/integration_test.go @@ -99,14 +99,14 @@ func TestNewResolverTCPDomain(t *testing.T) { func TestNewResolverDoTAddress(t *testing.T) { reso := netxlite.NewSerialResolver( - netxlite.NewDNSOverTLS(new(tls.Dialer).DialContext, "8.8.8.8:853")) + netxlite.NewDNSOverTLSTransport(new(tls.Dialer).DialContext, "8.8.8.8:853")) testresolverquick(t, reso) testresolverquickidna(t, reso) } func TestNewResolverDoTDomain(t *testing.T) { reso := netxlite.NewSerialResolver( - netxlite.NewDNSOverTLS(new(tls.Dialer).DialContext, "dns.google.com:853")) + netxlite.NewDNSOverTLSTransport(new(tls.Dialer).DialContext, "dns.google.com:853")) testresolverquick(t, reso) testresolverquickidna(t, reso) } diff --git a/internal/engine/netx/resolver/saver.go b/internal/engine/netx/resolver/saver.go index 0c034dc..ba11a9b 100644 --- a/internal/engine/netx/resolver/saver.go +++ b/internal/engine/netx/resolver/saver.go @@ -46,28 +46,41 @@ type SaverDNSTransport struct { } // RoundTrip implements RoundTripper.RoundTrip -func (txp SaverDNSTransport) RoundTrip(ctx context.Context, query []byte) ([]byte, error) { +func (txp SaverDNSTransport) RoundTrip( + ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) { start := time.Now() txp.Saver.Write(trace.Event{ Address: txp.Address(), - DNSQuery: query, + DNSQuery: txp.maybeQueryBytes(query), Name: "dns_round_trip_start", Proto: txp.Network(), Time: start, }) - reply, err := txp.DNSTransport.RoundTrip(ctx, query) + response, err := txp.DNSTransport.RoundTrip(ctx, query) stop := time.Now() txp.Saver.Write(trace.Event{ Address: txp.Address(), - DNSQuery: query, - DNSReply: reply, + DNSQuery: txp.maybeQueryBytes(query), + DNSReply: txp.maybeResponseBytes(response), Duration: stop.Sub(start), Err: err, Name: "dns_round_trip_done", Proto: txp.Network(), Time: stop, }) - return reply, err + return response, err +} + +func (txp SaverDNSTransport) maybeQueryBytes(query model.DNSQuery) []byte { + data, _ := query.Bytes() + return data +} + +func (txp SaverDNSTransport) maybeResponseBytes(response model.DNSResponse) []byte { + if response == nil { + return nil + } + return response.Bytes() } var _ model.Resolver = SaverResolver{} diff --git a/internal/engine/netx/resolver/saver_test.go b/internal/engine/netx/resolver/saver_test.go index d9ad7c9..01f35cb 100644 --- a/internal/engine/netx/resolver/saver_test.go +++ b/internal/engine/netx/resolver/saver_test.go @@ -10,6 +10,8 @@ import ( "github.com/ooni/probe-cli/v3/internal/engine/netx/resolver" "github.com/ooni/probe-cli/v3/internal/engine/netx/trace" + "github.com/ooni/probe-cli/v3/internal/model" + "github.com/ooni/probe-cli/v3/internal/model/mocks" ) func TestSaverResolverFailure(t *testing.T) { @@ -110,12 +112,25 @@ func TestSaverDNSTransportFailure(t *testing.T) { expected := errors.New("no such host") saver := &trace.Saver{} txp := resolver.SaverDNSTransport{ - DNSTransport: resolver.FakeTransport{ - Err: expected, + DNSTransport: &mocks.DNSTransport{ + MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) { + return nil, expected + }, + MockNetwork: func() string { + return "fake" + }, + MockAddress: func() string { + return "" + }, }, Saver: saver, } - query := []byte("abc") + rawQuery := []byte{0xde, 0xad, 0xbe, 0xef} + query := &mocks.DNSQuery{ + MockBytes: func() ([]byte, error) { + return rawQuery, nil + }, + } reply, err := txp.RoundTrip(context.Background(), query) if !errors.Is(err, expected) { t.Fatal("not the error we expected") @@ -127,7 +142,7 @@ func TestSaverDNSTransportFailure(t *testing.T) { if len(ev) != 2 { t.Fatal("expected number of events") } - if !bytes.Equal(ev[0].DNSQuery, query) { + if !bytes.Equal(ev[0].DNSQuery, rawQuery) { t.Fatal("unexpected DNSQuery") } if ev[0].Name != "dns_round_trip_start" { @@ -136,7 +151,7 @@ func TestSaverDNSTransportFailure(t *testing.T) { if !ev[0].Time.Before(time.Now()) { t.Fatal("the saved time is wrong") } - if !bytes.Equal(ev[1].DNSQuery, query) { + if !bytes.Equal(ev[1].DNSQuery, rawQuery) { t.Fatal("unexpected DNSQuery") } if ev[1].DNSReply != nil { @@ -157,27 +172,45 @@ func TestSaverDNSTransportFailure(t *testing.T) { } func TestSaverDNSTransportSuccess(t *testing.T) { - expected := []byte("def") + expected := []byte{0xef, 0xbe, 0xad, 0xde} saver := &trace.Saver{} + response := &mocks.DNSResponse{ + MockBytes: func() []byte { + return expected + }, + } txp := resolver.SaverDNSTransport{ - DNSTransport: resolver.FakeTransport{ - Data: expected, + DNSTransport: &mocks.DNSTransport{ + MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) { + return response, nil + }, + MockNetwork: func() string { + return "fake" + }, + MockAddress: func() string { + return "" + }, }, Saver: saver, } - query := []byte("abc") + rawQuery := []byte{0xde, 0xad, 0xbe, 0xef} + query := &mocks.DNSQuery{ + MockBytes: func() ([]byte, error) { + return rawQuery, nil + }, + } reply, err := txp.RoundTrip(context.Background(), query) if err != nil { t.Fatal("we expected nil error here") } - if !bytes.Equal(reply, expected) { + if !bytes.Equal(reply.Bytes(), expected) { t.Fatal("expected another reply here") } ev := saver.Read() if len(ev) != 2 { t.Fatal("expected number of events") } - if !bytes.Equal(ev[0].DNSQuery, query) { + if !bytes.Equal(ev[0].DNSQuery, rawQuery) { t.Fatal("unexpected DNSQuery") } if ev[0].Name != "dns_round_trip_start" { @@ -186,7 +219,7 @@ func TestSaverDNSTransportSuccess(t *testing.T) { if !ev[0].Time.Before(time.Now()) { t.Fatal("the saved time is wrong") } - if !bytes.Equal(ev[1].DNSQuery, query) { + if !bytes.Equal(ev[1].DNSQuery, rawQuery) { t.Fatal("unexpected DNSQuery") } if !bytes.Equal(ev[1].DNSReply, expected) { diff --git a/internal/measurex/dnsx.go b/internal/measurex/dnsx.go index a82ecf3..26bef6e 100644 --- a/internal/measurex/dnsx.go +++ b/internal/measurex/dnsx.go @@ -36,18 +36,31 @@ type DNSRoundTripEvent struct { Reply []byte } -func (txp *dnsxRoundTripperDB) RoundTrip(ctx context.Context, query []byte) ([]byte, error) { +func (txp *dnsxRoundTripperDB) RoundTrip( + ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) { started := time.Since(txp.begin).Seconds() - reply, err := txp.DNSTransport.RoundTrip(ctx, query) + response, err := txp.DNSTransport.RoundTrip(ctx, query) finished := time.Since(txp.begin).Seconds() txp.db.InsertIntoDNSRoundTrip(&DNSRoundTripEvent{ Network: txp.DNSTransport.Network(), Address: txp.DNSTransport.Address(), - Query: query, + Query: txp.maybeQueryBytes(query), Started: started, Finished: finished, Failure: NewFailure(err), - Reply: reply, + Reply: txp.maybeResponseBytes(response), }) - return reply, err + return response, err +} + +func (txp *dnsxRoundTripperDB) maybeQueryBytes(query model.DNSQuery) []byte { + data, _ := query.Bytes() + return data +} + +func (txp *dnsxRoundTripperDB) maybeResponseBytes(response model.DNSResponse) []byte { + if response == nil { + return nil + } + return response.Bytes() } diff --git a/internal/model/mocks/dnsdecoder.go b/internal/model/mocks/dnsdecoder.go index 66d3f0e..ab3edd1 100644 --- a/internal/model/mocks/dnsdecoder.go +++ b/internal/model/mocks/dnsdecoder.go @@ -1,36 +1,20 @@ package mocks -import ( - "net" +// +// Mocks for model.DNSDecoder +// - "github.com/miekg/dns" +import ( "github.com/ooni/probe-cli/v3/internal/model" ) -// DNSDecoder allows mocking dnsx.DNSDecoder. +// DNSDecoder allows mocking model.DNSDecoder. type DNSDecoder struct { - MockDecodeLookupHost func(qtype uint16, reply []byte, queryID uint16) ([]string, error) - MockDecodeHTTPS func(reply []byte, queryID uint16) (*model.HTTPSSvc, error) - MockDecodeNS func(reply []byte, queryID uint16) ([]*net.NS, error) - MockDecodeReply func(reply []byte) (*dns.Msg, error) + MockDecodeResponse func(data []byte, query model.DNSQuery) (model.DNSResponse, error) } -// DecodeLookupHost calls MockDecodeLookupHost. -func (e *DNSDecoder) DecodeLookupHost(qtype uint16, reply []byte, queryID uint16) ([]string, error) { - return e.MockDecodeLookupHost(qtype, reply, queryID) -} +var _ model.DNSDecoder = &DNSDecoder{} -// DecodeHTTPS calls MockDecodeHTTPS. -func (e *DNSDecoder) DecodeHTTPS(reply []byte, queryID uint16) (*model.HTTPSSvc, error) { - return e.MockDecodeHTTPS(reply, queryID) -} - -// DecodeNS calls MockDecodeNS. -func (e *DNSDecoder) DecodeNS(reply []byte, queryID uint16) ([]*net.NS, error) { - return e.MockDecodeNS(reply, queryID) -} - -// DecodeReply calls MockDecodeReply. -func (e *DNSDecoder) DecodeReply(reply []byte) (*dns.Msg, error) { - return e.MockDecodeReply(reply) +func (e *DNSDecoder) DecodeResponse(data []byte, query model.DNSQuery) (model.DNSResponse, error) { + return e.MockDecodeResponse(data, query) } diff --git a/internal/model/mocks/dnsdecoder_test.go b/internal/model/mocks/dnsdecoder_test.go index 1d52f97..e3a4024 100644 --- a/internal/model/mocks/dnsdecoder_test.go +++ b/internal/model/mocks/dnsdecoder_test.go @@ -2,70 +2,20 @@ package mocks import ( "errors" - "net" "testing" - "github.com/miekg/dns" "github.com/ooni/probe-cli/v3/internal/model" ) func TestDNSDecoder(t *testing.T) { - t.Run("DecodeLookupHost", func(t *testing.T) { + t.Run("DecodeResponse", func(t *testing.T) { expected := errors.New("mocked error") e := &DNSDecoder{ - MockDecodeLookupHost: func(qtype uint16, reply []byte, queryID uint16) ([]string, error) { + MockDecodeResponse: func(reply []byte, query model.DNSQuery) (model.DNSResponse, error) { return nil, expected }, } - out, err := e.DecodeLookupHost(dns.TypeA, make([]byte, 17), dns.Id()) - if !errors.Is(err, expected) { - t.Fatal("unexpected err", err) - } - if out != nil { - t.Fatal("unexpected out") - } - }) - - t.Run("DecodeHTTPS", func(t *testing.T) { - expected := errors.New("mocked error") - e := &DNSDecoder{ - MockDecodeHTTPS: func(reply []byte, queryID uint16) (*model.HTTPSSvc, error) { - return nil, expected - }, - } - out, err := e.DecodeHTTPS(make([]byte, 17), dns.Id()) - if !errors.Is(err, expected) { - t.Fatal("unexpected err", err) - } - if out != nil { - t.Fatal("unexpected out") - } - }) - - t.Run("DecodeNS", func(t *testing.T) { - expected := errors.New("mocked error") - e := &DNSDecoder{ - MockDecodeNS: func(reply []byte, queryID uint16) ([]*net.NS, error) { - return nil, expected - }, - } - out, err := e.DecodeNS(make([]byte, 17), dns.Id()) - if !errors.Is(err, expected) { - t.Fatal("unexpected err", err) - } - if out != nil { - t.Fatal("unexpected out") - } - }) - - t.Run("DecodeReply", func(t *testing.T) { - expected := errors.New("mocked error") - e := &DNSDecoder{ - MockDecodeReply: func(reply []byte) (*dns.Msg, error) { - return nil, expected - }, - } - out, err := e.DecodeReply(make([]byte, 17)) + out, err := e.DecodeResponse(make([]byte, 17), &DNSQuery{}) if !errors.Is(err, expected) { t.Fatal("unexpected err", err) } diff --git a/internal/model/mocks/dnsencoder.go b/internal/model/mocks/dnsencoder.go index d59ea19..228a714 100644 --- a/internal/model/mocks/dnsencoder.go +++ b/internal/model/mocks/dnsencoder.go @@ -1,11 +1,19 @@ package mocks -// DNSEncoder allows mocking dnsx.DNSEncoder. +// +// Mocks for model.DNSEncoder. +// + +import "github.com/ooni/probe-cli/v3/internal/model" + +// DNSEncoder allows mocking model.DNSEncoder. type DNSEncoder struct { - MockEncode func(domain string, qtype uint16, padding bool) ([]byte, uint16, error) + MockEncode func(domain string, qtype uint16, padding bool) model.DNSQuery } +var _ model.DNSEncoder = &DNSEncoder{} + // Encode calls MockEncode. -func (e *DNSEncoder) Encode(domain string, qtype uint16, padding bool) ([]byte, uint16, error) { +func (e *DNSEncoder) Encode(domain string, qtype uint16, padding bool) model.DNSQuery { return e.MockEncode(domain, qtype, padding) } diff --git a/internal/model/mocks/dnsencoder_test.go b/internal/model/mocks/dnsencoder_test.go index 0918f77..92af1bb 100644 --- a/internal/model/mocks/dnsencoder_test.go +++ b/internal/model/mocks/dnsencoder_test.go @@ -5,24 +5,46 @@ import ( "testing" "github.com/miekg/dns" + "github.com/ooni/probe-cli/v3/internal/model" ) func TestDNSEncoder(t *testing.T) { t.Run("Encode", func(t *testing.T) { expected := errors.New("mocked error") + queryID := dns.Id() e := &DNSEncoder{ - MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, uint16, error) { - return nil, 0, expected + MockEncode: func(domain string, qtype uint16, padding bool) model.DNSQuery { + return &DNSQuery{ + MockDomain: func() string { + return dns.Fqdn(domain) // do what an implementation MUST do + }, + MockType: func() uint16 { + return qtype + }, + MockBytes: func() ([]byte, error) { + return nil, expected + }, + MockID: func() uint16 { + return queryID + }, + } }, } - out, queryID, err := e.Encode("dns.google", dns.TypeA, true) + query := e.Encode("dns.google", dns.TypeA, true) + if query.Domain() != "dns.google." { + t.Fatal("invalid domain") + } + if query.Type() != dns.TypeA { + t.Fatal("invalid type") + } + out, err := query.Bytes() if !errors.Is(err, expected) { t.Fatal("unexpected err", err) } if out != nil { t.Fatal("unexpected out") } - if queryID != 0 { + if query.ID() != queryID { t.Fatal("unexpected queryID") } }) diff --git a/internal/model/mocks/dnsquery.go b/internal/model/mocks/dnsquery.go new file mode 100644 index 0000000..0ea8661 --- /dev/null +++ b/internal/model/mocks/dnsquery.go @@ -0,0 +1,33 @@ +package mocks + +// +// Mocks for model.DNSQuery. +// + +import "github.com/ooni/probe-cli/v3/internal/model" + +// DNSQuery allocks mocking model.DNSQuery. +type DNSQuery struct { + MockDomain func() string + MockType func() uint16 + MockBytes func() ([]byte, error) + MockID func() uint16 +} + +func (q *DNSQuery) Domain() string { + return q.MockDomain() +} + +func (q *DNSQuery) Type() uint16 { + return q.MockType() +} + +func (q *DNSQuery) Bytes() ([]byte, error) { + return q.MockBytes() +} + +func (q *DNSQuery) ID() uint16 { + return q.MockID() +} + +var _ model.DNSQuery = &DNSQuery{} diff --git a/internal/model/mocks/dnsquery_test.go b/internal/model/mocks/dnsquery_test.go new file mode 100644 index 0000000..6b13c13 --- /dev/null +++ b/internal/model/mocks/dnsquery_test.go @@ -0,0 +1,62 @@ +package mocks + +import ( + "bytes" + "testing" + + "github.com/miekg/dns" +) + +func TestDNSQuery(t *testing.T) { + t.Run("Domain", func(t *testing.T) { + expected := "dns.google." + q := &DNSQuery{ + MockDomain: func() string { + return expected + }, + } + if q.Domain() != expected { + t.Fatal("invalid domain") + } + }) + + t.Run("Type", func(t *testing.T) { + expected := dns.TypeAAAA + q := &DNSQuery{ + MockType: func() uint16 { + return expected + }, + } + if q.Type() != expected { + t.Fatal("invalid type") + } + }) + + t.Run("Bytes", func(t *testing.T) { + expected := []byte{0xde, 0xea, 0xad, 0xbe, 0xef} + q := &DNSQuery{ + MockBytes: func() ([]byte, error) { + return expected, nil + }, + } + out, err := q.Bytes() + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(expected, out) { + t.Fatal("invalid bytes") + } + }) + + t.Run("ID", func(t *testing.T) { + expected := dns.Id() + q := &DNSQuery{ + MockID: func() uint16 { + return expected + }, + } + if q.ID() != expected { + t.Fatal("invalid id") + } + }) +} diff --git a/internal/model/mocks/dnsresponse.go b/internal/model/mocks/dnsresponse.go new file mode 100644 index 0000000..7752c42 --- /dev/null +++ b/internal/model/mocks/dnsresponse.go @@ -0,0 +1,47 @@ +package mocks + +// +// Mocks for model.DNSResponse +// + +import ( + "net" + + "github.com/ooni/probe-cli/v3/internal/model" +) + +// DNSResponse allows mocking model.DNSResponse. +type DNSResponse struct { + MockQuery func() model.DNSQuery + MockBytes func() []byte + MockRcode func() int + MockDecodeHTTPS func() (*model.HTTPSSvc, error) + MockDecodeLookupHost func() ([]string, error) + MockDecodeNS func() ([]*net.NS, error) +} + +var _ model.DNSResponse = &DNSResponse{} + +func (r *DNSResponse) Query() model.DNSQuery { + return r.MockQuery() +} + +func (r *DNSResponse) Bytes() []byte { + return r.MockBytes() +} + +func (r *DNSResponse) Rcode() int { + return r.MockRcode() +} + +func (r *DNSResponse) DecodeHTTPS() (*model.HTTPSSvc, error) { + return r.MockDecodeHTTPS() +} + +func (r *DNSResponse) DecodeLookupHost() ([]string, error) { + return r.MockDecodeLookupHost() +} + +func (r *DNSResponse) DecodeNS() ([]*net.NS, error) { + return r.MockDecodeNS() +} diff --git a/internal/model/mocks/dnsresponse_test.go b/internal/model/mocks/dnsresponse_test.go new file mode 100644 index 0000000..e362544 --- /dev/null +++ b/internal/model/mocks/dnsresponse_test.go @@ -0,0 +1,105 @@ +package mocks + +import ( + "bytes" + "errors" + "net" + "testing" + + "github.com/miekg/dns" + "github.com/ooni/probe-cli/v3/internal/model" +) + +func TestDNSResponse(t *testing.T) { + t.Run("Query", func(t *testing.T) { + qid := dns.Id() + query := &DNSQuery{ + MockID: func() uint16 { + return qid + }, + } + resp := &DNSResponse{ + MockQuery: func() model.DNSQuery { + return query + }, + } + out := resp.Query() + if out.ID() != query.ID() { + t.Fatal("invalid query") + } + }) + + t.Run("Bytes", func(t *testing.T) { + expected := []byte{0xde, 0xea, 0xad, 0xbe, 0xef} + resp := &DNSResponse{ + MockBytes: func() []byte { + return expected + }, + } + out := resp.Bytes() + if !bytes.Equal(expected, out) { + t.Fatal("invalid bytes") + } + }) + + t.Run("Rcode", func(t *testing.T) { + expected := dns.RcodeBadAlg + resp := &DNSResponse{ + MockRcode: func() int { + return expected + }, + } + out := resp.Rcode() + if out != expected { + t.Fatal("invalid rcode") + } + }) + + t.Run("DecodeLookupHost", func(t *testing.T) { + expected := errors.New("mocked error") + r := &DNSResponse{ + MockDecodeLookupHost: func() ([]string, error) { + return nil, expected + }, + } + out, err := r.DecodeLookupHost() + if !errors.Is(err, expected) { + t.Fatal("unexpected err", err) + } + if out != nil { + t.Fatal("unexpected out") + } + }) + + t.Run("DecodeHTTPS", func(t *testing.T) { + expected := errors.New("mocked error") + r := &DNSResponse{ + MockDecodeHTTPS: func() (*model.HTTPSSvc, error) { + return nil, expected + }, + } + out, err := r.DecodeHTTPS() + if !errors.Is(err, expected) { + t.Fatal("unexpected err", err) + } + if out != nil { + t.Fatal("unexpected out") + } + }) + + t.Run("DecodeNS", func(t *testing.T) { + expected := errors.New("mocked error") + r := &DNSResponse{ + MockDecodeNS: func() ([]*net.NS, error) { + return nil, expected + }, + } + out, err := r.DecodeNS() + if !errors.Is(err, expected) { + t.Fatal("unexpected err", err) + } + if out != nil { + t.Fatal("unexpected out") + } + }) +} diff --git a/internal/model/mocks/dnstransport.go b/internal/model/mocks/dnstransport.go index ed32f09..006367a 100644 --- a/internal/model/mocks/dnstransport.go +++ b/internal/model/mocks/dnstransport.go @@ -1,10 +1,14 @@ package mocks -import "context" +import ( + "context" + + "github.com/ooni/probe-cli/v3/internal/model" +) // DNSTransport allows mocking dnsx.DNSTransport. type DNSTransport struct { - MockRoundTrip func(ctx context.Context, query []byte) ([]byte, error) + MockRoundTrip func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) MockRequiresPadding func() bool @@ -15,8 +19,10 @@ type DNSTransport struct { MockCloseIdleConnections func() } +var _ model.DNSTransport = &DNSTransport{} + // RoundTrip calls MockRoundTrip. -func (txp *DNSTransport) RoundTrip(ctx context.Context, query []byte) ([]byte, error) { +func (txp *DNSTransport) RoundTrip(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) { return txp.MockRoundTrip(ctx, query) } diff --git a/internal/model/mocks/dnstransport_test.go b/internal/model/mocks/dnstransport_test.go index ba42a04..4d26784 100644 --- a/internal/model/mocks/dnstransport_test.go +++ b/internal/model/mocks/dnstransport_test.go @@ -6,17 +6,18 @@ import ( "testing" "github.com/ooni/probe-cli/v3/internal/atomicx" + "github.com/ooni/probe-cli/v3/internal/model" ) func TestDNSTransport(t *testing.T) { t.Run("RoundTrip", func(t *testing.T) { expected := errors.New("mocked error") txp := &DNSTransport{ - MockRoundTrip: func(ctx context.Context, query []byte) ([]byte, error) { + MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) { return nil, expected }, } - resp, err := txp.RoundTrip(context.Background(), make([]byte, 16)) + resp, err := txp.RoundTrip(context.Background(), &DNSQuery{}) if !errors.Is(err, expected) { t.Fatal("not the error we expected", err) } diff --git a/internal/model/netx.go b/internal/model/netx.go index a25575b..4fc7cbf 100644 --- a/internal/model/netx.go +++ b/internal/model/netx.go @@ -1,5 +1,9 @@ package model +// +// Network extensions +// + import ( "context" "crypto/tls" @@ -9,74 +13,81 @@ import ( "time" "github.com/lucas-clemente/quic-go" - "github.com/miekg/dns" ) -// -// Network extensions -// +// DNSResponse is a parsed DNS response ready for further processing. +type DNSResponse interface { + // Query is the query associated with this response. + Query() DNSQuery -// The DNSDecoder decodes DNS replies. + // Bytes returns the bytes from which we parsed the query. + Bytes() []byte + + // Rcode returns the response's Rcode. + Rcode() int + + // DecodeHTTPS returns information gathered from all the HTTPS + // records found inside of this response. + DecodeHTTPS() (*HTTPSSvc, error) + + // DecodeLookupHost returns the addresses in the response matching + // the original query type (one of A and AAAA). + DecodeLookupHost() ([]string, error) + + // DecodeNS returns all the NS entries in this response. + DecodeNS() ([]*net.NS, error) +} + +// The DNSDecoder decodes DNS responses. type DNSDecoder interface { - // DecodeLookupHost decodes an A or AAAA reply. - // - // Arguments: - // - // - qtype is the query type (e.g., dns.TypeAAAA) - // - // - data contains the reply bytes read from a DNSTransport - // - // - queryID is the original query ID - // - // Returns: - // - // - on success, a list of IP addrs inside the reply and a nil error - // - // - on failure, a nil list and an error. - // - // Note that this function will return an error if there is no - // IP address inside of the reply. - DecodeLookupHost(qtype uint16, data []byte, queryID uint16) ([]string, error) - - // DecodeHTTPS is like DecodeLookupHost but decodes an HTTPS reply. - // - // The argument is the reply as read by the DNSTransport. - // - // On success, this function returns an HTTPSSvc structure and - // a nil error. On failure, the HTTPSSvc pointer is nil and - // the error points to the error that occurred. - // - // This function will return an error if the HTTPS reply does not - // contain at least a valid ALPN entry. It will not return - // an error, though, when there are no IPv4/IPv6 hints in the reply. - DecodeHTTPS(data []byte, queryID uint16) (*HTTPSSvc, error) - - // DecodeNS is like DecodeHTTPS but for NS queries. - DecodeNS(data []byte, queryID uint16) ([]*net.NS, error) - - // DecodeReply decodes a DNS reply message. + // DecodeResponse decodes a DNS response message. // // Arguments: // // - data is the raw reply // // This function fails if we cannot parse data as a DNS - // message or the message is not a reply. + // message or the message is not a response. // - // If you use this function, remember that: + // Regarding the returned response, remember that the Rcode + // MAY still be nonzero (this method does not treat a nonzero + // Rcode as an error when parsing the response). + DecodeResponse(data []byte, query DNSQuery) (DNSResponse, error) +} + +// DNSQuery is an encoded DNS query ready to be sent using a DNSTransport. +type DNSQuery interface { + // Domain is the domain we're querying for. + Domain() string + + // Type is the query type. + Type() uint16 + + // Bytes serializes the query to bytes. This function may fail if we're not + // able to correctly encode the domain into a query message. // - // 1. the Rcode MAY be nonzero; - // - // 2. the replyID MAY NOT match the original query ID. - // - // That is, this is a very basic parsing method. - DecodeReply(data []byte) (*dns.Msg, error) + // The value returned by this function WILL be memoized after the first call, + // so you SHOULD create a new DNSQuery if you need to retry a query. + Bytes() ([]byte, error) + + // ID returns the query ID. + ID() uint16 } // The DNSEncoder encodes DNS queries to bytes type DNSEncoder interface { // Encode transforms its arguments into a serialized DNS query. // + // Every time you call Encode, you get a new DNSQuery value + // using a query ID selected at random. + // + // Serialization to bytes is lazy to acommodate DNS transports that + // do not need to serialize and send bytes, e.g., getaddrinfo. + // + // You serialize to bytes using DNSQuery.Bytes. This operation MAY fail + // if the domain name cannot be packed into a DNS message (e.g., it is + // too long to fit into the message). + // // Arguments: // // - domain is the domain for the query (e.g., x.org); @@ -85,16 +96,15 @@ type DNSEncoder interface { // // - padding is whether to add padding to the query. // - // On success, this function returns a valid byte array, the queryID, and - // a nil error. On failure, we have a non-nil error, a nil arrary and a zero - // query ID. - Encode(domain string, qtype uint16, padding bool) ([]byte, uint16, error) + // This function will transform the domain into an FQDN is it's not + // already expressed in the FQDN format. + Encode(domain string, qtype uint16, padding bool) DNSQuery } // DNSTransport represents an abstract DNS transport. type DNSTransport interface { // RoundTrip sends a DNS query and receives the reply. - RoundTrip(ctx context.Context, query []byte) (reply []byte, err error) + RoundTrip(ctx context.Context, query DNSQuery) (DNSResponse, error) // RequiresPadding returns whether this transport needs padding. RequiresPadding() bool diff --git a/internal/netxlite/dnsdecoder.go b/internal/netxlite/dnsdecoder.go index cf6590d..ef7b46e 100644 --- a/internal/netxlite/dnsdecoder.go +++ b/internal/netxlite/dnsdecoder.go @@ -15,14 +15,16 @@ import ( // DNSDecoderMiekg uses github.com/miekg/dns to implement the Decoder. type DNSDecoderMiekg struct{} -// ErrDNSReplyWithWrongQueryID indicates we have got a DNS reply with the wrong queryID. -var ErrDNSReplyWithWrongQueryID = errors.New(FailureDNSReplyWithWrongQueryID) +var ( + // ErrDNSReplyWithWrongQueryID indicates we have got a DNS reply with the wrong queryID. + ErrDNSReplyWithWrongQueryID = errors.New(FailureDNSReplyWithWrongQueryID) -// ErrDNSIsQuery indicates that we were passed a DNS query. -var ErrDNSIsQuery = errors.New("ooresolver: expected response but received query") + // ErrDNSIsQuery indicates that we were passed a DNS query. + ErrDNSIsQuery = errors.New("ooresolver: expected response but received query") +) -// DecodeReply implements model.DNSDecoder.DecodeReply -func (d *DNSDecoderMiekg) DecodeReply(data []byte) (*dns.Msg, error) { +// DecodeResponse implements model.DNSDecoder.DecodeResponse. +func (d *DNSDecoderMiekg) DecodeResponse(data []byte, query model.DNSQuery) (model.DNSResponse, error) { reply := &dns.Msg{} if err := reply.Unpack(data); err != nil { return nil, err @@ -30,46 +32,64 @@ func (d *DNSDecoderMiekg) DecodeReply(data []byte) (*dns.Msg, error) { if !reply.Response { return nil, ErrDNSIsQuery } - return reply, nil -} - -// decodeSuccessfulReply decodes the bytes in data as a successful reply for the -// given queryID. This function returns an error if: -// -// 1. we cannot decode data -// -// 2. the decoded message is not a reply -// -// 3. the query ID does not match -// -// 4. the Rcode is not zero. -func (d *DNSDecoderMiekg) decodeSuccessfulReply(data []byte, queryID uint16) (*dns.Msg, error) { - reply, err := d.DecodeReply(data) - if err != nil { - return nil, err - } - if reply.Id != queryID { + if reply.Id != query.ID() { return nil, ErrDNSReplyWithWrongQueryID } + resp := &dnsResponse{ + bytes: data, + msg: reply, + query: query, + } + return resp, nil +} + +// dnsResponse implements model.DNSResponse. +type dnsResponse struct { + // bytes contains the response bytes. + bytes []byte + + // msg contains the message. + msg *dns.Msg + + // query is the original query. + query model.DNSQuery +} + +// Query implements model.DNSResponse.Query. +func (r *dnsResponse) Query() model.DNSQuery { + return r.query +} + +// Bytes implements model.DNSResponse.Bytes. +func (r *dnsResponse) Bytes() []byte { + return r.bytes +} + +// Rcode implements model.DNSResponse.Rcode. +func (r *dnsResponse) Rcode() int { + return r.msg.Rcode +} + +func (r *dnsResponse) rcodeToError() error { // TODO(bassosimone): map more errors to net.DNSError names // TODO(bassosimone): add support for lame referral. - switch reply.Rcode { + switch r.msg.Rcode { case dns.RcodeSuccess: - return reply, nil + return nil case dns.RcodeNameError: - return nil, ErrOODNSNoSuchHost + return ErrOODNSNoSuchHost case dns.RcodeRefused: - return nil, ErrOODNSRefused + return ErrOODNSRefused case dns.RcodeServerFailure: - return nil, ErrOODNSServfail + return ErrOODNSServfail default: - return nil, ErrOODNSMisbehaving + return ErrOODNSMisbehaving } } -func (d *DNSDecoderMiekg) DecodeHTTPS(data []byte, queryID uint16) (*model.HTTPSSvc, error) { - reply, err := d.decodeSuccessfulReply(data, queryID) - if err != nil { +// DecodeHTTPS implements model.DNSResponse.DecodeHTTPS. +func (r *dnsResponse) DecodeHTTPS() (*model.HTTPSSvc, error) { + if err := r.rcodeToError(); err != nil { return nil, err } out := &model.HTTPSSvc{ @@ -77,7 +97,7 @@ func (d *DNSDecoderMiekg) DecodeHTTPS(data []byte, queryID uint16) (*model.HTTPS IPv4: []string{}, // ensure it's not nil IPv6: []string{}, // ensure it's not nil } - for _, answer := range reply.Answer { + for _, answer := range r.msg.Answer { switch avalue := answer.(type) { case *dns.HTTPS: for _, v := range avalue.Value { @@ -102,14 +122,14 @@ func (d *DNSDecoderMiekg) DecodeHTTPS(data []byte, queryID uint16) (*model.HTTPS return out, nil } -func (d *DNSDecoderMiekg) DecodeLookupHost(qtype uint16, data []byte, queryID uint16) ([]string, error) { - reply, err := d.decodeSuccessfulReply(data, queryID) - if err != nil { +// DecodeLookupHost implements model.DNSResponse.DecodeLookupHost. +func (r *dnsResponse) DecodeLookupHost() ([]string, error) { + if err := r.rcodeToError(); err != nil { return nil, err } var addrs []string - for _, answer := range reply.Answer { - switch qtype { + for _, answer := range r.msg.Answer { + switch r.Query().Type() { case dns.TypeA: if rra, ok := answer.(*dns.A); ok { ip := rra.A @@ -128,13 +148,13 @@ func (d *DNSDecoderMiekg) DecodeLookupHost(qtype uint16, data []byte, queryID ui return addrs, nil } -func (d *DNSDecoderMiekg) DecodeNS(data []byte, queryID uint16) ([]*net.NS, error) { - reply, err := d.decodeSuccessfulReply(data, queryID) - if err != nil { +// DecodeNS implements model.DNSResponse.DecodeNS. +func (r *dnsResponse) DecodeNS() ([]*net.NS, error) { + if err := r.rcodeToError(); err != nil { return nil, err } out := []*net.NS{} - for _, answer := range reply.Answer { + for _, answer := range r.msg.Answer { switch avalue := answer.(type) { case *dns.NS: out = append(out, &net.NS{Host: avalue.Ns}) @@ -147,3 +167,4 @@ func (d *DNSDecoderMiekg) DecodeNS(data []byte, queryID uint16) ([]*net.NS, erro } var _ model.DNSDecoder = &DNSDecoderMiekg{} +var _ model.DNSResponse = &dnsResponse{} diff --git a/internal/netxlite/dnsdecoder_test.go b/internal/netxlite/dnsdecoder_test.go index 67fbd9d..839f203 100644 --- a/internal/netxlite/dnsdecoder_test.go +++ b/internal/netxlite/dnsdecoder_test.go @@ -1,26 +1,27 @@ package netxlite import ( + "bytes" "errors" "net" - "strings" "testing" "github.com/google/go-cmp/cmp" "github.com/miekg/dns" + "github.com/ooni/probe-cli/v3/internal/model/mocks" "github.com/ooni/probe-cli/v3/internal/runtimex" ) -func TestDNSDecoder(t *testing.T) { - t.Run("LookupHost", func(t *testing.T) { +func TestDNSDecoderMiekg(t *testing.T) { + t.Run("DecodeResponse", func(t *testing.T) { t.Run("UnpackError", func(t *testing.T) { d := &DNSDecoderMiekg{} - data, err := d.DecodeLookupHost(dns.TypeA, nil, 0) + resp, err := d.DecodeResponse(nil, &mocks.DNSQuery{}) if err == nil || err.Error() != "dns: overflow unpacking uint16" { t.Fatal("unexpected error", err) } - if data != nil { - t.Fatal("expected nil data here") + if resp != nil { + t.Fatal("expected nil resp here") } }) @@ -28,12 +29,12 @@ func TestDNSDecoder(t *testing.T) { d := &DNSDecoderMiekg{} queryID := dns.Id() rawQuery := dnsGenQuery(dns.TypeA, queryID) - addrs, err := d.DecodeLookupHost(dns.TypeA, rawQuery, queryID) + resp, err := d.DecodeResponse(rawQuery, &mocks.DNSQuery{}) if !errors.Is(err, ErrDNSIsQuery) { t.Fatal("unexpected err", err) } - if len(addrs) > 0 { - t.Fatal("expected no addrs") + if resp != nil { + t.Fatal("expected nil resp here") } }) @@ -44,297 +45,447 @@ func TestDNSDecoder(t *testing.T) { unrelatedID = 14 ) reply := dnsGenLookupHostReplySuccess(dnsGenQuery(dns.TypeA, queryID)) - data, err := d.DecodeLookupHost(dns.TypeA, reply, unrelatedID) + resp, err := d.DecodeResponse(reply, &mocks.DNSQuery{ + MockID: func() uint16 { + return unrelatedID + }, + }) if !errors.Is(err, ErrDNSReplyWithWrongQueryID) { t.Fatal("unexpected error", err) } - if data != nil { - t.Fatal("expected nil data here") + if resp != nil { + t.Fatal("expected nil resp here") } }) - t.Run("NXDOMAIN", func(t *testing.T) { + t.Run("dnsResponse.Query", func(t *testing.T) { d := &DNSDecoderMiekg{} queryID := dns.Id() - data, err := d.DecodeLookupHost(dns.TypeA, dnsGenReplyWithError( - dnsGenQuery(dns.TypeA, queryID), dns.RcodeNameError), queryID) - if err == nil || !strings.HasSuffix(err.Error(), "no such host") { - t.Fatal("not the error we expected", err) + rawQuery := dnsGenQuery(dns.TypeA, queryID) + rawResponse := dnsGenLookupHostReplySuccess(rawQuery) + query := &mocks.DNSQuery{ + MockID: func() uint16 { + return queryID + }, } - if data != nil { - t.Fatal("expected nil data here") - } - }) - - t.Run("Refused", func(t *testing.T) { - d := &DNSDecoderMiekg{} - queryID := dns.Id() - data, err := d.DecodeLookupHost(dns.TypeA, dnsGenReplyWithError( - dnsGenQuery(dns.TypeA, queryID), dns.RcodeRefused), queryID) - if !errors.Is(err, ErrOODNSRefused) { - t.Fatal("not the error we expected", err) - } - if data != nil { - t.Fatal("expected nil data here") - } - }) - - t.Run("Servfail", func(t *testing.T) { - d := &DNSDecoderMiekg{} - queryID := dns.Id() - data, err := d.DecodeLookupHost(dns.TypeA, dnsGenReplyWithError( - dnsGenQuery(dns.TypeA, queryID), dns.RcodeServerFailure), queryID) - if !errors.Is(err, ErrOODNSServfail) { - t.Fatal("not the error we expected", err) - } - if data != nil { - t.Fatal("expected nil data here") - } - }) - - t.Run("no address", func(t *testing.T) { - d := &DNSDecoderMiekg{} - queryID := dns.Id() - data, err := d.DecodeLookupHost(dns.TypeA, dnsGenLookupHostReplySuccess( - dnsGenQuery(dns.TypeA, queryID)), queryID) - if !errors.Is(err, ErrOODNSNoAnswer) { - t.Fatal("not the error we expected", err) - } - if data != nil { - t.Fatal("expected nil data here") - } - }) - - t.Run("decode A", func(t *testing.T) { - d := &DNSDecoderMiekg{} - queryID := dns.Id() - data, err := d.DecodeLookupHost(dns.TypeA, dnsGenLookupHostReplySuccess( - dnsGenQuery(dns.TypeA, queryID), "1.1.1.1", "8.8.8.8"), queryID) + resp, err := d.DecodeResponse(rawResponse, query) if err != nil { t.Fatal(err) } - if len(data) != 2 { - t.Fatal("expected two entries here") - } - if data[0] != "1.1.1.1" { - t.Fatal("invalid first IPv4 entry") - } - if data[1] != "8.8.8.8" { - t.Fatal("invalid second IPv4 entry") + if resp.Query().ID() != query.ID() { + t.Fatal("invalid query") } }) - t.Run("decode AAAA", func(t *testing.T) { + t.Run("dnsResponse.Bytes", func(t *testing.T) { d := &DNSDecoderMiekg{} queryID := dns.Id() - data, err := d.DecodeLookupHost(dns.TypeAAAA, dnsGenLookupHostReplySuccess( - dnsGenQuery(dns.TypeAAAA, queryID), "::1", "fe80::1"), queryID) + rawQuery := dnsGenQuery(dns.TypeA, queryID) + rawResponse := dnsGenLookupHostReplySuccess(rawQuery) + query := &mocks.DNSQuery{ + MockID: func() uint16 { + return queryID + }, + } + resp, err := d.DecodeResponse(rawResponse, query) if err != nil { t.Fatal(err) } - if len(data) != 2 { - t.Fatal("expected two entries here") - } - if data[0] != "::1" { - t.Fatal("invalid first IPv6 entry") - } - if data[1] != "fe80::1" { - t.Fatal("invalid second IPv6 entry") + if !bytes.Equal(rawResponse, resp.Bytes()) { + t.Fatal("invalid bytes") } }) - t.Run("unexpected A reply", func(t *testing.T) { + t.Run("dnsResponse.Rcode", func(t *testing.T) { d := &DNSDecoderMiekg{} queryID := dns.Id() - data, err := d.DecodeLookupHost(dns.TypeA, dnsGenLookupHostReplySuccess( - dnsGenQuery(dns.TypeAAAA, queryID), "::1", "fe80::1"), queryID) - if !errors.Is(err, ErrOODNSNoAnswer) { - t.Fatal("not the error we expected", err) + rawQuery := dnsGenQuery(dns.TypeA, queryID) + rawResponse := dnsGenReplyWithError(rawQuery, dns.RcodeRefused) + query := &mocks.DNSQuery{ + MockID: func() uint16 { + return queryID + }, } - if data != nil { - t.Fatal("expected nil data here") - } - }) - - t.Run("unexpected AAAA reply", func(t *testing.T) { - d := &DNSDecoderMiekg{} - queryID := dns.Id() - data, err := d.DecodeLookupHost(dns.TypeAAAA, dnsGenLookupHostReplySuccess( - dnsGenQuery(dns.TypeA, queryID), "1.1.1.1", "8.8.4.4"), queryID) - if !errors.Is(err, ErrOODNSNoAnswer) { - t.Fatal("not the error we expected", err) - } - if data != nil { - t.Fatal("expected nil data here") - } - }) - }) - - t.Run("decodeSuccessfulReply", func(t *testing.T) { - d := &DNSDecoderMiekg{} - msg := &dns.Msg{} - msg.Rcode = dns.RcodeFormatError // an rcode we don't handle - msg.Response = true - data, err := msg.Pack() - if err != nil { - t.Fatal(err) - } - reply, err := d.decodeSuccessfulReply(data, 0) - if !errors.Is(err, ErrOODNSMisbehaving) { // catch all error - t.Fatal("not the error we expected", err) - } - if reply != nil { - t.Fatal("expected nil reply") - } - }) - - t.Run("DecodeHTTPS", func(t *testing.T) { - t.Run("with nil data", func(t *testing.T) { - d := &DNSDecoderMiekg{} - reply, err := d.DecodeHTTPS(nil, 0) - if err == nil || err.Error() != "dns: overflow unpacking uint16" { - t.Fatal("not the error we expected", err) - } - if reply != nil { - t.Fatal("expected nil reply") - } - }) - - t.Run("with bytes containing a query", func(t *testing.T) { - d := &DNSDecoderMiekg{} - queryID := dns.Id() - rawQuery := dnsGenQuery(dns.TypeHTTPS, queryID) - https, err := d.DecodeHTTPS(rawQuery, queryID) - if !errors.Is(err, ErrDNSIsQuery) { - t.Fatal("unexpected err", err) - } - if https != nil { - t.Fatal("expected nil https") - } - }) - - t.Run("wrong query ID", func(t *testing.T) { - d := &DNSDecoderMiekg{} - const ( - queryID = 17 - unrelatedID = 14 - ) - reply := dnsGenHTTPSReplySuccess(dnsGenQuery(dns.TypeHTTPS, queryID), nil, nil, nil) - data, err := d.DecodeHTTPS(reply, unrelatedID) - if !errors.Is(err, ErrDNSReplyWithWrongQueryID) { - t.Fatal("unexpected error", err) - } - if data != nil { - t.Fatal("expected nil data here") - } - }) - - t.Run("with empty answer", func(t *testing.T) { - queryID := dns.Id() - data := dnsGenHTTPSReplySuccess( - dnsGenQuery(dns.TypeHTTPS, queryID), nil, nil, nil) - d := &DNSDecoderMiekg{} - reply, err := d.DecodeHTTPS(data, queryID) - if !errors.Is(err, ErrOODNSNoAnswer) { - t.Fatal("unexpected err", err) - } - if reply != nil { - t.Fatal("expected nil reply") - } - }) - - t.Run("with full answer", func(t *testing.T) { - queryID := dns.Id() - alpn := []string{"h3"} - v4 := []string{"1.1.1.1"} - v6 := []string{"::1"} - data := dnsGenHTTPSReplySuccess( - dnsGenQuery(dns.TypeHTTPS, queryID), alpn, v4, v6) - d := &DNSDecoderMiekg{} - reply, err := d.DecodeHTTPS(data, queryID) + resp, err := d.DecodeResponse(rawResponse, query) if err != nil { t.Fatal(err) } - if diff := cmp.Diff(alpn, reply.ALPN); diff != "" { - t.Fatal(diff) - } - if diff := cmp.Diff(v4, reply.IPv4); diff != "" { - t.Fatal(diff) - } - if diff := cmp.Diff(v6, reply.IPv6); diff != "" { - t.Fatal(diff) - } - }) - }) - - t.Run("DecodeNS", func(t *testing.T) { - t.Run("with nil data", func(t *testing.T) { - d := &DNSDecoderMiekg{} - reply, err := d.DecodeNS(nil, 0) - if err == nil || err.Error() != "dns: overflow unpacking uint16" { - t.Fatal("not the error we expected", err) - } - if reply != nil { - t.Fatal("expected nil reply") + if resp.Rcode() != dns.RcodeRefused { + t.Fatal("invalid rcode") } }) - t.Run("with bytes containing a query", func(t *testing.T) { - d := &DNSDecoderMiekg{} - queryID := dns.Id() - rawQuery := dnsGenQuery(dns.TypeNS, queryID) - ns, err := d.DecodeNS(rawQuery, queryID) - if !errors.Is(err, ErrDNSIsQuery) { - t.Fatal("unexpected err", err) - } - if len(ns) > 0 { - t.Fatal("expected no result") + t.Run("dnsResponse.rcodeToError", func(t *testing.T) { + // Here we want to ensure we map all the errors we recognize + // correctly and we also map unrecognized errors correctly + var inputsOutputs = []struct { + name string + rcode int + err error + }{{ + name: "when rcode is zero", + rcode: 0, + err: nil, + }, { + name: "NXDOMAIN", + rcode: dns.RcodeNameError, + err: ErrOODNSNoSuchHost, + }, { + name: "refused", + rcode: dns.RcodeRefused, + err: ErrOODNSRefused, + }, { + name: "servfail", + rcode: dns.RcodeServerFailure, + err: ErrOODNSServfail, + }, { + name: "anything else", + rcode: dns.RcodeFormatError, + err: ErrOODNSMisbehaving, + }} + for _, io := range inputsOutputs { + t.Run(io.name, func(t *testing.T) { + d := &DNSDecoderMiekg{} + queryID := dns.Id() + rawQuery := dnsGenQuery(dns.TypeHTTPS, queryID) + rawResponse := dnsGenReplyWithError(rawQuery, io.rcode) + query := &mocks.DNSQuery{ + MockID: func() uint16 { + return queryID + }, + } + resp, err := d.DecodeResponse(rawResponse, query) + if err != nil { + t.Fatal(err) + } + // The following cast should always work in this configuration + err = resp.(*dnsResponse).rcodeToError() + if !errors.Is(err, io.err) { + t.Fatal("unexpected err", err) + } + }) } }) - t.Run("wrong query ID", func(t *testing.T) { - d := &DNSDecoderMiekg{} - const ( - queryID = 17 - unrelatedID = 14 - ) - reply := dnsGenNSReplySuccess(dnsGenQuery(dns.TypeNS, queryID)) - data, err := d.DecodeNS(reply, unrelatedID) - if !errors.Is(err, ErrDNSReplyWithWrongQueryID) { - t.Fatal("unexpected error", err) - } - if data != nil { - t.Fatal("expected nil data here") - } + t.Run("dnsResponse.DecodeHTTPS", func(t *testing.T) { + t.Run("with failure", func(t *testing.T) { + // Ensure that we're not trying to decode if rcode != 0 + d := &DNSDecoderMiekg{} + queryID := dns.Id() + rawQuery := dnsGenQuery(dns.TypeHTTPS, queryID) + rawResponse := dnsGenReplyWithError(rawQuery, dns.RcodeRefused) + query := &mocks.DNSQuery{ + MockID: func() uint16 { + return queryID + }, + } + resp, err := d.DecodeResponse(rawResponse, query) + if err != nil { + t.Fatal(err) + } + https, err := resp.DecodeHTTPS() + if !errors.Is(err, ErrOODNSRefused) { + t.Fatal("unexpected err", err) + } + if https != nil { + t.Fatal("expected nil https result") + } + }) + + t.Run("with empty answer", func(t *testing.T) { + d := &DNSDecoderMiekg{} + queryID := dns.Id() + rawQuery := dnsGenQuery(dns.TypeHTTPS, queryID) + rawResponse := dnsGenHTTPSReplySuccess(rawQuery, nil, nil, nil) + query := &mocks.DNSQuery{ + MockID: func() uint16 { + return queryID + }, + } + resp, err := d.DecodeResponse(rawResponse, query) + if err != nil { + t.Fatal(err) + } + https, err := resp.DecodeHTTPS() + if !errors.Is(err, ErrOODNSNoAnswer) { + t.Fatal("unexpected err", err) + } + if https != nil { + t.Fatal("expected nil https results") + } + }) + + t.Run("with full answer", func(t *testing.T) { + alpn := []string{"h3"} + v4 := []string{"1.1.1.1"} + v6 := []string{"::1"} + d := &DNSDecoderMiekg{} + queryID := dns.Id() + rawQuery := dnsGenQuery(dns.TypeHTTPS, queryID) + rawResponse := dnsGenHTTPSReplySuccess(rawQuery, alpn, v4, v6) + query := &mocks.DNSQuery{ + MockID: func() uint16 { + return queryID + }, + } + resp, err := d.DecodeResponse(rawResponse, query) + if err != nil { + t.Fatal(err) + } + reply, err := resp.DecodeHTTPS() + if err != nil { + t.Fatal(err) + } + if diff := cmp.Diff(alpn, reply.ALPN); diff != "" { + t.Fatal(diff) + } + if diff := cmp.Diff(v4, reply.IPv4); diff != "" { + t.Fatal(diff) + } + if diff := cmp.Diff(v6, reply.IPv6); diff != "" { + t.Fatal(diff) + } + }) }) - t.Run("with empty answer", func(t *testing.T) { - queryID := dns.Id() - data := dnsGenNSReplySuccess(dnsGenQuery(dns.TypeNS, queryID)) - d := &DNSDecoderMiekg{} - reply, err := d.DecodeNS(data, queryID) - if !errors.Is(err, ErrOODNSNoAnswer) { - t.Fatal("unexpected err", err) - } - if reply != nil { - t.Fatal("expected nil reply") - } + t.Run("dnsResponse.DecodeNS", func(t *testing.T) { + t.Run("with failure", func(t *testing.T) { + // Ensure that we're not trying to decode if rcode != 0 + d := &DNSDecoderMiekg{} + queryID := dns.Id() + rawQuery := dnsGenQuery(dns.TypeNS, queryID) + rawResponse := dnsGenReplyWithError(rawQuery, dns.RcodeRefused) + query := &mocks.DNSQuery{ + MockID: func() uint16 { + return queryID + }, + } + resp, err := d.DecodeResponse(rawResponse, query) + if err != nil { + t.Fatal(err) + } + ns, err := resp.DecodeNS() + if !errors.Is(err, ErrOODNSRefused) { + t.Fatal("unexpected err", err) + } + if len(ns) > 0 { + t.Fatal("expected empty ns result") + } + }) + + t.Run("with empty answer", func(t *testing.T) { + d := &DNSDecoderMiekg{} + queryID := dns.Id() + rawQuery := dnsGenQuery(dns.TypeNS, queryID) + rawResponse := dnsGenNSReplySuccess(rawQuery) + query := &mocks.DNSQuery{ + MockID: func() uint16 { + return queryID + }, + } + resp, err := d.DecodeResponse(rawResponse, query) + if err != nil { + t.Fatal(err) + } + ns, err := resp.DecodeNS() + if !errors.Is(err, ErrOODNSNoAnswer) { + t.Fatal("unexpected err", err) + } + if len(ns) > 0 { + t.Fatal("expected empty ns results") + } + }) + + t.Run("with full answer", func(t *testing.T) { + d := &DNSDecoderMiekg{} + queryID := dns.Id() + rawQuery := dnsGenQuery(dns.TypeNS, queryID) + rawResponse := dnsGenNSReplySuccess(rawQuery, "ns1.zdns.google.") + query := &mocks.DNSQuery{ + MockID: func() uint16 { + return queryID + }, + } + resp, err := d.DecodeResponse(rawResponse, query) + if err != nil { + t.Fatal(err) + } + ns, err := resp.DecodeNS() + if err != nil { + t.Fatal(err) + } + if len(ns) != 1 { + t.Fatal("unexpected ns length") + } + if ns[0].Host != "ns1.zdns.google." { + t.Fatal("unexpected host") + } + }) }) - t.Run("with full answer", func(t *testing.T) { - queryID := dns.Id() - data := dnsGenNSReplySuccess(dnsGenQuery(dns.TypeNS, queryID), "ns1.zdns.google.") - d := &DNSDecoderMiekg{} - reply, err := d.DecodeNS(data, queryID) - if err != nil { - t.Fatal(err) - } - if len(reply) != 1 { - t.Fatal("unexpected reply length") - } - if reply[0].Host != "ns1.zdns.google." { - t.Fatal("unexpected reply host") - } + t.Run("dnsResponse.LookupHost", func(t *testing.T) { + t.Run("with failure", func(t *testing.T) { + // Ensure that we're not trying to decode if rcode != 0 + d := &DNSDecoderMiekg{} + queryID := dns.Id() + rawQuery := dnsGenQuery(dns.TypeA, queryID) + rawResponse := dnsGenReplyWithError(rawQuery, dns.RcodeRefused) + query := &mocks.DNSQuery{ + MockID: func() uint16 { + return queryID + }, + } + resp, err := d.DecodeResponse(rawResponse, query) + if err != nil { + t.Fatal(err) + } + addrs, err := resp.DecodeLookupHost() + if !errors.Is(err, ErrOODNSRefused) { + t.Fatal("unexpected err", err) + } + if len(addrs) > 0 { + t.Fatal("expected empty addrs result") + } + }) + + t.Run("with empty answer", func(t *testing.T) { + d := &DNSDecoderMiekg{} + queryID := dns.Id() + rawQuery := dnsGenQuery(dns.TypeA, queryID) + rawResponse := dnsGenLookupHostReplySuccess(rawQuery) + query := &mocks.DNSQuery{ + MockID: func() uint16 { + return queryID + }, + } + resp, err := d.DecodeResponse(rawResponse, query) + if err != nil { + t.Fatal(err) + } + addrs, err := resp.DecodeLookupHost() + if !errors.Is(err, ErrOODNSNoAnswer) { + t.Fatal("unexpected err", err) + } + if len(addrs) > 0 { + t.Fatal("expected empty ns results") + } + }) + + t.Run("decode A", func(t *testing.T) { + d := &DNSDecoderMiekg{} + queryID := dns.Id() + rawQuery := dnsGenQuery(dns.TypeA, queryID) + rawResponse := dnsGenLookupHostReplySuccess(rawQuery, "1.1.1.1", "8.8.8.8") + query := &mocks.DNSQuery{ + MockID: func() uint16 { + return queryID + }, + MockType: func() uint16 { + return dns.TypeA + }, + } + resp, err := d.DecodeResponse(rawResponse, query) + if err != nil { + t.Fatal(err) + } + addrs, err := resp.DecodeLookupHost() + if err != nil { + t.Fatal(err) + } + if len(addrs) != 2 { + t.Fatal("expected two entries here") + } + if addrs[0] != "1.1.1.1" { + t.Fatal("invalid first IPv4 entry") + } + if addrs[1] != "8.8.8.8" { + t.Fatal("invalid second IPv4 entry") + } + }) + + t.Run("decode AAAA", func(t *testing.T) { + d := &DNSDecoderMiekg{} + queryID := dns.Id() + rawQuery := dnsGenQuery(dns.TypeAAAA, queryID) + rawResponse := dnsGenLookupHostReplySuccess(rawQuery, "::1", "fe80::1") + query := &mocks.DNSQuery{ + MockID: func() uint16 { + return queryID + }, + MockType: func() uint16 { + return dns.TypeAAAA + }, + } + resp, err := d.DecodeResponse(rawResponse, query) + if err != nil { + t.Fatal(err) + } + addrs, err := resp.DecodeLookupHost() + if err != nil { + t.Fatal(err) + } + if len(addrs) != 2 { + t.Fatal("expected two entries here") + } + if addrs[0] != "::1" { + t.Fatal("invalid first IPv6 entry") + } + if addrs[1] != "fe80::1" { + t.Fatal("invalid second IPv6 entry") + } + }) + + t.Run("unexpected A reply to AAAA query", func(t *testing.T) { + d := &DNSDecoderMiekg{} + queryID := dns.Id() + rawQuery := dnsGenQuery(dns.TypeAAAA, queryID) + rawResponse := dnsGenLookupHostReplySuccess(rawQuery, "1.1.1.1", "8.8.8.8") + query := &mocks.DNSQuery{ + MockID: func() uint16 { + return queryID + }, + MockType: func() uint16 { + return dns.TypeAAAA + }, + } + resp, err := d.DecodeResponse(rawResponse, query) + if err != nil { + t.Fatal(err) + } + addrs, err := resp.DecodeLookupHost() + if !errors.Is(err, ErrOODNSNoAnswer) { + t.Fatal("not the error we expected", err) + } + if len(addrs) > 0 { + t.Fatal("expected no addrs here") + } + }) + + t.Run("unexpected AAAA reply to A query", func(t *testing.T) { + d := &DNSDecoderMiekg{} + queryID := dns.Id() + rawQuery := dnsGenQuery(dns.TypeA, queryID) + rawResponse := dnsGenLookupHostReplySuccess(rawQuery, "::1", "fe80::1") + query := &mocks.DNSQuery{ + MockID: func() uint16 { + return queryID + }, + MockType: func() uint16 { + return dns.TypeA + }, + } + resp, err := d.DecodeResponse(rawResponse, query) + if err != nil { + t.Fatal(err) + } + addrs, err := resp.DecodeLookupHost() + if !errors.Is(err, ErrOODNSNoAnswer) { + t.Fatal("not the error we expected", err) + } + if len(addrs) > 0 { + t.Fatal("expected no addrs here") + } + }) }) }) } @@ -371,8 +522,8 @@ func dnsGenReplyWithError(rawQuery []byte, code int) []byte { return data } -// dnsGenLookupHostReplySuccess generates a successful DNS reply for the given -// qtype (e.g., dns.TypeA) containing the given ips... in the answer. +// dnsGenLookupHostReplySuccess generates a successful DNS reply containing the given ips... +// in the answers where each answer's type depends on the IP's type (A/AAAA). func dnsGenLookupHostReplySuccess(rawQuery []byte, ips ...string) []byte { query := new(dns.Msg) err := query.Unpack(rawQuery) @@ -388,28 +539,22 @@ func dnsGenLookupHostReplySuccess(rawQuery []byte, ips ...string) []byte { reply.MsgHdr.RecursionAvailable = true reply.SetReply(query) for _, ip := range ips { - switch question.Qtype { - case dns.TypeA: - if isIPv6(ip) { - continue - } + switch isIPv6(ip) { + case false: reply.Answer = append(reply.Answer, &dns.A{ Hdr: dns.RR_Header{ - Name: dns.Fqdn("x.org"), - Rrtype: question.Qtype, + Name: question.Name, + Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 0, }, A: net.ParseIP(ip), }) - case dns.TypeAAAA: - if !isIPv6(ip) { - continue - } + case true: reply.Answer = append(reply.Answer, &dns.AAAA{ Hdr: dns.RR_Header{ - Name: dns.Fqdn("x.org"), - Rrtype: question.Qtype, + Name: question.Name, + Rrtype: dns.TypeAAAA, Class: dns.ClassINET, Ttl: 0, }, diff --git a/internal/netxlite/dnsencoder.go b/internal/netxlite/dnsencoder.go index 91e3ba1..9b9e8cb 100644 --- a/internal/netxlite/dnsencoder.go +++ b/internal/netxlite/dnsencoder.go @@ -5,7 +5,10 @@ package netxlite // import ( + "sync" + "github.com/miekg/dns" + "github.com/ooni/probe-cli/v3/internal/atomicx" "github.com/ooni/probe-cli/v3/internal/model" ) @@ -23,18 +26,82 @@ const ( dnsDNSSECEnabled = true ) -func (e *DNSEncoderMiekg) Encode(domain string, qtype uint16, padding bool) ([]byte, uint16, error) { +// Encoder implements model.DNSEncoder.Encode. +func (e *DNSEncoderMiekg) Encode(domain string, qtype uint16, padding bool) model.DNSQuery { + return &dnsQuery{ + bytesCalls: &atomicx.Int64{}, + domain: domain, + kind: qtype, + id: dns.Id(), + memoizedBytes: []byte{}, + mu: sync.Mutex{}, + padding: padding, + } +} + +// dnsQuery implements model.DNSQuery. +type dnsQuery struct { + // bytesCalls counts the calls to the bytes() method + bytesCalls *atomicx.Int64 + + // domain is the domain. + domain string + + // kind is the query type. + kind uint16 + + // id is the query ID. + id uint16 + + // memoizedBytes contains the query encoded as bytes. We only fill + // this field the first time the Bytes method is called. + memoizedBytes []byte + + // mu provides mutual exclusion. + mu sync.Mutex + + // padding indicates whether we need padding. + padding bool +} + +// Domain implements model.DNSQuery.Domain. +func (q *dnsQuery) Domain() string { + return q.domain +} + +// Type implements model.DNSQuery.Type. +func (q *dnsQuery) Type() uint16 { + return q.kind +} + +// Bytes implements model.DNSQuery.Bytes. +func (q *dnsQuery) Bytes() ([]byte, error) { + defer q.mu.Unlock() + q.mu.Lock() + if len(q.memoizedBytes) <= 0 { + q.bytesCalls.Add(1) // for testing + data, err := q.bytes() + if err != nil { + return nil, err + } + q.memoizedBytes = data + } + return q.memoizedBytes, nil +} + +// bytes is the unmemoized implementation of Bytes +func (q *dnsQuery) bytes() ([]byte, error) { question := dns.Question{ - Name: dns.Fqdn(domain), - Qtype: qtype, + Name: dns.Fqdn(q.domain), + Qtype: q.kind, Qclass: dns.ClassINET, } query := new(dns.Msg) - query.Id = dns.Id() + query.Id = q.id query.RecursionDesired = true query.Question = make([]dns.Question, 1) query.Question[0] = question - if padding { + if q.padding { query.SetEdns0(dnsEDNS0MaxResponseSize, dnsDNSSECEnabled) // Clients SHOULD pad queries to the closest multiple of // 128 octets RFC8467#section-4.1. We inflate the query @@ -47,8 +114,13 @@ func (e *DNSEncoderMiekg) Encode(domain string, qtype uint16, padding bool) ([]b opt.Padding = make([]byte, remainder) query.IsEdns0().Option = append(query.IsEdns0().Option, opt) } - data, err := query.Pack() - return data, query.Id, err + return query.Pack() +} + +// ID implements model.DNSQuery.ID +func (q *dnsQuery) ID() uint16 { + return q.id } var _ model.DNSEncoder = &DNSEncoderMiekg{} +var _ model.DNSQuery = &dnsQuery{} diff --git a/internal/netxlite/dnsencoder_test.go b/internal/netxlite/dnsencoder_test.go index 6f85809..0b24ce9 100644 --- a/internal/netxlite/dnsencoder_test.go +++ b/internal/netxlite/dnsencoder_test.go @@ -1,29 +1,103 @@ package netxlite import ( + "bytes" + "encoding/binary" "strings" "testing" "github.com/miekg/dns" + "github.com/ooni/probe-cli/v3/internal/randx" + "github.com/ooni/probe-cli/v3/internal/runtimex" ) -func TestDNSEncoder(t *testing.T) { +func TestDNSEncoderMiekg(t *testing.T) { + t.Run("we can fail to encode a domain name to bytes", func(t *testing.T) { + e := &DNSEncoderMiekg{} + domain := randx.LettersUppercase(512) + query := e.Encode(domain, dns.TypeA, false) + data, err := query.Bytes() + if err == nil || !strings.HasSuffix(err.Error(), "bad rdata") { + t.Fatal("unexpected err", err) + } + if data != nil { + t.Fatal("expected nil data here") + } + }) + + t.Run("calls to bytes are memoized", func(t *testing.T) { + t.Run("on success", func(t *testing.T) { + e := &DNSEncoderMiekg{} + query := e.Encode("x.org", dns.TypeA, false) + checkResult := func(data []byte, err error) { + if err != nil { + t.Fatal("unexpected err", err) + } + dnsValidateEncodedQueryBytes(t, data, byte(dns.TypeA), query.ID()) + } + const repeat = 3 + for idx := 0; idx < repeat; idx++ { + checkResult(query.Bytes()) + } + // The following cast will always work in this configuration + if query.(*dnsQuery).bytesCalls.Load() != 1 { + t.Fatal("invalid number of calls") + } + }) + + t.Run("on failure", func(t *testing.T) { + e := &DNSEncoderMiekg{} + domain := randx.LettersUppercase(512) + query := e.Encode(domain, dns.TypeA, false) + checkResult := func(data []byte, err error) { + if err == nil || !strings.HasSuffix(err.Error(), "bad rdata") { + t.Fatal("unexpected err", err) + } + if data != nil { + t.Fatal("expected nil data here") + } + } + const repeat = 3 + for idx := 0; idx < repeat; idx++ { + checkResult(query.Bytes()) + } + // The following cast will always work in this configuration + if query.(*dnsQuery).bytesCalls.Load() != repeat { + t.Fatal("invalid number of calls") + } + }) + }) + t.Run("encode A", func(t *testing.T) { e := &DNSEncoderMiekg{} - data, _, err := e.Encode("x.org", dns.TypeA, false) + query := e.Encode("x.org", dns.TypeA, false) + if query.Domain() != "x.org" { + t.Fatal("invalid domain") + } + if query.Type() != dns.TypeA { + t.Fatal("invalid type") + } + data, err := query.Bytes() if err != nil { t.Fatal(err) } - dnsValidateEncodedQueryBytes(t, data, byte(dns.TypeA)) + dnsValidateEncodedQueryBytes(t, data, byte(dns.TypeA), query.ID()) }) t.Run("encode AAAA", func(t *testing.T) { e := &DNSEncoderMiekg{} - data, _, err := e.Encode("x.org", dns.TypeAAAA, false) + query := e.Encode("x.org", dns.TypeAAAA, false) + if query.Domain() != "x.org" { + t.Fatal("invalid domain") + } + if query.Type() != dns.TypeAAAA { + t.Fatal("invalid type") + } + data, err := query.Bytes() if err != nil { t.Fatal(err) } - dnsValidateEncodedQueryBytes(t, data, byte(dns.TypeA)) + dnsValidateEncodedQueryBytes(t, data, byte(dns.TypeA), query.ID()) }) t.Run("encode padding", func(t *testing.T) { @@ -31,7 +105,7 @@ func TestDNSEncoder(t *testing.T) { // array of values we obtain the right query size. getquerylen := func(domainlen int, padding bool) int { e := &DNSEncoderMiekg{} - data, _, err := e.Encode( + query := e.Encode( // This is not a valid name because it ends up being way // longer than 255 octets. However, the library is allowing // us to generate such name and we are not going to send @@ -40,6 +114,7 @@ func TestDNSEncoder(t *testing.T) { dns.Fqdn(strings.Repeat("x.", domainlen)), dns.TypeA, padding, ) + data, err := query.Bytes() if err != nil { t.Fatal(err) } @@ -63,8 +138,13 @@ func TestDNSEncoder(t *testing.T) { // dnsValidateEncodedQueryBytes validates the query serialized in data // for the given query type qtype (e.g., dns.TypeAAAA). -func dnsValidateEncodedQueryBytes(t *testing.T, data []byte, qtype byte) { - // skipping over the query ID +func dnsValidateEncodedQueryBytes(t *testing.T, data []byte, qtype byte, qid uint16) { + var wirequery uint16 + err := binary.Read(bytes.NewReader(data), binary.BigEndian, &wirequery) + runtimex.PanicOnError(err, "binary.Read failed unexpectedly") + if wirequery != qid { + t.Fatal("invalid query ID") + } if data[2] != 1 { t.Fatal("FLAGS should only have RD set") } diff --git a/internal/netxlite/dnsoverhttps.go b/internal/netxlite/dnsoverhttps.go index 5fef09e..196bc28 100644 --- a/internal/netxlite/dnsoverhttps.go +++ b/internal/netxlite/dnsoverhttps.go @@ -8,6 +8,7 @@ import ( "bytes" "context" "errors" + "io" "net/http" "time" @@ -19,6 +20,9 @@ type DNSOverHTTPSTransport struct { // Client is the MANDATORY http client to use. Client model.HTTPClient + // Decoder is the MANDATORY DNSDecoder. + Decoder model.DNSDecoder + // URL is the MANDATORY URL of the DNS-over-HTTPS server. URL string @@ -31,9 +35,9 @@ type DNSOverHTTPSTransport struct { // // Arguments: // -// - client in http.Client-like type (e.g., http.DefaultClient); +// - client is a model.HTTPClient type; // -// - URL is the DoH resolver URL (e.g., https://1.1.1.1/dns-query). +// - URL is the DoH resolver URL (e.g., https://dns.google/dns-query). func NewDNSOverHTTPSTransport(client model.HTTPClient, URL string) *DNSOverHTTPSTransport { return NewDNSOverHTTPSTransportWithHostOverride(client, URL, "") } @@ -42,22 +46,31 @@ func NewDNSOverHTTPSTransport(client model.HTTPClient, URL string) *DNSOverHTTPS // with the given Host header override. func NewDNSOverHTTPSTransportWithHostOverride( client model.HTTPClient, URL, hostOverride string) *DNSOverHTTPSTransport { - return &DNSOverHTTPSTransport{Client: client, URL: URL, HostOverride: hostOverride} + return &DNSOverHTTPSTransport{ + Client: client, + Decoder: &DNSDecoderMiekg{}, + URL: URL, + HostOverride: hostOverride, + } } // RoundTrip sends a query and receives a reply. -func (t *DNSOverHTTPSTransport) RoundTrip(ctx context.Context, query []byte) ([]byte, error) { +func (t *DNSOverHTTPSTransport) RoundTrip( + ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) { + rawQuery, err := query.Bytes() + if err != nil { + return nil, err + } ctx, cancel := context.WithTimeout(ctx, 45*time.Second) defer cancel() - req, err := http.NewRequest("POST", t.URL, bytes.NewReader(query)) + req, err := http.NewRequest("POST", t.URL, bytes.NewReader(rawQuery)) if err != nil { return nil, err } req.Host = t.HostOverride req.Header.Set("user-agent", model.HTTPHeaderUserAgent) req.Header.Set("content-type", "application/dns-message") - var resp *http.Response - resp, err = t.Client.Do(req.WithContext(ctx)) + resp, err := t.Client.Do(req.WithContext(ctx)) if err != nil { return nil, err } @@ -70,7 +83,13 @@ func (t *DNSOverHTTPSTransport) RoundTrip(ctx context.Context, query []byte) ([] if resp.Header.Get("content-type") != "application/dns-message" { return nil, errors.New("doh: invalid content-type") } - return ReadAllContext(ctx, resp.Body) + const maxresponsesize = 1 << 20 + limitReader := io.LimitReader(resp.Body, maxresponsesize) + rawResponse, err := ReadAllContext(ctx, limitReader) + if err != nil { + return nil, err + } + return t.Decoder.DecodeResponse(rawResponse, query) } // RequiresPadding returns true for DoH according to RFC8467. diff --git a/internal/netxlite/dnsoverhttps_test.go b/internal/netxlite/dnsoverhttps_test.go index 24ef46d..57480f0 100644 --- a/internal/netxlite/dnsoverhttps_test.go +++ b/internal/netxlite/dnsoverhttps_test.go @@ -15,14 +15,36 @@ import ( func TestDNSOverHTTPSTransport(t *testing.T) { t.Run("RoundTrip", func(t *testing.T) { + t.Run("query serialization failure", func(t *testing.T) { + txp := NewDNSOverHTTPSTransport(http.DefaultClient, "https://1.1.1.1/dns-query") + expected := errors.New("mocked error") + query := &mocks.DNSQuery{ + MockBytes: func() ([]byte, error) { + return nil, expected + }, + } + resp, err := txp.RoundTrip(context.Background(), query) + if !errors.Is(err, expected) { + t.Fatal("unexpected err", err) + } + if resp != nil { + t.Fatal("expected no response here") + } + }) + t.Run("NewRequestFailure", func(t *testing.T) { const invalidURL = "\t" txp := NewDNSOverHTTPSTransport(http.DefaultClient, invalidURL) - data, err := txp.RoundTrip(context.Background(), nil) - if err == nil || !strings.HasSuffix(err.Error(), "invalid control character in URL") { - t.Fatal("expected an error here") + query := &mocks.DNSQuery{ + MockBytes: func() ([]byte, error) { + return make([]byte, 17), nil + }, } - if data != nil { + resp, err := txp.RoundTrip(context.Background(), query) + if err == nil || !strings.HasSuffix(err.Error(), "invalid control character in URL") { + t.Fatal("unexpected err", err) + } + if resp != nil { t.Fatal("expected no response here") } }) @@ -37,11 +59,16 @@ func TestDNSOverHTTPSTransport(t *testing.T) { }, URL: "https://cloudflare-dns.com/dns-query", } - data, err := txp.RoundTrip(context.Background(), nil) - if !errors.Is(err, expected) { - t.Fatal("expected an error here") + query := &mocks.DNSQuery{ + MockBytes: func() ([]byte, error) { + return make([]byte, 17), nil + }, } - if data != nil { + resp, err := txp.RoundTrip(context.Background(), query) + if !errors.Is(err, expected) { + t.Fatal("unexpected err", err) + } + if resp != nil { t.Fatal("expected no response here") } }) @@ -58,11 +85,16 @@ func TestDNSOverHTTPSTransport(t *testing.T) { }, URL: "https://cloudflare-dns.com/dns-query", } - data, err := txp.RoundTrip(context.Background(), nil) - if err == nil || err.Error() != "doh: server returned error" { - t.Fatal("expected an error here") + query := &mocks.DNSQuery{ + MockBytes: func() ([]byte, error) { + return make([]byte, 17), nil + }, } - if data != nil { + resp, err := txp.RoundTrip(context.Background(), query) + if err == nil || err.Error() != "doh: server returned error" { + t.Fatal("unexpected err", err) + } + if resp != nil { t.Fatal("expected no response here") } }) @@ -79,11 +111,86 @@ func TestDNSOverHTTPSTransport(t *testing.T) { }, URL: "https://cloudflare-dns.com/dns-query", } - data, err := txp.RoundTrip(context.Background(), nil) - if err == nil || err.Error() != "doh: invalid content-type" { - t.Fatal("expected an error here") + query := &mocks.DNSQuery{ + MockBytes: func() ([]byte, error) { + return make([]byte, 17), nil + }, } - if data != nil { + resp, err := txp.RoundTrip(context.Background(), query) + if err == nil || err.Error() != "doh: invalid content-type" { + t.Fatal("unexpected err", err) + } + if resp != nil { + t.Fatal("expected no response here") + } + }) + + t.Run("ReadAllContext fails", func(t *testing.T) { + expected := errors.New("mocked error") + txp := &DNSOverHTTPSTransport{ + Client: &mocks.HTTPClient{ + MockDo: func(*http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: 200, + Body: io.NopCloser(&mocks.Reader{ + MockRead: func(b []byte) (int, error) { + return 0, expected + }, + }), + Header: http.Header{ + "Content-Type": []string{"application/dns-message"}, + }, + }, nil + }, + }, + URL: "https://cloudflare-dns.com/dns-query", + } + query := &mocks.DNSQuery{ + MockBytes: func() ([]byte, error) { + return make([]byte, 17), nil + }, + } + resp, err := txp.RoundTrip(context.Background(), query) + if !errors.Is(err, expected) { + t.Fatal("unexpected err", err) + } + if resp != nil { + t.Fatal("expected no response here") + } + }) + + t.Run("decode response failure", func(t *testing.T) { + expected := errors.New("mocked error") + body := []byte("AAA") + txp := &DNSOverHTTPSTransport{ + Client: &mocks.HTTPClient{ + MockDo: func(*http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: 200, + Body: io.NopCloser(bytes.NewReader(body)), + Header: http.Header{ + "Content-Type": []string{"application/dns-message"}, + }, + }, nil + }, + }, + URL: "https://cloudflare-dns.com/dns-query", + Decoder: &mocks.DNSDecoder{ + MockDecodeResponse: func(data []byte, query model.DNSQuery) (model.DNSResponse, error) { + return nil, expected + }, + }, + } + query := &mocks.DNSQuery{ + MockBytes: func() ([]byte, error) { + return make([]byte, 17), nil + }, + } + resp, err := txp.RoundTrip(context.Background(), query) + if !errors.Is(err, expected) { + t.Fatal("unexpected err", err) + } + if resp != nil { t.Fatal("expected no response here") } }) @@ -103,13 +210,23 @@ func TestDNSOverHTTPSTransport(t *testing.T) { }, }, URL: "https://cloudflare-dns.com/dns-query", + Decoder: &mocks.DNSDecoder{ + MockDecodeResponse: func(data []byte, query model.DNSQuery) (model.DNSResponse, error) { + return &mocks.DNSResponse{}, nil + }, + }, } - data, err := txp.RoundTrip(context.Background(), nil) + query := &mocks.DNSQuery{ + MockBytes: func() ([]byte, error) { + return make([]byte, 17), nil + }, + } + resp, err := txp.RoundTrip(context.Background(), query) if err != nil { t.Fatal(err) } - if !bytes.Equal(data, body) { - t.Fatal("not the response we expected") + if resp == nil { + t.Fatal("expected non-nil resp here") } }) @@ -125,7 +242,12 @@ func TestDNSOverHTTPSTransport(t *testing.T) { }, URL: "https://cloudflare-dns.com/dns-query", } - data, err := txp.RoundTrip(context.Background(), nil) + query := &mocks.DNSQuery{ + MockBytes: func() ([]byte, error) { + return make([]byte, 17), nil + }, + } + data, err := txp.RoundTrip(context.Background(), query) if !errors.Is(err, expected) { t.Fatal("expected an error here") } @@ -151,18 +273,22 @@ func TestDNSOverHTTPSTransport(t *testing.T) { URL: "https://cloudflare-dns.com/dns-query", HostOverride: hostOverride, } - data, err := txp.RoundTrip(context.Background(), nil) + query := &mocks.DNSQuery{ + MockBytes: func() ([]byte, error) { + return make([]byte, 17), nil + }, + } + resp, err := txp.RoundTrip(context.Background(), query) if !errors.Is(err, expected) { t.Fatal("expected an error here") } - if data != nil { + if resp != nil { t.Fatal("expected no response here") } if !correct { t.Fatal("did not see correct host override") } }) - }) t.Run("other functions behave correctly", func(t *testing.T) { diff --git a/internal/netxlite/dnsovertcp.go b/internal/netxlite/dnsovertcp.go index 190b946..1f174c7 100644 --- a/internal/netxlite/dnsovertcp.go +++ b/internal/netxlite/dnsovertcp.go @@ -20,9 +20,12 @@ type DialContextFunc func(context.Context, string, string) (net.Conn, error) // DNSOverTCPTransport is a DNS-over-{TCP,TLS} DNSTransport. // -// Bug: this implementation always creates a new connection for each query. +// Note: this implementation always creates a new connection for each query. This +// strategy is less efficient but MAY be more robust for cleartext TCP connections +// when querying for a blocked domain name causes endpoint blocking. type DNSOverTCPTransport struct { dial DialContextFunc + decoder model.DNSDecoder address string network string requiresPadding bool @@ -36,47 +39,58 @@ type DNSOverTCPTransport struct { // // - address is the endpoint address (e.g., 8.8.8.8:53). func NewDNSOverTCPTransport(dial DialContextFunc, address string) *DNSOverTCPTransport { - return &DNSOverTCPTransport{ - dial: dial, - address: address, - network: "tcp", - requiresPadding: false, - } + return newDNSOverTCPOrTLSTransport(dial, "tcp", address, false) } -// NewDNSOverTLS creates a new DNSOverTLS transport. +// NewDNSOverTLSTransport creates a new DNSOverTLS transport. // // Arguments: // // - dial is a function with the net.Dialer.DialContext's signature; // // - address is the endpoint address (e.g., 8.8.8.8:853). -func NewDNSOverTLS(dial DialContextFunc, address string) *DNSOverTCPTransport { +func NewDNSOverTLSTransport(dial DialContextFunc, address string) *DNSOverTCPTransport { + return newDNSOverTCPOrTLSTransport(dial, "dot", address, true) +} + +// newDNSOverTCPOrTLSTransport is the common factory for creating a transport +func newDNSOverTCPOrTLSTransport( + dial DialContextFunc, network, address string, padding bool) *DNSOverTCPTransport { return &DNSOverTCPTransport{ dial: dial, + decoder: &DNSDecoderMiekg{}, address: address, - network: "dot", - requiresPadding: true, + network: network, + requiresPadding: padding, } } +// errQueryTooLarge indicates the query is too large for the transport. +var errQueryTooLarge = errors.New("oodns: query too large for this transport") + // RoundTrip sends a query and receives a reply. -func (t *DNSOverTCPTransport) RoundTrip(ctx context.Context, query []byte) ([]byte, error) { - if len(query) > math.MaxUint16 { - return nil, errors.New("query too long") +func (t *DNSOverTCPTransport) RoundTrip( + ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) { + // TODO(bassosimone): this method should more strictly honour the context, which + // currently is only used to bound the dial operation + rawQuery, err := query.Bytes() + if err != nil { + return nil, err + } + if len(rawQuery) > math.MaxUint16 { + return nil, errQueryTooLarge } conn, err := t.dial(ctx, "tcp", t.address) if err != nil { return nil, err } defer conn.Close() - if err = conn.SetDeadline(time.Now().Add(10 * time.Second)); err != nil { - return nil, err - } + const iotimeout = 10 * time.Second + conn.SetDeadline(time.Now().Add(iotimeout)) // Write request - buf := []byte{byte(len(query) >> 8)} - buf = append(buf, byte(len(query))) - buf = append(buf, query...) + buf := []byte{byte(len(rawQuery) >> 8)} + buf = append(buf, byte(len(rawQuery))) + buf = append(buf, rawQuery...) if _, err = conn.Write(buf); err != nil { return nil, err } @@ -86,11 +100,11 @@ func (t *DNSOverTCPTransport) RoundTrip(ctx context.Context, query []byte) ([]by return nil, err } length := int(header[0])<<8 | int(header[1]) - reply := make([]byte, length) - if _, err = io.ReadFull(conn, reply); err != nil { + rawResponse := make([]byte, length) + if _, err = io.ReadFull(conn, rawResponse); err != nil { return nil, err } - return reply, nil + return t.decoder.DecodeResponse(rawResponse, query) } // RequiresPadding returns true for DoT and false for TCP diff --git a/internal/netxlite/dnsovertcp_test.go b/internal/netxlite/dnsovertcp_test.go index 7fc4128..3be1391 100644 --- a/internal/netxlite/dnsovertcp_test.go +++ b/internal/netxlite/dnsovertcp_test.go @@ -6,73 +6,83 @@ import ( "crypto/tls" "errors" "io" + "math" "net" "testing" "time" + "github.com/ooni/probe-cli/v3/internal/model" "github.com/ooni/probe-cli/v3/internal/model/mocks" ) func TestDNSOverTCPTransport(t *testing.T) { t.Run("RoundTrip", func(t *testing.T) { + t.Run("cannot encode query", func(t *testing.T) { + expected := errors.New("mocked error") + const address = "9.9.9.9:53" + txp := NewDNSOverTCPTransport(new(net.Dialer).DialContext, address) + query := &mocks.DNSQuery{ + MockBytes: func() ([]byte, error) { + return nil, expected + }, + } + resp, err := txp.RoundTrip(context.Background(), query) + if !errors.Is(err, expected) { + t.Fatal("unexpected err", err) + } + if resp != nil { + t.Fatal("expected nil response here") + } + }) + t.Run("query too large", func(t *testing.T) { const address = "9.9.9.9:53" txp := NewDNSOverTCPTransport(new(net.Dialer).DialContext, address) - reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<18)) - if err == nil { - t.Fatal("expected an error here") + query := &mocks.DNSQuery{ + MockBytes: func() ([]byte, error) { + return make([]byte, math.MaxUint16+1), nil + }, } - if reply != nil { - t.Fatal("expected nil reply here") + resp, err := txp.RoundTrip(context.Background(), query) + if !errors.Is(err, errQueryTooLarge) { + t.Fatal("unexpected err", err) + } + if resp != nil { + t.Fatal("expected nil response here") } }) t.Run("dial failure", func(t *testing.T) { const address = "9.9.9.9:53" mocked := errors.New("mocked error") + query := &mocks.DNSQuery{ + MockBytes: func() ([]byte, error) { + return make([]byte, 128), nil + }, + } fakedialer := &mocks.Dialer{ MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) { return nil, mocked }, } txp := NewDNSOverTCPTransport(fakedialer.DialContext, address) - reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11)) + resp, err := txp.RoundTrip(context.Background(), query) if !errors.Is(err, mocked) { t.Fatal("not the error we expected") } - if reply != nil { - t.Fatal("expected nil reply here") - } - }) - - t.Run("SetDeadline failure", func(t *testing.T) { - const address = "9.9.9.9:53" - mocked := errors.New("mocked error") - fakedialer := &mocks.Dialer{ - MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) { - return &mocks.Conn{ - MockSetDeadline: func(t time.Time) error { - return mocked - }, - MockClose: func() error { - return nil - }, - }, nil - }, - } - txp := NewDNSOverTCPTransport(fakedialer.DialContext, address) - reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11)) - if !errors.Is(err, mocked) { - t.Fatal("not the error we expected") - } - if reply != nil { - t.Fatal("expected nil reply here") + if resp != nil { + t.Fatal("expected nil resp here") } }) t.Run("write failure", func(t *testing.T) { const address = "9.9.9.9:53" mocked := errors.New("mocked error") + query := &mocks.DNSQuery{ + MockBytes: func() ([]byte, error) { + return make([]byte, 128), nil + }, + } fakedialer := &mocks.Dialer{ MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) { return &mocks.Conn{ @@ -89,18 +99,23 @@ func TestDNSOverTCPTransport(t *testing.T) { }, } txp := NewDNSOverTCPTransport(fakedialer.DialContext, address) - reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11)) + resp, err := txp.RoundTrip(context.Background(), query) if !errors.Is(err, mocked) { t.Fatal("not the error we expected") } - if reply != nil { - t.Fatal("expected nil reply here") + if resp != nil { + t.Fatal("expected nil resp here") } }) t.Run("first read fails", func(t *testing.T) { const address = "9.9.9.9:53" mocked := errors.New("mocked error") + query := &mocks.DNSQuery{ + MockBytes: func() ([]byte, error) { + return make([]byte, 128), nil + }, + } fakedialer := &mocks.Dialer{ MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) { return &mocks.Conn{ @@ -120,18 +135,23 @@ func TestDNSOverTCPTransport(t *testing.T) { }, } txp := NewDNSOverTCPTransport(fakedialer.DialContext, address) - reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11)) + resp, err := txp.RoundTrip(context.Background(), query) if !errors.Is(err, mocked) { t.Fatal("not the error we expected") } - if reply != nil { - t.Fatal("expected nil reply here") + if resp != nil { + t.Fatal("expected nil resp here") } }) t.Run("second read fails", func(t *testing.T) { const address = "9.9.9.9:53" mocked := errors.New("mocked error") + query := &mocks.DNSQuery{ + MockBytes: func() ([]byte, error) { + return make([]byte, 128), nil + }, + } input := io.MultiReader( bytes.NewReader([]byte{byte(0), byte(2)}), &mocks.Reader{ @@ -157,17 +177,23 @@ func TestDNSOverTCPTransport(t *testing.T) { }, } txp := NewDNSOverTCPTransport(fakedialer.DialContext, address) - reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11)) + resp, err := txp.RoundTrip(context.Background(), query) if !errors.Is(err, mocked) { t.Fatal("not the error we expected") } - if reply != nil { - t.Fatal("expected nil reply here") + if resp != nil { + t.Fatal("expected nil resp here") } }) - t.Run("successful case", func(t *testing.T) { + t.Run("decode failure", func(t *testing.T) { const address = "9.9.9.9:53" + mocked := errors.New("mocked error") + query := &mocks.DNSQuery{ + MockBytes: func() ([]byte, error) { + return make([]byte, 128), nil + }, + } input := bytes.NewReader([]byte{byte(0), byte(1), byte(1)}) fakedialer := &mocks.Dialer{ MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) { @@ -186,11 +212,56 @@ func TestDNSOverTCPTransport(t *testing.T) { }, } txp := NewDNSOverTCPTransport(fakedialer.DialContext, address) - reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11)) + txp.decoder = &mocks.DNSDecoder{ + MockDecodeResponse: func(data []byte, query model.DNSQuery) (model.DNSResponse, error) { + return nil, mocked + }, + } + resp, err := txp.RoundTrip(context.Background(), query) + if !errors.Is(err, mocked) { + t.Fatal("unexpected err", err) + } + if resp != nil { + t.Fatal("expected nil resp here") + } + }) + + t.Run("successful case", func(t *testing.T) { + const address = "9.9.9.9:53" + query := &mocks.DNSQuery{ + MockBytes: func() ([]byte, error) { + return make([]byte, 128), nil + }, + } + input := bytes.NewReader([]byte{byte(0), byte(1), byte(1)}) + fakedialer := &mocks.Dialer{ + MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) { + return &mocks.Conn{ + MockSetDeadline: func(t time.Time) error { + return nil + }, + MockWrite: func(b []byte) (int, error) { + return len(b), nil + }, + MockRead: input.Read, + MockClose: func() error { + return nil + }, + }, nil + }, + } + txp := NewDNSOverTCPTransport(fakedialer.DialContext, address) + expectedResp := &mocks.DNSResponse{} + txp.decoder = &mocks.DNSDecoder{ + MockDecodeResponse: func(data []byte, query model.DNSQuery) (model.DNSResponse, error) { + return expectedResp, nil + }, + } + resp, err := txp.RoundTrip(context.Background(), query) if err != nil { t.Fatal(err) } - if len(reply) != 1 || reply[0] != 1 { + if resp != expectedResp { t.Fatal("not the response we expected") } }) @@ -213,7 +284,7 @@ func TestDNSOverTCPTransport(t *testing.T) { t.Run("other functions okay with TLS", func(t *testing.T) { const address = "9.9.9.9:853" - txp := NewDNSOverTLS((&tls.Dialer{}).DialContext, address) + txp := NewDNSOverTLSTransport((&tls.Dialer{}).DialContext, address) if txp.RequiresPadding() != true { t.Fatal("invalid RequiresPadding") } diff --git a/internal/netxlite/dnsoverudp.go b/internal/netxlite/dnsoverudp.go index 4a79271..2928826 100644 --- a/internal/netxlite/dnsoverudp.go +++ b/internal/netxlite/dnsoverudp.go @@ -14,6 +14,7 @@ import ( // DNSOverUDPTransport is a DNS-over-UDP DNSTransport. type DNSOverUDPTransport struct { dialer model.Dialer + decoder model.DNSDecoder address string } @@ -25,11 +26,20 @@ type DNSOverUDPTransport struct { // // - address is the endpoint address (e.g., 8.8.8.8:53). func NewDNSOverUDPTransport(dialer model.Dialer, address string) *DNSOverUDPTransport { - return &DNSOverUDPTransport{dialer: dialer, address: address} + return &DNSOverUDPTransport{ + dialer: dialer, + decoder: &DNSDecoderMiekg{}, + address: address, + } } // RoundTrip sends a query and receives a reply. -func (t *DNSOverUDPTransport) RoundTrip(ctx context.Context, query []byte) ([]byte, error) { +func (t *DNSOverUDPTransport) RoundTrip( + ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) { + rawQuery, err := query.Bytes() + if err != nil { + return nil, err + } conn, err := t.dialer.DialContext(ctx, "udp", t.address) if err != nil { return nil, err @@ -37,19 +47,19 @@ func (t *DNSOverUDPTransport) RoundTrip(ctx context.Context, query []byte) ([]by defer conn.Close() // Use five seconds timeout like Bionic does. See // https://labs.ripe.net/Members/baptiste_jonglez_1/persistent-dns-connections-for-reliability-and-performance - if err = conn.SetDeadline(time.Now().Add(5 * time.Second)); err != nil { + const iotimeout = 5 * time.Second + conn.SetDeadline(time.Now().Add(iotimeout)) + if _, err = conn.Write(rawQuery); err != nil { return nil, err } - if _, err = conn.Write(query); err != nil { - return nil, err - } - reply := make([]byte, 1<<17) - var n int - n, err = conn.Read(reply) + const maxmessagesize = 1 << 17 + rawResponse := make([]byte, maxmessagesize) + count, err := conn.Read(rawResponse) if err != nil { return nil, err } - return reply[:n], nil + rawResponse = rawResponse[:count] + return t.decoder.DecodeResponse(rawResponse, query) } // RequiresPadding returns false for UDP according to RFC8467. diff --git a/internal/netxlite/dnsoverudp_test.go b/internal/netxlite/dnsoverudp_test.go index e7c6292..5720edf 100644 --- a/internal/netxlite/dnsoverudp_test.go +++ b/internal/netxlite/dnsoverudp_test.go @@ -9,11 +9,30 @@ import ( "time" "github.com/apex/log" + "github.com/ooni/probe-cli/v3/internal/model" "github.com/ooni/probe-cli/v3/internal/model/mocks" ) func TestDNSOverUDPTransport(t *testing.T) { t.Run("RoundTrip", func(t *testing.T) { + t.Run("cannot encode query", func(t *testing.T) { + expected := errors.New("mocked error") + const address = "9.9.9.9:53" + txp := NewDNSOverUDPTransport(nil, address) + query := &mocks.DNSQuery{ + MockBytes: func() ([]byte, error) { + return nil, expected + }, + } + resp, err := txp.RoundTrip(context.Background(), query) + if !errors.Is(err, expected) { + t.Fatal("unexpected err", err) + } + if resp != nil { + t.Fatal("expected nil response here") + } + }) + t.Run("dial failure", func(t *testing.T) { mocked := errors.New("mocked error") const address = "9.9.9.9:53" @@ -22,36 +41,16 @@ func TestDNSOverUDPTransport(t *testing.T) { return nil, mocked }, }, address) - data, err := txp.RoundTrip(context.Background(), nil) + query := &mocks.DNSQuery{ + MockBytes: func() ([]byte, error) { + return make([]byte, 128), nil + }, + } + resp, err := txp.RoundTrip(context.Background(), query) if !errors.Is(err, mocked) { t.Fatal("not the error we expected") } - if data != nil { - t.Fatal("expected no response here") - } - }) - - t.Run("SetDeadline failure", func(t *testing.T) { - mocked := errors.New("mocked error") - txp := NewDNSOverUDPTransport( - &mocks.Dialer{ - MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) { - return &mocks.Conn{ - MockSetDeadline: func(t time.Time) error { - return mocked - }, - MockClose: func() error { - return nil - }, - }, nil - }, - }, "9.9.9.9:53", - ) - data, err := txp.RoundTrip(context.Background(), nil) - if !errors.Is(err, mocked) { - t.Fatal("not the error we expected") - } - if data != nil { + if resp != nil { t.Fatal("expected no response here") } }) @@ -75,11 +74,16 @@ func TestDNSOverUDPTransport(t *testing.T) { }, }, "9.9.9.9:53", ) - data, err := txp.RoundTrip(context.Background(), nil) + query := &mocks.DNSQuery{ + MockBytes: func() ([]byte, error) { + return make([]byte, 128), nil + }, + } + resp, err := txp.RoundTrip(context.Background(), query) if !errors.Is(err, mocked) { t.Fatal("not the error we expected") } - if data != nil { + if resp != nil { t.Fatal("expected no response here") } }) @@ -106,15 +110,61 @@ func TestDNSOverUDPTransport(t *testing.T) { }, }, "9.9.9.9:53", ) - data, err := txp.RoundTrip(context.Background(), nil) + query := &mocks.DNSQuery{ + MockBytes: func() ([]byte, error) { + return make([]byte, 128), nil + }, + } + resp, err := txp.RoundTrip(context.Background(), query) if !errors.Is(err, mocked) { t.Fatal("not the error we expected") } - if data != nil { + if resp != nil { t.Fatal("expected no response here") } }) + t.Run("decode failure", func(t *testing.T) { + const expected = 17 + input := bytes.NewReader(make([]byte, expected)) + txp := NewDNSOverUDPTransport( + &mocks.Dialer{ + MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) { + return &mocks.Conn{ + MockSetDeadline: func(t time.Time) error { + return nil + }, + MockWrite: func(b []byte) (int, error) { + return len(b), nil + }, + MockRead: input.Read, + MockClose: func() error { + return nil + }, + }, nil + }, + }, "9.9.9.9:53", + ) + expectedErr := errors.New("mocked error") + txp.decoder = &mocks.DNSDecoder{ + MockDecodeResponse: func(data []byte, query model.DNSQuery) (model.DNSResponse, error) { + return nil, expectedErr + }, + } + query := &mocks.DNSQuery{ + MockBytes: func() ([]byte, error) { + return make([]byte, 128), nil + }, + } + resp, err := txp.RoundTrip(context.Background(), query) + if !errors.Is(err, expectedErr) { + t.Fatal("unexpected err", err) + } + if resp != nil { + t.Fatal("expected nil resp") + } + }) + t.Run("read success", func(t *testing.T) { const expected = 17 input := bytes.NewReader(make([]byte, expected)) @@ -136,12 +186,23 @@ func TestDNSOverUDPTransport(t *testing.T) { }, }, "9.9.9.9:53", ) - data, err := txp.RoundTrip(context.Background(), nil) + expectedResp := &mocks.DNSResponse{} + txp.decoder = &mocks.DNSDecoder{ + MockDecodeResponse: func(data []byte, query model.DNSQuery) (model.DNSResponse, error) { + return expectedResp, nil + }, + } + query := &mocks.DNSQuery{ + MockBytes: func() ([]byte, error) { + return make([]byte, 128), nil + }, + } + resp, err := txp.RoundTrip(context.Background(), query) if err != nil { t.Fatal(err) } - if len(data) != expected { - t.Fatal("expected non nil data") + if resp != expectedResp { + t.Fatal("unexpected resp") } }) }) diff --git a/internal/netxlite/filtering/dns.go b/internal/netxlite/filtering/dns.go index d25cd34..8b63527 100644 --- a/internal/netxlite/filtering/dns.go +++ b/internal/netxlite/filtering/dns.go @@ -1,15 +1,12 @@ package filtering import ( - "context" "errors" "io" "net" - "net/http" "strings" "github.com/miekg/dns" - "github.com/ooni/probe-cli/v3/internal/netxlite" "github.com/ooni/probe-cli/v3/internal/runtimex" ) @@ -51,19 +48,13 @@ type DNSProxy struct { // receive a query for the given domain. OnQuery func(domain string) DNSAction - // Upstream is the OPTIONAL upstream transport. - Upstream DNSTransport + // UpstreamEndpoint is the OPTIONAL upstream transport endpoint. + UpstreamEndpoint string // mockableReply allows to mock DNSProxy.reply in tests. mockableReply func(query *dns.Msg) (*dns.Msg, error) } -// DNSTransport is the type we expect from an upstream DNS transport. -type DNSTransport interface { - RoundTrip(ctx context.Context, query []byte) ([]byte, error) - CloseIdleConnections() -} - // DNSListener is the interface returned by DNSProxy.Start type DNSListener interface { io.Closer @@ -204,23 +195,24 @@ func (p *DNSProxy) compose(query *dns.Msg, ips ...net.IP) *dns.Msg { return reply } +var ( + // errDNSExpectedSingleQuestion means we expected to see a single question + errDNSExpectedSingleQuestion = errors.New("filtering: expected single DNS question") + + // errDNSExpectedQueryNotResponse means we expected to see a query. + errDNSExpectedQueryNotResponse = errors.New("filtering: expected query not response") +) + func (p *DNSProxy) proxy(query *dns.Msg) (*dns.Msg, error) { - queryBytes, err := query.Pack() - if err != nil { - return nil, err + if query.Response { + return nil, errDNSExpectedQueryNotResponse } - txp := p.dnstransport() - defer txp.CloseIdleConnections() - ctx := context.Background() - replyBytes, err := txp.RoundTrip(ctx, queryBytes) - if err != nil { - return nil, err + if len(query.Question) != 1 { + return nil, errDNSExpectedSingleQuestion } - reply := &dns.Msg{} - if err := reply.Unpack(replyBytes); err != nil { - return nil, err - } - return reply, nil + clnt := &dns.Client{} + resp, _, err := clnt.Exchange(query, p.upstreamEndpoint()) + return resp, err } func (p *DNSProxy) cache(name string, query *dns.Msg) *dns.Msg { @@ -237,10 +229,9 @@ func (p *DNSProxy) cache(name string, query *dns.Msg) *dns.Msg { return p.compose(query, ipAddrs...) } -func (p *DNSProxy) dnstransport() DNSTransport { - if p.Upstream != nil { - return p.Upstream +func (p *DNSProxy) upstreamEndpoint() string { + if p.UpstreamEndpoint != "" { + return p.UpstreamEndpoint } - const URL = "https://1.1.1.1/dns-query" - return netxlite.NewDNSOverHTTPSTransport(http.DefaultClient, URL) + return "8.8.8.8:53" } diff --git a/internal/netxlite/filtering/dns_test.go b/internal/netxlite/filtering/dns_test.go index b2e1d91..88d5ac5 100644 --- a/internal/netxlite/filtering/dns_test.go +++ b/internal/netxlite/filtering/dns_test.go @@ -283,9 +283,7 @@ func TestDNSProxy(t *testing.T) { if len(p) < len(data) { panic("buffer too small") } - for i := 0; i < len(data); i++ { - p[i] = data[i] - } + copy(p, data) return len(data), &net.UDPAddr{}, nil }, } @@ -314,9 +312,7 @@ func TestDNSProxy(t *testing.T) { if len(p) < len(data) { panic("buffer too small") } - for i := 0; i < len(data); i++ { - p[i] = data[i] - } + copy(p, data) return len(data), &net.UDPAddr{}, nil }, } @@ -328,12 +324,24 @@ func TestDNSProxy(t *testing.T) { }) t.Run("proxy", func(t *testing.T) { - t.Run("Pack fails", func(t *testing.T) { + t.Run("with response", func(t *testing.T) { p := &DNSProxy{} query := &dns.Msg{} - query.Rcode = -1 // causes Pack to fail + query.Response = true reply, err := p.proxy(query) - if err == nil || !strings.HasSuffix(err.Error(), "bad rcode") { + if !errors.Is(err, errDNSExpectedQueryNotResponse) { + t.Fatal("unexpected err", err) + } + if reply != nil { + t.Fatal("expected nil reply") + } + }) + + t.Run("with no questions", func(t *testing.T) { + p := &DNSProxy{} + query := &dns.Msg{} + reply, err := p.proxy(query) + if !errors.Is(err, errDNSExpectedSingleQuestion) { t.Fatal("unexpected err", err) } if reply != nil { @@ -342,35 +350,13 @@ func TestDNSProxy(t *testing.T) { }) t.Run("round trip fails", func(t *testing.T) { - expected := errors.New("mocked error") p := &DNSProxy{ - Upstream: &mocks.DNSTransport{ - MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) { - return nil, expected - }, - MockCloseIdleConnections: func() {}, - }, + UpstreamEndpoint: "antani", } - reply, err := p.proxy(&dns.Msg{}) - if !errors.Is(err, expected) { - t.Fatal("unexpected err", err) - } - if reply != nil { - t.Fatal("expected nil reply here") - } - }) - - t.Run("Unpack fails", func(t *testing.T) { - p := &DNSProxy{ - Upstream: &mocks.DNSTransport{ - MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) { - return make([]byte, 1), nil - }, - MockCloseIdleConnections: func() {}, - }, - } - reply, err := p.proxy(&dns.Msg{}) - if err == nil || !strings.HasSuffix(err.Error(), "overflow unpacking uint16") { + query := &dns.Msg{} + query.Question = append(query.Question, dns.Question{}) + reply, err := p.proxy(query) + if err == nil || !strings.HasSuffix(err.Error(), "missing port in address") { t.Fatal("unexpected err", err) } if reply != nil { diff --git a/internal/netxlite/parallelresolver.go b/internal/netxlite/parallelresolver.go index a5dfed7..8c22bb5 100644 --- a/internal/netxlite/parallelresolver.go +++ b/internal/netxlite/parallelresolver.go @@ -9,7 +9,6 @@ import ( "net" "github.com/miekg/dns" - "github.com/ooni/probe-cli/v3/internal/atomicx" "github.com/ooni/probe-cli/v3/internal/model" ) @@ -19,15 +18,6 @@ import ( // You should probably use NewUnwrappedParallelResolver to // create a new instance of this type. type ParallelResolver struct { - // Encoder is the MANDATORY encoder to use. - Encoder model.DNSEncoder - - // Decoder is the MANDATORY decoder to use. - Decoder model.DNSDecoder - - // NumTimeouts is MANDATORY and counts the number of timeouts. - NumTimeouts *atomicx.Int64 - // Txp is the MANDATORY underlying DNS transport. Txp model.DNSTransport } @@ -36,10 +26,7 @@ type ParallelResolver struct { // not wrapped and you should wrap if before using it. func NewUnwrappedParallelResolver(t model.DNSTransport) *ParallelResolver { return &ParallelResolver{ - Encoder: &DNSEncoderMiekg{}, - Decoder: &DNSDecoderMiekg{}, - NumTimeouts: &atomicx.Int64{}, - Txp: t, + Txp: t, } } @@ -80,22 +67,22 @@ func (r *ParallelResolver) LookupHost(ctx context.Context, hostname string) ([]s var addrs []string addrs = append(addrs, ares.addrs...) addrs = append(addrs, aaaares.addrs...) + if len(addrs) < 1 { + return nil, ErrOODNSNoAnswer + } return addrs, nil } // LookupHTTPS implements Resolver.LookupHTTPS. func (r *ParallelResolver) LookupHTTPS( ctx context.Context, hostname string) (*model.HTTPSSvc, error) { - querydata, queryID, err := r.Encoder.Encode( - hostname, dns.TypeHTTPS, r.Txp.RequiresPadding()) + encoder := &DNSEncoderMiekg{} + query := encoder.Encode(hostname, dns.TypeHTTPS, r.Txp.RequiresPadding()) + response, err := r.Txp.RoundTrip(ctx, query) if err != nil { return nil, err } - replydata, err := r.Txp.RoundTrip(ctx, querydata) - if err != nil { - return nil, err - } - return r.Decoder.DecodeHTTPS(replydata, queryID) + return response.DecodeHTTPS() } // parallelResolverResult is the internal representation of a @@ -108,7 +95,9 @@ type parallelResolverResult struct { // lookupHost issues a lookup host query for the specified qtype (e.g., dns.A). func (r *ParallelResolver) lookupHost(ctx context.Context, hostname string, qtype uint16, out chan<- *parallelResolverResult) { - querydata, queryID, err := r.Encoder.Encode(hostname, qtype, r.Txp.RequiresPadding()) + encoder := &DNSEncoderMiekg{} + query := encoder.Encode(hostname, qtype, r.Txp.RequiresPadding()) + response, err := r.Txp.RoundTrip(ctx, query) if err != nil { out <- ¶llelResolverResult{ addrs: []string{}, @@ -116,15 +105,7 @@ func (r *ParallelResolver) lookupHost(ctx context.Context, hostname string, } return } - replydata, err := r.Txp.RoundTrip(ctx, querydata) - if err != nil { - out <- ¶llelResolverResult{ - addrs: []string{}, - err: err, - } - return - } - addrs, err := r.Decoder.DecodeLookupHost(qtype, replydata, queryID) + addrs, err := response.DecodeLookupHost() out <- ¶llelResolverResult{ addrs: addrs, err: err, @@ -134,14 +115,11 @@ func (r *ParallelResolver) lookupHost(ctx context.Context, hostname string, // LookupNS implements Resolver.LookupNS. func (r *ParallelResolver) LookupNS( ctx context.Context, hostname string) ([]*net.NS, error) { - querydata, queryID, err := r.Encoder.Encode( - hostname, dns.TypeNS, r.Txp.RequiresPadding()) + encoder := &DNSEncoderMiekg{} + query := encoder.Encode(hostname, dns.TypeNS, r.Txp.RequiresPadding()) + response, err := r.Txp.RoundTrip(ctx, query) if err != nil { return nil, err } - replydata, err := r.Txp.RoundTrip(ctx, querydata) - if err != nil { - return nil, err - } - return r.Decoder.DecodeNS(replydata, queryID) + return response.DecodeNS() } diff --git a/internal/netxlite/parallelresolver_test.go b/internal/netxlite/parallelresolver_test.go index e334dae..f208623 100644 --- a/internal/netxlite/parallelresolver_test.go +++ b/internal/netxlite/parallelresolver_test.go @@ -8,14 +8,13 @@ import ( "testing" "github.com/miekg/dns" - "github.com/ooni/probe-cli/v3/internal/atomicx" "github.com/ooni/probe-cli/v3/internal/model" "github.com/ooni/probe-cli/v3/internal/model/mocks" ) func TestParallelResolver(t *testing.T) { t.Run("transport okay", func(t *testing.T) { - txp := NewDNSOverTLS((&tls.Dialer{}).DialContext, "8.8.8.8:853") + txp := NewDNSOverTLSTransport((&tls.Dialer{}).DialContext, "8.8.8.8:853") r := NewUnwrappedParallelResolver(txp) rtx := r.Transport() if rtx.Network() != "dot" || rtx.Address() != "8.8.8.8:853" { @@ -30,30 +29,10 @@ func TestParallelResolver(t *testing.T) { }) t.Run("LookupHost", func(t *testing.T) { - t.Run("Encode error", func(t *testing.T) { - mocked := errors.New("mocked error") - txp := NewDNSOverTLS((&tls.Dialer{}).DialContext, "8.8.8.8:853") - r := ParallelResolver{ - Encoder: &mocks.DNSEncoder{ - MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, uint16, error) { - return nil, 0, mocked - }, - }, - Txp: txp, - } - addrs, err := r.LookupHost(context.Background(), "www.gogle.com") - if !errors.Is(err, mocked) { - t.Fatal("not the error we expected") - } - if addrs != nil { - t.Fatal("expected nil address here") - } - }) - t.Run("RoundTrip error", func(t *testing.T) { mocked := errors.New("mocked error") txp := &mocks.DNSTransport{ - MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) { + MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) { return nil, mocked }, MockRequiresPadding: func() bool { @@ -72,8 +51,13 @@ func TestParallelResolver(t *testing.T) { t.Run("empty reply", func(t *testing.T) { txp := &mocks.DNSTransport{ - MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) { - return dnsGenLookupHostReplySuccess(query), nil + MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) { + response := &mocks.DNSResponse{ + MockDecodeLookupHost: func() ([]string, error) { + return nil, nil + }, + } + return response, nil }, MockRequiresPadding: func() bool { return true @@ -91,8 +75,16 @@ func TestParallelResolver(t *testing.T) { t.Run("with A reply", func(t *testing.T) { txp := &mocks.DNSTransport{ - MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) { - return dnsGenLookupHostReplySuccess(query, "8.8.8.8"), nil + MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) { + response := &mocks.DNSResponse{ + MockDecodeLookupHost: func() ([]string, error) { + if query.Type() != dns.TypeA { + return nil, nil + } + return []string{"8.8.8.8"}, nil + }, + } + return response, nil }, MockRequiresPadding: func() bool { return true @@ -110,8 +102,16 @@ func TestParallelResolver(t *testing.T) { t.Run("with AAAA reply", func(t *testing.T) { txp := &mocks.DNSTransport{ - MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) { - return dnsGenLookupHostReplySuccess(query, "::1"), nil + MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) { + response := &mocks.DNSResponse{ + MockDecodeLookupHost: func() ([]string, error) { + if query.Type() != dns.TypeAAAA { + return nil, nil + } + return []string{"::1"}, nil + }, + } + return response, nil }, MockRequiresPadding: func() bool { return true @@ -131,22 +131,15 @@ func TestParallelResolver(t *testing.T) { afailure := errors.New("a failure") aaaafailure := errors.New("aaaa failure") txp := &mocks.DNSTransport{ - MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) { - msg := &dns.Msg{} - if err := msg.Unpack(query); err != nil { - return nil, err - } - if len(msg.Question) != 1 { - return nil, errors.New("expected just one question") - } - q := msg.Question[0] - if q.Qtype == dns.TypeA { + MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) { + switch query.Type() { + case dns.TypeA: return nil, afailure - } - if q.Qtype == dns.TypeAAAA { + case dns.TypeAAAA: return nil, aaaafailure + default: + return nil, errors.New("unexpected query") } - return nil, errors.New("expected A or AAAA query") }, MockRequiresPadding: func() bool { return true @@ -179,44 +172,11 @@ func TestParallelResolver(t *testing.T) { }) t.Run("LookupHTTPS", func(t *testing.T) { - t.Run("for encoding error", func(t *testing.T) { - expected := errors.New("mocked error") - r := &ParallelResolver{ - Encoder: &mocks.DNSEncoder{ - MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, uint16, error) { - return nil, 0, expected - }, - }, - Decoder: nil, - NumTimeouts: &atomicx.Int64{}, - Txp: &mocks.DNSTransport{ - MockRequiresPadding: func() bool { - return false - }, - }, - } - ctx := context.Background() - https, err := r.LookupHTTPS(ctx, "example.com") - if !errors.Is(err, expected) { - t.Fatal("unexpected err", err) - } - if https != nil { - t.Fatal("unexpected result") - } - }) - t.Run("for round-trip error", func(t *testing.T) { expected := errors.New("mocked error") r := &ParallelResolver{ - Encoder: &mocks.DNSEncoder{ - MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, uint16, error) { - return make([]byte, 64), 0, nil - }, - }, - Decoder: nil, - NumTimeouts: &atomicx.Int64{}, Txp: &mocks.DNSTransport{ - MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) { + MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) { return nil, expected }, MockRequiresPadding: func() bool { @@ -234,23 +194,17 @@ func TestParallelResolver(t *testing.T) { } }) - t.Run("for decode error", func(t *testing.T) { + t.Run("for DecodeHTTPS error", func(t *testing.T) { expected := errors.New("mocked error") r := &ParallelResolver{ - Encoder: &mocks.DNSEncoder{ - MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, uint16, error) { - return make([]byte, 64), 0, nil - }, - }, - Decoder: &mocks.DNSDecoder{ - MockDecodeHTTPS: func(reply []byte, queryID uint16) (*model.HTTPSSvc, error) { - return nil, expected - }, - }, - NumTimeouts: &atomicx.Int64{}, Txp: &mocks.DNSTransport{ - MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) { - return make([]byte, 128), nil + MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) { + response := &mocks.DNSResponse{ + MockDecodeHTTPS: func() (*model.HTTPSSvc, error) { + return nil, expected + }, + } + return response, nil }, MockRequiresPadding: func() bool { return false @@ -269,44 +223,11 @@ func TestParallelResolver(t *testing.T) { }) t.Run("LookupNS", func(t *testing.T) { - t.Run("for encoding error", func(t *testing.T) { - expected := errors.New("mocked error") - r := &ParallelResolver{ - Encoder: &mocks.DNSEncoder{ - MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, uint16, error) { - return nil, 0, expected - }, - }, - Decoder: nil, - NumTimeouts: &atomicx.Int64{}, - Txp: &mocks.DNSTransport{ - MockRequiresPadding: func() bool { - return false - }, - }, - } - ctx := context.Background() - ns, err := r.LookupNS(ctx, "example.com") - if !errors.Is(err, expected) { - t.Fatal("unexpected err", err) - } - if ns != nil { - t.Fatal("unexpected result") - } - }) - t.Run("for round-trip error", func(t *testing.T) { expected := errors.New("mocked error") r := &ParallelResolver{ - Encoder: &mocks.DNSEncoder{ - MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, uint16, error) { - return make([]byte, 64), 0, nil - }, - }, - Decoder: nil, - NumTimeouts: &atomicx.Int64{}, Txp: &mocks.DNSTransport{ - MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) { + MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) { return nil, expected }, MockRequiresPadding: func() bool { @@ -327,20 +248,14 @@ func TestParallelResolver(t *testing.T) { t.Run("for decode error", func(t *testing.T) { expected := errors.New("mocked error") r := &ParallelResolver{ - Encoder: &mocks.DNSEncoder{ - MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, uint16, error) { - return make([]byte, 64), 0, nil - }, - }, - Decoder: &mocks.DNSDecoder{ - MockDecodeNS: func(reply []byte, queryID uint16) ([]*net.NS, error) { - return nil, expected - }, - }, - NumTimeouts: &atomicx.Int64{}, Txp: &mocks.DNSTransport{ - MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) { - return make([]byte, 128), nil + MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) { + response := &mocks.DNSResponse{ + MockDecodeNS: func() ([]*net.NS, error) { + return nil, expected + }, + } + return response, nil }, MockRequiresPadding: func() bool { return false diff --git a/internal/netxlite/serialresolver.go b/internal/netxlite/serialresolver.go index b0805ab..316ca88 100644 --- a/internal/netxlite/serialresolver.go +++ b/internal/netxlite/serialresolver.go @@ -26,12 +26,6 @@ import ( // QUIRK: unlike the ParallelResolver, this resolver's LookupHost retries // each query three times for soft errors. type SerialResolver struct { - // Encoder is the MANDATORY encoder to use. - Encoder model.DNSEncoder - - // Decoder is the MANDATORY decoder to use. - Decoder model.DNSDecoder - // NumTimeouts is MANDATORY and counts the number of timeouts. NumTimeouts *atomicx.Int64 @@ -42,8 +36,6 @@ type SerialResolver struct { // NewSerialResolver creates a new SerialResolver instance. func NewSerialResolver(t model.DNSTransport) *SerialResolver { return &SerialResolver{ - Encoder: &DNSEncoderMiekg{}, - Decoder: &DNSDecoderMiekg{}, NumTimeouts: &atomicx.Int64{}, Txp: t, } @@ -82,22 +74,22 @@ func (r *SerialResolver) LookupHost(ctx context.Context, hostname string) ([]str } addrs = append(addrs, addrsA...) addrs = append(addrs, addrsAAAA...) + if len(addrs) < 1 { + return nil, ErrOODNSNoAnswer + } return addrs, nil } // LookupHTTPS implements Resolver.LookupHTTPS. func (r *SerialResolver) LookupHTTPS( ctx context.Context, hostname string) (*model.HTTPSSvc, error) { - querydata, queryID, err := r.Encoder.Encode( - hostname, dns.TypeHTTPS, r.Txp.RequiresPadding()) + encoder := &DNSEncoderMiekg{} + query := encoder.Encode(hostname, dns.TypeHTTPS, r.Txp.RequiresPadding()) + response, err := r.Txp.RoundTrip(ctx, query) if err != nil { return nil, err } - replydata, err := r.Txp.RoundTrip(ctx, querydata) - if err != nil { - return nil, err - } - return r.Decoder.DecodeHTTPS(replydata, queryID) + return response.DecodeHTTPS() } func (r *SerialResolver) lookupHostWithRetry( @@ -132,28 +124,23 @@ func (r *SerialResolver) lookupHostWithRetry( // qtype (dns.A or dns.AAAA) without retrying on failure. func (r *SerialResolver) lookupHostWithoutRetry( ctx context.Context, hostname string, qtype uint16) ([]string, error) { - querydata, queryID, err := r.Encoder.Encode(hostname, qtype, r.Txp.RequiresPadding()) + encoder := &DNSEncoderMiekg{} + query := encoder.Encode(hostname, qtype, r.Txp.RequiresPadding()) + response, err := r.Txp.RoundTrip(ctx, query) if err != nil { return nil, err } - replydata, err := r.Txp.RoundTrip(ctx, querydata) - if err != nil { - return nil, err - } - return r.Decoder.DecodeLookupHost(qtype, replydata, queryID) + return response.DecodeLookupHost() } // LookupNS implements Resolver.LookupNS. func (r *SerialResolver) LookupNS( ctx context.Context, hostname string) ([]*net.NS, error) { - querydata, queryID, err := r.Encoder.Encode( - hostname, dns.TypeNS, r.Txp.RequiresPadding()) + encoder := &DNSEncoderMiekg{} + query := encoder.Encode(hostname, dns.TypeNS, r.Txp.RequiresPadding()) + response, err := r.Txp.RoundTrip(ctx, query) if err != nil { return nil, err } - replydata, err := r.Txp.RoundTrip(ctx, querydata) - if err != nil { - return nil, err - } - return r.Decoder.DecodeNS(replydata, queryID) + return response.DecodeNS() } diff --git a/internal/netxlite/serialresolver_test.go b/internal/netxlite/serialresolver_test.go index af5df96..ed9df5f 100644 --- a/internal/netxlite/serialresolver_test.go +++ b/internal/netxlite/serialresolver_test.go @@ -7,6 +7,7 @@ import ( "net" "testing" + "github.com/miekg/dns" "github.com/ooni/probe-cli/v3/internal/atomicx" "github.com/ooni/probe-cli/v3/internal/model" "github.com/ooni/probe-cli/v3/internal/model/mocks" @@ -30,7 +31,7 @@ func (err *errorWithTimeout) Unwrap() error { func TestSerialResolver(t *testing.T) { t.Run("transport okay", func(t *testing.T) { - txp := NewDNSOverTLS((&tls.Dialer{}).DialContext, "8.8.8.8:853") + txp := NewDNSOverTLSTransport((&tls.Dialer{}).DialContext, "8.8.8.8:853") r := NewSerialResolver(txp) rtx := r.Transport() if rtx.Network() != "dot" || rtx.Address() != "8.8.8.8:853" { @@ -45,30 +46,10 @@ func TestSerialResolver(t *testing.T) { }) t.Run("LookupHost", func(t *testing.T) { - t.Run("Encode error", func(t *testing.T) { - mocked := errors.New("mocked error") - txp := NewDNSOverTLS((&tls.Dialer{}).DialContext, "8.8.8.8:853") - r := SerialResolver{ - Encoder: &mocks.DNSEncoder{ - MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, uint16, error) { - return nil, 0, mocked - }, - }, - Txp: txp, - } - addrs, err := r.LookupHost(context.Background(), "www.gogle.com") - if !errors.Is(err, mocked) { - t.Fatal("not the error we expected") - } - if addrs != nil { - t.Fatal("expected nil address here") - } - }) - t.Run("RoundTrip error", func(t *testing.T) { mocked := errors.New("mocked error") txp := &mocks.DNSTransport{ - MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) { + MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) { return nil, mocked }, MockRequiresPadding: func() bool { @@ -87,8 +68,13 @@ func TestSerialResolver(t *testing.T) { t.Run("empty reply", func(t *testing.T) { txp := &mocks.DNSTransport{ - MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) { - return dnsGenLookupHostReplySuccess(query), nil + MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) { + response := &mocks.DNSResponse{ + MockDecodeLookupHost: func() ([]string, error) { + return nil, nil + }, + } + return response, nil }, MockRequiresPadding: func() bool { return true @@ -106,8 +92,16 @@ func TestSerialResolver(t *testing.T) { t.Run("with A reply", func(t *testing.T) { txp := &mocks.DNSTransport{ - MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) { - return dnsGenLookupHostReplySuccess(query, "8.8.8.8"), nil + MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) { + response := &mocks.DNSResponse{ + MockDecodeLookupHost: func() ([]string, error) { + if query.Type() != dns.TypeA { + return nil, nil + } + return []string{"8.8.8.8"}, nil + }, + } + return response, nil }, MockRequiresPadding: func() bool { return true @@ -125,8 +119,16 @@ func TestSerialResolver(t *testing.T) { t.Run("with AAAA reply", func(t *testing.T) { txp := &mocks.DNSTransport{ - MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) { - return dnsGenLookupHostReplySuccess(query, "::1"), nil + MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) { + response := &mocks.DNSResponse{ + MockDecodeLookupHost: func() ([]string, error) { + if query.Type() != dns.TypeAAAA { + return nil, nil + } + return []string{"::1"}, nil + }, + } + return response, nil }, MockRequiresPadding: func() bool { return true @@ -144,11 +146,12 @@ func TestSerialResolver(t *testing.T) { t.Run("with timeout", func(t *testing.T) { txp := &mocks.DNSTransport{ - MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) { - return nil, &net.OpError{ + MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) { + err := &net.OpError{ Err: &errorWithTimeout{ETIMEDOUT}, Op: "dial", } + return nil, err }, MockRequiresPadding: func() bool { return true @@ -184,44 +187,12 @@ func TestSerialResolver(t *testing.T) { }) t.Run("LookupHTTPS", func(t *testing.T) { - t.Run("for encoding error", func(t *testing.T) { - expected := errors.New("mocked error") - r := &SerialResolver{ - Encoder: &mocks.DNSEncoder{ - MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, uint16, error) { - return nil, 0, expected - }, - }, - Decoder: nil, - NumTimeouts: &atomicx.Int64{}, - Txp: &mocks.DNSTransport{ - MockRequiresPadding: func() bool { - return false - }, - }, - } - ctx := context.Background() - https, err := r.LookupHTTPS(ctx, "example.com") - if !errors.Is(err, expected) { - t.Fatal("unexpected err", err) - } - if https != nil { - t.Fatal("unexpected result") - } - }) - t.Run("for round-trip error", func(t *testing.T) { expected := errors.New("mocked error") r := &SerialResolver{ - Encoder: &mocks.DNSEncoder{ - MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, uint16, error) { - return make([]byte, 64), 0, nil - }, - }, - Decoder: nil, NumTimeouts: &atomicx.Int64{}, Txp: &mocks.DNSTransport{ - MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) { + MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) { return nil, expected }, MockRequiresPadding: func() bool { @@ -242,20 +213,15 @@ func TestSerialResolver(t *testing.T) { t.Run("for decode error", func(t *testing.T) { expected := errors.New("mocked error") r := &SerialResolver{ - Encoder: &mocks.DNSEncoder{ - MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, uint16, error) { - return make([]byte, 64), 0, nil - }, - }, - Decoder: &mocks.DNSDecoder{ - MockDecodeHTTPS: func(reply []byte, queryID uint16) (*model.HTTPSSvc, error) { - return nil, expected - }, - }, NumTimeouts: &atomicx.Int64{}, Txp: &mocks.DNSTransport{ - MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) { - return make([]byte, 128), nil + MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) { + response := &mocks.DNSResponse{ + MockDecodeHTTPS: func() (*model.HTTPSSvc, error) { + return nil, expected + }, + } + return response, nil }, MockRequiresPadding: func() bool { return false @@ -274,44 +240,12 @@ func TestSerialResolver(t *testing.T) { }) t.Run("LookupNS", func(t *testing.T) { - t.Run("for encoding error", func(t *testing.T) { - expected := errors.New("mocked error") - r := &SerialResolver{ - Encoder: &mocks.DNSEncoder{ - MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, uint16, error) { - return nil, 0, expected - }, - }, - Decoder: nil, - NumTimeouts: &atomicx.Int64{}, - Txp: &mocks.DNSTransport{ - MockRequiresPadding: func() bool { - return false - }, - }, - } - ctx := context.Background() - ns, err := r.LookupNS(ctx, "example.com") - if !errors.Is(err, expected) { - t.Fatal("unexpected err", err) - } - if ns != nil { - t.Fatal("unexpected result") - } - }) - t.Run("for round-trip error", func(t *testing.T) { expected := errors.New("mocked error") r := &SerialResolver{ - Encoder: &mocks.DNSEncoder{ - MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, uint16, error) { - return make([]byte, 64), 0, nil - }, - }, - Decoder: nil, NumTimeouts: &atomicx.Int64{}, Txp: &mocks.DNSTransport{ - MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) { + MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) { return nil, expected }, MockRequiresPadding: func() bool { @@ -332,20 +266,15 @@ func TestSerialResolver(t *testing.T) { t.Run("for decode error", func(t *testing.T) { expected := errors.New("mocked error") r := &SerialResolver{ - Encoder: &mocks.DNSEncoder{ - MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, uint16, error) { - return make([]byte, 64), 0, nil - }, - }, - Decoder: &mocks.DNSDecoder{ - MockDecodeNS: func(reply []byte, queryID uint16) ([]*net.NS, error) { - return nil, expected - }, - }, NumTimeouts: &atomicx.Int64{}, Txp: &mocks.DNSTransport{ - MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) { - return make([]byte, 128), nil + MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) { + response := &mocks.DNSResponse{ + MockDecodeNS: func() ([]*net.NS, error) { + return nil, expected + }, + } + return response, nil }, MockRequiresPadding: func() bool { return false