diff --git a/src/_bentoml_sdk/models/huggingface.py b/src/_bentoml_sdk/models/huggingface.py index ee8816c269c..36fd18f207e 100644 --- a/src/_bentoml_sdk/models/huggingface.py +++ b/src/_bentoml_sdk/models/huggingface.py @@ -11,6 +11,7 @@ from bentoml._internal.cloud.schemas.modelschemas import ModelManifestSchema from bentoml._internal.cloud.schemas.schemasv1 import CreateModelSchema from bentoml._internal.models.model import ModelContext +from bentoml._internal.tag import GenericTag from bentoml._internal.tag import Tag from bentoml._internal.types import PathType @@ -34,7 +35,7 @@ class HuggingFaceModel(Model[str]): str: The downloaded model path. """ - tag: Tag = attrs.field(converter=Tag.from_taglike, alias="model_id") + tag: Tag = attrs.field(converter=GenericTag.from_taglike, alias="model_id") endpoint: str | None = attrs.field(factory=lambda: os.getenv("HF_ENDPOINT")) @property diff --git a/src/bentoml/_internal/runner/runner.py b/src/bentoml/_internal/runner/runner.py index a9e64e74c97..0fbb27c9203 100644 --- a/src/bentoml/_internal/runner/runner.py +++ b/src/bentoml/_internal/runner/runner.py @@ -12,7 +12,7 @@ from ...exceptions import StateException from ..configuration.containers import BentoMLContainer from ..models.model import Model -from ..tag import validate_tag_str +from ..tag import Tag from ..utils import first_not_none from .runnable import Runnable from .runner_handle import DummyRunnerHandle @@ -71,7 +71,7 @@ def _to_lower_name(name: str) -> str: def _validate_name(_: t.Any, attr: attr.Attribute[str], value: str): try: - validate_tag_str(value) + Tag.validate_tag_str(value) except ValueError as e: # TODO: link to tag validation documentation raise ValueError( diff --git a/src/bentoml/_internal/tag.py b/src/bentoml/_internal/tag.py index faf23c125bb..67b493643f1 100644 --- a/src/bentoml/_internal/tag.py +++ b/src/bentoml/_internal/tag.py @@ -37,50 +37,45 @@ def _join(match: re.Match[str]) -> str: return camelcase_re.sub(_join, name).lstrip("_") -def validate_tag_str(value: str): - """ - Validate that a tag value (either name or version) is a simple string that: - * Must be at most 63 characters long, - * Begin and end with an alphanumeric character (`[a-z0-9A-Z]`), and - * May contain dashes (`-`), underscores (`_`) dots (`.`), or alphanumerics - between. - """ - errors: list[str] = [] - if len(value) > tag_max_length: - errors.append(tag_max_length_error_msg) - if tag_regex.match(value) is None: - errors.append(tag_invalid_error_msg) - - if errors: - # TODO: link to tag documentation - raise ValueError( - f"{value} is not a valid BentoML tag: " + ", and ".join(errors) - ) - - @attr.define(slots=True) class Tag: name: str version: t.Optional[str] def __init__(self, name: str, version: t.Optional[str] = None): - lname = name.lower() - if name != lname: - logger.warning("Converting '%s' to lowercase: '%s'.", name, lname) - - validate_tag_str(lname) - - self.name = lname + self.name = self.validate_tag_str(name) if version is not None: - lversion = version.lower() - if version != lversion: - logger.warning("Converting '%s' to lowercase: '%s'.", version, lversion) - validate_tag_str(lversion) - self.version = lversion + self.version = self.validate_tag_str(version) else: self.version = None + @classmethod + def validate_tag_str(cls, value: str) -> str: + """ + Validate that a tag value (either name or version) is a simple string that: + * Must be at most 63 characters long, + * Begin and end with an alphanumeric character (`[a-z0-9A-Z]`), and + * May contain dashes (`-`), underscores (`_`) dots (`.`), or alphanumerics + between. + """ + if value != (lvalue := value.lower()): + logger.warning("Converting '%s' to lowercase: '%s'.", value, lvalue) + value = lvalue + + errors: list[str] = [] + if len(value) > tag_max_length: + errors.append(tag_max_length_error_msg) + if tag_regex.match(value) is None: + errors.append(tag_invalid_error_msg) + + if errors: + # TODO: link to tag documentation + raise ValueError( + f"{value} is not a valid BentoML tag: " + ", and ".join(errors) + ) + return value + def __str__(self): if self.version is None: return self.name @@ -112,7 +107,7 @@ def from_taglike(cls, taglike: t.Union["Tag", str]) -> "Tag": return cls.from_str(taglike) @classmethod - def from_str(cls, tag_str: str) -> "Tag": + def from_str(cls, tag_str: str) -> t.Self: if ":" not in tag_str: return cls(tag_str, None) try: @@ -126,7 +121,7 @@ def from_str(cls, tag_str: str) -> "Tag": except ValueError: raise BentoMLException(f"Invalid {cls.__name__} {tag_str}") - def make_new_version(self) -> "Tag": + def make_new_version(self) -> t.Self: if self.version is not None: raise ValueError( "tried to run 'make_new_version' on a Tag that already has a version" @@ -136,7 +131,7 @@ def make_new_version(self) -> "Tag": ver_bytes = ver_bytes[:6] + ver_bytes[8:12] encoded_ver = base64.b32encode(ver_bytes) - return Tag(self.name, encoded_ver.decode("ascii").lower()) + return type(self)(self.name, encoded_ver.decode("ascii").lower()) def path(self) -> str: if self.version is None: @@ -147,5 +142,12 @@ def latest_path(self) -> str: return fs.path.combine(self.name, "latest") -bentoml_cattr.register_structure_hook(Tag, lambda d, _: Tag.from_taglike(d)) # type: ignore[misc] +class GenericTag(Tag): + @classmethod + def validate_tag_str(cls, value: str) -> str: + """Allow any string as a tag""" + return value + + +bentoml_cattr.register_structure_hook(Tag, lambda d, cls: cls.from_taglike(d)) # type: ignore[misc] bentoml_cattr.register_unstructure_hook(Tag, str)