-
Notifications
You must be signed in to change notification settings - Fork 22.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[functorch] test: try using reference_inputs in vmap tests #91355
Changes from 3 commits
c7c4e66
adae0a4
fa1bd02
7d04fb4
faefe61
93c8aa7
613bd06
c91d9b2
e2ad2bc
ccf7621
311af6a
b6ba10c
4b2e3b5
21da1aa
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -19,7 +19,7 @@ | |||||||||||||||||||||
from torch.testing._internal.common_methods_invocations import op_db | ||||||||||||||||||||||
from torch.testing._internal.common_cuda import with_tf32_off | ||||||||||||||||||||||
from torch.testing._internal.common_device_type import instantiate_device_type_tests, \ | ||||||||||||||||||||||
skipCUDAIfNoMagma | ||||||||||||||||||||||
skipCUDAIfNoMagma, OpDTypes | ||||||||||||||||||||||
from torch.testing._internal.common_device_type import ops | ||||||||||||||||||||||
from torch.testing._internal.common_utils import ( | ||||||||||||||||||||||
parametrize, | ||||||||||||||||||||||
|
@@ -3361,11 +3361,45 @@ def test(): | |||||||||||||||||||||
vmap(op, in_dims)(*args, **kwargs) | ||||||||||||||||||||||
|
||||||||||||||||||||||
# Sample inputs check | ||||||||||||||||||||||
sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=False) | ||||||||||||||||||||||
sample_inputs_op = { | ||||||||||||||||||||||
# Take too long | ||||||||||||||||||||||
"special.chebyshev_polynomial_t", | ||||||||||||||||||||||
"special.chebyshev_polynomial_u", | ||||||||||||||||||||||
"special.chebyshev_polynomial_v", | ||||||||||||||||||||||
"special.chebyshev_polynomial_w", | ||||||||||||||||||||||
"special.hermite_polynomial_he", | ||||||||||||||||||||||
"special.laguerre_polynomial_l", | ||||||||||||||||||||||
"special.legendre_polynomial_p", | ||||||||||||||||||||||
"special.shifted_chebyshev_polynomial_t", | ||||||||||||||||||||||
"special.shifted_chebyshev_polynomial_u", | ||||||||||||||||||||||
"special.shifted_chebyshev_polynomial_v", | ||||||||||||||||||||||
"special.shifted_chebyshev_polynomial_w", | ||||||||||||||||||||||
# Leads to Illegal Memory Access on CUDA | ||||||||||||||||||||||
# locally | ||||||||||||||||||||||
"fft.fft", | ||||||||||||||||||||||
"fft.fft2", | ||||||||||||||||||||||
"fft.fftn", | ||||||||||||||||||||||
"fft.hfft", | ||||||||||||||||||||||
"fft.hfft2", | ||||||||||||||||||||||
"fft.hfftn", | ||||||||||||||||||||||
"fft.ifft", | ||||||||||||||||||||||
"fft.ifft2", | ||||||||||||||||||||||
"fft.ifftn", | ||||||||||||||||||||||
"fft.ihfft", | ||||||||||||||||||||||
"fft.ihfft2", | ||||||||||||||||||||||
"fft.ihfftn", | ||||||||||||||||||||||
} | ||||||||||||||||||||||
if op.name in sample_inputs_op: | ||||||||||||||||||||||
sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=False) | ||||||||||||||||||||||
else: | ||||||||||||||||||||||
sample_inputs_itr = op.reference_inputs(device, dtype, requires_grad=False) | ||||||||||||||||||||||
aliases, inplace_aliases = discover_variants(op) | ||||||||||||||||||||||
check_shape_only = op.name in ('empty_like', 'new_empty') | ||||||||||||||||||||||
for sample_input in sample_inputs_itr: | ||||||||||||||||||||||
args = (sample_input.input,) + sample_input.args | ||||||||||||||||||||||
if not any(map(lambda arg: isinstance(arg, torch.Tensor), args)): | ||||||||||||||||||||||
# Atleast one tensor required for vmap. | ||||||||||||||||||||||
continue | ||||||||||||||||||||||
kwargs = sample_input.kwargs | ||||||||||||||||||||||
is_batch_norm_and_training = is_batch_norm_training(op.name, kwargs) | ||||||||||||||||||||||
for args, in_dims, _ in generate_vmap_inputs( | ||||||||||||||||||||||
|
@@ -3420,16 +3454,6 @@ def test(): | |||||||||||||||||||||
xfail('nn.functional.fractional_max_pool2d'), # randomness | ||||||||||||||||||||||
xfail('pca_lowrank', ''), # random operation | ||||||||||||||||||||||
xfail('svd_lowrank', ''), # random operation | ||||||||||||||||||||||
xfail('linspace', ''), # test runner can't handle factory functions | ||||||||||||||||||||||
xfail('arange', ''), # test runner can't handle factory functions | ||||||||||||||||||||||
xfail('logspace', ''), # test runner can't handle factory functions | ||||||||||||||||||||||
xfail('scalar_tensor'), # test runner can't handle factory functions | ||||||||||||||||||||||
xfail('empty', ''), # test runner can't handle factory functions | ||||||||||||||||||||||
xfail('ones', ''), # test runner can't handle factory functions | ||||||||||||||||||||||
xfail('zeros', ''), # test runner can't handle factory functions | ||||||||||||||||||||||
xfail('full', ''), # test runner can't handle factory functions | ||||||||||||||||||||||
xfail('eye', ''), # non-tensor input | ||||||||||||||||||||||
xfail('broadcast_shapes', ''), # test runner can't handle non-Tensor ops | ||||||||||||||||||||||
xfail('sparse.sampled_addmm'), # sparse | ||||||||||||||||||||||
xfail("NumpyCubeNotComposableAutogradFunction"), # Not composable autograd.Function | ||||||||||||||||||||||
skip('_softmax_backward_data'), | ||||||||||||||||||||||
|
@@ -3449,7 +3473,6 @@ def test(): | |||||||||||||||||||||
xfail('nn.functional.gaussian_nll_loss'), # data-dependent control flow error | ||||||||||||||||||||||
xfail('nn.functional.embedding_bag'), # embedding renorm vmap inplace incompatible | ||||||||||||||||||||||
xfail('__rpow__'), # https://github.com/pytorch/functorch/issues/617 | ||||||||||||||||||||||
xfail('column_stack', ''), # Batching rule not implemented for aten::column_stack | ||||||||||||||||||||||
xfail('narrow'), # Batching rule not implemented for aten::narrow.Tensor | ||||||||||||||||||||||
|
||||||||||||||||||||||
# required rank 4 tensor to use channels_last format | ||||||||||||||||||||||
|
@@ -3473,11 +3496,34 @@ def test(): | |||||||||||||||||||||
xfail('jiterator_unary', device_type='cuda'), # NYI: querying is_contiguous inside of vmap | ||||||||||||||||||||||
xfail('jiterator_2inputs_2outputs', device_type='cuda'), # NYI: querying is_contiguous inside of vmap | ||||||||||||||||||||||
# --------------------------------------------------------------------- | ||||||||||||||||||||||
|
||||||||||||||||||||||
# TypeError: expected Tensor as element 0 in argument 0, but got NotImplementedType | ||||||||||||||||||||||
xfail('__rsub__'), | ||||||||||||||||||||||
# RuntimeError: Batching rule not implemented for aten::moveaxis.int; | ||||||||||||||||||||||
# the fallback path doesn't work on out= or view ops. | ||||||||||||||||||||||
xfail('movedim'), | ||||||||||||||||||||||
# RuntimeError: NYI: querying is_contiguous inside of vmap for | ||||||||||||||||||||||
# memory_format other than torch.contiguous_format | ||||||||||||||||||||||
xfail('contiguous'), | ||||||||||||||||||||||
# RuntimeError: NYI: Tensor.clone(memory_format) inside vmap is only supported | ||||||||||||||||||||||
# with memory_format torch.preserve_format or torch.contiguous_format (got ChannelsLast) | ||||||||||||||||||||||
xfail('clone'), | ||||||||||||||||||||||
# RuntimeError: When vmap-ing torch.nn.functional.one_hot, | ||||||||||||||||||||||
# please provide an explicit positive num_classes argument. | ||||||||||||||||||||||
xfail('nn.functional.one_hot'), | ||||||||||||||||||||||
Comment on lines
+3492
to
+3504
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Normally I'd feel bad about adding these xfails, but we do have manual tests for contiguous, clone, one_hot, sub, in the codebase; and movedim is tested just by virtue of being a part of the vmap implementation. |
||||||||||||||||||||||
# RuntimeError: Expected all tensors to be on the same device, | ||||||||||||||||||||||
# but found at least two devices, cuda:0 and cpu! | ||||||||||||||||||||||
xfail('eq', device_type='cuda'), | ||||||||||||||||||||||
xfail('ge', device_type='cuda'), | ||||||||||||||||||||||
xfail('gt', device_type='cuda'), | ||||||||||||||||||||||
xfail('le', device_type='cuda'), | ||||||||||||||||||||||
xfail('lt', device_type='cuda'), | ||||||||||||||||||||||
xfail('ne', device_type='cuda'), | ||||||||||||||||||||||
} | ||||||||||||||||||||||
|
||||||||||||||||||||||
@_set_autograd_function_extension_enabled() | ||||||||||||||||||||||
@with_tf32_off # https://github.com/pytorch/pytorch/issues/86798 | ||||||||||||||||||||||
@ops(op_db + additional_op_db + autograd_function_db, allowed_dtypes=(torch.float,)) | ||||||||||||||||||||||
@ops(op_db + additional_op_db + autograd_function_db, dtypes=OpDTypes.any_one) | ||||||||||||||||||||||
@opsToleranceOverride('TestVmapOperatorsOpInfo', 'test_vmap_exhaustive', ( | ||||||||||||||||||||||
tol1('linalg.det', | ||||||||||||||||||||||
{torch.float32: tol(atol=1e-04, rtol=1e-04)}, device_type='cuda'), | ||||||||||||||||||||||
|
@@ -3486,7 +3532,8 @@ def test(): | |||||||||||||||||||||
tol1('nn.functional.conv_transpose3d', | ||||||||||||||||||||||
{torch.float32: tol(atol=1e-04, rtol=1e-02)}, device_type='cuda'), | ||||||||||||||||||||||
)) | ||||||||||||||||||||||
@toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-04)}) | ||||||||||||||||||||||
@toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-04), | ||||||||||||||||||||||
torch.complex64: tol(atol=1e-04, rtol=1e-04)}) | ||||||||||||||||||||||
@skipOps('TestVmapOperatorsOpInfo', 'test_vmap_exhaustive', vmap_fail.union({ | ||||||||||||||||||||||
# RuntimeError: Batch norm got a batched tensor as input while the running_mean or running_var, | ||||||||||||||||||||||
# which will be updated in place, were not batched. | ||||||||||||||||||||||
|
@@ -3497,6 +3544,15 @@ def test(): | |||||||||||||||||||||
# The error inputs are vectors, that pass when batched as they are treated as a matrix | ||||||||||||||||||||||
xfail('trace'), | ||||||||||||||||||||||
xfail('as_strided', 'partial_views'), | ||||||||||||||||||||||
|
||||||||||||||||||||||
# RuntimeError: output with shape [4, 4] doesn't match the broadcast shape [1, 4, 4] | ||||||||||||||||||||||
xfail('addcdiv'), | ||||||||||||||||||||||
xfail('addcmul'), | ||||||||||||||||||||||
xfail('clamp'), | ||||||||||||||||||||||
# AssertionError: Tensor-likes are not equal! | ||||||||||||||||||||||
xfail('bitwise_left_shift', device_type='cpu'), | ||||||||||||||||||||||
xfail('bitwise_right_shift', device_type='cpu'), | ||||||||||||||||||||||
xfail('narrow_copy', device_type='cpu'), | ||||||||||||||||||||||
})) | ||||||||||||||||||||||
def test_vmap_exhaustive(self, device, dtype, op): | ||||||||||||||||||||||
# needs to be fixed | ||||||||||||||||||||||
|
@@ -3506,12 +3562,12 @@ def test_vmap_exhaustive(self, device, dtype, op): | |||||||||||||||||||||
skip_inplace=inplace_failure_list) | ||||||||||||||||||||||
|
||||||||||||||||||||||
@_set_autograd_function_extension_enabled() | ||||||||||||||||||||||
@ops(op_db + additional_op_db + autograd_function_db, allowed_dtypes=(torch.float,)) | ||||||||||||||||||||||
@ops(op_db + additional_op_db + autograd_function_db, dtypes=OpDTypes.any_one) | ||||||||||||||||||||||
@opsToleranceOverride('TestVmapOperatorsOpInfo', 'test_op_has_batch_rule', ( | ||||||||||||||||||||||
tol1('linalg.det', | ||||||||||||||||||||||
{torch.float32: tol(atol=1e-04, rtol=1e-04)}, device_type='cuda'), | ||||||||||||||||||||||
)) | ||||||||||||||||||||||
@toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-04)}) | ||||||||||||||||||||||
@toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-04), torch.complex64: tol(atol=1e-04, rtol=1e-04)}) | ||||||||||||||||||||||
@skipOps('TestVmapOperatorsOpInfo', 'test_op_has_batch_rule', vmap_fail.union({ | ||||||||||||||||||||||
xfail('as_strided', 'partial_views'), | ||||||||||||||||||||||
skip('to'), # RuntimeError: required rank 4 tensor to use channels_last format | ||||||||||||||||||||||
|
@@ -3589,7 +3645,6 @@ def test_vmap_exhaustive(self, device, dtype, op): | |||||||||||||||||||||
xfail('native_dropout_backward'), | ||||||||||||||||||||||
xfail('nn.functional.kl_div', ''), | ||||||||||||||||||||||
xfail('multinomial', ''), | ||||||||||||||||||||||
xfail('column_stack', ''), | ||||||||||||||||||||||
xfail('pca_lowrank', ''), | ||||||||||||||||||||||
xfail('normal', ''), | ||||||||||||||||||||||
xfail('nn.functional.dropout2d', ''), | ||||||||||||||||||||||
|
@@ -3618,7 +3673,6 @@ def test_vmap_exhaustive(self, device, dtype, op): | |||||||||||||||||||||
xfail('nn.functional.max_unpool3d', ''), | ||||||||||||||||||||||
xfail('linalg.ldl_solve', '', device_type='cpu'), | ||||||||||||||||||||||
xfail('chalf', ''), | ||||||||||||||||||||||
xfail('arange', ''), | ||||||||||||||||||||||
xfail('clamp_max', ''), | ||||||||||||||||||||||
xfail('jiterator_binary_return_by_ref', device_type='cuda'), | ||||||||||||||||||||||
xfail('special.spherical_bessel_j0'), | ||||||||||||||||||||||
|
@@ -3633,10 +3687,7 @@ def test_vmap_exhaustive(self, device, dtype, op): | |||||||||||||||||||||
xfail('special.modified_bessel_k1'), | ||||||||||||||||||||||
xfail('segment_reduce', 'offsets'), | ||||||||||||||||||||||
xfail('special.bessel_j1'), | ||||||||||||||||||||||
xfail('logspace', ''), | ||||||||||||||||||||||
xfail('empty', ''), | ||||||||||||||||||||||
xfail('index_reduce', ''), | ||||||||||||||||||||||
xfail('linspace', ''), | ||||||||||||||||||||||
xfail('special.laguerre_polynomial_l'), | ||||||||||||||||||||||
xfail('special.hermite_polynomial_h'), | ||||||||||||||||||||||
xfail('jiterator_binary', device_type='cuda'), | ||||||||||||||||||||||
|
@@ -3650,7 +3701,6 @@ def test_vmap_exhaustive(self, device, dtype, op): | |||||||||||||||||||||
xfail('special.scaled_modified_bessel_k0'), | ||||||||||||||||||||||
xfail('nn.functional.dropout3d', ''), | ||||||||||||||||||||||
xfail('special.scaled_modified_bessel_k1'), | ||||||||||||||||||||||
xfail('broadcast_shapes', ''), | ||||||||||||||||||||||
xfail('special.modified_bessel_k0'), | ||||||||||||||||||||||
xfail('linalg.vecdot', ''), | ||||||||||||||||||||||
xfail('linalg.ldl_factor', ''), | ||||||||||||||||||||||
|
@@ -3661,6 +3711,25 @@ def test_vmap_exhaustive(self, device, dtype, op): | |||||||||||||||||||||
xfail('linalg.lu', ''), | ||||||||||||||||||||||
skip('linalg.ldl_solve', ''), | ||||||||||||||||||||||
skip('_softmax_backward_data'), | ||||||||||||||||||||||
# One or more of the overload doesn't have a Batch rule. | ||||||||||||||||||||||
xfail('where'), | ||||||||||||||||||||||
xfail('bincount'), | ||||||||||||||||||||||
xfail('bitwise_and'), | ||||||||||||||||||||||
xfail('bitwise_or'), | ||||||||||||||||||||||
xfail('bitwise_xor'), | ||||||||||||||||||||||
xfail('bitwise_left_shift'), | ||||||||||||||||||||||
xfail('bitwise_right_shift'), | ||||||||||||||||||||||
xfail('float_power'), | ||||||||||||||||||||||
xfail('ge'), | ||||||||||||||||||||||
xfail('gt'), | ||||||||||||||||||||||
xfail('le'), | ||||||||||||||||||||||
xfail('lt'), | ||||||||||||||||||||||
xfail('ne'), | ||||||||||||||||||||||
# AssertionError | ||||||||||||||||||||||
# Mismatched elements: 18 / 20 (90.0%) | ||||||||||||||||||||||
# Greatest absolute difference: 14.031710147857666 at index (0, 5) (up to 0.0001 allowed) | ||||||||||||||||||||||
# Greatest relative difference: 2.9177700113052603 at index (0, 3) (up to 0.0001 allowed) | ||||||||||||||||||||||
xfail('narrow_copy', device_type='cpu'), | ||||||||||||||||||||||
Comment on lines
+3724
to
+3728
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you file an issue for silent correctness? Also, do you know which of the following is the actual problem?
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure, will file an issue.
So maybe the operator has some issue. Batching Rule Ref: pytorch/aten/src/ATen/functorch/BatchRulesViews.cpp Lines 506 to 515 in 3120054
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If the operator is a problem: if we can come up with some repro that doesn't involve vmap that shows that on the same input (on cpu/cuda with the same strides), it produces different outputs, then that would be great. One idea to "get rid of the vmap" is to use make_fx to trace out what's happening There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure. Thanks! Have assigned the issue to myself. Will have a look soon. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. More info here : #91690 |
||||||||||||||||||||||
})) | ||||||||||||||||||||||
def test_op_has_batch_rule(self, device, dtype, op): | ||||||||||||||||||||||
# needs to be fixed | ||||||||||||||||||||||
|
@@ -3688,12 +3757,14 @@ def test_op_has_batch_rule(self, device, dtype, op): | |||||||||||||||||||||
'div', | ||||||||||||||||||||||
'floor_divide', | ||||||||||||||||||||||
'fmod', | ||||||||||||||||||||||
'gcd', | ||||||||||||||||||||||
'heaviside', | ||||||||||||||||||||||
'hypot', | ||||||||||||||||||||||
'igamma', | ||||||||||||||||||||||
'igammac', | ||||||||||||||||||||||
'index_add', | ||||||||||||||||||||||
'index_copy', | ||||||||||||||||||||||
'lcm', | ||||||||||||||||||||||
'ldexp', | ||||||||||||||||||||||
'lerp', | ||||||||||||||||||||||
'neg', | ||||||||||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These ops already have skip for taking long time with reference inputs
Eg.
pytorch/torch/testing/_internal/opinfo/definitions/special.py
Lines 374 to 388 in 39d49db