diff options
author | Matt Joiner <anacrolix@gmail.com> | 2019-11-05 18:57:25 +1100 |
---|---|---|
committer | Matt Joiner <anacrolix@gmail.com> | 2019-11-05 18:57:25 +1100 |
commit | e8c572c45b4d028cfc0c608699434bc8e8808c67 (patch) | |
tree | 42d166c056f215a6b065fe4fca6a976aff3b4825 /stmutil | |
parent | Add failed commits profiling (diff) | |
download | stm-e8c572c45b4d028cfc0c608699434bc8e8808c67.tar.gz stm-e8c572c45b4d028cfc0c608699434bc8e8808c67.tar.xz |
Cache ContextDoneVars
Note that nothing currently flushes them. Could probably flush them when they're done, and the early return will take care of the rest.
Diffstat (limited to 'stmutil')
-rw-r--r-- | stmutil/context.go | 23 | ||||
-rw-r--r-- | stmutil/context_test.go | 20 |
2 files changed, 38 insertions, 5 deletions
diff --git a/stmutil/context.go b/stmutil/context.go index e82e522..98eb8aa 100644 --- a/stmutil/context.go +++ b/stmutil/context.go @@ -2,19 +2,32 @@ package stmutil import ( "context" + "sync" "github.com/anacrolix/stm" ) +var ( + mu sync.Mutex + ctxVars = map[context.Context]*stm.Var{} +) + func ContextDoneVar(ctx context.Context) (*stm.Var, func()) { + mu.Lock() + defer mu.Unlock() + if v, ok := ctxVars[ctx]; ok { + return v, func() {} + } if ctx.Err() != nil { - return stm.NewVar(true), func() {} + v := stm.NewVar(true) + ctxVars[ctx] = v + return v, func() {} } - ctx, cancel := context.WithCancel(ctx) - _var := stm.NewVar(false) + v := stm.NewVar(false) go func() { <-ctx.Done() - stm.AtomicSet(_var, true) + stm.AtomicSet(v, true) }() - return _var, cancel + ctxVars[ctx] = v + return v, func() {} } diff --git a/stmutil/context_test.go b/stmutil/context_test.go new file mode 100644 index 0000000..0a6b7c0 --- /dev/null +++ b/stmutil/context_test.go @@ -0,0 +1,20 @@ +package stmutil + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestContextEquality(t *testing.T) { + ctx := context.Background() + assert.True(t, ctx == context.Background()) + childCtx, cancel := context.WithCancel(ctx) + assert.True(t, childCtx != ctx) + assert.True(t, childCtx != ctx) + assert.Equal(t, context.Background(), ctx) + cancel() + assert.Equal(t, context.Background(), ctx) + assert.NotEqual(t, ctx, childCtx) +} |