Skip to content

Commit fdb7ceb

Browse files
committed
Added "Copied from" comments in test_modeling_pvt_v2.py
1 parent a64e779 commit fdb7ceb

File tree

1 file changed

+14
-0
lines changed

1 file changed

+14
-0
lines changed

tests/models/pvt_v2/test_modeling_pvt_v2.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
from PIL import Image
4848

4949

50+
# Copied from tests.models.pvt.test_modeling_pvt
5051
class PvtV2ConfigTester(ConfigTester):
5152
def run_common_tests(self):
5253
config = self.config_class(**self.inputs_dict)
@@ -123,13 +124,15 @@ def get_config(self):
123124
out_indices=self.out_indices,
124125
)
125126

127+
# Copied from tests.models.pvt.test_modeling_pvt
126128
def create_and_check_model(self, config, pixel_values, labels):
127129
model = PvtV2Model(config=config)
128130
model.to(torch_device)
129131
model.eval()
130132
result = model(pixel_values)
131133
self.parent.assertIsNotNone(result.last_hidden_state)
132134

135+
# Copied from tests.models.resnet.test_modeling_resnet
133136
def create_and_check_backbone(self, config, pixel_values, labels):
134137
model = PvtV2Backbone(config=config)
135138
model.to(torch_device)
@@ -177,6 +180,7 @@ def create_and_check_for_image_classification(self, config, pixel_values, labels
177180
result = model(pixel_values)
178181
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
179182

183+
# Copied from tests.models.pvt.test_modeling_pvt
180184
def prepare_config_and_inputs_for_common(self):
181185
config_and_inputs = self.prepare_config_and_inputs()
182186
config, pixel_values, labels = config_and_inputs
@@ -185,6 +189,7 @@ def prepare_config_and_inputs_for_common(self):
185189

186190

187191
# We will verify our results on an image of cute cats
192+
# Copied from tests.models.pvt.test_modeling_pvt
188193
def prepare_img():
189194
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
190195
return image
@@ -205,21 +210,26 @@ class PvtV2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
205210
test_torchscript = False
206211
has_attentions = False
207212

213+
# Copied from tests.models.pvt.test_modeling_pvt
208214
def setUp(self):
209215
self.model_tester = PvtV2ModelTester(self)
210216
self.config_tester = PvtV2ConfigTester(self, config_class=PvtV2Config)
211217

218+
# Copied from tests.models.pvt.test_modeling_pvt
212219
def test_config(self):
213220
self.config_tester.run_common_tests()
214221

222+
# Copied from tests.models.pvt.test_modeling_pvt
215223
def test_model(self):
216224
config_and_inputs = self.model_tester.prepare_config_and_inputs()
217225
self.model_tester.create_and_check_model(*config_and_inputs)
218226

227+
# Copied from tests.models.pvt.test_modeling_pvt
219228
@unittest.skip("Pvt-V2 does not use inputs_embeds")
220229
def test_inputs_embeds(self):
221230
pass
222231

232+
# Copied from tests.models.pvt.test_modeling_pvt
223233
@unittest.skip("Pvt-V2 does not have get_input_embeddings method and get_output_embeddings methods")
224234
def test_model_common_attributes(self):
225235
pass
@@ -235,6 +245,7 @@ def test_training_gradient_checkpointing_use_reentrant(self):
235245
# torch.utils.checkpoint.checkpoint
236246
self.check_training_gradient_checkpointing(gradient_checkpointing_kwargs={"use_reentrant": True})
237247

248+
# Copied from tests.models.pvt.test_modeling_pvt
238249
def test_initialization(self):
239250
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
240251

@@ -282,6 +293,7 @@ def check_hidden_states_output(inputs_dict, config, model_class):
282293

283294
check_hidden_states_output(inputs_dict, config, model_class)
284295

296+
# Copied from tests.models.pvt.test_modeling_pvt
285297
def test_training(self):
286298
if not self.model_tester.is_training:
287299
return
@@ -311,6 +323,7 @@ def test_forward_signature(self):
311323
expected_arg_names = ["pixel_values"]
312324
self.assertListEqual(arg_names[:1], expected_arg_names)
313325

326+
# Copied from tests.models.pvt.test_modeling_pvt
314327
@slow
315328
def test_model_from_pretrained(self):
316329
for model_name in PVT_V2_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
@@ -363,6 +376,7 @@ def test_inference_model(self):
363376

364377
self.assertTrue(torch.allclose(outputs.last_hidden_state[0, :3, :3], expected_slice, atol=1e-4))
365378

379+
# Copied from tests.models.pvt.test_modeling_pvt
366380
@slow
367381
@require_accelerate
368382
@require_torch_accelerator

0 commit comments

Comments
 (0)