package measurexlite

import (
	"bytes"
	"context"
	"crypto/tls"
	"crypto/x509"
	"errors"
	"net"
	"reflect"
	"testing"
	"time"

	"github.com/google/go-cmp/cmp"
	"github.com/ooni/probe-cli/v3/internal/model"
	"github.com/ooni/probe-cli/v3/internal/model/mocks"
	"github.com/ooni/probe-cli/v3/internal/netxlite"
	"github.com/ooni/probe-cli/v3/internal/netxlite/filtering"
	"github.com/ooni/probe-cli/v3/internal/testingx"
)

func TestNewTLSHandshakerStdlib(t *testing.T) {
	t.Run("NewTLSHandshakerStdlib creates a wrapped TLSHandshaker", func(t *testing.T) {
		underlying := &mocks.TLSHandshaker{}
		zeroTime := time.Now()
		trace := NewTrace(0, zeroTime)
		trace.NewTLSHandshakerStdlibFn = func(dl model.DebugLogger) model.TLSHandshaker {
			return underlying
		}
		thx := trace.NewTLSHandshakerStdlib(model.DiscardLogger)
		thxt := thx.(*tlsHandshakerTrace)
		if thxt.thx != underlying {
			t.Fatal("invalid TLS handshaker")
		}
		if thxt.tx != trace {
			t.Fatal("invalid trace")
		}
	})

	t.Run("Handshake calls the underlying dialer with context-based tracing", func(t *testing.T) {
		expectedErr := errors.New("mocked err")
		zeroTime := time.Now()
		trace := NewTrace(0, zeroTime)
		var hasCorrectTrace bool
		underlying := &mocks.TLSHandshaker{
			MockHandshake: func(ctx context.Context, conn net.Conn, config *tls.Config) (net.Conn, tls.ConnectionState, error) {
				gotTrace := netxlite.ContextTraceOrDefault(ctx)
				hasCorrectTrace = (gotTrace == trace)
				return nil, tls.ConnectionState{}, expectedErr
			},
		}
		trace.NewTLSHandshakerStdlibFn = func(dl model.DebugLogger) model.TLSHandshaker {
			return underlying
		}
		thx := trace.NewTLSHandshakerStdlib(model.DiscardLogger)
		ctx := context.Background()
		conn, state, err := thx.Handshake(ctx, &mocks.Conn{}, &tls.Config{})
		if !errors.Is(err, expectedErr) {
			t.Fatal("unexpected err", err)
		}
		if !reflect.ValueOf(state).IsZero() {
			t.Fatal("expected zero-value state")
		}
		if conn != nil {
			t.Fatal("expected nil conn")
		}
		if !hasCorrectTrace {
			t.Fatal("does not have the correct trace")
		}
	})

	t.Run("Handshake saves into the trace", func(t *testing.T) {
		mockedErr := errors.New("mocked")
		zeroTime := time.Now()
		td := testingx.NewTimeDeterministic(zeroTime)
		trace := NewTrace(0, zeroTime)
		trace.TimeNowFn = td.Now // deterministic timing
		thx := trace.NewTLSHandshakerStdlib(model.DiscardLogger)
		ctx := context.Background()
		tcpConn := &mocks.Conn{
			MockSetDeadline: func(t time.Time) error {
				return nil
			},
			MockRemoteAddr: func() net.Addr {
				return &mocks.Addr{
					MockNetwork: func() string {
						return "tcp"
					},
					MockString: func() string {
						return "1.1.1.1:443"
					},
				}
			},
			MockWrite: func(b []byte) (int, error) {
				return 0, mockedErr
			},
			MockClose: func() error {
				return nil
			},
		}
		tlsConfig := &tls.Config{
			InsecureSkipVerify: true,
			ServerName:         "dns.cloudflare.com",
		}
		conn, state, err := thx.Handshake(ctx, tcpConn, tlsConfig)
		if !errors.Is(err, mockedErr) {
			t.Fatal("unexpected err", err)
		}
		if !reflect.ValueOf(state).IsZero() {
			t.Fatal("expected zero-value state")
		}
		if conn != nil {
			t.Fatal("expected nil conn")
		}

		t.Run("TLSHandshake events", func(t *testing.T) {
			events := trace.TLSHandshakes()
			if len(events) != 1 {
				t.Fatal("expected to see single TLSHandshake event")
			}
			expectedFailure := "unknown_failure: mocked"
			expect := &model.ArchivalTLSOrQUICHandshakeResult{
				Network:            "tcp",
				Address:            "1.1.1.1:443",
				CipherSuite:        "",
				Failure:            &expectedFailure,
				NegotiatedProtocol: "",
				NoTLSVerify:        true,
				PeerCertificates:   []model.ArchivalMaybeBinaryData{},
				ServerName:         "dns.cloudflare.com",
				T:                  time.Second.Seconds(),
				Tags:               []string{},
				TLSVersion:         "",
			}
			got := events[0]
			if diff := cmp.Diff(expect, got); diff != "" {
				t.Fatal(diff)
			}
		})

		t.Run("Network events", func(t *testing.T) {
			events := trace.NetworkEvents()
			if len(events) != 2 {
				t.Fatal("expected to see two Network events")
			}

			t.Run("tls_handshake_start", func(t *testing.T) {
				expect := &model.ArchivalNetworkEvent{
					Address:   "",
					Failure:   nil,
					NumBytes:  0,
					Operation: "tls_handshake_start",
					Proto:     "",
					T:         0,
					Tags:      []string{},
				}
				got := events[0]
				if diff := cmp.Diff(expect, got); diff != "" {
					t.Fatal(diff)
				}
			})

			t.Run("tls_handshake_done", func(t *testing.T) {
				expect := &model.ArchivalNetworkEvent{
					Address:   "",
					Failure:   nil,
					NumBytes:  0,
					Operation: "tls_handshake_done",
					Proto:     "",
					T0:        time.Second.Seconds(),
					T:         time.Second.Seconds(),
					Tags:      []string{},
				}
				got := events[1]
				if diff := cmp.Diff(expect, got); diff != "" {
					t.Fatal(diff)
				}
			})
		})
	})

	t.Run("Handshake discards events when buffers are full", func(t *testing.T) {
		mockedErr := errors.New("mocked")
		zeroTime := time.Now()
		trace := NewTrace(0, zeroTime)
		trace.networkEvent = make(chan *model.ArchivalNetworkEvent)             // no buffer
		trace.tlsHandshake = make(chan *model.ArchivalTLSOrQUICHandshakeResult) // no buffer
		thx := trace.NewTLSHandshakerStdlib(model.DiscardLogger)
		ctx := context.Background()
		tcpConn := &mocks.Conn{
			MockSetDeadline: func(t time.Time) error {
				return nil
			},
			MockRemoteAddr: func() net.Addr {
				return &mocks.Addr{
					MockNetwork: func() string {
						return "tcp"
					},
					MockString: func() string {
						return "1.1.1.1:443"
					},
				}
			},
			MockWrite: func(b []byte) (int, error) {
				return 0, mockedErr
			},
			MockClose: func() error {
				return nil
			},
		}
		tlsConfig := &tls.Config{
			InsecureSkipVerify: true,
			ServerName:         "dns.cloudflare.com",
		}
		conn, state, err := thx.Handshake(ctx, tcpConn, tlsConfig)
		if !errors.Is(err, mockedErr) {
			t.Fatal("unexpected err", err)
		}
		if !reflect.ValueOf(state).IsZero() {
			t.Fatal("expected zero-value state")
		}
		if conn != nil {
			t.Fatal("expected nil conn")
		}

		t.Run("TLSHandshake events", func(t *testing.T) {
			events := trace.TLSHandshakes()
			if len(events) != 0 {
				t.Fatal("expected to see no TLSHandshake events")
			}
		})

		t.Run("Network events", func(t *testing.T) {
			events := trace.NetworkEvents()
			if len(events) != 0 {
				t.Fatal("expected to see no Network events")
			}
		})
	})

	t.Run("we collect the desired data with a local TLS server", func(t *testing.T) {
		server := filtering.NewTLSServer(filtering.TLSActionBlockText)
		dialer := netxlite.NewDialerWithoutResolver(model.DiscardLogger)
		ctx := context.Background()
		conn, err := dialer.DialContext(ctx, "tcp", server.Endpoint())
		if err != nil {
			t.Fatal(err)
		}
		defer conn.Close()
		zeroTime := time.Now()
		dt := testingx.NewTimeDeterministic(zeroTime)
		trace := NewTrace(0, zeroTime)
		trace.TimeNowFn = dt.Now // deterministic timing
		thx := trace.NewTLSHandshakerStdlib(model.DiscardLogger)
		tlsConfig := &tls.Config{
			RootCAs:    server.CertPool(),
			ServerName: "dns.google",
		}
		tlsConn, connState, err := thx.Handshake(ctx, conn, tlsConfig)
		if err != nil {
			t.Fatal(err)
		}
		defer tlsConn.Close()
		data, err := netxlite.ReadAllContext(ctx, tlsConn)
		if err != nil {
			t.Fatal(err)
		}
		if !bytes.Equal(data, filtering.HTTPBlockpage451) {
			t.Fatal("bytes should match")
		}

		t.Run("TLSHandshake events", func(t *testing.T) {
			events := trace.TLSHandshakes()
			if len(events) != 1 {
				t.Fatal("expected to see a single TLSHandshake event")
			}
			expected := &model.ArchivalTLSOrQUICHandshakeResult{
				Network:            "tcp",
				Address:            conn.RemoteAddr().String(),
				CipherSuite:        netxlite.TLSCipherSuiteString(connState.CipherSuite),
				Failure:            nil,
				NegotiatedProtocol: "",
				NoTLSVerify:        false,
				PeerCertificates:   []model.ArchivalMaybeBinaryData{},
				ServerName:         "dns.google",
				T:                  time.Second.Seconds(),
				Tags:               []string{},
				TLSVersion:         netxlite.TLSVersionString(connState.Version),
			}
			got := events[0]
			// TODO(bassosimone): it's still unclear to me how to test that
			// I am getting exactly the expected certificate here. I think the
			// certificate is generated on the fly by google/martian. So, I'm
			// just going to reduce the precision of this check.
			if len(got.PeerCertificates) != 2 {
				t.Fatal("expected to see two certificates")
			}
			got.PeerCertificates = []model.ArchivalMaybeBinaryData{} // see above
			if diff := cmp.Diff(expected, got); diff != "" {
				t.Fatal(diff)
			}
		})

		t.Run("Network events", func(t *testing.T) {
			events := trace.NetworkEvents()
			if len(events) != 2 {
				t.Fatal("expected to see two Network events")
			}

			t.Run("tls_handshake_start", func(t *testing.T) {
				expect := &model.ArchivalNetworkEvent{
					Address:   "",
					Failure:   nil,
					NumBytes:  0,
					Operation: "tls_handshake_start",
					Proto:     "",
					T:         0,
					Tags:      []string{},
				}
				got := events[0]
				if diff := cmp.Diff(expect, got); diff != "" {
					t.Fatal(diff)
				}
			})

			t.Run("tls_handshake_done", func(t *testing.T) {
				expect := &model.ArchivalNetworkEvent{
					Address:   "",
					Failure:   nil,
					NumBytes:  0,
					Operation: "tls_handshake_done",
					Proto:     "",
					T0:        time.Second.Seconds(),
					T:         time.Second.Seconds(),
					Tags:      []string{},
				}
				got := events[1]
				if diff := cmp.Diff(expect, got); diff != "" {
					t.Fatal(diff)
				}
			})
		})
	})
}

func TestFirstTLSHandshake(t *testing.T) {
	t.Run("returns nil when buffer is empty", func(t *testing.T) {
		zeroTime := time.Now()
		trace := NewTrace(0, zeroTime)
		got := trace.FirstTLSHandshakeOrNil()
		if got != nil {
			t.Fatal("expected nil event")
		}
	})

	t.Run("return first non-nil TLSHandshake", func(t *testing.T) {
		filler := func(tx *Trace, events []*model.ArchivalTLSOrQUICHandshakeResult) {
			for _, ev := range events {
				tx.tlsHandshake <- ev
			}
		}
		zeroTime := time.Now()
		trace := NewTrace(0, zeroTime)
		expect := []*model.ArchivalTLSOrQUICHandshakeResult{{
			Network:            "tcp",
			Address:            "1.1.1.1:443",
			CipherSuite:        "",
			Failure:            nil,
			NegotiatedProtocol: "",
			NoTLSVerify:        true,
			PeerCertificates:   []model.ArchivalMaybeBinaryData{},
			ServerName:         "dns.cloudflare.com",
			T:                  time.Second.Seconds(),
			Tags:               []string{},
			TLSVersion:         "",
		}, {
			Network:            "tcp",
			Address:            "8.8.8.8:443",
			CipherSuite:        "",
			Failure:            nil,
			NegotiatedProtocol: "",
			NoTLSVerify:        true,
			PeerCertificates:   []model.ArchivalMaybeBinaryData{},
			ServerName:         "dns.google.com",
			T:                  time.Second.Seconds(),
			Tags:               []string{},
			TLSVersion:         "",
		}}
		filler(trace, expect)
		got := trace.FirstTLSHandshakeOrNil()
		if diff := cmp.Diff(got, expect[0]); diff != "" {
			t.Fatal(diff)
		}
	})
}

func TestTLSPeerCerts(t *testing.T) {
	type args struct {
		state tls.ConnectionState
		err   error
	}
	tests := []struct {
		name    string
		args    args
		wantOut []model.ArchivalMaybeBinaryData
	}{{
		name: "x509.HostnameError",
		args: args{
			state: tls.ConnectionState{},
			err: x509.HostnameError{
				Certificate: &x509.Certificate{
					Raw: []byte("deadbeef"),
				},
			},
		},
		wantOut: []model.ArchivalMaybeBinaryData{{
			Value: "deadbeef",
		}},
	}, {
		name: "x509.UnknownAuthorityError",
		args: args{
			state: tls.ConnectionState{},
			err: x509.UnknownAuthorityError{
				Cert: &x509.Certificate{
					Raw: []byte("deadbeef"),
				},
			},
		},
		wantOut: []model.ArchivalMaybeBinaryData{{
			Value: "deadbeef",
		}},
	}, {
		name: "x509.CertificateInvalidError",
		args: args{
			state: tls.ConnectionState{},
			err: x509.CertificateInvalidError{
				Cert: &x509.Certificate{
					Raw: []byte("deadbeef"),
				},
			},
		},
		wantOut: []model.ArchivalMaybeBinaryData{{
			Value: "deadbeef",
		}},
	}, {
		name: "successful case",
		args: args{
			state: tls.ConnectionState{
				PeerCertificates: []*x509.Certificate{{
					Raw: []byte("deadbeef"),
				}, {
					Raw: []byte("abad1dea"),
				}},
			},
			err: nil,
		},
		wantOut: []model.ArchivalMaybeBinaryData{{
			Value: "deadbeef",
		}, {
			Value: "abad1dea",
		}},
	}}
	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			gotOut := TLSPeerCerts(tt.args.state, tt.args.err)
			if diff := cmp.Diff(tt.wantOut, gotOut); diff != "" {
				t.Fatal(diff)
			}
		})
	}
}