From 6e0c2595835c5eb782faf95de460c52a0939d764 Mon Sep 17 00:00:00 2001 From: Shuhei Kadowaki Date: Fri, 5 Apr 2024 01:36:45 +0900 Subject: [PATCH] optimize construction of `InferenceResult` for constant inference --- base/compiler/abstractinterpretation.jl | 38 ++++---- base/compiler/compiler.jl | 2 +- base/compiler/inferenceresult.jl | 110 +++++++----------------- base/compiler/inferencestate.jl | 5 +- base/compiler/ssair/legacy.jl | 2 +- base/compiler/typeinfer.jl | 8 +- base/compiler/types.jl | 29 ++++--- test/compiler/inference.jl | 2 +- 8 files changed, 81 insertions(+), 115 deletions(-) diff --git a/base/compiler/abstractinterpretation.jl b/base/compiler/abstractinterpretation.jl index a113d028248db..329d4b2c4b9ad 100644 --- a/base/compiler/abstractinterpretation.jl +++ b/base/compiler/abstractinterpretation.jl @@ -1239,7 +1239,7 @@ const_prop_result(inf_result::InferenceResult) = ConstCallResults(inf_result.result, inf_result.exc_result, ConstPropResult(inf_result), inf_result.ipo_effects, inf_result.linfo) -# return cached constant analysis result +# return cached result of constant analysis return_cached_result(::AbstractInterpreter, inf_result::InferenceResult, ::AbsIntState) = const_prop_result(inf_result) @@ -1249,7 +1249,14 @@ function const_prop_call(interp::AbstractInterpreter, inf_cache = get_inference_cache(interp) ๐•ƒแตข = typeinf_lattice(interp) argtypes = has_conditional(๐•ƒแตข, sv) ? ConditionalArgtypes(arginfo, sv) : SimpleArgtypes(arginfo.argtypes) - given_argtypes, overridden_by_const = matching_cache_argtypes(๐•ƒแตข, mi, argtypes) + # use `cache_argtypes` that has been constructed for fresh regular inference if available + volatile_inf_result = result.volatile_inf_result + if volatile_inf_result !== nothing + cache_argtypes = volatile_inf_result.inf_result.argtypes + else + cache_argtypes = matching_cache_argtypes(๐•ƒแตข, mi) + end + given_argtypes = matching_cache_argtypes(๐•ƒแตข, mi, argtypes, cache_argtypes) inf_result = cache_lookup(๐•ƒแตข, mi, given_argtypes, inf_cache) if inf_result !== nothing # found the cache for this constant prop' @@ -1260,12 +1267,18 @@ function const_prop_call(interp::AbstractInterpreter, @assert inf_result.linfo === mi "MethodInstance for cached inference result does not match" return return_cached_result(interp, inf_result, sv) end - # perform fresh constant prop' - inf_result = InferenceResult(mi, given_argtypes, overridden_by_const) - if !any(inf_result.overridden_by_const) + overridden_by_const = falses(length(given_argtypes)) + for i = 1:length(given_argtypes) + if given_argtypes[i] !== cache_argtypes[i] + overridden_by_const[i] = true + end + end + if !any(overridden_by_const) add_remark!(interp, sv, "[constprop] Could not handle constant info in matching_cache_argtypes") return nothing end + # perform fresh constant prop' + inf_result = InferenceResult(mi, given_argtypes, overridden_by_const) frame = InferenceState(inf_result, #=cache_mode=#:local, interp) if frame === nothing add_remark!(interp, sv, "[constprop] Could not retrieve the source") @@ -1287,26 +1300,19 @@ end # TODO implement MustAlias forwarding -struct ConditionalArgtypes <: ForwardableArgtypes +struct ConditionalArgtypes arginfo::ArgInfo sv::InferenceState end -""" - matching_cache_argtypes(๐•ƒ::AbstractLattice, mi::MethodInstance, - conditional_argtypes::ConditionalArgtypes) - -The implementation is able to forward `Conditional` of `conditional_argtypes`, -as well as the other general extended lattice information. -""" function matching_cache_argtypes(๐•ƒ::AbstractLattice, mi::MethodInstance, - conditional_argtypes::ConditionalArgtypes) + conditional_argtypes::ConditionalArgtypes, + cache_argtypes::Vector{Any}) (; arginfo, sv) = conditional_argtypes (; fargs, argtypes) = arginfo given_argtypes = Vector{Any}(undef, length(argtypes)) def = mi.def::Method nargs = Int(def.nargs) - cache_argtypes, overridden_by_const = matching_cache_argtypes(๐•ƒ, mi) local condargs = nothing for i in 1:length(argtypes) argtype = argtypes[i] @@ -1349,7 +1355,7 @@ function matching_cache_argtypes(๐•ƒ::AbstractLattice, mi::MethodInstance, else given_argtypes = va_process_argtypes(๐•ƒ, given_argtypes, mi) end - return pick_const_args!(๐•ƒ, cache_argtypes, overridden_by_const, given_argtypes) + return pick_const_args!(๐•ƒ, given_argtypes, cache_argtypes) end # This is only for use with `Conditional`. diff --git a/base/compiler/compiler.jl b/base/compiler/compiler.jl index 8b9c26be2ec81..12d6d5eb38764 100644 --- a/base/compiler/compiler.jl +++ b/base/compiler/compiler.jl @@ -203,6 +203,7 @@ include("compiler/ssair/ir.jl") include("compiler/ssair/tarjan.jl") include("compiler/abstractlattice.jl") +include("compiler/stmtinfo.jl") include("compiler/inferenceresult.jl") include("compiler/inferencestate.jl") @@ -210,7 +211,6 @@ include("compiler/typeutils.jl") include("compiler/typelimits.jl") include("compiler/typelattice.jl") include("compiler/tfuncs.jl") -include("compiler/stmtinfo.jl") include("compiler/abstractinterpretation.jl") include("compiler/typeinfer.jl") diff --git a/base/compiler/inferenceresult.jl b/base/compiler/inferenceresult.jl index 06fbffaa7aa04..574685c1e38d6 100644 --- a/base/compiler/inferenceresult.jl +++ b/base/compiler/inferenceresult.jl @@ -1,63 +1,30 @@ # This file is a part of Julia. License is MIT: https://julialang.org/license -""" - matching_cache_argtypes(๐•ƒ::AbstractLattice, mi::MethodInstance) -> - (cache_argtypes::Vector{Any}, overridden_by_const::BitVector) - -Returns argument types `cache_argtypes::Vector{Any}` for `mi` that are in the native -Julia type domain. `overridden_by_const::BitVector` is all `false` meaning that -there is no additional extended lattice information there. - - matching_cache_argtypes(๐•ƒ::AbstractLattice, mi::MethodInstance, argtypes::ForwardableArgtypes) -> - (cache_argtypes::Vector{Any}, overridden_by_const::BitVector) - -Returns cache-correct extended lattice argument types `cache_argtypes::Vector{Any}` -for `mi` given some `argtypes` accompanied by `overridden_by_const::BitVector` -that marks which argument contains additional extended lattice information. - -In theory, there could be a `cache` containing a matching `InferenceResult` -for the provided `mi` and `given_argtypes`. The purpose of this function is -to return a valid value for `cache_lookup(๐•ƒ, mi, argtypes, cache).argtypes`, -so that we can construct cache-correct `InferenceResult`s in the first place. -""" -function matching_cache_argtypes end - function matching_cache_argtypes(๐•ƒ::AbstractLattice, mi::MethodInstance) - method = isa(mi.def, Method) ? mi.def::Method : nothing - cache_argtypes = most_general_argtypes(method, mi.specTypes) - overridden_by_const = falses(length(cache_argtypes)) - return cache_argtypes, overridden_by_const + (; def, specTypes) = mi + return most_general_argtypes(isa(def, Method) ? def : nothing, specTypes) end -struct SimpleArgtypes <: ForwardableArgtypes +struct SimpleArgtypes argtypes::Vector{Any} end -""" - matching_cache_argtypes(๐•ƒ::AbstractLattice, mi::MethodInstance, argtypes::SimpleArgtypes) - -The implementation for `argtypes` with general extended lattice information. -This is supposed to be used for debugging and testing or external `AbstractInterpreter` -usages and in general `matching_cache_argtypes(::MethodInstance, ::ConditionalArgtypes)` -is more preferred it can forward `Conditional` information. -""" -function matching_cache_argtypes(๐•ƒ::AbstractLattice, mi::MethodInstance, simple_argtypes::SimpleArgtypes) +function matching_cache_argtypes(๐•ƒ::AbstractLattice, mi::MethodInstance, + simple_argtypes::SimpleArgtypes, + cache_argtypes::Vector{Any}) (; argtypes) = simple_argtypes given_argtypes = Vector{Any}(undef, length(argtypes)) for i = 1:length(argtypes) given_argtypes[i] = widenslotwrapper(argtypes[i]) end given_argtypes = va_process_argtypes(๐•ƒ, given_argtypes, mi) - return pick_const_args(๐•ƒ, mi, given_argtypes) + return pick_const_args!(๐•ƒ, given_argtypes, cache_argtypes) end -function pick_const_args(๐•ƒ::AbstractLattice, mi::MethodInstance, given_argtypes::Vector{Any}) - cache_argtypes, overridden_by_const = matching_cache_argtypes(๐•ƒ, mi) - return pick_const_args!(๐•ƒ, cache_argtypes, overridden_by_const, given_argtypes) -end - -function pick_const_args!(๐•ƒ::AbstractLattice, cache_argtypes::Vector{Any}, overridden_by_const::BitVector, given_argtypes::Vector{Any}) - for i = 1:length(given_argtypes) +function pick_const_args!(๐•ƒ::AbstractLattice, given_argtypes::Vector{Any}, cache_argtypes::Vector{Any}) + nargtypes = length(given_argtypes) + @assert nargtypes == length(cache_argtypes) #= == nargs =# "invalid `given_argtypes` for `mi`" + for i = 1:nargtypes given_argtype = given_argtypes[i] cache_argtype = cache_argtypes[i] if !is_argtype_match(๐•ƒ, given_argtype, cache_argtype, false) @@ -66,13 +33,13 @@ function pick_const_args!(๐•ƒ::AbstractLattice, cache_argtypes::Vector{Any}, ov !โŠ(๐•ƒ, given_argtype, cache_argtype)) # if the type information of this `PartialStruct` is less strict than # declared method signature, narrow it down using `tmeet` - given_argtype = tmeet(๐•ƒ, given_argtype, cache_argtype) + given_argtypes[i] = tmeet(๐•ƒ, given_argtype, cache_argtype) end - cache_argtypes[i] = given_argtype - overridden_by_const[i] = true + else + given_argtypes[i] = cache_argtype end end - return cache_argtypes, overridden_by_const + return given_argtypes end function is_argtype_match(๐•ƒ::AbstractLattice, @@ -89,9 +56,9 @@ end va_process_argtypes(๐•ƒ::AbstractLattice, given_argtypes::Vector{Any}, mi::MethodInstance) = va_process_argtypes(Returns(nothing), ๐•ƒ, given_argtypes, mi) function va_process_argtypes(@specialize(va_handler!), ๐•ƒ::AbstractLattice, given_argtypes::Vector{Any}, mi::MethodInstance) - def = mi.def - isva = isa(def, Method) ? def.isva : false - nargs = isa(def, Method) ? Int(def.nargs) : length(mi.specTypes.parameters) + def = mi.def::Method + isva = def.isva + nargs = Int(def.nargs) if isva || isvarargtype(given_argtypes[end]) isva_given_argtypes = Vector{Any}(undef, nargs) for i = 1:(nargs-isva) @@ -112,14 +79,11 @@ function va_process_argtypes(@specialize(va_handler!), ๐•ƒ::AbstractLattice, gi return given_argtypes end -function most_general_argtypes(method::Union{Method, Nothing}, @nospecialize(specTypes), - withfirst::Bool = true) +function most_general_argtypes(method::Union{Method,Nothing}, @nospecialize(specTypes)) toplevel = method === nothing isva = !toplevel && method.isva mi_argtypes = Any[(unwrap_unionall(specTypes)::DataType).parameters...] nargs::Int = toplevel ? 0 : method.nargs - # For opaque closure, the closure environment is processed elsewhere - withfirst || (nargs -= 1) cache_argtypes = Vector{Any}(undef, nargs) # First, if we're dealing with a varargs method, then we set the last element of `args` # to the appropriate `Tuple` type or `PartialStruct` instance. @@ -162,17 +126,16 @@ function most_general_argtypes(method::Union{Method, Nothing}, @nospecialize(spe cache_argtypes[nargs] = vargtype nargs -= 1 end - # Now, we propagate type info from `linfo_argtypes` into `cache_argtypes`, improving some + # Now, we propagate type info from `mi_argtypes` into `cache_argtypes`, improving some # type info as we go (where possible). Note that if we're dealing with a varargs method, # we already handled the last element of `cache_argtypes` (and decremented `nargs` so that # we don't overwrite the result of that work here). if mi_argtypes_length > 0 - n = mi_argtypes_length > nargs ? nargs : mi_argtypes_length - tail_index = n + tail_index = nargtypes = min(mi_argtypes_length, nargs) local lastatype - for i = 1:n + for i = 1:nargtypes atyp = mi_argtypes[i] - if i == n && isvarargtype(atyp) + if i == nargtypes && isvarargtype(atyp) atyp = unwrapva(atyp) tail_index -= 1 end @@ -185,16 +148,16 @@ function most_general_argtypes(method::Union{Method, Nothing}, @nospecialize(spe else atyp = elim_free_typevars(rewrap_unionall(atyp, specTypes)) end - i == n && (lastatype = atyp) + i == nargtypes && (lastatype = atyp) cache_argtypes[i] = atyp end - for i = (tail_index + 1):nargs + for i = (tail_index+1):nargs cache_argtypes[i] = lastatype end else @assert nargs == 0 "invalid specialization of method" # wrong number of arguments end - cache_argtypes + return cache_argtypes end # eliminate free `TypeVar`s in order to make the life much easier down the road: @@ -213,22 +176,15 @@ end function cache_lookup(๐•ƒ::AbstractLattice, mi::MethodInstance, given_argtypes::Vector{Any}, cache::Vector{InferenceResult}) method = mi.def::Method - nargs = Int(method.nargs) - method.isva && (nargs -= 1) - length(given_argtypes) โ‰ฅ nargs || return nothing + nargtypes = length(given_argtypes) + @assert nargtypes == Int(method.nargs) "invalid `given_argtypes` for `mi`" for cached_result in cache - cached_result.linfo === mi || continue + cached_result.linfo === mi || @goto next_cache cache_argtypes = cached_result.argtypes - cache_overridden_by_const = cached_result.overridden_by_const - for i in 1:nargs - if !is_argtype_match(๐•ƒ, widenmustalias(given_argtypes[i]), - cache_argtypes[i], cache_overridden_by_const[i]) - @goto next_cache - end - end - if method.isva - if !is_argtype_match(๐•ƒ, tuple_tfunc(๐•ƒ, given_argtypes[(nargs + 1):end]), - cache_argtypes[end], cache_overridden_by_const[end]) + @assert length(cache_argtypes) == nargtypes "invalid `cache_argtypes` for `mi`" + cache_overridden_by_const = cached_result.overridden_by_const::BitVector + for i in 1:nargtypes + if !is_argtype_match(๐•ƒ, given_argtypes[i], cache_argtypes[i], cache_overridden_by_const[i]) @goto next_cache end end diff --git a/base/compiler/inferencestate.jl b/base/compiler/inferencestate.jl index 3c71a88ab8a80..73c697d81b184 100644 --- a/base/compiler/inferencestate.jl +++ b/base/compiler/inferencestate.jl @@ -832,7 +832,10 @@ end frame_parent(sv::InferenceState) = sv.parent::Union{Nothing,AbsIntState} frame_parent(sv::IRInterpretationState) = sv.parent::Union{Nothing,AbsIntState} -is_constproped(sv::InferenceState) = any(sv.result.overridden_by_const) +function is_constproped(sv::InferenceState) + (;overridden_by_const) = sv.result + return overridden_by_const !== nothing +end is_constproped(::IRInterpretationState) = true is_cached(sv::InferenceState) = !iszero(sv.cache_mode & CACHE_MODE_GLOBAL) diff --git a/base/compiler/ssair/legacy.jl b/base/compiler/ssair/legacy.jl index f5181691cdf7b..b45db03875801 100644 --- a/base/compiler/ssair/legacy.jl +++ b/base/compiler/ssair/legacy.jl @@ -10,7 +10,7 @@ the original `ci::CodeInfo` are modified. """ function inflate_ir!(ci::CodeInfo, mi::MethodInstance) sptypes = sptypes_from_meth_instance(mi) - argtypes, _ = matching_cache_argtypes(fallback_lattice, mi) + argtypes = matching_cache_argtypes(fallback_lattice, mi) return inflate_ir!(ci, sptypes, argtypes) end function inflate_ir!(ci::CodeInfo, sptypes::Vector{VarState}, argtypes::Vector{Any}) diff --git a/base/compiler/typeinfer.jl b/base/compiler/typeinfer.jl index 825045670e350..22b0afc9b03c1 100644 --- a/base/compiler/typeinfer.jl +++ b/base/compiler/typeinfer.jl @@ -810,7 +810,7 @@ struct EdgeCallResult end end -# return cached regular inference result +# return cached result of regular inference function return_cached_result(::AbstractInterpreter, codeinst::CodeInstance, caller::AbsIntState) rt = cached_return_type(codeinst) effects = ipo_effects(codeinst) @@ -869,10 +869,8 @@ function typeinf_edge(interp::AbstractInterpreter, method::Method, @nospecialize effects = isinferred ? frame.result.ipo_effects : adjust_effects(Effects(), method) # effects are adjusted already within `finish` for ipo_effects exc_bestguess = refine_exception_type(frame.exc_bestguess, effects) # propagate newly inferred source to the inliner, allowing efficient inlining w/o deserialization: - # note that this result is cached globally exclusively, we can use this local result destructively - volatile_inf_result = (isinferred && (force_inline || - src_inlining_policy(interp, result.src, NoCallInfo(), IR_FLAG_NULL))) ? - VolatileInferenceResult(result) : nothing + # note that this result is cached globally exclusively, so we can use this local result destructively + volatile_inf_result = isinferred ? VolatileInferenceResult(result) : nothing return EdgeCallResult(frame.bestguess, exc_bestguess, edge, effects, volatile_inf_result) elseif frame === true # unresolvable cycle diff --git a/base/compiler/types.jl b/base/compiler/types.jl index 12c747d0c5707..088f8be234eb9 100644 --- a/base/compiler/types.jl +++ b/base/compiler/types.jl @@ -57,8 +57,6 @@ struct VarState VarState(@nospecialize(typ), undef::Bool) = new(typ, undef) end -abstract type ForwardableArgtypes end - struct AnalysisResults result next::AnalysisResults @@ -70,16 +68,19 @@ end const NULL_ANALYSIS_RESULTS = AnalysisResults(nothing) """ - InferenceResult(mi::MethodInstance, [argtypes::ForwardableArgtypes, ๐•ƒ::AbstractLattice]) + result::InferenceResult A type that represents the result of running type inference on a chunk of code. - -See also [`matching_cache_argtypes`](@ref). +There are two constructor available: +- `InferenceResult(mi::MethodInstance, [๐•ƒ::AbstractLattice])` for regualr inference, + without extended lattice information included in `result.argtypes`. +- `InferenceResult(mi::MethodInstance, argtypes::Vector{Any}, overridden_by_const::BitVector)` + for constant inference, with extended lattice information included in `result.argtypes`. """ mutable struct InferenceResult const linfo::MethodInstance const argtypes::Vector{Any} - const overridden_by_const::BitVector + const overridden_by_const::Union{Nothing,BitVector} result # extended lattice element if inferred, nothing otherwise exc_result # like `result`, but for the thrown value src # ::Union{CodeInfo, IRCode, OptimizationState} if inferred copy is available, nothing otherwise @@ -89,16 +90,18 @@ mutable struct InferenceResult analysis_results::AnalysisResults # AnalysisResults with e.g. result::ArgEscapeCache if optimized, otherwise NULL_ANALYSIS_RESULTS is_src_volatile::Bool # `src` has been cached globally as the compressed format already, allowing `src` to be used destructively ci::CodeInstance # CodeInstance if this result has been added to the cache - function InferenceResult(mi::MethodInstance, cache_argtypes::Vector{Any}, overridden_by_const::BitVector) - # def = mi.def - # nargs = def isa Method ? Int(def.nargs) : 0 - # @assert length(cache_argtypes) == nargs - return new(mi, cache_argtypes, overridden_by_const, nothing, nothing, nothing, + function InferenceResult(mi::MethodInstance, argtypes::Vector{Any}, overridden_by_const::Union{Nothing,BitVector}) + def = mi.def + nargs = def isa Method ? Int(def.nargs) : 0 + @assert length(argtypes) == nargs "invalid `argtypes` for `mi`" + return new(mi, argtypes, overridden_by_const, nothing, nothing, nothing, WorldRange(), Effects(), Effects(), NULL_ANALYSIS_RESULTS, false) end end -InferenceResult(mi::MethodInstance, ๐•ƒ::AbstractLattice=fallback_lattice) = - InferenceResult(mi, matching_cache_argtypes(๐•ƒ, mi)...) +function InferenceResult(mi::MethodInstance, ๐•ƒ::AbstractLattice=fallback_lattice) + argtypes = matching_cache_argtypes(๐•ƒ, mi) + return InferenceResult(mi, argtypes, #=overridden_by_const=#nothing) +end function stack_analysis_result!(inf_result::InferenceResult, @nospecialize(result)) return inf_result.analysis_results = AnalysisResults(result, inf_result.analysis_results) diff --git a/test/compiler/inference.jl b/test/compiler/inference.jl index 3accb1edf54e0..7bc1032828c80 100644 --- a/test/compiler/inference.jl +++ b/test/compiler/inference.jl @@ -4485,7 +4485,7 @@ let # Vararg #=va=# Bound, unbound, # => Tuple{Integer,Integer} (invalid `TypeVar` widened beforehand) } where Bound<:Integer - argtypes = Core.Compiler.most_general_argtypes(method, specTypes, true) + argtypes = Core.Compiler.most_general_argtypes(method, specTypes) popfirst!(argtypes) @test argtypes[1] == Integer @test argtypes[2] == Integer