diff --git a/internal/netxlite/filtering/dns_test.go b/internal/netxlite/filtering/dns_test.go index c5a0817..76782cd 100644 --- a/internal/netxlite/filtering/dns_test.go +++ b/internal/netxlite/filtering/dns_test.go @@ -4,6 +4,7 @@ import ( "context" "errors" "net" + "strings" "testing" "time" @@ -43,11 +44,11 @@ func TestDNSProxy(t *testing.T) { if addrs == nil { t.Fatal("unexpected empty addrs") } - var foundQuad8 bool + var found bool for _, addr := range addrs { - foundQuad8 = foundQuad8 || addr == "8.8.8.8" + found = found || addr == "8.8.8.8" } - if !foundQuad8 { + if !found { t.Fatal("did not find 8.8.8.8") } listener.Close() @@ -104,11 +105,11 @@ func TestDNSProxy(t *testing.T) { if addrs == nil { t.Fatal("expected non-empty addrs") } - var found127001 bool + var found bool for _, addr := range addrs { - found127001 = found127001 || addr == "127.0.0.1" + found = found || addr == "127.0.0.1" } - if !found127001 { + if !found { t.Fatal("did not find 127.0.0.1") } listener.Close() @@ -124,7 +125,7 @@ func TestDNSProxy(t *testing.T) { r := newresolver(listener) addrs, err := r.LookupHost(ctx, "dns.google") if err == nil || err.Error() != netxlite.FailureDNSNoAnswer { - t.Fatal(err) + t.Fatal("unexpected err", err) } if addrs != nil { t.Fatal("expected empty addrs") @@ -140,15 +141,15 @@ func TestDNSProxy(t *testing.T) { // careful because lots of legacy code uses SerialResolver. const timeout = time.Second ctx, cancel := context.WithTimeout(context.Background(), timeout) - listener, done, err := newproxy(DNSActionTimeout) defer cancel() + listener, done, err := newproxy(DNSActionTimeout) if err != nil { t.Fatal(err) } r := newresolver(listener) addrs, err := r.LookupHost(ctx, "dns.google") if err == nil || err.Error() != netxlite.FailureGenericTimeoutError { - t.Fatal(err) + t.Fatal("unexpected err", err) } if addrs != nil { t.Fatal("expected empty addrs") @@ -160,8 +161,8 @@ func TestDNSProxy(t *testing.T) { t.Run("Start with invalid address", func(t *testing.T) { p := &DNSProxy{} listener, err := p.Start("127.0.0.1") - if err == nil { - t.Fatal("expected an error") + if err == nil || !strings.HasSuffix(err.Error(), "missing port in address") { + t.Fatal("unexpected err", err) } if listener != nil { t.Fatal("expected nil listener") @@ -273,13 +274,13 @@ func TestDNSProxy(t *testing.T) { }) t.Run("proxy", func(t *testing.T) { - t.Run("pack fails", func(t *testing.T) { + t.Run("Pack fails", func(t *testing.T) { p := &DNSProxy{} query := &dns.Msg{} query.Rcode = -1 // causes Pack to fail reply, err := p.proxy(query) - if err == nil { - t.Fatal("expected error here") + if err == nil || !strings.HasSuffix(err.Error(), "bad rcode") { + t.Fatal("unexpected err", err) } if reply != nil { t.Fatal("expected nil reply") @@ -305,7 +306,7 @@ func TestDNSProxy(t *testing.T) { } }) - t.Run("unpack fails", func(t *testing.T) { + t.Run("Unpack fails", func(t *testing.T) { p := &DNSProxy{ Upstream: &mocks.DNSTransport{ MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) { @@ -315,8 +316,8 @@ func TestDNSProxy(t *testing.T) { }, } reply, err := p.proxy(&dns.Msg{}) - if err == nil { - t.Fatal("expected error") + if err == nil || !strings.HasSuffix(err.Error(), "overflow unpacking uint16") { + t.Fatal("unexpected err", err) } if reply != nil { t.Fatal("expected nil reply here") diff --git a/internal/netxlite/filtering/http_test.go b/internal/netxlite/filtering/http_test.go index 025ee74..c43abdf 100644 --- a/internal/netxlite/filtering/http_test.go +++ b/internal/netxlite/filtering/http_test.go @@ -2,6 +2,7 @@ package filtering import ( "context" + "errors" "net" "net/http" "net/url" @@ -97,7 +98,7 @@ func TestHTTPProxy(t *testing.T) { t.Fatal(err) } resp, err := httpGET(ctx, listener.Addr(), "nexa.polito.it") - if err == nil || !strings.HasSuffix(err.Error(), "context deadline exceeded") { + if !errors.Is(err, context.DeadlineExceeded) { t.Fatal("unexpected err", err) } if resp != nil { @@ -159,8 +160,8 @@ func TestHTTPProxy(t *testing.T) { t.Run("Start fails on an invalid address", func(t *testing.T) { p := &HTTPProxy{} listener, err := p.Start("127.0.0.1") - if err == nil { - t.Fatal("expected an error") + if err == nil || !strings.HasSuffix(err.Error(), "missing port in address") { + t.Fatal("unexpected err", err) } if listener != nil { t.Fatal("expected nil listener") diff --git a/internal/netxlite/filtering/tls.go b/internal/netxlite/filtering/tls.go index 7bdc836..0810c85 100644 --- a/internal/netxlite/filtering/tls.go +++ b/internal/netxlite/filtering/tls.go @@ -60,18 +60,23 @@ func (p *TLSProxy) start(address string) (net.Listener, <-chan interface{}, erro func (p *TLSProxy) mainloop(listener net.Listener, done chan<- interface{}) { defer close(done) - for { - conn, err := listener.Accept() - if err == nil { - go p.handle(conn) - continue - } - if strings.HasSuffix(err.Error(), "use of closed network connection") { - break - } + for p.oneloop(listener) { + // nothing } } +func (p *TLSProxy) oneloop(listener net.Listener) bool { + conn, err := listener.Accept() + if err != nil && strings.HasSuffix(err.Error(), "use of closed network connection") { + return false // we need to stop + } + if err != nil { + return true // we can continue running + } + go p.handle(conn) + return true // we can continue running +} + const ( tlsAlertInternalError = byte(80) tlsAlertUnrecognizedName = byte(112) @@ -143,10 +148,11 @@ type tlsClientHelloReader struct { func (c *tlsClientHelloReader) Read(b []byte) (int, error) { count, err := c.Conn.Read(b) - if err == nil { - c.clientHello = append(c.clientHello, b[:count]...) + if err != nil { + return 0, err } - return count, err + c.clientHello = append(c.clientHello, b[:count]...) + return count, nil } // Write prevents writing on the real connection diff --git a/internal/netxlite/filtering/tls_test.go b/internal/netxlite/filtering/tls_test.go index 162010e..cbf25df 100644 --- a/internal/netxlite/filtering/tls_test.go +++ b/internal/netxlite/filtering/tls_test.go @@ -134,25 +134,24 @@ func TestTLSProxy(t *testing.T) { <-done // wait for background goroutine to exit }) + dial := func(ctx context.Context, endpoint string) (net.Conn, error) { + d := netxlite.NewDialerWithoutResolver(log.Log) + return d.DialContext(ctx, "tcp", endpoint) + } + t.Run("handle cannot read ClientHello", func(t *testing.T) { listener, done, err := newproxy(TLSActionPass) if err != nil { t.Fatal(err) } - conn, err := net.Dial("tcp", listener.Addr().String()) + conn, err := dial(context.Background(), listener.Addr().String()) if err != nil { t.Fatal(err) } conn.Write([]byte("GET / HTTP/1.0\r\n\r\n")) buff := make([]byte, 1<<17) _, err = conn.Read(buff) - // Implementation note: we need to wrap the error because - // otherwise the error string on Windows is different from Unix - if err == nil { - t.Fatal("expected non-nil error") - } - err = netxlite.NewTopLevelGenericErrWrapper(err) - if err.Error() != netxlite.FailureConnectionReset { + if err == nil || err.Error() != netxlite.FailureConnectionReset { t.Fatal("unexpected err", err) } listener.Close() @@ -251,11 +250,45 @@ func TestTLSProxy(t *testing.T) { t.Run("Start fails on an invalid address", func(t *testing.T) { p := &TLSProxy{} listener, err := p.Start("127.0.0.1") - if err == nil { - t.Fatal("expected an error") + if err == nil || !strings.HasSuffix(err.Error(), "missing port in address") { + t.Fatal("unexpected err", err) } if listener != nil { t.Fatal("expected nil listener") } }) + + t.Run("oneloop correctly handles a listener error", func(t *testing.T) { + listener := &mocks.Listener{ + MockAccept: func() (net.Conn, error) { + return nil, errors.New("mocked error") + }, + } + p := &TLSProxy{} + if !p.oneloop(listener) { + t.Fatal("should return true here") + } + }) +} + +func TestTLSClientHelloReader(t *testing.T) { + t.Run("on failure", func(t *testing.T) { + expected := errors.New("mocked error") + chr := &tlsClientHelloReader{ + Conn: &mocks.Conn{ + MockRead: func(b []byte) (int, error) { + return 0, expected + }, + }, + clientHello: []byte{}, + } + buf := make([]byte, 128) + count, err := chr.Read(buf) + if !errors.Is(err, expected) { + t.Fatal("unexpected err", err) + } + if count != 0 { + t.Fatal("invalid count") + } + }) } diff --git a/internal/netxlite/mocks/listener.go b/internal/netxlite/mocks/listener.go new file mode 100644 index 0000000..6d3fbb9 --- /dev/null +++ b/internal/netxlite/mocks/listener.go @@ -0,0 +1,32 @@ +package mocks + +import "net" + +// Listener allows mocking a net.Listener. +type Listener struct { + // Accept allows mocking Accept. + MockAccept func() (net.Conn, error) + + // Close allows mocking Close. + MockClose func() error + + // Addr allows mocking Addr. + MockAddr func() net.Addr +} + +var _ net.Listener = &Listener{} + +// Accept implements net.Listener.Accept +func (li *Listener) Accept() (net.Conn, error) { + return li.MockAccept() +} + +// Close implements net.Listener.Closer. +func (li *Listener) Close() error { + return li.MockClose() +} + +// Addr implements net.Listener.Addr +func (li *Listener) Addr() net.Addr { + return li.MockAddr() +} diff --git a/internal/netxlite/mocks/listener_test.go b/internal/netxlite/mocks/listener_test.go new file mode 100644 index 0000000..17d6426 --- /dev/null +++ b/internal/netxlite/mocks/listener_test.go @@ -0,0 +1,56 @@ +package mocks + +import ( + "errors" + "net" + "testing" + + "github.com/google/go-cmp/cmp" +) + +func TestListener(t *testing.T) { + t.Run("Accept", func(t *testing.T) { + expected := errors.New("mocked error") + li := &Listener{ + MockAccept: func() (net.Conn, error) { + return nil, expected + }, + } + conn, err := li.Accept() + if !errors.Is(err, expected) { + t.Fatal("unexpected err", err) + } + if conn != nil { + t.Fatal("expected nil conn") + } + }) + + t.Run("Close", func(t *testing.T) { + expected := errors.New("mocked error") + li := &Listener{ + MockClose: func() error { + return expected + }, + } + err := li.Close() + if !errors.Is(err, expected) { + t.Fatal("unexpected err", err) + } + }) + + t.Run("Addr", func(t *testing.T) { + addr := &net.TCPAddr{ + IP: net.IPv4(127, 0, 0, 1), + Port: 1234, + } + li := &Listener{ + MockAddr: func() net.Addr { + return addr + }, + } + outAddr := li.Addr() + if diff := cmp.Diff(addr, outAddr); diff != "" { + t.Fatal(diff) + } + }) +}