Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Better tinygemm warning for T4 #1112

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 24 additions & 16 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from torchao.dtypes.floatx import from_scaled_tc_floatx
from torchao.sparsity.marlin import marlin_24_workspace, pack_to_marlin_24, inject_24
import pytest
from unittest.mock import patch


if is_fbcode():
pytest.skip("Skipping the test in fbcode since we don't have TARGET file for kernels")
Expand Down Expand Up @@ -274,7 +276,6 @@ def test_dequantize_tensor_core_tiled_layout_correctness_unpack_and_dequant(shap
assert diff_op_ao < 1e-1

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
# @pytest.mark.skipif(TORCH_VERSION_AT_LEAST_2_5, reason="weight packing is updated in 2.5+")
@pytest.mark.parametrize("shape, inner_k_tiles, group_size", TEST_CONFIGS_DEQUANT, ids=str)
def test_dequantize_tensor_core_tiled_layout_op(shape, inner_k_tiles, group_size):
n, k = shape
Expand All @@ -287,23 +288,30 @@ def test_dequantize_tensor_core_tiled_layout_op(shape, inner_k_tiles, group_size
q_groups = k // group_size
scales = torch.randn(n, q_groups, dtype=torch.bfloat16, device=device)
zeros = torch.randn_like(scales)
scales_and_zeros = pack_tinygemm_scales_and_zeros(scales, zeros)

test_utils = [
"test_schema",
"test_autograd_registration",
"test_faketensor",
]
# TODO: Figure out why test fails unless torch >= 2.5
if TORCH_VERSION_AT_LEAST_2_5:
test_utils.append("test_aot_dispatch_dynamic")
opcheck(
torch.ops.torchao.dequantize_tensor_core_tiled_layout,
(packed_w, scales_and_zeros, group_size, inner_k_tiles),
test_utils=test_utils,
)

# Test the case where CUDA SM version is less than 8.0
with patch('torch.cuda.get_device_capability', return_value=(7, 5)):
with pytest.raises(NotImplementedError) as excinfo:
scales_and_zeros = pack_tinygemm_scales_and_zeros(scales, zeros)
assert "4 bit quantization with tinygemm is not supported on this device" in str(excinfo.value)

with patch('torch.cuda.get_device_capability', return_value=(8, 0)):
scales_and_zeros = pack_tinygemm_scales_and_zeros(scales, zeros)
test_utils = [
"test_schema",
"test_autograd_registration",
"test_faketensor",
]
# TODO: Figure out why test fails unless torch >= 2.5
if TORCH_VERSION_AT_LEAST_2_5:
test_utils.append("test_aot_dispatch_dynamic")
opcheck(
torch.ops.torchao.dequantize_tensor_core_tiled_layout,
(packed_w, scales_and_zeros, group_size, inner_k_tiles),
test_utils=test_utils,
)


MARLIN_24_BATCH_SIZE = [1, 4, 8, 16, 32, 64]
MARLIN_24_K_CHUNKS = [128]
MARLIN_24_N_CHUNKS = [512]
Expand Down
5 changes: 5 additions & 0 deletions torchao/quantization/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,11 @@ def get_groupwise_affine_qparams(w, n_bit=4, groupsize=128, dtype=torch.bfloat16


def pack_tinygemm_scales_and_zeros(scales, zeros, dtype=torch.bfloat16):
if torch.cuda.is_available():
min_sm = (8, 0)
if torch.cuda.get_device_capability() < min_sm:
raise NotImplementedError(f"4 bit quantization with tinygemm is not supported on this device as it requires sm_{min_sm[0]}.{min_sm[1]} or higher but got {torch.cuda.get_device_capability()}")

guard_dtype_size(scales, "scales", dtype=dtype, size=zeros.size())
guard_dtype_size(zeros, "zeros", dtype=dtype)
return (
Expand Down
Loading