diff --git a/codebasin/preprocessor.py b/codebasin/preprocessor.py index 03c189c..43234ec 100644 --- a/codebasin/preprocessor.py +++ b/codebasin/preprocessor.py @@ -11,8 +11,9 @@ import hashlib import logging import os -from collections.abc import Iterable +from collections.abc import Callable, Iterable from copy import copy +from enum import Enum from typing import Self import numpy as np @@ -539,6 +540,11 @@ class ParseError(ValueError): """ +class Visit(Enum): + NEXT = 0 + NEXT_SIBLING = 1 + + class Node: """ Base class for all other Node types. @@ -597,6 +603,22 @@ def walk(self) -> Iterable[Self]: for child in self.children: yield from child.walk() + def visit(self, visitor: Callable[[Self], Visit]): + """ + Visit all descendants of this node via a preorder traversal, using the + supplied visitor. + + Raises + ------ + TypeError + If `visitor` is not callable. + """ + if not callable(visitor): + raise TypeError("visitor is not callable.") + if visitor(self) != Visit.NEXT_SIBLING: + for child in self.children: + child.visit(visitor) + class FileNode(Node): """ @@ -2359,6 +2381,18 @@ def walk(self) -> Iterable[Node]: """ yield from self.root.walk() + def visit(self, visitor: Callable[[Node], Visit]): + """ + Visit each node in the tree via a preorder traversal, using the + supplied visitor. + + Raises + ------ + TypeError + If `visitor` is not callable. + """ + self.root.visit(visitor) + def associate_file(self, filename): self.root.filename = filename diff --git a/tests/source-tree/test_source_tree.py b/tests/source-tree/test_source_tree.py index 4d9463a..c47c0b9 100644 --- a/tests/source-tree/test_source_tree.py +++ b/tests/source-tree/test_source_tree.py @@ -7,7 +7,7 @@ import warnings from codebasin.file_parser import FileParser -from codebasin.preprocessor import CodeNode, DirectiveNode, FileNode +from codebasin.preprocessor import CodeNode, DirectiveNode, FileNode, Visit class TestSourceTree(unittest.TestCase): @@ -19,9 +19,6 @@ def setUp(self): logging.getLogger("codebasin").disabled = False warnings.simplefilter("ignore", ResourceWarning) - def test_walk(self): - """Check that walk() visits nodes in the expected order""" - # TODO: Revisit this when SourceTree can be built without a file. with tempfile.NamedTemporaryFile( mode="w", @@ -43,36 +40,94 @@ def test_walk(self): f.close() # TODO: Revisit this when __str__() is more reliable. - tree = FileParser(f.name).parse_file(summarize_only=False) - expected_types = [ - FileNode, - DirectiveNode, - CodeNode, - DirectiveNode, - CodeNode, - DirectiveNode, - CodeNode, - DirectiveNode, - CodeNode, - ] - expected_contents = [ - f.name, - "FOO", - "foo", - "BAR", - "bar", - "else", - "baz", - "endif", - "qux", - ] - for i, node in enumerate(tree.walk()): - self.assertTrue(isinstance(node, expected_types[i])) - if isinstance(node, CodeNode): - contents = node.spelling()[0] - else: - contents = str(node) - self.assertTrue(expected_contents[i] in contents) + self.tree = FileParser(f.name).parse_file(summarize_only=False) + self.filename = f.name + + def test_walk(self): + """Check that walk() visits nodes in the expected order""" + expected_types = [ + FileNode, + DirectiveNode, + CodeNode, + DirectiveNode, + CodeNode, + DirectiveNode, + CodeNode, + DirectiveNode, + CodeNode, + ] + expected_contents = [ + self.filename, + "FOO", + "foo", + "BAR", + "bar", + "else", + "baz", + "endif", + "qux", + ] + for i, node in enumerate(self.tree.walk()): + self.assertTrue(isinstance(node, expected_types[i])) + if isinstance(node, CodeNode): + contents = node.spelling()[0] + else: + contents = str(node) + self.assertTrue(expected_contents[i] in contents) + + def test_visit_types(self): + """Check that visit() validates inputs""" + + class valid_visitor: + def __call__(self, node): + return True + + self.tree.visit(valid_visitor()) + + def visitor_function(node): + return True + + self.tree.visit(visitor_function) + + with self.assertRaises(TypeError): + self.tree.visit(1) + + class invalid_visitor: + pass + + with self.assertRaises(TypeError): + self.tree.visit(invalid_visitor()) + + def test_visit(self): + """Check that visit() visits nodes as expected""" + + # Check that a trivial visitor visits all nodes. + class NodeCounter: + def __init__(self): + self.count = 0 + + def __call__(self, node): + self.count += 1 + + node_counter = NodeCounter() + self.tree.visit(node_counter) + self.assertEqual(node_counter.count, 9) + + # Check that returning NEXT_SIBLING prevents descent. + class TopLevelCounter: + def __init__(self): + self.count = 0 + + def __call__(self, node): + if not isinstance(node, FileNode): + self.count += 1 + if isinstance(node, DirectiveNode): + return Visit.NEXT_SIBLING + return Visit.NEXT + + top_level_counter = TopLevelCounter() + self.tree.visit(top_level_counter) + self.assertEqual(top_level_counter.count, 5) if __name__ == "__main__":