Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update to using SII and SciMLStructures #1057

Merged
merged 88 commits into from
Jun 28, 2024
Merged

Update to using SII and SciMLStructures #1057

merged 88 commits into from
Jun 28, 2024

Conversation

DhairyaLGandhi
Copy link
Member

Checklist

  • Appropriate tests were added
  • Any code changes were done in a way that does not break public API
  • All documentation related to code changes were updated
  • The new code follows the
    contributor guidelines, in particular the SciML Style Guide and
    COLPRAC.
  • Any new documentation only uses public API

Additional context

Moves SciMLSensitivity to use the SII and SciMLStructures' interface.

Add any other context about the problem here.

ChrisRackauckas and others added 2 commits February 8, 2024 11:15
This uses the SciMLStructures Tunable interface https://github.com/SciML/SciMLStructures.jl in order to allow more generalized definitions of `p`.

- [ ] Ensure Lux.jl is well supported (componentarrays extension in SciMLStructures
- [ ] Add a test for a custom SciMLStructure
@DhairyaLGandhi DhairyaLGandhi changed the title Dg/structures Update to using SII and SciMLStructures May 22, 2024
Copy link
Member

@AayushSabharwal AayushSabharwal left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just minor things

src/adjoint_common.jl Outdated Show resolved Hide resolved
@@ -158,7 +158,7 @@ function ODEForwardSensitivityProblem(f::F, args...; kwargs...) where {F}
end

function ODEForwardSensitivityProblem(prob::ODEProblem, alg; kwargs...)
ODEForwardSensitivityProblem(prob.f, prob.u0, prob.tspan, prob.p, alg; kwargs...)
ODEForwardSensitivityProblem(symbolic_container(prob), state_values(prob), prob.tspan, parameter_values(prob), alg; kwargs...)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

symbolic_container technically does returns prob.f, however semantically it isn't guaranteed to do so and in such cases where you explicitly wand the function, I'd still use prob.f.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This has been addressed

@@ -354,14 +361,15 @@ function DiffEqBase._concrete_solve_adjoint(
sol = solve(_prob, alg, args...; save_noise = true, save_start = true,
save_end = true, kwargs_fwd...)
end
time = current_time(sol)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's current time from?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment on lines 364 to 372
ts = sol.t
ts = time
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why change the standard naming? Naming an array time instead of a scalar value is weird and not grammatically correct...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I used it where there were multiple calls to sol.t. I wanted to avoid multiple calls to current_time when something like sol.t was referenced multiple times, since field access might be free, function calls potentially may not be.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can change the naming or potentially if there's a better solution, can implement that

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's fine, it could still be called ts. If you want to write it, times. But I don't understand why syntactic changes are mixed with a feature PR.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can change it to be ts - being shorter to write.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be fixed now with 6952e1b and 95cad85

Project.toml Outdated Show resolved Hide resolved
@DhairyaLGandhi
Copy link
Member Author

DhairyaLGandhi commented Jun 25, 2024

The shadowing tests are passing now as well. The MWE for the DiffEqFlux failures is:

julia> ca = ComponentVector((x = rand(3), y = rand(4),))  # start with a component array                                                                                                          
ComponentVector{Float64}(x = [0.2755239169592659, 0.2144463395930084, 0.9085629964224934], y = [0.122401457810142, 0.498989539972836, 0.7718946152650026, 0.9715871491699061]
)                                                                                                                                                                            
                                                                                                                                                                             
julia> tr = Tracker.param(ca)  # use Tracker as the backend                                                                                                                                   
ComponentVector{Float64, TrackedVector{Float64, Vector{Float64}}, Tuple{Axis{(x = 1:3, y = 4:7)}}}(x = [0.2755239169592659, 0.2144463395930084, 0.9085629964224934] (tracked), y = [0.122401457810142, 0.498989539972836, 0.7718946152650026, 0.9715871491699061] (tracked))

julia> convert(typeof(ca), tr) # convert call masks the type of the underlying array
Tracked 7-element ComponentVector{Float64, Vector{Float64}, Tuple{Axis{(x = 1:3, y = 4:7)}}} with indices 1:1:7:
 0.2755239169592659
 0.2144463395930084
 0.9085629964224934
 0.122401457810142
 0.498989539972836
 0.7718946152650026
 0.9715871491699061

julia> Float64.(tr)
Tracked 7-element ComponentVector{Float64, Vector{Float64}, Tuple{Axis{(x = 1:3, y = 4:7)}}} with indices 1:1:7:
 0.2755239169592659
 0.2144463395930084
 0.9085629964224934
 0.122401457810142
 0.498989539972836
 0.7718946152650026
 0.9715871491699061

The specific convert method is https://github.com/jonniedie/ComponentArrays.jl/blob/554a9c03373680af84586762f68ebd32b6d34abe/src/similar_convert_copy.jl#L55 which seems to be missing a case where T1 === T2 and does an extra allocation. A dispatch could be defined to avoid this.

function Base.convert(::Type{ComponentArray{T,N,A1,Ax1}}, x::ComponentArray{T,N,A2,Ax2}) where {T,N,A1,A2,Ax1,Ax2}
    return x
end

@ChrisRackauckas
Copy link
Member

Can you PR that? If @jonniedie has a quick turnaround then I think we're close to done here.

@DhairyaLGandhi
Copy link
Member Author


if !isautojacvec
if DiffEqBase.has_paramjac(f)
f.paramjac(pJ, y, p, t) # Calculate the parameter Jacobian into pJ
else
pf.t = t
pf.u = y
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While working through this PR, I found this path which was not updating the caches. I have added this here, but wanted to highlight for correctness

@ChrisRackauckas
Copy link
Member

@wsmoses we're seeing something odd with the latest Enzyme release as well.

@wsmoses
Copy link

wsmoses commented Jun 27, 2024

@ChrisRackauckas how so, what’s the failure ?

@ChrisRackauckas
Copy link
Member

same as @avik-pal is mentioning is seen in CI here

Continuous and discrete costs: Error During Test at /home/runner/.julia/packages/SafeTestsets/raUNr/src/SafeTestsets.jl:30
  Got exception outside of a @test
  LoadError: MethodError: no method matching cpu_features!(::LLVM.ModulePassManager)
  
  Closest candidates are:
    cpu_features!(::LLVM.Module)
     @ GPUCompiler ~/.julia/packages/GPUCompiler/05oYT/src/optim.jl:606
  
  Stacktrace:
    [1] (::Enzyme.Compiler.var"#28472#28473"{LLVM.Module, LLVM.TargetMachine})(pm::LLVM.ModulePassManager)
      @ Enzyme.Compiler ~/.julia/packages/Enzyme/qd8AI/src/compiler/optimize.jl:1941
    [2] LLVM.ModulePassManager(::Enzyme.Compiler.var"#28472#28473"{LLVM.Module, LLVM.TargetMachine}; kwargs::@Kwargs{})
      @ LLVM ~/.julia/packages/LLVM/5DlHM/src/passmanager.jl:33
    [3] LLVM.ModulePassManager(::Function)
      @ LLVM ~/.julia/packages/LLVM/5DlHM/src/passmanager.jl:30
    [4] optimize!(mod::LLVM.Module, tm::LLVM.TargetMachine)

@wsmoses
Copy link

wsmoses commented Jun 27, 2024 via email

@wsmoses
Copy link

wsmoses commented Jun 27, 2024 via email

@avik-pal
Copy link
Member

avik-pal commented Jun 27, 2024

Oh I think this is the llvm.jl bump changing api (it was done properly with a change which did semver but it passed ci so I assume we didn’t hit).

On enzyme CI it installed v7 so didn't really check the correct thing.

Really wish there was a way to remove old deps in CompatHelper PRs and push another change with the old deps inplace once tests pass. (I would have guessed the force_latest_versions kwarg in test does this)

@wsmoses
Copy link

wsmoses commented Jun 27, 2024

@ChrisRackauckas @avik-pal does this fix?

EnzymeAD/Enzyme.jl#1581

@@ -34,12 +34,14 @@ function adjointdiffcache(g::G, sensealg, discrete, sol, dgdu::DG1, dgdp::DG2, f
prob = sol.prob
u0 = state_values(prob)
p = parameter_values(prob)
if p === nothing || p isa SciMLBase.NullParameters
if p === nothing || p === DiffEqBase.NullParameters()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why the change to DiffEq? Should be SciML

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it

@ChrisRackauckas ChrisRackauckas merged commit 903f96d into master Jun 28, 2024
15 of 16 checks passed
@ChrisRackauckas ChrisRackauckas deleted the dg/structures branch June 28, 2024 00:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants