aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--stmutil/context.go23
-rw-r--r--stmutil/context_test.go20
2 files changed, 38 insertions, 5 deletions
diff --git a/stmutil/context.go b/stmutil/context.go
index e82e522..98eb8aa 100644
--- a/stmutil/context.go
+++ b/stmutil/context.go
@@ -2,19 +2,32 @@ package stmutil
import (
"context"
+ "sync"
"github.com/anacrolix/stm"
)
+var (
+ mu sync.Mutex
+ ctxVars = map[context.Context]*stm.Var{}
+)
+
func ContextDoneVar(ctx context.Context) (*stm.Var, func()) {
+ mu.Lock()
+ defer mu.Unlock()
+ if v, ok := ctxVars[ctx]; ok {
+ return v, func() {}
+ }
if ctx.Err() != nil {
- return stm.NewVar(true), func() {}
+ v := stm.NewVar(true)
+ ctxVars[ctx] = v
+ return v, func() {}
}
- ctx, cancel := context.WithCancel(ctx)
- _var := stm.NewVar(false)
+ v := stm.NewVar(false)
go func() {
<-ctx.Done()
- stm.AtomicSet(_var, true)
+ stm.AtomicSet(v, true)
}()
- return _var, cancel
+ ctxVars[ctx] = v
+ return v, func() {}
}
diff --git a/stmutil/context_test.go b/stmutil/context_test.go
new file mode 100644
index 0000000..0a6b7c0
--- /dev/null
+++ b/stmutil/context_test.go
@@ -0,0 +1,20 @@
+package stmutil
+
+import (
+ "context"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+)
+
+func TestContextEquality(t *testing.T) {
+ ctx := context.Background()
+ assert.True(t, ctx == context.Background())
+ childCtx, cancel := context.WithCancel(ctx)
+ assert.True(t, childCtx != ctx)
+ assert.True(t, childCtx != ctx)
+ assert.Equal(t, context.Background(), ctx)
+ cancel()
+ assert.Equal(t, context.Background(), ctx)
+ assert.NotEqual(t, ctx, childCtx)
+}