Skip to content

Commit

Permalink
num_samples向下去整,防止数据集的溢出 (PaddlePaddle#8691)
Browse files Browse the repository at this point in the history
  • Loading branch information
JunnYu authored and DesmonDay committed Sep 5, 2024
1 parent 6e24524 commit a55589f
Showing 1 changed file with 2 additions and 4 deletions.
6 changes: 2 additions & 4 deletions paddlenlp/utils/batch_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@

from __future__ import division, print_function

import math

import paddle

__all__ = ["DistributedBatchSampler"]
Expand Down Expand Up @@ -110,7 +108,7 @@ def __init__(
# In pre-training mode when using distributed dataloader, the input dataset can be None. We should handle this situation.
self.num_samples = 0
else:
self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.nranks))
self.num_samples = int(len(self.dataset) * 1.0 / self.nranks)
self.total_size = self.num_samples * self.nranks

def get_start_end_idx(self):
Expand All @@ -125,7 +123,7 @@ def __iter__(self):
self.consumed_samples,
self.nranks,
)
self.remain_num_samples = int(math.ceil((len(self.dataset) - self.consumed_samples) * 1.0 / self.nranks))
self.remain_num_samples = int((len(self.dataset) - self.consumed_samples) * 1.0 / self.nranks)
self.remain_total_size = self.remain_num_samples * self.nranks
self.batch_size_times_rank_size = self.batch_size * self.nranks

Expand Down

0 comments on commit a55589f

Please sign in to comment.