Skip to content

Commit

Permalink
squash: fix(tests): update to pass contexts where needed
Browse files Browse the repository at this point in the history
This is split to try and make review easier by having the previous
commit with all interesting changes, and then boring commits with just
the API churn.
  • Loading branch information
ThinkChaos committed Nov 19, 2023
1 parent 48b213d commit c9db986
Show file tree
Hide file tree
Showing 24 changed files with 392 additions and 269 deletions.
39 changes: 22 additions & 17 deletions api/api_interface_impl_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"errors"
"time"

// . "github.com/0xERR0R/blocky/helpertest"
"github.com/0xERR0R/blocky/model"
"github.com/0xERR0R/blocky/util"
"github.com/miekg/dns"
Expand Down Expand Up @@ -54,14 +53,14 @@ func (m *BlockingControlMock) BlockingStatus() BlockingStatus {
return args.Get(0).(BlockingStatus)
}

func (m *QuerierMock) Query(question string, qType dns.Type) (*model.Response, error) {
args := m.Called(question, qType)
func (m *QuerierMock) Query(ctx context.Context, question string, qType dns.Type) (*model.Response, error) {
args := m.Called(ctx, question, qType)

return args.Get(0).(*model.Response), args.Error(1)
}

func (m *CacheControlMock) FlushCaches() {
_ = m.Called()
func (m *CacheControlMock) FlushCaches(ctx context.Context) {
_ = m.Called(ctx)
}

var _ = Describe("API implementation tests", func() {
Expand All @@ -71,9 +70,15 @@ var _ = Describe("API implementation tests", func() {
listRefreshMock *ListRefreshMock
cacheControlMock *CacheControlMock
sut *OpenAPIInterfaceImpl

ctx context.Context
cancelFn context.CancelFunc
)

BeforeEach(func() {
ctx, cancelFn = context.WithCancel(context.Background())
DeferCleanup(cancelFn)

blockingControlMock = &BlockingControlMock{}
querierMock = &QuerierMock{}
listRefreshMock = &ListRefreshMock{}
Expand All @@ -95,12 +100,12 @@ var _ = Describe("API implementation tests", func() {
)
Expect(err).Should(Succeed())

querierMock.On("Query", "google.com.", A).Return(&model.Response{
querierMock.On("Query", ctx, "google.com.", A).Return(&model.Response{
Res: queryResponse,
Reason: "reason",
}, nil)

resp, err := sut.Query(context.Background(), QueryRequestObject{
resp, err := sut.Query(ctx, QueryRequestObject{
Body: &ApiQueryRequest{
Query: "google.com", Type: "A",
},
Expand All @@ -116,7 +121,7 @@ var _ = Describe("API implementation tests", func() {
})

It("should return 400 on wrong parameter", func() {
resp, err := sut.Query(context.Background(), QueryRequestObject{
resp, err := sut.Query(ctx, QueryRequestObject{
Body: &ApiQueryRequest{
Query: "google.com",
Type: "WRONGTYPE",
Expand All @@ -135,7 +140,7 @@ var _ = Describe("API implementation tests", func() {
It("should return 200 on success", func() {
listRefreshMock.On("RefreshLists").Return(nil)

resp, err := sut.ListRefresh(context.Background(), ListRefreshRequestObject{})
resp, err := sut.ListRefresh(ctx, ListRefreshRequestObject{})
Expect(err).Should(Succeed())
var resp200 ListRefresh200Response
Expect(resp).Should(BeAssignableToTypeOf(resp200))
Expand All @@ -144,7 +149,7 @@ var _ = Describe("API implementation tests", func() {
It("should return 500 on failure", func() {
listRefreshMock.On("RefreshLists").Return(errors.New("failed"))

resp, err := sut.ListRefresh(context.Background(), ListRefreshRequestObject{})
resp, err := sut.ListRefresh(ctx, ListRefreshRequestObject{})
Expect(err).Should(Succeed())
var resp500 ListRefresh500TextResponse
Expect(resp).Should(BeAssignableToTypeOf(resp500))
Expand All @@ -160,7 +165,7 @@ var _ = Describe("API implementation tests", func() {
duration := "3s"
grroups := "gr1,gr2"

resp, err := sut.DisableBlocking(context.Background(), DisableBlockingRequestObject{
resp, err := sut.DisableBlocking(ctx, DisableBlockingRequestObject{
Params: DisableBlockingParams{
Duration: &duration,
Groups: &grroups,
Expand All @@ -173,7 +178,7 @@ var _ = Describe("API implementation tests", func() {

It("should return 400 on failure", func() {
blockingControlMock.On("DisableBlocking", mock.Anything, mock.Anything).Return(errors.New("failed"))
resp, err := sut.DisableBlocking(context.Background(), DisableBlockingRequestObject{})
resp, err := sut.DisableBlocking(ctx, DisableBlockingRequestObject{})
Expect(err).Should(Succeed())
var resp400 DisableBlocking400TextResponse
Expect(resp).Should(BeAssignableToTypeOf(resp400))
Expand All @@ -182,7 +187,7 @@ var _ = Describe("API implementation tests", func() {

It("should return 400 on wrong duration parameter", func() {
wrongDuration := "4sds"
resp, err := sut.DisableBlocking(context.Background(), DisableBlockingRequestObject{
resp, err := sut.DisableBlocking(ctx, DisableBlockingRequestObject{
Params: DisableBlockingParams{
Duration: &wrongDuration,
},
Expand All @@ -197,7 +202,7 @@ var _ = Describe("API implementation tests", func() {
It("should return 200 on success", func() {
blockingControlMock.On("EnableBlocking").Return()

resp, err := sut.EnableBlocking(context.Background(), EnableBlockingRequestObject{})
resp, err := sut.EnableBlocking(ctx, EnableBlockingRequestObject{})
Expect(err).Should(Succeed())
var resp200 EnableBlocking200Response
Expect(resp).Should(BeAssignableToTypeOf(resp200))
Expand All @@ -212,7 +217,7 @@ var _ = Describe("API implementation tests", func() {
AutoEnableInSec: 47,
})

resp, err := sut.BlockingStatus(context.Background(), BlockingStatusRequestObject{})
resp, err := sut.BlockingStatus(ctx, BlockingStatusRequestObject{})
Expect(err).Should(Succeed())
var resp200 BlockingStatus200JSONResponse
Expect(resp).Should(BeAssignableToTypeOf(resp200))
Expand All @@ -227,8 +232,8 @@ var _ = Describe("API implementation tests", func() {
Describe("Cache API", func() {
When("Cache flush is called", func() {
It("should return 200 on success", func() {
cacheControlMock.On("FlushCaches").Return()
resp, err := sut.CacheFlush(context.Background(), CacheFlushRequestObject{})
cacheControlMock.On("FlushCaches", ctx).Return()
resp, err := sut.CacheFlush(ctx, CacheFlushRequestObject{})
Expect(err).Should(Succeed())
var resp200 CacheFlush200Response
Expect(resp).Should(BeAssignableToTypeOf(resp200))
Expand Down
6 changes: 3 additions & 3 deletions cache/expirationcache/expiration_cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ var _ = Describe("Expiration cache", func() {
Describe("preExpiration function", func() {
When("function is defined", func() {
It("should update the value and TTL if function returns values", func() {
fn := func(key string) (val *string, ttl time.Duration) {
fn := func(ctx context.Context, key string) (val *string, ttl time.Duration) {
v2 := "v2"

return &v2, time.Second
Expand All @@ -169,7 +169,7 @@ var _ = Describe("Expiration cache", func() {
})

It("should update the value and TTL if function returns values on cleanup if element is expired", func() {
fn := func(key string) (val *string, ttl time.Duration) {
fn := func(ctx context.Context, key string) (val *string, ttl time.Duration) {
v2 := "val2"

return &v2, time.Second
Expand All @@ -192,7 +192,7 @@ var _ = Describe("Expiration cache", func() {
})

It("should delete the key if function returns nil", func() {
fn := func(key string) (val *string, ttl time.Duration) {
fn := func(ctx context.Context, key string) (val *string, ttl time.Duration) {
return nil, 0
}
cache := NewCacheWithOnExpired[string](ctx, Options{CleanupInterval: 100 * time.Microsecond}, fn)
Expand Down
8 changes: 4 additions & 4 deletions cache/expirationcache/prefetching_cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ var _ = Describe("Prefetching expiration cache", func() {
},
PrefetchThreshold: 2,
PrefetchExpires: 100 * time.Millisecond,
ReloadFn: func(cacheKey string) (*string, time.Duration) {
ReloadFn: func(ctx context.Context, cacheKey string) (*string, time.Duration) {
v := "v2"

return &v, 50 * time.Millisecond
Expand Down Expand Up @@ -86,7 +86,7 @@ var _ = Describe("Prefetching expiration cache", func() {
},
PrefetchThreshold: 2,
PrefetchExpires: 100 * time.Millisecond,
ReloadFn: func(cacheKey string) (*string, time.Duration) {
ReloadFn: func(ctx context.Context, cacheKey string) (*string, time.Duration) {
v := "v2"

return &v, 50 * time.Millisecond
Expand All @@ -113,7 +113,7 @@ var _ = Describe("Prefetching expiration cache", func() {
Options: Options{
CleanupInterval: 100 * time.Millisecond,
},
ReloadFn: func(cacheKey string) (*string, time.Duration) {
ReloadFn: func(ctx context.Context, cacheKey string) (*string, time.Duration) {
v := "v2"

return &v, 50 * time.Millisecond
Expand Down Expand Up @@ -143,7 +143,7 @@ var _ = Describe("Prefetching expiration cache", func() {
},
PrefetchThreshold: 2,
PrefetchExpires: 100 * time.Millisecond,
ReloadFn: func(cacheKey string) (*string, time.Duration) {
ReloadFn: func(ctx context.Context, cacheKey string) (*string, time.Duration) {
v := "v2"

return &v, 50 * time.Millisecond
Expand Down
Loading

0 comments on commit c9db986

Please sign in to comment.