Skip to content

Commit

Permalink
zune-python: Add numpy support.
Browse files Browse the repository at this point in the history
Signed-off-by: caleb <[email protected]>
  • Loading branch information
etemesi254 committed Oct 16, 2023
1 parent 55dfa52 commit 744613c
Show file tree
Hide file tree
Showing 5 changed files with 219 additions and 4 deletions.
7 changes: 4 additions & 3 deletions crates/zune-python/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@ edition = "2021"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[lib]
name = "zune_python"
name = "zune_image"
crate-type = ["cdylib"]

[dependencies]
pyo3 = "0.19.0"
pyo3 = "0.20.0"
zune-png = { path = "../zune-png" }
zune-jpeg = { path = "../zune-jpeg" }
zune-image = { path = "../zune-image" }
zune-core = { path = "../zune-core" }
zune-core = { path = "../zune-core" }
numpy = "0.20.0"
3 changes: 3 additions & 0 deletions crates/zune-python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ build-backend = "maturin"

[project]
name = "zune-image"
version = "0.1.1"
description = "The zune-image rust library python bindings"
license = "MIT"
requires-python = ">=3.7"
classifiers = [
"Programming Language :: Rust",
Expand Down
3 changes: 2 additions & 1 deletion crates/zune-python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ mod py_image;

/// A Python module implemented in Rust.
#[pymodule]
fn zune_python(_py: Python, m: &PyModule) -> PyResult<()> {
#[pyo3(name = "zune_image")]
fn zune_image(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_class::<PyImageFormats>()?;
m.add_class::<PyImageColorSpace>()?;
m.add_class::<PyImage>()?;
Expand Down
19 changes: 19 additions & 0 deletions crates/zune-python/src/py_image.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,11 @@
*
* You can redistribute it or modify it under terms of the MIT, Apache License or Zlib license
*/
mod numpy_bindings;

use std::fs::read;

use numpy::PyArray3;
use pyo3::exceptions::PyException;
use pyo3::prelude::*;
use zune_image::filters::box_blur::BoxBlur;
Expand Down Expand Up @@ -222,6 +225,13 @@ impl PyImage {

/// Convert from one depth to another
///
/// The following are the depth conversion details
///
/// - INT->Float : Convert to float and divide by max value for the previous integer type(255 for u8,65535 for u16).
/// - Float->Int : Multiply by max value of the new depth (255->Eight,65535->16)
/// - smallInt->Int : Multiply by (MAX_LARGE_INT/MAX_SMALL_INT)
/// - LargeInt->SmallInt: Divide by (MAX_LARGE_INT/MAX_SMALL_INT)
///
/// # Arguments
/// - to: The new depth to convert to
/// - in_place: Whether to perform the conversion in place or to create a copy and convert that
Expand Down Expand Up @@ -678,6 +688,15 @@ impl PyImage {
Ok(Some(im_clone))
}
}
pub fn to_numpy_u8<'py>(&self, py: Python<'py>) -> PyResult<&'py PyArray3<u8>> {
self.to_numpy_generic(py, PyImageDepth::Eight)
}
pub fn to_numpy_u16<'py>(&self, py: Python<'py>) -> PyResult<&'py PyArray3<u16>> {
self.to_numpy_generic(py, PyImageDepth::Sixteen)
}
pub fn to_numpy_f32<'py>(&self, py: Python<'py>) -> PyResult<&'py PyArray3<f32>> {
self.to_numpy_generic(py, PyImageDepth::F32)
}
}

#[pyfunction]
Expand Down
191 changes: 191 additions & 0 deletions crates/zune-python/src/py_image/numpy_bindings.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
/*
* Copyright (c) 2023.
*
* This software is free software;
*
* You can redistribute it or modify it under terms of the MIT, Apache License or Zlib license
*/

use numpy::PyArray3;
use pyo3::exceptions::PyException;
use pyo3::{PyErr, PyResult, Python};

use crate::py_enums::PyImageDepth;
use crate::py_image::PyImage;

impl PyImage {
pub(crate) fn to_numpy_generic<'py, T>(
&self, py: Python<'py>, expected: PyImageDepth
) -> PyResult<&'py PyArray3<T>>
where
T: Copy + Default + 'static + numpy::Element + Send
{
let arr = unsafe {
let colorspace = self.image.get_colorspace();
//PyArray3::uget_raw()
let arr = PyArray3::<T>::new(
py,
[self.height(), self.width(), colorspace.num_components()],
false
);

//obtain first channel
let channels = self.image.get_frames_ref()[0].get_channels_ref(colorspace, false);
for chan in channels {
if chan.reinterpret_as::<T>().is_err() {
return Err(PyErr::new::<PyException, _>(format!(
"The image depth {:?} is not u8 use image.convert_depth({:?}) to convert to 8 bit \nWe do not implicitly convert to desired depth", self.image.get_depth(), expected
)));
}
}
let reinterprets: Vec<&[T]> = channels
.iter()
.map(|z| z.reinterpret_as::<T>().unwrap())
.collect();

let width = self.width();
let height = self.height();

let dims = height.checked_mul(width);
// check for overflow
if dims.is_none() {
return Err(PyErr::new::<PyException, _>(format!(
"width * height overflowed to big of dimensions ({},{})",
width, height
)));
}
let dims = dims.unwrap();
// check that all reinterprets' length never passes dims
// SAFETY CHECK: DO NOT REMOVE
for chan in &reinterprets {
if dims != chan.len() {
return Err(PyErr::new::<PyException, _>(format!(
"[INTERNAL-ERROR]: length of one channel doesn't match the expected len={},expected={}",
chan.len(), dims
)));
}
}
// check that arr dims == length
match arr.dims()[2] {
1 => {
assert_eq!(reinterprets.len(), arr.dims()[2]);
// convert into u8
// get each pixel from each channel, so we iterate per row
for i in 0..arr.dims()[0] {
for j in 0..arr.dims()[1] {
let idx = (i * width) + j;
{
arr.uget_raw([i, j, 0])
.write(*reinterprets.get_unchecked(0).get_unchecked(idx));
}
}
}
}
2 => {
// convert into T
// get each pixel from each channel, so we iterate per row
// optimized to use unsafe.
//
// # SAFETY
// - Unchecked memory access
// - We have two memory accesses we care about,
// 1: uget_raw, that should never matter, since, we are iterating
// over arr.dims[0] and arr.dims[1],
// and we know arr_dims[2] is 2, (in this particular match)
// 2. reinterprets.get_unchecked(0), we assert below that the length is three
// 3. reinterprets.get_unchecked(0).get_unchecked(idx), we assert above(just above the match)
// that the overflow doesn't happen

// safety check, do not remove...
assert_eq!(reinterprets.len(), 2);
for i in 0..arr.dims()[0] {
for j in 0..arr.dims()[1] {
let idx = (i * width) + j;
arr.uget_raw([i, j, 0])
.write(*reinterprets.get_unchecked(0).get_unchecked(idx));
arr.uget_raw([i, j, 1])
.write(*reinterprets.get_unchecked(1).get_unchecked(idx));
}
}
}
3 => {
// convert into T
// get each pixel from each channel, so we iterate per row
// optimized to use unsafe.
//
// # SAFETY
// - Unchecked memory access
// - We have two memory accesses we care about,
// 1: uget_raw, that should never matter, since, we are iterating
// over arr.dims[0] and arr.dims[1],
// and we know arr_dims[2] is 3, (in this particular match)
// 2. reinterprets.get_unchecked(0), we assert below that the length is three
// 3. reinterprets.get_unchecked(0).get_unchecked(idx), we assert above(just above the match)
// that the overflow doesn't happen

// safety check, do not remove...
assert_eq!(reinterprets.len(), 3);
for i in 0..arr.dims()[0] {
for j in 0..arr.dims()[1] {
let idx = (i * width) + j;
arr.uget_raw([i, j, 0])
.write(*reinterprets.get_unchecked(0).get_unchecked(idx));
arr.uget_raw([i, j, 1])
.write(*reinterprets.get_unchecked(1).get_unchecked(idx));
arr.uget_raw([i, j, 2])
.write(*reinterprets.get_unchecked(2).get_unchecked(idx));
}
}
}
4 => {
// convert into T
// get each pixel from each channel, so we iterate per row
// optimized to use unsafe.
//
// # SAFETY
// - Unchecked memory access
// - We have two memory accesses we care about,
// 1: uget_raw, that should never matter, since, we are iterating
// over arr.dims[0] and arr.dims[1],
// and we know arr_dims[2] is 4, (in this particular match)
// 2. reinterprets.get_unchecked(0), we assert below that the length is three
// 3. reinterprets.get_unchecked(0).get_unchecked(idx), we assert above(just above the match)
// that the overflow doesn't happen

// safety check, do not remove...
assert_eq!(reinterprets.len(), 4);
for i in 0..arr.dims()[0] {
for j in 0..arr.dims()[1] {
let idx = (i * width) + j;
arr.uget_raw([i, j, 0])
.write(*reinterprets.get_unchecked(0).get_unchecked(idx));
arr.uget_raw([i, j, 1])
.write(*reinterprets.get_unchecked(1).get_unchecked(idx));
arr.uget_raw([i, j, 2])
.write(*reinterprets.get_unchecked(2).get_unchecked(idx));
arr.uget_raw([i, j, 3])
.write(*reinterprets.get_unchecked(3).get_unchecked(idx));
}
}
}
_ => {
assert_eq!(reinterprets.len(), arr.dims()[2]);
// convert into u8
// get each pixel from each channel, so we iterate per row
for i in 0..arr.dims()[0] {
for j in 0..arr.dims()[1] {
let idx = (i * width) + j;
for k in 0..arr.dims()[2] {
arr.uget_raw([i, j, k])
.write(*reinterprets.get_unchecked(k).get_unchecked(idx));
}
}
}
}
}

arr
};
return Ok(arr);
}
}

0 comments on commit 744613c

Please sign in to comment.