From d18c5d36a878b46e16445f619ecf5de359d63755 Mon Sep 17 00:00:00 2001 From: Mergen Nachin Date: Wed, 17 Apr 2024 16:53:19 -0400 Subject: [PATCH] Manually fix a few remaining lints (#238) --- generate.py | 14 +++++++++++--- quantized_ops.py | 30 +++++++++++++++--------------- 2 files changed, 26 insertions(+), 18 deletions(-) diff --git a/generate.py b/generate.py index a199c7e4b..2dfdcac45 100644 --- a/generate.py +++ b/generate.py @@ -314,6 +314,9 @@ def encode_tokens(tokenizer, string, bos=True, device="cpu"): return torch.tensor(tokens, dtype=torch.int, device=device) +B_INST, E_INST = "[INST]", "[/INST]" + + def _main( builder_args: BuilderArgs, speculative_builder_args: BuilderArgs, @@ -330,6 +333,7 @@ def _main( # from tp import maybe_init_dist # rank = maybe_init_dist() use_tp = False + rank: Optional[int] = None # if use_tp: # if rank != 0: # # only print on rank 0 @@ -417,8 +421,9 @@ def _main( period_id = tokenizer.encode(".")[0] done_generating = False - def callback(x): - nonlocal done_generating + def callback( + x, buffer=buffer, period_id=period_id, done_generating=done_generating + ): if done_generating: return buffer.append(tokenizer.decode([period_id] + x.tolist())[1:]) @@ -430,7 +435,10 @@ def callback(x): # print(, end='', flush=True) else: - callback = lambda x: x + + def callback(x): + return x + t0 = time.perf_counter() import contextlib diff --git a/quantized_ops.py b/quantized_ops.py index b4b0c8ae8..bf225469c 100644 --- a/quantized_ops.py +++ b/quantized_ops.py @@ -8,7 +8,7 @@ import torch import torch.nn.functional as F -from torch.library import impl, impl_abstract +from torch.library import impl torchchat_lib = torch.library.Library("torchchat", "DEF") @@ -25,21 +25,21 @@ def embedding_int8( ) -> torch.Tensor: indices = input # embedding_byte_weight_checks(weight, weight_scales, weight_zero_points) - groupsize = weight.size(1) // (scales.size(1) if scales.dim() == 2 else 1) + # groupsize = weight.size(1) // (scales.size(1) if scales.dim() == 2 else 1) # ET definition - if False: - weight_zero_points = None - weight = torch.ops.quantized_decomposed.dequantize_per_channel_group.default( - weight, - weight_scales, - weight_zero_points, - weight_quant_min, - weight_quant_max, - weight.dtype, - groupsize, - weight_scales.dtype, - ) - return torch.ops.aten.embedding.default(weight, indices) + # if False: + # weight_zero_points = None + # weight = torch.ops.quantized_decomposed.dequantize_per_channel_group.default( + # weight, + # weight_scales, + # weight_zero_points, + # weight_quant_min, + # weight_quant_max, + # weight.dtype, + # groupsize, + # weight_scales.dtype, + # ) + # return torch.ops.aten.embedding.default(weight, indices) scales = scales.view(weight.shape[0], -1) result_weights = F.embedding(indices, weight)