diff --git a/internal/engine/netx/dialer/bytecounter_test.go b/internal/engine/netx/dialer/bytecounter_test.go index 3376385..8fd09a2 100644 --- a/internal/engine/netx/dialer/bytecounter_test.go +++ b/internal/engine/netx/dialer/bytecounter_test.go @@ -11,6 +11,7 @@ import ( "github.com/ooni/probe-cli/v3/internal/engine/netx/bytecounter" "github.com/ooni/probe-cli/v3/internal/engine/netx/dialer" + "github.com/ooni/probe-cli/v3/internal/engine/netx/mockablex" ) func dorequest(ctx context.Context, url string) error { @@ -70,7 +71,11 @@ func TestByteCounterNoHandlers(t *testing.T) { } func TestByteCounterConnectFailure(t *testing.T) { - dialer := dialer.ByteCounterDialer{Dialer: dialer.EOFDialer{}} + dialer := dialer.ByteCounterDialer{Dialer: mockablex.Dialer{ + MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) { + return nil, io.EOF + }, + }} conn, err := dialer.DialContext(context.Background(), "tcp", "www.google.com:80") if !errors.Is(err, io.EOF) { t.Fatal("not the error we expected") diff --git a/internal/engine/netx/dialer/dns_test.go b/internal/engine/netx/dialer/dns_test.go index ba8e406..497bb27 100644 --- a/internal/engine/netx/dialer/dns_test.go +++ b/internal/engine/netx/dialer/dns_test.go @@ -12,6 +12,7 @@ import ( "github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/modelx" "github.com/ooni/probe-cli/v3/internal/engine/netx/dialer" "github.com/ooni/probe-cli/v3/internal/engine/netx/errorx" + "github.com/ooni/probe-cli/v3/internal/engine/netx/mockablex" ) func TestDNSDialerNoPort(t *testing.T) { @@ -62,7 +63,11 @@ func (r MockableResolver) LookupHost(ctx context.Context, host string) ([]string } func TestDNSDialerDialForSingleIPFails(t *testing.T) { - dialer := dialer.DNSDialer{Dialer: dialer.EOFDialer{}, Resolver: new(net.Resolver)} + dialer := dialer.DNSDialer{Dialer: mockablex.Dialer{ + MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) { + return nil, io.EOF + }, + }, Resolver: new(net.Resolver)} conn, err := dialer.DialContext(context.Background(), "tcp", "1.1.1.1:853") if !errors.Is(err, io.EOF) { t.Fatal("not the error we expected") @@ -73,9 +78,14 @@ func TestDNSDialerDialForSingleIPFails(t *testing.T) { } func TestDNSDialerDialForManyIPFails(t *testing.T) { - dialer := dialer.DNSDialer{Dialer: dialer.EOFDialer{}, Resolver: MockableResolver{ - Addresses: []string{"1.1.1.1", "8.8.8.8"}, - }} + dialer := dialer.DNSDialer{ + Dialer: mockablex.Dialer{ + MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) { + return nil, io.EOF + }, + }, Resolver: MockableResolver{ + Addresses: []string{"1.1.1.1", "8.8.8.8"}, + }} conn, err := dialer.DialContext(context.Background(), "tcp", "dot.dns:853") if !errors.Is(err, io.EOF) { t.Fatal("not the error we expected") @@ -86,7 +96,15 @@ func TestDNSDialerDialForManyIPFails(t *testing.T) { } func TestDNSDialerDialForManyIPSuccess(t *testing.T) { - dialer := dialer.DNSDialer{Dialer: dialer.EOFConnDialer{}, Resolver: MockableResolver{ + dialer := dialer.DNSDialer{Dialer: mockablex.Dialer{ + MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) { + return &mockablex.Conn{ + MockClose: func() error { + return nil + }, + }, nil + }, + }, Resolver: MockableResolver{ Addresses: []string{"1.1.1.1", "8.8.8.8"}, }} conn, err := dialer.DialContext(context.Background(), "tcp", "dot.dns:853") @@ -106,7 +124,18 @@ func TestDNSDialerDialSetsDialID(t *testing.T) { Handler: saver, }) dialer := dialer.DNSDialer{Dialer: dialer.EmitterDialer{ - Dialer: dialer.EOFConnDialer{}, + Dialer: mockablex.Dialer{ + MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) { + return &mockablex.Conn{ + MockClose: func() error { + return nil + }, + MockLocalAddr: func() net.Addr { + return &net.TCPAddr{} + }, + }, nil + }, + }, }, Resolver: MockableResolver{ Addresses: []string{"1.1.1.1", "8.8.8.8"}, }} diff --git a/internal/engine/netx/dialer/emitter_test.go b/internal/engine/netx/dialer/emitter_test.go index adf5ad5..ee7ad7f 100644 --- a/internal/engine/netx/dialer/emitter_test.go +++ b/internal/engine/netx/dialer/emitter_test.go @@ -4,6 +4,7 @@ import ( "context" "errors" "io" + "net" "testing" "time" @@ -12,6 +13,7 @@ import ( "github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/modelx" "github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/transactionid" "github.com/ooni/probe-cli/v3/internal/engine/netx/dialer" + "github.com/ooni/probe-cli/v3/internal/engine/netx/mockablex" ) func TestEmitterFailure(t *testing.T) { @@ -22,7 +24,11 @@ func TestEmitterFailure(t *testing.T) { Handler: saver, }) ctx = transactionid.WithTransactionID(ctx) - d := dialer.EmitterDialer{Dialer: dialer.EOFDialer{}} + d := dialer.EmitterDialer{Dialer: mockablex.Dialer{ + MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) { + return nil, io.EOF + }, + }} conn, err := d.DialContext(ctx, "tcp", "www.google.com:443") if !errors.Is(err, io.EOF) { t.Fatal("not the error we expected") @@ -77,7 +83,24 @@ func TestEmitterSuccess(t *testing.T) { Handler: saver, }) ctx = transactionid.WithTransactionID(ctx) - d := dialer.EmitterDialer{Dialer: dialer.EOFConnDialer{}} + d := dialer.EmitterDialer{Dialer: mockablex.Dialer{ + MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) { + return &mockablex.Conn{ + MockRead: func(b []byte) (int, error) { + return 0, io.EOF + }, + MockWrite: func(b []byte) (int, error) { + return 0, io.EOF + }, + MockClose: func() error { + return io.EOF + }, + MockLocalAddr: func() net.Addr { + return &net.TCPAddr{Port: 12345} + }, + }, nil + }, + }} conn, err := d.DialContext(ctx, "tcp", "www.google.com:443") if err != nil { t.Fatal("we expected no error") diff --git a/internal/engine/netx/dialer/eof_test.go b/internal/engine/netx/dialer/eof_test.go deleted file mode 100644 index c629a69..0000000 --- a/internal/engine/netx/dialer/eof_test.go +++ /dev/null @@ -1,70 +0,0 @@ -package dialer - -import ( - "context" - "io" - "net" - "time" -) - -type EOFDialer struct{} - -func (EOFDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { - time.Sleep(10 * time.Microsecond) - return nil, io.EOF -} - -type EOFConnDialer struct{} - -func (EOFConnDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { - return EOFConn{}, nil -} - -type EOFConn struct { - net.Conn -} - -func (EOFConn) Read(p []byte) (int, error) { - time.Sleep(10 * time.Microsecond) - return 0, io.EOF -} - -func (EOFConn) Write(p []byte) (int, error) { - time.Sleep(10 * time.Microsecond) - return 0, io.EOF -} - -func (EOFConn) Close() error { - time.Sleep(10 * time.Microsecond) - return io.EOF -} - -func (EOFConn) LocalAddr() net.Addr { - return EOFAddr{} -} - -func (EOFConn) RemoteAddr() net.Addr { - return EOFAddr{} -} - -func (EOFConn) SetDeadline(t time.Time) error { - return nil -} - -func (EOFConn) SetReadDeadline(t time.Time) error { - return nil -} - -func (EOFConn) SetWriteDeadline(t time.Time) error { - return nil -} - -type EOFAddr struct{} - -func (EOFAddr) Network() string { - return "tcp" -} - -func (EOFAddr) String() string { - return "127.0.0.1:1234" -} diff --git a/internal/engine/netx/dialer/errorwrapper_test.go b/internal/engine/netx/dialer/errorwrapper_test.go index 260519e..b8436ba 100644 --- a/internal/engine/netx/dialer/errorwrapper_test.go +++ b/internal/engine/netx/dialer/errorwrapper_test.go @@ -4,16 +4,22 @@ import ( "context" "errors" "io" + "net" "testing" "github.com/ooni/probe-cli/v3/internal/engine/legacy/netx/dialid" "github.com/ooni/probe-cli/v3/internal/engine/netx/dialer" "github.com/ooni/probe-cli/v3/internal/engine/netx/errorx" + "github.com/ooni/probe-cli/v3/internal/engine/netx/mockablex" ) func TestErrorWrapperFailure(t *testing.T) { ctx := dialid.WithDialID(context.Background()) - d := dialer.ErrorWrapperDialer{Dialer: dialer.EOFDialer{}} + d := dialer.ErrorWrapperDialer{Dialer: mockablex.Dialer{ + MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) { + return nil, io.EOF + }, + }} conn, err := d.DialContext(ctx, "tcp", "www.google.com:443") if conn != nil { t.Fatal("expected a nil conn here") @@ -42,7 +48,24 @@ func errorWrapperCheckErr(t *testing.T, err error, op string) { func TestErrorWrapperSuccess(t *testing.T) { ctx := dialid.WithDialID(context.Background()) - d := dialer.ErrorWrapperDialer{Dialer: dialer.EOFConnDialer{}} + d := dialer.ErrorWrapperDialer{Dialer: mockablex.Dialer{ + MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) { + return &mockablex.Conn{ + MockRead: func(b []byte) (int, error) { + return 0, io.EOF + }, + MockWrite: func(b []byte) (int, error) { + return 0, io.EOF + }, + MockClose: func() error { + return io.EOF + }, + MockLocalAddr: func() net.Addr { + return &net.TCPAddr{Port: 12345} + }, + }, nil + }, + }} conn, err := d.DialContext(ctx, "tcp", "www.google.com") if err != nil { t.Fatal(err) diff --git a/internal/engine/netx/dialer/fake_test.go b/internal/engine/netx/dialer/fake_test.go deleted file mode 100644 index 9b66568..0000000 --- a/internal/engine/netx/dialer/fake_test.go +++ /dev/null @@ -1,71 +0,0 @@ -package dialer - -import ( - "context" - "io" - "net" - "time" -) - -type FakeDialer struct { - Conn net.Conn - Err error -} - -func (d FakeDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { - time.Sleep(10 * time.Microsecond) - return d.Conn, d.Err -} - -type FakeConn struct { - ReadError error - ReadData []byte - SetDeadlineError error - SetReadDeadlineError error - SetWriteDeadlineError error - WriteError error -} - -func (c *FakeConn) Read(b []byte) (int, error) { - if len(c.ReadData) > 0 { - n := copy(b, c.ReadData) - c.ReadData = c.ReadData[n:] - return n, nil - } - if c.ReadError != nil { - return 0, c.ReadError - } - return 0, io.EOF -} - -func (c *FakeConn) Write(b []byte) (n int, err error) { - if c.WriteError != nil { - return 0, c.WriteError - } - n = len(b) - return -} - -func (*FakeConn) Close() (err error) { - return -} - -func (*FakeConn) LocalAddr() net.Addr { - return &net.TCPAddr{} -} - -func (*FakeConn) RemoteAddr() net.Addr { - return &net.TCPAddr{} -} - -func (c *FakeConn) SetDeadline(t time.Time) (err error) { - return c.SetDeadlineError -} - -func (c *FakeConn) SetReadDeadline(t time.Time) (err error) { - return c.SetReadDeadlineError -} - -func (c *FakeConn) SetWriteDeadline(t time.Time) (err error) { - return c.SetWriteDeadlineError -} diff --git a/internal/engine/netx/dialer/logging_test.go b/internal/engine/netx/dialer/logging_test.go index ba262e4..bfbf1e4 100644 --- a/internal/engine/netx/dialer/logging_test.go +++ b/internal/engine/netx/dialer/logging_test.go @@ -4,15 +4,21 @@ import ( "context" "errors" "io" + "net" "testing" "github.com/apex/log" "github.com/ooni/probe-cli/v3/internal/engine/netx/dialer" + "github.com/ooni/probe-cli/v3/internal/engine/netx/mockablex" ) func TestLoggingDialerFailure(t *testing.T) { d := dialer.LoggingDialer{ - Dialer: dialer.EOFDialer{}, + Dialer: mockablex.Dialer{ + MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) { + return nil, io.EOF + }, + }, Logger: log.Log, } conn, err := d.DialContext(context.Background(), "tcp", "www.google.com:443") diff --git a/internal/engine/netx/dialer/proxy_test.go b/internal/engine/netx/dialer/proxy_test.go index 7bfa52c..600eda9 100644 --- a/internal/engine/netx/dialer/proxy_test.go +++ b/internal/engine/netx/dialer/proxy_test.go @@ -4,16 +4,23 @@ import ( "context" "errors" "io" + "net" "net/url" "testing" + "time" "github.com/ooni/probe-cli/v3/internal/engine/netx/dialer" + "github.com/ooni/probe-cli/v3/internal/engine/netx/mockablex" ) func TestProxyDialerDialContextNoProxyURL(t *testing.T) { expected := errors.New("mocked error") d := dialer.ProxyDialer{ - Dialer: dialer.FakeDialer{Err: expected}, + Dialer: mockablex.Dialer{ + MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) { + return nil, expected + }, + }, } conn, err := d.DialContext(context.Background(), "tcp", "www.google.com:443") if !errors.Is(err, expected) { @@ -26,7 +33,6 @@ func TestProxyDialerDialContextNoProxyURL(t *testing.T) { func TestProxyDialerDialContextInvalidScheme(t *testing.T) { d := dialer.ProxyDialer{ - Dialer: dialer.FakeDialer{}, ProxyURL: &url.URL{Scheme: "antani"}, } conn, err := d.DialContext(context.Background(), "tcp", "www.google.com:443") @@ -40,8 +46,10 @@ func TestProxyDialerDialContextInvalidScheme(t *testing.T) { func TestProxyDialerDialContextWithEOF(t *testing.T) { d := dialer.ProxyDialer{ - Dialer: dialer.FakeDialer{ - Err: io.EOF, + Dialer: mockablex.Dialer{ + MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) { + return nil, io.EOF + }, }, ProxyURL: &url.URL{Scheme: "socks5"}, } @@ -58,8 +66,10 @@ func TestProxyDialerDialContextWithContextCanceled(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) cancel() // immediately fail d := dialer.ProxyDialer{ - Dialer: dialer.FakeDialer{ - Err: io.EOF, + Dialer: mockablex.Dialer{ + MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) { + return nil, io.EOF + }, }, ProxyURL: &url.URL{Scheme: "socks5"}, } @@ -74,10 +84,19 @@ func TestProxyDialerDialContextWithContextCanceled(t *testing.T) { func TestProxyDialerDialContextWithDialerSuccess(t *testing.T) { d := dialer.ProxyDialer{ - Dialer: dialer.FakeDialer{ - Conn: &dialer.FakeConn{ - ReadError: io.EOF, - WriteError: io.EOF, + Dialer: mockablex.Dialer{ + MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) { + return &mockablex.Conn{ + MockRead: func(b []byte) (int, error) { + return 0, io.EOF + }, + MockWrite: func(b []byte) (int, error) { + return 0, io.EOF + }, + MockClose: func() error { + return io.EOF + }, + }, nil }, }, ProxyURL: &url.URL{Scheme: "socks5"}, @@ -99,10 +118,20 @@ func TestProxyDialerDialContextWithDialerCanceledContext(t *testing.T) { // arm where we receive the conn is much less likely. cancel() d := dialer.ProxyDialer{ - Dialer: dialer.FakeDialer{ - Conn: &dialer.FakeConn{ - ReadError: io.EOF, - WriteError: io.EOF, + Dialer: mockablex.Dialer{ + MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) { + time.Sleep(10 * time.Microsecond) + return &mockablex.Conn{ + MockRead: func(b []byte) (int, error) { + return 0, io.EOF + }, + MockWrite: func(b []byte) (int, error) { + return 0, io.EOF + }, + MockClose: func() error { + return io.EOF + }, + }, nil }, }, ProxyURL: &url.URL{Scheme: "socks5"}, @@ -121,8 +150,10 @@ func TestProxyDialerDialContextWithDialerCanceledContext(t *testing.T) { func TestProxyDialerWrapper(t *testing.T) { d := dialer.ProxyDialerWrapper{ - Dialer: dialer.FakeDialer{ - Err: io.EOF, + Dialer: mockablex.Dialer{ + MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) { + return nil, io.EOF + }, }, } conn, err := d.Dial("tcp", "www.google.com:443") diff --git a/internal/engine/netx/dialer/saver_test.go b/internal/engine/netx/dialer/saver_test.go index a256b3a..3dd59eb 100644 --- a/internal/engine/netx/dialer/saver_test.go +++ b/internal/engine/netx/dialer/saver_test.go @@ -3,11 +3,14 @@ package dialer_test import ( "context" "errors" + "io" + "net" "testing" "time" "github.com/ooni/probe-cli/v3/internal/engine/netx/dialer" "github.com/ooni/probe-cli/v3/internal/engine/netx/errorx" + "github.com/ooni/probe-cli/v3/internal/engine/netx/mockablex" "github.com/ooni/probe-cli/v3/internal/engine/netx/trace" ) @@ -15,8 +18,10 @@ func TestSaverDialerFailure(t *testing.T) { expected := errors.New("mocked error") saver := &trace.Saver{} dlr := dialer.SaverDialer{ - Dialer: dialer.FakeDialer{ - Err: expected, + Dialer: mockablex.Dialer{ + MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) { + return nil, expected + }, }, Saver: saver, } @@ -55,8 +60,10 @@ func TestSaverConnDialerFailure(t *testing.T) { expected := errors.New("mocked error") saver := &trace.Saver{} dlr := dialer.SaverConnDialer{ - Dialer: dialer.FakeDialer{ - Err: expected, + Dialer: mockablex.Dialer{ + MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) { + return nil, expected + }, }, Saver: saver, } @@ -68,3 +75,66 @@ func TestSaverConnDialerFailure(t *testing.T) { t.Fatal("expected nil conn here") } } + +func TestSaverConnDialerSuccess(t *testing.T) { + saver := &trace.Saver{} + dlr := dialer.SaverConnDialer{ + Dialer: dialer.SaverDialer{ + Dialer: mockablex.Dialer{ + MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) { + return &mockablex.Conn{ + MockRead: func(b []byte) (int, error) { + return 0, io.EOF + }, + MockWrite: func(b []byte) (int, error) { + return 0, io.EOF + }, + MockClose: func() error { + return io.EOF + }, + MockLocalAddr: func() net.Addr { + return &net.TCPAddr{Port: 12345} + }, + }, nil + }, + }, + Saver: saver, + }, + Saver: saver, + } + conn, err := dlr.DialContext(context.Background(), "tcp", "www.google.com:443") + if err != nil { + t.Fatal("not the error we expected", err) + } + conn.Read(nil) + conn.Write(nil) + conn.Close() + events := saver.Read() + if len(events) != 3 { + t.Fatal("unexpected number of events saved", len(events)) + } + if events[0].Name != "connect" { + t.Fatal("expected a connect event") + } + saverCheckConnectEvent(t, &events[0]) + if events[1].Name != "read" { + t.Fatal("expected a read event") + } + saverCheckReadEvent(t, &events[1]) + if events[2].Name != "write" { + t.Fatal("expected a write event") + } + saverCheckWriteEvent(t, &events[2]) +} + +func saverCheckConnectEvent(t *testing.T, ev *trace.Event) { + // TODO(bassosimone): implement +} + +func saverCheckReadEvent(t *testing.T, ev *trace.Event) { + // TODO(bassosimone): implement +} + +func saverCheckWriteEvent(t *testing.T, ev *trace.Event) { + // TODO(bassosimone): implement +} diff --git a/internal/engine/netx/mockablex/conn.go b/internal/engine/netx/mockablex/conn.go new file mode 100644 index 0000000..1e369eb --- /dev/null +++ b/internal/engine/netx/mockablex/conn.go @@ -0,0 +1,60 @@ +package mockablex + +import ( + "net" + "time" +) + +// Conn is a mockable net.Conn. +type Conn struct { + MockRead func(b []byte) (int, error) + MockWrite func(b []byte) (int, error) + MockClose func() error + MockLocalAddr func() net.Addr + MockRemoteAddr func() net.Addr + MockSetDeadline func(t time.Time) error + MockSetReadDeadline func(t time.Time) error + MockSetWriteDeadline func(t time.Time) error +} + +// Read implements net.Conn.Read +func (c *Conn) Read(b []byte) (int, error) { + return c.MockRead(b) +} + +// Write implements net.Conn.Write +func (c *Conn) Write(b []byte) (int, error) { + return c.MockWrite(b) +} + +// Close implements net.Conn.Close +func (c *Conn) Close() error { + return c.MockClose() +} + +// LocalAddr returns the local address +func (c *Conn) LocalAddr() net.Addr { + return c.MockLocalAddr() +} + +// RemoteAddr returns the remote address +func (c *Conn) RemoteAddr() net.Addr { + return c.MockRemoteAddr() +} + +// SetDeadline sets the connection deadline. +func (c *Conn) SetDeadline(t time.Time) error { + return c.MockSetDeadline(t) +} + +// SetReadDeadline sets the read deadline. +func (c *Conn) SetReadDeadline(t time.Time) error { + return c.MockSetReadDeadline(t) +} + +// SetWriteDeadline sets the write deadline. +func (c *Conn) SetWriteDeadline(t time.Time) error { + return c.MockSetWriteDeadline(t) +} + +var _ net.Conn = &Conn{} diff --git a/internal/engine/netx/mockablex/conn_test.go b/internal/engine/netx/mockablex/conn_test.go new file mode 100644 index 0000000..b438ccd --- /dev/null +++ b/internal/engine/netx/mockablex/conn_test.go @@ -0,0 +1,126 @@ +package mockablex + +import ( + "errors" + "net" + "testing" + "time" + + "github.com/google/go-cmp/cmp" +) + +func TestConnReadWorks(t *testing.T) { + expected := errors.New("mocked error") + c := &Conn{ + MockRead: func(b []byte) (int, error) { + return 0, expected + }, + } + count, err := c.Read(make([]byte, 128)) + if !errors.Is(err, expected) { + t.Fatal("not the error we expected") + } + if count != 0 { + t.Fatal("expected 0 bytes") + } +} + +func TestConnWriteWorks(t *testing.T) { + expected := errors.New("mocked error") + c := &Conn{ + MockWrite: func(b []byte) (int, error) { + return 0, expected + }, + } + count, err := c.Write(make([]byte, 128)) + if !errors.Is(err, expected) { + t.Fatal("not the error we expected") + } + if count != 0 { + t.Fatal("expected 0 bytes") + } +} + +func TestConnCloseWorks(t *testing.T) { + expected := errors.New("mocked error") + c := &Conn{ + MockClose: func() error { + return expected + }, + } + err := c.Close() + if !errors.Is(err, expected) { + t.Fatal("not the error we expected") + } +} + +func TestConnLocalAddrWorks(t *testing.T) { + expected := &net.TCPAddr{ + IP: net.IPv6loopback, + Port: 1234, + } + c := &Conn{ + MockLocalAddr: func() net.Addr { + return expected + }, + } + out := c.LocalAddr() + if diff := cmp.Diff(expected, out); diff != "" { + t.Fatal(diff) + } +} + +func TestConnRemoteAddrWorks(t *testing.T) { + expected := &net.TCPAddr{ + IP: net.IPv6loopback, + Port: 1234, + } + c := &Conn{ + MockRemoteAddr: func() net.Addr { + return expected + }, + } + out := c.RemoteAddr() + if diff := cmp.Diff(expected, out); diff != "" { + t.Fatal(diff) + } +} + +func TestConnSetDeadline(t *testing.T) { + expected := errors.New("mocked error") + c := &Conn{ + MockSetDeadline: func(t time.Time) error { + return expected + }, + } + err := c.SetDeadline(time.Time{}) + if !errors.Is(err, expected) { + t.Fatal("not the error we expected", err) + } +} + +func TestConnSetReadDeadline(t *testing.T) { + expected := errors.New("mocked error") + c := &Conn{ + MockSetReadDeadline: func(t time.Time) error { + return expected + }, + } + err := c.SetReadDeadline(time.Time{}) + if !errors.Is(err, expected) { + t.Fatal("not the error we expected", err) + } +} + +func TestConnSetWriteDeadline(t *testing.T) { + expected := errors.New("mocked error") + c := &Conn{ + MockSetWriteDeadline: func(t time.Time) error { + return expected + }, + } + err := c.SetWriteDeadline(time.Time{}) + if !errors.Is(err, expected) { + t.Fatal("not the error we expected", err) + } +} diff --git a/internal/engine/netx/mockablex/dialer.go b/internal/engine/netx/mockablex/dialer.go new file mode 100644 index 0000000..dd9a58f --- /dev/null +++ b/internal/engine/netx/mockablex/dialer.go @@ -0,0 +1,23 @@ +package mockablex + +import ( + "context" + "net" +) + +// dialer is the interface we expect from a dialer +type dialer interface { + DialContext(ctx context.Context, network, address string) (net.Conn, error) +} + +// Dialer is a mockable Dialer. +type Dialer struct { + MockDialContext func(ctx context.Context, network, address string) (net.Conn, error) +} + +// DialContext implements Dialer.DialContext. +func (d Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { + return d.MockDialContext(ctx, network, address) +} + +var _ dialer = Dialer{} diff --git a/internal/engine/netx/mockablex/dialer_test.go b/internal/engine/netx/mockablex/dialer_test.go new file mode 100644 index 0000000..b9c3a83 --- /dev/null +++ b/internal/engine/netx/mockablex/dialer_test.go @@ -0,0 +1,25 @@ +package mockablex + +import ( + "context" + "errors" + "net" + "testing" +) + +func TestDialerWorks(t *testing.T) { + expected := errors.New("mocked error") + d := Dialer{ + MockDialContext: func(ctx context.Context, network string, address string) (net.Conn, error) { + return nil, expected + }, + } + ctx := context.Background() + conn, err := d.DialContext(ctx, "tcp", "8.8.8.8:53") + if !errors.Is(err, expected) { + t.Fatal("not the error we expected") + } + if conn != nil { + t.Fatal("expected nil conn") + } +} diff --git a/internal/engine/netx/mockablex/doc.go b/internal/engine/netx/mockablex/doc.go new file mode 100644 index 0000000..45b32d3 --- /dev/null +++ b/internal/engine/netx/mockablex/doc.go @@ -0,0 +1,2 @@ +// Package mockable contains mocks for netx types. +package mockablex