Skip to content

Commit 3f628ab

Browse files
committed
add aunet
1 parent 9fe0e73 commit 3f628ab

21 files changed

+2175
-30
lines changed

check_model.ipynb

+59
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,65 @@
294294
"\n",
295295
"test(20.3423, 3.1427)"
296296
]
297+
},
298+
{
299+
"cell_type": "code",
300+
"execution_count": 19,
301+
"metadata": {},
302+
"outputs": [
303+
{
304+
"name": "stdout",
305+
"output_type": "stream",
306+
"text": [
307+
"24M\n"
308+
]
309+
}
310+
],
311+
"source": [
312+
"import torch\n",
313+
"from utils.bridge.models import UNetModel\n",
314+
"\n",
315+
"kwargs = {'in_channels': 2, \n",
316+
" 'model_channels': 64, \n",
317+
" 'out_channels': 1, \n",
318+
" 'num_res_blocks': 4, \n",
319+
" 'attention_resolutions': (0,), \n",
320+
" 'dropout': 0.0, \n",
321+
" 'channel_mult': (1, 2, 4), \n",
322+
" 'num_classes': None, \n",
323+
" 'use_checkpoint': False, \n",
324+
" 'num_heads': 8, \n",
325+
" 'num_heads_upsample': -1, \n",
326+
" 'use_scale_shift_norm': True\n",
327+
" }\n",
328+
"\n",
329+
"model = UNetModel(**kwargs)\n",
330+
"before_train = None\n",
331+
"after_train = None\n",
332+
"print(f\"{int(sum(p.numel() for p in model.parameters())/1e6)}M\")"
333+
]
334+
},
335+
{
336+
"cell_type": "code",
337+
"execution_count": 13,
338+
"metadata": {},
339+
"outputs": [
340+
{
341+
"data": {
342+
"text/plain": [
343+
"torch.Size([100, 1, 28, 28])"
344+
]
345+
},
346+
"execution_count": 13,
347+
"metadata": {},
348+
"output_type": "execute_result"
349+
}
350+
],
351+
"source": [
352+
"x = torch.rand(100, 2, 28, 28)\n",
353+
"t = torch.rand(100)\n",
354+
"model(x, t).shape"
355+
]
297356
}
298357
],
299358
"metadata": {

test.py

+6-12
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,24 @@
11
import numpy as np
2-
import matplotlib.pyplot as plt
32
import pickle
43

54
import torch
6-
from torch import nn, optim
7-
from torch.utils.data import Dataset, DataLoader
8-
import torch.nn.functional as F
95

10-
from torch.optim.lr_scheduler import CosineAnnealingLR, CosineAnnealingWarmRestarts, StepLR, OneCycleLR
116
from pathlib import Path
12-
from sklearn.datasets import *
137

148
from rich.panel import Panel
159
from rich.pretty import Pretty
1610
from rich.console import Console
1711
from rich.progress import Progress, SpinnerColumn, TimeElapsedColumn
1812
import time as tt
1913

20-
from utils.Datasets import BBdataset, MNISTdataset
2114
from utils.utils import plot_source_and_target_mnist, binary, save_gif_frame_mnist
22-
from utils.data_utils import gen_mnist_data, reverse_normalize_dataset, normalize_dataset_with_metadata, gen_ds
15+
from utils.data_utils import gen_mnist_data, reverse_normalize_dataset, normalize_dataset_with_metadata
2316
from utils.model_utils import get_model_before_after
2417
import argparse
2518

2619
def check_model_task(args):
2720
if args.task == 'gaussian2mnist':
28-
assert args.model in ['tunet++', 'unet++', 'unet']
21+
assert args.model in ['tunet++', 'aunet', 'unet++', 'unet']
2922
args.time_expand = False
3023
else:
3124
assert args.model in ['mlp', 'unet++', 'unet']
@@ -45,7 +38,7 @@ def main():
4538
parser.add_argument('--lr', type=float, default=1e-4)
4639
parser.add_argument('--iter_nums', type=int, default=1)
4740
parser.add_argument('--epoch_nums', type=int, default=2)
48-
parser.add_argument('--batch_size', type=int, default=8000)
41+
parser.add_argument('-b', '--batch_size', type=int, default=8000)
4942
parser.add_argument('-n','--normalize', action='store_true')
5043
parser.add_argument('--tarined_data', action='store_true')
5144
parser.add_argument('--filter_number', type=int)
@@ -60,7 +53,7 @@ def main():
6053
np.random.seed(seed)
6154

6255
experiment_name = args.task
63-
if args.change_epsilons:
56+
if args.change_epsilons:
6457
experiment_name += '_change_epsilons'
6558
if args.filter_number is not None and 'mnist' in args.task:
6659
experiment_name += f'_filter{args.filter_number}'
@@ -138,10 +131,12 @@ def main_worker(args):
138131
for i in range(len(test_ts) - 1):
139132
dt = (test_ts[i+1] - test_ts[i])
140133
test_source_reshaped = test_source
134+
141135
if args.time_expand:
142136
test_ts_reshaped = test_ts[i].repeat(test_source.shape[0]).reshape(-1, 1, 1, 1).repeat(1, 1, 28, 28)
143137
else:
144138
test_ts_reshaped = torch.unsqueeze(test_ts[i], dim=0).T
139+
145140
pred_bridge_reshaped = pred_bridge[i]
146141

147142
ret = normalize_dataset_with_metadata(real_metadata, source=test_source_reshaped, ts=test_ts_reshaped, bridge=pred_bridge_reshaped)
@@ -156,7 +151,6 @@ def main_worker(args):
156151
time = test_ts_reshaped.to(args.device)
157152
if before_train is not None:
158153
x = before_train(x)
159-
160154
x = x.to(args.device)
161155
model = model.to(args.device)
162156
dydt = model(x, time) if time is not None else model(x)

train.py

+10-7
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
def check_model_task(args):
2626
if args.task.startswith('gaussian2mnist'):
27-
assert args.model in ['tunet++', 'unet++', 'unet']
27+
assert args.model in ['tunet++', 'aunet']
2828
args.time_expand = False
2929
else:
3030
assert args.model in ['mlp', 'unet++', 'unet']
@@ -43,7 +43,7 @@ def main():
4343
parser.add_argument('--lr', type=float, default=1e-4)
4444
parser.add_argument('--iter_nums', type=int, default=1)
4545
parser.add_argument('--epoch_nums', type=int, default=3)
46-
parser.add_argument('--batch_size', type=int, default=8000)
46+
parser.add_argument('-b','--batch_size', type=int, default=8000)
4747
parser.add_argument('-n','--normalize', action='store_true')
4848
parser.add_argument('--num_workers', type=int, default=20)
4949
parser.add_argument('--filter_number', type=int)
@@ -80,11 +80,10 @@ def main():
8080
def train(args, model, train_dl, optimizer, scheduler, loss_fn, before_train=None, after_train=None):
8181
losses = 0
8282
for data in train_dl:
83-
if args.model == 'tunet++':
83+
if isinstance(data, list):
8484
training_data, time = data
8585
else:
86-
training_data = data
87-
time = None
86+
training_data, time = data, None
8887

8988
training_data = training_data.squeeze().float().cpu()
9089
x, y = training_data[:, :-args.dim], training_data[:, -args.dim:]
@@ -168,15 +167,19 @@ def main_worker(args):
168167
progress.remove_task(task2)
169168
torch.save(model.state_dict(), args.log_dir / f'model_{model.__class__.__name__}_{int(iter)}.pth')
170169
progress.update(task1, advance=1, description="[red]Training whole dataset (lr: %2.5f) (loss=%2.5f)" % (cur_lr, now_loss))
171-
progress.log(f"[green]sub dataset {int(iter%ds_info['nums_sub_ds'])} finished; Loss: {now_loss}")
170+
progress.log("[green]sub dataset %d finished; Loss: %2.5f" % (int(iter%ds_info['nums_sub_ds']), now_loss))
171+
172+
console.rule("[bold bright_green blink]Finished Training")
173+
console.log("Final loss: %2.5f" % (loss_list[-1]))
172174
# Draw loss curve
173175
fig, ax = plt.subplots(figsize=(10, 5))
174176
ax.plot(loss_list)
175177
ax.set_title("Loss")
176178
fig.savefig(args.log_dir / 'loss.png')
179+
console.log("Loss curve saved to {}".format(args.log_dir / 'loss.png'))
177180

178181
torch.save(model.state_dict(), args.log_dir / f'model_{model.__class__.__name__}_final.pth')
179-
182+
console.log("Model saved to {}".format(args.log_dir / f'model_{model.__class__.__name__}_final.pth'))
180183

181184
if __name__ == '__main__':
182185
main()

utils/bridge/__init__.py

Whitespace-only changes.

utils/bridge/langevin.py

+127
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
import copy
2+
import torch
3+
import torch.nn.functional as F
4+
from tqdm import tqdm
5+
import os
6+
import numpy as np
7+
8+
9+
def grad_gauss(x, m, var):
10+
xout = (x - m) / var
11+
return -xout
12+
13+
14+
def ornstein_ulhenbeck(x, gradx, gamma):
15+
xout = x + gamma * gradx + \
16+
torch.sqrt(2 * gamma) * torch.randn(x.shape, device=x.device)
17+
return xout
18+
19+
20+
class Langevin(torch.nn.Module):
21+
22+
def __init__(self, num_steps, shape, gammas, time_sampler, device=None,
23+
mean_final=torch.tensor([0., 0.]), var_final=torch.tensor([.5, .5]), mean_match=True):
24+
super().__init__()
25+
26+
self.mean_match = mean_match
27+
self.mean_final = mean_final
28+
self.var_final = var_final
29+
30+
self.num_steps = num_steps # num diffusion steps
31+
self.d = shape # shape of object to diffuse
32+
self.gammas = gammas.float() # schedule
33+
gammas_vec = torch.ones(self.num_steps, *self.d, device=device)
34+
for k in range(num_steps):
35+
gammas_vec[k] = gammas[k].float()
36+
self.gammas_vec = gammas_vec
37+
38+
if device is not None:
39+
self.device = device
40+
else:
41+
self.device = gammas.device
42+
43+
self.steps = torch.arange(self.num_steps).to(self.device)
44+
self.time = torch.cumsum(self.gammas, 0).to(self.device).float()
45+
self.time_sampler = time_sampler
46+
47+
def record_init_langevin(self, init_samples):
48+
mean_final = self.mean_final
49+
var_final = self.var_final
50+
51+
x = init_samples
52+
N = x.shape[0]
53+
steps = self.steps.reshape((1, self.num_steps, 1)).repeat((N, 1, 1))
54+
time = self.time.reshape((1, self.num_steps, 1)).repeat((N, 1, 1))
55+
gammas = self.gammas.reshape((1, self.num_steps, 1)).repeat((N, 1, 1))
56+
steps = time
57+
58+
x_tot = torch.Tensor(N, self.num_steps, *self.d).to(x.device)
59+
out = torch.Tensor(N, self.num_steps, *self.d).to(x.device)
60+
store_steps = self.steps
61+
num_iter = self.num_steps
62+
steps_expanded = time
63+
64+
for k in range(num_iter):
65+
gamma = self.gammas[k]
66+
gradx = grad_gauss(x, mean_final, var_final)
67+
t_old = x + gamma * gradx
68+
z = torch.randn(x.shape, device=x.device)
69+
x = t_old + torch.sqrt(2 * gamma)*z
70+
gradx = grad_gauss(x, mean_final, var_final)
71+
t_new = x + gamma * gradx
72+
73+
x_tot[:, k, :] = x
74+
out[:, k, :] = (t_old - t_new) # / (2 * gamma)
75+
76+
return x_tot, out, steps_expanded
77+
78+
def record_langevin_seq(self, net, init_samples, t_batch=None, ipf_it=0, sample=False):
79+
mean_final = self.mean_final
80+
var_final = self.var_final
81+
82+
x = init_samples
83+
N = x.shape[0]
84+
steps = self.steps.reshape((1, self.num_steps, 1)).repeat((N, 1, 1))
85+
time = self.time.reshape((1, self.num_steps, 1)).repeat((N, 1, 1))
86+
gammas = self.gammas.reshape((1, self.num_steps, 1)).repeat((N, 1, 1))
87+
steps = time
88+
89+
x_tot = torch.Tensor(N, self.num_steps, *self.d).to(x.device)
90+
out = torch.Tensor(N, self.num_steps, *self.d).to(x.device)
91+
store_steps = self.steps
92+
steps_expanded = steps
93+
num_iter = self.num_steps
94+
95+
if self.mean_match:
96+
for k in range(num_iter):
97+
gamma = self.gammas[k]
98+
t_old = net(x, steps[:, k, :])
99+
100+
if sample & (k == num_iter-1):
101+
x = t_old
102+
else:
103+
z = torch.randn(x.shape, device=x.device)
104+
x = t_old + torch.sqrt(2 * gamma) * z
105+
106+
t_new = net(x, steps[:, k, :])
107+
x_tot[:, k, :] = x
108+
out[:, k, :] = (t_old - t_new)
109+
else:
110+
for k in range(num_iter):
111+
gamma = self.gammas[k]
112+
t_old = x + net(x, steps[:, k, :])
113+
114+
if sample & (k == num_iter-1):
115+
x = t_old
116+
else:
117+
z = torch.randn(x.shape, device=x.device)
118+
x = t_old + torch.sqrt(2 * gamma) * z
119+
t_new = x + net(x, steps[:, k, :])
120+
121+
x_tot[:, k, :] = x
122+
out[:, k, :] = (t_old - t_new)
123+
124+
return x_tot, out, steps_expanded
125+
126+
def forward(self, net, init_samples, t_batch, ipf_it):
127+
return self.record_langevin_seq(net, init_samples, t_batch, ipf_it)

utils/bridge/models/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .basic import ScoreNetwork
2+
from .unet import UNetModel

utils/bridge/models/basic/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .basic import ScoreNetwork

utils/bridge/models/basic/basic.py

+37
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import torch
2+
from .layers import MLP
3+
from .time_embedding import get_timestep_embedding
4+
5+
class ScoreNetwork(torch.nn.Module):
6+
7+
def __init__(self, encoder_layers=[16], pos_dim=16, decoder_layers=[128,128], x_dim=2):
8+
super().__init__()
9+
self.temb_dim = pos_dim
10+
t_enc_dim = pos_dim *2
11+
self.locals = [encoder_layers, pos_dim, decoder_layers, x_dim]
12+
13+
self.net = MLP(2 * t_enc_dim,
14+
layer_widths=decoder_layers +[x_dim],
15+
activate_final = False,
16+
activation_fn=torch.nn.LeakyReLU())
17+
18+
self.t_encoder = MLP(pos_dim,
19+
layer_widths=encoder_layers +[t_enc_dim],
20+
activate_final = False,
21+
activation_fn=torch.nn.LeakyReLU())
22+
23+
self.x_encoder = MLP(x_dim,
24+
layer_widths=encoder_layers +[t_enc_dim],
25+
activate_final = False,
26+
activation_fn=torch.nn.LeakyReLU())
27+
28+
def forward(self, x, t):
29+
if len(x.shape) == 1:
30+
x = x.unsqueeze(0)
31+
32+
temb = get_timestep_embedding(t, self.temb_dim)
33+
temb = self.t_encoder(temb)
34+
xemb = self.x_encoder(x)
35+
h = torch.cat([xemb ,temb], -1)
36+
out = self.net(h)
37+
return out

0 commit comments

Comments
 (0)