forked from devitocodes/devito
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #3 from ZoeLeibowitz/core_petsc
Core PETSc objects/functions.
- Loading branch information
Showing
5 changed files
with
170 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
from devito.tools import CustomDtype | ||
from devito.types import LocalObject | ||
from devito.types.array import ArrayBasic | ||
import numpy as np | ||
from cached_property import cached_property | ||
|
||
|
||
class DM(LocalObject): | ||
""" | ||
PETSc Data Management object (DM). | ||
""" | ||
dtype = CustomDtype('DM') | ||
|
||
|
||
class Mat(LocalObject): | ||
""" | ||
PETSc Matrix object (Mat). | ||
""" | ||
dtype = CustomDtype('Mat') | ||
|
||
|
||
class Vec(LocalObject): | ||
""" | ||
PETSc Vector object (Vec). | ||
""" | ||
dtype = CustomDtype('Vec') | ||
|
||
|
||
class PetscMPIInt(LocalObject): | ||
""" | ||
PETSc datatype used to represent `int` parameters | ||
to MPI functions. | ||
""" | ||
dtype = CustomDtype('PetscMPIInt') | ||
|
||
|
||
class KSP(LocalObject): | ||
""" | ||
PETSc KSP : Linear Systems Solvers. | ||
Manages Krylov Methods. | ||
""" | ||
dtype = CustomDtype('KSP') | ||
|
||
|
||
class PC(LocalObject): | ||
""" | ||
PETSc object that manages all preconditioners (PC). | ||
""" | ||
dtype = CustomDtype('PC') | ||
|
||
|
||
class KSPConvergedReason(LocalObject): | ||
""" | ||
PETSc object - reason a Krylov method was determined | ||
to have converged or diverged. | ||
""" | ||
dtype = CustomDtype('KSPConvergedReason') | ||
|
||
|
||
class PETScArray(ArrayBasic): | ||
""" | ||
PETScArrays are generated by the compiler only and represent | ||
a customised variant of ArrayBasic. They are designed to | ||
avoid generating a cast in the low-level code. | ||
""" | ||
|
||
_data_alignment = False | ||
|
||
@classmethod | ||
def __dtype_setup__(cls, **kwargs): | ||
return kwargs.get('dtype', np.float32) | ||
|
||
@cached_property | ||
def _C_ctype(self): | ||
petsc_type = dtype_to_petsctype(self.dtype) | ||
modifier = '*' * len(self.dimensions) | ||
return CustomDtype(petsc_type, modifier=modifier) | ||
|
||
@property | ||
def _C_name(self): | ||
return self.name | ||
|
||
|
||
def dtype_to_petsctype(dtype): | ||
"""Map numpy types to PETSc datatypes.""" | ||
|
||
return { | ||
np.int32: 'PetscInt', | ||
np.float32: 'PetscScalar', | ||
np.int64: 'PetscInt', | ||
np.float64: 'PetscScalar' | ||
}[dtype] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
from devito import Grid | ||
from devito.ir.iet import Call, ElementalFunction, Definition, DummyExpr | ||
from devito.passes.iet.languages.C import CDataManager | ||
from devito.types import (DM, Mat, Vec, PetscMPIInt, KSP, | ||
PC, KSPConvergedReason, PETScArray) | ||
import numpy as np | ||
|
||
|
||
def test_petsc_local_object(): | ||
""" | ||
Test C++ support for PETSc LocalObjects. | ||
""" | ||
lo0 = DM('da') | ||
lo1 = Mat('A') | ||
lo2 = Vec('x') | ||
lo3 = PetscMPIInt('size') | ||
lo4 = KSP('ksp') | ||
lo5 = PC('pc') | ||
lo6 = KSPConvergedReason('reason') | ||
|
||
iet = Call('foo', [lo0, lo1, lo2, lo3, lo4, lo5, lo6]) | ||
iet = ElementalFunction('foo', iet, parameters=()) | ||
|
||
dm = CDataManager(sregistry=None) | ||
iet = CDataManager.place_definitions.__wrapped__(dm, iet)[0] | ||
|
||
assert 'DM da;' in str(iet) | ||
assert 'Mat A;' in str(iet) | ||
assert 'Vec x;' in str(iet) | ||
assert 'PetscMPIInt size;' in str(iet) | ||
assert 'KSP ksp;' in str(iet) | ||
assert 'PC pc;' in str(iet) | ||
assert 'KSPConvergedReason reason;' in str(iet) | ||
|
||
|
||
def test_petsc_functions(): | ||
""" | ||
Test C++ support for PETScArrays. | ||
""" | ||
grid = Grid((2, 2)) | ||
x, y = grid.dimensions | ||
|
||
ptr0 = PETScArray(name='ptr0', dimensions=grid.dimensions, dtype=np.float32) | ||
ptr1 = PETScArray(name='ptr1', dimensions=grid.dimensions, dtype=np.float32, | ||
is_const=True) | ||
ptr2 = PETScArray(name='ptr2', dimensions=grid.dimensions, dtype=np.float64, | ||
is_const=True) | ||
ptr3 = PETScArray(name='ptr3', dimensions=grid.dimensions, dtype=np.int32) | ||
ptr4 = PETScArray(name='ptr4', dimensions=grid.dimensions, dtype=np.int64, | ||
is_const=True) | ||
|
||
defn0 = Definition(ptr0) | ||
defn1 = Definition(ptr1) | ||
defn2 = Definition(ptr2) | ||
defn3 = Definition(ptr3) | ||
defn4 = Definition(ptr4) | ||
|
||
expr = DummyExpr(ptr0.indexed[x, y], ptr1.indexed[x, y] + 1) | ||
|
||
assert str(defn0) == 'PetscScalar**restrict ptr0;' | ||
assert str(defn1) == 'const PetscScalar**restrict ptr1;' | ||
assert str(defn2) == 'const PetscScalar**restrict ptr2;' | ||
assert str(defn3) == 'PetscInt**restrict ptr3;' | ||
assert str(defn4) == 'const PetscInt**restrict ptr4;' | ||
assert str(expr) == 'ptr0[x][y] = ptr1[x][y] + 1;' |