From ec0561ea8c97058117686473205b31e855bf851b Mon Sep 17 00:00:00 2001 From: Simone Basso Date: Fri, 13 May 2022 17:36:58 +0200 Subject: [PATCH] feat(netxlite): implement parallel resolver (#724) This diff imports the parallel resolver from websteps winter 2022 edition, which was originally implemented here: https://github.com/bassosimone/websteps-illustrated/commit/55231d73cd822a851f532dea1b8089694d58100e See https://github.com/ooni/probe/issues/2096 --- internal/netxlite/parallelresolver.go | 131 ++++++++++ internal/netxlite/parallelresolver_test.go | 269 +++++++++++++++++++++ internal/netxlite/serialresolver.go | 10 +- 3 files changed, 408 insertions(+), 2 deletions(-) create mode 100644 internal/netxlite/parallelresolver.go create mode 100644 internal/netxlite/parallelresolver_test.go diff --git a/internal/netxlite/parallelresolver.go b/internal/netxlite/parallelresolver.go new file mode 100644 index 0000000..1da5397 --- /dev/null +++ b/internal/netxlite/parallelresolver.go @@ -0,0 +1,131 @@ +package netxlite + +// +// Parallel resolver implementation +// + +import ( + "context" + + "github.com/miekg/dns" + "github.com/ooni/probe-cli/v3/internal/atomicx" + "github.com/ooni/probe-cli/v3/internal/model" +) + +// ParallelResolver uses a transport and performs a LookupHost +// operation in a parallel fashion, hence its name. +// +// You should probably use NewUnwrappedParallelResolver to +// create a new instance of this type. +type ParallelResolver struct { + // Encoder is the MANDATORY encoder to use. + Encoder model.DNSEncoder + + // Decoder is the MANDATORY decoder to use. + Decoder model.DNSDecoder + + // NumTimeouts is MANDATORY and counts the number of timeouts. + NumTimeouts *atomicx.Int64 + + // Txp is the MANDATORY underlying DNS transport. + Txp model.DNSTransport +} + +// UnwrappedParallelResolver creates a new ParallelResolver instance. This instance is +// not wrapped and you should wrap if before using it. +func NewUnwrappedParallelResolver(t model.DNSTransport) *ParallelResolver { + return &ParallelResolver{ + Encoder: &DNSEncoderMiekg{}, + Decoder: &DNSDecoderMiekg{}, + NumTimeouts: &atomicx.Int64{}, + Txp: t, + } +} + +// Transport returns the transport being used. +func (r *ParallelResolver) Transport() model.DNSTransport { + return r.Txp +} + +// Network returns the "network" of the underlying transport. +func (r *ParallelResolver) Network() string { + return r.Txp.Network() +} + +// Address returns the "address" of the underlying transport. +func (r *ParallelResolver) Address() string { + return r.Txp.Address() +} + +// CloseIdleConnections closes idle connections, if any. +func (r *ParallelResolver) CloseIdleConnections() { + r.Txp.CloseIdleConnections() +} + +// LookupHost performs an A lookup in parallel with an AAAA lookup. +func (r *ParallelResolver) LookupHost(ctx context.Context, hostname string) ([]string, error) { + ach := make(chan *parallelResolverResult) + go r.lookupHost(ctx, hostname, dns.TypeA, ach) + aaaach := make(chan *parallelResolverResult) + go r.lookupHost(ctx, hostname, dns.TypeAAAA, aaaach) + ares := <-ach + aaaares := <-aaaach + if ares.err != nil && aaaares.err != nil { + // Note: we choose to return the A error because we assume that + // it's the more meaningful one: the AAAA error may just be telling + // us that there is no AAAA record for the website. + return nil, ares.err + } + var addrs []string + addrs = append(addrs, ares.addrs...) + addrs = append(addrs, aaaares.addrs...) + return addrs, nil +} + +// LookupHTTPS implements Resolver.LookupHTTPS. +func (r *ParallelResolver) LookupHTTPS( + ctx context.Context, hostname string) (*model.HTTPSSvc, error) { + querydata, err := r.Encoder.Encode( + hostname, dns.TypeHTTPS, r.Txp.RequiresPadding()) + if err != nil { + return nil, err + } + replydata, err := r.Txp.RoundTrip(ctx, querydata) + if err != nil { + return nil, err + } + return r.Decoder.DecodeHTTPS(replydata) +} + +// parallelResolverResult is the internal representation of a +// lookup using either the A or the AAAA query type. +type parallelResolverResult struct { + addrs []string + err error +} + +// lookupHost issues a lookup host query for the specified qtype (e.g., dns.A). +func (r *ParallelResolver) lookupHost(ctx context.Context, hostname string, + qtype uint16, out chan<- *parallelResolverResult) { + querydata, err := r.Encoder.Encode(hostname, qtype, r.Txp.RequiresPadding()) + if err != nil { + out <- ¶llelResolverResult{ + addrs: []string{}, + err: err, + } + return + } + replydata, err := r.Txp.RoundTrip(ctx, querydata) + if err != nil { + out <- ¶llelResolverResult{ + addrs: []string{}, + err: err, + } + return + } + addrs, err := r.Decoder.DecodeLookupHost(qtype, replydata) + out <- ¶llelResolverResult{ + addrs: addrs, + err: err, + } +} diff --git a/internal/netxlite/parallelresolver_test.go b/internal/netxlite/parallelresolver_test.go new file mode 100644 index 0000000..772340a --- /dev/null +++ b/internal/netxlite/parallelresolver_test.go @@ -0,0 +1,269 @@ +package netxlite + +import ( + "context" + "crypto/tls" + "errors" + "testing" + + "github.com/miekg/dns" + "github.com/ooni/probe-cli/v3/internal/atomicx" + "github.com/ooni/probe-cli/v3/internal/model" + "github.com/ooni/probe-cli/v3/internal/model/mocks" +) + +func TestParallelResolver(t *testing.T) { + t.Run("transport okay", func(t *testing.T) { + txp := NewDNSOverTLS((&tls.Dialer{}).DialContext, "8.8.8.8:853") + r := NewUnwrappedParallelResolver(txp) + rtx := r.Transport() + if rtx.Network() != "dot" || rtx.Address() != "8.8.8.8:853" { + t.Fatal("not the transport we expected") + } + if r.Network() != rtx.Network() { + t.Fatal("invalid network seen from the resolver") + } + if r.Address() != rtx.Address() { + t.Fatal("invalid address seen from the resolver") + } + }) + + t.Run("LookupHost", func(t *testing.T) { + t.Run("Encode error", func(t *testing.T) { + mocked := errors.New("mocked error") + txp := NewDNSOverTLS((&tls.Dialer{}).DialContext, "8.8.8.8:853") + r := ParallelResolver{ + Encoder: &mocks.DNSEncoder{ + MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, error) { + return nil, mocked + }, + }, + Txp: txp, + } + addrs, err := r.LookupHost(context.Background(), "www.gogle.com") + if !errors.Is(err, mocked) { + t.Fatal("not the error we expected") + } + if addrs != nil { + t.Fatal("expected nil address here") + } + }) + + t.Run("RoundTrip error", func(t *testing.T) { + mocked := errors.New("mocked error") + txp := &mocks.DNSTransport{ + MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) { + return nil, mocked + }, + MockRequiresPadding: func() bool { + return true + }, + } + r := NewUnwrappedParallelResolver(txp) + addrs, err := r.LookupHost(context.Background(), "www.gogle.com") + if !errors.Is(err, mocked) { + t.Fatal("not the error we expected") + } + if addrs != nil { + t.Fatal("expected nil address here") + } + }) + + t.Run("empty reply", func(t *testing.T) { + txp := &mocks.DNSTransport{ + MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) { + return dnsGenLookupHostReplySuccess(t, dns.TypeA), nil + }, + MockRequiresPadding: func() bool { + return true + }, + } + r := NewUnwrappedParallelResolver(txp) + addrs, err := r.LookupHost(context.Background(), "www.gogle.com") + if !errors.Is(err, ErrOODNSNoAnswer) { + t.Fatal("not the error we expected", err) + } + if addrs != nil { + t.Fatal("expected nil address here") + } + }) + + t.Run("with A reply", func(t *testing.T) { + txp := &mocks.DNSTransport{ + MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) { + return dnsGenLookupHostReplySuccess(t, dns.TypeA, "8.8.8.8"), nil + }, + MockRequiresPadding: func() bool { + return true + }, + } + r := NewUnwrappedParallelResolver(txp) + addrs, err := r.LookupHost(context.Background(), "www.gogle.com") + if err != nil { + t.Fatal(err) + } + if len(addrs) != 1 || addrs[0] != "8.8.8.8" { + t.Fatal("not the result we expected") + } + }) + + t.Run("with AAAA reply", func(t *testing.T) { + txp := &mocks.DNSTransport{ + MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) { + return dnsGenLookupHostReplySuccess(t, dns.TypeAAAA, "::1"), nil + }, + MockRequiresPadding: func() bool { + return true + }, + } + r := NewUnwrappedParallelResolver(txp) + addrs, err := r.LookupHost(context.Background(), "www.gogle.com") + if err != nil { + t.Fatal(err) + } + if len(addrs) != 1 || addrs[0] != "::1" { + t.Fatal("not the result we expected") + } + }) + + t.Run("A failure takes precedence over AAAA failure", func(t *testing.T) { + afailure := errors.New("a failure") + aaaafailure := errors.New("aaaa failure") + txp := &mocks.DNSTransport{ + MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) { + msg := &dns.Msg{} + if err := msg.Unpack(query); err != nil { + return nil, err + } + if len(msg.Question) != 1 { + return nil, errors.New("expected just one question") + } + q := msg.Question[0] + if q.Qtype == dns.TypeA { + return nil, afailure + } + if q.Qtype == dns.TypeAAAA { + return nil, aaaafailure + } + return nil, errors.New("expected A or AAAA query") + }, + MockRequiresPadding: func() bool { + return true + }, + } + r := NewUnwrappedParallelResolver(txp) + addrs, err := r.LookupHost(context.Background(), "www.gogle.com") + if !errors.Is(err, afailure) { + t.Fatal("unexpected error", err) + } + if len(addrs) != 0 { + t.Fatal("not the result we expected") + } + }) + }) + + t.Run("CloseIdleConnections", func(t *testing.T) { + var called bool + r := &ParallelResolver{ + Txp: &mocks.DNSTransport{ + MockCloseIdleConnections: func() { + called = true + }, + }, + } + r.CloseIdleConnections() + if !called { + t.Fatal("not called") + } + }) + + t.Run("LookupHTTPS", func(t *testing.T) { + t.Run("for encoding error", func(t *testing.T) { + expected := errors.New("mocked error") + r := &ParallelResolver{ + Encoder: &mocks.DNSEncoder{ + MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, error) { + return nil, expected + }, + }, + Decoder: nil, + NumTimeouts: &atomicx.Int64{}, + Txp: &mocks.DNSTransport{ + MockRequiresPadding: func() bool { + return false + }, + }, + } + ctx := context.Background() + https, err := r.LookupHTTPS(ctx, "example.com") + if !errors.Is(err, expected) { + t.Fatal("unexpected err", err) + } + if https != nil { + t.Fatal("unexpected result") + } + }) + + t.Run("for round-trip error", func(t *testing.T) { + expected := errors.New("mocked error") + r := &ParallelResolver{ + Encoder: &mocks.DNSEncoder{ + MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, error) { + return make([]byte, 64), nil + }, + }, + Decoder: nil, + NumTimeouts: &atomicx.Int64{}, + Txp: &mocks.DNSTransport{ + MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) { + return nil, expected + }, + MockRequiresPadding: func() bool { + return false + }, + }, + } + ctx := context.Background() + https, err := r.LookupHTTPS(ctx, "example.com") + if !errors.Is(err, expected) { + t.Fatal("unexpected err", err) + } + if https != nil { + t.Fatal("unexpected result") + } + }) + + t.Run("for decode error", func(t *testing.T) { + expected := errors.New("mocked error") + r := &ParallelResolver{ + Encoder: &mocks.DNSEncoder{ + MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, error) { + return make([]byte, 64), nil + }, + }, + Decoder: &mocks.DNSDecoder{ + MockDecodeHTTPS: func(reply []byte) (*model.HTTPSSvc, error) { + return nil, expected + }, + }, + NumTimeouts: &atomicx.Int64{}, + Txp: &mocks.DNSTransport{ + MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) { + return make([]byte, 128), nil + }, + MockRequiresPadding: func() bool { + return false + }, + }, + } + ctx := context.Background() + https, err := r.LookupHTTPS(ctx, "example.com") + if !errors.Is(err, expected) { + t.Fatal("unexpected err", err) + } + if https != nil { + t.Fatal("unexpected result") + } + }) + }) +} diff --git a/internal/netxlite/serialresolver.go b/internal/netxlite/serialresolver.go index d662a31..5b87e9b 100644 --- a/internal/netxlite/serialresolver.go +++ b/internal/netxlite/serialresolver.go @@ -1,5 +1,9 @@ package netxlite +// +// Serial resolver implementation +// + import ( "context" "errors" @@ -10,11 +14,13 @@ import ( "github.com/ooni/probe-cli/v3/internal/model" ) -// SerialResolver uses a transport and sends performs a LookupHost +// SerialResolver uses a transport and performs a LookupHost // operation in a serial fashion (query for A first, wait for response, // then query for AAAA, and wait for response), hence its name. // // You should probably use NewSerialResolver to create a new instance. +// +// Deprecated: please use ParallelResolver in new code. type SerialResolver struct { // Encoder is the MANDATORY encoder to use. Encoder model.DNSEncoder @@ -25,7 +31,7 @@ type SerialResolver struct { // NumTimeouts is MANDATORY and counts the number of timeouts. NumTimeouts *atomicx.Int64 - // Txp is the underlying DNS transport. + // Txp is the MANDATORY underlying DNS transport. Txp model.DNSTransport }