diff --git a/src/fasthep_flow/config.py b/src/fasthep_flow/config.py index 784b4d3..ddde62f 100644 --- a/src/fasthep_flow/config.py +++ b/src/fasthep_flow/config.py @@ -32,8 +32,7 @@ def validate_type(cls, value: str) -> str: """Validate the type field Any specified type needs to be a Python class that can be imported""" # Split the string to separate the module from the class name - if value in ALIASES: - value = ALIASES[value] + value = ALIASES.get(value, value) module_path, class_name = value.rsplit(".", 1) try: # Import the module diff --git a/src/fasthep_flow/orchestration.py b/src/fasthep_flow/orchestration.py index 9243953..4d5eda8 100644 --- a/src/fasthep_flow/orchestration.py +++ b/src/fasthep_flow/orchestration.py @@ -21,6 +21,7 @@ def create_dask_cluster() -> Any: + """Create a Dask cluster - to be used with Hamilton Dask adapter.""" cluster = LocalCluster() client = Client(cluster) logger.info(client.cluster) @@ -34,6 +35,7 @@ def create_dask_cluster() -> Any: def create_dask_adapter(client_type: str) -> Any: + """Create a Hamilton adapter for Dask execution""" from hamilton.plugins import h_dask client = DASK_CLIENTS[client_type]() @@ -48,6 +50,7 @@ def create_dask_adapter(client_type: str) -> Any: def create_local_adapter() -> Any: + """Create a Hamilton adapter for local execution.""" return base.SimplePythonGraphAdapter(base.DictResult()) @@ -62,7 +65,7 @@ def workflow_to_hamilton_dag( output_path: str, # method: str = "local" ) -> Any: - """Convert a workflow into a Hamilton flow.""" + """Convert a workflow into a Hamilton DAG.""" task_functions = load_tasks_module(workflow) # adapter = PRECONFIGURED_ADAPTERS[method]() cache_dir = Path(output_path) / ".hamilton_cache" diff --git a/src/fasthep_flow/workflow.py b/src/fasthep_flow/workflow.py index b4db05e..6bef29e 100644 --- a/src/fasthep_flow/workflow.py +++ b/src/fasthep_flow/workflow.py @@ -54,6 +54,7 @@ def {task_name}() -> dict[str, Any]: def get_task_source(obj: Any, task_name: str) -> str: + """Retrieve the source code of a task object and return a function definition.""" # Capture the object definition obj_attrs = {} @@ -72,11 +73,21 @@ def get_task_source(obj: Any, task_name: str) -> str: def get_config_hash(config_file: Path) -> str: + """Reads the config file and returns a shortened hash.""" with config_file.open("rb") as f: return hashlib.file_digest(f, "sha256").hexdigest()[:8] def create_save_path(base_path: Path, workflow_name: str, config_hash: str) -> Path: + """ + Creates a save path for the workflow and returns the generated path. + + @param base_path: Base path for the save location. + @param workflow_name: Name of the workflow. + @param config_hash: Hash of the configuration file. + + returns: Path to the save location. + """ date = datetime.now().strftime("%Y.%m.%d") # TODO: instead of date, create a "touched" file that is updated every time the workflow is saved path = Path(f"{base_path}/{workflow_name}/{date}/{config_hash}/").resolve() @@ -163,6 +174,10 @@ def save(self, base_path: Path = Path("~/.fasthep/flow")) -> str: @staticmethod def load(path: Path | str) -> Workflow: + """ + Load a workflow from a file. + @param path: Path to the directory containing the workflow file. + """ path = Path(path) workflow_file = path / "workflow.pkl" with workflow_file.open("rb") as f: @@ -172,6 +187,7 @@ def load(path: Path | str) -> Workflow: def load_tasks_module(workflow: Workflow) -> ModuleType: + """Load tasks from a tasks.py file in the workflow save path.""" task_location = workflow.save_path task_spec = importlib.machinery.PathFinder().find_spec("tasks", [task_location]) if task_spec is None: