ooni-probe-cli/internal/netxlite/dnsovertcp_test.go

300 lines
8.0 KiB
Go
Raw Permalink Normal View History

package netxlite
import (
"bytes"
"context"
"crypto/tls"
"errors"
"io"
"math"
"net"
"testing"
"time"
"github.com/ooni/probe-cli/v3/internal/model"
"github.com/ooni/probe-cli/v3/internal/model/mocks"
)
func TestDNSOverTCPTransport(t *testing.T) {
t.Run("RoundTrip", func(t *testing.T) {
t.Run("cannot encode query", func(t *testing.T) {
expected := errors.New("mocked error")
const address = "9.9.9.9:53"
txp := NewUnwrappedDNSOverTCPTransport(new(net.Dialer).DialContext, address)
query := &mocks.DNSQuery{
MockBytes: func() ([]byte, error) {
return nil, expected
},
}
resp, err := txp.RoundTrip(context.Background(), query)
if !errors.Is(err, expected) {
t.Fatal("unexpected err", err)
}
if resp != nil {
t.Fatal("expected nil response here")
}
})
t.Run("query too large", func(t *testing.T) {
const address = "9.9.9.9:53"
txp := NewUnwrappedDNSOverTCPTransport(new(net.Dialer).DialContext, address)
query := &mocks.DNSQuery{
MockBytes: func() ([]byte, error) {
return make([]byte, math.MaxUint16+1), nil
},
}
resp, err := txp.RoundTrip(context.Background(), query)
if !errors.Is(err, errQueryTooLarge) {
t.Fatal("unexpected err", err)
}
if resp != nil {
t.Fatal("expected nil response here")
}
})
t.Run("dial failure", func(t *testing.T) {
const address = "9.9.9.9:53"
mocked := errors.New("mocked error")
query := &mocks.DNSQuery{
MockBytes: func() ([]byte, error) {
return make([]byte, 128), nil
},
}
fakedialer := &mocks.Dialer{
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
return nil, mocked
},
}
txp := NewUnwrappedDNSOverTCPTransport(fakedialer.DialContext, address)
resp, err := txp.RoundTrip(context.Background(), query)
if !errors.Is(err, mocked) {
t.Fatal("not the error we expected")
}
if resp != nil {
t.Fatal("expected nil resp here")
}
})
t.Run("write failure", func(t *testing.T) {
const address = "9.9.9.9:53"
mocked := errors.New("mocked error")
query := &mocks.DNSQuery{
MockBytes: func() ([]byte, error) {
return make([]byte, 128), nil
},
}
fakedialer := &mocks.Dialer{
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
return &mocks.Conn{
MockSetDeadline: func(t time.Time) error {
return nil
},
MockWrite: func(b []byte) (int, error) {
return 0, mocked
},
MockClose: func() error {
return nil
},
}, nil
},
}
txp := NewUnwrappedDNSOverTCPTransport(fakedialer.DialContext, address)
resp, err := txp.RoundTrip(context.Background(), query)
if !errors.Is(err, mocked) {
t.Fatal("not the error we expected")
}
if resp != nil {
t.Fatal("expected nil resp here")
}
})
t.Run("first read fails", func(t *testing.T) {
const address = "9.9.9.9:53"
mocked := errors.New("mocked error")
query := &mocks.DNSQuery{
MockBytes: func() ([]byte, error) {
return make([]byte, 128), nil
},
}
fakedialer := &mocks.Dialer{
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
return &mocks.Conn{
MockSetDeadline: func(t time.Time) error {
return nil
},
MockWrite: func(b []byte) (int, error) {
return len(b), nil
},
MockRead: func(b []byte) (int, error) {
return 0, mocked
},
MockClose: func() error {
return nil
},
}, nil
},
}
txp := NewUnwrappedDNSOverTCPTransport(fakedialer.DialContext, address)
resp, err := txp.RoundTrip(context.Background(), query)
if !errors.Is(err, mocked) {
t.Fatal("not the error we expected")
}
if resp != nil {
t.Fatal("expected nil resp here")
}
})
t.Run("second read fails", func(t *testing.T) {
const address = "9.9.9.9:53"
mocked := errors.New("mocked error")
query := &mocks.DNSQuery{
MockBytes: func() ([]byte, error) {
return make([]byte, 128), nil
},
}
input := io.MultiReader(
bytes.NewReader([]byte{byte(0), byte(2)}),
&mocks.Reader{
MockRead: func(b []byte) (int, error) {
return 0, mocked
},
},
)
fakedialer := &mocks.Dialer{
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
return &mocks.Conn{
MockSetDeadline: func(t time.Time) error {
return nil
},
MockWrite: func(b []byte) (int, error) {
return len(b), nil
},
MockRead: input.Read,
MockClose: func() error {
return nil
},
}, nil
},
}
txp := NewUnwrappedDNSOverTCPTransport(fakedialer.DialContext, address)
resp, err := txp.RoundTrip(context.Background(), query)
if !errors.Is(err, mocked) {
t.Fatal("not the error we expected")
}
if resp != nil {
t.Fatal("expected nil resp here")
}
})
t.Run("decode failure", func(t *testing.T) {
const address = "9.9.9.9:53"
mocked := errors.New("mocked error")
query := &mocks.DNSQuery{
MockBytes: func() ([]byte, error) {
return make([]byte, 128), nil
},
}
input := bytes.NewReader([]byte{byte(0), byte(1), byte(1)})
fakedialer := &mocks.Dialer{
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
return &mocks.Conn{
MockSetDeadline: func(t time.Time) error {
return nil
},
MockWrite: func(b []byte) (int, error) {
return len(b), nil
},
MockRead: input.Read,
MockClose: func() error {
return nil
},
}, nil
},
}
txp := NewUnwrappedDNSOverTCPTransport(fakedialer.DialContext, address)
txp.decoder = &mocks.DNSDecoder{
MockDecodeResponse: func(data []byte, query model.DNSQuery) (model.DNSResponse, error) {
return nil, mocked
},
}
resp, err := txp.RoundTrip(context.Background(), query)
if !errors.Is(err, mocked) {
t.Fatal("unexpected err", err)
}
if resp != nil {
t.Fatal("expected nil resp here")
}
})
t.Run("successful case", func(t *testing.T) {
const address = "9.9.9.9:53"
query := &mocks.DNSQuery{
MockBytes: func() ([]byte, error) {
return make([]byte, 128), nil
},
}
input := bytes.NewReader([]byte{byte(0), byte(1), byte(1)})
fakedialer := &mocks.Dialer{
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
return &mocks.Conn{
MockSetDeadline: func(t time.Time) error {
return nil
},
MockWrite: func(b []byte) (int, error) {
return len(b), nil
},
MockRead: input.Read,
MockClose: func() error {
return nil
},
}, nil
},
}
txp := NewUnwrappedDNSOverTCPTransport(fakedialer.DialContext, address)
expectedResp := &mocks.DNSResponse{}
txp.decoder = &mocks.DNSDecoder{
MockDecodeResponse: func(data []byte, query model.DNSQuery) (model.DNSResponse, error) {
return expectedResp, nil
},
}
resp, err := txp.RoundTrip(context.Background(), query)
if err != nil {
t.Fatal(err)
}
if resp != expectedResp {
t.Fatal("not the response we expected")
}
})
})
t.Run("other functions okay with TCP", func(t *testing.T) {
const address = "9.9.9.9:53"
txp := NewUnwrappedDNSOverTCPTransport(new(net.Dialer).DialContext, address)
if txp.RequiresPadding() != false {
t.Fatal("invalid RequiresPadding")
}
if txp.Network() != "tcp" {
t.Fatal("invalid Network")
}
if txp.Address() != address {
t.Fatal("invalid Address")
}
txp.CloseIdleConnections()
})
t.Run("other functions okay with TLS", func(t *testing.T) {
const address = "9.9.9.9:853"
txp := NewUnwrappedDNSOverTLSTransport((&tls.Dialer{}).DialContext, address)
if txp.RequiresPadding() != true {
t.Fatal("invalid RequiresPadding")
}
if txp.Network() != "dot" {
t.Fatal("invalid Network")
}
if txp.Address() != address {
t.Fatal("invalid Address")
}
txp.CloseIdleConnections()
})
}