Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix transformer #235

Open
wants to merge 3 commits into
base: development
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions docs/how_to_guide/plot_04_population_glm.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,10 @@
"""

import jax.numpy as jnp
import nemos as nmo
import numpy as np
import matplotlib.pyplot as plt
import numpy as np

import nemos as nmo

np.random.seed(123)

Expand Down
5 changes: 3 additions & 2 deletions docs/how_to_guide/plot_05_batch_glm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@

"""

import matplotlib.pyplot as plt
import numpy as np
import pynapple as nap

import nemos as nmo
import numpy as np
import matplotlib.pyplot as plt

nap.nap_config.suppress_conversion_warnings = True

Expand Down
9 changes: 4 additions & 5 deletions docs/how_to_guide/plot_06_sklearn_pipeline_cv_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,20 +71,19 @@
# ## Combining basis transformations and GLM in a pipeline
# Let's start by creating some toy data.

import nemos as nmo
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy.stats
import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.pipeline import Pipeline
from sklearn.model_selection import GridSearchCV
from sklearn.pipeline import Pipeline

import nemos as nmo

# some helper plotting functions
from nemos import _documentation_utils as doc_plots


# predictors, shape (n_samples, n_features)
X = np.random.uniform(low=0, high=1, size=(1000, 1))
# observed counts, shape (n_samples,)
Expand Down
39 changes: 35 additions & 4 deletions src/nemos/basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -934,21 +934,44 @@ def __add__(self, other: Basis) -> AdditiveBasis:
"""
return AdditiveBasis(self, other)

def __mul__(self, other: Basis) -> MultiplicativeBasis:
def __len__(self) -> int:
"""
Multiply two Basis objects together.

Returns
-------
: int
Number of basis functions.
"""
return self.n_basis_funcs

def __mul__(self, other: (Basis, int)) -> MultiplicativeBasis:
"""
Multiply two Basis objects together or replicate the basis
by multiplying it with an integer.

Parameters
----------
other
The other Basis object to multiply.
The other Basis object to multiply or integer

Returns
-------
:
The resulting Basis object.
"""
return MultiplicativeBasis(self, other)
if isinstance(other, Basis):
return MultiplicativeBasis(self, other)
elif isinstance(other, int):
if other <= 0:
raise ValueError("Multiplier should be a non-negative integer!")
result = self
for _ in range(other - 1):
result = result + self
return result
else:
raise TypeError(
"Basis can only be multiplied with another basis or an integer!"
)

def __pow__(self, exponent: int) -> MultiplicativeBasis:
"""Exponentiation of a Basis object.
Expand Down Expand Up @@ -1044,6 +1067,14 @@ def __init__(self, basis1: Basis, basis2: Basis) -> None:
self._basis2 = basis2
return

@property
def basis1(self):
return self._basis1

@property
def basis2(self):
return self._basis2

def _check_n_basis_min(self) -> None:
pass

Expand Down
59 changes: 59 additions & 0 deletions tests/test_basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -3202,6 +3202,65 @@ def test_compute_features_input(self, eval_input):
basis_obj = basis.MSplineBasis(5) + basis.MSplineBasis(5)
basis_obj.compute_features(*eval_input)

@pytest.mark.parametrize("n_basis_a", [5, 6])
@pytest.mark.parametrize("n_basis_b", [5, 6])
@pytest.mark.parametrize("basis_a", list_all_basis_classes())
@pytest.mark.parametrize("basis_b", list_all_basis_classes())
def test_len(
self, n_basis_a, n_basis_b, basis_a, basis_b,
):
"""
Test for __len__ of basis
"""
# define the two basis
basis_a_obj = self.instantiate_basis(
n_basis_a, basis_a, mode="eval"
)
basis_b_obj = self.instantiate_basis(
n_basis_b, basis_b, mode="eval"
)

basis_obj = basis_a_obj + basis_b_obj

assert hasattr(basis_obj, "__len__")
assert len(basis_a_obj) == basis_a_obj.n_basis_funcs
assert len(basis_b_obj) == basis_b_obj.n_basis_funcs
assert len(basis_obj) == basis_obj.n_basis_funcs

@pytest.mark.parametrize("n", [1, 6])
@pytest.mark.parametrize("basis", list_all_basis_classes())
def test_basis_multiply_with_integer(
self, n, basis,
):
"""
Test for __mul__ of basis with integer
"""
# define the two basis
basis_obj = self.instantiate_basis(
5, basis, mode="eval"
)
new_basis_obj = basis_obj * n

assert new_basis_obj.n_basis_funcs == n * basis_obj.n_basis_funcs

@pytest.mark.parametrize("basis", list_all_basis_classes())
@pytest.mark.parametrize("n, expected", [
(-2, pytest.raises(ValueError, match=r"Multiplier should be a non-negative integer!")),
("6", pytest.raises(TypeError, match=r"Basis can only be multiplied with another basis or an integer!"))
])
def test_basis_multiply_errors(
self, basis, n, expected
):
"""
Test for __mul__ of basis. raise errors
"""
# define the two basis
basis_obj = self.instantiate_basis(
5, basis, mode="eval"
)
with expected:
basis_obj * n

@pytest.mark.parametrize("n_basis_a", [5, 6])
@pytest.mark.parametrize("n_basis_b", [5, 6])
@pytest.mark.parametrize("sample_size", [10, 1000])
Expand Down
Loading