diff --git a/Sources/GraphQL/Validation/Rules/ProvidedRequiredArgumentsOnDirectivesRule.swift b/Sources/GraphQL/Validation/Rules/ProvidedRequiredArgumentsOnDirectivesRule.swift new file mode 100644 index 00000000..f4211e15 --- /dev/null +++ b/Sources/GraphQL/Validation/Rules/ProvidedRequiredArgumentsOnDirectivesRule.swift @@ -0,0 +1,64 @@ + +func ProvidedRequiredArgumentsOnDirectivesRule( + context: SDLorNormalValidationContext +) -> Visitor { + var requiredArgsMap = [String: [String: String]]() + + let schema = context.getSchema() + let definedDirectives = schema?.directives ?? specifiedDirectives + for directive in definedDirectives { + var requiredArgs = [String: String]() + for arg in directive.args.filter({ isRequiredArgument($0) }) { + requiredArgs[arg.name] = arg.type.debugDescription + } + requiredArgsMap[directive.name] = requiredArgs + } + + let astDefinitions = context.ast.definitions + for def in astDefinitions { + if let def = def as? DirectiveDefinition { + let argNodes = def.arguments + var requiredArgs = [String: String]() + for arg in argNodes.filter({ isRequiredArgumentNode($0) }) { + requiredArgs[arg.name.value] = print(ast: arg.type) + } + requiredArgsMap[def.name.value] = requiredArgs + } + } + + return Visitor( + // Validate on leave to allow for deeper errors to appear first. + leave: { node, _, _, _, _ in + if let directiveNode = node as? Directive { + let directiveName = directiveNode.name.value + if let requiredArgs = requiredArgsMap[directiveName] { + let argNodes = directiveNode.arguments + let argNodeMap = Set(argNodes.map(\.name.value)) + for (argName, argType) in requiredArgs { + if !argNodeMap.contains(argName) { + context.report( + error: GraphQLError( + message: "Argument \"@\(directiveName)(\(argName):)\" of type \"\(argType)\" is required, but it was not provided.", + nodes: [directiveNode] + ) + ) + } + } + } + } + return .continue + } + ) +} + +func isRequiredArgumentNode( + arg: InputValueDefinition +) -> Bool { + return arg.type.kind == .nonNullType && arg.defaultValue == nil +} + +func isRequiredArgumentNode( + arg: VariableDefinition +) -> Bool { + return arg.type.kind == .nonNullType && arg.defaultValue == nil +} diff --git a/Sources/GraphQL/Validation/SpecifiedRules.swift b/Sources/GraphQL/Validation/SpecifiedRules.swift index fd2ad278..fe46dd29 100644 --- a/Sources/GraphQL/Validation/SpecifiedRules.swift +++ b/Sources/GraphQL/Validation/SpecifiedRules.swift @@ -51,5 +51,5 @@ public let specifiedSDLRules: [SDLValidationRule] = [ KnownArgumentNamesOnDirectivesRule, UniqueArgumentNamesRule, UniqueInputFieldNamesRule, -// ProvidedRequiredArgumentsOnDirectivesRule, + ProvidedRequiredArgumentsOnDirectivesRule, ] diff --git a/Tests/GraphQLTests/ValidationTests/ProvidedRequiredArgumentsOnDirectivesRuleTests.swift b/Tests/GraphQLTests/ValidationTests/ProvidedRequiredArgumentsOnDirectivesRuleTests.swift new file mode 100644 index 00000000..6d8a5f67 --- /dev/null +++ b/Tests/GraphQLTests/ValidationTests/ProvidedRequiredArgumentsOnDirectivesRuleTests.swift @@ -0,0 +1,118 @@ +@testable import GraphQL +import XCTest + +class ProvidedRequiredArgumentsOnDirectivesRuleTests: SDLValidationTestCase { + override func setUp() { + rule = ProvidedRequiredArgumentsOnDirectivesRule + } + + func testMissingOptionalArgsOnDirectiveDefinedInsideSDL() throws { + try assertValidationErrors( + """ + type Query { + foo: String @test + } + + directive @test(arg1: String, arg2: String! = "") on FIELD_DEFINITION + """, + [] + ) + } + + func testMissingArgOnDirectiveDefinedInsideSDL() throws { + try assertValidationErrors( + """ + type Query { + foo: String @test + } + + directive @test(arg: String!) on FIELD_DEFINITION + """, + [ + GraphQLError( + message: #"Argument "@test(arg:)" of type "String!" is required, but it was not provided."#, + locations: [.init(line: 2, column: 15)] + ), + ] + ) + } + + func testMissingArgOnStandardDirective() throws { + try assertValidationErrors( + """ + type Query { + foo: String @include + } + """, + [ + GraphQLError( + message: #"Argument "@include(if:)" of type "Boolean!" is required, but it was not provided."#, + locations: [.init(line: 2, column: 15)] + ), + ] + ) + } + + func testMissingArgOnOveriddenStandardDirective() throws { + try assertValidationErrors( + """ + type Query { + foo: String @deprecated + } + directive @deprecated(reason: String!) on FIELD + """, + [ + GraphQLError( + message: #"Argument "@deprecated(reason:)" of type "String!" is required, but it was not provided."#, + locations: [.init(line: 2, column: 15)] + ), + ] + ) + } + + func testMissingArgOnDirectiveDefinedInSchemaExtension() throws { + let schema = try buildSchema(source: """ + type Query { + foo: String + } + """) + let sdl = """ + directive @test(arg: String!) on OBJECT + + extend type Query @test + """ + try assertValidationErrors( + sdl, + schema: schema, + [ + GraphQLError( + message: #"Argument "@test(arg:)" of type "String!" is required, but it was not provided."#, + locations: [.init(line: 3, column: 20)] + ), + ] + ) + } + + func testMissingArgOnDirectiveUsedInSchemaExtension() throws { + let schema = try buildSchema(source: """ + directive @test(arg: String!) on OBJECT + + type Query { + foo: String + } + """) + let sdl = """ + extend type Query @test + """ + try assertValidationErrors( + sdl, + schema: schema, + [ + GraphQLError( + message: #"Argument "@test(arg:)" of type "String!" is required, but it was not provided."#, + locations: [.init(line: 1, column: 19)] + ), + ] + ) + } +}