Skip to content

Commit

Permalink
Make sharded Llama export test also compile to IREE module and verify…
Browse files Browse the repository at this point in the history
… numerics (#237)

Verifies the IREE module numerical accuracy compared to execution with
PyTorch.
The prefill step result has very low absolute accuracy of around `1e-2`
for FP32.
The resulting cache state of prefill is way off.
The decode step accuracy is also completely off.

This test is market as skipped until
iree-org/iree#18663 is merged. Without it the
IREE compilation will crash.
  • Loading branch information
sogartar authored Oct 3, 2024
1 parent 8e074b4 commit 8727db0
Showing 1 changed file with 230 additions and 35 deletions.
265 changes: 230 additions & 35 deletions sharktank/tests/models/llama/sharded_llama_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,17 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import unittest
from typing import Any, Dict, List, Tuple
from typing import Any, List, Tuple, Union, OrderedDict
import collections.abc
from sharktank.models.llama.llama import LlamaModelConfig, PagedLlamaModelV1
import sharktank.ops as ops
from sharktank.types import Dataset
from sharktank.types import (
unbox_tensor,
ShardedTensor,
DefaultPrimitiveTensor,
Dataset,
AnyTensor,
)
from sharktank.models.llama.testing import make_random_llama_theta
from sharktank.models.llama.sharding import shard_theta
from sharktank.layers.configs import LlamaHParams
Expand All @@ -18,6 +25,93 @@
import torch
from copy import deepcopy
from shark_turbine.aot import FxProgramsBuilder, export
import iree.runtime
from pathlib import Path


def get_iree_devices(driver: str, device_count: int) -> List[iree.runtime.HalDevice]:
hal_driver = iree.runtime.get_driver(driver)
available_devices = hal_driver.query_available_devices()
# Use the same actual device for all devices.
return [hal_driver.create_device(available_devices[0]) for _ in range(device_count)]


def load_iree_module(
module_path: str,
parameters_path: str,
devices: List[iree.runtime.HalDevice],
) -> Tuple[iree.runtime.VmModule, iree.runtime.VmContext, iree.runtime.VmInstance]:
params_path = Path(parameters_path)
# TODO: make IREE able to load the parameters from the top parameter file
# without having to specify the parameter file for each shard separately.
parameter_index = iree.runtime.ParameterIndex()
for i in range(len(devices)):
parameter_index.load(
file_path=str(
Path(params_path).with_suffix(f".rank{i}{params_path.suffix}")
)
)
parameter_provider = parameter_index.create_provider(scope="model")
vm_instance = iree.runtime.VmInstance()
parameters_module = iree.runtime.create_io_parameters_module(
vm_instance, parameter_provider
)
vm_module = iree.runtime.VmModule.mmap(vm_instance, str(module_path))
hal_module = iree.runtime.create_hal_module(instance=vm_instance, devices=devices)
vm_context = iree.runtime.VmContext(
instance=vm_instance, modules=(hal_module, parameters_module, vm_module)
)
return vm_module, vm_context, vm_instance


def run_iree_module_function(
module: iree.runtime.VmModule,
vm_context: iree.runtime.VmContext,
function_name: str,
args: List[iree.runtime.DeviceArray],
driver: str,
) -> List[iree.runtime.DeviceArray]:
vm_function = module.lookup_function(function_name)
invoker = iree.runtime.FunctionInvoker(
vm_context=vm_context,
# TODO: rework iree.runtime.FunctionInvoker interface for multiple devices.
# This works, but does not look right.
device=iree.runtime.get_device(driver, cache=False),
vm_function=vm_function,
)
res = invoker(*args)
if isinstance(res, iree.runtime.DeviceArray):
res = (res,)
return res


def prepare_iree_module_function_args(
args: List[Union[AnyTensor, List[AnyTensor]]], devices: List[iree.runtime.HalDevice]
) -> List[iree.runtime.DeviceArray]:
res = []
for arg in args:
if isinstance(arg, ShardedTensor):
assert len(devices) == len(arg.shards)
res.extend(
[
prepare_iree_module_function_args([shard], [device])[0]
for shard, device in zip(arg.shards, devices)
]
)
elif isinstance(arg, (DefaultPrimitiveTensor, torch.Tensor)):
res.append(
iree.runtime.asdevicearray(
devices[0], unbox_tensor(arg).to("cpu").numpy()
)
)
else:
assert isinstance(arg, collections.abc.Sequence)
res.extend(prepare_iree_module_function_args(arg, devices))
return res


def iree_to_torch(*tensors: iree.runtime.DeviceArray) -> List[torch.Tensor]:
return [torch.tensor(tensor.to_host()) for tensor in tensors]


class ShardedLlamaTest(unittest.TestCase):
Expand Down Expand Up @@ -53,6 +147,8 @@ def setUp(self):
activation_dtype=self.dtype,
attention_dtype=self.dtype,
)
self.sharded_config = deepcopy(self.config)
self.sharded_config.tensor_parallelism_size = 2
self.theta = make_random_llama_theta(
config=self.config,
vocab_size=self.vocabulary_size,
Expand All @@ -61,7 +157,7 @@ def setUp(self):
[14, 9, self.block_seq_stride - 1], dtype=torch.int32
)

def make_prefill_args(self, model: PagedLlamaModelV1) -> Dict[str, Any]:
def make_prefill_args(self, model: PagedLlamaModelV1) -> OrderedDict[str, Any]:
batch_seq_len = round_up_to_multiple_of(
int(torch.max(self.prefill_seq_lens)), model.cache.pad_sequence_stride
)
Expand All @@ -79,16 +175,18 @@ def make_prefill_args(self, model: PagedLlamaModelV1) -> Dict[str, Any]:
).view(self.batch_size, -1)
cache_state = model.cache.paged.allocate(page_count=self.cache_page_count)
cache_state = [torch.rand_like(cache_state[0])]
return {
"tokens": token_ids,
"attention_mask": attention_mask,
"seq_block_ids": seq_block_ids,
"cache_state": cache_state,
}
return OrderedDict(
[
("tokens", token_ids),
("attention_mask", attention_mask),
("seq_block_ids", seq_block_ids),
("cache_state", cache_state),
]
)

def make_equal_unsharded_and_sharded_prefill_args(
self, model: PagedLlamaModelV1, sharded_model: PagedLlamaModelV1
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
) -> Tuple[OrderedDict[str, Any], OrderedDict[str, Any]]:
prefill_args = self.make_prefill_args(model)
sharded_cache_state = sharded_model.cache.paged.allocate(
page_count=self.cache_page_count
Expand All @@ -103,7 +201,7 @@ def make_equal_unsharded_and_sharded_prefill_args(
sharded_prefill_args["cache_state"] = sharded_cache_state
return prefill_args, sharded_prefill_args

def make_decode_args(self, model: PagedLlamaModelV1) -> Dict[str, Any]:
def make_decode_args(self, model: PagedLlamaModelV1) -> OrderedDict[str, Any]:
start_positions = self.prefill_seq_lens.clone()
seq_lens = self.prefill_seq_lens + 1
batch_seq_len = round_up_to_multiple_of(
Expand All @@ -123,17 +221,19 @@ def make_decode_args(self, model: PagedLlamaModelV1) -> Dict[str, Any]:
).view(self.batch_size, -1)
cache_state = model.cache.paged.allocate(page_count=self.cache_page_count)
cache_state = [torch.rand_like(cache_state[0])]
return {
"tokens": decode_token_ids,
"attention_mask": attention_mask,
"start_positions": start_positions,
"seq_block_ids": seq_block_ids,
"cache_state": cache_state,
}
return OrderedDict(
[
("tokens", decode_token_ids),
("attention_mask", attention_mask),
("start_positions", start_positions),
("seq_block_ids", seq_block_ids),
("cache_state", cache_state),
]
)

def make_equal_unsharded_and_sharded_decode_args(
self, model: PagedLlamaModelV1, sharded_model: PagedLlamaModelV1
):
) -> Tuple[OrderedDict[str, Any], OrderedDict[str, Any]]:
decode_args = self.make_decode_args(model)
sharded_decode_args = deepcopy(decode_args)
sharded_decode_args["cache_state"] = sharded_model.cache.paged.shard_state(
Expand All @@ -145,10 +245,8 @@ def testCompareToySizedModelToUnsharded(self):
"""Run a sharded variant of a toy model size and compare it against the
unsharded variant."""
model = PagedLlamaModelV1(self.theta, self.config)
sharded_config = deepcopy(self.config)
sharded_config.tensor_parallelism_size = 2
sharded_theta = shard_theta(self.theta, sharded_config)
sharded_model = PagedLlamaModelV1(sharded_theta, sharded_config)
sharded_theta = shard_theta(self.theta, self.sharded_config)
sharded_model = PagedLlamaModelV1(sharded_theta, self.sharded_config)

# Verify prefill step.
(
Expand Down Expand Up @@ -180,7 +278,9 @@ def testCompareToySizedModelToUnsharded(self):
) = self.make_equal_unsharded_and_sharded_decode_args(model, sharded_model)
expected_decode_result = model.decode(**decode_args)
sharded_decode_result = sharded_model.decode(**sharded_decode_args)
torch.testing.assert_close(sharded_decode_result, expected_decode_result)
torch.testing.assert_close(
sharded_decode_result, expected_decode_result, atol=1e-4, rtol=1e-5
)
expected_decode_cache_state = decode_args["cache_state"][0]
actual_decode_cache_state = ops.unshard(
sharded_model.cache.paged.unflatten_page_table(
Expand All @@ -194,20 +294,28 @@ def testCompareToySizedModelToUnsharded(self):
actual_decode_cache_state, expected_decode_cache_state, atol=1e-4, rtol=1e-4
)

def testExportToySizedModelToMlir(self):
@unittest.skip(
(
"Before this does not crash at all we need "
"https://github.com/iree-org/iree/pull/18663 merged."
)
)
def testExportAndRunToySizedModelWithIree(self):
"""Test exporting to MLIR and compiling with IREE the sharded Llama model.
Test numerical accuracy of the IREE module against PyTorch."""

with tempfile.TemporaryDirectory() as temp_dir:
sharded_config = deepcopy(self.config)
sharded_config.tensor_parallelism_size = 2
sharded_theta = shard_theta(self.theta, sharded_config)
sharded_theta = shard_theta(self.theta, self.sharded_config)
sharded_theta.rename_tensors_to_paths()
sharded_dataset = Dataset({}, sharded_theta)
parameters_path = f"{temp_dir}/parameters.irpa"
sharded_dataset.save(f"{temp_dir}/parameters.irpa")
sharded_dataset = Dataset.load(parameters_path, mmap=False)
sharded_parameters_path = f"{temp_dir}/parameters.irpa"
sharded_dataset.save(sharded_parameters_path)
sharded_dataset = Dataset.load(sharded_parameters_path, mmap=False)
iree_driver = "local-task"

model = PagedLlamaModelV1(self.theta, self.config)
sharded_model = PagedLlamaModelV1(
sharded_dataset.root_theta, sharded_config
sharded_dataset.root_theta, self.sharded_config
)
sharded_fxb = FxProgramsBuilder(sharded_model)

Expand All @@ -222,9 +330,10 @@ def testExportToySizedModelToMlir(self):
def _(model, *args, **kwargs) -> torch.Tensor:
return model.prefill(*args, **kwargs)

_, sharded_decode_args = self.make_equal_unsharded_and_sharded_decode_args(
model, sharded_model
)
(
_,
sharded_decode_args,
) = self.make_equal_unsharded_and_sharded_decode_args(model, sharded_model)
# TODO: remove strict=False when
# https://github.com/pytorch/pytorch/issues/136757
# is resolved.
Expand All @@ -237,5 +346,91 @@ def _(model, *args, **kwargs) -> torch.Tensor:
def _(model, *args, **kwargs) -> torch.Tensor:
return model.decode(*args, **kwargs)

# Compile the IREE module.
output = export(sharded_fxb)
output.save_mlir(f"{temp_dir}/program.mlir")
output.session.set_flags(
*[
f"--iree-hal-target-device=llvm-cpu[{i}]"
for i in range(self.sharded_config.tensor_parallelism_size)
]
)
iree_module_path = f"{temp_dir}/program.vmfb"
output.compile(
save_to=iree_module_path,
target_backends=None,
)

iree_devices = get_iree_devices(
driver=iree_driver,
device_count=self.sharded_config.tensor_parallelism_size,
)
iree_module, vm_context, vm_instance = load_iree_module(
module_path=iree_module_path,
devices=iree_devices,
parameters_path=sharded_parameters_path,
)

# Check IREE's prefill step is close to torch.
prefill_iree_args = prepare_iree_module_function_args(
args=deepcopy(sharded_prefill_args).values(), devices=iree_devices
)
prefill_iree_result = run_iree_module_function(
args=prefill_iree_args,
function_name="prefill",
module=iree_module,
vm_context=vm_context,
driver=iree_driver,
)
prefill_iree_result = iree_to_torch(*prefill_iree_result)
assert len(prefill_iree_result) == 1
expected_prefill_result = sharded_model.prefill(**sharded_prefill_args)
# TODO: Although, not entirely wrong, investigate why this accuracy is that
# low for fp32 (atol=0.0011, rtol=0.013).
torch.testing.assert_close(
prefill_iree_result[0],
expected_prefill_result,
)
prefill_iree_cache_state_shards = prefill_iree_args[
-self.config.tensor_parallelism_size - 1 :
]
prefill_iree_cache_state_shards = iree_to_torch(
*prefill_iree_cache_state_shards
)
for actual_cache_state_shard, expected_cache_state_shard in zip(
prefill_iree_cache_state_shards,
sharded_prefill_args["cache_state"][0].shards,
):
# TODO: debug inaccuracy.
torch.testing.assert_close(
actual_cache_state_shard, unbox_tensor(expected_cache_state_shard)
)

# Check IREE's decode step is close to torch.
decode_iree_args = prepare_iree_module_function_args(
args=deepcopy(sharded_decode_args).values(), devices=iree_devices
)
decode_iree_result = run_iree_module_function(
args=decode_iree_args,
function_name="decode",
module=iree_module,
vm_context=vm_context,
)
decode_iree_result = iree_to_torch(*decode_iree_result)
expected_decode_result = sharded_model.decode(**sharded_decode_args)
# TODO: debug inaccuracy.
torch.testing.assert_close(decode_iree_result[0], expected_decode_result)
decode_iree_cache_state_shards = decode_iree_args[
-self.config.tensor_parallelism_size - 1 :
]
decode_iree_cache_state_shards = iree_to_torch(
*decode_iree_cache_state_shards
)
for actual_cache_state_shard, expected_cache_state_shard in zip(
decode_iree_cache_state_shards,
sharded_decode_args["cache_state"][0].shards,
):
# TODO: debug inaccuracy.
torch.testing.assert_close(
actual_cache_state_shard, unbox_tensor(expected_cache_state_shard)
)

0 comments on commit 8727db0

Please sign in to comment.