Skip to content

Commit 931d7a3

Browse files
authored
Add inference helpers & tests (#57)
* Add inference helpers & tests * Support testing with hatch * fixes to hatch script * add inference test action * change workflow trigger * widen trigger to test * revert changes to workflow triggers * Install local python in action * Trigger on push again * fix python version * add CODEOWNERS and change triggers * Report tests results * update action versions * format * Fix typo and add refiner helper * use a shared path loaded from a secret for checkpoints source * typo fix * Use device from input and remove duplicated code * PR feedback * fix call to load_model_from_config * Move model to gpu * Refactor helpers * cleanup * test refiner, prep for 1.0, align with metadata * fix paths on second load * deduplicate streamlit code * filenames * fixes * add pydantic to requirements * fix usage of `msg` in demo script * remove double text * run black * fix streamlit sampling when returning latents * extract function for streamlit output * another fix for streamlit outputs * fix img2img in streamlit * Make fp16 optional and fix device param * PR feedback * fix dict cast for dataclass * run black, update ci script * cache pip dependencies on hosted runners, remove extra runs * install package in ci env * fix cache path * PR cleanup * one more cleanup * don't cache, it filled up
1 parent e596332 commit 931d7a3

File tree

11 files changed

+889
-346
lines changed

11 files changed

+889
-346
lines changed

.github/workflows/CODEOWNERS

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
.github @Stability-AI/infrastructure

.github/workflows/black.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
name: Run black
2-
on: [push, pull_request]
2+
on: [pull_request]
33

44
jobs:
55
lint:

.github/workflows/test-build.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ name: Build package
22

33
on:
44
push:
5+
branches: [ main ]
56
pull_request:
67

78
jobs:

.github/workflows/test-inference.yml

+34
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
name: Test inference
2+
3+
on:
4+
pull_request:
5+
push:
6+
branches:
7+
- main
8+
9+
jobs:
10+
test:
11+
name: "Test inference"
12+
# This action is designed only to run on the Stability research cluster at this time, so many assumptions are made about the environment
13+
if: github.repository == 'stability-ai/generative-models'
14+
runs-on: [self-hosted, slurm, g40]
15+
steps:
16+
- uses: actions/checkout@v3
17+
- name: "Symlink checkpoints"
18+
run: ln -s ${{secrets.SGM_CHECKPOINTS_PATH}} checkpoints
19+
- name: "Setup python"
20+
uses: actions/setup-python@v4
21+
with:
22+
python-version: "3.10"
23+
- name: "Install Hatch"
24+
run: pip install hatch
25+
- name: "Run inference tests"
26+
run: hatch run ci:test-inference --junit-xml test-results.xml
27+
- name: Surface failing tests
28+
if: always()
29+
uses: pmeier/pytest-results-action@main
30+
with:
31+
path: test-results.xml
32+
summary: true
33+
display-options: fEX
34+
fail-on-empty: true

pyproject.toml

+14
Original file line numberDiff line numberDiff line change
@@ -32,3 +32,17 @@ include = [
3232

3333
[tool.hatch.build.targets.wheel.force-include]
3434
"./configs" = "sgm/configs"
35+
36+
[tool.hatch.envs.ci]
37+
skip-install = false
38+
39+
dependencies = [
40+
"pytest"
41+
]
42+
43+
[tool.hatch.envs.ci.scripts]
44+
test-inference = [
45+
"pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 torchaudio==2.0.2+cu118 --index-url https://download.pytorch.org/whl/cu118",
46+
"pip install -r requirements/pt2.txt",
47+
"pytest -v tests/inference/test_inference.py {args}",
48+
]

pytest.ini

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
[pytest]
2+
markers =
3+
inference: mark as inference test (deselect with '-m "not inference"')

scripts/demo/sampling.py

+16-3
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,14 @@
1+
import numpy as np
12
from pytorch_lightning import seed_everything
23

34
from scripts.demo.streamlit_helpers import *
45
from scripts.util.detection.nsfw_and_watermark_dectection import DeepFloydDataFiltering
6+
from sgm.inference.helpers import (
7+
do_img2img,
8+
do_sample,
9+
get_unique_embedder_keys_from_conditioner,
10+
perform_save_locally,
11+
)
512

613
SAVE_PATH = "outputs/demo/txt2img/"
714

@@ -131,6 +138,8 @@ def run_txt2img(
131138

132139
if st.button("Sample"):
133140
st.write(f"**Model I:** {version}")
141+
outputs = st.empty()
142+
st.text("Sampling")
134143
out = do_sample(
135144
state["model"],
136145
sampler,
@@ -144,6 +153,8 @@ def run_txt2img(
144153
return_latents=return_latents,
145154
filter=filter,
146155
)
156+
show_samples(out, outputs)
157+
147158
return out
148159

149160

@@ -175,6 +186,8 @@ def run_img2img(
175186
num_samples = num_rows * num_cols
176187

177188
if st.button("Sample"):
189+
outputs = st.empty()
190+
st.text("Sampling")
178191
out = do_img2img(
179192
repeat(img, "1 ... -> n ...", n=num_samples),
180193
state["model"],
@@ -185,6 +198,7 @@ def run_img2img(
185198
return_latents=return_latents,
186199
filter=filter,
187200
)
201+
show_samples(out, outputs)
188202
return out
189203

190204

@@ -249,8 +263,6 @@ def apply_refiner(
249263
save_locally, save_path = init_save_locally(os.path.join(SAVE_PATH, version))
250264

251265
state = init_st(version_dict)
252-
if state["msg"]:
253-
st.info(state["msg"])
254266
model = state["model"]
255267

256268
is_legacy = version_dict["is_legacy"]
@@ -275,7 +287,6 @@ def apply_refiner(
275287

276288
version_dict2 = VERSION2SPECS[version2]
277289
state2 = init_st(version_dict2)
278-
st.info(state2["msg"])
279290

280291
stage2strength = st.number_input(
281292
"**Refinement strength**", value=0.3, min_value=0.0, max_value=1.0
@@ -315,6 +326,7 @@ def apply_refiner(
315326
samples_z = None
316327

317328
if add_pipeline and samples_z is not None:
329+
outputs = st.empty()
318330
st.write("**Running Refinement Stage**")
319331
samples = apply_refiner(
320332
samples_z,
@@ -325,6 +337,7 @@ def apply_refiner(
325337
negative_prompt=negative_prompt if is_legacy else "",
326338
filter=filter,
327339
)
340+
show_samples(samples, outputs)
328341

329342
if save_locally and samples is not None:
330343
perform_save_locally(save_path, samples)

0 commit comments

Comments
 (0)