Skip to content

Commit

Permalink
extend sendlist nlist and other tensors but still bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
CaRoLZhangxy committed Oct 18, 2024
1 parent 3466e34 commit c3a4f3e
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 15 deletions.
60 changes: 46 additions & 14 deletions deepmd/pt/model/model/spin_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def process_spin_input_lower(
extended_spin,
nlist,
mapping: Optional[torch.Tensor] = None,
recv_num:Optional[torch.Tensor] = None
):
"""
Add `extended_spin` into `extended_coord` to generate virtual atoms, and extend `nlist` and `mapping`.
Expand All @@ -82,18 +83,18 @@ def process_spin_input_lower(
)[extended_atype].reshape([nframes, nall, 1])
virtual_extended_atype = extended_atype + self.ntypes_real
extended_coord_updated = self.concat_switch_virtual(
extended_coord, virtual_extended_coord, nloc
extended_coord, virtual_extended_coord, nloc, recv_num = recv_num
)
extended_atype_updated = self.concat_switch_virtual(
extended_atype, virtual_extended_atype, nloc
extended_atype, virtual_extended_atype, nloc, recv_num = recv_num
)
if mapping is not None:
virtual_mapping = mapping + nloc
mapping_updated = self.concat_switch_virtual(mapping, virtual_mapping, nloc)
mapping_updated = self.concat_switch_virtual(mapping, virtual_mapping, nloc, recv_num = recv_num)
else:
mapping_updated = None
# extend the nlist
nlist_updated = self.extend_nlist(extended_atype, nlist)
nlist_updated = self.extend_nlist(extended_atype, nlist, recv_num = recv_num)
return (
extended_coord_updated,
extended_atype_updated,
Expand Down Expand Up @@ -176,7 +177,7 @@ def process_spin_output_lower(
return extended_out_real, extended_out_mag, atomic_mask > 0.0

@staticmethod
def extend_nlist(extended_atype, nlist):
def extend_nlist(extended_atype, nlist, recv_num:Optional[torch.Tensor] = None):
nframes, nloc, nnei = nlist.shape
nall = extended_atype.shape[1]
nlist_mask = nlist != -1
Expand All @@ -203,10 +204,21 @@ def extend_nlist(extended_atype, nlist):
second_part_index = (nall <= extended_nlist) & (extended_nlist < (nall + nloc))
extended_nlist[first_part_index] += nloc
extended_nlist[second_part_index] -= nall - nloc
if recv_num is not None:
index_part = []
origin_recv_num = torch.div(recv_num, 2).to(torch.int)
prefix_sum = torch.cumsum(origin_recv_num, dim=0)
prefix_sum = torch.cat((torch.tensor([0]), prefix_sum))
for i in range(recv_num.size(0)):
index_part.append((nloc * 2 + prefix_sum[i] <= extended_nlist) & (extended_nlist < nloc *2 + prefix_sum[i+1]))
index_part.append((nloc + nall + prefix_sum[i] <= extended_nlist) & (extended_nlist < nloc + nall + prefix_sum[i+1]))
for i in range(recv_num.size(0)):
extended_nlist[index_part[2 * i]] += prefix_sum[i]
extended_nlist[index_part[2 * i + 1]] -= nall - nloc - prefix_sum[i + 1]
return extended_nlist

@staticmethod
def concat_switch_virtual(extended_tensor, extended_tensor_virtual, nloc: int):
def concat_switch_virtual(extended_tensor, extended_tensor_virtual, nloc: int, recv_num:Optional[torch.Tensor] = None):
"""
Concat real and virtual extended tensors, and switch all the local ones to the first nloc * 2 atoms.
- [:, :nloc]: original nloc real atoms.
Expand All @@ -230,6 +242,15 @@ def concat_switch_virtual(extended_tensor, extended_tensor_virtual, nloc: int):
:, nloc:
]
extended_tensor_updated[:, nloc + nall :] = extended_tensor_virtual[:, nloc:]
if recv_num is not None:
origin_recv_num = torch.div(recv_num, 2).to(torch.int)
prefix_sum = torch.cumsum(recv_num, dim=0)
prefix_sum = torch.cat((torch.tensor([0]), prefix_sum))
origin_prefix_sum = torch.cumsum(origin_recv_num, dim=0)
origin_prefix_sum = torch.cat((torch.tensor([0]), origin_prefix_sum))
for i in range(recv_num.size(0)):
extended_tensor_updated[:,nloc + nloc + prefix_sum[i]: nloc + nloc + prefix_sum[i] + origin_recv_num[i]] = extended_tensor[:, nloc+origin_prefix_sum[i]:nloc + origin_prefix_sum[i+1]]
extended_tensor_updated[:,nloc + nloc + prefix_sum[i] + origin_recv_num[i]: nloc + nloc + prefix_sum[i + 1]] = extended_tensor_virtual[:, nloc+origin_prefix_sum[i]:nloc + origin_prefix_sum[i+1]]
return extended_tensor_updated.view(out_shape)

@staticmethod
Expand Down Expand Up @@ -475,14 +496,25 @@ def forward_common_lower(
extra_nlist_sort: bool = False,
):
nframes, nloc = nlist.shape[:2]
(
extended_coord_updated,
extended_atype_updated,
nlist_updated,
mapping_updated,
) = self.process_spin_input_lower(
extended_coord, extended_atype, extended_spin, nlist, mapping=mapping
)
if comm_dict is not None:
assert "recv_num" in comm_dict
(
extended_coord_updated,
extended_atype_updated,
nlist_updated,
mapping_updated,
) = self.process_spin_input_lower(
extended_coord, extended_atype, extended_spin, nlist, mapping=mapping,recv_num=comm_dict["recv_num"]
)
else:
(
extended_coord_updated,
extended_atype_updated,
nlist_updated,
mapping_updated,
) = self.process_spin_input_lower(
extended_coord, extended_atype, extended_spin, nlist, mapping=mapping
)
if aparam is not None:
aparam = self.expand_aparam(aparam, nloc * 2)
model_ret = self.backbone_model.forward_common_lower(
Expand Down
1 change: 1 addition & 0 deletions source/api_cc/include/DeepPotPT.h
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,7 @@ class DeepPotPT : public DeepPotBase {
bool gpu_enabled;
at::Tensor firstneigh_tensor;
torch::Dict<std::string, torch::Tensor> comm_dict;
int** spin_sendlist;
/**
* @brief Translate PyTorch exceptions to the DeePMD-kit exception.
* @param[in] f The function to run.
Expand Down
65 changes: 64 additions & 1 deletion source/api_cc/src/DeepPotPT.cc
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,69 @@ void DeepPotPT::compute(ENERGYVTYPE& ener,
nlist_data.padding();
if (do_message_passing == 1 && nghost > 0) {
int nswap = lmp_list.nswap;
spin_sendlist = new int*[nswap];
std::vector<int> prefixSum(nswap);
prefixSum[0] = 0;
prefixSum[1] = lmp_list.recvnum[0];
for (int i = 2; i < nswap; ++i) {
prefixSum[i] = prefixSum[i - 1] + lmp_list.recvnum[i-1];
}
for (int i = 0; i < nswap; ++i) {
spin_sendlist[i] = new int[lmp_list.sendnum[i] * 2];
int* sendlist_part = new int[nswap];
for (int j = 0; j < nswap; ++j) {
sendlist_part[j] = -1;
}
for(int j = 0; j < lmp_list.sendnum[i]; j++)
{
for(int ii = 0; ii < nswap; ++ii)
{
if (lmp_list.sendlist[i][j] >= nloc + prefixSum[ii] && sendlist_part[ii] == -1)
{
sendlist_part[ii] = j;
}
}
}
// std::cout<<sendlist_part[0]<<std::endl;
// std::cout<<sendlist_part[1]<<std::endl;
// std::cout<<sendlist_part[2]<<std::endl;
// std::cout<<sendlist_part[3]<<std::endl;
// std::cout<<sendlist_part[4]<<std::endl;
// std::cout<<sendlist_part[5]<<std::endl;
for (int j = 0; j < nswap - 1; ++j) {
if(sendlist_part[j] == -1)
sendlist_part[j] = lmp_list.sendnum[i];
}
int j = 0;
for(; j < sendlist_part[0]; j++)
{
long index = lmp_list.sendlist[i][j];
spin_sendlist[i][j] = index;
spin_sendlist[i][j + sendlist_part[0]] = index + nloc;
}
for(int ii = 1; ii < nswap; ++ii)
{
for(; j < sendlist_part[ii]; j++)
{
long index = lmp_list.sendlist[i][j];
spin_sendlist[i][j + sendlist_part[ii - 1]] = index + nloc + prefixSum[ii-1];
spin_sendlist[i][j + sendlist_part[ii]] = index + nloc + prefixSum[ii];
}
}
for(; j<lmp_list.sendnum[i]; j++)
{
long index = lmp_list.sendlist[i][j];
spin_sendlist[i][j + sendlist_part[5]] = index + nloc + prefixSum[5];
spin_sendlist[i][j + lmp_list.sendnum[i]] = index + nloc + nghost_real;
}
lmp_list.recvnum[i] *= 2;
lmp_list.sendnum[i] *= 2;
// for(int j = 0; j < lmp_list.sendnum[i]; j++)
// {
// std::cout<<spin_sendlist[i][j]<<" ";
// }
// std::cout<<std::endl;
}
torch::Tensor sendproc_tensor =
torch::from_blob(lmp_list.sendproc, {nswap}, int32_option);
torch::Tensor recvproc_tensor =
Expand All @@ -391,7 +454,7 @@ void DeepPotPT::compute(ENERGYVTYPE& ener,
int total_send =
std::accumulate(lmp_list.sendnum, lmp_list.sendnum + nswap, 0);
torch::Tensor sendlist_tensor =
torch::from_blob(lmp_list.sendlist, {total_send}, int32_option);
torch::from_blob(spin_sendlist, {total_send}, int32_option);
comm_dict.insert("send_list", sendlist_tensor);
comm_dict.insert("send_proc", sendproc_tensor);
comm_dict.insert("recv_proc", recvproc_tensor);
Expand Down

0 comments on commit c3a4f3e

Please sign in to comment.