Skip to content

Commit

Permalink
Add meta function for PT2 wrappers (#3240)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #3240

X-link: facebookresearch/FBGEMM#341

To be pt2 compliant, add meta functions for PT2 wrappers and dispatch the ops to META.

Without meta dispatch, pt2 tests will fail on `torch.compile` with error like:
```
Unsupported: Backend compiler failed with a fake tensor exception
...
Caused by UnsupportedOperatorException: fbgemm.split_embedding_backward_codegen_sgd_unweighted_vbe_pt2_wrapper.default
```

Reviewed By: q10

Differential Revision: D64255747

fbshipit-source-id: 0e4fcbd061f2a6b3080ff157de07d83eec08c05b
  • Loading branch information
spcyppt authored and facebook-github-bot committed Oct 14, 2024
1 parent b260e98 commit c8e32e4
Show file tree
Hide file tree
Showing 12 changed files with 127 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ enum SSDTensor {
{%- for vbe in ([True, False] if has_vbe_support else [False]) %}
{%- set vdesc = "_vbe" if vbe else "" %}

{%- for dispatch_type in ["cuda", "meta"] %}
{%- for weighted in [True, False] %}
{%- for nobag in ([False] if (weighted or vbe) else [True, False]) %}
{%- set wdesc = "weighted" if weighted else "unweighted" %}
Expand All @@ -61,9 +62,8 @@ enum SSDTensor {
{%- set gwddesc = "_gwd" if is_gwd else "" %}
{%- set desc_suffix = wdesc + vdesc + gwddesc %}

{%- if is_forward %}
{#-/* PT2 wrapper function for forward CUDA */#}
{%- for dispatch_type in ["cuda", "meta"] %}
{%- if is_forward %}
Tensor {{ fwd_mdesc }}_embedding{{ ndesc }}_codegen_forward_{{ desc_suffix }}_pt2_{{ dispatch_type }}_wrapper(
const Tensor& /*host_weights*/,
const Tensor& dev_weights,
Expand Down Expand Up @@ -199,11 +199,10 @@ Tensor {{ fwd_mdesc }}_embedding{{ ndesc }}_codegen_forward_{{ desc_suffix }}_pt
is_experimental
);
};
{%- endfor %} {#-/*for dispatch_type in ["cuda", "meta"]*/#}
{%- else %}

{#-/* PT2 wrapper function for backward CUDA */#}
Tensor {{ bwd_mdesc }}_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ desc_suffix }}_pt2_cuda_wrapper(
Tensor {{ bwd_mdesc }}_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ desc_suffix }}_pt2_{{ dispatch_type }}_wrapper(
const Tensor& grad_output,
const Tensor& /*host_weights*/,
const Tensor& dev_weights,
Expand Down Expand Up @@ -372,14 +371,14 @@ Tensor {{ bwd_mdesc }}_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{
}

{%- endif %}
{%- endfor %} {#-/*for weighted*/#}
{%- endfor %} {#-/*for is_gwd*/#}
{%- endfor %} {#-/*for nobag*/#}
{%- endfor %} {#-/*for weighted*/#}


{%- if is_forward %}
{#-/* PT2 wrapper function for backward grad_indice_weights CUDA */#}
Tensor {{ fwd_mdesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_pt2_cuda_wrapper(
Tensor {{ fwd_mdesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_pt2_{{ dispatch_type }}_wrapper(
const Tensor& grad_output,
const Tensor& /*host_weights*/,
const Tensor& dev_weights,
Expand Down Expand Up @@ -464,6 +463,8 @@ Tensor {{ fwd_mdesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_pt2_cuda
);
}
{%- endif %}
{%- endfor %} {#-/*for dispatch_type*/#}

////////////////////////////////////////////////////////////////////////////////
// Op registrations
////////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -616,6 +617,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
"{{ embedding_codegen_backward_op }}_wrapper",
{{ embedding_codegen_backward_op }}_cuda_wrapper
);
m.impl("{{ embedding_codegen_backward_op }}_wrapper", torch::dispatch(c10::DispatchKey::Meta, TORCH_FN({{ embedding_codegen_backward_op }}_meta_wrapper)));
{%- endif %} {#-/* if is_forward */#}
{%- endfor %} {#-/*for is_gwd*/#}
{%- endfor %} {#-/*for nobag*/#}
Expand Down Expand Up @@ -660,6 +662,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
"{{ embedding_codegen_grad_indice_weights_op }}_wrapper",
{{ embedding_codegen_grad_indice_weights_op }}_cuda_wrapper
);
m.impl("{{ embedding_codegen_grad_indice_weights_op }}_wrapper", torch::dispatch(c10::DispatchKey::Meta, TORCH_FN({{ embedding_codegen_grad_indice_weights_op }}_meta_wrapper)));
{%- endif %}

}
Expand Down
2 changes: 2 additions & 0 deletions fbgemm_gpu/test/tbe/training/backward_adagrad_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
if open_source:
# pyre-ignore[21]
from test_utils import (
additional_decorators,
gpu_available,
gpu_unavailable,
gradcheck,
Expand All @@ -55,6 +56,7 @@
)
else:
from fbgemm_gpu.test.test_utils import ( # noqa F401
additional_decorators,
gpu_available,
gpu_unavailable,
gradcheck,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from hypothesis import given, settings

from .backward_adagrad_common import ( # noqa
additional_decorators,
adjust_mixed_B_st,
common_settings,
common_strategy,
Expand Down Expand Up @@ -356,7 +357,7 @@ def execute_global_weight_decay( # noqa C901
)


@optests.generate_opcheck_tests(fast=True)
@optests.generate_opcheck_tests(fast=True, additional_decorators=additional_decorators)
class BackwardAdagradGlobalWeightDecay(unittest.TestCase):
@unittest.skipIf(*gpu_unavailable)
@given(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from hypothesis import given, settings

from .backward_adagrad_common import (
additional_decorators,
adjust_mixed_B_st,
common_settings,
common_strategy,
Expand All @@ -32,7 +33,7 @@
test_st["D"] = st.integers(min_value=128, max_value=512)


@optests.generate_opcheck_tests(fast=True)
@optests.generate_opcheck_tests(fast=True, additional_decorators=additional_decorators)
class BackwardAdagradLargeDimTest(unittest.TestCase):
@skipIfRocm("Unblock large dim enablement on other GPUs")
@unittest.skipIf(*gpu_unavailable)
Expand Down
3 changes: 2 additions & 1 deletion fbgemm_gpu/test/tbe/training/backward_adagrad_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from hypothesis import given, settings

from .backward_adagrad_common import (
additional_decorators,
adjust_mixed_B_st,
common_settings,
common_strategy,
Expand All @@ -31,7 +32,7 @@
test_st["D"] = st.integers(min_value=2, max_value=128)


@optests.generate_opcheck_tests(fast=True)
@optests.generate_opcheck_tests(fast=True, additional_decorators=additional_decorators)
class BackwardAdagradTest(unittest.TestCase):
@unittest.skipIf(*gpu_unavailable)
@given(mixed_B=st.booleans(), **test_st)
Expand Down
11 changes: 9 additions & 2 deletions fbgemm_gpu/test/tbe/training/backward_dense_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,16 @@

if open_source:
# pyre-ignore[21]
from test_utils import gradcheck, optests, skipIfRocm, use_cpu_strategy
from test_utils import (
additional_decorators,
gradcheck,
optests,
skipIfRocm,
use_cpu_strategy,
)
else:
from fbgemm_gpu.test.test_utils import (
additional_decorators,
gradcheck,
optests,
skipIfRocm,
Expand All @@ -49,7 +56,7 @@
VERBOSITY: Verbosity = Verbosity.verbose


@optests.generate_opcheck_tests(fast=True)
@optests.generate_opcheck_tests(fast=True, additional_decorators=additional_decorators)
class BackwardDenseTest(unittest.TestCase):
@skipIfRocm("Currently runs into memory access issues")
@given(
Expand Down
46 changes: 32 additions & 14 deletions fbgemm_gpu/test/tbe/training/backward_none_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import random
import unittest
from typing import Any, Callable, Dict, List, Optional, Union
from typing import Any, List, Optional, Union

import hypothesis.strategies as st
import numpy as np
Expand Down Expand Up @@ -44,23 +44,41 @@

if open_source:
# pyre-ignore[21]
from test_utils import gpu_unavailable, optests, TEST_WITH_ROCM
from test_utils import (
additional_decorators,
gpu_unavailable,
optests,
TEST_WITH_ROCM,
)
else:
from fbgemm_gpu.test.test_utils import gpu_unavailable, optests, TEST_WITH_ROCM
from fbgemm_gpu.test.test_utils import (
additional_decorators,
gpu_unavailable,
optests,
TEST_WITH_ROCM,
)
VERBOSITY: Verbosity = Verbosity.verbose

# pyre-ignore
additional_decorators: Dict[str, List[Callable]] = {
"test_schema__test_backward_none_with_rowwise_adagrad": [
unittest.skip("Cannot access data pointer of Tensor that doesn't have storage")
],
"test_faketensor__test_backward_none_with_rowwise_adagrad": [
unittest.skip("Cannot access data pointer of Tensor that doesn't have storage")
],
"test_autograd_registration__test_backward_none_with_rowwise_adagrad": [
unittest.skip("Cannot access data pointer of Tensor that doesn't have storage")
],
}
additional_decorators.update(
{
"test_schema__test_backward_none_with_rowwise_adagrad": [
unittest.skip(
"Cannot access data pointer of Tensor that doesn't have storage"
)
],
"test_faketensor__test_backward_none_with_rowwise_adagrad": [
unittest.skip(
"Cannot access data pointer of Tensor that doesn't have storage"
)
],
"test_autograd_registration__test_backward_none_with_rowwise_adagrad": [
unittest.skip(
"Cannot access data pointer of Tensor that doesn't have storage"
)
],
}
)


@optests.generate_opcheck_tests(fast=True, additional_decorators=additional_decorators)
Expand Down
11 changes: 9 additions & 2 deletions fbgemm_gpu/test/tbe/training/backward_optimizers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,16 @@

if open_source:
# pyre-ignore[21]
from test_utils import gpu_unavailable, optests, TEST_WITH_ROCM, use_cpu_strategy
from test_utils import (
additional_decorators,
gpu_unavailable,
optests,
TEST_WITH_ROCM,
use_cpu_strategy,
)
else:
from fbgemm_gpu.test.test_utils import (
additional_decorators,
gpu_unavailable,
optests,
TEST_WITH_ROCM,
Expand All @@ -66,7 +73,7 @@
VERBOSITY: Verbosity = Verbosity.verbose


@optests.generate_opcheck_tests(fast=True)
@optests.generate_opcheck_tests(fast=True, additional_decorators=additional_decorators)
class BackwardOptimizersTest(unittest.TestCase):
def assert_close_optim_state(self, test: torch.Tensor, ref: torch.Tensor) -> None:
tolerance = 1.0e-4 if test.dtype == torch.float else 1.0e-2
Expand Down
11 changes: 9 additions & 2 deletions fbgemm_gpu/test/tbe/training/backward_sgd_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,16 @@

if open_source:
# pyre-ignore[21]
from test_utils import gpu_unavailable, optests, TEST_WITH_ROCM, use_cpu_strategy
from test_utils import (
additional_decorators,
gpu_unavailable,
optests,
TEST_WITH_ROCM,
use_cpu_strategy,
)
else:
from fbgemm_gpu.test.test_utils import (
additional_decorators,
gpu_unavailable,
optests,
TEST_WITH_ROCM,
Expand All @@ -56,7 +63,7 @@
VERBOSITY: Verbosity = Verbosity.verbose


@optests.generate_opcheck_tests(fast=True)
@optests.generate_opcheck_tests(fast=True, additional_decorators=additional_decorators)
class BackwardSGDTest(unittest.TestCase):
def execute_backward_sgd_( # noqa C901
self,
Expand Down
10 changes: 10 additions & 0 deletions fbgemm_gpu/test/tbe/training/failures_dict_fast.json
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,16 @@
"fbgemm::split_embedding_codegen_lookup_partial_rowwise_lamb_function": {},
"fbgemm::split_embedding_codegen_lookup_rowwise_adagrad_function": {},
"fbgemm::split_embedding_codegen_lookup_rowwise_adagrad_function_cpu": {},
"fbgemm::split_embedding_codegen_lookup_rowwise_adagrad_function_pt2": {
"ForwardTest.test_faketensor__test_forward_cpu_fp32": {
"comment": "",
"status": "xfail"
},
"ForwardTest.test_schema__test_forward_cpu_fp32": {
"comment": "",
"status": "xfail"
}
},
"fbgemm::split_embedding_codegen_lookup_rowwise_adagrad_with_counter_function": {},
"fbgemm::split_embedding_codegen_lookup_rowwise_weighted_adagrad_function": {},
"fbgemm::split_embedding_codegen_lookup_sgd_function": {},
Expand Down
47 changes: 29 additions & 18 deletions fbgemm_gpu/test/tbe/training/forward_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@

import random
import unittest
from typing import Callable, Dict, List

import hypothesis.strategies as st
import numpy as np
Expand Down Expand Up @@ -45,28 +44,40 @@

if open_source:
# pyre-ignore[21]
from test_utils import gpu_unavailable, optests, TEST_WITH_ROCM
from test_utils import (
additional_decorators,
gpu_unavailable,
optests,
TEST_WITH_ROCM,
)
else:
from fbgemm_gpu.test.test_utils import gpu_unavailable, optests, TEST_WITH_ROCM
from fbgemm_gpu.test.test_utils import (
additional_decorators,
gpu_unavailable,
optests,
TEST_WITH_ROCM,
)

VERBOSITY: Verbosity = Verbosity.verbose

# pyre-ignore
additional_decorators: Dict[str, List[Callable]] = {
# TODO: Implement the operator registrations later
"test_faketensor__test_forward_cpu_int8": [
unittest.skip("Operator not implemented for Meta tensors"),
],
"test_faketensor__test_forward_fused_pooled_emb_quant": [
unittest.skip("Operator not implemented for Meta tensors"),
],
"test_faketensor__test_forward_gpu_no_cache_int8": [
unittest.skip("Operator not implemented for Meta tensors"),
],
"test_faketensor__test_forward_gpu_uvm_cache_int8": [
unittest.skip("Operator not implemented for Meta tensors"),
],
}
additional_decorators.update(
{
# TODO: Implement the operator registrations later
"test_faketensor__test_forward_cpu_int8": [
unittest.skip("Operator not implemented for Meta tensors"),
],
"test_faketensor__test_forward_fused_pooled_emb_quant": [
unittest.skip("Operator not implemented for Meta tensors"),
],
"test_faketensor__test_forward_gpu_no_cache_int8": [
unittest.skip("Operator not implemented for Meta tensors"),
],
"test_faketensor__test_forward_gpu_uvm_cache_int8": [
unittest.skip("Operator not implemented for Meta tensors"),
],
}
)


@optests.generate_opcheck_tests(fast=True, additional_decorators=additional_decorators)
Expand Down
12 changes: 12 additions & 0 deletions fbgemm_gpu/test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,18 @@

TEST_WITH_ROCM: bool = os.getenv("FBGEMM_TEST_WITH_ROCM", "0") == "1"

# Skip pt2 compliant tag test for certain operators
# TODO: remove this once the operators are pt2 compliant
# pyre-ignore
additional_decorators: Dict[str, List[Callable]] = {
# vbe_generate_metadata_cpu return different values from vbe_generate_metadata_meta
# this fails fake_tensor test as the test expects them to be the same
# fake_tensor test is added in failures_dict but failing fake_tensor test still cause pt2_compliant tag test to fail
"test_pt2_compliant_tag_fbgemm_split_embedding_codegen_lookup_rowwise_adagrad_function_pt2": [
unittest.skip("Operator failed on pt2 compliant tag"),
]
}

# Used for `@unittest.skipIf`
gpu_unavailable: Tuple[bool, str] = (
not torch.cuda.is_available() or torch.cuda.device_count() == 0,
Expand Down

0 comments on commit c8e32e4

Please sign in to comment.