Skip to content

Commit

Permalink
int8 and tf32 peakflops warnings and notes
Browse files Browse the repository at this point in the history
  • Loading branch information
carstenbauer committed Mar 15, 2022
1 parent db259b4 commit 11209bd
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 2 deletions.
6 changes: 6 additions & 0 deletions docs/src/examples/peakflops_gpu.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,4 +87,10 @@ Peakflops (TFLOP/s):
├ tensorcores: true
├ dtype: Float16
└ max: 311.2

julia> peakflops_gpu(; dtype=:TensorFloat32, tensorcores=true); # as of writing, requires Julia 1.8 and https://github.com/JuliaGPU/CUDA.jl/pull/1419
Peakflops (TFLOP/s):
├ tensorcores: true
├ dtype: TensorFloat32
└ max: 155.5
```
6 changes: 4 additions & 2 deletions src/peakflops_gpu_wmmas.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ it takes to perform `_kernel_wmma_nwmmas()` many WMMAs on Tensor Cores.
**Keyword arguments:**
* `device` (default: `CUDA.device()`): CUDA device to be used.
* `dtype` (default: `Float16`): element type of the matrices. We currently only support `Float16`, `Int8` (`:TensorFloat32`, `:BFloat16`, and `Float64` might or might not work).
* `dtype` (default: `Float16`): element type of the matrices. We currently only support `Float16` (`Int8`, `:TensorFloat32`, `:BFloat16`, and `Float64` might or might not work).
* `nkernel` (default: `10`): number of kernel calls that make up one benchmarking sample.
* `nbench` (default: `5`): number of measurements to be performed the best of which is used for the TFLOP/s computation.
* `threads` (default: max. threads per block): how many threads to use per block (part of the kernel launch configuration).
Expand All @@ -101,7 +101,9 @@ function peakflops_gpu_wmmas(;
dtype_a = dtype_b = Float16
dtype_c = dtype_d = Float32
elseif Symbol(dtype) == :Int8
# requires CUDA.jl >= 3.8.4
if pkgversion(CUDA) < v"3.8.6"
error("At the time of writing, CUDA#master is required for Int8 WMMA support.")
end
m = n = k = 16
dtype_a = dtype_b = Int8
dtype_c = dtype_d = Int32
Expand Down

0 comments on commit 11209bd

Please sign in to comment.