mirror of https://go.googlesource.com/go
144 lines
3.4 KiB
Go
144 lines
3.4 KiB
Go
// Copyright 2023 The Go Authors. All rights reserved.
|
|
// Use of this source code is governed by a BSD-style
|
|
// license that can be found in the LICENSE file.
|
|
|
|
package context_test
|
|
|
|
import (
|
|
"context"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
)
|
|
|
|
// afterFuncContext is a context that's not one of the types
|
|
// defined in context.go, that supports registering AfterFuncs.
|
|
type afterFuncContext struct {
|
|
mu sync.Mutex
|
|
afterFuncs map[*byte]func()
|
|
done chan struct{}
|
|
err error
|
|
}
|
|
|
|
func newAfterFuncContext() context.Context {
|
|
return &afterFuncContext{}
|
|
}
|
|
|
|
func (c *afterFuncContext) Deadline() (time.Time, bool) {
|
|
return time.Time{}, false
|
|
}
|
|
|
|
func (c *afterFuncContext) Done() <-chan struct{} {
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
if c.done == nil {
|
|
c.done = make(chan struct{})
|
|
}
|
|
return c.done
|
|
}
|
|
|
|
func (c *afterFuncContext) Err() error {
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
return c.err
|
|
}
|
|
|
|
func (c *afterFuncContext) Value(key any) any {
|
|
return nil
|
|
}
|
|
|
|
func (c *afterFuncContext) AfterFunc(f func()) func() bool {
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
k := new(byte)
|
|
if c.afterFuncs == nil {
|
|
c.afterFuncs = make(map[*byte]func())
|
|
}
|
|
c.afterFuncs[k] = f
|
|
return func() bool {
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
_, ok := c.afterFuncs[k]
|
|
delete(c.afterFuncs, k)
|
|
return ok
|
|
}
|
|
}
|
|
|
|
func (c *afterFuncContext) cancel(err error) {
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
if c.err != nil {
|
|
return
|
|
}
|
|
c.err = err
|
|
for _, f := range c.afterFuncs {
|
|
go f()
|
|
}
|
|
c.afterFuncs = nil
|
|
}
|
|
|
|
func TestCustomContextAfterFuncCancel(t *testing.T) {
|
|
ctx0 := &afterFuncContext{}
|
|
ctx1, cancel := context.WithCancel(ctx0)
|
|
defer cancel()
|
|
ctx0.cancel(context.Canceled)
|
|
<-ctx1.Done()
|
|
}
|
|
|
|
func TestCustomContextAfterFuncTimeout(t *testing.T) {
|
|
ctx0 := &afterFuncContext{}
|
|
ctx1, cancel := context.WithTimeout(ctx0, veryLongDuration)
|
|
defer cancel()
|
|
ctx0.cancel(context.Canceled)
|
|
<-ctx1.Done()
|
|
}
|
|
|
|
func TestCustomContextAfterFuncAfterFunc(t *testing.T) {
|
|
ctx0 := &afterFuncContext{}
|
|
donec := make(chan struct{})
|
|
stop := context.AfterFunc(ctx0, func() {
|
|
close(donec)
|
|
})
|
|
defer stop()
|
|
ctx0.cancel(context.Canceled)
|
|
<-donec
|
|
}
|
|
|
|
func TestCustomContextAfterFuncUnregisterCancel(t *testing.T) {
|
|
ctx0 := &afterFuncContext{}
|
|
_, cancel1 := context.WithCancel(ctx0)
|
|
_, cancel2 := context.WithCancel(ctx0)
|
|
if got, want := len(ctx0.afterFuncs), 2; got != want {
|
|
t.Errorf("after WithCancel(ctx0): ctx0 has %v afterFuncs, want %v", got, want)
|
|
}
|
|
cancel1()
|
|
cancel2()
|
|
if got, want := len(ctx0.afterFuncs), 0; got != want {
|
|
t.Errorf("after canceling WithCancel(ctx0): ctx0 has %v afterFuncs, want %v", got, want)
|
|
}
|
|
}
|
|
|
|
func TestCustomContextAfterFuncUnregisterTimeout(t *testing.T) {
|
|
ctx0 := &afterFuncContext{}
|
|
_, cancel := context.WithTimeout(ctx0, veryLongDuration)
|
|
if got, want := len(ctx0.afterFuncs), 1; got != want {
|
|
t.Errorf("after WithTimeout(ctx0, d): ctx0 has %v afterFuncs, want %v", got, want)
|
|
}
|
|
cancel()
|
|
if got, want := len(ctx0.afterFuncs), 0; got != want {
|
|
t.Errorf("after canceling WithTimeout(ctx0, d): ctx0 has %v afterFuncs, want %v", got, want)
|
|
}
|
|
}
|
|
|
|
func TestCustomContextAfterFuncUnregisterAfterFunc(t *testing.T) {
|
|
ctx0 := &afterFuncContext{}
|
|
stop := context.AfterFunc(ctx0, func() {})
|
|
if got, want := len(ctx0.afterFuncs), 1; got != want {
|
|
t.Errorf("after AfterFunc(ctx0, f): ctx0 has %v afterFuncs, want %v", got, want)
|
|
}
|
|
stop()
|
|
if got, want := len(ctx0.afterFuncs), 0; got != want {
|
|
t.Errorf("after stopping AfterFunc(ctx0, f): ctx0 has %v afterFuncs, want %v", got, want)
|
|
}
|
|
}
|