Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add recursive map generalizing the make_zero mechanism #1852

Draft
wants to merge 20 commits into
base: main
Choose a base branch
from

Conversation

danielwe
Copy link
Contributor

This is to explore functionality for realizing JuliaMath/QuadGK.jl#120. The current draft cuts time and allocations in half for the MWE in that PR compared to the make_zero hack from the comments. Not sure if modifying the existing recursive_* functions like this is appropriate or whether it would be better to implement a separate deep_recursive_accumulate.

This probably breaks some existing uses of recursive_accumulate, like the Holomorphic derivative code, because recursive_accumulate now traverses most/all of the structure on its own and will double-accumulate when combined with the iteration over the seen IdDicts. Curious to see the total impact on the test suite.

This doesn't yet have any concept of seen and will thus double-accumulate if the structure has internal aliasing. That obviously needs to be fixed. Perhaps we can factor out and share the recursion code from make_zero.

A bit of a tangent, but perhaps a final version of this PR should include migrating ClosureVector to Enzyme from the QuadGK ext as suggested in JuliaMath/QuadGK.jl#110 (comment). Looks like that's the most relevant application of fully recursive accumulation at the moment.


Let me also throw out another suggestion: what if we implement a recursive generalization of broadcasting with an arbitrary number of arguments, i.e., recursive_broadcast!(f, a, b, c, ...) as a recursive generalization of a .= f.(b, c, ...), free of intermediate allocations whenever possible (and similarly an out-of-place recursive_broadcast(f, a, b, c...) generalizing f.(a, b, c...) that only materializes/allocates once if possible). That would enable more optimized custom rules with Duplicated args, such as having the QuadGK rule call the in-place version quadgk!(f!, result, segs...). Not sure if it would be hard to correctly handle aliasing without being overly defensive, or if that could mostly be taken care of by proper reuse of the existing broadcasting functionality.

@danielwe danielwe changed the title Make recursive_acc/accumulate more recursive Make recursive_add/accumulate more recursive Sep 18, 2024
@danielwe danielwe force-pushed the recursive_accumulate branch 2 times, most recently from 2161e03 to 545bf9b Compare September 25, 2024 00:13
@danielwe danielwe changed the title Make recursive_add/accumulate more recursive Add recursive map generalizing the make_zero mechanism Sep 25, 2024
@danielwe danielwe force-pushed the recursive_accumulate branch 4 times, most recently from 4fbdc47 to 74b212f Compare October 1, 2024 15:03
@danielwe danielwe marked this pull request as ready for review October 7, 2024 01:01
@danielwe
Copy link
Contributor Author

danielwe commented Oct 7, 2024

Alright, I could take some feedback/discussion on this now.

  • This implements a generic recursive_map for mapping a function over the differentiable values in arbitrary tuples of identical data structures. There's an in-place equivalent recursive_map! for mutable values, built on top of a bangbang-style recursive_map!! that works on arbitrary types and reuses all the mutable memory (similar to the old make_zero_immutable! but without code duplication).
  • The code is diffed with the old make_zero(!) code on github, but this is a complete rewrite and the diff will probably not be helpful for reviewing. Let me know if you want me to rename files or something to get rid github's diff view.
    • The implementation is leaner and simpler than the old one, and even though recursive_map{!!} aren't public I wrote extensive docstrings to clarify the spec for myself and others, so I don't think the code should be too difficult to review from scratch.
  • I added fast paths such that new structs are allocated using splatnew with a tuple instead of ccall with a vector in the common case where there are no undefined fields. This gives a substantial speedup in many cases, which is good since recursive_map will be called in hot loops in custom quadrature rules and the like.
  • I have refactored make_zero and make_zero! to be minimal wrappers around recursive_map{!}, without changing their public API.
    • To stay safe while doing such a big refactoring, I wrote extensive tests with ~full coverage of both the old and new implementations of make_zero(!) (a small number of edge case branches aren't covered only because I can't find a way to reach them from any public entry point).
    • These tests uncovered quite a few bugs in the existing make_zero(!) implementations. See the following commit on a separate branch in my fork for the necessary fixes to get the old code to pass the new tests: danielwe@7a6ca9f. (Note, one of the tests still errors due to active_reg_inner with justActive = true incorrect with immutable types that can be incompletely initialized #1935.)
  • I have not refactored recursive_add and recursive_accumulate, since these currently have different semantics where they don't recurse into mutable values. I'm happy to go ahead and do the refactoring if you're OK with changing their semantics.

TLDR: Should I rewrite recursive_add and recursive_accumulate to be based on recursive_map{!} and have full recursion semantics? Anything else?

@gdalle Promised to tag you when this was ready for review, but note that this PR only deals with the low-level, non-public guts of the implementation. I'll do the vector space wrapper in a separate PR as soon as this is merged (hopefully that won't be long, I really need that QuadGK rule for my research 📐)

return seen[prev]
xs::NTuple{N,T},
::Val{copy_if_inactive}=Val(false),
isleaftype::L=Returns(false),
Copy link
Contributor Author

@danielwe danielwe Oct 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wondering whether this is necessary or if the leaf types could just be hardcoded to Union{_RealOrComplexFloat,Array{<:_RealOrComplexFloat}}. I'll make a prototype of the vector space wrapper and the updated QuadGK rules to see if customizable leaf types comes in handy.

@danielwe danielwe marked this pull request as draft October 9, 2024 19:34
@danielwe
Copy link
Contributor Author

Update for anyone who's following: I've implemented the VectorSpace wrapper, which prompted me to adjust the recursive_map implementation a bit, all for the better. It's looking good and will make writing custom higher-order rules as well as the DI wrappers a lot nicer for arbitrary types. However, it dawned on me that you probably want make_zero to be easily extensible by just adding methods, like what's already done in the StaticArrays extension. That will require a bit of redesign, nothing too hard, but I've got weekend plans so might not get to it until next week.

@wsmoses
Copy link
Member

wsmoses commented Oct 11, 2024

awesome, sorry I haven't had a chance to review let [just a bunch of schenanigans atm], I'll try to take a closer look next week and ping me if not

@danielwe
Copy link
Contributor Author

No worries! I restored the draft label when I realized there was a bit more to do and will remove it again once I think this is ready for review. No need to look at it until then, the current state here on github doesn't reflect what I'm working with locally anyway.

isleaftype::L=Returns(false),
) where {T,F,N,L,copy_if_inactive}
x1 = first(xs)
if guaranteed_const_nongen(T, nothing)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to confirm, this is only for make_zero, and not for add/etc?

Because this case here already feels specific to the context

Copy link
Contributor Author

@danielwe danielwe Oct 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's going to look a bit different once I push the next update (hopefully tomorrow), but no, after some experimenting it seemed best to me to always skip guaranteed inactive subtrees and restrict recursive_map to applying f to the differentiable values only. I tried doing the opposite initially, leaving it as part of the isleaftype filter and handling the possible deepcopy within the mapped function f, but it made things a lot more complicated. I think the main issue was that the whole mechanism with seen and keeping track of object identity then becomes the purview of the mapped function f instead of recursive_map itself, increasing boilerplate and complicating the contract between recursive_map and its callers. I couldn't think of a use case within Enzyme where you're interested in mapping over the guaranteed inactive parts anyway, and not recursing through inactive subtrees saves you from having to deal with deconstruction/reconstruction of a few specialized types (deepcopy has a lot more methods than recursive_map). So I went with this solution instead.

Of course, adding a skip_guaranteed_const flag would be straightforward (or combining it with copy_if_inactive into a single inactive_mode parameter). Do you think this is warranted?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants