Skip to content

Commit

Permalink
Revert "Contiguous PA (#424)"
Browse files Browse the repository at this point in the history
This reverts commit 5b7f685.
  • Loading branch information
madamczykhabana authored Oct 25, 2024
1 parent 5b7f685 commit 7bf2a9e
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 29 deletions.
2 changes: 1 addition & 1 deletion requirements-hpu.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ pandas
tabulate
setuptools>=61
setuptools-scm>=8
vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@6cb6e19
vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@341a77f
54 changes: 26 additions & 28 deletions vllm/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,11 +199,10 @@ def generate_decode_buckets(bs_bucket_config, blocks_bucket_config,
bs_buckets = warmup_range(bs_bucket_config)
block_buckets = warmup_range(blocks_bucket_config)
bmin, bstep, bmax = blocks_bucket_config
last_bucket = max_blocks
last_bucket = round_up(max_blocks, bstep)
for bs in bs_buckets:
for blocks in block_buckets:
if blocks > last_bucket:
buckets.append((bs, last_bucket))
break
buckets.append((bs, blocks))
return list(sorted(buckets, key=lambda b: (b[0] * b[1], b[1], b[0])))
Expand Down Expand Up @@ -1003,40 +1002,39 @@ def _prepare_decode(

num_decode_tokens = sum(seq_lens)

block_list = list(itertools.chain(*block_tables))

max_idx = max(block_list)
max_blocks = max(max_idx + 1, len(block_list))
block_bucket_size = find_bucket(
max_blocks, self.bucketing_global_state.decode_block_bucket_cfg)
block_bucket_size = min(block_bucket_size,
self.cache_config.num_gpu_blocks)

block_mapping: List[Union[None, int]] = [None] * block_bucket_size
block_usage: List[Union[None, int]] = [None] * block_bucket_size
block_scales: List[Union[None, float]] = [None] * block_bucket_size

blocks_used = [len(bt) for bt in block_tables if bt]
block_list = []
block_scales = []
for i, bt in enumerate(block_tables):
if bt:
blocks_in_group = len(bt)
block_list.extend(bt)
blocks_in_group = len(bt)
if blocks_in_group > 0:
scale = 1.0 / blocks_in_group
for b in bt:
if block_mapping[b] is None:
block_mapping[b] = i
block_usage[b] = self.block_size
block_scales[b] = scale
block_scales.extend([scale] * blocks_in_group)

block_mapping = [b if b is not None else -1 for b in block_mapping]
block_scales = [b if b is not None else 0.0 for b in block_scales]
block_mapping_nested: List[List[int]] = [
[i] * b_u for i, b_u in enumerate(blocks_used)
]
block_mapping: List[int] = list(
itertools.chain.from_iterable(block_mapping_nested))

for bt, sl in zip(block_tables, slot_mapping):
if bt:
block_usage[bt[-1]] = sl[-1] % self.block_size + 1
block_usage = [u if u is not None else 1 for u in block_usage]
last_block = [
sl % self.block_size + 1 for sl in itertools.chain(*slot_mapping)
]
block_usage = [[self.block_size] * (b_u - 1) + [lb]
for b_u, lb in zip(blocks_used, last_block)]
block_usage = list(itertools.chain(*block_usage))

block_bucket_size = find_bucket(
len(block_list),
self.bucketing_global_state.decode_block_bucket_cfg)
block_list = pad_list(block_list, block_bucket_size, _PAD_BLOCK_ID)
block_groups = pad_list(block_mapping, block_bucket_size,
len(block_tables))
block_mapping = pad_list(block_mapping, block_bucket_size, -1)
block_usage = pad_list(block_usage, block_bucket_size, 1)
block_scales = pad_list(block_scales, block_bucket_size, 0.0)

block_list = torch.tensor(block_list,
dtype=torch.int,
device=self.device)
Expand Down

0 comments on commit 7bf2a9e

Please sign in to comment.