Skip to content

Commit 1a104dc

Browse files
committed
make forward/backward pathes same ref #1363
1 parent 58fb648 commit 1a104dc

File tree

1 file changed

+4
-8
lines changed

1 file changed

+4
-8
lines changed

networks/control_net_lllite_for_train.py

+4-8
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,10 @@
77
import torch
88
from library import sdxl_original_unet
99
from library.utils import setup_logging
10+
1011
setup_logging()
1112
import logging
13+
1214
logger = logging.getLogger(__name__)
1315

1416
# input_blocksに適用するかどうか / if True, input_blocks are not applied
@@ -103,19 +105,15 @@ def set_lllite(self, depth, cond_emb_dim, name, mlp_dim, dropout=None, multiplie
103105
add_lllite_modules(self, in_dim, depth, cond_emb_dim, mlp_dim)
104106

105107
self.cond_image = None
106-
self.cond_emb = None
107108

108109
def set_cond_image(self, cond_image):
109110
self.cond_image = cond_image
110-
self.cond_emb = None
111111

112112
def forward(self, x):
113113
if not self.enabled:
114114
return super().forward(x)
115115

116-
if self.cond_emb is None:
117-
self.cond_emb = self.lllite_conditioning1(self.cond_image)
118-
cx = self.cond_emb
116+
cx = self.lllite_conditioning1(self.cond_image) # make forward and backward compatible
119117

120118
# reshape / b,c,h,w -> b,h*w,c
121119
n, c, h, w = cx.shape
@@ -159,9 +157,7 @@ def forward(self, x): # , cond_image=None):
159157
if not self.enabled:
160158
return super().forward(x)
161159

162-
if self.cond_emb is None:
163-
self.cond_emb = self.lllite_conditioning1(self.cond_image)
164-
cx = self.cond_emb
160+
cx = self.lllite_conditioning1(self.cond_image)
165161

166162
cx = torch.cat([cx, self.down(x)], dim=1)
167163
cx = self.mid(cx)

0 commit comments

Comments
 (0)