47
47
from PIL import Image
48
48
49
49
50
+ # Copied from tests.models.pvt.test_modeling_pvt
50
51
class PvtV2ConfigTester (ConfigTester ):
51
52
def run_common_tests (self ):
52
53
config = self .config_class (** self .inputs_dict )
@@ -123,13 +124,15 @@ def get_config(self):
123
124
out_indices = self .out_indices ,
124
125
)
125
126
127
+ # Copied from tests.models.pvt.test_modeling_pvt
126
128
def create_and_check_model (self , config , pixel_values , labels ):
127
129
model = PvtV2Model (config = config )
128
130
model .to (torch_device )
129
131
model .eval ()
130
132
result = model (pixel_values )
131
133
self .parent .assertIsNotNone (result .last_hidden_state )
132
134
135
+ # Copied from tests.models.resnet.test_modeling_resnet
133
136
def create_and_check_backbone (self , config , pixel_values , labels ):
134
137
model = PvtV2Backbone (config = config )
135
138
model .to (torch_device )
@@ -177,6 +180,7 @@ def create_and_check_for_image_classification(self, config, pixel_values, labels
177
180
result = model (pixel_values )
178
181
self .parent .assertEqual (result .logits .shape , (self .batch_size , self .type_sequence_label_size ))
179
182
183
+ # Copied from tests.models.pvt.test_modeling_pvt
180
184
def prepare_config_and_inputs_for_common (self ):
181
185
config_and_inputs = self .prepare_config_and_inputs ()
182
186
config , pixel_values , labels = config_and_inputs
@@ -185,6 +189,7 @@ def prepare_config_and_inputs_for_common(self):
185
189
186
190
187
191
# We will verify our results on an image of cute cats
192
+ # Copied from tests.models.pvt.test_modeling_pvt
188
193
def prepare_img ():
189
194
image = Image .open ("./tests/fixtures/tests_samples/COCO/000000039769.png" )
190
195
return image
@@ -205,21 +210,26 @@ class PvtV2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
205
210
test_torchscript = False
206
211
has_attentions = False
207
212
213
+ # Copied from tests.models.pvt.test_modeling_pvt
208
214
def setUp (self ):
209
215
self .model_tester = PvtV2ModelTester (self )
210
216
self .config_tester = PvtV2ConfigTester (self , config_class = PvtV2Config )
211
217
218
+ # Copied from tests.models.pvt.test_modeling_pvt
212
219
def test_config (self ):
213
220
self .config_tester .run_common_tests ()
214
221
222
+ # Copied from tests.models.pvt.test_modeling_pvt
215
223
def test_model (self ):
216
224
config_and_inputs = self .model_tester .prepare_config_and_inputs ()
217
225
self .model_tester .create_and_check_model (* config_and_inputs )
218
226
227
+ # Copied from tests.models.pvt.test_modeling_pvt
219
228
@unittest .skip ("Pvt-V2 does not use inputs_embeds" )
220
229
def test_inputs_embeds (self ):
221
230
pass
222
231
232
+ # Copied from tests.models.pvt.test_modeling_pvt
223
233
@unittest .skip ("Pvt-V2 does not have get_input_embeddings method and get_output_embeddings methods" )
224
234
def test_model_common_attributes (self ):
225
235
pass
@@ -235,6 +245,7 @@ def test_training_gradient_checkpointing_use_reentrant(self):
235
245
# torch.utils.checkpoint.checkpoint
236
246
self .check_training_gradient_checkpointing (gradient_checkpointing_kwargs = {"use_reentrant" : True })
237
247
248
+ # Copied from tests.models.pvt.test_modeling_pvt
238
249
def test_initialization (self ):
239
250
config , inputs_dict = self .model_tester .prepare_config_and_inputs_for_common ()
240
251
@@ -282,6 +293,7 @@ def check_hidden_states_output(inputs_dict, config, model_class):
282
293
283
294
check_hidden_states_output (inputs_dict , config , model_class )
284
295
296
+ # Copied from tests.models.pvt.test_modeling_pvt
285
297
def test_training (self ):
286
298
if not self .model_tester .is_training :
287
299
return
@@ -311,6 +323,7 @@ def test_forward_signature(self):
311
323
expected_arg_names = ["pixel_values" ]
312
324
self .assertListEqual (arg_names [:1 ], expected_arg_names )
313
325
326
+ # Copied from tests.models.pvt.test_modeling_pvt
314
327
@slow
315
328
def test_model_from_pretrained (self ):
316
329
for model_name in PVT_V2_PRETRAINED_MODEL_ARCHIVE_LIST [:1 ]:
@@ -363,6 +376,7 @@ def test_inference_model(self):
363
376
364
377
self .assertTrue (torch .allclose (outputs .last_hidden_state [0 , :3 , :3 ], expected_slice , atol = 1e-4 ))
365
378
379
+ # Copied from tests.models.pvt.test_modeling_pvt
366
380
@slow
367
381
@require_accelerate
368
382
@require_torch_accelerator
0 commit comments