Skip to content

Commit ac60da8

Browse files
Pre release changes for production (Stability-AI#59)
* clean requirements * rm taming deps * isort, black * mv lipips, license * clean vq, fix path * fix loss path, gitignore * tested requirements pt13 * fix numpy req for python3.8, add tests * fix name * fix dep scipy 3.8 pt2 * add black test formatter
1 parent 99ece15 commit ac60da8

31 files changed

+641
-127
lines changed

.github/workflows/black.yml

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
name: Run black
2+
on: [push, pull_request]
3+
4+
jobs:
5+
lint:
6+
runs-on: ubuntu-latest
7+
steps:
8+
- uses: actions/checkout@v3
9+
- name: Install venv
10+
run: |
11+
sudo apt-get -y install python3.10-venv
12+
- uses: psf/black@stable
13+
with:
14+
options: "--check --verbose -l88"
15+
src: "./sgm ./scripts ./main.py"

.github/workflows/test-build.yaml

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
name: Build package
2+
3+
on:
4+
push:
5+
pull_request:
6+
7+
jobs:
8+
build:
9+
name: Build
10+
runs-on: ubuntu-latest
11+
strategy:
12+
fail-fast: false
13+
matrix:
14+
python-version: ["3.8", "3.10"]
15+
requirements-file: ["pt2", "pt13"]
16+
steps:
17+
- uses: actions/checkout@v2
18+
- name: Set up Python ${{ matrix.python-version }}
19+
uses: actions/setup-python@v2
20+
with:
21+
python-version: ${{ matrix.python-version }}
22+
- name: Install dependencies
23+
run: |
24+
python -m pip install --upgrade pip
25+
pip install -r requirements/${{ matrix.requirements-file }}.txt
26+
pip install .

.gitignore

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,14 @@
1+
# extensions
12
*.egg-info
23
*.py[cod]
4+
5+
# envs
36
.pt13
47
.pt2
5-
.pt2_2
8+
9+
# directories
610
/checkpoints
711
/dist
812
/outputs
9-
build
13+
/build
14+
/src

README.md

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -59,10 +59,9 @@ This is assuming you have navigated to the `generative-models` root after clonin
5959

6060
```shell
6161
# install required packages from pypi
62-
python3 -m venv .pt1
63-
source .pt1/bin/activate
64-
pip3 install wheel
65-
pip3 install -r requirements_pt13.txt
62+
python3 -m venv .pt13
63+
source .pt13/bin/activate
64+
pip3 install -r requirements/pt13.txt
6665
```
6766

6867
**PyTorch 2.0**
@@ -72,8 +71,20 @@ pip3 install -r requirements_pt13.txt
7271
# install required packages from pypi
7372
python3 -m venv .pt2
7473
source .pt2/bin/activate
75-
pip3 install wheel
76-
pip3 install -r requirements_pt2.txt
74+
pip3 install -r requirements/pt2.txt
75+
```
76+
77+
78+
#### 3. Install `sgm`
79+
80+
```shell
81+
pip3 install .
82+
```
83+
84+
#### 4. Install `sdata` for training
85+
86+
```shell
87+
pip3 install -e git+https://github.com/Stability-AI/datapipelines.git@main#egg=sdata
7788
```
7889

7990
## Packaging

main.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,22 +12,18 @@
1212
import torch
1313
import torchvision
1414
import wandb
15-
from PIL import Image
1615
from matplotlib import pyplot as plt
1716
from natsort import natsorted
1817
from omegaconf import OmegaConf
1918
from packaging import version
19+
from PIL import Image
2020
from pytorch_lightning import seed_everything
2121
from pytorch_lightning.callbacks import Callback
2222
from pytorch_lightning.loggers import WandbLogger
2323
from pytorch_lightning.trainer import Trainer
2424
from pytorch_lightning.utilities import rank_zero_only
2525

26-
from sgm.util import (
27-
exists,
28-
instantiate_from_config,
29-
isheatmap,
30-
)
26+
from sgm.util import exists, instantiate_from_config, isheatmap
3127

3228
MULTINODE_HACKS = True
3329

@@ -910,11 +906,12 @@ def divein(*args, **kwargs):
910906
trainer.test(model, data)
911907
except RuntimeError as err:
912908
if MULTINODE_HACKS:
913-
import requests
914909
import datetime
915910
import os
916911
import socket
917912

913+
import requests
914+
918915
device = os.environ.get("CUDA_VISIBLE_DEVICES", "?")
919916
hostname = socket.gethostname()
920917
ts = datetime.datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S")

requirements/pt13.txt

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
black==23.7.0
2+
chardet>=5.1.0
3+
clip @ git+https://github.com/openai/CLIP.git
4+
einops>=0.6.1
5+
fairscale>=0.4.13
6+
fire>=0.5.0
7+
fsspec>=2023.6.0
8+
invisible-watermark>=0.2.0
9+
kornia==0.6.9
10+
matplotlib>=3.7.2
11+
natsort>=8.4.0
12+
numpy>=1.24.4
13+
omegaconf>=2.3.0
14+
onnx<=1.12.0
15+
open-clip-torch>=2.20.0
16+
opencv-python==4.6.0.66
17+
pandas>=2.0.3
18+
pillow>=9.5.0
19+
pudb>=2022.1.3
20+
pytorch-lightning==1.8.5
21+
pyyaml>=6.0.1
22+
scipy>=1.10.1
23+
streamlit>=1.25.0
24+
tensorboardx==2.5.1
25+
timm>=0.9.2
26+
tokenizers==0.12.1
27+
--extra-index-url https://download.pytorch.org/whl/cu117
28+
torch==1.13.1+cu117
29+
torchaudio==0.13.1
30+
torchdata==0.5.1
31+
torchmetrics>=1.0.1
32+
torchvision==0.14.1+cu117
33+
tqdm>=4.65.0
34+
transformers==4.19.1
35+
triton==2.0.0.post1
36+
urllib3<1.27,>=1.25.4
37+
wandb>=0.15.6
38+
webdataset>=0.2.33
39+
wheel>=0.41.0
40+
xformers==0.0.16

requirements/pt2.txt

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
black==23.7.0
2+
chardet==5.1.0
3+
clip @ git+https://github.com/openai/CLIP.git
4+
einops>=0.6.1
5+
fairscale>=0.4.13
6+
fire>=0.5.0
7+
fsspec>=2023.6.0
8+
invisible-watermark>=0.2.0
9+
kornia==0.6.9
10+
matplotlib>=3.7.2
11+
natsort>=8.4.0
12+
ninja>=1.11.1
13+
numpy>=1.24.4
14+
omegaconf>=2.3.0
15+
open-clip-torch>=2.20.0
16+
opencv-python==4.6.0.66
17+
pandas>=2.0.3
18+
pillow>=9.5.0
19+
pudb>=2022.1.3
20+
pytorch-lightning==2.0.1
21+
pyyaml>=6.0.1
22+
scipy>=1.10.1
23+
streamlit>=0.73.1
24+
tensorboardx==2.6
25+
timm>=0.9.2
26+
tokenizers==0.12.1
27+
torch>=2.0.1
28+
torchaudio>=2.0.2
29+
torchdata==0.6.1
30+
torchmetrics>=1.0.1
31+
torchvision>=0.15.2
32+
tqdm>=4.65.0
33+
transformers==4.19.1
34+
triton==2.0.0
35+
urllib3<1.27,>=1.25.4
36+
wandb>=0.15.6
37+
webdataset>=0.2.33
38+
wheel>=0.41.0
39+
xformers>=0.0.20

requirements_pt13.txt

Lines changed: 0 additions & 41 deletions
This file was deleted.

requirements_pt2.txt

Lines changed: 0 additions & 41 deletions
This file was deleted.

scripts/demo/sampling.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from pytorch_lightning import seed_everything
2+
23
from scripts.demo.streamlit_helpers import *
34
from scripts.util.detection.nsfw_and_watermark_dectection import DeepFloydDataFiltering
45

scripts/demo/streamlit_helpers.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,28 @@
1+
import math
12
import os
2-
from typing import Union, List
3+
from typing import List, Union
34

4-
import math
55
import numpy as np
66
import streamlit as st
77
import torch
8-
from PIL import Image
98
from einops import rearrange, repeat
109
from imwatermark import WatermarkEncoder
11-
from omegaconf import OmegaConf, ListConfig
10+
from omegaconf import ListConfig, OmegaConf
11+
from PIL import Image
12+
from safetensors.torch import load_file as load_safetensors
1213
from torch import autocast
1314
from torchvision import transforms
1415
from torchvision.utils import make_grid
15-
from safetensors.torch import load_file as load_safetensors
1616

1717
from sgm.modules.diffusionmodules.sampling import (
18+
DPMPP2MSampler,
19+
DPMPP2SAncestralSampler,
20+
EulerAncestralSampler,
1821
EulerEDMSampler,
1922
HeunEDMSampler,
20-
EulerAncestralSampler,
21-
DPMPP2SAncestralSampler,
22-
DPMPP2MSampler,
2323
LinearMultistepSampler,
2424
)
25-
from sgm.util import append_dims
26-
from sgm.util import instantiate_from_config
25+
from sgm.util import append_dims, instantiate_from_config
2726

2827

2928
class WatermarkEmbedder:

scripts/util/detection/nsfw_and_watermark_dectection.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import os
2-
import torch
2+
3+
import clip
34
import numpy as np
5+
import torch
46
import torchvision.transforms as T
57
from PIL import Image
6-
import clip
78

89
RESOURCES_ROOT = "scripts/util/detection/"
910

sgm/__init__.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
from .data import StableDataModuleFromConfig
21
from .models import AutoencodingEngine, DiffusionEngine
3-
from .util import instantiate_from_config, get_configs_path
2+
from .util import get_configs_path, instantiate_from_config
43

5-
__version__ = "0.0.1"
4+
__version__ = "0.1.0"

sgm/data/cifar10.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
import torchvision
21
import pytorch_lightning as pl
3-
from torchvision import transforms
2+
import torchvision
43
from torch.utils.data import DataLoader, Dataset
4+
from torchvision import transforms
55

66

77
class CIFAR10DataDictWrapper(Dataset):

sgm/data/mnist.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
import torchvision
21
import pytorch_lightning as pl
3-
from torchvision import transforms
2+
import torchvision
43
from torch.utils.data import DataLoader, Dataset
4+
from torchvision import transforms
55

66

77
class MNISTDataDictWrapper(Dataset):

0 commit comments

Comments
 (0)