Skip to content

Commit

Permalink
Feat: core.ToContext(ctx, v) for ctx initialization (#852)
Browse files Browse the repository at this point in the history
  • Loading branch information
rurirei committed Apr 4, 2021
1 parent 0dcd1f4 commit aa40b8b
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 8 deletions.
16 changes: 16 additions & 0 deletions context.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,19 @@ func MustFromContext(ctx context.Context) *Instance {
}
return v
}

// ToContext returns ctx from the given context, or creates an Instance if the context doesn't find that.
func ToContext(ctx context.Context, v *Instance) context.Context {
if FromContext(ctx) != v {
ctx = context.WithValue(ctx, v2rayKey, v)
}
return ctx
}

// MustToContext returns ctx from the given context, or panics if not found that.
func MustToContext(ctx context.Context, v *Instance) context.Context {
if c := ToContext(ctx, v); c != ctx {
panic("V is not in context.")
}
return ctx
}
13 changes: 12 additions & 1 deletion context_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import (
. "github.com/v2fly/v2ray-core/v4"
)

func TestContextPanic(t *testing.T) {
func TestFromContextPanic(t *testing.T) {
defer func() {
r := recover()
if r == nil {
Expand All @@ -17,3 +17,14 @@ func TestContextPanic(t *testing.T) {

MustFromContext(context.Background())
}

func TestToContextPanic(t *testing.T) {
defer func() {
r := recover()
if r == nil {
t.Error("expect panic, but nil")
}
}()

MustToContext(context.Background(), &Instance{})
}
10 changes: 3 additions & 7 deletions functions.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import (
func CreateObject(v *Instance, config interface{}) (interface{}, error) {
var ctx context.Context
if v != nil {
ctx = context.WithValue(v.ctx, v2rayKey, v)
ctx = ToContext(v.ctx, v)
}
return common.CreateObject(ctx, config)
}
Expand Down Expand Up @@ -47,9 +47,7 @@ func StartInstance(configFormat string, configBytes []byte) (*Instance, error) {
//
// v2ray:api:stable
func Dial(ctx context.Context, v *Instance, dest net.Destination) (net.Conn, error) {
if FromContext(ctx) == nil {
ctx = context.WithValue(ctx, v2rayKey, v)
}
ctx = ToContext(ctx, v)

dispatcher := v.GetFeature(routing.DispatcherType())
if dispatcher == nil {
Expand All @@ -76,9 +74,7 @@ func Dial(ctx context.Context, v *Instance, dest net.Destination) (net.Conn, err
//
// v2ray:api:beta
func DialUDP(ctx context.Context, v *Instance) (net.PacketConn, error) {
if FromContext(ctx) == nil {
ctx = context.WithValue(ctx, v2rayKey, v)
}
ctx = ToContext(ctx, v)

dispatcher := v.GetFeature(routing.DispatcherType())
if dispatcher == nil {
Expand Down

0 comments on commit aa40b8b

Please sign in to comment.