Skip to content

Commit

Permalink
Merge pull request #110 from iris-hep/fix_98_four-vector-sum
Browse files Browse the repository at this point in the history
Test `Sum()` of four-vectors
  • Loading branch information
masonproffitt authored Aug 15, 2023
2 parents 74857e0 + 524d8dc commit b47629c
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 1 deletion.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
'numpy',
'qastle>=0.16.0',
'uproot>=5',
'vector',
'vector>=1.1.0',
],
extras_require={'test': ['flake8', 'pytest', 'pytest-cov']},
author='Mason Proffitt',
Expand Down
29 changes: 29 additions & 0 deletions tests/test_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,6 +539,21 @@ def test_ast_executor_tofourmomentum_mass():
assert np.allclose(result[2].tolist(), [8.6420747579])


def test_ast_executor_tofourmomentum_sum():
python_source = (
"Select(EventDataset('tests/four-vector_tree_file.root', 'tree'),"
+ "lambda row: Zip({'pt': row.pt_vector_branch, 'eta': row.eta_vector_branch,"
+ "'phi': row.phi_vector_branch, 'e': row.e_vector_branch})"
+ ".Select(lambda four_vector: four_vector.ToFourMomentum()).Sum())"
)
python_ast = ast.parse(python_source)
result = ast_executor(python_ast)
assert np.allclose(result.pt, [0, 4.223308959107644, 11.11])
assert np.allclose(result.eta, [0.0, -5.135728460688669, 0.1212])
assert np.allclose(result.phi, [0.0, 0.37848189207358623, 3.13])
assert np.allclose(result.e, [0.0, 1107.9, 14.14])


def test_ast_executor_tofourmomenta():
python_source = (
"Select(EventDataset('tests/four-vector_tree_file.root', 'tree'),"
Expand Down Expand Up @@ -579,6 +594,20 @@ def test_ast_executor_tofourmomenta_mass():
assert np.allclose(result[2].tolist(), [8.6420747579])


def test_ast_executor_tofourmomenta_sum():
python_source = (
"Select(EventDataset('tests/four-vector_tree_file.root', 'tree'),"
+ "lambda row: Zip({'pt': row.pt_vector_branch, 'eta': row.eta_vector_branch,"
+ "'phi': row.phi_vector_branch, 'e': row.e_vector_branch}).ToFourMomenta().Sum())"
)
python_ast = ast.parse(python_source)
result = ast_executor(python_ast)
assert np.allclose(result.pt, [0, 4.223308959107644, 11.11])
assert np.allclose(result.eta, [0.0, -5.135728460688669, 0.1212])
assert np.allclose(result.phi, [0.0, 0.37848189207358623, 3.13])
assert np.allclose(result.e, [0.0, 1107.9, 14.14])


# def test_ast_executor_orderby_same_scalar_branch():
# python_source = (
# "OrderBy(EventDataset('tests/scalars_tree_file.root', 'tree'),"
Expand Down

0 comments on commit b47629c

Please sign in to comment.