Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Intrinsic Support for BF16 Extension #223

Open
joshua-arch1 opened this issue Apr 26, 2023 · 7 comments
Open

Intrinsic Support for BF16 Extension #223

joshua-arch1 opened this issue Apr 26, 2023 · 7 comments
Labels
Revisit after v1.0 Features or problems we will revisit after the v1.0 release

Comments

@joshua-arch1
Copy link

The BF16 Extension has recently been proposed with three extra Instruction Set Extensions.
https://github.com/riscv/riscv-bfloat16

I'm wondering how we plan to address existing rvv intrinsics. Do we need to add new intrinsics tailored for bf16 datatype? If so, do we need to give a BF16 version for all the intrinsics with floating-point types? Maybe we can raise an issue for discussion.

@joshua-arch1
Copy link
Author

joshua-arch1 commented Apr 26, 2023

Maybe we can reuse most rvv intrinsics unless we want to generate new Zvfbfmin/Zvfbfwma instructions.

I have checked vector BF16 widening multiply-accumulate llvm implementation in AArch64 sve. It uses llvm.aarch64.sve.fmlalb.nxv4f32 for fmlalb and llvm.aarch64.sve.bfmlalb for bfmlalb. Therefore what is certain for RISCV now is to define a new intrinsic for vfwmaccbf16.

That is to say, for llvm.riscv.vfwmacc.nxv8f32.nxv8f16, we need to use llvm.riscv.vfwmaccbf16 in a given bf16 format.

@kito-cheng
Copy link
Collaborator

We definitely need to introduce new types like vbfloat16m1_t and corresponding RVV C intrinsic API.

Maybe we can reuse most rvv intrinsics unless we want to generate new Zvfbfmin/Zvfbfwma instructions.

I have checked vector BF16 widening multiply-accumulate implementation in AArch64 sve. It uses > llvm.aarch64.sve.fmlalb.nxv4f32 for fmlalb and llvm.aarch64.sve.bfmlalb for bfmlalb. Therefore what is certain for RISCV now >is to define a new intrinsic for vfwmaccbf16.

That is to say, for llvm.riscv.vfwmacc.nxv8f32.nxv8f16, we need to use llvm.riscv.vfwmaccbf16 in a given bf16 format.

That's LLVM implementation detail which should not specified in the intrinsic API spec.

@joshua-arch1
Copy link
Author

We definitely need to introduce new types like vbfloat16m1_t and corresponding RVV C intrinsic API.

Maybe we can reuse most rvv intrinsics unless we want to generate new Zvfbfmin/Zvfbfwma instructions.
I have checked vector BF16 widening multiply-accumulate implementation in AArch64 sve. It uses > llvm.aarch64.sve.fmlalb.nxv4f32 for fmlalb and llvm.aarch64.sve.bfmlalb for bfmlalb. Therefore what is certain for RISCV now >is to define a new intrinsic for vfwmaccbf16.
That is to say, for llvm.riscv.vfwmacc.nxv8f32.nxv8f16, we need to use llvm.riscv.vfwmaccbf16 in a given bf16 format.

That's LLVM implementation detail which should not specified in the intrinsic API spec.

But I don't think we need to add bfloat16 type for all the rvv floating-point intrinsics if we define a function to convert bf16 to fp32/fp16. Z(v)fbfmin has corresponding instructions.

@kito-cheng
Copy link
Collaborator

But I don't think we need to add bfloat16 type for all the rvv floating-point intrinsics if we define a function to convert bf16 to fp32/fp16. Z(v)fbfmin has corresponding instructions.

At least we should define intrinsic for convert instruction, and define _riscv_vfwmaccbf16[vv|vf]_bf16* for zvfbfwma, also some type utils functions like reinterpret.

@joshua-arch1
Copy link
Author

joshua-arch1 commented Apr 27, 2023

So I'll add bf16-format instrinsics for the following functions with float16 type.

Reinterpret Cast Conversion Functions:
vfloat16mf4_t __riscv_vreinterpret_v_i16mf4_f16mf4 (vint16mf4_t src);
vfloat16mf2_t __riscv_vreinterpret_v_i16mf2_f16mf2 (vint16mf2_t src);
vfloat16m1_t __riscv_vreinterpret_v_i16m1_f16m1 (vint16m1_t src);
vfloat16m2_t __riscv_vreinterpret_v_i16m2_f16m2 (vint16m2_t src);
vfloat16m4_t __riscv_vreinterpret_v_i16m4_f16m4 (vint16m4_t src);
vfloat16m8_t __riscv_vreinterpret_v_i16m8_f16m8 (vint16m8_t src);
vfloat16mf4_t __riscv_vreinterpret_v_u16mf4_f16mf4 (vuint16mf4_t src);
vfloat16mf2_t __riscv_vreinterpret_v_u16mf2_f16mf2 (vuint16mf2_t src);
vfloat16m1_t __riscv_vreinterpret_v_u16m1_f16m1 (vuint16m1_t src);
vfloat16m2_t __riscv_vreinterpret_v_u16m2_f16m2 (vuint16m2_t src);
vfloat16m4_t __riscv_vreinterpret_v_u16m4_f16m4 (vuint16m4_t src);
vfloat16m8_t __riscv_vreinterpret_v_u16m8_f16m8 (vuint16m8_t src);
vint16mf4_t __riscv_vreinterpret_v_f16mf4_i16mf4 (vfloat16mf4_t src);
vint16mf2_t __riscv_vreinterpret_v_f16mf2_i16mf2 (vfloat16mf2_t src);
vint16m1_t __riscv_vreinterpret_v_f16m1_i16m1 (vfloat16m1_t src);
vint16m2_t __riscv_vreinterpret_v_f16m2_i16m2 (vfloat16m2_t src);
vint16m4_t __riscv_vreinterpret_v_f16m4_i16m4 (vfloat16m4_t src);
vint16m8_t __riscv_vreinterpret_v_f16m8_i16m8 (vfloat16m8_t src);
vuint16mf4_t __riscv_vreinterpret_v_f16mf4_u16mf4 (vfloat16mf4_t src);
vuint16mf2_t __riscv_vreinterpret_v_f16mf2_u16mf2 (vfloat16mf2_t src);
vuint16m1_t __riscv_vreinterpret_v_f16m1_u16m1 (vfloat16m1_t src);
vuint16m2_t __riscv_vreinterpret_v_f16m2_u16m2 (vfloat16m2_t src);
vuint16m4_t __riscv_vreinterpret_v_f16m4_u16m4 (vfloat16m4_t src);
vuint16m8_t __riscv_vreinterpret_v_f16m8_u16m8 (vfloat16m8_t src);

Single-Width Floating-Point/Integer Type-Convert Functions:
vint16mf4_t __riscv_vfcvt_x_f_v_i16mf4 (vfloat16mf4_t src, size_t vl);
vint16mf4_t __riscv_vfcvt_rtz_x_f_v_i16mf4 (vfloat16mf4_t src, size_t vl);
vint16mf2_t __riscv_vfcvt_x_f_v_i16mf2 (vfloat16mf2_t src, size_t vl);
vint16mf2_t __riscv_vfcvt_rtz_x_f_v_i16mf2 (vfloat16mf2_t src, size_t vl);
vint16m1_t __riscv_vfcvt_x_f_v_i16m1 (vfloat16m1_t src, size_t vl);
vint16m1_t __riscv_vfcvt_rtz_x_f_v_i16m1 (vfloat16m1_t src, size_t vl);
vint16m2_t __riscv_vfcvt_x_f_v_i16m2 (vfloat16m2_t src, size_t vl);
vint16m2_t __riscv_vfcvt_rtz_x_f_v_i16m2 (vfloat16m2_t src, size_t vl);
vint16m4_t __riscv_vfcvt_x_f_v_i16m4 (vfloat16m4_t src, size_t vl);
vint16m4_t __riscv_vfcvt_rtz_x_f_v_i16m4 (vfloat16m4_t src, size_t vl);
vint16m8_t __riscv_vfcvt_x_f_v_i16m8 (vfloat16m8_t src, size_t vl);
vint16m8_t __riscv_vfcvt_rtz_x_f_v_i16m8 (vfloat16m8_t src, size_t vl);
vuint16mf4_t __riscv_vfcvt_xu_f_v_u16mf4 (vfloat16mf4_t src, size_t vl);
vuint16mf4_t __riscv_vfcvt_rtz_xu_f_v_u16mf4 (vfloat16mf4_t src, size_t vl);
vuint16mf2_t __riscv_vfcvt_xu_f_v_u16mf2 (vfloat16mf2_t src, size_t vl);
vuint16mf2_t __riscv_vfcvt_rtz_xu_f_v_u16mf2 (vfloat16mf2_t src, size_t vl);
vuint16m1_t __riscv_vfcvt_xu_f_v_u16m1 (vfloat16m1_t src, size_t vl);
vuint16m1_t __riscv_vfcvt_rtz_xu_f_v_u16m1 (vfloat16m1_t src, size_t vl);
vuint16m2_t __riscv_vfcvt_xu_f_v_u16m2 (vfloat16m2_t src, size_t vl);
vuint16m2_t __riscv_vfcvt_rtz_xu_f_v_u16m2 (vfloat16m2_t src, size_t vl);
vuint16m4_t __riscv_vfcvt_xu_f_v_u16m4 (vfloat16m4_t src, size_t vl);
vuint16m4_t __riscv_vfcvt_rtz_xu_f_v_u16m4 (vfloat16m4_t src, size_t vl);
vuint16m8_t __riscv_vfcvt_xu_f_v_u16m8 (vfloat16m8_t src, size_t vl);
vuint16m8_t __riscv_vfcvt_rtz_xu_f_v_u16m8 (vfloat16m8_t src, size_t vl);
vfloat16mf4_t __riscv_vfcvt_f_x_v_f16mf4 (vint16mf4_t src, size_t vl);
vfloat16mf2_t __riscv_vfcvt_f_x_v_f16mf2 (vint16mf2_t src, size_t vl);
vfloat16m1_t __riscv_vfcvt_f_x_v_f16m1 (vint16m1_t src, size_t vl);
vfloat16m2_t __riscv_vfcvt_f_x_v_f16m2 (vint16m2_t src, size_t vl);
vfloat16m4_t __riscv_vfcvt_f_x_v_f16m4 (vint16m4_t src, size_t vl);
vfloat16m8_t __riscv_vfcvt_f_x_v_f16m8 (vint16m8_t src, size_t vl);
vfloat16mf4_t __riscv_vfcvt_f_xu_v_f16mf4 (vuint16mf4_t src, size_t vl);
vfloat16mf2_t __riscv_vfcvt_f_xu_v_f16mf2 (vuint16mf2_t src, size_t vl);
vfloat16m1_t __riscv_vfcvt_f_xu_v_f16m1 (vuint16m1_t src, size_t vl);
vfloat16m2_t __riscv_vfcvt_f_xu_v_f16m2 (vuint16m2_t src, size_t vl);
vfloat16m4_t __riscv_vfcvt_f_xu_v_f16m4 (vuint16m4_t src, size_t vl);
vfloat16m8_t __riscv_vfcvt_f_xu_v_f16m8 (vuint16m8_t src, size_t vl);

Widening Floating-Point/Integer Type-Convert Functions:
vfloat16mf4_t __riscv_vfwcvt_f_x_v_f16mf4 (vint8mf8_t src, size_t vl);
vfloat16mf2_t __riscv_vfwcvt_f_x_v_f16mf2 (vint8mf4_t src, size_t vl);
vfloat16m1_t __riscv_vfwcvt_f_x_v_f16m1 (vint8mf2_t src, size_t vl);
vfloat16m2_t __riscv_vfwcvt_f_x_v_f16m2 (vint8m1_t src, size_t vl);
vfloat16m4_t __riscv_vfwcvt_f_x_v_f16m4 (vint8m2_t src, size_t vl);
vfloat16m8_t __riscv_vfwcvt_f_x_v_f16m8 (vint8m4_t src, size_t vl);
vfloat16mf4_t __riscv_vfwcvt_f_xu_v_f16mf4 (vuint8mf8_t src, size_t vl);
vfloat16mf2_t __riscv_vfwcvt_f_xu_v_f16mf2 (vuint8mf4_t src, size_t vl);
vfloat16m1_t __riscv_vfwcvt_f_xu_v_f16m1 (vuint8mf2_t src, size_t vl);
vfloat16m2_t __riscv_vfwcvt_f_xu_v_f16m2 (vuint8m1_t src, size_t vl);
vfloat16m4_t __riscv_vfwcvt_f_xu_v_f16m4 (vuint8m2_t src, size_t vl);
vfloat16m8_t __riscv_vfwcvt_f_xu_v_f16m8 (vuint8m4_t src, size_t vl);
vint32mf2_t __riscv_vfwcvt_x_f_v_i32mf2 (vfloat16mf4_t src, size_t vl);
vint32mf2_t __riscv_vfwcvt_rtz_x_f_v_i32mf2 (vfloat16mf4_t src, size_t vl);
vint32m1_t __riscv_vfwcvt_x_f_v_i32m1 (vfloat16mf2_t src, size_t vl);
vint32m1_t __riscv_vfwcvt_rtz_x_f_v_i32m1 (vfloat16mf2_t src, size_t vl);
vint32m2_t __riscv_vfwcvt_x_f_v_i32m2 (vfloat16m1_t src, size_t vl);
vint32m2_t __riscv_vfwcvt_rtz_x_f_v_i32m2 (vfloat16m1_t src, size_t vl);
vint32m4_t __riscv_vfwcvt_x_f_v_i32m4 (vfloat16m2_t src, size_t vl);
vint32m4_t __riscv_vfwcvt_rtz_x_f_v_i32m4 (vfloat16m2_t src, size_t vl);
vint32m8_t __riscv_vfwcvt_x_f_v_i32m8 (vfloat16m4_t src, size_t vl);
vint32m8_t __riscv_vfwcvt_rtz_x_f_v_i32m8 (vfloat16m4_t src, size_t vl);
vuint32mf2_t __riscv_vfwcvt_xu_f_v_u32mf2 (vfloat16mf4_t src, size_t vl);
vuint32mf2_t __riscv_vfwcvt_rtz_xu_f_v_u32mf2 (vfloat16mf4_t src, size_t vl);
vuint32m1_t __riscv_vfwcvt_xu_f_v_u32m1 (vfloat16mf2_t src, size_t vl);
vuint32m1_t __riscv_vfwcvt_rtz_xu_f_v_u32m1 (vfloat16mf2_t src, size_t vl);
vuint32m2_t __riscv_vfwcvt_xu_f_v_u32m2 (vfloat16m1_t src, size_t vl);
vuint32m2_t __riscv_vfwcvt_rtz_xu_f_v_u32m2 (vfloat16m1_t src, size_t vl);
vuint32m4_t __riscv_vfwcvt_xu_f_v_u32m4 (vfloat16m2_t src, size_t vl);
vuint32m4_t __riscv_vfwcvt_rtz_xu_f_v_u32m4 (vfloat16m2_t src, size_t vl);
vuint32m8_t __riscv_vfwcvt_xu_f_v_u32m8 (vfloat16m4_t src, size_t vl);
vuint32m8_t __riscv_vfwcvt_rtz_xu_f_v_u32m8 (vfloat16m4_t src, size_t vl);

Narrowing Floating-Point/Integer Type-Convert Functions:
vint8mf8_t __riscv_vfncvt_x_f_w_i8mf8 (vfloat16mf4_t src, size_t vl);
vint8mf8_t __riscv_vfncvt_rtz_x_f_w_i8mf8 (vfloat16mf4_t src, size_t vl);
vint8mf4_t __riscv_vfncvt_x_f_w_i8mf4 (vfloat16mf2_t src, size_t vl);
vint8mf4_t __riscv_vfncvt_rtz_x_f_w_i8mf4 (vfloat16mf2_t src, size_t vl);
vint8mf2_t __riscv_vfncvt_x_f_w_i8mf2 (vfloat16m1_t src, size_t vl);
vint8mf2_t __riscv_vfncvt_rtz_x_f_w_i8mf2 (vfloat16m1_t src, size_t vl);
vint8m1_t __riscv_vfncvt_x_f_w_i8m1 (vfloat16m2_t src, size_t vl);
vint8m1_t __riscv_vfncvt_rtz_x_f_w_i8m1 (vfloat16m2_t src, size_t vl);
vint8m2_t __riscv_vfncvt_x_f_w_i8m2 (vfloat16m4_t src, size_t vl);
vint8m2_t __riscv_vfncvt_rtz_x_f_w_i8m2 (vfloat16m4_t src, size_t vl);
vint8m4_t __riscv_vfncvt_x_f_w_i8m4 (vfloat16m8_t src, size_t vl);
vint8m4_t __riscv_vfncvt_rtz_x_f_w_i8m4 (vfloat16m8_t src, size_t vl);
vuint8mf8_t __riscv_vfncvt_xu_f_w_u8mf8 (vfloat16mf4_t src, size_t vl);
vuint8mf8_t __riscv_vfncvt_rtz_xu_f_w_u8mf8 (vfloat16mf4_t src, size_t vl);
vuint8mf4_t __riscv_vfncvt_xu_f_w_u8mf4 (vfloat16mf2_t src, size_t vl);
vuint8mf4_t __riscv_vfncvt_rtz_xu_f_w_u8mf4 (vfloat16mf2_t src, size_t vl);
vuint8mf2_t __riscv_vfncvt_xu_f_w_u8mf2 (vfloat16m1_t src, size_t vl);
vuint8mf2_t __riscv_vfncvt_rtz_xu_f_w_u8mf2 (vfloat16m1_t src, size_t vl);
vuint8m1_t __riscv_vfncvt_xu_f_w_u8m1 (vfloat16m2_t src, size_t vl);
vuint8m1_t __riscv_vfncvt_rtz_xu_f_w_u8m1 (vfloat16m2_t src, size_t vl);
vuint8m2_t __riscv_vfncvt_xu_f_w_u8m2 (vfloat16m4_t src, size_t vl);
vuint8m2_t __riscv_vfncvt_rtz_xu_f_w_u8m2 (vfloat16m4_t src, size_t vl);
vuint8m4_t __riscv_vfncvt_xu_f_w_u8m4 (vfloat16m8_t src, size_t vl);
vuint8m4_t __riscv_vfncvt_rtz_xu_f_w_u8m4 (vfloat16m8_t src, size_t vl);
vfloat16mf4_t __riscv_vfncvt_f_x_w_f16mf4 (vint32mf2_t src, size_t vl);
vfloat16mf2_t __riscv_vfncvt_f_x_w_f16mf2 (vint32m1_t src, size_t vl);
vfloat16m1_t __riscv_vfncvt_f_x_w_f16m1 (vint32m2_t src, size_t vl);
vfloat16m2_t __riscv_vfncvt_f_x_w_f16m2 (vint32m4_t src, size_t vl);
vfloat16m4_t __riscv_vfncvt_f_x_w_f16m4 (vint32m8_t src, size_t vl);
vfloat16mf4_t __riscv_vfncvt_f_xu_w_f16mf4 (vuint32mf2_t src, size_t vl);
vfloat16mf2_t __riscv_vfncvt_f_xu_w_f16mf2 (vuint32m1_t src, size_t vl);
vfloat16m1_t __riscv_vfncvt_f_xu_w_f16m1 (vuint32m2_t src, size_t vl);
vfloat16m2_t __riscv_vfncvt_f_xu_w_f16m2 (vuint32m4_t src, size_t vl);
vfloat16m4_t __riscv_vfncvt_f_xu_w_f16m4 (vuint32m8_t src, size_t vl);

As for __riscv_vfwcvt_f_f_v_f32 and __riscv_vfncvt_f_f_w_f16, I prefer to use a new format according to vfwcvtbf16.f.f.vand vfncvtbf16.f.f.w in the new Zvfbfmin Extension, so I didn't include them.

@fuhle044
Copy link

But I don't think we need to add bfloat16 type for all the rvv floating-point intrinsics if we define a function to convert bf16 to fp32/fp16. Z(v)fbfmin has corresponding instructions.

At least we should define intrinsic for convert instruction, and define _riscv_vfwmaccbf16[vv|vf]_bf16* for zvfbfwma, also some type utils functions like reinterpret.

Would it make sense to introduce BF16 load/store intrinsics that do a 16-bit integer load followed by a reinterpret cast? This would simplify the user interface considerably.

@joshua-arch1
Copy link
Author

Thank you for your suggestion. I'll add load/store intrinsics in my PR.

@eopXD eopXD added the Revisit after v1.0 Features or problems we will revisit after the v1.0 release label Oct 25, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Revisit after v1.0 Features or problems we will revisit after the v1.0 release
Projects
None yet
Development

No branches or pull requests

4 participants