Skip to content

Commit

Permalink
Replace CLIF SbsWriter with pybind-based gcpp extension
Browse files Browse the repository at this point in the history
Maintains compatibility with previous version.

PiperOrigin-RevId: 693774946
  • Loading branch information
pchx authored and copybara-github committed Nov 13, 2024
1 parent 719699f commit e5994f2
Show file tree
Hide file tree
Showing 6 changed files with 47 additions and 21 deletions.
1 change: 1 addition & 0 deletions MODULE.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ bazel_dep(name = "googletest", version = "1.15.2")
bazel_dep(name = "highway", version = "1.1.0")
bazel_dep(name = "nlohmann_json", version = "3.11.3")
bazel_dep(name = "platforms", version = "0.0.10")
bazel_dep(name = "pybind11_bazel", version = "2.12.0")
bazel_dep(name = "rules_cc", version = "0.0.9")
bazel_dep(name = "rules_license", version = "0.0.7")
bazel_dep(name = "google_benchmark", version = "1.8.5")
Expand Down
11 changes: 6 additions & 5 deletions compression/python/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
load("//devtools/clif/python:clif_build_rule.bzl", "py_clif_cc")
# [internal] load strict.bzl
load("@pybind11_bazel//:build_defs.bzl", "pybind_extension")

package(
default_applicable_licenses = [
Expand All @@ -12,21 +12,22 @@ cc_library(
name = "compression_clif_aux",
srcs = ["compression_clif_aux.cc"],
hdrs = ["compression_clif_aux.h"],
visibility = ["//visibility:private"],
deps = [
"//third_party/absl/types:span",
"@abseil-cpp//absl/types:span",
"//compression:compress",
"//compression:io",
"@highway//:hwy",
"@highway//:thread_pool",
],
)

py_clif_cc(
pybind_extension(
name = "compression",
srcs = ["compression.clif"],
srcs = ["compression_extension.cc"],
deps = [
":compression_clif_aux",
"//third_party/absl/python/numpy:span_clif_lib",
"@abseil-cpp//absl/types:span",
],
)

Expand Down
14 changes: 0 additions & 14 deletions compression/python/compression.clif

This file was deleted.

2 changes: 1 addition & 1 deletion compression/python/compression_clif_aux.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
#ifndef GEMMA_ONCE
#define GEMMA_ONCE

#include "third_party/absl/types/span.h"
#include "absl/types/span.h"
#include "compression/io.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
Expand Down
2 changes: 1 addition & 1 deletion compression/python/compression_clif_aux.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#include <string>
#include <vector>

#include "third_party/absl/types/span.h"
#include "absl/types/span.h"

namespace gcpp {

Expand Down
38 changes: 38 additions & 0 deletions compression/python/compression_extension.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
#include <pybind11/pybind11.h>

#include <exception>
#include <stdexcept>
#include <string>

#include "absl/types/span.h"
#include "compression/python/compression_clif_aux.h"
#include "pybind11/numpy.h"
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"

using gcpp::SbsWriter;

namespace py = pybind11;

namespace {
template <auto Func>
void wrap_span(SbsWriter& writer, std::string name, py::array_t<float> data) {
if (data.ndim() != 1 || data.strides(0) != sizeof(float)) {
throw std::domain_error("Input array must be 1D and densely packed.");
}
std::invoke(Func, writer, name, absl::MakeSpan(data.data(0), data.size()));
}
} // namespace

PYBIND11_MODULE(compression, m) {
py::class_<SbsWriter>(m, "SbsWriter")
.def(py::init<>())
// NOTE: Individual compression backends may impose constraints on the
// array length, such as a minimum of (say) 32 elements.
.def("insert", wrap_span<&SbsWriter::Insert>)
.def("insert_nuq", wrap_span<&SbsWriter::InsertNUQ>)
.def("insert_bf16", wrap_span<&SbsWriter::InsertBfloat16>)
.def("insert_float", wrap_span<&SbsWriter::InsertFloat>)
.def("add_scales", &SbsWriter::AddScales)
.def("write", &SbsWriter::Write);
}

0 comments on commit e5994f2

Please sign in to comment.