Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enrich ConvShift to support sequence data input #2132

Closed
pkuyym opened this issue May 15, 2017 · 1 comment
Closed

Enrich ConvShift to support sequence data input #2132

pkuyym opened this issue May 15, 2017 · 1 comment
Assignees

Comments

@pkuyym
Copy link
Contributor

pkuyym commented May 15, 2017

ConvShift Layer is designed to implement circular convolutional operation. Circular convolutional operation is a necessary step to implement location-based addressing in Neural Turing Machine. NTM was designed to process sequence data initially. However, current ConvShift Layer only supports non-sequence type data input.
Original implementation calls circularConv and circularConvDervative functions implemented in class Matrix.

void ConvShiftLayer::forward(PassType passType) {
  Layer::forward(passType);

  MatrixPtr inV0 = getInputValue(0);
  MatrixPtr inV1 = getInputValue(1);

  size_t batchSize = inV0->getHeight();
  size_t dataDim = inV0->getWidth();

  CHECK_EQ(batchSize, inV1->getHeight());
  CHECK_EQ(dataDim, getSize());

  {
    REGISTER_TIMER_INFO("FwResetTimer", getName().c_str());
    resetOutput(batchSize, dataDim);
  }

  MatrixPtr outV = getOutputValue();

  REGISTER_TIMER_INFO("FwConvShiftTimer", getName().c_str());
  outV->circularConv(*inV0, *inV1);
}

void ConvShiftLayer::backward(const UpdateCallback& callback) {
  MatrixPtr inV0 = getInputValue(0);
  MatrixPtr inV1 = getInputValue(1);
  MatrixPtr outG = getOutputGrad();
  MatrixPtr inG0 = getInputGrad(0);
  MatrixPtr inG1 = getInputGrad(1);

  REGISTER_TIMER_INFO("BwConvShiftTimer", getName().c_str());

  if (inG0 && inG1) {
    outG->circularConvDerivative(*outG, *inV0, *inV1, *inG0, *inG1);
  } else {
    CHECK(!inG0 || !inG1) << "Not supported";
  }
}

First, I will check the type of input data. If the input data is non-sequence type, the function will run origin logic, otherwise call circularConvSeq and circularConvSeqDerivative.

void ConvShiftLayer::forward(PassType passType) {
  Layer::forward(passType);

  MatrixPtr inV0 = getInputValue(0);

  size_t batchSize = inV0->getHeight();
  size_t dataDim = inV0->getWidth();

  CHECK_EQ(dataDim, getSize());

  {
    REGISTER_TIMER_INFO("FwResetTimer", getName().c_str());
    resetOutput(batchSize, dataDim);
  }

  REGISTER_TIMER_INFO("FwConvShiftTimer", getName().c_str());
  if (!isSeqType()) {
    MatrixPtr inV1 = getInputValue(1);
    CHECK_EQ(batchSize, inV1->getHeight());
    MatrixPtr outV = getOutputValue();
    outV->circularConv(*inV0, *inV1);
  } else {
    circularConvSeq();
  }
}

void ConvShiftLayer::backward(const UpdateCallback& callback) {
  MatrixPtr inG0 = getInputGrad(0);
  MatrixPtr inG1 = getInputGrad(1);

  REGISTER_TIMER_INFO("BwConvShiftTimer", getName().c_str());

  if (!(inG0 && inG1)) {
    CHECK(!inG0 || !inG1) << "Not supported";
  }

  if (!isSeqType()) {
    MatrixPtr inV0 = getInputValue(0);
    MatrixPtr inV1 = getInputValue(1);
    MatrixPtr outG = getOutputGrad();
    outG->circularConvDerivative(*outG, *inV0, *inV1, *inG0, *inG1);
  } else {
    circularConvSeqDerivative();
  }
}

Please go to related PR #2133 to check implementation details.

@typhoonzero
Copy link
Contributor

Close due to low activity, please feel free to reopen it.

heavengate pushed a commit to heavengate/Paddle that referenced this issue Aug 16, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants