Skip to content

Commit

Permalink
Add inference helpers & tests (#57)
Browse files Browse the repository at this point in the history
* Add inference helpers & tests

* Support testing with hatch

* fixes to hatch script

* add inference test action

* change workflow trigger

* widen trigger to test

* revert changes to workflow triggers

* Install local python in action

* Trigger on push again

* fix python version

* add CODEOWNERS and change triggers

* Report tests results

* update action versions

* format

* Fix typo and add refiner helper

* use a shared path loaded from a secret for checkpoints source

* typo fix

* Use device from input and remove duplicated code

* PR feedback

* fix call to load_model_from_config

* Move model to gpu

* Refactor helpers

* cleanup

* test refiner, prep for 1.0, align with metadata

* fix paths on second load

* deduplicate streamlit code

* filenames

* fixes

* add pydantic to requirements

* fix usage of `msg` in demo script

* remove double text

* run black

* fix streamlit sampling when returning latents

* extract function for streamlit output

* another fix for streamlit outputs

* fix img2img in streamlit

* Make fp16 optional and fix device param

* PR feedback

* fix dict cast for dataclass

* run black, update ci script

* cache pip dependencies on hosted runners, remove extra runs

* install package in ci env

* fix cache path

* PR cleanup

* one more cleanup

* don't cache, it filled up
  • Loading branch information
palp authored Jul 26, 2023
1 parent e596332 commit 931d7a3
Show file tree
Hide file tree
Showing 11 changed files with 889 additions and 346 deletions.
1 change: 1 addition & 0 deletions .github/workflows/CODEOWNERS
Validating CODEOWNERS rules …
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
.github @Stability-AI/infrastructure
2 changes: 1 addition & 1 deletion .github/workflows/black.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
name: Run black
on: [push, pull_request]
on: [pull_request]

jobs:
lint:
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/test-build.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ name: Build package

on:
push:
branches: [ main ]
pull_request:

jobs:
Expand Down
34 changes: 34 additions & 0 deletions .github/workflows/test-inference.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
name: Test inference

on:
pull_request:
push:
branches:
- main

jobs:
test:
name: "Test inference"
# This action is designed only to run on the Stability research cluster at this time, so many assumptions are made about the environment
if: github.repository == 'stability-ai/generative-models'
runs-on: [self-hosted, slurm, g40]
steps:
- uses: actions/checkout@v3
- name: "Symlink checkpoints"
run: ln -s ${{secrets.SGM_CHECKPOINTS_PATH}} checkpoints
- name: "Setup python"
uses: actions/setup-python@v4
with:
python-version: "3.10"
- name: "Install Hatch"
run: pip install hatch
- name: "Run inference tests"
run: hatch run ci:test-inference --junit-xml test-results.xml
- name: Surface failing tests
if: always()
uses: pmeier/pytest-results-action@main
with:
path: test-results.xml
summary: true
display-options: fEX
fail-on-empty: true
14 changes: 14 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,17 @@ include = [

[tool.hatch.build.targets.wheel.force-include]
"./configs" = "sgm/configs"

[tool.hatch.envs.ci]
skip-install = false

dependencies = [
"pytest"
]

[tool.hatch.envs.ci.scripts]
test-inference = [
"pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 torchaudio==2.0.2+cu118 --index-url https://download.pytorch.org/whl/cu118",
"pip install -r requirements/pt2.txt",
"pytest -v tests/inference/test_inference.py {args}",
]
3 changes: 3 additions & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[pytest]
markers =
inference: mark as inference test (deselect with '-m "not inference"')
19 changes: 16 additions & 3 deletions scripts/demo/sampling.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
import numpy as np
from pytorch_lightning import seed_everything

from scripts.demo.streamlit_helpers import *
from scripts.util.detection.nsfw_and_watermark_dectection import DeepFloydDataFiltering
from sgm.inference.helpers import (
do_img2img,
do_sample,
get_unique_embedder_keys_from_conditioner,
perform_save_locally,
)

SAVE_PATH = "outputs/demo/txt2img/"

Expand Down Expand Up @@ -131,6 +138,8 @@ def run_txt2img(

if st.button("Sample"):
st.write(f"**Model I:** {version}")
outputs = st.empty()
st.text("Sampling")
out = do_sample(
state["model"],
sampler,
Expand All @@ -144,6 +153,8 @@ def run_txt2img(
return_latents=return_latents,
filter=filter,
)
show_samples(out, outputs)

return out


Expand Down Expand Up @@ -175,6 +186,8 @@ def run_img2img(
num_samples = num_rows * num_cols

if st.button("Sample"):
outputs = st.empty()
st.text("Sampling")
out = do_img2img(
repeat(img, "1 ... -> n ...", n=num_samples),
state["model"],
Expand All @@ -185,6 +198,7 @@ def run_img2img(
return_latents=return_latents,
filter=filter,
)
show_samples(out, outputs)
return out


Expand Down Expand Up @@ -249,8 +263,6 @@ def apply_refiner(
save_locally, save_path = init_save_locally(os.path.join(SAVE_PATH, version))

state = init_st(version_dict)
if state["msg"]:
st.info(state["msg"])
model = state["model"]

is_legacy = version_dict["is_legacy"]
Expand All @@ -275,7 +287,6 @@ def apply_refiner(

version_dict2 = VERSION2SPECS[version2]
state2 = init_st(version_dict2)
st.info(state2["msg"])

stage2strength = st.number_input(
"**Refinement strength**", value=0.3, min_value=0.0, max_value=1.0
Expand Down Expand Up @@ -315,6 +326,7 @@ def apply_refiner(
samples_z = None

if add_pipeline and samples_z is not None:
outputs = st.empty()
st.write("**Running Refinement Stage**")
samples = apply_refiner(
samples_z,
Expand All @@ -325,6 +337,7 @@ def apply_refiner(
negative_prompt=negative_prompt if is_legacy else "",
filter=filter,
)
show_samples(samples, outputs)

if save_locally and samples is not None:
perform_save_locally(save_path, samples)
Loading

0 comments on commit 931d7a3

Please sign in to comment.