package netxlite

import (
	"bytes"
	"context"
	"crypto/tls"
	"errors"
	"io"
	"math"
	"net"
	"testing"
	"time"

	"github.com/ooni/probe-cli/v3/internal/model"
	"github.com/ooni/probe-cli/v3/internal/model/mocks"
)

func TestDNSOverTCPTransport(t *testing.T) {
	t.Run("RoundTrip", func(t *testing.T) {
		t.Run("cannot encode query", func(t *testing.T) {
			expected := errors.New("mocked error")
			const address = "9.9.9.9:53"
			txp := NewUnwrappedDNSOverTCPTransport(new(net.Dialer).DialContext, address)
			query := &mocks.DNSQuery{
				MockBytes: func() ([]byte, error) {
					return nil, expected
				},
			}
			resp, err := txp.RoundTrip(context.Background(), query)
			if !errors.Is(err, expected) {
				t.Fatal("unexpected err", err)
			}
			if resp != nil {
				t.Fatal("expected nil response here")
			}
		})

		t.Run("query too large", func(t *testing.T) {
			const address = "9.9.9.9:53"
			txp := NewUnwrappedDNSOverTCPTransport(new(net.Dialer).DialContext, address)
			query := &mocks.DNSQuery{
				MockBytes: func() ([]byte, error) {
					return make([]byte, math.MaxUint16+1), nil
				},
			}
			resp, err := txp.RoundTrip(context.Background(), query)
			if !errors.Is(err, errQueryTooLarge) {
				t.Fatal("unexpected err", err)
			}
			if resp != nil {
				t.Fatal("expected nil response here")
			}
		})

		t.Run("dial failure", func(t *testing.T) {
			const address = "9.9.9.9:53"
			mocked := errors.New("mocked error")
			query := &mocks.DNSQuery{
				MockBytes: func() ([]byte, error) {
					return make([]byte, 128), nil
				},
			}
			fakedialer := &mocks.Dialer{
				MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
					return nil, mocked
				},
			}
			txp := NewUnwrappedDNSOverTCPTransport(fakedialer.DialContext, address)
			resp, err := txp.RoundTrip(context.Background(), query)
			if !errors.Is(err, mocked) {
				t.Fatal("not the error we expected")
			}
			if resp != nil {
				t.Fatal("expected nil resp here")
			}
		})

		t.Run("write failure", func(t *testing.T) {
			const address = "9.9.9.9:53"
			mocked := errors.New("mocked error")
			query := &mocks.DNSQuery{
				MockBytes: func() ([]byte, error) {
					return make([]byte, 128), nil
				},
			}
			fakedialer := &mocks.Dialer{
				MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
					return &mocks.Conn{
						MockSetDeadline: func(t time.Time) error {
							return nil
						},
						MockWrite: func(b []byte) (int, error) {
							return 0, mocked
						},
						MockClose: func() error {
							return nil
						},
					}, nil
				},
			}
			txp := NewUnwrappedDNSOverTCPTransport(fakedialer.DialContext, address)
			resp, err := txp.RoundTrip(context.Background(), query)
			if !errors.Is(err, mocked) {
				t.Fatal("not the error we expected")
			}
			if resp != nil {
				t.Fatal("expected nil resp here")
			}
		})

		t.Run("first read fails", func(t *testing.T) {
			const address = "9.9.9.9:53"
			mocked := errors.New("mocked error")
			query := &mocks.DNSQuery{
				MockBytes: func() ([]byte, error) {
					return make([]byte, 128), nil
				},
			}
			fakedialer := &mocks.Dialer{
				MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
					return &mocks.Conn{
						MockSetDeadline: func(t time.Time) error {
							return nil
						},
						MockWrite: func(b []byte) (int, error) {
							return len(b), nil
						},
						MockRead: func(b []byte) (int, error) {
							return 0, mocked
						},
						MockClose: func() error {
							return nil
						},
					}, nil
				},
			}
			txp := NewUnwrappedDNSOverTCPTransport(fakedialer.DialContext, address)
			resp, err := txp.RoundTrip(context.Background(), query)
			if !errors.Is(err, mocked) {
				t.Fatal("not the error we expected")
			}
			if resp != nil {
				t.Fatal("expected nil resp here")
			}
		})

		t.Run("second read fails", func(t *testing.T) {
			const address = "9.9.9.9:53"
			mocked := errors.New("mocked error")
			query := &mocks.DNSQuery{
				MockBytes: func() ([]byte, error) {
					return make([]byte, 128), nil
				},
			}
			input := io.MultiReader(
				bytes.NewReader([]byte{byte(0), byte(2)}),
				&mocks.Reader{
					MockRead: func(b []byte) (int, error) {
						return 0, mocked
					},
				},
			)
			fakedialer := &mocks.Dialer{
				MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
					return &mocks.Conn{
						MockSetDeadline: func(t time.Time) error {
							return nil
						},
						MockWrite: func(b []byte) (int, error) {
							return len(b), nil
						},
						MockRead: input.Read,
						MockClose: func() error {
							return nil
						},
					}, nil
				},
			}
			txp := NewUnwrappedDNSOverTCPTransport(fakedialer.DialContext, address)
			resp, err := txp.RoundTrip(context.Background(), query)
			if !errors.Is(err, mocked) {
				t.Fatal("not the error we expected")
			}
			if resp != nil {
				t.Fatal("expected nil resp here")
			}
		})

		t.Run("decode failure", func(t *testing.T) {
			const address = "9.9.9.9:53"
			mocked := errors.New("mocked error")
			query := &mocks.DNSQuery{
				MockBytes: func() ([]byte, error) {
					return make([]byte, 128), nil
				},
			}
			input := bytes.NewReader([]byte{byte(0), byte(1), byte(1)})
			fakedialer := &mocks.Dialer{
				MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
					return &mocks.Conn{
						MockSetDeadline: func(t time.Time) error {
							return nil
						},
						MockWrite: func(b []byte) (int, error) {
							return len(b), nil
						},
						MockRead: input.Read,
						MockClose: func() error {
							return nil
						},
					}, nil
				},
			}
			txp := NewUnwrappedDNSOverTCPTransport(fakedialer.DialContext, address)
			txp.decoder = &mocks.DNSDecoder{
				MockDecodeResponse: func(data []byte, query model.DNSQuery) (model.DNSResponse, error) {
					return nil, mocked
				},
			}
			resp, err := txp.RoundTrip(context.Background(), query)
			if !errors.Is(err, mocked) {
				t.Fatal("unexpected err", err)
			}
			if resp != nil {
				t.Fatal("expected nil resp here")
			}
		})

		t.Run("successful case", func(t *testing.T) {
			const address = "9.9.9.9:53"
			query := &mocks.DNSQuery{
				MockBytes: func() ([]byte, error) {
					return make([]byte, 128), nil
				},
			}
			input := bytes.NewReader([]byte{byte(0), byte(1), byte(1)})
			fakedialer := &mocks.Dialer{
				MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
					return &mocks.Conn{
						MockSetDeadline: func(t time.Time) error {
							return nil
						},
						MockWrite: func(b []byte) (int, error) {
							return len(b), nil
						},
						MockRead: input.Read,
						MockClose: func() error {
							return nil
						},
					}, nil
				},
			}
			txp := NewUnwrappedDNSOverTCPTransport(fakedialer.DialContext, address)
			expectedResp := &mocks.DNSResponse{}
			txp.decoder = &mocks.DNSDecoder{
				MockDecodeResponse: func(data []byte, query model.DNSQuery) (model.DNSResponse, error) {
					return expectedResp, nil
				},
			}
			resp, err := txp.RoundTrip(context.Background(), query)
			if err != nil {
				t.Fatal(err)
			}
			if resp != expectedResp {
				t.Fatal("not the response we expected")
			}
		})
	})

	t.Run("other functions okay with TCP", func(t *testing.T) {
		const address = "9.9.9.9:53"
		txp := NewUnwrappedDNSOverTCPTransport(new(net.Dialer).DialContext, address)
		if txp.RequiresPadding() != false {
			t.Fatal("invalid RequiresPadding")
		}
		if txp.Network() != "tcp" {
			t.Fatal("invalid Network")
		}
		if txp.Address() != address {
			t.Fatal("invalid Address")
		}
		txp.CloseIdleConnections()
	})

	t.Run("other functions okay with TLS", func(t *testing.T) {
		const address = "9.9.9.9:853"
		txp := NewUnwrappedDNSOverTLSTransport((&tls.Dialer{}).DialContext, address)
		if txp.RequiresPadding() != true {
			t.Fatal("invalid RequiresPadding")
		}
		if txp.Network() != "dot" {
			t.Fatal("invalid Network")
		}
		if txp.Address() != address {
			t.Fatal("invalid Address")
		}
		txp.CloseIdleConnections()
	})
}