diff --git a/paddle/phi/kernels/funcs/load_store_util.h b/paddle/phi/kernels/funcs/load_store_util.h index 53fa916af7137..9c84bf05dba25 100644 --- a/paddle/phi/kernels/funcs/load_store_util.h +++ b/paddle/phi/kernels/funcs/load_store_util.h @@ -44,6 +44,17 @@ __forceinline__ __device__ OutType QuantHelperFunc(const InType input, ClipFunc(quant_value, min_bound, max_bound)); } +template +__forceinline__ __device__ OutType FP8QuantHelperFunc(const InType input, + const float scale, + const int round_type, + const float max_bound, + const float min_bound) { + float quant_value = max_bound * scale * input; + return static_cast( + ClipFunc(quant_value, min_bound, max_bound)); +} + template struct Load { explicit Load(const T *src) : src_(src) {} @@ -145,11 +156,19 @@ struct QuantStore { DstVec dst_vec; #pragma unroll for (int i = 0; i < VecSize; i++) { - dst_vec[i] = QuantHelperFunc(static_cast(src[i]), - quant_scale_, - quant_round_type_, - quant_max_bound_, - quant_min_bound_); + if constexpr (std::is_same_v) { + dst_vec[i] = FP8QuantHelperFunc(static_cast(src[i]), + quant_scale_, + quant_round_type_, + quant_max_bound_, + quant_min_bound_); + } else { + dst_vec[i] = QuantHelperFunc(static_cast(src[i]), + quant_scale_, + quant_round_type_, + quant_max_bound_, + quant_min_bound_); + } } phi::Store(dst_vec, dst_ + idx);