-
Notifications
You must be signed in to change notification settings - Fork 8
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add fourier #49
base: main
Are you sure you want to change the base?
Add fourier #49
Changes from 6 commits
5e58d74
290f223
0a57674
8ebf2dc
2908a5c
497b31f
332d555
17a05a9
00aaa95
e94b606
aa67d9c
adb5efb
a62712e
07bd9f6
b704c98
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -65,8 +65,10 @@ | |
# ----------------- | ||
# Each basis type may necessitate specific hyperparameters for instantiation. For a comprehensive description, | ||
# please refer to the [Code References](../../../reference/nemos/basis). After instantiation, all classes | ||
# share the same syntax for basis evaluation. The following is an example of how to instantiate and | ||
# evaluate a log-spaced cosine raised function basis. | ||
# share the same syntax for basis evaluation. | ||
# | ||
# ### The Log-Spaced Raised Cosine Basis | ||
# The following is an example of how to instantiate and evaluate a log-spaced cosine raised function basis. | ||
|
||
# Instantiate the basis noting that the `RaisedCosineBasisLog` does not require an `order` parameter | ||
raised_cosine_log = nmo.basis.RaisedCosineBasisLog(n_basis_funcs=10) | ||
|
@@ -81,3 +83,87 @@ | |
plt.plot(samples, eval_basis) | ||
plt.show() | ||
|
||
# %% | ||
# ### The Fourier Basis | ||
# Another type of basis available is the Fourier Basis. Fourier basis are ideal to capture periodic and | ||
# quasi-periodic patterns. Such oscillatory, rhythmic behavior is a common signature of many neural signals. | ||
# Additionally, the Fourier basis has the advantage of being orthogonal, which simplifies the estimation and | ||
# interpretation of the model parameters, each of which will represent the relative contribution of a specific | ||
# oscillation frequency to the overall signal. | ||
|
||
|
||
# A Fourier basis can be instantiated with the usual syntax. | ||
billbrod marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# The user can pass the desired frequencies for the basis or | ||
# the frequencies will be set to np.arange(n_basis_funcs//2). | ||
billbrod marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# The number of basis function is required to be even. | ||
fourier_basis = nmo.basis.FourierBasis(n_freqs=4) | ||
|
||
# evaluate on equi-spaced samples | ||
samples, eval_basis = fourier_basis.evaluate_on_grid(1000) | ||
|
||
# plot the `sin` and `cos` separately | ||
plt.figure(figsize=(6, 3)) | ||
plt.subplot(121) | ||
plt.title("Cos") | ||
plt.plot(samples, eval_basis[:, :4]) | ||
plt.subplot(122) | ||
plt.title("Sin") | ||
plt.plot(samples, eval_basis[:, 4:]) | ||
plt.tight_layout() | ||
|
||
# %% | ||
# !!! note "Fourier basis convolution and Fourier transform" | ||
billbrod marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# The Fourier transform of a signal $ s(t) $ restricted to a temporal window $ [t_0,\;t_1] $ is | ||
# $$ \\hat{x}(\\omega) = \\int_{t_0}^{t_1} s(\\tau) e^{-j\\omega \\tau} d\\tau. $$ | ||
# where $ e^{-j\\omega \\tau} = \\cos(\\omega \\tau) - j \\sin (\\omega \\tau) $. | ||
# | ||
# When computing the cross-correlation of a signal with the Fourier basis functions, | ||
# we essentially measure how well the signal correlates with sinusoids of different frequencies, | ||
# within a specified temporal window. This process mirrors the operation performed by the Fourier transform. | ||
# Therefore, it becomes clear that computing the cross-correlation of a signal with the Fourier basis defined here | ||
BalzaniEdoardo marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# is equivalent to computing the discrete Fourier transform on a sliding window of the same size | ||
# as that of the basis. | ||
|
||
n_samples = 1000 | ||
n_freqs = 20 | ||
|
||
# define a signal | ||
signal = np.random.normal(size=n_samples) | ||
|
||
# evaluate the basis | ||
_, eval_basis = nmo.basis.FourierBasis(n_freqs=n_freqs).evaluate_on_grid(n_samples) | ||
|
||
# compute the cross-corr with the signal and the basis | ||
# Note that we are inverting the time axis of the basis because we are aiming | ||
# for a cross-correlation, while np.convolve compute a convolution which would flip the time axis. | ||
Comment on lines
+139
to
+140
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can't we just compute the correlation directly to avoid this confusion? It's true, but provides an extra hurdle for folks (and then we could call out this equivalency in an admonition) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. still think this |
||
xcorr = np.array( | ||
[ | ||
np.convolve(eval_basis[::-1, k], signal, mode="valid")[0] | ||
for k in range(2 * n_freqs - 1) | ||
] | ||
) | ||
|
||
# compute the power (add back sin(0 * t) = 0) | ||
fft_complex = np.fft.fft(signal) | ||
fft_amplitude = np.abs(fft_complex[:n_freqs]) | ||
fft_phase = np.angle(fft_complex[:n_freqs]) | ||
# compute the phase and amplitude from the convolution | ||
xcorr_phase = np.arctan2(np.hstack([[0], xcorr[n_freqs:]]), xcorr[:n_freqs]) | ||
xcorr_aplitude = np.sqrt(xcorr[:n_freqs] ** 2 + np.hstack([[0], xcorr[n_freqs:]]) ** 2) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. probably worth explaining why this is the same as computing the phase and amplitude There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This would probably involve a bit more describing why we look at the phase and amplitude... Something like "The Fourier transform is often interpreted by looking at its amplitude and phase, which is the polar representation: amplitude is the distance from the origin, while phase is the angle. The amplitude conveys how much "stuff" is present at a given frequency, whereas the phase describes its alignment [or something??]. Thus, to compute the amplitude we ... and to compute the phase we ..." I'm sure Eero has a good intuitive explanation of this somewhere, maybe in the math tools notes. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. still think this |
||
|
||
fig, ax = plt.subplots(1, 2) | ||
ax[0].set_aspect("equal") | ||
ax[0].set_title("Signal amplitude") | ||
ax[0].scatter(fft_amplitude, xcorr_aplitude) | ||
ax[0].set_xlabel("FFT") | ||
ax[0].set_ylabel("cross-correlation") | ||
|
||
ax[1].set_aspect("equal") | ||
ax[1].set_title("Signal phase") | ||
ax[1].scatter(fft_phase, xcorr_phase) | ||
ax[1].set_xlabel("FFT") | ||
ax[1].set_ylabel("cross-correlation") | ||
plt.tight_layout() | ||
|
||
print(f"Max Error Amplitude: {np.abs(fft_amplitude - xcorr_aplitude).max()}") | ||
print(f"Max Error Phase: {np.abs(fft_phase - xcorr_phase).max()}") |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,3 @@ | ||
#!/usr/bin/env python3 | ||
|
||
from . import ( | ||
basis, | ||
exceptions, | ||
glm, | ||
observation_models, | ||
regularizer, | ||
simulation, | ||
utils, | ||
) | ||
from . import basis, exceptions, glm, observation_models, regularizer, simulation, utils |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,6 +22,7 @@ | |
"OrthExponentialBasis", | ||
"AdditiveBasis", | ||
"MultiplicativeBasis", | ||
"FourierBasis", | ||
] | ||
|
||
|
||
|
@@ -103,7 +104,7 @@ def _check_evaluate_input(self, *xi: ArrayLike) -> Tuple[NDArray]: | |
# make sure array is at least 1d (so that we succeed when only | ||
# passed a scalar) | ||
xi = tuple(np.atleast_1d(np.asarray(x, dtype=float)) for x in xi) | ||
except TypeError: | ||
except (TypeError, ValueError): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what additionally is being caught here? |
||
raise TypeError("Input samples must be array-like of floats!") | ||
|
||
# check for non-empty samples | ||
|
@@ -1024,7 +1025,8 @@ def _check_rates(self): | |
"linearly dependent set of function for the basis." | ||
) | ||
|
||
def _check_sample_range(self, sample_pts: NDArray): | ||
@staticmethod | ||
def _check_sample_range(sample_pts: NDArray): | ||
""" | ||
Check if the sample points are all positive. | ||
|
||
|
@@ -1115,6 +1117,92 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]: | |
return super().evaluate_on_grid(n_samples) | ||
|
||
|
||
class FourierBasis(Basis): | ||
"""Set of 1D Fourier basis. | ||
|
||
Parameters | ||
---------- | ||
n_freqs | ||
Number of frequencies. The number of basis function will be 2*n_freqs - 1. | ||
""" | ||
|
||
def __init__(self, n_freqs: int): | ||
super().__init__(n_basis_funcs=2 * n_freqs - 1) | ||
billbrod marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
self._frequencies = np.arange(n_freqs, dtype=np.float32) | ||
self._n_input_dimensionality = 1 | ||
|
||
def _check_n_basis_min(self) -> None: | ||
"""Check that the user required enough basis elements. | ||
|
||
Checks that the number of basis is at least 1. | ||
|
||
Raises | ||
------ | ||
ValueError | ||
If an insufficient number of basis element is requested for the basis type | ||
""" | ||
if self.n_basis_funcs < 1: | ||
raise ValueError( | ||
f"Object class {self.__class__.__name__} requires >= 1 basis elements. " | ||
f"{self.n_basis_funcs} basis elements specified instead" | ||
) | ||
|
||
def evaluate(self, sample_pts: NDArray) -> NDArray: | ||
"""Generate basis functions with given spacing. | ||
|
||
Parameters | ||
---------- | ||
sample_pts | ||
Spacing for basis functions. | ||
|
||
Returns | ||
------- | ||
basis_funcs | ||
Evaluated Fourier basis, shape (n_samples, n_basis_funcs). | ||
|
||
Notes | ||
----- | ||
If the frequencies provided are np.arange(n_freq), convolving a signal | ||
of length n_samples with this basis is equivalent, but slower, | ||
then computing the FFT truncated to the first n_freq components. | ||
|
||
Therefore, convolving a signal with this basis is equivalent | ||
to compute the FFT over sliding window. | ||
billbrod marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
Examples | ||
-------- | ||
>>> import nemos as nmo | ||
>>> import numpy as np | ||
>>> n_samples, n_freqs = 1000, 10 | ||
>>> basis = nmo.basis.FourierBasis(n_freqs*2) | ||
>>> eval_basis = basis.evaluate(np.linspace(0, 1, n_samples)) | ||
>>> sinusoid = np.cos(3 * np.arange(0, 1000) * np.pi * 2 / 1000.) | ||
>>> conv = [np.convolve(eval_basis[::-1, k], sinusoid, mode='valid')[0] for k in range(2*n_freqs-1)] | ||
>>> fft = np.fft.fft(sinusoid) | ||
billbrod marked this conversation as resolved.
Show resolved
Hide resolved
|
||
>>> print('FFT power: ', np.round(np.real(fft[:10]), 4)) | ||
>>> print('Convolution: ', np.round(conv[:10], 4)) | ||
billbrod marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
(sample_pts,) = self._check_evaluate_input(sample_pts) | ||
# assumes equi-spaced samples. | ||
if sample_pts.shape[0] / np.max(self._frequencies) < 2: | ||
raise ValueError("Not enough samples, aliasing likely to occur!") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe report |
||
|
||
# rescale to [0, 2pi) | ||
mn, mx = np.nanmin(sample_pts), np.nanmax(sample_pts) | ||
# first sample in 0, last sample in 2 pi - 2 pi / n_samples. | ||
sample_pts = ( | ||
2 | ||
* np.pi | ||
* (sample_pts - mn) | ||
/ (mx - mn) | ||
* (1.0 - 1.0 / sample_pts.shape[0]) | ||
) | ||
# create the basis | ||
angles = np.einsum("i,j->ij", sample_pts, self._frequencies) | ||
return np.concatenate([np.cos(angles), -np.sin(angles[:, 1:])], axis=1) | ||
|
||
|
||
def mspline(x: NDArray, k: int, i: int, T: NDArray): | ||
"""Compute M-spline basis function. | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should probably explain orthogonal here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
still think this. at least a foot note or link