package netxlite

import (
	"context"
	"crypto/tls"
	"errors"
	"net/http"
	"testing"

	"github.com/apex/log"
	"github.com/lucas-clemente/quic-go"
	"github.com/lucas-clemente/quic-go/http3"
	"github.com/ooni/probe-cli/v3/internal/model/mocks"
	nlmocks "github.com/ooni/probe-cli/v3/internal/netxlite/mocks"
)

func TestHTTP3Dialer(t *testing.T) {
	t.Run("Dial", func(t *testing.T) {
		expected := errors.New("mocked error")
		d := &http3Dialer{
			QUICDialer: &mocks.QUICDialer{
				MockDialContext: func(ctx context.Context, network, address string, tlsConfig *tls.Config, quicConfig *quic.Config) (quic.EarlySession, error) {
					return nil, expected
				},
			},
		}
		sess, err := d.dial("", "", &tls.Config{}, &quic.Config{})
		if !errors.Is(err, expected) {
			t.Fatal("unexpected err", err)
		}
		if sess != nil {
			t.Fatal("unexpected resp")
		}
	})
}

func TestHTTP3Transport(t *testing.T) {
	t.Run("CloseIdleConnections", func(t *testing.T) {
		var (
			calledHTTP3  bool
			calledDialer bool
		)
		txp := &http3Transport{
			child: &nlmocks.HTTP3RoundTripper{
				MockClose: func() error {
					calledHTTP3 = true
					return nil
				},
			},
			dialer: &mocks.QUICDialer{
				MockCloseIdleConnections: func() {
					calledDialer = true
				},
			},
		}
		txp.CloseIdleConnections()
		if !calledHTTP3 || !calledDialer {
			t.Fatal("not called")
		}
	})

	t.Run("Network", func(t *testing.T) {
		txp := &http3Transport{}
		if txp.Network() != "quic" {
			t.Fatal("unexpected .Network return value")
		}
	})

	t.Run("RoundTrip", func(t *testing.T) {
		expected := errors.New("mocked error")
		txp := &http3Transport{
			child: &nlmocks.HTTP3RoundTripper{
				MockRoundTrip: func(req *http.Request) (*http.Response, error) {
					return nil, expected
				},
			},
		}
		resp, err := txp.RoundTrip(&http.Request{})
		if !errors.Is(err, expected) {
			t.Fatal("unexpected err", err)
		}
		if resp != nil {
			t.Fatal("unexpected resp")
		}
	})
}

func TestNewHTTP3Transport(t *testing.T) {
	t.Run("creates the correct type chain", func(t *testing.T) {
		qd := &mocks.QUICDialer{}
		config := &tls.Config{}
		txp := NewHTTP3Transport(log.Log, qd, config)
		logger := txp.(*httpTransportLogger)
		if logger.Logger != log.Log {
			t.Fatal("invalid logger")
		}
		h3txp := logger.HTTPTransport.(*http3Transport)
		if h3txp.dialer != qd {
			t.Fatal("invalid dialer")
		}
		h3 := h3txp.child.(*http3.RoundTripper)
		if h3.Dial == nil {
			t.Fatal("invalid Dial")
		}
		if !h3.DisableCompression {
			t.Fatal("invalid DisableCompression")
		}
		if h3.TLSClientConfig != config {
			t.Fatal("invalid TLSClientConfig")
		}
	})
}