Skip to content

Commit

Permalink
Reorder methods for easier reading
Browse files Browse the repository at this point in the history
  • Loading branch information
danielwe committed Sep 30, 2024
1 parent bd1cca0 commit 4fbdc47
Showing 1 changed file with 143 additions and 143 deletions.
286 changes: 143 additions & 143 deletions src/make_zero.jl
Original file line number Diff line number Diff line change
@@ -1,46 +1,3 @@
const _RealOrComplexFloat = Union{AbstractFloat,Complex{<:AbstractFloat}}

@inline function EnzymeCore.make_zero(prev::FT) where {FT<:_RealOrComplexFloat}
return Base.zero(prev)::FT
end

@inline function EnzymeCore.make_zero(
::Type{FT},
@nospecialize(seen::IdDict),
prev::FT,
@nospecialize(_::Val{copy_if_inactive}=Val(false)),
) where {FT<:_RealOrComplexFloat,copy_if_inactive}
return EnzymeCore.make_zero(prev)::FT
end

@inline function EnzymeCore.make_zero(prev::Array{FT,N}) where {FT<:_RealOrComplexFloat,N}
# convert because Base.zero may return different eltype when FT is not concrete
return convert(Array{FT,N}, Base.zero(prev))::Array{FT,N}
end

@inline function EnzymeCore.make_zero(
::Type{Array{FT,N}},
seen::IdDict,
prev::Array{FT,N},
@nospecialize(_::Val{copy_if_inactive}=Val(false)),
) where {FT<:_RealOrComplexFloat,N,copy_if_inactive}
if haskey(seen, prev)
return seen[prev]::Array{FT,N}
end
newa = EnzymeCore.make_zero(prev)
seen[prev] = newa
return newa::Array{FT,N}
end

@inline function EnzymeCore.make_zero(
::Type{RT}, seen::IdDict, prev::RT, ::Val{copy_if_inactive}=Val(false)
) where {RT,copy_if_inactive}
isleaftype(_) = false
isleaftype(::Type{<:Union{_RealOrComplexFloat,Array{<:_RealOrComplexFloat}}}) = true
f(p) = EnzymeCore.make_zero(Core.Typeof(p), seen, p, Val(copy_if_inactive))
return recursive_map(RT, f, seen, (prev,), Val(copy_if_inactive), isleaftype)::RT
end

recursive_map(f::F, xs::T...) where {F,T} = recursive_map(T, f, IdDict(), xs)::T

@inline function recursive_map(
Expand All @@ -59,24 +16,6 @@ recursive_map(f::F, xs::T...) where {F,T} = recursive_map(T, f, IdDict(), xs)::T
return _recursive_map(RT, f, seen, xs, Val(copy_if_inactive), isleaftype)::RT
end

@inline function _recursive_map(
::Type{RT}, f::F, seen::IdDict, xs::NTuple{N,RT}, args...
) where {RT<:Array,F,N}
if haskey(seen, xs)
return seen[xs]::RT
end
y = RT(undef, size(first(xs)))
seen[xs] = y
for I in eachindex(xs...)
if all(x -> isassigned(x, I), xs)
xIs = ntuple(j -> xs[j][I], N)
ST = Core.Typeof(first(xIs))
@inbounds y[I] = recursive_map(ST, f, seen, xIs, args...)
end
end
return y
end

@inline function _recursive_map(
::Type{RT}, f::F, seen::IdDict, xs::NTuple{N,RT}, args...
) where {RT,F,N}
Expand Down Expand Up @@ -127,6 +66,103 @@ end
return y
end

@inline function _recursive_map(
::Type{RT}, f::F, seen::IdDict, xs::NTuple{N,RT}, args...
) where {RT<:Array,F,N}
if haskey(seen, xs)
return seen[xs]::RT
end
y = RT(undef, size(first(xs)))
seen[xs] = y
for I in eachindex(xs...)
if all(x -> isassigned(x, I), xs)
xIs = ntuple(j -> xs[j][I], N)
ST = Core.Typeof(first(xIs))
@inbounds y[I] = recursive_map(ST, f, seen, xIs, args...)
end
end
return y
end

@inline function recursive_map!(f::F, y::T, xs::T...) where {F,T}
return recursive_map!(f, y, Base.IdSet(), xs)::Nothing
end

@inline function recursive_map!(
f::F, y::T, seen::Base.IdSet, xs::NTuple{N,T}, isleaftype::L=Returns(false)
) where {F,T,N,L}
if guaranteed_const_nongen(T, nothing)
return nothing
elseif isleaftype(T)
# If there exist T such that isleaftype(T) and T has mutable content that is not
# guaranteed const, including mutables nested inside immutables like Tuple{Vector},
# then f must have a corresponding mutating method:
f(y, xs...)
return nothing
end
return _recursive_map!(f, y, seen, xs, isleaftype)::Nothing
end

@inline function _recursive_map!(
f::F, y::T, seen, xs::NTuple{N,T}, isleaftype
) where {F,T,N}
if y in seen
return nothing
end
@assert !Base.isabstracttype(T)
@assert Base.isconcretetype(T)
nf = fieldcount(T)
if nf == 0
return nothing
end
push!(seen, y)
for i = 1:nf
if isdefined(y, i) && all(x -> isdefined(x, i), xs)
yi = getfield(y, i)
xis = ntuple(j -> getfield(xs[j], i), N)
SBT = Core.Typeof(yi)
activitystate = active_reg_inner(SBT, (), nothing, Val(false))
if activitystate == AnyState
continue
elseif activitystate == DupState
recursive_map!(f, yi, seen, xis, isleaftype)
else
yi = recursive_map_immutable!(f, yi, seen, xis, isleaftype)
if Base.isconst(T, i)
ccall(:jl_set_nth_field, Cvoid, (Any, Csize_t, Any), y, i - 1, yi)
else
setfield!(y, i, yi)
end
end
end
end
return nothing
end

@inline function _recursive_map!(
f::F, y::Array{T,M}, seen, xs::NTuple{N,Array{T,M}}, isleaftype
) where {F,T,M,N}
if y in seen
return nothing
end
push!(seen, y)
for I in eachindex(y, xs...)
if isassigned(y, I) && all(x -> isassigned(x, I), xs)
yvalue = y[I]
xvalues = ntuple(j -> xs[j][I], N)
SBT = Core.Typeof(yvalue)
if active_reg_inner(SBT, (), nothing, Val(true)) == ActiveState #=justActive=#
@inbounds y[I] = recursive_map_immutable!(
f, yvalue, seen, xvalues, isleaftype
)
else
recursive_map!(f, yvalue, seen, xvalues, isleaftype)
end
end
end
return nothing
end

@inline function recursive_map_immutable!(f::F, y::T, xs::T...) where {F,T}
return recursive_map_immutable!(f, y, Base.IdSet(), xs)::T
end
Expand Down Expand Up @@ -185,20 +221,47 @@ end
return newy
end

@inline function EnzymeCore.make_zero!(prev::Array{T,N}) where {T<:_RealOrComplexFloat,N}
fill!(prev, zero(T))
return nothing
const _RealOrComplexFloat = Union{AbstractFloat,Complex{<:AbstractFloat}}

@inline function EnzymeCore.make_zero(
::Type{RT}, seen::IdDict, prev::RT, ::Val{copy_if_inactive}=Val(false)
) where {RT,copy_if_inactive}
isleaftype(_) = false
isleaftype(::Type{<:Union{_RealOrComplexFloat,Array{<:_RealOrComplexFloat}}}) = true
f(p) = EnzymeCore.make_zero(Core.Typeof(p), seen, p, Val(copy_if_inactive))
return recursive_map(RT, f, seen, (prev,), Val(copy_if_inactive), isleaftype)::RT
end

@inline function EnzymeCore.make_zero!(
prev::Array{T,N}, seen::Base.IdSet,
) where {T<:_RealOrComplexFloat,N}
if prev in seen
return nothing
@inline function EnzymeCore.make_zero(
::Type{FT},
@nospecialize(seen::IdDict),
prev::FT,
@nospecialize(_::Val{copy_if_inactive}=Val(false)),
) where {FT<:_RealOrComplexFloat,copy_if_inactive}
return EnzymeCore.make_zero(prev)::FT
end

@inline function EnzymeCore.make_zero(prev::FT) where {FT<:_RealOrComplexFloat}
return Base.zero(prev)::FT
end

@inline function EnzymeCore.make_zero(
::Type{Array{FT,N}},
seen::IdDict,
prev::Array{FT,N},
@nospecialize(_::Val{copy_if_inactive}=Val(false)),
) where {FT<:_RealOrComplexFloat,N,copy_if_inactive}
if haskey(seen, prev)
return seen[prev]::Array{FT,N}
end
push!(seen, prev)
EnzymeCore.make_zero!(prev)
return nothing
newa = EnzymeCore.make_zero(prev)
seen[prev] = newa
return newa::Array{FT,N}
end

@inline function EnzymeCore.make_zero(prev::Array{FT,N}) where {FT<:_RealOrComplexFloat,N}
# convert because Base.zero may return different eltype when FT is not concrete
return convert(Array{FT,N}, Base.zero(prev))::Array{FT,N}
end

@inline function EnzymeCore.make_zero!(prev, seen::Base.IdSet=Base.IdSet())
Expand All @@ -213,81 +276,18 @@ end
return recursive_map!(f, prev, seen, (prev,), isleaftype)::Nothing
end

@inline function recursive_map!(f::F, y::T, xs::T...) where {F,T}
return recursive_map!(f, y, Base.IdSet(), xs)::Nothing
end

@inline function recursive_map!(
f::F, y::T, seen::Base.IdSet, xs::NTuple{N,T}, isleaftype::L=Returns(false)
) where {F,T,N,L}
if guaranteed_const_nongen(T, nothing)
return nothing
elseif isleaftype(T)
# If there exist T such that isleaftype(T) and T has mutable content that is not
# guaranteed const, including mutables nested inside immutables like Tuple{Vector},
# then f must have a corresponding mutating method:
f(y, xs...)
return nothing
end
return _recursive_map!(f, y, seen, xs, isleaftype)::Nothing
end

@inline function _recursive_map!(
f::F, y::Array{T,M}, seen, xs::NTuple{N,Array{T,M}}, isleaftype
) where {F,T,M,N}
if y in seen
@inline function EnzymeCore.make_zero!(
prev::Array{T,N}, seen::Base.IdSet,
) where {T<:_RealOrComplexFloat,N}
if prev in seen
return nothing
end
push!(seen, y)
for I in eachindex(y, xs...)
if isassigned(y, I) && all(x -> isassigned(x, I), xs)
yvalue = y[I]
xvalues = ntuple(j -> xs[j][I], N)
SBT = Core.Typeof(yvalue)
if active_reg_inner(SBT, (), nothing, Val(true)) == ActiveState #=justActive=#
@inbounds y[I] = recursive_map_immutable!(
f, yvalue, seen, xvalues, isleaftype
)
else
recursive_map!(f, yvalue, seen, xvalues, isleaftype)
end
end
end
push!(seen, prev)
EnzymeCore.make_zero!(prev)
return nothing
end

@inline function _recursive_map!(
f::F, y::T, seen, xs::NTuple{N,T}, isleaftype
) where {F,T,N}
if y in seen
return nothing
end
@assert !Base.isabstracttype(T)
@assert Base.isconcretetype(T)
nf = fieldcount(T)
if nf == 0
return nothing
end
push!(seen, y)
for i = 1:nf
if isdefined(y, i) && all(x -> isdefined(x, i), xs)
yi = getfield(y, i)
xis = ntuple(j -> getfield(xs[j], i), N)
SBT = Core.Typeof(yi)
activitystate = active_reg_inner(SBT, (), nothing, Val(false))
if activitystate == AnyState
continue
elseif activitystate == DupState
recursive_map!(f, yi, seen, xis, isleaftype)
else
yi = recursive_map_immutable!(f, yi, seen, xis, isleaftype)
if Base.isconst(T, i)
ccall(:jl_set_nth_field, Cvoid, (Any, Csize_t, Any), y, i - 1, yi)
else
setfield!(y, i, yi)
end
end
end
end
@inline function EnzymeCore.make_zero!(prev::Array{T,N}) where {T<:_RealOrComplexFloat,N}
fill!(prev, zero(T))
return nothing
end

0 comments on commit 4fbdc47

Please sign in to comment.