From 5b08e9b9bb626f9a38b0bc7d40056a7e528302e7 Mon Sep 17 00:00:00 2001 From: John Chilton Date: Mon, 28 Oct 2024 16:29:16 -0400 Subject: [PATCH] Integrate some new Pydantic models into Planemo lint. --- lib/galaxy/tool_util/linters/tests.py | 62 ++++++++++ lib/galaxy/tool_util/parameters/__init__.py | 6 +- lib/galaxy/tool_util/parameters/case.py | 9 +- lib/galaxy/tool_util/parameters/models.py | 13 +-- lib/galaxy/tool_util/parameters/state.py | 47 ++++---- test/unit/tool_util/test_tool_linters.py | 120 ++++++++++++++++---- 6 files changed, 199 insertions(+), 58 deletions(-) diff --git a/lib/galaxy/tool_util/linters/tests.py b/lib/galaxy/tool_util/linters/tests.py index 90cec700b06b..f6f483279003 100644 --- a/lib/galaxy/tool_util/linters/tests.py +++ b/lib/galaxy/tool_util/linters/tests.py @@ -1,5 +1,6 @@ """This module contains a linting functions for tool tests.""" +from io import StringIO from typing import ( Iterator, List, @@ -8,6 +9,8 @@ ) from galaxy.tool_util.lint import Linter +from galaxy.tool_util.parameters import validate_test_cases_for_tool_source +from galaxy.tool_util.verify.assertion_models import assertion_list from galaxy.util import asbool from ._util import is_datasource @@ -134,6 +137,65 @@ def lint(cls, tool_source: "ToolSource", lint_ctx: "LintContext"): ) +class TestsAssertionValidation(Linter): + @classmethod + def lint(cls, tool_source: "ToolSource", lint_ctx: "LintContext"): + try: + raw_tests_dict = tool_source.parse_tests_to_dict() + except Exception: + lint_ctx.warn("Failed to parse test dictionaries from tool - cannot lint assertions") + return + assert "tests" in raw_tests_dict + for test_idx, test in enumerate(raw_tests_dict["tests"], start=1): + # TODO: validate command, command_version, element tests. What about children? + for output in test["outputs"]: + asserts_raw = output.get("attributes", {}).get("assert_list") or [] + to_yaml_assertions = [] + for raw_assert in asserts_raw: + to_yaml_assertions.append({"that": raw_assert["tag"], **raw_assert.get("attributes", {})}) + try: + assertion_list.model_validate(to_yaml_assertions) + except Exception as e: + error_str = _cleanup_pydantic_error(e) + lint_ctx.warn( + f"Test {test_idx}: failed to validate assertions. Validation errors are [{error_str}]" + ) + + +class TestsCaseValidation(Linter): + @classmethod + def lint(cls, tool_source: "ToolSource", lint_ctx: "LintContext"): + try: + validation_results = validate_test_cases_for_tool_source(tool_source, use_latest_profile=True) + except Exception as e: + lint_ctx.warn( + f"Serious problem parsing tool source or tests - cannot validate test cases. The exception is [{e}]", + linter=cls.name(), + ) + return + for test_idx, validation_result in enumerate(validation_results, start=1): + error = validation_result.validation_error + if error: + error_str = _cleanup_pydantic_error(error) + lint_ctx.warn( + f"Test {test_idx}: failed to validate test parameters against inputs - tests won't run on a modern Galaxy tool profile version. Validation errors are [{error_str}]", + linter=cls.name(), + ) + + +def _cleanup_pydantic_error(error) -> str: + full_validation_error = f"{error}" + new_error = StringIO("") + for line in full_validation_error.splitlines(): + # this repeated over and over isn't useful in the context of how we're building the dynamic models, + # tool authors should not be looking up pydantic docs on models they cannot even really inspect + if line.strip().startswith("For further information visit https://errors.pydantic"): + continue + else: + new_error.write(f"{line}\n") + return new_error.getvalue().strip() + + class TestsExpectNumOutputs(Linter): @classmethod def lint(cls, tool_source: "ToolSource", lint_ctx: "LintContext"): diff --git a/lib/galaxy/tool_util/parameters/__init__.py b/lib/galaxy/tool_util/parameters/__init__.py index 45cd770a5e3f..22dd7e6053aa 100644 --- a/lib/galaxy/tool_util/parameters/__init__.py +++ b/lib/galaxy/tool_util/parameters/__init__.py @@ -1,4 +1,7 @@ -from .case import test_case_state +from .case import ( + test_case_state, + validate_test_cases_for_tool_source, +) from .convert import ( decode, dereference, @@ -139,6 +142,7 @@ "ToolParameterT", "to_json_schema_string", "test_case_state", + "validate_test_cases_for_tool_source", "RequestToolState", "RequestInternalToolState", "RequestInternalDereferencedToolState", diff --git a/lib/galaxy/tool_util/parameters/case.py b/lib/galaxy/tool_util/parameters/case.py index d1ff72b67d24..c371d5e3b12e 100644 --- a/lib/galaxy/tool_util/parameters/case.py +++ b/lib/galaxy/tool_util/parameters/case.py @@ -164,12 +164,12 @@ def test_case_state( def test_case_validation( - test_dict: ToolSourceTest, tool_parameter_bundle: List[ToolParameterT], profile: str + test_dict: ToolSourceTest, tool_parameter_bundle: List[ToolParameterT], profile: str, name: Optional[str] = None ) -> TestCaseStateValidationResult: test_case_state_and_warnings = test_case_state(test_dict, tool_parameter_bundle, profile, validate=False) exception: Optional[Exception] = None try: - test_case_state_and_warnings.tool_state.validate(tool_parameter_bundle) + test_case_state_and_warnings.tool_state.validate(tool_parameter_bundle, name=name) for input_name in test_case_state_and_warnings.unhandled_inputs: raise Exception(f"Invalid parameter name found {input_name}") except Exception as e: @@ -323,8 +323,9 @@ def _input_for(flat_state_path: str, inputs: ToolSourceTestInputs) -> Optional[T def validate_test_cases_for_tool_source( - tool_source: ToolSource, use_latest_profile: bool = False + tool_source: ToolSource, use_latest_profile: bool = False, name: Optional[str] = None ) -> List[TestCaseStateValidationResult]: + name = name or f"PydanticModelFor[{tool_source.parse_id()}]" tool_parameter_bundle = input_models_for_tool_source(tool_source) if use_latest_profile: # this might get old but it is fine, just needs to be updated when test case changes are made @@ -334,6 +335,6 @@ def validate_test_cases_for_tool_source( test_cases: List[ToolSourceTest] = tool_source.parse_tests_to_dict()["tests"] results_by_test: List[TestCaseStateValidationResult] = [] for test_case in test_cases: - validation_result = test_case_validation(test_case, tool_parameter_bundle.parameters, profile) + validation_result = test_case_validation(test_case, tool_parameter_bundle.parameters, profile, name=name) results_by_test.append(validation_result) return results_by_test diff --git a/lib/galaxy/tool_util/parameters/models.py b/lib/galaxy/tool_util/parameters/models.py index 5019a92f5f70..d0e7e6bbcb8f 100644 --- a/lib/galaxy/tool_util/parameters/models.py +++ b/lib/galaxy/tool_util/parameters/models.py @@ -1489,8 +1489,8 @@ def create_model_strict(*args, **kwd) -> Type[BaseModel]: def create_model_factory(state_representation: StateRepresentationT): - def create_method(tool: ToolParameterBundle, name: str = DEFAULT_MODEL_NAME) -> Type[BaseModel]: - return create_field_model(tool.parameters, name, state_representation) + def create_method(tool: ToolParameterBundle, name: Optional[str] = None) -> Type[BaseModel]: + return create_field_model(tool.parameters, name or DEFAULT_MODEL_NAME, state_representation) return create_method @@ -1546,15 +1546,14 @@ def validate_against_model(pydantic_model: Type[BaseModel], parameter_state: Dic class ValidationFunctionT(Protocol): - def __call__(self, tool: ToolParameterBundle, request: RawStateDict, name: str = DEFAULT_MODEL_NAME) -> None: ... + def __call__(self, tool: ToolParameterBundle, request: RawStateDict, name: Optional[str] = None) -> None: ... def validate_model_type_factory(state_representation: StateRepresentationT) -> ValidationFunctionT: - def validate_request(tool: ToolParameterBundle, request: Dict[str, Any], name: str = DEFAULT_MODEL_NAME) -> None: - pydantic_model = create_field_model( - tool.parameters, name=DEFAULT_MODEL_NAME, state_representation=state_representation - ) + def validate_request(tool: ToolParameterBundle, request: Dict[str, Any], name: Optional[str] = None) -> None: + name = name or DEFAULT_MODEL_NAME + pydantic_model = create_field_model(tool.parameters, name=name, state_representation=state_representation) validate_against_model(pydantic_model, request) return validate_request diff --git a/lib/galaxy/tool_util/parameters/state.py b/lib/galaxy/tool_util/parameters/state.py index 3edc96f69c6e..df040e1803ca 100644 --- a/lib/galaxy/tool_util/parameters/state.py +++ b/lib/galaxy/tool_util/parameters/state.py @@ -6,6 +6,7 @@ Any, Dict, List, + Optional, Type, Union, ) @@ -42,8 +43,8 @@ def __init__(self, input_state: Dict[str, Any]): def _validate(self, pydantic_model: Type[BaseModel]) -> None: validate_against_model(pydantic_model, self.input_state) - def validate(self, parameters: HasToolParameters) -> None: - base_model = self.parameter_model_for(parameters) + def validate(self, parameters: HasToolParameters, name: Optional[str] = None) -> None: + base_model = self.parameter_model_for(parameters, name=name) if base_model is None: raise NotImplementedError( f"Validating tool state against state representation {self.state_representation} is not implemented." @@ -56,17 +57,17 @@ def state_representation(self) -> StateRepresentationT: """Get state representation of the inputs.""" @classmethod - def parameter_model_for(cls, parameters: HasToolParameters) -> Type[BaseModel]: + def parameter_model_for(cls, parameters: HasToolParameters, name: Optional[str] = None) -> Type[BaseModel]: bundle: ToolParameterBundle if isinstance(parameters, list): bundle = ToolParameterBundleModel(parameters=parameters) else: bundle = parameters - return cls._parameter_model_for(bundle) + return cls._parameter_model_for(bundle, name=name) @classmethod @abstractmethod - def _parameter_model_for(cls, parameters: ToolParameterBundle) -> Type[BaseModel]: + def _parameter_model_for(cls, parameters: ToolParameterBundle, name: Optional[str] = None) -> Type[BaseModel]: """Return a model type for this tool state kind.""" @@ -74,70 +75,70 @@ class RequestToolState(ToolState): state_representation: Literal["request"] = "request" @classmethod - def _parameter_model_for(cls, parameters: ToolParameterBundle) -> Type[BaseModel]: - return create_request_model(parameters) + def _parameter_model_for(cls, parameters: ToolParameterBundle, name: Optional[str] = None) -> Type[BaseModel]: + return create_request_model(parameters, name) class RequestInternalToolState(ToolState): state_representation: Literal["request_internal"] = "request_internal" @classmethod - def _parameter_model_for(cls, parameters: ToolParameterBundle) -> Type[BaseModel]: - return create_request_internal_model(parameters) + def _parameter_model_for(cls, parameters: ToolParameterBundle, name: Optional[str] = None) -> Type[BaseModel]: + return create_request_internal_model(parameters, name) class LandingRequestToolState(ToolState): state_representation: Literal["landing_request"] = "landing_request" @classmethod - def _parameter_model_for(cls, parameters: ToolParameterBundle) -> Type[BaseModel]: - return create_landing_request_model(parameters) + def _parameter_model_for(cls, parameters: ToolParameterBundle, name: Optional[str] = None) -> Type[BaseModel]: + return create_landing_request_model(parameters, name) class LandingRequestInternalToolState(ToolState): state_representation: Literal["landing_request_internal"] = "landing_request_internal" @classmethod - def _parameter_model_for(cls, parameters: ToolParameterBundle) -> Type[BaseModel]: - return create_landing_request_internal_model(parameters) + def _parameter_model_for(cls, parameters: ToolParameterBundle, name: Optional[str] = None) -> Type[BaseModel]: + return create_landing_request_internal_model(parameters, name) class RequestInternalDereferencedToolState(ToolState): state_representation: Literal["request_internal_dereferenced"] = "request_internal_dereferenced" @classmethod - def _parameter_model_for(cls, parameters: ToolParameterBundle) -> Type[BaseModel]: - return create_request_internal_dereferenced_model(parameters) + def _parameter_model_for(cls, parameters: ToolParameterBundle, name: Optional[str] = None) -> Type[BaseModel]: + return create_request_internal_dereferenced_model(parameters, name) class JobInternalToolState(ToolState): state_representation: Literal["job_internal"] = "job_internal" @classmethod - def _parameter_model_for(cls, parameters: ToolParameterBundle) -> Type[BaseModel]: - return create_job_internal_model(parameters) + def _parameter_model_for(cls, parameters: ToolParameterBundle, name: Optional[str] = None) -> Type[BaseModel]: + return create_job_internal_model(parameters, name) class TestCaseToolState(ToolState): state_representation: Literal["test_case_xml"] = "test_case_xml" @classmethod - def _parameter_model_for(cls, parameters: ToolParameterBundle) -> Type[BaseModel]: + def _parameter_model_for(cls, parameters: ToolParameterBundle, name: Optional[str] = None) -> Type[BaseModel]: # implement a test case model... - return create_test_case_model(parameters) + return create_test_case_model(parameters, name) class WorkflowStepToolState(ToolState): state_representation: Literal["workflow_step"] = "workflow_step" @classmethod - def _parameter_model_for(cls, parameters: ToolParameterBundle) -> Type[BaseModel]: - return create_workflow_step_model(parameters) + def _parameter_model_for(cls, parameters: ToolParameterBundle, name: Optional[str] = None) -> Type[BaseModel]: + return create_workflow_step_model(parameters, name) class WorkflowStepLinkedToolState(ToolState): state_representation: Literal["workflow_step_linked"] = "workflow_step_linked" @classmethod - def _parameter_model_for(cls, parameters: ToolParameterBundle) -> Type[BaseModel]: - return create_workflow_step_linked_model(parameters) + def _parameter_model_for(cls, parameters: ToolParameterBundle, name: Optional[str] = None) -> Type[BaseModel]: + return create_workflow_step_linked_model(parameters, name) diff --git a/test/unit/tool_util/test_tool_linters.py b/test/unit/tool_util/test_tool_linters.py index 7e6d3b62ec93..d693bae26390 100644 --- a/test/unit/tool_util/test_tool_linters.py +++ b/test/unit/tool_util/test_tool_linters.py @@ -26,7 +26,9 @@ xsd, ) from galaxy.tool_util.loader_directory import load_tool_sources_from_path +from galaxy.tool_util.parser.interface import ToolSource from galaxy.tool_util.parser.xml import XmlToolSource +from galaxy.tool_util.unittest_utils import functional_test_tool_path from galaxy.util import ( ElementTree, submodules, @@ -698,6 +700,11 @@ TESTS_ABSENT = """ """ +TESTS_ABSENT_YAML = """ +class: GalaxyTool +name: name +id: id +""" TESTS_ABSENT_DATA_SOURCE = """ """ @@ -794,6 +801,38 @@ """ +INVALID_CENTER_OF_MASS = """ + + + + + + + + + + + + + + +""" +VALID_CENTER_OF_MASS = """ + + + + + + + + + + + + + + +""" TESTS_VALID = """ @@ -801,7 +840,7 @@ - + @@ -950,6 +989,19 @@ def get_xml_tool_source(xml_string: str) -> XmlToolSource: return XmlToolSource(get_xml_tree(xml_string)) +def get_tool_source(source_contents: str) -> ToolSource: + if "GalaxyTool" in source_contents: + with tempfile.NamedTemporaryFile(mode="w", suffix="tool.yml") as tmp: + tmp.write(source_contents) + tmp.flush() + tool_sources = load_tool_sources_from_path(tmp.name) + assert len(tool_sources) == 1, "Expected 1 tool source" + tool_source = tool_sources[0][1] + return tool_source + else: + return get_xml_tool_source(source_contents) + + def run_lint_module(lint_ctx, lint_module, lint_target): lint_tool_source_with_modules(lint_ctx, lint_target, list({lint_module, xsd})) @@ -1751,13 +1803,14 @@ def test_stdio_invalid_match(lint_ctx): def test_tests_absent(lint_ctx): - tool_source = get_xml_tool_source(TESTS_ABSENT) - run_lint_module(lint_ctx, tests, tool_source) - assert "No tests found, most tools should define test cases." in lint_ctx.warn_messages - assert not lint_ctx.info_messages - assert not lint_ctx.valid_messages - assert len(lint_ctx.warn_messages) == 1 - assert not lint_ctx.error_messages + for test_contents in [TESTS_ABSENT, TESTS_ABSENT_YAML]: + tool_source = get_tool_source(test_contents) + run_lint_module(lint_ctx, tests, tool_source) + assert "No tests found, most tools should define test cases." in lint_ctx.warn_messages + assert not lint_ctx.info_messages + assert not lint_ctx.valid_messages + assert len(lint_ctx.warn_messages) == 1 + assert not lint_ctx.error_messages def test_tests_data_source(lint_ctx): @@ -1791,7 +1844,6 @@ def test_tests_param_output_names(lint_ctx): ) assert not lint_ctx.info_messages assert len(lint_ctx.valid_messages) == 1 - assert not lint_ctx.warn_messages assert len(lint_ctx.error_messages) == 6 @@ -1806,7 +1858,7 @@ def test_tests_expect_failure_output(lint_ctx): ) assert not lint_ctx.info_messages assert not lint_ctx.valid_messages - assert len(lint_ctx.warn_messages) == 1 + assert len(lint_ctx.warn_messages) == 3 assert len(lint_ctx.error_messages) == 2 @@ -1854,10 +1906,24 @@ def test_tests_asserts(lint_ctx): assert "Test 1: 'has_size' must not specify 'value' and 'size'" in lint_ctx.error_messages assert "Test 1: 'has_n_columns' needs to specify 'n', 'min', or 'max'" in lint_ctx.error_messages assert "Test 1: 'has_n_lines' needs to specify 'n', 'min', or 'max'" in lint_ctx.error_messages - assert not lint_ctx.warn_messages assert len(lint_ctx.error_messages) == 9 +def test_tests_assertion_models_valid(lint_ctx): + tool_source = get_xml_tool_source(VALID_CENTER_OF_MASS) + run_lint_module(lint_ctx, tests, tool_source) + assert len(lint_ctx.error_messages) == 0 + assert len(lint_ctx.warn_messages) == 0 + + +def test_tests_assertion_models_invalid(lint_ctx): + tool_source = get_xml_tool_source(INVALID_CENTER_OF_MASS) + run_lint_module(lint_ctx, tests, tool_source) + assert len(lint_ctx.error_messages) == 0 + assert len(lint_ctx.warn_messages) == 1 + assert "Test 1: failed to validate assertions. Validation errors are " in lint_ctx.warn_messages + + def test_tests_output_type_mismatch(lint_ctx): tool_source = get_xml_tool_source(TESTS_OUTPUT_TYPE_MISMATCH) run_lint_module(lint_ctx, tests, tool_source) @@ -1869,7 +1935,6 @@ def test_tests_output_type_mismatch(lint_ctx): "Test 1: test collection output 'data_name' does not correspond to a 'output_collection' output, but a 'data'" in lint_ctx.error_messages ) - assert not lint_ctx.warn_messages assert len(lint_ctx.error_messages) == 2 @@ -1892,7 +1957,6 @@ def test_tests_discover_outputs(lint_ctx): "Test 5: test collection 'collection_name' must contain nested 'element' tags and/or element children with a 'count' attribute" in lint_ctx.error_messages ) - assert not lint_ctx.warn_messages assert len(lint_ctx.error_messages) == 4 @@ -1911,7 +1975,6 @@ def test_tests_compare_attrib_incompatibility(lint_ctx): assert 'Test 1: Attribute sort is incompatible with compare="contains".' in lint_ctx.error_messages assert not lint_ctx.info_messages assert len(lint_ctx.valid_messages) == 1 - assert not lint_ctx.warn_messages assert len(lint_ctx.error_messages) == 2 @@ -2077,14 +2140,8 @@ def test_tool_and_macro_xml(lint_ctx_xpath, lint_ctx): def test_linting_yml_tool(lint_ctx): - with tempfile.TemporaryDirectory() as tmp: - tool_path = os.path.join(tmp, "tool.yml") - with open(tool_path, "w") as tmpf: - tmpf.write(YAML_TOOL) - tool_sources = load_tool_sources_from_path(tmp) - assert len(tool_sources) == 1, "Expected 1 tool source" - tool_source = tool_sources[0][1] - lint_tool_source_with(lint_ctx, tool_source) + tool_source = get_tool_source(YAML_TOOL) + lint_tool_source_with(lint_ctx, tool_source) assert "Tool defines a version [1.0]." in lint_ctx.valid_messages assert "Tool defines a name [simple_constructs_y]." in lint_ctx.valid_messages assert "Tool defines an id [simple_constructs_y]." in lint_ctx.valid_messages @@ -2145,7 +2202,7 @@ def test_skip_by_module(lint_ctx): def test_list_linters(): linter_names = Linter.list_listers() # make sure to add/remove a test for new/removed linters if this number changes - assert len(linter_names) == 132 + assert len(linter_names) == 134 assert "Linter" not in linter_names # make sure that linters from all modules are available for prefix in [ @@ -2164,6 +2221,23 @@ def test_list_linters(): assert len([x for x in linter_names if x.startswith(prefix)]) +def test_linting_functional_tool_multi_select(lint_ctx): + tool_source = functional_test_tool_source("multi_select.xml") + run_lint_module(lint_ctx, tests, tool_source) + warn_message = lint_ctx.warn_messages[0] + assert ( + "Test 2: failed to validate test parameters against inputs - tests won't run on a modern Galaxy tool profile version. Validation errors are [5 validation errors for" + in str(warn_message) + ) + + +def functional_test_tool_source(name: str) -> ToolSource: + tool_sources = load_tool_sources_from_path(functional_test_tool_path(name)) + assert len(tool_sources) == 1, "Expected 1 tool source" + tool_source = tool_sources[0][1] + return tool_source + + def test_linter_module_list(): linter_modules = submodules.import_submodules(galaxy.tool_util.linters) linter_module_names = [m.__name__.split(".")[-1] for m in linter_modules]