From 1eb9e8c9b06421c2d9ca5e70ced7844491606c40 Mon Sep 17 00:00:00 2001 From: Simone Basso Date: Thu, 9 Sep 2021 20:49:12 +0200 Subject: [PATCH] refactor(netx/resolver): add CloseIdleConnections to RoundTripper (#501) While there, also change to pointer receiver and use internal testing for what are clearly unit tests. Part of https://github.com/ooni/probe/issues/1591. --- .../experiment/urlgetter/configurer_test.go | 8 +- internal/engine/netx/netx_test.go | 20 ++--- internal/engine/netx/resolver/dnsoverhttps.go | 32 ++++--- .../engine/netx/resolver/dnsoverhttps_test.go | 84 +++++++++++-------- internal/engine/netx/resolver/dnsovertcp.go | 23 +++-- .../engine/netx/resolver/dnsovertcp_test.go | 34 ++++---- internal/engine/netx/resolver/dnsoverudp.go | 19 +++-- .../engine/netx/resolver/dnsoverudp_test.go | 32 ++++--- internal/engine/netx/resolver/emitter_test.go | 9 +- internal/engine/netx/resolver/fake_test.go | 4 + internal/engine/netx/resolver/serial.go | 5 +- internal/netxlite/mocks/http.go | 17 ++++ internal/netxlite/mocks/http_test.go | 31 +++++++ 13 files changed, 203 insertions(+), 115 deletions(-) diff --git a/internal/engine/experiment/urlgetter/configurer_test.go b/internal/engine/experiment/urlgetter/configurer_test.go index 08f27c3..2d96a17 100644 --- a/internal/engine/experiment/urlgetter/configurer_test.go +++ b/internal/engine/experiment/urlgetter/configurer_test.go @@ -124,7 +124,7 @@ func TestConfigurerNewConfigurationResolverDNSOverHTTPSPowerdns(t *testing.T) { if !ok { t.Fatal("not the DNS transport we expected") } - dohtxp, ok := stxp.RoundTripper.(resolver.DNSOverHTTPS) + dohtxp, ok := stxp.RoundTripper.(*resolver.DNSOverHTTPS) if !ok { t.Fatal("not the DNS transport we expected") } @@ -200,7 +200,7 @@ func TestConfigurerNewConfigurationResolverDNSOverHTTPSGoogle(t *testing.T) { if !ok { t.Fatal("not the DNS transport we expected") } - dohtxp, ok := stxp.RoundTripper.(resolver.DNSOverHTTPS) + dohtxp, ok := stxp.RoundTripper.(*resolver.DNSOverHTTPS) if !ok { t.Fatal("not the DNS transport we expected") } @@ -276,7 +276,7 @@ func TestConfigurerNewConfigurationResolverDNSOverHTTPSCloudflare(t *testing.T) if !ok { t.Fatal("not the DNS transport we expected") } - dohtxp, ok := stxp.RoundTripper.(resolver.DNSOverHTTPS) + dohtxp, ok := stxp.RoundTripper.(*resolver.DNSOverHTTPS) if !ok { t.Fatal("not the DNS transport we expected") } @@ -352,7 +352,7 @@ func TestConfigurerNewConfigurationResolverUDP(t *testing.T) { if !ok { t.Fatal("not the DNS transport we expected") } - udptxp, ok := stxp.RoundTripper.(resolver.DNSOverUDP) + udptxp, ok := stxp.RoundTripper.(*resolver.DNSOverUDP) if !ok { t.Fatal("not the DNS transport we expected") } diff --git a/internal/engine/netx/netx_test.go b/internal/engine/netx/netx_test.go index 228a63d..a7a36e1 100644 --- a/internal/engine/netx/netx_test.go +++ b/internal/engine/netx/netx_test.go @@ -680,7 +680,7 @@ func TestNewDNSClientPowerdnsDoH(t *testing.T) { if !ok { t.Fatal("not the resolver we expected") } - if _, ok := r.Transport().(resolver.DNSOverHTTPS); !ok { + if _, ok := r.Transport().(*resolver.DNSOverHTTPS); !ok { t.Fatal("not the transport we expected") } dnsclient.CloseIdleConnections() @@ -696,7 +696,7 @@ func TestNewDNSClientGoogleDoH(t *testing.T) { if !ok { t.Fatal("not the resolver we expected") } - if _, ok := r.Transport().(resolver.DNSOverHTTPS); !ok { + if _, ok := r.Transport().(*resolver.DNSOverHTTPS); !ok { t.Fatal("not the transport we expected") } dnsclient.CloseIdleConnections() @@ -712,7 +712,7 @@ func TestNewDNSClientCloudflareDoH(t *testing.T) { if !ok { t.Fatal("not the resolver we expected") } - if _, ok := r.Transport().(resolver.DNSOverHTTPS); !ok { + if _, ok := r.Transport().(*resolver.DNSOverHTTPS); !ok { t.Fatal("not the transport we expected") } dnsclient.CloseIdleConnections() @@ -733,7 +733,7 @@ func TestNewDNSClientCloudflareDoHSaver(t *testing.T) { if !ok { t.Fatal("not the transport we expected") } - if _, ok := txp.RoundTripper.(resolver.DNSOverHTTPS); !ok { + if _, ok := txp.RoundTripper.(*resolver.DNSOverHTTPS); !ok { t.Fatal("not the transport we expected") } dnsclient.CloseIdleConnections() @@ -749,7 +749,7 @@ func TestNewDNSClientUDP(t *testing.T) { if !ok { t.Fatal("not the resolver we expected") } - if _, ok := r.Transport().(resolver.DNSOverUDP); !ok { + if _, ok := r.Transport().(*resolver.DNSOverUDP); !ok { t.Fatal("not the transport we expected") } dnsclient.CloseIdleConnections() @@ -770,7 +770,7 @@ func TestNewDNSClientUDPDNSSaver(t *testing.T) { if !ok { t.Fatal("not the transport we expected") } - if _, ok := txp.RoundTripper.(resolver.DNSOverUDP); !ok { + if _, ok := txp.RoundTripper.(*resolver.DNSOverUDP); !ok { t.Fatal("not the transport we expected") } dnsclient.CloseIdleConnections() @@ -786,7 +786,7 @@ func TestNewDNSClientTCP(t *testing.T) { if !ok { t.Fatal("not the resolver we expected") } - txp, ok := r.Transport().(resolver.DNSOverTCP) + txp, ok := r.Transport().(*resolver.DNSOverTCP) if !ok { t.Fatal("not the transport we expected") } @@ -811,7 +811,7 @@ func TestNewDNSClientTCPDNSSaver(t *testing.T) { if !ok { t.Fatal("not the transport we expected") } - dotcp, ok := txp.RoundTripper.(resolver.DNSOverTCP) + dotcp, ok := txp.RoundTripper.(*resolver.DNSOverTCP) if !ok { t.Fatal("not the transport we expected") } @@ -831,7 +831,7 @@ func TestNewDNSClientDoT(t *testing.T) { if !ok { t.Fatal("not the resolver we expected") } - txp, ok := r.Transport().(resolver.DNSOverTCP) + txp, ok := r.Transport().(*resolver.DNSOverTCP) if !ok { t.Fatal("not the transport we expected") } @@ -856,7 +856,7 @@ func TestNewDNSClientDoTDNSSaver(t *testing.T) { if !ok { t.Fatal("not the transport we expected") } - dotls, ok := txp.RoundTripper.(resolver.DNSOverTCP) + dotls, ok := txp.RoundTripper.(*resolver.DNSOverTCP) if !ok { t.Fatal("not the transport we expected") } diff --git a/internal/engine/netx/resolver/dnsoverhttps.go b/internal/engine/netx/resolver/dnsoverhttps.go index 4dc25a3..d98fc11 100644 --- a/internal/engine/netx/resolver/dnsoverhttps.go +++ b/internal/engine/netx/resolver/dnsoverhttps.go @@ -11,28 +11,35 @@ import ( "github.com/ooni/probe-cli/v3/internal/netxlite/iox" ) +// HTTPClient is the HTTP client expected by DNSOverHTTPS. +type HTTPClient interface { + Do(req *http.Request) (*http.Response, error) + CloseIdleConnections() +} + // DNSOverHTTPS is a DNS over HTTPS RoundTripper. Requests are submitted over // an HTTP/HTTPS channel provided by URL using the Do function. type DNSOverHTTPS struct { - Do func(req *http.Request) (*http.Response, error) + Client HTTPClient URL string HostOverride string } // NewDNSOverHTTPS creates a new DNSOverHTTP instance from the // specified http.Client and URL, as a convenience. -func NewDNSOverHTTPS(client *http.Client, URL string) DNSOverHTTPS { +func NewDNSOverHTTPS(client *http.Client, URL string) *DNSOverHTTPS { return NewDNSOverHTTPSWithHostOverride(client, URL, "") } // NewDNSOverHTTPSWithHostOverride is like NewDNSOverHTTPS except that // it's creating a resolver where we use the specified host. -func NewDNSOverHTTPSWithHostOverride(client *http.Client, URL, hostOverride string) DNSOverHTTPS { - return DNSOverHTTPS{Do: client.Do, URL: URL, HostOverride: hostOverride} +func NewDNSOverHTTPSWithHostOverride( + client *http.Client, URL, hostOverride string) *DNSOverHTTPS { + return &DNSOverHTTPS{Client: client, URL: URL, HostOverride: hostOverride} } // RoundTrip implements RoundTripper.RoundTrip. -func (t DNSOverHTTPS) RoundTrip(ctx context.Context, query []byte) ([]byte, error) { +func (t *DNSOverHTTPS) RoundTrip(ctx context.Context, query []byte) ([]byte, error) { ctx, cancel := context.WithTimeout(ctx, 45*time.Second) defer cancel() req, err := http.NewRequest("POST", t.URL, bytes.NewReader(query)) @@ -43,7 +50,7 @@ func (t DNSOverHTTPS) RoundTrip(ctx context.Context, query []byte) ([]byte, erro req.Header.Set("user-agent", httpheader.UserAgent()) req.Header.Set("content-type", "application/dns-message") var resp *http.Response - resp, err = t.Do(req.WithContext(ctx)) + resp, err = t.Client.Do(req.WithContext(ctx)) if err != nil { return nil, err } @@ -60,18 +67,23 @@ func (t DNSOverHTTPS) RoundTrip(ctx context.Context, query []byte) ([]byte, erro } // RequiresPadding returns true for DoH according to RFC8467 -func (t DNSOverHTTPS) RequiresPadding() bool { +func (t *DNSOverHTTPS) RequiresPadding() bool { return true } // Network returns the transport network (e.g., doh, dot) -func (t DNSOverHTTPS) Network() string { +func (t *DNSOverHTTPS) Network() string { return "doh" } // Address returns the upstream server address. -func (t DNSOverHTTPS) Address() string { +func (t *DNSOverHTTPS) Address() string { return t.URL } -var _ RoundTripper = DNSOverHTTPS{} +// CloseIdleConnections closes idle connections. +func (t *DNSOverHTTPS) CloseIdleConnections() { + t.Client.CloseIdleConnections() +} + +var _ RoundTripper = &DNSOverHTTPS{} diff --git a/internal/engine/netx/resolver/dnsoverhttps_test.go b/internal/engine/netx/resolver/dnsoverhttps_test.go index dedc7f0..c6e5350 100644 --- a/internal/engine/netx/resolver/dnsoverhttps_test.go +++ b/internal/engine/netx/resolver/dnsoverhttps_test.go @@ -1,4 +1,4 @@ -package resolver_test +package resolver import ( "bytes" @@ -10,12 +10,12 @@ import ( "testing" "github.com/ooni/probe-cli/v3/internal/engine/httpheader" - "github.com/ooni/probe-cli/v3/internal/engine/netx/resolver" + "github.com/ooni/probe-cli/v3/internal/netxlite/mocks" ) func TestDNSOverHTTPSNewRequestFailure(t *testing.T) { const invalidURL = "\t" - txp := resolver.NewDNSOverHTTPS(http.DefaultClient, invalidURL) + txp := NewDNSOverHTTPS(http.DefaultClient, invalidURL) data, err := txp.RoundTrip(context.Background(), nil) if err == nil || !strings.HasSuffix(err.Error(), "invalid control character in URL") { t.Fatal("expected an error here") @@ -27,9 +27,11 @@ func TestDNSOverHTTPSNewRequestFailure(t *testing.T) { func TestDNSOverHTTPSClientDoFailure(t *testing.T) { expected := errors.New("mocked error") - txp := resolver.DNSOverHTTPS{ - Do: func(*http.Request) (*http.Response, error) { - return nil, expected + txp := &DNSOverHTTPS{ + Client: &mocks.HTTPClient{ + MockDo: func(*http.Request) (*http.Response, error) { + return nil, expected + }, }, URL: "https://cloudflare-dns.com/dns-query", } @@ -43,12 +45,14 @@ func TestDNSOverHTTPSClientDoFailure(t *testing.T) { } func TestDNSOverHTTPSHTTPFailure(t *testing.T) { - txp := resolver.DNSOverHTTPS{ - Do: func(*http.Request) (*http.Response, error) { - return &http.Response{ - StatusCode: 500, - Body: io.NopCloser(strings.NewReader("")), - }, nil + txp := &DNSOverHTTPS{ + Client: &mocks.HTTPClient{ + MockDo: func(*http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: 500, + Body: io.NopCloser(strings.NewReader("")), + }, nil + }, }, URL: "https://cloudflare-dns.com/dns-query", } @@ -62,12 +66,14 @@ func TestDNSOverHTTPSHTTPFailure(t *testing.T) { } func TestDNSOverHTTPSMissingContentType(t *testing.T) { - txp := resolver.DNSOverHTTPS{ - Do: func(*http.Request) (*http.Response, error) { - return &http.Response{ - StatusCode: 200, - Body: io.NopCloser(strings.NewReader("")), - }, nil + txp := &DNSOverHTTPS{ + Client: &mocks.HTTPClient{ + MockDo: func(*http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: 200, + Body: io.NopCloser(strings.NewReader("")), + }, nil + }, }, URL: "https://cloudflare-dns.com/dns-query", } @@ -82,15 +88,17 @@ func TestDNSOverHTTPSMissingContentType(t *testing.T) { func TestDNSOverHTTPSSuccess(t *testing.T) { body := []byte("AAA") - txp := resolver.DNSOverHTTPS{ - Do: func(*http.Request) (*http.Response, error) { - return &http.Response{ - StatusCode: 200, - Body: io.NopCloser(bytes.NewReader(body)), - Header: http.Header{ - "Content-Type": []string{"application/dns-message"}, - }, - }, nil + txp := &DNSOverHTTPS{ + Client: &mocks.HTTPClient{ + MockDo: func(*http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: 200, + Body: io.NopCloser(bytes.NewReader(body)), + Header: http.Header{ + "Content-Type": []string{"application/dns-message"}, + }, + }, nil + }, }, URL: "https://cloudflare-dns.com/dns-query", } @@ -105,7 +113,7 @@ func TestDNSOverHTTPSSuccess(t *testing.T) { func TestDNSOverHTTPTransportOK(t *testing.T) { const queryURL = "https://cloudflare-dns.com/dns-query" - txp := resolver.NewDNSOverHTTPS(http.DefaultClient, queryURL) + txp := NewDNSOverHTTPS(http.DefaultClient, queryURL) if txp.Network() != "doh" { t.Fatal("invalid network") } @@ -120,10 +128,12 @@ func TestDNSOverHTTPTransportOK(t *testing.T) { func TestDNSOverHTTPSClientSetsUserAgent(t *testing.T) { expected := errors.New("mocked error") var correct bool - txp := resolver.DNSOverHTTPS{ - Do: func(req *http.Request) (*http.Response, error) { - correct = req.Header.Get("User-Agent") == httpheader.UserAgent() - return nil, expected + txp := &DNSOverHTTPS{ + Client: &mocks.HTTPClient{ + MockDo: func(req *http.Request) (*http.Response, error) { + correct = req.Header.Get("User-Agent") == httpheader.UserAgent() + return nil, expected + }, }, URL: "https://cloudflare-dns.com/dns-query", } @@ -144,10 +154,12 @@ func TestDNSOverHTTPSHostOverride(t *testing.T) { expected := errors.New("mocked error") hostOverride := "test.com" - txp := resolver.DNSOverHTTPS{ - Do: func(req *http.Request) (*http.Response, error) { - correct = req.Host == hostOverride - return nil, expected + txp := &DNSOverHTTPS{ + Client: &mocks.HTTPClient{ + MockDo: func(req *http.Request) (*http.Response, error) { + correct = req.Host == hostOverride + return nil, expected + }, }, URL: "https://cloudflare-dns.com/dns-query", HostOverride: hostOverride, diff --git a/internal/engine/netx/resolver/dnsovertcp.go b/internal/engine/netx/resolver/dnsovertcp.go index 0c90f36..1c7037c 100644 --- a/internal/engine/netx/resolver/dnsovertcp.go +++ b/internal/engine/netx/resolver/dnsovertcp.go @@ -26,8 +26,8 @@ type DNSOverTCP struct { } // NewDNSOverTCP creates a new DNSOverTCP transport. -func NewDNSOverTCP(dial DialContextFunc, address string) DNSOverTCP { - return DNSOverTCP{ +func NewDNSOverTCP(dial DialContextFunc, address string) *DNSOverTCP { + return &DNSOverTCP{ dial: dial, address: address, network: "tcp", @@ -36,8 +36,8 @@ func NewDNSOverTCP(dial DialContextFunc, address string) DNSOverTCP { } // NewDNSOverTLS creates a new DNSOverTLS transport. -func NewDNSOverTLS(dial DialContextFunc, address string) DNSOverTCP { - return DNSOverTCP{ +func NewDNSOverTLS(dial DialContextFunc, address string) *DNSOverTCP { + return &DNSOverTCP{ dial: dial, address: address, network: "dot", @@ -46,7 +46,7 @@ func NewDNSOverTLS(dial DialContextFunc, address string) DNSOverTCP { } // RoundTrip implements RoundTripper.RoundTrip. -func (t DNSOverTCP) RoundTrip(ctx context.Context, query []byte) ([]byte, error) { +func (t *DNSOverTCP) RoundTrip(ctx context.Context, query []byte) ([]byte, error) { if len(query) > math.MaxUint16 { return nil, errors.New("query too long") } @@ -80,18 +80,23 @@ func (t DNSOverTCP) RoundTrip(ctx context.Context, query []byte) ([]byte, error) // RequiresPadding returns true for DoT and false for TCP // according to RFC8467. -func (t DNSOverTCP) RequiresPadding() bool { +func (t *DNSOverTCP) RequiresPadding() bool { return t.requiresPadding } // Network returns the transport network (e.g., doh, dot) -func (t DNSOverTCP) Network() string { +func (t *DNSOverTCP) Network() string { return t.network } // Address returns the upstream server address. -func (t DNSOverTCP) Address() string { +func (t *DNSOverTCP) Address() string { return t.address } -var _ RoundTripper = DNSOverTCP{} +// CloseIdleConnections closes idle connections. +func (t *DNSOverTCP) CloseIdleConnections() { + // nothing to do +} + +var _ RoundTripper = &DNSOverTCP{} diff --git a/internal/engine/netx/resolver/dnsovertcp_test.go b/internal/engine/netx/resolver/dnsovertcp_test.go index c3d035f..3295a7e 100644 --- a/internal/engine/netx/resolver/dnsovertcp_test.go +++ b/internal/engine/netx/resolver/dnsovertcp_test.go @@ -1,17 +1,15 @@ -package resolver_test +package resolver import ( "context" "errors" "net" "testing" - - "github.com/ooni/probe-cli/v3/internal/engine/netx/resolver" ) func TestDNSOverTCPTransportQueryTooLarge(t *testing.T) { const address = "9.9.9.9:53" - txp := resolver.NewDNSOverTCP(new(net.Dialer).DialContext, address) + txp := NewDNSOverTCP(new(net.Dialer).DialContext, address) reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<18)) if err == nil { t.Fatal("expected an error here") @@ -24,8 +22,8 @@ func TestDNSOverTCPTransportQueryTooLarge(t *testing.T) { func TestDNSOverTCPTransportDialFailure(t *testing.T) { const address = "9.9.9.9:53" mocked := errors.New("mocked error") - fakedialer := resolver.FakeDialer{Err: mocked} - txp := resolver.NewDNSOverTCP(fakedialer.DialContext, address) + fakedialer := FakeDialer{Err: mocked} + txp := NewDNSOverTCP(fakedialer.DialContext, address) reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11)) if !errors.Is(err, mocked) { t.Fatal("not the error we expected") @@ -38,10 +36,10 @@ func TestDNSOverTCPTransportDialFailure(t *testing.T) { func TestDNSOverTCPTransportSetDealineFailure(t *testing.T) { const address = "9.9.9.9:53" mocked := errors.New("mocked error") - fakedialer := resolver.FakeDialer{Conn: &resolver.FakeConn{ + fakedialer := FakeDialer{Conn: &FakeConn{ SetDeadlineError: mocked, }} - txp := resolver.NewDNSOverTCP(fakedialer.DialContext, address) + txp := NewDNSOverTCP(fakedialer.DialContext, address) reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11)) if !errors.Is(err, mocked) { t.Fatal("not the error we expected") @@ -54,10 +52,10 @@ func TestDNSOverTCPTransportSetDealineFailure(t *testing.T) { func TestDNSOverTCPTransportWriteFailure(t *testing.T) { const address = "9.9.9.9:53" mocked := errors.New("mocked error") - fakedialer := resolver.FakeDialer{Conn: &resolver.FakeConn{ + fakedialer := FakeDialer{Conn: &FakeConn{ WriteError: mocked, }} - txp := resolver.NewDNSOverTCP(fakedialer.DialContext, address) + txp := NewDNSOverTCP(fakedialer.DialContext, address) reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11)) if !errors.Is(err, mocked) { t.Fatal("not the error we expected") @@ -70,10 +68,10 @@ func TestDNSOverTCPTransportWriteFailure(t *testing.T) { func TestDNSOverTCPTransportReadFailure(t *testing.T) { const address = "9.9.9.9:53" mocked := errors.New("mocked error") - fakedialer := resolver.FakeDialer{Conn: &resolver.FakeConn{ + fakedialer := FakeDialer{Conn: &FakeConn{ ReadError: mocked, }} - txp := resolver.NewDNSOverTCP(fakedialer.DialContext, address) + txp := NewDNSOverTCP(fakedialer.DialContext, address) reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11)) if !errors.Is(err, mocked) { t.Fatal("not the error we expected") @@ -86,11 +84,11 @@ func TestDNSOverTCPTransportReadFailure(t *testing.T) { func TestDNSOverTCPTransportSecondReadFailure(t *testing.T) { const address = "9.9.9.9:53" mocked := errors.New("mocked error") - fakedialer := resolver.FakeDialer{Conn: &resolver.FakeConn{ + fakedialer := FakeDialer{Conn: &FakeConn{ ReadError: mocked, ReadData: []byte{byte(0), byte(2)}, }} - txp := resolver.NewDNSOverTCP(fakedialer.DialContext, address) + txp := NewDNSOverTCP(fakedialer.DialContext, address) reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11)) if !errors.Is(err, mocked) { t.Fatal("not the error we expected") @@ -103,11 +101,11 @@ func TestDNSOverTCPTransportSecondReadFailure(t *testing.T) { func TestDNSOverTCPTransportAllGood(t *testing.T) { const address = "9.9.9.9:53" mocked := errors.New("mocked error") - fakedialer := resolver.FakeDialer{Conn: &resolver.FakeConn{ + fakedialer := FakeDialer{Conn: &FakeConn{ ReadError: mocked, ReadData: []byte{byte(0), byte(1), byte(1)}, }} - txp := resolver.NewDNSOverTCP(fakedialer.DialContext, address) + txp := NewDNSOverTCP(fakedialer.DialContext, address) reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11)) if err != nil { t.Fatal(err) @@ -119,7 +117,7 @@ func TestDNSOverTCPTransportAllGood(t *testing.T) { func TestDNSOverTCPTransportOK(t *testing.T) { const address = "9.9.9.9:53" - txp := resolver.NewDNSOverTCP(new(net.Dialer).DialContext, address) + txp := NewDNSOverTCP(new(net.Dialer).DialContext, address) if txp.RequiresPadding() != false { t.Fatal("invalid RequiresPadding") } @@ -133,7 +131,7 @@ func TestDNSOverTCPTransportOK(t *testing.T) { func TestDNSOverTLSTransportOK(t *testing.T) { const address = "9.9.9.9:853" - txp := resolver.NewDNSOverTLS(resolver.DialTLSContext, address) + txp := NewDNSOverTLS(DialTLSContext, address) if txp.RequiresPadding() != true { t.Fatal("invalid RequiresPadding") } diff --git a/internal/engine/netx/resolver/dnsoverudp.go b/internal/engine/netx/resolver/dnsoverudp.go index bb4fb3d..bbd60a5 100644 --- a/internal/engine/netx/resolver/dnsoverudp.go +++ b/internal/engine/netx/resolver/dnsoverudp.go @@ -18,12 +18,12 @@ type DNSOverUDP struct { } // NewDNSOverUDP creates a DNSOverUDP instance. -func NewDNSOverUDP(dialer Dialer, address string) DNSOverUDP { - return DNSOverUDP{dialer: dialer, address: address} +func NewDNSOverUDP(dialer Dialer, address string) *DNSOverUDP { + return &DNSOverUDP{dialer: dialer, address: address} } // RoundTrip implements RoundTripper.RoundTrip. -func (t DNSOverUDP) RoundTrip(ctx context.Context, query []byte) ([]byte, error) { +func (t *DNSOverUDP) RoundTrip(ctx context.Context, query []byte) ([]byte, error) { conn, err := t.dialer.DialContext(ctx, "udp", t.address) if err != nil { return nil, err @@ -47,18 +47,23 @@ func (t DNSOverUDP) RoundTrip(ctx context.Context, query []byte) ([]byte, error) } // RequiresPadding returns false for UDP according to RFC8467 -func (t DNSOverUDP) RequiresPadding() bool { +func (t *DNSOverUDP) RequiresPadding() bool { return false } // Network returns the transport network (e.g., doh, dot) -func (t DNSOverUDP) Network() string { +func (t *DNSOverUDP) Network() string { return "udp" } // Address returns the upstream server address. -func (t DNSOverUDP) Address() string { +func (t *DNSOverUDP) Address() string { return t.address } -var _ RoundTripper = DNSOverUDP{} +// CloseIdleConnections closes idle connections. +func (t *DNSOverUDP) CloseIdleConnections() { + // nothing to do +} + +var _ RoundTripper = &DNSOverUDP{} diff --git a/internal/engine/netx/resolver/dnsoverudp_test.go b/internal/engine/netx/resolver/dnsoverudp_test.go index 9c9a4ce..04906dd 100644 --- a/internal/engine/netx/resolver/dnsoverudp_test.go +++ b/internal/engine/netx/resolver/dnsoverudp_test.go @@ -1,18 +1,16 @@ -package resolver_test +package resolver import ( "context" "errors" "net" "testing" - - "github.com/ooni/probe-cli/v3/internal/engine/netx/resolver" ) func TestDNSOverUDPDialFailure(t *testing.T) { mocked := errors.New("mocked error") const address = "9.9.9.9:53" - txp := resolver.NewDNSOverUDP(resolver.FakeDialer{Err: mocked}, address) + txp := NewDNSOverUDP(FakeDialer{Err: mocked}, address) data, err := txp.RoundTrip(context.Background(), nil) if !errors.Is(err, mocked) { t.Fatal("not the error we expected") @@ -24,9 +22,9 @@ func TestDNSOverUDPDialFailure(t *testing.T) { func TestDNSOverUDPSetDeadlineError(t *testing.T) { mocked := errors.New("mocked error") - txp := resolver.NewDNSOverUDP( - resolver.FakeDialer{ - Conn: &resolver.FakeConn{ + txp := NewDNSOverUDP( + FakeDialer{ + Conn: &FakeConn{ SetDeadlineError: mocked, }, }, "9.9.9.9:53", @@ -42,9 +40,9 @@ func TestDNSOverUDPSetDeadlineError(t *testing.T) { func TestDNSOverUDPWriteFailure(t *testing.T) { mocked := errors.New("mocked error") - txp := resolver.NewDNSOverUDP( - resolver.FakeDialer{ - Conn: &resolver.FakeConn{ + txp := NewDNSOverUDP( + FakeDialer{ + Conn: &FakeConn{ WriteError: mocked, }, }, "9.9.9.9:53", @@ -60,9 +58,9 @@ func TestDNSOverUDPWriteFailure(t *testing.T) { func TestDNSOverUDPReadFailure(t *testing.T) { mocked := errors.New("mocked error") - txp := resolver.NewDNSOverUDP( - resolver.FakeDialer{ - Conn: &resolver.FakeConn{ + txp := NewDNSOverUDP( + FakeDialer{ + Conn: &FakeConn{ ReadError: mocked, }, }, "9.9.9.9:53", @@ -78,9 +76,9 @@ func TestDNSOverUDPReadFailure(t *testing.T) { func TestDNSOverUDPReadSuccess(t *testing.T) { const expected = 17 - txp := resolver.NewDNSOverUDP( - resolver.FakeDialer{ - Conn: &resolver.FakeConn{ReadData: make([]byte, 17)}, + txp := NewDNSOverUDP( + FakeDialer{ + Conn: &FakeConn{ReadData: make([]byte, 17)}, }, "9.9.9.9:53", ) data, err := txp.RoundTrip(context.Background(), nil) @@ -94,7 +92,7 @@ func TestDNSOverUDPReadSuccess(t *testing.T) { func TestDNSOverUDPTransportOK(t *testing.T) { const address = "9.9.9.9:53" - txp := resolver.NewDNSOverUDP(&net.Dialer{}, address) + txp := NewDNSOverUDP(&net.Dialer{}, address) if txp.RequiresPadding() != false { t.Fatal("invalid RequiresPadding") } diff --git a/internal/engine/netx/resolver/emitter_test.go b/internal/engine/netx/resolver/emitter_test.go index d275d3d..a6bdd4c 100644 --- a/internal/engine/netx/resolver/emitter_test.go +++ b/internal/engine/netx/resolver/emitter_test.go @@ -13,6 +13,7 @@ import ( "github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/handlers" "github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/modelx" "github.com/ooni/probe-cli/v3/internal/engine/netx/resolver" + "github.com/ooni/probe-cli/v3/internal/netxlite/mocks" ) func TestEmitterTransportSuccess(t *testing.T) { @@ -107,9 +108,11 @@ func TestEmitterResolverFailure(t *testing.T) { } ctx = modelx.WithMeasurementRoot(ctx, root) r := resolver.EmitterResolver{Resolver: resolver.NewSerialResolver( - resolver.DNSOverHTTPS{ - Do: func(req *http.Request) (*http.Response, error) { - return nil, io.EOF + &resolver.DNSOverHTTPS{ + Client: &mocks.HTTPClient{ + MockDo: func(req *http.Request) (*http.Response, error) { + return nil, io.EOF + }, }, URL: "https://dns.google.com/", }, diff --git a/internal/engine/netx/resolver/fake_test.go b/internal/engine/netx/resolver/fake_test.go index 576e738..b74bbec 100644 --- a/internal/engine/netx/resolver/fake_test.go +++ b/internal/engine/netx/resolver/fake_test.go @@ -93,6 +93,10 @@ func (ft FakeTransport) Network() string { return "fake" } +func (fk FakeTransport) CloseIdleConnections() { + // nothing to do +} + type FakeEncoder struct { Data []byte Err error diff --git a/internal/engine/netx/resolver/serial.go b/internal/engine/netx/resolver/serial.go index 6844df7..377e3bf 100644 --- a/internal/engine/netx/resolver/serial.go +++ b/internal/engine/netx/resolver/serial.go @@ -22,6 +22,9 @@ type RoundTripper interface { // Address is the address of the round tripper (e.g. "1.1.1.1:853") Address() string + + // CloseIdleConnections closes idle connections. + CloseIdleConnections() } // SerialResolver is a resolver that first issues an A query and then @@ -81,7 +84,7 @@ func (r SerialResolver) roundTripWithRetry( } errorslist = append(errorslist, err) var operr *net.OpError - if errors.As(err, &operr) == false || operr.Timeout() == false { + if !errors.As(err, &operr) || !operr.Timeout() { // The first error is the one that is most likely to be caused // by the network. Subsequent errors are more likely to be caused // by context deadlines. So, the first error is attached to an diff --git a/internal/netxlite/mocks/http.go b/internal/netxlite/mocks/http.go index 0d658e3..4f4cf3e 100644 --- a/internal/netxlite/mocks/http.go +++ b/internal/netxlite/mocks/http.go @@ -17,3 +17,20 @@ func (txp *HTTPTransport) RoundTrip(req *http.Request) (*http.Response, error) { func (txp *HTTPTransport) CloseIdleConnections() { txp.MockCloseIdleConnections() } + +// HTTPClient allows mocking an http.Client. +type HTTPClient struct { + MockDo func(req *http.Request) (*http.Response, error) + + MockCloseIdleConnections func() +} + +// Do calls MockDo. +func (txp *HTTPClient) Do(req *http.Request) (*http.Response, error) { + return txp.MockDo(req) +} + +// CloseIdleConnections calls MockCloseIdleConnections. +func (txp *HTTPClient) CloseIdleConnections() { + txp.MockCloseIdleConnections() +} diff --git a/internal/netxlite/mocks/http_test.go b/internal/netxlite/mocks/http_test.go index 94107b0..7351207 100644 --- a/internal/netxlite/mocks/http_test.go +++ b/internal/netxlite/mocks/http_test.go @@ -38,3 +38,34 @@ func TestHTTPTransport(t *testing.T) { } }) } + +func TestHTTPClient(t *testing.T) { + t.Run("Do", func(t *testing.T) { + expected := errors.New("mocked error") + clnt := &HTTPClient{ + MockDo: func(req *http.Request) (*http.Response, error) { + return nil, expected + }, + } + resp, err := clnt.Do(&http.Request{}) + if !errors.Is(err, expected) { + t.Fatal("not the error we expected", err) + } + if resp != nil { + t.Fatal("expected nil response here") + } + }) + + t.Run("CloseIdleConnections", func(t *testing.T) { + called := &atomicx.Int64{} + clnt := &HTTPClient{ + MockCloseIdleConnections: func() { + called.Add(1) + }, + } + clnt.CloseIdleConnections() + if called.Load() != 1 { + t.Fatal("not called") + } + }) +}