package dialer_test import ( "context" "crypto/tls" "errors" "net" "reflect" "testing" "time" "github.com/ooni/probe-cli/v3/internal/engine/netx/dialer" "github.com/ooni/probe-cli/v3/internal/engine/netx/errorx" "github.com/ooni/probe-cli/v3/internal/engine/netx/trace" ) func TestSaverDialerFailure(t *testing.T) { expected := errors.New("mocked error") saver := &trace.Saver{} dlr := dialer.SaverDialer{ Dialer: dialer.FakeDialer{ Err: expected, }, Saver: saver, } conn, err := dlr.DialContext(context.Background(), "tcp", "www.google.com:443") if !errors.Is(err, expected) { t.Fatal("expected another error here") } if conn != nil { t.Fatal("expected nil conn here") } ev := saver.Read() if len(ev) != 1 { t.Fatal("expected a single event here") } if ev[0].Address != "www.google.com:443" { t.Fatal("unexpected Address") } if ev[0].Duration <= 0 { t.Fatal("unexpected Duration") } if !errors.Is(ev[0].Err, expected) { t.Fatal("unexpected Err") } if ev[0].Name != errorx.ConnectOperation { t.Fatal("unexpected Name") } if ev[0].Proto != "tcp" { t.Fatal("unexpected Proto") } if !ev[0].Time.Before(time.Now()) { t.Fatal("unexpected Time") } } func TestSaverConnDialerFailure(t *testing.T) { expected := errors.New("mocked error") saver := &trace.Saver{} dlr := dialer.SaverConnDialer{ Dialer: dialer.FakeDialer{ Err: expected, }, Saver: saver, } conn, err := dlr.DialContext(context.Background(), "tcp", "www.google.com:443") if !errors.Is(err, expected) { t.Fatal("not the error we expected") } if conn != nil { t.Fatal("expected nil conn here") } } func TestSaverTLSHandshakerSuccessWithReadWrite(t *testing.T) { // This is the most common use case for collecting reads, writes if testing.Short() { t.Skip("skip test in short mode") } nextprotos := []string{"h2"} saver := &trace.Saver{} tlsdlr := dialer.TLSDialer{ Config: &tls.Config{NextProtos: nextprotos}, Dialer: dialer.SaverConnDialer{ Dialer: new(net.Dialer), Saver: saver, }, TLSHandshaker: dialer.SaverTLSHandshaker{ TLSHandshaker: dialer.SystemTLSHandshaker{}, Saver: saver, }, } // Implementation note: we don't close the connection here because it is // very handy to have the last event being the end of the handshake _, err := tlsdlr.DialTLSContext(context.Background(), "tcp", "www.google.com:443") if err != nil { t.Fatal(err) } ev := saver.Read() if len(ev) < 4 { // it's a bit tricky to be sure about the right number of // events because network conditions may influence that t.Fatal("unexpected number of events") } if ev[0].Name != "tls_handshake_start" { t.Fatal("unexpected Name") } if ev[0].TLSServerName != "www.google.com" { t.Fatal("unexpected TLSServerName") } if !reflect.DeepEqual(ev[0].TLSNextProtos, nextprotos) { t.Fatal("unexpected TLSNextProtos") } if ev[0].Time.After(time.Now()) { t.Fatal("unexpected Time") } last := len(ev) - 1 for idx := 1; idx < last; idx++ { if ev[idx].Data == nil { t.Fatal("unexpected Data") } if ev[idx].Duration <= 0 { t.Fatal("unexpected Duration") } if ev[idx].Err != nil { t.Fatal("unexpected Err") } if ev[idx].NumBytes <= 0 { t.Fatal("unexpected NumBytes") } switch ev[idx].Name { case errorx.ReadOperation, errorx.WriteOperation: default: t.Fatal("unexpected Name") } if ev[idx].Time.Before(ev[idx-1].Time) { t.Fatal("unexpected Time") } } if ev[last].Duration <= 0 { t.Fatal("unexpected Duration") } if ev[last].Err != nil { t.Fatal("unexpected Err") } if ev[last].Name != "tls_handshake_done" { t.Fatal("unexpected Name") } if ev[last].TLSCipherSuite == "" { t.Fatal("unexpected TLSCipherSuite") } if ev[last].TLSNegotiatedProto != "h2" { t.Fatal("unexpected TLSNegotiatedProto") } if !reflect.DeepEqual(ev[last].TLSNextProtos, nextprotos) { t.Fatal("unexpected TLSNextProtos") } if ev[last].TLSPeerCerts == nil { t.Fatal("unexpected TLSPeerCerts") } if ev[last].TLSServerName != "www.google.com" { t.Fatal("unexpected TLSServerName") } if ev[last].TLSVersion == "" { t.Fatal("unexpected TLSVersion") } if ev[last].Time.Before(ev[last-1].Time) { t.Fatal("unexpected Time") } } func TestSaverTLSHandshakerSuccess(t *testing.T) { if testing.Short() { t.Skip("skip test in short mode") } nextprotos := []string{"h2"} saver := &trace.Saver{} tlsdlr := dialer.TLSDialer{ Config: &tls.Config{NextProtos: nextprotos}, Dialer: new(net.Dialer), TLSHandshaker: dialer.SaverTLSHandshaker{ TLSHandshaker: dialer.SystemTLSHandshaker{}, Saver: saver, }, } conn, err := tlsdlr.DialTLSContext(context.Background(), "tcp", "www.google.com:443") if err != nil { t.Fatal(err) } conn.Close() ev := saver.Read() if len(ev) != 2 { t.Fatal("unexpected number of events") } if ev[0].Name != "tls_handshake_start" { t.Fatal("unexpected Name") } if ev[0].TLSServerName != "www.google.com" { t.Fatal("unexpected TLSServerName") } if !reflect.DeepEqual(ev[0].TLSNextProtos, nextprotos) { t.Fatal("unexpected TLSNextProtos") } if ev[0].Time.After(time.Now()) { t.Fatal("unexpected Time") } if ev[1].Duration <= 0 { t.Fatal("unexpected Duration") } if ev[1].Err != nil { t.Fatal("unexpected Err") } if ev[1].Name != "tls_handshake_done" { t.Fatal("unexpected Name") } if ev[1].TLSCipherSuite == "" { t.Fatal("unexpected TLSCipherSuite") } if ev[1].TLSNegotiatedProto != "h2" { t.Fatal("unexpected TLSNegotiatedProto") } if !reflect.DeepEqual(ev[1].TLSNextProtos, nextprotos) { t.Fatal("unexpected TLSNextProtos") } if ev[1].TLSPeerCerts == nil { t.Fatal("unexpected TLSPeerCerts") } if ev[1].TLSServerName != "www.google.com" { t.Fatal("unexpected TLSServerName") } if ev[1].TLSVersion == "" { t.Fatal("unexpected TLSVersion") } if ev[1].Time.Before(ev[0].Time) { t.Fatal("unexpected Time") } } func TestSaverTLSHandshakerHostnameError(t *testing.T) { if testing.Short() { t.Skip("skip test in short mode") } saver := &trace.Saver{} tlsdlr := dialer.TLSDialer{ Dialer: new(net.Dialer), TLSHandshaker: dialer.SaverTLSHandshaker{ TLSHandshaker: dialer.SystemTLSHandshaker{}, Saver: saver, }, } conn, err := tlsdlr.DialTLSContext( context.Background(), "tcp", "wrong.host.badssl.com:443") if err == nil { t.Fatal("expected an error here") } if conn != nil { t.Fatal("expected nil conn here") } for _, ev := range saver.Read() { if ev.Name != "tls_handshake_done" { continue } if ev.NoTLSVerify == true { t.Fatal("expected NoTLSVerify to be false") } if len(ev.TLSPeerCerts) < 1 { t.Fatal("expected at least a certificate here") } } } func TestSaverTLSHandshakerInvalidCertError(t *testing.T) { if testing.Short() { t.Skip("skip test in short mode") } saver := &trace.Saver{} tlsdlr := dialer.TLSDialer{ Dialer: new(net.Dialer), TLSHandshaker: dialer.SaverTLSHandshaker{ TLSHandshaker: dialer.SystemTLSHandshaker{}, Saver: saver, }, } conn, err := tlsdlr.DialTLSContext( context.Background(), "tcp", "expired.badssl.com:443") if err == nil { t.Fatal("expected an error here") } if conn != nil { t.Fatal("expected nil conn here") } for _, ev := range saver.Read() { if ev.Name != "tls_handshake_done" { continue } if ev.NoTLSVerify == true { t.Fatal("expected NoTLSVerify to be false") } if len(ev.TLSPeerCerts) < 1 { t.Fatal("expected at least a certificate here") } } } func TestSaverTLSHandshakerAuthorityError(t *testing.T) { if testing.Short() { t.Skip("skip test in short mode") } saver := &trace.Saver{} tlsdlr := dialer.TLSDialer{ Dialer: new(net.Dialer), TLSHandshaker: dialer.SaverTLSHandshaker{ TLSHandshaker: dialer.SystemTLSHandshaker{}, Saver: saver, }, } conn, err := tlsdlr.DialTLSContext( context.Background(), "tcp", "self-signed.badssl.com:443") if err == nil { t.Fatal("expected an error here") } if conn != nil { t.Fatal("expected nil conn here") } for _, ev := range saver.Read() { if ev.Name != "tls_handshake_done" { continue } if ev.NoTLSVerify == true { t.Fatal("expected NoTLSVerify to be false") } if len(ev.TLSPeerCerts) < 1 { t.Fatal("expected at least a certificate here") } } } func TestSaverTLSHandshakerNoTLSVerify(t *testing.T) { if testing.Short() { t.Skip("skip test in short mode") } saver := &trace.Saver{} tlsdlr := dialer.TLSDialer{ Config: &tls.Config{InsecureSkipVerify: true}, Dialer: new(net.Dialer), TLSHandshaker: dialer.SaverTLSHandshaker{ TLSHandshaker: dialer.SystemTLSHandshaker{}, Saver: saver, }, } conn, err := tlsdlr.DialTLSContext( context.Background(), "tcp", "self-signed.badssl.com:443") if err != nil { t.Fatal(err) } if conn == nil { t.Fatal("expected non-nil conn here") } conn.Close() for _, ev := range saver.Read() { if ev.Name != "tls_handshake_done" { continue } if ev.NoTLSVerify != true { t.Fatal("expected NoTLSVerify to be true") } if len(ev.TLSPeerCerts) < 1 { t.Fatal("expected at least a certificate here") } } }