Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RNN变长输入设计 #3011

Merged
merged 11 commits into from
Jul 28, 2017
Merged

Conversation

Superjomn
Copy link
Contributor

No description provided.

void InferShape(const std::shared_ptr<Scope<>& scope) {
CopyInSeqToOut();
// ...
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Op的unittest需要保证这点。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

嗯,现在准备在 Op 的基类里做这个事情,以实现自动复制/传递的效果。

如果当前 Op 需要修改 Seq 的信息,则另外在继承的 InferShape 里自己改下。

基类会有 unittest check Seq信息的复制是否正常。
@emailweixu

对变长序列的学习,现有主流框架比如 tensorflow, pytorch, caffe2, mxnet 等均使用了padding的方式,
即将一个mini-batch内不同长度的序列补0到固定长度参与计算。

现有Paddle的 `RecurrentLayerGroup` 实现了无padding的变长序列支持,本文也将基于该模块的思路,设计重构后的变长序列支持。
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不只RecurrentLayerGroup,其他的LSTM/GRU也是无padding的变成序列支持。

由于padding是一种框架实现变长序列的妥协, 从用户角度,在使用RNN类模型时自然会比较介意padding的存在,
因此会有pytorch中对非padding方式变长序列支持长篇的讨论[3]。

由于padding对内存和计算会有额外的消耗,tensorflow和mxnet均使用了bucketing来就行优化[1][2],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

就行-》进行

-> sorted:
xx
xxx
xxxx
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorted:
xxxx
xxx
xx

x x x x
x x x
x x
```
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

125-140行和上面的重复了

`ConcatOutputs` 需要

- 将每个时间步的输出重新还原为原始输入的序列顺序(以防止Infer阶段顺序打乱)
- 将序列折叠,在batch维度上展开
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

将序列折叠,在batch维度上展开是什么意思呢?

};

std::vector<SortedSeqItem> sorted_seqs;
```
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

还要考虑memory boot layer的序列变化。是根据input序列的拆分变化来进行拆分的。可参考RecurrentGradientMachine::createMemoryFrameInfo函数。

Copy link
Contributor Author

@Superjomn Superjomn Jul 28, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

下面 InitMemories 有对应的逻辑 @luotao1

```c++
struct SeqPos {
int dim{1};
std::vector<std::shared_ptr<std::vector<int>> startPoses;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am afraid that a Tensor with some extensions is simply no longer a "tensor" at all. And it looks that this variable-lengths inputs are simply not tensor. So why not design a brand new data structure for this kind of brand new value?

For example, it seems that we can design a new data type, in addition to Tensor, let's call it class Array:

typedef std::vector<shared_ptr<Variable>> Array;

and we can create a multi-dimensional Array:

Variable article;
article.GetMutable<Array>()/*an array of paragraphs*/
  ->GetMutable<Array>()/*an array of sentences*/
  ->GetMutable<Array>()/*an array of words*/
  ->GetMutable<Tensor>()/*a tensor representing a word*/

By simply assuming that a RNNOp consumes an Array, we can assume that its step-net knows the element type of that Array, so we don't even need to record the type of elements in class Array.

Copy link
Collaborator

@wangkuiyi wangkuiyi Jul 26, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I got it that the very basic element of a sequence has a fixed size, so it is reasonable for us to pack a sequence into a tensor and record the boundaries in additional data fields. So please ignore my above comment.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

的确,对应 RNNOp 这样设计很自然。 但考虑到 RNNOp 的inputs/outputs 需要直接跟计算网络里其他的op做交互, RNNOp 的inputs/outputs 就需要框架做很多工作了。

SeqPos 有一些特殊的地方:

  • SeqPos 只要出现,就需要沿着调用的顺序一直传递下去
    • 比如 data_reader -> fc0 -> rnn0 -> fc1
      • data_reader 输出 SeqPos
      • rnn0 读取 SeqPos
      • 目前的设计是让所有的 Tensor 变成 TensorWithSequence ,这样 fc0 甚至 fc1 都会很自然地把 SeqPos 传递下去,因为后面可能有 op 会继续使用 SeqPos
  • SeqPos 可能会被修改,修改后的 SeqPos 需要沿着依赖传递下去
    • 比如 rnn0 -> max0 -> fc0 (max 指按序列求最大),每个序列长度会被压缩成 1fc0 接到的会是新的 SeqPos

所以,其实也是有两个设计思路

  1. 有强大的依赖分析,可以把 SeqPos 作为 input/output 加到分析里
    • 每个 Op 的输入需要的不同元信息,都可以通过额外的数据(比如 SeqPos)解开来
    • 最灵活的方式,但需要框架非常健壮
  2. 简单的把 SeqPos 塞入到 Tensor 里作为 Variable 存储的内容,Op默认把 input的 SeqPos 复制给 output,以达到传递SeqPos的效果,这也就是目前的设计, TensorWithSequence
    • 无需依赖分析,强制复制,按需修改
    • 应该是最简单的方式

@wangkuiyi

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wangkuiyi 目前paddle使用的是第二个设计思路:

  • 如果一个op没有修改seqPos,则output的seqPos和input的seqPos共享一个内存地址;
  • 如果一个op需要修改seqPos,则output的seqPos重新分配一块内存地址,来存储新的seqPos.


其中,startPoses可以用于存储多维的子序列,具体如下:

- 如果为1维序列,则 `dim=1`, `startPoses.size() = 1`
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的“维”应该改成“层”,因为对应的是网络里有几层RNN吧?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

对,应该是这个说法

```c++
struct TensorWithSequence {
Tensor* tensor;
std::shared_ptr<SeqPos> seq_pos;
Copy link
Collaborator

@wangkuiyi wangkuiyi Jul 26, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我理解,这里的

struct SeqPos {
  int dim{1};
  std::vector<std::shared_ptr<std::vector<int>> startPoses;
};

struct TensorWithSequence {
  Tensor* tensor;
  std::shared_ptr<SeqPos> seq_pos;
}

可以简化为

class RNNInput : public Tensor {
 public:
  size_t Elements() const { return this->dims()[0]; }
  size_t Levels() const { return start_positions_per_level_.size(); }
  const Tensor Element(size_t level, size_t element) const { 
    size_t pos = (*start_positions_per_level_)[level][element];
    return this->Slice(pos, pos+1); 
  }
 private:
  std::vector<std::share_ptr<std::vector<size_t>>> start_positions_per_level_;
};

请看是这样吗?

Copy link
Contributor

@luotao1 luotao1 Jul 27, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Elements() 返回 this->dims()[0],就是返回RNNInput中所有元素的个数。没问题。
  2. Levels() 返回有多少层RNN. 没问题。
  3. Tensor Element表示取第几层RNN的第几个元素。但应该修改成下面的形式,不然每次取出的Tensor->dims()[0]都为1。按照下面的形式,可以取出一句话的Tensor.
size_t pos_start = (*start_positions_per_level_)[level][element];
size_t pos_end = (*start_positions_per_level_)[level][element+1];
returen this->Slice(pos_start, pos_end);

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Superjom @luotao1 An additional question is -- should we make RNNInput represent a whole minibatch of data?

If so, the first dimension of the tensor should indicate instances, the second and thereafter are levels.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If so, the first dimension of the tensor should indicate instances, the second and thereafter are levels.

可以这样存。目前Paddle中,就是按照minibatch的格式存的,这样方便加速。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

现在就是这样想的,第一维存mini-batch信息,后面一层一层解开,一个金字塔的感觉,压缩到二维的 vector 里存储。

是支持 mini-batch 的,和 @qingqing01 刚推了一下格式。

但的确有点绕,需要封装成您上面的接口才直观一些。 @wangkuiyi

/*
* get an element of current level.
*/
TensorWithSequence Element(int element) const;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

getElement?

/*
* copy others' sequence info for mutation.
*/
void CopySeqPosFrom(const TensorWithSequence &other);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shareDataFrom, ShareSeqPosFrom, CopySeqPosFrom会在哪儿用到呢?

* 2-th level, element 1, start positions
* ...
* 2-th level, element n, start positions
* ...
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这段注释是附录里的解释更加清楚。

因此只需要以instance 为基本单位,
记录 start position就可以分解出每个序列的信息。

`seq_start_positions_` 里从上往下存储着 `level 0 ~ level L`的元素,可以认为level越小,表示的序列粒度越大。
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

个人觉得反过来比较好,level越小,粒度越小。因为seq_start_positions会伴随着所有op,在这些op中,用单层序列的情况会明显多于双层的。因此,如果反过来存储,使用会更加方便。即一般的op只要拿sentences level的信息就够了。

- sentence 2 has 2 words
- paragraph 1 has 2 sentences:
- sentence 0 has 5 words
- sentence 1 has 3 words
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我认为:
seq_start_positions_[0]存:0,3,7,9,14,17
seq_start_positions_[1]存:0,9,17

原因:

  1. 考虑mini-batch的情况,所以都存在一起。
  2. 在使用多层序列的情况下,单层序列的Op也会多于多层的,因此level越小,粒度越小,可以符合单测序列的Op
  3. 需要存储一个总长度。

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

有道理哦。这样每经过一级RNN,只需

std::vector<std::share_ptr<std::vector<size_t>>> start_positions_per_level_;

忽略这里的vector的最后一个元素。

@Superjomn Superjomn requested a review from hedaoyuan July 28, 2017 06:02
为了支持 `N-level` 序列的存储,本文将序列信息定义成如下数据结构:

```c++
std::shared_ptr<std::vector<std::vector<int>>> lod_start_pos_;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to consider how to share sequence position between GPU and CPU. If we want to copy from host to device every time, it could be very slow.

Maybe, we can use thrust here?

#ifdef PADDLE_ONLY_CPU
using level_t = std::vector<int>;
#else
using level_t = thrust::device_vector<int>;
#endif

`LODTensor` 具体定义如下:

```c++
class LODTensor : public Tensor {
Copy link
Collaborator

@reyoung reyoung Jul 28, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可能是我理解的问题,我不确定使用继承创建新的LODTensor是一个好主意。

因为显然这会让我们产生很多Tensor类型,至少SparseTensor就是另一种类型了。并且,SparseLODTensor是什么类型呢?难道使用菱形的继承关系么?

        Tensor
          +
          |
    +--------------+
    |              |
    +              +
LODTensor   SparseTensor
    +              +
    |              |
    +--------------+
          |
          +
    SparseLODTensor

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

也许类似于golang的class类型,使用组合而不是继承,对这个类型更好?

示例代码:

#include <iostream>
#include <memory>
#include <vector>

class Tensor {
public:
  void FunctionA() { std::cout << "FunctionA in Tensor" << std::endl; }
};

using Level = std::vector<int>;

struct LODTensor {
  Tensor tensor_;
  std::shared_ptr<std::vector<Level>> lod_;
};

struct SparseTensor {
  Tensor tensor_;
  std::vector<int> rows_;
  std::vector<int> cols_;
};

struct SparseLODTensor {
  Tensor tensor_;
  std::shared_ptr<std::vector<Level>> lod_;
  std::vector<int> rows_;
  std::vector<int> cols_;
};

template <typename T> void FunctionA(T *self) { self->tensor_.FunctionA(); }

template <> void FunctionA<Tensor>(Tensor *self) { self->FunctionA(); }

template <typename T> size_t LODLevels(T *self) {
  return self->lod_ == nullptr ? 0UL : self->lod_->size();
}

template <typename T> std::vector<Level> &LODMutableLevels(T *self) {
  if (self->lod_ == nullptr) {
    self->lod_.reset(new std::vector<Level>);
  }
  return *self->lod_;
}

int main() {
  LODTensor lod;
  FunctionA(&lod);
  LODMutableLevels(&lod).resize(10);
  std::cout << "LOD Tensor Levels" << LODLevels(&lod);
  return 0;
}

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

或者,如果觉得模板实现比较麻烦,可以使用如下实现方式

#include <iostream>
#include <memory>
#include <vector>

class Tensor {
public:
  void FunctionA() { std::cout << "FunctionA in Tensor" << std::endl; }
};

using Level = std::vector<int>;

struct LODTensor {
  std::shared_ptr<Tensor> tensor_;
  std::shared_ptr<std::vector<Level>> lod_;

  LODTensor() : tensor_(new Tensor()), lod_(new std::vector<Level>) {}

  Tensor &tensor() { return *tensor_; }

  size_t Levels() const { return lod_->size(); }
};

struct SparseTensor {
  std::shared_ptr<Tensor> tensor_;
  std::shared_ptr<std::vector<int>> rows_;
  std::shared_ptr<std::vector<int>> cols_;

  SparseTensor()
      : tensor_(new Tensor()), rows_(new std::vector<int>),
        cols_(new std::vector<int>) {}

  Tensor &tensor() { return *tensor_; }
  const Tensor &tensor() const { return *tensor_; }

  size_t NNZ() const { return cols_->size(); }
};

struct SparseLODTensor {
  std::shared_ptr<Tensor> tensor_;
  std::shared_ptr<std::vector<Level>> lod_;
  std::shared_ptr<std::vector<int>> rows_;
  std::shared_ptr<std::vector<int>> cols_;

  SparseLODTensor()
      : tensor_(new Tensor), lod_(new std::vector<Level>),
        rows_(new std::vector<int>), cols_(new std::vector<int>) {}

  Tensor &tensor() { return *tensor_; }

  LODTensor &lod_tensor() { return LODTensor{tensor_, lod_}; }

  SparseTensor &sparse_tensor() { return SparseTensor{tensor_, rows_, cols_}; }
};

int main() {
  SparseLODTensor tensor;
  tensor.tensor().FunctionA();
  tensor.lod_tensor().Levels();
  tensor.sparse_tensor().NNZ();
  return 0;
}

Copy link
Collaborator

@wangkuiyi wangkuiyi Jul 28, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because the polymorphism of Variable is via template, in particular, Variable::Get<T>(), not via inheritance, it looks to me that both inheritance and embedding work from this perspective.

Also, SparseTensor differs significantly from Tensor, that they don't have to share the same base class.

I agree embedding could be a better choice as it makes the logic cleaner:

class LODTesnor {
 public:
  typedef vector<vector<int>> LOD;
 private:
  Tensor elements_;
  LOD index_;
};

```c++
class LODTensor : public Tensor {
public:
size_t Levels() const { return seq_start_positions_.size(); }
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seq_start_positions并不存在,下面的代码也得改一下

```

其中, `lod_start_pos_` 使用了 `shared_ptr` 来减少存储和复制的代价,
可以认为 `LODTensor` 是 `Tensor` 的扩展,几乎完全兼容原始 `Tensor` 的使用。
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个实现里就不是『几乎完全兼容』Tensor。继承的语意是LODTensor is a Tensor

@wangkuiyi wangkuiyi merged commit ca8275d into PaddlePaddle:develop Jul 28, 2017
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants