From f28a2d7b2c4a743509e4a00cd979cc7d1fd4106d Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Thu, 6 Jun 2024 06:54:39 -0400 Subject: [PATCH] Update default ODE solver tests to use new OrdinaryDiffEq infrastructure Final piece of https://github.com/SciML/SciMLSensitivity.jl/issues/1035 --- .github/workflows/CI.yml | 1 - Project.toml | 2 +- test/{diffeq => }/default_alg_diff.jl | 13 +++++++++++-- test/diffeq/Project.toml | 5 ----- test/runtests.jl | 14 +------------- 5 files changed, 13 insertions(+), 22 deletions(-) rename test/{diffeq => }/default_alg_diff.jl (52%) delete mode 100644 test/diffeq/Project.toml diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 8623d4c8b..870d0421f 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -24,7 +24,6 @@ jobs: - Core5 - Core6 - Core7 - - DiffEq - SDE1 - SDE2 - SDE3 diff --git a/Project.toml b/Project.toml index ce8c18cad..08841e074 100644 --- a/Project.toml +++ b/Project.toml @@ -74,7 +74,7 @@ NLsolve = "4.5.1" NonlinearSolve = "3.0.1" Optimization = "3.19.3" OptimizationOptimisers = "0.1.6" -OrdinaryDiffEq = "6.68.1" +OrdinaryDiffEq = "6.81.1" Parameters = "0.12" Pkg = "1.10" PreallocationTools = "0.4.4" diff --git a/test/diffeq/default_alg_diff.jl b/test/default_alg_diff.jl similarity index 52% rename from test/diffeq/default_alg_diff.jl rename to test/default_alg_diff.jl index 7387e3b77..557f59c87 100644 --- a/test/diffeq/default_alg_diff.jl +++ b/test/default_alg_diff.jl @@ -1,4 +1,4 @@ -using ComponentArrays, DifferentialEquations, Lux, Random, SciMLSensitivity, Zygote +using ComponentArrays, OrdinaryDiffEq, Lux, Random, SciMLSensitivity, Zygote function f(du, u, p, t) du .= first(nn(u, p, st)) @@ -13,7 +13,16 @@ r = rand(Float32, 8, 64) function f2(x) prob = ODEProblem(f, r, (0.0f0, 1.0f0), x) - sol = solve(prob; sensealg = InterpolatingAdjoint(; autodiff = true, autojacvec = true)) + sol = solve(prob, OrdinaryDiffEq.DefaultODEAlgorithm()) + sum(last(sol.u)) +end + +f2(ps) +Zygote.gradient(f2, ps) + +function f2(x) + prob = ODEProblem(f, r, (0.0f0, 1.0f0), x) + sol = solve(prob) sum(last(sol.u)) end diff --git a/test/diffeq/Project.toml b/test/diffeq/Project.toml deleted file mode 100644 index 6dde5fc89..000000000 --- a/test/diffeq/Project.toml +++ /dev/null @@ -1,5 +0,0 @@ -[deps] -DifferentialEquations = "0c46a032-eb83-5123-abaf-570d42b7fbaa" - -[compat] -DifferentialEquations = "7" diff --git a/test/runtests.jl b/test/runtests.jl index c956d69da..d26d6789a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -9,12 +9,6 @@ function activate_gpu_env() Pkg.instantiate() end -function activate_diffeq_env() - Pkg.activate("diffeq") - Pkg.develop(PackageSpec(path = dirname(@__DIR__))) - Pkg.instantiate() -end - @time @testset "SciMLSensitivity" begin if GROUP == "All" || GROUP == "Core1" || GROUP == "Downstream" @testset "Core1" begin @@ -49,6 +43,7 @@ end if GROUP == "All" || GROUP == "Core3" || GROUP == "Downstream" @testset "Core 3" begin + @time @safetestset "Default DiffEq Alg" include("default_alg_diff.jl") @time @safetestset "Adjoint Sensitivity" include("adjoint.jl") @time @safetestset "automatic sensealg choice" include("automatic_sensealg_choice.jl") end @@ -150,13 +145,6 @@ end end end - if GROUP == "DiffEq" - @testset "DiffEq" begin - activate_diffeq_env() - @time @safetestset "Default DiffEq Alg" include("diffeq/default_alg_diff.jl") - end - end - if GROUP == "GPU" @testset "GPU" begin activate_gpu_env()