@@ -373,11 +373,11 @@ Error Runner::generate(std::string prompt) {
373
373
uncond_emb_vec.data (),
374
374
{1 , 77 , 1024 },
375
375
encoder_method_meta.output_tensor_meta (0 )->scalar_type ());
376
- modules_[0 ]->set_output_data_ptr (cond_emb_tensor, 0 );
376
+ modules_[0 ]->set_output (cond_emb_tensor);
377
377
long encoder_start = util::time_in_ms ();
378
378
auto cond_res = modules_[0 ]->forward (cond_tokens_tensor);
379
379
stats_.text_encoder_execution_time += (util::time_in_ms () - encoder_start);
380
- modules_[0 ]->set_output_data_ptr (uncond_emb_tensor, 0 );
380
+ modules_[0 ]->set_output (uncond_emb_tensor);
381
381
encoder_start = util::time_in_ms ();
382
382
auto uncond_res = modules_[0 ]->forward (uncond_tokens_tensor);
383
383
stats_.text_encoder_execution_time += (util::time_in_ms () - encoder_start);
@@ -462,13 +462,13 @@ Error Runner::generate(std::string prompt) {
462
462
463
463
stats_.unet_aggregate_post_processing_time +=
464
464
(util::time_in_ms () - start_post_process);
465
- modules_[1 ]->set_output_data_ptr (noise_pred_text_tensor, 0 );
465
+ modules_[1 ]->set_output (noise_pred_text_tensor);
466
466
long start_unet_execution = util::time_in_ms ();
467
467
auto cond_res = modules_[1 ]->forward (
468
468
{latent_tensor, time_emb_tensors[step_index], cond_emb_tensor});
469
469
stats_.unet_aggregate_execution_time +=
470
470
(util::time_in_ms () - start_unet_execution);
471
- modules_[1 ]->set_output_data_ptr (noise_pred_uncond_tensor, 0 );
471
+ modules_[1 ]->set_output (noise_pred_uncond_tensor);
472
472
start_unet_execution = util::time_in_ms ();
473
473
auto uncond_res = modules_[1 ]->forward (
474
474
{latent_tensor,
@@ -519,7 +519,7 @@ Error Runner::generate(std::string prompt) {
519
519
520
520
quant_tensor (latent, vae_input, vae_input_scale_, vae_input_offset_);
521
521
522
- modules_[2 ]->set_output_data_ptr (output_tensor, 0 );
522
+ modules_[2 ]->set_output (output_tensor);
523
523
long start_vae_execution = util::time_in_ms ();
524
524
auto vae_res = modules_[2 ]->forward (vae_input_tensor);
525
525
stats_.vae_execution_time = (util::time_in_ms () - start_vae_execution);
0 commit comments