diff --git a/src/stability_sdk/animation_ui.py b/src/stability_sdk/animation_ui.py index d340de86..ff625108 100644 --- a/src/stability_sdk/animation_ui.py +++ b/src/stability_sdk/animation_ui.py @@ -253,7 +253,7 @@ def project_load(title: str): global project project = next(p for p in projects if p.title == title) try: - data = project.load_settings() + data = project.get_settings() except OutOfCreditsException as e: log = f"Not enough credits to load project '{title}'\n{e.details}" returns = args_to_controls(get_default_project()) diff --git a/src/stability_sdk/api.py b/src/stability_sdk/api.py index 18813a87..2ecd101a 100644 --- a/src/stability_sdk/api.py +++ b/src/stability_sdk/api.py @@ -1,20 +1,24 @@ import grpc import json import logging +import os import random +import shutil import time +import uuid import warnings from google.protobuf.struct_pb2 import Struct from PIL import Image from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union +from abc import ABC, abstractmethod try: import cv2 import numpy as np except ImportError: warnings.warn( - "Failed to import animation reqs. To use the animation toolchain, install the requisite dependencies via:" + "Failed to import animation reqs. To use the animation toolchain, install the requisite dependencies via:" " pip install --upgrade stability_sdk[anim]" ) @@ -27,6 +31,7 @@ from .utils import ( image_mix, + image_to_png_bytes, image_to_prompt, tensor_to_prompt, ) @@ -40,7 +45,7 @@ def open_channel(host: str, api_key: str = None, max_message_len: int = 10*1024* options=[ ('grpc.max_send_message_length', max_message_len), ('grpc.max_receive_message_length', max_message_len), - ] + ] if host.endswith(":443"): call_credentials = [grpc.access_token_call_credentials(api_key)] channel_credentials = grpc.composite_channel_credentials( @@ -68,36 +73,93 @@ def __init__(self, stub, engine_id): self.stub = stub self.engine_id = engine_id -class Project(): - def __init__(self, context: 'Context', project: project.Project): + +class StorageBackend(ABC): + def __init__(self, context: 'Context', primary: bool = False, primary_fs: bool = False): self._context = context - self._project = project + self.primary = primary + self.primary_fs = primary_fs - @property - def id(self) -> str: - return self._project.id + @staticmethod + @abstractmethod + def create_project( + context: 'Context', + title: str, + access: project.ProjectAccess = project.PROJECT_ACCESS_PRIVATE, + status: project.ProjectStatus = project.PROJECT_STATUS_ACTIVE, + proj_id_to_use: str = None + ) -> 'Project': + pass - @property - def file_id(self) -> str: - return self._project.file.id + @staticmethod + @abstractmethod + def get_project( + context: 'Context', + id: str + ) -> 'Project': + pass - @property - def title(self) -> str: - return self._project.title + @staticmethod + @abstractmethod + def delete_project(context: 'Context', id: str) -> None: + pass @staticmethod - def create( - context: 'Context', - title: str, - access: project.ProjectAccess=project.PROJECT_ACCESS_PRIVATE, - status: project.ProjectStatus=project.PROJECT_STATUS_ACTIVE + @abstractmethod + def list_projects(context: 'Context') -> List['Project']: + pass + + @abstractmethod + def get_project_settings(self, proj: 'Project', asset_id: str = None) -> dict: + pass + + @abstractmethod + def put_project_settings(self, context: 'Context', proj: 'Project', data: dict) -> str: + pass + + def get_image_asset(self, proj: 'Project', asset_id: str, use: generation.AssetUse) -> Image.Image: + pass + + @abstractmethod + def put_image_asset(self, proj: 'Project', image: Union[Image.Image, np.ndarray], use: generation.AssetUse, asset_id: str = None) -> str: + pass + + def get_video_asset(self, proj: 'Project', asset_id: str, use: generation.AssetUse) -> bytes: + pass + + @abstractmethod + def put_video_asset(self, proj: 'Project', video_path: str, asset_id: str) -> str: + pass + + def update_project(context: 'Context', proj: 'Project', title: str = None, file_id: str = None, file_uri: str = None) -> None: + pass + + +class AssetServiceBackend(StorageBackend): + def __init__(self, context: 'Context', primary: bool = False): + super().__init__(context, primary) + + @staticmethod + def create_project( + context: 'Context', + title: str, + access: project.ProjectAccess = project.PROJECT_ACCESS_PRIVATE, + status: project.ProjectStatus = project.PROJECT_STATUS_ACTIVE, + proj_id_to_use: str = None ) -> 'Project': req = project.CreateProjectRequest(title=title, access=access, status=status) proj: project.Project = context._proj_stub.Create(req, wait_for_ready=True) return Project(context, proj) - def delete(self): - self._context._proj_stub.Delete(project.DeleteProjectRequest(id=self.id)) + @staticmethod + def get_project(context: 'Context', id: str) -> 'Project': + req = project.GetProjectRequest(id=id) + proj: project.Project = context._proj_stub.Get(req, wait_for_ready=True) + return Project(context, proj) + + @staticmethod + def delete_project(context: 'Context', id: str) -> None: + context._proj_stub.Delete(project.DeleteProjectRequest(id=id)) @staticmethod def list_projects(context: 'Context') -> List['Project']: @@ -108,82 +170,458 @@ def list_projects(context: 'Context') -> List['Project']: results.sort(key=lambda x: x.title.lower()) return results - def load_settings(self) -> dict: + def get_project_settings(self, proj: 'Project', asset_id: str = None) -> dict: + asset_id = asset_id if asset_id else proj.file.id request = generation.Request( engine_id=self._context._asset.engine_id, prompt=[generation.Prompt( artifact=generation.Artifact( type=generation.ARTIFACT_TEXT, mime="application/json", - uuid=self.file_id, + uuid=asset_id, ) )], asset=generation.AssetParameters( - action=generation.ASSET_GET, - project_id=self.id, + action=generation.ASSET_GET, + project_id=proj.id, use=generation.ASSET_USE_PROJECT ) ) results = self._context._run_request(self._context._asset, request) if generation.ARTIFACT_TEXT in results: - return json.loads(results[generation.ARTIFACT_TEXT][0]) - raise Exception(f"Failed to load project file for {self.id}") + settings_json = json.loads(results[generation.ARTIFACT_TEXT][0]) + return settings_json + raise Exception(f"Failed to load project file for {proj.id}") - def save_settings(self, data: dict) -> str: + def put_project_settings(self, context: 'Context', proj: 'Project', data: dict) -> str: contents = json.dumps(data) request = generation.Request( - engine_id=self._context._asset.engine_id, + engine_id=context._asset.engine_id, prompt=[generation.Prompt( artifact=generation.Artifact( type=generation.ARTIFACT_TEXT, text=contents, mime="application/json", - uuid=self.file_id + uuid=proj.file_id ) )], asset=generation.AssetParameters( - action=generation.ASSET_PUT, - project_id=self.id, + action=generation.ASSET_PUT, + project_id=proj.id, use=generation.ASSET_USE_PROJECT ) ) - results = self._context._run_request(self._context._asset, request) + results = context._run_request(context._asset, request) if generation.ARTIFACT_TEXT in results: return results[generation.ARTIFACT_TEXT][0] - raise Exception(f"Failed to save project file for {self.id}") + raise Exception(f"Failed to save project file for {proj.id}") - def put_image_asset( - self, - image: Union[Image.Image, np.ndarray], - use: generation.AssetUse=generation.ASSET_USE_OUTPUT - ): + def get_image_asset(self, proj: 'Project', asset_id: str, use: generation.AssetUse) -> Image.Image: + request = generation.Request( + engine_id=self._context._asset.engine_id, + prompt=[generation.Prompt( + artifact=generation.Artifact(type=generation.ARTIFACT_IMAGE, mime="image/png", uuid=asset_id) + )], + asset=generation.AssetParameters( + action=generation.ASSET_GET, + project_id=proj.id, + use=use + ) + ) + results = self._context._run_request(self._context._asset, request) + if generation.ARTIFACT_IMAGE in results: + img = results[generation.ARTIFACT_IMAGE][0] + pil_img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) + return pil_img + raise Exception(f"Failed to load image asset {asset_id} for project {proj.id}") + + def put_image_asset(self, proj: 'Project', image: Union[Image.Image, np.ndarray], use: generation.AssetUse, asset_id: str = None) -> str: request = generation.Request( engine_id=self._context._asset.engine_id, prompt=[image_to_prompt(image)], asset=generation.AssetParameters( - action=generation.ASSET_PUT, - project_id=self.id, + action=generation.ASSET_PUT, + project_id=proj.id, use=use ) ) results = self._context._run_request(self._context._asset, request) if generation.ARTIFACT_TEXT in results: return results[generation.ARTIFACT_TEXT][0] - raise Exception(f"Failed to store image asset for project {self.id}") + raise Exception(f"Failed to store image asset for project {proj.id}") + + def get_video_asset(self, proj: 'Project', asset_id: str, use: generation.AssetUse) -> str: + request = generation.Request( + engine_id=self._context._asset.engine_id, + prompt=[generation.Prompt( + artifact=generation.Artifact(type=generation.ARTIFACT_VIDEO, mime="video/mp4", uuid=asset_id) + )], + asset=generation.AssetParameters( + action=generation.ASSET_GET, + project_id=proj.id, + use=use + ) + ) + results = self._context._run_request(self._context._asset, request) + if generation.ARTIFACT_VIDEO in results: + return results[generation.ARTIFACT_VIDEO][0] + raise Exception(f"Failed to load video asset {asset_id} for project {proj.id}") + + def put_video_asset(self, proj: 'Project', video_path: str, asset_id: str) -> str: + if not os.path.isfile(video_path) or not video_path.endswith(".mp4"): + raise ValueError("Invalid video file path. Must be an existing .mp4 file.") + + with open(video_path, "rb") as f: + binary_data = f.read() + + request = generation.Request( + engine_id=self._context._asset.engine_id, + prompt=[ + generation.Prompt( + artifact=generation.Artifact( + type=generation.ARTIFACT_VIDEO, + mime="video/mp4", + binary=binary_data, + ) + ) + ], + asset=generation.AssetParameters( + action=generation.ASSET_PUT, + project_id=proj.id, + use=generation.ASSET_USE_INPUT, + ), + ) + results = self._context._run_request(self._context._asset, request) + if generation.ARTIFACT_TEXT in results: + return results[generation.ARTIFACT_TEXT][0] + raise Exception(f"Failed to store video asset for project {proj.id}") - def update(self, title:str=None, file_id:str=None, file_uri:str=None): + def update_project(context: 'Context', proj: 'Project', title: str = None, file_id: str = None, file_uri: str = None) -> None: file = project.ProjectAsset( id=file_id, uri=file_uri, use=project.PROJECT_ASSET_USE_PROJECT, ) if file_id and file_uri else None - - self._context._proj_stub.Update(project.UpdateProjectRequest( - id=self.id, + + context._proj_stub.Update(project.UpdateProjectRequest( + id=proj.id, title=title, file=file )) + +class LocalFileBackend(StorageBackend): + _projects_root = None + + def __init__(self, context: 'Context', primary: bool = False, primary_fs: bool = True, projects_root = 'projects'): + super().__init__(context, primary, primary_fs = primary_fs) + LocalFileBackend._projects_root = projects_root + + @staticmethod + def create_project( + context: 'Context', + title: str, + access: project.ProjectAccess = project.PROJECT_ACCESS_PRIVATE, + status: project.ProjectStatus = project.PROJECT_STATUS_ACTIVE, + proj_id_to_use: str = None + ) -> 'Project': + proj_id = proj_id_to_use if proj_id_to_use else str(uuid.uuid4()) + proj_file_id = proj_id # Let's keep it the same as the proj_id for now + proj = {"id": proj_id, + "title": title, + "file": {"id": proj_file_id}} + output_path = os.path.join(LocalFileBackend._projects_root, proj_id, proj_file_id) + os.makedirs(os.path.dirname(output_path), exist_ok=True) + with open(output_path, "w") as file: + json.dump(proj, file) + return Project(context, proj) + + @staticmethod + def get_project(context: 'Context', id: str) -> 'Project': + input_path = os.path.join(LocalFileBackend._projects_root, id, id) + with open(input_path, "r") as file: + proj = json.load(file) + return Project(context, proj) + + @staticmethod + def delete_project(context: 'Context', id: str): + if not id: + raise ValueError("Delete project requires a project id") + project_dir_path = os.path.join(LocalFileBackend._projects_root, id) + shutil.rmtree(project_dir_path) + + @staticmethod + def list_projects(context: 'Context') -> List['Project']: + # This returns a listing of directories in the projects root. + proj_root = LocalFileBackend._projects_root + all_entries = os.listdir(proj_root) + directories = [entry for entry in all_entries if os.path.isdir(os.path.join(proj_root, entry))] + projects = [] + for proj_id in directories: + proj_path = LocalFileBackend.get_path_for_asset(proj_id, proj_id) + try: + with open(proj_path, "r") as file: + proj_json = json.load(file) + proj_data = {"id": proj_json["id"], + "title": proj_json["title"], + "file": {"id": proj_json["file"]["id"]}} + projects.append(Project(context, proj_data)) + except FileNotFoundError: + pass + return projects + + def get_project_settings(self, proj: 'Project', asset_id: str = None) -> dict: + input_path = self.get_path_for_asset(proj.id, "project_settings.json") + with open(input_path, "r") as file: + settings_json = json.load(file) + return settings_json + + def put_project_settings(self, context: 'Context', proj: 'Project', data: dict) -> str: + filename = "project_settings.json" + output_path = self.get_path_for_asset(proj.id, filename) + os.makedirs(os.path.dirname(output_path), exist_ok=True) + with open(output_path, "w", encoding="utf-8") as file: + json.dump(data, file) + return filename + + def get_image_asset(self, proj: 'Project', asset_id: str, use: generation.AssetUse) -> Image.Image: + input_path = self.get_path_for_asset(proj.id, asset_id + '.png') + pil_image = Image.open(input_path) + return pil_image + + def put_image_asset(self, proj: 'Project', + image: Union[Image.Image, np.ndarray], + use: generation.AssetUse, + asset_id: str = None) -> str: + png = image_to_png_bytes(image) + if asset_id is not None: + filename = asset_id + else: + if not self.primary: + raise ValueError("If asset_id is None, then LocalFileBackend must be primary.") + filename = str(uuid.uuid4()) + output_path = self.get_path_for_asset(proj.id, filename) + os.makedirs(os.path.dirname(output_path), exist_ok=True) + with open(output_path + '.png', "wb") as file: + file.write(png) + return filename + + def get_video_asset(self, proj: 'Project', asset_id: str, use: generation.AssetUse) -> bytes: + input_path = self.get_path_for_asset(proj.id, asset_id) + with open(input_path, 'rb') as file: + binary_data = file.read() + return binary_data + + def put_video_asset(self, proj: 'Project', video_path: str, asset_id: str = None) -> str: + if not os.path.isfile(video_path) or not video_path.endswith(".mp4"): + raise ValueError("Invalid video file path. Must be an existing .mp4 file.") + + if asset_id is not None: + filename = asset_id + else: + if not self.primary: + raise ValueError("If name is None, then LocalFileBackend must be primary.") + filename = str(uuid.uuid4()) + output_path = self.get_path_for_asset(proj.id, filename) + os.makedirs(os.path.dirname(output_path), exist_ok=True) + shutil.copy(video_path, output_path) + return filename + + @staticmethod + def get_path_for_asset(project_id: str, filename: str): + path = os.path.join(LocalFileBackend._projects_root, project_id, filename) + return path + + def update_project(context: 'Context', proj: 'Project', title: str = None, file_id: str = None, file_uri: str = None): + proj_file_id = proj.file.id + proj = {"id": proj.id, + "title": title if title is not None else proj.title, + "file": {"id": file_id if file_id is not None else proj_file_id}} + output_path = os.path.join(LocalFileBackend._projects_root, proj.id, proj_file_id) + os.makedirs(os.path.dirname(output_path), exist_ok=True) + with open(output_path, "w") as file: + json.dump(proj, file) + return Project(context, proj) + + +class Project(): + _backends = None + _metadata_index = None + + def __init__(self, context: 'Context', proj: Union[project.Project, dict]): + ## __init__ could take backends: Optional[List[StorageBackend]] = None + # self._backends = backends if backends else [AssetServiceBackend(primary=True)] + self._context = context + + # proj should be project.Project or dict + # Currently, a supplied project.Project may contain additional properties that are ignored. + if isinstance(proj, dict): + self._project = project.Project() + self._project.id = proj["id"] + self._project.title = proj["title"] + self._project.file.id = proj["file"]["id"] + else: + self._project = proj + + def _primary_backend(self) -> Optional[StorageBackend]: + for backend in self.backends: + if backend.primary: + return backend + return None + + @property + def backends(self) -> str: + return Project._backends + + @property + def id(self) -> str: + return self._project.id + + @property + def file_id(self) -> str: + return self._project.file.id + + @property + def title(self) -> str: + return self._project.title + + @classmethod + def init_backends(cls, context: 'Context'): + cls._backends = [ + AssetServiceBackend(context=context, primary=True), + LocalFileBackend(context=context, primary=False)] + #cls._backends = [LocalFileBackend(context=context, primary=True)] + cls._metadata_index = cls.load_metadata_index() + + @staticmethod + def create( + context: 'Context', + title: str, + access: project.ProjectAccess = project.PROJECT_ACCESS_PRIVATE, + status: project.ProjectStatus = project.PROJECT_STATUS_ACTIVE + ) -> 'Project': + asset_id = None + for backend in Project._backends: + proj = backend.create_project(context, title, access, status, asset_id) + if isinstance(proj, dict): + proj_id = proj["id"] + proj_title = proj["title"] + else: + proj_id = proj.id + proj_title = proj.title + if backend.primary: + asset_id = proj_id + if backend.primary_fs: + filename = proj_id + proj_file_id = proj.file_id + mimetype = "application/json" + Project.add_asset_metadata(proj_id, asset_id, mimetype, filename, project_key="project_file_id") + return proj + + @classmethod + def get(cls, + context: 'Context', + id: str + ) -> 'Project': + for backend in cls._backends: + if backend.primary: + proj = backend.get_project(context, id) + return proj + raise Exception(f"Failed to list projects") + + def list_assets(self): + req = project.QueryAssetsRequest(id=self.id) + query_assets_response: project.QueryAssetsResponse = self._context._proj_stub.QueryAssets(req, + wait_for_ready=True) + return query_assets_response.assets + + def delete(self): + for backend in self.backends: + backend.delete_project(self._context, self.id) + Project.delete_project_metadata(self.id) + + @classmethod + def list_projects(cls, context: 'Context') -> List['Project']: + for backend in cls._backends: + if backend.primary: + results = backend.list_projects(context) + return results + raise Exception(f"Failed to list projects") + + def get_settings(self) -> dict: + for backend in self.backends: + if backend.primary: + result = backend.get_project_settings(self, self._metadata_index[self.id]["project_file_id"]) + return result + raise Exception(f"Failed to load project file for {self.id}") + + def save_settings(self, data: dict) -> str: + asset_id = None + filename = None + for backend in self.backends: + temp = backend.put_project_settings(self._context, self, data) + if backend.primary: + rsplit_res = temp.rsplit('/', 1) + asset_id = rsplit_res[1] if len(rsplit_res) > 1 else rsplit_res[0] + if backend.primary_fs: + filename = temp + mimetype = "application/json" + Project.add_asset_metadata(self.id, asset_id, mimetype, filename, project_key="project_file_id") + return asset_id + + def get_image_asset(self, asset_id: str, use: generation.AssetUse = generation.ASSET_USE_PROJECT) -> Image.Image: + for backend in self.backends: + if backend.primary: + result = backend.get_image_asset(self, asset_id, use) + return result + raise Exception(f"Failed to load image asset {asset_id}") + + def put_image_asset( + self, + image: Union[Image.Image, np.ndarray], + use: generation.AssetUse = generation.ASSET_USE_PROJECT + ): + results = [] + asset_id = None + filename = None + for backend in self.backends: + result = backend.put_image_asset(self, image, use, asset_id=asset_id) + if backend.primary: + rsplit_res = result.rsplit('/', 1) + asset_id = rsplit_res[1] if len(rsplit_res) > 1 else rsplit_res[0] + results.append(asset_id) + if backend.primary_fs: + filename = result + mimetype = "image/png" + Project.add_asset_metadata(self.id, asset_id, mimetype, filename) + return results + + def get_video_asset(self, asset_id: str, use: generation.AssetUse = generation.ASSET_USE_INPUT) -> bytes: + for backend in self.backends: + if backend.primary: + result = backend.get_video_asset(self, asset_id, use) + return result + raise Exception(f"Failed to load video asset {asset_id}") + + def put_video_asset(self, video_path: str) -> List[str]: + results = [] + filename = None + asset_id = None + for backend in self.backends: + result = backend.put_video_asset(self, video_path, asset_id=asset_id) + if backend.primary: + rsplit_res = result.rsplit('/', 1) + asset_id = rsplit_res[1] if len(rsplit_res) > 1 else rsplit_res[0] + results.append(asset_id) + if backend.primary_fs: + filename = result + mimetype = "video/mp4" + Project.add_asset_metadata(self.id, asset_id, mimetype, filename) + return results + + def update_project(self, title: str = None, file_id: str = None, file_uri: str = None): + for backend in self.backends: + result = backend.update_project(self._context, self, title, file_id, file_uri) if title: self._project.title = title if file_id: @@ -192,8 +630,44 @@ def update(self, title:str=None, file_id:str=None, file_uri:str=None): self._project.file.uri = file_uri + @staticmethod + def add_asset_metadata(project_id: str, asset_id: str, mime_type: str, filename: str, project_key: str = None) -> None: + # metadata_index = self.load_metadata_index() # I assume metadata is updated by each operation + if project_id not in Project._metadata_index: + Project._metadata_index[project_id] = {} + Project._metadata_index[project_id][asset_id] = { + "mime_type": mime_type + } + if filename is not None: + Project._metadata_index[project_id][asset_id]["filename"] = filename + if project_key is not None: + Project._metadata_index[project_id][project_key] = asset_id + Project.save_metadata_index() + + @staticmethod + def delete_project_metadata(project_id: str) -> None: + Project._metadata_index.pop(project_id, None) + + @classmethod + def save_metadata_index(cls, metadata_index: dict = None) -> None: + if metadata_index is None: + metadata_index = cls._metadata_index + index_file = f"metadata_index.json" + with open(index_file, "w") as f: + json.dump(metadata_index, f) + + @classmethod + def load_metadata_index(cls) -> dict: + index_file = "metadata_index.json" + if os.path.exists(index_file): + with open(index_file, "r") as f: + metadata_index = json.load(f) + return metadata_index + return {} + + class Context: - def __init__(self, host: str="", api_key: str=None, stub: generation_grpc.GenerationServiceStub=None): + def __init__(self, host: str = "", api_key: str = None, stub: generation_grpc.GenerationServiceStub = None): if not host and stub is None: raise Exception("Must provide either GRPC host or stub to Api") channel = open_channel(host, api_key) if host else None @@ -332,7 +806,7 @@ def inpaint( ) -> Dict[int, List[Union[np.ndarray, Any]]]: """ Apply inpainting to an image. - + :param image: Source image :param mask: Mask image with 0 for pixels to change and 255 for pixels to keep :param prompts: List of text prompts @@ -340,7 +814,7 @@ def inpaint( :param steps: Number of steps to run :param seed: Random seed :param samples: Number of samples to generate - :param cfg_scale: Classifier free guidance scale + :param cfg_scale: Classifier free guidance scale :param sampler: Sampler to use for the diffusion process :param init_strength: Strength of the initial image :param init_noise_scale: Scale of the initial noise @@ -396,7 +870,7 @@ def interpolate( elif ratios[0] == 1.0: return [images[1]] elif mode == generation.INTERPOLATE_LINEAR: - return [image_mix(images[0], images[1], ratios[0])] + return [image_mix(images[0], images[1], ratios[0])] p = [image_to_prompt(image) for image in images] request = generation.Request( @@ -584,7 +1058,7 @@ def transform_3d( id=op_id, request=rq_transform, on_status=[generation.OnStatus(action=[generation.STAGE_ACTION_RETURN])] - ) + ) ]) results = self._run_request(self._transform, chain_rq) @@ -661,6 +1135,8 @@ def _process_response(self, response) -> Dict[int, List[np.ndarray]]: results[artifact.type].append(artifact.tensor) elif artifact.type == generation.ARTIFACT_TEXT: results[artifact.type].append(artifact.text) + elif artifact.type == generation.ARTIFACT_VIDEO: + results[artifact.type].append(artifact.binary) return results def _run_request( @@ -688,13 +1164,13 @@ def _run_request( except ClassifierException as ce: if attempt == self._max_retries or not self._retry_obfuscation: raise ce - + for exceed in ce.classifier_result.exceeds: logger.warning(f"Received classifier obfuscation. Exceeded {exceed.name} threshold") for concept in exceed.concepts: if concept.HasField("threshold"): logger.warning(f" {concept.concept} ({concept.threshold})") - + if isinstance(request, generation.Request) and request.HasField("image"): self._adjust_request_for_retry(request, attempt) elif isinstance(request, generation.ChainRequest):