-
Notifications
You must be signed in to change notification settings - Fork 72
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Adds ProvidedRequiredArgumentsOnDirectivesRule
- Loading branch information
1 parent
0e41379
commit 560ac6f
Showing
3 changed files
with
183 additions
and
1 deletion.
There are no files selected for viewing
64 changes: 64 additions & 0 deletions
64
Sources/GraphQL/Validation/Rules/ProvidedRequiredArgumentsOnDirectivesRule.swift
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
|
||
func ProvidedRequiredArgumentsOnDirectivesRule( | ||
context: SDLorNormalValidationContext | ||
) -> Visitor { | ||
var requiredArgsMap = [String: [String: String]]() | ||
|
||
let schema = context.getSchema() | ||
let definedDirectives = schema?.directives ?? specifiedDirectives | ||
for directive in definedDirectives { | ||
var requiredArgs = [String: String]() | ||
for arg in directive.args.filter({ isRequiredArgument($0) }) { | ||
requiredArgs[arg.name] = arg.type.debugDescription | ||
} | ||
requiredArgsMap[directive.name] = requiredArgs | ||
} | ||
|
||
let astDefinitions = context.ast.definitions | ||
for def in astDefinitions { | ||
if let def = def as? DirectiveDefinition { | ||
let argNodes = def.arguments | ||
var requiredArgs = [String: String]() | ||
for arg in argNodes.filter({ isRequiredArgumentNode($0) }) { | ||
requiredArgs[arg.name.value] = print(ast: arg.type) | ||
} | ||
requiredArgsMap[def.name.value] = requiredArgs | ||
} | ||
} | ||
|
||
return Visitor( | ||
// Validate on leave to allow for deeper errors to appear first. | ||
leave: { node, _, _, _, _ in | ||
if let directiveNode = node as? Directive { | ||
let directiveName = directiveNode.name.value | ||
if let requiredArgs = requiredArgsMap[directiveName] { | ||
let argNodes = directiveNode.arguments | ||
let argNodeMap = Set(argNodes.map(\.name.value)) | ||
for (argName, argType) in requiredArgs { | ||
if !argNodeMap.contains(argName) { | ||
context.report( | ||
error: GraphQLError( | ||
message: "Argument \"@\(directiveName)(\(argName):)\" of type \"\(argType)\" is required, but it was not provided.", | ||
nodes: [directiveNode] | ||
) | ||
) | ||
} | ||
} | ||
} | ||
} | ||
return .continue | ||
} | ||
) | ||
} | ||
|
||
func isRequiredArgumentNode( | ||
arg: InputValueDefinition | ||
) -> Bool { | ||
return arg.type.kind == .nonNullType && arg.defaultValue == nil | ||
} | ||
|
||
func isRequiredArgumentNode( | ||
arg: VariableDefinition | ||
) -> Bool { | ||
return arg.type.kind == .nonNullType && arg.defaultValue == nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
118 changes: 118 additions & 0 deletions
118
Tests/GraphQLTests/ValidationTests/ProvidedRequiredArgumentsOnDirectivesRuleTests.swift
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
@testable import GraphQL | ||
import XCTest | ||
|
||
class ProvidedRequiredArgumentsOnDirectivesRuleTests: SDLValidationTestCase { | ||
override func setUp() { | ||
rule = ProvidedRequiredArgumentsOnDirectivesRule | ||
} | ||
|
||
func testMissingOptionalArgsOnDirectiveDefinedInsideSDL() throws { | ||
try assertValidationErrors( | ||
""" | ||
type Query { | ||
foo: String @test | ||
} | ||
directive @test(arg1: String, arg2: String! = "") on FIELD_DEFINITION | ||
""", | ||
[] | ||
) | ||
} | ||
|
||
func testMissingArgOnDirectiveDefinedInsideSDL() throws { | ||
try assertValidationErrors( | ||
""" | ||
type Query { | ||
foo: String @test | ||
} | ||
directive @test(arg: String!) on FIELD_DEFINITION | ||
""", | ||
[ | ||
GraphQLError( | ||
message: #"Argument "@test(arg:)" of type "String!" is required, but it was not provided."#, | ||
locations: [.init(line: 2, column: 15)] | ||
), | ||
] | ||
) | ||
} | ||
|
||
func testMissingArgOnStandardDirective() throws { | ||
try assertValidationErrors( | ||
""" | ||
type Query { | ||
foo: String @include | ||
} | ||
""", | ||
[ | ||
GraphQLError( | ||
message: #"Argument "@include(if:)" of type "Boolean!" is required, but it was not provided."#, | ||
locations: [.init(line: 2, column: 15)] | ||
), | ||
] | ||
) | ||
} | ||
|
||
func testMissingArgOnOveriddenStandardDirective() throws { | ||
try assertValidationErrors( | ||
""" | ||
type Query { | ||
foo: String @deprecated | ||
} | ||
directive @deprecated(reason: String!) on FIELD | ||
""", | ||
[ | ||
GraphQLError( | ||
message: #"Argument "@deprecated(reason:)" of type "String!" is required, but it was not provided."#, | ||
locations: [.init(line: 2, column: 15)] | ||
), | ||
] | ||
) | ||
} | ||
|
||
func testMissingArgOnDirectiveDefinedInSchemaExtension() throws { | ||
let schema = try buildSchema(source: """ | ||
type Query { | ||
foo: String | ||
} | ||
""") | ||
let sdl = """ | ||
directive @test(arg: String!) on OBJECT | ||
extend type Query @test | ||
""" | ||
try assertValidationErrors( | ||
sdl, | ||
schema: schema, | ||
[ | ||
GraphQLError( | ||
message: #"Argument "@test(arg:)" of type "String!" is required, but it was not provided."#, | ||
locations: [.init(line: 3, column: 20)] | ||
), | ||
] | ||
) | ||
} | ||
|
||
func testMissingArgOnDirectiveUsedInSchemaExtension() throws { | ||
let schema = try buildSchema(source: """ | ||
directive @test(arg: String!) on OBJECT | ||
type Query { | ||
foo: String | ||
} | ||
""") | ||
let sdl = """ | ||
extend type Query @test | ||
""" | ||
try assertValidationErrors( | ||
sdl, | ||
schema: schema, | ||
[ | ||
GraphQLError( | ||
message: #"Argument "@test(arg:)" of type "String!" is required, but it was not provided."#, | ||
locations: [.init(line: 1, column: 19)] | ||
), | ||
] | ||
) | ||
} | ||
} |