fix(dnsoverudp): allow to cancel async round trip immediately (#763)

To this end, we need to refactor the implementation to give the
DNSOverUDPChannel owenership over the net.Conn.

Once this happens, DNSOverUDPChannel.Close closes the conn.

When the conn is closed, the background goroutine will terminate
immediately because any blocking I/O operation will be immediately
unblocked and return net.ErrClosed.

See https://github.com/ooni/probe/issues/2099#issuecomment-1139066946
This commit is contained in:
Simone Basso 2022-05-26 23:49:14 +02:00 committed by GitHub
parent 16f7407b13
commit 62bd62ece1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 145 additions and 94 deletions

View File

@ -26,9 +26,9 @@ import (
// translated into socket errors (among them, host_unreachable); // translated into socket errors (among them, host_unreachable);
// //
// 2. connected sockets ignore responses from illegitimate IP addresses but // 2. connected sockets ignore responses from illegitimate IP addresses but
// most if not all DNS resolvers also do that, therefore it does not seem to // most if not all DNS resolvers also do that, therefore this does not seem to
// be a realistic censorship vector. At the same time, connected sockets // be a realistic censorship vector. At the same time, connected sockets
// provide us for free the feature that we don't need to bother with checking // provide us for free with the feature that we don't need to bother with checking
// whether the reply comes from the expected server. // whether the reply comes from the expected server.
// //
// Being able to observe some ICMP errors is good because it could possibly // Being able to observe some ICMP errors is good because it could possibly
@ -86,7 +86,11 @@ func (t *DNSOverUDPTransport) RoundTrip(
const opTimeout = 5 * time.Second const opTimeout = 5 * time.Second
ctx, cancel := context.WithTimeout(ctx, opTimeout) ctx, cancel := context.WithTimeout(ctx, opTimeout)
defer cancel() defer cancel()
outch := t.AsyncRoundTrip(query, 1) // buffer to avoid background's goroutine leak outch, err := t.AsyncRoundTrip(ctx, query, 1) // buffer to avoid background's goroutine leak
if err != nil {
return nil, err
}
defer outch.Close() // we own the channel
return outch.Next(ctx) return outch.Next(ctx)
} }
@ -153,16 +157,30 @@ func (t *DNSOverUDPTransport) newDNSOverUDPResponse(localAddr string, err error,
// or more *DNSOverUDPResponse that makes extracting information from // or more *DNSOverUDPResponse that makes extracting information from
// the underlying channels more user friendly than interacting with // the underlying channels more user friendly than interacting with
// the channels directly, thanks to useful wrapper methods implementing // the channels directly, thanks to useful wrapper methods implementing
// common access patterns. You can still use the channels directly if // common access patterns. You can still use the underlying channels
// there's no convenience method for your specific access pattern. // directly if there's no suitable convenience method.
//
// You MUST call the .Close method when done. Not calling such a method
// leaks goroutines and causes connections to stay open forever.
type DNSOverUDPChannel struct { type DNSOverUDPChannel struct {
// Response is the channel where we'll post responses. This channel // Response is the channel where we'll post responses. This channel
// WON'T be closed when the background goroutine terminates. // WILL NOT be closed when the background goroutine terminates.
Response <-chan *DNSOverUDPResponse Response <-chan *DNSOverUDPResponse
// Joined is a channel that IS CLOSED when the background // Joined IS CLOSED when the background goroutine terminates.
// goroutine performing this round trip TERMINATES.
Joined <-chan bool Joined <-chan bool
// conn is the underlying connection, which we can Close to
// immediately cause the background goroutine to join.
conn net.Conn
}
// Close releases the resources allocated by the channel. You MUST
// call this method to force the background goroutine that is performing
// the round trip to terminate. Calling this method also ensures we
// close the connection used by the round trip. This method is idempotent.
func (ch *DNSOverUDPChannel) Close() error {
return ch.conn.Close()
} }
// Next blocks until the next response is received on Response or the // Next blocks until the next response is received on Response or the
@ -205,11 +223,12 @@ func (ch *DNSOverUDPChannel) TryNextResponses() (out []model.DNSResponse) {
// //
// The real round trip runs in a background goroutine. We will terminate the background // The real round trip runs in a background goroutine. We will terminate the background
// goroutine when (1) the IOTimeout expires for the connection we're using or (2) we // goroutine when (1) the IOTimeout expires for the connection we're using or (2) we
// cannot write on the "Response" channel. Note that the background goroutine WILL NOT // cannot write on the "Response" channel or (3) the connection is closed by calling the
// close the "Response" channel to signal its completion. Hence, who reads such a // Close method of DNSOverUDPChannel. Note that the background goroutine WILL NOT close
// channel MUST be prepared for read operations to block forever and use a // the "Response" channel to signal its completion. Hence, who reads such a
// select for draining the channel in a deadlock-safe way. Also, we WILL NOT ever // channel MUST be prepared for read operations to block forever (i.e., should use
// emit a nil message over the "Response" channel. // a select operation for draining the channel in a deadlock-safe way). Also,
// we WILL NOT ever post a nil message to the "Response" channel.
// //
// The returned DNSOverUDPChannel contains another channel called Joined that is // The returned DNSOverUDPChannel contains another channel called Joined that is
// closed when the background goroutine terminates, so you can use this channel // closed when the background goroutine terminates, so you can use this channel
@ -217,56 +236,75 @@ func (ch *DNSOverUDPChannel) TryNextResponses() (out []model.DNSResponse) {
// //
// If you are using the Next or TryNextResponses methods of the DNSOverUDPChannel type, // If you are using the Next or TryNextResponses methods of the DNSOverUDPChannel type,
// you don't need to worry about these low level details though. // you don't need to worry about these low level details though.
func (t *DNSOverUDPTransport) AsyncRoundTrip(query model.DNSQuery, buffer int) *DNSOverUDPChannel { //
// We give you OWNERSHIP of the returned DNSOverUDPChannel and you MUST
// call its .Close method when done with using it.
func (t *DNSOverUDPTransport) AsyncRoundTrip(
ctx context.Context, query model.DNSQuery, buffer int) (*DNSOverUDPChannel, error) {
rawQuery, err := query.Bytes()
if err != nil {
return nil, err
}
conn, err := t.Dialer.DialContext(ctx, "udp", t.Endpoint)
if err != nil {
return nil, err
}
conn.SetDeadline(time.Now().Add(t.IOTimeout))
if buffer < 2 { if buffer < 2 {
buffer = 1 // as documented buffer = 1 // as documented
} }
outch := make(chan *DNSOverUDPResponse, buffer) outch := make(chan *DNSOverUDPResponse, buffer)
joinedch := make(chan bool) joinedch := make(chan bool)
go t.roundTripLoop(query, outch, joinedch) go t.sendRecvLoop(conn, rawQuery, query, outch, joinedch)
return &DNSOverUDPChannel{ dnsch := &DNSOverUDPChannel{
Response: outch, Response: outch,
Joined: joinedch, Joined: joinedch,
conn: conn, // transfer ownership
} }
return dnsch, nil
} }
// roundTripLoop performs the round trip and writes results into the "outch" channel. This // sendRecvLoop sends the given raw query on the given conn and receives responses
// function ASSUMES that "outch" is configured to have AT LEAST one buffer slot. This function // from the conn posting them onto the given output channel.
// TAKES OWNERSHIP of "outch" but WILL NOT close it when done. This function instead OWNS //
// the "joinedch" channel and WILL CLOSE it when done. // Arguments:
func (t *DNSOverUDPTransport) roundTripLoop( //
query model.DNSQuery, outch chan<- *DNSOverUDPResponse, joinedch chan<- bool) { // 1. conn is the BORROWED net.Conn (we will use it for reading or writing but
defer close(joinedch) // as documented // we do not own the connection and we're not going to close it);
rawQuery, err := query.Bytes() //
if err != nil { // 2. rawQuery contains the rawQuery and is BORROWED (we won't modify it);
//
// 3. query contains the original query and is also BORROWED;
//
// 4. outch is the channel where to emit measurements and is OWNED by this
// function (that said, we WILL NOT close this channel);
//
// 5. eofch is the channel to signal EOF, which is OWNED by this function
// and closed when this function exits.
//
// This method terminates in the following cases:
//
// 1. I/O error while reading or writing (including the deadline expiring or
// the owner of the connection closing the connection);
//
// 2. We cannot post on the output channel because either there is
// noone reading the channel or the channel's buffer is full.
//
// 3. We cannot parse incoming data as a valid DNS response message that
// responds to the query that we originally sent.
func (t *DNSOverUDPTransport) sendRecvLoop(conn net.Conn, rawQuery []byte,
query model.DNSQuery, outch chan<- *DNSOverUDPResponse, eofch chan<- bool) {
defer close(eofch) // synchronize with the caller
myaddr := conn.LocalAddr().String()
if _, err := conn.Write(rawQuery); err != nil {
outch <- t.newDNSOverUDPResponse( outch <- t.newDNSOverUDPResponse(
"", err, query, nil, "serialize_query") // one-sized buffer, can't block myaddr, err, query, nil, WriteOperation) // one-sized buffer, can't block
return
}
// While dial operations return immediately for UDP, we MAY be calling the
// dialer's resolver if t.Endpoint contains a domain name. So, let us basically
// enforce the same overall deadline covering DNS lookup and I/O operations.
deadline := time.Now().Add(t.IOTimeout)
ctx, cancel := context.WithDeadline(context.Background(), deadline)
defer cancel()
conn, err := t.Dialer.DialContext(ctx, "udp", t.Endpoint)
if err != nil {
outch <- t.newDNSOverUDPResponse(
"", err, query, nil, ConnectOperation) // one-sized buffer, can't block
return
}
defer conn.Close() // we own the conn
conn.SetDeadline(deadline)
localAddr := conn.LocalAddr().String()
if _, err = conn.Write(rawQuery); err != nil {
outch <- t.newDNSOverUDPResponse(
localAddr, err, query, nil, WriteOperation) // one-sized buffer, can't block
return return
} }
for { for {
resp, err := t.recv(query, conn) resp, err := t.recv(query, conn)
select { select {
case outch <- t.newDNSOverUDPResponse(localAddr, err, query, resp, ReadOperation): case outch <- t.newDNSOverUDPResponse(myaddr, err, query, resp, ReadOperation):
default: default:
return // no-one is reading the channel -- so long... return // no-one is reading the channel -- so long...
} }

View File

@ -283,60 +283,65 @@ func TestDNSOverUDPTransport(t *testing.T) {
t.Run("AsyncRoundTrip", func(t *testing.T) { t.Run("AsyncRoundTrip", func(t *testing.T) {
t.Run("calling Next with cancelled context", func(t *testing.T) { t.Run("calling Next with cancelled context", func(t *testing.T) {
blocker := make(chan interface{}) srvr := &filtering.DNSServer{
const expected = 17 OnQuery: func(domain string) filtering.DNSAction {
input := bytes.NewReader(make([]byte, expected)) return filtering.DNSActionCache
txp := NewDNSOverUDPTransport( },
&mocks.Dialer{ Cache: map[string][]string{
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) { "dns.google.": {"8.8.8.8"},
<-blocker // block here until Next returns because of expired context
return &mocks.Conn{
MockSetDeadline: func(t time.Time) error {
return nil
},
MockWrite: func(b []byte) (int, error) {
return len(b), nil
},
MockRead: input.Read,
MockClose: func() error {
return nil
},
MockLocalAddr: func() net.Addr {
return &mocks.Addr{
MockNetwork: func() string {
return "udp"
},
MockString: func() string {
return "127.0.0.1:1345"
},
}
},
}, nil
},
}, "9.9.9.9:53",
)
expectedResp := &mocks.DNSResponse{}
txp.Decoder = &mocks.DNSDecoder{
MockDecodeResponse: func(data []byte, query model.DNSQuery) (model.DNSResponse, error) {
return expectedResp, nil
}, },
} }
query := &mocks.DNSQuery{ listener, err := srvr.Start("127.0.0.1:0")
MockBytes: func() ([]byte, error) { if err != nil {
return make([]byte, 128), nil t.Fatal(err)
},
} }
out := txp.AsyncRoundTrip(query, 1) defer listener.Close()
ctx, cancel := context.WithCancel(context.Background()) dialer := NewDialerWithoutResolver(model.DiscardLogger)
cancel() // immediately cancel txp := NewDNSOverUDPTransport(dialer, listener.LocalAddr().String())
resp, err := out.Next(ctx) encoder := &DNSEncoderMiekg{}
query := encoder.Encode("dns.google.", dns.TypeA, false)
ctx := context.Background()
rch, err := txp.AsyncRoundTrip(ctx, query, 1)
if err != nil {
t.Fatal(err)
}
defer rch.Close()
ctx, cancel := context.WithCancel(ctx)
cancel() // fail immediately
resp, err := rch.Next(ctx)
if !errors.Is(err, context.Canceled) { if !errors.Is(err, context.Canceled) {
t.Fatal("unexpected err", err) t.Fatal("unexpected err", err)
} }
if resp != nil { if resp != nil {
t.Fatal("unexpected resp") t.Fatal("unexpected resp")
} }
close(blocker) // unblock the background goroutine })
t.Run("no-one is reading the channel", func(t *testing.T) {
srvr := &filtering.DNSServer{
OnQuery: func(domain string) filtering.DNSAction {
return filtering.DNSActionLocalHostPlusCache // i.e., two responses
},
Cache: map[string][]string{
"dns.google.": {"8.8.8.8"},
},
}
listener, err := srvr.Start("127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
defer listener.Close()
dialer := NewDialerWithoutResolver(model.DiscardLogger)
txp := NewDNSOverUDPTransport(dialer, listener.LocalAddr().String())
encoder := &DNSEncoderMiekg{}
query := encoder.Encode("dns.google.", dns.TypeA, false)
ctx := context.Background()
rch, err := txp.AsyncRoundTrip(ctx, query, 1) // but just one place
if err != nil {
t.Fatal(err)
}
defer rch.Close()
<-rch.Joined // should see no-one is reading and stop
}) })
t.Run("typical usage to obtain late responses", func(t *testing.T) { t.Run("typical usage to obtain late responses", func(t *testing.T) {
@ -357,7 +362,11 @@ func TestDNSOverUDPTransport(t *testing.T) {
txp := NewDNSOverUDPTransport(dialer, listener.LocalAddr().String()) txp := NewDNSOverUDPTransport(dialer, listener.LocalAddr().String())
encoder := &DNSEncoderMiekg{} encoder := &DNSEncoderMiekg{}
query := encoder.Encode("dns.google.", dns.TypeA, false) query := encoder.Encode("dns.google.", dns.TypeA, false)
rch := txp.AsyncRoundTrip(query, 1) rch, err := txp.AsyncRoundTrip(context.Background(), query, 1)
if err != nil {
t.Fatal(err)
}
defer rch.Close()
resp, err := rch.Next(context.Background()) resp, err := rch.Next(context.Background())
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -408,7 +417,11 @@ func TestDNSOverUDPTransport(t *testing.T) {
txp.IOTimeout = 30 * time.Millisecond // short timeout to have a fast test txp.IOTimeout = 30 * time.Millisecond // short timeout to have a fast test
encoder := &DNSEncoderMiekg{} encoder := &DNSEncoderMiekg{}
query := encoder.Encode("dns.google.", dns.TypeA, false) query := encoder.Encode("dns.google.", dns.TypeA, false)
rch := txp.AsyncRoundTrip(query, 1) rch, err := txp.AsyncRoundTrip(context.Background(), query, 1)
if err != nil {
t.Fatal(err)
}
defer rch.Close()
result := <-rch.Response result := <-rch.Response
if result.Err == nil || result.Err.Error() != "generic_timeout_error" { if result.Err == nil || result.Err.Error() != "generic_timeout_error" {
t.Fatal("unexpected error", result.Err) t.Fatal("unexpected error", result.Err)