Skip to content

Commit

Permalink
Add option to regularize linear fits (#17)
Browse files Browse the repository at this point in the history
* Add option to regularize linear fits

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add new tests for regularization

* Add option to de-mean the data when regularizing

* Add regularization to leastsq mode

* Add tests for de-meaning data

* More documentation for ridge_alpha parameter

* Correct documentation

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Exaggerate vertical offset to pass tests

* Remove leftover print statement

* Check that reg for all modes gives same answer

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
tyler-a-cox and pre-commit-ci[bot] authored Aug 23, 2023
1 parent a4ff591 commit 5c8c10f
Show file tree
Hide file tree
Showing 2 changed files with 118 additions and 15 deletions.
84 changes: 69 additions & 15 deletions hera_filters/dspec.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,8 +270,9 @@ def calc_width(filter_size, real_delta, nsamples):
lthresh = nsamples
return (uthresh, lthresh)

def fourier_filter(x, data, wgts, filter_centers, filter_half_widths, mode,
filter_dims=1, skip_wgt=0.1, zero_residual_flags=True, **filter_kwargs):
def fourier_filter(x, data, wgts, filter_centers, filter_half_widths, mode, ridge_alpha=0.0,
fit_intercept=False, filter_dims=1, skip_wgt=0.1, zero_residual_flags=True,
**filter_kwargs):
'''
A filtering function that wraps up all functionality of high_pass_fourier_filter
and add support for additional linear fitting options.
Expand Down Expand Up @@ -352,6 +353,19 @@ def fourier_filter(x, data, wgts, filter_centers, filter_half_widths, mode,
'dpss_matrix' method (see above)
'dayenu_clean', apply dayenu filter to data. Deconvolve
subtracted foregrounds with 'clean'.
ridge_alpha: float, optional
Regularization parameter used in ridge regression. Default is 0, if value is equal to zero,
then no regularization is applied. If value is greater , ridge_alpha is used as
the regularization parameter in ridge regression (specifically the main diagonal of the XTX product
is multiplied by a value of (1 + ridge_alpha)). Only used in the following linear modes
(dpss_leastsq, dft_leastsq, dpss_solve, dft_solve, dpss_matrix, dft_matrix). Reasonable values
for ridge_alpha when using the DPSS and DFT modes for inpainting wide gaps are between 1e-5 and 1e-2,
but will depend on factors such as the noise level in the data and the flagging mask.
fit_intercept: bool, optional
If true, subtracts off average of the data before fitting model to the data.
Default is False. Can be useful if the data is not centered around zero and
the user is fitting a regularized linear model (i.e if ridge_alpha > 0.0),
otherwise model will likely trend to zero in wide gaps.
zero_residual_flags : bool, optional.
If true, set flagged channels in the residual equal to zero.
Default is True.
Expand Down Expand Up @@ -479,6 +493,8 @@ def fourier_filter(x, data, wgts, filter_centers, filter_half_widths, mode,
raise ValueError("data must be a 1D or 2D ndarray")
if not ndim_wgts == ndim_data:
raise ValueError("Number of dimensions in weights, %d does not equal number of dimensions in data, %d!"%(ndim_wgts, ndim_data))

assert ridge_alpha >= 0.0, "ridge_alpha must be greater than or equal to zero."
#The core code of this method will always assume 2d data
if ndim_data == 1:
data = np.asarray([data])
Expand Down Expand Up @@ -535,6 +551,12 @@ def fourier_filter(x, data, wgts, filter_centers, filter_half_widths, mode,
else:
defaults = CLEAN_DEFAULTS_1D

if fit_intercept:
# subtract off mean of data
mean = np.sum(data * wgts, axis=tuple(filter_dims), keepdims=True) / np.sum(wgts, axis=tuple(filter_dims), keepdims=True)
data = np.copy(data) # make a copy so we don't modify the original data
data -= mean

_process_filter_kwargs(filter_kwargs, defaults)
if 'dft' in mode:
fp = np.asarray(filter_kwargs['fundamental_period']).flatten()
Expand Down Expand Up @@ -574,7 +596,7 @@ def fourier_filter(x, data, wgts, filter_centers, filter_half_widths, mode,
skip_wgt=skip_wgt, basis=mode[1], method=mode[2], wgts=wgts, basis_options=filter_kwargs,
filter_half_widths=filter_half_widths, suppression_factors=suppression_factors,
cache=cache, max_contiguous_edge_flags=max_contiguous_edge_flags,
zero_residual_flags=zero_residual_flags)
zero_residual_flags=zero_residual_flags, ridge_alpha=ridge_alpha)
info['info_deconv']=info_deconv

elif mode[0] in ['dft', 'dpss']:
Expand All @@ -594,7 +616,7 @@ def fourier_filter(x, data, wgts, filter_centers, filter_half_widths, mode,
skip_wgt=skip_wgt, basis=mode[0], method=mode[1], wgts=wgts, basis_options=filter_kwargs,
filter_half_widths=filter_half_widths, suppression_factors=suppression_factors,
cache=cache, max_contiguous_edge_flags=max_contiguous_edge_flags,
zero_residual_flags=zero_residual_flags)
zero_residual_flags=zero_residual_flags, ridge_alpha=ridge_alpha)
elif mode[0] == 'clean':
if zero_residual_flags is None:
zero_residual_flags = False
Expand Down Expand Up @@ -627,6 +649,13 @@ def fourier_filter(x, data, wgts, filter_centers, filter_half_widths, mode,
if ndim_data == 1:
model = model.flatten()
residual = residual.flatten()

if fit_intercept:
if ndim_data == 1:
mean = mean.flatten()

model += mean # add back mean of data to the model

return model, residual, info

def vis_clean(data, wgts, filter_size, real_delta, clean2d=False, tol=1e-9, window='none',
Expand Down Expand Up @@ -1561,7 +1590,7 @@ def delay_filter_leastsq(data, flags, sigma, nmax, add_noise=False,

def _fit_basis_1d(x, y, w, filter_centers, filter_half_widths,
basis_options, suppression_factors=None, hash_decimal=10,
method='leastsq', basis='dft', cache=None):
method='leastsq', basis='dft', cache=None, ridge_alpha=0.0):
r"""
A 1d linear-least-squares fitting function for computing models and residuals for fitting of the form
y_model = A @ c
Expand Down Expand Up @@ -1627,6 +1656,14 @@ def _fit_basis_1d(x, y, w, filter_centers, filter_half_widths,
using scipy.optimize.leastsq
*'matrix' derive model by directly calculate the fitting matrix
[A^T W A]^{-1} A^T W and applying it to the y vector.
ridge_alpha: float, optional
Regularization parameter used in ridge regression. Default is 0, if value is equal to zero,
then no regularization is applied. If value is greater than zero, ridge_alpha is used as
the regularization parameter in ridge regression (specifically the main diagonal of the XTX product
is multiplied by a value of (1 + ridge_alpha)). Only used in the following linear modes
(dpss_leastsq, dft_leastsq, dpss_solve, dft_solve, dpss_matrix, dft_matrix). Reasonable values
for ridge_alpha when using the DPSS and DFT modes for inpainting wide gaps are between 1e-5 and 1e-2,
but will depend on factors such as the noise level in the data and the flagging mask.
Returns:
Expand Down Expand Up @@ -1689,7 +1726,7 @@ def _fit_basis_1d(x, y, w, filter_centers, filter_half_widths,
x=x, hash_decimal=hash_decimal, label='covariance')
fm_key = _fourier_filter_hash(filter_centers=filter_centers, filter_half_widths=filter_half_widths,
filter_factors=suppression_vector, x=x, w=w, hash_decimal=hash_decimal,
label='fitting matrix', basis=basis, mode=method)
label='fitting matrix', basis=basis, mode=method, ridge_alpha=ridge_alpha)
if square_key in cache:
covmat = cache[square_key]
else:
Expand All @@ -1698,6 +1735,7 @@ def _fit_basis_1d(x, y, w, filter_centers, filter_half_widths,

if not fm_key in cache:
XTX = covmat - np.conj(amat[flags]).T @ amat[flags]
XTX.flat[::XTX.shape[0] + 1] *= (1 + ridge_alpha) # add regularization term

Xy = np.conj(amat[mask]).T @ y[mask]

Expand Down Expand Up @@ -1736,9 +1774,10 @@ def _fit_basis_1d(x, y, w, filter_centers, filter_half_widths,
raise ValueError("Provided 'method', '%s', is not in ['leastsq', 'matrix', 'solve', 'cholesky']."%(method))
else:
if method == 'leastsq':
a = np.atleast_2d(w).T * amat
a = np.dot((np.atleast_2d(w).T * amat).T.conj(), amat)
a.flat[::a.shape[0] + 1] *= (1 + ridge_alpha) # add regularization term
try:
res = lsq_linear(a, w * y)
res = lsq_linear(a, amat.T.conj().dot(w * y))
cn_out = res.x
# np.linalg.LinAlgError catches "SVD did not converge."
# which can happen if the solution is under-constrained.
Expand All @@ -1750,23 +1789,24 @@ def _fit_basis_1d(x, y, w, filter_centers, filter_half_widths,
elif method == 'matrix':
fm_key = _fourier_filter_hash(filter_centers=filter_centers, filter_half_widths=filter_half_widths,
filter_factors=suppression_vector, x=x, w=w, hash_decimal=hash_decimal,
label='fitting matrix', basis=basis)
label='fitting matrix', basis=basis, ridge_alpha=ridge_alpha)
if basis.lower() == 'dft':
fm_key = fm_key + (basis_options['fundamental_period'], )
elif basis.lower() == 'dpss':
fm_key = fm_key + tuple(nterms)
fmat = fit_solution_matrix(w, amat, cache=cache, fit_mat_key=fm_key)
fmat = fit_solution_matrix(w, amat, cache=cache, fit_mat_key=fm_key, ridge_alpha=ridge_alpha)
info['fitting_matrix'] = fmat
cn_out = fmat @ y

elif method == 'solve':
fm_key = _fourier_filter_hash(filter_centers=filter_centers, filter_half_widths=filter_half_widths,
filter_factors=suppression_vector, x=x, w=w, hash_decimal=hash_decimal,
label='fitting matrix', basis=basis, mode=method)
label='fitting matrix', basis=basis, mode=method, alpha=ridge_alpha)
if fm_key in cache:
L = cache[fm_key]
else:
XTX = np.dot(np.conj(amat).T * w, amat)
XTX.flat[::XTX.shape[0] + 1] *= (1 + ridge_alpha) # add regularization term
L = linalg.lu_factor(XTX)
cache[fm_key] = L

Expand Down Expand Up @@ -1942,7 +1982,7 @@ def _clean_filter(x, data, wgts, filter_centers, filter_half_widths,
def _fit_basis_2d(x, data, wgts, filter_centers, filter_half_widths,
basis_options, suppression_factors=None,
method='leastsq', basis='dft', cache=None,
filter_dims = 1, skip_wgt=0.1, max_contiguous_edge_flags=5,
filter_dims = 1, skip_wgt=0.1, max_contiguous_edge_flags=5, ridge_alpha=0.0,
zero_residual_flags=True):
r"""
A 1d linear-least-squares fitting function for computing models and residuals for fitting of the form
Expand Down Expand Up @@ -2024,6 +2064,15 @@ def _fit_basis_2d(x, data, wgts, filter_centers, filter_half_widths,
zero_residual_flags : bool, optional.
If true, set flagged channels in the residual equal to zero.
Default is True.
ridge_alpha: float, optional
Regularization parameter used in ridge regression. Default is 0, if value is equal to zero,
then no regularization is applied. If value is greater than zero, ridge_alpha is used as
the regularization parameter in ridge regression (specifically the main diagonal of the XTX product
is multiplied by a value of (1 + ridge_alpha)). Only used in the following linear modes
(dpss_leastsq, dft_leastsq, dpss_solve, dft_solve, dpss_matrix, dft_matrix). Reasonable values
for ridge_alpha when using the DPSS and DFT modes for inpainting wide gaps are between 1e-5 and 1e-2,
but will depend on factors such as the noise level in the data and the flagging mask.
Returns
-------
model: array-like
Expand Down Expand Up @@ -2077,7 +2126,7 @@ def _fit_basis_2d(x, data, wgts, filter_centers, filter_half_widths,
filter_half_widths=filter_half_widths[1],
suppression_factors=suppression_factors[1],
basis_options=basis_options[1], method=method,
basis=basis, cache=cache)
basis=basis, cache=cache, ridge_alpha=ridge_alpha)
if info_t['skipped']:
info['status']['axis_1'][i] = 'skipped'
else:
Expand Down Expand Up @@ -2107,7 +2156,7 @@ def _fit_basis_2d(x, data, wgts, filter_centers, filter_half_widths,
filter_half_widths=filter_half_widths[0],
suppression_factors=suppression_factors[0],
basis_options=basis_options[0], method=method,
basis=basis, cache=cache)
basis=basis, cache=cache, ridge_alpha=ridge_alpha)
if info_t['skipped']:
info['status']['axis_0'][i] = 'skipped'
else:
Expand Down Expand Up @@ -2137,7 +2186,7 @@ def _fit_basis_2d(x, data, wgts, filter_centers, filter_half_widths,
return model, residual, info


def fit_solution_matrix(weights, design_matrix, cache=None, hash_decimal=10, fit_mat_key=None):
def fit_solution_matrix(weights, design_matrix, cache=None, hash_decimal=10, fit_mat_key=None, ridge_alpha=0.0):
"""
Calculate the linear least squares solution matrix
from a design matrix, A and a weights matrix W
Expand All @@ -2156,6 +2205,9 @@ def fit_solution_matrix(weights, design_matrix, cache=None, hash_decimal=10, fit
fit_mat_key: optional hashable variable
optional key. If none is used, hash fit matrix against design and
weighting matrix.
alpha: float, optional
Regularization parameter. If non-zero, adds alpha * I to the
fitting matrix. Default is 0.0.
Returns
-----------
Expand Down Expand Up @@ -2185,6 +2237,8 @@ def fit_solution_matrix(weights, design_matrix, cache=None, hash_decimal=10, fit
xwmat = np.conj(design_matrix.T) @ weights
cmat = xwmat @ design_matrix

cmat.flat[::cmat.shape[0] + 1] *= (1 + ridge_alpha)

#should there be a conjugation!?!
if np.linalg.cond(cmat)>=1e9:
warn('Warning!!!!: Poorly conditioned matrix! Your linear inpainting IS WRONG!')
Expand Down
49 changes: 49 additions & 0 deletions hera_filters/tests/test_dspec.py
Original file line number Diff line number Diff line change
Expand Up @@ -999,6 +999,55 @@ def get_snr(clean, fftax=1, avgax=0, modes=[2, 20]):
assert np.isclose(info_dft['filter_params']['axis_0']['basis_options']['fundamental_period'],
dft_options2_2d['fundamental_period'][0])

def test_regularized_regression():
nfreqs = 500
freqs = np.linspace(50e6, 250e6, nfreqs)

# Simulate some data
C = np.sinc(2 * (freqs[None] - freqs[:, None]) * 100e-9)
y = np.random.multivariate_normal(np.zeros(nfreqs), C) + 1j * np.random.multivariate_normal(np.zeros(nfreqs), C)
d = y + np.random.normal(0, 0.1, size=nfreqs) + np.random.normal(0, 0.1, size=nfreqs) * 1j

# Create a mask to simulate missing data
w = np.ones(nfreqs)
w[200 : 200 + int((1 / (700e-9 / 4)) / np.diff(freqs)[0] + 1)] -= 1

# Compare regularized regression to standard least squares for each dpss mode
mdls = []
for mode in ['dpss_solve', 'dpss_leastsq', 'dpss_matrix']:
mdl_reg, res_reg, _ = dspec.fourier_filter(freqs, d, w, [0.], [700e-9], suppression_factors=[0.],
mode=mode, ridge_alpha=1e-3, eigenval_cutoff=[1e-12])
mdl, res, _ = dspec.fourier_filter(freqs, d, w, [0.], [700e-9], suppression_factors=[0.],
mode=mode, eigenval_cutoff=[1e-12])
mdls.append(mdl_reg)
# Check that the regularized regression has a smaller residual norm in the flagged region
assert np.linalg.norm((d - mdl_reg)[~w.astype(bool)]) < np.linalg.norm((d - mdl)[~w.astype(bool)])

# Check that the regularized regression models are close to each other
assert np.all(np.isclose(mdls[0], mdls[1]))
assert np.all(np.isclose(mdls[0], mdls[2]))

# Check that de-meaning the data improves interpolation
d += 20 + 20j
mdls = []
for mode in ['dpss_solve', 'dpss_leastsq', 'dpss_matrix']:
mdl_reg_demean, res_reg_demean, _ = dspec.fourier_filter(freqs, d, w, [0.], [700e-9], suppression_factors=[0.],
mode=mode, ridge_alpha=1e-3, eigenval_cutoff=[1e-12], fit_intercept=True)
mdls.append(mdl_reg_demean)
mdl_reg, res_reg, _ = dspec.fourier_filter(freqs, d, w, [0.], [700e-9], suppression_factors=[0.],
mode=mode, ridge_alpha=1e-3, eigenval_cutoff=[1e-12])

# Check that the demeaned regularized regression has a smaller residual norm in the flagged region than
# the non-demeaned regularized regression. This is because ridge regression reduces the amplitude of the
# coefficients, leading to a near-zero mean in the flagged region, which can be a poor prediction of the
# inpainted region given non-zero mean data.
assert np.linalg.norm((d - mdl_reg_demean)[~w.astype(bool)]) < np.linalg.norm((d - mdl_reg)[~w.astype(bool)])

# Check that the regularized regression models are close to each other
assert np.all(np.isclose(mdls[0], mdls[1]))
assert np.all(np.isclose(mdls[0], mdls[2]))


def test_vis_clean():
# validate that fourier_filter in various clean modes gives close values to vis_clean with equivalent parameters!
uvd = UVData()
Expand Down

0 comments on commit 5c8c10f

Please sign in to comment.