Skip to content

Commit

Permalink
Add example Video REST api notebook
Browse files Browse the repository at this point in the history
  • Loading branch information
enzymezoo-code committed Dec 22, 2023
1 parent ff797c0 commit 4f5e4c5
Showing 1 changed file with 270 additions and 0 deletions.
270 changes: 270 additions & 0 deletions nbs/stable_video_v1_REST_API_alpha.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,270 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"id": "AvTo6cVeF3ip"
},
"outputs": [],
"source": [
"#@title Install requirements\n",
"import base64\n",
"from io import BytesIO\n",
"import json\n",
"import mimetypes\n",
"import os\n",
"from PIL import Image\n",
"import requests\n",
"import time"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"id": "pVBZ1o3fH1HX"
},
"outputs": [],
"source": [
"#@title Define helper functions\n",
"\n",
"def image_to_bytes(\n",
" img: Image,\n",
" format=\"PNG\"\n",
"):\n",
" im_file = BytesIO()\n",
" img.save(im_file, format=format)\n",
" img_bytes = im_file.getvalue()\n",
" return img_bytes\n",
"\n",
"def get_image_format(\n",
" image_path : str\n",
"):\n",
" image_mime_type = mimetypes.guess_type(image_path)[0]\n",
" if image_mime_type is None:\n",
" raise ValueError(f\"Unknown image mime type for {image_path}\")\n",
" image_format = image_mime_type.split(\"/\")[-1].upper()\n",
" return image_format\n",
"\n",
"def resize_image_to_bytes(\n",
" image_path : str,\n",
" size : tuple[int,int] = None\n",
"):\n",
" # Resize image from file and convert to bytes\n",
" image = Image.open(image_path)\n",
" format = get_image_format(image_path)\n",
" if size is None:\n",
" width, height = get_closest_valid_dims(image)\n",
" else:\n",
" width, height = size\n",
" image = image.resize((width, height))\n",
" image_bytes = image_to_bytes(image, format=format)\n",
" return image_bytes\n",
"\n",
"def get_closest_valid_dims(\n",
" image : Image\n",
"):\n",
" # Finds the closest aspect ratio to the input image that are valid for SSC\n",
" # Valid dims are 1024x576, 768x768, 1024\n",
" w,h = image.size\n",
" aspect_ratio = w/h\n",
" portrait_aspect_ratio = 9/16\n",
" landscape_aspect_ratio = 16/9\n",
" portrait_aspect_ratio_midpoint = (portrait_aspect_ratio + 1)/2\n",
" landscape_aspect_ratio_midpoint = (landscape_aspect_ratio + 1)/2\n",
" if aspect_ratio < 1.0:\n",
" # portrait\n",
" width,height = (576,1024) if aspect_ratio < portrait_aspect_ratio_midpoint else (768,768)\n",
" else:\n",
" # landscape\n",
" width,height = (1024,576) if aspect_ratio < landscape_aspect_ratio_midpoint else (768,768)\n",
"\n",
" return width, height\n",
"\n",
"\n",
"def image_to_valid_bytes(\n",
" image_path : str\n",
" ):\n",
" image = Image.open(image_path)\n",
" width, height = get_closest_valid_dims(image)\n",
" format = get_image_format(image_path)\n",
" image = image.resize((width, height))\n",
" image_bytes = image_to_bytes(image, format=format)\n",
" return image_bytes"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "dtw-2LAC7NgM"
},
"outputs": [],
"source": [
"#@title Set up credentials\n",
"\n",
"import getpass\n",
"# @markdown To get your API key visit https://platform.stability.ai/account/keys\n",
"STABILITY_KEY = getpass.getpass('Enter your API Key')\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {
"cellView": "form",
"id": "0lDpGa2jAmAs"
},
"outputs": [],
"source": [
"#@title Define input\n",
"\n",
"#@markdown - Drag and drop image to file folder on left\n",
"#@markdown - Right click it and choose Copy path\n",
"#@markdown - Paste that path into init_image field below\n",
"#@markdown <br><br>\n",
"\n",
"init_image = \"/content/img.jpg\" #@param {type:\"string\"}\n",
"seed = 0 #@param {type:\"integer\"}\n",
"cfg_scale = 2.5 #@param {type:\"number\"}\n",
"motion_bucket_id = 40 #@param {type:\"integer\"}"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "YNrNGa107s1M"
},
"outputs": [],
"source": [
"#@title Use REST API\n",
"\n",
"headers = {\n",
" \"Accept\": \"application/json\",\n",
" \"Authorization\": f\"Bearer {STABILITY_KEY}\"\n",
"}\n",
"host = f\"https://api.stability.ai/v2alpha/generation/image-to-video\"\n",
"\n",
"init_image_bytes = image_to_valid_bytes(init_image)\n",
"image_mime_type = mimetypes.guess_type(init_image)[0]\n",
"files = {\n",
" \"image\": (\"file\", init_image_bytes, image_mime_type),\n",
" }\n",
"params = {\n",
" \"seed\": seed,\n",
" \"cfg_scale\": cfg_scale,\n",
" \"motion_bucket_id\": motion_bucket_id\n",
" }\n",
"\n",
"for k,v in params.items():\n",
" if isinstance(v, bool):\n",
" v = str(v).lower()\n",
" files[k] = (None, str(v).encode('utf-8'))\n",
"\n",
"print(f\"Sending REST request to {host}...\")\n",
"\n",
"response = requests.post(\n",
" host,\n",
" headers=headers,\n",
" files=files,\n",
" )\n",
"\n",
"if not response.ok:\n",
" raise Exception(f\"HTTP {response.status_code}: {response.text}\")\n",
"\n",
"#\n",
"# Process async response\n",
"#\n",
"response_dict = json.loads(response.text)\n",
"request_id = response_dict.get(\"id\", None)\n",
"assert request_id is not None, \"Expected id in response\"\n",
"\n",
"# Loop until video result or timeout\n",
"timeout = int(os.getenv(\"WORKER_TIMEOUT\", 500))\n",
"start = time.time()\n",
"status_code = 202\n",
"while status_code == 202:\n",
"\n",
" response = requests.get(\n",
" f\"{host}/result/{request_id}\",\n",
" headers=headers,\n",
" )\n",
"\n",
" if not response.ok:\n",
" raise Exception(f\"HTTP {response.status_code}: {response.text}\")\n",
" status_code = response.status_code\n",
" time.sleep(2)\n",
" if time.time() - start > timeout:\n",
" raise Exception(f\"Timeout after {timeout} seconds\")\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {
"id": "q6aXWOraCW0T"
},
"outputs": [],
"source": [
"#@title Decode response\n",
"json_data = response.json()\n",
"\n",
"video = base64.b64decode(json_data[\"video\"])\n",
"seed = json_data[\"seed\"]\n",
"finish_reason = json_data[\"finishReason\"]\n",
"\n",
"if finish_reason == 'CONTENT_FILTERED':\n",
" raise Warning(\"Video failed NSFW classifier\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "PenTBs3oGZ2b"
},
"outputs": [],
"source": [
"#@title Save and display result\n",
"\n",
"filename = f\"video_{seed}.mp4\"\n",
"with open(filename, \"wb\") as f:\n",
" f.write(video)\n",
"print(f\"Saved video {filename}\")\n",
"\n",
"import IPython\n",
"mp4 = open(filename,'rb').read()\n",
"data_url = f\"data:video/mp4;base64,\" + base64.b64encode(mp4).decode()\n",
"IPython.display.display(IPython.display.HTML(f'<video controls loop><source src=\"{data_url}\" type=\"video/mp4\"></video>'))"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {
"id": "L-0LuAXdba8V"
},
"outputs": [],
"source": []
}
],
"metadata": {
"colab": {
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 0
}

0 comments on commit 4f5e4c5

Please sign in to comment.