Goroutines are one of the most powerful features in Go. They allow us to run tasks concurrently and are also lightweight compared to threads. However, testing goroutines can be challenging because they execute in a random order. Sometimes, the test may finish before the goroutine completes, causing the test to fail intermittently. In this article, we’ll explore how to test goroutines in Go using just the standard library.
Let’s consider this simple example:
// task_runner.go
type Task func(ctx context.Context, args []string)
type TaskRunner struct {
log *slog.Logger
tasks []Task
}
func NewTaskRunner(l *slog.Logger, tasks ...Task) *TaskRunner {
return &TaskRunner{log: l, tasks: tasks}
}
func (r *TaskRunner) Run(ctx context.Context, args []string) {
r.log.InfoContext(ctx, "Run tasks", "args", args)
ctx = context.WithoutCancel(ctx)
for _, task := range r.tasks {
go task(ctx, args)
}
}
The Run
methods run tasks concurrently and then forget about them. This kind of problem is often found in real-world applications. For example, when sending emails to multiple recipients, using goroutines allows us to send the emails concurrently, making the process faster.
While the code appears to be correct, how can we test it? Let’s first write a test for it without synchronization, and then observe the problem.
// task_runner_test.go
func NewTask(l *slog.Logger, name string) Task {
return func(ctx context.Context, args []string) {
l.InfoContext(ctx, "Task started", "name", name)
delay := time.Duration(rand.Intn(10)*100) * time.Millisecond
select {
case <-ctx.Done():
l.InfoContext(ctx, "Task canceled", "name", name)
case <-time.After(delay): // simulate some work.
l.InfoContext(ctx, "Task finished", "name", name, "args", args)
}
}
}
func TestTaskRunner_Run(t *testing.T) {
var logHistory bytes.Buffer
logger := slog.New(slog.NewTextHandler(&logHistory, &slog.HandlerOptions{}))
defer func() { t.Log(logHistory.String()) }()
task1 := NewTask(logger, "task1")
task2 := NewTask(logger, "task2")
task3 := NewTask(logger, "task3")
runner := NewTaskRunner(logger, task1, task2, task3)
ctx := context.Background()
args := []string{"a", "b", "c"}
runner.Run(ctx, args)
}
The test is very straightforward. We create a Task
constructor to create a task with a delay to simulate some work. The returned task will block until the context is canceled or the delay is finished. We create three tasks and run them concurrently using the Run
method.
If you run the test, you will find that some logs are missing. This is what I got (results may vary):
=== RUN TestTaskRunner_Run
async_event_test.go:28: time=2024-04-09T21:19:52.255+07:00 level=INFO msg="Run tasks" args="[a b c]"
--- PASS: TestTaskRunner_Run (0.00s)
PASS
I have two solutions to solve this problem. Let’s see which one is better for your case.
First Approach: Using sync.WaitGroup
and Mocking the Task
The first approach is to use sync.WaitGroup
and mock the Task
function. Since we can inject the task into the TaskRunner
, we can fully control the task’s behavior.
We modify the NewTask
function to accept sync.WaitGroup
as the third argument. Additionally, we add wg.Done()
at the end of the task function to notify that the task has finished. We also modify the test to use sync.WaitGroup
to wait for all tasks to finish.
Here is the git diff:
diff --git a/how-to-test-goroutines/task_runner_test.go b/how-to-test-goroutines/task_runner_test.go
index aeac6a6..4e8b96d 100644
--- a/how-to-test-goroutines/task_runner_test.go
+++ b/how-to-test-goroutines/task_runner_test.go
@@ -5,12 +5,15 @@ import (
"context"
"log/slog"
"math/rand"
+ "sync"
"testing"
"time"
)
-func NewTask(l *slog.Logger, name string) Task {
+func NewTask(l *slog.Logger, name string, wg *sync.WaitGroup) Task {
return func(ctx context.Context, args []string) {
+ defer wg.Done()
+
l.InfoContext(ctx, "Task started", "name", name)
delay := time.Duration(rand.Intn(5)*100) * time.Millisecond
select {
@@ -26,13 +29,18 @@ func TestTaskRunner_Run(t *testing.T) {
logger := slog.New(slog.NewTextHandler(&logHistory, &slog.HandlerOptions{}))
defer func() { t.Log(logHistory.String()) }()
- task1 := NewTask(logger, "task1")
- task2 := NewTask(logger, "task2")
- task3 := NewTask(logger, "task3")
+ var wg sync.WaitGroup
+ wg.Add(3)
+
+ task1 := NewTask(logger, "task1", &wg)
+ task2 := NewTask(logger, "task2", &wg)
+ task3 := NewTask(logger, "task3", &wg)
runner := NewTaskRunner(logger, task1, task2, task3)
ctx := context.Background()
args := []string{"a", "b", "c"}
runner.Run(ctx, args)
+
+ wg.Wait()
}
If you run the test, you will see that all logs are printed. This is what I got:
=== RUN TestTaskRunner_Run
task_runner_test.go:30: time=2024-04-09T21:44:58.959+07:00 level=INFO msg="Run tasks" args="[a b c]"
time=2024-04-09T21:44:58.959+07:00 level=INFO msg="Task started" name=task3
time=2024-04-09T21:44:58.959+07:00 level=INFO msg="Task started" name=task1
time=2024-04-09T21:44:58.959+07:00 level=INFO msg="Task started" name=task2
time=2024-04-09T21:44:59.060+07:00 level=INFO msg="Task finished" name=task1 args="[a b c]"
time=2024-04-09T21:44:59.060+07:00 level=INFO msg="Task finished" name=task3 args="[a b c]"
time=2024-04-09T21:44:59.060+07:00 level=INFO msg="Task finished" name=task2 args="[a b c]"
--- PASS: TestTaskRunner_Run (0.10s)
PASS
Second Approach: Using context.Context
and sync.WaitGroup
I was happy with the first approach since, in most cases, we have full control of the task, and that’s the beauty of dependency injection. However, sometimes we don’t have control over the task, or perhaps we simply don’t want to mock the task itself. In such cases, we can utilize the context.Context
and sync.WaitGroup
to enable/disable synchronization contextually.
In the first approach, we heavily expose the synchronization mechanism in the test code. In the second approach, we encapsulate the synchronization mechanism within the TaskRunner
itself. We create a new package called await
to handle the synchronization. The await
package will provide a way to add and wait for the task to finish. The only change we need in the test code is to wrap context.Background()
with await.Context()
.
Here is the git diff:
diff --git a/how-to-test-goroutines/task_runner.go b/how-to-test-goroutines/task_runner.go
index 4e73a49..5cc8036 100644
--- a/how-to-test-goroutines/task_runner.go
+++ b/how-to-test-goroutines/task_runner.go
@@ -3,6 +3,8 @@ package how_to_test_goroutines
import (
"context"
"log/slog"
+
+ "github.com/josestg/gotips/how-to-test-goroutines/await"
)
type Task func(ctx context.Context, args []string)
@@ -19,7 +21,15 @@ func NewTaskRunner(l *slog.Logger, tasks ...Task) *TaskRunner {
func (r *TaskRunner) Run(ctx context.Context, args []string) {
r.log.InfoContext(ctx, "Run tasks", "args", args)
ctx = context.WithoutCancel(ctx)
+
+ awaiter := await.FromContext(ctx)
for _, task := range r.tasks {
- go task(ctx, args)
+ awaiter.Add(1)
+ task := task
+ go func() {
+ defer awaiter.Done()
+ task(ctx, args)
+ }()
}
+ awaiter.Wait()
}
diff --git a/how-to-test-goroutines/task_runner_test.go b/how-to-test-goroutines/task_runner_test.go
index aeac6a6..b1cf9a7 100644
--- a/how-to-test-goroutines/task_runner_test.go
+++ b/how-to-test-goroutines/task_runner_test.go
@@ -7,6 +7,8 @@ import (
"math/rand"
"testing"
"time"
+
+ "github.com/josestg/gotips/how-to-test-goroutines/await"
)
func NewTask(l *slog.Logger, name string) Task {
@@ -33,6 +35,7 @@ func TestTaskRunner_Run(t *testing.T) {
runner := NewTaskRunner(logger, task1, task2, task3)
ctx := context.Background()
+ ctx = await.Context(ctx)
args := []string{"a", "b", "c"}
runner.Run(ctx, args)
}
It seems like we have changed the expectation. Previously, we wanted the task to fire and forget, but now we are making it wait for all tasks to finish, which is blocking. However, this behavior depends on the value in the context. To gain a better understanding, let’s examine the await
package:
// await/await.go
type contextKey struct{}
var awaitKey = &contextKey{}
// Awaiter basically an interface that describes the sync.WaitGroup.
type Awaiter interface {
Add(delta int)
Done()
Wait()
}
// nopAwaiter is a no-op implementation of Awaiter.
type nopAwaiter struct{}
func (nopAwaiter) Add(_ int) {}
func (nopAwaiter) Done() {}
func (nopAwaiter) Wait() {}
// Context returns a new context with an Awaiter.
func Context(ctx context.Context) context.Context {
var wg Awaiter = &sync.WaitGroup{}
return context.WithValue(ctx, awaitKey, wg)
}
// FromContext returns the Awaiter from the context if it exists. Otherwise, it returns a no-op Awaiter.
func FromContext(ctx context.Context) Awaiter {
wg, ok := ctx.Value(awaitKey).(Awaiter)
if !ok {
return &nopAwaiter{}
}
return wg
}
The secret recipe lies in the FromContext
function. When the context doesn’t have the Awaiter
, it returns a no-op Awaiter
. Since our default expectation is to fire and forget the task, and we only require synchronization in the test, we change the context’s behavior to wait for all tasks to finish using await.Context()
in the test.
Because we need more than one behavior depending on the context, this is where interfaces shine. By creating an interface Awaiter
that describes the synchronization mechanism, we can easily switch the behavior of the TaskRunner
by just changing the context value.
And that’s it!
Conclusion
Personally, I prefer the second approach because it encapsulates the synchronization mechanism in the business logic rather than exposing it in the test code. However, it depends on the case. Sometimes the first approach is more suitable. The key takeaway is to understand the problem and choose the best solution for it. I hope this article helps you to test goroutines in Go. If you have any questions or suggestions, feel free to leave a comment below. Thank you for reading!
You can find the complete code in the GitHub repository