feat: context-based tracing to record delayed DNS responses (#870)
See https://github.com/ooni/probe/issues/2221 Co-authored-by: decfox <decfox@github.com> Co-authored-by: Simone Basso <bassosimone@gmail.com>
This commit is contained in:
parent
fe6d378a1f
commit
2301a30630
|
@ -6,6 +6,7 @@ package measurexlite
|
|||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"log"
|
||||
"net"
|
||||
"time"
|
||||
|
@ -97,9 +98,20 @@ func (tx *Trace) OnDNSRoundTripForLookupHost(started time.Time, reso model.Resol
|
|||
}
|
||||
}
|
||||
|
||||
// DNSNetworkAddresser is the type of something we just used to perform a DNS
|
||||
// round trip (e.g., model.DNSTransport, model.Resolver) that allows us to get
|
||||
// the network and the address of the underlying resolver/transport.
|
||||
type DNSNetworkAddresser interface {
|
||||
// Address is like model.DNSTransport.Address
|
||||
Address() string
|
||||
|
||||
// Network is like model.DNSTransport.Network
|
||||
Network() string
|
||||
}
|
||||
|
||||
// NewArchivalDNSLookupResultFromRoundTrip generates a model.ArchivalDNSLookupResultFromRoundTrip
|
||||
// from the available information right after the DNS RoundTrip
|
||||
func NewArchivalDNSLookupResultFromRoundTrip(index int64, started time.Duration, reso model.Resolver, query model.DNSQuery,
|
||||
func NewArchivalDNSLookupResultFromRoundTrip(index int64, started time.Duration, reso DNSNetworkAddresser, query model.DNSQuery,
|
||||
response model.DNSResponse, addrs []string, err error, finished time.Duration) *model.ArchivalDNSLookupResult {
|
||||
return &model.ArchivalDNSLookupResult{
|
||||
Answers: archivalAnswersFromAddrs(addrs),
|
||||
|
@ -167,3 +179,53 @@ func (tx *Trace) FirstDNSLookup() *model.ArchivalDNSLookupResult {
|
|||
}
|
||||
return ev[0]
|
||||
}
|
||||
|
||||
// ErrDelayedDNSResponseBufferFull indicates that the delayedDNSResponse buffer is full.
|
||||
var ErrDelayedDNSResponseBufferFull = errors.New("buffer full")
|
||||
|
||||
// OnDelayedDNSResponse implements model.Trace.OnDelayedDNSResponse
|
||||
func (tx *Trace) OnDelayedDNSResponse(started time.Time, txp model.DNSTransport, query model.DNSQuery,
|
||||
response model.DNSResponse, addrs []string, err error, finished time.Time) error {
|
||||
t := finished.Sub(tx.ZeroTime)
|
||||
select {
|
||||
case tx.delayedDNSResponse <- NewArchivalDNSLookupResultFromRoundTrip(
|
||||
tx.Index,
|
||||
started.Sub(tx.ZeroTime),
|
||||
txp,
|
||||
query,
|
||||
response,
|
||||
addrs,
|
||||
err,
|
||||
t,
|
||||
):
|
||||
return nil
|
||||
default:
|
||||
return ErrDelayedDNSResponseBufferFull
|
||||
}
|
||||
}
|
||||
|
||||
// DelayedDNSResponseWithTimeout drains the network events buffered inside
|
||||
// the delayedDNSResponse channel. We construct a child context based on [ctx]
|
||||
// and the given [timeout] and we stop reading when original [ctx] has been
|
||||
// cancelled or the given [timeout] expires, whatever happens first. Once the
|
||||
// timeout expired, we drain the chan as much as possible before returning.
|
||||
func (tx *Trace) DelayedDNSResponseWithTimeout(ctx context.Context,
|
||||
timeout time.Duration) (out []*model.ArchivalDNSLookupResult) {
|
||||
ctx, cancel := context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
for { // once the context is done enter in channel draining mode
|
||||
select {
|
||||
case ev := <-tx.delayedDNSResponse:
|
||||
out = append(out, ev)
|
||||
default:
|
||||
return
|
||||
}
|
||||
}
|
||||
case ev := <-tx.delayedDNSResponse:
|
||||
out = append(out, ev)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2,6 +2,7 @@ package measurexlite
|
|||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
@ -322,6 +323,158 @@ func TestFirstDNSLookup(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func TestDelayedDNSResponseWithTimeout(t *testing.T) {
|
||||
t.Run("OnDelayedDNSResponse saves into the trace", func(t *testing.T) {
|
||||
t.Run("when buffer is not full", func(t *testing.T) {
|
||||
zeroTime := time.Now()
|
||||
td := testingx.NewTimeDeterministic(zeroTime)
|
||||
trace := NewTrace(0, zeroTime)
|
||||
trace.TimeNowFn = td.Now
|
||||
txp := &mocks.DNSTransport{
|
||||
MockNetwork: func() string {
|
||||
return "udp"
|
||||
},
|
||||
MockAddress: func() string {
|
||||
return "1.1.1.1"
|
||||
},
|
||||
}
|
||||
started := trace.TimeNow()
|
||||
query := &mocks.DNSQuery{
|
||||
MockType: func() uint16 {
|
||||
return dns.TypeA
|
||||
},
|
||||
MockDomain: func() string {
|
||||
return "dns.google.com"
|
||||
},
|
||||
}
|
||||
addrs := []string{"1.1.1.1"}
|
||||
finished := trace.TimeNow()
|
||||
// 1. fill the trace
|
||||
err := trace.OnDelayedDNSResponse(started, txp, query, &mocks.DNSResponse{},
|
||||
addrs, nil, finished)
|
||||
// 2. read the trace
|
||||
got := trace.DelayedDNSResponseWithTimeout(context.Background(), time.Second)
|
||||
if err != nil {
|
||||
t.Fatal("unexpected error", err)
|
||||
}
|
||||
if len(got) != 1 {
|
||||
t.Fatal("unexpected output from trace")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("when buffer is full", func(t *testing.T) {
|
||||
zeroTime := time.Now()
|
||||
td := testingx.NewTimeDeterministic(zeroTime)
|
||||
trace := NewTrace(0, zeroTime)
|
||||
trace.TimeNowFn = td.Now
|
||||
trace.delayedDNSResponse = make(chan *model.ArchivalDNSLookupResult) // no buffer
|
||||
txp := &mocks.DNSTransport{
|
||||
MockNetwork: func() string {
|
||||
return "udp"
|
||||
},
|
||||
MockAddress: func() string {
|
||||
return "1.1.1.1"
|
||||
},
|
||||
}
|
||||
started := trace.TimeNow()
|
||||
query := &mocks.DNSQuery{
|
||||
MockType: func() uint16 {
|
||||
return dns.TypeA
|
||||
},
|
||||
MockDomain: func() string {
|
||||
return "dns.google.com"
|
||||
},
|
||||
}
|
||||
addrs := []string{"1.1.1.1"}
|
||||
finished := trace.TimeNow()
|
||||
// 1. attempt to write into the trace
|
||||
err := trace.OnDelayedDNSResponse(started, txp, query, &mocks.DNSResponse{},
|
||||
addrs, nil, finished)
|
||||
if !errors.Is(err, ErrDelayedDNSResponseBufferFull) {
|
||||
t.Fatal("unexpected error", err)
|
||||
}
|
||||
// 2. confirm we didn't write anything
|
||||
got := trace.DelayedDNSResponseWithTimeout(context.Background(), time.Second)
|
||||
if len(got) != 0 {
|
||||
t.Fatal("unexpected output from trace")
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("DelayedDNSResponseWithTimeout drains the trace", func(t *testing.T) {
|
||||
t.Run("context is already cancelled and we still drain the trace", func(t *testing.T) {
|
||||
zeroTime := time.Now()
|
||||
td := testingx.NewTimeDeterministic(zeroTime)
|
||||
trace := NewTrace(0, zeroTime)
|
||||
trace.TimeNowFn = td.Now
|
||||
txp := &mocks.DNSTransport{
|
||||
MockNetwork: func() string {
|
||||
return "udp"
|
||||
},
|
||||
MockAddress: func() string {
|
||||
return "1.1.1.1"
|
||||
},
|
||||
}
|
||||
started := trace.TimeNow()
|
||||
query := &mocks.DNSQuery{
|
||||
MockType: func() uint16 {
|
||||
return dns.TypeA
|
||||
},
|
||||
MockDomain: func() string {
|
||||
return "dns.google.com"
|
||||
},
|
||||
}
|
||||
addrs := []string{"1.1.1.1"}
|
||||
finished := trace.TimeNow()
|
||||
events := 4
|
||||
for i := 0; i < events; i++ {
|
||||
// fill the trace
|
||||
trace.delayedDNSResponse <- NewArchivalDNSLookupResultFromRoundTrip(trace.Index, started.Sub(trace.ZeroTime),
|
||||
txp, query, &mocks.DNSResponse{}, addrs, nil, finished.Sub(trace.ZeroTime))
|
||||
}
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel() // we ensure that the context cancels before draining all the events
|
||||
// drain the trace
|
||||
got := trace.DelayedDNSResponseWithTimeout(ctx, 10*time.Second)
|
||||
if len(got) != 4 {
|
||||
t.Fatal("unexpected output from trace", len(got))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("normal case where the context times out after we start draining", func(t *testing.T) {
|
||||
zeroTime := time.Now()
|
||||
td := testingx.NewTimeDeterministic(zeroTime)
|
||||
trace := NewTrace(0, zeroTime)
|
||||
trace.TimeNowFn = td.Now
|
||||
txp := &mocks.DNSTransport{
|
||||
MockNetwork: func() string {
|
||||
return "udp"
|
||||
},
|
||||
MockAddress: func() string {
|
||||
return "1.1.1.1"
|
||||
},
|
||||
}
|
||||
started := trace.TimeNow()
|
||||
query := &mocks.DNSQuery{
|
||||
MockType: func() uint16 {
|
||||
return dns.TypeA
|
||||
},
|
||||
MockDomain: func() string {
|
||||
return "dns.google.com"
|
||||
},
|
||||
}
|
||||
addrs := []string{"1.1.1.1"}
|
||||
finished := trace.TimeNow()
|
||||
trace.delayedDNSResponse <- NewArchivalDNSLookupResultFromRoundTrip(trace.Index, started.Sub(trace.ZeroTime),
|
||||
txp, query, &mocks.DNSResponse{}, addrs, nil, finished.Sub(trace.ZeroTime))
|
||||
got := trace.DelayedDNSResponseWithTimeout(context.Background(), time.Second)
|
||||
if len(got) != 1 {
|
||||
t.Fatal("unexpected output from trace")
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestAnswersFromAddrs(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
|
|
@ -34,8 +34,7 @@ type Trace struct {
|
|||
// traces, you can use zero to indicate the "default" trace.
|
||||
Index int64
|
||||
|
||||
// networkEvent is MANDATORY and buffers network events. If you create
|
||||
// this channel manually, ensure it has some buffer.
|
||||
// networkEvent is MANDATORY and buffers network events.
|
||||
networkEvent chan *model.ArchivalNetworkEvent
|
||||
|
||||
// NewStdlibResolverFn is OPTIONAL and can be used to overide
|
||||
|
@ -62,20 +61,19 @@ type Trace struct {
|
|||
// calls to the netxlite.NewQUICDialerWithoutResolver factory.
|
||||
NewQUICDialerWithoutResolverFn func(listener model.QUICListener, dl model.DebugLogger) model.QUICDialer
|
||||
|
||||
// dnsLookup is MANDATORY and buffers DNS Lookup observations. If you create
|
||||
// this channel manually, ensure it has some buffer.
|
||||
// dnsLookup is MANDATORY and buffers DNS Lookup observations.
|
||||
dnsLookup chan *model.ArchivalDNSLookupResult
|
||||
|
||||
// tcpConnect is MANDATORY and buffers TCP connect observations. If you create
|
||||
// this channel manually, ensure it has some buffer.
|
||||
// delayedDNSResponse is MANDATORY and buffers delayed DNS responses.
|
||||
delayedDNSResponse chan *model.ArchivalDNSLookupResult
|
||||
|
||||
// tcpConnect is MANDATORY and buffers TCP connect observations.
|
||||
tcpConnect chan *model.ArchivalTCPConnectResult
|
||||
|
||||
// tlsHandshake is MANDATORY and buffers TLS handshake observations. If you create
|
||||
// this channel manually, ensure it has some buffer.
|
||||
// tlsHandshake is MANDATORY and buffers TLS handshake observations.
|
||||
tlsHandshake chan *model.ArchivalTLSOrQUICHandshakeResult
|
||||
|
||||
// quicHandshake is MANDATORY and buffers QUIC handshake observations. If you create
|
||||
// this channel manually, ensure it has some buffer.
|
||||
// quicHandshake is MANDATORY and buffers QUIC handshake observations.
|
||||
quicHandshake chan *model.ArchivalTLSOrQUICHandshakeResult
|
||||
|
||||
// TimeNowFn is OPTIONAL and can be used to override calls to time.Now
|
||||
|
@ -88,23 +86,27 @@ type Trace struct {
|
|||
|
||||
const (
|
||||
// NetworkEventBufferSize is the buffer size for constructing
|
||||
// the Trace's NetworkEvent buffered channel.
|
||||
// the Trace's networkEvent buffered channel.
|
||||
NetworkEventBufferSize = 64
|
||||
|
||||
// DNSLookupBufferSize is the buffer size for constructing
|
||||
// the Trace's DNSLookup map of buffered channels.
|
||||
// the Trace's dnsLookup buffered channel.
|
||||
DNSLookupBufferSize = 8
|
||||
|
||||
// DNSResponseBufferSize is the buffer size for constructing
|
||||
// the Trace's dnsDelayedResponse buffered channel.
|
||||
DelayedDNSResponseBufferSize = 8
|
||||
|
||||
// TCPConnectBufferSize is the buffer size for constructing
|
||||
// the Trace's TCPConnect buffered channel.
|
||||
// the Trace's tcpConnect buffered channel.
|
||||
TCPConnectBufferSize = 8
|
||||
|
||||
// TLSHandshakeBufferSize is the buffer for construcing
|
||||
// the Trace's TLSHandshake buffered channel.
|
||||
// the Trace's tlsHandshake buffered channel.
|
||||
TLSHandshakeBufferSize = 8
|
||||
|
||||
// QUICHandshakeBufferSize is the buffer for constructing
|
||||
// the Trace's QUICHandshake buffered channel.
|
||||
// the Trace's quicHandshake buffered channel.
|
||||
QUICHandshakeBufferSize = 8
|
||||
)
|
||||
|
||||
|
@ -132,6 +134,10 @@ func NewTrace(index int64, zeroTime time.Time) *Trace {
|
|||
chan *model.ArchivalDNSLookupResult,
|
||||
DNSLookupBufferSize,
|
||||
),
|
||||
delayedDNSResponse: make(
|
||||
chan *model.ArchivalDNSLookupResult,
|
||||
DelayedDNSResponseBufferSize,
|
||||
),
|
||||
tcpConnect: make(
|
||||
chan *model.ArchivalTCPConnectResult,
|
||||
TCPConnectBufferSize,
|
||||
|
|
|
@ -84,7 +84,7 @@ func TestNewTrace(t *testing.T) {
|
|||
}
|
||||
})
|
||||
|
||||
t.Run("DNSLookup has the expected buffer size", func(t *testing.T) {
|
||||
t.Run("dnsLookup has the expected buffer size", func(t *testing.T) {
|
||||
ff := &testingx.FakeFiller{}
|
||||
var idx int
|
||||
Loop:
|
||||
|
@ -99,11 +99,30 @@ func TestNewTrace(t *testing.T) {
|
|||
}
|
||||
}
|
||||
if idx != DNSLookupBufferSize {
|
||||
t.Fatal("invalid DNSLookup channel buffer size")
|
||||
t.Fatal("invalid dnsLookup channel buffer size")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("TCPConnect has the expected buffer size", func(t *testing.T) {
|
||||
t.Run("delayedDNSResponse has the expected buffer size", func(t *testing.T) {
|
||||
ff := &testingx.FakeFiller{}
|
||||
var idx int
|
||||
Loop:
|
||||
for {
|
||||
ev := &model.ArchivalDNSLookupResult{}
|
||||
ff.Fill(ev)
|
||||
select {
|
||||
case trace.delayedDNSResponse <- ev:
|
||||
idx++
|
||||
default:
|
||||
break Loop
|
||||
}
|
||||
}
|
||||
if idx != DelayedDNSResponseBufferSize {
|
||||
t.Fatal("invalid delayedDNSResponse channel buffer size")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("tcpConnect has the expected buffer size", func(t *testing.T) {
|
||||
ff := &testingx.FakeFiller{}
|
||||
var idx int
|
||||
Loop:
|
||||
|
@ -118,11 +137,11 @@ func TestNewTrace(t *testing.T) {
|
|||
}
|
||||
}
|
||||
if idx != TCPConnectBufferSize {
|
||||
t.Fatal("invalid TCPConnect channel buffer size")
|
||||
t.Fatal("invalid tcpConnect channel buffer size")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("TLSHandshake has the expected buffer size", func(t *testing.T) {
|
||||
t.Run("tlsHandshake has the expected buffer size", func(t *testing.T) {
|
||||
ff := &testingx.FakeFiller{}
|
||||
var idx int
|
||||
Loop:
|
||||
|
@ -137,11 +156,11 @@ func TestNewTrace(t *testing.T) {
|
|||
}
|
||||
}
|
||||
if idx != TLSHandshakeBufferSize {
|
||||
t.Fatal("invalid TLSHandshake channel buffer size")
|
||||
t.Fatal("invalid tlsHandshake channel buffer size")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("QUICHandshake has the expected buffer size", func(t *testing.T) {
|
||||
t.Run("quicHandshake has the expected buffer size", func(t *testing.T) {
|
||||
ff := &testingx.FakeFiller{}
|
||||
var idx int
|
||||
Loop:
|
||||
|
@ -156,7 +175,7 @@ func TestNewTrace(t *testing.T) {
|
|||
}
|
||||
}
|
||||
if idx != QUICHandshakeBufferSize {
|
||||
t.Fatal("invalid QUICHandshake channel buffer size")
|
||||
t.Fatal("invalid quicHandshake channel buffer size")
|
||||
}
|
||||
})
|
||||
|
||||
|
|
|
@ -24,6 +24,9 @@ type Trace struct {
|
|||
MockOnDNSRoundTripForLookupHost func(started time.Time, reso model.Resolver, query model.DNSQuery,
|
||||
response model.DNSResponse, addrs []string, err error, finished time.Time)
|
||||
|
||||
MockOnDelayedDNSResponse func(started time.Time, txp model.DNSTransport, query model.DNSQuery,
|
||||
response model.DNSResponse, addrs []string, err error, finished time.Time) error
|
||||
|
||||
MockOnConnectDone func(
|
||||
started time.Time, network, domain, remoteAddr string, err error, finished time.Time)
|
||||
|
||||
|
@ -57,6 +60,11 @@ func (t *Trace) OnDNSRoundTripForLookupHost(started time.Time, reso model.Resolv
|
|||
t.MockOnDNSRoundTripForLookupHost(started, reso, query, response, addrs, err, finished)
|
||||
}
|
||||
|
||||
func (t *Trace) OnDelayedDNSResponse(started time.Time, txp model.DNSTransport, query model.DNSQuery,
|
||||
response model.DNSResponse, addrs []string, err error, finished time.Time) error {
|
||||
return t.MockOnDelayedDNSResponse(started, txp, query, response, addrs, err, finished)
|
||||
}
|
||||
|
||||
func (t *Trace) OnConnectDone(
|
||||
started time.Time, network, domain, remoteAddr string, err error, finished time.Time) {
|
||||
t.MockOnConnectDone(started, network, domain, remoteAddr, err, finished)
|
||||
|
|
|
@ -71,6 +71,30 @@ func TestTrace(t *testing.T) {
|
|||
}
|
||||
})
|
||||
|
||||
t.Run("OnDelayedDNSResponse", func(t *testing.T) {
|
||||
var called bool
|
||||
tx := &Trace{
|
||||
MockOnDelayedDNSResponse: func(started time.Time, txp model.DNSTransport,
|
||||
query model.DNSQuery, response model.DNSResponse,
|
||||
addrs []string, err error, finished time.Time) error {
|
||||
called = true
|
||||
return nil
|
||||
},
|
||||
}
|
||||
tx.OnDelayedDNSResponse(
|
||||
time.Now(),
|
||||
&DNSTransport{},
|
||||
&DNSQuery{},
|
||||
&DNSResponse{},
|
||||
[]string{},
|
||||
nil,
|
||||
time.Now(),
|
||||
)
|
||||
if !called {
|
||||
t.Fatal("not called")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("OnConnectDone", func(t *testing.T) {
|
||||
var called bool
|
||||
tx := &Trace{
|
||||
|
|
|
@ -340,6 +340,29 @@ type Trace interface {
|
|||
OnDNSRoundTripForLookupHost(started time.Time, reso Resolver, query DNSQuery,
|
||||
response DNSResponse, addrs []string, err error, finished time.Time)
|
||||
|
||||
// OnDelayedDNSResponse is used with a DNSOverUDPTransport and called
|
||||
// when we get delayed, unexpected DNS responses.
|
||||
//
|
||||
// Arguments:
|
||||
//
|
||||
// - started is when we started reading the delayed response;
|
||||
//
|
||||
// - txp is the DNS transport used with the resolver;
|
||||
//
|
||||
// - query is the non-nil DNS query we use for the RoundTrip;
|
||||
//
|
||||
// - response is the non-nil valid DNS response, obtained after some delay;
|
||||
//
|
||||
// - addrs is the list of addresses obtained after decoding the delayed response,
|
||||
// which is empty if the response did not contain any addresses, which we
|
||||
// extract by calling the DecodeLookupHost method.
|
||||
//
|
||||
// - err is the result of DecodeLookupHost: either an error or nil;
|
||||
//
|
||||
// - finished is when we have read the delayed response.
|
||||
OnDelayedDNSResponse(started time.Time, txp DNSTransport, query DNSQuery,
|
||||
resp DNSResponse, addrs []string, err error, finsihed time.Time) error
|
||||
|
||||
// OnConnectDone is called when connect terminates.
|
||||
//
|
||||
// Arguments:
|
||||
|
|
|
@ -45,10 +45,6 @@ type DNSOverUDPTransport struct {
|
|||
|
||||
// Endpoint is the MANDATORY server's endpoint (e.g., 1.1.1.1:53)
|
||||
Endpoint string
|
||||
|
||||
// IOTimeout is the MANDATORY I/O timeout after which any
|
||||
// conn created to perform round trips times out.
|
||||
IOTimeout time.Duration
|
||||
}
|
||||
|
||||
// NewUnwrappedDNSOverUDPTransport creates a DNSOverUDPTransport instance
|
||||
|
@ -70,7 +66,6 @@ func NewUnwrappedDNSOverUDPTransport(dialer model.Dialer, address string) *DNSOv
|
|||
Decoder: &DNSDecoderMiekg{},
|
||||
Dialer: dialer,
|
||||
Endpoint: address,
|
||||
IOTimeout: 10 * time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -78,21 +73,36 @@ func NewUnwrappedDNSOverUDPTransport(dialer model.Dialer, address string) *DNSOv
|
|||
func (t *DNSOverUDPTransport) RoundTrip(
|
||||
ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
|
||||
// QUIRK: the original code had a five seconds timeout, which is
|
||||
// consistent with the Bionic implementation. Let's enforce such a
|
||||
// timeout using the context in the outer operation because we
|
||||
// need to run for more seconds in the background to catch as many
|
||||
// duplicate replies as possible.
|
||||
// consistent with the Bionic implementation.
|
||||
//
|
||||
// See https://labs.ripe.net/Members/baptiste_jonglez_1/persistent-dns-connections-for-reliability-and-performance
|
||||
const opTimeout = 5 * time.Second
|
||||
ctx, cancel := context.WithTimeout(ctx, opTimeout)
|
||||
defer cancel()
|
||||
outch, err := t.AsyncRoundTrip(ctx, query, 1) // buffer to avoid background's goroutine leak
|
||||
rawQuery, err := query.Bytes()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer outch.Close() // we own the channel
|
||||
return outch.Next(ctx)
|
||||
conn, err := t.Dialer.DialContext(ctx, "udp", t.Endpoint)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
conn.SetDeadline(time.Now().Add(opTimeout))
|
||||
joinedch := make(chan bool)
|
||||
myaddr := conn.LocalAddr().String()
|
||||
if _, err := conn.Write(rawQuery); err != nil {
|
||||
conn.Close() // we still own the conn
|
||||
return nil, err
|
||||
}
|
||||
resp, err := t.recv(query, conn)
|
||||
if err != nil {
|
||||
conn.Close() // we still own the conn
|
||||
return nil, err
|
||||
}
|
||||
// start a goroutine to listen for any delayed DNS response and
|
||||
// TRANSFER the conn's OWNERSHIP to such a goroutine.
|
||||
go t.ownConnAndSendRecvLoop(ctx, conn, query, myaddr, joinedch)
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// RequiresPadding returns false for UDP according to RFC8467.
|
||||
|
@ -119,196 +129,17 @@ func (t *DNSOverUDPTransport) CloseIdleConnections() {
|
|||
|
||||
var _ model.DNSTransport = &DNSOverUDPTransport{}
|
||||
|
||||
// DNSOverUDPResponse is a response received by a DNSOverUDPTransport when you
|
||||
// use its AsyncRoundTrip method as opposed to using RoundTrip.
|
||||
type DNSOverUDPResponse struct {
|
||||
// Err is the error that occurred (nil in case of success).
|
||||
Err error
|
||||
|
||||
// LocalAddr is the local UDP address we're using.
|
||||
LocalAddr string
|
||||
|
||||
// Operation is the operation that failed.
|
||||
Operation string
|
||||
|
||||
// Query is the related DNS query.
|
||||
Query model.DNSQuery
|
||||
|
||||
// RemoteAddr is the remote server address.
|
||||
RemoteAddr string
|
||||
|
||||
// Response is the response (nil iff error is not nil).
|
||||
Response model.DNSResponse
|
||||
}
|
||||
|
||||
// newDNSOverUDPResponse creates a new DNSOverUDPResponse instance.
|
||||
func (t *DNSOverUDPTransport) newDNSOverUDPResponse(localAddr string, err error,
|
||||
query model.DNSQuery, resp model.DNSResponse, operation string) *DNSOverUDPResponse {
|
||||
return &DNSOverUDPResponse{
|
||||
Err: err,
|
||||
LocalAddr: localAddr,
|
||||
Operation: operation,
|
||||
Query: query,
|
||||
RemoteAddr: t.Endpoint, // The common case is to have an IP:port here (domains are discouraged)
|
||||
Response: resp,
|
||||
}
|
||||
}
|
||||
|
||||
// DNSOverUDPChannel is a wrapper around a channel for reading zero
|
||||
// or more *DNSOverUDPResponse that makes extracting information from
|
||||
// the underlying channels more user friendly than interacting with
|
||||
// the channels directly, thanks to useful wrapper methods implementing
|
||||
// common access patterns. You can still use the underlying channels
|
||||
// directly if there's no suitable convenience method.
|
||||
//
|
||||
// You MUST call the .Close method when done. Not calling such a method
|
||||
// leaks goroutines and causes connections to stay open forever.
|
||||
type DNSOverUDPChannel struct {
|
||||
// Response is the channel where we'll post responses. This channel
|
||||
// WILL NOT be closed when the background goroutine terminates.
|
||||
Response <-chan *DNSOverUDPResponse
|
||||
|
||||
// Joined IS CLOSED when the background goroutine terminates.
|
||||
Joined <-chan bool
|
||||
|
||||
// conn is the underlying connection, which we can Close to
|
||||
// immediately cause the background goroutine to join.
|
||||
conn net.Conn
|
||||
}
|
||||
|
||||
// Close releases the resources allocated by the channel. You MUST
|
||||
// call this method to force the background goroutine that is performing
|
||||
// the round trip to terminate. Calling this method also ensures we
|
||||
// close the connection used by the round trip. This method is idempotent.
|
||||
func (ch *DNSOverUDPChannel) Close() error {
|
||||
return ch.conn.Close()
|
||||
}
|
||||
|
||||
// Next blocks until the next response is received on Response or the
|
||||
// given context expires, whatever happens first. This function will
|
||||
// completely ignore the Joined channel and will just timeout in case
|
||||
// you call Next after the background goroutine had joined. In fact,
|
||||
// the use case for this function is using it to get a response or
|
||||
// a timeout when you know the DNS round trip is pending.
|
||||
func (ch *DNSOverUDPChannel) Next(ctx context.Context) (model.DNSResponse, error) {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case out := <-ch.Response: // Note: AsyncRoundTrip WILL NOT close the channel or emit a nil
|
||||
return out.Response, out.Err
|
||||
}
|
||||
}
|
||||
|
||||
// TryNextResponses attempts to read all the buffered messages inside of the "Response"
|
||||
// channel that contains successful DNS responses. That is, this function will silently skip
|
||||
// any possible DNSOverUDPResponse with its Err != nil. The use case for this function is
|
||||
// to obtain all the subsequent response messages we received while we were performing
|
||||
// other operations (e.g., contacting the test helper of fetching a webpage).
|
||||
func (ch *DNSOverUDPChannel) TryNextResponses() (out []model.DNSResponse) {
|
||||
for {
|
||||
select {
|
||||
case r := <-ch.Response: // Note: AsyncRoundTrip WILL NOT close the channel or emit a nil
|
||||
if r.Err == nil && r.Response != nil {
|
||||
out = append(out, r.Response)
|
||||
}
|
||||
default:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// AsyncRoundTrip performs an async DNS round trip. The "buffer" argument
|
||||
// controls how many buffer slots the returned DNSOverUDPChannel's Response
|
||||
// channel should have. A zero or negative value causes this function to
|
||||
// create a channel having a single-slot buffer.
|
||||
//
|
||||
// The real round trip runs in a background goroutine. We will terminate the background
|
||||
// goroutine when (1) the IOTimeout expires for the connection we're using or (2) we
|
||||
// cannot write on the "Response" channel or (3) the connection is closed by calling the
|
||||
// Close method of DNSOverUDPChannel. Note that the background goroutine WILL NOT close
|
||||
// the "Response" channel to signal its completion. Hence, who reads such a
|
||||
// channel MUST be prepared for read operations to block forever (i.e., should use
|
||||
// a select operation for draining the channel in a deadlock-safe way). Also,
|
||||
// we WILL NOT ever post a nil message to the "Response" channel.
|
||||
//
|
||||
// The returned DNSOverUDPChannel contains another channel called Joined that is
|
||||
// closed when the background goroutine terminates, so you can use this channel
|
||||
// should you need to synchronize with such goroutine's termination.
|
||||
//
|
||||
// If you are using the Next or TryNextResponses methods of the DNSOverUDPChannel type,
|
||||
// you don't need to worry about these low level details though.
|
||||
//
|
||||
// We give you OWNERSHIP of the returned DNSOverUDPChannel and you MUST
|
||||
// call its .Close method when done with using it.
|
||||
func (t *DNSOverUDPTransport) AsyncRoundTrip(
|
||||
ctx context.Context, query model.DNSQuery, buffer int) (*DNSOverUDPChannel, error) {
|
||||
rawQuery, err := query.Bytes()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
conn, err := t.Dialer.DialContext(ctx, "udp", t.Endpoint)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
conn.SetDeadline(time.Now().Add(t.IOTimeout))
|
||||
if buffer < 2 {
|
||||
buffer = 1 // as documented
|
||||
}
|
||||
outch := make(chan *DNSOverUDPResponse, buffer)
|
||||
joinedch := make(chan bool)
|
||||
go t.sendRecvLoop(conn, rawQuery, query, outch, joinedch)
|
||||
dnsch := &DNSOverUDPChannel{
|
||||
Response: outch,
|
||||
Joined: joinedch,
|
||||
conn: conn, // transfer ownership
|
||||
}
|
||||
return dnsch, nil
|
||||
}
|
||||
|
||||
// sendRecvLoop sends the given raw query on the given conn and receives responses
|
||||
// from the conn posting them onto the given output channel.
|
||||
//
|
||||
// Arguments:
|
||||
//
|
||||
// 1. conn is the BORROWED net.Conn (we will use it for reading or writing but
|
||||
// we do not own the connection and we're not going to close it);
|
||||
//
|
||||
// 2. rawQuery contains the rawQuery and is BORROWED (we won't modify it);
|
||||
//
|
||||
// 3. query contains the original query and is also BORROWED;
|
||||
//
|
||||
// 4. outch is the channel where to emit measurements and is OWNED by this
|
||||
// function (that said, we WILL NOT close this channel);
|
||||
//
|
||||
// 5. eofch is the channel to signal EOF, which is OWNED by this function
|
||||
// and closed when this function exits.
|
||||
//
|
||||
// This method terminates in the following cases:
|
||||
//
|
||||
// 1. I/O error while reading or writing (including the deadline expiring or
|
||||
// the owner of the connection closing the connection);
|
||||
//
|
||||
// 2. We cannot post on the output channel because either there is
|
||||
// noone reading the channel or the channel's buffer is full.
|
||||
//
|
||||
// 3. We cannot parse incoming data as a valid DNS response message that
|
||||
// responds to the query that we originally sent.
|
||||
func (t *DNSOverUDPTransport) sendRecvLoop(conn net.Conn, rawQuery []byte,
|
||||
query model.DNSQuery, outch chan<- *DNSOverUDPResponse, eofch chan<- bool) {
|
||||
// ownConnAndSendRecvLoop listens for delayed DNS responses after we have returned the
|
||||
// first response. As the name implies, this function TAKES OWNERSHIP of the [conn].
|
||||
func (t *DNSOverUDPTransport) ownConnAndSendRecvLoop(ctx context.Context, conn net.Conn,
|
||||
query model.DNSQuery, myaddr string, eofch chan<- bool) {
|
||||
defer close(eofch) // synchronize with the caller
|
||||
myaddr := conn.LocalAddr().String()
|
||||
if _, err := conn.Write(rawQuery); err != nil {
|
||||
outch <- t.newDNSOverUDPResponse(
|
||||
myaddr, err, query, nil, WriteOperation) // one-sized buffer, can't block
|
||||
return
|
||||
}
|
||||
defer conn.Close() // we own the conn
|
||||
trace := ContextTraceOrDefault(ctx)
|
||||
for {
|
||||
started := trace.TimeNow()
|
||||
resp, err := t.recv(query, conn)
|
||||
select {
|
||||
case outch <- t.newDNSOverUDPResponse(myaddr, err, query, resp, ReadOperation):
|
||||
default:
|
||||
return // no-one is reading the channel -- so long...
|
||||
}
|
||||
finished := trace.TimeNow()
|
||||
if err != nil {
|
||||
// We are going to consider all errors as fatal for now until we
|
||||
// hear of specific errs that it might have sense to ignore.
|
||||
|
@ -316,6 +147,16 @@ func (t *DNSOverUDPTransport) sendRecvLoop(conn net.Conn, rawQuery []byte,
|
|||
// Note that erroring out here includes the expiration of the conn's
|
||||
// I/O deadline, which we set above precisely because we want
|
||||
// the total runtime of this goroutine to be bounded.
|
||||
//
|
||||
// Also, we ARE NOT going to report any failure here as a delayed
|
||||
// DNS response because we only care about duplicate messages, since
|
||||
// this seems how censorship is implemented in, e.g., China.
|
||||
return
|
||||
}
|
||||
addrs, err := resp.DecodeLookupHost()
|
||||
if err := trace.OnDelayedDNSResponse(started, t, query, resp, addrs, err, finished); err != nil {
|
||||
// This error typically indicates that the buffer on which we're
|
||||
// writing is now full, so there's no point in persisting.
|
||||
return
|
||||
}
|
||||
}
|
||||
|
|
|
@ -5,6 +5,7 @@ import (
|
|||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
@ -14,6 +15,7 @@ import (
|
|||
"github.com/ooni/probe-cli/v3/internal/model"
|
||||
"github.com/ooni/probe-cli/v3/internal/model/mocks"
|
||||
"github.com/ooni/probe-cli/v3/internal/netxlite/filtering"
|
||||
"github.com/ooni/probe-cli/v3/internal/testingx"
|
||||
)
|
||||
|
||||
func TestDNSOverUDPTransport(t *testing.T) {
|
||||
|
@ -281,70 +283,16 @@ func TestDNSOverUDPTransport(t *testing.T) {
|
|||
})
|
||||
})
|
||||
|
||||
t.Run("AsyncRoundTrip", func(t *testing.T) {
|
||||
t.Run("calling Next with cancelled context", func(t *testing.T) {
|
||||
srvr := &filtering.DNSServer{
|
||||
OnQuery: func(domain string) filtering.DNSAction {
|
||||
return filtering.DNSActionCache
|
||||
},
|
||||
Cache: map[string][]string{
|
||||
"dns.google.": {"8.8.8.8"},
|
||||
},
|
||||
}
|
||||
listener, err := srvr.Start("127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer listener.Close()
|
||||
dialer := NewDialerWithoutResolver(model.DiscardLogger)
|
||||
txp := NewUnwrappedDNSOverUDPTransport(dialer, listener.LocalAddr().String())
|
||||
encoder := &DNSEncoderMiekg{}
|
||||
query := encoder.Encode("dns.google.", dns.TypeA, false)
|
||||
ctx := context.Background()
|
||||
rch, err := txp.AsyncRoundTrip(ctx, query, 1)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer rch.Close()
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
cancel() // fail immediately
|
||||
resp, err := rch.Next(ctx)
|
||||
if !errors.Is(err, context.Canceled) {
|
||||
t.Fatal("unexpected err", err)
|
||||
}
|
||||
if resp != nil {
|
||||
t.Fatal("unexpected resp")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("no-one is reading the channel", func(t *testing.T) {
|
||||
srvr := &filtering.DNSServer{
|
||||
OnQuery: func(domain string) filtering.DNSAction {
|
||||
return filtering.DNSActionLocalHostPlusCache // i.e., two responses
|
||||
},
|
||||
Cache: map[string][]string{
|
||||
"dns.google.": {"8.8.8.8"},
|
||||
},
|
||||
}
|
||||
listener, err := srvr.Start("127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer listener.Close()
|
||||
dialer := NewDialerWithoutResolver(model.DiscardLogger)
|
||||
txp := NewUnwrappedDNSOverUDPTransport(dialer, listener.LocalAddr().String())
|
||||
encoder := &DNSEncoderMiekg{}
|
||||
query := encoder.Encode("dns.google.", dns.TypeA, false)
|
||||
ctx := context.Background()
|
||||
rch, err := txp.AsyncRoundTrip(ctx, query, 1) // but just one place
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer rch.Close()
|
||||
<-rch.Joined // should see no-one is reading and stop
|
||||
})
|
||||
|
||||
t.Run("typical usage to obtain late responses", func(t *testing.T) {
|
||||
t.Run("recording delayed DNS responses", func(t *testing.T) {
|
||||
t.Run("uses a context-injected custom trace (success case)", func(t *testing.T) {
|
||||
var (
|
||||
delayedDNSResponseCalled bool
|
||||
goodQueryType bool
|
||||
goodTransportNetwork bool
|
||||
goodTransportAddress bool
|
||||
goodLookupAddrs bool
|
||||
goodError bool
|
||||
)
|
||||
srvr := &filtering.DNSServer{
|
||||
OnQuery: func(domain string) filtering.DNSAction {
|
||||
return filtering.DNSActionLocalHostPlusCache
|
||||
|
@ -359,52 +307,94 @@ func TestDNSOverUDPTransport(t *testing.T) {
|
|||
}
|
||||
defer listener.Close()
|
||||
dialer := NewDialerWithoutResolver(model.DiscardLogger)
|
||||
txp := NewUnwrappedDNSOverUDPTransport(dialer, listener.LocalAddr().String())
|
||||
expectedAddress := listener.LocalAddr().String()
|
||||
txp := NewUnwrappedDNSOverUDPTransport(dialer, expectedAddress)
|
||||
encoder := &DNSEncoderMiekg{}
|
||||
query := encoder.Encode("dns.google.", dns.TypeA, false)
|
||||
rch, err := txp.AsyncRoundTrip(context.Background(), query, 1)
|
||||
zeroTime := time.Now()
|
||||
deterministicTime := testingx.NewTimeDeterministic(zeroTime)
|
||||
expectedAddrs := []string{"8.8.8.8"}
|
||||
respChannel := make(chan *model.DNSResponse, 8)
|
||||
mu := new(sync.Mutex)
|
||||
tx := &mocks.Trace{
|
||||
MockTimeNow: deterministicTime.Now,
|
||||
MockOnDelayedDNSResponse: func(started time.Time, txp model.DNSTransport,
|
||||
query model.DNSQuery, response model.DNSResponse, addrs []string, err error,
|
||||
finished time.Time) error {
|
||||
mu.Lock()
|
||||
delayedDNSResponseCalled = true
|
||||
goodQueryType = (query.Type() == dns.TypeA)
|
||||
goodTransportNetwork = (txp.Network() == "udp")
|
||||
goodTransportAddress = (txp.Address() == expectedAddress)
|
||||
goodLookupAddrs = (cmp.Diff(expectedAddrs, addrs) == "")
|
||||
goodError = (err == nil)
|
||||
mu.Unlock()
|
||||
select {
|
||||
case respChannel <- &response:
|
||||
return nil
|
||||
default:
|
||||
return errors.New("full buffer")
|
||||
}
|
||||
},
|
||||
MockOnConnectDone: func(started time.Time, network, domain, remoteAddr string, err error,
|
||||
finished time.Time) {
|
||||
// do nothing
|
||||
},
|
||||
MockMaybeWrapNetConn: func(conn net.Conn) net.Conn {
|
||||
return conn
|
||||
},
|
||||
}
|
||||
ctx := ContextWithTrace(context.Background(), tx)
|
||||
rch, err := txp.RoundTrip(ctx, query)
|
||||
<-respChannel // wait for the delayed response
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer rch.Close()
|
||||
resp, err := rch.Next(context.Background())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
addrs, err := resp.DecodeLookupHost()
|
||||
addrs, err := rch.DecodeLookupHost()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
mu.Lock()
|
||||
if diff := cmp.Diff(addrs, []string{"127.0.0.1"}); diff != "" {
|
||||
t.Fatal(diff)
|
||||
}
|
||||
// One would not normally busy loop but it's fine to do that in the context
|
||||
// of this test because we know we're going to receive a second reply. In
|
||||
// a real network experiment here we'll do other activities, e.g., contacting
|
||||
// the test helper or fetching a webpage.
|
||||
var additional []model.DNSResponse
|
||||
for {
|
||||
additional = rch.TryNextResponses()
|
||||
if len(additional) > 0 {
|
||||
if len(additional) != 1 {
|
||||
t.Fatal("expected exactly one additional response")
|
||||
if !delayedDNSResponseCalled {
|
||||
t.Fatal("delayedDNSResponse not called")
|
||||
}
|
||||
break
|
||||
if !goodQueryType {
|
||||
t.Fatal("unexpected query type")
|
||||
}
|
||||
if !goodTransportNetwork {
|
||||
t.Fatal("unexpected DNS transport network")
|
||||
}
|
||||
addrs, err = additional[0].DecodeLookupHost()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
if !goodTransportAddress {
|
||||
t.Fatal("unexpected DNS Transport address")
|
||||
}
|
||||
if diff := cmp.Diff(addrs, []string{"8.8.8.8"}); diff != "" {
|
||||
t.Fatal(diff)
|
||||
if !goodLookupAddrs {
|
||||
t.Fatal("unexpected delayed DNSLookup address")
|
||||
}
|
||||
if !goodError {
|
||||
t.Fatal("unexpected error encountered")
|
||||
}
|
||||
mu.Unlock()
|
||||
})
|
||||
|
||||
t.Run("correct behavior when read times out", func(t *testing.T) {
|
||||
t.Run("uses a context-injected custom trace (failure case)", func(t *testing.T) {
|
||||
var (
|
||||
delayedDNSResponseCalled bool
|
||||
goodQueryType bool
|
||||
goodTransportNetwork bool
|
||||
goodTransportAddress bool
|
||||
goodLookupAddrs bool
|
||||
goodError bool
|
||||
)
|
||||
srvr := &filtering.DNSServer{
|
||||
OnQuery: func(domain string) filtering.DNSAction {
|
||||
return filtering.DNSActionTimeout
|
||||
return filtering.DNSActionLocalHostPlusCache
|
||||
},
|
||||
Cache: map[string][]string{
|
||||
// Note: the cache here is nonexistent so we should
|
||||
// get a "no such host" error from the server.
|
||||
},
|
||||
}
|
||||
listener, err := srvr.Start("127.0.0.1:0")
|
||||
|
@ -413,22 +403,71 @@ func TestDNSOverUDPTransport(t *testing.T) {
|
|||
}
|
||||
defer listener.Close()
|
||||
dialer := NewDialerWithoutResolver(model.DiscardLogger)
|
||||
txp := NewUnwrappedDNSOverUDPTransport(dialer, listener.LocalAddr().String())
|
||||
txp.IOTimeout = 30 * time.Millisecond // short timeout to have a fast test
|
||||
expectedAddress := listener.LocalAddr().String()
|
||||
txp := NewUnwrappedDNSOverUDPTransport(dialer, expectedAddress)
|
||||
encoder := &DNSEncoderMiekg{}
|
||||
query := encoder.Encode("dns.google.", dns.TypeA, false)
|
||||
rch, err := txp.AsyncRoundTrip(context.Background(), query, 1)
|
||||
zeroTime := time.Now()
|
||||
deterministicTime := testingx.NewTimeDeterministic(zeroTime)
|
||||
respChannel := make(chan *model.DNSResponse, 8)
|
||||
mu := new(sync.Mutex)
|
||||
tx := &mocks.Trace{
|
||||
MockTimeNow: deterministicTime.Now,
|
||||
MockOnDelayedDNSResponse: func(started time.Time, txp model.DNSTransport,
|
||||
query model.DNSQuery, response model.DNSResponse, addrs []string, err error,
|
||||
finished time.Time) error {
|
||||
mu.Lock()
|
||||
delayedDNSResponseCalled = true
|
||||
goodQueryType = (query.Type() == dns.TypeA)
|
||||
goodTransportNetwork = (txp.Network() == "udp")
|
||||
goodTransportAddress = (txp.Address() == expectedAddress)
|
||||
goodLookupAddrs = (len(addrs) == 0)
|
||||
goodError = errors.Is(err, ErrOODNSNoSuchHost)
|
||||
mu.Unlock()
|
||||
respChannel <- &response
|
||||
return errors.New("mocked") // return error to stop background routine to record responses
|
||||
},
|
||||
MockOnConnectDone: func(started time.Time, network, domain, remoteAddr string, err error,
|
||||
finished time.Time) {
|
||||
// do nothing
|
||||
},
|
||||
MockMaybeWrapNetConn: func(conn net.Conn) net.Conn {
|
||||
return conn
|
||||
},
|
||||
}
|
||||
ctx := ContextWithTrace(context.Background(), tx)
|
||||
rch, err := txp.RoundTrip(ctx, query)
|
||||
<-respChannel // wait for the delayed response
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer rch.Close()
|
||||
result := <-rch.Response
|
||||
if result.Err == nil || result.Err.Error() != "generic_timeout_error" {
|
||||
t.Fatal("unexpected error", result.Err)
|
||||
addrs, err := rch.DecodeLookupHost()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if result.Operation != ReadOperation {
|
||||
t.Fatal("unexpected failed operation", result.Operation)
|
||||
mu.Lock()
|
||||
if diff := cmp.Diff(addrs, []string{"127.0.0.1"}); diff != "" {
|
||||
t.Fatal(diff)
|
||||
}
|
||||
if !delayedDNSResponseCalled {
|
||||
t.Fatal("delayedDNSResponse not called")
|
||||
}
|
||||
if !goodQueryType {
|
||||
t.Fatal("unexpected query type")
|
||||
}
|
||||
if !goodTransportNetwork {
|
||||
t.Fatal("unexpected DNS transport network")
|
||||
}
|
||||
if !goodTransportAddress {
|
||||
t.Fatal("unexpected DNS Transport address")
|
||||
}
|
||||
if !goodLookupAddrs {
|
||||
t.Fatal("unexpected delayed DNSLookup address")
|
||||
}
|
||||
if !goodError {
|
||||
t.Fatal("unexpected error encountered")
|
||||
}
|
||||
mu.Unlock()
|
||||
})
|
||||
})
|
||||
|
||||
|
|
|
@ -67,6 +67,12 @@ func (*traceDefault) OnDNSRoundTripForLookupHost(started time.Time, reso model.R
|
|||
// nothing
|
||||
}
|
||||
|
||||
// OnDelayedDNSResponse implements model.Trace.OnDelayedDNSResponse.
|
||||
func (*traceDefault) OnDelayedDNSResponse(started time.Time, txp model.DNSTransport,
|
||||
query model.DNSQuery, response model.DNSResponse, addrs []string, err error, finished time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// OnConnectDone implements model.Trace.OnConnectDone.
|
||||
func (*traceDefault) OnConnectDone(
|
||||
started time.Time, network, domain, remoteAddr string, err error, finished time.Time) {
|
||||
|
|
Loading…
Reference in New Issue
Block a user