Skip to content

Commit

Permalink
squash: add context args and params where needed
Browse files Browse the repository at this point in the history
This is split to try and make review easier by having the previous
commit with all interesting changes, and then boring commits with just
the API churn.
  • Loading branch information
ThinkChaos committed Nov 19, 2023
1 parent 62f7471 commit 53deed6
Show file tree
Hide file tree
Showing 24 changed files with 113 additions and 88 deletions.
2 changes: 1 addition & 1 deletion cache/expirationcache/expiration_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ func NewCacheWithOnExpired[T any](ctx context.Context, options Options,
l, _ := lru.New(defaultSize)
c := &ExpiringLRUCache[T]{
cleanUpInterval: defaultCleanUpInterval,
preExpirationFn: func(key string) (val *T, ttl time.Duration) {
preExpirationFn: func(ctx context.Context, key string) (val *T, ttl time.Duration) {
return nil, 0
},
onCacheHit: func(key string) {},
Expand Down
6 changes: 4 additions & 2 deletions cache/expirationcache/prefetching_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,11 @@ func (e *PrefetchingExpiringLRUCache[T]) shouldPrefetch(cacheKey string) bool {
return cnt != nil && int64(cnt.Load()) > int64(e.prefetchThreshold)
}

func (e *PrefetchingExpiringLRUCache[T]) onExpired(cacheKey string) (val *cacheValue[T], ttl time.Duration) {
func (e *PrefetchingExpiringLRUCache[T]) onExpired(
ctx context.Context, cacheKey string,
) (val *cacheValue[T], ttl time.Duration) {
if e.shouldPrefetch(cacheKey) {
loadedVal, ttl := e.reloadFn(cacheKey)
loadedVal, ttl := e.reloadFn(ctx, cacheKey)
if loadedVal != nil {
if e.onPrefetchEntryReloaded != nil {
e.onPrefetchEntryReloaded(cacheKey)
Expand Down
22 changes: 11 additions & 11 deletions resolver/blocking_resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,8 @@ func NewBlockingResolver(ctx context.Context,

res.fqdnIPCache = expirationcache.NewCacheWithOnExpired[[]net.IP](ctx, expirationcache.Options{
CleanupInterval: defaultBlockingCleanUpInterval,
}, func(key string) (val *[]net.IP, ttl time.Duration) {
return res.queryForFQIdentifierIPs(key)
}, func(ctx context.Context, key string) (val *[]net.IP, ttl time.Duration) {
return res.queryForFQIdentifierIPs(ctx, key)
})

if res.redisClient != nil {
Expand Down Expand Up @@ -356,7 +356,7 @@ func (r *BlockingResolver) hasWhiteListOnlyAllowed(groupsToCheck []string) bool
return false
}

func (r *BlockingResolver) handleBlacklist(groupsToCheck []string,
func (r *BlockingResolver) handleBlacklist(ctx context.Context, groupsToCheck []string,
request *model.Request, logger *logrus.Entry,
) (bool, *model.Response, error) {
logger.WithField("groupsToCheck", strings.Join(groupsToCheck, "; ")).Debug("checking groups for request")
Expand All @@ -369,7 +369,7 @@ func (r *BlockingResolver) handleBlacklist(groupsToCheck []string,
if groups := r.matches(groupsToCheck, r.whitelistMatcher, domain); len(groups) > 0 {
logger.WithField("groups", groups).Debugf("domain is whitelisted")

resp, err := r.next.Resolve(request)
resp, err := r.next.Resolve(ctx, request)

return true, resp, err
}
Expand All @@ -391,18 +391,18 @@ func (r *BlockingResolver) handleBlacklist(groupsToCheck []string,
}

// Resolve checks the query against the blacklist and delegates to next resolver if domain is not blocked
func (r *BlockingResolver) Resolve(request *model.Request) (*model.Response, error) {
func (r *BlockingResolver) Resolve(ctx context.Context, request *model.Request) (*model.Response, error) {
logger := log.WithPrefix(request.Log, "blacklist_resolver")
groupsToCheck := r.groupsToCheckForClient(request)

if len(groupsToCheck) > 0 {
handled, resp, err := r.handleBlacklist(groupsToCheck, request, logger)
handled, resp, err := r.handleBlacklist(ctx, groupsToCheck, request, logger)
if handled {
return resp, err
}
}

respFromNext, err := r.next.Resolve(request)
respFromNext, err := r.next.Resolve(ctx, request)

if err == nil && len(groupsToCheck) > 0 && respFromNext.Res != nil {
for _, rr := range respFromNext.Res.Answer {
Expand Down Expand Up @@ -572,15 +572,15 @@ func (b ipBlockHandler) handleBlock(question dns.Question, response *dns.Msg) {
}
}

func (r *BlockingResolver) queryForFQIdentifierIPs(identifier string) (*[]net.IP, time.Duration) {
func (r *BlockingResolver) queryForFQIdentifierIPs(ctx context.Context, identifier string) (*[]net.IP, time.Duration) {
prefixedLog := log.WithPrefix(r.log(), "client_id_cache")

var result []net.IP

var ttl time.Duration

for _, qType := range []uint16{dns.TypeA, dns.TypeAAAA} {
resp, err := r.next.Resolve(&model.Request{
resp, err := r.next.Resolve(ctx, &model.Request{
Req: util.NewMsgWithQuestion(identifier, dns.Type(qType)),
Log: prefixedLog,
})
Expand All @@ -604,12 +604,12 @@ func (r *BlockingResolver) queryForFQIdentifierIPs(identifier string) (*[]net.IP
return &result, ttl
}

func (r *BlockingResolver) initFQDNIPCache() {
func (r *BlockingResolver) initFQDNIPCache(ctx context.Context) {
identifiers := maps.Keys(r.clientGroupsBlock)

for _, identifier := range identifiers {
if isFQDN(identifier) {
iPs, ttl := r.queryForFQIdentifierIPs(identifier)
iPs, ttl := r.queryForFQIdentifierIPs(ctx, identifier)
r.fqdnIPCache.Put(identifier, iPs, ttl)
}
}
Expand Down
18 changes: 9 additions & 9 deletions resolver/bootstrap.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,22 +106,22 @@ func NewBootstrap(ctx context.Context, cfg *config.Config) (b *Bootstrap, err er
return b, nil
}

func (b *Bootstrap) UpstreamIPs(r *UpstreamResolver) (*IPSet, error) {
func (b *Bootstrap) UpstreamIPs(ctx context.Context, r *UpstreamResolver) (*IPSet, error) {
hostname := r.upstream.Host

if ip := net.ParseIP(hostname); ip != nil { // nil-safe when hostname is an IP: makes writing test easier
return newIPSet([]net.IP{ip}), nil
}

ips, err := b.resolveUpstream(r, hostname)
ips, err := b.resolveUpstream(ctx, r, hostname)
if err != nil {
return nil, err
}

return newIPSet(ips), nil
}

func (b *Bootstrap) resolveUpstream(r Resolver, host string) ([]net.IP, error) {
func (b *Bootstrap) resolveUpstream(ctx context.Context, r Resolver, host string) ([]net.IP, error) {
// Use system resolver if no bootstrap is configured
if b.resolver == nil {
ctx, cancel := context.WithTimeout(ctx, b.timeout)
Expand All @@ -135,7 +135,7 @@ func (b *Bootstrap) resolveUpstream(r Resolver, host string) ([]net.IP, error) {
return ips, nil
}

return b.resolve(host, b.connectIPVersion.QTypes())
return b.resolve(ctx, host, b.connectIPVersion.QTypes())
}

// NewHTTPTransport returns a new http.Transport that uses b to resolve hostnames
Expand Down Expand Up @@ -175,7 +175,7 @@ func (b *Bootstrap) dialContext(ctx context.Context, network, addr string) (net.
}

// Resolve the host with the bootstrap DNS
ips, err := b.resolve(host, qTypes)
ips, err := b.resolve(ctx, host, qTypes)
if err != nil {
logger.Errorf("resolve error: %s", err)

Expand All @@ -192,11 +192,11 @@ func (b *Bootstrap) dialContext(ctx context.Context, network, addr string) (net.
return b.dialer.DialContext(ctx, network, addrWithIP)
}

func (b *Bootstrap) resolve(hostname string, qTypes []dns.Type) (ips []net.IP, err error) {
func (b *Bootstrap) resolve(ctx context.Context, hostname string, qTypes []dns.Type) (ips []net.IP, err error) {
ips = make([]net.IP, 0, len(qTypes))

for _, qType := range qTypes {
qIPs, qErr := b.resolveType(hostname, qType)
qIPs, qErr := b.resolveType(ctx, hostname, qType)
if qErr != nil {
err = multierror.Append(err, qErr)

Expand All @@ -213,7 +213,7 @@ func (b *Bootstrap) resolve(hostname string, qTypes []dns.Type) (ips []net.IP, e
return
}

func (b *Bootstrap) resolveType(hostname string, qType dns.Type) (ips []net.IP, err error) {
func (b *Bootstrap) resolveType(ctx context.Context, hostname string, qType dns.Type) (ips []net.IP, err error) {
if ip := net.ParseIP(hostname); ip != nil {
return []net.IP{ip}, nil
}
Expand All @@ -223,7 +223,7 @@ func (b *Bootstrap) resolveType(hostname string, qType dns.Type) (ips []net.IP,
Log: b.log,
}

rsp, err := b.resolver.Resolve(&req)
rsp, err := b.resolver.Resolve(ctx, &req)
if err != nil {
return nil, err
}
Expand Down
12 changes: 6 additions & 6 deletions resolver/caching_resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,14 +105,14 @@ func configureCaches(ctx context.Context, c *CachingResolver, cfg *config.Cachin
}
}

func (r *CachingResolver) reloadCacheEntry(cacheKey string) (*[]byte, time.Duration) {
func (r *CachingResolver) reloadCacheEntry(ctx context.Context, cacheKey string) (*[]byte, time.Duration) {
qType, domainName := util.ExtractCacheKey(cacheKey)
logger := r.log()

logger.Debugf("prefetching '%s' (%s)", util.Obfuscate(domainName), qType)

req := newRequest(dns.Fqdn(domainName), qType, logger)
response, err := r.next.Resolve(req)
response, err := r.next.Resolve(ctx, req)

if err == nil {
if response.Res.Rcode == dns.RcodeSuccess {
Expand Down Expand Up @@ -157,13 +157,13 @@ func (r *CachingResolver) LogConfig(logger *logrus.Entry) {

// Resolve checks if the current query result is already in the cache and returns it
// or delegates to the next resolver
func (r *CachingResolver) Resolve(request *model.Request) (response *model.Response, err error) {
func (r *CachingResolver) Resolve(ctx context.Context, request *model.Request) (response *model.Response, err error) {
logger := log.WithPrefix(request.Log, "caching_resolver")

if r.cfg.MaxCachingTime < 0 {
logger.Debug("skip cache")

return r.next.Resolve(request)
return r.next.Resolve(ctx, request)
}

for _, question := range request.Req.Question {
Expand All @@ -189,7 +189,7 @@ func (r *CachingResolver) Resolve(request *model.Request) (response *model.Respo
}

logger.WithField("next_resolver", Name(r.next)).Trace("not in cache: go to next resolver")
response, err = r.next.Resolve(request)
response, err = r.next.Resolve(ctx, request)

if err == nil {
cacheTTL := r.adjustTTLs(response.Res.Answer)
Expand Down Expand Up @@ -318,7 +318,7 @@ func (r *CachingResolver) publishMetricsIfEnabled(event string, val interface{})
}
}

func (r *CachingResolver) FlushCaches() {
func (r *CachingResolver) FlushCaches(context.Context) {
r.log().Debug("flush caches")
r.resultCache.Clear()
}
16 changes: 8 additions & 8 deletions resolver/client_names_resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ func NewClientNamesResolver(ctx context.Context,
) (cr *ClientNamesResolver, err error) {
var r Resolver
if !cfg.Upstream.IsDefault() {
r, err = NewUpstreamResolver(cfg.Upstream, bootstrap, shouldVerifyUpstreams)
r, err = NewUpstreamResolver(ctx, cfg.Upstream, bootstrap, shouldVerifyUpstreams)
if err != nil {
return nil, err
}
Expand All @@ -59,17 +59,17 @@ func (r *ClientNamesResolver) LogConfig(logger *logrus.Entry) {
}

// Resolve tries to resolve the client name from the ip address
func (r *ClientNamesResolver) Resolve(request *model.Request) (*model.Response, error) {
clientNames := r.getClientNames(request)
func (r *ClientNamesResolver) Resolve(ctx context.Context, request *model.Request) (*model.Response, error) {
clientNames := r.getClientNames(ctx, request)

request.ClientNames = clientNames
request.Log = request.Log.WithField("client_names", strings.Join(clientNames, "; "))

return r.next.Resolve(request)
return r.next.Resolve(ctx, request)
}

// returns names of client
func (r *ClientNamesResolver) getClientNames(request *model.Request) []string {
func (r *ClientNamesResolver) getClientNames(ctx context.Context, request *model.Request) []string {
if request.RequestClientID != "" {
return []string{request.RequestClientID}
}
Expand All @@ -88,7 +88,7 @@ func (r *ClientNamesResolver) getClientNames(request *model.Request) []string {
return cpy
}

names := r.resolveClientNames(ip, log.WithPrefix(request.Log, "client_names_resolver"))
names := r.resolveClientNames(ctx, ip, log.WithPrefix(request.Log, "client_names_resolver"))

r.cache.Put(ip.String(), &names, time.Hour)

Expand All @@ -111,7 +111,7 @@ func extractClientNamesFromAnswer(answer []dns.RR, fallbackIP net.IP) (clientNam
}

// tries to resolve client name from mapping, performs reverse DNS lookup otherwise
func (r *ClientNamesResolver) resolveClientNames(ip net.IP, logger *logrus.Entry) (result []string) {
func (r *ClientNamesResolver) resolveClientNames(ctx context.Context, ip net.IP, logger *logrus.Entry) (result []string) {
// try client mapping first
result = r.getNameFromIPMapping(ip, result)
if len(result) > 0 {
Expand All @@ -124,7 +124,7 @@ func (r *ClientNamesResolver) resolveClientNames(ip net.IP, logger *logrus.Entry

reverse, _ := dns.ReverseAddr(ip.String())

resp, err := r.externalResolver.Resolve(&model.Request{
resp, err := r.externalResolver.Resolve(ctx, &model.Request{
Req: util.NewMsgWithQuestion(reverse, dns.Type(dns.TypePTR)),
Log: logger,
})
Expand Down
23 changes: 13 additions & 10 deletions resolver/conditional_upstream_resolver.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package resolver

import (
"context"
"strings"

"github.com/0xERR0R/blocky/config"
Expand All @@ -23,14 +24,14 @@ type ConditionalUpstreamResolver struct {

// NewConditionalUpstreamResolver returns new resolver instance
func NewConditionalUpstreamResolver(
cfg config.ConditionalUpstreamConfig, bootstrap *Bootstrap, shouldVerifyUpstreams bool,
ctx context.Context, cfg config.ConditionalUpstreamConfig, bootstrap *Bootstrap, shouldVerifyUpstreams bool,
) (*ConditionalUpstreamResolver, error) {
m := make(map[string]Resolver, len(cfg.Mapping.Upstreams))

for domain, upstreams := range cfg.Mapping.Upstreams {
cfg := config.UpstreamGroup{Name: upstreamDefaultCfgName, Upstreams: upstreams}

r, err := NewParallelBestResolver(cfg, bootstrap, shouldVerifyUpstreams)
r, err := NewParallelBestResolver(ctx, cfg, bootstrap, shouldVerifyUpstreams)
if err != nil {
return nil, err
}
Expand All @@ -48,15 +49,17 @@ func NewConditionalUpstreamResolver(
return &r, nil
}

func (r *ConditionalUpstreamResolver) processRequest(request *model.Request) (bool, *model.Response, error) {
func (r *ConditionalUpstreamResolver) processRequest(
ctx context.Context, request *model.Request,
) (bool, *model.Response, error) {
domainFromQuestion := util.ExtractDomain(request.Req.Question[0])
domain := domainFromQuestion

if strings.Contains(domainFromQuestion, ".") {
// try with domain with and without sub-domains
for len(domain) > 0 {
if resolver, found := r.mapping[domain]; found {
resp, err := r.internalResolve(resolver, domainFromQuestion, domain, request)
resp, err := r.internalResolve(ctx, resolver, domainFromQuestion, domain, request)

return true, resp, err
}
Expand All @@ -68,7 +71,7 @@ func (r *ConditionalUpstreamResolver) processRequest(request *model.Request) (bo
}
}
} else if resolver, found := r.mapping["."]; found {
resp, err := r.internalResolve(resolver, domainFromQuestion, domain, request)
resp, err := r.internalResolve(ctx, resolver, domainFromQuestion, domain, request)

return true, resp, err
}
Expand All @@ -77,29 +80,29 @@ func (r *ConditionalUpstreamResolver) processRequest(request *model.Request) (bo
}

// Resolve uses the conditional resolver to resolve the query
func (r *ConditionalUpstreamResolver) Resolve(request *model.Request) (*model.Response, error) {
func (r *ConditionalUpstreamResolver) Resolve(ctx context.Context, request *model.Request) (*model.Response, error) {
logger := log.WithPrefix(request.Log, "conditional_resolver")

if len(r.mapping) > 0 {
resolved, resp, err := r.processRequest(request)
resolved, resp, err := r.processRequest(ctx, request)
if resolved {
return resp, err
}
}

logger.WithField("next_resolver", Name(r.next)).Trace("go to next resolver")

return r.next.Resolve(request)
return r.next.Resolve(ctx, request)
}

func (r *ConditionalUpstreamResolver) internalResolve(reso Resolver, doFQ, do string,
func (r *ConditionalUpstreamResolver) internalResolve(ctx context.Context, reso Resolver, doFQ, do string,
req *model.Request,
) (*model.Response, error) {
// internal request resolution
logger := log.WithPrefix(req.Log, "conditional_resolver")

req.Req.Question[0].Name = dns.Fqdn(doFQ)
response, err := reso.Resolve(req)
response, err := reso.Resolve(ctx, req)

if err == nil {
response.Reason = "CONDITIONAL"
Expand Down
Loading

0 comments on commit 53deed6

Please sign in to comment.