Skip to content

Commit

Permalink
kernels: disambiguate quantized types via a new ScalarType
Browse files Browse the repository at this point in the history
Co-authored-by: Lucas Wilkinson <[email protected]>
  • Loading branch information
AlpinDale and LucasWilkinson committed Sep 1, 2024
1 parent 7844103 commit 141672a
Show file tree
Hide file tree
Showing 27 changed files with 1,008 additions and 295 deletions.
8 changes: 7 additions & 1 deletion .github/workflows/publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,13 @@
name: Create Release

on:
schedule:
- cron: '0 2 * * *'
push:
branches:
- 'rc_054'
tags:
- v*
- 'v*'

# Needed to create release and upload assets
permissions:
Expand Down Expand Up @@ -55,6 +59,8 @@ jobs:
steps:
- name: Checkout
uses: actions/checkout@v3
with:
ref: 'rc_054'

- name: Set up Linux Env
if: ${{ runner.os == 'Linux' }}
Expand Down
52 changes: 35 additions & 17 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,39 @@ endif()
#
find_package(Torch REQUIRED)

#
# Add the `default` target which detects which extensions should be
# built based on platform/architecture. This is the same logic that
# setup.py uses to select which extensions should be built and should
# be kept in sync.
#
# The `default` target makes direct use of cmake easier since knowledge
# of which extensions are supported has been factored in, e.g.
#
# mkdir build && cd build
# cmake -G Ninja -DAPHRODITE_PYTHON_EXECUTABLE=`which python3` -DCMAKE_LIBRARY_OUTPUT_DIRECTORY=../aphrodite ..
# cmake --build . --target default
#
add_custom_target(default)
message(STATUS "Enabling core extension.")

# Define _core_C extension
# built for (almost) every target platform, (excludes TPU and Neuron)

set(APHRODITE_EXT_SRC
"kernels/core/torch_bindings.cpp")

define_gpu_extension_target(
_core_C
DESTINATION aphrodite
LANGUAGE CXX
SOURCES ${APHRODITE_EXT_SRC}
COMPILE_FLAGS ${CXX_COMPILE_FLAGS}
USE_SABI 3
WITH_SOABI)

add_dependencies(default _core_C)

#
# Forward the non-CUDA device extensions to external CMake scripts.
#
Expand All @@ -74,7 +107,7 @@ if (NOT APHRODITE_TARGET_DEVICE STREQUAL "cuda" AND
if (APHRODITE_TARGET_DEVICE STREQUAL "cpu")
include(${CMAKE_CURRENT_LIST_DIR}/cmake/cpu_extension.cmake)
else()
message(FATAL_ERROR "Unsupported Aphrodite target device: ${APHRODITE_TARGET_DEVICE}")
return()
endif()
return()
endif()
Expand Down Expand Up @@ -132,7 +165,7 @@ if(NVCC_THREADS AND APHRODITE_GPU_LANG STREQUAL "CUDA")
endif()

#
# Define extension targets
# Define other extension targets
#

#
Expand Down Expand Up @@ -227,21 +260,6 @@ define_gpu_extension_target(
USE_SABI 3
WITH_SOABI)

#
# Add the `default` target which detects which extensions should be
# built based on platform/architecture. This is the same logic that
# setup.py uses to select which extensions should be built and should
# be kept in sync.
#
# The `default` target makes direct use of cmake easier since knowledge
# of which extensions are supported has been factored in, e.g.
#
# mkdir build && cd build
# cmake -G Ninja -DAPHRODITE_PYTHON_EXECUTABLE=`which python3` -DCMAKE_LIBRARY_OUTPUT_DIRECTORY=../aphrodite ..
# cmake --build . --target default
#
add_custom_target(default)

if(APHRODITE_GPU_LANG STREQUAL "CUDA" OR APHRODITE_GPU_LANG STREQUAL "HIP")
message(STATUS "Enabling C extension.")
add_dependencies(default _C)
Expand Down
3 changes: 3 additions & 0 deletions Dockerfile.openvino
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ COPY requirements-common.txt /workspace/aphrodite-engine/
COPY requirements-openvino.txt /workspace/aphrodite-engine/

COPY aphrodite/ /workspace/aphrodite-engine/aphrodite
COPY kernels/core /workspace/aphrodite-engine/kernels/core
COPY cmake/utils.cmake /workspace/aphrodite-engine/cmake/
COPY CMakeLists.txt /workspace/aphrodite-engine/
COPY setup.py /workspace/aphrodite-engine/

# install build requirements
Expand Down
176 changes: 176 additions & 0 deletions aphrodite/_core_ext.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
import importlib.util
from enum import Enum
from typing import TYPE_CHECKING, Optional, Union

import torch
from loguru import logger

core_C_available = importlib.util.find_spec('._core_C',
'aphrodite') is not None


# Mirrors enum in `core/scalar_type.hpp`
class NanRepr(Enum):
NONE = 0 # nans are not supported
IEEE_754 = 1 # nans are: Exp all 1s, mantissa not all 0s
EXTD_RANGE_MAX_MIN = 2 # nans are: Exp all 1s, mantissa all 1s


if TYPE_CHECKING or not core_C_available:
# On platforms were we cannot use/build the C++ core extension (i.e. namely
# neuron and tpu), we define the mock ScalarType class here that partially
# mimics the C++ ScalarType class.
#
# We also use this provide type signatures to the Python LSP for the methods
# in the C++ ScalarType class. So these type signatures should be kept
# in sync with csrc/core/scalar_type.hpp

from dataclasses import dataclass

@dataclass(frozen=True)
class ScalarType:
"""
ScalarType can represent a wide range of floating point and integer
types, in particular it can be used to represent sub-byte data types
(something that torch.dtype currently does not support). It is also
capable of representing types with a bias, i.e.:
`stored_value = value + bias`,
this is useful for quantized types (e.g. standard GPTQ 4bit uses a bias
of 8). The implementation for this class can be found in
csrc/core/scalar_type.hpp, these type signatures should be kept in sync
with that file.
"""

exponent: int
"""
Number of bits in the exponent if this is a floating point type
(zero if this an integer type)
"""

mantissa: int
"""
Number of bits in the mantissa if this is a floating point type,
or the number bits representing an integer excluding the sign bit if
this an integer type.
"""

bias: int
"""
bias used to encode the values in this scalar type
(value = stored_value - bias, default 0) for example if we store the
type as an unsigned integer with a bias of 128 then the value 0 will be
stored as 128 and -1 will be stored as 127 and 1 will be stored as 129.
"""

signed: bool
"If the type is signed (i.e. has a sign bit)"

_finite_values_only: bool = False
"""
Private: if NANs are supported, used `has_infs()` instead.
"""

nan_repr: int = NanRepr.IEEE_754.value
"""
How NaNs are represent in this scalar type, returns NanRepr value.
(not applicable for integer types)
"""

@property
def size_bits(self):
return self.exponent + self.mantissa + int(self.signed)

def min(self) -> Union[int, float]:
"""
Min representable value for this scalar type.
(accounting for bias if there is one)
"""
raise NotImplementedError

def max(self) -> Union[int, float]:
"""
Max representable value for this scalar type.
(accounting for bias if there is one)
"""
raise NotImplementedError

def is_signed(self) -> bool:
"""
If the type is signed (i.e. has a sign bit), same as `signed`
added for consistency with:
https://pytorch.org/docs/stable/generated/torch.Tensor.is_signed.html
"""
...

def is_floating_point(self):
"If the type is a floating point type"
return self.exponent != 0

def is_integer(self):
"If the type is an integer type"
return self.exponent == 0

def has_bias(self):
"If the type has a non-zero bias"
return self.bias != 0

def has_infs(self):
"If the type is floating point and supports infinity"
return not self._finite_values_only

def has_nans(self):
return self.nan_repr != NanRepr.NONE.value

def is_ieee_754(self) -> bool:
"""
If the type is a floating point type that follows IEEE 754
conventions
"""
return self.nan_repr == NanRepr.IEEE_754.value and \
not self._finite_values_only

def __str__(self) -> str:
raise NotImplementedError

def __repr__(self) -> str:
raise NotImplementedError

#
# Convenience Constructors
#

@classmethod
def int_(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType':
"Create a signed integer scalar type (size_bits includes sign-bit)."
return cls(size_bits - 1, size_bits, bias if bias else 0, True)

@classmethod
def uint(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType':
"""Create a unsigned integer scalar type."""
return cls(size_bits, size_bits, bias if bias else 0, False)

@classmethod
def float_IEEE754(cls, exponent: int, mantissa: int) -> 'ScalarType':
"""
Create a standard floating point type
(i.e. follows IEEE 754 conventions).
"""
return cls(exponent, mantissa, 0, True)

@classmethod
def float_(cls, exponent: int, mantissa: int, finite_values_only: bool,
nan_repr: int):
"""
Create a non-standard floating point type
(i.e. does not follow IEEE 754 conventions).
"""
return cls(exponent, mantissa, 0, True, finite_values_only,
nan_repr)

elif core_C_available:
try:
import aphrodite._core_C # noqa: F401
except ImportError as e:
logger.warning(f"Failed to import from aphrodite._core_C with {e}")

ScalarType = torch.classes._core_C.ScalarType
30 changes: 20 additions & 10 deletions aphrodite/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import torch
from loguru import logger

from aphrodite._core_ext import ScalarType

try:
import aphrodite._C
except ImportError as e:
Expand Down Expand Up @@ -217,10 +219,10 @@ def marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
# marlin_24
def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
b_meta: torch.Tensor, b_scales: torch.Tensor,
workspace: torch.Tensor, num_bits: int, size_m: int,
size_n: int, size_k: int) -> torch.Tensor:
workspace: torch.Tensor, b_q_type: ScalarType,
size_m: int, size_n: int, size_k: int) -> torch.Tensor:
return torch.ops._C.gptq_marlin_24_gemm(a, b_q_weight, b_meta, b_scales,
workspace, num_bits, size_m,
workspace, b_q_type, size_m,
size_n, size_k)


Expand Down Expand Up @@ -284,14 +286,22 @@ def awq_marlin_repack(b_q_weight: torch.Tensor, size_k: int, size_n: int,
return torch.ops._C.awq_marlin_repack(b_q_weight, size_k, size_n, num_bits)


def gptq_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
b_scales: torch.Tensor, b_zeros: torch.Tensor,
g_idx: torch.Tensor, perm: torch.Tensor,
workspace: torch.Tensor, num_bits: int, size_m: int,
size_n: int, size_k: int, is_k_full: bool, has_zp: bool,
use_fp32_reduce: bool) -> torch.Tensor:
def gptq_marlin_gemm(a: torch.Tensor,
b_q_weight: torch.Tensor,
b_scales: torch.Tensor,
b_zeros: torch.Tensor,
g_idx: torch.Tensor,
perm: torch.Tensor,
workspace: torch.Tensor,
b_q_type: ScalarType,
size_m: int,
size_n: int,
size_k: int,
is_k_full: bool,
has_zp: bool = False,
use_fp32_reduce: bool = False) -> torch.Tensor:
return torch.ops._C.gptq_marlin_gemm(a, b_q_weight, b_scales, b_zeros,
g_idx, perm, workspace, num_bits,
g_idx, perm, workspace, b_q_type,
size_m, size_n, size_k, is_k_full,
has_zp, use_fp32_reduce)

Expand Down
Loading

0 comments on commit 141672a

Please sign in to comment.