ooni-probe-cli/internal/netxlite/dnsoverudp_test.go
Simone Basso cc24f28b9d
feat(netxlite): support extracting the CNAME (#875)
* feat(netxlite): support extracting the CNAME

Closes https://github.com/ooni/probe/issues/2225

* fix(netxlite): attempt to increase coverage and improve tests

1. dnsovergetaddrinfo: specify the behavior of a DNSResponse returned
by this file to make it line with normal responses and write unit tests
to make sure we adhere to expectations;

2. dnsoverudp: make sure we wait to deferred responses also w/o a
custom context and post on a private channel and test that;

3. utls: recognize that we can actually write a test for NetConn and
what needs to change when we'll use go1.19 by default will just be
a cast that at that point can be removed.
2022-08-23 13:04:00 +02:00

535 lines
15 KiB
Go

package netxlite
import (
"bytes"
"context"
"errors"
"net"
"sync"
"testing"
"time"
"github.com/apex/log"
"github.com/google/go-cmp/cmp"
"github.com/miekg/dns"
"github.com/ooni/probe-cli/v3/internal/model"
"github.com/ooni/probe-cli/v3/internal/model/mocks"
"github.com/ooni/probe-cli/v3/internal/netxlite/filtering"
"github.com/ooni/probe-cli/v3/internal/testingx"
)
func TestDNSOverUDPTransport(t *testing.T) {
t.Run("RoundTrip", func(t *testing.T) {
t.Run("cannot encode query", func(t *testing.T) {
expected := errors.New("mocked error")
const address = "9.9.9.9:53"
txp := NewUnwrappedDNSOverUDPTransport(nil, address)
query := &mocks.DNSQuery{
MockBytes: func() ([]byte, error) {
return nil, expected
},
}
resp, err := txp.RoundTrip(context.Background(), query)
if !errors.Is(err, expected) {
t.Fatal("unexpected err", err)
}
if resp != nil {
t.Fatal("expected nil response here")
}
})
t.Run("dial failure", func(t *testing.T) {
mocked := errors.New("mocked error")
const address = "9.9.9.9:53"
txp := NewUnwrappedDNSOverUDPTransport(&mocks.Dialer{
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
return nil, mocked
},
}, address)
query := &mocks.DNSQuery{
MockBytes: func() ([]byte, error) {
return make([]byte, 128), nil
},
}
resp, err := txp.RoundTrip(context.Background(), query)
if !errors.Is(err, mocked) {
t.Fatal("not the error we expected")
}
if resp != nil {
t.Fatal("expected no response here")
}
})
t.Run("Write failure", func(t *testing.T) {
mocked := errors.New("mocked error")
txp := NewUnwrappedDNSOverUDPTransport(
&mocks.Dialer{
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
return &mocks.Conn{
MockSetDeadline: func(t time.Time) error {
return nil
},
MockWrite: func(b []byte) (int, error) {
return 0, mocked
},
MockClose: func() error {
return nil
},
MockLocalAddr: func() net.Addr {
return &mocks.Addr{
MockNetwork: func() string {
return "udp"
},
MockString: func() string {
return "127.0.0.1:1345"
},
}
},
}, nil
},
}, "9.9.9.9:53",
)
query := &mocks.DNSQuery{
MockBytes: func() ([]byte, error) {
return make([]byte, 128), nil
},
}
resp, err := txp.RoundTrip(context.Background(), query)
if !errors.Is(err, mocked) {
t.Fatal("not the error we expected")
}
if resp != nil {
t.Fatal("expected no response here")
}
})
t.Run("Read failure", func(t *testing.T) {
mocked := errors.New("mocked error")
txp := NewUnwrappedDNSOverUDPTransport(
&mocks.Dialer{
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
return &mocks.Conn{
MockSetDeadline: func(t time.Time) error {
return nil
},
MockWrite: func(b []byte) (int, error) {
return len(b), nil
},
MockRead: func(b []byte) (int, error) {
return 0, mocked
},
MockClose: func() error {
return nil
},
MockLocalAddr: func() net.Addr {
return &mocks.Addr{
MockNetwork: func() string {
return "udp"
},
MockString: func() string {
return "127.0.0.1:1345"
},
}
},
}, nil
},
}, "9.9.9.9:53",
)
query := &mocks.DNSQuery{
MockBytes: func() ([]byte, error) {
return make([]byte, 128), nil
},
}
resp, err := txp.RoundTrip(context.Background(), query)
if !errors.Is(err, mocked) {
t.Fatal("not the error we expected")
}
if resp != nil {
t.Fatal("expected no response here")
}
})
t.Run("decode failure", func(t *testing.T) {
const expected = 17
input := bytes.NewReader(make([]byte, expected))
txp := NewUnwrappedDNSOverUDPTransport(
&mocks.Dialer{
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
return &mocks.Conn{
MockSetDeadline: func(t time.Time) error {
return nil
},
MockWrite: func(b []byte) (int, error) {
return len(b), nil
},
MockRead: input.Read,
MockClose: func() error {
return nil
},
MockLocalAddr: func() net.Addr {
return &mocks.Addr{
MockNetwork: func() string {
return "udp"
},
MockString: func() string {
return "127.0.0.1:1345"
},
}
},
}, nil
},
}, "9.9.9.9:53",
)
expectedErr := errors.New("mocked error")
txp.Decoder = &mocks.DNSDecoder{
MockDecodeResponse: func(data []byte, query model.DNSQuery) (model.DNSResponse, error) {
return nil, expectedErr
},
}
query := &mocks.DNSQuery{
MockBytes: func() ([]byte, error) {
return make([]byte, 128), nil
},
}
resp, err := txp.RoundTrip(context.Background(), query)
if !errors.Is(err, expectedErr) {
t.Fatal("unexpected err", err)
}
if resp != nil {
t.Fatal("expected nil resp")
}
})
t.Run("decode success", func(t *testing.T) {
const expected = 17
input := bytes.NewReader(make([]byte, expected))
txp := NewUnwrappedDNSOverUDPTransport(
&mocks.Dialer{
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
return &mocks.Conn{
MockSetDeadline: func(t time.Time) error {
return nil
},
MockWrite: func(b []byte) (int, error) {
return len(b), nil
},
MockRead: input.Read,
MockClose: func() error {
return nil
},
MockLocalAddr: func() net.Addr {
return &mocks.Addr{
MockNetwork: func() string {
return "udp"
},
MockString: func() string {
return "127.0.0.1:1345"
},
}
},
}, nil
},
}, "9.9.9.9:53",
)
expectedResp := &mocks.DNSResponse{}
txp.Decoder = &mocks.DNSDecoder{
MockDecodeResponse: func(data []byte, query model.DNSQuery) (model.DNSResponse, error) {
return expectedResp, nil
},
}
query := &mocks.DNSQuery{
MockBytes: func() ([]byte, error) {
return make([]byte, 128), nil
},
}
resp, err := txp.RoundTrip(context.Background(), query)
if err != nil {
t.Fatal(err)
}
if resp != expectedResp {
t.Fatal("unexpected resp")
}
})
t.Run("using a real server", func(t *testing.T) {
srvr := &filtering.DNSServer{
OnQuery: func(domain string) filtering.DNSAction {
return filtering.DNSActionCache
},
Cache: map[string][]string{
"dns.google.": {"8.8.8.8"},
},
}
listener, err := srvr.Start("127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
defer listener.Close()
dialer := NewDialerWithoutResolver(model.DiscardLogger)
txp := NewUnwrappedDNSOverUDPTransport(dialer, listener.LocalAddr().String())
encoder := &DNSEncoderMiekg{}
query := encoder.Encode("dns.google.", dns.TypeA, false)
resp, err := txp.RoundTrip(context.Background(), query)
if err != nil {
t.Fatal(err)
}
addrs, err := resp.DecodeLookupHost()
if err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(addrs, []string{"8.8.8.8"}); diff != "" {
t.Fatal(diff)
}
})
})
t.Run("recording delayed DNS responses", func(t *testing.T) {
t.Run("without any context-injected traces", func(t *testing.T) {
srvr := &filtering.DNSServer{
OnQuery: func(domain string) filtering.DNSAction {
return filtering.DNSActionLocalHostPlusCache
},
Cache: map[string][]string{
"dns.google.": {"8.8.8.8"},
},
}
listener, err := srvr.Start("127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
defer listener.Close()
dialer := NewDialerWithoutResolver(model.DiscardLogger)
expectedAddress := listener.LocalAddr().String()
txp := NewUnwrappedDNSOverUDPTransport(dialer, expectedAddress)
txp.lateResponses = make(chan any, 1) // with buffer to avoid deadlocks
encoder := &DNSEncoderMiekg{}
query := encoder.Encode("dns.google.", dns.TypeA, false)
rch, err := txp.RoundTrip(context.Background(), query)
if err != nil {
t.Fatal(err)
}
if _, err := rch.DecodeLookupHost(); err != nil {
t.Fatal(err)
}
// Now wait for the delayed response to arrive. We don't care much
// about observing it here, rather we want to know it happened.
<-txp.lateResponses
})
t.Run("uses a context-injected custom trace (success case)", func(t *testing.T) {
var (
delayedDNSResponseCalled bool
goodQueryType bool
goodTransportNetwork bool
goodTransportAddress bool
goodLookupAddrs bool
goodError bool
)
srvr := &filtering.DNSServer{
OnQuery: func(domain string) filtering.DNSAction {
return filtering.DNSActionLocalHostPlusCache
},
Cache: map[string][]string{
"dns.google.": {"8.8.8.8"},
},
}
listener, err := srvr.Start("127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
defer listener.Close()
dialer := NewDialerWithoutResolver(model.DiscardLogger)
expectedAddress := listener.LocalAddr().String()
txp := NewUnwrappedDNSOverUDPTransport(dialer, expectedAddress)
encoder := &DNSEncoderMiekg{}
query := encoder.Encode("dns.google.", dns.TypeA, false)
zeroTime := time.Now()
deterministicTime := testingx.NewTimeDeterministic(zeroTime)
expectedAddrs := []string{"8.8.8.8"}
respChannel := make(chan *model.DNSResponse, 8)
mu := new(sync.Mutex)
tx := &mocks.Trace{
MockTimeNow: deterministicTime.Now,
MockOnDelayedDNSResponse: func(started time.Time, txp model.DNSTransport,
query model.DNSQuery, response model.DNSResponse, addrs []string, err error,
finished time.Time) error {
mu.Lock()
delayedDNSResponseCalled = true
goodQueryType = (query.Type() == dns.TypeA)
goodTransportNetwork = (txp.Network() == "udp")
goodTransportAddress = (txp.Address() == expectedAddress)
goodLookupAddrs = (cmp.Diff(expectedAddrs, addrs) == "")
goodError = (err == nil)
mu.Unlock()
select {
case respChannel <- &response:
return nil
default:
return errors.New("full buffer")
}
},
MockOnConnectDone: func(started time.Time, network, domain, remoteAddr string, err error,
finished time.Time) {
// do nothing
},
MockMaybeWrapNetConn: func(conn net.Conn) net.Conn {
return conn
},
}
ctx := ContextWithTrace(context.Background(), tx)
rch, err := txp.RoundTrip(ctx, query)
<-respChannel // wait for the delayed response
if err != nil {
t.Fatal(err)
}
addrs, err := rch.DecodeLookupHost()
if err != nil {
t.Fatal(err)
}
defer mu.Unlock()
mu.Lock()
if diff := cmp.Diff(addrs, []string{"127.0.0.1"}); diff != "" {
t.Fatal(diff)
}
if !delayedDNSResponseCalled {
t.Fatal("delayedDNSResponse not called")
}
if !goodQueryType {
t.Fatal("unexpected query type")
}
if !goodTransportNetwork {
t.Fatal("unexpected DNS transport network")
}
if !goodTransportAddress {
t.Fatal("unexpected DNS Transport address")
}
if !goodLookupAddrs {
t.Fatal("unexpected delayed DNSLookup address")
}
if !goodError {
t.Fatal("unexpected error encountered")
}
})
t.Run("uses a context-injected custom trace (failure case)", func(t *testing.T) {
var (
delayedDNSResponseCalled bool
goodQueryType bool
goodTransportNetwork bool
goodTransportAddress bool
goodLookupAddrs bool
goodError bool
)
srvr := &filtering.DNSServer{
OnQuery: func(domain string) filtering.DNSAction {
return filtering.DNSActionLocalHostPlusCache
},
Cache: map[string][]string{
// Note: the cache here is nonexistent so we should
// get a "no such host" error from the server.
},
}
listener, err := srvr.Start("127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
defer listener.Close()
dialer := NewDialerWithoutResolver(model.DiscardLogger)
expectedAddress := listener.LocalAddr().String()
txp := NewUnwrappedDNSOverUDPTransport(dialer, expectedAddress)
encoder := &DNSEncoderMiekg{}
query := encoder.Encode("dns.google.", dns.TypeA, false)
zeroTime := time.Now()
deterministicTime := testingx.NewTimeDeterministic(zeroTime)
respChannel := make(chan *model.DNSResponse, 8)
mu := new(sync.Mutex)
tx := &mocks.Trace{
MockTimeNow: deterministicTime.Now,
MockOnDelayedDNSResponse: func(started time.Time, txp model.DNSTransport,
query model.DNSQuery, response model.DNSResponse, addrs []string, err error,
finished time.Time) error {
mu.Lock()
delayedDNSResponseCalled = true
goodQueryType = (query.Type() == dns.TypeA)
goodTransportNetwork = (txp.Network() == "udp")
goodTransportAddress = (txp.Address() == expectedAddress)
goodLookupAddrs = (len(addrs) == 0)
goodError = errors.Is(err, ErrOODNSNoSuchHost)
mu.Unlock()
respChannel <- &response
return errors.New("mocked") // return error to stop background routine to record responses
},
MockOnConnectDone: func(started time.Time, network, domain, remoteAddr string, err error,
finished time.Time) {
// do nothing
},
MockMaybeWrapNetConn: func(conn net.Conn) net.Conn {
return conn
},
}
ctx := ContextWithTrace(context.Background(), tx)
rch, err := txp.RoundTrip(ctx, query)
<-respChannel // wait for the delayed response
if err != nil {
t.Fatal(err)
}
addrs, err := rch.DecodeLookupHost()
if err != nil {
t.Fatal(err)
}
defer mu.Unlock()
mu.Lock()
if diff := cmp.Diff(addrs, []string{"127.0.0.1"}); diff != "" {
t.Fatal(diff)
}
if !delayedDNSResponseCalled {
t.Fatal("delayedDNSResponse not called")
}
if !goodQueryType {
t.Fatal("unexpected query type")
}
if !goodTransportNetwork {
t.Fatal("unexpected DNS transport network")
}
if !goodTransportAddress {
t.Fatal("unexpected DNS Transport address")
}
if !goodLookupAddrs {
t.Fatal("unexpected delayed DNSLookup address")
}
if !goodError {
t.Fatal("unexpected error encountered")
}
})
})
t.Run("CloseIdleConnections", func(t *testing.T) {
var called bool
dialer := &mocks.Dialer{
MockCloseIdleConnections: func() {
called = true
},
}
const address = "9.9.9.9:53"
txp := NewUnwrappedDNSOverUDPTransport(dialer, address)
txp.CloseIdleConnections()
if !called {
t.Fatal("not called")
}
})
t.Run("other functions okay", func(t *testing.T) {
const address = "9.9.9.9:53"
txp := NewUnwrappedDNSOverUDPTransport(NewDialerWithoutResolver(log.Log), address)
if txp.RequiresPadding() != false {
t.Fatal("invalid RequiresPadding")
}
if txp.Network() != "udp" {
t.Fatal("invalid Network")
}
if txp.Address() != address {
t.Fatal("invalid Address")
}
})
}