8f7e3803eb
Acknowledge that transports MAY be used in isolation (i.e., outside of a Resolver) and add support for wrapping. Ensure that every factory that creates an unwrapped type is named accordingly to hopefully ensure there are no surprises. Implement DNSTransport wrapping and use a technique similar to the one used by Dialer to customize the DNSTransport while constructing more complex data types (e.g., a specific resolver). Ensure that the stdlib resolver's own "getaddrinfo" transport (1) is wrapped and (2) could be extended during construction. This work is part of my ongoing effort to bring to this repository websteps-illustrated changes relative to netxlite. Ref issue: https://github.com/ooni/probe/issues/2096
134 lines
3.8 KiB
Go
134 lines
3.8 KiB
Go
package netxlite
|
|
|
|
//
|
|
// DNS-over-{TCP,TLS} transport
|
|
//
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"io"
|
|
"math"
|
|
"net"
|
|
"time"
|
|
|
|
"github.com/ooni/probe-cli/v3/internal/model"
|
|
)
|
|
|
|
// DialContextFunc is the type of net.Dialer.DialContext.
|
|
type DialContextFunc func(context.Context, string, string) (net.Conn, error)
|
|
|
|
// DNSOverTCPTransport is a DNS-over-{TCP,TLS} DNSTransport.
|
|
//
|
|
// Note: this implementation always creates a new connection for each query. This
|
|
// strategy is less efficient but MAY be more robust for cleartext TCP connections
|
|
// when querying for a blocked domain name causes endpoint blocking.
|
|
type DNSOverTCPTransport struct {
|
|
dial DialContextFunc
|
|
decoder model.DNSDecoder
|
|
address string
|
|
network string
|
|
requiresPadding bool
|
|
}
|
|
|
|
// NewUnwrappedDNSOverTCPTransport creates a new DNSOverTCPTransport
|
|
// that has not been wrapped yet.
|
|
//
|
|
// Arguments:
|
|
//
|
|
// - dial is a function with the net.Dialer.DialContext's signature;
|
|
//
|
|
// - address is the endpoint address (e.g., 8.8.8.8:53).
|
|
func NewUnwrappedDNSOverTCPTransport(dial DialContextFunc, address string) *DNSOverTCPTransport {
|
|
return newDNSOverTCPOrTLSTransport(dial, "tcp", address, false)
|
|
}
|
|
|
|
// NewUnwrappedDNSOverTLSTransport creates a new DNSOverTLS transport
|
|
// that has not been wrapped yet.
|
|
//
|
|
// Arguments:
|
|
//
|
|
// - dial is a function with the net.Dialer.DialContext's signature;
|
|
//
|
|
// - address is the endpoint address (e.g., 8.8.8.8:853).
|
|
func NewUnwrappedDNSOverTLSTransport(dial DialContextFunc, address string) *DNSOverTCPTransport {
|
|
return newDNSOverTCPOrTLSTransport(dial, "dot", address, true)
|
|
}
|
|
|
|
// newDNSOverTCPOrTLSTransport is the common factory for creating a transport
|
|
func newDNSOverTCPOrTLSTransport(
|
|
dial DialContextFunc, network, address string, padding bool) *DNSOverTCPTransport {
|
|
return &DNSOverTCPTransport{
|
|
dial: dial,
|
|
decoder: &DNSDecoderMiekg{},
|
|
address: address,
|
|
network: network,
|
|
requiresPadding: padding,
|
|
}
|
|
}
|
|
|
|
// errQueryTooLarge indicates the query is too large for the transport.
|
|
var errQueryTooLarge = errors.New("oodns: query too large for this transport")
|
|
|
|
// RoundTrip sends a query and receives a reply.
|
|
func (t *DNSOverTCPTransport) RoundTrip(
|
|
ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
|
|
// TODO(bassosimone): this method should more strictly honour the context, which
|
|
// currently is only used to bound the dial operation
|
|
rawQuery, err := query.Bytes()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if len(rawQuery) > math.MaxUint16 {
|
|
return nil, errQueryTooLarge
|
|
}
|
|
conn, err := t.dial(ctx, "tcp", t.address)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer conn.Close()
|
|
const iotimeout = 10 * time.Second
|
|
conn.SetDeadline(time.Now().Add(iotimeout))
|
|
// Write request
|
|
buf := []byte{byte(len(rawQuery) >> 8)}
|
|
buf = append(buf, byte(len(rawQuery)))
|
|
buf = append(buf, rawQuery...)
|
|
if _, err = conn.Write(buf); err != nil {
|
|
return nil, err
|
|
}
|
|
// Read response
|
|
header := make([]byte, 2)
|
|
if _, err = io.ReadFull(conn, header); err != nil {
|
|
return nil, err
|
|
}
|
|
length := int(header[0])<<8 | int(header[1])
|
|
rawResponse := make([]byte, length)
|
|
if _, err = io.ReadFull(conn, rawResponse); err != nil {
|
|
return nil, err
|
|
}
|
|
return t.decoder.DecodeResponse(rawResponse, query)
|
|
}
|
|
|
|
// RequiresPadding returns true for DoT and false for TCP
|
|
// according to RFC8467.
|
|
func (t *DNSOverTCPTransport) RequiresPadding() bool {
|
|
return t.requiresPadding
|
|
}
|
|
|
|
// Network returns the transport network, i.e., "dot" or "tcp".
|
|
func (t *DNSOverTCPTransport) Network() string {
|
|
return t.network
|
|
}
|
|
|
|
// Address returns the upstream server endpoint (e.g., "1.1.1.1:853").
|
|
func (t *DNSOverTCPTransport) Address() string {
|
|
return t.address
|
|
}
|
|
|
|
// CloseIdleConnections closes idle connections, if any.
|
|
func (t *DNSOverTCPTransport) CloseIdleConnections() {
|
|
// nothing to do
|
|
}
|
|
|
|
var _ model.DNSTransport = &DNSOverTCPTransport{}
|