diff --git a/internal/cmd/oohelper/oohelper.go b/internal/cmd/oohelper/oohelper.go index da9f85e..ec18108 100644 --- a/internal/cmd/oohelper/oohelper.go +++ b/internal/cmd/oohelper/oohelper.go @@ -29,8 +29,7 @@ func init() { // puzzling https://github.com/ooni/probe/issues/1409 issue. const resolverURL = "https://8.8.8.8/dns-query" resolver = netxlite.NewParallelDNSOverHTTPSResolver(log.Log, resolverURL) - txp := netxlite.NewHTTPTransportWithResolver(log.Log, resolver) - httpClient = netxlite.NewHTTPClient(txp) + httpClient = netxlite.NewHTTPClientWithResolver(log.Log, resolver) } func main() { diff --git a/internal/cmd/oohelperd/oohelperd.go b/internal/cmd/oohelperd/oohelperd.go index 007cdf5..92db20f 100644 --- a/internal/cmd/oohelperd/oohelperd.go +++ b/internal/cmd/oohelperd/oohelperd.go @@ -32,8 +32,7 @@ func init() { // default resolver configured by the box. Also, use an encrypted transport thus // we're less vulnerable to any policy implemented by the box's provider. resolver = netxlite.NewParallelDNSOverHTTPSResolver(log.Log, "https://8.8.8.8/dns-query") - txp := netxlite.NewHTTPTransportWithResolver(log.Log, resolver) - httpClient = netxlite.NewHTTPClient(txp) + httpClient = netxlite.NewHTTPClientWithResolver(log.Log, resolver) } func shutdown(srv *http.Server) { diff --git a/internal/engine/session.go b/internal/engine/session.go index bcc0f7e..dd62d42 100644 --- a/internal/engine/session.go +++ b/internal/engine/session.go @@ -197,11 +197,9 @@ func NewSession(ctx context.Context, config SessionConfig) (*Session, error) { Logger: sess.logger, ProxyURL: proxyURL, } - dialer := netxlite.NewDialerWithResolver(sess.logger, sess.resolver) - dialer = netxlite.MaybeWrapWithProxyDialer(dialer, proxyURL) - handshaker := netxlite.NewTLSHandshakerStdlib(sess.logger) - tlsDialer := netxlite.NewTLSDialer(dialer, handshaker) - txp := netxlite.NewHTTPTransport(sess.logger, dialer, tlsDialer) + txp := netxlite.NewHTTPTransportWithLoggerResolverAndOptionalProxyURL( + sess.logger, sess.resolver, sess.proxyURL, + ) txp = bytecounter.WrapHTTPTransport(txp, sess.byteCounter) sess.httpDefaultTransport = txp return sess, nil diff --git a/internal/netxlite/dnsoverhttps.go b/internal/netxlite/dnsoverhttps.go index 3a7cec9..aeb56f0 100644 --- a/internal/netxlite/dnsoverhttps.go +++ b/internal/netxlite/dnsoverhttps.go @@ -43,6 +43,18 @@ func NewUnwrappedDNSOverHTTPSTransport(client model.HTTPClient, URL string) *DNS return NewUnwrappedDNSOverHTTPSTransportWithHostOverride(client, URL, "") } +// NewDNSOverHTTPSTransport is like NewUnwrappedDNSOverHTTPSTransport but +// returns an already wrapped DNSTransport. +func NewDNSOverHTTPSTransport(client model.HTTPClient, URL string) model.DNSTransport { + return WrapDNSTransport(NewUnwrappedDNSOverHTTPSTransport(client, URL)) +} + +// NewDNSOverHTTPSTransportWithHTTPTransport is like NewDNSOverHTTPSTransport +// but takes in input an HTTPTransport rather than an HTTPClient. +func NewDNSOverHTTPSTransportWithHTTPTransport(txp model.HTTPTransport, URL string) model.DNSTransport { + return WrapDNSTransport(NewUnwrappedDNSOverHTTPSTransport(NewHTTPClient(txp), URL)) +} + // NewUnwrappedDNSOverHTTPSTransportWithHostOverride creates a new DNSOverHTTPSTransport // with the given Host header override. This instance has not been wrapped yet. func NewUnwrappedDNSOverHTTPSTransportWithHostOverride( diff --git a/internal/netxlite/dnsoverhttps_test.go b/internal/netxlite/dnsoverhttps_test.go index c6bbf23..cc062e6 100644 --- a/internal/netxlite/dnsoverhttps_test.go +++ b/internal/netxlite/dnsoverhttps_test.go @@ -13,6 +13,36 @@ import ( "github.com/ooni/probe-cli/v3/internal/model/mocks" ) +func TestNewDNSOverHTTPSTransport(t *testing.T) { + const URL = "https://1.1.1.1/dns-query" + clnt := NewHTTPClientStdlib(model.DiscardLogger) + txp := NewDNSOverHTTPSTransport(clnt, URL) + ew := txp.(*dnsTransportErrWrapper) + https := ew.DNSTransport.(*DNSOverHTTPSTransport) + if https.Client != clnt { + t.Fatal("invalid client") + } + if https.URL != URL { + t.Fatal("invalid URL") + } +} + +func TestNewDNSOverHTTPSTransportWithHTTPTransport(t *testing.T) { + const URL = "https://1.1.1.1/dns-query" + httpTxp := NewHTTPTransportStdlib(model.DiscardLogger) + txp := NewDNSOverHTTPSTransportWithHTTPTransport(httpTxp, URL) + ew := txp.(*dnsTransportErrWrapper) + https := ew.DNSTransport.(*DNSOverHTTPSTransport) + ewClient := https.Client.(*httpClientErrWrapper) + clnt := ewClient.HTTPClient.(*http.Client) + if clnt.Transport != httpTxp { + t.Fatal("invalid transport") + } + if https.URL != URL { + t.Fatal("invalid URL") + } +} + func TestDNSOverHTTPSTransport(t *testing.T) { t.Run("RoundTrip", func(t *testing.T) { t.Run("query serialization failure", func(t *testing.T) { diff --git a/internal/netxlite/http.go b/internal/netxlite/http.go index 41969f7..3eaefb3 100644 --- a/internal/netxlite/http.go +++ b/internal/netxlite/http.go @@ -9,6 +9,7 @@ import ( "errors" "net" "net/http" + "net/url" "time" oohttp "github.com/ooni/oohttp" @@ -105,6 +106,25 @@ func (txp *httpTransportConnectionsCloser) CloseIdleConnections() { txp.TLSDialer.CloseIdleConnections() } +// NewHTTPTransportWithLoggerResolverAndOptionalProxyURL creates an HTTPTransport using +// the given logger and resolver and an optional proxy URL. +// +// Arguments: +// +// - logger is the MANDATORY logger; +// +// - resolver is the MANDATORY resolver; +// +// - purl is the OPTIONAL proxy URL. +func NewHTTPTransportWithLoggerResolverAndOptionalProxyURL( + logger model.DebugLogger, resolver model.Resolver, purl *url.URL) model.HTTPTransport { + dialer := NewDialerWithResolver(logger, resolver) + dialer = MaybeWrapWithProxyDialer(dialer, purl) + handshaker := NewTLSHandshakerStdlib(logger) + tlsDialer := NewTLSDialer(dialer, handshaker) + return NewHTTPTransport(logger, dialer, tlsDialer) +} + // NewHTTPTransportWithResolver creates a new HTTP transport using // the stdlib for everything but the given resolver. func NewHTTPTransportWithResolver(logger model.DebugLogger, reso model.Resolver) model.HTTPTransport { @@ -335,6 +355,12 @@ func NewHTTPClientStdlib(logger model.DebugLogger) model.HTTPClient { return NewHTTPClient(txp) } +// NewHTTPClientWithResolver creates a new HTTPTransport using the +// given resolver and then from that builds an HTTPClient. +func NewHTTPClientWithResolver(logger model.Logger, reso model.Resolver) model.HTTPClient { + return NewHTTPClient(NewHTTPTransportWithResolver(logger, reso)) +} + // NewHTTPClient creates a new, wrapped HTTPClient using the given transport. func NewHTTPClient(txp model.HTTPTransport) model.HTTPClient { return WrapHTTPClient(&http.Client{Transport: txp}) diff --git a/internal/netxlite/http3.go b/internal/netxlite/http3.go index b5ab546..6a9c128 100644 --- a/internal/netxlite/http3.go +++ b/internal/netxlite/http3.go @@ -61,3 +61,19 @@ func NewHTTP3Transport( dialer: dialer, }) } + +// NewHTTP3TransportStdlib creates a new HTTPTransport using http3 that +// uses standard functionality for everything but the logger. +func NewHTTP3TransportStdlib(logger model.DebugLogger) model.HTTPTransport { + ql := NewQUICListener() + reso := NewStdlibResolver(logger) + qd := NewQUICDialerWithResolver(ql, logger, reso) + return NewHTTP3Transport(logger, qd, nil) +} + +// NewHTTPTransportWithResolver creates a new HTTPTransport using http3 +// that uses the given logger and the given resolver. +func NewHTTP3TransportWithResolver(logger model.Logger, reso model.Resolver) model.HTTPTransport { + qd := NewQUICDialerWithResolver(NewQUICListener(), logger, reso) + return NewHTTP3Transport(logger, qd, nil) +} diff --git a/internal/netxlite/http3_test.go b/internal/netxlite/http3_test.go index 6f21295..07249fe 100644 --- a/internal/netxlite/http3_test.go +++ b/internal/netxlite/http3_test.go @@ -6,8 +6,8 @@ import ( "net/http" "testing" - "github.com/apex/log" "github.com/lucas-clemente/quic-go/http3" + "github.com/ooni/probe-cli/v3/internal/model" "github.com/ooni/probe-cli/v3/internal/model/mocks" nlmocks "github.com/ooni/probe-cli/v3/internal/netxlite/mocks" ) @@ -63,29 +63,78 @@ func TestHTTP3Transport(t *testing.T) { }) } +// verifyTypeChainForHTTP3 helps to verify type chains for HTTP3. +// +// Arguments: +// +// - t is the MANDATORY testing ref; +// +// - txp is the MANDATORY HTTP transport to verify; +// +// - underlyingLogger is the MANDATORY logger we expect to find; +// +// - qd is the OPTIONAL QUIC dialer: if not nil, we expect to +// see this value as the QUIC dialer, otherwise we will check the +// type chain of the real dialer; +// +// - config is the MANDATORY TLS config: we'll always check +// whether the TLSClientConfig is equal to this value: passing +// nil here means we expect to see nil in the object; +// +// - reso is the OPTIONAL resolver: if present and the qd is +// nil, we'll unwrap the QUIC dialer and check whether we have +// this resolver as the underlying resolver. +func verifyTypeChainForHTTP3(t *testing.T, txp model.HTTPTransport, + underlyingLogger model.DebugLogger, qd model.QUICDialer, + config *tls.Config, reso model.Resolver) { + logger := txp.(*httpTransportLogger) + if logger.Logger != underlyingLogger { + t.Fatal("invalid logger") + } + ew := logger.HTTPTransport.(*httpTransportErrWrapper) + h3txp := ew.HTTPTransport.(*http3Transport) + if qd != nil && h3txp.dialer != qd { + t.Fatal("invalid dialer") + } + if qd == nil { + qdlog := h3txp.dialer.(*quicDialerLogger) + qdr := qdlog.Dialer.(*quicDialerResolver) + if reso != nil && qdr.Resolver != reso { + t.Fatal("invalid resolver") + } + } + h3 := h3txp.child.(*http3.RoundTripper) + if h3.Dial == nil { + t.Fatal("invalid Dial") + } + if !h3.DisableCompression { + t.Fatal("invalid DisableCompression") + } + if h3.TLSClientConfig != config { + t.Fatal("invalid TLSClientConfig") + } +} + func TestNewHTTP3Transport(t *testing.T) { t.Run("creates the correct type chain", func(t *testing.T) { qd := &mocks.QUICDialer{} config := &tls.Config{} - txp := NewHTTP3Transport(log.Log, qd, config) - logger := txp.(*httpTransportLogger) - if logger.Logger != log.Log { - t.Fatal("invalid logger") - } - ew := logger.HTTPTransport.(*httpTransportErrWrapper) - h3txp := ew.HTTPTransport.(*http3Transport) - if h3txp.dialer != qd { - t.Fatal("invalid dialer") - } - h3 := h3txp.child.(*http3.RoundTripper) - if h3.Dial == nil { - t.Fatal("invalid Dial") - } - if !h3.DisableCompression { - t.Fatal("invalid DisableCompression") - } - if h3.TLSClientConfig != config { - t.Fatal("invalid TLSClientConfig") - } + txp := NewHTTP3Transport(model.DiscardLogger, qd, config) + verifyTypeChainForHTTP3(t, txp, model.DiscardLogger, qd, config, nil) + }) +} + +func TestNewHTTP3TransportStdlib(t *testing.T) { + t.Run("creates the correct type chain", func(t *testing.T) { + txp := NewHTTP3TransportStdlib(model.DiscardLogger) + verifyTypeChainForHTTP3(t, txp, model.DiscardLogger, nil, nil, nil) + }) +} + +func TestNewHTTP3TransportWithResolver(t *testing.T) { + t.Run("creates the correct type chain", func(t *testing.T) { + reso := &mocks.Resolver{} + txp := NewHTTP3TransportWithResolver(model.DiscardLogger, reso) + verifyTypeChainForHTTP3(t, txp, model.DiscardLogger, nil, nil, reso) }) } diff --git a/internal/netxlite/http_test.go b/internal/netxlite/http_test.go index 6c19f53..e50ea8c 100644 --- a/internal/netxlite/http_test.go +++ b/internal/netxlite/http_test.go @@ -6,6 +6,7 @@ import ( "io" "net" "net/http" + "net/url" "strings" "testing" "time" @@ -16,6 +17,48 @@ import ( "github.com/ooni/probe-cli/v3/internal/model/mocks" ) +func TestNewHTTPTransportWithLoggerResolverAndOptionalProxyURL(t *testing.T) { + t.Run("without proxy URL", func(t *testing.T) { + logger := &mocks.Logger{} + resolver := &mocks.Resolver{} + txp := NewHTTPTransportWithLoggerResolverAndOptionalProxyURL(logger, resolver, nil) + txpLogger := txp.(*httpTransportLogger) + if txpLogger.Logger != logger { + t.Fatal("unexpected logger") + } + txpErrWrapper := txpLogger.HTTPTransport.(*httpTransportErrWrapper) + txpCc := txpErrWrapper.HTTPTransport.(*httpTransportConnectionsCloser) + dialer := txpCc.Dialer + dialerWithReadTimeout := dialer.(*httpDialerWithReadTimeout) + dialerLog := dialerWithReadTimeout.Dialer.(*dialerLogger) + dialerReso := dialerLog.Dialer.(*dialerResolver) + if dialerReso.Resolver != resolver { + t.Fatal("invalid resolver") + } + }) + + t.Run("with proxy URL", func(t *testing.T) { + URL := &url.URL{} + logger := &mocks.Logger{} + resolver := &mocks.Resolver{} + txp := NewHTTPTransportWithLoggerResolverAndOptionalProxyURL(logger, resolver, URL) + txpLogger := txp.(*httpTransportLogger) + if txpLogger.Logger != logger { + t.Fatal("unexpected logger") + } + txpErrWrapper := txpLogger.HTTPTransport.(*httpTransportErrWrapper) + txpCc := txpErrWrapper.HTTPTransport.(*httpTransportConnectionsCloser) + dialer := txpCc.Dialer + dialerWithReadTimeout := dialer.(*httpDialerWithReadTimeout) + dialerProxy := dialerWithReadTimeout.Dialer.(*proxyDialer) + dialerLog := dialerProxy.Dialer.(*dialerLogger) + dialerReso := dialerLog.Dialer.(*dialerResolver) + if dialerReso.Resolver != resolver { + t.Fatal("invalid resolver") + } + }) +} + func TestNewHTTPTransportWithResolver(t *testing.T) { expected := errors.New("mocked error") reso := &mocks.Resolver{ @@ -553,6 +596,28 @@ func TestNewHTTPClientStdlib(t *testing.T) { } } +func TestNewHTTPClientWithResolver(t *testing.T) { + reso := &mocks.Resolver{} + clnt := NewHTTPClientWithResolver(model.DiscardLogger, reso) + ewc, ok := clnt.(*httpClientErrWrapper) + if !ok { + t.Fatal("expected *httpClientErrWrapper") + } + httpClnt, ok := ewc.HTTPClient.(*http.Client) + if !ok { + t.Fatal("expected *http.Client") + } + txp := httpClnt.Transport.(*httpTransportLogger) + txpEwrap := txp.HTTPTransport.(*httpTransportErrWrapper) + txpCc := txpEwrap.HTTPTransport.(*httpTransportConnectionsCloser) + dialer := txpCc.Dialer.(*httpDialerWithReadTimeout) + dialerLogger := dialer.Dialer.(*dialerLogger) + dialerReso := dialerLogger.Dialer.(*dialerResolver) + if dialerReso.Resolver != reso { + t.Fatal("invalid resolver") + } +} + func TestWrapHTTPClient(t *testing.T) { origClient := &http.Client{} wrapped := WrapHTTPClient(origClient)