diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index c868d566..5b1c4b7a 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -71,12 +71,8 @@ jobs: shell: bash working-directory: pyop3 run: | - # Gross: - pip install toml - python scripts/requirements.py build | pip install -r /dev/stdin - pip install --no-build-isolation . - python scripts/requirements.py run | pip install -r /dev/stdin pip install pytest pytest-cov pytest-timeout pytest-xdist pytest-timeout + pip install . - name: Run tests shell: bash @@ -92,5 +88,3 @@ jobs: -n 12 --dist worksteal \ -v tests timeout-minutes: 10 - - diff --git a/pyop3/axes/tree.py b/pyop3/axes/tree.py index 371dffc9..2eceec67 100644 --- a/pyop3/axes/tree.py +++ b/pyop3/axes/tree.py @@ -136,7 +136,7 @@ def with_context(self, context): key = {} for loop_index, path in context.items(): if loop_index in self.keys: - key |= {loop_index: path} + key.update({loop_index: path}) key = pmap(key) return self.context_map[key] @@ -474,9 +474,10 @@ def _collect_datamap(axis, *subdatamaps, axes): datamap = {} for cidx, component in enumerate(axis.components): if isinstance(count := component.count, MultiArray): - datamap |= count.datamap + datamap.update(count.datamap) - return datamap | merge_dicts(subdatamaps) + datamap.update(merge_dicts(subdatamaps)) + return datamap class AxisComponent(LabelledImmutableRecord): @@ -516,7 +517,7 @@ class AxisComponent(LabelledImmutableRecord): def __init__( self, count, - label: Hashable | None = None, + label: Optional[Hashable] = None, *, indices=None, overlap=None, @@ -623,10 +624,10 @@ class Axis(StrictLabelledNode, LoopIterable): def __init__( self, - components: Sequence[AxisComponent] | AxisComponent | int, - label: Hashable | None = None, + components: Union[Sequence[AxisComponent], AxisComponent, int], + label: Optional[Hashable] = None, *, - permutation: Sequence[int] | None = None, + permutation: Optional[Sequence[int]] = None, **kwargs, ): components = tuple(_as_axis_component(cpt) for cpt in as_tuple(components)) @@ -745,8 +746,8 @@ class AxisTree(StrictLabelledTree, LoopIterable, ContextFree): # fields = StrictLabelledTree.fields | {"target_paths", "index_exprs", "layout_exprs", "orig_axes", "sf", "shared_sf", "comm"} def __init__( self, - root: MultiAxis | None = None, - parent_to_children: dict | None = None, + root: Optional[MultiAxis] = None, + parent_to_children: Optional[Dict] = None, *, target_paths=None, index_exprs=None, @@ -914,11 +915,9 @@ def parse_bits( if target_axis.id in new_visited_target_axes: continue new_visited_target_axes |= {target_axis.id} - new_target_path_per_cpt[ - axis.id, component.label - ] |= self.target_path_per_component[ - target_axis.id, target_cpt.label - ] + new_target_path_per_cpt[axis.id, component.label].update( + self.target_path_per_component[target_axis.id, target_cpt.label] + ) # do a replacement orig_index_exprs = self.index_exprs_per_component[ @@ -953,9 +952,9 @@ def parse_bits( partial_layout_exprs=new_partial_layout_exprs, visited_target_axes=new_visited_target_axes, ) - new_target_path_per_cpt |= retval[0] - new_index_exprs_per_cpt |= retval[1] - new_layout_exprs_per_cpt |= retval[2] + new_target_path_per_cpt.update(retval[0]) + new_index_exprs_per_cpt.update(retval[1]) + new_layout_exprs_per_cpt.update(retval[2]) else: pass @@ -1061,7 +1060,7 @@ def datamap(self) -> dict[str:DistributedArray]: for cleverdict in [self.layouts, self.orig_layout_fn]: for layout in cleverdict.values(): for array in MultiArrayCollector()(layout): - dmap |= array.datamap + dmap.update(array.datamap) # TODO # for cleverdict in [self.index_exprs, self.layout_exprs]: @@ -1069,9 +1068,8 @@ def datamap(self) -> dict[str:DistributedArray]: for exprs in cleverdict.values(): for expr in exprs.values(): for array in MultiArrayCollector()(expr): - dmap |= array.datamap - # breakpoint() - return dmap + dmap.update(array.datamap) + return pmap(dmap) def _make_target_paths(self): return tuple(self.path(ax, cpt) for ax, cpt in self.leaves) @@ -1191,8 +1189,8 @@ def leaf_component(self): return self.leaf[1] def child( - self, parent: Axis, component: AxisComponent | ComponentLabel - ) -> Axis | None: + self, parent: Axis, component: Union[AxisComponent, ComponentLabel] + ) -> Optional[Axis]: cpt_label = _as_axis_component_label(component) return super().child(parent, cpt_label) @@ -1540,7 +1538,7 @@ def _compute_layouts( ) sublayoutss.append(sublayouts) csubtrees.append(csubtree) - steps |= substeps + steps.update(substeps) else: csubtrees.append(None) sublayoutss.append(collections.defaultdict(list)) @@ -1584,26 +1582,29 @@ def _compute_layouts( ctree = None for c in axis.components: step = step_size(axes, axis, c) - layouts |= { - path - # | {axis.label: c.label}: AffineLayout(axis.label, c.label, step) - | {axis.label: c.label}: AxisVariable(axis.label) * step - } + layouts.update( + { + path + # | {axis.label: c.label}: AffineLayout(axis.label, c.label, step) + | {axis.label: c.label}: AxisVariable(axis.label) * step + } + ) else: croot = CustomNode( [(cpt.count, axis.label, cpt.label) for cpt in axis.components] ) if strictly_all(sub is not None for sub in csubtrees): - cparent_to_children = { - croot.id: [sub.root for sub in csubtrees] - } | merge_dicts(sub.parent_to_children for sub in csubtrees) + cparent_to_children = pmap( + {croot.id: [sub.root for sub in csubtrees]} + ) | merge_dicts(sub.parent_to_children for sub in csubtrees) else: cparent_to_children = {} ctree = StrictLabelledTree(croot, cparent_to_children) # layouts and steps are just propagated from below - return layouts | merge_dicts(sublayoutss), ctree, steps + layouts.update(merge_dicts(sublayoutss)) + return layouts, ctree, steps # 2. add layouts here else: @@ -1623,9 +1624,9 @@ def _compute_layouts( bits.append((cpt.count, axlabel, clabel)) croot = CustomNode(bits) if strictly_all(sub is not None for sub in csubtrees): - cparent_to_children = { - croot.id: [sub.root for sub in csubtrees] - } | merge_dicts(sub.parent_to_children for sub in csubtrees) + cparent_to_children = pmap( + {croot.id: [sub.root for sub in csubtrees]} + ) | merge_dicts(sub.parent_to_children for sub in csubtrees) else: cparent_to_children = {} ctree = StrictLabelledTree(croot, cparent_to_children) @@ -1641,7 +1642,8 @@ def _compute_layouts( ctree = None steps = {path: _axis_size(axes, axis)} - return layouts | merge_dicts(sublayoutss), ctree, steps + layouts.update(merge_dicts(sublayoutss)) + return layouts, ctree, steps # must therefore be affine else: @@ -1661,7 +1663,7 @@ def _compute_layouts( sublayouts[path | {axis.label: mycomponent.label}] = new_layout start += _axis_component_size(axes, axis, mycomponent) - layouts |= sublayouts + layouts.update(sublayouts) steps = {path: _axis_size(axes, axis)} return layouts, None, steps @@ -1698,11 +1700,13 @@ def _create_count_array_tree( ) arrays[new_path] = countarray else: - arrays |= _create_count_array_tree( - ctree, - child, - counts | current_node.counts[cidx], - new_path, + arrays.update( + _create_count_array_tree( + ctree, + child, + counts | current_node.counts[cidx], + new_path, + ) ) return arrays @@ -1790,7 +1794,7 @@ def _tabulate_count_array_tree( def _collect_at_leaves( axes, values, - axis: Axis | None = None, + axis: Optional[Axis] = None, path=pmap(), prior=0, ): @@ -1804,7 +1808,7 @@ def _collect_at_leaves( else: prior_ = prior if subaxis := axes.child(axis, cpt): - acc |= _collect_at_leaves(axes, values, subaxis, new_path, prior_) + acc.update(_collect_at_leaves(axes, values, subaxis, new_path, prior_)) else: acc[new_path] = prior_ @@ -1877,7 +1881,7 @@ def _(arg: numbers.Real, path: Mapping, indices: Mapping): def _path_and_indices_from_index_tuple( axes, index_tuple -) -> tuple[pmap[Label, Label], pmap[Label, int]]: +) -> Tuple[pmap[Label, Label], pmap[Label, int]]: path = pmap() indices = pmap() axis = axes.root diff --git a/pyop3/codegen/ir.py b/pyop3/codegen/ir.py index 8ba220a4..ed301c43 100644 --- a/pyop3/codegen/ir.py +++ b/pyop3/codegen/ir.py @@ -765,7 +765,7 @@ def map_called_map(self, expr): map_array, pmap({rootaxis.label: just_one(rootaxis.components).label}) | pmap({inner_axis.label: inner_cpt.label}), - {rootaxis.label: inner_expr[0]} | {inner_axis.label: inner_expr[1]}, + {rootaxis.label: inner_expr[0], inner_axis.label: inner_expr[1]}, self._codegen_context, ) return jname_expr diff --git a/pyop3/distarray/multiarray.py b/pyop3/distarray/multiarray.py index b81656ff..a88c4492 100644 --- a/pyop3/distarray/multiarray.py +++ b/pyop3/distarray/multiarray.py @@ -153,13 +153,12 @@ def data_wo(self): @functools.cached_property def datamap(self) -> dict[str:DistributedArray]: - # FIXME when we use proper index trees - # return {self.name: self} | self.axes.datamap | merge_dicts([idxs.datamap for idxs in self.indicess]) - return ( - {self.name: self} - | self.axes.datamap - | merge_dicts([idx.datamap for idxs in self.indicess for idx in idxs]) + datamap = {self.name: self} + datamap.update(self.axes.datamap) + datamap.update( + merge_dicts([idx.datamap for idxs in self.indicess for idx in idxs]) ) + return datamap @property def alloc_size(self): diff --git a/pyop3/extras/debug.py b/pyop3/extras/debug.py index 20ad4daa..0c76e6df 100644 --- a/pyop3/extras/debug.py +++ b/pyop3/extras/debug.py @@ -1,13 +1,17 @@ +from typing import Optional, Union + from mpi4py import MPI from petsc4py import PETSc -def print_with_rank(*args, comm: PETSc.Comm | MPI.Comm | None = None) -> None: +def print_with_rank(*args, comm: Optional[Union[PETSc.Comm, MPI.Comm]] = None) -> None: comm = comm or PETSc.Sys.getDefaultComm() print(f"[rank {comm.rank}] : ", *args, sep="", flush=True) -def print_if_rank(rank: int, *args, comm: PETSc.Comm | MPI.Comm | None = None) -> None: +def print_if_rank( + rank: int, *args, comm: Optional[Union[PETSc.Comm, MPI.Comm]] = None +) -> None: comm = comm or PETSc.Sys.getDefaultComm() if rank == comm.rank: print(*args, flush=True) diff --git a/pyop3/indices/tree.py b/pyop3/indices/tree.py index 1d62c418..4832d264 100644 --- a/pyop3/indices/tree.py +++ b/pyop3/indices/tree.py @@ -77,8 +77,8 @@ def parse_parent_to_children(parent_to_children, parent, loop_context): parse_parent_to_children(parent_to_children, child, loop_context) ) - return pmap( - {parent.id: tuple(new_children)} | merge_dicts(subparents_to_children) + return pmap({parent.id: tuple(new_children)}) | merge_dicts( + subparents_to_children ) else: return pmap() @@ -155,12 +155,9 @@ def datamap(self): for expr_per_leaf in self.layout_exprs.values(): for expr in expr_per_leaf.values(): dmap |= collect_datamap_from_expression(expr) - - return ( - dmap - | self.array.datamap - | merge_dicts([axes.datamap for axes in self.axis_trees.values()]) - ) + dmap.update(self.array.datamap) + dmap.update(merge_dicts([axes.datamap for axes in self.axis_trees.values()])) + return pmap(dmap) @property def name(self): @@ -191,7 +188,7 @@ def __init__(self, component, start=None, stop=None, step=None, **kwargs): @property def datamap(self): - return {} + return pmap() class Subset(SliceComponent): @@ -301,8 +298,8 @@ def datamap(self): data = {} for bit in self.connectivity.values(): for map_cpt in bit: - data |= map_cpt.datamap - return data + data.update(map_cpt.datamap) + return pmap(data) class Index(LabelledNode): @@ -395,7 +392,9 @@ def target_paths(self, context): for leaf in iterset.leaves: target_path = {} for axis, cpt in iterset.path_with_nodes(*leaf).items(): - target_path |= iterset.target_path_per_component.get((axis.id, cpt), {}) + target_path.update( + iterset.target_path_per_component.get((axis.id, cpt), {}) + ) target_paths_.append(pmap(target_path)) return tuple(target_paths_) @@ -607,11 +606,11 @@ def _(arg: LoopIndex): for axis, cpt in axis_tree.path_with_nodes( *leaf, and_components=True ).items(): - target_path |= axis_tree.target_path_per_component[ - axis.id, cpt.label - ] - extra_source_context |= source_path - extracontext |= target_path + target_path.update( + axis_tree.target_path_per_component[axis.id, cpt.label] + ) + extra_source_context.update(source_path) + extracontext.update(target_path) contexts.append( loop_context | {arg: (pmap(extra_source_context), pmap(extracontext))} ) @@ -628,7 +627,9 @@ def _(arg: LoopIndex): for axis, cpt in iterset.path_with_nodes( *leaf, and_components=True ).items(): - target_path |= iterset.target_path_per_component[axis.id, cpt.label] + target_path.update( + iterset.target_path_per_component[axis.id, cpt.label] + ) contexts.append(pmap({arg: (source_path, pmap(target_path))})) return tuple(contexts) @@ -722,7 +723,7 @@ def index_tree_from_iterable( subtrees.append(subtree) root = index - parent_to_children = pmap({index.id: children} | merge_dicts(subtrees)) + parent_to_children = pmap({index.id: children}) | merge_dicts(subtrees) else: root = index parent_to_children = pmap() @@ -963,9 +964,9 @@ def _(called_map: CalledMap, **kwargs): called_map, prior_target_path, prior_index_exprs ) axes = axes.add_subaxis(subaxis, prior_leaf_axis, prior_leaf_cpt) - target_path_per_cpt |= subtarget_paths - index_exprs_per_cpt |= subindex_exprs - layout_exprs_per_cpt |= sublayout_exprs + target_path_per_cpt.update(subtarget_paths) + index_exprs_per_cpt.update(subindex_exprs) + layout_exprs_per_cpt.update(sublayout_exprs) return (axes, target_path_per_cpt, index_exprs_per_cpt, layout_exprs_per_cpt) @@ -1084,9 +1085,9 @@ def _index_axes_rec( layout_exprs_per_cpt_per_index[key] | retval[3][key] ) else: - target_path_per_cpt_per_index |= {key: retval[1][key]} - index_exprs_per_cpt_per_index |= {key: retval[2][key]} - layout_exprs_per_cpt_per_index |= {key: retval[3][key]} + target_path_per_cpt_per_index.update({key: retval[1][key]}) + index_exprs_per_cpt_per_index.update({key: retval[2][key]}) + layout_exprs_per_cpt_per_index.update({key: retval[3][key]}) target_path_per_component = pmap(target_path_per_cpt_per_index) index_exprs_per_component = pmap(index_exprs_per_cpt_per_index) diff --git a/pyop3/lang.py b/pyop3/lang.py index f1a54453..c3d8009b 100644 --- a/pyop3/lang.py +++ b/pyop3/lang.py @@ -220,9 +220,7 @@ def __init__(self, function, arguments): @functools.cached_property def datamap(self) -> dict[str, DistributedArray]: - return functools.reduce( - operator.or_, [arg.datamap for arg in self.arguments], {} - ) + return merge_dicts([arg.datamap for arg in self.arguments]) # @property # def operands(self): diff --git a/pyop3/space.py b/pyop3/space.py index 7992fcf5..a3a1e64f 100644 --- a/pyop3/space.py +++ b/pyop3/space.py @@ -1,6 +1,6 @@ import functools import numbers -from typing import Any, Hashable +from typing import Any, Dict, FrozenSet, Hashable, Optional import numpy as np import pytools @@ -27,7 +27,7 @@ def __init__( axis: Axis, *, priority: int = DEFAULT_AXIS_PRIORITY, - within_labels: frozenset[Hashable] = frozenset(), + within_labels: FrozenSet[Hashable] = frozenset(), ): self.axis = axis self.priority = priority @@ -184,8 +184,8 @@ def _insert_axis( axes: AxisTree, new_caxis: ConstrainedAxis, current_axis: Axis, - axis_to_caxis: dict[Axis, ConstrainedAxis], - path: dict[Hashable] | None = None, + axis_to_caxis: Dict[Axis, ConstrainedAxis], + path: Optional[Dict[Hashable, Dict]] = None, ): path = path or {} diff --git a/pyop3/tree.py b/pyop3/tree.py index b7583ae7..7ee9d5d8 100644 --- a/pyop3/tree.py +++ b/pyop3/tree.py @@ -3,7 +3,7 @@ import collections import functools from collections.abc import Hashable, Sequence -from typing import Any, Mapping +from typing import Any, Dict, FrozenSet, List, Mapping, Optional, Tuple, Union import pyrsistent import pytools @@ -66,7 +66,7 @@ def __init__(self, root, parent_to_children): def __str__(self): return self._stringify() - def __contains__(self, node: Node | str) -> bool: + def __contains__(self, node: Union[Node, str]) -> bool: return self._as_node(node) in self.nodes @property @@ -75,10 +75,10 @@ def is_empty(self) -> bool: def _stringify( self, - node: Node | Hashable | None = None, + node: Optional[Union[Node, Hashable]] = None, begin_prefix: str = "", cont_prefix: str = "", - ) -> list[str] | str: + ) -> Union[List[str], str]: if self.is_empty: return "empty" node = self._as_node(node) if node else self.root @@ -102,8 +102,8 @@ def _stringify( class StrictLabelledTree(LabelledTree): def __init__( self, - root: Node | None = None, - parent_to_children: Mapping[Id, Node] | None = None, + root: Optional[Node] = None, + parent_to_children: Optional[Mapping[Id, Node]] = None, ) -> None: if root: if parent_to_children: @@ -112,11 +112,15 @@ def __init__( for parent_id, children in parent_to_children.items() } - parent_to_children |= { - node.id: (None,) * node.degree - for node in filter(None, flatten(list(parent_to_children.values()))) - if node.id not in parent_to_children - } + parent_to_children.update( + { + node.id: (None,) * node.degree + for node in filter( + None, flatten(list(parent_to_children.values())) + ) + if node.id not in parent_to_children + } + ) node_ids = [ node.id @@ -148,12 +152,12 @@ def depth(self) -> int: count = lambda _, *o: max(o or [0]) + 1 return postvisit(self, count) - def children(self, node: Node | str) -> tuple[Node]: + def children(self, node: Union[Node, str]) -> Tuple[Node]: node_id = self._as_node_id(node) return self.parent_to_children[node_id] def child( - self, parent: LabelledNode | NodeId, component_label: ComponentLabel + self, parent: Union[LabelledNode, NodeId], component_label: ComponentLabel ) -> LabelledNode: parent = self._as_node(parent) cpt_index = parent.component_labels.index(component_label) @@ -162,8 +166,8 @@ def child( def add_node( self, node: Node, - parent: Node | Id | None = None, - parent_component: Label | None = None, + parent: Optional[Union[Node, Id]] = None, + parent_component: Optional[Label] = None, uniquify: bool = False, ) -> None: if parent is None: @@ -205,7 +209,7 @@ def add_node( # old alias put_node = add_node - def replace_node(self, old: Node | Id, new: Node) -> LabelledTree: + def replace_node(self, old: Union[Node, Id], new: Node) -> LabelledTree: old = self._as_node(old) new_root = self.root @@ -231,7 +235,7 @@ def node_ids(self) -> frozenset[Id]: return frozenset(node.id for node in self.nodes) @functools.cached_property - def child_to_parent(self) -> dict[Node, tuple[Node, NodeComponent]]: + def child_to_parent(self) -> Dict[Node, Tuple[Node, NodeComponent]]: child_to_parent_ = {} for parent_id, children in self.parent_to_children.items(): parent = self._as_node(parent_id) @@ -246,20 +250,20 @@ def id_to_node(self): return {node.id: node for node in self.nodes} @functools.cached_property - def nodes(self) -> frozenset[Node]: + def nodes(self) -> Frozenset[Node]: return frozenset({self.root}) | { node for node in filter(None, flatten(list(self.parent_to_children.values()))) } - def parent(self, node: Node | Id) -> tuple[Node, NodeComponent] | None: + def parent(self, node: Union[Node, Id]) -> Optional[Tuple[Node, NodeComponent]]: node = self._as_node(node) if node == self.root: return None else: return self.child_to_parent[node] - def pop_subtree(self, subroot: Node | str) -> "Tree": + def pop_subtree(self, subroot: Union[Node, str]) -> Tree: subroot = self._as_node(subroot) self._check_exists(subroot) @@ -293,8 +297,8 @@ def collect_node_and_parent(node, _): def add_subtree( self, subtree: LabelledTree, - parent: Node | Id | None = None, - component: NodeComponent | Label | None = None, + parent: Optional[Union[Node, Id]] = None, + component: Optional[Union[NodeComponent, Label]] = None, uniquify: bool = False, ) -> None: """ @@ -322,18 +326,18 @@ def add_subtree( p: list(ch) for p, ch in self.parent_to_children.items() } new_parent_to_children[parent.id][cpt_index] = subtree.root - new_parent_to_children |= subtree.parent_to_children + new_parent_to_children.update(subtree.parent_to_children) return self.copy(parent_to_children=new_parent_to_children) # alias, better? def _to_node_id(self, arg): return self._as_node_id(arg) - def _check_exists(self, node: Node | str) -> None: + def _check_exists(self, node: Union[Node, str]) -> None: if (node_id := self._as_node(node).id) not in self.node_ids: raise NodeNotFoundException(f"{node_id} is not present in the tree") - def _first_unique_id(self, node: Node | Hashable, sep: str = "_") -> str: + def _first_unique_id(self, node: Union[Node, Hashable], sep: str = "_") -> str: orig_node_id = self._as_node_id(node) if orig_node_id not in self: return orig_node_id @@ -345,23 +349,26 @@ def _first_unique_id(self, node: Node | Hashable, sep: str = "_") -> str: node_id = f"{orig_node_id}{sep}{counter}" return node_id - def _as_node(self, node: LabelledNode | Id) -> Node: + def _as_node(self, node: Union[LabelledNode, Id]) -> Node: return node if isinstance(node, Node) else self.id_to_node[node] - def _as_node_id(self, node: Node | Id) -> Id: + def _as_node_id(self, node: Union[Node, Id]) -> Id: return node.id if isinstance(node, Node) else node - def with_modified_node(self, node: Node | Id, **kwargs): + def with_modified_node(self, node: Union[Node, Id], **kwargs): return self.replace_node(node, node.copy(**kwargs)) def with_modified_component( - self, node: Node, component: NodeComponent | Label | None = None, **kwargs + self, + node: Node, + component: Optional[Union[NodeComponent, Label]] = None, + **kwargs, ): return self.replace_node( node, node.with_modified_component(component, **kwargs) ) - def pop_subtree(self, subroot: Node | str) -> "Tree": + def pop_subtree(self, subroot: Union[Node, str]) -> "Tree": subroot = self._as_node(subroot) self._check_exists(subroot) @@ -399,7 +406,7 @@ def collect_node_and_parent(node, _): return subtree - def child_by_label(self, node: LabelledNode | Hashable, label: Hashable): + def child_by_label(self, node: Union[LabelledNode, Hashable], label: Hashable): node_id = self._as_node_id(node) child = self._parent_and_label_to_child[node_id, label] if child is not None: @@ -410,7 +417,7 @@ def child_by_label(self, node: LabelledNode | Hashable, label: Hashable): @classmethod def from_dict( cls, - node_dict: dict[Node, Hashable], + node_dict: Dict[Node, Hashable], set_up: bool = False, ) -> "LabelledTree": # -> subclass? tree = cls() @@ -467,7 +474,7 @@ def paths_fn(node, component_label, current_path): return pmap(paths_) @functools.cached_property - def leaves(self) -> tuple[tuple[Node, ComponentLabel]]: + def leaves(self) -> Tuple[Tuple[Node, ComponentLabel]]: """Return the leaves of the tree.""" leaves_ = [] @@ -482,7 +489,7 @@ def leaves_fn(node, cpt, prev): def leaf(self) -> Node: return just_one(self.leaves) - def is_leaf(self, node: Node | str) -> bool: + def is_leaf(self, node: Union[Node, str]) -> bool: node = self._as_node(node) self._check_exists(node) return all(child is None for child in self.parent_to_children[node.id]) @@ -520,7 +527,7 @@ def path_with_nodes( else: return pmap(path_) - def _node_from_path(self, path: Mapping[Node | Hashable, int]) -> Node: + def _node_from_path(self, path: Mapping[Union[Node, Hashable], int]) -> Node: if not path: return None @@ -539,28 +546,15 @@ def _node_from_path(self, path: Mapping[Node | Hashable, int]) -> Node: assert False, "shouldn't get this far" -NodePath = dict[Hashable, Hashable] +NodePath = Dict[Hashable, Hashable] """Mapping from axis labels to component labels.""" # wrong now -# def previsit( -# tree, fn, current_node: Node | None = None, prev_result: Any | None = None -# ) -> Any: -# if tree.is_empty: -# raise RuntimeError("Cannot traverse an empty tree") -# -# current_node = current_node or tree.root -# -# result = fn(current_node, prev_result) -# for child in tree.children(current_node): -# previsit(tree, fn, child, result) -# -# def previsit( tree, fn, - current_node: Node | None = None, + current_node: Optional[Node] = None, prev=None, ) -> Any: if tree.is_empty: @@ -573,7 +567,7 @@ def previsit( previsit(tree, fn, subnode, next) -def postvisit(tree, fn, current_node: Node | None = None, **kwargs) -> Any: +def postvisit(tree, fn, current_node: Optional[Node] = None, **kwargs) -> Any: """Traverse the tree in postorder. # TODO rewrite diff --git a/pyop3/utils.py b/pyop3/utils.py index cf40e0d0..1fe2e002 100644 --- a/pyop3/utils.py +++ b/pyop3/utils.py @@ -2,9 +2,10 @@ import functools import itertools import operator -from typing import Any, Collection, Hashable +from typing import Any, Collection, Hashable, Optional import pytools +from pyrsistent import pmap class UniqueNameGenerator(pytools.UniqueNameGenerator): @@ -32,7 +33,7 @@ def unique_name(prefix: str) -> str: class UniquelyIdentifiedImmutableRecord(pytools.ImmutableRecord): fields = {"id"} - def __init__(self, id: Id | None = None): + def __init__(self, id: Optional[Id] = None): pytools.ImmutableRecord.__init__(self) self.id = id if id is not None else self.unique_id() @@ -44,7 +45,7 @@ def unique_id(cls): class LabelledImmutableRecord(UniquelyIdentifiedImmutableRecord): fields = {"label"} | UniquelyIdentifiedImmutableRecord.fields - def __init__(self, label: Label | None = None, **kwargs): + def __init__(self, label: Optional[Label] = None, **kwargs): super().__init__(**kwargs) self.label = ( label if label is not None else unique_name(f"_{type(self).__name__}_label") @@ -95,12 +96,11 @@ def pad(iterable, length, after=True, padding_value=None): is_single_valued = pytools.is_single_valued -def merge_dicts(dicts): +def merge_dicts(dicts, persistent=True): merged = {} for dict_ in dicts: merged.update(dict_) - return merged - # return functools.reduce(operator.or_, dicts, {}) + return pmap(merged) if persistent else merged def unique(iterable): diff --git a/pyproject.toml b/pyproject.toml index 35ad9086..6d8625b3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,9 +1,6 @@ [build-system] requires = [ "setuptools", - "cython<3", - "numpy", - "petsc4py" ] build-backend = "setuptools.build_meta" @@ -11,11 +8,11 @@ build-backend = "setuptools.build_meta" name = "pyop3" version = "0.1" dependencies = [ - "pyrsistent", + "mpi4py", "numpy", + "petsc4py", + "pyrsistent", "loopy @ git+https://github.com/firedrakeproject/loopy.git", - "mpi4py", - "petsc4py" ] [project.optional-dependencies] diff --git a/setup.py b/setup.py deleted file mode 100644 index 3eb05a3b..00000000 --- a/setup.py +++ /dev/null @@ -1,40 +0,0 @@ -import os - -import numpy as np -import petsc4py -from setuptools import Extension, setup - - -def get_petsc_dirs(): - try: - petsc_dir = os.environ["PETSC_DIR"] - petsc_arch = os.environ["PETSC_ARCH"] - except KeyError: - raise RuntimeError("PETSC_DIR and PETSC_ARCH variables not defined") - return (petsc_dir, f"{petsc_dir}/{petsc_arch}") - - -def make_sparsity_extension(): - petsc_dirs = get_petsc_dirs() - include_dirs = [np.get_include(), petsc4py.get_include()] + [ - f"{dir}/include" for dir in petsc_dirs - ] - extra_link_args = [f"-L{dir}/lib" for dir in petsc_dirs] + [ - f"-Wl,-rpath,{dir}/lib" for dir in petsc_dirs - ] - - return Extension( - name="pyop3.sparsity", - sources=["pyop3/sparsity.pyx"], - include_dirs=include_dirs, - language="c", - libraries=["petsc"], - extra_link_args=extra_link_args, - ) - - -if __name__ == "__main__": - # sparsity_ext = make_sparsity_extension() - - # setup(ext_modules=[sparsity_ext]) - setup()