Skip to content

Commit

Permalink
add dropout backward
Browse files Browse the repository at this point in the history
  • Loading branch information
Yin Hongyun committed Nov 29, 2024
1 parent 4a1484c commit 6e71c12
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 4 deletions.
23 changes: 22 additions & 1 deletion diopi_test/python/conformance/diopi_manual_test.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,29 @@
# Copyright (c) 2023, DeepLink.
# -*- coding: UTF-8 -*-
import numpy as np
import diopilib
from diopilib import build_generator_state
from .diopi_runtime import Tensor, Generator, default_context
from . import diopi_functions as F


class ManualTest(object):

def test_dropout_backward(input, p, atol, rtol):
import torch
grad_in = Tensor(input.size().data, input.get_dtype())
torch_input = torch.from_numpy(input.numpy()).requires_grad_(False)
torch_input[torch_input==0] = 0.1
torch_input = torch_input.requires_grad_()
torch_ones = torch.ones_like(torch_input)
grad_outputs = Tensor.from_numpy(torch_ones.numpy())
out = torch.nn.functional.dropout(torch_input, p=p, training=True)
out.backward(torch_ones)
mask = Tensor.from_numpy(out.ne(0).to(torch.float32).numpy())

diopilib.diopiDropoutBackward(input.context(), grad_in, grad_outputs, mask, p)
assert np.allclose(grad_in.numpy(), torch_input.grad.numpy(), rtol=rtol, atol=atol)

def test_dropout_(func, input, p=0.5, training=True, inplace=False):
input_numpy = input.numpy()
state = build_generator_state(input.context())
Expand All @@ -30,6 +47,10 @@ def test_dropout_(func, input, p=0.5, training=True, inplace=False):
ref = input_numpy[mask_numpy == 1]
assert np.allclose(remains, ref / (1 - p), rtol=rtol, atol=atol), \
f"failed to execute {name}, dropout value doesn't matches."

if name == 'dropout':
ManualTest.test_dropout_backward(input, p, atol, rtol)

if mask.numel() > 100:
# 0.05 is from pytorch
assert np.abs(real_ratio - (1 - p)) < 0.05, \
Expand All @@ -43,7 +64,7 @@ def test_dropout(input, p=0.5, training=True, inplace=False):

def test_dropout2d(input, p=0.5, training=True, inplace=False):
ManualTest.test_dropout_(F.dropout2d, input, p, training, inplace)

def test_randperm(n):
state = build_generator_state(default_context)
generator = Generator(state)
Expand Down
14 changes: 14 additions & 0 deletions impl/torch/functions/functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2496,6 +2496,20 @@ diopiError_t diopiDropoutInp(diopiContextHandle_t ctx, diopiTensorHandle_t input
return diopiSuccess;
}

diopiError_t diopiDropoutBackward(diopiContextHandle_t ctx, diopiTensorHandle_t grad_input, diopiConstTensorHandle_t grad_output,
diopiTensorHandle_t mask, double p){
impl::aten::setCurStream(ctx);
auto atGradInput = impl::aten::buildATen(grad_input);
auto atGradOutput = impl::aten::buildATen(grad_output);
auto atMask = impl::aten::buildATen(mask);

atMask.mul_(atGradOutput);
atMask.div_(1 - p);
impl::aten::updateATen2Tensor(ctx, atMask, grad_input);

return diopiSuccess;
}

diopiError_t diopiMSELoss(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input, diopiConstTensorHandle_t target,
diopiReduction_t reduction) {
impl::aten::setCurStream(ctx);
Expand Down
6 changes: 3 additions & 3 deletions proto/include/diopi/functions.h
Original file line number Diff line number Diff line change
Expand Up @@ -721,10 +721,10 @@ DIOPI_API diopiError_t diopiDropout(diopiContextHandle_t ctx, diopiTensorHandle_
bool train, diopiGeneratorHandle_t generator);

/**
*@brief Randomly zeroes some of the elements of the input tensor with probability p
*@brief Compute the backward pass of diopiDropout().
*/
DIOPI_API diopiError_t diopiDropout(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiTensorHandle_t mask, diopiConstTensorHandle_t input, double p,
bool train, diopiGeneratorHandle_t generator);
DIOPI_API diopiError_t diopiDropoutBackward(diopiContextHandle_t ctx, diopiTensorHandle_t grad_input, diopiConstTensorHandle_t grad_output,
diopiTensorHandle_t mask, double p);

/**
* @brief The in-place version of diopiDropout().
Expand Down

0 comments on commit 6e71c12

Please sign in to comment.