fix(dnsoverudp): allow to cancel async round trip immediately (#763)
To this end, we need to refactor the implementation to give the DNSOverUDPChannel owenership over the net.Conn. Once this happens, DNSOverUDPChannel.Close closes the conn. When the conn is closed, the background goroutine will terminate immediately because any blocking I/O operation will be immediately unblocked and return net.ErrClosed. See https://github.com/ooni/probe/issues/2099#issuecomment-1139066946
This commit is contained in:
parent
16f7407b13
commit
62bd62ece1
|
@ -26,9 +26,9 @@ import (
|
||||||
// translated into socket errors (among them, host_unreachable);
|
// translated into socket errors (among them, host_unreachable);
|
||||||
//
|
//
|
||||||
// 2. connected sockets ignore responses from illegitimate IP addresses but
|
// 2. connected sockets ignore responses from illegitimate IP addresses but
|
||||||
// most if not all DNS resolvers also do that, therefore it does not seem to
|
// most if not all DNS resolvers also do that, therefore this does not seem to
|
||||||
// be a realistic censorship vector. At the same time, connected sockets
|
// be a realistic censorship vector. At the same time, connected sockets
|
||||||
// provide us for free the feature that we don't need to bother with checking
|
// provide us for free with the feature that we don't need to bother with checking
|
||||||
// whether the reply comes from the expected server.
|
// whether the reply comes from the expected server.
|
||||||
//
|
//
|
||||||
// Being able to observe some ICMP errors is good because it could possibly
|
// Being able to observe some ICMP errors is good because it could possibly
|
||||||
|
@ -86,7 +86,11 @@ func (t *DNSOverUDPTransport) RoundTrip(
|
||||||
const opTimeout = 5 * time.Second
|
const opTimeout = 5 * time.Second
|
||||||
ctx, cancel := context.WithTimeout(ctx, opTimeout)
|
ctx, cancel := context.WithTimeout(ctx, opTimeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
outch := t.AsyncRoundTrip(query, 1) // buffer to avoid background's goroutine leak
|
outch, err := t.AsyncRoundTrip(ctx, query, 1) // buffer to avoid background's goroutine leak
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer outch.Close() // we own the channel
|
||||||
return outch.Next(ctx)
|
return outch.Next(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -153,16 +157,30 @@ func (t *DNSOverUDPTransport) newDNSOverUDPResponse(localAddr string, err error,
|
||||||
// or more *DNSOverUDPResponse that makes extracting information from
|
// or more *DNSOverUDPResponse that makes extracting information from
|
||||||
// the underlying channels more user friendly than interacting with
|
// the underlying channels more user friendly than interacting with
|
||||||
// the channels directly, thanks to useful wrapper methods implementing
|
// the channels directly, thanks to useful wrapper methods implementing
|
||||||
// common access patterns. You can still use the channels directly if
|
// common access patterns. You can still use the underlying channels
|
||||||
// there's no convenience method for your specific access pattern.
|
// directly if there's no suitable convenience method.
|
||||||
|
//
|
||||||
|
// You MUST call the .Close method when done. Not calling such a method
|
||||||
|
// leaks goroutines and causes connections to stay open forever.
|
||||||
type DNSOverUDPChannel struct {
|
type DNSOverUDPChannel struct {
|
||||||
// Response is the channel where we'll post responses. This channel
|
// Response is the channel where we'll post responses. This channel
|
||||||
// WON'T be closed when the background goroutine terminates.
|
// WILL NOT be closed when the background goroutine terminates.
|
||||||
Response <-chan *DNSOverUDPResponse
|
Response <-chan *DNSOverUDPResponse
|
||||||
|
|
||||||
// Joined is a channel that IS CLOSED when the background
|
// Joined IS CLOSED when the background goroutine terminates.
|
||||||
// goroutine performing this round trip TERMINATES.
|
|
||||||
Joined <-chan bool
|
Joined <-chan bool
|
||||||
|
|
||||||
|
// conn is the underlying connection, which we can Close to
|
||||||
|
// immediately cause the background goroutine to join.
|
||||||
|
conn net.Conn
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close releases the resources allocated by the channel. You MUST
|
||||||
|
// call this method to force the background goroutine that is performing
|
||||||
|
// the round trip to terminate. Calling this method also ensures we
|
||||||
|
// close the connection used by the round trip. This method is idempotent.
|
||||||
|
func (ch *DNSOverUDPChannel) Close() error {
|
||||||
|
return ch.conn.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Next blocks until the next response is received on Response or the
|
// Next blocks until the next response is received on Response or the
|
||||||
|
@ -205,11 +223,12 @@ func (ch *DNSOverUDPChannel) TryNextResponses() (out []model.DNSResponse) {
|
||||||
//
|
//
|
||||||
// The real round trip runs in a background goroutine. We will terminate the background
|
// The real round trip runs in a background goroutine. We will terminate the background
|
||||||
// goroutine when (1) the IOTimeout expires for the connection we're using or (2) we
|
// goroutine when (1) the IOTimeout expires for the connection we're using or (2) we
|
||||||
// cannot write on the "Response" channel. Note that the background goroutine WILL NOT
|
// cannot write on the "Response" channel or (3) the connection is closed by calling the
|
||||||
// close the "Response" channel to signal its completion. Hence, who reads such a
|
// Close method of DNSOverUDPChannel. Note that the background goroutine WILL NOT close
|
||||||
// channel MUST be prepared for read operations to block forever and use a
|
// the "Response" channel to signal its completion. Hence, who reads such a
|
||||||
// select for draining the channel in a deadlock-safe way. Also, we WILL NOT ever
|
// channel MUST be prepared for read operations to block forever (i.e., should use
|
||||||
// emit a nil message over the "Response" channel.
|
// a select operation for draining the channel in a deadlock-safe way). Also,
|
||||||
|
// we WILL NOT ever post a nil message to the "Response" channel.
|
||||||
//
|
//
|
||||||
// The returned DNSOverUDPChannel contains another channel called Joined that is
|
// The returned DNSOverUDPChannel contains another channel called Joined that is
|
||||||
// closed when the background goroutine terminates, so you can use this channel
|
// closed when the background goroutine terminates, so you can use this channel
|
||||||
|
@ -217,56 +236,75 @@ func (ch *DNSOverUDPChannel) TryNextResponses() (out []model.DNSResponse) {
|
||||||
//
|
//
|
||||||
// If you are using the Next or TryNextResponses methods of the DNSOverUDPChannel type,
|
// If you are using the Next or TryNextResponses methods of the DNSOverUDPChannel type,
|
||||||
// you don't need to worry about these low level details though.
|
// you don't need to worry about these low level details though.
|
||||||
func (t *DNSOverUDPTransport) AsyncRoundTrip(query model.DNSQuery, buffer int) *DNSOverUDPChannel {
|
//
|
||||||
|
// We give you OWNERSHIP of the returned DNSOverUDPChannel and you MUST
|
||||||
|
// call its .Close method when done with using it.
|
||||||
|
func (t *DNSOverUDPTransport) AsyncRoundTrip(
|
||||||
|
ctx context.Context, query model.DNSQuery, buffer int) (*DNSOverUDPChannel, error) {
|
||||||
|
rawQuery, err := query.Bytes()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
conn, err := t.Dialer.DialContext(ctx, "udp", t.Endpoint)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
conn.SetDeadline(time.Now().Add(t.IOTimeout))
|
||||||
if buffer < 2 {
|
if buffer < 2 {
|
||||||
buffer = 1 // as documented
|
buffer = 1 // as documented
|
||||||
}
|
}
|
||||||
outch := make(chan *DNSOverUDPResponse, buffer)
|
outch := make(chan *DNSOverUDPResponse, buffer)
|
||||||
joinedch := make(chan bool)
|
joinedch := make(chan bool)
|
||||||
go t.roundTripLoop(query, outch, joinedch)
|
go t.sendRecvLoop(conn, rawQuery, query, outch, joinedch)
|
||||||
return &DNSOverUDPChannel{
|
dnsch := &DNSOverUDPChannel{
|
||||||
Response: outch,
|
Response: outch,
|
||||||
Joined: joinedch,
|
Joined: joinedch,
|
||||||
|
conn: conn, // transfer ownership
|
||||||
}
|
}
|
||||||
|
return dnsch, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// roundTripLoop performs the round trip and writes results into the "outch" channel. This
|
// sendRecvLoop sends the given raw query on the given conn and receives responses
|
||||||
// function ASSUMES that "outch" is configured to have AT LEAST one buffer slot. This function
|
// from the conn posting them onto the given output channel.
|
||||||
// TAKES OWNERSHIP of "outch" but WILL NOT close it when done. This function instead OWNS
|
//
|
||||||
// the "joinedch" channel and WILL CLOSE it when done.
|
// Arguments:
|
||||||
func (t *DNSOverUDPTransport) roundTripLoop(
|
//
|
||||||
query model.DNSQuery, outch chan<- *DNSOverUDPResponse, joinedch chan<- bool) {
|
// 1. conn is the BORROWED net.Conn (we will use it for reading or writing but
|
||||||
defer close(joinedch) // as documented
|
// we do not own the connection and we're not going to close it);
|
||||||
rawQuery, err := query.Bytes()
|
//
|
||||||
if err != nil {
|
// 2. rawQuery contains the rawQuery and is BORROWED (we won't modify it);
|
||||||
|
//
|
||||||
|
// 3. query contains the original query and is also BORROWED;
|
||||||
|
//
|
||||||
|
// 4. outch is the channel where to emit measurements and is OWNED by this
|
||||||
|
// function (that said, we WILL NOT close this channel);
|
||||||
|
//
|
||||||
|
// 5. eofch is the channel to signal EOF, which is OWNED by this function
|
||||||
|
// and closed when this function exits.
|
||||||
|
//
|
||||||
|
// This method terminates in the following cases:
|
||||||
|
//
|
||||||
|
// 1. I/O error while reading or writing (including the deadline expiring or
|
||||||
|
// the owner of the connection closing the connection);
|
||||||
|
//
|
||||||
|
// 2. We cannot post on the output channel because either there is
|
||||||
|
// noone reading the channel or the channel's buffer is full.
|
||||||
|
//
|
||||||
|
// 3. We cannot parse incoming data as a valid DNS response message that
|
||||||
|
// responds to the query that we originally sent.
|
||||||
|
func (t *DNSOverUDPTransport) sendRecvLoop(conn net.Conn, rawQuery []byte,
|
||||||
|
query model.DNSQuery, outch chan<- *DNSOverUDPResponse, eofch chan<- bool) {
|
||||||
|
defer close(eofch) // synchronize with the caller
|
||||||
|
myaddr := conn.LocalAddr().String()
|
||||||
|
if _, err := conn.Write(rawQuery); err != nil {
|
||||||
outch <- t.newDNSOverUDPResponse(
|
outch <- t.newDNSOverUDPResponse(
|
||||||
"", err, query, nil, "serialize_query") // one-sized buffer, can't block
|
myaddr, err, query, nil, WriteOperation) // one-sized buffer, can't block
|
||||||
return
|
|
||||||
}
|
|
||||||
// While dial operations return immediately for UDP, we MAY be calling the
|
|
||||||
// dialer's resolver if t.Endpoint contains a domain name. So, let us basically
|
|
||||||
// enforce the same overall deadline covering DNS lookup and I/O operations.
|
|
||||||
deadline := time.Now().Add(t.IOTimeout)
|
|
||||||
ctx, cancel := context.WithDeadline(context.Background(), deadline)
|
|
||||||
defer cancel()
|
|
||||||
conn, err := t.Dialer.DialContext(ctx, "udp", t.Endpoint)
|
|
||||||
if err != nil {
|
|
||||||
outch <- t.newDNSOverUDPResponse(
|
|
||||||
"", err, query, nil, ConnectOperation) // one-sized buffer, can't block
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer conn.Close() // we own the conn
|
|
||||||
conn.SetDeadline(deadline)
|
|
||||||
localAddr := conn.LocalAddr().String()
|
|
||||||
if _, err = conn.Write(rawQuery); err != nil {
|
|
||||||
outch <- t.newDNSOverUDPResponse(
|
|
||||||
localAddr, err, query, nil, WriteOperation) // one-sized buffer, can't block
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
for {
|
for {
|
||||||
resp, err := t.recv(query, conn)
|
resp, err := t.recv(query, conn)
|
||||||
select {
|
select {
|
||||||
case outch <- t.newDNSOverUDPResponse(localAddr, err, query, resp, ReadOperation):
|
case outch <- t.newDNSOverUDPResponse(myaddr, err, query, resp, ReadOperation):
|
||||||
default:
|
default:
|
||||||
return // no-one is reading the channel -- so long...
|
return // no-one is reading the channel -- so long...
|
||||||
}
|
}
|
||||||
|
|
|
@ -283,60 +283,65 @@ func TestDNSOverUDPTransport(t *testing.T) {
|
||||||
|
|
||||||
t.Run("AsyncRoundTrip", func(t *testing.T) {
|
t.Run("AsyncRoundTrip", func(t *testing.T) {
|
||||||
t.Run("calling Next with cancelled context", func(t *testing.T) {
|
t.Run("calling Next with cancelled context", func(t *testing.T) {
|
||||||
blocker := make(chan interface{})
|
srvr := &filtering.DNSServer{
|
||||||
const expected = 17
|
OnQuery: func(domain string) filtering.DNSAction {
|
||||||
input := bytes.NewReader(make([]byte, expected))
|
return filtering.DNSActionCache
|
||||||
txp := NewDNSOverUDPTransport(
|
},
|
||||||
&mocks.Dialer{
|
Cache: map[string][]string{
|
||||||
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
"dns.google.": {"8.8.8.8"},
|
||||||
<-blocker // block here until Next returns because of expired context
|
|
||||||
return &mocks.Conn{
|
|
||||||
MockSetDeadline: func(t time.Time) error {
|
|
||||||
return nil
|
|
||||||
},
|
|
||||||
MockWrite: func(b []byte) (int, error) {
|
|
||||||
return len(b), nil
|
|
||||||
},
|
|
||||||
MockRead: input.Read,
|
|
||||||
MockClose: func() error {
|
|
||||||
return nil
|
|
||||||
},
|
|
||||||
MockLocalAddr: func() net.Addr {
|
|
||||||
return &mocks.Addr{
|
|
||||||
MockNetwork: func() string {
|
|
||||||
return "udp"
|
|
||||||
},
|
|
||||||
MockString: func() string {
|
|
||||||
return "127.0.0.1:1345"
|
|
||||||
},
|
|
||||||
}
|
|
||||||
},
|
|
||||||
}, nil
|
|
||||||
},
|
|
||||||
}, "9.9.9.9:53",
|
|
||||||
)
|
|
||||||
expectedResp := &mocks.DNSResponse{}
|
|
||||||
txp.Decoder = &mocks.DNSDecoder{
|
|
||||||
MockDecodeResponse: func(data []byte, query model.DNSQuery) (model.DNSResponse, error) {
|
|
||||||
return expectedResp, nil
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
query := &mocks.DNSQuery{
|
listener, err := srvr.Start("127.0.0.1:0")
|
||||||
MockBytes: func() ([]byte, error) {
|
if err != nil {
|
||||||
return make([]byte, 128), nil
|
t.Fatal(err)
|
||||||
},
|
|
||||||
}
|
}
|
||||||
out := txp.AsyncRoundTrip(query, 1)
|
defer listener.Close()
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
dialer := NewDialerWithoutResolver(model.DiscardLogger)
|
||||||
cancel() // immediately cancel
|
txp := NewDNSOverUDPTransport(dialer, listener.LocalAddr().String())
|
||||||
resp, err := out.Next(ctx)
|
encoder := &DNSEncoderMiekg{}
|
||||||
|
query := encoder.Encode("dns.google.", dns.TypeA, false)
|
||||||
|
ctx := context.Background()
|
||||||
|
rch, err := txp.AsyncRoundTrip(ctx, query, 1)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer rch.Close()
|
||||||
|
ctx, cancel := context.WithCancel(ctx)
|
||||||
|
cancel() // fail immediately
|
||||||
|
resp, err := rch.Next(ctx)
|
||||||
if !errors.Is(err, context.Canceled) {
|
if !errors.Is(err, context.Canceled) {
|
||||||
t.Fatal("unexpected err", err)
|
t.Fatal("unexpected err", err)
|
||||||
}
|
}
|
||||||
if resp != nil {
|
if resp != nil {
|
||||||
t.Fatal("unexpected resp")
|
t.Fatal("unexpected resp")
|
||||||
}
|
}
|
||||||
close(blocker) // unblock the background goroutine
|
})
|
||||||
|
|
||||||
|
t.Run("no-one is reading the channel", func(t *testing.T) {
|
||||||
|
srvr := &filtering.DNSServer{
|
||||||
|
OnQuery: func(domain string) filtering.DNSAction {
|
||||||
|
return filtering.DNSActionLocalHostPlusCache // i.e., two responses
|
||||||
|
},
|
||||||
|
Cache: map[string][]string{
|
||||||
|
"dns.google.": {"8.8.8.8"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
listener, err := srvr.Start("127.0.0.1:0")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer listener.Close()
|
||||||
|
dialer := NewDialerWithoutResolver(model.DiscardLogger)
|
||||||
|
txp := NewDNSOverUDPTransport(dialer, listener.LocalAddr().String())
|
||||||
|
encoder := &DNSEncoderMiekg{}
|
||||||
|
query := encoder.Encode("dns.google.", dns.TypeA, false)
|
||||||
|
ctx := context.Background()
|
||||||
|
rch, err := txp.AsyncRoundTrip(ctx, query, 1) // but just one place
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer rch.Close()
|
||||||
|
<-rch.Joined // should see no-one is reading and stop
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("typical usage to obtain late responses", func(t *testing.T) {
|
t.Run("typical usage to obtain late responses", func(t *testing.T) {
|
||||||
|
@ -357,7 +362,11 @@ func TestDNSOverUDPTransport(t *testing.T) {
|
||||||
txp := NewDNSOverUDPTransport(dialer, listener.LocalAddr().String())
|
txp := NewDNSOverUDPTransport(dialer, listener.LocalAddr().String())
|
||||||
encoder := &DNSEncoderMiekg{}
|
encoder := &DNSEncoderMiekg{}
|
||||||
query := encoder.Encode("dns.google.", dns.TypeA, false)
|
query := encoder.Encode("dns.google.", dns.TypeA, false)
|
||||||
rch := txp.AsyncRoundTrip(query, 1)
|
rch, err := txp.AsyncRoundTrip(context.Background(), query, 1)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer rch.Close()
|
||||||
resp, err := rch.Next(context.Background())
|
resp, err := rch.Next(context.Background())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
|
@ -408,7 +417,11 @@ func TestDNSOverUDPTransport(t *testing.T) {
|
||||||
txp.IOTimeout = 30 * time.Millisecond // short timeout to have a fast test
|
txp.IOTimeout = 30 * time.Millisecond // short timeout to have a fast test
|
||||||
encoder := &DNSEncoderMiekg{}
|
encoder := &DNSEncoderMiekg{}
|
||||||
query := encoder.Encode("dns.google.", dns.TypeA, false)
|
query := encoder.Encode("dns.google.", dns.TypeA, false)
|
||||||
rch := txp.AsyncRoundTrip(query, 1)
|
rch, err := txp.AsyncRoundTrip(context.Background(), query, 1)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer rch.Close()
|
||||||
result := <-rch.Response
|
result := <-rch.Response
|
||||||
if result.Err == nil || result.Err.Error() != "generic_timeout_error" {
|
if result.Err == nil || result.Err.Error() != "generic_timeout_error" {
|
||||||
t.Fatal("unexpected error", result.Err)
|
t.Fatal("unexpected error", result.Err)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user