package urlgetter_test

import (
	"context"
	"crypto/x509"
	"errors"
	"fmt"
	"net/http"
	"net/http/httptest"
	"testing"
	"time"

	"github.com/apex/log"
	"github.com/ooni/probe-cli/v3/internal/engine/experiment/urlgetter"
	"github.com/ooni/probe-cli/v3/internal/engine/mockable"
	"github.com/ooni/probe-cli/v3/internal/model"
)

func TestMultiIntegration(t *testing.T) {
	multi := urlgetter.Multi{Session: &mockable.Session{}}
	inputs := []urlgetter.MultiInput{{
		Config: urlgetter.Config{Method: "HEAD", NoFollowRedirects: true},
		Target: "https://www.google.com",
	}, {
		Config: urlgetter.Config{Method: "HEAD", NoFollowRedirects: true},
		Target: "https://www.facebook.com",
	}, {
		Config: urlgetter.Config{Method: "HEAD", NoFollowRedirects: true},
		Target: "https://www.kernel.org",
	}, {
		Config: urlgetter.Config{Method: "HEAD", NoFollowRedirects: true},
		Target: "https://www.instagram.com",
	}}
	outputs := multi.Collect(context.Background(), inputs, "integration-test",
		model.NewPrinterCallbacks(log.Log))
	var count int
	for result := range outputs {
		count++
		switch result.Input.Target {
		case "https://www.google.com":
		case "https://www.facebook.com":
		case "https://www.kernel.org":
		case "https://www.instagram.com":
		default:
			t.Fatal("unexpected Input.Target")
		}
		if result.Input.Config.Method != "HEAD" {
			t.Fatal("unexpected Input.Config.Method")
		}
		if result.Err != nil {
			t.Fatal(result.Err)
		}
		if result.TestKeys.Agent != "agent" {
			t.Fatal("invalid TestKeys.Agent")
		}
		if len(result.TestKeys.Queries) != 2 {
			t.Fatal("invalid number of Queries")
		}
		if len(result.TestKeys.Requests) != 1 {
			t.Fatal("invalid number of Requests")
		}
		if len(result.TestKeys.TCPConnect) != 1 {
			t.Fatal("invalid number of TCPConnects")
		}
		if len(result.TestKeys.TLSHandshakes) != 1 {
			t.Fatal("invalid number of TLSHandshakes")
		}
	}
	if count != 4 {
		t.Fatal("invalid number of outputs")
	}
}

func TestMultiIntegrationWithBaseTime(t *testing.T) {
	// We set a beginning of time that's significantly in the past and then
	// fail the test if we see any T smaller than 3600 seconds.
	begin := time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC)
	multi := urlgetter.Multi{
		Begin:   begin,
		Session: &mockable.Session{},
	}
	inputs := []urlgetter.MultiInput{{
		Config: urlgetter.Config{Method: "HEAD", NoFollowRedirects: true},
		Target: "https://www.google.com",
	}, {
		Config: urlgetter.Config{Method: "HEAD", NoFollowRedirects: true},
		Target: "https://www.instagram.com",
	}}
	outputs := multi.Collect(context.Background(), inputs, "integration-test",
		model.NewPrinterCallbacks(log.Log))
	var count int
	for result := range outputs {
		for _, entry := range result.TestKeys.NetworkEvents {
			if entry.T < 3600 {
				t.Fatal("base time not correctly set")
			}
			count++
		}
		for _, entry := range result.TestKeys.Queries {
			if entry.T < 3600 {
				t.Fatal("base time not correctly set")
			}
			count++
		}
		for _, entry := range result.TestKeys.TCPConnect {
			if entry.T < 3600 {
				t.Fatal("base time not correctly set")
			}
			count++
		}
		for _, entry := range result.TestKeys.TLSHandshakes {
			if entry.T < 3600 {
				t.Fatal("base time not correctly set")
			}
			count++
		}
	}
	if count <= 0 {
		t.Fatal("unexpected number of entries processed")
	}
}

func TestMultiIntegrationWithoutBaseTime(t *testing.T) {
	// We use the default beginning of time and then fail the test
	// if we see any T smaller than 60 seconds.
	multi := urlgetter.Multi{Session: &mockable.Session{}}
	inputs := []urlgetter.MultiInput{{
		Config: urlgetter.Config{Method: "HEAD", NoFollowRedirects: true},
		Target: "https://www.google.com",
	}, {
		Config: urlgetter.Config{Method: "HEAD", NoFollowRedirects: true},
		Target: "https://www.instagram.com",
	}}
	outputs := multi.Collect(context.Background(), inputs, "integration-test",
		model.NewPrinterCallbacks(log.Log))
	var count int
	for result := range outputs {
		for _, entry := range result.TestKeys.NetworkEvents {
			if entry.T > 60 {
				t.Fatal("base time not correctly set")
			}
			count++
		}
		for _, entry := range result.TestKeys.Queries {
			if entry.T > 60 {
				t.Fatal("base time not correctly set")
			}
			count++
		}
		for _, entry := range result.TestKeys.TCPConnect {
			if entry.T > 60 {
				t.Fatal("base time not correctly set")
			}
			count++
		}
		for _, entry := range result.TestKeys.TLSHandshakes {
			if entry.T > 60 {
				t.Fatal("base time not correctly set")
			}
			count++
		}
	}
	if count <= 0 {
		t.Fatal("unexpected number of entries processed")
	}
}

func TestMultiContextCanceled(t *testing.T) {
	multi := urlgetter.Multi{Session: &mockable.Session{}}
	inputs := []urlgetter.MultiInput{{
		Config: urlgetter.Config{Method: "HEAD", NoFollowRedirects: true},
		Target: "https://www.google.com",
	}, {
		Config: urlgetter.Config{Method: "HEAD", NoFollowRedirects: true},
		Target: "https://www.facebook.com",
	}, {
		Config: urlgetter.Config{Method: "HEAD", NoFollowRedirects: true},
		Target: "https://www.kernel.org",
	}, {
		Config: urlgetter.Config{Method: "HEAD", NoFollowRedirects: true},
		Target: "https://www.instagram.com",
	}}
	ctx, cancel := context.WithCancel(context.Background())
	cancel()
	outputs := multi.Collect(ctx, inputs, "integration-test",
		model.NewPrinterCallbacks(log.Log))
	var count int
	for result := range outputs {
		count++
		switch result.Input.Target {
		case "https://www.google.com":
		case "https://www.facebook.com":
		case "https://www.kernel.org":
		case "https://www.instagram.com":
		default:
			t.Fatal("unexpected Input.Target")
		}
		if result.Input.Config.Method != "HEAD" {
			t.Fatal("unexpected Input.Config.Method")
		}
		if !errors.Is(result.Err, context.Canceled) {
			t.Fatal("unexpected error")
		}
		if result.TestKeys.Agent != "agent" {
			t.Fatal("invalid TestKeys.Agent")
		}
		if len(result.TestKeys.Queries) != 0 {
			t.Fatal("invalid number of Queries")
		}
		if len(result.TestKeys.Requests) != 1 {
			t.Fatal("invalid number of Requests")
		}
		if len(result.TestKeys.TCPConnect) != 0 {
			t.Fatal("invalid number of TCPConnects")
		}
		if len(result.TestKeys.TLSHandshakes) != 0 {
			t.Fatal("invalid number of TLSHandshakes")
		}
	}
	if count != 4 {
		t.Fatal("invalid number of outputs")
	}
}

func TestMultiWithSpecificCertPool(t *testing.T) {
	server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		fmt.Fprintln(w, "Hello, client")
	}))
	defer server.Close()
	cert := server.Certificate()
	certpool := x509.NewCertPool()
	certpool.AddCert(cert)
	multi := urlgetter.Multi{Session: &mockable.Session{}}
	inputs := []urlgetter.MultiInput{{
		Config: urlgetter.Config{
			CertPool:          certpool,
			Method:            "GET",
			NoFollowRedirects: true,
		},
		Target: server.URL,
	}}
	ctx, cancel := context.WithCancel(context.Background())
	defer cancel()
	outputs := multi.Collect(ctx, inputs, "integration-test",
		model.NewPrinterCallbacks(log.Log))
	var count int
	for result := range outputs {
		count++
		if result.Err != nil {
			t.Fatal(result.Err)
		}
	}
	if count != 1 {
		t.Fatal("unexpected count value")
	}
}