diff --git a/fast_carpenter/expressions.py b/fast_carpenter/expressions.py index 495feec..bce6d27 100644 --- a/fast_carpenter/expressions.py +++ b/fast_carpenter/expressions.py @@ -1,3 +1,4 @@ +import re import numexpr import tokenize import awkward @@ -34,13 +35,15 @@ class TreeToDictAdaptor(): """ Make an uproot tree look like a dict for numexpr """ - def __init__(self, tree): + def __init__(self, tree, alias_dict): self.tree = tree self.starts = None self.stops = None + self.aliases = alias_dict def __getitem__(self, item): - array = self.tree.array(item) + full_item = self.aliases.get(item, item) + array = self.tree.array(full_item) starts = getattr(array, "starts", None) if starts is not None: self.set_starts_stop(starts, array.stops) @@ -48,7 +51,7 @@ def __getitem__(self, item): return array def __contains__(self, item): - return item in self.tree + return item in self.tree or item in self.aliases def __iter__(self): for i in self.tree: @@ -63,9 +66,25 @@ def set_starts_stop(self, starts, stops): self.stops = stops +attribute_re = re.compile(r"([a-zA-Z]\w*)\s*\.\s*(\w+)") + + +def preprocess_expression(expression): + alias_dict = {} + replace_dict = {} + for match in attribute_re.finditer(expression): + original = match.group(0) + alias = match.expand(r"\1__DOT__\2") + alias_dict[alias] = original + replace_dict[original] = alias + clean_expr = attribute_re.sub(lambda x: replace_dict[x.group(0)], expression) + return clean_expr, alias_dict + + def evaluate(tree, expression): - adaptor = TreeToDictAdaptor(tree) - result = numexpr.evaluate(expression, local_dict=adaptor) + cleaned_expression, alias_dict = preprocess_expression(expression) + adaptor = TreeToDictAdaptor(tree, alias_dict) + result = numexpr.evaluate(cleaned_expression, local_dict=adaptor) if adaptor.starts is not None: result = awkward.JaggedArray(adaptor.starts, adaptor.stops, result) return result diff --git a/fast_carpenter/version.py b/fast_carpenter/version.py index 9b19412..2ceaa0e 100644 --- a/fast_carpenter/version.py +++ b/fast_carpenter/version.py @@ -12,5 +12,5 @@ def split_version(version): return tuple(result) -__version__ = '0.11.0' +__version__ = '0.11.1' version_info = split_version(__version__) # noqa diff --git a/setup.cfg b/setup.cfg index 1ce8092..7e7a349 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,5 +1,5 @@ [bumpversion] -current_version = 0.11.0-rc5 +current_version = 0.11.1 commit = True tag = False diff --git a/tests/test_expressions.py b/tests/test_expressions.py index 8d39037..89a83ea 100644 --- a/tests/test_expressions.py +++ b/tests/test_expressions.py @@ -23,3 +23,9 @@ def test_evaluate(wrapped_tree): def test_evaluate_bool(wrapped_tree): all_true = expressions.evaluate(wrapped_tree, "Muon_Px == Muon_Px") assert all(all_true.all()) + + +def test_evaluate_dot(wrapped_tree): + wrapped_tree.new_variable("Muon.Px", wrapped_tree.array("Muon_Px")) + all_true = expressions.evaluate(wrapped_tree, "Muon.Px == Muon_Px") + assert all(all_true.all())