refactor: allow automatically wrap net/quic conn (#867)

See https://github.com/ooni/probe/issues/2219
This commit is contained in:
DecFox 2022-08-18 00:28:06 +05:30 committed by GitHub
parent e1d014e826
commit 097926c51f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 83 additions and 101 deletions

View File

@ -172,8 +172,7 @@ func (m *Measurer) quicHandshake(ctx context.Context, index int64,
alpn := strings.Split(m.config.alpn(), " ") alpn := strings.Split(m.config.alpn(), " ")
trace := measurexlite.NewTrace(index, zeroTime) trace := measurexlite.NewTrace(index, zeroTime)
ol := measurexlite.NewOperationLogger(logger, "SimpleQUICPing #%d %s %s %v", index, address, sni, alpn) ol := measurexlite.NewOperationLogger(logger, "SimpleQUICPing #%d %s %s %v", index, address, sni, alpn)
quicListener := netxlite.NewQUICListener() listener := netxlite.NewQUICListener()
listener := trace.WrapQUICListener(quicListener)
dialer := trace.NewQUICDialerWithoutResolver(listener, logger) dialer := trace.NewQUICDialerWithoutResolver(listener, logger)
tlsConfig := &tls.Config{ tlsConfig := &tls.Config{
NextProtos: alpn, NextProtos: alpn,

View File

@ -182,7 +182,6 @@ func (m *Measurer) tlsConnectAndHandshake(ctx context.Context, index int64,
return sp return sp
} }
defer conn.Close() defer conn.Close()
conn = trace.WrapNetConn(conn)
thx := trace.NewTLSHandshakerStdlib(logger) thx := trace.NewTLSHandshakerStdlib(logger)
config := &tls.Config{ config := &tls.Config{
NextProtos: alpn, NextProtos: alpn,

View File

@ -21,8 +21,8 @@ func MaybeClose(conn net.Conn) (err error) {
return return
} }
// WrapNetConn returns a wrapped conn that saves network events into this trace. // MaybeWrapNetConn implements model.Trace.MaybeWrapNetConn.
func (tx *Trace) WrapNetConn(conn net.Conn) net.Conn { func (tx *Trace) MaybeWrapNetConn(conn net.Conn) net.Conn {
return &connTrace{ return &connTrace{
Conn: conn, Conn: conn,
tx: tx, tx: tx,
@ -77,8 +77,8 @@ func MaybeCloseUDPLikeConn(conn model.UDPLikeConn) (err error) {
return return
} }
// WrapUDPLikeConn returns a wrapped conn that saves network events into this trace. // MaybeWrapUDPLikeConn implements model.Trace.MaybeWrapUDPLikeConn.
func (tx *Trace) WrapUDPLikeConn(conn model.UDPLikeConn) model.UDPLikeConn { func (tx *Trace) MaybeWrapUDPLikeConn(conn model.UDPLikeConn) model.UDPLikeConn {
return &udpLikeConnTrace{ return &udpLikeConnTrace{
UDPLikeConn: conn, UDPLikeConn: conn,
tx: tx, tx: tx,

View File

@ -40,7 +40,7 @@ func TestWrapNetConn(t *testing.T) {
underlying := &mocks.Conn{} underlying := &mocks.Conn{}
zeroTime := time.Now() zeroTime := time.Now()
trace := NewTrace(0, zeroTime) trace := NewTrace(0, zeroTime)
conn := trace.WrapNetConn(underlying) conn := trace.MaybeWrapNetConn(underlying)
ct := conn.(*connTrace) ct := conn.(*connTrace)
if ct.Conn != underlying { if ct.Conn != underlying {
t.Fatal("invalid underlying") t.Fatal("invalid underlying")
@ -70,7 +70,7 @@ func TestWrapNetConn(t *testing.T) {
td := testingx.NewTimeDeterministic(zeroTime) td := testingx.NewTimeDeterministic(zeroTime)
trace := NewTrace(0, zeroTime) trace := NewTrace(0, zeroTime)
trace.TimeNowFn = td.Now // deterministic time counting trace.TimeNowFn = td.Now // deterministic time counting
conn := trace.WrapNetConn(underlying) conn := trace.MaybeWrapNetConn(underlying)
const bufsiz = 128 const bufsiz = 128
buffer := make([]byte, bufsiz) buffer := make([]byte, bufsiz)
count, err := conn.Read(buffer) count, err := conn.Read(buffer)
@ -118,7 +118,7 @@ func TestWrapNetConn(t *testing.T) {
zeroTime := time.Now() zeroTime := time.Now()
trace := NewTrace(0, zeroTime) trace := NewTrace(0, zeroTime)
trace.networkEvent = make(chan *model.ArchivalNetworkEvent) // no buffer trace.networkEvent = make(chan *model.ArchivalNetworkEvent) // no buffer
conn := trace.WrapNetConn(underlying) conn := trace.MaybeWrapNetConn(underlying)
const bufsiz = 128 const bufsiz = 128
buffer := make([]byte, bufsiz) buffer := make([]byte, bufsiz)
count, err := conn.Read(buffer) count, err := conn.Read(buffer)
@ -154,7 +154,7 @@ func TestWrapNetConn(t *testing.T) {
td := testingx.NewTimeDeterministic(zeroTime) td := testingx.NewTimeDeterministic(zeroTime)
trace := NewTrace(0, zeroTime) trace := NewTrace(0, zeroTime)
trace.TimeNowFn = td.Now // deterministic time tracking trace.TimeNowFn = td.Now // deterministic time tracking
conn := trace.WrapNetConn(underlying) conn := trace.MaybeWrapNetConn(underlying)
const bufsiz = 128 const bufsiz = 128
buffer := make([]byte, bufsiz) buffer := make([]byte, bufsiz)
count, err := conn.Write(buffer) count, err := conn.Write(buffer)
@ -202,7 +202,7 @@ func TestWrapNetConn(t *testing.T) {
zeroTime := time.Now() zeroTime := time.Now()
trace := NewTrace(0, zeroTime) trace := NewTrace(0, zeroTime)
trace.networkEvent = make(chan *model.ArchivalNetworkEvent) // no buffer trace.networkEvent = make(chan *model.ArchivalNetworkEvent) // no buffer
conn := trace.WrapNetConn(underlying) conn := trace.MaybeWrapNetConn(underlying)
const bufsiz = 128 const bufsiz = 128
buffer := make([]byte, bufsiz) buffer := make([]byte, bufsiz)
count, err := conn.Write(buffer) count, err := conn.Write(buffer)
@ -224,7 +224,7 @@ func TestWrapUDPLikeConn(t *testing.T) {
underlying := &mocks.UDPLikeConn{} underlying := &mocks.UDPLikeConn{}
zeroTime := time.Now() zeroTime := time.Now()
trace := NewTrace(0, zeroTime) trace := NewTrace(0, zeroTime)
conn := trace.WrapUDPLikeConn(underlying) conn := trace.MaybeWrapUDPLikeConn(underlying)
ct := conn.(*udpLikeConnTrace) ct := conn.(*udpLikeConnTrace)
if ct.UDPLikeConn != underlying { if ct.UDPLikeConn != underlying {
t.Fatal("invalid underlying") t.Fatal("invalid underlying")
@ -248,7 +248,7 @@ func TestWrapUDPLikeConn(t *testing.T) {
td := testingx.NewTimeDeterministic(zeroTime) td := testingx.NewTimeDeterministic(zeroTime)
trace := NewTrace(0, zeroTime) trace := NewTrace(0, zeroTime)
trace.TimeNowFn = td.Now // deterministic time counting trace.TimeNowFn = td.Now // deterministic time counting
conn := trace.WrapUDPLikeConn(underlying) conn := trace.MaybeWrapUDPLikeConn(underlying)
const bufsiz = 128 const bufsiz = 128
buffer := make([]byte, bufsiz) buffer := make([]byte, bufsiz)
count, addr, err := conn.ReadFrom(buffer) count, addr, err := conn.ReadFrom(buffer)
@ -293,7 +293,7 @@ func TestWrapUDPLikeConn(t *testing.T) {
zeroTime := time.Now() zeroTime := time.Now()
trace := NewTrace(0, zeroTime) trace := NewTrace(0, zeroTime)
trace.networkEvent = make(chan *model.ArchivalNetworkEvent) // no buffer trace.networkEvent = make(chan *model.ArchivalNetworkEvent) // no buffer
conn := trace.WrapUDPLikeConn(underlying) conn := trace.MaybeWrapUDPLikeConn(underlying)
const bufsiz = 128 const bufsiz = 128
buffer := make([]byte, bufsiz) buffer := make([]byte, bufsiz)
count, addr, err := conn.ReadFrom(buffer) count, addr, err := conn.ReadFrom(buffer)
@ -322,7 +322,7 @@ func TestWrapUDPLikeConn(t *testing.T) {
td := testingx.NewTimeDeterministic(zeroTime) td := testingx.NewTimeDeterministic(zeroTime)
trace := NewTrace(0, zeroTime) trace := NewTrace(0, zeroTime)
trace.TimeNowFn = td.Now // deterministic time tracking trace.TimeNowFn = td.Now // deterministic time tracking
conn := trace.WrapUDPLikeConn(underlying) conn := trace.MaybeWrapUDPLikeConn(underlying)
const bufsiz = 128 const bufsiz = 128
buffer := make([]byte, bufsiz) buffer := make([]byte, bufsiz)
addr := &mocks.Addr{ addr := &mocks.Addr{
@ -365,7 +365,7 @@ func TestWrapUDPLikeConn(t *testing.T) {
zeroTime := time.Now() zeroTime := time.Now()
trace := NewTrace(0, zeroTime) trace := NewTrace(0, zeroTime)
trace.networkEvent = make(chan *model.ArchivalNetworkEvent) // no buffer trace.networkEvent = make(chan *model.ArchivalNetworkEvent) // no buffer
conn := trace.WrapUDPLikeConn(underlying) conn := trace.MaybeWrapUDPLikeConn(underlying)
const bufsiz = 128 const bufsiz = 128
buffer := make([]byte, bufsiz) buffer := make([]byte, bufsiz)
addr := &mocks.Addr{ addr := &mocks.Addr{

View File

@ -7,7 +7,6 @@ package measurexlite
import ( import (
"context" "context"
"crypto/tls" "crypto/tls"
"net"
"time" "time"
"github.com/lucas-clemente/quic-go" "github.com/lucas-clemente/quic-go"
@ -15,29 +14,6 @@ import (
"github.com/ooni/probe-cli/v3/internal/netxlite" "github.com/ooni/probe-cli/v3/internal/netxlite"
) )
// WrapQUICListener returns a wrapped model.QUICListener that uses this trace.
func (tx *Trace) WrapQUICListener(listener model.QUICListener) model.QUICListener {
return &quicListenerTrace{
QUICListener: listener,
tx: tx,
}
}
// quicListenerTrace is a trace-aware QUIC listener.
type quicListenerTrace struct {
model.QUICListener
tx *Trace
}
// Listen implements model.QUICListener.Listen
func (ql *quicListenerTrace) Listen(addr *net.UDPAddr) (model.UDPLikeConn, error) {
pconn, err := ql.QUICListener.Listen(addr)
if err != nil {
return nil, err
}
return ql.tx.WrapUDPLikeConn(pconn), nil
}
// NewQUICDialerWithoutResolver is equivalent to netxlite.NewQUICDialerWithoutResolver // NewQUICDialerWithoutResolver is equivalent to netxlite.NewQUICDialerWithoutResolver
// except that it returns a model.QUICDialer that uses this trace. // except that it returns a model.QUICDialer that uses this trace.
func (tx *Trace) NewQUICDialerWithoutResolver(listener model.QUICListener, dl model.DebugLogger) model.QUICDialer { func (tx *Trace) NewQUICDialerWithoutResolver(listener model.QUICListener, dl model.DebugLogger) model.QUICDialer {

View File

@ -17,65 +17,6 @@ import (
"github.com/ooni/probe-cli/v3/internal/testingx" "github.com/ooni/probe-cli/v3/internal/testingx"
) )
func TestNewQUICListener(t *testing.T) {
t.Run("NewQUICListenerTrace creates a wrapped listener", func(t *testing.T) {
underlying := &mocks.QUICListener{}
zeroTime := time.Now()
trace := NewTrace(0, zeroTime)
listenert := trace.WrapQUICListener(underlying).(*quicListenerTrace)
if listenert.QUICListener != underlying {
t.Fatal("invalid quic dialer")
}
if listenert.tx != trace {
t.Fatal("invalid trace")
}
})
t.Run("Listen works as intended", func(t *testing.T) {
t.Run("with error", func(t *testing.T) {
zeroTime := time.Now()
trace := NewTrace(0, zeroTime)
mockedErr := errors.New("mocked")
mockListener := &mocks.QUICListener{
MockListen: func(addr *net.UDPAddr) (model.UDPLikeConn, error) {
return nil, mockedErr
},
}
listener := trace.WrapQUICListener(mockListener)
pconn, err := listener.Listen(&net.UDPAddr{})
if !errors.Is(err, mockedErr) {
t.Fatal("unexpected err", err)
}
if pconn != nil {
t.Fatal("expected nil conn")
}
})
t.Run("without error", func(t *testing.T) {
zeroTime := time.Now()
trace := NewTrace(0, zeroTime)
mockConn := &mocks.UDPLikeConn{}
mockListener := &mocks.QUICListener{
MockListen: func(addr *net.UDPAddr) (model.UDPLikeConn, error) {
return mockConn, nil
},
}
listener := trace.WrapQUICListener(mockListener)
pconn, err := listener.Listen(&net.UDPAddr{})
if err != nil {
t.Fatal("unexpected err", err)
}
conn := pconn.(*udpLikeConnTrace)
if conn.UDPLikeConn != mockConn {
t.Fatal("invalid conn")
}
if conn.tx != trace {
t.Fatal("invalid trace")
}
})
})
}
func TestNewQUICDialerWithoutResolver(t *testing.T) { func TestNewQUICDialerWithoutResolver(t *testing.T) {
t.Run("NewQUICDialerWithoutResolver creates a wrapped dialer", func(t *testing.T) { t.Run("NewQUICDialerWithoutResolver creates a wrapped dialer", func(t *testing.T) {
underlying := &mocks.QUICDialer{} underlying := &mocks.QUICDialer{}

View File

@ -6,6 +6,7 @@ package mocks
import ( import (
"crypto/tls" "crypto/tls"
"net"
"time" "time"
"github.com/lucas-clemente/quic-go" "github.com/lucas-clemente/quic-go"
@ -16,6 +17,10 @@ import (
type Trace struct { type Trace struct {
MockTimeNow func() time.Time MockTimeNow func() time.Time
MockMaybeWrapNetConn func(conn net.Conn) net.Conn
MockMaybeWrapUDPLikeConn func(conn model.UDPLikeConn) model.UDPLikeConn
MockOnDNSRoundTripForLookupHost func(started time.Time, reso model.Resolver, query model.DNSQuery, MockOnDNSRoundTripForLookupHost func(started time.Time, reso model.Resolver, query model.DNSQuery,
response model.DNSResponse, addrs []string, err error, finished time.Time) response model.DNSResponse, addrs []string, err error, finished time.Time)
@ -39,6 +44,14 @@ func (t *Trace) TimeNow() time.Time {
return t.MockTimeNow() return t.MockTimeNow()
} }
func (t *Trace) MaybeWrapNetConn(conn net.Conn) net.Conn {
return t.MockMaybeWrapNetConn(conn)
}
func (t *Trace) MaybeWrapUDPLikeConn(conn model.UDPLikeConn) model.UDPLikeConn {
return t.MockMaybeWrapUDPLikeConn(conn)
}
func (t *Trace) OnDNSRoundTripForLookupHost(started time.Time, reso model.Resolver, query model.DNSQuery, func (t *Trace) OnDNSRoundTripForLookupHost(started time.Time, reso model.Resolver, query model.DNSQuery,
response model.DNSResponse, addrs []string, err error, finished time.Time) { response model.DNSResponse, addrs []string, err error, finished time.Time) {
t.MockOnDNSRoundTripForLookupHost(started, reso, query, response, addrs, err, finished) t.MockOnDNSRoundTripForLookupHost(started, reso, query, response, addrs, err, finished)

View File

@ -2,6 +2,7 @@ package mocks
import ( import (
"crypto/tls" "crypto/tls"
"net"
"testing" "testing"
"time" "time"
@ -22,6 +23,32 @@ func TestTrace(t *testing.T) {
} }
}) })
t.Run("MaybeWrapNetConn", func(t *testing.T) {
expect := &Conn{}
tx := &Trace{
MockMaybeWrapNetConn: func(conn net.Conn) net.Conn {
return expect
},
}
got := tx.MaybeWrapNetConn(&Conn{})
if got != expect {
t.Fatal("not working as intended")
}
})
t.Run("MaybeWrapUDPLikeConn", func(t *testing.T) {
expect := &UDPLikeConn{}
tx := &Trace{
MockMaybeWrapUDPLikeConn: func(conn model.UDPLikeConn) model.UDPLikeConn {
return expect
},
}
got := tx.MaybeWrapUDPLikeConn(&UDPLikeConn{})
if got != expect {
t.Fatal("not working as intended")
}
})
t.Run("OnDNSRoundTripForLookupHost", func(t *testing.T) { t.Run("OnDNSRoundTripForLookupHost", func(t *testing.T) {
var called bool var called bool
tx := &Trace{ tx := &Trace{

View File

@ -303,6 +303,21 @@ type Trace interface {
// can use functionality exported by the ./internal/testingx pkg. // can use functionality exported by the ./internal/testingx pkg.
TimeNow() time.Time TimeNow() time.Time
// MaybeWrapNetConn possibly wraps a net.Conn with the caller trace. If there's no
// desire to wrap the net.Conn, this function just returns the original net.Conn.
//
// Arguments:
//
// - conn is the non-nil underlying net.Conn to be wrapped
MaybeWrapNetConn(conn net.Conn) net.Conn
// MaybeWrapUDPLikeConn is like MaybeWrapNetConn but for UDPLikeConn.
//
// Arguments:
//
// - conn is the non-nil underlying UDPLikeConn to be wrapped
MaybeWrapUDPLikeConn(conn UDPLikeConn) UDPLikeConn
// OnDNSRoundTripForLookupHost is used with a DNSTransport and called // OnDNSRoundTripForLookupHost is used with a DNSTransport and called
// when the RoundTrip terminates. // when the RoundTrip terminates.
// //

View File

@ -224,7 +224,7 @@ func (d *dialerResolverWithTracing) DialContext(ctx context.Context, network, ad
trace.OnConnectDone(started, network, onlyhost, target, err, finished) trace.OnConnectDone(started, network, onlyhost, target, err, finished)
if err == nil { if err == nil {
conn = &dialerErrWrapperConn{conn} conn = &dialerErrWrapperConn{conn}
return conn, nil return trace.MaybeWrapNetConn(conn), nil
} }
errorslist = append(errorslist, err) errorslist = append(errorslist, err)
} }

View File

@ -135,6 +135,7 @@ func (d *quicDialerQUICGo) DialContext(ctx context.Context, network string,
} }
tlsConfig = d.maybeApplyTLSDefaults(tlsConfig, udpAddr.Port) tlsConfig = d.maybeApplyTLSDefaults(tlsConfig, udpAddr.Port)
trace := ContextTraceOrDefault(ctx) trace := ContextTraceOrDefault(ctx)
pconn = trace.MaybeWrapUDPLikeConn(pconn)
started := trace.TimeNow() started := trace.TimeNow()
trace.OnQUICHandshakeStart(started, address, quicConfig) trace.OnQUICHandshakeStart(started, address, quicConfig)
qconn, err := d.dialEarlyContext( qconn, err := d.dialEarlyContext(

View File

@ -7,6 +7,7 @@ package netxlite
import ( import (
"context" "context"
"crypto/tls" "crypto/tls"
"net"
"time" "time"
"github.com/lucas-clemente/quic-go" "github.com/lucas-clemente/quic-go"
@ -50,6 +51,16 @@ func (*traceDefault) TimeNow() time.Time {
return time.Now() return time.Now()
} }
// MaybeWrapNetConn implements model.Trace.MaybeWrapNetConn
func (*traceDefault) MaybeWrapNetConn(conn net.Conn) net.Conn {
return conn
}
// MaybeWrapUDPLikeConn implements model.Trace.MaybeWrapUDPLikeConn
func (*traceDefault) MaybeWrapUDPLikeConn(conn model.UDPLikeConn) model.UDPLikeConn {
return conn
}
// OnDNSRoundTripForLookupHost implements model.Trace.OnDNSRoundTripForLookupHost. // OnDNSRoundTripForLookupHost implements model.Trace.OnDNSRoundTripForLookupHost.
func (*traceDefault) OnDNSRoundTripForLookupHost(started time.Time, reso model.Resolver, query model.DNSQuery, func (*traceDefault) OnDNSRoundTripForLookupHost(started time.Time, reso model.Resolver, query model.DNSQuery,
response model.DNSResponse, addrs []string, err error, finished time.Time) { response model.DNSResponse, addrs []string, err error, finished time.Time) {