Skip to content

Commit

Permalink
Fix pt2_wrapper registration for unified TBE interface (#3238)
Browse files Browse the repository at this point in the history
Summary:
X-link: facebookresearch/FBGEMM#339

Pull Request resolved: #3238

`*_pt2_wrapper` registration is currently in gpu source files so the wrappers are not found in CPU builds.

To fix this, add codegen `m.def` in cpu source files for ops that have cpu support. Keep the `m.def` in gpu source files for GPU-only support (i.e., `ssd` and global weight decay `gwd` kernels.

Reviewed By: q10

Differential Revision: D64201777

fbshipit-source-id: 3f48ec82938f5881aed9aad9ec2de6fd192942e1
  • Loading branch information
spcyppt authored and facebook-github-bot committed Oct 11, 2024
1 parent f9f0600 commit f59d5ee
Show file tree
Hide file tree
Showing 2 changed files with 133 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ Tensor split_embedding_codegen_grad_indice_weights{{ vdesc }}_pt2_cpu_wrapper(
offsets,
feature_requires_grad);
}
{%- else %}
{%- endif %}
{%- for weighted in [True, False] %}
{%- set wdesc = "weighted" if weighted else "unweighted" %}
Expand Down Expand Up @@ -256,32 +255,143 @@ Tensor split_embedding_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vdesc }}_p

namespace {
TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
{%- if is_forward %}
DISPATCH_TO_CPU(
"split_embedding_codegen_grad_indice_weights{{ vdesc }}_pt2_wrapper",
split_embedding_codegen_grad_indice_weights{{ vdesc }}_pt2_cpu_wrapper);
{%- endif %}

{%- for weighted in [True, False] %}
{%- set wdesc = "weighted" if weighted else "unweighted" %}

{%- if is_forward %}
{%- set embedding_codegen_forward_op = "split_embedding_codegen_forward_{}{}_pt2".format(
wdesc, vdesc
)
%}
m.def("{{ embedding_codegen_forward_op }}_wrapper("
" Tensor host_weights, "
" Tensor dev_weights, "
" Tensor uvm_weights, "
" Tensor lxu_cache_weights, "
" Tensor weights_placements, "
" Tensor weights_offsets, "
{%- if nobag %}
" SymInt D, "
{%- else %}
" Tensor D_offsets, "
" SymInt total_D, "
" SymInt max_D, "
{%- endif %}
" Tensor hash_size_cumsum, "
" Tensor indices, "
" Tensor offsets, "
{%- if not nobag %}
" int pooling_mode, "
" Tensor indice_weights, "
{%- endif %}
" Tensor lxu_cache_locations, "
" Tensor uvm_cache_stats, "
{%- if vbe %}
" Tensor vbe_row_output_offsets, "
" Tensor vbe_b_t_map, "
" SymInt vbe_output_size, "
" int info_B_num_bits, "
" int info_B_mask_int64, "
{%- endif %}
" bool is_experimental, "
" int output_dtype "
") -> Tensor"
{%- if not nobag and not vbe %}
// only split_embedding_codegen_forward_[un]weighted_cuda
// are tested to be PT2 compliant
, {PT2_COMPLIANT_TAG}
{%- endif %}
);
DISPATCH_TO_CPU("{{ embedding_codegen_forward_op }}_wrapper", {{ embedding_codegen_forward_op }}_cpu_wrapper);
{%- else %}

{%- else %} {#-/* backward */#}
{%- set embedding_codegen_backward_op = "split_embedding_backward_codegen_{}_{}{}_pt2".format(
optimizer, wdesc, vdesc
)
%}
m.def("{{ embedding_codegen_backward_op }}_wrapper("
" Tensor grad_output, "
" Tensor(a!) host_weights, "
" Tensor(b!) dev_weights, "
" Tensor(c!) uvm_weights, "
" Tensor lxu_cache_weights, "
" Tensor weights_placements, "
" Tensor weights_offsets, "
{%- if nobag %}
" SymInt D, "
{%- else %}
" Tensor D_offsets, "
" SymInt max_D, "
{%- endif %}
" Tensor hash_size_cumsum, "
" int total_hash_size_bits, "
" Tensor indices, "
" Tensor offsets, "
{%- if not nobag %}
" int pooling_mode, "
" Tensor indice_weights, "
{%- endif %}
" Tensor lxu_cache_locations, "
" int BT_block_size, "
" int max_segment_length_per_warp, "
{%- if optimizer != "none" %}
" bool stochastic_rounding, "
{%- endif %}
" int info_B_num_bits, "
" int info_B_mask_int64, "
{%- if vbe %}
" Tensor B_offsets, "
" Tensor vbe_row_output_offsets, "
" Tensor vbe_b_t_map, "
{%- endif %}
" bool use_uniq_cache_locations, "
" bool use_homogeneous_placements,"
" {{ args_pt2.split_function_schemas | join(", ") }} "
{%- if not nobag %}
" , int output_dtype=0 "
{%- endif %}
") -> Tensor");
DISPATCH_TO_CPU("{{ embedding_codegen_backward_op }}_wrapper", {{ embedding_codegen_backward_op }}_cpu_wrapper);
{%- endif %}
{%- endif %} {#-/*if is_forward*/#}
{%- endfor %} {#-/*for weighted*/#}

{%- if is_forward %}
{%- set embedding_codegen_grad_indice_weights_op =
"split_embedding_codegen_grad_indice_weights{}_pt2".format(
vdesc
)
%}
m.def("{{ embedding_codegen_grad_indice_weights_op }}_wrapper("
" Tensor grad_output, "
" Tensor host_weights, "
" Tensor dev_weights, "
" Tensor uvm_weights, "
" Tensor lxu_cache_weights, "
" Tensor weights_placements, "
" Tensor weights_offsets, "
" Tensor D_offsets, "
" SymInt max_D, "
" Tensor indices, "
" Tensor offsets, "
" Tensor lxu_cache_locations, "
{%- if vbe %}
" Tensor feature_requires_grad, "
" Tensor vbe_row_output_offsets, "
" Tensor vbe_b_t_map, "
" int info_B_num_bits, "
" int info_B_mask_int64"
{%- else %}
" Tensor feature_requires_grad"
{%- endif %}
") -> Tensor");

DISPATCH_TO_CPU(
"{{ embedding_codegen_grad_indice_weights_op }}_wrapper",
{{ embedding_codegen_grad_indice_weights_op }}_cpu_wrapper);
{%- endif %}
}
} // namespace
{%- endfor %} {#-/* for vbe in [True, False] */#}

{% endif %} // if has_cpu_support
{% endif %} {#/* if has_cpu_support */#}
// clang-format on
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,6 @@ Tensor {{ fwd_mdesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_pt2_cuda
////////////////////////////////////////////////////////////////////////////////

TORCH_LIBRARY_FRAGMENT(fbgemm, m) {

{%- for weighted in [True, False] %}
{%- for nobag in ([False] if (weighted or vbe) else [True, False]) %}
{%- set wdesc = "weighted" if weighted else "unweighted" %}
Expand All @@ -485,11 +484,14 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
else [False]) %}
{%- set gwddesc = "_gwd" if is_gwd else "" %}
{%- set desc_suffix = wdesc + vdesc + gwddesc %}

{%- if is_forward %}
{%- set embedding_codegen_forward_op = "{}_embedding{}_codegen_forward_{}_pt2".format(
fwd_mdesc, ndesc, desc_suffix
)
%}
{%- if ssd or is_gwd or nobag %}
/* Register scehema for wrappers with GPU-only support */
m.def("{{ embedding_codegen_forward_op }}_wrapper("
" Tensor host_weights, "
" Tensor dev_weights, "
Expand Down Expand Up @@ -540,17 +542,20 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
, {PT2_COMPLIANT_TAG}
{%- endif %}
);

{%- endif %}
DISPATCH_TO_CUDA(
"{{ embedding_codegen_forward_op }}_wrapper",
{{ embedding_codegen_forward_op }}_cuda_wrapper
);
m.impl("{{ embedding_codegen_forward_op }}_wrapper", torch::dispatch(c10::DispatchKey::Meta, TORCH_FN({{ embedding_codegen_forward_op }}_meta_wrapper)));
{%- else %}

{%- else %} {#-/* backward */#}
{%- set embedding_codegen_backward_op = "{}_embedding{}_backward_codegen_{}_{}_pt2".format(
bwd_mdesc, ndesc, optimizer, desc_suffix
)
%}
{%- if ssd or is_gwd or nobag or not has_cpu_support %}
/* Register scehema for wrappers with GPU-only support */
m.def("{{ embedding_codegen_backward_op }}_wrapper("
" Tensor grad_output, "
" Tensor(a!) host_weights, "
Expand Down Expand Up @@ -606,11 +611,12 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
" , int output_dtype=0 "
{%- endif %}
") -> Tensor");
{%- endif %}
DISPATCH_TO_CUDA(
"{{ embedding_codegen_backward_op }}_wrapper",
{{ embedding_codegen_backward_op }}_cuda_wrapper
);
{%- endif %}
{%- endif %} {#-/* if is_forward */#}
{%- endfor %} {#-/*for is_gwd*/#}
{%- endfor %} {#-/*for nobag*/#}
{%- endfor %} {#-/*for weighted*/#}
Expand All @@ -620,6 +626,8 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
fwd_mdesc, vdesc
)
%}
{%- if ssd %}
/* Register scehema for wrappers with GPU-only support */
m.def("{{ embedding_codegen_grad_indice_weights_op }}_wrapper("
" Tensor grad_output, "
" Tensor host_weights, "
Expand Down Expand Up @@ -647,7 +655,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
" Tensor feature_requires_grad"
{%- endif %}
") -> Tensor");

{%- endif %}
DISPATCH_TO_CUDA(
"{{ embedding_codegen_grad_indice_weights_op }}_wrapper",
{{ embedding_codegen_grad_indice_weights_op }}_cuda_wrapper
Expand Down

0 comments on commit f59d5ee

Please sign in to comment.