package netxlite

import (
	"bytes"
	"context"
	"errors"
	"net"
	"sync"
	"testing"
	"time"

	"github.com/apex/log"
	"github.com/google/go-cmp/cmp"
	"github.com/miekg/dns"
	"github.com/ooni/probe-cli/v3/internal/model"
	"github.com/ooni/probe-cli/v3/internal/model/mocks"
	"github.com/ooni/probe-cli/v3/internal/netxlite/filtering"
	"github.com/ooni/probe-cli/v3/internal/testingx"
)

func TestDNSOverUDPTransport(t *testing.T) {
	t.Run("RoundTrip", func(t *testing.T) {
		t.Run("cannot encode query", func(t *testing.T) {
			expected := errors.New("mocked error")
			const address = "9.9.9.9:53"
			txp := NewUnwrappedDNSOverUDPTransport(nil, address)
			query := &mocks.DNSQuery{
				MockBytes: func() ([]byte, error) {
					return nil, expected
				},
			}
			resp, err := txp.RoundTrip(context.Background(), query)
			if !errors.Is(err, expected) {
				t.Fatal("unexpected err", err)
			}
			if resp != nil {
				t.Fatal("expected nil response here")
			}
		})

		t.Run("dial failure", func(t *testing.T) {
			mocked := errors.New("mocked error")
			const address = "9.9.9.9:53"
			txp := NewUnwrappedDNSOverUDPTransport(&mocks.Dialer{
				MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
					return nil, mocked
				},
			}, address)
			query := &mocks.DNSQuery{
				MockBytes: func() ([]byte, error) {
					return make([]byte, 128), nil
				},
			}
			resp, err := txp.RoundTrip(context.Background(), query)
			if !errors.Is(err, mocked) {
				t.Fatal("not the error we expected")
			}
			if resp != nil {
				t.Fatal("expected no response here")
			}
		})

		t.Run("Write failure", func(t *testing.T) {
			mocked := errors.New("mocked error")
			txp := NewUnwrappedDNSOverUDPTransport(
				&mocks.Dialer{
					MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
						return &mocks.Conn{
							MockSetDeadline: func(t time.Time) error {
								return nil
							},
							MockWrite: func(b []byte) (int, error) {
								return 0, mocked
							},
							MockClose: func() error {
								return nil
							},
							MockLocalAddr: func() net.Addr {
								return &mocks.Addr{
									MockNetwork: func() string {
										return "udp"
									},
									MockString: func() string {
										return "127.0.0.1:1345"
									},
								}
							},
						}, nil
					},
				}, "9.9.9.9:53",
			)
			query := &mocks.DNSQuery{
				MockBytes: func() ([]byte, error) {
					return make([]byte, 128), nil
				},
			}
			resp, err := txp.RoundTrip(context.Background(), query)
			if !errors.Is(err, mocked) {
				t.Fatal("not the error we expected")
			}
			if resp != nil {
				t.Fatal("expected no response here")
			}
		})

		t.Run("Read failure", func(t *testing.T) {
			mocked := errors.New("mocked error")
			txp := NewUnwrappedDNSOverUDPTransport(
				&mocks.Dialer{
					MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
						return &mocks.Conn{
							MockSetDeadline: func(t time.Time) error {
								return nil
							},
							MockWrite: func(b []byte) (int, error) {
								return len(b), nil
							},
							MockRead: func(b []byte) (int, error) {
								return 0, mocked
							},
							MockClose: func() error {
								return nil
							},
							MockLocalAddr: func() net.Addr {
								return &mocks.Addr{
									MockNetwork: func() string {
										return "udp"
									},
									MockString: func() string {
										return "127.0.0.1:1345"
									},
								}
							},
						}, nil
					},
				}, "9.9.9.9:53",
			)
			query := &mocks.DNSQuery{
				MockBytes: func() ([]byte, error) {
					return make([]byte, 128), nil
				},
			}
			resp, err := txp.RoundTrip(context.Background(), query)
			if !errors.Is(err, mocked) {
				t.Fatal("not the error we expected")
			}
			if resp != nil {
				t.Fatal("expected no response here")
			}
		})

		t.Run("decode failure", func(t *testing.T) {
			const expected = 17
			input := bytes.NewReader(make([]byte, expected))
			txp := NewUnwrappedDNSOverUDPTransport(
				&mocks.Dialer{
					MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
						return &mocks.Conn{
							MockSetDeadline: func(t time.Time) error {
								return nil
							},
							MockWrite: func(b []byte) (int, error) {
								return len(b), nil
							},
							MockRead: input.Read,
							MockClose: func() error {
								return nil
							},
							MockLocalAddr: func() net.Addr {
								return &mocks.Addr{
									MockNetwork: func() string {
										return "udp"
									},
									MockString: func() string {
										return "127.0.0.1:1345"
									},
								}
							},
						}, nil
					},
				}, "9.9.9.9:53",
			)
			expectedErr := errors.New("mocked error")
			txp.Decoder = &mocks.DNSDecoder{
				MockDecodeResponse: func(data []byte, query model.DNSQuery) (model.DNSResponse, error) {
					return nil, expectedErr
				},
			}
			query := &mocks.DNSQuery{
				MockBytes: func() ([]byte, error) {
					return make([]byte, 128), nil
				},
			}
			resp, err := txp.RoundTrip(context.Background(), query)
			if !errors.Is(err, expectedErr) {
				t.Fatal("unexpected err", err)
			}
			if resp != nil {
				t.Fatal("expected nil resp")
			}
		})

		t.Run("decode success", func(t *testing.T) {
			const expected = 17
			input := bytes.NewReader(make([]byte, expected))
			txp := NewUnwrappedDNSOverUDPTransport(
				&mocks.Dialer{
					MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
						return &mocks.Conn{
							MockSetDeadline: func(t time.Time) error {
								return nil
							},
							MockWrite: func(b []byte) (int, error) {
								return len(b), nil
							},
							MockRead: input.Read,
							MockClose: func() error {
								return nil
							},
							MockLocalAddr: func() net.Addr {
								return &mocks.Addr{
									MockNetwork: func() string {
										return "udp"
									},
									MockString: func() string {
										return "127.0.0.1:1345"
									},
								}
							},
						}, nil
					},
				}, "9.9.9.9:53",
			)
			expectedResp := &mocks.DNSResponse{}
			txp.Decoder = &mocks.DNSDecoder{
				MockDecodeResponse: func(data []byte, query model.DNSQuery) (model.DNSResponse, error) {
					return expectedResp, nil
				},
			}
			query := &mocks.DNSQuery{
				MockBytes: func() ([]byte, error) {
					return make([]byte, 128), nil
				},
			}
			resp, err := txp.RoundTrip(context.Background(), query)
			if err != nil {
				t.Fatal(err)
			}
			if resp != expectedResp {
				t.Fatal("unexpected resp")
			}
		})

		t.Run("using a real server", func(t *testing.T) {
			srvr := &filtering.DNSServer{
				OnQuery: func(domain string) filtering.DNSAction {
					return filtering.DNSActionCache
				},
				Cache: map[string][]string{
					"dns.google.": {"8.8.8.8"},
				},
			}
			listener, err := srvr.Start("127.0.0.1:0")
			if err != nil {
				t.Fatal(err)
			}
			defer listener.Close()
			dialer := NewDialerWithoutResolver(model.DiscardLogger)
			txp := NewUnwrappedDNSOverUDPTransport(dialer, listener.LocalAddr().String())
			encoder := &DNSEncoderMiekg{}
			query := encoder.Encode("dns.google.", dns.TypeA, false)
			resp, err := txp.RoundTrip(context.Background(), query)
			if err != nil {
				t.Fatal(err)
			}
			addrs, err := resp.DecodeLookupHost()
			if err != nil {
				t.Fatal(err)
			}
			if diff := cmp.Diff(addrs, []string{"8.8.8.8"}); diff != "" {
				t.Fatal(diff)
			}
		})
	})

	t.Run("recording delayed DNS responses", func(t *testing.T) {
		t.Run("without any context-injected traces", func(t *testing.T) {
			srvr := &filtering.DNSServer{
				OnQuery: func(domain string) filtering.DNSAction {
					return filtering.DNSActionLocalHostPlusCache
				},
				Cache: map[string][]string{
					"dns.google.": {"8.8.8.8"},
				},
			}
			listener, err := srvr.Start("127.0.0.1:0")
			if err != nil {
				t.Fatal(err)
			}
			defer listener.Close()
			dialer := NewDialerWithoutResolver(model.DiscardLogger)
			expectedAddress := listener.LocalAddr().String()
			txp := NewUnwrappedDNSOverUDPTransport(dialer, expectedAddress)
			txp.lateResponses = make(chan any, 1) // with buffer to avoid deadlocks
			encoder := &DNSEncoderMiekg{}
			query := encoder.Encode("dns.google.", dns.TypeA, false)
			rch, err := txp.RoundTrip(context.Background(), query)
			if err != nil {
				t.Fatal(err)
			}
			if _, err := rch.DecodeLookupHost(); err != nil {
				t.Fatal(err)
			}
			// Now wait for the delayed response to arrive. We don't care much
			// about observing it here, rather we want to know it happened.
			<-txp.lateResponses
		})

		t.Run("uses a context-injected custom trace (success case)", func(t *testing.T) {
			var (
				delayedDNSResponseCalled bool
				goodQueryType            bool
				goodTransportNetwork     bool
				goodTransportAddress     bool
				goodLookupAddrs          bool
				goodError                bool
			)
			srvr := &filtering.DNSServer{
				OnQuery: func(domain string) filtering.DNSAction {
					return filtering.DNSActionLocalHostPlusCache
				},
				Cache: map[string][]string{
					"dns.google.": {"8.8.8.8"},
				},
			}
			listener, err := srvr.Start("127.0.0.1:0")
			if err != nil {
				t.Fatal(err)
			}
			defer listener.Close()
			dialer := NewDialerWithoutResolver(model.DiscardLogger)
			expectedAddress := listener.LocalAddr().String()
			txp := NewUnwrappedDNSOverUDPTransport(dialer, expectedAddress)
			encoder := &DNSEncoderMiekg{}
			query := encoder.Encode("dns.google.", dns.TypeA, false)
			zeroTime := time.Now()
			deterministicTime := testingx.NewTimeDeterministic(zeroTime)
			expectedAddrs := []string{"8.8.8.8"}
			respChannel := make(chan *model.DNSResponse, 8)
			mu := new(sync.Mutex)
			tx := &mocks.Trace{
				MockTimeNow: deterministicTime.Now,
				MockOnDelayedDNSResponse: func(started time.Time, txp model.DNSTransport,
					query model.DNSQuery, response model.DNSResponse, addrs []string, err error,
					finished time.Time) error {
					mu.Lock()
					delayedDNSResponseCalled = true
					goodQueryType = (query.Type() == dns.TypeA)
					goodTransportNetwork = (txp.Network() == "udp")
					goodTransportAddress = (txp.Address() == expectedAddress)
					goodLookupAddrs = (cmp.Diff(expectedAddrs, addrs) == "")
					goodError = (err == nil)
					mu.Unlock()
					select {
					case respChannel <- &response:
						return nil
					default:
						return errors.New("full buffer")
					}
				},
				MockOnConnectDone: func(started time.Time, network, domain, remoteAddr string, err error,
					finished time.Time) {
					// do nothing
				},
				MockMaybeWrapNetConn: func(conn net.Conn) net.Conn {
					return conn
				},
			}
			ctx := ContextWithTrace(context.Background(), tx)
			rch, err := txp.RoundTrip(ctx, query)
			<-respChannel // wait for the delayed response
			if err != nil {
				t.Fatal(err)
			}
			addrs, err := rch.DecodeLookupHost()
			if err != nil {
				t.Fatal(err)
			}
			defer mu.Unlock()
			mu.Lock()
			if diff := cmp.Diff(addrs, []string{"127.0.0.1"}); diff != "" {
				t.Fatal(diff)
			}
			if !delayedDNSResponseCalled {
				t.Fatal("delayedDNSResponse not called")
			}
			if !goodQueryType {
				t.Fatal("unexpected query type")
			}
			if !goodTransportNetwork {
				t.Fatal("unexpected DNS transport network")
			}
			if !goodTransportAddress {
				t.Fatal("unexpected DNS Transport address")
			}
			if !goodLookupAddrs {
				t.Fatal("unexpected delayed DNSLookup address")
			}
			if !goodError {
				t.Fatal("unexpected error encountered")
			}
		})

		t.Run("uses a context-injected custom trace (failure case)", func(t *testing.T) {
			var (
				delayedDNSResponseCalled bool
				goodQueryType            bool
				goodTransportNetwork     bool
				goodTransportAddress     bool
				goodLookupAddrs          bool
				goodError                bool
			)
			srvr := &filtering.DNSServer{
				OnQuery: func(domain string) filtering.DNSAction {
					return filtering.DNSActionLocalHostPlusCache
				},
				Cache: map[string][]string{
					// Note: the cache here is nonexistent so we should
					// get a "no such host" error from the server.
				},
			}
			listener, err := srvr.Start("127.0.0.1:0")
			if err != nil {
				t.Fatal(err)
			}
			defer listener.Close()
			dialer := NewDialerWithoutResolver(model.DiscardLogger)
			expectedAddress := listener.LocalAddr().String()
			txp := NewUnwrappedDNSOverUDPTransport(dialer, expectedAddress)
			encoder := &DNSEncoderMiekg{}
			query := encoder.Encode("dns.google.", dns.TypeA, false)
			zeroTime := time.Now()
			deterministicTime := testingx.NewTimeDeterministic(zeroTime)
			respChannel := make(chan *model.DNSResponse, 8)
			mu := new(sync.Mutex)
			tx := &mocks.Trace{
				MockTimeNow: deterministicTime.Now,
				MockOnDelayedDNSResponse: func(started time.Time, txp model.DNSTransport,
					query model.DNSQuery, response model.DNSResponse, addrs []string, err error,
					finished time.Time) error {
					mu.Lock()
					delayedDNSResponseCalled = true
					goodQueryType = (query.Type() == dns.TypeA)
					goodTransportNetwork = (txp.Network() == "udp")
					goodTransportAddress = (txp.Address() == expectedAddress)
					goodLookupAddrs = (len(addrs) == 0)
					goodError = errors.Is(err, ErrOODNSNoSuchHost)
					mu.Unlock()
					respChannel <- &response
					return errors.New("mocked") // return error to stop background routine to record responses
				},
				MockOnConnectDone: func(started time.Time, network, domain, remoteAddr string, err error,
					finished time.Time) {
					// do nothing
				},
				MockMaybeWrapNetConn: func(conn net.Conn) net.Conn {
					return conn
				},
			}
			ctx := ContextWithTrace(context.Background(), tx)
			rch, err := txp.RoundTrip(ctx, query)
			<-respChannel // wait for the delayed response
			if err != nil {
				t.Fatal(err)
			}
			addrs, err := rch.DecodeLookupHost()
			if err != nil {
				t.Fatal(err)
			}
			defer mu.Unlock()
			mu.Lock()
			if diff := cmp.Diff(addrs, []string{"127.0.0.1"}); diff != "" {
				t.Fatal(diff)
			}
			if !delayedDNSResponseCalled {
				t.Fatal("delayedDNSResponse not called")
			}
			if !goodQueryType {
				t.Fatal("unexpected query type")
			}
			if !goodTransportNetwork {
				t.Fatal("unexpected DNS transport network")
			}
			if !goodTransportAddress {
				t.Fatal("unexpected DNS Transport address")
			}
			if !goodLookupAddrs {
				t.Fatal("unexpected delayed DNSLookup address")
			}
			if !goodError {
				t.Fatal("unexpected error encountered")
			}
		})
	})

	t.Run("CloseIdleConnections", func(t *testing.T) {
		var called bool
		dialer := &mocks.Dialer{
			MockCloseIdleConnections: func() {
				called = true
			},
		}
		const address = "9.9.9.9:53"
		txp := NewUnwrappedDNSOverUDPTransport(dialer, address)
		txp.CloseIdleConnections()
		if !called {
			t.Fatal("not called")
		}
	})

	t.Run("other functions okay", func(t *testing.T) {
		const address = "9.9.9.9:53"
		txp := NewUnwrappedDNSOverUDPTransport(NewDialerWithoutResolver(log.Log), address)
		if txp.RequiresPadding() != false {
			t.Fatal("invalid RequiresPadding")
		}
		if txp.Network() != "udp" {
			t.Fatal("invalid Network")
		}
		if txp.Address() != address {
			t.Fatal("invalid Address")
		}
	})
}