From 82e3e487d5203a9218dfbae6e368a5ae7a4f8dc7 Mon Sep 17 00:00:00 2001 From: kylemann16 Date: Fri, 13 Sep 2024 09:51:09 -0500 Subject: [PATCH] fixing custom metrics usage --- src/silvimetric/cli/common.py | 50 +++++++++++++------------- src/silvimetric/resources/attribute.py | 8 +++-- 2 files changed, 31 insertions(+), 27 deletions(-) diff --git a/src/silvimetric/cli/common.py b/src/silvimetric/cli/common.py index 80abc40..a21cf62 100644 --- a/src/silvimetric/cli/common.py +++ b/src/silvimetric/cli/common.py @@ -54,6 +54,7 @@ def convert(self, value, param, ctx) -> list[Metric]: metrics: set[Metric] = set() for val in parsed_values: if '.py' in val: + # user imported metrics from external file try: import importlib.util import os @@ -77,31 +78,32 @@ def convert(self, value, param, ctx) -> list[Metric]: if not isinstance(m, Metric): self.fail(f"Invalid Metric supplied: {m}") - metrics.update(list(user_metrics.metrics().values())) - - try: - if val == 'stats': - metrics.update(list(statistics.values())) - elif val == 'p_moments': - metrics.update(list(product_moments.values())) - elif val == 'l_moments': - metrics.update(list(l_moments.values())) - elif val == 'percentiles': - metrics.update(list(percentiles.values())) - elif val == 'aad': - metrics.update(list(aad.aad.values())) - elif val == 'grid_metrics': - metrics.update(list(grid_metrics.values())) - elif val == 'all': - metrics.update(list(all_metrics.values())) - else: - m = all_metrics[val] - if isinstance(m, Metric): - metrics.add(m) + metrics.update(list(user_metrics.metrics())) + else: + # SilviMetric defined metrics + try: + if val == 'stats': + metrics.update(list(statistics.values())) + elif val == 'p_moments': + metrics.update(list(product_moments.values())) + elif val == 'l_moments': + metrics.update(list(l_moments.values())) + elif val == 'percentiles': + metrics.update(list(percentiles.values())) + elif val == 'aad': + metrics.update(list(aad.aad.values())) + elif val == 'grid_metrics': + metrics.update(list(grid_metrics.values())) + elif val == 'all': + metrics.update(list(all_metrics.values())) else: - metrics.udpate(list(m)) - except Exception as e: - self.fail(f"{val!r} is not available in Metrics", param, ctx) + m = all_metrics[val] + if isinstance(m, Metric): + metrics.add(m) + else: + metrics.udpate(list(m)) + except Exception as e: + self.fail(f"{val!r} is not available in Metrics", param, ctx) return list(metrics) def dask_handle(dasktype: str, scheduler: str, workers: int, threads: int, diff --git a/src/silvimetric/resources/attribute.py b/src/silvimetric/resources/attribute.py index 4780f2d..05dff52 100644 --- a/src/silvimetric/resources/attribute.py +++ b/src/silvimetric/resources/attribute.py @@ -15,10 +15,12 @@ def __init__(self, name: str, dtype) -> None: """SilviMetric representation of array of numpy dtype""" if isinstance(dtype, AttributeDtype): self.dtype = dtype - elif isinstance(dtype, np.dtype): - self.dtype = AttributeDtype(subtype=dtype) else: - raise AttributeError(f"Invalid dtype passed to Attribute: {dtype}") + # AttributeDtype takes any dtype that can be passed to np.dtype + try: + self.dtype = AttributeDtype(subtype=dtype) + except Exception as e: + raise AttributeError(f"Invalid dtype passed to Attribute: {dtype}") from e def make_array(self, data, copy=False): return AttributeArray(data=data, copy=copy)