From 62bd62ece178c1e13704fd129836547c07ca7d3f Mon Sep 17 00:00:00 2001 From: Simone Basso Date: Thu, 26 May 2022 23:49:14 +0200 Subject: [PATCH] fix(dnsoverudp): allow to cancel async round trip immediately (#763) To this end, we need to refactor the implementation to give the DNSOverUDPChannel owenership over the net.Conn. Once this happens, DNSOverUDPChannel.Close closes the conn. When the conn is closed, the background goroutine will terminate immediately because any blocking I/O operation will be immediately unblocked and return net.ErrClosed. See https://github.com/ooni/probe/issues/2099#issuecomment-1139066946 --- internal/netxlite/dnsoverudp.go | 132 +++++++++++++++++---------- internal/netxlite/dnsoverudp_test.go | 107 ++++++++++++---------- 2 files changed, 145 insertions(+), 94 deletions(-) diff --git a/internal/netxlite/dnsoverudp.go b/internal/netxlite/dnsoverudp.go index c492a9f..fa9d0e4 100644 --- a/internal/netxlite/dnsoverudp.go +++ b/internal/netxlite/dnsoverudp.go @@ -26,9 +26,9 @@ import ( // translated into socket errors (among them, host_unreachable); // // 2. connected sockets ignore responses from illegitimate IP addresses but -// most if not all DNS resolvers also do that, therefore it does not seem to +// most if not all DNS resolvers also do that, therefore this does not seem to // be a realistic censorship vector. At the same time, connected sockets -// provide us for free the feature that we don't need to bother with checking +// provide us for free with the feature that we don't need to bother with checking // whether the reply comes from the expected server. // // Being able to observe some ICMP errors is good because it could possibly @@ -86,7 +86,11 @@ func (t *DNSOverUDPTransport) RoundTrip( const opTimeout = 5 * time.Second ctx, cancel := context.WithTimeout(ctx, opTimeout) defer cancel() - outch := t.AsyncRoundTrip(query, 1) // buffer to avoid background's goroutine leak + outch, err := t.AsyncRoundTrip(ctx, query, 1) // buffer to avoid background's goroutine leak + if err != nil { + return nil, err + } + defer outch.Close() // we own the channel return outch.Next(ctx) } @@ -153,16 +157,30 @@ func (t *DNSOverUDPTransport) newDNSOverUDPResponse(localAddr string, err error, // or more *DNSOverUDPResponse that makes extracting information from // the underlying channels more user friendly than interacting with // the channels directly, thanks to useful wrapper methods implementing -// common access patterns. You can still use the channels directly if -// there's no convenience method for your specific access pattern. +// common access patterns. You can still use the underlying channels +// directly if there's no suitable convenience method. +// +// You MUST call the .Close method when done. Not calling such a method +// leaks goroutines and causes connections to stay open forever. type DNSOverUDPChannel struct { // Response is the channel where we'll post responses. This channel - // WON'T be closed when the background goroutine terminates. + // WILL NOT be closed when the background goroutine terminates. Response <-chan *DNSOverUDPResponse - // Joined is a channel that IS CLOSED when the background - // goroutine performing this round trip TERMINATES. + // Joined IS CLOSED when the background goroutine terminates. Joined <-chan bool + + // conn is the underlying connection, which we can Close to + // immediately cause the background goroutine to join. + conn net.Conn +} + +// Close releases the resources allocated by the channel. You MUST +// call this method to force the background goroutine that is performing +// the round trip to terminate. Calling this method also ensures we +// close the connection used by the round trip. This method is idempotent. +func (ch *DNSOverUDPChannel) Close() error { + return ch.conn.Close() } // Next blocks until the next response is received on Response or the @@ -205,11 +223,12 @@ func (ch *DNSOverUDPChannel) TryNextResponses() (out []model.DNSResponse) { // // The real round trip runs in a background goroutine. We will terminate the background // goroutine when (1) the IOTimeout expires for the connection we're using or (2) we -// cannot write on the "Response" channel. Note that the background goroutine WILL NOT -// close the "Response" channel to signal its completion. Hence, who reads such a -// channel MUST be prepared for read operations to block forever and use a -// select for draining the channel in a deadlock-safe way. Also, we WILL NOT ever -// emit a nil message over the "Response" channel. +// cannot write on the "Response" channel or (3) the connection is closed by calling the +// Close method of DNSOverUDPChannel. Note that the background goroutine WILL NOT close +// the "Response" channel to signal its completion. Hence, who reads such a +// channel MUST be prepared for read operations to block forever (i.e., should use +// a select operation for draining the channel in a deadlock-safe way). Also, +// we WILL NOT ever post a nil message to the "Response" channel. // // The returned DNSOverUDPChannel contains another channel called Joined that is // closed when the background goroutine terminates, so you can use this channel @@ -217,56 +236,75 @@ func (ch *DNSOverUDPChannel) TryNextResponses() (out []model.DNSResponse) { // // If you are using the Next or TryNextResponses methods of the DNSOverUDPChannel type, // you don't need to worry about these low level details though. -func (t *DNSOverUDPTransport) AsyncRoundTrip(query model.DNSQuery, buffer int) *DNSOverUDPChannel { +// +// We give you OWNERSHIP of the returned DNSOverUDPChannel and you MUST +// call its .Close method when done with using it. +func (t *DNSOverUDPTransport) AsyncRoundTrip( + ctx context.Context, query model.DNSQuery, buffer int) (*DNSOverUDPChannel, error) { + rawQuery, err := query.Bytes() + if err != nil { + return nil, err + } + conn, err := t.Dialer.DialContext(ctx, "udp", t.Endpoint) + if err != nil { + return nil, err + } + conn.SetDeadline(time.Now().Add(t.IOTimeout)) if buffer < 2 { buffer = 1 // as documented } outch := make(chan *DNSOverUDPResponse, buffer) joinedch := make(chan bool) - go t.roundTripLoop(query, outch, joinedch) - return &DNSOverUDPChannel{ + go t.sendRecvLoop(conn, rawQuery, query, outch, joinedch) + dnsch := &DNSOverUDPChannel{ Response: outch, Joined: joinedch, + conn: conn, // transfer ownership } + return dnsch, nil } -// roundTripLoop performs the round trip and writes results into the "outch" channel. This -// function ASSUMES that "outch" is configured to have AT LEAST one buffer slot. This function -// TAKES OWNERSHIP of "outch" but WILL NOT close it when done. This function instead OWNS -// the "joinedch" channel and WILL CLOSE it when done. -func (t *DNSOverUDPTransport) roundTripLoop( - query model.DNSQuery, outch chan<- *DNSOverUDPResponse, joinedch chan<- bool) { - defer close(joinedch) // as documented - rawQuery, err := query.Bytes() - if err != nil { +// sendRecvLoop sends the given raw query on the given conn and receives responses +// from the conn posting them onto the given output channel. +// +// Arguments: +// +// 1. conn is the BORROWED net.Conn (we will use it for reading or writing but +// we do not own the connection and we're not going to close it); +// +// 2. rawQuery contains the rawQuery and is BORROWED (we won't modify it); +// +// 3. query contains the original query and is also BORROWED; +// +// 4. outch is the channel where to emit measurements and is OWNED by this +// function (that said, we WILL NOT close this channel); +// +// 5. eofch is the channel to signal EOF, which is OWNED by this function +// and closed when this function exits. +// +// This method terminates in the following cases: +// +// 1. I/O error while reading or writing (including the deadline expiring or +// the owner of the connection closing the connection); +// +// 2. We cannot post on the output channel because either there is +// noone reading the channel or the channel's buffer is full. +// +// 3. We cannot parse incoming data as a valid DNS response message that +// responds to the query that we originally sent. +func (t *DNSOverUDPTransport) sendRecvLoop(conn net.Conn, rawQuery []byte, + query model.DNSQuery, outch chan<- *DNSOverUDPResponse, eofch chan<- bool) { + defer close(eofch) // synchronize with the caller + myaddr := conn.LocalAddr().String() + if _, err := conn.Write(rawQuery); err != nil { outch <- t.newDNSOverUDPResponse( - "", err, query, nil, "serialize_query") // one-sized buffer, can't block - return - } - // While dial operations return immediately for UDP, we MAY be calling the - // dialer's resolver if t.Endpoint contains a domain name. So, let us basically - // enforce the same overall deadline covering DNS lookup and I/O operations. - deadline := time.Now().Add(t.IOTimeout) - ctx, cancel := context.WithDeadline(context.Background(), deadline) - defer cancel() - conn, err := t.Dialer.DialContext(ctx, "udp", t.Endpoint) - if err != nil { - outch <- t.newDNSOverUDPResponse( - "", err, query, nil, ConnectOperation) // one-sized buffer, can't block - return - } - defer conn.Close() // we own the conn - conn.SetDeadline(deadline) - localAddr := conn.LocalAddr().String() - if _, err = conn.Write(rawQuery); err != nil { - outch <- t.newDNSOverUDPResponse( - localAddr, err, query, nil, WriteOperation) // one-sized buffer, can't block + myaddr, err, query, nil, WriteOperation) // one-sized buffer, can't block return } for { resp, err := t.recv(query, conn) select { - case outch <- t.newDNSOverUDPResponse(localAddr, err, query, resp, ReadOperation): + case outch <- t.newDNSOverUDPResponse(myaddr, err, query, resp, ReadOperation): default: return // no-one is reading the channel -- so long... } diff --git a/internal/netxlite/dnsoverudp_test.go b/internal/netxlite/dnsoverudp_test.go index 895179a..f79b2f7 100644 --- a/internal/netxlite/dnsoverudp_test.go +++ b/internal/netxlite/dnsoverudp_test.go @@ -283,60 +283,65 @@ func TestDNSOverUDPTransport(t *testing.T) { t.Run("AsyncRoundTrip", func(t *testing.T) { t.Run("calling Next with cancelled context", func(t *testing.T) { - blocker := make(chan interface{}) - const expected = 17 - input := bytes.NewReader(make([]byte, expected)) - txp := NewDNSOverUDPTransport( - &mocks.Dialer{ - MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) { - <-blocker // block here until Next returns because of expired context - return &mocks.Conn{ - MockSetDeadline: func(t time.Time) error { - return nil - }, - MockWrite: func(b []byte) (int, error) { - return len(b), nil - }, - MockRead: input.Read, - MockClose: func() error { - return nil - }, - MockLocalAddr: func() net.Addr { - return &mocks.Addr{ - MockNetwork: func() string { - return "udp" - }, - MockString: func() string { - return "127.0.0.1:1345" - }, - } - }, - }, nil - }, - }, "9.9.9.9:53", - ) - expectedResp := &mocks.DNSResponse{} - txp.Decoder = &mocks.DNSDecoder{ - MockDecodeResponse: func(data []byte, query model.DNSQuery) (model.DNSResponse, error) { - return expectedResp, nil + srvr := &filtering.DNSServer{ + OnQuery: func(domain string) filtering.DNSAction { + return filtering.DNSActionCache + }, + Cache: map[string][]string{ + "dns.google.": {"8.8.8.8"}, }, } - query := &mocks.DNSQuery{ - MockBytes: func() ([]byte, error) { - return make([]byte, 128), nil - }, + listener, err := srvr.Start("127.0.0.1:0") + if err != nil { + t.Fatal(err) } - out := txp.AsyncRoundTrip(query, 1) - ctx, cancel := context.WithCancel(context.Background()) - cancel() // immediately cancel - resp, err := out.Next(ctx) + defer listener.Close() + dialer := NewDialerWithoutResolver(model.DiscardLogger) + txp := NewDNSOverUDPTransport(dialer, listener.LocalAddr().String()) + encoder := &DNSEncoderMiekg{} + query := encoder.Encode("dns.google.", dns.TypeA, false) + ctx := context.Background() + rch, err := txp.AsyncRoundTrip(ctx, query, 1) + if err != nil { + t.Fatal(err) + } + defer rch.Close() + ctx, cancel := context.WithCancel(ctx) + cancel() // fail immediately + resp, err := rch.Next(ctx) if !errors.Is(err, context.Canceled) { t.Fatal("unexpected err", err) } if resp != nil { t.Fatal("unexpected resp") } - close(blocker) // unblock the background goroutine + }) + + t.Run("no-one is reading the channel", func(t *testing.T) { + srvr := &filtering.DNSServer{ + OnQuery: func(domain string) filtering.DNSAction { + return filtering.DNSActionLocalHostPlusCache // i.e., two responses + }, + Cache: map[string][]string{ + "dns.google.": {"8.8.8.8"}, + }, + } + listener, err := srvr.Start("127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer listener.Close() + dialer := NewDialerWithoutResolver(model.DiscardLogger) + txp := NewDNSOverUDPTransport(dialer, listener.LocalAddr().String()) + encoder := &DNSEncoderMiekg{} + query := encoder.Encode("dns.google.", dns.TypeA, false) + ctx := context.Background() + rch, err := txp.AsyncRoundTrip(ctx, query, 1) // but just one place + if err != nil { + t.Fatal(err) + } + defer rch.Close() + <-rch.Joined // should see no-one is reading and stop }) t.Run("typical usage to obtain late responses", func(t *testing.T) { @@ -357,7 +362,11 @@ func TestDNSOverUDPTransport(t *testing.T) { txp := NewDNSOverUDPTransport(dialer, listener.LocalAddr().String()) encoder := &DNSEncoderMiekg{} query := encoder.Encode("dns.google.", dns.TypeA, false) - rch := txp.AsyncRoundTrip(query, 1) + rch, err := txp.AsyncRoundTrip(context.Background(), query, 1) + if err != nil { + t.Fatal(err) + } + defer rch.Close() resp, err := rch.Next(context.Background()) if err != nil { t.Fatal(err) @@ -408,7 +417,11 @@ func TestDNSOverUDPTransport(t *testing.T) { txp.IOTimeout = 30 * time.Millisecond // short timeout to have a fast test encoder := &DNSEncoderMiekg{} query := encoder.Encode("dns.google.", dns.TypeA, false) - rch := txp.AsyncRoundTrip(query, 1) + rch, err := txp.AsyncRoundTrip(context.Background(), query, 1) + if err != nil { + t.Fatal(err) + } + defer rch.Close() result := <-rch.Response if result.Err == nil || result.Err.Error() != "generic_timeout_error" { t.Fatal("unexpected error", result.Err)