package badproxy import ( "crypto/rsa" "crypto/tls" "crypto/x509" "errors" "net" "testing" "time" "github.com/google/martian/v3/mitm" ) func TestCleartext(t *testing.T) { listener := newproxy(t) checkdial(t, listener.Addr().String(), nil, net.Dial) killproxy(t, listener) } func TestTLS(t *testing.T) { listener := newproxytls(t) checkdial(t, listener.Addr().String(), nil, func(network, address string) (net.Conn, error) { conn, err := tls.Dial(network, address, &tls.Config{ InsecureSkipVerify: true, ServerName: "antani.local", }) if err != nil { return nil, err } if err = conn.Handshake(); err != nil { conn.Close() return nil, err } return conn, nil }) killproxy(t, listener) } func TestListenError(t *testing.T) { proxy := NewCensoringProxy() listener, err := proxy.Start("8.8.8.8:80") if err == nil { t.Fatal("expected an error here") } if listener != nil { t.Fatal("expected nil listener here") } } func TestStarTLS(t *testing.T) { expected := errors.New("mocked error") t.Run("when we cannot create a new authority", func(t *testing.T) { proxy := NewCensoringProxy() proxy.mitmNewAuthority = func( name string, organization string, validity time.Duration, ) (*x509.Certificate, *rsa.PrivateKey, error) { return nil, nil, expected } cert, privkey, err := proxy.StartTLS("127.0.0.1:0") if !errors.Is(err, expected) { t.Fatal("not the error we expected") } if cert != nil { t.Fatal("expected nil cert") } if privkey != nil { t.Fatal("expected nil privkey") } }) t.Run("when we cannot create a new config", func(t *testing.T) { proxy := NewCensoringProxy() proxy.mitmNewConfig = func( ca *x509.Certificate, privateKey interface{}, ) (*mitm.Config, error) { return nil, expected } cert, privkey, err := proxy.StartTLS("127.0.0.1:0") if !errors.Is(err, expected) { t.Fatal("not the error we expected") } if cert != nil { t.Fatal("expected nil cert") } if privkey != nil { t.Fatal("expected nil privkey") } }) t.Run("when we cannot listen", func(t *testing.T) { proxy := NewCensoringProxy() proxy.tlsListen = func( network string, laddr string, config *tls.Config, ) (net.Listener, error) { return nil, expected } cert, privkey, err := proxy.StartTLS("127.0.0.1:0") if !errors.Is(err, expected) { t.Fatal("not the error we expected") } if cert != nil { t.Fatal("expected nil cert") } if privkey != nil { t.Fatal("expected nil privkey") } }) } func newproxy(t *testing.T) net.Listener { proxy := NewCensoringProxy() listener, err := proxy.Start("127.0.0.1:0") if err != nil { t.Fatal(err) } return listener } func newproxytls(t *testing.T) net.Listener { proxy := NewCensoringProxy() listener, _, err := proxy.StartTLS("127.0.0.1:0") if err != nil { t.Fatal(err) } return listener } func killproxy(t *testing.T, listener net.Listener) { err := listener.Close() if err != nil { t.Fatal(err) } } func checkdial( t *testing.T, proxyAddr string, expectErr error, dial func(network, address string) (net.Conn, error), ) { conn, err := dial("tcp", proxyAddr) if err != expectErr { t.Fatal("not the result we expected") } if conn == nil && expectErr == nil { t.Fatal("expected actionable conn") } if conn != nil && expectErr != nil { t.Fatal("expected nil conn") } if conn != nil { conn.Write([]byte("123454321")) conn.Close() } }