feat(netxlite): implement parallel resolver (#724)
This diff imports the parallel resolver from websteps winter 2022
edition, which was originally implemented here:
55231d73cd
See https://github.com/ooni/probe/issues/2096
This commit is contained in:
parent
0efd4ff130
commit
ec0561ea8c
131
internal/netxlite/parallelresolver.go
Normal file
131
internal/netxlite/parallelresolver.go
Normal file
|
@ -0,0 +1,131 @@
|
||||||
|
package netxlite
|
||||||
|
|
||||||
|
//
|
||||||
|
// Parallel resolver implementation
|
||||||
|
//
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
"github.com/ooni/probe-cli/v3/internal/atomicx"
|
||||||
|
"github.com/ooni/probe-cli/v3/internal/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ParallelResolver uses a transport and performs a LookupHost
|
||||||
|
// operation in a parallel fashion, hence its name.
|
||||||
|
//
|
||||||
|
// You should probably use NewUnwrappedParallelResolver to
|
||||||
|
// create a new instance of this type.
|
||||||
|
type ParallelResolver struct {
|
||||||
|
// Encoder is the MANDATORY encoder to use.
|
||||||
|
Encoder model.DNSEncoder
|
||||||
|
|
||||||
|
// Decoder is the MANDATORY decoder to use.
|
||||||
|
Decoder model.DNSDecoder
|
||||||
|
|
||||||
|
// NumTimeouts is MANDATORY and counts the number of timeouts.
|
||||||
|
NumTimeouts *atomicx.Int64
|
||||||
|
|
||||||
|
// Txp is the MANDATORY underlying DNS transport.
|
||||||
|
Txp model.DNSTransport
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnwrappedParallelResolver creates a new ParallelResolver instance. This instance is
|
||||||
|
// not wrapped and you should wrap if before using it.
|
||||||
|
func NewUnwrappedParallelResolver(t model.DNSTransport) *ParallelResolver {
|
||||||
|
return &ParallelResolver{
|
||||||
|
Encoder: &DNSEncoderMiekg{},
|
||||||
|
Decoder: &DNSDecoderMiekg{},
|
||||||
|
NumTimeouts: &atomicx.Int64{},
|
||||||
|
Txp: t,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Transport returns the transport being used.
|
||||||
|
func (r *ParallelResolver) Transport() model.DNSTransport {
|
||||||
|
return r.Txp
|
||||||
|
}
|
||||||
|
|
||||||
|
// Network returns the "network" of the underlying transport.
|
||||||
|
func (r *ParallelResolver) Network() string {
|
||||||
|
return r.Txp.Network()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Address returns the "address" of the underlying transport.
|
||||||
|
func (r *ParallelResolver) Address() string {
|
||||||
|
return r.Txp.Address()
|
||||||
|
}
|
||||||
|
|
||||||
|
// CloseIdleConnections closes idle connections, if any.
|
||||||
|
func (r *ParallelResolver) CloseIdleConnections() {
|
||||||
|
r.Txp.CloseIdleConnections()
|
||||||
|
}
|
||||||
|
|
||||||
|
// LookupHost performs an A lookup in parallel with an AAAA lookup.
|
||||||
|
func (r *ParallelResolver) LookupHost(ctx context.Context, hostname string) ([]string, error) {
|
||||||
|
ach := make(chan *parallelResolverResult)
|
||||||
|
go r.lookupHost(ctx, hostname, dns.TypeA, ach)
|
||||||
|
aaaach := make(chan *parallelResolverResult)
|
||||||
|
go r.lookupHost(ctx, hostname, dns.TypeAAAA, aaaach)
|
||||||
|
ares := <-ach
|
||||||
|
aaaares := <-aaaach
|
||||||
|
if ares.err != nil && aaaares.err != nil {
|
||||||
|
// Note: we choose to return the A error because we assume that
|
||||||
|
// it's the more meaningful one: the AAAA error may just be telling
|
||||||
|
// us that there is no AAAA record for the website.
|
||||||
|
return nil, ares.err
|
||||||
|
}
|
||||||
|
var addrs []string
|
||||||
|
addrs = append(addrs, ares.addrs...)
|
||||||
|
addrs = append(addrs, aaaares.addrs...)
|
||||||
|
return addrs, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// LookupHTTPS implements Resolver.LookupHTTPS.
|
||||||
|
func (r *ParallelResolver) LookupHTTPS(
|
||||||
|
ctx context.Context, hostname string) (*model.HTTPSSvc, error) {
|
||||||
|
querydata, err := r.Encoder.Encode(
|
||||||
|
hostname, dns.TypeHTTPS, r.Txp.RequiresPadding())
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
replydata, err := r.Txp.RoundTrip(ctx, querydata)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return r.Decoder.DecodeHTTPS(replydata)
|
||||||
|
}
|
||||||
|
|
||||||
|
// parallelResolverResult is the internal representation of a
|
||||||
|
// lookup using either the A or the AAAA query type.
|
||||||
|
type parallelResolverResult struct {
|
||||||
|
addrs []string
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
// lookupHost issues a lookup host query for the specified qtype (e.g., dns.A).
|
||||||
|
func (r *ParallelResolver) lookupHost(ctx context.Context, hostname string,
|
||||||
|
qtype uint16, out chan<- *parallelResolverResult) {
|
||||||
|
querydata, err := r.Encoder.Encode(hostname, qtype, r.Txp.RequiresPadding())
|
||||||
|
if err != nil {
|
||||||
|
out <- ¶llelResolverResult{
|
||||||
|
addrs: []string{},
|
||||||
|
err: err,
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
replydata, err := r.Txp.RoundTrip(ctx, querydata)
|
||||||
|
if err != nil {
|
||||||
|
out <- ¶llelResolverResult{
|
||||||
|
addrs: []string{},
|
||||||
|
err: err,
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
addrs, err := r.Decoder.DecodeLookupHost(qtype, replydata)
|
||||||
|
out <- ¶llelResolverResult{
|
||||||
|
addrs: addrs,
|
||||||
|
err: err,
|
||||||
|
}
|
||||||
|
}
|
269
internal/netxlite/parallelresolver_test.go
Normal file
269
internal/netxlite/parallelresolver_test.go
Normal file
|
@ -0,0 +1,269 @@
|
||||||
|
package netxlite
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
"github.com/ooni/probe-cli/v3/internal/atomicx"
|
||||||
|
"github.com/ooni/probe-cli/v3/internal/model"
|
||||||
|
"github.com/ooni/probe-cli/v3/internal/model/mocks"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestParallelResolver(t *testing.T) {
|
||||||
|
t.Run("transport okay", func(t *testing.T) {
|
||||||
|
txp := NewDNSOverTLS((&tls.Dialer{}).DialContext, "8.8.8.8:853")
|
||||||
|
r := NewUnwrappedParallelResolver(txp)
|
||||||
|
rtx := r.Transport()
|
||||||
|
if rtx.Network() != "dot" || rtx.Address() != "8.8.8.8:853" {
|
||||||
|
t.Fatal("not the transport we expected")
|
||||||
|
}
|
||||||
|
if r.Network() != rtx.Network() {
|
||||||
|
t.Fatal("invalid network seen from the resolver")
|
||||||
|
}
|
||||||
|
if r.Address() != rtx.Address() {
|
||||||
|
t.Fatal("invalid address seen from the resolver")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("LookupHost", func(t *testing.T) {
|
||||||
|
t.Run("Encode error", func(t *testing.T) {
|
||||||
|
mocked := errors.New("mocked error")
|
||||||
|
txp := NewDNSOverTLS((&tls.Dialer{}).DialContext, "8.8.8.8:853")
|
||||||
|
r := ParallelResolver{
|
||||||
|
Encoder: &mocks.DNSEncoder{
|
||||||
|
MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, error) {
|
||||||
|
return nil, mocked
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Txp: txp,
|
||||||
|
}
|
||||||
|
addrs, err := r.LookupHost(context.Background(), "www.gogle.com")
|
||||||
|
if !errors.Is(err, mocked) {
|
||||||
|
t.Fatal("not the error we expected")
|
||||||
|
}
|
||||||
|
if addrs != nil {
|
||||||
|
t.Fatal("expected nil address here")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("RoundTrip error", func(t *testing.T) {
|
||||||
|
mocked := errors.New("mocked error")
|
||||||
|
txp := &mocks.DNSTransport{
|
||||||
|
MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) {
|
||||||
|
return nil, mocked
|
||||||
|
},
|
||||||
|
MockRequiresPadding: func() bool {
|
||||||
|
return true
|
||||||
|
},
|
||||||
|
}
|
||||||
|
r := NewUnwrappedParallelResolver(txp)
|
||||||
|
addrs, err := r.LookupHost(context.Background(), "www.gogle.com")
|
||||||
|
if !errors.Is(err, mocked) {
|
||||||
|
t.Fatal("not the error we expected")
|
||||||
|
}
|
||||||
|
if addrs != nil {
|
||||||
|
t.Fatal("expected nil address here")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("empty reply", func(t *testing.T) {
|
||||||
|
txp := &mocks.DNSTransport{
|
||||||
|
MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) {
|
||||||
|
return dnsGenLookupHostReplySuccess(t, dns.TypeA), nil
|
||||||
|
},
|
||||||
|
MockRequiresPadding: func() bool {
|
||||||
|
return true
|
||||||
|
},
|
||||||
|
}
|
||||||
|
r := NewUnwrappedParallelResolver(txp)
|
||||||
|
addrs, err := r.LookupHost(context.Background(), "www.gogle.com")
|
||||||
|
if !errors.Is(err, ErrOODNSNoAnswer) {
|
||||||
|
t.Fatal("not the error we expected", err)
|
||||||
|
}
|
||||||
|
if addrs != nil {
|
||||||
|
t.Fatal("expected nil address here")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("with A reply", func(t *testing.T) {
|
||||||
|
txp := &mocks.DNSTransport{
|
||||||
|
MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) {
|
||||||
|
return dnsGenLookupHostReplySuccess(t, dns.TypeA, "8.8.8.8"), nil
|
||||||
|
},
|
||||||
|
MockRequiresPadding: func() bool {
|
||||||
|
return true
|
||||||
|
},
|
||||||
|
}
|
||||||
|
r := NewUnwrappedParallelResolver(txp)
|
||||||
|
addrs, err := r.LookupHost(context.Background(), "www.gogle.com")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if len(addrs) != 1 || addrs[0] != "8.8.8.8" {
|
||||||
|
t.Fatal("not the result we expected")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("with AAAA reply", func(t *testing.T) {
|
||||||
|
txp := &mocks.DNSTransport{
|
||||||
|
MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) {
|
||||||
|
return dnsGenLookupHostReplySuccess(t, dns.TypeAAAA, "::1"), nil
|
||||||
|
},
|
||||||
|
MockRequiresPadding: func() bool {
|
||||||
|
return true
|
||||||
|
},
|
||||||
|
}
|
||||||
|
r := NewUnwrappedParallelResolver(txp)
|
||||||
|
addrs, err := r.LookupHost(context.Background(), "www.gogle.com")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if len(addrs) != 1 || addrs[0] != "::1" {
|
||||||
|
t.Fatal("not the result we expected")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("A failure takes precedence over AAAA failure", func(t *testing.T) {
|
||||||
|
afailure := errors.New("a failure")
|
||||||
|
aaaafailure := errors.New("aaaa failure")
|
||||||
|
txp := &mocks.DNSTransport{
|
||||||
|
MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) {
|
||||||
|
msg := &dns.Msg{}
|
||||||
|
if err := msg.Unpack(query); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if len(msg.Question) != 1 {
|
||||||
|
return nil, errors.New("expected just one question")
|
||||||
|
}
|
||||||
|
q := msg.Question[0]
|
||||||
|
if q.Qtype == dns.TypeA {
|
||||||
|
return nil, afailure
|
||||||
|
}
|
||||||
|
if q.Qtype == dns.TypeAAAA {
|
||||||
|
return nil, aaaafailure
|
||||||
|
}
|
||||||
|
return nil, errors.New("expected A or AAAA query")
|
||||||
|
},
|
||||||
|
MockRequiresPadding: func() bool {
|
||||||
|
return true
|
||||||
|
},
|
||||||
|
}
|
||||||
|
r := NewUnwrappedParallelResolver(txp)
|
||||||
|
addrs, err := r.LookupHost(context.Background(), "www.gogle.com")
|
||||||
|
if !errors.Is(err, afailure) {
|
||||||
|
t.Fatal("unexpected error", err)
|
||||||
|
}
|
||||||
|
if len(addrs) != 0 {
|
||||||
|
t.Fatal("not the result we expected")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("CloseIdleConnections", func(t *testing.T) {
|
||||||
|
var called bool
|
||||||
|
r := &ParallelResolver{
|
||||||
|
Txp: &mocks.DNSTransport{
|
||||||
|
MockCloseIdleConnections: func() {
|
||||||
|
called = true
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
r.CloseIdleConnections()
|
||||||
|
if !called {
|
||||||
|
t.Fatal("not called")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("LookupHTTPS", func(t *testing.T) {
|
||||||
|
t.Run("for encoding error", func(t *testing.T) {
|
||||||
|
expected := errors.New("mocked error")
|
||||||
|
r := &ParallelResolver{
|
||||||
|
Encoder: &mocks.DNSEncoder{
|
||||||
|
MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, error) {
|
||||||
|
return nil, expected
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Decoder: nil,
|
||||||
|
NumTimeouts: &atomicx.Int64{},
|
||||||
|
Txp: &mocks.DNSTransport{
|
||||||
|
MockRequiresPadding: func() bool {
|
||||||
|
return false
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
ctx := context.Background()
|
||||||
|
https, err := r.LookupHTTPS(ctx, "example.com")
|
||||||
|
if !errors.Is(err, expected) {
|
||||||
|
t.Fatal("unexpected err", err)
|
||||||
|
}
|
||||||
|
if https != nil {
|
||||||
|
t.Fatal("unexpected result")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("for round-trip error", func(t *testing.T) {
|
||||||
|
expected := errors.New("mocked error")
|
||||||
|
r := &ParallelResolver{
|
||||||
|
Encoder: &mocks.DNSEncoder{
|
||||||
|
MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, error) {
|
||||||
|
return make([]byte, 64), nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Decoder: nil,
|
||||||
|
NumTimeouts: &atomicx.Int64{},
|
||||||
|
Txp: &mocks.DNSTransport{
|
||||||
|
MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) {
|
||||||
|
return nil, expected
|
||||||
|
},
|
||||||
|
MockRequiresPadding: func() bool {
|
||||||
|
return false
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
ctx := context.Background()
|
||||||
|
https, err := r.LookupHTTPS(ctx, "example.com")
|
||||||
|
if !errors.Is(err, expected) {
|
||||||
|
t.Fatal("unexpected err", err)
|
||||||
|
}
|
||||||
|
if https != nil {
|
||||||
|
t.Fatal("unexpected result")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("for decode error", func(t *testing.T) {
|
||||||
|
expected := errors.New("mocked error")
|
||||||
|
r := &ParallelResolver{
|
||||||
|
Encoder: &mocks.DNSEncoder{
|
||||||
|
MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, error) {
|
||||||
|
return make([]byte, 64), nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Decoder: &mocks.DNSDecoder{
|
||||||
|
MockDecodeHTTPS: func(reply []byte) (*model.HTTPSSvc, error) {
|
||||||
|
return nil, expected
|
||||||
|
},
|
||||||
|
},
|
||||||
|
NumTimeouts: &atomicx.Int64{},
|
||||||
|
Txp: &mocks.DNSTransport{
|
||||||
|
MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) {
|
||||||
|
return make([]byte, 128), nil
|
||||||
|
},
|
||||||
|
MockRequiresPadding: func() bool {
|
||||||
|
return false
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
ctx := context.Background()
|
||||||
|
https, err := r.LookupHTTPS(ctx, "example.com")
|
||||||
|
if !errors.Is(err, expected) {
|
||||||
|
t.Fatal("unexpected err", err)
|
||||||
|
}
|
||||||
|
if https != nil {
|
||||||
|
t.Fatal("unexpected result")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
|
@ -1,5 +1,9 @@
|
||||||
package netxlite
|
package netxlite
|
||||||
|
|
||||||
|
//
|
||||||
|
// Serial resolver implementation
|
||||||
|
//
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
|
@ -10,11 +14,13 @@ import (
|
||||||
"github.com/ooni/probe-cli/v3/internal/model"
|
"github.com/ooni/probe-cli/v3/internal/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
// SerialResolver uses a transport and sends performs a LookupHost
|
// SerialResolver uses a transport and performs a LookupHost
|
||||||
// operation in a serial fashion (query for A first, wait for response,
|
// operation in a serial fashion (query for A first, wait for response,
|
||||||
// then query for AAAA, and wait for response), hence its name.
|
// then query for AAAA, and wait for response), hence its name.
|
||||||
//
|
//
|
||||||
// You should probably use NewSerialResolver to create a new instance.
|
// You should probably use NewSerialResolver to create a new instance.
|
||||||
|
//
|
||||||
|
// Deprecated: please use ParallelResolver in new code.
|
||||||
type SerialResolver struct {
|
type SerialResolver struct {
|
||||||
// Encoder is the MANDATORY encoder to use.
|
// Encoder is the MANDATORY encoder to use.
|
||||||
Encoder model.DNSEncoder
|
Encoder model.DNSEncoder
|
||||||
|
@ -25,7 +31,7 @@ type SerialResolver struct {
|
||||||
// NumTimeouts is MANDATORY and counts the number of timeouts.
|
// NumTimeouts is MANDATORY and counts the number of timeouts.
|
||||||
NumTimeouts *atomicx.Int64
|
NumTimeouts *atomicx.Int64
|
||||||
|
|
||||||
// Txp is the underlying DNS transport.
|
// Txp is the MANDATORY underlying DNS transport.
|
||||||
Txp model.DNSTransport
|
Txp model.DNSTransport
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user