Skip to content

Commit

Permalink
Merge branch 'main' into detectron_model_health
Browse files Browse the repository at this point in the history
  • Loading branch information
amrit110 authored Jul 17, 2024
2 parents a311878 + f391ebf commit 1667f2c
Show file tree
Hide file tree
Showing 7 changed files with 14 additions and 10 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ jobs:
fi
- name: Set up Node.js
uses: actions/[email protected].2
uses: actions/[email protected].3
with:
node-version: 18
cache: yarn
Expand Down
2 changes: 1 addition & 1 deletion cyclops/data/clean.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def normalize_events(
data = data.infer_objects()
data[event_name_col] = normalize_names(data[event_name_col])

if event_value_col and data[event_value_col].dtypes == object:
if event_value_col and data[event_value_col].dtypes == object: # noqa: E721
data[event_value_col] = normalize_values(data[event_value_col])
log_df_counts(data, event_name_col, "Normalized values...", columns=True)

Expand Down
2 changes: 1 addition & 1 deletion cyclops/data/slicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -923,7 +923,7 @@ def is_datetime(
return False
if isinstance(value, (list, np.ndarray)):
return all((is_datetime(v) for v in value))
if isinstance(value, (datetime.datetime, np.datetime64)):
if isinstance(value, (datetime.datetime, np.datetime64)): # noqa: SIM103
return True

return False
Expand Down
8 changes: 6 additions & 2 deletions cyclops/evaluate/metrics/average_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,9 +131,13 @@ def __new__( # type: ignore # mypy expects a subclass of AveragePrecision
pos_label=pos_label,
)
if task == "multiclass":
NotImplementedError("Multiclass average precision is not implemented.")
raise NotImplementedError(
"Multiclass average precision is not implemented."
)
if task == "multilabel":
NotImplementedError("Multilabel average precision is not implemented.")
raise NotImplementedError(
"Multilabel average precision is not implemented."
)

raise ValueError(
"Expected argument `task` to be either 'binary', 'multiclass' or "
Expand Down
4 changes: 2 additions & 2 deletions cyclops/evaluate/metrics/functional/average_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,9 +159,9 @@ def average_precision(
if task == "binary":
return binary_average_precision(target, preds, thresholds, pos_label)
if task == "multiclass":
NotImplementedError("Multiclass average precision is not implemented.")
raise NotImplementedError("Multiclass average precision is not implemented.")
if task == "multilabel":
NotImplementedError("Multilabel average precision is not implemented.")
raise NotImplementedError("Multilabel average precision is not implemented.")

raise ValueError(
"Expected argument `task` to be either 'binary', 'multiclass' or "
Expand Down
4 changes: 2 additions & 2 deletions cyclops/report/model_card/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def update_field(self, name: str, value: Any) -> None:
raise ValueError(f"Field {name} does not exist.")

field = self.__fields__[name]
if field.default_factory == list or isinstance(getattr(self, name), list):
if field.default_factory == list or isinstance(getattr(self, name), list): # noqa: E721
# NOTE: pydantic does not trigger validation when appending to a list,
# but if `validate_assignment` is set to `True`, then validation will
# be triggered when the list is assigned to the field.
Expand Down Expand Up @@ -173,6 +173,6 @@ def add_field(self, name: str, value: Any) -> None:
model_config=BaseModelCardField.Config,
default_factory=default_factory,
field_info=FieldInfo(unique_items=True)
if default_factory == list
if default_factory == list # noqa: E721
else None,
)
2 changes: 1 addition & 1 deletion tests/cyclops/evaluate/metrics/experimental/test_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def test_add_state_factory():
# default_factory is 'list'
metric.add_state_default_factory("b", list) # type: ignore
assert (
metric._default_factories.get("b") == list
metric._default_factories.get("b") == list # noqa: E721
), "Default factory should be 'list'."

# dist_reduce_fn is "sum"
Expand Down

0 comments on commit 1667f2c

Please sign in to comment.