-
Notifications
You must be signed in to change notification settings - Fork 339
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
ff797c0
commit 4f5e4c5
Showing
1 changed file
with
270 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,270 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"metadata": { | ||
"id": "AvTo6cVeF3ip" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"#@title Install requirements\n", | ||
"import base64\n", | ||
"from io import BytesIO\n", | ||
"import json\n", | ||
"import mimetypes\n", | ||
"import os\n", | ||
"from PIL import Image\n", | ||
"import requests\n", | ||
"import time" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"metadata": { | ||
"id": "pVBZ1o3fH1HX" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"#@title Define helper functions\n", | ||
"\n", | ||
"def image_to_bytes(\n", | ||
" img: Image,\n", | ||
" format=\"PNG\"\n", | ||
"):\n", | ||
" im_file = BytesIO()\n", | ||
" img.save(im_file, format=format)\n", | ||
" img_bytes = im_file.getvalue()\n", | ||
" return img_bytes\n", | ||
"\n", | ||
"def get_image_format(\n", | ||
" image_path : str\n", | ||
"):\n", | ||
" image_mime_type = mimetypes.guess_type(image_path)[0]\n", | ||
" if image_mime_type is None:\n", | ||
" raise ValueError(f\"Unknown image mime type for {image_path}\")\n", | ||
" image_format = image_mime_type.split(\"/\")[-1].upper()\n", | ||
" return image_format\n", | ||
"\n", | ||
"def resize_image_to_bytes(\n", | ||
" image_path : str,\n", | ||
" size : tuple[int,int] = None\n", | ||
"):\n", | ||
" # Resize image from file and convert to bytes\n", | ||
" image = Image.open(image_path)\n", | ||
" format = get_image_format(image_path)\n", | ||
" if size is None:\n", | ||
" width, height = get_closest_valid_dims(image)\n", | ||
" else:\n", | ||
" width, height = size\n", | ||
" image = image.resize((width, height))\n", | ||
" image_bytes = image_to_bytes(image, format=format)\n", | ||
" return image_bytes\n", | ||
"\n", | ||
"def get_closest_valid_dims(\n", | ||
" image : Image\n", | ||
"):\n", | ||
" # Finds the closest aspect ratio to the input image that are valid for SSC\n", | ||
" # Valid dims are 1024x576, 768x768, 1024\n", | ||
" w,h = image.size\n", | ||
" aspect_ratio = w/h\n", | ||
" portrait_aspect_ratio = 9/16\n", | ||
" landscape_aspect_ratio = 16/9\n", | ||
" portrait_aspect_ratio_midpoint = (portrait_aspect_ratio + 1)/2\n", | ||
" landscape_aspect_ratio_midpoint = (landscape_aspect_ratio + 1)/2\n", | ||
" if aspect_ratio < 1.0:\n", | ||
" # portrait\n", | ||
" width,height = (576,1024) if aspect_ratio < portrait_aspect_ratio_midpoint else (768,768)\n", | ||
" else:\n", | ||
" # landscape\n", | ||
" width,height = (1024,576) if aspect_ratio < landscape_aspect_ratio_midpoint else (768,768)\n", | ||
"\n", | ||
" return width, height\n", | ||
"\n", | ||
"\n", | ||
"def image_to_valid_bytes(\n", | ||
" image_path : str\n", | ||
" ):\n", | ||
" image = Image.open(image_path)\n", | ||
" width, height = get_closest_valid_dims(image)\n", | ||
" format = get_image_format(image_path)\n", | ||
" image = image.resize((width, height))\n", | ||
" image_bytes = image_to_bytes(image, format=format)\n", | ||
" return image_bytes" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"cellView": "form", | ||
"id": "dtw-2LAC7NgM" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"#@title Set up credentials\n", | ||
"\n", | ||
"import getpass\n", | ||
"# @markdown To get your API key visit https://platform.stability.ai/account/keys\n", | ||
"STABILITY_KEY = getpass.getpass('Enter your API Key')\n", | ||
"\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 31, | ||
"metadata": { | ||
"cellView": "form", | ||
"id": "0lDpGa2jAmAs" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"#@title Define input\n", | ||
"\n", | ||
"#@markdown - Drag and drop image to file folder on left\n", | ||
"#@markdown - Right click it and choose Copy path\n", | ||
"#@markdown - Paste that path into init_image field below\n", | ||
"#@markdown <br><br>\n", | ||
"\n", | ||
"init_image = \"/content/img.jpg\" #@param {type:\"string\"}\n", | ||
"seed = 0 #@param {type:\"integer\"}\n", | ||
"cfg_scale = 2.5 #@param {type:\"number\"}\n", | ||
"motion_bucket_id = 40 #@param {type:\"integer\"}" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"id": "YNrNGa107s1M" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"#@title Use REST API\n", | ||
"\n", | ||
"headers = {\n", | ||
" \"Accept\": \"application/json\",\n", | ||
" \"Authorization\": f\"Bearer {STABILITY_KEY}\"\n", | ||
"}\n", | ||
"host = f\"https://api.stability.ai/v2alpha/generation/image-to-video\"\n", | ||
"\n", | ||
"init_image_bytes = image_to_valid_bytes(init_image)\n", | ||
"image_mime_type = mimetypes.guess_type(init_image)[0]\n", | ||
"files = {\n", | ||
" \"image\": (\"file\", init_image_bytes, image_mime_type),\n", | ||
" }\n", | ||
"params = {\n", | ||
" \"seed\": seed,\n", | ||
" \"cfg_scale\": cfg_scale,\n", | ||
" \"motion_bucket_id\": motion_bucket_id\n", | ||
" }\n", | ||
"\n", | ||
"for k,v in params.items():\n", | ||
" if isinstance(v, bool):\n", | ||
" v = str(v).lower()\n", | ||
" files[k] = (None, str(v).encode('utf-8'))\n", | ||
"\n", | ||
"print(f\"Sending REST request to {host}...\")\n", | ||
"\n", | ||
"response = requests.post(\n", | ||
" host,\n", | ||
" headers=headers,\n", | ||
" files=files,\n", | ||
" )\n", | ||
"\n", | ||
"if not response.ok:\n", | ||
" raise Exception(f\"HTTP {response.status_code}: {response.text}\")\n", | ||
"\n", | ||
"#\n", | ||
"# Process async response\n", | ||
"#\n", | ||
"response_dict = json.loads(response.text)\n", | ||
"request_id = response_dict.get(\"id\", None)\n", | ||
"assert request_id is not None, \"Expected id in response\"\n", | ||
"\n", | ||
"# Loop until video result or timeout\n", | ||
"timeout = int(os.getenv(\"WORKER_TIMEOUT\", 500))\n", | ||
"start = time.time()\n", | ||
"status_code = 202\n", | ||
"while status_code == 202:\n", | ||
"\n", | ||
" response = requests.get(\n", | ||
" f\"{host}/result/{request_id}\",\n", | ||
" headers=headers,\n", | ||
" )\n", | ||
"\n", | ||
" if not response.ok:\n", | ||
" raise Exception(f\"HTTP {response.status_code}: {response.text}\")\n", | ||
" status_code = response.status_code\n", | ||
" time.sleep(2)\n", | ||
" if time.time() - start > timeout:\n", | ||
" raise Exception(f\"Timeout after {timeout} seconds\")\n", | ||
"\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 33, | ||
"metadata": { | ||
"id": "q6aXWOraCW0T" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"#@title Decode response\n", | ||
"json_data = response.json()\n", | ||
"\n", | ||
"video = base64.b64decode(json_data[\"video\"])\n", | ||
"seed = json_data[\"seed\"]\n", | ||
"finish_reason = json_data[\"finishReason\"]\n", | ||
"\n", | ||
"if finish_reason == 'CONTENT_FILTERED':\n", | ||
" raise Warning(\"Video failed NSFW classifier\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"id": "PenTBs3oGZ2b" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"#@title Save and display result\n", | ||
"\n", | ||
"filename = f\"video_{seed}.mp4\"\n", | ||
"with open(filename, \"wb\") as f:\n", | ||
" f.write(video)\n", | ||
"print(f\"Saved video {filename}\")\n", | ||
"\n", | ||
"import IPython\n", | ||
"mp4 = open(filename,'rb').read()\n", | ||
"data_url = f\"data:video/mp4;base64,\" + base64.b64encode(mp4).decode()\n", | ||
"IPython.display.display(IPython.display.HTML(f'<video controls loop><source src=\"{data_url}\" type=\"video/mp4\"></video>'))" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 17, | ||
"metadata": { | ||
"id": "L-0LuAXdba8V" | ||
}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"colab": { | ||
"provenance": [] | ||
}, | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"name": "python" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 0 | ||
} |