diff --git a/test/null_parameters.jl b/test/null_parameters.jl index e302c7bce..d1ed9576c 100644 --- a/test/null_parameters.jl +++ b/test/null_parameters.jl @@ -81,7 +81,7 @@ function loss10(params) u0 = zeros(2) problem = ODEProblem(dynamics, u0, (0.0, 1.0)) rollout = solve(problem, Tsit5(), u0 = u0, p = params, - sensealg = GaussAdjoint(autojacvec = EnzymeVJP())) + sensealg = QuadratureAdjoint(autojacvec = EnzymeVJP())) sum(Array(rollout)[:, end]) end @@ -89,7 +89,7 @@ function loss11(params) u0 = zeros(2) problem = ODEProblem(dynamics, u0, (0.0, 1.0)) rollout = solve(problem, Tsit5(), u0 = u0, p = params, - sensealg = GaussAdjoint(autojacvec = ZygoteVJP())) + sensealg = QuadratureAdjoint(autojacvec = ZygoteVJP())) sum(Array(rollout)[:, end]) end