Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Experiment #139

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,13 @@
.DS_Store
__pycache__/

ckpt/
*.ipynb
datasets/cifar_test
datasets/cifar_train
.venv/
anaconda.sh
improved_diffusion.egg-info/
wandb/
*.pt
results/
141 changes: 76 additions & 65 deletions improved_diffusion/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
from .nn import update_ema
from .resample import LossAwareSampler, UniformSampler

import wandb
from tqdm import tqdm

# For ImageNet experiments, this was a good default value.
# We found that the lg_loss_scale quickly climbed to
# 20-21 within the first ~1K steps of training.
Expand Down Expand Up @@ -52,11 +55,7 @@ def __init__(
self.batch_size = batch_size
self.microbatch = microbatch if microbatch > 0 else batch_size
self.lr = lr
self.ema_rate = (
[ema_rate]
if isinstance(ema_rate, float)
else [float(x) for x in ema_rate.split(",")]
)
self.ema_rate = [ema_rate] if isinstance(ema_rate, float) else [float(x) for x in ema_rate.split(",")]
self.log_interval = log_interval
self.save_interval = save_interval
self.resume_checkpoint = resume_checkpoint
Expand Down Expand Up @@ -84,13 +83,9 @@ def __init__(
self._load_optimizer_state()
# Model was resumed, either due to a restart or a checkpoint
# being specified at the command line.
self.ema_params = [
self._load_ema_parameters(rate) for rate in self.ema_rate
]
self.ema_params = [self._load_ema_parameters(rate) for rate in self.ema_rate]
else:
self.ema_params = [
copy.deepcopy(self.master_params) for _ in range(len(self.ema_rate))
]
self.ema_params = [copy.deepcopy(self.master_params) for _ in range(len(self.ema_rate))]

if th.cuda.is_available():
self.use_ddp = True
Expand All @@ -104,25 +99,37 @@ def __init__(
)
else:
if dist.get_world_size() > 1:
logger.warn(
"Distributed training requires CUDA. "
"Gradients will not be synchronized properly!"
)
logger.warn("Distributed training requires CUDA. " "Gradients will not be synchronized properly!")
self.use_ddp = False
self.ddp_model = self.model

wandb.init(project="imporved_diffusion", entity="quasar529")
wandb.config.update(
{
"batch_size": batch_size,
"microbatch": microbatch,
"lr": lr,
"ema_rate": ema_rate,
"log_interval": log_interval,
"save_interval": save_interval,
"resume_checkpoint": resume_checkpoint,
"use_fp16": use_fp16,
"fp16_scale_growth": fp16_scale_growth,
"schedule_sampler": schedule_sampler,
"weight_decay": weight_decay,
"lr_anneal_steps": lr_anneal_steps,
}
)
wandb.run.name = f"diffusion_batch_{batch_size}_steps{lr_anneal_steps}_{self.resume_checkpoint}"

def _load_and_sync_parameters(self):
resume_checkpoint = find_resume_checkpoint() or self.resume_checkpoint

if resume_checkpoint:
self.resume_step = parse_resume_step_from_filename(resume_checkpoint)
if dist.get_rank() == 0:
logger.log(f"loading model from checkpoint: {resume_checkpoint}...")
self.model.load_state_dict(
dist_util.load_state_dict(
resume_checkpoint, map_location=dist_util.dev()
)
)
self.model.load_state_dict(dist_util.load_state_dict(resume_checkpoint, map_location=dist_util.dev()))

dist_util.sync_params(self.model.parameters())

Expand All @@ -134,45 +141,40 @@ def _load_ema_parameters(self, rate):
if ema_checkpoint:
if dist.get_rank() == 0:
logger.log(f"loading EMA from checkpoint: {ema_checkpoint}...")
state_dict = dist_util.load_state_dict(
ema_checkpoint, map_location=dist_util.dev()
)
state_dict = dist_util.load_state_dict(ema_checkpoint, map_location=dist_util.dev())
ema_params = self._state_dict_to_master_params(state_dict)

dist_util.sync_params(ema_params)
return ema_params

def _load_optimizer_state(self):
main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint
opt_checkpoint = bf.join(
bf.dirname(main_checkpoint), f"opt{self.resume_step:06}.pt"
)
opt_checkpoint = bf.join(bf.dirname(main_checkpoint), f"opt{self.resume_step:06}.pt")
if bf.exists(opt_checkpoint):
logger.log(f"loading optimizer state from checkpoint: {opt_checkpoint}")
state_dict = dist_util.load_state_dict(
opt_checkpoint, map_location=dist_util.dev()
)
state_dict = dist_util.load_state_dict(opt_checkpoint, map_location=dist_util.dev())
self.opt.load_state_dict(state_dict)

def _setup_fp16(self):
self.master_params = make_master_params(self.model_params)
self.model.convert_to_fp16()

def run_loop(self):
while (
not self.lr_anneal_steps
or self.step + self.resume_step < self.lr_anneal_steps
):
batch, cond = next(self.data)
self.run_step(batch, cond)
if self.step % self.log_interval == 0:
logger.dumpkvs()
if self.step % self.save_interval == 0:
self.save()
# Run for a finite amount of time in integration tests.
if os.environ.get("DIFFUSION_TRAINING_TEST", "") and self.step > 0:
return
self.step += 1
total_steps = self.lr_anneal_steps if self.lr_anneal_steps else float("inf")
with tqdm(total=total_steps, desc="Training Progress", dynamic_ncols=True) as pbar:
while not self.lr_anneal_steps or self.step + self.resume_step < self.lr_anneal_steps:
batch, cond = next(self.data)
self.run_step(batch, cond)
if self.step % self.log_interval == 0:
logger.dumpkvs()
if self.step % self.save_interval == 0:
self.save()
# Run for a finite amount of time in integration tests.
if os.environ.get("DIFFUSION_TRAINING_TEST", "") and self.step > 0:
return
self.step += 1
pbar.update(1) # Update the progress bar

# Save the last checkpoint if it wasn't already saved.
if (self.step - 1) % self.save_interval != 0:
self.save()
Expand All @@ -187,12 +189,17 @@ def run_step(self, batch, cond):

def forward_backward(self, batch, cond):
zero_grad(self.model_params)
total_batches = (batch.shape[0] + self.microbatch - 1) // self.microbatch # 전체 배치 수 계산
progress_bar = tqdm(total=total_batches, desc="Processing Batches", leave=False, dynamic_ncols=True)

for i in range(0, batch.shape[0], self.microbatch):
current_batch = i // self.microbatch + 1 # 현재 배치 번호 계산
progress_bar.set_description(
f"Processing Batches {current_batch}/{total_batches}"
) # 진행 바 설명 업데이트

micro = batch[i : i + self.microbatch].to(dist_util.dev())
micro_cond = {
k: v[i : i + self.microbatch].to(dist_util.dev())
for k, v in cond.items()
}
micro_cond = {k: v[i : i + self.microbatch].to(dist_util.dev()) for k, v in cond.items()}
last_batch = (i + self.microbatch) >= batch.shape[0]
t, weights = self.schedule_sampler.sample(micro.shape[0], dist_util.dev())

Expand All @@ -211,19 +218,21 @@ def forward_backward(self, batch, cond):
losses = compute_losses()

if isinstance(self.schedule_sampler, LossAwareSampler):
self.schedule_sampler.update_with_local_losses(
t, losses["loss"].detach()
)
self.schedule_sampler.update_with_local_losses(t, losses["loss"].detach())

loss = (losses["loss"] * weights).mean()
log_loss_dict(
self.diffusion, t, {k: v * weights for k, v in losses.items()}
)
wandb.log({"train/loss": loss.item()})
log_loss_dict(self.diffusion, t, {k: v * weights for k, v in losses.items()})

if self.use_fp16:
loss_scale = 2 ** self.lg_loss_scale
loss_scale = 2**self.lg_loss_scale
(loss * loss_scale).backward()
else:
loss.backward()
progress_bar.set_postfix(loss=loss.item())
progress_bar.update(1) # 배치 진행 상황 업데이트

progress_bar.close() # 배치 진행 바 종료

def optimize_fp16(self):
if any(not th.isfinite(p.grad).all() for p in self.model_params):
Expand All @@ -232,7 +241,7 @@ def optimize_fp16(self):
return

model_grads_to_master_grads(self.model_params, self.master_params)
self.master_params[0].grad.mul_(1.0 / (2 ** self.lg_loss_scale))
self.master_params[0].grad.mul_(1.0 / (2**self.lg_loss_scale))
self._log_grad_norm()
self._anneal_lr()
self.opt.step()
Expand All @@ -251,7 +260,8 @@ def optimize_normal(self):
def _log_grad_norm(self):
sqsum = 0.0
for p in self.master_params:
sqsum += (p.grad ** 2).sum().item()
sqsum += (p.grad**2).sum().item()
wandb.log({"train/grad_norm": np.sqrt(sqsum)})
logger.logkv_mean("grad_norm", np.sqrt(sqsum))

def _anneal_lr(self):
Expand All @@ -261,14 +271,17 @@ def _anneal_lr(self):
lr = self.lr * (1 - frac_done)
for param_group in self.opt.param_groups:
param_group["lr"] = lr
wandb.log({"train/lr": lr})

def log_step(self):
logger.logkv("step", self.step + self.resume_step)
logger.logkv("samples", (self.step + self.resume_step + 1) * self.global_batch)
wandb.log({"train/step": self.step + self.resume_step})
wandb.log({"train/samples": (self.step + self.resume_step + 1) * self.global_batch})
if self.use_fp16:
logger.logkv("lg_loss_scale", self.lg_loss_scale)

def save(self):
def save(self, save_dir="/home/jun/improved-diffusion/results"):
def save_checkpoint(rate, params):
state_dict = self._master_params_to_state_dict(params)
if dist.get_rank() == 0:
Expand All @@ -277,27 +290,25 @@ def save_checkpoint(rate, params):
filename = f"model{(self.step+self.resume_step):06d}.pt"
else:
filename = f"ema_{rate}_{(self.step+self.resume_step):06d}.pt"
with bf.BlobFile(bf.join(get_blob_logdir(), filename), "wb") as f:
save_path = bf.join(save_dir, filename)
with bf.BlobFile(save_path, "wb") as f:
th.save(state_dict, f)

save_checkpoint(0, self.master_params)
for rate, params in zip(self.ema_rate, self.ema_params):
save_checkpoint(rate, params)

if dist.get_rank() == 0:
with bf.BlobFile(
bf.join(get_blob_logdir(), f"opt{(self.step+self.resume_step):06d}.pt"),
"wb",
) as f:
opt_filename = f"opt{(self.step+self.resume_step):06d}.pt"
opt_save_path = bf.join(save_dir, opt_filename)
with bf.BlobFile(opt_save_path, "wb") as f:
th.save(self.opt.state_dict(), f)

dist.barrier()

def _master_params_to_state_dict(self, master_params):
if self.use_fp16:
master_params = unflatten_master_params(
self.model.parameters(), master_params
)
master_params = unflatten_master_params(self.model.parameters(), master_params)
state_dict = self.model.state_dict()
for i, (name, _value) in enumerate(self.model.named_parameters()):
assert name in state_dict
Expand Down
27 changes: 10 additions & 17 deletions scripts/image_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@
import numpy as np
import torch as th
import torch.distributed as dist
import sys

print(sys.path)
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from improved_diffusion import dist_util, logger
from improved_diffusion.script_util import (
NUM_CLASSES,
Expand All @@ -27,12 +30,8 @@ def main():
logger.configure()

logger.log("creating model and diffusion...")
model, diffusion = create_model_and_diffusion(
**args_to_dict(args, model_and_diffusion_defaults().keys())
)
model.load_state_dict(
dist_util.load_state_dict(args.model_path, map_location="cpu")
)
model, diffusion = create_model_and_diffusion(**args_to_dict(args, model_and_diffusion_defaults().keys()))
model.load_state_dict(dist_util.load_state_dict(args.model_path, map_location="cpu"))
model.to(dist_util.dev())
model.eval()

Expand All @@ -42,13 +41,9 @@ def main():
while len(all_images) * args.batch_size < args.num_samples:
model_kwargs = {}
if args.class_cond:
classes = th.randint(
low=0, high=NUM_CLASSES, size=(args.batch_size,), device=dist_util.dev()
)
classes = th.randint(low=0, high=NUM_CLASSES, size=(args.batch_size,), device=dist_util.dev())
model_kwargs["y"] = classes
sample_fn = (
diffusion.p_sample_loop if not args.use_ddim else diffusion.ddim_sample_loop
)
sample_fn = diffusion.p_sample_loop if not args.use_ddim else diffusion.ddim_sample_loop
sample = sample_fn(
model,
(args.batch_size, 3, args.image_size, args.image_size),
Expand All @@ -63,9 +58,7 @@ def main():
dist.all_gather(gathered_samples, sample) # gather not supported with NCCL
all_images.extend([sample.cpu().numpy() for sample in gathered_samples])
if args.class_cond:
gathered_labels = [
th.zeros_like(classes) for _ in range(dist.get_world_size())
]
gathered_labels = [th.zeros_like(classes) for _ in range(dist.get_world_size())]
dist.all_gather(gathered_labels, classes)
all_labels.extend([labels.cpu().numpy() for labels in gathered_labels])
logger.log(f"created {len(all_images) * args.batch_size} samples")
Expand All @@ -91,9 +84,9 @@ def main():
def create_argparser():
defaults = dict(
clip_denoised=True,
num_samples=10000,
num_samples=10,
batch_size=16,
use_ddim=False,
use_ddim=True,
model_path="",
)
defaults.update(model_and_diffusion_defaults())
Expand Down
13 changes: 7 additions & 6 deletions scripts/image_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@ def main():
logger.configure()

logger.log("creating model and diffusion...")
model, diffusion = create_model_and_diffusion(
**args_to_dict(args, model_and_diffusion_defaults().keys())
)
model, diffusion = create_model_and_diffusion(**args_to_dict(args, model_and_diffusion_defaults().keys()))
#model.load_state_dict(dist_util.load_state_dict(args.model_path, map_location="cpu"))
model.to(dist_util.dev())
print(dist_util.dev())
schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, diffusion)

logger.log("creating data loader...")
Expand Down Expand Up @@ -63,15 +63,16 @@ def create_argparser():
schedule_sampler="uniform",
lr=1e-4,
weight_decay=0.0,
lr_anneal_steps=0,
lr_anneal_steps=10000,
batch_size=1,
microbatch=-1, # -1 disables microbatches
ema_rate="0.9999", # comma-separated list of EMA values
log_interval=10,
save_interval=10000,
save_interval=5000,
resume_checkpoint="",
use_fp16=False,
use_fp16=True,
fp16_scale_growth=1e-3,
model_path="",
)
defaults.update(model_and_diffusion_defaults())
parser = argparse.ArgumentParser()
Expand Down