Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisRackauckas committed Jun 6, 2024
1 parent 2380c6a commit 89d38f7
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 9 deletions.
12 changes: 6 additions & 6 deletions docs/src/faq.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

If you see the warnings:

```julia
```
┌ Warning: Reverse-Mode AD VJP choices all failed. Falling back to numerical VJPs
└ @ SciMLSensitivity C:\Users\accou\.julia\dev\SciMLSensitivity\src\concrete_solve.jl:145
┌ Warning: Potential performance improvement omitted. EnzymeVJP tried and failed in the automated AD choice algorithm. To show the stack trace, set SciMLSensitivity.STACKTRACE_WITH_VJPWARN[] = true. To turn off this printing, add `verbose = false` to the `solve` call.
Expand All @@ -29,7 +29,7 @@ du = zero(u0)
if DiffEqBase.isinplace(prob)
_f = prob.f
else
_f = (du,u,p,t) -> (du .= prob.f(u,p,t); nothing)
_f = (du, u, p, t) -> (du .= prob.f(u, p, t); nothing)
end

_tmp6 = Enzyme.make_zero(_f)
Expand All @@ -39,10 +39,10 @@ ytmp = zero(u0)
tmp1 = zero(u0)

Enzyme.autodiff(Enzyme.Reverse, Enzyme.Duplicated(_f, _tmp6),
Enzyme.Const, Enzyme.Duplicated(tmp3, tmp4),
Enzyme.Duplicated(ytmp, tmp1),
Enzyme.Duplicated(p, tmp2),
Enzyme.Const(t))
Enzyme.Const, Enzyme.Duplicated(tmp3, tmp4),
Enzyme.Duplicated(ytmp, tmp1),
Enzyme.Duplicated(p, tmp2),
Enzyme.Const(t))
```

This is exactly the inner core Enzyme call and if this fails, that is the issue that
Expand Down
2 changes: 1 addition & 1 deletion src/adjoint_common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ function adjointdiffcache(g::G, sensealg, discrete, sol, dgdu::DG1, dgdp::DG2, f
tape = ReverseDiff.GradientTape((y, _p)) do u, p
du1 = p !== nothing && p !== DiffEqBase.NullParameters() ?
similar(p, size(u)) : similar(u)
copyto!(du1, false)
du1 .= false
unwrappedf(du1, u, p, nothing)
return vec(du1)
end
Expand Down
4 changes: 2 additions & 2 deletions test/null_parameters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,15 +81,15 @@ function loss10(params)
u0 = zeros(2)
problem = ODEProblem(dynamics, u0, (0.0, 1.0))
rollout = solve(problem, Tsit5(), u0 = u0, p = params,
sensealg = GaussAdjoint(autojacvec = EnzymeVJP()))
sensealg = GaussAdjoint(autojacvec = EnzymeVJP()))
sum(Array(rollout)[:, end])
end

function loss11(params)
u0 = zeros(2)
problem = ODEProblem(dynamics, u0, (0.0, 1.0))
rollout = solve(problem, Tsit5(), u0 = u0, p = params,
sensealg = GaussAdjoint(autojacvec = ZygoteVJP()))
sensealg = GaussAdjoint(autojacvec = ZygoteVJP()))
sum(Array(rollout)[:, end])
end

Expand Down

0 comments on commit 89d38f7

Please sign in to comment.