Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update rvv-intrinsic-generator to define new RVV C intrinsic API for bf16 type #229

Closed
wants to merge 14 commits into from
2 changes: 2 additions & 0 deletions rvv-intrinsic-generator/rvv_intrinsic_gen/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,10 @@
WFSEWS = [16, 32]
NSEWS = [16, 32, 64]
TYPES = ["float", "int", "uint"]
TYPES = ["float", "int", "uint", "bfloat"]
ITYPES = ["int", "uint"]
FTYPES = ["float"]
BFTYPES = ["bfloat"]
MTYPES = ["bool"]
MLENS = [1, 2, 4, 8, 16, 32, 64]
REF_DOC_URL = "../rvv-intrinsic-api.md"
Expand Down
27 changes: 24 additions & 3 deletions rvv-intrinsic-generator/rvv_intrinsic_gen/inst.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
from templates import mask_load_store_template
from templates import permute_template
from constants import LMULS,WLMULS,NCVTLMULS,SEWS,WSEWS,FSEWS,WFSEWS,NSEWS,\
TYPES,ITYPES,FTYPES,MTYPES,MLENS,REF_DOC_URL
TYPES,BTYPES,ITYPES,FTYPES,BFTYPES,MTYPES,MLENS,REF_DOC_URL
from generator import CompatibleHeaderGenerator


Expand All @@ -68,11 +68,11 @@ def gen(g):

g.function_group(load_template, "Vector Unit-Stride Load Functions",
REF_DOC_URL + "#74-vector-unit-stride-operations", ["vle"],
TYPES, SEWS, LMULS, decorators.has_masking_maskedoff_policy)
BTYPES, SEWS, LMULS, decorators.has_masking_maskedoff_policy)

g.function_group(store_template, "Vector Unit-Stride Store Functions",
REF_DOC_URL + "#74-vector-unit-stride-operations", ["vse"],
TYPES, SEWS, LMULS, decorators.has_masking_no_maskedoff)
BTYPES, SEWS, LMULS, decorators.has_masking_no_maskedoff)

g.function_group(load_template, "Vector Strided Load Functions",
REF_DOC_URL + "#75-vector-strided-loadstore-operations",
Expand Down Expand Up @@ -408,6 +408,27 @@ def gen(g):
"Narrowing Floating-Point/Integer Type-Convert Functions", REF_DOC_URL +
"#1419-narrowing-floating-pointinteger-type-convert-operations", ["ncvt"],
"", NSEWS, NCVTLMULS, decorators.has_masking_maskedoff_policy)

####################################################################
g.start_group("Vector BFloat16 Functions (still on the draft status)")

g.function_group(
mac_template, "Vector BFloat16 Widening Multiply-Add Functions",
REF_DOC_URL + "#1420-vector-bf16-widening-multiply-add-operations",
["wmacc"], BFTYPES, WFSEWS, WLMULS,
decorators.has_masking_no_maskedoff_policy)

g.function_group(
cvt_op_template, "Widening BFloat16/FP32 Type-Convert Functions",
REF_DOC_URL +
"#1421-widening-bf16-fp32-type-convert-operations", ["wcvtbf16"],
"", WSEWS, WLMULS, decorators.has_masking_maskedoff_policy)

g.function_group(
cvt_op_template,
"Narrowing FP32/BFloat16 Type-Convert Functions", REF_DOC_URL +
"#1422-narrowing-fp32-bf16-type-convert-operations", ["ncvtbf16"],
"", NSEWS, NCVTLMULS, decorators.has_masking_maskedoff_policy)

####################################################################
g.start_group("Vector Reduction Functions")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ def render(G, op_list, type_list, sew_list, lmul_list, decorator_list):
convert_set = [["int", "x", "float", "f"], ["int", "x", "int", "x"],
["uint", "x", "uint", "x"], ["uint", "xu", "float", "f"],
["float", "f", "int", "x"], ["float", "f", "uint", "xu"],
["float", "f", "float", "f"]]
["float", "f", "float", "f"], ["bfloat", "bf", "float", "f"],
["float", "f", "bfloat", "bf"]]
for args in prod(
OP=op_list, SEW=sew_list, TYPES=convert_set, LMUL=lmul_list):
op = args["OP"]
Expand All @@ -54,11 +55,19 @@ def render(G, op_list, type_list, sew_list, lmul_list, decorator_list):
if (op == "cvt" and args["TYPES1"] == args["TYPES3"]):
continue

if ((args["TYPES1"] == "bf" or args["TYPES3"] == "bf") and
op != "wcvtbf16" and op != "ncvtbf16"):
continue

if ((op == "wcvtbf16" and args["TYPES3"] != "bf" ) or
(op == "ncvtbf16" and args["TYPES1"] != "bf" )):
continue

args["MIDDLE"] = "v"
factor = ""
if op == "wcvt":
if op == "wcvt" or op == "wcvtbf16":
factor = "W"
if op == "ncvt":
if op == "ncvt" or op == "ncvtbf16":
factor = "N"
args["MIDDLE"] = "w"

Expand Down Expand Up @@ -101,7 +110,7 @@ def render(G, op_list, type_list, sew_list, lmul_list, decorator_list):
**decorator.tu_dest_args(rt),
src=src_type,
vl=type_helper.size_t)
if args["TYPES1"] != args["TYPES3"] and args["TYPES3"] == "f":
if args["TYPES1"] != args["TYPES3"] and args["TYPES3"] == "f" and args["TYPES1"] != "bf":
args["OP"] = args["OP"] + "_rtz"
inst_info = InstInfo.get(
args, decorator, InstType.VV, extra_attr=extra_attr)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ def render(G, op_list, type_list, sew_list, lmul_list, decorator_list):
args["S_TYPE"] = "f"
args["OP"] = "f" + op
inst_type = InstType.VVF
elif data_type == "bfloat":
args["S_TYPE"] = "f"
args["OP"] = "f" + op + "bf16"
inst_type = InstType.VVF
else:
args["S_TYPE"] = "x"
inst_type = InstType.VVX
Expand Down Expand Up @@ -146,6 +150,30 @@ def render(G, op_list, type_list, sew_list, lmul_list, decorator_list):
vs1=type_helper.s,
vs2=type_helper.v,
vl=type_helper.size_t)
elif data_type == "bfloat":
if "wmacc" in op and args["SEW"] == 16:
G.func(
inst_info_vv,
name="{OP}_vv_f{WSEW}m{WLMUL}".format_map(args) +
decorator.func_suffix,
return_type="vfloat{WSEW}m{WLMUL}_t".format_map(args),
**decorator.mask_args(type_helper.m, type_helper.v),
vd="vfloat{WSEW}m{WLMUL}_t".format_map(args),
vs1=type_helper.v,
vs2=type_helper.v,
vl=type_helper.size_t)
G.func(
inst_info_vs,
name="{OP}_v{S_TYPE}_f{WSEW}m{WLMUL}".format_map(args) +
decorator.func_suffix,
return_type="vfloat{WSEW}m{WLMUL}_t".format_map(args),
**decorator.mask_args(type_helper.m, type_helper.v),
vd="vfloat{WSEW}m{WLMUL}_t".format_map(args),
vs1=type_helper.s,
vs2=type_helper.v,
vl=type_helper.size_t)
else:
continue
else:
G.func(
inst_info_vv,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,10 @@ def render(G, op_list, type_list, sew_list, lmul_list, decorator_list):
# Variable in list means
# [dst type, dst short type, src type, src short type]
convert_set = [["float", "f", "int", "i"], ["float", "f", "uint", "u"],
["bfloat", "bf", "int", "i"], ["bfloat", "bf", "uint", "u"],
["uint", "u", "int", "i"], ["int", "i", "uint", "u"],
["int", "i", "float", "f"], ["uint", "u", "float", "f"]]
["int", "i", "float", "f"], ["uint", "u", "float", "f"],
["int", "i", "bfloat", "bf"], ["uint", "u", "bfloat", "bf"]]

for args in prod(
OP=op_list, SEW=sew_list, TYPES=convert_set, LMUL=lmul_list):
Expand Down
48 changes: 25 additions & 23 deletions rvv-intrinsic-rfc.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,32 +51,34 @@ Further, individual intrinsic functions depend on the availability of the corres

Encode `SEW` and `LMUL` into data types. We enforce the constraint `LMUL ≥ SEW/ELEN` in the implementation. There are the following data types for `ELEN` = 64.

| Types | LMUL = 1 | LMUL = 2 | LMUL = 4 | LMUL = 8 | LMUL = 1/2 | LMUL = 1/4 | LMUL = 1/8
| ------------ | ------------ | ------------ | ------------ | ----------- | ------------- | ------------- | --------------
| **int64_t** | vint64m1_t | vint64m2_t | vint64m4_t | vint64m8_t | N/A | N/A | N/A
| **uint64_t** | vuint64m1_t | vuint64m2_t | vuint64m4_t | vuint64m8_t | N/A | N/A | N/A
| **int32_t** | vint32m1_t | vint32m2_t | vint32m4_t | vint32m8_t | vint32mf2_t | N/A | N/A
| **uint32_t** | vuint32m1_t | vuint32m2_t | vuint32m4_t | vuint32m8_t | vuint32mf2_t | N/A | N/A
| **int16_t** | vint16m1_t | vint16m2_t | vint16m4_t | vint16m8_t | vint16mf2_t | vint16mf4_t | N/A
| **uint16_t** | vuint16m1_t | vuint16m2_t | vuint16m4_t | vuint16m8_t | vuint16mf2_t | vuint16mf4_t | N/A
| **int8_t** | vint8m1_t | vint8m2_t | vint8m4_t | vint8m8_t | vint8mf2_t | vint8mf4_t | vint8mf8_t
| **uint8_t** | vuint8m1_t | vuint8m2_t | vuint8m4_t | vuint8m8_t | vuint8mf2_t | vuint8mf4_t | vuint8mf8_t
| **vfloat64** | vfloat64m1_t | vfloat64m2_t | vfloat64m4_t | vfloat64m8_t | N/A | N/A | N/A
| **vfloat32** | vfloat32m1_t | vfloat32m2_t | vfloat32m4_t | vfloat32m8_t | vfloat32mf2_t | N/A | N/A
| **vfloat16** | vfloat16m1_t | vfloat16m2_t | vfloat16m4_t | vfloat16m8_t | vfloat16mf2_t | vfloat16mf4_t | N/A
| Types | LMUL = 1 | LMUL = 2 | LMUL = 4 | LMUL = 8 | LMUL = 1/2 | LMUL = 1/4 | LMUL = 1/8
| ------------ | ------------ | ------------ | ------------ | ----------- | ------------- | ------------- | --------------
| **int64_t** | vint64m1_t | vint64m2_t | vint64m4_t | vint64m8_t | N/A | N/A | N/A
| **uint64_t** | vuint64m1_t | vuint64m2_t | vuint64m4_t | vuint64m8_t | N/A | N/A | N/A
| **int32_t** | vint32m1_t | vint32m2_t | vint32m4_t | vint32m8_t | vint32mf2_t | N/A | N/A
| **uint32_t** | vuint32m1_t | vuint32m2_t | vuint32m4_t | vuint32m8_t | vuint32mf2_t | N/A | N/A
| **int16_t** | vint16m1_t | vint16m2_t | vint16m4_t | vint16m8_t | vint16mf2_t | vint16mf4_t | N/A
| **uint16_t** | vuint16m1_t | vuint16m2_t | vuint16m4_t | vuint16m8_t | vuint16mf2_t | vuint16mf4_t | N/A
| **int8_t** | vint8m1_t | vint8m2_t | vint8m4_t | vint8m8_t | vint8mf2_t | vint8mf4_t | vint8mf8_t
| **uint8_t** | vuint8m1_t | vuint8m2_t | vuint8m4_t | vuint8m8_t | vuint8mf2_t | vuint8mf4_t | vuint8mf8_t
| **vfloat64** | vfloat64m1_t | vfloat64m2_t | vfloat64m4_t | vfloat64m8_t | N/A | N/A | N/A
| **vfloat32** | vfloat32m1_t | vfloat32m2_t | vfloat32m4_t | vfloat32m8_t | vfloat32mf2_t | N/A | N/A
| **vfloat16** | vfloat16m1_t | vfloat16m2_t | vfloat16m4_t | vfloat16m8_t | vfloat16mf2_t | vfloat16mf4_t | N/A
| **vbfloat16** | vbfloat16m1_t | vbfloat16m2_t | vbfloat16m4_t | vbfloat16m8_t | vbfloat16mf2_t | vbfloat16mf4_t | N/A

There are the following data types for `ELEN` = 32.

| Types | LMUL = 1 | LMUL = 2 | LMUL = 4 | LMUL = 8 | LMUL = 1/2 | LMUL = 1/4 | LMUL = 1/8
| ------------ | ------------ | ------------ | ------------ | ----------- | ------------- | ------------- | --------------
| **int32_t** | vint32m1_t | vint32m2_t | vint32m4_t | vint32m8_t | N/A | N/A | N/A
| **uint32_t** | vuint32m1_t | vuint32m2_t | vuint32m4_t | vuint32m8_t | N/A | N/A | N/A
| **int16_t** | vint16m1_t | vint16m2_t | vint16m4_t | vint16m8_t | vint16mf2_t | N/A | N/A
| **uint16_t** | vuint16m1_t | vuint16m2_t | vuint16m4_t | vuint16m8_t | vuint16mf2_t | N/A | N/A
| **int8_t** | vint8m1_t | vint8m2_t | vint8m4_t | vint8m8_t | vint8mf2_t | vint8mf4_t | N/A
| **uint8_t** | vuint8m1_t | vuint8m2_t | vuint8m4_t | vuint8m8_t | vuint8mf2_t | vuint8mf4_t | N/A
| **vfloat32** | vfloat32m1_t | vfloat32m2_t | vfloat32m4_t | vfloat32m8_t | N/A | N/A | N/A
| **vfloat16** | vfloat16m1_t | vfloat16m2_t | vfloat16m4_t | vfloat16m8_t | vfloat16mf2_t | N/A | N/A
| Types | LMUL = 1 | LMUL = 2 | LMUL = 4 | LMUL = 8 | LMUL = 1/2 | LMUL = 1/4 | LMUL = 1/8
| ------------ | ------------ | ------------ | ------------ | ----------- | ------------- | ------------- | --------------
| **int32_t** | vint32m1_t | vint32m2_t | vint32m4_t | vint32m8_t | N/A | N/A | N/A
| **uint32_t** | vuint32m1_t | vuint32m2_t | vuint32m4_t | vuint32m8_t | N/A | N/A | N/A
| **int16_t** | vint16m1_t | vint16m2_t | vint16m4_t | vint16m8_t | vint16mf2_t | N/A | N/A
| **uint16_t** | vuint16m1_t | vuint16m2_t | vuint16m4_t | vuint16m8_t | vuint16mf2_t | N/A | N/A
| **int8_t** | vint8m1_t | vint8m2_t | vint8m4_t | vint8m8_t | vint8mf2_t | vint8mf4_t | N/A
| **uint8_t** | vuint8m1_t | vuint8m2_t | vuint8m4_t | vuint8m8_t | vuint8mf2_t | vuint8mf4_t | N/A
| **vfloat32** | vfloat32m1_t | vfloat32m2_t | vfloat32m4_t | vfloat32m8_t | N/A | N/A | N/A
| **vfloat16** | vfloat16m1_t | vfloat16m2_t | vfloat16m4_t | vfloat16m8_t | vfloat16mf2_t | N/A | N/A
| **vbfloat16** | vbfloat16m1_t | vbfloat16m2_t | vbfloat16m4_t | vbfloat16m8_t | vbfloat16mf2_t | N/A | N/A

### Mask Types<a name="mask-types"></a>

Expand Down