Skip to content

Commit

Permalink
fix alignment bug (#39747)
Browse files Browse the repository at this point in the history
  • Loading branch information
sneaxiy committed Feb 21, 2022
1 parent 496aadf commit 65ced1f
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 17 deletions.
26 changes: 10 additions & 16 deletions paddle/fluid/operators/optimizers/distributed_fused_lamb_init_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -63,19 +63,6 @@ static size_t GetAlignSize(size_t n, size_t alignment) {
return remainder == 0 ? n : n + alignment - remainder;
}

// gcd(x, y) = gcd(y, x % y)
// gcd(x, 0) = x
static size_t GCD(size_t x, size_t y) {
while (y > 0) {
auto tmp = x;
x = y;
y = tmp % y;
}
return x;
}

static size_t LCM(size_t x, size_t y) { return x / GCD(x, y) * y; }

// Shard the ParamGradInfo list by the numel size [start_size, end_size)
// The final results should be:
//
Expand Down Expand Up @@ -155,11 +142,18 @@ static size_t FillAlignmentPaddingInfo(std::vector<ParamGradInfo> *infos,

size_t total_numel_sum_with_padding = 0;
size_t n = infos->size();
auto lcm = LCM(alignment, nranks);
for (size_t i = 0; i < n; ++i) {
auto &info = (*infos)[i];
size_t numel_with_padding =
GetAlignSize(info.numel, i + 1 == n ? lcm : alignment);
size_t numel_with_padding;
if (i + 1 == n) {
// the total fused numel must be a factor of alignment * nranks
numel_with_padding =
GetAlignSize(info.numel + total_numel_sum_with_padding,
alignment * nranks) -
total_numel_sum_with_padding;
} else {
numel_with_padding = GetAlignSize(info.numel, alignment);
}
info.numel_with_padding = numel_with_padding;
info.numel_offset = total_numel_sum_with_padding;
total_numel_sum_with_padding += numel_with_padding;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ class TestDistributedFusedLambWithClip(unittest.TestCase):
def test_1(self):
run_test(clip_after_allreduce=True, max_global_norm=0.01)

def _test_2(self):
def test_2(self):
run_test(clip_after_allreduce=False, max_global_norm=0.01)


Expand Down

0 comments on commit 65ced1f

Please sign in to comment.