diff --git a/src/forward_sensitivity.jl b/src/forward_sensitivity.jl index b6527400a..f23cceed4 100644 --- a/src/forward_sensitivity.jl +++ b/src/forward_sensitivity.jl @@ -675,8 +675,14 @@ function SciMLBase.remake( {uType, tType, isinplace, P, F, K} _p = p === nothing ? parameter_values(prob) : p _f = f === nothing ? prob.f.f : f - _u0 = u0 === nothing ? state_values(prob, 1:(prob.f.numindvar)) : - u0[1:(prob.f.numindvar)] + + if typeof(_f) <: ODEForwardSensitivityFunction + _u0 = u0 === nothing ? state_values(prob, 1:(_f.numindvar)) : + u0[1:(_f.numindvar)] + else + _u0 = u0 === nothing ? state_values(prob) : u0 + end + _tspan = tspan === nothing ? prob.tspan : tspan ODEForwardSensitivityProblem(_f, _u0, _tspan, _p; sensealg = prob.problem_type.sensealg, diff --git a/test/forward.jl b/test/forward.jl index 08158aec3..71a819c8b 100644 --- a/test/forward.jl +++ b/test/forward.jl @@ -271,3 +271,15 @@ f = prob.f @assert f isa ODEForwardSensitivityFunction @test hasproperty(f, :observed) @test f.observed == SciMLBase.DEFAULT_OBSERVED + +# `remake`: https://github.com/SciML/SciMLSensitivity.jl/issues/1137 + +function ff3(du, u, p, t) + du[1] = dx = p[1] * u[1] - p[2] * u[1] * u[2] + du[2] = dy = -p[3] * u[2] + u[1] * u[2] +end + +p = [1.5, 1.0, 3.0] +ts = (0, 10) +prob = ODEForwardSensitivityProblem(ff3, [1.0; 1.0], ts, p, sensealg=ForwardDiffSensitivity()) +sol = solve(prob, Tsit5()) \ No newline at end of file