refactor(resolver): add CloseIdleConnections to SerialResolver (#502)

While there, generally convert more code to internal testing
and to using pointer receivers as well.

Part of https://github.com/ooni/probe/issues/1591.
This commit is contained in:
Simone Basso 2021-09-09 20:58:04 +02:00 committed by GitHub
parent 1eb9e8c9b0
commit b3c36b5c7f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 55 additions and 52 deletions

View File

@ -116,7 +116,7 @@ func TestConfigurerNewConfigurationResolverDNSOverHTTPSPowerdns(t *testing.T) {
if configuration.HTTPConfig.BaseResolver == nil {
t.Fatal("not the BaseResolver we expected")
}
sr, ok := configuration.HTTPConfig.BaseResolver.(resolver.SerialResolver)
sr, ok := configuration.HTTPConfig.BaseResolver.(*resolver.SerialResolver)
if !ok {
t.Fatal("not the resolver we expected")
}
@ -192,7 +192,7 @@ func TestConfigurerNewConfigurationResolverDNSOverHTTPSGoogle(t *testing.T) {
if configuration.HTTPConfig.BaseResolver == nil {
t.Fatal("not the BaseResolver we expected")
}
sr, ok := configuration.HTTPConfig.BaseResolver.(resolver.SerialResolver)
sr, ok := configuration.HTTPConfig.BaseResolver.(*resolver.SerialResolver)
if !ok {
t.Fatal("not the resolver we expected")
}
@ -268,7 +268,7 @@ func TestConfigurerNewConfigurationResolverDNSOverHTTPSCloudflare(t *testing.T)
if configuration.HTTPConfig.BaseResolver == nil {
t.Fatal("not the BaseResolver we expected")
}
sr, ok := configuration.HTTPConfig.BaseResolver.(resolver.SerialResolver)
sr, ok := configuration.HTTPConfig.BaseResolver.(*resolver.SerialResolver)
if !ok {
t.Fatal("not the resolver we expected")
}
@ -344,7 +344,7 @@ func TestConfigurerNewConfigurationResolverUDP(t *testing.T) {
if configuration.HTTPConfig.BaseResolver == nil {
t.Fatal("not the BaseResolver we expected")
}
sr, ok := configuration.HTTPConfig.BaseResolver.(resolver.SerialResolver)
sr, ok := configuration.HTTPConfig.BaseResolver.(*resolver.SerialResolver)
if !ok {
t.Fatal("not the resolver we expected")
}

View File

@ -676,7 +676,7 @@ func TestNewDNSClientPowerdnsDoH(t *testing.T) {
if err != nil {
t.Fatal(err)
}
r, ok := dnsclient.Resolver.(resolver.SerialResolver)
r, ok := dnsclient.Resolver.(*resolver.SerialResolver)
if !ok {
t.Fatal("not the resolver we expected")
}
@ -692,7 +692,7 @@ func TestNewDNSClientGoogleDoH(t *testing.T) {
if err != nil {
t.Fatal(err)
}
r, ok := dnsclient.Resolver.(resolver.SerialResolver)
r, ok := dnsclient.Resolver.(*resolver.SerialResolver)
if !ok {
t.Fatal("not the resolver we expected")
}
@ -708,7 +708,7 @@ func TestNewDNSClientCloudflareDoH(t *testing.T) {
if err != nil {
t.Fatal(err)
}
r, ok := dnsclient.Resolver.(resolver.SerialResolver)
r, ok := dnsclient.Resolver.(*resolver.SerialResolver)
if !ok {
t.Fatal("not the resolver we expected")
}
@ -725,7 +725,7 @@ func TestNewDNSClientCloudflareDoHSaver(t *testing.T) {
if err != nil {
t.Fatal(err)
}
r, ok := dnsclient.Resolver.(resolver.SerialResolver)
r, ok := dnsclient.Resolver.(*resolver.SerialResolver)
if !ok {
t.Fatal("not the resolver we expected")
}
@ -745,7 +745,7 @@ func TestNewDNSClientUDP(t *testing.T) {
if err != nil {
t.Fatal(err)
}
r, ok := dnsclient.Resolver.(resolver.SerialResolver)
r, ok := dnsclient.Resolver.(*resolver.SerialResolver)
if !ok {
t.Fatal("not the resolver we expected")
}
@ -762,7 +762,7 @@ func TestNewDNSClientUDPDNSSaver(t *testing.T) {
if err != nil {
t.Fatal(err)
}
r, ok := dnsclient.Resolver.(resolver.SerialResolver)
r, ok := dnsclient.Resolver.(*resolver.SerialResolver)
if !ok {
t.Fatal("not the resolver we expected")
}
@ -782,7 +782,7 @@ func TestNewDNSClientTCP(t *testing.T) {
if err != nil {
t.Fatal(err)
}
r, ok := dnsclient.Resolver.(resolver.SerialResolver)
r, ok := dnsclient.Resolver.(*resolver.SerialResolver)
if !ok {
t.Fatal("not the resolver we expected")
}
@ -803,7 +803,7 @@ func TestNewDNSClientTCPDNSSaver(t *testing.T) {
if err != nil {
t.Fatal(err)
}
r, ok := dnsclient.Resolver.(resolver.SerialResolver)
r, ok := dnsclient.Resolver.(*resolver.SerialResolver)
if !ok {
t.Fatal("not the resolver we expected")
}
@ -827,7 +827,7 @@ func TestNewDNSClientDoT(t *testing.T) {
if err != nil {
t.Fatal(err)
}
r, ok := dnsclient.Resolver.(resolver.SerialResolver)
r, ok := dnsclient.Resolver.(*resolver.SerialResolver)
if !ok {
t.Fatal("not the resolver we expected")
}
@ -848,7 +848,7 @@ func TestNewDNSClientDoTDNSSaver(t *testing.T) {
if err != nil {
t.Fatal(err)
}
r, ok := dnsclient.Resolver.(resolver.SerialResolver)
r, ok := dnsclient.Resolver.(*resolver.SerialResolver)
if !ok {
t.Fatal("not the resolver we expected")
}

View File

@ -17,7 +17,7 @@ type Decoder interface {
type MiekgDecoder struct{}
// Decode implements Decoder.Decode.
func (d MiekgDecoder) Decode(qtype uint16, data []byte) ([]string, error) {
func (d *MiekgDecoder) Decode(qtype uint16, data []byte) ([]string, error) {
reply := new(dns.Msg)
if err := reply.Unpack(data); err != nil {
return nil, err
@ -51,4 +51,4 @@ func (d MiekgDecoder) Decode(qtype uint16, data []byte) ([]string, error) {
return addrs, nil
}
var _ Decoder = MiekgDecoder{}
var _ Decoder = &MiekgDecoder{}

View File

@ -1,15 +1,14 @@
package resolver_test
package resolver
import (
"strings"
"testing"
"github.com/miekg/dns"
"github.com/ooni/probe-cli/v3/internal/engine/netx/resolver"
)
func TestDecoderUnpackError(t *testing.T) {
d := resolver.MiekgDecoder{}
d := &MiekgDecoder{}
data, err := d.Decode(dns.TypeA, nil)
if err == nil {
t.Fatal("expected an error here")
@ -20,8 +19,8 @@ func TestDecoderUnpackError(t *testing.T) {
}
func TestDecoderNXDOMAIN(t *testing.T) {
d := resolver.MiekgDecoder{}
data, err := d.Decode(dns.TypeA, resolver.GenReplyError(t, dns.RcodeNameError))
d := &MiekgDecoder{}
data, err := d.Decode(dns.TypeA, GenReplyError(t, dns.RcodeNameError))
if err == nil || !strings.HasSuffix(err.Error(), "no such host") {
t.Fatal("not the error we expected")
}
@ -31,8 +30,8 @@ func TestDecoderNXDOMAIN(t *testing.T) {
}
func TestDecoderOtherError(t *testing.T) {
d := resolver.MiekgDecoder{}
data, err := d.Decode(dns.TypeA, resolver.GenReplyError(t, dns.RcodeRefused))
d := &MiekgDecoder{}
data, err := d.Decode(dns.TypeA, GenReplyError(t, dns.RcodeRefused))
if err == nil || !strings.HasSuffix(err.Error(), "query failed") {
t.Fatal("not the error we expected")
}
@ -42,8 +41,8 @@ func TestDecoderOtherError(t *testing.T) {
}
func TestDecoderNoAddress(t *testing.T) {
d := resolver.MiekgDecoder{}
data, err := d.Decode(dns.TypeA, resolver.GenReplySuccess(t, dns.TypeA))
d := &MiekgDecoder{}
data, err := d.Decode(dns.TypeA, GenReplySuccess(t, dns.TypeA))
if err == nil || !strings.HasSuffix(err.Error(), "no response returned") {
t.Fatal("not the error we expected")
}
@ -53,9 +52,9 @@ func TestDecoderNoAddress(t *testing.T) {
}
func TestDecoderDecodeA(t *testing.T) {
d := resolver.MiekgDecoder{}
d := &MiekgDecoder{}
data, err := d.Decode(
dns.TypeA, resolver.GenReplySuccess(t, dns.TypeA, "1.1.1.1", "8.8.8.8"))
dns.TypeA, GenReplySuccess(t, dns.TypeA, "1.1.1.1", "8.8.8.8"))
if err != nil {
t.Fatal(err)
}
@ -71,9 +70,9 @@ func TestDecoderDecodeA(t *testing.T) {
}
func TestDecoderDecodeAAAA(t *testing.T) {
d := resolver.MiekgDecoder{}
d := &MiekgDecoder{}
data, err := d.Decode(
dns.TypeAAAA, resolver.GenReplySuccess(t, dns.TypeAAAA, "::1", "fe80::1"))
dns.TypeAAAA, GenReplySuccess(t, dns.TypeAAAA, "::1", "fe80::1"))
if err != nil {
t.Fatal(err)
}
@ -89,9 +88,9 @@ func TestDecoderDecodeAAAA(t *testing.T) {
}
func TestDecoderUnexpectedAReply(t *testing.T) {
d := resolver.MiekgDecoder{}
d := &MiekgDecoder{}
data, err := d.Decode(
dns.TypeA, resolver.GenReplySuccess(t, dns.TypeAAAA, "::1", "fe80::1"))
dns.TypeA, GenReplySuccess(t, dns.TypeAAAA, "::1", "fe80::1"))
if err == nil || !strings.HasSuffix(err.Error(), "no response returned") {
t.Fatal("not the error we expected")
}
@ -101,9 +100,9 @@ func TestDecoderUnexpectedAReply(t *testing.T) {
}
func TestDecoderUnexpectedAAAAReply(t *testing.T) {
d := resolver.MiekgDecoder{}
d := &MiekgDecoder{}
data, err := d.Decode(
dns.TypeAAAA, resolver.GenReplySuccess(t, dns.TypeA, "1.1.1.1", "8.8.4.4."))
dns.TypeAAAA, GenReplySuccess(t, dns.TypeA, "1.1.1.1", "8.8.4.4."))
if err == nil || !strings.HasSuffix(err.Error(), "no response returned") {
t.Fatal("not the error we expected")
}

View File

@ -22,7 +22,7 @@ const (
)
// Encode implements Encoder.Encode
func (e MiekgEncoder) Encode(domain string, qtype uint16, padding bool) ([]byte, error) {
func (e *MiekgEncoder) Encode(domain string, qtype uint16, padding bool) ([]byte, error) {
question := dns.Question{
Name: dns.Fqdn(domain),
Qtype: qtype,
@ -49,4 +49,4 @@ func (e MiekgEncoder) Encode(domain string, qtype uint16, padding bool) ([]byte,
return query.Pack()
}
var _ Encoder = MiekgEncoder{}
var _ Encoder = &MiekgEncoder{}

View File

@ -1,15 +1,14 @@
package resolver_test
package resolver
import (
"strings"
"testing"
"github.com/miekg/dns"
"github.com/ooni/probe-cli/v3/internal/engine/netx/resolver"
)
func TestEncoderEncodeA(t *testing.T) {
e := resolver.MiekgEncoder{}
e := &MiekgEncoder{}
data, err := e.Encode("x.org", dns.TypeA, false)
if err != nil {
t.Fatal(err)
@ -18,7 +17,7 @@ func TestEncoderEncodeA(t *testing.T) {
}
func TestEncoderEncodeAAAA(t *testing.T) {
e := resolver.MiekgEncoder{}
e := &MiekgEncoder{}
data, err := e.Encode("x.org", dns.TypeAAAA, false)
if err != nil {
t.Fatal(err)
@ -68,7 +67,7 @@ func TestEncoderPadding(t *testing.T) {
// The purpose of this unit test is to make sure that for a wide
// array of values we obtain the right query size.
getquerylen := func(domainlen int, padding bool) int {
e := resolver.MiekgEncoder{}
e := &MiekgEncoder{}
data, err := e.Encode(
// This is not a valid name because it ends up being way
// longer than 255 octets. However, the library is allowing
@ -89,7 +88,7 @@ func TestEncoderPadding(t *testing.T) {
if vanillalen < domainlen {
t.Fatal("vanillalen is smaller than domainlen")
}
if (paddedlen % resolver.PaddingDesiredBlockSize) != 0 {
if (paddedlen % PaddingDesiredBlockSize) != 0 {
t.Fatal("paddedlen is not a multiple of PaddingDesiredBlockSize")
}
if paddedlen < vanillalen {

View File

@ -37,32 +37,37 @@ type SerialResolver struct {
}
// NewSerialResolver creates a new OONI Resolver instance.
func NewSerialResolver(t RoundTripper) SerialResolver {
return SerialResolver{
Encoder: MiekgEncoder{},
Decoder: MiekgDecoder{},
func NewSerialResolver(t RoundTripper) *SerialResolver {
return &SerialResolver{
Encoder: &MiekgEncoder{},
Decoder: &MiekgDecoder{},
NumTimeouts: &atomicx.Int64{},
Txp: t,
}
}
// Transport returns the transport being used.
func (r SerialResolver) Transport() RoundTripper {
func (r *SerialResolver) Transport() RoundTripper {
return r.Txp
}
// Network implements Resolver.Network
func (r SerialResolver) Network() string {
func (r *SerialResolver) Network() string {
return r.Txp.Network()
}
// Address implements Resolver.Address
func (r SerialResolver) Address() string {
func (r *SerialResolver) Address() string {
return r.Txp.Address()
}
// CloseIdleConnections closes idle connections.
func (r *SerialResolver) CloseIdleConnections() {
r.Txp.CloseIdleConnections()
}
// LookupHost implements Resolver.LookupHost.
func (r SerialResolver) LookupHost(ctx context.Context, hostname string) ([]string, error) {
func (r *SerialResolver) LookupHost(ctx context.Context, hostname string) ([]string, error) {
var addrs []string
addrsA, errA := r.roundTripWithRetry(ctx, hostname, dns.TypeA)
addrsAAAA, errAAAA := r.roundTripWithRetry(ctx, hostname, dns.TypeAAAA)
@ -74,7 +79,7 @@ func (r SerialResolver) LookupHost(ctx context.Context, hostname string) ([]stri
return addrs, nil
}
func (r SerialResolver) roundTripWithRetry(
func (r *SerialResolver) roundTripWithRetry(
ctx context.Context, hostname string, qtype uint16) ([]string, error) {
var errorslist []error
for i := 0; i < 3; i++ {
@ -100,7 +105,7 @@ func (r SerialResolver) roundTripWithRetry(
return nil, errorslist[0]
}
func (r SerialResolver) roundTrip(
func (r *SerialResolver) roundTrip(
ctx context.Context, hostname string, qtype uint16) ([]string, error) {
querydata, err := r.Encoder.Encode(hostname, qtype, r.Txp.RequiresPadding())
if err != nil {
@ -113,4 +118,4 @@ func (r SerialResolver) roundTrip(
return r.Decoder.Decode(qtype, replydata)
}
var _ Resolver = SerialResolver{}
var _ Resolver = &SerialResolver{}