Skip to content

Commit ccc60d5

Browse files
committed
fix: Use len() to get size of dataset
Signed-off-by: Dheeraj Peri <[email protected]>
1 parent eb39f9c commit ccc60d5

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

Diff for: py/trtorch/ptq.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@ def get_batch_size(self):
2626

2727

2828
def get_batch(self, names):
29-
if self.current_batch_idx + self.batch_size > self.data_loader.dataset.data.shape[0]:
29+
print("Current batch idx: ", self.current_batch_idx, " Dataset size: ", len(self.data_loader.dataset))
30+
if self.current_batch_idx + self.batch_size > len(self.data_loader.dataset):
3031
return None
3132

3233
batch = self.dataset_iterator.next()

0 commit comments

Comments
 (0)