fix(filtering): avoid the if err == nil pattern (#567)
1. in normal code is better to always do if err != nil so that the ifs only contain error code (this is ~coding policy) 2. in tests we want to ensure we narrow down the error to the real error that happened, to have greater confidence Written while working on https://github.com/ooni/probe/issues/1803#issuecomment-957323297
This commit is contained in:
parent
374577f5a8
commit
560b1a9a97
|
@ -4,6 +4,7 @@ import (
|
|||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
@ -43,11 +44,11 @@ func TestDNSProxy(t *testing.T) {
|
|||
if addrs == nil {
|
||||
t.Fatal("unexpected empty addrs")
|
||||
}
|
||||
var foundQuad8 bool
|
||||
var found bool
|
||||
for _, addr := range addrs {
|
||||
foundQuad8 = foundQuad8 || addr == "8.8.8.8"
|
||||
found = found || addr == "8.8.8.8"
|
||||
}
|
||||
if !foundQuad8 {
|
||||
if !found {
|
||||
t.Fatal("did not find 8.8.8.8")
|
||||
}
|
||||
listener.Close()
|
||||
|
@ -104,11 +105,11 @@ func TestDNSProxy(t *testing.T) {
|
|||
if addrs == nil {
|
||||
t.Fatal("expected non-empty addrs")
|
||||
}
|
||||
var found127001 bool
|
||||
var found bool
|
||||
for _, addr := range addrs {
|
||||
found127001 = found127001 || addr == "127.0.0.1"
|
||||
found = found || addr == "127.0.0.1"
|
||||
}
|
||||
if !found127001 {
|
||||
if !found {
|
||||
t.Fatal("did not find 127.0.0.1")
|
||||
}
|
||||
listener.Close()
|
||||
|
@ -124,7 +125,7 @@ func TestDNSProxy(t *testing.T) {
|
|||
r := newresolver(listener)
|
||||
addrs, err := r.LookupHost(ctx, "dns.google")
|
||||
if err == nil || err.Error() != netxlite.FailureDNSNoAnswer {
|
||||
t.Fatal(err)
|
||||
t.Fatal("unexpected err", err)
|
||||
}
|
||||
if addrs != nil {
|
||||
t.Fatal("expected empty addrs")
|
||||
|
@ -140,15 +141,15 @@ func TestDNSProxy(t *testing.T) {
|
|||
// careful because lots of legacy code uses SerialResolver.
|
||||
const timeout = time.Second
|
||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
listener, done, err := newproxy(DNSActionTimeout)
|
||||
defer cancel()
|
||||
listener, done, err := newproxy(DNSActionTimeout)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
r := newresolver(listener)
|
||||
addrs, err := r.LookupHost(ctx, "dns.google")
|
||||
if err == nil || err.Error() != netxlite.FailureGenericTimeoutError {
|
||||
t.Fatal(err)
|
||||
t.Fatal("unexpected err", err)
|
||||
}
|
||||
if addrs != nil {
|
||||
t.Fatal("expected empty addrs")
|
||||
|
@ -160,8 +161,8 @@ func TestDNSProxy(t *testing.T) {
|
|||
t.Run("Start with invalid address", func(t *testing.T) {
|
||||
p := &DNSProxy{}
|
||||
listener, err := p.Start("127.0.0.1")
|
||||
if err == nil {
|
||||
t.Fatal("expected an error")
|
||||
if err == nil || !strings.HasSuffix(err.Error(), "missing port in address") {
|
||||
t.Fatal("unexpected err", err)
|
||||
}
|
||||
if listener != nil {
|
||||
t.Fatal("expected nil listener")
|
||||
|
@ -273,13 +274,13 @@ func TestDNSProxy(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("proxy", func(t *testing.T) {
|
||||
t.Run("pack fails", func(t *testing.T) {
|
||||
t.Run("Pack fails", func(t *testing.T) {
|
||||
p := &DNSProxy{}
|
||||
query := &dns.Msg{}
|
||||
query.Rcode = -1 // causes Pack to fail
|
||||
reply, err := p.proxy(query)
|
||||
if err == nil {
|
||||
t.Fatal("expected error here")
|
||||
if err == nil || !strings.HasSuffix(err.Error(), "bad rcode") {
|
||||
t.Fatal("unexpected err", err)
|
||||
}
|
||||
if reply != nil {
|
||||
t.Fatal("expected nil reply")
|
||||
|
@ -305,7 +306,7 @@ func TestDNSProxy(t *testing.T) {
|
|||
}
|
||||
})
|
||||
|
||||
t.Run("unpack fails", func(t *testing.T) {
|
||||
t.Run("Unpack fails", func(t *testing.T) {
|
||||
p := &DNSProxy{
|
||||
Upstream: &mocks.DNSTransport{
|
||||
MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) {
|
||||
|
@ -315,8 +316,8 @@ func TestDNSProxy(t *testing.T) {
|
|||
},
|
||||
}
|
||||
reply, err := p.proxy(&dns.Msg{})
|
||||
if err == nil {
|
||||
t.Fatal("expected error")
|
||||
if err == nil || !strings.HasSuffix(err.Error(), "overflow unpacking uint16") {
|
||||
t.Fatal("unexpected err", err)
|
||||
}
|
||||
if reply != nil {
|
||||
t.Fatal("expected nil reply here")
|
||||
|
|
|
@ -2,6 +2,7 @@ package filtering
|
|||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
@ -97,7 +98,7 @@ func TestHTTPProxy(t *testing.T) {
|
|||
t.Fatal(err)
|
||||
}
|
||||
resp, err := httpGET(ctx, listener.Addr(), "nexa.polito.it")
|
||||
if err == nil || !strings.HasSuffix(err.Error(), "context deadline exceeded") {
|
||||
if !errors.Is(err, context.DeadlineExceeded) {
|
||||
t.Fatal("unexpected err", err)
|
||||
}
|
||||
if resp != nil {
|
||||
|
@ -159,8 +160,8 @@ func TestHTTPProxy(t *testing.T) {
|
|||
t.Run("Start fails on an invalid address", func(t *testing.T) {
|
||||
p := &HTTPProxy{}
|
||||
listener, err := p.Start("127.0.0.1")
|
||||
if err == nil {
|
||||
t.Fatal("expected an error")
|
||||
if err == nil || !strings.HasSuffix(err.Error(), "missing port in address") {
|
||||
t.Fatal("unexpected err", err)
|
||||
}
|
||||
if listener != nil {
|
||||
t.Fatal("expected nil listener")
|
||||
|
|
|
@ -60,18 +60,23 @@ func (p *TLSProxy) start(address string) (net.Listener, <-chan interface{}, erro
|
|||
|
||||
func (p *TLSProxy) mainloop(listener net.Listener, done chan<- interface{}) {
|
||||
defer close(done)
|
||||
for {
|
||||
conn, err := listener.Accept()
|
||||
if err == nil {
|
||||
go p.handle(conn)
|
||||
continue
|
||||
}
|
||||
if strings.HasSuffix(err.Error(), "use of closed network connection") {
|
||||
break
|
||||
}
|
||||
for p.oneloop(listener) {
|
||||
// nothing
|
||||
}
|
||||
}
|
||||
|
||||
func (p *TLSProxy) oneloop(listener net.Listener) bool {
|
||||
conn, err := listener.Accept()
|
||||
if err != nil && strings.HasSuffix(err.Error(), "use of closed network connection") {
|
||||
return false // we need to stop
|
||||
}
|
||||
if err != nil {
|
||||
return true // we can continue running
|
||||
}
|
||||
go p.handle(conn)
|
||||
return true // we can continue running
|
||||
}
|
||||
|
||||
const (
|
||||
tlsAlertInternalError = byte(80)
|
||||
tlsAlertUnrecognizedName = byte(112)
|
||||
|
@ -143,10 +148,11 @@ type tlsClientHelloReader struct {
|
|||
|
||||
func (c *tlsClientHelloReader) Read(b []byte) (int, error) {
|
||||
count, err := c.Conn.Read(b)
|
||||
if err == nil {
|
||||
c.clientHello = append(c.clientHello, b[:count]...)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return count, err
|
||||
c.clientHello = append(c.clientHello, b[:count]...)
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// Write prevents writing on the real connection
|
||||
|
|
|
@ -134,25 +134,24 @@ func TestTLSProxy(t *testing.T) {
|
|||
<-done // wait for background goroutine to exit
|
||||
})
|
||||
|
||||
dial := func(ctx context.Context, endpoint string) (net.Conn, error) {
|
||||
d := netxlite.NewDialerWithoutResolver(log.Log)
|
||||
return d.DialContext(ctx, "tcp", endpoint)
|
||||
}
|
||||
|
||||
t.Run("handle cannot read ClientHello", func(t *testing.T) {
|
||||
listener, done, err := newproxy(TLSActionPass)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
conn, err := net.Dial("tcp", listener.Addr().String())
|
||||
conn, err := dial(context.Background(), listener.Addr().String())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
conn.Write([]byte("GET / HTTP/1.0\r\n\r\n"))
|
||||
buff := make([]byte, 1<<17)
|
||||
_, err = conn.Read(buff)
|
||||
// Implementation note: we need to wrap the error because
|
||||
// otherwise the error string on Windows is different from Unix
|
||||
if err == nil {
|
||||
t.Fatal("expected non-nil error")
|
||||
}
|
||||
err = netxlite.NewTopLevelGenericErrWrapper(err)
|
||||
if err.Error() != netxlite.FailureConnectionReset {
|
||||
if err == nil || err.Error() != netxlite.FailureConnectionReset {
|
||||
t.Fatal("unexpected err", err)
|
||||
}
|
||||
listener.Close()
|
||||
|
@ -251,11 +250,45 @@ func TestTLSProxy(t *testing.T) {
|
|||
t.Run("Start fails on an invalid address", func(t *testing.T) {
|
||||
p := &TLSProxy{}
|
||||
listener, err := p.Start("127.0.0.1")
|
||||
if err == nil {
|
||||
t.Fatal("expected an error")
|
||||
if err == nil || !strings.HasSuffix(err.Error(), "missing port in address") {
|
||||
t.Fatal("unexpected err", err)
|
||||
}
|
||||
if listener != nil {
|
||||
t.Fatal("expected nil listener")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("oneloop correctly handles a listener error", func(t *testing.T) {
|
||||
listener := &mocks.Listener{
|
||||
MockAccept: func() (net.Conn, error) {
|
||||
return nil, errors.New("mocked error")
|
||||
},
|
||||
}
|
||||
p := &TLSProxy{}
|
||||
if !p.oneloop(listener) {
|
||||
t.Fatal("should return true here")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestTLSClientHelloReader(t *testing.T) {
|
||||
t.Run("on failure", func(t *testing.T) {
|
||||
expected := errors.New("mocked error")
|
||||
chr := &tlsClientHelloReader{
|
||||
Conn: &mocks.Conn{
|
||||
MockRead: func(b []byte) (int, error) {
|
||||
return 0, expected
|
||||
},
|
||||
},
|
||||
clientHello: []byte{},
|
||||
}
|
||||
buf := make([]byte, 128)
|
||||
count, err := chr.Read(buf)
|
||||
if !errors.Is(err, expected) {
|
||||
t.Fatal("unexpected err", err)
|
||||
}
|
||||
if count != 0 {
|
||||
t.Fatal("invalid count")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
32
internal/netxlite/mocks/listener.go
Normal file
32
internal/netxlite/mocks/listener.go
Normal file
|
@ -0,0 +1,32 @@
|
|||
package mocks
|
||||
|
||||
import "net"
|
||||
|
||||
// Listener allows mocking a net.Listener.
|
||||
type Listener struct {
|
||||
// Accept allows mocking Accept.
|
||||
MockAccept func() (net.Conn, error)
|
||||
|
||||
// Close allows mocking Close.
|
||||
MockClose func() error
|
||||
|
||||
// Addr allows mocking Addr.
|
||||
MockAddr func() net.Addr
|
||||
}
|
||||
|
||||
var _ net.Listener = &Listener{}
|
||||
|
||||
// Accept implements net.Listener.Accept
|
||||
func (li *Listener) Accept() (net.Conn, error) {
|
||||
return li.MockAccept()
|
||||
}
|
||||
|
||||
// Close implements net.Listener.Closer.
|
||||
func (li *Listener) Close() error {
|
||||
return li.MockClose()
|
||||
}
|
||||
|
||||
// Addr implements net.Listener.Addr
|
||||
func (li *Listener) Addr() net.Addr {
|
||||
return li.MockAddr()
|
||||
}
|
56
internal/netxlite/mocks/listener_test.go
Normal file
56
internal/netxlite/mocks/listener_test.go
Normal file
|
@ -0,0 +1,56 @@
|
|||
package mocks
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
)
|
||||
|
||||
func TestListener(t *testing.T) {
|
||||
t.Run("Accept", func(t *testing.T) {
|
||||
expected := errors.New("mocked error")
|
||||
li := &Listener{
|
||||
MockAccept: func() (net.Conn, error) {
|
||||
return nil, expected
|
||||
},
|
||||
}
|
||||
conn, err := li.Accept()
|
||||
if !errors.Is(err, expected) {
|
||||
t.Fatal("unexpected err", err)
|
||||
}
|
||||
if conn != nil {
|
||||
t.Fatal("expected nil conn")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Close", func(t *testing.T) {
|
||||
expected := errors.New("mocked error")
|
||||
li := &Listener{
|
||||
MockClose: func() error {
|
||||
return expected
|
||||
},
|
||||
}
|
||||
err := li.Close()
|
||||
if !errors.Is(err, expected) {
|
||||
t.Fatal("unexpected err", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Addr", func(t *testing.T) {
|
||||
addr := &net.TCPAddr{
|
||||
IP: net.IPv4(127, 0, 0, 1),
|
||||
Port: 1234,
|
||||
}
|
||||
li := &Listener{
|
||||
MockAddr: func() net.Addr {
|
||||
return addr
|
||||
},
|
||||
}
|
||||
outAddr := li.Addr()
|
||||
if diff := cmp.Diff(addr, outAddr); diff != "" {
|
||||
t.Fatal(diff)
|
||||
}
|
||||
})
|
||||
}
|
Loading…
Reference in New Issue
Block a user