diff --git a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_host.cpp b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_host.cpp index 8bc20c215..5fcc3a017 100644 --- a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_host.cpp +++ b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_host.cpp @@ -283,6 +283,9 @@ Tensor int_nbit_split_embedding_codegen_lookup_function( std::optional max_float8_D, std::optional fp8_exponent_bits, std::optional fp8_exponent_bias) { + if (offsets.scalar_type() != indices.scalar_type()) { + offsets = offsets.toType(indices.scalar_type()); + } if (static_cast(pooling_mode) == PoolingMode::NONE) { std::vector max_D_list{ max_int2_D,