diff --git a/docs/source/openvino/reference.mdx b/docs/source/openvino/reference.mdx index c5043d877..7547a56b0 100644 --- a/docs/source/openvino/reference.mdx +++ b/docs/source/openvino/reference.mdx @@ -19,7 +19,7 @@ limitations under the License. ## Generic model classes [[autodoc]] openvino.modeling_base.OVBaseModel - - _from_pretrained + - from_pretrained - reshape ## Natural Language Processing diff --git a/docs/source/openvino/tutorials/diffusers.mdx b/docs/source/openvino/tutorials/diffusers.mdx index dad09420b..c9fe214bd 100644 --- a/docs/source/openvino/tutorials/diffusers.mdx +++ b/docs/source/openvino/tutorials/diffusers.mdx @@ -50,18 +50,14 @@ To further speed up inference, the model can be statically reshaped : ```python # Define the shapes related to the inputs and desired outputs -batch_size = 1 -num_images_per_prompt = 1 -height = 512 -width = 512 - +batch_size, num_images, height, width = 1, 1, 512, 512 # Statically reshape the model -pipeline.reshape(batch_size=batch_size, height=height, width=width, num_images_per_prompt=num_images_per_prompt) +pipeline.reshape(batch_size=batch_size, height=height, width=width, num_images_per_prompt=num_images) # Compile the model before the first inference pipeline.compile() # Run inference -images = pipeline(prompt, height=height, width=width, num_images_per_prompt=num_images_per_prompt).images +images = pipeline(prompt, height=height, width=width, num_images_per_prompt=num_images).images ``` In case you want to change any parameters such as the outputs height or width, you'll need to statically reshape your model once again. diff --git a/optimum/exporters/openvino/__main__.py b/optimum/exporters/openvino/__main__.py index c4b6ef0cd..842198625 100644 --- a/optimum/exporters/openvino/__main__.py +++ b/optimum/exporters/openvino/__main__.py @@ -49,15 +49,6 @@ import torch -_COMPRESSION_OPTIONS = { - "int8": {"bits": 8}, - "int4_sym_g128": {"bits": 4, "sym": True, "group_size": 128}, - "int4_asym_g128": {"bits": 4, "sym": False, "group_size": 128}, - "int4_sym_g64": {"bits": 4, "sym": True, "group_size": 64}, - "int4_asym_g64": {"bits": 4, "sym": False, "group_size": 64}, -} - - logger = logging.getLogger(__name__) @@ -108,8 +99,6 @@ def main_export( model_kwargs: Optional[Dict[str, Any]] = None, custom_export_configs: Optional[Dict[str, "OnnxConfig"]] = None, fn_get_submodels: Optional[Callable] = None, - compression_option: Optional[str] = None, - compression_ratio: Optional[float] = None, ov_config: "OVConfig" = None, stateful: bool = True, convert_tokenizer: bool = False, @@ -171,11 +160,6 @@ def main_export( fn_get_submodels (`Optional[Callable]`, defaults to `None`): Experimental usage: Override the default submodels that are used at the export. This is especially useful when exporting a custom architecture that needs to split the ONNX (e.g. encoder-decoder). If unspecified with custom models, optimum will try to use the default submodels used for the given task, with no guarantee of success. - compression_option (`Optional[str]`, defaults to `None`): - The weight compression option, e.g. `f16` stands for float16 weights, `i8` - INT8 weights, `int4_sym_g128` - INT4 symmetric weights w/ group size 128, `int4_asym_g128` - as previous but asymmetric w/ zero-point, - `int4_sym_g64` - INT4 symmetric weights w/ group size 64, "int4_asym_g64" - as previous but asymmetric w/ zero-point, `f32` - means no compression. - compression_ratio (`Optional[float]`, defaults to `None`): - Compression ratio between primary and backup precision (only relevant to INT4). stateful (`bool`, defaults to `True`): Produce stateful model where all kv-cache inputs and outputs are hidden in the model and are not exposed as model inputs and outputs. Applicable only for decoder models. **kwargs_shapes (`Dict`): @@ -198,28 +182,6 @@ def main_export( raise ValueError("You cannot use both `use_auth_token` and `token` arguments at the same time.") token = use_auth_token - if compression_option is not None: - logger.warning( - "The `compression_option` argument is deprecated and will be removed in optimum-intel v1.17.0. " - "Please, pass an `ov_config` argument instead `OVConfig(..., quantization_config=quantization_config)`." - ) - - if compression_ratio is not None: - logger.warning( - "The `compression_ratio` argument is deprecated and will be removed in optimum-intel v1.17.0. " - "Please, pass an `ov_config` argument instead `OVConfig(quantization_config={ratio=compression_ratio})`." - ) - - if ov_config is None and compression_option is not None: - from ...intel.openvino.configuration import OVConfig - - if compression_option == "fp16": - ov_config = OVConfig(dtype="fp16") - elif compression_option != "fp32": - q_config = _COMPRESSION_OPTIONS[compression_option] if compression_option in _COMPRESSION_OPTIONS else {} - q_config["ratio"] = compression_ratio or 1.0 - ov_config = OVConfig(quantization_config=q_config) - original_task = task task = infer_task( task, model_name_or_path, subfolder=subfolder, revision=revision, cache_dir=cache_dir, token=token diff --git a/optimum/intel/__init__.py b/optimum/intel/__init__.py index 7f76c2854..c70dc4676 100644 --- a/optimum/intel/__init__.py +++ b/optimum/intel/__init__.py @@ -24,6 +24,7 @@ is_neural_compressor_available, is_nncf_available, is_openvino_available, + is_sentence_transformers_available, ) from .version import __version__ @@ -179,6 +180,21 @@ _import_structure["neural_compressor"].append("INCStableDiffusionPipeline") +try: + if not (is_openvino_available() and is_sentence_transformers_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + _import_structure["utils.dummy_openvino_and_sentence_transformers_objects"] = [ + "OVSentenceTransformer", + ] +else: + _import_structure["openvino"].extend( + [ + "OVSentenceTransformer", + ] + ) + + if TYPE_CHECKING: try: if not is_ipex_available(): @@ -302,6 +318,18 @@ else: from .neural_compressor import INCStableDiffusionPipeline + try: + if not (is_openvino_available() and is_sentence_transformers_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from .utils.dummy_openvino_and_sentence_transformers_objects import ( + OVSentenceTransformer, + ) + else: + from .openvino import ( + OVSentenceTransformer, + ) + else: import sys diff --git a/optimum/intel/openvino/__init__.py b/optimum/intel/openvino/__init__.py index 4ee285f07..929bdf1be 100644 --- a/optimum/intel/openvino/__init__.py +++ b/optimum/intel/openvino/__init__.py @@ -15,7 +15,12 @@ import logging import warnings -from ..utils.import_utils import is_accelerate_available, is_diffusers_available, is_nncf_available +from ..utils.import_utils import ( + is_accelerate_available, + is_diffusers_available, + is_nncf_available, + is_sentence_transformers_available, +) from .utils import ( OV_DECODER_NAME, OV_DECODER_WITH_PAST_NAME, @@ -77,3 +82,7 @@ OVStableDiffusionXLImg2ImgPipeline, OVStableDiffusionXLPipeline, ) + + +if is_sentence_transformers_available(): + from .modeling_sentence_transformers import OVSentenceTransformer diff --git a/optimum/intel/openvino/modeling.py b/optimum/intel/openvino/modeling.py index ba2d26434..786704682 100644 --- a/optimum/intel/openvino/modeling.py +++ b/optimum/intel/openvino/modeling.py @@ -14,7 +14,6 @@ import logging import os -import warnings from pathlib import Path from tempfile import TemporaryDirectory from typing import Dict, Optional, Union @@ -370,6 +369,13 @@ class OVModelForFeatureExtraction(OVModel): auto_model_class = AutoModel def __init__(self, model=None, config=None, **kwargs): + if {"token_embeddings", "sentence_embedding"}.issubset( + {name for output in model.outputs for name in output.names} + ): # Sentence Transormers outputs + raise ValueError( + "This model is a Sentence Transformers model. Please use `OVSentenceTransformer` to load this model." + ) + super().__init__(model, config, **kwargs) @add_start_docstrings_to_model_forward( @@ -417,7 +423,6 @@ def _from_transformers( cls, model_id: str, config: PretrainedConfig, - use_auth_token: Optional[Union[bool, str]] = None, token: Optional[Union[bool, str]] = None, revision: Optional[str] = None, force_download: bool = False, @@ -430,15 +435,6 @@ def _from_transformers( quantization_config: Union[OVWeightQuantizationConfig, Dict] = None, **kwargs, ): - if use_auth_token is not None: - warnings.warn( - "The `use_auth_token` argument is deprecated and will be removed soon. Please use the `token` argument instead.", - FutureWarning, - ) - if token is not None: - raise ValueError("You cannot use both `use_auth_token` and `token` arguments at the same time.") - token = use_auth_token - save_dir = TemporaryDirectory() save_dir_path = Path(save_dir.name) # This attribute is needed to keep one reference on the temporary directory, since garbage collecting @@ -591,7 +587,6 @@ def from_pretrained( model_id: Union[str, Path], export: bool = False, config: Optional["PretrainedConfig"] = None, - use_auth_token: Optional[Union[bool, str]] = None, token: Optional[Union[bool, str]] = None, revision: Optional[str] = None, force_download: bool = False, @@ -602,15 +597,6 @@ def from_pretrained( trust_remote_code: bool = False, **kwargs, ): - if use_auth_token is not None: - warnings.warn( - "The `use_auth_token` argument is deprecated and will be removed soon. Please use the `token` argument instead.", - FutureWarning, - ) - if token is not None: - raise ValueError("You cannot use both `use_auth_token` and `token` arguments at the same time.") - token = use_auth_token - # Fix the mismatch between timm_config and huggingface_config local_timm_model = _is_timm_ov_dir(model_id) if local_timm_model or (not os.path.isdir(model_id) and model_info(model_id).library_name == "timm"): diff --git a/optimum/intel/openvino/modeling_base.py b/optimum/intel/openvino/modeling_base.py index 98fec1735..e8dc11312 100644 --- a/optimum/intel/openvino/modeling_base.py +++ b/optimum/intel/openvino/modeling_base.py @@ -28,12 +28,14 @@ from transformers import GenerationConfig, PretrainedConfig from transformers.file_utils import add_start_docstrings from transformers.generation import GenerationMixin +from transformers.utils import is_offline_mode from optimum.exporters.onnx import OnnxConfig -from optimum.modeling_base import OptimizedModel +from optimum.modeling_base import FROM_PRETRAINED_START_DOCSTRING, OptimizedModel from ...exporters.openvino import export, main_export from ..utils.import_utils import is_nncf_available +from ..utils.modeling_utils import _find_files_matching_pattern from .configuration import OVConfig, OVDynamicQuantizationConfig, OVWeightQuantizationConfig from .utils import ONNX_WEIGHTS_NAME, OV_TO_PT_TYPE, OV_XML_FILE_NAME, _print_compiled_model_properties @@ -52,6 +54,7 @@ class OVBaseModel(OptimizedModel): auto_model_class = None export_feature = None _supports_cache_class = False + _library_name = "transformers" def __init__( self, @@ -220,7 +223,6 @@ def _from_pretrained( cls, model_id: Union[str, Path], config: PretrainedConfig, - use_auth_token: Optional[Union[bool, str]] = None, token: Optional[Union[bool, str]] = None, revision: Optional[str] = None, force_download: bool = False, @@ -242,8 +244,6 @@ def _from_pretrained( Can be either: - The model id of a pretrained model hosted inside a model repo on huggingface.co. - The path to a directory containing the model weights. - use_auth_token (Optional[Union[bool, str]], defaults to `None`): - Deprecated. Please use `token` instead. token (Optional[Union[bool, str]], defaults to `None`): The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated when running `huggingface-cli login` (stored in `~/.huggingface`). @@ -263,15 +263,6 @@ def _from_pretrained( load_in_8bit (`bool`, *optional*, defaults to `False`): Whether or not to apply 8-bit weight quantization. """ - if use_auth_token is not None: - warnings.warn( - "The `use_auth_token` argument is deprecated and will be removed soon. Please use the `token` argument instead.", - FutureWarning, - ) - if token is not None: - raise ValueError("You cannot use both `use_auth_token` and `token` arguments at the same time.") - token = use_auth_token - model_path = Path(model_id) default_file_name = ONNX_WEIGHTS_NAME if from_onnx else OV_XML_FILE_NAME file_name = file_name or default_file_name @@ -312,6 +303,87 @@ def _from_pretrained( **kwargs, ) + @classmethod + @add_start_docstrings(FROM_PRETRAINED_START_DOCSTRING) + def from_pretrained( + cls, + model_id: Union[str, Path], + export: bool = False, + force_download: bool = False, + use_auth_token: Optional[Union[bool, str]] = None, + token: Optional[Union[bool, str]] = None, + cache_dir: str = HUGGINGFACE_HUB_CACHE, + subfolder: str = "", + config: Optional[PretrainedConfig] = None, + local_files_only: bool = False, + trust_remote_code: bool = False, + revision: Optional[str] = None, + **kwargs, + ): + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed soon. Please use the `token` argument instead.", + FutureWarning, + ) + if token is not None: + raise ValueError("You cannot use both `use_auth_token` and `token` arguments at the same time.") + token = use_auth_token + + if is_offline_mode() and not local_files_only: + logger.info("Offline mode: forcing local_files_only=True") + local_files_only = True + + _export = export + try: + if local_files_only: + object_id = model_id.replace("/", "--") + cached_model_dir = os.path.join(cache_dir, f"models--{object_id}") + refs_file = os.path.join(os.path.join(cached_model_dir, "refs"), revision or "main") + with open(refs_file) as f: + revision = f.read() + model_dir = os.path.join(cached_model_dir, "snapshots", revision) + else: + model_dir = model_id + + ov_files = _find_files_matching_pattern( + model_dir, + pattern=r"(.*)?openvino(.*)?\_model.xml", + subfolder=subfolder, + use_auth_token=token, + revision=revision, + ) + _export = len(ov_files) == 0 + if _export ^ export: + if export: + logger.warning( + f"The model {model_id} was already converted to the OpenVINO IR but got `export=True`, the model will be converted to OpenVINO once again. " + "Don't forget to save the resulting model with `.save_pretrained()`" + ) + _export = True + else: + logger.warning( + f"No OpenVINO files were found for {model_id}, setting `export=True` to convert the model to the OpenVINO IR. " + "Don't forget to save the resulting model with `.save_pretrained()`" + ) + except Exception as exception: + logger.warning( + f"Could not infer whether the model was already converted or not to the OpenVINO IR, keeping `export={export}`.\n{exception}" + ) + + return super().from_pretrained( + model_id, + export=_export, + force_download=force_download, + token=token, + cache_dir=cache_dir, + subfolder=subfolder, + config=config, + local_files_only=local_files_only, + trust_remote_code=trust_remote_code, + revision=revision, + **kwargs, + ) + @staticmethod def _prepare_weight_quantization_config( quantization_config: Optional[Union[OVWeightQuantizationConfig, Dict]] = None, load_in_8bit: bool = False @@ -337,7 +409,6 @@ def _set_ov_config_parameters(self): @staticmethod def _cached_file( model_path: Union[Path, str], - use_auth_token: Optional[Union[bool, str]] = None, token: Optional[Union[bool, str]] = None, revision: Optional[str] = None, force_download: bool = False, @@ -346,15 +417,6 @@ def _cached_file( subfolder: str = "", local_files_only: bool = False, ): - if use_auth_token is not None: - warnings.warn( - "The `use_auth_token` argument is deprecated and will be removed soon. Please use the `token` argument instead.", - FutureWarning, - ) - if token is not None: - raise ValueError("You cannot use both `use_auth_token` and `token` arguments at the same time.") - token = use_auth_token - # locates a file in a local folder and repo, downloads and cache it if necessary. model_path = Path(model_path) if model_path.is_dir(): @@ -385,7 +447,6 @@ def _from_transformers( cls, model_id: str, config: PretrainedConfig, - use_auth_token: Optional[Union[bool, str]] = None, token: Optional[Union[bool, str]] = None, revision: Optional[str] = None, force_download: bool = False, @@ -409,8 +470,6 @@ def _from_transformers( - The path to a directory containing the model weights. save_dir (`str` or `Path`): The directory where the exported ONNX model should be saved, default to `transformers.file_utils.default_cache_path`, which is the cache directory for transformers. - use_auth_token (`Optional[str]`, defaults to `None`): - Deprecated. Please use `token` instead. token (Optional[Union[bool, str]], defaults to `None`): The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated when running `huggingface-cli login` (stored in `~/.huggingface`). @@ -419,15 +478,6 @@ def _from_transformers( kwargs (`Dict`, *optional*): kwargs will be passed to the model during initialization """ - if use_auth_token is not None: - warnings.warn( - "The `use_auth_token` argument is deprecated and will be removed soon. Please use the `token` argument instead.", - FutureWarning, - ) - if token is not None: - raise ValueError("You cannot use both `use_auth_token` and `token` arguments at the same time.") - token = use_auth_token - save_dir = TemporaryDirectory() save_dir_path = Path(save_dir.name) # This attribute is needed to keep one reference on the temporary directory, since garbage collecting @@ -452,6 +502,7 @@ def _from_transformers( force_download=force_download, trust_remote_code=trust_remote_code, ov_config=ov_config, + library_name=cls._library_name, ) config.save_pretrained(save_dir_path) @@ -469,7 +520,6 @@ def _to_load( model, config: PretrainedConfig, onnx_config: OnnxConfig, - use_auth_token: Optional[Union[bool, str]] = None, token: Optional[Union[bool, str]] = None, revision: Optional[str] = None, force_download: bool = False, @@ -478,15 +528,6 @@ def _to_load( stateful: bool = False, **kwargs, ): - if use_auth_token is not None: - warnings.warn( - "The `use_auth_token` argument is deprecated and will be removed soon. Please use the `token` argument instead.", - FutureWarning, - ) - if token is not None: - raise ValueError("You cannot use both `use_auth_token` and `token` arguments at the same time.") - token = use_auth_token - save_dir = TemporaryDirectory() save_dir_path = Path(save_dir.name) diff --git a/optimum/intel/openvino/modeling_base_seq2seq.py b/optimum/intel/openvino/modeling_base_seq2seq.py index 95ffbc930..bf34abf5f 100644 --- a/optimum/intel/openvino/modeling_base_seq2seq.py +++ b/optimum/intel/openvino/modeling_base_seq2seq.py @@ -14,7 +14,6 @@ import logging import os -import warnings from pathlib import Path from tempfile import TemporaryDirectory from typing import Dict, Optional, Union @@ -120,7 +119,6 @@ def _from_pretrained( cls, model_id: Union[str, Path], config: PretrainedConfig, - use_auth_token: Optional[Union[bool, str]] = None, token: Optional[Union[bool, str]] = None, revision: Optional[str] = None, force_download: bool = False, @@ -144,8 +142,6 @@ def _from_pretrained( Can be either: - The model id of a pretrained model hosted inside a model repo on huggingface.co. - The path to a directory containing the model weights. - use_auth_token (Optional[Union[bool, str]], defaults to `None`): - Deprecated. Please use `token` instead. token (Optional[Union[bool, str]], defaults to `None`): The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated when running `huggingface-cli login` (stored in `~/.huggingface`). @@ -169,15 +165,6 @@ def _from_pretrained( local_files_only(`bool`, *optional*, defaults to `False`): Whether or not to only look at local files (i.e., do not try to download the model). """ - if use_auth_token is not None: - warnings.warn( - "The `use_auth_token` argument is deprecated and will be removed soon. Please use the `token` argument instead.", - FutureWarning, - ) - if token is not None: - raise ValueError("You cannot use both `use_auth_token` and `token` arguments at the same time.") - token = use_auth_token - default_encoder_file_name = ONNX_ENCODER_NAME if from_onnx else OV_ENCODER_NAME default_decoder_file_name = ONNX_DECODER_NAME if from_onnx else OV_DECODER_NAME default_decoder_with_past_file_name = ONNX_DECODER_WITH_PAST_NAME if from_onnx else OV_DECODER_WITH_PAST_NAME @@ -256,7 +243,6 @@ def _from_transformers( cls, model_id: str, config: PretrainedConfig, - use_auth_token: Optional[Union[bool, str]] = None, token: Optional[Union[bool, str]] = None, revision: Optional[str] = None, force_download: bool = False, @@ -282,8 +268,6 @@ def _from_transformers( save_dir (`str` or `Path`): The directory where the exported ONNX model should be saved, defaults to `transformers.file_utils.default_cache_path`, which is the cache directory for transformers. - use_auth_token (`Optional[str]`, defaults to `None`): - Deprecated. Please use `token` instead. token (Optional[Union[bool, str]], defaults to `None`): The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated when running `huggingface-cli login` (stored in `~/.huggingface`). @@ -292,15 +276,6 @@ def _from_transformers( kwargs (`Dict`, *optional*): kwargs will be passed to the model during initialization """ - if use_auth_token is not None: - warnings.warn( - "The `use_auth_token` argument is deprecated and will be removed soon. Please use the `token` argument instead.", - FutureWarning, - ) - if token is not None: - raise ValueError("You cannot use both `use_auth_token` and `token` arguments at the same time.") - token = use_auth_token - save_dir = TemporaryDirectory() save_dir_path = Path(save_dir.name) diff --git a/optimum/intel/openvino/modeling_decoder.py b/optimum/intel/openvino/modeling_decoder.py index 23117e936..534edd20b 100644 --- a/optimum/intel/openvino/modeling_decoder.py +++ b/optimum/intel/openvino/modeling_decoder.py @@ -14,7 +14,6 @@ import copy import logging import os -import warnings from pathlib import Path from tempfile import TemporaryDirectory from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union @@ -244,7 +243,6 @@ def _from_transformers( cls, model_id: str, config: PretrainedConfig, - use_auth_token: Optional[Union[bool, str]] = None, token: Optional[Union[bool, str]] = None, revision: Optional[str] = None, force_download: bool = False, @@ -258,15 +256,6 @@ def _from_transformers( quantization_config: Optional[Union[OVWeightQuantizationConfig, Dict]] = None, **kwargs, ): - if use_auth_token is not None: - warnings.warn( - "The `use_auth_token` argument is deprecated and will be removed soon. Please use the `token` argument instead.", - FutureWarning, - ) - if token is not None: - raise ValueError("You cannot use both `use_auth_token` and `token` arguments at the same time.") - token = use_auth_token - save_dir = TemporaryDirectory() save_dir_path = Path(save_dir.name) # This attribute is needed to keep one reference on the temporary directory, since garbage collecting @@ -307,6 +296,7 @@ def _from_transformers( ov_config=ov_export_config, stateful=stateful, model_loading_kwargs=model_loading_kwargs, + library_name=cls._library_name, ) config.is_decoder = True @@ -749,7 +739,6 @@ def _from_pretrained( cls, model_id: Union[str, Path], config: PretrainedConfig, - use_auth_token: Optional[Union[bool, str]] = None, token: Optional[Union[bool, str]] = None, revision: Optional[Union[str, None]] = None, force_download: bool = False, @@ -762,15 +751,6 @@ def _from_pretrained( quantization_config: Optional[Union[OVWeightQuantizationConfig, Dict]] = None, **kwargs, ): - if use_auth_token is not None: - warnings.warn( - "The `use_auth_token` argument is deprecated and will be removed soon. Please use the `token` argument instead.", - FutureWarning, - ) - if token is not None: - raise ValueError("You cannot use both `use_auth_token` and `token` arguments at the same time.") - token = use_auth_token - model_path = Path(model_id) default_file_name = ONNX_WEIGHTS_NAME if from_onnx else OV_XML_FILE_NAME file_name = file_name or default_file_name diff --git a/optimum/intel/openvino/modeling_diffusion.py b/optimum/intel/openvino/modeling_diffusion.py index 9b945caab..a39c9a80b 100644 --- a/optimum/intel/openvino/modeling_diffusion.py +++ b/optimum/intel/openvino/modeling_diffusion.py @@ -16,7 +16,6 @@ import logging import os import shutil -import warnings from copy import deepcopy from pathlib import Path from tempfile import TemporaryDirectory, gettempdir @@ -79,6 +78,7 @@ class OVStableDiffusionPipelineBase(OVBaseModel, OVTextualInversionLoaderMixin): auto_model_class = StableDiffusionPipeline config_name = "model_index.json" export_feature = "text-to-image" + _library_name = "diffusers" def __init__( self, @@ -210,7 +210,6 @@ def _from_pretrained( cls, model_id: Union[str, Path], config: Dict[str, Any], - use_auth_token: Optional[Union[bool, str]] = None, token: Optional[Union[bool, str]] = None, revision: Optional[str] = None, cache_dir: str = HUGGINGFACE_HUB_CACHE, @@ -226,15 +225,6 @@ def _from_pretrained( quantization_config: Union[OVWeightQuantizationConfig, Dict] = None, **kwargs, ): - if use_auth_token is not None: - warnings.warn( - "The `use_auth_token` argument is deprecated and will be removed soon. Please use the `token` argument instead.", - FutureWarning, - ) - if token is not None: - raise ValueError("You cannot use both `use_auth_token` and `token` arguments at the same time.") - token = use_auth_token - default_file_name = ONNX_WEIGHTS_NAME if from_onnx else OV_XML_FILE_NAME vae_decoder_file_name = vae_decoder_file_name or default_file_name text_encoder_file_name = text_encoder_file_name or default_file_name @@ -349,7 +339,6 @@ def _from_transformers( cls, model_id: str, config: Dict[str, Any], - use_auth_token: Optional[Union[bool, str]] = None, token: Optional[Union[bool, str]] = None, revision: Optional[str] = None, force_download: bool = False, @@ -363,15 +352,6 @@ def _from_transformers( quantization_config: Union[OVWeightQuantizationConfig, Dict] = None, **kwargs, ): - if use_auth_token is not None: - warnings.warn( - "The `use_auth_token` argument is deprecated and will be removed soon. Please use the `token` argument instead.", - FutureWarning, - ) - if token is not None: - raise ValueError("You cannot use both `use_auth_token` and `token` arguments at the same time.") - token = use_auth_token - save_dir = TemporaryDirectory() save_dir_path = Path(save_dir.name) @@ -393,6 +373,7 @@ def _from_transformers( local_files_only=local_files_only, force_download=force_download, ov_config=ov_config, + library_name=cls._library_name, ) return cls._from_pretrained( diff --git a/optimum/intel/openvino/modeling_sentence_transformers.py b/optimum/intel/openvino/modeling_sentence_transformers.py new file mode 100644 index 000000000..d523993cf --- /dev/null +++ b/optimum/intel/openvino/modeling_sentence_transformers.py @@ -0,0 +1,142 @@ +from pathlib import Path +from types import MethodType +from typing import Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE +from sentence_transformers import SentenceTransformer +from transformers import AutoTokenizer, PretrainedConfig +from transformers.file_utils import add_start_docstrings + +from .modeling import MODEL_START_DOCSTRING, OVModel + + +@add_start_docstrings( + """ + OpenVINO Model for feature extraction tasks for Sentence Transformers. + """, + MODEL_START_DOCSTRING, +) +class OVSentenceTransformer(OVModel): + export_feature = "feature-extraction" + _library_name = "sentence_transformers" + + def __init__(self, model=None, config=None, tokenizer=None, **kwargs): + super().__init__(model, config, **kwargs) + + self.encode = MethodType(SentenceTransformer.encode, self) + self._text_length = MethodType(SentenceTransformer._text_length, self) + self.default_prompt_name = None + self.truncate_dim = None + self.tokenizer = tokenizer + + def _save_pretrained(self, save_directory: Union[str, Path]): + super()._save_pretrained(save_directory) + self.tokenizer.save_pretrained(save_directory) + + def forward(self, inputs: Dict[str, torch.Tensor]): + self.compile() + input_ids = inputs.get("input_ids") + attention_mask = inputs.get("attention_mask") + token_type_ids = inputs.get("token_type_ids") + + np_inputs = isinstance(input_ids, np.ndarray) + if not np_inputs: + input_ids = np.array(input_ids) + attention_mask = np.array(attention_mask) + token_type_ids = np.array(token_type_ids) if token_type_ids is not None else token_type_ids + + inputs = { + "input_ids": input_ids, + "attention_mask": attention_mask, + } + + # Add the token_type_ids when needed + if "token_type_ids" in self.input_names: + inputs["token_type_ids"] = token_type_ids if token_type_ids is not None else np.zeros_like(input_ids) + + outputs = self._inference(inputs) + return { + "token_embeddings": torch.from_numpy(outputs["token_embeddings"]).to(self.device), + "sentence_embedding": torch.from_numpy(outputs["sentence_embedding"]).to(self.device), + } + + @classmethod + def _from_pretrained( + cls, + model_id: Union[str, Path], + config: PretrainedConfig, + token: Optional[Union[bool, str]] = None, + revision: Optional[str] = None, + force_download: bool = False, + cache_dir: str = HUGGINGFACE_HUB_CACHE, + file_name: Optional[str] = None, + subfolder: str = "", + from_onnx: bool = False, + local_files_only: bool = False, + **kwargs, + ): + trust_remote_code = kwargs.pop("trust_remote_code", False) + tokenizer_kwargs = kwargs.pop("tokenizer_kwargs", None) + + tokenizer_args = { + "token": token, + "trust_remote_code": trust_remote_code, + "revision": revision, + "local_files_only": local_files_only, + } + if tokenizer_kwargs: + kwargs["tokenizer_args"].update(tokenizer_kwargs) + + tokenizer = AutoTokenizer.from_pretrained(model_id, **tokenizer_args) + + return super()._from_pretrained( + model_id=model_id, + config=config, + token=token, + revision=revision, + force_download=force_download, + cache_dir=cache_dir, + file_name=file_name, + subfolder=subfolder, + from_onnx=from_onnx, + local_files_only=local_files_only, + tokenizer=tokenizer, + **kwargs, + ) + + def tokenize( + self, texts: Union[List[str], List[Dict], List[Tuple[str, str]]], padding: Union[str, bool] = True + ) -> Dict[str, torch.Tensor]: + """Tokenizes a text and maps tokens to token-ids""" + output = {} + if isinstance(texts[0], str): + to_tokenize = [texts] + elif isinstance(texts[0], dict): + to_tokenize = [] + output["text_keys"] = [] + for lookup in texts: + text_key, text = next(iter(lookup.items())) + to_tokenize.append(text) + output["text_keys"].append(text_key) + to_tokenize = [to_tokenize] + else: + batch1, batch2 = [], [] + for text_tuple in texts: + batch1.append(text_tuple[0]) + batch2.append(text_tuple[1]) + to_tokenize = [batch1, batch2] + + # strip + to_tokenize = [[str(s).strip() for s in col] for col in to_tokenize] + + output.update( + self.tokenizer( + *to_tokenize, + padding=padding, + truncation="longest_first", + return_tensors="pt", + ) + ) + return output diff --git a/optimum/intel/pipelines/pipeline_base.py b/optimum/intel/pipelines/pipeline_base.py index ae3b5d3eb..d26d8c42b 100644 --- a/optimum/intel/pipelines/pipeline_base.py +++ b/optimum/intel/pipelines/pipeline_base.py @@ -48,7 +48,6 @@ is_ipex_available, is_openvino_available, ) -from optimum.intel.utils.modeling_utils import _find_files_matching_pattern if is_ipex_available(): @@ -228,20 +227,9 @@ def load_openvino_model( model_kwargs = model_kwargs or {} ov_model_class = SUPPORTED_TASKS[targeted_task]["class"][0] - if model is None: - model_id = SUPPORTED_TASKS[targeted_task]["default"] - model = ov_model_class.from_pretrained(model_id, export=True, **hub_kwargs, **model_kwargs) - elif isinstance(model, str): - model_id = model - pattern = r"(.*)?openvino(.*)?\_model.xml" - ov_files = _find_files_matching_pattern( - model, - pattern, - use_auth_token=hub_kwargs.get("token", None), - revision=hub_kwargs.get("revision", None), - ) - export = len(ov_files) == 0 - model = ov_model_class.from_pretrained(model, export=export, **hub_kwargs, **model_kwargs) + if isinstance(model, str) or model is None: + model_id = model or SUPPORTED_TASKS[targeted_task]["default"] + model = ov_model_class.from_pretrained(model_id, **hub_kwargs, **model_kwargs) elif isinstance(model, OVBaseModel): model_id = model.model_save_dir else: diff --git a/optimum/intel/utils/__init__.py b/optimum/intel/utils/__init__.py index 50cdfa143..b79deeb62 100644 --- a/optimum/intel/utils/__init__.py +++ b/optimum/intel/utils/__init__.py @@ -24,6 +24,7 @@ is_nncf_available, is_numa_available, is_openvino_available, + is_sentence_transformers_available, is_torch_version, is_transformers_available, is_transformers_version, diff --git a/optimum/intel/utils/dummy_openvino_and_sentence_transformers_objects.py b/optimum/intel/utils/dummy_openvino_and_sentence_transformers_objects.py new file mode 100644 index 000000000..fd13e5f56 --- /dev/null +++ b/optimum/intel/utils/dummy_openvino_and_sentence_transformers_objects.py @@ -0,0 +1,26 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .import_utils import DummyObject, requires_backends + + +class OVSentenceTransformer(metaclass=DummyObject): + _backends = ["openvino", "sentence_transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["openvino", "sentence_transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["openvino", "sentence_transformers"]) diff --git a/optimum/intel/utils/import_utils.py b/optimum/intel/utils/import_utils.py index 032280e94..8024d2389 100644 --- a/optimum/intel/utils/import_utils.py +++ b/optimum/intel/utils/import_utils.py @@ -159,6 +159,16 @@ _numa_available = False +_sentence_transformers_available = importlib.util.find_spec("sentence_transformers") is not None +_sentence_transformers_available = "N/A" + +if _sentence_transformers_available: + try: + _sentence_transformers_available = importlib_metadata.version("sentence_transformers") + except importlib_metadata.PackageNotFoundError: + _sentence_transformers_available = False + + def is_transformers_available(): return _transformers_available @@ -280,6 +290,10 @@ def is_accelerate_available(): return _accelerate_available +def is_sentence_transformers_available(): + return _sentence_transformers_available + + def is_numa_available(): return _numa_available diff --git a/optimum/intel/utils/modeling_utils.py b/optimum/intel/utils/modeling_utils.py index 1d2f7b03c..672636ff1 100644 --- a/optimum/intel/utils/modeling_utils.py +++ b/optimum/intel/utils/modeling_utils.py @@ -24,6 +24,8 @@ import torch from huggingface_hub import HfApi, HfFolder +from optimum.exporters import TasksManager + from .import_utils import is_numa_available @@ -102,19 +104,26 @@ def _find_files_matching_pattern( Returns: `List[Path]` """ - model_path = Path(model_name_or_path) if isinstance(model_name_or_path, str) else model_name_or_path - pattern = re.compile(f"{subfolder}/{pattern}" if subfolder != "" else pattern) - subfolder = subfolder or "." + model_path = Path(model_name_or_path) if not isinstance(model_name_or_path, Path) else model_name_or_path + + if isinstance(use_auth_token, bool): + token = HfFolder().get_token() + else: + token = use_auth_token + + library_name = TasksManager.infer_library_from_model( + str(model_name_or_path), subfolder=subfolder, revision=revision, token=token + ) + if library_name == "diffusers": + subfolder = os.path.join(subfolder, "unet") + else: + subfolder = subfolder or "." if model_path.is_dir(): glob_pattern = subfolder + "/*" files = model_path.glob(glob_pattern) files = [p for p in files if re.search(pattern, str(p))] else: - if isinstance(use_auth_token, bool): - token = HfFolder().get_token() - else: - token = use_auth_token repo_files = map(Path, HfApi().list_repo_files(model_name_or_path, revision=revision, token=token)) files = [Path(p) for p in repo_files if re.match(pattern, str(p)) and str(p.parent) == subfolder] diff --git a/tests/openvino/test_diffusion.py b/tests/openvino/test_diffusion.py index 03c1e2048..8236c53db 100644 --- a/tests/openvino/test_diffusion.py +++ b/tests/openvino/test_diffusion.py @@ -94,7 +94,7 @@ class OVStableDiffusionPipelineBaseTest(unittest.TestCase): @parameterized.expand(SUPPORTED_ARCHITECTURES) def test_num_images_per_prompt(self, model_arch: str): model_id = MODEL_NAMES[model_arch] - pipeline = self.MODEL_CLASS.from_pretrained(model_id, export=True, compile=False) + pipeline = self.MODEL_CLASS.from_pretrained(model_id, compile=False) pipeline.to("cpu") pipeline.compile() self.assertEqual(pipeline.vae_scale_factor, 2) diff --git a/tests/openvino/test_export.py b/tests/openvino/test_export.py index ef20ed5a2..d48e86fe2 100644 --- a/tests/openvino/test_export.py +++ b/tests/openvino/test_export.py @@ -16,7 +16,6 @@ import unittest from pathlib import Path from tempfile import TemporaryDirectory -from typing import Optional import torch from parameterized import parameterized @@ -76,7 +75,6 @@ class ExportModelTest(unittest.TestCase): def _openvino_export( self, model_type: str, - compression_option: Optional[str] = None, stateful: bool = True, patch_16bit_model: bool = False, ): @@ -106,7 +104,6 @@ def _openvino_export( output=Path(tmpdirname), task=supported_task, preprocessors=preprocessors, - compression_option=compression_option, stateful=stateful, ) diff --git a/tests/openvino/test_exporters_cli.py b/tests/openvino/test_exporters_cli.py index 6380a5288..b95c83881 100644 --- a/tests/openvino/test_exporters_cli.py +++ b/tests/openvino/test_exporters_cli.py @@ -36,6 +36,7 @@ OVModelForSeq2SeqLM, OVModelForSequenceClassification, OVModelForTokenClassification, + OVSentenceTransformer, OVStableDiffusionPipeline, OVStableDiffusionXLPipeline, ) @@ -108,16 +109,12 @@ class OVCLIExportTestCase(unittest.TestCase): ), ] - def _openvino_export( - self, model_name: str, task: str, compression_option: str = None, compression_ratio: float = None - ): + def _openvino_export(self, model_name: str, task: str): with TemporaryDirectory() as tmpdir: main_export( model_name_or_path=model_name, output=tmpdir, task=task, - compression_option=compression_option, - compression_ratio=compression_ratio, ) @parameterized.expand(SUPPORTED_ARCHITECTURES) @@ -320,5 +317,5 @@ def test_exporters_cli_sentence_transformers(self): shell=True, check=True, ) - model = eval(_HEAD_TO_AUTOMODELS["feature-extraction"]).from_pretrained(tmpdir, compile=False) + model = OVSentenceTransformer.from_pretrained(tmpdir, compile=False) self.assertFalse("last_hidden_state" in model.output_names) diff --git a/tests/openvino/test_modeling.py b/tests/openvino/test_modeling.py index 6f24ea0de..30c70c7c9 100644 --- a/tests/openvino/test_modeling.py +++ b/tests/openvino/test_modeling.py @@ -17,6 +17,7 @@ import tempfile import time import unittest +from pathlib import Path from typing import Dict import numpy as np @@ -26,6 +27,7 @@ import torch from datasets import load_dataset from evaluate import evaluator +from huggingface_hub import HfApi from parameterized import parameterized from PIL import Image from transformers import ( @@ -75,6 +77,7 @@ OVModelForSpeechSeq2Seq, OVModelForTokenClassification, OVModelForVision2Seq, + OVSentenceTransformer, OVStableDiffusionPipeline, ) from optimum.intel.openvino import OV_DECODER_NAME, OV_DECODER_WITH_PAST_NAME, OV_ENCODER_NAME, OV_XML_FILE_NAME @@ -84,6 +87,7 @@ from optimum.intel.openvino.utils import _print_compiled_model_properties from optimum.intel.pipelines import pipeline as optimum_pipeline from optimum.intel.utils.import_utils import is_openvino_version, is_transformers_version +from optimum.intel.utils.modeling_utils import _find_files_matching_pattern from optimum.utils import ( DIFFUSION_MODEL_TEXT_ENCODER_SUBFOLDER, DIFFUSION_MODEL_UNET_SUBFOLDER, @@ -253,14 +257,64 @@ def test_load_from_hub_and_save_stable_diffusion_model(self): @pytest.mark.run_slow @slow def test_load_model_from_hub_private_with_token(self): + model_id = "optimum-internal-testing/tiny-random-phi-private" token = os.environ.get("HF_HUB_READ_TOKEN", None) if token is None: self.skipTest("Test requires a token `HF_HUB_READ_TOKEN` in the environment variable") - model = OVModelForCausalLM.from_pretrained( - "optimum-internal-testing/tiny-random-phi-private", token=token, revision="openvino" - ) + model = OVModelForCausalLM.from_pretrained(model_id, token=token, revision="openvino") self.assertIsInstance(model.config, PretrainedConfig) + self.assertTrue(model.stateful) + + def test_infer_export_when_loading(self): + model_id = MODEL_NAMES["phi"] + model = AutoModelForCausalLM.from_pretrained(model_id) + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(Path(tmpdirname) / "original") + # Load original model and convert + model = OVModelForCausalLM.from_pretrained(Path(tmpdirname) / "original") + model.save_pretrained(Path(tmpdirname) / "openvino") + # Load openvino model + model = OVModelForCausalLM.from_pretrained(Path(tmpdirname) / "openvino") + del model + gc.collect() + + def test_find_files_matching_pattern(self): + model_id = "echarlaix/tiny-random-PhiForCausalLM" + pattern = r"(.*)?openvino(.*)?\_model.xml" + # hub model + for revision in ("main", "ov", "itrex"): + ov_files = _find_files_matching_pattern( + model_id, pattern=pattern, revision=revision, subfolder="openvino" if revision == "itrex" else "" + ) + self.assertTrue(len(ov_files) == 0 if revision == "main" else len(ov_files) > 0) + + # local model + api = HfApi() + with tempfile.TemporaryDirectory() as tmpdirname: + for revision in ("main", "ov", "itrex"): + local_dir = Path(tmpdirname) / revision + api.snapshot_download(repo_id=model_id, local_dir=local_dir, revision=revision) + ov_files = _find_files_matching_pattern( + local_dir, pattern=pattern, revision=revision, subfolder="openvino" if revision == "itrex" else "" + ) + self.assertTrue(len(ov_files) == 0 if revision == "main" else len(ov_files) > 0) + + @parameterized.expand(("stable-diffusion", "stable-diffusion-openvino")) + def test_find_files_matching_pattern_sd(self, model_arch): + pattern = r"(.*)?openvino(.*)?\_model.xml" + model_id = MODEL_NAMES[model_arch] + # hub model + ov_files = _find_files_matching_pattern(model_id, pattern=pattern) + self.assertTrue(len(ov_files) > 0 if "openvino" in model_id else len(ov_files) == 0) + + # local model + api = HfApi() + with tempfile.TemporaryDirectory() as tmpdirname: + local_dir = Path(tmpdirname) / "model" + api.snapshot_download(repo_id=model_id, local_dir=local_dir) + ov_files = _find_files_matching_pattern(local_dir, pattern=pattern) + self.assertTrue(len(ov_files) > 0 if "openvino" in model_id else len(ov_files) == 0) class PipelineTest(unittest.TestCase): @@ -350,7 +404,7 @@ def test_compare_to_transformers(self, model_arch): def test_pipeline(self, model_arch): set_seed(SEED) model_id = MODEL_NAMES[model_arch] - model = OVModelForSequenceClassification.from_pretrained(model_id, export=True, compile=False) + model = OVModelForSequenceClassification.from_pretrained(model_id, compile=False) model.eval() tokenizer = AutoTokenizer.from_pretrained(model_id) pipe = pipeline("text-classification", model=model, tokenizer=tokenizer) @@ -433,7 +487,7 @@ def test_compare_to_transformers(self, model_arch): def test_pipeline(self, model_arch): set_seed(SEED) model_id = MODEL_NAMES[model_arch] - model = OVModelForQuestionAnswering.from_pretrained(model_id, export=True) + model = OVModelForQuestionAnswering.from_pretrained(model_id) model.eval() tokenizer = AutoTokenizer.from_pretrained(model_id) pipe = pipeline("question-answering", model=model, tokenizer=tokenizer) @@ -509,7 +563,7 @@ def test_compare_to_transformers(self, model_arch): def test_pipeline(self, model_arch): set_seed(SEED) model_id = MODEL_NAMES[model_arch] - model = OVModelForTokenClassification.from_pretrained(model_id, export=True) + model = OVModelForTokenClassification.from_pretrained(model_id) model.eval() tokenizer = AutoTokenizer.from_pretrained(model_id) pipe = pipeline("token-classification", model=model, tokenizer=tokenizer) @@ -586,7 +640,7 @@ def test_compare_to_transformers(self, model_arch): def test_pipeline(self, model_arch): set_seed(SEED) model_id = MODEL_NAMES[model_arch] - model = OVModelForFeatureExtraction.from_pretrained(model_id, export=True) + model = OVModelForFeatureExtraction.from_pretrained(model_id) model.eval() tokenizer = AutoTokenizer.from_pretrained(model_id) pipe = pipeline("feature-extraction", model=model, tokenizer=tokenizer) @@ -602,6 +656,20 @@ def test_pipeline(self, model_arch): del model gc.collect() + @parameterized.expand(SUPPORTED_ARCHITECTURES) + def test_sentence_transformers_pipeline(self, model_arch): + """ + Check if we call OVModelForFeatureExtraction passing saved ir-model with outputs + from Sentence Transformers then an appropriate exception raises. + """ + model_id = MODEL_NAMES[model_arch] + with tempfile.TemporaryDirectory() as tmp_dir: + save_dir = str(tmp_dir) + OVSentenceTransformer.from_pretrained(model_id, export=True).save_pretrained(save_dir) + with self.assertRaises(Exception) as context: + OVModelForFeatureExtraction.from_pretrained(save_dir) + self.assertIn("Please use `OVSentenceTransformer`", str(context.exception)) + class OVModelForCausalLMIntegrationTest(unittest.TestCase): SUPPORTED_ARCHITECTURES = ( @@ -789,9 +857,7 @@ def test_pipeline(self, model_arch): if model_arch == "qwen": tokenizer._convert_tokens_to_ids = lambda x: 0 - model = OVModelForCausalLM.from_pretrained( - model_id, export=True, use_cache=False, compile=False, **model_kwargs - ) + model = OVModelForCausalLM.from_pretrained(model_id, use_cache=False, compile=False, **model_kwargs) model.eval() model.config.encoder_no_repeat_ngram_size = 0 model.to("cpu") @@ -1116,7 +1182,7 @@ def test_compare_to_transformers(self, model_arch): def test_pipeline(self, model_arch): model_id = MODEL_NAMES[model_arch] set_seed(SEED) - model = OVModelForMaskedLM.from_pretrained(model_id, export=True) + model = OVModelForMaskedLM.from_pretrained(model_id) model.eval() tokenizer = AutoTokenizer.from_pretrained(model_id) pipe = pipeline("fill-mask", model=model, tokenizer=tokenizer) @@ -1185,7 +1251,7 @@ def test_compare_to_transformers(self, model_arch): def test_pipeline(self, model_arch): set_seed(SEED) model_id = MODEL_NAMES[model_arch] - model = OVModelForImageClassification.from_pretrained(model_id, export=True) + model = OVModelForImageClassification.from_pretrained(model_id) model.eval() preprocessor = AutoFeatureExtractor.from_pretrained(model_id) pipe = pipeline("image-classification", model=model, feature_extractor=preprocessor) @@ -1292,7 +1358,7 @@ def test_pipeline(self, model_arch): model_id = MODEL_NAMES[model_arch] tokenizer = AutoTokenizer.from_pretrained(model_id) inputs = "This is a test" - model = OVModelForSeq2SeqLM.from_pretrained(model_id, export=True, compile=False) + model = OVModelForSeq2SeqLM.from_pretrained(model_id, compile=False) model.eval() model.half() model.to("cpu") @@ -1433,7 +1499,7 @@ def test_compare_to_transformers(self, model_arch): def test_pipeline(self, model_arch): set_seed(SEED) model_id = MODEL_NAMES[model_arch] - model = OVModelForAudioClassification.from_pretrained(model_id, export=True) + model = OVModelForAudioClassification.from_pretrained(model_id) model.eval() preprocessor = AutoFeatureExtractor.from_pretrained(model_id) pipe = pipeline("audio-classification", model=model, feature_extractor=preprocessor) @@ -1753,8 +1819,7 @@ def test_compare_to_transformers(self, model_arch): def test_pipeline(self, model_arch): set_seed(SEED) model_id = MODEL_NAMES[model_arch] - model = OVModelForSpeechSeq2Seq.from_pretrained(model_id, export=True) - + model = OVModelForSpeechSeq2Seq.from_pretrained(model_id) processor = get_preprocessor(model_id) pipe = pipeline( "automatic-speech-recognition", @@ -1861,7 +1926,7 @@ def test_compare_to_transformers(self, model_arch: str): def test_pipeline(self, model_arch: str): set_seed(SEED) model_id = MODEL_NAMES[model_arch] - ov_model = OVModelForVision2Seq.from_pretrained(model_id, export=True, compile=False) + ov_model = OVModelForVision2Seq.from_pretrained(model_id, compile=False) feature_extractor, tokenizer = self._get_preprocessors(model_id) ov_model.reshape(1, -1) ov_model.compile() diff --git a/tests/openvino/test_modeling_sentence_transformers.py b/tests/openvino/test_modeling_sentence_transformers.py new file mode 100644 index 000000000..acda04512 --- /dev/null +++ b/tests/openvino/test_modeling_sentence_transformers.py @@ -0,0 +1,74 @@ +# Copyright 2021 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import os +import tempfile +import unittest + +import numpy as np +from parameterized import parameterized +from sentence_transformers import SentenceTransformer +from transformers import ( + PretrainedConfig, + set_seed, +) + +from optimum.intel import OVSentenceTransformer + + +SEED = 42 + +F32_CONFIG = {"INFERENCE_PRECISION_HINT": "f32"} + +MODEL_NAMES = { + "bert": "sentence-transformers/all-MiniLM-L6-v2", + "mpnet": "sentence-transformers/all-mpnet-base-v2", +} + + +class OVModelForSTFeatureExtractionIntegrationTest(unittest.TestCase): + SUPPORTED_ARCHITECTURES = ( + "bert", + "mpnet", + ) + + @parameterized.expand(SUPPORTED_ARCHITECTURES) + def test_compare_to_transformers(self, model_arch): + model_id = MODEL_NAMES[model_arch] + set_seed(SEED) + ov_model = OVSentenceTransformer.from_pretrained(model_id, export=True, ov_config=F32_CONFIG) + self.assertIsInstance(ov_model.config, PretrainedConfig) + self.assertTrue(hasattr(ov_model, "encode")) + st_model = SentenceTransformer(model_id) + sentences = ["This is an example sentence", "Each sentence is converted"] + st_embeddings = st_model.encode(sentences) + ov_embeddings = ov_model.encode(sentences) + # Compare tensor outputs + self.assertTrue(np.allclose(ov_embeddings, st_embeddings, atol=1e-4)) + del st_embeddings + del ov_model + gc.collect() + + @parameterized.expand(SUPPORTED_ARCHITECTURES) + def test_sentence_transformers_save_and_infer(self, model_arch): + model_id = MODEL_NAMES[model_arch] + ov_model = OVSentenceTransformer.from_pretrained(model_id, export=True, ov_config=F32_CONFIG) + with tempfile.TemporaryDirectory() as tmpdirname: + model_save_path = os.path.join(tmpdirname, "sentence_transformers_ov_model") + ov_model.save_pretrained(model_save_path) + model = OVSentenceTransformer.from_pretrained(model_save_path) + sentences = ["This is an example sentence", "Each sentence is converted"] + model.encode(sentences) + gc.collect() diff --git a/tests/openvino/utils_tests.py b/tests/openvino/utils_tests.py index 869d5897e..7dffdb3a1 100644 --- a/tests/openvino/utils_tests.py +++ b/tests/openvino/utils_tests.py @@ -112,6 +112,7 @@ "speech_to_text": "hf-internal-testing/tiny-random-Speech2TextModel", "squeezebert": "hf-internal-testing/tiny-random-squeezebert", "stable-diffusion": "hf-internal-testing/tiny-stable-diffusion-torch", + "stable-diffusion-openvino": "hf-internal-testing/tiny-stable-diffusion-openvino", "stable-diffusion-xl": "echarlaix/tiny-random-stable-diffusion-xl", "stable-diffusion-xl-refiner": "echarlaix/tiny-random-stable-diffusion-xl-refiner", "stablelm": "hf-internal-testing/tiny-random-StableLmForCausalLM",