From 3f8d54754fb37933542053b9526d3e4a53c14549 Mon Sep 17 00:00:00 2001 From: Yang Hau Date: Fri, 2 Aug 2024 00:38:58 +0800 Subject: [PATCH] feat: Add vsqrt[q]_[f32|f64] --- neon2rvv.h | 8 +++--- tests/impl.cpp | 68 +++++++++++++++++++++++++++++++++++++++++++++++--- tests/impl.h | 8 +++--- 3 files changed, 72 insertions(+), 12 deletions(-) diff --git a/neon2rvv.h b/neon2rvv.h index 8cdec040..dbe6724b 100644 --- a/neon2rvv.h +++ b/neon2rvv.h @@ -4800,13 +4800,13 @@ FORCE_INLINE float32_t vrecpss_f32(float32_t a, float32_t b) { return 2.0 - a * FORCE_INLINE float64_t vrecpsd_f64(float64_t a, float64_t b) { return 2.0 - a * b; } -// FORCE_INLINE float32x2_t vsqrt_f32(float32x2_t a); +FORCE_INLINE float32x2_t vsqrt_f32(float32x2_t a) { return __riscv_vfsqrt_v_f32m1(a, 2); } -// FORCE_INLINE float32x4_t vsqrtq_f32(float32x4_t a); +FORCE_INLINE float32x4_t vsqrtq_f32(float32x4_t a) { return __riscv_vfsqrt_v_f32m1(a, 4); } -// FORCE_INLINE float64x1_t vsqrt_f64(float64x1_t a); +FORCE_INLINE float64x1_t vsqrt_f64(float64x1_t a) { return __riscv_vfsqrt_v_f64m1(a, 1); } -// FORCE_INLINE float64x2_t vsqrtq_f64(float64x2_t a); +FORCE_INLINE float64x2_t vsqrtq_f64(float64x2_t a) { return __riscv_vfsqrt_v_f64m1(a, 2); } FORCE_INLINE float32x2_t vrsqrts_f32(float32x2_t a, float32x2_t b) { return __riscv_vfdiv_vf_f32m1(__riscv_vfnmsac_vv_f32m1(vdup_n_f32(3.0), a, b, 2), 2.0, 2); diff --git a/tests/impl.cpp b/tests/impl.cpp index 1339c891..92d056fc 100644 --- a/tests/impl.cpp +++ b/tests/impl.cpp @@ -17010,13 +17010,73 @@ result_t test_vrecpsd_f64(const NEON2RVV_TEST_IMPL &impl, uint32_t iter) { #endif // ENABLE_TEST_ALL } -result_t test_vsqrt_f32(const NEON2RVV_TEST_IMPL &impl, uint32_t iter) { return TEST_UNIMPL; } +result_t test_vsqrt_f32(const NEON2RVV_TEST_IMPL &impl, uint32_t iter) { +#ifdef ENABLE_TEST_ALL + const float *_a = (const float *)impl.test_cases_float_pointer1; + float _c[2]; + for (int i = 0; i < 2; i++) { + _c[i] = sqrtf(_a[i]); + } + + float32x2_t a = vld1_f32(_a); + float32x2_t c = vsqrt_f32(a); + + return validate_float_error(c, _c[0], _c[1], 0.0001f); +#else + return TEST_UNIMPL; +#endif // ENABLE_TEST_ALL +} + +result_t test_vsqrtq_f32(const NEON2RVV_TEST_IMPL &impl, uint32_t iter) { +#ifdef ENABLE_TEST_ALL + const float *_a = (const float *)impl.test_cases_float_pointer1; + float _c[4]; + for (int i = 0; i < 4; i++) { + _c[i] = sqrtf(_a[i]); + } + + float32x4_t a = vld1q_f32(_a); + float32x4_t c = vsqrtq_f32(a); + + return validate_float_error(c, _c[0], _c[1], _c[2], _c[3], 0.0001f); +#else + return TEST_UNIMPL; +#endif // ENABLE_TEST_ALL +} + +result_t test_vsqrt_f64(const NEON2RVV_TEST_IMPL &impl, uint32_t iter) { +#ifdef ENABLE_TEST_ALL + const double *_a = (const double *)impl.test_cases_float_pointer1; + double _c[2]; + for (int i = 0; i < 1; i++) { + _c[i] = sqrt(_a[i]); + } + + float64x1_t a = vld1_f64(_a); + float64x1_t c = vsqrt_f64(a); + + return validate_double_error(c, _c[0], 0.0001f); +#else + return TEST_UNIMPL; +#endif // ENABLE_TEST_ALL +} -result_t test_vsqrtq_f32(const NEON2RVV_TEST_IMPL &impl, uint32_t iter) { return TEST_UNIMPL; } +result_t test_vsqrtq_f64(const NEON2RVV_TEST_IMPL &impl, uint32_t iter) { +#ifdef ENABLE_TEST_ALL + const double *_a = (const double *)impl.test_cases_float_pointer1; + double _c[4]; + for (int i = 0; i < 4; i++) { + _c[i] = sqrt(_a[i]); + } -result_t test_vsqrt_f64(const NEON2RVV_TEST_IMPL &impl, uint32_t iter) { return TEST_UNIMPL; } + float64x2_t a = vld1q_f64(_a); + float64x2_t c = vsqrtq_f64(a); -result_t test_vsqrtq_f64(const NEON2RVV_TEST_IMPL &impl, uint32_t iter) { return TEST_UNIMPL; } + return validate_double_error(c, _c[0], _c[1], 0.0001f); +#else + return TEST_UNIMPL; +#endif // ENABLE_TEST_ALL +} result_t test_vrsqrts_f32(const NEON2RVV_TEST_IMPL &impl, uint32_t iter) { #ifdef ENABLE_TEST_ALL diff --git a/tests/impl.h b/tests/impl.h index b053816d..3cd61178 100644 --- a/tests/impl.h +++ b/tests/impl.h @@ -961,10 +961,10 @@ _(vrecpsq_f64) \ _(vrecpss_f32) \ _(vrecpsd_f64) \ - /*_(vsqrt_f32) */ \ - /*_(vsqrtq_f32) */ \ - /*_(vsqrt_f64) */ \ - /*_(vsqrtq_f64) */ \ + _(vsqrt_f32) \ + _(vsqrtq_f32) \ + _(vsqrt_f64) \ + _(vsqrtq_f64) \ _(vrsqrts_f32) \ _(vrsqrtsq_f32) \ _(vrsqrts_f64) \