fix(filtering): avoid the if err == nil pattern (#567)
1. in normal code is better to always do if err != nil so that the ifs only contain error code (this is ~coding policy) 2. in tests we want to ensure we narrow down the error to the real error that happened, to have greater confidence Written while working on https://github.com/ooni/probe/issues/1803#issuecomment-957323297
This commit is contained in:
parent
374577f5a8
commit
560b1a9a97
|
@ -4,6 +4,7 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"net"
|
"net"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -43,11 +44,11 @@ func TestDNSProxy(t *testing.T) {
|
||||||
if addrs == nil {
|
if addrs == nil {
|
||||||
t.Fatal("unexpected empty addrs")
|
t.Fatal("unexpected empty addrs")
|
||||||
}
|
}
|
||||||
var foundQuad8 bool
|
var found bool
|
||||||
for _, addr := range addrs {
|
for _, addr := range addrs {
|
||||||
foundQuad8 = foundQuad8 || addr == "8.8.8.8"
|
found = found || addr == "8.8.8.8"
|
||||||
}
|
}
|
||||||
if !foundQuad8 {
|
if !found {
|
||||||
t.Fatal("did not find 8.8.8.8")
|
t.Fatal("did not find 8.8.8.8")
|
||||||
}
|
}
|
||||||
listener.Close()
|
listener.Close()
|
||||||
|
@ -104,11 +105,11 @@ func TestDNSProxy(t *testing.T) {
|
||||||
if addrs == nil {
|
if addrs == nil {
|
||||||
t.Fatal("expected non-empty addrs")
|
t.Fatal("expected non-empty addrs")
|
||||||
}
|
}
|
||||||
var found127001 bool
|
var found bool
|
||||||
for _, addr := range addrs {
|
for _, addr := range addrs {
|
||||||
found127001 = found127001 || addr == "127.0.0.1"
|
found = found || addr == "127.0.0.1"
|
||||||
}
|
}
|
||||||
if !found127001 {
|
if !found {
|
||||||
t.Fatal("did not find 127.0.0.1")
|
t.Fatal("did not find 127.0.0.1")
|
||||||
}
|
}
|
||||||
listener.Close()
|
listener.Close()
|
||||||
|
@ -124,7 +125,7 @@ func TestDNSProxy(t *testing.T) {
|
||||||
r := newresolver(listener)
|
r := newresolver(listener)
|
||||||
addrs, err := r.LookupHost(ctx, "dns.google")
|
addrs, err := r.LookupHost(ctx, "dns.google")
|
||||||
if err == nil || err.Error() != netxlite.FailureDNSNoAnswer {
|
if err == nil || err.Error() != netxlite.FailureDNSNoAnswer {
|
||||||
t.Fatal(err)
|
t.Fatal("unexpected err", err)
|
||||||
}
|
}
|
||||||
if addrs != nil {
|
if addrs != nil {
|
||||||
t.Fatal("expected empty addrs")
|
t.Fatal("expected empty addrs")
|
||||||
|
@ -140,15 +141,15 @@ func TestDNSProxy(t *testing.T) {
|
||||||
// careful because lots of legacy code uses SerialResolver.
|
// careful because lots of legacy code uses SerialResolver.
|
||||||
const timeout = time.Second
|
const timeout = time.Second
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||||
listener, done, err := newproxy(DNSActionTimeout)
|
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
listener, done, err := newproxy(DNSActionTimeout)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
r := newresolver(listener)
|
r := newresolver(listener)
|
||||||
addrs, err := r.LookupHost(ctx, "dns.google")
|
addrs, err := r.LookupHost(ctx, "dns.google")
|
||||||
if err == nil || err.Error() != netxlite.FailureGenericTimeoutError {
|
if err == nil || err.Error() != netxlite.FailureGenericTimeoutError {
|
||||||
t.Fatal(err)
|
t.Fatal("unexpected err", err)
|
||||||
}
|
}
|
||||||
if addrs != nil {
|
if addrs != nil {
|
||||||
t.Fatal("expected empty addrs")
|
t.Fatal("expected empty addrs")
|
||||||
|
@ -160,8 +161,8 @@ func TestDNSProxy(t *testing.T) {
|
||||||
t.Run("Start with invalid address", func(t *testing.T) {
|
t.Run("Start with invalid address", func(t *testing.T) {
|
||||||
p := &DNSProxy{}
|
p := &DNSProxy{}
|
||||||
listener, err := p.Start("127.0.0.1")
|
listener, err := p.Start("127.0.0.1")
|
||||||
if err == nil {
|
if err == nil || !strings.HasSuffix(err.Error(), "missing port in address") {
|
||||||
t.Fatal("expected an error")
|
t.Fatal("unexpected err", err)
|
||||||
}
|
}
|
||||||
if listener != nil {
|
if listener != nil {
|
||||||
t.Fatal("expected nil listener")
|
t.Fatal("expected nil listener")
|
||||||
|
@ -273,13 +274,13 @@ func TestDNSProxy(t *testing.T) {
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("proxy", func(t *testing.T) {
|
t.Run("proxy", func(t *testing.T) {
|
||||||
t.Run("pack fails", func(t *testing.T) {
|
t.Run("Pack fails", func(t *testing.T) {
|
||||||
p := &DNSProxy{}
|
p := &DNSProxy{}
|
||||||
query := &dns.Msg{}
|
query := &dns.Msg{}
|
||||||
query.Rcode = -1 // causes Pack to fail
|
query.Rcode = -1 // causes Pack to fail
|
||||||
reply, err := p.proxy(query)
|
reply, err := p.proxy(query)
|
||||||
if err == nil {
|
if err == nil || !strings.HasSuffix(err.Error(), "bad rcode") {
|
||||||
t.Fatal("expected error here")
|
t.Fatal("unexpected err", err)
|
||||||
}
|
}
|
||||||
if reply != nil {
|
if reply != nil {
|
||||||
t.Fatal("expected nil reply")
|
t.Fatal("expected nil reply")
|
||||||
|
@ -305,7 +306,7 @@ func TestDNSProxy(t *testing.T) {
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("unpack fails", func(t *testing.T) {
|
t.Run("Unpack fails", func(t *testing.T) {
|
||||||
p := &DNSProxy{
|
p := &DNSProxy{
|
||||||
Upstream: &mocks.DNSTransport{
|
Upstream: &mocks.DNSTransport{
|
||||||
MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) {
|
MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) {
|
||||||
|
@ -315,8 +316,8 @@ func TestDNSProxy(t *testing.T) {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
reply, err := p.proxy(&dns.Msg{})
|
reply, err := p.proxy(&dns.Msg{})
|
||||||
if err == nil {
|
if err == nil || !strings.HasSuffix(err.Error(), "overflow unpacking uint16") {
|
||||||
t.Fatal("expected error")
|
t.Fatal("unexpected err", err)
|
||||||
}
|
}
|
||||||
if reply != nil {
|
if reply != nil {
|
||||||
t.Fatal("expected nil reply here")
|
t.Fatal("expected nil reply here")
|
||||||
|
|
|
@ -2,6 +2,7 @@ package filtering
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
@ -97,7 +98,7 @@ func TestHTTPProxy(t *testing.T) {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
resp, err := httpGET(ctx, listener.Addr(), "nexa.polito.it")
|
resp, err := httpGET(ctx, listener.Addr(), "nexa.polito.it")
|
||||||
if err == nil || !strings.HasSuffix(err.Error(), "context deadline exceeded") {
|
if !errors.Is(err, context.DeadlineExceeded) {
|
||||||
t.Fatal("unexpected err", err)
|
t.Fatal("unexpected err", err)
|
||||||
}
|
}
|
||||||
if resp != nil {
|
if resp != nil {
|
||||||
|
@ -159,8 +160,8 @@ func TestHTTPProxy(t *testing.T) {
|
||||||
t.Run("Start fails on an invalid address", func(t *testing.T) {
|
t.Run("Start fails on an invalid address", func(t *testing.T) {
|
||||||
p := &HTTPProxy{}
|
p := &HTTPProxy{}
|
||||||
listener, err := p.Start("127.0.0.1")
|
listener, err := p.Start("127.0.0.1")
|
||||||
if err == nil {
|
if err == nil || !strings.HasSuffix(err.Error(), "missing port in address") {
|
||||||
t.Fatal("expected an error")
|
t.Fatal("unexpected err", err)
|
||||||
}
|
}
|
||||||
if listener != nil {
|
if listener != nil {
|
||||||
t.Fatal("expected nil listener")
|
t.Fatal("expected nil listener")
|
||||||
|
|
|
@ -60,16 +60,21 @@ func (p *TLSProxy) start(address string) (net.Listener, <-chan interface{}, erro
|
||||||
|
|
||||||
func (p *TLSProxy) mainloop(listener net.Listener, done chan<- interface{}) {
|
func (p *TLSProxy) mainloop(listener net.Listener, done chan<- interface{}) {
|
||||||
defer close(done)
|
defer close(done)
|
||||||
for {
|
for p.oneloop(listener) {
|
||||||
|
// nothing
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *TLSProxy) oneloop(listener net.Listener) bool {
|
||||||
conn, err := listener.Accept()
|
conn, err := listener.Accept()
|
||||||
if err == nil {
|
if err != nil && strings.HasSuffix(err.Error(), "use of closed network connection") {
|
||||||
|
return false // we need to stop
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return true // we can continue running
|
||||||
|
}
|
||||||
go p.handle(conn)
|
go p.handle(conn)
|
||||||
continue
|
return true // we can continue running
|
||||||
}
|
|
||||||
if strings.HasSuffix(err.Error(), "use of closed network connection") {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -143,10 +148,11 @@ type tlsClientHelloReader struct {
|
||||||
|
|
||||||
func (c *tlsClientHelloReader) Read(b []byte) (int, error) {
|
func (c *tlsClientHelloReader) Read(b []byte) (int, error) {
|
||||||
count, err := c.Conn.Read(b)
|
count, err := c.Conn.Read(b)
|
||||||
if err == nil {
|
if err != nil {
|
||||||
c.clientHello = append(c.clientHello, b[:count]...)
|
return 0, err
|
||||||
}
|
}
|
||||||
return count, err
|
c.clientHello = append(c.clientHello, b[:count]...)
|
||||||
|
return count, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Write prevents writing on the real connection
|
// Write prevents writing on the real connection
|
||||||
|
|
|
@ -134,25 +134,24 @@ func TestTLSProxy(t *testing.T) {
|
||||||
<-done // wait for background goroutine to exit
|
<-done // wait for background goroutine to exit
|
||||||
})
|
})
|
||||||
|
|
||||||
|
dial := func(ctx context.Context, endpoint string) (net.Conn, error) {
|
||||||
|
d := netxlite.NewDialerWithoutResolver(log.Log)
|
||||||
|
return d.DialContext(ctx, "tcp", endpoint)
|
||||||
|
}
|
||||||
|
|
||||||
t.Run("handle cannot read ClientHello", func(t *testing.T) {
|
t.Run("handle cannot read ClientHello", func(t *testing.T) {
|
||||||
listener, done, err := newproxy(TLSActionPass)
|
listener, done, err := newproxy(TLSActionPass)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
conn, err := net.Dial("tcp", listener.Addr().String())
|
conn, err := dial(context.Background(), listener.Addr().String())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
conn.Write([]byte("GET / HTTP/1.0\r\n\r\n"))
|
conn.Write([]byte("GET / HTTP/1.0\r\n\r\n"))
|
||||||
buff := make([]byte, 1<<17)
|
buff := make([]byte, 1<<17)
|
||||||
_, err = conn.Read(buff)
|
_, err = conn.Read(buff)
|
||||||
// Implementation note: we need to wrap the error because
|
if err == nil || err.Error() != netxlite.FailureConnectionReset {
|
||||||
// otherwise the error string on Windows is different from Unix
|
|
||||||
if err == nil {
|
|
||||||
t.Fatal("expected non-nil error")
|
|
||||||
}
|
|
||||||
err = netxlite.NewTopLevelGenericErrWrapper(err)
|
|
||||||
if err.Error() != netxlite.FailureConnectionReset {
|
|
||||||
t.Fatal("unexpected err", err)
|
t.Fatal("unexpected err", err)
|
||||||
}
|
}
|
||||||
listener.Close()
|
listener.Close()
|
||||||
|
@ -251,11 +250,45 @@ func TestTLSProxy(t *testing.T) {
|
||||||
t.Run("Start fails on an invalid address", func(t *testing.T) {
|
t.Run("Start fails on an invalid address", func(t *testing.T) {
|
||||||
p := &TLSProxy{}
|
p := &TLSProxy{}
|
||||||
listener, err := p.Start("127.0.0.1")
|
listener, err := p.Start("127.0.0.1")
|
||||||
if err == nil {
|
if err == nil || !strings.HasSuffix(err.Error(), "missing port in address") {
|
||||||
t.Fatal("expected an error")
|
t.Fatal("unexpected err", err)
|
||||||
}
|
}
|
||||||
if listener != nil {
|
if listener != nil {
|
||||||
t.Fatal("expected nil listener")
|
t.Fatal("expected nil listener")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("oneloop correctly handles a listener error", func(t *testing.T) {
|
||||||
|
listener := &mocks.Listener{
|
||||||
|
MockAccept: func() (net.Conn, error) {
|
||||||
|
return nil, errors.New("mocked error")
|
||||||
|
},
|
||||||
|
}
|
||||||
|
p := &TLSProxy{}
|
||||||
|
if !p.oneloop(listener) {
|
||||||
|
t.Fatal("should return true here")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTLSClientHelloReader(t *testing.T) {
|
||||||
|
t.Run("on failure", func(t *testing.T) {
|
||||||
|
expected := errors.New("mocked error")
|
||||||
|
chr := &tlsClientHelloReader{
|
||||||
|
Conn: &mocks.Conn{
|
||||||
|
MockRead: func(b []byte) (int, error) {
|
||||||
|
return 0, expected
|
||||||
|
},
|
||||||
|
},
|
||||||
|
clientHello: []byte{},
|
||||||
|
}
|
||||||
|
buf := make([]byte, 128)
|
||||||
|
count, err := chr.Read(buf)
|
||||||
|
if !errors.Is(err, expected) {
|
||||||
|
t.Fatal("unexpected err", err)
|
||||||
|
}
|
||||||
|
if count != 0 {
|
||||||
|
t.Fatal("invalid count")
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
32
internal/netxlite/mocks/listener.go
Normal file
32
internal/netxlite/mocks/listener.go
Normal file
|
@ -0,0 +1,32 @@
|
||||||
|
package mocks
|
||||||
|
|
||||||
|
import "net"
|
||||||
|
|
||||||
|
// Listener allows mocking a net.Listener.
|
||||||
|
type Listener struct {
|
||||||
|
// Accept allows mocking Accept.
|
||||||
|
MockAccept func() (net.Conn, error)
|
||||||
|
|
||||||
|
// Close allows mocking Close.
|
||||||
|
MockClose func() error
|
||||||
|
|
||||||
|
// Addr allows mocking Addr.
|
||||||
|
MockAddr func() net.Addr
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ net.Listener = &Listener{}
|
||||||
|
|
||||||
|
// Accept implements net.Listener.Accept
|
||||||
|
func (li *Listener) Accept() (net.Conn, error) {
|
||||||
|
return li.MockAccept()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close implements net.Listener.Closer.
|
||||||
|
func (li *Listener) Close() error {
|
||||||
|
return li.MockClose()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Addr implements net.Listener.Addr
|
||||||
|
func (li *Listener) Addr() net.Addr {
|
||||||
|
return li.MockAddr()
|
||||||
|
}
|
56
internal/netxlite/mocks/listener_test.go
Normal file
56
internal/netxlite/mocks/listener_test.go
Normal file
|
@ -0,0 +1,56 @@
|
||||||
|
package mocks
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestListener(t *testing.T) {
|
||||||
|
t.Run("Accept", func(t *testing.T) {
|
||||||
|
expected := errors.New("mocked error")
|
||||||
|
li := &Listener{
|
||||||
|
MockAccept: func() (net.Conn, error) {
|
||||||
|
return nil, expected
|
||||||
|
},
|
||||||
|
}
|
||||||
|
conn, err := li.Accept()
|
||||||
|
if !errors.Is(err, expected) {
|
||||||
|
t.Fatal("unexpected err", err)
|
||||||
|
}
|
||||||
|
if conn != nil {
|
||||||
|
t.Fatal("expected nil conn")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Close", func(t *testing.T) {
|
||||||
|
expected := errors.New("mocked error")
|
||||||
|
li := &Listener{
|
||||||
|
MockClose: func() error {
|
||||||
|
return expected
|
||||||
|
},
|
||||||
|
}
|
||||||
|
err := li.Close()
|
||||||
|
if !errors.Is(err, expected) {
|
||||||
|
t.Fatal("unexpected err", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Addr", func(t *testing.T) {
|
||||||
|
addr := &net.TCPAddr{
|
||||||
|
IP: net.IPv4(127, 0, 0, 1),
|
||||||
|
Port: 1234,
|
||||||
|
}
|
||||||
|
li := &Listener{
|
||||||
|
MockAddr: func() net.Addr {
|
||||||
|
return addr
|
||||||
|
},
|
||||||
|
}
|
||||||
|
outAddr := li.Addr()
|
||||||
|
if diff := cmp.Diff(addr, outAddr); diff != "" {
|
||||||
|
t.Fatal(diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user