Skip to content

Commit

Permalink
Don't make multiple plots by default. Merge models.
Browse files Browse the repository at this point in the history
  • Loading branch information
Syver Døving Agdestein committed Oct 6, 2023
1 parent fd8764e commit b02bdc7
Showing 1 changed file with 63 additions and 83 deletions.
146 changes: 63 additions & 83 deletions tutorials/burgers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -753,14 +753,30 @@ end
# During training, we will monitor the error on the validation dataset with a
# callback. We will plot the history of the a priori and a posteriori errors.

## Initial empty history
## Initial empty history (with no-model errrors)
initial_callbackstate() = (; ihist = Int[], ehist_prior = zeros(0), ehist_post = zeros(0))

## Plot convergence
function plot_convergence(state, data)
e_post_ref = trajectory_error(dns, data.u; data.dt, data.μ)
fig = plot(; yscale = :log10, xlabel = "Iterations", title = "Relative error")
hline!(fig, [1.0]; color = 1, linestyle = :dash, label = "A priori: No model")
plot!(fig, state.ihist, state.ehist_prior; color = 1, label = "A priori: Model")
hline!(
fig,
[e_post_ref];
color = 2,
linestyle = :dash,
label = "A posteriori: No model",
)
plot!(fig, state.ihist, state.ehist_post; color = 2, label = "A posteriori: Model")
fig
end

## Create callback for given model and dataset
function create_callback(m, data)
function create_callback(m, data; doplot = false)
(; u, c, dt, μ) = data
uu, cc = reshape(u, size(u, 1), :), reshape(c, size(c, 1), :)
e_post_ref = trajectory_error(dns, u; dt, μ)
function callback(i, θ, state)
(; ihist, ehist_prior, ehist_post) = state
eprior = norm(m(uu, θ) - cc) / norm(cc)
Expand All @@ -770,19 +786,8 @@ function create_callback(m, data)
ehist_prior = vcat(ehist_prior, eprior),
ehist_post = vcat(ehist_post, epost),
)
fig = plot(; yscale = :log10, xlabel = "Iterations", title = "Relative error")
hline!(fig, [1.0]; color = 1, linestyle = :dash, label = "A priori: No model")
plot!(fig, state.ihist, state.ehist_prior; color = 1, label = "A priori: Model")
hline!(
fig,
[e_post_ref];
color = 2,
linestyle = :dash,
label = "A posteriori: No model",
)
plot!(fig, state.ihist, state.ehist_post; color = 2, label = "A posteriori: Model")
display(fig)
@printf "Iteration %d\ta priori error: %.4g\ta posteriori error: %.4g\n" i eprior epost
doplot && display(plot_convergence(state, data))
@printf "Iteration %d,\t\ta priori error: %.4g,\t\ta posteriori error: %.4g\n" i eprior epost
state
end
end
Expand Down Expand Up @@ -1147,65 +1152,23 @@ end
# We start by defining the "no closure" model, where $m = 0$.
# This is the baseline, and corresponds to coarse DNS.

noclosure, θ_noclosure = (u, θ) -> zero(u), nothing
m_0, θ_0, label_0 = (u, θ) -> zero(u), nothing, "m=0"

# ### Train a CNN
#
# We now create CNN model. Note that the last activation is `identity`, as we
# We now create a closure model. Note that the last activation is `identity`, as we
# don't want to restrict the output values. We can inspect the structure in the
# wrapped Lux `Chain`.

cnn, θ_cnn = create_cnn(;
m_cnn, θ_cnn = create_cnn(;
radii = [2, 2, 2, 2],
channels = [8, 8, 8, 1],
activations = [leakyrelu, leakyrelu, leakyrelu, identity],
use_bias = [true, true, true, false],
input_channels = (u -> u, u -> u .^ 2),
rng,
)
cnn.chain
m_cnn.chain

# Choose loss function

loss = create_randloss_commutator(cnn, data_train; nuse = 50)
## loss = create_randloss_trajectory(
## les,
## data_train;
## nuse = 3,
## n_unroll = 10,
## data_train.μ,
## m = cnn,
## )

# Initilize CNN training state

trainstate = initial_trainstate(Adam(1.0e-3), θ_cnn)

# Model warm-up: trigger compilation and get indication of complexity

loss(θ_cnn)
gradient(loss, θ_cnn);
@time loss(θ_cnn);
@time gradient(loss, θ_cnn);

# Train the CNN. The cell below can be repeated to continue training where the
# previous training session left off.

trainstate = train(;
trainstate...,
loss,
niter = 1000,
ncallback = 20,
callback = create_callback(cnn, data_valid),
)

# Final CNN weights

θ_cnn = trainstate.θ

# #### Train an FNO
#
# Create FNO. Like for the CNN, last activation is `identity`.
#-

fno, θ_fno = create_fno(;
channels = [5, 5, 5, 5],
Expand All @@ -1216,50 +1179,62 @@ fno, θ_fno = create_fno(;
)
fno.chain

#-

m, θ, label = m_cnn, θ_cnn, "CNN"
## m, θ, label = m_fno, θ_fno, "FNO"

# Choose loss function

loss = create_randloss_commutator(fno, data_train; nuse = 50)
loss = create_randloss_commutator(m, data_train; nuse = 50)
## loss = create_randloss_trajectory(
## les,
## data_train;
## nuse = 3,
## n_unroll = 10,
## data_train.μ,
## m = fno,
## m,
## )

trainstate = initial_trainstate(Adam(1.0e-3), θ_fno)
# Initilize training state. Note that we have to provide an optimizer, here
# `Adam(η)` where `η` is the learning rate [^4]. This optimizer exploits the
# random nature of our loss function.

## Model warm-up: trigger compilation and get indication of complexity
loss(θ_fno);
gradient(loss, θ_fno);
@time loss(θ_fno);
@time gradient(loss, θ_fno);
trainstate = initial_trainstate(Adam(1.0e-3), θ)

# Train the FNO. The cell below can be repeated to continue training where the
# Model warm-up: trigger compilation and get indication of complexity

loss(θ)
gradient(loss, θ);
@time loss(θ);
@time gradient(loss, θ);

# Train the model. The cell below can be repeated to continue training where the
# previous training session left off.
# If you run this in a notebook, `doplot = true` will create a lot of plots
# below the cell.

trainstate = train(;
trainstate...,
loss,
niter = 1000,
ncallback = 20,
callback = create_callback(fno, data_valid),
callback = create_callback(m, data_valid; doplot = false),
)
plot_convergence(trainstate.callbackstate, data_valid)

# Final FNO weights
θ_fno = trainstate.θ
# Final model weights

(; θ) = trainstate

# ### Model performance
#
# We will now make a comparison of the three closure models (including the
# "no-model" where $m = 0$, which corresponds to solving the DNS equations on the
# LES grid).
# We will now make a comparison between our closure model, the baseline "no closure" model,
# and the reference testing data.

models = [
(noclosure, θ_noclosure, "m=0")
(cnn, θ_cnn, "CNN")
(fno, θ_fno, "FNO")
(m_0, θ_0, label_0)
(m, θ, label)
]

println("Relative a posteriori errors:")
Expand Down Expand Up @@ -1325,7 +1300,7 @@ end
# $$
# \frac{\mathrm{d} \bar{v}}{\mathrm{d} t} = m(\bar{v}, \theta).
# $$
# This is known as a _Neural ODE_ (see Chen [^4]).
# This is known as a _Neural ODE_ (see Chen [^5]).
# 1. Define a model that predicts the _entire_ right hand side.
# This can be done by using the following little "hack":
#
Expand Down Expand Up @@ -1412,7 +1387,12 @@ end
# arXiv:[2010.08895](https://arxiv.org/abs/2010.08895),
# 2021.
#
# [^4]: R. T. Q. Chen, Y. Rubanova, J. Bettencourt, and D. Duvenaud.
# [^4]: D. P. Kingma and J. Ba.
# _Adam: A method for stochastic optimization_.
# arxiv:[1412.6980](https://arxiv.org/abs/1412.6980),
# 2014.
#
# [^5]: R. T. Q. Chen, Y. Rubanova, J. Bettencourt, and D. Duvenaud.
# _Neural Ordinary Differential Equations_.
# arXiv:[1806.07366](https://arxiv.org/abs/1806.07366),
# 2018.

0 comments on commit b02bdc7

Please sign in to comment.