From fa9f8da267ab201ad0ca33bbbdd2c88e47786260 Mon Sep 17 00:00:00 2001 From: "Jiang, Zhiwei" Date: Wed, 10 Jul 2024 13:16:14 +0800 Subject: [PATCH] Remove amax Signed-off-by: Jiang, Zhiwei --- features/feature_case/cublasLt/matmul.cu | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/features/feature_case/cublasLt/matmul.cu b/features/feature_case/cublasLt/matmul.cu index 9384795ee..49ec4d314 100644 --- a/features/feature_case/cublasLt/matmul.cu +++ b/features/feature_case/cublasLt/matmul.cu @@ -729,8 +729,7 @@ void fgemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, cublasLtMatrixLayout_t Adesc, cublasLtMatrixLayout_t Bdesc, cublasLtMatrixLayout_t Cdesc, - cublasLtMatrixLayout_t Ddesc, - float *amax_d) { + cublasLtMatrixLayout_t Ddesc) { cublasLtMatmulDesc_t matmulDesc = NULL; cublasLtMatmulDescCreate(&matmulDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F); @@ -747,7 +746,6 @@ void fgemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, &scale_a, sizeof(scale_a)); cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, &scale_b, sizeof(scale_b)); cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_D_SCALE_POINTER, &scale_d, sizeof(scale_d)); - cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_AMAX_D_POINTER, &amax_d, sizeof(amax_d)); cublasLtEpilogue_t ep = CUBLASLT_EPILOGUE_RELU; cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &ep, sizeof(ep)); @@ -818,11 +816,8 @@ bool test7() { // Matmul - float *amax_d; - cudaMallocManaged(&amax_d, sizeof(float)); - fgemmlt(ltHandle, m, n, k, (const float *)Adev, (const float *)Bdev, (const float *)Cdev, (float *)Ddev, - &alpha, &beta, lda, ldb, ldc, ldd, Adesc_col_major, Bdesc_col_major, Cdesc_col_major, Ddesc_col_major, amax_d); + &alpha, &beta, lda, ldb, ldc, ldd, Adesc_col_major, Bdesc_col_major, Cdesc_col_major, Ddesc_col_major); cudaStreamSynchronize(0); // Check result @@ -837,14 +832,11 @@ bool test7() { break; } } - if (*amax_d != 8300) - error = true; printf("d:\n"); for (int i = 0; i < ldd * n; i++) printf("%f, ", Dhost[i]); printf("\n"); - printf("amax_d:%f\n", *amax_d); if (error) { printf("error\n");