diff --git a/docs/src/faq.md b/docs/src/faq.md index 74de01a31..3f63d569f 100644 --- a/docs/src/faq.md +++ b/docs/src/faq.md @@ -4,7 +4,7 @@ If you see the warnings: -```julia +``` ┌ Warning: Reverse-Mode AD VJP choices all failed. Falling back to numerical VJPs └ @ SciMLSensitivity C:\Users\accou\.julia\dev\SciMLSensitivity\src\concrete_solve.jl:145 ┌ Warning: Potential performance improvement omitted. EnzymeVJP tried and failed in the automated AD choice algorithm. To show the stack trace, set SciMLSensitivity.STACKTRACE_WITH_VJPWARN[] = true. To turn off this printing, add `verbose = false` to the `solve` call. @@ -29,7 +29,7 @@ du = zero(u0) if DiffEqBase.isinplace(prob) _f = prob.f else - _f = (du,u,p,t) -> (du .= prob.f(u,p,t); nothing) + _f = (du, u, p, t) -> (du .= prob.f(u, p, t); nothing) end _tmp6 = Enzyme.make_zero(_f) @@ -39,10 +39,10 @@ ytmp = zero(u0) tmp1 = zero(u0) Enzyme.autodiff(Enzyme.Reverse, Enzyme.Duplicated(_f, _tmp6), - Enzyme.Const, Enzyme.Duplicated(tmp3, tmp4), - Enzyme.Duplicated(ytmp, tmp1), - Enzyme.Duplicated(p, tmp2), - Enzyme.Const(t)) + Enzyme.Const, Enzyme.Duplicated(tmp3, tmp4), + Enzyme.Duplicated(ytmp, tmp1), + Enzyme.Duplicated(p, tmp2), + Enzyme.Const(t)) ``` This is exactly the inner core Enzyme call and if this fails, that is the issue that diff --git a/src/adjoint_common.jl b/src/adjoint_common.jl index fe5a88b4f..92e83db8a 100644 --- a/src/adjoint_common.jl +++ b/src/adjoint_common.jl @@ -174,7 +174,7 @@ function adjointdiffcache(g::G, sensealg, discrete, sol, dgdu::DG1, dgdp::DG2, f tape = ReverseDiff.GradientTape((y, _p)) do u, p du1 = p !== nothing && p !== DiffEqBase.NullParameters() ? similar(p, size(u)) : similar(u) - copyto!(du1, false) + du1 .= false unwrappedf(du1, u, p, nothing) return vec(du1) end diff --git a/test/null_parameters.jl b/test/null_parameters.jl index 7085fd487..e302c7bce 100644 --- a/test/null_parameters.jl +++ b/test/null_parameters.jl @@ -81,7 +81,7 @@ function loss10(params) u0 = zeros(2) problem = ODEProblem(dynamics, u0, (0.0, 1.0)) rollout = solve(problem, Tsit5(), u0 = u0, p = params, - sensealg = GaussAdjoint(autojacvec = EnzymeVJP())) + sensealg = GaussAdjoint(autojacvec = EnzymeVJP())) sum(Array(rollout)[:, end]) end @@ -89,7 +89,7 @@ function loss11(params) u0 = zeros(2) problem = ODEProblem(dynamics, u0, (0.0, 1.0)) rollout = solve(problem, Tsit5(), u0 = u0, p = params, - sensealg = GaussAdjoint(autojacvec = ZygoteVJP())) + sensealg = GaussAdjoint(autojacvec = ZygoteVJP())) sum(Array(rollout)[:, end]) end