diff --git a/internal/engine/legacy/netx/emitterdialer_test.go b/internal/engine/legacy/netx/emitterdialer_test.go index 21b31fe..485b935 100644 --- a/internal/engine/legacy/netx/emitterdialer_test.go +++ b/internal/engine/legacy/netx/emitterdialer_test.go @@ -20,7 +20,7 @@ func TestEmitterFailure(t *testing.T) { Beginning: time.Now(), Handler: saver, }) - d := EmitterDialer{Dialer: netxmocks.Dialer{ + d := EmitterDialer{Dialer: &netxmocks.Dialer{ MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) { return nil, io.EOF }, @@ -69,7 +69,7 @@ func TestEmitterSuccess(t *testing.T) { Beginning: time.Now(), Handler: saver, }) - d := EmitterDialer{Dialer: netxmocks.Dialer{ + d := EmitterDialer{Dialer: &netxmocks.Dialer{ MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) { return &netxmocks.Conn{ MockRead: func(b []byte) (int, error) { diff --git a/internal/engine/netx/dialer/bytecounter_test.go b/internal/engine/netx/dialer/bytecounter_test.go index be8f9a8..e2937b3 100644 --- a/internal/engine/netx/dialer/bytecounter_test.go +++ b/internal/engine/netx/dialer/bytecounter_test.go @@ -76,7 +76,7 @@ func TestByteCounterNoHandlers(t *testing.T) { } func TestByteCounterConnectFailure(t *testing.T) { - dialer := &byteCounterDialer{Dialer: netxmocks.Dialer{ + dialer := &byteCounterDialer{Dialer: &netxmocks.Dialer{ MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) { return nil, io.EOF }, diff --git a/internal/engine/netx/dialer/errorwrapper_test.go b/internal/engine/netx/dialer/errorwrapper_test.go index 9dafba4..122afcd 100644 --- a/internal/engine/netx/dialer/errorwrapper_test.go +++ b/internal/engine/netx/dialer/errorwrapper_test.go @@ -13,7 +13,7 @@ import ( func TestErrorWrapperFailure(t *testing.T) { ctx := context.Background() - d := &errorWrapperDialer{Dialer: netxmocks.Dialer{ + d := &errorWrapperDialer{Dialer: &netxmocks.Dialer{ MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) { return nil, io.EOF }, @@ -43,7 +43,7 @@ func errorWrapperCheckErr(t *testing.T, err error, op string) { func TestErrorWrapperSuccess(t *testing.T) { ctx := context.Background() - d := &errorWrapperDialer{Dialer: netxmocks.Dialer{ + d := &errorWrapperDialer{Dialer: &netxmocks.Dialer{ MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) { return &netxmocks.Conn{ MockRead: func(b []byte) (int, error) { diff --git a/internal/engine/netx/dialer/proxy_test.go b/internal/engine/netx/dialer/proxy_test.go index c1c2c38..e91bb80 100644 --- a/internal/engine/netx/dialer/proxy_test.go +++ b/internal/engine/netx/dialer/proxy_test.go @@ -14,7 +14,7 @@ import ( func TestProxyDialerDialContextNoProxyURL(t *testing.T) { expected := errors.New("mocked error") d := &proxyDialer{ - Dialer: netxmocks.Dialer{ + Dialer: &netxmocks.Dialer{ MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) { return nil, expected }, @@ -45,7 +45,7 @@ func TestProxyDialerDialContextInvalidScheme(t *testing.T) { func TestProxyDialerDialContextWithEOF(t *testing.T) { const expect = "10.0.0.1:9050" d := &proxyDialer{ - Dialer: netxmocks.Dialer{ + Dialer: &netxmocks.Dialer{ MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) { if address != expect { return nil, errors.New("unexpected address") diff --git a/internal/engine/netx/dialer/saver_test.go b/internal/engine/netx/dialer/saver_test.go index c66f716..ce962f1 100644 --- a/internal/engine/netx/dialer/saver_test.go +++ b/internal/engine/netx/dialer/saver_test.go @@ -17,7 +17,7 @@ func TestSaverDialerFailure(t *testing.T) { expected := errors.New("mocked error") saver := &trace.Saver{} dlr := &saverDialer{ - Dialer: netxmocks.Dialer{ + Dialer: &netxmocks.Dialer{ MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) { return nil, expected }, @@ -59,7 +59,7 @@ func TestSaverConnDialerFailure(t *testing.T) { expected := errors.New("mocked error") saver := &trace.Saver{} dlr := &saverConnDialer{ - Dialer: netxmocks.Dialer{ + Dialer: &netxmocks.Dialer{ MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) { return nil, expected }, @@ -79,7 +79,7 @@ func TestSaverConnDialerSuccess(t *testing.T) { saver := &trace.Saver{} dlr := &saverConnDialer{ Dialer: &saverDialer{ - Dialer: netxmocks.Dialer{ + Dialer: &netxmocks.Dialer{ MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) { return &netxmocks.Conn{ MockRead: func(b []byte) (int, error) { diff --git a/internal/netxlite/dialer_test.go b/internal/netxlite/dialer_test.go index 879d3b6..6782ff0 100644 --- a/internal/netxlite/dialer_test.go +++ b/internal/netxlite/dialer_test.go @@ -69,7 +69,7 @@ func (r MockableResolver) Address() string { } func TestDialerResolverDialForSingleIPFails(t *testing.T) { - dialer := &DialerResolver{Dialer: netxmocks.Dialer{ + dialer := &DialerResolver{Dialer: &netxmocks.Dialer{ MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) { return nil, io.EOF }, @@ -85,7 +85,7 @@ func TestDialerResolverDialForSingleIPFails(t *testing.T) { func TestDialerResolverDialForManyIPFails(t *testing.T) { dialer := &DialerResolver{ - Dialer: netxmocks.Dialer{ + Dialer: &netxmocks.Dialer{ MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) { return nil, io.EOF }, @@ -102,7 +102,7 @@ func TestDialerResolverDialForManyIPFails(t *testing.T) { } func TestDialerResolverDialForManyIPSuccess(t *testing.T) { - dialer := &DialerResolver{Dialer: netxmocks.Dialer{ + dialer := &DialerResolver{Dialer: &netxmocks.Dialer{ MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) { return &netxmocks.Conn{ MockClose: func() error { @@ -125,7 +125,7 @@ func TestDialerResolverDialForManyIPSuccess(t *testing.T) { func TestDialerLoggerFailure(t *testing.T) { d := &DialerLogger{ - Dialer: netxmocks.Dialer{ + Dialer: &netxmocks.Dialer{ MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) { return nil, io.EOF }, diff --git a/internal/netxmocks/dialer.go b/internal/netxmocks/dialer.go index 0875ac6..c7d0dcb 100644 --- a/internal/netxmocks/dialer.go +++ b/internal/netxmocks/dialer.go @@ -16,8 +16,8 @@ type Dialer struct { } // DialContext implements Dialer.DialContext. -func (d Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { +func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { return d.MockDialContext(ctx, network, address) } -var _ dialer = Dialer{} +var _ dialer = &Dialer{} diff --git a/internal/netxmocks/resolver.go b/internal/netxmocks/resolver.go new file mode 100644 index 0000000..467aa62 --- /dev/null +++ b/internal/netxmocks/resolver.go @@ -0,0 +1,34 @@ +package netxmocks + +import "context" + +// resolver is the interface we expect from a resolver +type resolver interface { + LookupHost(ctx context.Context, domain string) ([]string, error) + Network() string + Address() string +} + +// Resolver is a mockable Resolver. +type Resolver struct { + MockLookupHost func(ctx context.Context, domain string) ([]string, error) + MockNetwork func() string + MockAddress func() string +} + +// LookupHost implements Resolver.LookupHost. +func (r *Resolver) LookupHost(ctx context.Context, domain string) ([]string, error) { + return r.MockLookupHost(ctx, domain) +} + +// Address implements Resolver.Address. +func (r *Resolver) Address() string { + return r.MockAddress() +} + +// Network implements Resolver.Network. +func (r *Resolver) Network() string { + return r.MockNetwork() +} + +var _ resolver = &Resolver{} diff --git a/internal/netxmocks/resolver_test.go b/internal/netxmocks/resolver_test.go new file mode 100644 index 0000000..a9c42ff --- /dev/null +++ b/internal/netxmocks/resolver_test.go @@ -0,0 +1,46 @@ +package netxmocks + +import ( + "context" + "errors" + "testing" +) + +func TestResolverLookupHost(t *testing.T) { + expected := errors.New("mocked error") + r := &Resolver{ + MockLookupHost: func(ctx context.Context, domain string) ([]string, error) { + return nil, expected + }, + } + ctx := context.Background() + addrs, err := r.LookupHost(ctx, "dns.google") + if !errors.Is(err, expected) { + t.Fatal("unexpected error", err) + } + if addrs != nil { + t.Fatal("expected nil addr") + } +} + +func TestResolverNetwork(t *testing.T) { + r := &Resolver{ + MockNetwork: func() string { + return "antani" + }, + } + if v := r.Network(); v != "antani" { + t.Fatal("unexpected network", v) + } +} + +func TestResolverAddress(t *testing.T) { + r := &Resolver{ + MockAddress: func() string { + return "1.1.1.1" + }, + } + if v := r.Address(); v != "1.1.1.1" { + t.Fatal("unexpected address", v) + } +}