-
Notifications
You must be signed in to change notification settings - Fork 133
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[New Feature]add normalized_emb #443
Conversation
@@ -240,6 +241,100 @@ def get_kv_creator(mpi_size: int, | |||
return de.CuckooHashTableCreator(saver=saver) | |||
|
|||
|
|||
class DynamicLayerNormalization(LayerNormalization): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool! It is really useful, and I recommend this layer could be a regular API, not only a demo.
from tensorflow.keras.layers import LayerNormalization | ||
|
||
|
||
class DynamicLayerNormalization(LayerNormalization): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is recommended to keep the name of LayerNormalization
and just use a different prefix sort of de.keras.layers.LayerNormalization
. Accordingly, the import statement needs to be modified to from tensorflow.keras.layers import TFLayerNormalization
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you mean from tensorflow.keras.layers import LayerNormalization as TFLayerNormalization ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Correct! And using de.keras.layers.LayerNormalization
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how about de.LayerNormalization?
DynamicLayerNormalization | ||
|
||
|
||
class DynamicLayerNormalizationTest(tf.test.TestCase): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We'd better to have one case testing the scenario in which the new layer works with tfra API like the demo shows.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Description
add DynamicLayerNormalization since tf LayerNormalization doesn't support dynamic shape
Brief Description of the PR:
Fixes # (issue)
Type of change
Checklist:
How Has This Been Tested?
If you're adding a bugfix or new feature please describe the tests that you ran to verify your changes:
*