diff --git a/api/api_interface_impl_test.go b/api/api_interface_impl_test.go index 0072e6d03..3d3c8a608 100644 --- a/api/api_interface_impl_test.go +++ b/api/api_interface_impl_test.go @@ -5,7 +5,6 @@ import ( "errors" "time" - // . "github.com/0xERR0R/blocky/helpertest" "github.com/0xERR0R/blocky/model" "github.com/0xERR0R/blocky/util" "github.com/miekg/dns" @@ -54,14 +53,14 @@ func (m *BlockingControlMock) BlockingStatus() BlockingStatus { return args.Get(0).(BlockingStatus) } -func (m *QuerierMock) Query(question string, qType dns.Type) (*model.Response, error) { - args := m.Called(question, qType) +func (m *QuerierMock) Query(ctx context.Context, question string, qType dns.Type) (*model.Response, error) { + args := m.Called(ctx, question, qType) return args.Get(0).(*model.Response), args.Error(1) } -func (m *CacheControlMock) FlushCaches() { - _ = m.Called() +func (m *CacheControlMock) FlushCaches(ctx context.Context) { + _ = m.Called(ctx) } var _ = Describe("API implementation tests", func() { @@ -71,9 +70,15 @@ var _ = Describe("API implementation tests", func() { listRefreshMock *ListRefreshMock cacheControlMock *CacheControlMock sut *OpenAPIInterfaceImpl + + ctx context.Context + cancelFn context.CancelFunc ) BeforeEach(func() { + ctx, cancelFn = context.WithCancel(context.Background()) + DeferCleanup(cancelFn) + blockingControlMock = &BlockingControlMock{} querierMock = &QuerierMock{} listRefreshMock = &ListRefreshMock{} @@ -95,12 +100,12 @@ var _ = Describe("API implementation tests", func() { ) Expect(err).Should(Succeed()) - querierMock.On("Query", "google.com.", A).Return(&model.Response{ + querierMock.On("Query", ctx, "google.com.", A).Return(&model.Response{ Res: queryResponse, Reason: "reason", }, nil) - resp, err := sut.Query(context.Background(), QueryRequestObject{ + resp, err := sut.Query(ctx, QueryRequestObject{ Body: &ApiQueryRequest{ Query: "google.com", Type: "A", }, @@ -116,7 +121,7 @@ var _ = Describe("API implementation tests", func() { }) It("should return 400 on wrong parameter", func() { - resp, err := sut.Query(context.Background(), QueryRequestObject{ + resp, err := sut.Query(ctx, QueryRequestObject{ Body: &ApiQueryRequest{ Query: "google.com", Type: "WRONGTYPE", @@ -135,7 +140,7 @@ var _ = Describe("API implementation tests", func() { It("should return 200 on success", func() { listRefreshMock.On("RefreshLists").Return(nil) - resp, err := sut.ListRefresh(context.Background(), ListRefreshRequestObject{}) + resp, err := sut.ListRefresh(ctx, ListRefreshRequestObject{}) Expect(err).Should(Succeed()) var resp200 ListRefresh200Response Expect(resp).Should(BeAssignableToTypeOf(resp200)) @@ -144,7 +149,7 @@ var _ = Describe("API implementation tests", func() { It("should return 500 on failure", func() { listRefreshMock.On("RefreshLists").Return(errors.New("failed")) - resp, err := sut.ListRefresh(context.Background(), ListRefreshRequestObject{}) + resp, err := sut.ListRefresh(ctx, ListRefreshRequestObject{}) Expect(err).Should(Succeed()) var resp500 ListRefresh500TextResponse Expect(resp).Should(BeAssignableToTypeOf(resp500)) @@ -160,7 +165,7 @@ var _ = Describe("API implementation tests", func() { duration := "3s" grroups := "gr1,gr2" - resp, err := sut.DisableBlocking(context.Background(), DisableBlockingRequestObject{ + resp, err := sut.DisableBlocking(ctx, DisableBlockingRequestObject{ Params: DisableBlockingParams{ Duration: &duration, Groups: &grroups, @@ -173,7 +178,7 @@ var _ = Describe("API implementation tests", func() { It("should return 400 on failure", func() { blockingControlMock.On("DisableBlocking", mock.Anything, mock.Anything).Return(errors.New("failed")) - resp, err := sut.DisableBlocking(context.Background(), DisableBlockingRequestObject{}) + resp, err := sut.DisableBlocking(ctx, DisableBlockingRequestObject{}) Expect(err).Should(Succeed()) var resp400 DisableBlocking400TextResponse Expect(resp).Should(BeAssignableToTypeOf(resp400)) @@ -182,7 +187,7 @@ var _ = Describe("API implementation tests", func() { It("should return 400 on wrong duration parameter", func() { wrongDuration := "4sds" - resp, err := sut.DisableBlocking(context.Background(), DisableBlockingRequestObject{ + resp, err := sut.DisableBlocking(ctx, DisableBlockingRequestObject{ Params: DisableBlockingParams{ Duration: &wrongDuration, }, @@ -197,7 +202,7 @@ var _ = Describe("API implementation tests", func() { It("should return 200 on success", func() { blockingControlMock.On("EnableBlocking").Return() - resp, err := sut.EnableBlocking(context.Background(), EnableBlockingRequestObject{}) + resp, err := sut.EnableBlocking(ctx, EnableBlockingRequestObject{}) Expect(err).Should(Succeed()) var resp200 EnableBlocking200Response Expect(resp).Should(BeAssignableToTypeOf(resp200)) @@ -212,7 +217,7 @@ var _ = Describe("API implementation tests", func() { AutoEnableInSec: 47, }) - resp, err := sut.BlockingStatus(context.Background(), BlockingStatusRequestObject{}) + resp, err := sut.BlockingStatus(ctx, BlockingStatusRequestObject{}) Expect(err).Should(Succeed()) var resp200 BlockingStatus200JSONResponse Expect(resp).Should(BeAssignableToTypeOf(resp200)) @@ -227,8 +232,8 @@ var _ = Describe("API implementation tests", func() { Describe("Cache API", func() { When("Cache flush is called", func() { It("should return 200 on success", func() { - cacheControlMock.On("FlushCaches").Return() - resp, err := sut.CacheFlush(context.Background(), CacheFlushRequestObject{}) + cacheControlMock.On("FlushCaches", ctx).Return() + resp, err := sut.CacheFlush(ctx, CacheFlushRequestObject{}) Expect(err).Should(Succeed()) var resp200 CacheFlush200Response Expect(resp).Should(BeAssignableToTypeOf(resp200)) diff --git a/cache/expirationcache/expiration_cache_test.go b/cache/expirationcache/expiration_cache_test.go index 9364350b2..f8313dffd 100644 --- a/cache/expirationcache/expiration_cache_test.go +++ b/cache/expirationcache/expiration_cache_test.go @@ -149,7 +149,7 @@ var _ = Describe("Expiration cache", func() { Describe("preExpiration function", func() { When("function is defined", func() { It("should update the value and TTL if function returns values", func() { - fn := func(key string) (val *string, ttl time.Duration) { + fn := func(ctx context.Context, key string) (val *string, ttl time.Duration) { v2 := "v2" return &v2, time.Second @@ -169,7 +169,7 @@ var _ = Describe("Expiration cache", func() { }) It("should update the value and TTL if function returns values on cleanup if element is expired", func() { - fn := func(key string) (val *string, ttl time.Duration) { + fn := func(ctx context.Context, key string) (val *string, ttl time.Duration) { v2 := "val2" return &v2, time.Second @@ -192,7 +192,7 @@ var _ = Describe("Expiration cache", func() { }) It("should delete the key if function returns nil", func() { - fn := func(key string) (val *string, ttl time.Duration) { + fn := func(ctx context.Context, key string) (val *string, ttl time.Duration) { return nil, 0 } cache := NewCacheWithOnExpired[string](ctx, Options{CleanupInterval: 100 * time.Microsecond}, fn) diff --git a/cache/expirationcache/prefetching_cache_test.go b/cache/expirationcache/prefetching_cache_test.go index 91b796a1b..fb94c62ff 100644 --- a/cache/expirationcache/prefetching_cache_test.go +++ b/cache/expirationcache/prefetching_cache_test.go @@ -54,7 +54,7 @@ var _ = Describe("Prefetching expiration cache", func() { }, PrefetchThreshold: 2, PrefetchExpires: 100 * time.Millisecond, - ReloadFn: func(cacheKey string) (*string, time.Duration) { + ReloadFn: func(ctx context.Context, cacheKey string) (*string, time.Duration) { v := "v2" return &v, 50 * time.Millisecond @@ -86,7 +86,7 @@ var _ = Describe("Prefetching expiration cache", func() { }, PrefetchThreshold: 2, PrefetchExpires: 100 * time.Millisecond, - ReloadFn: func(cacheKey string) (*string, time.Duration) { + ReloadFn: func(ctx context.Context, cacheKey string) (*string, time.Duration) { v := "v2" return &v, 50 * time.Millisecond @@ -113,7 +113,7 @@ var _ = Describe("Prefetching expiration cache", func() { Options: Options{ CleanupInterval: 100 * time.Millisecond, }, - ReloadFn: func(cacheKey string) (*string, time.Duration) { + ReloadFn: func(ctx context.Context, cacheKey string) (*string, time.Duration) { v := "v2" return &v, 50 * time.Millisecond @@ -143,7 +143,7 @@ var _ = Describe("Prefetching expiration cache", func() { }, PrefetchThreshold: 2, PrefetchExpires: 100 * time.Millisecond, - ReloadFn: func(cacheKey string) (*string, time.Duration) { + ReloadFn: func(ctx context.Context, cacheKey string) (*string, time.Duration) { v := "v2" return &v, 50 * time.Millisecond diff --git a/resolver/blocking_resolver_test.go b/resolver/blocking_resolver_test.go index b7004fbed..e55b7419e 100644 --- a/resolver/blocking_resolver_test.go +++ b/resolver/blocking_resolver_test.go @@ -155,7 +155,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { } Bus().Publish(ApplicationStarted, "") Eventually(func(g Gomega) { - g.Expect(sut.Resolve(newRequestWithClient("blocked2.com.", A, "192.168.178.39", "client1"))). + g.Expect(sut.Resolve(ctx, newRequestWithClient("blocked2.com.", A, "192.168.178.39", "client1"))). Should(And( BeDNSRecord("blocked2.com.", A, "0.0.0.0"), HaveTTL(BeNumerically("==", 60)), @@ -185,6 +185,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { When("Domain is on the black list", func() { It("should block request", func() { Eventually(sut.Resolve). + WithContext(ctx). WithArguments(newRequestWithClient("regex.com.", dns.Type(dns.TypeA), "1.2.1.2", "client1")). Should( SatisfyAll( @@ -222,7 +223,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { When("client name is defined in client groups block", func() { It("should block the A query if domain is on the black list (single)", func() { - Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "1.2.1.2", "client1"))). + Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "1.2.1.2", "client1"))). Should( SatisfyAll( BeDNSRecord("domain1.com.", A, "0.0.0.0"), @@ -233,7 +234,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { )) }) It("should block the A query if domain is on the black list (multipart 1)", func() { - Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "1.2.1.2", "client2"))). + Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "1.2.1.2", "client2"))). Should( SatisfyAll( BeDNSRecord("domain1.com.", A, "0.0.0.0"), @@ -244,7 +245,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { )) }) It("should block the A query if domain is on the black list (multipart 2)", func() { - Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "1.2.1.2", "client3"))). + Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "1.2.1.2", "client3"))). Should( SatisfyAll( BeDNSRecord("domain1.com.", A, "0.0.0.0"), @@ -255,7 +256,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { )) }) It("should block the A query if domain is on the black list (merged)", func() { - Expect(sut.Resolve(newRequestWithClient("blocked2.com.", A, "1.2.1.2", "client3"))). + Expect(sut.Resolve(ctx, newRequestWithClient("blocked2.com.", A, "1.2.1.2", "client3"))). Should( SatisfyAll( BeDNSRecord("blocked2.com.", A, "0.0.0.0"), @@ -266,7 +267,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { )) }) It("should block the AAAA query if domain is on the black list", func() { - Expect(sut.Resolve(newRequestWithClient("domain1.com.", AAAA, "1.2.1.2", "client1"))). + Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", AAAA, "1.2.1.2", "client1"))). Should( SatisfyAll( BeDNSRecord("domain1.com.", AAAA, "::"), @@ -277,18 +278,18 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { )) }) It("should block the HTTPS query if domain is on the black list", func() { - Expect(sut.Resolve(newRequestWithClient("domain1.com.", HTTPS, "1.2.1.2", "client1"))). + Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", HTTPS, "1.2.1.2", "client1"))). Should(HaveReturnCode(dns.RcodeNameError)) }) It("should block the MX query if domain is on the black list", func() { - Expect(sut.Resolve(newRequestWithClient("domain1.com.", MX, "1.2.1.2", "client1"))). + Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", MX, "1.2.1.2", "client1"))). Should(HaveReturnCode(dns.RcodeNameError)) }) }) When("Client ip is defined in client groups block", func() { It("should block the query if domain is on the black list", func() { - Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "192.168.178.55", "unknown"))). + Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "192.168.178.55", "unknown"))). Should( SatisfyAll( BeDNSRecord("domain1.com.", A, "0.0.0.0"), @@ -301,7 +302,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { }) When("Client CIDR (10.43.8.64 - 10.43.8.79) is defined in client groups block", func() { It("should not block the query for 10.43.8.63 if domain is on the black list", func() { - Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "10.43.8.63", "unknown"))). + Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "10.43.8.63", "unknown"))). Should( SatisfyAll( HaveNoAnswer(), @@ -313,7 +314,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { m.AssertExpectations(GinkgoT()) }) It("should not block the query for 10.43.8.80 if domain is on the black list", func() { - Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "10.43.8.80", "unknown"))). + Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "10.43.8.80", "unknown"))). Should( SatisfyAll( HaveNoAnswer(), @@ -328,7 +329,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { When("Client CIDR (10.43.8.64 - 10.43.8.79) is defined in client groups block", func() { It("should block the query for 10.43.8.64 if domain is on the black list", func() { - Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "10.43.8.64", "unknown"))). + Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "10.43.8.64", "unknown"))). Should( SatisfyAll( BeDNSRecord("domain1.com.", A, "0.0.0.0"), @@ -339,7 +340,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { )) }) It("should block the query for 10.43.8.79 if domain is on the black list", func() { - Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "10.43.8.79", "unknown"))). + Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "10.43.8.79", "unknown"))). Should( SatisfyAll( BeDNSRecord("domain1.com.", A, "0.0.0.0"), @@ -353,7 +354,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { When("Client has multiple names and for each name a client group block definition exists", func() { It("should block query if domain is in one group", func() { - Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "1.2.1.2", "client1", "altname"))). + Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "1.2.1.2", "client1", "altname"))). Should( SatisfyAll( BeDNSRecord("domain1.com.", A, "0.0.0.0"), @@ -364,7 +365,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { )) }) It("should block query if domain is in another group too", func() { - Expect(sut.Resolve(newRequestWithClient("blocked2.com.", A, "1.2.1.2", "client1", "altName"))). + Expect(sut.Resolve(ctx, newRequestWithClient("blocked2.com.", A, "1.2.1.2", "client1", "altName"))). Should( SatisfyAll( BeDNSRecord("blocked2.com.", A, "0.0.0.0"), @@ -377,7 +378,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { }) When("Client name matches wildcard", func() { It("should block query if domain is in one group", func() { - Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "1.2.1.2", "wildcard1name"))). + Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "1.2.1.2", "wildcard1name"))). Should( SatisfyAll( BeDNSRecord("domain1.com.", A, "0.0.0.0"), @@ -391,7 +392,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { When("Default group is defined", func() { It("should block domains from default group for each client", func() { - Expect(sut.Resolve(newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))). + Expect(sut.Resolve(ctx, newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))). Should( SatisfyAll( BeDNSRecord("blocked3.com.", A, "0.0.0.0"), @@ -418,7 +419,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { }) It("should return NXDOMAIN if query is blocked", func() { - Expect(sut.Resolve(newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))). + Expect(sut.Resolve(ctx, newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))). Should( SatisfyAll( HaveNoAnswer(), @@ -444,7 +445,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { }) It("should return answer with specified TTL", func() { - Expect(sut.Resolve(newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))). + Expect(sut.Resolve(ctx, newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))). Should( SatisfyAll( BeDNSRecord("blocked3.com.", A, "0.0.0.0"), @@ -461,7 +462,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { }) It("should return custom IP with specified TTL", func() { - Expect(sut.Resolve(newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))). + Expect(sut.Resolve(ctx, newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))). Should( SatisfyAll( BeDNSRecord("blocked3.com.", A, "12.12.12.12"), @@ -489,7 +490,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { }) It("should return ipv4 address for A query if query is blocked", func() { - Expect(sut.Resolve(newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))). + Expect(sut.Resolve(ctx, newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))). Should( SatisfyAll( BeDNSRecord("blocked3.com.", A, "12.12.12.12"), @@ -501,7 +502,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { }) It("should return ipv6 address for AAAA query if query is blocked", func() { - Expect(sut.Resolve(newRequestWithClient("blocked3.com.", AAAA, "1.2.1.2", "unknown"))). + Expect(sut.Resolve(ctx, newRequestWithClient("blocked3.com.", AAAA, "1.2.1.2", "unknown"))). Should( SatisfyAll( BeDNSRecord("blocked3.com.", AAAA, "2001:db8:85a3::8a2e:370:7334"), @@ -528,7 +529,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { }) It("should use fallback for ipv6 and return zero ip", func() { - Expect(sut.Resolve(newRequestWithClient("blocked3.com.", AAAA, "1.2.1.2", "unknown"))). + Expect(sut.Resolve(ctx, newRequestWithClient("blocked3.com.", AAAA, "1.2.1.2", "unknown"))). Should( SatisfyAll( BeDNSRecord("blocked3.com.", AAAA, "::"), @@ -547,7 +548,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { mockAnswer, _ = util.NewMsgWithAnswer("example.com.", 300, A, "123.145.123.145") }) It("should block query, if lookup result contains blacklisted IP", func() { - Expect(sut.Resolve(newRequestWithClient("example.com.", A, "1.2.1.2", "unknown"))). + Expect(sut.Resolve(ctx, newRequestWithClient("example.com.", A, "1.2.1.2", "unknown"))). Should( SatisfyAll( BeDNSRecord("example.com.", A, "0.0.0.0"), @@ -567,7 +568,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { ) }) It("should block query, if lookup result contains blacklisted IP", func() { - Expect(sut.Resolve(newRequestWithClient("example.com.", AAAA, "1.2.1.2", "unknown"))). + Expect(sut.Resolve(ctx, newRequestWithClient("example.com.", AAAA, "1.2.1.2", "unknown"))). Should( SatisfyAll( BeDNSRecord("example.com.", AAAA, "::"), @@ -590,7 +591,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { mockAnswer.Answer = []dns.RR{rr1, rr2, rr3} }) It("should block the query, if response contains a CNAME with domain on a blacklist", func() { - Expect(sut.Resolve(newRequestWithClient("example.com.", A, "1.2.1.2", "unknown"))). + Expect(sut.Resolve(ctx, newRequestWithClient("example.com.", A, "1.2.1.2", "unknown"))). Should( SatisfyAll( BeDNSRecord("example.com.", A, "0.0.0.0"), @@ -617,7 +618,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { } }) It("Should not be blocked", func() { - Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))). + Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))). Should( SatisfyAll( HaveNoAnswer(), @@ -649,7 +650,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { }) It("should block everything else except domains on the white list with default group", func() { By("querying domain on the whitelist", func() { - Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))). + Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))). Should( SatisfyAll( HaveNoAnswer(), @@ -662,7 +663,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { }) By("querying another domain, which is not on the whitelist", func() { - Expect(sut.Resolve(newRequestWithClient("google.com.", A, "1.2.1.2", "unknown"))). + Expect(sut.Resolve(ctx, newRequestWithClient("google.com.", A, "1.2.1.2", "unknown"))). Should( SatisfyAll( BeDNSRecord("google.com.", A, "0.0.0.0"), @@ -678,7 +679,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { It("should block everything else except domains on the white list "+ "if multiple white list only groups are defined", func() { By("querying domain on the whitelist", func() { - Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "1.2.1.2", "one-client"))). + Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "1.2.1.2", "one-client"))). Should( SatisfyAll( HaveNoAnswer(), @@ -691,7 +692,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { }) By("querying another domain, which is not on the whitelist", func() { - Expect(sut.Resolve(newRequestWithClient("blocked2.com.", A, "1.2.1.2", "one-client"))). + Expect(sut.Resolve(ctx, newRequestWithClient("blocked2.com.", A, "1.2.1.2", "one-client"))). Should( SatisfyAll( BeDNSRecord("blocked2.com.", A, "0.0.0.0"), @@ -706,7 +707,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { It("should block everything else except domains on the white list "+ "if multiple white list only groups are defined", func() { By("querying domain on the whitelist group 1", func() { - Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "1.2.1.2", "all-client"))). + Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "1.2.1.2", "all-client"))). Should( SatisfyAll( HaveNoAnswer(), @@ -719,7 +720,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { }) By("querying another domain, which is in the whitelist group 1", func() { - Expect(sut.Resolve(newRequestWithClient("blocked2.com.", A, "1.2.1.2", "all-client"))). + Expect(sut.Resolve(ctx, newRequestWithClient("blocked2.com.", A, "1.2.1.2", "all-client"))). Should( SatisfyAll( HaveNoAnswer(), @@ -745,7 +746,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { mockAnswer, _ = util.NewMsgWithAnswer("example.com.", 300, A, "123.145.123.145") }) It("should not block if DNS answer contains IP from the white list", func() { - Expect(sut.Resolve(newRequestWithClient("example.com.", A, "1.2.1.2", "unknown"))). + Expect(sut.Resolve(ctx, newRequestWithClient("example.com.", A, "1.2.1.2", "unknown"))). Should( SatisfyAll( BeDNSRecord("example.com.", A, "123.145.123.145"), @@ -775,7 +776,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { }) When("domain is not on the black list", func() { It("should delegate to next resolver", func() { - Expect(sut.Resolve(newRequestWithClient("example.com.", A, "1.2.1.2", "unknown"))). + Expect(sut.Resolve(ctx, newRequestWithClient("example.com.", A, "1.2.1.2", "unknown"))). Should( SatisfyAll( HaveNoAnswer(), @@ -792,7 +793,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { } }) It("should delegate to next resolver", func() { - Expect(sut.Resolve(newRequestWithClient("example.com.", A, "1.2.1.2", "unknown"))). + Expect(sut.Resolve(ctx, newRequestWithClient("example.com.", A, "1.2.1.2", "unknown"))). Should( SatisfyAll( HaveNoAnswer(), @@ -819,7 +820,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { When("Disable blocking is called", func() { It("no query should be blocked", func() { By("Perform query to ensure that the blocking status is active (defaultGroup)", func() { - Expect(sut.Resolve(newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))). + Expect(sut.Resolve(ctx, newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))). Should( SatisfyAll( BeDNSRecord("blocked3.com.", A, "0.0.0.0"), @@ -830,7 +831,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { }) By("Perform query to ensure that the blocking status is active (group1)", func() { - Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))). + Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))). Should( SatisfyAll( BeDNSRecord("domain1.com.", A, "0.0.0.0"), @@ -847,7 +848,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { By("perform the same query again (defaultGroup)", func() { // now is blocking disabled, query the url again - Expect(sut.Resolve(newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))). + Expect(sut.Resolve(ctx, newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))). Should( SatisfyAll( HaveNoAnswer(), @@ -861,7 +862,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { By("perform the same query again (group1)", func() { // now is blocking disabled, query the url again - Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))). + Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))). Should( SatisfyAll( HaveNoAnswer(), @@ -880,7 +881,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { By("perform the same query again (defaultGroup)", func() { // now is blocking disabled, query the url again - Expect(sut.Resolve(newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))). + Expect(sut.Resolve(ctx, newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))). Should( SatisfyAll( HaveNoAnswer(), @@ -893,7 +894,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { }) By("Perform query to ensure that the blocking status is active (group1)", func() { - Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))). + Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))). Should( SatisfyAll( BeDNSRecord("domain1.com.", A, "0.0.0.0"), @@ -908,7 +909,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { When("Disable blocking for all groups is called with a duration parameter", func() { It("No query should be blocked only for passed amount of time", func() { By("Perform query to ensure that the blocking status is active (defaultGroup)", func() { - Expect(sut.Resolve(newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))). + Expect(sut.Resolve(ctx, newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))). Should( SatisfyAll( BeDNSRecord("blocked3.com.", A, "0.0.0.0"), @@ -918,7 +919,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { )) }) By("Perform query to ensure that the blocking status is active (group1)", func() { - Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))). + Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))). Should( SatisfyAll( BeDNSRecord("domain1.com.", A, "0.0.0.0"), @@ -941,7 +942,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { By("perform the same query again to ensure that this query will not be blocked (defaultGroup)", func() { // now is blocking disabled, query the url again - Expect(sut.Resolve(newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))). + Expect(sut.Resolve(ctx, newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))). Should( SatisfyAll( HaveNoAnswer(), @@ -954,7 +955,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { }) By("perform the same query again to ensure that this query will not be blocked (group1)", func() { // now is blocking disabled, query the url again - Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))). + Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))). Should( SatisfyAll( HaveNoAnswer(), @@ -974,7 +975,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { // wait 1 sec Eventually(enabled, "1s").Should(Receive(BeTrue())) - Expect(sut.Resolve(newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))). + Expect(sut.Resolve(ctx, newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))). Should( SatisfyAll( BeDNSRecord("blocked3.com.", A, "0.0.0.0"), @@ -983,7 +984,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { HaveReturnCode(dns.RcodeSuccess), )) - Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))). + Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))). Should( SatisfyAll( BeDNSRecord("domain1.com.", A, "0.0.0.0"), @@ -998,7 +999,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { When("Disable blocking for one group is called with a duration parameter", func() { It("No query should be blocked only for passed amount of time", func() { By("Perform query to ensure that the blocking status is active (defaultGroup)", func() { - Expect(sut.Resolve(newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))). + Expect(sut.Resolve(ctx, newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))). Should( SatisfyAll( BeDNSRecord("blocked3.com.", A, "0.0.0.0"), @@ -1008,7 +1009,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { )) }) By("Perform query to ensure that the blocking status is active (group1)", func() { - Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))). + Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))). Should( SatisfyAll( BeDNSRecord("domain1.com.", A, "0.0.0.0"), @@ -1031,7 +1032,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { By("perform the same query again to ensure that this query will not be blocked (defaultGroup)", func() { // now is blocking disabled, query the url again - Expect(sut.Resolve(newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))). + Expect(sut.Resolve(ctx, newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))). Should( SatisfyAll( BeDNSRecord("blocked3.com.", A, "0.0.0.0"), @@ -1042,7 +1043,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { }) By("perform the same query again to ensure that this query will not be blocked (group1)", func() { // now is blocking disabled, query the url again - Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))). + Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))). Should( SatisfyAll( HaveNoAnswer(), @@ -1062,7 +1063,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { // wait 1 sec Eventually(enabled, "1s").Should(Receive(BeTrue())) - Expect(sut.Resolve(newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))). + Expect(sut.Resolve(ctx, newRequestWithClient("blocked3.com.", A, "1.2.1.2", "unknown"))). Should( SatisfyAll( BeDNSRecord("blocked3.com.", A, "0.0.0.0"), @@ -1071,7 +1072,7 @@ var _ = Describe("BlockingResolver", Label("blockingResolver"), func() { HaveReturnCode(dns.RcodeSuccess), )) - Expect(sut.Resolve(newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))). + Expect(sut.Resolve(ctx, newRequestWithClient("domain1.com.", A, "1.2.1.2", "unknown"))). Should( SatisfyAll( BeDNSRecord("domain1.com.", A, "0.0.0.0"), diff --git a/resolver/bootstrap_test.go b/resolver/bootstrap_test.go index 61b3722dc..5e4ffb397 100644 --- a/resolver/bootstrap_test.go +++ b/resolver/bootstrap_test.go @@ -72,7 +72,7 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() { }, } - _, err := sut.resolveUpstream(nil, "example.com") + _, err := sut.resolveUpstream(ctx, nil, "example.com") Expect(err).ShouldNot(Succeed()) Expect(usedSystemResolver).Should(Receive(BeTrue())) }) @@ -244,7 +244,7 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() { When("called from bootstrap.upstream", func() { It("uses hardcoded IPs", func() { - ips, err := sut.resolveUpstream(bootstrapUpstream, "host") + ips, err := sut.resolveUpstream(ctx, bootstrapUpstream, "host") Expect(err).Should(Succeed()) Expect(ips).Should(Equal(sutConfig.BootstrapDNS[0].IPs)) @@ -253,7 +253,7 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() { When("hostname is an IP", func() { It("returns immediately", func() { - ips, err := sut.resolve("0.0.0.0", config.IPVersionDual.QTypes()) + ips, err := sut.resolve(ctx, "0.0.0.0", config.IPVersionDual.QTypes()) Expect(err).Should(Succeed()) Expect(ips).Should(ContainElement(net.IPv4zero)) @@ -269,7 +269,7 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() { bootstrapUpstream.On("Resolve", mock.Anything).Return(&model.Response{Res: bootstrapResponse}, nil) - ips, err := sut.resolve("localhost", []dns.Type{AAAA}) + ips, err := sut.resolve(ctx, "localhost", []dns.Type{AAAA}) Expect(err).Should(Succeed()) Expect(ips).Should(HaveLen(1)) @@ -283,7 +283,7 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() { bootstrapUpstream.On("Resolve", mock.Anything).Return(nil, resolveErr) - ips, err := sut.resolve("localhost", []dns.Type{A}) + ips, err := sut.resolve(ctx, "localhost", []dns.Type{A}) Expect(err).ShouldNot(Succeed()) Expect(err.Error()).Should(ContainSubstring(resolveErr.Error())) @@ -297,7 +297,7 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() { bootstrapUpstream.On("Resolve", mock.Anything).Return(&model.Response{Res: bootstrapResponse}, nil) - ips, err := sut.resolve("unknownhost.invalid", []dns.Type{A}) + ips, err := sut.resolve(ctx, "unknownhost.invalid", []dns.Type{A}) Expect(err).ShouldNot(Succeed()) Expect(err.Error()).Should(ContainSubstring("no such host")) @@ -329,7 +329,7 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() { r := newUpstreamResolverUnchecked(upstream, sut) - rsp, err := r.Resolve(mainReq) + rsp, err := r.Resolve(ctx, mainReq) Expect(err).Should(Succeed()) Expect(mockUpstreamServer.GetCallCount()).Should(Equal(1)) Expect(rsp.Res.Question[0].Name).Should(Equal("example.com.")) @@ -373,7 +373,7 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() { // implicit expectation of 0 bootstrapUpstream.Resolve calls - _, err = t.DialContext(context.Background(), "ip", "!bad-addr!") + _, err = t.DialContext(ctx, "ip", "!bad-addr!") Expect(err).ShouldNot(Succeed()) }) @@ -384,7 +384,7 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() { t := sut.NewHTTPTransport() - _, err = t.DialContext(context.Background(), "ip", "abc:123") + _, err = t.DialContext(ctx, "ip", "abc:123") Expect(err).ShouldNot(Succeed()) Expect(err.Error()).Should(ContainSubstring(resolveErr.Error())) @@ -397,7 +397,7 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() { t := sut.NewHTTPTransport() - _, err = t.DialContext(context.Background(), "ip", "abc:123") + _, err = t.DialContext(ctx, "ip", "abc:123") Expect(err).ShouldNot(Succeed()) Expect(err.Error()).Should(ContainSubstring("no such host")) @@ -437,7 +437,7 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() { Describe("resolve", func() { AfterEach(func() { - _, err := sut.resolveUpstream(nil, "example.com") + _, err := sut.resolveUpstream(ctx, nil, "example.com") Expect(err).Should(Succeed()) m.AssertExpectations(GinkgoT()) @@ -501,7 +501,7 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() { AfterEach(func() { t := sut.NewHTTPTransport() - conn, err := t.DialContext(context.Background(), dialIPVersion.Net(), "localhost:0") + conn, err := t.DialContext(ctx, dialIPVersion.Net(), "localhost:0") Expect(err).Should(Succeed()) Expect(conn).Should(Equal(aMockConn)) @@ -583,7 +583,7 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() { }) It("uses both", func() { - _, err := sut.resolve("example.com.", []dns.Type{dns.Type(dns.TypeA)}) + _, err := sut.resolve(ctx, "example.com.", []dns.Type{dns.Type(dns.TypeA)}) Expect(err).To(Succeed()) diff --git a/resolver/caching_resolver_test.go b/resolver/caching_resolver_test.go index acc7be64c..df95b613b 100644 --- a/resolver/caching_resolver_test.go +++ b/resolver/caching_resolver_test.go @@ -102,7 +102,7 @@ var _ = Describe("CachingResolver", func() { })).Should(Succeed()) // first request - _, _ = sut.Resolve(newRequest("example.com.", A)) + _, _ = sut.Resolve(ctx, newRequest("example.com.", A)) // Domain is not prefetched Expect(domainPrefetched).ShouldNot(Receive()) @@ -112,7 +112,7 @@ var _ = Describe("CachingResolver", func() { // now query again > threshold for i := 0; i < prefetchThreshold+1; i++ { - _, err := sut.Resolve(newRequest("example.com.", A)) + _, err := sut.Resolve(ctx, newRequest("example.com.", A)) Expect(err).Should(Succeed()) } @@ -120,7 +120,7 @@ var _ = Describe("CachingResolver", func() { Eventually(domainPrefetched, "10s").Should(Receive(Equal(true))) // and it should hit from prefetch cache - Expect(sut.Resolve(newRequest("example.com.", A))). + Expect(sut.Resolve(ctx, newRequest("example.com.", A))). Should( SatisfyAll( HaveResponseType(ResponseTypeCACHED), @@ -144,7 +144,7 @@ var _ = Describe("CachingResolver", func() { }) It("should cache response and use response's TTL for multiple records", func() { By("first request", func() { - result, err := sut.Resolve(newRequest("example.com.", A)) + result, err := sut.Resolve(ctx, newRequest("example.com.", A)) Expect(err).Should(Succeed()) Expect(result). Should( @@ -164,7 +164,7 @@ var _ = Describe("CachingResolver", func() { By("second request", func() { Eventually(func(g Gomega) { - result, err := sut.Resolve(newRequest("example.com.", A)) + result, err := sut.Resolve(ctx, newRequest("example.com.", A)) g.Expect(err).Should(Succeed()) g.Expect(result). Should( @@ -206,7 +206,7 @@ var _ = Describe("CachingResolver", func() { _ = Bus().SubscribeOnce(CachingResultCacheChanged, func(d int) { totalCacheCount <- d }) - Expect(sut.Resolve(newRequest("example.com.", A))). + Expect(sut.Resolve(ctx, newRequest("example.com.", A))). Should( SatisfyAll( HaveResponseType(ResponseTypeRESOLVED), @@ -227,7 +227,7 @@ var _ = Describe("CachingResolver", func() { domain <- true }) - g.Expect(sut.Resolve(newRequest("example.com.", A))). + g.Expect(sut.Resolve(ctx, newRequest("example.com.", A))). Should( SatisfyAll( HaveResponseType(ResponseTypeCACHED), @@ -252,7 +252,7 @@ var _ = Describe("CachingResolver", func() { It("should cache response and use min caching time as TTL", func() { By("first request", func() { - Expect(sut.Resolve(newRequest("example.com.", A))). + Expect(sut.Resolve(ctx, newRequest("example.com.", A))). Should( SatisfyAll( HaveResponseType(ResponseTypeRESOLVED), @@ -264,7 +264,9 @@ var _ = Describe("CachingResolver", func() { }) By("second request", func() { - Eventually(sut.Resolve).WithArguments(newRequest("example.com.", A)). + Eventually(sut.Resolve). + WithContext(ctx). + WithArguments(newRequest("example.com.", A)). Should( SatisfyAll( HaveResponseType(ResponseTypeCACHED), @@ -287,7 +289,7 @@ var _ = Describe("CachingResolver", func() { It("should cache response and use min caching time as TTL", func() { By("first request", func() { - Expect(sut.Resolve(newRequest("example.com.", AAAA))). + Expect(sut.Resolve(ctx, newRequest("example.com.", AAAA))). Should( SatisfyAll( HaveResponseType(ResponseTypeRESOLVED), @@ -298,7 +300,9 @@ var _ = Describe("CachingResolver", func() { }) By("second request", func() { - Eventually(sut.Resolve).WithArguments(newRequest("example.com.", AAAA)). + Eventually(sut.Resolve). + WithContext(ctx). + WithArguments(newRequest("example.com.", AAAA)). Should( SatisfyAll( HaveResponseType(ResponseTypeCACHED), @@ -332,7 +336,7 @@ var _ = Describe("CachingResolver", func() { It("Shouldn't cache any responses", func() { By("first request", func() { - Expect(sut.Resolve(newRequest("example.com.", AAAA))). + Expect(sut.Resolve(ctx, newRequest("example.com.", AAAA))). Should( SatisfyAll( HaveResponseType(ResponseTypeRESOLVED), @@ -343,7 +347,9 @@ var _ = Describe("CachingResolver", func() { }) By("second request", func() { - Eventually(sut.Resolve).WithArguments(newRequest("example.com.", AAAA)). + Eventually(sut.Resolve). + WithContext(ctx). + WithArguments(newRequest("example.com.", AAAA)). Should( SatisfyAll( HaveResponseType(ResponseTypeRESOLVED), @@ -366,7 +372,7 @@ var _ = Describe("CachingResolver", func() { }) It("should cache response and use max caching time as TTL if response TTL is bigger", func() { By("first request", func() { - Expect(sut.Resolve(newRequest("example.com.", AAAA))). + Expect(sut.Resolve(ctx, newRequest("example.com.", AAAA))). Should( SatisfyAll( HaveResponseType(ResponseTypeRESOLVED), @@ -377,7 +383,9 @@ var _ = Describe("CachingResolver", func() { }) By("second request", func() { - Eventually(sut.Resolve).WithArguments(newRequest("example.com.", AAAA)). + Eventually(sut.Resolve). + WithContext(ctx). + WithArguments(newRequest("example.com.", AAAA)). Should( SatisfyAll( HaveResponseType(ResponseTypeCACHED), @@ -405,7 +413,7 @@ var _ = Describe("CachingResolver", func() { }) It("should cache response and return 0 TTL if entry is expired", func() { By("first request", func() { - Expect(sut.Resolve(newRequest("example.com.", A))). + Expect(sut.Resolve(ctx, newRequest("example.com.", A))). Should( SatisfyAll( HaveResponseType(ResponseTypeRESOLVED), @@ -418,7 +426,9 @@ var _ = Describe("CachingResolver", func() { }) By("second request", func() { - Eventually(sut.Resolve, "2s").WithArguments(newRequest("example.com.", A)). + Eventually(sut.Resolve, "2s"). + WithContext(ctx). + WithArguments(newRequest("example.com.", A)). Should( SatisfyAll( HaveResponseType(ResponseTypeCACHED), @@ -449,7 +459,7 @@ var _ = Describe("CachingResolver", func() { }) By("first request", func() { - Expect(sut.Resolve(newRequest("example.com.", AAAA))). + Expect(sut.Resolve(ctx, newRequest("example.com.", AAAA))). Should(SatisfyAll( HaveResponseType(ResponseTypeRESOLVED), HaveReturnCode(dns.RcodeNameError), @@ -460,7 +470,9 @@ var _ = Describe("CachingResolver", func() { }) By("second request", func() { - Eventually(sut.Resolve).WithArguments(newRequest("example.com.", AAAA)). + Eventually(sut.Resolve). + WithContext(ctx). + WithArguments(newRequest("example.com.", AAAA)). Should(SatisfyAll( HaveResponseType(ResponseTypeCACHED), HaveReason("CACHED NEGATIVE"), @@ -483,7 +495,7 @@ var _ = Describe("CachingResolver", func() { It("response shouldn't be cached", func() { By("first request", func() { - Expect(sut.Resolve(newRequest("example.com.", AAAA))). + Expect(sut.Resolve(ctx, newRequest("example.com.", AAAA))). Should(SatisfyAll( HaveResponseType(ResponseTypeRESOLVED), HaveReturnCode(dns.RcodeNameError), @@ -494,7 +506,9 @@ var _ = Describe("CachingResolver", func() { }) By("second request", func() { - Eventually(sut.Resolve).WithArguments(newRequest("example.com.", AAAA)). + Eventually(sut.Resolve). + WithContext(ctx). + WithArguments(newRequest("example.com.", AAAA)). Should(SatisfyAll( HaveResponseType(ResponseTypeRESOLVED), HaveReason(""), @@ -517,7 +531,7 @@ var _ = Describe("CachingResolver", func() { It("response should be cached", func() { By("first request", func() { - Expect(sut.Resolve(newRequest("example.com.", AAAA))). + Expect(sut.Resolve(ctx, newRequest("example.com.", AAAA))). Should(SatisfyAll( HaveResponseType(ResponseTypeRESOLVED), HaveReturnCode(dns.RcodeSuccess), @@ -528,7 +542,9 @@ var _ = Describe("CachingResolver", func() { }) By("second request", func() { - Eventually(sut.Resolve).WithArguments(newRequest("example.com.", AAAA)). + Eventually(sut.Resolve). + WithContext(ctx). + WithArguments(newRequest("example.com.", AAAA)). Should(SatisfyAll( HaveResponseType(ResponseTypeCACHED), HaveReason("CACHED"), @@ -551,7 +567,7 @@ var _ = Describe("CachingResolver", func() { }) It("Should be cached", func() { By("first request", func() { - Expect(sut.Resolve(newRequest("google.de.", MX))). + Expect(sut.Resolve(ctx, newRequest("google.de.", MX))). Should(SatisfyAll( HaveResponseType(ResponseTypeRESOLVED), HaveReturnCode(dns.RcodeSuccess), @@ -563,7 +579,9 @@ var _ = Describe("CachingResolver", func() { }) By("second request", func() { - Eventually(sut.Resolve).WithArguments(newRequest("google.de.", MX)). + Eventually(sut.Resolve). + WithContext(ctx). + WithArguments(newRequest("google.de.", MX)). Should(SatisfyAll( HaveResponseType(ResponseTypeCACHED), HaveReason("CACHED"), @@ -587,7 +605,7 @@ var _ = Describe("CachingResolver", func() { }) It("Should not be cached", func() { By("first request", func() { - Expect(sut.Resolve(newRequest("google.de.", A))). + Expect(sut.Resolve(ctx, newRequest("google.de.", A))). Should(SatisfyAll( HaveResponseType(ResponseTypeRESOLVED), HaveReturnCode(dns.RcodeSuccess), @@ -599,7 +617,7 @@ var _ = Describe("CachingResolver", func() { }) By("second request", func() { - Expect(sut.Resolve(newRequest("google.de.", A))). + Expect(sut.Resolve(ctx, newRequest("google.de.", A))). Should(SatisfyAll( HaveResponseType(ResponseTypeRESOLVED), HaveReturnCode(dns.RcodeSuccess), @@ -621,7 +639,7 @@ var _ = Describe("CachingResolver", func() { }) It("Should not be cached", func() { By("first request", func() { - Expect(sut.Resolve(newRequest("google.de.", A))). + Expect(sut.Resolve(ctx, newRequest("google.de.", A))). Should(SatisfyAll( HaveResponseType(ResponseTypeRESOLVED), HaveReturnCode(dns.RcodeSuccess), @@ -633,7 +651,7 @@ var _ = Describe("CachingResolver", func() { }) By("second request", func() { - Expect(sut.Resolve(newRequest("google.de.", A))). + Expect(sut.Resolve(ctx, newRequest("google.de.", A))). Should(SatisfyAll( HaveResponseType(ResponseTypeRESOLVED), HaveReturnCode(dns.RcodeSuccess), @@ -659,7 +677,7 @@ var _ = Describe("CachingResolver", func() { }) It("Should not be cached", func() { By("first request", func() { - Expect(sut.Resolve(newRequest("google.de.", A))). + Expect(sut.Resolve(ctx, newRequest("google.de.", A))). Should(SatisfyAll( HaveResponseType(ResponseTypeRESOLVED), HaveReturnCode(dns.RcodeSuccess), @@ -676,7 +694,9 @@ var _ = Describe("CachingResolver", func() { }) By("second request", func() { - Eventually(sut.Resolve).WithArguments(newRequest("google.de.", A)). + Eventually(sut.Resolve). + WithContext(ctx). + WithArguments(newRequest("google.de.", A)). Should(SatisfyAll( HaveResponseType(ResponseTypeCACHED), HaveReason("CACHED"), @@ -738,7 +758,7 @@ var _ = Describe("CachingResolver", func() { }) It("put in redis", func() { - Expect(sut.Resolve(newRequest("example.com.", A))). + Expect(sut.Resolve(ctx, newRequest("example.com.", A))). Should(HaveResponseType(ResponseTypeRESOLVED)) Eventually(func() []string { @@ -760,7 +780,9 @@ var _ = Describe("CachingResolver", func() { } redisClient.CacheChannel <- redisMockMsg - Eventually(sut.Resolve).WithArguments(request). + Eventually(sut.Resolve). + WithContext(ctx). + WithArguments(request). Should( SatisfyAll( HaveResponseType(ResponseTypeCACHED), diff --git a/resolver/client_names_resolver_test.go b/resolver/client_names_resolver_test.go index dfe03fb07..4ea611b06 100644 --- a/resolver/client_names_resolver_test.go +++ b/resolver/client_names_resolver_test.go @@ -22,8 +22,9 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() { sut *ClientNamesResolver sutConfig config.ClientLookupConfig m *mockResolver - ctx context.Context - cancelFn context.CancelFunc + + ctx context.Context + cancelFn context.CancelFunc ) Describe("Type", func() { @@ -34,6 +35,7 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() { JustBeforeEach(func() { var err error + ctx, cancelFn = context.WithCancel(context.Background()) DeferCleanup(cancelFn) @@ -71,7 +73,7 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() { It("should use clientID if set", func() { request := newRequestWithClientID("google1.de.", dns.Type(dns.TypeA), "1.2.3.4", "client123") - Expect(sut.Resolve(request)). + Expect(sut.Resolve(ctx, request)). Should( SatisfyAll( HaveResponseType(ResponseTypeRESOLVED), @@ -82,7 +84,7 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() { }) It("should use IP as fallback if clientID not set", func() { request := newRequestWithClientID("google2.de.", dns.Type(dns.TypeA), "1.2.3.4", "") - Expect(sut.Resolve(request)). + Expect(sut.Resolve(ctx, request)). Should( SatisfyAll( HaveResponseType(ResponseTypeRESOLVED), @@ -112,7 +114,7 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() { It("should resolve defined name with ipv4 address", func() { request := newRequestWithClient("google.de.", dns.Type(dns.TypeA), "1.2.3.4") - Expect(sut.Resolve(request)). + Expect(sut.Resolve(ctx, request)). Should( SatisfyAll( HaveResponseType(ResponseTypeRESOLVED), @@ -124,7 +126,7 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() { It("should resolve defined name with ipv6 address", func() { request := newRequestWithClient("google.de.", dns.Type(dns.TypeA), "2a02:590:505:4700:2e4f:1503:ce74:df78") - Expect(sut.Resolve(request)). + Expect(sut.Resolve(ctx, request)). Should( SatisfyAll( HaveResponseType(ResponseTypeRESOLVED), @@ -135,7 +137,7 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() { }) It("should resolve multiple names defined names", func() { request := newRequestWithClient("google.de.", dns.Type(dns.TypeA), "1.2.3.5") - Expect(sut.Resolve(request)). + Expect(sut.Resolve(ctx, request)). Should( SatisfyAll( HaveResponseType(ResponseTypeRESOLVED), @@ -168,7 +170,7 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() { It("should resolve client name", func() { By("first request", func() { request := newRequestWithClient("google.de.", dns.Type(dns.TypeA), "192.168.178.25") - Expect(sut.Resolve(request)). + Expect(sut.Resolve(ctx, request)). Should( SatisfyAll( HaveResponseType(ResponseTypeRESOLVED), @@ -180,7 +182,7 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() { By("second request", func() { request := newRequestWithClient("google.de.", dns.Type(dns.TypeA), "192.168.178.25") - Expect(sut.Resolve(request)). + Expect(sut.Resolve(ctx, request)). Should( SatisfyAll( HaveResponseType(ResponseTypeRESOLVED), @@ -198,7 +200,7 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() { By("third request", func() { request := newRequestWithClient("google.de.", dns.Type(dns.TypeA), "192.168.178.25") - Expect(sut.Resolve(request)). + Expect(sut.Resolve(ctx, request)). Should( SatisfyAll( HaveResponseType(ResponseTypeRESOLVED), @@ -223,7 +225,7 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() { It("should resolve all client names", func() { request := newRequestWithClient("google.de.", dns.Type(dns.TypeA), "192.168.178.25") - Expect(sut.Resolve(request)). + Expect(sut.Resolve(ctx, request)). Should( SatisfyAll( HaveResponseType(ResponseTypeRESOLVED), @@ -251,7 +253,7 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() { It("should resolve client name", func() { request := newRequestWithClient("google.de.", dns.Type(dns.TypeA), "192.168.178.25") - Expect(sut.Resolve(request)). + Expect(sut.Resolve(ctx, request)). Should( SatisfyAll( HaveResponseType(ResponseTypeRESOLVED), @@ -272,7 +274,7 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() { It("should resolve the client name depending to defined order", func() { request := newRequestWithClient("google.de.", dns.Type(dns.TypeA), "192.168.178.25") - Expect(sut.Resolve(request)). + Expect(sut.Resolve(ctx, request)). Should( SatisfyAll( HaveResponseType(ResponseTypeRESOLVED), @@ -298,7 +300,7 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() { It("should use fallback for client name", func() { request := newRequestWithClient("google.de.", dns.Type(dns.TypeA), "192.168.178.25") - Expect(sut.Resolve(request)). + Expect(sut.Resolve(ctx, request)). Should( SatisfyAll( HaveResponseType(ResponseTypeRESOLVED), @@ -318,7 +320,7 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() { }) It("should use fallback for client name", func() { request := newRequestWithClient("google.de.", dns.Type(dns.TypeA), "192.168.178.25") - Expect(sut.Resolve(request)). + Expect(sut.Resolve(ctx, request)). Should( SatisfyAll( HaveResponseType(ResponseTypeRESOLVED), @@ -335,7 +337,7 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() { }) It("should resolve no names", func() { request := newRequestWithClient("google.de.", dns.Type(dns.TypeA), "") - Expect(sut.Resolve(request)). + Expect(sut.Resolve(ctx, request)). Should( SatisfyAll( HaveResponseType(ResponseTypeRESOLVED), @@ -351,7 +353,7 @@ var _ = Describe("ClientResolver", Label("clientNamesResolver"), func() { }) It("should use fallback for client name", func() { request := newRequestWithClient("google.de.", dns.Type(dns.TypeA), "192.168.178.25") - Expect(sut.Resolve(request)). + Expect(sut.Resolve(ctx, request)). Should( SatisfyAll( HaveResponseType(ResponseTypeRESOLVED), diff --git a/resolver/conditional_upstream_resolver_test.go b/resolver/conditional_upstream_resolver_test.go index 72f0cb2ac..2c6ef4bc8 100644 --- a/resolver/conditional_upstream_resolver_test.go +++ b/resolver/conditional_upstream_resolver_test.go @@ -19,6 +19,9 @@ var _ = Describe("ConditionalUpstreamResolver", Label("conditionalResolver"), fu var ( sut *ConditionalUpstreamResolver m *mockResolver + + ctx context.Context + cancelFn context.CancelFunc ) Describe("Type", func() { @@ -28,6 +31,9 @@ var _ = Describe("ConditionalUpstreamResolver", Label("conditionalResolver"), fu }) BeforeEach(func() { + ctx, cancelFn = context.WithCancel(context.Background()) + DeferCleanup(cancelFn) + fbTestUpstream := NewMockUDPUpstreamServer().WithAnswerFn(func(request *dns.Msg) (response *dns.Msg) { response, _ = util.NewMsgWithAnswer(request.Question[0].Name, 123, A, "123.124.122.122") @@ -59,7 +65,7 @@ var _ = Describe("ConditionalUpstreamResolver", Label("conditionalResolver"), fu }) DeferCleanup(refuseTestUpstream.Close) - sut, _ = NewConditionalUpstreamResolver(config.ConditionalUpstreamConfig{ + sut, _ = NewConditionalUpstreamResolver(ctx, config.ConditionalUpstreamConfig{ Mapping: config.ConditionalUpstreamMapping{ Upstreams: map[string][]config.Upstream{ "fritz.box": {fbTestUpstream.Start()}, @@ -93,7 +99,7 @@ var _ = Describe("ConditionalUpstreamResolver", Label("conditionalResolver"), fu Describe("Resolve conditional DNS queries via defined DNS server", func() { When("conditional resolver returns error code", func() { It("Should be returned without changes", func() { - Expect(sut.Resolve(newRequest("refused.domain.", A))). + Expect(sut.Resolve(ctx, newRequest("refused.domain.", A))). Should( SatisfyAll( HaveNoAnswer(), @@ -109,7 +115,7 @@ var _ = Describe("ConditionalUpstreamResolver", Label("conditionalResolver"), fu When("Query is exact equal defined condition in mapping", func() { Context("first mapping entry", func() { It("Should resolve the IP of conditional DNS", func() { - Expect(sut.Resolve(newRequest("fritz.box.", A))). + Expect(sut.Resolve(ctx, newRequest("fritz.box.", A))). Should( SatisfyAll( BeDNSRecord("fritz.box.", A, "123.124.122.122"), @@ -125,7 +131,7 @@ var _ = Describe("ConditionalUpstreamResolver", Label("conditionalResolver"), fu }) Context("last mapping entry", func() { It("Should resolve the IP of conditional DNS", func() { - Expect(sut.Resolve(newRequest("other.box.", A))). + Expect(sut.Resolve(ctx, newRequest("other.box.", A))). Should( SatisfyAll( BeDNSRecord("other.box.", A, "192.192.192.192"), @@ -141,7 +147,7 @@ var _ = Describe("ConditionalUpstreamResolver", Label("conditionalResolver"), fu }) When("Query is a subdomain of defined condition in mapping", func() { It("Should resolve the IP of subdomain", func() { - Expect(sut.Resolve(newRequest("test.fritz.box.", A))). + Expect(sut.Resolve(ctx, newRequest("test.fritz.box.", A))). Should( SatisfyAll( BeDNSRecord("test.fritz.box.", A, "123.124.122.122"), @@ -156,7 +162,7 @@ var _ = Describe("ConditionalUpstreamResolver", Label("conditionalResolver"), fu }) When("Query is not fqdn and . condition is defined in mapping", func() { It("Should resolve the IP of .", func() { - Expect(sut.Resolve(newRequest("test.", A))). + Expect(sut.Resolve(ctx, newRequest("test.", A))). Should( SatisfyAll( BeDNSRecord("test.", A, "168.168.168.168"), @@ -173,7 +179,7 @@ var _ = Describe("ConditionalUpstreamResolver", Label("conditionalResolver"), fu Describe("Delegation to next resolver", func() { When("Query doesn't match defined mapping", func() { It("should delegate to next resolver", func() { - Expect(sut.Resolve(newRequest("google.com.", A))). + Expect(sut.Resolve(ctx, newRequest("google.com.", A))). Should( SatisfyAll( HaveResponseType(ResponseTypeRESOLVED), @@ -186,11 +192,9 @@ var _ = Describe("ConditionalUpstreamResolver", Label("conditionalResolver"), fu When("upstream is invalid", func() { It("errors during construction", func() { - ctx, cancelFn := context.WithCancel(context.Background()) - DeferCleanup(cancelFn) b := newTestBootstrap(ctx, &dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure}}) - r, err := NewConditionalUpstreamResolver(config.ConditionalUpstreamConfig{ + r, err := NewConditionalUpstreamResolver(ctx, config.ConditionalUpstreamConfig{ Mapping: config.ConditionalUpstreamMapping{ Upstreams: map[string][]config.Upstream{ ".": {config.Upstream{Host: "example.com"}}, diff --git a/resolver/custom_dns_resolver_test.go b/resolver/custom_dns_resolver_test.go index 9ad65f170..d06f3672f 100644 --- a/resolver/custom_dns_resolver_test.go +++ b/resolver/custom_dns_resolver_test.go @@ -1,6 +1,7 @@ package resolver import ( + "context" "net" "time" @@ -21,6 +22,9 @@ var _ = Describe("CustomDNSResolver", func() { sut *CustomDNSResolver m *mockResolver cfg config.CustomDNSConfig + + ctx context.Context + cancelFn context.CancelFunc ) Describe("Type", func() { @@ -30,6 +34,9 @@ var _ = Describe("CustomDNSResolver", func() { }) BeforeEach(func() { + ctx, cancelFn = context.WithCancel(context.Background()) + DeferCleanup(cancelFn) + cfg = config.CustomDNSConfig{ Mapping: config.CustomDNSMapping{HostIPs: map[string][]net.IP{ "custom.domain": {net.ParseIP("192.168.143.123")}, @@ -73,7 +80,7 @@ var _ = Describe("CustomDNSResolver", func() { Context("filterUnmappedTypes is true", func() { BeforeEach(func() { cfg.FilterUnmappedTypes = true }) It("defined ip4 query should be resolved", func() { - Expect(sut.Resolve(newRequest("custom.domain.", A))). + Expect(sut.Resolve(ctx, newRequest("custom.domain.", A))). Should( SatisfyAll( BeDNSRecord("custom.domain.", A, "192.168.143.123"), @@ -86,7 +93,7 @@ var _ = Describe("CustomDNSResolver", func() { m.AssertNotCalled(GinkgoT(), "Resolve", mock.Anything) }) It("TXT query for defined mapping should return NOERROR and empty result", func() { - Expect(sut.Resolve(newRequest("custom.domain.", TXT))). + Expect(sut.Resolve(ctx, newRequest("custom.domain.", TXT))). Should( SatisfyAll( HaveNoAnswer(), @@ -98,7 +105,7 @@ var _ = Describe("CustomDNSResolver", func() { m.AssertNotCalled(GinkgoT(), "Resolve", mock.Anything) }) It("ip6 query should return NOERROR and empty result", func() { - Expect(sut.Resolve(newRequest("custom.domain.", AAAA))). + Expect(sut.Resolve(ctx, newRequest("custom.domain.", AAAA))). Should( SatisfyAll( HaveNoAnswer(), @@ -114,7 +121,7 @@ var _ = Describe("CustomDNSResolver", func() { Context("filterUnmappedTypes is false", func() { BeforeEach(func() { cfg.FilterUnmappedTypes = false }) It("defined ip4 query should be resolved", func() { - Expect(sut.Resolve(newRequest("custom.domain.", A))). + Expect(sut.Resolve(ctx, newRequest("custom.domain.", A))). Should( SatisfyAll( BeDNSRecord("custom.domain.", A, "192.168.143.123"), @@ -127,7 +134,7 @@ var _ = Describe("CustomDNSResolver", func() { m.AssertNotCalled(GinkgoT(), "Resolve", mock.Anything) }) It("TXT query for defined mapping should be delegated to next resolver", func() { - Expect(sut.Resolve(newRequest("custom.domain.", TXT))). + Expect(sut.Resolve(ctx, newRequest("custom.domain.", TXT))). Should( SatisfyAll( HaveNoAnswer(), @@ -139,7 +146,7 @@ var _ = Describe("CustomDNSResolver", func() { m.AssertExpectations(GinkgoT()) }) It("ip6 query should return NOERROR and empty result", func() { - Expect(sut.Resolve(newRequest("custom.domain.", AAAA))). + Expect(sut.Resolve(ctx, newRequest("custom.domain.", AAAA))). Should( SatisfyAll( HaveNoAnswer(), @@ -154,7 +161,7 @@ var _ = Describe("CustomDNSResolver", func() { }) When("Ip 6 mapping is defined for custom domain ", func() { It("ip6 query should be resolved", func() { - Expect(sut.Resolve(newRequest("ip6.domain.", AAAA))). + Expect(sut.Resolve(ctx, newRequest("ip6.domain.", AAAA))). Should( SatisfyAll( BeDNSRecord("ip6.domain.", AAAA, "2001:db8:85a3::8a2e:370:7334"), @@ -170,7 +177,7 @@ var _ = Describe("CustomDNSResolver", func() { When("Multiple IPs are defined for custom domain ", func() { It("all IPs for the current type should be returned", func() { By("IPv6 query", func() { - Expect(sut.Resolve(newRequest("multiple.ips.", AAAA))). + Expect(sut.Resolve(ctx, newRequest("multiple.ips.", AAAA))). Should( SatisfyAll( BeDNSRecord("multiple.ips.", AAAA, "2001:db8:85a3::8a2e:370:7334"), @@ -185,7 +192,7 @@ var _ = Describe("CustomDNSResolver", func() { }) By("IPv4 query", func() { - Expect(sut.Resolve(newRequest("multiple.ips.", A))). + Expect(sut.Resolve(ctx, newRequest("multiple.ips.", A))). Should( SatisfyAll( WithTransform(ToAnswer, SatisfyAll( @@ -207,7 +214,7 @@ var _ = Describe("CustomDNSResolver", func() { When("Reverse DNS request is received", func() { It("should resolve the defined domain name", func() { By("ipv4", func() { - Expect(sut.Resolve(newRequest("123.143.168.192.in-addr.arpa.", PTR))). + Expect(sut.Resolve(ctx, newRequest("123.143.168.192.in-addr.arpa.", PTR))). Should( SatisfyAll( WithTransform(ToAnswer, SatisfyAll( @@ -226,7 +233,7 @@ var _ = Describe("CustomDNSResolver", func() { }) By("ipv6", func() { - Expect(sut.Resolve(newRequest("4.3.3.7.0.7.3.0.e.2.a.8.0.0.0.0.0.0.0.0.3.a.5.8.8.b.d.0.1.0.0.2.ip6.arpa.", + Expect(sut.Resolve(ctx, newRequest("4.3.3.7.0.7.3.0.e.2.a.8.0.0.0.0.0.0.0.0.3.a.5.8.8.b.d.0.1.0.0.2.ip6.arpa.", PTR))). Should( SatisfyAll( @@ -250,7 +257,7 @@ var _ = Describe("CustomDNSResolver", func() { }) When("Domain mapping is defined", func() { It("subdomain must also match", func() { - Expect(sut.Resolve(newRequest("ABC.CUSTOM.DOMAIN.", A))). + Expect(sut.Resolve(ctx, newRequest("ABC.CUSTOM.DOMAIN.", A))). Should( SatisfyAll( BeDNSRecord("ABC.CUSTOM.DOMAIN.", A, "192.168.143.123"), @@ -268,7 +275,7 @@ var _ = Describe("CustomDNSResolver", func() { Describe("Delegating to next resolver", func() { When("no mapping for domain exist", func() { It("should delegate to next resolver", func() { - Expect(sut.Resolve(newRequest("example.com.", A))). + Expect(sut.Resolve(ctx, newRequest("example.com.", A))). Should( SatisfyAll( HaveResponseType(ResponseTypeRESOLVED), diff --git a/resolver/ede_resolver_test.go b/resolver/ede_resolver_test.go index fc7e7ad5b..ec4e87911 100644 --- a/resolver/ede_resolver_test.go +++ b/resolver/ede_resolver_test.go @@ -1,6 +1,7 @@ package resolver import ( + "context" "errors" "github.com/0xERR0R/blocky/config" @@ -21,6 +22,9 @@ var _ = Describe("EdeResolver", func() { sutConfig config.EdeConfig m *mockResolver mockAnswer *dns.Msg + + ctx context.Context + cancelFn context.CancelFunc ) Describe("Type", func() { @@ -30,6 +34,9 @@ var _ = Describe("EdeResolver", func() { }) BeforeEach(func() { + ctx, cancelFn = context.WithCancel(context.Background()) + DeferCleanup(cancelFn) + mockAnswer = new(dns.Msg) }) @@ -54,7 +61,7 @@ var _ = Describe("EdeResolver", func() { } }) It("shouldn't add EDE information", func() { - Expect(sut.Resolve(newRequest("example.com.", A))). + Expect(sut.Resolve(ctx, newRequest("example.com.", A))). Should( SatisfyAll( HaveNoAnswer(), @@ -86,7 +93,7 @@ var _ = Describe("EdeResolver", func() { } It("should add EDE information", func() { - Expect(sut.Resolve(newRequest("example.com.", A))). + Expect(sut.Resolve(ctx, newRequest("example.com.", A))). Should( SatisfyAll( HaveNoAnswer(), @@ -114,7 +121,7 @@ var _ = Describe("EdeResolver", func() { }) It("should return it", func() { - resp, err := sut.Resolve(newRequest("example.com", A)) + resp, err := sut.Resolve(ctx, newRequest("example.com", A)) Expect(resp).To(BeNil()) Expect(err).To(Equal(resolveErr)) }) diff --git a/resolver/filtering_resolver_test.go b/resolver/filtering_resolver_test.go index aeaf56008..cccadd2e2 100644 --- a/resolver/filtering_resolver_test.go +++ b/resolver/filtering_resolver_test.go @@ -1,6 +1,8 @@ package resolver import ( + "context" + "github.com/0xERR0R/blocky/config" . "github.com/0xERR0R/blocky/helpertest" "github.com/0xERR0R/blocky/log" @@ -18,6 +20,9 @@ var _ = Describe("FilteringResolver", func() { sutConfig config.FilteringConfig m *mockResolver mockAnswer *dns.Msg + + ctx context.Context + cancelFn context.CancelFunc ) Describe("Type", func() { @@ -27,6 +32,9 @@ var _ = Describe("FilteringResolver", func() { }) BeforeEach(func() { + ctx, cancelFn = context.WithCancel(context.Background()) + DeferCleanup(cancelFn) + mockAnswer = new(dns.Msg) }) @@ -60,7 +68,7 @@ var _ = Describe("FilteringResolver", func() { } }) It("Should delegate to next resolver if request query has other type", func() { - Expect(sut.Resolve(newRequest("example.com.", A))). + Expect(sut.Resolve(ctx, newRequest("example.com.", A))). Should( SatisfyAll( HaveNoAnswer(), @@ -72,7 +80,7 @@ var _ = Describe("FilteringResolver", func() { Expect(m.Calls).Should(HaveLen(1)) }) It("Should return empty answer for defined query type", func() { - Expect(sut.Resolve(newRequest("example.com.", AAAA))). + Expect(sut.Resolve(ctx, newRequest("example.com.", AAAA))). Should( SatisfyAll( HaveNoAnswer(), @@ -90,7 +98,7 @@ var _ = Describe("FilteringResolver", func() { sutConfig = config.FilteringConfig{} }) It("Should return empty answer without error", func() { - Expect(sut.Resolve(newRequest("example.com.", AAAA))). + Expect(sut.Resolve(ctx, newRequest("example.com.", AAAA))). Should( SatisfyAll( HaveNoAnswer(), diff --git a/resolver/fqdn_only_resolver_test.go b/resolver/fqdn_only_resolver_test.go index 7838cc349..666db3517 100644 --- a/resolver/fqdn_only_resolver_test.go +++ b/resolver/fqdn_only_resolver_test.go @@ -1,6 +1,8 @@ package resolver import ( + "context" + "github.com/0xERR0R/blocky/config" . "github.com/0xERR0R/blocky/helpertest" "github.com/0xERR0R/blocky/log" @@ -17,6 +19,9 @@ var _ = Describe("FqdnOnlyResolver", func() { sutConfig config.FqdnOnlyConfig m *mockResolver mockAnswer *dns.Msg + + ctx context.Context + cancelFn context.CancelFunc ) Describe("Type", func() { @@ -26,6 +31,9 @@ var _ = Describe("FqdnOnlyResolver", func() { }) BeforeEach(func() { + ctx, cancelFn = context.WithCancel(context.Background()) + DeferCleanup(cancelFn) + mockAnswer = new(dns.Msg) }) @@ -57,7 +65,7 @@ var _ = Describe("FqdnOnlyResolver", func() { sutConfig = config.FqdnOnlyConfig{Enable: true} }) It("Should delegate to next resolver if request query is fqdn", func() { - Expect(sut.Resolve(newRequest("example.com", A))). + Expect(sut.Resolve(ctx, newRequest("example.com", A))). Should( SatisfyAll( HaveNoAnswer(), @@ -69,7 +77,7 @@ var _ = Describe("FqdnOnlyResolver", func() { Expect(m.Calls).Should(HaveLen(1)) }) It("Should return NXDOMAIN if request query is not fqdn", func() { - Expect(sut.Resolve(newRequest("example", AAAA))). + Expect(sut.Resolve(ctx, newRequest("example", AAAA))). Should( SatisfyAll( HaveNoAnswer(), @@ -103,7 +111,7 @@ var _ = Describe("FqdnOnlyResolver", func() { sutConfig = config.FqdnOnlyConfig{Enable: false} }) It("Should delegate to next resolver if request query is fqdn", func() { - Expect(sut.Resolve(newRequest("example.com", A))). + Expect(sut.Resolve(ctx, newRequest("example.com", A))). Should( SatisfyAll( HaveNoAnswer(), @@ -115,7 +123,7 @@ var _ = Describe("FqdnOnlyResolver", func() { Expect(m.Calls).Should(HaveLen(1)) }) It("Should delegate to next resolver if request query is not fqdn", func() { - Expect(sut.Resolve(newRequest("example", AAAA))). + Expect(sut.Resolve(ctx, newRequest("example", AAAA))). Should( SatisfyAll( HaveNoAnswer(), diff --git a/resolver/hosts_file_resolver_test.go b/resolver/hosts_file_resolver_test.go index 79fead569..197c75dc3 100644 --- a/resolver/hosts_file_resolver_test.go +++ b/resolver/hosts_file_resolver_test.go @@ -25,6 +25,9 @@ var _ = Describe("HostsFileResolver", func() { tmpFile *TmpFile err error resp *Response + + ctx context.Context + cancelFn context.CancelFunc ) Describe("Type", func() { @@ -34,6 +37,9 @@ var _ = Describe("HostsFileResolver", func() { }) BeforeEach(func() { + ctx, cancelFn = context.WithCancel(context.Background()) + DeferCleanup(cancelFn) + tmpDir = NewTmpFolder("HostsFileResolver") Expect(tmpDir.Error).Should(Succeed()) DeferCleanup(tmpDir.Clean) @@ -53,8 +59,6 @@ var _ = Describe("HostsFileResolver", func() { }) JustBeforeEach(func() { - ctx, cancelFn := context.WithCancel(context.Background()) - DeferCleanup(cancelFn) sut, err = NewHostsFileResolver(ctx, sutConfig, systemResolverBootstrap) Expect(err).Should(Succeed()) @@ -96,7 +100,7 @@ var _ = Describe("HostsFileResolver", func() { Expect(sut.hosts.isEmpty()).Should(BeTrue()) }) It("should go to next resolver on query", func() { - Expect(sut.Resolve(newRequest("example.com.", A))). + Expect(sut.Resolve(ctx, newRequest("example.com.", A))). Should( SatisfyAll( HaveResponseType(ResponseTypeRESOLVED), @@ -112,11 +116,11 @@ var _ = Describe("HostsFileResolver", func() { sutConfig.Sources = make([]config.BytesSource, 0) }) JustBeforeEach(func() { - err = sut.loadSources(context.Background()) + err = sut.loadSources(ctx) Expect(err).Should(Succeed()) }) It("should go to next resolver on query", func() { - Expect(sut.Resolve(newRequest("example.com.", A))). + Expect(sut.Resolve(ctx, newRequest("example.com.", A))). Should( SatisfyAll( HaveResponseType(ResponseTypeRESOLVED), @@ -178,7 +182,7 @@ var _ = Describe("HostsFileResolver", func() { When("IPv4 mapping is defined for a host", func() { It("defined ipv4 query should be resolved", func() { - Expect(sut.Resolve(newRequest("ipv4host.", A))). + Expect(sut.Resolve(ctx, newRequest("ipv4host.", A))). Should( SatisfyAll( HaveResponseType(ResponseTypeHOSTSFILE), @@ -188,7 +192,7 @@ var _ = Describe("HostsFileResolver", func() { )) }) It("defined ipv4 query for alias should be resolved", func() { - Expect(sut.Resolve(newRequest("router2.", A))). + Expect(sut.Resolve(ctx, newRequest("router2.", A))). Should( SatisfyAll( HaveResponseType(ResponseTypeHOSTSFILE), @@ -198,7 +202,7 @@ var _ = Describe("HostsFileResolver", func() { )) }) It("ipv4 query should return NOERROR and empty result", func() { - Expect(sut.Resolve(newRequest("does.not.exist.", A))). + Expect(sut.Resolve(ctx, newRequest("does.not.exist.", A))). Should( SatisfyAll( HaveNoAnswer(), @@ -210,7 +214,7 @@ var _ = Describe("HostsFileResolver", func() { When("IPv6 mapping is defined for a host", func() { It("defined ipv6 query should be resolved", func() { - Expect(sut.Resolve(newRequest("ipv6host.", AAAA))). + Expect(sut.Resolve(ctx, newRequest("ipv6host.", AAAA))). Should( SatisfyAll( HaveResponseType(ResponseTypeHOSTSFILE), @@ -220,7 +224,7 @@ var _ = Describe("HostsFileResolver", func() { )) }) It("ipv6 query should return NOERROR and empty result", func() { - Expect(sut.Resolve(newRequest("does.not.exist.", AAAA))). + Expect(sut.Resolve(ctx, newRequest("does.not.exist.", AAAA))). Should( SatisfyAll( HaveNoAnswer(), @@ -232,7 +236,7 @@ var _ = Describe("HostsFileResolver", func() { When("the domain is not known", func() { It("calls the next resolver", func() { - resp, err = sut.Resolve(newRequest("not-in-hostsfile.tld.", A)) + resp, err = sut.Resolve(ctx, newRequest("not-in-hostsfile.tld.", A)) Expect(err).Should(Succeed()) Expect(resp).ShouldNot(HaveResponseType(ResponseTypeHOSTSFILE)) m.AssertExpectations(GinkgoT()) @@ -241,7 +245,7 @@ var _ = Describe("HostsFileResolver", func() { When("the question type is not handled", func() { It("calls the next resolver", func() { - resp, err = sut.Resolve(newRequest("localhost.", MX)) + resp, err = sut.Resolve(ctx, newRequest("localhost.", MX)) Expect(err).Should(Succeed()) Expect(resp).ShouldNot(HaveResponseType(ResponseTypeHOSTSFILE)) m.AssertExpectations(GinkgoT()) @@ -251,7 +255,7 @@ var _ = Describe("HostsFileResolver", func() { When("Reverse DNS request is received", func() { It("should resolve the defined domain name", func() { By("ipv4 with one hostname", func() { - Expect(sut.Resolve(newRequest("2.0.0.10.in-addr.arpa.", PTR))). + Expect(sut.Resolve(ctx, newRequest("2.0.0.10.in-addr.arpa.", PTR))). Should( SatisfyAll( HaveResponseType(ResponseTypeHOSTSFILE), @@ -261,7 +265,7 @@ var _ = Describe("HostsFileResolver", func() { )) }) By("ipv4 with aliases", func() { - Expect(sut.Resolve(newRequest("1.0.0.10.in-addr.arpa.", PTR))). + Expect(sut.Resolve(ctx, newRequest("1.0.0.10.in-addr.arpa.", PTR))). Should( SatisfyAll( HaveResponseType(ResponseTypeHOSTSFILE), @@ -274,7 +278,9 @@ var _ = Describe("HostsFileResolver", func() { )) }) By("ipv6", func() { - Expect(sut.Resolve(newRequest("1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.f.a.a.f.f.a.a.f.f.a.a.f.f.a.a.f.ip6.arpa.", PTR))). + Expect(sut.Resolve(ctx, + newRequest("1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.f.a.a.f.f.a.a.f.f.a.a.f.f.a.a.f.ip6.arpa.", PTR)), + ). Should( SatisfyAll( HaveResponseType(ResponseTypeHOSTSFILE), @@ -290,7 +296,7 @@ var _ = Describe("HostsFileResolver", func() { }) It("should ignore invalid PTR", func() { - resp, err = sut.Resolve(newRequest("2.0.0.10.in-addr.fail.arpa.", PTR)) + resp, err = sut.Resolve(ctx, newRequest("2.0.0.10.in-addr.fail.arpa.", PTR)) Expect(err).Should(Succeed()) Expect(resp).ShouldNot(HaveResponseType(ResponseTypeHOSTSFILE)) m.AssertExpectations(GinkgoT()) @@ -298,7 +304,7 @@ var _ = Describe("HostsFileResolver", func() { When("filterLoopback is true", func() { It("calls the next resolver", func() { - resp, err = sut.Resolve(newRequest("1.0.0.127.in-addr.arpa.", PTR)) + resp, err = sut.Resolve(ctx, newRequest("1.0.0.127.in-addr.arpa.", PTR)) Expect(err).Should(Succeed()) Expect(resp).ShouldNot(HaveResponseType(ResponseTypeHOSTSFILE)) m.AssertExpectations(GinkgoT()) @@ -307,7 +313,7 @@ var _ = Describe("HostsFileResolver", func() { When("the IP is not known", func() { It("calls the next resolver", func() { - resp, err = sut.Resolve(newRequest("255.255.255.255.in-addr.arpa.", PTR)) + resp, err = sut.Resolve(ctx, newRequest("255.255.255.255.in-addr.arpa.", PTR)) Expect(err).Should(Succeed()) Expect(resp).ShouldNot(HaveResponseType(ResponseTypeHOSTSFILE)) m.AssertExpectations(GinkgoT()) @@ -320,7 +326,7 @@ var _ = Describe("HostsFileResolver", func() { }) It("resolve the defined domain name", func() { - Expect(sut.Resolve(newRequest("1.1.0.127.in-addr.arpa.", PTR))). + Expect(sut.Resolve(ctx, newRequest("1.1.0.127.in-addr.arpa.", PTR))). Should( SatisfyAll( HaveResponseType(ResponseTypeHOSTSFILE), @@ -338,7 +344,7 @@ var _ = Describe("HostsFileResolver", func() { Describe("Delegating to next resolver", func() { When("no hosts file is provided", func() { It("should delegate to next resolver", func() { - _, err = sut.Resolve(newRequest("example.com.", A)) + _, err = sut.Resolve(ctx, newRequest("example.com.", A)) Expect(err).Should(Succeed()) // delegate was executed m.AssertExpectations(GinkgoT()) diff --git a/resolver/metrics_resolver_test.go b/resolver/metrics_resolver_test.go index a686e5377..d62cf5973 100644 --- a/resolver/metrics_resolver_test.go +++ b/resolver/metrics_resolver_test.go @@ -1,6 +1,7 @@ package resolver import ( + "context" "errors" "github.com/0xERR0R/blocky/config" @@ -21,6 +22,9 @@ var _ = Describe("MetricResolver", func() { var ( sut *MetricsResolver m *mockResolver + + ctx context.Context + cancelFn context.CancelFunc ) Describe("Type", func() { @@ -30,6 +34,9 @@ var _ = Describe("MetricResolver", func() { }) BeforeEach(func() { + ctx, cancelFn = context.WithCancel(context.Background()) + DeferCleanup(cancelFn) + sut = NewMetricsResolver(config.MetricsConfig{Enable: true}) m = &mockResolver{} m.On("Resolve", mock.Anything).Return(&Response{Res: new(dns.Msg)}, nil) @@ -56,7 +63,7 @@ var _ = Describe("MetricResolver", func() { Context("Recording request metrics", func() { When("Request will be performed", func() { It("Should record metrics", func() { - Expect(sut.Resolve(newRequestWithClient("example.com.", A, "", "client"))). + Expect(sut.Resolve(ctx, newRequestWithClient("example.com.", A, "", "client"))). Should( SatisfyAll( HaveResponseType(ResponseTypeRESOLVED), @@ -77,7 +84,7 @@ var _ = Describe("MetricResolver", func() { sut.Next(m) }) It("Error should be recorded", func() { - _, err := sut.Resolve(newRequestWithClient("example.com.", A, "", "client")) + _, err := sut.Resolve(ctx, newRequestWithClient("example.com.", A, "", "client")) Expect(err).Should(HaveOccurred()) diff --git a/resolver/mocks_test.go b/resolver/mocks_test.go index 48fa25862..83d4f71fa 100644 --- a/resolver/mocks_test.go +++ b/resolver/mocks_test.go @@ -23,7 +23,7 @@ type mockResolver struct { mock.Mock NextResolver - ResolveFn func(req *model.Request) (*model.Response, error) + ResolveFn func(ctx context.Context, req *model.Request) (*model.Response, error) ResponseFn func(req *dns.Msg) *dns.Msg AnswerFn func(qType dns.Type, qName string) (*dns.Msg, error) } @@ -45,11 +45,11 @@ func (r *mockResolver) LogConfig(*logrus.Entry) { r.Called() } -func (r *mockResolver) Resolve(req *model.Request) (*model.Response, error) { +func (r *mockResolver) Resolve(ctx context.Context, req *model.Request) (*model.Response, error) { args := r.Called(req) if r.ResolveFn != nil { - return r.ResolveFn(req) + return r.ResolveFn(ctx, req) } if r.ResponseFn != nil { diff --git a/resolver/noop_resolver_test.go b/resolver/noop_resolver_test.go index b9f470966..9e0adb1c9 100644 --- a/resolver/noop_resolver_test.go +++ b/resolver/noop_resolver_test.go @@ -1,6 +1,8 @@ package resolver import ( + "context" + . "github.com/0xERR0R/blocky/helpertest" "github.com/0xERR0R/blocky/log" . "github.com/onsi/ginkgo/v2" @@ -8,7 +10,12 @@ import ( ) var _ = Describe("NoOpResolver", func() { - var sut *NoOpResolver + var ( + sut *NoOpResolver + + ctx context.Context + cancelFn context.CancelFunc + ) Describe("Type", func() { It("follows conventions", func() { @@ -17,12 +24,15 @@ var _ = Describe("NoOpResolver", func() { }) BeforeEach(func() { + ctx, cancelFn = context.WithCancel(context.Background()) + DeferCleanup(cancelFn) + sut = NewNoOpResolver() }) Describe("Resolving", func() { It("returns no response", func() { - resp, err := sut.Resolve(newRequest("test.tld", A)) + resp, err := sut.Resolve(ctx, newRequest("test.tld", A)) Expect(err).Should(Succeed()) Expect(resp).Should(Equal(NoResponse)) }) diff --git a/resolver/parallel_best_resolver_test.go b/resolver/parallel_best_resolver_test.go index cf7088e16..026534221 100644 --- a/resolver/parallel_best_resolver_test.go +++ b/resolver/parallel_best_resolver_test.go @@ -64,7 +64,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() { Upstreams: upstreams, } - sut, err = NewParallelBestResolver(sutConfig, bootstrap, sutVerify) + sut, err = NewParallelBestResolver(ctx, sutConfig, bootstrap, sutVerify) }) Describe("IsEnabled", func() { @@ -104,7 +104,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() { mockUpstream.Start(), } - _, err := NewParallelBestResolver(config.UpstreamGroup{ + _, err := NewParallelBestResolver(ctx, config.UpstreamGroup{ Name: upstreamDefaultCfgName, Upstreams: upstreams, }, @@ -192,7 +192,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() { }) It("Should use result from fastest one", func() { request := newRequest("example.com.", A) - Expect(sut.Resolve(request)). + Expect(sut.Resolve(ctx, request)). Should( SatisfyAll( BeDNSRecord("example.com.", A, "123.124.122.122"), @@ -214,7 +214,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() { }) It("Should use result from successful resolver", func() { request := newRequest("example.com.", A) - Expect(sut.Resolve(request)). + Expect(sut.Resolve(ctx, request)). Should( SatisfyAll( BeDNSRecord("example.com.", A, "123.124.122.123"), @@ -235,7 +235,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() { It("Should return error", func() { Expect(err).Should(Succeed()) request := newRequest("example.com.", A) - _, err = sut.Resolve(request) + _, err = sut.Resolve(ctx, request) Expect(err).Should(HaveOccurred()) }) @@ -251,7 +251,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() { It("Should use result from defined resolver", func() { request := newRequest("example.com.", A) - Expect(sut.Resolve(request)). + Expect(sut.Resolve(ctx, request)). Should( SatisfyAll( BeDNSRecord("example.com.", A, "123.124.122.122"), @@ -275,7 +275,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() { mockUpstream2 := NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.122") DeferCleanup(mockUpstream2.Close) - sut, _ = NewParallelBestResolver(config.UpstreamGroup{ + sut, _ = NewParallelBestResolver(ctx, config.UpstreamGroup{ Name: upstreamDefaultCfgName, Upstreams: []config.Upstream{withError1, mockUpstream1.Start(), mockUpstream2.Start(), withError2}, }, @@ -301,7 +301,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() { By("perform 100 request, error upstream's weight will be reduced", func() { for i := 0; i < 100; i++ { request := newRequest("example.com.", A) - _, _ = sut.Resolve(request) + _, _ = sut.Resolve(ctx, request) } }) @@ -335,7 +335,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() { It("errors during construction", func() { b := newTestBootstrap(ctx, &dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure}}) - r, err := NewParallelBestResolver(config.UpstreamGroup{ + r, err := NewParallelBestResolver(ctx, config.UpstreamGroup{ Name: "test", Upstreams: []config.Upstream{{Host: "example.com"}}, }, b, verifyUpstreams) @@ -372,7 +372,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() { }) It("Should return result from either one", func() { request := newRequest("example.com.", A) - Expect(sut.Resolve(request)). + Expect(sut.Resolve(ctx, request)). Should(SatisfyAll( HaveTTL(BeNumerically("==", 123)), HaveResponseType(ResponseTypeRESOLVED), @@ -398,7 +398,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() { }) It("should ask another random upstream and return its response", func() { request := newRequest("example.com", A) - Expect(sut.Resolve(request)).Should( + Expect(sut.Resolve(ctx, request)).Should( SatisfyAll( BeDNSRecord("example.com.", A, "123.124.122.2"), HaveTTL(BeNumerically("==", 123)), @@ -439,7 +439,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() { }) It("Should return error", func() { request := newRequest("example.com.", A) - _, err := sut.Resolve(request) + _, err := sut.Resolve(ctx, request) Expect(err).Should(HaveOccurred()) }) }) @@ -454,7 +454,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() { It("Should use result from defined resolver", func() { request := newRequest("example.com.", A) - Expect(sut.Resolve(request)). + Expect(sut.Resolve(ctx, request)). Should( SatisfyAll( BeDNSRecord("example.com.", A, "123.124.122.122"), @@ -478,7 +478,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() { mockUpstream2 := NewMockUDPUpstreamServer().WithAnswerRR("example.com 123 IN A 123.124.122.122") DeferCleanup(mockUpstream2.Close) - sut, _ = NewParallelBestResolver(config.UpstreamGroup{ + sut, _ = NewParallelBestResolver(ctx, config.UpstreamGroup{ Name: upstreamDefaultCfgName, Upstreams: []config.Upstream{withError1, mockUpstream1.Start(), mockUpstream2.Start(), withError2}, }, @@ -499,7 +499,7 @@ var _ = Describe("ParallelBestResolver", Label("parallelBestResolver"), func() { By("perform 100 request, error upstream's weight will be reduced", func() { for i := 0; i < 100; i++ { request := newRequest("example.com.", A) - _, _ = sut.Resolve(request) + _, _ = sut.Resolve(ctx, request) } }) diff --git a/resolver/query_logging_resolver_test.go b/resolver/query_logging_resolver_test.go index efc33bf51..40e326e54 100644 --- a/resolver/query_logging_resolver_test.go +++ b/resolver/query_logging_resolver_test.go @@ -44,6 +44,9 @@ var _ = Describe("QueryLoggingResolver", func() { m *mockResolver tmpDir *TmpFolder mockAnswer *dns.Msg + + ctx context.Context + cancelFn context.CancelFunc ) Describe("Type", func() { @@ -53,6 +56,9 @@ var _ = Describe("QueryLoggingResolver", func() { }) BeforeEach(func() { + ctx, cancelFn = context.WithCancel(context.Background()) + DeferCleanup(cancelFn) + mockAnswer = new(dns.Msg) tmpDir = NewTmpFolder("queryLoggingResolver") Expect(tmpDir.Error).Should(Succeed()) @@ -64,9 +70,6 @@ var _ = Describe("QueryLoggingResolver", func() { sutConfig.SetDefaults() // not called when using a struct literal } - ctx, cancelFn := context.WithCancel(context.Background()) - DeferCleanup(cancelFn) - sut = NewQueryLoggingResolver(ctx, sutConfig) m = &mockResolver{} m.On("Resolve", mock.Anything).Return(&Response{Res: mockAnswer, Reason: "reason"}, nil) @@ -98,7 +101,7 @@ var _ = Describe("QueryLoggingResolver", func() { } }) It("should process request without query logging", func() { - Expect(sut.Resolve(newRequest("example.com", A))). + Expect(sut.Resolve(ctx, newRequest("example.com", A))). Should( SatisfyAll( HaveResponseType(ResponseTypeRESOLVED), @@ -120,7 +123,7 @@ var _ = Describe("QueryLoggingResolver", func() { }) It("should create a log file per client", func() { By("request from client 1", func() { - Expect(sut.Resolve(newRequestWithClient("example.com.", A, "192.168.178.25", "client1"))). + Expect(sut.Resolve(ctx, newRequestWithClient("example.com.", A, "192.168.178.25", "client1"))). Should( SatisfyAll( HaveResponseType(ResponseTypeRESOLVED), @@ -128,7 +131,7 @@ var _ = Describe("QueryLoggingResolver", func() { )) }) By("request from client 2, has name with special chars, should be escaped", func() { - Expect(sut.Resolve(newRequestWithClient( + Expect(sut.Resolve(ctx, newRequestWithClient( "example.com.", A, "192.168.178.26", "cl/ient2\\$%&test"))). Should( SatisfyAll( @@ -188,7 +191,7 @@ var _ = Describe("QueryLoggingResolver", func() { }) It("should create one log file for all clients", func() { By("request from client 1", func() { - Expect(sut.Resolve(newRequestWithClient("example.com.", A, "192.168.178.25", "client1"))). + Expect(sut.Resolve(ctx, newRequestWithClient("example.com.", A, "192.168.178.25", "client1"))). Should( SatisfyAll( HaveResponseType(ResponseTypeRESOLVED), @@ -196,7 +199,7 @@ var _ = Describe("QueryLoggingResolver", func() { )) }) By("request from client 2, has name with special chars, should be escaped", func() { - Expect(sut.Resolve(newRequestWithClient("example.com.", A, "192.168.178.26", "client2"))). + Expect(sut.Resolve(ctx, newRequestWithClient("example.com.", A, "192.168.178.26", "client2"))). Should( SatisfyAll( HaveResponseType(ResponseTypeRESOLVED), @@ -249,7 +252,7 @@ var _ = Describe("QueryLoggingResolver", func() { }) It("should create one log file", func() { By("request from client 1", func() { - Expect(sut.Resolve(newRequestWithClient("example.com.", A, "192.168.178.25", "client1"))). + Expect(sut.Resolve(ctx, newRequestWithClient("example.com.", A, "192.168.178.25", "client1"))). Should( SatisfyAll( HaveResponseType(ResponseTypeRESOLVED), @@ -297,7 +300,7 @@ var _ = Describe("QueryLoggingResolver", func() { sut.writer = mockWriter Eventually(func() int { - _, ierr := sut.Resolve(newRequestWithClient("example.com.", A, "192.168.178.25", "client1")) + _, ierr := sut.Resolve(ctx, newRequestWithClient("example.com.", A, "192.168.178.25", "client1")) Expect(ierr).Should(Succeed()) return len(sut.logChan) diff --git a/resolver/rewriter_resolver_test.go b/resolver/rewriter_resolver_test.go index a74d0f317..6027c7458 100644 --- a/resolver/rewriter_resolver_test.go +++ b/resolver/rewriter_resolver_test.go @@ -1,6 +1,8 @@ package resolver import ( + "context" + "github.com/0xERR0R/blocky/config" "github.com/0xERR0R/blocky/log" "github.com/0xERR0R/blocky/model" @@ -94,7 +96,7 @@ var _ = Describe("RewriterResolver", func() { return res } - resp, err := sut.Resolve(request) + resp, err := sut.Resolve(context.Background(), request) Expect(err).Should(Succeed()) if resp != mNextResponse { Expect(resp.Res.Question[0].Name).Should(Equal(fqdnOriginal)) @@ -132,18 +134,18 @@ var _ = Describe("RewriterResolver", func() { expectNilAnswer = true // Make inner call the NoOpResolver - mInner.ResolveFn = func(req *model.Request) (*model.Response, error) { + mInner.ResolveFn = func(ctx context.Context, req *model.Request) (*model.Response, error) { Expect(req).Should(Equal(request)) // Inner should see fqdnRewritten Expect(req.Req.Question[0].Name).Should(Equal(fqdnRewritten)) - return mInner.next.Resolve(req) + return mInner.next.Resolve(ctx, req) } // Resolver after RewriterResolver should see `fqdnOriginal` mNext.On("Resolve", mock.Anything) - mNext.ResolveFn = func(req *model.Request) (*model.Response, error) { + mNext.ResolveFn = func(ctx context.Context, req *model.Request) (*model.Response, error) { Expect(req.Req.Question[0].Name).Should(Equal(fqdnOriginal)) return mNextResponse, nil @@ -156,7 +158,7 @@ var _ = Describe("RewriterResolver", func() { expectNilAnswer = true // Make inner return a nil Answer but not an empty Response - mInner.ResolveFn = func(req *model.Request) (*model.Response, error) { + mInner.ResolveFn = func(ctx context.Context, req *model.Request) (*model.Response, error) { Expect(req).Should(Equal(request)) // Inner should see fqdnRewritten @@ -179,7 +181,7 @@ var _ = Describe("RewriterResolver", func() { fqdnRewritten = sampleRewritten // Make inner return a nil Answer but not an empty Response - mInner.ResolveFn = func(req *model.Request) (*model.Response, error) { + mInner.ResolveFn = func(ctx context.Context, req *model.Request) (*model.Response, error) { Expect(req).Should(Equal(request)) // Inner should see fqdnRewritten @@ -190,7 +192,7 @@ var _ = Describe("RewriterResolver", func() { // Resolver after RewriterResolver should see `fqdnOriginal` mNext.On("Resolve", mock.Anything) - mNext.ResolveFn = func(req *model.Request) (*model.Response, error) { + mNext.ResolveFn = func(ctx context.Context, req *model.Request) (*model.Response, error) { Expect(req.Req.Question[0].Name).Should(Equal(fqdnOriginal)) return mNextResponse, nil diff --git a/resolver/strict_resolver_test.go b/resolver/strict_resolver_test.go index 6f5e3dd74..0a4468442 100644 --- a/resolver/strict_resolver_test.go +++ b/resolver/strict_resolver_test.go @@ -1,6 +1,7 @@ package resolver import ( + "context" "time" "github.com/0xERR0R/blocky/config" @@ -27,6 +28,9 @@ var _ = Describe("StrictResolver", Label("strictResolver"), func() { err error bootstrap *Bootstrap + + ctx context.Context + cancelFn context.CancelFunc ) Describe("Type", func() { @@ -36,6 +40,9 @@ var _ = Describe("StrictResolver", Label("strictResolver"), func() { }) BeforeEach(func() { + ctx, cancelFn = context.WithCancel(context.Background()) + DeferCleanup(cancelFn) + upstreams = []config.Upstream{ {Host: "wrong"}, {Host: "127.0.0.2"}, @@ -51,7 +58,7 @@ var _ = Describe("StrictResolver", Label("strictResolver"), func() { Name: upstreamDefaultCfgName, Upstreams: upstreams, } - sut, err = NewStrictResolver(sutConfig, bootstrap, sutVerify) + sut, err = NewStrictResolver(ctx, sutConfig, bootstrap, sutVerify) }) config.GetConfig().Upstreams.Timeout = config.Duration(time.Second) @@ -100,7 +107,7 @@ var _ = Describe("StrictResolver", Label("strictResolver"), func() { mockUpstream.Start(), } - _, err := NewStrictResolver(config.UpstreamGroup{ + _, err := NewStrictResolver(ctx, config.UpstreamGroup{ Name: upstreamDefaultCfgName, Upstreams: upstreams, }, @@ -151,7 +158,7 @@ var _ = Describe("StrictResolver", Label("strictResolver"), func() { }) It("Should use result from first one", func() { request := newRequest("example.com.", A) - Expect(sut.Resolve(request)). + Expect(sut.Resolve(ctx, request)). Should( SatisfyAll( BeDNSRecord("example.com.", A, "123.124.122.122"), @@ -180,7 +187,7 @@ var _ = Describe("StrictResolver", Label("strictResolver"), func() { }) It("should return response from next upstream", func() { request := newRequest("example.com", A) - Expect(sut.Resolve(request)).Should( + Expect(sut.Resolve(ctx, request)).Should( SatisfyAll( BeDNSRecord("example.com.", A, "123.124.122.2"), HaveTTL(BeNumerically("==", 123)), @@ -214,7 +221,7 @@ var _ = Describe("StrictResolver", Label("strictResolver"), func() { }) It("should return error", func() { request := newRequest("example.com", A) - _, err := sut.Resolve(request) + _, err := sut.Resolve(ctx, request) Expect(err).To(HaveOccurred()) }) }) @@ -230,7 +237,7 @@ var _ = Describe("StrictResolver", Label("strictResolver"), func() { }) It("Should use result from second one", func() { request := newRequest("example.com.", A) - Expect(sut.Resolve(request)). + Expect(sut.Resolve(ctx, request)). Should( SatisfyAll( BeDNSRecord("example.com.", A, "123.124.122.123"), @@ -247,7 +254,7 @@ var _ = Describe("StrictResolver", Label("strictResolver"), func() { }) It("Should return error", func() { request := newRequest("example.com.", A) - _, err = sut.Resolve(request) + _, err = sut.Resolve(ctx, request) Expect(err).Should(HaveOccurred()) }) }) @@ -262,7 +269,7 @@ var _ = Describe("StrictResolver", Label("strictResolver"), func() { It("Should use result from defined resolver", func() { request := newRequest("example.com.", A) - Expect(sut.Resolve(request)). + Expect(sut.Resolve(ctx, request)). Should( SatisfyAll( BeDNSRecord("example.com.", A, "123.124.122.122"), diff --git a/resolver/sudn_resolver_test.go b/resolver/sudn_resolver_test.go index 1c004b203..6dc3de282 100644 --- a/resolver/sudn_resolver_test.go +++ b/resolver/sudn_resolver_test.go @@ -1,6 +1,7 @@ package resolver import ( + "context" "fmt" "github.com/0xERR0R/blocky/config" @@ -19,6 +20,9 @@ var _ = Describe("SudnResolver", Label("sudnResolver"), func() { sut *SpecialUseDomainNamesResolver sutConfig config.SUDNConfig m *mockResolver + + ctx context.Context + cancelFn context.CancelFunc ) Describe("Type", func() { @@ -30,6 +34,9 @@ var _ = Describe("SudnResolver", Label("sudnResolver"), func() { BeforeEach(func() { var err error + ctx, cancelFn = context.WithCancel(context.Background()) + DeferCleanup(cancelFn) + sutConfig, err = config.WithDefaults[config.SUDNConfig]() Expect(err).Should(Succeed()) }) @@ -48,7 +55,7 @@ var _ = Describe("SudnResolver", Label("sudnResolver"), func() { Describe("handlers", func() { It("should have correct response type", func() { for domain, handler := range sudnHandlers { - resp, err := sut.Resolve(newRequest(domain, A)) + resp, err := sut.Resolve(ctx, newRequest(domain, A)) Expect(err).Should(Succeed()) if handler == nil { @@ -90,7 +97,7 @@ var _ = Describe("SudnResolver", Label("sudnResolver"), func() { DescribeTable("handled domains", func(qType dns.Type, qName string, expectedRCode int, extraMatchers ...types.GomegaMatcher) { - resp, err := sut.Resolve(newRequest(qName, qType)) + resp, err := sut.Resolve(ctx, newRequest(qName, qType)) Expect(err).Should(Succeed()) Expect(resp).Should(SatisfyAll( HaveResponseType(ResponseTypeSPECIAL), @@ -133,7 +140,7 @@ var _ = Describe("SudnResolver", Label("sudnResolver"), func() { DescribeTable("", func(qType dns.Type, qName string, expectedRCode int) { - resp, err := sut.Resolve(newRequest(qName, qType)) + resp, err := sut.Resolve(ctx, newRequest(qName, qType)) Expect(err).Should(Succeed()) Expect(resp).Should(HaveReturnCode(expectedRCode)) Expect(resp).ShouldNot(HaveResponseType(ResponseTypeSPECIAL)) @@ -150,7 +157,7 @@ var _ = Describe("SudnResolver", Label("sudnResolver"), func() { }) It("should forward example.com", func() { - Expect(sut.Resolve(newRequest("example.com", A))). + Expect(sut.Resolve(ctx, newRequest("example.com", A))). Should( SatisfyAll( BeDNSRecord("example.com.", A, "123.145.123.145"), @@ -161,7 +168,7 @@ var _ = Describe("SudnResolver", Label("sudnResolver"), func() { }) It("should forward home.arpa. IN DS", func() { - Expect(sut.Resolve(newRequest("something.home.arpa.", DS))). + Expect(sut.Resolve(ctx, newRequest("something.home.arpa.", DS))). Should( SatisfyAll( // setup code doesn't care about the question @@ -173,7 +180,7 @@ var _ = Describe("SudnResolver", Label("sudnResolver"), func() { }) It("should forward non special use domains", func() { - resp, err := sut.Resolve(newRequest("something.not-special.", AAAA)) + resp, err := sut.Resolve(ctx, newRequest("something.not-special.", AAAA)) Expect(err).Should(Succeed()) Expect(resp).ShouldNot(HaveResponseType(ResponseTypeSPECIAL)) }) diff --git a/resolver/upstream_resolver_test.go b/resolver/upstream_resolver_test.go index 8f5a26065..3c9f94e7d 100644 --- a/resolver/upstream_resolver_test.go +++ b/resolver/upstream_resolver_test.go @@ -1,6 +1,7 @@ package resolver import ( + "context" "crypto/tls" "fmt" "net/http" @@ -21,9 +22,15 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() { var ( sut *UpstreamResolver sutConfig config.Upstream + + ctx context.Context + cancelFn context.CancelFunc ) BeforeEach(func() { + ctx, cancelFn = context.WithCancel(context.Background()) + DeferCleanup(cancelFn) + sutConfig = config.Upstream{Host: "localhost"} }) @@ -62,7 +69,7 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() { upstream := mockUpstream.Start() sut := newUpstreamResolverUnchecked(upstream, nil) - Expect(sut.Resolve(newRequest("example.com.", A))). + Expect(sut.Resolve(ctx, newRequest("example.com.", A))). Should( SatisfyAll( BeDNSRecord("example.com.", A, "123.124.122.122"), @@ -81,7 +88,7 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() { upstream := mockUpstream.Start() sut := newUpstreamResolverUnchecked(upstream, nil) - Expect(sut.Resolve(newRequest("example.com.", A))). + Expect(sut.Resolve(ctx, newRequest("example.com.", A))). Should( SatisfyAll( HaveNoAnswer(), @@ -100,7 +107,7 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() { upstream := mockUpstream.Start() sut := newUpstreamResolverUnchecked(upstream, nil) - _, err := sut.Resolve(newRequest("example.com.", A)) + _, err := sut.Resolve(ctx, newRequest("example.com.", A)) Expect(err).Should(HaveOccurred()) }) }) @@ -133,7 +140,7 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() { atomic.StoreInt32(&counter, 0) atomic.StoreInt32(&attemptsWithTimeout, 2) - Expect(sut.Resolve(newRequest("example.com.", A))). + Expect(sut.Resolve(ctx, newRequest("example.com.", A))). Should( SatisfyAll( BeDNSRecord("example.com.", A, "123.124.122.122"), @@ -146,7 +153,7 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() { By("3 attempts with timeout -> should return error", func() { atomic.StoreInt32(&counter, 0) atomic.StoreInt32(&attemptsWithTimeout, 3) - _, err := sut.Resolve(newRequest("example.com.", A)) + _, err := sut.Resolve(ctx, newRequest("example.com.", A)) Expect(err).Should(HaveOccurred()) Expect(err.Error()).Should(ContainSubstring("i/o timeout")) }) @@ -185,7 +192,7 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() { }) When("Configured DOH resolver can resolve query", func() { It("should return answer from DNS upstream", func() { - Expect(sut.Resolve(newRequest("example.com.", A))). + Expect(sut.Resolve(ctx, newRequest("example.com.", A))). Should( SatisfyAll( BeDNSRecord("example.com.", A, "123.124.122.122"), @@ -203,7 +210,7 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() { } }) It("should return error", func() { - _, err := sut.Resolve(newRequest("example.com.", A)) + _, err := sut.Resolve(ctx, newRequest("example.com.", A)) Expect(err).Should(HaveOccurred()) Expect(err.Error()).Should(ContainSubstring("http return code should be 200, but received 500")) }) @@ -215,7 +222,7 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() { } }) It("should return error", func() { - _, err := sut.Resolve(newRequest("example.com.", A)) + _, err := sut.Resolve(ctx, newRequest("example.com.", A)) Expect(err).Should(HaveOccurred()) Expect(err.Error()).Should( ContainSubstring("http return content type should be 'application/dns-message', but was 'text'")) @@ -228,7 +235,7 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() { } }) It("should return error", func() { - _, err := sut.Resolve(newRequest("example.com.", A)) + _, err := sut.Resolve(ctx, newRequest("example.com.", A)) Expect(err).Should(HaveOccurred()) Expect(err.Error()).Should(ContainSubstring("can't unpack message")) }) @@ -241,7 +248,7 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() { }, systemResolverBootstrap) }) It("should return error", func() { - _, err := sut.Resolve(newRequest("example.com.", A)) + _, err := sut.Resolve(ctx, newRequest("example.com.", A)) Expect(err).Should(HaveOccurred()) Expect(err.Error()).Should(Or( ContainSubstring("no such host"), diff --git a/resolver/upstream_tree_resolver_test.go b/resolver/upstream_tree_resolver_test.go index 3855761ab..11b709042 100644 --- a/resolver/upstream_tree_resolver_test.go +++ b/resolver/upstream_tree_resolver_test.go @@ -1,6 +1,8 @@ package resolver import ( + "context" + "github.com/0xERR0R/blocky/config" . "github.com/0xERR0R/blocky/helpertest" "github.com/0xERR0R/blocky/log" @@ -139,7 +141,15 @@ var _ = Describe("UpstreamTreeResolver", Label("upstreamTreeResolver"), func() { }) When("client specific resolvers are defined", func() { + var ( + ctx context.Context + cancelFn context.CancelFunc + ) + BeforeEach(func() { + ctx, cancelFn = context.WithCancel(context.Background()) + DeferCleanup(cancelFn) + sutConfig = config.UpstreamsConfig{Groups: config.UpstreamGroups{ upstreamDefaultCfgName: {config.Upstream{}}, "laptop": {config.Upstream{}}, @@ -191,7 +201,7 @@ var _ = Describe("UpstreamTreeResolver", Label("upstreamTreeResolver"), func() { It("Should use default if client name or IP don't match", func() { request := newRequestWithClient("example.com.", A, "192.168.178.55", "test") - Expect(sut.Resolve(request)). + Expect(sut.Resolve(ctx, request)). Should( SatisfyAll( BeDNSRecord("example.com.", A, "default"), @@ -202,7 +212,7 @@ var _ = Describe("UpstreamTreeResolver", Label("upstreamTreeResolver"), func() { It("Should use client specific resolver if client name matches exact", func() { request := newRequestWithClient("example.com.", A, "192.168.178.55", "laptop") - Expect(sut.Resolve(request)). + Expect(sut.Resolve(ctx, request)). Should( SatisfyAll( BeDNSRecord("example.com.", A, "laptop"), @@ -213,7 +223,7 @@ var _ = Describe("UpstreamTreeResolver", Label("upstreamTreeResolver"), func() { It("Should use client specific resolver if client name matches with wildcard", func() { request := newRequestWithClient("example.com.", A, "192.168.178.55", "client-test-m") - Expect(sut.Resolve(request)). + Expect(sut.Resolve(ctx, request)). Should( SatisfyAll( BeDNSRecord("example.com.", A, "client-*-m"), @@ -224,7 +234,7 @@ var _ = Describe("UpstreamTreeResolver", Label("upstreamTreeResolver"), func() { It("Should use client specific resolver if client name matches with range wildcard", func() { request := newRequestWithClient("example.com.", A, "192.168.178.55", "client7") - Expect(sut.Resolve(request)). + Expect(sut.Resolve(ctx, request)). Should( SatisfyAll( BeDNSRecord("example.com.", A, "client[0-9]"), @@ -235,7 +245,7 @@ var _ = Describe("UpstreamTreeResolver", Label("upstreamTreeResolver"), func() { It("Should use client specific resolver if client IP matches", func() { request := newRequestWithClient("example.com.", A, "192.168.178.33", "noname") - Expect(sut.Resolve(request)). + Expect(sut.Resolve(ctx, request)). Should( SatisfyAll( BeDNSRecord("example.com.", A, "192.168.178.33"), @@ -246,7 +256,7 @@ var _ = Describe("UpstreamTreeResolver", Label("upstreamTreeResolver"), func() { It("Should use client specific resolver if client name (containing IP) matches", func() { request := newRequestWithClient("example.com.", A, "0.0.0.0", "192.168.178.33") - Expect(sut.Resolve(request)). + Expect(sut.Resolve(ctx, request)). Should( SatisfyAll( BeDNSRecord("example.com.", A, "192.168.178.33"), @@ -257,7 +267,7 @@ var _ = Describe("UpstreamTreeResolver", Label("upstreamTreeResolver"), func() { It("Should use client specific resolver if client's CIDR (10.43.8.64 - 10.43.8.79) matches", func() { request := newRequestWithClient("example.com.", A, "10.43.8.70", "noname") - Expect(sut.Resolve(request)). + Expect(sut.Resolve(ctx, request)). Should( SatisfyAll( BeDNSRecord("example.com.", A, "10.43.8.67/28"), @@ -268,7 +278,7 @@ var _ = Describe("UpstreamTreeResolver", Label("upstreamTreeResolver"), func() { It("Should use exact IP match before client name match", func() { request := newRequestWithClient("example.com.", A, "192.168.178.33", "laptop") - Expect(sut.Resolve(request)). + Expect(sut.Resolve(ctx, request)). Should( SatisfyAll( BeDNSRecord("example.com.", A, "192.168.178.33"), @@ -279,7 +289,7 @@ var _ = Describe("UpstreamTreeResolver", Label("upstreamTreeResolver"), func() { It("Should use client name match before CIDR match", func() { request := newRequestWithClient("example.com.", A, "10.43.8.70", "laptop") - Expect(sut.Resolve(request)). + Expect(sut.Resolve(ctx, request)). Should( SatisfyAll( BeDNSRecord("example.com.", A, "laptop"), @@ -293,7 +303,7 @@ var _ = Describe("UpstreamTreeResolver", Label("upstreamTreeResolver"), func() { request := newRequestWithClient("example.com.", A, "0.0.0.0", "name-matches1") request.Log = logger - Expect(sut.Resolve(request)). + Expect(sut.Resolve(ctx, request)). Should( SatisfyAll( SatisfyAny( diff --git a/server/server_test.go b/server/server_test.go index 5c902a728..21f45474e 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -697,7 +697,7 @@ var _ = Describe("Running DNS server", func() { Describe("NewServer with strict upstream strategy", func() { It("successfully returns upstream branches", func() { - branches, err := createUpstreamBranches(&config.Config{ + branches, err := createUpstreamBranches(context.Background(), &config.Config{ Upstreams: config.UpstreamsConfig{ Strategy: config.UpstreamStrategyStrict, Groups: config.UpstreamGroups{ @@ -715,7 +715,7 @@ var _ = Describe("Running DNS server", func() { Describe("NewServer with random upstream strategy", func() { It("successfully returns upstream branches", func() { - branches, err := createUpstreamBranches(&config.Config{ + branches, err := createUpstreamBranches(context.Background(), &config.Config{ Upstreams: config.UpstreamsConfig{ Strategy: config.UpstreamStrategyRandom, Groups: config.UpstreamGroups{