diff --git a/dmoj/commands/__init__.py b/dmoj/commands/__init__.py index 06c50933f..e3b8051a8 100644 --- a/dmoj/commands/__init__.py +++ b/dmoj/commands/__init__.py @@ -11,6 +11,7 @@ from dmoj.commands.submissions import ListSubmissionsCommand from dmoj.commands.submit import SubmitCommand from dmoj.commands.test import TestCommand +from dmoj.commands.validate import ValidateCommand all_commands: List[Type[Command]] = [ ListProblemsCommand, @@ -23,4 +24,5 @@ ShowCommand, HelpCommand, QuitCommand, + ValidateCommand, ] diff --git a/dmoj/commands/validate.py b/dmoj/commands/validate.py new file mode 100644 index 000000000..fe6d3bc75 --- /dev/null +++ b/dmoj/commands/validate.py @@ -0,0 +1,123 @@ +import os +from itertools import groupby +from operator import itemgetter +from typing import List, Optional, Tuple, Union + + +from dmoj import executors +from dmoj.commands.base_command import Command +from dmoj.error import CompileError, InvalidCommandException +from dmoj.graders import StandardGrader +from dmoj.judgeenv import get_problem_root, get_supported_problems +from dmoj.problem import BatchedTestCase, Problem, ProblemConfig, ProblemDataManager, TestCase +from dmoj.result import CheckerResult, Result +from dmoj.utils.ansi import print_ansi +from dmoj.utils.unicode import utf8bytes, utf8text + + +all_executors = executors.executors + + +class ValidationGrader(StandardGrader): + def check_result(self, case, result): + return CheckerResult(not result.result_flag, (not result.result_flag) * case.points, result.proc_output) + + +class ValidateCommand(Command): + name = 'validate' + help = 'Validates input for problems.' + + def _populate_parser(self) -> None: + self.arg_parser.add_argument('problem_ids', nargs='+', help='ids of problems to test') + + def execute(self, line: str) -> int: + args = self.arg_parser.parse_args(line) + + problem_ids = args.problem_ids + supported_problems = set(get_supported_problems()) + + unknown_problems = ', '.join(f"'{i}'" for i in problem_ids if i not in supported_problems) + if unknown_problems: + raise InvalidCommandException(f'unknown problem(s) {unknown_problems}') + total_fails = 0 + for problem_id in problem_ids: + fails = self.validate_problem(problem_id) + if fails: + print_ansi(f'Problem #ansi[{problem_id}](cyan|bold) #ansi[failed](red|bold).') + total_fails += 1 + else: + print_ansi(f'Problem #ansi[{problem_id}](cyan|bold) passed with flying colours.') + print() + + print() + print('Test complete.') + if total_fails: + print_ansi(f'#ansi[A total of {total_fails} problems have invalid input](red|bold)') + else: + print_ansi('#ansi[All problems validated.](green|bold)') + + return total_fails + + def validate_problem(self, problem_id: str) -> int: + print_ansi(f'Validating problem #ansi[{problem_id}](cyan|bold)...') + + problem_root = get_problem_root(problem_id) + config = ProblemConfig(ProblemDataManager(problem_root)) + + if not config or 'validator' not in config: + print_ansi('\t#ansi[Skipped](magenta|bold) - No validator found') + return 0 + + validator_config = config['validator'] + language = validator_config['language'] + if language not in all_executors: + print_ansi('\t\t#ansi[Skipped](magenta|bold) - Language not supported') + return 0 + time_limit = validator_config['time'] + memory_limit = validator_config['memory'] + with open(os.path.join(problem_root, validator_config['source'])) as f: + source = f.read() + + problem = Problem(problem_id, time_limit, memory_limit, {}) + + try: + real_grader = problem.grader_class(self.judge, problem, language, utf8bytes(source)) + validation_grader = ValidationGrader(self.judge, problem, language, utf8bytes(source)) + except CompileError as compilation_error: + error = compilation_error.message or 'compiler exited abnormally' + print_ansi('#ansi[Failed compiling validator!](red|bold)') + print(error.rstrip()) + return 1 + + flattened_cases: List[Tuple[Optional[int], Union[TestCase, BatchedTestCase]]] = [] + batch_number = 0 + for case in real_grader.cases(): + if isinstance(case, BatchedTestCase): + batch_number += 1 + for batched_case in case.batched_cases: + flattened_cases.append((batch_number, batched_case)) + else: + flattened_cases.append((None, case)) + + case_number = 0 + fail = 0 + for batch_number, cases in groupby(flattened_cases, key=itemgetter(0)): + if batch_number: + print_ansi(f'#ansi[Batch #{batch_number}](yellow|bold)') + for _, case in cases: + case_number += 1 + result = validation_grader.grade(case) + + codes = result.readable_codes() + + colored_codes = [f'#ansi[{x}]({Result.COLORS_BYID[x]}|bold)' for x in codes] + colored_aux_codes = f'{{{", ".join(colored_codes[1:])}}}' if len(codes) > 1 else '' + colored_feedback = f'(#ansi[{utf8text(result.feedback)}](|underline)) ' if result.feedback else '' + case_info = f'{colored_feedback}{colored_aux_codes}' + case_padding = ' ' if batch_number is not None else '' + print_ansi(f'{case_padding}Test case {case_number:2d} {colored_codes[0]:3s} {case_info}') + + if result.result_flag: + fail = 1 + + return fail