Skip to content

Commit

Permalink
Use if check on torch.double usages for MPS backend
Browse files Browse the repository at this point in the history
  • Loading branch information
sadra-barikbin committed Sep 3, 2024
1 parent cb6a328 commit 248fe89
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 11 deletions.
12 changes: 7 additions & 5 deletions ignite/metrics/mean_average_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def _compute_average_precision(self, recall: torch.Tensor, precision: torch.Tens
).where(rec_thresh_indices != recall.size(-1), 0)
recall = rec_thresholds
recall_differential = recall.diff(
dim=-1, prepend=torch.zeros((*recall.shape[:-1], 1), device=recall.device, dtype=torch.double)
dim=-1, prepend=torch.zeros((*recall.shape[:-1], 1), device=recall.device, dtype=recall.dtype)
)
return torch.sum(recall_differential * precision, dim=-1)

Expand Down Expand Up @@ -327,7 +327,9 @@ def _compute_recall_and_precision(
`(recall, precision)`
"""
indices = torch.argsort(y_pred, stable=True, descending=True)
tp_summation = y_true[indices].cumsum(dim=0).double()
tp_summation = y_true[indices].cumsum(dim=0)
if tp_summation.device != torch.device("mps"):
tp_summation = tp_summation.double()

# Adopted from Scikit-learn's implementation
unique_scores_indices = torch.nonzero(
Expand Down Expand Up @@ -360,16 +362,16 @@ def compute(self) -> Union[torch.Tensor, float]:
torch.long if self._type == "multiclass" else torch.uint8,
self._device,
)

y_pred = _cat_and_agg_tensors(self._y_pred, (num_classes,), torch.double, self._device)
fp_precision = torch.double if self._device != torch.device("mps") else torch.float32
y_pred = _cat_and_agg_tensors(self._y_pred, (num_classes,), fp_precision, self._device)

if self._type == "multiclass":
y_true = to_onehot(y_true, num_classes=num_classes).T
if self.class_mean == "micro":
y_true = y_true.reshape(1, -1)
y_pred = y_pred.view(1, -1)
y_true_positive_count = y_true.sum(dim=-1)
average_precisions = torch.zeros_like(y_true_positive_count, device=self._device, dtype=torch.double)
average_precisions = torch.zeros_like(y_true_positive_count, device=self._device, dtype=fp_precision)
for cls in range(y_true_positive_count.size(0)):
recall, precision = self._compute_recall_and_precision(y_true[cls], y_pred[cls], y_true_positive_count[cls])
average_precisions[cls] = self._compute_average_precision(recall, precision)
Expand Down
24 changes: 18 additions & 6 deletions ignite/metrics/vision/object_detection_average_precision_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,13 +104,19 @@ def box_iou(pred_boxes: torch.Tensor, gt_boxes: torch.Tensor, iscrowd: torch.Boo
except ImportError:
raise ModuleNotFoundError("This metric requires torchvision to be installed.")

precision = torch.double if not torch.device(device) != torch.device("mps") else torch.float32

if iou_thresholds is None:
iou_thresholds = torch.linspace(0.5, 0.95, 10, dtype=torch.double)
iou_thresholds = torch.linspace(0.5, 0.95, 10, device=device, dtype=precision)

self._iou_thresholds = self._setup_thresholds(iou_thresholds, "iou_thresholds")
self._iou_thresholds = self._iou_thresholds.to(device=device, dtype=precision)

if rec_thresholds is None:
rec_thresholds = torch.linspace(0, 1, 101, device=device, dtype=torch.double)
rec_thresholds = torch.linspace(0, 1, 101, device=device, dtype=precision)

self._rec_thresholds = self._setup_thresholds(rec_thresholds, "rec_thresholds")
self._rec_thresholds = self._rec_thresholds.to(device=device, dtype=precision)

self._num_classes = num_classes
self._area_range = area_range
Expand Down Expand Up @@ -204,9 +210,14 @@ def _compute_recall_and_precision(
"""
indices = torch.argsort(scores, dim=-1, stable=True, descending=True)
tp = TP[..., indices]
tp_summation = tp.cumsum(dim=-1).double()
tp_summation = tp.cumsum(dim=-1)
if tp_summation.device != torch.device("mps"):
tp_summation = tp_summation.double()

fp = FP[..., indices]
fp_summation = fp.cumsum(dim=-1).double()
fp_summation = fp.cumsum(dim=-1)
if fp_summation.device != torch.device("mps"):
fp_summation = fp_summation.double()

recall = tp_summation / y_true_count
predicted_positive = tp_summation + fp_summation
Expand Down Expand Up @@ -342,12 +353,13 @@ def _compute(self) -> torch.Tensor:
pred_labels = _cat_and_agg_tensors(self._y_pred_labels, cast(Tuple[int], ()), torch.long, self._device)
TP = _cat_and_agg_tensors(self._tps, (len(self._iou_thresholds),), torch.uint8, self._device)
FP = _cat_and_agg_tensors(self._fps, (len(self._iou_thresholds),), torch.uint8, self._device)
scores = _cat_and_agg_tensors(self._scores, cast(Tuple[int], ()), torch.double, self._device)
fp_precision = torch.double if self._device != torch.device("mps") else torch.float32
scores = _cat_and_agg_tensors(self._scores, cast(Tuple[int], ()), fp_precision, self._device)

average_precisions_recalls = -torch.ones(
(2, self._num_classes, len(self._iou_thresholds)),
device=self._device,
dtype=torch.double,
dtype=fp_precision,
)
for cls in range(self._num_classes):
if self._y_true_count[cls] == 0:
Expand Down

0 comments on commit 248fe89

Please sign in to comment.