-
Notifications
You must be signed in to change notification settings - Fork 1.1k
/
Copy pathfid_evaluation.py
109 lines (98 loc) · 3.94 KB
/
fid_evaluation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import math
import os
import numpy as np
import torch
from einops import rearrange, repeat
from pytorch_fid.fid_score import calculate_frechet_distance
from pytorch_fid.inception import InceptionV3
from torch.nn.functional import adaptive_avg_pool2d
from tqdm.auto import tqdm
def num_to_groups(num, divisor):
groups = num // divisor
remainder = num % divisor
arr = [divisor] * groups
if remainder > 0:
arr.append(remainder)
return arr
class FIDEvaluation:
def __init__(
self,
batch_size,
dl,
sampler,
channels=3,
accelerator=None,
stats_dir="./results",
device="cuda",
num_fid_samples=50000,
inception_block_idx=2048,
):
self.batch_size = batch_size
self.n_samples = num_fid_samples
self.device = device
self.channels = channels
self.dl = dl
self.sampler = sampler
self.stats_dir = stats_dir
self.print_fn = print if accelerator is None else accelerator.print
assert inception_block_idx in InceptionV3.BLOCK_INDEX_BY_DIM
block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[inception_block_idx]
self.inception_v3 = InceptionV3([block_idx]).to(device)
self.dataset_stats_loaded = False
def calculate_inception_features(self, samples):
if self.channels == 1:
samples = repeat(samples, "b 1 ... -> b c ...", c=3)
self.inception_v3.eval()
features = self.inception_v3(samples)[0]
if features.size(2) != 1 or features.size(3) != 1:
features = adaptive_avg_pool2d(features, output_size=(1, 1))
features = rearrange(features, "... 1 1 -> ...")
return features
def load_or_precalc_dataset_stats(self):
path = os.path.join(self.stats_dir, "dataset_stats")
try:
ckpt = np.load(path + ".npz")
self.m2, self.s2 = ckpt["m2"], ckpt["s2"]
self.print_fn("Dataset stats loaded from disk.")
ckpt.close()
except OSError:
num_batches = int(math.ceil(self.n_samples / self.batch_size))
stacked_real_features = []
self.print_fn(
f"Stacking Inception features for {self.n_samples} samples from the real dataset."
)
for _ in tqdm(range(num_batches)):
try:
real_samples = next(self.dl)
except StopIteration:
break
real_samples = real_samples.to(self.device)
real_features = self.calculate_inception_features(real_samples)
stacked_real_features.append(real_features)
stacked_real_features = (
torch.cat(stacked_real_features, dim=0).cpu().numpy()
)
m2 = np.mean(stacked_real_features, axis=0)
s2 = np.cov(stacked_real_features, rowvar=False)
np.savez_compressed(path, m2=m2, s2=s2)
self.print_fn(f"Dataset stats cached to {path}.npz for future use.")
self.m2, self.s2 = m2, s2
self.dataset_stats_loaded = True
@torch.inference_mode()
def fid_score(self):
if not self.dataset_stats_loaded:
self.load_or_precalc_dataset_stats()
self.sampler.eval()
batches = num_to_groups(self.n_samples, self.batch_size)
stacked_fake_features = []
self.print_fn(
f"Stacking Inception features for {self.n_samples} generated samples."
)
for batch in tqdm(batches):
fake_samples = self.sampler.sample(batch_size=batch)
fake_features = self.calculate_inception_features(fake_samples)
stacked_fake_features.append(fake_features)
stacked_fake_features = torch.cat(stacked_fake_features, dim=0).cpu().numpy()
m1 = np.mean(stacked_fake_features, axis=0)
s1 = np.cov(stacked_fake_features, rowvar=False)
return calculate_frechet_distance(m1, s1, self.m2, self.s2)