Skip to content

Commit

Permalink
Add SciMLStructures.jl methods
Browse files Browse the repository at this point in the history
  • Loading branch information
albertomercurio committed Nov 15, 2024
1 parent 42f827e commit bf3450d
Showing 1 changed file with 46 additions and 0 deletions.
46 changes: 46 additions & 0 deletions src/time_evolution/time_evo_parameters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,49 @@ Base.length(obj::TimeEvolutionParameters) = length(obj.params)
# return TimeEvolutionParameters(merge(a.params, b), a.expvals, a.progr, a.mcsolve_params)
# end

########## Mark the struct as a SciMLStructure ##########
# The NamedTuple `params` case still doesn't work, and it should be put as a `Vector` instead

isscimlstructure(::TimeEvolutionParameters) = true
# ismutablescimlstructure(::TimeEvolutionParameters{ParT}) where {ParT<:NamedTuple} = false
ismutablescimlstructure(::TimeEvolutionParameters{ParT}) where {ParT<:AbstractVector} = true

hasportion(::Tunable, ::TimeEvolutionParameters) = true

function _vectorize_params(p::TimeEvolutionParameters{ParT}) where {ParT<:NamedTuple}
buffer = isempty(p.params) ? eltype(p.expvals)[] : collect(values(p.params))
return (buffer, false)
end
_vectorize_params(p::TimeEvolutionParameters{ParT}) where {ParT<:AbstractVector} = (p.params, true)

function canonicalize(::Tunable, p::TimeEvolutionParameters)
buffer, aliases = _vectorize_params(p) # We are assuming that the values have the same type

# repack takes a new vector of the same length as `buffer`, and constructs
# a new `TimeEvolutionParameters` object using the values from the new vector for tunables
# and retaining old values for other parameters. This is exactly what replace does,
# so we can use that instead.
repack = let p = p
repack(newbuffer) = replace(Tunable(), p, newbuffer)
end
# the canonicalized vector, the repack function, and a boolean indicating
# whether the buffer aliases values in the parameter object
return buffer, repack, aliases
end

function replace(::Tunable, p::TimeEvolutionParameters{ParT}, newbuffer) where {ParT<:NamedTuple}
@assert length(newbuffer) == length(p.params)
new_params = NamedTuple{keys(p.params)}(Tuple(newbuffer))
return TimeEvolutionParameters(new_params, p.expvals, p.progr, p.mcsolve_params)
end

function replace(::Tunable, p::TimeEvolutionParameters{ParT}, newbuffer) where {ParT<:AbstractVector}
@assert length(newbuffer) == length(p.params)
return TimeEvolutionParameters(newbuffer, p.expvals, p.progr, p.mcsolve_params)
end

function replace!(::Tunable, p::TimeEvolutionParameters{ParT}, newbuffer) where {ParT<:AbstractVector}
@assert length(newbuffer) == length(p.params)
copyto!(p.params, newbuffer)
return p
end

0 comments on commit bf3450d

Please sign in to comment.