Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better Typing for Tool Execution Plumbing #18626

Merged
merged 2 commits into from
Aug 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions lib/galaxy/managers/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,7 +547,7 @@ def set_metadata(self, trans, dataset_assoc, overwrite=False, validate=True):
if overwrite:
self.overwrite_metadata(data)

job, *_ = self.app.datatypes_registry.set_external_metadata_tool.tool_action.execute(
job, *_ = self.app.datatypes_registry.set_external_metadata_tool.tool_action.execute_via_trans(
self.app.datatypes_registry.set_external_metadata_tool,
trans,
incoming={"input1": data, "validate": validate},
Expand Down Expand Up @@ -883,7 +883,7 @@ def deserialize_datatype(self, item, key, val, **context):
assert (
trans
), "Logic error in Galaxy, deserialize_datatype not send a transation object" # TODO: restructure this for stronger typing
job, *_ = self.app.datatypes_registry.set_external_metadata_tool.tool_action.execute(
job, *_ = self.app.datatypes_registry.set_external_metadata_tool.tool_action.execute_via_trans(
self.app.datatypes_registry.set_external_metadata_tool, trans, incoming={"input1": item}, overwrite=False
) # overwrite is False as per existing behavior
trans.app.job_manager.enqueue(job, tool=trans.app.datatypes_registry.set_external_metadata_tool)
Expand Down
2 changes: 1 addition & 1 deletion lib/galaxy/managers/histories.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,7 @@ def queue_history_export(

# Run job to do export.
history_exp_tool = trans.app.toolbox.get_tool(export_tool_id)
job, *_ = history_exp_tool.execute(trans, incoming=params, history=history, set_output_hid=True)
job, *_ = history_exp_tool.execute(trans, incoming=params, history=history)
trans.app.job_manager.enqueue(job, tool=history_exp_tool)
return job

Expand Down
101 changes: 82 additions & 19 deletions lib/galaxy/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
StoredWorkflow,
)
from galaxy.model.base import transaction
from galaxy.model.dataset_collections.matching import MatchingCollections
from galaxy.tool_shed.util.repository_util import get_installed_repository
from galaxy.tool_shed.util.shed_util_common import set_image_paths
from galaxy.tool_util.deps import (
Expand Down Expand Up @@ -113,6 +114,7 @@
from galaxy.tools.actions.model_operations import ModelOperationToolAction
from galaxy.tools.cache import ToolDocumentCache
from galaxy.tools.evaluation import global_tool_errors
from galaxy.tools.execution_helpers import ToolExecutionCache
from galaxy.tools.imp_exp import JobImportHistoryArchiveWrapper
from galaxy.tools.parameters import (
check_param,
Expand Down Expand Up @@ -183,8 +185,18 @@
from galaxy.version import VERSION_MAJOR
from galaxy.work.context import proxy_work_context_for_history
from .execute import (
DatasetCollectionElementsSliceT,
DEFAULT_JOB_CALLBACK,
DEFAULT_PREFERRED_OBJECT_STORE_ID,
DEFAULT_RERUN_REMAP_JOB_ID,
DEFAULT_SET_OUTPUT_HID,
DEFAULT_USE_CACHED_JOB,
execute as execute_job,
ExecutionSlice,
JobCallbackT,
MappingParameters,
ToolParameterRequestInstanceT,
ToolParameterRequestT,
)

if TYPE_CHECKING:
Expand Down Expand Up @@ -1862,11 +1874,11 @@ def expand_incoming(self, trans, incoming, request_context, input_format="legacy
def handle_input(
self,
trans,
incoming,
history=None,
use_cached_job=False,
preferred_object_store_id: Optional[str] = None,
input_format="legacy",
incoming: ToolParameterRequestT,
history: Optional[model.History] = None,
use_cached_job: bool = DEFAULT_USE_CACHED_JOB,
preferred_object_store_id: Optional[str] = DEFAULT_PREFERRED_OBJECT_STORE_ID,
input_format: str = "legacy",
):
"""
Process incoming parameters for this tool from the dict `incoming`,
Expand Down Expand Up @@ -1942,23 +1954,23 @@ def handle_incoming_errors(self, all_errors):
def handle_single_execution(
self,
trans,
rerun_remap_job_id,
execution_slice,
history,
execution_cache=None,
completed_job=None,
collection_info=None,
job_callback=None,
preferred_object_store_id=None,
flush_job=True,
skip=False,
rerun_remap_job_id: Optional[int],
execution_slice: ExecutionSlice,
history: model.History,
execution_cache: ToolExecutionCache,
completed_job: Optional[model.Job],
collection_info: Optional[MatchingCollections],
job_callback: Optional[JobCallbackT],
preferred_object_store_id: Optional[str],
flush_job: bool,
skip: bool,
):
"""
Return a pair with whether execution is successful as well as either
resulting output data or an error message indicating the problem.
"""
try:
rval = self.execute(
rval = self._execute(
trans,
incoming=execution_slice.param_combination,
history=history,
Expand Down Expand Up @@ -2045,18 +2057,67 @@ def get_static_param_values(self, trans):
args[key] = param.get_initial_value(trans, None)
return args

def execute(self, trans, incoming=None, set_output_hid=True, history=None, **kwargs):
def execute(
self,
trans,
incoming: Optional[ToolParameterRequestInstanceT] = None,
history: Optional[model.History] = None,
set_output_hid: bool = DEFAULT_SET_OUTPUT_HID,
flush_job: bool = True,
):
"""
Execute the tool using parameter values in `incoming`. This just
dispatches to the `ToolAction` instance specified by
`self.tool_action`. In general this will create a `Job` that
when run will build the tool's outputs, e.g. `DefaultToolAction`.

_execute has many more options but should be accessed through
handle_single_execution. The public interface to execute should be
rarely used and in more specific ways.
"""
return self._execute(
trans,
incoming=incoming,
history=history,
set_output_hid=set_output_hid,
flush_job=flush_job,
)

def _execute(
self,
trans,
incoming: Optional[ToolParameterRequestInstanceT] = None,
history: Optional[model.History] = None,
rerun_remap_job_id: Optional[int] = DEFAULT_RERUN_REMAP_JOB_ID,
execution_cache: Optional[ToolExecutionCache] = None,
dataset_collection_elements: Optional[DatasetCollectionElementsSliceT] = None,
completed_job: Optional[model.Job] = None,
collection_info: Optional[MatchingCollections] = None,
job_callback: Optional[JobCallbackT] = DEFAULT_JOB_CALLBACK,
preferred_object_store_id: Optional[str] = DEFAULT_PREFERRED_OBJECT_STORE_ID,
set_output_hid: bool = DEFAULT_SET_OUTPUT_HID,
flush_job: bool = True,
skip: bool = False,
):
if incoming is None:
incoming = {}
try:
return self.tool_action.execute(
self, trans, incoming=incoming, set_output_hid=set_output_hid, history=history, **kwargs
self,
trans,
incoming=incoming,
history=history,
job_params=None,
rerun_remap_job_id=rerun_remap_job_id,
execution_cache=execution_cache,
dataset_collection_elements=dataset_collection_elements,
completed_job=completed_job,
collection_info=collection_info,
job_callback=job_callback,
preferred_object_store_id=preferred_object_store_id,
set_output_hid=set_output_hid,
flush_job=flush_job,
skip=skip,
)
except exceptions.ToolExecutionError as exc:
job = exc.job
Expand Down Expand Up @@ -2988,7 +3049,9 @@ class SetMetadataTool(Tool):
requires_setting_metadata = False
tool_action: "SetMetadataToolAction"

def regenerate_imported_metadata_if_needed(self, hda, history, user, session_id):
def regenerate_imported_metadata_if_needed(
self, hda: model.HistoryDatasetAssociation, history: model.History, user: model.User, session_id: int
):
if hda.has_metadata_files:
job, *_ = self.tool_action.execute_via_app(
self,
Expand Down
132 changes: 55 additions & 77 deletions lib/galaxy/tools/actions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
cast,
Dict,
List,
Optional,
Set,
Tuple,
TYPE_CHECKING,
Union,
)
Expand All @@ -24,15 +26,32 @@
from galaxy.job_execution.actions.post import ActionBox
from galaxy.managers.context import ProvidesHistoryContext
from galaxy.model import (
History,
HistoryDatasetAssociation,
Job,
LibraryDatasetDatasetAssociation,
WorkflowRequestInputParameter,
)
from galaxy.model.base import transaction
from galaxy.model.dataset_collections.builder import CollectionBuilder
from galaxy.model.dataset_collections.matching import MatchingCollections
from galaxy.model.none_like import NoneDataset
from galaxy.objectstore import ObjectStorePopulator
from galaxy.tools.execute import (
DatasetCollectionElementsSliceT,
DEFAULT_DATASET_COLLECTION_ELEMENTS,
DEFAULT_JOB_CALLBACK,
DEFAULT_PREFERRED_OBJECT_STORE_ID,
DEFAULT_RERUN_REMAP_JOB_ID,
DEFAULT_SET_OUTPUT_HID,
JobCallbackT,
ToolParameterRequestInstanceT,
)
from galaxy.tools.execution_helpers import (
filter_output,
on_text_for_names,
ToolExecutionCache,
)
from galaxy.tools.parameters import update_dataset_ids
from galaxy.tools.parameters.basic import (
DataCollectionToolParameter,
Expand All @@ -54,32 +73,8 @@
log = logging.getLogger(__name__)


class ToolExecutionCache:
"""An object mean to cache calculation caused by repeatedly evaluting
the same tool by the same user with slightly different parameters.
"""

def __init__(self, trans):
self.trans = trans
self.current_user_roles = trans.get_current_user_roles()
self.chrom_info = {}
self.cached_collection_elements = {}

def get_chrom_info(self, tool_id, input_dbkey):
genome_builds = self.trans.app.genome_builds
custom_build_hack_get_len_from_fasta_conversion = tool_id != "CONVERTER_fasta_to_len"
if custom_build_hack_get_len_from_fasta_conversion and input_dbkey in self.chrom_info:
return self.chrom_info[input_dbkey]

chrom_info_pair = genome_builds.get_chrom_info(
input_dbkey,
trans=self.trans,
custom_build_hack_get_len_from_fasta_conversion=custom_build_hack_get_len_from_fasta_conversion,
)
if custom_build_hack_get_len_from_fasta_conversion:
self.chrom_info[input_dbkey] = chrom_info_pair

return chrom_info_pair
OutputDatasetsT = Dict[str, "DatasetInstance"]
ToolActionExecuteResult = Union[Tuple[Job, OutputDatasetsT, Optional[History]], Tuple[Job, OutputDatasetsT]]


class ToolAction:
Expand All @@ -89,14 +84,31 @@ class ToolAction:
"""

@abstractmethod
def execute(self, tool, trans, incoming=None, set_output_hid=True, **kwargs):
pass
def execute(
self,
tool,
trans,
incoming: Optional[ToolParameterRequestInstanceT] = None,
history: Optional[History] = None,
job_params=None,
rerun_remap_job_id: Optional[int] = DEFAULT_RERUN_REMAP_JOB_ID,
execution_cache: Optional[ToolExecutionCache] = None,
dataset_collection_elements: Optional[DatasetCollectionElementsSliceT] = DEFAULT_DATASET_COLLECTION_ELEMENTS,
completed_job: Optional[Job] = None,
collection_info: Optional[MatchingCollections] = None,
job_callback: Optional[JobCallbackT] = DEFAULT_JOB_CALLBACK,
preferred_object_store_id: Optional[str] = DEFAULT_PREFERRED_OBJECT_STORE_ID,
set_output_hid: bool = DEFAULT_SET_OUTPUT_HID,
flush_job: bool = True,
skip: bool = False,
) -> ToolActionExecuteResult:
"""Perform target tool action."""


class DefaultToolAction(ToolAction):
"""Default tool action is to run an external command"""

produces_real_jobs = True
produces_real_jobs: bool = True

def _collect_input_datasets(
self,
Expand Down Expand Up @@ -389,21 +401,20 @@ def execute(
self,
tool,
trans,
incoming=None,
return_job=False,
set_output_hid=True,
history=None,
incoming: Optional[ToolParameterRequestInstanceT] = None,
history: Optional[History] = None,
job_params=None,
rerun_remap_job_id=None,
execution_cache=None,
rerun_remap_job_id: Optional[int] = DEFAULT_RERUN_REMAP_JOB_ID,
execution_cache: Optional[ToolExecutionCache] = None,
dataset_collection_elements=None,
completed_job=None,
collection_info=None,
job_callback=None,
preferred_object_store_id=None,
flush_job=True,
skip=False,
):
completed_job: Optional[Job] = None,
collection_info: Optional[MatchingCollections] = None,
job_callback: Optional[JobCallbackT] = DEFAULT_JOB_CALLBACK,
preferred_object_store_id: Optional[str] = DEFAULT_PREFERRED_OBJECT_STORE_ID,
set_output_hid: bool = DEFAULT_SET_OUTPUT_HID,
flush_job: bool = True,
skip: bool = False,
) -> ToolActionExecuteResult:
"""
Executes a tool, creating job and tool outputs, associating them, and
submitting the job to the job queue. If history is not specified, use
Expand All @@ -424,6 +435,7 @@ def execute(
preserved_hdca_tags,
all_permissions,
) = self._collect_inputs(tool, trans, incoming, history, current_user_roles, collection_info)
assert history # tell type system we've set history and it is no longer optional
# Build name for output datasets based on tool name and input names
on_text = self._get_on_text(inp_data)

Expand Down Expand Up @@ -846,7 +858,7 @@ def _get_on_text(self, inp_data):

return on_text_for_names(input_names)

def _new_job_for_session(self, trans, tool, history):
def _new_job_for_session(self, trans, tool, history) -> Tuple[model.Job, Optional[model.GalaxySession]]:
job = trans.app.model.Job()
job.galaxy_version = trans.app.config.version_major
galaxy_session = None
Expand Down Expand Up @@ -1097,40 +1109,6 @@ def check_elements(elements):
self.out_collection_instances[name] = hdca


def on_text_for_names(input_names):
# input_names may contain duplicates... this is because the first value in
# multiple input dataset parameters will appear twice once as param_name
# and once as param_name1.
unique_names = []
for name in input_names:
if name not in unique_names:
unique_names.append(name)
input_names = unique_names

# Build name for output datasets based on tool name and input names
if len(input_names) == 0:
on_text = ""
elif len(input_names) == 1:
on_text = input_names[0]
elif len(input_names) == 2:
on_text = "{} and {}".format(*input_names)
elif len(input_names) == 3:
on_text = "{}, {}, and {}".format(*input_names)
else:
on_text = "{}, {}, and others".format(*input_names[:2])
return on_text


def filter_output(tool, output, incoming):
for filter in output.filters:
try:
if not eval(filter.text.strip(), globals(), incoming):
return True # do not create this dataset
except Exception as e:
log.debug(f"Tool {tool.id} output {output.name}: dataset output filter ({filter.text}) failed: {e}")
return False


def get_ext_or_implicit_ext(hda):
if hda.implicitly_converted_parent_datasets:
# implicitly_converted_parent_datasets is a list of ImplicitlyConvertedDatasetAssociation
Expand Down
Loading
Loading