Skip to content

Commit

Permalink
fix: huggingface model push
Browse files Browse the repository at this point in the history
Signed-off-by: Frost Ming <[email protected]>
  • Loading branch information
frostming committed Aug 30, 2024
1 parent 24bbf3f commit 23335ae
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 10 deletions.
14 changes: 10 additions & 4 deletions src/_bentoml_sdk/models/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,8 @@ class HuggingFaceModel(Model[str]):
"""A model reference to a Hugging Face model.
Args:
model_id (Tag): The model tag. E.g. "google-bert/bert-base-uncased".
You can specify a rev or commit hash by appending it to the model name separated by a colon:
google-bert/bert-base-uncased:main
google-bert/bert-base-uncased:86b5e0934494bd15c9632b12f734a8a67f723594
model_id (str): The model tag. E.g. "google-bert/bert-base-uncased".
revision (str, optional): The revision to use. Defaults to "main".
endpoint (str, optional): The Hugging Face endpoint to use. Defaults to https://huggingface.co.
Returns:
Expand All @@ -39,6 +37,14 @@ class HuggingFaceModel(Model[str]):
revision: str = "main"
endpoint: str | None = attrs.field(factory=lambda: os.getenv("HF_ENDPOINT"))

@classmethod
def from_tag(cls, tag: Tag, endpoint: str | None = None) -> HuggingFaceModel:
return cls(
model_id=tag.name.replace("--", "/"),
revision=tag.version or "main",
endpoint=endpoint,
)

@cached_property
def commit_hash(self) -> str:
from huggingface_hub import get_hf_file_metadata
Expand Down
11 changes: 6 additions & 5 deletions src/bentoml/_internal/cloud/bentocloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,16 +93,18 @@ def _do_push_bento(
models_to_push: list[Model[t.Any]] = []
for model in info.all_models:
if model.registry == "huggingface":
models_to_push.append(HuggingFaceModel(model.tag, model.endpoint))
models_to_push.append(
HuggingFaceModel.from_tag(model.tag, model.endpoint)
)
else:
model = BentoModel(model.tag)
if model.stored is not None:
models_to_push.append(model)
with ThreadPoolExecutor(max_workers=max(len(models_to_push), 1)) as executor:

def push_model(model: BentoModel) -> None:
def push_model(model: Model[t.Any]) -> None:
model_upload_task_id = self.spinner.transmission_progress.add_task(
f'Pushing model "{model.tag}"', start=False, visible=False
f'Pushing model "{model}"', start=False, visible=False
)
self._do_push_model(
model,
Expand All @@ -111,7 +113,7 @@ def push_model(model: BentoModel) -> None:
threads=threads,
)

executor.map(push_model, models_to_push)
executor.map(push_model, models_to_push[1:])
with self.spinner.spin(text=f'Fetching Bento repository "{name}"'):
bento_repository = rest_client.v1.get_bento_repository(
bento_repository_name=name
Expand Down Expand Up @@ -550,7 +552,6 @@ def _do_push_model(
from _bentoml_sdk.models import BentoModel

model_info = model.to_info()

name = model_info.tag.name
version = model_info.tag.version
if version is None:
Expand Down
2 changes: 1 addition & 1 deletion src/bentoml/_internal/container/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def construct_containerfile(
model_ref = BentoModel(model.tag)
model_ref.resolve(bento_model_dir)
else:
model_ref = HuggingFaceModel(model.tag)
model_ref = HuggingFaceModel.from_tag(model.tag, model.endpoint)
model_ref.resolve(hf_model_dir)

# NOTE: dockerfile_template is already included in the
Expand Down

0 comments on commit 23335ae

Please sign in to comment.