diff --git a/sharktank/sharktank/examples/export_paged_llm_v1.py b/sharktank/sharktank/examples/export_paged_llm_v1.py index aeb70cd3a..fca07266f 100644 --- a/sharktank/sharktank/examples/export_paged_llm_v1.py +++ b/sharktank/sharktank/examples/export_paged_llm_v1.py @@ -66,11 +66,14 @@ def main(): dataset = cli.get_input_dataset(args) hp = configs.LlamaHParams.from_gguf_props(dataset.properties) + tensor_parallelism_size = ( + dataset.properties["tensor_parallelism_size"] + if "tensor_parallelism_size" in dataset.properties + else 1 + ) llama_config = LlamaModelConfig( hp, - tensor_parallelism_size=dataset.properties["tensor_parallelism_size"] - if "tensor_parallelism_size" in dataset.properties - else 1, + tensor_parallelism_size=tensor_parallelism_size, use_hf=False, static_tables=False, # Rely on the compiler for hoisting tables. kv_cache_type="direct" if args.bs == [1] else "paged",