From c4f3c53f6ee9f3f2f95d8825bb826e532c4184cb Mon Sep 17 00:00:00 2001 From: Guillaume Viejo Date: Wed, 2 Oct 2024 17:14:57 -0400 Subject: [PATCH 1/2] First commit --- src/nemos/basis.py | 35 ++++++++++++++++++++++++++++++++--- 1 file changed, 32 insertions(+), 3 deletions(-) diff --git a/src/nemos/basis.py b/src/nemos/basis.py index 2cc48f95..3bf13b50 100644 --- a/src/nemos/basis.py +++ b/src/nemos/basis.py @@ -185,6 +185,7 @@ def _unpack_inputs(X: FeatureMatrix): A tuple of each individual input. """ + return (X[:, k] for k in range(X.shape[1])) def fit(self, X: FeatureMatrix, y=None): @@ -934,21 +935,41 @@ def __add__(self, other: Basis) -> AdditiveBasis: """ return AdditiveBasis(self, other) - def __mul__(self, other: Basis) -> MultiplicativeBasis: + def __len__(self) -> int: + """ + + Returns + ------- + : int + Number of basis functions. + """ + return self.n_basis_funcs + + def __mul__(self, other: (Basis, int)) -> MultiplicativeBasis: """ Multiply two Basis objects together. Parameters ---------- other - The other Basis object to multiply. + The other Basis object to multiply or integer Returns ------- : The resulting Basis object. """ - return MultiplicativeBasis(self, other) + if isinstance(other, Basis): + return MultiplicativeBasis(self, other) + elif isinstance(other, int): + if other <= 0: + raise ValueError("Multiplier should be a non-negative integer!") + result = self + for _ in range(other-1): + result = result + self + return result + else: + raise TypeError("Basis can only be multiplied with another basis or an integer!") def __pow__(self, exponent: int) -> MultiplicativeBasis: """Exponentiation of a Basis object. @@ -1044,6 +1065,14 @@ def __init__(self, basis1: Basis, basis2: Basis) -> None: self._basis2 = basis2 return + @property + def basis1(self): + return self._basis1 + + @property + def basis2(self): + return self._basis2 + def _check_n_basis_min(self) -> None: pass From a8c8e3683cd589e39540892106c1dda8b7ac8555 Mon Sep 17 00:00:00 2001 From: Guillaume Viejo Date: Thu, 10 Oct 2024 16:36:51 -0400 Subject: [PATCH 2/2] Adding tests --- docs/how_to_guide/plot_04_population_glm.py | 5 +- docs/how_to_guide/plot_05_batch_glm.py | 5 +- .../plot_06_sklearn_pipeline_cv_demo.py | 9 ++- src/nemos/basis.py | 10 ++-- tests/test_basis.py | 59 +++++++++++++++++++ 5 files changed, 75 insertions(+), 13 deletions(-) diff --git a/docs/how_to_guide/plot_04_population_glm.py b/docs/how_to_guide/plot_04_population_glm.py index 84282477..70dac9cd 100644 --- a/docs/how_to_guide/plot_04_population_glm.py +++ b/docs/how_to_guide/plot_04_population_glm.py @@ -23,9 +23,10 @@ """ import jax.numpy as jnp -import nemos as nmo -import numpy as np import matplotlib.pyplot as plt +import numpy as np + +import nemos as nmo np.random.seed(123) diff --git a/docs/how_to_guide/plot_05_batch_glm.py b/docs/how_to_guide/plot_05_batch_glm.py index f9e758fc..84f64d98 100644 --- a/docs/how_to_guide/plot_05_batch_glm.py +++ b/docs/how_to_guide/plot_05_batch_glm.py @@ -6,10 +6,11 @@ """ +import matplotlib.pyplot as plt +import numpy as np import pynapple as nap + import nemos as nmo -import numpy as np -import matplotlib.pyplot as plt nap.nap_config.suppress_conversion_warnings = True diff --git a/docs/how_to_guide/plot_06_sklearn_pipeline_cv_demo.py b/docs/how_to_guide/plot_06_sklearn_pipeline_cv_demo.py index b7168e33..ca9b167a 100644 --- a/docs/how_to_guide/plot_06_sklearn_pipeline_cv_demo.py +++ b/docs/how_to_guide/plot_06_sklearn_pipeline_cv_demo.py @@ -71,20 +71,19 @@ # ## Combining basis transformations and GLM in a pipeline # Let's start by creating some toy data. -import nemos as nmo +import matplotlib.pyplot as plt import numpy as np import pandas as pd import scipy.stats -import matplotlib.pyplot as plt import seaborn as sns - -from sklearn.pipeline import Pipeline from sklearn.model_selection import GridSearchCV +from sklearn.pipeline import Pipeline + +import nemos as nmo # some helper plotting functions from nemos import _documentation_utils as doc_plots - # predictors, shape (n_samples, n_features) X = np.random.uniform(low=0, high=1, size=(1000, 1)) # observed counts, shape (n_samples,) diff --git a/src/nemos/basis.py b/src/nemos/basis.py index 3bf13b50..97c099a7 100644 --- a/src/nemos/basis.py +++ b/src/nemos/basis.py @@ -185,7 +185,6 @@ def _unpack_inputs(X: FeatureMatrix): A tuple of each individual input. """ - return (X[:, k] for k in range(X.shape[1])) def fit(self, X: FeatureMatrix, y=None): @@ -947,7 +946,8 @@ def __len__(self) -> int: def __mul__(self, other: (Basis, int)) -> MultiplicativeBasis: """ - Multiply two Basis objects together. + Multiply two Basis objects together or replicate the basis + by multiplying it with an integer. Parameters ---------- @@ -965,11 +965,13 @@ def __mul__(self, other: (Basis, int)) -> MultiplicativeBasis: if other <= 0: raise ValueError("Multiplier should be a non-negative integer!") result = self - for _ in range(other-1): + for _ in range(other - 1): result = result + self return result else: - raise TypeError("Basis can only be multiplied with another basis or an integer!") + raise TypeError( + "Basis can only be multiplied with another basis or an integer!" + ) def __pow__(self, exponent: int) -> MultiplicativeBasis: """Exponentiation of a Basis object. diff --git a/tests/test_basis.py b/tests/test_basis.py index 6e81d142..7bc3d797 100644 --- a/tests/test_basis.py +++ b/tests/test_basis.py @@ -3202,6 +3202,65 @@ def test_compute_features_input(self, eval_input): basis_obj = basis.MSplineBasis(5) + basis.MSplineBasis(5) basis_obj.compute_features(*eval_input) + @pytest.mark.parametrize("n_basis_a", [5, 6]) + @pytest.mark.parametrize("n_basis_b", [5, 6]) + @pytest.mark.parametrize("basis_a", list_all_basis_classes()) + @pytest.mark.parametrize("basis_b", list_all_basis_classes()) + def test_len( + self, n_basis_a, n_basis_b, basis_a, basis_b, + ): + """ + Test for __len__ of basis + """ + # define the two basis + basis_a_obj = self.instantiate_basis( + n_basis_a, basis_a, mode="eval" + ) + basis_b_obj = self.instantiate_basis( + n_basis_b, basis_b, mode="eval" + ) + + basis_obj = basis_a_obj + basis_b_obj + + assert hasattr(basis_obj, "__len__") + assert len(basis_a_obj) == basis_a_obj.n_basis_funcs + assert len(basis_b_obj) == basis_b_obj.n_basis_funcs + assert len(basis_obj) == basis_obj.n_basis_funcs + + @pytest.mark.parametrize("n", [1, 6]) + @pytest.mark.parametrize("basis", list_all_basis_classes()) + def test_basis_multiply_with_integer( + self, n, basis, + ): + """ + Test for __mul__ of basis with integer + """ + # define the two basis + basis_obj = self.instantiate_basis( + 5, basis, mode="eval" + ) + new_basis_obj = basis_obj * n + + assert new_basis_obj.n_basis_funcs == n * basis_obj.n_basis_funcs + + @pytest.mark.parametrize("basis", list_all_basis_classes()) + @pytest.mark.parametrize("n, expected", [ + (-2, pytest.raises(ValueError, match=r"Multiplier should be a non-negative integer!")), + ("6", pytest.raises(TypeError, match=r"Basis can only be multiplied with another basis or an integer!")) + ]) + def test_basis_multiply_errors( + self, basis, n, expected + ): + """ + Test for __mul__ of basis. raise errors + """ + # define the two basis + basis_obj = self.instantiate_basis( + 5, basis, mode="eval" + ) + with expected: + basis_obj * n + @pytest.mark.parametrize("n_basis_a", [5, 6]) @pytest.mark.parametrize("n_basis_b", [5, 6]) @pytest.mark.parametrize("sample_size", [10, 1000])