Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check if function is being called inside autodiff #1761

Closed
avik-pal opened this issue Aug 29, 2024 · 6 comments
Closed

Check if function is being called inside autodiff #1761

avik-pal opened this issue Aug 29, 2024 · 6 comments

Comments

@avik-pal
Copy link
Contributor

Conversation from Slack

@avik-pal

In Enzyme can we check if a function is being called within an autodiff? Something equivalent to the following ChainRules version

within_gradient(_) = False()
CRC.rrule(::typeof(within_gradient), x) = True(), _ -> (∂∅, ∂∅)

I tried defining:

function EnzymeRules.forward(
        ::EnzymeCore.Const{typeof(within_gradient)}, ::Type{RT}, x) where {RT}
    error("within_gradient")
end

@wsmoses

I think this is getting optimized out lol
But yeah we can add something (c++ interface already has so we should bring parity to Julia for that :P)
(Julia will still constprop through things with enzyme rules atm)

@avik-pal
Copy link
Contributor Author

@wsmoses is this as simple as adding a ccall somewhere? I can make a PR if you point me to the function

@wsmoses
Copy link
Member

wsmoses commented Aug 31, 2024

we have this magic function: https://github.com/EnzymeAD/Enzyme/blob/7f614f43808e5bd3960f3712ac880b38eedc01d6/enzyme/test/Integration/ReverseMode/mycos.c#L39

which has the semantics that __enzyme_iter(x, y) = x if not differentiated, x + y on first order, etc [used for sake of keeping right order of taylor series].

It probably needs a bit of julia integration, but maybe something like this?

@avik-pal
Copy link
Contributor Author

Is this not an "actual" function present in the binary? I tried

function within_autodiff()
    return ccall((:__enzyme_iter, libEnzyme), UInt64, (UInt64, UInt64), 0, 1) != 0
end

But that symbol is not present

@wsmoses
Copy link
Member

wsmoses commented Aug 31, 2024 via email

@wsmoses
Copy link
Member

wsmoses commented Sep 1, 2024

hm okay I realize this won't quite work the same in Julia since we won't have control over the non differentiated optimizatio pipeline.

I think the solution here is we just need to tell the abstract interpreter to block constprop for a special function which we rewrite if inside an enzyme differentiated context. cc @vchuravy @aviatesk

I know some early work on absint overriding was here: #1443 but I think needs someone to push it forward atm

@wsmoses
Copy link
Member

wsmoses commented Sep 18, 2024

@avik-pal so this should be basically as simple as defining a new function returning false and then doing the same as #1839 to change to false [which this interpreter always runs in an autodiff context]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants