Skip to content

Commit

Permalink
refactor: add UpstreamGroup.UpstreamsConfig to make values accessible
Browse files Browse the repository at this point in the history
Move `startVerifyUpstream` to `upstreams.startVerify` so it's accessible
via `UpstreamGroup` and we don't need to pass `startVerify` to all
resolver constructors that call `NewUpstreamResolver`.

Also has the nice benefit of greatly reducing the usage of `GetConfig`.
  • Loading branch information
ThinkChaos committed Nov 22, 2023
1 parent 3ea4012 commit 475023d
Show file tree
Hide file tree
Showing 22 changed files with 432 additions and 502 deletions.
87 changes: 44 additions & 43 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -189,44 +189,44 @@ func (b *BootstrappedUpstreamConfig) UnmarshalYAML(unmarshal func(interface{}) e
//
//nolint:maligned
type Config struct {
Upstreams UpstreamsConfig `yaml:"upstreams"`
ConnectIPVersion IPVersion `yaml:"connectIPVersion"`
CustomDNS CustomDNSConfig `yaml:"customDNS"`
Conditional ConditionalUpstreamConfig `yaml:"conditional"`
Blocking BlockingConfig `yaml:"blocking"`
ClientLookup ClientLookupConfig `yaml:"clientLookup"`
Caching CachingConfig `yaml:"caching"`
QueryLog QueryLogConfig `yaml:"queryLog"`
Prometheus MetricsConfig `yaml:"prometheus"`
Redis RedisConfig `yaml:"redis"`
Log log.Config `yaml:"log"`
Ports PortsConfig `yaml:"ports"`
DoHUserAgent string `yaml:"dohUserAgent"`
MinTLSServeVer string `yaml:"minTlsServeVersion" default:"1.2"`
StartVerifyUpstream bool `yaml:"startVerifyUpstream" default:"false"`
CertFile string `yaml:"certFile"`
KeyFile string `yaml:"keyFile"`
BootstrapDNS BootstrapDNSConfig `yaml:"bootstrapDns"`
HostsFile HostsFileConfig `yaml:"hostsFile"`
FQDNOnly FQDNOnly `yaml:"fqdnOnly"`
Filtering FilteringConfig `yaml:"filtering"`
EDE EDE `yaml:"ede"`
ECS ECS `yaml:"ecs"`
SUDN SUDN `yaml:"specialUseDomains"`
Upstreams UpstreamsConfig `yaml:"upstreams"`
ConnectIPVersion IPVersion `yaml:"connectIPVersion"`
CustomDNS CustomDNSConfig `yaml:"customDNS"`
Conditional ConditionalUpstreamConfig `yaml:"conditional"`
Blocking BlockingConfig `yaml:"blocking"`
ClientLookup ClientLookupConfig `yaml:"clientLookup"`
Caching CachingConfig `yaml:"caching"`
QueryLog QueryLogConfig `yaml:"queryLog"`
Prometheus MetricsConfig `yaml:"prometheus"`
Redis RedisConfig `yaml:"redis"`
Log log.Config `yaml:"log"`
Ports PortsConfig `yaml:"ports"`
DoHUserAgent string `yaml:"dohUserAgent"`
MinTLSServeVer string `yaml:"minTlsServeVersion" default:"1.2"`
CertFile string `yaml:"certFile"`
KeyFile string `yaml:"keyFile"`
BootstrapDNS BootstrapDNSConfig `yaml:"bootstrapDns"`
HostsFile HostsFileConfig `yaml:"hostsFile"`
FQDNOnly FQDNOnly `yaml:"fqdnOnly"`
Filtering FilteringConfig `yaml:"filtering"`
EDE EDE `yaml:"ede"`
ECS ECS `yaml:"ecs"`
SUDN SUDN `yaml:"specialUseDomains"`

// Deprecated options
Deprecated struct {
Upstream *UpstreamGroups `yaml:"upstream"`
UpstreamTimeout *Duration `yaml:"upstreamTimeout"`
DisableIPv6 *bool `yaml:"disableIPv6"`
LogLevel *log.Level `yaml:"logLevel"`
LogFormat *log.FormatType `yaml:"logFormat"`
LogPrivacy *bool `yaml:"logPrivacy"`
LogTimestamp *bool `yaml:"logTimestamp"`
DNSPorts *ListenConfig `yaml:"port"`
HTTPPorts *ListenConfig `yaml:"httpPort"`
HTTPSPorts *ListenConfig `yaml:"httpsPort"`
TLSPorts *ListenConfig `yaml:"tlsPort"`
Upstream *UpstreamGroups `yaml:"upstream"`
UpstreamTimeout *Duration `yaml:"upstreamTimeout"`
DisableIPv6 *bool `yaml:"disableIPv6"`
LogLevel *log.Level `yaml:"logLevel"`
LogFormat *log.FormatType `yaml:"logFormat"`
LogPrivacy *bool `yaml:"logPrivacy"`
LogTimestamp *bool `yaml:"logTimestamp"`
DNSPorts *ListenConfig `yaml:"port"`
HTTPPorts *ListenConfig `yaml:"httpPort"`
HTTPSPorts *ListenConfig `yaml:"httpsPort"`
TLSPorts *ListenConfig `yaml:"tlsPort"`
StartVerifyUpstream *bool `yaml:"startVerifyUpstream"`
} `yaml:",inline"`
}

Expand Down Expand Up @@ -514,14 +514,15 @@ func (cfg *Config) migrate(logger *logrus.Entry) bool {
cfg.Filtering.QueryTypes.Insert(dns.Type(dns.TypeAAAA))
}
}),
"port": Move(To("ports.dns", &cfg.Ports)),
"httpPort": Move(To("ports.http", &cfg.Ports)),
"httpsPort": Move(To("ports.https", &cfg.Ports)),
"tlsPort": Move(To("ports.tls", &cfg.Ports)),
"logLevel": Move(To("log.level", &cfg.Log)),
"logFormat": Move(To("log.format", &cfg.Log)),
"logPrivacy": Move(To("log.privacy", &cfg.Log)),
"logTimestamp": Move(To("log.timestamp", &cfg.Log)),
"port": Move(To("ports.dns", &cfg.Ports)),
"httpPort": Move(To("ports.http", &cfg.Ports)),
"httpsPort": Move(To("ports.https", &cfg.Ports)),
"tlsPort": Move(To("ports.tls", &cfg.Ports)),
"logLevel": Move(To("log.level", &cfg.Log)),
"logFormat": Move(To("log.format", &cfg.Log)),
"logPrivacy": Move(To("log.privacy", &cfg.Log)),
"logTimestamp": Move(To("log.timestamp", &cfg.Log)),
"startVerifyUpstream": Move(To("upstreams.startVerify", &cfg.Upstreams)),
})

usesDepredOpts = cfg.Blocking.migrate(logger) || usesDepredOpts
Expand Down
5 changes: 3 additions & 2 deletions config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -770,6 +770,7 @@ bootstrapDns:

func defaultTestFileConfig() {
Expect(config.Ports.DNS).Should(Equal(ListenConfig{"55553", ":55554", "[::1]:55555"}))
Expect(config.Upstreams.StartVerify).Should(BeFalse())
Expect(config.Upstreams.Groups["default"]).Should(HaveLen(3))
Expect(config.Upstreams.Groups["default"][0].Host).Should(Equal("8.8.8.8"))
Expect(config.Upstreams.Groups["default"][1].Host).Should(Equal("8.8.4.4"))
Expand Down Expand Up @@ -798,14 +799,14 @@ func defaultTestFileConfig() {

Expect(config.DoHUserAgent).Should(Equal("testBlocky"))
Expect(config.MinTLSServeVer).Should(Equal("1.3"))
Expect(config.StartVerifyUpstream).Should(BeFalse())

Expect(GetConfig()).Should(Not(BeNil()))
}

func writeConfigYml(tmpDir *helpertest.TmpFolder) *helpertest.TmpFile {
return tmpDir.CreateStringFile("config.yml",
"upstreams:",
" startVerify: false",
" groups:",
" default:",
" - tcp+udp:8.8.8.8",
Expand Down Expand Up @@ -860,7 +861,7 @@ func writeConfigYml(tmpDir *helpertest.TmpFolder) *helpertest.TmpFile {
"logLevel: debug",
"dohUserAgent: testBlocky",
"minTlsServeVersion: 1.3",
"startVerifyUpstream: false")
)
}

func writeConfigDir(tmpDir *helpertest.TmpFolder) error {
Expand Down
34 changes: 27 additions & 7 deletions config/upstreams.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@ const UpstreamDefaultCfgName = "default"

// UpstreamsConfig upstream servers configuration
type UpstreamsConfig struct {
Timeout Duration `yaml:"timeout" default:"2s"`
Groups UpstreamGroups `yaml:"groups"`
Strategy UpstreamStrategy `yaml:"strategy" default:"parallel_best"`
Timeout Duration `yaml:"timeout" default:"2s"`
Groups UpstreamGroups `yaml:"groups"`
Strategy UpstreamStrategy `yaml:"strategy" default:"parallel_best"`
StartVerify bool `yaml:"startVerify" default:"false"`
}

type UpstreamGroups map[string][]Upstream
Expand All @@ -37,21 +38,40 @@ func (c *UpstreamsConfig) LogConfig(logger *logrus.Entry) {

// UpstreamGroup represents the config for one group (upstream branch)
type UpstreamGroup struct {
Name string
Upstreams []Upstream
UpstreamsConfig

Name string // group name
}

// NewUpstreamGroup creates an UpstreamGroup with the given name and upstreams.
//
// The upstreams from `cfg.Groups` are ignored.
func NewUpstreamGroup(name string, cfg UpstreamsConfig, upstreams []Upstream) UpstreamGroup {
group := UpstreamGroup{
Name: name,
UpstreamsConfig: cfg,
}

group.Groups = UpstreamGroups{name: upstreams}

return group
}

func (c *UpstreamGroup) Upstreams() []Upstream {
return c.Groups[c.Name]
}

// IsEnabled implements `config.Configurable`.
func (c *UpstreamGroup) IsEnabled() bool {
return len(c.Upstreams) != 0
return len(c.Upstreams()) != 0
}

// LogConfig implements `config.Configurable`.
func (c *UpstreamGroup) LogConfig(logger *logrus.Entry) {
logger.Info("group: ", c.Name)
logger.Info("upstreams:")

for _, upstream := range c.Upstreams {
for _, upstream := range c.Upstreams() {
logger.Infof(" - %s", upstream)
}
}
16 changes: 8 additions & 8 deletions config/upstreams_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,13 @@ var _ = Describe("ParallelBestConfig", func() {
var cfg UpstreamGroup

BeforeEach(func() {
cfg = UpstreamGroup{
Name: UpstreamDefaultCfgName,
Upstreams: []Upstream{
{Host: "host1"},
{Host: "host2"},
},
}
upstreamsCfg, err := WithDefaults[UpstreamsConfig]()
Expect(err).Should(Succeed())

cfg = NewUpstreamGroup("test", upstreamsCfg, []Upstream{
{Host: "host1"},
{Host: "host2"},
})
})

Describe("IsEnabled", func() {
Expand Down Expand Up @@ -102,7 +102,7 @@ var _ = Describe("ParallelBestConfig", func() {
cfg.LogConfig(logger)

Expect(hook.Calls).ShouldNot(BeEmpty())
Expect(hook.Messages).Should(ContainElement(ContainSubstring("group: default")))
Expect(hook.Messages).Should(ContainElement(ContainSubstring("group: test")))
Expect(hook.Messages).Should(ContainElement(ContainSubstring("upstreams:")))
Expect(hook.Messages).Should(ContainElement(ContainSubstring(":host1:")))
Expect(hook.Messages).Should(ContainElement(ContainSubstring(":host2:")))
Expand Down
19 changes: 9 additions & 10 deletions resolver/bootstrap.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ func NewBootstrap(ctx context.Context, cfg *config.Config) (b *Bootstrap, err er
},
}

bootstraped, err := newBootstrapedResolvers(b, cfg.BootstrapDNS)
bootstraped, err := newBootstrapedResolvers(b, cfg.BootstrapDNS, cfg.Upstreams)
if err != nil {
return nil, err
}
Expand All @@ -76,11 +76,8 @@ func NewBootstrap(ctx context.Context, cfg *config.Config) (b *Bootstrap, err er
return b, nil
}

// Bootstrap doesn't have a `LogConfig` method, and since that's the only place
// where `ParallelBestResolver` uses its config, we can just use an empty one.
pbCfg := config.UpstreamGroup{Name: upstreamDefaultCfgName}

parallelResolver := newParallelBestResolver(pbCfg, bootstraped.Resolvers())
pbCfg := config.NewUpstreamGroup("<bootstrap>", cfg.Upstreams, nil)
pbCfg.UpstreamsConfig.Groups = nil // To be on the safe side it doesn't try to use anything besides the bootstrap

// Always enable prefetching to avoid stalling user requests
// Otherwise, a request to blocky could end up waiting for 2 DNS requests:
Expand All @@ -100,14 +97,14 @@ func NewBootstrap(ctx context.Context, cfg *config.Config) (b *Bootstrap, err er
NewFilteringResolver(cfg.Filtering),
// false: no metrics, to not overwrite the main blocking resolver ones
newCachingResolver(ctx, cachingCfg, nil, false),
parallelResolver,
newParallelBestResolver(pbCfg, bootstraped.Resolvers()),
)

return b, nil
}

func (b *Bootstrap) UpstreamIPs(ctx context.Context, r *UpstreamResolver) (*IPSet, error) {
hostname := r.upstream.Host
hostname := r.cfg.Host

if ip := net.ParseIP(hostname); ip != nil { // nil-safe when hostname is an IP: makes writing test easier
return newIPSet([]net.IP{ip}), nil
Expand Down Expand Up @@ -249,7 +246,9 @@ func (b *Bootstrap) resolveType(ctx context.Context, hostname string, qType dns.
// map of bootstraped resolvers their hardcoded IPs
type bootstrapedResolvers map[Resolver][]net.IP

func newBootstrapedResolvers(b *Bootstrap, cfg config.BootstrapDNSConfig) (bootstrapedResolvers, error) {
func newBootstrapedResolvers(
b *Bootstrap, cfg config.BootstrapDNSConfig, upstreamsCfg config.UpstreamsConfig,
) (bootstrapedResolvers, error) {
upstreamIPs := make(bootstrapedResolvers, len(cfg))

var multiErr *multierror.Error
Expand Down Expand Up @@ -289,7 +288,7 @@ func newBootstrapedResolvers(b *Bootstrap, cfg config.BootstrapDNSConfig) (boots
continue
}

resolver := newUpstreamResolverUnchecked(upstream, b)
resolver := newUpstreamResolverUnchecked(newUpstreamConfig(upstream, upstreamsCfg), b)

upstreamIPs[resolver] = ips
}
Expand Down
4 changes: 2 additions & 2 deletions resolver/bootstrap_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() {
)

BeforeEach(func() {
config.GetConfig().Upstreams.Strategy = config.UpstreamStrategyParallelBest
sutConfig = &config.Config{
BootstrapDNS: []config.BootstrappedUpstreamConfig{
{
Expand All @@ -43,6 +42,7 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() {
IPs: []net.IP{net.IPv4zero},
},
},
Upstreams: defaultUpstreamsConfig,
}

ctx, cancelFn = context.WithCancel(context.Background())
Expand Down Expand Up @@ -327,7 +327,7 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() {

upstream.Host = "localhost" // force bootstrap to do resolve, and not just return the IP as is

r := newUpstreamResolverUnchecked(upstream, sut)
r := newUpstreamResolverUnchecked(newUpstreamConfig(upstream, sutConfig.Upstreams), sut)

rsp, err := r.Resolve(ctx, mainReq)
Expect(err).Should(Succeed())
Expand Down
4 changes: 2 additions & 2 deletions resolver/client_names_resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,11 @@ type ClientNamesResolver struct {

// NewClientNamesResolver creates new resolver instance
func NewClientNamesResolver(ctx context.Context,
cfg config.ClientLookupConfig, bootstrap *Bootstrap, shouldVerifyUpstreams bool,
cfg config.ClientLookupConfig, upstreamsCfg config.UpstreamsConfig, bootstrap *Bootstrap,
) (cr *ClientNamesResolver, err error) {
var r Resolver
if !cfg.Upstream.IsDefault() {
r, err = NewUpstreamResolver(ctx, cfg.Upstream, bootstrap, shouldVerifyUpstreams)
r, err = NewUpstreamResolver(ctx, newUpstreamConfig(cfg.Upstream, upstreamsCfg), bootstrap)
if err != nil {
return nil, err
}
Expand Down
7 changes: 5 additions & 2 deletions resolver/client_names_resolver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() {
ctx, cancelFn = context.WithCancel(context.Background())
DeferCleanup(cancelFn)

sut, err = NewClientNamesResolver(ctx, sutConfig, nil, false)
sut, err = NewClientNamesResolver(ctx, sutConfig, defaultUpstreamsConfig, nil)
Expect(err).Should(Succeed())
m = &mockResolver{}
m.On("Resolve", mock.Anything).Return(&Response{Res: new(dns.Msg)}, nil)
Expand Down Expand Up @@ -370,9 +370,12 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() {
It("errors during construction", func() {
b := newTestBootstrap(ctx, &dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure}})

upstreamsCfg := defaultUpstreamsConfig
upstreamsCfg.StartVerify = true

r, err := NewClientNamesResolver(ctx, config.ClientLookupConfig{
Upstream: config.Upstream{Host: "example.com"},
}, b, true)
}, upstreamsCfg, b)

Expect(err).ShouldNot(Succeed())
Expect(r).Should(BeNil())
Expand Down
8 changes: 5 additions & 3 deletions resolver/conditional_upstream_resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package resolver

import (
"context"
"fmt"
"strings"

"github.com/0xERR0R/blocky/config"
Expand All @@ -24,14 +25,15 @@ type ConditionalUpstreamResolver struct {

// NewConditionalUpstreamResolver returns new resolver instance
func NewConditionalUpstreamResolver(
ctx context.Context, cfg config.ConditionalUpstreamConfig, bootstrap *Bootstrap, shouldVerifyUpstreams bool,
ctx context.Context, cfg config.ConditionalUpstreamConfig, upstreamsCfg config.UpstreamsConfig, bootstrap *Bootstrap,
) (*ConditionalUpstreamResolver, error) {
m := make(map[string]Resolver, len(cfg.Mapping.Upstreams))

for domain, upstreams := range cfg.Mapping.Upstreams {
cfg := config.UpstreamGroup{Name: upstreamDefaultCfgName, Upstreams: upstreams}
name := fmt.Sprintf("<conditional in %s>", domain)
cfg := config.NewUpstreamGroup(name, upstreamsCfg, upstreams)

r, err := NewParallelBestResolver(ctx, cfg, bootstrap, shouldVerifyUpstreams)
r, err := NewParallelBestResolver(ctx, cfg, bootstrap)
if err != nil {
return nil, err
}
Expand Down
Loading

0 comments on commit 475023d

Please sign in to comment.