From de130d249cd28737fed169eafd2be8aefd0d9362 Mon Sep 17 00:00:00 2001 From: Simone Basso Date: Tue, 28 Sep 2021 11:26:16 +0200 Subject: [PATCH] refactor(dnsx): group tests together (#516) Part of https://github.com/ooni/probe/issues/1591 --- internal/netxlite/dnsx/dnsdecoder_test.go | 387 ++++++++-------- internal/netxlite/dnsx/dnsencoder_test.go | 107 +++-- internal/netxlite/dnsx/dnsoverhttps_test.go | 350 +++++++------- internal/netxlite/dnsx/dnsovertcp_test.go | 388 ++++++++-------- internal/netxlite/dnsx/dnsoverudp_test.go | 272 +++++------ internal/netxlite/dnsx/serialresolver.go | 2 +- internal/netxlite/dnsx/serialresolver_test.go | 438 +++++++++--------- 7 files changed, 983 insertions(+), 961 deletions(-) diff --git a/internal/netxlite/dnsx/dnsdecoder_test.go b/internal/netxlite/dnsx/dnsdecoder_test.go index 9976722..a15e15f 100644 --- a/internal/netxlite/dnsx/dnsdecoder_test.go +++ b/internal/netxlite/dnsx/dnsdecoder_test.go @@ -11,114 +11,185 @@ import ( "github.com/ooni/probe-cli/v3/internal/netxlite/errorsx" ) -func TestDecoderUnpackError(t *testing.T) { - d := &DNSDecoderMiekg{} - data, err := d.DecodeLookupHost(dns.TypeA, nil) - if err == nil { - t.Fatal("expected an error here") - } - if data != nil { - t.Fatal("expected nil data here") - } +func TestDNSDecoder(t *testing.T) { + t.Run("LookupHost", func(t *testing.T) { + t.Run("UnpackError", func(t *testing.T) { + d := &DNSDecoderMiekg{} + data, err := d.DecodeLookupHost(dns.TypeA, nil) + if err == nil { + t.Fatal("expected an error here") + } + if data != nil { + t.Fatal("expected nil data here") + } + }) + + t.Run("NXDOMAIN", func(t *testing.T) { + d := &DNSDecoderMiekg{} + data, err := d.DecodeLookupHost( + dns.TypeA, dnsGenReplyWithError(t, dns.TypeA, dns.RcodeNameError)) + if err == nil || !strings.HasSuffix(err.Error(), "no such host") { + t.Fatal("not the error we expected", err) + } + if data != nil { + t.Fatal("expected nil data here") + } + }) + + t.Run("Refused", func(t *testing.T) { + d := &DNSDecoderMiekg{} + data, err := d.DecodeLookupHost( + dns.TypeA, dnsGenReplyWithError(t, dns.TypeA, dns.RcodeRefused)) + if !errors.Is(err, errorsx.ErrOODNSRefused) { + t.Fatal("not the error we expected", err) + } + if data != nil { + t.Fatal("expected nil data here") + } + }) + + t.Run("no address", func(t *testing.T) { + d := &DNSDecoderMiekg{} + data, err := d.DecodeLookupHost(dns.TypeA, dnsGenLookupHostReplySuccess(t, dns.TypeA)) + if !errors.Is(err, errorsx.ErrOODNSNoAnswer) { + t.Fatal("not the error we expected", err) + } + if data != nil { + t.Fatal("expected nil data here") + } + }) + + t.Run("decode A", func(t *testing.T) { + d := &DNSDecoderMiekg{} + data, err := d.DecodeLookupHost( + dns.TypeA, dnsGenLookupHostReplySuccess(t, dns.TypeA, "1.1.1.1", "8.8.8.8")) + if err != nil { + t.Fatal(err) + } + if len(data) != 2 { + t.Fatal("expected two entries here") + } + if data[0] != "1.1.1.1" { + t.Fatal("invalid first IPv4 entry") + } + if data[1] != "8.8.8.8" { + t.Fatal("invalid second IPv4 entry") + } + }) + + t.Run("decode AAAA", func(t *testing.T) { + d := &DNSDecoderMiekg{} + data, err := d.DecodeLookupHost( + dns.TypeAAAA, dnsGenLookupHostReplySuccess(t, dns.TypeAAAA, "::1", "fe80::1")) + if err != nil { + t.Fatal(err) + } + if len(data) != 2 { + t.Fatal("expected two entries here") + } + if data[0] != "::1" { + t.Fatal("invalid first IPv6 entry") + } + if data[1] != "fe80::1" { + t.Fatal("invalid second IPv6 entry") + } + }) + + t.Run("unexpected A reply", func(t *testing.T) { + d := &DNSDecoderMiekg{} + data, err := d.DecodeLookupHost( + dns.TypeA, dnsGenLookupHostReplySuccess(t, dns.TypeAAAA, "::1", "fe80::1")) + if !errors.Is(err, errorsx.ErrOODNSNoAnswer) { + t.Fatal("not the error we expected", err) + } + if data != nil { + t.Fatal("expected nil data here") + } + }) + + t.Run("unexpected AAAA reply", func(t *testing.T) { + d := &DNSDecoderMiekg{} + data, err := d.DecodeLookupHost( + dns.TypeAAAA, dnsGenLookupHostReplySuccess(t, dns.TypeA, "1.1.1.1", "8.8.4.4.")) + if !errors.Is(err, errorsx.ErrOODNSNoAnswer) { + t.Fatal("not the error we expected", err) + } + if data != nil { + t.Fatal("expected nil data here") + } + }) + }) + + t.Run("parseReply", func(t *testing.T) { + d := &DNSDecoderMiekg{} + msg := &dns.Msg{} + msg.Rcode = dns.RcodeFormatError // an rcode we don't handle + data, err := msg.Pack() + if err != nil { + t.Fatal(err) + } + reply, err := d.parseReply(data) + if !errors.Is(err, errorsx.ErrOODNSMisbehaving) { // catch all error + t.Fatal("not the error we expected", err) + } + if reply != nil { + t.Fatal("expected nil reply") + } + }) + + t.Run("DecodeHTTPS", func(t *testing.T) { + t.Run("with nil data", func(t *testing.T) { + d := &DNSDecoderMiekg{} + reply, err := d.DecodeHTTPS(nil) + if err == nil || err.Error() != "dns: overflow unpacking uint16" { + t.Fatal("not the error we expected", err) + } + if reply != nil { + t.Fatal("expected nil reply") + } + }) + + t.Run("with empty answer", func(t *testing.T) { + data := dnsGenHTTPSReplySuccess(t, nil, nil, nil) + d := &DNSDecoderMiekg{} + reply, err := d.DecodeHTTPS(data) + if !errors.Is(err, errorsx.ErrOODNSNoAnswer) { + t.Fatal("unexpected err", err) + } + if reply != nil { + t.Fatal("expected nil reply") + } + }) + + t.Run("with full answer", func(t *testing.T) { + alpn := []string{"h3"} + v4 := []string{"1.1.1.1"} + v6 := []string{"::1"} + data := dnsGenHTTPSReplySuccess(t, alpn, v4, v6) + d := &DNSDecoderMiekg{} + reply, err := d.DecodeHTTPS(data) + if err != nil { + t.Fatal(err) + } + if diff := cmp.Diff(alpn, reply.ALPN); diff != "" { + t.Fatal(diff) + } + if diff := cmp.Diff(v4, reply.IPv4); diff != "" { + t.Fatal(diff) + } + if diff := cmp.Diff(v6, reply.IPv6); diff != "" { + t.Fatal(diff) + } + }) + }) } -func TestDecoderNXDOMAIN(t *testing.T) { - d := &DNSDecoderMiekg{} - data, err := d.DecodeLookupHost(dns.TypeA, genReplyError(t, dns.RcodeNameError)) - if err == nil || !strings.HasSuffix(err.Error(), "no such host") { - t.Fatal("not the error we expected", err) - } - if data != nil { - t.Fatal("expected nil data here") - } -} - -func TestDecoderRefusedError(t *testing.T) { - d := &DNSDecoderMiekg{} - data, err := d.DecodeLookupHost(dns.TypeA, genReplyError(t, dns.RcodeRefused)) - if !errors.Is(err, errorsx.ErrOODNSRefused) { - t.Fatal("not the error we expected", err) - } - if data != nil { - t.Fatal("expected nil data here") - } -} - -func TestDecoderNoAddress(t *testing.T) { - d := &DNSDecoderMiekg{} - data, err := d.DecodeLookupHost(dns.TypeA, genReplySuccess(t, dns.TypeA)) - if !errors.Is(err, errorsx.ErrOODNSNoAnswer) { - t.Fatal("not the error we expected", err) - } - if data != nil { - t.Fatal("expected nil data here") - } -} - -func TestDecoderDecodeA(t *testing.T) { - d := &DNSDecoderMiekg{} - data, err := d.DecodeLookupHost( - dns.TypeA, genReplySuccess(t, dns.TypeA, "1.1.1.1", "8.8.8.8")) - if err != nil { - t.Fatal(err) - } - if len(data) != 2 { - t.Fatal("expected two entries here") - } - if data[0] != "1.1.1.1" { - t.Fatal("invalid first IPv4 entry") - } - if data[1] != "8.8.8.8" { - t.Fatal("invalid second IPv4 entry") - } -} - -func TestDecoderDecodeAAAA(t *testing.T) { - d := &DNSDecoderMiekg{} - data, err := d.DecodeLookupHost( - dns.TypeAAAA, genReplySuccess(t, dns.TypeAAAA, "::1", "fe80::1")) - if err != nil { - t.Fatal(err) - } - if len(data) != 2 { - t.Fatal("expected two entries here") - } - if data[0] != "::1" { - t.Fatal("invalid first IPv6 entry") - } - if data[1] != "fe80::1" { - t.Fatal("invalid second IPv6 entry") - } -} - -func TestDecoderUnexpectedAReply(t *testing.T) { - d := &DNSDecoderMiekg{} - data, err := d.DecodeLookupHost( - dns.TypeA, genReplySuccess(t, dns.TypeAAAA, "::1", "fe80::1")) - if !errors.Is(err, errorsx.ErrOODNSNoAnswer) { - t.Fatal("not the error we expected", err) - } - if data != nil { - t.Fatal("expected nil data here") - } -} - -func TestDecoderUnexpectedAAAAReply(t *testing.T) { - d := &DNSDecoderMiekg{} - data, err := d.DecodeLookupHost( - dns.TypeAAAA, genReplySuccess(t, dns.TypeA, "1.1.1.1", "8.8.4.4.")) - if !errors.Is(err, errorsx.ErrOODNSNoAnswer) { - t.Fatal("not the error we expected", err) - } - if data != nil { - t.Fatal("expected nil data here") - } -} - -func genReplyError(t *testing.T, code int) []byte { +// dnsGenReplyWithError generates a DNS reply for the given +// query type (e.g., dns.TypeA) using code as the Rcode. +func dnsGenReplyWithError(t *testing.T, qtype uint16, code int) []byte { question := dns.Question{ Name: dns.Fqdn("x.org"), - Qtype: dns.TypeA, + Qtype: qtype, Qclass: dns.ClassINET, } query := new(dns.Msg) @@ -137,7 +208,9 @@ func genReplyError(t *testing.T, code int) []byte { return data } -func genReplySuccess(t *testing.T, qtype uint16, ips ...string) []byte { +// dnsGenLookupHostReplySuccess generates a successful DNS reply for the given +// qtype (e.g., dns.TypeA) containing the given ips... in the answer. +func dnsGenLookupHostReplySuccess(t *testing.T, qtype uint16, ips ...string) []byte { question := dns.Question{ Name: dns.Fqdn("x.org"), Qtype: qtype, @@ -183,24 +256,9 @@ func genReplySuccess(t *testing.T, qtype uint16, ips ...string) []byte { return data } -func TestParseReply(t *testing.T) { - d := &DNSDecoderMiekg{} - msg := &dns.Msg{} - msg.Rcode = dns.RcodeFormatError // an rcode we don't handle - data, err := msg.Pack() - if err != nil { - t.Fatal(err) - } - reply, err := d.parseReply(data) - if !errors.Is(err, errorsx.ErrOODNSMisbehaving) { // catch all error - t.Fatal("not the error we expected", err) - } - if reply != nil { - t.Fatal("expected nil reply") - } -} - -func genReplyHTTPS(t *testing.T, alpns, ipv4, ipv6 []string) []byte { +// dnsGenHTTPSReplySuccess generates a successful HTTPS response containing +// the given (possibly nil) alpns, ipv4s, and ipv6s. +func dnsGenHTTPSReplySuccess(t *testing.T, alpns, ipv4s, ipv6s []string) []byte { question := dns.Question{ Name: dns.Fqdn("x.org"), Qtype: dns.TypeHTTPS, @@ -218,43 +276,32 @@ func genReplyHTTPS(t *testing.T, alpns, ipv4, ipv6 []string) []byte { answer := &dns.HTTPS{ SVCB: dns.SVCB{ Hdr: dns.RR_Header{ - Name: dns.Fqdn("x.org"), - Rrtype: dns.TypeHTTPS, - Class: dns.ClassINET, - Ttl: 100, - Rdlength: 0, + Name: dns.Fqdn("x.org"), + Rrtype: dns.TypeHTTPS, + Class: dns.ClassINET, + Ttl: 100, }, - Priority: 5, - Target: dns.Fqdn("x.org"), - Value: []dns.SVCBKeyValue{}, + Target: dns.Fqdn("x.org"), + Value: []dns.SVCBKeyValue{}, }, } reply.Answer = append(reply.Answer, answer) if len(alpns) > 0 { - answer.Value = append(answer.Value, &dns.SVCBAlpn{ - Alpn: alpns, - }) - answer.Hdr.Rdlength++ + answer.Value = append(answer.Value, &dns.SVCBAlpn{Alpn: alpns}) } - if len(ipv4) > 0 { + if len(ipv4s) > 0 { var addrs []net.IP - for _, addr := range ipv4 { + for _, addr := range ipv4s { addrs = append(addrs, net.ParseIP(addr)) } - answer.Value = append(answer.Value, &dns.SVCBIPv4Hint{ - Hint: addrs, - }) - answer.Hdr.Rdlength++ + answer.Value = append(answer.Value, &dns.SVCBIPv4Hint{Hint: addrs}) } - if len(ipv6) > 0 { + if len(ipv6s) > 0 { var addrs []net.IP - for _, addr := range ipv6 { + for _, addr := range ipv6s { addrs = append(addrs, net.ParseIP(addr)) } - answer.Value = append(answer.Value, &dns.SVCBIPv6Hint{ - Hint: addrs, - }) - answer.Hdr.Rdlength++ + answer.Value = append(answer.Value, &dns.SVCBIPv6Hint{Hint: addrs}) } data, err := reply.Pack() if err != nil { @@ -262,49 +309,3 @@ func genReplyHTTPS(t *testing.T, alpns, ipv4, ipv6 []string) []byte { } return data } - -func TestDecodeHTTPS(t *testing.T) { - t.Run("with nil data", func(t *testing.T) { - d := &DNSDecoderMiekg{} - reply, err := d.DecodeHTTPS(nil) - if err == nil || err.Error() != "dns: overflow unpacking uint16" { - t.Fatal("not the error we expected", err) - } - if reply != nil { - t.Fatal("expected nil reply") - } - }) - - t.Run("with empty answer", func(t *testing.T) { - data := genReplyHTTPS(t, nil, nil, nil) - d := &DNSDecoderMiekg{} - reply, err := d.DecodeHTTPS(data) - if !errors.Is(err, errorsx.ErrOODNSNoAnswer) { - t.Fatal("unexpected err", err) - } - if reply != nil { - t.Fatal("expected nil reply") - } - }) - - t.Run("with full answer", func(t *testing.T) { - alpn := []string{"h3"} - v4 := []string{"1.1.1.1"} - v6 := []string{"::1"} - data := genReplyHTTPS(t, alpn, v4, v6) - d := &DNSDecoderMiekg{} - reply, err := d.DecodeHTTPS(data) - if err != nil { - t.Fatal(err) - } - if diff := cmp.Diff(alpn, reply.ALPN); diff != "" { - t.Fatal(diff) - } - if diff := cmp.Diff(v4, reply.IPv4); diff != "" { - t.Fatal(diff) - } - if diff := cmp.Diff(v6, reply.IPv6); diff != "" { - t.Fatal(diff) - } - }) -} diff --git a/internal/netxlite/dnsx/dnsencoder_test.go b/internal/netxlite/dnsx/dnsencoder_test.go index fa3a9cf..99edad8 100644 --- a/internal/netxlite/dnsx/dnsencoder_test.go +++ b/internal/netxlite/dnsx/dnsencoder_test.go @@ -7,25 +7,64 @@ import ( "github.com/miekg/dns" ) -func TestEncoderEncodeA(t *testing.T) { - e := &DNSEncoderMiekg{} - data, err := e.Encode("x.org", dns.TypeA, false) - if err != nil { - t.Fatal(err) - } - validate(t, data, byte(dns.TypeA)) +func TestDNSEncoder(t *testing.T) { + + t.Run("encode A", func(t *testing.T) { + e := &DNSEncoderMiekg{} + data, err := e.Encode("x.org", dns.TypeA, false) + if err != nil { + t.Fatal(err) + } + dnsValidateEncodedQueryBytes(t, data, byte(dns.TypeA)) + }) + + t.Run("encode AAAA", func(t *testing.T) { + e := &DNSEncoderMiekg{} + data, err := e.Encode("x.org", dns.TypeAAAA, false) + if err != nil { + t.Fatal(err) + } + dnsValidateEncodedQueryBytes(t, data, byte(dns.TypeA)) + }) + + t.Run("encode padding", func(t *testing.T) { + // The purpose of this unit test is to make sure that for a wide + // array of values we obtain the right query size. + getquerylen := func(domainlen int, padding bool) int { + e := &DNSEncoderMiekg{} + data, err := e.Encode( + // This is not a valid name because it ends up being way + // longer than 255 octets. However, the library is allowing + // us to generate such name and we are not going to send + // it on the wire. Also, we check below that the query that + // we generate is long enough, so we should be good. + dns.Fqdn(strings.Repeat("x.", domainlen)), + dns.TypeA, padding, + ) + if err != nil { + t.Fatal(err) + } + return len(data) + } + for domainlen := 1; domainlen <= 4000; domainlen++ { + vanillalen := getquerylen(domainlen, false) + paddedlen := getquerylen(domainlen, true) + if vanillalen < domainlen { + t.Fatal("vanillalen is smaller than domainlen") + } + if (paddedlen % dnsPaddingDesiredBlockSize) != 0 { + t.Fatal("paddedlen is not a multiple of PaddingDesiredBlockSize") + } + if paddedlen < vanillalen { + t.Fatal("paddedlen is smaller than vanillalen") + } + } + }) } -func TestEncoderEncodeAAAA(t *testing.T) { - e := &DNSEncoderMiekg{} - data, err := e.Encode("x.org", dns.TypeAAAA, false) - if err != nil { - t.Fatal(err) - } - validate(t, data, byte(dns.TypeA)) -} - -func validate(t *testing.T, data []byte, qtype byte) { +// dnsValidateEncodedQueryBytes validates the query serialized in data +// for the given query type qtype (e.g., dns.TypeAAAA). +func dnsValidateEncodedQueryBytes(t *testing.T, data []byte, qtype byte) { // skipping over the query ID if data[2] != 1 { t.Fatal("FLAGS should only have RD set") @@ -62,37 +101,3 @@ func validate(t *testing.T, data []byte, qtype byte) { t.Fatal("The query is not IN") } } - -func TestEncoderPadding(t *testing.T) { - // The purpose of this unit test is to make sure that for a wide - // array of values we obtain the right query size. - getquerylen := func(domainlen int, padding bool) int { - e := &DNSEncoderMiekg{} - data, err := e.Encode( - // This is not a valid name because it ends up being way - // longer than 255 octets. However, the library is allowing - // us to generate such name and we are not going to send - // it on the wire. Also, we check below that the query that - // we generate is long enough, so we should be good. - dns.Fqdn(strings.Repeat("x.", domainlen)), - dns.TypeA, padding, - ) - if err != nil { - t.Fatal(err) - } - return len(data) - } - for domainlen := 1; domainlen <= 4000; domainlen++ { - vanillalen := getquerylen(domainlen, false) - paddedlen := getquerylen(domainlen, true) - if vanillalen < domainlen { - t.Fatal("vanillalen is smaller than domainlen") - } - if (paddedlen % dnsPaddingDesiredBlockSize) != 0 { - t.Fatal("paddedlen is not a multiple of PaddingDesiredBlockSize") - } - if paddedlen < vanillalen { - t.Fatal("paddedlen is smaller than vanillalen") - } - } -} diff --git a/internal/netxlite/dnsx/dnsoverhttps_test.go b/internal/netxlite/dnsx/dnsoverhttps_test.go index 0d1df36..1c88fb3 100644 --- a/internal/netxlite/dnsx/dnsoverhttps_test.go +++ b/internal/netxlite/dnsx/dnsoverhttps_test.go @@ -13,180 +13,184 @@ import ( "github.com/ooni/probe-cli/v3/internal/netxlite/mocks" ) -func TestDNSOverHTTPSNewRequestFailure(t *testing.T) { - const invalidURL = "\t" - txp := NewDNSOverHTTPS(http.DefaultClient, invalidURL) - data, err := txp.RoundTrip(context.Background(), nil) - if err == nil || !strings.HasSuffix(err.Error(), "invalid control character in URL") { - t.Fatal("expected an error here") - } - if data != nil { - t.Fatal("expected no response here") - } -} +func TestDNSOverHTTPS(t *testing.T) { + t.Run("RoundTrip", func(t *testing.T) { + t.Run("NewRequestFailure", func(t *testing.T) { + const invalidURL = "\t" + txp := NewDNSOverHTTPS(http.DefaultClient, invalidURL) + data, err := txp.RoundTrip(context.Background(), nil) + if err == nil || !strings.HasSuffix(err.Error(), "invalid control character in URL") { + t.Fatal("expected an error here") + } + if data != nil { + t.Fatal("expected no response here") + } + }) -func TestDNSOverHTTPSClientDoFailure(t *testing.T) { - expected := errors.New("mocked error") - txp := &DNSOverHTTPS{ - Client: &mocks.HTTPClient{ - MockDo: func(*http.Request) (*http.Response, error) { - return nil, expected - }, - }, - URL: "https://cloudflare-dns.com/dns-query", - } - data, err := txp.RoundTrip(context.Background(), nil) - if !errors.Is(err, expected) { - t.Fatal("expected an error here") - } - if data != nil { - t.Fatal("expected no response here") - } -} - -func TestDNSOverHTTPSHTTPFailure(t *testing.T) { - txp := &DNSOverHTTPS{ - Client: &mocks.HTTPClient{ - MockDo: func(*http.Request) (*http.Response, error) { - return &http.Response{ - StatusCode: 500, - Body: io.NopCloser(strings.NewReader("")), - }, nil - }, - }, - URL: "https://cloudflare-dns.com/dns-query", - } - data, err := txp.RoundTrip(context.Background(), nil) - if err == nil || err.Error() != "doh: server returned error" { - t.Fatal("expected an error here") - } - if data != nil { - t.Fatal("expected no response here") - } -} - -func TestDNSOverHTTPSMissingContentType(t *testing.T) { - txp := &DNSOverHTTPS{ - Client: &mocks.HTTPClient{ - MockDo: func(*http.Request) (*http.Response, error) { - return &http.Response{ - StatusCode: 200, - Body: io.NopCloser(strings.NewReader("")), - }, nil - }, - }, - URL: "https://cloudflare-dns.com/dns-query", - } - data, err := txp.RoundTrip(context.Background(), nil) - if err == nil || err.Error() != "doh: invalid content-type" { - t.Fatal("expected an error here") - } - if data != nil { - t.Fatal("expected no response here") - } -} - -func TestDNSOverHTTPSSuccess(t *testing.T) { - body := []byte("AAA") - txp := &DNSOverHTTPS{ - Client: &mocks.HTTPClient{ - MockDo: func(*http.Request) (*http.Response, error) { - return &http.Response{ - StatusCode: 200, - Body: io.NopCloser(bytes.NewReader(body)), - Header: http.Header{ - "Content-Type": []string{"application/dns-message"}, + t.Run("client.Do failure", func(t *testing.T) { + expected := errors.New("mocked error") + txp := &DNSOverHTTPS{ + Client: &mocks.HTTPClient{ + MockDo: func(*http.Request) (*http.Response, error) { + return nil, expected }, - }, nil + }, + URL: "https://cloudflare-dns.com/dns-query", + } + data, err := txp.RoundTrip(context.Background(), nil) + if !errors.Is(err, expected) { + t.Fatal("expected an error here") + } + if data != nil { + t.Fatal("expected no response here") + } + }) + + t.Run("server returns 500", func(t *testing.T) { + txp := &DNSOverHTTPS{ + Client: &mocks.HTTPClient{ + MockDo: func(*http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: 500, + Body: io.NopCloser(strings.NewReader("")), + }, nil + }, + }, + URL: "https://cloudflare-dns.com/dns-query", + } + data, err := txp.RoundTrip(context.Background(), nil) + if err == nil || err.Error() != "doh: server returned error" { + t.Fatal("expected an error here") + } + if data != nil { + t.Fatal("expected no response here") + } + }) + + t.Run("missing content type", func(t *testing.T) { + txp := &DNSOverHTTPS{ + Client: &mocks.HTTPClient{ + MockDo: func(*http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: 200, + Body: io.NopCloser(strings.NewReader("")), + }, nil + }, + }, + URL: "https://cloudflare-dns.com/dns-query", + } + data, err := txp.RoundTrip(context.Background(), nil) + if err == nil || err.Error() != "doh: invalid content-type" { + t.Fatal("expected an error here") + } + if data != nil { + t.Fatal("expected no response here") + } + }) + + t.Run("success", func(t *testing.T) { + body := []byte("AAA") + txp := &DNSOverHTTPS{ + Client: &mocks.HTTPClient{ + MockDo: func(*http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: 200, + Body: io.NopCloser(bytes.NewReader(body)), + Header: http.Header{ + "Content-Type": []string{"application/dns-message"}, + }, + }, nil + }, + }, + URL: "https://cloudflare-dns.com/dns-query", + } + data, err := txp.RoundTrip(context.Background(), nil) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(data, body) { + t.Fatal("not the response we expected") + } + }) + + t.Run("sets the correct user-agent", func(t *testing.T) { + expected := errors.New("mocked error") + var correct bool + txp := &DNSOverHTTPS{ + Client: &mocks.HTTPClient{ + MockDo: func(req *http.Request) (*http.Response, error) { + correct = req.Header.Get("User-Agent") == httpheader.UserAgent() + return nil, expected + }, + }, + URL: "https://cloudflare-dns.com/dns-query", + } + data, err := txp.RoundTrip(context.Background(), nil) + if !errors.Is(err, expected) { + t.Fatal("expected an error here") + } + if data != nil { + t.Fatal("expected no response here") + } + if !correct { + t.Fatal("did not see correct user agent") + } + }) + + t.Run("we can override the Host header", func(t *testing.T) { + var correct bool + expected := errors.New("mocked error") + hostOverride := "test.com" + txp := &DNSOverHTTPS{ + Client: &mocks.HTTPClient{ + MockDo: func(req *http.Request) (*http.Response, error) { + correct = req.Host == hostOverride + return nil, expected + }, + }, + URL: "https://cloudflare-dns.com/dns-query", + HostOverride: hostOverride, + } + data, err := txp.RoundTrip(context.Background(), nil) + if !errors.Is(err, expected) { + t.Fatal("expected an error here") + } + if data != nil { + t.Fatal("expected no response here") + } + if !correct { + t.Fatal("did not see correct host override") + } + }) + + }) + + t.Run("other functions behave correctly", func(t *testing.T) { + const queryURL = "https://cloudflare-dns.com/dns-query" + txp := NewDNSOverHTTPS(http.DefaultClient, queryURL) + if txp.Network() != "doh" { + t.Fatal("invalid network") + } + if txp.RequiresPadding() != true { + t.Fatal("should require padding") + } + if txp.Address() != queryURL { + t.Fatal("invalid address") + } + }) + + t.Run("CloseIdleConnections", func(t *testing.T) { + var called bool + doh := &DNSOverHTTPS{ + Client: &mocks.HTTPClient{ + MockCloseIdleConnections: func() { + called = true + }, }, - }, - URL: "https://cloudflare-dns.com/dns-query", - } - data, err := txp.RoundTrip(context.Background(), nil) - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(data, body) { - t.Fatal("not the response we expected") - } -} - -func TestDNSOverHTTPTransportOK(t *testing.T) { - const queryURL = "https://cloudflare-dns.com/dns-query" - txp := NewDNSOverHTTPS(http.DefaultClient, queryURL) - if txp.Network() != "doh" { - t.Fatal("invalid network") - } - if txp.RequiresPadding() != true { - t.Fatal("should require padding") - } - if txp.Address() != queryURL { - t.Fatal("invalid address") - } -} - -func TestDNSOverHTTPSClientSetsUserAgent(t *testing.T) { - expected := errors.New("mocked error") - var correct bool - txp := &DNSOverHTTPS{ - Client: &mocks.HTTPClient{ - MockDo: func(req *http.Request) (*http.Response, error) { - correct = req.Header.Get("User-Agent") == httpheader.UserAgent() - return nil, expected - }, - }, - URL: "https://cloudflare-dns.com/dns-query", - } - data, err := txp.RoundTrip(context.Background(), nil) - if !errors.Is(err, expected) { - t.Fatal("expected an error here") - } - if data != nil { - t.Fatal("expected no response here") - } - if !correct { - t.Fatal("did not see correct user agent") - } -} - -func TestDNSOverHTTPSHostOverride(t *testing.T) { - var correct bool - expected := errors.New("mocked error") - - hostOverride := "test.com" - txp := &DNSOverHTTPS{ - Client: &mocks.HTTPClient{ - MockDo: func(req *http.Request) (*http.Response, error) { - correct = req.Host == hostOverride - return nil, expected - }, - }, - URL: "https://cloudflare-dns.com/dns-query", - HostOverride: hostOverride, - } - data, err := txp.RoundTrip(context.Background(), nil) - if !errors.Is(err, expected) { - t.Fatal("expected an error here") - } - if data != nil { - t.Fatal("expected no response here") - } - if !correct { - t.Fatal("did not see correct host override") - } -} - -func TestDNSOverHTTPSCloseIdleConnections(t *testing.T) { - var called bool - doh := &DNSOverHTTPS{ - Client: &mocks.HTTPClient{ - MockCloseIdleConnections: func() { - called = true - }, - }, - } - doh.CloseIdleConnections() - if !called { - t.Fatal("not called") - } + } + doh.CloseIdleConnections() + if !called { + t.Fatal("not called") + } + }) } diff --git a/internal/netxlite/dnsx/dnsovertcp_test.go b/internal/netxlite/dnsx/dnsovertcp_test.go index efcbe34..803d97d 100644 --- a/internal/netxlite/dnsx/dnsovertcp_test.go +++ b/internal/netxlite/dnsx/dnsovertcp_test.go @@ -13,210 +13,214 @@ import ( "github.com/ooni/probe-cli/v3/internal/netxlite/mocks" ) -func TestDNSOverTCPTransportQueryTooLarge(t *testing.T) { - const address = "9.9.9.9:53" - txp := NewDNSOverTCP(new(net.Dialer).DialContext, address) - reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<18)) - if err == nil { - t.Fatal("expected an error here") - } - if reply != nil { - t.Fatal("expected nil reply here") - } -} +func TestDNSOverTCP(t *testing.T) { + t.Run("RoundTrip", func(t *testing.T) { + t.Run("query too large", func(t *testing.T) { + const address = "9.9.9.9:53" + txp := NewDNSOverTCP(new(net.Dialer).DialContext, address) + reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<18)) + if err == nil { + t.Fatal("expected an error here") + } + if reply != nil { + t.Fatal("expected nil reply here") + } + }) -func TestDNSOverTCPTransportDialFailure(t *testing.T) { - const address = "9.9.9.9:53" - mocked := errors.New("mocked error") - fakedialer := &mocks.Dialer{ - MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) { - return nil, mocked - }, - } - txp := NewDNSOverTCP(fakedialer.DialContext, address) - reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11)) - if !errors.Is(err, mocked) { - t.Fatal("not the error we expected") - } - if reply != nil { - t.Fatal("expected nil reply here") - } -} + t.Run("dial failure", func(t *testing.T) { + const address = "9.9.9.9:53" + mocked := errors.New("mocked error") + fakedialer := &mocks.Dialer{ + MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) { + return nil, mocked + }, + } + txp := NewDNSOverTCP(fakedialer.DialContext, address) + reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11)) + if !errors.Is(err, mocked) { + t.Fatal("not the error we expected") + } + if reply != nil { + t.Fatal("expected nil reply here") + } + }) -func TestDNSOverTCPTransportSetDealineFailure(t *testing.T) { - const address = "9.9.9.9:53" - mocked := errors.New("mocked error") - fakedialer := &mocks.Dialer{ - MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) { - return &mocks.Conn{ - MockSetDeadline: func(t time.Time) error { - return mocked + t.Run("SetDeadline failure", func(t *testing.T) { + const address = "9.9.9.9:53" + mocked := errors.New("mocked error") + fakedialer := &mocks.Dialer{ + MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) { + return &mocks.Conn{ + MockSetDeadline: func(t time.Time) error { + return mocked + }, + MockClose: func() error { + return nil + }, + }, nil }, - MockClose: func() error { - return nil - }, - }, nil - }, - } - txp := NewDNSOverTCP(fakedialer.DialContext, address) - reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11)) - if !errors.Is(err, mocked) { - t.Fatal("not the error we expected") - } - if reply != nil { - t.Fatal("expected nil reply here") - } -} + } + txp := NewDNSOverTCP(fakedialer.DialContext, address) + reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11)) + if !errors.Is(err, mocked) { + t.Fatal("not the error we expected") + } + if reply != nil { + t.Fatal("expected nil reply here") + } + }) -func TestDNSOverTCPTransportWriteFailure(t *testing.T) { - const address = "9.9.9.9:53" - mocked := errors.New("mocked error") - fakedialer := &mocks.Dialer{ - MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) { - return &mocks.Conn{ - MockSetDeadline: func(t time.Time) error { - return nil + t.Run("write failure", func(t *testing.T) { + const address = "9.9.9.9:53" + mocked := errors.New("mocked error") + fakedialer := &mocks.Dialer{ + MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) { + return &mocks.Conn{ + MockSetDeadline: func(t time.Time) error { + return nil + }, + MockWrite: func(b []byte) (int, error) { + return 0, mocked + }, + MockClose: func() error { + return nil + }, + }, nil }, - MockWrite: func(b []byte) (int, error) { - return 0, mocked - }, - MockClose: func() error { - return nil - }, - }, nil - }, - } - txp := NewDNSOverTCP(fakedialer.DialContext, address) - reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11)) - if !errors.Is(err, mocked) { - t.Fatal("not the error we expected") - } - if reply != nil { - t.Fatal("expected nil reply here") - } -} + } + txp := NewDNSOverTCP(fakedialer.DialContext, address) + reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11)) + if !errors.Is(err, mocked) { + t.Fatal("not the error we expected") + } + if reply != nil { + t.Fatal("expected nil reply here") + } + }) -func TestDNSOverTCPTransportReadFailure(t *testing.T) { - const address = "9.9.9.9:53" - mocked := errors.New("mocked error") - fakedialer := &mocks.Dialer{ - MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) { - return &mocks.Conn{ - MockSetDeadline: func(t time.Time) error { - return nil + t.Run("first read fails", func(t *testing.T) { + const address = "9.9.9.9:53" + mocked := errors.New("mocked error") + fakedialer := &mocks.Dialer{ + MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) { + return &mocks.Conn{ + MockSetDeadline: func(t time.Time) error { + return nil + }, + MockWrite: func(b []byte) (int, error) { + return len(b), nil + }, + MockRead: func(b []byte) (int, error) { + return 0, mocked + }, + MockClose: func() error { + return nil + }, + }, nil }, - MockWrite: func(b []byte) (int, error) { - return len(b), nil - }, - MockRead: func(b []byte) (int, error) { - return 0, mocked - }, - MockClose: func() error { - return nil - }, - }, nil - }, - } - txp := NewDNSOverTCP(fakedialer.DialContext, address) - reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11)) - if !errors.Is(err, mocked) { - t.Fatal("not the error we expected") - } - if reply != nil { - t.Fatal("expected nil reply here") - } -} + } + txp := NewDNSOverTCP(fakedialer.DialContext, address) + reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11)) + if !errors.Is(err, mocked) { + t.Fatal("not the error we expected") + } + if reply != nil { + t.Fatal("expected nil reply here") + } + }) -func TestDNSOverTCPTransportSecondReadFailure(t *testing.T) { - const address = "9.9.9.9:53" - mocked := errors.New("mocked error") - input := io.MultiReader( - bytes.NewReader([]byte{byte(0), byte(2)}), - &mocks.Reader{ - MockRead: func(b []byte) (int, error) { - return 0, mocked - }, - }, - ) - fakedialer := &mocks.Dialer{ - MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) { - return &mocks.Conn{ - MockSetDeadline: func(t time.Time) error { - return nil + t.Run("second read fails", func(t *testing.T) { + const address = "9.9.9.9:53" + mocked := errors.New("mocked error") + input := io.MultiReader( + bytes.NewReader([]byte{byte(0), byte(2)}), + &mocks.Reader{ + MockRead: func(b []byte) (int, error) { + return 0, mocked + }, }, - MockWrite: func(b []byte) (int, error) { - return len(b), nil + ) + fakedialer := &mocks.Dialer{ + MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) { + return &mocks.Conn{ + MockSetDeadline: func(t time.Time) error { + return nil + }, + MockWrite: func(b []byte) (int, error) { + return len(b), nil + }, + MockRead: input.Read, + MockClose: func() error { + return nil + }, + }, nil }, - MockRead: input.Read, - MockClose: func() error { - return nil - }, - }, nil - }, - } - txp := NewDNSOverTCP(fakedialer.DialContext, address) - reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11)) - if !errors.Is(err, mocked) { - t.Fatal("not the error we expected") - } - if reply != nil { - t.Fatal("expected nil reply here") - } -} + } + txp := NewDNSOverTCP(fakedialer.DialContext, address) + reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11)) + if !errors.Is(err, mocked) { + t.Fatal("not the error we expected") + } + if reply != nil { + t.Fatal("expected nil reply here") + } + }) -func TestDNSOverTCPTransportAllGood(t *testing.T) { - const address = "9.9.9.9:53" - input := bytes.NewReader([]byte{byte(0), byte(1), byte(1)}) - fakedialer := &mocks.Dialer{ - MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) { - return &mocks.Conn{ - MockSetDeadline: func(t time.Time) error { - return nil + t.Run("successful case", func(t *testing.T) { + const address = "9.9.9.9:53" + input := bytes.NewReader([]byte{byte(0), byte(1), byte(1)}) + fakedialer := &mocks.Dialer{ + MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) { + return &mocks.Conn{ + MockSetDeadline: func(t time.Time) error { + return nil + }, + MockWrite: func(b []byte) (int, error) { + return len(b), nil + }, + MockRead: input.Read, + MockClose: func() error { + return nil + }, + }, nil }, - MockWrite: func(b []byte) (int, error) { - return len(b), nil - }, - MockRead: input.Read, - MockClose: func() error { - return nil - }, - }, nil - }, - } - txp := NewDNSOverTCP(fakedialer.DialContext, address) - reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11)) - if err != nil { - t.Fatal(err) - } - if len(reply) != 1 || reply[0] != 1 { - t.Fatal("not the response we expected") - } -} + } + txp := NewDNSOverTCP(fakedialer.DialContext, address) + reply, err := txp.RoundTrip(context.Background(), make([]byte, 1<<11)) + if err != nil { + t.Fatal(err) + } + if len(reply) != 1 || reply[0] != 1 { + t.Fatal("not the response we expected") + } + }) + }) -func TestDNSOverTCPTransportOK(t *testing.T) { - const address = "9.9.9.9:53" - txp := NewDNSOverTCP(new(net.Dialer).DialContext, address) - if txp.RequiresPadding() != false { - t.Fatal("invalid RequiresPadding") - } - if txp.Network() != "tcp" { - t.Fatal("invalid Network") - } - if txp.Address() != address { - t.Fatal("invalid Address") - } -} + t.Run("other functions okay with TCP", func(t *testing.T) { + const address = "9.9.9.9:53" + txp := NewDNSOverTCP(new(net.Dialer).DialContext, address) + if txp.RequiresPadding() != false { + t.Fatal("invalid RequiresPadding") + } + if txp.Network() != "tcp" { + t.Fatal("invalid Network") + } + if txp.Address() != address { + t.Fatal("invalid Address") + } + }) -func TestDNSOverTLSTransportOK(t *testing.T) { - const address = "9.9.9.9:853" - txp := NewDNSOverTLS((&tls.Dialer{}).DialContext, address) - if txp.RequiresPadding() != true { - t.Fatal("invalid RequiresPadding") - } - if txp.Network() != "dot" { - t.Fatal("invalid Network") - } - if txp.Address() != address { - t.Fatal("invalid Address") - } + t.Run("other functions okay with TLS", func(t *testing.T) { + const address = "9.9.9.9:853" + txp := NewDNSOverTLS((&tls.Dialer{}).DialContext, address) + if txp.RequiresPadding() != true { + t.Fatal("invalid RequiresPadding") + } + if txp.Network() != "dot" { + t.Fatal("invalid Network") + } + if txp.Address() != address { + t.Fatal("invalid Address") + } + }) } diff --git a/internal/netxlite/dnsx/dnsoverudp_test.go b/internal/netxlite/dnsx/dnsoverudp_test.go index 092ffa3..fab0a92 100644 --- a/internal/netxlite/dnsx/dnsoverudp_test.go +++ b/internal/netxlite/dnsx/dnsoverudp_test.go @@ -11,147 +11,151 @@ import ( "github.com/ooni/probe-cli/v3/internal/netxlite/mocks" ) -func TestDNSOverUDPDialFailure(t *testing.T) { - mocked := errors.New("mocked error") - const address = "9.9.9.9:53" - txp := NewDNSOverUDP(&mocks.Dialer{ - MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) { - return nil, mocked - }, - }, address) - data, err := txp.RoundTrip(context.Background(), nil) - if !errors.Is(err, mocked) { - t.Fatal("not the error we expected") - } - if data != nil { - t.Fatal("expected no response here") - } -} +func TestDNSOverUDP(t *testing.T) { + t.Run("RoundTrip", func(t *testing.T) { + t.Run("dial failure", func(t *testing.T) { + mocked := errors.New("mocked error") + const address = "9.9.9.9:53" + txp := NewDNSOverUDP(&mocks.Dialer{ + MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) { + return nil, mocked + }, + }, address) + data, err := txp.RoundTrip(context.Background(), nil) + if !errors.Is(err, mocked) { + t.Fatal("not the error we expected") + } + if data != nil { + t.Fatal("expected no response here") + } + }) -func TestDNSOverUDPSetDeadlineError(t *testing.T) { - mocked := errors.New("mocked error") - txp := NewDNSOverUDP( - &mocks.Dialer{ - MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) { - return &mocks.Conn{ - MockSetDeadline: func(t time.Time) error { - return mocked + t.Run("SetDeadline failure", func(t *testing.T) { + mocked := errors.New("mocked error") + txp := NewDNSOverUDP( + &mocks.Dialer{ + MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) { + return &mocks.Conn{ + MockSetDeadline: func(t time.Time) error { + return mocked + }, + MockClose: func() error { + return nil + }, + }, nil }, - MockClose: func() error { - return nil - }, - }, nil - }, - }, "9.9.9.9:53", - ) - data, err := txp.RoundTrip(context.Background(), nil) - if !errors.Is(err, mocked) { - t.Fatal("not the error we expected") - } - if data != nil { - t.Fatal("expected no response here") - } -} + }, "9.9.9.9:53", + ) + data, err := txp.RoundTrip(context.Background(), nil) + if !errors.Is(err, mocked) { + t.Fatal("not the error we expected") + } + if data != nil { + t.Fatal("expected no response here") + } + }) -func TestDNSOverUDPWriteFailure(t *testing.T) { - mocked := errors.New("mocked error") - txp := NewDNSOverUDP( - &mocks.Dialer{ - MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) { - return &mocks.Conn{ - MockSetDeadline: func(t time.Time) error { - return nil + t.Run("Write failure", func(t *testing.T) { + mocked := errors.New("mocked error") + txp := NewDNSOverUDP( + &mocks.Dialer{ + MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) { + return &mocks.Conn{ + MockSetDeadline: func(t time.Time) error { + return nil + }, + MockWrite: func(b []byte) (int, error) { + return 0, mocked + }, + MockClose: func() error { + return nil + }, + }, nil }, - MockWrite: func(b []byte) (int, error) { - return 0, mocked - }, - MockClose: func() error { - return nil - }, - }, nil - }, - }, "9.9.9.9:53", - ) - data, err := txp.RoundTrip(context.Background(), nil) - if !errors.Is(err, mocked) { - t.Fatal("not the error we expected") - } - if data != nil { - t.Fatal("expected no response here") - } -} + }, "9.9.9.9:53", + ) + data, err := txp.RoundTrip(context.Background(), nil) + if !errors.Is(err, mocked) { + t.Fatal("not the error we expected") + } + if data != nil { + t.Fatal("expected no response here") + } + }) -func TestDNSOverUDPReadFailure(t *testing.T) { - mocked := errors.New("mocked error") - txp := NewDNSOverUDP( - &mocks.Dialer{ - MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) { - return &mocks.Conn{ - MockSetDeadline: func(t time.Time) error { - return nil + t.Run("Read failure", func(t *testing.T) { + mocked := errors.New("mocked error") + txp := NewDNSOverUDP( + &mocks.Dialer{ + MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) { + return &mocks.Conn{ + MockSetDeadline: func(t time.Time) error { + return nil + }, + MockWrite: func(b []byte) (int, error) { + return len(b), nil + }, + MockRead: func(b []byte) (int, error) { + return 0, mocked + }, + MockClose: func() error { + return nil + }, + }, nil }, - MockWrite: func(b []byte) (int, error) { - return len(b), nil - }, - MockRead: func(b []byte) (int, error) { - return 0, mocked - }, - MockClose: func() error { - return nil - }, - }, nil - }, - }, "9.9.9.9:53", - ) - data, err := txp.RoundTrip(context.Background(), nil) - if !errors.Is(err, mocked) { - t.Fatal("not the error we expected") - } - if data != nil { - t.Fatal("expected no response here") - } -} + }, "9.9.9.9:53", + ) + data, err := txp.RoundTrip(context.Background(), nil) + if !errors.Is(err, mocked) { + t.Fatal("not the error we expected") + } + if data != nil { + t.Fatal("expected no response here") + } + }) -func TestDNSOverUDPReadSuccess(t *testing.T) { - const expected = 17 - input := bytes.NewReader(make([]byte, expected)) - txp := NewDNSOverUDP( - &mocks.Dialer{ - MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) { - return &mocks.Conn{ - MockSetDeadline: func(t time.Time) error { - return nil + t.Run("read success", func(t *testing.T) { + const expected = 17 + input := bytes.NewReader(make([]byte, expected)) + txp := NewDNSOverUDP( + &mocks.Dialer{ + MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) { + return &mocks.Conn{ + MockSetDeadline: func(t time.Time) error { + return nil + }, + MockWrite: func(b []byte) (int, error) { + return len(b), nil + }, + MockRead: input.Read, + MockClose: func() error { + return nil + }, + }, nil }, - MockWrite: func(b []byte) (int, error) { - return len(b), nil - }, - MockRead: input.Read, - MockClose: func() error { - return nil - }, - }, nil - }, - }, "9.9.9.9:53", - ) - data, err := txp.RoundTrip(context.Background(), nil) - if err != nil { - t.Fatal(err) - } - if len(data) != expected { - t.Fatal("expected non nil data") - } -} + }, "9.9.9.9:53", + ) + data, err := txp.RoundTrip(context.Background(), nil) + if err != nil { + t.Fatal(err) + } + if len(data) != expected { + t.Fatal("expected non nil data") + } + }) + }) -func TestDNSOverUDPTransportOK(t *testing.T) { - const address = "9.9.9.9:53" - txp := NewDNSOverUDP(&net.Dialer{}, address) - if txp.RequiresPadding() != false { - t.Fatal("invalid RequiresPadding") - } - if txp.Network() != "udp" { - t.Fatal("invalid Network") - } - if txp.Address() != address { - t.Fatal("invalid Address") - } + t.Run("other functions okay", func(t *testing.T) { + const address = "9.9.9.9:53" + txp := NewDNSOverUDP(&net.Dialer{}, address) + if txp.RequiresPadding() != false { + t.Fatal("invalid RequiresPadding") + } + if txp.Network() != "udp" { + t.Fatal("invalid Network") + } + if txp.Address() != address { + t.Fatal("invalid Address") + } + }) } diff --git a/internal/netxlite/dnsx/serialresolver.go b/internal/netxlite/dnsx/serialresolver.go index 6ac1745..2ef0db6 100644 --- a/internal/netxlite/dnsx/serialresolver.go +++ b/internal/netxlite/dnsx/serialresolver.go @@ -61,7 +61,7 @@ func (r *SerialResolver) LookupHost(ctx context.Context, hostname string) ([]str return addrs, nil } -// LookupHTTPS issues an HTTPS query without retrying on failure. +// LookupHTTPS implements Resolver.LookupHTTPS. func (r *SerialResolver) LookupHTTPS( ctx context.Context, hostname string) (*HTTPSSvc, error) { querydata, err := r.Encoder.Encode( diff --git a/internal/netxlite/dnsx/serialresolver_test.go b/internal/netxlite/dnsx/serialresolver_test.go index 5efe8f7..4ac2e29 100644 --- a/internal/netxlite/dnsx/serialresolver_test.go +++ b/internal/netxlite/dnsx/serialresolver_test.go @@ -13,241 +13,245 @@ import ( "github.com/ooni/probe-cli/v3/internal/netxlite/errorsx" ) -func TestOONIGettingTransport(t *testing.T) { - txp := NewDNSOverTLS((&tls.Dialer{}).DialContext, "8.8.8.8:853") - r := NewSerialResolver(txp) - rtx := r.Transport() - if rtx.Network() != "dot" || rtx.Address() != "8.8.8.8:853" { - t.Fatal("not the transport we expected") - } - if r.Network() != rtx.Network() { - t.Fatal("invalid network seen from the resolver") - } - if r.Address() != rtx.Address() { - t.Fatal("invalid address seen from the resolver") - } -} - -func TestOONIEncodeError(t *testing.T) { - mocked := errors.New("mocked error") - txp := NewDNSOverTLS((&tls.Dialer{}).DialContext, "8.8.8.8:853") - r := SerialResolver{ - Encoder: &mocks.DNSEncoder{ - MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, error) { - return nil, mocked - }, - }, - Txp: txp, - } - addrs, err := r.LookupHost(context.Background(), "www.gogle.com") - if !errors.Is(err, mocked) { - t.Fatal("not the error we expected") - } - if addrs != nil { - t.Fatal("expected nil address here") - } -} - -func TestOONIRoundTripError(t *testing.T) { - mocked := errors.New("mocked error") - txp := &mocks.DNSTransport{ - MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) { - return nil, mocked - }, - MockRequiresPadding: func() bool { - return true - }, - } - r := NewSerialResolver(txp) - addrs, err := r.LookupHost(context.Background(), "www.gogle.com") - if !errors.Is(err, mocked) { - t.Fatal("not the error we expected") - } - if addrs != nil { - t.Fatal("expected nil address here") - } -} - -func TestOONIWithEmptyReply(t *testing.T) { - txp := &mocks.DNSTransport{ - MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) { - return genReplySuccess(t, dns.TypeA), nil - }, - MockRequiresPadding: func() bool { - return true - }, - } - r := NewSerialResolver(txp) - addrs, err := r.LookupHost(context.Background(), "www.gogle.com") - if !errors.Is(err, errorsx.ErrOODNSNoAnswer) { - t.Fatal("not the error we expected", err) - } - if addrs != nil { - t.Fatal("expected nil address here") - } -} - -func TestOONIWithAReply(t *testing.T) { - txp := &mocks.DNSTransport{ - MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) { - return genReplySuccess(t, dns.TypeA, "8.8.8.8"), nil - }, - MockRequiresPadding: func() bool { - return true - }, - } - r := NewSerialResolver(txp) - addrs, err := r.LookupHost(context.Background(), "www.gogle.com") - if err != nil { - t.Fatal(err) - } - if len(addrs) != 1 || addrs[0] != "8.8.8.8" { - t.Fatal("not the result we expected") - } -} - -func TestOONIWithAAAAReply(t *testing.T) { - txp := &mocks.DNSTransport{ - MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) { - return genReplySuccess(t, dns.TypeAAAA, "::1"), nil - }, - MockRequiresPadding: func() bool { - return true - }, - } - r := NewSerialResolver(txp) - addrs, err := r.LookupHost(context.Background(), "www.gogle.com") - if err != nil { - t.Fatal(err) - } - if len(addrs) != 1 || addrs[0] != "::1" { - t.Fatal("not the result we expected") - } -} - -func TestOONIWithTimeout(t *testing.T) { - txp := &mocks.DNSTransport{ - MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) { - return nil, &net.OpError{Err: errorsx.ETIMEDOUT, Op: "dial"} - }, - MockRequiresPadding: func() bool { - return true - }, - } - r := NewSerialResolver(txp) - addrs, err := r.LookupHost(context.Background(), "www.gogle.com") - if !errors.Is(err, errorsx.ETIMEDOUT) { - t.Fatal("not the error we expected") - } - if addrs != nil { - t.Fatal("expected nil address here") - } - if r.NumTimeouts.Load() <= 0 { - t.Fatal("we didn't actually take the timeouts") - } -} - -func TestSerialResolverCloseIdleConnections(t *testing.T) { - var called bool - r := &SerialResolver{ - Txp: &mocks.DNSTransport{ - MockCloseIdleConnections: func() { - called = true - }, - }, - } - r.CloseIdleConnections() - if !called { - t.Fatal("not called") - } -} - -func TestSerialResolverLookupHTTPS(t *testing.T) { - t.Run("for encoding error", func(t *testing.T) { - expected := errors.New("mocked error") - r := &SerialResolver{ - Encoder: &mocks.DNSEncoder{ - MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, error) { - return nil, expected - }, - }, - Decoder: nil, - NumTimeouts: &atomicx.Int64{}, - Txp: &mocks.DNSTransport{ - MockRequiresPadding: func() bool { - return false - }, - }, +func TestSerialResolver(t *testing.T) { + t.Run("transport okay", func(t *testing.T) { + txp := NewDNSOverTLS((&tls.Dialer{}).DialContext, "8.8.8.8:853") + r := NewSerialResolver(txp) + rtx := r.Transport() + if rtx.Network() != "dot" || rtx.Address() != "8.8.8.8:853" { + t.Fatal("not the transport we expected") } - ctx := context.Background() - https, err := r.LookupHTTPS(ctx, "example.com") - if !errors.Is(err, expected) { - t.Fatal("unexpected err", err) + if r.Network() != rtx.Network() { + t.Fatal("invalid network seen from the resolver") } - if https != nil { - t.Fatal("unexpected result") + if r.Address() != rtx.Address() { + t.Fatal("invalid address seen from the resolver") } }) - t.Run("for round-trip error", func(t *testing.T) { - expected := errors.New("mocked error") - r := &SerialResolver{ - Encoder: &mocks.DNSEncoder{ - MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, error) { - return make([]byte, 64), nil + t.Run("LookupHost", func(t *testing.T) { + t.Run("Encode error", func(t *testing.T) { + mocked := errors.New("mocked error") + txp := NewDNSOverTLS((&tls.Dialer{}).DialContext, "8.8.8.8:853") + r := SerialResolver{ + Encoder: &mocks.DNSEncoder{ + MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, error) { + return nil, mocked + }, }, - }, - Decoder: nil, - NumTimeouts: &atomicx.Int64{}, - Txp: &mocks.DNSTransport{ + Txp: txp, + } + addrs, err := r.LookupHost(context.Background(), "www.gogle.com") + if !errors.Is(err, mocked) { + t.Fatal("not the error we expected") + } + if addrs != nil { + t.Fatal("expected nil address here") + } + }) + + t.Run("RoundTrip error", func(t *testing.T) { + mocked := errors.New("mocked error") + txp := &mocks.DNSTransport{ MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) { - return nil, expected + return nil, mocked }, MockRequiresPadding: func() bool { - return false + return true + }, + } + r := NewSerialResolver(txp) + addrs, err := r.LookupHost(context.Background(), "www.gogle.com") + if !errors.Is(err, mocked) { + t.Fatal("not the error we expected") + } + if addrs != nil { + t.Fatal("expected nil address here") + } + }) + + t.Run("empty reply", func(t *testing.T) { + txp := &mocks.DNSTransport{ + MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) { + return dnsGenLookupHostReplySuccess(t, dns.TypeA), nil + }, + MockRequiresPadding: func() bool { + return true + }, + } + r := NewSerialResolver(txp) + addrs, err := r.LookupHost(context.Background(), "www.gogle.com") + if !errors.Is(err, errorsx.ErrOODNSNoAnswer) { + t.Fatal("not the error we expected", err) + } + if addrs != nil { + t.Fatal("expected nil address here") + } + }) + + t.Run("with A reply", func(t *testing.T) { + txp := &mocks.DNSTransport{ + MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) { + return dnsGenLookupHostReplySuccess(t, dns.TypeA, "8.8.8.8"), nil + }, + MockRequiresPadding: func() bool { + return true + }, + } + r := NewSerialResolver(txp) + addrs, err := r.LookupHost(context.Background(), "www.gogle.com") + if err != nil { + t.Fatal(err) + } + if len(addrs) != 1 || addrs[0] != "8.8.8.8" { + t.Fatal("not the result we expected") + } + }) + + t.Run("with AAAA reply", func(t *testing.T) { + txp := &mocks.DNSTransport{ + MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) { + return dnsGenLookupHostReplySuccess(t, dns.TypeAAAA, "::1"), nil + }, + MockRequiresPadding: func() bool { + return true + }, + } + r := NewSerialResolver(txp) + addrs, err := r.LookupHost(context.Background(), "www.gogle.com") + if err != nil { + t.Fatal(err) + } + if len(addrs) != 1 || addrs[0] != "::1" { + t.Fatal("not the result we expected") + } + }) + + t.Run("with timeout", func(t *testing.T) { + txp := &mocks.DNSTransport{ + MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) { + return nil, &net.OpError{Err: errorsx.ETIMEDOUT, Op: "dial"} + }, + MockRequiresPadding: func() bool { + return true + }, + } + r := NewSerialResolver(txp) + addrs, err := r.LookupHost(context.Background(), "www.gogle.com") + if !errors.Is(err, errorsx.ETIMEDOUT) { + t.Fatal("not the error we expected") + } + if addrs != nil { + t.Fatal("expected nil address here") + } + if r.NumTimeouts.Load() <= 0 { + t.Fatal("we didn't actually take the timeouts") + } + }) + }) + + t.Run("CloseIdleConnections", func(t *testing.T) { + var called bool + r := &SerialResolver{ + Txp: &mocks.DNSTransport{ + MockCloseIdleConnections: func() { + called = true }, }, } - ctx := context.Background() - https, err := r.LookupHTTPS(ctx, "example.com") - if !errors.Is(err, expected) { - t.Fatal("unexpected err", err) - } - if https != nil { - t.Fatal("unexpected result") + r.CloseIdleConnections() + if !called { + t.Fatal("not called") } }) - t.Run("for decode error", func(t *testing.T) { - expected := errors.New("mocked error") - r := &SerialResolver{ - Encoder: &mocks.DNSEncoder{ - MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, error) { - return make([]byte, 64), nil + t.Run("LookupHTTPS", func(t *testing.T) { + t.Run("for encoding error", func(t *testing.T) { + expected := errors.New("mocked error") + r := &SerialResolver{ + Encoder: &mocks.DNSEncoder{ + MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, error) { + return nil, expected + }, }, - }, - Decoder: &mocks.DNSDecoder{ - MockDecodeHTTPS: func(reply []byte) (*mocks.HTTPSSvc, error) { - return nil, expected + Decoder: nil, + NumTimeouts: &atomicx.Int64{}, + Txp: &mocks.DNSTransport{ + MockRequiresPadding: func() bool { + return false + }, }, - }, - NumTimeouts: &atomicx.Int64{}, - Txp: &mocks.DNSTransport{ - MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) { - return make([]byte, 128), nil + } + ctx := context.Background() + https, err := r.LookupHTTPS(ctx, "example.com") + if !errors.Is(err, expected) { + t.Fatal("unexpected err", err) + } + if https != nil { + t.Fatal("unexpected result") + } + }) + + t.Run("for round-trip error", func(t *testing.T) { + expected := errors.New("mocked error") + r := &SerialResolver{ + Encoder: &mocks.DNSEncoder{ + MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, error) { + return make([]byte, 64), nil + }, }, - MockRequiresPadding: func() bool { - return false + Decoder: nil, + NumTimeouts: &atomicx.Int64{}, + Txp: &mocks.DNSTransport{ + MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) { + return nil, expected + }, + MockRequiresPadding: func() bool { + return false + }, }, - }, - } - ctx := context.Background() - https, err := r.LookupHTTPS(ctx, "example.com") - if !errors.Is(err, expected) { - t.Fatal("unexpected err", err) - } - if https != nil { - t.Fatal("unexpected result") - } + } + ctx := context.Background() + https, err := r.LookupHTTPS(ctx, "example.com") + if !errors.Is(err, expected) { + t.Fatal("unexpected err", err) + } + if https != nil { + t.Fatal("unexpected result") + } + }) + + t.Run("for decode error", func(t *testing.T) { + expected := errors.New("mocked error") + r := &SerialResolver{ + Encoder: &mocks.DNSEncoder{ + MockEncode: func(domain string, qtype uint16, padding bool) ([]byte, error) { + return make([]byte, 64), nil + }, + }, + Decoder: &mocks.DNSDecoder{ + MockDecodeHTTPS: func(reply []byte) (*mocks.HTTPSSvc, error) { + return nil, expected + }, + }, + NumTimeouts: &atomicx.Int64{}, + Txp: &mocks.DNSTransport{ + MockRoundTrip: func(ctx context.Context, query []byte) (reply []byte, err error) { + return make([]byte, 128), nil + }, + MockRequiresPadding: func() bool { + return false + }, + }, + } + ctx := context.Background() + https, err := r.LookupHTTPS(ctx, "example.com") + if !errors.Is(err, expected) { + t.Fatal("unexpected err", err) + } + if https != nil { + t.Fatal("unexpected result") + } + }) }) }