From 7a537151aafe445ecc8ca0f2c006069f571d2dbe Mon Sep 17 00:00:00 2001 From: elftausend <76885970+elftausend@users.noreply.github.com> Date: Sun, 13 Oct 2024 13:28:25 +0200 Subject: [PATCH] Add WrappedCopy impls --- Cargo.toml | 2 +- src/buffer.rs | 9 +++++--- src/devices/cpu/cpu_ptr.rs | 11 +++++++++- src/devices/cuda/cuda_ptr.rs | 11 +++++++++- src/devices/opencl/cl_ptr.rs | 11 +++++++++- src/devices/stack_array.rs | 12 ++++++++++- src/devices/vulkan/vk_array.rs | 11 +++++++++- src/features.rs | 2 +- src/lib.rs | 5 +++++ src/modules/autograd/wrapper.rs | 18 +++++++++++++++- src/modules/lazy/wrapper.rs | 29 ++++++++++++++++++++++---- src/modules/lazy/wrapper/maybe_data.rs | 8 +++---- src/range.rs | 6 ++---- 13 files changed, 112 insertions(+), 23 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 9f556d70..a90bb07a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -51,7 +51,7 @@ min-cl = { git = "https://github.com/elftausend/min-cl", optional = true } # min-cl = { version = "0.3.0", optional=true } [features] -default = ["cpu", "opencl", "blas", "static-api", "stack", "macro", "untyped", "autograd", "autograd", "cached", "lazy", "fork", "graph", "serde"] +default = ["cpu", "blas", "static-api", "macro", "cached", "autograd", "cuda", "vulkan", "stack"] # default = ["cpu"] # default = ["no-std"] diff --git a/src/buffer.rs b/src/buffer.rs index 31638197..ce74e220 100644 --- a/src/buffer.rs +++ b/src/buffer.rs @@ -12,7 +12,7 @@ use crate::CPU; use crate::{ flag::AllocFlag, shape::Shape, Alloc, Base, ClearBuf, CloneBuf, Device, DevicelessAble, HasId, IsShapeIndep, OnDropBuffer, OnNewBuffer, PtrType, Read, ReplaceBuf, ShallowCopy, Unit, - WrappedData, WriteBuf, ZeroGrad, + WrappedCopy, WrappedData, WriteBuf, ZeroGrad, }; pub use self::num::Num; @@ -477,11 +477,14 @@ impl<'a, T: Unit, D: Device, S: Shape> Buffer<'a, T, D, S> { pub fn to_dims(self) -> Buffer<'a, T, D, O> where D: crate::ToDim, - D::Data: ShallowCopy, + D::Data: WrappedCopy>, + D::Base: ShallowCopy, { + let base = unsafe { (*self).shallow() }; + let data = self.data.wrapped_copy(base); let buf = ManuallyDrop::new(self); - let mut data = buf.device().to_dim(unsafe { buf.data.shallow() }); + let mut data = buf.device().to_dim(data); unsafe { data.set_flag(AllocFlag::None) }; Buffer { diff --git a/src/devices/cpu/cpu_ptr.rs b/src/devices/cpu/cpu_ptr.rs index e2a41801..9709c3e5 100644 --- a/src/devices/cpu/cpu_ptr.rs +++ b/src/devices/cpu/cpu_ptr.rs @@ -7,7 +7,7 @@ use core::{ use std::alloc::handle_alloc_error; -use crate::{flag::AllocFlag, HasId, HostPtr, Id, PtrType, ShallowCopy}; +use crate::{flag::AllocFlag, HasId, HostPtr, Id, PtrType, ShallowCopy, WrappedCopy}; /// The pointer used for `CPU` [`Buffer`](crate::Buffer)s #[derive(Debug)] @@ -229,6 +229,15 @@ impl PtrType for CPUPtr { } } +impl WrappedCopy for CPUPtr { + type Base = Self; + + #[inline] + fn wrapped_copy(&self, to_wrap: Self::Base) -> Self { + to_wrap + } +} + impl ShallowCopy for CPUPtr { #[inline] unsafe fn shallow(&self) -> Self { diff --git a/src/devices/cuda/cuda_ptr.rs b/src/devices/cuda/cuda_ptr.rs index c5891183..794c90de 100644 --- a/src/devices/cuda/cuda_ptr.rs +++ b/src/devices/cuda/cuda_ptr.rs @@ -1,5 +1,5 @@ use super::api::{cu_read, cufree, cumalloc, CudaResult}; -use crate::{flag::AllocFlag, HasId, Id, PtrType, ShallowCopy}; +use crate::{flag::AllocFlag, HasId, Id, PtrType, ShallowCopy, WrappedCopy}; use core::marker::PhantomData; /// The pointer used for `CUDA` [`Buffer`](crate::Buffer)s @@ -76,6 +76,15 @@ impl Drop for CUDAPtr { } } +impl WrappedCopy for CUDAPtr { + type Base = Self; + + #[inline] + fn wrapped_copy(&self, to_wrap: Self::Base) -> Self { + to_wrap + } +} + impl ShallowCopy for CUDAPtr { #[inline] unsafe fn shallow(&self) -> Self { diff --git a/src/devices/opencl/cl_ptr.rs b/src/devices/opencl/cl_ptr.rs index 56467465..c7b80fb3 100644 --- a/src/devices/opencl/cl_ptr.rs +++ b/src/devices/opencl/cl_ptr.rs @@ -8,7 +8,7 @@ use crate::HostPtr; use min_cl::api::release_mem_object; -use crate::{flag::AllocFlag, HasId, Id, PtrType, ShallowCopy}; +use crate::{flag::AllocFlag, HasId, Id, PtrType, ShallowCopy, WrappedCopy}; /// The pointer used for `OpenCL` [`Buffer`](crate::Buffer)s #[derive(Debug, PartialEq, Eq)] @@ -59,6 +59,15 @@ impl CLPtr { } } +impl WrappedCopy for CLPtr { + type Base = Self; + + #[inline] + fn wrapped_copy(&self, to_wrap: Self::Base) -> Self { + to_wrap + } +} + impl ShallowCopy for CLPtr { #[inline] unsafe fn shallow(&self) -> Self { diff --git a/src/devices/stack_array.rs b/src/devices/stack_array.rs index 01ef5e8a..d4862c28 100644 --- a/src/devices/stack_array.rs +++ b/src/devices/stack_array.rs @@ -1,6 +1,6 @@ use core::ops::{Deref, DerefMut}; -use crate::{shape::Shape, HasId, HostPtr, PtrType, ShallowCopy}; +use crate::{shape::Shape, HasId, HostPtr, PtrType, ShallowCopy, WrappedCopy}; /// A possibly multi-dimensional array allocated on the stack. /// It uses `S:`[`Shape`] to get the type of the array. @@ -137,6 +137,16 @@ impl HostPtr for StackArray { } } + +impl WrappedCopy for StackArray { + type Base = Self; + + #[inline] + fn wrapped_copy(&self, to_wrap: Self::Base) -> Self { + to_wrap + } +} + impl ShallowCopy for StackArray where S::ARR: Copy, diff --git a/src/devices/vulkan/vk_array.rs b/src/devices/vulkan/vk_array.rs index fde22b15..883c0df5 100644 --- a/src/devices/vulkan/vk_array.rs +++ b/src/devices/vulkan/vk_array.rs @@ -9,7 +9,7 @@ use core::{ }; use std::rc::Rc; -use crate::{flag::AllocFlag, HasId, HostPtr, PtrType, ShallowCopy}; +use crate::{flag::AllocFlag, HasId, HostPtr, PtrType, ShallowCopy, WrappedCopy}; use super::{context::Context, submit_and_wait}; @@ -228,6 +228,15 @@ impl VkArray { } } +impl WrappedCopy for VkArray { + type Base = Self; + + #[inline] + fn wrapped_copy(&self, to_wrap: Self::Base) -> Self { + to_wrap + } +} + impl ShallowCopy for VkArray { #[inline] unsafe fn shallow(&self) -> Self { diff --git a/src/features.rs b/src/features.rs index 6f4ae691..6bdbec26 100644 --- a/src/features.rs +++ b/src/features.rs @@ -80,7 +80,7 @@ pub trait Cursor { } #[inline] - fn cached(&self, cb: impl Fn()) + fn cached(&self, cb: impl Fn()) where Self: Sized, { diff --git a/src/lib.rs b/src/lib.rs index a17e8ca0..dccda7c7 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -180,6 +180,11 @@ pub trait Unit {} // useful for Sync and Send or 'static impl Unit for T {} +pub trait WrappedCopy { + type Base; + fn wrapped_copy(&self, to_wrap: Self::Base) -> Self; +} + /// Used to shallow-copy a pointer. Use is discouraged. pub trait ShallowCopy { /// # Safety diff --git a/src/modules/autograd/wrapper.rs b/src/modules/autograd/wrapper.rs index dbb4d25d..ada63d94 100644 --- a/src/modules/autograd/wrapper.rs +++ b/src/modules/autograd/wrapper.rs @@ -1,6 +1,6 @@ use core::marker::PhantomData; -use crate::{flag::AllocFlag, Autograd, HasId, PtrType, ShallowCopy, WrappedData}; +use crate::{flag::AllocFlag, Autograd, HasId, PtrType, ShallowCopy, WrappedCopy, WrappedData}; #[derive(Debug, PartialEq, Eq, PartialOrd, Ord)] pub struct ReqGradWrapper { @@ -74,6 +74,22 @@ impl PtrType for ReqGradWrapper { } } +impl WrappedCopy for ReqGradWrapper +where + Data: WrappedCopy, +{ + type Base = T; + + #[inline] + fn wrapped_copy(&self, to_wrap: Self::Base) -> Self { + Self { + requires_grad: self.requires_grad, + data: self.data.wrapped_copy(to_wrap), + _pd: PhantomData, + } + } +} + impl ShallowCopy for ReqGradWrapper where Data: ShallowCopy, diff --git a/src/modules/lazy/wrapper.rs b/src/modules/lazy/wrapper.rs index 9ac1000a..c534098a 100644 --- a/src/modules/lazy/wrapper.rs +++ b/src/modules/lazy/wrapper.rs @@ -6,7 +6,9 @@ use core::{ ops::{Deref, DerefMut}, }; -use crate::{flag::AllocFlag, HasId, HostPtr, Lazy, PtrType, ShallowCopy, WrappedData}; +use crate::{ + flag::AllocFlag, HasId, HostPtr, Lazy, PtrType, ShallowCopy, WrappedCopy, WrappedData, +}; #[derive(Debug, Default)] pub struct LazyWrapper { @@ -42,7 +44,7 @@ impl HasId for LazyWrapper { match self.maybe_data { MaybeData::Data(ref data) => data.id(), MaybeData::Id(id) => id, - MaybeData::None => unimplemented!() + MaybeData::None => unimplemented!(), } } } @@ -53,13 +55,14 @@ impl PtrType for LazyWrapper { match self.maybe_data { MaybeData::Data(ref data) => data.size(), MaybeData::Id(id) => id.len, - MaybeData::None => unimplemented!() + MaybeData::None => unimplemented!(), } } #[inline] fn flag(&self) -> AllocFlag { - self.maybe_data.data() + self.maybe_data + .data() .map(|data| data.flag()) .unwrap_or(AllocFlag::Lazy) } @@ -101,6 +104,24 @@ impl> HostPtr for LazyWrapper { } } +impl WrappedCopy for LazyWrapper +where + Data: WrappedCopy, +{ + type Base = T; + + fn wrapped_copy(&self, to_wrap: Self::Base) -> Self { + LazyWrapper { + maybe_data: match &self.maybe_data { + MaybeData::Data(data) => MaybeData::Data(data.wrapped_copy(to_wrap)), + MaybeData::Id(id) => MaybeData::Id(*id), + MaybeData::None => unimplemented!(), + }, + _pd: PhantomData, + } + } +} + impl ShallowCopy for LazyWrapper { #[inline] unsafe fn shallow(&self) -> Self { diff --git a/src/modules/lazy/wrapper/maybe_data.rs b/src/modules/lazy/wrapper/maybe_data.rs index dfeab997..bf89e5aa 100644 --- a/src/modules/lazy/wrapper/maybe_data.rs +++ b/src/modules/lazy/wrapper/maybe_data.rs @@ -17,16 +17,16 @@ impl MaybeData { MaybeData::None => None, } } - + #[inline] pub fn data_mut(&mut self) -> Option<&mut Data> { match self { MaybeData::Data(data) => Some(data), MaybeData::Id(_id) => None, - MaybeData::None => None + MaybeData::None => None, } } - + #[inline] pub fn id(&self) -> Option<&Id> { match self { @@ -35,7 +35,7 @@ impl MaybeData { MaybeData::None => None, } } - + #[inline] pub fn id_mut(&mut self) -> Option<&mut Id> { match self { diff --git a/src/range.rs b/src/range.rs index 3153deaa..5961586c 100644 --- a/src/range.rs +++ b/src/range.rs @@ -23,7 +23,7 @@ impl<'a, D: Cursor> CursorRangeIter<'a, D> { pub fn previous_cursor(&self) -> &usize { &self.previous_cursor } - + #[inline] pub fn cursor_range(&self) -> &CursorRange<'a, D> { &self.range @@ -68,7 +68,6 @@ pub trait AsRange { fn end(&self) -> usize; } - // Implementing AsRange for standard Range (e.g., 0..10) impl AsRange for Range { #[inline] @@ -173,7 +172,6 @@ impl AsRange for RangeToInclusive { } } - #[cfg(test)] mod tests { #[cfg(feature = "cpu")] @@ -243,7 +241,7 @@ mod tests { unsafe { device.bump_cursor() }; assert_eq!(device.cursor(), 10); } - } + } #[cfg(feature = "cpu")] #[cfg(feature = "cached")]