feat(netxlite): observe additional DNS-over-UDP responses (#762)
This diff introduces support for observing additional DNS-over-UDP responses in some censored environments (e.g. China). After some uncertainty around whether to use connected or unconnected UDP sockets, I eventually settled for connected. Here's a recap: | | connected | unconnected | | ----------------------- | --------- | ----------- | | see ICMP errors | ✔️ | ❌ | | responses from any server | ❌ | ✔️ | Because most if not all DNS resolvers expect answers from exactly the same servers to which they sent the query, I would say that it's more important to have some limited ability of observing the effect of ICMP errors (e.g., host_unreachable when we set a low TTL and send out a query to a server). Therefore, my choice was to modify the existing DNS-over-UDP transport. Here's an overview of the changes: 1. introduce a new API for performing an async round trip that returns a channel wrapper where all responses are posted. The channel will not ever be closed, so the reader needs to use select for safely reading. If the reader users the wrapper's Next or TryNextResponses methods, these details do not matter because they already implement a safe reading pattern. 2. the async round trip API performs the round trip in the background and stops processing when it sees the first error. 3. the background running code will use an overall deadline derived from the DNSTransport.IOTimeout field to know when to stop. 4. the background running code will additionally stop running if noone is reading the channel and there are no empty slots in the channel's buffer. 5. the RoundTrip method has been rewritten in terms of the async API. The design I'm using here implements the proposal for async round trips defined at https://github.com/ooni/probe/issues/2099. I have chosen not to make all transports async because the DNS transport seems the only transport that needs to also work in async mode. While there, I noticed that we were not propagating CloseIdleConnection to the underlying dialer, which was potentially wrong, so I did it.
This commit is contained in:
parent
01a513a496
commit
16f7407b13
|
@ -6,16 +6,49 @@ package netxlite
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"net"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/ooni/probe-cli/v3/internal/model"
|
"github.com/ooni/probe-cli/v3/internal/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
// DNSOverUDPTransport is a DNS-over-UDP DNSTransport.
|
// DNSOverUDPTransport is a DNS-over-UDP DNSTransport.
|
||||||
|
//
|
||||||
|
// To construct this type, either manually fill the fields marked as MANDATORY
|
||||||
|
// or just use the NewDNSOverUDPTransport factory directly.
|
||||||
|
//
|
||||||
|
// RoundTrip creates a new connected UDP socket for each outgoing query. Using a
|
||||||
|
// new socket is good because some censored environments will block the client UDP
|
||||||
|
// endpoint for several seconds when you query for blocked domains. We could also
|
||||||
|
// have used an unconnected UDP socket here, but:
|
||||||
|
//
|
||||||
|
// 1. connected sockets are great because they get some ICMP errors to be
|
||||||
|
// translated into socket errors (among them, host_unreachable);
|
||||||
|
//
|
||||||
|
// 2. connected sockets ignore responses from illegitimate IP addresses but
|
||||||
|
// most if not all DNS resolvers also do that, therefore it does not seem to
|
||||||
|
// be a realistic censorship vector. At the same time, connected sockets
|
||||||
|
// provide us for free the feature that we don't need to bother with checking
|
||||||
|
// whether the reply comes from the expected server.
|
||||||
|
//
|
||||||
|
// Being able to observe some ICMP errors is good because it could possibly
|
||||||
|
// make this code suitable to implement parasitic traceroute.
|
||||||
|
//
|
||||||
|
// This transport is capable of collecting additional responses after the first
|
||||||
|
// response. To see these responses, use the AsyncRoundTrip method.
|
||||||
type DNSOverUDPTransport struct {
|
type DNSOverUDPTransport struct {
|
||||||
dialer model.Dialer
|
// Decoder is the MANDATORY DNSDecoder to use.
|
||||||
decoder model.DNSDecoder
|
Decoder model.DNSDecoder
|
||||||
address string
|
|
||||||
|
// Dialer is the MANDATORY dialer used to create the conn.
|
||||||
|
Dialer model.Dialer
|
||||||
|
|
||||||
|
// Endpoint is the MANDATORY server's endpoint (e.g., 1.1.1.1:53)
|
||||||
|
Endpoint string
|
||||||
|
|
||||||
|
// IOTimeout is the MANDATORY I/O timeout after which any
|
||||||
|
// conn created to perform round trips times out.
|
||||||
|
IOTimeout time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewDNSOverUDPTransport creates a DNSOverUDPTransport instance.
|
// NewDNSOverUDPTransport creates a DNSOverUDPTransport instance.
|
||||||
|
@ -25,41 +58,36 @@ type DNSOverUDPTransport struct {
|
||||||
// - dialer is any type that implements the Dialer interface;
|
// - dialer is any type that implements the Dialer interface;
|
||||||
//
|
//
|
||||||
// - address is the endpoint address (e.g., 8.8.8.8:53).
|
// - address is the endpoint address (e.g., 8.8.8.8:53).
|
||||||
|
//
|
||||||
|
// If the address contains a domain name rather than an IP address
|
||||||
|
// (e.g., dns.google:53), we will end up using the first of the
|
||||||
|
// IP addresses returned by the underlying DNS lookup performed using
|
||||||
|
// the dialer. This usage pattern is NOT RECOMMENDED because we'll
|
||||||
|
// have less control over which IP address is being used.
|
||||||
func NewDNSOverUDPTransport(dialer model.Dialer, address string) *DNSOverUDPTransport {
|
func NewDNSOverUDPTransport(dialer model.Dialer, address string) *DNSOverUDPTransport {
|
||||||
return &DNSOverUDPTransport{
|
return &DNSOverUDPTransport{
|
||||||
dialer: dialer,
|
Decoder: &DNSDecoderMiekg{},
|
||||||
decoder: &DNSDecoderMiekg{},
|
Dialer: dialer,
|
||||||
address: address,
|
Endpoint: address,
|
||||||
|
IOTimeout: 10 * time.Second,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// RoundTrip sends a query and receives a reply.
|
// RoundTrip sends a query and receives a response.
|
||||||
func (t *DNSOverUDPTransport) RoundTrip(
|
func (t *DNSOverUDPTransport) RoundTrip(
|
||||||
ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
|
ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
|
||||||
rawQuery, err := query.Bytes()
|
// QUIRK: the original code had a five seconds timeout, which is
|
||||||
if err != nil {
|
// consistent with the Bionic implementation. Let's enforce such a
|
||||||
return nil, err
|
// timeout using the context in the outer operation because we
|
||||||
}
|
// need to run for more seconds in the background to catch as many
|
||||||
conn, err := t.dialer.DialContext(ctx, "udp", t.address)
|
// duplicate replies as possible.
|
||||||
if err != nil {
|
//
|
||||||
return nil, err
|
// See https://labs.ripe.net/Members/baptiste_jonglez_1/persistent-dns-connections-for-reliability-and-performance
|
||||||
}
|
const opTimeout = 5 * time.Second
|
||||||
defer conn.Close()
|
ctx, cancel := context.WithTimeout(ctx, opTimeout)
|
||||||
// Use five seconds timeout like Bionic does. See
|
defer cancel()
|
||||||
// https://labs.ripe.net/Members/baptiste_jonglez_1/persistent-dns-connections-for-reliability-and-performance
|
outch := t.AsyncRoundTrip(query, 1) // buffer to avoid background's goroutine leak
|
||||||
const iotimeout = 5 * time.Second
|
return outch.Next(ctx)
|
||||||
conn.SetDeadline(time.Now().Add(iotimeout))
|
|
||||||
if _, err = conn.Write(rawQuery); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
const maxmessagesize = 1 << 17
|
|
||||||
rawResponse := make([]byte, maxmessagesize)
|
|
||||||
count, err := conn.Read(rawResponse)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
rawResponse = rawResponse[:count]
|
|
||||||
return t.decoder.DecodeResponse(rawResponse, query)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// RequiresPadding returns false for UDP according to RFC8467.
|
// RequiresPadding returns false for UDP according to RFC8467.
|
||||||
|
@ -74,12 +102,194 @@ func (t *DNSOverUDPTransport) Network() string {
|
||||||
|
|
||||||
// Address returns the upstream server address.
|
// Address returns the upstream server address.
|
||||||
func (t *DNSOverUDPTransport) Address() string {
|
func (t *DNSOverUDPTransport) Address() string {
|
||||||
return t.address
|
return t.Endpoint
|
||||||
}
|
}
|
||||||
|
|
||||||
// CloseIdleConnections closes idle connections, if any.
|
// CloseIdleConnections closes idle connections, if any.
|
||||||
func (t *DNSOverUDPTransport) CloseIdleConnections() {
|
func (t *DNSOverUDPTransport) CloseIdleConnections() {
|
||||||
// nothing to do
|
// The underlying dialer MAY have idle connections so let's
|
||||||
|
// forward the call...
|
||||||
|
t.Dialer.CloseIdleConnections()
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ model.DNSTransport = &DNSOverUDPTransport{}
|
var _ model.DNSTransport = &DNSOverUDPTransport{}
|
||||||
|
|
||||||
|
// DNSOverUDPResponse is a response received by a DNSOverUDPTransport when you
|
||||||
|
// use its AsyncRoundTrip method as opposed to using RoundTrip.
|
||||||
|
type DNSOverUDPResponse struct {
|
||||||
|
// Err is the error that occurred (nil in case of success).
|
||||||
|
Err error
|
||||||
|
|
||||||
|
// LocalAddr is the local UDP address we're using.
|
||||||
|
LocalAddr string
|
||||||
|
|
||||||
|
// Operation is the operation that failed.
|
||||||
|
Operation string
|
||||||
|
|
||||||
|
// Query is the related DNS query.
|
||||||
|
Query model.DNSQuery
|
||||||
|
|
||||||
|
// RemoteAddr is the remote server address.
|
||||||
|
RemoteAddr string
|
||||||
|
|
||||||
|
// Response is the response (nil iff error is not nil).
|
||||||
|
Response model.DNSResponse
|
||||||
|
}
|
||||||
|
|
||||||
|
// newDNSOverUDPResponse creates a new DNSOverUDPResponse instance.
|
||||||
|
func (t *DNSOverUDPTransport) newDNSOverUDPResponse(localAddr string, err error,
|
||||||
|
query model.DNSQuery, resp model.DNSResponse, operation string) *DNSOverUDPResponse {
|
||||||
|
return &DNSOverUDPResponse{
|
||||||
|
Err: err,
|
||||||
|
LocalAddr: localAddr,
|
||||||
|
Operation: operation,
|
||||||
|
Query: query,
|
||||||
|
RemoteAddr: t.Endpoint, // The common case is to have an IP:port here (domains are discouraged)
|
||||||
|
Response: resp,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// DNSOverUDPChannel is a wrapper around a channel for reading zero
|
||||||
|
// or more *DNSOverUDPResponse that makes extracting information from
|
||||||
|
// the underlying channels more user friendly than interacting with
|
||||||
|
// the channels directly, thanks to useful wrapper methods implementing
|
||||||
|
// common access patterns. You can still use the channels directly if
|
||||||
|
// there's no convenience method for your specific access pattern.
|
||||||
|
type DNSOverUDPChannel struct {
|
||||||
|
// Response is the channel where we'll post responses. This channel
|
||||||
|
// WON'T be closed when the background goroutine terminates.
|
||||||
|
Response <-chan *DNSOverUDPResponse
|
||||||
|
|
||||||
|
// Joined is a channel that IS CLOSED when the background
|
||||||
|
// goroutine performing this round trip TERMINATES.
|
||||||
|
Joined <-chan bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// Next blocks until the next response is received on Response or the
|
||||||
|
// given context expires, whatever happens first. This function will
|
||||||
|
// completely ignore the Joined channel and will just timeout in case
|
||||||
|
// you call Next after the background goroutine had joined. In fact,
|
||||||
|
// the use case for this function is using it to get a response or
|
||||||
|
// a timeout when you know the DNS round trip is pending.
|
||||||
|
func (ch *DNSOverUDPChannel) Next(ctx context.Context) (model.DNSResponse, error) {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return nil, ctx.Err()
|
||||||
|
case out := <-ch.Response: // Note: AsyncRoundTrip WILL NOT close the channel or emit a nil
|
||||||
|
return out.Response, out.Err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TryNextResponses attempts to read all the buffered messages inside of the "Response"
|
||||||
|
// channel that contains successful DNS responses. That is, this function will silently skip
|
||||||
|
// any possible DNSOverUDPResponse with its Err != nil. The use case for this function is
|
||||||
|
// to obtain all the subsequent response messages we received while we were performing
|
||||||
|
// other operations (e.g., contacting the test helper of fetching a webpage).
|
||||||
|
func (ch *DNSOverUDPChannel) TryNextResponses() (out []model.DNSResponse) {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case r := <-ch.Response: // Note: AsyncRoundTrip WILL NOT close the channel or emit a nil
|
||||||
|
if r.Err == nil && r.Response != nil {
|
||||||
|
out = append(out, r.Response)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// AsyncRoundTrip performs an async DNS round trip. The "buffer" argument
|
||||||
|
// controls how many buffer slots the returned DNSOverUDPChannel's Response
|
||||||
|
// channel should have. A zero or negative value causes this function to
|
||||||
|
// create a channel having a single-slot buffer.
|
||||||
|
//
|
||||||
|
// The real round trip runs in a background goroutine. We will terminate the background
|
||||||
|
// goroutine when (1) the IOTimeout expires for the connection we're using or (2) we
|
||||||
|
// cannot write on the "Response" channel. Note that the background goroutine WILL NOT
|
||||||
|
// close the "Response" channel to signal its completion. Hence, who reads such a
|
||||||
|
// channel MUST be prepared for read operations to block forever and use a
|
||||||
|
// select for draining the channel in a deadlock-safe way. Also, we WILL NOT ever
|
||||||
|
// emit a nil message over the "Response" channel.
|
||||||
|
//
|
||||||
|
// The returned DNSOverUDPChannel contains another channel called Joined that is
|
||||||
|
// closed when the background goroutine terminates, so you can use this channel
|
||||||
|
// should you need to synchronize with such goroutine's termination.
|
||||||
|
//
|
||||||
|
// If you are using the Next or TryNextResponses methods of the DNSOverUDPChannel type,
|
||||||
|
// you don't need to worry about these low level details though.
|
||||||
|
func (t *DNSOverUDPTransport) AsyncRoundTrip(query model.DNSQuery, buffer int) *DNSOverUDPChannel {
|
||||||
|
if buffer < 2 {
|
||||||
|
buffer = 1 // as documented
|
||||||
|
}
|
||||||
|
outch := make(chan *DNSOverUDPResponse, buffer)
|
||||||
|
joinedch := make(chan bool)
|
||||||
|
go t.roundTripLoop(query, outch, joinedch)
|
||||||
|
return &DNSOverUDPChannel{
|
||||||
|
Response: outch,
|
||||||
|
Joined: joinedch,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// roundTripLoop performs the round trip and writes results into the "outch" channel. This
|
||||||
|
// function ASSUMES that "outch" is configured to have AT LEAST one buffer slot. This function
|
||||||
|
// TAKES OWNERSHIP of "outch" but WILL NOT close it when done. This function instead OWNS
|
||||||
|
// the "joinedch" channel and WILL CLOSE it when done.
|
||||||
|
func (t *DNSOverUDPTransport) roundTripLoop(
|
||||||
|
query model.DNSQuery, outch chan<- *DNSOverUDPResponse, joinedch chan<- bool) {
|
||||||
|
defer close(joinedch) // as documented
|
||||||
|
rawQuery, err := query.Bytes()
|
||||||
|
if err != nil {
|
||||||
|
outch <- t.newDNSOverUDPResponse(
|
||||||
|
"", err, query, nil, "serialize_query") // one-sized buffer, can't block
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// While dial operations return immediately for UDP, we MAY be calling the
|
||||||
|
// dialer's resolver if t.Endpoint contains a domain name. So, let us basically
|
||||||
|
// enforce the same overall deadline covering DNS lookup and I/O operations.
|
||||||
|
deadline := time.Now().Add(t.IOTimeout)
|
||||||
|
ctx, cancel := context.WithDeadline(context.Background(), deadline)
|
||||||
|
defer cancel()
|
||||||
|
conn, err := t.Dialer.DialContext(ctx, "udp", t.Endpoint)
|
||||||
|
if err != nil {
|
||||||
|
outch <- t.newDNSOverUDPResponse(
|
||||||
|
"", err, query, nil, ConnectOperation) // one-sized buffer, can't block
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer conn.Close() // we own the conn
|
||||||
|
conn.SetDeadline(deadline)
|
||||||
|
localAddr := conn.LocalAddr().String()
|
||||||
|
if _, err = conn.Write(rawQuery); err != nil {
|
||||||
|
outch <- t.newDNSOverUDPResponse(
|
||||||
|
localAddr, err, query, nil, WriteOperation) // one-sized buffer, can't block
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for {
|
||||||
|
resp, err := t.recv(query, conn)
|
||||||
|
select {
|
||||||
|
case outch <- t.newDNSOverUDPResponse(localAddr, err, query, resp, ReadOperation):
|
||||||
|
default:
|
||||||
|
return // no-one is reading the channel -- so long...
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
// We are going to consider all errors as fatal for now until we
|
||||||
|
// hear of specific errs that it might have sense to ignore.
|
||||||
|
//
|
||||||
|
// Note that erroring out here includes the expiration of the conn's
|
||||||
|
// I/O deadline, which we set above precisely because we want
|
||||||
|
// the total runtime of this goroutine to be bounded.
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// recv receives a single response for the given query using the given conn.
|
||||||
|
func (t *DNSOverUDPTransport) recv(query model.DNSQuery, conn net.Conn) (model.DNSResponse, error) {
|
||||||
|
const maxmessagesize = 1 << 17
|
||||||
|
rawResponse := make([]byte, maxmessagesize)
|
||||||
|
count, err := conn.Read(rawResponse)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
rawResponse = rawResponse[:count]
|
||||||
|
return t.Decoder.DecodeResponse(rawResponse, query)
|
||||||
|
}
|
||||||
|
|
|
@ -9,8 +9,11 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/apex/log"
|
"github.com/apex/log"
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
|
"github.com/miekg/dns"
|
||||||
"github.com/ooni/probe-cli/v3/internal/model"
|
"github.com/ooni/probe-cli/v3/internal/model"
|
||||||
"github.com/ooni/probe-cli/v3/internal/model/mocks"
|
"github.com/ooni/probe-cli/v3/internal/model/mocks"
|
||||||
|
"github.com/ooni/probe-cli/v3/internal/netxlite/filtering"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestDNSOverUDPTransport(t *testing.T) {
|
func TestDNSOverUDPTransport(t *testing.T) {
|
||||||
|
@ -70,6 +73,16 @@ func TestDNSOverUDPTransport(t *testing.T) {
|
||||||
MockClose: func() error {
|
MockClose: func() error {
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
|
MockLocalAddr: func() net.Addr {
|
||||||
|
return &mocks.Addr{
|
||||||
|
MockNetwork: func() string {
|
||||||
|
return "udp"
|
||||||
|
},
|
||||||
|
MockString: func() string {
|
||||||
|
return "127.0.0.1:1345"
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
}, nil
|
}, nil
|
||||||
},
|
},
|
||||||
}, "9.9.9.9:53",
|
}, "9.9.9.9:53",
|
||||||
|
@ -106,6 +119,16 @@ func TestDNSOverUDPTransport(t *testing.T) {
|
||||||
MockClose: func() error {
|
MockClose: func() error {
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
|
MockLocalAddr: func() net.Addr {
|
||||||
|
return &mocks.Addr{
|
||||||
|
MockNetwork: func() string {
|
||||||
|
return "udp"
|
||||||
|
},
|
||||||
|
MockString: func() string {
|
||||||
|
return "127.0.0.1:1345"
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
}, nil
|
}, nil
|
||||||
},
|
},
|
||||||
}, "9.9.9.9:53",
|
}, "9.9.9.9:53",
|
||||||
|
@ -141,12 +164,22 @@ func TestDNSOverUDPTransport(t *testing.T) {
|
||||||
MockClose: func() error {
|
MockClose: func() error {
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
|
MockLocalAddr: func() net.Addr {
|
||||||
|
return &mocks.Addr{
|
||||||
|
MockNetwork: func() string {
|
||||||
|
return "udp"
|
||||||
|
},
|
||||||
|
MockString: func() string {
|
||||||
|
return "127.0.0.1:1345"
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
}, nil
|
}, nil
|
||||||
},
|
},
|
||||||
}, "9.9.9.9:53",
|
}, "9.9.9.9:53",
|
||||||
)
|
)
|
||||||
expectedErr := errors.New("mocked error")
|
expectedErr := errors.New("mocked error")
|
||||||
txp.decoder = &mocks.DNSDecoder{
|
txp.Decoder = &mocks.DNSDecoder{
|
||||||
MockDecodeResponse: func(data []byte, query model.DNSQuery) (model.DNSResponse, error) {
|
MockDecodeResponse: func(data []byte, query model.DNSQuery) (model.DNSResponse, error) {
|
||||||
return nil, expectedErr
|
return nil, expectedErr
|
||||||
},
|
},
|
||||||
|
@ -165,7 +198,7 @@ func TestDNSOverUDPTransport(t *testing.T) {
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("read success", func(t *testing.T) {
|
t.Run("decode success", func(t *testing.T) {
|
||||||
const expected = 17
|
const expected = 17
|
||||||
input := bytes.NewReader(make([]byte, expected))
|
input := bytes.NewReader(make([]byte, expected))
|
||||||
txp := NewDNSOverUDPTransport(
|
txp := NewDNSOverUDPTransport(
|
||||||
|
@ -182,12 +215,22 @@ func TestDNSOverUDPTransport(t *testing.T) {
|
||||||
MockClose: func() error {
|
MockClose: func() error {
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
|
MockLocalAddr: func() net.Addr {
|
||||||
|
return &mocks.Addr{
|
||||||
|
MockNetwork: func() string {
|
||||||
|
return "udp"
|
||||||
|
},
|
||||||
|
MockString: func() string {
|
||||||
|
return "127.0.0.1:1345"
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
}, nil
|
}, nil
|
||||||
},
|
},
|
||||||
}, "9.9.9.9:53",
|
}, "9.9.9.9:53",
|
||||||
)
|
)
|
||||||
expectedResp := &mocks.DNSResponse{}
|
expectedResp := &mocks.DNSResponse{}
|
||||||
txp.decoder = &mocks.DNSDecoder{
|
txp.Decoder = &mocks.DNSDecoder{
|
||||||
MockDecodeResponse: func(data []byte, query model.DNSQuery) (model.DNSResponse, error) {
|
MockDecodeResponse: func(data []byte, query model.DNSQuery) (model.DNSResponse, error) {
|
||||||
return expectedResp, nil
|
return expectedResp, nil
|
||||||
},
|
},
|
||||||
|
@ -205,6 +248,190 @@ func TestDNSOverUDPTransport(t *testing.T) {
|
||||||
t.Fatal("unexpected resp")
|
t.Fatal("unexpected resp")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("using a real server", func(t *testing.T) {
|
||||||
|
srvr := &filtering.DNSServer{
|
||||||
|
OnQuery: func(domain string) filtering.DNSAction {
|
||||||
|
return filtering.DNSActionCache
|
||||||
|
},
|
||||||
|
Cache: map[string][]string{
|
||||||
|
"dns.google.": {"8.8.8.8"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
listener, err := srvr.Start("127.0.0.1:0")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer listener.Close()
|
||||||
|
dialer := NewDialerWithoutResolver(model.DiscardLogger)
|
||||||
|
txp := NewDNSOverUDPTransport(dialer, listener.LocalAddr().String())
|
||||||
|
encoder := &DNSEncoderMiekg{}
|
||||||
|
query := encoder.Encode("dns.google.", dns.TypeA, false)
|
||||||
|
resp, err := txp.RoundTrip(context.Background(), query)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
addrs, err := resp.DecodeLookupHost()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if diff := cmp.Diff(addrs, []string{"8.8.8.8"}); diff != "" {
|
||||||
|
t.Fatal(diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("AsyncRoundTrip", func(t *testing.T) {
|
||||||
|
t.Run("calling Next with cancelled context", func(t *testing.T) {
|
||||||
|
blocker := make(chan interface{})
|
||||||
|
const expected = 17
|
||||||
|
input := bytes.NewReader(make([]byte, expected))
|
||||||
|
txp := NewDNSOverUDPTransport(
|
||||||
|
&mocks.Dialer{
|
||||||
|
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||||
|
<-blocker // block here until Next returns because of expired context
|
||||||
|
return &mocks.Conn{
|
||||||
|
MockSetDeadline: func(t time.Time) error {
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
MockWrite: func(b []byte) (int, error) {
|
||||||
|
return len(b), nil
|
||||||
|
},
|
||||||
|
MockRead: input.Read,
|
||||||
|
MockClose: func() error {
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
MockLocalAddr: func() net.Addr {
|
||||||
|
return &mocks.Addr{
|
||||||
|
MockNetwork: func() string {
|
||||||
|
return "udp"
|
||||||
|
},
|
||||||
|
MockString: func() string {
|
||||||
|
return "127.0.0.1:1345"
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
}, "9.9.9.9:53",
|
||||||
|
)
|
||||||
|
expectedResp := &mocks.DNSResponse{}
|
||||||
|
txp.Decoder = &mocks.DNSDecoder{
|
||||||
|
MockDecodeResponse: func(data []byte, query model.DNSQuery) (model.DNSResponse, error) {
|
||||||
|
return expectedResp, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
query := &mocks.DNSQuery{
|
||||||
|
MockBytes: func() ([]byte, error) {
|
||||||
|
return make([]byte, 128), nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
out := txp.AsyncRoundTrip(query, 1)
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
cancel() // immediately cancel
|
||||||
|
resp, err := out.Next(ctx)
|
||||||
|
if !errors.Is(err, context.Canceled) {
|
||||||
|
t.Fatal("unexpected err", err)
|
||||||
|
}
|
||||||
|
if resp != nil {
|
||||||
|
t.Fatal("unexpected resp")
|
||||||
|
}
|
||||||
|
close(blocker) // unblock the background goroutine
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("typical usage to obtain late responses", func(t *testing.T) {
|
||||||
|
srvr := &filtering.DNSServer{
|
||||||
|
OnQuery: func(domain string) filtering.DNSAction {
|
||||||
|
return filtering.DNSActionLocalHostPlusCache
|
||||||
|
},
|
||||||
|
Cache: map[string][]string{
|
||||||
|
"dns.google.": {"8.8.8.8"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
listener, err := srvr.Start("127.0.0.1:0")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer listener.Close()
|
||||||
|
dialer := NewDialerWithoutResolver(model.DiscardLogger)
|
||||||
|
txp := NewDNSOverUDPTransport(dialer, listener.LocalAddr().String())
|
||||||
|
encoder := &DNSEncoderMiekg{}
|
||||||
|
query := encoder.Encode("dns.google.", dns.TypeA, false)
|
||||||
|
rch := txp.AsyncRoundTrip(query, 1)
|
||||||
|
resp, err := rch.Next(context.Background())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
addrs, err := resp.DecodeLookupHost()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if diff := cmp.Diff(addrs, []string{"127.0.0.1"}); diff != "" {
|
||||||
|
t.Fatal(diff)
|
||||||
|
}
|
||||||
|
// One would not normally busy loop but it's fine to do that in the context
|
||||||
|
// of this test because we know we're going to receive a second reply. In
|
||||||
|
// a real network experiment here we'll do other activities, e.g., contacting
|
||||||
|
// the test helper or fetching a webpage.
|
||||||
|
var additional []model.DNSResponse
|
||||||
|
for {
|
||||||
|
additional = rch.TryNextResponses()
|
||||||
|
if len(additional) > 0 {
|
||||||
|
if len(additional) != 1 {
|
||||||
|
t.Fatal("expected exactly one additional response")
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
addrs, err = additional[0].DecodeLookupHost()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if diff := cmp.Diff(addrs, []string{"8.8.8.8"}); diff != "" {
|
||||||
|
t.Fatal(diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("correct behavior when read times out", func(t *testing.T) {
|
||||||
|
srvr := &filtering.DNSServer{
|
||||||
|
OnQuery: func(domain string) filtering.DNSAction {
|
||||||
|
return filtering.DNSActionTimeout
|
||||||
|
},
|
||||||
|
}
|
||||||
|
listener, err := srvr.Start("127.0.0.1:0")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer listener.Close()
|
||||||
|
dialer := NewDialerWithoutResolver(model.DiscardLogger)
|
||||||
|
txp := NewDNSOverUDPTransport(dialer, listener.LocalAddr().String())
|
||||||
|
txp.IOTimeout = 30 * time.Millisecond // short timeout to have a fast test
|
||||||
|
encoder := &DNSEncoderMiekg{}
|
||||||
|
query := encoder.Encode("dns.google.", dns.TypeA, false)
|
||||||
|
rch := txp.AsyncRoundTrip(query, 1)
|
||||||
|
result := <-rch.Response
|
||||||
|
if result.Err == nil || result.Err.Error() != "generic_timeout_error" {
|
||||||
|
t.Fatal("unexpected error", result.Err)
|
||||||
|
}
|
||||||
|
if result.Operation != ReadOperation {
|
||||||
|
t.Fatal("unexpected failed operation", result.Operation)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("CloseIdleConnections", func(t *testing.T) {
|
||||||
|
var called bool
|
||||||
|
dialer := &mocks.Dialer{
|
||||||
|
MockCloseIdleConnections: func() {
|
||||||
|
called = true
|
||||||
|
},
|
||||||
|
}
|
||||||
|
const address = "9.9.9.9:53"
|
||||||
|
txp := NewDNSOverUDPTransport(dialer, address)
|
||||||
|
txp.CloseIdleConnections()
|
||||||
|
if !called {
|
||||||
|
t.Fatal("not called")
|
||||||
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("other functions okay", func(t *testing.T) {
|
t.Run("other functions okay", func(t *testing.T) {
|
||||||
|
|
|
@ -1,22 +1,19 @@
|
||||||
package filtering
|
package filtering
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
"github.com/ooni/probe-cli/v3/internal/runtimex"
|
"github.com/ooni/probe-cli/v3/internal/runtimex"
|
||||||
)
|
)
|
||||||
|
|
||||||
// DNSAction is a DNS filtering action that this proxy should take.
|
// DNSAction is a DNS filtering action that a DNSServer should take.
|
||||||
type DNSAction string
|
type DNSAction string
|
||||||
|
|
||||||
const (
|
const (
|
||||||
// DNSActionPass passes the traffic to the upstream server.
|
|
||||||
DNSActionPass = DNSAction("pass")
|
|
||||||
|
|
||||||
// DNSActionNXDOMAIN replies with NXDOMAIN.
|
// DNSActionNXDOMAIN replies with NXDOMAIN.
|
||||||
DNSActionNXDOMAIN = DNSAction("nxdomain")
|
DNSActionNXDOMAIN = DNSAction("nxdomain")
|
||||||
|
|
||||||
|
@ -32,15 +29,19 @@ const (
|
||||||
// DNSActionTimeout never replies to the query.
|
// DNSActionTimeout never replies to the query.
|
||||||
DNSActionTimeout = DNSAction("timeout")
|
DNSActionTimeout = DNSAction("timeout")
|
||||||
|
|
||||||
// DNSActionCache causes the proxy to check the cache. If there
|
// DNSActionCache causes the server to check the cache. If there
|
||||||
// are entries, they are returned. Otherwise, NXDOMAIN is returned.
|
// are entries, they are returned. Otherwise, NXDOMAIN is returned.
|
||||||
DNSActionCache = DNSAction("cache")
|
DNSActionCache = DNSAction("cache")
|
||||||
|
|
||||||
|
// DNSActionLocalHostPlusCache combines the LocalHost and
|
||||||
|
// Cache actions returning first a localhost response followed
|
||||||
|
// by a subsequent response obtained using the cache.
|
||||||
|
DNSActionLocalHostPlusCache = DNSAction("localhost+cache")
|
||||||
)
|
)
|
||||||
|
|
||||||
// DNSProxy is a DNS proxy that routes traffic to an upstream
|
// DNSServer is a DNS server implementing filtering policies.
|
||||||
// resolver and may implement filtering policies.
|
type DNSServer struct {
|
||||||
type DNSProxy struct {
|
// Cache is the OPTIONAL DNS cache. Note that the keys of the map
|
||||||
// Cache is the DNS cache. Note that the keys of the map
|
|
||||||
// must be FQDNs (i.e., including the final `.`).
|
// must be FQDNs (i.e., including the final `.`).
|
||||||
Cache map[string][]string
|
Cache map[string][]string
|
||||||
|
|
||||||
|
@ -48,26 +49,27 @@ type DNSProxy struct {
|
||||||
// receive a query for the given domain.
|
// receive a query for the given domain.
|
||||||
OnQuery func(domain string) DNSAction
|
OnQuery func(domain string) DNSAction
|
||||||
|
|
||||||
// UpstreamEndpoint is the OPTIONAL upstream transport endpoint.
|
// onTimeout is the OPTIONAL channel where we emit a true
|
||||||
UpstreamEndpoint string
|
// value each time there's a timeout. If you set this value
|
||||||
|
// to a non-nil channel, then you MUST drain the channel
|
||||||
// mockableReply allows to mock DNSProxy.reply in tests.
|
// for each expected timeout. Otherwise, the code will just
|
||||||
mockableReply func(query *dns.Msg) (*dns.Msg, error)
|
// ignore this field and nothing will be emitted.
|
||||||
|
onTimeout chan bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// DNSListener is the interface returned by DNSProxy.Start
|
// DNSListener is the interface returned by DNSServer.Start.
|
||||||
type DNSListener interface {
|
type DNSListener interface {
|
||||||
io.Closer
|
io.Closer
|
||||||
LocalAddr() net.Addr
|
LocalAddr() net.Addr
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start starts the proxy.
|
// Start starts this server.
|
||||||
func (p *DNSProxy) Start(address string) (DNSListener, error) {
|
func (p *DNSServer) Start(address string) (DNSListener, error) {
|
||||||
pconn, _, err := p.start(address)
|
pconn, _, err := p.start(address)
|
||||||
return pconn, err
|
return pconn, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *DNSProxy) start(address string) (DNSListener, <-chan interface{}, error) {
|
func (p *DNSServer) start(address string) (DNSListener, <-chan interface{}, error) {
|
||||||
pconn, err := net.ListenPacket("udp", address)
|
pconn, err := net.ListenPacket("udp", address)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
|
@ -77,15 +79,15 @@ func (p *DNSProxy) start(address string) (DNSListener, <-chan interface{}, error
|
||||||
return pconn, done, nil
|
return pconn, done, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *DNSProxy) mainloop(pconn net.PacketConn, done chan<- interface{}) {
|
func (p *DNSServer) mainloop(pconn net.PacketConn, done chan<- interface{}) {
|
||||||
defer close(done)
|
defer close(done)
|
||||||
for p.oneloop(pconn) {
|
for p.oneloop(pconn) {
|
||||||
// nothing
|
// nothing
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *DNSProxy) oneloop(pconn net.PacketConn) bool {
|
func (p *DNSServer) oneloop(pconn net.PacketConn) bool {
|
||||||
buffer := make([]byte, 1<<12)
|
buffer := make([]byte, 1<<17)
|
||||||
count, addr, err := pconn.ReadFrom(buffer)
|
count, addr, err := pconn.ReadFrom(buffer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return !strings.HasSuffix(err.Error(), "use of closed network connection")
|
return !strings.HasSuffix(err.Error(), "use of closed network connection")
|
||||||
|
@ -95,73 +97,70 @@ func (p *DNSProxy) oneloop(pconn net.PacketConn) bool {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *DNSProxy) serveAsync(pconn net.PacketConn, addr net.Addr, buffer []byte) {
|
func (p *DNSServer) emit(pconn net.PacketConn, addr net.Addr, reply ...*dns.Msg) (success int) {
|
||||||
|
for _, entry := range reply {
|
||||||
|
replyBytes, err := entry.Pack()
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
pconn.WriteTo(replyBytes, addr)
|
||||||
|
success++ // we use this value in tests
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *DNSServer) serveAsync(pconn net.PacketConn, addr net.Addr, buffer []byte) {
|
||||||
query := &dns.Msg{}
|
query := &dns.Msg{}
|
||||||
if err := query.Unpack(buffer); err != nil {
|
if err := query.Unpack(buffer); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
reply, err := p.reply(query)
|
if len(query.Question) < 1 {
|
||||||
if err != nil {
|
return // just discard the query
|
||||||
return
|
|
||||||
}
|
|
||||||
replyBytes, err := reply.Pack()
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
pconn.WriteTo(replyBytes, addr)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *DNSProxy) reply(query *dns.Msg) (*dns.Msg, error) {
|
|
||||||
if p.mockableReply != nil {
|
|
||||||
return p.mockableReply(query)
|
|
||||||
}
|
|
||||||
return p.replyDefault(query)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *DNSProxy) replyDefault(query *dns.Msg) (*dns.Msg, error) {
|
|
||||||
if len(query.Question) != 1 {
|
|
||||||
return nil, errors.New("unhandled message")
|
|
||||||
}
|
}
|
||||||
name := query.Question[0].Name
|
name := query.Question[0].Name
|
||||||
switch p.OnQuery(name) {
|
switch p.OnQuery(name) {
|
||||||
case DNSActionPass:
|
|
||||||
return p.proxy(query)
|
|
||||||
case DNSActionNXDOMAIN:
|
case DNSActionNXDOMAIN:
|
||||||
return p.nxdomain(query), nil
|
p.emit(pconn, addr, p.nxdomain(query))
|
||||||
case DNSActionLocalHost:
|
case DNSActionLocalHost:
|
||||||
return p.localHost(query), nil
|
p.emit(pconn, addr, p.localHost(query))
|
||||||
case DNSActionNoAnswer:
|
case DNSActionNoAnswer:
|
||||||
return p.empty(query), nil
|
p.emit(pconn, addr, p.empty(query))
|
||||||
case DNSActionTimeout:
|
case DNSActionTimeout:
|
||||||
return nil, errors.New("let's ignore this query")
|
if p.onTimeout != nil {
|
||||||
|
p.onTimeout <- true
|
||||||
|
}
|
||||||
case DNSActionCache:
|
case DNSActionCache:
|
||||||
return p.cache(name, query), nil
|
p.emit(pconn, addr, p.cache(name, query))
|
||||||
|
case DNSActionLocalHostPlusCache:
|
||||||
|
p.emit(pconn, addr, p.localHost(query))
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
p.emit(pconn, addr, p.cache(name, query))
|
||||||
default:
|
default:
|
||||||
return p.refused(query), nil
|
p.emit(pconn, addr, p.refused(query))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *DNSProxy) refused(query *dns.Msg) *dns.Msg {
|
func (p *DNSServer) refused(query *dns.Msg) *dns.Msg {
|
||||||
m := new(dns.Msg)
|
m := new(dns.Msg)
|
||||||
m.SetRcode(query, dns.RcodeRefused)
|
m.SetRcode(query, dns.RcodeRefused)
|
||||||
return m
|
return m
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *DNSProxy) nxdomain(query *dns.Msg) *dns.Msg {
|
func (p *DNSServer) nxdomain(query *dns.Msg) *dns.Msg {
|
||||||
m := new(dns.Msg)
|
m := new(dns.Msg)
|
||||||
m.SetRcode(query, dns.RcodeNameError)
|
m.SetRcode(query, dns.RcodeNameError)
|
||||||
return m
|
return m
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *DNSProxy) localHost(query *dns.Msg) *dns.Msg {
|
func (p *DNSServer) localHost(query *dns.Msg) *dns.Msg {
|
||||||
return p.compose(query, net.IPv6loopback, net.IPv4(127, 0, 0, 1))
|
return p.compose(query, net.IPv6loopback, net.IPv4(127, 0, 0, 1))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *DNSProxy) empty(query *dns.Msg) *dns.Msg {
|
func (p *DNSServer) empty(query *dns.Msg) *dns.Msg {
|
||||||
return p.compose(query)
|
return p.compose(query)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *DNSProxy) compose(query *dns.Msg, ips ...net.IP) *dns.Msg {
|
func (p *DNSServer) compose(query *dns.Msg, ips ...net.IP) *dns.Msg {
|
||||||
runtimex.PanicIfTrue(len(query.Question) != 1, "expecting a single question")
|
runtimex.PanicIfTrue(len(query.Question) != 1, "expecting a single question")
|
||||||
question := query.Question[0]
|
question := query.Question[0]
|
||||||
reply := new(dns.Msg)
|
reply := new(dns.Msg)
|
||||||
|
@ -195,27 +194,7 @@ func (p *DNSProxy) compose(query *dns.Msg, ips ...net.IP) *dns.Msg {
|
||||||
return reply
|
return reply
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
func (p *DNSServer) cache(name string, query *dns.Msg) *dns.Msg {
|
||||||
// errDNSExpectedSingleQuestion means we expected to see a single question
|
|
||||||
errDNSExpectedSingleQuestion = errors.New("filtering: expected single DNS question")
|
|
||||||
|
|
||||||
// errDNSExpectedQueryNotResponse means we expected to see a query.
|
|
||||||
errDNSExpectedQueryNotResponse = errors.New("filtering: expected query not response")
|
|
||||||
)
|
|
||||||
|
|
||||||
func (p *DNSProxy) proxy(query *dns.Msg) (*dns.Msg, error) {
|
|
||||||
if query.Response {
|
|
||||||
return nil, errDNSExpectedQueryNotResponse
|
|
||||||
}
|
|
||||||
if len(query.Question) != 1 {
|
|
||||||
return nil, errDNSExpectedSingleQuestion
|
|
||||||
}
|
|
||||||
clnt := &dns.Client{}
|
|
||||||
resp, _, err := clnt.Exchange(query, p.upstreamEndpoint())
|
|
||||||
return resp, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *DNSProxy) cache(name string, query *dns.Msg) *dns.Msg {
|
|
||||||
addrs := p.Cache[name]
|
addrs := p.Cache[name]
|
||||||
var ipAddrs []net.IP
|
var ipAddrs []net.IP
|
||||||
for _, addr := range addrs {
|
for _, addr := range addrs {
|
||||||
|
@ -228,10 +207,3 @@ func (p *DNSProxy) cache(name string, query *dns.Msg) *dns.Msg {
|
||||||
}
|
}
|
||||||
return p.compose(query, ipAddrs...)
|
return p.compose(query, ipAddrs...)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *DNSProxy) upstreamEndpoint() string {
|
|
||||||
if p.UpstreamEndpoint != "" {
|
|
||||||
return p.UpstreamEndpoint
|
|
||||||
}
|
|
||||||
return "8.8.8.8:53"
|
|
||||||
}
|
|
||||||
|
|
|
@ -1,122 +1,98 @@
|
||||||
package filtering
|
package filtering
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"errors"
|
"errors"
|
||||||
"net"
|
"net"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/apex/log"
|
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
"github.com/ooni/probe-cli/v3/internal/model"
|
|
||||||
"github.com/ooni/probe-cli/v3/internal/model/mocks"
|
"github.com/ooni/probe-cli/v3/internal/model/mocks"
|
||||||
"github.com/ooni/probe-cli/v3/internal/netxlite"
|
"github.com/ooni/probe-cli/v3/internal/randx"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestDNSProxy(t *testing.T) {
|
func TestDNSServer(t *testing.T) {
|
||||||
if testing.Short() {
|
newServerWithCache := func(action DNSAction, cache map[string][]string) (
|
||||||
t.Skip("skip test in short mode")
|
*DNSServer, DNSListener, <-chan interface{}, error) {
|
||||||
}
|
p := &DNSServer{
|
||||||
newProxyWithCache := func(action DNSAction, cache map[string][]string) (DNSListener, <-chan interface{}, error) {
|
|
||||||
p := &DNSProxy{
|
|
||||||
Cache: cache,
|
Cache: cache,
|
||||||
OnQuery: func(domain string) DNSAction {
|
OnQuery: func(domain string) DNSAction {
|
||||||
return action
|
return action
|
||||||
},
|
},
|
||||||
|
onTimeout: make(chan bool),
|
||||||
}
|
}
|
||||||
return p.start("127.0.0.1:0")
|
listener, done, err := p.start("127.0.0.1:0")
|
||||||
|
return p, listener, done, err
|
||||||
}
|
}
|
||||||
|
|
||||||
newProxy := func(action DNSAction) (DNSListener, <-chan interface{}, error) {
|
newServer := func(action DNSAction) (*DNSServer, DNSListener, <-chan interface{}, error) {
|
||||||
return newProxyWithCache(action, nil)
|
return newServerWithCache(action, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
newresolver := func(listener DNSListener) model.Resolver {
|
newQuery := func(qtype uint16) *dns.Msg {
|
||||||
dlr := netxlite.NewDialerWithoutResolver(log.Log)
|
question := dns.Question{
|
||||||
r := netxlite.NewResolverUDP(log.Log, dlr, listener.LocalAddr().String())
|
Name: dns.Fqdn("dns.google"),
|
||||||
return r
|
Qtype: qtype,
|
||||||
|
Qclass: dns.ClassINET,
|
||||||
|
}
|
||||||
|
query := new(dns.Msg)
|
||||||
|
query.Id = dns.Id()
|
||||||
|
query.RecursionDesired = true
|
||||||
|
query.Question = make([]dns.Question, 1)
|
||||||
|
query.Question[0] = question
|
||||||
|
return query
|
||||||
}
|
}
|
||||||
|
|
||||||
t.Run("DNSActionPass", func(t *testing.T) {
|
|
||||||
ctx := context.Background()
|
|
||||||
listener, done, err := newProxy(DNSActionPass)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
r := newresolver(listener)
|
|
||||||
addrs, err := r.LookupHost(ctx, "dns.google")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
if addrs == nil {
|
|
||||||
t.Fatal("unexpected empty addrs")
|
|
||||||
}
|
|
||||||
var found bool
|
|
||||||
for _, addr := range addrs {
|
|
||||||
found = found || addr == "8.8.8.8"
|
|
||||||
}
|
|
||||||
if !found {
|
|
||||||
t.Fatal("did not find 8.8.8.8")
|
|
||||||
}
|
|
||||||
listener.Close()
|
|
||||||
<-done // wait for background goroutine to exit
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("DNSActionNXDOMAIN", func(t *testing.T) {
|
t.Run("DNSActionNXDOMAIN", func(t *testing.T) {
|
||||||
ctx := context.Background()
|
_, listener, done, err := newServer(DNSActionNXDOMAIN)
|
||||||
listener, done, err := newProxy(DNSActionNXDOMAIN)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
r := newresolver(listener)
|
reply, err := dns.Exchange(newQuery(dns.TypeA), listener.LocalAddr().String())
|
||||||
addrs, err := r.LookupHost(ctx, "dns.google")
|
if err != nil {
|
||||||
if err == nil || err.Error() != netxlite.FailureDNSNXDOMAINError {
|
t.Fatal(err)
|
||||||
t.Fatal("unexpected err", err)
|
|
||||||
}
|
}
|
||||||
if addrs != nil {
|
if reply.Rcode != dns.RcodeNameError {
|
||||||
t.Fatal("expected empty addrs")
|
t.Fatal("unexpected rcode")
|
||||||
}
|
}
|
||||||
listener.Close()
|
listener.Close()
|
||||||
<-done // wait for background goroutine to exit
|
<-done // wait for background goroutine to exit
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("DNSActionRefused", func(t *testing.T) {
|
t.Run("DNSActionRefused", func(t *testing.T) {
|
||||||
ctx := context.Background()
|
_, listener, done, err := newServer(DNSActionRefused)
|
||||||
listener, done, err := newProxy(DNSActionRefused)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
r := newresolver(listener)
|
reply, err := dns.Exchange(newQuery(dns.TypeA), listener.LocalAddr().String())
|
||||||
addrs, err := r.LookupHost(ctx, "dns.google")
|
if err != nil {
|
||||||
if err == nil || err.Error() != netxlite.FailureDNSRefusedError {
|
t.Fatal(err)
|
||||||
t.Fatal("unexpected err", err)
|
|
||||||
}
|
}
|
||||||
if addrs != nil {
|
if reply.Rcode != dns.RcodeRefused {
|
||||||
t.Fatal("expected empty addrs")
|
t.Fatal("unexpected rcode")
|
||||||
}
|
}
|
||||||
listener.Close()
|
listener.Close()
|
||||||
<-done // wait for background goroutine to exit
|
<-done // wait for background goroutine to exit
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("DNSActionLocalHost", func(t *testing.T) {
|
t.Run("DNSActionLocalHost", func(t *testing.T) {
|
||||||
ctx := context.Background()
|
_, listener, done, err := newServer(DNSActionLocalHost)
|
||||||
listener, done, err := newProxy(DNSActionLocalHost)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
r := newresolver(listener)
|
reply, err := dns.Exchange(newQuery(dns.TypeA), listener.LocalAddr().String())
|
||||||
addrs, err := r.LookupHost(ctx, "dns.google")
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
if addrs == nil {
|
if reply.Rcode != dns.RcodeSuccess {
|
||||||
t.Fatal("expected non-empty addrs")
|
t.Fatal("unexpected rcode")
|
||||||
}
|
}
|
||||||
var found bool
|
var found bool
|
||||||
for _, addr := range addrs {
|
for _, ans := range reply.Answer {
|
||||||
found = found || addr == "127.0.0.1"
|
switch v := ans.(type) {
|
||||||
|
case *dns.A:
|
||||||
|
found = found || v.A.String() == "127.0.0.1"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if !found {
|
if !found {
|
||||||
t.Fatal("did not find 127.0.0.1")
|
t.Fatal("did not find 127.0.0.1")
|
||||||
|
@ -126,94 +102,154 @@ func TestDNSProxy(t *testing.T) {
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("DNSActionEmpty", func(t *testing.T) {
|
t.Run("DNSActionEmpty", func(t *testing.T) {
|
||||||
ctx := context.Background()
|
_, listener, done, err := newServer(DNSActionNoAnswer)
|
||||||
listener, done, err := newProxy(DNSActionNoAnswer)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
r := newresolver(listener)
|
reply, err := dns.Exchange(newQuery(dns.TypeA), listener.LocalAddr().String())
|
||||||
addrs, err := r.LookupHost(ctx, "dns.google")
|
if err != nil {
|
||||||
if err == nil || err.Error() != netxlite.FailureDNSNoAnswer {
|
t.Fatal(err)
|
||||||
t.Fatal("unexpected err", err)
|
|
||||||
}
|
}
|
||||||
if addrs != nil {
|
if reply.Rcode != dns.RcodeSuccess {
|
||||||
t.Fatal("expected empty addrs")
|
t.Fatal("unexpected rcode")
|
||||||
|
}
|
||||||
|
if len(reply.Answer) != 0 {
|
||||||
|
t.Fatal("expected no answers")
|
||||||
}
|
}
|
||||||
listener.Close()
|
listener.Close()
|
||||||
<-done // wait for background goroutine to exit
|
<-done // wait for background goroutine to exit
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("DNSActionTimeout", func(t *testing.T) {
|
t.Run("DNSActionTimeout", func(t *testing.T) {
|
||||||
// Implementation note: if you see this test running for more
|
srvr, listener, done, err := newServer(DNSActionTimeout)
|
||||||
// than one second, then it means we're not checking the context
|
|
||||||
// immediately. We should be improving there but we need to be
|
|
||||||
// careful because lots of legacy code uses SerialResolver.
|
|
||||||
const timeout = time.Second
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
|
||||||
defer cancel()
|
|
||||||
listener, done, err := newProxy(DNSActionTimeout)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
r := newresolver(listener)
|
c := &dns.Client{}
|
||||||
addrs, err := r.LookupHost(ctx, "dns.google")
|
conn, err := c.Dial(listener.LocalAddr().String())
|
||||||
if err == nil || err.Error() != netxlite.FailureGenericTimeoutError {
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
go func() {
|
||||||
|
<-srvr.onTimeout
|
||||||
|
conn.Close() // close as soon as the server times out, so this test is fast
|
||||||
|
}()
|
||||||
|
reply, _, err := c.ExchangeWithConn(newQuery(dns.TypeA), conn)
|
||||||
|
if !errors.Is(err, net.ErrClosed) {
|
||||||
t.Fatal("unexpected err", err)
|
t.Fatal("unexpected err", err)
|
||||||
}
|
}
|
||||||
if addrs != nil {
|
if reply != nil {
|
||||||
t.Fatal("expected empty addrs")
|
t.Fatal("expected nil reply here")
|
||||||
}
|
}
|
||||||
listener.Close()
|
listener.Close()
|
||||||
<-done // wait for background goroutine to exit
|
<-done // wait for background goroutine to exit
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("DNSActionCache without entries", func(t *testing.T) {
|
t.Run("DNSActionCache without entries", func(t *testing.T) {
|
||||||
ctx := context.Background()
|
_, listener, done, err := newServerWithCache(DNSActionCache, nil)
|
||||||
listener, done, err := newProxyWithCache(DNSActionCache, nil)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
r := newresolver(listener)
|
reply, err := dns.Exchange(newQuery(dns.TypeA), listener.LocalAddr().String())
|
||||||
addrs, err := r.LookupHost(ctx, "dns.google")
|
if err != nil {
|
||||||
if err == nil || err.Error() != netxlite.FailureDNSNXDOMAINError {
|
t.Fatal(err)
|
||||||
t.Fatal("unexpected err", err)
|
|
||||||
}
|
}
|
||||||
if addrs != nil {
|
if reply.Rcode != dns.RcodeNameError {
|
||||||
t.Fatal("expected empty addrs")
|
t.Fatal("unexpected rcode")
|
||||||
}
|
}
|
||||||
listener.Close()
|
listener.Close()
|
||||||
<-done // wait for background goroutine to exit
|
<-done // wait for background goroutine to exit
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("DNSActionCache with entries", func(t *testing.T) {
|
t.Run("DNSActionCache with IPv4 entry", func(t *testing.T) {
|
||||||
ctx := context.Background()
|
|
||||||
cache := map[string][]string{
|
cache := map[string][]string{
|
||||||
"dns.google.": {"8.8.8.8", "8.8.4.4"},
|
"dns.google.": {"8.8.8.8"},
|
||||||
}
|
}
|
||||||
listener, done, err := newProxyWithCache(DNSActionCache, cache)
|
_, listener, done, err := newServerWithCache(DNSActionCache, cache)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
r := newresolver(listener)
|
reply, err := dns.Exchange(newQuery(dns.TypeA), listener.LocalAddr().String())
|
||||||
addrs, err := r.LookupHost(ctx, "dns.google")
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
if len(addrs) != 2 {
|
if reply.Rcode != dns.RcodeSuccess {
|
||||||
t.Fatal("expected two entries")
|
t.Fatal("unexpected rcode")
|
||||||
}
|
}
|
||||||
if addrs[0] != "8.8.8.8" {
|
var found bool
|
||||||
t.Fatal("invalid first entry")
|
for _, ans := range reply.Answer {
|
||||||
|
switch v := ans.(type) {
|
||||||
|
case *dns.A:
|
||||||
|
found = found || v.A.String() == "8.8.8.8"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if addrs[1] != "8.8.4.4" {
|
if !found {
|
||||||
t.Fatal("invalid second entry")
|
t.Fatal("did not find 8.8.8.8")
|
||||||
|
}
|
||||||
|
listener.Close()
|
||||||
|
<-done // wait for background goroutine to exit
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("DNSActionCache with IPv6 entry", func(t *testing.T) {
|
||||||
|
cache := map[string][]string{
|
||||||
|
"dns.google.": {"2001:4860:4860::8888"},
|
||||||
|
}
|
||||||
|
_, listener, done, err := newServerWithCache(DNSActionCache, cache)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
reply, err := dns.Exchange(newQuery(dns.TypeAAAA), listener.LocalAddr().String())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if reply.Rcode != dns.RcodeSuccess {
|
||||||
|
t.Fatal("unexpected rcode")
|
||||||
|
}
|
||||||
|
var found bool
|
||||||
|
for _, ans := range reply.Answer {
|
||||||
|
switch v := ans.(type) {
|
||||||
|
case *dns.AAAA:
|
||||||
|
found = found || v.AAAA.String() == "2001:4860:4860::8888"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
t.Fatal("did not find 2001:4860:4860::8888")
|
||||||
|
}
|
||||||
|
listener.Close()
|
||||||
|
<-done // wait for background goroutine to exit
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("DNSActionLocalHostPlusCache", func(t *testing.T) {
|
||||||
|
cache := map[string][]string{
|
||||||
|
"dns.google.": {"2001:4860:4860::8888"},
|
||||||
|
}
|
||||||
|
_, listener, done, err := newServerWithCache(DNSActionLocalHostPlusCache, cache)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
reply, err := dns.Exchange(newQuery(dns.TypeAAAA), listener.LocalAddr().String())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if reply.Rcode != dns.RcodeSuccess {
|
||||||
|
t.Fatal("unexpected rcode")
|
||||||
|
}
|
||||||
|
var found bool
|
||||||
|
for _, ans := range reply.Answer {
|
||||||
|
switch v := ans.(type) {
|
||||||
|
case *dns.AAAA:
|
||||||
|
found = found || v.AAAA.String() == "::1"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
t.Fatal("did not find ::1")
|
||||||
}
|
}
|
||||||
listener.Close()
|
listener.Close()
|
||||||
<-done // wait for background goroutine to exit
|
<-done // wait for background goroutine to exit
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("Start with invalid address", func(t *testing.T) {
|
t.Run("Start with invalid address", func(t *testing.T) {
|
||||||
p := &DNSProxy{}
|
p := &DNSServer{}
|
||||||
listener, err := p.Start("127.0.0.1")
|
listener, err := p.Start("127.0.0.1")
|
||||||
if err == nil || !strings.HasSuffix(err.Error(), "missing port in address") {
|
if err == nil || !strings.HasSuffix(err.Error(), "missing port in address") {
|
||||||
t.Fatal("unexpected err", err)
|
t.Fatal("unexpected err", err)
|
||||||
|
@ -226,7 +262,7 @@ func TestDNSProxy(t *testing.T) {
|
||||||
t.Run("oneloop", func(t *testing.T) {
|
t.Run("oneloop", func(t *testing.T) {
|
||||||
t.Run("ReadFrom failure after which we should continue", func(t *testing.T) {
|
t.Run("ReadFrom failure after which we should continue", func(t *testing.T) {
|
||||||
expected := errors.New("mocked error")
|
expected := errors.New("mocked error")
|
||||||
p := &DNSProxy{}
|
p := &DNSServer{}
|
||||||
conn := &mocks.UDPLikeConn{
|
conn := &mocks.UDPLikeConn{
|
||||||
MockReadFrom: func(p []byte) (n int, addr net.Addr, err error) {
|
MockReadFrom: func(p []byte) (n int, addr net.Addr, err error) {
|
||||||
return 0, nil, expected
|
return 0, nil, expected
|
||||||
|
@ -240,7 +276,7 @@ func TestDNSProxy(t *testing.T) {
|
||||||
|
|
||||||
t.Run("ReadFrom the connection is closed", func(t *testing.T) {
|
t.Run("ReadFrom the connection is closed", func(t *testing.T) {
|
||||||
expected := errors.New("use of closed network connection")
|
expected := errors.New("use of closed network connection")
|
||||||
p := &DNSProxy{}
|
p := &DNSServer{}
|
||||||
conn := &mocks.UDPLikeConn{
|
conn := &mocks.UDPLikeConn{
|
||||||
MockReadFrom: func(p []byte) (n int, addr net.Addr, err error) {
|
MockReadFrom: func(p []byte) (n int, addr net.Addr, err error) {
|
||||||
return 0, nil, expected
|
return 0, nil, expected
|
||||||
|
@ -253,7 +289,7 @@ func TestDNSProxy(t *testing.T) {
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("Unpack fails", func(t *testing.T) {
|
t.Run("Unpack fails", func(t *testing.T) {
|
||||||
p := &DNSProxy{}
|
p := &DNSServer{}
|
||||||
conn := &mocks.UDPLikeConn{
|
conn := &mocks.UDPLikeConn{
|
||||||
MockReadFrom: func(p []byte) (n int, addr net.Addr, err error) {
|
MockReadFrom: func(p []byte) (n int, addr net.Addr, err error) {
|
||||||
if len(p) < 4 {
|
if len(p) < 4 {
|
||||||
|
@ -269,46 +305,16 @@ func TestDNSProxy(t *testing.T) {
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("reply fails", func(t *testing.T) {
|
t.Run("no questions", func(t *testing.T) {
|
||||||
p := &DNSProxy{}
|
query := newQuery(dns.TypeA)
|
||||||
|
query.Question = nil // remove the question
|
||||||
|
data, err := query.Pack()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
p := &DNSServer{}
|
||||||
conn := &mocks.UDPLikeConn{
|
conn := &mocks.UDPLikeConn{
|
||||||
MockReadFrom: func(p []byte) (n int, addr net.Addr, err error) {
|
MockReadFrom: func(p []byte) (n int, addr net.Addr, err error) {
|
||||||
query := &dns.Msg{}
|
|
||||||
query.Question = append(query.Question, dns.Question{})
|
|
||||||
query.Question = append(query.Question, dns.Question{})
|
|
||||||
data, err := query.Pack()
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
if len(p) < len(data) {
|
|
||||||
panic("buffer too small")
|
|
||||||
}
|
|
||||||
copy(p, data)
|
|
||||||
return len(data), &net.UDPAddr{}, nil
|
|
||||||
},
|
|
||||||
}
|
|
||||||
okay := p.oneloop(conn)
|
|
||||||
if !okay {
|
|
||||||
t.Fatal("we should be okay after this error")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("pack fails", func(t *testing.T) {
|
|
||||||
p := &DNSProxy{
|
|
||||||
mockableReply: func(query *dns.Msg) (*dns.Msg, error) {
|
|
||||||
reply := &dns.Msg{}
|
|
||||||
reply.MsgHdr.Rcode = -1 // causes pack to fail
|
|
||||||
return reply, nil
|
|
||||||
},
|
|
||||||
}
|
|
||||||
conn := &mocks.UDPLikeConn{
|
|
||||||
MockReadFrom: func(p []byte) (n int, addr net.Addr, err error) {
|
|
||||||
query := &dns.Msg{}
|
|
||||||
query.Question = append(query.Question, dns.Question{})
|
|
||||||
data, err := query.Pack()
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
if len(p) < len(data) {
|
if len(p) < len(data) {
|
||||||
panic("buffer too small")
|
panic("buffer too small")
|
||||||
}
|
}
|
||||||
|
@ -323,45 +329,13 @@ func TestDNSProxy(t *testing.T) {
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("proxy", func(t *testing.T) {
|
t.Run("pack fails", func(t *testing.T) {
|
||||||
t.Run("with response", func(t *testing.T) {
|
query := newQuery(dns.TypeA)
|
||||||
p := &DNSProxy{}
|
query.Question[0].Name = randx.Letters(1024) // should be too large
|
||||||
query := &dns.Msg{}
|
p := &DNSServer{}
|
||||||
query.Response = true
|
count := p.emit(&mocks.UDPLikeConn{}, &mocks.Addr{}, query)
|
||||||
reply, err := p.proxy(query)
|
if count != 0 {
|
||||||
if !errors.Is(err, errDNSExpectedQueryNotResponse) {
|
t.Fatal("expected to see zero here")
|
||||||
t.Fatal("unexpected err", err)
|
}
|
||||||
}
|
|
||||||
if reply != nil {
|
|
||||||
t.Fatal("expected nil reply")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("with no questions", func(t *testing.T) {
|
|
||||||
p := &DNSProxy{}
|
|
||||||
query := &dns.Msg{}
|
|
||||||
reply, err := p.proxy(query)
|
|
||||||
if !errors.Is(err, errDNSExpectedSingleQuestion) {
|
|
||||||
t.Fatal("unexpected err", err)
|
|
||||||
}
|
|
||||||
if reply != nil {
|
|
||||||
t.Fatal("expected nil reply")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("round trip fails", func(t *testing.T) {
|
|
||||||
p := &DNSProxy{
|
|
||||||
UpstreamEndpoint: "antani",
|
|
||||||
}
|
|
||||||
query := &dns.Msg{}
|
|
||||||
query.Question = append(query.Question, dns.Question{})
|
|
||||||
reply, err := p.proxy(query)
|
|
||||||
if err == nil || !strings.HasSuffix(err.Error(), "missing port in address") {
|
|
||||||
t.Fatal("unexpected err", err)
|
|
||||||
}
|
|
||||||
if reply != nil {
|
|
||||||
t.Fatal("expected nil reply here")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,13 +1,3 @@
|
||||||
// Package filtering allows to implement self-censorship.
|
// Package filtering allows to implement self-censorship. We expose proxies
|
||||||
//
|
// implementing filtering policies for DNS, TLS, and HTTP.
|
||||||
// The top-level struct is the TProxy. It implements model's
|
|
||||||
// UnderlyingNetworkLibrary interface. Therefore, you can use TProxy to
|
|
||||||
// implement filtering and blocking of TCP, TLS, QUIC, DNS, HTTP.
|
|
||||||
//
|
|
||||||
// We also expose proxies that implement filtering policies for
|
|
||||||
// DNS, TLS, and HTTP.
|
|
||||||
//
|
|
||||||
// The typical usage of this package's functionality is to
|
|
||||||
// load a censoring policy into TProxyConfig and then to create
|
|
||||||
// and start a TProxy instance using NewTProxy.
|
|
||||||
package filtering
|
package filtering
|
||||||
|
|
|
@ -9,6 +9,9 @@ import (
|
||||||
"github.com/ooni/probe-cli/v3/internal/runtimex"
|
"github.com/ooni/probe-cli/v3/internal/runtimex"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// TODO(bassosimone): remove HTTPActionPass since we want integration tests
|
||||||
|
// to only run locally to make them much more predictable.
|
||||||
|
|
||||||
// HTTPAction is an HTTP filtering action that this proxy should take.
|
// HTTPAction is an HTTP filtering action that this proxy should take.
|
||||||
type HTTPAction string
|
type HTTPAction string
|
||||||
|
|
||||||
|
|
|
@ -1,16 +1,17 @@
|
||||||
package filtering
|
package filtering
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"errors"
|
"errors"
|
||||||
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/ooni/probe-cli/v3/internal/netxlite"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// TODO(bassosimone): remove TLSActionPass since we want integration tests
|
||||||
|
// to only run locally to make them much more predictable.
|
||||||
|
|
||||||
// TLSAction is a TLS filtering action that this proxy should take.
|
// TLSAction is a TLS filtering action that this proxy should take.
|
||||||
type TLSAction string
|
type TLSAction string
|
||||||
|
|
||||||
|
@ -237,5 +238,11 @@ func (p *TLSProxy) connectingToMyself(conn net.Conn) bool {
|
||||||
// forward will forward the traffic.
|
// forward will forward the traffic.
|
||||||
func (p *TLSProxy) forward(wg *sync.WaitGroup, left net.Conn, right net.Conn) {
|
func (p *TLSProxy) forward(wg *sync.WaitGroup, left net.Conn, right net.Conn) {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
netxlite.CopyContext(context.Background(), left, right)
|
// We cannot use netxlite.CopyContext here because we want netxlite to
|
||||||
|
// use filtering inside its test suite, so this package cannot depend on
|
||||||
|
// netxlite. In general, we don't want to use io.Copy or io.ReadAll
|
||||||
|
// directly because they may cause the code to block as documented in
|
||||||
|
// internal/netxlite/iox.go. However, this package is only used for
|
||||||
|
// testing, so it's completely okay to make an exception here.
|
||||||
|
io.Copy(left, right)
|
||||||
}
|
}
|
||||||
|
|
|
@ -113,7 +113,7 @@ func TestMeasureWithUDPResolver(t *testing.T) {
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("for nxdomain", func(t *testing.T) {
|
t.Run("for nxdomain", func(t *testing.T) {
|
||||||
proxy := &filtering.DNSProxy{
|
proxy := &filtering.DNSServer{
|
||||||
OnQuery: func(domain string) filtering.DNSAction {
|
OnQuery: func(domain string) filtering.DNSAction {
|
||||||
return filtering.DNSActionNXDOMAIN
|
return filtering.DNSActionNXDOMAIN
|
||||||
},
|
},
|
||||||
|
@ -137,7 +137,7 @@ func TestMeasureWithUDPResolver(t *testing.T) {
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("for refused", func(t *testing.T) {
|
t.Run("for refused", func(t *testing.T) {
|
||||||
proxy := &filtering.DNSProxy{
|
proxy := &filtering.DNSServer{
|
||||||
OnQuery: func(domain string) filtering.DNSAction {
|
OnQuery: func(domain string) filtering.DNSAction {
|
||||||
return filtering.DNSActionRefused
|
return filtering.DNSActionRefused
|
||||||
},
|
},
|
||||||
|
@ -161,7 +161,7 @@ func TestMeasureWithUDPResolver(t *testing.T) {
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("for timeout", func(t *testing.T) {
|
t.Run("for timeout", func(t *testing.T) {
|
||||||
proxy := &filtering.DNSProxy{
|
proxy := &filtering.DNSServer{
|
||||||
OnQuery: func(domain string) filtering.DNSAction {
|
OnQuery: func(domain string) filtering.DNSAction {
|
||||||
return filtering.DNSActionTimeout
|
return filtering.DNSActionTimeout
|
||||||
},
|
},
|
||||||
|
|
|
@ -7,6 +7,12 @@ for file in $(find . -type f -name \*.go); do
|
||||||
# implement safer wrappers for these functions.
|
# implement safer wrappers for these functions.
|
||||||
continue
|
continue
|
||||||
fi
|
fi
|
||||||
|
if [ "$file" = "./internal/netxlite/filtering/tls.go" ]; then
|
||||||
|
# We're allowed to use ReadAll and Copy in this file to
|
||||||
|
# avoid depending on netxlite, so we can use filtering
|
||||||
|
# inside of netxlite's own test suite.
|
||||||
|
continue
|
||||||
|
fi
|
||||||
if grep -q 'io\.ReadAll' $file; then
|
if grep -q 'io\.ReadAll' $file; then
|
||||||
echo "in $file: do not use io.ReadAll, use netxlite.ReadAllContext" 1>&2
|
echo "in $file: do not use io.ReadAll, use netxlite.ReadAllContext" 1>&2
|
||||||
exitcode=1
|
exitcode=1
|
||||||
|
|
Loading…
Reference in New Issue
Block a user