Skip to content

Commit

Permalink
Merge branch 'distributed-ndcg' of https://github.com/ili0820/ignite
Browse files Browse the repository at this point in the history
…into distributed-ndcg
  • Loading branch information
ili0820 committed Sep 20, 2023
2 parents 3920b2e + 3964133 commit 6a4bc7f
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 1 deletion.
2 changes: 1 addition & 1 deletion ignite/contrib/handlers/tqdm_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ class ProgressBar(BaseLogger):
Note:
When adding attaching the progress bar to an engine, it is recommend that you replace
When attaching the progress bar to an engine, it is recommended that you replace
every print operation in the engine's handlers triggered every iteration with
``pbar.log_message`` to guarantee the correct format of the stdout.
Expand Down
8 changes: 8 additions & 0 deletions ignite/handlers/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,10 +669,18 @@ def reload_objects(self, to_load: Mapping, load_kwargs: Optional[Dict] = None, *
If ``to_load`` contains objects of type torch `DistributedDataParallel`_ or
`DataParallel`_, method ``load_state_dict`` will applied to their internal wrapped model (``obj.module``).
Note:
This method works only when the ``save_handler`` is of types string,
:class:`~pathlib.Path` or :class:`~ignite.handlers.checkpoint.DiskSaver`.
.. _DistributedDataParallel: https://pytorch.org/docs/stable/generated/
torch.nn.parallel.DistributedDataParallel.html
.. _DataParallel: https://pytorch.org/docs/stable/generated/torch.nn.DataParallel.html
"""
if not isinstance(self.save_handler, DiskSaver):
raise AttributeError(
f"Checkpoint's `save_handler` should be of type `DiskSaver`, given {type(self.save_handler)}"
)

global_step = filename_components.get("global_step", None)

Expand Down
7 changes: 7 additions & 0 deletions ignite/metrics/ssim.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def __init__(

super(SSIM, self).__init__(output_transform=output_transform, device=device)
self.gaussian = gaussian
self.data_range = data_range
self.c1 = (k1 * data_range) ** 2
self.c2 = (k2 * data_range) ** 2
self.pad_h = (self.kernel_size[0] - 1) // 2
Expand Down Expand Up @@ -157,6 +158,12 @@ def update(self, output: Sequence[torch.Tensor]) -> None:
f"Expected y_pred and y to have BxCxHxW shape. Got y_pred: {y_pred.shape} and y: {y.shape}."
)

# converts potential integer tensor to fp
if not y.is_floating_point():
y = y.float()
if not y_pred.is_floating_point():
y_pred = y_pred.float()

nb_channel = y_pred.size(1)
if self._kernel is None or self._kernel.shape[0] != nb_channel:
self._kernel = self._kernel_2d.expand(nb_channel, 1, -1, -1)
Expand Down
4 changes: 4 additions & 0 deletions tests/ignite/handlers/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,10 @@ def __len__(self):
with pytest.raises(TypeError, match="If `include_self` is True, then `to_save` must be mutable"):
Checkpoint(ImmutableMapping(), lambda x: x, include_self=True)

checkpoint = Checkpoint(to_save, lambda x: x)
with pytest.raises(AttributeError, match="Checkpoint's `save_handler` should be of type `DiskSaver`"):
checkpoint.reload_objects(to_save)


def test_save_handler_as_str(dirname):
to_save = {"model": model}
Expand Down
31 changes: 31 additions & 0 deletions tests/ignite/metrics/test_ssim.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,37 @@ def test_cuda_ssim_dtypes(available_device, dtype, precision):
compare_ssim_ignite_skiimg(y_pred, y, available_device, precision)


@pytest.mark.parametrize(
"shape, kernel_size, gaussian, use_sample_covariance",
[[(8, 3, 224, 224), 7, False, True], [(12, 3, 28, 28), 11, True, False]],
)
def test_ssim_uint8(available_device, shape, kernel_size, gaussian, use_sample_covariance):
y_pred = torch.randint(0, 255, shape, device=available_device, dtype=torch.uint8)
y = (y_pred * 0.8).to(dtype=torch.uint8)

sigma = 1.5
data_range = 255
ssim = SSIM(data_range=data_range, sigma=sigma, device=available_device)
ssim.update((y_pred, y))
ignite_ssim = ssim.compute()

skimg_pred = y_pred.cpu().numpy()
skimg_y = (skimg_pred * 0.8).astype(np.uint8)
skimg_ssim = ski_ssim(
skimg_pred,
skimg_y,
win_size=kernel_size,
sigma=sigma,
channel_axis=1,
gaussian_weights=gaussian,
data_range=data_range,
use_sample_covariance=use_sample_covariance,
)

assert isinstance(ignite_ssim, float)
assert np.allclose(ignite_ssim, skimg_ssim, atol=1e-5)


@pytest.mark.parametrize("metric_device", ["cpu", "process_device"])
def test_distrib_integration(distributed, metric_device):
from ignite.engine import Engine
Expand Down

0 comments on commit 6a4bc7f

Please sign in to comment.