refactor(resolver): add CloseIdleConnections to SerialResolver (#502)
While there, generally convert more code to internal testing and to using pointer receivers as well. Part of https://github.com/ooni/probe/issues/1591.
This commit is contained in:
parent
1eb9e8c9b0
commit
b3c36b5c7f
|
@ -116,7 +116,7 @@ func TestConfigurerNewConfigurationResolverDNSOverHTTPSPowerdns(t *testing.T) {
|
|||
if configuration.HTTPConfig.BaseResolver == nil {
|
||||
t.Fatal("not the BaseResolver we expected")
|
||||
}
|
||||
sr, ok := configuration.HTTPConfig.BaseResolver.(resolver.SerialResolver)
|
||||
sr, ok := configuration.HTTPConfig.BaseResolver.(*resolver.SerialResolver)
|
||||
if !ok {
|
||||
t.Fatal("not the resolver we expected")
|
||||
}
|
||||
|
@ -192,7 +192,7 @@ func TestConfigurerNewConfigurationResolverDNSOverHTTPSGoogle(t *testing.T) {
|
|||
if configuration.HTTPConfig.BaseResolver == nil {
|
||||
t.Fatal("not the BaseResolver we expected")
|
||||
}
|
||||
sr, ok := configuration.HTTPConfig.BaseResolver.(resolver.SerialResolver)
|
||||
sr, ok := configuration.HTTPConfig.BaseResolver.(*resolver.SerialResolver)
|
||||
if !ok {
|
||||
t.Fatal("not the resolver we expected")
|
||||
}
|
||||
|
@ -268,7 +268,7 @@ func TestConfigurerNewConfigurationResolverDNSOverHTTPSCloudflare(t *testing.T)
|
|||
if configuration.HTTPConfig.BaseResolver == nil {
|
||||
t.Fatal("not the BaseResolver we expected")
|
||||
}
|
||||
sr, ok := configuration.HTTPConfig.BaseResolver.(resolver.SerialResolver)
|
||||
sr, ok := configuration.HTTPConfig.BaseResolver.(*resolver.SerialResolver)
|
||||
if !ok {
|
||||
t.Fatal("not the resolver we expected")
|
||||
}
|
||||
|
@ -344,7 +344,7 @@ func TestConfigurerNewConfigurationResolverUDP(t *testing.T) {
|
|||
if configuration.HTTPConfig.BaseResolver == nil {
|
||||
t.Fatal("not the BaseResolver we expected")
|
||||
}
|
||||
sr, ok := configuration.HTTPConfig.BaseResolver.(resolver.SerialResolver)
|
||||
sr, ok := configuration.HTTPConfig.BaseResolver.(*resolver.SerialResolver)
|
||||
if !ok {
|
||||
t.Fatal("not the resolver we expected")
|
||||
}
|
||||
|
|
|
@ -676,7 +676,7 @@ func TestNewDNSClientPowerdnsDoH(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
r, ok := dnsclient.Resolver.(resolver.SerialResolver)
|
||||
r, ok := dnsclient.Resolver.(*resolver.SerialResolver)
|
||||
if !ok {
|
||||
t.Fatal("not the resolver we expected")
|
||||
}
|
||||
|
@ -692,7 +692,7 @@ func TestNewDNSClientGoogleDoH(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
r, ok := dnsclient.Resolver.(resolver.SerialResolver)
|
||||
r, ok := dnsclient.Resolver.(*resolver.SerialResolver)
|
||||
if !ok {
|
||||
t.Fatal("not the resolver we expected")
|
||||
}
|
||||
|
@ -708,7 +708,7 @@ func TestNewDNSClientCloudflareDoH(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
r, ok := dnsclient.Resolver.(resolver.SerialResolver)
|
||||
r, ok := dnsclient.Resolver.(*resolver.SerialResolver)
|
||||
if !ok {
|
||||
t.Fatal("not the resolver we expected")
|
||||
}
|
||||
|
@ -725,7 +725,7 @@ func TestNewDNSClientCloudflareDoHSaver(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
r, ok := dnsclient.Resolver.(resolver.SerialResolver)
|
||||
r, ok := dnsclient.Resolver.(*resolver.SerialResolver)
|
||||
if !ok {
|
||||
t.Fatal("not the resolver we expected")
|
||||
}
|
||||
|
@ -745,7 +745,7 @@ func TestNewDNSClientUDP(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
r, ok := dnsclient.Resolver.(resolver.SerialResolver)
|
||||
r, ok := dnsclient.Resolver.(*resolver.SerialResolver)
|
||||
if !ok {
|
||||
t.Fatal("not the resolver we expected")
|
||||
}
|
||||
|
@ -762,7 +762,7 @@ func TestNewDNSClientUDPDNSSaver(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
r, ok := dnsclient.Resolver.(resolver.SerialResolver)
|
||||
r, ok := dnsclient.Resolver.(*resolver.SerialResolver)
|
||||
if !ok {
|
||||
t.Fatal("not the resolver we expected")
|
||||
}
|
||||
|
@ -782,7 +782,7 @@ func TestNewDNSClientTCP(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
r, ok := dnsclient.Resolver.(resolver.SerialResolver)
|
||||
r, ok := dnsclient.Resolver.(*resolver.SerialResolver)
|
||||
if !ok {
|
||||
t.Fatal("not the resolver we expected")
|
||||
}
|
||||
|
@ -803,7 +803,7 @@ func TestNewDNSClientTCPDNSSaver(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
r, ok := dnsclient.Resolver.(resolver.SerialResolver)
|
||||
r, ok := dnsclient.Resolver.(*resolver.SerialResolver)
|
||||
if !ok {
|
||||
t.Fatal("not the resolver we expected")
|
||||
}
|
||||
|
@ -827,7 +827,7 @@ func TestNewDNSClientDoT(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
r, ok := dnsclient.Resolver.(resolver.SerialResolver)
|
||||
r, ok := dnsclient.Resolver.(*resolver.SerialResolver)
|
||||
if !ok {
|
||||
t.Fatal("not the resolver we expected")
|
||||
}
|
||||
|
@ -848,7 +848,7 @@ func TestNewDNSClientDoTDNSSaver(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
r, ok := dnsclient.Resolver.(resolver.SerialResolver)
|
||||
r, ok := dnsclient.Resolver.(*resolver.SerialResolver)
|
||||
if !ok {
|
||||
t.Fatal("not the resolver we expected")
|
||||
}
|
||||
|
|
|
@ -17,7 +17,7 @@ type Decoder interface {
|
|||
type MiekgDecoder struct{}
|
||||
|
||||
// Decode implements Decoder.Decode.
|
||||
func (d MiekgDecoder) Decode(qtype uint16, data []byte) ([]string, error) {
|
||||
func (d *MiekgDecoder) Decode(qtype uint16, data []byte) ([]string, error) {
|
||||
reply := new(dns.Msg)
|
||||
if err := reply.Unpack(data); err != nil {
|
||||
return nil, err
|
||||
|
@ -51,4 +51,4 @@ func (d MiekgDecoder) Decode(qtype uint16, data []byte) ([]string, error) {
|
|||
return addrs, nil
|
||||
}
|
||||
|
||||
var _ Decoder = MiekgDecoder{}
|
||||
var _ Decoder = &MiekgDecoder{}
|
||||
|
|
|
@ -1,15 +1,14 @@
|
|||
package resolver_test
|
||||
package resolver
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/resolver"
|
||||
)
|
||||
|
||||
func TestDecoderUnpackError(t *testing.T) {
|
||||
d := resolver.MiekgDecoder{}
|
||||
d := &MiekgDecoder{}
|
||||
data, err := d.Decode(dns.TypeA, nil)
|
||||
if err == nil {
|
||||
t.Fatal("expected an error here")
|
||||
|
@ -20,8 +19,8 @@ func TestDecoderUnpackError(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestDecoderNXDOMAIN(t *testing.T) {
|
||||
d := resolver.MiekgDecoder{}
|
||||
data, err := d.Decode(dns.TypeA, resolver.GenReplyError(t, dns.RcodeNameError))
|
||||
d := &MiekgDecoder{}
|
||||
data, err := d.Decode(dns.TypeA, GenReplyError(t, dns.RcodeNameError))
|
||||
if err == nil || !strings.HasSuffix(err.Error(), "no such host") {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
|
@ -31,8 +30,8 @@ func TestDecoderNXDOMAIN(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestDecoderOtherError(t *testing.T) {
|
||||
d := resolver.MiekgDecoder{}
|
||||
data, err := d.Decode(dns.TypeA, resolver.GenReplyError(t, dns.RcodeRefused))
|
||||
d := &MiekgDecoder{}
|
||||
data, err := d.Decode(dns.TypeA, GenReplyError(t, dns.RcodeRefused))
|
||||
if err == nil || !strings.HasSuffix(err.Error(), "query failed") {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
|
@ -42,8 +41,8 @@ func TestDecoderOtherError(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestDecoderNoAddress(t *testing.T) {
|
||||
d := resolver.MiekgDecoder{}
|
||||
data, err := d.Decode(dns.TypeA, resolver.GenReplySuccess(t, dns.TypeA))
|
||||
d := &MiekgDecoder{}
|
||||
data, err := d.Decode(dns.TypeA, GenReplySuccess(t, dns.TypeA))
|
||||
if err == nil || !strings.HasSuffix(err.Error(), "no response returned") {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
|
@ -53,9 +52,9 @@ func TestDecoderNoAddress(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestDecoderDecodeA(t *testing.T) {
|
||||
d := resolver.MiekgDecoder{}
|
||||
d := &MiekgDecoder{}
|
||||
data, err := d.Decode(
|
||||
dns.TypeA, resolver.GenReplySuccess(t, dns.TypeA, "1.1.1.1", "8.8.8.8"))
|
||||
dns.TypeA, GenReplySuccess(t, dns.TypeA, "1.1.1.1", "8.8.8.8"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -71,9 +70,9 @@ func TestDecoderDecodeA(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestDecoderDecodeAAAA(t *testing.T) {
|
||||
d := resolver.MiekgDecoder{}
|
||||
d := &MiekgDecoder{}
|
||||
data, err := d.Decode(
|
||||
dns.TypeAAAA, resolver.GenReplySuccess(t, dns.TypeAAAA, "::1", "fe80::1"))
|
||||
dns.TypeAAAA, GenReplySuccess(t, dns.TypeAAAA, "::1", "fe80::1"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -89,9 +88,9 @@ func TestDecoderDecodeAAAA(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestDecoderUnexpectedAReply(t *testing.T) {
|
||||
d := resolver.MiekgDecoder{}
|
||||
d := &MiekgDecoder{}
|
||||
data, err := d.Decode(
|
||||
dns.TypeA, resolver.GenReplySuccess(t, dns.TypeAAAA, "::1", "fe80::1"))
|
||||
dns.TypeA, GenReplySuccess(t, dns.TypeAAAA, "::1", "fe80::1"))
|
||||
if err == nil || !strings.HasSuffix(err.Error(), "no response returned") {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
|
@ -101,9 +100,9 @@ func TestDecoderUnexpectedAReply(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestDecoderUnexpectedAAAAReply(t *testing.T) {
|
||||
d := resolver.MiekgDecoder{}
|
||||
d := &MiekgDecoder{}
|
||||
data, err := d.Decode(
|
||||
dns.TypeAAAA, resolver.GenReplySuccess(t, dns.TypeA, "1.1.1.1", "8.8.4.4."))
|
||||
dns.TypeAAAA, GenReplySuccess(t, dns.TypeA, "1.1.1.1", "8.8.4.4."))
|
||||
if err == nil || !strings.HasSuffix(err.Error(), "no response returned") {
|
||||
t.Fatal("not the error we expected")
|
||||
}
|
||||
|
|
|
@ -22,7 +22,7 @@ const (
|
|||
)
|
||||
|
||||
// Encode implements Encoder.Encode
|
||||
func (e MiekgEncoder) Encode(domain string, qtype uint16, padding bool) ([]byte, error) {
|
||||
func (e *MiekgEncoder) Encode(domain string, qtype uint16, padding bool) ([]byte, error) {
|
||||
question := dns.Question{
|
||||
Name: dns.Fqdn(domain),
|
||||
Qtype: qtype,
|
||||
|
@ -49,4 +49,4 @@ func (e MiekgEncoder) Encode(domain string, qtype uint16, padding bool) ([]byte,
|
|||
return query.Pack()
|
||||
}
|
||||
|
||||
var _ Encoder = MiekgEncoder{}
|
||||
var _ Encoder = &MiekgEncoder{}
|
||||
|
|
|
@ -1,15 +1,14 @@
|
|||
package resolver_test
|
||||
package resolver
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/ooni/probe-cli/v3/internal/engine/netx/resolver"
|
||||
)
|
||||
|
||||
func TestEncoderEncodeA(t *testing.T) {
|
||||
e := resolver.MiekgEncoder{}
|
||||
e := &MiekgEncoder{}
|
||||
data, err := e.Encode("x.org", dns.TypeA, false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
@ -18,7 +17,7 @@ func TestEncoderEncodeA(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestEncoderEncodeAAAA(t *testing.T) {
|
||||
e := resolver.MiekgEncoder{}
|
||||
e := &MiekgEncoder{}
|
||||
data, err := e.Encode("x.org", dns.TypeAAAA, false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
@ -68,7 +67,7 @@ func TestEncoderPadding(t *testing.T) {
|
|||
// The purpose of this unit test is to make sure that for a wide
|
||||
// array of values we obtain the right query size.
|
||||
getquerylen := func(domainlen int, padding bool) int {
|
||||
e := resolver.MiekgEncoder{}
|
||||
e := &MiekgEncoder{}
|
||||
data, err := e.Encode(
|
||||
// This is not a valid name because it ends up being way
|
||||
// longer than 255 octets. However, the library is allowing
|
||||
|
@ -89,7 +88,7 @@ func TestEncoderPadding(t *testing.T) {
|
|||
if vanillalen < domainlen {
|
||||
t.Fatal("vanillalen is smaller than domainlen")
|
||||
}
|
||||
if (paddedlen % resolver.PaddingDesiredBlockSize) != 0 {
|
||||
if (paddedlen % PaddingDesiredBlockSize) != 0 {
|
||||
t.Fatal("paddedlen is not a multiple of PaddingDesiredBlockSize")
|
||||
}
|
||||
if paddedlen < vanillalen {
|
||||
|
|
|
@ -37,32 +37,37 @@ type SerialResolver struct {
|
|||
}
|
||||
|
||||
// NewSerialResolver creates a new OONI Resolver instance.
|
||||
func NewSerialResolver(t RoundTripper) SerialResolver {
|
||||
return SerialResolver{
|
||||
Encoder: MiekgEncoder{},
|
||||
Decoder: MiekgDecoder{},
|
||||
func NewSerialResolver(t RoundTripper) *SerialResolver {
|
||||
return &SerialResolver{
|
||||
Encoder: &MiekgEncoder{},
|
||||
Decoder: &MiekgDecoder{},
|
||||
NumTimeouts: &atomicx.Int64{},
|
||||
Txp: t,
|
||||
}
|
||||
}
|
||||
|
||||
// Transport returns the transport being used.
|
||||
func (r SerialResolver) Transport() RoundTripper {
|
||||
func (r *SerialResolver) Transport() RoundTripper {
|
||||
return r.Txp
|
||||
}
|
||||
|
||||
// Network implements Resolver.Network
|
||||
func (r SerialResolver) Network() string {
|
||||
func (r *SerialResolver) Network() string {
|
||||
return r.Txp.Network()
|
||||
}
|
||||
|
||||
// Address implements Resolver.Address
|
||||
func (r SerialResolver) Address() string {
|
||||
func (r *SerialResolver) Address() string {
|
||||
return r.Txp.Address()
|
||||
}
|
||||
|
||||
// CloseIdleConnections closes idle connections.
|
||||
func (r *SerialResolver) CloseIdleConnections() {
|
||||
r.Txp.CloseIdleConnections()
|
||||
}
|
||||
|
||||
// LookupHost implements Resolver.LookupHost.
|
||||
func (r SerialResolver) LookupHost(ctx context.Context, hostname string) ([]string, error) {
|
||||
func (r *SerialResolver) LookupHost(ctx context.Context, hostname string) ([]string, error) {
|
||||
var addrs []string
|
||||
addrsA, errA := r.roundTripWithRetry(ctx, hostname, dns.TypeA)
|
||||
addrsAAAA, errAAAA := r.roundTripWithRetry(ctx, hostname, dns.TypeAAAA)
|
||||
|
@ -74,7 +79,7 @@ func (r SerialResolver) LookupHost(ctx context.Context, hostname string) ([]stri
|
|||
return addrs, nil
|
||||
}
|
||||
|
||||
func (r SerialResolver) roundTripWithRetry(
|
||||
func (r *SerialResolver) roundTripWithRetry(
|
||||
ctx context.Context, hostname string, qtype uint16) ([]string, error) {
|
||||
var errorslist []error
|
||||
for i := 0; i < 3; i++ {
|
||||
|
@ -100,7 +105,7 @@ func (r SerialResolver) roundTripWithRetry(
|
|||
return nil, errorslist[0]
|
||||
}
|
||||
|
||||
func (r SerialResolver) roundTrip(
|
||||
func (r *SerialResolver) roundTrip(
|
||||
ctx context.Context, hostname string, qtype uint16) ([]string, error) {
|
||||
querydata, err := r.Encoder.Encode(hostname, qtype, r.Txp.RequiresPadding())
|
||||
if err != nil {
|
||||
|
@ -113,4 +118,4 @@ func (r SerialResolver) roundTrip(
|
|||
return r.Decoder.Decode(qtype, replydata)
|
||||
}
|
||||
|
||||
var _ Resolver = SerialResolver{}
|
||||
var _ Resolver = &SerialResolver{}
|
||||
|
|
Loading…
Reference in New Issue
Block a user