diff --git a/internal/measurexlite/dns.go b/internal/measurexlite/dns.go index cd2353b..fd86eaa 100644 --- a/internal/measurexlite/dns.go +++ b/internal/measurexlite/dns.go @@ -6,6 +6,7 @@ package measurexlite import ( "context" + "errors" "log" "net" "time" @@ -97,9 +98,20 @@ func (tx *Trace) OnDNSRoundTripForLookupHost(started time.Time, reso model.Resol } } +// DNSNetworkAddresser is the type of something we just used to perform a DNS +// round trip (e.g., model.DNSTransport, model.Resolver) that allows us to get +// the network and the address of the underlying resolver/transport. +type DNSNetworkAddresser interface { + // Address is like model.DNSTransport.Address + Address() string + + // Network is like model.DNSTransport.Network + Network() string +} + // NewArchivalDNSLookupResultFromRoundTrip generates a model.ArchivalDNSLookupResultFromRoundTrip // from the available information right after the DNS RoundTrip -func NewArchivalDNSLookupResultFromRoundTrip(index int64, started time.Duration, reso model.Resolver, query model.DNSQuery, +func NewArchivalDNSLookupResultFromRoundTrip(index int64, started time.Duration, reso DNSNetworkAddresser, query model.DNSQuery, response model.DNSResponse, addrs []string, err error, finished time.Duration) *model.ArchivalDNSLookupResult { return &model.ArchivalDNSLookupResult{ Answers: archivalAnswersFromAddrs(addrs), @@ -167,3 +179,53 @@ func (tx *Trace) FirstDNSLookup() *model.ArchivalDNSLookupResult { } return ev[0] } + +// ErrDelayedDNSResponseBufferFull indicates that the delayedDNSResponse buffer is full. +var ErrDelayedDNSResponseBufferFull = errors.New("buffer full") + +// OnDelayedDNSResponse implements model.Trace.OnDelayedDNSResponse +func (tx *Trace) OnDelayedDNSResponse(started time.Time, txp model.DNSTransport, query model.DNSQuery, + response model.DNSResponse, addrs []string, err error, finished time.Time) error { + t := finished.Sub(tx.ZeroTime) + select { + case tx.delayedDNSResponse <- NewArchivalDNSLookupResultFromRoundTrip( + tx.Index, + started.Sub(tx.ZeroTime), + txp, + query, + response, + addrs, + err, + t, + ): + return nil + default: + return ErrDelayedDNSResponseBufferFull + } +} + +// DelayedDNSResponseWithTimeout drains the network events buffered inside +// the delayedDNSResponse channel. We construct a child context based on [ctx] +// and the given [timeout] and we stop reading when original [ctx] has been +// cancelled or the given [timeout] expires, whatever happens first. Once the +// timeout expired, we drain the chan as much as possible before returning. +func (tx *Trace) DelayedDNSResponseWithTimeout(ctx context.Context, + timeout time.Duration) (out []*model.ArchivalDNSLookupResult) { + ctx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + for { + select { + case <-ctx.Done(): + for { // once the context is done enter in channel draining mode + select { + case ev := <-tx.delayedDNSResponse: + out = append(out, ev) + default: + return + } + } + case ev := <-tx.delayedDNSResponse: + out = append(out, ev) + } + } +} diff --git a/internal/measurexlite/dns_test.go b/internal/measurexlite/dns_test.go index b7138bb..fb0c94e 100644 --- a/internal/measurexlite/dns_test.go +++ b/internal/measurexlite/dns_test.go @@ -2,6 +2,7 @@ package measurexlite import ( "context" + "errors" "net" "testing" "time" @@ -322,6 +323,158 @@ func TestFirstDNSLookup(t *testing.T) { }) } +func TestDelayedDNSResponseWithTimeout(t *testing.T) { + t.Run("OnDelayedDNSResponse saves into the trace", func(t *testing.T) { + t.Run("when buffer is not full", func(t *testing.T) { + zeroTime := time.Now() + td := testingx.NewTimeDeterministic(zeroTime) + trace := NewTrace(0, zeroTime) + trace.TimeNowFn = td.Now + txp := &mocks.DNSTransport{ + MockNetwork: func() string { + return "udp" + }, + MockAddress: func() string { + return "1.1.1.1" + }, + } + started := trace.TimeNow() + query := &mocks.DNSQuery{ + MockType: func() uint16 { + return dns.TypeA + }, + MockDomain: func() string { + return "dns.google.com" + }, + } + addrs := []string{"1.1.1.1"} + finished := trace.TimeNow() + // 1. fill the trace + err := trace.OnDelayedDNSResponse(started, txp, query, &mocks.DNSResponse{}, + addrs, nil, finished) + // 2. read the trace + got := trace.DelayedDNSResponseWithTimeout(context.Background(), time.Second) + if err != nil { + t.Fatal("unexpected error", err) + } + if len(got) != 1 { + t.Fatal("unexpected output from trace") + } + }) + + t.Run("when buffer is full", func(t *testing.T) { + zeroTime := time.Now() + td := testingx.NewTimeDeterministic(zeroTime) + trace := NewTrace(0, zeroTime) + trace.TimeNowFn = td.Now + trace.delayedDNSResponse = make(chan *model.ArchivalDNSLookupResult) // no buffer + txp := &mocks.DNSTransport{ + MockNetwork: func() string { + return "udp" + }, + MockAddress: func() string { + return "1.1.1.1" + }, + } + started := trace.TimeNow() + query := &mocks.DNSQuery{ + MockType: func() uint16 { + return dns.TypeA + }, + MockDomain: func() string { + return "dns.google.com" + }, + } + addrs := []string{"1.1.1.1"} + finished := trace.TimeNow() + // 1. attempt to write into the trace + err := trace.OnDelayedDNSResponse(started, txp, query, &mocks.DNSResponse{}, + addrs, nil, finished) + if !errors.Is(err, ErrDelayedDNSResponseBufferFull) { + t.Fatal("unexpected error", err) + } + // 2. confirm we didn't write anything + got := trace.DelayedDNSResponseWithTimeout(context.Background(), time.Second) + if len(got) != 0 { + t.Fatal("unexpected output from trace") + } + }) + }) + + t.Run("DelayedDNSResponseWithTimeout drains the trace", func(t *testing.T) { + t.Run("context is already cancelled and we still drain the trace", func(t *testing.T) { + zeroTime := time.Now() + td := testingx.NewTimeDeterministic(zeroTime) + trace := NewTrace(0, zeroTime) + trace.TimeNowFn = td.Now + txp := &mocks.DNSTransport{ + MockNetwork: func() string { + return "udp" + }, + MockAddress: func() string { + return "1.1.1.1" + }, + } + started := trace.TimeNow() + query := &mocks.DNSQuery{ + MockType: func() uint16 { + return dns.TypeA + }, + MockDomain: func() string { + return "dns.google.com" + }, + } + addrs := []string{"1.1.1.1"} + finished := trace.TimeNow() + events := 4 + for i := 0; i < events; i++ { + // fill the trace + trace.delayedDNSResponse <- NewArchivalDNSLookupResultFromRoundTrip(trace.Index, started.Sub(trace.ZeroTime), + txp, query, &mocks.DNSResponse{}, addrs, nil, finished.Sub(trace.ZeroTime)) + } + ctx, cancel := context.WithCancel(context.Background()) + cancel() // we ensure that the context cancels before draining all the events + // drain the trace + got := trace.DelayedDNSResponseWithTimeout(ctx, 10*time.Second) + if len(got) != 4 { + t.Fatal("unexpected output from trace", len(got)) + } + }) + + t.Run("normal case where the context times out after we start draining", func(t *testing.T) { + zeroTime := time.Now() + td := testingx.NewTimeDeterministic(zeroTime) + trace := NewTrace(0, zeroTime) + trace.TimeNowFn = td.Now + txp := &mocks.DNSTransport{ + MockNetwork: func() string { + return "udp" + }, + MockAddress: func() string { + return "1.1.1.1" + }, + } + started := trace.TimeNow() + query := &mocks.DNSQuery{ + MockType: func() uint16 { + return dns.TypeA + }, + MockDomain: func() string { + return "dns.google.com" + }, + } + addrs := []string{"1.1.1.1"} + finished := trace.TimeNow() + trace.delayedDNSResponse <- NewArchivalDNSLookupResultFromRoundTrip(trace.Index, started.Sub(trace.ZeroTime), + txp, query, &mocks.DNSResponse{}, addrs, nil, finished.Sub(trace.ZeroTime)) + got := trace.DelayedDNSResponseWithTimeout(context.Background(), time.Second) + if len(got) != 1 { + t.Fatal("unexpected output from trace") + } + }) + }) +} + func TestAnswersFromAddrs(t *testing.T) { tests := []struct { name string diff --git a/internal/measurexlite/trace.go b/internal/measurexlite/trace.go index 49c483b..b631cb6 100644 --- a/internal/measurexlite/trace.go +++ b/internal/measurexlite/trace.go @@ -34,8 +34,7 @@ type Trace struct { // traces, you can use zero to indicate the "default" trace. Index int64 - // networkEvent is MANDATORY and buffers network events. If you create - // this channel manually, ensure it has some buffer. + // networkEvent is MANDATORY and buffers network events. networkEvent chan *model.ArchivalNetworkEvent // NewStdlibResolverFn is OPTIONAL and can be used to overide @@ -62,20 +61,19 @@ type Trace struct { // calls to the netxlite.NewQUICDialerWithoutResolver factory. NewQUICDialerWithoutResolverFn func(listener model.QUICListener, dl model.DebugLogger) model.QUICDialer - // dnsLookup is MANDATORY and buffers DNS Lookup observations. If you create - // this channel manually, ensure it has some buffer. + // dnsLookup is MANDATORY and buffers DNS Lookup observations. dnsLookup chan *model.ArchivalDNSLookupResult - // tcpConnect is MANDATORY and buffers TCP connect observations. If you create - // this channel manually, ensure it has some buffer. + // delayedDNSResponse is MANDATORY and buffers delayed DNS responses. + delayedDNSResponse chan *model.ArchivalDNSLookupResult + + // tcpConnect is MANDATORY and buffers TCP connect observations. tcpConnect chan *model.ArchivalTCPConnectResult - // tlsHandshake is MANDATORY and buffers TLS handshake observations. If you create - // this channel manually, ensure it has some buffer. + // tlsHandshake is MANDATORY and buffers TLS handshake observations. tlsHandshake chan *model.ArchivalTLSOrQUICHandshakeResult - // quicHandshake is MANDATORY and buffers QUIC handshake observations. If you create - // this channel manually, ensure it has some buffer. + // quicHandshake is MANDATORY and buffers QUIC handshake observations. quicHandshake chan *model.ArchivalTLSOrQUICHandshakeResult // TimeNowFn is OPTIONAL and can be used to override calls to time.Now @@ -88,23 +86,27 @@ type Trace struct { const ( // NetworkEventBufferSize is the buffer size for constructing - // the Trace's NetworkEvent buffered channel. + // the Trace's networkEvent buffered channel. NetworkEventBufferSize = 64 // DNSLookupBufferSize is the buffer size for constructing - // the Trace's DNSLookup map of buffered channels. + // the Trace's dnsLookup buffered channel. DNSLookupBufferSize = 8 + // DNSResponseBufferSize is the buffer size for constructing + // the Trace's dnsDelayedResponse buffered channel. + DelayedDNSResponseBufferSize = 8 + // TCPConnectBufferSize is the buffer size for constructing - // the Trace's TCPConnect buffered channel. + // the Trace's tcpConnect buffered channel. TCPConnectBufferSize = 8 // TLSHandshakeBufferSize is the buffer for construcing - // the Trace's TLSHandshake buffered channel. + // the Trace's tlsHandshake buffered channel. TLSHandshakeBufferSize = 8 // QUICHandshakeBufferSize is the buffer for constructing - // the Trace's QUICHandshake buffered channel. + // the Trace's quicHandshake buffered channel. QUICHandshakeBufferSize = 8 ) @@ -132,6 +134,10 @@ func NewTrace(index int64, zeroTime time.Time) *Trace { chan *model.ArchivalDNSLookupResult, DNSLookupBufferSize, ), + delayedDNSResponse: make( + chan *model.ArchivalDNSLookupResult, + DelayedDNSResponseBufferSize, + ), tcpConnect: make( chan *model.ArchivalTCPConnectResult, TCPConnectBufferSize, diff --git a/internal/measurexlite/trace_test.go b/internal/measurexlite/trace_test.go index ce29c66..2e8396e 100644 --- a/internal/measurexlite/trace_test.go +++ b/internal/measurexlite/trace_test.go @@ -84,7 +84,7 @@ func TestNewTrace(t *testing.T) { } }) - t.Run("DNSLookup has the expected buffer size", func(t *testing.T) { + t.Run("dnsLookup has the expected buffer size", func(t *testing.T) { ff := &testingx.FakeFiller{} var idx int Loop: @@ -99,11 +99,30 @@ func TestNewTrace(t *testing.T) { } } if idx != DNSLookupBufferSize { - t.Fatal("invalid DNSLookup channel buffer size") + t.Fatal("invalid dnsLookup channel buffer size") } }) - t.Run("TCPConnect has the expected buffer size", func(t *testing.T) { + t.Run("delayedDNSResponse has the expected buffer size", func(t *testing.T) { + ff := &testingx.FakeFiller{} + var idx int + Loop: + for { + ev := &model.ArchivalDNSLookupResult{} + ff.Fill(ev) + select { + case trace.delayedDNSResponse <- ev: + idx++ + default: + break Loop + } + } + if idx != DelayedDNSResponseBufferSize { + t.Fatal("invalid delayedDNSResponse channel buffer size") + } + }) + + t.Run("tcpConnect has the expected buffer size", func(t *testing.T) { ff := &testingx.FakeFiller{} var idx int Loop: @@ -118,11 +137,11 @@ func TestNewTrace(t *testing.T) { } } if idx != TCPConnectBufferSize { - t.Fatal("invalid TCPConnect channel buffer size") + t.Fatal("invalid tcpConnect channel buffer size") } }) - t.Run("TLSHandshake has the expected buffer size", func(t *testing.T) { + t.Run("tlsHandshake has the expected buffer size", func(t *testing.T) { ff := &testingx.FakeFiller{} var idx int Loop: @@ -137,11 +156,11 @@ func TestNewTrace(t *testing.T) { } } if idx != TLSHandshakeBufferSize { - t.Fatal("invalid TLSHandshake channel buffer size") + t.Fatal("invalid tlsHandshake channel buffer size") } }) - t.Run("QUICHandshake has the expected buffer size", func(t *testing.T) { + t.Run("quicHandshake has the expected buffer size", func(t *testing.T) { ff := &testingx.FakeFiller{} var idx int Loop: @@ -156,7 +175,7 @@ func TestNewTrace(t *testing.T) { } } if idx != QUICHandshakeBufferSize { - t.Fatal("invalid QUICHandshake channel buffer size") + t.Fatal("invalid quicHandshake channel buffer size") } }) diff --git a/internal/model/mocks/trace.go b/internal/model/mocks/trace.go index aca9962..bf9132b 100644 --- a/internal/model/mocks/trace.go +++ b/internal/model/mocks/trace.go @@ -24,6 +24,9 @@ type Trace struct { MockOnDNSRoundTripForLookupHost func(started time.Time, reso model.Resolver, query model.DNSQuery, response model.DNSResponse, addrs []string, err error, finished time.Time) + MockOnDelayedDNSResponse func(started time.Time, txp model.DNSTransport, query model.DNSQuery, + response model.DNSResponse, addrs []string, err error, finished time.Time) error + MockOnConnectDone func( started time.Time, network, domain, remoteAddr string, err error, finished time.Time) @@ -57,6 +60,11 @@ func (t *Trace) OnDNSRoundTripForLookupHost(started time.Time, reso model.Resolv t.MockOnDNSRoundTripForLookupHost(started, reso, query, response, addrs, err, finished) } +func (t *Trace) OnDelayedDNSResponse(started time.Time, txp model.DNSTransport, query model.DNSQuery, + response model.DNSResponse, addrs []string, err error, finished time.Time) error { + return t.MockOnDelayedDNSResponse(started, txp, query, response, addrs, err, finished) +} + func (t *Trace) OnConnectDone( started time.Time, network, domain, remoteAddr string, err error, finished time.Time) { t.MockOnConnectDone(started, network, domain, remoteAddr, err, finished) diff --git a/internal/model/mocks/trace_test.go b/internal/model/mocks/trace_test.go index 2972bf7..3b294ff 100644 --- a/internal/model/mocks/trace_test.go +++ b/internal/model/mocks/trace_test.go @@ -71,6 +71,30 @@ func TestTrace(t *testing.T) { } }) + t.Run("OnDelayedDNSResponse", func(t *testing.T) { + var called bool + tx := &Trace{ + MockOnDelayedDNSResponse: func(started time.Time, txp model.DNSTransport, + query model.DNSQuery, response model.DNSResponse, + addrs []string, err error, finished time.Time) error { + called = true + return nil + }, + } + tx.OnDelayedDNSResponse( + time.Now(), + &DNSTransport{}, + &DNSQuery{}, + &DNSResponse{}, + []string{}, + nil, + time.Now(), + ) + if !called { + t.Fatal("not called") + } + }) + t.Run("OnConnectDone", func(t *testing.T) { var called bool tx := &Trace{ diff --git a/internal/model/netx.go b/internal/model/netx.go index 0f3a56b..ecc9f5e 100644 --- a/internal/model/netx.go +++ b/internal/model/netx.go @@ -340,6 +340,29 @@ type Trace interface { OnDNSRoundTripForLookupHost(started time.Time, reso Resolver, query DNSQuery, response DNSResponse, addrs []string, err error, finished time.Time) + // OnDelayedDNSResponse is used with a DNSOverUDPTransport and called + // when we get delayed, unexpected DNS responses. + // + // Arguments: + // + // - started is when we started reading the delayed response; + // + // - txp is the DNS transport used with the resolver; + // + // - query is the non-nil DNS query we use for the RoundTrip; + // + // - response is the non-nil valid DNS response, obtained after some delay; + // + // - addrs is the list of addresses obtained after decoding the delayed response, + // which is empty if the response did not contain any addresses, which we + // extract by calling the DecodeLookupHost method. + // + // - err is the result of DecodeLookupHost: either an error or nil; + // + // - finished is when we have read the delayed response. + OnDelayedDNSResponse(started time.Time, txp DNSTransport, query DNSQuery, + resp DNSResponse, addrs []string, err error, finsihed time.Time) error + // OnConnectDone is called when connect terminates. // // Arguments: diff --git a/internal/netxlite/dnsoverudp.go b/internal/netxlite/dnsoverudp.go index 4f453c9..9d246c3 100644 --- a/internal/netxlite/dnsoverudp.go +++ b/internal/netxlite/dnsoverudp.go @@ -45,10 +45,6 @@ type DNSOverUDPTransport struct { // Endpoint is the MANDATORY server's endpoint (e.g., 1.1.1.1:53) Endpoint string - - // IOTimeout is the MANDATORY I/O timeout after which any - // conn created to perform round trips times out. - IOTimeout time.Duration } // NewUnwrappedDNSOverUDPTransport creates a DNSOverUDPTransport instance @@ -67,10 +63,9 @@ type DNSOverUDPTransport struct { // have less control over which IP address is being used. func NewUnwrappedDNSOverUDPTransport(dialer model.Dialer, address string) *DNSOverUDPTransport { return &DNSOverUDPTransport{ - Decoder: &DNSDecoderMiekg{}, - Dialer: dialer, - Endpoint: address, - IOTimeout: 10 * time.Second, + Decoder: &DNSDecoderMiekg{}, + Dialer: dialer, + Endpoint: address, } } @@ -78,21 +73,36 @@ func NewUnwrappedDNSOverUDPTransport(dialer model.Dialer, address string) *DNSOv func (t *DNSOverUDPTransport) RoundTrip( ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) { // QUIRK: the original code had a five seconds timeout, which is - // consistent with the Bionic implementation. Let's enforce such a - // timeout using the context in the outer operation because we - // need to run for more seconds in the background to catch as many - // duplicate replies as possible. + // consistent with the Bionic implementation. // // See https://labs.ripe.net/Members/baptiste_jonglez_1/persistent-dns-connections-for-reliability-and-performance const opTimeout = 5 * time.Second ctx, cancel := context.WithTimeout(ctx, opTimeout) defer cancel() - outch, err := t.AsyncRoundTrip(ctx, query, 1) // buffer to avoid background's goroutine leak + rawQuery, err := query.Bytes() if err != nil { return nil, err } - defer outch.Close() // we own the channel - return outch.Next(ctx) + conn, err := t.Dialer.DialContext(ctx, "udp", t.Endpoint) + if err != nil { + return nil, err + } + conn.SetDeadline(time.Now().Add(opTimeout)) + joinedch := make(chan bool) + myaddr := conn.LocalAddr().String() + if _, err := conn.Write(rawQuery); err != nil { + conn.Close() // we still own the conn + return nil, err + } + resp, err := t.recv(query, conn) + if err != nil { + conn.Close() // we still own the conn + return nil, err + } + // start a goroutine to listen for any delayed DNS response and + // TRANSFER the conn's OWNERSHIP to such a goroutine. + go t.ownConnAndSendRecvLoop(ctx, conn, query, myaddr, joinedch) + return resp, nil } // RequiresPadding returns false for UDP according to RFC8467. @@ -119,196 +129,17 @@ func (t *DNSOverUDPTransport) CloseIdleConnections() { var _ model.DNSTransport = &DNSOverUDPTransport{} -// DNSOverUDPResponse is a response received by a DNSOverUDPTransport when you -// use its AsyncRoundTrip method as opposed to using RoundTrip. -type DNSOverUDPResponse struct { - // Err is the error that occurred (nil in case of success). - Err error - - // LocalAddr is the local UDP address we're using. - LocalAddr string - - // Operation is the operation that failed. - Operation string - - // Query is the related DNS query. - Query model.DNSQuery - - // RemoteAddr is the remote server address. - RemoteAddr string - - // Response is the response (nil iff error is not nil). - Response model.DNSResponse -} - -// newDNSOverUDPResponse creates a new DNSOverUDPResponse instance. -func (t *DNSOverUDPTransport) newDNSOverUDPResponse(localAddr string, err error, - query model.DNSQuery, resp model.DNSResponse, operation string) *DNSOverUDPResponse { - return &DNSOverUDPResponse{ - Err: err, - LocalAddr: localAddr, - Operation: operation, - Query: query, - RemoteAddr: t.Endpoint, // The common case is to have an IP:port here (domains are discouraged) - Response: resp, - } -} - -// DNSOverUDPChannel is a wrapper around a channel for reading zero -// or more *DNSOverUDPResponse that makes extracting information from -// the underlying channels more user friendly than interacting with -// the channels directly, thanks to useful wrapper methods implementing -// common access patterns. You can still use the underlying channels -// directly if there's no suitable convenience method. -// -// You MUST call the .Close method when done. Not calling such a method -// leaks goroutines and causes connections to stay open forever. -type DNSOverUDPChannel struct { - // Response is the channel where we'll post responses. This channel - // WILL NOT be closed when the background goroutine terminates. - Response <-chan *DNSOverUDPResponse - - // Joined IS CLOSED when the background goroutine terminates. - Joined <-chan bool - - // conn is the underlying connection, which we can Close to - // immediately cause the background goroutine to join. - conn net.Conn -} - -// Close releases the resources allocated by the channel. You MUST -// call this method to force the background goroutine that is performing -// the round trip to terminate. Calling this method also ensures we -// close the connection used by the round trip. This method is idempotent. -func (ch *DNSOverUDPChannel) Close() error { - return ch.conn.Close() -} - -// Next blocks until the next response is received on Response or the -// given context expires, whatever happens first. This function will -// completely ignore the Joined channel and will just timeout in case -// you call Next after the background goroutine had joined. In fact, -// the use case for this function is using it to get a response or -// a timeout when you know the DNS round trip is pending. -func (ch *DNSOverUDPChannel) Next(ctx context.Context) (model.DNSResponse, error) { - select { - case <-ctx.Done(): - return nil, ctx.Err() - case out := <-ch.Response: // Note: AsyncRoundTrip WILL NOT close the channel or emit a nil - return out.Response, out.Err - } -} - -// TryNextResponses attempts to read all the buffered messages inside of the "Response" -// channel that contains successful DNS responses. That is, this function will silently skip -// any possible DNSOverUDPResponse with its Err != nil. The use case for this function is -// to obtain all the subsequent response messages we received while we were performing -// other operations (e.g., contacting the test helper of fetching a webpage). -func (ch *DNSOverUDPChannel) TryNextResponses() (out []model.DNSResponse) { - for { - select { - case r := <-ch.Response: // Note: AsyncRoundTrip WILL NOT close the channel or emit a nil - if r.Err == nil && r.Response != nil { - out = append(out, r.Response) - } - default: - return - } - } -} - -// AsyncRoundTrip performs an async DNS round trip. The "buffer" argument -// controls how many buffer slots the returned DNSOverUDPChannel's Response -// channel should have. A zero or negative value causes this function to -// create a channel having a single-slot buffer. -// -// The real round trip runs in a background goroutine. We will terminate the background -// goroutine when (1) the IOTimeout expires for the connection we're using or (2) we -// cannot write on the "Response" channel or (3) the connection is closed by calling the -// Close method of DNSOverUDPChannel. Note that the background goroutine WILL NOT close -// the "Response" channel to signal its completion. Hence, who reads such a -// channel MUST be prepared for read operations to block forever (i.e., should use -// a select operation for draining the channel in a deadlock-safe way). Also, -// we WILL NOT ever post a nil message to the "Response" channel. -// -// The returned DNSOverUDPChannel contains another channel called Joined that is -// closed when the background goroutine terminates, so you can use this channel -// should you need to synchronize with such goroutine's termination. -// -// If you are using the Next or TryNextResponses methods of the DNSOverUDPChannel type, -// you don't need to worry about these low level details though. -// -// We give you OWNERSHIP of the returned DNSOverUDPChannel and you MUST -// call its .Close method when done with using it. -func (t *DNSOverUDPTransport) AsyncRoundTrip( - ctx context.Context, query model.DNSQuery, buffer int) (*DNSOverUDPChannel, error) { - rawQuery, err := query.Bytes() - if err != nil { - return nil, err - } - conn, err := t.Dialer.DialContext(ctx, "udp", t.Endpoint) - if err != nil { - return nil, err - } - conn.SetDeadline(time.Now().Add(t.IOTimeout)) - if buffer < 2 { - buffer = 1 // as documented - } - outch := make(chan *DNSOverUDPResponse, buffer) - joinedch := make(chan bool) - go t.sendRecvLoop(conn, rawQuery, query, outch, joinedch) - dnsch := &DNSOverUDPChannel{ - Response: outch, - Joined: joinedch, - conn: conn, // transfer ownership - } - return dnsch, nil -} - -// sendRecvLoop sends the given raw query on the given conn and receives responses -// from the conn posting them onto the given output channel. -// -// Arguments: -// -// 1. conn is the BORROWED net.Conn (we will use it for reading or writing but -// we do not own the connection and we're not going to close it); -// -// 2. rawQuery contains the rawQuery and is BORROWED (we won't modify it); -// -// 3. query contains the original query and is also BORROWED; -// -// 4. outch is the channel where to emit measurements and is OWNED by this -// function (that said, we WILL NOT close this channel); -// -// 5. eofch is the channel to signal EOF, which is OWNED by this function -// and closed when this function exits. -// -// This method terminates in the following cases: -// -// 1. I/O error while reading or writing (including the deadline expiring or -// the owner of the connection closing the connection); -// -// 2. We cannot post on the output channel because either there is -// noone reading the channel or the channel's buffer is full. -// -// 3. We cannot parse incoming data as a valid DNS response message that -// responds to the query that we originally sent. -func (t *DNSOverUDPTransport) sendRecvLoop(conn net.Conn, rawQuery []byte, - query model.DNSQuery, outch chan<- *DNSOverUDPResponse, eofch chan<- bool) { +// ownConnAndSendRecvLoop listens for delayed DNS responses after we have returned the +// first response. As the name implies, this function TAKES OWNERSHIP of the [conn]. +func (t *DNSOverUDPTransport) ownConnAndSendRecvLoop(ctx context.Context, conn net.Conn, + query model.DNSQuery, myaddr string, eofch chan<- bool) { defer close(eofch) // synchronize with the caller - myaddr := conn.LocalAddr().String() - if _, err := conn.Write(rawQuery); err != nil { - outch <- t.newDNSOverUDPResponse( - myaddr, err, query, nil, WriteOperation) // one-sized buffer, can't block - return - } + defer conn.Close() // we own the conn + trace := ContextTraceOrDefault(ctx) for { + started := trace.TimeNow() resp, err := t.recv(query, conn) - select { - case outch <- t.newDNSOverUDPResponse(myaddr, err, query, resp, ReadOperation): - default: - return // no-one is reading the channel -- so long... - } + finished := trace.TimeNow() if err != nil { // We are going to consider all errors as fatal for now until we // hear of specific errs that it might have sense to ignore. @@ -316,6 +147,16 @@ func (t *DNSOverUDPTransport) sendRecvLoop(conn net.Conn, rawQuery []byte, // Note that erroring out here includes the expiration of the conn's // I/O deadline, which we set above precisely because we want // the total runtime of this goroutine to be bounded. + // + // Also, we ARE NOT going to report any failure here as a delayed + // DNS response because we only care about duplicate messages, since + // this seems how censorship is implemented in, e.g., China. + return + } + addrs, err := resp.DecodeLookupHost() + if err := trace.OnDelayedDNSResponse(started, t, query, resp, addrs, err, finished); err != nil { + // This error typically indicates that the buffer on which we're + // writing is now full, so there's no point in persisting. return } } diff --git a/internal/netxlite/dnsoverudp_test.go b/internal/netxlite/dnsoverudp_test.go index 2e14a01..6246842 100644 --- a/internal/netxlite/dnsoverudp_test.go +++ b/internal/netxlite/dnsoverudp_test.go @@ -5,6 +5,7 @@ import ( "context" "errors" "net" + "sync" "testing" "time" @@ -14,6 +15,7 @@ import ( "github.com/ooni/probe-cli/v3/internal/model" "github.com/ooni/probe-cli/v3/internal/model/mocks" "github.com/ooni/probe-cli/v3/internal/netxlite/filtering" + "github.com/ooni/probe-cli/v3/internal/testingx" ) func TestDNSOverUDPTransport(t *testing.T) { @@ -281,70 +283,16 @@ func TestDNSOverUDPTransport(t *testing.T) { }) }) - t.Run("AsyncRoundTrip", func(t *testing.T) { - t.Run("calling Next with cancelled context", func(t *testing.T) { - srvr := &filtering.DNSServer{ - OnQuery: func(domain string) filtering.DNSAction { - return filtering.DNSActionCache - }, - Cache: map[string][]string{ - "dns.google.": {"8.8.8.8"}, - }, - } - listener, err := srvr.Start("127.0.0.1:0") - if err != nil { - t.Fatal(err) - } - defer listener.Close() - dialer := NewDialerWithoutResolver(model.DiscardLogger) - txp := NewUnwrappedDNSOverUDPTransport(dialer, listener.LocalAddr().String()) - encoder := &DNSEncoderMiekg{} - query := encoder.Encode("dns.google.", dns.TypeA, false) - ctx := context.Background() - rch, err := txp.AsyncRoundTrip(ctx, query, 1) - if err != nil { - t.Fatal(err) - } - defer rch.Close() - ctx, cancel := context.WithCancel(ctx) - cancel() // fail immediately - resp, err := rch.Next(ctx) - if !errors.Is(err, context.Canceled) { - t.Fatal("unexpected err", err) - } - if resp != nil { - t.Fatal("unexpected resp") - } - }) - - t.Run("no-one is reading the channel", func(t *testing.T) { - srvr := &filtering.DNSServer{ - OnQuery: func(domain string) filtering.DNSAction { - return filtering.DNSActionLocalHostPlusCache // i.e., two responses - }, - Cache: map[string][]string{ - "dns.google.": {"8.8.8.8"}, - }, - } - listener, err := srvr.Start("127.0.0.1:0") - if err != nil { - t.Fatal(err) - } - defer listener.Close() - dialer := NewDialerWithoutResolver(model.DiscardLogger) - txp := NewUnwrappedDNSOverUDPTransport(dialer, listener.LocalAddr().String()) - encoder := &DNSEncoderMiekg{} - query := encoder.Encode("dns.google.", dns.TypeA, false) - ctx := context.Background() - rch, err := txp.AsyncRoundTrip(ctx, query, 1) // but just one place - if err != nil { - t.Fatal(err) - } - defer rch.Close() - <-rch.Joined // should see no-one is reading and stop - }) - - t.Run("typical usage to obtain late responses", func(t *testing.T) { + t.Run("recording delayed DNS responses", func(t *testing.T) { + t.Run("uses a context-injected custom trace (success case)", func(t *testing.T) { + var ( + delayedDNSResponseCalled bool + goodQueryType bool + goodTransportNetwork bool + goodTransportAddress bool + goodLookupAddrs bool + goodError bool + ) srvr := &filtering.DNSServer{ OnQuery: func(domain string) filtering.DNSAction { return filtering.DNSActionLocalHostPlusCache @@ -359,52 +307,94 @@ func TestDNSOverUDPTransport(t *testing.T) { } defer listener.Close() dialer := NewDialerWithoutResolver(model.DiscardLogger) - txp := NewUnwrappedDNSOverUDPTransport(dialer, listener.LocalAddr().String()) + expectedAddress := listener.LocalAddr().String() + txp := NewUnwrappedDNSOverUDPTransport(dialer, expectedAddress) encoder := &DNSEncoderMiekg{} query := encoder.Encode("dns.google.", dns.TypeA, false) - rch, err := txp.AsyncRoundTrip(context.Background(), query, 1) + zeroTime := time.Now() + deterministicTime := testingx.NewTimeDeterministic(zeroTime) + expectedAddrs := []string{"8.8.8.8"} + respChannel := make(chan *model.DNSResponse, 8) + mu := new(sync.Mutex) + tx := &mocks.Trace{ + MockTimeNow: deterministicTime.Now, + MockOnDelayedDNSResponse: func(started time.Time, txp model.DNSTransport, + query model.DNSQuery, response model.DNSResponse, addrs []string, err error, + finished time.Time) error { + mu.Lock() + delayedDNSResponseCalled = true + goodQueryType = (query.Type() == dns.TypeA) + goodTransportNetwork = (txp.Network() == "udp") + goodTransportAddress = (txp.Address() == expectedAddress) + goodLookupAddrs = (cmp.Diff(expectedAddrs, addrs) == "") + goodError = (err == nil) + mu.Unlock() + select { + case respChannel <- &response: + return nil + default: + return errors.New("full buffer") + } + }, + MockOnConnectDone: func(started time.Time, network, domain, remoteAddr string, err error, + finished time.Time) { + // do nothing + }, + MockMaybeWrapNetConn: func(conn net.Conn) net.Conn { + return conn + }, + } + ctx := ContextWithTrace(context.Background(), tx) + rch, err := txp.RoundTrip(ctx, query) + <-respChannel // wait for the delayed response if err != nil { t.Fatal(err) } - defer rch.Close() - resp, err := rch.Next(context.Background()) - if err != nil { - t.Fatal(err) - } - addrs, err := resp.DecodeLookupHost() + addrs, err := rch.DecodeLookupHost() if err != nil { t.Fatal(err) } + mu.Lock() if diff := cmp.Diff(addrs, []string{"127.0.0.1"}); diff != "" { t.Fatal(diff) } - // One would not normally busy loop but it's fine to do that in the context - // of this test because we know we're going to receive a second reply. In - // a real network experiment here we'll do other activities, e.g., contacting - // the test helper or fetching a webpage. - var additional []model.DNSResponse - for { - additional = rch.TryNextResponses() - if len(additional) > 0 { - if len(additional) != 1 { - t.Fatal("expected exactly one additional response") - } - break - } + if !delayedDNSResponseCalled { + t.Fatal("delayedDNSResponse not called") } - addrs, err = additional[0].DecodeLookupHost() - if err != nil { - t.Fatal(err) + if !goodQueryType { + t.Fatal("unexpected query type") } - if diff := cmp.Diff(addrs, []string{"8.8.8.8"}); diff != "" { - t.Fatal(diff) + if !goodTransportNetwork { + t.Fatal("unexpected DNS transport network") } + if !goodTransportAddress { + t.Fatal("unexpected DNS Transport address") + } + if !goodLookupAddrs { + t.Fatal("unexpected delayed DNSLookup address") + } + if !goodError { + t.Fatal("unexpected error encountered") + } + mu.Unlock() }) - t.Run("correct behavior when read times out", func(t *testing.T) { + t.Run("uses a context-injected custom trace (failure case)", func(t *testing.T) { + var ( + delayedDNSResponseCalled bool + goodQueryType bool + goodTransportNetwork bool + goodTransportAddress bool + goodLookupAddrs bool + goodError bool + ) srvr := &filtering.DNSServer{ OnQuery: func(domain string) filtering.DNSAction { - return filtering.DNSActionTimeout + return filtering.DNSActionLocalHostPlusCache + }, + Cache: map[string][]string{ + // Note: the cache here is nonexistent so we should + // get a "no such host" error from the server. }, } listener, err := srvr.Start("127.0.0.1:0") @@ -413,22 +403,71 @@ func TestDNSOverUDPTransport(t *testing.T) { } defer listener.Close() dialer := NewDialerWithoutResolver(model.DiscardLogger) - txp := NewUnwrappedDNSOverUDPTransport(dialer, listener.LocalAddr().String()) - txp.IOTimeout = 30 * time.Millisecond // short timeout to have a fast test + expectedAddress := listener.LocalAddr().String() + txp := NewUnwrappedDNSOverUDPTransport(dialer, expectedAddress) encoder := &DNSEncoderMiekg{} query := encoder.Encode("dns.google.", dns.TypeA, false) - rch, err := txp.AsyncRoundTrip(context.Background(), query, 1) + zeroTime := time.Now() + deterministicTime := testingx.NewTimeDeterministic(zeroTime) + respChannel := make(chan *model.DNSResponse, 8) + mu := new(sync.Mutex) + tx := &mocks.Trace{ + MockTimeNow: deterministicTime.Now, + MockOnDelayedDNSResponse: func(started time.Time, txp model.DNSTransport, + query model.DNSQuery, response model.DNSResponse, addrs []string, err error, + finished time.Time) error { + mu.Lock() + delayedDNSResponseCalled = true + goodQueryType = (query.Type() == dns.TypeA) + goodTransportNetwork = (txp.Network() == "udp") + goodTransportAddress = (txp.Address() == expectedAddress) + goodLookupAddrs = (len(addrs) == 0) + goodError = errors.Is(err, ErrOODNSNoSuchHost) + mu.Unlock() + respChannel <- &response + return errors.New("mocked") // return error to stop background routine to record responses + }, + MockOnConnectDone: func(started time.Time, network, domain, remoteAddr string, err error, + finished time.Time) { + // do nothing + }, + MockMaybeWrapNetConn: func(conn net.Conn) net.Conn { + return conn + }, + } + ctx := ContextWithTrace(context.Background(), tx) + rch, err := txp.RoundTrip(ctx, query) + <-respChannel // wait for the delayed response if err != nil { t.Fatal(err) } - defer rch.Close() - result := <-rch.Response - if result.Err == nil || result.Err.Error() != "generic_timeout_error" { - t.Fatal("unexpected error", result.Err) + addrs, err := rch.DecodeLookupHost() + if err != nil { + t.Fatal(err) } - if result.Operation != ReadOperation { - t.Fatal("unexpected failed operation", result.Operation) + mu.Lock() + if diff := cmp.Diff(addrs, []string{"127.0.0.1"}); diff != "" { + t.Fatal(diff) } + if !delayedDNSResponseCalled { + t.Fatal("delayedDNSResponse not called") + } + if !goodQueryType { + t.Fatal("unexpected query type") + } + if !goodTransportNetwork { + t.Fatal("unexpected DNS transport network") + } + if !goodTransportAddress { + t.Fatal("unexpected DNS Transport address") + } + if !goodLookupAddrs { + t.Fatal("unexpected delayed DNSLookup address") + } + if !goodError { + t.Fatal("unexpected error encountered") + } + mu.Unlock() }) }) diff --git a/internal/netxlite/trace.go b/internal/netxlite/trace.go index 8409c01..313ad8e 100644 --- a/internal/netxlite/trace.go +++ b/internal/netxlite/trace.go @@ -67,6 +67,12 @@ func (*traceDefault) OnDNSRoundTripForLookupHost(started time.Time, reso model.R // nothing } +// OnDelayedDNSResponse implements model.Trace.OnDelayedDNSResponse. +func (*traceDefault) OnDelayedDNSResponse(started time.Time, txp model.DNSTransport, + query model.DNSQuery, response model.DNSResponse, addrs []string, err error, finished time.Time) error { + return nil +} + // OnConnectDone implements model.Trace.OnConnectDone. func (*traceDefault) OnConnectDone( started time.Time, network, domain, remoteAddr string, err error, finished time.Time) {