Skip to content

Commit

Permalink
feat: add upstream strategy random (0xERR0R#1221)
Browse files Browse the repository at this point in the history
Also simplify code by getting rid of `resolversPerClient` and all surrounding logic.
  • Loading branch information
DerRockWolf authored Nov 18, 2023
1 parent 4a5a395 commit 94663ee
Show file tree
Hide file tree
Showing 23 changed files with 653 additions and 384 deletions.
2 changes: 1 addition & 1 deletion api/api_client.gen.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion api/api_server.gen.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion api/api_types.gen.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ func (s *StartStrategyType) do(setup func() error, logErr func(error)) error {
type QueryLogField string

// UpstreamStrategy data field to be logged
// ENUM(parallel_best,strict)
// ENUM(parallel_best,strict,random)
type UpstreamStrategy uint8

//nolint:gochecknoglobals
Expand Down
8 changes: 7 additions & 1 deletion config/config_enum.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

21 changes: 21 additions & 0 deletions config/upstreams.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,24 @@ func (c *UpstreamsConfig) LogConfig(logger *logrus.Entry) {
}
}
}

// UpstreamGroup represents the config for one group (upstream branch)
type UpstreamGroup struct {
Name string
Upstreams []Upstream
}

// IsEnabled implements `config.Configurable`.
func (c *UpstreamGroup) IsEnabled() bool {
return len(c.Upstreams) != 0
}

// LogConfig implements `config.Configurable`.
func (c *UpstreamGroup) LogConfig(logger *logrus.Entry) {
logger.Info("group: ", c.Name)
logger.Info("upstreams:")

for _, upstream := range c.Upstreams {
logger.Infof(" - %s", upstream)
}
}
113 changes: 82 additions & 31 deletions config/upstreams_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,53 +9,104 @@ import (
)

var _ = Describe("ParallelBestConfig", func() {
var cfg UpstreamsConfig

suiteBeforeEach()

BeforeEach(func() {
cfg = UpstreamsConfig{
Timeout: Duration(5 * time.Second),
Groups: UpstreamGroups{
UpstreamDefaultCfgName: {
{Host: "host1"},
{Host: "host2"},
Context("UpstreamsConfig", func() {
var cfg UpstreamsConfig

BeforeEach(func() {
cfg = UpstreamsConfig{
Timeout: Duration(5 * time.Second),
Groups: UpstreamGroups{
UpstreamDefaultCfgName: {
{Host: "host1"},
{Host: "host2"},
},
},
},
}
})
}
})

Describe("IsEnabled", func() {
It("should be false by default", func() {
cfg := UpstreamsConfig{}
Expect(defaults.Set(&cfg)).Should(Succeed())
Describe("IsEnabled", func() {
It("should be false by default", func() {
cfg := UpstreamsConfig{}
Expect(defaults.Set(&cfg)).Should(Succeed())

Expect(cfg.IsEnabled()).Should(BeFalse())
Expect(cfg.IsEnabled()).Should(BeFalse())
})

When("enabled", func() {
It("should be true", func() {
Expect(cfg.IsEnabled()).Should(BeTrue())
})
})

When("disabled", func() {
It("should be false", func() {
cfg := UpstreamsConfig{}

Expect(cfg.IsEnabled()).Should(BeFalse())
})
})
})

When("enabled", func() {
It("should be true", func() {
Expect(cfg.IsEnabled()).Should(BeTrue())
Describe("LogConfig", func() {
It("should log configuration", func() {
cfg.LogConfig(logger)

Expect(hook.Calls).ShouldNot(BeEmpty())
Expect(hook.Messages).Should(ContainElement(ContainSubstring("timeout:")))
Expect(hook.Messages).Should(ContainElement(ContainSubstring("groups:")))
Expect(hook.Messages).Should(ContainElement(ContainSubstring(":host2:")))
})
})
})

When("disabled", func() {
It("should be false", func() {
cfg := UpstreamsConfig{}
Context("UpstreamGroupConfig", func() {
var cfg UpstreamGroup

BeforeEach(func() {
cfg = UpstreamGroup{
Name: UpstreamDefaultCfgName,
Upstreams: []Upstream{
{Host: "host1"},
{Host: "host2"},
},
}
})

Describe("IsEnabled", func() {
It("should be false by default", func() {
cfg := UpstreamGroup{}
Expect(defaults.Set(&cfg)).Should(Succeed())

Expect(cfg.IsEnabled()).Should(BeFalse())
})

When("enabled", func() {
It("should be true", func() {
Expect(cfg.IsEnabled()).Should(BeTrue())
})
})

When("disabled", func() {
It("should be false", func() {
cfg := UpstreamGroup{}

Expect(cfg.IsEnabled()).Should(BeFalse())
})
})
})
})

Describe("LogConfig", func() {
It("should log configuration", func() {
cfg.LogConfig(logger)
Describe("LogConfig", func() {
It("should log configuration", func() {
cfg.LogConfig(logger)

Expect(hook.Calls).ShouldNot(BeEmpty())
Expect(hook.Messages).Should(ContainElement(ContainSubstring("timeout:")))
Expect(hook.Messages).Should(ContainElement(ContainSubstring("groups:")))
Expect(hook.Messages).Should(ContainElement(ContainSubstring(":host2:")))
Expect(hook.Calls).ShouldNot(BeEmpty())
Expect(hook.Messages).Should(ContainElement(ContainSubstring("group: default")))
Expect(hook.Messages).Should(ContainElement(ContainSubstring("upstreams:")))
Expect(hook.Messages).Should(ContainElement(ContainSubstring(":host1:")))
Expect(hook.Messages).Should(ContainElement(ContainSubstring(":host2:")))
})
})
})
})
2 changes: 1 addition & 1 deletion docs/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ upstreams:
laptop*:
- 123.123.123.123
# optional: Determines what strategy blocky uses to choose the upstream servers.
# accepted: parallel_best, strict
# accepted: parallel_best, strict, random
# default: parallel_best
strategy: parallel_best
# optional: timeout to query the upstream resolver. Default: 2s
Expand Down
8 changes: 6 additions & 2 deletions docs/configuration.md
Original file line number Diff line number Diff line change
Expand Up @@ -139,10 +139,14 @@ Blocky supports different upstream strategies (default `parallel_best`) that det

Currently available strategies:

- `parallel_best`: blocky picks 2 random (weighted) resolvers from the upstream group for each query and returns the answer from the fastest one.
- `parallel_best`: blocky picks 2 random (weighted) resolvers from the upstream group for each query and returns the answer from the fastest one.
If an upstream failed to answer within the last hour, it is less likely to be chosen for the race.
This improves your network speed and increases your privacy - your DNS traffic will be distributed over multiple providers
This improves your network speed and increases your privacy - your DNS traffic will be distributed over multiple providers.
(When using 10 upstream servers, each upstream will get on average 20% of the DNS requests)
- `random`: blocky picks one random (weighted) resolver from the upstream group for each query and if successful, returns its response.
If the selected resolver fails to respond, a second one is picked to which the query is sent.
The weighting is identical to the `parallel_best` strategy.
Although the `random` strategy might be slower than the `parallel_best` strategy, it offers more privacy since each request is sent to a single upstream.
- `strict`: blocky forwards the request in a strict order. If the first upstream does not respond, the second is asked, and so on.

!!! example
Expand Down
22 changes: 5 additions & 17 deletions resolver/blocking_resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@ import (
"context"
"fmt"
"net"
"slices"
"sort"
"strings"
"sync"
"sync/atomic"
"time"

"github.com/0xERR0R/blocky/cache/expirationcache"
"golang.org/x/exp/maps"

"github.com/hashicorp/go-multierror"

Expand Down Expand Up @@ -206,21 +208,11 @@ func (r *BlockingResolver) RefreshLists() error {
return err.ErrorOrNil()
}

//nolint:prealloc
func (r *BlockingResolver) retrieveAllBlockingGroups() []string {
groups := make(map[string]bool, len(r.cfg.BlackLists))

for group := range r.cfg.BlackLists {
groups[group] = true
}

var result []string
for k := range groups {
result = append(result, k)
}
result := maps.Keys(r.cfg.BlackLists)

result = append(result, "default")
sort.Strings(result)
slices.Sort(result)

return result
}
Expand Down Expand Up @@ -615,11 +607,7 @@ func (r *BlockingResolver) queryForFQIdentifierIPs(identifier string) (*[]net.IP
}

func (r *BlockingResolver) initFQDNIPCache() {
identifiers := make([]string, 0)

for identifier := range r.clientGroupsBlock {
identifiers = append(identifiers, identifier)
}
identifiers := maps.Keys(r.clientGroupsBlock)

for _, identifier := range identifiers {
if isFQDN(identifier) {
Expand Down
17 changes: 5 additions & 12 deletions resolver/bootstrap.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"github.com/hashicorp/go-multierror"
"github.com/miekg/dns"
"github.com/sirupsen/logrus"
"golang.org/x/exp/maps"
)

const (
Expand Down Expand Up @@ -77,9 +78,9 @@ func NewBootstrap(ctx context.Context, cfg *config.Config) (b *Bootstrap, err er

// Bootstrap doesn't have a `LogConfig` method, and since that's the only place
// where `ParallelBestResolver` uses its config, we can just use an empty one.
var pbCfg config.UpstreamsConfig
pbCfg := config.UpstreamGroup{Name: upstreamDefaultCfgName}

parallelResolver := newParallelBestResolver(pbCfg, bootstraped.ResolverGroups())
parallelResolver := newParallelBestResolver(pbCfg, bootstraped.Resolvers())

// Always enable prefetching to avoid stalling user requests
// Otherwise, a request to blocky could end up waiting for 2 DNS requests:
Expand Down Expand Up @@ -300,16 +301,8 @@ func newBootstrapedResolvers(b *Bootstrap, cfg config.BootstrapDNSConfig) (boots
return upstreamIPs, nil
}

func (br bootstrapedResolvers) ResolverGroups() map[string][]Resolver {
resolvers := make([]Resolver, 0, len(br))

for resolver := range br {
resolvers = append(resolvers, resolver)
}

return map[string][]Resolver{
upstreamDefaultCfgName: resolvers,
}
func (br bootstrapedResolvers) Resolvers() []Resolver {
return maps.Keys(br)
}

type IPSet struct {
Expand Down
1 change: 1 addition & 0 deletions resolver/bootstrap_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() {
)

BeforeEach(func() {
config.GetConfig().Upstreams.Strategy = config.UpstreamStrategyParallelBest
sutConfig = &config.Config{
BootstrapDNS: []config.BootstrappedUpstreamConfig{
{
Expand Down
2 changes: 1 addition & 1 deletion resolver/caching_resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ func (r *CachingResolver) Resolve(request *model.Request) (response *model.Respo
return &model.Response{Res: val, RType: model.ResponseTypeCACHED, Reason: "CACHED NEGATIVE"}, nil
}

logger.WithField("next_resolver", Name(r.next)).Debug("not in cache: go to next resolver")
logger.WithField("next_resolver", Name(r.next)).Trace("not in cache: go to next resolver")
response, err = r.next.Resolve(request)

if err == nil {
Expand Down
10 changes: 3 additions & 7 deletions resolver/conditional_upstream_resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,10 @@ func NewConditionalUpstreamResolver(
) (*ConditionalUpstreamResolver, error) {
m := make(map[string]Resolver, len(cfg.Mapping.Upstreams))

for domain, upstream := range cfg.Mapping.Upstreams {
pbCfg := config.UpstreamsConfig{
Groups: config.UpstreamGroups{
upstreamDefaultCfgName: upstream,
},
}
for domain, upstreams := range cfg.Mapping.Upstreams {
cfg := config.UpstreamGroup{Name: upstreamDefaultCfgName, Upstreams: upstreams}

r, err := NewParallelBestResolver(pbCfg, bootstrap, shouldVerifyUpstreams)
r, err := NewParallelBestResolver(cfg, bootstrap, shouldVerifyUpstreams)
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion resolver/custom_dns_resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ func (r *CustomDNSResolver) Resolve(request *model.Request) (*model.Response, er
}
}

logger.WithField("resolver", Name(r.next)).Trace("go to next resolver")
logger.WithField("next_resolver", Name(r.next)).Trace("go to next resolver")

return r.next.Resolve(request)
}
Loading

0 comments on commit 94663ee

Please sign in to comment.