feat: dnsping using step-by-step (#831)
Reference issue for this pull request: https://github.com/ooni/probe/issues/2159 This diff refactors the `dnsping` experiment to use the [step-by-step measurement style](https://github.com/ooni/probe-cli/blob/master/docs/design/dd-003-step-by-step.md). Co-authored-by: decfox <decfox@github.com> Co-authored-by: Simone Basso <bassosimone@gmail.com>
This commit is contained in:
@@ -0,0 +1,170 @@
|
||||
package measurexlite
|
||||
|
||||
//
|
||||
// DNS Lookup with tracing
|
||||
//
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/geolocate"
|
||||
"github.com/ooni/probe-cli/v3/internal/model"
|
||||
"github.com/ooni/probe-cli/v3/internal/netxlite"
|
||||
"github.com/ooni/probe-cli/v3/internal/tracex"
|
||||
)
|
||||
|
||||
// newParallelResolverTrace is equivalent to netxlite.NewParallelResolver
|
||||
// except that it returns a model.Resolver that uses this trace.
|
||||
func (tx *Trace) newParallelResolverTrace(newResolver func() model.Resolver) model.Resolver {
|
||||
return &resolverTrace{
|
||||
r: tx.newParallelResolver(newResolver),
|
||||
tx: tx,
|
||||
}
|
||||
}
|
||||
|
||||
// resolverTrace is a trace-aware resolver
|
||||
type resolverTrace struct {
|
||||
r model.Resolver
|
||||
tx *Trace
|
||||
}
|
||||
|
||||
var _ model.Resolver = &resolverTrace{}
|
||||
|
||||
// Address implements model.Resolver.Address
|
||||
func (r *resolverTrace) Address() string {
|
||||
return r.r.Address()
|
||||
}
|
||||
|
||||
// Network implements model.Resolver.Network
|
||||
func (r *resolverTrace) Network() string {
|
||||
return r.r.Network()
|
||||
}
|
||||
|
||||
// CloseIdleConnections implements model.Resolver.CloseIdleConnections
|
||||
func (r *resolverTrace) CloseIdleConnections() {
|
||||
r.r.CloseIdleConnections()
|
||||
}
|
||||
|
||||
// LookupHost implements model.Resolver.LookupHost
|
||||
func (r *resolverTrace) LookupHost(ctx context.Context, hostname string) ([]string, error) {
|
||||
return r.r.LookupHost(netxlite.ContextWithTrace(ctx, r.tx), hostname)
|
||||
}
|
||||
|
||||
// LookupHTTPS implements model.Resolver.LookupHTTPS
|
||||
func (r *resolverTrace) LookupHTTPS(ctx context.Context, domain string) (*model.HTTPSSvc, error) {
|
||||
return r.r.LookupHTTPS(netxlite.ContextWithTrace(ctx, r.tx), domain)
|
||||
}
|
||||
|
||||
// LookupNS implements model.Resolver.LookupNS
|
||||
func (r *resolverTrace) LookupNS(ctx context.Context, domain string) ([]*net.NS, error) {
|
||||
return r.r.LookupNS(netxlite.ContextWithTrace(ctx, r.tx), domain)
|
||||
}
|
||||
|
||||
// NewParallelUDPResolver returns a trace-ware parallel UDP resolver
|
||||
func (tx *Trace) NewParallelUDPResolver(logger model.Logger, dialer model.Dialer, address string) model.Resolver {
|
||||
return tx.newParallelResolverTrace(func() model.Resolver {
|
||||
return netxlite.NewParallelUDPResolver(logger, dialer, address)
|
||||
})
|
||||
}
|
||||
|
||||
// NewParallelDNSOverHTTPSResolver returns a trace-aware parallel DoH resolver
|
||||
func (tx *Trace) NewParallelDNSOverHTTPSResolver(logger model.Logger, URL string) model.Resolver {
|
||||
return tx.newParallelResolverTrace(func() model.Resolver {
|
||||
return netxlite.NewParallelDNSOverHTTPSResolver(logger, URL)
|
||||
})
|
||||
}
|
||||
|
||||
// OnDNSRoundTripForLookupHost implements model.Trace.OnDNSRoundTripForLookupHost
|
||||
func (tx *Trace) OnDNSRoundTripForLookupHost(started time.Time, reso model.Resolver, query model.DNSQuery,
|
||||
response model.DNSResponse, addrs []string, err error, finished time.Time) {
|
||||
ch := tx.DNSLookup[query.Type()]
|
||||
if ch == nil {
|
||||
// Prevent blocking forever. See https://dave.cheney.net/2014/03/19/channel-axioms.
|
||||
log.Printf("BUG: Requested query type %s has no valid channel to buffer results", dns.TypeToString[query.Type()])
|
||||
return
|
||||
}
|
||||
select {
|
||||
case ch <- NewArchivalDNSLookupResultFromRoundTrip(
|
||||
tx.Index,
|
||||
started.Sub(tx.ZeroTime),
|
||||
reso,
|
||||
query,
|
||||
response,
|
||||
addrs,
|
||||
err,
|
||||
finished.Sub(tx.ZeroTime),
|
||||
):
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
// NewArchivalDNSLookupResultFromRoundTrip generates a model.ArchivalDNSLookupResultFromRoundTrip
|
||||
// from the available information right after the DNS RoundTrip
|
||||
func NewArchivalDNSLookupResultFromRoundTrip(index int64, started time.Duration, reso model.Resolver, query model.DNSQuery,
|
||||
response model.DNSResponse, addrs []string, err error, finished time.Duration) *model.ArchivalDNSLookupResult {
|
||||
return &model.ArchivalDNSLookupResult{
|
||||
Answers: archivalAnswersFromAddrs(addrs),
|
||||
Engine: reso.Network(),
|
||||
Failure: tracex.NewFailure(err),
|
||||
Hostname: query.Domain(),
|
||||
QueryType: dns.TypeToString[query.Type()],
|
||||
ResolverHostname: nil,
|
||||
ResolverAddress: reso.Address(),
|
||||
T: finished.Seconds(),
|
||||
}
|
||||
}
|
||||
|
||||
// archivalAnswersFromAddrs generates model.ArchivalDNSAnswer from an array of addresses
|
||||
func archivalAnswersFromAddrs(addrs []string) (out []model.ArchivalDNSAnswer) {
|
||||
for _, addr := range addrs {
|
||||
ipv6, err := netxlite.IsIPv6(addr)
|
||||
if err != nil {
|
||||
log.Printf("BUG: NewArchivalDNSLookupResult: invalid IP address: %s", addr)
|
||||
continue
|
||||
}
|
||||
asn, org, _ := geolocate.LookupASN(addr)
|
||||
switch ipv6 {
|
||||
case false:
|
||||
out = append(out, model.ArchivalDNSAnswer{
|
||||
ASN: int64(asn),
|
||||
ASOrgName: org,
|
||||
AnswerType: "A",
|
||||
Hostname: "",
|
||||
IPv4: addr,
|
||||
TTL: nil,
|
||||
})
|
||||
case true:
|
||||
out = append(out, model.ArchivalDNSAnswer{
|
||||
ASN: int64(asn),
|
||||
ASOrgName: org,
|
||||
AnswerType: "AAAA",
|
||||
Hostname: "",
|
||||
IPv6: addr,
|
||||
TTL: nil,
|
||||
})
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// DNSLookupsFromRoundTrip drains the network events buffered inside the corresponding query channel
|
||||
func (tx *Trace) DNSLookupsFromRoundTrip(query uint16) (out []*model.ArchivalDNSLookupResult) {
|
||||
ch := tx.DNSLookup[query]
|
||||
if ch == nil {
|
||||
// Prevent blocking forever. See https://dave.cheney.net/2014/03/19/channel-axioms.
|
||||
log.Printf("BUG: Requested query type %s has no valid channel to buffer results", dns.TypeToString[query])
|
||||
return
|
||||
}
|
||||
for {
|
||||
select {
|
||||
case ev := <-ch:
|
||||
out = append(out, ev)
|
||||
default:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,296 @@
|
||||
package measurexlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/ooni/probe-cli/v3/internal/model"
|
||||
"github.com/ooni/probe-cli/v3/internal/model/mocks"
|
||||
"github.com/ooni/probe-cli/v3/internal/netxlite"
|
||||
"github.com/ooni/probe-cli/v3/internal/testingx"
|
||||
)
|
||||
|
||||
func TestNewUnwrappedParallelResolver(t *testing.T) {
|
||||
t.Run("NewUnwrappedParallelResolver creates an UnwrappedParallelResolver with Trace", func(t *testing.T) {
|
||||
underlying := &mocks.Resolver{}
|
||||
zeroTime := time.Now()
|
||||
trace := NewTrace(0, zeroTime)
|
||||
trace.NewParallelResolverFn = func() model.Resolver {
|
||||
return underlying
|
||||
}
|
||||
resolver := trace.newParallelResolverTrace(func() model.Resolver {
|
||||
return nil
|
||||
})
|
||||
resolvert := resolver.(*resolverTrace)
|
||||
if resolvert.r != underlying {
|
||||
t.Fatal("invalid parallel resolver")
|
||||
}
|
||||
if resolvert.tx != trace {
|
||||
t.Fatal("invalid trace")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Trace-aware resolver forwards underlying functions", func(t *testing.T) {
|
||||
var called bool
|
||||
zeroTime := time.Now()
|
||||
trace := NewTrace(0, zeroTime)
|
||||
newMockResolver := func() model.Resolver {
|
||||
return &mocks.Resolver{
|
||||
MockAddress: func() string {
|
||||
return "dns.google"
|
||||
},
|
||||
MockNetwork: func() string {
|
||||
return "udp"
|
||||
},
|
||||
MockCloseIdleConnections: func() {
|
||||
called = true
|
||||
},
|
||||
}
|
||||
}
|
||||
resolver := trace.newParallelResolver(newMockResolver)
|
||||
|
||||
t.Run("Address is correctly forwarded", func(t *testing.T) {
|
||||
got := resolver.Address()
|
||||
if got != "dns.google" {
|
||||
t.Fatal("Address not called")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Network is correctly forwarded", func(t *testing.T) {
|
||||
got := resolver.Network()
|
||||
if got != "udp" {
|
||||
t.Fatal("Network not called")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("CloseIdleConnections is correctly forwarded", func(t *testing.T) {
|
||||
resolver.CloseIdleConnections()
|
||||
if !called {
|
||||
t.Fatal("CloseIdleConnections not called")
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("LookupHost saves into trace", func(t *testing.T) {
|
||||
zeroTime := time.Now()
|
||||
td := testingx.NewTimeDeterministic(zeroTime)
|
||||
trace := NewTrace(0, zeroTime)
|
||||
trace.TimeNowFn = td.Now
|
||||
txp := &mocks.DNSTransport{
|
||||
MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
|
||||
response := &mocks.DNSResponse{
|
||||
MockDecodeLookupHost: func() ([]string, error) {
|
||||
if query.Type() != dns.TypeA {
|
||||
return []string{"fe80::a00:20ff:feb9:4c54"}, nil
|
||||
}
|
||||
return []string{"1.1.1.1"}, nil
|
||||
},
|
||||
}
|
||||
return response, nil
|
||||
},
|
||||
MockRequiresPadding: func() bool {
|
||||
return true
|
||||
},
|
||||
MockNetwork: func() string {
|
||||
return ""
|
||||
},
|
||||
MockAddress: func() string {
|
||||
return "dns.google"
|
||||
},
|
||||
}
|
||||
newResolver := func() model.Resolver {
|
||||
return netxlite.NewUnwrappedParallelResolver(txp)
|
||||
}
|
||||
resolver := trace.newParallelResolverTrace(newResolver)
|
||||
ctx := context.Background()
|
||||
addrs, err := resolver.LookupHost(ctx, "example.com")
|
||||
if err != nil {
|
||||
t.Fatal("unexpected err", err)
|
||||
}
|
||||
if len(addrs) != 2 {
|
||||
t.Fatal("unexpected array output", addrs)
|
||||
}
|
||||
if addrs[0] != "1.1.1.1" && addrs[1] != "1.1.1.1" {
|
||||
t.Fatal("unexpected array output", addrs)
|
||||
}
|
||||
if addrs[0] != "fe80::a00:20ff:feb9:4c54" && addrs[1] != "fe80::a00:20ff:feb9:4c54" {
|
||||
t.Fatal("unexpected array output", addrs)
|
||||
}
|
||||
|
||||
t.Run("DNSLookups QueryType A", func(t *testing.T) {
|
||||
events := trace.DNSLookupsFromRoundTrip(dns.TypeA)
|
||||
if len(events) != 1 {
|
||||
t.Fatal("expected to see single DNSLookup event")
|
||||
}
|
||||
lookup := events[0]
|
||||
answers := lookup.Answers
|
||||
if lookup.Failure != nil {
|
||||
t.Fatal("unexpected err", *(lookup.Failure))
|
||||
}
|
||||
if lookup.ResolverAddress != "dns.google" {
|
||||
t.Fatal("unexpected address field")
|
||||
}
|
||||
if len(answers) != 1 {
|
||||
t.Fatal("expected 1 DNS answer, got", len(answers))
|
||||
}
|
||||
if answers[0].AnswerType != "A" || answers[0].IPv4 != "1.1.1.1" {
|
||||
t.Fatal("unexpected DNS answer", answers)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("DNSLookups QueryType AAAA", func(t *testing.T) {
|
||||
events := trace.DNSLookupsFromRoundTrip(dns.TypeAAAA)
|
||||
if len(events) != 1 {
|
||||
t.Fatal("expected to see single DNSLookup event")
|
||||
}
|
||||
lookup := events[0]
|
||||
answers := lookup.Answers
|
||||
if lookup.Failure != nil {
|
||||
t.Fatal("unexpected err", *(lookup.Failure))
|
||||
}
|
||||
if lookup.ResolverAddress != "dns.google" {
|
||||
t.Fatal("unexpected address field")
|
||||
}
|
||||
if len(answers) != 1 {
|
||||
t.Fatal("expected 1 DNS answer, got", len(answers))
|
||||
}
|
||||
if answers[0].AnswerType != "AAAA" || answers[0].IPv6 != "fe80::a00:20ff:feb9:4c54" {
|
||||
t.Fatal("unexpected DNS answer", answers)
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("LookupHost discards events when buffers are full", func(t *testing.T) {
|
||||
zeroTime := time.Now()
|
||||
td := testingx.NewTimeDeterministic(zeroTime)
|
||||
trace := NewTrace(0, zeroTime)
|
||||
trace.DNSLookup = map[uint16]chan *model.ArchivalDNSLookupResult{
|
||||
dns.TypeA: make(chan *model.ArchivalDNSLookupResult), // no buffer
|
||||
dns.TypeAAAA: make(chan *model.ArchivalDNSLookupResult), // no buffer
|
||||
}
|
||||
trace.TimeNowFn = td.Now
|
||||
txp := &mocks.DNSTransport{
|
||||
MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
|
||||
response := &mocks.DNSResponse{
|
||||
MockDecodeLookupHost: func() ([]string, error) {
|
||||
if query.Type() != dns.TypeA {
|
||||
return []string{"fe80::a00:20ff:feb9:4c54"}, nil
|
||||
}
|
||||
return []string{"1.1.1.1"}, nil
|
||||
},
|
||||
}
|
||||
return response, nil
|
||||
},
|
||||
MockRequiresPadding: func() bool {
|
||||
return true
|
||||
},
|
||||
MockNetwork: func() string {
|
||||
return ""
|
||||
},
|
||||
MockAddress: func() string {
|
||||
return "dns.google"
|
||||
},
|
||||
}
|
||||
newResolver := func() model.Resolver {
|
||||
return netxlite.NewUnwrappedParallelResolver(txp)
|
||||
}
|
||||
resolver := trace.newParallelResolverTrace(newResolver)
|
||||
ctx := context.Background()
|
||||
addrs, err := resolver.LookupHost(ctx, "example.com")
|
||||
if err != nil {
|
||||
t.Fatal("unexpected err", err)
|
||||
}
|
||||
if len(addrs) != 2 {
|
||||
t.Fatal("unexpected array output", addrs)
|
||||
}
|
||||
|
||||
t.Run("DNSLookups QueryType A", func(t *testing.T) {
|
||||
events := trace.DNSLookupsFromRoundTrip(dns.TypeA)
|
||||
if len(events) != 0 {
|
||||
t.Fatal("expected to see no DNSLookup")
|
||||
}
|
||||
})
|
||||
t.Run("DNSLookups QueryType AAAA", func(t *testing.T) {
|
||||
events := trace.DNSLookupsFromRoundTrip(dns.TypeAAAA)
|
||||
if len(events) != 0 {
|
||||
t.Fatal("expected to see no DNSLookup")
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestAnswersFromAddrs(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
args []string
|
||||
}{{
|
||||
name: "with valid input",
|
||||
args: []string{"1.1.1.1", "fe80::a00:20ff:feb9:4c54"},
|
||||
}, {
|
||||
name: "with invalid IPv4 address",
|
||||
args: []string{"1.1.1.1.1", "fe80::a00:20ff:feb9:4c54"},
|
||||
}, {
|
||||
name: "with invalid IPv6 address",
|
||||
args: []string{"1.1.1.1", "fe80::a00:20ff:feb9:::4c54"},
|
||||
}, {
|
||||
name: "with empty input",
|
||||
args: []string{},
|
||||
}, {
|
||||
name: "with nil input",
|
||||
args: nil,
|
||||
}}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := archivalAnswersFromAddrs(tt.args)
|
||||
var idx int
|
||||
for _, inp := range tt.args {
|
||||
ip6, err := netxlite.IsIPv6(inp)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if idx >= len(got) {
|
||||
t.Fatal("unexpected array length")
|
||||
}
|
||||
answer := got[idx]
|
||||
if ip6 {
|
||||
if answer.AnswerType != "AAAA" || answer.IPv6 != inp {
|
||||
t.Fatal("unexpected output", answer)
|
||||
}
|
||||
} else {
|
||||
if answer.AnswerType != "A" || answer.IPv4 != inp {
|
||||
t.Fatal("unexpected output", answer)
|
||||
}
|
||||
}
|
||||
idx++
|
||||
}
|
||||
if idx != len(got) {
|
||||
t.Fatal("unexpected array length", len(got))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSLookupsFromRoundTrips(t *testing.T) {
|
||||
zeroTime := time.Now()
|
||||
trace := NewTrace(0, zeroTime)
|
||||
checkPanic := func(query uint16, f func(uint16) []*model.ArchivalDNSLookupResult) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Fatal("unexpected panic encoutered")
|
||||
}
|
||||
}()
|
||||
f(query)
|
||||
}
|
||||
t.Run("DNSLookup is nil", func(t *testing.T) {
|
||||
trace.DNSLookup = nil
|
||||
checkPanic(dns.TypeA, trace.DNSLookupsFromRoundTrip)
|
||||
})
|
||||
t.Run("Query has nil channel", func(t *testing.T) {
|
||||
trace.DNSLookup = map[uint16]chan *model.ArchivalDNSLookupResult{
|
||||
dns.TypeA: nil,
|
||||
}
|
||||
checkPanic(dns.TypeA, trace.DNSLookupsFromRoundTrip)
|
||||
})
|
||||
}
|
||||
@@ -7,6 +7,7 @@ package measurexlite
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/ooni/probe-cli/v3/internal/model"
|
||||
"github.com/ooni/probe-cli/v3/internal/netxlite"
|
||||
)
|
||||
@@ -31,13 +32,17 @@ import (
|
||||
type Trace struct {
|
||||
// Index is the MANDATORY unique index of this trace within the
|
||||
// current measurement. If you don't care about uniquely identifying
|
||||
// treaces, you can use zero to indicate the "default" trace.
|
||||
// traces, you can use zero to indicate the "default" trace.
|
||||
Index int64
|
||||
|
||||
// NetworkEvent is MANDATORY and buffers network events. If you create
|
||||
// this channel manually, ensure it has some buffer.
|
||||
NetworkEvent chan *model.ArchivalNetworkEvent
|
||||
|
||||
// NewParallelResolverFn is OPTIONAL and can be used to overide
|
||||
// calls to the netxlite.NewParallelResolver factory.
|
||||
NewParallelResolverFn func() model.Resolver
|
||||
|
||||
// NewDialerWithoutResolverFn is OPTIONAL and can be used to override
|
||||
// calls to the netxlite.NewDialerWithoutResolver factory.
|
||||
NewDialerWithoutResolverFn func(dl model.DebugLogger) model.Dialer
|
||||
@@ -46,6 +51,14 @@ type Trace struct {
|
||||
// calls to the netxlite.NewTLSHandshakerStdlib factory.
|
||||
NewTLSHandshakerStdlibFn func(dl model.DebugLogger) model.TLSHandshaker
|
||||
|
||||
// DNSLookup is MANDATORY and buffers DNSLookup results based on the
|
||||
// query type. When we create this map using NewTrace, we will create
|
||||
// an entry for each dns.Type in DNSQueryTypes. If you create this channel
|
||||
// manually, you probably want to to the same (and most likely you also
|
||||
// want to create buffered channels). Note that the code will print a
|
||||
// warning and otherwise ignore all the query types not included in this map.
|
||||
DNSLookup map[uint16]chan *model.ArchivalDNSLookupResult
|
||||
|
||||
// TCPConnect is MANDATORY and buffers TCP connect observations. If you create
|
||||
// this channel manually, ensure it has some buffer.
|
||||
TCPConnect chan *model.ArchivalTCPConnectResult
|
||||
@@ -67,6 +80,10 @@ const (
|
||||
// the Trace's NetworkEvent buffered channel.
|
||||
NetworkEventBufferSize = 64
|
||||
|
||||
// DNSLookupBufferSize is the buffer size for constructing
|
||||
// the Trace's DNSLookup map of buffered channels.
|
||||
DNSLookupBufferSize = 8
|
||||
|
||||
// TCPConnectBufferSize is the buffer size for constructing
|
||||
// the Trace's TCPConnect buffered channel.
|
||||
TCPConnectBufferSize = 8
|
||||
@@ -76,6 +93,25 @@ const (
|
||||
TLSHandshakeBufferSize = 8
|
||||
)
|
||||
|
||||
// DNSQueryTypes contains the list of DNS query types for which
|
||||
// NewTrace create entries in Trace.DNSLookup.
|
||||
var DNSQueryTypes = []uint16{
|
||||
dns.TypeANY,
|
||||
dns.TypeA,
|
||||
dns.TypeAAAA,
|
||||
dns.TypeCNAME,
|
||||
dns.TypeNS,
|
||||
}
|
||||
|
||||
// newDefaultDNSLookupMap is a convenience factory for creating Trace.DNSLookup
|
||||
func newDefaultDNSLookupMap() map[uint16]chan *model.ArchivalDNSLookupResult {
|
||||
out := make(map[uint16]chan *model.ArchivalDNSLookupResult)
|
||||
for _, qtype := range DNSQueryTypes {
|
||||
out[qtype] = make(chan *model.ArchivalDNSLookupResult, DNSLookupBufferSize)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// NewTrace creates a new instance of Trace using default settings.
|
||||
//
|
||||
// We create buffered channels using as buffer sizes the constants that
|
||||
@@ -96,6 +132,7 @@ func NewTrace(index int64, zeroTime time.Time) *Trace {
|
||||
),
|
||||
NewDialerWithoutResolverFn: nil, // use default
|
||||
NewTLSHandshakerStdlibFn: nil, // use default
|
||||
DNSLookup: newDefaultDNSLookupMap(),
|
||||
TCPConnect: make(
|
||||
chan *model.ArchivalTCPConnectResult,
|
||||
TCPConnectBufferSize,
|
||||
@@ -110,7 +147,7 @@ func NewTrace(index int64, zeroTime time.Time) *Trace {
|
||||
}
|
||||
|
||||
// newDialerWithoutResolver indirectly calls netxlite.NewDialerWithoutResolver
|
||||
// thus allows us to mock this func for testing.
|
||||
// thus allowing us to mock this func for testing.
|
||||
func (tx *Trace) newDialerWithoutResolver(dl model.DebugLogger) model.Dialer {
|
||||
if tx.NewDialerWithoutResolverFn != nil {
|
||||
return tx.NewDialerWithoutResolverFn(dl)
|
||||
@@ -118,6 +155,15 @@ func (tx *Trace) newDialerWithoutResolver(dl model.DebugLogger) model.Dialer {
|
||||
return netxlite.NewDialerWithoutResolver(dl)
|
||||
}
|
||||
|
||||
// newParallelResolver indirectly calls the passed netxlite.NewParallerResolver
|
||||
// thus allowing us to mock this function for testing
|
||||
func (tx *Trace) newParallelResolver(newResolver func() model.Resolver) model.Resolver {
|
||||
if tx.NewParallelResolverFn != nil {
|
||||
return tx.NewParallelResolverFn()
|
||||
}
|
||||
return newResolver()
|
||||
}
|
||||
|
||||
// newTLSHandshakerStdlib indirectly calls netxlite.NewTLSHandshakerStdlib
|
||||
// thus allowing us to mock this func for testing.
|
||||
func (tx *Trace) newTLSHandshakerStdlib(dl model.DebugLogger) model.TLSHandshaker {
|
||||
|
||||
@@ -46,6 +46,12 @@ func TestNewTrace(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("NewParallelResolverFn is nil", func(t *testing.T) {
|
||||
if trace.NewParallelResolverFn != nil {
|
||||
t.Fatal("expected nil NewUnwrappedParallelResolverFn")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("NewDialerWithoutResolverFn is nil", func(t *testing.T) {
|
||||
if trace.NewDialerWithoutResolverFn != nil {
|
||||
t.Fatal("expected nil NewDialerWithoutResolverFn")
|
||||
@@ -58,6 +64,27 @@ func TestNewTrace(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("DNSLookup has the expected buffer size", func(t *testing.T) {
|
||||
ff := &testingx.FakeFiller{}
|
||||
for _, qtype := range DNSQueryTypes {
|
||||
var count int
|
||||
Loop:
|
||||
for {
|
||||
ev := &model.ArchivalDNSLookupResult{}
|
||||
ff.Fill(ev)
|
||||
select {
|
||||
case trace.DNSLookup[qtype] <- ev:
|
||||
count++
|
||||
default:
|
||||
break Loop
|
||||
}
|
||||
}
|
||||
if count != DNSLookupBufferSize {
|
||||
t.Fatal("invalid DNSLookup A channel buffer size")
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("TCPConnect has the expected buffer size", func(t *testing.T) {
|
||||
ff := &testingx.FakeFiller{}
|
||||
var idx int
|
||||
@@ -111,6 +138,57 @@ func TestNewTrace(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestTrace(t *testing.T) {
|
||||
t.Run("NewParallelResolverFn works as intended", func(t *testing.T) {
|
||||
t.Run("when not nil", func(t *testing.T) {
|
||||
mockedErr := errors.New("mocked")
|
||||
tx := &Trace{
|
||||
NewParallelResolverFn: func() model.Resolver {
|
||||
return &mocks.Resolver{
|
||||
MockLookupHost: func(ctx context.Context, domain string) ([]string, error) {
|
||||
return []string{}, mockedErr
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
resolver := tx.newParallelResolver(func() model.Resolver {
|
||||
return nil
|
||||
})
|
||||
ctx := context.Background()
|
||||
addrs, err := resolver.LookupHost(ctx, "example.com")
|
||||
if !errors.Is(err, mockedErr) {
|
||||
t.Fatal("unexpected err", err)
|
||||
}
|
||||
if len(addrs) != 0 {
|
||||
t.Fatal("expected array of size 0")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("when nil", func(t *testing.T) {
|
||||
tx := &Trace{
|
||||
NewParallelResolverFn: nil,
|
||||
}
|
||||
newResolver := func() model.Resolver {
|
||||
return &mocks.Resolver{
|
||||
MockLookupHost: func(ctx context.Context, domain string) ([]string, error) {
|
||||
return []string{"1.1.1.1"}, nil
|
||||
},
|
||||
}
|
||||
}
|
||||
resolver := tx.newParallelResolver(newResolver)
|
||||
ctx := context.Background()
|
||||
addrs, err := resolver.LookupHost(ctx, "example.com")
|
||||
if err != nil {
|
||||
t.Fatal("unexpected err", err)
|
||||
}
|
||||
if len(addrs) != 1 {
|
||||
t.Fatal("expected array of size 1")
|
||||
}
|
||||
if addrs[0] != "1.1.1.1" {
|
||||
t.Fatal("unexpected array output", addrs)
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("NewDialerWithoutResolverFn works as intended", func(t *testing.T) {
|
||||
t.Run("when not nil", func(t *testing.T) {
|
||||
mockedErr := errors.New("mocked")
|
||||
|
||||
Reference in New Issue
Block a user