Skip to content

Commit

Permalink
fixing custom metrics usage
Browse files Browse the repository at this point in the history
  • Loading branch information
kylemann16 committed Sep 13, 2024
1 parent 987c413 commit 82e3e48
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 27 deletions.
50 changes: 26 additions & 24 deletions src/silvimetric/cli/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def convert(self, value, param, ctx) -> list[Metric]:
metrics: set[Metric] = set()
for val in parsed_values:
if '.py' in val:
# user imported metrics from external file
try:
import importlib.util
import os
Expand All @@ -77,31 +78,32 @@ def convert(self, value, param, ctx) -> list[Metric]:
if not isinstance(m, Metric):
self.fail(f"Invalid Metric supplied: {m}")

metrics.update(list(user_metrics.metrics().values()))

try:
if val == 'stats':
metrics.update(list(statistics.values()))
elif val == 'p_moments':
metrics.update(list(product_moments.values()))
elif val == 'l_moments':
metrics.update(list(l_moments.values()))
elif val == 'percentiles':
metrics.update(list(percentiles.values()))
elif val == 'aad':
metrics.update(list(aad.aad.values()))
elif val == 'grid_metrics':
metrics.update(list(grid_metrics.values()))
elif val == 'all':
metrics.update(list(all_metrics.values()))
else:
m = all_metrics[val]
if isinstance(m, Metric):
metrics.add(m)
metrics.update(list(user_metrics.metrics()))
else:
# SilviMetric defined metrics
try:
if val == 'stats':
metrics.update(list(statistics.values()))
elif val == 'p_moments':
metrics.update(list(product_moments.values()))
elif val == 'l_moments':
metrics.update(list(l_moments.values()))
elif val == 'percentiles':
metrics.update(list(percentiles.values()))
elif val == 'aad':
metrics.update(list(aad.aad.values()))
elif val == 'grid_metrics':
metrics.update(list(grid_metrics.values()))
elif val == 'all':
metrics.update(list(all_metrics.values()))
else:
metrics.udpate(list(m))
except Exception as e:
self.fail(f"{val!r} is not available in Metrics", param, ctx)
m = all_metrics[val]
if isinstance(m, Metric):
metrics.add(m)
else:
metrics.udpate(list(m))
except Exception as e:
self.fail(f"{val!r} is not available in Metrics", param, ctx)
return list(metrics)

def dask_handle(dasktype: str, scheduler: str, workers: int, threads: int,
Expand Down
8 changes: 5 additions & 3 deletions src/silvimetric/resources/attribute.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@ def __init__(self, name: str, dtype) -> None:
"""SilviMetric representation of array of numpy dtype"""
if isinstance(dtype, AttributeDtype):
self.dtype = dtype
elif isinstance(dtype, np.dtype):
self.dtype = AttributeDtype(subtype=dtype)
else:
raise AttributeError(f"Invalid dtype passed to Attribute: {dtype}")
# AttributeDtype takes any dtype that can be passed to np.dtype
try:
self.dtype = AttributeDtype(subtype=dtype)
except Exception as e:
raise AttributeError(f"Invalid dtype passed to Attribute: {dtype}") from e

def make_array(self, data, copy=False):
return AttributeArray(data=data, copy=copy)
Expand Down

0 comments on commit 82e3e48

Please sign in to comment.