Skip to content

Commit

Permalink
fuse some kernels in pnpn_res_stress_device.F90
Browse files Browse the repository at this point in the history
  • Loading branch information
Shiyu-Sandy-Du committed Nov 6, 2024
1 parent 863c5d3 commit 219d1f5
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ extern "C" {

void pnpn_prs_stress_res_part1_cuda(void *ta1, void *ta2, void *ta3,
void *wa1, void *wa2, void *wa3,
void *s11, void *s22, void *s33,
void *s12, void *s13, void *s23,
void *f_u, void *f_v, void *f_w,
void *B, void *h1, void *rho, int *n) {

Expand All @@ -53,6 +55,9 @@ extern "C" {
<<<nblcks, nthrds, 0, stream>>>((real *) ta1, (real *) ta2,
(real *) ta3, (real *) wa1,
(real *) wa2, (real *) wa3,
(real *) s11, (real *) s22,
(real *) s33, (real *) s12,
(real *) s13, (real *) s23,
(real *) f_u, (real *) f_v,
(real *) f_w, (real *) B,
(real *) rho, *n);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,15 @@ template< typename T >
__global__ void prs_stress_res_part1_kernel(T * __restrict__ ta1,
T * __restrict__ ta2,
T * __restrict__ ta3,
const T * __restrict__ wa1,
const T * __restrict__ wa2,
const T * __restrict__ wa3,
T * __restrict__ wa1,
T * __restrict__ wa2,
T * __restrict__ wa3,
const T * __restrict__ s11,
const T * __restrict__ s22,
const T * __restrict__ s33,
const T * __restrict__ s12,
const T * __restrict__ s13,
const T * __restrict__ s23,
const T * __restrict__ f_u,
const T * __restrict__ f_v,
const T * __restrict__ f_w,
Expand All @@ -53,6 +59,16 @@ __global__ void prs_stress_res_part1_kernel(T * __restrict__ ta1,
const int str = blockDim.x * gridDim.x;

for (int i = idx; i < n; i += str) {
wa1[i] -= 2.0 * (ta1[i] * s11[i]
+ ta2[i] * s12[i]
+ ta3[i] * s13[i]);
wa2[i] -= 2.0 * (ta1[i] * s12[i]
+ ta2[i] * s22[i]
+ ta3[i] * s23[i]);
wa3[i] -= 2.0 * (ta1[i] * s13[i]
+ ta2[i] * s23[i]
+ ta3[i] * s33[i]);

ta1[i] = (f_u[i] / rho[i]) - ((wa1[i] / rho[i]) * B[i]);
ta2[i] = (f_v[i] / rho[i]) - ((wa2[i] / rho[i]) * B[i]);
ta3[i] = (f_w[i] / rho[i]) - ((wa3[i] / rho[i]) * B[i]);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ extern "C" {

void pnpn_prs_stress_res_part1_hip(void *ta1, void *ta2, void *ta3,
void *wa1, void *wa2, void *wa3,
void *s11, void *s22, void *s33,
void *s12, void *s13, void *s23,
void *f_u, void *f_v, void *f_w,
void *B, void *h1, void *rho, int *n) {

Expand All @@ -52,6 +54,8 @@ extern "C" {
nblcks, nthrds, 0, (hipStream_t) glb_cmd_queue,
(real *) ta1, (real *) ta2, (real *) ta3,
(real *) wa1, (real *) wa2, (real *) wa3,
(real *) s11, (real *) s22, (real *) s33,
(real *) s12, (real *) s13, (real *) s23,
(real *) f_u, (real *) f_v, (real *) f_w,
(real *) B, (real *) rho, *n);
HIP_CHECK(hipGetLastError());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,15 @@ template< typename T >
__global__ void prs_stress_res_part1_kernel(T * __restrict__ ta1,
T * __restrict__ ta2,
T * __restrict__ ta3,
const T * __restrict__ wa1,
const T * __restrict__ wa2,
const T * __restrict__ wa3,
T * __restrict__ wa1,
T * __restrict__ wa2,
T * __restrict__ wa3,
const T * __restrict__ s11,
const T * __restrict__ s22,
const T * __restrict__ s33,
const T * __restrict__ s12,
const T * __restrict__ s13,
const T * __restrict__ s23,
const T * __restrict__ f_u,
const T * __restrict__ f_v,
const T * __restrict__ f_w,
Expand All @@ -53,6 +59,16 @@ __global__ void prs_stress_res_part1_kernel(T * __restrict__ ta1,
const int str = blockDim.x * gridDim.x;

for (int i = idx; i < n; i += str) {
wa1[i] -= 2.0 * (ta1[i] * s11[i]
+ ta2[i] * s12[i]
+ ta3[i] * s13[i]);
wa2[i] -= 2.0 * (ta1[i] * s12[i]
+ ta2[i] * s22[i]
+ ta3[i] * s23[i]);
wa3[i] -= 2.0 * (ta1[i] * s13[i]
+ ta2[i] * s23[i]
+ ta3[i] * s33[i]);

ta1[i] = (f_u[i] / rho[i]) - ((wa1[i] / rho[i]) * B[i]);
ta2[i] = (f_v[i] / rho[i]) - ((wa2[i] / rho[i]) * B[i]);
ta3[i] = (f_w[i] / rho[i]) - ((wa3[i] / rho[i]) * B[i]);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,16 @@ module pnpn_res_stress_device
#ifdef HAVE_HIP
interface
subroutine pnpn_prs_stress_res_part1_hip(ta1_d, ta2_d, ta3_d, &
wa1_d, wa2_d, wa3_d, f_u_d, f_v_d, f_w_d, &
wa1_d, wa2_d, wa3_d, s11_d, s22_d, s33_d, &
s12_d, s13_d, s23_d, f_u_d, f_v_d, f_w_d, &
B_d, h1_d, rho_d, n) &
bind(c, name = 'pnpn_prs_stress_res_part1_hip')
use, intrinsic :: iso_c_binding
import c_rp
implicit none
type(c_ptr), value :: ta1_d, ta2_d, ta3_d
type(c_ptr), value :: wa1_d, wa2_d, wa3_d
type(c_ptr), value :: s11_d, s22_d, s33_d, s12_d, s13_d, s23_d
type(c_ptr), value :: f_u_d, f_v_d, f_w_d
type(c_ptr), value :: B_d, h1_d, rho_d
integer(c_int) :: n
Expand Down Expand Up @@ -89,14 +91,16 @@ end subroutine pnpn_vel_res_update_hip
#elif HAVE_CUDA
interface
subroutine pnpn_prs_stress_res_part1_cuda(ta1_d, ta2_d, ta3_d, &
wa1_d, wa2_d, wa3_d, f_u_d, f_v_d, f_w_d, &
wa1_d, wa2_d, wa3_d, s11_d, s22_d, s33_d, &
s12_d, s13_d, s23_d, f_u_d, f_v_d, f_w_d, &
B_d, h1_d, rho_d, n) &
bind(c, name = 'pnpn_prs_stress_res_part1_cuda')
use, intrinsic :: iso_c_binding
import c_rp
implicit none
type(c_ptr), value :: ta1_d, ta2_d, ta3_d
type(c_ptr), value :: wa1_d, wa2_d, wa3_d
type(c_ptr), value :: s11_d, s22_d, s33_d, s12_d, s13_d, s23_d
type(c_ptr), value :: f_u_d, f_v_d, f_w_d
type(c_ptr), value :: B_d, h1_d, rho_d
integer(c_int) :: n
Expand Down Expand Up @@ -270,35 +274,17 @@ subroutine pnpn_prs_res_stress_device_compute(p, p_res, u, v, w, u_e, v_e,&
call dudxyz(ta2%x, mu%x, c_Xh%drdy, c_Xh%dsdy, c_Xh%dtdy, c_Xh)
call dudxyz(ta3%x, mu%x, c_Xh%drdz, c_Xh%dsdz, c_Xh%dtdz, c_Xh)

call device_cmult(ta1%x_d, 2.0_rp, n)
call device_cmult(ta2%x_d, 2.0_rp, n)
call device_cmult(ta3%x_d, 2.0_rp, n)

! S^T grad \mu
call device_vdot3 (work1%x_d, ta1%x_d, ta2%x_d, ta3%x_d, &
s11%x_d, s12%x_d, s13%x_d, n)

call device_vdot3 (work2%x_d, ta1%x_d, ta2%x_d, ta3%x_d, &
s12%x_d, s22%x_d, s23%x_d, n)

call device_vdot3 (work3%x_d, ta1%x_d, ta2%x_d, ta3%x_d, &
s13%x_d, s23%x_d, s33%x_d, n)

! Subtract the two terms of the viscous stress to get
! \nabla x \nabla u - S^T \nabla \mu
! The sign is consitent with the fact that we subtract the term
! below.
call device_sub2(wa1%x_d, work1%x_d, n)
call device_sub2(wa2%x_d, work2%x_d, n)
call device_sub2(wa3%x_d, work3%x_d, n)

#ifdef HAVE_HIP
call pnpn_prs_stress_res_part1_hip(ta1%x_d, ta2%x_d, ta3%x_d, &
wa1%x_d, wa2%x_d, wa3%x_d, f_x%x_d, f_y%x_d, f_z%x_d, &
wa1%x_d, wa2%x_d, wa3%x_d, &
s11%x_d, s22%x_d, s33%x_d, s12%x_d, s13%x_d, s23%x_d, &
f_x%x_d, f_y%x_d, f_z%x_d, &
c_Xh%B_d, c_Xh%h1_d, rho%x_d, n)
#elif HAVE_CUDA
call pnpn_prs_stress_res_part1_cuda(ta1%x_d, ta2%x_d, ta3%x_d, &
wa1%x_d, wa2%x_d, wa3%x_d, f_x%x_d, f_y%x_d, f_z%x_d, &
wa1%x_d, wa2%x_d, wa3%x_d, &
s11%x_d, s22%x_d, s33%x_d, s12%x_d, s13%x_d, s23%x_d, &
f_x%x_d, f_y%x_d, f_z%x_d, &
c_Xh%B_d, c_Xh%h1_d, rho%x_d, n)
#else
call neko_error('No device backend configured')
Expand Down

0 comments on commit 219d1f5

Please sign in to comment.