From 42aea01fb5b3c583852a4d44cca0b6c57be9039d Mon Sep 17 00:00:00 2001 From: yuhan6665 <1588741+yuhan6665@users.noreply.github.com> Date: Thu, 26 Dec 2024 07:55:12 -0500 Subject: [PATCH] Core: Add mutex to injection resolution (#4206) * Revert "Add RequireFeaturesAsync() that works regardless order of app init" * Add mutex to injection resolution - Turns out we already support async DI resolution regardless of feature ordering Previous code contain a race condition causing some resolution is lost - Note that the new mutex cover s.pendingResolutions and s.features but must not cover callbackResolution() due to deadlock - Refactor some method names and simplify code * Add OptionalFeatures injection For example OptionalFeatures() is useful for fakedns module --- app/dispatcher/default.go | 2 +- app/dns/nameserver.go | 2 +- app/observatory/command/command.go | 2 +- app/proxyman/command/command.go | 2 +- app/router/balancing.go | 6 +- app/router/command/command.go | 2 +- app/router/strategy_leastload.go | 6 +- app/router/strategy_leastping.go | 6 +- app/router/strategy_random.go | 6 +- core/xray.go | 158 +++++++++++++++-------------- core/xray_test.go | 2 +- proxy/dns/dns.go | 2 +- 12 files changed, 105 insertions(+), 91 deletions(-) diff --git a/app/dispatcher/default.go b/app/dispatcher/default.go index dbf58dad623c..7bc580565f67 100644 --- a/app/dispatcher/default.go +++ b/app/dispatcher/default.go @@ -106,7 +106,7 @@ func init() { common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) { d := new(DefaultDispatcher) if err := core.RequireFeatures(ctx, func(om outbound.Manager, router routing.Router, pm policy.Manager, sm stats.Manager, dc dns.Client) error { - core.RequireFeatures(ctx, func(fdns dns.FakeDNSEngine) { // FakeDNSEngine is optional + core.OptionalFeatures(ctx, func(fdns dns.FakeDNSEngine) { d.fdns = fdns }) return d.Init(config.(*Config), om, router, pm, sm, dc) diff --git a/app/dns/nameserver.go b/app/dns/nameserver.go index de6e1686f017..9c2668d90b01 100644 --- a/app/dns/nameserver.go +++ b/app/dns/nameserver.go @@ -56,7 +56,7 @@ func NewServer(ctx context.Context, dest net.Destination, dispatcher routing.Dis return NewTCPLocalNameServer(u, queryStrategy) case strings.EqualFold(u.String(), "fakedns"): var fd dns.FakeDNSEngine - core.RequireFeatures(ctx, func(fdns dns.FakeDNSEngine) { // FakeDNSEngine is optional + core.RequireFeatures(ctx, func(fdns dns.FakeDNSEngine) { fd = fdns }) return NewFakeDNSServer(fd), nil diff --git a/app/observatory/command/command.go b/app/observatory/command/command.go index aab85e801b0e..f9bb58e36d40 100644 --- a/app/observatory/command/command.go +++ b/app/observatory/command/command.go @@ -38,7 +38,7 @@ func init() { sv := &service{v: s} err := s.RequireFeatures(func(Observatory extension.Observatory) { sv.observatory = Observatory - }) + }, false) if err != nil { return nil, err } diff --git a/app/proxyman/command/command.go b/app/proxyman/command/command.go index 3c7824d2d65f..ef71052111fa 100644 --- a/app/proxyman/command/command.go +++ b/app/proxyman/command/command.go @@ -177,7 +177,7 @@ func (s *service) Register(server *grpc.Server) { common.Must(s.v.RequireFeatures(func(im inbound.Manager, om outbound.Manager) { hs.ihm = im hs.ohm = om - })) + }, false)) RegisterHandlerServiceServer(server, hs) // For compatibility purposes diff --git a/app/router/balancing.go b/app/router/balancing.go index 7d8bb022b709..5f8cb1c21b0b 100644 --- a/app/router/balancing.go +++ b/app/router/balancing.go @@ -5,6 +5,7 @@ import ( sync "sync" "github.com/xtls/xray-core/app/observatory" + "github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common/errors" "github.com/xtls/xray-core/core" "github.com/xtls/xray-core/features/extension" @@ -31,9 +32,10 @@ type RoundRobinStrategy struct { func (s *RoundRobinStrategy) InjectContext(ctx context.Context) { s.ctx = ctx if len(s.FallbackTag) > 0 { - core.RequireFeaturesAsync(s.ctx, func(observatory extension.Observatory) { + common.Must(core.RequireFeatures(s.ctx, func(observatory extension.Observatory) error { s.observatory = observatory - }) + return nil + })) } } diff --git a/app/router/command/command.go b/app/router/command/command.go index baf76b8b89a3..fd9caa2228b0 100644 --- a/app/router/command/command.go +++ b/app/router/command/command.go @@ -135,7 +135,7 @@ func (s *service) Register(server *grpc.Server) { vCoreDesc := RoutingService_ServiceDesc vCoreDesc.ServiceName = "v2ray.core.app.router.command.RoutingService" server.RegisterService(&vCoreDesc, rs) - })) + }, false)) } func init() { diff --git a/app/router/strategy_leastload.go b/app/router/strategy_leastload.go index a4ef1c122a67..1bf3cbc09a22 100644 --- a/app/router/strategy_leastload.go +++ b/app/router/strategy_leastload.go @@ -7,6 +7,7 @@ import ( "time" "github.com/xtls/xray-core/app/observatory" + "github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common/dice" "github.com/xtls/xray-core/common/errors" "github.com/xtls/xray-core/core" @@ -59,9 +60,10 @@ type node struct { func (s *LeastLoadStrategy) InjectContext(ctx context.Context) { s.ctx = ctx - core.RequireFeaturesAsync(s.ctx, func(observatory extension.Observatory) { + common.Must(core.RequireFeatures(s.ctx, func(observatory extension.Observatory) error { s.observer = observatory - }) + return nil + })) } func (s *LeastLoadStrategy) PickOutbound(candidates []string) string { diff --git a/app/router/strategy_leastping.go b/app/router/strategy_leastping.go index b13d1a7ddce5..ada3492d35a1 100644 --- a/app/router/strategy_leastping.go +++ b/app/router/strategy_leastping.go @@ -4,6 +4,7 @@ import ( "context" "github.com/xtls/xray-core/app/observatory" + "github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common/errors" "github.com/xtls/xray-core/core" "github.com/xtls/xray-core/features/extension" @@ -20,9 +21,10 @@ func (l *LeastPingStrategy) GetPrincipleTarget(strings []string) []string { func (l *LeastPingStrategy) InjectContext(ctx context.Context) { l.ctx = ctx - core.RequireFeaturesAsync(l.ctx, func(observatory extension.Observatory) { + common.Must(core.RequireFeatures(l.ctx, func(observatory extension.Observatory) error { l.observatory = observatory - }) + return nil + })) } func (l *LeastPingStrategy) PickOutbound(strings []string) string { diff --git a/app/router/strategy_random.go b/app/router/strategy_random.go index 9f4cdd774c33..ea9b7add094a 100644 --- a/app/router/strategy_random.go +++ b/app/router/strategy_random.go @@ -4,6 +4,7 @@ import ( "context" "github.com/xtls/xray-core/app/observatory" + "github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common/dice" "github.com/xtls/xray-core/core" "github.com/xtls/xray-core/features/extension" @@ -20,9 +21,10 @@ type RandomStrategy struct { func (s *RandomStrategy) InjectContext(ctx context.Context) { s.ctx = ctx if len(s.FallbackTag) > 0 { - core.RequireFeaturesAsync(s.ctx, func(observatory extension.Observatory) { + common.Must(core.RequireFeatures(s.ctx, func(observatory extension.Observatory) error { s.observatory = observatory - }) + return nil + })) } } diff --git a/core/xray.go b/core/xray.go index 5ab106035046..f6ccc27d617e 100644 --- a/core/xray.go +++ b/core/xray.go @@ -4,7 +4,6 @@ import ( "context" "reflect" "sync" - "time" "github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common/errors" @@ -45,22 +44,13 @@ func getFeature(allFeatures []features.Feature, t reflect.Type) features.Feature return nil } -func (r *resolution) resolve(allFeatures []features.Feature) (bool, error) { - var fs []features.Feature - for _, d := range r.deps { - f := getFeature(allFeatures, d) - if f == nil { - return false, nil - } - fs = append(fs, f) - } - +func (r *resolution) callbackResolution(allFeatures []features.Feature) error { callback := reflect.ValueOf(r.callback) var input []reflect.Value callbackType := callback.Type() for i := 0; i < callbackType.NumIn(); i++ { pt := callbackType.In(i) - for _, f := range fs { + for _, f := range allFeatures { if reflect.TypeOf(f).AssignableTo(pt) { input = append(input, reflect.ValueOf(f)) break @@ -85,15 +75,17 @@ func (r *resolution) resolve(allFeatures []features.Feature) (bool, error) { } } - return true, err + return err } // Instance combines all Xray features. type Instance struct { - access sync.Mutex - features []features.Feature - featureResolutions []resolution - running bool + statusLock sync.Mutex + features []features.Feature + pendingResolutions []resolution + pendingOptionalResolutions []resolution + running bool + resolveLock sync.Mutex ctx context.Context } @@ -154,13 +146,14 @@ func addOutboundHandlers(server *Instance, configs []*OutboundHandlerConfig) err // See Instance.RequireFeatures for more information. func RequireFeatures(ctx context.Context, callback interface{}) error { v := MustFromContext(ctx) - return v.RequireFeatures(callback) + return v.RequireFeatures(callback, false) } -// RequireFeaturesAsync registers a callback, which will be called when all dependent features are registered. The order of app init doesn't matter -func RequireFeaturesAsync(ctx context.Context, callback interface{}) { +// OptionalFeatures is a helper function to aquire features from Instance in context. +// See Instance.RequireFeatures for more information. +func OptionalFeatures(ctx context.Context, callback interface{}) error { v := MustFromContext(ctx) - v.RequireFeaturesAsync(callback) + return v.RequireFeatures(callback, true) } // New returns a new Xray instance based on given configuration. @@ -234,9 +227,12 @@ func initInstanceWithConfig(config *Config, server *Instance) (bool, error) { }(), ) - if server.featureResolutions != nil { + server.resolveLock.Lock() + if server.pendingResolutions != nil { + server.resolveLock.Unlock() return true, errors.New("not all dependencies are resolved.") } + server.resolveLock.Unlock() if err := addInboundHandlers(server, config.Inbound); err != nil { return true, err @@ -255,8 +251,8 @@ func (s *Instance) Type() interface{} { // Close shutdown the Xray instance. func (s *Instance) Close() error { - s.access.Lock() - defer s.access.Unlock() + s.statusLock.Lock() + defer s.statusLock.Unlock() s.running = false @@ -275,7 +271,7 @@ func (s *Instance) Close() error { // RequireFeatures registers a callback, which will be called when all dependent features are registered. // The callback must be a func(). All its parameters must be features.Feature. -func (s *Instance) RequireFeatures(callback interface{}) error { +func (s *Instance) RequireFeatures(callback interface{}, optional bool) error { callbackType := reflect.TypeOf(callback) if callbackType.Kind() != reflect.Func { panic("not a function") @@ -290,47 +286,32 @@ func (s *Instance) RequireFeatures(callback interface{}) error { deps: featureTypes, callback: callback, } - if finished, err := r.resolve(s.features); finished { - return err - } - s.featureResolutions = append(s.featureResolutions, r) - return nil -} -// RequireFeaturesAsync registers a callback, which will be called when all dependent features are registered. The order of app init doesn't matter -func (s *Instance) RequireFeaturesAsync(callback interface{}) { - callbackType := reflect.TypeOf(callback) - if callbackType.Kind() != reflect.Func { - panic("not a function") - } - - var featureTypes []reflect.Type - for i := 0; i < callbackType.NumIn(); i++ { - featureTypes = append(featureTypes, reflect.PtrTo(callbackType.In(i))) - } - - r := resolution{ - deps: featureTypes, - callback: callback, + s.resolveLock.Lock() + foundAll := true + for _, d := range r.deps { + f := getFeature(s.features, d) + if f == nil { + foundAll = false + break + } } - go func() { - var finished = false - for i := 0; !finished; i++ { - if i > 100000 { - errors.LogError(s.ctx, "RequireFeaturesAsync failed after count ", i) - break; - } - finished, _ = r.resolve(s.features) - time.Sleep(time.Millisecond) + if foundAll { + s.resolveLock.Unlock() + return r.callbackResolution(s.features) + } else { + if optional { + s.pendingOptionalResolutions = append(s.pendingOptionalResolutions, r) + } else { + s.pendingResolutions = append(s.pendingResolutions, r) } - s.featureResolutions = append(s.featureResolutions, r) - }() + s.resolveLock.Unlock() + return nil + } } // AddFeature registers a feature into current Instance. func (s *Instance) AddFeature(feature features.Feature) error { - s.features = append(s.features, feature) - if s.running { if err := feature.Start(); err != nil { errors.LogInfoInner(s.ctx, err, "failed to start feature") @@ -338,27 +319,52 @@ func (s *Instance) AddFeature(feature features.Feature) error { return nil } - if s.featureResolutions == nil { - return nil + s.resolveLock.Lock() + s.features = append(s.features, feature) + + var availableResolution []resolution + var pending []resolution + for _, r := range s.pendingResolutions { + foundAll := true + for _, d := range r.deps { + f := getFeature(s.features, d) + if f == nil { + foundAll = false + break + } + } + if foundAll { + availableResolution = append(availableResolution, r) + } else { + pending = append(pending, r) + } } + s.pendingResolutions = pending - var pendingResolutions []resolution - for _, r := range s.featureResolutions { - finished, err := r.resolve(s.features) - if finished && err != nil { - return err + var pendingOptional []resolution + for _, r := range s.pendingOptionalResolutions { + foundAll := true + for _, d := range r.deps { + f := getFeature(s.features, d) + if f == nil { + foundAll = false + break + } } - if !finished { - pendingResolutions = append(pendingResolutions, r) + if foundAll { + availableResolution = append(availableResolution, r) + } else { + pendingOptional = append(pendingOptional, r) } } - if len(pendingResolutions) == 0 { - s.featureResolutions = nil - } else if len(pendingResolutions) < len(s.featureResolutions) { - s.featureResolutions = pendingResolutions + s.pendingOptionalResolutions = pendingOptional + s.resolveLock.Unlock() + + var err error + for _, r := range availableResolution { + err = r.callbackResolution(s.features) // only return the last error for now } - - return nil + return err } // GetFeature returns a feature of the given type, or nil if such feature is not registered. @@ -371,8 +377,8 @@ func (s *Instance) GetFeature(featureType interface{}) features.Feature { // // xray:api:stable func (s *Instance) Start() error { - s.access.Lock() - defer s.access.Unlock() + s.statusLock.Lock() + defer s.statusLock.Unlock() s.running = true for _, f := range s.features { diff --git a/core/xray_test.go b/core/xray_test.go index 43d021efd2b7..f4cb11abe7a0 100644 --- a/core/xray_test.go +++ b/core/xray_test.go @@ -30,7 +30,7 @@ func TestXrayDependency(t *testing.T) { t.Error("expected dns client fulfilled, but actually nil") } wait <- true - }) + }, false) instance.AddFeature(localdns.New()) <-wait } diff --git a/proxy/dns/dns.go b/proxy/dns/dns.go index 790c80c1425b..b7a3264a62a2 100644 --- a/proxy/dns/dns.go +++ b/proxy/dns/dns.go @@ -27,7 +27,7 @@ func init() { common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) { h := new(Handler) if err := core.RequireFeatures(ctx, func(dnsClient dns.Client, policyManager policy.Manager) error { - core.RequireFeatures(ctx, func(fdns dns.FakeDNSEngine) { // FakeDNSEngine is optional + core.OptionalFeatures(ctx, func(fdns dns.FakeDNSEngine) { h.fdns = fdns }) return h.Init(config.(*Config), dnsClient, policyManager)