refactor: DNSTransport I/Os DNS messages (#760)
This diff refactors the DNSTransport model to receive in input a DNSQuery and return in output a DNSResponse. The design of DNSQuery and DNSResponse takes into account the use case of a transport using getaddrinfo, meaning that we don't need to serialize and deserialize messages when using getaddrinfo. The current codebase does not use a getaddrinfo transport, but I wrote one such a transport in the Websteps Winter 2021 prototype (https://github.com/bassosimone/websteps-illustrated/). The design conversation that lead to producing this diff is https://github.com/ooni/probe/issues/2099
This commit is contained in:
parent
7a0a156aec
commit
01a513a496
1
go.mod
1
go.mod
|
@ -17,7 +17,6 @@ require (
|
||||||
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510
|
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510
|
||||||
github.com/google/uuid v1.3.0
|
github.com/google/uuid v1.3.0
|
||||||
github.com/gorilla/websocket v1.5.0
|
github.com/gorilla/websocket v1.5.0
|
||||||
github.com/hexops/gotextdiff v1.0.3
|
|
||||||
github.com/iancoleman/strcase v0.2.0
|
github.com/iancoleman/strcase v0.2.0
|
||||||
github.com/lucas-clemente/quic-go v0.27.0
|
github.com/lucas-clemente/quic-go v0.27.0
|
||||||
github.com/mattn/go-colorable v0.1.12
|
github.com/mattn/go-colorable v0.1.12
|
||||||
|
|
2
go.sum
2
go.sum
|
@ -378,8 +378,6 @@ github.com/hashicorp/logutils v1.0.0/go.mod h1:QIAnNjmIWmVIIkWDTG1z5v++HQmx9WQRO
|
||||||
github.com/hashicorp/mdns v1.0.0/go.mod h1:tL+uN++7HEJ6SQLQ2/p+z2pH24WQKWjBPkE0mNTz8vQ=
|
github.com/hashicorp/mdns v1.0.0/go.mod h1:tL+uN++7HEJ6SQLQ2/p+z2pH24WQKWjBPkE0mNTz8vQ=
|
||||||
github.com/hashicorp/memberlist v0.1.3/go.mod h1:ajVTdAv/9Im8oMAAj5G31PhhMCZJV2pPBoIllUwCN7I=
|
github.com/hashicorp/memberlist v0.1.3/go.mod h1:ajVTdAv/9Im8oMAAj5G31PhhMCZJV2pPBoIllUwCN7I=
|
||||||
github.com/hashicorp/serf v0.8.2/go.mod h1:6hOLApaqBFA1NXqRQAsxw9QxuDEvNxSQRwA/JwenrHc=
|
github.com/hashicorp/serf v0.8.2/go.mod h1:6hOLApaqBFA1NXqRQAsxw9QxuDEvNxSQRwA/JwenrHc=
|
||||||
github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUqJM=
|
|
||||||
github.com/hexops/gotextdiff v1.0.3/go.mod h1:pSWU5MAI3yDq+fZBTazCSJysOMbxWL1BSow5/V2vxeg=
|
|
||||||
github.com/hinshun/vt10x v0.0.0-20220119200601-820417d04eec h1:qv2VnGeEQHchGaZ/u7lxST/RaJw+cv273q79D81Xbog=
|
github.com/hinshun/vt10x v0.0.0-20220119200601-820417d04eec h1:qv2VnGeEQHchGaZ/u7lxST/RaJw+cv273q79D81Xbog=
|
||||||
github.com/hinshun/vt10x v0.0.0-20220119200601-820417d04eec/go.mod h1:Q48J4R4DvxnHolD5P8pOtXigYlRuPLGl6moFx3ulM68=
|
github.com/hinshun/vt10x v0.0.0-20220119200601-820417d04eec/go.mod h1:Q48J4R4DvxnHolD5P8pOtXigYlRuPLGl6moFx3ulM68=
|
||||||
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
|
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
|
||||||
|
|
|
@ -317,7 +317,7 @@ func NewDNSClientWithOverrides(config Config, URL, hostOverride, SNIOverride,
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
var txp model.DNSTransport = netxlite.NewDNSOverTLS(
|
var txp model.DNSTransport = netxlite.NewDNSOverTLSTransport(
|
||||||
tlsDialer.DialTLSContext, endpoint)
|
tlsDialer.DialTLSContext, endpoint)
|
||||||
if config.ResolveSaver != nil {
|
if config.ResolveSaver != nil {
|
||||||
txp = resolver.SaverDNSTransport{
|
txp = resolver.SaverDNSTransport{
|
||||||
|
|
|
@ -76,40 +76,6 @@ func (c *FakeConn) SetWriteDeadline(t time.Time) (err error) {
|
||||||
return c.SetWriteDeadlineError
|
return c.SetWriteDeadlineError
|
||||||
}
|
}
|
||||||
|
|
||||||
type FakeTransport struct {
|
|
||||||
Data []byte
|
|
||||||
Err error
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ft FakeTransport) RoundTrip(ctx context.Context, query []byte) ([]byte, error) {
|
|
||||||
return ft.Data, ft.Err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ft FakeTransport) RequiresPadding() bool {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ft FakeTransport) Address() string {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ft FakeTransport) Network() string {
|
|
||||||
return "fake"
|
|
||||||
}
|
|
||||||
|
|
||||||
func (fk FakeTransport) CloseIdleConnections() {
|
|
||||||
// nothing to do
|
|
||||||
}
|
|
||||||
|
|
||||||
type FakeEncoder struct {
|
|
||||||
Data []byte
|
|
||||||
Err error
|
|
||||||
}
|
|
||||||
|
|
||||||
func (fe FakeEncoder) Encode(domain string, qtype uint16, padding bool) ([]byte, error) {
|
|
||||||
return fe.Data, fe.Err
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewFakeResolverThatFails() model.Resolver {
|
func NewFakeResolverThatFails() model.Resolver {
|
||||||
return NewFakeResolverWithExplicitError(netxlite.ErrOODNSNoSuchHost)
|
return NewFakeResolverWithExplicitError(netxlite.ErrOODNSNoSuchHost)
|
||||||
}
|
}
|
||||||
|
|
|
@ -99,14 +99,14 @@ func TestNewResolverTCPDomain(t *testing.T) {
|
||||||
|
|
||||||
func TestNewResolverDoTAddress(t *testing.T) {
|
func TestNewResolverDoTAddress(t *testing.T) {
|
||||||
reso := netxlite.NewSerialResolver(
|
reso := netxlite.NewSerialResolver(
|
||||||
netxlite.NewDNSOverTLS(new(tls.Dialer).DialContext, "8.8.8.8:853"))
|
netxlite.NewDNSOverTLSTransport(new(tls.Dialer).DialContext, "8.8.8.8:853"))
|
||||||
testresolverquick(t, reso)
|
testresolverquick(t, reso)
|
||||||
testresolverquickidna(t, reso)
|
testresolverquickidna(t, reso)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNewResolverDoTDomain(t *testing.T) {
|
func TestNewResolverDoTDomain(t *testing.T) {
|
||||||
reso := netxlite.NewSerialResolver(
|
reso := netxlite.NewSerialResolver(
|
||||||
netxlite.NewDNSOverTLS(new(tls.Dialer).DialContext, "dns.google.com:853"))
|
netxlite.NewDNSOverTLSTransport(new(tls.Dialer).DialContext, "dns.google.com:853"))
|
||||||
testresolverquick(t, reso)
|
testresolverquick(t, reso)
|
||||||
testresolverquickidna(t, reso)
|
testresolverquickidna(t, reso)
|
||||||
}
|
}
|
||||||
|
|
|
@ -46,28 +46,41 @@ type SaverDNSTransport struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// RoundTrip implements RoundTripper.RoundTrip
|
// RoundTrip implements RoundTripper.RoundTrip
|
||||||
func (txp SaverDNSTransport) RoundTrip(ctx context.Context, query []byte) ([]byte, error) {
|
func (txp SaverDNSTransport) RoundTrip(
|
||||||
|
ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
txp.Saver.Write(trace.Event{
|
txp.Saver.Write(trace.Event{
|
||||||
Address: txp.Address(),
|
Address: txp.Address(),
|
||||||
DNSQuery: query,
|
DNSQuery: txp.maybeQueryBytes(query),
|
||||||
Name: "dns_round_trip_start",
|
Name: "dns_round_trip_start",
|
||||||
Proto: txp.Network(),
|
Proto: txp.Network(),
|
||||||
Time: start,
|
Time: start,
|
||||||
})
|
})
|
||||||
reply, err := txp.DNSTransport.RoundTrip(ctx, query)
|
response, err := txp.DNSTransport.RoundTrip(ctx, query)
|
||||||
stop := time.Now()
|
stop := time.Now()
|
||||||
txp.Saver.Write(trace.Event{
|
txp.Saver.Write(trace.Event{
|
||||||
Address: txp.Address(),
|
Address: txp.Address(),
|
||||||
DNSQuery: query,
|
DNSQuery: txp.maybeQueryBytes(query),
|
||||||
DNSReply: reply,
|
DNSReply: txp.maybeResponseBytes(response),
|
||||||
Duration: stop.Sub(start),
|
Duration: stop.Sub(start),
|
||||||
Err: err,
|
Err: err,
|
||||||
Name: "dns_round_trip_done",
|
Name: "dns_round_trip_done",
|
||||||
Proto: txp.Network(),
|
Proto: txp.Network(),
|
||||||
Time: stop,
|
Time: stop,
|
||||||
})
|
})
|
||||||
return reply, err
|
return response, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (txp SaverDNSTransport) maybeQueryBytes(query model.DNSQuery) []byte {
|
||||||
|
data, _ := query.Bytes()
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
|
||||||
|
func (txp SaverDNSTransport) maybeResponseBytes(response model.DNSResponse) []byte {
|
||||||
|
if response == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return response.Bytes()
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ model.Resolver = SaverResolver{}
|
var _ model.Resolver = SaverResolver{}
|
||||||
|
|
|
@ -10,6 +10,8 @@ import (
|
||||||
|
|
||||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/resolver"
|
"github.com/ooni/probe-cli/v3/internal/engine/netx/resolver"
|
||||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/trace"
|
"github.com/ooni/probe-cli/v3/internal/engine/netx/trace"
|
||||||
|
"github.com/ooni/probe-cli/v3/internal/model"
|
||||||
|
"github.com/ooni/probe-cli/v3/internal/model/mocks"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestSaverResolverFailure(t *testing.T) {
|
func TestSaverResolverFailure(t *testing.T) {
|
||||||
|
@ -110,12 +112,25 @@ func TestSaverDNSTransportFailure(t *testing.T) {
|
||||||
expected := errors.New("no such host")
|
expected := errors.New("no such host")
|
||||||
saver := &trace.Saver{}
|
saver := &trace.Saver{}
|
||||||
txp := resolver.SaverDNSTransport{
|
txp := resolver.SaverDNSTransport{
|
||||||
DNSTransport: resolver.FakeTransport{
|
DNSTransport: &mocks.DNSTransport{
|
||||||
Err: expected,
|
MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
|
||||||
|
return nil, expected
|
||||||
|
},
|
||||||
|
MockNetwork: func() string {
|
||||||
|
return "fake"
|
||||||
|
},
|
||||||
|
MockAddress: func() string {
|
||||||
|
return ""
|
||||||
|
},
|
||||||
},
|
},
|
||||||
Saver: saver,
|
Saver: saver,
|
||||||
}
|
}
|
||||||
query := []byte("abc")
|
rawQuery := []byte{0xde, 0xad, 0xbe, 0xef}
|
||||||
|
query := &mocks.DNSQuery{
|
||||||
|
MockBytes: func() ([]byte, error) {
|
||||||
|
return rawQuery, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
reply, err := txp.RoundTrip(context.Background(), query)
|
reply, err := txp.RoundTrip(context.Background(), query)
|
||||||
if !errors.Is(err, expected) {
|
if !errors.Is(err, expected) {
|
||||||
t.Fatal("not the error we expected")
|
t.Fatal("not the error we expected")
|
||||||
|
@ -127,7 +142,7 @@ func TestSaverDNSTransportFailure(t *testing.T) {
|
||||||
if len(ev) != 2 {
|
if len(ev) != 2 {
|
||||||
t.Fatal("expected number of events")
|
t.Fatal("expected number of events")
|
||||||
}
|
}
|
||||||
if !bytes.Equal(ev[0].DNSQuery, query) {
|
if !bytes.Equal(ev[0].DNSQuery, rawQuery) {
|
||||||
t.Fatal("unexpected DNSQuery")
|
t.Fatal("unexpected DNSQuery")
|
||||||
}
|
}
|
||||||
if ev[0].Name != "dns_round_trip_start" {
|
if ev[0].Name != "dns_round_trip_start" {
|
||||||
|
@ -136,7 +151,7 @@ func TestSaverDNSTransportFailure(t *testing.T) {
|
||||||
if !ev[0].Time.Before(time.Now()) {
|
if !ev[0].Time.Before(time.Now()) {
|
||||||
t.Fatal("the saved time is wrong")
|
t.Fatal("the saved time is wrong")
|
||||||
}
|
}
|
||||||
if !bytes.Equal(ev[1].DNSQuery, query) {
|
if !bytes.Equal(ev[1].DNSQuery, rawQuery) {
|
||||||
t.Fatal("unexpected DNSQuery")
|
t.Fatal("unexpected DNSQuery")
|
||||||
}
|
}
|
||||||
if ev[1].DNSReply != nil {
|
if ev[1].DNSReply != nil {
|
||||||
|
@ -157,27 +172,45 @@ func TestSaverDNSTransportFailure(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSaverDNSTransportSuccess(t *testing.T) {
|
func TestSaverDNSTransportSuccess(t *testing.T) {
|
||||||
expected := []byte("def")
|
expected := []byte{0xef, 0xbe, 0xad, 0xde}
|
||||||
saver := &trace.Saver{}
|
saver := &trace.Saver{}
|
||||||
|
response := &mocks.DNSResponse{
|
||||||
|
MockBytes: func() []byte {
|
||||||
|
return expected
|
||||||
|
},
|
||||||
|
}
|
||||||
txp := resolver.SaverDNSTransport{
|
txp := resolver.SaverDNSTransport{
|
||||||
DNSTransport: resolver.FakeTransport{
|
DNSTransport: &mocks.DNSTransport{
|
||||||
Data: expected,
|
MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
|
||||||
|
return response, nil
|
||||||
|
},
|
||||||
|
MockNetwork: func() string {
|
||||||
|
return "fake"
|
||||||
|
},
|
||||||
|
MockAddress: func() string {
|
||||||
|
return ""
|
||||||
|
},
|
||||||
},
|
},
|
||||||
Saver: saver,
|
Saver: saver,
|
||||||
}
|
}
|
||||||
query := []byte("abc")
|
rawQuery := []byte{0xde, 0xad, 0xbe, 0xef}
|
||||||
|
query := &mocks.DNSQuery{
|
||||||
|
MockBytes: func() ([]byte, error) {
|
||||||
|
return rawQuery, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
reply, err := txp.RoundTrip(context.Background(), query)
|
reply, err := txp.RoundTrip(context.Background(), query)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal("we expected nil error here")
|
t.Fatal("we expected nil error here")
|
||||||
}
|
}
|
||||||
if !bytes.Equal(reply, expected) {
|
if !bytes.Equal(reply.Bytes(), expected) {
|
||||||
t.Fatal("expected another reply here")
|
t.Fatal("expected another reply here")
|
||||||
}
|
}
|
||||||
ev := saver.Read()
|
ev := saver.Read()
|
||||||
if len(ev) != 2 {
|
if len(ev) != 2 {
|
||||||
t.Fatal("expected number of events")
|
t.Fatal("expected number of events")
|
||||||
}
|
}
|
||||||
if !bytes.Equal(ev[0].DNSQuery, query) {
|
if !bytes.Equal(ev[0].DNSQuery, rawQuery) {
|
||||||
t.Fatal("unexpected DNSQuery")
|
t.Fatal("unexpected DNSQuery")
|
||||||
}
|
}
|
||||||
if ev[0].Name != "dns_round_trip_start" {
|
if ev[0].Name != "dns_round_trip_start" {
|
||||||
|
@ -186,7 +219,7 @@ func TestSaverDNSTransportSuccess(t *testing.T) {
|
||||||
if !ev[0].Time.Before(time.Now()) {
|
if !ev[0].Time.Before(time.Now()) {
|
||||||
t.Fatal("the saved time is wrong")
|
t.Fatal("the saved time is wrong")
|
||||||
}
|
}
|
||||||
if !bytes.Equal(ev[1].DNSQuery, query) {
|
if !bytes.Equal(ev[1].DNSQuery, rawQuery) {
|
||||||
t.Fatal("unexpected DNSQuery")
|
t.Fatal("unexpected DNSQuery")
|
||||||
}
|
}
|
||||||
if !bytes.Equal(ev[1].DNSReply, expected) {
|
if !bytes.Equal(ev[1].DNSReply, expected) {
|
||||||
|
|
|
@ -36,18 +36,31 @@ type DNSRoundTripEvent struct {
|
||||||
Reply []byte
|
Reply []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
func (txp *dnsxRoundTripperDB) RoundTrip(ctx context.Context, query []byte) ([]byte, error) {
|
func (txp *dnsxRoundTripperDB) RoundTrip(
|
||||||
|
ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
|
||||||
started := time.Since(txp.begin).Seconds()
|
started := time.Since(txp.begin).Seconds()
|
||||||
reply, err := txp.DNSTransport.RoundTrip(ctx, query)
|
response, err := txp.DNSTransport.RoundTrip(ctx, query)
|
||||||
finished := time.Since(txp.begin).Seconds()
|
finished := time.Since(txp.begin).Seconds()
|
||||||
txp.db.InsertIntoDNSRoundTrip(&DNSRoundTripEvent{
|
txp.db.InsertIntoDNSRoundTrip(&DNSRoundTripEvent{
|
||||||
Network: txp.DNSTransport.Network(),
|
Network: txp.DNSTransport.Network(),
|
||||||
Address: txp.DNSTransport.Address(),
|
Address: txp.DNSTransport.Address(),
|
||||||
Query: query,
|
Query: txp.maybeQueryBytes(query),
|
||||||
Started: started,
|
Started: started,
|
||||||
Finished: finished,
|
Finished: finished,
|
||||||
Failure: NewFailure(err),
|
Failure: NewFailure(err),
|
||||||
Reply: reply,
|
Reply: txp.maybeResponseBytes(response),
|
||||||
})
|
})
|
||||||
return reply, err
|
return response, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (txp *dnsxRoundTripperDB) maybeQueryBytes(query model.DNSQuery) []byte {
|
||||||
|
data, _ := query.Bytes()
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
|
||||||
|
func (txp *dnsxRoundTripperDB) maybeResponseBytes(response model.DNSResponse) []byte {
|
||||||
|
if response == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return response.Bytes()
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,36 +1,20 @@
|
||||||
package mocks
|
package mocks
|
||||||
|
|
||||||
import (
|
//
|
||||||
"net"
|
// Mocks for model.DNSDecoder
|
||||||
|
//
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
import (
|
||||||
"github.com/ooni/probe-cli/v3/internal/model"
|
"github.com/ooni/probe-cli/v3/internal/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
// DNSDecoder allows mocking dnsx.DNSDecoder.
|
// DNSDecoder allows mocking model.DNSDecoder.
|
||||||
type DNSDecoder struct {
|
type DNSDecoder struct {
|
||||||
MockDecodeLookupHost func(qtype uint16, reply []byte, queryID uint16) ([]string, error)
|
MockDecodeResponse func(data []byte, query model.DNSQuery) (model.DNSResponse, error)
|
||||||
MockDecodeHTTPS func(reply []byte, queryID uint16) (*model.HTTPSSvc, error)
|
|
||||||
MockDecodeNS func(reply []byte, queryID uint16) ([]*net.NS, error)
|
|
||||||
MockDecodeReply func(reply []byte) (*dns.Msg, error)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// DecodeLookupHost calls MockDecodeLookupHost.
|
var _ model.DNSDecoder = &DNSDecoder{}
|
||||||
func (e *DNSDecoder) DecodeLookupHost(qtype uint16, reply []byte, queryID uint16) ([]string, error) {
|
|
||||||
return e.MockDecodeLookupHost(qtype, reply, queryID)
|
|
||||||
}
|
|
||||||
|
|
||||||
// DecodeHTTPS calls MockDecodeHTTPS.
|
func (e *DNSDecoder) DecodeResponse(data []byte, query model.DNSQuery) (model.DNSResponse, error) {
|
||||||
func (e *DNSDecoder) DecodeHTTPS(reply []byte, queryID uint16) (*model.HTTPSSvc, error) {
|
return e.MockDecodeResponse(data, query)
|
||||||
return e.MockDecodeHTTPS(reply, queryID)
|
|
||||||
}
|
|
||||||
|
|
||||||
// DecodeNS calls MockDecodeNS.
|
|
||||||
func (e *DNSDecoder) DecodeNS(reply []byte, queryID uint16) ([]*net.NS, error) {
|
|
||||||
return e.MockDecodeNS(reply, queryID)
|
|
||||||
}
|
|
||||||
|
|
||||||
// DecodeReply calls MockDecodeReply.
|
|
||||||
func (e *DNSDecoder) DecodeReply(reply []byte) (*dns.Msg, error) {
|
|
||||||
return e.MockDecodeReply(reply)
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,70 +2,20 @@ package mocks
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"net"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
|
||||||
"github.com/ooni/probe-cli/v3/internal/model"
|
"github.com/ooni/probe-cli/v3/internal/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestDNSDecoder(t *testing.T) {
|
func TestDNSDecoder(t *testing.T) {
|
||||||
t.Run("DecodeLookupHost", func(t *testing.T) {
|
t.Run("DecodeResponse", func(t *testing.T) {
|
||||||
expected := errors.New("mocked error")
|
expected := errors.New("mocked error")
|
||||||
e := &DNSDecoder{
|
e := &DNSDecoder{
|
||||||
MockDecodeLookupHost: func(qtype uint16, reply []byte, queryID uint16) ([]string, error) {
|
MockDecodeResponse: func(reply []byte, query model.DNSQuery) (model.DNSResponse, error) {
|
||||||
return nil, expected
|
return nil, expected
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
out, err := e.DecodeLookupHost(dns.TypeA, make([]byte, 17), dns.Id())
|
out, err := e.DecodeResponse(make([]byte, 17), &DNSQuery{})
|
||||||
if !errors.Is(err, expected) {
|
|
||||||
t.Fatal("unexpected err", err)
|
|
||||||
}
|
|
||||||
if out != nil {
|
|
||||||
t.Fatal("unexpected out")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("DecodeHTTPS", func(t *testing.T) {
|
|
||||||
expected := errors.New("mocked error")
|
|
||||||
e := &DNSDecoder{
|
|
||||||
MockDecodeHTTPS: func(reply []byte, queryID uint16) (*model.HTTPSSvc, error) {
|
|
||||||
return nil, expected
|
|
||||||
},
|
|
||||||
}
|
|
||||||
out, err := e.DecodeHTTPS(make([]byte, 17), dns.Id())
|
|
||||||
if !errors.Is(err, expected) {
|
|
||||||
t.Fatal("unexpected err", err)
|
|
||||||
}
|
|
||||||
if out != nil {
|
|
||||||
t.Fatal("unexpected out")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("DecodeNS", func(t *testing.T) {
|
|
||||||
expected := errors.New("mocked error")
|
|
||||||
e := &DNSDecoder{
|
|
||||||
MockDecodeNS: func(reply []byte, queryID uint16) ([]*net.NS, error) {
|
|
||||||
return nil, expected
|
|
||||||
},
|
|
||||||
}
|
|
||||||
out, err := e.DecodeNS(make([]byte, 17), dns.Id())
|
|
||||||
if !errors.Is(err, expected) {
|
|
||||||
t.Fatal("unexpected err", err)
|
|
||||||
}
|
|
||||||
if out != nil {
|
|
||||||
t.Fatal("unexpected out")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("DecodeReply", func(t *testing.T) {
|
|
||||||
expected := errors.New("mocked error")
|
|
||||||
e := &DNSDecoder{
|
|
||||||
MockDecodeReply: func(reply []byte) (*dns.Msg, error) {
|
|
||||||
return nil, expected
|
|
||||||
},
|
|
||||||
}
|
|
||||||
out, err := e.DecodeReply(make([]byte, 17))
|
|
||||||
if !errors.Is(err, expected) {
|
if !errors.Is(err, expected) {
|
||||||
t.Fatal("unexpected err", err)
|
t.Fatal("unexpected err", err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,11 +1,19 @@
|
||||||
package mocks
|
package mocks
|
||||||
|
|
||||||
// DNSEncoder allows mocking dnsx.DNSEncoder.
|
//
|
||||||
|
// Mocks for model.DNSEncoder.
|
||||||
|
//
|
||||||
|
|
||||||
|
import "github.com/ooni/probe-cli/v3/internal/model"
|
||||||
|
|
||||||
|
// DNSEncoder allows mocking model.DNSEncoder.
|
||||||
type DNSEncoder struct {
|
type DNSEncoder struct {
|
||||||
MockEncode func(domain string, qtype uint16, padding bool) ([]byte, uint16, error)
|
MockEncode func(domain string, qtype uint16, padding bool) model.DNSQuery
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var _ model.DNSEncoder = &DNSEncoder{}
|
||||||
|
|
||||||
// Encode calls MockEncode.
|
// Encode calls MockEncode.
|
||||||
func (e *DNSEncoder) Encode(domain string, qtype uint16, padding bool) ([]byte, uint16, error) {
|
func (e *DNSEncoder) Encode(domain string, qtype uint16, padding bool) model.DNSQuery {
|
||||||
return e.MockEncode(domain, qtype, padding)
|
return e.MockEncode(domain, qtype, padding)
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,24 +5,46 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
|
"github.com/ooni/probe-cli/v3/internal/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestDNSEncoder(t *testing.T) {
|
func TestDNSEncoder(t *testing.T) {
|
||||||
t.Run("Encode", func(t *testing.T) {
|
t.Run("Encode", func(t *testing.T) {
|
||||||
expected := errors.New("mocked error")
|
expected := errors.New("mocked error")
|
||||||
|
queryID := dns.Id()
|
||||||
e := &DNSEncoder{
|
e := &DNSEncoder{
|
||||||
MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, uint16, error) {
|
MockEncode: func(domain string, qtype uint16, padding bool) model.DNSQuery {
|
||||||
return nil, 0, expected
|
return &DNSQuery{
|
||||||
|
MockDomain: func() string {
|
||||||
|
return dns.Fqdn(domain) // do what an implementation MUST do
|
||||||
|
},
|
||||||
|
MockType: func() uint16 {
|
||||||
|
return qtype
|
||||||
|
},
|
||||||
|
MockBytes: func() ([]byte, error) {
|
||||||
|
return nil, expected
|
||||||
|
},
|
||||||
|
MockID: func() uint16 {
|
||||||
|
return queryID
|
||||||
|
},
|
||||||
|
}
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
out, queryID, err := e.Encode("dns.google", dns.TypeA, true)
|
query := e.Encode("dns.google", dns.TypeA, true)
|
||||||
|
if query.Domain() != "dns.google." {
|
||||||
|
t.Fatal("invalid domain")
|
||||||
|
}
|
||||||
|
if query.Type() != dns.TypeA {
|
||||||
|
t.Fatal("invalid type")
|
||||||
|
}
|
||||||
|
out, err := query.Bytes()
|
||||||
if !errors.Is(err, expected) {
|
if !errors.Is(err, expected) {
|
||||||
t.Fatal("unexpected err", err)
|
t.Fatal("unexpected err", err)
|
||||||
}
|
}
|
||||||
if out != nil {
|
if out != nil {
|
||||||
t.Fatal("unexpected out")
|
t.Fatal("unexpected out")
|
||||||
}
|
}
|
||||||
if queryID != 0 {
|
if query.ID() != queryID {
|
||||||
t.Fatal("unexpected queryID")
|
t.Fatal("unexpected queryID")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
33
internal/model/mocks/dnsquery.go
Normal file
33
internal/model/mocks/dnsquery.go
Normal file
|
@ -0,0 +1,33 @@
|
||||||
|
package mocks
|
||||||
|
|
||||||
|
//
|
||||||
|
// Mocks for model.DNSQuery.
|
||||||
|
//
|
||||||
|
|
||||||
|
import "github.com/ooni/probe-cli/v3/internal/model"
|
||||||
|
|
||||||
|
// DNSQuery allocks mocking model.DNSQuery.
|
||||||
|
type DNSQuery struct {
|
||||||
|
MockDomain func() string
|
||||||
|
MockType func() uint16
|
||||||
|
MockBytes func() ([]byte, error)
|
||||||
|
MockID func() uint16
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *DNSQuery) Domain() string {
|
||||||
|
return q.MockDomain()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *DNSQuery) Type() uint16 {
|
||||||
|
return q.MockType()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *DNSQuery) Bytes() ([]byte, error) {
|
||||||
|
return q.MockBytes()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *DNSQuery) ID() uint16 {
|
||||||
|
return q.MockID()
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ model.DNSQuery = &DNSQuery{}
|
62
internal/model/mocks/dnsquery_test.go
Normal file
62
internal/model/mocks/dnsquery_test.go
Normal file
|
@ -0,0 +1,62 @@
|
||||||
|
package mocks
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDNSQuery(t *testing.T) {
|
||||||
|
t.Run("Domain", func(t *testing.T) {
|
||||||
|
expected := "dns.google."
|
||||||
|
q := &DNSQuery{
|
||||||
|
MockDomain: func() string {
|
||||||
|
return expected
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if q.Domain() != expected {
|
||||||
|
t.Fatal("invalid domain")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Type", func(t *testing.T) {
|
||||||
|
expected := dns.TypeAAAA
|
||||||
|
q := &DNSQuery{
|
||||||
|
MockType: func() uint16 {
|
||||||
|
return expected
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if q.Type() != expected {
|
||||||
|
t.Fatal("invalid type")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Bytes", func(t *testing.T) {
|
||||||
|
expected := []byte{0xde, 0xea, 0xad, 0xbe, 0xef}
|
||||||
|
q := &DNSQuery{
|
||||||
|
MockBytes: func() ([]byte, error) {
|
||||||
|
return expected, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
out, err := q.Bytes()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if !bytes.Equal(expected, out) {
|
||||||
|
t.Fatal("invalid bytes")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("ID", func(t *testing.T) {
|
||||||
|
expected := dns.Id()
|
||||||
|
q := &DNSQuery{
|
||||||
|
MockID: func() uint16 {
|
||||||
|
return expected
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if q.ID() != expected {
|
||||||
|
t.Fatal("invalid id")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
47
internal/model/mocks/dnsresponse.go
Normal file
47
internal/model/mocks/dnsresponse.go
Normal file
|
@ -0,0 +1,47 @@
|
||||||
|
package mocks
|
||||||
|
|
||||||
|
//
|
||||||
|
// Mocks for model.DNSResponse
|
||||||
|
//
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
|
||||||
|
"github.com/ooni/probe-cli/v3/internal/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DNSResponse allows mocking model.DNSResponse.
|
||||||
|
type DNSResponse struct {
|
||||||
|
MockQuery func() model.DNSQuery
|
||||||
|
MockBytes func() []byte
|
||||||
|
MockRcode func() int
|
||||||
|
MockDecodeHTTPS func() (*model.HTTPSSvc, error)
|
||||||
|
MockDecodeLookupHost func() ([]string, error)
|
||||||
|
MockDecodeNS func() ([]*net.NS, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ model.DNSResponse = &DNSResponse{}
|
||||||
|
|
||||||
|
func (r *DNSResponse) Query() model.DNSQuery {
|
||||||
|
return r.MockQuery()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *DNSResponse) Bytes() []byte {
|
||||||
|
return r.MockBytes()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *DNSResponse) Rcode() int {
|
||||||
|
return r.MockRcode()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *DNSResponse) DecodeHTTPS() (*model.HTTPSSvc, error) {
|
||||||
|
return r.MockDecodeHTTPS()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *DNSResponse) DecodeLookupHost() ([]string, error) {
|
||||||
|
return r.MockDecodeLookupHost()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *DNSResponse) DecodeNS() ([]*net.NS, error) {
|
||||||
|
return r.MockDecodeNS()
|
||||||
|
}
|
105
internal/model/mocks/dnsresponse_test.go
Normal file
105
internal/model/mocks/dnsresponse_test.go
Normal file
|
@ -0,0 +1,105 @@
|
||||||
|
package mocks
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"errors"
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
"github.com/ooni/probe-cli/v3/internal/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDNSResponse(t *testing.T) {
|
||||||
|
t.Run("Query", func(t *testing.T) {
|
||||||
|
qid := dns.Id()
|
||||||
|
query := &DNSQuery{
|
||||||
|
MockID: func() uint16 {
|
||||||
|
return qid
|
||||||
|
},
|
||||||
|
}
|
||||||
|
resp := &DNSResponse{
|
||||||
|
MockQuery: func() model.DNSQuery {
|
||||||
|
return query
|
||||||
|
},
|
||||||
|
}
|
||||||
|
out := resp.Query()
|
||||||
|
if out.ID() != query.ID() {
|
||||||
|
t.Fatal("invalid query")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Bytes", func(t *testing.T) {
|
||||||
|
expected := []byte{0xde, 0xea, 0xad, 0xbe, 0xef}
|
||||||
|
resp := &DNSResponse{
|
||||||
|
MockBytes: func() []byte {
|
||||||
|
return expected
|
||||||
|
},
|
||||||
|
}
|
||||||
|
out := resp.Bytes()
|
||||||
|
if !bytes.Equal(expected, out) {
|
||||||
|
t.Fatal("invalid bytes")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Rcode", func(t *testing.T) {
|
||||||
|
expected := dns.RcodeBadAlg
|
||||||
|
resp := &DNSResponse{
|
||||||
|
MockRcode: func() int {
|
||||||
|
return expected
|
||||||
|
},
|
||||||
|
}
|
||||||
|
out := resp.Rcode()
|
||||||
|
if out != expected {
|
||||||
|
t.Fatal("invalid rcode")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("DecodeLookupHost", func(t *testing.T) {
|
||||||
|
expected := errors.New("mocked error")
|
||||||
|
r := &DNSResponse{
|
||||||
|
MockDecodeLookupHost: func() ([]string, error) {
|
||||||
|
return nil, expected
|
||||||
|
},
|
||||||
|
}
|
||||||
|
out, err := r.DecodeLookupHost()
|
||||||
|
if !errors.Is(err, expected) {
|
||||||
|
t.Fatal("unexpected err", err)
|
||||||
|
}
|
||||||
|
if out != nil {
|
||||||
|
t.Fatal("unexpected out")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("DecodeHTTPS", func(t *testing.T) {
|
||||||
|
expected := errors.New("mocked error")
|
||||||
|
r := &DNSResponse{
|
||||||
|
MockDecodeHTTPS: func() (*model.HTTPSSvc, error) {
|
||||||
|
return nil, expected
|
||||||
|
},
|
||||||
|
}
|
||||||
|
out, err := r.DecodeHTTPS()
|
||||||
|
if !errors.Is(err, expected) {
|
||||||
|
t.Fatal("unexpected err", err)
|
||||||
|
}
|
||||||
|
if out != nil {
|
||||||
|
t.Fatal("unexpected out")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("DecodeNS", func(t *testing.T) {
|
||||||
|
expected := errors.New("mocked error")
|
||||||
|
r := &DNSResponse{
|
||||||
|
MockDecodeNS: func() ([]*net.NS, error) {
|
||||||
|
return nil, expected
|
||||||
|
},
|
||||||
|
}
|
||||||
|
out, err := r.DecodeNS()
|
||||||
|
if !errors.Is(err, expected) {
|
||||||
|
t.Fatal("unexpected err", err)
|
||||||
|
}
|
||||||
|
if out != nil {
|
||||||
|
t.Fatal("unexpected out")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
|
@ -1,10 +1,14 @@
|
||||||
package mocks
|
package mocks
|
||||||
|
|
||||||
import "context"
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"github.com/ooni/probe-cli/v3/internal/model"
|
||||||
|
)
|
||||||
|
|
||||||
// DNSTransport allows mocking dnsx.DNSTransport.
|
// DNSTransport allows mocking dnsx.DNSTransport.
|
||||||
type DNSTransport struct {
|
type DNSTransport struct {
|
||||||
MockRoundTrip func(ctx context.Context, query []byte) ([]byte, error)
|
MockRoundTrip func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error)
|
||||||
|
|
||||||
MockRequiresPadding func() bool
|
MockRequiresPadding func() bool
|
||||||
|
|
||||||
|
@ -15,8 +19,10 @@ type DNSTransport struct {
|
||||||
MockCloseIdleConnections func()
|
MockCloseIdleConnections func()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var _ model.DNSTransport = &DNSTransport{}
|
||||||
|
|
||||||
// RoundTrip calls MockRoundTrip.
|
// RoundTrip calls MockRoundTrip.
|
||||||
func (txp *DNSTransport) RoundTrip(ctx context.Context, query []byte) ([]byte, error) {
|
func (txp *DNSTransport) RoundTrip(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
|
||||||
return txp.MockRoundTrip(ctx, query)
|
return txp.MockRoundTrip(ctx, query)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -6,17 +6,18 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/ooni/probe-cli/v3/internal/atomicx"
|
"github.com/ooni/probe-cli/v3/internal/atomicx"
|
||||||
|
"github.com/ooni/probe-cli/v3/internal/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestDNSTransport(t *testing.T) {
|
func TestDNSTransport(t *testing.T) {
|
||||||
t.Run("RoundTrip", func(t *testing.T) {
|
t.Run("RoundTrip", func(t *testing.T) {
|
||||||
expected := errors.New("mocked error")
|
expected := errors.New("mocked error")
|
||||||
txp := &DNSTransport{
|
txp := &DNSTransport{
|
||||||
MockRoundTrip: func(ctx context.Context, query []byte) ([]byte, error) {
|
MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
|
||||||
return nil, expected
|
return nil, expected
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
resp, err := txp.RoundTrip(context.Background(), make([]byte, 16))
|
resp, err := txp.RoundTrip(context.Background(), &DNSQuery{})
|
||||||
if !errors.Is(err, expected) {
|
if !errors.Is(err, expected) {
|
||||||
t.Fatal("not the error we expected", err)
|
t.Fatal("not the error we expected", err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,5 +1,9 @@
|
||||||
package model
|
package model
|
||||||
|
|
||||||
|
//
|
||||||
|
// Network extensions
|
||||||
|
//
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
|
@ -9,74 +13,81 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go"
|
"github.com/lucas-clemente/quic-go"
|
||||||
"github.com/miekg/dns"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
//
|
// DNSResponse is a parsed DNS response ready for further processing.
|
||||||
// Network extensions
|
type DNSResponse interface {
|
||||||
//
|
// Query is the query associated with this response.
|
||||||
|
Query() DNSQuery
|
||||||
|
|
||||||
// The DNSDecoder decodes DNS replies.
|
// Bytes returns the bytes from which we parsed the query.
|
||||||
|
Bytes() []byte
|
||||||
|
|
||||||
|
// Rcode returns the response's Rcode.
|
||||||
|
Rcode() int
|
||||||
|
|
||||||
|
// DecodeHTTPS returns information gathered from all the HTTPS
|
||||||
|
// records found inside of this response.
|
||||||
|
DecodeHTTPS() (*HTTPSSvc, error)
|
||||||
|
|
||||||
|
// DecodeLookupHost returns the addresses in the response matching
|
||||||
|
// the original query type (one of A and AAAA).
|
||||||
|
DecodeLookupHost() ([]string, error)
|
||||||
|
|
||||||
|
// DecodeNS returns all the NS entries in this response.
|
||||||
|
DecodeNS() ([]*net.NS, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// The DNSDecoder decodes DNS responses.
|
||||||
type DNSDecoder interface {
|
type DNSDecoder interface {
|
||||||
// DecodeLookupHost decodes an A or AAAA reply.
|
// DecodeResponse decodes a DNS response message.
|
||||||
//
|
|
||||||
// Arguments:
|
|
||||||
//
|
|
||||||
// - qtype is the query type (e.g., dns.TypeAAAA)
|
|
||||||
//
|
|
||||||
// - data contains the reply bytes read from a DNSTransport
|
|
||||||
//
|
|
||||||
// - queryID is the original query ID
|
|
||||||
//
|
|
||||||
// Returns:
|
|
||||||
//
|
|
||||||
// - on success, a list of IP addrs inside the reply and a nil error
|
|
||||||
//
|
|
||||||
// - on failure, a nil list and an error.
|
|
||||||
//
|
|
||||||
// Note that this function will return an error if there is no
|
|
||||||
// IP address inside of the reply.
|
|
||||||
DecodeLookupHost(qtype uint16, data []byte, queryID uint16) ([]string, error)
|
|
||||||
|
|
||||||
// DecodeHTTPS is like DecodeLookupHost but decodes an HTTPS reply.
|
|
||||||
//
|
|
||||||
// The argument is the reply as read by the DNSTransport.
|
|
||||||
//
|
|
||||||
// On success, this function returns an HTTPSSvc structure and
|
|
||||||
// a nil error. On failure, the HTTPSSvc pointer is nil and
|
|
||||||
// the error points to the error that occurred.
|
|
||||||
//
|
|
||||||
// This function will return an error if the HTTPS reply does not
|
|
||||||
// contain at least a valid ALPN entry. It will not return
|
|
||||||
// an error, though, when there are no IPv4/IPv6 hints in the reply.
|
|
||||||
DecodeHTTPS(data []byte, queryID uint16) (*HTTPSSvc, error)
|
|
||||||
|
|
||||||
// DecodeNS is like DecodeHTTPS but for NS queries.
|
|
||||||
DecodeNS(data []byte, queryID uint16) ([]*net.NS, error)
|
|
||||||
|
|
||||||
// DecodeReply decodes a DNS reply message.
|
|
||||||
//
|
//
|
||||||
// Arguments:
|
// Arguments:
|
||||||
//
|
//
|
||||||
// - data is the raw reply
|
// - data is the raw reply
|
||||||
//
|
//
|
||||||
// This function fails if we cannot parse data as a DNS
|
// This function fails if we cannot parse data as a DNS
|
||||||
// message or the message is not a reply.
|
// message or the message is not a response.
|
||||||
//
|
//
|
||||||
// If you use this function, remember that:
|
// Regarding the returned response, remember that the Rcode
|
||||||
|
// MAY still be nonzero (this method does not treat a nonzero
|
||||||
|
// Rcode as an error when parsing the response).
|
||||||
|
DecodeResponse(data []byte, query DNSQuery) (DNSResponse, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DNSQuery is an encoded DNS query ready to be sent using a DNSTransport.
|
||||||
|
type DNSQuery interface {
|
||||||
|
// Domain is the domain we're querying for.
|
||||||
|
Domain() string
|
||||||
|
|
||||||
|
// Type is the query type.
|
||||||
|
Type() uint16
|
||||||
|
|
||||||
|
// Bytes serializes the query to bytes. This function may fail if we're not
|
||||||
|
// able to correctly encode the domain into a query message.
|
||||||
//
|
//
|
||||||
// 1. the Rcode MAY be nonzero;
|
// The value returned by this function WILL be memoized after the first call,
|
||||||
//
|
// so you SHOULD create a new DNSQuery if you need to retry a query.
|
||||||
// 2. the replyID MAY NOT match the original query ID.
|
Bytes() ([]byte, error)
|
||||||
//
|
|
||||||
// That is, this is a very basic parsing method.
|
// ID returns the query ID.
|
||||||
DecodeReply(data []byte) (*dns.Msg, error)
|
ID() uint16
|
||||||
}
|
}
|
||||||
|
|
||||||
// The DNSEncoder encodes DNS queries to bytes
|
// The DNSEncoder encodes DNS queries to bytes
|
||||||
type DNSEncoder interface {
|
type DNSEncoder interface {
|
||||||
// Encode transforms its arguments into a serialized DNS query.
|
// Encode transforms its arguments into a serialized DNS query.
|
||||||
//
|
//
|
||||||
|
// Every time you call Encode, you get a new DNSQuery value
|
||||||
|
// using a query ID selected at random.
|
||||||
|
//
|
||||||
|
// Serialization to bytes is lazy to acommodate DNS transports that
|
||||||
|
// do not need to serialize and send bytes, e.g., getaddrinfo.
|
||||||
|
//
|
||||||
|
// You serialize to bytes using DNSQuery.Bytes. This operation MAY fail
|
||||||
|
// if the domain name cannot be packed into a DNS message (e.g., it is
|
||||||
|
// too long to fit into the message).
|
||||||
|
//
|
||||||
// Arguments:
|
// Arguments:
|
||||||
//
|
//
|
||||||
// - domain is the domain for the query (e.g., x.org);
|
// - domain is the domain for the query (e.g., x.org);
|
||||||
|
@ -85,16 +96,15 @@ type DNSEncoder interface {
|
||||||
//
|
//
|
||||||
// - padding is whether to add padding to the query.
|
// - padding is whether to add padding to the query.
|
||||||
//
|
//
|
||||||
// On success, this function returns a valid byte array, the queryID, and
|
// This function will transform the domain into an FQDN is it's not
|
||||||
// a nil error. On failure, we have a non-nil error, a nil arrary and a zero
|
// already expressed in the FQDN format.
|
||||||
// query ID.
|
Encode(domain string, qtype uint16, padding bool) DNSQuery
|
||||||
Encode(domain string, qtype uint16, padding bool) ([]byte, uint16, error)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// DNSTransport represents an abstract DNS transport.
|
// DNSTransport represents an abstract DNS transport.
|
||||||
type DNSTransport interface {
|
type DNSTransport interface {
|
||||||
// RoundTrip sends a DNS query and receives the reply.
|
// RoundTrip sends a DNS query and receives the reply.
|
||||||
RoundTrip(ctx context.Context, query []byte) (reply []byte, err error)
|
RoundTrip(ctx context.Context, query DNSQuery) (DNSResponse, error)
|
||||||
|
|
||||||
// RequiresPadding returns whether this transport needs padding.
|
// RequiresPadding returns whether this transport needs padding.
|
||||||
RequiresPadding() bool
|
RequiresPadding() bool
|
||||||
|
|
|
@ -15,14 +15,16 @@ import (
|
||||||
// DNSDecoderMiekg uses github.com/miekg/dns to implement the Decoder.
|
// DNSDecoderMiekg uses github.com/miekg/dns to implement the Decoder.
|
||||||
type DNSDecoderMiekg struct{}
|
type DNSDecoderMiekg struct{}
|
||||||
|
|
||||||
// ErrDNSReplyWithWrongQueryID indicates we have got a DNS reply with the wrong queryID.
|
var (
|
||||||
var ErrDNSReplyWithWrongQueryID = errors.New(FailureDNSReplyWithWrongQueryID)
|
// ErrDNSReplyWithWrongQueryID indicates we have got a DNS reply with the wrong queryID.
|
||||||
|
ErrDNSReplyWithWrongQueryID = errors.New(FailureDNSReplyWithWrongQueryID)
|
||||||
|
|
||||||
// ErrDNSIsQuery indicates that we were passed a DNS query.
|
// ErrDNSIsQuery indicates that we were passed a DNS query.
|
||||||
var ErrDNSIsQuery = errors.New("ooresolver: expected response but received query")
|
ErrDNSIsQuery = errors.New("ooresolver: expected response but received query")
|
||||||
|
)
|
||||||
|
|
||||||
// DecodeReply implements model.DNSDecoder.DecodeReply
|
// DecodeResponse implements model.DNSDecoder.DecodeResponse.
|
||||||
func (d *DNSDecoderMiekg) DecodeReply(data []byte) (*dns.Msg, error) {
|
func (d *DNSDecoderMiekg) DecodeResponse(data []byte, query model.DNSQuery) (model.DNSResponse, error) {
|
||||||
reply := &dns.Msg{}
|
reply := &dns.Msg{}
|
||||||
if err := reply.Unpack(data); err != nil {
|
if err := reply.Unpack(data); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -30,46 +32,64 @@ func (d *DNSDecoderMiekg) DecodeReply(data []byte) (*dns.Msg, error) {
|
||||||
if !reply.Response {
|
if !reply.Response {
|
||||||
return nil, ErrDNSIsQuery
|
return nil, ErrDNSIsQuery
|
||||||
}
|
}
|
||||||
return reply, nil
|
if reply.Id != query.ID() {
|
||||||
}
|
|
||||||
|
|
||||||
// decodeSuccessfulReply decodes the bytes in data as a successful reply for the
|
|
||||||
// given queryID. This function returns an error if:
|
|
||||||
//
|
|
||||||
// 1. we cannot decode data
|
|
||||||
//
|
|
||||||
// 2. the decoded message is not a reply
|
|
||||||
//
|
|
||||||
// 3. the query ID does not match
|
|
||||||
//
|
|
||||||
// 4. the Rcode is not zero.
|
|
||||||
func (d *DNSDecoderMiekg) decodeSuccessfulReply(data []byte, queryID uint16) (*dns.Msg, error) {
|
|
||||||
reply, err := d.DecodeReply(data)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if reply.Id != queryID {
|
|
||||||
return nil, ErrDNSReplyWithWrongQueryID
|
return nil, ErrDNSReplyWithWrongQueryID
|
||||||
}
|
}
|
||||||
|
resp := &dnsResponse{
|
||||||
|
bytes: data,
|
||||||
|
msg: reply,
|
||||||
|
query: query,
|
||||||
|
}
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// dnsResponse implements model.DNSResponse.
|
||||||
|
type dnsResponse struct {
|
||||||
|
// bytes contains the response bytes.
|
||||||
|
bytes []byte
|
||||||
|
|
||||||
|
// msg contains the message.
|
||||||
|
msg *dns.Msg
|
||||||
|
|
||||||
|
// query is the original query.
|
||||||
|
query model.DNSQuery
|
||||||
|
}
|
||||||
|
|
||||||
|
// Query implements model.DNSResponse.Query.
|
||||||
|
func (r *dnsResponse) Query() model.DNSQuery {
|
||||||
|
return r.query
|
||||||
|
}
|
||||||
|
|
||||||
|
// Bytes implements model.DNSResponse.Bytes.
|
||||||
|
func (r *dnsResponse) Bytes() []byte {
|
||||||
|
return r.bytes
|
||||||
|
}
|
||||||
|
|
||||||
|
// Rcode implements model.DNSResponse.Rcode.
|
||||||
|
func (r *dnsResponse) Rcode() int {
|
||||||
|
return r.msg.Rcode
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *dnsResponse) rcodeToError() error {
|
||||||
// TODO(bassosimone): map more errors to net.DNSError names
|
// TODO(bassosimone): map more errors to net.DNSError names
|
||||||
// TODO(bassosimone): add support for lame referral.
|
// TODO(bassosimone): add support for lame referral.
|
||||||
switch reply.Rcode {
|
switch r.msg.Rcode {
|
||||||
case dns.RcodeSuccess:
|
case dns.RcodeSuccess:
|
||||||
return reply, nil
|
return nil
|
||||||
case dns.RcodeNameError:
|
case dns.RcodeNameError:
|
||||||
return nil, ErrOODNSNoSuchHost
|
return ErrOODNSNoSuchHost
|
||||||
case dns.RcodeRefused:
|
case dns.RcodeRefused:
|
||||||
return nil, ErrOODNSRefused
|
return ErrOODNSRefused
|
||||||
case dns.RcodeServerFailure:
|
case dns.RcodeServerFailure:
|
||||||
return nil, ErrOODNSServfail
|
return ErrOODNSServfail
|
||||||
default:
|
default:
|
||||||
return nil, ErrOODNSMisbehaving
|
return ErrOODNSMisbehaving
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *DNSDecoderMiekg) DecodeHTTPS(data []byte, queryID uint16) (*model.HTTPSSvc, error) {
|
// DecodeHTTPS implements model.DNSResponse.DecodeHTTPS.
|
||||||
reply, err := d.decodeSuccessfulReply(data, queryID)
|
func (r *dnsResponse) DecodeHTTPS() (*model.HTTPSSvc, error) {
|
||||||
if err != nil {
|
if err := r.rcodeToError(); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
out := &model.HTTPSSvc{
|
out := &model.HTTPSSvc{
|
||||||
|
@ -77,7 +97,7 @@ func (d *DNSDecoderMiekg) DecodeHTTPS(data []byte, queryID uint16) (*model.HTTPS
|
||||||
IPv4: []string{}, // ensure it's not nil
|
IPv4: []string{}, // ensure it's not nil
|
||||||
IPv6: []string{}, // ensure it's not nil
|
IPv6: []string{}, // ensure it's not nil
|
||||||
}
|
}
|
||||||
for _, answer := range reply.Answer {
|
for _, answer := range r.msg.Answer {
|
||||||
switch avalue := answer.(type) {
|
switch avalue := answer.(type) {
|
||||||
case *dns.HTTPS:
|
case *dns.HTTPS:
|
||||||
for _, v := range avalue.Value {
|
for _, v := range avalue.Value {
|
||||||
|
@ -102,14 +122,14 @@ func (d *DNSDecoderMiekg) DecodeHTTPS(data []byte, queryID uint16) (*model.HTTPS
|
||||||
return out, nil
|
return out, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *DNSDecoderMiekg) DecodeLookupHost(qtype uint16, data []byte, queryID uint16) ([]string, error) {
|
// DecodeLookupHost implements model.DNSResponse.DecodeLookupHost.
|
||||||
reply, err := d.decodeSuccessfulReply(data, queryID)
|
func (r *dnsResponse) DecodeLookupHost() ([]string, error) {
|
||||||
if err != nil {
|
if err := r.rcodeToError(); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
var addrs []string
|
var addrs []string
|
||||||
for _, answer := range reply.Answer {
|
for _, answer := range r.msg.Answer {
|
||||||
switch qtype {
|
switch r.Query().Type() {
|
||||||
case dns.TypeA:
|
case dns.TypeA:
|
||||||
if rra, ok := answer.(*dns.A); ok {
|
if rra, ok := answer.(*dns.A); ok {
|
||||||
ip := rra.A
|
ip := rra.A
|
||||||
|
@ -128,13 +148,13 @@ func (d *DNSDecoderMiekg) DecodeLookupHost(qtype uint16, data []byte, queryID ui
|
||||||
return addrs, nil
|
return addrs, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *DNSDecoderMiekg) DecodeNS(data []byte, queryID uint16) ([]*net.NS, error) {
|
// DecodeNS implements model.DNSResponse.DecodeNS.
|
||||||
reply, err := d.decodeSuccessfulReply(data, queryID)
|
func (r *dnsResponse) DecodeNS() ([]*net.NS, error) {
|
||||||
if err != nil {
|
if err := r.rcodeToError(); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
out := []*net.NS{}
|
out := []*net.NS{}
|
||||||
for _, answer := range reply.Answer {
|
for _, answer := range r.msg.Answer {
|
||||||
switch avalue := answer.(type) {
|
switch avalue := answer.(type) {
|
||||||
case *dns.NS:
|
case *dns.NS:
|
||||||
out = append(out, &net.NS{Host: avalue.Ns})
|
out = append(out, &net.NS{Host: avalue.Ns})
|
||||||
|
@ -147,3 +167,4 @@ func (d *DNSDecoderMiekg) DecodeNS(data []byte, queryID uint16) ([]*net.NS, erro
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ model.DNSDecoder = &DNSDecoderMiekg{}
|
var _ model.DNSDecoder = &DNSDecoderMiekg{}
|
||||||
|
var _ model.DNSResponse = &dnsResponse{}
|
||||||
|
|
|
@ -1,26 +1,27 @@
|
||||||
package netxlite
|
package netxlite
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"errors"
|
"errors"
|
||||||
"net"
|
"net"
|
||||||
"strings"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/google/go-cmp/cmp"
|
"github.com/google/go-cmp/cmp"
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
|
"github.com/ooni/probe-cli/v3/internal/model/mocks"
|
||||||
"github.com/ooni/probe-cli/v3/internal/runtimex"
|
"github.com/ooni/probe-cli/v3/internal/runtimex"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestDNSDecoder(t *testing.T) {
|
func TestDNSDecoderMiekg(t *testing.T) {
|
||||||
t.Run("LookupHost", func(t *testing.T) {
|
t.Run("DecodeResponse", func(t *testing.T) {
|
||||||
t.Run("UnpackError", func(t *testing.T) {
|
t.Run("UnpackError", func(t *testing.T) {
|
||||||
d := &DNSDecoderMiekg{}
|
d := &DNSDecoderMiekg{}
|
||||||
data, err := d.DecodeLookupHost(dns.TypeA, nil, 0)
|
resp, err := d.DecodeResponse(nil, &mocks.DNSQuery{})
|
||||||
if err == nil || err.Error() != "dns: overflow unpacking uint16" {
|
if err == nil || err.Error() != "dns: overflow unpacking uint16" {
|
||||||
t.Fatal("unexpected error", err)
|
t.Fatal("unexpected error", err)
|
||||||
}
|
}
|
||||||
if data != nil {
|
if resp != nil {
|
||||||
t.Fatal("expected nil data here")
|
t.Fatal("expected nil resp here")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -28,12 +29,12 @@ func TestDNSDecoder(t *testing.T) {
|
||||||
d := &DNSDecoderMiekg{}
|
d := &DNSDecoderMiekg{}
|
||||||
queryID := dns.Id()
|
queryID := dns.Id()
|
||||||
rawQuery := dnsGenQuery(dns.TypeA, queryID)
|
rawQuery := dnsGenQuery(dns.TypeA, queryID)
|
||||||
addrs, err := d.DecodeLookupHost(dns.TypeA, rawQuery, queryID)
|
resp, err := d.DecodeResponse(rawQuery, &mocks.DNSQuery{})
|
||||||
if !errors.Is(err, ErrDNSIsQuery) {
|
if !errors.Is(err, ErrDNSIsQuery) {
|
||||||
t.Fatal("unexpected err", err)
|
t.Fatal("unexpected err", err)
|
||||||
}
|
}
|
||||||
if len(addrs) > 0 {
|
if resp != nil {
|
||||||
t.Fatal("expected no addrs")
|
t.Fatal("expected nil resp here")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -44,297 +45,447 @@ func TestDNSDecoder(t *testing.T) {
|
||||||
unrelatedID = 14
|
unrelatedID = 14
|
||||||
)
|
)
|
||||||
reply := dnsGenLookupHostReplySuccess(dnsGenQuery(dns.TypeA, queryID))
|
reply := dnsGenLookupHostReplySuccess(dnsGenQuery(dns.TypeA, queryID))
|
||||||
data, err := d.DecodeLookupHost(dns.TypeA, reply, unrelatedID)
|
resp, err := d.DecodeResponse(reply, &mocks.DNSQuery{
|
||||||
|
MockID: func() uint16 {
|
||||||
|
return unrelatedID
|
||||||
|
},
|
||||||
|
})
|
||||||
if !errors.Is(err, ErrDNSReplyWithWrongQueryID) {
|
if !errors.Is(err, ErrDNSReplyWithWrongQueryID) {
|
||||||
t.Fatal("unexpected error", err)
|
t.Fatal("unexpected error", err)
|
||||||
}
|
}
|
||||||
if data != nil {
|
if resp != nil {
|
||||||
t.Fatal("expected nil data here")
|
t.Fatal("expected nil resp here")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("NXDOMAIN", func(t *testing.T) {
|
t.Run("dnsResponse.Query", func(t *testing.T) {
|
||||||
d := &DNSDecoderMiekg{}
|
d := &DNSDecoderMiekg{}
|
||||||
queryID := dns.Id()
|
queryID := dns.Id()
|
||||||
data, err := d.DecodeLookupHost(dns.TypeA, dnsGenReplyWithError(
|
rawQuery := dnsGenQuery(dns.TypeA, queryID)
|
||||||
dnsGenQuery(dns.TypeA, queryID), dns.RcodeNameError), queryID)
|
rawResponse := dnsGenLookupHostReplySuccess(rawQuery)
|
||||||
if err == nil || !strings.HasSuffix(err.Error(), "no such host") {
|
query := &mocks.DNSQuery{
|
||||||
t.Fatal("not the error we expected", err)
|
MockID: func() uint16 {
|
||||||
|
return queryID
|
||||||
|
},
|
||||||
}
|
}
|
||||||
if data != nil {
|
resp, err := d.DecodeResponse(rawResponse, query)
|
||||||
t.Fatal("expected nil data here")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("Refused", func(t *testing.T) {
|
|
||||||
d := &DNSDecoderMiekg{}
|
|
||||||
queryID := dns.Id()
|
|
||||||
data, err := d.DecodeLookupHost(dns.TypeA, dnsGenReplyWithError(
|
|
||||||
dnsGenQuery(dns.TypeA, queryID), dns.RcodeRefused), queryID)
|
|
||||||
if !errors.Is(err, ErrOODNSRefused) {
|
|
||||||
t.Fatal("not the error we expected", err)
|
|
||||||
}
|
|
||||||
if data != nil {
|
|
||||||
t.Fatal("expected nil data here")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("Servfail", func(t *testing.T) {
|
|
||||||
d := &DNSDecoderMiekg{}
|
|
||||||
queryID := dns.Id()
|
|
||||||
data, err := d.DecodeLookupHost(dns.TypeA, dnsGenReplyWithError(
|
|
||||||
dnsGenQuery(dns.TypeA, queryID), dns.RcodeServerFailure), queryID)
|
|
||||||
if !errors.Is(err, ErrOODNSServfail) {
|
|
||||||
t.Fatal("not the error we expected", err)
|
|
||||||
}
|
|
||||||
if data != nil {
|
|
||||||
t.Fatal("expected nil data here")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("no address", func(t *testing.T) {
|
|
||||||
d := &DNSDecoderMiekg{}
|
|
||||||
queryID := dns.Id()
|
|
||||||
data, err := d.DecodeLookupHost(dns.TypeA, dnsGenLookupHostReplySuccess(
|
|
||||||
dnsGenQuery(dns.TypeA, queryID)), queryID)
|
|
||||||
if !errors.Is(err, ErrOODNSNoAnswer) {
|
|
||||||
t.Fatal("not the error we expected", err)
|
|
||||||
}
|
|
||||||
if data != nil {
|
|
||||||
t.Fatal("expected nil data here")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("decode A", func(t *testing.T) {
|
|
||||||
d := &DNSDecoderMiekg{}
|
|
||||||
queryID := dns.Id()
|
|
||||||
data, err := d.DecodeLookupHost(dns.TypeA, dnsGenLookupHostReplySuccess(
|
|
||||||
dnsGenQuery(dns.TypeA, queryID), "1.1.1.1", "8.8.8.8"), queryID)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
if len(data) != 2 {
|
if resp.Query().ID() != query.ID() {
|
||||||
t.Fatal("expected two entries here")
|
t.Fatal("invalid query")
|
||||||
}
|
|
||||||
if data[0] != "1.1.1.1" {
|
|
||||||
t.Fatal("invalid first IPv4 entry")
|
|
||||||
}
|
|
||||||
if data[1] != "8.8.8.8" {
|
|
||||||
t.Fatal("invalid second IPv4 entry")
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("decode AAAA", func(t *testing.T) {
|
t.Run("dnsResponse.Bytes", func(t *testing.T) {
|
||||||
d := &DNSDecoderMiekg{}
|
d := &DNSDecoderMiekg{}
|
||||||
queryID := dns.Id()
|
queryID := dns.Id()
|
||||||
data, err := d.DecodeLookupHost(dns.TypeAAAA, dnsGenLookupHostReplySuccess(
|
rawQuery := dnsGenQuery(dns.TypeA, queryID)
|
||||||
dnsGenQuery(dns.TypeAAAA, queryID), "::1", "fe80::1"), queryID)
|
rawResponse := dnsGenLookupHostReplySuccess(rawQuery)
|
||||||
|
query := &mocks.DNSQuery{
|
||||||
|
MockID: func() uint16 {
|
||||||
|
return queryID
|
||||||
|
},
|
||||||
|
}
|
||||||
|
resp, err := d.DecodeResponse(rawResponse, query)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
if len(data) != 2 {
|
if !bytes.Equal(rawResponse, resp.Bytes()) {
|
||||||
t.Fatal("expected two entries here")
|
t.Fatal("invalid bytes")
|
||||||
}
|
|
||||||
if data[0] != "::1" {
|
|
||||||
t.Fatal("invalid first IPv6 entry")
|
|
||||||
}
|
|
||||||
if data[1] != "fe80::1" {
|
|
||||||
t.Fatal("invalid second IPv6 entry")
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("unexpected A reply", func(t *testing.T) {
|
t.Run("dnsResponse.Rcode", func(t *testing.T) {
|
||||||
d := &DNSDecoderMiekg{}
|
d := &DNSDecoderMiekg{}
|
||||||
queryID := dns.Id()
|
queryID := dns.Id()
|
||||||
data, err := d.DecodeLookupHost(dns.TypeA, dnsGenLookupHostReplySuccess(
|
rawQuery := dnsGenQuery(dns.TypeA, queryID)
|
||||||
dnsGenQuery(dns.TypeAAAA, queryID), "::1", "fe80::1"), queryID)
|
rawResponse := dnsGenReplyWithError(rawQuery, dns.RcodeRefused)
|
||||||
if !errors.Is(err, ErrOODNSNoAnswer) {
|
query := &mocks.DNSQuery{
|
||||||
t.Fatal("not the error we expected", err)
|
MockID: func() uint16 {
|
||||||
|
return queryID
|
||||||
|
},
|
||||||
}
|
}
|
||||||
if data != nil {
|
resp, err := d.DecodeResponse(rawResponse, query)
|
||||||
t.Fatal("expected nil data here")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("unexpected AAAA reply", func(t *testing.T) {
|
|
||||||
d := &DNSDecoderMiekg{}
|
|
||||||
queryID := dns.Id()
|
|
||||||
data, err := d.DecodeLookupHost(dns.TypeAAAA, dnsGenLookupHostReplySuccess(
|
|
||||||
dnsGenQuery(dns.TypeA, queryID), "1.1.1.1", "8.8.4.4"), queryID)
|
|
||||||
if !errors.Is(err, ErrOODNSNoAnswer) {
|
|
||||||
t.Fatal("not the error we expected", err)
|
|
||||||
}
|
|
||||||
if data != nil {
|
|
||||||
t.Fatal("expected nil data here")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("decodeSuccessfulReply", func(t *testing.T) {
|
|
||||||
d := &DNSDecoderMiekg{}
|
|
||||||
msg := &dns.Msg{}
|
|
||||||
msg.Rcode = dns.RcodeFormatError // an rcode we don't handle
|
|
||||||
msg.Response = true
|
|
||||||
data, err := msg.Pack()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
reply, err := d.decodeSuccessfulReply(data, 0)
|
|
||||||
if !errors.Is(err, ErrOODNSMisbehaving) { // catch all error
|
|
||||||
t.Fatal("not the error we expected", err)
|
|
||||||
}
|
|
||||||
if reply != nil {
|
|
||||||
t.Fatal("expected nil reply")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("DecodeHTTPS", func(t *testing.T) {
|
|
||||||
t.Run("with nil data", func(t *testing.T) {
|
|
||||||
d := &DNSDecoderMiekg{}
|
|
||||||
reply, err := d.DecodeHTTPS(nil, 0)
|
|
||||||
if err == nil || err.Error() != "dns: overflow unpacking uint16" {
|
|
||||||
t.Fatal("not the error we expected", err)
|
|
||||||
}
|
|
||||||
if reply != nil {
|
|
||||||
t.Fatal("expected nil reply")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("with bytes containing a query", func(t *testing.T) {
|
|
||||||
d := &DNSDecoderMiekg{}
|
|
||||||
queryID := dns.Id()
|
|
||||||
rawQuery := dnsGenQuery(dns.TypeHTTPS, queryID)
|
|
||||||
https, err := d.DecodeHTTPS(rawQuery, queryID)
|
|
||||||
if !errors.Is(err, ErrDNSIsQuery) {
|
|
||||||
t.Fatal("unexpected err", err)
|
|
||||||
}
|
|
||||||
if https != nil {
|
|
||||||
t.Fatal("expected nil https")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("wrong query ID", func(t *testing.T) {
|
|
||||||
d := &DNSDecoderMiekg{}
|
|
||||||
const (
|
|
||||||
queryID = 17
|
|
||||||
unrelatedID = 14
|
|
||||||
)
|
|
||||||
reply := dnsGenHTTPSReplySuccess(dnsGenQuery(dns.TypeHTTPS, queryID), nil, nil, nil)
|
|
||||||
data, err := d.DecodeHTTPS(reply, unrelatedID)
|
|
||||||
if !errors.Is(err, ErrDNSReplyWithWrongQueryID) {
|
|
||||||
t.Fatal("unexpected error", err)
|
|
||||||
}
|
|
||||||
if data != nil {
|
|
||||||
t.Fatal("expected nil data here")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("with empty answer", func(t *testing.T) {
|
|
||||||
queryID := dns.Id()
|
|
||||||
data := dnsGenHTTPSReplySuccess(
|
|
||||||
dnsGenQuery(dns.TypeHTTPS, queryID), nil, nil, nil)
|
|
||||||
d := &DNSDecoderMiekg{}
|
|
||||||
reply, err := d.DecodeHTTPS(data, queryID)
|
|
||||||
if !errors.Is(err, ErrOODNSNoAnswer) {
|
|
||||||
t.Fatal("unexpected err", err)
|
|
||||||
}
|
|
||||||
if reply != nil {
|
|
||||||
t.Fatal("expected nil reply")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("with full answer", func(t *testing.T) {
|
|
||||||
queryID := dns.Id()
|
|
||||||
alpn := []string{"h3"}
|
|
||||||
v4 := []string{"1.1.1.1"}
|
|
||||||
v6 := []string{"::1"}
|
|
||||||
data := dnsGenHTTPSReplySuccess(
|
|
||||||
dnsGenQuery(dns.TypeHTTPS, queryID), alpn, v4, v6)
|
|
||||||
d := &DNSDecoderMiekg{}
|
|
||||||
reply, err := d.DecodeHTTPS(data, queryID)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
if diff := cmp.Diff(alpn, reply.ALPN); diff != "" {
|
if resp.Rcode() != dns.RcodeRefused {
|
||||||
t.Fatal(diff)
|
t.Fatal("invalid rcode")
|
||||||
}
|
|
||||||
if diff := cmp.Diff(v4, reply.IPv4); diff != "" {
|
|
||||||
t.Fatal(diff)
|
|
||||||
}
|
|
||||||
if diff := cmp.Diff(v6, reply.IPv6); diff != "" {
|
|
||||||
t.Fatal(diff)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("DecodeNS", func(t *testing.T) {
|
|
||||||
t.Run("with nil data", func(t *testing.T) {
|
|
||||||
d := &DNSDecoderMiekg{}
|
|
||||||
reply, err := d.DecodeNS(nil, 0)
|
|
||||||
if err == nil || err.Error() != "dns: overflow unpacking uint16" {
|
|
||||||
t.Fatal("not the error we expected", err)
|
|
||||||
}
|
|
||||||
if reply != nil {
|
|
||||||
t.Fatal("expected nil reply")
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("with bytes containing a query", func(t *testing.T) {
|
t.Run("dnsResponse.rcodeToError", func(t *testing.T) {
|
||||||
d := &DNSDecoderMiekg{}
|
// Here we want to ensure we map all the errors we recognize
|
||||||
queryID := dns.Id()
|
// correctly and we also map unrecognized errors correctly
|
||||||
rawQuery := dnsGenQuery(dns.TypeNS, queryID)
|
var inputsOutputs = []struct {
|
||||||
ns, err := d.DecodeNS(rawQuery, queryID)
|
name string
|
||||||
if !errors.Is(err, ErrDNSIsQuery) {
|
rcode int
|
||||||
t.Fatal("unexpected err", err)
|
err error
|
||||||
}
|
}{{
|
||||||
if len(ns) > 0 {
|
name: "when rcode is zero",
|
||||||
t.Fatal("expected no result")
|
rcode: 0,
|
||||||
|
err: nil,
|
||||||
|
}, {
|
||||||
|
name: "NXDOMAIN",
|
||||||
|
rcode: dns.RcodeNameError,
|
||||||
|
err: ErrOODNSNoSuchHost,
|
||||||
|
}, {
|
||||||
|
name: "refused",
|
||||||
|
rcode: dns.RcodeRefused,
|
||||||
|
err: ErrOODNSRefused,
|
||||||
|
}, {
|
||||||
|
name: "servfail",
|
||||||
|
rcode: dns.RcodeServerFailure,
|
||||||
|
err: ErrOODNSServfail,
|
||||||
|
}, {
|
||||||
|
name: "anything else",
|
||||||
|
rcode: dns.RcodeFormatError,
|
||||||
|
err: ErrOODNSMisbehaving,
|
||||||
|
}}
|
||||||
|
for _, io := range inputsOutputs {
|
||||||
|
t.Run(io.name, func(t *testing.T) {
|
||||||
|
d := &DNSDecoderMiekg{}
|
||||||
|
queryID := dns.Id()
|
||||||
|
rawQuery := dnsGenQuery(dns.TypeHTTPS, queryID)
|
||||||
|
rawResponse := dnsGenReplyWithError(rawQuery, io.rcode)
|
||||||
|
query := &mocks.DNSQuery{
|
||||||
|
MockID: func() uint16 {
|
||||||
|
return queryID
|
||||||
|
},
|
||||||
|
}
|
||||||
|
resp, err := d.DecodeResponse(rawResponse, query)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
// The following cast should always work in this configuration
|
||||||
|
err = resp.(*dnsResponse).rcodeToError()
|
||||||
|
if !errors.Is(err, io.err) {
|
||||||
|
t.Fatal("unexpected err", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("wrong query ID", func(t *testing.T) {
|
t.Run("dnsResponse.DecodeHTTPS", func(t *testing.T) {
|
||||||
d := &DNSDecoderMiekg{}
|
t.Run("with failure", func(t *testing.T) {
|
||||||
const (
|
// Ensure that we're not trying to decode if rcode != 0
|
||||||
queryID = 17
|
d := &DNSDecoderMiekg{}
|
||||||
unrelatedID = 14
|
queryID := dns.Id()
|
||||||
)
|
rawQuery := dnsGenQuery(dns.TypeHTTPS, queryID)
|
||||||
reply := dnsGenNSReplySuccess(dnsGenQuery(dns.TypeNS, queryID))
|
rawResponse := dnsGenReplyWithError(rawQuery, dns.RcodeRefused)
|
||||||
data, err := d.DecodeNS(reply, unrelatedID)
|
query := &mocks.DNSQuery{
|
||||||
if !errors.Is(err, ErrDNSReplyWithWrongQueryID) {
|
MockID: func() uint16 {
|
||||||
t.Fatal("unexpected error", err)
|
return queryID
|
||||||
}
|
},
|
||||||
if data != nil {
|
}
|
||||||
t.Fatal("expected nil data here")
|
resp, err := d.DecodeResponse(rawResponse, query)
|
||||||
}
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
https, err := resp.DecodeHTTPS()
|
||||||
|
if !errors.Is(err, ErrOODNSRefused) {
|
||||||
|
t.Fatal("unexpected err", err)
|
||||||
|
}
|
||||||
|
if https != nil {
|
||||||
|
t.Fatal("expected nil https result")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("with empty answer", func(t *testing.T) {
|
||||||
|
d := &DNSDecoderMiekg{}
|
||||||
|
queryID := dns.Id()
|
||||||
|
rawQuery := dnsGenQuery(dns.TypeHTTPS, queryID)
|
||||||
|
rawResponse := dnsGenHTTPSReplySuccess(rawQuery, nil, nil, nil)
|
||||||
|
query := &mocks.DNSQuery{
|
||||||
|
MockID: func() uint16 {
|
||||||
|
return queryID
|
||||||
|
},
|
||||||
|
}
|
||||||
|
resp, err := d.DecodeResponse(rawResponse, query)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
https, err := resp.DecodeHTTPS()
|
||||||
|
if !errors.Is(err, ErrOODNSNoAnswer) {
|
||||||
|
t.Fatal("unexpected err", err)
|
||||||
|
}
|
||||||
|
if https != nil {
|
||||||
|
t.Fatal("expected nil https results")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("with full answer", func(t *testing.T) {
|
||||||
|
alpn := []string{"h3"}
|
||||||
|
v4 := []string{"1.1.1.1"}
|
||||||
|
v6 := []string{"::1"}
|
||||||
|
d := &DNSDecoderMiekg{}
|
||||||
|
queryID := dns.Id()
|
||||||
|
rawQuery := dnsGenQuery(dns.TypeHTTPS, queryID)
|
||||||
|
rawResponse := dnsGenHTTPSReplySuccess(rawQuery, alpn, v4, v6)
|
||||||
|
query := &mocks.DNSQuery{
|
||||||
|
MockID: func() uint16 {
|
||||||
|
return queryID
|
||||||
|
},
|
||||||
|
}
|
||||||
|
resp, err := d.DecodeResponse(rawResponse, query)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
reply, err := resp.DecodeHTTPS()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if diff := cmp.Diff(alpn, reply.ALPN); diff != "" {
|
||||||
|
t.Fatal(diff)
|
||||||
|
}
|
||||||
|
if diff := cmp.Diff(v4, reply.IPv4); diff != "" {
|
||||||
|
t.Fatal(diff)
|
||||||
|
}
|
||||||
|
if diff := cmp.Diff(v6, reply.IPv6); diff != "" {
|
||||||
|
t.Fatal(diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("with empty answer", func(t *testing.T) {
|
t.Run("dnsResponse.DecodeNS", func(t *testing.T) {
|
||||||
queryID := dns.Id()
|
t.Run("with failure", func(t *testing.T) {
|
||||||
data := dnsGenNSReplySuccess(dnsGenQuery(dns.TypeNS, queryID))
|
// Ensure that we're not trying to decode if rcode != 0
|
||||||
d := &DNSDecoderMiekg{}
|
d := &DNSDecoderMiekg{}
|
||||||
reply, err := d.DecodeNS(data, queryID)
|
queryID := dns.Id()
|
||||||
if !errors.Is(err, ErrOODNSNoAnswer) {
|
rawQuery := dnsGenQuery(dns.TypeNS, queryID)
|
||||||
t.Fatal("unexpected err", err)
|
rawResponse := dnsGenReplyWithError(rawQuery, dns.RcodeRefused)
|
||||||
}
|
query := &mocks.DNSQuery{
|
||||||
if reply != nil {
|
MockID: func() uint16 {
|
||||||
t.Fatal("expected nil reply")
|
return queryID
|
||||||
}
|
},
|
||||||
|
}
|
||||||
|
resp, err := d.DecodeResponse(rawResponse, query)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
ns, err := resp.DecodeNS()
|
||||||
|
if !errors.Is(err, ErrOODNSRefused) {
|
||||||
|
t.Fatal("unexpected err", err)
|
||||||
|
}
|
||||||
|
if len(ns) > 0 {
|
||||||
|
t.Fatal("expected empty ns result")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("with empty answer", func(t *testing.T) {
|
||||||
|
d := &DNSDecoderMiekg{}
|
||||||
|
queryID := dns.Id()
|
||||||
|
rawQuery := dnsGenQuery(dns.TypeNS, queryID)
|
||||||
|
rawResponse := dnsGenNSReplySuccess(rawQuery)
|
||||||
|
query := &mocks.DNSQuery{
|
||||||
|
MockID: func() uint16 {
|
||||||
|
return queryID
|
||||||
|
},
|
||||||
|
}
|
||||||
|
resp, err := d.DecodeResponse(rawResponse, query)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
ns, err := resp.DecodeNS()
|
||||||
|
if !errors.Is(err, ErrOODNSNoAnswer) {
|
||||||
|
t.Fatal("unexpected err", err)
|
||||||
|
}
|
||||||
|
if len(ns) > 0 {
|
||||||
|
t.Fatal("expected empty ns results")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("with full answer", func(t *testing.T) {
|
||||||
|
d := &DNSDecoderMiekg{}
|
||||||
|
queryID := dns.Id()
|
||||||
|
rawQuery := dnsGenQuery(dns.TypeNS, queryID)
|
||||||
|
rawResponse := dnsGenNSReplySuccess(rawQuery, "ns1.zdns.google.")
|
||||||
|
query := &mocks.DNSQuery{
|
||||||
|
MockID: func() uint16 {
|
||||||
|
return queryID
|
||||||
|
},
|
||||||
|
}
|
||||||
|
resp, err := d.DecodeResponse(rawResponse, query)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
ns, err := resp.DecodeNS()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if len(ns) != 1 {
|
||||||
|
t.Fatal("unexpected ns length")
|
||||||
|
}
|
||||||
|
if ns[0].Host != "ns1.zdns.google." {
|
||||||
|
t.Fatal("unexpected host")
|
||||||
|
}
|
||||||
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("with full answer", func(t *testing.T) {
|
t.Run("dnsResponse.LookupHost", func(t *testing.T) {
|
||||||
queryID := dns.Id()
|
t.Run("with failure", func(t *testing.T) {
|
||||||
data := dnsGenNSReplySuccess(dnsGenQuery(dns.TypeNS, queryID), "ns1.zdns.google.")
|
// Ensure that we're not trying to decode if rcode != 0
|
||||||
d := &DNSDecoderMiekg{}
|
d := &DNSDecoderMiekg{}
|
||||||
reply, err := d.DecodeNS(data, queryID)
|
queryID := dns.Id()
|
||||||
if err != nil {
|
rawQuery := dnsGenQuery(dns.TypeA, queryID)
|
||||||
t.Fatal(err)
|
rawResponse := dnsGenReplyWithError(rawQuery, dns.RcodeRefused)
|
||||||
}
|
query := &mocks.DNSQuery{
|
||||||
if len(reply) != 1 {
|
MockID: func() uint16 {
|
||||||
t.Fatal("unexpected reply length")
|
return queryID
|
||||||
}
|
},
|
||||||
if reply[0].Host != "ns1.zdns.google." {
|
}
|
||||||
t.Fatal("unexpected reply host")
|
resp, err := d.DecodeResponse(rawResponse, query)
|
||||||
}
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
addrs, err := resp.DecodeLookupHost()
|
||||||
|
if !errors.Is(err, ErrOODNSRefused) {
|
||||||
|
t.Fatal("unexpected err", err)
|
||||||
|
}
|
||||||
|
if len(addrs) > 0 {
|
||||||
|
t.Fatal("expected empty addrs result")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("with empty answer", func(t *testing.T) {
|
||||||
|
d := &DNSDecoderMiekg{}
|
||||||
|
queryID := dns.Id()
|
||||||
|
rawQuery := dnsGenQuery(dns.TypeA, queryID)
|
||||||
|
rawResponse := dnsGenLookupHostReplySuccess(rawQuery)
|
||||||
|
query := &mocks.DNSQuery{
|
||||||
|
MockID: func() uint16 {
|
||||||
|
return queryID
|
||||||
|
},
|
||||||
|
}
|
||||||
|
resp, err := d.DecodeResponse(rawResponse, query)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
addrs, err := resp.DecodeLookupHost()
|
||||||
|
if !errors.Is(err, ErrOODNSNoAnswer) {
|
||||||
|
t.Fatal("unexpected err", err)
|
||||||
|
}
|
||||||
|
if len(addrs) > 0 {
|
||||||
|
t.Fatal("expected empty ns results")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("decode A", func(t *testing.T) {
|
||||||
|
d := &DNSDecoderMiekg{}
|
||||||
|
queryID := dns.Id()
|
||||||
|
rawQuery := dnsGenQuery(dns.TypeA, queryID)
|
||||||
|
rawResponse := dnsGenLookupHostReplySuccess(rawQuery, "1.1.1.1", "8.8.8.8")
|
||||||
|
query := &mocks.DNSQuery{
|
||||||
|
MockID: func() uint16 {
|
||||||
|
return queryID
|
||||||
|
},
|
||||||
|
MockType: func() uint16 {
|
||||||
|
return dns.TypeA
|
||||||
|
},
|
||||||
|
}
|
||||||
|
resp, err := d.DecodeResponse(rawResponse, query)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
addrs, err := resp.DecodeLookupHost()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if len(addrs) != 2 {
|
||||||
|
t.Fatal("expected two entries here")
|
||||||
|
}
|
||||||
|
if addrs[0] != "1.1.1.1" {
|
||||||
|
t.Fatal("invalid first IPv4 entry")
|
||||||
|
}
|
||||||
|
if addrs[1] != "8.8.8.8" {
|
||||||
|
t.Fatal("invalid second IPv4 entry")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("decode AAAA", func(t *testing.T) {
|
||||||
|
d := &DNSDecoderMiekg{}
|
||||||
|
queryID := dns.Id()
|
||||||
|
rawQuery := dnsGenQuery(dns.TypeAAAA, queryID)
|
||||||
|
rawResponse := dnsGenLookupHostReplySuccess(rawQuery, "::1", "fe80::1")
|
||||||
|
query := &mocks.DNSQuery{
|
||||||
|
MockID: func() uint16 {
|
||||||
|
return queryID
|
||||||
|
},
|
||||||
|
MockType: func() uint16 {
|
||||||
|
return dns.TypeAAAA
|
||||||
|
},
|
||||||
|
}
|
||||||
|
resp, err := d.DecodeResponse(rawResponse, query)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
addrs, err := resp.DecodeLookupHost()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if len(addrs) != 2 {
|
||||||
|
t.Fatal("expected two entries here")
|
||||||
|
}
|
||||||
|
if addrs[0] != "::1" {
|
||||||
|
t.Fatal("invalid first IPv6 entry")
|
||||||
|
}
|
||||||
|
if addrs[1] != "fe80::1" {
|
||||||
|
t.Fatal("invalid second IPv6 entry")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("unexpected A reply to AAAA query", func(t *testing.T) {
|
||||||
|
d := &DNSDecoderMiekg{}
|
||||||
|
queryID := dns.Id()
|
||||||
|
rawQuery := dnsGenQuery(dns.TypeAAAA, queryID)
|
||||||
|
rawResponse := dnsGenLookupHostReplySuccess(rawQuery, "1.1.1.1", "8.8.8.8")
|
||||||
|
query := &mocks.DNSQuery{
|
||||||
|
MockID: func() uint16 {
|
||||||
|
return queryID
|
||||||
|
},
|
||||||
|
MockType: func() uint16 {
|
||||||
|
return dns.TypeAAAA
|
||||||
|
},
|
||||||
|
}
|
||||||
|
resp, err := d.DecodeResponse(rawResponse, query)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
addrs, err := resp.DecodeLookupHost()
|
||||||
|
if !errors.Is(err, ErrOODNSNoAnswer) {
|
||||||
|
t.Fatal("not the error we expected", err)
|
||||||
|
}
|
||||||
|
if len(addrs) > 0 {
|
||||||
|
t.Fatal("expected no addrs here")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("unexpected AAAA reply to A query", func(t *testing.T) {
|
||||||
|
d := &DNSDecoderMiekg{}
|
||||||
|
queryID := dns.Id()
|
||||||
|
rawQuery := dnsGenQuery(dns.TypeA, queryID)
|
||||||
|
rawResponse := dnsGenLookupHostReplySuccess(rawQuery, "::1", "fe80::1")
|
||||||
|
query := &mocks.DNSQuery{
|
||||||
|
MockID: func() uint16 {
|
||||||
|
return queryID
|
||||||
|
},
|
||||||
|
MockType: func() uint16 {
|
||||||
|
return dns.TypeA
|
||||||
|
},
|
||||||
|
}
|
||||||
|
resp, err := d.DecodeResponse(rawResponse, query)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
addrs, err := resp.DecodeLookupHost()
|
||||||
|
if !errors.Is(err, ErrOODNSNoAnswer) {
|
||||||
|
t.Fatal("not the error we expected", err)
|
||||||
|
}
|
||||||
|
if len(addrs) > 0 {
|
||||||
|
t.Fatal("expected no addrs here")
|
||||||
|
}
|
||||||
|
})
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -371,8 +522,8 @@ func dnsGenReplyWithError(rawQuery []byte, code int) []byte {
|
||||||
return data
|
return data
|
||||||
}
|
}
|
||||||
|
|
||||||
// dnsGenLookupHostReplySuccess generates a successful DNS reply for the given
|
// dnsGenLookupHostReplySuccess generates a successful DNS reply containing the given ips...
|
||||||
// qtype (e.g., dns.TypeA) containing the given ips... in the answer.
|
// in the answers where each answer's type depends on the IP's type (A/AAAA).
|
||||||
func dnsGenLookupHostReplySuccess(rawQuery []byte, ips ...string) []byte {
|
func dnsGenLookupHostReplySuccess(rawQuery []byte, ips ...string) []byte {
|
||||||
query := new(dns.Msg)
|
query := new(dns.Msg)
|
||||||
err := query.Unpack(rawQuery)
|
err := query.Unpack(rawQuery)
|
||||||
|
@ -388,28 +539,22 @@ func dnsGenLookupHostReplySuccess(rawQuery []byte, ips ...string) []byte {
|
||||||
reply.MsgHdr.RecursionAvailable = true
|
reply.MsgHdr.RecursionAvailable = true
|
||||||
reply.SetReply(query)
|
reply.SetReply(query)
|
||||||
for _, ip := range ips {
|
for _, ip := range ips {
|
||||||
switch question.Qtype {
|
switch isIPv6(ip) {
|
||||||
case dns.TypeA:
|
case false:
|
||||||
if isIPv6(ip) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
reply.Answer = append(reply.Answer, &dns.A{
|
reply.Answer = append(reply.Answer, &dns.A{
|
||||||
Hdr: dns.RR_Header{
|
Hdr: dns.RR_Header{
|
||||||
Name: dns.Fqdn("x.org"),
|
Name: question.Name,
|
||||||
Rrtype: question.Qtype,
|
Rrtype: dns.TypeA,
|
||||||
Class: dns.ClassINET,
|
Class: dns.ClassINET,
|
||||||
Ttl: 0,
|
Ttl: 0,
|
||||||
},
|
},
|
||||||
A: net.ParseIP(ip),
|
A: net.ParseIP(ip),
|
||||||
})
|
})
|
||||||
case dns.TypeAAAA:
|
case true:
|
||||||
if !isIPv6(ip) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
reply.Answer = append(reply.Answer, &dns.AAAA{
|
reply.Answer = append(reply.Answer, &dns.AAAA{
|
||||||
Hdr: dns.RR_Header{
|
Hdr: dns.RR_Header{
|
||||||
Name: dns.Fqdn("x.org"),
|
Name: question.Name,
|
||||||
Rrtype: question.Qtype,
|
Rrtype: dns.TypeAAAA,
|
||||||
Class: dns.ClassINET,
|
Class: dns.ClassINET,
|
||||||
Ttl: 0,
|
Ttl: 0,
|
||||||
},
|
},
|
||||||
|
|
|
@ -5,7 +5,10 @@ package netxlite
|
||||||
//
|
//
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"sync"
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
|
"github.com/ooni/probe-cli/v3/internal/atomicx"
|
||||||
"github.com/ooni/probe-cli/v3/internal/model"
|
"github.com/ooni/probe-cli/v3/internal/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -23,18 +26,82 @@ const (
|
||||||
dnsDNSSECEnabled = true
|
dnsDNSSECEnabled = true
|
||||||
)
|
)
|
||||||
|
|
||||||
func (e *DNSEncoderMiekg) Encode(domain string, qtype uint16, padding bool) ([]byte, uint16, error) {
|
// Encoder implements model.DNSEncoder.Encode.
|
||||||
|
func (e *DNSEncoderMiekg) Encode(domain string, qtype uint16, padding bool) model.DNSQuery {
|
||||||
|
return &dnsQuery{
|
||||||
|
bytesCalls: &atomicx.Int64{},
|
||||||
|
domain: domain,
|
||||||
|
kind: qtype,
|
||||||
|
id: dns.Id(),
|
||||||
|
memoizedBytes: []byte{},
|
||||||
|
mu: sync.Mutex{},
|
||||||
|
padding: padding,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// dnsQuery implements model.DNSQuery.
|
||||||
|
type dnsQuery struct {
|
||||||
|
// bytesCalls counts the calls to the bytes() method
|
||||||
|
bytesCalls *atomicx.Int64
|
||||||
|
|
||||||
|
// domain is the domain.
|
||||||
|
domain string
|
||||||
|
|
||||||
|
// kind is the query type.
|
||||||
|
kind uint16
|
||||||
|
|
||||||
|
// id is the query ID.
|
||||||
|
id uint16
|
||||||
|
|
||||||
|
// memoizedBytes contains the query encoded as bytes. We only fill
|
||||||
|
// this field the first time the Bytes method is called.
|
||||||
|
memoizedBytes []byte
|
||||||
|
|
||||||
|
// mu provides mutual exclusion.
|
||||||
|
mu sync.Mutex
|
||||||
|
|
||||||
|
// padding indicates whether we need padding.
|
||||||
|
padding bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// Domain implements model.DNSQuery.Domain.
|
||||||
|
func (q *dnsQuery) Domain() string {
|
||||||
|
return q.domain
|
||||||
|
}
|
||||||
|
|
||||||
|
// Type implements model.DNSQuery.Type.
|
||||||
|
func (q *dnsQuery) Type() uint16 {
|
||||||
|
return q.kind
|
||||||
|
}
|
||||||
|
|
||||||
|
// Bytes implements model.DNSQuery.Bytes.
|
||||||
|
func (q *dnsQuery) Bytes() ([]byte, error) {
|
||||||
|
defer q.mu.Unlock()
|
||||||
|
q.mu.Lock()
|
||||||
|
if len(q.memoizedBytes) <= 0 {
|
||||||
|
q.bytesCalls.Add(1) // for testing
|
||||||
|
data, err := q.bytes()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
q.memoizedBytes = data
|
||||||
|
}
|
||||||
|
return q.memoizedBytes, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// bytes is the unmemoized implementation of Bytes
|
||||||
|
func (q *dnsQuery) bytes() ([]byte, error) {
|
||||||
question := dns.Question{
|
question := dns.Question{
|
||||||
Name: dns.Fqdn(domain),
|
Name: dns.Fqdn(q.domain),
|
||||||
Qtype: qtype,
|
Qtype: q.kind,
|
||||||
Qclass: dns.ClassINET,
|
Qclass: dns.ClassINET,
|
||||||
}
|
}
|
||||||
query := new(dns.Msg)
|
query := new(dns.Msg)
|
||||||
query.Id = dns.Id()
|
query.Id = q.id
|
||||||
query.RecursionDesired = true
|
query.RecursionDesired = true
|
||||||
query.Question = make([]dns.Question, 1)
|
query.Question = make([]dns.Question, 1)
|
||||||
query.Question[0] = question
|
query.Question[0] = question
|
||||||
if padding {
|
if q.padding {
|
||||||
query.SetEdns0(dnsEDNS0MaxResponseSize, dnsDNSSECEnabled)
|
query.SetEdns0(dnsEDNS0MaxResponseSize, dnsDNSSECEnabled)
|
||||||
// Clients SHOULD pad queries to the closest multiple of
|
// Clients SHOULD pad queries to the closest multiple of
|
||||||
// 128 octets RFC8467#section-4.1. We inflate the query
|
// 128 octets RFC8467#section-4.1. We inflate the query
|
||||||
|
@ -47,8 +114,13 @@ func (e *DNSEncoderMiekg) Encode(domain string, qtype uint16, padding bool) ([]b
|
||||||
opt.Padding = make([]byte, remainder)
|
opt.Padding = make([]byte, remainder)
|
||||||
query.IsEdns0().Option = append(query.IsEdns0().Option, opt)
|
query.IsEdns0().Option = append(query.IsEdns0().Option, opt)
|
||||||
}
|
}
|
||||||
data, err := query.Pack()
|
return query.Pack()
|
||||||
return data, query.Id, err
|
}
|
||||||
|
|
||||||
|
// ID implements model.DNSQuery.ID
|
||||||
|
func (q *dnsQuery) ID() uint16 {
|
||||||
|
return q.id
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ model.DNSEncoder = &DNSEncoderMiekg{}
|
var _ model.DNSEncoder = &DNSEncoderMiekg{}
|
||||||
|
var _ model.DNSQuery = &dnsQuery{}
|
||||||
|
|
|
@ -1,29 +1,103 @@
|
||||||
package netxlite
|
package netxlite
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/binary"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
|
"github.com/ooni/probe-cli/v3/internal/randx"
|
||||||
|
"github.com/ooni/probe-cli/v3/internal/runtimex"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestDNSEncoder(t *testing.T) {
|
func TestDNSEncoderMiekg(t *testing.T) {
|
||||||
|
t.Run("we can fail to encode a domain name to bytes", func(t *testing.T) {
|
||||||
|
e := &DNSEncoderMiekg{}
|
||||||
|
domain := randx.LettersUppercase(512)
|
||||||
|
query := e.Encode(domain, dns.TypeA, false)
|
||||||
|
data, err := query.Bytes()
|
||||||
|
if err == nil || !strings.HasSuffix(err.Error(), "bad rdata") {
|
||||||
|
t.Fatal("unexpected err", err)
|
||||||
|
}
|
||||||
|
if data != nil {
|
||||||
|
t.Fatal("expected nil data here")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("calls to bytes are memoized", func(t *testing.T) {
|
||||||
|
t.Run("on success", func(t *testing.T) {
|
||||||
|
e := &DNSEncoderMiekg{}
|
||||||
|
query := e.Encode("x.org", dns.TypeA, false)
|
||||||
|
checkResult := func(data []byte, err error) {
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("unexpected err", err)
|
||||||
|
}
|
||||||
|
dnsValidateEncodedQueryBytes(t, data, byte(dns.TypeA), query.ID())
|
||||||
|
}
|
||||||
|
const repeat = 3
|
||||||
|
for idx := 0; idx < repeat; idx++ {
|
||||||
|
checkResult(query.Bytes())
|
||||||
|
}
|
||||||
|
// The following cast will always work in this configuration
|
||||||
|
if query.(*dnsQuery).bytesCalls.Load() != 1 {
|
||||||
|
t.Fatal("invalid number of calls")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("on failure", func(t *testing.T) {
|
||||||
|
e := &DNSEncoderMiekg{}
|
||||||
|
domain := randx.LettersUppercase(512)
|
||||||
|
query := e.Encode(domain, dns.TypeA, false)
|
||||||
|
checkResult := func(data []byte, err error) {
|
||||||
|
if err == nil || !strings.HasSuffix(err.Error(), "bad rdata") {
|
||||||
|
t.Fatal("unexpected err", err)
|
||||||
|
}
|
||||||
|
if data != nil {
|
||||||
|
t.Fatal("expected nil data here")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
const repeat = 3
|
||||||
|
for idx := 0; idx < repeat; idx++ {
|
||||||
|
checkResult(query.Bytes())
|
||||||
|
}
|
||||||
|
// The following cast will always work in this configuration
|
||||||
|
if query.(*dnsQuery).bytesCalls.Load() != repeat {
|
||||||
|
t.Fatal("invalid number of calls")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
t.Run("encode A", func(t *testing.T) {
|
t.Run("encode A", func(t *testing.T) {
|
||||||
e := &DNSEncoderMiekg{}
|
e := &DNSEncoderMiekg{}
|
||||||
data, _, err := e.Encode("x.org", dns.TypeA, false)
|
query := e.Encode("x.org", dns.TypeA, false)
|
||||||
|
if query.Domain() != "x.org" {
|
||||||
|
t.Fatal("invalid domain")
|
||||||
|
}
|
||||||
|
if query.Type() != dns.TypeA {
|
||||||
|
t.Fatal("invalid type")
|
||||||
|
}
|
||||||
|
data, err := query.Bytes()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
dnsValidateEncodedQueryBytes(t, data, byte(dns.TypeA))
|
dnsValidateEncodedQueryBytes(t, data, byte(dns.TypeA), query.ID())
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("encode AAAA", func(t *testing.T) {
|
t.Run("encode AAAA", func(t *testing.T) {
|
||||||
e := &DNSEncoderMiekg{}
|
e := &DNSEncoderMiekg{}
|
||||||
data, _, err := e.Encode("x.org", dns.TypeAAAA, false)
|
query := e.Encode("x.org", dns.TypeAAAA, false)
|
||||||
|
if query.Domain() != "x.org" {
|
||||||
|
t.Fatal("invalid domain")
|
||||||
|
}
|
||||||
|
if query.Type() != dns.TypeAAAA {
|
||||||
|
t.Fatal("invalid type")
|
||||||
|
}
|
||||||
|
data, err := query.Bytes()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
dnsValidateEncodedQueryBytes(t, data, byte(dns.TypeA))
|
dnsValidateEncodedQueryBytes(t, data, byte(dns.TypeA), query.ID())
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("encode padding", func(t *testing.T) {
|
t.Run("encode padding", func(t *testing.T) {
|
||||||
|
@ -31,7 +105,7 @@ func TestDNSEncoder(t *testing.T) {
|
||||||
// array of values we obtain the right query size.
|
// array of values we obtain the right query size.
|
||||||
getquerylen := func(domainlen int, padding bool) int {
|
getquerylen := func(domainlen int, padding bool) int {
|
||||||
e := &DNSEncoderMiekg{}
|
e := &DNSEncoderMiekg{}
|
||||||
data, _, err := e.Encode(
|
query := e.Encode(
|
||||||
// This is not a valid name because it ends up being way
|
// This is not a valid name because it ends up being way
|
||||||
// longer than 255 octets. However, the library is allowing
|
// longer than 255 octets. However, the library is allowing
|
||||||
// us to generate such name and we are not going to send
|
// us to generate such name and we are not going to send
|
||||||
|
@ -40,6 +114,7 @@ func TestDNSEncoder(t *testing.T) {
|
||||||
dns.Fqdn(strings.Repeat("x.", domainlen)),
|
dns.Fqdn(strings.Repeat("x.", domainlen)),
|
||||||
dns.TypeA, padding,
|
dns.TypeA, padding,
|
||||||
)
|
)
|
||||||
|
data, err := query.Bytes()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
@ -63,8 +138,13 @@ func TestDNSEncoder(t *testing.T) {
|
||||||
|
|
||||||
// dnsValidateEncodedQueryBytes validates the query serialized in data
|
// dnsValidateEncodedQueryBytes validates the query serialized in data
|
||||||
// for the given query type qtype (e.g., dns.TypeAAAA).
|
// for the given query type qtype (e.g., dns.TypeAAAA).
|
||||||
func dnsValidateEncodedQueryBytes(t *testing.T, data []byte, qtype byte) {
|
func dnsValidateEncodedQueryBytes(t *testing.T, data []byte, qtype byte, qid uint16) {
|
||||||
// skipping over the query ID
|
var wirequery uint16
|
||||||
|
err := binary.Read(bytes.NewReader(data), binary.BigEndian, &wirequery)
|
||||||
|
runtimex.PanicOnError(err, "binary.Read failed unexpectedly")
|
||||||
|
if wirequery != qid {
|
||||||
|
t.Fatal("invalid query ID")
|
||||||
|
}
|
||||||
if data[2] != 1 {
|
if data[2] != 1 {
|
||||||
t.Fatal("FLAGS should only have RD set")
|
t.Fatal("FLAGS should only have RD set")
|
||||||
}
|
}
|
||||||
|
|
|
@ -8,6 +8,7 @@ import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -19,6 +20,9 @@ type DNSOverHTTPSTransport struct {
|
||||||
// Client is the MANDATORY http client to use.
|
// Client is the MANDATORY http client to use.
|
||||||
Client model.HTTPClient
|
Client model.HTTPClient
|
||||||
|
|
||||||
|
// Decoder is the MANDATORY DNSDecoder.
|
||||||
|
Decoder model.DNSDecoder
|
||||||
|
|
||||||
// URL is the MANDATORY URL of the DNS-over-HTTPS server.
|
// URL is the MANDATORY URL of the DNS-over-HTTPS server.
|
||||||
URL string
|
URL string
|
||||||
|
|
||||||
|
@ -31,9 +35,9 @@ type DNSOverHTTPSTransport struct {
|
||||||
//
|
//
|
||||||
// Arguments:
|
// Arguments:
|
||||||
//
|
//
|
||||||
// - client in http.Client-like type (e.g., http.DefaultClient);
|
// - client is a model.HTTPClient type;
|
||||||
//
|
//
|
||||||
// - URL is the DoH resolver URL (e.g., https://1.1.1.1/dns-query).
|
// - URL is the DoH resolver URL (e.g., https://dns.google/dns-query).
|
||||||
func NewDNSOverHTTPSTransport(client model.HTTPClient, URL string) *DNSOverHTTPSTransport {
|
func NewDNSOverHTTPSTransport(client model.HTTPClient, URL string) *DNSOverHTTPSTransport {
|
||||||
return NewDNSOverHTTPSTransportWithHostOverride(client, URL, "")
|
return NewDNSOverHTTPSTransportWithHostOverride(client, URL, "")
|
||||||
}
|
}
|
||||||
|
@ -42,22 +46,31 @@ func NewDNSOverHTTPSTransport(client model.HTTPClient, URL string) *DNSOverHTTPS
|
||||||
// with the given Host header override.
|
// with the given Host header override.
|
||||||
func NewDNSOverHTTPSTransportWithHostOverride(
|
func NewDNSOverHTTPSTransportWithHostOverride(
|
||||||
client model.HTTPClient, URL, hostOverride string) *DNSOverHTTPSTransport {
|
client model.HTTPClient, URL, hostOverride string) *DNSOverHTTPSTransport {
|
||||||
return &DNSOverHTTPSTransport{Client: client, URL: URL, HostOverride: hostOverride}
|
return &DNSOverHTTPSTransport{
|
||||||
|
Client: client,
|
||||||
|
Decoder: &DNSDecoderMiekg{},
|
||||||
|
URL: URL,
|
||||||
|
HostOverride: hostOverride,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// RoundTrip sends a query and receives a reply.
|
// RoundTrip sends a query and receives a reply.
|
||||||
func (t *DNSOverHTTPSTransport) RoundTrip(ctx context.Context, query []byte) ([]byte, error) {
|
func (t *DNSOverHTTPSTransport) RoundTrip(
|
||||||
|
ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
|
||||||
|
rawQuery, err := query.Bytes()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
ctx, cancel := context.WithTimeout(ctx, 45*time.Second)
|
ctx, cancel := context.WithTimeout(ctx, 45*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
req, err := http.NewRequest("POST", t.URL, bytes.NewReader(query))
|
req, err := http.NewRequest("POST", t.URL, bytes.NewReader(rawQuery))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
req.Host = t.HostOverride
|
req.Host = t.HostOverride
|
||||||
req.Header.Set("user-agent", model.HTTPHeaderUserAgent)
|
req.Header.Set("user-agent", model.HTTPHeaderUserAgent)
|
||||||
req.Header.Set("content-type", "application/dns-message")
|
req.Header.Set("content-type", "application/dns-message")
|
||||||
var resp *http.Response
|
resp, err := t.Client.Do(req.WithContext(ctx))
|
||||||
resp, err = t.Client.Do(req.WithContext(ctx))
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -70,7 +83,13 @@ func (t *DNSOverHTTPSTransport) RoundTrip(ctx context.Context, query []byte) ([]
|
||||||
if resp.Header.Get("content-type") != "application/dns-message" {
|
if resp.Header.Get("content-type") != "application/dns-message" {
|
||||||
return nil, errors.New("doh: invalid content-type")
|
return nil, errors.New("doh: invalid content-type")
|
||||||
}
|
}
|
||||||
return ReadAllContext(ctx, resp.Body)
|
const maxresponsesize = 1 << 20
|
||||||
|
limitReader := io.LimitReader(resp.Body, maxresponsesize)
|
||||||
|
rawResponse, err := ReadAllContext(ctx, limitReader)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return t.Decoder.DecodeResponse(rawResponse, query)
|
||||||
}
|
}
|
||||||
|
|
||||||
// RequiresPadding returns true for DoH according to RFC8467.
|
// RequiresPadding returns true for DoH according to RFC8467.
|
||||||
|
|
|
@ -15,14 +15,36 @@ import (
|
||||||
|
|
||||||
func TestDNSOverHTTPSTransport(t *testing.T) {
|
func TestDNSOverHTTPSTransport(t *testing.T) {
|
||||||
t.Run("RoundTrip", func(t *testing.T) {
|
t.Run("RoundTrip", func(t *testing.T) {
|
||||||
|
t.Run("query serialization failure", func(t *testing.T) {
|
||||||
|
txp := NewDNSOverHTTPSTransport(http.DefaultClient, "https://1.1.1.1/dns-query")
|
||||||
|
expected := errors.New("mocked error")
|
||||||
|
query := &mocks.DNSQuery{
|
||||||
|
MockBytes: func() ([]byte, error) {
|
||||||
|
return nil, expected
|
||||||
|
},
|
||||||
|
}
|
||||||
|
resp, err := txp.RoundTrip(context.Background(), query)
|
||||||
|
if !errors.Is(err, expected) {
|
||||||
|
t.Fatal("unexpected err", err)
|
||||||
|
}
|
||||||
|
if resp != nil {
|
||||||
|
t.Fatal("expected no response here")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
t.Run("NewRequestFailure", func(t *testing.T) {
|
t.Run("NewRequestFailure", func(t *testing.T) {
|
||||||
const invalidURL = "\t"
|
const invalidURL = "\t"
|
||||||
txp := NewDNSOverHTTPSTransport(http.DefaultClient, invalidURL)
|
txp := NewDNSOverHTTPSTransport(http.DefaultClient, invalidURL)
|
||||||
data, err := txp.RoundTrip(context.Background(), nil)
|
query := &mocks.DNSQuery{
|
||||||
if err == nil || !strings.HasSuffix(err.Error(), "invalid control character in URL") {
|
MockBytes: func() ([]byte, error) {
|
||||||
t.Fatal("expected an error here")
|
return make([]byte, 17), nil
|
||||||
|
},
|
||||||
}
|
}
|
||||||
if data != nil {
|
resp, err := txp.RoundTrip(context.Background(), query)
|
||||||
|
if err == nil || !strings.HasSuffix(err.Error(), "invalid control character in URL") {
|
||||||
|
t.Fatal("unexpected err", err)
|
||||||
|
}
|
||||||
|
if resp != nil {
|
||||||
t.Fatal("expected no response here")
|
t.Fatal("expected no response here")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
@ -37,11 +59,16 @@ func TestDNSOverHTTPSTransport(t *testing.T) {
|
||||||
},
|
},
|
||||||
URL: "https://cloudflare-dns.com/dns-query",
|
URL: "https://cloudflare-dns.com/dns-query",
|
||||||
}
|
}
|
||||||
data, err := txp.RoundTrip(context.Background(), nil)
|
query := &mocks.DNSQuery{
|
||||||
if !errors.Is(err, expected) {
|
MockBytes: func() ([]byte, error) {
|
||||||
t.Fatal("expected an error here")
|
return make([]byte, 17), nil
|
||||||
|
},
|
||||||
}
|
}
|
||||||
if data != nil {
|
resp, err := txp.RoundTrip(context.Background(), query)
|
||||||
|
if !errors.Is(err, expected) {
|
||||||
|
t.Fatal("unexpected err", err)
|
||||||
|
}
|
||||||
|
if resp != nil {
|
||||||
t.Fatal("expected no response here")
|
t.Fatal("expected no response here")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
@ -58,11 +85,16 @@ func TestDNSOverHTTPSTransport(t *testing.T) {
|
||||||
},
|
},
|
||||||
URL: "https://cloudflare-dns.com/dns-query",
|
URL: "https://cloudflare-dns.com/dns-query",
|
||||||
}
|
}
|
||||||
data, err := txp.RoundTrip(context.Background(), nil)
|
query := &mocks.DNSQuery{
|
||||||
if err == nil || err.Error() != "doh: server returned error" {
|
MockBytes: func() ([]byte, error) {
|
||||||
t.Fatal("expected an error here")
|
return make([]byte, 17), nil
|
||||||
|
},
|
||||||
}
|
}
|
||||||
if data != nil {
|
resp, err := txp.RoundTrip(context.Background(), query)
|
||||||
|
if err == nil || err.Error() != "doh: server returned error" {
|
||||||
|
t.Fatal("unexpected err", err)
|
||||||
|
}
|
||||||
|
if resp != nil {
|
||||||
t.Fatal("expected no response here")
|
t.Fatal("expected no response here")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
@ -79,11 +111,86 @@ func TestDNSOverHTTPSTransport(t *testing.T) {
|
||||||
},
|
},
|
||||||
URL: "https://cloudflare-dns.com/dns-query",
|
URL: "https://cloudflare-dns.com/dns-query",
|
||||||
}
|
}
|
||||||
data, err := txp.RoundTrip(context.Background(), nil)
|
query := &mocks.DNSQuery{
|
||||||
if err == nil || err.Error() != "doh: invalid content-type" {
|
MockBytes: func() ([]byte, error) {
|
||||||
t.Fatal("expected an error here")
|
return make([]byte, 17), nil
|
||||||
|
},
|
||||||
}
|
}
|
||||||
if data != nil {
|
resp, err := txp.RoundTrip(context.Background(), query)
|
||||||
|
if err == nil || err.Error() != "doh: invalid content-type" {
|
||||||
|
t.Fatal("unexpected err", err)
|
||||||
|
}
|
||||||
|
if resp != nil {
|
||||||
|
t.Fatal("expected no response here")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("ReadAllContext fails", func(t *testing.T) {
|
||||||
|
expected := errors.New("mocked error")
|
||||||
|
txp := &DNSOverHTTPSTransport{
|
||||||
|
Client: &mocks.HTTPClient{
|
||||||
|
MockDo: func(*http.Request) (*http.Response, error) {
|
||||||
|
return &http.Response{
|
||||||
|
StatusCode: 200,
|
||||||
|
Body: io.NopCloser(&mocks.Reader{
|
||||||
|
MockRead: func(b []byte) (int, error) {
|
||||||
|
return 0, expected
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
Header: http.Header{
|
||||||
|
"Content-Type": []string{"application/dns-message"},
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
URL: "https://cloudflare-dns.com/dns-query",
|
||||||
|
}
|
||||||
|
query := &mocks.DNSQuery{
|
||||||
|
MockBytes: func() ([]byte, error) {
|
||||||
|
return make([]byte, 17), nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
resp, err := txp.RoundTrip(context.Background(), query)
|
||||||
|
if !errors.Is(err, expected) {
|
||||||
|
t.Fatal("unexpected err", err)
|
||||||
|
}
|
||||||
|
if resp != nil {
|
||||||
|
t.Fatal("expected no response here")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("decode response failure", func(t *testing.T) {
|
||||||
|
expected := errors.New("mocked error")
|
||||||
|
body := []byte("AAA")
|
||||||
|
txp := &DNSOverHTTPSTransport{
|
||||||
|
Client: &mocks.HTTPClient{
|
||||||
|
MockDo: func(*http.Request) (*http.Response, error) {
|
||||||
|
return &http.Response{
|
||||||
|
StatusCode: 200,
|
||||||
|
Body: io.NopCloser(bytes.NewReader(body)),
|
||||||
|
Header: http.Header{
|
||||||
|
"Content-Type": []string{"application/dns-message"},
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
URL: "https://cloudflare-dns.com/dns-query",
|
||||||
|
Decoder: &mocks.DNSDecoder{
|
||||||
|
MockDecodeResponse: func(data []byte, query model.DNSQuery) (model.DNSResponse, error) {
|
||||||
|
return nil, expected
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
query := &mocks.DNSQuery{
|
||||||
|
MockBytes: func() ([]byte, error) {
|
||||||
|
return make([]byte, 17), nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
resp, err := txp.RoundTrip(context.Background(), query)
|
||||||
|
if !errors.Is(err, expected) {
|
||||||
|
t.Fatal("unexpected err", err)
|
||||||
|
}
|
||||||
|
if resp != nil {
|
||||||
t.Fatal("expected no response here")
|
t.Fatal("expected no response here")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
@ -103,13 +210,23 @@ func TestDNSOverHTTPSTransport(t *testing.T) {
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
URL: "https://cloudflare-dns.com/dns-query",
|
URL: "https://cloudflare-dns.com/dns-query",
|
||||||
|
Decoder: &mocks.DNSDecoder{
|
||||||
|
MockDecodeResponse: func(data []byte, query model.DNSQuery) (model.DNSResponse, error) {
|
||||||
|
return &mocks.DNSResponse{}, nil
|
||||||
|
},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
data, err := txp.RoundTrip(context.Background(), nil)
|
query := &mocks.DNSQuery{
|
||||||
|
MockBytes: func() ([]byte, error) {
|
||||||
|
return make([]byte, 17), nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
resp, err := txp.RoundTrip(context.Background(), query)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
if !bytes.Equal(data, body) {
|
if resp == nil {
|
||||||
t.Fatal("not the response we expected")
|
t.Fatal("expected non-nil resp here")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -125,7 +242,12 @@ func TestDNSOverHTTPSTransport(t *testing.T) {
|
||||||
},
|
},
|
||||||
URL: "https://cloudflare-dns.com/dns-query",
|
URL: "https://cloudflare-dns.com/dns-query",
|
||||||
}
|
}
|
||||||
data, err := txp.RoundTrip(context.Background(), nil)
|
query := &mocks.DNSQuery{
|
||||||
|
MockBytes: func() ([]byte, error) {
|
||||||
|
return make([]byte, 17), nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
data, err := txp.RoundTrip(context.Background(), query)
|
||||||
if !errors.Is(err, expected) {
|
if !errors.Is(err, expected) {
|
||||||
t.Fatal("expected an error here")
|
t.Fatal("expected an error here")
|
||||||
}
|
}
|
||||||
|
@ -151,18 +273,22 @@ func TestDNSOverHTTPSTransport(t *testing.T) {
|
||||||
URL: "https://cloudflare-dns.com/dns-query",
|
URL: "https://cloudflare-dns.com/dns-query",
|
||||||
HostOverride: hostOverride,
|
HostOverride: hostOverride,
|
||||||
}
|
}
|
||||||
data, err := txp.RoundTrip(context.Background(), nil)
|
query := &mocks.DNSQuery{
|
||||||
|
MockBytes: func() ([]byte, error) {
|
||||||
|
return make([]byte, 17), nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
resp, err := txp.RoundTrip(context.Background(), query)
|
||||||
if !errors.Is(err, expected) {
|
if !errors.Is(err, expected) {
|
||||||
t.Fatal("expected an error here")
|
t.Fatal("expected an error here")
|
||||||
}
|
}
|
||||||
if data != nil {
|
if resp != nil {
|
||||||
t.Fatal("expected no response here")
|
t.Fatal("expected no response here")
|
||||||
}
|
}
|
||||||
if !correct {
|
if !correct {
|
||||||
t.Fatal("did not see correct host override")
|
t.Fatal("did not see correct host override")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("other functions behave correctly", func(t *testing.T) {
|
t.Run("other functions behave correctly", func(t *testing.T) {
|
||||||
|
|
|
@ -20,9 +20,12 @@ type DialContextFunc func(context.Context, string, string) (net.Conn, error)
|
||||||
|
|
||||||
// DNSOverTCPTransport is a DNS-over-{TCP,TLS} DNSTransport.
|
// DNSOverTCPTransport is a DNS-over-{TCP,TLS} DNSTransport.
|
||||||
//
|
//
|
||||||
// Bug: this implementation always creates a new connection for each query.
|
// Note: this implementation always creates a new connection for each query. This
|
||||||
|
// strategy is less efficient but MAY be more robust for cleartext TCP connections
|
||||||
|
// when querying for a blocked domain name causes endpoint blocking.
|
||||||
type DNSOverTCPTransport struct {
|
type DNSOverTCPTransport struct {
|
||||||
dial DialContextFunc
|
dial DialContextFunc
|
||||||
|
decoder model.DNSDecoder
|
||||||
address string
|
address string
|
||||||
network string
|
network string
|
||||||
requiresPadding bool
|
requiresPadding bool
|
||||||
|
@ -36,47 +39,58 @@ type DNSOverTCPTransport struct {
|
||||||
//
|
//
|
||||||
// - address is the endpoint address (e.g., 8.8.8.8:53).
|
// - address is the endpoint address (e.g., 8.8.8.8:53).
|
||||||
func NewDNSOverTCPTransport(dial DialContextFunc, address string) *DNSOverTCPTransport {
|
func NewDNSOverTCPTransport(dial DialContextFunc, address string) *DNSOverTCPTransport {
|
||||||
return &DNSOverTCPTransport{
|
return newDNSOverTCPOrTLSTransport(dial, "tcp", address, false)
|
||||||
dial: dial,
|
|
||||||
address: address,
|
|
||||||
network: "tcp",
|
|
||||||
requiresPadding: false,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewDNSOverTLS creates a new DNSOverTLS transport.
|
// NewDNSOverTLSTransport creates a new DNSOverTLS transport.
|
||||||
//
|
//
|
||||||
// Arguments:
|
// Arguments:
|
||||||
//
|
//
|
||||||
// - dial is a function with the net.Dialer.DialContext's signature;
|
// - dial is a function with the net.Dialer.DialContext's signature;
|
||||||
//
|
//
|
||||||
// - address is the endpoint address (e.g., 8.8.8.8:853).
|
// - address is the endpoint address (e.g., 8.8.8.8:853).
|
||||||
func NewDNSOverTLS(dial DialContextFunc, address string) *DNSOverTCPTransport {
|
func NewDNSOverTLSTransport(dial DialContextFunc, address string) *DNSOverTCPTransport {
|
||||||
|
return newDNSOverTCPOrTLSTransport(dial, "dot", address, true)
|
||||||
|
}
|
||||||
|
|
||||||
|
// newDNSOverTCPOrTLSTransport is the common factory for creating a transport
|
||||||
|
func newDNSOverTCPOrTLSTransport(
|
||||||
|
dial DialContextFunc, network, address string, padding bool) *DNSOverTCPTransport {
|
||||||
return &DNSOverTCPTransport{
|
return &DNSOverTCPTransport{
|
||||||
dial: dial,
|
dial: dial,
|
||||||
|
decoder: &DNSDecoderMiekg{},
|
||||||
address: address,
|
address: address,
|
||||||
network: "dot",
|
network: network,
|
||||||
requiresPadding: true,
|
requiresPadding: padding,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// errQueryTooLarge indicates the query is too large for the transport.
|
||||||
|
var errQueryTooLarge = errors.New("oodns: query too large for this transport")
|
||||||
|
|
||||||
// RoundTrip sends a query and receives a reply.
|
// RoundTrip sends a query and receives a reply.
|
||||||
func (t *DNSOverTCPTransport) RoundTrip(ctx context.Context, query []byte) ([]byte, error) {
|
func (t *DNSOverTCPTransport) RoundTrip(
|
||||||
if len(query) > math.MaxUint16 {
|
ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
|
||||||
return nil, errors.New("query too long")
|
// TODO(bassosimone): this method should more strictly honour the context, which
|
||||||
|
// currently is only used to bound the dial operation
|
||||||
|
rawQuery, err := query.Bytes()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if len(rawQuery) > math.MaxUint16 {
|
||||||
|
return nil, errQueryTooLarge
|
||||||
}
|
}
|
||||||
conn, err := t.dial(ctx, "tcp", t.address)
|
conn, err := t.dial(ctx, "tcp", t.address)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer conn.Close()
|
defer conn.Close()
|
||||||
if err = conn.SetDeadline(time.Now().Add(10 * time.Second)); err != nil {
|
const iotimeout = 10 * time.Second
|
||||||
return nil, err
|
conn.SetDeadline(time.Now().Add(iotimeout))
|
||||||
}
|
|
||||||
// Write request
|
// Write request
|
||||||
buf := []byte{byte(len(query) >> 8)}
|
buf := []byte{byte(len(rawQuery) >> 8)}
|
||||||
buf = append(buf, byte(len(query)))
|
buf = append(buf, byte(len(rawQuery)))
|
||||||
buf = append(buf, query...)
|
buf = append(buf, rawQuery...)
|
||||||
if _, err = conn.Write(buf); err != nil {
|
if _, err = conn.Write(buf); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -86,11 +100,11 @@ func (t *DNSOverTCPTransport) RoundTrip(ctx context.Context, query []byte) ([]by
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
length := int(header[0])<<8 | int(header[1])
|
length := int(header[0])<<8 | int(header[1])
|
||||||
reply := make([]byte, length)
|
rawResponse := make([]byte, length)
|
||||||
if _, err = io.ReadFull(conn, reply); err != nil {
|
if _, err = io.ReadFull(conn, rawResponse); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return reply, nil
|
return t.decoder.DecodeResponse(rawResponse, query)
|
||||||
}
|
}
|
||||||
|
|
||||||
// RequiresPadding returns true for DoT and false for TCP
|
// RequiresPadding returns true for DoT and false for TCP
|
||||||
|
|
|
@ -6,73 +6,83 @@ import (
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"errors"
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
|
"math"
|
||||||
"net"
|
"net"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/ooni/probe-cli/v3/internal/model"
|
||||||
"github.com/ooni/probe-cli/v3/internal/model/mocks"
|
"github.com/ooni/probe-cli/v3/internal/model/mocks"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestDNSOverTCPTransport(t *testing.T) {
|
func TestDNSOverTCPTransport(t *testing.T) {
|
||||||
t.Run("RoundTrip", func(t *testing.T) {
|
t.Run("RoundTrip", func(t *testing.T) {
|
||||||
|
t.Run("cannot encode query", func(t *testing.T) {
|
||||||
|
expected := errors.New("mocked error")
|
||||||
|
const address = "9.9.9.9:53"
|
||||||
|
txp := NewDNSOverTCPTransport(new(net.Dialer).DialContext, address)
|
||||||
|
query := &mocks.DNSQuery{
|
||||||
|
MockBytes: func() ([]byte, error) {
|
||||||
|
return nil, expected
|
||||||
|
},
|
||||||
|
}
|
||||||
|
resp, err := txp.RoundTrip(context.Background(), query)
|
||||||
|
if !errors.Is(err, expected) {
|
||||||
|
t.Fatal("unexpected err", err)
|
||||||
|
}
|
||||||
|
if resp != nil {
|
||||||
|
t.Fatal("expected nil response here")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
t.Run("query too large", func(t *testing.T) {
|
t.Run("query too large", func(t *testing.T) {
|
||||||
const address = "9.9.9.9:53"
|
const address = "9.9.9.9:53"
|
||||||
txp := NewDNSOverTCPTransport(new(net.Dialer).DialContext, address)
|
txp := NewDNSOverTCPTransport(new(net.Dialer).DialContext, address)
|
||||||
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<18))
|
query := &mocks.DNSQuery{
|
||||||
if err == nil {
|
MockBytes: func() ([]byte, error) {
|
||||||
t.Fatal("expected an error here")
|
return make([]byte, math.MaxUint16+1), nil
|
||||||
|
},
|
||||||
}
|
}
|
||||||
if reply != nil {
|
resp, err := txp.RoundTrip(context.Background(), query)
|
||||||
t.Fatal("expected nil reply here")
|
if !errors.Is(err, errQueryTooLarge) {
|
||||||
|
t.Fatal("unexpected err", err)
|
||||||
|
}
|
||||||
|
if resp != nil {
|
||||||
|
t.Fatal("expected nil response here")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("dial failure", func(t *testing.T) {
|
t.Run("dial failure", func(t *testing.T) {
|
||||||
const address = "9.9.9.9:53"
|
const address = "9.9.9.9:53"
|
||||||
mocked := errors.New("mocked error")
|
mocked := errors.New("mocked error")
|
||||||
|
query := &mocks.DNSQuery{
|
||||||
|
MockBytes: func() ([]byte, error) {
|
||||||
|
return make([]byte, 128), nil
|
||||||
|
},
|
||||||
|
}
|
||||||
fakedialer := &mocks.Dialer{
|
fakedialer := &mocks.Dialer{
|
||||||
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||||
return nil, mocked
|
return nil, mocked
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
txp := NewDNSOverTCPTransport(fakedialer.DialContext, address)
|
txp := NewDNSOverTCPTransport(fakedialer.DialContext, address)
|
||||||
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11))
|
resp, err := txp.RoundTrip(context.Background(), query)
|
||||||
if !errors.Is(err, mocked) {
|
if !errors.Is(err, mocked) {
|
||||||
t.Fatal("not the error we expected")
|
t.Fatal("not the error we expected")
|
||||||
}
|
}
|
||||||
if reply != nil {
|
if resp != nil {
|
||||||
t.Fatal("expected nil reply here")
|
t.Fatal("expected nil resp here")
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("SetDeadline failure", func(t *testing.T) {
|
|
||||||
const address = "9.9.9.9:53"
|
|
||||||
mocked := errors.New("mocked error")
|
|
||||||
fakedialer := &mocks.Dialer{
|
|
||||||
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
|
||||||
return &mocks.Conn{
|
|
||||||
MockSetDeadline: func(t time.Time) error {
|
|
||||||
return mocked
|
|
||||||
},
|
|
||||||
MockClose: func() error {
|
|
||||||
return nil
|
|
||||||
},
|
|
||||||
}, nil
|
|
||||||
},
|
|
||||||
}
|
|
||||||
txp := NewDNSOverTCPTransport(fakedialer.DialContext, address)
|
|
||||||
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11))
|
|
||||||
if !errors.Is(err, mocked) {
|
|
||||||
t.Fatal("not the error we expected")
|
|
||||||
}
|
|
||||||
if reply != nil {
|
|
||||||
t.Fatal("expected nil reply here")
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("write failure", func(t *testing.T) {
|
t.Run("write failure", func(t *testing.T) {
|
||||||
const address = "9.9.9.9:53"
|
const address = "9.9.9.9:53"
|
||||||
mocked := errors.New("mocked error")
|
mocked := errors.New("mocked error")
|
||||||
|
query := &mocks.DNSQuery{
|
||||||
|
MockBytes: func() ([]byte, error) {
|
||||||
|
return make([]byte, 128), nil
|
||||||
|
},
|
||||||
|
}
|
||||||
fakedialer := &mocks.Dialer{
|
fakedialer := &mocks.Dialer{
|
||||||
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||||
return &mocks.Conn{
|
return &mocks.Conn{
|
||||||
|
@ -89,18 +99,23 @@ func TestDNSOverTCPTransport(t *testing.T) {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
txp := NewDNSOverTCPTransport(fakedialer.DialContext, address)
|
txp := NewDNSOverTCPTransport(fakedialer.DialContext, address)
|
||||||
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11))
|
resp, err := txp.RoundTrip(context.Background(), query)
|
||||||
if !errors.Is(err, mocked) {
|
if !errors.Is(err, mocked) {
|
||||||
t.Fatal("not the error we expected")
|
t.Fatal("not the error we expected")
|
||||||
}
|
}
|
||||||
if reply != nil {
|
if resp != nil {
|
||||||
t.Fatal("expected nil reply here")
|
t.Fatal("expected nil resp here")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("first read fails", func(t *testing.T) {
|
t.Run("first read fails", func(t *testing.T) {
|
||||||
const address = "9.9.9.9:53"
|
const address = "9.9.9.9:53"
|
||||||
mocked := errors.New("mocked error")
|
mocked := errors.New("mocked error")
|
||||||
|
query := &mocks.DNSQuery{
|
||||||
|
MockBytes: func() ([]byte, error) {
|
||||||
|
return make([]byte, 128), nil
|
||||||
|
},
|
||||||
|
}
|
||||||
fakedialer := &mocks.Dialer{
|
fakedialer := &mocks.Dialer{
|
||||||
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||||
return &mocks.Conn{
|
return &mocks.Conn{
|
||||||
|
@ -120,18 +135,23 @@ func TestDNSOverTCPTransport(t *testing.T) {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
txp := NewDNSOverTCPTransport(fakedialer.DialContext, address)
|
txp := NewDNSOverTCPTransport(fakedialer.DialContext, address)
|
||||||
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11))
|
resp, err := txp.RoundTrip(context.Background(), query)
|
||||||
if !errors.Is(err, mocked) {
|
if !errors.Is(err, mocked) {
|
||||||
t.Fatal("not the error we expected")
|
t.Fatal("not the error we expected")
|
||||||
}
|
}
|
||||||
if reply != nil {
|
if resp != nil {
|
||||||
t.Fatal("expected nil reply here")
|
t.Fatal("expected nil resp here")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("second read fails", func(t *testing.T) {
|
t.Run("second read fails", func(t *testing.T) {
|
||||||
const address = "9.9.9.9:53"
|
const address = "9.9.9.9:53"
|
||||||
mocked := errors.New("mocked error")
|
mocked := errors.New("mocked error")
|
||||||
|
query := &mocks.DNSQuery{
|
||||||
|
MockBytes: func() ([]byte, error) {
|
||||||
|
return make([]byte, 128), nil
|
||||||
|
},
|
||||||
|
}
|
||||||
input := io.MultiReader(
|
input := io.MultiReader(
|
||||||
bytes.NewReader([]byte{byte(0), byte(2)}),
|
bytes.NewReader([]byte{byte(0), byte(2)}),
|
||||||
&mocks.Reader{
|
&mocks.Reader{
|
||||||
|
@ -157,17 +177,23 @@ func TestDNSOverTCPTransport(t *testing.T) {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
txp := NewDNSOverTCPTransport(fakedialer.DialContext, address)
|
txp := NewDNSOverTCPTransport(fakedialer.DialContext, address)
|
||||||
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11))
|
resp, err := txp.RoundTrip(context.Background(), query)
|
||||||
if !errors.Is(err, mocked) {
|
if !errors.Is(err, mocked) {
|
||||||
t.Fatal("not the error we expected")
|
t.Fatal("not the error we expected")
|
||||||
}
|
}
|
||||||
if reply != nil {
|
if resp != nil {
|
||||||
t.Fatal("expected nil reply here")
|
t.Fatal("expected nil resp here")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("successful case", func(t *testing.T) {
|
t.Run("decode failure", func(t *testing.T) {
|
||||||
const address = "9.9.9.9:53"
|
const address = "9.9.9.9:53"
|
||||||
|
mocked := errors.New("mocked error")
|
||||||
|
query := &mocks.DNSQuery{
|
||||||
|
MockBytes: func() ([]byte, error) {
|
||||||
|
return make([]byte, 128), nil
|
||||||
|
},
|
||||||
|
}
|
||||||
input := bytes.NewReader([]byte{byte(0), byte(1), byte(1)})
|
input := bytes.NewReader([]byte{byte(0), byte(1), byte(1)})
|
||||||
fakedialer := &mocks.Dialer{
|
fakedialer := &mocks.Dialer{
|
||||||
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||||
|
@ -186,11 +212,56 @@ func TestDNSOverTCPTransport(t *testing.T) {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
txp := NewDNSOverTCPTransport(fakedialer.DialContext, address)
|
txp := NewDNSOverTCPTransport(fakedialer.DialContext, address)
|
||||||
reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11))
|
txp.decoder = &mocks.DNSDecoder{
|
||||||
|
MockDecodeResponse: func(data []byte, query model.DNSQuery) (model.DNSResponse, error) {
|
||||||
|
return nil, mocked
|
||||||
|
},
|
||||||
|
}
|
||||||
|
resp, err := txp.RoundTrip(context.Background(), query)
|
||||||
|
if !errors.Is(err, mocked) {
|
||||||
|
t.Fatal("unexpected err", err)
|
||||||
|
}
|
||||||
|
if resp != nil {
|
||||||
|
t.Fatal("expected nil resp here")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("successful case", func(t *testing.T) {
|
||||||
|
const address = "9.9.9.9:53"
|
||||||
|
query := &mocks.DNSQuery{
|
||||||
|
MockBytes: func() ([]byte, error) {
|
||||||
|
return make([]byte, 128), nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
input := bytes.NewReader([]byte{byte(0), byte(1), byte(1)})
|
||||||
|
fakedialer := &mocks.Dialer{
|
||||||
|
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||||
|
return &mocks.Conn{
|
||||||
|
MockSetDeadline: func(t time.Time) error {
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
MockWrite: func(b []byte) (int, error) {
|
||||||
|
return len(b), nil
|
||||||
|
},
|
||||||
|
MockRead: input.Read,
|
||||||
|
MockClose: func() error {
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
txp := NewDNSOverTCPTransport(fakedialer.DialContext, address)
|
||||||
|
expectedResp := &mocks.DNSResponse{}
|
||||||
|
txp.decoder = &mocks.DNSDecoder{
|
||||||
|
MockDecodeResponse: func(data []byte, query model.DNSQuery) (model.DNSResponse, error) {
|
||||||
|
return expectedResp, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
resp, err := txp.RoundTrip(context.Background(), query)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
if len(reply) != 1 || reply[0] != 1 {
|
if resp != expectedResp {
|
||||||
t.Fatal("not the response we expected")
|
t.Fatal("not the response we expected")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
@ -213,7 +284,7 @@ func TestDNSOverTCPTransport(t *testing.T) {
|
||||||
|
|
||||||
t.Run("other functions okay with TLS", func(t *testing.T) {
|
t.Run("other functions okay with TLS", func(t *testing.T) {
|
||||||
const address = "9.9.9.9:853"
|
const address = "9.9.9.9:853"
|
||||||
txp := NewDNSOverTLS((&tls.Dialer{}).DialContext, address)
|
txp := NewDNSOverTLSTransport((&tls.Dialer{}).DialContext, address)
|
||||||
if txp.RequiresPadding() != true {
|
if txp.RequiresPadding() != true {
|
||||||
t.Fatal("invalid RequiresPadding")
|
t.Fatal("invalid RequiresPadding")
|
||||||
}
|
}
|
||||||
|
|
|
@ -14,6 +14,7 @@ import (
|
||||||
// DNSOverUDPTransport is a DNS-over-UDP DNSTransport.
|
// DNSOverUDPTransport is a DNS-over-UDP DNSTransport.
|
||||||
type DNSOverUDPTransport struct {
|
type DNSOverUDPTransport struct {
|
||||||
dialer model.Dialer
|
dialer model.Dialer
|
||||||
|
decoder model.DNSDecoder
|
||||||
address string
|
address string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -25,11 +26,20 @@ type DNSOverUDPTransport struct {
|
||||||
//
|
//
|
||||||
// - address is the endpoint address (e.g., 8.8.8.8:53).
|
// - address is the endpoint address (e.g., 8.8.8.8:53).
|
||||||
func NewDNSOverUDPTransport(dialer model.Dialer, address string) *DNSOverUDPTransport {
|
func NewDNSOverUDPTransport(dialer model.Dialer, address string) *DNSOverUDPTransport {
|
||||||
return &DNSOverUDPTransport{dialer: dialer, address: address}
|
return &DNSOverUDPTransport{
|
||||||
|
dialer: dialer,
|
||||||
|
decoder: &DNSDecoderMiekg{},
|
||||||
|
address: address,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// RoundTrip sends a query and receives a reply.
|
// RoundTrip sends a query and receives a reply.
|
||||||
func (t *DNSOverUDPTransport) RoundTrip(ctx context.Context, query []byte) ([]byte, error) {
|
func (t *DNSOverUDPTransport) RoundTrip(
|
||||||
|
ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
|
||||||
|
rawQuery, err := query.Bytes()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
conn, err := t.dialer.DialContext(ctx, "udp", t.address)
|
conn, err := t.dialer.DialContext(ctx, "udp", t.address)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -37,19 +47,19 @@ func (t *DNSOverUDPTransport) RoundTrip(ctx context.Context, query []byte) ([]by
|
||||||
defer conn.Close()
|
defer conn.Close()
|
||||||
// Use five seconds timeout like Bionic does. See
|
// Use five seconds timeout like Bionic does. See
|
||||||
// https://labs.ripe.net/Members/baptiste_jonglez_1/persistent-dns-connections-for-reliability-and-performance
|
// https://labs.ripe.net/Members/baptiste_jonglez_1/persistent-dns-connections-for-reliability-and-performance
|
||||||
if err = conn.SetDeadline(time.Now().Add(5 * time.Second)); err != nil {
|
const iotimeout = 5 * time.Second
|
||||||
|
conn.SetDeadline(time.Now().Add(iotimeout))
|
||||||
|
if _, err = conn.Write(rawQuery); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if _, err = conn.Write(query); err != nil {
|
const maxmessagesize = 1 << 17
|
||||||
return nil, err
|
rawResponse := make([]byte, maxmessagesize)
|
||||||
}
|
count, err := conn.Read(rawResponse)
|
||||||
reply := make([]byte, 1<<17)
|
|
||||||
var n int
|
|
||||||
n, err = conn.Read(reply)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return reply[:n], nil
|
rawResponse = rawResponse[:count]
|
||||||
|
return t.decoder.DecodeResponse(rawResponse, query)
|
||||||
}
|
}
|
||||||
|
|
||||||
// RequiresPadding returns false for UDP according to RFC8467.
|
// RequiresPadding returns false for UDP according to RFC8467.
|
||||||
|
|
|
@ -9,11 +9,30 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/apex/log"
|
"github.com/apex/log"
|
||||||
|
"github.com/ooni/probe-cli/v3/internal/model"
|
||||||
"github.com/ooni/probe-cli/v3/internal/model/mocks"
|
"github.com/ooni/probe-cli/v3/internal/model/mocks"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestDNSOverUDPTransport(t *testing.T) {
|
func TestDNSOverUDPTransport(t *testing.T) {
|
||||||
t.Run("RoundTrip", func(t *testing.T) {
|
t.Run("RoundTrip", func(t *testing.T) {
|
||||||
|
t.Run("cannot encode query", func(t *testing.T) {
|
||||||
|
expected := errors.New("mocked error")
|
||||||
|
const address = "9.9.9.9:53"
|
||||||
|
txp := NewDNSOverUDPTransport(nil, address)
|
||||||
|
query := &mocks.DNSQuery{
|
||||||
|
MockBytes: func() ([]byte, error) {
|
||||||
|
return nil, expected
|
||||||
|
},
|
||||||
|
}
|
||||||
|
resp, err := txp.RoundTrip(context.Background(), query)
|
||||||
|
if !errors.Is(err, expected) {
|
||||||
|
t.Fatal("unexpected err", err)
|
||||||
|
}
|
||||||
|
if resp != nil {
|
||||||
|
t.Fatal("expected nil response here")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
t.Run("dial failure", func(t *testing.T) {
|
t.Run("dial failure", func(t *testing.T) {
|
||||||
mocked := errors.New("mocked error")
|
mocked := errors.New("mocked error")
|
||||||
const address = "9.9.9.9:53"
|
const address = "9.9.9.9:53"
|
||||||
|
@ -22,36 +41,16 @@ func TestDNSOverUDPTransport(t *testing.T) {
|
||||||
return nil, mocked
|
return nil, mocked
|
||||||
},
|
},
|
||||||
}, address)
|
}, address)
|
||||||
data, err := txp.RoundTrip(context.Background(), nil)
|
query := &mocks.DNSQuery{
|
||||||
|
MockBytes: func() ([]byte, error) {
|
||||||
|
return make([]byte, 128), nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
resp, err := txp.RoundTrip(context.Background(), query)
|
||||||
if !errors.Is(err, mocked) {
|
if !errors.Is(err, mocked) {
|
||||||
t.Fatal("not the error we expected")
|
t.Fatal("not the error we expected")
|
||||||
}
|
}
|
||||||
if data != nil {
|
if resp != nil {
|
||||||
t.Fatal("expected no response here")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("SetDeadline failure", func(t *testing.T) {
|
|
||||||
mocked := errors.New("mocked error")
|
|
||||||
txp := NewDNSOverUDPTransport(
|
|
||||||
&mocks.Dialer{
|
|
||||||
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
|
||||||
return &mocks.Conn{
|
|
||||||
MockSetDeadline: func(t time.Time) error {
|
|
||||||
return mocked
|
|
||||||
},
|
|
||||||
MockClose: func() error {
|
|
||||||
return nil
|
|
||||||
},
|
|
||||||
}, nil
|
|
||||||
},
|
|
||||||
}, "9.9.9.9:53",
|
|
||||||
)
|
|
||||||
data, err := txp.RoundTrip(context.Background(), nil)
|
|
||||||
if !errors.Is(err, mocked) {
|
|
||||||
t.Fatal("not the error we expected")
|
|
||||||
}
|
|
||||||
if data != nil {
|
|
||||||
t.Fatal("expected no response here")
|
t.Fatal("expected no response here")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
@ -75,11 +74,16 @@ func TestDNSOverUDPTransport(t *testing.T) {
|
||||||
},
|
},
|
||||||
}, "9.9.9.9:53",
|
}, "9.9.9.9:53",
|
||||||
)
|
)
|
||||||
data, err := txp.RoundTrip(context.Background(), nil)
|
query := &mocks.DNSQuery{
|
||||||
|
MockBytes: func() ([]byte, error) {
|
||||||
|
return make([]byte, 128), nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
resp, err := txp.RoundTrip(context.Background(), query)
|
||||||
if !errors.Is(err, mocked) {
|
if !errors.Is(err, mocked) {
|
||||||
t.Fatal("not the error we expected")
|
t.Fatal("not the error we expected")
|
||||||
}
|
}
|
||||||
if data != nil {
|
if resp != nil {
|
||||||
t.Fatal("expected no response here")
|
t.Fatal("expected no response here")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
@ -106,15 +110,61 @@ func TestDNSOverUDPTransport(t *testing.T) {
|
||||||
},
|
},
|
||||||
}, "9.9.9.9:53",
|
}, "9.9.9.9:53",
|
||||||
)
|
)
|
||||||
data, err := txp.RoundTrip(context.Background(), nil)
|
query := &mocks.DNSQuery{
|
||||||
|
MockBytes: func() ([]byte, error) {
|
||||||
|
return make([]byte, 128), nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
resp, err := txp.RoundTrip(context.Background(), query)
|
||||||
if !errors.Is(err, mocked) {
|
if !errors.Is(err, mocked) {
|
||||||
t.Fatal("not the error we expected")
|
t.Fatal("not the error we expected")
|
||||||
}
|
}
|
||||||
if data != nil {
|
if resp != nil {
|
||||||
t.Fatal("expected no response here")
|
t.Fatal("expected no response here")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("decode failure", func(t *testing.T) {
|
||||||
|
const expected = 17
|
||||||
|
input := bytes.NewReader(make([]byte, expected))
|
||||||
|
txp := NewDNSOverUDPTransport(
|
||||||
|
&mocks.Dialer{
|
||||||
|
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||||
|
return &mocks.Conn{
|
||||||
|
MockSetDeadline: func(t time.Time) error {
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
MockWrite: func(b []byte) (int, error) {
|
||||||
|
return len(b), nil
|
||||||
|
},
|
||||||
|
MockRead: input.Read,
|
||||||
|
MockClose: func() error {
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
}, "9.9.9.9:53",
|
||||||
|
)
|
||||||
|
expectedErr := errors.New("mocked error")
|
||||||
|
txp.decoder = &mocks.DNSDecoder{
|
||||||
|
MockDecodeResponse: func(data []byte, query model.DNSQuery) (model.DNSResponse, error) {
|
||||||
|
return nil, expectedErr
|
||||||
|
},
|
||||||
|
}
|
||||||
|
query := &mocks.DNSQuery{
|
||||||
|
MockBytes: func() ([]byte, error) {
|
||||||
|
return make([]byte, 128), nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
resp, err := txp.RoundTrip(context.Background(), query)
|
||||||
|
if !errors.Is(err, expectedErr) {
|
||||||
|
t.Fatal("unexpected err", err)
|
||||||
|
}
|
||||||
|
if resp != nil {
|
||||||
|
t.Fatal("expected nil resp")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
t.Run("read success", func(t *testing.T) {
|
t.Run("read success", func(t *testing.T) {
|
||||||
const expected = 17
|
const expected = 17
|
||||||
input := bytes.NewReader(make([]byte, expected))
|
input := bytes.NewReader(make([]byte, expected))
|
||||||
|
@ -136,12 +186,23 @@ func TestDNSOverUDPTransport(t *testing.T) {
|
||||||
},
|
},
|
||||||
}, "9.9.9.9:53",
|
}, "9.9.9.9:53",
|
||||||
)
|
)
|
||||||
data, err := txp.RoundTrip(context.Background(), nil)
|
expectedResp := &mocks.DNSResponse{}
|
||||||
|
txp.decoder = &mocks.DNSDecoder{
|
||||||
|
MockDecodeResponse: func(data []byte, query model.DNSQuery) (model.DNSResponse, error) {
|
||||||
|
return expectedResp, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
query := &mocks.DNSQuery{
|
||||||
|
MockBytes: func() ([]byte, error) {
|
||||||
|
return make([]byte, 128), nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
resp, err := txp.RoundTrip(context.Background(), query)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
if len(data) != expected {
|
if resp != expectedResp {
|
||||||
t.Fatal("expected non nil data")
|
t.Fatal("unexpected resp")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
|
@ -1,15 +1,12 @@
|
||||||
package filtering
|
package filtering
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"errors"
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
"github.com/ooni/probe-cli/v3/internal/netxlite"
|
|
||||||
"github.com/ooni/probe-cli/v3/internal/runtimex"
|
"github.com/ooni/probe-cli/v3/internal/runtimex"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -51,19 +48,13 @@ type DNSProxy struct {
|
||||||
// receive a query for the given domain.
|
// receive a query for the given domain.
|
||||||
OnQuery func(domain string) DNSAction
|
OnQuery func(domain string) DNSAction
|
||||||
|
|
||||||
// Upstream is the OPTIONAL upstream transport.
|
// UpstreamEndpoint is the OPTIONAL upstream transport endpoint.
|
||||||
Upstream DNSTransport
|
UpstreamEndpoint string
|
||||||
|
|
||||||
// mockableReply allows to mock DNSProxy.reply in tests.
|
// mockableReply allows to mock DNSProxy.reply in tests.
|
||||||
mockableReply func(query *dns.Msg) (*dns.Msg, error)
|
mockableReply func(query *dns.Msg) (*dns.Msg, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// DNSTransport is the type we expect from an upstream DNS transport.
|
|
||||||
type DNSTransport interface {
|
|
||||||
RoundTrip(ctx context.Context, query []byte) ([]byte, error)
|
|
||||||
CloseIdleConnections()
|
|
||||||
}
|
|
||||||
|
|
||||||
// DNSListener is the interface returned by DNSProxy.Start
|
// DNSListener is the interface returned by DNSProxy.Start
|
||||||
type DNSListener interface {
|
type DNSListener interface {
|
||||||
io.Closer
|
io.Closer
|
||||||
|
@ -204,23 +195,24 @@ func (p *DNSProxy) compose(query *dns.Msg, ips ...net.IP) *dns.Msg {
|
||||||
return reply
|
return reply
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
// errDNSExpectedSingleQuestion means we expected to see a single question
|
||||||
|
errDNSExpectedSingleQuestion = errors.New("filtering: expected single DNS question")
|
||||||
|
|
||||||
|
// errDNSExpectedQueryNotResponse means we expected to see a query.
|
||||||
|
errDNSExpectedQueryNotResponse = errors.New("filtering: expected query not response")
|
||||||
|
)
|
||||||
|
|
||||||
func (p *DNSProxy) proxy(query *dns.Msg) (*dns.Msg, error) {
|
func (p *DNSProxy) proxy(query *dns.Msg) (*dns.Msg, error) {
|
||||||
queryBytes, err := query.Pack()
|
if query.Response {
|
||||||
if err != nil {
|
return nil, errDNSExpectedQueryNotResponse
|
||||||
return nil, err
|
|
||||||
}
|
}
|
||||||
txp := p.dnstransport()
|
if len(query.Question) != 1 {
|
||||||
defer txp.CloseIdleConnections()
|
return nil, errDNSExpectedSingleQuestion
|
||||||
ctx := context.Background()
|
|
||||||
replyBytes, err := txp.RoundTrip(ctx, queryBytes)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
}
|
||||||
reply := &dns.Msg{}
|
clnt := &dns.Client{}
|
||||||
if err := reply.Unpack(replyBytes); err != nil {
|
resp, _, err := clnt.Exchange(query, p.upstreamEndpoint())
|
||||||
return nil, err
|
return resp, err
|
||||||
}
|
|
||||||
return reply, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *DNSProxy) cache(name string, query *dns.Msg) *dns.Msg {
|
func (p *DNSProxy) cache(name string, query *dns.Msg) *dns.Msg {
|
||||||
|
@ -237,10 +229,9 @@ func (p *DNSProxy) cache(name string, query *dns.Msg) *dns.Msg {
|
||||||
return p.compose(query, ipAddrs...)
|
return p.compose(query, ipAddrs...)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *DNSProxy) dnstransport() DNSTransport {
|
func (p *DNSProxy) upstreamEndpoint() string {
|
||||||
if p.Upstream != nil {
|
if p.UpstreamEndpoint != "" {
|
||||||
return p.Upstream
|
return p.UpstreamEndpoint
|
||||||
}
|
}
|
||||||
const URL = "https://1.1.1.1/dns-query"
|
return "8.8.8.8:53"
|
||||||
return netxlite.NewDNSOverHTTPSTransport(http.DefaultClient, URL)
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -283,9 +283,7 @@ func TestDNSProxy(t *testing.T) {
|
||||||
if len(p) < len(data) {
|
if len(p) < len(data) {
|
||||||
panic("buffer too small")
|
panic("buffer too small")
|
||||||
}
|
}
|
||||||
for i := 0; i < len(data); i++ {
|
copy(p, data)
|
||||||
p[i] = data[i]
|
|
||||||
}
|
|
||||||
return len(data), &net.UDPAddr{}, nil
|
return len(data), &net.UDPAddr{}, nil
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
@ -314,9 +312,7 @@ func TestDNSProxy(t *testing.T) {
|
||||||
if len(p) < len(data) {
|
if len(p) < len(data) {
|
||||||
panic("buffer too small")
|
panic("buffer too small")
|
||||||
}
|
}
|
||||||
for i := 0; i < len(data); i++ {
|
copy(p, data)
|
||||||
p[i] = data[i]
|
|
||||||
}
|
|
||||||
return len(data), &net.UDPAddr{}, nil
|
return len(data), &net.UDPAddr{}, nil
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
@ -328,12 +324,24 @@ func TestDNSProxy(t *testing.T) {
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("proxy", func(t *testing.T) {
|
t.Run("proxy", func(t *testing.T) {
|
||||||
t.Run("Pack fails", func(t *testing.T) {
|
t.Run("with response", func(t *testing.T) {
|
||||||
p := &DNSProxy{}
|
p := &DNSProxy{}
|
||||||
query := &dns.Msg{}
|
query := &dns.Msg{}
|
||||||
query.Rcode = -1 // causes Pack to fail
|
query.Response = true
|
||||||
reply, err := p.proxy(query)
|
reply, err := p.proxy(query)
|
||||||
if err == nil || !strings.HasSuffix(err.Error(), "bad rcode") {
|
if !errors.Is(err, errDNSExpectedQueryNotResponse) {
|
||||||
|
t.Fatal("unexpected err", err)
|
||||||
|
}
|
||||||
|
if reply != nil {
|
||||||
|
t.Fatal("expected nil reply")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("with no questions", func(t *testing.T) {
|
||||||
|
p := &DNSProxy{}
|
||||||
|
query := &dns.Msg{}
|
||||||
|
reply, err := p.proxy(query)
|
||||||
|
if !errors.Is(err, errDNSExpectedSingleQuestion) {
|
||||||
t.Fatal("unexpected err", err)
|
t.Fatal("unexpected err", err)
|
||||||
}
|
}
|
||||||
if reply != nil {
|
if reply != nil {
|
||||||
|
@ -342,35 +350,13 @@ func TestDNSProxy(t *testing.T) {
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("round trip fails", func(t *testing.T) {
|
t.Run("round trip fails", func(t *testing.T) {
|
||||||
expected := errors.New("mocked error")
|
|
||||||
p := &DNSProxy{
|
p := &DNSProxy{
|
||||||
Upstream: &mocks.DNSTransport{
|
UpstreamEndpoint: "antani",
|
||||||
MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) {
|
|
||||||
return nil, expected
|
|
||||||
},
|
|
||||||
MockCloseIdleConnections: func() {},
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
reply, err := p.proxy(&dns.Msg{})
|
query := &dns.Msg{}
|
||||||
if !errors.Is(err, expected) {
|
query.Question = append(query.Question, dns.Question{})
|
||||||
t.Fatal("unexpected err", err)
|
reply, err := p.proxy(query)
|
||||||
}
|
if err == nil || !strings.HasSuffix(err.Error(), "missing port in address") {
|
||||||
if reply != nil {
|
|
||||||
t.Fatal("expected nil reply here")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("Unpack fails", func(t *testing.T) {
|
|
||||||
p := &DNSProxy{
|
|
||||||
Upstream: &mocks.DNSTransport{
|
|
||||||
MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) {
|
|
||||||
return make([]byte, 1), nil
|
|
||||||
},
|
|
||||||
MockCloseIdleConnections: func() {},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
reply, err := p.proxy(&dns.Msg{})
|
|
||||||
if err == nil || !strings.HasSuffix(err.Error(), "overflow unpacking uint16") {
|
|
||||||
t.Fatal("unexpected err", err)
|
t.Fatal("unexpected err", err)
|
||||||
}
|
}
|
||||||
if reply != nil {
|
if reply != nil {
|
||||||
|
|
|
@ -9,7 +9,6 @@ import (
|
||||||
"net"
|
"net"
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
"github.com/ooni/probe-cli/v3/internal/atomicx"
|
|
||||||
"github.com/ooni/probe-cli/v3/internal/model"
|
"github.com/ooni/probe-cli/v3/internal/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -19,15 +18,6 @@ import (
|
||||||
// You should probably use NewUnwrappedParallelResolver to
|
// You should probably use NewUnwrappedParallelResolver to
|
||||||
// create a new instance of this type.
|
// create a new instance of this type.
|
||||||
type ParallelResolver struct {
|
type ParallelResolver struct {
|
||||||
// Encoder is the MANDATORY encoder to use.
|
|
||||||
Encoder model.DNSEncoder
|
|
||||||
|
|
||||||
// Decoder is the MANDATORY decoder to use.
|
|
||||||
Decoder model.DNSDecoder
|
|
||||||
|
|
||||||
// NumTimeouts is MANDATORY and counts the number of timeouts.
|
|
||||||
NumTimeouts *atomicx.Int64
|
|
||||||
|
|
||||||
// Txp is the MANDATORY underlying DNS transport.
|
// Txp is the MANDATORY underlying DNS transport.
|
||||||
Txp model.DNSTransport
|
Txp model.DNSTransport
|
||||||
}
|
}
|
||||||
|
@ -36,10 +26,7 @@ type ParallelResolver struct {
|
||||||
// not wrapped and you should wrap if before using it.
|
// not wrapped and you should wrap if before using it.
|
||||||
func NewUnwrappedParallelResolver(t model.DNSTransport) *ParallelResolver {
|
func NewUnwrappedParallelResolver(t model.DNSTransport) *ParallelResolver {
|
||||||
return &ParallelResolver{
|
return &ParallelResolver{
|
||||||
Encoder: &DNSEncoderMiekg{},
|
Txp: t,
|
||||||
Decoder: &DNSDecoderMiekg{},
|
|
||||||
NumTimeouts: &atomicx.Int64{},
|
|
||||||
Txp: t,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -80,22 +67,22 @@ func (r *ParallelResolver) LookupHost(ctx context.Context, hostname string) ([]s
|
||||||
var addrs []string
|
var addrs []string
|
||||||
addrs = append(addrs, ares.addrs...)
|
addrs = append(addrs, ares.addrs...)
|
||||||
addrs = append(addrs, aaaares.addrs...)
|
addrs = append(addrs, aaaares.addrs...)
|
||||||
|
if len(addrs) < 1 {
|
||||||
|
return nil, ErrOODNSNoAnswer
|
||||||
|
}
|
||||||
return addrs, nil
|
return addrs, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// LookupHTTPS implements Resolver.LookupHTTPS.
|
// LookupHTTPS implements Resolver.LookupHTTPS.
|
||||||
func (r *ParallelResolver) LookupHTTPS(
|
func (r *ParallelResolver) LookupHTTPS(
|
||||||
ctx context.Context, hostname string) (*model.HTTPSSvc, error) {
|
ctx context.Context, hostname string) (*model.HTTPSSvc, error) {
|
||||||
querydata, queryID, err := r.Encoder.Encode(
|
encoder := &DNSEncoderMiekg{}
|
||||||
hostname, dns.TypeHTTPS, r.Txp.RequiresPadding())
|
query := encoder.Encode(hostname, dns.TypeHTTPS, r.Txp.RequiresPadding())
|
||||||
|
response, err := r.Txp.RoundTrip(ctx, query)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
replydata, err := r.Txp.RoundTrip(ctx, querydata)
|
return response.DecodeHTTPS()
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return r.Decoder.DecodeHTTPS(replydata, queryID)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// parallelResolverResult is the internal representation of a
|
// parallelResolverResult is the internal representation of a
|
||||||
|
@ -108,7 +95,9 @@ type parallelResolverResult struct {
|
||||||
// lookupHost issues a lookup host query for the specified qtype (e.g., dns.A).
|
// lookupHost issues a lookup host query for the specified qtype (e.g., dns.A).
|
||||||
func (r *ParallelResolver) lookupHost(ctx context.Context, hostname string,
|
func (r *ParallelResolver) lookupHost(ctx context.Context, hostname string,
|
||||||
qtype uint16, out chan<- *parallelResolverResult) {
|
qtype uint16, out chan<- *parallelResolverResult) {
|
||||||
querydata, queryID, err := r.Encoder.Encode(hostname, qtype, r.Txp.RequiresPadding())
|
encoder := &DNSEncoderMiekg{}
|
||||||
|
query := encoder.Encode(hostname, qtype, r.Txp.RequiresPadding())
|
||||||
|
response, err := r.Txp.RoundTrip(ctx, query)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
out <- ¶llelResolverResult{
|
out <- ¶llelResolverResult{
|
||||||
addrs: []string{},
|
addrs: []string{},
|
||||||
|
@ -116,15 +105,7 @@ func (r *ParallelResolver) lookupHost(ctx context.Context, hostname string,
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
replydata, err := r.Txp.RoundTrip(ctx, querydata)
|
addrs, err := response.DecodeLookupHost()
|
||||||
if err != nil {
|
|
||||||
out <- ¶llelResolverResult{
|
|
||||||
addrs: []string{},
|
|
||||||
err: err,
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
addrs, err := r.Decoder.DecodeLookupHost(qtype, replydata, queryID)
|
|
||||||
out <- ¶llelResolverResult{
|
out <- ¶llelResolverResult{
|
||||||
addrs: addrs,
|
addrs: addrs,
|
||||||
err: err,
|
err: err,
|
||||||
|
@ -134,14 +115,11 @@ func (r *ParallelResolver) lookupHost(ctx context.Context, hostname string,
|
||||||
// LookupNS implements Resolver.LookupNS.
|
// LookupNS implements Resolver.LookupNS.
|
||||||
func (r *ParallelResolver) LookupNS(
|
func (r *ParallelResolver) LookupNS(
|
||||||
ctx context.Context, hostname string) ([]*net.NS, error) {
|
ctx context.Context, hostname string) ([]*net.NS, error) {
|
||||||
querydata, queryID, err := r.Encoder.Encode(
|
encoder := &DNSEncoderMiekg{}
|
||||||
hostname, dns.TypeNS, r.Txp.RequiresPadding())
|
query := encoder.Encode(hostname, dns.TypeNS, r.Txp.RequiresPadding())
|
||||||
|
response, err := r.Txp.RoundTrip(ctx, query)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
replydata, err := r.Txp.RoundTrip(ctx, querydata)
|
return response.DecodeNS()
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return r.Decoder.DecodeNS(replydata, queryID)
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -8,14 +8,13 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
"github.com/ooni/probe-cli/v3/internal/atomicx"
|
|
||||||
"github.com/ooni/probe-cli/v3/internal/model"
|
"github.com/ooni/probe-cli/v3/internal/model"
|
||||||
"github.com/ooni/probe-cli/v3/internal/model/mocks"
|
"github.com/ooni/probe-cli/v3/internal/model/mocks"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestParallelResolver(t *testing.T) {
|
func TestParallelResolver(t *testing.T) {
|
||||||
t.Run("transport okay", func(t *testing.T) {
|
t.Run("transport okay", func(t *testing.T) {
|
||||||
txp := NewDNSOverTLS((&tls.Dialer{}).DialContext, "8.8.8.8:853")
|
txp := NewDNSOverTLSTransport((&tls.Dialer{}).DialContext, "8.8.8.8:853")
|
||||||
r := NewUnwrappedParallelResolver(txp)
|
r := NewUnwrappedParallelResolver(txp)
|
||||||
rtx := r.Transport()
|
rtx := r.Transport()
|
||||||
if rtx.Network() != "dot" || rtx.Address() != "8.8.8.8:853" {
|
if rtx.Network() != "dot" || rtx.Address() != "8.8.8.8:853" {
|
||||||
|
@ -30,30 +29,10 @@ func TestParallelResolver(t *testing.T) {
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("LookupHost", func(t *testing.T) {
|
t.Run("LookupHost", func(t *testing.T) {
|
||||||
t.Run("Encode error", func(t *testing.T) {
|
|
||||||
mocked := errors.New("mocked error")
|
|
||||||
txp := NewDNSOverTLS((&tls.Dialer{}).DialContext, "8.8.8.8:853")
|
|
||||||
r := ParallelResolver{
|
|
||||||
Encoder: &mocks.DNSEncoder{
|
|
||||||
MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, uint16, error) {
|
|
||||||
return nil, 0, mocked
|
|
||||||
},
|
|
||||||
},
|
|
||||||
Txp: txp,
|
|
||||||
}
|
|
||||||
addrs, err := r.LookupHost(context.Background(), "www.gogle.com")
|
|
||||||
if !errors.Is(err, mocked) {
|
|
||||||
t.Fatal("not the error we expected")
|
|
||||||
}
|
|
||||||
if addrs != nil {
|
|
||||||
t.Fatal("expected nil address here")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("RoundTrip error", func(t *testing.T) {
|
t.Run("RoundTrip error", func(t *testing.T) {
|
||||||
mocked := errors.New("mocked error")
|
mocked := errors.New("mocked error")
|
||||||
txp := &mocks.DNSTransport{
|
txp := &mocks.DNSTransport{
|
||||||
MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) {
|
MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
|
||||||
return nil, mocked
|
return nil, mocked
|
||||||
},
|
},
|
||||||
MockRequiresPadding: func() bool {
|
MockRequiresPadding: func() bool {
|
||||||
|
@ -72,8 +51,13 @@ func TestParallelResolver(t *testing.T) {
|
||||||
|
|
||||||
t.Run("empty reply", func(t *testing.T) {
|
t.Run("empty reply", func(t *testing.T) {
|
||||||
txp := &mocks.DNSTransport{
|
txp := &mocks.DNSTransport{
|
||||||
MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) {
|
MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
|
||||||
return dnsGenLookupHostReplySuccess(query), nil
|
response := &mocks.DNSResponse{
|
||||||
|
MockDecodeLookupHost: func() ([]string, error) {
|
||||||
|
return nil, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return response, nil
|
||||||
},
|
},
|
||||||
MockRequiresPadding: func() bool {
|
MockRequiresPadding: func() bool {
|
||||||
return true
|
return true
|
||||||
|
@ -91,8 +75,16 @@ func TestParallelResolver(t *testing.T) {
|
||||||
|
|
||||||
t.Run("with A reply", func(t *testing.T) {
|
t.Run("with A reply", func(t *testing.T) {
|
||||||
txp := &mocks.DNSTransport{
|
txp := &mocks.DNSTransport{
|
||||||
MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) {
|
MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
|
||||||
return dnsGenLookupHostReplySuccess(query, "8.8.8.8"), nil
|
response := &mocks.DNSResponse{
|
||||||
|
MockDecodeLookupHost: func() ([]string, error) {
|
||||||
|
if query.Type() != dns.TypeA {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return []string{"8.8.8.8"}, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return response, nil
|
||||||
},
|
},
|
||||||
MockRequiresPadding: func() bool {
|
MockRequiresPadding: func() bool {
|
||||||
return true
|
return true
|
||||||
|
@ -110,8 +102,16 @@ func TestParallelResolver(t *testing.T) {
|
||||||
|
|
||||||
t.Run("with AAAA reply", func(t *testing.T) {
|
t.Run("with AAAA reply", func(t *testing.T) {
|
||||||
txp := &mocks.DNSTransport{
|
txp := &mocks.DNSTransport{
|
||||||
MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) {
|
MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
|
||||||
return dnsGenLookupHostReplySuccess(query, "::1"), nil
|
response := &mocks.DNSResponse{
|
||||||
|
MockDecodeLookupHost: func() ([]string, error) {
|
||||||
|
if query.Type() != dns.TypeAAAA {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return []string{"::1"}, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return response, nil
|
||||||
},
|
},
|
||||||
MockRequiresPadding: func() bool {
|
MockRequiresPadding: func() bool {
|
||||||
return true
|
return true
|
||||||
|
@ -131,22 +131,15 @@ func TestParallelResolver(t *testing.T) {
|
||||||
afailure := errors.New("a failure")
|
afailure := errors.New("a failure")
|
||||||
aaaafailure := errors.New("aaaa failure")
|
aaaafailure := errors.New("aaaa failure")
|
||||||
txp := &mocks.DNSTransport{
|
txp := &mocks.DNSTransport{
|
||||||
MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) {
|
MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
|
||||||
msg := &dns.Msg{}
|
switch query.Type() {
|
||||||
if err := msg.Unpack(query); err != nil {
|
case dns.TypeA:
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if len(msg.Question) != 1 {
|
|
||||||
return nil, errors.New("expected just one question")
|
|
||||||
}
|
|
||||||
q := msg.Question[0]
|
|
||||||
if q.Qtype == dns.TypeA {
|
|
||||||
return nil, afailure
|
return nil, afailure
|
||||||
}
|
case dns.TypeAAAA:
|
||||||
if q.Qtype == dns.TypeAAAA {
|
|
||||||
return nil, aaaafailure
|
return nil, aaaafailure
|
||||||
|
default:
|
||||||
|
return nil, errors.New("unexpected query")
|
||||||
}
|
}
|
||||||
return nil, errors.New("expected A or AAAA query")
|
|
||||||
},
|
},
|
||||||
MockRequiresPadding: func() bool {
|
MockRequiresPadding: func() bool {
|
||||||
return true
|
return true
|
||||||
|
@ -179,44 +172,11 @@ func TestParallelResolver(t *testing.T) {
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("LookupHTTPS", func(t *testing.T) {
|
t.Run("LookupHTTPS", func(t *testing.T) {
|
||||||
t.Run("for encoding error", func(t *testing.T) {
|
|
||||||
expected := errors.New("mocked error")
|
|
||||||
r := &ParallelResolver{
|
|
||||||
Encoder: &mocks.DNSEncoder{
|
|
||||||
MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, uint16, error) {
|
|
||||||
return nil, 0, expected
|
|
||||||
},
|
|
||||||
},
|
|
||||||
Decoder: nil,
|
|
||||||
NumTimeouts: &atomicx.Int64{},
|
|
||||||
Txp: &mocks.DNSTransport{
|
|
||||||
MockRequiresPadding: func() bool {
|
|
||||||
return false
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
ctx := context.Background()
|
|
||||||
https, err := r.LookupHTTPS(ctx, "example.com")
|
|
||||||
if !errors.Is(err, expected) {
|
|
||||||
t.Fatal("unexpected err", err)
|
|
||||||
}
|
|
||||||
if https != nil {
|
|
||||||
t.Fatal("unexpected result")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("for round-trip error", func(t *testing.T) {
|
t.Run("for round-trip error", func(t *testing.T) {
|
||||||
expected := errors.New("mocked error")
|
expected := errors.New("mocked error")
|
||||||
r := &ParallelResolver{
|
r := &ParallelResolver{
|
||||||
Encoder: &mocks.DNSEncoder{
|
|
||||||
MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, uint16, error) {
|
|
||||||
return make([]byte, 64), 0, nil
|
|
||||||
},
|
|
||||||
},
|
|
||||||
Decoder: nil,
|
|
||||||
NumTimeouts: &atomicx.Int64{},
|
|
||||||
Txp: &mocks.DNSTransport{
|
Txp: &mocks.DNSTransport{
|
||||||
MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) {
|
MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
|
||||||
return nil, expected
|
return nil, expected
|
||||||
},
|
},
|
||||||
MockRequiresPadding: func() bool {
|
MockRequiresPadding: func() bool {
|
||||||
|
@ -234,23 +194,17 @@ func TestParallelResolver(t *testing.T) {
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("for decode error", func(t *testing.T) {
|
t.Run("for DecodeHTTPS error", func(t *testing.T) {
|
||||||
expected := errors.New("mocked error")
|
expected := errors.New("mocked error")
|
||||||
r := &ParallelResolver{
|
r := &ParallelResolver{
|
||||||
Encoder: &mocks.DNSEncoder{
|
|
||||||
MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, uint16, error) {
|
|
||||||
return make([]byte, 64), 0, nil
|
|
||||||
},
|
|
||||||
},
|
|
||||||
Decoder: &mocks.DNSDecoder{
|
|
||||||
MockDecodeHTTPS: func(reply []byte, queryID uint16) (*model.HTTPSSvc, error) {
|
|
||||||
return nil, expected
|
|
||||||
},
|
|
||||||
},
|
|
||||||
NumTimeouts: &atomicx.Int64{},
|
|
||||||
Txp: &mocks.DNSTransport{
|
Txp: &mocks.DNSTransport{
|
||||||
MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) {
|
MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
|
||||||
return make([]byte, 128), nil
|
response := &mocks.DNSResponse{
|
||||||
|
MockDecodeHTTPS: func() (*model.HTTPSSvc, error) {
|
||||||
|
return nil, expected
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return response, nil
|
||||||
},
|
},
|
||||||
MockRequiresPadding: func() bool {
|
MockRequiresPadding: func() bool {
|
||||||
return false
|
return false
|
||||||
|
@ -269,44 +223,11 @@ func TestParallelResolver(t *testing.T) {
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("LookupNS", func(t *testing.T) {
|
t.Run("LookupNS", func(t *testing.T) {
|
||||||
t.Run("for encoding error", func(t *testing.T) {
|
|
||||||
expected := errors.New("mocked error")
|
|
||||||
r := &ParallelResolver{
|
|
||||||
Encoder: &mocks.DNSEncoder{
|
|
||||||
MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, uint16, error) {
|
|
||||||
return nil, 0, expected
|
|
||||||
},
|
|
||||||
},
|
|
||||||
Decoder: nil,
|
|
||||||
NumTimeouts: &atomicx.Int64{},
|
|
||||||
Txp: &mocks.DNSTransport{
|
|
||||||
MockRequiresPadding: func() bool {
|
|
||||||
return false
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
ctx := context.Background()
|
|
||||||
ns, err := r.LookupNS(ctx, "example.com")
|
|
||||||
if !errors.Is(err, expected) {
|
|
||||||
t.Fatal("unexpected err", err)
|
|
||||||
}
|
|
||||||
if ns != nil {
|
|
||||||
t.Fatal("unexpected result")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("for round-trip error", func(t *testing.T) {
|
t.Run("for round-trip error", func(t *testing.T) {
|
||||||
expected := errors.New("mocked error")
|
expected := errors.New("mocked error")
|
||||||
r := &ParallelResolver{
|
r := &ParallelResolver{
|
||||||
Encoder: &mocks.DNSEncoder{
|
|
||||||
MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, uint16, error) {
|
|
||||||
return make([]byte, 64), 0, nil
|
|
||||||
},
|
|
||||||
},
|
|
||||||
Decoder: nil,
|
|
||||||
NumTimeouts: &atomicx.Int64{},
|
|
||||||
Txp: &mocks.DNSTransport{
|
Txp: &mocks.DNSTransport{
|
||||||
MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) {
|
MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
|
||||||
return nil, expected
|
return nil, expected
|
||||||
},
|
},
|
||||||
MockRequiresPadding: func() bool {
|
MockRequiresPadding: func() bool {
|
||||||
|
@ -327,20 +248,14 @@ func TestParallelResolver(t *testing.T) {
|
||||||
t.Run("for decode error", func(t *testing.T) {
|
t.Run("for decode error", func(t *testing.T) {
|
||||||
expected := errors.New("mocked error")
|
expected := errors.New("mocked error")
|
||||||
r := &ParallelResolver{
|
r := &ParallelResolver{
|
||||||
Encoder: &mocks.DNSEncoder{
|
|
||||||
MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, uint16, error) {
|
|
||||||
return make([]byte, 64), 0, nil
|
|
||||||
},
|
|
||||||
},
|
|
||||||
Decoder: &mocks.DNSDecoder{
|
|
||||||
MockDecodeNS: func(reply []byte, queryID uint16) ([]*net.NS, error) {
|
|
||||||
return nil, expected
|
|
||||||
},
|
|
||||||
},
|
|
||||||
NumTimeouts: &atomicx.Int64{},
|
|
||||||
Txp: &mocks.DNSTransport{
|
Txp: &mocks.DNSTransport{
|
||||||
MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) {
|
MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
|
||||||
return make([]byte, 128), nil
|
response := &mocks.DNSResponse{
|
||||||
|
MockDecodeNS: func() ([]*net.NS, error) {
|
||||||
|
return nil, expected
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return response, nil
|
||||||
},
|
},
|
||||||
MockRequiresPadding: func() bool {
|
MockRequiresPadding: func() bool {
|
||||||
return false
|
return false
|
||||||
|
|
|
@ -26,12 +26,6 @@ import (
|
||||||
// QUIRK: unlike the ParallelResolver, this resolver's LookupHost retries
|
// QUIRK: unlike the ParallelResolver, this resolver's LookupHost retries
|
||||||
// each query three times for soft errors.
|
// each query three times for soft errors.
|
||||||
type SerialResolver struct {
|
type SerialResolver struct {
|
||||||
// Encoder is the MANDATORY encoder to use.
|
|
||||||
Encoder model.DNSEncoder
|
|
||||||
|
|
||||||
// Decoder is the MANDATORY decoder to use.
|
|
||||||
Decoder model.DNSDecoder
|
|
||||||
|
|
||||||
// NumTimeouts is MANDATORY and counts the number of timeouts.
|
// NumTimeouts is MANDATORY and counts the number of timeouts.
|
||||||
NumTimeouts *atomicx.Int64
|
NumTimeouts *atomicx.Int64
|
||||||
|
|
||||||
|
@ -42,8 +36,6 @@ type SerialResolver struct {
|
||||||
// NewSerialResolver creates a new SerialResolver instance.
|
// NewSerialResolver creates a new SerialResolver instance.
|
||||||
func NewSerialResolver(t model.DNSTransport) *SerialResolver {
|
func NewSerialResolver(t model.DNSTransport) *SerialResolver {
|
||||||
return &SerialResolver{
|
return &SerialResolver{
|
||||||
Encoder: &DNSEncoderMiekg{},
|
|
||||||
Decoder: &DNSDecoderMiekg{},
|
|
||||||
NumTimeouts: &atomicx.Int64{},
|
NumTimeouts: &atomicx.Int64{},
|
||||||
Txp: t,
|
Txp: t,
|
||||||
}
|
}
|
||||||
|
@ -82,22 +74,22 @@ func (r *SerialResolver) LookupHost(ctx context.Context, hostname string) ([]str
|
||||||
}
|
}
|
||||||
addrs = append(addrs, addrsA...)
|
addrs = append(addrs, addrsA...)
|
||||||
addrs = append(addrs, addrsAAAA...)
|
addrs = append(addrs, addrsAAAA...)
|
||||||
|
if len(addrs) < 1 {
|
||||||
|
return nil, ErrOODNSNoAnswer
|
||||||
|
}
|
||||||
return addrs, nil
|
return addrs, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// LookupHTTPS implements Resolver.LookupHTTPS.
|
// LookupHTTPS implements Resolver.LookupHTTPS.
|
||||||
func (r *SerialResolver) LookupHTTPS(
|
func (r *SerialResolver) LookupHTTPS(
|
||||||
ctx context.Context, hostname string) (*model.HTTPSSvc, error) {
|
ctx context.Context, hostname string) (*model.HTTPSSvc, error) {
|
||||||
querydata, queryID, err := r.Encoder.Encode(
|
encoder := &DNSEncoderMiekg{}
|
||||||
hostname, dns.TypeHTTPS, r.Txp.RequiresPadding())
|
query := encoder.Encode(hostname, dns.TypeHTTPS, r.Txp.RequiresPadding())
|
||||||
|
response, err := r.Txp.RoundTrip(ctx, query)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
replydata, err := r.Txp.RoundTrip(ctx, querydata)
|
return response.DecodeHTTPS()
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return r.Decoder.DecodeHTTPS(replydata, queryID)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *SerialResolver) lookupHostWithRetry(
|
func (r *SerialResolver) lookupHostWithRetry(
|
||||||
|
@ -132,28 +124,23 @@ func (r *SerialResolver) lookupHostWithRetry(
|
||||||
// qtype (dns.A or dns.AAAA) without retrying on failure.
|
// qtype (dns.A or dns.AAAA) without retrying on failure.
|
||||||
func (r *SerialResolver) lookupHostWithoutRetry(
|
func (r *SerialResolver) lookupHostWithoutRetry(
|
||||||
ctx context.Context, hostname string, qtype uint16) ([]string, error) {
|
ctx context.Context, hostname string, qtype uint16) ([]string, error) {
|
||||||
querydata, queryID, err := r.Encoder.Encode(hostname, qtype, r.Txp.RequiresPadding())
|
encoder := &DNSEncoderMiekg{}
|
||||||
|
query := encoder.Encode(hostname, qtype, r.Txp.RequiresPadding())
|
||||||
|
response, err := r.Txp.RoundTrip(ctx, query)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
replydata, err := r.Txp.RoundTrip(ctx, querydata)
|
return response.DecodeLookupHost()
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return r.Decoder.DecodeLookupHost(qtype, replydata, queryID)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// LookupNS implements Resolver.LookupNS.
|
// LookupNS implements Resolver.LookupNS.
|
||||||
func (r *SerialResolver) LookupNS(
|
func (r *SerialResolver) LookupNS(
|
||||||
ctx context.Context, hostname string) ([]*net.NS, error) {
|
ctx context.Context, hostname string) ([]*net.NS, error) {
|
||||||
querydata, queryID, err := r.Encoder.Encode(
|
encoder := &DNSEncoderMiekg{}
|
||||||
hostname, dns.TypeNS, r.Txp.RequiresPadding())
|
query := encoder.Encode(hostname, dns.TypeNS, r.Txp.RequiresPadding())
|
||||||
|
response, err := r.Txp.RoundTrip(ctx, query)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
replydata, err := r.Txp.RoundTrip(ctx, querydata)
|
return response.DecodeNS()
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return r.Decoder.DecodeNS(replydata, queryID)
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,6 +7,7 @@ import (
|
||||||
"net"
|
"net"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/miekg/dns"
|
||||||
"github.com/ooni/probe-cli/v3/internal/atomicx"
|
"github.com/ooni/probe-cli/v3/internal/atomicx"
|
||||||
"github.com/ooni/probe-cli/v3/internal/model"
|
"github.com/ooni/probe-cli/v3/internal/model"
|
||||||
"github.com/ooni/probe-cli/v3/internal/model/mocks"
|
"github.com/ooni/probe-cli/v3/internal/model/mocks"
|
||||||
|
@ -30,7 +31,7 @@ func (err *errorWithTimeout) Unwrap() error {
|
||||||
|
|
||||||
func TestSerialResolver(t *testing.T) {
|
func TestSerialResolver(t *testing.T) {
|
||||||
t.Run("transport okay", func(t *testing.T) {
|
t.Run("transport okay", func(t *testing.T) {
|
||||||
txp := NewDNSOverTLS((&tls.Dialer{}).DialContext, "8.8.8.8:853")
|
txp := NewDNSOverTLSTransport((&tls.Dialer{}).DialContext, "8.8.8.8:853")
|
||||||
r := NewSerialResolver(txp)
|
r := NewSerialResolver(txp)
|
||||||
rtx := r.Transport()
|
rtx := r.Transport()
|
||||||
if rtx.Network() != "dot" || rtx.Address() != "8.8.8.8:853" {
|
if rtx.Network() != "dot" || rtx.Address() != "8.8.8.8:853" {
|
||||||
|
@ -45,30 +46,10 @@ func TestSerialResolver(t *testing.T) {
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("LookupHost", func(t *testing.T) {
|
t.Run("LookupHost", func(t *testing.T) {
|
||||||
t.Run("Encode error", func(t *testing.T) {
|
|
||||||
mocked := errors.New("mocked error")
|
|
||||||
txp := NewDNSOverTLS((&tls.Dialer{}).DialContext, "8.8.8.8:853")
|
|
||||||
r := SerialResolver{
|
|
||||||
Encoder: &mocks.DNSEncoder{
|
|
||||||
MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, uint16, error) {
|
|
||||||
return nil, 0, mocked
|
|
||||||
},
|
|
||||||
},
|
|
||||||
Txp: txp,
|
|
||||||
}
|
|
||||||
addrs, err := r.LookupHost(context.Background(), "www.gogle.com")
|
|
||||||
if !errors.Is(err, mocked) {
|
|
||||||
t.Fatal("not the error we expected")
|
|
||||||
}
|
|
||||||
if addrs != nil {
|
|
||||||
t.Fatal("expected nil address here")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("RoundTrip error", func(t *testing.T) {
|
t.Run("RoundTrip error", func(t *testing.T) {
|
||||||
mocked := errors.New("mocked error")
|
mocked := errors.New("mocked error")
|
||||||
txp := &mocks.DNSTransport{
|
txp := &mocks.DNSTransport{
|
||||||
MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) {
|
MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
|
||||||
return nil, mocked
|
return nil, mocked
|
||||||
},
|
},
|
||||||
MockRequiresPadding: func() bool {
|
MockRequiresPadding: func() bool {
|
||||||
|
@ -87,8 +68,13 @@ func TestSerialResolver(t *testing.T) {
|
||||||
|
|
||||||
t.Run("empty reply", func(t *testing.T) {
|
t.Run("empty reply", func(t *testing.T) {
|
||||||
txp := &mocks.DNSTransport{
|
txp := &mocks.DNSTransport{
|
||||||
MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) {
|
MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
|
||||||
return dnsGenLookupHostReplySuccess(query), nil
|
response := &mocks.DNSResponse{
|
||||||
|
MockDecodeLookupHost: func() ([]string, error) {
|
||||||
|
return nil, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return response, nil
|
||||||
},
|
},
|
||||||
MockRequiresPadding: func() bool {
|
MockRequiresPadding: func() bool {
|
||||||
return true
|
return true
|
||||||
|
@ -106,8 +92,16 @@ func TestSerialResolver(t *testing.T) {
|
||||||
|
|
||||||
t.Run("with A reply", func(t *testing.T) {
|
t.Run("with A reply", func(t *testing.T) {
|
||||||
txp := &mocks.DNSTransport{
|
txp := &mocks.DNSTransport{
|
||||||
MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) {
|
MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
|
||||||
return dnsGenLookupHostReplySuccess(query, "8.8.8.8"), nil
|
response := &mocks.DNSResponse{
|
||||||
|
MockDecodeLookupHost: func() ([]string, error) {
|
||||||
|
if query.Type() != dns.TypeA {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return []string{"8.8.8.8"}, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return response, nil
|
||||||
},
|
},
|
||||||
MockRequiresPadding: func() bool {
|
MockRequiresPadding: func() bool {
|
||||||
return true
|
return true
|
||||||
|
@ -125,8 +119,16 @@ func TestSerialResolver(t *testing.T) {
|
||||||
|
|
||||||
t.Run("with AAAA reply", func(t *testing.T) {
|
t.Run("with AAAA reply", func(t *testing.T) {
|
||||||
txp := &mocks.DNSTransport{
|
txp := &mocks.DNSTransport{
|
||||||
MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) {
|
MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
|
||||||
return dnsGenLookupHostReplySuccess(query, "::1"), nil
|
response := &mocks.DNSResponse{
|
||||||
|
MockDecodeLookupHost: func() ([]string, error) {
|
||||||
|
if query.Type() != dns.TypeAAAA {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return []string{"::1"}, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return response, nil
|
||||||
},
|
},
|
||||||
MockRequiresPadding: func() bool {
|
MockRequiresPadding: func() bool {
|
||||||
return true
|
return true
|
||||||
|
@ -144,11 +146,12 @@ func TestSerialResolver(t *testing.T) {
|
||||||
|
|
||||||
t.Run("with timeout", func(t *testing.T) {
|
t.Run("with timeout", func(t *testing.T) {
|
||||||
txp := &mocks.DNSTransport{
|
txp := &mocks.DNSTransport{
|
||||||
MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) {
|
MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
|
||||||
return nil, &net.OpError{
|
err := &net.OpError{
|
||||||
Err: &errorWithTimeout{ETIMEDOUT},
|
Err: &errorWithTimeout{ETIMEDOUT},
|
||||||
Op: "dial",
|
Op: "dial",
|
||||||
}
|
}
|
||||||
|
return nil, err
|
||||||
},
|
},
|
||||||
MockRequiresPadding: func() bool {
|
MockRequiresPadding: func() bool {
|
||||||
return true
|
return true
|
||||||
|
@ -184,44 +187,12 @@ func TestSerialResolver(t *testing.T) {
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("LookupHTTPS", func(t *testing.T) {
|
t.Run("LookupHTTPS", func(t *testing.T) {
|
||||||
t.Run("for encoding error", func(t *testing.T) {
|
|
||||||
expected := errors.New("mocked error")
|
|
||||||
r := &SerialResolver{
|
|
||||||
Encoder: &mocks.DNSEncoder{
|
|
||||||
MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, uint16, error) {
|
|
||||||
return nil, 0, expected
|
|
||||||
},
|
|
||||||
},
|
|
||||||
Decoder: nil,
|
|
||||||
NumTimeouts: &atomicx.Int64{},
|
|
||||||
Txp: &mocks.DNSTransport{
|
|
||||||
MockRequiresPadding: func() bool {
|
|
||||||
return false
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
ctx := context.Background()
|
|
||||||
https, err := r.LookupHTTPS(ctx, "example.com")
|
|
||||||
if !errors.Is(err, expected) {
|
|
||||||
t.Fatal("unexpected err", err)
|
|
||||||
}
|
|
||||||
if https != nil {
|
|
||||||
t.Fatal("unexpected result")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("for round-trip error", func(t *testing.T) {
|
t.Run("for round-trip error", func(t *testing.T) {
|
||||||
expected := errors.New("mocked error")
|
expected := errors.New("mocked error")
|
||||||
r := &SerialResolver{
|
r := &SerialResolver{
|
||||||
Encoder: &mocks.DNSEncoder{
|
|
||||||
MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, uint16, error) {
|
|
||||||
return make([]byte, 64), 0, nil
|
|
||||||
},
|
|
||||||
},
|
|
||||||
Decoder: nil,
|
|
||||||
NumTimeouts: &atomicx.Int64{},
|
NumTimeouts: &atomicx.Int64{},
|
||||||
Txp: &mocks.DNSTransport{
|
Txp: &mocks.DNSTransport{
|
||||||
MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) {
|
MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
|
||||||
return nil, expected
|
return nil, expected
|
||||||
},
|
},
|
||||||
MockRequiresPadding: func() bool {
|
MockRequiresPadding: func() bool {
|
||||||
|
@ -242,20 +213,15 @@ func TestSerialResolver(t *testing.T) {
|
||||||
t.Run("for decode error", func(t *testing.T) {
|
t.Run("for decode error", func(t *testing.T) {
|
||||||
expected := errors.New("mocked error")
|
expected := errors.New("mocked error")
|
||||||
r := &SerialResolver{
|
r := &SerialResolver{
|
||||||
Encoder: &mocks.DNSEncoder{
|
|
||||||
MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, uint16, error) {
|
|
||||||
return make([]byte, 64), 0, nil
|
|
||||||
},
|
|
||||||
},
|
|
||||||
Decoder: &mocks.DNSDecoder{
|
|
||||||
MockDecodeHTTPS: func(reply []byte, queryID uint16) (*model.HTTPSSvc, error) {
|
|
||||||
return nil, expected
|
|
||||||
},
|
|
||||||
},
|
|
||||||
NumTimeouts: &atomicx.Int64{},
|
NumTimeouts: &atomicx.Int64{},
|
||||||
Txp: &mocks.DNSTransport{
|
Txp: &mocks.DNSTransport{
|
||||||
MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) {
|
MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
|
||||||
return make([]byte, 128), nil
|
response := &mocks.DNSResponse{
|
||||||
|
MockDecodeHTTPS: func() (*model.HTTPSSvc, error) {
|
||||||
|
return nil, expected
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return response, nil
|
||||||
},
|
},
|
||||||
MockRequiresPadding: func() bool {
|
MockRequiresPadding: func() bool {
|
||||||
return false
|
return false
|
||||||
|
@ -274,44 +240,12 @@ func TestSerialResolver(t *testing.T) {
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("LookupNS", func(t *testing.T) {
|
t.Run("LookupNS", func(t *testing.T) {
|
||||||
t.Run("for encoding error", func(t *testing.T) {
|
|
||||||
expected := errors.New("mocked error")
|
|
||||||
r := &SerialResolver{
|
|
||||||
Encoder: &mocks.DNSEncoder{
|
|
||||||
MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, uint16, error) {
|
|
||||||
return nil, 0, expected
|
|
||||||
},
|
|
||||||
},
|
|
||||||
Decoder: nil,
|
|
||||||
NumTimeouts: &atomicx.Int64{},
|
|
||||||
Txp: &mocks.DNSTransport{
|
|
||||||
MockRequiresPadding: func() bool {
|
|
||||||
return false
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
ctx := context.Background()
|
|
||||||
ns, err := r.LookupNS(ctx, "example.com")
|
|
||||||
if !errors.Is(err, expected) {
|
|
||||||
t.Fatal("unexpected err", err)
|
|
||||||
}
|
|
||||||
if ns != nil {
|
|
||||||
t.Fatal("unexpected result")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("for round-trip error", func(t *testing.T) {
|
t.Run("for round-trip error", func(t *testing.T) {
|
||||||
expected := errors.New("mocked error")
|
expected := errors.New("mocked error")
|
||||||
r := &SerialResolver{
|
r := &SerialResolver{
|
||||||
Encoder: &mocks.DNSEncoder{
|
|
||||||
MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, uint16, error) {
|
|
||||||
return make([]byte, 64), 0, nil
|
|
||||||
},
|
|
||||||
},
|
|
||||||
Decoder: nil,
|
|
||||||
NumTimeouts: &atomicx.Int64{},
|
NumTimeouts: &atomicx.Int64{},
|
||||||
Txp: &mocks.DNSTransport{
|
Txp: &mocks.DNSTransport{
|
||||||
MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) {
|
MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
|
||||||
return nil, expected
|
return nil, expected
|
||||||
},
|
},
|
||||||
MockRequiresPadding: func() bool {
|
MockRequiresPadding: func() bool {
|
||||||
|
@ -332,20 +266,15 @@ func TestSerialResolver(t *testing.T) {
|
||||||
t.Run("for decode error", func(t *testing.T) {
|
t.Run("for decode error", func(t *testing.T) {
|
||||||
expected := errors.New("mocked error")
|
expected := errors.New("mocked error")
|
||||||
r := &SerialResolver{
|
r := &SerialResolver{
|
||||||
Encoder: &mocks.DNSEncoder{
|
|
||||||
MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, uint16, error) {
|
|
||||||
return make([]byte, 64), 0, nil
|
|
||||||
},
|
|
||||||
},
|
|
||||||
Decoder: &mocks.DNSDecoder{
|
|
||||||
MockDecodeNS: func(reply []byte, queryID uint16) ([]*net.NS, error) {
|
|
||||||
return nil, expected
|
|
||||||
},
|
|
||||||
},
|
|
||||||
NumTimeouts: &atomicx.Int64{},
|
NumTimeouts: &atomicx.Int64{},
|
||||||
Txp: &mocks.DNSTransport{
|
Txp: &mocks.DNSTransport{
|
||||||
MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) {
|
MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
|
||||||
return make([]byte, 128), nil
|
response := &mocks.DNSResponse{
|
||||||
|
MockDecodeNS: func() ([]*net.NS, error) {
|
||||||
|
return nil, expected
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return response, nil
|
||||||
},
|
},
|
||||||
MockRequiresPadding: func() bool {
|
MockRequiresPadding: func() bool {
|
||||||
return false
|
return false
|
||||||
|
|
Loading…
Reference in New Issue
Block a user