Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add fourier #49

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
4 changes: 3 additions & 1 deletion docs/developers_notes/01-basis_module.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@ Abstract Class Basis
│ │
│ └─ Concrete Subclass RaisedCosineBasisLog
└─ Concrete Subclass OrthExponentialBasis
├─ Concrete Subclass OrthExponentialBasis
└─ Concrete Subclass FourierBasis
```

The super-class `Basis` provides two public methods, [`evaluate`](#the-public-method-evaluate) and [`evaluate_on_grid`](#the-public-method-evaluate_on_grid). These methods perform checks on both the input provided by the user and the output of the evaluation to ensure correctness, and are thus considered "safe". They both make use of the private abstract method `_evaluate` that is specific for each concrete class. See below for more details.
Expand Down
90 changes: 88 additions & 2 deletions docs/examples/plot_1D_basis_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,10 @@
# -----------------
# Each basis type may necessitate specific hyperparameters for instantiation. For a comprehensive description,
# please refer to the [Code References](../../../reference/nemos/basis). After instantiation, all classes
# share the same syntax for basis evaluation. The following is an example of how to instantiate and
# evaluate a log-spaced cosine raised function basis.
# share the same syntax for basis evaluation.
#
# ### The Log-Spaced Raised Cosine Basis
# The following is an example of how to instantiate and evaluate a log-spaced cosine raised function basis.

# Instantiate the basis noting that the `RaisedCosineBasisLog` does not require an `order` parameter
raised_cosine_log = nmo.basis.RaisedCosineBasisLog(n_basis_funcs=10)
Expand All @@ -81,3 +83,87 @@
plt.plot(samples, eval_basis)
plt.show()

# %%
# ### The Fourier Basis
# Another type of basis available is the Fourier Basis. Fourier basis are ideal to capture periodic and
# quasi-periodic patterns. Such oscillatory, rhythmic behavior is a common signature of many neural signals.
# Additionally, the Fourier basis has the advantage of being orthogonal, which simplifies the estimation and
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should probably explain orthogonal here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

still think this. at least a foot note or link

# interpretation of the model parameters, each of which will represent the relative contribution of a specific
# oscillation frequency to the overall signal.


# A Fourier basis can be instantiated with the usual syntax.
billbrod marked this conversation as resolved.
Show resolved Hide resolved
# The user can pass the desired frequencies for the basis or
# the frequencies will be set to np.arange(n_basis_funcs//2).
billbrod marked this conversation as resolved.
Show resolved Hide resolved
# The number of basis function is required to be even.
fourier_basis = nmo.basis.FourierBasis(n_freqs=4)

# evaluate on equi-spaced samples
samples, eval_basis = fourier_basis.evaluate_on_grid(1000)

# plot the `sin` and `cos` separately
plt.figure(figsize=(6, 3))
plt.subplot(121)
plt.title("Cos")
plt.plot(samples, eval_basis[:, :4])
plt.subplot(122)
plt.title("Sin")
plt.plot(samples, eval_basis[:, 4:])
plt.tight_layout()

# %%
# !!! note "Fourier basis convolution and Fourier transform"
billbrod marked this conversation as resolved.
Show resolved Hide resolved
# The Fourier transform of a signal $ s(t) $ restricted to a temporal window $ [t_0,\;t_1] $ is
# $$ \\hat{x}(\\omega) = \\int_{t_0}^{t_1} s(\\tau) e^{-j\\omega \\tau} d\\tau. $$
# where $ e^{-j\\omega \\tau} = \\cos(\\omega \\tau) - j \\sin (\\omega \\tau) $.
#
# When computing the cross-correlation of a signal with the Fourier basis functions,
# we essentially measure how well the signal correlates with sinusoids of different frequencies,
# within a specified temporal window. This process mirrors the operation performed by the Fourier transform.
# Therefore, it becomes clear that computing the cross-correlation of a signal with the Fourier basis defined here
BalzaniEdoardo marked this conversation as resolved.
Show resolved Hide resolved
# is equivalent to computing the discrete Fourier transform on a sliding window of the same size
# as that of the basis.

n_samples = 1000
n_freqs = 20

# define a signal
signal = np.random.normal(size=n_samples)

# evaluate the basis
_, eval_basis = nmo.basis.FourierBasis(n_freqs=n_freqs).evaluate_on_grid(n_samples)

# compute the cross-corr with the signal and the basis
# Note that we are inverting the time axis of the basis because we are aiming
# for a cross-correlation, while np.convolve compute a convolution which would flip the time axis.
Comment on lines +139 to +140
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't we just compute the correlation directly to avoid this confusion? It's true, but provides an extra hurdle for folks (and then we could call out this equivalency in an admonition)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

still think this

xcorr = np.array(
[
np.convolve(eval_basis[::-1, k], signal, mode="valid")[0]
for k in range(2 * n_freqs - 1)
]
)

# compute the power (add back sin(0 * t) = 0)
fft_complex = np.fft.fft(signal)
fft_amplitude = np.abs(fft_complex[:n_freqs])
fft_phase = np.angle(fft_complex[:n_freqs])
# compute the phase and amplitude from the convolution
xcorr_phase = np.arctan2(np.hstack([[0], xcorr[n_freqs:]]), xcorr[:n_freqs])
xcorr_aplitude = np.sqrt(xcorr[:n_freqs] ** 2 + np.hstack([[0], xcorr[n_freqs:]]) ** 2)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably worth explaining why this is the same as computing the phase and amplitude

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would probably involve a bit more describing why we look at the phase and amplitude... Something like

"The Fourier transform is often interpreted by looking at its amplitude and phase, which is the polar representation: amplitude is the distance from the origin, while phase is the angle. The amplitude conveys how much "stuff" is present at a given frequency, whereas the phase describes its alignment [or something??]. Thus, to compute the amplitude we ... and to compute the phase we ..."

I'm sure Eero has a good intuitive explanation of this somewhere, maybe in the math tools notes.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

still think this


fig, ax = plt.subplots(1, 2)
ax[0].set_aspect("equal")
ax[0].set_title("Signal amplitude")
ax[0].scatter(fft_amplitude, xcorr_aplitude)
ax[0].set_xlabel("FFT")
ax[0].set_ylabel("cross-correlation")

ax[1].set_aspect("equal")
ax[1].set_title("Signal phase")
ax[1].scatter(fft_phase, xcorr_phase)
ax[1].set_xlabel("FFT")
ax[1].set_ylabel("cross-correlation")
plt.tight_layout()

print(f"Max Error Amplitude: {np.abs(fft_amplitude - xcorr_aplitude).max()}")
print(f"Max Error Phase: {np.abs(fft_phase - xcorr_phase).max()}")
10 changes: 1 addition & 9 deletions src/nemos/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,3 @@
#!/usr/bin/env python3

from . import (
basis,
exceptions,
glm,
observation_models,
regularizer,
simulation,
utils,
)
from . import basis, exceptions, glm, observation_models, regularizer, simulation, utils
92 changes: 90 additions & 2 deletions src/nemos/basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
"OrthExponentialBasis",
"AdditiveBasis",
"MultiplicativeBasis",
"FourierBasis",
]


Expand Down Expand Up @@ -103,7 +104,7 @@ def _check_evaluate_input(self, *xi: ArrayLike) -> Tuple[NDArray]:
# make sure array is at least 1d (so that we succeed when only
# passed a scalar)
xi = tuple(np.atleast_1d(np.asarray(x, dtype=float)) for x in xi)
except TypeError:
except (TypeError, ValueError):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what additionally is being caught here?

raise TypeError("Input samples must be array-like of floats!")

# check for non-empty samples
Expand Down Expand Up @@ -1024,7 +1025,8 @@ def _check_rates(self):
"linearly dependent set of function for the basis."
)

def _check_sample_range(self, sample_pts: NDArray):
@staticmethod
def _check_sample_range(sample_pts: NDArray):
"""
Check if the sample points are all positive.

Expand Down Expand Up @@ -1115,6 +1117,92 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]:
return super().evaluate_on_grid(n_samples)


class FourierBasis(Basis):
"""Set of 1D Fourier basis.

Parameters
----------
n_freqs
Number of frequencies. The number of basis function will be 2*n_freqs - 1.
"""

def __init__(self, n_freqs: int):
super().__init__(n_basis_funcs=2 * n_freqs - 1)
billbrod marked this conversation as resolved.
Show resolved Hide resolved

self._frequencies = np.arange(n_freqs, dtype=np.float32)
self._n_input_dimensionality = 1

def _check_n_basis_min(self) -> None:
"""Check that the user required enough basis elements.

Checks that the number of basis is at least 1.

Raises
------
ValueError
If an insufficient number of basis element is requested for the basis type
"""
if self.n_basis_funcs < 1:
raise ValueError(
f"Object class {self.__class__.__name__} requires >= 1 basis elements. "
f"{self.n_basis_funcs} basis elements specified instead"
)

def evaluate(self, sample_pts: NDArray) -> NDArray:
"""Generate basis functions with given spacing.

Parameters
----------
sample_pts
Spacing for basis functions.

Returns
-------
basis_funcs
Evaluated Fourier basis, shape (n_samples, n_basis_funcs).

Notes
-----
If the frequencies provided are np.arange(n_freq), convolving a signal
of length n_samples with this basis is equivalent, but slower,
then computing the FFT truncated to the first n_freq components.

Therefore, convolving a signal with this basis is equivalent
to compute the FFT over sliding window.
billbrod marked this conversation as resolved.
Show resolved Hide resolved

Examples
--------
>>> import nemos as nmo
>>> import numpy as np
>>> n_samples, n_freqs = 1000, 10
>>> basis = nmo.basis.FourierBasis(n_freqs*2)
>>> eval_basis = basis.evaluate(np.linspace(0, 1, n_samples))
>>> sinusoid = np.cos(3 * np.arange(0, 1000) * np.pi * 2 / 1000.)
>>> conv = [np.convolve(eval_basis[::-1, k], sinusoid, mode='valid')[0] for k in range(2*n_freqs-1)]
>>> fft = np.fft.fft(sinusoid)
billbrod marked this conversation as resolved.
Show resolved Hide resolved
>>> print('FFT power: ', np.round(np.real(fft[:10]), 4))
>>> print('Convolution: ', np.round(conv[:10], 4))
billbrod marked this conversation as resolved.
Show resolved Hide resolved
"""
(sample_pts,) = self._check_evaluate_input(sample_pts)
# assumes equi-spaced samples.
if sample_pts.shape[0] / np.max(self._frequencies) < 2:
raise ValueError("Not enough samples, aliasing likely to occur!")

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe report sample_pts.shape[0] and max(self._frequencies) to the user?


# rescale to [0, 2pi)
mn, mx = np.nanmin(sample_pts), np.nanmax(sample_pts)
# first sample in 0, last sample in 2 pi - 2 pi / n_samples.
sample_pts = (
2
* np.pi
* (sample_pts - mn)
/ (mx - mn)
* (1.0 - 1.0 / sample_pts.shape[0])
)
# create the basis
angles = np.einsum("i,j->ij", sample_pts, self._frequencies)
return np.concatenate([np.cos(angles), -np.sin(angles[:, 1:])], axis=1)


def mspline(x: NDArray, k: int, i: int, T: NDArray):
"""Compute M-spline basis function.

Expand Down
Loading