From b9a844ecee04df1ebcf2600d5d66b4d5e1d7a9f9 Mon Sep 17 00:00:00 2001 From: Simone Basso Date: Wed, 29 Sep 2021 16:04:26 +0200 Subject: [PATCH] feat: run ~always netxlite integration tests (#522) * feat: run ~always netxlite integration tests This diff ensures that we check on windows, linux, macos that our fundamental networking library (netxlite) works. We combine unit and integration tests. This work is part of https://github.com/ooni/probe/issues/1733, where I want to have more strong guarantees about the foundations. * fix(filtering/tls_test.go): make portable on Windows The trick here is to use the wrapped error so to normalize the different errors messages we see on Windows. * fix(netxlite/quic_test.go): make portable on windows Rather than using the zero port, use the `x` port which fails when the stdlib is parsing the address. The zero port seems to work on Windows while it does not on Unix. * fix(serialresolver_test.go): make error more timeout than before This seems enough to convince Go on Windows about this error being really a timeout timeouty timeouted thingie. --- .github/workflows/coverage.yml | 2 +- .github/workflows/netxlite.yml | 20 + internal/netxlite/dialer_test.go | 2 +- internal/netxlite/filtering/dns.go | 218 ++++++++++ internal/netxlite/filtering/dns_test.go | 326 +++++++++++++++ internal/netxlite/filtering/doc.go | 2 + internal/netxlite/filtering/tls.go | 234 +++++++++++ internal/netxlite/filtering/tls_test.go | 261 ++++++++++++ internal/netxlite/http_test.go | 2 +- internal/netxlite/integration_test.go | 492 ++++++++++++++++++++--- internal/netxlite/quic_test.go | 30 +- internal/netxlite/serialresolver_test.go | 21 +- internal/runtimex/runtimex.go | 12 + internal/runtimex/runtimex_test.go | 74 +++- 14 files changed, 1588 insertions(+), 108 deletions(-) create mode 100644 .github/workflows/netxlite.yml create mode 100644 internal/netxlite/filtering/dns.go create mode 100644 internal/netxlite/filtering/dns_test.go create mode 100644 internal/netxlite/filtering/doc.go create mode 100644 internal/netxlite/filtering/tls.go create mode 100644 internal/netxlite/filtering/tls_test.go diff --git a/.github/workflows/coverage.yml b/.github/workflows/coverage.yml index f3ac09b..cccb360 100644 --- a/.github/workflows/coverage.yml +++ b/.github/workflows/coverage.yml @@ -10,7 +10,7 @@ jobs: runs-on: ubuntu-20.04 strategy: matrix: - go: [ "1.17" ] + go: [ "1.17.1" ] steps: - uses: actions/setup-go@v1 with: diff --git a/.github/workflows/netxlite.yml b/.github/workflows/netxlite.yml new file mode 100644 index 0000000..3b1d2e7 --- /dev/null +++ b/.github/workflows/netxlite.yml @@ -0,0 +1,20 @@ +# netxlite runs unit and integration tests on our fundamental net library +name: netxlite +on: + pull_request: + push: + branches: + - "master" +jobs: + test: + runs-on: "${{ matrix.os }}" + strategy: + matrix: + go: [ "1.17.1" ] + os: [ "ubuntu-20.04", "windows-2019", "macos-10.15" ] + steps: + - uses: actions/setup-go@v1 + with: + go-version: "${{ matrix.go }}" + - uses: actions/checkout@v2 + - run: go test -race ./internal/netxlite/... diff --git a/internal/netxlite/dialer_test.go b/internal/netxlite/dialer_test.go index 67cbafb..f95eb1f 100644 --- a/internal/netxlite/dialer_test.go +++ b/internal/netxlite/dialer_test.go @@ -72,7 +72,7 @@ func TestDialerSystem(t *testing.T) { }) t.Run("enforces the configured timeout", func(t *testing.T) { - const timeout = 1 * time.Millisecond + const timeout = 1 * time.Nanosecond d := &dialerSystem{timeout: timeout} ctx := context.Background() start := time.Now() diff --git a/internal/netxlite/filtering/dns.go b/internal/netxlite/filtering/dns.go new file mode 100644 index 0000000..8adec9d --- /dev/null +++ b/internal/netxlite/filtering/dns.go @@ -0,0 +1,218 @@ +package filtering + +import ( + "context" + "errors" + "io" + "net" + "net/http" + "strings" + + "github.com/miekg/dns" + "github.com/ooni/probe-cli/v3/internal/netxlite" + "github.com/ooni/probe-cli/v3/internal/runtimex" +) + +// DNSAction is the action that this proxy should take. +type DNSAction int + +const ( + // DNSActionProxy proxies the traffic to the upstream server. + DNSActionProxy = DNSAction(iota) + + // DNSActionNXDOMAIN replies with NXDOMAIN. + DNSActionNXDOMAIN + + // DNSActionRefused replies with Refused. + DNSActionRefused + + // DNSActionLocalHost replies with `127.0.0.1` and `::1`. + DNSActionLocalHost + + // DNSActionEmpty returns an empty reply. + DNSActionEmpty + + // DNSActionTimeout never replies to the query. + DNSActionTimeout +) + +// DNSProxy is a DNS proxy that routes traffic to an upstream +// resolver and may implement filtering policies. +type DNSProxy struct { + // OnQuery is the MANDATORY hook called whenever we + // receive a query for the given domain. + OnQuery func(domain string) DNSAction + + // Upstream is the OPTIONAL upstream transport. + Upstream DNSTransport + + // mockableReply allows to mock DNSProxy.reply in tests. + mockableReply func(query *dns.Msg) (*dns.Msg, error) +} + +// DNSTransport is the type we expect from an upstream DNS transport. +type DNSTransport interface { + RoundTrip(ctx context.Context, query []byte) ([]byte, error) + CloseIdleConnections() +} + +// DNSListener is the interface returned by DNSProxy.Start +type DNSListener interface { + io.Closer + LocalAddr() net.Addr +} + +// Start starts the proxy. +func (p *DNSProxy) Start(address string) (DNSListener, error) { + pconn, _, err := p.start(address) + return pconn, err +} + +func (p *DNSProxy) start(address string) (DNSListener, <-chan interface{}, error) { + pconn, err := net.ListenPacket("udp", address) + if err != nil { + return nil, nil, err + } + done := make(chan interface{}) + go p.mainloop(pconn, done) + return pconn, done, nil +} + +func (p *DNSProxy) mainloop(pconn net.PacketConn, done chan<- interface{}) { + defer close(done) + for p.oneloop(pconn) { + // nothing + } +} + +func (p *DNSProxy) oneloop(pconn net.PacketConn) bool { + buffer := make([]byte, 1<<12) + count, addr, err := pconn.ReadFrom(buffer) + if err != nil { + return !strings.HasSuffix(err.Error(), "use of closed network connection") + } + buffer = buffer[:count] + query := &dns.Msg{} + if err := query.Unpack(buffer); err != nil { + return true // can continue + } + reply, err := p.reply(query) + if err != nil { + return true // can continue + } + replyBytes, err := reply.Pack() + if err != nil { + return true // can continue + } + pconn.WriteTo(replyBytes, addr) + return true // can continue +} + +func (p *DNSProxy) reply(query *dns.Msg) (*dns.Msg, error) { + if p.mockableReply != nil { + return p.mockableReply(query) + } + return p.replyDefault(query) +} + +func (p *DNSProxy) replyDefault(query *dns.Msg) (*dns.Msg, error) { + if len(query.Question) != 1 { + return nil, errors.New("unhandled message") + } + name := query.Question[0].Name + switch p.OnQuery(name) { + case DNSActionProxy: + return p.proxy(query) + case DNSActionNXDOMAIN: + return p.nxdomain(query), nil + case DNSActionLocalHost: + return p.localHost(query), nil + case DNSActionEmpty: + return p.empty(query), nil + case DNSActionTimeout: + return nil, errors.New("let's ignore this query") + default: + return p.refused(query), nil + } +} + +func (p *DNSProxy) refused(query *dns.Msg) *dns.Msg { + m := new(dns.Msg) + m.SetRcode(query, dns.RcodeRefused) + return m +} + +func (p *DNSProxy) nxdomain(query *dns.Msg) *dns.Msg { + m := new(dns.Msg) + m.SetRcode(query, dns.RcodeNameError) + return m +} + +func (p *DNSProxy) localHost(query *dns.Msg) *dns.Msg { + return p.compose(query, net.IPv6loopback, net.IPv4(127, 0, 0, 1)) +} + +func (p *DNSProxy) empty(query *dns.Msg) *dns.Msg { + return p.compose(query) +} + +func (p *DNSProxy) compose(query *dns.Msg, ips ...net.IP) *dns.Msg { + runtimex.PanicIfTrue(len(query.Question) != 1, "expecting a single question") + question := query.Question[0] + reply := new(dns.Msg) + reply.Compress = true + reply.MsgHdr.RecursionAvailable = true + reply.SetReply(query) + for _, ip := range ips { + isIPv6 := strings.Contains(ip.String(), ":") + if !isIPv6 && question.Qtype == dns.TypeA { + reply.Answer = append(reply.Answer, &dns.A{ + Hdr: dns.RR_Header{ + Name: question.Name, + Rrtype: dns.TypeA, + Class: dns.ClassINET, + Ttl: 0, + }, + A: ip, + }) + } else if isIPv6 && question.Qtype == dns.TypeAAAA { + reply.Answer = append(reply.Answer, &dns.AAAA{ + Hdr: dns.RR_Header{ + Name: question.Name, + Rrtype: dns.TypeAAAA, + Class: dns.ClassINET, + Ttl: 0, + }, + AAAA: ip, + }) + } + } + return reply +} + +func (p *DNSProxy) proxy(query *dns.Msg) (*dns.Msg, error) { + queryBytes, err := query.Pack() + if err != nil { + return nil, err + } + txp := p.dnstransport() + defer txp.CloseIdleConnections() + ctx := context.Background() + replyBytes, err := txp.RoundTrip(ctx, queryBytes) + if err != nil { + return nil, err + } + reply := &dns.Msg{} + if err := reply.Unpack(replyBytes); err != nil { + return nil, err + } + return reply, nil +} + +func (p *DNSProxy) dnstransport() DNSTransport { + if p.Upstream != nil { + return p.Upstream + } + const URL = "https://1.1.1.1/dns-query" + return netxlite.NewDNSOverHTTPS(http.DefaultClient, URL) +} diff --git a/internal/netxlite/filtering/dns_test.go b/internal/netxlite/filtering/dns_test.go new file mode 100644 index 0000000..f68514b --- /dev/null +++ b/internal/netxlite/filtering/dns_test.go @@ -0,0 +1,326 @@ +package filtering + +import ( + "context" + "errors" + "net" + "testing" + "time" + + "github.com/apex/log" + "github.com/miekg/dns" + "github.com/ooni/probe-cli/v3/internal/netxlite" + "github.com/ooni/probe-cli/v3/internal/netxlite/mocks" +) + +func TestDNSProxy(t *testing.T) { + newproxy := func(action DNSAction) (DNSListener, <-chan interface{}, error) { + p := &DNSProxy{ + OnQuery: func(domain string) DNSAction { + return action + }, + } + return p.start("127.0.0.1:0") + } + + newresolver := func(listener DNSListener) netxlite.Resolver { + dlr := netxlite.NewDialerWithoutResolver(log.Log) + r := netxlite.NewResolverUDP(log.Log, dlr, listener.LocalAddr().String()) + return r + } + + t.Run("DNSActionProxy with default proxy", func(t *testing.T) { + ctx := context.Background() + listener, done, err := newproxy(DNSActionProxy) + if err != nil { + t.Fatal(err) + } + r := newresolver(listener) + addrs, err := r.LookupHost(ctx, "dns.google") + if err != nil { + t.Fatal(err) + } + if addrs == nil { + t.Fatal("unexpected empty addrs") + } + var foundQuad8 bool + for _, addr := range addrs { + foundQuad8 = foundQuad8 || addr == "8.8.8.8" + } + if !foundQuad8 { + t.Fatal("did not find 8.8.8.8") + } + listener.Close() + <-done // wait for background goroutine to exit + }) + + t.Run("DNSActionNXDOMAIN", func(t *testing.T) { + ctx := context.Background() + listener, done, err := newproxy(DNSActionNXDOMAIN) + if err != nil { + t.Fatal(err) + } + r := newresolver(listener) + addrs, err := r.LookupHost(ctx, "dns.google") + if err == nil || err.Error() != netxlite.FailureDNSNXDOMAINError { + t.Fatal("unexpected err", err) + } + if addrs != nil { + t.Fatal("expected empty addrs") + } + listener.Close() + <-done // wait for background goroutine to exit + }) + + t.Run("DNSActionRefused", func(t *testing.T) { + ctx := context.Background() + listener, done, err := newproxy(DNSActionRefused) + if err != nil { + t.Fatal(err) + } + r := newresolver(listener) + addrs, err := r.LookupHost(ctx, "dns.google") + if err == nil || err.Error() != netxlite.FailureDNSRefusedError { + t.Fatal("unexpected err", err) + } + if addrs != nil { + t.Fatal("expected empty addrs") + } + listener.Close() + <-done // wait for background goroutine to exit + }) + + t.Run("DNSActionLocalHost", func(t *testing.T) { + ctx := context.Background() + listener, done, err := newproxy(DNSActionLocalHost) + if err != nil { + t.Fatal(err) + } + r := newresolver(listener) + addrs, err := r.LookupHost(ctx, "dns.google") + if err != nil { + t.Fatal(err) + } + if addrs == nil { + t.Fatal("expected non-empty addrs") + } + var found127001 bool + for _, addr := range addrs { + found127001 = found127001 || addr == "127.0.0.1" + } + if !found127001 { + t.Fatal("did not find 127.0.0.1") + } + listener.Close() + <-done // wait for background goroutine to exit + }) + + t.Run("DNSActionEmpty", func(t *testing.T) { + ctx := context.Background() + listener, done, err := newproxy(DNSActionEmpty) + if err != nil { + t.Fatal(err) + } + r := newresolver(listener) + addrs, err := r.LookupHost(ctx, "dns.google") + if err == nil || err.Error() != netxlite.FailureDNSNoAnswer { + t.Fatal(err) + } + if addrs != nil { + t.Fatal("expected empty addrs") + } + listener.Close() + <-done // wait for background goroutine to exit + }) + + t.Run("DNSActionTimeout", func(t *testing.T) { + // Implementation note: if you see this test running for more + // than one second, then it means we're not checking the context + // immediately. We should be improving there but we need to be + // careful because lots of legacy code uses SerialResolver. + const timeout = time.Second + ctx, cancel := context.WithTimeout(context.Background(), timeout) + listener, done, err := newproxy(DNSActionTimeout) + defer cancel() + if err != nil { + t.Fatal(err) + } + r := newresolver(listener) + addrs, err := r.LookupHost(ctx, "dns.google") + if err == nil || err.Error() != netxlite.FailureGenericTimeoutError { + t.Fatal(err) + } + if addrs != nil { + t.Fatal("expected empty addrs") + } + listener.Close() + <-done // wait for background goroutine to exit + }) + + t.Run("Start with invalid address", func(t *testing.T) { + p := &DNSProxy{} + listener, err := p.Start("127.0.0.1") + if err == nil { + t.Fatal("expected an error") + } + if listener != nil { + t.Fatal("expected nil listener") + } + }) + + t.Run("oneloop", func(t *testing.T) { + t.Run("ReadFrom failure after which we should continue", func(t *testing.T) { + expected := errors.New("mocked error") + p := &DNSProxy{} + conn := &mocks.QUICUDPLikeConn{ + MockReadFrom: func(p []byte) (n int, addr net.Addr, err error) { + return 0, nil, expected + }, + } + okay := p.oneloop(conn) + if !okay { + t.Fatal("we should be okay after this error") + } + }) + + t.Run("ReadFrom the connection is closed", func(t *testing.T) { + expected := errors.New("use of closed network connection") + p := &DNSProxy{} + conn := &mocks.QUICUDPLikeConn{ + MockReadFrom: func(p []byte) (n int, addr net.Addr, err error) { + return 0, nil, expected + }, + } + okay := p.oneloop(conn) + if okay { + t.Fatal("we should not be okay after this error") + } + }) + + t.Run("Unpack fails", func(t *testing.T) { + p := &DNSProxy{} + conn := &mocks.QUICUDPLikeConn{ + MockReadFrom: func(p []byte) (n int, addr net.Addr, err error) { + if len(p) < 4 { + panic("buffer too small") + } + p[0] = 7 + return 1, &net.UDPAddr{}, nil + }, + } + okay := p.oneloop(conn) + if !okay { + t.Fatal("we should be okay after this error") + } + }) + + t.Run("reply fails", func(t *testing.T) { + p := &DNSProxy{} + conn := &mocks.QUICUDPLikeConn{ + MockReadFrom: func(p []byte) (n int, addr net.Addr, err error) { + query := &dns.Msg{} + query.Question = append(query.Question, dns.Question{}) + query.Question = append(query.Question, dns.Question{}) + data, err := query.Pack() + if err != nil { + panic(err) + } + if len(p) < len(data) { + panic("buffer too small") + } + for i := 0; i < len(data); i++ { + p[i] = data[i] + } + return len(data), &net.UDPAddr{}, nil + }, + } + okay := p.oneloop(conn) + if !okay { + t.Fatal("we should be okay after this error") + } + }) + + t.Run("pack fails", func(t *testing.T) { + p := &DNSProxy{ + mockableReply: func(query *dns.Msg) (*dns.Msg, error) { + reply := &dns.Msg{} + reply.MsgHdr.Rcode = -1 // causes pack to fail + return reply, nil + }, + } + conn := &mocks.QUICUDPLikeConn{ + MockReadFrom: func(p []byte) (n int, addr net.Addr, err error) { + query := &dns.Msg{} + query.Question = append(query.Question, dns.Question{}) + data, err := query.Pack() + if err != nil { + panic(err) + } + if len(p) < len(data) { + panic("buffer too small") + } + for i := 0; i < len(data); i++ { + p[i] = data[i] + } + return len(data), &net.UDPAddr{}, nil + }, + } + okay := p.oneloop(conn) + if !okay { + t.Fatal("we should be okay after this error") + } + }) + }) + + t.Run("proxy", func(t *testing.T) { + t.Run("pack fails", func(t *testing.T) { + p := &DNSProxy{} + query := &dns.Msg{} + query.Rcode = -1 // causes Pack to fail + reply, err := p.proxy(query) + if err == nil { + t.Fatal("expected error here") + } + if reply != nil { + t.Fatal("expected nil reply") + } + }) + + t.Run("round trip fails", func(t *testing.T) { + expected := errors.New("mocked error") + p := &DNSProxy{ + Upstream: &mocks.DNSTransport{ + MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) { + return nil, expected + }, + MockCloseIdleConnections: func() {}, + }, + } + reply, err := p.proxy(&dns.Msg{}) + if !errors.Is(err, expected) { + t.Fatal("unexpected err", err) + } + if reply != nil { + t.Fatal("expected nil reply here") + } + }) + + t.Run("unpack fails", func(t *testing.T) { + p := &DNSProxy{ + Upstream: &mocks.DNSTransport{ + MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) { + return make([]byte, 1), nil + }, + MockCloseIdleConnections: func() {}, + }, + } + reply, err := p.proxy(&dns.Msg{}) + if err == nil { + t.Fatal("expected error") + } + if reply != nil { + t.Fatal("expected nil reply here") + } + }) + }) +} diff --git a/internal/netxlite/filtering/doc.go b/internal/netxlite/filtering/doc.go new file mode 100644 index 0000000..46bf2f6 --- /dev/null +++ b/internal/netxlite/filtering/doc.go @@ -0,0 +1,2 @@ +// Package filtering contains primitives for implementing filtering. +package filtering diff --git a/internal/netxlite/filtering/tls.go b/internal/netxlite/filtering/tls.go new file mode 100644 index 0000000..0c4f581 --- /dev/null +++ b/internal/netxlite/filtering/tls.go @@ -0,0 +1,234 @@ +package filtering + +import ( + "crypto/tls" + "errors" + "io" + "net" + "strings" + "sync" +) + +// TLSAction is the action that this proxy should take. +type TLSAction int + +const ( + // TLSActionProxy proxies the traffic to the destination. + TLSActionProxy = TLSAction(iota) + + // TLSActionReset resets the connection. + TLSActionReset + + // TLSActionTimeout causes the connection to timeout. + TLSActionTimeout + + // TLSActionEOF closes the connection. + TLSActionEOF + + // TLSActionAlertInternalError sends an internal error + // alert message to the TLS client. + TLSActionAlertInternalError + + // TLSActionAlertUnrecognizedName tells the client that + // it's handshaking with an unknown SNI. + TLSActionAlertUnrecognizedName +) + +// TLSProxy is a TLS proxy that routes the traffic depending +// on the SNI value and may implement filtering policies. +type TLSProxy struct { + // OnIncomingSNI is the MANDATORY hook called whenever we have + // successfully received a ClientHello message. + OnIncomingSNI func(sni string) TLSAction +} + +// Start starts the proxy. +func (p *TLSProxy) Start(address string) (net.Listener, error) { + listener, _, err := p.start(address) + return listener, err +} + +// Start starts the proxy. +func (p *TLSProxy) start(address string) (net.Listener, <-chan interface{}, error) { + listener, err := net.Listen("tcp", address) + if err != nil { + return nil, nil, err + } + done := make(chan interface{}) + go p.mainloop(listener, done) + return listener, done, nil +} + +func (p *TLSProxy) mainloop(listener net.Listener, done chan<- interface{}) { + defer close(done) + for { + conn, err := listener.Accept() + if err == nil { + go p.handle(conn) + continue + } + if strings.HasSuffix(err.Error(), "use of closed network connection") { + break + } + } +} + +const ( + tlsAlertInternalError = byte(80) + tlsAlertUnrecognizedName = byte(112) +) + +func (p *TLSProxy) handle(conn net.Conn) { + defer conn.Close() + sni, hello, err := p.readClientHello(conn) + if err != nil { + p.reset(conn) + return + } + switch p.OnIncomingSNI(sni) { + case TLSActionProxy: + p.proxy(conn, sni, hello) + case TLSActionTimeout: + p.timeout(conn) + case TLSActionAlertInternalError: + p.alert(conn, tlsAlertInternalError) + case TLSActionAlertUnrecognizedName: + p.alert(conn, tlsAlertUnrecognizedName) + case TLSActionEOF: + p.eof(conn) + default: + p.reset(conn) + } +} + +// readClientHello reads the incoming ClientHello message. +// +// Arguments: +// +// - conn is the connection from which to read the ClientHello. +// +// Returns: +// +// - a string containing the SNI (empty on error); +// +// - bytes from the original ClientHello (nil on error); +// +// - an error (nil on success). +func (p *TLSProxy) readClientHello(conn net.Conn) (string, []byte, error) { + connWrapper := &tlsClientHelloReader{Conn: conn} + var ( + expectedErr = errors.New("cannot continue handhake") + sni string + mutex sync.Mutex // just for safety + ) + err := tls.Server(connWrapper, &tls.Config{ + GetCertificate: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) { + mutex.Lock() + sni = info.ServerName + mutex.Unlock() + return nil, expectedErr + }, + }).Handshake() + if !errors.Is(err, expectedErr) { + return "", nil, err + } + return sni, connWrapper.clientHello, nil +} + +// tlsClientHelloReader wraps a net.Conn for the purpose of +// saving the bytes of the ClientHello message. +type tlsClientHelloReader struct { + net.Conn + clientHello []byte +} + +func (c *tlsClientHelloReader) Read(b []byte) (int, error) { + count, err := c.Conn.Read(b) + if err == nil { + c.clientHello = append(c.clientHello, b[:count]...) + } + return count, err +} + +// Write prevents writing on the real connection +func (c *tlsClientHelloReader) Write(b []byte) (int, error) { + return 0, errors.New("cannot write on this connection") +} + +func (p *TLSProxy) reset(conn net.Conn) { + if tc, ok := conn.(*net.TCPConn); ok { + tc.SetLinger(0) + } + conn.Close() +} + +func (p *TLSProxy) timeout(conn net.Conn) { + buffer := make([]byte, 1<<14) + conn.Read(buffer) + conn.Close() +} + +func (p *TLSProxy) eof(conn net.Conn) { + conn.Close() +} + +func (p *TLSProxy) alert(conn net.Conn, code byte) { + alertdata := []byte{ + 21, // alert + 3, // version[0] + 3, // version[1] + 0, // length[0] + 2, // length[1] + 2, // fatal + code, + } + conn.Write(alertdata) + conn.Close() +} + +func (p *TLSProxy) proxy(conn net.Conn, sni string, hello []byte) { + p.proxydial(conn, sni, hello, net.Dial) +} + +func (p *TLSProxy) proxydial(conn net.Conn, sni string, hello []byte, + dial func(network, address string) (net.Conn, error)) { + if sni == "" { // don't know the destination host + p.reset(conn) + return + } + serverconn, err := dial("tcp", net.JoinHostPort(sni, "443")) + if err != nil { + p.reset(conn) + return + } + if p.connectingToMyself(serverconn) { + p.reset(conn) + return + } + if _, err := serverconn.Write(hello); err != nil { + p.reset(conn) + return + } + defer serverconn.Close() // conn is owned by the caller + wg := &sync.WaitGroup{} + wg.Add(2) + go p.forward(wg, conn, serverconn) + go p.forward(wg, serverconn, conn) + wg.Wait() +} + +// connectingToMyself returns true when the proxy has been somehow +// forced to create a connection to itself. +func (p *TLSProxy) connectingToMyself(conn net.Conn) bool { + local := conn.LocalAddr().String() + localAddr, _, localErr := net.SplitHostPort(local) + remote := conn.RemoteAddr().String() + remoteAddr, _, remoteErr := net.SplitHostPort(remote) + return localErr != nil || remoteErr != nil || localAddr == remoteAddr +} + +// forward will forward the traffic. +func (p *TLSProxy) forward(wg *sync.WaitGroup, left net.Conn, right net.Conn) { + defer wg.Done() + io.Copy(left, right) +} diff --git a/internal/netxlite/filtering/tls_test.go b/internal/netxlite/filtering/tls_test.go new file mode 100644 index 0000000..6e441ee --- /dev/null +++ b/internal/netxlite/filtering/tls_test.go @@ -0,0 +1,261 @@ +package filtering + +import ( + "context" + "crypto/tls" + "errors" + "net" + "strings" + "testing" + + "github.com/apex/log" + "github.com/ooni/probe-cli/v3/internal/netxlite" + "github.com/ooni/probe-cli/v3/internal/netxlite/mocks" +) + +func TestTLSProxy(t *testing.T) { + newproxy := func(action TLSAction) (net.Listener, <-chan interface{}, error) { + p := &TLSProxy{ + OnIncomingSNI: func(sni string) TLSAction { + return action + }, + } + return p.start("127.0.0.1:0") + } + + dialTLS := func(ctx context.Context, endpoint string, sni string) (net.Conn, error) { + d := netxlite.NewDialerWithoutResolver(log.Log) + th := netxlite.NewTLSHandshakerStdlib(log.Log) + tdx := netxlite.NewTLSDialerWithConfig(d, th, &tls.Config{ + ServerName: sni, + NextProtos: []string{"h2", "http/1.1"}, + RootCAs: netxlite.NewDefaultCertPool(), + }) + return tdx.DialTLSContext(ctx, "tcp", endpoint) + } + + t.Run("TLSActionProxy with default proxy", func(t *testing.T) { + ctx := context.Background() + listener, done, err := newproxy(TLSActionProxy) + if err != nil { + t.Fatal(err) + } + conn, err := dialTLS(ctx, listener.Addr().String(), "dns.google") + if err != nil { + t.Fatal(err) + } + conn.Close() + listener.Close() + <-done // wait for background goroutine to exit + }) + + t.Run("TLSActionTimeout", func(t *testing.T) { + ctx := context.Background() + listener, done, err := newproxy(TLSActionTimeout) + if err != nil { + t.Fatal(err) + } + conn, err := dialTLS(ctx, listener.Addr().String(), "dns.google") + if err == nil || err.Error() != netxlite.FailureGenericTimeoutError { + t.Fatal("unexpected err", err) + } + if conn != nil { + t.Fatal("expected nil conn") + } + listener.Close() + <-done // wait for background goroutine to exit + }) + + t.Run("TLSActionAlertInternalError", func(t *testing.T) { + ctx := context.Background() + listener, done, err := newproxy(TLSActionAlertInternalError) + if err != nil { + t.Fatal(err) + } + conn, err := dialTLS(ctx, listener.Addr().String(), "dns.google") + if err == nil || !strings.HasSuffix(err.Error(), "tls: internal error") { + t.Fatal("unexpected err", err) + } + if conn != nil { + t.Fatal("expected nil conn") + } + listener.Close() + <-done // wait for background goroutine to exit + }) + + t.Run("TLSActionAlertUnrecognizedName", func(t *testing.T) { + ctx := context.Background() + listener, done, err := newproxy(TLSActionAlertUnrecognizedName) + if err != nil { + t.Fatal(err) + } + conn, err := dialTLS(ctx, listener.Addr().String(), "dns.google") + if err == nil || !strings.HasSuffix(err.Error(), "tls: unrecognized name") { + t.Fatal("unexpected err", err) + } + if conn != nil { + t.Fatal("expected nil conn") + } + listener.Close() + <-done // wait for background goroutine to exit + }) + + t.Run("TLSActionEOF", func(t *testing.T) { + ctx := context.Background() + listener, done, err := newproxy(TLSActionEOF) + if err != nil { + t.Fatal(err) + } + conn, err := dialTLS(ctx, listener.Addr().String(), "dns.google") + if err == nil || err.Error() != netxlite.FailureEOFError { + t.Fatal("unexpected err", err) + } + if conn != nil { + t.Fatal("expected nil conn") + } + listener.Close() + <-done // wait for background goroutine to exit + }) + + t.Run("TLSActionReset", func(t *testing.T) { + ctx := context.Background() + listener, done, err := newproxy(TLSActionReset) + if err != nil { + t.Fatal(err) + } + conn, err := dialTLS(ctx, listener.Addr().String(), "dns.google") + if err == nil || err.Error() != netxlite.FailureConnectionReset { + t.Fatal("unexpected err", err) + } + if conn != nil { + t.Fatal("expected nil conn") + } + listener.Close() + <-done // wait for background goroutine to exit + }) + + t.Run("handle cannot read ClientHello", func(t *testing.T) { + listener, done, err := newproxy(TLSActionProxy) + if err != nil { + t.Fatal(err) + } + conn, err := net.Dial("tcp", listener.Addr().String()) + if err != nil { + t.Fatal(err) + } + conn.Write([]byte("GET / HTTP/1.0\r\n\r\n")) + buff := make([]byte, 1<<17) + _, err = conn.Read(buff) + // Implementation note: we need to wrap the error because + // otherwise the error string on Windows is different from Unix + if err == nil { + t.Fatal("expected non-nil error") + } + err = netxlite.NewTopLevelGenericErrWrapper(err) + if err.Error() != netxlite.FailureConnectionReset { + t.Fatal("unexpected err", err) + } + listener.Close() + <-done // wait for background goroutine to exit + }) + + t.Run("TLSActionProxy fails because we don't have SNI", func(t *testing.T) { + ctx := context.Background() + listener, done, err := newproxy(TLSActionProxy) + if err != nil { + t.Fatal(err) + } + conn, err := dialTLS(ctx, listener.Addr().String(), "127.0.0.1") + if err == nil || err.Error() != netxlite.FailureConnectionReset { + t.Fatal("unexpected err", err) + } + if conn != nil { + t.Fatal("expected nil conn") + } + listener.Close() + <-done // wait for background goroutine to exit + }) + + t.Run("TLSActionProxy fails because we can't dial", func(t *testing.T) { + ctx := context.Background() + listener, done, err := newproxy(TLSActionProxy) + if err != nil { + t.Fatal(err) + } + conn, err := dialTLS(ctx, listener.Addr().String(), "antani.ooni.org") + if err == nil || err.Error() != netxlite.FailureConnectionReset { + t.Fatal("unexpected err", err) + } + if conn != nil { + t.Fatal("expected nil conn") + } + listener.Close() + <-done // wait for background goroutine to exit + }) + + t.Run("proxydial fails because it's connecting to itself", func(t *testing.T) { + p := &TLSProxy{} + conn := &mocks.Conn{ + MockClose: func() error { + return nil + }, + } + p.proxydial(conn, "ooni.org", nil, func(network, address string) (net.Conn, error) { + return &mocks.Conn{ + MockClose: func() error { + return nil + }, + MockLocalAddr: func() net.Addr { + return &net.TCPAddr{ + IP: net.IPv6loopback, + } + }, + MockRemoteAddr: func() net.Addr { + return &net.TCPAddr{ + IP: net.IPv6loopback, + } + }, + }, nil + }) + }) + + t.Run("proxydial fails because it cannot write the hello", func(t *testing.T) { + p := &TLSProxy{} + conn := &mocks.Conn{ + MockClose: func() error { + return nil + }, + } + p.proxydial(conn, "ooni.org", nil, func(network, address string) (net.Conn, error) { + return &mocks.Conn{ + MockClose: func() error { + return nil + }, + MockLocalAddr: func() net.Addr { + return &net.TCPAddr{ + IP: net.IPv6loopback, + } + }, + MockRemoteAddr: func() net.Addr { + return &net.TCPAddr{ + IP: net.IPv4(10, 0, 0, 1), + } + }, + MockWrite: func(b []byte) (int, error) { + return 0, errors.New("mocked error") + }, + }, nil + }) + }) + + t.Run("Start fails on an invalid address", func(t *testing.T) { + p := &TLSProxy{} + listener, err := p.Start("127.0.0.1") + if err == nil { + t.Fatal("expected an error") + } + if listener != nil { + t.Fatal("expected nil listener") + } + }) +} diff --git a/internal/netxlite/http_test.go b/internal/netxlite/http_test.go index 9045fc3..2844d8a 100644 --- a/internal/netxlite/http_test.go +++ b/internal/netxlite/http_test.go @@ -177,7 +177,7 @@ func TestNewHTTPTransport(t *testing.T) { td := NewTLSDialer(d, NewTLSHandshakerStdlib(log.Log)) txp := NewHTTPTransport(log.Log, d, td) client := &http.Client{Transport: txp} - resp, err := client.Get("https://www.google.com/robots.txt") + resp, err := client.Get("https://8.8.4.4/robots.txt") if !errors.Is(err, expected) { t.Fatal("not the error we expected", err) } diff --git a/internal/netxlite/integration_test.go b/internal/netxlite/integration_test.go index 7bc791e..d4c0234 100644 --- a/internal/netxlite/integration_test.go +++ b/internal/netxlite/integration_test.go @@ -3,27 +3,42 @@ package netxlite_test import ( "context" "crypto/tls" + "fmt" "net" "net/http" "testing" + "time" "github.com/apex/log" "github.com/lucas-clemente/quic-go" "github.com/ooni/probe-cli/v3/internal/netxlite" + "github.com/ooni/probe-cli/v3/internal/netxlite/filtering" utls "gitlab.com/yawning/utls.git" ) -func TestResolver(t *testing.T) { +// This set of integration tests ensures that we continue to +// be able to measure the conditions we care about + +func TestMeasureWithSystemResolver(t *testing.T) { if testing.Short() { t.Skip("skip test in short mode") } - t.Run("works as intended", func(t *testing.T) { - // TODO(bassosimone): this is actually an integration - // test but how to test this case? + // + // Measurement conditions we care about: + // + // - success + // + // - nxdomain + // + // - timeout + // + + t.Run("on success", func(t *testing.T) { r := netxlite.NewResolverStdlib(log.Log) defer r.CloseIdleConnections() - addrs, err := r.LookupHost(context.Background(), "dns.google.com") + ctx := context.Background() + addrs, err := r.LookupHost(ctx, "dns.google.com") if err != nil { t.Fatal(err) } @@ -31,6 +46,413 @@ func TestResolver(t *testing.T) { t.Fatal("expected non-nil result here") } }) + + t.Run("for nxdomain", func(t *testing.T) { + r := netxlite.NewResolverStdlib(log.Log) + defer r.CloseIdleConnections() + ctx := context.Background() + addrs, err := r.LookupHost(ctx, "antani.ooni.org") + if err == nil || err.Error() != netxlite.FailureDNSNXDOMAINError { + t.Fatal("not the error we expected", err) + } + if addrs != nil { + t.Fatal("expected nil result here") + } + }) + + t.Run("for timeout", func(t *testing.T) { + r := netxlite.NewResolverStdlib(log.Log) + defer r.CloseIdleConnections() + const timeout = time.Nanosecond + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + addrs, err := r.LookupHost(ctx, "ooni.org") + if err == nil || err.Error() != netxlite.FailureGenericTimeoutError { + t.Fatal("not the error we expected", err) + } + if addrs != nil { + t.Fatal("expected nil result here") + } + }) +} + +func TestMeasureWithUDPResolver(t *testing.T) { + if testing.Short() { + t.Skip("skip test in short mode") + } + + // + // Measurement conditions we care about: + // + // - success + // + // - nxdomain + // + // - refused + // + // - timeout + // + + t.Run("on success", func(t *testing.T) { + dlr := netxlite.NewDialerWithoutResolver(log.Log) + r := netxlite.NewResolverUDP(log.Log, dlr, "8.8.4.4:53") + defer r.CloseIdleConnections() + ctx := context.Background() + addrs, err := r.LookupHost(ctx, "dns.google.com") + if err != nil { + t.Fatal(err) + } + if addrs == nil { + t.Fatal("expected non-nil result here") + } + }) + + t.Run("for nxdomain", func(t *testing.T) { + proxy := &filtering.DNSProxy{ + OnQuery: func(domain string) filtering.DNSAction { + return filtering.DNSActionNXDOMAIN + }, + } + listener, err := proxy.Start("127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer listener.Close() + dlr := netxlite.NewDialerWithoutResolver(log.Log) + r := netxlite.NewResolverUDP(log.Log, dlr, listener.LocalAddr().String()) + defer r.CloseIdleConnections() + ctx := context.Background() + addrs, err := r.LookupHost(ctx, "ooni.org") + if err == nil || err.Error() != netxlite.FailureDNSNXDOMAINError { + t.Fatal("not the error we expected", err) + } + if addrs != nil { + t.Fatal("expected nil result here") + } + }) + + t.Run("for refused", func(t *testing.T) { + proxy := &filtering.DNSProxy{ + OnQuery: func(domain string) filtering.DNSAction { + return filtering.DNSActionRefused + }, + } + listener, err := proxy.Start("127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer listener.Close() + dlr := netxlite.NewDialerWithoutResolver(log.Log) + r := netxlite.NewResolverUDP(log.Log, dlr, listener.LocalAddr().String()) + defer r.CloseIdleConnections() + ctx := context.Background() + addrs, err := r.LookupHost(ctx, "ooni.org") + if err == nil || err.Error() != netxlite.FailureDNSRefusedError { + t.Fatal("not the error we expected", err) + } + if addrs != nil { + t.Fatal("expected nil result here") + } + }) + + t.Run("for timeout", func(t *testing.T) { + proxy := &filtering.DNSProxy{ + OnQuery: func(domain string) filtering.DNSAction { + return filtering.DNSActionTimeout + }, + } + listener, err := proxy.Start("127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer listener.Close() + dlr := netxlite.NewDialerWithoutResolver(log.Log) + r := netxlite.NewResolverUDP(log.Log, dlr, listener.LocalAddr().String()) + defer r.CloseIdleConnections() + ctx := context.Background() + addrs, err := r.LookupHost(ctx, "ooni.org") + if err == nil || err.Error() != netxlite.FailureGenericTimeoutError { + t.Fatal("not the error we expected", err) + } + if addrs != nil { + t.Fatal("expected nil result here") + } + }) +} + +func TestMeasureWithDialer(t *testing.T) { + if testing.Short() { + t.Skip("skip test in short mode") + } + + // + // Measurement conditions we care about: + // + // - success + // + // - connection refused + // + // - timeout + // + + t.Run("on success", func(t *testing.T) { + d := netxlite.NewDialerWithoutResolver(log.Log) + defer d.CloseIdleConnections() + ctx := context.Background() + conn, err := d.DialContext(ctx, "tcp", "8.8.4.4:443") + if err != nil { + t.Fatal(err) + } + if conn == nil { + t.Fatal("expected non-nil conn here") + } + conn.Close() + }) + + t.Run("on connection refused", func(t *testing.T) { + d := netxlite.NewDialerWithoutResolver(log.Log) + defer d.CloseIdleConnections() + ctx := context.Background() + // Here we assume that no-one is listening on 127.0.0.1:1 + conn, err := d.DialContext(ctx, "tcp", "127.0.0.1:1") + if err == nil || err.Error() != netxlite.FailureConnectionRefused { + t.Fatal("not the error we expected", err) + } + if conn != nil { + t.Fatal("expected nil conn here") + } + }) + + t.Run("on timeout", func(t *testing.T) { + d := netxlite.NewDialerWithoutResolver(log.Log) + defer d.CloseIdleConnections() + ctx := context.Background() + // Here we assume 8.8.4.4:1 is filtered + conn, err := d.DialContext(ctx, "tcp", "8.8.4.4:1") + if err == nil || err.Error() != netxlite.FailureGenericTimeoutError { + t.Fatal("not the error we expected", err) + } + if conn != nil { + t.Fatal("expected nil conn here") + } + }) +} + +func TestMeasureWithTLSHandshaker(t *testing.T) { + if testing.Short() { + t.Skip("skip test in short mode") + } + + // + // Measurement conditions we care about: + // + // - success + // + // - connection reset + // + // - timeout + // + + dial := func(ctx context.Context, address string) (net.Conn, error) { + d := netxlite.NewDialerWithoutResolver(log.Log) + return d.DialContext(ctx, "tcp", address) + } + + successFlow := func(th netxlite.TLSHandshaker) error { + ctx := context.Background() + conn, err := dial(ctx, "8.8.4.4:443") + if err != nil { + return fmt.Errorf("dial failed: %w", err) + } + defer conn.Close() + config := &tls.Config{ + ServerName: "dns.google", + NextProtos: []string{"h2", "http/1.1"}, + RootCAs: netxlite.NewDefaultCertPool(), + } + tconn, _, err := th.Handshake(ctx, conn, config) + if err != nil { + return fmt.Errorf("tls handshake failed: %w", err) + } + tconn.Close() + return nil + } + + connectionResetFlow := func(th netxlite.TLSHandshaker) error { + tlsProxy := &filtering.TLSProxy{ + OnIncomingSNI: func(sni string) filtering.TLSAction { + return filtering.TLSActionReset + }, + } + listener, err := tlsProxy.Start("127.0.0.1:0") + if err != nil { + return fmt.Errorf("cannot start proxy: %w", err) + } + defer listener.Close() + ctx := context.Background() + conn, err := dial(ctx, listener.Addr().String()) + if err != nil { + return fmt.Errorf("dial failed: %w", err) + } + defer conn.Close() + config := &tls.Config{ + ServerName: "dns.google", + NextProtos: []string{"h2", "http/1.1"}, + RootCAs: netxlite.NewDefaultCertPool(), + } + tconn, _, err := th.Handshake(ctx, conn, config) + if err == nil { + return fmt.Errorf("tls handshake succeded unexpectedly") + } + if err.Error() != netxlite.FailureConnectionReset { + return fmt.Errorf("not the error we expected: %w", err) + } + if tconn != nil { + return fmt.Errorf("expected nil tconn here") + } + return nil + } + + timeoutFlow := func(th netxlite.TLSHandshaker) error { + tlsProxy := &filtering.TLSProxy{ + OnIncomingSNI: func(sni string) filtering.TLSAction { + return filtering.TLSActionTimeout + }, + } + listener, err := tlsProxy.Start("127.0.0.1:0") + if err != nil { + return fmt.Errorf("cannot start proxy: %w", err) + } + defer listener.Close() + ctx := context.Background() + conn, err := dial(ctx, listener.Addr().String()) + if err != nil { + return fmt.Errorf("dial failed: %w", err) + } + defer conn.Close() + config := &tls.Config{ + ServerName: "dns.google", + NextProtos: []string{"h2", "http/1.1"}, + RootCAs: netxlite.NewDefaultCertPool(), + } + tconn, _, err := th.Handshake(ctx, conn, config) + if err == nil { + return fmt.Errorf("tls handshake succeded unexpectedly") + } + if err.Error() != netxlite.FailureGenericTimeoutError { + return fmt.Errorf("not the error we expected: %w", err) + } + if tconn != nil { + return fmt.Errorf("expected nil tconn here") + } + return nil + } + + t.Run("for stdlib handshaker", func(t *testing.T) { + t.Run("on success", func(t *testing.T) { + th := netxlite.NewTLSHandshakerStdlib(log.Log) + err := successFlow(th) + if err != nil { + t.Fatal(err) + } + }) + + t.Run("on connection reset", func(t *testing.T) { + th := netxlite.NewTLSHandshakerStdlib(log.Log) + err := connectionResetFlow(th) + if err != nil { + t.Fatal(err) + } + }) + + t.Run("on timeout", func(t *testing.T) { + th := netxlite.NewTLSHandshakerStdlib(log.Log) + err := timeoutFlow(th) + if err != nil { + t.Fatal(err) + } + }) + }) + + t.Run("for utls handshaker", func(t *testing.T) { + t.Run("on success", func(t *testing.T) { + th := netxlite.NewTLSHandshakerUTLS(log.Log, &utls.HelloFirefox_55) + err := successFlow(th) + if err != nil { + t.Fatal(err) + } + }) + + t.Run("on connection reset", func(t *testing.T) { + th := netxlite.NewTLSHandshakerUTLS(log.Log, &utls.HelloFirefox_55) + err := connectionResetFlow(th) + if err != nil { + t.Fatal(err) + } + }) + + t.Run("on timeout", func(t *testing.T) { + th := netxlite.NewTLSHandshakerUTLS(log.Log, &utls.HelloFirefox_55) + err := timeoutFlow(th) + if err != nil { + t.Fatal(err) + } + }) + }) +} + +func TestMeasureWithQUICDialer(t *testing.T) { + if testing.Short() { + t.Skip("skip test in short mode") + } + + // + // Measurement conditions we care about: + // + // - success + // + // - timeout + // + + t.Run("on success", func(t *testing.T) { + ql := netxlite.NewQUICListener() + d := netxlite.NewQUICDialerWithoutResolver(ql, log.Log) + defer d.CloseIdleConnections() + ctx := context.Background() + config := &tls.Config{ + ServerName: "dns.google", + NextProtos: []string{"h3"}, + RootCAs: netxlite.NewDefaultCertPool(), + } + sess, err := d.DialContext(ctx, "udp", "8.8.4.4:443", config, &quic.Config{}) + if err != nil { + t.Fatal(err) + } + if sess == nil { + t.Fatal("expected non-nil sess here") + } + sess.CloseWithError(0, "") + }) + + t.Run("on timeout", func(t *testing.T) { + ql := netxlite.NewQUICListener() + d := netxlite.NewQUICDialerWithoutResolver(ql, log.Log) + defer d.CloseIdleConnections() + ctx := context.Background() + config := &tls.Config{ + ServerName: "dns.google", + NextProtos: []string{"h3"}, + RootCAs: netxlite.NewDefaultCertPool(), + } + // Here we assume 8.8.4.4:1 is filtered + sess, err := d.DialContext(ctx, "udp", "8.8.4.4:1", config, &quic.Config{}) + if err == nil || err.Error() != netxlite.FailureGenericTimeoutError { + t.Fatal("not the error we expected", err) + } + if sess != nil { + t.Fatal("expected nil sess here") + } + }) } func TestHTTPTransport(t *testing.T) { @@ -73,63 +495,3 @@ func TestHTTP3Transport(t *testing.T) { txp.CloseIdleConnections() }) } - -func TestUTLSHandshaker(t *testing.T) { - t.Run("with chrome fingerprint", func(t *testing.T) { - h := netxlite.NewTLSHandshakerUTLS(log.Log, &utls.HelloChrome_Auto) - cfg := &tls.Config{ServerName: "google.com"} - conn, err := net.Dial("tcp", "google.com:443") - if err != nil { - t.Fatal("unexpected error", err) - } - conn, _, err = h.Handshake(context.Background(), conn, cfg) - if err != nil { - t.Fatal("unexpected error", err) - } - if conn == nil { - t.Fatal("nil connection") - } - }) -} - -func TestQUICDialer(t *testing.T) { - if testing.Short() { - t.Skip("skip test in short mode") - } - - t.Run("works as intended", func(t *testing.T) { - tlsConfig := &tls.Config{ - ServerName: "dns.google", - } - d := netxlite.NewQUICDialerWithoutResolver( - netxlite.NewQUICListener(), log.Log, - ) - ctx := context.Background() - sess, err := d.DialContext( - ctx, "udp", "8.8.8.8:443", tlsConfig, &quic.Config{}) - if err != nil { - t.Fatal("not the error we expected", err) - } - <-sess.HandshakeComplete().Done() - if err := sess.CloseWithError(0, ""); err != nil { - t.Fatal(err) - } - }) - - t.Run("can guess the SNI and ALPN when using a domain name for web", func(t *testing.T) { - d := netxlite.NewQUICDialerWithResolver( - netxlite.NewQUICListener(), log.Log, - netxlite.NewResolverStdlib(log.Log), - ) - ctx := context.Background() - sess, err := d.DialContext( - ctx, "udp", "dns.google:443", &tls.Config{}, &quic.Config{}) - if err != nil { - t.Fatal("not the error we expected", err) - } - <-sess.HandshakeComplete().Done() - if err := sess.CloseWithError(0, ""); err != nil { - t.Fatal(err) - } - }) -} diff --git a/internal/netxlite/quic_test.go b/internal/netxlite/quic_test.go index 8ed70c0..ce7dd88 100644 --- a/internal/netxlite/quic_test.go +++ b/internal/netxlite/quic_test.go @@ -252,25 +252,6 @@ func TestQUICDialerResolver(t *testing.T) { }) t.Run("DialContext", func(t *testing.T) { - t.Run("on success", func(t *testing.T) { - tlsConfig := &tls.Config{} - dialer := &quicDialerResolver{ - Resolver: NewResolverStdlib(log.Log), - Dialer: &quicDialerQUICGo{ - QUICListener: &quicListenerStdlib{}, - }} - sess, err := dialer.DialContext( - context.Background(), "udp", "www.google.com:443", - tlsConfig, &quic.Config{}) - if err != nil { - t.Fatal(err) - } - <-sess.HandshakeComplete().Done() - if err := sess.CloseWithError(0, ""); err != nil { - t.Fatal(err) - } - }) - t.Run("with missing port", func(t *testing.T) { tlsConfig := &tls.Config{} dialer := &quicDialerResolver{ @@ -306,7 +287,7 @@ func TestQUICDialerResolver(t *testing.T) { } }) - t.Run("with invalid port (i.e., the zero port)", func(t *testing.T) { + t.Run("with invalid, non-numeric port)", func(t *testing.T) { // This test allows us to check for the case where every attempt // to establish a connection leads to a failure tlsConf := &tls.Config{} @@ -316,13 +297,12 @@ func TestQUICDialerResolver(t *testing.T) { QUICListener: &quicListenerStdlib{}, }} sess, err := dialer.DialContext( - context.Background(), "udp", "www.google.com:0", + context.Background(), "udp", "8.8.4.4:x", tlsConf, &quic.Config{}) if err == nil { t.Fatal("expected an error here") } - if !strings.HasSuffix(err.Error(), "sendto: invalid argument") && - !strings.HasSuffix(err.Error(), "sendto: can't assign requested address") { + if !strings.HasSuffix(err.Error(), "invalid syntax") { t.Fatal("not the error we expected", err) } if sess != nil { @@ -344,7 +324,7 @@ func TestQUICDialerResolver(t *testing.T) { }, }} sess, err := dialer.DialContext( - context.Background(), "udp", "www.google.com:443", + context.Background(), "udp", "8.8.4.4:443", tlsConfig, &quic.Config{}) if !errors.Is(err, expected) { t.Fatal("not the error we expected", err) @@ -355,7 +335,7 @@ func TestQUICDialerResolver(t *testing.T) { if tlsConfig.ServerName != "" { t.Fatal("should not have changed tlsConfig.ServerName") } - if gotTLSConfig.ServerName != "www.google.com" { + if gotTLSConfig.ServerName != "8.8.4.4" { t.Fatal("gotTLSConfig.ServerName has not been set") } }) diff --git a/internal/netxlite/serialresolver_test.go b/internal/netxlite/serialresolver_test.go index f621487..4238176 100644 --- a/internal/netxlite/serialresolver_test.go +++ b/internal/netxlite/serialresolver_test.go @@ -12,6 +12,22 @@ import ( "github.com/ooni/probe-cli/v3/internal/netxlite/mocks" ) +// errorWithTimeout is an error that golang will always consider +// to be a timeout because it has a Timeout() bool method +type errorWithTimeout struct { + error +} + +// Timeout returns whether this error is a timeout. +func (err *errorWithTimeout) Timeout() bool { + return true +} + +// Unwrap allows to unwrap the error. +func (err *errorWithTimeout) Unwrap() error { + return err.error +} + func TestSerialResolver(t *testing.T) { t.Run("transport okay", func(t *testing.T) { txp := NewDNSOverTLS((&tls.Dialer{}).DialContext, "8.8.8.8:853") @@ -129,7 +145,10 @@ func TestSerialResolver(t *testing.T) { t.Run("with timeout", func(t *testing.T) { txp := &mocks.DNSTransport{ MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) { - return nil, &net.OpError{Err: ETIMEDOUT, Op: "dial"} + return nil, &net.OpError{ + Err: &errorWithTimeout{ETIMEDOUT}, + Op: "dial", + } }, MockRequiresPadding: func() bool { return true diff --git a/internal/runtimex/runtimex.go b/internal/runtimex/runtimex.go index 6d454e4..0be3a0c 100644 --- a/internal/runtimex/runtimex.go +++ b/internal/runtimex/runtimex.go @@ -10,3 +10,15 @@ func PanicOnError(err error, message string) { panic(fmt.Errorf("%s: %w", message, err)) } } + +// PanicIfFalse calls panic if assertion is false. +func PanicIfFalse(assertion bool, message string) { + if !assertion { + panic(message) + } +} + +// PanicIfTrue calls panic if assertion is true. +func PanicIfTrue(assertion bool, message string) { + PanicIfFalse(!assertion, message) +} diff --git a/internal/runtimex/runtimex_test.go b/internal/runtimex/runtimex_test.go index 9b129ec..4649540 100644 --- a/internal/runtimex/runtimex_test.go +++ b/internal/runtimex/runtimex_test.go @@ -7,21 +7,67 @@ import ( "github.com/ooni/probe-cli/v3/internal/runtimex" ) -func TestGood(t *testing.T) { - runtimex.PanicOnError(nil, "antani failed") -} - -func TestBad(t *testing.T) { - expected := errors.New("mocked error") - if !errors.Is(badfunc(expected), expected) { - t.Fatal("not the error we expected") +func TestPanicOnError(t *testing.T) { + badfunc := func(in error) (out error) { + defer func() { + out = recover().(error) + }() + runtimex.PanicOnError(in, "antani failed") + return } + + t.Run("error is nil", func(t *testing.T) { + runtimex.PanicOnError(nil, "antani failed") + }) + + t.Run("error is not nil", func(t *testing.T) { + expected := errors.New("mocked error") + if !errors.Is(badfunc(expected), expected) { + t.Fatal("not the error we expected") + } + }) } -func badfunc(in error) (out error) { - defer func() { - out = recover().(error) - }() - runtimex.PanicOnError(in, "antani failed") - return +func TestPanicIfFalse(t *testing.T) { + badfunc := func(in bool, message string) (out error) { + defer func() { + out = errors.New(recover().(string)) + }() + runtimex.PanicIfFalse(in, message) + return + } + + t.Run("assertion is true", func(t *testing.T) { + runtimex.PanicIfFalse(true, "antani failed") + }) + + t.Run("assertion is false", func(t *testing.T) { + message := "mocked error" + err := badfunc(false, message) + if err == nil || err.Error() != message { + t.Fatal("not the error we expected", err) + } + }) +} + +func TestPanicIfTrue(t *testing.T) { + badfunc := func(in bool, message string) (out error) { + defer func() { + out = errors.New(recover().(string)) + }() + runtimex.PanicIfTrue(in, message) + return + } + + t.Run("assertion is false", func(t *testing.T) { + runtimex.PanicIfTrue(false, "antani failed") + }) + + t.Run("assertion is true", func(t *testing.T) { + message := "mocked error" + err := badfunc(true, message) + if err == nil || err.Error() != message { + t.Fatal("not the error we expected", err) + } + }) }