Skip to content

Commit

Permalink
Unpin CUDA Nightly (#1064)
Browse files Browse the repository at this point in the history
  • Loading branch information
msaroufim authored Oct 13, 2024
1 parent e6ceb95 commit e7b33bc
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 11 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/regression_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ jobs:
torch-spec: 'torch==2.4.0'
gpu-arch-type: "cuda"
gpu-arch-version: "12.1"
- name: CUDA Nightly (Oct 1)
- name: CUDA Nightly
runs-on: linux.g5.12xlarge.nvidia.gpu
torch-spec: '--pre torch==2.6.0.dev20241001+cu121 --index-url https://download.pytorch.org/whl/nightly/cu121'
torch-spec: '--pre torch --index-url https://download.pytorch.org/whl/nightly/cu121'
gpu-arch-type: "cuda"
gpu-arch-version: "12.1"

Expand Down
5 changes: 2 additions & 3 deletions torchao/prototype/quantized_training/bitnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,8 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
# return new unwrapped object
return out

# new signature https://github.com/pytorch/pytorch/pull/136129
# we need default None for module and mp_policy so this method still works with PyTorch 2.4 and 2.5
def fsdp_pre_all_gather(self, mesh, module=None, mp_policy=None):
# FSDP all-gather extension v1
def fsdp_pre_all_gather(self, mesh):
# quantize and pack into 2-bit to save comm bandwidth
if self._precomputed_scale is not None:
scale = self._precomputed_scale
Expand Down
14 changes: 11 additions & 3 deletions torchao/prototype/quantized_training/int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,17 @@ def __repr__(self):
f"requires_grad={self.requires_grad})"
)

# require https://github.com/pytorch/pytorch/pull/136129 for mixed-precision param_dtype
# we need default None for module and mp_policy so this method still works with PyTorch 2.4 and 2.5
def fsdp_pre_all_gather(self, mesh, module=None, mp_policy=None):
# FSDP all-gather extension v2
# https://github.com/pytorch/pytorch/pull/137005
# we need default values so this method still works with PyTorch 2.4 and 2.5
def fsdp_pre_all_gather(
self,
mesh,
outer_size=None,
outer_stride=None,
module=None,
mp_policy=None,
):
scale = self.scale
if mp_policy is not None:
scale = scale.to(mp_policy.param_dtype)
Expand Down
14 changes: 11 additions & 3 deletions torchao/prototype/quantized_training/int8_mixed_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,17 @@ def unwrap(x: cls):
# return new unwrapped object
return out

# require https://github.com/pytorch/pytorch/pull/136129 for mixed-precision param_dtype
# we need default None for module and mp_policy so this method still works with PyTorch 2.4 and 2.5
def fsdp_pre_all_gather(self, mesh, module=None, mp_policy=None):
# FSDP all-gather extension v2
# https://github.com/pytorch/pytorch/pull/137005
# we need default values so this method still works with PyTorch 2.4 and 2.5
def fsdp_pre_all_gather(
self,
mesh,
outer_size=None,
outer_stride=None,
module=None,
mp_policy=None,
):
# TODO: pre-quantize weight here -> reduce comm bandwidth.
# we will need another tensor subclass to hold the quantized weight.
data = self._data
Expand Down

0 comments on commit e7b33bc

Please sign in to comment.