diff --git a/src/OutputReaders/field_dataset.jl b/src/OutputReaders/field_dataset.jl index dc6072eb6d..cf5c1507f3 100644 --- a/src/OutputReaders/field_dataset.jl +++ b/src/OutputReaders/field_dataset.jl @@ -1,7 +1,8 @@ -struct FieldDataset{F, M, P} - fields :: F - metadata :: M - filepath :: P +struct FieldDataset{F, M, P, KW} + fields :: F + metadata :: M + filepath :: P + backend_kw :: KW end """ @@ -22,17 +23,24 @@ linearly. `file["metadata"]`. - `grid`: May be specified to override the grid used in the JLD2 file. + +- `backend_kw`: A dictionary of keyword arguments to pass to the backend (currently only JLD2) + to be used when opening files. """ function FieldDataset(filepath; - architecture=CPU(), grid=nothing, backend=InMemory(), metadata_paths=["metadata"]) + architecture = CPU(), + grid = nothing, + backend = InMemory(), + metadata_paths = ["metadata"], + backend_kw = Dict{Symbol, Any}()) - file = jldopen(filepath) + file = jldopen(filepath; backend_kw...) field_names = keys(file["timeseries"]) filter!(k -> k != "t", field_names) # Time is not a field. ds = Dict{String, FieldTimeSeries}( - name => FieldTimeSeries(filepath, name; architecture, backend, grid) + name => FieldTimeSeries(filepath, name; architecture, backend, grid, backend_kw) for name in field_names ) @@ -44,7 +52,7 @@ function FieldDataset(filepath; close(file) - return FieldDataset(ds, metadata, abspath(filepath)) + return FieldDataset(ds, metadata, abspath(filepath), backend_kw) end Base.getindex(fds::FieldDataset, inds...) = Base.getindex(fds.fields, inds...) diff --git a/src/OutputReaders/field_time_series.jl b/src/OutputReaders/field_time_series.jl index b7b8504e36..7bb6fee317 100644 --- a/src/OutputReaders/field_time_series.jl +++ b/src/OutputReaders/field_time_series.jl @@ -85,7 +85,7 @@ period = t[end] - t[1] + Δt """ struct Cyclical{FT} period :: FT -end +end Cyclical() = Cyclical(nothing) @@ -164,7 +164,7 @@ Nt = 5 backend = InMemory(4, 3) # so we have (4, 5, 1) n = 1 # so, the right answer is m̃ = 3 m = 1 - (4 - 1) # = -2 -m̃ = mod1(-2, 5) # = 3 ✓ +m̃ = mod1(-2, 5) # = 3 ✓ ``` # Another shifting + wrapping example @@ -213,7 +213,7 @@ Base.length(backend::PartlyInMemory) = backend.length ##### FieldTimeSeries ##### -mutable struct FieldTimeSeries{LX, LY, LZ, TI, K, I, D, G, ET, B, χ, P, N} <: AbstractField{LX, LY, LZ, G, ET, 4} +mutable struct FieldTimeSeries{LX, LY, LZ, TI, K, I, D, G, ET, B, χ, P, N, KW} <: AbstractField{LX, LY, LZ, G, ET, 4} data :: D grid :: G backend :: K @@ -223,16 +223,18 @@ mutable struct FieldTimeSeries{LX, LY, LZ, TI, K, I, D, G, ET, B, χ, P, N} <: A path :: P name :: N time_indexing :: TI - + backend_kw :: KW + function FieldTimeSeries{LX, LY, LZ}(data::D, grid::G, backend::K, bcs::B, - indices::I, + indices::I, times, path, name, - time_indexing) where {LX, LY, LZ, K, D, G, B, I} + time_indexing, + backend_kw) where {LX, LY, LZ, K, D, G, B, I} ET = eltype(data) @@ -250,7 +252,7 @@ mutable struct FieldTimeSeries{LX, LY, LZ, TI, K, I, D, G, ET, B, χ, P, N} <: A times = on_architecture(architecture(grid), times) end - + if time_indexing isa Cyclical{Nothing} # we have to infer the period Δt = @allowscalar times[end] - times[end-1] period = @allowscalar times[end] - times[1] + Δt @@ -261,23 +263,25 @@ mutable struct FieldTimeSeries{LX, LY, LZ, TI, K, I, D, G, ET, B, χ, P, N} <: A TI = typeof(time_indexing) P = typeof(path) N = typeof(name) + KW = typeof(backend_kw) - return new{LX, LY, LZ, TI, K, I, D, G, ET, B, χ, P, N}(data, grid, backend, bcs, - indices, times, path, name, - time_indexing) + return new{LX, LY, LZ, TI, K, I, D, G, ET, B, χ, P, N, KW}(data, grid, backend, bcs, + indices, times, path, name, + time_indexing, backend_kw) end end -on_architecture(to, fts::FieldTimeSeries{LX, LY, LZ}) where {LX, LY, LZ} = +on_architecture(to, fts::FieldTimeSeries{LX, LY, LZ}) where {LX, LY, LZ} = FieldTimeSeries{LX, LY, LZ}(on_architecture(to, fts.data), on_architecture(to, fts.grid), on_architecture(to, fts.backend), on_architecture(to, fts.bcs), - on_architecture(to, fts.indices), + on_architecture(to, fts.indices), on_architecture(to, fts.times), on_architecture(to, fts.path), on_architecture(to, fts.name), - on_architecture(to, fts.time_indexing)) + on_architecture(to, fts.time_indexing), + on_architecture(to, fts.backend_kw)) ##### ##### Minimal implementation of FieldTimeSeries for use in GPU kernels @@ -290,7 +294,7 @@ struct GPUAdaptedFieldTimeSeries{LX, LY, LZ, TI, K, ET, D, χ} <: AbstractField{ times :: χ backend :: K time_indexing :: TI - + function GPUAdaptedFieldTimeSeries{LX, LY, LZ}(data::D, times::χ, backend::K, @@ -313,7 +317,7 @@ const FTS{LX, LY, LZ, TI, K} = FieldTimeSeries{LX, LY, LZ, TI, K} w const GPUFTS{LX, LY, LZ, TI, K} = GPUAdaptedFieldTimeSeries{LX, LY, LZ, TI, K} where {LX, LY, LZ, TI, K} const FlavorOfFTS{LX, LY, LZ, TI, K} = Union{GPUFTS{LX, LY, LZ, TI, K}, - FTS{LX, LY, LZ, TI, K}} where {LX, LY, LZ, TI, K} + FTS{LX, LY, LZ, TI, K}} where {LX, LY, LZ, TI, K} const InMemoryFTS = FlavorOfFTS{<:Any, <:Any, <:Any, <:Any, <:AbstractInMemoryBackend} const OnDiskFTS = FlavorOfFTS{<:Any, <:Any, <:Any, <:Any, <:OnDisk} @@ -345,7 +349,7 @@ instantiate(T::Type) = T() new_data(FT, grid, loc, indices, ::Nothing) = nothing # Apparently, not explicitly specifying Int64 in here makes this function -# fail on x86 processors where `Int` is implied to be `Int32` +# fail on x86 processors where `Int` is implied to be `Int32` # see ClimaOcean commit 3c47d887659d81e0caed6c9df41b7438e1f1cd52 at https://github.com/CliMA/ClimaOcean.jl/actions/runs/8804916198/job/24166354095) function new_data(FT, grid, loc, indices, Nt::Union{Int, Int64}) space_size = total_size(grid, loc, indices) @@ -360,12 +364,13 @@ time_indices_length(backend::PartlyInMemory, times) = length(backend) time_indices_length(::OnDisk, times) = nothing function FieldTimeSeries(loc, grid, times=(); - indices = (:, :, :), + indices = (:, :, :), backend = InMemory(), - path = nothing, + path = nothing, name = nothing, time_indexing = Linear(), - boundary_conditions = nothing) + boundary_conditions = nothing, + backend_kw = Dict{Symbol, Any}()) LX, LY, LZ = loc @@ -379,9 +384,9 @@ function FieldTimeSeries(loc, grid, times=(); isnothing(path) && error(ArgumentError("Must provide the keyword argument `path` when `backend=OnDisk()`.")) isnothing(name) && error(ArgumentError("Must provide the keyword argument `name` when `backend=OnDisk()`.")) end - - return FieldTimeSeries{LX, LY, LZ}(data, grid, backend, boundary_conditions, - indices, times, path, name, time_indexing) + + return FieldTimeSeries{LX, LY, LZ}(data, grid, backend, boundary_conditions, indices, + times, path, name, time_indexing, backend_kw) end """ @@ -408,10 +413,16 @@ end struct UnspecifiedBoundaryConditions end """ - FieldTimeSeries(path, name, backend = InMemory(); + FieldTimeSeries(path, name; + backend = InMemory(), + architecture = nothing, grid = nothing, + location = nothing, + boundary_conditions = UnspecifiedBoundaryConditions(), + time_indexing = Linear(), iterations = nothing, - times = nothing) + times = nothing, + backend_kw = Dict{Symbol, Any}()) Return a `FieldTimeSeries` containing a time-series of the field `name` load from JLD2 output located at `path`. @@ -430,6 +441,9 @@ Keyword arguments - `times`: Save times to load, as determined through an approximate floating point comparison to recorded save times. Defaults to times associated with `iterations`. Takes precedence over `iterations` if `times` is specified. + +- `backend_kw`: A dictionary of keyword arguments to pass to the backend (currently only JLD2) + to be used when opening files. """ function FieldTimeSeries(path::String, name::String; backend = InMemory(), @@ -439,9 +453,10 @@ function FieldTimeSeries(path::String, name::String; boundary_conditions = UnspecifiedBoundaryConditions(), time_indexing = Linear(), iterations = nothing, - times = nothing) + times = nothing, + backend_kw = Dict{Symbol, Any}()) - file = jldopen(path) + file = jldopen(path; backend_kw...) # Defaults isnothing(iterations) && (iterations = parse.(Int, keys(file["timeseries/t"]))) @@ -523,8 +538,8 @@ function FieldTimeSeries(path::String, name::String; Nt = time_indices_length(backend, times) data = new_data(eltype(grid), grid, loc, indices, Nt) - time_series = FieldTimeSeries{LX, LY, LZ}(data, grid, backend, boundary_conditions, - indices, times, path, name, time_indexing) + time_series = FieldTimeSeries{LX, LY, LZ}(data, grid, backend, boundary_conditions, indices, + times, path, name, time_indexing, backend_kw) set!(time_series, path, name) @@ -536,7 +551,8 @@ end grid = nothing, architecture = nothing, indices = (:, :, :), - boundary_conditions = nothing) + boundary_conditions = nothing, + backend_kw = Dict{Symbol, Any}()) Load a field called `name` saved in a JLD2 file at `path` at `iter`ation. Unless specified, the `grid` is loaded from `path`. @@ -545,7 +561,8 @@ function Field(location, path::String, name::String, iter; grid = nothing, architecture = nothing, indices = (:, :, :), - boundary_conditions = nothing) + boundary_conditions = nothing, + backend_kw = Dict{Symbol, Any}()) # Default to CPU if neither architecture nor grid is specified if isnothing(architecture) @@ -555,9 +572,9 @@ function Field(location, path::String, name::String, iter; architecture = Architectures.architecture(grid) end end - + # Load the grid and data from file - file = jldopen(path) + file = jldopen(path; backend_kw...) isnothing(grid) && (grid = file["serialized/grid"]) raw_data = file["timeseries/$name/$iter"] @@ -568,7 +585,7 @@ function Field(location, path::String, name::String, iter; grid = on_architecture(architecture, grid) raw_data = on_architecture(architecture, raw_data) data = offset_data(raw_data, grid, location, indices) - + return Field(location, grid; boundary_conditions, indices, data) end @@ -630,4 +647,3 @@ function fill_halo_regions!(fts::InMemoryFTS) return nothing end - diff --git a/src/OutputReaders/field_time_series_indexing.jl b/src/OutputReaders/field_time_series_indexing.jl index 6a4683f9d0..ac08258444 100644 --- a/src/OutputReaders/field_time_series_indexing.jl +++ b/src/OutputReaders/field_time_series_indexing.jl @@ -79,7 +79,7 @@ import Base: getindex function getindex(fts::OnDiskFTS, n::Int) # Load data arch = architecture(fts) - file = jldopen(fts.path) + file = jldopen(fts.path; fts.backend_kw...) iter = keys(file["timeseries/t"])[n] raw_data = on_architecture(arch, file["timeseries/$(fts.name)/$iter"]) close(file) diff --git a/src/OutputReaders/set_field_time_series.jl b/src/OutputReaders/set_field_time_series.jl index d450926b3e..01a97a50b7 100644 --- a/src/OutputReaders/set_field_time_series.jl +++ b/src/OutputReaders/set_field_time_series.jl @@ -11,7 +11,7 @@ find_time_index(time::Number, file_times) = findfirst(t -> t ≈ time, fil find_time_index(time::AbstractTime, file_times) = findfirst(t -> t == time, file_times) function set!(fts::InMemoryFTS, path::String=fts.path, name::String=fts.name) - file = jldopen(path) + file = jldopen(path; fts.backend_kw...) file_iterations = iterations_from_file(file) file_times = [file["timeseries/t/$i"] for i in file_iterations] close(file) @@ -51,7 +51,7 @@ set!(fts::InMemoryFTS, value, n::Int) = set!(fts[n], value) function set!(fts::InMemoryFTS, fields_vector::AbstractVector{<:AbstractField}) raw_data = parent(fts) - file = jldopen(path) + file = jldopen(path; fts.backend_kw...) for (n, field) in enumerate(fields_vector) nth_raw_data = view(raw_data, :, :, :, n) diff --git a/test/test_forcings.jl b/test/test_forcings.jl index e4e4e48294..c4056ea7e0 100644 --- a/test/test_forcings.jl +++ b/test/test_forcings.jl @@ -118,7 +118,7 @@ end function time_step_with_field_time_series_forcing(arch) grid = RectilinearGrid(arch, size=(1, 1, 1), extent=(1, 1, 1)) - + u_forcing = FieldTimeSeries{Face, Center, Center}(grid, 0:1:3) for (t, time) in enumerate(u_forcing.times) @@ -134,14 +134,14 @@ function time_step_with_field_time_series_forcing(arch) model = NonhydrostaticModel(; grid, forcing=(; u=u_forcing)) time_step!(model, 2) time_step!(model, 2) - + @test u_forcing.backend.start == 4 return true end function relaxed_time_stepping(arch) - x_relax = Relaxation(rate = 1/60, mask = GaussianMask{:x}(center=0.5, width=0.1), + x_relax = Relaxation(rate = 1/60, mask = GaussianMask{:x}(center=0.5, width=0.1), target = LinearTarget{:x}(intercept=π, gradient=ℯ)) y_relax = Relaxation(rate = 1/60, mask = GaussianMask{:y}(center=0.5, width=0.1), @@ -197,7 +197,7 @@ end function two_forcings(arch) grid = RectilinearGrid(arch, size=(4, 5, 6), extent=(1, 1, 1), halo=(4, 4, 4)) - + forcing1 = Relaxation(rate=1) forcing2 = Relaxation(rate=2) @@ -221,7 +221,7 @@ function seven_forcings(arch) peculiar_forcing(x, y, z, t) = 2t / z eccentric_forcing(x, y, z, t) = x + y + z + t unconventional_forcing(x, y, z, t) = 10x * y - + F1 = Forcing(weird_forcing) F2 = Forcing(wonky_forcing) F3 = Forcing(strange_forcing) @@ -269,7 +269,7 @@ end @test time_step_with_multiple_field_dependent_forcing(arch) @test time_step_with_parameterized_field_dependent_forcing(arch) - end + end @testset "Relaxation forcing functions [$A]" begin @info " Testing relaxation forcing functions [$A]..."