Skip to content

Commit

Permalink
Building maps for nested matrix; More organised code.
Browse files Browse the repository at this point in the history
  • Loading branch information
Ig-dolci committed Apr 26, 2024
1 parent 89f3df2 commit ddbccab
Showing 1 changed file with 54 additions and 151 deletions.
205 changes: 54 additions & 151 deletions pyop3/array/petsc.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,23 +318,31 @@ def _iter_nest_labels(
else:
yield (rlabel_acc_, clabel_acc_)

# @cached_property
def _block_axes(self, axes, shape, blocked=False):
def _block_axes(self, axes, shape, nested_index=None):
block_axes, target_paths, index_exprs = self._collect_block_axes(axes, shape)
block_axes_unindexed, _, _ = self._collect_block_axes(axes.unindexed, shape)
if self.nested:
block_axes_unindexed, _, _ = self._collect_block_axes(
axes.unindexed[nested_index], shape
)
else:
block_axes_unindexed, _, _ = self._collect_block_axes(
axes.unindexed, shape
)
return IndexedAxisTree(
block_axes.node_map, block_axes_unindexed,
target_paths=target_paths, index_exprs=index_exprs,
outer_loops=axes.outer_loops,
layout_exprs=None)

# @cached_property
def _nest_axes(self, axes, index, blocked=False):

def _nest_axes(self, axes, index):
if axes.size > 1:
axes = self._block_axes(axes, axes.size, nested_index=index)
axes_unindexed = AxisTree(axes.unindexed[index].node_map)
target_paths = thaw(axes.target_paths)
to_kill = target_paths.pop(None)
for key, target_paths_per_axis in target_paths.items():
target_paths[key].pop(to_kill)
for key, _ in target_paths.items():
for k, _ in to_kill.items():
target_paths[key].pop(k)
index_exprs = dict(axes.index_exprs)
index_exprs.pop(None)
return IndexedAxisTree(
Expand All @@ -344,7 +352,6 @@ def _nest_axes(self, axes, index, blocked=False):
outer_loops=axes.outer_loops,
layout_exprs=None)


def _collect_block_axes(self, axes, shape, axis=None):
from pyop3.axtree.layout import _axis_size
target_paths = {}
Expand All @@ -368,6 +375,30 @@ def _collect_block_axes(self, axes, shape, axis=None):
index_exprs.update(subindex_exprs)
return axis_tree, target_paths, index_exprs

def _map(self, axes):
from pyop3.axtree.layout import my_product
loop_index = just_one(axes.outer_loops)
iterset = AxisTree(loop_index.iterset.node_map)
map_axes = iterset.add_subtree(axes, *iterset.leaf)
global_map = HierarchicalArray(map_axes, dtype=IntType)
global_map = global_map[loop_index.local_index]
for idxs in my_product(axes.outer_loops):
# target_indices = {idx.index.id: idx.target_exprs for idx in idxs}
target_indices = merge_dicts([idx.replace_map for idx in idxs])

for p in axes.iter(idxs, include_ghost_points=True):
offset = axes.unindexed.offset(
p.target_exprs, p.target_path, loop_exprs=target_indices
)
global_map.set_value(
p.source_exprs,
offset,
p.source_path,
loop_exprs=target_indices,
)
return global_map


@cached_property
@PETSc.Log.EventDecorator()
def maps(self):
Expand All @@ -376,153 +407,25 @@ def maps(self):
# TODO: Don't think these need to be lists here.
# FIXME: This will only work for singly-nested matrices
if self.nested:
#for (row_index, col_index), submat_type in self.mat_type.items():

row_index = 0 # for now only!
#col_index = ???
#submat_type = self.mat_type[row_index, col_index]
row_index = 1 # for now only!
col_index = 1
submat_type = "aij"


raxes = self._nest_axes(self.raxes, row_index)

# if self.raxes.unindexed[row_index].size // 2 == 1:
# raxes = unindexed_row_axes
# else:
# raxes = self._block_axes(unindexed_row_axes, self.raxes.unindexed[row_index].size // 2)

loop_index = just_one(raxes.outer_loops)
iterset = AxisTree(loop_index.iterset.node_map)
rmap_axes = iterset.add_subtree(raxes, *iterset.leaf)
rmap = HierarchicalArray(rmap_axes, dtype=IntType)
rmap = rmap[loop_index.local_index]
for idxs in my_product(raxes.outer_loops):
# target_indices = {idx.index.id: idx.target_exprs for idx in idxs}
target_indices = merge_dicts([idx.replace_map for idx in idxs])

for p in raxes.iter(idxs, include_ghost_points=True):
offset = raxes.unindexed.offset(
p.target_exprs, p.target_path, loop_exprs=target_indices
)
rmap.set_value(
p.source_exprs,
offset,
p.source_path,
loop_exprs=target_indices,
)

# # # rfield_axis = self.raxes.unindexed.root
# # cfield_axis = self.caxes.unindexed.root

# if strictly_all(c.unit for c in rfield_axis.components):
# # This weird trick is because the right target path for the field
# # is actually tied to the root of the axis tree, rather than None.
# # This seems like a limitation of the _compose_bits function.
# rfield = single_valued(
# cpt
# for mycpt in self.raxes.root.components
# for ax, cpt in self.raxes.target_paths[
# self.raxes.root.id, mycpt.label
# ].items()
# if ax == rfield_axis.label
# )
# orig_raxes = AxisTree(self.raxes.unindexed[rfield].node_map)
# orig_raxess = [orig_raxes]
# dropped_rkeys = {rfield_axis.label}
# else:
# orig_raxess = [self.raxes.unindexed]
# dropped_rkeys = frozenset()

# if strictly_all(c.unit for c in cfield_axis.components):
# cfield = single_valued(
# cpt
# for mycpt in self.caxes.root.components
# for ax, cpt in self.caxes.target_paths[
# self.caxes.root.id, mycpt.label
# ].items()
# if ax == cfield_axis.label
# )
# orig_caxes = AxisTree(self.caxes.unindexed[cfield].node_map)
# orig_caxess = [orig_caxes]
# dropped_ckeys = {cfield_axis.label}
# else:
# orig_caxess = [self.caxes.unindexed]
# dropped_ckeys = set()
elif self.mat_type == "baij":
raxes = self._block_axes(self.raxes, self.block_shape)
caxes = self._block_axes(self.caxes, self.block_shape)
elif self.mat_type == "aij":
raxes = self.raxes
caxes = self.caxes
caxes = self._nest_axes(self.caxes, col_index)
rmap = self._map(raxes)
cmap = self._map(caxes)
elif self.mat_type == "baij" or self.mat_type == "aij":
if self.mat_type == "baij":
raxes = self._block_axes(self.raxes, self.block_shape)
caxes = self._block_axes(self.caxes, self.block_shape)
else:
raxes = self.raxes
caxes = self.caxes
rmap = self._map(raxes)
cmap = self._map(caxes)
else:
raise NotImplementedError

orig_raxess = [raxes.unindexed]
orig_caxess = [caxes.unindexed]
dropped_rkeys = set()
dropped_ckeys = set()

# TODO: are dropped_rkeys and dropped_ckeys still needed?
loop_index = just_one(raxes.outer_loops)
iterset = AxisTree(loop_index.iterset.node_map)

rmap_axes = iterset.add_subtree(raxes, *iterset.leaf)
rmap = HierarchicalArray(rmap_axes, dtype=IntType)
rmap = rmap[loop_index.local_index]

loop_index = just_one(caxes.outer_loops)
iterset = AxisTree(loop_index.iterset.node_map)

cmap_axes = iterset.add_subtree(caxes, *iterset.leaf)
cmap = HierarchicalArray(cmap_axes, dtype=IntType)
cmap = cmap[loop_index.local_index]

# TODO: Make the code below go into a separate function distinct
# from mat_type logic. Then can also share code for rmap and cmap.
for orig_raxes in orig_raxess:
for idxs in my_product(raxes.outer_loops):
# target_indices = {idx.index.id: idx.target_exprs for idx in idxs}
target_indices = merge_dicts([idx.replace_map for idx in idxs])

for p in raxes.iter(idxs, include_ghost_points=True): # seems to fix things
target_path = p.target_path
target_exprs = p.target_exprs
for key in dropped_rkeys:
target_path = target_path.remove(key)
target_exprs = target_exprs.remove(key)

offset = orig_raxes.offset(
target_exprs, target_path, loop_exprs=target_indices
)
rmap.set_value(
p.source_exprs,
offset,
p.source_path,
loop_exprs=target_indices,
)

for orig_caxes in orig_caxess:
for idxs in my_product(self.caxes.outer_loops):
# target_indices = {idx.index.id: idx.target_exprs for idx in idxs}
target_indices = merge_dicts([idx.replace_map for idx in idxs])

# for p in self.caxes.iter(idxs):
for p in caxes.iter(idxs, include_ghost_points=True): # seems to fix things
target_path = p.target_path
target_exprs = p.target_exprs
for key in dropped_ckeys:
target_path = target_path.remove(key)
target_exprs = target_exprs.remove(key)

offset = orig_caxes.offset(
target_exprs, target_path, loop_exprs=target_indices
)
cmap.set_value(
p.source_exprs,
offset,
p.source_path,
loop_exprs=target_indices,
)
return (rmap, cmap)

@property
Expand Down

0 comments on commit ddbccab

Please sign in to comment.