Skip to content

Commit

Permalink
Update on "Move QAT out of prototype"
Browse files Browse the repository at this point in the history
**Summary:** Move QAT out of prototype so we can provide stronger
BC guarantees moving forward.

**BC-breaking notes**

Before:
```
from torchao.quantization.prototype.qat import (
    disable_4w_fake_quant,
    disable_8da4w_fake_quant,
    enable_4w_fake_quant,
    enable_8da4w_fake_quant,
    ComposableQATQuantizer,
    Int4WeightOnlyQATQuantizer,
    Int4WeightOnlyEmbeddingQATQuantizer
    Int8DynActInt4WeightQATQuantizer,
    Int8DynActInt4WeightQATLinear,
)
from torchao.quantization.prototype.qat.api import (
    FakeQuantizeConfig,
)
from torchao.quantization.prototype.qat.fake_quantizer import (
    FakeQuantizer,
)
```

After:
```
from torchao.quantization.qat import (
    ComposableQATQuantizer,
    Int4WeightOnlyQATQuantizer,
    Int4WeightOnlyEmbeddingQATQuantizer
    Int8DynActInt4WeightQATQuantizer,
)
from torchao.quantization.qat.linear import (
    disable_4w_fake_quant,
    disable_8da4w_fake_quant,
    enable_4w_fake_quant,
    enable_8da4w_fake_quant,
    Int8DynActInt4WeightQATLinear,
)
from torchao.quantization.qat.api import (
    FakeQuantizeConfig,
)
from torchao.quantization.qat.fake_quantizer import (
    FakeQuantizer,
)
```

**Test Plan:**
python test/quantization/test_qat.py

[ghstack-poisoned]
  • Loading branch information
andrewor14 committed Oct 17, 2024
2 parents a5a0428 + bdb89a0 commit 58f402d
Show file tree
Hide file tree
Showing 8 changed files with 349 additions and 175 deletions.
55 changes: 55 additions & 0 deletions .github/workflows/build_wheels_aarch64_linux.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# From https://github.com/pytorch/test-infra/wiki/Using-Nova-Reusable-Build-Workflows
name: Build AARCH64 Linux Wheels

on:
pull_request:
paths:
- build/packaging/**
- .github/workflows/build_wheels_aarch64_linux.yml
- setup.py
push:
branches:
- nightly
- main
- release/*
tags:
# NOTE: Binary build pipelines should only get triggered on release candidate builds
# Release candidate tags look like: v1.11.0-rc1
- v[0-9]+.[0-9]+.[0-9]+-rc[0-9]+
schedule:
- cron: '0 0 * * *' # Runs at midnight UTC every day
workflow_dispatch:

jobs:
generate-matrix:
uses: pytorch/test-infra/.github/workflows/generate_binary_build_matrix.yml@main
with:
package-type: wheel
os: linux-aarch64
test-infra-repository: pytorch/test-infra
test-infra-ref: main
with-cuda: disable

build:
needs: generate-matrix
permissions:
id-token: write
contents: read
uses: pytorch/test-infra/.github/workflows/build_wheels_linux.yml@main
with:
# Set the ref to an empty string instead of the default nightly because
# torchao doesn't have nightly branch setup yet, instead the build is
# triggered daily from main with a schedule
repository: pytorch/ao
ref: ""
build-matrix: ${{ needs.generate-matrix.outputs.matrix }}
env-var-script: packaging/env_var_script_linux.sh
pre-script: packaging/pre_build_script.sh
# post-script: packaging/post_build_script.sh
smoke-test-script: packaging/smoke_test.py
package-name: torchao
trigger-event: ${{ github.event_name }}
architecture: aarch64
setup-miniconda: false
secrets:
PYPI_API_TOKEN: ${{ secrets.PYPI_API_TOKEN }}
2 changes: 1 addition & 1 deletion .github/workflows/build_wheels_linux.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ jobs:
os: linux
with-cpu: enable
with-cuda: enable
with-rocm: disable
with-rocm: enable

build:
needs: generate-matrix
Expand Down
62 changes: 62 additions & 0 deletions .github/workflows/build_wheels_windows.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
name: Build Windows Wheels

on:
pull_request:
paths:
- build/packaging/**
- .github/workflows/build_wheels_windows.yml
- setup.py
push:
branches:
- nightly
- main
- release/*
tags:
# NOTE: Binary build pipelines should only get triggered on release candidate builds
# Release candidate tags look like: v1.11.0-rc1
- v[0-9]+.[0-9]+.[0-9]+-rc[0-9]+
schedule:
- cron: '0 0 * * *' # Runs at midnight UTC every day
workflow_dispatch:

permissions:
id-token: write
contents: read

jobs:
generate-matrix:
uses: pytorch/test-infra/.github/workflows/generate_binary_build_matrix.yml@main
with:
package-type: wheel
os: windows
test-infra-repository: pytorch/test-infra
test-infra-ref: main
with-xpu: enable
with-cuda: disable

build:
needs: generate-matrix
strategy:
fail-fast: false
matrix:
include:
- repository: pytorch/ao
pre-script: packaging/pre_build_script.sh
env-script: packaging/vc_env_helper.bat
# post-script: "python packaging/wheel/relocate.py"
smoke-test-script: packaging/smoke_test.py
package-name: torchao
name: ${{ matrix.repository }}
uses: pytorch/test-infra/.github/workflows/build_wheels_windows.yml@main
with:
repository: ${{ matrix.repository }}
ref: ""
test-infra-repository: pytorch/test-infra
test-infra-ref: main
build-matrix: ${{ needs.generate-matrix.outputs.matrix }}
pre-script: ${{ matrix.pre-script }}
env-script: ${{ matrix.env-script }}
post-script: ${{ matrix.post-script }}
package-name: ${{ matrix.package-name }}
smoke-test-script: ${{ matrix.smoke-test-script }}
trigger-event: ${{ github.event_name }}
43 changes: 43 additions & 0 deletions packaging/vc_env_helper.bat
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
@echo on

set VC_VERSION_LOWER=17
set VC_VERSION_UPPER=18
if "%VC_YEAR%" == "2019" ( set VC_VERSION_LOWER=16 set VC_VERSION_UPPER=17)
if "%VC_YEAR%" == "2017" ( set VC_VERSION_LOWER=15 set VC_VERSION_UPPER=16)

for /f "usebackq tokens=*" %%i in (`"%ProgramFiles(x86)%\Microsoft Visual Studio\Installer\vswhere.exe" -legacy -products * -version [%VC_VERSION_LOWER%^,%VC_VERSION_UPPER%^) -property installationPath`) do (
if exist "%%i" if exist "%%i\VC\Auxiliary\Build\vcvarsall.bat" (
set "VS15INSTALLDIR=%%i"
set "VS15VCVARSALL=%%i\VC\Auxiliary\Build\vcvarsall.bat"
goto vswhere
)
)

:vswhere
if "%VSDEVCMD_ARGS%" == "" (
call "%VS15VCVARSALL%" x64 || exit /b 1
) else (
call "%VS15VCVARSALL%" x64 %VSDEVCMD_ARGS% || exit /b 1
)

@echo on

if "%CU_VERSION%" == "xpu" call "C:\Program Files (x86)\Intel\oneAPI\setvars.bat"

set DISTUTILS_USE_SDK=1

set args=%1
shift
:start
if [%1] == [] goto done
set args=%args% %1
shift
goto start

:done
if "%args%" == "" (
echo Usage: vc_env_helper.bat [command] [args]
echo e.g. vc_env_helper.bat cl /c test.cpp
)

%args% || exit /b 1
50 changes: 49 additions & 1 deletion test/quantization/test_qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
)
from torchao.quantization.qat.linear import (
FakeQuantizedLinear,
Int8DynActInt4WeightQATLinear,
Int4WeightOnlyQATLinear
)
from torchao.quantization.qat.utils import (
_choose_qparams_per_token_asymmetric,
Expand Down Expand Up @@ -66,6 +68,10 @@
TORCH_VERSION_AT_LEAST_2_5,
)

from torchao.quantization.GPTQ import (
_replace_linear_8da4w,
_replace_linear_int4
)

# TODO: put this in a common test utils file
_CUDA_IS_AVAILABLE = torch.cuda.is_available()
Expand Down Expand Up @@ -854,6 +860,48 @@ def linear_forward_4w(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
fq_out = fq_linear(x)
baseline_out = linear_forward_4w(x2, fq_linear.weight)
torch.testing.assert_close(baseline_out, fq_out, atol=0, rtol=0)

@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
def test_replace_linear_8da4w(self):
module = torch.nn.ModuleList([
torch.nn.Linear(in_features=256, out_features=50, bias=True)
])
_replace_linear_8da4w(module, 256, False, torch.float32, torch.float32, Int8DynActInt4WeightQATLinear, copy_weights=True)
assert(not isinstance(module[0], Int8DynActInt4WeightQATLinear) and isinstance(module[0], torch.nn.Linear))
module = torch.nn.ModuleList([
torch.nn.Linear(in_features=256, out_features=50, bias=False)
])
_replace_linear_8da4w(module, 256, False, torch.float32, torch.float32, Int8DynActInt4WeightQATLinear, copy_weights=True)
assert(isinstance(module[0], Int8DynActInt4WeightQATLinear))

@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
def test_replace_linear_int4(self):
module = torch.nn.ModuleList([
torch.nn.Linear(in_features=256, out_features=50, bias=True)
])
_replace_linear_int4(
module,
256,
8,
padding_allowed=True,
precision=torch.bfloat16,
scales_precision=torch.bfloat16,
linear_class=Int4WeightOnlyQATLinear,
copy_weights=True)
assert(not isinstance(module[0], Int4WeightOnlyQATLinear) and isinstance(module[0], torch.nn.Linear))
module = torch.nn.ModuleList([
torch.nn.Linear(in_features=256, out_features=50, bias=False)
])
_replace_linear_int4(
module,
256,
8,
padding_allowed=True,
precision=torch.bfloat16,
scales_precision=torch.bfloat16,
linear_class=Int4WeightOnlyQATLinear,
copy_weights=True)
assert(isinstance(module[0], Int4WeightOnlyQATLinear))

@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
def test_fake_quantized_embedding_4w(self):
Expand Down Expand Up @@ -891,4 +939,4 @@ def embedding_forward_4w(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:


if __name__ == "__main__":
unittest.main()
unittest.main()
17 changes: 8 additions & 9 deletions torchao/csrc/cuda/fp6_llm/fp6_linear.cu
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,16 @@
inline bool isSM75GPU() {
int device;
cudaError_t err = cudaGetDevice(&device);
if (err != cudaSuccess) {
return false;
}
if (err != cudaSuccess) return false;

cudaDeviceProp props;
err = cudaGetDeviceProperties(&props, device);
if (err != cudaSuccess) {
return false;
}
int major, minor;
err = cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, device);
if (err != cudaSuccess) return false;

err = cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, device);
if (err != cudaSuccess) return false;

return (props.major == 7) && (props.minor == 5);
return (major == 7) && (minor == 5);
}

template<typename TilingConfig, typename OutputDataType, int EXPONENT, int MANTISSA>
Expand Down
Loading

0 comments on commit 58f402d

Please sign in to comment.