From 16f7407b13b1ec4cc7d8eb7ba04f512ec01afa22 Mon Sep 17 00:00:00 2001 From: Simone Basso Date: Thu, 26 May 2022 20:09:00 +0200 Subject: [PATCH] feat(netxlite): observe additional DNS-over-UDP responses (#762) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This diff introduces support for observing additional DNS-over-UDP responses in some censored environments (e.g. China). After some uncertainty around whether to use connected or unconnected UDP sockets, I eventually settled for connected. Here's a recap: | | connected | unconnected | | ----------------------- | --------- | ----------- | | see ICMP errors | ✔️ | ❌ | | responses from any server | ❌ | ✔️ | Because most if not all DNS resolvers expect answers from exactly the same servers to which they sent the query, I would say that it's more important to have some limited ability of observing the effect of ICMP errors (e.g., host_unreachable when we set a low TTL and send out a query to a server). Therefore, my choice was to modify the existing DNS-over-UDP transport. Here's an overview of the changes: 1. introduce a new API for performing an async round trip that returns a channel wrapper where all responses are posted. The channel will not ever be closed, so the reader needs to use select for safely reading. If the reader users the wrapper's Next or TryNextResponses methods, these details do not matter because they already implement a safe reading pattern. 2. the async round trip API performs the round trip in the background and stops processing when it sees the first error. 3. the background running code will use an overall deadline derived from the DNSTransport.IOTimeout field to know when to stop. 4. the background running code will additionally stop running if noone is reading the channel and there are no empty slots in the channel's buffer. 5. the RoundTrip method has been rewritten in terms of the async API. The design I'm using here implements the proposal for async round trips defined at https://github.com/ooni/probe/issues/2099. I have chosen not to make all transports async because the DNS transport seems the only transport that needs to also work in async mode. While there, I noticed that we were not propagating CloseIdleConnection to the underlying dialer, which was potentially wrong, so I did it. --- internal/netxlite/dnsoverudp.go | 276 +++++++++++++++--- internal/netxlite/dnsoverudp_test.go | 233 +++++++++++++++- internal/netxlite/filtering/dns.go | 142 ++++------ internal/netxlite/filtering/dns_test.go | 354 +++++++++++------------- internal/netxlite/filtering/doc.go | 14 +- internal/netxlite/filtering/http.go | 3 + internal/netxlite/filtering/tls.go | 15 +- internal/netxlite/integration_test.go | 6 +- script/nocopyreadall.bash | 6 + 9 files changed, 719 insertions(+), 330 deletions(-) diff --git a/internal/netxlite/dnsoverudp.go b/internal/netxlite/dnsoverudp.go index 2928826..c492a9f 100644 --- a/internal/netxlite/dnsoverudp.go +++ b/internal/netxlite/dnsoverudp.go @@ -6,16 +6,49 @@ package netxlite import ( "context" + "net" "time" "github.com/ooni/probe-cli/v3/internal/model" ) // DNSOverUDPTransport is a DNS-over-UDP DNSTransport. +// +// To construct this type, either manually fill the fields marked as MANDATORY +// or just use the NewDNSOverUDPTransport factory directly. +// +// RoundTrip creates a new connected UDP socket for each outgoing query. Using a +// new socket is good because some censored environments will block the client UDP +// endpoint for several seconds when you query for blocked domains. We could also +// have used an unconnected UDP socket here, but: +// +// 1. connected sockets are great because they get some ICMP errors to be +// translated into socket errors (among them, host_unreachable); +// +// 2. connected sockets ignore responses from illegitimate IP addresses but +// most if not all DNS resolvers also do that, therefore it does not seem to +// be a realistic censorship vector. At the same time, connected sockets +// provide us for free the feature that we don't need to bother with checking +// whether the reply comes from the expected server. +// +// Being able to observe some ICMP errors is good because it could possibly +// make this code suitable to implement parasitic traceroute. +// +// This transport is capable of collecting additional responses after the first +// response. To see these responses, use the AsyncRoundTrip method. type DNSOverUDPTransport struct { - dialer model.Dialer - decoder model.DNSDecoder - address string + // Decoder is the MANDATORY DNSDecoder to use. + Decoder model.DNSDecoder + + // Dialer is the MANDATORY dialer used to create the conn. + Dialer model.Dialer + + // Endpoint is the MANDATORY server's endpoint (e.g., 1.1.1.1:53) + Endpoint string + + // IOTimeout is the MANDATORY I/O timeout after which any + // conn created to perform round trips times out. + IOTimeout time.Duration } // NewDNSOverUDPTransport creates a DNSOverUDPTransport instance. @@ -25,41 +58,36 @@ type DNSOverUDPTransport struct { // - dialer is any type that implements the Dialer interface; // // - address is the endpoint address (e.g., 8.8.8.8:53). +// +// If the address contains a domain name rather than an IP address +// (e.g., dns.google:53), we will end up using the first of the +// IP addresses returned by the underlying DNS lookup performed using +// the dialer. This usage pattern is NOT RECOMMENDED because we'll +// have less control over which IP address is being used. func NewDNSOverUDPTransport(dialer model.Dialer, address string) *DNSOverUDPTransport { return &DNSOverUDPTransport{ - dialer: dialer, - decoder: &DNSDecoderMiekg{}, - address: address, + Decoder: &DNSDecoderMiekg{}, + Dialer: dialer, + Endpoint: address, + IOTimeout: 10 * time.Second, } } -// RoundTrip sends a query and receives a reply. +// RoundTrip sends a query and receives a response. func (t *DNSOverUDPTransport) RoundTrip( ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) { - rawQuery, err := query.Bytes() - if err != nil { - return nil, err - } - conn, err := t.dialer.DialContext(ctx, "udp", t.address) - if err != nil { - return nil, err - } - defer conn.Close() - // Use five seconds timeout like Bionic does. See - // https://labs.ripe.net/Members/baptiste_jonglez_1/persistent-dns-connections-for-reliability-and-performance - const iotimeout = 5 * time.Second - conn.SetDeadline(time.Now().Add(iotimeout)) - if _, err = conn.Write(rawQuery); err != nil { - return nil, err - } - const maxmessagesize = 1 << 17 - rawResponse := make([]byte, maxmessagesize) - count, err := conn.Read(rawResponse) - if err != nil { - return nil, err - } - rawResponse = rawResponse[:count] - return t.decoder.DecodeResponse(rawResponse, query) + // QUIRK: the original code had a five seconds timeout, which is + // consistent with the Bionic implementation. Let's enforce such a + // timeout using the context in the outer operation because we + // need to run for more seconds in the background to catch as many + // duplicate replies as possible. + // + // See https://labs.ripe.net/Members/baptiste_jonglez_1/persistent-dns-connections-for-reliability-and-performance + const opTimeout = 5 * time.Second + ctx, cancel := context.WithTimeout(ctx, opTimeout) + defer cancel() + outch := t.AsyncRoundTrip(query, 1) // buffer to avoid background's goroutine leak + return outch.Next(ctx) } // RequiresPadding returns false for UDP according to RFC8467. @@ -74,12 +102,194 @@ func (t *DNSOverUDPTransport) Network() string { // Address returns the upstream server address. func (t *DNSOverUDPTransport) Address() string { - return t.address + return t.Endpoint } // CloseIdleConnections closes idle connections, if any. func (t *DNSOverUDPTransport) CloseIdleConnections() { - // nothing to do + // The underlying dialer MAY have idle connections so let's + // forward the call... + t.Dialer.CloseIdleConnections() } var _ model.DNSTransport = &DNSOverUDPTransport{} + +// DNSOverUDPResponse is a response received by a DNSOverUDPTransport when you +// use its AsyncRoundTrip method as opposed to using RoundTrip. +type DNSOverUDPResponse struct { + // Err is the error that occurred (nil in case of success). + Err error + + // LocalAddr is the local UDP address we're using. + LocalAddr string + + // Operation is the operation that failed. + Operation string + + // Query is the related DNS query. + Query model.DNSQuery + + // RemoteAddr is the remote server address. + RemoteAddr string + + // Response is the response (nil iff error is not nil). + Response model.DNSResponse +} + +// newDNSOverUDPResponse creates a new DNSOverUDPResponse instance. +func (t *DNSOverUDPTransport) newDNSOverUDPResponse(localAddr string, err error, + query model.DNSQuery, resp model.DNSResponse, operation string) *DNSOverUDPResponse { + return &DNSOverUDPResponse{ + Err: err, + LocalAddr: localAddr, + Operation: operation, + Query: query, + RemoteAddr: t.Endpoint, // The common case is to have an IP:port here (domains are discouraged) + Response: resp, + } +} + +// DNSOverUDPChannel is a wrapper around a channel for reading zero +// or more *DNSOverUDPResponse that makes extracting information from +// the underlying channels more user friendly than interacting with +// the channels directly, thanks to useful wrapper methods implementing +// common access patterns. You can still use the channels directly if +// there's no convenience method for your specific access pattern. +type DNSOverUDPChannel struct { + // Response is the channel where we'll post responses. This channel + // WON'T be closed when the background goroutine terminates. + Response <-chan *DNSOverUDPResponse + + // Joined is a channel that IS CLOSED when the background + // goroutine performing this round trip TERMINATES. + Joined <-chan bool +} + +// Next blocks until the next response is received on Response or the +// given context expires, whatever happens first. This function will +// completely ignore the Joined channel and will just timeout in case +// you call Next after the background goroutine had joined. In fact, +// the use case for this function is using it to get a response or +// a timeout when you know the DNS round trip is pending. +func (ch *DNSOverUDPChannel) Next(ctx context.Context) (model.DNSResponse, error) { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case out := <-ch.Response: // Note: AsyncRoundTrip WILL NOT close the channel or emit a nil + return out.Response, out.Err + } +} + +// TryNextResponses attempts to read all the buffered messages inside of the "Response" +// channel that contains successful DNS responses. That is, this function will silently skip +// any possible DNSOverUDPResponse with its Err != nil. The use case for this function is +// to obtain all the subsequent response messages we received while we were performing +// other operations (e.g., contacting the test helper of fetching a webpage). +func (ch *DNSOverUDPChannel) TryNextResponses() (out []model.DNSResponse) { + for { + select { + case r := <-ch.Response: // Note: AsyncRoundTrip WILL NOT close the channel or emit a nil + if r.Err == nil && r.Response != nil { + out = append(out, r.Response) + } + default: + return + } + } +} + +// AsyncRoundTrip performs an async DNS round trip. The "buffer" argument +// controls how many buffer slots the returned DNSOverUDPChannel's Response +// channel should have. A zero or negative value causes this function to +// create a channel having a single-slot buffer. +// +// The real round trip runs in a background goroutine. We will terminate the background +// goroutine when (1) the IOTimeout expires for the connection we're using or (2) we +// cannot write on the "Response" channel. Note that the background goroutine WILL NOT +// close the "Response" channel to signal its completion. Hence, who reads such a +// channel MUST be prepared for read operations to block forever and use a +// select for draining the channel in a deadlock-safe way. Also, we WILL NOT ever +// emit a nil message over the "Response" channel. +// +// The returned DNSOverUDPChannel contains another channel called Joined that is +// closed when the background goroutine terminates, so you can use this channel +// should you need to synchronize with such goroutine's termination. +// +// If you are using the Next or TryNextResponses methods of the DNSOverUDPChannel type, +// you don't need to worry about these low level details though. +func (t *DNSOverUDPTransport) AsyncRoundTrip(query model.DNSQuery, buffer int) *DNSOverUDPChannel { + if buffer < 2 { + buffer = 1 // as documented + } + outch := make(chan *DNSOverUDPResponse, buffer) + joinedch := make(chan bool) + go t.roundTripLoop(query, outch, joinedch) + return &DNSOverUDPChannel{ + Response: outch, + Joined: joinedch, + } +} + +// roundTripLoop performs the round trip and writes results into the "outch" channel. This +// function ASSUMES that "outch" is configured to have AT LEAST one buffer slot. This function +// TAKES OWNERSHIP of "outch" but WILL NOT close it when done. This function instead OWNS +// the "joinedch" channel and WILL CLOSE it when done. +func (t *DNSOverUDPTransport) roundTripLoop( + query model.DNSQuery, outch chan<- *DNSOverUDPResponse, joinedch chan<- bool) { + defer close(joinedch) // as documented + rawQuery, err := query.Bytes() + if err != nil { + outch <- t.newDNSOverUDPResponse( + "", err, query, nil, "serialize_query") // one-sized buffer, can't block + return + } + // While dial operations return immediately for UDP, we MAY be calling the + // dialer's resolver if t.Endpoint contains a domain name. So, let us basically + // enforce the same overall deadline covering DNS lookup and I/O operations. + deadline := time.Now().Add(t.IOTimeout) + ctx, cancel := context.WithDeadline(context.Background(), deadline) + defer cancel() + conn, err := t.Dialer.DialContext(ctx, "udp", t.Endpoint) + if err != nil { + outch <- t.newDNSOverUDPResponse( + "", err, query, nil, ConnectOperation) // one-sized buffer, can't block + return + } + defer conn.Close() // we own the conn + conn.SetDeadline(deadline) + localAddr := conn.LocalAddr().String() + if _, err = conn.Write(rawQuery); err != nil { + outch <- t.newDNSOverUDPResponse( + localAddr, err, query, nil, WriteOperation) // one-sized buffer, can't block + return + } + for { + resp, err := t.recv(query, conn) + select { + case outch <- t.newDNSOverUDPResponse(localAddr, err, query, resp, ReadOperation): + default: + return // no-one is reading the channel -- so long... + } + if err != nil { + // We are going to consider all errors as fatal for now until we + // hear of specific errs that it might have sense to ignore. + // + // Note that erroring out here includes the expiration of the conn's + // I/O deadline, which we set above precisely because we want + // the total runtime of this goroutine to be bounded. + return + } + } +} + +// recv receives a single response for the given query using the given conn. +func (t *DNSOverUDPTransport) recv(query model.DNSQuery, conn net.Conn) (model.DNSResponse, error) { + const maxmessagesize = 1 << 17 + rawResponse := make([]byte, maxmessagesize) + count, err := conn.Read(rawResponse) + if err != nil { + return nil, err + } + rawResponse = rawResponse[:count] + return t.Decoder.DecodeResponse(rawResponse, query) +} diff --git a/internal/netxlite/dnsoverudp_test.go b/internal/netxlite/dnsoverudp_test.go index 5720edf..895179a 100644 --- a/internal/netxlite/dnsoverudp_test.go +++ b/internal/netxlite/dnsoverudp_test.go @@ -9,8 +9,11 @@ import ( "time" "github.com/apex/log" + "github.com/google/go-cmp/cmp" + "github.com/miekg/dns" "github.com/ooni/probe-cli/v3/internal/model" "github.com/ooni/probe-cli/v3/internal/model/mocks" + "github.com/ooni/probe-cli/v3/internal/netxlite/filtering" ) func TestDNSOverUDPTransport(t *testing.T) { @@ -70,6 +73,16 @@ func TestDNSOverUDPTransport(t *testing.T) { MockClose: func() error { return nil }, + MockLocalAddr: func() net.Addr { + return &mocks.Addr{ + MockNetwork: func() string { + return "udp" + }, + MockString: func() string { + return "127.0.0.1:1345" + }, + } + }, }, nil }, }, "9.9.9.9:53", @@ -106,6 +119,16 @@ func TestDNSOverUDPTransport(t *testing.T) { MockClose: func() error { return nil }, + MockLocalAddr: func() net.Addr { + return &mocks.Addr{ + MockNetwork: func() string { + return "udp" + }, + MockString: func() string { + return "127.0.0.1:1345" + }, + } + }, }, nil }, }, "9.9.9.9:53", @@ -141,12 +164,22 @@ func TestDNSOverUDPTransport(t *testing.T) { MockClose: func() error { return nil }, + MockLocalAddr: func() net.Addr { + return &mocks.Addr{ + MockNetwork: func() string { + return "udp" + }, + MockString: func() string { + return "127.0.0.1:1345" + }, + } + }, }, nil }, }, "9.9.9.9:53", ) expectedErr := errors.New("mocked error") - txp.decoder = &mocks.DNSDecoder{ + txp.Decoder = &mocks.DNSDecoder{ MockDecodeResponse: func(data []byte, query model.DNSQuery) (model.DNSResponse, error) { return nil, expectedErr }, @@ -165,7 +198,7 @@ func TestDNSOverUDPTransport(t *testing.T) { } }) - t.Run("read success", func(t *testing.T) { + t.Run("decode success", func(t *testing.T) { const expected = 17 input := bytes.NewReader(make([]byte, expected)) txp := NewDNSOverUDPTransport( @@ -182,12 +215,22 @@ func TestDNSOverUDPTransport(t *testing.T) { MockClose: func() error { return nil }, + MockLocalAddr: func() net.Addr { + return &mocks.Addr{ + MockNetwork: func() string { + return "udp" + }, + MockString: func() string { + return "127.0.0.1:1345" + }, + } + }, }, nil }, }, "9.9.9.9:53", ) expectedResp := &mocks.DNSResponse{} - txp.decoder = &mocks.DNSDecoder{ + txp.Decoder = &mocks.DNSDecoder{ MockDecodeResponse: func(data []byte, query model.DNSQuery) (model.DNSResponse, error) { return expectedResp, nil }, @@ -205,6 +248,190 @@ func TestDNSOverUDPTransport(t *testing.T) { t.Fatal("unexpected resp") } }) + + t.Run("using a real server", func(t *testing.T) { + srvr := &filtering.DNSServer{ + OnQuery: func(domain string) filtering.DNSAction { + return filtering.DNSActionCache + }, + Cache: map[string][]string{ + "dns.google.": {"8.8.8.8"}, + }, + } + listener, err := srvr.Start("127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer listener.Close() + dialer := NewDialerWithoutResolver(model.DiscardLogger) + txp := NewDNSOverUDPTransport(dialer, listener.LocalAddr().String()) + encoder := &DNSEncoderMiekg{} + query := encoder.Encode("dns.google.", dns.TypeA, false) + resp, err := txp.RoundTrip(context.Background(), query) + if err != nil { + t.Fatal(err) + } + addrs, err := resp.DecodeLookupHost() + if err != nil { + t.Fatal(err) + } + if diff := cmp.Diff(addrs, []string{"8.8.8.8"}); diff != "" { + t.Fatal(diff) + } + }) + }) + + t.Run("AsyncRoundTrip", func(t *testing.T) { + t.Run("calling Next with cancelled context", func(t *testing.T) { + blocker := make(chan interface{}) + const expected = 17 + input := bytes.NewReader(make([]byte, expected)) + txp := NewDNSOverUDPTransport( + &mocks.Dialer{ + MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) { + <-blocker // block here until Next returns because of expired context + return &mocks.Conn{ + MockSetDeadline: func(t time.Time) error { + return nil + }, + MockWrite: func(b []byte) (int, error) { + return len(b), nil + }, + MockRead: input.Read, + MockClose: func() error { + return nil + }, + MockLocalAddr: func() net.Addr { + return &mocks.Addr{ + MockNetwork: func() string { + return "udp" + }, + MockString: func() string { + return "127.0.0.1:1345" + }, + } + }, + }, nil + }, + }, "9.9.9.9:53", + ) + expectedResp := &mocks.DNSResponse{} + txp.Decoder = &mocks.DNSDecoder{ + MockDecodeResponse: func(data []byte, query model.DNSQuery) (model.DNSResponse, error) { + return expectedResp, nil + }, + } + query := &mocks.DNSQuery{ + MockBytes: func() ([]byte, error) { + return make([]byte, 128), nil + }, + } + out := txp.AsyncRoundTrip(query, 1) + ctx, cancel := context.WithCancel(context.Background()) + cancel() // immediately cancel + resp, err := out.Next(ctx) + if !errors.Is(err, context.Canceled) { + t.Fatal("unexpected err", err) + } + if resp != nil { + t.Fatal("unexpected resp") + } + close(blocker) // unblock the background goroutine + }) + + t.Run("typical usage to obtain late responses", func(t *testing.T) { + srvr := &filtering.DNSServer{ + OnQuery: func(domain string) filtering.DNSAction { + return filtering.DNSActionLocalHostPlusCache + }, + Cache: map[string][]string{ + "dns.google.": {"8.8.8.8"}, + }, + } + listener, err := srvr.Start("127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer listener.Close() + dialer := NewDialerWithoutResolver(model.DiscardLogger) + txp := NewDNSOverUDPTransport(dialer, listener.LocalAddr().String()) + encoder := &DNSEncoderMiekg{} + query := encoder.Encode("dns.google.", dns.TypeA, false) + rch := txp.AsyncRoundTrip(query, 1) + resp, err := rch.Next(context.Background()) + if err != nil { + t.Fatal(err) + } + addrs, err := resp.DecodeLookupHost() + if err != nil { + t.Fatal(err) + } + if diff := cmp.Diff(addrs, []string{"127.0.0.1"}); diff != "" { + t.Fatal(diff) + } + // One would not normally busy loop but it's fine to do that in the context + // of this test because we know we're going to receive a second reply. In + // a real network experiment here we'll do other activities, e.g., contacting + // the test helper or fetching a webpage. + var additional []model.DNSResponse + for { + additional = rch.TryNextResponses() + if len(additional) > 0 { + if len(additional) != 1 { + t.Fatal("expected exactly one additional response") + } + break + } + } + addrs, err = additional[0].DecodeLookupHost() + if err != nil { + t.Fatal(err) + } + if diff := cmp.Diff(addrs, []string{"8.8.8.8"}); diff != "" { + t.Fatal(diff) + } + }) + + t.Run("correct behavior when read times out", func(t *testing.T) { + srvr := &filtering.DNSServer{ + OnQuery: func(domain string) filtering.DNSAction { + return filtering.DNSActionTimeout + }, + } + listener, err := srvr.Start("127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer listener.Close() + dialer := NewDialerWithoutResolver(model.DiscardLogger) + txp := NewDNSOverUDPTransport(dialer, listener.LocalAddr().String()) + txp.IOTimeout = 30 * time.Millisecond // short timeout to have a fast test + encoder := &DNSEncoderMiekg{} + query := encoder.Encode("dns.google.", dns.TypeA, false) + rch := txp.AsyncRoundTrip(query, 1) + result := <-rch.Response + if result.Err == nil || result.Err.Error() != "generic_timeout_error" { + t.Fatal("unexpected error", result.Err) + } + if result.Operation != ReadOperation { + t.Fatal("unexpected failed operation", result.Operation) + } + }) + }) + + t.Run("CloseIdleConnections", func(t *testing.T) { + var called bool + dialer := &mocks.Dialer{ + MockCloseIdleConnections: func() { + called = true + }, + } + const address = "9.9.9.9:53" + txp := NewDNSOverUDPTransport(dialer, address) + txp.CloseIdleConnections() + if !called { + t.Fatal("not called") + } }) t.Run("other functions okay", func(t *testing.T) { diff --git a/internal/netxlite/filtering/dns.go b/internal/netxlite/filtering/dns.go index 8b63527..3eede6d 100644 --- a/internal/netxlite/filtering/dns.go +++ b/internal/netxlite/filtering/dns.go @@ -1,22 +1,19 @@ package filtering import ( - "errors" "io" "net" "strings" + "time" "github.com/miekg/dns" "github.com/ooni/probe-cli/v3/internal/runtimex" ) -// DNSAction is a DNS filtering action that this proxy should take. +// DNSAction is a DNS filtering action that a DNSServer should take. type DNSAction string const ( - // DNSActionPass passes the traffic to the upstream server. - DNSActionPass = DNSAction("pass") - // DNSActionNXDOMAIN replies with NXDOMAIN. DNSActionNXDOMAIN = DNSAction("nxdomain") @@ -32,15 +29,19 @@ const ( // DNSActionTimeout never replies to the query. DNSActionTimeout = DNSAction("timeout") - // DNSActionCache causes the proxy to check the cache. If there + // DNSActionCache causes the server to check the cache. If there // are entries, they are returned. Otherwise, NXDOMAIN is returned. DNSActionCache = DNSAction("cache") + + // DNSActionLocalHostPlusCache combines the LocalHost and + // Cache actions returning first a localhost response followed + // by a subsequent response obtained using the cache. + DNSActionLocalHostPlusCache = DNSAction("localhost+cache") ) -// DNSProxy is a DNS proxy that routes traffic to an upstream -// resolver and may implement filtering policies. -type DNSProxy struct { - // Cache is the DNS cache. Note that the keys of the map +// DNSServer is a DNS server implementing filtering policies. +type DNSServer struct { + // Cache is the OPTIONAL DNS cache. Note that the keys of the map // must be FQDNs (i.e., including the final `.`). Cache map[string][]string @@ -48,26 +49,27 @@ type DNSProxy struct { // receive a query for the given domain. OnQuery func(domain string) DNSAction - // UpstreamEndpoint is the OPTIONAL upstream transport endpoint. - UpstreamEndpoint string - - // mockableReply allows to mock DNSProxy.reply in tests. - mockableReply func(query *dns.Msg) (*dns.Msg, error) + // onTimeout is the OPTIONAL channel where we emit a true + // value each time there's a timeout. If you set this value + // to a non-nil channel, then you MUST drain the channel + // for each expected timeout. Otherwise, the code will just + // ignore this field and nothing will be emitted. + onTimeout chan bool } -// DNSListener is the interface returned by DNSProxy.Start +// DNSListener is the interface returned by DNSServer.Start. type DNSListener interface { io.Closer LocalAddr() net.Addr } -// Start starts the proxy. -func (p *DNSProxy) Start(address string) (DNSListener, error) { +// Start starts this server. +func (p *DNSServer) Start(address string) (DNSListener, error) { pconn, _, err := p.start(address) return pconn, err } -func (p *DNSProxy) start(address string) (DNSListener, <-chan interface{}, error) { +func (p *DNSServer) start(address string) (DNSListener, <-chan interface{}, error) { pconn, err := net.ListenPacket("udp", address) if err != nil { return nil, nil, err @@ -77,15 +79,15 @@ func (p *DNSProxy) start(address string) (DNSListener, <-chan interface{}, error return pconn, done, nil } -func (p *DNSProxy) mainloop(pconn net.PacketConn, done chan<- interface{}) { +func (p *DNSServer) mainloop(pconn net.PacketConn, done chan<- interface{}) { defer close(done) for p.oneloop(pconn) { // nothing } } -func (p *DNSProxy) oneloop(pconn net.PacketConn) bool { - buffer := make([]byte, 1<<12) +func (p *DNSServer) oneloop(pconn net.PacketConn) bool { + buffer := make([]byte, 1<<17) count, addr, err := pconn.ReadFrom(buffer) if err != nil { return !strings.HasSuffix(err.Error(), "use of closed network connection") @@ -95,73 +97,70 @@ func (p *DNSProxy) oneloop(pconn net.PacketConn) bool { return true } -func (p *DNSProxy) serveAsync(pconn net.PacketConn, addr net.Addr, buffer []byte) { +func (p *DNSServer) emit(pconn net.PacketConn, addr net.Addr, reply ...*dns.Msg) (success int) { + for _, entry := range reply { + replyBytes, err := entry.Pack() + if err != nil { + continue + } + pconn.WriteTo(replyBytes, addr) + success++ // we use this value in tests + } + return +} + +func (p *DNSServer) serveAsync(pconn net.PacketConn, addr net.Addr, buffer []byte) { query := &dns.Msg{} if err := query.Unpack(buffer); err != nil { return } - reply, err := p.reply(query) - if err != nil { - return - } - replyBytes, err := reply.Pack() - if err != nil { - return - } - pconn.WriteTo(replyBytes, addr) -} - -func (p *DNSProxy) reply(query *dns.Msg) (*dns.Msg, error) { - if p.mockableReply != nil { - return p.mockableReply(query) - } - return p.replyDefault(query) -} - -func (p *DNSProxy) replyDefault(query *dns.Msg) (*dns.Msg, error) { - if len(query.Question) != 1 { - return nil, errors.New("unhandled message") + if len(query.Question) < 1 { + return // just discard the query } name := query.Question[0].Name switch p.OnQuery(name) { - case DNSActionPass: - return p.proxy(query) case DNSActionNXDOMAIN: - return p.nxdomain(query), nil + p.emit(pconn, addr, p.nxdomain(query)) case DNSActionLocalHost: - return p.localHost(query), nil + p.emit(pconn, addr, p.localHost(query)) case DNSActionNoAnswer: - return p.empty(query), nil + p.emit(pconn, addr, p.empty(query)) case DNSActionTimeout: - return nil, errors.New("let's ignore this query") + if p.onTimeout != nil { + p.onTimeout <- true + } case DNSActionCache: - return p.cache(name, query), nil + p.emit(pconn, addr, p.cache(name, query)) + case DNSActionLocalHostPlusCache: + p.emit(pconn, addr, p.localHost(query)) + time.Sleep(10 * time.Millisecond) + p.emit(pconn, addr, p.cache(name, query)) default: - return p.refused(query), nil + p.emit(pconn, addr, p.refused(query)) } } -func (p *DNSProxy) refused(query *dns.Msg) *dns.Msg { +func (p *DNSServer) refused(query *dns.Msg) *dns.Msg { m := new(dns.Msg) m.SetRcode(query, dns.RcodeRefused) return m } -func (p *DNSProxy) nxdomain(query *dns.Msg) *dns.Msg { +func (p *DNSServer) nxdomain(query *dns.Msg) *dns.Msg { m := new(dns.Msg) m.SetRcode(query, dns.RcodeNameError) return m } -func (p *DNSProxy) localHost(query *dns.Msg) *dns.Msg { +func (p *DNSServer) localHost(query *dns.Msg) *dns.Msg { return p.compose(query, net.IPv6loopback, net.IPv4(127, 0, 0, 1)) } -func (p *DNSProxy) empty(query *dns.Msg) *dns.Msg { +func (p *DNSServer) empty(query *dns.Msg) *dns.Msg { return p.compose(query) } -func (p *DNSProxy) compose(query *dns.Msg, ips ...net.IP) *dns.Msg { +func (p *DNSServer) compose(query *dns.Msg, ips ...net.IP) *dns.Msg { runtimex.PanicIfTrue(len(query.Question) != 1, "expecting a single question") question := query.Question[0] reply := new(dns.Msg) @@ -195,27 +194,7 @@ func (p *DNSProxy) compose(query *dns.Msg, ips ...net.IP) *dns.Msg { return reply } -var ( - // errDNSExpectedSingleQuestion means we expected to see a single question - errDNSExpectedSingleQuestion = errors.New("filtering: expected single DNS question") - - // errDNSExpectedQueryNotResponse means we expected to see a query. - errDNSExpectedQueryNotResponse = errors.New("filtering: expected query not response") -) - -func (p *DNSProxy) proxy(query *dns.Msg) (*dns.Msg, error) { - if query.Response { - return nil, errDNSExpectedQueryNotResponse - } - if len(query.Question) != 1 { - return nil, errDNSExpectedSingleQuestion - } - clnt := &dns.Client{} - resp, _, err := clnt.Exchange(query, p.upstreamEndpoint()) - return resp, err -} - -func (p *DNSProxy) cache(name string, query *dns.Msg) *dns.Msg { +func (p *DNSServer) cache(name string, query *dns.Msg) *dns.Msg { addrs := p.Cache[name] var ipAddrs []net.IP for _, addr := range addrs { @@ -228,10 +207,3 @@ func (p *DNSProxy) cache(name string, query *dns.Msg) *dns.Msg { } return p.compose(query, ipAddrs...) } - -func (p *DNSProxy) upstreamEndpoint() string { - if p.UpstreamEndpoint != "" { - return p.UpstreamEndpoint - } - return "8.8.8.8:53" -} diff --git a/internal/netxlite/filtering/dns_test.go b/internal/netxlite/filtering/dns_test.go index 88d5ac5..a59ea63 100644 --- a/internal/netxlite/filtering/dns_test.go +++ b/internal/netxlite/filtering/dns_test.go @@ -1,122 +1,98 @@ package filtering import ( - "context" "errors" "net" "strings" "testing" - "time" - "github.com/apex/log" "github.com/miekg/dns" - "github.com/ooni/probe-cli/v3/internal/model" "github.com/ooni/probe-cli/v3/internal/model/mocks" - "github.com/ooni/probe-cli/v3/internal/netxlite" + "github.com/ooni/probe-cli/v3/internal/randx" ) -func TestDNSProxy(t *testing.T) { - if testing.Short() { - t.Skip("skip test in short mode") - } - newProxyWithCache := func(action DNSAction, cache map[string][]string) (DNSListener, <-chan interface{}, error) { - p := &DNSProxy{ +func TestDNSServer(t *testing.T) { + newServerWithCache := func(action DNSAction, cache map[string][]string) ( + *DNSServer, DNSListener, <-chan interface{}, error) { + p := &DNSServer{ Cache: cache, OnQuery: func(domain string) DNSAction { return action }, + onTimeout: make(chan bool), } - return p.start("127.0.0.1:0") + listener, done, err := p.start("127.0.0.1:0") + return p, listener, done, err } - newProxy := func(action DNSAction) (DNSListener, <-chan interface{}, error) { - return newProxyWithCache(action, nil) + newServer := func(action DNSAction) (*DNSServer, DNSListener, <-chan interface{}, error) { + return newServerWithCache(action, nil) } - newresolver := func(listener DNSListener) model.Resolver { - dlr := netxlite.NewDialerWithoutResolver(log.Log) - r := netxlite.NewResolverUDP(log.Log, dlr, listener.LocalAddr().String()) - return r + newQuery := func(qtype uint16) *dns.Msg { + question := dns.Question{ + Name: dns.Fqdn("dns.google"), + Qtype: qtype, + Qclass: dns.ClassINET, + } + query := new(dns.Msg) + query.Id = dns.Id() + query.RecursionDesired = true + query.Question = make([]dns.Question, 1) + query.Question[0] = question + return query } - t.Run("DNSActionPass", func(t *testing.T) { - ctx := context.Background() - listener, done, err := newProxy(DNSActionPass) - if err != nil { - t.Fatal(err) - } - r := newresolver(listener) - addrs, err := r.LookupHost(ctx, "dns.google") - if err != nil { - t.Fatal(err) - } - if addrs == nil { - t.Fatal("unexpected empty addrs") - } - var found bool - for _, addr := range addrs { - found = found || addr == "8.8.8.8" - } - if !found { - t.Fatal("did not find 8.8.8.8") - } - listener.Close() - <-done // wait for background goroutine to exit - }) - t.Run("DNSActionNXDOMAIN", func(t *testing.T) { - ctx := context.Background() - listener, done, err := newProxy(DNSActionNXDOMAIN) + _, listener, done, err := newServer(DNSActionNXDOMAIN) if err != nil { t.Fatal(err) } - r := newresolver(listener) - addrs, err := r.LookupHost(ctx, "dns.google") - if err == nil || err.Error() != netxlite.FailureDNSNXDOMAINError { - t.Fatal("unexpected err", err) + reply, err := dns.Exchange(newQuery(dns.TypeA), listener.LocalAddr().String()) + if err != nil { + t.Fatal(err) } - if addrs != nil { - t.Fatal("expected empty addrs") + if reply.Rcode != dns.RcodeNameError { + t.Fatal("unexpected rcode") } listener.Close() <-done // wait for background goroutine to exit }) t.Run("DNSActionRefused", func(t *testing.T) { - ctx := context.Background() - listener, done, err := newProxy(DNSActionRefused) + _, listener, done, err := newServer(DNSActionRefused) if err != nil { t.Fatal(err) } - r := newresolver(listener) - addrs, err := r.LookupHost(ctx, "dns.google") - if err == nil || err.Error() != netxlite.FailureDNSRefusedError { - t.Fatal("unexpected err", err) + reply, err := dns.Exchange(newQuery(dns.TypeA), listener.LocalAddr().String()) + if err != nil { + t.Fatal(err) } - if addrs != nil { - t.Fatal("expected empty addrs") + if reply.Rcode != dns.RcodeRefused { + t.Fatal("unexpected rcode") } listener.Close() <-done // wait for background goroutine to exit }) t.Run("DNSActionLocalHost", func(t *testing.T) { - ctx := context.Background() - listener, done, err := newProxy(DNSActionLocalHost) + _, listener, done, err := newServer(DNSActionLocalHost) if err != nil { t.Fatal(err) } - r := newresolver(listener) - addrs, err := r.LookupHost(ctx, "dns.google") + reply, err := dns.Exchange(newQuery(dns.TypeA), listener.LocalAddr().String()) if err != nil { t.Fatal(err) } - if addrs == nil { - t.Fatal("expected non-empty addrs") + if reply.Rcode != dns.RcodeSuccess { + t.Fatal("unexpected rcode") } var found bool - for _, addr := range addrs { - found = found || addr == "127.0.0.1" + for _, ans := range reply.Answer { + switch v := ans.(type) { + case *dns.A: + found = found || v.A.String() == "127.0.0.1" + } } if !found { t.Fatal("did not find 127.0.0.1") @@ -126,94 +102,154 @@ func TestDNSProxy(t *testing.T) { }) t.Run("DNSActionEmpty", func(t *testing.T) { - ctx := context.Background() - listener, done, err := newProxy(DNSActionNoAnswer) + _, listener, done, err := newServer(DNSActionNoAnswer) if err != nil { t.Fatal(err) } - r := newresolver(listener) - addrs, err := r.LookupHost(ctx, "dns.google") - if err == nil || err.Error() != netxlite.FailureDNSNoAnswer { - t.Fatal("unexpected err", err) + reply, err := dns.Exchange(newQuery(dns.TypeA), listener.LocalAddr().String()) + if err != nil { + t.Fatal(err) } - if addrs != nil { - t.Fatal("expected empty addrs") + if reply.Rcode != dns.RcodeSuccess { + t.Fatal("unexpected rcode") + } + if len(reply.Answer) != 0 { + t.Fatal("expected no answers") } listener.Close() <-done // wait for background goroutine to exit }) t.Run("DNSActionTimeout", func(t *testing.T) { - // Implementation note: if you see this test running for more - // than one second, then it means we're not checking the context - // immediately. We should be improving there but we need to be - // careful because lots of legacy code uses SerialResolver. - const timeout = time.Second - ctx, cancel := context.WithTimeout(context.Background(), timeout) - defer cancel() - listener, done, err := newProxy(DNSActionTimeout) + srvr, listener, done, err := newServer(DNSActionTimeout) if err != nil { t.Fatal(err) } - r := newresolver(listener) - addrs, err := r.LookupHost(ctx, "dns.google") - if err == nil || err.Error() != netxlite.FailureGenericTimeoutError { + c := &dns.Client{} + conn, err := c.Dial(listener.LocalAddr().String()) + if err != nil { + t.Fatal(err) + } + go func() { + <-srvr.onTimeout + conn.Close() // close as soon as the server times out, so this test is fast + }() + reply, _, err := c.ExchangeWithConn(newQuery(dns.TypeA), conn) + if !errors.Is(err, net.ErrClosed) { t.Fatal("unexpected err", err) } - if addrs != nil { - t.Fatal("expected empty addrs") + if reply != nil { + t.Fatal("expected nil reply here") } listener.Close() <-done // wait for background goroutine to exit }) t.Run("DNSActionCache without entries", func(t *testing.T) { - ctx := context.Background() - listener, done, err := newProxyWithCache(DNSActionCache, nil) + _, listener, done, err := newServerWithCache(DNSActionCache, nil) if err != nil { t.Fatal(err) } - r := newresolver(listener) - addrs, err := r.LookupHost(ctx, "dns.google") - if err == nil || err.Error() != netxlite.FailureDNSNXDOMAINError { - t.Fatal("unexpected err", err) + reply, err := dns.Exchange(newQuery(dns.TypeA), listener.LocalAddr().String()) + if err != nil { + t.Fatal(err) } - if addrs != nil { - t.Fatal("expected empty addrs") + if reply.Rcode != dns.RcodeNameError { + t.Fatal("unexpected rcode") } listener.Close() <-done // wait for background goroutine to exit }) - t.Run("DNSActionCache with entries", func(t *testing.T) { - ctx := context.Background() + t.Run("DNSActionCache with IPv4 entry", func(t *testing.T) { cache := map[string][]string{ - "dns.google.": {"8.8.8.8", "8.8.4.4"}, + "dns.google.": {"8.8.8.8"}, } - listener, done, err := newProxyWithCache(DNSActionCache, cache) + _, listener, done, err := newServerWithCache(DNSActionCache, cache) if err != nil { t.Fatal(err) } - r := newresolver(listener) - addrs, err := r.LookupHost(ctx, "dns.google") + reply, err := dns.Exchange(newQuery(dns.TypeA), listener.LocalAddr().String()) if err != nil { t.Fatal(err) } - if len(addrs) != 2 { - t.Fatal("expected two entries") + if reply.Rcode != dns.RcodeSuccess { + t.Fatal("unexpected rcode") } - if addrs[0] != "8.8.8.8" { - t.Fatal("invalid first entry") + var found bool + for _, ans := range reply.Answer { + switch v := ans.(type) { + case *dns.A: + found = found || v.A.String() == "8.8.8.8" + } } - if addrs[1] != "8.8.4.4" { - t.Fatal("invalid second entry") + if !found { + t.Fatal("did not find 8.8.8.8") + } + listener.Close() + <-done // wait for background goroutine to exit + }) + + t.Run("DNSActionCache with IPv6 entry", func(t *testing.T) { + cache := map[string][]string{ + "dns.google.": {"2001:4860:4860::8888"}, + } + _, listener, done, err := newServerWithCache(DNSActionCache, cache) + if err != nil { + t.Fatal(err) + } + reply, err := dns.Exchange(newQuery(dns.TypeAAAA), listener.LocalAddr().String()) + if err != nil { + t.Fatal(err) + } + if reply.Rcode != dns.RcodeSuccess { + t.Fatal("unexpected rcode") + } + var found bool + for _, ans := range reply.Answer { + switch v := ans.(type) { + case *dns.AAAA: + found = found || v.AAAA.String() == "2001:4860:4860::8888" + } + } + if !found { + t.Fatal("did not find 2001:4860:4860::8888") + } + listener.Close() + <-done // wait for background goroutine to exit + }) + + t.Run("DNSActionLocalHostPlusCache", func(t *testing.T) { + cache := map[string][]string{ + "dns.google.": {"2001:4860:4860::8888"}, + } + _, listener, done, err := newServerWithCache(DNSActionLocalHostPlusCache, cache) + if err != nil { + t.Fatal(err) + } + reply, err := dns.Exchange(newQuery(dns.TypeAAAA), listener.LocalAddr().String()) + if err != nil { + t.Fatal(err) + } + if reply.Rcode != dns.RcodeSuccess { + t.Fatal("unexpected rcode") + } + var found bool + for _, ans := range reply.Answer { + switch v := ans.(type) { + case *dns.AAAA: + found = found || v.AAAA.String() == "::1" + } + } + if !found { + t.Fatal("did not find ::1") } listener.Close() <-done // wait for background goroutine to exit }) t.Run("Start with invalid address", func(t *testing.T) { - p := &DNSProxy{} + p := &DNSServer{} listener, err := p.Start("127.0.0.1") if err == nil || !strings.HasSuffix(err.Error(), "missing port in address") { t.Fatal("unexpected err", err) @@ -226,7 +262,7 @@ func TestDNSProxy(t *testing.T) { t.Run("oneloop", func(t *testing.T) { t.Run("ReadFrom failure after which we should continue", func(t *testing.T) { expected := errors.New("mocked error") - p := &DNSProxy{} + p := &DNSServer{} conn := &mocks.UDPLikeConn{ MockReadFrom: func(p []byte) (n int, addr net.Addr, err error) { return 0, nil, expected @@ -240,7 +276,7 @@ func TestDNSProxy(t *testing.T) { t.Run("ReadFrom the connection is closed", func(t *testing.T) { expected := errors.New("use of closed network connection") - p := &DNSProxy{} + p := &DNSServer{} conn := &mocks.UDPLikeConn{ MockReadFrom: func(p []byte) (n int, addr net.Addr, err error) { return 0, nil, expected @@ -253,7 +289,7 @@ func TestDNSProxy(t *testing.T) { }) t.Run("Unpack fails", func(t *testing.T) { - p := &DNSProxy{} + p := &DNSServer{} conn := &mocks.UDPLikeConn{ MockReadFrom: func(p []byte) (n int, addr net.Addr, err error) { if len(p) < 4 { @@ -269,46 +305,16 @@ func TestDNSProxy(t *testing.T) { } }) - t.Run("reply fails", func(t *testing.T) { - p := &DNSProxy{} + t.Run("no questions", func(t *testing.T) { + query := newQuery(dns.TypeA) + query.Question = nil // remove the question + data, err := query.Pack() + if err != nil { + t.Fatal(err) + } + p := &DNSServer{} conn := &mocks.UDPLikeConn{ MockReadFrom: func(p []byte) (n int, addr net.Addr, err error) { - query := &dns.Msg{} - query.Question = append(query.Question, dns.Question{}) - query.Question = append(query.Question, dns.Question{}) - data, err := query.Pack() - if err != nil { - panic(err) - } - if len(p) < len(data) { - panic("buffer too small") - } - copy(p, data) - return len(data), &net.UDPAddr{}, nil - }, - } - okay := p.oneloop(conn) - if !okay { - t.Fatal("we should be okay after this error") - } - }) - - t.Run("pack fails", func(t *testing.T) { - p := &DNSProxy{ - mockableReply: func(query *dns.Msg) (*dns.Msg, error) { - reply := &dns.Msg{} - reply.MsgHdr.Rcode = -1 // causes pack to fail - return reply, nil - }, - } - conn := &mocks.UDPLikeConn{ - MockReadFrom: func(p []byte) (n int, addr net.Addr, err error) { - query := &dns.Msg{} - query.Question = append(query.Question, dns.Question{}) - data, err := query.Pack() - if err != nil { - panic(err) - } if len(p) < len(data) { panic("buffer too small") } @@ -323,45 +329,13 @@ func TestDNSProxy(t *testing.T) { }) }) - t.Run("proxy", func(t *testing.T) { - t.Run("with response", func(t *testing.T) { - p := &DNSProxy{} - query := &dns.Msg{} - query.Response = true - reply, err := p.proxy(query) - if !errors.Is(err, errDNSExpectedQueryNotResponse) { - t.Fatal("unexpected err", err) - } - if reply != nil { - t.Fatal("expected nil reply") - } - }) - - t.Run("with no questions", func(t *testing.T) { - p := &DNSProxy{} - query := &dns.Msg{} - reply, err := p.proxy(query) - if !errors.Is(err, errDNSExpectedSingleQuestion) { - t.Fatal("unexpected err", err) - } - if reply != nil { - t.Fatal("expected nil reply") - } - }) - - t.Run("round trip fails", func(t *testing.T) { - p := &DNSProxy{ - UpstreamEndpoint: "antani", - } - query := &dns.Msg{} - query.Question = append(query.Question, dns.Question{}) - reply, err := p.proxy(query) - if err == nil || !strings.HasSuffix(err.Error(), "missing port in address") { - t.Fatal("unexpected err", err) - } - if reply != nil { - t.Fatal("expected nil reply here") - } - }) + t.Run("pack fails", func(t *testing.T) { + query := newQuery(dns.TypeA) + query.Question[0].Name = randx.Letters(1024) // should be too large + p := &DNSServer{} + count := p.emit(&mocks.UDPLikeConn{}, &mocks.Addr{}, query) + if count != 0 { + t.Fatal("expected to see zero here") + } }) } diff --git a/internal/netxlite/filtering/doc.go b/internal/netxlite/filtering/doc.go index e5febfa..69bb7ce 100644 --- a/internal/netxlite/filtering/doc.go +++ b/internal/netxlite/filtering/doc.go @@ -1,13 +1,3 @@ -// Package filtering allows to implement self-censorship. -// -// The top-level struct is the TProxy. It implements model's -// UnderlyingNetworkLibrary interface. Therefore, you can use TProxy to -// implement filtering and blocking of TCP, TLS, QUIC, DNS, HTTP. -// -// We also expose proxies that implement filtering policies for -// DNS, TLS, and HTTP. -// -// The typical usage of this package's functionality is to -// load a censoring policy into TProxyConfig and then to create -// and start a TProxy instance using NewTProxy. +// Package filtering allows to implement self-censorship. We expose proxies +// implementing filtering policies for DNS, TLS, and HTTP. package filtering diff --git a/internal/netxlite/filtering/http.go b/internal/netxlite/filtering/http.go index c39fe1b..6097fca 100644 --- a/internal/netxlite/filtering/http.go +++ b/internal/netxlite/filtering/http.go @@ -9,6 +9,9 @@ import ( "github.com/ooni/probe-cli/v3/internal/runtimex" ) +// TODO(bassosimone): remove HTTPActionPass since we want integration tests +// to only run locally to make them much more predictable. + // HTTPAction is an HTTP filtering action that this proxy should take. type HTTPAction string diff --git a/internal/netxlite/filtering/tls.go b/internal/netxlite/filtering/tls.go index 32b092e..298a30b 100644 --- a/internal/netxlite/filtering/tls.go +++ b/internal/netxlite/filtering/tls.go @@ -1,16 +1,17 @@ package filtering import ( - "context" "crypto/tls" "errors" + "io" "net" "strings" "sync" - - "github.com/ooni/probe-cli/v3/internal/netxlite" ) +// TODO(bassosimone): remove TLSActionPass since we want integration tests +// to only run locally to make them much more predictable. + // TLSAction is a TLS filtering action that this proxy should take. type TLSAction string @@ -237,5 +238,11 @@ func (p *TLSProxy) connectingToMyself(conn net.Conn) bool { // forward will forward the traffic. func (p *TLSProxy) forward(wg *sync.WaitGroup, left net.Conn, right net.Conn) { defer wg.Done() - netxlite.CopyContext(context.Background(), left, right) + // We cannot use netxlite.CopyContext here because we want netxlite to + // use filtering inside its test suite, so this package cannot depend on + // netxlite. In general, we don't want to use io.Copy or io.ReadAll + // directly because they may cause the code to block as documented in + // internal/netxlite/iox.go. However, this package is only used for + // testing, so it's completely okay to make an exception here. + io.Copy(left, right) } diff --git a/internal/netxlite/integration_test.go b/internal/netxlite/integration_test.go index 60b2ddf..19a35e1 100644 --- a/internal/netxlite/integration_test.go +++ b/internal/netxlite/integration_test.go @@ -113,7 +113,7 @@ func TestMeasureWithUDPResolver(t *testing.T) { }) t.Run("for nxdomain", func(t *testing.T) { - proxy := &filtering.DNSProxy{ + proxy := &filtering.DNSServer{ OnQuery: func(domain string) filtering.DNSAction { return filtering.DNSActionNXDOMAIN }, @@ -137,7 +137,7 @@ func TestMeasureWithUDPResolver(t *testing.T) { }) t.Run("for refused", func(t *testing.T) { - proxy := &filtering.DNSProxy{ + proxy := &filtering.DNSServer{ OnQuery: func(domain string) filtering.DNSAction { return filtering.DNSActionRefused }, @@ -161,7 +161,7 @@ func TestMeasureWithUDPResolver(t *testing.T) { }) t.Run("for timeout", func(t *testing.T) { - proxy := &filtering.DNSProxy{ + proxy := &filtering.DNSServer{ OnQuery: func(domain string) filtering.DNSAction { return filtering.DNSActionTimeout }, diff --git a/script/nocopyreadall.bash b/script/nocopyreadall.bash index 0797655..cd4a234 100755 --- a/script/nocopyreadall.bash +++ b/script/nocopyreadall.bash @@ -7,6 +7,12 @@ for file in $(find . -type f -name \*.go); do # implement safer wrappers for these functions. continue fi + if [ "$file" = "./internal/netxlite/filtering/tls.go" ]; then + # We're allowed to use ReadAll and Copy in this file to + # avoid depending on netxlite, so we can use filtering + # inside of netxlite's own test suite. + continue + fi if grep -q 'io\.ReadAll' $file; then echo "in $file: do not use io.ReadAll, use netxlite.ReadAllContext" 1>&2 exitcode=1