Skip to content

Commit

Permalink
channel dimensions to DenseAffineIntegralOperator and KernelIntegralO…
Browse files Browse the repository at this point in the history
…perator; additional tests for channels
  • Loading branch information
zbmorro committed Mar 25, 2024
1 parent 44a7961 commit 411eb80
Show file tree
Hide file tree
Showing 3 changed files with 233 additions and 73 deletions.
142 changes: 83 additions & 59 deletions pyapprox/sciml/integraloperators.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ def _format_nx(self, nx):


class EmbeddingOperator(IntegralOperator):
def __init__(self, dim_in: int, dim_out: int, nx=None, v0=None):
nvars_mat = dim_in*dim_out
def __init__(self, channel_in: int, channel_out: int, nx=None, v0=None):
nvars_mat = channel_in*channel_out
bounds = np.tile([-np.inf, np.inf], nvars_mat)
W = np.empty((nvars_mat,), dtype=float)
W[:] = (np.random.normal(0, 1/nvars_mat, nvars_mat) if v0 is None
Expand All @@ -42,18 +42,17 @@ def __init__(self, dim_in: int, dim_out: int, nx=None, v0=None):
"W_embedding", nvars_mat, W, bounds,
IdentityHyperParameterTransform())
self._hyp_list = HyperParameterList([self._W])
self._dim_in = dim_in
self._dim_out = dim_out
self._channel_in = channel_in
self._channel_out = channel_out
self._format_nx(nx)

def _integrate(self, y_k_samples):
if y_k_samples.ndim < 3:
raise ValueError('y_k_samples must have shape (n_x, d_c, n_train)')
if self._nx is None:
self._format_nx(y_k_samples.shape[:-2])
W = self._hyp_list.get_values().reshape(self._dim_out, self._dim_in)
# y_k_flat = y_k_samples.reshape(self._N, self._dim_in,
# y_k_samples.shape[-1])
W = self._hyp_list.get_values().reshape(self._channel_out,
self._channel_in)
return einsum('ij,...jl->...il', W, y_k_samples)

def _format_nx(self, nx):
Expand All @@ -76,73 +75,106 @@ def __init__(self, d_c=1, nx=None, v0=None):


class KernelIntegralOperator(IntegralOperator):
def __init__(self, kernel, quad_rule_k, quad_rule_kp1):
self._kernel = kernel
self._hyp_list = self._kernel.hyp_list
def __init__(self, kernels, quad_rule_k, quad_rule_kp1, channel_in=1,
channel_out=1):
if not hasattr(kernels, '__iter__'):
self._kernels = channel_in*[kernels]
elif len(kernels) != channel_in:
raise ValueError('len(kernels) must equal channel_in')
else:
self._kernels = kernels
self._hyp_list = sum([kernel.hyp_list for kernel in self._kernels])

self._quad_rule_k = quad_rule_k
self._quad_rule_kp1 = quad_rule_kp1

def _integrate(self, y_k_samples):
z_k_samples, z_k_weights = self._quad_rule_k.get_samples_weights()
# Apply matvec to each channel in parallel
z_k_samples, w_k = self._quad_rule_k.get_samples_weights()
z_kp1_samples = self._quad_rule_kp1.get_samples_weights()[0]
K_mat = self._kernel(z_kp1_samples, z_k_samples)
WK_mat = K_mat * z_k_weights[:, 0] # equivalent to K @ diag(w)
u_samples = WK_mat.double() @ y_k_samples.double()
self._WK_mat = zeros(z_kp1_samples.shape[1], z_k_samples.shape[1],
len(self._kernels))
for k in range(len(self._kernels)):
self._WK_mat[..., k] = (
self._kernels[k](z_kp1_samples, z_k_samples) * w_k[:, 0])

u_samples = einsum('ijk,jk...->ik...', self._WK_mat.double(),
y_k_samples.double())
return u_samples


class DenseAffineIntegralOperator(IntegralOperator):
def __init__(self, ninputs: int, noutputs: int, v0=None):
def __init__(self, ninputs: int, noutputs: int, v0=None, channel_in=1,
channel_out=1):
'''
Implements the usual fully connected layer of an MLP:
u_{k+1} = W_k y_k + b_k
u_{k+1} = W_k y_k + b_k (single channel)
where W_k is a 2D array of shape (N_{k+1}, N_k), y_k is a 1D array of
shape (N_k,), and b_k is a 1D array of shape (N_{k+1},)
'''
self._ninputs = ninputs
self._noutputs = noutputs
nvars_mat = self._noutputs * (self._ninputs+1)
self._channel_in = channel_in
self._channel_out = channel_out
self._b_size = self._noutputs*self._channel_out
self._nvars_mat = (self._noutputs * self._channel_out * (
self._ninputs * self._channel_in + 1))

weights_biases = self._default_values(nvars_mat, v0)
bounds = self._default_bounds(nvars_mat)
weights_biases = self._default_values(v0)
bounds = self._default_bounds()
self._weights_biases = HyperParameter(
"weights_biases", nvars_mat, weights_biases,
bounds, IdentityHyperParameterTransform())
"weights_biases", self._nvars_mat, weights_biases, bounds,
IdentityHyperParameterTransform())

self._hyp_list = HyperParameterList([self._weights_biases])

def _default_values(self, nvars_mat, v0):
weights_biases = np.empty((nvars_mat,), dtype=float)
def _default_values(self, v0):
weights_biases = np.empty((self._nvars_mat,), dtype=float)
weights_biases[:] = (
np.random.normal(0, 1, nvars_mat) if v0 is None else v0)
np.random.normal(0, 1, self._nvars_mat) if v0 is None else v0)
return weights_biases

def _default_bounds(self, nvars_mat):
return np.tile([-np.inf, np.inf], nvars_mat)
def _default_bounds(self):
return np.tile([-np.inf, np.inf], self._nvars_mat)

def _integrate(self, y_k_samples):
mat = self._weights_biases.get_values().reshape(
self._noutputs, self._ninputs+1)
W = mat[:, :-1]
b = mat[:, -1:]
return W @ y_k_samples + b
if y_k_samples.shape[-2] != self._channel_in:
if self._channel_in == 1:
y_k_samples = y_k_samples[..., None, :]
else:
raise ValueError(
'Could not infer channel dimension. y_k_samples.shape[-2] '
'must be channel_in.')

ntrain = y_k_samples.shape[-1]
W = (self._weights_biases.get_values()[:-self._b_size].reshape(
self._noutputs, self._ninputs, self._channel_out,
self._channel_in))
b = (self._weights_biases.get_values()[-self._b_size:].reshape(
self._noutputs, self._channel_out))
if self._channel_in > 1 or self._channel_out > 1:
return einsum('ijkl,jlm->ikm', W, y_k_samples) + b[..., None]
else:
# handle separately for speed
return W.squeeze() @ y_k_samples[..., 0, :] + b


class DenseAffineIntegralOperatorFixedBias(DenseAffineIntegralOperator):
def __init__(self, ninputs: int, noutputs: int, v0=None):
super().__init__(ninputs, noutputs, v0)
def __init__(self, ninputs: int, noutputs: int, v0=None, channel_in=1,
channel_out=1):
super().__init__(ninputs, noutputs, v0, channel_in, channel_out)

def _default_values(self, nvars_mat, v0):
weights_biases = super()._default_values(nvars_mat, v0)
weights_biases[self._ninputs::self._ninputs+1] = 0.
def _default_values(self, v0):
weights_biases = super()._default_values(v0)
weights_biases[-self._b_size:] = 0.
return weights_biases

def _default_bounds(self, nvars_mat):
bounds = super()._default_bounds(nvars_mat).reshape((nvars_mat, 2))
bounds[self._ninputs::self._ninputs+1, 0] = np.nan
bounds[self._ninputs::self._ninputs+1, 1] = np.nan
def _default_bounds(self):
bounds = super()._default_bounds().reshape(self._nvars_mat, 2)
bounds[-self._b_size:, 0] = np.nan
bounds[-self._b_size:, 1] = np.nan
return bounds.flatten()


Expand Down Expand Up @@ -227,15 +259,11 @@ def _integrate(self, y_k_samples):

# R[n, d_c, d_c] = c_n, -kmax <= n <= kmax
R = vstack([flip(conj(v[1:, ...]), dims=[0]), v])
R = R.reshape(*fftshift_y_proj.shape[:-2], self._channel_out,
self._channel_in)

# Do convolution and lift into original spatial resolution
fftshift_y_proj_flat = fftshift_y_proj.reshape(R.shape[0],
self._channel_in,
ntrain)
conv_shift = einsum('ijk,ikl->ijl', R, fftshift_y_proj_flat)
conv_shift = conv_shift.reshape((*fftshift_y_proj.shape[:-2],
self._channel_out,
ntrain))
conv_shift = einsum('...jk,...kl->...jl', R, fftshift_y_proj)
conv_shift_lift = zeros((*fft_y.shape[:-2], self._channel_out, ntrain),
dtype=cfloat)
conv_shift_lift[freq_slices] = conv_shift
Expand Down Expand Up @@ -288,8 +316,8 @@ def _precompute_weights(self):
W_tot_ifct *= W_ifct[k]

self._N_tot = N_tot
self._W_tot_R = W_tot.flatten()
self._W_tot_ifct = W_tot_ifct.flatten()
self._W_tot_R = W_tot
self._W_tot_ifct = W_tot_ifct

def _integrate(self, y_k_samples):
# If channel_in is not explicit in y_k_samples, then assume
Expand Down Expand Up @@ -332,19 +360,15 @@ def _integrate(self, y_k_samples):
# Construct convolution factor R; keep books on weights
if self._W_tot_R is None:
self._precompute_weights()
P = diag(self._N_tot / self._W_tot_R)
fct_y_proj_flat = fct_y_proj.reshape(P.shape[0], self._channel_in,
ntrain)
fct_y_proj_precond = einsum('ij,jkl->ikl', P, fct_y_proj_flat)
R = self._hyp_list.get_values().reshape(P.shape[0], self._channel_out,
P = self._N_tot / self._W_tot_R
fct_y_proj_precond = einsum('...,...jk->...jk', P, fct_y_proj)
R = self._hyp_list.get_values().reshape(*fct_y_proj.shape[:-2],
self._channel_out,
self._channel_in)

# Do convolution and lift into original spatial resolution
r_conv_y = einsum('ijk,ikl->ijl', R, fct_y_proj_precond)
r_conv_y = r_conv_y.reshape((*fct_y_proj.shape[:-2], self._channel_out,
ntrain))
conv_lift = zeros((*fct_y.shape[:-2], self._channel_out,
fct_y.shape[-1]))
r_conv_y = einsum('...jk,...kl->...jl', R, fct_y_proj_precond)
conv_lift = zeros((*self._nx, self._channel_out, fct_y.shape[-1]))
conv_lift[deg_slices] = r_conv_y
res = fct.ifct(conv_lift, W_tot=self._W_tot_ifct)
return res.reshape(output_shape)
Expand Down
46 changes: 38 additions & 8 deletions pyapprox/sciml/tests/test_integral_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@
from pyapprox.sciml.integraloperators import (
FourierConvolutionOperator, ChebyshevConvolutionOperator,
DenseAffineIntegralOperator, DenseAffineIntegralOperatorFixedBias,
ChebyshevIntegralOperator)
ChebyshevIntegralOperator, KernelIntegralOperator)
from pyapprox.sciml.layers import Layer
from pyapprox.sciml.activations import IdentityActivation
from pyapprox.sciml.optimizers import Adam
from pyapprox.sciml.kernels import MaternKernel
from pyapprox.sciml.quadrature import Fixed1DGaussLegendreIOQuadRule


class TestIntegralOperators(unittest.TestCase):
Expand Down Expand Up @@ -88,7 +90,6 @@ def test_chebyshev_convolution_operator_1d(self):
tol = 4e-4
relerr = (tw.norm(fct.fct(v)[:kmax+1] - ctn._hyp_list.get_values()) /
tw.norm(fct.fct(v)[:kmax+1]))
print(relerr)
assert relerr < tol, f'Relative error = {relerr:.2e} > {tol:.2e}'

def test_chebyshev_convolution_operator_multidim(self):
Expand Down Expand Up @@ -184,7 +185,7 @@ def test_dense_affine_integral_operator(self):
ctn = CERTANN(N0, [Layer([DenseAffineIntegralOperator(N0, N1)])],
[IdentityActivation()])
ctn.fit(XX, YY, tol=1e-14)
assert np.allclose(tw.hstack([W, b]).flatten(),
assert np.allclose(tw.hstack([W.flatten(), b.flatten()]),
ctn._hyp_list.get_values())

ctn = CERTANN(
Expand All @@ -193,24 +194,53 @@ def test_dense_affine_integral_operator(self):
optimizer=Adam(epochs=1000, lr=1e-2, batches=5))
ctn.fit(XX, YY, tol=1e-12)

tol = 5e-3
relerr = (tw.norm(tw.hstack([W, b]).flatten() -
tol = 1e-8
relerr = (tw.norm(tw.hstack([W.flatten(), b.flatten()]) -
ctn._hyp_list.get_values()) /
tw.norm(ctn._hyp_list.get_values()))
tw.norm(tw.hstack([W.flatten(), b.flatten()])))
assert relerr < tol, f'Relative error = {relerr:.2e} > {tol:.2e}'

def test_dense_affine_integral_operator_fixed_bias(self):
N0, N1 = 3, 5
XX = tw.asarray(np.random.normal(0, 1, (N0, 20)))
iop = DenseAffineIntegralOperatorFixedBias(N0, N1)
b = tw.full((N1, 1), 0)
W = iop._weights_biases.get_values().reshape(
iop._noutputs, iop._ninputs+1)[:, :-1]
W = iop._weights_biases.get_values()[:-N1].reshape(iop._noutputs,
iop._ninputs)
YY = W @ XX + b
assert np.allclose(iop._integrate(XX), YY), 'Quadrature error'
assert np.allclose(iop._hyp_list.nactive_vars(), N0*N1), ('Dimension '
'mismatch')

def test_parameterized_kernels_parallel_channels(self):
ninputs = 21

matern_sqexp = MaternKernel(tw.inf, [0.2], [0.01, 0.5], 1)
matern_exp = MaternKernel(0.5, [0.2], [0.01, 0.5], 1)

# One block, two channels
quad_rule_k = Fixed1DGaussLegendreIOQuadRule(ninputs)
quad_rule_kp1 = Fixed1DGaussLegendreIOQuadRule(ninputs)
iop = KernelIntegralOperator([matern_sqexp, matern_exp], quad_rule_k,
quad_rule_kp1, channel_in=2,
channel_out=2)
xx = tw.asarray(np.linspace(0, 1, ninputs))[:, None]
samples = tw.hstack([xx, xx])[..., None]
values = iop(samples)

# Two blocks, one channel
iop_sqexp = KernelIntegralOperator([matern_sqexp], quad_rule_k,
quad_rule_kp1, channel_in=1,
channel_out=1)
iop_exp = KernelIntegralOperator([matern_exp], quad_rule_k,
quad_rule_kp1, channel_in=1,
channel_out=1)

# Results should be identical
assert (np.allclose(iop_sqexp(xx), values[:, 0]) and
np.allclose(iop_exp(xx), values[:, 1])), ('Kernel integral '
'operators not acting on channels in parallel')

def test_chebno_channels(self):
n = 21
w = fct.make_weights(n)[:, None]
Expand Down
Loading

0 comments on commit 411eb80

Please sign in to comment.