diff --git a/src/pygama/evt/aggregators.py b/src/pygama/evt/aggregators.py index c9adee29b..9a82e9ce2 100644 --- a/src/pygama/evt/aggregators.py +++ b/src/pygama/evt/aggregators.py @@ -238,6 +238,7 @@ def evaluate_to_scalar( def evaluate_at_channel( datainfo, tcm, + channels, channels_skip, expr, field_list, @@ -253,6 +254,8 @@ def evaluate_at_channel( input and output LH5 datainfo with HDF5 groups where tables are found. tcm TCM data arrays in an object that can be accessed by attribute. + channels + list of channels to be included for evaluation. channels_skip list of channels to be skipped from evaluation and set to default value. expr @@ -281,7 +284,7 @@ def evaluate_at_channel( evt_ids_ch = np.searchsorted( tcm.cumulative_length, np.where(tcm.id == ch)[0], "right" ) - if table_name not in channels_skip: + if (table_name in channels) and (table_name not in channels_skip): res = utils.get_data_at_channel( datainfo=datainfo, ch=table_name, @@ -307,6 +310,7 @@ def evaluate_at_channel_vov( expr, field_list, ch_comp, + channels, channels_skip, pars_dict=None, default_value=np.nan, @@ -326,6 +330,8 @@ def evaluate_at_channel_vov( list of `dsp/hit/evt` parameter tuples in expression ``(tier, field)``. ch_comp array of "rawid"s at which the expression is evaluated. + channels + list of channels to be included for evaluation. channels_skip list of channels to be skipped from evaluation and set to default value. pars_dict @@ -335,20 +341,19 @@ def evaluate_at_channel_vov( """ f = utils.make_files_config(datainfo) - # blow up vov to aoesa - out = ak.Array([[] for _ in range(len(ch_comp))]) + ch_comp_channels = np.unique(ch_comp.flattened_data.nda).astype(int) - channels = np.unique(ch_comp.flattened_data.nda).astype(int) - ch_comp = ch_comp.view_as("ak") + out = np.full( + len(ch_comp.flattened_data.nda), default_value, dtype=type(default_value) + ) type_name = None - for ch in channels: + for ch in ch_comp_channels: table_name = utils.get_table_name_by_pattern(f.hit.table_fmt, ch) - evt_ids_ch = np.searchsorted( tcm.cumulative_length, np.where(tcm.id == ch)[0], "right" ) - if table_name not in channels_skip: + if (table_name in channels) and (table_name not in channels_skip): res = utils.get_data_at_channel( datainfo=datainfo, ch=table_name, @@ -357,23 +362,27 @@ def evaluate_at_channel_vov( field_list=field_list, pars_dict=pars_dict, ) - else: - idx_ch = tcm.idx[tcm.id == ch] - res = np.full(len(idx_ch), default_value) - - # see in which events the current channel is present - mask = ak.to_numpy(ak.any(ch_comp == ch, axis=-1), allow_missing=False) - cv = np.full(len(ch_comp), np.nan) - cv[evt_ids_ch] = res - cv[~mask] = np.nan - cv = ak.drop_none(ak.nan_to_none(ak.Array(cv)[:, None])) + new_evt_ids_ch = np.searchsorted( + ch_comp.cumulative_length, + np.where(ch_comp.flattened_data.nda == ch)[0], + "right", + ) + matches = np.isin(evt_ids_ch, new_evt_ids_ch) + out[ch_comp.flattened_data.nda == ch] = res[matches] - out = ak.concatenate((out, cv), axis=-1) + else: + length = len(np.where(ch_comp.flattened_data.nda == ch)[0]) + res = np.full(length, np.nan) + out[ch_comp.flattened_data.nda == ch] = res - if ch == channels[0]: + if ch == ch_comp_channels[0]: + out = out.astype(res.dtype) type_name = res.dtype - return types.VectorOfVectors(ak.values_astype(out, type_name)) + return types.VectorOfVectors( + flattened_data=types.Array(out, dtype=type_name), + cumulative_length=ch_comp.cumulative_length, + ) def evaluate_to_aoesa( diff --git a/src/pygama/evt/build_evt.py b/src/pygama/evt/build_evt.py index 3620dd373..66be91bf6 100644 --- a/src/pygama/evt/build_evt.py +++ b/src/pygama/evt/build_evt.py @@ -497,6 +497,7 @@ def evaluate_expression( return aggregators.evaluate_at_channel( datainfo=datainfo, tcm=tcm, + channels=channels, channels_skip=channels_skip, expr=expr, field_list=field_list, @@ -512,6 +513,7 @@ def evaluate_expression( expr=expr, field_list=field_list, ch_comp=ch_comp, + channels=channels, channels_skip=channels_skip, pars_dict=pars_dict, default_value=default_value, diff --git a/src/pygama/evt/modules/geds.py b/src/pygama/evt/modules/geds.py index e806bd150..ecd926b9b 100644 --- a/src/pygama/evt/modules/geds.py +++ b/src/pygama/evt/modules/geds.py @@ -132,6 +132,7 @@ def apply_xtalk_correction_and_calibrate( xtalk_rawid_obj: str = "xtc/rawid_index", xtalk_matrix_obj: str = "xtc/xtalk_matrix_negative", positive_xtalk_matrix_obj: str = "xtc/xtalk_matrix_positive", + uncal_var: str = "dsp.cuspEmax", recal_var: str = "hit.cuspEmax_ctc_cal", ) -> types.VectorOfVectors: """Applies the cross-talk correction to the energy observable. @@ -191,7 +192,7 @@ def apply_xtalk_correction_and_calibrate( energy_corr, xtalk_matrix_rawids, cal_par_files, - uncal_energy_expr, + uncal_var, recal_var, ) diff --git a/src/pygama/evt/modules/xtalk.py b/src/pygama/evt/modules/xtalk.py index 6f10677bf..e4cc05a0f 100644 --- a/src/pygama/evt/modules/xtalk.py +++ b/src/pygama/evt/modules/xtalk.py @@ -27,7 +27,7 @@ def build_tcm_index_array( """ # initialise the output object - tcm_indexs_out = np.full((len(rawids), len(tcm.cumulative_length)), np.nan) + tcm_indexs_out = np.full((len(tcm.cumulative_length), len(rawids)), np.nan) # parse observables string. default to hit tier for idx_chan, channel in enumerate(rawids): @@ -37,11 +37,15 @@ def build_tcm_index_array( datainfo._asdict()["dsp"].table_fmt, f"ch{channel}" ) tcm_indexs = np.where(tcm.id == table_id)[0] - idx_events = ak.to_numpy(tcm.idx[tcm.id == table_id]) - tcm_indexs_out[idx_chan][idx_events] = tcm_indexs + evt_ids_ch = np.searchsorted( + tcm.cumulative_length, + np.where(tcm.id == channel)[0], + "right", + ) + tcm_indexs_out[evt_ids_ch, idx_chan] = tcm_indexs # transpose to return object where row is events and column rawid idx - return tcm_indexs_out.T + return tcm_indexs_out def gather_energy( @@ -84,18 +88,23 @@ def gather_energy( for idx_chan, channel in enumerate(rawids): tbl = types.Table() - idx_events = ak.to_numpy(tcm.idx[tcm.id == channel]) + hit_idx = ak.to_numpy(tcm.idx[tcm.id == channel]) + evt_ids_ch = np.searchsorted( + tcm.cumulative_length, + np.where(tcm.id == channel)[0], + "right", + ) for name, file, group, column in tier_params: try: # read the energy data - data = lh5.read(f"ch{channel}/{group}/{column}", file, idx=idx_events) + data = lh5.read(f"ch{channel}/{group}/{column}", file, idx=hit_idx) tbl.add_column(name, data) except (lh5.exceptions.LH5DecodeError, KeyError): - tbl.add_column(name, types.Array(np.full_like(idx_events, np.nan))) + tbl.add_column(name, types.Array(np.full_like(evt_ids_ch, np.nan))) res = tbl.eval(observable) - energy_out[idx_events, idx_chan] = res.nda + energy_out[evt_ids_ch, idx_chan] = res.nda return energy_out @@ -146,23 +155,29 @@ def filter_hits( for idx_chan, channel in enumerate(rawids): tbl = types.Table() - idx_events = ak.to_numpy(tcm.idx[tcm.id == channel]) + + hit_idx = ak.to_numpy(tcm.idx[tcm.id == channel]) + evt_ids_ch = np.searchsorted( + tcm.cumulative_length, + np.where(tcm.id == channel)[0], + "right", + ) for name, file, group, column in tier_params: try: # read the energy data - data = lh5.read(f"ch{channel}/{group}/{column}", file, idx=idx_events) + data = lh5.read(f"ch{channel}/{group}/{column}", file, idx=hit_idx) tbl.add_column(name, data) except (lh5.exceptions.LH5DecodeError, KeyError): - tbl.add_column(name, types.Array(np.full_like(idx_events, np.nan))) + tbl.add_column(name, types.Array(np.full_like(evt_ids_ch, np.nan))) # add the corrected energy to the table tbl.add_column( - "xtalk_corr_energy", types.Array(xtalk_corr_energy[idx_events, idx_chan]) + "xtalk_corr_energy", types.Array(xtalk_corr_energy[evt_ids_ch, idx_chan]) ) res = tbl.eval(filter_expr) - mask[idx_events, idx_chan] = res.nda + mask[evt_ids_ch, idx_chan] = res.nda return mask @@ -269,7 +284,7 @@ def calibrate_energy( Parameters --------- datainfo - utils.DataInfo object containg the paths etc to the data + utils.DataInfo object containing the paths etc to the data tcm utils.TCMData object energy_corr @@ -305,17 +320,23 @@ def calibrate_energy( # get the event indices table_id = utils.get_tcm_id_by_pattern(table_fmt, f"ch{chan}") - idx_events = ak.to_numpy(tcm.idx[tcm.id == table_id]) + + hit_idx = ak.to_numpy(tcm.idx[tcm.id == table_id]) + evt_ids_ch = np.searchsorted( + tcm.cumulative_length, + np.where(tcm.id == table_id)[0], + "right", + ) # read the dsp data outtbl_obj = lh5.read( - f"ch{chan}/dsp/", file, idx=idx_events, field_mask=chan_inputs + f"ch{chan}/dsp/", file, idx=hit_idx, field_mask=chan_inputs ) # add the uncalibrated energy to the table outtbl_obj.add_column( uncal_energy_var.split(".")[-1], - types.Array(energy_corr[idx_events, i]), + types.Array(energy_corr[evt_ids_ch, i]), ) for outname, info in cfg.items(): @@ -323,7 +344,7 @@ def calibrate_energy( info["expression"], info.get("parameters", None) ) outtbl_obj.add_column(outname, outcol) - out_arr[idx_events, i] = outtbl_obj[recal_energy_var.split(".")[-1]].nda + out_arr[evt_ids_ch, i] = outtbl_obj[recal_energy_var.split(".")[-1]].nda except KeyError: out_arr[:, i] = np.nan