package resolver_test

import (
	"bytes"
	"context"
	"errors"
	"io"
	"net/http"
	"strings"
	"testing"

	"github.com/ooni/probe-cli/v3/internal/engine/httpheader"
	"github.com/ooni/probe-cli/v3/internal/engine/netx/resolver"
)

func TestDNSOverHTTPSNewRequestFailure(t *testing.T) {
	const invalidURL = "\t"
	txp := resolver.NewDNSOverHTTPS(http.DefaultClient, invalidURL)
	data, err := txp.RoundTrip(context.Background(), nil)
	if err == nil || !strings.HasSuffix(err.Error(), "invalid control character in URL") {
		t.Fatal("expected an error here")
	}
	if data != nil {
		t.Fatal("expected no response here")
	}
}

func TestDNSOverHTTPSClientDoFailure(t *testing.T) {
	expected := errors.New("mocked error")
	txp := resolver.DNSOverHTTPS{
		Do: func(*http.Request) (*http.Response, error) {
			return nil, expected
		},
		URL: "https://cloudflare-dns.com/dns-query",
	}
	data, err := txp.RoundTrip(context.Background(), nil)
	if !errors.Is(err, expected) {
		t.Fatal("expected an error here")
	}
	if data != nil {
		t.Fatal("expected no response here")
	}
}

func TestDNSOverHTTPSHTTPFailure(t *testing.T) {
	txp := resolver.DNSOverHTTPS{
		Do: func(*http.Request) (*http.Response, error) {
			return &http.Response{
				StatusCode: 500,
				Body:       io.NopCloser(strings.NewReader("")),
			}, nil
		},
		URL: "https://cloudflare-dns.com/dns-query",
	}
	data, err := txp.RoundTrip(context.Background(), nil)
	if err == nil || err.Error() != "doh: server returned error" {
		t.Fatal("expected an error here")
	}
	if data != nil {
		t.Fatal("expected no response here")
	}
}

func TestDNSOverHTTPSMissingContentType(t *testing.T) {
	txp := resolver.DNSOverHTTPS{
		Do: func(*http.Request) (*http.Response, error) {
			return &http.Response{
				StatusCode: 200,
				Body:       io.NopCloser(strings.NewReader("")),
			}, nil
		},
		URL: "https://cloudflare-dns.com/dns-query",
	}
	data, err := txp.RoundTrip(context.Background(), nil)
	if err == nil || err.Error() != "doh: invalid content-type" {
		t.Fatal("expected an error here")
	}
	if data != nil {
		t.Fatal("expected no response here")
	}
}

func TestDNSOverHTTPSSuccess(t *testing.T) {
	body := []byte("AAA")
	txp := resolver.DNSOverHTTPS{
		Do: func(*http.Request) (*http.Response, error) {
			return &http.Response{
				StatusCode: 200,
				Body:       io.NopCloser(bytes.NewReader(body)),
				Header: http.Header{
					"Content-Type": []string{"application/dns-message"},
				},
			}, nil
		},
		URL: "https://cloudflare-dns.com/dns-query",
	}
	data, err := txp.RoundTrip(context.Background(), nil)
	if err != nil {
		t.Fatal(err)
	}
	if !bytes.Equal(data, body) {
		t.Fatal("not the response we expected")
	}
}

func TestDNSOverHTTPTransportOK(t *testing.T) {
	const queryURL = "https://cloudflare-dns.com/dns-query"
	txp := resolver.NewDNSOverHTTPS(http.DefaultClient, queryURL)
	if txp.Network() != "doh" {
		t.Fatal("invalid network")
	}
	if txp.RequiresPadding() != true {
		t.Fatal("should require padding")
	}
	if txp.Address() != queryURL {
		t.Fatal("invalid address")
	}
}

func TestDNSOverHTTPSClientSetsUserAgent(t *testing.T) {
	expected := errors.New("mocked error")
	var correct bool
	txp := resolver.DNSOverHTTPS{
		Do: func(req *http.Request) (*http.Response, error) {
			correct = req.Header.Get("User-Agent") == httpheader.UserAgent()
			return nil, expected
		},
		URL: "https://cloudflare-dns.com/dns-query",
	}
	data, err := txp.RoundTrip(context.Background(), nil)
	if !errors.Is(err, expected) {
		t.Fatal("expected an error here")
	}
	if data != nil {
		t.Fatal("expected no response here")
	}
	if !correct {
		t.Fatal("did not see correct user agent")
	}
}

func TestDNSOverHTTPSHostOverride(t *testing.T) {
	var correct bool
	expected := errors.New("mocked error")

	hostOverride := "test.com"
	txp := resolver.DNSOverHTTPS{
		Do: func(req *http.Request) (*http.Response, error) {
			correct = req.Host == hostOverride
			return nil, expected
		},
		URL:          "https://cloudflare-dns.com/dns-query",
		HostOverride: hostOverride,
	}
	data, err := txp.RoundTrip(context.Background(), nil)
	if !errors.Is(err, expected) {
		t.Fatal("expected an error here")
	}
	if data != nil {
		t.Fatal("expected no response here")
	}
	if !correct {
		t.Fatal("did not see correct host override")
	}
}