feat: context-based tracing to record delayed DNS responses (#870)

See https://github.com/ooni/probe/issues/2221

Co-authored-by: decfox <decfox@github.com>
Co-authored-by: Simone Basso <bassosimone@gmail.com>
This commit is contained in:
DecFox 2022-08-22 17:51:32 +05:30 committed by GitHub
parent fe6d378a1f
commit 2301a30630
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 508 additions and 327 deletions

View File

@ -6,6 +6,7 @@ package measurexlite
import ( import (
"context" "context"
"errors"
"log" "log"
"net" "net"
"time" "time"
@ -97,9 +98,20 @@ func (tx *Trace) OnDNSRoundTripForLookupHost(started time.Time, reso model.Resol
} }
} }
// DNSNetworkAddresser is the type of something we just used to perform a DNS
// round trip (e.g., model.DNSTransport, model.Resolver) that allows us to get
// the network and the address of the underlying resolver/transport.
type DNSNetworkAddresser interface {
// Address is like model.DNSTransport.Address
Address() string
// Network is like model.DNSTransport.Network
Network() string
}
// NewArchivalDNSLookupResultFromRoundTrip generates a model.ArchivalDNSLookupResultFromRoundTrip // NewArchivalDNSLookupResultFromRoundTrip generates a model.ArchivalDNSLookupResultFromRoundTrip
// from the available information right after the DNS RoundTrip // from the available information right after the DNS RoundTrip
func NewArchivalDNSLookupResultFromRoundTrip(index int64, started time.Duration, reso model.Resolver, query model.DNSQuery, func NewArchivalDNSLookupResultFromRoundTrip(index int64, started time.Duration, reso DNSNetworkAddresser, query model.DNSQuery,
response model.DNSResponse, addrs []string, err error, finished time.Duration) *model.ArchivalDNSLookupResult { response model.DNSResponse, addrs []string, err error, finished time.Duration) *model.ArchivalDNSLookupResult {
return &model.ArchivalDNSLookupResult{ return &model.ArchivalDNSLookupResult{
Answers: archivalAnswersFromAddrs(addrs), Answers: archivalAnswersFromAddrs(addrs),
@ -167,3 +179,53 @@ func (tx *Trace) FirstDNSLookup() *model.ArchivalDNSLookupResult {
} }
return ev[0] return ev[0]
} }
// ErrDelayedDNSResponseBufferFull indicates that the delayedDNSResponse buffer is full.
var ErrDelayedDNSResponseBufferFull = errors.New("buffer full")
// OnDelayedDNSResponse implements model.Trace.OnDelayedDNSResponse
func (tx *Trace) OnDelayedDNSResponse(started time.Time, txp model.DNSTransport, query model.DNSQuery,
response model.DNSResponse, addrs []string, err error, finished time.Time) error {
t := finished.Sub(tx.ZeroTime)
select {
case tx.delayedDNSResponse <- NewArchivalDNSLookupResultFromRoundTrip(
tx.Index,
started.Sub(tx.ZeroTime),
txp,
query,
response,
addrs,
err,
t,
):
return nil
default:
return ErrDelayedDNSResponseBufferFull
}
}
// DelayedDNSResponseWithTimeout drains the network events buffered inside
// the delayedDNSResponse channel. We construct a child context based on [ctx]
// and the given [timeout] and we stop reading when original [ctx] has been
// cancelled or the given [timeout] expires, whatever happens first. Once the
// timeout expired, we drain the chan as much as possible before returning.
func (tx *Trace) DelayedDNSResponseWithTimeout(ctx context.Context,
timeout time.Duration) (out []*model.ArchivalDNSLookupResult) {
ctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
for {
select {
case <-ctx.Done():
for { // once the context is done enter in channel draining mode
select {
case ev := <-tx.delayedDNSResponse:
out = append(out, ev)
default:
return
}
}
case ev := <-tx.delayedDNSResponse:
out = append(out, ev)
}
}
}

View File

@ -2,6 +2,7 @@ package measurexlite
import ( import (
"context" "context"
"errors"
"net" "net"
"testing" "testing"
"time" "time"
@ -322,6 +323,158 @@ func TestFirstDNSLookup(t *testing.T) {
}) })
} }
func TestDelayedDNSResponseWithTimeout(t *testing.T) {
t.Run("OnDelayedDNSResponse saves into the trace", func(t *testing.T) {
t.Run("when buffer is not full", func(t *testing.T) {
zeroTime := time.Now()
td := testingx.NewTimeDeterministic(zeroTime)
trace := NewTrace(0, zeroTime)
trace.TimeNowFn = td.Now
txp := &mocks.DNSTransport{
MockNetwork: func() string {
return "udp"
},
MockAddress: func() string {
return "1.1.1.1"
},
}
started := trace.TimeNow()
query := &mocks.DNSQuery{
MockType: func() uint16 {
return dns.TypeA
},
MockDomain: func() string {
return "dns.google.com"
},
}
addrs := []string{"1.1.1.1"}
finished := trace.TimeNow()
// 1. fill the trace
err := trace.OnDelayedDNSResponse(started, txp, query, &mocks.DNSResponse{},
addrs, nil, finished)
// 2. read the trace
got := trace.DelayedDNSResponseWithTimeout(context.Background(), time.Second)
if err != nil {
t.Fatal("unexpected error", err)
}
if len(got) != 1 {
t.Fatal("unexpected output from trace")
}
})
t.Run("when buffer is full", func(t *testing.T) {
zeroTime := time.Now()
td := testingx.NewTimeDeterministic(zeroTime)
trace := NewTrace(0, zeroTime)
trace.TimeNowFn = td.Now
trace.delayedDNSResponse = make(chan *model.ArchivalDNSLookupResult) // no buffer
txp := &mocks.DNSTransport{
MockNetwork: func() string {
return "udp"
},
MockAddress: func() string {
return "1.1.1.1"
},
}
started := trace.TimeNow()
query := &mocks.DNSQuery{
MockType: func() uint16 {
return dns.TypeA
},
MockDomain: func() string {
return "dns.google.com"
},
}
addrs := []string{"1.1.1.1"}
finished := trace.TimeNow()
// 1. attempt to write into the trace
err := trace.OnDelayedDNSResponse(started, txp, query, &mocks.DNSResponse{},
addrs, nil, finished)
if !errors.Is(err, ErrDelayedDNSResponseBufferFull) {
t.Fatal("unexpected error", err)
}
// 2. confirm we didn't write anything
got := trace.DelayedDNSResponseWithTimeout(context.Background(), time.Second)
if len(got) != 0 {
t.Fatal("unexpected output from trace")
}
})
})
t.Run("DelayedDNSResponseWithTimeout drains the trace", func(t *testing.T) {
t.Run("context is already cancelled and we still drain the trace", func(t *testing.T) {
zeroTime := time.Now()
td := testingx.NewTimeDeterministic(zeroTime)
trace := NewTrace(0, zeroTime)
trace.TimeNowFn = td.Now
txp := &mocks.DNSTransport{
MockNetwork: func() string {
return "udp"
},
MockAddress: func() string {
return "1.1.1.1"
},
}
started := trace.TimeNow()
query := &mocks.DNSQuery{
MockType: func() uint16 {
return dns.TypeA
},
MockDomain: func() string {
return "dns.google.com"
},
}
addrs := []string{"1.1.1.1"}
finished := trace.TimeNow()
events := 4
for i := 0; i < events; i++ {
// fill the trace
trace.delayedDNSResponse <- NewArchivalDNSLookupResultFromRoundTrip(trace.Index, started.Sub(trace.ZeroTime),
txp, query, &mocks.DNSResponse{}, addrs, nil, finished.Sub(trace.ZeroTime))
}
ctx, cancel := context.WithCancel(context.Background())
cancel() // we ensure that the context cancels before draining all the events
// drain the trace
got := trace.DelayedDNSResponseWithTimeout(ctx, 10*time.Second)
if len(got) != 4 {
t.Fatal("unexpected output from trace", len(got))
}
})
t.Run("normal case where the context times out after we start draining", func(t *testing.T) {
zeroTime := time.Now()
td := testingx.NewTimeDeterministic(zeroTime)
trace := NewTrace(0, zeroTime)
trace.TimeNowFn = td.Now
txp := &mocks.DNSTransport{
MockNetwork: func() string {
return "udp"
},
MockAddress: func() string {
return "1.1.1.1"
},
}
started := trace.TimeNow()
query := &mocks.DNSQuery{
MockType: func() uint16 {
return dns.TypeA
},
MockDomain: func() string {
return "dns.google.com"
},
}
addrs := []string{"1.1.1.1"}
finished := trace.TimeNow()
trace.delayedDNSResponse <- NewArchivalDNSLookupResultFromRoundTrip(trace.Index, started.Sub(trace.ZeroTime),
txp, query, &mocks.DNSResponse{}, addrs, nil, finished.Sub(trace.ZeroTime))
got := trace.DelayedDNSResponseWithTimeout(context.Background(), time.Second)
if len(got) != 1 {
t.Fatal("unexpected output from trace")
}
})
})
}
func TestAnswersFromAddrs(t *testing.T) { func TestAnswersFromAddrs(t *testing.T) {
tests := []struct { tests := []struct {
name string name string

View File

@ -34,8 +34,7 @@ type Trace struct {
// traces, you can use zero to indicate the "default" trace. // traces, you can use zero to indicate the "default" trace.
Index int64 Index int64
// networkEvent is MANDATORY and buffers network events. If you create // networkEvent is MANDATORY and buffers network events.
// this channel manually, ensure it has some buffer.
networkEvent chan *model.ArchivalNetworkEvent networkEvent chan *model.ArchivalNetworkEvent
// NewStdlibResolverFn is OPTIONAL and can be used to overide // NewStdlibResolverFn is OPTIONAL and can be used to overide
@ -62,20 +61,19 @@ type Trace struct {
// calls to the netxlite.NewQUICDialerWithoutResolver factory. // calls to the netxlite.NewQUICDialerWithoutResolver factory.
NewQUICDialerWithoutResolverFn func(listener model.QUICListener, dl model.DebugLogger) model.QUICDialer NewQUICDialerWithoutResolverFn func(listener model.QUICListener, dl model.DebugLogger) model.QUICDialer
// dnsLookup is MANDATORY and buffers DNS Lookup observations. If you create // dnsLookup is MANDATORY and buffers DNS Lookup observations.
// this channel manually, ensure it has some buffer.
dnsLookup chan *model.ArchivalDNSLookupResult dnsLookup chan *model.ArchivalDNSLookupResult
// tcpConnect is MANDATORY and buffers TCP connect observations. If you create // delayedDNSResponse is MANDATORY and buffers delayed DNS responses.
// this channel manually, ensure it has some buffer. delayedDNSResponse chan *model.ArchivalDNSLookupResult
// tcpConnect is MANDATORY and buffers TCP connect observations.
tcpConnect chan *model.ArchivalTCPConnectResult tcpConnect chan *model.ArchivalTCPConnectResult
// tlsHandshake is MANDATORY and buffers TLS handshake observations. If you create // tlsHandshake is MANDATORY and buffers TLS handshake observations.
// this channel manually, ensure it has some buffer.
tlsHandshake chan *model.ArchivalTLSOrQUICHandshakeResult tlsHandshake chan *model.ArchivalTLSOrQUICHandshakeResult
// quicHandshake is MANDATORY and buffers QUIC handshake observations. If you create // quicHandshake is MANDATORY and buffers QUIC handshake observations.
// this channel manually, ensure it has some buffer.
quicHandshake chan *model.ArchivalTLSOrQUICHandshakeResult quicHandshake chan *model.ArchivalTLSOrQUICHandshakeResult
// TimeNowFn is OPTIONAL and can be used to override calls to time.Now // TimeNowFn is OPTIONAL and can be used to override calls to time.Now
@ -88,23 +86,27 @@ type Trace struct {
const ( const (
// NetworkEventBufferSize is the buffer size for constructing // NetworkEventBufferSize is the buffer size for constructing
// the Trace's NetworkEvent buffered channel. // the Trace's networkEvent buffered channel.
NetworkEventBufferSize = 64 NetworkEventBufferSize = 64
// DNSLookupBufferSize is the buffer size for constructing // DNSLookupBufferSize is the buffer size for constructing
// the Trace's DNSLookup map of buffered channels. // the Trace's dnsLookup buffered channel.
DNSLookupBufferSize = 8 DNSLookupBufferSize = 8
// DNSResponseBufferSize is the buffer size for constructing
// the Trace's dnsDelayedResponse buffered channel.
DelayedDNSResponseBufferSize = 8
// TCPConnectBufferSize is the buffer size for constructing // TCPConnectBufferSize is the buffer size for constructing
// the Trace's TCPConnect buffered channel. // the Trace's tcpConnect buffered channel.
TCPConnectBufferSize = 8 TCPConnectBufferSize = 8
// TLSHandshakeBufferSize is the buffer for construcing // TLSHandshakeBufferSize is the buffer for construcing
// the Trace's TLSHandshake buffered channel. // the Trace's tlsHandshake buffered channel.
TLSHandshakeBufferSize = 8 TLSHandshakeBufferSize = 8
// QUICHandshakeBufferSize is the buffer for constructing // QUICHandshakeBufferSize is the buffer for constructing
// the Trace's QUICHandshake buffered channel. // the Trace's quicHandshake buffered channel.
QUICHandshakeBufferSize = 8 QUICHandshakeBufferSize = 8
) )
@ -132,6 +134,10 @@ func NewTrace(index int64, zeroTime time.Time) *Trace {
chan *model.ArchivalDNSLookupResult, chan *model.ArchivalDNSLookupResult,
DNSLookupBufferSize, DNSLookupBufferSize,
), ),
delayedDNSResponse: make(
chan *model.ArchivalDNSLookupResult,
DelayedDNSResponseBufferSize,
),
tcpConnect: make( tcpConnect: make(
chan *model.ArchivalTCPConnectResult, chan *model.ArchivalTCPConnectResult,
TCPConnectBufferSize, TCPConnectBufferSize,

View File

@ -84,7 +84,7 @@ func TestNewTrace(t *testing.T) {
} }
}) })
t.Run("DNSLookup has the expected buffer size", func(t *testing.T) { t.Run("dnsLookup has the expected buffer size", func(t *testing.T) {
ff := &testingx.FakeFiller{} ff := &testingx.FakeFiller{}
var idx int var idx int
Loop: Loop:
@ -99,11 +99,30 @@ func TestNewTrace(t *testing.T) {
} }
} }
if idx != DNSLookupBufferSize { if idx != DNSLookupBufferSize {
t.Fatal("invalid DNSLookup channel buffer size") t.Fatal("invalid dnsLookup channel buffer size")
} }
}) })
t.Run("TCPConnect has the expected buffer size", func(t *testing.T) { t.Run("delayedDNSResponse has the expected buffer size", func(t *testing.T) {
ff := &testingx.FakeFiller{}
var idx int
Loop:
for {
ev := &model.ArchivalDNSLookupResult{}
ff.Fill(ev)
select {
case trace.delayedDNSResponse <- ev:
idx++
default:
break Loop
}
}
if idx != DelayedDNSResponseBufferSize {
t.Fatal("invalid delayedDNSResponse channel buffer size")
}
})
t.Run("tcpConnect has the expected buffer size", func(t *testing.T) {
ff := &testingx.FakeFiller{} ff := &testingx.FakeFiller{}
var idx int var idx int
Loop: Loop:
@ -118,11 +137,11 @@ func TestNewTrace(t *testing.T) {
} }
} }
if idx != TCPConnectBufferSize { if idx != TCPConnectBufferSize {
t.Fatal("invalid TCPConnect channel buffer size") t.Fatal("invalid tcpConnect channel buffer size")
} }
}) })
t.Run("TLSHandshake has the expected buffer size", func(t *testing.T) { t.Run("tlsHandshake has the expected buffer size", func(t *testing.T) {
ff := &testingx.FakeFiller{} ff := &testingx.FakeFiller{}
var idx int var idx int
Loop: Loop:
@ -137,11 +156,11 @@ func TestNewTrace(t *testing.T) {
} }
} }
if idx != TLSHandshakeBufferSize { if idx != TLSHandshakeBufferSize {
t.Fatal("invalid TLSHandshake channel buffer size") t.Fatal("invalid tlsHandshake channel buffer size")
} }
}) })
t.Run("QUICHandshake has the expected buffer size", func(t *testing.T) { t.Run("quicHandshake has the expected buffer size", func(t *testing.T) {
ff := &testingx.FakeFiller{} ff := &testingx.FakeFiller{}
var idx int var idx int
Loop: Loop:
@ -156,7 +175,7 @@ func TestNewTrace(t *testing.T) {
} }
} }
if idx != QUICHandshakeBufferSize { if idx != QUICHandshakeBufferSize {
t.Fatal("invalid QUICHandshake channel buffer size") t.Fatal("invalid quicHandshake channel buffer size")
} }
}) })

View File

@ -24,6 +24,9 @@ type Trace struct {
MockOnDNSRoundTripForLookupHost func(started time.Time, reso model.Resolver, query model.DNSQuery, MockOnDNSRoundTripForLookupHost func(started time.Time, reso model.Resolver, query model.DNSQuery,
response model.DNSResponse, addrs []string, err error, finished time.Time) response model.DNSResponse, addrs []string, err error, finished time.Time)
MockOnDelayedDNSResponse func(started time.Time, txp model.DNSTransport, query model.DNSQuery,
response model.DNSResponse, addrs []string, err error, finished time.Time) error
MockOnConnectDone func( MockOnConnectDone func(
started time.Time, network, domain, remoteAddr string, err error, finished time.Time) started time.Time, network, domain, remoteAddr string, err error, finished time.Time)
@ -57,6 +60,11 @@ func (t *Trace) OnDNSRoundTripForLookupHost(started time.Time, reso model.Resolv
t.MockOnDNSRoundTripForLookupHost(started, reso, query, response, addrs, err, finished) t.MockOnDNSRoundTripForLookupHost(started, reso, query, response, addrs, err, finished)
} }
func (t *Trace) OnDelayedDNSResponse(started time.Time, txp model.DNSTransport, query model.DNSQuery,
response model.DNSResponse, addrs []string, err error, finished time.Time) error {
return t.MockOnDelayedDNSResponse(started, txp, query, response, addrs, err, finished)
}
func (t *Trace) OnConnectDone( func (t *Trace) OnConnectDone(
started time.Time, network, domain, remoteAddr string, err error, finished time.Time) { started time.Time, network, domain, remoteAddr string, err error, finished time.Time) {
t.MockOnConnectDone(started, network, domain, remoteAddr, err, finished) t.MockOnConnectDone(started, network, domain, remoteAddr, err, finished)

View File

@ -71,6 +71,30 @@ func TestTrace(t *testing.T) {
} }
}) })
t.Run("OnDelayedDNSResponse", func(t *testing.T) {
var called bool
tx := &Trace{
MockOnDelayedDNSResponse: func(started time.Time, txp model.DNSTransport,
query model.DNSQuery, response model.DNSResponse,
addrs []string, err error, finished time.Time) error {
called = true
return nil
},
}
tx.OnDelayedDNSResponse(
time.Now(),
&DNSTransport{},
&DNSQuery{},
&DNSResponse{},
[]string{},
nil,
time.Now(),
)
if !called {
t.Fatal("not called")
}
})
t.Run("OnConnectDone", func(t *testing.T) { t.Run("OnConnectDone", func(t *testing.T) {
var called bool var called bool
tx := &Trace{ tx := &Trace{

View File

@ -340,6 +340,29 @@ type Trace interface {
OnDNSRoundTripForLookupHost(started time.Time, reso Resolver, query DNSQuery, OnDNSRoundTripForLookupHost(started time.Time, reso Resolver, query DNSQuery,
response DNSResponse, addrs []string, err error, finished time.Time) response DNSResponse, addrs []string, err error, finished time.Time)
// OnDelayedDNSResponse is used with a DNSOverUDPTransport and called
// when we get delayed, unexpected DNS responses.
//
// Arguments:
//
// - started is when we started reading the delayed response;
//
// - txp is the DNS transport used with the resolver;
//
// - query is the non-nil DNS query we use for the RoundTrip;
//
// - response is the non-nil valid DNS response, obtained after some delay;
//
// - addrs is the list of addresses obtained after decoding the delayed response,
// which is empty if the response did not contain any addresses, which we
// extract by calling the DecodeLookupHost method.
//
// - err is the result of DecodeLookupHost: either an error or nil;
//
// - finished is when we have read the delayed response.
OnDelayedDNSResponse(started time.Time, txp DNSTransport, query DNSQuery,
resp DNSResponse, addrs []string, err error, finsihed time.Time) error
// OnConnectDone is called when connect terminates. // OnConnectDone is called when connect terminates.
// //
// Arguments: // Arguments:

View File

@ -45,10 +45,6 @@ type DNSOverUDPTransport struct {
// Endpoint is the MANDATORY server's endpoint (e.g., 1.1.1.1:53) // Endpoint is the MANDATORY server's endpoint (e.g., 1.1.1.1:53)
Endpoint string Endpoint string
// IOTimeout is the MANDATORY I/O timeout after which any
// conn created to perform round trips times out.
IOTimeout time.Duration
} }
// NewUnwrappedDNSOverUDPTransport creates a DNSOverUDPTransport instance // NewUnwrappedDNSOverUDPTransport creates a DNSOverUDPTransport instance
@ -70,7 +66,6 @@ func NewUnwrappedDNSOverUDPTransport(dialer model.Dialer, address string) *DNSOv
Decoder: &DNSDecoderMiekg{}, Decoder: &DNSDecoderMiekg{},
Dialer: dialer, Dialer: dialer,
Endpoint: address, Endpoint: address,
IOTimeout: 10 * time.Second,
} }
} }
@ -78,21 +73,36 @@ func NewUnwrappedDNSOverUDPTransport(dialer model.Dialer, address string) *DNSOv
func (t *DNSOverUDPTransport) RoundTrip( func (t *DNSOverUDPTransport) RoundTrip(
ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) { ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
// QUIRK: the original code had a five seconds timeout, which is // QUIRK: the original code had a five seconds timeout, which is
// consistent with the Bionic implementation. Let's enforce such a // consistent with the Bionic implementation.
// timeout using the context in the outer operation because we
// need to run for more seconds in the background to catch as many
// duplicate replies as possible.
// //
// See https://labs.ripe.net/Members/baptiste_jonglez_1/persistent-dns-connections-for-reliability-and-performance // See https://labs.ripe.net/Members/baptiste_jonglez_1/persistent-dns-connections-for-reliability-and-performance
const opTimeout = 5 * time.Second const opTimeout = 5 * time.Second
ctx, cancel := context.WithTimeout(ctx, opTimeout) ctx, cancel := context.WithTimeout(ctx, opTimeout)
defer cancel() defer cancel()
outch, err := t.AsyncRoundTrip(ctx, query, 1) // buffer to avoid background's goroutine leak rawQuery, err := query.Bytes()
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer outch.Close() // we own the channel conn, err := t.Dialer.DialContext(ctx, "udp", t.Endpoint)
return outch.Next(ctx) if err != nil {
return nil, err
}
conn.SetDeadline(time.Now().Add(opTimeout))
joinedch := make(chan bool)
myaddr := conn.LocalAddr().String()
if _, err := conn.Write(rawQuery); err != nil {
conn.Close() // we still own the conn
return nil, err
}
resp, err := t.recv(query, conn)
if err != nil {
conn.Close() // we still own the conn
return nil, err
}
// start a goroutine to listen for any delayed DNS response and
// TRANSFER the conn's OWNERSHIP to such a goroutine.
go t.ownConnAndSendRecvLoop(ctx, conn, query, myaddr, joinedch)
return resp, nil
} }
// RequiresPadding returns false for UDP according to RFC8467. // RequiresPadding returns false for UDP according to RFC8467.
@ -119,196 +129,17 @@ func (t *DNSOverUDPTransport) CloseIdleConnections() {
var _ model.DNSTransport = &DNSOverUDPTransport{} var _ model.DNSTransport = &DNSOverUDPTransport{}
// DNSOverUDPResponse is a response received by a DNSOverUDPTransport when you // ownConnAndSendRecvLoop listens for delayed DNS responses after we have returned the
// use its AsyncRoundTrip method as opposed to using RoundTrip. // first response. As the name implies, this function TAKES OWNERSHIP of the [conn].
type DNSOverUDPResponse struct { func (t *DNSOverUDPTransport) ownConnAndSendRecvLoop(ctx context.Context, conn net.Conn,
// Err is the error that occurred (nil in case of success). query model.DNSQuery, myaddr string, eofch chan<- bool) {
Err error
// LocalAddr is the local UDP address we're using.
LocalAddr string
// Operation is the operation that failed.
Operation string
// Query is the related DNS query.
Query model.DNSQuery
// RemoteAddr is the remote server address.
RemoteAddr string
// Response is the response (nil iff error is not nil).
Response model.DNSResponse
}
// newDNSOverUDPResponse creates a new DNSOverUDPResponse instance.
func (t *DNSOverUDPTransport) newDNSOverUDPResponse(localAddr string, err error,
query model.DNSQuery, resp model.DNSResponse, operation string) *DNSOverUDPResponse {
return &DNSOverUDPResponse{
Err: err,
LocalAddr: localAddr,
Operation: operation,
Query: query,
RemoteAddr: t.Endpoint, // The common case is to have an IP:port here (domains are discouraged)
Response: resp,
}
}
// DNSOverUDPChannel is a wrapper around a channel for reading zero
// or more *DNSOverUDPResponse that makes extracting information from
// the underlying channels more user friendly than interacting with
// the channels directly, thanks to useful wrapper methods implementing
// common access patterns. You can still use the underlying channels
// directly if there's no suitable convenience method.
//
// You MUST call the .Close method when done. Not calling such a method
// leaks goroutines and causes connections to stay open forever.
type DNSOverUDPChannel struct {
// Response is the channel where we'll post responses. This channel
// WILL NOT be closed when the background goroutine terminates.
Response <-chan *DNSOverUDPResponse
// Joined IS CLOSED when the background goroutine terminates.
Joined <-chan bool
// conn is the underlying connection, which we can Close to
// immediately cause the background goroutine to join.
conn net.Conn
}
// Close releases the resources allocated by the channel. You MUST
// call this method to force the background goroutine that is performing
// the round trip to terminate. Calling this method also ensures we
// close the connection used by the round trip. This method is idempotent.
func (ch *DNSOverUDPChannel) Close() error {
return ch.conn.Close()
}
// Next blocks until the next response is received on Response or the
// given context expires, whatever happens first. This function will
// completely ignore the Joined channel and will just timeout in case
// you call Next after the background goroutine had joined. In fact,
// the use case for this function is using it to get a response or
// a timeout when you know the DNS round trip is pending.
func (ch *DNSOverUDPChannel) Next(ctx context.Context) (model.DNSResponse, error) {
select {
case <-ctx.Done():
return nil, ctx.Err()
case out := <-ch.Response: // Note: AsyncRoundTrip WILL NOT close the channel or emit a nil
return out.Response, out.Err
}
}
// TryNextResponses attempts to read all the buffered messages inside of the "Response"
// channel that contains successful DNS responses. That is, this function will silently skip
// any possible DNSOverUDPResponse with its Err != nil. The use case for this function is
// to obtain all the subsequent response messages we received while we were performing
// other operations (e.g., contacting the test helper of fetching a webpage).
func (ch *DNSOverUDPChannel) TryNextResponses() (out []model.DNSResponse) {
for {
select {
case r := <-ch.Response: // Note: AsyncRoundTrip WILL NOT close the channel or emit a nil
if r.Err == nil && r.Response != nil {
out = append(out, r.Response)
}
default:
return
}
}
}
// AsyncRoundTrip performs an async DNS round trip. The "buffer" argument
// controls how many buffer slots the returned DNSOverUDPChannel's Response
// channel should have. A zero or negative value causes this function to
// create a channel having a single-slot buffer.
//
// The real round trip runs in a background goroutine. We will terminate the background
// goroutine when (1) the IOTimeout expires for the connection we're using or (2) we
// cannot write on the "Response" channel or (3) the connection is closed by calling the
// Close method of DNSOverUDPChannel. Note that the background goroutine WILL NOT close
// the "Response" channel to signal its completion. Hence, who reads such a
// channel MUST be prepared for read operations to block forever (i.e., should use
// a select operation for draining the channel in a deadlock-safe way). Also,
// we WILL NOT ever post a nil message to the "Response" channel.
//
// The returned DNSOverUDPChannel contains another channel called Joined that is
// closed when the background goroutine terminates, so you can use this channel
// should you need to synchronize with such goroutine's termination.
//
// If you are using the Next or TryNextResponses methods of the DNSOverUDPChannel type,
// you don't need to worry about these low level details though.
//
// We give you OWNERSHIP of the returned DNSOverUDPChannel and you MUST
// call its .Close method when done with using it.
func (t *DNSOverUDPTransport) AsyncRoundTrip(
ctx context.Context, query model.DNSQuery, buffer int) (*DNSOverUDPChannel, error) {
rawQuery, err := query.Bytes()
if err != nil {
return nil, err
}
conn, err := t.Dialer.DialContext(ctx, "udp", t.Endpoint)
if err != nil {
return nil, err
}
conn.SetDeadline(time.Now().Add(t.IOTimeout))
if buffer < 2 {
buffer = 1 // as documented
}
outch := make(chan *DNSOverUDPResponse, buffer)
joinedch := make(chan bool)
go t.sendRecvLoop(conn, rawQuery, query, outch, joinedch)
dnsch := &DNSOverUDPChannel{
Response: outch,
Joined: joinedch,
conn: conn, // transfer ownership
}
return dnsch, nil
}
// sendRecvLoop sends the given raw query on the given conn and receives responses
// from the conn posting them onto the given output channel.
//
// Arguments:
//
// 1. conn is the BORROWED net.Conn (we will use it for reading or writing but
// we do not own the connection and we're not going to close it);
//
// 2. rawQuery contains the rawQuery and is BORROWED (we won't modify it);
//
// 3. query contains the original query and is also BORROWED;
//
// 4. outch is the channel where to emit measurements and is OWNED by this
// function (that said, we WILL NOT close this channel);
//
// 5. eofch is the channel to signal EOF, which is OWNED by this function
// and closed when this function exits.
//
// This method terminates in the following cases:
//
// 1. I/O error while reading or writing (including the deadline expiring or
// the owner of the connection closing the connection);
//
// 2. We cannot post on the output channel because either there is
// noone reading the channel or the channel's buffer is full.
//
// 3. We cannot parse incoming data as a valid DNS response message that
// responds to the query that we originally sent.
func (t *DNSOverUDPTransport) sendRecvLoop(conn net.Conn, rawQuery []byte,
query model.DNSQuery, outch chan<- *DNSOverUDPResponse, eofch chan<- bool) {
defer close(eofch) // synchronize with the caller defer close(eofch) // synchronize with the caller
myaddr := conn.LocalAddr().String() defer conn.Close() // we own the conn
if _, err := conn.Write(rawQuery); err != nil { trace := ContextTraceOrDefault(ctx)
outch <- t.newDNSOverUDPResponse(
myaddr, err, query, nil, WriteOperation) // one-sized buffer, can't block
return
}
for { for {
started := trace.TimeNow()
resp, err := t.recv(query, conn) resp, err := t.recv(query, conn)
select { finished := trace.TimeNow()
case outch <- t.newDNSOverUDPResponse(myaddr, err, query, resp, ReadOperation):
default:
return // no-one is reading the channel -- so long...
}
if err != nil { if err != nil {
// We are going to consider all errors as fatal for now until we // We are going to consider all errors as fatal for now until we
// hear of specific errs that it might have sense to ignore. // hear of specific errs that it might have sense to ignore.
@ -316,6 +147,16 @@ func (t *DNSOverUDPTransport) sendRecvLoop(conn net.Conn, rawQuery []byte,
// Note that erroring out here includes the expiration of the conn's // Note that erroring out here includes the expiration of the conn's
// I/O deadline, which we set above precisely because we want // I/O deadline, which we set above precisely because we want
// the total runtime of this goroutine to be bounded. // the total runtime of this goroutine to be bounded.
//
// Also, we ARE NOT going to report any failure here as a delayed
// DNS response because we only care about duplicate messages, since
// this seems how censorship is implemented in, e.g., China.
return
}
addrs, err := resp.DecodeLookupHost()
if err := trace.OnDelayedDNSResponse(started, t, query, resp, addrs, err, finished); err != nil {
// This error typically indicates that the buffer on which we're
// writing is now full, so there's no point in persisting.
return return
} }
} }

View File

@ -5,6 +5,7 @@ import (
"context" "context"
"errors" "errors"
"net" "net"
"sync"
"testing" "testing"
"time" "time"
@ -14,6 +15,7 @@ import (
"github.com/ooni/probe-cli/v3/internal/model" "github.com/ooni/probe-cli/v3/internal/model"
"github.com/ooni/probe-cli/v3/internal/model/mocks" "github.com/ooni/probe-cli/v3/internal/model/mocks"
"github.com/ooni/probe-cli/v3/internal/netxlite/filtering" "github.com/ooni/probe-cli/v3/internal/netxlite/filtering"
"github.com/ooni/probe-cli/v3/internal/testingx"
) )
func TestDNSOverUDPTransport(t *testing.T) { func TestDNSOverUDPTransport(t *testing.T) {
@ -281,70 +283,16 @@ func TestDNSOverUDPTransport(t *testing.T) {
}) })
}) })
t.Run("AsyncRoundTrip", func(t *testing.T) { t.Run("recording delayed DNS responses", func(t *testing.T) {
t.Run("calling Next with cancelled context", func(t *testing.T) { t.Run("uses a context-injected custom trace (success case)", func(t *testing.T) {
srvr := &filtering.DNSServer{ var (
OnQuery: func(domain string) filtering.DNSAction { delayedDNSResponseCalled bool
return filtering.DNSActionCache goodQueryType bool
}, goodTransportNetwork bool
Cache: map[string][]string{ goodTransportAddress bool
"dns.google.": {"8.8.8.8"}, goodLookupAddrs bool
}, goodError bool
} )
listener, err := srvr.Start("127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
defer listener.Close()
dialer := NewDialerWithoutResolver(model.DiscardLogger)
txp := NewUnwrappedDNSOverUDPTransport(dialer, listener.LocalAddr().String())
encoder := &DNSEncoderMiekg{}
query := encoder.Encode("dns.google.", dns.TypeA, false)
ctx := context.Background()
rch, err := txp.AsyncRoundTrip(ctx, query, 1)
if err != nil {
t.Fatal(err)
}
defer rch.Close()
ctx, cancel := context.WithCancel(ctx)
cancel() // fail immediately
resp, err := rch.Next(ctx)
if !errors.Is(err, context.Canceled) {
t.Fatal("unexpected err", err)
}
if resp != nil {
t.Fatal("unexpected resp")
}
})
t.Run("no-one is reading the channel", func(t *testing.T) {
srvr := &filtering.DNSServer{
OnQuery: func(domain string) filtering.DNSAction {
return filtering.DNSActionLocalHostPlusCache // i.e., two responses
},
Cache: map[string][]string{
"dns.google.": {"8.8.8.8"},
},
}
listener, err := srvr.Start("127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
defer listener.Close()
dialer := NewDialerWithoutResolver(model.DiscardLogger)
txp := NewUnwrappedDNSOverUDPTransport(dialer, listener.LocalAddr().String())
encoder := &DNSEncoderMiekg{}
query := encoder.Encode("dns.google.", dns.TypeA, false)
ctx := context.Background()
rch, err := txp.AsyncRoundTrip(ctx, query, 1) // but just one place
if err != nil {
t.Fatal(err)
}
defer rch.Close()
<-rch.Joined // should see no-one is reading and stop
})
t.Run("typical usage to obtain late responses", func(t *testing.T) {
srvr := &filtering.DNSServer{ srvr := &filtering.DNSServer{
OnQuery: func(domain string) filtering.DNSAction { OnQuery: func(domain string) filtering.DNSAction {
return filtering.DNSActionLocalHostPlusCache return filtering.DNSActionLocalHostPlusCache
@ -359,52 +307,94 @@ func TestDNSOverUDPTransport(t *testing.T) {
} }
defer listener.Close() defer listener.Close()
dialer := NewDialerWithoutResolver(model.DiscardLogger) dialer := NewDialerWithoutResolver(model.DiscardLogger)
txp := NewUnwrappedDNSOverUDPTransport(dialer, listener.LocalAddr().String()) expectedAddress := listener.LocalAddr().String()
txp := NewUnwrappedDNSOverUDPTransport(dialer, expectedAddress)
encoder := &DNSEncoderMiekg{} encoder := &DNSEncoderMiekg{}
query := encoder.Encode("dns.google.", dns.TypeA, false) query := encoder.Encode("dns.google.", dns.TypeA, false)
rch, err := txp.AsyncRoundTrip(context.Background(), query, 1) zeroTime := time.Now()
deterministicTime := testingx.NewTimeDeterministic(zeroTime)
expectedAddrs := []string{"8.8.8.8"}
respChannel := make(chan *model.DNSResponse, 8)
mu := new(sync.Mutex)
tx := &mocks.Trace{
MockTimeNow: deterministicTime.Now,
MockOnDelayedDNSResponse: func(started time.Time, txp model.DNSTransport,
query model.DNSQuery, response model.DNSResponse, addrs []string, err error,
finished time.Time) error {
mu.Lock()
delayedDNSResponseCalled = true
goodQueryType = (query.Type() == dns.TypeA)
goodTransportNetwork = (txp.Network() == "udp")
goodTransportAddress = (txp.Address() == expectedAddress)
goodLookupAddrs = (cmp.Diff(expectedAddrs, addrs) == "")
goodError = (err == nil)
mu.Unlock()
select {
case respChannel <- &response:
return nil
default:
return errors.New("full buffer")
}
},
MockOnConnectDone: func(started time.Time, network, domain, remoteAddr string, err error,
finished time.Time) {
// do nothing
},
MockMaybeWrapNetConn: func(conn net.Conn) net.Conn {
return conn
},
}
ctx := ContextWithTrace(context.Background(), tx)
rch, err := txp.RoundTrip(ctx, query)
<-respChannel // wait for the delayed response
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
defer rch.Close() addrs, err := rch.DecodeLookupHost()
resp, err := rch.Next(context.Background())
if err != nil {
t.Fatal(err)
}
addrs, err := resp.DecodeLookupHost()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
mu.Lock()
if diff := cmp.Diff(addrs, []string{"127.0.0.1"}); diff != "" { if diff := cmp.Diff(addrs, []string{"127.0.0.1"}); diff != "" {
t.Fatal(diff) t.Fatal(diff)
} }
// One would not normally busy loop but it's fine to do that in the context if !delayedDNSResponseCalled {
// of this test because we know we're going to receive a second reply. In t.Fatal("delayedDNSResponse not called")
// a real network experiment here we'll do other activities, e.g., contacting
// the test helper or fetching a webpage.
var additional []model.DNSResponse
for {
additional = rch.TryNextResponses()
if len(additional) > 0 {
if len(additional) != 1 {
t.Fatal("expected exactly one additional response")
} }
break if !goodQueryType {
t.Fatal("unexpected query type")
} }
if !goodTransportNetwork {
t.Fatal("unexpected DNS transport network")
} }
addrs, err = additional[0].DecodeLookupHost() if !goodTransportAddress {
if err != nil { t.Fatal("unexpected DNS Transport address")
t.Fatal(err)
} }
if diff := cmp.Diff(addrs, []string{"8.8.8.8"}); diff != "" { if !goodLookupAddrs {
t.Fatal(diff) t.Fatal("unexpected delayed DNSLookup address")
} }
if !goodError {
t.Fatal("unexpected error encountered")
}
mu.Unlock()
}) })
t.Run("correct behavior when read times out", func(t *testing.T) { t.Run("uses a context-injected custom trace (failure case)", func(t *testing.T) {
var (
delayedDNSResponseCalled bool
goodQueryType bool
goodTransportNetwork bool
goodTransportAddress bool
goodLookupAddrs bool
goodError bool
)
srvr := &filtering.DNSServer{ srvr := &filtering.DNSServer{
OnQuery: func(domain string) filtering.DNSAction { OnQuery: func(domain string) filtering.DNSAction {
return filtering.DNSActionTimeout return filtering.DNSActionLocalHostPlusCache
},
Cache: map[string][]string{
// Note: the cache here is nonexistent so we should
// get a "no such host" error from the server.
}, },
} }
listener, err := srvr.Start("127.0.0.1:0") listener, err := srvr.Start("127.0.0.1:0")
@ -413,22 +403,71 @@ func TestDNSOverUDPTransport(t *testing.T) {
} }
defer listener.Close() defer listener.Close()
dialer := NewDialerWithoutResolver(model.DiscardLogger) dialer := NewDialerWithoutResolver(model.DiscardLogger)
txp := NewUnwrappedDNSOverUDPTransport(dialer, listener.LocalAddr().String()) expectedAddress := listener.LocalAddr().String()
txp.IOTimeout = 30 * time.Millisecond // short timeout to have a fast test txp := NewUnwrappedDNSOverUDPTransport(dialer, expectedAddress)
encoder := &DNSEncoderMiekg{} encoder := &DNSEncoderMiekg{}
query := encoder.Encode("dns.google.", dns.TypeA, false) query := encoder.Encode("dns.google.", dns.TypeA, false)
rch, err := txp.AsyncRoundTrip(context.Background(), query, 1) zeroTime := time.Now()
deterministicTime := testingx.NewTimeDeterministic(zeroTime)
respChannel := make(chan *model.DNSResponse, 8)
mu := new(sync.Mutex)
tx := &mocks.Trace{
MockTimeNow: deterministicTime.Now,
MockOnDelayedDNSResponse: func(started time.Time, txp model.DNSTransport,
query model.DNSQuery, response model.DNSResponse, addrs []string, err error,
finished time.Time) error {
mu.Lock()
delayedDNSResponseCalled = true
goodQueryType = (query.Type() == dns.TypeA)
goodTransportNetwork = (txp.Network() == "udp")
goodTransportAddress = (txp.Address() == expectedAddress)
goodLookupAddrs = (len(addrs) == 0)
goodError = errors.Is(err, ErrOODNSNoSuchHost)
mu.Unlock()
respChannel <- &response
return errors.New("mocked") // return error to stop background routine to record responses
},
MockOnConnectDone: func(started time.Time, network, domain, remoteAddr string, err error,
finished time.Time) {
// do nothing
},
MockMaybeWrapNetConn: func(conn net.Conn) net.Conn {
return conn
},
}
ctx := ContextWithTrace(context.Background(), tx)
rch, err := txp.RoundTrip(ctx, query)
<-respChannel // wait for the delayed response
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
defer rch.Close() addrs, err := rch.DecodeLookupHost()
result := <-rch.Response if err != nil {
if result.Err == nil || result.Err.Error() != "generic_timeout_error" { t.Fatal(err)
t.Fatal("unexpected error", result.Err)
} }
if result.Operation != ReadOperation { mu.Lock()
t.Fatal("unexpected failed operation", result.Operation) if diff := cmp.Diff(addrs, []string{"127.0.0.1"}); diff != "" {
t.Fatal(diff)
} }
if !delayedDNSResponseCalled {
t.Fatal("delayedDNSResponse not called")
}
if !goodQueryType {
t.Fatal("unexpected query type")
}
if !goodTransportNetwork {
t.Fatal("unexpected DNS transport network")
}
if !goodTransportAddress {
t.Fatal("unexpected DNS Transport address")
}
if !goodLookupAddrs {
t.Fatal("unexpected delayed DNSLookup address")
}
if !goodError {
t.Fatal("unexpected error encountered")
}
mu.Unlock()
}) })
}) })

View File

@ -67,6 +67,12 @@ func (*traceDefault) OnDNSRoundTripForLookupHost(started time.Time, reso model.R
// nothing // nothing
} }
// OnDelayedDNSResponse implements model.Trace.OnDelayedDNSResponse.
func (*traceDefault) OnDelayedDNSResponse(started time.Time, txp model.DNSTransport,
query model.DNSQuery, response model.DNSResponse, addrs []string, err error, finished time.Time) error {
return nil
}
// OnConnectDone implements model.Trace.OnConnectDone. // OnConnectDone implements model.Trace.OnConnectDone.
func (*traceDefault) OnConnectDone( func (*traceDefault) OnConnectDone(
started time.Time, network, domain, remoteAddr string, err error, finished time.Time) { started time.Time, network, domain, remoteAddr string, err error, finished time.Time) {