diff --git a/example/auth_directive/module.go b/example/auth_directive/module.go index a81f5abc..11fd66ec 100644 --- a/example/auth_directive/module.go +++ b/example/auth_directive/module.go @@ -18,7 +18,7 @@ directive @auth on FIELD_DEFINITION type Query { # @candi:queryRoot - user: UserQueryResolver + user: UserQueryResolver @auth } type Mutation { @@ -60,7 +60,7 @@ enum FilterSortEnum { # UserModule Resolver Area type UserQueryResolver { - getAllUser(filter: FilterListInputResolver): UserListResolver! @auth + getAllUser(filter: FilterListInputResolver): UserListResolver! getDetailUser(id: String!): UserResolver! } diff --git a/example/auth_directive/server/server.go b/example/auth_directive/server/server.go index d235f74a..cf8f3fdb 100644 --- a/example/auth_directive/server/server.go +++ b/example/auth_directive/server/server.go @@ -3,7 +3,6 @@ package main import ( "context" "encoding/json" - "errors" "io" "log" "net/http" @@ -14,27 +13,40 @@ import ( "github.com/golangid/graphql-go/ws" ) +type extensionser struct { +} + +func (extensionser) Error() string { + return "Unauthorized" +} +func (extensionser) Extensions() map[string]interface{} { + return map[string]interface{}{ + "code": 401, + } +} + type authDirective struct { } -func (v *authDirective) Exec(ctx context.Context, directive *types.Directive, input interface{}) (context.Context, error) { +func (v *authDirective) Auth(ctx context.Context, directive *types.Directive, input interface{}) (context.Context, error) { headers, _ := ctx.Value("header").(http.Header) if headers.Get("Authorization") == "" { - return ctx, errors.New("Unauthorized") + return ctx, &extensionser{} } return context.WithValue(ctx, "claim", "wkwkwkwk"), nil } func main() { + dir := &authDirective{} opts := []graphql.SchemaOpt{ graphql.UseStringDescriptions(), graphql.UseFieldResolvers(), graphql.MaxParallelism(20), - graphql.DirectiveExecutors( - map[string]types.DirectiveExecutor{ - "auth": &authDirective{}, + graphql.DirectiveFuncs( + map[string]types.DirectiveFunc{ + "auth": dir.Auth, }, ), } diff --git a/graphql.go b/graphql.go index 4aa39c1b..243e88b9 100644 --- a/graphql.go +++ b/graphql.go @@ -84,7 +84,7 @@ type Schema struct { useStringDescriptions bool disableIntrospection bool subscribeResolverTimeout time.Duration - executors map[string]types.DirectiveExecutor + directiveFuncs map[string]types.DirectiveFunc } func (s *Schema) ASTSchema() *types.Schema { @@ -171,11 +171,11 @@ func SubscribeResolverTimeout(timeout time.Duration) SchemaOpt { } } -// DirectiveExecutors allows to pass custom directive visitors that will be able to handle +// DirectiveFuncs allows to pass custom directive visitors that will be able to handle // your GraphQL schema directives. -func DirectiveExecutors(executors map[string]types.DirectiveExecutor) SchemaOpt { +func DirectiveFuncs(dirFuncs map[string]types.DirectiveFunc) SchemaOpt { return func(s *Schema) { - s.executors = executors + s.directiveFuncs = dirFuncs } } @@ -264,11 +264,11 @@ func (s *Schema) exec(ctx context.Context, queryString string, operationName str Schema: s.schema, DisableIntrospection: s.disableIntrospection, }, - Limiter: make(chan struct{}, s.maxParallelism), - Tracer: s.tracer, - Logger: s.logger, - PanicHandler: s.panicHandler, - DirectiveExecutors: s.executors, + Limiter: make(chan struct{}, s.maxParallelism), + Tracer: s.tracer, + Logger: s.logger, + PanicHandler: s.panicHandler, + DirectiveFuncs: s.directiveFuncs, } varTypes := make(map[string]*introspection.Type) for _, v := range op.Vars { diff --git a/internal/exec/exec.go b/internal/exec/exec.go index f443d73d..10dddf22 100644 --- a/internal/exec/exec.go +++ b/internal/exec/exec.go @@ -26,7 +26,7 @@ type Request struct { Logger log.Logger PanicHandler errors.PanicHandler SubscribeResolverTimeout time.Duration - DirectiveExecutors map[string]types.DirectiveExecutor + DirectiveFuncs map[string]types.DirectiveFunc } func (r *Request) handlePanic(ctx context.Context) { @@ -229,21 +229,22 @@ func execFieldSelection(ctx context.Context, r *Request, s *resolvable.Schema, f } res := f.resolver - if f.field.UseMethodResolver() { - - for _, directive := range f.field.Directives { - if executor, ok := r.DirectiveExecutors[directive.Name.Name]; ok { - var dirErr error - ctx, dirErr = executor.Exec(ctx, directive, f.field.PackedArgs) - if dirErr != nil { - err := errors.Errorf("%s", dirErr) - err.Path = path.toSlice() - err.ResolverError = dirErr - return err + for _, directive := range f.field.Directives { + if dirFunc, ok := r.DirectiveFuncs[directive.Name.Name]; ok { + var dirErr error + ctx, dirErr = dirFunc(ctx, directive, f.field.PackedArgs) + if dirErr != nil { + err := errors.Errorf("%s", dirErr) + err.Path = path.toSlice() + err.ResolverError = dirErr + if ex, ok := dirErr.(extensionser); ok { + err.Extensions = ex.Extensions() } + return err } } - + } + if f.field.UseMethodResolver() { var in []reflect.Value if f.field.HasContext { in = append(in, reflect.ValueOf(ctx)) diff --git a/internal/exec/subscribe.go b/internal/exec/subscribe.go index ef2afbff..7249be48 100644 --- a/internal/exec/subscribe.go +++ b/internal/exec/subscribe.go @@ -30,11 +30,16 @@ func (r *Request) Subscribe(ctx context.Context, s *resolvable.Schema, op *types } for _, directive := range f.field.Directives { - if executor, ok := r.DirectiveExecutors[directive.Name.Name]; ok { + if dirFunc, ok := r.DirectiveFuncs[directive.Name.Name]; ok { var dirErr error - ctx, dirErr = executor.Exec(ctx, directive, f.field.PackedArgs) + ctx, dirErr = dirFunc(ctx, directive, f.field.PackedArgs) if dirErr != nil { - return sendAndReturnClosed(&Response{Errors: []*errors.QueryError{errors.Errorf("%s", dirErr)}}) + err := errors.Errorf("%s", dirErr) + err.ResolverError = dirErr + if ex, ok := dirErr.(extensionser); ok { + err.Extensions = ex.Extensions() + } + return sendAndReturnClosed(&Response{Errors: []*errors.QueryError{err}}) } } } @@ -65,6 +70,9 @@ func (r *Request) Subscribe(ctx context.Context, s *resolvable.Schema, op *types resolverErr := callOut[1].Interface().(error) err = errors.Errorf("%s", resolverErr) err.ResolverError = resolverErr + if ex, ok := callOut[1].Interface().(extensionser); ok { + err.Extensions = ex.Extensions() + } } if err != nil { @@ -117,10 +125,10 @@ func (r *Request) Subscribe(ctx context.Context, s *resolvable.Schema, op *types Vars: r.Request.Vars, Schema: r.Request.Schema, }, - Limiter: r.Limiter, - Tracer: r.Tracer, - Logger: r.Logger, - DirectiveExecutors: r.DirectiveExecutors, + Limiter: r.Limiter, + Tracer: r.Tracer, + Logger: r.Logger, + DirectiveFuncs: r.DirectiveFuncs, } var out bytes.Buffer func() { diff --git a/subscriptions.go b/subscriptions.go index 3cde025a..01cf5a9c 100644 --- a/subscriptions.go +++ b/subscriptions.go @@ -59,7 +59,7 @@ func (s *Schema) subscribe(ctx context.Context, queryString string, operationNam Logger: s.logger, PanicHandler: s.panicHandler, SubscribeResolverTimeout: s.subscribeResolverTimeout, - DirectiveExecutors: s.executors, + DirectiveFuncs: s.directiveFuncs, } varTypes := make(map[string]*introspection.Type) for _, v := range op.Vars { diff --git a/types/directive.go b/types/directive.go index fa5c2559..58cf2394 100644 --- a/types/directive.go +++ b/types/directive.go @@ -28,9 +28,7 @@ type DirectiveDefinition struct { type DirectiveList []*Directive -type DirectiveExecutor interface { - Exec(ctx context.Context, directive *Directive, input interface{}) (context.Context, error) -} +type DirectiveFunc func(ctx context.Context, directive *Directive, input interface{}) (context.Context, error) // Returns the Directive in the DirectiveList by name or nil if not found. func (l DirectiveList) Get(name string) *Directive {