Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add the unified LSTM api [WIP] #25502

Closed
wants to merge 1 commit into from

Conversation

guoshengCS
Copy link
Contributor

PR types

New features

PR changes

APIs

Describe

Add a unified LSTM API, which switch between cudnn and non-cudnn

@paddle-bot-old
Copy link

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@guoshengCS
Copy link
Contributor Author

guoshengCS commented Jul 13, 2020

整个LSTM相关工作包含以下内容(测试代码见 #25502 (comment) ):

  1. 修复 cudnn lstm op

    • 动态图适配:

      • 内容:修改cudnn_lstm_op以支持动态图
      • 详细: 去掉cudnn_lstm_op中对scope的使用(scope只用于静态图),将CudnnRNNCache(原来在scope中创建)中存放的tensor加入OP输入输出
      • 状态:基本完成,remove scope in cudnn lstm #25188
    • 修复其他bug:

      • 内容:双向、多层结果无法运行或计算结果有异,需修复;错误的shape也能得到结果(见下面的测试代码);要求预设max_len需去掉
      • 详细:计算结果有异原因排查中,max_len可以直接替换为数据实际的length
      • 状态:TO BE DONE
  2. 增强cudnn lstm op

    • 支持padding数据

      • 内容:cudnn lstm op中现用接口使用packed数据,若使用padding数据无法得到各层最后一步的state(若只需最后一层的state可以通过用sequence_length来gather取得),增强后需支持sequence_length与padding数据。
      • 详细:可采用pytorch的方案引入pack和unpack;或者类似tf使用较新版本的cudnn接口。后者可能更容易实现
      • 状态:TO BE DONE
    • 支持多种数据类型

      • 内容:支持double和half类型
      • 详细:当前代码无较多扩展性,可能需要重新实现整个cudnn lstm op代码
      • 状态:TO BE DONE
  3. 统一的LSTM API

    • 自动切换CUDNN

      • 内容:在用户参数满足且可以使用CUDNN时(lib存在、使用CUDAPlace),自动使用CUDNN实现
      • 详细:需要框架提供获取执行信息(是否可以使用CUDNN)的接口以便在python端切换,或者实现CPU版本lstm op在C++运行时切换。不过从静态图组网功能来讲,在组网阶段使用获取执行信息信息是否合适。当前PR中只使用了不完备的判断。
      • 状态:TO BE DONE
    • CUDNN与非CUDNN实现参数转换

      • 内容:CUDNN使用特定的参数格式,若自动切换CUDNN,也需要自动转换参数(CUDNN实现中使用合并的参数,非CUDNN实现参数分开存放)
      • 详细:存在三种方案路径:
        1. Python端通过concat等其他一些OP产生新的参数,原参数将不再会更新,需考虑尽可能屏蔽用户调用原参数,并额外提供一些接口在需要时隐式同步。
        2. 在C++端完成参数转换,并将新参数和原参数的共享存储,新参数对用户隐藏,用户仍使用原参数;并在检测到参数不合要求时(如某个参数指向了其他参数而不再连续)在执行时重新转换并打warning(pytorch的方法),较多内容需要由C++端完成。
        3. 实现CPU版本lstm op,使用与CUDNN相同格式的参数,避免额外的转换。
      • 状态:当前PR中基本完成了第一种方案,一些调用加入warning,提供了接口保证初始化和保存的时(set_parameter_values(self, values)state_dict(self))参数的正确。 需考虑框架层面是否允许共享存储使用第二种方案
    • 兼容动态图和静态图

      • 说明:和其他API类似,要保证组网上动态图和静态图均可使用
      • 详细:由于自动参数转换的存在,需保证转换过程对静态图也要生效(加入startup_program中),同时也要保证额外加入的接口对静态图和动态图均生效。
      • 状态:当前PR中基本完成。但和其他API一样,不考虑动态图下的特殊用法。
    • 兼容多种形式及原有LSTM计算实现

      • 说明:LSTM的具体实现有些许差别
      • 详细:CUDNN与原始论文中实现存在差异
      • 状态:暂定只支持对齐CUDNN计算的实现(<8.0,8.0之后CUDNN支持更丰富度实现)

按照C++ OP和python API划分的话,整理来讲对C++端的需求是1和2,个人偏向于1和2通过重新实现cudnn lstm op来解决。3中的方案选择也会对C++端产生一些额外的需求。

@guoshengCS
Copy link
Contributor Author

guoshengCS commented Jul 13, 2020

测试代码,调整层数和双向结果有问题

import paddle
import paddle.fluid as fluid

from paddle.nn.layer import LSTM

import torch
import numpy as np

np.random.seed(123)
torch.manual_seed(123)
dtype = "float32"
device = torch.device('cuda:0')

x_size = seq_length, batch_size, input_size = 20, 2, 32
hidden_size = 32
gate_size = 4 * hidden_size
# 调整层数和双向结果有问题
n_layer, is_bi = 1, False
n_direct = 2 if is_bi else 1

dygraph = True

input = np.random.random(x_size).astype(dtype)
np_all_weights = []
for layer in range(n_layer):
    for direction in range(n_direct):
        layer_input_size = input_size if layer == 0 else hidden_size * n_direct
        w_ih = np.random.rand(gate_size, layer_input_size).astype(dtype)
        w_hh = np.random.rand(gate_size, hidden_size).astype(dtype)
        b_ih = np.random.rand(gate_size).astype(dtype)
        b_hh = np.random.rand(gate_size).astype(dtype)
        layer_params = [w_ih, w_hh, b_ih, b_hh]
        np_all_weights.extend(layer_params)

lstm = torch.nn.LSTM(input_size, hidden_size, num_layers=n_layer, bidirectional=is_bi).to(device)
for i, param in enumerate(lstm.parameters()):
    param.data = torch.from_numpy(np_all_weights[i]).to(device)

x = torch.from_numpy(input).to(device)
out, state = lstm(x)
print(out.cpu().detach().numpy(), out.shape)

def to_tensor(var, val):
    def set_var(var, ndarray):
        assert fluid.executor.global_scope().find_var(
            var.name) and fluid.executor.global_scope().find_var(
                var.name).get_tensor(
                ), "Please do parameter initialization."
        t = fluid.executor.global_scope().find_var(
            var.name).get_tensor()
        p = t._place()
        if p.is_cpu_place():
            place = fluid.CPUPlace()
        elif p.is_cuda_pinned_place():
            place = fluid.CUDAPinnedPlace()
        else:
            p = fluid.core.Place()
            p.set_place(t._place())
            place = fluid.CUDAPlace(p.gpu_device_id())
        t.set(ndarray, place)

    val = to_numpy(val)
    if isinstance(var, fluid.core.VarBase):
        return var.set_value(val)
    set_var(var, val)
    return var

def to_numpy(var):
    if isinstance(var, np.ndarray):
        return var
    if isinstance(var, fluid.core.VarBase):
        return var.numpy()
    assert fluid.executor.global_scope(
    ).find_var(var.name) and fluid.executor.global_scope().find_var(
        var.name).get_tensor(), "Please do parameter initialization."
    t = fluid.executor.global_scope().find_var(var.name).get_tensor()
    return np.array(t)

place = paddle.fluid.CUDAPlace(0)
paddle.fluid.enable_dygraph(place) if dygraph else None
pd_lstm = paddle.nn.layer.rnn.LSTM(input_size, hidden_size, num_layers=n_layer, direction="bidirect" if is_bi else "forward", time_major=True, dtype=dtype)
if not dygraph:
    executor = paddle.fluid.Executor(place)
    executor.run(paddle.fluid.default_startup_program())

pd_lstm.set_parameter_values(np_all_weights)

for p1, p2 in zip(lstm.parameters(), pd_lstm.lstm.parameters()):
   p1 = p1.cpu().detach().numpy()
   p2 = to_numpy(p2)
   assert np.allclose(p1, p2)

x = paddle.fluid.dygraph.to_variable(input) if dygraph else fluid.data(name='x', shape=input.shape, dtype=input.dtype)
x.stop_gradient = False
out, state = pd_lstm(x)
out = out.numpy() if dygraph else executor.run(feed={'x': input}, fetch_list=[out])[0]
print(out, out.shape)

直接使用静态图API也存在异常结果,如错误的init_h的shape仍能得到计算结果

import numpy as np
import paddle.fluid as fluid
import paddle.fluid.layers as layers

fluid.default_main_program().random_seed = 123
fluid.default_startup_program().random_seed = 123
np.random.seed(123)

emb_dim = 20 #256
vocab_size = 10000
data = fluid.data(name='x', shape=[None, None], dtype='int64')
emb = fluid.embedding(input=data, size=[vocab_size, emb_dim], is_sparse=True)
batch_size = 1 #20
max_len = 20#10
dropout_prob = 0.0
input_size = 100
hidden_size = 20#150
num_layers = 1
init_h = layers.fill_constant( [num_layers*1, batch_size, hidden_size-2], 'float32', 0.0 )
init_c = layers.fill_constant( [num_layers*1, batch_size, hidden_size-2], 'float32', 0.0 )
rnn_out, last_h, last_c = layers.lstm( emb, init_h, init_c, \
        max_len+10, hidden_size, num_layers, \
        dropout_prob=dropout_prob, seed=123, is_bidirec=True)
rnn_out.shape  # (-1, 100, 150)
last_h.shape  # (1, 20, 150)
last_c.shape  # (1, 20, 150)
place = fluid.CUDAPlace(0)
executor = fluid.Executor(place)
executor.run(fluid.default_startup_program())
out = executor.run(feed={'x': np.random.randint(0, 100, (max_len, batch_size)).astype('int64')}, fetch_list=[rnn_out])[0]
print(out, out.shape)

@guoshengCS
Copy link
Contributor Author

guoshengCS commented Jul 14, 2020

将动态图和静态图接口合入一份测试代码,通过dygraph设置。打印torch.nn.LSTM和paddle.LSTM的结果,静态图下额外打印paddle.lstm的结果。测试发现双向、多层结果无法运行或计算结果有异;错误的shape也能得到结果。

import torch
import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.nn.layer import LSTM
import paddle.fluid.layers as layers
fluid.default_main_program().random_seed = 123
fluid.default_startup_program().random_seed = 123

np.random.seed(123)
torch.manual_seed(123)
dtype = "float32"
device = torch.device('cuda:0')

x_size = seq_length, batch_size, input_size = 20, 2, 32
hidden_size = 32
gate_size = 4 * hidden_size
n_layer, is_bi = 2, False
n_direct = 2 if is_bi else 1

dygraph = False#True

input = np.random.random(x_size).astype(dtype)
np_all_weights = []
for layer in range(n_layer):
    for direction in range(n_direct):
        layer_input_size = input_size if layer == 0 else hidden_size * n_direct
        w_ih = np.random.rand(gate_size, layer_input_size).astype(dtype)
        w_hh = np.random.rand(gate_size, hidden_size).astype(dtype)
        b_ih = np.random.rand(gate_size).astype(dtype)
        b_hh = np.random.rand(gate_size).astype(dtype)
        layer_params = [w_ih, w_hh, b_ih, b_hh]
        np_all_weights.extend(layer_params)
np_flat_weight = np.concatenate([p.reshape([-1]) for p in np_all_weights])

lstm = torch.nn.LSTM(input_size, hidden_size, num_layers=n_layer, bidirectional=is_bi).to(device)
for i, param in enumerate(lstm.parameters()):
    param.data = torch.from_numpy(np_all_weights[i]).to(device)

x = torch.from_numpy(input).to(device)
out, state = lstm(x)
print(out.cpu().detach().numpy(), out.shape)

def to_tensor(var, val):
    def set_var(var, ndarray):
        assert fluid.executor.global_scope().find_var(
            var.name) and fluid.executor.global_scope().find_var(
                var.name).get_tensor(
                ), "Please do parameter initialization."
        t = fluid.executor.global_scope().find_var(
            var.name).get_tensor()
        p = t._place()
        if p.is_cpu_place():
            place = fluid.CPUPlace()
        elif p.is_cuda_pinned_place():
            place = fluid.CUDAPinnedPlace()
        else:
            p = fluid.core.Place()
            p.set_place(t._place())
            place = fluid.CUDAPlace(p.gpu_device_id())
        t.set(ndarray, place)

    val = to_numpy(val)
    if isinstance(var, fluid.core.VarBase):
        return var.set_value(val)
    set_var(var, val)
    return var

def to_numpy(var):
    if isinstance(var, np.ndarray):
        return var
    if isinstance(var, fluid.core.VarBase):
        return var.numpy()
    assert fluid.executor.global_scope(
    ).find_var(var.name) and fluid.executor.global_scope().find_var(
        var.name).get_tensor(), "Please do parameter initialization."
    t = fluid.executor.global_scope().find_var(var.name).get_tensor()
    return np.array(t)

place = paddle.fluid.CUDAPlace(0)
if dygraph:
    paddle.fluid.enable_dygraph(place)

x = paddle.fluid.dygraph.to_variable(input) if dygraph else fluid.data(name='x', shape=input.shape, dtype=input.dtype)
x.stop_gradient = False
outs = []
pd_lstm = paddle.nn.layer.rnn.LSTM(input_size, hidden_size, num_layers=n_layer, direction="bidirect" if is_bi else "forward", time_major=True, dtype=dtype)

if not dygraph:
    init_h = layers.fill_constant( [n_layer * n_direct, batch_size, hidden_size-10], dtype, 0.0 )
    init_c = layers.fill_constant( [n_layer * n_direct, batch_size, hidden_size-10], dtype, 0.0 )
    rnn_out, last_h, last_c = layers.lstm(x, init_h, init_c, \
            50, hidden_size, n_layer, \
            dropout_prob=0., seed=123, is_bidirec=is_bi)
    outs.append(rnn_out)
    executor = paddle.fluid.Executor(place)
    executor.run(paddle.fluid.default_startup_program())

all_params = fluid.default_main_program().all_parameters()
for p in all_params:
    if p.shape == np_flat_weight.shape:
        to_tensor(p, np_flat_weight)
pd_lstm.set_parameter_values(np_all_weights)

out, state = pd_lstm(x)
outs.append(out)

if not dygraph:
    outs = executor.run(feed={'x': input}, fetch_list=outs)
for out in outs:
    out = to_numpy(out)
    print(out, out.shape)

# 测试转换前后参数是否一致
for p1, p2 in zip(lstm.parameters(), pd_lstm.lstm.parameters()):
   p1 = p1.cpu().detach().numpy()
   p2 = to_numpy(p2)
   print(np.allclose(p1, p2))

@paddle-bot-old
Copy link

Since you haven't replied for more than a year, we have closed this issue/pr.
If the problem is not solved or there is a follow-up one, please reopen it at any time and we will continue to follow up.
由于您超过一年未回复,我们将关闭这个issue/pr。
若问题未解决或有后续问题,请随时重新打开,我们会继续跟进。

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant