Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add option to regularize linear fits #17

Merged
merged 15 commits into from
Aug 23, 2023
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 69 additions & 15 deletions hera_filters/dspec.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,8 +270,9 @@ def calc_width(filter_size, real_delta, nsamples):
lthresh = nsamples
return (uthresh, lthresh)

def fourier_filter(x, data, wgts, filter_centers, filter_half_widths, mode,
filter_dims=1, skip_wgt=0.1, zero_residual_flags=True, **filter_kwargs):
def fourier_filter(x, data, wgts, filter_centers, filter_half_widths, mode, ridge_alpha=0.0,
fit_intercept=False, filter_dims=1, skip_wgt=0.1, zero_residual_flags=True,
**filter_kwargs):
'''
A filtering function that wraps up all functionality of high_pass_fourier_filter
and add support for additional linear fitting options.
Expand Down Expand Up @@ -352,6 +353,19 @@ def fourier_filter(x, data, wgts, filter_centers, filter_half_widths, mode,
'dpss_matrix' method (see above)
'dayenu_clean', apply dayenu filter to data. Deconvolve
subtracted foregrounds with 'clean'.
ridge_alpha: float, optional
Regularization parameter used in ridge regression. Default is 0, if value is equal to zero,
then no regularization is applied. If value is greater , ridge_alpha is used as
the regularization parameter in ridge regression (specifically the main diagonal of the XTX product
is multiplied by a value of (1 + ridge_alpha)). Only used in the following linear modes
(dpss_leastsq, dft_leastsq, dpss_solve, dft_solve, dpss_matrix, dft_matrix). Reasonable values
for ridge_alpha when using the DPSS and DFT modes for inpainting wide gaps are between 1e-5 and 1e-2,
but will depend on factors such as the noise level in the data and the flagging mask.
fit_intercept: bool, optional
If true, subtracts off average of the data before fitting model to the data.
Default is False. Can be useful if the data is not centered around zero and
the user is fitting a regularized linear model (i.e if ridge_alpha > 0.0),
otherwise model will likely trend to zero in wide gaps.
zero_residual_flags : bool, optional.
If true, set flagged channels in the residual equal to zero.
Default is True.
Expand Down Expand Up @@ -479,6 +493,8 @@ def fourier_filter(x, data, wgts, filter_centers, filter_half_widths, mode,
raise ValueError("data must be a 1D or 2D ndarray")
if not ndim_wgts == ndim_data:
raise ValueError("Number of dimensions in weights, %d does not equal number of dimensions in data, %d!"%(ndim_wgts, ndim_data))

assert ridge_alpha >= 0.0, "ridge_alpha must be greater than or equal to zero."
#The core code of this method will always assume 2d data
if ndim_data == 1:
data = np.asarray([data])
Expand Down Expand Up @@ -535,6 +551,12 @@ def fourier_filter(x, data, wgts, filter_centers, filter_half_widths, mode,
else:
defaults = CLEAN_DEFAULTS_1D

if fit_intercept:
# subtract off mean of data
mean = np.sum(data * wgts, axis=tuple(filter_dims), keepdims=True) / np.sum(wgts, axis=tuple(filter_dims), keepdims=True)
data = np.copy(data) # make a copy so we don't modify the original data
data -= mean

_process_filter_kwargs(filter_kwargs, defaults)
if 'dft' in mode:
fp = np.asarray(filter_kwargs['fundamental_period']).flatten()
Expand Down Expand Up @@ -574,7 +596,7 @@ def fourier_filter(x, data, wgts, filter_centers, filter_half_widths, mode,
skip_wgt=skip_wgt, basis=mode[1], method=mode[2], wgts=wgts, basis_options=filter_kwargs,
filter_half_widths=filter_half_widths, suppression_factors=suppression_factors,
cache=cache, max_contiguous_edge_flags=max_contiguous_edge_flags,
zero_residual_flags=zero_residual_flags)
zero_residual_flags=zero_residual_flags, ridge_alpha=ridge_alpha)
info['info_deconv']=info_deconv

elif mode[0] in ['dft', 'dpss']:
Expand All @@ -594,7 +616,7 @@ def fourier_filter(x, data, wgts, filter_centers, filter_half_widths, mode,
skip_wgt=skip_wgt, basis=mode[0], method=mode[1], wgts=wgts, basis_options=filter_kwargs,
filter_half_widths=filter_half_widths, suppression_factors=suppression_factors,
cache=cache, max_contiguous_edge_flags=max_contiguous_edge_flags,
zero_residual_flags=zero_residual_flags)
zero_residual_flags=zero_residual_flags, ridge_alpha=ridge_alpha)
elif mode[0] == 'clean':
if zero_residual_flags is None:
zero_residual_flags = False
Expand Down Expand Up @@ -627,6 +649,13 @@ def fourier_filter(x, data, wgts, filter_centers, filter_half_widths, mode,
if ndim_data == 1:
model = model.flatten()
residual = residual.flatten()

if fit_intercept:
if ndim_data == 1:
mean = mean.flatten()

model += mean # add back mean of data to the model

return model, residual, info

def vis_clean(data, wgts, filter_size, real_delta, clean2d=False, tol=1e-9, window='none',
Expand Down Expand Up @@ -1561,7 +1590,7 @@ def delay_filter_leastsq(data, flags, sigma, nmax, add_noise=False,

def _fit_basis_1d(x, y, w, filter_centers, filter_half_widths,
basis_options, suppression_factors=None, hash_decimal=10,
method='leastsq', basis='dft', cache=None):
method='leastsq', basis='dft', cache=None, ridge_alpha=0.0):
r"""
A 1d linear-least-squares fitting function for computing models and residuals for fitting of the form
y_model = A @ c
Expand Down Expand Up @@ -1627,6 +1656,14 @@ def _fit_basis_1d(x, y, w, filter_centers, filter_half_widths,
using scipy.optimize.leastsq
*'matrix' derive model by directly calculate the fitting matrix
[A^T W A]^{-1} A^T W and applying it to the y vector.
ridge_alpha: float, optional
Regularization parameter used in ridge regression. Default is 0, if value is equal to zero,
then no regularization is applied. If value is greater than zero, ridge_alpha is used as
the regularization parameter in ridge regression (specifically the main diagonal of the XTX product
is multiplied by a value of (1 + ridge_alpha)). Only used in the following linear modes
(dpss_leastsq, dft_leastsq, dpss_solve, dft_solve, dpss_matrix, dft_matrix). Reasonable values
for ridge_alpha when using the DPSS and DFT modes for inpainting wide gaps are between 1e-5 and 1e-2,
but will depend on factors such as the noise level in the data and the flagging mask.


Returns:
Expand Down Expand Up @@ -1689,7 +1726,7 @@ def _fit_basis_1d(x, y, w, filter_centers, filter_half_widths,
x=x, hash_decimal=hash_decimal, label='covariance')
fm_key = _fourier_filter_hash(filter_centers=filter_centers, filter_half_widths=filter_half_widths,
filter_factors=suppression_vector, x=x, w=w, hash_decimal=hash_decimal,
label='fitting matrix', basis=basis, mode=method)
label='fitting matrix', basis=basis, mode=method, ridge_alpha=ridge_alpha)
if square_key in cache:
covmat = cache[square_key]
else:
Expand All @@ -1698,6 +1735,7 @@ def _fit_basis_1d(x, y, w, filter_centers, filter_half_widths,

if not fm_key in cache:
XTX = covmat - np.conj(amat[flags]).T @ amat[flags]
XTX.flat[::XTX.shape[0] + 1] *= (1 + ridge_alpha) # add regularization term

Xy = np.conj(amat[mask]).T @ y[mask]

Expand Down Expand Up @@ -1736,9 +1774,10 @@ def _fit_basis_1d(x, y, w, filter_centers, filter_half_widths,
raise ValueError("Provided 'method', '%s', is not in ['leastsq', 'matrix', 'solve', 'cholesky']."%(method))
else:
if method == 'leastsq':
a = np.atleast_2d(w).T * amat
a = np.dot((np.atleast_2d(w).T * amat).T.conj(), amat)
a.flat[::a.shape[0] + 1] *= (1 + ridge_alpha) # add regularization term
try:
res = lsq_linear(a, w * y)
res = lsq_linear(a, amat.T.conj().dot(w * y))
cn_out = res.x
# np.linalg.LinAlgError catches "SVD did not converge."
# which can happen if the solution is under-constrained.
Expand All @@ -1750,23 +1789,24 @@ def _fit_basis_1d(x, y, w, filter_centers, filter_half_widths,
elif method == 'matrix':
fm_key = _fourier_filter_hash(filter_centers=filter_centers, filter_half_widths=filter_half_widths,
filter_factors=suppression_vector, x=x, w=w, hash_decimal=hash_decimal,
label='fitting matrix', basis=basis)
label='fitting matrix', basis=basis, ridge_alpha=ridge_alpha)
if basis.lower() == 'dft':
fm_key = fm_key + (basis_options['fundamental_period'], )
elif basis.lower() == 'dpss':
fm_key = fm_key + tuple(nterms)
fmat = fit_solution_matrix(w, amat, cache=cache, fit_mat_key=fm_key)
fmat = fit_solution_matrix(w, amat, cache=cache, fit_mat_key=fm_key, ridge_alpha=ridge_alpha)
info['fitting_matrix'] = fmat
cn_out = fmat @ y

elif method == 'solve':
fm_key = _fourier_filter_hash(filter_centers=filter_centers, filter_half_widths=filter_half_widths,
filter_factors=suppression_vector, x=x, w=w, hash_decimal=hash_decimal,
label='fitting matrix', basis=basis, mode=method)
label='fitting matrix', basis=basis, mode=method, alpha=ridge_alpha)
if fm_key in cache:
L = cache[fm_key]
else:
XTX = np.dot(np.conj(amat).T * w, amat)
XTX.flat[::XTX.shape[0] + 1] *= (1 + ridge_alpha) # add regularization term
L = linalg.lu_factor(XTX)
cache[fm_key] = L

Expand Down Expand Up @@ -1942,7 +1982,7 @@ def _clean_filter(x, data, wgts, filter_centers, filter_half_widths,
def _fit_basis_2d(x, data, wgts, filter_centers, filter_half_widths,
basis_options, suppression_factors=None,
method='leastsq', basis='dft', cache=None,
filter_dims = 1, skip_wgt=0.1, max_contiguous_edge_flags=5,
filter_dims = 1, skip_wgt=0.1, max_contiguous_edge_flags=5, ridge_alpha=0.0,
zero_residual_flags=True):
r"""
A 1d linear-least-squares fitting function for computing models and residuals for fitting of the form
Expand Down Expand Up @@ -2024,6 +2064,15 @@ def _fit_basis_2d(x, data, wgts, filter_centers, filter_half_widths,
zero_residual_flags : bool, optional.
If true, set flagged channels in the residual equal to zero.
Default is True.
ridge_alpha: float, optional
Regularization parameter used in ridge regression. Default is 0, if value is equal to zero,
then no regularization is applied. If value is greater than zero, ridge_alpha is used as
the regularization parameter in ridge regression (specifically the main diagonal of the XTX product
is multiplied by a value of (1 + ridge_alpha)). Only used in the following linear modes
(dpss_leastsq, dft_leastsq, dpss_solve, dft_solve, dpss_matrix, dft_matrix). Reasonable values
for ridge_alpha when using the DPSS and DFT modes for inpainting wide gaps are between 1e-5 and 1e-2,
but will depend on factors such as the noise level in the data and the flagging mask.

Returns
-------
model: array-like
Expand Down Expand Up @@ -2077,7 +2126,7 @@ def _fit_basis_2d(x, data, wgts, filter_centers, filter_half_widths,
filter_half_widths=filter_half_widths[1],
suppression_factors=suppression_factors[1],
basis_options=basis_options[1], method=method,
basis=basis, cache=cache)
basis=basis, cache=cache, ridge_alpha=ridge_alpha)
if info_t['skipped']:
info['status']['axis_1'][i] = 'skipped'
else:
Expand Down Expand Up @@ -2107,7 +2156,7 @@ def _fit_basis_2d(x, data, wgts, filter_centers, filter_half_widths,
filter_half_widths=filter_half_widths[0],
suppression_factors=suppression_factors[0],
basis_options=basis_options[0], method=method,
basis=basis, cache=cache)
basis=basis, cache=cache, ridge_alpha=ridge_alpha)
if info_t['skipped']:
info['status']['axis_0'][i] = 'skipped'
else:
Expand Down Expand Up @@ -2137,7 +2186,7 @@ def _fit_basis_2d(x, data, wgts, filter_centers, filter_half_widths,
return model, residual, info


def fit_solution_matrix(weights, design_matrix, cache=None, hash_decimal=10, fit_mat_key=None):
def fit_solution_matrix(weights, design_matrix, cache=None, hash_decimal=10, fit_mat_key=None, ridge_alpha=0.0):
"""
Calculate the linear least squares solution matrix
from a design matrix, A and a weights matrix W
Expand All @@ -2156,6 +2205,9 @@ def fit_solution_matrix(weights, design_matrix, cache=None, hash_decimal=10, fit
fit_mat_key: optional hashable variable
optional key. If none is used, hash fit matrix against design and
weighting matrix.
alpha: float, optional
Regularization parameter. If non-zero, adds alpha * I to the
fitting matrix. Default is 0.0.

Returns
-----------
Expand Down Expand Up @@ -2185,6 +2237,8 @@ def fit_solution_matrix(weights, design_matrix, cache=None, hash_decimal=10, fit
xwmat = np.conj(design_matrix.T) @ weights
cmat = xwmat @ design_matrix

cmat.flat[::cmat.shape[0] + 1] *= (1 + ridge_alpha)

#should there be a conjugation!?!
if np.linalg.cond(cmat)>=1e9:
warn('Warning!!!!: Poorly conditioned matrix! Your linear inpainting IS WRONG!')
Expand Down
37 changes: 37 additions & 0 deletions hera_filters/tests/test_dspec.py
Original file line number Diff line number Diff line change
Expand Up @@ -999,6 +999,43 @@ def get_snr(clean, fftax=1, avgax=0, modes=[2, 20]):
assert np.isclose(info_dft['filter_params']['axis_0']['basis_options']['fundamental_period'],
dft_options2_2d['fundamental_period'][0])

def test_regularized_regression():
nfreqs = 500
freqs = np.linspace(50e6, 250e6, nfreqs)

# Simulate some data
C = np.sinc(2 * (freqs[None] - freqs[:, None]) * 100e-9)
y = np.random.multivariate_normal(np.zeros(nfreqs), C) + 1j * np.random.multivariate_normal(np.zeros(nfreqs), C)
d = y + np.random.normal(0, 0.1, size=nfreqs) + np.random.normal(0, 0.1, size=nfreqs) * 1j

# Create a mask to simulate missing data
w = np.ones(nfreqs)
w[200 : 200 + int((1 / (700e-9 / 4)) / np.diff(freqs)[0] + 1)] -= 1

# Compare regularized regression to standard least squares
mdl_reg, res_reg, _ = dspec.fourier_filter(freqs, d, w, [0.], [700e-9], suppression_factors=[0.],
mode='dpss_solve', ridge_alpha=1e-3, eigenval_cutoff=[1e-12])
mdl, res, _ = dspec.fourier_filter(freqs, d, w, [0.], [700e-9], suppression_factors=[0.],
mode='dpss_solve', eigenval_cutoff=[1e-12])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we test the other solvers that use ridge alphas?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Specifically matrix and leastsq


# Check that the regularized regression has a smaller residual norm in the flagged region
assert np.linalg.norm((d - mdl_reg)[~w.astype(bool)]) < np.linalg.norm((d - mdl)[~w.astype(bool)])

# Check that de-meaning the data improves interpolation
d += 20 + 20j

mdl_reg_demean, res_reg_demean, _ = dspec.fourier_filter(freqs, d, w, [0.], [700e-9], suppression_factors=[0.],
mode='dpss_solve', ridge_alpha=1e-3, eigenval_cutoff=[1e-12], fit_intercept=True)
mdl_reg, res_reg, _ = dspec.fourier_filter(freqs, d, w, [0.], [700e-9], suppression_factors=[0.],
mode='dpss_solve', ridge_alpha=1e-3, eigenval_cutoff=[1e-12])

# Check that the demeaned regularized regression has a smaller residual norm in the flagged region than
# the non-demeaned regularized regression. This is because ridge regression reduces the amplitude of the
# coefficients, leading to a near-zero mean in the flagged region, which can be a poor prediction of the
# inpainted region given non-zero mean data.
assert np.linalg.norm((d - mdl_reg_demean)[~w.astype(bool)]) < np.linalg.norm((d - mdl_reg)[~w.astype(bool)])


def test_vis_clean():
# validate that fourier_filter in various clean modes gives close values to vis_clean with equivalent parameters!
uvd = UVData()
Expand Down
Loading