Skip to content

Commit

Permalink
Merge pull request #7 from ggmarshall/xtalk
Browse files Browse the repository at this point in the history
evt fixes
  • Loading branch information
tdixon97 authored May 7, 2024
2 parents 75451ce + b3e9973 commit ff93678
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 40 deletions.
51 changes: 30 additions & 21 deletions src/pygama/evt/aggregators.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@ def evaluate_to_scalar(
def evaluate_at_channel(
datainfo,
tcm,
channels,
channels_skip,
expr,
field_list,
Expand All @@ -253,6 +254,8 @@ def evaluate_at_channel(
input and output LH5 datainfo with HDF5 groups where tables are found.
tcm
TCM data arrays in an object that can be accessed by attribute.
channels
list of channels to be included for evaluation.
channels_skip
list of channels to be skipped from evaluation and set to default value.
expr
Expand Down Expand Up @@ -281,7 +284,7 @@ def evaluate_at_channel(
evt_ids_ch = np.searchsorted(
tcm.cumulative_length, np.where(tcm.id == ch)[0], "right"
)
if table_name not in channels_skip:
if (table_name in channels) and (table_name not in channels_skip):
res = utils.get_data_at_channel(
datainfo=datainfo,
ch=table_name,
Expand All @@ -307,6 +310,7 @@ def evaluate_at_channel_vov(
expr,
field_list,
ch_comp,
channels,
channels_skip,
pars_dict=None,
default_value=np.nan,
Expand All @@ -326,6 +330,8 @@ def evaluate_at_channel_vov(
list of `dsp/hit/evt` parameter tuples in expression ``(tier, field)``.
ch_comp
array of "rawid"s at which the expression is evaluated.
channels
list of channels to be included for evaluation.
channels_skip
list of channels to be skipped from evaluation and set to default value.
pars_dict
Expand All @@ -335,20 +341,19 @@ def evaluate_at_channel_vov(
"""
f = utils.make_files_config(datainfo)

# blow up vov to aoesa
out = ak.Array([[] for _ in range(len(ch_comp))])
ch_comp_channels = np.unique(ch_comp.flattened_data.nda).astype(int)

channels = np.unique(ch_comp.flattened_data.nda).astype(int)
ch_comp = ch_comp.view_as("ak")
out = np.full(
len(ch_comp.flattened_data.nda), default_value, dtype=type(default_value)
)

type_name = None
for ch in channels:
for ch in ch_comp_channels:
table_name = utils.get_table_name_by_pattern(f.hit.table_fmt, ch)

evt_ids_ch = np.searchsorted(
tcm.cumulative_length, np.where(tcm.id == ch)[0], "right"
)
if table_name not in channels_skip:
if (table_name in channels) and (table_name not in channels_skip):
res = utils.get_data_at_channel(
datainfo=datainfo,
ch=table_name,
Expand All @@ -357,23 +362,27 @@ def evaluate_at_channel_vov(
field_list=field_list,
pars_dict=pars_dict,
)
else:
idx_ch = tcm.idx[tcm.id == ch]
res = np.full(len(idx_ch), default_value)

# see in which events the current channel is present
mask = ak.to_numpy(ak.any(ch_comp == ch, axis=-1), allow_missing=False)
cv = np.full(len(ch_comp), np.nan)
cv[evt_ids_ch] = res
cv[~mask] = np.nan
cv = ak.drop_none(ak.nan_to_none(ak.Array(cv)[:, None]))
new_evt_ids_ch = np.searchsorted(
ch_comp.cumulative_length,
np.where(ch_comp.flattened_data.nda == ch)[0],
"right",
)
matches = np.isin(evt_ids_ch, new_evt_ids_ch)
out[ch_comp.flattened_data.nda == ch] = res[matches]

out = ak.concatenate((out, cv), axis=-1)
else:
length = len(np.where(ch_comp.flattened_data.nda == ch)[0])
res = np.full(length, np.nan)
out[ch_comp.flattened_data.nda == ch] = res

if ch == channels[0]:
if ch == ch_comp_channels[0]:
out = out.astype(res.dtype)
type_name = res.dtype

return types.VectorOfVectors(ak.values_astype(out, type_name))
return types.VectorOfVectors(
flattened_data=types.Array(out, dtype=type_name),
cumulative_length=ch_comp.cumulative_length,
)


def evaluate_to_aoesa(
Expand Down
2 changes: 2 additions & 0 deletions src/pygama/evt/build_evt.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,7 @@ def evaluate_expression(
return aggregators.evaluate_at_channel(
datainfo=datainfo,
tcm=tcm,
channels=channels,
channels_skip=channels_skip,
expr=expr,
field_list=field_list,
Expand All @@ -512,6 +513,7 @@ def evaluate_expression(
expr=expr,
field_list=field_list,
ch_comp=ch_comp,
channels=channels,
channels_skip=channels_skip,
pars_dict=pars_dict,
default_value=default_value,
Expand Down
3 changes: 2 additions & 1 deletion src/pygama/evt/modules/geds.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def apply_xtalk_correction_and_calibrate(
xtalk_rawid_obj: str = "xtc/rawid_index",
xtalk_matrix_obj: str = "xtc/xtalk_matrix_negative",
positive_xtalk_matrix_obj: str = "xtc/xtalk_matrix_positive",
uncal_var: str = "dsp.cuspEmax",
recal_var: str = "hit.cuspEmax_ctc_cal",
) -> types.VectorOfVectors:
"""Applies the cross-talk correction to the energy observable.
Expand Down Expand Up @@ -191,7 +192,7 @@ def apply_xtalk_correction_and_calibrate(
energy_corr,
xtalk_matrix_rawids,
cal_par_files,
uncal_energy_expr,
uncal_var,
recal_var,
)

Expand Down
57 changes: 39 additions & 18 deletions src/pygama/evt/modules/xtalk.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def build_tcm_index_array(
"""

# initialise the output object
tcm_indexs_out = np.full((len(rawids), len(tcm.cumulative_length)), np.nan)
tcm_indexs_out = np.full((len(tcm.cumulative_length), len(rawids)), np.nan)

# parse observables string. default to hit tier
for idx_chan, channel in enumerate(rawids):
Expand All @@ -37,11 +37,15 @@ def build_tcm_index_array(
datainfo._asdict()["dsp"].table_fmt, f"ch{channel}"
)
tcm_indexs = np.where(tcm.id == table_id)[0]
idx_events = ak.to_numpy(tcm.idx[tcm.id == table_id])
tcm_indexs_out[idx_chan][idx_events] = tcm_indexs
evt_ids_ch = np.searchsorted(
tcm.cumulative_length,
np.where(tcm.id == channel)[0],
"right",
)
tcm_indexs_out[evt_ids_ch, idx_chan] = tcm_indexs

# transpose to return object where row is events and column rawid idx
return tcm_indexs_out.T
return tcm_indexs_out


def gather_energy(
Expand Down Expand Up @@ -84,18 +88,23 @@ def gather_energy(

for idx_chan, channel in enumerate(rawids):
tbl = types.Table()
idx_events = ak.to_numpy(tcm.idx[tcm.id == channel])
hit_idx = ak.to_numpy(tcm.idx[tcm.id == channel])
evt_ids_ch = np.searchsorted(
tcm.cumulative_length,
np.where(tcm.id == channel)[0],
"right",
)

for name, file, group, column in tier_params:
try:
# read the energy data
data = lh5.read(f"ch{channel}/{group}/{column}", file, idx=idx_events)
data = lh5.read(f"ch{channel}/{group}/{column}", file, idx=hit_idx)
tbl.add_column(name, data)
except (lh5.exceptions.LH5DecodeError, KeyError):
tbl.add_column(name, types.Array(np.full_like(idx_events, np.nan)))
tbl.add_column(name, types.Array(np.full_like(evt_ids_ch, np.nan)))

res = tbl.eval(observable)
energy_out[idx_events, idx_chan] = res.nda
energy_out[evt_ids_ch, idx_chan] = res.nda

return energy_out

Expand Down Expand Up @@ -146,23 +155,29 @@ def filter_hits(

for idx_chan, channel in enumerate(rawids):
tbl = types.Table()
idx_events = ak.to_numpy(tcm.idx[tcm.id == channel])

hit_idx = ak.to_numpy(tcm.idx[tcm.id == channel])
evt_ids_ch = np.searchsorted(
tcm.cumulative_length,
np.where(tcm.id == channel)[0],
"right",
)

for name, file, group, column in tier_params:
try:
# read the energy data
data = lh5.read(f"ch{channel}/{group}/{column}", file, idx=idx_events)
data = lh5.read(f"ch{channel}/{group}/{column}", file, idx=hit_idx)

tbl.add_column(name, data)
except (lh5.exceptions.LH5DecodeError, KeyError):
tbl.add_column(name, types.Array(np.full_like(idx_events, np.nan)))
tbl.add_column(name, types.Array(np.full_like(evt_ids_ch, np.nan)))

# add the corrected energy to the table
tbl.add_column(
"xtalk_corr_energy", types.Array(xtalk_corr_energy[idx_events, idx_chan])
"xtalk_corr_energy", types.Array(xtalk_corr_energy[evt_ids_ch, idx_chan])
)
res = tbl.eval(filter_expr)
mask[idx_events, idx_chan] = res.nda
mask[evt_ids_ch, idx_chan] = res.nda

return mask

Expand Down Expand Up @@ -269,7 +284,7 @@ def calibrate_energy(
Parameters
---------
datainfo
utils.DataInfo object containg the paths etc to the data
utils.DataInfo object containing the paths etc to the data
tcm
utils.TCMData object
energy_corr
Expand Down Expand Up @@ -305,25 +320,31 @@ def calibrate_energy(

# get the event indices
table_id = utils.get_tcm_id_by_pattern(table_fmt, f"ch{chan}")
idx_events = ak.to_numpy(tcm.idx[tcm.id == table_id])

hit_idx = ak.to_numpy(tcm.idx[tcm.id == table_id])
evt_ids_ch = np.searchsorted(
tcm.cumulative_length,
np.where(tcm.id == table_id)[0],
"right",
)

# read the dsp data
outtbl_obj = lh5.read(
f"ch{chan}/dsp/", file, idx=idx_events, field_mask=chan_inputs
f"ch{chan}/dsp/", file, idx=hit_idx, field_mask=chan_inputs
)

# add the uncalibrated energy to the table
outtbl_obj.add_column(
uncal_energy_var.split(".")[-1],
types.Array(energy_corr[idx_events, i]),
types.Array(energy_corr[evt_ids_ch, i]),
)

for outname, info in cfg.items():
outcol = outtbl_obj.eval(
info["expression"], info.get("parameters", None)
)
outtbl_obj.add_column(outname, outcol)
out_arr[idx_events, i] = outtbl_obj[recal_energy_var.split(".")[-1]].nda
out_arr[evt_ids_ch, i] = outtbl_obj[recal_energy_var.split(".")[-1]].nda
except KeyError:
out_arr[:, i] = np.nan

Expand Down

0 comments on commit ff93678

Please sign in to comment.