ooni-probe-cli/internal/netxlite/dnsx/dnsovertcp.go
Simone Basso 3cb782f0a2
refactor(netx): move dns transports in netxlite/dnsx (#503)
While there, modernize the way in which we run tests to avoid
depending on the fake files scattered around the tree and to
use some well defined mock structures instead.

Part of https://github.com/ooni/probe/issues/1591
2021-09-09 21:24:27 +02:00

103 lines
2.5 KiB
Go

package dnsx
import (
"context"
"errors"
"io"
"math"
"net"
"time"
)
// DialContextFunc is a generic function for dialing a connection.
type DialContextFunc func(context.Context, string, string) (net.Conn, error)
// DNSOverTCP is a DNS over TCP/TLS RoundTripper. Use NewDNSOverTCP
// and NewDNSOverTLS to create specific instances that use plaintext
// queries or encrypted queries over TLS.
//
// As a known bug, this implementation always creates a new connection
// for each incoming query, thus increasing the response delay.
type DNSOverTCP struct {
dial DialContextFunc
address string
network string
requiresPadding bool
}
// NewDNSOverTCP creates a new DNSOverTCP transport.
func NewDNSOverTCP(dial DialContextFunc, address string) *DNSOverTCP {
return &DNSOverTCP{
dial: dial,
address: address,
network: "tcp",
requiresPadding: false,
}
}
// NewDNSOverTLS creates a new DNSOverTLS transport.
func NewDNSOverTLS(dial DialContextFunc, address string) *DNSOverTCP {
return &DNSOverTCP{
dial: dial,
address: address,
network: "dot",
requiresPadding: true,
}
}
// RoundTrip implements RoundTripper.RoundTrip.
func (t *DNSOverTCP) RoundTrip(ctx context.Context, query []byte) ([]byte, error) {
if len(query) > math.MaxUint16 {
return nil, errors.New("query too long")
}
conn, err := t.dial(ctx, "tcp", t.address)
if err != nil {
return nil, err
}
defer conn.Close()
if err = conn.SetDeadline(time.Now().Add(10 * time.Second)); err != nil {
return nil, err
}
// Write request
buf := []byte{byte(len(query) >> 8)}
buf = append(buf, byte(len(query)))
buf = append(buf, query...)
if _, err = conn.Write(buf); err != nil {
return nil, err
}
// Read response
header := make([]byte, 2)
if _, err = io.ReadFull(conn, header); err != nil {
return nil, err
}
length := int(header[0])<<8 | int(header[1])
reply := make([]byte, length)
if _, err = io.ReadFull(conn, reply); err != nil {
return nil, err
}
return reply, nil
}
// RequiresPadding returns true for DoT and false for TCP
// according to RFC8467.
func (t *DNSOverTCP) RequiresPadding() bool {
return t.requiresPadding
}
// Network returns the transport network (e.g., doh, dot)
func (t *DNSOverTCP) Network() string {
return t.network
}
// Address returns the upstream server address.
func (t *DNSOverTCP) Address() string {
return t.address
}
// CloseIdleConnections closes idle connections.
func (t *DNSOverTCP) CloseIdleConnections() {
// nothing to do
}
var _ RoundTripper = &DNSOverTCP{}