3cb782f0a2
While there, modernize the way in which we run tests to avoid depending on the fake files scattered around the tree and to use some well defined mock structures instead. Part of https://github.com/ooni/probe/issues/1591
103 lines
2.5 KiB
Go
103 lines
2.5 KiB
Go
package dnsx
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"io"
|
|
"math"
|
|
"net"
|
|
"time"
|
|
)
|
|
|
|
// DialContextFunc is a generic function for dialing a connection.
|
|
type DialContextFunc func(context.Context, string, string) (net.Conn, error)
|
|
|
|
// DNSOverTCP is a DNS over TCP/TLS RoundTripper. Use NewDNSOverTCP
|
|
// and NewDNSOverTLS to create specific instances that use plaintext
|
|
// queries or encrypted queries over TLS.
|
|
//
|
|
// As a known bug, this implementation always creates a new connection
|
|
// for each incoming query, thus increasing the response delay.
|
|
type DNSOverTCP struct {
|
|
dial DialContextFunc
|
|
address string
|
|
network string
|
|
requiresPadding bool
|
|
}
|
|
|
|
// NewDNSOverTCP creates a new DNSOverTCP transport.
|
|
func NewDNSOverTCP(dial DialContextFunc, address string) *DNSOverTCP {
|
|
return &DNSOverTCP{
|
|
dial: dial,
|
|
address: address,
|
|
network: "tcp",
|
|
requiresPadding: false,
|
|
}
|
|
}
|
|
|
|
// NewDNSOverTLS creates a new DNSOverTLS transport.
|
|
func NewDNSOverTLS(dial DialContextFunc, address string) *DNSOverTCP {
|
|
return &DNSOverTCP{
|
|
dial: dial,
|
|
address: address,
|
|
network: "dot",
|
|
requiresPadding: true,
|
|
}
|
|
}
|
|
|
|
// RoundTrip implements RoundTripper.RoundTrip.
|
|
func (t *DNSOverTCP) RoundTrip(ctx context.Context, query []byte) ([]byte, error) {
|
|
if len(query) > math.MaxUint16 {
|
|
return nil, errors.New("query too long")
|
|
}
|
|
conn, err := t.dial(ctx, "tcp", t.address)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer conn.Close()
|
|
if err = conn.SetDeadline(time.Now().Add(10 * time.Second)); err != nil {
|
|
return nil, err
|
|
}
|
|
// Write request
|
|
buf := []byte{byte(len(query) >> 8)}
|
|
buf = append(buf, byte(len(query)))
|
|
buf = append(buf, query...)
|
|
if _, err = conn.Write(buf); err != nil {
|
|
return nil, err
|
|
}
|
|
// Read response
|
|
header := make([]byte, 2)
|
|
if _, err = io.ReadFull(conn, header); err != nil {
|
|
return nil, err
|
|
}
|
|
length := int(header[0])<<8 | int(header[1])
|
|
reply := make([]byte, length)
|
|
if _, err = io.ReadFull(conn, reply); err != nil {
|
|
return nil, err
|
|
}
|
|
return reply, nil
|
|
}
|
|
|
|
// RequiresPadding returns true for DoT and false for TCP
|
|
// according to RFC8467.
|
|
func (t *DNSOverTCP) RequiresPadding() bool {
|
|
return t.requiresPadding
|
|
}
|
|
|
|
// Network returns the transport network (e.g., doh, dot)
|
|
func (t *DNSOverTCP) Network() string {
|
|
return t.network
|
|
}
|
|
|
|
// Address returns the upstream server address.
|
|
func (t *DNSOverTCP) Address() string {
|
|
return t.address
|
|
}
|
|
|
|
// CloseIdleConnections closes idle connections.
|
|
func (t *DNSOverTCP) CloseIdleConnections() {
|
|
// nothing to do
|
|
}
|
|
|
|
var _ RoundTripper = &DNSOverTCP{}
|