diff --git a/Project.toml b/Project.toml index 8caea86c2..fb41a9df9 100644 --- a/Project.toml +++ b/Project.toml @@ -32,6 +32,7 @@ RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" +SciMLJacobianOperators = "19f34311-ddf3-4b8b-af20-060888a46c0e" SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961" SciMLStructures = "53ae85a6-f571-4167-b2af-e1d143709226" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" @@ -67,7 +68,7 @@ Functors = "0.4" GPUArraysCore = "0.1" LinearAlgebra = "1.10" LinearSolve = "2" -Lux = "0.5.51" +Lux = "1" Markdown = "1.10" ModelingToolkit = "9" NLsolve = "4.5.1" @@ -85,6 +86,7 @@ Reexport = "1.0" ReverseDiff = "1.15.1" SafeTestsets = "0.1.0" SciMLBase = "2.51.4" +SciMLJacobianOperators = "0.1" SciMLOperators = "0.3" SciMLStructures = "1.3" SparseArrays = "1.10" diff --git a/docs/Project.toml b/docs/Project.toml index dbc64e153..4e1017435 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -45,7 +45,7 @@ Enzyme = "0.12, 0.13" Flux = "0.14" ForwardDiff = "0.10" IterTools = "1" -Lux = "0.5.7, 1" +Lux = "1" LuxCUDA = "0.3" Optimization = "3.9, 4" OptimizationOptimJL = "0.2, 0.3, 0.4" diff --git a/src/SciMLSensitivity.jl b/src/SciMLSensitivity.jl index 5eae8ef30..d05c054f4 100644 --- a/src/SciMLSensitivity.jl +++ b/src/SciMLSensitivity.jl @@ -21,7 +21,7 @@ using RandomNumbers: Xorshifts using RecursiveArrayTools: RecursiveArrayTools, AbstractDiffEqArray, AbstractVectorOfArray, ArrayPartition, DiffEqArray, VectorOfArray -# using SciMLJacobianOperators: VecJacOperator # TODO: Replace uses of VecJac +using SciMLJacobianOperators: VecJacOperator, StatefulJacobianOperator using SciMLStructures: SciMLStructures, canonicalize, Tunable, isscimlstructure using SymbolicIndexingInterface: SymbolicIndexingInterface, current_time, getu, parameter_values, state_values @@ -32,7 +32,7 @@ using SciMLBase: SciMLBase, AbstractOverloadingSensitivityAlgorithm, AbstractShadowingSensitivityAlgorithm, AbstractTimeseriesSolution, AbstractNonlinearProblem, AbstractSensitivityAlgorithm, AbstractDiffEqFunction, AbstractODEFunction, unwrapped_f, CallbackSet, - ContinuousCallback, DESolution, + ContinuousCallback, DESolution, NonlinearFunction, NonlinearProblem, DiscreteCallback, LinearProblem, ODEFunction, ODEProblem, RODEFunction, RODEProblem, ReturnCode, SDEFunction, SDEProblem, VectorContinuousCallback, deleteat!, diff --git a/src/sensitivity_algorithms.jl b/src/sensitivity_algorithms.jl index d2913c4e0..133267f7a 100644 --- a/src/sensitivity_algorithms.jl +++ b/src/sensitivity_algorithms.jl @@ -1311,8 +1311,8 @@ struct ForwardDiffOverAdjoint{A} <: adjalg::A end -function get_autodiff_from_vjp(vjp::ReverseDiffVJP{compile}) where {compile} - AutoReverseDiff(; compile) +function get_autodiff_from_vjp(::ReverseDiffVJP{compile}) where {compile} + return AutoReverseDiff(; compile) end get_autodiff_from_vjp(::ZygoteVJP) = AutoZygote() get_autodiff_from_vjp(::EnzymeVJP) = AutoEnzyme() diff --git a/src/steadystate_adjoint.jl b/src/steadystate_adjoint.jl index b546b0475..e7f34de43 100644 --- a/src/steadystate_adjoint.jl +++ b/src/steadystate_adjoint.jl @@ -85,12 +85,21 @@ end end if !needs_jac + # Current SciMLJacobianOperators requires specifying the problem as a NonlinearProblem usize = size(y) - __f = y -> vec(f(reshape(y, usize), p, nothing)) - operator = VecJac(__f, vec(y); - autodiff = get_autodiff_from_vjp(sensealg.autojacvec)) - linear_problem = LinearProblem(operator, vec(dgdu_val); u0 = vec(λ)) - solve(linear_problem, linsolve; alias_A = true, sensealg.linsolve_kwargs...) # u is vec(λ) + if SciMLBase.isinplace(f) + nlfunc = NonlinearFunction{true}((du, u, p) -> unwrapped_f(f)( + reshape(u, usize), reshape(u, usize), p, nothing)) + else + nlfunc = NonlinearFunction{false}((u, p) -> unwrapped_f(f)( + reshape(u, usize), p, nothing)) + end + nlprob = NonlinearProblem(nlfunc, vec(λ), p) + operator = VecJacOperator( + nlprob, vec(y), (λ); autodiff = get_autodiff_from_vjp(sensealg.autojacvec)) + soperator = StatefulJacobianOperator(operator, vec(λ), p) + linear_problem = LinearProblem(soperator, vec(dgdu_val); u0 = vec(λ)) + solve(linear_problem, linsolve; alias_A = true, sensealg.linsolve_kwargs...) else if linsolve === nothing && isempty(sensealg.linsolve_kwargs) # For the default case use `\` to avoid any form of unnecessary cache allocation diff --git a/test/gpu/Project.toml b/test/gpu/Project.toml index 1763c5871..b3f667084 100644 --- a/test/gpu/Project.toml +++ b/test/gpu/Project.toml @@ -5,7 +5,7 @@ DiffEqFlux = "aae7a2af-3d4f-5e19-a356-7da93b79d9d0" LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" [compat] -CUDA = "3.12, 4, 5" -DiffEqCallbacks = "2.24, 3" -DiffEqFlux = "3, 4" +CUDA = "5" +DiffEqCallbacks = "3" +DiffEqFlux = "4" LuxCUDA = "0.3.1"