From 82604e5fa6f2a939673d04ad875e18e82fab5a7f Mon Sep 17 00:00:00 2001 From: DhairyaLGandhi Date: Tue, 21 May 2024 15:08:35 +0530 Subject: [PATCH 01/11] chore: move literal_getproperty here --- src/adjoint_common.jl | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/adjoint_common.jl b/src/adjoint_common.jl index 7c356e58d..7cb80017f 100644 --- a/src/adjoint_common.jl +++ b/src/adjoint_common.jl @@ -651,3 +651,13 @@ function out_and_ts(_ts, duplicate_iterator_times, sol) end return out, ts end + +Zygote.@adjoint function Zygote.literal_getproperty(sol::AbstractTimeseriesSolution, + ::Val{:u}) + function solu_adjoint(Δ) + zerou = zero(sol.prob.u0) + _Δ = @. ifelse(Δ === nothing, (zerou,), Δ) + (build_solution(sol.prob, sol.alg, sol.t, _Δ),) + end + sol.u, solu_adjoint +end From 242d10be2e2d71caea8f7e8e3469291ae69c11a6 Mon Sep 17 00:00:00 2001 From: DhairyaLGandhi Date: Tue, 21 May 2024 15:46:09 +0530 Subject: [PATCH 02/11] chore: qualify bukld_solution --- src/adjoint_common.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/adjoint_common.jl b/src/adjoint_common.jl index 7cb80017f..6659174a4 100644 --- a/src/adjoint_common.jl +++ b/src/adjoint_common.jl @@ -657,7 +657,7 @@ Zygote.@adjoint function Zygote.literal_getproperty(sol::AbstractTimeseriesSolut function solu_adjoint(Δ) zerou = zero(sol.prob.u0) _Δ = @. ifelse(Δ === nothing, (zerou,), Δ) - (build_solution(sol.prob, sol.alg, sol.t, _Δ),) + (SciMLBase.build_solution(sol.prob, sol.alg, sol.t, _Δ),) end sol.u, solu_adjoint end From c9d87eea81ebb4141e6a41748674b4a99f1b5d3a Mon Sep 17 00:00:00 2001 From: DhairyaLGandhi Date: Tue, 21 May 2024 17:03:53 +0530 Subject: [PATCH 03/11] chore: import AbstractTimeSeriesSolution --- src/SciMLSensitivity.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/SciMLSensitivity.jl b/src/SciMLSensitivity.jl index 7529176bd..d037929f2 100644 --- a/src/SciMLSensitivity.jl +++ b/src/SciMLSensitivity.jl @@ -37,7 +37,8 @@ import SciMLBase: unwrapped_f, _unwrap_val import SciMLBase: AbstractOverloadingSensitivityAlgorithm, AbstractSensitivityAlgorithm, AbstractForwardSensitivityAlgorithm, AbstractAdjointSensitivityAlgorithm, AbstractSecondOrderSensitivityAlgorithm, - AbstractShadowingSensitivityAlgorithm + AbstractShadowingSensitivityAlgorithm, + AbstractTimeSeriesSolution include("parameters_handling.jl") include("sensitivity_algorithms.jl") From ac2f59a1c6b5b8460b202cb4b41b8cfe77092381 Mon Sep 17 00:00:00 2001 From: DhairyaLGandhi Date: Tue, 21 May 2024 17:16:35 +0530 Subject: [PATCH 04/11] chore: fix typo --- src/SciMLSensitivity.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/SciMLSensitivity.jl b/src/SciMLSensitivity.jl index d037929f2..5e6156aae 100644 --- a/src/SciMLSensitivity.jl +++ b/src/SciMLSensitivity.jl @@ -38,7 +38,7 @@ import SciMLBase: AbstractOverloadingSensitivityAlgorithm, AbstractSensitivityAl AbstractForwardSensitivityAlgorithm, AbstractAdjointSensitivityAlgorithm, AbstractSecondOrderSensitivityAlgorithm, AbstractShadowingSensitivityAlgorithm, - AbstractTimeSeriesSolution + AbstractTimeseriesSolution include("parameters_handling.jl") include("sensitivity_algorithms.jl") From 9226662435ecc337c4effa3ba6ea5712ed2babf8 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Sun, 26 May 2024 10:29:24 -0400 Subject: [PATCH 05/11] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 914f0457f..fda091cc8 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "SciMLSensitivity" uuid = "1ed8b502-d754-442c-8d5d-10ac956f44a1" authors = ["Christopher Rackauckas ", "Yingbo Ma "] -version = "7.59.0" +version = "7.60.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" From a78fbde86bd36f2f84414839712229d1f100889d Mon Sep 17 00:00:00 2001 From: DhairyaLGandhi Date: Mon, 27 May 2024 00:48:02 +0530 Subject: [PATCH 06/11] chore: check if method literal_getproperty exists --- src/adjoint_common.jl | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/src/adjoint_common.jl b/src/adjoint_common.jl index 6659174a4..1cf0582f0 100644 --- a/src/adjoint_common.jl +++ b/src/adjoint_common.jl @@ -652,12 +652,15 @@ function out_and_ts(_ts, duplicate_iterator_times, sol) return out, ts end -Zygote.@adjoint function Zygote.literal_getproperty(sol::AbstractTimeseriesSolution, - ::Val{:u}) - function solu_adjoint(Δ) - zerou = zero(sol.prob.u0) - _Δ = @. ifelse(Δ === nothing, (zerou,), Δ) - (SciMLBase.build_solution(sol.prob, sol.alg, sol.t, _Δ),) +if !hasmethod(Zygote.adjoint, Tuple{Zygote.AContext, typeof(Zygote.literal_getproperty), SciMLBase.AbstractTimeseriesSolution, Val{:u}}) + Zygote.@adjoint function Zygote.literal_getproperty(sol::AbstractTimeseriesSolution, + ::Val{:u}) + function solu_adjoint(Δ) + zerou = zero(sol.prob.u0) + _Δ = @. ifelse(Δ === nothing, (zerou,), Δ) + (SciMLBase.build_solution(sol.prob, sol.alg, sol.t, _Δ),) + end + @show "here" + sol.u, solu_adjoint end - sol.u, solu_adjoint end From 0a852349665d28335e8abb8b993e92d72f513d26 Mon Sep 17 00:00:00 2001 From: DhairyaLGandhi Date: Mon, 27 May 2024 00:49:11 +0530 Subject: [PATCH 07/11] chore: rm debug --- src/adjoint_common.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/adjoint_common.jl b/src/adjoint_common.jl index 1cf0582f0..ed7d92462 100644 --- a/src/adjoint_common.jl +++ b/src/adjoint_common.jl @@ -660,7 +660,6 @@ if !hasmethod(Zygote.adjoint, Tuple{Zygote.AContext, typeof(Zygote.literal_getpr _Δ = @. ifelse(Δ === nothing, (zerou,), Δ) (SciMLBase.build_solution(sol.prob, sol.alg, sol.t, _Δ),) end - @show "here" sol.u, solu_adjoint end end From cca0c2ff0567dadadcfb86b78fbd3e10d06bfd6a Mon Sep 17 00:00:00 2001 From: DhairyaLGandhi Date: Mon, 27 May 2024 00:50:22 +0530 Subject: [PATCH 08/11] chore: format --- src/adjoint_common.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/adjoint_common.jl b/src/adjoint_common.jl index ed7d92462..e511fae81 100644 --- a/src/adjoint_common.jl +++ b/src/adjoint_common.jl @@ -652,7 +652,9 @@ function out_and_ts(_ts, duplicate_iterator_times, sol) return out, ts end -if !hasmethod(Zygote.adjoint, Tuple{Zygote.AContext, typeof(Zygote.literal_getproperty), SciMLBase.AbstractTimeseriesSolution, Val{:u}}) +if !hasmethod(Zygote.adjoint, + Tuple{Zygote.AContext, typeof(Zygote.literal_getproperty), + SciMLBase.AbstractTimeseriesSolution, Val{:u}}) Zygote.@adjoint function Zygote.literal_getproperty(sol::AbstractTimeseriesSolution, ::Val{:u}) function solu_adjoint(Δ) From b02bf7d0e2b0e2d520340ea80cd7b5147e0afd63 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Tue, 28 May 2024 15:32:01 +0200 Subject: [PATCH 09/11] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index fda091cc8..ce8c18cad 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "SciMLSensitivity" uuid = "1ed8b502-d754-442c-8d5d-10ac956f44a1" authors = ["Christopher Rackauckas ", "Yingbo Ma "] -version = "7.60.0" +version = "7.60.1" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" From f28a2d7b2c4a743509e4a00cd979cc7d1fd4106d Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Thu, 6 Jun 2024 06:54:39 -0400 Subject: [PATCH 10/11] Update default ODE solver tests to use new OrdinaryDiffEq infrastructure Final piece of https://github.com/SciML/SciMLSensitivity.jl/issues/1035 --- .github/workflows/CI.yml | 1 - Project.toml | 2 +- test/{diffeq => }/default_alg_diff.jl | 13 +++++++++++-- test/diffeq/Project.toml | 5 ----- test/runtests.jl | 14 +------------- 5 files changed, 13 insertions(+), 22 deletions(-) rename test/{diffeq => }/default_alg_diff.jl (52%) delete mode 100644 test/diffeq/Project.toml diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 8623d4c8b..870d0421f 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -24,7 +24,6 @@ jobs: - Core5 - Core6 - Core7 - - DiffEq - SDE1 - SDE2 - SDE3 diff --git a/Project.toml b/Project.toml index ce8c18cad..08841e074 100644 --- a/Project.toml +++ b/Project.toml @@ -74,7 +74,7 @@ NLsolve = "4.5.1" NonlinearSolve = "3.0.1" Optimization = "3.19.3" OptimizationOptimisers = "0.1.6" -OrdinaryDiffEq = "6.68.1" +OrdinaryDiffEq = "6.81.1" Parameters = "0.12" Pkg = "1.10" PreallocationTools = "0.4.4" diff --git a/test/diffeq/default_alg_diff.jl b/test/default_alg_diff.jl similarity index 52% rename from test/diffeq/default_alg_diff.jl rename to test/default_alg_diff.jl index 7387e3b77..557f59c87 100644 --- a/test/diffeq/default_alg_diff.jl +++ b/test/default_alg_diff.jl @@ -1,4 +1,4 @@ -using ComponentArrays, DifferentialEquations, Lux, Random, SciMLSensitivity, Zygote +using ComponentArrays, OrdinaryDiffEq, Lux, Random, SciMLSensitivity, Zygote function f(du, u, p, t) du .= first(nn(u, p, st)) @@ -13,7 +13,16 @@ r = rand(Float32, 8, 64) function f2(x) prob = ODEProblem(f, r, (0.0f0, 1.0f0), x) - sol = solve(prob; sensealg = InterpolatingAdjoint(; autodiff = true, autojacvec = true)) + sol = solve(prob, OrdinaryDiffEq.DefaultODEAlgorithm()) + sum(last(sol.u)) +end + +f2(ps) +Zygote.gradient(f2, ps) + +function f2(x) + prob = ODEProblem(f, r, (0.0f0, 1.0f0), x) + sol = solve(prob) sum(last(sol.u)) end diff --git a/test/diffeq/Project.toml b/test/diffeq/Project.toml deleted file mode 100644 index 6dde5fc89..000000000 --- a/test/diffeq/Project.toml +++ /dev/null @@ -1,5 +0,0 @@ -[deps] -DifferentialEquations = "0c46a032-eb83-5123-abaf-570d42b7fbaa" - -[compat] -DifferentialEquations = "7" diff --git a/test/runtests.jl b/test/runtests.jl index c956d69da..d26d6789a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -9,12 +9,6 @@ function activate_gpu_env() Pkg.instantiate() end -function activate_diffeq_env() - Pkg.activate("diffeq") - Pkg.develop(PackageSpec(path = dirname(@__DIR__))) - Pkg.instantiate() -end - @time @testset "SciMLSensitivity" begin if GROUP == "All" || GROUP == "Core1" || GROUP == "Downstream" @testset "Core1" begin @@ -49,6 +43,7 @@ end if GROUP == "All" || GROUP == "Core3" || GROUP == "Downstream" @testset "Core 3" begin + @time @safetestset "Default DiffEq Alg" include("default_alg_diff.jl") @time @safetestset "Adjoint Sensitivity" include("adjoint.jl") @time @safetestset "automatic sensealg choice" include("automatic_sensealg_choice.jl") end @@ -150,13 +145,6 @@ end end end - if GROUP == "DiffEq" - @testset "DiffEq" begin - activate_diffeq_env() - @time @safetestset "Default DiffEq Alg" include("diffeq/default_alg_diff.jl") - end - end - if GROUP == "GPU" @testset "GPU" begin activate_gpu_env() From bc3bcd52989304e1342ccb6c26e55424a0931c04 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Thu, 6 Jun 2024 09:07:00 -0400 Subject: [PATCH 11/11] Do not require gradient for vjp choice Last little bit to fix https://github.com/SciML/DiffEqFlux.jl/issues/928 and make that nicer --- docs/pages.jl | 4 ++- docs/src/faq.md | 62 +++++++++++++++++++++++++++++++++++++++++++ src/concrete_solve.jl | 32 +++++++++++++++++++--- 3 files changed, 93 insertions(+), 5 deletions(-) create mode 100644 docs/src/faq.md diff --git a/docs/pages.jl b/docs/pages.jl index af24391c3..482e9e2c4 100644 --- a/docs/pages.jl +++ b/docs/pages.jl @@ -8,6 +8,7 @@ pages = ["index.md", "Training Techniques and Tips" => Any["tutorials/training_tips/local_minima.md", "tutorials/training_tips/divergence.md", "tutorials/training_tips/multiple_nn.md"]], + "Frequently Asked Questions (FAQ)" => "faq.md", "Examples" => Any[ "Ordinary Differential Equations (ODEs)" => Any["examples/ode/exogenous_input.md", "examples/ode/prediction_error_method.md", @@ -28,7 +29,8 @@ pages = ["index.md", "Optimal and Model Predictive Control" => Any[ "examples/optimal_control/optimal_control.md", "examples/optimal_control/feedback_control.md"]], - "Manual and APIs" => Any["manual/differential_equation_sensitivities.md", + "Manual and APIs" => Any[ + "manual/differential_equation_sensitivities.md", "manual/nonlinear_solve_sensitivities.md", "manual/direct_forward_sensitivity.md", "manual/direct_adjoint_sensitivities.md"], diff --git a/docs/src/faq.md b/docs/src/faq.md new file mode 100644 index 000000000..74de01a31 --- /dev/null +++ b/docs/src/faq.md @@ -0,0 +1,62 @@ +# Frequently Asked Qestuions (FAQ) + +## How do I isolate potential gradient issues and improve performance? + +If you see the warnings: + +```julia +┌ Warning: Reverse-Mode AD VJP choices all failed. Falling back to numerical VJPs +└ @ SciMLSensitivity C:\Users\accou\.julia\dev\SciMLSensitivity\src\concrete_solve.jl:145 +┌ Warning: Potential performance improvement omitted. EnzymeVJP tried and failed in the automated AD choice algorithm. To show the stack trace, set SciMLSensitivity.STACKTRACE_WITH_VJPWARN[] = true. To turn off this printing, add `verbose = false` to the `solve` call. +└ @ SciMLSensitivity C:\Users\accou\.julia\dev\SciMLSensitivity\src\concrete_solve.jl:100 +``` + +then you're in luck! Well, not really. But there are things you can do. You can isolate the +issue to automatic differentiation of your `f` function in order to either fix your `f` +function, or open an issue with the AD library directly without the ODE solver involved. + +If you have an in-place function, then you will want to isolate it to Enzyme. This is done +as follows for an arbitrary problem: + +```julia +using Enzyme +u0 = prob.u0 +p = prob.p +tmp2 = Enzyme.make_zero(p) +t = prob.tspan[1] +du = zero(u0) + +if DiffEqBase.isinplace(prob) + _f = prob.f +else + _f = (du,u,p,t) -> (du .= prob.f(u,p,t); nothing) +end + +_tmp6 = Enzyme.make_zero(_f) +tmp3 = zero(u0) +tmp4 = zero(u0) +ytmp = zero(u0) +tmp1 = zero(u0) + +Enzyme.autodiff(Enzyme.Reverse, Enzyme.Duplicated(_f, _tmp6), + Enzyme.Const, Enzyme.Duplicated(tmp3, tmp4), + Enzyme.Duplicated(ytmp, tmp1), + Enzyme.Duplicated(p, tmp2), + Enzyme.Const(t)) +``` + +This is exactly the inner core Enzyme call and if this fails, that is the issue that +needs to be fixed. + +And similarly, for out-of-place functions the Zygote isolation is as follows: + +```julia +p = prob.p +y = prob.u0 +f = prob.f +λ = zero(prob.u0) +_dy, back = Zygote.pullback(y, p) do u, p + vec(f(u, p, t)) +end +tmp1, tmp2 = back(λ) +``` diff --git a/src/concrete_solve.jl b/src/concrete_solve.jl index 6a2202d00..d0d7f4f4f 100644 --- a/src/concrete_solve.jl +++ b/src/concrete_solve.jl @@ -89,10 +89,22 @@ function automatic_sensealg_choice( # so if out-of-place, try Zygote vjp = try + p = prob.p + y = prob.u0 + f = prob.f + t = prob.tspan[1] + λ = zero(prob.u0) + if p === nothing || p isa SciMLBase.NullParameters - Zygote.gradient((u) -> sum(prob.f(u, p, prob.tspan[1])), u0) + _dy, back = Zygote.pullback(y) do u + vec(f(u, p, t)) + end + tmp1 = back(λ) else - Zygote.gradient((u, p) -> sum(prob.f(u, p, prob.tspan[1])), u0, p) + _dy, back = Zygote.pullback(y, p) do u, p + vec(f(u, p, t)) + end + tmp1, tmp2 = back(λ) end ZygoteVJP() catch e @@ -124,10 +136,22 @@ function automatic_sensealg_choice( if vjp == false vjp = try + p = prob.p + y = prob.u0 + f = prob.f + t = prob.tspan[1] + λ = zero(prob.u0) + if p === nothing || p isa SciMLBase.NullParameters - Tracker.gradient((u) -> sum(prob.f(u, p, prob.tspan[1])), u0) + _dy, back = Tracker.forward(y) do u + vec(f(u, p, t)) + end + tmp1 = back(λ) else - Tracker.gradient((u, p) -> sum(prob.f(u, p, prob.tspan[1])), u0, p) + _dy, back = Tracker.forward(y, p) do u, p + vec(f(u, p, t)) + end + tmp1, tmp2 = back(λ) end TrackerVJP() catch e