diff --git a/internal/netxlite/dnsoverudp.go b/internal/netxlite/dnsoverudp.go index c492a9f..fa9d0e4 100644 --- a/internal/netxlite/dnsoverudp.go +++ b/internal/netxlite/dnsoverudp.go @@ -26,9 +26,9 @@ import ( // translated into socket errors (among them, host_unreachable); // // 2. connected sockets ignore responses from illegitimate IP addresses but -// most if not all DNS resolvers also do that, therefore it does not seem to +// most if not all DNS resolvers also do that, therefore this does not seem to // be a realistic censorship vector. At the same time, connected sockets -// provide us for free the feature that we don't need to bother with checking +// provide us for free with the feature that we don't need to bother with checking // whether the reply comes from the expected server. // // Being able to observe some ICMP errors is good because it could possibly @@ -86,7 +86,11 @@ func (t *DNSOverUDPTransport) RoundTrip( const opTimeout = 5 * time.Second ctx, cancel := context.WithTimeout(ctx, opTimeout) defer cancel() - outch := t.AsyncRoundTrip(query, 1) // buffer to avoid background's goroutine leak + outch, err := t.AsyncRoundTrip(ctx, query, 1) // buffer to avoid background's goroutine leak + if err != nil { + return nil, err + } + defer outch.Close() // we own the channel return outch.Next(ctx) } @@ -153,16 +157,30 @@ func (t *DNSOverUDPTransport) newDNSOverUDPResponse(localAddr string, err error, // or more *DNSOverUDPResponse that makes extracting information from // the underlying channels more user friendly than interacting with // the channels directly, thanks to useful wrapper methods implementing -// common access patterns. You can still use the channels directly if -// there's no convenience method for your specific access pattern. +// common access patterns. You can still use the underlying channels +// directly if there's no suitable convenience method. +// +// You MUST call the .Close method when done. Not calling such a method +// leaks goroutines and causes connections to stay open forever. type DNSOverUDPChannel struct { // Response is the channel where we'll post responses. This channel - // WON'T be closed when the background goroutine terminates. + // WILL NOT be closed when the background goroutine terminates. Response <-chan *DNSOverUDPResponse - // Joined is a channel that IS CLOSED when the background - // goroutine performing this round trip TERMINATES. + // Joined IS CLOSED when the background goroutine terminates. Joined <-chan bool + + // conn is the underlying connection, which we can Close to + // immediately cause the background goroutine to join. + conn net.Conn +} + +// Close releases the resources allocated by the channel. You MUST +// call this method to force the background goroutine that is performing +// the round trip to terminate. Calling this method also ensures we +// close the connection used by the round trip. This method is idempotent. +func (ch *DNSOverUDPChannel) Close() error { + return ch.conn.Close() } // Next blocks until the next response is received on Response or the @@ -205,11 +223,12 @@ func (ch *DNSOverUDPChannel) TryNextResponses() (out []model.DNSResponse) { // // The real round trip runs in a background goroutine. We will terminate the background // goroutine when (1) the IOTimeout expires for the connection we're using or (2) we -// cannot write on the "Response" channel. Note that the background goroutine WILL NOT -// close the "Response" channel to signal its completion. Hence, who reads such a -// channel MUST be prepared for read operations to block forever and use a -// select for draining the channel in a deadlock-safe way. Also, we WILL NOT ever -// emit a nil message over the "Response" channel. +// cannot write on the "Response" channel or (3) the connection is closed by calling the +// Close method of DNSOverUDPChannel. Note that the background goroutine WILL NOT close +// the "Response" channel to signal its completion. Hence, who reads such a +// channel MUST be prepared for read operations to block forever (i.e., should use +// a select operation for draining the channel in a deadlock-safe way). Also, +// we WILL NOT ever post a nil message to the "Response" channel. // // The returned DNSOverUDPChannel contains another channel called Joined that is // closed when the background goroutine terminates, so you can use this channel @@ -217,56 +236,75 @@ func (ch *DNSOverUDPChannel) TryNextResponses() (out []model.DNSResponse) { // // If you are using the Next or TryNextResponses methods of the DNSOverUDPChannel type, // you don't need to worry about these low level details though. -func (t *DNSOverUDPTransport) AsyncRoundTrip(query model.DNSQuery, buffer int) *DNSOverUDPChannel { +// +// We give you OWNERSHIP of the returned DNSOverUDPChannel and you MUST +// call its .Close method when done with using it. +func (t *DNSOverUDPTransport) AsyncRoundTrip( + ctx context.Context, query model.DNSQuery, buffer int) (*DNSOverUDPChannel, error) { + rawQuery, err := query.Bytes() + if err != nil { + return nil, err + } + conn, err := t.Dialer.DialContext(ctx, "udp", t.Endpoint) + if err != nil { + return nil, err + } + conn.SetDeadline(time.Now().Add(t.IOTimeout)) if buffer < 2 { buffer = 1 // as documented } outch := make(chan *DNSOverUDPResponse, buffer) joinedch := make(chan bool) - go t.roundTripLoop(query, outch, joinedch) - return &DNSOverUDPChannel{ + go t.sendRecvLoop(conn, rawQuery, query, outch, joinedch) + dnsch := &DNSOverUDPChannel{ Response: outch, Joined: joinedch, + conn: conn, // transfer ownership } + return dnsch, nil } -// roundTripLoop performs the round trip and writes results into the "outch" channel. This -// function ASSUMES that "outch" is configured to have AT LEAST one buffer slot. This function -// TAKES OWNERSHIP of "outch" but WILL NOT close it when done. This function instead OWNS -// the "joinedch" channel and WILL CLOSE it when done. -func (t *DNSOverUDPTransport) roundTripLoop( - query model.DNSQuery, outch chan<- *DNSOverUDPResponse, joinedch chan<- bool) { - defer close(joinedch) // as documented - rawQuery, err := query.Bytes() - if err != nil { +// sendRecvLoop sends the given raw query on the given conn and receives responses +// from the conn posting them onto the given output channel. +// +// Arguments: +// +// 1. conn is the BORROWED net.Conn (we will use it for reading or writing but +// we do not own the connection and we're not going to close it); +// +// 2. rawQuery contains the rawQuery and is BORROWED (we won't modify it); +// +// 3. query contains the original query and is also BORROWED; +// +// 4. outch is the channel where to emit measurements and is OWNED by this +// function (that said, we WILL NOT close this channel); +// +// 5. eofch is the channel to signal EOF, which is OWNED by this function +// and closed when this function exits. +// +// This method terminates in the following cases: +// +// 1. I/O error while reading or writing (including the deadline expiring or +// the owner of the connection closing the connection); +// +// 2. We cannot post on the output channel because either there is +// noone reading the channel or the channel's buffer is full. +// +// 3. We cannot parse incoming data as a valid DNS response message that +// responds to the query that we originally sent. +func (t *DNSOverUDPTransport) sendRecvLoop(conn net.Conn, rawQuery []byte, + query model.DNSQuery, outch chan<- *DNSOverUDPResponse, eofch chan<- bool) { + defer close(eofch) // synchronize with the caller + myaddr := conn.LocalAddr().String() + if _, err := conn.Write(rawQuery); err != nil { outch <- t.newDNSOverUDPResponse( - "", err, query, nil, "serialize_query") // one-sized buffer, can't block - return - } - // While dial operations return immediately for UDP, we MAY be calling the - // dialer's resolver if t.Endpoint contains a domain name. So, let us basically - // enforce the same overall deadline covering DNS lookup and I/O operations. - deadline := time.Now().Add(t.IOTimeout) - ctx, cancel := context.WithDeadline(context.Background(), deadline) - defer cancel() - conn, err := t.Dialer.DialContext(ctx, "udp", t.Endpoint) - if err != nil { - outch <- t.newDNSOverUDPResponse( - "", err, query, nil, ConnectOperation) // one-sized buffer, can't block - return - } - defer conn.Close() // we own the conn - conn.SetDeadline(deadline) - localAddr := conn.LocalAddr().String() - if _, err = conn.Write(rawQuery); err != nil { - outch <- t.newDNSOverUDPResponse( - localAddr, err, query, nil, WriteOperation) // one-sized buffer, can't block + myaddr, err, query, nil, WriteOperation) // one-sized buffer, can't block return } for { resp, err := t.recv(query, conn) select { - case outch <- t.newDNSOverUDPResponse(localAddr, err, query, resp, ReadOperation): + case outch <- t.newDNSOverUDPResponse(myaddr, err, query, resp, ReadOperation): default: return // no-one is reading the channel -- so long... } diff --git a/internal/netxlite/dnsoverudp_test.go b/internal/netxlite/dnsoverudp_test.go index 895179a..f79b2f7 100644 --- a/internal/netxlite/dnsoverudp_test.go +++ b/internal/netxlite/dnsoverudp_test.go @@ -283,60 +283,65 @@ func TestDNSOverUDPTransport(t *testing.T) { t.Run("AsyncRoundTrip", func(t *testing.T) { t.Run("calling Next with cancelled context", func(t *testing.T) { - blocker := make(chan interface{}) - const expected = 17 - input := bytes.NewReader(make([]byte, expected)) - txp := NewDNSOverUDPTransport( - &mocks.Dialer{ - MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) { - <-blocker // block here until Next returns because of expired context - return &mocks.Conn{ - MockSetDeadline: func(t time.Time) error { - return nil - }, - MockWrite: func(b []byte) (int, error) { - return len(b), nil - }, - MockRead: input.Read, - MockClose: func() error { - return nil - }, - MockLocalAddr: func() net.Addr { - return &mocks.Addr{ - MockNetwork: func() string { - return "udp" - }, - MockString: func() string { - return "127.0.0.1:1345" - }, - } - }, - }, nil - }, - }, "9.9.9.9:53", - ) - expectedResp := &mocks.DNSResponse{} - txp.Decoder = &mocks.DNSDecoder{ - MockDecodeResponse: func(data []byte, query model.DNSQuery) (model.DNSResponse, error) { - return expectedResp, nil + srvr := &filtering.DNSServer{ + OnQuery: func(domain string) filtering.DNSAction { + return filtering.DNSActionCache + }, + Cache: map[string][]string{ + "dns.google.": {"8.8.8.8"}, }, } - query := &mocks.DNSQuery{ - MockBytes: func() ([]byte, error) { - return make([]byte, 128), nil - }, + listener, err := srvr.Start("127.0.0.1:0") + if err != nil { + t.Fatal(err) } - out := txp.AsyncRoundTrip(query, 1) - ctx, cancel := context.WithCancel(context.Background()) - cancel() // immediately cancel - resp, err := out.Next(ctx) + defer listener.Close() + dialer := NewDialerWithoutResolver(model.DiscardLogger) + txp := NewDNSOverUDPTransport(dialer, listener.LocalAddr().String()) + encoder := &DNSEncoderMiekg{} + query := encoder.Encode("dns.google.", dns.TypeA, false) + ctx := context.Background() + rch, err := txp.AsyncRoundTrip(ctx, query, 1) + if err != nil { + t.Fatal(err) + } + defer rch.Close() + ctx, cancel := context.WithCancel(ctx) + cancel() // fail immediately + resp, err := rch.Next(ctx) if !errors.Is(err, context.Canceled) { t.Fatal("unexpected err", err) } if resp != nil { t.Fatal("unexpected resp") } - close(blocker) // unblock the background goroutine + }) + + t.Run("no-one is reading the channel", func(t *testing.T) { + srvr := &filtering.DNSServer{ + OnQuery: func(domain string) filtering.DNSAction { + return filtering.DNSActionLocalHostPlusCache // i.e., two responses + }, + Cache: map[string][]string{ + "dns.google.": {"8.8.8.8"}, + }, + } + listener, err := srvr.Start("127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer listener.Close() + dialer := NewDialerWithoutResolver(model.DiscardLogger) + txp := NewDNSOverUDPTransport(dialer, listener.LocalAddr().String()) + encoder := &DNSEncoderMiekg{} + query := encoder.Encode("dns.google.", dns.TypeA, false) + ctx := context.Background() + rch, err := txp.AsyncRoundTrip(ctx, query, 1) // but just one place + if err != nil { + t.Fatal(err) + } + defer rch.Close() + <-rch.Joined // should see no-one is reading and stop }) t.Run("typical usage to obtain late responses", func(t *testing.T) { @@ -357,7 +362,11 @@ func TestDNSOverUDPTransport(t *testing.T) { txp := NewDNSOverUDPTransport(dialer, listener.LocalAddr().String()) encoder := &DNSEncoderMiekg{} query := encoder.Encode("dns.google.", dns.TypeA, false) - rch := txp.AsyncRoundTrip(query, 1) + rch, err := txp.AsyncRoundTrip(context.Background(), query, 1) + if err != nil { + t.Fatal(err) + } + defer rch.Close() resp, err := rch.Next(context.Background()) if err != nil { t.Fatal(err) @@ -408,7 +417,11 @@ func TestDNSOverUDPTransport(t *testing.T) { txp.IOTimeout = 30 * time.Millisecond // short timeout to have a fast test encoder := &DNSEncoderMiekg{} query := encoder.Encode("dns.google.", dns.TypeA, false) - rch := txp.AsyncRoundTrip(query, 1) + rch, err := txp.AsyncRoundTrip(context.Background(), query, 1) + if err != nil { + t.Fatal(err) + } + defer rch.Close() result := <-rch.Response if result.Err == nil || result.Err.Error() != "generic_timeout_error" { t.Fatal("unexpected error", result.Err)