package dialer import ( "context" "net" "github.com/ooni/probe-cli/v3/internal/bytecounter" "github.com/ooni/probe-cli/v3/internal/model" ) // byteCounterDialer is a byte-counting-aware dialer. To perform byte counting, you // should make sure that you insert this dialer in the dialing chain. type byteCounterDialer struct { model.Dialer } // DialContext implements Dialer.DialContext func (d *byteCounterDialer) DialContext( ctx context.Context, network, address string) (net.Conn, error) { conn, err := d.Dialer.DialContext(ctx, network, address) if err != nil { return nil, err } if exp := contextExperimentByteCounter(ctx); exp != nil { conn = &bytecounter.Conn{Conn: conn, Counter: exp} } if sess := contextSessionByteCounter(ctx); sess != nil { conn = &bytecounter.Conn{Conn: conn, Counter: sess} } return conn, nil } type byteCounterSessionKey struct{} // contextSessionByteCounter retrieves the session byte counter from the context func contextSessionByteCounter(ctx context.Context) *bytecounter.Counter { counter, _ := ctx.Value(byteCounterSessionKey{}).(*bytecounter.Counter) return counter } // WithSessionByteCounter assigns the session byte counter to the context. func WithSessionByteCounter(ctx context.Context, counter *bytecounter.Counter) context.Context { return context.WithValue(ctx, byteCounterSessionKey{}, counter) } type byteCounterExperimentKey struct{} // contextExperimentByteCounter retrieves the experiment byte counter from the context func contextExperimentByteCounter(ctx context.Context) *bytecounter.Counter { counter, _ := ctx.Value(byteCounterExperimentKey{}).(*bytecounter.Counter) return counter } // WithExperimentByteCounter assigns the experiment byte counter to the context. func WithExperimentByteCounter(ctx context.Context, counter *bytecounter.Counter) context.Context { return context.WithValue(ctx, byteCounterExperimentKey{}, counter) }