diff --git a/internal/engine/experiment/smtp/smtp.go b/internal/engine/experiment/smtp/smtp.go index 09d3484..50b66af 100644 --- a/internal/engine/experiment/smtp/smtp.go +++ b/internal/engine/experiment/smtp/smtp.go @@ -92,7 +92,7 @@ type TestKeys struct { Queries []*model.ArchivalDNSLookupResult `json:"queries"` TCPConnect []*model.ArchivalTCPConnectResult `json:"tcp_connect"` TLSHandshakes []*model.ArchivalTLSOrQUICHandshakeResult `json:"tls_handshakes"` - SMTPErrors map[string][]*string `json:"smtp"` + Errors map[string][]*string `json:"smtp"` NoOpCounter uint8 `json:"successful_noops"` // Used for global failure (DNS resolution) Failure string `json:"failure"` @@ -117,9 +117,7 @@ func (m Measurer) ExperimentVersion() string { return testVersion } -// Manages sequential SMTP sessions to the same hostname (over different IPs) -// don't use in parallel! -type SMTPRunner struct { +type TCPRunner struct { trace *measurexlite.Trace logger model.Logger ctx context.Context @@ -127,17 +125,84 @@ type SMTPRunner struct { tlsconfig *tls.Config host string port string - // addr is changed everytime SMTPRunner.conn(addr) is called - addr string } -func (r SMTPRunner) smtp_error(err error) { - key := net.JoinHostPort(r.addr, r.port) - // Key is initialized in conn() no need to check here - r.tk.SMTPErrors[key] = append(r.tk.SMTPErrors[key], tracex.NewFailure(err)) +type TCPSession struct { + addr string + port string + runner *TCPRunner + errors []*string + tls bool + raw_conn *net.Conn + tls_conn *net.Conn } -func (r SMTPRunner) resolve(host string) ([]string, bool) { +func (s TCPSession) Close() { + if s.tls { + var conn = *s.tls_conn + conn.Close() + //*(s.tls_conn).Close() + } else { + var conn = *s.raw_conn + conn.Close() + //*(s.raw_conn).Close() + } +} + +func (s TCPSession) current_conn() net.Conn { + if s.tls { + return *s.tls_conn + } else { + return *s.raw_conn + } +} + +func (r TCPRunner) conn(addr string, port string) (*TCPSession, bool) { + key := net.JoinHostPort(addr, port) + // Initialize errors + if r.tk.Errors == nil { + r.tk.Errors = make(map[string][]*string) + } + r.tk.Errors[key] = []*string{} + + s := new(TCPSession) + if !s.conn(addr, port, r, r.tk.Errors[key]) { + return nil, false + } + return s, true +} + +func (r TCPRunner) dial(addr string, port string) (net.Conn, error) { + dialer := r.trace.NewDialerWithoutResolver(r.logger) + conn, err := dialer.DialContext(r.ctx, "tcp", net.JoinHostPort(addr, port)) + r.tk.TCPConnect = append(r.tk.TCPConnect, r.trace.TCPConnects()...) + return conn, err + +} + +func (s TCPSession) conn(addr string, port string, runner TCPRunner, errors []*string) bool { + // Initialize addr field and corresponding errors in TestKeys + s.addr = addr + s.port = port + s.tls = false + s.runner = &runner + s.errors = errors + + conn, err := runner.dial(addr, port) + if err != nil { + s.error(err) + return false + } + s.raw_conn = &conn + + return true +} + +func (s TCPSession) error(err error) { + s.errors = append(s.errors, tracex.NewFailure(err)) +} + +func (r TCPRunner) resolve(host string) ([]string, bool) { r.logger.Infof("Resolving DNS for %s", host) resolver := r.trace.NewStdlibResolver(r.logger) addrs, err := resolver.LookupHost(r.ctx, host) @@ -151,77 +216,70 @@ func (r SMTPRunner) resolve(host string) ([]string, bool) { return addrs, true } -func (r SMTPRunner) conn(addr string) (net.Conn, bool) { - // Initialize addr field and corresponding errors in TestKeys - r.addr = addr - if r.tk.SMTPErrors == nil { - r.tk.SMTPErrors = make(map[string][]*string) +func (s TCPSession) handshake() bool { + if s.tls { + // TLS already initialized... + return true } - r.tk.SMTPErrors[net.JoinHostPort(addr, r.port)] = []*string{} - - dialer := r.trace.NewDialerWithoutResolver(r.logger) - conn, err := dialer.DialContext(r.ctx, "tcp", net.JoinHostPort(r.addr, r.port)) - r.tk.TCPConnect = append(r.tk.TCPConnect, r.trace.TCPConnects()...) + s.runner.logger.Infof("Starting TLS handshake with %s:%s", s.addr, s.port) + thx := s.runner.trace.NewTLSHandshakerStdlib(s.runner.logger) + tconn, _, err := thx.Handshake(s.runner.ctx, *s.raw_conn, s.runner.tlsconfig) + s.runner.tk.TLSHandshakes = append(s.runner.tk.TLSHandshakes, s.runner.trace.FirstTLSHandshakeOrNil()) if err != nil { - r.smtp_error(err) - return nil, false + s.error(err) + return false } - return conn, true + + s.tls = true + s.tls_conn = &tconn + s.runner.logger.Infof("Handshake succeeded") + return true } -func (r SMTPRunner) handshake(conn net.Conn) (net.Conn, bool) { - r.logger.Infof("Starting TLS handshake with %s:%s (%s)", r.host, r.port, r.addr) - thx := r.trace.NewTLSHandshakerStdlib(r.logger) - tconn, _, err := thx.Handshake(r.ctx, conn, r.tlsconfig) - r.tk.TLSHandshakes = append(r.tk.TLSHandshakes, r.trace.FirstTLSHandshakeOrNil()) - if err != nil { - r.smtp_error(err) - return nil, false +func (s TCPSession) starttls(message string) bool { + if s.tls { + // TLS already initialized... + return true } - r.logger.Infof("Handshake succeeded") - return tconn, true -} - -func (r SMTPRunner) starttls(conn net.Conn, message string) (net.Conn, bool) { if message != "" { - r.logger.Infof("Asking for StartTLS upgrade") - conn.Write([]byte(message)) + s.runner.logger.Infof("Asking for StartTLS upgrade") + s.current_conn().Write([]byte(message)) } - tconn, success := r.handshake(conn) - return tconn, success + return s.handshake() } -func (r SMTPRunner) smtp(conn net.Conn, ehlo string, noop uint8) bool { - client, err := smtp.NewClient(conn, ehlo) +func (s TCPSession) smtp(ehlo string, noop uint8) bool { + // Auto-choose plaintext/TCP session + client, err := smtp.NewClient(s.current_conn(), ehlo) if err != nil { - r.smtp_error(err) + s.error(err) return false } err = client.Hello(ehlo) if err != nil { - r.smtp_error(err) + s.error(err) return false } if noop > 0 { - r.logger.Infof("Trying to generate more no-op traffic") + s.runner.logger.Infof("Trying to generate more no-op traffic") // TODO: noop counter per IP address - r.tk.NoOpCounter = 0 - for r.tk.NoOpCounter < noop { - r.tk.NoOpCounter += 1 - r.logger.Infof("NoOp Iteration %d", r.tk.NoOpCounter) + s.runner.tk.NoOpCounter = 0 + for s.runner.tk.NoOpCounter < noop { + s.runner.tk.NoOpCounter += 1 + s.runner.logger.Infof("NoOp Iteration %d", s.runner.tk.NoOpCounter) err = client.Noop() if err != nil { - r.smtp_error(err) + s.error(err) break } } - if r.tk.NoOpCounter == noop { - r.logger.Infof("Successfully generated no-op traffic") + if s.runner.tk.NoOpCounter == noop { + s.runner.logger.Infof("Successfully generated no-op traffic") return true } else { - r.logger.Infof("Failed no-op traffic at iteration %d", r.tk.NoOpCounter) + s.runner.logger.Infof("Failed no-op traffic at iteration %d", s.runner.tk.NoOpCounter) return false } } @@ -255,7 +313,7 @@ func (m Measurer) Run( ServerName: config.host, } - runner := SMTPRunner{ + runner := TCPRunner{ trace: trace, logger: log, ctx: ctx, @@ -263,7 +321,6 @@ func (m Measurer) Run( tlsconfig: &tlsconfig, host: config.host, port: config.port, - addr: "", } // First resolve DNS @@ -273,38 +330,34 @@ func (m Measurer) Run( } for _, addr := range addrs { - conn, success := runner.conn(addr) + tcp_session, success := runner.conn(addr, config.port) if !success { - return nil + continue } - defer conn.Close() + defer tcp_session.Close() if config.forced_tls { // Direct TLS connection - tconn, success := runner.handshake(conn) - if !success { + if !tcp_session.handshake() { continue } - defer tconn.Close() // Try EHLO + NoOps - if !runner.smtp(tconn, "localhost", 10) { + if !tcp_session.smtp("localhost", 10) { continue } } else { // StartTLS... first try plaintext EHLO - if !runner.smtp(conn, "localhost", 0) { + if !tcp_session.smtp("localhost", 0) { continue } // Upgrade via StartTLS and try EHLO + NoOps - tconn, success := runner.starttls(conn, "STARTTLS\n") - if !success { + if !tcp_session.starttls("STARTTLS\n") { continue } - defer tconn.Close() - if !runner.smtp(tconn, "localhost", 10) { + if !tcp_session.smtp("localhost", 10) { continue } }