From 8d3ae7400a4d2414e5d2efc1db8d07fcb110d367 Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Thu, 10 Oct 2024 11:35:52 +0200 Subject: [PATCH 1/5] make_zero! for immutable --- src/make_zero.jl | 6 ++++++ test/abi.jl | 17 +++++++++++++++++ 2 files changed, 23 insertions(+) diff --git a/src/make_zero.jl b/src/make_zero.jl index 4f627581ea..68397ef7da 100644 --- a/src/make_zero.jl +++ b/src/make_zero.jl @@ -384,6 +384,12 @@ end push!(seen, prev) + # For make_zero!(NamedTuple) we want to recurse and zero out + # the storage + if !Base.ismutabletype(T) + return Base.make_zero_immutable!(prev, seen) + end + for i = 1:nf if isdefined(prev, i) xi = getfield(prev, i) diff --git a/test/abi.jl b/test/abi.jl index 7a7917553f..9e24474083 100644 --- a/test/abi.jl +++ b/test/abi.jl @@ -493,6 +493,23 @@ end @test dv.y ≈ 0.0 end +@testset "Make Zero!" begin + params = (; q = rand(64), p = rand(64)) + dparams = make_zero(params) + @test all(==(0), dparams.q) + @test all(==(0), dparams.p) + + rand!(dparams.q) + rand!(dparams.p) + + dparams2 = make_zero!(dparams) + @test all(==(0), dparams.q) + @test all(==(0), dparams.p) + + @test dparams2 == dparams +end + + @testset "Type inference" begin x = ones(10) @inferred autodiff(Enzyme.Reverse, abssum, Duplicated(x,x)) From 68df161fdf959020dd67cc817830620587433cdc Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Thu, 10 Oct 2024 12:02:14 +0200 Subject: [PATCH 2/5] fixup! make_zero! for immutable --- test/abi.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/test/abi.jl b/test/abi.jl index 9e24474083..a08f73c36c 100644 --- a/test/abi.jl +++ b/test/abi.jl @@ -1,5 +1,6 @@ using Enzyme using Test +using Random @testset "ABI & Calling convention" begin From 9df5d4058e9adbac0845cde29498de66c1601376 Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Thu, 10 Oct 2024 12:29:18 +0200 Subject: [PATCH 3/5] fixup! fixup! make_zero! for immutable --- src/make_zero.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/make_zero.jl b/src/make_zero.jl index 68397ef7da..3388193c4a 100644 --- a/src/make_zero.jl +++ b/src/make_zero.jl @@ -387,7 +387,7 @@ end # For make_zero!(NamedTuple) we want to recurse and zero out # the storage if !Base.ismutabletype(T) - return Base.make_zero_immutable!(prev, seen) + return make_zero_immutable!(prev, seen) end for i = 1:nf From a3def6ce1ebcb4857fe38414e7fc92102d418631 Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Thu, 10 Oct 2024 12:37:53 +0200 Subject: [PATCH 4/5] fixup! fixup! fixup! make_zero! for immutable --- src/make_zero.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/make_zero.jl b/src/make_zero.jl index 3388193c4a..f131f58fce 100644 --- a/src/make_zero.jl +++ b/src/make_zero.jl @@ -198,7 +198,7 @@ function make_zero_immutable!(prev::T, seen::S)::T where {T<:Tuple,S} end function make_zero_immutable!(prev::NamedTuple{a,b}, seen::S)::NamedTuple{a,b} where {a,b,S} - NamedTuple{a,b}(ntuple(Val(length(T.parameters))) do i + NamedTuple{a,b}(ntuple(Val(length(a))) do i Base.@_inline_meta make_zero_immutable!(prev[a[i]], seen) end) From 0a4f468195a8697e0fe01f6e6ab915664a57b423 Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Thu, 10 Oct 2024 15:38:36 +0200 Subject: [PATCH 5/5] fixup! fixup! fixup! fixup! make_zero! for immutable --- src/make_zero.jl | 3 ++- test/abi.jl | 4 +--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/make_zero.jl b/src/make_zero.jl index f131f58fce..b0566210cd 100644 --- a/src/make_zero.jl +++ b/src/make_zero.jl @@ -387,7 +387,8 @@ end # For make_zero!(NamedTuple) we want to recurse and zero out # the storage if !Base.ismutabletype(T) - return make_zero_immutable!(prev, seen) + make_zero_immutable!(prev, seen) + return nothing end for i = 1:nf diff --git a/test/abi.jl b/test/abi.jl index a08f73c36c..ee81186874 100644 --- a/test/abi.jl +++ b/test/abi.jl @@ -503,11 +503,9 @@ end rand!(dparams.q) rand!(dparams.p) - dparams2 = make_zero!(dparams) + make_zero!(dparams) @test all(==(0), dparams.q) @test all(==(0), dparams.p) - - @test dparams2 == dparams end