Skip to content

Commit

Permalink
Merge pull request #66 from FAST-HEP/BK_fix_overwrite_treebranches
Browse files Browse the repository at this point in the history
Fix itervalues for overwritten branches
  • Loading branch information
benkrikler authored Aug 19, 2019
2 parents da6b5b6 + 4c03474 commit 2ca4ad7
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 2 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Removed

## [0.13.2] - 2019-08-19
### Changed
- Protect against overwriting branches and add tests, pull request #66 [@benkrikler](https://github.com/benkrikler)

## [0.13.1] - 2019-08-05
### Added
- Adds support for masking variables in their definition, issue #59 [@benkrikler](https://github.com/benkrikler)
Expand Down
3 changes: 3 additions & 0 deletions fast_carpenter/tree_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,9 @@ def __len__(self):
return len(self._values)

def new_variable(self, name, value):
if name in self:
msg = "Trying to overwrite existing variable: '%s'"
raise ValueError(msg % name)
if len(value) != len(self):
msg = "New array %s does not have the right length: %d not %d"
raise ValueError(msg % (name, len(value), len(self)))
Expand Down
2 changes: 1 addition & 1 deletion fast_carpenter/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,5 @@ def split_version(version):
return tuple(result)


__version__ = '0.13.1'
__version__ = '0.13.2'
version_info = split_version(__version__) # noqa
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[bumpversion]
current_version = 0.13.1
current_version = 0.13.2
commit = True
tag = False

Expand Down
9 changes: 9 additions & 0 deletions tests/test_tree_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import pytest
import numpy as np


Expand All @@ -10,3 +11,11 @@ def test_add_retrieve(wrapped_tree):
wrapped_tree.new_variable("Muon_momentum", muon_momentum)
retrieve_momentum = wrapped_tree.array("Muon_momentum")
assert (retrieve_momentum == muon_momentum).flatten().all()


def test_overwrite(wrapped_tree):
muon_px = wrapped_tree.array("Muon_Px")
with pytest.raises(ValueError) as err:
wrapped_tree.new_variable("Muon_Px", muon_px / muon_px)
assert "Muon_Px" in str(err)
assert len(wrapped_tree.keys(filtername=lambda x: x.decode() == "Muon_Px")) == 1

0 comments on commit 2ca4ad7

Please sign in to comment.