Skip to content

Commit 4a0a35b

Browse files
Batch of small fixes and updates to Animation SDK (#235)
* avoid 'prefill' border when 'inpaint_border' is disabled * Provide user more feedback when stopping anim/post and while creating mp4 from frames. * Ensure animation dimensions are a multiple of 64 * Switch from Colab form UI to `getpass` for entering API keys. --------- Co-authored-by: Dmitrii Tochilkin <[email protected]>
1 parent 1e85abd commit 4a0a35b

File tree

4 files changed

+89
-83
lines changed

4 files changed

+89
-83
lines changed

nbs/animation.ipynb

+31-29
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,20 @@
1010
"# Animation SDK example"
1111
]
1212
},
13+
{
14+
"cell_type": "code",
15+
"execution_count": 1,
16+
"metadata": {
17+
"cellView": "form",
18+
"id": "eeA1mYLdxr2j"
19+
},
20+
"outputs": [],
21+
"source": [
22+
"#@title Install the Stability SDK\n",
23+
"%%capture captured --no-stderr\n",
24+
"%pip install stability-sdk[anim]"
25+
]
26+
},
1327
{
1428
"cell_type": "code",
1529
"execution_count": null,
@@ -32,45 +46,40 @@
3246
},
3347
{
3448
"cell_type": "code",
35-
"execution_count": 2,
49+
"execution_count": null,
3650
"metadata": {
3751
"cellView": "form",
3852
"id": "zj56t6tc3prF"
3953
},
4054
"outputs": [],
4155
"source": [
42-
"%%capture\n",
4356
"#@title Connect to the Stability API\n",
44-
"\n",
45-
"# install Stability Animation SDK for Python\n",
46-
"%pip install stability-sdk[anim]\n",
47-
"\n",
4857
"import datetime\n",
58+
"import getpass\n",
4959
"import json\n",
5060
"import os\n",
5161
"import panel as pn\n",
5262
"import param\n",
53-
"import shutil\n",
54-
"import sys\n",
5563
"\n",
5664
"from base64 import b64encode\n",
5765
"from IPython import display\n",
58-
"from pathlib import Path\n",
59-
"from PIL import Image\n",
6066
"from tqdm import tqdm\n",
6167
"from types import SimpleNamespace\n",
6268
"\n",
6369
"from stability_sdk.api import Context\n",
6470
"from stability_sdk.animation import AnimationArgs, Animator\n",
6571
"from stability_sdk.utils import create_video_from_frames\n",
6672
"\n",
67-
"\n",
68-
"# Enter your API key from dreamstudio.ai\n",
73+
"# @markdown To get your API key visit https://dreamstudio.ai/account\n",
6974
"STABILITY_HOST = \"grpc.stability.ai:443\" #@param {type:\"string\"}\n",
70-
"STABILITY_KEY = \"\" #@param {type:\"string\"}\n",
75+
"STABILITY_KEY = getpass.getpass('Enter your API Key')\n",
7176
"\n",
7277
"# Connect to Stability API\n",
73-
"api_context = Context(STABILITY_HOST, STABILITY_KEY)"
78+
"context = Context(STABILITY_HOST, STABILITY_KEY)\n",
79+
"\n",
80+
"# Test the connection\n",
81+
"context.get_user_info()\n",
82+
"print(\"Connection successful!\")"
7483
]
7584
},
7685
{
@@ -84,12 +93,10 @@
8493
"source": [
8594
"# @title Settings\n",
8695
"\n",
87-
"# @markdown Run this cell to reveal the settings UI. After entering values, move on to the next step.\n",
96+
"# @markdown Run this cell to reveal the settings UI grouped across several tabs. After entering values, move on to the next step.\n",
8897
"\n",
8998
"# @markdown To reset values to default, simply re-run this cell.\n",
9099
"\n",
91-
"# @markdown NB: Settings are grouped across several tabs.\n",
92-
"\n",
93100
"show_documentation = True # @param {type:'boolean'}\n",
94101
"\n",
95102
"# #@markdown ####**Resume:**\n",
@@ -158,6 +165,7 @@
158165
]
159166
},
160167
{
168+
"attachments": {},
161169
"cell_type": "markdown",
162170
"metadata": {
163171
"id": "_SudvbZG3prI"
@@ -168,7 +176,7 @@
168176
},
169177
{
170178
"cell_type": "code",
171-
"execution_count": 4,
179+
"execution_count": 7,
172180
"metadata": {
173181
"id": "FT9slDSw3prJ"
174182
},
@@ -228,7 +236,7 @@
228236
"print(f\"Saving animation frames to {out_dir}...\")\n",
229237
"\n",
230238
"animator = Animator(\n",
231-
" api_context=api_context,\n",
239+
" api_context=context,\n",
232240
" animation_prompts=animation_prompts,\n",
233241
" args=args,\n",
234242
" out_dir=out_dir, \n",
@@ -271,13 +279,12 @@
271279
],
272280
"metadata": {
273281
"colab": {
274-
"collapsed_sections": [],
275282
"provenance": []
276283
},
277284
"kernelspec": {
278-
"display_name": "client",
285+
"display_name": "venv",
279286
"language": "python",
280-
"name": "client"
287+
"name": "python3"
281288
},
282289
"language_info": {
283290
"codemirror_mode": {
@@ -291,13 +298,8 @@
291298
"pygments_lexer": "ipython3",
292299
"version": "3.9.5"
293300
},
294-
"orig_nbformat": 4,
295-
"vscode": {
296-
"interpreter": {
297-
"hash": "fb02550c4ef2b9a37ba5f7f381e893a74079cea154f791601856f87ae67cf67c"
298-
}
299-
}
301+
"orig_nbformat": 4
300302
},
301303
"nbformat": 4,
302-
"nbformat_minor": 4
304+
"nbformat_minor": 0
303305
}

nbs/animation_gradio.ipynb

+28-17
Original file line numberDiff line numberDiff line change
@@ -15,21 +15,30 @@
1515
"execution_count": null,
1616
"metadata": {
1717
"cellView": "form",
18-
"colab": {
19-
"base_uri": "https://localhost:8080/"
20-
},
21-
"id": "LUMF8i8BTwYH",
22-
"outputId": "f61c635e-bc57-48ab-cc3d-5166286b158f"
18+
"id": "enjwV3WW1yxL"
19+
},
20+
"outputs": [],
21+
"source": [
22+
"#@title Install the Stability SDK\n",
23+
"%%capture captured --no-stderr\n",
24+
"%pip install stability-sdk[anim_ui]"
25+
]
26+
},
27+
{
28+
"cell_type": "code",
29+
"execution_count": null,
30+
"metadata": {
31+
"cellView": "form",
32+
"id": "LUMF8i8BTwYH"
2333
},
2434
"outputs": [],
2535
"source": [
2636
"#@title Mount Google Drive\n",
27-
"import os\n",
2837
"try:\n",
2938
" from google.colab import drive\n",
3039
" drive.mount('/content/gdrive')\n",
3140
" outputs_path = \"/content/gdrive/MyDrive/AI/StableAnimation\"\n",
32-
" os.makedirs(outputs_path, exist_ok=True)\n",
41+
" !mkdir -p $outputs_path\n",
3342
"except:\n",
3443
" outputs_path = \".\"\n",
3544
"print(f\"Animations will be saved to {outputs_path}\")"
@@ -44,18 +53,21 @@
4453
},
4554
"outputs": [],
4655
"source": [
47-
"#@title Install Animation SDK and connect to the Stability API\n",
48-
"%pip install stability-sdk[anim_ui]\n",
49-
"\n",
56+
"#@title Connect to the Stability API\n",
57+
"import getpass\n",
5058
"from stability_sdk.api import Context\n",
5159
"from stability_sdk.animation_ui import create_ui\n",
5260
"\n",
53-
"# Enter your API key from dreamstudio.ai\n",
61+
"# @markdown To get your API key visit https://dreamstudio.ai/account\n",
5462
"STABILITY_HOST = \"grpc.stability.ai:443\" #@param {type:\"string\"}\n",
55-
"STABILITY_KEY = \"\" #@param {type:\"string\"}\n",
63+
"STABILITY_KEY = getpass.getpass('Enter your API Key')\n",
5664
"\n",
5765
"# Connect to Stability API\n",
58-
"api_context = Context(STABILITY_HOST, STABILITY_KEY)"
66+
"context = Context(STABILITY_HOST, STABILITY_KEY)\n",
67+
"\n",
68+
"# Test the connection\n",
69+
"context.get_user_info()\n",
70+
"print(\"Connection successful!\")"
5971
]
6072
},
6173
{
@@ -70,8 +82,7 @@
7082
"#@title Animation UI\n",
7183
"show_ui_in_notebook = True #@param {type:\"boolean\"}\n",
7284
"\n",
73-
"ui = create_ui(api_context, outputs_path)\n",
74-
"\n",
85+
"ui = create_ui(context, outputs_path)\n",
7586
"ui.queue(concurrency_count=2, max_size=2)\n",
7687
"ui.launch(show_api=False, debug=True, inline=show_ui_in_notebook, height=768, share=True, show_error=True)"
7788
]
@@ -106,5 +117,5 @@
106117
}
107118
},
108119
"nbformat": 4,
109-
"nbformat_minor": 4
110-
}
120+
"nbformat_minor": 0
121+
}

src/stability_sdk/animation.py

+9-6
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
import param
1111
import random
1212
import shutil
13-
import subprocess
1413

1514
from collections import OrderedDict, deque
1615
from dataclasses import dataclass, fields
@@ -805,10 +804,18 @@ def setup_animation(self, resume):
805804
# select image generation model
806805
self.api._generate.engine_id = args.custom_model if args.model == "custom" else args.model
807806

807+
# validate dimensions
808+
if args.width % 64 != 0 or args.height % 64 != 0:
809+
args.width, args.height = map(lambda x: x - x % 64, (args.width, args.height))
810+
logger.warning(f"Adjusted dimensions to {args.width}x{args.height} to be multiples of 64.")
811+
808812
# validate border settings
809813
if args.border == 'wrap' and args.animation_mode != '2D':
810814
args.border = 'reflect'
811815
logger.warning(f"Border 'wrap' is only supported in 2D mode, switching to '{args.border}'.")
816+
if args.border == 'prefill' and args.animation_mode in ('2D', '3D warp') and not args.inpaint_border:
817+
args.border = 'reflect'
818+
logger.warning(f"Border 'prefill' is only supported when 'inpaint_border' is enabled, switching to '{args.border}'.")
812819

813820
# validate clip guidance setting against selected model and sampler
814821
if args.clip_guidance.lower() != 'none':
@@ -817,11 +824,8 @@ def setup_animation(self, resume):
817824
logger.warning(f"CLIP guidance is not supported by {unsupported}, disabling guidance.")
818825
args.clip_guidance = 'None'
819826

820-
def curve_to_series(curve: str) -> List[float]:
821-
return curve_from_cn_string(curve)
822-
823827
# expand key frame strings to per frame series
824-
frame_args_dict = {f.name: curve_to_series(getattr(args, f.name)) for f in fields(FrameArgs)}
828+
frame_args_dict = {f.name: curve_from_cn_string(getattr(args, f.name)) for f in fields(FrameArgs)}
825829
self.frame_args = FrameArgs(**frame_args_dict)
826830

827831
# prepare sorted list of key frames
@@ -953,7 +957,6 @@ def transform_video(self, frame_idx) -> Optional[Image.Image]:
953957
mask = masks[0]
954958
self.prior_frames.extend(transformed_prior_frames)
955959
self.video_prev_frame = video_next_frame
956-
self.color_match_image = video_next_frame
957960
return mask
958961
return None
959962

0 commit comments

Comments
 (0)