Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[functorch] test: try using reference_inputs in vmap tests #91355

Closed
2 changes: 2 additions & 0 deletions test/functorch/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ def is_valid_inplace_sample_input(sample_input, op, inplace_variant):
return False
if sample_input.broadcasts_input:
return False
if not isinstance(sample_input.input, torch.Tensor):
return False

# Check if input's dtype matches the output's dtype
args = (sample_input.input,) + sample_input.args
Expand Down
117 changes: 94 additions & 23 deletions test/functorch/test_vmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from torch.testing._internal.common_methods_invocations import op_db
from torch.testing._internal.common_cuda import with_tf32_off
from torch.testing._internal.common_device_type import instantiate_device_type_tests, \
skipCUDAIfNoMagma
skipCUDAIfNoMagma, OpDTypes
from torch.testing._internal.common_device_type import ops
from torch.testing._internal.common_utils import (
parametrize,
Expand Down Expand Up @@ -3361,11 +3361,45 @@ def test():
vmap(op, in_dims)(*args, **kwargs)

# Sample inputs check
sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=False)
sample_inputs_op = {
# Take too long
"special.chebyshev_polynomial_t",
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These ops already have skip for taking long time with reference inputs

Eg.

BinaryUfuncInfo(
"special.chebyshev_polynomial_t",
dtypes=all_types_and(torch.bool),
promotes_int_to_float=True,
skips=(
DecorateInfo(unittest.skip("Skipped!"), "TestCudaFuserOpInfo"),
DecorateInfo(unittest.skip("Skipped!"), "TestNNCOpInfo"),
DecorateInfo(
unittest.skip("testing takes an unreasonably long time, #79528"),
"TestCommon",
"test_compare_cpu",
),
),
supports_one_python_scalar=True,
supports_autograd=False,

"special.chebyshev_polynomial_u",
"special.chebyshev_polynomial_v",
"special.chebyshev_polynomial_w",
"special.hermite_polynomial_he",
"special.laguerre_polynomial_l",
"special.legendre_polynomial_p",
"special.shifted_chebyshev_polynomial_t",
"special.shifted_chebyshev_polynomial_u",
"special.shifted_chebyshev_polynomial_v",
"special.shifted_chebyshev_polynomial_w",
# Leads to Illegal Memory Access on CUDA
# locally
"fft.fft",
"fft.fft2",
"fft.fftn",
"fft.hfft",
"fft.hfft2",
"fft.hfftn",
"fft.ifft",
"fft.ifft2",
"fft.ifftn",
"fft.ihfft",
"fft.ihfft2",
"fft.ihfftn",
}
if op.name in sample_inputs_op:
sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=False)
else:
sample_inputs_itr = op.reference_inputs(device, dtype, requires_grad=False)
aliases, inplace_aliases = discover_variants(op)
check_shape_only = op.name in ('empty_like', 'new_empty')
for sample_input in sample_inputs_itr:
args = (sample_input.input,) + sample_input.args
if not any(map(lambda arg: isinstance(arg, torch.Tensor), args)):
# Atleast one tensor required for vmap.
continue
kwargs = sample_input.kwargs
is_batch_norm_and_training = is_batch_norm_training(op.name, kwargs)
for args, in_dims, _ in generate_vmap_inputs(
Expand Down Expand Up @@ -3420,16 +3454,6 @@ def test():
xfail('nn.functional.fractional_max_pool2d'), # randomness
xfail('pca_lowrank', ''), # random operation
xfail('svd_lowrank', ''), # random operation
xfail('linspace', ''), # test runner can't handle factory functions
xfail('arange', ''), # test runner can't handle factory functions
xfail('logspace', ''), # test runner can't handle factory functions
xfail('scalar_tensor'), # test runner can't handle factory functions
xfail('empty', ''), # test runner can't handle factory functions
xfail('ones', ''), # test runner can't handle factory functions
xfail('zeros', ''), # test runner can't handle factory functions
xfail('full', ''), # test runner can't handle factory functions
xfail('eye', ''), # non-tensor input
xfail('broadcast_shapes', ''), # test runner can't handle non-Tensor ops
xfail('sparse.sampled_addmm'), # sparse
xfail("NumpyCubeNotComposableAutogradFunction"), # Not composable autograd.Function
skip('_softmax_backward_data'),
Expand All @@ -3449,7 +3473,6 @@ def test():
xfail('nn.functional.gaussian_nll_loss'), # data-dependent control flow error
xfail('nn.functional.embedding_bag'), # embedding renorm vmap inplace incompatible
xfail('__rpow__'), # https://github.com/pytorch/functorch/issues/617
xfail('column_stack', ''), # Batching rule not implemented for aten::column_stack
xfail('narrow'), # Batching rule not implemented for aten::narrow.Tensor

# required rank 4 tensor to use channels_last format
Expand All @@ -3473,11 +3496,34 @@ def test():
xfail('jiterator_unary', device_type='cuda'), # NYI: querying is_contiguous inside of vmap
xfail('jiterator_2inputs_2outputs', device_type='cuda'), # NYI: querying is_contiguous inside of vmap
# ---------------------------------------------------------------------

# TypeError: expected Tensor as element 0 in argument 0, but got NotImplementedType
xfail('__rsub__'),
# RuntimeError: Batching rule not implemented for aten::moveaxis.int;
# the fallback path doesn't work on out= or view ops.
xfail('movedim'),
# RuntimeError: NYI: querying is_contiguous inside of vmap for
# memory_format other than torch.contiguous_format
xfail('contiguous'),
# RuntimeError: NYI: Tensor.clone(memory_format) inside vmap is only supported
# with memory_format torch.preserve_format or torch.contiguous_format (got ChannelsLast)
xfail('clone'),
# RuntimeError: When vmap-ing torch.nn.functional.one_hot,
# please provide an explicit positive num_classes argument.
xfail('nn.functional.one_hot'),
Comment on lines +3492 to +3504
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Normally I'd feel bad about adding these xfails, but we do have manual tests for contiguous, clone, one_hot, sub, in the codebase; and movedim is tested just by virtue of being a part of the vmap implementation.

# RuntimeError: Expected all tensors to be on the same device,
# but found at least two devices, cuda:0 and cpu!
xfail('eq', device_type='cuda'),
xfail('ge', device_type='cuda'),
xfail('gt', device_type='cuda'),
xfail('le', device_type='cuda'),
xfail('lt', device_type='cuda'),
xfail('ne', device_type='cuda'),
}

@_set_autograd_function_extension_enabled()
@with_tf32_off # https://github.com/pytorch/pytorch/issues/86798
@ops(op_db + additional_op_db + autograd_function_db, allowed_dtypes=(torch.float,))
@ops(op_db + additional_op_db + autograd_function_db, dtypes=OpDTypes.any_one)
@opsToleranceOverride('TestVmapOperatorsOpInfo', 'test_vmap_exhaustive', (
tol1('linalg.det',
{torch.float32: tol(atol=1e-04, rtol=1e-04)}, device_type='cuda'),
Expand All @@ -3486,7 +3532,8 @@ def test():
tol1('nn.functional.conv_transpose3d',
{torch.float32: tol(atol=1e-04, rtol=1e-02)}, device_type='cuda'),
))
@toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-04)})
@toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-04),
torch.complex64: tol(atol=1e-04, rtol=1e-04)})
@skipOps('TestVmapOperatorsOpInfo', 'test_vmap_exhaustive', vmap_fail.union({
# RuntimeError: Batch norm got a batched tensor as input while the running_mean or running_var,
# which will be updated in place, were not batched.
Expand All @@ -3497,6 +3544,15 @@ def test():
# The error inputs are vectors, that pass when batched as they are treated as a matrix
xfail('trace'),
xfail('as_strided', 'partial_views'),

# RuntimeError: output with shape [4, 4] doesn't match the broadcast shape [1, 4, 4]
xfail('addcdiv'),
xfail('addcmul'),
xfail('clamp'),
# AssertionError: Tensor-likes are not equal!
xfail('bitwise_left_shift', device_type='cpu'),
xfail('bitwise_right_shift', device_type='cpu'),
xfail('narrow_copy', device_type='cpu'),
}))
def test_vmap_exhaustive(self, device, dtype, op):
# needs to be fixed
Expand All @@ -3506,12 +3562,12 @@ def test_vmap_exhaustive(self, device, dtype, op):
skip_inplace=inplace_failure_list)

@_set_autograd_function_extension_enabled()
@ops(op_db + additional_op_db + autograd_function_db, allowed_dtypes=(torch.float,))
@ops(op_db + additional_op_db + autograd_function_db, dtypes=OpDTypes.any_one)
@opsToleranceOverride('TestVmapOperatorsOpInfo', 'test_op_has_batch_rule', (
tol1('linalg.det',
{torch.float32: tol(atol=1e-04, rtol=1e-04)}, device_type='cuda'),
))
@toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-04)})
@toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-04), torch.complex64: tol(atol=1e-04, rtol=1e-04)})
@skipOps('TestVmapOperatorsOpInfo', 'test_op_has_batch_rule', vmap_fail.union({
xfail('as_strided', 'partial_views'),
skip('to'), # RuntimeError: required rank 4 tensor to use channels_last format
Expand Down Expand Up @@ -3589,7 +3645,6 @@ def test_vmap_exhaustive(self, device, dtype, op):
xfail('native_dropout_backward'),
xfail('nn.functional.kl_div', ''),
xfail('multinomial', ''),
xfail('column_stack', ''),
xfail('pca_lowrank', ''),
xfail('normal', ''),
xfail('nn.functional.dropout2d', ''),
Expand Down Expand Up @@ -3618,7 +3673,6 @@ def test_vmap_exhaustive(self, device, dtype, op):
xfail('nn.functional.max_unpool3d', ''),
xfail('linalg.ldl_solve', '', device_type='cpu'),
xfail('chalf', ''),
xfail('arange', ''),
xfail('clamp_max', ''),
xfail('jiterator_binary_return_by_ref', device_type='cuda'),
xfail('special.spherical_bessel_j0'),
Expand All @@ -3633,10 +3687,7 @@ def test_vmap_exhaustive(self, device, dtype, op):
xfail('special.modified_bessel_k1'),
xfail('segment_reduce', 'offsets'),
xfail('special.bessel_j1'),
xfail('logspace', ''),
xfail('empty', ''),
xfail('index_reduce', ''),
xfail('linspace', ''),
xfail('special.laguerre_polynomial_l'),
xfail('special.hermite_polynomial_h'),
xfail('jiterator_binary', device_type='cuda'),
Expand All @@ -3650,7 +3701,6 @@ def test_vmap_exhaustive(self, device, dtype, op):
xfail('special.scaled_modified_bessel_k0'),
xfail('nn.functional.dropout3d', ''),
xfail('special.scaled_modified_bessel_k1'),
xfail('broadcast_shapes', ''),
xfail('special.modified_bessel_k0'),
xfail('linalg.vecdot', ''),
xfail('linalg.ldl_factor', ''),
Expand All @@ -3661,6 +3711,25 @@ def test_vmap_exhaustive(self, device, dtype, op):
xfail('linalg.lu', ''),
skip('linalg.ldl_solve', ''),
skip('_softmax_backward_data'),
# One or more of the overload doesn't have a Batch rule.
xfail('where'),
xfail('bincount'),
xfail('bitwise_and'),
xfail('bitwise_or'),
xfail('bitwise_xor'),
xfail('bitwise_left_shift'),
xfail('bitwise_right_shift'),
xfail('float_power'),
xfail('ge'),
xfail('gt'),
xfail('le'),
xfail('lt'),
xfail('ne'),
# AssertionError
# Mismatched elements: 18 / 20 (90.0%)
# Greatest absolute difference: 14.031710147857666 at index (0, 5) (up to 0.0001 allowed)
# Greatest relative difference: 2.9177700113052603 at index (0, 3) (up to 0.0001 allowed)
xfail('narrow_copy', device_type='cpu'),
Comment on lines +3724 to +3728
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you file an issue for silent correctness? Also, do you know which of the following is the actual problem?

  • the non-contiguous test is failing
  • the batching rule is bogus?
  • narrow_copy has inconsistent semantics on cpu/cuda?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, will file an issue.

  • I don't think non-contiguous sample is an issue as we haven't added non-contig testing to vmap tests.
  • Batching rule for narrow_copy seems innocuous and doesn't have special handling for CPU and CUDA.

So maybe the operator has some issue.

Batching Rule Ref:

std::tuple<Tensor, optional<int64_t>> narrow_copy_batch_rule(
const Tensor &self, optional<int64_t> self_bdim, int64_t dim, c10::SymInt start, c10::SymInt length)
{
TORCH_INTERNAL_ASSERT(self_bdim.has_value());
auto self_ = moveBatchDimToFront(self, self_bdim);
auto logical_rank = rankWithoutBatchDim(self, self_bdim);
dim = maybe_wrap_dim(dim, logical_rank) + 1;
auto result = self_.narrow_copy_symint(dim, start, length);
return std::make_tuple(result, 0);

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the operator is a problem: if we can come up with some repro that doesn't involve vmap that shows that on the same input (on cpu/cuda with the same strides), it produces different outputs, then that would be great. One idea to "get rid of the vmap" is to use make_fx to trace out what's happening

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. Thanks!

Have assigned the issue to myself. Will have a look soon.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

More info here : #91690

}))
def test_op_has_batch_rule(self, device, dtype, op):
# needs to be fixed
Expand Down Expand Up @@ -3688,12 +3757,14 @@ def test_op_has_batch_rule(self, device, dtype, op):
'div',
'floor_divide',
'fmod',
'gcd',
'heaviside',
'hypot',
'igamma',
'igammac',
'index_add',
'index_copy',
'lcm',
'ldexp',
'lerp',
'neg',
Expand Down