Skip to content

Commit

Permalink
Particles: Allocator Support
Browse files Browse the repository at this point in the history
  • Loading branch information
ax3l committed Feb 4, 2023
1 parent bc83999 commit 9218519
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 16 deletions.
17 changes: 15 additions & 2 deletions src/Base/PODVector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,11 @@ namespace
}

template <class T, class Allocator = std::allocator<T> >
void make_PODVector(py::module &m, std::string typestr)
void make_PODVector(py::module &m, std::string typestr, std::string allocstr)
{
using PODVector_type = PODVector<T, Allocator>;
auto const podv_name = std::string("PODVector_").append(typestr);
auto const podv_name = std::string("PODVector_").append(typestr)
.append("_").append(allocstr);

py::class_<PODVector_type>(m, podv_name.c_str())
.def("__repr__",
Expand Down Expand Up @@ -110,6 +111,18 @@ void make_PODVector(py::module &m, std::string typestr)
;
}

template <class T>
void make_PODVector(py::module &m, std::string typestr)
{
// see Src/Base/AMReX_GpuContainers.H
make_PODVector<T, std::allocator<T>> (m, typestr, "std");
make_PODVector<T, amrex::ArenaAllocator<T>> (m, typestr, "arena");
make_PODVector<T, amrex::DeviceArenaAllocator<T>> (m, typestr, "device");
make_PODVector<T, amrex::ManagedArenaAllocator<T>> (m, typestr, "managed");
make_PODVector<T, amrex::PinnedArenaAllocator<T>> (m, typestr, "pinned");
make_PODVector<T, amrex::AsyncArenaAllocator<T>> (m, typestr, "async");
}

void init_PODVector(py::module& m) {
make_PODVector<ParticleReal> (m, "real");
make_PODVector<int> (m, "int");
Expand Down
32 changes: 23 additions & 9 deletions src/Particle/ArrayOfStructs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@ namespace
template <int NReal, int NInt,
template<class> class Allocator=DefaultAllocator>
py::dict
array_interface(ArrayOfStructs<NReal, NInt> const & aos)
array_interface(ArrayOfStructs<NReal, NInt, Allocator> const & aos)
{
using ParticleType = Particle<NReal, NInt>;
using RealType = typename ParticleType::RealType;
using RealType = typename ParticleType::RealType;

auto d = py::dict();
bool const read_only = false;
Expand Down Expand Up @@ -62,12 +62,15 @@ namespace

template <int NReal, int NInt,
template<class> class Allocator=DefaultAllocator>
void make_ArrayOfStructs(py::module &m)
void make_ArrayOfStructs(py::module &m, std::string allocstr)
{
using AOSType = ArrayOfStructs<NReal, NInt>;
using AOSType = ArrayOfStructs<NReal, NInt, Allocator>;
using ParticleType = Particle<NReal, NInt>;

auto const aos_name = std::string("ArrayOfStructs_").append(std::to_string(NReal) + "_" + std::to_string(NInt));
auto const aos_name = std::string("ArrayOfStructs_")
.append(std::to_string(NReal)).append("_")
.append(std::to_string(NInt)).append("_")
.append(allocstr);
py::class_<AOSType>(m, aos_name.c_str())
.def(py::init())
// TODO:
Expand Down Expand Up @@ -117,9 +120,20 @@ void make_ArrayOfStructs(py::module &m)
;
}

template <int NReal, int NInt>
void make_ArrayOfStructs(py::module &m)
{
// see Src/Base/AMReX_GpuContainers.H
make_ArrayOfStructs<NReal, NInt, std::allocator> (m, "std");
make_ArrayOfStructs<NReal, NInt, amrex::ArenaAllocator> (m, "arena");
make_ArrayOfStructs<NReal, NInt, amrex::DeviceArenaAllocator> (m, "device");
make_ArrayOfStructs<NReal, NInt, amrex::ManagedArenaAllocator> (m, "managed");
make_ArrayOfStructs<NReal, NInt, amrex::PinnedArenaAllocator> (m, "pinned");
make_ArrayOfStructs<NReal, NInt, amrex::AsyncArenaAllocator> (m, "async");
}

void init_ArrayOfStructs(py::module& m) {
make_ArrayOfStructs< 0, 0> (m);
make_ArrayOfStructs< 7, 0> (m);
make_ArrayOfStructs< 1, 1> (m);
make_ArrayOfStructs< 2, 1> (m);
make_ArrayOfStructs<0, 0> (m); // WarpX 22.07, ImpactX 22.07, HiPACE++ 22.07
make_ArrayOfStructs<1, 1> (m); // test in ParticleContainer
make_ArrayOfStructs<2, 1> (m); // test
}
6 changes: 3 additions & 3 deletions tests/test_aos.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,15 @@


def test_aos_init():
aos = amrex.ArrayOfStructs_2_1()
aos = amrex.ArrayOfStructs_2_1_std()

assert aos.numParticles() == 0
assert aos.numTotalParticles() == aos.numRealParticles() == 0
assert aos.empty()


def test_aos_push_pop():
aos = amrex.ArrayOfStructs_2_1()
aos = amrex.ArrayOfStructs_2_1_std()
p1 = amrex.Particle_2_1()
p1.set_rdata([1.5, 2.2])
p1.set_idata([3])
Expand Down Expand Up @@ -50,7 +50,7 @@ def test_aos_push_pop():


def test_array_interface():
aos = amrex.ArrayOfStructs_2_1()
aos = amrex.ArrayOfStructs_2_1_std()
p1 = amrex.Particle_2_1()
p1.setPos([1, 2, 3])
p1.set_rdata([4.5, 5.2])
Expand Down
4 changes: 2 additions & 2 deletions tests/test_podvector.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@


def test_podvector_init():
podv = amrex.PODVector_real()
podv = amrex.PODVector_real_std()
print(podv.__array_interface__)
# podv[0] = 1
# podv[2] = 3
Expand All @@ -28,7 +28,7 @@ def test_podvector_init():


def test_array_interface():
podv = amrex.PODVector_int()
podv = amrex.PODVector_int_std()
podv.push_back(1)
podv.push_back(2)
podv.push_back(1)
Expand Down

0 comments on commit 9218519

Please sign in to comment.