Skip to content

Commit

Permalink
Fix: error in GPU
Browse files Browse the repository at this point in the history
  • Loading branch information
dyzheng committed Sep 5, 2024
1 parent 83f0acd commit 8443c42
Showing 1 changed file with 24 additions and 6 deletions.
30 changes: 24 additions & 6 deletions source/module_hamilt_pw/hamilt_pwdft/fs_nonlocal_tools.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -925,8 +925,8 @@ void FS_Nonlocal_tools<FPTYPE, Device>::cal_force_dftu(int ik,
const std::complex<FPTYPE>* vu,
const int size_vu)
{
int* orbital_corr_tmp = const_cast<int*>(orbital_corr);
std::complex<FPTYPE>* vu_tmp = const_cast<std::complex<FPTYPE>*>(vu);
int* orbital_corr_tmp = nullptr;
std::complex<FPTYPE>* vu_tmp = nullptr;
#if defined(__CUDA) || defined(__ROCM)
if (this->device == base_device::GpuDevice)
{
Expand All @@ -935,7 +935,12 @@ void FS_Nonlocal_tools<FPTYPE, Device>::cal_force_dftu(int ik,
resmem_complex_op()(this->ctx, vu_tmp, size_vu);
syncmem_complex_h2d_op()(this->ctx, cpu_ctx, vu_tmp, vu, size_vu);
}
else
#endif
{
orbital_corr_tmp = const_cast<int*>(orbital_corr);
vu_tmp = const_cast<std::complex<FPTYPE>*>(vu);
}
const int force_nc = 3;
cal_force_nl_op<FPTYPE, Device>()(this->ctx,
npm,
Expand Down Expand Up @@ -976,14 +981,18 @@ void FS_Nonlocal_tools<FPTYPE, Device>::cal_force_dspin(int ik,
lambda_array[iat * 3 + 1] = lambda[iat].y;
lambda_array[iat * 3 + 2] = lambda[iat].z;
}
FPTYPE* lambda_tmp = lambda_array.data();
FPTYPE* lambda_tmp = nullptr;
#if defined(__CUDA) || defined(__ROCM)
if (this->device == base_device::GpuDevice)
{
resmem_var_op()(this->ctx, lambda_tmp, this->ucell_->nat * 3);
syncmem_var_h2d_op()(this->ctx, this->cpu_ctx, lambda_tmp, lambda_array.data(), this->ucell_->nat * 3);
}
else
#endif
{
lambda_tmp = lambda_array.data();
}
const int force_nc = 3;
cal_force_nl_op<FPTYPE, Device>()(this->ctx,
npm,
Expand Down Expand Up @@ -1018,8 +1027,8 @@ void FS_Nonlocal_tools<FPTYPE, Device>::cal_stress_dftu(int ik,
const std::complex<FPTYPE>* vu,
const int size_vu)
{
int* orbital_corr_tmp = const_cast<int*>(orbital_corr);
std::complex<FPTYPE>* vu_tmp = const_cast<std::complex<FPTYPE>*>(vu);
int* orbital_corr_tmp = nullptr;
std::complex<FPTYPE>* vu_tmp = nullptr;
#if defined(__CUDA) || defined(__ROCM)
if (this->device == base_device::GpuDevice)
{
Expand All @@ -1028,7 +1037,12 @@ void FS_Nonlocal_tools<FPTYPE, Device>::cal_stress_dftu(int ik,
resmem_complex_op()(this->ctx, vu_tmp, size_vu);
syncmem_complex_h2d_op()(this->ctx, cpu_ctx, vu_tmp, vu, size_vu);
}
else
#endif
{
orbital_corr_tmp = const_cast<int*>(orbital_corr);
vu_tmp = const_cast<std::complex<FPTYPE>*>(vu);
}
cal_stress_nl_op()(this->ctx,
nkb,
npm,
Expand Down Expand Up @@ -1065,14 +1079,18 @@ void FS_Nonlocal_tools<FPTYPE, Device>::cal_stress_dspin(int ik,
lambda_array[iat * 3 + 1] = lambda[iat].y;
lambda_array[iat * 3 + 2] = lambda[iat].z;
}
FPTYPE* lambda_tmp = lambda_array.data();
FPTYPE* lambda_tmp = nullptr;
#if defined(__CUDA) || defined(__ROCM)
if (this->device == base_device::GpuDevice)
{
resmem_var_op()(this->ctx, lambda_tmp, this->ucell_->nat * 3);
syncmem_var_h2d_op()(this->ctx, this->cpu_ctx, lambda_tmp, lambda_array.data(), this->ucell_->nat * 3);
}
else
#endif
{
lambda_tmp = lambda_array.data();
}
const int force_nc = 3;
cal_stress_nl_op()(this->ctx,
nkb,
Expand Down

0 comments on commit 8443c42

Please sign in to comment.