-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enrich ConvShift to support sequence data input #2133
Changes from 1 commit
6adf4ac
8cd2222
25cdee6
167ceb5
cec53b9
aa4ac87
a4e5e66
490e35d
eeccac1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -52,6 +52,9 @@ class ConvShiftLayer : public Layer { | |
|
||
void forward(PassType passType) override; | ||
void backward(const UpdateCallback& callback = nullptr) override; | ||
bool isSeqType(); | ||
void circularConvSeq(); | ||
void circularConvSeqDerivative(); | ||
}; | ||
|
||
REGISTER_LAYER(conv_shift, ConvShiftLayer); | ||
|
@@ -66,42 +69,161 @@ bool ConvShiftLayer::init(const LayerMap& layerMap, | |
return true; | ||
} | ||
|
||
bool ConvShiftLayer::isSeqType() { | ||
const Argument& inLayer0 = getInput(0); | ||
if (nullptr == inLayer0.sequenceStartPositions) | ||
return false; | ||
else | ||
return true; | ||
} | ||
|
||
void ConvShiftLayer::circularConvSeq() { | ||
const Argument& inLayer0 = getInput(0); | ||
MatrixPtr in0 = inLayer0.value; | ||
MatrixPtr in1 = getInputValue(1); | ||
MatrixPtr out = getOutputValue(); | ||
const ICpuGpuVectorPtr& sequenceStartPositions = | ||
inLayer0.sequenceStartPositions; | ||
|
||
size_t width0 = in0->getWidth(); | ||
size_t numSeqs = sequenceStartPositions->getSize() - 1; | ||
size_t height0 = in0->getHeight(); | ||
size_t width1 = in1->getWidth(); | ||
size_t height1 = in1->getHeight(); | ||
|
||
CHECK_EQ(numSeqs, height1); | ||
CHECK_EQ(width0, out->getWidth()); | ||
CHECK_EQ(height0, out->getHeight()); | ||
|
||
CHECK_EQ(width1 % 2, 1U); | ||
|
||
real* inV0 = in0->getData(); | ||
const int* startPosIntPtr = sequenceStartPositions->getData(false); | ||
real* inV1 = in1->getData(); | ||
real* outV = out->getData(); | ||
|
||
int leftCtxLen = (width1 - 1) / 2; | ||
for (size_t x = 0; x < numSeqs - 1; x++) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. for循环最后的x++统一改成++x,下同 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
int curSeqLen = startPosIntPtr[x + 1]; | ||
size_t curSeqWidth = curSeqLen * width0; | ||
for (size_t i = 0; i < curSeqWidth; i++) { | ||
for (size_t j = 0; j < width1; ++j) { | ||
int index = i + j - leftCtxLen; | ||
index = (index + curSeqWidth) % curSeqWidth; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 111行和112行可以写在一块。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
int outVRowOffset = i / width0; | ||
int outVColOffset = i % width0; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 113和114行只和i有关,应该放在109行的for循环里面,不要放在110行的for循环里面,这样会多算很多次。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
int inV0RowOffset = index / width0; | ||
int inV0ColOffset = index % width0; | ||
(outV + outVRowOffset)[outVColOffset] += | ||
(inV0 + inV0RowOffset)[inV0ColOffset] * inV1[j]; | ||
} | ||
} | ||
outV += curSeqWidth; | ||
inV0 += curSeqWidth; | ||
inV1 += width1; | ||
} | ||
} | ||
|
||
void ConvShiftLayer::circularConvSeqDerivative() { | ||
const Argument& inLayer0 = getInput(0); | ||
MatrixPtr in0 = inLayer0.value; | ||
MatrixPtr in1 = getInputValue(1); | ||
MatrixPtr inG0 = getInputGrad(0); | ||
MatrixPtr inG1 = getInputGrad(1); | ||
MatrixPtr outG = getOutputGrad(); | ||
const ICpuGpuVectorPtr& sequenceStartPositions = | ||
inLayer0.sequenceStartPositions; | ||
|
||
size_t height0 = in0->getHeight(); | ||
size_t height1 = in1->getHeight(); | ||
size_t numSeqs = sequenceStartPositions->getSize() - 1; | ||
size_t width0 = in0->getWidth(); | ||
size_t width1 = in1->getWidth(); | ||
|
||
CHECK_EQ(height1, numSeqs); | ||
CHECK_EQ(height0, inG0->getHeight()); | ||
CHECK_EQ(width0, inG0->getWidth()); | ||
CHECK_EQ(height1, inG1->getHeight()); | ||
CHECK_EQ(width1, inG1->getWidth()); | ||
CHECK_EQ(height0, outG->getHeight()); | ||
CHECK_EQ(width0, outG->getWidth()); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. cheack_eq有点多,可以只保留143,148,149三行 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个check是必要的,梯度、输入等的维度必须严格一致 |
||
|
||
const int* startPosIntPtr = sequenceStartPositions->getData(false); | ||
real* outGV = outG->getData(); | ||
real* inV0 = in0->getData(); | ||
real* inV1 = in1->getData(); | ||
real* inGV0 = inG0->getData(); | ||
real* inGV1 = inG1->getData(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 命名的时候,GV都可以改成G,下同。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 延伸了Matrix函数的命名方式,GV代表真实的矩阵,感觉比较合适 |
||
|
||
int leftCtxLen = (width1 - 1) / 2; | ||
for (size_t x = 0; x < numSeqs - 1; x++) { | ||
int curSeqLen = startPosIntPtr[x + 1]; | ||
size_t curSeqWidth = curSeqLen * width0; | ||
for (size_t j = 0; j < width1; j++) { | ||
for (size_t i = 0; i < curSeqWidth; i++) { | ||
int index = i + j - leftCtxLen; | ||
index = (index + curSeqWidth) % curSeqWidth; | ||
int inGV0RowOffset = index / width0; | ||
int inGV0ColOffset = index % width0; | ||
int outGVRowOffset = i / width0; | ||
int outGVColOffset = i % width0; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 修改意见同forward There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
(inGV0 + inGV0RowOffset)[inGV0ColOffset] += | ||
(outGV + outGVRowOffset)[outGVColOffset] * inV1[j]; | ||
inGV1[j] += (outGV + outGVRowOffset)[outGVColOffset] * | ||
(inGV0 + inGV0RowOffset)[inGV0ColOffset]; | ||
} | ||
} | ||
outGV += curSeqWidth; | ||
inV0 += curSeqWidth; | ||
inV1 += width1; | ||
inGV0 += curSeqWidth; | ||
inGV1 += width1; | ||
} | ||
} | ||
|
||
void ConvShiftLayer::forward(PassType passType) { | ||
Layer::forward(passType); | ||
|
||
MatrixPtr inV0 = getInputValue(0); | ||
MatrixPtr inV1 = getInputValue(1); | ||
|
||
size_t batchSize = inV0->getHeight(); | ||
size_t dataDim = inV0->getWidth(); | ||
|
||
CHECK_EQ(batchSize, inV1->getHeight()); | ||
CHECK_EQ(dataDim, getSize()); | ||
|
||
{ | ||
REGISTER_TIMER_INFO("FwResetTimer", getName().c_str()); | ||
resetOutput(batchSize, dataDim); | ||
} | ||
|
||
MatrixPtr outV = getOutputValue(); | ||
|
||
REGISTER_TIMER_INFO("FwConvShiftTimer", getName().c_str()); | ||
outV->circularConv(*inV0, *inV1); | ||
if (!isSeqType()) { | ||
MatrixPtr inV1 = getInputValue(1); | ||
CHECK_EQ(batchSize, inV1->getHeight()); | ||
MatrixPtr outV = getOutputValue(); | ||
outV->circularConv(*inV0, *inV1); | ||
} else { | ||
circularConvSeq(); | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 将circularConvSeq函数挪到matrix.cpp里面,这个函数的接口应该类似circularConv。即 这样200-207行,可以修改为
去掉isSeqType这个函数。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
} | ||
|
||
void ConvShiftLayer::backward(const UpdateCallback& callback) { | ||
MatrixPtr inV0 = getInputValue(0); | ||
MatrixPtr inV1 = getInputValue(1); | ||
MatrixPtr outG = getOutputGrad(); | ||
MatrixPtr inG0 = getInputGrad(0); | ||
MatrixPtr inG1 = getInputGrad(1); | ||
|
||
REGISTER_TIMER_INFO("BwConvShiftTimer", getName().c_str()); | ||
|
||
if (inG0 && inG1) { | ||
if (!(inG0 && inG1)) { | ||
CHECK(!inG0 || !inG1) << "Not supported"; | ||
} | ||
|
||
if (!isSeqType()) { | ||
MatrixPtr inV0 = getInputValue(0); | ||
MatrixPtr inV1 = getInputValue(1); | ||
MatrixPtr outG = getOutputGrad(); | ||
outG->circularConvDerivative(*outG, *inV0, *inV1, *inG0, *inG1); | ||
} else { | ||
CHECK(!inG0 || !inG1) << "Not supported"; | ||
circularConvSeqDerivative(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里修改同forward函数 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
} | ||
} | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
22 ~ 42行的注释也需要对应地修改。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已经移到Matrix中