From e3bd6a93b36ab9d4a5aa2c781bae0573b124d2a5 Mon Sep 17 00:00:00 2001 From: Yang Hau Date: Fri, 2 Aug 2024 14:29:04 +0800 Subject: [PATCH] feat: Add v[max|min]nm[q]_[f32|f64] --- neon2rvv.h | 64 ++++++++++++++++++++++++++++++++++++++------ tests/impl.cpp | 72 +++++++++++++++++++++++++++++++++++++++++++++++--- tests/impl.h | 8 +++--- 3 files changed, 128 insertions(+), 16 deletions(-) diff --git a/neon2rvv.h b/neon2rvv.h index dbe6724b..884a15ce 100644 --- a/neon2rvv.h +++ b/neon2rvv.h @@ -3829,21 +3829,69 @@ FORCE_INLINE float64x2_t vmaxq_f64(float64x2_t a, float64x2_t b) { return __riscv_vmerge_vvm_f64m1(vdupq_n_f64(NAN), max_res, mask, 2); } -FORCE_INLINE float32x2_t vmaxnm_f32(float32x2_t a, float32x2_t b) { return __riscv_vfmax_vv_f32m1(a, b, 2); } +FORCE_INLINE float32x2_t vmaxnm_f32(float32x2_t a, float32x2_t b) { + vbool32_t a_non_nan_mask = __riscv_vmfeq_vv_f32m1_b32(a, a, 2); + vbool32_t b_non_nan_mask = __riscv_vmfeq_vv_f32m1_b32(b, b, 2); + float32x2_t a_replace = __riscv_vmerge_vvm_f32m1(b, a, a_non_nan_mask, 2); + float32x2_t b_replace = __riscv_vmerge_vvm_f32m1(a, b, b_non_nan_mask, 2); + return __riscv_vfmax_vv_f32m1(a_replace, b_replace, 2); +} -FORCE_INLINE float32x4_t vmaxnmq_f32(float32x4_t a, float32x4_t b) { return __riscv_vfmax_vv_f32m1(a, b, 4); } +FORCE_INLINE float32x4_t vmaxnmq_f32(float32x4_t a, float32x4_t b) { + vbool32_t a_non_nan_mask = __riscv_vmfeq_vv_f32m1_b32(a, a, 4); + vbool32_t b_non_nan_mask = __riscv_vmfeq_vv_f32m1_b32(b, b, 4); + float32x2_t a_replace = __riscv_vmerge_vvm_f32m1(b, a, a_non_nan_mask, 4); + float32x2_t b_replace = __riscv_vmerge_vvm_f32m1(a, b, b_non_nan_mask, 4); + return __riscv_vfmax_vv_f32m1(a_replace, b_replace, 4); +} -// FORCE_INLINE float64x1_t vmaxnm_f64(float64x1_t a, float64x1_t b); +FORCE_INLINE float64x1_t vmaxnm_f64(float64x1_t a, float64x1_t b) { + vbool64_t a_non_nan_mask = __riscv_vmfeq_vv_f64m1_b64(a, a, 1); + vbool64_t b_non_nan_mask = __riscv_vmfeq_vv_f64m1_b64(b, b, 1); + float64x2_t a_replace = __riscv_vmerge_vvm_f64m1(b, a, a_non_nan_mask, 1); + float64x2_t b_replace = __riscv_vmerge_vvm_f64m1(a, b, b_non_nan_mask, 1); + return __riscv_vfmax_vv_f64m1(a_replace, b_replace, 1); +} -// FORCE_INLINE float64x2_t vmaxnmq_f64(float64x2_t a, float64x2_t b); +FORCE_INLINE float64x2_t vmaxnmq_f64(float64x2_t a, float64x2_t b) { + vbool64_t a_non_nan_mask = __riscv_vmfeq_vv_f64m1_b64(a, a, 2); + vbool64_t b_non_nan_mask = __riscv_vmfeq_vv_f64m1_b64(b, b, 2); + float64x2_t a_replace = __riscv_vmerge_vvm_f64m1(b, a, a_non_nan_mask, 2); + float64x2_t b_replace = __riscv_vmerge_vvm_f64m1(a, b, b_non_nan_mask, 2); + return __riscv_vfmax_vv_f64m1(a_replace, b_replace, 2); +} -FORCE_INLINE float32x2_t vminnm_f32(float32x2_t a, float32x2_t b) { return __riscv_vfmin_vv_f32m1(a, b, 2); } +FORCE_INLINE float32x2_t vminnm_f32(float32x2_t a, float32x2_t b) { + vbool32_t a_non_nan_mask = __riscv_vmfeq_vv_f32m1_b32(a, a, 2); + vbool32_t b_non_nan_mask = __riscv_vmfeq_vv_f32m1_b32(b, b, 2); + float32x2_t a_replace = __riscv_vmerge_vvm_f32m1(b, a, a_non_nan_mask, 2); + float32x2_t b_replace = __riscv_vmerge_vvm_f32m1(a, b, b_non_nan_mask, 2); + return __riscv_vfmin_vv_f32m1(a_replace, b_replace, 2); +} -FORCE_INLINE float32x4_t vminnmq_f32(float32x4_t a, float32x4_t b) { return __riscv_vfmin_vv_f32m1(a, b, 4); } +FORCE_INLINE float32x4_t vminnmq_f32(float32x4_t a, float32x4_t b) { + vbool32_t a_non_nan_mask = __riscv_vmfeq_vv_f32m1_b32(a, a, 4); + vbool32_t b_non_nan_mask = __riscv_vmfeq_vv_f32m1_b32(b, b, 4); + float32x2_t a_replace = __riscv_vmerge_vvm_f32m1(b, a, a_non_nan_mask, 4); + float32x2_t b_replace = __riscv_vmerge_vvm_f32m1(a, b, b_non_nan_mask, 4); + return __riscv_vfmin_vv_f32m1(a_replace, b_replace, 4); +} -// FORCE_INLINE float64x1_t vminnm_f64(float64x1_t a, float64x1_t b); +FORCE_INLINE float64x1_t vminnm_f64(float64x1_t a, float64x1_t b) { + vbool64_t a_non_nan_mask = __riscv_vmfeq_vv_f64m1_b64(a, a, 1); + vbool64_t b_non_nan_mask = __riscv_vmfeq_vv_f64m1_b64(b, b, 1); + float64x2_t a_replace = __riscv_vmerge_vvm_f64m1(b, a, a_non_nan_mask, 1); + float64x2_t b_replace = __riscv_vmerge_vvm_f64m1(a, b, b_non_nan_mask, 1); + return __riscv_vfmin_vv_f64m1(a_replace, b_replace, 1); +} -// FORCE_INLINE float64x2_t vminnmq_f64(float64x2_t a, float64x2_t b); +FORCE_INLINE float64x2_t vminnmq_f64(float64x2_t a, float64x2_t b) { + vbool64_t a_non_nan_mask = __riscv_vmfeq_vv_f64m1_b64(a, a, 2); + vbool64_t b_non_nan_mask = __riscv_vmfeq_vv_f64m1_b64(b, b, 2); + float64x2_t a_replace = __riscv_vmerge_vvm_f64m1(b, a, a_non_nan_mask, 2); + float64x2_t b_replace = __riscv_vmerge_vvm_f64m1(a, b, b_non_nan_mask, 2); + return __riscv_vfmin_vv_f64m1(a_replace, b_replace, 2); +} FORCE_INLINE uint8x16_t vmaxq_u8(uint8x16_t a, uint8x16_t b) { return __riscv_vmaxu_vv_u8m1(a, b, 16); } diff --git a/tests/impl.cpp b/tests/impl.cpp index 92d056fc..82c557c2 100644 --- a/tests/impl.cpp +++ b/tests/impl.cpp @@ -14080,9 +14080,41 @@ result_t test_vmaxnmq_f32(const NEON2RVV_TEST_IMPL &impl, uint32_t iter) { #endif // ENABLE_TEST_ALL } -result_t test_vmaxnm_f64(const NEON2RVV_TEST_IMPL &impl, uint32_t iter) { return TEST_UNIMPL; } +result_t test_vmaxnm_f64(const NEON2RVV_TEST_IMPL &impl, uint32_t iter) { +#ifdef ENABLE_TEST_ALL + const double *_a = (const double *)impl.test_cases_float_pointer1; + const double *_b = (const double *)impl.test_cases_float_pointer2; + double _c[1]; + for (int i = 0; i < 1; i++) { + _c[i] = _a[i] > _b[i] ? _a[i] : _b[i]; + } + + float64x1_t a = vld1_f64(_a); + float64x1_t b = vld1_f64(_b); + float64x1_t c = vmaxnm_f64(a, b); + return validate_double(c, _c[0]); +#else + return TEST_UNIMPL; +#endif // ENABLE_TEST_ALL +} -result_t test_vmaxnmq_f64(const NEON2RVV_TEST_IMPL &impl, uint32_t iter) { return TEST_UNIMPL; } +result_t test_vmaxnmq_f64(const NEON2RVV_TEST_IMPL &impl, uint32_t iter) { +#ifdef ENABLE_TEST_ALL + const double *_a = (const double *)impl.test_cases_float_pointer1; + const double *_b = (const double *)impl.test_cases_float_pointer2; + double _c[2]; + for (int i = 0; i < 2; i++) { + _c[i] = _a[i] > _b[i] ? _a[i] : _b[i]; + } + + float64x2_t a = vld1q_f64(_a); + float64x2_t b = vld1q_f64(_b); + float64x2_t c = vmaxnmq_f64(a, b); + return validate_double(c, _c[0], _c[1]); +#else + return TEST_UNIMPL; +#endif // ENABLE_TEST_ALL +} result_t test_vminnm_f32(const NEON2RVV_TEST_IMPL &impl, uint32_t iter) { #ifdef ENABLE_TEST_ALL @@ -14120,9 +14152,41 @@ result_t test_vminnmq_f32(const NEON2RVV_TEST_IMPL &impl, uint32_t iter) { #endif // ENABLE_TEST_ALL } -result_t test_vminnm_f64(const NEON2RVV_TEST_IMPL &impl, uint32_t iter) { return TEST_UNIMPL; } +result_t test_vminnm_f64(const NEON2RVV_TEST_IMPL &impl, uint32_t iter) { +#ifdef ENABLE_TEST_ALL + const double *_a = (const double *)impl.test_cases_float_pointer1; + const double *_b = (const double *)impl.test_cases_float_pointer2; + double _c[1]; + for (int i = 0; i < 1; i++) { + _c[i] = _a[i] < _b[i] ? _a[i] : _b[i]; + } + + float64x1_t a = vld1_f64(_a); + float64x1_t b = vld1_f64(_b); + float64x1_t c = vminnm_f64(a, b); + return validate_double(c, _c[0]); +#else + return TEST_UNIMPL; +#endif // ENABLE_TEST_ALL +} -result_t test_vminnmq_f64(const NEON2RVV_TEST_IMPL &impl, uint32_t iter) { return TEST_UNIMPL; } +result_t test_vminnmq_f64(const NEON2RVV_TEST_IMPL &impl, uint32_t iter) { +#ifdef ENABLE_TEST_ALL + const double *_a = (const double *)impl.test_cases_float_pointer1; + const double *_b = (const double *)impl.test_cases_float_pointer2; + double _c[2]; + for (int i = 0; i < 2; i++) { + _c[i] = _a[i] < _b[i] ? _a[i] : _b[i]; + } + + float64x2_t a = vld1q_f64(_a); + float64x2_t b = vld1q_f64(_b); + float64x2_t c = vminnmq_f64(a, b); + return validate_double(c, _c[0], _c[1]); +#else + return TEST_UNIMPL; +#endif // ENABLE_TEST_ALL +} result_t test_vmaxq_u8(const NEON2RVV_TEST_IMPL &impl, uint32_t iter) { #ifdef ENABLE_TEST_ALL diff --git a/tests/impl.h b/tests/impl.h index 3cd61178..29d96a07 100644 --- a/tests/impl.h +++ b/tests/impl.h @@ -776,12 +776,12 @@ _(vmaxq_f64) \ _(vmaxnm_f32) \ _(vmaxnmq_f32) \ - /*_(vmaxnm_f64) */ \ - /*_(vmaxnmq_f64) */ \ + _(vmaxnm_f64) \ + _(vmaxnmq_f64) \ _(vminnm_f32) \ _(vminnmq_f32) \ - /*_(vminnm_f64) */ \ - /*_(vminnmq_f64) */ \ + _(vminnm_f64) \ + _(vminnmq_f64) \ _(vmaxq_u8) \ _(vmaxq_u16) \ _(vmaxq_u32) \