From 9532d1e69caae435c538a7e2a9b6a495083da88d Mon Sep 17 00:00:00 2001 From: vfdev Date: Wed, 2 Oct 2024 15:23:48 +0200 Subject: [PATCH] Fix pytorch versions CI failures --- .github/workflows/pytorch-version-tests.yml | 14 ++++++++------ ignite/metrics/hsic.py | 1 + 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/.github/workflows/pytorch-version-tests.yml b/.github/workflows/pytorch-version-tests.yml index f268669158e..27fe019664f 100644 --- a/.github/workflows/pytorch-version-tests.yml +++ b/.github/workflows/pytorch-version-tests.yml @@ -17,13 +17,8 @@ jobs: matrix: python-version: [3.8, 3.9, "3.10"] pytorch-version: - [2.3.1, 2.2.2, 2.1.2, 2.0.1, 1.13.1, 1.12.1, 1.10.0, 1.8.1, 1.5.1] + [2.3.1, 2.2.2, 2.1.2, 2.0.1, 1.13.1, 1.12.1, 1.10.0, 1.8.1] exclude: - - pytorch-version: 1.5.1 - python-version: 3.9 - - pytorch-version: 1.5.1 - python-version: "3.10" - # disabling python 3.9 support with PyTorch 1.7.1 and 1.8.1, to stop repeated pytorch-version test fail. # https://github.com/pytorch/ignite/issues/2383 - pytorch-version: 1.8.1 @@ -72,6 +67,13 @@ jobs: shell: bash -l {0} run: | conda install pytorch=${{ matrix.pytorch-version }} torchvision cpuonly python=${{ matrix.python-version }} -c pytorch + + # We should install numpy<2.0 for pytorch<2.3 + numpy_one_pth_version=$(python -c "import torch; print(float('.'.join(torch.__version__.split('.')[:2])) < 2.3)") + if [ "${numpy_one_pth_version}" == "True" ]; then + pip install -U "numpy<2.0" + fi + pip install -r requirements-dev.txt python setup.py install diff --git a/ignite/metrics/hsic.py b/ignite/metrics/hsic.py index a35d47f258b..2967db19e4f 100644 --- a/ignite/metrics/hsic.py +++ b/ignite/metrics/hsic.py @@ -134,6 +134,7 @@ def update(self, output: Sequence[Tensor]) -> None: vx: Union[Tensor, float] if self.sigma_x < 0: + # vx = torch.quantile(dxx, 0.5) vx = torch.quantile(dxx, 0.5) else: vx = self.sigma_x**2