chore: improve testing and increase coverage (#794)
This diff improves testing and increases coverage inside the ./internal/netxlite and ./internal/tracex packages. See https://github.com/ooni/probe/issues/2121
This commit is contained in:
		
							parent
							
								
									464d03184e
								
							
						
					
					
						commit
						d5249a6cf7
					
				| @ -207,23 +207,6 @@ func (state *getaddrinfoState) addrinfoToString(r *C.struct_addrinfo) (string, e | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // staticAddrinfoWithInvalidFamily is an helper to construct an addrinfo struct |  | ||||||
| // that we use in testing. (We cannot call CGO directly from tests.) |  | ||||||
| func staticAddrinfoWithInvalidFamily() *C.struct_addrinfo { |  | ||||||
| 	var value C.struct_addrinfo       // zeroed by Go |  | ||||||
| 	value.ai_socktype = C.SOCK_STREAM // this is what the code expects |  | ||||||
| 	value.ai_family = 0               // but 0 is not AF_INET{,6} |  | ||||||
| 	return &value |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // staticAddrinfoWithInvalidSocketType is an helper to construct an addrinfo struct |  | ||||||
| // that we use in testing. (We cannot call CGO directly from tests.) |  | ||||||
| func staticAddrinfoWithInvalidSocketType() *C.struct_addrinfo { |  | ||||||
| 	var value C.struct_addrinfo      // zeroed by Go |  | ||||||
| 	value.ai_socktype = C.SOCK_DGRAM // not SOCK_STREAM |  | ||||||
| 	return &value |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // getaddrinfoCopyIP copies a net.IP. | // getaddrinfoCopyIP copies a net.IP. | ||||||
| // | // | ||||||
| // This function is adapted from copyIP | // This function is adapted from copyIP | ||||||
|  | |||||||
| @ -177,29 +177,29 @@ func NewOOHTTPBaseTransport(dialer model.Dialer, tlsDialer model.TLSDialer) mode | |||||||
| 
 | 
 | ||||||
| 	// Ensure we correctly forward CloseIdleConnections. | 	// Ensure we correctly forward CloseIdleConnections. | ||||||
| 	return &httpTransportConnectionsCloser{ | 	return &httpTransportConnectionsCloser{ | ||||||
| 		HTTPTransport: &stdlibTransport{&oohttp.StdlibTransport{Transport: txp}}, | 		HTTPTransport: &httpTransportStdlib{&oohttp.StdlibTransport{Transport: txp}}, | ||||||
| 		Dialer:        dialer, | 		Dialer:        dialer, | ||||||
| 		TLSDialer:     tlsDialer, | 		TLSDialer:     tlsDialer, | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // stdlibTransport wraps oohttp.StdlibTransport to add .Network() | // stdlibTransport wraps oohttp.StdlibTransport to add .Network() | ||||||
| type stdlibTransport struct { | type httpTransportStdlib struct { | ||||||
| 	StdlibTransport *oohttp.StdlibTransport | 	StdlibTransport *oohttp.StdlibTransport | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| var _ model.HTTPTransport = &stdlibTransport{} | var _ model.HTTPTransport = &httpTransportStdlib{} | ||||||
| 
 | 
 | ||||||
| func (txp *stdlibTransport) CloseIdleConnections() { | func (txp *httpTransportStdlib) CloseIdleConnections() { | ||||||
| 	txp.StdlibTransport.CloseIdleConnections() | 	txp.StdlibTransport.CloseIdleConnections() | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (txp *stdlibTransport) RoundTrip(req *http.Request) (*http.Response, error) { | func (txp *httpTransportStdlib) RoundTrip(req *http.Request) (*http.Response, error) { | ||||||
| 	return txp.StdlibTransport.RoundTrip(req) | 	return txp.StdlibTransport.RoundTrip(req) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Network implements HTTPTransport.Network. | // Network implements HTTPTransport.Network. | ||||||
| func (txp *stdlibTransport) Network() string { | func (txp *httpTransportStdlib) Network() string { | ||||||
| 	return "tcp" | 	return "tcp" | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -272,7 +272,7 @@ func TestNewHTTPTransport(t *testing.T) { | |||||||
| 		if tlsWithReadTimeout.TLSDialer != td { | 		if tlsWithReadTimeout.TLSDialer != td { | ||||||
| 			t.Fatal("invalid tls dialer") | 			t.Fatal("invalid tls dialer") | ||||||
| 		} | 		} | ||||||
| 		stdlib := connectionsCloser.HTTPTransport.(*stdlibTransport) | 		stdlib := connectionsCloser.HTTPTransport.(*httpTransportStdlib) | ||||||
| 		if !stdlib.StdlibTransport.ForceAttemptHTTP2 { | 		if !stdlib.StdlibTransport.ForceAttemptHTTP2 { | ||||||
| 			t.Fatal("invalid ForceAttemptHTTP2") | 			t.Fatal("invalid ForceAttemptHTTP2") | ||||||
| 		} | 		} | ||||||
| @ -292,83 +292,13 @@ func TestNewHTTPTransport(t *testing.T) { | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func TestHTTPDialerWithReadTimeout(t *testing.T) { | func TestHTTPDialerWithReadTimeout(t *testing.T) { | ||||||
| 	t.Run("on success", func(t *testing.T) { | 	t.Run("DialContext", func(t *testing.T) { | ||||||
| 		var ( | 		t.Run("on success", func(t *testing.T) { | ||||||
| 			calledWithZeroTime    bool | 			var ( | ||||||
| 			calledWithNonZeroTime bool | 				calledWithZeroTime    bool | ||||||
| 		) | 				calledWithNonZeroTime bool | ||||||
| 		origConn := &mocks.Conn{ | 			) | ||||||
| 			MockSetReadDeadline: func(t time.Time) error { | 			origConn := &mocks.Conn{ | ||||||
| 				switch t.IsZero() { |  | ||||||
| 				case true: |  | ||||||
| 					calledWithZeroTime = true |  | ||||||
| 				case false: |  | ||||||
| 					calledWithNonZeroTime = true |  | ||||||
| 				} |  | ||||||
| 				return nil |  | ||||||
| 			}, |  | ||||||
| 			MockRead: func(b []byte) (int, error) { |  | ||||||
| 				return 0, io.EOF |  | ||||||
| 			}, |  | ||||||
| 		} |  | ||||||
| 		d := &httpDialerWithReadTimeout{ |  | ||||||
| 			Dialer: &mocks.Dialer{ |  | ||||||
| 				MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) { |  | ||||||
| 					return origConn, nil |  | ||||||
| 				}, |  | ||||||
| 			}, |  | ||||||
| 		} |  | ||||||
| 		ctx := context.Background() |  | ||||||
| 		conn, err := d.DialContext(ctx, "", "") |  | ||||||
| 		if err != nil { |  | ||||||
| 			t.Fatal(err) |  | ||||||
| 		} |  | ||||||
| 		if _, okay := conn.(*httpConnWithReadTimeout); !okay { |  | ||||||
| 			t.Fatal("invalid conn type") |  | ||||||
| 		} |  | ||||||
| 		if conn.(*httpConnWithReadTimeout).Conn != origConn { |  | ||||||
| 			t.Fatal("invalid origin conn") |  | ||||||
| 		} |  | ||||||
| 		b := make([]byte, 1024) |  | ||||||
| 		count, err := conn.Read(b) |  | ||||||
| 		if !errors.Is(err, io.EOF) { |  | ||||||
| 			t.Fatal("invalid error") |  | ||||||
| 		} |  | ||||||
| 		if count != 0 { |  | ||||||
| 			t.Fatal("invalid count") |  | ||||||
| 		} |  | ||||||
| 		if !calledWithZeroTime || !calledWithNonZeroTime { |  | ||||||
| 			t.Fatal("not called") |  | ||||||
| 		} |  | ||||||
| 	}) |  | ||||||
| 
 |  | ||||||
| 	t.Run("on failure", func(t *testing.T) { |  | ||||||
| 		expected := errors.New("mocked error") |  | ||||||
| 		d := &httpDialerWithReadTimeout{ |  | ||||||
| 			Dialer: &mocks.Dialer{ |  | ||||||
| 				MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) { |  | ||||||
| 					return nil, expected |  | ||||||
| 				}, |  | ||||||
| 			}, |  | ||||||
| 		} |  | ||||||
| 		conn, err := d.DialContext(context.Background(), "", "") |  | ||||||
| 		if !errors.Is(err, expected) { |  | ||||||
| 			t.Fatal("not the error we expected") |  | ||||||
| 		} |  | ||||||
| 		if conn != nil { |  | ||||||
| 			t.Fatal("expected nil conn here") |  | ||||||
| 		} |  | ||||||
| 	}) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func TestHTTPTLSDialerWithReadTimeout(t *testing.T) { |  | ||||||
| 	t.Run("on success", func(t *testing.T) { |  | ||||||
| 		var ( |  | ||||||
| 			calledWithZeroTime    bool |  | ||||||
| 			calledWithNonZeroTime bool |  | ||||||
| 		) |  | ||||||
| 		origConn := &mocks.TLSConn{ |  | ||||||
| 			Conn: mocks.Conn{ |  | ||||||
| 				MockSetReadDeadline: func(t time.Time) error { | 				MockSetReadDeadline: func(t time.Time) error { | ||||||
| 					switch t.IsZero() { | 					switch t.IsZero() { | ||||||
| 					case true: | 					case true: | ||||||
| @ -381,81 +311,155 @@ func TestHTTPTLSDialerWithReadTimeout(t *testing.T) { | |||||||
| 				MockRead: func(b []byte) (int, error) { | 				MockRead: func(b []byte) (int, error) { | ||||||
| 					return 0, io.EOF | 					return 0, io.EOF | ||||||
| 				}, | 				}, | ||||||
| 			}, | 			} | ||||||
| 		} | 			d := &httpDialerWithReadTimeout{ | ||||||
| 		d := &httpTLSDialerWithReadTimeout{ | 				Dialer: &mocks.Dialer{ | ||||||
| 			TLSDialer: &mocks.TLSDialer{ | 					MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) { | ||||||
| 				MockDialTLSContext: func(ctx context.Context, network, address string) (net.Conn, error) { | 						return origConn, nil | ||||||
| 					return origConn, nil | 					}, | ||||||
| 				}, | 				}, | ||||||
| 			}, | 			} | ||||||
| 		} | 			ctx := context.Background() | ||||||
| 		ctx := context.Background() | 			conn, err := d.DialContext(ctx, "", "") | ||||||
| 		conn, err := d.DialTLSContext(ctx, "", "") | 			if err != nil { | ||||||
| 		if err != nil { | 				t.Fatal(err) | ||||||
| 			t.Fatal(err) | 			} | ||||||
| 		} | 			if _, okay := conn.(*httpConnWithReadTimeout); !okay { | ||||||
| 		if _, okay := conn.(*httpTLSConnWithReadTimeout); !okay { | 				t.Fatal("invalid conn type") | ||||||
| 			t.Fatal("invalid conn type") | 			} | ||||||
| 		} | 			if conn.(*httpConnWithReadTimeout).Conn != origConn { | ||||||
| 		if conn.(*httpTLSConnWithReadTimeout).TLSConn != origConn { | 				t.Fatal("invalid origin conn") | ||||||
| 			t.Fatal("invalid origin conn") | 			} | ||||||
| 		} | 			b := make([]byte, 1024) | ||||||
| 		b := make([]byte, 1024) | 			count, err := conn.Read(b) | ||||||
| 		count, err := conn.Read(b) | 			if !errors.Is(err, io.EOF) { | ||||||
| 		if !errors.Is(err, io.EOF) { | 				t.Fatal("invalid error") | ||||||
| 			t.Fatal("invalid error") | 			} | ||||||
| 		} | 			if count != 0 { | ||||||
| 		if count != 0 { | 				t.Fatal("invalid count") | ||||||
| 			t.Fatal("invalid count") | 			} | ||||||
| 		} | 			if !calledWithZeroTime || !calledWithNonZeroTime { | ||||||
| 		if !calledWithZeroTime || !calledWithNonZeroTime { | 				t.Fatal("not called") | ||||||
| 			t.Fatal("not called") | 			} | ||||||
| 		} | 		}) | ||||||
| 	}) |  | ||||||
| 
 | 
 | ||||||
| 	t.Run("on failure", func(t *testing.T) { | 		t.Run("on failure", func(t *testing.T) { | ||||||
| 		expected := errors.New("mocked error") | 			expected := errors.New("mocked error") | ||||||
| 		d := &httpTLSDialerWithReadTimeout{ | 			d := &httpDialerWithReadTimeout{ | ||||||
| 			TLSDialer: &mocks.TLSDialer{ | 				Dialer: &mocks.Dialer{ | ||||||
| 				MockDialTLSContext: func(ctx context.Context, network, address string) (net.Conn, error) { | 					MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) { | ||||||
| 					return nil, expected | 						return nil, expected | ||||||
|  | 					}, | ||||||
| 				}, | 				}, | ||||||
| 			}, | 			} | ||||||
| 		} | 			conn, err := d.DialContext(context.Background(), "", "") | ||||||
| 		conn, err := d.DialTLSContext(context.Background(), "", "") | 			if !errors.Is(err, expected) { | ||||||
| 		if !errors.Is(err, expected) { | 				t.Fatal("not the error we expected") | ||||||
| 			t.Fatal("not the error we expected") | 			} | ||||||
| 		} | 			if conn != nil { | ||||||
| 		if conn != nil { | 				t.Fatal("expected nil conn here") | ||||||
| 			t.Fatal("expected nil conn here") | 			} | ||||||
| 		} | 		}) | ||||||
| 	}) | 	}) | ||||||
|  | } | ||||||
| 
 | 
 | ||||||
| 	t.Run("with invalid conn type", func(t *testing.T) { | func TestHTTPTLSDialerWithReadTimeout(t *testing.T) { | ||||||
| 		var called bool | 	t.Run("DialContext", func(t *testing.T) { | ||||||
| 		d := &httpTLSDialerWithReadTimeout{ | 		t.Run("on success", func(t *testing.T) { | ||||||
| 			TLSDialer: &mocks.TLSDialer{ | 			var ( | ||||||
| 				MockDialTLSContext: func(ctx context.Context, network, address string) (net.Conn, error) { | 				calledWithZeroTime    bool | ||||||
| 					return &mocks.Conn{ | 				calledWithNonZeroTime bool | ||||||
| 						MockClose: func() error { | 			) | ||||||
| 							called = true | 			origConn := &mocks.TLSConn{ | ||||||
| 							return nil | 				Conn: mocks.Conn{ | ||||||
| 						}, | 					MockSetReadDeadline: func(t time.Time) error { | ||||||
| 					}, nil | 						switch t.IsZero() { | ||||||
|  | 						case true: | ||||||
|  | 							calledWithZeroTime = true | ||||||
|  | 						case false: | ||||||
|  | 							calledWithNonZeroTime = true | ||||||
|  | 						} | ||||||
|  | 						return nil | ||||||
|  | 					}, | ||||||
|  | 					MockRead: func(b []byte) (int, error) { | ||||||
|  | 						return 0, io.EOF | ||||||
|  | 					}, | ||||||
| 				}, | 				}, | ||||||
| 			}, | 			} | ||||||
| 		} | 			d := &httpTLSDialerWithReadTimeout{ | ||||||
| 		conn, err := d.DialTLSContext(context.Background(), "", "") | 				TLSDialer: &mocks.TLSDialer{ | ||||||
| 		if !errors.Is(err, ErrNotTLSConn) { | 					MockDialTLSContext: func(ctx context.Context, network, address string) (net.Conn, error) { | ||||||
| 			t.Fatal("not the error we expected") | 						return origConn, nil | ||||||
| 		} | 					}, | ||||||
| 		if conn != nil { | 				}, | ||||||
| 			t.Fatal("expected nil conn here") | 			} | ||||||
| 		} | 			ctx := context.Background() | ||||||
| 		if !called { | 			conn, err := d.DialTLSContext(ctx, "", "") | ||||||
| 			t.Fatal("not called") | 			if err != nil { | ||||||
| 		} | 				t.Fatal(err) | ||||||
|  | 			} | ||||||
|  | 			if _, okay := conn.(*httpTLSConnWithReadTimeout); !okay { | ||||||
|  | 				t.Fatal("invalid conn type") | ||||||
|  | 			} | ||||||
|  | 			if conn.(*httpTLSConnWithReadTimeout).TLSConn != origConn { | ||||||
|  | 				t.Fatal("invalid origin conn") | ||||||
|  | 			} | ||||||
|  | 			b := make([]byte, 1024) | ||||||
|  | 			count, err := conn.Read(b) | ||||||
|  | 			if !errors.Is(err, io.EOF) { | ||||||
|  | 				t.Fatal("invalid error") | ||||||
|  | 			} | ||||||
|  | 			if count != 0 { | ||||||
|  | 				t.Fatal("invalid count") | ||||||
|  | 			} | ||||||
|  | 			if !calledWithZeroTime || !calledWithNonZeroTime { | ||||||
|  | 				t.Fatal("not called") | ||||||
|  | 			} | ||||||
|  | 		}) | ||||||
|  | 
 | ||||||
|  | 		t.Run("on failure", func(t *testing.T) { | ||||||
|  | 			expected := errors.New("mocked error") | ||||||
|  | 			d := &httpTLSDialerWithReadTimeout{ | ||||||
|  | 				TLSDialer: &mocks.TLSDialer{ | ||||||
|  | 					MockDialTLSContext: func(ctx context.Context, network, address string) (net.Conn, error) { | ||||||
|  | 						return nil, expected | ||||||
|  | 					}, | ||||||
|  | 				}, | ||||||
|  | 			} | ||||||
|  | 			conn, err := d.DialTLSContext(context.Background(), "", "") | ||||||
|  | 			if !errors.Is(err, expected) { | ||||||
|  | 				t.Fatal("not the error we expected") | ||||||
|  | 			} | ||||||
|  | 			if conn != nil { | ||||||
|  | 				t.Fatal("expected nil conn here") | ||||||
|  | 			} | ||||||
|  | 		}) | ||||||
|  | 
 | ||||||
|  | 		t.Run("with invalid conn type", func(t *testing.T) { | ||||||
|  | 			var called bool | ||||||
|  | 			d := &httpTLSDialerWithReadTimeout{ | ||||||
|  | 				TLSDialer: &mocks.TLSDialer{ | ||||||
|  | 					MockDialTLSContext: func(ctx context.Context, network, address string) (net.Conn, error) { | ||||||
|  | 						return &mocks.Conn{ | ||||||
|  | 							MockClose: func() error { | ||||||
|  | 								called = true | ||||||
|  | 								return nil | ||||||
|  | 							}, | ||||||
|  | 						}, nil | ||||||
|  | 					}, | ||||||
|  | 				}, | ||||||
|  | 			} | ||||||
|  | 			conn, err := d.DialTLSContext(context.Background(), "", "") | ||||||
|  | 			if !errors.Is(err, ErrNotTLSConn) { | ||||||
|  | 				t.Fatal("not the error we expected") | ||||||
|  | 			} | ||||||
|  | 			if conn != nil { | ||||||
|  | 				t.Fatal("expected nil conn here") | ||||||
|  | 			} | ||||||
|  | 			if !called { | ||||||
|  | 				t.Fatal("not called") | ||||||
|  | 			} | ||||||
|  | 		}) | ||||||
| 	}) | 	}) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -140,7 +140,7 @@ func (d *quicDialerQUICGo) DialContext(ctx context.Context, network string, | |||||||
| 		pconn.Close() // we own it on failure | 		pconn.Close() // we own it on failure | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 	return &quicConnectionOwnsConn{EarlyConnection: qconn, conn: pconn}, nil | 	return newQUICConnectionOwnsConn(qconn, pconn), nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (d *quicDialerQUICGo) dialEarlyContext(ctx context.Context, | func (d *quicDialerQUICGo) dialEarlyContext(ctx context.Context, | ||||||
| @ -183,6 +183,8 @@ type quicDialerHandshakeCompleter struct { | |||||||
| 	Dialer model.QUICDialer | 	Dialer model.QUICDialer | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | var _ model.QUICDialer = &quicDialerHandshakeCompleter{} | ||||||
|  | 
 | ||||||
| // DialContext implements model.QUICDialer.DialContext. | // DialContext implements model.QUICDialer.DialContext. | ||||||
| func (d *quicDialerHandshakeCompleter) DialContext( | func (d *quicDialerHandshakeCompleter) DialContext( | ||||||
| 	ctx context.Context, network, address string, | 	ctx context.Context, network, address string, | ||||||
| @ -214,6 +216,10 @@ type quicConnectionOwnsConn struct { | |||||||
| 	conn model.UDPLikeConn | 	conn model.UDPLikeConn | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | func newQUICConnectionOwnsConn(qconn quic.EarlyConnection, pconn model.UDPLikeConn) *quicConnectionOwnsConn { | ||||||
|  | 	return &quicConnectionOwnsConn{EarlyConnection: qconn, conn: pconn} | ||||||
|  | } | ||||||
|  | 
 | ||||||
| // CloseWithError implements quic.EarlyConnection.CloseWithError. | // CloseWithError implements quic.EarlyConnection.CloseWithError. | ||||||
| func (qconn *quicConnectionOwnsConn) CloseWithError( | func (qconn *quicConnectionOwnsConn) CloseWithError( | ||||||
| 	code quic.ApplicationErrorCode, reason string) error { | 	code quic.ApplicationErrorCode, reason string) error { | ||||||
|  | |||||||
| @ -302,6 +302,31 @@ func TestQUICDialerQUICGo(t *testing.T) { | |||||||
| 				t.Fatal("the ServerName field must match") | 				t.Fatal("the ServerName field must match") | ||||||
| 			} | 			} | ||||||
| 		}) | 		}) | ||||||
|  | 
 | ||||||
|  | 		t.Run("returns a quicDialerOwnConn in case of success", func(t *testing.T) { | ||||||
|  | 			tlsConfig := &tls.Config{ | ||||||
|  | 				ServerName: "dns.google", | ||||||
|  | 			} | ||||||
|  | 			fakeconn := &mocks.QUICEarlyConnection{} | ||||||
|  | 			systemdialer := quicDialerQUICGo{ | ||||||
|  | 				QUICListener: &quicListenerStdlib{}, | ||||||
|  | 				mockDialEarlyContext: func(ctx context.Context, pconn net.PacketConn, | ||||||
|  | 					remoteAddr net.Addr, host string, tlsConfig *tls.Config, | ||||||
|  | 					quicConfig *quic.Config) (quic.EarlyConnection, error) { | ||||||
|  | 					return fakeconn, nil | ||||||
|  | 				}, | ||||||
|  | 			} | ||||||
|  | 			ctx := context.Background() | ||||||
|  | 			qconn, err := systemdialer.DialContext( | ||||||
|  | 				ctx, "udp", "8.8.8.8:443", tlsConfig, &quic.Config{}) | ||||||
|  | 			if err != nil { | ||||||
|  | 				t.Fatal(err) | ||||||
|  | 			} | ||||||
|  | 			connOwner := qconn.(*quicConnectionOwnsConn) | ||||||
|  | 			if connOwner.EarlyConnection != fakeconn { | ||||||
|  | 				t.Fatal("invalid underlying conn") | ||||||
|  | 			} | ||||||
|  | 		}) | ||||||
| 	}) | 	}) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| @ -406,6 +431,33 @@ func TestQUICDialerHandshakeCompleter(t *testing.T) { | |||||||
| 	}) | 	}) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | func TestQUICConnectionOwnsConn(t *testing.T) { | ||||||
|  | 	var ( | ||||||
|  | 		quicClose bool | ||||||
|  | 		udpClose  bool | ||||||
|  | 	) | ||||||
|  | 	qconn := &mocks.QUICEarlyConnection{ | ||||||
|  | 		MockCloseWithError: func(code quic.ApplicationErrorCode, reason string) error { | ||||||
|  | 			quicClose = true | ||||||
|  | 			return nil | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  | 	pconn := &mocks.UDPLikeConn{ | ||||||
|  | 		MockClose: func() error { | ||||||
|  | 			udpClose = true | ||||||
|  | 			return nil | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  | 	conn := newQUICConnectionOwnsConn(qconn, pconn) | ||||||
|  | 	conn.CloseWithError(0, "") | ||||||
|  | 	if !quicClose { | ||||||
|  | 		t.Fatal("did not call qconn.CloseWithError") | ||||||
|  | 	} | ||||||
|  | 	if !udpClose { | ||||||
|  | 		t.Fatal("did not call pconn.Close") | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
| func TestQUICDialerResolver(t *testing.T) { | func TestQUICDialerResolver(t *testing.T) { | ||||||
| 	t.Run("CloseIdleConnections", func(t *testing.T) { | 	t.Run("CloseIdleConnections", func(t *testing.T) { | ||||||
| 		var ( | 		var ( | ||||||
| @ -518,6 +570,27 @@ func TestQUICDialerResolver(t *testing.T) { | |||||||
| 				t.Fatal("gotTLSConfig.ServerName has not been set") | 				t.Fatal("gotTLSConfig.ServerName has not been set") | ||||||
| 			} | 			} | ||||||
| 		}) | 		}) | ||||||
|  | 
 | ||||||
|  | 		t.Run("on success", func(t *testing.T) { | ||||||
|  | 			expectedQConn := &mocks.QUICEarlyConnection{} | ||||||
|  | 			dialer := &quicDialerResolver{ | ||||||
|  | 				Resolver: NewResolverStdlib(log.Log), | ||||||
|  | 				Dialer: &mocks.QUICDialer{ | ||||||
|  | 					MockDialContext: func(ctx context.Context, network, address string, | ||||||
|  | 						tlsConfig *tls.Config, quicConfig *quic.Config) (quic.EarlyConnection, error) { | ||||||
|  | 						return expectedQConn, nil | ||||||
|  | 					}, | ||||||
|  | 				}} | ||||||
|  | 			qconn, err := dialer.DialContext( | ||||||
|  | 				context.Background(), "udp", "8.8.4.4:443", | ||||||
|  | 				&tls.Config{}, &quic.Config{}) | ||||||
|  | 			if err != nil { | ||||||
|  | 				t.Fatal(err) | ||||||
|  | 			} | ||||||
|  | 			if qconn != expectedQConn { | ||||||
|  | 				t.Fatal("unexpected underlying qconn") | ||||||
|  | 			} | ||||||
|  | 		}) | ||||||
| 	}) | 	}) | ||||||
| 
 | 
 | ||||||
| 	t.Run("lookup host with address", func(t *testing.T) { | 	t.Run("lookup host with address", func(t *testing.T) { | ||||||
|  | |||||||
| @ -22,6 +22,8 @@ type ParallelResolver struct { | |||||||
| 	Txp model.DNSTransport | 	Txp model.DNSTransport | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | var _ model.Resolver = &ParallelResolver{} | ||||||
|  | 
 | ||||||
| // UnwrappedParallelResolver creates a new ParallelResolver instance. This instance is | // UnwrappedParallelResolver creates a new ParallelResolver instance. This instance is | ||||||
| // not wrapped and you should wrap if before using it. | // not wrapped and you should wrap if before using it. | ||||||
| func NewUnwrappedParallelResolver(t model.DNSTransport) *ParallelResolver { | func NewUnwrappedParallelResolver(t model.DNSTransport) *ParallelResolver { | ||||||
|  | |||||||
| @ -33,6 +33,8 @@ type SerialResolver struct { | |||||||
| 	Txp model.DNSTransport | 	Txp model.DNSTransport | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | var _ model.Resolver = &SerialResolver{} | ||||||
|  | 
 | ||||||
| // NewUnwrappedSerialResolver creates a new, and unwrapped, SerialResolver instance. | // NewUnwrappedSerialResolver creates a new, and unwrapped, SerialResolver instance. | ||||||
| func NewUnwrappedSerialResolver(t model.DNSTransport) *SerialResolver { | func NewUnwrappedSerialResolver(t model.DNSTransport) *SerialResolver { | ||||||
| 	return &SerialResolver{ | 	return &SerialResolver{ | ||||||
|  | |||||||
							
								
								
									
										36
									
								
								internal/tracex/event_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										36
									
								
								internal/tracex/event_test.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,36 @@ | |||||||
|  | package tracex | ||||||
|  | 
 | ||||||
|  | import "testing" | ||||||
|  | 
 | ||||||
|  | func TestUnusedEventsNames(t *testing.T) { | ||||||
|  | 	// Tests that we don't break the names of events we're currently | ||||||
|  | 	// not getting the name of directly even if they're saved. | ||||||
|  | 
 | ||||||
|  | 	t.Run("EventQUICHandshakeStart", func(t *testing.T) { | ||||||
|  | 		ev := &EventQUICHandshakeStart{} | ||||||
|  | 		if ev.Name() != "quic_handshake_start" { | ||||||
|  | 			t.Fatal("invalid event name") | ||||||
|  | 		} | ||||||
|  | 	}) | ||||||
|  | 
 | ||||||
|  | 	t.Run("EventQUICHandshakeDone", func(t *testing.T) { | ||||||
|  | 		ev := &EventQUICHandshakeDone{} | ||||||
|  | 		if ev.Name() != "quic_handshake_done" { | ||||||
|  | 			t.Fatal("invalid event name") | ||||||
|  | 		} | ||||||
|  | 	}) | ||||||
|  | 
 | ||||||
|  | 	t.Run("EventTLSHandshakeStart", func(t *testing.T) { | ||||||
|  | 		ev := &EventTLSHandshakeStart{} | ||||||
|  | 		if ev.Name() != "tls_handshake_start" { | ||||||
|  | 			t.Fatal("invalid event name") | ||||||
|  | 		} | ||||||
|  | 	}) | ||||||
|  | 
 | ||||||
|  | 	t.Run("EventTLSHandshakeDone", func(t *testing.T) { | ||||||
|  | 		ev := &EventTLSHandshakeDone{} | ||||||
|  | 		if ev.Name() != "tls_handshake_done" { | ||||||
|  | 			t.Fatal("invalid event name") | ||||||
|  | 		} | ||||||
|  | 	}) | ||||||
|  | } | ||||||
| @ -177,6 +177,14 @@ func TestQUICDialerSaver(t *testing.T) { | |||||||
| 	}) | 	}) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | func TestWrapQUICListener(t *testing.T) { | ||||||
|  | 	var saver *Saver | ||||||
|  | 	ql := &mocks.QUICListener{} | ||||||
|  | 	if saver.WrapQUICListener(ql) != ql { | ||||||
|  | 		t.Fatal("unexpected result") | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
| func TestQUICListenerSaver(t *testing.T) { | func TestQUICListenerSaver(t *testing.T) { | ||||||
| 	t.Run("on failure", func(t *testing.T) { | 	t.Run("on failure", func(t *testing.T) { | ||||||
| 		expected := errors.New("mocked error") | 		expected := errors.New("mocked error") | ||||||
|  | |||||||
| @ -15,219 +15,370 @@ import ( | |||||||
| 	"github.com/ooni/probe-cli/v3/internal/runtimex" | 	"github.com/ooni/probe-cli/v3/internal/runtimex" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | func TestWrapResolver(t *testing.T) { | ||||||
|  | 	var saver *Saver | ||||||
|  | 	reso := &mocks.Resolver{} | ||||||
|  | 	if saver.WrapResolver(reso) != reso { | ||||||
|  | 		t.Fatal("unexpected result") | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
| func TestResolverSaver(t *testing.T) { | func TestResolverSaver(t *testing.T) { | ||||||
| 	t.Run("on failure", func(t *testing.T) { | 	t.Run("LookupHost", func(t *testing.T) { | ||||||
| 		expected := netxlite.ErrOODNSNoSuchHost | 		t.Run("on failure", func(t *testing.T) { | ||||||
|  | 			expected := netxlite.ErrOODNSNoSuchHost | ||||||
|  | 			saver := &Saver{} | ||||||
|  | 			reso := saver.WrapResolver(newFakeResolverWithExplicitError(expected)) | ||||||
|  | 			addrs, err := reso.LookupHost(context.Background(), "www.google.com") | ||||||
|  | 			if !errors.Is(err, expected) { | ||||||
|  | 				t.Fatal("not the error we expected") | ||||||
|  | 			} | ||||||
|  | 			if addrs != nil { | ||||||
|  | 				t.Fatal("expected nil address here") | ||||||
|  | 			} | ||||||
|  | 			ev := saver.Read() | ||||||
|  | 			if len(ev) != 2 { | ||||||
|  | 				t.Fatal("expected number of events") | ||||||
|  | 			} | ||||||
|  | 			if ev[0].Value().Hostname != "www.google.com" { | ||||||
|  | 				t.Fatal("unexpected Hostname") | ||||||
|  | 			} | ||||||
|  | 			if ev[0].Name() != "resolve_start" { | ||||||
|  | 				t.Fatal("unexpected name") | ||||||
|  | 			} | ||||||
|  | 			if !ev[0].Value().Time.Before(time.Now()) { | ||||||
|  | 				t.Fatal("the saved time is wrong") | ||||||
|  | 			} | ||||||
|  | 			if ev[1].Value().Addresses != nil { | ||||||
|  | 				t.Fatal("unexpected Addresses") | ||||||
|  | 			} | ||||||
|  | 			if ev[1].Value().Duration <= 0 { | ||||||
|  | 				t.Fatal("unexpected Duration") | ||||||
|  | 			} | ||||||
|  | 			if ev[1].Value().Err != netxlite.FailureDNSNXDOMAINError { | ||||||
|  | 				t.Fatal("unexpected Err") | ||||||
|  | 			} | ||||||
|  | 			if ev[1].Value().Hostname != "www.google.com" { | ||||||
|  | 				t.Fatal("unexpected Hostname") | ||||||
|  | 			} | ||||||
|  | 			if ev[1].Name() != "resolve_done" { | ||||||
|  | 				t.Fatal("unexpected name") | ||||||
|  | 			} | ||||||
|  | 			if !ev[1].Value().Time.After(ev[0].Value().Time) { | ||||||
|  | 				t.Fatal("the saved time is wrong") | ||||||
|  | 			} | ||||||
|  | 		}) | ||||||
|  | 
 | ||||||
|  | 		t.Run("on success", func(t *testing.T) { | ||||||
|  | 			expected := []string{"8.8.8.8", "8.8.4.4"} | ||||||
|  | 			saver := &Saver{} | ||||||
|  | 			reso := saver.WrapResolver(newFakeResolverWithResult(expected)) | ||||||
|  | 			addrs, err := reso.LookupHost(context.Background(), "www.google.com") | ||||||
|  | 			if err != nil { | ||||||
|  | 				t.Fatal("expected nil error here") | ||||||
|  | 			} | ||||||
|  | 			if !reflect.DeepEqual(addrs, expected) { | ||||||
|  | 				t.Fatal("not the result we expected") | ||||||
|  | 			} | ||||||
|  | 			ev := saver.Read() | ||||||
|  | 			if len(ev) != 2 { | ||||||
|  | 				t.Fatal("expected number of events") | ||||||
|  | 			} | ||||||
|  | 			if ev[0].Value().Hostname != "www.google.com" { | ||||||
|  | 				t.Fatal("unexpected Hostname") | ||||||
|  | 			} | ||||||
|  | 			if ev[0].Name() != "resolve_start" { | ||||||
|  | 				t.Fatal("unexpected name") | ||||||
|  | 			} | ||||||
|  | 			if !ev[0].Value().Time.Before(time.Now()) { | ||||||
|  | 				t.Fatal("the saved time is wrong") | ||||||
|  | 			} | ||||||
|  | 			if !reflect.DeepEqual(ev[1].Value().Addresses, expected) { | ||||||
|  | 				t.Fatal("unexpected Addresses") | ||||||
|  | 			} | ||||||
|  | 			if ev[1].Value().Duration <= 0 { | ||||||
|  | 				t.Fatal("unexpected Duration") | ||||||
|  | 			} | ||||||
|  | 			if ev[1].Value().Err.IsNotNil() { | ||||||
|  | 				t.Fatal("unexpected Err") | ||||||
|  | 			} | ||||||
|  | 			if ev[1].Value().Hostname != "www.google.com" { | ||||||
|  | 				t.Fatal("unexpected Hostname") | ||||||
|  | 			} | ||||||
|  | 			if ev[1].Name() != "resolve_done" { | ||||||
|  | 				t.Fatal("unexpected name") | ||||||
|  | 			} | ||||||
|  | 			if !ev[1].Value().Time.After(ev[0].Value().Time) { | ||||||
|  | 				t.Fatal("the saved time is wrong") | ||||||
|  | 			} | ||||||
|  | 		}) | ||||||
|  | 	}) | ||||||
|  | 
 | ||||||
|  | 	t.Run("Network", func(t *testing.T) { | ||||||
| 		saver := &Saver{} | 		saver := &Saver{} | ||||||
| 		reso := saver.WrapResolver(newFakeResolverWithExplicitError(expected)) | 		child := &mocks.Resolver{ | ||||||
| 		addrs, err := reso.LookupHost(context.Background(), "www.google.com") | 			MockNetwork: func() string { | ||||||
| 		if !errors.Is(err, expected) { | 				return "x" | ||||||
| 			t.Fatal("not the error we expected") | 			}, | ||||||
| 		} | 		} | ||||||
| 		if addrs != nil { | 		reso := saver.WrapResolver(child) | ||||||
| 			t.Fatal("expected nil address here") | 		if reso.Network() != "x" { | ||||||
| 		} | 			t.Fatal("unexpected result") | ||||||
| 		ev := saver.Read() |  | ||||||
| 		if len(ev) != 2 { |  | ||||||
| 			t.Fatal("expected number of events") |  | ||||||
| 		} |  | ||||||
| 		if ev[0].Value().Hostname != "www.google.com" { |  | ||||||
| 			t.Fatal("unexpected Hostname") |  | ||||||
| 		} |  | ||||||
| 		if ev[0].Name() != "resolve_start" { |  | ||||||
| 			t.Fatal("unexpected name") |  | ||||||
| 		} |  | ||||||
| 		if !ev[0].Value().Time.Before(time.Now()) { |  | ||||||
| 			t.Fatal("the saved time is wrong") |  | ||||||
| 		} |  | ||||||
| 		if ev[1].Value().Addresses != nil { |  | ||||||
| 			t.Fatal("unexpected Addresses") |  | ||||||
| 		} |  | ||||||
| 		if ev[1].Value().Duration <= 0 { |  | ||||||
| 			t.Fatal("unexpected Duration") |  | ||||||
| 		} |  | ||||||
| 		if ev[1].Value().Err != netxlite.FailureDNSNXDOMAINError { |  | ||||||
| 			t.Fatal("unexpected Err") |  | ||||||
| 		} |  | ||||||
| 		if ev[1].Value().Hostname != "www.google.com" { |  | ||||||
| 			t.Fatal("unexpected Hostname") |  | ||||||
| 		} |  | ||||||
| 		if ev[1].Name() != "resolve_done" { |  | ||||||
| 			t.Fatal("unexpected name") |  | ||||||
| 		} |  | ||||||
| 		if !ev[1].Value().Time.After(ev[0].Value().Time) { |  | ||||||
| 			t.Fatal("the saved time is wrong") |  | ||||||
| 		} | 		} | ||||||
| 	}) | 	}) | ||||||
| 
 | 
 | ||||||
| 	t.Run("on success", func(t *testing.T) { | 	t.Run("Address", func(t *testing.T) { | ||||||
| 		expected := []string{"8.8.8.8", "8.8.4.4"} |  | ||||||
| 		saver := &Saver{} | 		saver := &Saver{} | ||||||
| 		reso := saver.WrapResolver(newFakeResolverWithResult(expected)) | 		child := &mocks.Resolver{ | ||||||
| 		addrs, err := reso.LookupHost(context.Background(), "www.google.com") | 			MockAddress: func() string { | ||||||
| 		if err != nil { | 				return "x" | ||||||
| 			t.Fatal("expected nil error here") | 			}, | ||||||
| 		} | 		} | ||||||
| 		if !reflect.DeepEqual(addrs, expected) { | 		reso := saver.WrapResolver(child) | ||||||
| 			t.Fatal("not the result we expected") | 		if reso.Address() != "x" { | ||||||
|  | 			t.Fatal("unexpected result") | ||||||
| 		} | 		} | ||||||
| 		ev := saver.Read() | 	}) | ||||||
| 		if len(ev) != 2 { | 
 | ||||||
| 			t.Fatal("expected number of events") | 	t.Run("LookupHTTPS", func(t *testing.T) { | ||||||
|  | 		expected := errors.New("mocked") | ||||||
|  | 		saver := &Saver{} | ||||||
|  | 		child := &mocks.Resolver{ | ||||||
|  | 			MockLookupHTTPS: func(ctx context.Context, domain string) (*model.HTTPSSvc, error) { | ||||||
|  | 				return nil, expected | ||||||
|  | 			}, | ||||||
| 		} | 		} | ||||||
| 		if ev[0].Value().Hostname != "www.google.com" { | 		reso := saver.WrapResolver(child) | ||||||
| 			t.Fatal("unexpected Hostname") | 		https, err := reso.LookupHTTPS(context.Background(), "dns.google") | ||||||
|  | 		if !errors.Is(err, expected) { | ||||||
|  | 			t.Fatal("unexpected err", err) | ||||||
| 		} | 		} | ||||||
| 		if ev[0].Name() != "resolve_start" { | 		if https != nil { | ||||||
| 			t.Fatal("unexpected name") | 			t.Fatal("expected nil") | ||||||
| 		} | 		} | ||||||
| 		if !ev[0].Value().Time.Before(time.Now()) { | 	}) | ||||||
| 			t.Fatal("the saved time is wrong") | 
 | ||||||
|  | 	t.Run("LookupNS", func(t *testing.T) { | ||||||
|  | 		expected := errors.New("mocked") | ||||||
|  | 		saver := &Saver{} | ||||||
|  | 		child := &mocks.Resolver{ | ||||||
|  | 			MockLookupNS: func(ctx context.Context, domain string) ([]*net.NS, error) { | ||||||
|  | 				return nil, expected | ||||||
|  | 			}, | ||||||
| 		} | 		} | ||||||
| 		if !reflect.DeepEqual(ev[1].Value().Addresses, expected) { | 		reso := saver.WrapResolver(child) | ||||||
| 			t.Fatal("unexpected Addresses") | 		ns, err := reso.LookupNS(context.Background(), "dns.google") | ||||||
|  | 		if !errors.Is(err, expected) { | ||||||
|  | 			t.Fatal("unexpected err", err) | ||||||
| 		} | 		} | ||||||
| 		if ev[1].Value().Duration <= 0 { | 		if len(ns) != 0 { | ||||||
| 			t.Fatal("unexpected Duration") | 			t.Fatal("expected zero length array") | ||||||
| 		} | 		} | ||||||
| 		if ev[1].Value().Err.IsNotNil() { | 	}) | ||||||
| 			t.Fatal("unexpected Err") | 
 | ||||||
|  | 	t.Run("CloseIdleConnections", func(t *testing.T) { | ||||||
|  | 		var called bool | ||||||
|  | 		saver := &Saver{} | ||||||
|  | 		child := &mocks.Resolver{ | ||||||
|  | 			MockCloseIdleConnections: func() { | ||||||
|  | 				called = true | ||||||
|  | 			}, | ||||||
| 		} | 		} | ||||||
| 		if ev[1].Value().Hostname != "www.google.com" { | 		reso := saver.WrapResolver(child) | ||||||
| 			t.Fatal("unexpected Hostname") | 		reso.CloseIdleConnections() | ||||||
| 		} | 		if !called { | ||||||
| 		if ev[1].Name() != "resolve_done" { | 			t.Fatal("not called") | ||||||
| 			t.Fatal("unexpected name") |  | ||||||
| 		} |  | ||||||
| 		if !ev[1].Value().Time.After(ev[0].Value().Time) { |  | ||||||
| 			t.Fatal("the saved time is wrong") |  | ||||||
| 		} | 		} | ||||||
| 	}) | 	}) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | func TestWrapDNSTransport(t *testing.T) { | ||||||
|  | 	var saver *Saver | ||||||
|  | 	txp := &mocks.DNSTransport{} | ||||||
|  | 	if saver.WrapDNSTransport(txp) != txp { | ||||||
|  | 		t.Fatal("unexpected result") | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
| func TestDNSTransportSaver(t *testing.T) { | func TestDNSTransportSaver(t *testing.T) { | ||||||
| 	t.Run("on failure", func(t *testing.T) { | 	t.Run("RoundTrip", func(t *testing.T) { | ||||||
| 		expected := netxlite.ErrOODNSNoSuchHost | 		t.Run("on failure", func(t *testing.T) { | ||||||
| 		saver := &Saver{} | 			expected := netxlite.ErrOODNSNoSuchHost | ||||||
| 		txp := saver.WrapDNSTransport(&mocks.DNSTransport{ | 			saver := &Saver{} | ||||||
| 			MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) { | 			txp := saver.WrapDNSTransport(&mocks.DNSTransport{ | ||||||
| 				return nil, expected | 				MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) { | ||||||
| 			}, | 					return nil, expected | ||||||
| 			MockNetwork: func() string { | 				}, | ||||||
| 				return "fake" | 				MockNetwork: func() string { | ||||||
| 			}, | 					return "fake" | ||||||
| 			MockAddress: func() string { | 				}, | ||||||
| 				return "" | 				MockAddress: func() string { | ||||||
| 			}, | 					return "" | ||||||
|  | 				}, | ||||||
|  | 			}) | ||||||
|  | 			rawQuery := []byte{0xde, 0xad, 0xbe, 0xef} | ||||||
|  | 			query := &mocks.DNSQuery{ | ||||||
|  | 				MockBytes: func() ([]byte, error) { | ||||||
|  | 					return rawQuery, nil | ||||||
|  | 				}, | ||||||
|  | 			} | ||||||
|  | 			reply, err := txp.RoundTrip(context.Background(), query) | ||||||
|  | 			if !errors.Is(err, expected) { | ||||||
|  | 				t.Fatal("not the error we expected") | ||||||
|  | 			} | ||||||
|  | 			if reply != nil { | ||||||
|  | 				t.Fatal("expected nil reply here") | ||||||
|  | 			} | ||||||
|  | 			ev := saver.Read() | ||||||
|  | 			if len(ev) != 2 { | ||||||
|  | 				t.Fatal("expected number of events") | ||||||
|  | 			} | ||||||
|  | 			if !bytes.Equal(ev[0].Value().DNSQuery, rawQuery) { | ||||||
|  | 				t.Fatal("unexpected DNSQuery") | ||||||
|  | 			} | ||||||
|  | 			if ev[0].Name() != "dns_round_trip_start" { | ||||||
|  | 				t.Fatal("unexpected name") | ||||||
|  | 			} | ||||||
|  | 			if !ev[0].Value().Time.Before(time.Now()) { | ||||||
|  | 				t.Fatal("the saved time is wrong") | ||||||
|  | 			} | ||||||
|  | 			if !bytes.Equal(ev[1].Value().DNSQuery, rawQuery) { | ||||||
|  | 				t.Fatal("unexpected DNSQuery") | ||||||
|  | 			} | ||||||
|  | 			if ev[1].Value().DNSResponse != nil { | ||||||
|  | 				t.Fatal("unexpected DNSReply") | ||||||
|  | 			} | ||||||
|  | 			if ev[1].Value().Duration <= 0 { | ||||||
|  | 				t.Fatal("unexpected Duration") | ||||||
|  | 			} | ||||||
|  | 			if ev[1].Value().Err != netxlite.FailureDNSNXDOMAINError { | ||||||
|  | 				t.Fatal("unexpected Err") | ||||||
|  | 			} | ||||||
|  | 			if ev[1].Name() != "dns_round_trip_done" { | ||||||
|  | 				t.Fatal("unexpected name") | ||||||
|  | 			} | ||||||
|  | 			if !ev[1].Value().Time.After(ev[0].Value().Time) { | ||||||
|  | 				t.Fatal("the saved time is wrong") | ||||||
|  | 			} | ||||||
| 		}) | 		}) | ||||||
| 		rawQuery := []byte{0xde, 0xad, 0xbe, 0xef} | 
 | ||||||
| 		query := &mocks.DNSQuery{ | 		t.Run("on success", func(t *testing.T) { | ||||||
| 			MockBytes: func() ([]byte, error) { | 			expected := []byte{0xef, 0xbe, 0xad, 0xde} | ||||||
| 				return rawQuery, nil | 			saver := &Saver{} | ||||||
|  | 			response := &mocks.DNSResponse{ | ||||||
|  | 				MockBytes: func() []byte { | ||||||
|  | 					return expected | ||||||
|  | 				}, | ||||||
|  | 			} | ||||||
|  | 			txp := saver.WrapDNSTransport(&mocks.DNSTransport{ | ||||||
|  | 				MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) { | ||||||
|  | 					return response, nil | ||||||
|  | 				}, | ||||||
|  | 				MockNetwork: func() string { | ||||||
|  | 					return "fake" | ||||||
|  | 				}, | ||||||
|  | 				MockAddress: func() string { | ||||||
|  | 					return "" | ||||||
|  | 				}, | ||||||
|  | 			}) | ||||||
|  | 			rawQuery := []byte{0xde, 0xad, 0xbe, 0xef} | ||||||
|  | 			query := &mocks.DNSQuery{ | ||||||
|  | 				MockBytes: func() ([]byte, error) { | ||||||
|  | 					return rawQuery, nil | ||||||
|  | 				}, | ||||||
|  | 			} | ||||||
|  | 			reply, err := txp.RoundTrip(context.Background(), query) | ||||||
|  | 			if err != nil { | ||||||
|  | 				t.Fatal("we expected nil error here") | ||||||
|  | 			} | ||||||
|  | 			if !bytes.Equal(reply.Bytes(), expected) { | ||||||
|  | 				t.Fatal("expected another reply here") | ||||||
|  | 			} | ||||||
|  | 			ev := saver.Read() | ||||||
|  | 			if len(ev) != 2 { | ||||||
|  | 				t.Fatal("expected number of events") | ||||||
|  | 			} | ||||||
|  | 			if !bytes.Equal(ev[0].Value().DNSQuery, rawQuery) { | ||||||
|  | 				t.Fatal("unexpected DNSQuery") | ||||||
|  | 			} | ||||||
|  | 			if ev[0].Name() != "dns_round_trip_start" { | ||||||
|  | 				t.Fatal("unexpected name") | ||||||
|  | 			} | ||||||
|  | 			if !ev[0].Value().Time.Before(time.Now()) { | ||||||
|  | 				t.Fatal("the saved time is wrong") | ||||||
|  | 			} | ||||||
|  | 			if !bytes.Equal(ev[1].Value().DNSQuery, rawQuery) { | ||||||
|  | 				t.Fatal("unexpected DNSQuery") | ||||||
|  | 			} | ||||||
|  | 			if !bytes.Equal(ev[1].Value().DNSResponse, expected) { | ||||||
|  | 				t.Fatal("unexpected DNSReply") | ||||||
|  | 			} | ||||||
|  | 			if ev[1].Value().Duration <= 0 { | ||||||
|  | 				t.Fatal("unexpected Duration") | ||||||
|  | 			} | ||||||
|  | 			if ev[1].Value().Err.IsNotNil() { | ||||||
|  | 				t.Fatal("unexpected Err") | ||||||
|  | 			} | ||||||
|  | 			if ev[1].Name() != "dns_round_trip_done" { | ||||||
|  | 				t.Fatal("unexpected name") | ||||||
|  | 			} | ||||||
|  | 			if !ev[1].Value().Time.After(ev[0].Value().Time) { | ||||||
|  | 				t.Fatal("the saved time is wrong") | ||||||
|  | 			} | ||||||
|  | 		}) | ||||||
|  | 	}) | ||||||
|  | 
 | ||||||
|  | 	t.Run("Network", func(t *testing.T) { | ||||||
|  | 		saver := &Saver{} | ||||||
|  | 		child := &mocks.DNSTransport{ | ||||||
|  | 			MockNetwork: func() string { | ||||||
|  | 				return "x" | ||||||
| 			}, | 			}, | ||||||
| 		} | 		} | ||||||
| 		reply, err := txp.RoundTrip(context.Background(), query) | 		txp := saver.WrapDNSTransport(child) | ||||||
| 		if !errors.Is(err, expected) { | 		if txp.Network() != "x" { | ||||||
| 			t.Fatal("not the error we expected") | 			t.Fatal("unexpected result") | ||||||
| 		} |  | ||||||
| 		if reply != nil { |  | ||||||
| 			t.Fatal("expected nil reply here") |  | ||||||
| 		} |  | ||||||
| 		ev := saver.Read() |  | ||||||
| 		if len(ev) != 2 { |  | ||||||
| 			t.Fatal("expected number of events") |  | ||||||
| 		} |  | ||||||
| 		if !bytes.Equal(ev[0].Value().DNSQuery, rawQuery) { |  | ||||||
| 			t.Fatal("unexpected DNSQuery") |  | ||||||
| 		} |  | ||||||
| 		if ev[0].Name() != "dns_round_trip_start" { |  | ||||||
| 			t.Fatal("unexpected name") |  | ||||||
| 		} |  | ||||||
| 		if !ev[0].Value().Time.Before(time.Now()) { |  | ||||||
| 			t.Fatal("the saved time is wrong") |  | ||||||
| 		} |  | ||||||
| 		if !bytes.Equal(ev[1].Value().DNSQuery, rawQuery) { |  | ||||||
| 			t.Fatal("unexpected DNSQuery") |  | ||||||
| 		} |  | ||||||
| 		if ev[1].Value().DNSResponse != nil { |  | ||||||
| 			t.Fatal("unexpected DNSReply") |  | ||||||
| 		} |  | ||||||
| 		if ev[1].Value().Duration <= 0 { |  | ||||||
| 			t.Fatal("unexpected Duration") |  | ||||||
| 		} |  | ||||||
| 		if ev[1].Value().Err != netxlite.FailureDNSNXDOMAINError { |  | ||||||
| 			t.Fatal("unexpected Err") |  | ||||||
| 		} |  | ||||||
| 		if ev[1].Name() != "dns_round_trip_done" { |  | ||||||
| 			t.Fatal("unexpected name") |  | ||||||
| 		} |  | ||||||
| 		if !ev[1].Value().Time.After(ev[0].Value().Time) { |  | ||||||
| 			t.Fatal("the saved time is wrong") |  | ||||||
| 		} | 		} | ||||||
| 	}) | 	}) | ||||||
| 
 | 
 | ||||||
| 	t.Run("on success", func(t *testing.T) { | 	t.Run("Address", func(t *testing.T) { | ||||||
| 		expected := []byte{0xef, 0xbe, 0xad, 0xde} |  | ||||||
| 		saver := &Saver{} | 		saver := &Saver{} | ||||||
| 		response := &mocks.DNSResponse{ | 		child := &mocks.DNSTransport{ | ||||||
| 			MockBytes: func() []byte { |  | ||||||
| 				return expected |  | ||||||
| 			}, |  | ||||||
| 		} |  | ||||||
| 		txp := saver.WrapDNSTransport(&mocks.DNSTransport{ |  | ||||||
| 			MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) { |  | ||||||
| 				return response, nil |  | ||||||
| 			}, |  | ||||||
| 			MockNetwork: func() string { |  | ||||||
| 				return "fake" |  | ||||||
| 			}, |  | ||||||
| 			MockAddress: func() string { | 			MockAddress: func() string { | ||||||
| 				return "" | 				return "x" | ||||||
| 			}, |  | ||||||
| 		}) |  | ||||||
| 		rawQuery := []byte{0xde, 0xad, 0xbe, 0xef} |  | ||||||
| 		query := &mocks.DNSQuery{ |  | ||||||
| 			MockBytes: func() ([]byte, error) { |  | ||||||
| 				return rawQuery, nil |  | ||||||
| 			}, | 			}, | ||||||
| 		} | 		} | ||||||
| 		reply, err := txp.RoundTrip(context.Background(), query) | 		txp := saver.WrapDNSTransport(child) | ||||||
| 		if err != nil { | 		if txp.Address() != "x" { | ||||||
| 			t.Fatal("we expected nil error here") | 			t.Fatal("unexpected result") | ||||||
| 		} | 		} | ||||||
| 		if !bytes.Equal(reply.Bytes(), expected) { | 	}) | ||||||
| 			t.Fatal("expected another reply here") | 
 | ||||||
|  | 	t.Run("CloseIdleConnections", func(t *testing.T) { | ||||||
|  | 		var called bool | ||||||
|  | 		saver := &Saver{} | ||||||
|  | 		child := &mocks.DNSTransport{ | ||||||
|  | 			MockCloseIdleConnections: func() { | ||||||
|  | 				called = true | ||||||
|  | 			}, | ||||||
| 		} | 		} | ||||||
| 		ev := saver.Read() | 		txp := saver.WrapDNSTransport(child) | ||||||
| 		if len(ev) != 2 { | 		txp.CloseIdleConnections() | ||||||
| 			t.Fatal("expected number of events") | 		if !called { | ||||||
|  | 			t.Fatal("not called") | ||||||
| 		} | 		} | ||||||
| 		if !bytes.Equal(ev[0].Value().DNSQuery, rawQuery) { | 	}) | ||||||
| 			t.Fatal("unexpected DNSQuery") | 
 | ||||||
|  | 	t.Run("RequiresPadding", func(t *testing.T) { | ||||||
|  | 		saver := &Saver{} | ||||||
|  | 		child := &mocks.DNSTransport{ | ||||||
|  | 			MockRequiresPadding: func() bool { | ||||||
|  | 				return true | ||||||
|  | 			}, | ||||||
| 		} | 		} | ||||||
| 		if ev[0].Name() != "dns_round_trip_start" { | 		txp := saver.WrapDNSTransport(child) | ||||||
| 			t.Fatal("unexpected name") | 		if !txp.RequiresPadding() { | ||||||
| 		} | 			t.Fatal("unexpected result") | ||||||
| 		if !ev[0].Value().Time.Before(time.Now()) { |  | ||||||
| 			t.Fatal("the saved time is wrong") |  | ||||||
| 		} |  | ||||||
| 		if !bytes.Equal(ev[1].Value().DNSQuery, rawQuery) { |  | ||||||
| 			t.Fatal("unexpected DNSQuery") |  | ||||||
| 		} |  | ||||||
| 		if !bytes.Equal(ev[1].Value().DNSResponse, expected) { |  | ||||||
| 			t.Fatal("unexpected DNSReply") |  | ||||||
| 		} |  | ||||||
| 		if ev[1].Value().Duration <= 0 { |  | ||||||
| 			t.Fatal("unexpected Duration") |  | ||||||
| 		} |  | ||||||
| 		if ev[1].Value().Err.IsNotNil() { |  | ||||||
| 			t.Fatal("unexpected Err") |  | ||||||
| 		} |  | ||||||
| 		if ev[1].Name() != "dns_round_trip_done" { |  | ||||||
| 			t.Fatal("unexpected name") |  | ||||||
| 		} |  | ||||||
| 		if !ev[1].Value().Time.After(ev[0].Value().Time) { |  | ||||||
| 			t.Fatal("the saved time is wrong") |  | ||||||
| 		} | 		} | ||||||
| 	}) | 	}) | ||||||
| } | } | ||||||
|  | |||||||
| @ -12,6 +12,14 @@ import ( | |||||||
| 	"github.com/ooni/probe-cli/v3/internal/model/mocks" | 	"github.com/ooni/probe-cli/v3/internal/model/mocks" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | func TestWrapTLSHandshaker(t *testing.T) { | ||||||
|  | 	var saver *Saver | ||||||
|  | 	thx := &mocks.TLSHandshaker{} | ||||||
|  | 	if saver.WrapTLSHandshaker(thx) != thx { | ||||||
|  | 		t.Fatal("unexpected result") | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
| func TestTLSHandshakerSaver(t *testing.T) { | func TestTLSHandshakerSaver(t *testing.T) { | ||||||
| 
 | 
 | ||||||
| 	t.Run("Handshake", func(t *testing.T) { | 	t.Run("Handshake", func(t *testing.T) { | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user