Skip to content

Commit

Permalink
convert data type of indices and offsets in int_nbit_split_embedding_…
Browse files Browse the repository at this point in the history
…codegen_lookup_function

Summary:
X-link: facebookresearch/FBGEMM#336

Convert offsets.dtype to indices.dtype in gpu ops. The two tensor need to keep same dtype.

Reviewed By: 842974287

Differential Revision: D64085749

fbshipit-source-id: 0c07d85267f11a097d9dd6711fe4591cb5e11226
  • Loading branch information
mrmiywj authored and facebook-github-bot committed Oct 10, 2024
1 parent 8e7beba commit 3172e6f
Showing 1 changed file with 3 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,9 @@ Tensor int_nbit_split_embedding_codegen_lookup_function(
std::optional<int64_t> max_float8_D,
std::optional<int64_t> fp8_exponent_bits,
std::optional<int64_t> fp8_exponent_bias) {
if (offsets.scalar_type() != indices.scalar_type()) {
offsets = offsets.toType(indices.scalar_type());
}
if (static_cast<PoolingMode>(pooling_mode) == PoolingMode::NONE) {
std::vector<int64_t> max_D_list{
max_int2_D,
Expand Down

0 comments on commit 3172e6f

Please sign in to comment.