diff --git a/csrc/ops.cu b/csrc/ops.cu index 1f259d67f..8c72b22b4 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -824,8 +824,8 @@ template void gemm_4bit_inference(int m, int n, int k, T * A, unsi template void gemm_4bit_inference_naive(int m, int n, int k, T * A, unsigned char* B, float *absmax, float *datatype, T * out, int lda, int ldb, int ldc, int blocksize, cudaStream_t stream) { - int num_blocks = (m+7)/8; - kgemm_4bit_inference_naive<<< num_blocks, 256, 0, stream>>>(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize); + int num_blocks = (m+3)/4; + kgemm_4bit_inference_naive<<< num_blocks, 128, 0, stream>>>(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize); CUDA_CHECK_RETURN(cudaPeekAtLastError()); }