From 2d2cb345670fcfa9e5dd60cb3586a29f08f461d7 Mon Sep 17 00:00:00 2001 From: Simon Liu Date: Sun, 7 Jul 2024 12:21:32 +0800 Subject: [PATCH] Refactor: Unify f64-Tensor arithmetic operations using macro This commit refactors the previously separate implementations of arithmetic operations (Add, Sub, Mul, Div) between f64 and Tensor types into a single, reusable macro `impl_f64_tensor_ops`. --- candle-core/src/tensor.rs | 77 +++++++++------------------------------ 1 file changed, 18 insertions(+), 59 deletions(-) diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index dd1b44b0a0..eab29ae5a8 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -2529,68 +2529,27 @@ bin_trait!(Sub, sub, |_| 1., |v: f64| -v); bin_trait!(Mul, mul, |v| v, |_| 0.); bin_trait!(Div, div, |v| 1. / v, |_| 0.); -impl std::ops::Add for f64 { - type Output = Result; - - fn add(self, rhs: Tensor) -> Self::Output { - rhs + self - } -} - -impl std::ops::Add<&Tensor> for f64 { - type Output = Result; - - fn add(self, rhs: &Tensor) -> Self::Output { - rhs + self - } -} - -impl std::ops::Mul for f64 { - type Output = Result; - - fn mul(self, rhs: Tensor) -> Self::Output { - rhs * self - } -} - -impl std::ops::Mul<&Tensor> for f64 { - type Output = Result; - - fn mul(self, rhs: &Tensor) -> Self::Output { - rhs * self - } -} - -impl std::ops::Sub for f64 { - type Output = Result; - - fn sub(self, rhs: Tensor) -> Self::Output { - rhs.affine(-1., self) - } -} - -impl std::ops::Sub<&Tensor> for f64 { - type Output = Result; +macro_rules! impl_f64_tensor_ops { + ($trait:ident, $method:ident, $impl:expr) => { + impl std::ops::$trait for f64 { + type Output = Result; - fn sub(self, rhs: &Tensor) -> Self::Output { - rhs.affine(-1., self) - } -} + fn $method(self, rhs: Tensor) -> Self::Output { + $impl(self, &rhs) + } + } -impl std::ops::Div for f64 { - type Output = Result; + impl std::ops::$trait<&Tensor> for f64 { + type Output = Result; - #[allow(clippy::suspicious_arithmetic_impl)] - fn div(self, rhs: Tensor) -> Self::Output { - rhs.recip()? * self + fn $method(self, rhs: &Tensor) -> Self::Output { + $impl(self, rhs) + } + } } } -impl std::ops::Div<&Tensor> for f64 { - type Output = Result; - - #[allow(clippy::suspicious_arithmetic_impl)] - fn div(self, rhs: &Tensor) -> Self::Output { - rhs.recip()? * self - } -} +impl_f64_tensor_ops!(Add, add, |self_: f64, rhs: &Tensor| rhs + self_); +impl_f64_tensor_ops!(Sub, sub, |self_: f64, rhs: &Tensor| rhs.affine(-1., self_)); +impl_f64_tensor_ops!(Mul, mul, |self_: f64, rhs: &Tensor| rhs * self_); +impl_f64_tensor_ops!(Div, div, |self_: f64, rhs: &Tensor| rhs.recip()? * self_);