Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for using non-AbstractArrays with OrdinaryDiffEq and Zygote #1092

Merged
merged 2 commits into from
Sep 2, 2024

Conversation

apkille
Copy link
Contributor

@apkille apkille commented Aug 16, 2024

Continuing my quest to open up SciML to custom non-AbstractArray types :). See SciML/StochasticDiffEq.jl#579 and SciML/OrdinaryDiffEq.jl#2368 for more reference.

@ChrisRackauckas
Copy link
Member

These changes don't look too unreasonble. However, we already support non-AbstractArrays? For parameters we support SciMLStructure types and Functors. If you define the custom type as a SciMLStructure with the right canonicialization dispatch does it just work?

@apkille
Copy link
Contributor Author

apkille commented Aug 17, 2024

@ChrisRackauckas Defining the SciMLStructure interface for my custom type as follows:

isscimlstructure(::CustomArray) = true
hasportion(::Tunable, ::CustomArray) = true
hasportion(::Constants, ::CustomArray) = false
hasportion(::Caches, ::CustomArray) = false
hasportion(::Discrete, ::CustomArray) = false

canonicalize(::Tunable, x::CustomArray) = vec(x.x), ArrayRepack(x.x), true

and running the test on master gives the following error:

ERROR: MethodError: no method matching vec(::ChainRulesCore.Tangent{Any, @NamedTuple{x::Vector{Float64}}})

Closest candidates are:
  vec(::StaticArraysCore.SizedArray{S, T, N, M} where {T, N, M}) where S
   @ StaticArrays ~/.julia/packages/StaticArrays/MSJcA/src/SizedArray.jl:171
  vec(::LazyArrays.ApplyArray{T, 2, typeof(hcat)} where T)
   @ LazyArrays ~/.julia/packages/LazyArrays/jaUBE/src/lazyconcat.jl:408
  vec(::Distributions.Distribution{<:Distributions.ArrayLikeVariate})
   @ Distributions ~/.julia/packages/Distributions/ji8PW/src/reshaped.jl:143
  ...

Stacktrace:
  [1] (::SciMLSensitivity.var"#332#341"{Vector{Any}, CustomArray{Float64, 1}})(i::Int64)
    @ SciMLSensitivity ~/Documents/Julia Packages/SciML/SciMLSensitivity.jl/src/concrete_solve.jl:1089
  [2] _mapreduce(f::SciMLSensitivity.var"#332#341"{…}, op::typeof(Base.add_sum), ::IndexLinear, A::Base.OneTo{…})
    @ Base ./reduce.jl:440
  [3] _mapreduce_dim(f::Function, op::Function, ::Base._InitialValue, A::Base.OneTo{Int64}, ::Colon)
    @ Base ./reducedim.jl:365
  [4] mapreduce
    @ ./reducedim.jl:357 [inlined]
  [5] _sum
    @ ./reducedim.jl:1015 [inlined]
  [6] sum(f::Function, a::Base.OneTo{Int64})
    @ Base ./reducedim.jl:1011
  [7] (::SciMLSensitivity.var"#330#339"{…})()
    @ SciMLSensitivity ~/Documents/Julia Packages/SciML/SciMLSensitivity.jl/src/concrete_solve.jl:1074
  [8] unthunk
    @ ~/.julia/packages/ChainRulesCore/I1EbV/src/tangent_types/thunks.jl:204 [inlined]
  [9] wrap_chainrules_output
    @ ~/.julia/packages/Zygote/nsBv0/src/compiler/chainrules.jl:110 [inlined]
 [10] map
    @ ./tuple.jl:293 [inlined]
 [11] map (repeats 3 times)
    @ ./tuple.jl:294 [inlined]
 [12] wrap_chainrules_output
    @ ~/.julia/packages/Zygote/nsBv0/src/compiler/chainrules.jl:111 [inlined]
 [13] ZBack
    @ ~/.julia/packages/Zygote/nsBv0/src/compiler/chainrules.jl:211 [inlined]
 [14] (::Zygote.var"#kw_zpullback#53"{SciMLSensitivity.var"#forward_sensitivity_backpass#335"{…}})(dy::Vector{Any})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/chainrules.jl:237
 [15] #291
    @ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:206 [inlined]
 [16] (::Zygote.var"#2169#back#293"{Zygote.var"#291#292"{Tuple{…}, Zygote.var"#kw_zpullback#53"{…}}})(Δ::Vector{Any})
    @ Zygote ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72
 [17] #solve#51
    @ ~/.julia/packages/DiffEqBase/V6SCE/src/solve.jl:1003 [inlined]
 [18] (::Zygote.Pullback{Tuple{…}, Any})(Δ::Vector{Any})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [19] #291
    @ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:206 [inlined]
 [20] (::Zygote.var"#2169#back#293"{Zygote.var"#291#292"{Tuple{…}, Zygote.Pullback{…}}})(Δ::Vector{Any})
    @ Zygote ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72
 [21] solve
    @ ~/.julia/packages/DiffEqBase/V6SCE/src/solve.jl:993 [inlined]
 [22] (::Zygote.Pullback{Tuple{…}, Any})(Δ::Vector{Any})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [23] cost
    @ ~/Documents/Julia Packages/SciML/zcratch/sciml_structure.jl:103 [inlined]
 [24] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [25] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{…}, Tuple{…}}})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface.jl:91
 [26] gradient(f::Function, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface.jl:148
 [27] (::var"#38#42"{var"#cost#39"{Tsit5{…}}})()
    @ Main ~/.julia/juliaup/julia-1.10.4+0.aarch64.apple.darwin14/share/julia/stdlib/v1.10/Test/src/Test.jl:898
 [28] (::Base.RedirectStdStream)(thunk::var"#38#42"{var"#cost#39"{Tsit5{…}}}, stream::IOStream)
    @ Base ./stream.jl:1429
 [29] #37
    @ ~/.julia/juliaup/julia-1.10.4+0.aarch64.apple.darwin14/share/julia/stdlib/v1.10/Test/src/Test.jl:897 [inlined]
 [30] open(::var"#37#41"{var"#cost#39"{Tsit5{…}}}, ::String, ::Vararg{String}; kwargs::@Kwargs{})
    @ Base ./io.jl:396
 [31] open(::Function, ::String, ::String)
    @ Base ./io.jl:393
 [32] macro expansion
    @ ~/.julia/juliaup/julia-1.10.4+0.aarch64.apple.darwin14/share/julia/stdlib/v1.10/Test/src/Test.jl:896 [inlined]
 [33] top-level scope
    @ ~/Documents/Julia Packages/SciML/zcratch/sciml_structure.jl:106
Some type information was truncated. Use `show(err)` to see complete types.

Although I'm not sure if I'm doing this correctly, as the docs and tests for SciMLStructures are lacking examples. If it somehow does work, I'm happy to write up a docs example for a simple custom type, as this would save me a ton of work.

@ChrisRackauckas
Copy link
Member

Try setting it to GaussAdjoint?

For forward mode, #1085 should've just handled this yesterday. Give master a try?

@apkille
Copy link
Contributor Author

apkille commented Aug 22, 2024

@ChrisRackauckas still getting an error. On master, running

isscimlstructure(::CustomArray) = true
hasportion(::Tunable, ::CustomArray) = true
hasportion(::Constants, ::CustomArray) = false
hasportion(::Caches, ::CustomArray) = false
hasportion(::Discrete, ::CustomArray) = false

canonicalize(::Tunable, x::CustomArray) = vec(x.x), ArrayRepack(x.x), true

ca0 = CustomArray(ones(2))
tspan = (0.0, 1.0)
par = [rand(), rand()]

algs = [Tsit5(), BS3(), Vern9(), DP5()]

for alg in algs
    function cost(ca0, p)
        prob = ODEProblem((du, u, p, t) -> (du[1] = p[1]*u[1] + p[2]*u[2]; du[2] = p[2]*u[1]), ca0, tspan, p)
        sol = solve(prob, alg; sensealg = GaussAdjoint(), save_everystep = false)
        return 1 - norm(sol[end])^2
    end
    @test_nowarn Zygote.gradient(cost, ca0, par)
end

gives

ERROR: MethodError: no method matching ForwardDiff.JacobianConfig(::SciMLBase.ParamJacobianWrapper{…}, ::CustomArray{…}, ::Vector{…}, ::ForwardDiff.Chunk{…})

Closest candidates are:
  ForwardDiff.JacobianConfig(::F, ::AbstractArray{Y}, ::AbstractArray{X}, ::ForwardDiff.Chunk{N}) where {F, Y, X, N}
   @ ForwardDiff ~/.julia/packages/ForwardDiff/PcZ48/src/config.jl:179
  ForwardDiff.JacobianConfig(::F, ::AbstractArray{Y}, ::AbstractArray{X}, ::ForwardDiff.Chunk{N}, ::T) where {F, Y, X, N, T}
   @ ForwardDiff ~/.julia/packages/ForwardDiff/PcZ48/src/config.jl:179
  ForwardDiff.JacobianConfig(::F, ::AbstractArray{V}, ::ForwardDiff.Chunk{N}, ::T) where {F, V, N, T}
   @ ForwardDiff ~/.julia/packages/ForwardDiff/PcZ48/src/config.jl:154
  ...

Stacktrace:
  [1] build_param_jac_config(alg::GaussAdjoint{…}, pf::Function, u::CustomArray{…}, p::Vector{…})
    @ SciMLSensitivity ~/Documents/Julia Packages/SciML/SciMLSensitivity.jl/src/derivative_wrappers.jl:1063
  [2] SciMLSensitivity.GaussIntegrand(sol::ODESolution{…}, sensealg::GaussAdjoint{…}, checkpoints::Vector{…}, dgdp::Nothing)
    @ SciMLSensitivity ~/Documents/Julia Packages/SciML/SciMLSensitivity.jl/src/gauss_adjoint.jl:440
  [3] _adjoint_sensitivities(sol::ODESolution{…}, sensealg::GaussAdjoint{…}, alg::Tsit5{…}; t::Vector{…}, dgdu_discrete::Function, dgdp_discrete::Nothing, dgdu_continuous::Nothing, dgdp_continuous::Nothing, g::Nothing, abstol::Float64, reltol::Float64, checkpoints::Vector{…}, corfunc_analytical::Bool, callback::Nothing, kwargs::@Kwargs{…})
    @ SciMLSensitivity ~/Documents/Julia Packages/SciML/SciMLSensitivity.jl/src/gauss_adjoint.jl:558
  [4] adjoint_sensitivities(sol::ODESolution{…}, args::Tsit5{…}; sensealg::GaussAdjoint{…}, verbose::Bool, kwargs::@Kwargs{…})
    @ SciMLSensitivity ~/Documents/Julia Packages/SciML/SciMLSensitivity.jl/src/sensitivity_interface.jl:397
  [5] (::SciMLSensitivity.var"#adjoint_sensitivity_backpass#310"{…})(Δ::Vector{…})
    @ SciMLSensitivity ~/Documents/Julia Packages/SciML/SciMLSensitivity.jl/src/concrete_solve.jl:633
  [6] ZBack
    @ ~/.julia/packages/Zygote/nsBv0/src/compiler/chainrules.jl:211 [inlined]
  [7] (::Zygote.var"#kw_zpullback#53"{SciMLSensitivity.var"#adjoint_sensitivity_backpass#310"{…}})(dy::Vector{Any})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/chainrules.jl:237
  [8] #291
    @ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:206 [inlined]
  [9] (::Zygote.var"#2169#back#293"{Zygote.var"#291#292"{Tuple{…}, Zygote.var"#kw_zpullback#53"{…}}})(Δ::Vector{Any})
    @ Zygote ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72
 [10] #solve#51
    @ ~/.julia/packages/DiffEqBase/V6SCE/src/solve.jl:1003 [inlined]
 [11] (::Zygote.Pullback{Tuple{…}, Any})(Δ::Vector{Any})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [12] #291
    @ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:206 [inlined]
 [13] (::Zygote.var"#2169#back#293"{Zygote.var"#291#292"{Tuple{…}, Zygote.Pullback{…}}})(Δ::Vector{Any})
    @ Zygote ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72
 [14] solve
    @ ~/.julia/packages/DiffEqBase/V6SCE/src/solve.jl:993 [inlined]
 [15] (::Zygote.Pullback{Tuple{…}, Any})(Δ::Vector{Any})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [16] cost
    @ ~/Documents/Julia Packages/SciML/zcratch/sciml_structure.jl:103 [inlined]
 [17] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [18] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{…}, Tuple{…}}})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface.jl:91
 [19] gradient(::Function, ::CustomArray{Float64, 1}, ::Vararg{Any})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface.jl:148
 [20] (::var"#10#14"{var"#cost#11"{Tsit5{…}}})()
    @ Main ~/.julia/juliaup/julia-1.10.4+0.aarch64.apple.darwin14/share/julia/stdlib/v1.10/Test/src/Test.jl:898
 [21] (::Base.RedirectStdStream)(thunk::var"#10#14"{var"#cost#11"{Tsit5{…}}}, stream::IOStream)
    @ Base ./stream.jl:1429
 [22] #9
    @ ~/.julia/juliaup/julia-1.10.4+0.aarch64.apple.darwin14/share/julia/stdlib/v1.10/Test/src/Test.jl:897 [inlined]
 [23] open(::var"#9#13"{var"#cost#11"{Tsit5{…}}}, ::String, ::Vararg{String}; kwargs::@Kwargs{})
    @ Base ./io.jl:396
 [24] open(::Function, ::String, ::String)
    @ Base ./io.jl:393
 [25] macro expansion
    @ ~/.julia/juliaup/julia-1.10.4+0.aarch64.apple.darwin14/share/julia/stdlib/v1.10/Test/src/Test.jl:896 [inlined]
 [26] top-level scope
    @ ~/Documents/Julia Packages/SciML/zcratch/sciml_structure.jl:106

caused by: AssertionError: sensealg isa QuadratureAdjoint
Stacktrace:
  [1] (::SciMLSensitivity.ReverseLossCallback{…})(integrator::OrdinaryDiffEq.ODEIntegrator{…})
    @ SciMLSensitivity ~/Documents/Julia Packages/SciML/SciMLSensitivity.jl/src/adjoint_common.jl:563
  [2] (::DiffEqCallbacks.var"#112#116"{…})(c::DiscreteCallback{…}, u::CustomArray{…}, t::Float64, integrator::OrdinaryDiffEq.ODEIntegrator{…})
    @ DiffEqCallbacks ~/.julia/packages/DiffEqCallbacks/kbsPG/src/preset_time.jl:75
  [3] initialize!
    @ ~/.julia/packages/DiffEqBase/V6SCE/src/callbacks.jl:18 [inlined]
  [4] initialize!
    @ ~/.julia/packages/DiffEqBase/V6SCE/src/callbacks.jl:14 [inlined]
  [5] initialize!
    @ ~/.julia/packages/DiffEqBase/V6SCE/src/callbacks.jl:7 [inlined]
  [6] initialize_callbacks!(integrator::OrdinaryDiffEq.ODEIntegrator{…}, initialize_save::Bool)
    @ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/s27pa/src/solve.jl:667
  [7] __init(prob::ODEProblem{…}, alg::Tsit5{…}, timeseries_init::Tuple{}, ts_init::Tuple{}, ks_init::Tuple{}, recompile::Type{…}; saveat::Vector{…}, tstops::Vector{…}, d_discontinuities::Tuple{}, save_idxs::Nothing, save_everystep::Bool, save_on::Bool, save_start::Bool, save_end::Bool, callback::CallbackSet{…}, dense::Bool, calck::Bool, dt::Float64, dtmin::Float64, dtmax::Float64, force_dtmin::Bool, adaptive::Bool, gamma::Rational{…}, abstol::Float64, reltol::Float64, qmin::Rational{…}, qmax::Int64, qsteady_min::Int64, qsteady_max::Int64, beta1::Nothing, beta2::Nothing, qoldinit::Rational{…}, controller::Nothing, fullnormalize::Bool, failfactor::Int64, maxiters::Int64, internalnorm::typeof(DiffEqBase.ODE_DEFAULT_NORM), internalopnorm::typeof(opnorm), isoutofdomain::typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN), unstable_check::typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK), verbose::Bool, timeseries_errors::Bool, dense_errors::Bool, advance_to_tstop::Bool, stop_at_next_tstop::Bool, initialize_save::Bool, progress::Bool, progress_steps::Int64, progress_name::String, progress_message::typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE), progress_id::Symbol, userdata::Nothing, allow_extrapolation::Bool, initialize_integrator::Bool, alias_u0::Bool, alias_du0::Bool, initializealg::OrdinaryDiffEq.DefaultInit, kwargs::@Kwargs{})
    @ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/s27pa/src/solve.jl:523
  [8] __init (repeats 5 times)
    @ ~/.julia/packages/OrdinaryDiffEq/s27pa/src/solve.jl:11 [inlined]
  [9] #__solve#433
    @ ~/.julia/packages/OrdinaryDiffEq/s27pa/src/solve.jl:6 [inlined]
 [10] __solve
    @ ~/.julia/packages/OrdinaryDiffEq/s27pa/src/solve.jl:1 [inlined]
 [11] solve_call(_prob::ODEProblem{…}, args::Tsit5{…}; merge_callbacks::Bool, kwargshandle::Nothing, kwargs::@Kwargs{…})
    @ DiffEqBase ~/.julia/packages/DiffEqBase/V6SCE/src/solve.jl:612
 [12] solve_call
    @ ~/.julia/packages/DiffEqBase/V6SCE/src/solve.jl:569 [inlined]
 [13] #solve_up#53
    @ ~/.julia/packages/DiffEqBase/V6SCE/src/solve.jl:1080 [inlined]
 [14] solve_up
    @ ~/.julia/packages/DiffEqBase/V6SCE/src/solve.jl:1066 [inlined]
 [15] #solve#51
    @ ~/.julia/packages/DiffEqBase/V6SCE/src/solve.jl:1003 [inlined]
 [16] _adjoint_sensitivities(sol::ODESolution{…}, sensealg::GaussAdjoint{…}, alg::Tsit5{…}; t::Vector{…}, dgdu_discrete::Function, dgdp_discrete::Nothing, dgdu_continuous::Nothing, dgdp_continuous::Nothing, g::Nothing, abstol::Float64, reltol::Float64, checkpoints::Vector{…}, corfunc_analytical::Bool, callback::Nothing, kwargs::@Kwargs{…})
    @ SciMLSensitivity ~/Documents/Julia Packages/SciML/SciMLSensitivity.jl/src/gauss_adjoint.jl:580
 [17] _adjoint_sensitivities
    @ ~/Documents/Julia Packages/SciML/SciMLSensitivity.jl/src/gauss_adjoint.jl:533 [inlined]
 [18] adjoint_sensitivities(sol::ODESolution{…}, args::Tsit5{…}; sensealg::GaussAdjoint{…}, verbose::Bool, kwargs::@Kwargs{…})
    @ SciMLSensitivity ~/Documents/Julia Packages/SciML/SciMLSensitivity.jl/src/sensitivity_interface.jl:393
 [19] (::SciMLSensitivity.var"#adjoint_sensitivity_backpass#310"{…})(Δ::Vector{…})
    @ SciMLSensitivity ~/Documents/Julia Packages/SciML/SciMLSensitivity.jl/src/concrete_solve.jl:633
 [20] ZBack
    @ ~/.julia/packages/Zygote/nsBv0/src/compiler/chainrules.jl:211 [inlined]
 [21] (::Zygote.var"#kw_zpullback#53"{SciMLSensitivity.var"#adjoint_sensitivity_backpass#310"{…}})(dy::Vector{Any})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/chainrules.jl:237
 [22] #291
    @ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:206 [inlined]
 [23] (::Zygote.var"#2169#back#293"{Zygote.var"#291#292"{Tuple{…}, Zygote.var"#kw_zpullback#53"{…}}})(Δ::Vector{Any})
    @ Zygote ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72
 [24] #solve#51
    @ ~/.julia/packages/DiffEqBase/V6SCE/src/solve.jl:1003 [inlined]
 [25] (::Zygote.Pullback{Tuple{…}, Any})(Δ::Vector{Any})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [26] #291
    @ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:206 [inlined]
 [27] (::Zygote.var"#2169#back#293"{Zygote.var"#291#292"{Tuple{…}, Zygote.Pullback{…}}})(Δ::Vector{Any})
    @ Zygote ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72
 [28] solve
    @ ~/.julia/packages/DiffEqBase/V6SCE/src/solve.jl:993 [inlined]
 [29] (::Zygote.Pullback{Tuple{…}, Any})(Δ::Vector{Any})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [30] cost
    @ ~/Documents/Julia Packages/SciML/zcratch/sciml_structure.jl:103 [inlined]
 [31] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [32] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{…}, Tuple{…}}})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface.jl:91
 [33] gradient(::Function, ::CustomArray{Float64, 1}, ::Vararg{Any})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface.jl:148
 [34] (::var"#10#14"{var"#cost#11"{Tsit5{…}}})()
    @ Main ~/.julia/juliaup/julia-1.10.4+0.aarch64.apple.darwin14/share/julia/stdlib/v1.10/Test/src/Test.jl:898
 [35] (::Base.RedirectStdStream)(thunk::var"#10#14"{var"#cost#11"{Tsit5{…}}}, stream::IOStream)
    @ Base ./stream.jl:1429
 [36] #9
    @ ~/.julia/juliaup/julia-1.10.4+0.aarch64.apple.darwin14/share/julia/stdlib/v1.10/Test/src/Test.jl:897 [inlined]
 [37] open(::var"#9#13"{var"#cost#11"{Tsit5{…}}}, ::String, ::Vararg{String}; kwargs::@Kwargs{})
    @ Base ./io.jl:396
 [38] open(::Function, ::String, ::String)
    @ Base ./io.jl:393
 [39] macro expansion
    @ ~/.julia/juliaup/julia-1.10.4+0.aarch64.apple.darwin14/share/julia/stdlib/v1.10/Test/src/Test.jl:896 [inlined]
 [40] top-level scope
    @ ~/Documents/Julia Packages/SciML/zcratch/sciml_structure.jl:106
Some type information was truncated. Use `show(err)` to see complete types.

@apkille
Copy link
Contributor Author

apkille commented Aug 22, 2024

Maybe the discussion for using SciMLStructures with SciMLSensitivity should be moved to another issue. This PR only deals with using non-AbstractArrays directly with Zygote and SciMLSensitivity.

@apkille
Copy link
Contributor Author

apkille commented Aug 25, 2024

@ChrisRackauckas pinging in case you missed my above comment

@ChrisRackauckas
Copy link
Member

Rebase this PR onto the latest master. If all went well then this shouldn't need the code changes, just the tests.

@apkille
Copy link
Contributor Author

apkille commented Sep 2, 2024

@ChrisRackauckas the tests work with two small changes in concrete_solve.jl.

@ChrisRackauckas ChrisRackauckas merged commit 552d22b into SciML:master Sep 2, 2024
13 of 16 checks passed
@apkille apkille deleted the ca branch September 2, 2024 14:49
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants