diff --git a/internal/cmd/oohelperd/internal/webconnectivity/dns.go b/internal/cmd/oohelperd/internal/webconnectivity/dns.go index 8faaf60..c759fda 100644 --- a/internal/cmd/oohelperd/internal/webconnectivity/dns.go +++ b/internal/cmd/oohelperd/internal/webconnectivity/dns.go @@ -19,16 +19,18 @@ type CtrlDNSResult = webconnectivity.ControlDNSResult // DNSConfig configures the DNS check. type DNSConfig struct { - Domain string - Out chan CtrlDNSResult - Resolver model.Resolver - Wg *sync.WaitGroup + Domain string + NewResolver func() model.Resolver + Out chan CtrlDNSResult + Wg *sync.WaitGroup } // DNSDo performs the DNS check. func DNSDo(ctx context.Context, config *DNSConfig) { defer config.Wg.Done() - addrs, err := config.Resolver.LookupHost(ctx, config.Domain) + reso := config.NewResolver() + defer reso.CloseIdleConnections() + addrs, err := reso.LookupHost(ctx, config.Domain) if addrs == nil { addrs = []string{} // fix: the old test helper did that } diff --git a/internal/cmd/oohelperd/internal/webconnectivity/dns_test.go b/internal/cmd/oohelperd/internal/webconnectivity/dns_test.go index cc9b805..c2eb1eb 100644 --- a/internal/cmd/oohelperd/internal/webconnectivity/dns_test.go +++ b/internal/cmd/oohelperd/internal/webconnectivity/dns_test.go @@ -7,6 +7,7 @@ import ( "github.com/google/go-cmp/cmp" "github.com/ooni/probe-cli/v3/internal/engine/experiment/webconnectivity" + "github.com/ooni/probe-cli/v3/internal/model" "github.com/ooni/probe-cli/v3/internal/model/mocks" "github.com/ooni/probe-cli/v3/internal/netxlite" ) @@ -68,13 +69,18 @@ func TestDNSDo(t *testing.T) { ctx := context.Background() config := &DNSConfig{ Domain: "antani.ooni.org", - Out: make(chan webconnectivity.ControlDNSResult, 1), - Resolver: &mocks.Resolver{ - MockLookupHost: func(ctx context.Context, domain string) ([]string, error) { - return nil, netxlite.ErrOODNSNoSuchHost - }, + NewResolver: func() model.Resolver { + return &mocks.Resolver{ + MockLookupHost: func(ctx context.Context, domain string) ([]string, error) { + return nil, netxlite.ErrOODNSNoSuchHost + }, + MockCloseIdleConnections: func() { + // nothing + }, + } }, - Wg: &sync.WaitGroup{}, + Out: make(chan webconnectivity.ControlDNSResult, 1), + Wg: &sync.WaitGroup{}, } config.Wg.Add(1) DNSDo(ctx, config) diff --git a/internal/cmd/oohelperd/internal/webconnectivity/http.go b/internal/cmd/oohelperd/internal/webconnectivity/http.go index 8400258..ed39f2c 100644 --- a/internal/cmd/oohelperd/internal/webconnectivity/http.go +++ b/internal/cmd/oohelperd/internal/webconnectivity/http.go @@ -19,9 +19,9 @@ type CtrlHTTPResponse = webconnectivity.ControlHTTPRequestResult // HTTPConfig configures the HTTP check. type HTTPConfig struct { - Client model.HTTPClient Headers map[string][]string MaxAcceptableBody int64 + NewClient func() model.HTTPClient Out chan CtrlHTTPResponse URL string Wg *sync.WaitGroup @@ -50,7 +50,9 @@ func HTTPDo(ctx context.Context, config *HTTPConfig) { } } } - resp, err := config.Client.Do(req) + clnt := config.NewClient() + defer clnt.CloseIdleConnections() + resp, err := clnt.Do(req) if err != nil { config.Out <- CtrlHTTPResponse{ // fix: emit -1 like old test helper does BodyLength: -1, diff --git a/internal/cmd/oohelperd/internal/webconnectivity/http_test.go b/internal/cmd/oohelperd/internal/webconnectivity/http_test.go index 0347da5..0c97a14 100644 --- a/internal/cmd/oohelperd/internal/webconnectivity/http_test.go +++ b/internal/cmd/oohelperd/internal/webconnectivity/http_test.go @@ -8,6 +8,7 @@ import ( "testing" "github.com/google/go-cmp/cmp" + "github.com/ooni/probe-cli/v3/internal/model" "github.com/ooni/probe-cli/v3/internal/netxlite" ) @@ -17,12 +18,14 @@ func TestHTTPDoWithInvalidURL(t *testing.T) { httpch := make(chan CtrlHTTPResponse, 1) wg.Add(1) go HTTPDo(ctx, &HTTPConfig{ - Client: http.DefaultClient, Headers: nil, MaxAcceptableBody: 1 << 24, - Out: httpch, - URL: "http://[::1]aaaa", - Wg: wg, + NewClient: func() model.HTTPClient { + return http.DefaultClient + }, + Out: httpch, + URL: "http://[::1]aaaa", + Wg: wg, }) // wait for measurement steps to complete wg.Wait() @@ -39,16 +42,18 @@ func TestHTTPDoWithHTTPTransportFailure(t *testing.T) { httpch := make(chan CtrlHTTPResponse, 1) wg.Add(1) go HTTPDo(ctx, &HTTPConfig{ - Client: &http.Client{ - Transport: FakeTransport{ - Err: expected, - }, - }, Headers: nil, MaxAcceptableBody: 1 << 24, - Out: httpch, - URL: "http://www.x.org", - Wg: wg, + NewClient: func() model.HTTPClient { + return &http.Client{ + Transport: FakeTransport{ + Err: expected, + }, + } + }, + Out: httpch, + URL: "http://www.x.org", + Wg: wg, }) // wait for measurement steps to complete wg.Wait() diff --git a/internal/cmd/oohelperd/internal/webconnectivity/measure.go b/internal/cmd/oohelperd/internal/webconnectivity/measure.go index 971df6c..d7f3158 100644 --- a/internal/cmd/oohelperd/internal/webconnectivity/measure.go +++ b/internal/cmd/oohelperd/internal/webconnectivity/measure.go @@ -20,10 +20,10 @@ type ( // MeasureConfig contains configuration for Measure. type MeasureConfig struct { - Client model.HTTPClient - Dialer model.Dialer MaxAcceptableBody int64 - Resolver model.Resolver + NewClient func() model.HTTPClient + NewDialer func() model.Dialer + NewResolver func() model.Resolver } // Measure performs the measurement described by the request and @@ -40,10 +40,10 @@ func Measure(ctx context.Context, config MeasureConfig, creq *CtrlRequest) (*Ctr if net.ParseIP(URL.Hostname()) == nil { wg.Add(1) go DNSDo(ctx, &DNSConfig{ - Domain: URL.Hostname(), - Out: dnsch, - Resolver: config.Resolver, - Wg: wg, + Domain: URL.Hostname(), + NewResolver: config.NewResolver, + Out: dnsch, + Wg: wg, }) } // tcpconnect: start @@ -51,19 +51,19 @@ func Measure(ctx context.Context, config MeasureConfig, creq *CtrlRequest) (*Ctr for _, endpoint := range creq.TCPConnect { wg.Add(1) go TCPDo(ctx, &TCPConfig{ - Dialer: config.Dialer, - Endpoint: endpoint, - Out: tcpconnch, - Wg: wg, + Endpoint: endpoint, + NewDialer: config.NewDialer, + Out: tcpconnch, + Wg: wg, }) } // http: start httpch := make(chan CtrlHTTPResponse, 1) wg.Add(1) go HTTPDo(ctx, &HTTPConfig{ - Client: config.Client, Headers: creq.HTTPRequestHeaders, MaxAcceptableBody: config.MaxAcceptableBody, + NewClient: config.NewClient, Out: httpch, URL: creq.HTTPRequest, Wg: wg, diff --git a/internal/cmd/oohelperd/internal/webconnectivity/tcpconnect.go b/internal/cmd/oohelperd/internal/webconnectivity/tcpconnect.go index a741a97..21afa26 100644 --- a/internal/cmd/oohelperd/internal/webconnectivity/tcpconnect.go +++ b/internal/cmd/oohelperd/internal/webconnectivity/tcpconnect.go @@ -20,16 +20,18 @@ type TCPResultPair struct { // TCPConfig configures the TCP connect check. type TCPConfig struct { - Dialer model.Dialer - Endpoint string - Out chan TCPResultPair - Wg *sync.WaitGroup + Endpoint string + NewDialer func() model.Dialer + Out chan TCPResultPair + Wg *sync.WaitGroup } // TCPDo performs the TCP check. func TCPDo(ctx context.Context, config *TCPConfig) { defer config.Wg.Done() - conn, err := config.Dialer.DialContext(ctx, "tcp", config.Endpoint) + dialer := config.NewDialer() + defer dialer.CloseIdleConnections() + conn, err := dialer.DialContext(ctx, "tcp", config.Endpoint) if conn != nil { conn.Close() } diff --git a/internal/cmd/oohelperd/internal/webconnectivity/webconnectivity.go b/internal/cmd/oohelperd/internal/webconnectivity/webconnectivity.go index 01a4480..1bd9ab1 100644 --- a/internal/cmd/oohelperd/internal/webconnectivity/webconnectivity.go +++ b/internal/cmd/oohelperd/internal/webconnectivity/webconnectivity.go @@ -14,10 +14,10 @@ import ( // Handler implements the Web Connectivity test helper HTTP API. type Handler struct { - Client model.HTTPClient - Dialer model.Dialer MaxAcceptableBody int64 - Resolver model.Resolver + NewClient func() model.HTTPClient + NewDialer func() model.Dialer + NewResolver func() model.Resolver } func (h Handler) ServeHTTP(w http.ResponseWriter, req *http.Request) { diff --git a/internal/cmd/oohelperd/internal/webconnectivity/webconnectivity_test.go b/internal/cmd/oohelperd/internal/webconnectivity/webconnectivity_test.go index b70d10f..d7f88f0 100644 --- a/internal/cmd/oohelperd/internal/webconnectivity/webconnectivity_test.go +++ b/internal/cmd/oohelperd/internal/webconnectivity/webconnectivity_test.go @@ -51,10 +51,16 @@ const requestWithoutDomainName = `{ func TestWorkingAsIntended(t *testing.T) { handler := Handler{ - Client: http.DefaultClient, - Dialer: netxlite.NewDialerWithStdlibResolver(model.DiscardLogger), MaxAcceptableBody: 1 << 24, - Resolver: netxlite.NewUnwrappedStdlibResolver(), + NewClient: func() model.HTTPClient { + return http.DefaultClient + }, + NewDialer: func() model.Dialer { + return netxlite.NewDialerWithStdlibResolver(model.DiscardLogger) + }, + NewResolver: func() model.Resolver { + return netxlite.NewUnwrappedStdlibResolver() + }, } srv := httptest.NewServer(handler) defer srv.Close() diff --git a/internal/cmd/oohelperd/oohelperd.go b/internal/cmd/oohelperd/oohelperd.go index 92db20f..71c9f89 100644 --- a/internal/cmd/oohelperd/oohelperd.go +++ b/internal/cmd/oohelperd/oohelperd.go @@ -17,22 +17,22 @@ import ( const maxAcceptableBody = 1 << 24 var ( - dialer model.Dialer - endpoint = flag.String("endpoint", ":8080", "Endpoint where to listen") - httpClient model.HTTPClient - resolver model.Resolver - srvcancel context.CancelFunc - srvctx context.Context - srvwg = new(sync.WaitGroup) + endpoint = flag.String("endpoint", ":8080", "Endpoint where to listen") + srvcancel context.CancelFunc + srvctx context.Context + srvwg = new(sync.WaitGroup) ) func init() { srvctx, srvcancel = context.WithCancel(context.Background()) +} + +func newresolver() model.Resolver { // Implementation note: pin to a specific resolver so we don't depend upon the // default resolver configured by the box. Also, use an encrypted transport thus // we're less vulnerable to any policy implemented by the box's provider. - resolver = netxlite.NewParallelDNSOverHTTPSResolver(log.Log, "https://8.8.8.8/dns-query") - httpClient = netxlite.NewHTTPClientWithResolver(log.Log, resolver) + resolver := netxlite.NewParallelDNSOverHTTPSResolver(log.Log, "https://8.8.8.8/dns-query") + return resolver } func shutdown(srv *http.Server) { @@ -55,10 +55,14 @@ func main() { func testableMain() { mux := http.NewServeMux() mux.Handle("/", webconnectivity.Handler{ - Client: httpClient, - Dialer: dialer, MaxAcceptableBody: maxAcceptableBody, - Resolver: resolver, + NewClient: func() model.HTTPClient { + return netxlite.NewHTTPClientWithResolver(log.Log, newresolver()) + }, + NewDialer: func() model.Dialer { + return netxlite.NewDialerWithResolver(log.Log, newresolver()) + }, + NewResolver: newresolver, }) srv := &http.Server{Addr: *endpoint, Handler: mux} srvwg.Add(1)