Skip to content

Commit

Permalink
Remove amax
Browse files Browse the repository at this point in the history
Signed-off-by: Jiang, Zhiwei <[email protected]>
  • Loading branch information
zhiweij1 committed Jul 10, 2024
1 parent b5dcbc7 commit fa9f8da
Showing 1 changed file with 2 additions and 10 deletions.
12 changes: 2 additions & 10 deletions features/feature_case/cublasLt/matmul.cu
Original file line number Diff line number Diff line change
Expand Up @@ -729,8 +729,7 @@ void fgemmlt(cublasLtHandle_t ltHandle, int m, int n, int k,
cublasLtMatrixLayout_t Adesc,
cublasLtMatrixLayout_t Bdesc,
cublasLtMatrixLayout_t Cdesc,
cublasLtMatrixLayout_t Ddesc,
float *amax_d) {
cublasLtMatrixLayout_t Ddesc) {
cublasLtMatmulDesc_t matmulDesc = NULL;
cublasLtMatmulDescCreate(&matmulDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F);

Expand All @@ -747,7 +746,6 @@ void fgemmlt(cublasLtHandle_t ltHandle, int m, int n, int k,
cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, &scale_a, sizeof(scale_a));
cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, &scale_b, sizeof(scale_b));
cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_D_SCALE_POINTER, &scale_d, sizeof(scale_d));
cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_AMAX_D_POINTER, &amax_d, sizeof(amax_d));

cublasLtEpilogue_t ep = CUBLASLT_EPILOGUE_RELU;
cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &ep, sizeof(ep));
Expand Down Expand Up @@ -818,11 +816,8 @@ bool test7() {

// Matmul

float *amax_d;
cudaMallocManaged(&amax_d, sizeof(float));

fgemmlt(ltHandle, m, n, k, (const float *)Adev, (const float *)Bdev, (const float *)Cdev, (float *)Ddev,
&alpha, &beta, lda, ldb, ldc, ldd, Adesc_col_major, Bdesc_col_major, Cdesc_col_major, Ddesc_col_major, amax_d);
&alpha, &beta, lda, ldb, ldc, ldd, Adesc_col_major, Bdesc_col_major, Cdesc_col_major, Ddesc_col_major);
cudaStreamSynchronize(0);

// Check result
Expand All @@ -837,14 +832,11 @@ bool test7() {
break;
}
}
if (*amax_d != 8300)
error = true;

printf("d:\n");
for (int i = 0; i < ldd * n; i++)
printf("%f, ", Dhost[i]);
printf("\n");
printf("amax_d:%f\n", *amax_d);

if (error) {
printf("error\n");
Expand Down

0 comments on commit fa9f8da

Please sign in to comment.