Skip to content

Commit

Permalink
Merge branch 'master' into dg/structures
Browse files Browse the repository at this point in the history
  • Loading branch information
DhairyaLGandhi committed Jun 7, 2024
2 parents cf50e43 + 7569697 commit 47231ef
Show file tree
Hide file tree
Showing 10 changed files with 123 additions and 29 deletions.
1 change: 0 additions & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ jobs:
- Core5
- Core6
- Core7
- DiffEq
- SDE1
- SDE2
- SDE3
Expand Down
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "SciMLSensitivity"
uuid = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
authors = ["Christopher Rackauckas <accounts@chrisrackauckas.com>", "Yingbo Ma <mayingbo5@gmail.com>"]
version = "7.59.0"
version = "7.60.1"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down Expand Up @@ -76,7 +76,7 @@ NLsolve = "4.5.1"
NonlinearSolve = "3.0.1"
Optimization = "3.19.3"
OptimizationOptimisers = "0.1.6"
OrdinaryDiffEq = "6.68.1"
OrdinaryDiffEq = "6.81.1"
Parameters = "0.12"
Pkg = "1.10"
PreallocationTools = "0.4.4"
Expand Down
4 changes: 3 additions & 1 deletion docs/pages.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ pages = ["index.md",
"Training Techniques and Tips" => Any["tutorials/training_tips/local_minima.md",
"tutorials/training_tips/divergence.md",
"tutorials/training_tips/multiple_nn.md"]],
"Frequently Asked Questions (FAQ)" => "faq.md",
"Examples" => Any[
"Ordinary Differential Equations (ODEs)" => Any["examples/ode/exogenous_input.md",
"examples/ode/prediction_error_method.md",
Expand All @@ -28,7 +29,8 @@ pages = ["index.md",
"Optimal and Model Predictive Control" => Any[
"examples/optimal_control/optimal_control.md",
"examples/optimal_control/feedback_control.md"]],
"Manual and APIs" => Any["manual/differential_equation_sensitivities.md",
"Manual and APIs" => Any[
"manual/differential_equation_sensitivities.md",
"manual/nonlinear_solve_sensitivities.md",
"manual/direct_forward_sensitivity.md",
"manual/direct_adjoint_sensitivities.md"],
Expand Down
62 changes: 62 additions & 0 deletions docs/src/faq.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Frequently Asked Qestuions (FAQ)

## How do I isolate potential gradient issues and improve performance?

If you see the warnings:

```julia
┌ Warning: Reverse-Mode AD VJP choices all failed. Falling back to numerical VJPs
└ @ SciMLSensitivity C:\Users\accou\.julia\dev\SciMLSensitivity\src\concrete_solve.jl:145
┌ Warning: Potential performance improvement omitted. EnzymeVJP tried and failed in the automated AD choice algorithm. To show the stack trace, set SciMLSensitivity.STACKTRACE_WITH_VJPWARN[] = true. To turn off this printing, add `verbose = false` to the `solve` call.
└ @ SciMLSensitivity C:\Users\accou\.julia\dev\SciMLSensitivity\src\concrete_solve.jl:100
```

then you're in luck! Well, not really. But there are things you can do. You can isolate the
issue to automatic differentiation of your `f` function in order to either fix your `f`
function, or open an issue with the AD library directly without the ODE solver involved.

If you have an in-place function, then you will want to isolate it to Enzyme. This is done
as follows for an arbitrary problem:

```julia
using Enzyme
u0 = prob.u0
p = prob.p
tmp2 = Enzyme.make_zero(p)
t = prob.tspan[1]
du = zero(u0)

if DiffEqBase.isinplace(prob)
_f = prob.f
else
_f = (du,u,p,t) -> (du .= prob.f(u,p,t); nothing)
end

_tmp6 = Enzyme.make_zero(_f)
tmp3 = zero(u0)
tmp4 = zero(u0)
ytmp = zero(u0)
tmp1 = zero(u0)

Enzyme.autodiff(Enzyme.Reverse, Enzyme.Duplicated(_f, _tmp6),
Enzyme.Const, Enzyme.Duplicated(tmp3, tmp4),
Enzyme.Duplicated(ytmp, tmp1),
Enzyme.Duplicated(p, tmp2),
Enzyme.Const(t))
```

This is exactly the inner core Enzyme call and if this fails, that is the issue that
needs to be fixed.

And similarly, for out-of-place functions the Zygote isolation is as follows:

```julia
p = prob.p
y = prob.u0
f = prob.f
λ = zero(prob.u0)
_dy, back = Zygote.pullback(y, p) do u, p
vec(f(u, p, t))
end
tmp1, tmp2 = back(λ)
```
3 changes: 2 additions & 1 deletion src/SciMLSensitivity.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ import SciMLBase: unwrapped_f, _unwrap_val
import SciMLBase: AbstractOverloadingSensitivityAlgorithm, AbstractSensitivityAlgorithm,
AbstractForwardSensitivityAlgorithm, AbstractAdjointSensitivityAlgorithm,
AbstractSecondOrderSensitivityAlgorithm,
AbstractShadowingSensitivityAlgorithm
AbstractShadowingSensitivityAlgorithm,
AbstractTimeseriesSolution

include("parameters_handling.jl")
include("sensitivity_algorithms.jl")
Expand Down
14 changes: 14 additions & 0 deletions src/adjoint_common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -662,3 +662,17 @@ function out_and_ts(_ts, duplicate_iterator_times, sol)
end
return out, ts
end

if !hasmethod(Zygote.adjoint,
Tuple{Zygote.AContext, typeof(Zygote.literal_getproperty),
SciMLBase.AbstractTimeseriesSolution, Val{:u}})
Zygote.@adjoint function Zygote.literal_getproperty(sol::AbstractTimeseriesSolution,
::Val{:u})
function solu_adjoint(Δ)
zerou = zero(sol.prob.u0)
= @. ifelse=== nothing, (zerou,), Δ)
(SciMLBase.build_solution(sol.prob, sol.alg, sol.t, _Δ),)
end
sol.u, solu_adjoint
end
end
32 changes: 28 additions & 4 deletions src/concrete_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,22 @@ function automatic_sensealg_choice(prob::Union{SciMLBase.AbstractODEProblem,
# so if out-of-place, try Zygote

vjp = try
p = prob.p
y = prob.u0
f = prob.f
t = prob.tspan[1]
λ = zero(prob.u0)

if p === nothing || p isa SciMLBase.NullParameters
Zygote.gradient((u) -> sum(prob.f(u, p, prob.tspan[1])), u0)
_dy, back = Zygote.pullback(y) do u
vec(f(u, p, t))
end
tmp1 = back(λ)
else
Zygote.gradient((u, _p) -> sum(prob.f(u, repack(_p), prob.tspan[1])), u0, p)
_dy, back = Zygote.pullback(y, p) do u, p
vec(f(u, p, t))
end
tmp1, tmp2 = back(λ)
end
ZygoteVJP()
catch e
Expand Down Expand Up @@ -122,10 +134,22 @@ function automatic_sensealg_choice(prob::Union{SciMLBase.AbstractODEProblem,

if vjp == false
vjp = try
p = prob.p
y = prob.u0
f = prob.f
t = prob.tspan[1]
λ = zero(prob.u0)

if p === nothing || p isa SciMLBase.NullParameters
Tracker.gradient((u) -> sum(prob.f(u, p, prob.tspan[1])), u0)
_dy, back = Tracker.forward(y) do u
vec(f(u, p, t))
end
tmp1 = back(λ)
else
Tracker.gradient((u, _p) -> sum(prob.f(u, repack(_p), prob.tspan[1])), u0, p)
_dy, back = Tracker.forward(y, p) do u, p
vec(f(u, p, t))
end
tmp1, tmp2 = back(λ)
end
TrackerVJP()
catch e
Expand Down
13 changes: 11 additions & 2 deletions test/diffeq/default_alg_diff.jl → test/default_alg_diff.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using ComponentArrays, DifferentialEquations, Lux, Random, SciMLSensitivity, Zygote
using ComponentArrays, OrdinaryDiffEq, Lux, Random, SciMLSensitivity, Zygote

function f(du, u, p, t)
du .= first(nn(u, p, st))
Expand All @@ -13,7 +13,16 @@ r = rand(Float32, 8, 64)

function f2(x)
prob = ODEProblem(f, r, (0.0f0, 1.0f0), x)
sol = solve(prob; sensealg = InterpolatingAdjoint(; autodiff = true, autojacvec = true))
sol = solve(prob, OrdinaryDiffEq.DefaultODEAlgorithm())
sum(last(sol.u))
end

f2(ps)
Zygote.gradient(f2, ps)

function f2(x)
prob = ODEProblem(f, r, (0.0f0, 1.0f0), x)
sol = solve(prob)
sum(last(sol.u))
end

Expand Down
5 changes: 0 additions & 5 deletions test/diffeq/Project.toml

This file was deleted.

14 changes: 1 addition & 13 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,6 @@ function activate_gpu_env()
Pkg.instantiate()
end

function activate_diffeq_env()
Pkg.activate("diffeq")
Pkg.develop(PackageSpec(path = dirname(@__DIR__)))
Pkg.instantiate()
end

@time @testset "SciMLSensitivity" begin
if GROUP == "All" || GROUP == "Core1" || GROUP == "Downstream"
@testset "Core1" begin
Expand Down Expand Up @@ -49,6 +43,7 @@ end

if GROUP == "All" || GROUP == "Core3" || GROUP == "Downstream"
@testset "Core 3" begin
@time @safetestset "Default DiffEq Alg" include("default_alg_diff.jl")
@time @safetestset "Adjoint Sensitivity" include("adjoint.jl")
@time @safetestset "automatic sensealg choice" include("automatic_sensealg_choice.jl")
end
Expand Down Expand Up @@ -150,13 +145,6 @@ end
end
end

if GROUP == "DiffEq"
@testset "DiffEq" begin
activate_diffeq_env()
@time @safetestset "Default DiffEq Alg" include("diffeq/default_alg_diff.jl")
end
end

if GROUP == "GPU"
@testset "GPU" begin
activate_gpu_env()
Expand Down

0 comments on commit 47231ef

Please sign in to comment.