Skip to content

Commit

Permalink
feat: stop timer when processBatch trigger
Browse files Browse the repository at this point in the history
  • Loading branch information
sysulq committed Aug 1, 2024
1 parent 4aec5e8 commit c3fc508
Show file tree
Hide file tree
Showing 5 changed files with 108 additions and 49 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/go.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ jobs:
cache: true

- name: Test
run: go test -v ./... -coverprofile=coverage.out
run: go test -race -v ./... -coverprofile=coverage.out

- name: Upload coverage reports to Codecov
uses: codecov/[email protected]
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ Feature
- 200+ lines of code, easy to understand and maintain.
- 100% test coverage, bug free and reliable.
- Based on generics and can be used with any type of data.
- Use a LRU cache to store the loaded values.
- Use hashicorp/golang-lru to cache the loaded values.
- Can be used to batch and cache multiple requests.
- Deduplicate identical requests, reducing the number of requests.

Expand Down
66 changes: 29 additions & 37 deletions dataloader.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,13 @@ type config struct {

// dataLoader is the main struct for the dataloader
type dataLoader[K comparable, V any] struct {
loader Loader[K, V]
cache *expirable.LRU[K, V]
config config
mu sync.Mutex
batch []K
chs map[K][]chan Result[V]
loader Loader[K, V]
cache *expirable.LRU[K, V]
config config
mu sync.Mutex
batch []K
chs map[K][]chan Result[V]
stopSchedule chan struct{}
}

// Interface is a `DataLoader` Interface which defines a public API for loading data from a particular
Expand All @@ -59,22 +60,6 @@ type Interface[K comparable, V any] interface {
Prime(ctx context.Context, key K, value V) Interface[K, V]
}

// Result is the result of a DataLoader operation
type Result[V any] struct {
data V
err error
}

// Wrap wraps data and an error into a Result
func Wrap[V any](data V, err error) Result[V] {
return Result[V]{data: data, err: err}
}

// TryUnwrap returns the data or an error
func (r Result[V]) Unwrap() (V, error) {
return r.data, r.err
}

// New creates a new DataLoader with the given loader function and options
func New[K comparable, V any](loader Loader[K, V], options ...Option) Interface[K, V] {
config := config{
Expand All @@ -89,10 +74,10 @@ func New[K comparable, V any](loader Loader[K, V], options ...Option) Interface[
}

dl := &dataLoader[K, V]{
loader: loader,
config: config,
loader: loader,
config: config,
stopSchedule: make(chan struct{}),
}

dl.reset()

// Create a cache if the cache size is greater than 0
Expand Down Expand Up @@ -145,7 +130,8 @@ func (d *dataLoader[K, V]) goLoad(ctx context.Context, key K) <-chan Result[V] {
d.mu.Lock()
if len(d.batch) == 0 {
// If there are no keys in the current batch, schedule a new batch timer
go d.scheduleBatch(ctx, ch)
d.stopSchedule = make(chan struct{})
go d.scheduleBatch(ctx, d.stopSchedule)
}

// Check if the key is in flight
Expand All @@ -155,18 +141,19 @@ func (d *dataLoader[K, V]) goLoad(ctx context.Context, key K) <-chan Result[V] {
return ch
}

// Add the key and channel to the current batch
d.batch = append(d.batch, key)
d.chs[key] = []chan Result[V]{ch}

// If the current batch is full, start processing it
if len(d.batch) >= d.config.BatchSize {
// spawn a new goroutine to process the batch
go d.processBatch(ctx, d.batch, d.chs)
close(d.stopSchedule)
// Create a new batch, and a new set of channels
d.reset()
}

// Add the key and channel to the current batch
d.batch = append(d.batch, key)
d.chs[key] = []chan Result[V]{ch}

// Unlock the DataLoader
d.mu.Unlock()

Expand Down Expand Up @@ -215,7 +202,7 @@ func (d *dataLoader[K, V]) reset() {
}

// scheduleBatch schedules a batch to be processed
func (d *dataLoader[K, V]) scheduleBatch(ctx context.Context, ch chan Result[V]) {
func (d *dataLoader[K, V]) scheduleBatch(ctx context.Context, stopSchedule <-chan struct{}) {
select {
case <-time.After(d.config.Wait):
d.mu.Lock()
Expand All @@ -224,8 +211,8 @@ func (d *dataLoader[K, V]) scheduleBatch(ctx context.Context, ch chan Result[V])
d.reset()
}
d.mu.Unlock()
case <-ctx.Done():
ch <- Result[V]{err: ctx.Err()}
case <-stopSchedule:
return
}
}

Expand Down Expand Up @@ -254,10 +241,15 @@ func (d *dataLoader[K, V]) processBatch(ctx context.Context, keys []K, chs map[K
d.cache.Add(key, results[i].data)
}

for _, ch := range chs[key] {
ch <- results[i]
close(ch)
}
sendResult(chs[key], results[i])
}
}

// sendResult sends a result to channels
func sendResult[V any](chs []chan Result[V], result Result[V]) {
for _, ch := range chs {
ch <- result
close(ch)
}
}

Expand Down
70 changes: 60 additions & 10 deletions dataloader_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@ func TestDataLoader(t *testing.T) {
t.Run("Panic recovered", testPanicRecovered)
t.Run("Prime", testPrime)
t.Run("Inflight", testInflight)
t.Run("Schedule batch", testScheduleBatch)
}

func testInflight(t *testing.T) {
func testScheduleBatch(t *testing.T) {
loader := New(func(ctx context.Context, keys []int) []Result[string] {
if len(keys) != 5 {
t.Errorf("Expected 5 keys, got %d", keys)
Expand All @@ -35,20 +36,58 @@ func testInflight(t *testing.T) {
results[i] = Result[string]{data: fmt.Sprintf("Result for %d", key)}
}
return results
}, WithBatchSize(5))
}, WithBatchSize(5), WithWait(100*time.Millisecond))

chs := make([]<-chan Result[string], 0)
for i := 0; i < 10; i++ {
chs = append(chs, loader.(*dataLoader[int, string]).goLoad(context.Background(), i/2))
for i := 0; i < 4; i++ {
chs = append(chs, loader.(*dataLoader[int, string]).goLoad(context.Background(), i))
}
time.Sleep(60 * time.Millisecond)

for i := 4; i < 10; i++ {
chs = append(chs, loader.(*dataLoader[int, string]).goLoad(context.Background(), i))
}

time.Sleep(60 * time.Millisecond)

for idx, ch := range chs {
result := <-ch
data, err := result.Unwrap()
if err != nil {
t.Errorf("Unexpected error: %v", err)
}

if data != fmt.Sprintf("Result for %d", idx) {
t.Errorf("Unexpected result: %v", data)
}
}
}

func testInflight(t *testing.T) {
loader := New(func(ctx context.Context, keys []int) []Result[string] {
if len(keys) != 5 {
t.Errorf("Expected 5 keys, got %d", keys)
}

results := make([]Result[string], len(keys))
for i, key := range keys {
results[i] = Result[string]{data: fmt.Sprintf("Result for %d", key)}
}
return results
}, WithBatchSize(5))

ii := make([]int, 0)
for i := 0; i < 9; i++ {
ii = append(ii, i/2)
}

chs := loader.LoadMany(context.TODO(), ii)
for idx, result := range chs {
data, err := result.Unwrap()
if err != nil {
t.Errorf("Unexpected error: %v", err)
}

if data != fmt.Sprintf("Result for %d", idx/2) {
t.Errorf("Unexpected result: %v", data)
}
Expand Down Expand Up @@ -222,15 +261,26 @@ func testOptions(t *testing.T) {

func testContextCancellation(t *testing.T) {
loader := New(func(ctx context.Context, keys []int) []Result[string] {
<-ctx.Done()
return nil
})
results := make([]Result[string], len(keys))
for i, key := range keys {
results[i] = Result[string]{data: fmt.Sprintf("Result for %d", key)}
}

return results
}, WithBatchSize(2))

ctx, cancel := context.WithCancel(context.Background())
cancel()
result := loader.Load(ctx, 1)
if result.err == nil {
t.Error("Expected error when context is cancelled")

results := loader.LoadMany(ctx, []int{0, 1})
for idx, result := range results {
if result.err != nil {
t.Errorf("Unexpected error: %v", result.err)
}

if result.data != fmt.Sprintf("Result for %d", idx) {
t.Errorf("Unexpected result: %v", result.data)
}
}
}

Expand Down
17 changes: 17 additions & 0 deletions result.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package dataloader

// Result is the result of a DataLoader operation
type Result[V any] struct {
data V
err error
}

// Wrap wraps data and an error into a Result
func Wrap[V any](data V, err error) Result[V] {
return Result[V]{data: data, err: err}
}

// TryUnwrap returns the data or an error
func (r Result[V]) Unwrap() (V, error) {
return r.data, r.err
}

0 comments on commit c3fc508

Please sign in to comment.