diff --git a/data_utils.py b/data_utils.py index cd997fa..fae525a 100644 --- a/data_utils.py +++ b/data_utils.py @@ -195,10 +195,19 @@ class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler): if idx_bucket != -1: buckets[idx_bucket].append(i) - for i in range(len(buckets) - 1, 0, -1): - if len(buckets[i]) == 0: - buckets.pop(i) - self.boundaries.pop(i + 1) + try: + for i in range(len(buckets) - 1, 0, -1): + if len(buckets[i]) == 0: + buckets.pop(i) + self.boundaries.pop(i + 1) + assert all(len(bucket) > 0 for bucket in buckets) + # When one bucket is not traversed + except Exception as e: + print('Bucket warning ', e) + for i in range(len(buckets) - 1, -1, -1): + if len(buckets[i]) == 0: + buckets.pop(i) + self.boundaries.pop(i + 1) num_samples_per_bucket = [] for i in range(len(buckets)): @@ -264,4 +273,4 @@ class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler): return -1 def __len__(self): - return self.num_samples // self.batch_size \ No newline at end of file + return self.num_samples // self.batch_size