diff --git a/.github/workflows/coverage.yml b/.github/workflows/coverage.yml index f3ac09b..cccb360 100644 --- a/.github/workflows/coverage.yml +++ b/.github/workflows/coverage.yml @@ -10,7 +10,7 @@ jobs: runs-on: ubuntu-20.04 strategy: matrix: - go: [ "1.17" ] + go: [ "1.17.1" ] steps: - uses: actions/setup-go@v1 with: diff --git a/.github/workflows/netxlite.yml b/.github/workflows/netxlite.yml new file mode 100644 index 0000000..3b1d2e7 --- /dev/null +++ b/.github/workflows/netxlite.yml @@ -0,0 +1,20 @@ +# netxlite runs unit and integration tests on our fundamental net library +name: netxlite +on: + pull_request: + push: + branches: + - "master" +jobs: + test: + runs-on: "${{ matrix.os }}" + strategy: + matrix: + go: [ "1.17.1" ] + os: [ "ubuntu-20.04", "windows-2019", "macos-10.15" ] + steps: + - uses: actions/setup-go@v1 + with: + go-version: "${{ matrix.go }}" + - uses: actions/checkout@v2 + - run: go test -race ./internal/netxlite/... diff --git a/internal/netxlite/dialer_test.go b/internal/netxlite/dialer_test.go index 67cbafb..f95eb1f 100644 --- a/internal/netxlite/dialer_test.go +++ b/internal/netxlite/dialer_test.go @@ -72,7 +72,7 @@ func TestDialerSystem(t *testing.T) { }) t.Run("enforces the configured timeout", func(t *testing.T) { - const timeout = 1 * time.Millisecond + const timeout = 1 * time.Nanosecond d := &dialerSystem{timeout: timeout} ctx := context.Background() start := time.Now() diff --git a/internal/netxlite/filtering/dns.go b/internal/netxlite/filtering/dns.go new file mode 100644 index 0000000..8adec9d --- /dev/null +++ b/internal/netxlite/filtering/dns.go @@ -0,0 +1,218 @@ +package filtering + +import ( + "context" + "errors" + "io" + "net" + "net/http" + "strings" + + "github.com/miekg/dns" + "github.com/ooni/probe-cli/v3/internal/netxlite" + "github.com/ooni/probe-cli/v3/internal/runtimex" +) + +// DNSAction is the action that this proxy should take. +type DNSAction int + +const ( + // DNSActionProxy proxies the traffic to the upstream server. + DNSActionProxy = DNSAction(iota) + + // DNSActionNXDOMAIN replies with NXDOMAIN. + DNSActionNXDOMAIN + + // DNSActionRefused replies with Refused. + DNSActionRefused + + // DNSActionLocalHost replies with `127.0.0.1` and `::1`. + DNSActionLocalHost + + // DNSActionEmpty returns an empty reply. + DNSActionEmpty + + // DNSActionTimeout never replies to the query. + DNSActionTimeout +) + +// DNSProxy is a DNS proxy that routes traffic to an upstream +// resolver and may implement filtering policies. +type DNSProxy struct { + // OnQuery is the MANDATORY hook called whenever we + // receive a query for the given domain. + OnQuery func(domain string) DNSAction + + // Upstream is the OPTIONAL upstream transport. + Upstream DNSTransport + + // mockableReply allows to mock DNSProxy.reply in tests. + mockableReply func(query *dns.Msg) (*dns.Msg, error) +} + +// DNSTransport is the type we expect from an upstream DNS transport. +type DNSTransport interface { + RoundTrip(ctx context.Context, query []byte) ([]byte, error) + CloseIdleConnections() +} + +// DNSListener is the interface returned by DNSProxy.Start +type DNSListener interface { + io.Closer + LocalAddr() net.Addr +} + +// Start starts the proxy. +func (p *DNSProxy) Start(address string) (DNSListener, error) { + pconn, _, err := p.start(address) + return pconn, err +} + +func (p *DNSProxy) start(address string) (DNSListener, <-chan interface{}, error) { + pconn, err := net.ListenPacket("udp", address) + if err != nil { + return nil, nil, err + } + done := make(chan interface{}) + go p.mainloop(pconn, done) + return pconn, done, nil +} + +func (p *DNSProxy) mainloop(pconn net.PacketConn, done chan<- interface{}) { + defer close(done) + for p.oneloop(pconn) { + // nothing + } +} + +func (p *DNSProxy) oneloop(pconn net.PacketConn) bool { + buffer := make([]byte, 1<<12) + count, addr, err := pconn.ReadFrom(buffer) + if err != nil { + return !strings.HasSuffix(err.Error(), "use of closed network connection") + } + buffer = buffer[:count] + query := &dns.Msg{} + if err := query.Unpack(buffer); err != nil { + return true // can continue + } + reply, err := p.reply(query) + if err != nil { + return true // can continue + } + replyBytes, err := reply.Pack() + if err != nil { + return true // can continue + } + pconn.WriteTo(replyBytes, addr) + return true // can continue +} + +func (p *DNSProxy) reply(query *dns.Msg) (*dns.Msg, error) { + if p.mockableReply != nil { + return p.mockableReply(query) + } + return p.replyDefault(query) +} + +func (p *DNSProxy) replyDefault(query *dns.Msg) (*dns.Msg, error) { + if len(query.Question) != 1 { + return nil, errors.New("unhandled message") + } + name := query.Question[0].Name + switch p.OnQuery(name) { + case DNSActionProxy: + return p.proxy(query) + case DNSActionNXDOMAIN: + return p.nxdomain(query), nil + case DNSActionLocalHost: + return p.localHost(query), nil + case DNSActionEmpty: + return p.empty(query), nil + case DNSActionTimeout: + return nil, errors.New("let's ignore this query") + default: + return p.refused(query), nil + } +} + +func (p *DNSProxy) refused(query *dns.Msg) *dns.Msg { + m := new(dns.Msg) + m.SetRcode(query, dns.RcodeRefused) + return m +} + +func (p *DNSProxy) nxdomain(query *dns.Msg) *dns.Msg { + m := new(dns.Msg) + m.SetRcode(query, dns.RcodeNameError) + return m +} + +func (p *DNSProxy) localHost(query *dns.Msg) *dns.Msg { + return p.compose(query, net.IPv6loopback, net.IPv4(127, 0, 0, 1)) +} + +func (p *DNSProxy) empty(query *dns.Msg) *dns.Msg { + return p.compose(query) +} + +func (p *DNSProxy) compose(query *dns.Msg, ips ...net.IP) *dns.Msg { + runtimex.PanicIfTrue(len(query.Question) != 1, "expecting a single question") + question := query.Question[0] + reply := new(dns.Msg) + reply.Compress = true + reply.MsgHdr.RecursionAvailable = true + reply.SetReply(query) + for _, ip := range ips { + isIPv6 := strings.Contains(ip.String(), ":") + if !isIPv6 && question.Qtype == dns.TypeA { + reply.Answer = append(reply.Answer, &dns.A{ + Hdr: dns.RR_Header{ + Name: question.Name, + Rrtype: dns.TypeA, + Class: dns.ClassINET, + Ttl: 0, + }, + A: ip, + }) + } else if isIPv6 && question.Qtype == dns.TypeAAAA { + reply.Answer = append(reply.Answer, &dns.AAAA{ + Hdr: dns.RR_Header{ + Name: question.Name, + Rrtype: dns.TypeAAAA, + Class: dns.ClassINET, + Ttl: 0, + }, + AAAA: ip, + }) + } + } + return reply +} + +func (p *DNSProxy) proxy(query *dns.Msg) (*dns.Msg, error) { + queryBytes, err := query.Pack() + if err != nil { + return nil, err + } + txp := p.dnstransport() + defer txp.CloseIdleConnections() + ctx := context.Background() + replyBytes, err := txp.RoundTrip(ctx, queryBytes) + if err != nil { + return nil, err + } + reply := &dns.Msg{} + if err := reply.Unpack(replyBytes); err != nil { + return nil, err + } + return reply, nil +} + +func (p *DNSProxy) dnstransport() DNSTransport { + if p.Upstream != nil { + return p.Upstream + } + const URL = "https://1.1.1.1/dns-query" + return netxlite.NewDNSOverHTTPS(http.DefaultClient, URL) +} diff --git a/internal/netxlite/filtering/dns_test.go b/internal/netxlite/filtering/dns_test.go new file mode 100644 index 0000000..f68514b --- /dev/null +++ b/internal/netxlite/filtering/dns_test.go @@ -0,0 +1,326 @@ +package filtering + +import ( + "context" + "errors" + "net" + "testing" + "time" + + "github.com/apex/log" + "github.com/miekg/dns" + "github.com/ooni/probe-cli/v3/internal/netxlite" + "github.com/ooni/probe-cli/v3/internal/netxlite/mocks" +) + +func TestDNSProxy(t *testing.T) { + newproxy := func(action DNSAction) (DNSListener, <-chan interface{}, error) { + p := &DNSProxy{ + OnQuery: func(domain string) DNSAction { + return action + }, + } + return p.start("127.0.0.1:0") + } + + newresolver := func(listener DNSListener) netxlite.Resolver { + dlr := netxlite.NewDialerWithoutResolver(log.Log) + r := netxlite.NewResolverUDP(log.Log, dlr, listener.LocalAddr().String()) + return r + } + + t.Run("DNSActionProxy with default proxy", func(t *testing.T) { + ctx := context.Background() + listener, done, err := newproxy(DNSActionProxy) + if err != nil { + t.Fatal(err) + } + r := newresolver(listener) + addrs, err := r.LookupHost(ctx, "dns.google") + if err != nil { + t.Fatal(err) + } + if addrs == nil { + t.Fatal("unexpected empty addrs") + } + var foundQuad8 bool + for _, addr := range addrs { + foundQuad8 = foundQuad8 || addr == "8.8.8.8" + } + if !foundQuad8 { + t.Fatal("did not find 8.8.8.8") + } + listener.Close() + <-done // wait for background goroutine to exit + }) + + t.Run("DNSActionNXDOMAIN", func(t *testing.T) { + ctx := context.Background() + listener, done, err := newproxy(DNSActionNXDOMAIN) + if err != nil { + t.Fatal(err) + } + r := newresolver(listener) + addrs, err := r.LookupHost(ctx, "dns.google") + if err == nil || err.Error() != netxlite.FailureDNSNXDOMAINError { + t.Fatal("unexpected err", err) + } + if addrs != nil { + t.Fatal("expected empty addrs") + } + listener.Close() + <-done // wait for background goroutine to exit + }) + + t.Run("DNSActionRefused", func(t *testing.T) { + ctx := context.Background() + listener, done, err := newproxy(DNSActionRefused) + if err != nil { + t.Fatal(err) + } + r := newresolver(listener) + addrs, err := r.LookupHost(ctx, "dns.google") + if err == nil || err.Error() != netxlite.FailureDNSRefusedError { + t.Fatal("unexpected err", err) + } + if addrs != nil { + t.Fatal("expected empty addrs") + } + listener.Close() + <-done // wait for background goroutine to exit + }) + + t.Run("DNSActionLocalHost", func(t *testing.T) { + ctx := context.Background() + listener, done, err := newproxy(DNSActionLocalHost) + if err != nil { + t.Fatal(err) + } + r := newresolver(listener) + addrs, err := r.LookupHost(ctx, "dns.google") + if err != nil { + t.Fatal(err) + } + if addrs == nil { + t.Fatal("expected non-empty addrs") + } + var found127001 bool + for _, addr := range addrs { + found127001 = found127001 || addr == "127.0.0.1" + } + if !found127001 { + t.Fatal("did not find 127.0.0.1") + } + listener.Close() + <-done // wait for background goroutine to exit + }) + + t.Run("DNSActionEmpty", func(t *testing.T) { + ctx := context.Background() + listener, done, err := newproxy(DNSActionEmpty) + if err != nil { + t.Fatal(err) + } + r := newresolver(listener) + addrs, err := r.LookupHost(ctx, "dns.google") + if err == nil || err.Error() != netxlite.FailureDNSNoAnswer { + t.Fatal(err) + } + if addrs != nil { + t.Fatal("expected empty addrs") + } + listener.Close() + <-done // wait for background goroutine to exit + }) + + t.Run("DNSActionTimeout", func(t *testing.T) { + // Implementation note: if you see this test running for more + // than one second, then it means we're not checking the context + // immediately. We should be improving there but we need to be + // careful because lots of legacy code uses SerialResolver. + const timeout = time.Second + ctx, cancel := context.WithTimeout(context.Background(), timeout) + listener, done, err := newproxy(DNSActionTimeout) + defer cancel() + if err != nil { + t.Fatal(err) + } + r := newresolver(listener) + addrs, err := r.LookupHost(ctx, "dns.google") + if err == nil || err.Error() != netxlite.FailureGenericTimeoutError { + t.Fatal(err) + } + if addrs != nil { + t.Fatal("expected empty addrs") + } + listener.Close() + <-done // wait for background goroutine to exit + }) + + t.Run("Start with invalid address", func(t *testing.T) { + p := &DNSProxy{} + listener, err := p.Start("127.0.0.1") + if err == nil { + t.Fatal("expected an error") + } + if listener != nil { + t.Fatal("expected nil listener") + } + }) + + t.Run("oneloop", func(t *testing.T) { + t.Run("ReadFrom failure after which we should continue", func(t *testing.T) { + expected := errors.New("mocked error") + p := &DNSProxy{} + conn := &mocks.QUICUDPLikeConn{ + MockReadFrom: func(p []byte) (n int, addr net.Addr, err error) { + return 0, nil, expected + }, + } + okay := p.oneloop(conn) + if !okay { + t.Fatal("we should be okay after this error") + } + }) + + t.Run("ReadFrom the connection is closed", func(t *testing.T) { + expected := errors.New("use of closed network connection") + p := &DNSProxy{} + conn := &mocks.QUICUDPLikeConn{ + MockReadFrom: func(p []byte) (n int, addr net.Addr, err error) { + return 0, nil, expected + }, + } + okay := p.oneloop(conn) + if okay { + t.Fatal("we should not be okay after this error") + } + }) + + t.Run("Unpack fails", func(t *testing.T) { + p := &DNSProxy{} + conn := &mocks.QUICUDPLikeConn{ + MockReadFrom: func(p []byte) (n int, addr net.Addr, err error) { + if len(p) < 4 { + panic("buffer too small") + } + p[0] = 7 + return 1, &net.UDPAddr{}, nil + }, + } + okay := p.oneloop(conn) + if !okay { + t.Fatal("we should be okay after this error") + } + }) + + t.Run("reply fails", func(t *testing.T) { + p := &DNSProxy{} + conn := &mocks.QUICUDPLikeConn{ + MockReadFrom: func(p []byte) (n int, addr net.Addr, err error) { + query := &dns.Msg{} + query.Question = append(query.Question, dns.Question{}) + query.Question = append(query.Question, dns.Question{}) + data, err := query.Pack() + if err != nil { + panic(err) + } + if len(p) < len(data) { + panic("buffer too small") + } + for i := 0; i < len(data); i++ { + p[i] = data[i] + } + return len(data), &net.UDPAddr{}, nil + }, + } + okay := p.oneloop(conn) + if !okay { + t.Fatal("we should be okay after this error") + } + }) + + t.Run("pack fails", func(t *testing.T) { + p := &DNSProxy{ + mockableReply: func(query *dns.Msg) (*dns.Msg, error) { + reply := &dns.Msg{} + reply.MsgHdr.Rcode = -1 // causes pack to fail + return reply, nil + }, + } + conn := &mocks.QUICUDPLikeConn{ + MockReadFrom: func(p []byte) (n int, addr net.Addr, err error) { + query := &dns.Msg{} + query.Question = append(query.Question, dns.Question{}) + data, err := query.Pack() + if err != nil { + panic(err) + } + if len(p) < len(data) { + panic("buffer too small") + } + for i := 0; i < len(data); i++ { + p[i] = data[i] + } + return len(data), &net.UDPAddr{}, nil + }, + } + okay := p.oneloop(conn) + if !okay { + t.Fatal("we should be okay after this error") + } + }) + }) + + t.Run("proxy", func(t *testing.T) { + t.Run("pack fails", func(t *testing.T) { + p := &DNSProxy{} + query := &dns.Msg{} + query.Rcode = -1 // causes Pack to fail + reply, err := p.proxy(query) + if err == nil { + t.Fatal("expected error here") + } + if reply != nil { + t.Fatal("expected nil reply") + } + }) + + t.Run("round trip fails", func(t *testing.T) { + expected := errors.New("mocked error") + p := &DNSProxy{ + Upstream: &mocks.DNSTransport{ + MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) { + return nil, expected + }, + MockCloseIdleConnections: func() {}, + }, + } + reply, err := p.proxy(&dns.Msg{}) + if !errors.Is(err, expected) { + t.Fatal("unexpected err", err) + } + if reply != nil { + t.Fatal("expected nil reply here") + } + }) + + t.Run("unpack fails", func(t *testing.T) { + p := &DNSProxy{ + Upstream: &mocks.DNSTransport{ + MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) { + return make([]byte, 1), nil + }, + MockCloseIdleConnections: func() {}, + }, + } + reply, err := p.proxy(&dns.Msg{}) + if err == nil { + t.Fatal("expected error") + } + if reply != nil { + t.Fatal("expected nil reply here") + } + }) + }) +} diff --git a/internal/netxlite/filtering/doc.go b/internal/netxlite/filtering/doc.go new file mode 100644 index 0000000..46bf2f6 --- /dev/null +++ b/internal/netxlite/filtering/doc.go @@ -0,0 +1,2 @@ +// Package filtering contains primitives for implementing filtering. +package filtering diff --git a/internal/netxlite/filtering/tls.go b/internal/netxlite/filtering/tls.go new file mode 100644 index 0000000..0c4f581 --- /dev/null +++ b/internal/netxlite/filtering/tls.go @@ -0,0 +1,234 @@ +package filtering + +import ( + "crypto/tls" + "errors" + "io" + "net" + "strings" + "sync" +) + +// TLSAction is the action that this proxy should take. +type TLSAction int + +const ( + // TLSActionProxy proxies the traffic to the destination. + TLSActionProxy = TLSAction(iota) + + // TLSActionReset resets the connection. + TLSActionReset + + // TLSActionTimeout causes the connection to timeout. + TLSActionTimeout + + // TLSActionEOF closes the connection. + TLSActionEOF + + // TLSActionAlertInternalError sends an internal error + // alert message to the TLS client. + TLSActionAlertInternalError + + // TLSActionAlertUnrecognizedName tells the client that + // it's handshaking with an unknown SNI. + TLSActionAlertUnrecognizedName +) + +// TLSProxy is a TLS proxy that routes the traffic depending +// on the SNI value and may implement filtering policies. +type TLSProxy struct { + // OnIncomingSNI is the MANDATORY hook called whenever we have + // successfully received a ClientHello message. + OnIncomingSNI func(sni string) TLSAction +} + +// Start starts the proxy. +func (p *TLSProxy) Start(address string) (net.Listener, error) { + listener, _, err := p.start(address) + return listener, err +} + +// Start starts the proxy. +func (p *TLSProxy) start(address string) (net.Listener, <-chan interface{}, error) { + listener, err := net.Listen("tcp", address) + if err != nil { + return nil, nil, err + } + done := make(chan interface{}) + go p.mainloop(listener, done) + return listener, done, nil +} + +func (p *TLSProxy) mainloop(listener net.Listener, done chan<- interface{}) { + defer close(done) + for { + conn, err := listener.Accept() + if err == nil { + go p.handle(conn) + continue + } + if strings.HasSuffix(err.Error(), "use of closed network connection") { + break + } + } +} + +const ( + tlsAlertInternalError = byte(80) + tlsAlertUnrecognizedName = byte(112) +) + +func (p *TLSProxy) handle(conn net.Conn) { + defer conn.Close() + sni, hello, err := p.readClientHello(conn) + if err != nil { + p.reset(conn) + return + } + switch p.OnIncomingSNI(sni) { + case TLSActionProxy: + p.proxy(conn, sni, hello) + case TLSActionTimeout: + p.timeout(conn) + case TLSActionAlertInternalError: + p.alert(conn, tlsAlertInternalError) + case TLSActionAlertUnrecognizedName: + p.alert(conn, tlsAlertUnrecognizedName) + case TLSActionEOF: + p.eof(conn) + default: + p.reset(conn) + } +} + +// readClientHello reads the incoming ClientHello message. +// +// Arguments: +// +// - conn is the connection from which to read the ClientHello. +// +// Returns: +// +// - a string containing the SNI (empty on error); +// +// - bytes from the original ClientHello (nil on error); +// +// - an error (nil on success). +func (p *TLSProxy) readClientHello(conn net.Conn) (string, []byte, error) { + connWrapper := &tlsClientHelloReader{Conn: conn} + var ( + expectedErr = errors.New("cannot continue handhake") + sni string + mutex sync.Mutex // just for safety + ) + err := tls.Server(connWrapper, &tls.Config{ + GetCertificate: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) { + mutex.Lock() + sni = info.ServerName + mutex.Unlock() + return nil, expectedErr + }, + }).Handshake() + if !errors.Is(err, expectedErr) { + return "", nil, err + } + return sni, connWrapper.clientHello, nil +} + +// tlsClientHelloReader wraps a net.Conn for the purpose of +// saving the bytes of the ClientHello message. +type tlsClientHelloReader struct { + net.Conn + clientHello []byte +} + +func (c *tlsClientHelloReader) Read(b []byte) (int, error) { + count, err := c.Conn.Read(b) + if err == nil { + c.clientHello = append(c.clientHello, b[:count]...) + } + return count, err +} + +// Write prevents writing on the real connection +func (c *tlsClientHelloReader) Write(b []byte) (int, error) { + return 0, errors.New("cannot write on this connection") +} + +func (p *TLSProxy) reset(conn net.Conn) { + if tc, ok := conn.(*net.TCPConn); ok { + tc.SetLinger(0) + } + conn.Close() +} + +func (p *TLSProxy) timeout(conn net.Conn) { + buffer := make([]byte, 1<<14) + conn.Read(buffer) + conn.Close() +} + +func (p *TLSProxy) eof(conn net.Conn) { + conn.Close() +} + +func (p *TLSProxy) alert(conn net.Conn, code byte) { + alertdata := []byte{ + 21, // alert + 3, // version[0] + 3, // version[1] + 0, // length[0] + 2, // length[1] + 2, // fatal + code, + } + conn.Write(alertdata) + conn.Close() +} + +func (p *TLSProxy) proxy(conn net.Conn, sni string, hello []byte) { + p.proxydial(conn, sni, hello, net.Dial) +} + +func (p *TLSProxy) proxydial(conn net.Conn, sni string, hello []byte, + dial func(network, address string) (net.Conn, error)) { + if sni == "" { // don't know the destination host + p.reset(conn) + return + } + serverconn, err := dial("tcp", net.JoinHostPort(sni, "443")) + if err != nil { + p.reset(conn) + return + } + if p.connectingToMyself(serverconn) { + p.reset(conn) + return + } + if _, err := serverconn.Write(hello); err != nil { + p.reset(conn) + return + } + defer serverconn.Close() // conn is owned by the caller + wg := &sync.WaitGroup{} + wg.Add(2) + go p.forward(wg, conn, serverconn) + go p.forward(wg, serverconn, conn) + wg.Wait() +} + +// connectingToMyself returns true when the proxy has been somehow +// forced to create a connection to itself. +func (p *TLSProxy) connectingToMyself(conn net.Conn) bool { + local := conn.LocalAddr().String() + localAddr, _, localErr := net.SplitHostPort(local) + remote := conn.RemoteAddr().String() + remoteAddr, _, remoteErr := net.SplitHostPort(remote) + return localErr != nil || remoteErr != nil || localAddr == remoteAddr +} + +// forward will forward the traffic. +func (p *TLSProxy) forward(wg *sync.WaitGroup, left net.Conn, right net.Conn) { + defer wg.Done() + io.Copy(left, right) +} diff --git a/internal/netxlite/filtering/tls_test.go b/internal/netxlite/filtering/tls_test.go new file mode 100644 index 0000000..6e441ee --- /dev/null +++ b/internal/netxlite/filtering/tls_test.go @@ -0,0 +1,261 @@ +package filtering + +import ( + "context" + "crypto/tls" + "errors" + "net" + "strings" + "testing" + + "github.com/apex/log" + "github.com/ooni/probe-cli/v3/internal/netxlite" + "github.com/ooni/probe-cli/v3/internal/netxlite/mocks" +) + +func TestTLSProxy(t *testing.T) { + newproxy := func(action TLSAction) (net.Listener, <-chan interface{}, error) { + p := &TLSProxy{ + OnIncomingSNI: func(sni string) TLSAction { + return action + }, + } + return p.start("127.0.0.1:0") + } + + dialTLS := func(ctx context.Context, endpoint string, sni string) (net.Conn, error) { + d := netxlite.NewDialerWithoutResolver(log.Log) + th := netxlite.NewTLSHandshakerStdlib(log.Log) + tdx := netxlite.NewTLSDialerWithConfig(d, th, &tls.Config{ + ServerName: sni, + NextProtos: []string{"h2", "http/1.1"}, + RootCAs: netxlite.NewDefaultCertPool(), + }) + return tdx.DialTLSContext(ctx, "tcp", endpoint) + } + + t.Run("TLSActionProxy with default proxy", func(t *testing.T) { + ctx := context.Background() + listener, done, err := newproxy(TLSActionProxy) + if err != nil { + t.Fatal(err) + } + conn, err := dialTLS(ctx, listener.Addr().String(), "dns.google") + if err != nil { + t.Fatal(err) + } + conn.Close() + listener.Close() + <-done // wait for background goroutine to exit + }) + + t.Run("TLSActionTimeout", func(t *testing.T) { + ctx := context.Background() + listener, done, err := newproxy(TLSActionTimeout) + if err != nil { + t.Fatal(err) + } + conn, err := dialTLS(ctx, listener.Addr().String(), "dns.google") + if err == nil || err.Error() != netxlite.FailureGenericTimeoutError { + t.Fatal("unexpected err", err) + } + if conn != nil { + t.Fatal("expected nil conn") + } + listener.Close() + <-done // wait for background goroutine to exit + }) + + t.Run("TLSActionAlertInternalError", func(t *testing.T) { + ctx := context.Background() + listener, done, err := newproxy(TLSActionAlertInternalError) + if err != nil { + t.Fatal(err) + } + conn, err := dialTLS(ctx, listener.Addr().String(), "dns.google") + if err == nil || !strings.HasSuffix(err.Error(), "tls: internal error") { + t.Fatal("unexpected err", err) + } + if conn != nil { + t.Fatal("expected nil conn") + } + listener.Close() + <-done // wait for background goroutine to exit + }) + + t.Run("TLSActionAlertUnrecognizedName", func(t *testing.T) { + ctx := context.Background() + listener, done, err := newproxy(TLSActionAlertUnrecognizedName) + if err != nil { + t.Fatal(err) + } + conn, err := dialTLS(ctx, listener.Addr().String(), "dns.google") + if err == nil || !strings.HasSuffix(err.Error(), "tls: unrecognized name") { + t.Fatal("unexpected err", err) + } + if conn != nil { + t.Fatal("expected nil conn") + } + listener.Close() + <-done // wait for background goroutine to exit + }) + + t.Run("TLSActionEOF", func(t *testing.T) { + ctx := context.Background() + listener, done, err := newproxy(TLSActionEOF) + if err != nil { + t.Fatal(err) + } + conn, err := dialTLS(ctx, listener.Addr().String(), "dns.google") + if err == nil || err.Error() != netxlite.FailureEOFError { + t.Fatal("unexpected err", err) + } + if conn != nil { + t.Fatal("expected nil conn") + } + listener.Close() + <-done // wait for background goroutine to exit + }) + + t.Run("TLSActionReset", func(t *testing.T) { + ctx := context.Background() + listener, done, err := newproxy(TLSActionReset) + if err != nil { + t.Fatal(err) + } + conn, err := dialTLS(ctx, listener.Addr().String(), "dns.google") + if err == nil || err.Error() != netxlite.FailureConnectionReset { + t.Fatal("unexpected err", err) + } + if conn != nil { + t.Fatal("expected nil conn") + } + listener.Close() + <-done // wait for background goroutine to exit + }) + + t.Run("handle cannot read ClientHello", func(t *testing.T) { + listener, done, err := newproxy(TLSActionProxy) + if err != nil { + t.Fatal(err) + } + conn, err := net.Dial("tcp", listener.Addr().String()) + if err != nil { + t.Fatal(err) + } + conn.Write([]byte("GET / HTTP/1.0\r\n\r\n")) + buff := make([]byte, 1<<17) + _, err = conn.Read(buff) + // Implementation note: we need to wrap the error because + // otherwise the error string on Windows is different from Unix + if err == nil { + t.Fatal("expected non-nil error") + } + err = netxlite.NewTopLevelGenericErrWrapper(err) + if err.Error() != netxlite.FailureConnectionReset { + t.Fatal("unexpected err", err) + } + listener.Close() + <-done // wait for background goroutine to exit + }) + + t.Run("TLSActionProxy fails because we don't have SNI", func(t *testing.T) { + ctx := context.Background() + listener, done, err := newproxy(TLSActionProxy) + if err != nil { + t.Fatal(err) + } + conn, err := dialTLS(ctx, listener.Addr().String(), "127.0.0.1") + if err == nil || err.Error() != netxlite.FailureConnectionReset { + t.Fatal("unexpected err", err) + } + if conn != nil { + t.Fatal("expected nil conn") + } + listener.Close() + <-done // wait for background goroutine to exit + }) + + t.Run("TLSActionProxy fails because we can't dial", func(t *testing.T) { + ctx := context.Background() + listener, done, err := newproxy(TLSActionProxy) + if err != nil { + t.Fatal(err) + } + conn, err := dialTLS(ctx, listener.Addr().String(), "antani.ooni.org") + if err == nil || err.Error() != netxlite.FailureConnectionReset { + t.Fatal("unexpected err", err) + } + if conn != nil { + t.Fatal("expected nil conn") + } + listener.Close() + <-done // wait for background goroutine to exit + }) + + t.Run("proxydial fails because it's connecting to itself", func(t *testing.T) { + p := &TLSProxy{} + conn := &mocks.Conn{ + MockClose: func() error { + return nil + }, + } + p.proxydial(conn, "ooni.org", nil, func(network, address string) (net.Conn, error) { + return &mocks.Conn{ + MockClose: func() error { + return nil + }, + MockLocalAddr: func() net.Addr { + return &net.TCPAddr{ + IP: net.IPv6loopback, + } + }, + MockRemoteAddr: func() net.Addr { + return &net.TCPAddr{ + IP: net.IPv6loopback, + } + }, + }, nil + }) + }) + + t.Run("proxydial fails because it cannot write the hello", func(t *testing.T) { + p := &TLSProxy{} + conn := &mocks.Conn{ + MockClose: func() error { + return nil + }, + } + p.proxydial(conn, "ooni.org", nil, func(network, address string) (net.Conn, error) { + return &mocks.Conn{ + MockClose: func() error { + return nil + }, + MockLocalAddr: func() net.Addr { + return &net.TCPAddr{ + IP: net.IPv6loopback, + } + }, + MockRemoteAddr: func() net.Addr { + return &net.TCPAddr{ + IP: net.IPv4(10, 0, 0, 1), + } + }, + MockWrite: func(b []byte) (int, error) { + return 0, errors.New("mocked error") + }, + }, nil + }) + }) + + t.Run("Start fails on an invalid address", func(t *testing.T) { + p := &TLSProxy{} + listener, err := p.Start("127.0.0.1") + if err == nil { + t.Fatal("expected an error") + } + if listener != nil { + t.Fatal("expected nil listener") + } + }) +} diff --git a/internal/netxlite/http_test.go b/internal/netxlite/http_test.go index 9045fc3..2844d8a 100644 --- a/internal/netxlite/http_test.go +++ b/internal/netxlite/http_test.go @@ -177,7 +177,7 @@ func TestNewHTTPTransport(t *testing.T) { td := NewTLSDialer(d, NewTLSHandshakerStdlib(log.Log)) txp := NewHTTPTransport(log.Log, d, td) client := &http.Client{Transport: txp} - resp, err := client.Get("https://www.google.com/robots.txt") + resp, err := client.Get("https://8.8.4.4/robots.txt") if !errors.Is(err, expected) { t.Fatal("not the error we expected", err) } diff --git a/internal/netxlite/integration_test.go b/internal/netxlite/integration_test.go index 7bc791e..d4c0234 100644 --- a/internal/netxlite/integration_test.go +++ b/internal/netxlite/integration_test.go @@ -3,27 +3,42 @@ package netxlite_test import ( "context" "crypto/tls" + "fmt" "net" "net/http" "testing" + "time" "github.com/apex/log" "github.com/lucas-clemente/quic-go" "github.com/ooni/probe-cli/v3/internal/netxlite" + "github.com/ooni/probe-cli/v3/internal/netxlite/filtering" utls "gitlab.com/yawning/utls.git" ) -func TestResolver(t *testing.T) { +// This set of integration tests ensures that we continue to +// be able to measure the conditions we care about + +func TestMeasureWithSystemResolver(t *testing.T) { if testing.Short() { t.Skip("skip test in short mode") } - t.Run("works as intended", func(t *testing.T) { - // TODO(bassosimone): this is actually an integration - // test but how to test this case? + // + // Measurement conditions we care about: + // + // - success + // + // - nxdomain + // + // - timeout + // + + t.Run("on success", func(t *testing.T) { r := netxlite.NewResolverStdlib(log.Log) defer r.CloseIdleConnections() - addrs, err := r.LookupHost(context.Background(), "dns.google.com") + ctx := context.Background() + addrs, err := r.LookupHost(ctx, "dns.google.com") if err != nil { t.Fatal(err) } @@ -31,6 +46,413 @@ func TestResolver(t *testing.T) { t.Fatal("expected non-nil result here") } }) + + t.Run("for nxdomain", func(t *testing.T) { + r := netxlite.NewResolverStdlib(log.Log) + defer r.CloseIdleConnections() + ctx := context.Background() + addrs, err := r.LookupHost(ctx, "antani.ooni.org") + if err == nil || err.Error() != netxlite.FailureDNSNXDOMAINError { + t.Fatal("not the error we expected", err) + } + if addrs != nil { + t.Fatal("expected nil result here") + } + }) + + t.Run("for timeout", func(t *testing.T) { + r := netxlite.NewResolverStdlib(log.Log) + defer r.CloseIdleConnections() + const timeout = time.Nanosecond + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + addrs, err := r.LookupHost(ctx, "ooni.org") + if err == nil || err.Error() != netxlite.FailureGenericTimeoutError { + t.Fatal("not the error we expected", err) + } + if addrs != nil { + t.Fatal("expected nil result here") + } + }) +} + +func TestMeasureWithUDPResolver(t *testing.T) { + if testing.Short() { + t.Skip("skip test in short mode") + } + + // + // Measurement conditions we care about: + // + // - success + // + // - nxdomain + // + // - refused + // + // - timeout + // + + t.Run("on success", func(t *testing.T) { + dlr := netxlite.NewDialerWithoutResolver(log.Log) + r := netxlite.NewResolverUDP(log.Log, dlr, "8.8.4.4:53") + defer r.CloseIdleConnections() + ctx := context.Background() + addrs, err := r.LookupHost(ctx, "dns.google.com") + if err != nil { + t.Fatal(err) + } + if addrs == nil { + t.Fatal("expected non-nil result here") + } + }) + + t.Run("for nxdomain", func(t *testing.T) { + proxy := &filtering.DNSProxy{ + OnQuery: func(domain string) filtering.DNSAction { + return filtering.DNSActionNXDOMAIN + }, + } + listener, err := proxy.Start("127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer listener.Close() + dlr := netxlite.NewDialerWithoutResolver(log.Log) + r := netxlite.NewResolverUDP(log.Log, dlr, listener.LocalAddr().String()) + defer r.CloseIdleConnections() + ctx := context.Background() + addrs, err := r.LookupHost(ctx, "ooni.org") + if err == nil || err.Error() != netxlite.FailureDNSNXDOMAINError { + t.Fatal("not the error we expected", err) + } + if addrs != nil { + t.Fatal("expected nil result here") + } + }) + + t.Run("for refused", func(t *testing.T) { + proxy := &filtering.DNSProxy{ + OnQuery: func(domain string) filtering.DNSAction { + return filtering.DNSActionRefused + }, + } + listener, err := proxy.Start("127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer listener.Close() + dlr := netxlite.NewDialerWithoutResolver(log.Log) + r := netxlite.NewResolverUDP(log.Log, dlr, listener.LocalAddr().String()) + defer r.CloseIdleConnections() + ctx := context.Background() + addrs, err := r.LookupHost(ctx, "ooni.org") + if err == nil || err.Error() != netxlite.FailureDNSRefusedError { + t.Fatal("not the error we expected", err) + } + if addrs != nil { + t.Fatal("expected nil result here") + } + }) + + t.Run("for timeout", func(t *testing.T) { + proxy := &filtering.DNSProxy{ + OnQuery: func(domain string) filtering.DNSAction { + return filtering.DNSActionTimeout + }, + } + listener, err := proxy.Start("127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer listener.Close() + dlr := netxlite.NewDialerWithoutResolver(log.Log) + r := netxlite.NewResolverUDP(log.Log, dlr, listener.LocalAddr().String()) + defer r.CloseIdleConnections() + ctx := context.Background() + addrs, err := r.LookupHost(ctx, "ooni.org") + if err == nil || err.Error() != netxlite.FailureGenericTimeoutError { + t.Fatal("not the error we expected", err) + } + if addrs != nil { + t.Fatal("expected nil result here") + } + }) +} + +func TestMeasureWithDialer(t *testing.T) { + if testing.Short() { + t.Skip("skip test in short mode") + } + + // + // Measurement conditions we care about: + // + // - success + // + // - connection refused + // + // - timeout + // + + t.Run("on success", func(t *testing.T) { + d := netxlite.NewDialerWithoutResolver(log.Log) + defer d.CloseIdleConnections() + ctx := context.Background() + conn, err := d.DialContext(ctx, "tcp", "8.8.4.4:443") + if err != nil { + t.Fatal(err) + } + if conn == nil { + t.Fatal("expected non-nil conn here") + } + conn.Close() + }) + + t.Run("on connection refused", func(t *testing.T) { + d := netxlite.NewDialerWithoutResolver(log.Log) + defer d.CloseIdleConnections() + ctx := context.Background() + // Here we assume that no-one is listening on 127.0.0.1:1 + conn, err := d.DialContext(ctx, "tcp", "127.0.0.1:1") + if err == nil || err.Error() != netxlite.FailureConnectionRefused { + t.Fatal("not the error we expected", err) + } + if conn != nil { + t.Fatal("expected nil conn here") + } + }) + + t.Run("on timeout", func(t *testing.T) { + d := netxlite.NewDialerWithoutResolver(log.Log) + defer d.CloseIdleConnections() + ctx := context.Background() + // Here we assume 8.8.4.4:1 is filtered + conn, err := d.DialContext(ctx, "tcp", "8.8.4.4:1") + if err == nil || err.Error() != netxlite.FailureGenericTimeoutError { + t.Fatal("not the error we expected", err) + } + if conn != nil { + t.Fatal("expected nil conn here") + } + }) +} + +func TestMeasureWithTLSHandshaker(t *testing.T) { + if testing.Short() { + t.Skip("skip test in short mode") + } + + // + // Measurement conditions we care about: + // + // - success + // + // - connection reset + // + // - timeout + // + + dial := func(ctx context.Context, address string) (net.Conn, error) { + d := netxlite.NewDialerWithoutResolver(log.Log) + return d.DialContext(ctx, "tcp", address) + } + + successFlow := func(th netxlite.TLSHandshaker) error { + ctx := context.Background() + conn, err := dial(ctx, "8.8.4.4:443") + if err != nil { + return fmt.Errorf("dial failed: %w", err) + } + defer conn.Close() + config := &tls.Config{ + ServerName: "dns.google", + NextProtos: []string{"h2", "http/1.1"}, + RootCAs: netxlite.NewDefaultCertPool(), + } + tconn, _, err := th.Handshake(ctx, conn, config) + if err != nil { + return fmt.Errorf("tls handshake failed: %w", err) + } + tconn.Close() + return nil + } + + connectionResetFlow := func(th netxlite.TLSHandshaker) error { + tlsProxy := &filtering.TLSProxy{ + OnIncomingSNI: func(sni string) filtering.TLSAction { + return filtering.TLSActionReset + }, + } + listener, err := tlsProxy.Start("127.0.0.1:0") + if err != nil { + return fmt.Errorf("cannot start proxy: %w", err) + } + defer listener.Close() + ctx := context.Background() + conn, err := dial(ctx, listener.Addr().String()) + if err != nil { + return fmt.Errorf("dial failed: %w", err) + } + defer conn.Close() + config := &tls.Config{ + ServerName: "dns.google", + NextProtos: []string{"h2", "http/1.1"}, + RootCAs: netxlite.NewDefaultCertPool(), + } + tconn, _, err := th.Handshake(ctx, conn, config) + if err == nil { + return fmt.Errorf("tls handshake succeded unexpectedly") + } + if err.Error() != netxlite.FailureConnectionReset { + return fmt.Errorf("not the error we expected: %w", err) + } + if tconn != nil { + return fmt.Errorf("expected nil tconn here") + } + return nil + } + + timeoutFlow := func(th netxlite.TLSHandshaker) error { + tlsProxy := &filtering.TLSProxy{ + OnIncomingSNI: func(sni string) filtering.TLSAction { + return filtering.TLSActionTimeout + }, + } + listener, err := tlsProxy.Start("127.0.0.1:0") + if err != nil { + return fmt.Errorf("cannot start proxy: %w", err) + } + defer listener.Close() + ctx := context.Background() + conn, err := dial(ctx, listener.Addr().String()) + if err != nil { + return fmt.Errorf("dial failed: %w", err) + } + defer conn.Close() + config := &tls.Config{ + ServerName: "dns.google", + NextProtos: []string{"h2", "http/1.1"}, + RootCAs: netxlite.NewDefaultCertPool(), + } + tconn, _, err := th.Handshake(ctx, conn, config) + if err == nil { + return fmt.Errorf("tls handshake succeded unexpectedly") + } + if err.Error() != netxlite.FailureGenericTimeoutError { + return fmt.Errorf("not the error we expected: %w", err) + } + if tconn != nil { + return fmt.Errorf("expected nil tconn here") + } + return nil + } + + t.Run("for stdlib handshaker", func(t *testing.T) { + t.Run("on success", func(t *testing.T) { + th := netxlite.NewTLSHandshakerStdlib(log.Log) + err := successFlow(th) + if err != nil { + t.Fatal(err) + } + }) + + t.Run("on connection reset", func(t *testing.T) { + th := netxlite.NewTLSHandshakerStdlib(log.Log) + err := connectionResetFlow(th) + if err != nil { + t.Fatal(err) + } + }) + + t.Run("on timeout", func(t *testing.T) { + th := netxlite.NewTLSHandshakerStdlib(log.Log) + err := timeoutFlow(th) + if err != nil { + t.Fatal(err) + } + }) + }) + + t.Run("for utls handshaker", func(t *testing.T) { + t.Run("on success", func(t *testing.T) { + th := netxlite.NewTLSHandshakerUTLS(log.Log, &utls.HelloFirefox_55) + err := successFlow(th) + if err != nil { + t.Fatal(err) + } + }) + + t.Run("on connection reset", func(t *testing.T) { + th := netxlite.NewTLSHandshakerUTLS(log.Log, &utls.HelloFirefox_55) + err := connectionResetFlow(th) + if err != nil { + t.Fatal(err) + } + }) + + t.Run("on timeout", func(t *testing.T) { + th := netxlite.NewTLSHandshakerUTLS(log.Log, &utls.HelloFirefox_55) + err := timeoutFlow(th) + if err != nil { + t.Fatal(err) + } + }) + }) +} + +func TestMeasureWithQUICDialer(t *testing.T) { + if testing.Short() { + t.Skip("skip test in short mode") + } + + // + // Measurement conditions we care about: + // + // - success + // + // - timeout + // + + t.Run("on success", func(t *testing.T) { + ql := netxlite.NewQUICListener() + d := netxlite.NewQUICDialerWithoutResolver(ql, log.Log) + defer d.CloseIdleConnections() + ctx := context.Background() + config := &tls.Config{ + ServerName: "dns.google", + NextProtos: []string{"h3"}, + RootCAs: netxlite.NewDefaultCertPool(), + } + sess, err := d.DialContext(ctx, "udp", "8.8.4.4:443", config, &quic.Config{}) + if err != nil { + t.Fatal(err) + } + if sess == nil { + t.Fatal("expected non-nil sess here") + } + sess.CloseWithError(0, "") + }) + + t.Run("on timeout", func(t *testing.T) { + ql := netxlite.NewQUICListener() + d := netxlite.NewQUICDialerWithoutResolver(ql, log.Log) + defer d.CloseIdleConnections() + ctx := context.Background() + config := &tls.Config{ + ServerName: "dns.google", + NextProtos: []string{"h3"}, + RootCAs: netxlite.NewDefaultCertPool(), + } + // Here we assume 8.8.4.4:1 is filtered + sess, err := d.DialContext(ctx, "udp", "8.8.4.4:1", config, &quic.Config{}) + if err == nil || err.Error() != netxlite.FailureGenericTimeoutError { + t.Fatal("not the error we expected", err) + } + if sess != nil { + t.Fatal("expected nil sess here") + } + }) } func TestHTTPTransport(t *testing.T) { @@ -73,63 +495,3 @@ func TestHTTP3Transport(t *testing.T) { txp.CloseIdleConnections() }) } - -func TestUTLSHandshaker(t *testing.T) { - t.Run("with chrome fingerprint", func(t *testing.T) { - h := netxlite.NewTLSHandshakerUTLS(log.Log, &utls.HelloChrome_Auto) - cfg := &tls.Config{ServerName: "google.com"} - conn, err := net.Dial("tcp", "google.com:443") - if err != nil { - t.Fatal("unexpected error", err) - } - conn, _, err = h.Handshake(context.Background(), conn, cfg) - if err != nil { - t.Fatal("unexpected error", err) - } - if conn == nil { - t.Fatal("nil connection") - } - }) -} - -func TestQUICDialer(t *testing.T) { - if testing.Short() { - t.Skip("skip test in short mode") - } - - t.Run("works as intended", func(t *testing.T) { - tlsConfig := &tls.Config{ - ServerName: "dns.google", - } - d := netxlite.NewQUICDialerWithoutResolver( - netxlite.NewQUICListener(), log.Log, - ) - ctx := context.Background() - sess, err := d.DialContext( - ctx, "udp", "8.8.8.8:443", tlsConfig, &quic.Config{}) - if err != nil { - t.Fatal("not the error we expected", err) - } - <-sess.HandshakeComplete().Done() - if err := sess.CloseWithError(0, ""); err != nil { - t.Fatal(err) - } - }) - - t.Run("can guess the SNI and ALPN when using a domain name for web", func(t *testing.T) { - d := netxlite.NewQUICDialerWithResolver( - netxlite.NewQUICListener(), log.Log, - netxlite.NewResolverStdlib(log.Log), - ) - ctx := context.Background() - sess, err := d.DialContext( - ctx, "udp", "dns.google:443", &tls.Config{}, &quic.Config{}) - if err != nil { - t.Fatal("not the error we expected", err) - } - <-sess.HandshakeComplete().Done() - if err := sess.CloseWithError(0, ""); err != nil { - t.Fatal(err) - } - }) -} diff --git a/internal/netxlite/quic_test.go b/internal/netxlite/quic_test.go index 8ed70c0..ce7dd88 100644 --- a/internal/netxlite/quic_test.go +++ b/internal/netxlite/quic_test.go @@ -252,25 +252,6 @@ func TestQUICDialerResolver(t *testing.T) { }) t.Run("DialContext", func(t *testing.T) { - t.Run("on success", func(t *testing.T) { - tlsConfig := &tls.Config{} - dialer := &quicDialerResolver{ - Resolver: NewResolverStdlib(log.Log), - Dialer: &quicDialerQUICGo{ - QUICListener: &quicListenerStdlib{}, - }} - sess, err := dialer.DialContext( - context.Background(), "udp", "www.google.com:443", - tlsConfig, &quic.Config{}) - if err != nil { - t.Fatal(err) - } - <-sess.HandshakeComplete().Done() - if err := sess.CloseWithError(0, ""); err != nil { - t.Fatal(err) - } - }) - t.Run("with missing port", func(t *testing.T) { tlsConfig := &tls.Config{} dialer := &quicDialerResolver{ @@ -306,7 +287,7 @@ func TestQUICDialerResolver(t *testing.T) { } }) - t.Run("with invalid port (i.e., the zero port)", func(t *testing.T) { + t.Run("with invalid, non-numeric port)", func(t *testing.T) { // This test allows us to check for the case where every attempt // to establish a connection leads to a failure tlsConf := &tls.Config{} @@ -316,13 +297,12 @@ func TestQUICDialerResolver(t *testing.T) { QUICListener: &quicListenerStdlib{}, }} sess, err := dialer.DialContext( - context.Background(), "udp", "www.google.com:0", + context.Background(), "udp", "8.8.4.4:x", tlsConf, &quic.Config{}) if err == nil { t.Fatal("expected an error here") } - if !strings.HasSuffix(err.Error(), "sendto: invalid argument") && - !strings.HasSuffix(err.Error(), "sendto: can't assign requested address") { + if !strings.HasSuffix(err.Error(), "invalid syntax") { t.Fatal("not the error we expected", err) } if sess != nil { @@ -344,7 +324,7 @@ func TestQUICDialerResolver(t *testing.T) { }, }} sess, err := dialer.DialContext( - context.Background(), "udp", "www.google.com:443", + context.Background(), "udp", "8.8.4.4:443", tlsConfig, &quic.Config{}) if !errors.Is(err, expected) { t.Fatal("not the error we expected", err) @@ -355,7 +335,7 @@ func TestQUICDialerResolver(t *testing.T) { if tlsConfig.ServerName != "" { t.Fatal("should not have changed tlsConfig.ServerName") } - if gotTLSConfig.ServerName != "www.google.com" { + if gotTLSConfig.ServerName != "8.8.4.4" { t.Fatal("gotTLSConfig.ServerName has not been set") } }) diff --git a/internal/netxlite/serialresolver_test.go b/internal/netxlite/serialresolver_test.go index f621487..4238176 100644 --- a/internal/netxlite/serialresolver_test.go +++ b/internal/netxlite/serialresolver_test.go @@ -12,6 +12,22 @@ import ( "github.com/ooni/probe-cli/v3/internal/netxlite/mocks" ) +// errorWithTimeout is an error that golang will always consider +// to be a timeout because it has a Timeout() bool method +type errorWithTimeout struct { + error +} + +// Timeout returns whether this error is a timeout. +func (err *errorWithTimeout) Timeout() bool { + return true +} + +// Unwrap allows to unwrap the error. +func (err *errorWithTimeout) Unwrap() error { + return err.error +} + func TestSerialResolver(t *testing.T) { t.Run("transport okay", func(t *testing.T) { txp := NewDNSOverTLS((&tls.Dialer{}).DialContext, "8.8.8.8:853") @@ -129,7 +145,10 @@ func TestSerialResolver(t *testing.T) { t.Run("with timeout", func(t *testing.T) { txp := &mocks.DNSTransport{ MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) { - return nil, &net.OpError{Err: ETIMEDOUT, Op: "dial"} + return nil, &net.OpError{ + Err: &errorWithTimeout{ETIMEDOUT}, + Op: "dial", + } }, MockRequiresPadding: func() bool { return true diff --git a/internal/runtimex/runtimex.go b/internal/runtimex/runtimex.go index 6d454e4..0be3a0c 100644 --- a/internal/runtimex/runtimex.go +++ b/internal/runtimex/runtimex.go @@ -10,3 +10,15 @@ func PanicOnError(err error, message string) { panic(fmt.Errorf("%s: %w", message, err)) } } + +// PanicIfFalse calls panic if assertion is false. +func PanicIfFalse(assertion bool, message string) { + if !assertion { + panic(message) + } +} + +// PanicIfTrue calls panic if assertion is true. +func PanicIfTrue(assertion bool, message string) { + PanicIfFalse(!assertion, message) +} diff --git a/internal/runtimex/runtimex_test.go b/internal/runtimex/runtimex_test.go index 9b129ec..4649540 100644 --- a/internal/runtimex/runtimex_test.go +++ b/internal/runtimex/runtimex_test.go @@ -7,21 +7,67 @@ import ( "github.com/ooni/probe-cli/v3/internal/runtimex" ) -func TestGood(t *testing.T) { - runtimex.PanicOnError(nil, "antani failed") -} - -func TestBad(t *testing.T) { - expected := errors.New("mocked error") - if !errors.Is(badfunc(expected), expected) { - t.Fatal("not the error we expected") +func TestPanicOnError(t *testing.T) { + badfunc := func(in error) (out error) { + defer func() { + out = recover().(error) + }() + runtimex.PanicOnError(in, "antani failed") + return } + + t.Run("error is nil", func(t *testing.T) { + runtimex.PanicOnError(nil, "antani failed") + }) + + t.Run("error is not nil", func(t *testing.T) { + expected := errors.New("mocked error") + if !errors.Is(badfunc(expected), expected) { + t.Fatal("not the error we expected") + } + }) } -func badfunc(in error) (out error) { - defer func() { - out = recover().(error) - }() - runtimex.PanicOnError(in, "antani failed") - return +func TestPanicIfFalse(t *testing.T) { + badfunc := func(in bool, message string) (out error) { + defer func() { + out = errors.New(recover().(string)) + }() + runtimex.PanicIfFalse(in, message) + return + } + + t.Run("assertion is true", func(t *testing.T) { + runtimex.PanicIfFalse(true, "antani failed") + }) + + t.Run("assertion is false", func(t *testing.T) { + message := "mocked error" + err := badfunc(false, message) + if err == nil || err.Error() != message { + t.Fatal("not the error we expected", err) + } + }) +} + +func TestPanicIfTrue(t *testing.T) { + badfunc := func(in bool, message string) (out error) { + defer func() { + out = errors.New(recover().(string)) + }() + runtimex.PanicIfTrue(in, message) + return + } + + t.Run("assertion is false", func(t *testing.T) { + runtimex.PanicIfTrue(false, "antani failed") + }) + + t.Run("assertion is true", func(t *testing.T) { + message := "mocked error" + err := badfunc(true, message) + if err == nil || err.Error() != message { + t.Fatal("not the error we expected", err) + } + }) }