feat: context-based tracing to record delayed DNS responses (#870)
See https://github.com/ooni/probe/issues/2221 Co-authored-by: decfox <decfox@github.com> Co-authored-by: Simone Basso <bassosimone@gmail.com>
This commit is contained in:
parent
fe6d378a1f
commit
2301a30630
|
@ -6,6 +6,7 @@ package measurexlite
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"log"
|
"log"
|
||||||
"net"
|
"net"
|
||||||
"time"
|
"time"
|
||||||
|
@ -97,9 +98,20 @@ func (tx *Trace) OnDNSRoundTripForLookupHost(started time.Time, reso model.Resol
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DNSNetworkAddresser is the type of something we just used to perform a DNS
|
||||||
|
// round trip (e.g., model.DNSTransport, model.Resolver) that allows us to get
|
||||||
|
// the network and the address of the underlying resolver/transport.
|
||||||
|
type DNSNetworkAddresser interface {
|
||||||
|
// Address is like model.DNSTransport.Address
|
||||||
|
Address() string
|
||||||
|
|
||||||
|
// Network is like model.DNSTransport.Network
|
||||||
|
Network() string
|
||||||
|
}
|
||||||
|
|
||||||
// NewArchivalDNSLookupResultFromRoundTrip generates a model.ArchivalDNSLookupResultFromRoundTrip
|
// NewArchivalDNSLookupResultFromRoundTrip generates a model.ArchivalDNSLookupResultFromRoundTrip
|
||||||
// from the available information right after the DNS RoundTrip
|
// from the available information right after the DNS RoundTrip
|
||||||
func NewArchivalDNSLookupResultFromRoundTrip(index int64, started time.Duration, reso model.Resolver, query model.DNSQuery,
|
func NewArchivalDNSLookupResultFromRoundTrip(index int64, started time.Duration, reso DNSNetworkAddresser, query model.DNSQuery,
|
||||||
response model.DNSResponse, addrs []string, err error, finished time.Duration) *model.ArchivalDNSLookupResult {
|
response model.DNSResponse, addrs []string, err error, finished time.Duration) *model.ArchivalDNSLookupResult {
|
||||||
return &model.ArchivalDNSLookupResult{
|
return &model.ArchivalDNSLookupResult{
|
||||||
Answers: archivalAnswersFromAddrs(addrs),
|
Answers: archivalAnswersFromAddrs(addrs),
|
||||||
|
@ -167,3 +179,53 @@ func (tx *Trace) FirstDNSLookup() *model.ArchivalDNSLookupResult {
|
||||||
}
|
}
|
||||||
return ev[0]
|
return ev[0]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ErrDelayedDNSResponseBufferFull indicates that the delayedDNSResponse buffer is full.
|
||||||
|
var ErrDelayedDNSResponseBufferFull = errors.New("buffer full")
|
||||||
|
|
||||||
|
// OnDelayedDNSResponse implements model.Trace.OnDelayedDNSResponse
|
||||||
|
func (tx *Trace) OnDelayedDNSResponse(started time.Time, txp model.DNSTransport, query model.DNSQuery,
|
||||||
|
response model.DNSResponse, addrs []string, err error, finished time.Time) error {
|
||||||
|
t := finished.Sub(tx.ZeroTime)
|
||||||
|
select {
|
||||||
|
case tx.delayedDNSResponse <- NewArchivalDNSLookupResultFromRoundTrip(
|
||||||
|
tx.Index,
|
||||||
|
started.Sub(tx.ZeroTime),
|
||||||
|
txp,
|
||||||
|
query,
|
||||||
|
response,
|
||||||
|
addrs,
|
||||||
|
err,
|
||||||
|
t,
|
||||||
|
):
|
||||||
|
return nil
|
||||||
|
default:
|
||||||
|
return ErrDelayedDNSResponseBufferFull
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// DelayedDNSResponseWithTimeout drains the network events buffered inside
|
||||||
|
// the delayedDNSResponse channel. We construct a child context based on [ctx]
|
||||||
|
// and the given [timeout] and we stop reading when original [ctx] has been
|
||||||
|
// cancelled or the given [timeout] expires, whatever happens first. Once the
|
||||||
|
// timeout expired, we drain the chan as much as possible before returning.
|
||||||
|
func (tx *Trace) DelayedDNSResponseWithTimeout(ctx context.Context,
|
||||||
|
timeout time.Duration) (out []*model.ArchivalDNSLookupResult) {
|
||||||
|
ctx, cancel := context.WithTimeout(ctx, timeout)
|
||||||
|
defer cancel()
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
for { // once the context is done enter in channel draining mode
|
||||||
|
select {
|
||||||
|
case ev := <-tx.delayedDNSResponse:
|
||||||
|
out = append(out, ev)
|
||||||
|
default:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case ev := <-tx.delayedDNSResponse:
|
||||||
|
out = append(out, ev)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -2,6 +2,7 @@ package measurexlite
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"net"
|
"net"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
@ -322,6 +323,158 @@ func TestFirstDNSLookup(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDelayedDNSResponseWithTimeout(t *testing.T) {
|
||||||
|
t.Run("OnDelayedDNSResponse saves into the trace", func(t *testing.T) {
|
||||||
|
t.Run("when buffer is not full", func(t *testing.T) {
|
||||||
|
zeroTime := time.Now()
|
||||||
|
td := testingx.NewTimeDeterministic(zeroTime)
|
||||||
|
trace := NewTrace(0, zeroTime)
|
||||||
|
trace.TimeNowFn = td.Now
|
||||||
|
txp := &mocks.DNSTransport{
|
||||||
|
MockNetwork: func() string {
|
||||||
|
return "udp"
|
||||||
|
},
|
||||||
|
MockAddress: func() string {
|
||||||
|
return "1.1.1.1"
|
||||||
|
},
|
||||||
|
}
|
||||||
|
started := trace.TimeNow()
|
||||||
|
query := &mocks.DNSQuery{
|
||||||
|
MockType: func() uint16 {
|
||||||
|
return dns.TypeA
|
||||||
|
},
|
||||||
|
MockDomain: func() string {
|
||||||
|
return "dns.google.com"
|
||||||
|
},
|
||||||
|
}
|
||||||
|
addrs := []string{"1.1.1.1"}
|
||||||
|
finished := trace.TimeNow()
|
||||||
|
// 1. fill the trace
|
||||||
|
err := trace.OnDelayedDNSResponse(started, txp, query, &mocks.DNSResponse{},
|
||||||
|
addrs, nil, finished)
|
||||||
|
// 2. read the trace
|
||||||
|
got := trace.DelayedDNSResponseWithTimeout(context.Background(), time.Second)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("unexpected error", err)
|
||||||
|
}
|
||||||
|
if len(got) != 1 {
|
||||||
|
t.Fatal("unexpected output from trace")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("when buffer is full", func(t *testing.T) {
|
||||||
|
zeroTime := time.Now()
|
||||||
|
td := testingx.NewTimeDeterministic(zeroTime)
|
||||||
|
trace := NewTrace(0, zeroTime)
|
||||||
|
trace.TimeNowFn = td.Now
|
||||||
|
trace.delayedDNSResponse = make(chan *model.ArchivalDNSLookupResult) // no buffer
|
||||||
|
txp := &mocks.DNSTransport{
|
||||||
|
MockNetwork: func() string {
|
||||||
|
return "udp"
|
||||||
|
},
|
||||||
|
MockAddress: func() string {
|
||||||
|
return "1.1.1.1"
|
||||||
|
},
|
||||||
|
}
|
||||||
|
started := trace.TimeNow()
|
||||||
|
query := &mocks.DNSQuery{
|
||||||
|
MockType: func() uint16 {
|
||||||
|
return dns.TypeA
|
||||||
|
},
|
||||||
|
MockDomain: func() string {
|
||||||
|
return "dns.google.com"
|
||||||
|
},
|
||||||
|
}
|
||||||
|
addrs := []string{"1.1.1.1"}
|
||||||
|
finished := trace.TimeNow()
|
||||||
|
// 1. attempt to write into the trace
|
||||||
|
err := trace.OnDelayedDNSResponse(started, txp, query, &mocks.DNSResponse{},
|
||||||
|
addrs, nil, finished)
|
||||||
|
if !errors.Is(err, ErrDelayedDNSResponseBufferFull) {
|
||||||
|
t.Fatal("unexpected error", err)
|
||||||
|
}
|
||||||
|
// 2. confirm we didn't write anything
|
||||||
|
got := trace.DelayedDNSResponseWithTimeout(context.Background(), time.Second)
|
||||||
|
if len(got) != 0 {
|
||||||
|
t.Fatal("unexpected output from trace")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("DelayedDNSResponseWithTimeout drains the trace", func(t *testing.T) {
|
||||||
|
t.Run("context is already cancelled and we still drain the trace", func(t *testing.T) {
|
||||||
|
zeroTime := time.Now()
|
||||||
|
td := testingx.NewTimeDeterministic(zeroTime)
|
||||||
|
trace := NewTrace(0, zeroTime)
|
||||||
|
trace.TimeNowFn = td.Now
|
||||||
|
txp := &mocks.DNSTransport{
|
||||||
|
MockNetwork: func() string {
|
||||||
|
return "udp"
|
||||||
|
},
|
||||||
|
MockAddress: func() string {
|
||||||
|
return "1.1.1.1"
|
||||||
|
},
|
||||||
|
}
|
||||||
|
started := trace.TimeNow()
|
||||||
|
query := &mocks.DNSQuery{
|
||||||
|
MockType: func() uint16 {
|
||||||
|
return dns.TypeA
|
||||||
|
},
|
||||||
|
MockDomain: func() string {
|
||||||
|
return "dns.google.com"
|
||||||
|
},
|
||||||
|
}
|
||||||
|
addrs := []string{"1.1.1.1"}
|
||||||
|
finished := trace.TimeNow()
|
||||||
|
events := 4
|
||||||
|
for i := 0; i < events; i++ {
|
||||||
|
// fill the trace
|
||||||
|
trace.delayedDNSResponse <- NewArchivalDNSLookupResultFromRoundTrip(trace.Index, started.Sub(trace.ZeroTime),
|
||||||
|
txp, query, &mocks.DNSResponse{}, addrs, nil, finished.Sub(trace.ZeroTime))
|
||||||
|
}
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
cancel() // we ensure that the context cancels before draining all the events
|
||||||
|
// drain the trace
|
||||||
|
got := trace.DelayedDNSResponseWithTimeout(ctx, 10*time.Second)
|
||||||
|
if len(got) != 4 {
|
||||||
|
t.Fatal("unexpected output from trace", len(got))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("normal case where the context times out after we start draining", func(t *testing.T) {
|
||||||
|
zeroTime := time.Now()
|
||||||
|
td := testingx.NewTimeDeterministic(zeroTime)
|
||||||
|
trace := NewTrace(0, zeroTime)
|
||||||
|
trace.TimeNowFn = td.Now
|
||||||
|
txp := &mocks.DNSTransport{
|
||||||
|
MockNetwork: func() string {
|
||||||
|
return "udp"
|
||||||
|
},
|
||||||
|
MockAddress: func() string {
|
||||||
|
return "1.1.1.1"
|
||||||
|
},
|
||||||
|
}
|
||||||
|
started := trace.TimeNow()
|
||||||
|
query := &mocks.DNSQuery{
|
||||||
|
MockType: func() uint16 {
|
||||||
|
return dns.TypeA
|
||||||
|
},
|
||||||
|
MockDomain: func() string {
|
||||||
|
return "dns.google.com"
|
||||||
|
},
|
||||||
|
}
|
||||||
|
addrs := []string{"1.1.1.1"}
|
||||||
|
finished := trace.TimeNow()
|
||||||
|
trace.delayedDNSResponse <- NewArchivalDNSLookupResultFromRoundTrip(trace.Index, started.Sub(trace.ZeroTime),
|
||||||
|
txp, query, &mocks.DNSResponse{}, addrs, nil, finished.Sub(trace.ZeroTime))
|
||||||
|
got := trace.DelayedDNSResponseWithTimeout(context.Background(), time.Second)
|
||||||
|
if len(got) != 1 {
|
||||||
|
t.Fatal("unexpected output from trace")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func TestAnswersFromAddrs(t *testing.T) {
|
func TestAnswersFromAddrs(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
|
|
|
@ -34,8 +34,7 @@ type Trace struct {
|
||||||
// traces, you can use zero to indicate the "default" trace.
|
// traces, you can use zero to indicate the "default" trace.
|
||||||
Index int64
|
Index int64
|
||||||
|
|
||||||
// networkEvent is MANDATORY and buffers network events. If you create
|
// networkEvent is MANDATORY and buffers network events.
|
||||||
// this channel manually, ensure it has some buffer.
|
|
||||||
networkEvent chan *model.ArchivalNetworkEvent
|
networkEvent chan *model.ArchivalNetworkEvent
|
||||||
|
|
||||||
// NewStdlibResolverFn is OPTIONAL and can be used to overide
|
// NewStdlibResolverFn is OPTIONAL and can be used to overide
|
||||||
|
@ -62,20 +61,19 @@ type Trace struct {
|
||||||
// calls to the netxlite.NewQUICDialerWithoutResolver factory.
|
// calls to the netxlite.NewQUICDialerWithoutResolver factory.
|
||||||
NewQUICDialerWithoutResolverFn func(listener model.QUICListener, dl model.DebugLogger) model.QUICDialer
|
NewQUICDialerWithoutResolverFn func(listener model.QUICListener, dl model.DebugLogger) model.QUICDialer
|
||||||
|
|
||||||
// dnsLookup is MANDATORY and buffers DNS Lookup observations. If you create
|
// dnsLookup is MANDATORY and buffers DNS Lookup observations.
|
||||||
// this channel manually, ensure it has some buffer.
|
|
||||||
dnsLookup chan *model.ArchivalDNSLookupResult
|
dnsLookup chan *model.ArchivalDNSLookupResult
|
||||||
|
|
||||||
// tcpConnect is MANDATORY and buffers TCP connect observations. If you create
|
// delayedDNSResponse is MANDATORY and buffers delayed DNS responses.
|
||||||
// this channel manually, ensure it has some buffer.
|
delayedDNSResponse chan *model.ArchivalDNSLookupResult
|
||||||
|
|
||||||
|
// tcpConnect is MANDATORY and buffers TCP connect observations.
|
||||||
tcpConnect chan *model.ArchivalTCPConnectResult
|
tcpConnect chan *model.ArchivalTCPConnectResult
|
||||||
|
|
||||||
// tlsHandshake is MANDATORY and buffers TLS handshake observations. If you create
|
// tlsHandshake is MANDATORY and buffers TLS handshake observations.
|
||||||
// this channel manually, ensure it has some buffer.
|
|
||||||
tlsHandshake chan *model.ArchivalTLSOrQUICHandshakeResult
|
tlsHandshake chan *model.ArchivalTLSOrQUICHandshakeResult
|
||||||
|
|
||||||
// quicHandshake is MANDATORY and buffers QUIC handshake observations. If you create
|
// quicHandshake is MANDATORY and buffers QUIC handshake observations.
|
||||||
// this channel manually, ensure it has some buffer.
|
|
||||||
quicHandshake chan *model.ArchivalTLSOrQUICHandshakeResult
|
quicHandshake chan *model.ArchivalTLSOrQUICHandshakeResult
|
||||||
|
|
||||||
// TimeNowFn is OPTIONAL and can be used to override calls to time.Now
|
// TimeNowFn is OPTIONAL and can be used to override calls to time.Now
|
||||||
|
@ -88,23 +86,27 @@ type Trace struct {
|
||||||
|
|
||||||
const (
|
const (
|
||||||
// NetworkEventBufferSize is the buffer size for constructing
|
// NetworkEventBufferSize is the buffer size for constructing
|
||||||
// the Trace's NetworkEvent buffered channel.
|
// the Trace's networkEvent buffered channel.
|
||||||
NetworkEventBufferSize = 64
|
NetworkEventBufferSize = 64
|
||||||
|
|
||||||
// DNSLookupBufferSize is the buffer size for constructing
|
// DNSLookupBufferSize is the buffer size for constructing
|
||||||
// the Trace's DNSLookup map of buffered channels.
|
// the Trace's dnsLookup buffered channel.
|
||||||
DNSLookupBufferSize = 8
|
DNSLookupBufferSize = 8
|
||||||
|
|
||||||
|
// DNSResponseBufferSize is the buffer size for constructing
|
||||||
|
// the Trace's dnsDelayedResponse buffered channel.
|
||||||
|
DelayedDNSResponseBufferSize = 8
|
||||||
|
|
||||||
// TCPConnectBufferSize is the buffer size for constructing
|
// TCPConnectBufferSize is the buffer size for constructing
|
||||||
// the Trace's TCPConnect buffered channel.
|
// the Trace's tcpConnect buffered channel.
|
||||||
TCPConnectBufferSize = 8
|
TCPConnectBufferSize = 8
|
||||||
|
|
||||||
// TLSHandshakeBufferSize is the buffer for construcing
|
// TLSHandshakeBufferSize is the buffer for construcing
|
||||||
// the Trace's TLSHandshake buffered channel.
|
// the Trace's tlsHandshake buffered channel.
|
||||||
TLSHandshakeBufferSize = 8
|
TLSHandshakeBufferSize = 8
|
||||||
|
|
||||||
// QUICHandshakeBufferSize is the buffer for constructing
|
// QUICHandshakeBufferSize is the buffer for constructing
|
||||||
// the Trace's QUICHandshake buffered channel.
|
// the Trace's quicHandshake buffered channel.
|
||||||
QUICHandshakeBufferSize = 8
|
QUICHandshakeBufferSize = 8
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -132,6 +134,10 @@ func NewTrace(index int64, zeroTime time.Time) *Trace {
|
||||||
chan *model.ArchivalDNSLookupResult,
|
chan *model.ArchivalDNSLookupResult,
|
||||||
DNSLookupBufferSize,
|
DNSLookupBufferSize,
|
||||||
),
|
),
|
||||||
|
delayedDNSResponse: make(
|
||||||
|
chan *model.ArchivalDNSLookupResult,
|
||||||
|
DelayedDNSResponseBufferSize,
|
||||||
|
),
|
||||||
tcpConnect: make(
|
tcpConnect: make(
|
||||||
chan *model.ArchivalTCPConnectResult,
|
chan *model.ArchivalTCPConnectResult,
|
||||||
TCPConnectBufferSize,
|
TCPConnectBufferSize,
|
||||||
|
|
|
@ -84,7 +84,7 @@ func TestNewTrace(t *testing.T) {
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("DNSLookup has the expected buffer size", func(t *testing.T) {
|
t.Run("dnsLookup has the expected buffer size", func(t *testing.T) {
|
||||||
ff := &testingx.FakeFiller{}
|
ff := &testingx.FakeFiller{}
|
||||||
var idx int
|
var idx int
|
||||||
Loop:
|
Loop:
|
||||||
|
@ -99,11 +99,30 @@ func TestNewTrace(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if idx != DNSLookupBufferSize {
|
if idx != DNSLookupBufferSize {
|
||||||
t.Fatal("invalid DNSLookup channel buffer size")
|
t.Fatal("invalid dnsLookup channel buffer size")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("TCPConnect has the expected buffer size", func(t *testing.T) {
|
t.Run("delayedDNSResponse has the expected buffer size", func(t *testing.T) {
|
||||||
|
ff := &testingx.FakeFiller{}
|
||||||
|
var idx int
|
||||||
|
Loop:
|
||||||
|
for {
|
||||||
|
ev := &model.ArchivalDNSLookupResult{}
|
||||||
|
ff.Fill(ev)
|
||||||
|
select {
|
||||||
|
case trace.delayedDNSResponse <- ev:
|
||||||
|
idx++
|
||||||
|
default:
|
||||||
|
break Loop
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if idx != DelayedDNSResponseBufferSize {
|
||||||
|
t.Fatal("invalid delayedDNSResponse channel buffer size")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("tcpConnect has the expected buffer size", func(t *testing.T) {
|
||||||
ff := &testingx.FakeFiller{}
|
ff := &testingx.FakeFiller{}
|
||||||
var idx int
|
var idx int
|
||||||
Loop:
|
Loop:
|
||||||
|
@ -118,11 +137,11 @@ func TestNewTrace(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if idx != TCPConnectBufferSize {
|
if idx != TCPConnectBufferSize {
|
||||||
t.Fatal("invalid TCPConnect channel buffer size")
|
t.Fatal("invalid tcpConnect channel buffer size")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("TLSHandshake has the expected buffer size", func(t *testing.T) {
|
t.Run("tlsHandshake has the expected buffer size", func(t *testing.T) {
|
||||||
ff := &testingx.FakeFiller{}
|
ff := &testingx.FakeFiller{}
|
||||||
var idx int
|
var idx int
|
||||||
Loop:
|
Loop:
|
||||||
|
@ -137,11 +156,11 @@ func TestNewTrace(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if idx != TLSHandshakeBufferSize {
|
if idx != TLSHandshakeBufferSize {
|
||||||
t.Fatal("invalid TLSHandshake channel buffer size")
|
t.Fatal("invalid tlsHandshake channel buffer size")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("QUICHandshake has the expected buffer size", func(t *testing.T) {
|
t.Run("quicHandshake has the expected buffer size", func(t *testing.T) {
|
||||||
ff := &testingx.FakeFiller{}
|
ff := &testingx.FakeFiller{}
|
||||||
var idx int
|
var idx int
|
||||||
Loop:
|
Loop:
|
||||||
|
@ -156,7 +175,7 @@ func TestNewTrace(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if idx != QUICHandshakeBufferSize {
|
if idx != QUICHandshakeBufferSize {
|
||||||
t.Fatal("invalid QUICHandshake channel buffer size")
|
t.Fatal("invalid quicHandshake channel buffer size")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
|
@ -24,6 +24,9 @@ type Trace struct {
|
||||||
MockOnDNSRoundTripForLookupHost func(started time.Time, reso model.Resolver, query model.DNSQuery,
|
MockOnDNSRoundTripForLookupHost func(started time.Time, reso model.Resolver, query model.DNSQuery,
|
||||||
response model.DNSResponse, addrs []string, err error, finished time.Time)
|
response model.DNSResponse, addrs []string, err error, finished time.Time)
|
||||||
|
|
||||||
|
MockOnDelayedDNSResponse func(started time.Time, txp model.DNSTransport, query model.DNSQuery,
|
||||||
|
response model.DNSResponse, addrs []string, err error, finished time.Time) error
|
||||||
|
|
||||||
MockOnConnectDone func(
|
MockOnConnectDone func(
|
||||||
started time.Time, network, domain, remoteAddr string, err error, finished time.Time)
|
started time.Time, network, domain, remoteAddr string, err error, finished time.Time)
|
||||||
|
|
||||||
|
@ -57,6 +60,11 @@ func (t *Trace) OnDNSRoundTripForLookupHost(started time.Time, reso model.Resolv
|
||||||
t.MockOnDNSRoundTripForLookupHost(started, reso, query, response, addrs, err, finished)
|
t.MockOnDNSRoundTripForLookupHost(started, reso, query, response, addrs, err, finished)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *Trace) OnDelayedDNSResponse(started time.Time, txp model.DNSTransport, query model.DNSQuery,
|
||||||
|
response model.DNSResponse, addrs []string, err error, finished time.Time) error {
|
||||||
|
return t.MockOnDelayedDNSResponse(started, txp, query, response, addrs, err, finished)
|
||||||
|
}
|
||||||
|
|
||||||
func (t *Trace) OnConnectDone(
|
func (t *Trace) OnConnectDone(
|
||||||
started time.Time, network, domain, remoteAddr string, err error, finished time.Time) {
|
started time.Time, network, domain, remoteAddr string, err error, finished time.Time) {
|
||||||
t.MockOnConnectDone(started, network, domain, remoteAddr, err, finished)
|
t.MockOnConnectDone(started, network, domain, remoteAddr, err, finished)
|
||||||
|
|
|
@ -71,6 +71,30 @@ func TestTrace(t *testing.T) {
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("OnDelayedDNSResponse", func(t *testing.T) {
|
||||||
|
var called bool
|
||||||
|
tx := &Trace{
|
||||||
|
MockOnDelayedDNSResponse: func(started time.Time, txp model.DNSTransport,
|
||||||
|
query model.DNSQuery, response model.DNSResponse,
|
||||||
|
addrs []string, err error, finished time.Time) error {
|
||||||
|
called = true
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
tx.OnDelayedDNSResponse(
|
||||||
|
time.Now(),
|
||||||
|
&DNSTransport{},
|
||||||
|
&DNSQuery{},
|
||||||
|
&DNSResponse{},
|
||||||
|
[]string{},
|
||||||
|
nil,
|
||||||
|
time.Now(),
|
||||||
|
)
|
||||||
|
if !called {
|
||||||
|
t.Fatal("not called")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
t.Run("OnConnectDone", func(t *testing.T) {
|
t.Run("OnConnectDone", func(t *testing.T) {
|
||||||
var called bool
|
var called bool
|
||||||
tx := &Trace{
|
tx := &Trace{
|
||||||
|
|
|
@ -340,6 +340,29 @@ type Trace interface {
|
||||||
OnDNSRoundTripForLookupHost(started time.Time, reso Resolver, query DNSQuery,
|
OnDNSRoundTripForLookupHost(started time.Time, reso Resolver, query DNSQuery,
|
||||||
response DNSResponse, addrs []string, err error, finished time.Time)
|
response DNSResponse, addrs []string, err error, finished time.Time)
|
||||||
|
|
||||||
|
// OnDelayedDNSResponse is used with a DNSOverUDPTransport and called
|
||||||
|
// when we get delayed, unexpected DNS responses.
|
||||||
|
//
|
||||||
|
// Arguments:
|
||||||
|
//
|
||||||
|
// - started is when we started reading the delayed response;
|
||||||
|
//
|
||||||
|
// - txp is the DNS transport used with the resolver;
|
||||||
|
//
|
||||||
|
// - query is the non-nil DNS query we use for the RoundTrip;
|
||||||
|
//
|
||||||
|
// - response is the non-nil valid DNS response, obtained after some delay;
|
||||||
|
//
|
||||||
|
// - addrs is the list of addresses obtained after decoding the delayed response,
|
||||||
|
// which is empty if the response did not contain any addresses, which we
|
||||||
|
// extract by calling the DecodeLookupHost method.
|
||||||
|
//
|
||||||
|
// - err is the result of DecodeLookupHost: either an error or nil;
|
||||||
|
//
|
||||||
|
// - finished is when we have read the delayed response.
|
||||||
|
OnDelayedDNSResponse(started time.Time, txp DNSTransport, query DNSQuery,
|
||||||
|
resp DNSResponse, addrs []string, err error, finsihed time.Time) error
|
||||||
|
|
||||||
// OnConnectDone is called when connect terminates.
|
// OnConnectDone is called when connect terminates.
|
||||||
//
|
//
|
||||||
// Arguments:
|
// Arguments:
|
||||||
|
|
|
@ -45,10 +45,6 @@ type DNSOverUDPTransport struct {
|
||||||
|
|
||||||
// Endpoint is the MANDATORY server's endpoint (e.g., 1.1.1.1:53)
|
// Endpoint is the MANDATORY server's endpoint (e.g., 1.1.1.1:53)
|
||||||
Endpoint string
|
Endpoint string
|
||||||
|
|
||||||
// IOTimeout is the MANDATORY I/O timeout after which any
|
|
||||||
// conn created to perform round trips times out.
|
|
||||||
IOTimeout time.Duration
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewUnwrappedDNSOverUDPTransport creates a DNSOverUDPTransport instance
|
// NewUnwrappedDNSOverUDPTransport creates a DNSOverUDPTransport instance
|
||||||
|
@ -70,7 +66,6 @@ func NewUnwrappedDNSOverUDPTransport(dialer model.Dialer, address string) *DNSOv
|
||||||
Decoder: &DNSDecoderMiekg{},
|
Decoder: &DNSDecoderMiekg{},
|
||||||
Dialer: dialer,
|
Dialer: dialer,
|
||||||
Endpoint: address,
|
Endpoint: address,
|
||||||
IOTimeout: 10 * time.Second,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -78,21 +73,36 @@ func NewUnwrappedDNSOverUDPTransport(dialer model.Dialer, address string) *DNSOv
|
||||||
func (t *DNSOverUDPTransport) RoundTrip(
|
func (t *DNSOverUDPTransport) RoundTrip(
|
||||||
ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
|
ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
|
||||||
// QUIRK: the original code had a five seconds timeout, which is
|
// QUIRK: the original code had a five seconds timeout, which is
|
||||||
// consistent with the Bionic implementation. Let's enforce such a
|
// consistent with the Bionic implementation.
|
||||||
// timeout using the context in the outer operation because we
|
|
||||||
// need to run for more seconds in the background to catch as many
|
|
||||||
// duplicate replies as possible.
|
|
||||||
//
|
//
|
||||||
// See https://labs.ripe.net/Members/baptiste_jonglez_1/persistent-dns-connections-for-reliability-and-performance
|
// See https://labs.ripe.net/Members/baptiste_jonglez_1/persistent-dns-connections-for-reliability-and-performance
|
||||||
const opTimeout = 5 * time.Second
|
const opTimeout = 5 * time.Second
|
||||||
ctx, cancel := context.WithTimeout(ctx, opTimeout)
|
ctx, cancel := context.WithTimeout(ctx, opTimeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
outch, err := t.AsyncRoundTrip(ctx, query, 1) // buffer to avoid background's goroutine leak
|
rawQuery, err := query.Bytes()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer outch.Close() // we own the channel
|
conn, err := t.Dialer.DialContext(ctx, "udp", t.Endpoint)
|
||||||
return outch.Next(ctx)
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
conn.SetDeadline(time.Now().Add(opTimeout))
|
||||||
|
joinedch := make(chan bool)
|
||||||
|
myaddr := conn.LocalAddr().String()
|
||||||
|
if _, err := conn.Write(rawQuery); err != nil {
|
||||||
|
conn.Close() // we still own the conn
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
resp, err := t.recv(query, conn)
|
||||||
|
if err != nil {
|
||||||
|
conn.Close() // we still own the conn
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
// start a goroutine to listen for any delayed DNS response and
|
||||||
|
// TRANSFER the conn's OWNERSHIP to such a goroutine.
|
||||||
|
go t.ownConnAndSendRecvLoop(ctx, conn, query, myaddr, joinedch)
|
||||||
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// RequiresPadding returns false for UDP according to RFC8467.
|
// RequiresPadding returns false for UDP according to RFC8467.
|
||||||
|
@ -119,196 +129,17 @@ func (t *DNSOverUDPTransport) CloseIdleConnections() {
|
||||||
|
|
||||||
var _ model.DNSTransport = &DNSOverUDPTransport{}
|
var _ model.DNSTransport = &DNSOverUDPTransport{}
|
||||||
|
|
||||||
// DNSOverUDPResponse is a response received by a DNSOverUDPTransport when you
|
// ownConnAndSendRecvLoop listens for delayed DNS responses after we have returned the
|
||||||
// use its AsyncRoundTrip method as opposed to using RoundTrip.
|
// first response. As the name implies, this function TAKES OWNERSHIP of the [conn].
|
||||||
type DNSOverUDPResponse struct {
|
func (t *DNSOverUDPTransport) ownConnAndSendRecvLoop(ctx context.Context, conn net.Conn,
|
||||||
// Err is the error that occurred (nil in case of success).
|
query model.DNSQuery, myaddr string, eofch chan<- bool) {
|
||||||
Err error
|
|
||||||
|
|
||||||
// LocalAddr is the local UDP address we're using.
|
|
||||||
LocalAddr string
|
|
||||||
|
|
||||||
// Operation is the operation that failed.
|
|
||||||
Operation string
|
|
||||||
|
|
||||||
// Query is the related DNS query.
|
|
||||||
Query model.DNSQuery
|
|
||||||
|
|
||||||
// RemoteAddr is the remote server address.
|
|
||||||
RemoteAddr string
|
|
||||||
|
|
||||||
// Response is the response (nil iff error is not nil).
|
|
||||||
Response model.DNSResponse
|
|
||||||
}
|
|
||||||
|
|
||||||
// newDNSOverUDPResponse creates a new DNSOverUDPResponse instance.
|
|
||||||
func (t *DNSOverUDPTransport) newDNSOverUDPResponse(localAddr string, err error,
|
|
||||||
query model.DNSQuery, resp model.DNSResponse, operation string) *DNSOverUDPResponse {
|
|
||||||
return &DNSOverUDPResponse{
|
|
||||||
Err: err,
|
|
||||||
LocalAddr: localAddr,
|
|
||||||
Operation: operation,
|
|
||||||
Query: query,
|
|
||||||
RemoteAddr: t.Endpoint, // The common case is to have an IP:port here (domains are discouraged)
|
|
||||||
Response: resp,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// DNSOverUDPChannel is a wrapper around a channel for reading zero
|
|
||||||
// or more *DNSOverUDPResponse that makes extracting information from
|
|
||||||
// the underlying channels more user friendly than interacting with
|
|
||||||
// the channels directly, thanks to useful wrapper methods implementing
|
|
||||||
// common access patterns. You can still use the underlying channels
|
|
||||||
// directly if there's no suitable convenience method.
|
|
||||||
//
|
|
||||||
// You MUST call the .Close method when done. Not calling such a method
|
|
||||||
// leaks goroutines and causes connections to stay open forever.
|
|
||||||
type DNSOverUDPChannel struct {
|
|
||||||
// Response is the channel where we'll post responses. This channel
|
|
||||||
// WILL NOT be closed when the background goroutine terminates.
|
|
||||||
Response <-chan *DNSOverUDPResponse
|
|
||||||
|
|
||||||
// Joined IS CLOSED when the background goroutine terminates.
|
|
||||||
Joined <-chan bool
|
|
||||||
|
|
||||||
// conn is the underlying connection, which we can Close to
|
|
||||||
// immediately cause the background goroutine to join.
|
|
||||||
conn net.Conn
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close releases the resources allocated by the channel. You MUST
|
|
||||||
// call this method to force the background goroutine that is performing
|
|
||||||
// the round trip to terminate. Calling this method also ensures we
|
|
||||||
// close the connection used by the round trip. This method is idempotent.
|
|
||||||
func (ch *DNSOverUDPChannel) Close() error {
|
|
||||||
return ch.conn.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Next blocks until the next response is received on Response or the
|
|
||||||
// given context expires, whatever happens first. This function will
|
|
||||||
// completely ignore the Joined channel and will just timeout in case
|
|
||||||
// you call Next after the background goroutine had joined. In fact,
|
|
||||||
// the use case for this function is using it to get a response or
|
|
||||||
// a timeout when you know the DNS round trip is pending.
|
|
||||||
func (ch *DNSOverUDPChannel) Next(ctx context.Context) (model.DNSResponse, error) {
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
return nil, ctx.Err()
|
|
||||||
case out := <-ch.Response: // Note: AsyncRoundTrip WILL NOT close the channel or emit a nil
|
|
||||||
return out.Response, out.Err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TryNextResponses attempts to read all the buffered messages inside of the "Response"
|
|
||||||
// channel that contains successful DNS responses. That is, this function will silently skip
|
|
||||||
// any possible DNSOverUDPResponse with its Err != nil. The use case for this function is
|
|
||||||
// to obtain all the subsequent response messages we received while we were performing
|
|
||||||
// other operations (e.g., contacting the test helper of fetching a webpage).
|
|
||||||
func (ch *DNSOverUDPChannel) TryNextResponses() (out []model.DNSResponse) {
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case r := <-ch.Response: // Note: AsyncRoundTrip WILL NOT close the channel or emit a nil
|
|
||||||
if r.Err == nil && r.Response != nil {
|
|
||||||
out = append(out, r.Response)
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// AsyncRoundTrip performs an async DNS round trip. The "buffer" argument
|
|
||||||
// controls how many buffer slots the returned DNSOverUDPChannel's Response
|
|
||||||
// channel should have. A zero or negative value causes this function to
|
|
||||||
// create a channel having a single-slot buffer.
|
|
||||||
//
|
|
||||||
// The real round trip runs in a background goroutine. We will terminate the background
|
|
||||||
// goroutine when (1) the IOTimeout expires for the connection we're using or (2) we
|
|
||||||
// cannot write on the "Response" channel or (3) the connection is closed by calling the
|
|
||||||
// Close method of DNSOverUDPChannel. Note that the background goroutine WILL NOT close
|
|
||||||
// the "Response" channel to signal its completion. Hence, who reads such a
|
|
||||||
// channel MUST be prepared for read operations to block forever (i.e., should use
|
|
||||||
// a select operation for draining the channel in a deadlock-safe way). Also,
|
|
||||||
// we WILL NOT ever post a nil message to the "Response" channel.
|
|
||||||
//
|
|
||||||
// The returned DNSOverUDPChannel contains another channel called Joined that is
|
|
||||||
// closed when the background goroutine terminates, so you can use this channel
|
|
||||||
// should you need to synchronize with such goroutine's termination.
|
|
||||||
//
|
|
||||||
// If you are using the Next or TryNextResponses methods of the DNSOverUDPChannel type,
|
|
||||||
// you don't need to worry about these low level details though.
|
|
||||||
//
|
|
||||||
// We give you OWNERSHIP of the returned DNSOverUDPChannel and you MUST
|
|
||||||
// call its .Close method when done with using it.
|
|
||||||
func (t *DNSOverUDPTransport) AsyncRoundTrip(
|
|
||||||
ctx context.Context, query model.DNSQuery, buffer int) (*DNSOverUDPChannel, error) {
|
|
||||||
rawQuery, err := query.Bytes()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
conn, err := t.Dialer.DialContext(ctx, "udp", t.Endpoint)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
conn.SetDeadline(time.Now().Add(t.IOTimeout))
|
|
||||||
if buffer < 2 {
|
|
||||||
buffer = 1 // as documented
|
|
||||||
}
|
|
||||||
outch := make(chan *DNSOverUDPResponse, buffer)
|
|
||||||
joinedch := make(chan bool)
|
|
||||||
go t.sendRecvLoop(conn, rawQuery, query, outch, joinedch)
|
|
||||||
dnsch := &DNSOverUDPChannel{
|
|
||||||
Response: outch,
|
|
||||||
Joined: joinedch,
|
|
||||||
conn: conn, // transfer ownership
|
|
||||||
}
|
|
||||||
return dnsch, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// sendRecvLoop sends the given raw query on the given conn and receives responses
|
|
||||||
// from the conn posting them onto the given output channel.
|
|
||||||
//
|
|
||||||
// Arguments:
|
|
||||||
//
|
|
||||||
// 1. conn is the BORROWED net.Conn (we will use it for reading or writing but
|
|
||||||
// we do not own the connection and we're not going to close it);
|
|
||||||
//
|
|
||||||
// 2. rawQuery contains the rawQuery and is BORROWED (we won't modify it);
|
|
||||||
//
|
|
||||||
// 3. query contains the original query and is also BORROWED;
|
|
||||||
//
|
|
||||||
// 4. outch is the channel where to emit measurements and is OWNED by this
|
|
||||||
// function (that said, we WILL NOT close this channel);
|
|
||||||
//
|
|
||||||
// 5. eofch is the channel to signal EOF, which is OWNED by this function
|
|
||||||
// and closed when this function exits.
|
|
||||||
//
|
|
||||||
// This method terminates in the following cases:
|
|
||||||
//
|
|
||||||
// 1. I/O error while reading or writing (including the deadline expiring or
|
|
||||||
// the owner of the connection closing the connection);
|
|
||||||
//
|
|
||||||
// 2. We cannot post on the output channel because either there is
|
|
||||||
// noone reading the channel or the channel's buffer is full.
|
|
||||||
//
|
|
||||||
// 3. We cannot parse incoming data as a valid DNS response message that
|
|
||||||
// responds to the query that we originally sent.
|
|
||||||
func (t *DNSOverUDPTransport) sendRecvLoop(conn net.Conn, rawQuery []byte,
|
|
||||||
query model.DNSQuery, outch chan<- *DNSOverUDPResponse, eofch chan<- bool) {
|
|
||||||
defer close(eofch) // synchronize with the caller
|
defer close(eofch) // synchronize with the caller
|
||||||
myaddr := conn.LocalAddr().String()
|
defer conn.Close() // we own the conn
|
||||||
if _, err := conn.Write(rawQuery); err != nil {
|
trace := ContextTraceOrDefault(ctx)
|
||||||
outch <- t.newDNSOverUDPResponse(
|
|
||||||
myaddr, err, query, nil, WriteOperation) // one-sized buffer, can't block
|
|
||||||
return
|
|
||||||
}
|
|
||||||
for {
|
for {
|
||||||
|
started := trace.TimeNow()
|
||||||
resp, err := t.recv(query, conn)
|
resp, err := t.recv(query, conn)
|
||||||
select {
|
finished := trace.TimeNow()
|
||||||
case outch <- t.newDNSOverUDPResponse(myaddr, err, query, resp, ReadOperation):
|
|
||||||
default:
|
|
||||||
return // no-one is reading the channel -- so long...
|
|
||||||
}
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// We are going to consider all errors as fatal for now until we
|
// We are going to consider all errors as fatal for now until we
|
||||||
// hear of specific errs that it might have sense to ignore.
|
// hear of specific errs that it might have sense to ignore.
|
||||||
|
@ -316,6 +147,16 @@ func (t *DNSOverUDPTransport) sendRecvLoop(conn net.Conn, rawQuery []byte,
|
||||||
// Note that erroring out here includes the expiration of the conn's
|
// Note that erroring out here includes the expiration of the conn's
|
||||||
// I/O deadline, which we set above precisely because we want
|
// I/O deadline, which we set above precisely because we want
|
||||||
// the total runtime of this goroutine to be bounded.
|
// the total runtime of this goroutine to be bounded.
|
||||||
|
//
|
||||||
|
// Also, we ARE NOT going to report any failure here as a delayed
|
||||||
|
// DNS response because we only care about duplicate messages, since
|
||||||
|
// this seems how censorship is implemented in, e.g., China.
|
||||||
|
return
|
||||||
|
}
|
||||||
|
addrs, err := resp.DecodeLookupHost()
|
||||||
|
if err := trace.OnDelayedDNSResponse(started, t, query, resp, addrs, err, finished); err != nil {
|
||||||
|
// This error typically indicates that the buffer on which we're
|
||||||
|
// writing is now full, so there's no point in persisting.
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,6 +5,7 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"net"
|
"net"
|
||||||
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -14,6 +15,7 @@ import (
|
||||||
"github.com/ooni/probe-cli/v3/internal/model"
|
"github.com/ooni/probe-cli/v3/internal/model"
|
||||||
"github.com/ooni/probe-cli/v3/internal/model/mocks"
|
"github.com/ooni/probe-cli/v3/internal/model/mocks"
|
||||||
"github.com/ooni/probe-cli/v3/internal/netxlite/filtering"
|
"github.com/ooni/probe-cli/v3/internal/netxlite/filtering"
|
||||||
|
"github.com/ooni/probe-cli/v3/internal/testingx"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestDNSOverUDPTransport(t *testing.T) {
|
func TestDNSOverUDPTransport(t *testing.T) {
|
||||||
|
@ -281,70 +283,16 @@ func TestDNSOverUDPTransport(t *testing.T) {
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("AsyncRoundTrip", func(t *testing.T) {
|
t.Run("recording delayed DNS responses", func(t *testing.T) {
|
||||||
t.Run("calling Next with cancelled context", func(t *testing.T) {
|
t.Run("uses a context-injected custom trace (success case)", func(t *testing.T) {
|
||||||
srvr := &filtering.DNSServer{
|
var (
|
||||||
OnQuery: func(domain string) filtering.DNSAction {
|
delayedDNSResponseCalled bool
|
||||||
return filtering.DNSActionCache
|
goodQueryType bool
|
||||||
},
|
goodTransportNetwork bool
|
||||||
Cache: map[string][]string{
|
goodTransportAddress bool
|
||||||
"dns.google.": {"8.8.8.8"},
|
goodLookupAddrs bool
|
||||||
},
|
goodError bool
|
||||||
}
|
)
|
||||||
listener, err := srvr.Start("127.0.0.1:0")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
defer listener.Close()
|
|
||||||
dialer := NewDialerWithoutResolver(model.DiscardLogger)
|
|
||||||
txp := NewUnwrappedDNSOverUDPTransport(dialer, listener.LocalAddr().String())
|
|
||||||
encoder := &DNSEncoderMiekg{}
|
|
||||||
query := encoder.Encode("dns.google.", dns.TypeA, false)
|
|
||||||
ctx := context.Background()
|
|
||||||
rch, err := txp.AsyncRoundTrip(ctx, query, 1)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
defer rch.Close()
|
|
||||||
ctx, cancel := context.WithCancel(ctx)
|
|
||||||
cancel() // fail immediately
|
|
||||||
resp, err := rch.Next(ctx)
|
|
||||||
if !errors.Is(err, context.Canceled) {
|
|
||||||
t.Fatal("unexpected err", err)
|
|
||||||
}
|
|
||||||
if resp != nil {
|
|
||||||
t.Fatal("unexpected resp")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("no-one is reading the channel", func(t *testing.T) {
|
|
||||||
srvr := &filtering.DNSServer{
|
|
||||||
OnQuery: func(domain string) filtering.DNSAction {
|
|
||||||
return filtering.DNSActionLocalHostPlusCache // i.e., two responses
|
|
||||||
},
|
|
||||||
Cache: map[string][]string{
|
|
||||||
"dns.google.": {"8.8.8.8"},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
listener, err := srvr.Start("127.0.0.1:0")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
defer listener.Close()
|
|
||||||
dialer := NewDialerWithoutResolver(model.DiscardLogger)
|
|
||||||
txp := NewUnwrappedDNSOverUDPTransport(dialer, listener.LocalAddr().String())
|
|
||||||
encoder := &DNSEncoderMiekg{}
|
|
||||||
query := encoder.Encode("dns.google.", dns.TypeA, false)
|
|
||||||
ctx := context.Background()
|
|
||||||
rch, err := txp.AsyncRoundTrip(ctx, query, 1) // but just one place
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
defer rch.Close()
|
|
||||||
<-rch.Joined // should see no-one is reading and stop
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("typical usage to obtain late responses", func(t *testing.T) {
|
|
||||||
srvr := &filtering.DNSServer{
|
srvr := &filtering.DNSServer{
|
||||||
OnQuery: func(domain string) filtering.DNSAction {
|
OnQuery: func(domain string) filtering.DNSAction {
|
||||||
return filtering.DNSActionLocalHostPlusCache
|
return filtering.DNSActionLocalHostPlusCache
|
||||||
|
@ -359,52 +307,94 @@ func TestDNSOverUDPTransport(t *testing.T) {
|
||||||
}
|
}
|
||||||
defer listener.Close()
|
defer listener.Close()
|
||||||
dialer := NewDialerWithoutResolver(model.DiscardLogger)
|
dialer := NewDialerWithoutResolver(model.DiscardLogger)
|
||||||
txp := NewUnwrappedDNSOverUDPTransport(dialer, listener.LocalAddr().String())
|
expectedAddress := listener.LocalAddr().String()
|
||||||
|
txp := NewUnwrappedDNSOverUDPTransport(dialer, expectedAddress)
|
||||||
encoder := &DNSEncoderMiekg{}
|
encoder := &DNSEncoderMiekg{}
|
||||||
query := encoder.Encode("dns.google.", dns.TypeA, false)
|
query := encoder.Encode("dns.google.", dns.TypeA, false)
|
||||||
rch, err := txp.AsyncRoundTrip(context.Background(), query, 1)
|
zeroTime := time.Now()
|
||||||
|
deterministicTime := testingx.NewTimeDeterministic(zeroTime)
|
||||||
|
expectedAddrs := []string{"8.8.8.8"}
|
||||||
|
respChannel := make(chan *model.DNSResponse, 8)
|
||||||
|
mu := new(sync.Mutex)
|
||||||
|
tx := &mocks.Trace{
|
||||||
|
MockTimeNow: deterministicTime.Now,
|
||||||
|
MockOnDelayedDNSResponse: func(started time.Time, txp model.DNSTransport,
|
||||||
|
query model.DNSQuery, response model.DNSResponse, addrs []string, err error,
|
||||||
|
finished time.Time) error {
|
||||||
|
mu.Lock()
|
||||||
|
delayedDNSResponseCalled = true
|
||||||
|
goodQueryType = (query.Type() == dns.TypeA)
|
||||||
|
goodTransportNetwork = (txp.Network() == "udp")
|
||||||
|
goodTransportAddress = (txp.Address() == expectedAddress)
|
||||||
|
goodLookupAddrs = (cmp.Diff(expectedAddrs, addrs) == "")
|
||||||
|
goodError = (err == nil)
|
||||||
|
mu.Unlock()
|
||||||
|
select {
|
||||||
|
case respChannel <- &response:
|
||||||
|
return nil
|
||||||
|
default:
|
||||||
|
return errors.New("full buffer")
|
||||||
|
}
|
||||||
|
},
|
||||||
|
MockOnConnectDone: func(started time.Time, network, domain, remoteAddr string, err error,
|
||||||
|
finished time.Time) {
|
||||||
|
// do nothing
|
||||||
|
},
|
||||||
|
MockMaybeWrapNetConn: func(conn net.Conn) net.Conn {
|
||||||
|
return conn
|
||||||
|
},
|
||||||
|
}
|
||||||
|
ctx := ContextWithTrace(context.Background(), tx)
|
||||||
|
rch, err := txp.RoundTrip(ctx, query)
|
||||||
|
<-respChannel // wait for the delayed response
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
defer rch.Close()
|
addrs, err := rch.DecodeLookupHost()
|
||||||
resp, err := rch.Next(context.Background())
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
addrs, err := resp.DecodeLookupHost()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
mu.Lock()
|
||||||
if diff := cmp.Diff(addrs, []string{"127.0.0.1"}); diff != "" {
|
if diff := cmp.Diff(addrs, []string{"127.0.0.1"}); diff != "" {
|
||||||
t.Fatal(diff)
|
t.Fatal(diff)
|
||||||
}
|
}
|
||||||
// One would not normally busy loop but it's fine to do that in the context
|
if !delayedDNSResponseCalled {
|
||||||
// of this test because we know we're going to receive a second reply. In
|
t.Fatal("delayedDNSResponse not called")
|
||||||
// a real network experiment here we'll do other activities, e.g., contacting
|
|
||||||
// the test helper or fetching a webpage.
|
|
||||||
var additional []model.DNSResponse
|
|
||||||
for {
|
|
||||||
additional = rch.TryNextResponses()
|
|
||||||
if len(additional) > 0 {
|
|
||||||
if len(additional) != 1 {
|
|
||||||
t.Fatal("expected exactly one additional response")
|
|
||||||
}
|
}
|
||||||
break
|
if !goodQueryType {
|
||||||
|
t.Fatal("unexpected query type")
|
||||||
}
|
}
|
||||||
|
if !goodTransportNetwork {
|
||||||
|
t.Fatal("unexpected DNS transport network")
|
||||||
}
|
}
|
||||||
addrs, err = additional[0].DecodeLookupHost()
|
if !goodTransportAddress {
|
||||||
if err != nil {
|
t.Fatal("unexpected DNS Transport address")
|
||||||
t.Fatal(err)
|
|
||||||
}
|
}
|
||||||
if diff := cmp.Diff(addrs, []string{"8.8.8.8"}); diff != "" {
|
if !goodLookupAddrs {
|
||||||
t.Fatal(diff)
|
t.Fatal("unexpected delayed DNSLookup address")
|
||||||
}
|
}
|
||||||
|
if !goodError {
|
||||||
|
t.Fatal("unexpected error encountered")
|
||||||
|
}
|
||||||
|
mu.Unlock()
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("correct behavior when read times out", func(t *testing.T) {
|
t.Run("uses a context-injected custom trace (failure case)", func(t *testing.T) {
|
||||||
|
var (
|
||||||
|
delayedDNSResponseCalled bool
|
||||||
|
goodQueryType bool
|
||||||
|
goodTransportNetwork bool
|
||||||
|
goodTransportAddress bool
|
||||||
|
goodLookupAddrs bool
|
||||||
|
goodError bool
|
||||||
|
)
|
||||||
srvr := &filtering.DNSServer{
|
srvr := &filtering.DNSServer{
|
||||||
OnQuery: func(domain string) filtering.DNSAction {
|
OnQuery: func(domain string) filtering.DNSAction {
|
||||||
return filtering.DNSActionTimeout
|
return filtering.DNSActionLocalHostPlusCache
|
||||||
|
},
|
||||||
|
Cache: map[string][]string{
|
||||||
|
// Note: the cache here is nonexistent so we should
|
||||||
|
// get a "no such host" error from the server.
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
listener, err := srvr.Start("127.0.0.1:0")
|
listener, err := srvr.Start("127.0.0.1:0")
|
||||||
|
@ -413,22 +403,71 @@ func TestDNSOverUDPTransport(t *testing.T) {
|
||||||
}
|
}
|
||||||
defer listener.Close()
|
defer listener.Close()
|
||||||
dialer := NewDialerWithoutResolver(model.DiscardLogger)
|
dialer := NewDialerWithoutResolver(model.DiscardLogger)
|
||||||
txp := NewUnwrappedDNSOverUDPTransport(dialer, listener.LocalAddr().String())
|
expectedAddress := listener.LocalAddr().String()
|
||||||
txp.IOTimeout = 30 * time.Millisecond // short timeout to have a fast test
|
txp := NewUnwrappedDNSOverUDPTransport(dialer, expectedAddress)
|
||||||
encoder := &DNSEncoderMiekg{}
|
encoder := &DNSEncoderMiekg{}
|
||||||
query := encoder.Encode("dns.google.", dns.TypeA, false)
|
query := encoder.Encode("dns.google.", dns.TypeA, false)
|
||||||
rch, err := txp.AsyncRoundTrip(context.Background(), query, 1)
|
zeroTime := time.Now()
|
||||||
|
deterministicTime := testingx.NewTimeDeterministic(zeroTime)
|
||||||
|
respChannel := make(chan *model.DNSResponse, 8)
|
||||||
|
mu := new(sync.Mutex)
|
||||||
|
tx := &mocks.Trace{
|
||||||
|
MockTimeNow: deterministicTime.Now,
|
||||||
|
MockOnDelayedDNSResponse: func(started time.Time, txp model.DNSTransport,
|
||||||
|
query model.DNSQuery, response model.DNSResponse, addrs []string, err error,
|
||||||
|
finished time.Time) error {
|
||||||
|
mu.Lock()
|
||||||
|
delayedDNSResponseCalled = true
|
||||||
|
goodQueryType = (query.Type() == dns.TypeA)
|
||||||
|
goodTransportNetwork = (txp.Network() == "udp")
|
||||||
|
goodTransportAddress = (txp.Address() == expectedAddress)
|
||||||
|
goodLookupAddrs = (len(addrs) == 0)
|
||||||
|
goodError = errors.Is(err, ErrOODNSNoSuchHost)
|
||||||
|
mu.Unlock()
|
||||||
|
respChannel <- &response
|
||||||
|
return errors.New("mocked") // return error to stop background routine to record responses
|
||||||
|
},
|
||||||
|
MockOnConnectDone: func(started time.Time, network, domain, remoteAddr string, err error,
|
||||||
|
finished time.Time) {
|
||||||
|
// do nothing
|
||||||
|
},
|
||||||
|
MockMaybeWrapNetConn: func(conn net.Conn) net.Conn {
|
||||||
|
return conn
|
||||||
|
},
|
||||||
|
}
|
||||||
|
ctx := ContextWithTrace(context.Background(), tx)
|
||||||
|
rch, err := txp.RoundTrip(ctx, query)
|
||||||
|
<-respChannel // wait for the delayed response
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
defer rch.Close()
|
addrs, err := rch.DecodeLookupHost()
|
||||||
result := <-rch.Response
|
if err != nil {
|
||||||
if result.Err == nil || result.Err.Error() != "generic_timeout_error" {
|
t.Fatal(err)
|
||||||
t.Fatal("unexpected error", result.Err)
|
|
||||||
}
|
}
|
||||||
if result.Operation != ReadOperation {
|
mu.Lock()
|
||||||
t.Fatal("unexpected failed operation", result.Operation)
|
if diff := cmp.Diff(addrs, []string{"127.0.0.1"}); diff != "" {
|
||||||
|
t.Fatal(diff)
|
||||||
}
|
}
|
||||||
|
if !delayedDNSResponseCalled {
|
||||||
|
t.Fatal("delayedDNSResponse not called")
|
||||||
|
}
|
||||||
|
if !goodQueryType {
|
||||||
|
t.Fatal("unexpected query type")
|
||||||
|
}
|
||||||
|
if !goodTransportNetwork {
|
||||||
|
t.Fatal("unexpected DNS transport network")
|
||||||
|
}
|
||||||
|
if !goodTransportAddress {
|
||||||
|
t.Fatal("unexpected DNS Transport address")
|
||||||
|
}
|
||||||
|
if !goodLookupAddrs {
|
||||||
|
t.Fatal("unexpected delayed DNSLookup address")
|
||||||
|
}
|
||||||
|
if !goodError {
|
||||||
|
t.Fatal("unexpected error encountered")
|
||||||
|
}
|
||||||
|
mu.Unlock()
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
|
@ -67,6 +67,12 @@ func (*traceDefault) OnDNSRoundTripForLookupHost(started time.Time, reso model.R
|
||||||
// nothing
|
// nothing
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// OnDelayedDNSResponse implements model.Trace.OnDelayedDNSResponse.
|
||||||
|
func (*traceDefault) OnDelayedDNSResponse(started time.Time, txp model.DNSTransport,
|
||||||
|
query model.DNSQuery, response model.DNSResponse, addrs []string, err error, finished time.Time) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// OnConnectDone implements model.Trace.OnConnectDone.
|
// OnConnectDone implements model.Trace.OnConnectDone.
|
||||||
func (*traceDefault) OnConnectDone(
|
func (*traceDefault) OnConnectDone(
|
||||||
started time.Time, network, domain, remoteAddr string, err error, finished time.Time) {
|
started time.Time, network, domain, remoteAddr string, err error, finished time.Time) {
|
||||||
|
|
Loading…
Reference in New Issue
Block a user