From c9db98618183f7ebe03c9990a0ed09b371452d83 Mon Sep 17 00:00:00 2001 From: ThinkChaos Date: Sun, 19 Nov 2023 10:56:06 -0500 Subject: [PATCH] squash: fix(tests): update to pass contexts where needed This is split to try and make review easier by having the previous commit with all interesting changes, and then boring commits with just the API churn. --- api/api_interface_impl_test.go | 39 ++++--- .../expirationcache/expiration_cache_test.go | 6 +- .../expirationcache/prefetching_cache_test.go | 8 +- resolver/blocking_resolver_test.go | 109 +++++++++--------- resolver/bootstrap_test.go | 26 ++--- resolver/caching_resolver_test.go | 88 ++++++++------ resolver/client_names_resolver_test.go | 36 +++--- .../conditional_upstream_resolver_test.go | 24 ++-- resolver/custom_dns_resolver_test.go | 33 +++--- resolver/ede_resolver_test.go | 13 ++- resolver/filtering_resolver_test.go | 14 ++- resolver/fqdn_only_resolver_test.go | 16 ++- resolver/hosts_file_resolver_test.go | 46 ++++---- resolver/metrics_resolver_test.go | 11 +- resolver/mocks_test.go | 6 +- resolver/noop_resolver_test.go | 14 ++- resolver/parallel_best_resolver_test.go | 30 ++--- resolver/query_logging_resolver_test.go | 23 ++-- resolver/rewriter_resolver_test.go | 16 +-- resolver/strict_resolver_test.go | 23 ++-- resolver/sudn_resolver_test.go | 19 ++- resolver/upstream_resolver_test.go | 27 +++-- resolver/upstream_tree_resolver_test.go | 30 +++-- server/server_test.go | 4 +- 24 files changed, 392 insertions(+), 269 deletions(-) diff --git a/api/api_interface_impl_test.go b/api/api_interface_impl_test.go index 0072e6d03..3d3c8a608 100644 --- a/api/api_interface_impl_test.go +++ b/api/api_interface_impl_test.go @@ -5,7 +5,6 @@ import ( "errors" "time" - // . "github.com/0xERR0R/blocky/helpertest" "github.com/0xERR0R/blocky/model" "github.com/0xERR0R/blocky/util" "github.com/miekg/dns" @@ -54,14 +53,14 @@ func (m *BlockingControlMock) BlockingStatus() BlockingStatus { return args.Get(0).(BlockingStatus) } -func (m *QuerierMock) Query(question string, qType dns.Type) (*model.Response, error) { - args := m.Called(question, qType) +func (m *QuerierMock) Query(ctx context.Context, question string, qType dns.Type) (*model.Response, error) { + args := m.Called(ctx, question, qType) return args.Get(0).(*model.Response), args.Error(1) } -func (m *CacheControlMock) FlushCaches() { - _ = m.Called() +func (m *CacheControlMock) FlushCaches(ctx context.Context) { + _ = m.Called(ctx) } var _ = Describe("API implementation tests", func() { @@ -71,9 +70,15 @@ var _ = Describe("API implementation tests", func() { listRefreshMock *ListRefreshMock cacheControlMock *CacheControlMock sut *OpenAPIInterfaceImpl + + ctx context.Context + cancelFn context.CancelFunc ) BeforeEach(func() { + ctx, cancelFn = context.WithCancel(context.Background()) + DeferCleanup(cancelFn) + blockingControlMock = &BlockingControlMock{} querierMock = &QuerierMock{} listRefreshMock = &ListRefreshMock{} @@ -95,12 +100,12 @@ var _ = Describe("API implementation tests", func() { ) Expect(err).Should(Succeed()) - querierMock.On("Query", "google.com.", A).Return(&model.Response{ + querierMock.On("Query", ctx, "google.com.", A).Return(&model.Response{ Res: queryResponse, Reason: "reason", }, nil) - resp, err := sut.Query(context.Background(), QueryRequestObject{ + resp, err := sut.Query(ctx, QueryRequestObject{ Body: &ApiQueryRequest{ Query: "google.com", Type: "A", }, @@ -116,7 +121,7 @@ var _ = Describe("API implementation tests", func() { }) It("should return 400 on wrong parameter", func() { - resp, err := sut.Query(context.Background(), QueryRequestObject{ + resp, err := sut.Query(ctx, QueryRequestObject{ Body: &ApiQueryRequest{ Query: "google.com", Type: "WRONGTYPE", @@ -135,7 +140,7 @@ var _ = Describe("API implementation tests", func() { It("should return 200 on success", func() { listRefreshMock.On("RefreshLists").Return(nil) - resp, err := sut.ListRefresh(context.Background(), ListRefreshRequestObject{}) + resp, err := sut.ListRefresh(ctx, ListRefreshRequestObject{}) Expect(err).Should(Succeed()) var resp200 ListRefresh200Response Expect(resp).Should(BeAssignableToTypeOf(resp200)) @@ -144,7 +149,7 @@ var _ = Describe("API implementation tests", func() { It("should return 500 on failure", func() { listRefreshMock.On("RefreshLists").Return(errors.New("failed")) - resp, err := sut.ListRefresh(context.Background(), ListRefreshRequestObject{}) + resp, err := sut.ListRefresh(ctx, ListRefreshRequestObject{}) Expect(err).Should(Succeed()) var resp500 ListRefresh500TextResponse Expect(resp).Should(BeAssignableToTypeOf(resp500)) @@ -160,7 +165,7 @@ var _ = Describe("API implementation tests", func() { duration := "3s" grroups := "gr1,gr2" - resp, err := sut.DisableBlocking(context.Background(), DisableBlockingRequestObject{ + resp, err := sut.DisableBlocking(ctx, DisableBlockingRequestObject{ Params: DisableBlockingParams{ Duration: &duration, Groups: &grroups, @@ -173,7 +178,7 @@ var _ = Describe("API implementation tests", func() { It("should return 400 on failure", func() { blockingControlMock.On("DisableBlocking", mock.Anything, mock.Anything).Return(errors.New("failed")) - resp, err := sut.DisableBlocking(context.Background(), DisableBlockingRequestObject{}) + resp, err := sut.DisableBlocking(ctx, DisableBlockingRequestObject{}) Expect(err).Should(Succeed()) var resp400 DisableBlocking400TextResponse Expect(resp).Should(BeAssignableToTypeOf(resp400)) @@ -182,7 +187,7 @@ var _ = Describe("API implementation tests", func() { It("should return 400 on wrong duration parameter", func() { wrongDuration := "4sds" - resp, err := sut.DisableBlocking(context.Background(), DisableBlockingRequestObject{ + resp, err := sut.DisableBlocking(ctx, DisableBlockingRequestObject{ Params: DisableBlockingParams{ Duration: &wrongDuration, }, @@ -197,7 +202,7 @@ var _ = Describe("API implementation tests", func() { It("should return 200 on success", func() { blockingControlMock.On("EnableBlocking").Return() - resp, err := sut.EnableBlocking(context.Background(), EnableBlockingRequestObject{}) + resp, err := sut.EnableBlocking(ctx, EnableBlockingRequestObject{}) Expect(err).Should(Succeed()) var resp200 EnableBlocking200Response Expect(resp).Should(BeAssignableToTypeOf(resp200)) @@ -212,7 +217,7 @@ var _ = Describe("API implementation tests", func() { AutoEnableInSec: 47, }) - resp, err := sut.BlockingStatus(context.Background(), BlockingStatusRequestObject{}) + resp, err := sut.BlockingStatus(ctx, BlockingStatusRequestObject{}) Expect(err).Should(Succeed()) var resp200 BlockingStatus200JSONResponse Expect(resp).Should(BeAssignableToTypeOf(resp200)) @@ -227,8 +232,8 @@ var _ = Describe("API implementation tests", func() { Describe("Cache API", func() { When("Cache flush is called", func() { It("should return 200 on success", func() { - cacheControlMock.On("FlushCaches").Return() - resp, err := sut.CacheFlush(context.Background(), CacheFlushRequestObject{}) + cacheControlMock.On("FlushCaches", ctx).Return() + resp, err := sut.CacheFlush(ctx, CacheFlushRequestObject{}) Expect(err).Should(Succeed()) var resp200 CacheFlush200Response Expect(resp).Should(BeAssignableToTypeOf(resp200)) diff --git a/cache/expirationcache/expiration_cache_test.go b/cache/expirationcache/expiration_cache_test.go index 9364350b2..f8313dffd 100644 --- a/cache/expirationcache/expiration_cache_test.go +++ b/cache/expirationcache/expiration_cache_test.go @@ -149,7 +149,7 @@ var _ = Describe("Expiration cache", func() { Describe("preExpiration function", func() { When("function is defined", func() { It("should update the value and TTL if function returns values", func() { - fn := func(key string) (val *string, ttl time.Duration) { + fn := func(ctx context.Context, key string) (val *string, ttl time.Duration) { v2 := "v2" return &v2, time.Second @@ -169,7 +169,7 @@ var _ = Describe("Expiration cache", func() { }) It("should update the value and TTL if function returns values on cleanup if element is expired", func() { - fn := func(key string) (val *string, ttl time.Duration) { + fn := func(ctx context.Context, key string) (val *string, ttl time.Duration) { v2 := "val2" return &v2, time.Second @@ -192,7 +192,7 @@ var _ = Describe("Expiration cache", func() { }) It("should delete the key if function returns nil", func() { - fn := func(key string) (val *string, ttl time.Duration) { + fn := func(ctx context.Context, key string) (val *string, ttl time.Duration) { return nil, 0 } cache := NewCacheWithOnExpired[string](ctx, Options{CleanupInterval: 100 * time.Microsecond}, fn) diff --git a/cache/expirationcache/prefetching_cache_test.go b/cache/expirationcache/prefetching_cache_test.go index 91b796a1b..fb94c62ff 100644 --- a/cache/expirationcache/prefetching_cache_test.go +++ b/cache/expirationcache/prefetching_cache_test.go @@ -54,7 +54,7 @@ var _ = Describe("Prefetching expiration cache", func() { }, PrefetchThreshold: 2, PrefetchExpires: 100 * time.Millisecond, - ReloadFn: func(cacheKey string) (*string, time.Duration) { + ReloadFn: func(ctx context.Context, cacheKey string) (*string, time.Duration) { v := "v2" return &v, 50 * time.Millisecond @@ -86,7 +86,7 @@ var _ = Describe("Prefetching expiration cache", func() { }, PrefetchThreshold: 2, PrefetchExpires: 100 * time.Millisecond, - ReloadFn: func(cacheKey string) (*string, time.Duration) { + ReloadFn: func(ctx context.Context, cacheKey string) (*string, time.Duration) { v := "v2" return &v, 50 * time.Millisecond @@ -113,7 +113,7 @@ var _ = Describe("Prefetching expiration cache", func() { Options: Options{ CleanupInterval: 100 * time.Millisecond, }, - ReloadFn: func(cacheKey string) (*string, time.Duration) { + ReloadFn: func(ctx context.Context, cacheKey string) (*string, time.Duration) { v := "v2" return &v, 50 * time.Millisecond @@ -143,7 +143,7 @@ var _ = Describe("Prefetching expiration cache", func() { }, PrefetchThreshold: 2, PrefetchExpires: 100 * time.Millisecond, - ReloadFn: func(cacheKey string) (*string, time.Duration) { + ReloadFn: func(ctx context.Context, cacheKey string) (*string, time.Duration) { v := "v2" return &v, 50 * time.Millisecond diff --git a/resolver/blocking_resolver_test.go b/resolver/blocking_resolver_test.go index b7004fbed..e55b7419e 100644 --- a/resolver/blocking_resolver_test.go +++ b/resolver/blocking_resolver_test.go @@ -155,7 +155,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { } Bus().Publish(ApplicationStarted, "") Eventually(func(g Gomega) { - g.Expect(sut.Resolve(newRequestWithClient("blocked2.com.", A, "192.168.178.39", "client1"))). + g.Expect(sut.Resolve(ctx, newRequestWithClient("blocked2.com.", A, "192.168.178.39", "client1"))). Should(And( BeDNSRecord("blocked2.com.", A, "0.0.0.0"), HaveTTL(BeNumerically("==", 60)), @@ -185,6 +185,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { When("Domain is on the black list", func() { It("should block request", func() { Eventually(sut.Resolve). + WithContext(ctx). WithArguments(newRequestWithClient("regex.com.", dns.Type(dns.TypeA), "1.2.1.2", "client1")). Should( SatisfyAll( @@ -222,7 +223,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { When("client name is defined in client groups block", func() { It("should block the A query if domain is on the black list (single)", func() { - Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "1.2.1.2", "client1"))). + Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "1.2.1.2", "client1"))). Should( SatisfyAll( BeDNSRecord("domain1.com.", A, "0.0.0.0"), @@ -233,7 +234,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { )) }) It("should block the A query if domain is on the black list (multipart 1)", func() { - Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "1.2.1.2", "client2"))). + Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "1.2.1.2", "client2"))). Should( SatisfyAll( BeDNSRecord("domain1.com.", A, "0.0.0.0"), @@ -244,7 +245,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { )) }) It("should block the A query if domain is on the black list (multipart 2)", func() { - Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "1.2.1.2", "client3"))). + Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "1.2.1.2", "client3"))). Should( SatisfyAll( BeDNSRecord("domain1.com.", A, "0.0.0.0"), @@ -255,7 +256,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { )) }) It("should block the A query if domain is on the black list (merged)", func() { - Expect(sut.Resolve(newRequestWithClient("blocked2.com.", A, "1.2.1.2", "client3"))). + Expect(sut.Resolve(ctx, newRequestWithClient("blocked2.com.", A, "1.2.1.2", "client3"))). Should( SatisfyAll( BeDNSRecord("blocked2.com.", A, "0.0.0.0"), @@ -266,7 +267,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { )) }) It("should block the AAAA query if domain is on the black list", func() { - Expect(sut.Resolve(newRequestWithClient("domain1.com.", AAAA, "1.2.1.2", "client1"))). + Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", AAAA, "1.2.1.2", "client1"))). Should( SatisfyAll( BeDNSRecord("domain1.com.", AAAA, "::"), @@ -277,18 +278,18 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { )) }) It("should block the HTTPS query if domain is on the black list", func() { - Expect(sut.Resolve(newRequestWithClient("domain1.com.", HTTPS, "1.2.1.2", "client1"))). + Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", HTTPS, "1.2.1.2", "client1"))). Should(HaveReturnCode(dns.RcodeNameError)) }) It("should block the MX query if domain is on the black list", func() { - Expect(sut.Resolve(newRequestWithClient("domain1.com.", MX, "1.2.1.2", "client1"))). + Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", MX, "1.2.1.2", "client1"))). Should(HaveReturnCode(dns.RcodeNameError)) }) }) When("Client ip is defined in client groups block", func() { It("should block the query if domain is on the black list", func() { - Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "192.168.178.55", "unknown"))). + Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "192.168.178.55", "unknown"))). Should( SatisfyAll( BeDNSRecord("domain1.com.", A, "0.0.0.0"), @@ -301,7 +302,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { }) When("Client CIDR (10.43.8.64 - 10.43.8.79) is defined in client groups block", func() { It("should not block the query for 10.43.8.63 if domain is on the black list", func() { - Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "10.43.8.63", "unknown"))). + Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "10.43.8.63", "unknown"))). Should( SatisfyAll( HaveNoAnswer(), @@ -313,7 +314,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { m.AssertExpectations(GinkgoT()) }) It("should not block the query for 10.43.8.80 if domain is on the black list", func() { - Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "10.43.8.80", "unknown"))). + Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "10.43.8.80", "unknown"))). Should( SatisfyAll( HaveNoAnswer(), @@ -328,7 +329,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { When("Client CIDR (10.43.8.64 - 10.43.8.79) is defined in client groups block", func() { It("should block the query for 10.43.8.64 if domain is on the black list", func() { - Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "10.43.8.64", "unknown"))). + Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "10.43.8.64", "unknown"))). Should( SatisfyAll( BeDNSRecord("domain1.com.", A, "0.0.0.0"), @@ -339,7 +340,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { )) }) It("should block the query for 10.43.8.79 if domain is on the black list", func() { - Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "10.43.8.79", "unknown"))). + Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "10.43.8.79", "unknown"))). Should( SatisfyAll( BeDNSRecord("domain1.com.", A, "0.0.0.0"), @@ -353,7 +354,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { When("Client has multiple names and for each name a client group block definition exists", func() { It("should block query if domain is in one group", func() { - Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "1.2.1.2", "client1", "altname"))). + Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "1.2.1.2", "client1", "altname"))). Should( SatisfyAll( BeDNSRecord("domain1.com.", A, "0.0.0.0"), @@ -364,7 +365,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { )) }) It("should block query if domain is in another group too", func() { - Expect(sut.Resolve(newRequestWithClient("blocked2.com.", A, "1.2.1.2", "client1", "altName"))). + Expect(sut.Resolve(ctx, newRequestWithClient("blocked2.com.", A, "1.2.1.2", "client1", "altName"))). Should( SatisfyAll( BeDNSRecord("blocked2.com.", A, "0.0.0.0"), @@ -377,7 +378,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { }) When("Client name matches wildcard", func() { It("should block query if domain is in one group", func() { - Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "1.2.1.2", "wildcard1name"))). + Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "1.2.1.2", "wildcard1name"))). Should( SatisfyAll( BeDNSRecord("domain1.com.", A, "0.0.0.0"), @@ -391,7 +392,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { When("Default group is defined", func() { It("should block domains from default group for each client", func() { - Expect(sut.Resolve(newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))). + Expect(sut.Resolve(ctx, newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))). Should( SatisfyAll( BeDNSRecord("blocked3.com.", A, "0.0.0.0"), @@ -418,7 +419,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { }) It("should return NXDOMAIN if query is blocked", func() { - Expect(sut.Resolve(newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))). + Expect(sut.Resolve(ctx, newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))). Should( SatisfyAll( HaveNoAnswer(), @@ -444,7 +445,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { }) It("should return answer with specified TTL", func() { - Expect(sut.Resolve(newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))). + Expect(sut.Resolve(ctx, newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))). Should( SatisfyAll( BeDNSRecord("blocked3.com.", A, "0.0.0.0"), @@ -461,7 +462,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { }) It("should return custom IP with specified TTL", func() { - Expect(sut.Resolve(newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))). + Expect(sut.Resolve(ctx, newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))). Should( SatisfyAll( BeDNSRecord("blocked3.com.", A, "12.12.12.12"), @@ -489,7 +490,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { }) It("should return ipv4 address for A query if query is blocked", func() { - Expect(sut.Resolve(newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))). + Expect(sut.Resolve(ctx, newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))). Should( SatisfyAll( BeDNSRecord("blocked3.com.", A, "12.12.12.12"), @@ -501,7 +502,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { }) It("should return ipv6 address for AAAA query if query is blocked", func() { - Expect(sut.Resolve(newRequestWithClient("blocked3.com.", AAAA, "1.2.1.2", "unknown"))). + Expect(sut.Resolve(ctx, newRequestWithClient("blocked3.com.", AAAA, "1.2.1.2", "unknown"))). Should( SatisfyAll( BeDNSRecord("blocked3.com.", AAAA, "2001:db8:85a3::8a2e:370:7334"), @@ -528,7 +529,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { }) It("should use fallback for ipv6 and return zero ip", func() { - Expect(sut.Resolve(newRequestWithClient("blocked3.com.", AAAA, "1.2.1.2", "unknown"))). + Expect(sut.Resolve(ctx, newRequestWithClient("blocked3.com.", AAAA, "1.2.1.2", "unknown"))). Should( SatisfyAll( BeDNSRecord("blocked3.com.", AAAA, "::"), @@ -547,7 +548,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { mockAnswer, _ = util.NewMsgWithAnswer("example.com.", 300, A, "123.145.123.145") }) It("should block query, if lookup result contains blacklisted IP", func() { - Expect(sut.Resolve(newRequestWithClient("example.com.", A, "1.2.1.2", "unknown"))). + Expect(sut.Resolve(ctx, newRequestWithClient("example.com.", A, "1.2.1.2", "unknown"))). Should( SatisfyAll( BeDNSRecord("example.com.", A, "0.0.0.0"), @@ -567,7 +568,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { ) }) It("should block query, if lookup result contains blacklisted IP", func() { - Expect(sut.Resolve(newRequestWithClient("example.com.", AAAA, "1.2.1.2", "unknown"))). + Expect(sut.Resolve(ctx, newRequestWithClient("example.com.", AAAA, "1.2.1.2", "unknown"))). Should( SatisfyAll( BeDNSRecord("example.com.", AAAA, "::"), @@ -590,7 +591,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { mockAnswer.Answer = []dns.RR{rr1, rr2, rr3} }) It("should block the query, if response contains a CNAME with domain on a blacklist", func() { - Expect(sut.Resolve(newRequestWithClient("example.com.", A, "1.2.1.2", "unknown"))). + Expect(sut.Resolve(ctx, newRequestWithClient("example.com.", A, "1.2.1.2", "unknown"))). Should( SatisfyAll( BeDNSRecord("example.com.", A, "0.0.0.0"), @@ -617,7 +618,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { } }) It("Should not be blocked", func() { - Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))). + Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))). Should( SatisfyAll( HaveNoAnswer(), @@ -649,7 +650,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { }) It("should block everything else except domains on the white list with default group", func() { By("querying domain on the whitelist", func() { - Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))). + Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))). Should( SatisfyAll( HaveNoAnswer(), @@ -662,7 +663,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { }) By("querying another domain, which is not on the whitelist", func() { - Expect(sut.Resolve(newRequestWithClient("google.com.", A, "1.2.1.2", "unknown"))). + Expect(sut.Resolve(ctx, newRequestWithClient("google.com.", A, "1.2.1.2", "unknown"))). Should( SatisfyAll( BeDNSRecord("google.com.", A, "0.0.0.0"), @@ -678,7 +679,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { It("should block everything else except domains on the white list "+ "if multiple white list only groups are defined", func() { By("querying domain on the whitelist", func() { - Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "1.2.1.2", "one-client"))). + Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "1.2.1.2", "one-client"))). Should( SatisfyAll( HaveNoAnswer(), @@ -691,7 +692,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { }) By("querying another domain, which is not on the whitelist", func() { - Expect(sut.Resolve(newRequestWithClient("blocked2.com.", A, "1.2.1.2", "one-client"))). + Expect(sut.Resolve(ctx, newRequestWithClient("blocked2.com.", A, "1.2.1.2", "one-client"))). Should( SatisfyAll( BeDNSRecord("blocked2.com.", A, "0.0.0.0"), @@ -706,7 +707,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { It("should block everything else except domains on the white list "+ "if multiple white list only groups are defined", func() { By("querying domain on the whitelist group 1", func() { - Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "1.2.1.2", "all-client"))). + Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "1.2.1.2", "all-client"))). Should( SatisfyAll( HaveNoAnswer(), @@ -719,7 +720,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { }) By("querying another domain, which is in the whitelist group 1", func() { - Expect(sut.Resolve(newRequestWithClient("blocked2.com.", A, "1.2.1.2", "all-client"))). + Expect(sut.Resolve(ctx, newRequestWithClient("blocked2.com.", A, "1.2.1.2", "all-client"))). Should( SatisfyAll( HaveNoAnswer(), @@ -745,7 +746,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { mockAnswer, _ = util.NewMsgWithAnswer("example.com.", 300, A, "123.145.123.145") }) It("should not block if DNS answer contains IP from the white list", func() { - Expect(sut.Resolve(newRequestWithClient("example.com.", A, "1.2.1.2", "unknown"))). + Expect(sut.Resolve(ctx, newRequestWithClient("example.com.", A, "1.2.1.2", "unknown"))). Should( SatisfyAll( BeDNSRecord("example.com.", A, "123.145.123.145"), @@ -775,7 +776,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { }) When("domain is not on the black list", func() { It("should delegate to next resolver", func() { - Expect(sut.Resolve(newRequestWithClient("example.com.", A, "1.2.1.2", "unknown"))). + Expect(sut.Resolve(ctx, newRequestWithClient("example.com.", A, "1.2.1.2", "unknown"))). Should( SatisfyAll( HaveNoAnswer(), @@ -792,7 +793,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { } }) It("should delegate to next resolver", func() { - Expect(sut.Resolve(newRequestWithClient("example.com.", A, "1.2.1.2", "unknown"))). + Expect(sut.Resolve(ctx, newRequestWithClient("example.com.", A, "1.2.1.2", "unknown"))). Should( SatisfyAll( HaveNoAnswer(), @@ -819,7 +820,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { When("Disable blocking is called", func() { It("no query should be blocked", func() { By("Perform query to ensure that the blocking status is active (defaultGroup)", func() { - Expect(sut.Resolve(newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))). + Expect(sut.Resolve(ctx, newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))). Should( SatisfyAll( BeDNSRecord("blocked3.com.", A, "0.0.0.0"), @@ -830,7 +831,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { }) By("Perform query to ensure that the blocking status is active (group1)", func() { - Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))). + Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))). Should( SatisfyAll( BeDNSRecord("domain1.com.", A, "0.0.0.0"), @@ -847,7 +848,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { By("perform the same query again (defaultGroup)", func() { // now is blocking disabled, query the url again - Expect(sut.Resolve(newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))). + Expect(sut.Resolve(ctx, newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))). Should( SatisfyAll( HaveNoAnswer(), @@ -861,7 +862,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { By("perform the same query again (group1)", func() { // now is blocking disabled, query the url again - Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))). + Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))). Should( SatisfyAll( HaveNoAnswer(), @@ -880,7 +881,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { By("perform the same query again (defaultGroup)", func() { // now is blocking disabled, query the url again - Expect(sut.Resolve(newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))). + Expect(sut.Resolve(ctx, newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))). Should( SatisfyAll( HaveNoAnswer(), @@ -893,7 +894,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { }) By("Perform query to ensure that the blocking status is active (group1)", func() { - Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))). + Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))). Should( SatisfyAll( BeDNSRecord("domain1.com.", A, "0.0.0.0"), @@ -908,7 +909,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { When("Disable blocking for all groups is called with a duration parameter", func() { It("No query should be blocked only for passed amount of time", func() { By("Perform query to ensure that the blocking status is active (defaultGroup)", func() { - Expect(sut.Resolve(newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))). + Expect(sut.Resolve(ctx, newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))). Should( SatisfyAll( BeDNSRecord("blocked3.com.", A, "0.0.0.0"), @@ -918,7 +919,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { )) }) By("Perform query to ensure that the blocking status is active (group1)", func() { - Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))). + Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))). Should( SatisfyAll( BeDNSRecord("domain1.com.", A, "0.0.0.0"), @@ -941,7 +942,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { By("perform the same query again to ensure that this query will not be blocked (defaultGroup)", func() { // now is blocking disabled, query the url again - Expect(sut.Resolve(newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))). + Expect(sut.Resolve(ctx, newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))). Should( SatisfyAll( HaveNoAnswer(), @@ -954,7 +955,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { }) By("perform the same query again to ensure that this query will not be blocked (group1)", func() { // now is blocking disabled, query the url again - Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))). + Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))). Should( SatisfyAll( HaveNoAnswer(), @@ -974,7 +975,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { // wait 1 sec Eventually(enabled, "1s").Should(Receive(BeTrue())) - Expect(sut.Resolve(newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))). + Expect(sut.Resolve(ctx, newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))). Should( SatisfyAll( BeDNSRecord("blocked3.com.", A, "0.0.0.0"), @@ -983,7 +984,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { HaveReturnCode(dns.RcodeSuccess), )) - Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))). + Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))). Should( SatisfyAll( BeDNSRecord("domain1.com.", A, "0.0.0.0"), @@ -998,7 +999,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { When("Disable blocking for one group is called with a duration parameter", func() { It("No query should be blocked only for passed amount of time", func() { By("Perform query to ensure that the blocking status is active (defaultGroup)", func() { - Expect(sut.Resolve(newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))). + Expect(sut.Resolve(ctx, newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))). Should( SatisfyAll( BeDNSRecord("blocked3.com.", A, "0.0.0.0"), @@ -1008,7 +1009,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { )) }) By("Perform query to ensure that the blocking status is active (group1)", func() { - Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))). + Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))). Should( SatisfyAll( BeDNSRecord("domain1.com.", A, "0.0.0.0"), @@ -1031,7 +1032,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { By("perform the same query again to ensure that this query will not be blocked (defaultGroup)", func() { // now is blocking disabled, query the url again - Expect(sut.Resolve(newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))). + Expect(sut.Resolve(ctx, newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))). Should( SatisfyAll( BeDNSRecord("blocked3.com.", A, "0.0.0.0"), @@ -1042,7 +1043,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { }) By("perform the same query again to ensure that this query will not be blocked (group1)", func() { // now is blocking disabled, query the url again - Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))). + Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))). Should( SatisfyAll( HaveNoAnswer(), @@ -1062,7 +1063,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { // wait 1 sec Eventually(enabled, "1s").Should(Receive(BeTrue())) - Expect(sut.Resolve(newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))). + Expect(sut.Resolve(ctx, newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))). Should( SatisfyAll( BeDNSRecord("blocked3.com.", A, "0.0.0.0"), @@ -1071,7 +1072,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { HaveReturnCode(dns.RcodeSuccess), )) - Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))). + Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))). Should( SatisfyAll( BeDNSRecord("domain1.com.", A, "0.0.0.0"), diff --git a/resolver/bootstrap_test.go b/resolver/bootstrap_test.go index 61b3722dc..5e4ffb397 100644 --- a/resolver/bootstrap_test.go +++ b/resolver/bootstrap_test.go @@ -72,7 +72,7 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() { }, } - _, err := sut.resolveUpstream(nil, "example.com") + _, err := sut.resolveUpstream(ctx, nil, "example.com") Expect(err).ShouldNot(Succeed()) Expect(usedSystemResolver).Should(Receive(BeTrue())) }) @@ -244,7 +244,7 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() { When("called from bootstrap.upstream", func() { It("uses hardcoded IPs", func() { - ips, err := sut.resolveUpstream(bootstrapUpstream, "host") + ips, err := sut.resolveUpstream(ctx, bootstrapUpstream, "host") Expect(err).Should(Succeed()) Expect(ips).Should(Equal(sutConfig.BootstrapDNS[0].IPs)) @@ -253,7 +253,7 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() { When("hostname is an IP", func() { It("returns immediately", func() { - ips, err := sut.resolve("0.0.0.0", config.IPVersionDual.QTypes()) + ips, err := sut.resolve(ctx, "0.0.0.0", config.IPVersionDual.QTypes()) Expect(err).Should(Succeed()) Expect(ips).Should(ContainElement(net.IPv4zero)) @@ -269,7 +269,7 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() { bootstrapUpstream.On("Resolve", mock.Anything).Return(&model.Response{Res: bootstrapResponse}, nil) - ips, err := sut.resolve("localhost", []dns.Type{AAAA}) + ips, err := sut.resolve(ctx, "localhost", []dns.Type{AAAA}) Expect(err).Should(Succeed()) Expect(ips).Should(HaveLen(1)) @@ -283,7 +283,7 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() { bootstrapUpstream.On("Resolve", mock.Anything).Return(nil, resolveErr) - ips, err := sut.resolve("localhost", []dns.Type{A}) + ips, err := sut.resolve(ctx, "localhost", []dns.Type{A}) Expect(err).ShouldNot(Succeed()) Expect(err.Error()).Should(ContainSubstring(resolveErr.Error())) @@ -297,7 +297,7 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() { bootstrapUpstream.On("Resolve", mock.Anything).Return(&model.Response{Res: bootstrapResponse}, nil) - ips, err := sut.resolve("unknownhost.invalid", []dns.Type{A}) + ips, err := sut.resolve(ctx, "unknownhost.invalid", []dns.Type{A}) Expect(err).ShouldNot(Succeed()) Expect(err.Error()).Should(ContainSubstring("no such host")) @@ -329,7 +329,7 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() { r := newUpstreamResolverUnchecked(upstream, sut) - rsp, err := r.Resolve(mainReq) + rsp, err := r.Resolve(ctx, mainReq) Expect(err).Should(Succeed()) Expect(mockUpstreamServer.GetCallCount()).Should(Equal(1)) Expect(rsp.Res.Question[0].Name).Should(Equal("example.com.")) @@ -373,7 +373,7 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() { // implicit expectation of 0 bootstrapUpstream.Resolve calls - _, err = t.DialContext(context.Background(), "ip", "!bad-addr!") + _, err = t.DialContext(ctx, "ip", "!bad-addr!") Expect(err).ShouldNot(Succeed()) }) @@ -384,7 +384,7 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() { t := sut.NewHTTPTransport() - _, err = t.DialContext(context.Background(), "ip", "abc:123") + _, err = t.DialContext(ctx, "ip", "abc:123") Expect(err).ShouldNot(Succeed()) Expect(err.Error()).Should(ContainSubstring(resolveErr.Error())) @@ -397,7 +397,7 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() { t := sut.NewHTTPTransport() - _, err = t.DialContext(context.Background(), "ip", "abc:123") + _, err = t.DialContext(ctx, "ip", "abc:123") Expect(err).ShouldNot(Succeed()) Expect(err.Error()).Should(ContainSubstring("no such host")) @@ -437,7 +437,7 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() { Describe("resolve", func() { AfterEach(func() { - _, err := sut.resolveUpstream(nil, "example.com") + _, err := sut.resolveUpstream(ctx, nil, "example.com") Expect(err).Should(Succeed()) m.AssertExpectations(GinkgoT()) @@ -501,7 +501,7 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() { AfterEach(func() { t := sut.NewHTTPTransport() - conn, err := t.DialContext(context.Background(), dialIPVersion.Net(), "localhost:0") + conn, err := t.DialContext(ctx, dialIPVersion.Net(), "localhost:0") Expect(err).Should(Succeed()) Expect(conn).Should(Equal(aMockConn)) @@ -583,7 +583,7 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() { }) It("uses both", func() { - _, err := sut.resolve("example.com.", []dns.Type{dns.Type(dns.TypeA)}) + _, err := sut.resolve(ctx, "example.com.", []dns.Type{dns.Type(dns.TypeA)}) Expect(err).To(Succeed()) diff --git a/resolver/caching_resolver_test.go b/resolver/caching_resolver_test.go index acc7be64c..df95b613b 100644 --- a/resolver/caching_resolver_test.go +++ b/resolver/caching_resolver_test.go @@ -102,7 +102,7 @@ var _ = Describe("CachingResolver", func() { })).Should(Succeed()) // first request - _, _ = sut.Resolve(newRequest("example.com.", A)) + _, _ = sut.Resolve(ctx, newRequest("example.com.", A)) // Domain is not prefetched Expect(domainPrefetched).ShouldNot(Receive()) @@ -112,7 +112,7 @@ var _ = Describe("CachingResolver", func() { // now query again > threshold for i := 0; i < prefetchThreshold+1; i++ { - _, err := sut.Resolve(newRequest("example.com.", A)) + _, err := sut.Resolve(ctx, newRequest("example.com.", A)) Expect(err).Should(Succeed()) } @@ -120,7 +120,7 @@ var _ = Describe("CachingResolver", func() { Eventually(domainPrefetched, "10s").Should(Receive(Equal(true))) // and it should hit from prefetch cache - Expect(sut.Resolve(newRequest("example.com.", A))). + Expect(sut.Resolve(ctx, newRequest("example.com.", A))). Should( SatisfyAll( HaveResponseType(ResponseTypeCACHED), @@ -144,7 +144,7 @@ var _ = Describe("CachingResolver", func() { }) It("should cache response and use response's TTL for multiple records", func() { By("first request", func() { - result, err := sut.Resolve(newRequest("example.com.", A)) + result, err := sut.Resolve(ctx, newRequest("example.com.", A)) Expect(err).Should(Succeed()) Expect(result). Should( @@ -164,7 +164,7 @@ var _ = Describe("CachingResolver", func() { By("second request", func() { Eventually(func(g Gomega) { - result, err := sut.Resolve(newRequest("example.com.", A)) + result, err := sut.Resolve(ctx, newRequest("example.com.", A)) g.Expect(err).Should(Succeed()) g.Expect(result). Should( @@ -206,7 +206,7 @@ var _ = Describe("CachingResolver", func() { _ = Bus().SubscribeOnce(CachingResultCacheChanged, func(d int) { totalCacheCount <- d }) - Expect(sut.Resolve(newRequest("example.com.", A))). + Expect(sut.Resolve(ctx, newRequest("example.com.", A))). Should( SatisfyAll( HaveResponseType(ResponseTypeRESOLVED), @@ -227,7 +227,7 @@ var _ = Describe("CachingResolver", func() { domain <- true }) - g.Expect(sut.Resolve(newRequest("example.com.", A))). + g.Expect(sut.Resolve(ctx, newRequest("example.com.", A))). Should( SatisfyAll( HaveResponseType(ResponseTypeCACHED), @@ -252,7 +252,7 @@ var _ = Describe("CachingResolver", func() { It("should cache response and use min caching time as TTL", func() { By("first request", func() { - Expect(sut.Resolve(newRequest("example.com.", A))). + Expect(sut.Resolve(ctx, newRequest("example.com.", A))). Should( SatisfyAll( HaveResponseType(ResponseTypeRESOLVED), @@ -264,7 +264,9 @@ var _ = Describe("CachingResolver", func() { }) By("second request", func() { - Eventually(sut.Resolve).WithArguments(newRequest("example.com.", A)). + Eventually(sut.Resolve). + WithContext(ctx). + WithArguments(newRequest("example.com.", A)). Should( SatisfyAll( HaveResponseType(ResponseTypeCACHED), @@ -287,7 +289,7 @@ var _ = Describe("CachingResolver", func() { It("should cache response and use min caching time as TTL", func() { By("first request", func() { - Expect(sut.Resolve(newRequest("example.com.", AAAA))). + Expect(sut.Resolve(ctx, newRequest("example.com.", AAAA))). Should( SatisfyAll( HaveResponseType(ResponseTypeRESOLVED), @@ -298,7 +300,9 @@ var _ = Describe("CachingResolver", func() { }) By("second request", func() { - Eventually(sut.Resolve).WithArguments(newRequest("example.com.", AAAA)). + Eventually(sut.Resolve). + WithContext(ctx). + WithArguments(newRequest("example.com.", AAAA)). Should( SatisfyAll( HaveResponseType(ResponseTypeCACHED), @@ -332,7 +336,7 @@ var _ = Describe("CachingResolver", func() { It("Shouldn't cache any responses", func() { By("first request", func() { - Expect(sut.Resolve(newRequest("example.com.", AAAA))). + Expect(sut.Resolve(ctx, newRequest("example.com.", AAAA))). Should( SatisfyAll( HaveResponseType(ResponseTypeRESOLVED), @@ -343,7 +347,9 @@ var _ = Describe("CachingResolver", func() { }) By("second request", func() { - Eventually(sut.Resolve).WithArguments(newRequest("example.com.", AAAA)). + Eventually(sut.Resolve). + WithContext(ctx). + WithArguments(newRequest("example.com.", AAAA)). Should( SatisfyAll( HaveResponseType(ResponseTypeRESOLVED), @@ -366,7 +372,7 @@ var _ = Describe("CachingResolver", func() { }) It("should cache response and use max caching time as TTL if response TTL is bigger", func() { By("first request", func() { - Expect(sut.Resolve(newRequest("example.com.", AAAA))). + Expect(sut.Resolve(ctx, newRequest("example.com.", AAAA))). Should( SatisfyAll( HaveResponseType(ResponseTypeRESOLVED), @@ -377,7 +383,9 @@ var _ = Describe("CachingResolver", func() { }) By("second request", func() { - Eventually(sut.Resolve).WithArguments(newRequest("example.com.", AAAA)). + Eventually(sut.Resolve). + WithContext(ctx). + WithArguments(newRequest("example.com.", AAAA)). Should( SatisfyAll( HaveResponseType(ResponseTypeCACHED), @@ -405,7 +413,7 @@ var _ = Describe("CachingResolver", func() { }) It("should cache response and return 0 TTL if entry is expired", func() { By("first request", func() { - Expect(sut.Resolve(newRequest("example.com.", A))). + Expect(sut.Resolve(ctx, newRequest("example.com.", A))). Should( SatisfyAll( HaveResponseType(ResponseTypeRESOLVED), @@ -418,7 +426,9 @@ var _ = Describe("CachingResolver", func() { }) By("second request", func() { - Eventually(sut.Resolve, "2s").WithArguments(newRequest("example.com.", A)). + Eventually(sut.Resolve, "2s"). + WithContext(ctx). + WithArguments(newRequest("example.com.", A)). Should( SatisfyAll( HaveResponseType(ResponseTypeCACHED), @@ -449,7 +459,7 @@ var _ = Describe("CachingResolver", func() { }) By("first request", func() { - Expect(sut.Resolve(newRequest("example.com.", AAAA))). + Expect(sut.Resolve(ctx, newRequest("example.com.", AAAA))). Should(SatisfyAll( HaveResponseType(ResponseTypeRESOLVED), HaveReturnCode(dns.RcodeNameError), @@ -460,7 +470,9 @@ var _ = Describe("CachingResolver", func() { }) By("second request", func() { - Eventually(sut.Resolve).WithArguments(newRequest("example.com.", AAAA)). + Eventually(sut.Resolve). + WithContext(ctx). + WithArguments(newRequest("example.com.", AAAA)). Should(SatisfyAll( HaveResponseType(ResponseTypeCACHED), HaveReason("CACHED NEGATIVE"), @@ -483,7 +495,7 @@ var _ = Describe("CachingResolver", func() { It("response shouldn't be cached", func() { By("first request", func() { - Expect(sut.Resolve(newRequest("example.com.", AAAA))). + Expect(sut.Resolve(ctx, newRequest("example.com.", AAAA))). Should(SatisfyAll( HaveResponseType(ResponseTypeRESOLVED), HaveReturnCode(dns.RcodeNameError), @@ -494,7 +506,9 @@ var _ = Describe("CachingResolver", func() { }) By("second request", func() { - Eventually(sut.Resolve).WithArguments(newRequest("example.com.", AAAA)). + Eventually(sut.Resolve). + WithContext(ctx). + WithArguments(newRequest("example.com.", AAAA)). Should(SatisfyAll( HaveResponseType(ResponseTypeRESOLVED), HaveReason(""), @@ -517,7 +531,7 @@ var _ = Describe("CachingResolver", func() { It("response should be cached", func() { By("first request", func() { - Expect(sut.Resolve(newRequest("example.com.", AAAA))). + Expect(sut.Resolve(ctx, newRequest("example.com.", AAAA))). Should(SatisfyAll( HaveResponseType(ResponseTypeRESOLVED), HaveReturnCode(dns.RcodeSuccess), @@ -528,7 +542,9 @@ var _ = Describe("CachingResolver", func() { }) By("second request", func() { - Eventually(sut.Resolve).WithArguments(newRequest("example.com.", AAAA)). + Eventually(sut.Resolve). + WithContext(ctx). + WithArguments(newRequest("example.com.", AAAA)). Should(SatisfyAll( HaveResponseType(ResponseTypeCACHED), HaveReason("CACHED"), @@ -551,7 +567,7 @@ var _ = Describe("CachingResolver", func() { }) It("Should be cached", func() { By("first request", func() { - Expect(sut.Resolve(newRequest("google.de.", MX))). + Expect(sut.Resolve(ctx, newRequest("google.de.", MX))). Should(SatisfyAll( HaveResponseType(ResponseTypeRESOLVED), HaveReturnCode(dns.RcodeSuccess), @@ -563,7 +579,9 @@ var _ = Describe("CachingResolver", func() { }) By("second request", func() { - Eventually(sut.Resolve).WithArguments(newRequest("google.de.", MX)). + Eventually(sut.Resolve). + WithContext(ctx). + WithArguments(newRequest("google.de.", MX)). Should(SatisfyAll( HaveResponseType(ResponseTypeCACHED), HaveReason("CACHED"), @@ -587,7 +605,7 @@ var _ = Describe("CachingResolver", func() { }) It("Should not be cached", func() { By("first request", func() { - Expect(sut.Resolve(newRequest("google.de.", A))). + Expect(sut.Resolve(ctx, newRequest("google.de.", A))). Should(SatisfyAll( HaveResponseType(ResponseTypeRESOLVED), HaveReturnCode(dns.RcodeSuccess), @@ -599,7 +617,7 @@ var _ = Describe("CachingResolver", func() { }) By("second request", func() { - Expect(sut.Resolve(newRequest("google.de.", A))). + Expect(sut.Resolve(ctx, newRequest("google.de.", A))). Should(SatisfyAll( HaveResponseType(ResponseTypeRESOLVED), HaveReturnCode(dns.RcodeSuccess), @@ -621,7 +639,7 @@ var _ = Describe("CachingResolver", func() { }) It("Should not be cached", func() { By("first request", func() { - Expect(sut.Resolve(newRequest("google.de.", A))). + Expect(sut.Resolve(ctx, newRequest("google.de.", A))). Should(SatisfyAll( HaveResponseType(ResponseTypeRESOLVED), HaveReturnCode(dns.RcodeSuccess), @@ -633,7 +651,7 @@ var _ = Describe("CachingResolver", func() { }) By("second request", func() { - Expect(sut.Resolve(newRequest("google.de.", A))). + Expect(sut.Resolve(ctx, newRequest("google.de.", A))). Should(SatisfyAll( HaveResponseType(ResponseTypeRESOLVED), HaveReturnCode(dns.RcodeSuccess), @@ -659,7 +677,7 @@ var _ = Describe("CachingResolver", func() { }) It("Should not be cached", func() { By("first request", func() { - Expect(sut.Resolve(newRequest("google.de.", A))). + Expect(sut.Resolve(ctx, newRequest("google.de.", A))). Should(SatisfyAll( HaveResponseType(ResponseTypeRESOLVED), HaveReturnCode(dns.RcodeSuccess), @@ -676,7 +694,9 @@ var _ = Describe("CachingResolver", func() { }) By("second request", func() { - Eventually(sut.Resolve).WithArguments(newRequest("google.de.", A)). + Eventually(sut.Resolve). + WithContext(ctx). + WithArguments(newRequest("google.de.", A)). Should(SatisfyAll( HaveResponseType(ResponseTypeCACHED), HaveReason("CACHED"), @@ -738,7 +758,7 @@ var _ = Describe("CachingResolver", func() { }) It("put in redis", func() { - Expect(sut.Resolve(newRequest("example.com.", A))). + Expect(sut.Resolve(ctx, newRequest("example.com.", A))). Should(HaveResponseType(ResponseTypeRESOLVED)) Eventually(func() []string { @@ -760,7 +780,9 @@ var _ = Describe("CachingResolver", func() { } redisClient.CacheChannel <- redisMockMsg - Eventually(sut.Resolve).WithArguments(request). + Eventually(sut.Resolve). + WithContext(ctx). + WithArguments(request). Should( SatisfyAll( HaveResponseType(ResponseTypeCACHED), diff --git a/resolver/client_names_resolver_test.go b/resolver/client_names_resolver_test.go index dfe03fb07..4ea611b06 100644 --- a/resolver/client_names_resolver_test.go +++ b/resolver/client_names_resolver_test.go @@ -22,8 +22,9 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() { sut *ClientNamesResolver sutConfig config.ClientLookupConfig m *mockResolver - ctx context.Context - cancelFn context.CancelFunc + + ctx context.Context + cancelFn context.CancelFunc ) Describe("Type", func() { @@ -34,6 +35,7 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() { JustBeforeEach(func() { var err error + ctx, cancelFn = context.WithCancel(context.Background()) DeferCleanup(cancelFn) @@ -71,7 +73,7 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() { It("should use clientID if set", func() { request := newRequestWithClientID("google1.de.", dns.Type(dns.TypeA), "1.2.3.4", "client123") - Expect(sut.Resolve(request)). + Expect(sut.Resolve(ctx, request)). Should( SatisfyAll( HaveResponseType(ResponseTypeRESOLVED), @@ -82,7 +84,7 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() { }) It("should use IP as fallback if clientID not set", func() { request := newRequestWithClientID("google2.de.", dns.Type(dns.TypeA), "1.2.3.4", "") - Expect(sut.Resolve(request)). + Expect(sut.Resolve(ctx, request)). Should( SatisfyAll( HaveResponseType(ResponseTypeRESOLVED), @@ -112,7 +114,7 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() { It("should resolve defined name with ipv4 address", func() { request := newRequestWithClient("google.de.", dns.Type(dns.TypeA), "1.2.3.4") - Expect(sut.Resolve(request)). + Expect(sut.Resolve(ctx, request)). Should( SatisfyAll( HaveResponseType(ResponseTypeRESOLVED), @@ -124,7 +126,7 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() { It("should resolve defined name with ipv6 address", func() { request := newRequestWithClient("google.de.", dns.Type(dns.TypeA), "2a02:590:505:4700:2e4f:1503:ce74:df78") - Expect(sut.Resolve(request)). + Expect(sut.Resolve(ctx, request)). Should( SatisfyAll( HaveResponseType(ResponseTypeRESOLVED), @@ -135,7 +137,7 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() { }) It("should resolve multiple names defined names", func() { request := newRequestWithClient("google.de.", dns.Type(dns.TypeA), "1.2.3.5") - Expect(sut.Resolve(request)). + Expect(sut.Resolve(ctx, request)). Should( SatisfyAll( HaveResponseType(ResponseTypeRESOLVED), @@ -168,7 +170,7 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() { It("should resolve client name", func() { By("first request", func() { request := newRequestWithClient("google.de.", dns.Type(dns.TypeA), "192.168.178.25") - Expect(sut.Resolve(request)). + Expect(sut.Resolve(ctx, request)). Should( SatisfyAll( HaveResponseType(ResponseTypeRESOLVED), @@ -180,7 +182,7 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() { By("second request", func() { request := newRequestWithClient("google.de.", dns.Type(dns.TypeA), "192.168.178.25") - Expect(sut.Resolve(request)). + Expect(sut.Resolve(ctx, request)). Should( SatisfyAll( HaveResponseType(ResponseTypeRESOLVED), @@ -198,7 +200,7 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() { By("third request", func() { request := newRequestWithClient("google.de.", dns.Type(dns.TypeA), "192.168.178.25") - Expect(sut.Resolve(request)). + Expect(sut.Resolve(ctx, request)). Should( SatisfyAll( HaveResponseType(ResponseTypeRESOLVED), @@ -223,7 +225,7 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() { It("should resolve all client names", func() { request := newRequestWithClient("google.de.", dns.Type(dns.TypeA), "192.168.178.25") - Expect(sut.Resolve(request)). + Expect(sut.Resolve(ctx, request)). Should( SatisfyAll( HaveResponseType(ResponseTypeRESOLVED), @@ -251,7 +253,7 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() { It("should resolve client name", func() { request := newRequestWithClient("google.de.", dns.Type(dns.TypeA), "192.168.178.25") - Expect(sut.Resolve(request)). + Expect(sut.Resolve(ctx, request)). Should( SatisfyAll( HaveResponseType(ResponseTypeRESOLVED), @@ -272,7 +274,7 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() { It("should resolve the client name depending to defined order", func() { request := newRequestWithClient("google.de.", dns.Type(dns.TypeA), "192.168.178.25") - Expect(sut.Resolve(request)). + Expect(sut.Resolve(ctx, request)). Should( SatisfyAll( HaveResponseType(ResponseTypeRESOLVED), @@ -298,7 +300,7 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() { It("should use fallback for client name", func() { request := newRequestWithClient("google.de.", dns.Type(dns.TypeA), "192.168.178.25") - Expect(sut.Resolve(request)). + Expect(sut.Resolve(ctx, request)). Should( SatisfyAll( HaveResponseType(ResponseTypeRESOLVED), @@ -318,7 +320,7 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() { }) It("should use fallback for client name", func() { request := newRequestWithClient("google.de.", dns.Type(dns.TypeA), "192.168.178.25") - Expect(sut.Resolve(request)). + Expect(sut.Resolve(ctx, request)). Should( SatisfyAll( HaveResponseType(ResponseTypeRESOLVED), @@ -335,7 +337,7 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() { }) It("should resolve no names", func() { request := newRequestWithClient("google.de.", dns.Type(dns.TypeA), "") - Expect(sut.Resolve(request)). + Expect(sut.Resolve(ctx, request)). Should( SatisfyAll( HaveResponseType(ResponseTypeRESOLVED), @@ -351,7 +353,7 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() { }) It("should use fallback for client name", func() { request := newRequestWithClient("google.de.", dns.Type(dns.TypeA), "192.168.178.25") - Expect(sut.Resolve(request)). + Expect(sut.Resolve(ctx, request)). Should( SatisfyAll( HaveResponseType(ResponseTypeRESOLVED), diff --git a/resolver/conditional_upstream_resolver_test.go b/resolver/conditional_upstream_resolver_test.go index 72f0cb2ac..2c6ef4bc8 100644 --- a/resolver/conditional_upstream_resolver_test.go +++ b/resolver/conditional_upstream_resolver_test.go @@ -19,6 +19,9 @@ var _ = Describe("ConditionalUpstreamResolver", Label("conditionalResolver"), fu var ( sut *ConditionalUpstreamResolver m *mockResolver + + ctx context.Context + cancelFn context.CancelFunc ) Describe("Type", func() { @@ -28,6 +31,9 @@ var _ = Describe("ConditionalUpstreamResolver", Label("conditionalResolver"), fu }) BeforeEach(func() { + ctx, cancelFn = context.WithCancel(context.Background()) + DeferCleanup(cancelFn) + fbTestUpstream := NewMockUDPUpstreamServer().WithAnswerFn(func(request *dns.Msg) (response *dns.Msg) { response, _ = util.NewMsgWithAnswer(request.Question[0].Name, 123, A, "123.124.122.122") @@ -59,7 +65,7 @@ var _ = Describe("ConditionalUpstreamResolver", Label("conditionalResolver"), fu }) DeferCleanup(refuseTestUpstream.Close) - sut, _ = NewConditionalUpstreamResolver(config.ConditionalUpstreamConfig{ + sut, _ = NewConditionalUpstreamResolver(ctx, config.ConditionalUpstreamConfig{ Mapping: config.ConditionalUpstreamMapping{ Upstreams: map[string][]config.Upstream{ "fritz.box": {fbTestUpstream.Start()}, @@ -93,7 +99,7 @@ var _ = Describe("ConditionalUpstreamResolver", Label("conditionalResolver"), fu Describe("Resolve conditional DNS queries via defined DNS server", func() { When("conditional resolver returns error code", func() { It("Should be returned without changes", func() { - Expect(sut.Resolve(newRequest("refused.domain.", A))). + Expect(sut.Resolve(ctx, newRequest("refused.domain.", A))). Should( SatisfyAll( HaveNoAnswer(), @@ -109,7 +115,7 @@ var _ = Describe("ConditionalUpstreamResolver", Label("conditionalResolver"), fu When("Query is exact equal defined condition in mapping", func() { Context("first mapping entry", func() { It("Should resolve the IP of conditional DNS", func() { - Expect(sut.Resolve(newRequest("fritz.box.", A))). + Expect(sut.Resolve(ctx, newRequest("fritz.box.", A))). Should( SatisfyAll( BeDNSRecord("fritz.box.", A, "123.124.122.122"), @@ -125,7 +131,7 @@ var _ = Describe("ConditionalUpstreamResolver", Label("conditionalResolver"), fu }) Context("last mapping entry", func() { It("Should resolve the IP of conditional DNS", func() { - Expect(sut.Resolve(newRequest("other.box.", A))). + Expect(sut.Resolve(ctx, newRequest("other.box.", A))). Should( SatisfyAll( BeDNSRecord("other.box.", A, "192.192.192.192"), @@ -141,7 +147,7 @@ var _ = Describe("ConditionalUpstreamResolver", Label("conditionalResolver"), fu }) When("Query is a subdomain of defined condition in mapping", func() { It("Should resolve the IP of subdomain", func() { - Expect(sut.Resolve(newRequest("test.fritz.box.", A))). + Expect(sut.Resolve(ctx, newRequest("test.fritz.box.", A))). Should( SatisfyAll( BeDNSRecord("test.fritz.box.", A, "123.124.122.122"), @@ -156,7 +162,7 @@ var _ = Describe("ConditionalUpstreamResolver", Label("conditionalResolver"), fu }) When("Query is not fqdn and . condition is defined in mapping", func() { It("Should resolve the IP of .", func() { - Expect(sut.Resolve(newRequest("test.", A))). + Expect(sut.Resolve(ctx, newRequest("test.", A))). Should( SatisfyAll( BeDNSRecord("test.", A, "168.168.168.168"), @@ -173,7 +179,7 @@ var _ = Describe("ConditionalUpstreamResolver", Label("conditionalResolver"), fu Describe("Delegation to next resolver", func() { When("Query doesn't match defined mapping", func() { It("should delegate to next resolver", func() { - Expect(sut.Resolve(newRequest("google.com.", A))). + Expect(sut.Resolve(ctx, newRequest("google.com.", A))). Should( SatisfyAll( HaveResponseType(ResponseTypeRESOLVED), @@ -186,11 +192,9 @@ var _ = Describe("ConditionalUpstreamResolver", Label("conditionalResolver"), fu When("upstream is invalid", func() { It("errors during construction", func() { - ctx, cancelFn := context.WithCancel(context.Background()) - DeferCleanup(cancelFn) b := newTestBootstrap(ctx, &dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure}}) - r, err := NewConditionalUpstreamResolver(config.ConditionalUpstreamConfig{ + r, err := NewConditionalUpstreamResolver(ctx, config.ConditionalUpstreamConfig{ Mapping: config.ConditionalUpstreamMapping{ Upstreams: map[string][]config.Upstream{ ".": {config.Upstream{Host: "example.com"}}, diff --git a/resolver/custom_dns_resolver_test.go b/resolver/custom_dns_resolver_test.go index 9ad65f170..d06f3672f 100644 --- a/resolver/custom_dns_resolver_test.go +++ b/resolver/custom_dns_resolver_test.go @@ -1,6 +1,7 @@ package resolver import ( + "context" "net" "time" @@ -21,6 +22,9 @@ var _ = Describe("CustomDNSResolver", func() { sut *CustomDNSResolver m *mockResolver cfg config.CustomDNSConfig + + ctx context.Context + cancelFn context.CancelFunc ) Describe("Type", func() { @@ -30,6 +34,9 @@ var _ = Describe("CustomDNSResolver", func() { }) BeforeEach(func() { + ctx, cancelFn = context.WithCancel(context.Background()) + DeferCleanup(cancelFn) + cfg = config.CustomDNSConfig{ Mapping: config.CustomDNSMapping{HostIPs: map[string][]net.IP{ "custom.domain": {net.ParseIP("192.168.143.123")}, @@ -73,7 +80,7 @@ var _ = Describe("CustomDNSResolver", func() { Context("filterUnmappedTypes is true", func() { BeforeEach(func() { cfg.FilterUnmappedTypes = true }) It("defined ip4 query should be resolved", func() { - Expect(sut.Resolve(newRequest("custom.domain.", A))). + Expect(sut.Resolve(ctx, newRequest("custom.domain.", A))). Should( SatisfyAll( BeDNSRecord("custom.domain.", A, "192.168.143.123"), @@ -86,7 +93,7 @@ var _ = Describe("CustomDNSResolver", func() { m.AssertNotCalled(GinkgoT(), "Resolve", mock.Anything) }) It("TXT query for defined mapping should return NOERROR and empty result", func() { - Expect(sut.Resolve(newRequest("custom.domain.", TXT))). + Expect(sut.Resolve(ctx, newRequest("custom.domain.", TXT))). Should( SatisfyAll( HaveNoAnswer(), @@ -98,7 +105,7 @@ var _ = Describe("CustomDNSResolver", func() { m.AssertNotCalled(GinkgoT(), "Resolve", mock.Anything) }) It("ip6 query should return NOERROR and empty result", func() { - Expect(sut.Resolve(newRequest("custom.domain.", AAAA))). + Expect(sut.Resolve(ctx, newRequest("custom.domain.", AAAA))). Should( SatisfyAll( HaveNoAnswer(), @@ -114,7 +121,7 @@ var _ = Describe("CustomDNSResolver", func() { Context("filterUnmappedTypes is false", func() { BeforeEach(func() { cfg.FilterUnmappedTypes = false }) It("defined ip4 query should be resolved", func() { - Expect(sut.Resolve(newRequest("custom.domain.", A))). + Expect(sut.Resolve(ctx, newRequest("custom.domain.", A))). Should( SatisfyAll( BeDNSRecord("custom.domain.", A, "192.168.143.123"), @@ -127,7 +134,7 @@ var _ = Describe("CustomDNSResolver", func() { m.AssertNotCalled(GinkgoT(), "Resolve", mock.Anything) }) It("TXT query for defined mapping should be delegated to next resolver", func() { - Expect(sut.Resolve(newRequest("custom.domain.", TXT))). + Expect(sut.Resolve(ctx, newRequest("custom.domain.", TXT))). Should( SatisfyAll( HaveNoAnswer(), @@ -139,7 +146,7 @@ var _ = Describe("CustomDNSResolver", func() { m.AssertExpectations(GinkgoT()) }) It("ip6 query should return NOERROR and empty result", func() { - Expect(sut.Resolve(newRequest("custom.domain.", AAAA))). + Expect(sut.Resolve(ctx, newRequest("custom.domain.", AAAA))). Should( SatisfyAll( HaveNoAnswer(), @@ -154,7 +161,7 @@ var _ = Describe("CustomDNSResolver", func() { }) When("Ip 6 mapping is defined for custom domain ", func() { It("ip6 query should be resolved", func() { - Expect(sut.Resolve(newRequest("ip6.domain.", AAAA))). + Expect(sut.Resolve(ctx, newRequest("ip6.domain.", AAAA))). Should( SatisfyAll( BeDNSRecord("ip6.domain.", AAAA, "2001:db8:85a3::8a2e:370:7334"), @@ -170,7 +177,7 @@ var _ = Describe("CustomDNSResolver", func() { When("Multiple IPs are defined for custom domain ", func() { It("all IPs for the current type should be returned", func() { By("IPv6 query", func() { - Expect(sut.Resolve(newRequest("multiple.ips.", AAAA))). + Expect(sut.Resolve(ctx, newRequest("multiple.ips.", AAAA))). Should( SatisfyAll( BeDNSRecord("multiple.ips.", AAAA, "2001:db8:85a3::8a2e:370:7334"), @@ -185,7 +192,7 @@ var _ = Describe("CustomDNSResolver", func() { }) By("IPv4 query", func() { - Expect(sut.Resolve(newRequest("multiple.ips.", A))). + Expect(sut.Resolve(ctx, newRequest("multiple.ips.", A))). Should( SatisfyAll( WithTransform(ToAnswer, SatisfyAll( @@ -207,7 +214,7 @@ var _ = Describe("CustomDNSResolver", func() { When("Reverse DNS request is received", func() { It("should resolve the defined domain name", func() { By("ipv4", func() { - Expect(sut.Resolve(newRequest("123.143.168.192.in-addr.arpa.", PTR))). + Expect(sut.Resolve(ctx, newRequest("123.143.168.192.in-addr.arpa.", PTR))). Should( SatisfyAll( WithTransform(ToAnswer, SatisfyAll( @@ -226,7 +233,7 @@ var _ = Describe("CustomDNSResolver", func() { }) By("ipv6", func() { - Expect(sut.Resolve(newRequest("4.3.3.7.0.7.3.0.e.2.a.8.0.0.0.0.0.0.0.0.3.a.5.8.8.b.d.0.1.0.0.2.ip6.arpa.", + Expect(sut.Resolve(ctx, newRequest("4.3.3.7.0.7.3.0.e.2.a.8.0.0.0.0.0.0.0.0.3.a.5.8.8.b.d.0.1.0.0.2.ip6.arpa.", PTR))). Should( SatisfyAll( @@ -250,7 +257,7 @@ var _ = Describe("CustomDNSResolver", func() { }) When("Domain mapping is defined", func() { It("subdomain must also match", func() { - Expect(sut.Resolve(newRequest("ABC.CUSTOM.DOMAIN.", A))). + Expect(sut.Resolve(ctx, newRequest("ABC.CUSTOM.DOMAIN.", A))). Should( SatisfyAll( BeDNSRecord("ABC.CUSTOM.DOMAIN.", A, "192.168.143.123"), @@ -268,7 +275,7 @@ var _ = Describe("CustomDNSResolver", func() { Describe("Delegating to next resolver", func() { When("no mapping for domain exist", func() { It("should delegate to next resolver", func() { - Expect(sut.Resolve(newRequest("example.com.", A))). + Expect(sut.Resolve(ctx, newRequest("example.com.", A))). Should( SatisfyAll( HaveResponseType(ResponseTypeRESOLVED), diff --git a/resolver/ede_resolver_test.go b/resolver/ede_resolver_test.go index fc7e7ad5b..ec4e87911 100644 --- a/resolver/ede_resolver_test.go +++ b/resolver/ede_resolver_test.go @@ -1,6 +1,7 @@ package resolver import ( + "context" "errors" "github.com/0xERR0R/blocky/config" @@ -21,6 +22,9 @@ var _ = Describe("EdeResolver", func() { sutConfig config.EdeConfig m *mockResolver mockAnswer *dns.Msg + + ctx context.Context + cancelFn context.CancelFunc ) Describe("Type", func() { @@ -30,6 +34,9 @@ var _ = Describe("EdeResolver", func() { }) BeforeEach(func() { + ctx, cancelFn = context.WithCancel(context.Background()) + DeferCleanup(cancelFn) + mockAnswer = new(dns.Msg) }) @@ -54,7 +61,7 @@ var _ = Describe("EdeResolver", func() { } }) It("shouldn't add EDE information", func() { - Expect(sut.Resolve(newRequest("example.com.", A))). + Expect(sut.Resolve(ctx, newRequest("example.com.", A))). Should( SatisfyAll( HaveNoAnswer(), @@ -86,7 +93,7 @@ var _ = Describe("EdeResolver", func() { } It("should add EDE information", func() { - Expect(sut.Resolve(newRequest("example.com.", A))). + Expect(sut.Resolve(ctx, newRequest("example.com.", A))). Should( SatisfyAll( HaveNoAnswer(), @@ -114,7 +121,7 @@ var _ = Describe("EdeResolver", func() { }) It("should return it", func() { - resp, err := sut.Resolve(newRequest("example.com", A)) + resp, err := sut.Resolve(ctx, newRequest("example.com", A)) Expect(resp).To(BeNil()) Expect(err).To(Equal(resolveErr)) }) diff --git a/resolver/filtering_resolver_test.go b/resolver/filtering_resolver_test.go index aeaf56008..cccadd2e2 100644 --- a/resolver/filtering_resolver_test.go +++ b/resolver/filtering_resolver_test.go @@ -1,6 +1,8 @@ package resolver import ( + "context" + "github.com/0xERR0R/blocky/config" . "github.com/0xERR0R/blocky/helpertest" "github.com/0xERR0R/blocky/log" @@ -18,6 +20,9 @@ var _ = Describe("FilteringResolver", func() { sutConfig config.FilteringConfig m *mockResolver mockAnswer *dns.Msg + + ctx context.Context + cancelFn context.CancelFunc ) Describe("Type", func() { @@ -27,6 +32,9 @@ var _ = Describe("FilteringResolver", func() { }) BeforeEach(func() { + ctx, cancelFn = context.WithCancel(context.Background()) + DeferCleanup(cancelFn) + mockAnswer = new(dns.Msg) }) @@ -60,7 +68,7 @@ var _ = Describe("FilteringResolver", func() { } }) It("Should delegate to next resolver if request query has other type", func() { - Expect(sut.Resolve(newRequest("example.com.", A))). + Expect(sut.Resolve(ctx, newRequest("example.com.", A))). Should( SatisfyAll( HaveNoAnswer(), @@ -72,7 +80,7 @@ var _ = Describe("FilteringResolver", func() { Expect(m.Calls).Should(HaveLen(1)) }) It("Should return empty answer for defined query type", func() { - Expect(sut.Resolve(newRequest("example.com.", AAAA))). + Expect(sut.Resolve(ctx, newRequest("example.com.", AAAA))). Should( SatisfyAll( HaveNoAnswer(), @@ -90,7 +98,7 @@ var _ = Describe("FilteringResolver", func() { sutConfig = config.FilteringConfig{} }) It("Should return empty answer without error", func() { - Expect(sut.Resolve(newRequest("example.com.", AAAA))). + Expect(sut.Resolve(ctx, newRequest("example.com.", AAAA))). Should( SatisfyAll( HaveNoAnswer(), diff --git a/resolver/fqdn_only_resolver_test.go b/resolver/fqdn_only_resolver_test.go index 7838cc349..666db3517 100644 --- a/resolver/fqdn_only_resolver_test.go +++ b/resolver/fqdn_only_resolver_test.go @@ -1,6 +1,8 @@ package resolver import ( + "context" + "github.com/0xERR0R/blocky/config" . "github.com/0xERR0R/blocky/helpertest" "github.com/0xERR0R/blocky/log" @@ -17,6 +19,9 @@ var _ = Describe("FqdnOnlyResolver", func() { sutConfig config.FqdnOnlyConfig m *mockResolver mockAnswer *dns.Msg + + ctx context.Context + cancelFn context.CancelFunc ) Describe("Type", func() { @@ -26,6 +31,9 @@ var _ = Describe("FqdnOnlyResolver", func() { }) BeforeEach(func() { + ctx, cancelFn = context.WithCancel(context.Background()) + DeferCleanup(cancelFn) + mockAnswer = new(dns.Msg) }) @@ -57,7 +65,7 @@ var _ = Describe("FqdnOnlyResolver", func() { sutConfig = config.FqdnOnlyConfig{Enable: true} }) It("Should delegate to next resolver if request query is fqdn", func() { - Expect(sut.Resolve(newRequest("example.com", A))). + Expect(sut.Resolve(ctx, newRequest("example.com", A))). Should( SatisfyAll( HaveNoAnswer(), @@ -69,7 +77,7 @@ var _ = Describe("FqdnOnlyResolver", func() { Expect(m.Calls).Should(HaveLen(1)) }) It("Should return NXDOMAIN if request query is not fqdn", func() { - Expect(sut.Resolve(newRequest("example", AAAA))). + Expect(sut.Resolve(ctx, newRequest("example", AAAA))). Should( SatisfyAll( HaveNoAnswer(), @@ -103,7 +111,7 @@ var _ = Describe("FqdnOnlyResolver", func() { sutConfig = config.FqdnOnlyConfig{Enable: false} }) It("Should delegate to next resolver if request query is fqdn", func() { - Expect(sut.Resolve(newRequest("example.com", A))). + Expect(sut.Resolve(ctx, newRequest("example.com", A))). Should( SatisfyAll( HaveNoAnswer(), @@ -115,7 +123,7 @@ var _ = Describe("FqdnOnlyResolver", func() { Expect(m.Calls).Should(HaveLen(1)) }) It("Should delegate to next resolver if request query is not fqdn", func() { - Expect(sut.Resolve(newRequest("example", AAAA))). + Expect(sut.Resolve(ctx, newRequest("example", AAAA))). Should( SatisfyAll( HaveNoAnswer(), diff --git a/resolver/hosts_file_resolver_test.go b/resolver/hosts_file_resolver_test.go index 79fead569..197c75dc3 100644 --- a/resolver/hosts_file_resolver_test.go +++ b/resolver/hosts_file_resolver_test.go @@ -25,6 +25,9 @@ var _ = Describe("HostsFileResolver", func() { tmpFile *TmpFile err error resp *Response + + ctx context.Context + cancelFn context.CancelFunc ) Describe("Type", func() { @@ -34,6 +37,9 @@ var _ = Describe("HostsFileResolver", func() { }) BeforeEach(func() { + ctx, cancelFn = context.WithCancel(context.Background()) + DeferCleanup(cancelFn) + tmpDir = NewTmpFolder("HostsFileResolver") Expect(tmpDir.Error).Should(Succeed()) DeferCleanup(tmpDir.Clean) @@ -53,8 +59,6 @@ var _ = Describe("HostsFileResolver", func() { }) JustBeforeEach(func() { - ctx, cancelFn := context.WithCancel(context.Background()) - DeferCleanup(cancelFn) sut, err = NewHostsFileResolver(ctx, sutConfig, systemResolverBootstrap) Expect(err).Should(Succeed()) @@ -96,7 +100,7 @@ var _ = Describe("HostsFileResolver", func() { Expect(sut.hosts.isEmpty()).Should(BeTrue()) }) It("should go to next resolver on query", func() { - Expect(sut.Resolve(newRequest("example.com.", A))). + Expect(sut.Resolve(ctx, newRequest("example.com.", A))). Should( SatisfyAll( HaveResponseType(ResponseTypeRESOLVED), @@ -112,11 +116,11 @@ var _ = Describe("HostsFileResolver", func() { sutConfig.Sources = make([]config.BytesSource, 0) }) JustBeforeEach(func() { - err = sut.loadSources(context.Background()) + err = sut.loadSources(ctx) Expect(err).Should(Succeed()) }) It("should go to next resolver on query", func() { - Expect(sut.Resolve(newRequest("example.com.", A))). + Expect(sut.Resolve(ctx, newRequest("example.com.", A))). Should( SatisfyAll( HaveResponseType(ResponseTypeRESOLVED), @@ -178,7 +182,7 @@ var _ = Describe("HostsFileResolver", func() { When("IPv4 mapping is defined for a host", func() { It("defined ipv4 query should be resolved", func() { - Expect(sut.Resolve(newRequest("ipv4host.", A))). + Expect(sut.Resolve(ctx, newRequest("ipv4host.", A))). Should( SatisfyAll( HaveResponseType(ResponseTypeHOSTSFILE), @@ -188,7 +192,7 @@ var _ = Describe("HostsFileResolver", func() { )) }) It("defined ipv4 query for alias should be resolved", func() { - Expect(sut.Resolve(newRequest("router2.", A))). + Expect(sut.Resolve(ctx, newRequest("router2.", A))). Should( SatisfyAll( HaveResponseType(ResponseTypeHOSTSFILE), @@ -198,7 +202,7 @@ var _ = Describe("HostsFileResolver", func() { )) }) It("ipv4 query should return NOERROR and empty result", func() { - Expect(sut.Resolve(newRequest("does.not.exist.", A))). + Expect(sut.Resolve(ctx, newRequest("does.not.exist.", A))). Should( SatisfyAll( HaveNoAnswer(), @@ -210,7 +214,7 @@ var _ = Describe("HostsFileResolver", func() { When("IPv6 mapping is defined for a host", func() { It("defined ipv6 query should be resolved", func() { - Expect(sut.Resolve(newRequest("ipv6host.", AAAA))). + Expect(sut.Resolve(ctx, newRequest("ipv6host.", AAAA))). Should( SatisfyAll( HaveResponseType(ResponseTypeHOSTSFILE), @@ -220,7 +224,7 @@ var _ = Describe("HostsFileResolver", func() { )) }) It("ipv6 query should return NOERROR and empty result", func() { - Expect(sut.Resolve(newRequest("does.not.exist.", AAAA))). + Expect(sut.Resolve(ctx, newRequest("does.not.exist.", AAAA))). Should( SatisfyAll( HaveNoAnswer(), @@ -232,7 +236,7 @@ var _ = Describe("HostsFileResolver", func() { When("the domain is not known", func() { It("calls the next resolver", func() { - resp, err = sut.Resolve(newRequest("not-in-hostsfile.tld.", A)) + resp, err = sut.Resolve(ctx, newRequest("not-in-hostsfile.tld.", A)) Expect(err).Should(Succeed()) Expect(resp).ShouldNot(HaveResponseType(ResponseTypeHOSTSFILE)) m.AssertExpectations(GinkgoT()) @@ -241,7 +245,7 @@ var _ = Describe("HostsFileResolver", func() { When("the question type is not handled", func() { It("calls the next resolver", func() { - resp, err = sut.Resolve(newRequest("localhost.", MX)) + resp, err = sut.Resolve(ctx, newRequest("localhost.", MX)) Expect(err).Should(Succeed()) Expect(resp).ShouldNot(HaveResponseType(ResponseTypeHOSTSFILE)) m.AssertExpectations(GinkgoT()) @@ -251,7 +255,7 @@ var _ = Describe("HostsFileResolver", func() { When("Reverse DNS request is received", func() { It("should resolve the defined domain name", func() { By("ipv4 with one hostname", func() { - Expect(sut.Resolve(newRequest("2.0.0.10.in-addr.arpa.", PTR))). + Expect(sut.Resolve(ctx, newRequest("2.0.0.10.in-addr.arpa.", PTR))). Should( SatisfyAll( HaveResponseType(ResponseTypeHOSTSFILE), @@ -261,7 +265,7 @@ var _ = Describe("HostsFileResolver", func() { )) }) By("ipv4 with aliases", func() { - Expect(sut.Resolve(newRequest("1.0.0.10.in-addr.arpa.", PTR))). + Expect(sut.Resolve(ctx, newRequest("1.0.0.10.in-addr.arpa.", PTR))). Should( SatisfyAll( HaveResponseType(ResponseTypeHOSTSFILE), @@ -274,7 +278,9 @@ var _ = Describe("HostsFileResolver", func() { )) }) By("ipv6", func() { - Expect(sut.Resolve(newRequest("1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.f.a.a.f.f.a.a.f.f.a.a.f.f.a.a.f.ip6.arpa.", PTR))). + Expect(sut.Resolve(ctx, + newRequest("1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.f.a.a.f.f.a.a.f.f.a.a.f.f.a.a.f.ip6.arpa.", PTR)), + ). Should( SatisfyAll( HaveResponseType(ResponseTypeHOSTSFILE), @@ -290,7 +296,7 @@ var _ = Describe("HostsFileResolver", func() { }) It("should ignore invalid PTR", func() { - resp, err = sut.Resolve(newRequest("2.0.0.10.in-addr.fail.arpa.", PTR)) + resp, err = sut.Resolve(ctx, newRequest("2.0.0.10.in-addr.fail.arpa.", PTR)) Expect(err).Should(Succeed()) Expect(resp).ShouldNot(HaveResponseType(ResponseTypeHOSTSFILE)) m.AssertExpectations(GinkgoT()) @@ -298,7 +304,7 @@ var _ = Describe("HostsFileResolver", func() { When("filterLoopback is true", func() { It("calls the next resolver", func() { - resp, err = sut.Resolve(newRequest("1.0.0.127.in-addr.arpa.", PTR)) + resp, err = sut.Resolve(ctx, newRequest("1.0.0.127.in-addr.arpa.", PTR)) Expect(err).Should(Succeed()) Expect(resp).ShouldNot(HaveResponseType(ResponseTypeHOSTSFILE)) m.AssertExpectations(GinkgoT()) @@ -307,7 +313,7 @@ var _ = Describe("HostsFileResolver", func() { When("the IP is not known", func() { It("calls the next resolver", func() { - resp, err = sut.Resolve(newRequest("255.255.255.255.in-addr.arpa.", PTR)) + resp, err = sut.Resolve(ctx, newRequest("255.255.255.255.in-addr.arpa.", PTR)) Expect(err).Should(Succeed()) Expect(resp).ShouldNot(HaveResponseType(ResponseTypeHOSTSFILE)) m.AssertExpectations(GinkgoT()) @@ -320,7 +326,7 @@ var _ = Describe("HostsFileResolver", func() { }) It("resolve the defined domain name", func() { - Expect(sut.Resolve(newRequest("1.1.0.127.in-addr.arpa.", PTR))). + Expect(sut.Resolve(ctx, newRequest("1.1.0.127.in-addr.arpa.", PTR))). Should( SatisfyAll( HaveResponseType(ResponseTypeHOSTSFILE), @@ -338,7 +344,7 @@ var _ = Describe("HostsFileResolver", func() { Describe("Delegating to next resolver", func() { When("no hosts file is provided", func() { It("should delegate to next resolver", func() { - _, err = sut.Resolve(newRequest("example.com.", A)) + _, err = sut.Resolve(ctx, newRequest("example.com.", A)) Expect(err).Should(Succeed()) // delegate was executed m.AssertExpectations(GinkgoT()) diff --git a/resolver/metrics_resolver_test.go b/resolver/metrics_resolver_test.go index a686e5377..d62cf5973 100644 --- a/resolver/metrics_resolver_test.go +++ b/resolver/metrics_resolver_test.go @@ -1,6 +1,7 @@ package resolver import ( + "context" "errors" "github.com/0xERR0R/blocky/config" @@ -21,6 +22,9 @@ var _ = Describe("MetricResolver", func() { var ( sut *MetricsResolver m *mockResolver + + ctx context.Context + cancelFn context.CancelFunc ) Describe("Type", func() { @@ -30,6 +34,9 @@ var _ = Describe("MetricResolver", func() { }) BeforeEach(func() { + ctx, cancelFn = context.WithCancel(context.Background()) + DeferCleanup(cancelFn) + sut = NewMetricsResolver(config.MetricsConfig{Enable: true}) m = &mockResolver{} m.On("Resolve", mock.Anything).Return(&Response{Res: new(dns.Msg)}, nil) @@ -56,7 +63,7 @@ var _ = Describe("MetricResolver", func() { Context("Recording request metrics", func() { When("Request will be performed", func() { It("Should record metrics", func() { - Expect(sut.Resolve(newRequestWithClient("example.com.", A, "", "client"))). + Expect(sut.Resolve(ctx, newRequestWithClient("example.com.", A, "", "client"))). Should( SatisfyAll( HaveResponseType(ResponseTypeRESOLVED), @@ -77,7 +84,7 @@ var _ = Describe("MetricResolver", func() { sut.Next(m) }) It("Error should be recorded", func() { - _, err := sut.Resolve(newRequestWithClient("example.com.", A, "", "client")) + _, err := sut.Resolve(ctx, newRequestWithClient("example.com.", A, "", "client")) Expect(err).Should(HaveOccurred()) diff --git a/resolver/mocks_test.go b/resolver/mocks_test.go index 48fa25862..83d4f71fa 100644 --- a/resolver/mocks_test.go +++ b/resolver/mocks_test.go @@ -23,7 +23,7 @@ type mockResolver struct { mock.Mock NextResolver - ResolveFn func(req *model.Request) (*model.Response, error) + ResolveFn func(ctx context.Context, req *model.Request) (*model.Response, error) ResponseFn func(req *dns.Msg) *dns.Msg AnswerFn func(qType dns.Type, qName string) (*dns.Msg, error) } @@ -45,11 +45,11 @@ func (r *mockResolver) LogConfig(*logrus.Entry) { r.Called() } -func (r *mockResolver) Resolve(req *model.Request) (*model.Response, error) { +func (r *mockResolver) Resolve(ctx context.Context, req *model.Request) (*model.Response, error) { args := r.Called(req) if r.ResolveFn != nil { - return r.ResolveFn(req) + return r.ResolveFn(ctx, req) } if r.ResponseFn != nil { diff --git a/resolver/noop_resolver_test.go b/resolver/noop_resolver_test.go index b9f470966..9e0adb1c9 100644 --- a/resolver/noop_resolver_test.go +++ b/resolver/noop_resolver_test.go @@ -1,6 +1,8 @@ package resolver import ( + "context" + . "github.com/0xERR0R/blocky/helpertest" "github.com/0xERR0R/blocky/log" . "github.com/onsi/ginkgo/v2" @@ -8,7 +10,12 @@ import ( ) var _ = Describe("NoOpResolver", func() { - var sut *NoOpResolver + var ( + sut *NoOpResolver + + ctx context.Context + cancelFn context.CancelFunc + ) Describe("Type", func() { It("follows conventions", func() { @@ -17,12 +24,15 @@ var _ = Describe("NoOpResolver", func() { }) BeforeEach(func() { + ctx, cancelFn = context.WithCancel(context.Background()) + DeferCleanup(cancelFn) + sut = NewNoOpResolver() }) Describe("Resolving", func() { It("returns no response", func() { - resp, err := sut.Resolve(newRequest("test.tld", A)) + resp, err := sut.Resolve(ctx, newRequest("test.tld", A)) Expect(err).Should(Succeed()) Expect(resp).Should(Equal(NoResponse)) }) diff --git a/resolver/parallel_best_resolver_test.go b/resolver/parallel_best_resolver_test.go index cf7088e16..026534221 100644 --- a/resolver/parallel_best_resolver_test.go +++ b/resolver/parallel_best_resolver_test.go @@ -64,7 +64,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() { Upstreams: upstreams, } - sut, err = NewParallelBestResolver(sutConfig, bootstrap, sutVerify) + sut, err = NewParallelBestResolver(ctx, sutConfig, bootstrap, sutVerify) }) Describe("IsEnabled", func() { @@ -104,7 +104,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() { mockUpstream.Start(), } - _, err := NewParallelBestResolver(config.UpstreamGroup{ + _, err := NewParallelBestResolver(ctx, config.UpstreamGroup{ Name: upstreamDefaultCfgName, Upstreams: upstreams, }, @@ -192,7 +192,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() { }) It("Should use result from fastest one", func() { request := newRequest("example.com.", A) - Expect(sut.Resolve(request)). + Expect(sut.Resolve(ctx, request)). Should( SatisfyAll( BeDNSRecord("example.com.", A, "123.124.122.122"), @@ -214,7 +214,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() { }) It("Should use result from successful resolver", func() { request := newRequest("example.com.", A) - Expect(sut.Resolve(request)). + Expect(sut.Resolve(ctx, request)). Should( SatisfyAll( BeDNSRecord("example.com.", A, "123.124.122.123"), @@ -235,7 +235,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() { It("Should return error", func() { Expect(err).Should(Succeed()) request := newRequest("example.com.", A) - _, err = sut.Resolve(request) + _, err = sut.Resolve(ctx, request) Expect(err).Should(HaveOccurred()) }) @@ -251,7 +251,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() { It("Should use result from defined resolver", func() { request := newRequest("example.com.", A) - Expect(sut.Resolve(request)). + Expect(sut.Resolve(ctx, request)). Should( SatisfyAll( BeDNSRecord("example.com.", A, "123.124.122.122"), @@ -275,7 +275,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() { mockUpstream2 := NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.122") DeferCleanup(mockUpstream2.Close) - sut, _ = NewParallelBestResolver(config.UpstreamGroup{ + sut, _ = NewParallelBestResolver(ctx, config.UpstreamGroup{ Name: upstreamDefaultCfgName, Upstreams: []config.Upstream{withError1, mockUpstream1.Start(), mockUpstream2.Start(), withError2}, }, @@ -301,7 +301,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() { By("perform 100 request, error upstream's weight will be reduced", func() { for i := 0; i < 100; i++ { request := newRequest("example.com.", A) - _, _ = sut.Resolve(request) + _, _ = sut.Resolve(ctx, request) } }) @@ -335,7 +335,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() { It("errors during construction", func() { b := newTestBootstrap(ctx, &dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure}}) - r, err := NewParallelBestResolver(config.UpstreamGroup{ + r, err := NewParallelBestResolver(ctx, config.UpstreamGroup{ Name: "test", Upstreams: []config.Upstream{{Host: "example.com"}}, }, b, verifyUpstreams) @@ -372,7 +372,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() { }) It("Should return result from either one", func() { request := newRequest("example.com.", A) - Expect(sut.Resolve(request)). + Expect(sut.Resolve(ctx, request)). Should(SatisfyAll( HaveTTL(BeNumerically("==", 123)), HaveResponseType(ResponseTypeRESOLVED), @@ -398,7 +398,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() { }) It("should ask another random upstream and return its response", func() { request := newRequest("example.com", A) - Expect(sut.Resolve(request)).Should( + Expect(sut.Resolve(ctx, request)).Should( SatisfyAll( BeDNSRecord("example.com.", A, "123.124.122.2"), HaveTTL(BeNumerically("==", 123)), @@ -439,7 +439,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() { }) It("Should return error", func() { request := newRequest("example.com.", A) - _, err := sut.Resolve(request) + _, err := sut.Resolve(ctx, request) Expect(err).Should(HaveOccurred()) }) }) @@ -454,7 +454,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() { It("Should use result from defined resolver", func() { request := newRequest("example.com.", A) - Expect(sut.Resolve(request)). + Expect(sut.Resolve(ctx, request)). Should( SatisfyAll( BeDNSRecord("example.com.", A, "123.124.122.122"), @@ -478,7 +478,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() { mockUpstream2 := NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.122") DeferCleanup(mockUpstream2.Close) - sut, _ = NewParallelBestResolver(config.UpstreamGroup{ + sut, _ = NewParallelBestResolver(ctx, config.UpstreamGroup{ Name: upstreamDefaultCfgName, Upstreams: []config.Upstream{withError1, mockUpstream1.Start(), mockUpstream2.Start(), withError2}, }, @@ -499,7 +499,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() { By("perform 100 request, error upstream's weight will be reduced", func() { for i := 0; i < 100; i++ { request := newRequest("example.com.", A) - _, _ = sut.Resolve(request) + _, _ = sut.Resolve(ctx, request) } }) diff --git a/resolver/query_logging_resolver_test.go b/resolver/query_logging_resolver_test.go index efc33bf51..40e326e54 100644 --- a/resolver/query_logging_resolver_test.go +++ b/resolver/query_logging_resolver_test.go @@ -44,6 +44,9 @@ var _ = Describe("QueryLoggingResolver", func() { m *mockResolver tmpDir *TmpFolder mockAnswer *dns.Msg + + ctx context.Context + cancelFn context.CancelFunc ) Describe("Type", func() { @@ -53,6 +56,9 @@ var _ = Describe("QueryLoggingResolver", func() { }) BeforeEach(func() { + ctx, cancelFn = context.WithCancel(context.Background()) + DeferCleanup(cancelFn) + mockAnswer = new(dns.Msg) tmpDir = NewTmpFolder("queryLoggingResolver") Expect(tmpDir.Error).Should(Succeed()) @@ -64,9 +70,6 @@ var _ = Describe("QueryLoggingResolver", func() { sutConfig.SetDefaults() // not called when using a struct literal } - ctx, cancelFn := context.WithCancel(context.Background()) - DeferCleanup(cancelFn) - sut = NewQueryLoggingResolver(ctx, sutConfig) m = &mockResolver{} m.On("Resolve", mock.Anything).Return(&Response{Res: mockAnswer, Reason: "reason"}, nil) @@ -98,7 +101,7 @@ var _ = Describe("QueryLoggingResolver", func() { } }) It("should process request without query logging", func() { - Expect(sut.Resolve(newRequest("example.com", A))). + Expect(sut.Resolve(ctx, newRequest("example.com", A))). Should( SatisfyAll( HaveResponseType(ResponseTypeRESOLVED), @@ -120,7 +123,7 @@ var _ = Describe("QueryLoggingResolver", func() { }) It("should create a log file per client", func() { By("request from client 1", func() { - Expect(sut.Resolve(newRequestWithClient("example.com.", A, "192.168.178.25", "client1"))). + Expect(sut.Resolve(ctx, newRequestWithClient("example.com.", A, "192.168.178.25", "client1"))). Should( SatisfyAll( HaveResponseType(ResponseTypeRESOLVED), @@ -128,7 +131,7 @@ var _ = Describe("QueryLoggingResolver", func() { )) }) By("request from client 2, has name with special chars, should be escaped", func() { - Expect(sut.Resolve(newRequestWithClient( + Expect(sut.Resolve(ctx, newRequestWithClient( "example.com.", A, "192.168.178.26", "cl/ient2\\$%&test"))). Should( SatisfyAll( @@ -188,7 +191,7 @@ var _ = Describe("QueryLoggingResolver", func() { }) It("should create one log file for all clients", func() { By("request from client 1", func() { - Expect(sut.Resolve(newRequestWithClient("example.com.", A, "192.168.178.25", "client1"))). + Expect(sut.Resolve(ctx, newRequestWithClient("example.com.", A, "192.168.178.25", "client1"))). Should( SatisfyAll( HaveResponseType(ResponseTypeRESOLVED), @@ -196,7 +199,7 @@ var _ = Describe("QueryLoggingResolver", func() { )) }) By("request from client 2, has name with special chars, should be escaped", func() { - Expect(sut.Resolve(newRequestWithClient("example.com.", A, "192.168.178.26", "client2"))). + Expect(sut.Resolve(ctx, newRequestWithClient("example.com.", A, "192.168.178.26", "client2"))). Should( SatisfyAll( HaveResponseType(ResponseTypeRESOLVED), @@ -249,7 +252,7 @@ var _ = Describe("QueryLoggingResolver", func() { }) It("should create one log file", func() { By("request from client 1", func() { - Expect(sut.Resolve(newRequestWithClient("example.com.", A, "192.168.178.25", "client1"))). + Expect(sut.Resolve(ctx, newRequestWithClient("example.com.", A, "192.168.178.25", "client1"))). Should( SatisfyAll( HaveResponseType(ResponseTypeRESOLVED), @@ -297,7 +300,7 @@ var _ = Describe("QueryLoggingResolver", func() { sut.writer = mockWriter Eventually(func() int { - _, ierr := sut.Resolve(newRequestWithClient("example.com.", A, "192.168.178.25", "client1")) + _, ierr := sut.Resolve(ctx, newRequestWithClient("example.com.", A, "192.168.178.25", "client1")) Expect(ierr).Should(Succeed()) return len(sut.logChan) diff --git a/resolver/rewriter_resolver_test.go b/resolver/rewriter_resolver_test.go index a74d0f317..6027c7458 100644 --- a/resolver/rewriter_resolver_test.go +++ b/resolver/rewriter_resolver_test.go @@ -1,6 +1,8 @@ package resolver import ( + "context" + "github.com/0xERR0R/blocky/config" "github.com/0xERR0R/blocky/log" "github.com/0xERR0R/blocky/model" @@ -94,7 +96,7 @@ var _ = Describe("RewriterResolver", func() { return res } - resp, err := sut.Resolve(request) + resp, err := sut.Resolve(context.Background(), request) Expect(err).Should(Succeed()) if resp != mNextResponse { Expect(resp.Res.Question[0].Name).Should(Equal(fqdnOriginal)) @@ -132,18 +134,18 @@ var _ = Describe("RewriterResolver", func() { expectNilAnswer = true // Make inner call the NoOpResolver - mInner.ResolveFn = func(req *model.Request) (*model.Response, error) { + mInner.ResolveFn = func(ctx context.Context, req *model.Request) (*model.Response, error) { Expect(req).Should(Equal(request)) // Inner should see fqdnRewritten Expect(req.Req.Question[0].Name).Should(Equal(fqdnRewritten)) - return mInner.next.Resolve(req) + return mInner.next.Resolve(ctx, req) } // Resolver after RewriterResolver should see `fqdnOriginal` mNext.On("Resolve", mock.Anything) - mNext.ResolveFn = func(req *model.Request) (*model.Response, error) { + mNext.ResolveFn = func(ctx context.Context, req *model.Request) (*model.Response, error) { Expect(req.Req.Question[0].Name).Should(Equal(fqdnOriginal)) return mNextResponse, nil @@ -156,7 +158,7 @@ var _ = Describe("RewriterResolver", func() { expectNilAnswer = true // Make inner return a nil Answer but not an empty Response - mInner.ResolveFn = func(req *model.Request) (*model.Response, error) { + mInner.ResolveFn = func(ctx context.Context, req *model.Request) (*model.Response, error) { Expect(req).Should(Equal(request)) // Inner should see fqdnRewritten @@ -179,7 +181,7 @@ var _ = Describe("RewriterResolver", func() { fqdnRewritten = sampleRewritten // Make inner return a nil Answer but not an empty Response - mInner.ResolveFn = func(req *model.Request) (*model.Response, error) { + mInner.ResolveFn = func(ctx context.Context, req *model.Request) (*model.Response, error) { Expect(req).Should(Equal(request)) // Inner should see fqdnRewritten @@ -190,7 +192,7 @@ var _ = Describe("RewriterResolver", func() { // Resolver after RewriterResolver should see `fqdnOriginal` mNext.On("Resolve", mock.Anything) - mNext.ResolveFn = func(req *model.Request) (*model.Response, error) { + mNext.ResolveFn = func(ctx context.Context, req *model.Request) (*model.Response, error) { Expect(req.Req.Question[0].Name).Should(Equal(fqdnOriginal)) return mNextResponse, nil diff --git a/resolver/strict_resolver_test.go b/resolver/strict_resolver_test.go index 6f5e3dd74..0a4468442 100644 --- a/resolver/strict_resolver_test.go +++ b/resolver/strict_resolver_test.go @@ -1,6 +1,7 @@ package resolver import ( + "context" "time" "github.com/0xERR0R/blocky/config" @@ -27,6 +28,9 @@ var _ = Describe("StrictResolver", Label("strictResolver"), func() { err error bootstrap *Bootstrap + + ctx context.Context + cancelFn context.CancelFunc ) Describe("Type", func() { @@ -36,6 +40,9 @@ var _ = Describe("StrictResolver", Label("strictResolver"), func() { }) BeforeEach(func() { + ctx, cancelFn = context.WithCancel(context.Background()) + DeferCleanup(cancelFn) + upstreams = []config.Upstream{ {Host: "wrong"}, {Host: "127.0.0.2"}, @@ -51,7 +58,7 @@ var _ = Describe("StrictResolver", Label("strictResolver"), func() { Name: upstreamDefaultCfgName, Upstreams: upstreams, } - sut, err = NewStrictResolver(sutConfig, bootstrap, sutVerify) + sut, err = NewStrictResolver(ctx, sutConfig, bootstrap, sutVerify) }) config.GetConfig().Upstreams.Timeout = config.Duration(time.Second) @@ -100,7 +107,7 @@ var _ = Describe("StrictResolver", Label("strictResolver"), func() { mockUpstream.Start(), } - _, err := NewStrictResolver(config.UpstreamGroup{ + _, err := NewStrictResolver(ctx, config.UpstreamGroup{ Name: upstreamDefaultCfgName, Upstreams: upstreams, }, @@ -151,7 +158,7 @@ var _ = Describe("StrictResolver", Label("strictResolver"), func() { }) It("Should use result from first one", func() { request := newRequest("example.com.", A) - Expect(sut.Resolve(request)). + Expect(sut.Resolve(ctx, request)). Should( SatisfyAll( BeDNSRecord("example.com.", A, "123.124.122.122"), @@ -180,7 +187,7 @@ var _ = Describe("StrictResolver", Label("strictResolver"), func() { }) It("should return response from next upstream", func() { request := newRequest("example.com", A) - Expect(sut.Resolve(request)).Should( + Expect(sut.Resolve(ctx, request)).Should( SatisfyAll( BeDNSRecord("example.com.", A, "123.124.122.2"), HaveTTL(BeNumerically("==", 123)), @@ -214,7 +221,7 @@ var _ = Describe("StrictResolver", Label("strictResolver"), func() { }) It("should return error", func() { request := newRequest("example.com", A) - _, err := sut.Resolve(request) + _, err := sut.Resolve(ctx, request) Expect(err).To(HaveOccurred()) }) }) @@ -230,7 +237,7 @@ var _ = Describe("StrictResolver", Label("strictResolver"), func() { }) It("Should use result from second one", func() { request := newRequest("example.com.", A) - Expect(sut.Resolve(request)). + Expect(sut.Resolve(ctx, request)). Should( SatisfyAll( BeDNSRecord("example.com.", A, "123.124.122.123"), @@ -247,7 +254,7 @@ var _ = Describe("StrictResolver", Label("strictResolver"), func() { }) It("Should return error", func() { request := newRequest("example.com.", A) - _, err = sut.Resolve(request) + _, err = sut.Resolve(ctx, request) Expect(err).Should(HaveOccurred()) }) }) @@ -262,7 +269,7 @@ var _ = Describe("StrictResolver", Label("strictResolver"), func() { It("Should use result from defined resolver", func() { request := newRequest("example.com.", A) - Expect(sut.Resolve(request)). + Expect(sut.Resolve(ctx, request)). Should( SatisfyAll( BeDNSRecord("example.com.", A, "123.124.122.122"), diff --git a/resolver/sudn_resolver_test.go b/resolver/sudn_resolver_test.go index 1c004b203..6dc3de282 100644 --- a/resolver/sudn_resolver_test.go +++ b/resolver/sudn_resolver_test.go @@ -1,6 +1,7 @@ package resolver import ( + "context" "fmt" "github.com/0xERR0R/blocky/config" @@ -19,6 +20,9 @@ var _ = Describe("SudnResolver", Label("sudnResolver"), func() { sut *SpecialUseDomainNamesResolver sutConfig config.SUDNConfig m *mockResolver + + ctx context.Context + cancelFn context.CancelFunc ) Describe("Type", func() { @@ -30,6 +34,9 @@ var _ = Describe("SudnResolver", Label("sudnResolver"), func() { BeforeEach(func() { var err error + ctx, cancelFn = context.WithCancel(context.Background()) + DeferCleanup(cancelFn) + sutConfig, err = config.WithDefaults[config.SUDNConfig]() Expect(err).Should(Succeed()) }) @@ -48,7 +55,7 @@ var _ = Describe("SudnResolver", Label("sudnResolver"), func() { Describe("handlers", func() { It("should have correct response type", func() { for domain, handler := range sudnHandlers { - resp, err := sut.Resolve(newRequest(domain, A)) + resp, err := sut.Resolve(ctx, newRequest(domain, A)) Expect(err).Should(Succeed()) if handler == nil { @@ -90,7 +97,7 @@ var _ = Describe("SudnResolver", Label("sudnResolver"), func() { DescribeTable("handled domains", func(qType dns.Type, qName string, expectedRCode int, extraMatchers ...types.GomegaMatcher) { - resp, err := sut.Resolve(newRequest(qName, qType)) + resp, err := sut.Resolve(ctx, newRequest(qName, qType)) Expect(err).Should(Succeed()) Expect(resp).Should(SatisfyAll( HaveResponseType(ResponseTypeSPECIAL), @@ -133,7 +140,7 @@ var _ = Describe("SudnResolver", Label("sudnResolver"), func() { DescribeTable("", func(qType dns.Type, qName string, expectedRCode int) { - resp, err := sut.Resolve(newRequest(qName, qType)) + resp, err := sut.Resolve(ctx, newRequest(qName, qType)) Expect(err).Should(Succeed()) Expect(resp).Should(HaveReturnCode(expectedRCode)) Expect(resp).ShouldNot(HaveResponseType(ResponseTypeSPECIAL)) @@ -150,7 +157,7 @@ var _ = Describe("SudnResolver", Label("sudnResolver"), func() { }) It("should forward example.com", func() { - Expect(sut.Resolve(newRequest("example.com", A))). + Expect(sut.Resolve(ctx, newRequest("example.com", A))). Should( SatisfyAll( BeDNSRecord("example.com.", A, "123.145.123.145"), @@ -161,7 +168,7 @@ var _ = Describe("SudnResolver", Label("sudnResolver"), func() { }) It("should forward home.arpa. IN DS", func() { - Expect(sut.Resolve(newRequest("something.home.arpa.", DS))). + Expect(sut.Resolve(ctx, newRequest("something.home.arpa.", DS))). Should( SatisfyAll( // setup code doesn't care about the question @@ -173,7 +180,7 @@ var _ = Describe("SudnResolver", Label("sudnResolver"), func() { }) It("should forward non special use domains", func() { - resp, err := sut.Resolve(newRequest("something.not-special.", AAAA)) + resp, err := sut.Resolve(ctx, newRequest("something.not-special.", AAAA)) Expect(err).Should(Succeed()) Expect(resp).ShouldNot(HaveResponseType(ResponseTypeSPECIAL)) }) diff --git a/resolver/upstream_resolver_test.go b/resolver/upstream_resolver_test.go index 8f5a26065..3c9f94e7d 100644 --- a/resolver/upstream_resolver_test.go +++ b/resolver/upstream_resolver_test.go @@ -1,6 +1,7 @@ package resolver import ( + "context" "crypto/tls" "fmt" "net/http" @@ -21,9 +22,15 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() { var ( sut *UpstreamResolver sutConfig config.Upstream + + ctx context.Context + cancelFn context.CancelFunc ) BeforeEach(func() { + ctx, cancelFn = context.WithCancel(context.Background()) + DeferCleanup(cancelFn) + sutConfig = config.Upstream{Host: "localhost"} }) @@ -62,7 +69,7 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() { upstream := mockUpstream.Start() sut := newUpstreamResolverUnchecked(upstream, nil) - Expect(sut.Resolve(newRequest("example.com.", A))). + Expect(sut.Resolve(ctx, newRequest("example.com.", A))). Should( SatisfyAll( BeDNSRecord("example.com.", A, "123.124.122.122"), @@ -81,7 +88,7 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() { upstream := mockUpstream.Start() sut := newUpstreamResolverUnchecked(upstream, nil) - Expect(sut.Resolve(newRequest("example.com.", A))). + Expect(sut.Resolve(ctx, newRequest("example.com.", A))). Should( SatisfyAll( HaveNoAnswer(), @@ -100,7 +107,7 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() { upstream := mockUpstream.Start() sut := newUpstreamResolverUnchecked(upstream, nil) - _, err := sut.Resolve(newRequest("example.com.", A)) + _, err := sut.Resolve(ctx, newRequest("example.com.", A)) Expect(err).Should(HaveOccurred()) }) }) @@ -133,7 +140,7 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() { atomic.StoreInt32(&counter, 0) atomic.StoreInt32(&attemptsWithTimeout, 2) - Expect(sut.Resolve(newRequest("example.com.", A))). + Expect(sut.Resolve(ctx, newRequest("example.com.", A))). Should( SatisfyAll( BeDNSRecord("example.com.", A, "123.124.122.122"), @@ -146,7 +153,7 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() { By("3 attempts with timeout -> should return error", func() { atomic.StoreInt32(&counter, 0) atomic.StoreInt32(&attemptsWithTimeout, 3) - _, err := sut.Resolve(newRequest("example.com.", A)) + _, err := sut.Resolve(ctx, newRequest("example.com.", A)) Expect(err).Should(HaveOccurred()) Expect(err.Error()).Should(ContainSubstring("i/o timeout")) }) @@ -185,7 +192,7 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() { }) When("Configured DOH resolver can resolve query", func() { It("should return answer from DNS upstream", func() { - Expect(sut.Resolve(newRequest("example.com.", A))). + Expect(sut.Resolve(ctx, newRequest("example.com.", A))). Should( SatisfyAll( BeDNSRecord("example.com.", A, "123.124.122.122"), @@ -203,7 +210,7 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() { } }) It("should return error", func() { - _, err := sut.Resolve(newRequest("example.com.", A)) + _, err := sut.Resolve(ctx, newRequest("example.com.", A)) Expect(err).Should(HaveOccurred()) Expect(err.Error()).Should(ContainSubstring("http return code should be 200, but received 500")) }) @@ -215,7 +222,7 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() { } }) It("should return error", func() { - _, err := sut.Resolve(newRequest("example.com.", A)) + _, err := sut.Resolve(ctx, newRequest("example.com.", A)) Expect(err).Should(HaveOccurred()) Expect(err.Error()).Should( ContainSubstring("http return content type should be 'application/dns-message', but was 'text'")) @@ -228,7 +235,7 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() { } }) It("should return error", func() { - _, err := sut.Resolve(newRequest("example.com.", A)) + _, err := sut.Resolve(ctx, newRequest("example.com.", A)) Expect(err).Should(HaveOccurred()) Expect(err.Error()).Should(ContainSubstring("can't unpack message")) }) @@ -241,7 +248,7 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() { }, systemResolverBootstrap) }) It("should return error", func() { - _, err := sut.Resolve(newRequest("example.com.", A)) + _, err := sut.Resolve(ctx, newRequest("example.com.", A)) Expect(err).Should(HaveOccurred()) Expect(err.Error()).Should(Or( ContainSubstring("no such host"), diff --git a/resolver/upstream_tree_resolver_test.go b/resolver/upstream_tree_resolver_test.go index 3855761ab..11b709042 100644 --- a/resolver/upstream_tree_resolver_test.go +++ b/resolver/upstream_tree_resolver_test.go @@ -1,6 +1,8 @@ package resolver import ( + "context" + "github.com/0xERR0R/blocky/config" . "github.com/0xERR0R/blocky/helpertest" "github.com/0xERR0R/blocky/log" @@ -139,7 +141,15 @@ var _ = Describe("UpstreamTreeResolver", Label("upstreamTreeResolver"), func() { }) When("client specific resolvers are defined", func() { + var ( + ctx context.Context + cancelFn context.CancelFunc + ) + BeforeEach(func() { + ctx, cancelFn = context.WithCancel(context.Background()) + DeferCleanup(cancelFn) + sutConfig = config.UpstreamsConfig{Groups: config.UpstreamGroups{ upstreamDefaultCfgName: {config.Upstream{}}, "laptop": {config.Upstream{}}, @@ -191,7 +201,7 @@ var _ = Describe("UpstreamTreeResolver", Label("upstreamTreeResolver"), func() { It("Should use default if client name or IP don't match", func() { request := newRequestWithClient("example.com.", A, "192.168.178.55", "test") - Expect(sut.Resolve(request)). + Expect(sut.Resolve(ctx, request)). Should( SatisfyAll( BeDNSRecord("example.com.", A, "default"), @@ -202,7 +212,7 @@ var _ = Describe("UpstreamTreeResolver", Label("upstreamTreeResolver"), func() { It("Should use client specific resolver if client name matches exact", func() { request := newRequestWithClient("example.com.", A, "192.168.178.55", "laptop") - Expect(sut.Resolve(request)). + Expect(sut.Resolve(ctx, request)). Should( SatisfyAll( BeDNSRecord("example.com.", A, "laptop"), @@ -213,7 +223,7 @@ var _ = Describe("UpstreamTreeResolver", Label("upstreamTreeResolver"), func() { It("Should use client specific resolver if client name matches with wildcard", func() { request := newRequestWithClient("example.com.", A, "192.168.178.55", "client-test-m") - Expect(sut.Resolve(request)). + Expect(sut.Resolve(ctx, request)). Should( SatisfyAll( BeDNSRecord("example.com.", A, "client-*-m"), @@ -224,7 +234,7 @@ var _ = Describe("UpstreamTreeResolver", Label("upstreamTreeResolver"), func() { It("Should use client specific resolver if client name matches with range wildcard", func() { request := newRequestWithClient("example.com.", A, "192.168.178.55", "client7") - Expect(sut.Resolve(request)). + Expect(sut.Resolve(ctx, request)). Should( SatisfyAll( BeDNSRecord("example.com.", A, "client[0-9]"), @@ -235,7 +245,7 @@ var _ = Describe("UpstreamTreeResolver", Label("upstreamTreeResolver"), func() { It("Should use client specific resolver if client IP matches", func() { request := newRequestWithClient("example.com.", A, "192.168.178.33", "noname") - Expect(sut.Resolve(request)). + Expect(sut.Resolve(ctx, request)). Should( SatisfyAll( BeDNSRecord("example.com.", A, "192.168.178.33"), @@ -246,7 +256,7 @@ var _ = Describe("UpstreamTreeResolver", Label("upstreamTreeResolver"), func() { It("Should use client specific resolver if client name (containing IP) matches", func() { request := newRequestWithClient("example.com.", A, "0.0.0.0", "192.168.178.33") - Expect(sut.Resolve(request)). + Expect(sut.Resolve(ctx, request)). Should( SatisfyAll( BeDNSRecord("example.com.", A, "192.168.178.33"), @@ -257,7 +267,7 @@ var _ = Describe("UpstreamTreeResolver", Label("upstreamTreeResolver"), func() { It("Should use client specific resolver if client's CIDR (10.43.8.64 - 10.43.8.79) matches", func() { request := newRequestWithClient("example.com.", A, "10.43.8.70", "noname") - Expect(sut.Resolve(request)). + Expect(sut.Resolve(ctx, request)). Should( SatisfyAll( BeDNSRecord("example.com.", A, "10.43.8.67/28"), @@ -268,7 +278,7 @@ var _ = Describe("UpstreamTreeResolver", Label("upstreamTreeResolver"), func() { It("Should use exact IP match before client name match", func() { request := newRequestWithClient("example.com.", A, "192.168.178.33", "laptop") - Expect(sut.Resolve(request)). + Expect(sut.Resolve(ctx, request)). Should( SatisfyAll( BeDNSRecord("example.com.", A, "192.168.178.33"), @@ -279,7 +289,7 @@ var _ = Describe("UpstreamTreeResolver", Label("upstreamTreeResolver"), func() { It("Should use client name match before CIDR match", func() { request := newRequestWithClient("example.com.", A, "10.43.8.70", "laptop") - Expect(sut.Resolve(request)). + Expect(sut.Resolve(ctx, request)). Should( SatisfyAll( BeDNSRecord("example.com.", A, "laptop"), @@ -293,7 +303,7 @@ var _ = Describe("UpstreamTreeResolver", Label("upstreamTreeResolver"), func() { request := newRequestWithClient("example.com.", A, "0.0.0.0", "name-matches1") request.Log = logger - Expect(sut.Resolve(request)). + Expect(sut.Resolve(ctx, request)). Should( SatisfyAll( SatisfyAny( diff --git a/server/server_test.go b/server/server_test.go index 5c902a728..21f45474e 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -697,7 +697,7 @@ var _ = Describe("Running DNS server", func() { Describe("NewServer with strict upstream strategy", func() { It("successfully returns upstream branches", func() { - branches, err := createUpstreamBranches(&config.Config{ + branches, err := createUpstreamBranches(context.Background(), &config.Config{ Upstreams: config.UpstreamsConfig{ Strategy: config.UpstreamStrategyStrict, Groups: config.UpstreamGroups{ @@ -715,7 +715,7 @@ var _ = Describe("Running DNS server", func() { Describe("NewServer with random upstream strategy", func() { It("successfully returns upstream branches", func() { - branches, err := createUpstreamBranches(&config.Config{ + branches, err := createUpstreamBranches(context.Background(), &config.Config{ Upstreams: config.UpstreamsConfig{ Strategy: config.UpstreamStrategyRandom, Groups: config.UpstreamGroups{