diff --git a/config/config.go b/config/config.go index 90c1ff16f..ec9053366 100644 --- a/config/config.go +++ b/config/config.go @@ -189,44 +189,44 @@ func (b *BootstrappedUpstreamConfig) UnmarshalYAML(unmarshal func(interface{}) e // //nolint:maligned type Config struct { - Upstreams UpstreamsConfig `yaml:"upstreams"` - ConnectIPVersion IPVersion `yaml:"connectIPVersion"` - CustomDNS CustomDNSConfig `yaml:"customDNS"` - Conditional ConditionalUpstreamConfig `yaml:"conditional"` - Blocking BlockingConfig `yaml:"blocking"` - ClientLookup ClientLookupConfig `yaml:"clientLookup"` - Caching CachingConfig `yaml:"caching"` - QueryLog QueryLogConfig `yaml:"queryLog"` - Prometheus MetricsConfig `yaml:"prometheus"` - Redis RedisConfig `yaml:"redis"` - Log log.Config `yaml:"log"` - Ports PortsConfig `yaml:"ports"` - DoHUserAgent string `yaml:"dohUserAgent"` - MinTLSServeVer string `yaml:"minTlsServeVersion" default:"1.2"` - StartVerifyUpstream bool `yaml:"startVerifyUpstream" default:"false"` - CertFile string `yaml:"certFile"` - KeyFile string `yaml:"keyFile"` - BootstrapDNS BootstrapDNSConfig `yaml:"bootstrapDns"` - HostsFile HostsFileConfig `yaml:"hostsFile"` - FQDNOnly FQDNOnly `yaml:"fqdnOnly"` - Filtering FilteringConfig `yaml:"filtering"` - EDE EDE `yaml:"ede"` - ECS ECS `yaml:"ecs"` - SUDN SUDN `yaml:"specialUseDomains"` + Upstreams UpstreamsConfig `yaml:"upstreams"` + ConnectIPVersion IPVersion `yaml:"connectIPVersion"` + CustomDNS CustomDNSConfig `yaml:"customDNS"` + Conditional ConditionalUpstreamConfig `yaml:"conditional"` + Blocking BlockingConfig `yaml:"blocking"` + ClientLookup ClientLookupConfig `yaml:"clientLookup"` + Caching CachingConfig `yaml:"caching"` + QueryLog QueryLogConfig `yaml:"queryLog"` + Prometheus MetricsConfig `yaml:"prometheus"` + Redis RedisConfig `yaml:"redis"` + Log log.Config `yaml:"log"` + Ports PortsConfig `yaml:"ports"` + DoHUserAgent string `yaml:"dohUserAgent"` + MinTLSServeVer string `yaml:"minTlsServeVersion" default:"1.2"` + CertFile string `yaml:"certFile"` + KeyFile string `yaml:"keyFile"` + BootstrapDNS BootstrapDNSConfig `yaml:"bootstrapDns"` + HostsFile HostsFileConfig `yaml:"hostsFile"` + FQDNOnly FQDNOnly `yaml:"fqdnOnly"` + Filtering FilteringConfig `yaml:"filtering"` + EDE EDE `yaml:"ede"` + ECS ECS `yaml:"ecs"` + SUDN SUDN `yaml:"specialUseDomains"` // Deprecated options Deprecated struct { - Upstream *UpstreamGroups `yaml:"upstream"` - UpstreamTimeout *Duration `yaml:"upstreamTimeout"` - DisableIPv6 *bool `yaml:"disableIPv6"` - LogLevel *log.Level `yaml:"logLevel"` - LogFormat *log.FormatType `yaml:"logFormat"` - LogPrivacy *bool `yaml:"logPrivacy"` - LogTimestamp *bool `yaml:"logTimestamp"` - DNSPorts *ListenConfig `yaml:"port"` - HTTPPorts *ListenConfig `yaml:"httpPort"` - HTTPSPorts *ListenConfig `yaml:"httpsPort"` - TLSPorts *ListenConfig `yaml:"tlsPort"` + Upstream *UpstreamGroups `yaml:"upstream"` + UpstreamTimeout *Duration `yaml:"upstreamTimeout"` + DisableIPv6 *bool `yaml:"disableIPv6"` + LogLevel *log.Level `yaml:"logLevel"` + LogFormat *log.FormatType `yaml:"logFormat"` + LogPrivacy *bool `yaml:"logPrivacy"` + LogTimestamp *bool `yaml:"logTimestamp"` + DNSPorts *ListenConfig `yaml:"port"` + HTTPPorts *ListenConfig `yaml:"httpPort"` + HTTPSPorts *ListenConfig `yaml:"httpsPort"` + TLSPorts *ListenConfig `yaml:"tlsPort"` + StartVerifyUpstream *bool `yaml:"startVerifyUpstream"` } `yaml:",inline"` } @@ -514,14 +514,15 @@ func (cfg *Config) migrate(logger *logrus.Entry) bool { cfg.Filtering.QueryTypes.Insert(dns.Type(dns.TypeAAAA)) } }), - "port": Move(To("ports.dns", &cfg.Ports)), - "httpPort": Move(To("ports.http", &cfg.Ports)), - "httpsPort": Move(To("ports.https", &cfg.Ports)), - "tlsPort": Move(To("ports.tls", &cfg.Ports)), - "logLevel": Move(To("log.level", &cfg.Log)), - "logFormat": Move(To("log.format", &cfg.Log)), - "logPrivacy": Move(To("log.privacy", &cfg.Log)), - "logTimestamp": Move(To("log.timestamp", &cfg.Log)), + "port": Move(To("ports.dns", &cfg.Ports)), + "httpPort": Move(To("ports.http", &cfg.Ports)), + "httpsPort": Move(To("ports.https", &cfg.Ports)), + "tlsPort": Move(To("ports.tls", &cfg.Ports)), + "logLevel": Move(To("log.level", &cfg.Log)), + "logFormat": Move(To("log.format", &cfg.Log)), + "logPrivacy": Move(To("log.privacy", &cfg.Log)), + "logTimestamp": Move(To("log.timestamp", &cfg.Log)), + "startVerifyUpstream": Move(To("upstreams.startVerify", &cfg.Upstreams)), }) usesDepredOpts = cfg.Blocking.migrate(logger) || usesDepredOpts diff --git a/config/config_test.go b/config/config_test.go index 06b982833..bb9f09e48 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -770,6 +770,7 @@ bootstrapDns: func defaultTestFileConfig() { Expect(config.Ports.DNS).Should(Equal(ListenConfig{"55553", ":55554", "[::1]:55555"})) + Expect(config.Upstreams.StartVerify).Should(BeFalse()) Expect(config.Upstreams.Groups["default"]).Should(HaveLen(3)) Expect(config.Upstreams.Groups["default"][0].Host).Should(Equal("8.8.8.8")) Expect(config.Upstreams.Groups["default"][1].Host).Should(Equal("8.8.4.4")) @@ -798,7 +799,6 @@ func defaultTestFileConfig() { Expect(config.DoHUserAgent).Should(Equal("testBlocky")) Expect(config.MinTLSServeVer).Should(Equal("1.3")) - Expect(config.StartVerifyUpstream).Should(BeFalse()) Expect(GetConfig()).Should(Not(BeNil())) } @@ -806,6 +806,7 @@ func defaultTestFileConfig() { func writeConfigYml(tmpDir *helpertest.TmpFolder) *helpertest.TmpFile { return tmpDir.CreateStringFile("config.yml", "upstreams:", + " startVerify: false", " groups:", " default:", " - tcp+udp:8.8.8.8", @@ -860,7 +861,7 @@ func writeConfigYml(tmpDir *helpertest.TmpFolder) *helpertest.TmpFile { "logLevel: debug", "dohUserAgent: testBlocky", "minTlsServeVersion: 1.3", - "startVerifyUpstream: false") + ) } func writeConfigDir(tmpDir *helpertest.TmpFolder) error { diff --git a/config/upstreams.go b/config/upstreams.go index 5f1787d82..6cb975d88 100644 --- a/config/upstreams.go +++ b/config/upstreams.go @@ -8,9 +8,10 @@ const UpstreamDefaultCfgName = "default" // UpstreamsConfig upstream servers configuration type UpstreamsConfig struct { - Timeout Duration `yaml:"timeout" default:"2s"` - Groups UpstreamGroups `yaml:"groups"` - Strategy UpstreamStrategy `yaml:"strategy" default:"parallel_best"` + Timeout Duration `yaml:"timeout" default:"2s"` + Groups UpstreamGroups `yaml:"groups"` + Strategy UpstreamStrategy `yaml:"strategy" default:"parallel_best"` + StartVerify bool `yaml:"startVerify" default:"false"` } type UpstreamGroups map[string][]Upstream @@ -37,13 +38,32 @@ func (c *UpstreamsConfig) LogConfig(logger *logrus.Entry) { // UpstreamGroup represents the config for one group (upstream branch) type UpstreamGroup struct { - Name string - Upstreams []Upstream + UpstreamsConfig + + Name string // group name +} + +// NewUpstreamGroup creates an UpstreamGroup with the given name and upstreams. +// +// The upstreams from `cfg.Groups` are ignored. +func NewUpstreamGroup(name string, cfg UpstreamsConfig, upstreams []Upstream) UpstreamGroup { + group := UpstreamGroup{ + Name: name, + UpstreamsConfig: cfg, + } + + group.Groups = UpstreamGroups{name: upstreams} + + return group +} + +func (c *UpstreamGroup) Upstreams() []Upstream { + return c.Groups[c.Name] } // IsEnabled implements `config.Configurable`. func (c *UpstreamGroup) IsEnabled() bool { - return len(c.Upstreams) != 0 + return len(c.Upstreams()) != 0 } // LogConfig implements `config.Configurable`. @@ -51,7 +71,7 @@ func (c *UpstreamGroup) LogConfig(logger *logrus.Entry) { logger.Info("group: ", c.Name) logger.Info("upstreams:") - for _, upstream := range c.Upstreams { + for _, upstream := range c.Upstreams() { logger.Infof(" - %s", upstream) } } diff --git a/config/upstreams_test.go b/config/upstreams_test.go index 3dc6bd4a0..8850c0a71 100644 --- a/config/upstreams_test.go +++ b/config/upstreams_test.go @@ -65,13 +65,13 @@ var _ = Describe("ParallelBestConfig", func() { var cfg UpstreamGroup BeforeEach(func() { - cfg = UpstreamGroup{ - Name: UpstreamDefaultCfgName, - Upstreams: []Upstream{ - {Host: "host1"}, - {Host: "host2"}, - }, - } + upstreamsCfg, err := WithDefaults[UpstreamsConfig]() + Expect(err).Should(Succeed()) + + cfg = NewUpstreamGroup("test", upstreamsCfg, []Upstream{ + {Host: "host1"}, + {Host: "host2"}, + }) }) Describe("IsEnabled", func() { @@ -102,7 +102,7 @@ var _ = Describe("ParallelBestConfig", func() { cfg.LogConfig(logger) Expect(hook.Calls).ShouldNot(BeEmpty()) - Expect(hook.Messages).Should(ContainElement(ContainSubstring("group: default"))) + Expect(hook.Messages).Should(ContainElement(ContainSubstring("group: test"))) Expect(hook.Messages).Should(ContainElement(ContainSubstring("upstreams:"))) Expect(hook.Messages).Should(ContainElement(ContainSubstring(":host1:"))) Expect(hook.Messages).Should(ContainElement(ContainSubstring(":host2:"))) diff --git a/docs/config.yml b/docs/config.yml index 045447066..9c5d06af0 100644 --- a/docs/config.yml +++ b/docs/config.yml @@ -24,9 +24,8 @@ upstreams: strategy: parallel_best # optional: timeout to query the upstream resolver. Default: 2s timeout: 2s - -# optional: If true, blocky will fail to start unless at least one upstream server per group is reachable. Default: false -startVerifyUpstream: true + # optional: If true, blocky will fail to start unless at least one upstream server per group is reachable. Default: false + startVerify: false # optional: Determines how blocky will create outgoing connections. This impacts both upstreams, and lists. # accepted: dual, v4, v6 diff --git a/docs/configuration.md b/docs/configuration.md index db717afe0..eee36e70d 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -17,7 +17,6 @@ configuration properties as [JSON](config.yml). | keyFile | path | no | | Path to cert and key file for SSL encryption (DoH and DoT); if empty, self-signed certificate is generated | | dohUserAgent | string | no | | HTTP User Agent for DoH upstreams | | minTlsServeVersion | string | no | 1.2 | Minimum TLS version that the DoT and DoH server use to serve those encrypted DNS requests | -| startVerifyUpstream | bool | no | false | If true, blocky will fail to start unless at least one upstream server per group is reachable. | | connectIPVersion | enum (dual, v4, v6) | no | dual | IP version to use for outgoing connections (dual, v4, v6) | !!! example @@ -70,6 +69,16 @@ All logging options are optional. ## Upstreams configuration +| Parameter | Type | Default value | Description | +| --------------------- | ------------------------------------ | ------------- | ----------------------------------------------------------------------------------------------- | +| usptreams.groups | map of name to upstream | none | Upstream DNS servers to use, in groups. | +| usptreams.startVerify | bool | false | If true, blocky will fail to start unless at least one upstream server per group is functional. | +| usptreams.strategy | enum (parallel_best, random, strict) | parallel_best | Upstream server usage strategy. | +| usptreams.timeout | bool | true | Upstream connection timeout. | + + +### Upstream Groups + To resolve a DNS query, blocky needs external public or private DNS resolvers. Blocky supports DNS resolvers with following network protocols (net part of the resolver URL): @@ -133,6 +142,22 @@ The logic determining what group a client belongs to follows a strict order: IP, If a client matches multiple client name or CIDR groups, a warning is logged and the first found group is used. +### Upstream connection timeout + +Blocky will wait 2 seconds (default value) for the response from the external upstream DNS server. You can change this +value by setting the `timeout` configuration parameter (in **duration format**). + +!!! example + + ```yaml + upstreams: + timeout: 5s + groups: + default: + - 46.182.19.48 + - 80.241.218.68 + ``` + ### Upstream strategy Blocky supports different upstream strategies (default `parallel_best`) that determine how and to which upstream DNS servers requests are forwarded. @@ -160,21 +185,6 @@ Currently available strategies: - 9.8.7.6 ``` -### Upstream lookup timeout - -Blocky will wait 2 seconds (default value) for the response from the external upstream DNS server. You can change this -value by setting the `timeout` configuration parameter (in **duration format**). - -!!! example - - ```yaml - upstreams: - timeout: 5s - groups: - default: - - 46.182.19.48 - - 80.241.218.68 - ``` ## Bootstrap DNS configuration diff --git a/resolver/bootstrap.go b/resolver/bootstrap.go index de474720f..7cf031fb9 100644 --- a/resolver/bootstrap.go +++ b/resolver/bootstrap.go @@ -65,7 +65,7 @@ func NewBootstrap(ctx context.Context, cfg *config.Config) (b *Bootstrap, err er }, } - bootstraped, err := newBootstrapedResolvers(b, cfg.BootstrapDNS) + bootstraped, err := newBootstrapedResolvers(b, cfg.BootstrapDNS, cfg.Upstreams) if err != nil { return nil, err } @@ -76,11 +76,8 @@ func NewBootstrap(ctx context.Context, cfg *config.Config) (b *Bootstrap, err er return b, nil } - // Bootstrap doesn't have a `LogConfig` method, and since that's the only place - // where `ParallelBestResolver` uses its config, we can just use an empty one. - pbCfg := config.UpstreamGroup{Name: upstreamDefaultCfgName} - - parallelResolver := newParallelBestResolver(pbCfg, bootstraped.Resolvers()) + pbCfg := config.NewUpstreamGroup("", cfg.Upstreams, nil) + pbCfg.UpstreamsConfig.Groups = nil // To be on the safe side it doesn't try to use anything besides the bootstrap // Always enable prefetching to avoid stalling user requests // Otherwise, a request to blocky could end up waiting for 2 DNS requests: @@ -100,14 +97,14 @@ func NewBootstrap(ctx context.Context, cfg *config.Config) (b *Bootstrap, err er NewFilteringResolver(cfg.Filtering), // false: no metrics, to not overwrite the main blocking resolver ones newCachingResolver(ctx, cachingCfg, nil, false), - parallelResolver, + newParallelBestResolver(pbCfg, bootstraped.Resolvers()), ) return b, nil } func (b *Bootstrap) UpstreamIPs(ctx context.Context, r *UpstreamResolver) (*IPSet, error) { - hostname := r.upstream.Host + hostname := r.cfg.Host if ip := net.ParseIP(hostname); ip != nil { // nil-safe when hostname is an IP: makes writing test easier return newIPSet([]net.IP{ip}), nil @@ -249,7 +246,9 @@ func (b *Bootstrap) resolveType(ctx context.Context, hostname string, qType dns. // map of bootstraped resolvers their hardcoded IPs type bootstrapedResolvers map[Resolver][]net.IP -func newBootstrapedResolvers(b *Bootstrap, cfg config.BootstrapDNSConfig) (bootstrapedResolvers, error) { +func newBootstrapedResolvers( + b *Bootstrap, cfg config.BootstrapDNSConfig, upstreamsCfg config.UpstreamsConfig, +) (bootstrapedResolvers, error) { upstreamIPs := make(bootstrapedResolvers, len(cfg)) var multiErr *multierror.Error @@ -289,7 +288,7 @@ func newBootstrapedResolvers(b *Bootstrap, cfg config.BootstrapDNSConfig) (boots continue } - resolver := newUpstreamResolverUnchecked(upstream, b) + resolver := newUpstreamResolverUnchecked(newUpstreamConfig(upstream, upstreamsCfg), b) upstreamIPs[resolver] = ips } diff --git a/resolver/bootstrap_test.go b/resolver/bootstrap_test.go index 5e4ffb397..4d2b65ba2 100644 --- a/resolver/bootstrap_test.go +++ b/resolver/bootstrap_test.go @@ -32,7 +32,6 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() { ) BeforeEach(func() { - config.GetConfig().Upstreams.Strategy = config.UpstreamStrategyParallelBest sutConfig = &config.Config{ BootstrapDNS: []config.BootstrappedUpstreamConfig{ { @@ -43,6 +42,7 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() { IPs: []net.IP{net.IPv4zero}, }, }, + Upstreams: defaultUpstreamsConfig, } ctx, cancelFn = context.WithCancel(context.Background()) @@ -327,7 +327,7 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() { upstream.Host = "localhost" // force bootstrap to do resolve, and not just return the IP as is - r := newUpstreamResolverUnchecked(upstream, sut) + r := newUpstreamResolverUnchecked(newUpstreamConfig(upstream, sutConfig.Upstreams), sut) rsp, err := r.Resolve(ctx, mainReq) Expect(err).Should(Succeed()) diff --git a/resolver/client_names_resolver.go b/resolver/client_names_resolver.go index 019b508b0..e3ecc4f79 100644 --- a/resolver/client_names_resolver.go +++ b/resolver/client_names_resolver.go @@ -28,11 +28,11 @@ type ClientNamesResolver struct { // NewClientNamesResolver creates new resolver instance func NewClientNamesResolver(ctx context.Context, - cfg config.ClientLookupConfig, bootstrap *Bootstrap, shouldVerifyUpstreams bool, + cfg config.ClientLookupConfig, upstreamsCfg config.UpstreamsConfig, bootstrap *Bootstrap, ) (cr *ClientNamesResolver, err error) { var r Resolver if !cfg.Upstream.IsDefault() { - r, err = NewUpstreamResolver(ctx, cfg.Upstream, bootstrap, shouldVerifyUpstreams) + r, err = NewUpstreamResolver(ctx, newUpstreamConfig(cfg.Upstream, upstreamsCfg), bootstrap) if err != nil { return nil, err } diff --git a/resolver/client_names_resolver_test.go b/resolver/client_names_resolver_test.go index 4ea611b06..a95372788 100644 --- a/resolver/client_names_resolver_test.go +++ b/resolver/client_names_resolver_test.go @@ -39,7 +39,7 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() { ctx, cancelFn = context.WithCancel(context.Background()) DeferCleanup(cancelFn) - sut, err = NewClientNamesResolver(ctx, sutConfig, nil, false) + sut, err = NewClientNamesResolver(ctx, sutConfig, defaultUpstreamsConfig, nil) Expect(err).Should(Succeed()) m = &mockResolver{} m.On("Resolve", mock.Anything).Return(&Response{Res: new(dns.Msg)}, nil) @@ -370,9 +370,12 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() { It("errors during construction", func() { b := newTestBootstrap(ctx, &dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure}}) + upstreamsCfg := defaultUpstreamsConfig + upstreamsCfg.StartVerify = true + r, err := NewClientNamesResolver(ctx, config.ClientLookupConfig{ Upstream: config.Upstream{Host: "example.com"}, - }, b, true) + }, upstreamsCfg, b) Expect(err).ShouldNot(Succeed()) Expect(r).Should(BeNil()) diff --git a/resolver/conditional_upstream_resolver.go b/resolver/conditional_upstream_resolver.go index 9823e1e12..32ffb841e 100644 --- a/resolver/conditional_upstream_resolver.go +++ b/resolver/conditional_upstream_resolver.go @@ -2,6 +2,7 @@ package resolver import ( "context" + "fmt" "strings" "github.com/0xERR0R/blocky/config" @@ -24,14 +25,15 @@ type ConditionalUpstreamResolver struct { // NewConditionalUpstreamResolver returns new resolver instance func NewConditionalUpstreamResolver( - ctx context.Context, cfg config.ConditionalUpstreamConfig, bootstrap *Bootstrap, shouldVerifyUpstreams bool, + ctx context.Context, cfg config.ConditionalUpstreamConfig, upstreamsCfg config.UpstreamsConfig, bootstrap *Bootstrap, ) (*ConditionalUpstreamResolver, error) { m := make(map[string]Resolver, len(cfg.Mapping.Upstreams)) for domain, upstreams := range cfg.Mapping.Upstreams { - cfg := config.UpstreamGroup{Name: upstreamDefaultCfgName, Upstreams: upstreams} + name := fmt.Sprintf("", domain) + cfg := config.NewUpstreamGroup(name, upstreamsCfg, upstreams) - r, err := NewParallelBestResolver(ctx, cfg, bootstrap, shouldVerifyUpstreams) + r, err := NewParallelBestResolver(ctx, cfg, bootstrap) if err != nil { return nil, err } diff --git a/resolver/conditional_upstream_resolver_test.go b/resolver/conditional_upstream_resolver_test.go index 2c6ef4bc8..35c3feef2 100644 --- a/resolver/conditional_upstream_resolver_test.go +++ b/resolver/conditional_upstream_resolver_test.go @@ -17,8 +17,10 @@ import ( var _ = Describe("ConditionalUpstreamResolver", Label("conditionalResolver"), func() { var ( - sut *ConditionalUpstreamResolver - m *mockResolver + sut *ConditionalUpstreamResolver + sutConfig config.ConditionalUpstreamConfig + + m *mockResolver ctx context.Context cancelFn context.CancelFunc @@ -65,7 +67,7 @@ var _ = Describe("ConditionalUpstreamResolver", Label("conditionalResolver"), fu }) DeferCleanup(refuseTestUpstream.Close) - sut, _ = NewConditionalUpstreamResolver(ctx, config.ConditionalUpstreamConfig{ + sutConfig = config.ConditionalUpstreamConfig{ Mapping: config.ConditionalUpstreamMapping{ Upstreams: map[string][]config.Upstream{ "fritz.box": {fbTestUpstream.Start()}, @@ -74,7 +76,11 @@ var _ = Describe("ConditionalUpstreamResolver", Label("conditionalResolver"), fu ".": {dotTestUpstream.Start()}, }, }, - }, nil, false) + } + }) + + JustBeforeEach(func() { + sut, _ = NewConditionalUpstreamResolver(ctx, sutConfig, defaultUpstreamsConfig, systemResolverBootstrap) m = &mockResolver{} m.On("Resolve", mock.Anything).Return(&Response{Res: new(dns.Msg)}, nil) sut.Next(m) @@ -194,14 +200,19 @@ var _ = Describe("ConditionalUpstreamResolver", Label("conditionalResolver"), fu It("errors during construction", func() { b := newTestBootstrap(ctx, &dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure}}) - r, err := NewConditionalUpstreamResolver(ctx, config.ConditionalUpstreamConfig{ + upstreamsCfg := config.UpstreamsConfig{ + StartVerify: true, + } + + sutConfig := config.ConditionalUpstreamConfig{ Mapping: config.ConditionalUpstreamMapping{ Upstreams: map[string][]config.Upstream{ ".": {config.Upstream{Host: "example.com"}}, }, }, - }, b, true) + } + r, err := NewConditionalUpstreamResolver(ctx, sutConfig, upstreamsCfg, b) Expect(err).ShouldNot(Succeed()) Expect(r).Should(BeNil()) }) diff --git a/resolver/parallel_best_resolver.go b/resolver/parallel_best_resolver.go index c0f022e23..de87f88b0 100644 --- a/resolver/parallel_best_resolver.go +++ b/resolver/parallel_best_resolver.go @@ -13,7 +13,6 @@ import ( "github.com/0xERR0R/blocky/log" "github.com/0xERR0R/blocky/model" "github.com/0xERR0R/blocky/util" - "github.com/miekg/dns" "github.com/mroth/weightedrand/v2" "github.com/sirupsen/logrus" @@ -31,7 +30,6 @@ type ParallelBestResolver struct { configurable[*config.UpstreamGroup] typed - groupName string resolvers []*upstreamResolverStatus resolverCount int @@ -53,6 +51,16 @@ func newUpstreamResolverStatus(resolver Resolver) *upstreamResolverStatus { return status } +func newUpstreamResolverStatuses(resolvers []Resolver) []*upstreamResolverStatus { + statuses := make([]*upstreamResolverStatus, 0, len(resolvers)) + + for _, r := range resolvers { + statuses = append(statuses, newUpstreamResolverStatus(r)) + } + + return statuses +} + func (r *upstreamResolverStatus) resolve(ctx context.Context, req *model.Request) (*model.Response, error) { resp, err := r.resolver.Resolve(ctx, req) if err != nil { @@ -83,25 +91,13 @@ type requestResponse struct { err error } -// testResolver sends a test query to verify the resolver is reachable and working -func testResolver(ctx context.Context, r *UpstreamResolver) error { - request := newRequest("github.com.", dns.Type(dns.TypeA)) - - resp, err := r.Resolve(ctx, request) - if err != nil || resp.RType != model.ResponseTypeRESOLVED { - return fmt.Errorf("test resolve of upstream server failed: %w", err) - } - - return nil -} - // NewParallelBestResolver creates new resolver instance func NewParallelBestResolver( - ctx context.Context, cfg config.UpstreamGroup, bootstrap *Bootstrap, shouldVerifyUpstreams bool, + ctx context.Context, cfg config.UpstreamGroup, bootstrap *Bootstrap, ) (*ParallelBestResolver, error) { logger := log.PrefixedLog(parallelResolverType) - resolvers, err := createResolvers(ctx, logger, cfg, bootstrap, shouldVerifyUpstreams) + resolvers, err := createResolvers(ctx, logger, cfg, bootstrap) if err != nil { return nil, err } @@ -109,20 +105,12 @@ func NewParallelBestResolver( return newParallelBestResolver(cfg, resolvers), nil } -func newParallelBestResolver( - cfg config.UpstreamGroup, resolvers []Resolver, -) *ParallelBestResolver { - resolverStatuses := make([]*upstreamResolverStatus, 0, len(resolvers)) - - for _, r := range resolvers { - resolverStatuses = append(resolverStatuses, newUpstreamResolverStatus(r)) - } - +func newParallelBestResolver(cfg config.UpstreamGroup, resolvers []Resolver) *ParallelBestResolver { typeName := "parallel_best" resolverCount := parallelBestResolverCount retryWithDifferentResolver := false - if config.GetConfig().Upstreams.Strategy == config.UpstreamStrategyRandom { + if cfg.Strategy == config.UpstreamStrategyRandom { typeName = "random" resolverCount = 1 retryWithDifferentResolver = true @@ -132,11 +120,9 @@ func newParallelBestResolver( configurable: withConfig(&cfg), typed: withType(typeName), - groupName: cfg.Name, - resolvers: resolverStatuses, - resolverCount: resolverCount, retryWithDifferentResolver: retryWithDifferentResolver, + resolvers: newUpstreamResolverStatuses(resolvers), } return &r @@ -147,12 +133,12 @@ func (r *ParallelBestResolver) Name() string { } func (r *ParallelBestResolver) String() string { - result := make([]string, len(r.resolvers)) + resolvers := make([]string, len(r.resolvers)) for i, s := range r.resolvers { - result[i] = fmt.Sprintf("%s", s.resolver) + resolvers[i] = fmt.Sprintf("%s", s.resolver) } - return fmt.Sprintf("%s upstreams '%s (%s)'", r.Type(), r.groupName, strings.Join(result, ",")) + return fmt.Sprintf("%s upstreams '%s (%s)'", r.Type(), r.cfg.Name, strings.Join(resolvers, ",")) } // Resolve sends the query request to multiple upstream resolvers and returns the fastest result diff --git a/resolver/parallel_best_resolver_test.go b/resolver/parallel_best_resolver_test.go index 13e6faf28..1d784535b 100644 --- a/resolver/parallel_best_resolver_test.go +++ b/resolver/parallel_best_resolver_test.go @@ -9,7 +9,6 @@ import ( . "github.com/0xERR0R/blocky/helpertest" "github.com/0xERR0R/blocky/log" . "github.com/0xERR0R/blocky/model" - "github.com/0xERR0R/blocky/util" "github.com/miekg/dns" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" @@ -17,18 +16,19 @@ import ( var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() { const ( + timeout = 50 * time.Millisecond + verifyUpstreams = true noVerifyUpstreams = false - - timeout = 50 * time.Millisecond ) var ( - sut *ParallelBestResolver - upstreams []config.Upstream - sutVerify bool - ctx context.Context - cancelFn context.CancelFunc + sut *ParallelBestResolver + sutStrategy config.UpstreamStrategy + upstreams []config.Upstream + sutVerify bool + ctx context.Context + cancelFn context.CancelFunc err error @@ -42,29 +42,27 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() { }) BeforeEach(func() { - old := config.GetConfig().Upstreams.Timeout - DeferCleanup(func() { config.GetConfig().Upstreams.Timeout = old }) - - config.GetConfig().Upstreams.Timeout = config.Duration(timeout) - config.GetConfig().Upstreams.Strategy = config.UpstreamStrategyParallelBest - ctx, cancelFn = context.WithCancel(context.Background()) DeferCleanup(cancelFn) upstreams = []config.Upstream{{Host: "wrong"}, {Host: "127.0.0.2"}} + sutStrategy = config.UpstreamStrategyParallelBest sutVerify = noVerifyUpstreams bootstrap = systemResolverBootstrap }) JustBeforeEach(func() { - sutConfig := config.UpstreamGroup{ - Name: upstreamDefaultCfgName, - Upstreams: upstreams, + upstreamsCfg := config.UpstreamsConfig{ + StartVerify: sutVerify, + Strategy: sutStrategy, + Timeout: config.Duration(timeout), } - sut, err = NewParallelBestResolver(ctx, sutConfig, bootstrap, sutVerify) + sutConfig := config.NewUpstreamGroup("test", upstreamsCfg, upstreams) + + sut, err = NewParallelBestResolver(ctx, sutConfig, bootstrap) }) Describe("IsEnabled", func() { @@ -90,26 +88,14 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() { }) }) - When("some default upstream resolvers cannot be reached", func() { - It("should start normally", func() { - mockUpstream := NewMockUDPUpstreamServer().WithAnswerFn(func(request *dns.Msg) (response *dns.Msg) { - response, _ = util.NewMsgWithAnswer(request.Question[0].Name, 123, A, "123.124.122.122") + When("default upstream resolvers are not defined", func() { + BeforeEach(func() { + upstreams = []config.Upstream{} + }) - return - }) - defer mockUpstream.Close() - - upstreams := []config.Upstream{ - {Host: "wrong"}, - mockUpstream.Start(), - } - - _, err := NewParallelBestResolver(ctx, config.UpstreamGroup{ - Name: upstreamDefaultCfgName, - Upstreams: upstreams, - }, - systemResolverBootstrap, verifyUpstreams) - Expect(err).Should(Not(HaveOccurred())) + It("should fail on startup", func() { + Expect(err).Should(HaveOccurred()) + Expect(err).Should(MatchError(ContainSubstring("no external DNS resolvers configured"))) }) }) @@ -230,13 +216,12 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() { withError2 := config.Upstream{Host: "wrong"} upstreams = []config.Upstream{withError1, withError2} - Expect(err).Should(Succeed()) }) It("Should return error", func() { Expect(err).Should(Succeed()) request := newRequest("example.com.", A) - _, err = sut.Resolve(ctx, request) + _, err = sut.Resolve(ctx, request) Expect(err).Should(HaveOccurred()) }) }) @@ -265,7 +250,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() { Describe("Weighted random on resolver selection", func() { When("5 upstream resolvers are defined", func() { - It("should use 2 random peeked resolvers, weighted with last error timestamp", func() { + BeforeEach(func() { withError1 := config.Upstream{Host: "wrong1"} withError2 := config.Upstream{Host: "wrong2"} @@ -275,12 +260,10 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() { mockUpstream2 := NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.122") DeferCleanup(mockUpstream2.Close) - sut, _ = NewParallelBestResolver(ctx, config.UpstreamGroup{ - Name: upstreamDefaultCfgName, - Upstreams: []config.Upstream{withError1, mockUpstream1.Start(), mockUpstream2.Start(), withError2}, - }, - systemResolverBootstrap, noVerifyUpstreams) + upstreams = []config.Upstream{withError1, mockUpstream1.Start(), mockUpstream2.Start(), withError2} + }) + It("should use 2 random peeked resolvers, weighted with last error timestamp", func() { By("all resolvers have same weight for random -> equal distribution", func() { resolverCount := make(map[Resolver]int) @@ -335,11 +318,12 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() { It("errors during construction", func() { b := newTestBootstrap(ctx, &dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure}}) - r, err := NewParallelBestResolver(ctx, config.UpstreamGroup{ - Name: "test", - Upstreams: []config.Upstream{{Host: "example.com"}}, - }, b, verifyUpstreams) + upstreamsCfg := sut.cfg.UpstreamsConfig + upstreamsCfg.StartVerify = true + group := config.NewUpstreamGroup("test", upstreamsCfg, []config.Upstream{{Host: "example.com"}}) + + r, err := NewParallelBestResolver(ctx, group, b) Expect(err).ShouldNot(Succeed()) Expect(r).Should(BeNil()) }) @@ -347,7 +331,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() { Describe("random resolver strategy", func() { BeforeEach(func() { - config.GetConfig().Upstreams.Strategy = config.UpstreamStrategyRandom + sutStrategy = config.UpstreamStrategyRandom }) Describe("Name", func() { @@ -468,7 +452,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() { Describe("Weighted random on resolver selection", func() { When("4 upstream resolvers are defined", func() { - It("should use 2 random peeked resolvers, weighted with last error timestamp", func() { + BeforeEach(func() { withError1 := config.Upstream{Host: "wrong1"} withError2 := config.Upstream{Host: "wrong2"} @@ -478,12 +462,10 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() { mockUpstream2 := NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.122") DeferCleanup(mockUpstream2.Close) - sut, _ = NewParallelBestResolver(ctx, config.UpstreamGroup{ - Name: upstreamDefaultCfgName, - Upstreams: []config.Upstream{withError1, mockUpstream1.Start(), mockUpstream2.Start(), withError2}, - }, - systemResolverBootstrap, noVerifyUpstreams) + upstreams = []config.Upstream{withError1, mockUpstream1.Start(), mockUpstream2.Start(), withError2} + }) + It("should use 2 random peeked resolvers, weighted with last error timestamp", func() { By("all resolvers have same weight for random -> equal distribution", func() { resolverCount := make(map[Resolver]int) diff --git a/resolver/resolver.go b/resolver/resolver.go index c10969ae9..c0a687a86 100644 --- a/resolver/resolver.go +++ b/resolver/resolver.go @@ -218,21 +218,25 @@ func (c *configurable[T]) LogConfig(logger *logrus.Entry) { func createResolvers( ctx context.Context, logger *logrus.Entry, - cfg config.UpstreamGroup, bootstrap *Bootstrap, shoudVerifyUpstreams bool, + cfg config.UpstreamGroup, bootstrap *Bootstrap, ) ([]Resolver, error) { - resolvers := make([]Resolver, 0, len(cfg.Upstreams)) + if len(cfg.Upstreams()) == 0 { + return nil, fmt.Errorf("no external DNS resolvers configured for group %s", cfg.Name) + } + + resolvers := make([]Resolver, 0, len(cfg.Upstreams())) hasValidResolvers := false - for _, u := range cfg.Upstreams { - resolver, err := NewUpstreamResolver(ctx, u, bootstrap, shoudVerifyUpstreams) + for _, u := range cfg.Upstreams() { + resolver, err := NewUpstreamResolver(ctx, newUpstreamConfig(u, cfg.UpstreamsConfig), bootstrap) if err != nil { logger.Warnf("upstream group %s: %v", cfg.Name, err) continue } - if shoudVerifyUpstreams { - err = testResolver(ctx, resolver) + if cfg.StartVerify { + err = resolver.testResolve(ctx) if err != nil { logger.Warn(err) } else { @@ -243,7 +247,7 @@ func createResolvers( resolvers = append(resolvers, resolver) } - if shoudVerifyUpstreams && !hasValidResolvers { + if cfg.StartVerify && !hasValidResolvers { return nil, fmt.Errorf("no valid upstream for group %s", cfg.Name) } diff --git a/resolver/resolver_suite_test.go b/resolver/resolver_suite_test.go index 71ce26c23..5786808aa 100644 --- a/resolver/resolver_suite_test.go +++ b/resolver/resolver_suite_test.go @@ -1,20 +1,33 @@ -package resolver_test +package resolver import ( "context" "testing" + "time" + "github.com/0xERR0R/blocky/config" "github.com/0xERR0R/blocky/log" - "github.com/go-redis/redis/v8" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" ) +var defaultUpstreamsConfig config.UpstreamsConfig + func init() { log.Silence() redis.SetLogger(NoLogs{}) + + var err error + + defaultUpstreamsConfig, err = config.WithDefaults[config.UpstreamsConfig]() + if err != nil { + panic(err) + } + + // Shorter timeout for tests + defaultUpstreamsConfig.Timeout = config.Duration(50 * time.Millisecond) } func TestResolver(t *testing.T) { diff --git a/resolver/strict_resolver.go b/resolver/strict_resolver.go index 8f40b155e..cde91e936 100644 --- a/resolver/strict_resolver.go +++ b/resolver/strict_resolver.go @@ -24,17 +24,16 @@ type StrictResolver struct { configurable[*config.UpstreamGroup] typed - groupName string resolvers []*upstreamResolverStatus } // NewStrictResolver creates a new strict resolver instance func NewStrictResolver( - ctx context.Context, cfg config.UpstreamGroup, bootstrap *Bootstrap, shouldVerifyUpstreams bool, + ctx context.Context, cfg config.UpstreamGroup, bootstrap *Bootstrap, ) (*StrictResolver, error) { logger := log.PrefixedLog(strictResolverType) - resolvers, err := createResolvers(ctx, logger, cfg, bootstrap, shouldVerifyUpstreams) + resolvers, err := createResolvers(ctx, logger, cfg, bootstrap) if err != nil { return nil, err } @@ -45,18 +44,11 @@ func NewStrictResolver( func newStrictResolver( cfg config.UpstreamGroup, resolvers []Resolver, ) *StrictResolver { - resolverStatuses := make([]*upstreamResolverStatus, 0, len(resolvers)) - - for _, r := range resolvers { - resolverStatuses = append(resolverStatuses, newUpstreamResolverStatus(r)) - } - r := StrictResolver{ configurable: withConfig(&cfg), typed: withType(strictResolverType), - groupName: cfg.Name, - resolvers: resolverStatuses, + resolvers: newUpstreamResolverStatuses(resolvers), } return &r @@ -72,7 +64,7 @@ func (r *StrictResolver) String() string { result[i] = fmt.Sprintf("%s", s.resolver) } - return fmt.Sprintf("%s upstreams '%s (%s)'", strictResolverType, r.groupName, strings.Join(result, ",")) + return fmt.Sprintf("%s upstreams '%s (%s)'", strictResolverType, r.cfg.Name, strings.Join(result, ",")) } // Resolve sends the query request in a strict order to the upstream resolvers diff --git a/resolver/strict_resolver_test.go b/resolver/strict_resolver_test.go index 0a4468442..451541630 100644 --- a/resolver/strict_resolver_test.go +++ b/resolver/strict_resolver_test.go @@ -54,11 +54,11 @@ var _ = Describe("StrictResolver", Label("strictResolver"), func() { }) JustBeforeEach(func() { - sutConfig := config.UpstreamGroup{ - Name: upstreamDefaultCfgName, - Upstreams: upstreams, - } - sut, err = NewStrictResolver(ctx, sutConfig, bootstrap, sutVerify) + upstreamsCfg := defaultUpstreamsConfig + upstreamsCfg.StartVerify = sutVerify + + sutConfig := config.NewUpstreamGroup("test", upstreamsCfg, upstreams) + sut, err = NewStrictResolver(ctx, sutConfig, bootstrap) }) config.GetConfig().Upstreams.Timeout = config.Duration(time.Second) @@ -94,24 +94,21 @@ var _ = Describe("StrictResolver", Label("strictResolver"), func() { }) When("some default upstream resolvers cannot be reached", func() { - It("should start normally", func() { + BeforeEach(func() { mockUpstream := NewMockUDPUpstreamServer().WithAnswerFn(func(request *dns.Msg) (response *dns.Msg) { response, _ = util.NewMsgWithAnswer(request.Question[0].Name, 123, A, "123.124.122.122") return }) - defer mockUpstream.Close() + DeferCleanup(mockUpstream.Close) - upstreams := []config.Upstream{ + upstreams = []config.Upstream{ {Host: "wrong"}, mockUpstream.Start(), } + }) - _, err := NewStrictResolver(ctx, config.UpstreamGroup{ - Name: upstreamDefaultCfgName, - Upstreams: upstreams, - }, - systemResolverBootstrap, verifyUpstreams) + It("should start normally", func() { Expect(err).Should(Not(HaveOccurred())) }) }) diff --git a/resolver/upstream_resolver.go b/resolver/upstream_resolver.go index 31f5f706d..1b0386f01 100644 --- a/resolver/upstream_resolver.go +++ b/resolver/upstream_resolver.go @@ -29,11 +29,34 @@ const ( retryAttempts = 3 ) +type upstreamConfig struct { + config.UpstreamsConfig + config.Upstream +} + +func newUpstreamConfig(upstream config.Upstream, cfg config.UpstreamsConfig) upstreamConfig { + return upstreamConfig{cfg, upstream} +} + +func (c upstreamConfig) String() string { + return c.Upstream.String() +} + +// IsEnabled implements `config.Configurable`. +func (c upstreamConfig) IsEnabled() bool { + return true +} + +// LogConfig implements `config.Configurable`. +func (c upstreamConfig) LogConfig(logger *logrus.Entry) { + logger.Info(c.Upstream) +} + // UpstreamResolver sends request to external DNS server type UpstreamResolver struct { typed + configurable[upstreamConfig] - upstream config.Upstream upstreamClient upstreamClient bootstrap *Bootstrap } @@ -54,7 +77,7 @@ type httpUpstreamClient struct { host string } -func createUpstreamClient(cfg config.Upstream) upstreamClient { +func createUpstreamClient(cfg upstreamConfig) upstreamClient { tlsConfig := tls.Config{ ServerName: cfg.Host, MinVersion: tls.VersionTLS12, @@ -187,11 +210,11 @@ func (r *dnsUpstreamClient) callExternal( // NewUpstreamResolver creates new resolver instance func NewUpstreamResolver( - ctx context.Context, upstream config.Upstream, bootstrap *Bootstrap, verify bool, + ctx context.Context, cfg upstreamConfig, bootstrap *Bootstrap, ) (*UpstreamResolver, error) { - r := newUpstreamResolverUnchecked(upstream, bootstrap) + r := newUpstreamResolverUnchecked(cfg, bootstrap) - if verify { + if cfg.StartVerify { _, err := r.bootstrap.UpstreamIPs(ctx, r) if err != nil { return nil, err @@ -202,30 +225,34 @@ func NewUpstreamResolver( } // newUpstreamResolverUnchecked creates new resolver instance without validating the upstream -func newUpstreamResolverUnchecked(upstream config.Upstream, bootstrap *Bootstrap) *UpstreamResolver { - upstreamClient := createUpstreamClient(upstream) +func newUpstreamResolverUnchecked(cfg upstreamConfig, bootstrap *Bootstrap) *UpstreamResolver { + upstreamClient := createUpstreamClient(cfg) return &UpstreamResolver{ - typed: withType("upstream"), + typed: withType("upstream"), + configurable: withConfig(cfg), - upstream: upstream, upstreamClient: upstreamClient, bootstrap: bootstrap, } } -// IsEnabled implements `config.Configurable`. -func (r *UpstreamResolver) IsEnabled() bool { - return true +func (r UpstreamResolver) String() string { + return fmt.Sprintf("%s '%s'", r.Type(), r.cfg) } -// LogConfig implements `config.Configurable`. -func (r *UpstreamResolver) LogConfig(logger *logrus.Entry) { - logger.Info(r.upstream) +func (r *UpstreamResolver) log() *logrus.Entry { + return r.typed.log().WithField("upstream", r.cfg.String()) } -func (r UpstreamResolver) String() string { - return fmt.Sprintf("%s '%s'", r.Type(), r.upstream) +// testResolve sends a test query to verify the upstream is reachable and working +func (r *UpstreamResolver) testResolve(ctx context.Context) error { + // example.com MUST always resolve. See SUDN resolver + request := newRequest("example.com.", dns.Type(dns.TypeA)) + + _, err := r.Resolve(ctx, request) + + return err } // Resolve calls external resolver @@ -242,15 +269,21 @@ func (r *UpstreamResolver) Resolve(ctx context.Context, request *model.Request) err = retry.Do( func() error { - ctx, cancel := context.WithTimeout(ctx, config.GetConfig().Upstreams.Timeout.ToDuration()) - defer cancel() - ip = ips.Current() - upstreamURL := r.upstreamClient.fmtURL(ip, r.upstream.Port, r.upstream.Path) + upstreamURL := r.upstreamClient.fmtURL(ip, r.cfg.Port, r.cfg.Path) + + ctx := ctx // make sure we don't overwrite the outer function's context + + if r.cfg.Timeout.IsAboveZero() { + var cancel context.CancelFunc + + ctx, cancel = context.WithTimeout(ctx, r.cfg.Timeout.ToDuration()) + defer cancel() + } response, rtt, err := r.upstreamClient.callExternal(ctx, request.Req, upstreamURL, request.Protocol) if err != nil { - return fmt.Errorf("can't resolve request via upstream server %s (%s): %w", r.upstream, upstreamURL, err) + return fmt.Errorf("can't resolve request via upstream server %s (%s): %w", r.cfg, upstreamURL, err) } resp = response @@ -266,7 +299,7 @@ func (r *UpstreamResolver) Resolve(ctx context.Context, request *model.Request) retry.RetryIf(isTimeout), retry.OnRetry(func(n uint, err error) { r.log().WithFields(logrus.Fields{ - "upstream": r.upstream.String(), + "upstream": r.cfg.String(), "upstream_ip": ip.String(), "question": util.QuestionToString(request.Req.Question), "attempt": fmt.Sprintf("%d/%d", n+1, retryAttempts), @@ -275,25 +308,20 @@ func (r *UpstreamResolver) Resolve(ctx context.Context, request *model.Request) ips.Next() })) if err != nil { - if errors.Is(err, context.DeadlineExceeded) { - // Make the error more user friendly than just "context deadline exceeded" - err = fmt.Errorf("timeout (%w)", err) - } - return nil, err } - return &model.Response{Res: resp, Reason: fmt.Sprintf("RESOLVED (%s)", r.upstream)}, nil + return &model.Response{Res: resp, Reason: fmt.Sprintf("RESOLVED (%s)", r.cfg)}, nil } func (r *UpstreamResolver) logResponse(request *model.Request, resp *dns.Msg, ip net.IP, rtt time.Duration) { r.log().WithFields(logrus.Fields{ "answer": util.AnswerToString(resp.Answer), "return_code": dns.RcodeToString[resp.Rcode], - "upstream": r.upstream.String(), + "upstream": r.cfg.String(), "upstream_ip": ip.String(), "protocol": request.Protocol, - "net": r.upstream.Net, + "net": r.cfg.Net, "response_time_ms": rtt.Milliseconds(), }).Debugf("received response from upstream") } diff --git a/resolver/upstream_resolver_test.go b/resolver/upstream_resolver_test.go index cc3512b61..7c9bffd0f 100644 --- a/resolver/upstream_resolver_test.go +++ b/resolver/upstream_resolver_test.go @@ -18,10 +18,10 @@ import ( . "github.com/onsi/gomega" ) -var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() { +var _ = FDescribe("UpstreamResolver", Label("upstreamResolver"), func() { var ( sut *UpstreamResolver - sutConfig config.Upstream + sutConfig upstreamConfig ctx context.Context cancelFn context.CancelFunc @@ -31,7 +31,7 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() { ctx, cancelFn = context.WithCancel(context.Background()) DeferCleanup(cancelFn) - sutConfig = config.Upstream{Host: "localhost"} + sutConfig = newUpstreamConfig(config.Upstream{Host: "localhost"}, defaultUpstreamsConfig) }) JustBeforeEach(func() { @@ -66,8 +66,8 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() { mockUpstream := NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.122") DeferCleanup(mockUpstream.Close) - upstream := mockUpstream.Start() - sut := newUpstreamResolverUnchecked(upstream, nil) + sutConfig.Upstream = mockUpstream.Start() + sut := newUpstreamResolverUnchecked(sutConfig, nil) Expect(sut.Resolve(ctx, newRequest("example.com.", A))). Should( @@ -76,7 +76,7 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() { HaveResponseType(ResponseTypeRESOLVED), HaveReturnCode(dns.RcodeSuccess), HaveTTL(BeNumerically("==", 123)), - HaveReason(fmt.Sprintf("RESOLVED (%s)", upstream))), + HaveReason(fmt.Sprintf("RESOLVED (%s)", sutConfig.Upstream))), ) }) }) @@ -85,8 +85,8 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() { mockUpstream := NewMockUDPUpstreamServer().WithAnswerError(dns.RcodeNameError) DeferCleanup(mockUpstream.Close) - upstream := mockUpstream.Start() - sut := newUpstreamResolverUnchecked(upstream, nil) + sutConfig.Upstream = mockUpstream.Start() + sut := newUpstreamResolverUnchecked(sutConfig, nil) Expect(sut.Resolve(ctx, newRequest("example.com.", A))). Should( @@ -94,7 +94,7 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() { HaveNoAnswer(), HaveResponseType(ResponseTypeRESOLVED), HaveReturnCode(dns.RcodeNameError), - HaveReason(fmt.Sprintf("RESOLVED (%s)", upstream))), + HaveReason(fmt.Sprintf("RESOLVED (%s)", sutConfig.Upstream))), ) }) }) @@ -104,8 +104,8 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() { return nil }) DeferCleanup(mockUpstream.Close) - upstream := mockUpstream.Start() - sut := newUpstreamResolverUnchecked(upstream, nil) + sutConfig.Upstream = mockUpstream.Start() + sut := newUpstreamResolverUnchecked(sutConfig, nil) _, err := sut.Resolve(ctx, newRequest("example.com.", A)) Expect(err).Should(HaveOccurred()) @@ -114,27 +114,25 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() { When("Timeout occurs", func() { var counter int32 var attemptsWithTimeout int32 - var sut *UpstreamResolver BeforeEach(func() { - resolveFn := func(request *dns.Msg) (response *dns.Msg) { - atomic.AddInt32(&counter, 1) + resolveFn := func(request *dns.Msg) *dns.Msg { // timeout on first x attempts - if atomic.LoadInt32(&counter) <= atomic.LoadInt32(&attemptsWithTimeout) { - time.Sleep(110 * time.Millisecond) + if atomic.AddInt32(&counter, 1) <= atomic.LoadInt32(&attemptsWithTimeout) { + time.Sleep(2 * sutConfig.Timeout.ToDuration()) } + response, err := util.NewMsgWithAnswer("example.com", 123, A, "123.124.122.122") Expect(err).Should(Succeed()) return response } + mockUpstream := NewMockUDPUpstreamServer().WithAnswerFn(resolveFn) DeferCleanup(mockUpstream.Close) - upstream := mockUpstream.Start() - - sut = newUpstreamResolverUnchecked(upstream, nil) - sut.upstreamClient.(*dnsUpstreamClient).udpClient.Timeout = 100 * time.Millisecond + sutConfig.Upstream = mockUpstream.Start() }) + It("should perform a retry with 3 attempts", func() { By("2 attempts with timeout -> should resolve with third attempt", func() { atomic.StoreInt32(&counter, 0) @@ -166,7 +164,7 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() { mockUpstream := NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.122") DeferCleanup(mockUpstream.Close) - sutConfig = mockUpstream.Start() + sutConfig.Upstream = mockUpstream.Start() }) It("should retry with UDP", func() { @@ -188,8 +186,6 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() { Describe("Using Dns over HTTP (DOH) upstream", func() { var ( - sut *UpstreamResolver - upstream config.Upstream respFn func(request *dns.Msg) (response *dns.Msg) modifyHTTPRespFn func(w http.ResponseWriter) ) @@ -205,8 +201,8 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() { }) JustBeforeEach(func() { - upstream = newTestDOHUpstream(respFn, modifyHTTPRespFn) - sut = newUpstreamResolverUnchecked(upstream, nil) + sutConfig.Upstream = newTestDOHUpstream(respFn, modifyHTTPRespFn) + sut = newUpstreamResolverUnchecked(sutConfig, nil) // use insecure certificates for test doh upstream sut.upstreamClient.(*httpUpstreamClient).client.Transport = &http.Transport{ @@ -224,7 +220,7 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() { HaveResponseType(ResponseTypeRESOLVED), HaveReturnCode(dns.RcodeSuccess), HaveTTL(BeNumerically("==", 123)), - HaveReason(fmt.Sprintf("RESOLVED (https://%s:%d)", upstream.Host, upstream.Port)), + HaveReason(fmt.Sprintf("RESOLVED (%s)", sutConfig.Upstream)), )) }) }) @@ -267,10 +263,12 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() { }) When("Configured DOH resolver does not respond", func() { JustBeforeEach(func() { - sut = newUpstreamResolverUnchecked(config.Upstream{ + sutConfig.Upstream = config.Upstream{ Net: config.NetProtocolHttps, Host: "wronghost.example.com", - }, systemResolverBootstrap) + } + + sut = newUpstreamResolverUnchecked(sutConfig, systemResolverBootstrap) }) It("should return error", func() { _, err := sut.Resolve(ctx, newRequest("example.com.", A)) diff --git a/resolver/upstream_tree_resolver.go b/resolver/upstream_tree_resolver.go index 9ed84831b..65345173b 100644 --- a/resolver/upstream_tree_resolver.go +++ b/resolver/upstream_tree_resolver.go @@ -2,6 +2,7 @@ package resolver import ( "context" + "errors" "fmt" "strings" @@ -23,15 +24,15 @@ type UpstreamTreeResolver struct { branches map[string]Resolver } -func NewUpstreamTreeResolver(cfg config.UpstreamsConfig, branches map[string]Resolver) (Resolver, error) { +func NewUpstreamTreeResolver(ctx context.Context, cfg config.UpstreamsConfig, bootstrap *Bootstrap) (Resolver, error) { if len(cfg.Groups[upstreamDefaultCfgName]) == 0 { return nil, fmt.Errorf("no external DNS resolvers configured as default upstream resolvers. "+ "Please configure at least one under '%s' configuration name", upstreamDefaultCfgName) } - if len(branches) != len(cfg.Groups) { - return nil, fmt.Errorf("amount of passed in branches (%d) does not match amount of configured upstream groups (%d)", - len(branches), len(cfg.Groups)) + branches, err := createUpstreamBranches(ctx, cfg, bootstrap) + if err != nil { + return nil, err } if len(branches) == 1 { @@ -51,6 +52,45 @@ func NewUpstreamTreeResolver(cfg config.UpstreamsConfig, branches map[string]Res return &r, nil } +func createUpstreamBranches( + ctx context.Context, cfg config.UpstreamsConfig, bootstrap *Bootstrap, +) (map[string]Resolver, error) { + branches := make(map[string]Resolver, len(cfg.Groups)) + errs := make([]error, 0, len(cfg.Groups)) + + for group, upstreams := range cfg.Groups { + var ( + upstream Resolver + err error + ) + + groupConfig := config.NewUpstreamGroup(group, cfg, upstreams) + + switch cfg.Strategy { + case config.UpstreamStrategyParallelBest: + fallthrough + case config.UpstreamStrategyRandom: + upstream, err = NewParallelBestResolver(ctx, groupConfig, bootstrap) + case config.UpstreamStrategyStrict: + upstream, err = NewStrictResolver(ctx, groupConfig, bootstrap) + } + + if err != nil { + errs = append(errs, fmt.Errorf("group %s: %w", group, err)) + + continue + } + + branches[group] = upstream + } + + if len(errs) != 0 { + return nil, errors.Join(errs...) + } + + return branches, nil +} + func (r *UpstreamTreeResolver) Name() string { return r.String() } @@ -77,7 +117,7 @@ func (r *UpstreamTreeResolver) Resolve(ctx context.Context, request *model.Reque } func (r *UpstreamTreeResolver) upstreamGroupByClient(request *model.Request) string { - groups := []string{} + groups := make([]string, 0, len(r.branches)) clientIP := request.ClientIP.String() // try IP diff --git a/resolver/upstream_tree_resolver_test.go b/resolver/upstream_tree_resolver_test.go index 11b709042..df9a7fdcc 100644 --- a/resolver/upstream_tree_resolver_test.go +++ b/resolver/upstream_tree_resolver_test.go @@ -2,38 +2,40 @@ package resolver import ( "context" + "fmt" "github.com/0xERR0R/blocky/config" . "github.com/0xERR0R/blocky/helpertest" "github.com/0xERR0R/blocky/log" . "github.com/0xERR0R/blocky/model" - "github.com/0xERR0R/blocky/util" "github.com/miekg/dns" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" - "github.com/stretchr/testify/mock" ) -var mockRes *mockResolver - var _ = Describe("UpstreamTreeResolver", Label("upstreamTreeResolver"), func() { var ( sut Resolver sutConfig config.UpstreamsConfig - branches map[string]Resolver err error + + ctx context.Context + cancelFn context.CancelFunc ) BeforeEach(func() { - mockRes = &mockResolver{} + ctx, cancelFn = context.WithCancel(context.Background()) + DeferCleanup(cancelFn) + + sutConfig = defaultUpstreamsConfig }) JustBeforeEach(func() { - sut, err = NewUpstreamTreeResolver(sutConfig, branches) + sut, err = NewUpstreamTreeResolver(ctx, sutConfig, systemResolverBootstrap) }) - When("has no configuration", func() { + When("it has no configuration", func() { BeforeEach(func() { sutConfig = config.UpstreamsConfig{} }) @@ -45,67 +47,56 @@ var _ = Describe("UpstreamTreeResolver", Label("upstreamTreeResolver"), func() { }) }) - When("amount of passed in resolvers doesn't match amount of groups", func() { + When("it has only default group", func() { BeforeEach(func() { - sutConfig = config.UpstreamsConfig{ - Groups: config.UpstreamGroups{ - upstreamDefaultCfgName: { - {Host: "wrong"}, - {Host: "127.0.0.1"}, - }, + sutConfig.Groups = config.UpstreamGroups{ + upstreamDefaultCfgName: { + {Host: "wrong"}, + {Host: "127.0.0.1"}, }, } - branches = map[string]Resolver{} }) - It("should return error", func() { - Expect(err).To(HaveOccurred()) - Expect(err).To(MatchError( - "amount of passed in branches (0) does not match amount of configured upstream groups (1)")) - Expect(sut).To(BeNil()) - }) - }) + When("strategy is parallel", func() { + BeforeEach(func() { + sutConfig.Strategy = config.UpstreamStrategyParallelBest + }) - When("has only default group", func() { - BeforeEach(func() { - sutConfig = config.UpstreamsConfig{ - Groups: config.UpstreamGroups{ - upstreamDefaultCfgName: { - {Host: "wrong"}, - {Host: "127.0.0.1"}, - }, - }, - } - branches = createBranchesMock(sutConfig) - }) - Describe("Type", func() { - It("does not return error", func() { + It("returns the resolver directly", func() { Expect(err).ToNot(HaveOccurred()) + + _, ok := sut.(*ParallelBestResolver) + Expect(ok).Should(BeTrue()) }) - It("follows conventions", func() { - expectValidResolverType(sut) + }) + + When("strategy is strict", func() { + BeforeEach(func() { + sutConfig.Strategy = config.UpstreamStrategyStrict }) - It("returns mock", func() { - Expect(sut.Type()).To(Equal("mock")) + + It("returns the resolver directly", func() { + Expect(err).ToNot(HaveOccurred()) + + _, ok := sut.(*StrictResolver) + Expect(ok).Should(BeTrue()) }) }) }) - When("has multiple groups", func() { + When("it has multiple groups", func() { BeforeEach(func() { - sutConfig = config.UpstreamsConfig{ - Groups: config.UpstreamGroups{ - upstreamDefaultCfgName: { - {Host: "wrong"}, - {Host: "127.0.0.1"}, - }, - "test": { - {Host: "some-resolver"}, - }, + sutConfig.Groups = config.UpstreamGroups{ + upstreamDefaultCfgName: { + {Host: "wrong"}, + {Host: "127.0.0.1"}, + }, + "test": { + {Host: "some-resolver"}, }, } - branches = createBranchesMock(sutConfig) }) + Describe("Type", func() { It("does not return error", func() { Expect(err).ToNot(HaveOccurred()) @@ -117,6 +108,7 @@ var _ = Describe("UpstreamTreeResolver", Label("upstreamTreeResolver"), func() { Expect(sut.Type()).To(Equal(upstreamTreeResolverType)) }) }) + Describe("Configuration output", func() { It("should return configuration", func() { Expect(sut.IsEnabled()).Should(BeTrue()) @@ -140,62 +132,41 @@ var _ = Describe("UpstreamTreeResolver", Label("upstreamTreeResolver"), func() { }) }) + When("start verify is enabled", func() { + BeforeEach(func() { + sutConfig.StartVerify = true + }) + + It("should fail", func() { + Expect(err).To(HaveOccurred()) + Expect(err).To(MatchError(ContainSubstring("no valid upstream"))) + Expect(sut).To(BeNil()) + }) + }) + When("client specific resolvers are defined", func() { - var ( - ctx context.Context - cancelFn context.CancelFunc - ) + groups := map[string]string{ + upstreamDefaultCfgName: "127.0.0.1", + "laptop": "127.0.0.2", + "client-*-m": "127.0.0.3", + "client[0-9]": "127.0.0.4", + "192.168.178.33": "127.0.0.5", + "10.43.8.67/28": "127.0.0.6", + "name-matches1": "127.0.0.7", + "name-matches*": "127.0.0.8", + } BeforeEach(func() { - ctx, cancelFn = context.WithCancel(context.Background()) - DeferCleanup(cancelFn) - - sutConfig = config.UpstreamsConfig{Groups: config.UpstreamGroups{ - upstreamDefaultCfgName: {config.Upstream{}}, - "laptop": {config.Upstream{}}, - "client-*-m": {config.Upstream{}}, - "client[0-9]": {config.Upstream{}}, - "192.168.178.33": {config.Upstream{}}, - "10.43.8.67/28": {config.Upstream{}}, - "name-matches1": {config.Upstream{}}, - "name-matches*": {config.Upstream{}}, - }} - - createMockResolver := func(group string) *mockResolver { - resolver := &mockResolver{} - - resolver.On("Resolve", mock.Anything) - resolver.ResponseFn = func(req *dns.Msg) *dns.Msg { - res := new(dns.Msg) - res.SetReply(req) - - ptr := new(dns.PTR) - ptr.Ptr = group - ptr.Hdr = util.CreateHeader(req.Question[0], 1) - res.Answer = append(res.Answer, ptr) - - return res - } - - return resolver - } + sutConfig.Groups = make(config.UpstreamGroups, len(groups)) - branches = map[string]Resolver{ - upstreamDefaultCfgName: nil, - "laptop": nil, - "client-*-m": nil, - "client[0-9]": nil, - "192.168.178.33": nil, - "10.43.8.67/28": nil, - "name-matches1": nil, - "name-matches*": nil, - } + for group, ip := range groups { + Expect(ip).ShouldNot(BeNil()) - for group := range branches { - branches[group] = createMockResolver(group) - } + server := NewMockUDPUpstreamServer().WithAnswerRR(fmt.Sprintf("example.com 123 IN A %s", ip)) + sutConfig.Groups[group] = []config.Upstream{server.Start()} - Expect(branches).To(HaveLen(8)) + DeferCleanup(server.Close) + } }) It("Should use default if client name or IP don't match", func() { @@ -204,7 +175,7 @@ var _ = Describe("UpstreamTreeResolver", Label("upstreamTreeResolver"), func() { Expect(sut.Resolve(ctx, request)). Should( SatisfyAll( - BeDNSRecord("example.com.", A, "default"), + BeDNSRecord("example.com.", A, groups["default"]), HaveResponseType(ResponseTypeRESOLVED), HaveReturnCode(dns.RcodeSuccess), )) @@ -215,7 +186,7 @@ var _ = Describe("UpstreamTreeResolver", Label("upstreamTreeResolver"), func() { Expect(sut.Resolve(ctx, request)). Should( SatisfyAll( - BeDNSRecord("example.com.", A, "laptop"), + BeDNSRecord("example.com.", A, groups["laptop"]), HaveResponseType(ResponseTypeRESOLVED), HaveReturnCode(dns.RcodeSuccess), )) @@ -226,7 +197,7 @@ var _ = Describe("UpstreamTreeResolver", Label("upstreamTreeResolver"), func() { Expect(sut.Resolve(ctx, request)). Should( SatisfyAll( - BeDNSRecord("example.com.", A, "client-*-m"), + BeDNSRecord("example.com.", A, groups["client-*-m"]), HaveResponseType(ResponseTypeRESOLVED), HaveReturnCode(dns.RcodeSuccess), )) @@ -237,7 +208,7 @@ var _ = Describe("UpstreamTreeResolver", Label("upstreamTreeResolver"), func() { Expect(sut.Resolve(ctx, request)). Should( SatisfyAll( - BeDNSRecord("example.com.", A, "client[0-9]"), + BeDNSRecord("example.com.", A, groups["client[0-9]"]), HaveResponseType(ResponseTypeRESOLVED), HaveReturnCode(dns.RcodeSuccess), )) @@ -248,7 +219,7 @@ var _ = Describe("UpstreamTreeResolver", Label("upstreamTreeResolver"), func() { Expect(sut.Resolve(ctx, request)). Should( SatisfyAll( - BeDNSRecord("example.com.", A, "192.168.178.33"), + BeDNSRecord("example.com.", A, groups["192.168.178.33"]), HaveResponseType(ResponseTypeRESOLVED), HaveReturnCode(dns.RcodeSuccess), )) @@ -259,7 +230,7 @@ var _ = Describe("UpstreamTreeResolver", Label("upstreamTreeResolver"), func() { Expect(sut.Resolve(ctx, request)). Should( SatisfyAll( - BeDNSRecord("example.com.", A, "192.168.178.33"), + BeDNSRecord("example.com.", A, groups["192.168.178.33"]), HaveResponseType(ResponseTypeRESOLVED), HaveReturnCode(dns.RcodeSuccess), )) @@ -270,7 +241,7 @@ var _ = Describe("UpstreamTreeResolver", Label("upstreamTreeResolver"), func() { Expect(sut.Resolve(ctx, request)). Should( SatisfyAll( - BeDNSRecord("example.com.", A, "10.43.8.67/28"), + BeDNSRecord("example.com.", A, groups["10.43.8.67/28"]), HaveResponseType(ResponseTypeRESOLVED), HaveReturnCode(dns.RcodeSuccess), )) @@ -281,7 +252,7 @@ var _ = Describe("UpstreamTreeResolver", Label("upstreamTreeResolver"), func() { Expect(sut.Resolve(ctx, request)). Should( SatisfyAll( - BeDNSRecord("example.com.", A, "192.168.178.33"), + BeDNSRecord("example.com.", A, groups["192.168.178.33"]), HaveResponseType(ResponseTypeRESOLVED), HaveReturnCode(dns.RcodeSuccess), )) @@ -292,7 +263,7 @@ var _ = Describe("UpstreamTreeResolver", Label("upstreamTreeResolver"), func() { Expect(sut.Resolve(ctx, request)). Should( SatisfyAll( - BeDNSRecord("example.com.", A, "laptop"), + BeDNSRecord("example.com.", A, groups["laptop"]), HaveResponseType(ResponseTypeRESOLVED), HaveReturnCode(dns.RcodeSuccess), )) @@ -307,8 +278,8 @@ var _ = Describe("UpstreamTreeResolver", Label("upstreamTreeResolver"), func() { Should( SatisfyAll( SatisfyAny( - BeDNSRecord("example.com.", A, "name-matches1"), - BeDNSRecord("example.com.", A, "name-matches*"), + BeDNSRecord("example.com.", A, groups["name-matches1"]), + BeDNSRecord("example.com.", A, groups["name-matches*"]), ), HaveResponseType(ResponseTypeRESOLVED), HaveReturnCode(dns.RcodeSuccess), @@ -319,13 +290,3 @@ var _ = Describe("UpstreamTreeResolver", Label("upstreamTreeResolver"), func() { }) }) }) - -func createBranchesMock(cfg config.UpstreamsConfig) map[string]Resolver { - branches := make(map[string]Resolver, len(cfg.Groups)) - - for name := range cfg.Groups { - branches[name] = mockRes - } - - return branches -} diff --git a/server/server.go b/server/server.go index 60ae465d7..924602c24 100644 --- a/server/server.go +++ b/server/server.go @@ -388,22 +388,14 @@ func createQueryResolver( cfg *config.Config, bootstrap *resolver.Bootstrap, redisClient *redis.Client, -) (r resolver.ChainedResolver, err error) { - upstreamBranches, uErr := createUpstreamBranches(ctx, cfg, bootstrap) - if uErr != nil { - return nil, fmt.Errorf("creation of upstream branches failed: %w", uErr) - } - - upstreamTree, utErr := resolver.NewUpstreamTreeResolver(cfg.Upstreams, upstreamBranches) - +) (resolver.ChainedResolver, error) { + upstreamTree, utErr := resolver.NewUpstreamTreeResolver(ctx, cfg.Upstreams, bootstrap) blocking, blErr := resolver.NewBlockingResolver(ctx, cfg.Blocking, redisClient, bootstrap) - clientNames, cnErr := resolver.NewClientNamesResolver(ctx, cfg.ClientLookup, bootstrap, cfg.StartVerifyUpstream) - condUpstream, cuErr := resolver.NewConditionalUpstreamResolver( - ctx, cfg.Conditional, bootstrap, cfg.StartVerifyUpstream, - ) + clientNames, cnErr := resolver.NewClientNamesResolver(ctx, cfg.ClientLookup, cfg.Upstreams, bootstrap) + condUpstream, cuErr := resolver.NewConditionalUpstreamResolver(ctx, cfg.Conditional, cfg.Upstreams, bootstrap) hostsFile, hfErr := resolver.NewHostsFileResolver(ctx, cfg.HostsFile, bootstrap) - err = multierror.Append( + err := multierror.Append( multierror.Prefix(utErr, "upstream tree resolver: "), multierror.Prefix(blErr, "blocking resolver: "), multierror.Prefix(cnErr, "client names resolver: "), @@ -414,7 +406,7 @@ func createQueryResolver( return nil, err } - r = resolver.Chain( + r := resolver.Chain( resolver.NewFilteringResolver(cfg.Filtering), resolver.NewFQDNOnlyResolver(cfg.FQDNOnly), resolver.NewECSResolver(cfg.ECS), @@ -434,42 +426,6 @@ func createQueryResolver( return r, nil } -func createUpstreamBranches( - ctx context.Context, - cfg *config.Config, - bootstrap *resolver.Bootstrap, -) (map[string]resolver.Resolver, error) { - upstreamBranches := make(map[string]resolver.Resolver, len(cfg.Upstreams.Groups)) - - var uErr error - - for group, upstreams := range cfg.Upstreams.Groups { - var ( - upstream resolver.Resolver - err error - ) - - groupConfig := config.UpstreamGroup{ - Name: group, - Upstreams: upstreams, - } - - switch cfg.Upstreams.Strategy { - case config.UpstreamStrategyParallelBest: - upstream, err = resolver.NewParallelBestResolver(ctx, groupConfig, bootstrap, cfg.StartVerifyUpstream) - case config.UpstreamStrategyStrict: - upstream, err = resolver.NewStrictResolver(ctx, groupConfig, bootstrap, cfg.StartVerifyUpstream) - case config.UpstreamStrategyRandom: - upstream, err = resolver.NewParallelBestResolver(ctx, groupConfig, bootstrap, cfg.StartVerifyUpstream) - } - - upstreamBranches[group] = upstream - uErr = multierror.Append(multierror.Prefix(err, fmt.Sprintf("group %s: ", group))).ErrorOrNil() - } - - return upstreamBranches, uErr -} - func (s *Server) registerDNSHandlers() { for _, server := range s.dnsServers { handler := server.Handler.(*dns.ServeMux) diff --git a/server/server_test.go b/server/server_test.go index b14a5d9cb..53b026134 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -697,62 +697,6 @@ var _ = Describe("Running DNS server", func() { }) }) - Describe("NewServer with strict upstream strategy", func() { - It("successfully returns upstream branches", func() { - branches, err := createUpstreamBranches(context.Background(), &config.Config{ - Upstreams: config.UpstreamsConfig{ - Strategy: config.UpstreamStrategyStrict, - Groups: config.UpstreamGroups{ - "default": {{Host: "0.0.0.0"}}, - }, - }, - }, nil) - - Expect(err).ToNot(HaveOccurred()) - Expect(branches).ToNot(BeNil()) - Expect(branches).To(HaveLen(1)) - _ = branches["default"].(*resolver.StrictResolver) - }) - }) - - Describe("NewServer with random upstream strategy", func() { - It("successfully returns upstream branches", func() { - branches, err := createUpstreamBranches(context.Background(), &config.Config{ - Upstreams: config.UpstreamsConfig{ - Strategy: config.UpstreamStrategyRandom, - Groups: config.UpstreamGroups{ - "default": {{Host: "0.0.0.0"}}, - }, - }, - }, nil) - - Expect(err).ToNot(HaveOccurred()) - Expect(branches).ToNot(BeNil()) - Expect(branches).To(HaveLen(1)) - _ = branches["default"].(*resolver.ParallelBestResolver) - }) - }) - - Describe("create query resolver", func() { - When("some upstream returns error", func() { - It("create query resolver should return error", func() { - r, err := createQueryResolver(ctx, &config.Config{ - StartVerifyUpstream: true, - Upstreams: config.UpstreamsConfig{ - Groups: config.UpstreamGroups{ - "default": {{Host: "0.0.0.0"}}, - }, - }, - }, - nil, nil) - - Expect(err).To(HaveOccurred()) - Expect(err).To(MatchError(ContainSubstring("creation of upstream branches failed: "))) - Expect(r).To(BeNil()) - }) - }) - }) - Describe("resolve client IP", func() { Context("UDP address", func() { It("should correct resolve client IP", func() {