diff --git a/src/lgdo/compression/generic.py b/src/lgdo/compression/generic.py index 07e038e8..82dbb81b 100644 --- a/src/lgdo/compression/generic.py +++ b/src/lgdo/compression/generic.py @@ -40,6 +40,7 @@ def encode( def decode( obj: lgdo.VectorOfEncodedVectors | lgdo.ArrayOfEncodedEqualSizedArrays, + out_buf: lgdo.ArrayOfEqualSizedArrays = None, ) -> lgdo.VectorOfVectors | lgdo.ArrayOfEqualsizedArrays: """Decode encoded LGDOs. @@ -51,6 +52,9 @@ def decode( ---------- obj LGDO array type. + out_buf + pre-allocated LGDO for the decoded signals. See documentation of + wrapped encoders for limitations. """ if "codec" not in obj.attrs: raise RuntimeError( @@ -61,9 +65,11 @@ def decode( log.debug(f"decoding {repr(obj)} with {codec}") if _is_codec(codec, radware.RadwareSigcompress): - return radware.decode(obj, shift=int(obj.attrs.get("codec_shift", 0))) + return radware.decode( + obj, sig_out=out_buf, shift=int(obj.attrs.get("codec_shift", 0)) + ) elif _is_codec(codec, varlen.ULEB128ZigZagDiff): - return varlen.decode(obj) + return varlen.decode(obj, sig_out=out_buf) else: raise ValueError(f"'{codec}' not supported") diff --git a/src/lgdo/compression/radware.py b/src/lgdo/compression/radware.py index f8235d7e..0332dfb7 100644 --- a/src/lgdo/compression/radware.py +++ b/src/lgdo/compression/radware.py @@ -120,7 +120,7 @@ def encode( return sig_out, nbytes elif isinstance(sig_in, lgdo.VectorOfVectors): - if sig_out: + if sig_out is not None: log.warning( "a pre-allocated VectorOfEncodedVectors was given " "to hold an encoded ArrayOfEqualSizedArrays. " @@ -143,7 +143,7 @@ def encode( return sig_out elif isinstance(sig_in, lgdo.ArrayOfEqualSizedArrays): - if sig_out: + if sig_out is not None: log.warning( "a pre-allocated ArrayOfEncodedEqualSizedArrays was given " "to hold an encoded ArrayOfEqualSizedArrays. " @@ -243,7 +243,7 @@ def decode( return sig_out, siglen elif isinstance(sig_in, lgdo.ArrayOfEncodedEqualSizedArrays): - if not sig_out: + if sig_out is None: # initialize output structure with decoded_size sig_out = lgdo.ArrayOfEqualSizedArrays( dims=(1, 1), @@ -263,7 +263,7 @@ def decode( # can now decode on the 2D matrix together with number of bytes to read per row _, siglen = decode( (sig_in.encoded_data.to_aoesa(preserve_dtype=True).nda, nbytes), - sig_out.nda, + sig_out if isinstance(sig_out, np.ndarray) else sig_out.nda, shift=shift, ) diff --git a/src/lgdo/compression/varlen.py b/src/lgdo/compression/varlen.py index e3a4846e..f273e038 100644 --- a/src/lgdo/compression/varlen.py +++ b/src/lgdo/compression/varlen.py @@ -94,7 +94,7 @@ def encode( return sig_out, nbytes elif isinstance(sig_in, lgdo.VectorOfVectors): - if sig_out: + if sig_out is not None: log.warning( "a pre-allocated VectorOfEncodedVectors was given " "to hold an encoded ArrayOfEqualSizedArrays. " @@ -208,7 +208,7 @@ def decode( return sig_out, siglen elif isinstance(sig_in, lgdo.ArrayOfEncodedEqualSizedArrays): - if not sig_out: + if sig_out is None: # initialize output structure with decoded_size sig_out = lgdo.ArrayOfEqualSizedArrays( dims=(1, 1), diff --git a/src/lgdo/lh5_store.py b/src/lgdo/lh5_store.py index 6c430459..dfdfa87f 100644 --- a/src/lgdo/lh5_store.py +++ b/src/lgdo/lh5_store.py @@ -531,26 +531,31 @@ def read_object( elif obj_buf is None and decompress: return compress.decode(rawdata), n_rows_read + # eventually expand provided obj_buf, if too short + buf_size = obj_buf_start + n_rows_read + if len(obj_buf) < buf_size: + obj_buf.resize(buf_size) + # use the (decoded object type) buffer otherwise - if enc_lgdo == VectorOfEncodedVectors and not isinstance( - obj_buf, VectorOfVectors - ): - raise ValueError( - f"obj_buf for decoded '{name}' not a VectorOfVectors" - ) - elif enc_lgdo == ArrayOfEncodedEqualSizedArrays and not isinstance( - obj_buf, ArrayOfEqualSizedArrays - ): - raise ValueError( - f"obj_buf for decoded '{name}' not an ArrayOfEqualSizedArrays" - ) + if enc_lgdo == ArrayOfEncodedEqualSizedArrays: + if not isinstance(obj_buf, ArrayOfEqualSizedArrays): + raise ValueError( + f"obj_buf for decoded '{name}' not an ArrayOfEqualSizedArrays" + ) + + compress.decode(rawdata, obj_buf[obj_buf_start:buf_size]) + + elif enc_lgdo == VectorOfEncodedVectors: + if not isinstance(obj_buf, VectorOfVectors): + raise ValueError( + f"obj_buf for decoded '{name}' not a VectorOfVectors" + ) - # FIXME: not a good idea. an in place decoding version - # of decode would be needed to avoid extra memory - # allocations - # FIXME: obj_buf_start??? Write a unit test - for i, wf in enumerate(compress.decode(rawdata)): - obj_buf[i] = wf + # FIXME: not a good idea. an in place decoding version + # of decode would be needed to avoid extra memory + # allocations + for i, wf in enumerate(compress.decode(rawdata)): + obj_buf[obj_buf_start + i] = wf return obj_buf, n_rows_read diff --git a/tests/compression/test_radware_sigcompress.py b/tests/compression/test_radware_sigcompress.py index ffeb44af..aacf38f6 100644 --- a/tests/compression/test_radware_sigcompress.py +++ b/tests/compression/test_radware_sigcompress.py @@ -107,8 +107,9 @@ def test_wrapper(wftable): enc_wfs = np.zeros(s[:-1] + (2 * s[-1],), dtype="ubyte") enclen = np.empty(s[0], dtype="uint32") + _shift = np.full(s[0], shift, dtype="int32") - _radware_sigcompress_encode(wfs, enc_wfs, shift, enclen, _mask) + _radware_sigcompress_encode(wfs, enc_wfs, _shift, enclen, _mask) # test if the wrapper gives the same result w_enc_wfs = np.zeros(s[:-1] + (2 * s[-1],), dtype="ubyte") @@ -167,6 +168,13 @@ def test_aoesa(wftable): for wf1, wf2 in zip(dec_aoesa, wftable.values): assert np.array_equal(wf1, wf2) + # test using pre-allocated decoded array + sig_out = ArrayOfEqualSizedArrays( + shape=wftable.values.nda.shape, dtype=wftable.values.dtype + ) + decode(enc_vov, sig_out=sig_out, shift=shift) + assert wftable.values == sig_out + def test_performance(lgnd_test_data): store = LH5Store() diff --git a/tests/test_lh5_store.py b/tests/test_lh5_store.py index c220c63a..13f775b2 100644 --- a/tests/test_lh5_store.py +++ b/tests/test_lh5_store.py @@ -340,23 +340,27 @@ def test_read_wftable_encoded(lh5_file): assert lh5_obj.values.attrs["codec"] == "radware_sigcompress" assert "codec_shift" in lh5_obj.values.attrs + lh5_obj, n_rows = store.read_object("/data/struct/wftable_enc/values", lh5_file) + assert isinstance(lh5_obj, lgdo.ArrayOfEqualSizedArrays) + assert n_rows == 3 + lh5_obj, n_rows = store.read_object("/data/struct/wftable_enc", lh5_file) assert isinstance(lh5_obj, lgdo.WaveformTable) assert isinstance(lh5_obj.values, lgdo.ArrayOfEqualSizedArrays) assert n_rows == 3 - lh5_obj, n_rows = store.read_object("/data/struct/wftable_enc/values", lh5_file) - assert isinstance(lh5_obj, lgdo.ArrayOfEqualSizedArrays) - assert n_rows == 3 - - lh5_obj, n_rows = store.read_object( + lh5_obj_chain, n_rows = store.read_object( "/data/struct/wftable_enc", [lh5_file, lh5_file], decompress=False ) assert n_rows == 6 + assert isinstance(lh5_obj_chain.values, lgdo.ArrayOfEncodedEqualSizedArrays) - lh5_obj, n_rows = store.read_object( + lh5_obj_chain, n_rows = store.read_object( "/data/struct/wftable_enc", [lh5_file, lh5_file], decompress=True ) + assert isinstance(lh5_obj_chain.values, lgdo.ArrayOfEqualSizedArrays) + assert np.array_equal(lh5_obj_chain.values[:3], lh5_obj.values) + assert np.array_equal(lh5_obj_chain.values[3:], lh5_obj.values) assert n_rows == 6