diff --git a/setup.py b/setup.py index d854d41..cc9670c 100644 --- a/setup.py +++ b/setup.py @@ -22,7 +22,7 @@ 'numpy', 'qastle>=0.16.0', 'uproot>=5', - 'vector', + 'vector>=1.1.0', ], extras_require={'test': ['flake8', 'pytest', 'pytest-cov']}, author='Mason Proffitt', diff --git a/tests/test_executor.py b/tests/test_executor.py index 71b3df1..bde9624 100644 --- a/tests/test_executor.py +++ b/tests/test_executor.py @@ -539,6 +539,21 @@ def test_ast_executor_tofourmomentum_mass(): assert np.allclose(result[2].tolist(), [8.6420747579]) +def test_ast_executor_tofourmomentum_sum(): + python_source = ( + "Select(EventDataset('tests/four-vector_tree_file.root', 'tree')," + + "lambda row: Zip({'pt': row.pt_vector_branch, 'eta': row.eta_vector_branch," + + "'phi': row.phi_vector_branch, 'e': row.e_vector_branch})" + + ".Select(lambda four_vector: four_vector.ToFourMomentum()).Sum())" + ) + python_ast = ast.parse(python_source) + result = ast_executor(python_ast) + assert np.allclose(result.pt, [0, 4.223308959107644, 11.11]) + assert np.allclose(result.eta, [0.0, -5.135728460688669, 0.1212]) + assert np.allclose(result.phi, [0.0, 0.37848189207358623, 3.13]) + assert np.allclose(result.e, [0.0, 1107.9, 14.14]) + + def test_ast_executor_tofourmomenta(): python_source = ( "Select(EventDataset('tests/four-vector_tree_file.root', 'tree')," @@ -579,6 +594,20 @@ def test_ast_executor_tofourmomenta_mass(): assert np.allclose(result[2].tolist(), [8.6420747579]) +def test_ast_executor_tofourmomenta_sum(): + python_source = ( + "Select(EventDataset('tests/four-vector_tree_file.root', 'tree')," + + "lambda row: Zip({'pt': row.pt_vector_branch, 'eta': row.eta_vector_branch," + + "'phi': row.phi_vector_branch, 'e': row.e_vector_branch}).ToFourMomenta().Sum())" + ) + python_ast = ast.parse(python_source) + result = ast_executor(python_ast) + assert np.allclose(result.pt, [0, 4.223308959107644, 11.11]) + assert np.allclose(result.eta, [0.0, -5.135728460688669, 0.1212]) + assert np.allclose(result.phi, [0.0, 0.37848189207358623, 3.13]) + assert np.allclose(result.e, [0.0, 1107.9, 14.14]) + + # def test_ast_executor_orderby_same_scalar_branch(): # python_source = ( # "OrderBy(EventDataset('tests/scalars_tree_file.root', 'tree'),"