Skip to content

Commit

Permalink
Merge pull request #121 from Pennycook/tree-visit
Browse files Browse the repository at this point in the history
Add visit() function to source tree
  • Loading branch information
Pennycook authored Nov 5, 2024
2 parents 9231ff8 + 72b3d1b commit 29e7c77
Show file tree
Hide file tree
Showing 2 changed files with 124 additions and 35 deletions.
36 changes: 35 additions & 1 deletion codebasin/preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@
import hashlib
import logging
import os
from collections.abc import Iterable
from collections.abc import Callable, Iterable
from copy import copy
from enum import Enum
from typing import Self

import numpy as np
Expand Down Expand Up @@ -539,6 +540,11 @@ class ParseError(ValueError):
"""


class Visit(Enum):
NEXT = 0
NEXT_SIBLING = 1


class Node:
"""
Base class for all other Node types.
Expand Down Expand Up @@ -597,6 +603,22 @@ def walk(self) -> Iterable[Self]:
for child in self.children:
yield from child.walk()

def visit(self, visitor: Callable[[Self], Visit]):
"""
Visit all descendants of this node via a preorder traversal, using the
supplied visitor.
Raises
------
TypeError
If `visitor` is not callable.
"""
if not callable(visitor):
raise TypeError("visitor is not callable.")
if visitor(self) != Visit.NEXT_SIBLING:
for child in self.children:
child.visit(visitor)


class FileNode(Node):
"""
Expand Down Expand Up @@ -2359,6 +2381,18 @@ def walk(self) -> Iterable[Node]:
"""
yield from self.root.walk()

def visit(self, visitor: Callable[[Node], Visit]):
"""
Visit each node in the tree via a preorder traversal, using the
supplied visitor.
Raises
------
TypeError
If `visitor` is not callable.
"""
self.root.visit(visitor)

def associate_file(self, filename):
self.root.filename = filename

Expand Down
123 changes: 89 additions & 34 deletions tests/source-tree/test_source_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import warnings

from codebasin.file_parser import FileParser
from codebasin.preprocessor import CodeNode, DirectiveNode, FileNode
from codebasin.preprocessor import CodeNode, DirectiveNode, FileNode, Visit


class TestSourceTree(unittest.TestCase):
Expand All @@ -19,9 +19,6 @@ def setUp(self):
logging.getLogger("codebasin").disabled = False
warnings.simplefilter("ignore", ResourceWarning)

def test_walk(self):
"""Check that walk() visits nodes in the expected order"""

# TODO: Revisit this when SourceTree can be built without a file.
with tempfile.NamedTemporaryFile(
mode="w",
Expand All @@ -43,36 +40,94 @@ def test_walk(self):
f.close()

# TODO: Revisit this when __str__() is more reliable.
tree = FileParser(f.name).parse_file(summarize_only=False)
expected_types = [
FileNode,
DirectiveNode,
CodeNode,
DirectiveNode,
CodeNode,
DirectiveNode,
CodeNode,
DirectiveNode,
CodeNode,
]
expected_contents = [
f.name,
"FOO",
"foo",
"BAR",
"bar",
"else",
"baz",
"endif",
"qux",
]
for i, node in enumerate(tree.walk()):
self.assertTrue(isinstance(node, expected_types[i]))
if isinstance(node, CodeNode):
contents = node.spelling()[0]
else:
contents = str(node)
self.assertTrue(expected_contents[i] in contents)
self.tree = FileParser(f.name).parse_file(summarize_only=False)
self.filename = f.name

def test_walk(self):
"""Check that walk() visits nodes in the expected order"""
expected_types = [
FileNode,
DirectiveNode,
CodeNode,
DirectiveNode,
CodeNode,
DirectiveNode,
CodeNode,
DirectiveNode,
CodeNode,
]
expected_contents = [
self.filename,
"FOO",
"foo",
"BAR",
"bar",
"else",
"baz",
"endif",
"qux",
]
for i, node in enumerate(self.tree.walk()):
self.assertTrue(isinstance(node, expected_types[i]))
if isinstance(node, CodeNode):
contents = node.spelling()[0]
else:
contents = str(node)
self.assertTrue(expected_contents[i] in contents)

def test_visit_types(self):
"""Check that visit() validates inputs"""

class valid_visitor:
def __call__(self, node):
return True

self.tree.visit(valid_visitor())

def visitor_function(node):
return True

self.tree.visit(visitor_function)

with self.assertRaises(TypeError):
self.tree.visit(1)

class invalid_visitor:
pass

with self.assertRaises(TypeError):
self.tree.visit(invalid_visitor())

def test_visit(self):
"""Check that visit() visits nodes as expected"""

# Check that a trivial visitor visits all nodes.
class NodeCounter:
def __init__(self):
self.count = 0

def __call__(self, node):
self.count += 1

node_counter = NodeCounter()
self.tree.visit(node_counter)
self.assertEqual(node_counter.count, 9)

# Check that returning NEXT_SIBLING prevents descent.
class TopLevelCounter:
def __init__(self):
self.count = 0

def __call__(self, node):
if not isinstance(node, FileNode):
self.count += 1
if isinstance(node, DirectiveNode):
return Visit.NEXT_SIBLING
return Visit.NEXT

top_level_counter = TopLevelCounter()
self.tree.visit(top_level_counter)
self.assertEqual(top_level_counter.count, 5)


if __name__ == "__main__":
Expand Down

0 comments on commit 29e7c77

Please sign in to comment.