-
Notifications
You must be signed in to change notification settings - Fork 471
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support of AOT compilation (refine #6992) #7581
base: master
Are you sure you want to change the base?
Conversation
@@ -12,6 +13,8 @@ namespace runtime { | |||
|
|||
std::atomic<bool> g_computation_client_initialized(false); | |||
|
|||
std::string aot_topology = ""; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since this state is only ever used by PjRtCompilationClient
, would it make sense to make this a static field on PjRtCompilationClient
?
I think it's probably even better to just imperatively initialize this client when set_virtual_topology()
happens... but that also would introduce some extra edge cases for initialization. For the POC, a global var is okay, but I think a static field makes more sense to make it clear what uses that state.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moving aot_topology
to pjrt_compilation_client.h
may cause circular dependency issue because setting the aot_topology need to check g_computation_client_initialized
in the current setup.
Let's use the global var for now since this makes it clear in runtime.cc that we have three kinds of clients all together.
|
||
// Builds a map from the device's global ordinal to its index in the `devices` | ||
// array. | ||
std::unordered_map<int, int> build_index_map( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see some duplication of helpers here and in the PJRT/IFRT client implementations. It would be cleaner to factor it out rather than have 3 copies, but both AOT and IFRT are just in a concept stage...
I'll let @JackCaoG have the final word on readability so I don't have to decide.
This is a follow up PR to refine #6992.
In this PR, I created the
PjRtCompilationClient
to serve the ahead of time compilation. In this way, we don't need to create theCompileOnlyPjRtClient
,CompileOnlyPjRtDevice
etc. This makes it easier during openxla pin update.Instructions on how to run AOT compilation has been updated. We need to specify two extra flags when run on CPU device:
XLA_PERSISTENT_CACHE_PATH
as follows:----------------------- ON CPU--------------------
aot_encode.py:
This will genereate a hashing file named like 229013763457648799216243727807636414712, which can be deserialized by running the same graph code on a TPU:
-----------------------ON TPU v4-8--------------------
aot_decode.py: