diff --git a/requirements.txt b/requirements.txt index 53d745a1b..6a5cccc92 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,6 +3,11 @@ gguf==0.6.0 numpy==1.26.3 onnx==1.15.0 +# Model deps. +huggingface-hub==0.22.2 +transformers==4.40.0 +sentencepiece==0.2.0 + # It is expected that you have installed a PyTorch version/variant specific # to your needs, so we only include a minimum version spec. # TODO: Use a versioned release once 2.3.0 drops. diff --git a/sharktank/sharktank/examples/validate_llama_ref_model.py b/sharktank/sharktank/examples/validate_llama_ref_model.py index 61b95d0ce..8b6a8d40c 100644 --- a/sharktank/sharktank/examples/validate_llama_ref_model.py +++ b/sharktank/sharktank/examples/validate_llama_ref_model.py @@ -14,8 +14,16 @@ def main(args: list[str]): + from ..utils import cli + torch.no_grad().__enter__() - config = Dataset.load(args[0]) + + parser = cli.create_parser() + cli.add_gguf_dataset_options(parser) + args = cli.parse(parser) + + data_files = cli.get_gguf_data_files(args) + config = Dataset.load(data_files["gguf"]) hp = configs.LlamaHParams.from_gguf_props(config.properties) model = DirectCacheLlamaModelV1(config.root_theta, hp) diff --git a/sharktank/sharktank/models/llama/llama_ref.py b/sharktank/sharktank/models/llama/llama_ref.py index 852fcc249..7747ce6ed 100644 --- a/sharktank/sharktank/models/llama/llama_ref.py +++ b/sharktank/sharktank/models/llama/llama_ref.py @@ -13,6 +13,7 @@ import torch.nn.functional as F from ...layers import * +from ...types import Theta __all__ = [ "DirectCacheLlamaModelV1",