ooni-probe-cli/internal/netxlite/dnsovertcp.go
Simone Basso 8f7e3803eb
feat(netxlite): implement DNSTransport wrapping (#776)
Acknowledge that transports MAY be used in isolation (i.e., outside
of a Resolver) and add support for wrapping.

Ensure that every factory that creates an unwrapped type is named
accordingly to hopefully ensure there are no surprises.

Implement DNSTransport wrapping and use a technique similar to the
one used by Dialer to customize the DNSTransport while constructing
more complex data types (e.g., a specific resolver).

Ensure that the stdlib resolver's own "getaddrinfo" transport (1)
is wrapped and (2) could be extended during construction.

This work is part of my ongoing effort to bring to this repository
websteps-illustrated changes relative to netxlite.

Ref issue: https://github.com/ooni/probe/issues/2096
2022-06-01 11:10:08 +02:00

134 lines
3.8 KiB
Go

package netxlite
//
// DNS-over-{TCP,TLS} transport
//
import (
"context"
"errors"
"io"
"math"
"net"
"time"
"github.com/ooni/probe-cli/v3/internal/model"
)
// DialContextFunc is the type of net.Dialer.DialContext.
type DialContextFunc func(context.Context, string, string) (net.Conn, error)
// DNSOverTCPTransport is a DNS-over-{TCP,TLS} DNSTransport.
//
// Note: this implementation always creates a new connection for each query. This
// strategy is less efficient but MAY be more robust for cleartext TCP connections
// when querying for a blocked domain name causes endpoint blocking.
type DNSOverTCPTransport struct {
dial DialContextFunc
decoder model.DNSDecoder
address string
network string
requiresPadding bool
}
// NewUnwrappedDNSOverTCPTransport creates a new DNSOverTCPTransport
// that has not been wrapped yet.
//
// Arguments:
//
// - dial is a function with the net.Dialer.DialContext's signature;
//
// - address is the endpoint address (e.g., 8.8.8.8:53).
func NewUnwrappedDNSOverTCPTransport(dial DialContextFunc, address string) *DNSOverTCPTransport {
return newDNSOverTCPOrTLSTransport(dial, "tcp", address, false)
}
// NewUnwrappedDNSOverTLSTransport creates a new DNSOverTLS transport
// that has not been wrapped yet.
//
// Arguments:
//
// - dial is a function with the net.Dialer.DialContext's signature;
//
// - address is the endpoint address (e.g., 8.8.8.8:853).
func NewUnwrappedDNSOverTLSTransport(dial DialContextFunc, address string) *DNSOverTCPTransport {
return newDNSOverTCPOrTLSTransport(dial, "dot", address, true)
}
// newDNSOverTCPOrTLSTransport is the common factory for creating a transport
func newDNSOverTCPOrTLSTransport(
dial DialContextFunc, network, address string, padding bool) *DNSOverTCPTransport {
return &DNSOverTCPTransport{
dial: dial,
decoder: &DNSDecoderMiekg{},
address: address,
network: network,
requiresPadding: padding,
}
}
// errQueryTooLarge indicates the query is too large for the transport.
var errQueryTooLarge = errors.New("oodns: query too large for this transport")
// RoundTrip sends a query and receives a reply.
func (t *DNSOverTCPTransport) RoundTrip(
ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
// TODO(bassosimone): this method should more strictly honour the context, which
// currently is only used to bound the dial operation
rawQuery, err := query.Bytes()
if err != nil {
return nil, err
}
if len(rawQuery) > math.MaxUint16 {
return nil, errQueryTooLarge
}
conn, err := t.dial(ctx, "tcp", t.address)
if err != nil {
return nil, err
}
defer conn.Close()
const iotimeout = 10 * time.Second
conn.SetDeadline(time.Now().Add(iotimeout))
// Write request
buf := []byte{byte(len(rawQuery) >> 8)}
buf = append(buf, byte(len(rawQuery)))
buf = append(buf, rawQuery...)
if _, err = conn.Write(buf); err != nil {
return nil, err
}
// Read response
header := make([]byte, 2)
if _, err = io.ReadFull(conn, header); err != nil {
return nil, err
}
length := int(header[0])<<8 | int(header[1])
rawResponse := make([]byte, length)
if _, err = io.ReadFull(conn, rawResponse); err != nil {
return nil, err
}
return t.decoder.DecodeResponse(rawResponse, query)
}
// RequiresPadding returns true for DoT and false for TCP
// according to RFC8467.
func (t *DNSOverTCPTransport) RequiresPadding() bool {
return t.requiresPadding
}
// Network returns the transport network, i.e., "dot" or "tcp".
func (t *DNSOverTCPTransport) Network() string {
return t.network
}
// Address returns the upstream server endpoint (e.g., "1.1.1.1:853").
func (t *DNSOverTCPTransport) Address() string {
return t.address
}
// CloseIdleConnections closes idle connections, if any.
func (t *DNSOverTCPTransport) CloseIdleConnections() {
// nothing to do
}
var _ model.DNSTransport = &DNSOverTCPTransport{}