diff --git a/nbs/stable_video_v1_REST_API_alpha.ipynb b/nbs/stable_video_v1_REST_API_alpha.ipynb new file mode 100644 index 0000000..1808603 --- /dev/null +++ b/nbs/stable_video_v1_REST_API_alpha.ipynb @@ -0,0 +1,270 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "id": "AvTo6cVeF3ip" + }, + "outputs": [], + "source": [ + "#@title Install requirements\n", + "import base64\n", + "from io import BytesIO\n", + "import json\n", + "import mimetypes\n", + "import os\n", + "from PIL import Image\n", + "import requests\n", + "import time" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "id": "pVBZ1o3fH1HX" + }, + "outputs": [], + "source": [ + "#@title Define helper functions\n", + "\n", + "def image_to_bytes(\n", + " img: Image,\n", + " format=\"PNG\"\n", + "):\n", + " im_file = BytesIO()\n", + " img.save(im_file, format=format)\n", + " img_bytes = im_file.getvalue()\n", + " return img_bytes\n", + "\n", + "def get_image_format(\n", + " image_path : str\n", + "):\n", + " image_mime_type = mimetypes.guess_type(image_path)[0]\n", + " if image_mime_type is None:\n", + " raise ValueError(f\"Unknown image mime type for {image_path}\")\n", + " image_format = image_mime_type.split(\"/\")[-1].upper()\n", + " return image_format\n", + "\n", + "def resize_image_to_bytes(\n", + " image_path : str,\n", + " size : tuple[int,int] = None\n", + "):\n", + " # Resize image from file and convert to bytes\n", + " image = Image.open(image_path)\n", + " format = get_image_format(image_path)\n", + " if size is None:\n", + " width, height = get_closest_valid_dims(image)\n", + " else:\n", + " width, height = size\n", + " image = image.resize((width, height))\n", + " image_bytes = image_to_bytes(image, format=format)\n", + " return image_bytes\n", + "\n", + "def get_closest_valid_dims(\n", + " image : Image\n", + "):\n", + " # Finds the closest aspect ratio to the input image that are valid for SSC\n", + " # Valid dims are 1024x576, 768x768, 1024\n", + " w,h = image.size\n", + " aspect_ratio = w/h\n", + " portrait_aspect_ratio = 9/16\n", + " landscape_aspect_ratio = 16/9\n", + " portrait_aspect_ratio_midpoint = (portrait_aspect_ratio + 1)/2\n", + " landscape_aspect_ratio_midpoint = (landscape_aspect_ratio + 1)/2\n", + " if aspect_ratio < 1.0:\n", + " # portrait\n", + " width,height = (576,1024) if aspect_ratio < portrait_aspect_ratio_midpoint else (768,768)\n", + " else:\n", + " # landscape\n", + " width,height = (1024,576) if aspect_ratio < landscape_aspect_ratio_midpoint else (768,768)\n", + "\n", + " return width, height\n", + "\n", + "\n", + "def image_to_valid_bytes(\n", + " image_path : str\n", + " ):\n", + " image = Image.open(image_path)\n", + " width, height = get_closest_valid_dims(image)\n", + " format = get_image_format(image_path)\n", + " image = image.resize((width, height))\n", + " image_bytes = image_to_bytes(image, format=format)\n", + " return image_bytes" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "dtw-2LAC7NgM" + }, + "outputs": [], + "source": [ + "#@title Set up credentials\n", + "\n", + "import getpass\n", + "# @markdown To get your API key visit https://platform.stability.ai/account/keys\n", + "STABILITY_KEY = getpass.getpass('Enter your API Key')\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": { + "cellView": "form", + "id": "0lDpGa2jAmAs" + }, + "outputs": [], + "source": [ + "#@title Define input\n", + "\n", + "#@markdown - Drag and drop image to file folder on left\n", + "#@markdown - Right click it and choose Copy path\n", + "#@markdown - Paste that path into init_image field below\n", + "#@markdown

\n", + "\n", + "init_image = \"/content/img.jpg\" #@param {type:\"string\"}\n", + "seed = 0 #@param {type:\"integer\"}\n", + "cfg_scale = 2.5 #@param {type:\"number\"}\n", + "motion_bucket_id = 40 #@param {type:\"integer\"}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "YNrNGa107s1M" + }, + "outputs": [], + "source": [ + "#@title Use REST API\n", + "\n", + "headers = {\n", + " \"Accept\": \"application/json\",\n", + " \"Authorization\": f\"Bearer {STABILITY_KEY}\"\n", + "}\n", + "host = f\"https://api.stability.ai/v2alpha/generation/image-to-video\"\n", + "\n", + "init_image_bytes = image_to_valid_bytes(init_image)\n", + "image_mime_type = mimetypes.guess_type(init_image)[0]\n", + "files = {\n", + " \"image\": (\"file\", init_image_bytes, image_mime_type),\n", + " }\n", + "params = {\n", + " \"seed\": seed,\n", + " \"cfg_scale\": cfg_scale,\n", + " \"motion_bucket_id\": motion_bucket_id\n", + " }\n", + "\n", + "for k,v in params.items():\n", + " if isinstance(v, bool):\n", + " v = str(v).lower()\n", + " files[k] = (None, str(v).encode('utf-8'))\n", + "\n", + "print(f\"Sending REST request to {host}...\")\n", + "\n", + "response = requests.post(\n", + " host,\n", + " headers=headers,\n", + " files=files,\n", + " )\n", + "\n", + "if not response.ok:\n", + " raise Exception(f\"HTTP {response.status_code}: {response.text}\")\n", + "\n", + "#\n", + "# Process async response\n", + "#\n", + "response_dict = json.loads(response.text)\n", + "request_id = response_dict.get(\"id\", None)\n", + "assert request_id is not None, \"Expected id in response\"\n", + "\n", + "# Loop until video result or timeout\n", + "timeout = int(os.getenv(\"WORKER_TIMEOUT\", 500))\n", + "start = time.time()\n", + "status_code = 202\n", + "while status_code == 202:\n", + "\n", + " response = requests.get(\n", + " f\"{host}/result/{request_id}\",\n", + " headers=headers,\n", + " )\n", + "\n", + " if not response.ok:\n", + " raise Exception(f\"HTTP {response.status_code}: {response.text}\")\n", + " status_code = response.status_code\n", + " time.sleep(2)\n", + " if time.time() - start > timeout:\n", + " raise Exception(f\"Timeout after {timeout} seconds\")\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "metadata": { + "id": "q6aXWOraCW0T" + }, + "outputs": [], + "source": [ + "#@title Decode response\n", + "json_data = response.json()\n", + "\n", + "video = base64.b64decode(json_data[\"video\"])\n", + "seed = json_data[\"seed\"]\n", + "finish_reason = json_data[\"finishReason\"]\n", + "\n", + "if finish_reason == 'CONTENT_FILTERED':\n", + " raise Warning(\"Video failed NSFW classifier\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "PenTBs3oGZ2b" + }, + "outputs": [], + "source": [ + "#@title Save and display result\n", + "\n", + "filename = f\"video_{seed}.mp4\"\n", + "with open(filename, \"wb\") as f:\n", + " f.write(video)\n", + "print(f\"Saved video {filename}\")\n", + "\n", + "import IPython\n", + "mp4 = open(filename,'rb').read()\n", + "data_url = f\"data:video/mp4;base64,\" + base64.b64encode(mp4).decode()\n", + "IPython.display.display(IPython.display.HTML(f''))" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": { + "id": "L-0LuAXdba8V" + }, + "outputs": [], + "source": [] + } + ], + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +}