refactor(netx): move dns transports in netxlite/dnsx (#503)

While there, modernize the way in which we run tests to avoid
depending on the fake files scattered around the tree and to
use some well defined mock structures instead.

Part of https://github.com/ooni/probe/issues/1591
This commit is contained in:
Simone Basso 2021-09-09 21:24:27 +02:00 committed by GitHub
parent b3c36b5c7f
commit 3cb782f0a2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 546 additions and 91 deletions

View File

@ -0,0 +1,28 @@
package resolver
import "github.com/ooni/probe-cli/v3/internal/netxlite/dnsx"
// Variables that other packages expect to find here but have been
// moved into the internal/netxlite/dnsx package.
var (
NewSerialResolver = dnsx.NewSerialResolver
NewDNSOverUDP = dnsx.NewDNSOverUDP
NewDNSOverTCP = dnsx.NewDNSOverTCP
NewDNSOverTLS = dnsx.NewDNSOverTLS
NewDNSOverHTTPS = dnsx.NewDNSOverHTTPS
NewDNSOverHTTPSWithHostOverride = dnsx.NewDNSOverHTTPSWithHostOverride
)
// Types that other packages expect to find here but have been
// moved into the internal/netxlite/dnsx package.
type (
DNSOverHTTPS = dnsx.DNSOverHTTPS
DNSOverTCP = dnsx.DNSOverTCP
DNSOverUDP = dnsx.DNSOverUDP
MiekgEncoder = dnsx.MiekgEncoder
MiekgDecoder = dnsx.MiekgDecoder
RoundTripper = dnsx.RoundTripper
SerialResolver = dnsx.SerialResolver
Dialer = dnsx.Dialer
DialContextFunc = dnsx.DialContextFunc
)

View File

@ -1,4 +1,4 @@
package resolver package dnsx
import ( import (
"errors" "errors"

View File

@ -1,6 +1,7 @@
package resolver package dnsx
import ( import (
"net"
"strings" "strings"
"testing" "testing"
@ -20,7 +21,7 @@ func TestDecoderUnpackError(t *testing.T) {
func TestDecoderNXDOMAIN(t *testing.T) { func TestDecoderNXDOMAIN(t *testing.T) {
d := &MiekgDecoder{} d := &MiekgDecoder{}
data, err := d.Decode(dns.TypeA, GenReplyError(t, dns.RcodeNameError)) data, err := d.Decode(dns.TypeA, genReplyError(t, dns.RcodeNameError))
if err == nil || !strings.HasSuffix(err.Error(), "no such host") { if err == nil || !strings.HasSuffix(err.Error(), "no such host") {
t.Fatal("not the error we expected") t.Fatal("not the error we expected")
} }
@ -31,7 +32,7 @@ func TestDecoderNXDOMAIN(t *testing.T) {
func TestDecoderOtherError(t *testing.T) { func TestDecoderOtherError(t *testing.T) {
d := &MiekgDecoder{} d := &MiekgDecoder{}
data, err := d.Decode(dns.TypeA, GenReplyError(t, dns.RcodeRefused)) data, err := d.Decode(dns.TypeA, genReplyError(t, dns.RcodeRefused))
if err == nil || !strings.HasSuffix(err.Error(), "query failed") { if err == nil || !strings.HasSuffix(err.Error(), "query failed") {
t.Fatal("not the error we expected") t.Fatal("not the error we expected")
} }
@ -42,7 +43,7 @@ func TestDecoderOtherError(t *testing.T) {
func TestDecoderNoAddress(t *testing.T) { func TestDecoderNoAddress(t *testing.T) {
d := &MiekgDecoder{} d := &MiekgDecoder{}
data, err := d.Decode(dns.TypeA, GenReplySuccess(t, dns.TypeA)) data, err := d.Decode(dns.TypeA, genReplySuccess(t, dns.TypeA))
if err == nil || !strings.HasSuffix(err.Error(), "no response returned") { if err == nil || !strings.HasSuffix(err.Error(), "no response returned") {
t.Fatal("not the error we expected") t.Fatal("not the error we expected")
} }
@ -54,7 +55,7 @@ func TestDecoderNoAddress(t *testing.T) {
func TestDecoderDecodeA(t *testing.T) { func TestDecoderDecodeA(t *testing.T) {
d := &MiekgDecoder{} d := &MiekgDecoder{}
data, err := d.Decode( data, err := d.Decode(
dns.TypeA, GenReplySuccess(t, dns.TypeA, "1.1.1.1", "8.8.8.8")) dns.TypeA, genReplySuccess(t, dns.TypeA, "1.1.1.1", "8.8.8.8"))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -72,7 +73,7 @@ func TestDecoderDecodeA(t *testing.T) {
func TestDecoderDecodeAAAA(t *testing.T) { func TestDecoderDecodeAAAA(t *testing.T) {
d := &MiekgDecoder{} d := &MiekgDecoder{}
data, err := d.Decode( data, err := d.Decode(
dns.TypeAAAA, GenReplySuccess(t, dns.TypeAAAA, "::1", "fe80::1")) dns.TypeAAAA, genReplySuccess(t, dns.TypeAAAA, "::1", "fe80::1"))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -90,7 +91,7 @@ func TestDecoderDecodeAAAA(t *testing.T) {
func TestDecoderUnexpectedAReply(t *testing.T) { func TestDecoderUnexpectedAReply(t *testing.T) {
d := &MiekgDecoder{} d := &MiekgDecoder{}
data, err := d.Decode( data, err := d.Decode(
dns.TypeA, GenReplySuccess(t, dns.TypeAAAA, "::1", "fe80::1")) dns.TypeA, genReplySuccess(t, dns.TypeAAAA, "::1", "fe80::1"))
if err == nil || !strings.HasSuffix(err.Error(), "no response returned") { if err == nil || !strings.HasSuffix(err.Error(), "no response returned") {
t.Fatal("not the error we expected") t.Fatal("not the error we expected")
} }
@ -102,7 +103,7 @@ func TestDecoderUnexpectedAReply(t *testing.T) {
func TestDecoderUnexpectedAAAAReply(t *testing.T) { func TestDecoderUnexpectedAAAAReply(t *testing.T) {
d := &MiekgDecoder{} d := &MiekgDecoder{}
data, err := d.Decode( data, err := d.Decode(
dns.TypeAAAA, GenReplySuccess(t, dns.TypeA, "1.1.1.1", "8.8.4.4.")) dns.TypeAAAA, genReplySuccess(t, dns.TypeA, "1.1.1.1", "8.8.4.4."))
if err == nil || !strings.HasSuffix(err.Error(), "no response returned") { if err == nil || !strings.HasSuffix(err.Error(), "no response returned") {
t.Fatal("not the error we expected") t.Fatal("not the error we expected")
} }
@ -110,3 +111,71 @@ func TestDecoderUnexpectedAAAAReply(t *testing.T) {
t.Fatal("expected nil data here") t.Fatal("expected nil data here")
} }
} }
func genReplyError(t *testing.T, code int) []byte {
question := dns.Question{
Name: dns.Fqdn("x.org"),
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
}
query := new(dns.Msg)
query.Id = dns.Id()
query.RecursionDesired = true
query.Question = make([]dns.Question, 1)
query.Question[0] = question
reply := new(dns.Msg)
reply.Compress = true
reply.MsgHdr.RecursionAvailable = true
reply.SetRcode(query, code)
data, err := reply.Pack()
if err != nil {
t.Fatal(err)
}
return data
}
func genReplySuccess(t *testing.T, qtype uint16, ips ...string) []byte {
question := dns.Question{
Name: dns.Fqdn("x.org"),
Qtype: qtype,
Qclass: dns.ClassINET,
}
query := new(dns.Msg)
query.Id = dns.Id()
query.RecursionDesired = true
query.Question = make([]dns.Question, 1)
query.Question[0] = question
reply := new(dns.Msg)
reply.Compress = true
reply.MsgHdr.RecursionAvailable = true
reply.SetReply(query)
for _, ip := range ips {
switch qtype {
case dns.TypeA:
reply.Answer = append(reply.Answer, &dns.A{
Hdr: dns.RR_Header{
Name: dns.Fqdn("x.org"),
Rrtype: qtype,
Class: dns.ClassINET,
Ttl: 0,
},
A: net.ParseIP(ip),
})
case dns.TypeAAAA:
reply.Answer = append(reply.Answer, &dns.AAAA{
Hdr: dns.RR_Header{
Name: dns.Fqdn("x.org"),
Rrtype: qtype,
Class: dns.ClassINET,
Ttl: 0,
},
AAAA: net.ParseIP(ip),
})
}
}
data, err := reply.Pack()
if err != nil {
t.Fatal(err)
}
return data
}

View File

@ -1,4 +1,4 @@
package resolver package dnsx
import ( import (
"bytes" "bytes"

View File

@ -1,4 +1,4 @@
package resolver package dnsx
import ( import (
"bytes" "bytes"

View File

@ -1,4 +1,4 @@
package resolver package dnsx
import ( import (
"context" "context"

View File

@ -1,10 +1,16 @@
package resolver package dnsx
import ( import (
"bytes"
"context" "context"
"crypto/tls"
"errors" "errors"
"io"
"net" "net"
"testing" "testing"
"time"
"github.com/ooni/probe-cli/v3/internal/netxlite/mocks"
) )
func TestDNSOverTCPTransportQueryTooLarge(t *testing.T) { func TestDNSOverTCPTransportQueryTooLarge(t *testing.T) {
@ -22,7 +28,11 @@ func TestDNSOverTCPTransportQueryTooLarge(t *testing.T) {
func TestDNSOverTCPTransportDialFailure(t *testing.T) { func TestDNSOverTCPTransportDialFailure(t *testing.T) {
const address = "9.9.9.9:53" const address = "9.9.9.9:53"
mocked := errors.New("mocked error") mocked := errors.New("mocked error")
fakedialer := FakeDialer{Err: mocked} fakedialer := &mocks.Dialer{
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
return nil, mocked
},
}
txp := NewDNSOverTCP(fakedialer.DialContext, address) txp := NewDNSOverTCP(fakedialer.DialContext, address)
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11)) reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11))
if !errors.Is(err, mocked) { if !errors.Is(err, mocked) {
@ -36,9 +46,18 @@ func TestDNSOverTCPTransportDialFailure(t *testing.T) {
func TestDNSOverTCPTransportSetDealineFailure(t *testing.T) { func TestDNSOverTCPTransportSetDealineFailure(t *testing.T) {
const address = "9.9.9.9:53" const address = "9.9.9.9:53"
mocked := errors.New("mocked error") mocked := errors.New("mocked error")
fakedialer := FakeDialer{Conn: &FakeConn{ fakedialer := &mocks.Dialer{
SetDeadlineError: mocked, MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
}} return &mocks.Conn{
MockSetDeadline: func(t time.Time) error {
return mocked
},
MockClose: func() error {
return nil
},
}, nil
},
}
txp := NewDNSOverTCP(fakedialer.DialContext, address) txp := NewDNSOverTCP(fakedialer.DialContext, address)
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11)) reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11))
if !errors.Is(err, mocked) { if !errors.Is(err, mocked) {
@ -52,9 +71,21 @@ func TestDNSOverTCPTransportSetDealineFailure(t *testing.T) {
func TestDNSOverTCPTransportWriteFailure(t *testing.T) { func TestDNSOverTCPTransportWriteFailure(t *testing.T) {
const address = "9.9.9.9:53" const address = "9.9.9.9:53"
mocked := errors.New("mocked error") mocked := errors.New("mocked error")
fakedialer := FakeDialer{Conn: &FakeConn{ fakedialer := &mocks.Dialer{
WriteError: mocked, MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
}} return &mocks.Conn{
MockSetDeadline: func(t time.Time) error {
return nil
},
MockWrite: func(b []byte) (int, error) {
return 0, mocked
},
MockClose: func() error {
return nil
},
}, nil
},
}
txp := NewDNSOverTCP(fakedialer.DialContext, address) txp := NewDNSOverTCP(fakedialer.DialContext, address)
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11)) reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11))
if !errors.Is(err, mocked) { if !errors.Is(err, mocked) {
@ -68,9 +99,24 @@ func TestDNSOverTCPTransportWriteFailure(t *testing.T) {
func TestDNSOverTCPTransportReadFailure(t *testing.T) { func TestDNSOverTCPTransportReadFailure(t *testing.T) {
const address = "9.9.9.9:53" const address = "9.9.9.9:53"
mocked := errors.New("mocked error") mocked := errors.New("mocked error")
fakedialer := FakeDialer{Conn: &FakeConn{ fakedialer := &mocks.Dialer{
ReadError: mocked, MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
}} return &mocks.Conn{
MockSetDeadline: func(t time.Time) error {
return nil
},
MockWrite: func(b []byte) (int, error) {
return len(b), nil
},
MockRead: func(b []byte) (int, error) {
return 0, mocked
},
MockClose: func() error {
return nil
},
}, nil
},
}
txp := NewDNSOverTCP(fakedialer.DialContext, address) txp := NewDNSOverTCP(fakedialer.DialContext, address)
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11)) reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11))
if !errors.Is(err, mocked) { if !errors.Is(err, mocked) {
@ -84,10 +130,30 @@ func TestDNSOverTCPTransportReadFailure(t *testing.T) {
func TestDNSOverTCPTransportSecondReadFailure(t *testing.T) { func TestDNSOverTCPTransportSecondReadFailure(t *testing.T) {
const address = "9.9.9.9:53" const address = "9.9.9.9:53"
mocked := errors.New("mocked error") mocked := errors.New("mocked error")
fakedialer := FakeDialer{Conn: &FakeConn{ input := io.MultiReader(
ReadError: mocked, bytes.NewReader([]byte{byte(0), byte(2)}),
ReadData: []byte{byte(0), byte(2)}, &mocks.Reader{
}} MockRead: func(b []byte) (int, error) {
return 0, mocked
},
},
)
fakedialer := &mocks.Dialer{
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
return &mocks.Conn{
MockSetDeadline: func(t time.Time) error {
return nil
},
MockWrite: func(b []byte) (int, error) {
return len(b), nil
},
MockRead: input.Read,
MockClose: func() error {
return nil
},
}, nil
},
}
txp := NewDNSOverTCP(fakedialer.DialContext, address) txp := NewDNSOverTCP(fakedialer.DialContext, address)
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11)) reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11))
if !errors.Is(err, mocked) { if !errors.Is(err, mocked) {
@ -100,11 +166,23 @@ func TestDNSOverTCPTransportSecondReadFailure(t *testing.T) {
func TestDNSOverTCPTransportAllGood(t *testing.T) { func TestDNSOverTCPTransportAllGood(t *testing.T) {
const address = "9.9.9.9:53" const address = "9.9.9.9:53"
mocked := errors.New("mocked error") input := bytes.NewReader([]byte{byte(0), byte(1), byte(1)})
fakedialer := FakeDialer{Conn: &FakeConn{ fakedialer := &mocks.Dialer{
ReadError: mocked, MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
ReadData: []byte{byte(0), byte(1), byte(1)}, return &mocks.Conn{
}} MockSetDeadline: func(t time.Time) error {
return nil
},
MockWrite: func(b []byte) (int, error) {
return len(b), nil
},
MockRead: input.Read,
MockClose: func() error {
return nil
},
}, nil
},
}
txp := NewDNSOverTCP(fakedialer.DialContext, address) txp := NewDNSOverTCP(fakedialer.DialContext, address)
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11)) reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11))
if err != nil { if err != nil {
@ -131,7 +209,7 @@ func TestDNSOverTCPTransportOK(t *testing.T) {
func TestDNSOverTLSTransportOK(t *testing.T) { func TestDNSOverTLSTransportOK(t *testing.T) {
const address = "9.9.9.9:853" const address = "9.9.9.9:853"
txp := NewDNSOverTLS(DialTLSContext, address) txp := NewDNSOverTLS((&tls.Dialer{}).DialContext, address)
if txp.RequiresPadding() != true { if txp.RequiresPadding() != true {
t.Fatal("invalid RequiresPadding") t.Fatal("invalid RequiresPadding")
} }

View File

@ -1,4 +1,4 @@
package resolver package dnsx
import ( import (
"context" "context"

View File

@ -1,16 +1,24 @@
package resolver package dnsx
import ( import (
"bytes"
"context" "context"
"errors" "errors"
"net" "net"
"testing" "testing"
"time"
"github.com/ooni/probe-cli/v3/internal/netxlite/mocks"
) )
func TestDNSOverUDPDialFailure(t *testing.T) { func TestDNSOverUDPDialFailure(t *testing.T) {
mocked := errors.New("mocked error") mocked := errors.New("mocked error")
const address = "9.9.9.9:53" const address = "9.9.9.9:53"
txp := NewDNSOverUDP(FakeDialer{Err: mocked}, address) txp := NewDNSOverUDP(&mocks.Dialer{
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
return nil, mocked
},
}, address)
data, err := txp.RoundTrip(context.Background(), nil) data, err := txp.RoundTrip(context.Background(), nil)
if !errors.Is(err, mocked) { if !errors.Is(err, mocked) {
t.Fatal("not the error we expected") t.Fatal("not the error we expected")
@ -23,9 +31,16 @@ func TestDNSOverUDPDialFailure(t *testing.T) {
func TestDNSOverUDPSetDeadlineError(t *testing.T) { func TestDNSOverUDPSetDeadlineError(t *testing.T) {
mocked := errors.New("mocked error") mocked := errors.New("mocked error")
txp := NewDNSOverUDP( txp := NewDNSOverUDP(
FakeDialer{ &mocks.Dialer{
Conn: &FakeConn{ MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
SetDeadlineError: mocked, return &mocks.Conn{
MockSetDeadline: func(t time.Time) error {
return mocked
},
MockClose: func() error {
return nil
},
}, nil
}, },
}, "9.9.9.9:53", }, "9.9.9.9:53",
) )
@ -41,9 +56,19 @@ func TestDNSOverUDPSetDeadlineError(t *testing.T) {
func TestDNSOverUDPWriteFailure(t *testing.T) { func TestDNSOverUDPWriteFailure(t *testing.T) {
mocked := errors.New("mocked error") mocked := errors.New("mocked error")
txp := NewDNSOverUDP( txp := NewDNSOverUDP(
FakeDialer{ &mocks.Dialer{
Conn: &FakeConn{ MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
WriteError: mocked, return &mocks.Conn{
MockSetDeadline: func(t time.Time) error {
return nil
},
MockWrite: func(b []byte) (int, error) {
return 0, mocked
},
MockClose: func() error {
return nil
},
}, nil
}, },
}, "9.9.9.9:53", }, "9.9.9.9:53",
) )
@ -59,9 +84,22 @@ func TestDNSOverUDPWriteFailure(t *testing.T) {
func TestDNSOverUDPReadFailure(t *testing.T) { func TestDNSOverUDPReadFailure(t *testing.T) {
mocked := errors.New("mocked error") mocked := errors.New("mocked error")
txp := NewDNSOverUDP( txp := NewDNSOverUDP(
FakeDialer{ &mocks.Dialer{
Conn: &FakeConn{ MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
ReadError: mocked, return &mocks.Conn{
MockSetDeadline: func(t time.Time) error {
return nil
},
MockWrite: func(b []byte) (int, error) {
return len(b), nil
},
MockRead: func(b []byte) (int, error) {
return 0, mocked
},
MockClose: func() error {
return nil
},
}, nil
}, },
}, "9.9.9.9:53", }, "9.9.9.9:53",
) )
@ -76,9 +114,23 @@ func TestDNSOverUDPReadFailure(t *testing.T) {
func TestDNSOverUDPReadSuccess(t *testing.T) { func TestDNSOverUDPReadSuccess(t *testing.T) {
const expected = 17 const expected = 17
input := bytes.NewReader(make([]byte, expected))
txp := NewDNSOverUDP( txp := NewDNSOverUDP(
FakeDialer{ &mocks.Dialer{
Conn: &FakeConn{ReadData: make([]byte, 17)}, MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
return &mocks.Conn{
MockSetDeadline: func(t time.Time) error {
return nil
},
MockWrite: func(b []byte) (int, error) {
return len(b), nil
},
MockRead: input.Read,
MockClose: func() error {
return nil
},
}, nil
},
}, "9.9.9.9:53", }, "9.9.9.9:53",
) )
data, err := txp.RoundTrip(context.Background(), nil) data, err := txp.RoundTrip(context.Background(), nil)

View File

@ -1,4 +1,4 @@
package resolver package dnsx
import "github.com/miekg/dns" import "github.com/miekg/dns"

View File

@ -1,4 +1,4 @@
package resolver package dnsx
import ( import (
"strings" "strings"

View File

@ -0,0 +1,11 @@
package mocks
// Decoder allows mocking dnsx.Decoder.
type Decoder struct {
MockDecode func(qtype uint16, reply []byte) ([]string, error)
}
// Decode calls MockDecode.
func (e *Decoder) Decode(qtype uint16, reply []byte) ([]string, error) {
return e.MockDecode(qtype, reply)
}

View File

@ -0,0 +1,26 @@
package mocks
import (
"errors"
"testing"
"github.com/miekg/dns"
)
func TestDecoder(t *testing.T) {
t.Run("Decode", func(t *testing.T) {
expected := errors.New("mocked error")
e := &Decoder{
MockDecode: func(qtype uint16, reply []byte) ([]string, error) {
return nil, expected
},
}
out, err := e.Decode(dns.TypeA, make([]byte, 17))
if !errors.Is(err, expected) {
t.Fatal("unexpected err", err)
}
if out != nil {
t.Fatal("unexpected out")
}
})
}

View File

@ -0,0 +1,2 @@
// Package mocks contains mocks for dnsx.
package mocks

View File

@ -0,0 +1,11 @@
package mocks
// Encoder allows mocking dnsx.Encoder.
type Encoder struct {
MockEncode func(domain string, qtype uint16, padding bool) ([]byte, error)
}
// Encode calls MockEncode.
func (e *Encoder) Encode(domain string, qtype uint16, padding bool) ([]byte, error) {
return e.MockEncode(domain, qtype, padding)
}

View File

@ -0,0 +1,26 @@
package mocks
import (
"errors"
"testing"
"github.com/miekg/dns"
)
func TestEncoder(t *testing.T) {
t.Run("Encode", func(t *testing.T) {
expected := errors.New("mocked error")
e := &Encoder{
MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, error) {
return nil, expected
},
}
out, err := e.Encode("dns.google", dns.TypeA, true)
if !errors.Is(err, expected) {
t.Fatal("unexpected err", err)
}
if out != nil {
t.Fatal("unexpected out")
}
})
}

View File

@ -0,0 +1,41 @@
package mocks
import "context"
// RoundTripper allows mocking dnsx.RoundTripper.
type RoundTripper struct {
MockRoundTrip func(ctx context.Context, query []byte) (reply []byte, err error)
MockRequiresPadding func() bool
MockNetwork func() string
MockAddress func() string
MockCloseIdleConnections func()
}
// RoundTrip calls MockRoundTrip.
func (txp *RoundTripper) RoundTrip(ctx context.Context, query []byte) (reply []byte, err error) {
return txp.MockRoundTrip(ctx, query)
}
// RequiresPadding calls MockRequiresPadding.
func (txp *RoundTripper) RequiresPadding() bool {
return txp.MockRequiresPadding()
}
// Network calls MockNetwork.
func (txp *RoundTripper) Network() string {
return txp.MockNetwork()
}
// Address calls MockAddress.
func (txp *RoundTripper) Address() string {
return txp.MockAddress()
}
// CloseIdleConnections calls MockCloseIdleConnections.
func (txp *RoundTripper) CloseIdleConnections() {
txp.MockCloseIdleConnections()
}

View File

@ -0,0 +1,73 @@
package mocks
import (
"context"
"errors"
"testing"
"github.com/ooni/probe-cli/v3/internal/atomicx"
)
func TestRoundTripper(t *testing.T) {
t.Run("RoundTrip", func(t *testing.T) {
expected := errors.New("mocked error")
txp := &RoundTripper{
MockRoundTrip: func(ctx context.Context, query []byte) ([]byte, error) {
return nil, expected
},
}
resp, err := txp.RoundTrip(context.Background(), make([]byte, 16))
if !errors.Is(err, expected) {
t.Fatal("not the error we expected", err)
}
if resp != nil {
t.Fatal("expected nil response here")
}
})
t.Run("RequiresPadding", func(t *testing.T) {
txp := &RoundTripper{
MockRequiresPadding: func() bool {
return true
},
}
if txp.RequiresPadding() != true {
t.Fatal("unexpected result")
}
})
t.Run("Network", func(t *testing.T) {
txp := &RoundTripper{
MockNetwork: func() string {
return "antani"
},
}
if txp.Network() != "antani" {
t.Fatal("unexpected result")
}
})
t.Run("Address", func(t *testing.T) {
txp := &RoundTripper{
MockAddress: func() string {
return "mascetti"
},
}
if txp.Address() != "mascetti" {
t.Fatal("unexpected result")
}
})
t.Run("CloseIdleConnections", func(t *testing.T) {
called := &atomicx.Int64{}
txp := &RoundTripper{
MockCloseIdleConnections: func() {
called.Add(1)
},
}
txp.CloseIdleConnections()
if called.Load() != 1 {
t.Fatal("not called")
}
})
}

View File

@ -0,0 +1,21 @@
package dnsx
import "context"
// RoundTripper represents an abstract DNS transport.
type RoundTripper interface {
// RoundTrip sends a DNS query and receives the reply.
RoundTrip(ctx context.Context, query []byte) (reply []byte, err error)
// RequiresPadding return true for DoH and DoT according to RFC8467
RequiresPadding() bool
// Network is the network of the round tripper (e.g. "dot")
Network() string
// Address is the address of the round tripper (e.g. "1.1.1.1:853")
Address() string
// CloseIdleConnections closes idle connections.
CloseIdleConnections()
}

View File

@ -1,4 +1,4 @@
package resolver package dnsx
import ( import (
"context" "context"
@ -9,24 +9,6 @@ import (
"github.com/ooni/probe-cli/v3/internal/atomicx" "github.com/ooni/probe-cli/v3/internal/atomicx"
) )
// RoundTripper represents an abstract DNS transport.
type RoundTripper interface {
// RoundTrip sends a DNS query and receives the reply.
RoundTrip(ctx context.Context, query []byte) (reply []byte, err error)
// RequiresPadding return true for DoH and DoT according to RFC8467
RequiresPadding() bool
// Network is the network of the round tripper (e.g. "dot")
Network() string
// Address is the address of the round tripper (e.g. "1.1.1.1:853")
Address() string
// CloseIdleConnections closes idle connections.
CloseIdleConnections()
}
// SerialResolver is a resolver that first issues an A query and then // SerialResolver is a resolver that first issues an A query and then
// issues an AAAA query for the requested domain. // issues an AAAA query for the requested domain.
type SerialResolver struct { type SerialResolver struct {
@ -117,5 +99,3 @@ func (r *SerialResolver) roundTrip(
} }
return r.Decoder.Decode(qtype, replydata) return r.Decoder.Decode(qtype, replydata)
} }
var _ Resolver = &SerialResolver{}

View File

@ -1,20 +1,21 @@
package resolver_test package dnsx
import ( import (
"context" "context"
"crypto/tls"
"errors" "errors"
"net" "net"
"strings" "strings"
"syscall"
"testing" "testing"
"github.com/miekg/dns" "github.com/miekg/dns"
"github.com/ooni/probe-cli/v3/internal/engine/netx/resolver" "github.com/ooni/probe-cli/v3/internal/netxlite/dnsx/mocks"
"github.com/ooni/probe-cli/v3/internal/netxlite/errorsx"
) )
func TestOONIGettingTransport(t *testing.T) { func TestOONIGettingTransport(t *testing.T) {
txp := resolver.NewDNSOverTLS(resolver.DialTLSContext, "8.8.8.8:853") txp := NewDNSOverTLS((&tls.Dialer{}).DialContext, "8.8.8.8:853")
r := resolver.NewSerialResolver(txp) r := NewSerialResolver(txp)
rtx := r.Transport() rtx := r.Transport()
if rtx.Network() != "dot" || rtx.Address() != "8.8.8.8:853" { if rtx.Network() != "dot" || rtx.Address() != "8.8.8.8:853" {
t.Fatal("not the transport we expected") t.Fatal("not the transport we expected")
@ -29,8 +30,15 @@ func TestOONIGettingTransport(t *testing.T) {
func TestOONIEncodeError(t *testing.T) { func TestOONIEncodeError(t *testing.T) {
mocked := errors.New("mocked error") mocked := errors.New("mocked error")
txp := resolver.NewDNSOverTLS(resolver.DialTLSContext, "8.8.8.8:853") txp := NewDNSOverTLS((&tls.Dialer{}).DialContext, "8.8.8.8:853")
r := resolver.SerialResolver{Encoder: resolver.FakeEncoder{Err: mocked}, Txp: txp} r := SerialResolver{
Encoder: &mocks.Encoder{
MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, error) {
return nil, mocked
},
},
Txp: txp,
}
addrs, err := r.LookupHost(context.Background(), "www.gogle.com") addrs, err := r.LookupHost(context.Background(), "www.gogle.com")
if !errors.Is(err, mocked) { if !errors.Is(err, mocked) {
t.Fatal("not the error we expected") t.Fatal("not the error we expected")
@ -42,8 +50,15 @@ func TestOONIEncodeError(t *testing.T) {
func TestOONIRoundTripError(t *testing.T) { func TestOONIRoundTripError(t *testing.T) {
mocked := errors.New("mocked error") mocked := errors.New("mocked error")
txp := resolver.FakeTransport{Err: mocked} txp := &mocks.RoundTripper{
r := resolver.NewSerialResolver(txp) MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) {
return nil, mocked
},
MockRequiresPadding: func() bool {
return true
},
}
r := NewSerialResolver(txp)
addrs, err := r.LookupHost(context.Background(), "www.gogle.com") addrs, err := r.LookupHost(context.Background(), "www.gogle.com")
if !errors.Is(err, mocked) { if !errors.Is(err, mocked) {
t.Fatal("not the error we expected") t.Fatal("not the error we expected")
@ -54,8 +69,15 @@ func TestOONIRoundTripError(t *testing.T) {
} }
func TestOONIWithEmptyReply(t *testing.T) { func TestOONIWithEmptyReply(t *testing.T) {
txp := resolver.FakeTransport{Data: resolver.GenReplySuccess(t, dns.TypeA)} txp := &mocks.RoundTripper{
r := resolver.NewSerialResolver(txp) MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) {
return genReplySuccess(t, dns.TypeA), nil
},
MockRequiresPadding: func() bool {
return true
},
}
r := NewSerialResolver(txp)
addrs, err := r.LookupHost(context.Background(), "www.gogle.com") addrs, err := r.LookupHost(context.Background(), "www.gogle.com")
if err == nil || !strings.HasSuffix(err.Error(), "no response returned") { if err == nil || !strings.HasSuffix(err.Error(), "no response returned") {
t.Fatal("not the error we expected") t.Fatal("not the error we expected")
@ -66,10 +88,15 @@ func TestOONIWithEmptyReply(t *testing.T) {
} }
func TestOONIWithAReply(t *testing.T) { func TestOONIWithAReply(t *testing.T) {
txp := resolver.FakeTransport{ txp := &mocks.RoundTripper{
Data: resolver.GenReplySuccess(t, dns.TypeA, "8.8.8.8"), MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) {
return genReplySuccess(t, dns.TypeA, "8.8.8.8"), nil
},
MockRequiresPadding: func() bool {
return true
},
} }
r := resolver.NewSerialResolver(txp) r := NewSerialResolver(txp)
addrs, err := r.LookupHost(context.Background(), "www.gogle.com") addrs, err := r.LookupHost(context.Background(), "www.gogle.com")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -80,10 +107,15 @@ func TestOONIWithAReply(t *testing.T) {
} }
func TestOONIWithAAAAReply(t *testing.T) { func TestOONIWithAAAAReply(t *testing.T) {
txp := resolver.FakeTransport{ txp := &mocks.RoundTripper{
Data: resolver.GenReplySuccess(t, dns.TypeAAAA, "::1"), MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) {
return genReplySuccess(t, dns.TypeAAAA, "::1"), nil
},
MockRequiresPadding: func() bool {
return true
},
} }
r := resolver.NewSerialResolver(txp) r := NewSerialResolver(txp)
addrs, err := r.LookupHost(context.Background(), "www.gogle.com") addrs, err := r.LookupHost(context.Background(), "www.gogle.com")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -94,12 +126,17 @@ func TestOONIWithAAAAReply(t *testing.T) {
} }
func TestOONIWithTimeout(t *testing.T) { func TestOONIWithTimeout(t *testing.T) {
txp := resolver.FakeTransport{ txp := &mocks.RoundTripper{
Err: &net.OpError{Err: syscall.ETIMEDOUT, Op: "dial"}, MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) {
return nil, &net.OpError{Err: errorsx.ETIMEDOUT, Op: "dial"}
},
MockRequiresPadding: func() bool {
return true
},
} }
r := resolver.NewSerialResolver(txp) r := NewSerialResolver(txp)
addrs, err := r.LookupHost(context.Background(), "www.gogle.com") addrs, err := r.LookupHost(context.Background(), "www.gogle.com")
if !errors.Is(err, syscall.ETIMEDOUT) { if !errors.Is(err, errorsx.ETIMEDOUT) {
t.Fatal("not the error we expected") t.Fatal("not the error we expected")
} }
if addrs != nil { if addrs != nil {