Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow FieldTimeSeries to pass keyword arguments to jldopen #3739

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
24 changes: 16 additions & 8 deletions src/OutputReaders/field_dataset.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
struct FieldDataset{F, M, P}
fields :: F
metadata :: M
filepath :: P
struct FieldDataset{F, M, P, KW}
fields :: F
metadata :: M
filepath :: P
backend_kw :: KW
end

"""
Expand All @@ -22,17 +23,24 @@ linearly.
`file["metadata"]`.

- `grid`: May be specified to override the grid used in the JLD2 file.

- `backend_kw`: A dictionary of keyword arguments to pass to the backend (currently only JLD2)
to be used when opening files.
"""
function FieldDataset(filepath;
architecture=CPU(), grid=nothing, backend=InMemory(), metadata_paths=["metadata"])
architecture = CPU(),
grid = nothing,
backend = InMemory(),
metadata_paths = ["metadata"],
backend_kw = Dict{Symbol, Any}())

file = jldopen(filepath)
file = jldopen(filepath; backend_kw...)

field_names = keys(file["timeseries"])
filter!(k -> k != "t", field_names) # Time is not a field.

ds = Dict{String, FieldTimeSeries}(
name => FieldTimeSeries(filepath, name; architecture, backend, grid)
name => FieldTimeSeries(filepath, name; architecture, backend, grid, backend_kw)
for name in field_names
)

Expand All @@ -44,7 +52,7 @@ function FieldDataset(filepath;

close(file)

return FieldDataset(ds, metadata, abspath(filepath))
return FieldDataset(ds, metadata, abspath(filepath), backend_kw)
end

Base.getindex(fds::FieldDataset, inds...) = Base.getindex(fds.fields, inds...)
Expand Down
84 changes: 50 additions & 34 deletions src/OutputReaders/field_time_series.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ period = t[end] - t[1] + Δt
"""
struct Cyclical{FT}
period :: FT
end
end

Cyclical() = Cyclical(nothing)

Expand Down Expand Up @@ -164,7 +164,7 @@ Nt = 5
backend = InMemory(4, 3) # so we have (4, 5, 1)
n = 1 # so, the right answer is m̃ = 3
m = 1 - (4 - 1) # = -2
m̃ = mod1(-2, 5) # = 3 ✓
m̃ = mod1(-2, 5) # = 3 ✓
```

# Another shifting + wrapping example
Expand Down Expand Up @@ -213,7 +213,7 @@ Base.length(backend::PartlyInMemory) = backend.length
##### FieldTimeSeries
#####

mutable struct FieldTimeSeries{LX, LY, LZ, TI, K, I, D, G, ET, B, χ, P, N} <: AbstractField{LX, LY, LZ, G, ET, 4}
mutable struct FieldTimeSeries{LX, LY, LZ, TI, K, I, D, G, ET, B, χ, P, N, KW} <: AbstractField{LX, LY, LZ, G, ET, 4}
data :: D
grid :: G
backend :: K
Expand All @@ -223,16 +223,18 @@ mutable struct FieldTimeSeries{LX, LY, LZ, TI, K, I, D, G, ET, B, χ, P, N} <: A
path :: P
name :: N
time_indexing :: TI

backend_kw :: KW

function FieldTimeSeries{LX, LY, LZ}(data::D,
grid::G,
backend::K,
bcs::B,
indices::I,
indices::I,
times,
path,
name,
time_indexing) where {LX, LY, LZ, K, D, G, B, I}
time_indexing,
backend_kw) where {LX, LY, LZ, K, D, G, B, I}

ET = eltype(data)

Expand All @@ -250,7 +252,7 @@ mutable struct FieldTimeSeries{LX, LY, LZ, TI, K, I, D, G, ET, B, χ, P, N} <: A

times = on_architecture(architecture(grid), times)
end

if time_indexing isa Cyclical{Nothing} # we have to infer the period
Δt = @allowscalar times[end] - times[end-1]
period = @allowscalar times[end] - times[1] + Δt
Expand All @@ -261,23 +263,25 @@ mutable struct FieldTimeSeries{LX, LY, LZ, TI, K, I, D, G, ET, B, χ, P, N} <: A
TI = typeof(time_indexing)
P = typeof(path)
N = typeof(name)
KW = typeof(backend_kw)

return new{LX, LY, LZ, TI, K, I, D, G, ET, B, χ, P, N}(data, grid, backend, bcs,
indices, times, path, name,
time_indexing)
return new{LX, LY, LZ, TI, K, I, D, G, ET, B, χ, P, N, KW}(data, grid, backend, bcs,
indices, times, path, name,
time_indexing, backend_kw)
end
end

on_architecture(to, fts::FieldTimeSeries{LX, LY, LZ}) where {LX, LY, LZ} =
on_architecture(to, fts::FieldTimeSeries{LX, LY, LZ}) where {LX, LY, LZ} =
FieldTimeSeries{LX, LY, LZ}(on_architecture(to, fts.data),
on_architecture(to, fts.grid),
on_architecture(to, fts.backend),
on_architecture(to, fts.bcs),
on_architecture(to, fts.indices),
on_architecture(to, fts.indices),
on_architecture(to, fts.times),
on_architecture(to, fts.path),
on_architecture(to, fts.name),
on_architecture(to, fts.time_indexing))
on_architecture(to, fts.time_indexing),
on_architecture(to, fts.backend_kw))

#####
##### Minimal implementation of FieldTimeSeries for use in GPU kernels
Expand All @@ -290,7 +294,7 @@ struct GPUAdaptedFieldTimeSeries{LX, LY, LZ, TI, K, ET, D, χ} <: AbstractField{
times :: χ
backend :: K
time_indexing :: TI

function GPUAdaptedFieldTimeSeries{LX, LY, LZ}(data::D,
times::χ,
backend::K,
Expand All @@ -313,7 +317,7 @@ const FTS{LX, LY, LZ, TI, K} = FieldTimeSeries{LX, LY, LZ, TI, K} w
const GPUFTS{LX, LY, LZ, TI, K} = GPUAdaptedFieldTimeSeries{LX, LY, LZ, TI, K} where {LX, LY, LZ, TI, K}

const FlavorOfFTS{LX, LY, LZ, TI, K} = Union{GPUFTS{LX, LY, LZ, TI, K},
FTS{LX, LY, LZ, TI, K}} where {LX, LY, LZ, TI, K}
FTS{LX, LY, LZ, TI, K}} where {LX, LY, LZ, TI, K}

const InMemoryFTS = FlavorOfFTS{<:Any, <:Any, <:Any, <:Any, <:AbstractInMemoryBackend}
const OnDiskFTS = FlavorOfFTS{<:Any, <:Any, <:Any, <:Any, <:OnDisk}
Expand Down Expand Up @@ -345,7 +349,7 @@ instantiate(T::Type) = T()
new_data(FT, grid, loc, indices, ::Nothing) = nothing

# Apparently, not explicitly specifying Int64 in here makes this function
# fail on x86 processors where `Int` is implied to be `Int32`
# fail on x86 processors where `Int` is implied to be `Int32`
# see ClimaOcean commit 3c47d887659d81e0caed6c9df41b7438e1f1cd52 at https://github.com/CliMA/ClimaOcean.jl/actions/runs/8804916198/job/24166354095)
function new_data(FT, grid, loc, indices, Nt::Union{Int, Int64})
space_size = total_size(grid, loc, indices)
Expand All @@ -360,12 +364,13 @@ time_indices_length(backend::PartlyInMemory, times) = length(backend)
time_indices_length(::OnDisk, times) = nothing

function FieldTimeSeries(loc, grid, times=();
indices = (:, :, :),
indices = (:, :, :),
backend = InMemory(),
path = nothing,
path = nothing,
name = nothing,
time_indexing = Linear(),
boundary_conditions = nothing)
boundary_conditions = nothing,
backend_kw = Dict{Symbol, Any}())

LX, LY, LZ = loc

Expand All @@ -379,9 +384,9 @@ function FieldTimeSeries(loc, grid, times=();
isnothing(path) && error(ArgumentError("Must provide the keyword argument `path` when `backend=OnDisk()`."))
isnothing(name) && error(ArgumentError("Must provide the keyword argument `name` when `backend=OnDisk()`."))
end
return FieldTimeSeries{LX, LY, LZ}(data, grid, backend, boundary_conditions,
indices, times, path, name, time_indexing)

return FieldTimeSeries{LX, LY, LZ}(data, grid, backend, boundary_conditions, indices,
times, path, name, time_indexing, backend_kw)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think theere needs to be changes to adapt_structure which I do not see in this PR. Possibly this is not tested right now. @simone-silvestri

Copy link
Member Author

@ali-ramadhan ali-ramadhan Sep 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the review!

Ah true. I guess originally we didn't envision passing FieldTimeSeries into GPU kernels and I don't see a test that does this.

But with features like FieldTimeSeries forcing (PR #3760) an adapt_structure will be necessary. Actually not sure how tests will pass in that PR without adapting.

I can add some tests here though. There are no tests that use backend_kw...

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, I think there are no tests about this. The adapt for a FieldTimeSeries returns a different type: GPUAdaptedFieldTimeSeries that is a little slimmer in terms of parameter space.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok but this PR breaks FieldTimeSeries on GPU which will break everything in ClimaOcean. @simone-silvestri can you add some tests here?

end

"""
Expand All @@ -408,10 +413,16 @@ end
struct UnspecifiedBoundaryConditions end

"""
FieldTimeSeries(path, name, backend = InMemory();
FieldTimeSeries(path, name;
backend = InMemory(),
architecture = nothing,
grid = nothing,
location = nothing,
boundary_conditions = UnspecifiedBoundaryConditions(),
time_indexing = Linear(),
iterations = nothing,
times = nothing)
times = nothing,
backend_kw = Dict{Symbol, Any}())

Return a `FieldTimeSeries` containing a time-series of the field `name`
load from JLD2 output located at `path`.
Expand All @@ -430,6 +441,9 @@ Keyword arguments
- `times`: Save times to load, as determined through an approximate floating point
comparison to recorded save times. Defaults to times associated with `iterations`.
Takes precedence over `iterations` if `times` is specified.

- `backend_kw`: A dictionary of keyword arguments to pass to the backend (currently only JLD2)
to be used when opening files.
"""
function FieldTimeSeries(path::String, name::String;
backend = InMemory(),
Expand All @@ -439,9 +453,10 @@ function FieldTimeSeries(path::String, name::String;
boundary_conditions = UnspecifiedBoundaryConditions(),
time_indexing = Linear(),
iterations = nothing,
times = nothing)
times = nothing,
backend_kw = Dict{Symbol, Any}())

file = jldopen(path)
file = jldopen(path; backend_kw...)

# Defaults
isnothing(iterations) && (iterations = parse.(Int, keys(file["timeseries/t"])))
Expand Down Expand Up @@ -523,8 +538,8 @@ function FieldTimeSeries(path::String, name::String;
Nt = time_indices_length(backend, times)
data = new_data(eltype(grid), grid, loc, indices, Nt)

time_series = FieldTimeSeries{LX, LY, LZ}(data, grid, backend, boundary_conditions,
indices, times, path, name, time_indexing)
time_series = FieldTimeSeries{LX, LY, LZ}(data, grid, backend, boundary_conditions, indices,
times, path, name, time_indexing, backend_kw)

set!(time_series, path, name)

Expand All @@ -536,7 +551,8 @@ end
grid = nothing,
architecture = nothing,
indices = (:, :, :),
boundary_conditions = nothing)
boundary_conditions = nothing,
backend_kw = Dict{Symbol, Any}())

Load a field called `name` saved in a JLD2 file at `path` at `iter`ation.
Unless specified, the `grid` is loaded from `path`.
Expand All @@ -545,7 +561,8 @@ function Field(location, path::String, name::String, iter;
grid = nothing,
architecture = nothing,
indices = (:, :, :),
boundary_conditions = nothing)
boundary_conditions = nothing,
backend_kw = Dict{Symbol, Any}())

# Default to CPU if neither architecture nor grid is specified
if isnothing(architecture)
Expand All @@ -555,9 +572,9 @@ function Field(location, path::String, name::String, iter;
architecture = Architectures.architecture(grid)
end
end

# Load the grid and data from file
file = jldopen(path)
file = jldopen(path; backend_kw...)

isnothing(grid) && (grid = file["serialized/grid"])
raw_data = file["timeseries/$name/$iter"]
Expand All @@ -568,7 +585,7 @@ function Field(location, path::String, name::String, iter;
grid = on_architecture(architecture, grid)
raw_data = on_architecture(architecture, raw_data)
data = offset_data(raw_data, grid, location, indices)

return Field(location, grid; boundary_conditions, indices, data)
end

Expand Down Expand Up @@ -630,4 +647,3 @@ function fill_halo_regions!(fts::InMemoryFTS)

return nothing
end

2 changes: 1 addition & 1 deletion src/OutputReaders/field_time_series_indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ import Base: getindex
function getindex(fts::OnDiskFTS, n::Int)
# Load data
arch = architecture(fts)
file = jldopen(fts.path)
file = jldopen(fts.path; fts.backend_kw...)
iter = keys(file["timeseries/t"])[n]
raw_data = on_architecture(arch, file["timeseries/$(fts.name)/$iter"])
close(file)
Expand Down
4 changes: 2 additions & 2 deletions src/OutputReaders/set_field_time_series.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ find_time_index(time::Number, file_times) = findfirst(t -> t ≈ time, fil
find_time_index(time::AbstractTime, file_times) = findfirst(t -> t == time, file_times)

function set!(fts::InMemoryFTS, path::String=fts.path, name::String=fts.name)
file = jldopen(path)
file = jldopen(path; fts.backend_kw...)
file_iterations = iterations_from_file(file)
file_times = [file["timeseries/t/$i"] for i in file_iterations]
close(file)
Expand Down Expand Up @@ -51,7 +51,7 @@ set!(fts::InMemoryFTS, value, n::Int) = set!(fts[n], value)

function set!(fts::InMemoryFTS, fields_vector::AbstractVector{<:AbstractField})
raw_data = parent(fts)
file = jldopen(path)
file = jldopen(path; fts.backend_kw...)

for (n, field) in enumerate(fields_vector)
nth_raw_data = view(raw_data, :, :, :, n)
Expand Down
12 changes: 6 additions & 6 deletions test/test_forcings.jl
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ end
function time_step_with_field_time_series_forcing(arch)

grid = RectilinearGrid(arch, size=(1, 1, 1), extent=(1, 1, 1))

u_forcing = FieldTimeSeries{Face, Center, Center}(grid, 0:1:3)

for (t, time) in enumerate(u_forcing.times)
Expand All @@ -134,14 +134,14 @@ function time_step_with_field_time_series_forcing(arch)
model = NonhydrostaticModel(; grid, forcing=(; u=u_forcing))
time_step!(model, 2)
time_step!(model, 2)

@test u_forcing.backend.start == 4

return true
end

function relaxed_time_stepping(arch)
x_relax = Relaxation(rate = 1/60, mask = GaussianMask{:x}(center=0.5, width=0.1),
x_relax = Relaxation(rate = 1/60, mask = GaussianMask{:x}(center=0.5, width=0.1),
target = LinearTarget{:x}(intercept=π, gradient=ℯ))

y_relax = Relaxation(rate = 1/60, mask = GaussianMask{:y}(center=0.5, width=0.1),
Expand Down Expand Up @@ -197,7 +197,7 @@ end

function two_forcings(arch)
grid = RectilinearGrid(arch, size=(4, 5, 6), extent=(1, 1, 1), halo=(4, 4, 4))

forcing1 = Relaxation(rate=1)
forcing2 = Relaxation(rate=2)

Expand All @@ -221,7 +221,7 @@ function seven_forcings(arch)
peculiar_forcing(x, y, z, t) = 2t / z
eccentric_forcing(x, y, z, t) = x + y + z + t
unconventional_forcing(x, y, z, t) = 10x * y

F1 = Forcing(weird_forcing)
F2 = Forcing(wonky_forcing)
F3 = Forcing(strange_forcing)
Expand Down Expand Up @@ -269,7 +269,7 @@ end

@test time_step_with_multiple_field_dependent_forcing(arch)
@test time_step_with_parameterized_field_dependent_forcing(arch)
end
end

@testset "Relaxation forcing functions [$A]" begin
@info " Testing relaxation forcing functions [$A]..."
Expand Down