@@ -779,8 +779,33 @@ void sd_log_cb(enum sd_log_level_t level, const char* log, void* data) {
779
779
#ifdef SD_EXAMPLES_GLOVE_GUI_DESKTOP
780
780
#pragma GLOVE_APP_MSVC_NO_CONSOLE
781
781
#endif
782
+ struct RecurrentStruct {
783
+ sd_ctx_t * sd_ctx = NULL ;
784
+ GlvSDParams params;
785
+ bool model_updated (const GlvSDParams& _params) {
786
+ bool l_model_upddated = _params.get_model () != params.get_model ();
787
+ l_model_upddated |= _params.get_diffusion_model () != params.get_diffusion_model ();
788
+ l_model_upddated |= _params.get_model_addons () != params.get_model_addons ();
789
+ l_model_upddated |= _params.get_advanced_params ().get_taesd () != params.get_advanced_params ().get_taesd ();
790
+ l_model_upddated |= _params.get_advanced_params ().get_control_net () != params.get_advanced_params ().get_control_net ();
791
+ l_model_upddated |= _params.get_advanced_params ().get_embd_dir () != params.get_advanced_params ().get_embd_dir ();
792
+ l_model_upddated |= _params.get_photomaker_params ().get_stacked_id_embd_dir () != params.get_photomaker_params ().get_stacked_id_embd_dir ();
793
+ l_model_upddated |= _params.get_advanced_params ().get_vae_tiling () != params.get_advanced_params ().get_vae_tiling ();
794
+ l_model_upddated |= _params.get_advanced_params ().get_threads () != params.get_advanced_params ().get_threads ();
795
+ l_model_upddated |= _params.get_advanced_params ().get_type () != params.get_advanced_params ().get_type ();
796
+ l_model_upddated |= _params.get_rng () != params.get_rng ();
797
+ l_model_upddated |= _params.get_advanced_params ().get_schedule () != params.get_advanced_params ().get_schedule ();
798
+ l_model_upddated |= _params.get_advanced_params ().get_clip_on_cpu () != params.get_advanced_params ().get_clip_on_cpu ();
799
+ l_model_upddated |= _params.get_advanced_params ().get_control_net_cpu () != params.get_advanced_params ().get_control_net_cpu ();
800
+ l_model_upddated |= _params.get_advanced_params ().get_vae_on_cpu () != params.get_advanced_params ().get_vae_on_cpu ();
801
+ l_model_upddated |= _params.get_advanced_params ().get_diffusion_fa () != params.get_advanced_params ().get_diffusion_fa ();
802
+ return l_model_upddated;
803
+ }
804
+ };
782
805
#endif
783
806
807
+
808
+
784
809
int main (int argc, char * argv[]) {
785
810
786
811
#ifdef SD_EXAMPLES_GLOVE_GUI
@@ -797,10 +822,14 @@ int main(int argc, char* argv[]) {
797
822
GlvApp::get_progression (" Result" );
798
823
799
824
#ifdef SD_EXAMPLES_GLOVE_GUI_DESKTOP
800
- GLOVE_APP_PARAM_AUTO (GlvSDParams);
801
- #else
802
- GLOVE_APP_PARAM (GlvSDParams);
825
+ #define GLOVE_APP_AUTO true
803
826
#endif
827
+ #define GLOVE_APP_RECURRENT_MODE true
828
+ #define GLOVE_APP_RECURRENT_TYPE RecurrentStruct
829
+ RecurrentStruct glove_recurrent_var;
830
+
831
+ GLOVE_APP_PARAM (GlvSDParams);
832
+
804
833
#endif
805
834
806
835
SDParams params;
@@ -893,31 +922,53 @@ int main(int argc, char* argv[]) {
893
922
}
894
923
}
895
924
896
- sd_ctx_t * sd_ctx = new_sd_ctx (params.model_path .c_str (),
897
- params.clip_l_path .c_str (),
898
- params.clip_g_path .c_str (),
899
- params.t5xxl_path .c_str (),
900
- params.diffusion_model_path .c_str (),
901
- params.vae_path .c_str (),
902
- params.taesd_path .c_str (),
903
- params.controlnet_path .c_str (),
904
- params.lora_model_dir .c_str (),
905
- params.embeddings_path .c_str (),
906
- params.stacked_id_embeddings_path .c_str (),
907
- vae_decode_only,
908
- params.vae_tiling ,
909
- true ,
910
- params.n_threads ,
911
- params.wtype ,
912
- params.rng_type ,
913
- params.schedule ,
914
- params.clip_on_cpu ,
915
- params.control_net_cpu ,
916
- params.vae_on_cpu ,
917
- params.diffusion_flash_attn );
925
+ sd_ctx_t * sd_ctx;
926
+ #ifdef SD_EXAMPLES_GLOVE_GUI
927
+ if (!is_glove_recurrent || !glove_recurrent_var.sd_ctx || glove_recurrent_var.model_updated (glove_parametrization)) {
928
+ if (glove_recurrent_var.sd_ctx ) {
929
+ free_sd_ctx (glove_recurrent_var.sd_ctx );
930
+ }
931
+ #endif
932
+ sd_ctx = new_sd_ctx (params.model_path .c_str (),
933
+ params.clip_l_path .c_str (),
934
+ params.clip_g_path .c_str (),
935
+ params.t5xxl_path .c_str (),
936
+ params.diffusion_model_path .c_str (),
937
+ params.vae_path .c_str (),
938
+ params.taesd_path .c_str (),
939
+ params.controlnet_path .c_str (),
940
+ params.lora_model_dir .c_str (),
941
+ params.embeddings_path .c_str (),
942
+ params.stacked_id_embeddings_path .c_str (),
943
+ vae_decode_only,
944
+ params.vae_tiling ,
945
+ #ifdef SD_EXAMPLES_GLOVE_GUI
946
+ !is_glove_recurrent,
947
+ #else
948
+ true ,
949
+ #endif
950
+ params.n_threads ,
951
+ params.wtype ,
952
+ params.rng_type ,
953
+ params.schedule ,
954
+ params.clip_on_cpu ,
955
+ params.control_net_cpu ,
956
+ params.vae_on_cpu ,
957
+ params.diffusion_flash_attn );
958
+ #ifdef SD_EXAMPLES_GLOVE_GUI
959
+ if (is_glove_recurrent) {
960
+ glove_recurrent_var.sd_ctx = sd_ctx;
961
+ }
962
+ } else {
963
+ sd_ctx = glove_recurrent_var.sd_ctx ;
964
+ }
965
+ #endif
918
966
919
967
if (sd_ctx == NULL ) {
920
968
printf (" new_sd_ctx_t failed\n " );
969
+ #ifdef SD_EXAMPLES_GLOVE_GUI
970
+ GlvApp::show (SlvStatus (SlvStatus::statusType::warning, " new_sd_ctx_t failed" ), true );
971
+ #endif
921
972
return 1 ;
922
973
}
923
974
@@ -992,7 +1043,12 @@ int main(int argc, char* argv[]) {
992
1043
params.seed );
993
1044
if (results == NULL ) {
994
1045
printf (" generate failed\n " );
995
- free_sd_ctx (sd_ctx);
1046
+ #ifdef SD_EXAMPLES_GLOVE_GUI
1047
+ GlvApp::show (SlvStatus (SlvStatus::statusType::warning, " generate failed" ), true );
1048
+ if (!is_glove_recurrent)
1049
+ #endif
1050
+ free_sd_ctx (sd_ctx);
1051
+
996
1052
return 1 ;
997
1053
}
998
1054
size_t last = params.output_path .find_last_of (" ." );
@@ -1009,7 +1065,10 @@ int main(int argc, char* argv[]) {
1009
1065
results[i].data = NULL ;
1010
1066
}
1011
1067
free (results);
1012
- free_sd_ctx (sd_ctx);
1068
+ #ifdef SD_EXAMPLES_GLOVE_GUI
1069
+ if (!is_glove_recurrent)
1070
+ #endif
1071
+ free_sd_ctx (sd_ctx);
1013
1072
return 0 ;
1014
1073
} else {
1015
1074
results = img2img (sd_ctx,
@@ -1041,7 +1100,12 @@ int main(int argc, char* argv[]) {
1041
1100
1042
1101
if (results == NULL ) {
1043
1102
printf (" generate failed\n " );
1044
- free_sd_ctx (sd_ctx);
1103
+ #ifdef SD_EXAMPLES_GLOVE_GUI
1104
+ GlvApp::show (SlvStatus (SlvStatus::statusType::warning, " generate failed" ), true );
1105
+ if (!is_glove_recurrent)
1106
+ #endif
1107
+ free_sd_ctx (sd_ctx);
1108
+
1045
1109
return 1 ;
1046
1110
}
1047
1111
@@ -1074,10 +1138,14 @@ int main(int argc, char* argv[]) {
1074
1138
1075
1139
size_t last = params.output_path .find_last_of (" ." );
1076
1140
std::string dummy_name = last != std::string::npos ? params.output_path .substr (0 , last) : params.output_path ;
1141
+ #ifdef SD_EXAMPLES_GLOVE_GUI
1077
1142
GlvApp::get_progression (" Result" )->set_message (" Saving images" );
1078
1143
SlvProgressionQt& p = *GlvApp::get_progression (" Result" );
1079
1144
for (p = 0 ; p << params.batch_count ; p++) {
1080
1145
int i = p;
1146
+ #else
1147
+ for (int i = 0 ; i < params.batch_count ; i++) {
1148
+ #endif
1081
1149
if (results[i].data == NULL ) {
1082
1150
continue ;
1083
1151
}
@@ -1089,9 +1157,16 @@ int main(int argc, char* argv[]) {
1089
1157
results[i].data = NULL ;
1090
1158
}
1091
1159
free (results);
1092
- free_sd_ctx (sd_ctx);
1160
+ #ifdef SD_EXAMPLES_GLOVE_GUI
1161
+ if (!is_glove_recurrent)
1162
+ #endif
1163
+ free_sd_ctx (sd_ctx);
1093
1164
free (control_image_buffer);
1094
1165
free (input_image_buffer);
1095
1166
1167
+ #ifdef SD_EXAMPLES_GLOVE_GUI
1168
+ glove_recurrent_var.params = glove_parametrization;
1169
+ #endif
1170
+
1096
1171
return 0 ;
1097
1172
}
0 commit comments