Skip to content

Commit

Permalink
Merge pull request #9 from sameeul/follow_up_TS_writer
Browse files Browse the repository at this point in the history
Follow-up of #8
  • Loading branch information
sameeul authored Jun 27, 2024
2 parents 0b9d111 + 6d68864 commit 66cdc40
Show file tree
Hide file tree
Showing 9 changed files with 73 additions and 77 deletions.
6 changes: 3 additions & 3 deletions ci-utils/install_prereq_linux.sh
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ fi

mkdir -p $LOCAL_INSTALL_DIR

curl -L https://github.com/pybind/pybind11/archive/refs/tags/v2.11.1.zip -o v2.11.1.zip
unzip v2.11.1.zip
cd pybind11-2.11.1
curl -L https://github.com/pybind/pybind11/archive/refs/tags/v2.12.0.zip -o v2.12.0.zip
unzip v2.12.0.zip
cd pybind11-2.12.0
mkdir build_man
cd build_man
cmake -DCMAKE_INSTALL_PREFIX=../../$LOCAL_INSTALL_DIR/ -DPYBIND11_TEST=OFF ..
Expand Down
6 changes: 3 additions & 3 deletions ci-utils/install_prereq_win.bat
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@ mkdir local_install
mkdir local_install\include


curl -L https://github.com/pybind/pybind11/archive/refs/tags/v2.11.1.zip -o v2.11.1.zip
tar -xvf v2.11.1.zip
pushd pybind11-2.11.1
curl -L https://github.com/pybind/pybind11/archive/refs/tags/v2.12.0.zip -o v2.12.0.zip
tar -xvf v2.12.0.zip
pushd pybind11-2.12.0
mkdir build_man
pushd build_man
cmake -DCMAKE_INSTALL_PREFIX=../../local_install/ -DPYBIND11_TEST=OFF ..
Expand Down
2 changes: 1 addition & 1 deletion src/cpp/interface/interface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,5 +123,5 @@ PYBIND11_MODULE(libbfiocpp, m) {
// Writer class
py::class_<bfiocpp::TsWriterCPP, std::shared_ptr<bfiocpp::TsWriterCPP>>(m, "TsWriterCPP")
.def(py::init<const std::string&, const std::vector<std::int64_t>&, const std::vector<std::int64_t>&, const std::string&>())
.def("write", &bfiocpp::TsWriterCPP::write_image);
.def("write_image_data", &bfiocpp::TsWriterCPP::WriteImageData);
}
77 changes: 48 additions & 29 deletions src/cpp/utilities/utilities.cpp
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
#include <iomanip>
#include <ctime>
#include <chrono>
#include "utilities.h"
#include <cassert>
#include <tiffio.h>
#include <thread>

#include "tensorstore/driver/zarr/dtype.h"


using ::tensorstore::internal_zarr::ChooseBaseDType;

namespace bfiocpp {
tensorstore::Spec GetOmeTiffSpecToRead(const std::string& filename){
return tensorstore::Spec::FromJson({{"driver", "ometiff"},
Expand Down Expand Up @@ -42,9 +46,49 @@ uint16_t GetDataTypeCode (std::string_view type_name){
else if (type_name == std::string_view{"int64"}) {return 128;}
else if (type_name == std::string_view{"float32"}) {return 256;}
else if (type_name == std::string_view{"float64"}) {return 512;}
else if (type_name == std::string_view{"double"}) {return 512;}
else {return 2;}
}

std::string GetEncodedType(uint16_t data_type_code){
switch (data_type_code)
{
case 1:
return ChooseBaseDType(tensorstore::dtype_v<std::uint8_t>).value().encoded_dtype;
break;
case 2:
return ChooseBaseDType(tensorstore::dtype_v<std::uint16_t>).value().encoded_dtype;
break;
case 4:
return ChooseBaseDType(tensorstore::dtype_v<std::uint32_t>).value().encoded_dtype;
break;
case 8:
return ChooseBaseDType(tensorstore::dtype_v<std::uint16_t>).value().encoded_dtype;
break;
case 16:
return ChooseBaseDType(tensorstore::dtype_v<std::int8_t>).value().encoded_dtype;
break;
case 32:
return ChooseBaseDType(tensorstore::dtype_v<std::int16_t>).value().encoded_dtype;
break;
case 64:
return ChooseBaseDType(tensorstore::dtype_v<std::int32_t>).value().encoded_dtype;
break;
case 128:
return ChooseBaseDType(tensorstore::dtype_v<std::int64_t>).value().encoded_dtype;
break;
case 256:
return ChooseBaseDType(tensorstore::dtype_v<float>).value().encoded_dtype;
break;
case 512:
return ChooseBaseDType(tensorstore::dtype_v<double>).value().encoded_dtype;
break;
default:
return ChooseBaseDType(tensorstore::dtype_v<std::uint16_t>).value().encoded_dtype;
break;
}
}

std::string GetUTCString() {
// Get the current UTC time
auto now = std::chrono::system_clock::now();
Expand Down Expand Up @@ -112,6 +156,9 @@ tensorstore::Spec GetZarrSpecToWrite( const std::string& filename,
const std::vector<std::int64_t>& image_shape,
const std::vector<std::int64_t>& chunk_shape,
const std::string& dtype){

// valid values for dtype are subset of
// https://google.github.io/tensorstore/spec.html#json-dtype
return tensorstore::Spec::FromJson({{"driver", "zarr"},
{"kvstore", {{"driver", "file"},
{"path", filename}}
Expand All @@ -128,33 +175,5 @@ tensorstore::Spec GetZarrSpecToWrite( const std::string& filename,
{"dtype", dtype},
},
}}).value();
}

// Function to get the TensorStore DataType based on a string identifier
tensorstore::DataType GetTensorStoreDataType(const std::string& type_str) {
if (type_str == "uint8") {
return tensorstore::dtype_v<std::uint8_t>;
} else if (type_str == "uint16") {
return tensorstore::dtype_v<std::uint16_t>;
} else if (type_str == "uint32") {
return tensorstore::dtype_v<std::uint32_t>;
} else if (type_str == "uint64") {
return tensorstore::dtype_v<std::uint64_t>;
} else if (type_str == "int8") {
return tensorstore::dtype_v<std::int8_t>;
} else if (type_str == "int16") {
return tensorstore::dtype_v<std::int16_t>;
} else if (type_str == "int32") {
return tensorstore::dtype_v<std::int32_t>;
} else if (type_str == "int64") {
return tensorstore::dtype_v<std::int64_t>;
} else if (type_str == "float") {
return tensorstore::dtype_v<float>;
} else if (type_str == "double" || type_str == "float64") { // handle float64 from numpy
return tensorstore::dtype_v<double>;
} else {
throw std::invalid_argument("Unknown data type string: " + type_str);
}
}

} // ns bfiocpp
2 changes: 1 addition & 1 deletion src/cpp/utilities/utilities.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@ tensorstore::Spec GetOmeTiffSpecToRead(const std::string& filename);
tensorstore::Spec GetZarrSpecToRead(const std::string& filename);

uint16_t GetDataTypeCode (std::string_view type_name);
std::string GetEncodedType(uint16_t data_type_code);
std::string GetUTCString();
std::string GetOmeXml(const std::string& file_path);
std::tuple<std::optional<int>, std::optional<int>, std::optional<int>>ParseMultiscaleMetadata(const std::string& axes_list, int len);
tensorstore::DataType GetTensorStoreDataType(const std::string& type_str);
tensorstore::Spec GetZarrSpecToWrite(const std::string& filename,
const std::vector<std::int64_t>& image_shape,
const std::vector<std::int64_t>& chunk_shape,
Expand Down
31 changes: 12 additions & 19 deletions src/cpp/writer/tswriter.cpp
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
#include "tswriter.h"
#include <string>

#include "../utilities/utilities.h"
#include "tensorstore/array.h"
#include "tensorstore/open.h"

#include <variant>
#include <string>
#include "tswriter.h"
#include "../utilities/utilities.h"

using ::tensorstore::internal_zarr::ChooseBaseDType;

namespace bfiocpp {

Expand All @@ -14,27 +14,20 @@ TsWriterCPP::TsWriterCPP(
const std::vector<std::int64_t>& image_shape,
const std::vector<std::int64_t>& chunk_shape,
const std::string& dtype_str
): _filename(fname), _image_shape(image_shape), _chunk_shape(chunk_shape) {

_dtype_code = GetDataTypeCode(dtype_str);

std::string dtype_str_converted = (dtype_str == "float64") ? "double" : dtype_str; // change float64 numpy type to double

auto dtype = GetTensorStoreDataType(dtype_str_converted);

auto dtype_base = ChooseBaseDType(dtype).value().encoded_dtype;

auto spec = GetZarrSpecToWrite(_filename, image_shape, chunk_shape, dtype_base);

): _filename(fname),
_image_shape(image_shape),
_chunk_shape(chunk_shape),
_dtype_code(GetDataTypeCode(dtype_str)) {

TENSORSTORE_CHECK_OK_AND_ASSIGN(_source, tensorstore::Open(
spec,
GetZarrSpecToWrite(_filename, _image_shape, _chunk_shape, GetEncodedType(_dtype_code)),
tensorstore::OpenMode::create |
tensorstore::OpenMode::delete_existing,
tensorstore::ReadWriteMode::write).result());
}


void TsWriterCPP::write_image(py::array& py_image) {
void TsWriterCPP::WriteImageData(py::array& py_image) {

// use switch instead of template to avoid creating functions for each datatype
switch(_dtype_code)
Expand Down
20 changes: 2 additions & 18 deletions src/cpp/writer/tswriter.h
Original file line number Diff line number Diff line change
@@ -1,25 +1,9 @@
#pragma once

#include <string>
#include <memory>
#include <vector>
#include <variant>
#include <iostream>
#include <tuple>
#include <optional>
#include <unordered_map>
#include "../reader/sequence.h"

#include "tensorstore/tensorstore.h"
#include "tensorstore/context.h"
#include "tensorstore/array.h"
#include "tensorstore/driver/zarr/dtype.h"
#include "tensorstore/index_space/dim_expression.h"
#include "tensorstore/kvstore/kvstore.h"
#include "tensorstore/open.h"

#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "../reader/sequence.h"
#include <pybind11/numpy.h>

namespace py = pybind11;
Expand All @@ -30,7 +14,7 @@ class TsWriterCPP{
public:
TsWriterCPP(const std::string& fname, const std::vector<std::int64_t>& image_shape, const std::vector<std::int64_t>& chunk_shape, const std::string& dtype);

void write_image(py::array& py_image);
void WriteImageData(py::array& py_image);

private:
std::string _filename;
Expand Down
4 changes: 2 additions & 2 deletions src/python/bfiocpp/tswriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def __init__(
file_name, image_shape, chunk_shape, str(dtype)
)

def write_image(self, image_data: np.ndarray):
def write_image_data(self, image_data: np.ndarray):
"""Write image data to file
image_data: 5d numpy array containing image data
Expand All @@ -27,7 +27,7 @@ def write_image(self, image_data: np.ndarray):
raise ValueError("Image data must be a 5d numpy array")

try:
self._image_writer.write(image_data.flatten())
self._image_writer.write_image_data(image_data.flatten())

except Exception as e:
raise RuntimeError(f"Error writing image data: {e.what}")
Expand Down
2 changes: 1 addition & 1 deletion tests/test_write.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def test_write_zarr_2d(self):
test_file_path = os.path.join(dir, 'out/test.ome.zarr')

bw = TSWriter(test_file_path, tmp.shape, tmp.shape, str(tmp.dtype))
bw.write_image(tmp)
bw.write_image_data(tmp)
bw.close()

br = TSReader(
Expand Down

0 comments on commit 66cdc40

Please sign in to comment.