Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove EnzymeInterpreter and instead reuse GPUInterpreter #1893

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4048,7 +4048,7 @@ if VERSION >= v"1.11.0-DEV.1552"
always_inline::Any
method_table::Core.MethodTable
param_type::Type
is_fwd::Bool
mode::API.CDerivativeMode
end

GPUCompiler.ci_cache_token(job::CompilerJob{<:Any,<:AbstractEnzymeCompilerParams}) =
Expand All @@ -4057,15 +4057,15 @@ if VERSION >= v"1.11.0-DEV.1552"
job.config.always_inline,
GPUCompiler.method_table(job),
typeof(job.config.params),
job.config.params.mode == API.DEM_ForwardMode,
API.DEM_ForwardMode,
)

GPUCompiler.get_interpreter(job::CompilerJob{<:Any,<:AbstractEnzymeCompilerParams}) =
Interpreter.EnzymeInterpreter(
GPUCompiler.GPUInterpreter(
GPUCompiler.ci_cache_token(job),
GPUCompiler.method_table(job),
job.world,
job.config.params.mode,
meta=Interpreter.EnzymeMeta(job.config.params.mode),
)
else

Expand All @@ -4074,6 +4074,7 @@ else
# rule or not inlining a rev mode rule. Otherwise, all caches can be re-used.
const GLOBAL_FWD_CACHE = GPUCompiler.CodeCache()
const GLOBAL_REV_CACHE = GPUCompiler.CodeCache()
# TODO: Branch on target... otherwise GPU and CPU code end in the same cache
function enzyme_ci_cache(job::CompilerJob{<:Any,<:AbstractEnzymeCompilerParams})
return if job.config.params.mode == API.DEM_ForwardMode
GLOBAL_FWD_CACHE
Expand All @@ -4086,11 +4087,11 @@ else
enzyme_ci_cache(job)

GPUCompiler.get_interpreter(job::CompilerJob{<:Any,<:AbstractEnzymeCompilerParams}) =
Interpreter.EnzymeInterpreter(
GPUCompiler.GPUInterpreter(
enzyme_ci_cache(job),
GPUCompiler.method_table(job),
job.world,
job.config.params.mode,
meta=Interpreter.EnzymeMeta(job.config.params.mode),
)
end

Expand Down
279 changes: 28 additions & 251 deletions src/compiler/interpreter.jl
Original file line number Diff line number Diff line change
@@ -1,18 +1,5 @@
module Interpreter
import Enzyme: API
using Core.Compiler:
AbstractInterpreter,
InferenceResult,
InferenceParams,
InferenceState,
OptimizationParams,
MethodInstance
using GPUCompiler: @safe_debug
if VERSION < v"1.11.0-DEV.1552"
using GPUCompiler: CodeCache, WorldView, @safe_debug
end
const HAS_INTEGRATED_CACHE = VERSION >= v"1.11.0-DEV.1552"

import ..Enzyme
import ..EnzymeRules

Expand All @@ -22,93 +9,10 @@ import ..EnzymeRules
else
import Core.Compiler: get_world_counter, get_world_counter as get_inference_world
end

struct EnzymeInterpreter <: AbstractInterpreter
@static if HAS_INTEGRATED_CACHE
token::Any
else
code_cache::CodeCache
end
method_table::Union{Nothing,Core.MethodTable}

# Cache of inference results for this particular interpreter
local_cache::Vector{InferenceResult}
# The world age we're working inside of
world::UInt

# Parameters for inference and optimization
inf_params::InferenceParams
opt_params::OptimizationParams

struct EnzymeMeta
mode::API.CDerivativeMode
end

function EnzymeInterpreter(
cache_or_token,
mt::Union{Nothing,Core.MethodTable},
world::UInt,
mode::API.CDerivativeMode,
)
@assert world <= Base.get_world_counter()

parms = @static if VERSION < v"1.12"
InferenceParams(unoptimize_throw_blocks = false)
else
InferenceParams()
end

return EnzymeInterpreter(
cache_or_token,
mt,

# Initially empty cache
Vector{InferenceResult}(),

# world age counter
world,

# parameters for inference and optimization
parms,
OptimizationParams(),
mode,
)
end

Core.Compiler.InferenceParams(interp::EnzymeInterpreter) = interp.inf_params
Core.Compiler.OptimizationParams(interp::EnzymeInterpreter) = interp.opt_params
get_inference_world(interp::EnzymeInterpreter) = interp.world
Core.Compiler.get_inference_cache(interp::EnzymeInterpreter) = interp.local_cache
@static if HAS_INTEGRATED_CACHE
Core.Compiler.cache_owner(interp::EnzymeInterpreter) = interp.token
else
Core.Compiler.code_cache(interp::EnzymeInterpreter) =
WorldView(interp.code_cache, interp.world)
end

# No need to do any locking since we're not putting our results into the runtime cache
Core.Compiler.lock_mi_inference(::EnzymeInterpreter, ::MethodInstance) = nothing
Core.Compiler.unlock_mi_inference(::EnzymeInterpreter, ::MethodInstance) = nothing

Core.Compiler.may_optimize(::EnzymeInterpreter) = true
Core.Compiler.may_compress(::EnzymeInterpreter) = true
# From @aviatesk:
# `may_discard_trees = true`` means a complicated (in terms of inlineability) source will be discarded,
# but as far as I understand Enzyme wants "always inlining, except special cased functions",
# so I guess we really don't want to discard sources?
Core.Compiler.may_discard_trees(::EnzymeInterpreter) = false
Core.Compiler.verbose_stmt_info(::EnzymeInterpreter) = false

if isdefined(Base.Experimental, Symbol("@overlay"))
Core.Compiler.method_table(interp::EnzymeInterpreter, sv::InferenceState) =
Core.Compiler.OverlayMethodTable(interp.world, interp.method_table)
else

# On 1.6- CUDA.jl will poison the method table at the end of the world
# using GPUCompiler: WorldOverlayMethodTable
# Core.Compiler.method_table(interp::EnzymeInterpreter, sv::InferenceState) =
# WorldOverlayMethodTable(interp.world)
end

function is_alwaysinline_func(@nospecialize(TT))
isa(TT, DataType) || return false
return false
Expand Down Expand Up @@ -149,153 +53,41 @@ function simplify_kw(@nospecialize specTypes)
end
end

import Core.Compiler: CallInfo
struct NoInlineCallInfo <: CallInfo
info::CallInfo # wrapped call
tt::Any # ::Type
kind::Symbol
NoInlineCallInfo(@nospecialize(info::CallInfo), @nospecialize(tt), kind::Symbol) =
new(info, tt, kind)
end
Core.Compiler.nsplit_impl(info::NoInlineCallInfo) = Core.Compiler.nsplit(info.info)
Core.Compiler.getsplit_impl(info::NoInlineCallInfo, idx::Int) =
Core.Compiler.getsplit(info.info, idx)
Core.Compiler.getresult_impl(info::NoInlineCallInfo, idx::Int) =
Core.Compiler.getresult(info.info, idx)
struct AlwaysInlineCallInfo <: CallInfo
info::CallInfo # wrapped call
tt::Any # ::Type
AlwaysInlineCallInfo(@nospecialize(info::CallInfo), @nospecialize(tt)) = new(info, tt)
end
Core.Compiler.nsplit_impl(info::AlwaysInlineCallInfo) = Core.Compiler.nsplit(info.info)
Core.Compiler.getsplit_impl(info::AlwaysInlineCallInfo, idx::Int) =
Core.Compiler.getsplit(info.info, idx)
Core.Compiler.getresult_impl(info::AlwaysInlineCallInfo, idx::Int) =
Core.Compiler.getresult(info.info, idx)

using Core.Compiler: ArgInfo, StmtInfo, AbsIntState
function Core.Compiler.abstract_call_gf_by_type(
interp::EnzymeInterpreter,
@nospecialize(f),
arginfo::ArgInfo,
si::StmtInfo,
@nospecialize(atype),
sv::AbsIntState,
max_methods::Int,
)
ret = @invoke Core.Compiler.abstract_call_gf_by_type(
interp::AbstractInterpreter,
f::Any,
arginfo::ArgInfo,
si::StmtInfo,
atype::Any,
sv::AbsIntState,
max_methods::Int,
)
callinfo = ret.info
import GPUCompiler: GPUInterpreter, NoInlineCallInfo, AlwaysInlineCallInfo
function inlining_handler(meta::EnzymeMeta, interp::GPUInterpreter, @nospecialize(atype), callinfo)
method_table = Core.Compiler.method_table(interp)
world = get_inference_world(interp)

specTypes = simplify_kw(atype)
if is_primitive_func(specTypes)
callinfo = NoInlineCallInfo(callinfo, atype, :primitive)
return NoInlineCallInfo(callinfo, atype, :primitive)
elseif is_alwaysinline_func(specTypes)
callinfo = AlwaysInlineCallInfo(callinfo, atype)
elseif EnzymeRules.is_inactive_from_sig(specTypes; world = interp.world, method_table)
callinfo = NoInlineCallInfo(callinfo, atype, :inactive)
elseif interp.mode == API.DEM_ForwardMode
if EnzymeRules.has_frule_from_sig(specTypes; world = interp.world, method_table)
callinfo = NoInlineCallInfo(callinfo, atype, :frule)
return AlwaysInlineCallInfo(callinfo, atype)
elseif EnzymeRules.is_inactive_from_sig(specTypes; world, method_table)
return NoInlineCallInfo(callinfo, atype, :inactive)
elseif meta.mode == API.DEM_ForwardMode
if EnzymeRules.has_frule_from_sig(specTypes; world, method_table)
return NoInlineCallInfo(callinfo, atype, :frule)
end
elseif EnzymeRules.has_rrule_from_sig(specTypes; world = interp.world, method_table)
callinfo = NoInlineCallInfo(callinfo, atype, :rrule)
end
@static if VERSION ≥ v"1.11-"
return Core.Compiler.CallMeta(ret.rt, ret.exct, ret.effects, callinfo)
else
return Core.Compiler.CallMeta(ret.rt, ret.effects, callinfo)
end
end

let # overload `inlining_policy`
@static if VERSION ≥ v"1.11.0-DEV.879"
sigs_ex = :(
interp::EnzymeInterpreter,
@nospecialize(src),
@nospecialize(info::Core.Compiler.CallInfo),
stmt_flag::UInt32,
)
args_ex = :(
interp::AbstractInterpreter,
src::Any,
info::Core.Compiler.CallInfo,
stmt_flag::UInt32,
)
else
sigs_ex = :(
interp::EnzymeInterpreter,
@nospecialize(src),
@nospecialize(info::Core.Compiler.CallInfo),
stmt_flag::UInt8,
mi::MethodInstance,
argtypes::Vector{Any},
)
args_ex = :(
interp::AbstractInterpreter,
src::Any,
info::Core.Compiler.CallInfo,
stmt_flag::UInt8,
mi::MethodInstance,
argtypes::Vector{Any},
)
end
@eval function Core.Compiler.inlining_policy($(sigs_ex.args...))
if info isa NoInlineCallInfo
if info.kind === :primitive
@safe_debug "Blocking inlining for primitive func" info.tt
elseif info.kind === :inactive
@safe_debug "Blocking inlining due to inactive rule" info.tt
elseif info.kind === :frule
@safe_debug "Blocking inlining due to frule" info.tt
else
@assert info.kind === :rrule
@safe_debug "Blocking inlining due to rrule" info.tt
end
return nothing
elseif info isa AlwaysInlineCallInfo
@safe_debug "Forcing inlining for primitive func" info.tt
return src
elseif meta.mode == API.DEM_ReverseModeCombined ||
meta.mode == API.DEM_ReverseModePrimal ||
meta.mode == API.DEM_ReverseModeGradient
if EnzymeRules.has_rrule_from_sig(specTypes; world, method_table)
return NoInlineCallInfo(callinfo, atype, :rrule)
end
return @invoke Core.Compiler.inlining_policy($(args_ex.args...))
end
return nothing
end

import Core.Compiler:
abstract_call,
abstract_call_known,
ArgInfo,
StmtInfo,
AbsIntState,
get_max_methods,
CallMeta,
Effects,
NoCallInfo,
widenconst,
mapany,
MethodResultPure

struct AutodiffCallInfo <: CallInfo
struct AutodiffCallInfo <: CC.CallInfo
# ...
info::CallInfo
info::CC.CallInfo
end

function abstract_call_known(
interp::EnzymeInterpreter,
@nospecialize(f),
arginfo::ArgInfo,
si::StmtInfo,
sv::AbsIntState,
max_methods::Int = get_max_methods(interp, f, sv),
)

import GPUCompiler: abstract_call_known
import CC: CallMeta, Effects, NoCallInfo
function abstract_call_known(meta::EnzymeMeta, interp::GPUInterpreter, @nospecialize(f),
arginfo::ArgInfo, si::StmtInfo, sv::AbsIntState, max_methods::Int)
(; fargs, argtypes) = arginfo

if f === Enzyme.within_autodiff
Expand All @@ -307,18 +99,9 @@ function abstract_call_known(
end
end
@static if VERSION < v"1.11.0-"
return CallMeta(
Core.Const(true),
Core.Compiler.EFFECTS_TOTAL,
MethodResultPure(),
)
return CallMeta(Core.Const(true), CC.EFFECTS_TOTAL, MethodResultPure())
else
return CallMeta(
Core.Const(true),
Union{},
Core.Compiler.EFFECTS_TOTAL,
MethodResultPure(),
)
return CallMeta(Core.Const(true), Union{}, CC.EFFECTS_TOTAL, MethodResultPure(),)
end
end

Expand All @@ -331,6 +114,7 @@ function abstract_call_known(
[:(Enzyme.autodiff_deferred), fargs[2:end]...],
[Core.Const(Enzyme.autodiff_deferred), argtypes[2:end]...],
)
# FIXME: Use AutodiffCallInfo and a custom inlining handler
return abstract_call_known(
interp,
Enzyme.autodiff_deferred,
Expand All @@ -341,14 +125,7 @@ function abstract_call_known(
)
end
end
return Base.@invoke abstract_call_known(
interp::AbstractInterpreter,
f,
arginfo::ArgInfo,
si::StmtInfo,
sv::AbsIntState,
max_methods::Int,
)
return nothing
end

end
Loading