Skip to content

Commit

Permalink
[Misc] Use torch.Tensor for type annotation (vllm-project#6505)
Browse files Browse the repository at this point in the history
  • Loading branch information
WoosukKwon authored and dtrifiro committed Jul 17, 2024
1 parent 6aa0b94 commit 212a93b
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 18 deletions.
34 changes: 17 additions & 17 deletions benchmarks/cutlass_benchmarks/w8a8_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,18 @@
# helpers


def to_fp8(tensor: torch.tensor) -> torch.tensor:
def to_fp8(tensor: torch.Tensor) -> torch.Tensor:
finfo = torch.finfo(torch.float8_e4m3fn)
return torch.round(tensor.clamp(
min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn)


def to_int8(tensor: torch.tensor) -> torch.tensor:
def to_int8(tensor: torch.Tensor) -> torch.Tensor:
return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8)


def make_rand_tensors(dtype: torch.dtype, m: int, n: int,
k: int) -> Tuple[torch.tensor, torch.tensor]:
k: int) -> Tuple[torch.Tensor, torch.Tensor]:

a = torch.randn((m, k), device='cuda') * 5
b = torch.randn((n, k), device='cuda').t() * 5
Expand All @@ -47,25 +47,25 @@ def make_rand_tensors(dtype: torch.dtype, m: int, n: int,
# impl


def pytorch_mm_impl(a: torch.tensor, b: torch.tensor, scale_a: torch.tensor,
scale_b: torch.tensor,
out_dtype: torch.dtype) -> torch.tensor:
def pytorch_mm_impl(a: torch.Tensor, b: torch.Tensor, scale_a: torch.Tensor,
scale_b: torch.Tensor,
out_dtype: torch.dtype) -> torch.Tensor:
return torch.mm(a, b)


def pytorch_fp8_impl(a: torch.tensor, b: torch.tensor, scale_a: torch.tensor,
scale_b: torch.tensor,
out_dtype: torch.dtype) -> torch.tensor:
def pytorch_fp8_impl(a: torch.Tensor, b: torch.Tensor, scale_a: torch.Tensor,
scale_b: torch.Tensor,
out_dtype: torch.dtype) -> torch.Tensor:
return torch._scaled_mm(a,
b,
scale_a=scale_a,
scale_b=scale_b,
out_dtype=out_dtype)


def pytorch_fp8_impl_fast_accum(a: torch.tensor, b: torch.tensor,
scale_a: torch.tensor, scale_b: torch.tensor,
out_dtype: torch.dtype) -> torch.tensor:
def pytorch_fp8_impl_fast_accum(a: torch.Tensor, b: torch.Tensor,
scale_a: torch.Tensor, scale_b: torch.Tensor,
out_dtype: torch.dtype) -> torch.Tensor:
return torch._scaled_mm(a,
b,
scale_a=scale_a,
Expand All @@ -74,15 +74,15 @@ def pytorch_fp8_impl_fast_accum(a: torch.tensor, b: torch.tensor,
use_fast_accum=True)


def cutlass_impl(a: torch.tensor, b: torch.tensor, scale_a: torch.tensor,
scale_b: torch.tensor,
out_dtype: torch.dtype) -> torch.tensor:
def cutlass_impl(a: torch.Tensor, b: torch.Tensor, scale_a: torch.Tensor,
scale_b: torch.Tensor,
out_dtype: torch.dtype) -> torch.Tensor:
return ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype=out_dtype)


# bench
def bench_fn(a: torch.tensor, b: torch.tensor, scale_a: torch.tensor,
scale_b: torch.tensor, out_dtype: torch.dtype, label: str,
def bench_fn(a: torch.Tensor, b: torch.Tensor, scale_a: torch.Tensor,
scale_b: torch.Tensor, out_dtype: torch.dtype, label: str,
sub_label: str, fn: Callable, description: str) -> TMeasurement:

min_run_time = 1
Expand Down
2 changes: 1 addition & 1 deletion vllm/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def __init__(
# initialize_cache.
self.cache_engine: List[CacheEngine]
# Initialize gpu_cache as embedding models don't initialize kv_caches
self.gpu_cache: Optional[List[List[torch.tensor]]] = None
self.gpu_cache: Optional[List[List[torch.Tensor]]] = None

def init_device(self) -> None:
if self.device_config.device.type == "cuda":
Expand Down

0 comments on commit 212a93b

Please sign in to comment.