Skip to content

Commit

Permalink
[Enhancement] Change the order of condition to make fx wok (#2883)
Browse files Browse the repository at this point in the history
  • Loading branch information
youkaichao authored Aug 3, 2023
1 parent f64d485 commit 9036241
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 7 deletions.
14 changes: 7 additions & 7 deletions mmcv/cnn/bricks/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def backward(ctx, grad: torch.Tensor) -> tuple:
class Conv2d(nn.Conv2d):

def forward(self, x: torch.Tensor) -> torch.Tensor:
if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)):
if obsolete_torch_version(TORCH_VERSION, (1, 4)) and x.numel() == 0:
out_shape = [x.shape[0], self.out_channels]
for i, k, p, s, d in zip(x.shape[-2:], self.kernel_size,
self.padding, self.stride, self.dilation):
Expand All @@ -62,7 +62,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
class Conv3d(nn.Conv3d):

def forward(self, x: torch.Tensor) -> torch.Tensor:
if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)):
if obsolete_torch_version(TORCH_VERSION, (1, 4)) and x.numel() == 0:
out_shape = [x.shape[0], self.out_channels]
for i, k, p, s, d in zip(x.shape[-3:], self.kernel_size,
self.padding, self.stride, self.dilation):
Expand All @@ -84,7 +84,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
class ConvTranspose2d(nn.ConvTranspose2d):

def forward(self, x: torch.Tensor) -> torch.Tensor:
if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)):
if obsolete_torch_version(TORCH_VERSION, (1, 4)) and x.numel() == 0:
out_shape = [x.shape[0], self.out_channels]
for i, k, p, s, d, op in zip(x.shape[-2:], self.kernel_size,
self.padding, self.stride,
Expand All @@ -106,7 +106,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
class ConvTranspose3d(nn.ConvTranspose3d):

def forward(self, x: torch.Tensor) -> torch.Tensor:
if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)):
if obsolete_torch_version(TORCH_VERSION, (1, 4)) and x.numel() == 0:
out_shape = [x.shape[0], self.out_channels]
for i, k, p, s, d, op in zip(x.shape[-3:], self.kernel_size,
self.padding, self.stride,
Expand All @@ -127,7 +127,7 @@ class MaxPool2d(nn.MaxPool2d):

def forward(self, x: torch.Tensor) -> torch.Tensor:
# PyTorch 1.9 does not support empty tensor inference yet
if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 9)):
if obsolete_torch_version(TORCH_VERSION, (1, 9)) and x.numel() == 0:
out_shape = list(x.shape[:2])
for i, k, p, s, d in zip(x.shape[-2:], _pair(self.kernel_size),
_pair(self.padding), _pair(self.stride),
Expand All @@ -145,7 +145,7 @@ class MaxPool3d(nn.MaxPool3d):

def forward(self, x: torch.Tensor) -> torch.Tensor:
# PyTorch 1.9 does not support empty tensor inference yet
if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 9)):
if obsolete_torch_version(TORCH_VERSION, (1, 9)) and x.numel() == 0:
out_shape = list(x.shape[:2])
for i, k, p, s, d in zip(x.shape[-3:], _triple(self.kernel_size),
_triple(self.padding),
Expand All @@ -164,7 +164,7 @@ class Linear(torch.nn.Linear):

def forward(self, x: torch.Tensor) -> torch.Tensor:
# empty tensor forward of Linear layer is supported in Pytorch 1.6
if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 5)):
if obsolete_torch_version(TORCH_VERSION, (1, 5)) and x.numel() == 0:
out_shape = [x.shape[0], self.out_features]
empty = NewEmptyTensorOp.apply(x, out_shape)
if self.training:
Expand Down
20 changes: 20 additions & 0 deletions tests/test_cnn/test_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import pytest
import torch
import torch.nn as nn
from mmengine.utils import digit_version
from mmengine.utils.dl_utils import TORCH_VERSION

from mmcv.cnn.bricks import (Conv2d, Conv3d, ConvTranspose2d, ConvTranspose3d,
Linear, MaxPool2d, MaxPool3d)
Expand Down Expand Up @@ -374,3 +376,21 @@ def test_nn_op_forward_called():
wrapper = Linear(3, 3)
wrapper(x_normal)
nn_module_forward.assert_called_with(x_normal)


@pytest.mark.skipif(
digit_version(TORCH_VERSION) < digit_version('1.10'),
reason='MaxPool2d and MaxPool3d will fail fx for torch<=1.9')
def test_fx_compatibility():
from torch import fx

# ensure the fx trace can pass the network
for Net in (MaxPool2d, MaxPool3d):
net = Net(1)
gm_module = fx.symbolic_trace(net) # noqa: F841
for Net in (Linear, ):
net = Net(1, 1)
gm_module = fx.symbolic_trace(net) # noqa: F841
for Net in (Conv2d, ConvTranspose2d, Conv3d, ConvTranspose3d):
net = Net(1, 1, 1)
gm_module = fx.symbolic_trace(net) # noqa: F841

0 comments on commit 9036241

Please sign in to comment.