diff --git a/internal/netxlite/filtering/dns.go b/internal/netxlite/filtering/dns.go index 452dc2a..a055c53 100644 --- a/internal/netxlite/filtering/dns.go +++ b/internal/netxlite/filtering/dns.go @@ -34,11 +34,19 @@ const ( // DNSActionTimeout never replies to the query. DNSActionTimeout = DNSAction("timeout") + + // DNSActionCache causes the proxy to check the cache. If there + // are entries, they are returned. Otherwise, NXDOMAIN is returned. + DNSActionCache = DNSAction("cache") ) // DNSProxy is a DNS proxy that routes traffic to an upstream // resolver and may implement filtering policies. type DNSProxy struct { + // Cache is the DNS cache. Note that the keys of the map + // must be FQDNs (i.e., including the final `.`). + Cache map[string][]string + // OnQuery is the MANDATORY hook called whenever we // receive a query for the given domain. OnQuery func(domain string) DNSAction @@ -135,6 +143,8 @@ func (p *DNSProxy) replyDefault(query *dns.Msg) (*dns.Msg, error) { return p.empty(query), nil case DNSActionTimeout: return nil, errors.New("let's ignore this query") + case DNSActionCache: + return p.cache(name, query), nil default: return p.refused(query), nil } @@ -213,6 +223,20 @@ func (p *DNSProxy) proxy(query *dns.Msg) (*dns.Msg, error) { return reply, nil } +func (p *DNSProxy) cache(name string, query *dns.Msg) *dns.Msg { + addrs := p.Cache[name] + var ipAddrs []net.IP + for _, addr := range addrs { + if ip := net.ParseIP(addr); ip != nil { + ipAddrs = append(ipAddrs, ip) + } + } + if len(ipAddrs) <= 0 { + return p.nxdomain(query) + } + return p.compose(query, ipAddrs...) +} + func (p *DNSProxy) dnstransport() DNSTransport { if p.Upstream != nil { return p.Upstream diff --git a/internal/netxlite/filtering/dns_test.go b/internal/netxlite/filtering/dns_test.go index 76782cd..debd0a6 100644 --- a/internal/netxlite/filtering/dns_test.go +++ b/internal/netxlite/filtering/dns_test.go @@ -15,8 +15,9 @@ import ( ) func TestDNSProxy(t *testing.T) { - newproxy := func(action DNSAction) (DNSListener, <-chan interface{}, error) { + newProxyWithCache := func(action DNSAction, cache map[string][]string) (DNSListener, <-chan interface{}, error) { p := &DNSProxy{ + Cache: cache, OnQuery: func(domain string) DNSAction { return action }, @@ -24,6 +25,10 @@ func TestDNSProxy(t *testing.T) { return p.start("127.0.0.1:0") } + newProxy := func(action DNSAction) (DNSListener, <-chan interface{}, error) { + return newProxyWithCache(action, nil) + } + newresolver := func(listener DNSListener) netxlite.Resolver { dlr := netxlite.NewDialerWithoutResolver(log.Log) r := netxlite.NewResolverUDP(log.Log, dlr, listener.LocalAddr().String()) @@ -32,7 +37,7 @@ func TestDNSProxy(t *testing.T) { t.Run("DNSActionPass", func(t *testing.T) { ctx := context.Background() - listener, done, err := newproxy(DNSActionPass) + listener, done, err := newProxy(DNSActionPass) if err != nil { t.Fatal(err) } @@ -57,7 +62,7 @@ func TestDNSProxy(t *testing.T) { t.Run("DNSActionNXDOMAIN", func(t *testing.T) { ctx := context.Background() - listener, done, err := newproxy(DNSActionNXDOMAIN) + listener, done, err := newProxy(DNSActionNXDOMAIN) if err != nil { t.Fatal(err) } @@ -75,7 +80,7 @@ func TestDNSProxy(t *testing.T) { t.Run("DNSActionRefused", func(t *testing.T) { ctx := context.Background() - listener, done, err := newproxy(DNSActionRefused) + listener, done, err := newProxy(DNSActionRefused) if err != nil { t.Fatal(err) } @@ -93,7 +98,7 @@ func TestDNSProxy(t *testing.T) { t.Run("DNSActionLocalHost", func(t *testing.T) { ctx := context.Background() - listener, done, err := newproxy(DNSActionLocalHost) + listener, done, err := newProxy(DNSActionLocalHost) if err != nil { t.Fatal(err) } @@ -118,7 +123,7 @@ func TestDNSProxy(t *testing.T) { t.Run("DNSActionEmpty", func(t *testing.T) { ctx := context.Background() - listener, done, err := newproxy(DNSActionNoAnswer) + listener, done, err := newProxy(DNSActionNoAnswer) if err != nil { t.Fatal(err) } @@ -142,7 +147,7 @@ func TestDNSProxy(t *testing.T) { const timeout = time.Second ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() - listener, done, err := newproxy(DNSActionTimeout) + listener, done, err := newProxy(DNSActionTimeout) if err != nil { t.Fatal(err) } @@ -158,6 +163,51 @@ func TestDNSProxy(t *testing.T) { <-done // wait for background goroutine to exit }) + t.Run("DNSActionCache without entries", func(t *testing.T) { + ctx := context.Background() + listener, done, err := newProxyWithCache(DNSActionCache, nil) + if err != nil { + t.Fatal(err) + } + r := newresolver(listener) + addrs, err := r.LookupHost(ctx, "dns.google") + if err == nil || err.Error() != netxlite.FailureDNSNXDOMAINError { + t.Fatal("unexpected err", err) + } + if addrs != nil { + t.Fatal("expected empty addrs") + } + listener.Close() + <-done // wait for background goroutine to exit + }) + + t.Run("DNSActionCache with entries", func(t *testing.T) { + ctx := context.Background() + cache := map[string][]string{ + "dns.google.": {"8.8.8.8", "8.8.4.4"}, + } + listener, done, err := newProxyWithCache(DNSActionCache, cache) + if err != nil { + t.Fatal(err) + } + r := newresolver(listener) + addrs, err := r.LookupHost(ctx, "dns.google") + if err != nil { + t.Fatal(err) + } + if len(addrs) != 2 { + t.Fatal("expected two entries") + } + if addrs[0] != "8.8.8.8" { + t.Fatal("invalid first entry") + } + if addrs[1] != "8.8.4.4" { + t.Fatal("invalid second entry") + } + listener.Close() + <-done // wait for background goroutine to exit + }) + t.Run("Start with invalid address", func(t *testing.T) { p := &DNSProxy{} listener, err := p.Start("127.0.0.1") diff --git a/internal/netxlite/filtering/testdata/valid.json b/internal/netxlite/filtering/testdata/valid.json index 9b437e2..ab9d3fd 100644 --- a/internal/netxlite/filtering/testdata/valid.json +++ b/internal/netxlite/filtering/testdata/valid.json @@ -1,4 +1,7 @@ { + "DNSCache": { + "dns.google": ["8.8.8.8", "8.8.4.4"] + }, "Domains": { "x.org": "pass" } diff --git a/internal/netxlite/filtering/tproxy.go b/internal/netxlite/filtering/tproxy.go index 43f200b..59fb03f 100644 --- a/internal/netxlite/filtering/tproxy.go +++ b/internal/netxlite/filtering/tproxy.go @@ -43,6 +43,15 @@ const ( // TProxyConfig contains configuration for TProxy. type TProxyConfig struct { + // DNSCache is the cached used when the domains policy is "cache". Note + // that the map MUST contain FQDNs. That is, you need to append + // a final dot to the domain name (e.g., `example.com.`). If you + // use the NewTProxyConfig factory, you don't need to worry about this + // issue, because the factory will canonicalize non-canonical + // entries. Otherwise, you can explicitly call the CanonicalizeDNS + // method _before_ using the TProxy. + DNSCache map[string][]string + // Domains contains rules for filtering the lookup of domains. Note // that the map MUST contain FQDNs. That is, you need to append // a final dot to the domain name (e.g., `example.com.`). If you @@ -84,6 +93,11 @@ func (c *TProxyConfig) CanonicalizeDNS() { domains[dns.CanonicalName(domain)] = policy } c.Domains = domains + cache := make(map[string][]string) + for domain, addrs := range c.DNSCache { + cache[dns.CanonicalName(domain)] = addrs + } + c.DNSCache = cache } // TProxy is a netxlite.TProxable that implements self censorship. @@ -146,7 +160,7 @@ func newTProxy(config *TProxyConfig, logger Logger, dnsListenerAddr, func (p *TProxy) newDNSListener(listenAddr string) error { var err error - dnsProxy := &DNSProxy{OnQuery: p.onQuery} + dnsProxy := &DNSProxy{Cache: p.config.DNSCache, OnQuery: p.onQuery} p.dnsListener, err = dnsProxy.Start(listenAddr) return err } diff --git a/internal/netxlite/filtering/tproxy_test.go b/internal/netxlite/filtering/tproxy_test.go index 038ec33..aca4247 100644 --- a/internal/netxlite/filtering/tproxy_test.go +++ b/internal/netxlite/filtering/tproxy_test.go @@ -58,7 +58,10 @@ func TestNewTProxyConfig(t *testing.T) { t.Fatal("expected non-nil config here") } if config.Domains["x.org."] != "pass" { - t.Fatal("did not auto-canonicalize names") + t.Fatal("did not auto-canonicalize config.Domains") + } + if len(config.DNSCache["dns.google."]) != 2 { + t.Fatal("did not auto-canonicalize config.DNSCache") } }) } @@ -519,3 +522,54 @@ func TestTProxyDial(t *testing.T) { } }) } + +func TestTProxyDNSCache(t *testing.T) { + t.Run("without cache but with the cache rule", func(t *testing.T) { + config := &TProxyConfig{ + Domains: map[string]DNSAction{ + "dns.google.": DNSActionCache, + }, + } + proxy, err := NewTProxy(config, log.Log) + if err != nil { + t.Fatal(err) + } + ctx := context.Background() + addrs, err := proxy.LookupHost(ctx, "dns.google") + if err == nil || err.Error() != netxlite.FailureDNSNXDOMAINError { + t.Fatal("unexpected err", err) + } + if addrs != nil { + t.Fatal("expected nil addrs") + } + }) + + t.Run("with cache", func(t *testing.T) { + config := &TProxyConfig{ + DNSCache: map[string][]string{ + "dns.google.": {"8.8.8.8", "8.8.4.4"}, + }, + Domains: map[string]DNSAction{ + "dns.google.": DNSActionCache, + }, + } + proxy, err := NewTProxy(config, log.Log) + if err != nil { + t.Fatal(err) + } + ctx := context.Background() + addrs, err := proxy.LookupHost(ctx, "dns.google") + if err != nil { + t.Fatal(err) + } + if len(addrs) != 2 { + t.Fatal("expected two addrs") + } + if addrs[0] != "8.8.8.8" { + t.Fatal("invalid first address") + } + if addrs[1] != "8.8.4.4" { + t.Fatal("invalid second address") + } + }) +}