diff --git a/src/pygama/flow/__init__.py b/src/pygama/flow/__init__.py index 384d7603d..fcce73ed6 100644 --- a/src/pygama/flow/__init__.py +++ b/src/pygama/flow/__init__.py @@ -4,5 +4,6 @@ from .data_loader import DataLoader from .file_db import FileDB +from .data_group import DataGroup -__all__ = ["DataLoader", "FileDB"] +__all__ = ["DataLoader", "FileDB", "DataGroup"] diff --git a/src/pygama/flow/file_db.py b/src/pygama/flow/file_db.py index 4683737c0..2639b3122 100644 --- a/src/pygama/flow/file_db.py +++ b/src/pygama/flow/file_db.py @@ -226,7 +226,7 @@ def scan_files(self, dirs: list[str] = None) -> None: scan_dirs = [dirs] log.info(f"scanning {scan_dirs} with template {template}") - + print(f"scanning {scan_dirs} with template {template}") for scan_dir in scan_dirs: # some logic to guess where the scan directory is if not os.path.isabs(scan_dir): @@ -240,9 +240,11 @@ def scan_files(self, dirs: list[str] = None) -> None: scan_dir = os.path.join(os.getcwd(), scan_dir) log.debug(f"scanning {scan_dir}") + print(f"scanning {scan_dir}") for path, _, files in os.walk(scan_dir): log.debug(f"scanning {path}") + print(f"scanning {path}") n_files += len(files) for f in files: diff --git a/tests/flow/test_data_loader.py b/tests/flow/test_data_loader.py index 3c3b857f9..7236fa73a 100644 --- a/tests/flow/test_data_loader.py +++ b/tests/flow/test_data_loader.py @@ -84,22 +84,34 @@ def test_no_merge(test_dl): def test_outputs(test_dl): test_dl.set_files("type == 'phy'") - test_dl.set_datastreams([1057600, 1059201], "ch") + test_dl.set_datastreams([1057600, 1084803], "ch") test_dl.set_output( - fmt="pd.DataFrame", columns=["timestamp", "channel", "energies", "energy_in_pe"] + fmt="pd.DataFrame", columns=["energies", "trapEmax", "is_valid_0vbb", "is_valid_hit"] ) data = test_dl.load() assert isinstance(data, pd.DataFrame) - assert list(data.keys()) == [ + assert set(data.keys()) == set([ "hit_table", "hit_idx", "file", - "timestamp", - "channel", + # "timestamp", + # "channel", "energies", - "energy_in_pe", - ] + # "energy_in_pe", + "trapEmax", + "is_valid_0vbb", + "is_valid_hit" + ]) + print(data) + for id, row in data.iterrows(): + print() + if row.hit_table == 1057600: + print( row.is_valid_0vbb == False ) + print( row.trapEmax < 1e-20 ) + else: + print( (np.array(row.energies) < 1e-20).all() ) + print( True not in row.is_valid_hit ) def test_any_mode(test_dl):