|
12 | 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 | 13 | # See the License for the specific language governing permissions and
|
14 | 14 | # limitations under the License.
|
| 15 | +import copy |
15 | 16 | import gc
|
16 | 17 | import os
|
17 | 18 | import sys
|
@@ -162,6 +163,105 @@ def test_with_alpha_in_state_dict(self):
|
162 | 163 | )
|
163 | 164 | self.assertFalse(np.allclose(images_lora_with_alpha, images_lora, atol=1e-3, rtol=1e-3))
|
164 | 165 |
|
| 166 | + def test_lora_expansion_works_for_absent_keys(self): |
| 167 | + components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) |
| 168 | + pipe = self.pipeline_class(**components) |
| 169 | + pipe = pipe.to(torch_device) |
| 170 | + pipe.set_progress_bar_config(disable=None) |
| 171 | + _, _, inputs = self.get_dummy_inputs(with_generator=False) |
| 172 | + |
| 173 | + output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images |
| 174 | + self.assertTrue(output_no_lora.shape == self.output_shape) |
| 175 | + |
| 176 | + # Modify the config to have a layer which won't be present in the second LoRA we will load. |
| 177 | + modified_denoiser_lora_config = copy.deepcopy(denoiser_lora_config) |
| 178 | + modified_denoiser_lora_config.target_modules.add("x_embedder") |
| 179 | + |
| 180 | + pipe.transformer.add_adapter(modified_denoiser_lora_config) |
| 181 | + self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer") |
| 182 | + |
| 183 | + images_lora = pipe(**inputs, generator=torch.manual_seed(0)).images |
| 184 | + self.assertFalse( |
| 185 | + np.allclose(images_lora, output_no_lora, atol=1e-3, rtol=1e-3), |
| 186 | + "LoRA should lead to different results.", |
| 187 | + ) |
| 188 | + |
| 189 | + with tempfile.TemporaryDirectory() as tmpdirname: |
| 190 | + denoiser_state_dict = get_peft_model_state_dict(pipe.transformer) |
| 191 | + self.pipeline_class.save_lora_weights(tmpdirname, transformer_lora_layers=denoiser_state_dict) |
| 192 | + |
| 193 | + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) |
| 194 | + pipe.unload_lora_weights() |
| 195 | + pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"), adapter_name="one") |
| 196 | + |
| 197 | + # Modify the state dict to exclude "x_embedder" related LoRA params. |
| 198 | + lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) |
| 199 | + lora_state_dict_without_xembedder = {k: v for k, v in lora_state_dict.items() if "x_embedder" not in k} |
| 200 | + |
| 201 | + pipe.load_lora_weights(lora_state_dict_without_xembedder, adapter_name="two") |
| 202 | + pipe.set_adapters(["one", "two"]) |
| 203 | + self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer") |
| 204 | + images_lora_with_absent_keys = pipe(**inputs, generator=torch.manual_seed(0)).images |
| 205 | + |
| 206 | + self.assertFalse( |
| 207 | + np.allclose(images_lora, images_lora_with_absent_keys, atol=1e-3, rtol=1e-3), |
| 208 | + "Different LoRAs should lead to different results.", |
| 209 | + ) |
| 210 | + self.assertFalse( |
| 211 | + np.allclose(output_no_lora, images_lora_with_absent_keys, atol=1e-3, rtol=1e-3), |
| 212 | + "LoRA should lead to different results.", |
| 213 | + ) |
| 214 | + |
| 215 | + def test_lora_expansion_works_for_extra_keys(self): |
| 216 | + components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) |
| 217 | + pipe = self.pipeline_class(**components) |
| 218 | + pipe = pipe.to(torch_device) |
| 219 | + pipe.set_progress_bar_config(disable=None) |
| 220 | + _, _, inputs = self.get_dummy_inputs(with_generator=False) |
| 221 | + |
| 222 | + output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images |
| 223 | + self.assertTrue(output_no_lora.shape == self.output_shape) |
| 224 | + |
| 225 | + # Modify the config to have a layer which won't be present in the first LoRA we will load. |
| 226 | + modified_denoiser_lora_config = copy.deepcopy(denoiser_lora_config) |
| 227 | + modified_denoiser_lora_config.target_modules.add("x_embedder") |
| 228 | + |
| 229 | + pipe.transformer.add_adapter(modified_denoiser_lora_config) |
| 230 | + self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer") |
| 231 | + |
| 232 | + images_lora = pipe(**inputs, generator=torch.manual_seed(0)).images |
| 233 | + self.assertFalse( |
| 234 | + np.allclose(images_lora, output_no_lora, atol=1e-3, rtol=1e-3), |
| 235 | + "LoRA should lead to different results.", |
| 236 | + ) |
| 237 | + |
| 238 | + with tempfile.TemporaryDirectory() as tmpdirname: |
| 239 | + denoiser_state_dict = get_peft_model_state_dict(pipe.transformer) |
| 240 | + self.pipeline_class.save_lora_weights(tmpdirname, transformer_lora_layers=denoiser_state_dict) |
| 241 | + |
| 242 | + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) |
| 243 | + pipe.unload_lora_weights() |
| 244 | + # Modify the state dict to exclude "x_embedder" related LoRA params. |
| 245 | + lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) |
| 246 | + lora_state_dict_without_xembedder = {k: v for k, v in lora_state_dict.items() if "x_embedder" not in k} |
| 247 | + pipe.load_lora_weights(lora_state_dict_without_xembedder, adapter_name="one") |
| 248 | + |
| 249 | + # Load state dict with `x_embedder`. |
| 250 | + pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"), adapter_name="two") |
| 251 | + |
| 252 | + pipe.set_adapters(["one", "two"]) |
| 253 | + self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer") |
| 254 | + images_lora_with_extra_keys = pipe(**inputs, generator=torch.manual_seed(0)).images |
| 255 | + |
| 256 | + self.assertFalse( |
| 257 | + np.allclose(images_lora, images_lora_with_extra_keys, atol=1e-3, rtol=1e-3), |
| 258 | + "Different LoRAs should lead to different results.", |
| 259 | + ) |
| 260 | + self.assertFalse( |
| 261 | + np.allclose(output_no_lora, images_lora_with_extra_keys, atol=1e-3, rtol=1e-3), |
| 262 | + "LoRA should lead to different results.", |
| 263 | + ) |
| 264 | + |
165 | 265 | @unittest.skip("Not supported in Flux.")
|
166 | 266 | def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
|
167 | 267 | pass
|
|
0 commit comments