Skip to content

Commit

Permalink
fix fused_bias_act'output for fp8 (#70434)
Browse files Browse the repository at this point in the history
  • Loading branch information
Wangzheee authored Dec 25, 2024
1 parent 2df97ab commit c2ee0b0
Showing 1 changed file with 24 additions and 5 deletions.
29 changes: 24 additions & 5 deletions paddle/phi/kernels/funcs/load_store_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,17 @@ __forceinline__ __device__ OutType QuantHelperFunc(const InType input,
ClipFunc<float>(quant_value, min_bound, max_bound));
}

template <typename InType, typename OutType>
__forceinline__ __device__ OutType FP8QuantHelperFunc(const InType input,
const float scale,
const int round_type,
const float max_bound,
const float min_bound) {
float quant_value = max_bound * scale * input;
return static_cast<OutType>(
ClipFunc<float>(quant_value, min_bound, max_bound));
}

template <typename T>
struct Load {
explicit Load(const T *src) : src_(src) {}
Expand Down Expand Up @@ -145,11 +156,19 @@ struct QuantStore {
DstVec dst_vec;
#pragma unroll
for (int i = 0; i < VecSize; i++) {
dst_vec[i] = QuantHelperFunc<float, OutT>(static_cast<float>(src[i]),
quant_scale_,
quant_round_type_,
quant_max_bound_,
quant_min_bound_);
if constexpr (std::is_same_v<OutT, phi::dtype::float8_e4m3fn>) {
dst_vec[i] = FP8QuantHelperFunc<float, OutT>(static_cast<float>(src[i]),
quant_scale_,
quant_round_type_,
quant_max_bound_,
quant_min_bound_);
} else {
dst_vec[i] = QuantHelperFunc<float, OutT>(static_cast<float>(src[i]),
quant_scale_,
quant_round_type_,
quant_max_bound_,
quant_min_bound_);
}
}

phi::Store<OutT, VecSize>(dst_vec, dst_ + idx);
Expand Down

0 comments on commit c2ee0b0

Please sign in to comment.