From b02bdc7d7984e82a7932f27fe778b81f0573a450 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Syver=20D=C3=B8ving=20Agdestein?= Date: Fri, 6 Oct 2023 17:10:17 +0200 Subject: [PATCH] Don't make multiple plots by default. Merge models. --- tutorials/burgers.jl | 146 +++++++++++++++++++------------------------ 1 file changed, 63 insertions(+), 83 deletions(-) diff --git a/tutorials/burgers.jl b/tutorials/burgers.jl index 63d9b29..0bc4b8e 100644 --- a/tutorials/burgers.jl +++ b/tutorials/burgers.jl @@ -753,14 +753,30 @@ end # During training, we will monitor the error on the validation dataset with a # callback. We will plot the history of the a priori and a posteriori errors. -## Initial empty history +## Initial empty history (with no-model errrors) initial_callbackstate() = (; ihist = Int[], ehist_prior = zeros(0), ehist_post = zeros(0)) +## Plot convergence +function plot_convergence(state, data) + e_post_ref = trajectory_error(dns, data.u; data.dt, data.μ) + fig = plot(; yscale = :log10, xlabel = "Iterations", title = "Relative error") + hline!(fig, [1.0]; color = 1, linestyle = :dash, label = "A priori: No model") + plot!(fig, state.ihist, state.ehist_prior; color = 1, label = "A priori: Model") + hline!( + fig, + [e_post_ref]; + color = 2, + linestyle = :dash, + label = "A posteriori: No model", + ) + plot!(fig, state.ihist, state.ehist_post; color = 2, label = "A posteriori: Model") + fig +end + ## Create callback for given model and dataset -function create_callback(m, data) +function create_callback(m, data; doplot = false) (; u, c, dt, μ) = data uu, cc = reshape(u, size(u, 1), :), reshape(c, size(c, 1), :) - e_post_ref = trajectory_error(dns, u; dt, μ) function callback(i, θ, state) (; ihist, ehist_prior, ehist_post) = state eprior = norm(m(uu, θ) - cc) / norm(cc) @@ -770,19 +786,8 @@ function create_callback(m, data) ehist_prior = vcat(ehist_prior, eprior), ehist_post = vcat(ehist_post, epost), ) - fig = plot(; yscale = :log10, xlabel = "Iterations", title = "Relative error") - hline!(fig, [1.0]; color = 1, linestyle = :dash, label = "A priori: No model") - plot!(fig, state.ihist, state.ehist_prior; color = 1, label = "A priori: Model") - hline!( - fig, - [e_post_ref]; - color = 2, - linestyle = :dash, - label = "A posteriori: No model", - ) - plot!(fig, state.ihist, state.ehist_post; color = 2, label = "A posteriori: Model") - display(fig) - @printf "Iteration %d\ta priori error: %.4g\ta posteriori error: %.4g\n" i eprior epost + doplot && display(plot_convergence(state, data)) + @printf "Iteration %d,\t\ta priori error: %.4g,\t\ta posteriori error: %.4g\n" i eprior epost state end end @@ -1147,15 +1152,13 @@ end # We start by defining the "no closure" model, where $m = 0$. # This is the baseline, and corresponds to coarse DNS. -noclosure, θ_noclosure = (u, θ) -> zero(u), nothing +m_0, θ_0, label_0 = (u, θ) -> zero(u), nothing, "m=0" -# ### Train a CNN -# -# We now create CNN model. Note that the last activation is `identity`, as we +# We now create a closure model. Note that the last activation is `identity`, as we # don't want to restrict the output values. We can inspect the structure in the # wrapped Lux `Chain`. -cnn, θ_cnn = create_cnn(; +m_cnn, θ_cnn = create_cnn(; radii = [2, 2, 2, 2], channels = [8, 8, 8, 1], activations = [leakyrelu, leakyrelu, leakyrelu, identity], @@ -1163,49 +1166,9 @@ cnn, θ_cnn = create_cnn(; input_channels = (u -> u, u -> u .^ 2), rng, ) -cnn.chain +m_cnn.chain -# Choose loss function - -loss = create_randloss_commutator(cnn, data_train; nuse = 50) -## loss = create_randloss_trajectory( -## les, -## data_train; -## nuse = 3, -## n_unroll = 10, -## data_train.μ, -## m = cnn, -## ) - -# Initilize CNN training state - -trainstate = initial_trainstate(Adam(1.0e-3), θ_cnn) - -# Model warm-up: trigger compilation and get indication of complexity - -loss(θ_cnn) -gradient(loss, θ_cnn); -@time loss(θ_cnn); -@time gradient(loss, θ_cnn); - -# Train the CNN. The cell below can be repeated to continue training where the -# previous training session left off. - -trainstate = train(; - trainstate..., - loss, - niter = 1000, - ncallback = 20, - callback = create_callback(cnn, data_valid), -) - -# Final CNN weights - -θ_cnn = trainstate.θ - -# #### Train an FNO -# -# Create FNO. Like for the CNN, last activation is `identity`. +#- fno, θ_fno = create_fno(; channels = [5, 5, 5, 5], @@ -1216,50 +1179,62 @@ fno, θ_fno = create_fno(; ) fno.chain +#- + +m, θ, label = m_cnn, θ_cnn, "CNN" +## m, θ, label = m_fno, θ_fno, "FNO" + # Choose loss function -loss = create_randloss_commutator(fno, data_train; nuse = 50) +loss = create_randloss_commutator(m, data_train; nuse = 50) ## loss = create_randloss_trajectory( ## les, ## data_train; ## nuse = 3, ## n_unroll = 10, ## data_train.μ, -## m = fno, +## m, ## ) -trainstate = initial_trainstate(Adam(1.0e-3), θ_fno) +# Initilize training state. Note that we have to provide an optimizer, here +# `Adam(η)` where `η` is the learning rate [^4]. This optimizer exploits the +# random nature of our loss function. -## Model warm-up: trigger compilation and get indication of complexity -loss(θ_fno); -gradient(loss, θ_fno); -@time loss(θ_fno); -@time gradient(loss, θ_fno); +trainstate = initial_trainstate(Adam(1.0e-3), θ) -# Train the FNO. The cell below can be repeated to continue training where the +# Model warm-up: trigger compilation and get indication of complexity + +loss(θ) +gradient(loss, θ); +@time loss(θ); +@time gradient(loss, θ); + +# Train the model. The cell below can be repeated to continue training where the # previous training session left off. +# If you run this in a notebook, `doplot = true` will create a lot of plots +# below the cell. trainstate = train(; trainstate..., loss, niter = 1000, ncallback = 20, - callback = create_callback(fno, data_valid), + callback = create_callback(m, data_valid; doplot = false), ) +plot_convergence(trainstate.callbackstate, data_valid) -# Final FNO weights -θ_fno = trainstate.θ +# Final model weights + +(; θ) = trainstate # ### Model performance # -# We will now make a comparison of the three closure models (including the -# "no-model" where $m = 0$, which corresponds to solving the DNS equations on the -# LES grid). +# We will now make a comparison between our closure model, the baseline "no closure" model, +# and the reference testing data. models = [ - (noclosure, θ_noclosure, "m=0") - (cnn, θ_cnn, "CNN") - (fno, θ_fno, "FNO") + (m_0, θ_0, label_0) + (m, θ, label) ] println("Relative a posteriori errors:") @@ -1325,7 +1300,7 @@ end # $$ # \frac{\mathrm{d} \bar{v}}{\mathrm{d} t} = m(\bar{v}, \theta). # $$ -# This is known as a _Neural ODE_ (see Chen [^4]). +# This is known as a _Neural ODE_ (see Chen [^5]). # 1. Define a model that predicts the _entire_ right hand side. # This can be done by using the following little "hack": # @@ -1412,7 +1387,12 @@ end # arXiv:[2010.08895](https://arxiv.org/abs/2010.08895), # 2021. # -# [^4]: R. T. Q. Chen, Y. Rubanova, J. Bettencourt, and D. Duvenaud. +# [^4]: D. P. Kingma and J. Ba. +# _Adam: A method for stochastic optimization_. +# arxiv:[1412.6980](https://arxiv.org/abs/1412.6980), +# 2014. +# +# [^5]: R. T. Q. Chen, Y. Rubanova, J. Bettencourt, and D. Duvenaud. # _Neural Ordinary Differential Equations_. # arXiv:[1806.07366](https://arxiv.org/abs/1806.07366), # 2018.