Skip to content

Commit 000d6b0

Browse files
authored
Merge pull request #1 from scxue/feat/sa-solver
add sa solver test file
2 parents c81e607 + 4c8f855 commit 000d6b0

File tree

1 file changed

+166
-0
lines changed

1 file changed

+166
-0
lines changed
+166
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
import torch
2+
3+
from diffusers import SASolverScheduler
4+
from diffusers.utils.testing_utils import require_torchsde, torch_device
5+
6+
from .test_schedulers import SchedulerCommonTest
7+
8+
9+
@require_torchsde
10+
class SASolverSchedulerTest(SchedulerCommonTest):
11+
scheduler_classes = (SASolverScheduler,)
12+
num_inference_steps = 10
13+
14+
def get_scheduler_config(self, **kwargs):
15+
config = {
16+
"num_train_timesteps": 1100,
17+
"beta_start": 0.0001,
18+
"beta_end": 0.02,
19+
"beta_schedule": "linear",
20+
}
21+
22+
config.update(**kwargs)
23+
return config
24+
25+
def test_timesteps(self):
26+
for timesteps in [10, 50, 100, 1000]:
27+
self.check_over_configs(num_train_timesteps=timesteps)
28+
29+
def test_betas(self):
30+
for beta_start, beta_end in zip([0.00001, 0.0001, 0.001], [0.0002, 0.002, 0.02]):
31+
self.check_over_configs(beta_start=beta_start, beta_end=beta_end)
32+
33+
def test_schedules(self):
34+
for schedule in ["linear", "scaled_linear"]:
35+
self.check_over_configs(beta_schedule=schedule)
36+
37+
def test_prediction_type(self):
38+
for prediction_type in ["epsilon", "v_prediction"]:
39+
self.check_over_configs(prediction_type=prediction_type)
40+
41+
def test_full_loop_no_noise(self):
42+
scheduler_class = self.scheduler_classes[0]
43+
scheduler_config = self.get_scheduler_config()
44+
scheduler = scheduler_class(**scheduler_config)
45+
46+
scheduler.set_timesteps(self.num_inference_steps)
47+
48+
model = self.dummy_model()
49+
sample = self.dummy_sample_deter * scheduler.init_noise_sigma
50+
sample = sample.to(torch_device)
51+
52+
for i, t in enumerate(scheduler.timesteps):
53+
sample = scheduler.scale_model_input(sample, t)
54+
55+
model_output = model(sample, t)
56+
57+
output = scheduler.step(model_output, t, sample)
58+
sample = output.prev_sample
59+
60+
result_sum = torch.sum(torch.abs(sample))
61+
result_mean = torch.mean(torch.abs(sample))
62+
63+
if torch_device in ["mps"]:
64+
assert abs(result_sum.item() - 167.47821044921875) < 1e-2
65+
assert abs(result_mean.item() - 0.2178705964565277) < 1e-3
66+
elif torch_device in ["cuda"]:
67+
assert abs(result_sum.item() - 171.59352111816406) < 1e-2
68+
assert abs(result_mean.item() - 0.22342906892299652) < 1e-3
69+
else:
70+
assert abs(result_sum.item() - 162.52383422851562) < 1e-2
71+
assert abs(result_mean.item() - 0.211619570851326) < 1e-3
72+
73+
def test_full_loop_with_v_prediction(self):
74+
scheduler_class = self.scheduler_classes[0]
75+
scheduler_config = self.get_scheduler_config(prediction_type="v_prediction")
76+
scheduler = scheduler_class(**scheduler_config)
77+
78+
scheduler.set_timesteps(self.num_inference_steps)
79+
80+
model = self.dummy_model()
81+
sample = self.dummy_sample_deter * scheduler.init_noise_sigma
82+
sample = sample.to(torch_device)
83+
84+
for i, t in enumerate(scheduler.timesteps):
85+
sample = scheduler.scale_model_input(sample, t)
86+
87+
model_output = model(sample, t)
88+
89+
output = scheduler.step(model_output, t, sample)
90+
sample = output.prev_sample
91+
92+
result_sum = torch.sum(torch.abs(sample))
93+
result_mean = torch.mean(torch.abs(sample))
94+
95+
if torch_device in ["mps"]:
96+
assert abs(result_sum.item() - 124.77149200439453) < 1e-2
97+
assert abs(result_mean.item() - 0.16226289014816284) < 1e-3
98+
elif torch_device in ["cuda"]:
99+
assert abs(result_sum.item() - 128.1663360595703) < 1e-2
100+
assert abs(result_mean.item() - 0.16688326001167297) < 1e-3
101+
else:
102+
assert abs(result_sum.item() - 119.8487548828125) < 1e-2
103+
assert abs(result_mean.item() - 0.1560530662536621) < 1e-3
104+
105+
def test_full_loop_device(self):
106+
scheduler_class = self.scheduler_classes[0]
107+
scheduler_config = self.get_scheduler_config()
108+
scheduler = scheduler_class(**scheduler_config)
109+
110+
scheduler.set_timesteps(self.num_inference_steps, device=torch_device)
111+
112+
model = self.dummy_model()
113+
sample = self.dummy_sample_deter.to(torch_device) * scheduler.init_noise_sigma
114+
115+
for t in scheduler.timesteps:
116+
sample = scheduler.scale_model_input(sample, t)
117+
118+
model_output = model(sample, t)
119+
120+
output = scheduler.step(model_output, t, sample)
121+
sample = output.prev_sample
122+
123+
result_sum = torch.sum(torch.abs(sample))
124+
result_mean = torch.mean(torch.abs(sample))
125+
126+
if torch_device in ["mps"]:
127+
assert abs(result_sum.item() - 167.46957397460938) < 1e-2
128+
assert abs(result_mean.item() - 0.21805934607982635) < 1e-3
129+
elif torch_device in ["cuda"]:
130+
assert abs(result_sum.item() - 171.59353637695312) < 1e-2
131+
assert abs(result_mean.item() - 0.22342908382415771) < 1e-3
132+
else:
133+
assert abs(result_sum.item() - 162.52383422851562) < 1e-2
134+
assert abs(result_mean.item() - 0.211619570851326) < 1e-3
135+
136+
def test_full_loop_device_karras_sigmas(self):
137+
scheduler_class = self.scheduler_classes[0]
138+
scheduler_config = self.get_scheduler_config()
139+
scheduler = scheduler_class(**scheduler_config, use_karras_sigmas=True)
140+
141+
scheduler.set_timesteps(self.num_inference_steps, device=torch_device)
142+
143+
model = self.dummy_model()
144+
sample = self.dummy_sample_deter.to(torch_device) * scheduler.init_noise_sigma
145+
sample = sample.to(torch_device)
146+
147+
for t in scheduler.timesteps:
148+
sample = scheduler.scale_model_input(sample, t)
149+
150+
model_output = model(sample, t)
151+
152+
output = scheduler.step(model_output, t, sample)
153+
sample = output.prev_sample
154+
155+
result_sum = torch.sum(torch.abs(sample))
156+
result_mean = torch.mean(torch.abs(sample))
157+
158+
if torch_device in ["mps"]:
159+
assert abs(result_sum.item() - 176.66974135742188) < 1e-2
160+
assert abs(result_mean.item() - 0.23003872730981811) < 1e-2
161+
elif torch_device in ["cuda"]:
162+
assert abs(result_sum.item() - 177.63653564453125) < 1e-2
163+
assert abs(result_mean.item() - 0.23003872730981811) < 1e-2
164+
else:
165+
assert abs(result_sum.item() - 170.3135223388672) < 1e-2
166+
assert abs(result_mean.item() - 0.23003872730981811) < 1e-2

0 commit comments

Comments
 (0)