package ooapi

import (
	"encoding/json"
	"net/http"
	"strings"
	"sync"
	"testing"
	"time"

	"github.com/ooni/probe-cli/v3/internal/atomicx"
	"github.com/ooni/probe-cli/v3/internal/netxlite"
	"github.com/ooni/probe-cli/v3/internal/ooapi/apimodel"
)

// LoginHandler is an http.Handler to test login
type LoginHandler struct {
	failCallWith []int // ignored by login and register
	mu           sync.Mutex
	noRegister   bool
	state        []*loginState
	t            *testing.T
	logins       *atomicx.Int64
	registers    *atomicx.Int64
}

func (lh *LoginHandler) forgetLogins() {
	defer lh.mu.Unlock()
	lh.mu.Lock()
	lh.state = nil
}

func (lh *LoginHandler) forgetTokens() {
	defer lh.mu.Unlock()
	lh.mu.Lock()
	for _, entry := range lh.state {
		// This should be enough to cause all tokens to
		// be expired and force clients to relogin.
		//
		// (It does not matter much whether the client
		// clock is off, or the server clock is off,
		// thanks Galileo for explaining this to us <3.)
		entry.Expire = time.Now().Add(-3600 * time.Second)
	}
}

func (lh *LoginHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
	// Implementation note: we don't check for the method
	// for simplicity since it's already tested.
	switch r.URL.Path {
	case "/api/v1/register":
		if lh.registers != nil {
			lh.registers.Add(1)
		}
		lh.register(w, r)
	case "/api/v1/login":
		if lh.logins != nil {
			lh.logins.Add(1)
		}
		lh.login(w, r)
	case "/api/v1/test-list/psiphon-config":
		lh.psiphon(w, r)
	case "/api/v1/test-list/tor-targets":
		lh.tor(w, r)
	default:
		w.WriteHeader(500)
	}
}

func (lh *LoginHandler) register(w http.ResponseWriter, r *http.Request) {
	if r.Body == nil {
		w.WriteHeader(400)
		return
	}
	data, err := netxlite.ReadAllContext(r.Context(), r.Body)
	if err != nil {
		w.WriteHeader(400)
		return
	}
	var req apimodel.RegisterRequest
	if err := json.Unmarshal(data, &req); err != nil {
		w.WriteHeader(400)
		return
	}
	if req.Password == "" {
		w.WriteHeader(400)
		return
	}
	defer lh.mu.Unlock()
	lh.mu.Lock()
	if lh.noRegister {
		// We have been asked to stop registering clients so
		// we're going to make a boo boo.
		w.WriteHeader(500)
		return
	}
	var resp apimodel.RegisterResponse
	ff := &fakeFill{}
	ff.Fill(&resp)
	lh.state = append(lh.state, &loginState{
		ClientID: resp.ClientID, Password: req.Password})
	data, err = json.Marshal(&resp)
	if err != nil {
		w.WriteHeader(500)
		return
	}
	lh.t.Logf("register: %+v", string(data))
	w.Write(data)
}

func (lh *LoginHandler) login(w http.ResponseWriter, r *http.Request) {
	if r.Body == nil {
		w.WriteHeader(400)
		return
	}
	data, err := netxlite.ReadAllContext(r.Context(), r.Body)
	if err != nil {
		w.WriteHeader(400)
		return
	}
	var req apimodel.LoginRequest
	if err := json.Unmarshal(data, &req); err != nil {
		w.WriteHeader(400)
		return
	}
	defer lh.mu.Unlock()
	lh.mu.Lock()
	for _, s := range lh.state {
		if req.ClientID == s.ClientID && req.Password == s.Password {
			var resp apimodel.LoginResponse
			ff := &fakeFill{}
			ff.Fill(&resp)
			// We want the token to be many seconds in the future while
			// ff.fill only sets the tokent to now plus a small delta.
			resp.Expire = time.Now().Add(3600 * time.Second)
			s.Expire = resp.Expire
			s.Token = resp.Token
			data, err = json.Marshal(&resp)
			if err != nil {
				w.WriteHeader(500)
				return
			}
			lh.t.Logf("login: %+v", string(data))
			w.Write(data)
			return
		}
	}
	lh.t.Log("login: 401")
	w.WriteHeader(401)
}

func (lh *LoginHandler) psiphon(w http.ResponseWriter, r *http.Request) {
	defer lh.mu.Unlock()
	lh.mu.Lock()
	if len(lh.failCallWith) > 0 {
		code := lh.failCallWith[0]
		lh.failCallWith = lh.failCallWith[1:]
		w.WriteHeader(code)
		return
	}
	token := strings.Replace(r.Header.Get("Authorization"), "Bearer ", "", 1)
	for _, s := range lh.state {
		if token == s.Token && time.Now().Before(s.Expire) {
			var resp apimodel.PsiphonConfigResponse
			ff := &fakeFill{}
			ff.Fill(&resp)
			data, err := json.Marshal(&resp)
			if err != nil {
				w.WriteHeader(500)
				return
			}
			lh.t.Logf("psiphon: %+v", string(data))
			w.Write(data)
			return
		}
	}
	lh.t.Log("psiphon: 401")
	w.WriteHeader(401)
}

func (lh *LoginHandler) tor(w http.ResponseWriter, r *http.Request) {
	defer lh.mu.Unlock()
	lh.mu.Lock()
	if len(lh.failCallWith) > 0 {
		code := lh.failCallWith[0]
		lh.failCallWith = lh.failCallWith[1:]
		w.WriteHeader(code)
		return
	}
	token := strings.Replace(r.Header.Get("Authorization"), "Bearer ", "", 1)
	for _, s := range lh.state {
		if token == s.Token && time.Now().Before(s.Expire) {
			var resp apimodel.TorTargetsResponse
			ff := &fakeFill{}
			ff.Fill(&resp)
			data, err := json.Marshal(&resp)
			if err != nil {
				w.WriteHeader(500)
				return
			}
			lh.t.Logf("tor: %+v", string(data))
			w.Write(data)
			return
		}
	}
	lh.t.Log("tor: 401")
	w.WriteHeader(401)
}