diff --git a/CHANGELOG.md b/CHANGELOG.md index 07502b32f32..10ef6d19738 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +* Added `compas.datastructures.TreeNode` and `compas.datastructures.Tree` classes. * Added `EllipseArtist` to `compas_rhino` and `compas_ghpython`. ### Changed diff --git a/src/compas/datastructures/__init__.py b/src/compas/datastructures/__init__.py index 1ef6a20fb3b..cd761fe275a 100644 --- a/src/compas/datastructures/__init__.py +++ b/src/compas/datastructures/__init__.py @@ -159,6 +159,8 @@ from .cell_network.cell_network import CellNetwork +from .tree.tree import Tree, TreeNode + BaseNetwork = Network BaseMesh = Mesh BaseVolMesh = VolMesh @@ -277,6 +279,9 @@ "Feature", "GeometricFeature", "ParametricFeature", + # Trees + "Tree", + "TreeNode", ] if not compas.IPY: diff --git a/src/compas/datastructures/tree/__init__.py b/src/compas/datastructures/tree/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/src/compas/datastructures/tree/tree.py b/src/compas/datastructures/tree/tree.py new file mode 100644 index 00000000000..7d55c7fcb94 --- /dev/null +++ b/src/compas/datastructures/tree/tree.py @@ -0,0 +1,449 @@ +from __future__ import print_function +from __future__ import absolute_import +from __future__ import division + +from compas.datastructures import Datastructure +from compas.data import Data + + +class TreeNode(Data): + """A node of a tree data structure. + + Parameters + ---------- + name : str, optional + The name of the tree ndoe. + attributes : dict[str, Any], optional + User-defined attributes of the datastructure. + + Attributes + ---------- + name : str + The name of the datastructure. + attributes : dict[str, Any] + User-defined attributes of the datastructure. + parent : :class:`~compas.datastructures.TreeNode` + The parent node of the tree node. + children : list[:class:`~compas.datastructures.TreeNode`] + The children of the tree node. + tree : :class:`~compas.datastructures.Tree` + The tree to which the node belongs. + is_root : bool + True if the node is the root node of the tree. + is_leaf : bool + True if the node is a leaf node of the tree. + is_branch : bool + True if the node is a branch node of the tree. + acestors : generator[:class:`~compas.datastructures.TreeNode`] + A generator of the acestors of the tree node. + descendants : generator[:class:`~compas.datastructures.TreeNode`] + A generator of the descendants of the tree node, using a depth-first preorder traversal. + + """ + + DATASCHEMA = { + "type": "object", + "$recursiveAnchor": True, + "properties": { + "name": {"type": "string"}, + "attributes": {"type": "object"}, + "children": {"type": "array", "items": {"$recursiveRef": "#"}}, + }, + "required": ["name", "attributes", "children"], + } + + def __init__(self, name=None, attributes=None): + super(TreeNode, self).__init__(name=name) + self.attributes = attributes or {} + self._parent = None + self._children = [] + self._tree = None + + def __repr__(self): + return "".format(self.name) + + @property + def is_root(self): + return self._parent is None + + @property + def is_leaf(self): + return not self._children + + @property + def is_branch(self): + return not self.is_root and not self.is_leaf + + @property + def parent(self): + return self._parent + + @property + def children(self): + return self._children + + @property + def tree(self): + if self.is_root: + return self._tree + else: + return self.parent.tree + + @property + def data(self): + return { + "name": self.name, + "attributes": self.attributes, + "children": [child.data for child in self.children], + } + + @classmethod + def from_data(cls, data): + node = cls(data["name"], data["attributes"]) + for child in data["children"]: + node.add(cls.from_data(child)) + return node + + def add(self, node): + """ + Add a child node to this node. + + Parameters + ---------- + node : :class:`~compas.datastructures.TreeNode` + The node to add. + + Returns + ------- + None + + Raises + ------ + TypeError + If the node is not a :class:`~compas.datastructures.TreeNode` object. + + """ + if not isinstance(node, TreeNode): + raise TypeError("The node is not a TreeNode object.") + if node not in self._children: + self._children.append(node) + node._parent = self + + def remove(self, node): + """ + Remove a child node from this node. + + Parameters + ---------- + node : :class:`~compas.datastructures.TreeNode` + The node to remove. + + Returns + ------- + None + + """ + self._children.remove(node) + node._parent = None + + @property + def ancestors(self): + this = self + while this: + yield this + this = this.parent + + @property + def descendants(self): + for child in self.children: + yield child + for descendant in child.descendants: + yield descendant + + def traverse(self, strategy="depthfirst", order="preorder"): + """ + Traverse the tree from this node. + + Parameters + ---------- + strategy : {"depthfirst", "breadthfirst"}, optional + The traversal strategy. + order : {"preorder", "postorder"}, optional + The traversal order. This parameter is only used for depth-first traversal. + + Yields + ------ + :class:`~compas.datastructures.TreeNode` + The next node in the traversal. + + Raises + ------ + ValueError + If the strategy is not ``"depthfirst"`` or ``"breadthfirst"``. + If the order is not ``"preorder"`` or ``"postorder"``. + + """ + if strategy == "depthfirst": + if order == "preorder": + yield self + for child in self.children: + for node in child.traverse(strategy, order): + yield node + elif order == "postorder": + for child in self.children: + for node in child.traverse(strategy, order): + yield node + yield self + else: + raise ValueError("Unknown traversal order: {}".format(order)) + elif strategy == "breadthfirst": + queue = [self] + while queue: + node = queue.pop(0) + yield node + queue.extend(node.children) + else: + raise ValueError("Unknown traversal strategy: {}".format(strategy)) + + +class Tree(Datastructure): + """A hierarchical data structure that organizes elements into parent-child relationships. The tree starts from a unique root node, and every node (excluding the root) has exactly one parent. + + Parameters + ---------- + name : str, optional + The name of the datastructure. + attributes : dict[str, Any], optional + User-defined attributes of the datastructure. + + Attributes + ---------- + name : str + The name of the datastructure. + attributes : dict[str, Any] + User-defined attributes of the datastructure. + root : :class:`~compas.datastructures.TreeNode` + The root node of the tree. + nodes : generator[:class:`~compas.datastructures.TreeNode`] + The nodes of the tree. + leaves : generator[:class:`~compas.datastructures.TreeNode`] + A generator of the leaves of the tree. + + Examples + -------- + >>> from compas.datastructures import Tree, TreeNode + >>> tree = Tree() + >>> root = TreeNode('root') + >>> branch = TreeNode('branch') + >>> leaf1 = TreeNode('leaf1') + >>> leaf2 = TreeNode('leaf2') + >>> tree.add(root) + >>> root.add(branch) + >>> branch.add(leaf1) + >>> branch.add(leaf2) + >>> print(tree) + + >>> tree.print() + + + + + + """ + + DATASCHEMA = { + "type": "object", + "properties": { + "name": {"type": "string"}, + "root": TreeNode.DATASCHEMA, + "attributes": {"type": "object"}, + }, + "required": ["name", "root", "attributes"], + } + + def __init__(self, name=None, attributes=None): + super(Tree, self).__init__(name=name) + self.attributes.update(attributes or {}) + self._root = None + + @property + def data(self): + return { + "name": self.name, + "root": self.root.data, + "attributes": self.attributes, + } + + @classmethod + def from_data(cls, data): + tree = cls(data["name"], data["attributes"]) + root = TreeNode.from_data(data["root"]) + tree.add(root) + return tree + + @property + def root(self): + return self._root + + def add(self, node, parent=None): + """ + Add a node to the tree. + + Parameters + ---------- + node : :class:`~compas.datastructures.TreeNode` + The node to add. + parent : :class:`~compas.datastructures.TreeNode`, optional + The parent node of the node to add. + Default is ``None``, in which case the node is added as a root node. + + Returns + ------- + None + + Raises + ------ + TypeError + If the node is not a :class:`~compas.datastructures.TreeNode` object. + If the supplied parent node is not a :class:`~compas.datastructures.TreeNode` object. + ValueError + If the node is already part of another tree. + If the supplied parent node is not part of this tree. + If the tree already has a root node, when trying to add a root node. + + """ + if not isinstance(node, TreeNode): + raise TypeError("The node is not a TreeNode object.") + + if node.parent: + raise ValueError("The node already has a parent, remove it from that parent first.") + + if parent is None: + # add the node as a root node + if self.root is not None: + raise ValueError("The tree already has a root node, remove it first.") + + self._root = node + node._tree = self + + else: + # add the node as a child of the parent node + if not isinstance(parent, TreeNode): + raise TypeError("The parent node is not a TreeNode object.") + + if parent.tree is not self: + raise ValueError("The parent node is not part of this tree.") + + parent.add(node) + + @property + def nodes(self): + if self.root: + for node in self.root.traverse(): + yield node + + def remove(self, node): + """ + Remove a node from the tree. + + Parameters + ---------- + node : :class:`~compas.datastructures.TreeNode` + The node to remove. + + Returns + ------- + None + + """ + if node == self.root: + self._root = None + node._tree = None + else: + node.parent.remove(node) + + @property + def leaves(self): + for node in self.nodes: + if node.is_leaf: + yield node + + def traverse(self, strategy="depthfirst", order="preorder"): + """ + Traverse the tree from the root node. + + Parameters + ---------- + strategy : {"depthfirst", "breadthfirst"}, optional + The traversal strategy. + order : {"preorder", "postorder"}, optional + The traversal order. This parameter is only used for depth-first traversal. + + Yields + ------ + :class:`~compas.datastructures.TreeNode` + The next node in the traversal. + + Raises + ------ + ValueError + If the strategy is not ``"depthfirst"`` or ``"breadthfirst"``. + If the order is not ``"preorder"`` or ``"postorder"``. + + """ + if self.root: + for node in self.root.traverse(strategy=strategy, order=order): + yield node + + def get_node_by_name(self, name): + """ + Get a node by its name. + + Parameters + ---------- + name : str + The name of the node. + + Returns + ------- + :class:`~compas.datastructures.TreeNode` + The node. + + """ + for node in self.nodes: + if node.name == name: + return node + + def get_nodes_by_name(self, name): + """ + Get all nodes by their name. + + Parameters + ---------- + name : str + The name of the node. + + Returns + ------- + list[:class:`~compas.datastructures.TreeNode`] + The nodes. + + """ + nodes = [] + for node in self.nodes: + if node.name == name: + nodes.append(node) + return nodes + + def __repr__(self): + return "".format(len(list(self.nodes))) + + def print(self): + """Print the spatial hierarchy of the tree.""" + + def _print(node, depth=0): + print(" " * depth + str(node)) + for child in node.children: + _print(child, depth + 1) + + _print(self.root) diff --git a/tests/compas/datastructures/test_tree.py b/tests/compas/datastructures/test_tree.py new file mode 100644 index 00000000000..384a9091f12 --- /dev/null +++ b/tests/compas/datastructures/test_tree.py @@ -0,0 +1,169 @@ +import pytest +import compas +import json + +from compas.datastructures import Tree, TreeNode +from compas.data import json_dumps, json_loads + +# ============================================================================= +# Fixtures +# ============================================================================= + + +@pytest.fixture +def simple_tree(): + # A basic fixture for a simple tree + root = TreeNode(name="root") + branch1 = TreeNode(name="branch1") + branch2 = TreeNode(name="branch2") + leaf1_1 = TreeNode(name="leaf1_1") + leaf1_2 = TreeNode(name="leaf1_2") + leaf2_1 = TreeNode(name="leaf2_1") + leaf2_2 = TreeNode(name="leaf2_2") + + tree = Tree() + tree.add(root) + tree.add(branch1, parent=root) + tree.add(branch2, parent=root) + tree.add(leaf1_1, parent=branch1) + tree.add(leaf1_2, parent=branch1) + tree.add(leaf2_1, parent=branch2) + tree.add(leaf2_2, parent=branch2) + return tree + + +# ============================================================================= +# Basics +# ============================================================================= + + +def test_treenode_initialization(): + node = TreeNode(name="test") + assert node.name == "test" + assert node.parent is None + assert node.tree is None + assert len(node.children) == 0 + + +def test_tree_initialization(): + tree = Tree(name="test") + assert tree.name == "test" + assert tree.root is None + + +# ============================================================================= +# TreeNode Properties +# ============================================================================= + + +def test_treenode_properties(simple_tree): + root = simple_tree.root + branch1, branch2 = list(root.children) + leaf1_1, leaf1_2 = list(branch1.children) + leaf2_1, leaf2_2 = list(branch2.children) + + assert root.is_root is True + assert root.is_leaf is False + assert root.is_branch is False + + assert branch1.is_root is False + assert branch1.is_leaf is False + assert branch1.is_branch is True + + assert branch2.is_root is False + assert branch2.is_leaf is False + assert branch2.is_branch is True + + assert leaf1_1.is_root is False + assert leaf1_1.is_leaf is True + assert leaf1_1.is_branch is False + + assert leaf1_2.is_root is False + assert leaf1_2.is_leaf is True + assert leaf1_2.is_branch is False + + assert leaf2_1.is_root is False + assert leaf2_1.is_leaf is True + assert leaf2_1.is_branch is False + + assert leaf2_2.is_root is False + assert leaf2_2.is_leaf is True + assert leaf2_2.is_branch is False + + +# ============================================================================= +# Tree Properties +# ============================================================================= + + +def test_tree_properties(simple_tree): + nodes = list(simple_tree.nodes) + leaves = list(simple_tree.leaves) + + assert len(nodes) == 7 + assert len(leaves) == 4 + + +# ============================================================================= +# Tree Traversal +# ============================================================================= + + +def test_tree_traversal(simple_tree): + nodes = [node.name for node in simple_tree.traverse(strategy="depthfirst", order="preorder")] + assert nodes == ["root", "branch1", "leaf1_1", "leaf1_2", "branch2", "leaf2_1", "leaf2_2"] + + nodes = [node.name for node in simple_tree.traverse(strategy="depthfirst", order="postorder")] + assert nodes == ["leaf1_1", "leaf1_2", "branch1", "leaf2_1", "leaf2_2", "branch2", "root"] + + nodes = [node.name for node in simple_tree.traverse(strategy="breadthfirst")] + assert nodes == ["root", "branch1", "branch2", "leaf1_1", "leaf1_2", "leaf2_1", "leaf2_2"] + + +# ============================================================================= +# Tree Manipulation +# ============================================================================= + + +def test_tree_add_node(simple_tree): + branch2 = simple_tree.get_node_by_name("branch2") + branch2.add(TreeNode(name="test")) + + assert len(list(branch2.children)) == 3 + assert len(list(simple_tree.nodes)) == 8 + + +def test_tree_remove_node(simple_tree): + branch2 = simple_tree.get_node_by_name("branch2") + leaf2_1 = simple_tree.get_node_by_name("leaf2_1") + branch2.remove(leaf2_1) + + assert len(list(branch2.children)) == 1 + assert len(list(simple_tree.nodes)) == 6 + + root = simple_tree.root + branch1 = simple_tree.get_node_by_name("branch1") + root.remove(branch1) + + assert len(list(root.children)) == 1 + assert len(list(simple_tree.nodes)) == 3 + + +# ============================================================================= +# Tree Serialization +# ============================================================================= + + +def test_tree_serialization(simple_tree): + serialized = json_dumps(simple_tree) + deserialized = json_loads(serialized) + assert simple_tree.data == deserialized.data + + test_tree_properties(deserialized) + test_tree_traversal(deserialized) + test_tree_add_node(deserialized) + test_tree_remove_node(json_loads(serialized)) + + if not compas.IPY: + data = json.loads(serialized)["data"] + assert Tree.validate_data(data)