feat: context-based tracing to record delayed DNS responses (#870)

See https://github.com/ooni/probe/issues/2221

Co-authored-by: decfox <decfox@github.com>
Co-authored-by: Simone Basso <bassosimone@gmail.com>
This commit is contained in:
DecFox
2022-08-22 17:51:32 +05:30
committed by GitHub
parent fe6d378a1f
commit 2301a30630
10 changed files with 508 additions and 327 deletions
+63 -1
View File
@@ -6,6 +6,7 @@ package measurexlite
import (
"context"
"errors"
"log"
"net"
"time"
@@ -97,9 +98,20 @@ func (tx *Trace) OnDNSRoundTripForLookupHost(started time.Time, reso model.Resol
}
}
// DNSNetworkAddresser is the type of something we just used to perform a DNS
// round trip (e.g., model.DNSTransport, model.Resolver) that allows us to get
// the network and the address of the underlying resolver/transport.
type DNSNetworkAddresser interface {
// Address is like model.DNSTransport.Address
Address() string
// Network is like model.DNSTransport.Network
Network() string
}
// NewArchivalDNSLookupResultFromRoundTrip generates a model.ArchivalDNSLookupResultFromRoundTrip
// from the available information right after the DNS RoundTrip
func NewArchivalDNSLookupResultFromRoundTrip(index int64, started time.Duration, reso model.Resolver, query model.DNSQuery,
func NewArchivalDNSLookupResultFromRoundTrip(index int64, started time.Duration, reso DNSNetworkAddresser, query model.DNSQuery,
response model.DNSResponse, addrs []string, err error, finished time.Duration) *model.ArchivalDNSLookupResult {
return &model.ArchivalDNSLookupResult{
Answers: archivalAnswersFromAddrs(addrs),
@@ -167,3 +179,53 @@ func (tx *Trace) FirstDNSLookup() *model.ArchivalDNSLookupResult {
}
return ev[0]
}
// ErrDelayedDNSResponseBufferFull indicates that the delayedDNSResponse buffer is full.
var ErrDelayedDNSResponseBufferFull = errors.New("buffer full")
// OnDelayedDNSResponse implements model.Trace.OnDelayedDNSResponse
func (tx *Trace) OnDelayedDNSResponse(started time.Time, txp model.DNSTransport, query model.DNSQuery,
response model.DNSResponse, addrs []string, err error, finished time.Time) error {
t := finished.Sub(tx.ZeroTime)
select {
case tx.delayedDNSResponse <- NewArchivalDNSLookupResultFromRoundTrip(
tx.Index,
started.Sub(tx.ZeroTime),
txp,
query,
response,
addrs,
err,
t,
):
return nil
default:
return ErrDelayedDNSResponseBufferFull
}
}
// DelayedDNSResponseWithTimeout drains the network events buffered inside
// the delayedDNSResponse channel. We construct a child context based on [ctx]
// and the given [timeout] and we stop reading when original [ctx] has been
// cancelled or the given [timeout] expires, whatever happens first. Once the
// timeout expired, we drain the chan as much as possible before returning.
func (tx *Trace) DelayedDNSResponseWithTimeout(ctx context.Context,
timeout time.Duration) (out []*model.ArchivalDNSLookupResult) {
ctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
for {
select {
case <-ctx.Done():
for { // once the context is done enter in channel draining mode
select {
case ev := <-tx.delayedDNSResponse:
out = append(out, ev)
default:
return
}
}
case ev := <-tx.delayedDNSResponse:
out = append(out, ev)
}
}
}
+153
View File
@@ -2,6 +2,7 @@ package measurexlite
import (
"context"
"errors"
"net"
"testing"
"time"
@@ -322,6 +323,158 @@ func TestFirstDNSLookup(t *testing.T) {
})
}
func TestDelayedDNSResponseWithTimeout(t *testing.T) {
t.Run("OnDelayedDNSResponse saves into the trace", func(t *testing.T) {
t.Run("when buffer is not full", func(t *testing.T) {
zeroTime := time.Now()
td := testingx.NewTimeDeterministic(zeroTime)
trace := NewTrace(0, zeroTime)
trace.TimeNowFn = td.Now
txp := &mocks.DNSTransport{
MockNetwork: func() string {
return "udp"
},
MockAddress: func() string {
return "1.1.1.1"
},
}
started := trace.TimeNow()
query := &mocks.DNSQuery{
MockType: func() uint16 {
return dns.TypeA
},
MockDomain: func() string {
return "dns.google.com"
},
}
addrs := []string{"1.1.1.1"}
finished := trace.TimeNow()
// 1. fill the trace
err := trace.OnDelayedDNSResponse(started, txp, query, &mocks.DNSResponse{},
addrs, nil, finished)
// 2. read the trace
got := trace.DelayedDNSResponseWithTimeout(context.Background(), time.Second)
if err != nil {
t.Fatal("unexpected error", err)
}
if len(got) != 1 {
t.Fatal("unexpected output from trace")
}
})
t.Run("when buffer is full", func(t *testing.T) {
zeroTime := time.Now()
td := testingx.NewTimeDeterministic(zeroTime)
trace := NewTrace(0, zeroTime)
trace.TimeNowFn = td.Now
trace.delayedDNSResponse = make(chan *model.ArchivalDNSLookupResult) // no buffer
txp := &mocks.DNSTransport{
MockNetwork: func() string {
return "udp"
},
MockAddress: func() string {
return "1.1.1.1"
},
}
started := trace.TimeNow()
query := &mocks.DNSQuery{
MockType: func() uint16 {
return dns.TypeA
},
MockDomain: func() string {
return "dns.google.com"
},
}
addrs := []string{"1.1.1.1"}
finished := trace.TimeNow()
// 1. attempt to write into the trace
err := trace.OnDelayedDNSResponse(started, txp, query, &mocks.DNSResponse{},
addrs, nil, finished)
if !errors.Is(err, ErrDelayedDNSResponseBufferFull) {
t.Fatal("unexpected error", err)
}
// 2. confirm we didn't write anything
got := trace.DelayedDNSResponseWithTimeout(context.Background(), time.Second)
if len(got) != 0 {
t.Fatal("unexpected output from trace")
}
})
})
t.Run("DelayedDNSResponseWithTimeout drains the trace", func(t *testing.T) {
t.Run("context is already cancelled and we still drain the trace", func(t *testing.T) {
zeroTime := time.Now()
td := testingx.NewTimeDeterministic(zeroTime)
trace := NewTrace(0, zeroTime)
trace.TimeNowFn = td.Now
txp := &mocks.DNSTransport{
MockNetwork: func() string {
return "udp"
},
MockAddress: func() string {
return "1.1.1.1"
},
}
started := trace.TimeNow()
query := &mocks.DNSQuery{
MockType: func() uint16 {
return dns.TypeA
},
MockDomain: func() string {
return "dns.google.com"
},
}
addrs := []string{"1.1.1.1"}
finished := trace.TimeNow()
events := 4
for i := 0; i < events; i++ {
// fill the trace
trace.delayedDNSResponse <- NewArchivalDNSLookupResultFromRoundTrip(trace.Index, started.Sub(trace.ZeroTime),
txp, query, &mocks.DNSResponse{}, addrs, nil, finished.Sub(trace.ZeroTime))
}
ctx, cancel := context.WithCancel(context.Background())
cancel() // we ensure that the context cancels before draining all the events
// drain the trace
got := trace.DelayedDNSResponseWithTimeout(ctx, 10*time.Second)
if len(got) != 4 {
t.Fatal("unexpected output from trace", len(got))
}
})
t.Run("normal case where the context times out after we start draining", func(t *testing.T) {
zeroTime := time.Now()
td := testingx.NewTimeDeterministic(zeroTime)
trace := NewTrace(0, zeroTime)
trace.TimeNowFn = td.Now
txp := &mocks.DNSTransport{
MockNetwork: func() string {
return "udp"
},
MockAddress: func() string {
return "1.1.1.1"
},
}
started := trace.TimeNow()
query := &mocks.DNSQuery{
MockType: func() uint16 {
return dns.TypeA
},
MockDomain: func() string {
return "dns.google.com"
},
}
addrs := []string{"1.1.1.1"}
finished := trace.TimeNow()
trace.delayedDNSResponse <- NewArchivalDNSLookupResultFromRoundTrip(trace.Index, started.Sub(trace.ZeroTime),
txp, query, &mocks.DNSResponse{}, addrs, nil, finished.Sub(trace.ZeroTime))
got := trace.DelayedDNSResponseWithTimeout(context.Background(), time.Second)
if len(got) != 1 {
t.Fatal("unexpected output from trace")
}
})
})
}
func TestAnswersFromAddrs(t *testing.T) {
tests := []struct {
name string
+21 -15
View File
@@ -34,8 +34,7 @@ type Trace struct {
// traces, you can use zero to indicate the "default" trace.
Index int64
// networkEvent is MANDATORY and buffers network events. If you create
// this channel manually, ensure it has some buffer.
// networkEvent is MANDATORY and buffers network events.
networkEvent chan *model.ArchivalNetworkEvent
// NewStdlibResolverFn is OPTIONAL and can be used to overide
@@ -62,20 +61,19 @@ type Trace struct {
// calls to the netxlite.NewQUICDialerWithoutResolver factory.
NewQUICDialerWithoutResolverFn func(listener model.QUICListener, dl model.DebugLogger) model.QUICDialer
// dnsLookup is MANDATORY and buffers DNS Lookup observations. If you create
// this channel manually, ensure it has some buffer.
// dnsLookup is MANDATORY and buffers DNS Lookup observations.
dnsLookup chan *model.ArchivalDNSLookupResult
// tcpConnect is MANDATORY and buffers TCP connect observations. If you create
// this channel manually, ensure it has some buffer.
// delayedDNSResponse is MANDATORY and buffers delayed DNS responses.
delayedDNSResponse chan *model.ArchivalDNSLookupResult
// tcpConnect is MANDATORY and buffers TCP connect observations.
tcpConnect chan *model.ArchivalTCPConnectResult
// tlsHandshake is MANDATORY and buffers TLS handshake observations. If you create
// this channel manually, ensure it has some buffer.
// tlsHandshake is MANDATORY and buffers TLS handshake observations.
tlsHandshake chan *model.ArchivalTLSOrQUICHandshakeResult
// quicHandshake is MANDATORY and buffers QUIC handshake observations. If you create
// this channel manually, ensure it has some buffer.
// quicHandshake is MANDATORY and buffers QUIC handshake observations.
quicHandshake chan *model.ArchivalTLSOrQUICHandshakeResult
// TimeNowFn is OPTIONAL and can be used to override calls to time.Now
@@ -88,23 +86,27 @@ type Trace struct {
const (
// NetworkEventBufferSize is the buffer size for constructing
// the Trace's NetworkEvent buffered channel.
// the Trace's networkEvent buffered channel.
NetworkEventBufferSize = 64
// DNSLookupBufferSize is the buffer size for constructing
// the Trace's DNSLookup map of buffered channels.
// the Trace's dnsLookup buffered channel.
DNSLookupBufferSize = 8
// DNSResponseBufferSize is the buffer size for constructing
// the Trace's dnsDelayedResponse buffered channel.
DelayedDNSResponseBufferSize = 8
// TCPConnectBufferSize is the buffer size for constructing
// the Trace's TCPConnect buffered channel.
// the Trace's tcpConnect buffered channel.
TCPConnectBufferSize = 8
// TLSHandshakeBufferSize is the buffer for construcing
// the Trace's TLSHandshake buffered channel.
// the Trace's tlsHandshake buffered channel.
TLSHandshakeBufferSize = 8
// QUICHandshakeBufferSize is the buffer for constructing
// the Trace's QUICHandshake buffered channel.
// the Trace's quicHandshake buffered channel.
QUICHandshakeBufferSize = 8
)
@@ -132,6 +134,10 @@ func NewTrace(index int64, zeroTime time.Time) *Trace {
chan *model.ArchivalDNSLookupResult,
DNSLookupBufferSize,
),
delayedDNSResponse: make(
chan *model.ArchivalDNSLookupResult,
DelayedDNSResponseBufferSize,
),
tcpConnect: make(
chan *model.ArchivalTCPConnectResult,
TCPConnectBufferSize,
+27 -8
View File
@@ -84,7 +84,7 @@ func TestNewTrace(t *testing.T) {
}
})
t.Run("DNSLookup has the expected buffer size", func(t *testing.T) {
t.Run("dnsLookup has the expected buffer size", func(t *testing.T) {
ff := &testingx.FakeFiller{}
var idx int
Loop:
@@ -99,11 +99,30 @@ func TestNewTrace(t *testing.T) {
}
}
if idx != DNSLookupBufferSize {
t.Fatal("invalid DNSLookup channel buffer size")
t.Fatal("invalid dnsLookup channel buffer size")
}
})
t.Run("TCPConnect has the expected buffer size", func(t *testing.T) {
t.Run("delayedDNSResponse has the expected buffer size", func(t *testing.T) {
ff := &testingx.FakeFiller{}
var idx int
Loop:
for {
ev := &model.ArchivalDNSLookupResult{}
ff.Fill(ev)
select {
case trace.delayedDNSResponse <- ev:
idx++
default:
break Loop
}
}
if idx != DelayedDNSResponseBufferSize {
t.Fatal("invalid delayedDNSResponse channel buffer size")
}
})
t.Run("tcpConnect has the expected buffer size", func(t *testing.T) {
ff := &testingx.FakeFiller{}
var idx int
Loop:
@@ -118,11 +137,11 @@ func TestNewTrace(t *testing.T) {
}
}
if idx != TCPConnectBufferSize {
t.Fatal("invalid TCPConnect channel buffer size")
t.Fatal("invalid tcpConnect channel buffer size")
}
})
t.Run("TLSHandshake has the expected buffer size", func(t *testing.T) {
t.Run("tlsHandshake has the expected buffer size", func(t *testing.T) {
ff := &testingx.FakeFiller{}
var idx int
Loop:
@@ -137,11 +156,11 @@ func TestNewTrace(t *testing.T) {
}
}
if idx != TLSHandshakeBufferSize {
t.Fatal("invalid TLSHandshake channel buffer size")
t.Fatal("invalid tlsHandshake channel buffer size")
}
})
t.Run("QUICHandshake has the expected buffer size", func(t *testing.T) {
t.Run("quicHandshake has the expected buffer size", func(t *testing.T) {
ff := &testingx.FakeFiller{}
var idx int
Loop:
@@ -156,7 +175,7 @@ func TestNewTrace(t *testing.T) {
}
}
if idx != QUICHandshakeBufferSize {
t.Fatal("invalid QUICHandshake channel buffer size")
t.Fatal("invalid quicHandshake channel buffer size")
}
})