diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 8623d4c8b..870d0421f 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -24,7 +24,6 @@ jobs: - Core5 - Core6 - Core7 - - DiffEq - SDE1 - SDE2 - SDE3 diff --git a/Project.toml b/Project.toml index 268eb0854..a22eee543 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "SciMLSensitivity" uuid = "1ed8b502-d754-442c-8d5d-10ac956f44a1" authors = ["Christopher Rackauckas ", "Yingbo Ma "] -version = "7.59.0" +version = "7.60.1" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" @@ -76,7 +76,7 @@ NLsolve = "4.5.1" NonlinearSolve = "3.0.1" Optimization = "3.19.3" OptimizationOptimisers = "0.1.6" -OrdinaryDiffEq = "6.68.1" +OrdinaryDiffEq = "6.81.1" Parameters = "0.12" Pkg = "1.10" PreallocationTools = "0.4.4" diff --git a/docs/pages.jl b/docs/pages.jl index af24391c3..482e9e2c4 100644 --- a/docs/pages.jl +++ b/docs/pages.jl @@ -8,6 +8,7 @@ pages = ["index.md", "Training Techniques and Tips" => Any["tutorials/training_tips/local_minima.md", "tutorials/training_tips/divergence.md", "tutorials/training_tips/multiple_nn.md"]], + "Frequently Asked Questions (FAQ)" => "faq.md", "Examples" => Any[ "Ordinary Differential Equations (ODEs)" => Any["examples/ode/exogenous_input.md", "examples/ode/prediction_error_method.md", @@ -28,7 +29,8 @@ pages = ["index.md", "Optimal and Model Predictive Control" => Any[ "examples/optimal_control/optimal_control.md", "examples/optimal_control/feedback_control.md"]], - "Manual and APIs" => Any["manual/differential_equation_sensitivities.md", + "Manual and APIs" => Any[ + "manual/differential_equation_sensitivities.md", "manual/nonlinear_solve_sensitivities.md", "manual/direct_forward_sensitivity.md", "manual/direct_adjoint_sensitivities.md"], diff --git a/docs/src/faq.md b/docs/src/faq.md new file mode 100644 index 000000000..74de01a31 --- /dev/null +++ b/docs/src/faq.md @@ -0,0 +1,62 @@ +# Frequently Asked Qestuions (FAQ) + +## How do I isolate potential gradient issues and improve performance? + +If you see the warnings: + +```julia +┌ Warning: Reverse-Mode AD VJP choices all failed. Falling back to numerical VJPs +└ @ SciMLSensitivity C:\Users\accou\.julia\dev\SciMLSensitivity\src\concrete_solve.jl:145 +┌ Warning: Potential performance improvement omitted. EnzymeVJP tried and failed in the automated AD choice algorithm. To show the stack trace, set SciMLSensitivity.STACKTRACE_WITH_VJPWARN[] = true. To turn off this printing, add `verbose = false` to the `solve` call. +└ @ SciMLSensitivity C:\Users\accou\.julia\dev\SciMLSensitivity\src\concrete_solve.jl:100 +``` + +then you're in luck! Well, not really. But there are things you can do. You can isolate the +issue to automatic differentiation of your `f` function in order to either fix your `f` +function, or open an issue with the AD library directly without the ODE solver involved. + +If you have an in-place function, then you will want to isolate it to Enzyme. This is done +as follows for an arbitrary problem: + +```julia +using Enzyme +u0 = prob.u0 +p = prob.p +tmp2 = Enzyme.make_zero(p) +t = prob.tspan[1] +du = zero(u0) + +if DiffEqBase.isinplace(prob) + _f = prob.f +else + _f = (du,u,p,t) -> (du .= prob.f(u,p,t); nothing) +end + +_tmp6 = Enzyme.make_zero(_f) +tmp3 = zero(u0) +tmp4 = zero(u0) +ytmp = zero(u0) +tmp1 = zero(u0) + +Enzyme.autodiff(Enzyme.Reverse, Enzyme.Duplicated(_f, _tmp6), + Enzyme.Const, Enzyme.Duplicated(tmp3, tmp4), + Enzyme.Duplicated(ytmp, tmp1), + Enzyme.Duplicated(p, tmp2), + Enzyme.Const(t)) +``` + +This is exactly the inner core Enzyme call and if this fails, that is the issue that +needs to be fixed. + +And similarly, for out-of-place functions the Zygote isolation is as follows: + +```julia +p = prob.p +y = prob.u0 +f = prob.f +λ = zero(prob.u0) +_dy, back = Zygote.pullback(y, p) do u, p + vec(f(u, p, t)) +end +tmp1, tmp2 = back(λ) +``` diff --git a/src/SciMLSensitivity.jl b/src/SciMLSensitivity.jl index 4ed109ce6..48c32ffa6 100644 --- a/src/SciMLSensitivity.jl +++ b/src/SciMLSensitivity.jl @@ -40,7 +40,8 @@ import SciMLBase: unwrapped_f, _unwrap_val import SciMLBase: AbstractOverloadingSensitivityAlgorithm, AbstractSensitivityAlgorithm, AbstractForwardSensitivityAlgorithm, AbstractAdjointSensitivityAlgorithm, AbstractSecondOrderSensitivityAlgorithm, - AbstractShadowingSensitivityAlgorithm + AbstractShadowingSensitivityAlgorithm, + AbstractTimeseriesSolution include("parameters_handling.jl") include("sensitivity_algorithms.jl") diff --git a/src/adjoint_common.jl b/src/adjoint_common.jl index 018ad3ffc..b98b5eaa2 100644 --- a/src/adjoint_common.jl +++ b/src/adjoint_common.jl @@ -662,3 +662,17 @@ function out_and_ts(_ts, duplicate_iterator_times, sol) end return out, ts end + +if !hasmethod(Zygote.adjoint, + Tuple{Zygote.AContext, typeof(Zygote.literal_getproperty), + SciMLBase.AbstractTimeseriesSolution, Val{:u}}) + Zygote.@adjoint function Zygote.literal_getproperty(sol::AbstractTimeseriesSolution, + ::Val{:u}) + function solu_adjoint(Δ) + zerou = zero(sol.prob.u0) + _Δ = @. ifelse(Δ === nothing, (zerou,), Δ) + (SciMLBase.build_solution(sol.prob, sol.alg, sol.t, _Δ),) + end + sol.u, solu_adjoint + end +end diff --git a/src/concrete_solve.jl b/src/concrete_solve.jl index b71c52e8c..7de8adb34 100644 --- a/src/concrete_solve.jl +++ b/src/concrete_solve.jl @@ -87,10 +87,22 @@ function automatic_sensealg_choice(prob::Union{SciMLBase.AbstractODEProblem, # so if out-of-place, try Zygote vjp = try + p = prob.p + y = prob.u0 + f = prob.f + t = prob.tspan[1] + λ = zero(prob.u0) + if p === nothing || p isa SciMLBase.NullParameters - Zygote.gradient((u) -> sum(prob.f(u, p, prob.tspan[1])), u0) + _dy, back = Zygote.pullback(y) do u + vec(f(u, p, t)) + end + tmp1 = back(λ) else - Zygote.gradient((u, _p) -> sum(prob.f(u, repack(_p), prob.tspan[1])), u0, p) + _dy, back = Zygote.pullback(y, p) do u, p + vec(f(u, p, t)) + end + tmp1, tmp2 = back(λ) end ZygoteVJP() catch e @@ -122,10 +134,22 @@ function automatic_sensealg_choice(prob::Union{SciMLBase.AbstractODEProblem, if vjp == false vjp = try + p = prob.p + y = prob.u0 + f = prob.f + t = prob.tspan[1] + λ = zero(prob.u0) + if p === nothing || p isa SciMLBase.NullParameters - Tracker.gradient((u) -> sum(prob.f(u, p, prob.tspan[1])), u0) + _dy, back = Tracker.forward(y) do u + vec(f(u, p, t)) + end + tmp1 = back(λ) else - Tracker.gradient((u, _p) -> sum(prob.f(u, repack(_p), prob.tspan[1])), u0, p) + _dy, back = Tracker.forward(y, p) do u, p + vec(f(u, p, t)) + end + tmp1, tmp2 = back(λ) end TrackerVJP() catch e diff --git a/test/diffeq/default_alg_diff.jl b/test/default_alg_diff.jl similarity index 52% rename from test/diffeq/default_alg_diff.jl rename to test/default_alg_diff.jl index 7387e3b77..557f59c87 100644 --- a/test/diffeq/default_alg_diff.jl +++ b/test/default_alg_diff.jl @@ -1,4 +1,4 @@ -using ComponentArrays, DifferentialEquations, Lux, Random, SciMLSensitivity, Zygote +using ComponentArrays, OrdinaryDiffEq, Lux, Random, SciMLSensitivity, Zygote function f(du, u, p, t) du .= first(nn(u, p, st)) @@ -13,7 +13,16 @@ r = rand(Float32, 8, 64) function f2(x) prob = ODEProblem(f, r, (0.0f0, 1.0f0), x) - sol = solve(prob; sensealg = InterpolatingAdjoint(; autodiff = true, autojacvec = true)) + sol = solve(prob, OrdinaryDiffEq.DefaultODEAlgorithm()) + sum(last(sol.u)) +end + +f2(ps) +Zygote.gradient(f2, ps) + +function f2(x) + prob = ODEProblem(f, r, (0.0f0, 1.0f0), x) + sol = solve(prob) sum(last(sol.u)) end diff --git a/test/diffeq/Project.toml b/test/diffeq/Project.toml deleted file mode 100644 index 6dde5fc89..000000000 --- a/test/diffeq/Project.toml +++ /dev/null @@ -1,5 +0,0 @@ -[deps] -DifferentialEquations = "0c46a032-eb83-5123-abaf-570d42b7fbaa" - -[compat] -DifferentialEquations = "7" diff --git a/test/runtests.jl b/test/runtests.jl index c956d69da..d26d6789a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -9,12 +9,6 @@ function activate_gpu_env() Pkg.instantiate() end -function activate_diffeq_env() - Pkg.activate("diffeq") - Pkg.develop(PackageSpec(path = dirname(@__DIR__))) - Pkg.instantiate() -end - @time @testset "SciMLSensitivity" begin if GROUP == "All" || GROUP == "Core1" || GROUP == "Downstream" @testset "Core1" begin @@ -49,6 +43,7 @@ end if GROUP == "All" || GROUP == "Core3" || GROUP == "Downstream" @testset "Core 3" begin + @time @safetestset "Default DiffEq Alg" include("default_alg_diff.jl") @time @safetestset "Adjoint Sensitivity" include("adjoint.jl") @time @safetestset "automatic sensealg choice" include("automatic_sensealg_choice.jl") end @@ -150,13 +145,6 @@ end end end - if GROUP == "DiffEq" - @testset "DiffEq" begin - activate_diffeq_env() - @time @safetestset "Default DiffEq Alg" include("diffeq/default_alg_diff.jl") - end - end - if GROUP == "GPU" @testset "GPU" begin activate_gpu_env()