-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Gelu test gpu #6
base: main
Are you sure you want to change the base?
Conversation
op_acc_stable_run.py
Outdated
@@ -128,7 +128,7 @@ def op_acc_stable_run(test_obj, stable_num=100): | |||
ret = [] | |||
tmp_cache_path = getattr(test_obj, "tmp_cache_path", None) | |||
if not tmp_cache_path: | |||
tmp_cache_path = os.getenv("TMP_CACHE_PATH", "/dev/shm") | |||
tmp_cache_path = os.getenv("TMP_CACHE_PATH", "/home") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不要改这个,这个具体设置可以在你的单测里改。
比如单测里写self.tmp_cache_path = "."
op_acc_stable_run.py
Outdated
@@ -168,6 +168,7 @@ def op_acc_stable_run(test_obj, stable_num=100): | |||
else: | |||
with {framework}.no_grad(): | |||
check_aadiff(prev_ret, outputs) | |||
print(i) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
删除无用代码。
op_acc_stable_run(GeluTest(shape=[1, 12288], dtype="bfloat16")) | ||
op_acc_stable_run(GeluTest(shape=[1, 4096, 24576], dtype="float32")) | ||
op_acc_stable_run(GeluTest(shape=[1, 4096, 24576], dtype="float16")) | ||
op_acc_stable_run(GeluTest(shape=[1, 4096, 24576], dtype="bfloat16")) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
修复最后一行的这个问题。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个最后的一行的问题指什么?我不太确定,指最后一行缺少空行吗?
@@ -45,4 +45,4 @@ def check_diff(self, paddle, pd_ret, th_ret): | |||
|
|||
|
|||
if __name__ == "__main__": | |||
op_acc_stable_run(SoftmaxTest) | |||
op_acc_stable_run(SoftmaxTest) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
修复最后一行的这个问题。
No description provided.