29
29
from pipeline_txt2img_xl import Txt2ImgXLPipeline
30
30
31
31
32
- def run_demo ():
33
- """Run Stable Diffusion XL Base + Refiner together (known as ensemble of expert denoisers) to generate an image."""
34
-
35
- args = parse_arguments (is_xl = True , description = "Options for Stable Diffusion XL Demo" )
36
-
37
- prompt , negative_prompt = repeat_prompt (args )
38
-
39
- # Recommend image size as one of those used in training (see Appendix I in https://arxiv.org/pdf/2307.01952.pdf).
40
- image_height = args .height
41
- image_width = args .width
42
-
32
+ def load_pipelines (args , batch_size ):
43
33
# Register TensorRT plugins
44
34
engine_type = get_engine_type (args .engine )
45
35
if engine_type == EngineType .TRT :
@@ -49,37 +39,83 @@ def run_demo():
49
39
50
40
max_batch_size = 16
51
41
if (engine_type in [EngineType .ORT_TRT , EngineType .TRT ]) and (
52
- args .build_dynamic_shape or image_height > 512 or image_width > 512
42
+ args .build_dynamic_shape or args . height > 512 or args . width > 512
53
43
):
54
44
max_batch_size = 4
55
45
56
- batch_size = len (prompt )
57
46
if batch_size > max_batch_size :
58
47
raise ValueError (f"Batch size { batch_size } is larger than allowed { max_batch_size } ." )
59
48
49
+ # For TensorRT, performance of engine built with dynamic shape is very sensitive to the range of image size.
50
+ # Here, we reduce the range of image size for TensorRT to trade-off flexibility and performance.
51
+ # This range can cover most frequent shape of landscape (832x1216), portrait (1216x832) or square (1024x1024).
52
+ min_image_size = 832 if args .engine != "ORT_CUDA" else 512
53
+ max_image_size = 1216 if args .engine != "ORT_CUDA" else 2048
54
+
60
55
# No VAE decoder in base when it outputs latent instead of image.
61
- base_info = PipelineInfo (args .version , use_vae = False )
62
- base = init_pipeline (Txt2ImgXLPipeline , base_info , engine_type , args , max_batch_size , batch_size )
56
+ base_info = PipelineInfo (
57
+ args .version , use_vae = args .disable_refiner , min_image_size = min_image_size , max_image_size = max_image_size
58
+ )
63
59
64
- refiner_info = PipelineInfo (args .version , is_refiner = True )
65
- refiner = init_pipeline (Img2ImgXLPipeline , refiner_info , engine_type , args , max_batch_size , batch_size )
60
+ # Ideally, the optimized batch size and image size for TRT engine shall align with user's preference. That is to
61
+ # optimize the shape used most frequently. We can let user config it when we develop a UI plugin.
62
+ # In this demo, we optimize batch size 1 and image size 1024x1024 for SD XL dynamic engine.
63
+ # This is mainly for benchmark purpose to simulate the case that we have no knowledge of user's preference.
64
+ opt_batch_size = 1 if args .build_dynamic_batch else batch_size
65
+ opt_image_height = base_info .default_image_size () if args .build_dynamic_shape else args .height
66
+ opt_image_width = base_info .default_image_size () if args .build_dynamic_shape else args .width
67
+
68
+ base = init_pipeline (
69
+ Txt2ImgXLPipeline ,
70
+ base_info ,
71
+ engine_type ,
72
+ args ,
73
+ max_batch_size ,
74
+ opt_batch_size ,
75
+ opt_image_height ,
76
+ opt_image_width ,
77
+ )
78
+
79
+ refiner = None
80
+ if not args .disable_refiner :
81
+ refiner_info = PipelineInfo (
82
+ args .version , is_refiner = True , min_image_size = min_image_size , max_image_size = max_image_size
83
+ )
84
+ refiner = init_pipeline (
85
+ Img2ImgXLPipeline ,
86
+ refiner_info ,
87
+ engine_type ,
88
+ args ,
89
+ max_batch_size ,
90
+ opt_batch_size ,
91
+ opt_image_height ,
92
+ opt_image_width ,
93
+ )
66
94
67
95
if engine_type == EngineType .TRT :
68
- max_device_memory = max (base .backend .max_device_memory (), refiner .backend .max_device_memory ())
96
+ max_device_memory = max (base .backend .max_device_memory (), ( refiner or base ) .backend .max_device_memory ())
69
97
_ , shared_device_memory = cudart .cudaMalloc (max_device_memory )
70
98
base .backend .activate_engines (shared_device_memory )
71
- refiner .backend .activate_engines (shared_device_memory )
99
+ if refiner :
100
+ refiner .backend .activate_engines (shared_device_memory )
72
101
73
102
if engine_type == EngineType .ORT_CUDA :
74
103
enable_vae_slicing = args .enable_vae_slicing
75
104
if batch_size > 4 and not enable_vae_slicing :
76
105
print ("Updating enable_vae_slicing to be True to avoid cuDNN error for batch size > 4." )
77
106
enable_vae_slicing = True
78
107
if enable_vae_slicing :
79
- refiner .backend .enable_vae_slicing ()
108
+ (refiner or base ).backend .enable_vae_slicing ()
109
+ return base , refiner
110
+
80
111
112
+ def run_pipelines (args , base , refiner , prompt , negative_prompt , is_warm_up = False ):
113
+ image_height = args .height
114
+ image_width = args .width
115
+ batch_size = len (prompt )
81
116
base .load_resources (image_height , image_width , batch_size )
82
- refiner .load_resources (image_height , image_width , batch_size )
117
+ if refiner :
118
+ refiner .load_resources (image_height , image_width , batch_size )
83
119
84
120
def run_base_and_refiner (warmup = False ):
85
121
images , time_base = base .run (
@@ -91,8 +127,13 @@ def run_base_and_refiner(warmup=False):
91
127
denoising_steps = args .denoising_steps ,
92
128
guidance = args .guidance ,
93
129
seed = args .seed ,
94
- return_type = "latent" ,
130
+ return_type = "latent" if refiner else "image" ,
95
131
)
132
+ if refiner is None :
133
+ return images , time_base
134
+
135
+ # Use same seed in base and refiner.
136
+ seed = base .get_current_seed ()
96
137
97
138
images , time_refiner = refiner .run (
98
139
prompt ,
@@ -103,7 +144,7 @@ def run_base_and_refiner(warmup=False):
103
144
warmup = warmup ,
104
145
denoising_steps = args .denoising_steps ,
105
146
guidance = args .guidance ,
106
- seed = args . seed ,
147
+ seed = seed ,
107
148
)
108
149
109
150
return images , time_base + time_refiner
@@ -112,25 +153,104 @@ def run_base_and_refiner(warmup=False):
112
153
# inference once to get cuda graph
113
154
_ , _ = run_base_and_refiner (warmup = True )
114
155
115
- print ("[I] Warming up .." )
156
+ if args .num_warmup_runs > 0 :
157
+ print ("[I] Warming up .." )
116
158
for _ in range (args .num_warmup_runs ):
117
159
_ , _ = run_base_and_refiner (warmup = True )
118
160
161
+ if is_warm_up :
162
+ return
163
+
119
164
print ("[I] Running StableDiffusion XL pipeline" )
120
165
if args .nvtx_profile :
121
166
cudart .cudaProfilerStart ()
122
167
_ , latency = run_base_and_refiner (warmup = False )
123
168
if args .nvtx_profile :
124
169
cudart .cudaProfilerStop ()
125
170
126
- base .teardown ()
127
-
128
171
print ("|------------|--------------|" )
129
172
print ("| {:^10} | {:>9.2f} ms |" .format ("e2e" , latency ))
130
173
print ("|------------|--------------|" )
131
- refiner .teardown ()
174
+
175
+
176
+ def run_demo (args ):
177
+ """Run Stable Diffusion XL Base + Refiner together (known as ensemble of expert denoisers) to generate an image."""
178
+
179
+ prompt , negative_prompt = repeat_prompt (args )
180
+ batch_size = len (prompt )
181
+ base , refiner = load_pipelines (args , batch_size )
182
+ run_pipelines (args , base , refiner , prompt , negative_prompt )
183
+ base .teardown ()
184
+ if refiner :
185
+ refiner .teardown ()
186
+
187
+
188
+ def run_dynamic_shape_demo (args ):
189
+ """Run demo of generating images with different settings with ORT CUDA provider."""
190
+ args .engine = "ORT_CUDA"
191
+ args .disable_cuda_graph = True
192
+ base , refiner = load_pipelines (args , 1 )
193
+
194
+ prompts = [
195
+ "starry night over Golden Gate Bridge by van gogh" ,
196
+ "beautiful photograph of Mt. Fuji during cherry blossom" ,
197
+ "little cute gremlin sitting on a bed, cinematic" ,
198
+ "cute grey cat with blue eyes, wearing a bowtie, acrylic painting" ,
199
+ "beautiful Renaissance Revival Estate, Hobbit-House, detailed painting, warm colors, 8k, trending on Artstation" ,
200
+ "blue owl, big green eyes, portrait, intricate metal design, unreal engine, octane render, realistic" ,
201
+ ]
202
+
203
+ # batch size, height, width, scheduler, steps, prompt, seed
204
+ configs = [
205
+ (1 , 832 , 1216 , "UniPC" , 8 , prompts [0 ], None ),
206
+ (1 , 1024 , 1024 , "DDIM" , 24 , prompts [1 ], None ),
207
+ (1 , 1216 , 832 , "UniPC" , 16 , prompts [2 ], None ),
208
+ (1 , 1344 , 768 , "DDIM" , 24 , prompts [3 ], None ),
209
+ (2 , 640 , 1536 , "UniPC" , 16 , prompts [4 ], 4312973633252712 ),
210
+ (2 , 1152 , 896 , "DDIM" , 24 , prompts [5 ], 1964684802882906 ),
211
+ ]
212
+
213
+ # Warm up each combination of (batch size, height, width) once before serving.
214
+ args .prompt = ["warm up" ]
215
+ args .num_warmup_runs = 1
216
+ for batch_size , height , width , _ , _ , _ , _ in configs :
217
+ args .batch_size = batch_size
218
+ args .height = height
219
+ args .width = width
220
+ print (f"\n Warm up batch_size={ batch_size } , height={ height } , width={ width } " )
221
+ prompt , negative_prompt = repeat_prompt (args )
222
+ run_pipelines (args , base , refiner , prompt , negative_prompt , is_warm_up = True )
223
+
224
+ # Run pipeline on a list of prompts.
225
+ args .num_warmup_runs = 0
226
+ for batch_size , height , width , scheduler , steps , example_prompt , seed in configs :
227
+ args .prompt = [example_prompt ]
228
+ args .batch_size = batch_size
229
+ args .height = height
230
+ args .width = width
231
+ args .scheduler = scheduler
232
+ args .denoising_steps = steps
233
+ args .seed = seed
234
+ base .set_scheduler (scheduler )
235
+ if refiner :
236
+ refiner .set_scheduler (scheduler )
237
+ print (
238
+ f"\n batch_size={ batch_size } , height={ height } , width={ width } , scheduler={ scheduler } , steps={ steps } , prompt={ example_prompt } , seed={ seed } "
239
+ )
240
+ prompt , negative_prompt = repeat_prompt (args )
241
+ run_pipelines (args , base , refiner , prompt , negative_prompt , is_warm_up = False )
242
+
243
+ base .teardown ()
244
+ if refiner :
245
+ refiner .teardown ()
132
246
133
247
134
248
if __name__ == "__main__" :
135
249
coloredlogs .install (fmt = "%(funcName)20s: %(message)s" )
136
- run_demo ()
250
+
251
+ args = parse_arguments (is_xl = True , description = "Options for Stable Diffusion XL Demo" )
252
+ no_prompt = isinstance (args .prompt , list ) and len (args .prompt ) == 1 and not args .prompt [0 ]
253
+ if no_prompt :
254
+ run_dynamic_shape_demo (args )
255
+ else :
256
+ run_demo (args )
0 commit comments