fix(filtering): avoid the if err == nil pattern (#567)

1. in normal code is better to always do if err != nil so that
the ifs only contain error code (this is ~coding policy)

2. in tests we want to ensure we narrow down the error to the
real error that happened, to have greater confidence

Written while working on https://github.com/ooni/probe/issues/1803#issuecomment-957323297
This commit is contained in:
Simone Basso 2021-11-02 19:48:10 +01:00 committed by GitHub
parent 374577f5a8
commit 560b1a9a97
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 171 additions and 42 deletions

View File

@ -4,6 +4,7 @@ import (
"context" "context"
"errors" "errors"
"net" "net"
"strings"
"testing" "testing"
"time" "time"
@ -43,11 +44,11 @@ func TestDNSProxy(t *testing.T) {
if addrs == nil { if addrs == nil {
t.Fatal("unexpected empty addrs") t.Fatal("unexpected empty addrs")
} }
var foundQuad8 bool var found bool
for _, addr := range addrs { for _, addr := range addrs {
foundQuad8 = foundQuad8 || addr == "8.8.8.8" found = found || addr == "8.8.8.8"
} }
if !foundQuad8 { if !found {
t.Fatal("did not find 8.8.8.8") t.Fatal("did not find 8.8.8.8")
} }
listener.Close() listener.Close()
@ -104,11 +105,11 @@ func TestDNSProxy(t *testing.T) {
if addrs == nil { if addrs == nil {
t.Fatal("expected non-empty addrs") t.Fatal("expected non-empty addrs")
} }
var found127001 bool var found bool
for _, addr := range addrs { for _, addr := range addrs {
found127001 = found127001 || addr == "127.0.0.1" found = found || addr == "127.0.0.1"
} }
if !found127001 { if !found {
t.Fatal("did not find 127.0.0.1") t.Fatal("did not find 127.0.0.1")
} }
listener.Close() listener.Close()
@ -124,7 +125,7 @@ func TestDNSProxy(t *testing.T) {
r := newresolver(listener) r := newresolver(listener)
addrs, err := r.LookupHost(ctx, "dns.google") addrs, err := r.LookupHost(ctx, "dns.google")
if err == nil || err.Error() != netxlite.FailureDNSNoAnswer { if err == nil || err.Error() != netxlite.FailureDNSNoAnswer {
t.Fatal(err) t.Fatal("unexpected err", err)
} }
if addrs != nil { if addrs != nil {
t.Fatal("expected empty addrs") t.Fatal("expected empty addrs")
@ -140,15 +141,15 @@ func TestDNSProxy(t *testing.T) {
// careful because lots of legacy code uses SerialResolver. // careful because lots of legacy code uses SerialResolver.
const timeout = time.Second const timeout = time.Second
ctx, cancel := context.WithTimeout(context.Background(), timeout) ctx, cancel := context.WithTimeout(context.Background(), timeout)
listener, done, err := newproxy(DNSActionTimeout)
defer cancel() defer cancel()
listener, done, err := newproxy(DNSActionTimeout)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
r := newresolver(listener) r := newresolver(listener)
addrs, err := r.LookupHost(ctx, "dns.google") addrs, err := r.LookupHost(ctx, "dns.google")
if err == nil || err.Error() != netxlite.FailureGenericTimeoutError { if err == nil || err.Error() != netxlite.FailureGenericTimeoutError {
t.Fatal(err) t.Fatal("unexpected err", err)
} }
if addrs != nil { if addrs != nil {
t.Fatal("expected empty addrs") t.Fatal("expected empty addrs")
@ -160,8 +161,8 @@ func TestDNSProxy(t *testing.T) {
t.Run("Start with invalid address", func(t *testing.T) { t.Run("Start with invalid address", func(t *testing.T) {
p := &DNSProxy{} p := &DNSProxy{}
listener, err := p.Start("127.0.0.1") listener, err := p.Start("127.0.0.1")
if err == nil { if err == nil || !strings.HasSuffix(err.Error(), "missing port in address") {
t.Fatal("expected an error") t.Fatal("unexpected err", err)
} }
if listener != nil { if listener != nil {
t.Fatal("expected nil listener") t.Fatal("expected nil listener")
@ -273,13 +274,13 @@ func TestDNSProxy(t *testing.T) {
}) })
t.Run("proxy", func(t *testing.T) { t.Run("proxy", func(t *testing.T) {
t.Run("pack fails", func(t *testing.T) { t.Run("Pack fails", func(t *testing.T) {
p := &DNSProxy{} p := &DNSProxy{}
query := &dns.Msg{} query := &dns.Msg{}
query.Rcode = -1 // causes Pack to fail query.Rcode = -1 // causes Pack to fail
reply, err := p.proxy(query) reply, err := p.proxy(query)
if err == nil { if err == nil || !strings.HasSuffix(err.Error(), "bad rcode") {
t.Fatal("expected error here") t.Fatal("unexpected err", err)
} }
if reply != nil { if reply != nil {
t.Fatal("expected nil reply") t.Fatal("expected nil reply")
@ -305,7 +306,7 @@ func TestDNSProxy(t *testing.T) {
} }
}) })
t.Run("unpack fails", func(t *testing.T) { t.Run("Unpack fails", func(t *testing.T) {
p := &DNSProxy{ p := &DNSProxy{
Upstream: &mocks.DNSTransport{ Upstream: &mocks.DNSTransport{
MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) { MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) {
@ -315,8 +316,8 @@ func TestDNSProxy(t *testing.T) {
}, },
} }
reply, err := p.proxy(&dns.Msg{}) reply, err := p.proxy(&dns.Msg{})
if err == nil { if err == nil || !strings.HasSuffix(err.Error(), "overflow unpacking uint16") {
t.Fatal("expected error") t.Fatal("unexpected err", err)
} }
if reply != nil { if reply != nil {
t.Fatal("expected nil reply here") t.Fatal("expected nil reply here")

View File

@ -2,6 +2,7 @@ package filtering
import ( import (
"context" "context"
"errors"
"net" "net"
"net/http" "net/http"
"net/url" "net/url"
@ -97,7 +98,7 @@ func TestHTTPProxy(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
resp, err := httpGET(ctx, listener.Addr(), "nexa.polito.it") resp, err := httpGET(ctx, listener.Addr(), "nexa.polito.it")
if err == nil || !strings.HasSuffix(err.Error(), "context deadline exceeded") { if !errors.Is(err, context.DeadlineExceeded) {
t.Fatal("unexpected err", err) t.Fatal("unexpected err", err)
} }
if resp != nil { if resp != nil {
@ -159,8 +160,8 @@ func TestHTTPProxy(t *testing.T) {
t.Run("Start fails on an invalid address", func(t *testing.T) { t.Run("Start fails on an invalid address", func(t *testing.T) {
p := &HTTPProxy{} p := &HTTPProxy{}
listener, err := p.Start("127.0.0.1") listener, err := p.Start("127.0.0.1")
if err == nil { if err == nil || !strings.HasSuffix(err.Error(), "missing port in address") {
t.Fatal("expected an error") t.Fatal("unexpected err", err)
} }
if listener != nil { if listener != nil {
t.Fatal("expected nil listener") t.Fatal("expected nil listener")

View File

@ -60,16 +60,21 @@ func (p *TLSProxy) start(address string) (net.Listener, <-chan interface{}, erro
func (p *TLSProxy) mainloop(listener net.Listener, done chan<- interface{}) { func (p *TLSProxy) mainloop(listener net.Listener, done chan<- interface{}) {
defer close(done) defer close(done)
for { for p.oneloop(listener) {
// nothing
}
}
func (p *TLSProxy) oneloop(listener net.Listener) bool {
conn, err := listener.Accept() conn, err := listener.Accept()
if err == nil { if err != nil && strings.HasSuffix(err.Error(), "use of closed network connection") {
return false // we need to stop
}
if err != nil {
return true // we can continue running
}
go p.handle(conn) go p.handle(conn)
continue return true // we can continue running
}
if strings.HasSuffix(err.Error(), "use of closed network connection") {
break
}
}
} }
const ( const (
@ -143,10 +148,11 @@ type tlsClientHelloReader struct {
func (c *tlsClientHelloReader) Read(b []byte) (int, error) { func (c *tlsClientHelloReader) Read(b []byte) (int, error) {
count, err := c.Conn.Read(b) count, err := c.Conn.Read(b)
if err == nil { if err != nil {
c.clientHello = append(c.clientHello, b[:count]...) return 0, err
} }
return count, err c.clientHello = append(c.clientHello, b[:count]...)
return count, nil
} }
// Write prevents writing on the real connection // Write prevents writing on the real connection

View File

@ -134,25 +134,24 @@ func TestTLSProxy(t *testing.T) {
<-done // wait for background goroutine to exit <-done // wait for background goroutine to exit
}) })
dial := func(ctx context.Context, endpoint string) (net.Conn, error) {
d := netxlite.NewDialerWithoutResolver(log.Log)
return d.DialContext(ctx, "tcp", endpoint)
}
t.Run("handle cannot read ClientHello", func(t *testing.T) { t.Run("handle cannot read ClientHello", func(t *testing.T) {
listener, done, err := newproxy(TLSActionPass) listener, done, err := newproxy(TLSActionPass)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
conn, err := net.Dial("tcp", listener.Addr().String()) conn, err := dial(context.Background(), listener.Addr().String())
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
conn.Write([]byte("GET / HTTP/1.0\r\n\r\n")) conn.Write([]byte("GET / HTTP/1.0\r\n\r\n"))
buff := make([]byte, 1<<17) buff := make([]byte, 1<<17)
_, err = conn.Read(buff) _, err = conn.Read(buff)
// Implementation note: we need to wrap the error because if err == nil || err.Error() != netxlite.FailureConnectionReset {
// otherwise the error string on Windows is different from Unix
if err == nil {
t.Fatal("expected non-nil error")
}
err = netxlite.NewTopLevelGenericErrWrapper(err)
if err.Error() != netxlite.FailureConnectionReset {
t.Fatal("unexpected err", err) t.Fatal("unexpected err", err)
} }
listener.Close() listener.Close()
@ -251,11 +250,45 @@ func TestTLSProxy(t *testing.T) {
t.Run("Start fails on an invalid address", func(t *testing.T) { t.Run("Start fails on an invalid address", func(t *testing.T) {
p := &TLSProxy{} p := &TLSProxy{}
listener, err := p.Start("127.0.0.1") listener, err := p.Start("127.0.0.1")
if err == nil { if err == nil || !strings.HasSuffix(err.Error(), "missing port in address") {
t.Fatal("expected an error") t.Fatal("unexpected err", err)
} }
if listener != nil { if listener != nil {
t.Fatal("expected nil listener") t.Fatal("expected nil listener")
} }
}) })
t.Run("oneloop correctly handles a listener error", func(t *testing.T) {
listener := &mocks.Listener{
MockAccept: func() (net.Conn, error) {
return nil, errors.New("mocked error")
},
}
p := &TLSProxy{}
if !p.oneloop(listener) {
t.Fatal("should return true here")
}
})
}
func TestTLSClientHelloReader(t *testing.T) {
t.Run("on failure", func(t *testing.T) {
expected := errors.New("mocked error")
chr := &tlsClientHelloReader{
Conn: &mocks.Conn{
MockRead: func(b []byte) (int, error) {
return 0, expected
},
},
clientHello: []byte{},
}
buf := make([]byte, 128)
count, err := chr.Read(buf)
if !errors.Is(err, expected) {
t.Fatal("unexpected err", err)
}
if count != 0 {
t.Fatal("invalid count")
}
})
} }

View File

@ -0,0 +1,32 @@
package mocks
import "net"
// Listener allows mocking a net.Listener.
type Listener struct {
// Accept allows mocking Accept.
MockAccept func() (net.Conn, error)
// Close allows mocking Close.
MockClose func() error
// Addr allows mocking Addr.
MockAddr func() net.Addr
}
var _ net.Listener = &Listener{}
// Accept implements net.Listener.Accept
func (li *Listener) Accept() (net.Conn, error) {
return li.MockAccept()
}
// Close implements net.Listener.Closer.
func (li *Listener) Close() error {
return li.MockClose()
}
// Addr implements net.Listener.Addr
func (li *Listener) Addr() net.Addr {
return li.MockAddr()
}

View File

@ -0,0 +1,56 @@
package mocks
import (
"errors"
"net"
"testing"
"github.com/google/go-cmp/cmp"
)
func TestListener(t *testing.T) {
t.Run("Accept", func(t *testing.T) {
expected := errors.New("mocked error")
li := &Listener{
MockAccept: func() (net.Conn, error) {
return nil, expected
},
}
conn, err := li.Accept()
if !errors.Is(err, expected) {
t.Fatal("unexpected err", err)
}
if conn != nil {
t.Fatal("expected nil conn")
}
})
t.Run("Close", func(t *testing.T) {
expected := errors.New("mocked error")
li := &Listener{
MockClose: func() error {
return expected
},
}
err := li.Close()
if !errors.Is(err, expected) {
t.Fatal("unexpected err", err)
}
})
t.Run("Addr", func(t *testing.T) {
addr := &net.TCPAddr{
IP: net.IPv4(127, 0, 0, 1),
Port: 1234,
}
li := &Listener{
MockAddr: func() net.Addr {
return addr
},
}
outAddr := li.Addr()
if diff := cmp.Diff(addr, outAddr); diff != "" {
t.Fatal(diff)
}
})
}