Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

是否有API可以获取到所有参数权重的name #26519

Closed
supyer9 opened this issue Aug 21, 2020 · 19 comments
Closed

是否有API可以获取到所有参数权重的name #26519

supyer9 opened this issue Aug 21, 2020 · 19 comments

Comments

@supyer9
Copy link

supyer9 commented Aug 21, 2020

是否有api可以获取到所有参数权重的名称,例如返回所有的 ParamAttr(name=name + '_weights') 中的name的name列表。

@chenwhql
Copy link
Contributor

没有直接获取名称的api,但有获取参数的,可以先获取参数再处理下就能得到name列表。

请问您使用的是动态图还是静态图

@supyer9
Copy link
Author

supyer9 commented Aug 21, 2020

@chenwhql
是静态图。我想把我的问题描述的再具体一些:
1.我有一个paddle训练好的模型,其中包括model和params两个文件,能否单独的用numpy读取出params中的内容。得到权重的名称和值。
2.如果上述方法不可行的话。我知道用fluid.global_scope().find_var("weight_name").get_tensor()可以获取到某一个权重具体的值,那么是否存在调用一个函数,能够直接获取到网络所有权重的值呢?

@chenwhql
Copy link
Contributor

方法1不可行,因为存的格式numpy不能识别

方法2可行,但得看看您的情况,看您的描述,是所有参数都保存的到一个文件里了吗?你要的是值还是变量名呢?

@supyer9
Copy link
Author

supyer9 commented Aug 21, 2020

方法1不可行,因为存的格式numpy不能识别

方法2可行,但得看看您的情况,看您的描述,是所有参数都保存的到一个文件里了吗?你要的是值还是变量名呢?

谢谢你的回复:
方法1,如果不可行,那有没有什么办法将params中的内容读取出来,进行一系列的转化,将其内容打印到屏幕上,之后保存成其他格式呢?
方法2,目前来看,所有的参数都保存到了params里,我想拿到参数值

@chenwhql
Copy link
Contributor

我们有个接口叫load_program_state,在这种情况下如果给定所有的参数名,可以拿到参数的dict,但参数名信息在model里

目前确实还不太方便,没有一个接口就能搞定的,近期会发布的2.0-beta版本里会实现用一个接口做这件事

@supyer9
Copy link
Author

supyer9 commented Aug 21, 2020

我们有个接口叫load_program_state,在这种情况下如果给定所有的参数名,可以拿到参数的dict,但参数名信息在model里

目前确实还不太方便,没有一个接口就能搞定的,近期会发布的2.0-beta版本里会实现用一个接口做这件事

谢谢。
还想请教下,paddle.fluid.io.load_program_state(model_path, var_list=None), 这个model_path参数,是指文件名为‘model’的这个文件的路径吗?

@chenwhql
Copy link
Contributor

chenwhql commented Aug 21, 2020

目前的话分两步吧:

  1. 拿到所有的存储参数的name
  2. 根据name列表拿到所有的参数值

第1步:目前应该没有接口,只能自己写段code了,比如下面这样,当然也可以用别的方法,这段code我试了下,应该是ok的,这段code拿到了所有存储参数的name,就是你标题里的问题

import six
import paddle
from paddle.fluid import core

def load_program_desc(model_file_path):
    # 1. parse program desc
    with open(model_file_path, "rb") as f:
        program_desc_str = f.read()

    program_desc = core.ProgramDesc(program_desc_str)
    if not core._is_program_version_supported(program_desc._version()):
        raise ValueError("Unsupported program version: %d\n" %
                         program_desc._version())

    return program_desc


def is_persistable(var_desc):
    if var_desc.type() == core.VarDesc.VarType.FEED_MINIBATCH or \
            var_desc.type() == core.VarDesc.VarType.FETCH_LIST or \
            var_desc.type() == core.VarDesc.VarType.READER or \
            var_desc.type() == core.VarDesc.VarType.RAW:
        return False
    return var_desc.persistable()


def get_persistable_vars(program_desc):
    persistable_vars = []
    for i in six.moves.range(program_desc.num_blocks()):
        block = program_desc.block(i)
        persistable_vars.extend(list(filter(is_persistable, block.all_vars())))
    return persistable_vars


def get_persistable_var_names(program_desc):
    """
    Get all persistable variable names in ProgramDesc.
    """
    var_names = []
    persistable_vars = get_persistable_vars(program_desc)
    for var in persistable_vars:
        var_names.append(var.name())
    return var_names

program_desc = load_program_desc("./dy2stat_infer_model/__model__")
var_names = get_persistable_var_names(program_desc)
print(var_names)

第2步:用下面这个接口

https://www.paddlepaddle.org.cn/documentation/docs/zh/1.7/api_cn/io_cn/load_program_state_cn.html#load-program-state

这种情况下,model_path应该制定你参数名文件的路径,比如,your_save_path/params

不过这都是临时方案,后续我们会把这里完善下

@supyer9
Copy link
Author

supyer9 commented Aug 21, 2020

目前的话分两步吧:

  1. 拿到所有的存储参数的name
  2. 根据name列表拿到所有的参数值

第1步:目前应该没有接口,只能自己写段code了,比如下面这样,当然也可以用别的方法,这段code我试了下,应该是ok的,这段code拿到了所有存储参数的name,就是你标题里的问题

import six
import paddle
from paddle.fluid import core

def load_program_desc(model_file_path):
    # 1. parse program desc
    with open(model_file_path, "rb") as f:
        program_desc_str = f.read()

    program_desc = core.ProgramDesc(program_desc_str)
    if not core._is_program_version_supported(program_desc._version()):
        raise ValueError("Unsupported program version: %d\n" %
                         program_desc._version())

    return program_desc


def is_persistable(var_desc):
    if var_desc.type() == core.VarDesc.VarType.FEED_MINIBATCH or \
            var_desc.type() == core.VarDesc.VarType.FETCH_LIST or \
            var_desc.type() == core.VarDesc.VarType.READER or \
            var_desc.type() == core.VarDesc.VarType.RAW:
        return False
    return var_desc.persistable()


def get_persistable_vars(program_desc):
    persistable_vars = []
    for i in six.moves.range(program_desc.num_blocks()):
        block = program_desc.block(i)
        persistable_vars.extend(list(filter(is_persistable, block.all_vars())))
    return persistable_vars


def get_persistable_var_names(program_desc):
    """
    Get all persistable variable names in ProgramDesc.
    """
    var_names = []
    persistable_vars = get_persistable_vars(program_desc)
    for var in persistable_vars:
        var_names.append(var.name())
    return var_names

program_desc = load_program_desc("./dy2stat_infer_model/__model__")
var_names = get_persistable_var_names(program_desc)
print(var_names)

第2步:用下面这个接口

https://www.paddlepaddle.org.cn/documentation/docs/zh/1.7/api_cn/io_cn/load_program_state_cn.html#load-program-state

这种情况下,model_path应该制定你参数名文件的路径,比如,your_save_path/params

不过这都是临时方案,后续我们会把这里完善下

非常感谢!我尝试一下。

@chenwhql
Copy link
Contributor

同学,这个问题后来解决了吗

@supyer9
Copy link
Author

supyer9 commented Aug 24, 2020

同学,这个问题后来解决了吗

您好,第一步可以获取到 var_names。第二步用fluid.load_program_state这个接口,第一个参数传的是/inference/ch_det_mv3_db/params,第二个参数传的是 params。会报错。
报错信息:
/inference/ch_det_mv3_db/params.pdparams not found
TypeError: value in var_list must be variable
请问下问题出在哪呢?

@chenwhql
Copy link
Contributor

噢噢,第二个接口load_program_state要求传入的var_list都是Variable类型,比较麻烦了,第二步不太可行。

目前暂时还是用你说的这个方法吧,fluid.global_scope().find_var("weight_name").get_tensor()

这个我们记个需求尽快完善下,让一个接口可以把这些事做完

@supyer9
Copy link
Author

supyer9 commented Aug 24, 2020

噢噢,第二个接口load_program_state要求传入的var_list都是Variable类型,比较麻烦了,第二步不太可行。

目前暂时还是用你说的这个方法吧,fluid.global_scope().find_var("weight_name").get_tensor()

这个我们记个需求尽快完善下,让一个接口可以把这些事做完

您好,我刚才试了下 fluid.global_scope().find_var("weight_name").get_tensor() 这个接口
会报如下错误,请问是参数传的不对吗?var_names[0]是第一步取到的参数列表中的第一个字符串
program_state = fluid.global_scope().find_var(var_names[0]).get_tensor()
AttributeError: 'NoneType' object has no attribute 'get_tensor'

@supyer9
Copy link
Author

supyer9 commented Aug 24, 2020

@chenwhql
我下载了模型的checkpoint模型(之前使用的是inference模型),用fluid.io.load_program_state接口拿到了所有的参数名称和值,vallist为none。不过目前只能拿出所有的参数,没有发现怎么只拿出特定的参数

@chenwhql
Copy link
Contributor

噢噢,第二个接口load_program_state要求传入的var_list都是Variable类型,比较麻烦了,第二步不太可行。
目前暂时还是用你说的这个方法吧,fluid.global_scope().find_var("weight_name").get_tensor()
这个我们记个需求尽快完善下,让一个接口可以把这些事做完

您好,我刚才试了下 fluid.global_scope().find_var("weight_name").get_tensor() 这个接口
会报如下错误,请问是参数传的不对吗?var_names[0]是第一步取到的参数列表中的第一个字符串
program_state = fluid.global_scope().find_var(var_names[0]).get_tensor()
AttributeError: 'NoneType' object has no attribute 'get_tensor'

这里要load_inferenece_model之后在inference_program里面搜索

@chenwhql
Copy link
Contributor

@chenwhql
我下载了模型的checkpoint模型(之前使用的是inference模型),用fluid.io.load_program_state接口拿到了所有的参数名称和值,vallist为none。不过目前只能拿出所有的参数,没有发现怎么只拿出特定的参数

这里返回的是一个dict吧,前面已经拿到了name,根据名称找到参数?

@supyer9
Copy link
Author

supyer9 commented Aug 25, 2020

@chenwhql
我下载了模型的checkpoint模型(之前使用的是inference模型),用fluid.io.load_program_state接口拿到了所有的参数名称和值,vallist为none。不过目前只能拿出所有的参数,没有发现怎么只拿出特定的参数

这里返回的是一个dict吧,前面已经拿到了name,根据名称找到参数?

嗯嗯,是的,根据名称从这个dict里找参数

@chenwhql
Copy link
Contributor

chenwhql commented Sep 1, 2020

您好,您的问题应该解决了吧,这个可以关闭了吗

@supyer9
Copy link
Author

supyer9 commented Sep 1, 2020

您好,您的问题应该解决了吧,这个可以关闭了吗

嗯嗯,解决了,谢谢您!

@chenwhql chenwhql closed this as completed Sep 1, 2020
@yhhu99
Copy link

yhhu99 commented Nov 29, 2020

噢噢,第二个接口load_program_state要求传入的var_list都是Variable类型,比较麻烦了,第二步不太可行。
目前暂时还是用你说的这个方法吧,fluid.global_scope().find_var("weight_name").get_tensor()
这个我们记个需求尽快完善下,让一个接口可以把这些事做完

您好,我刚才试了下 fluid.global_scope().find_var("weight_name").get_tensor() 这个接口
会报如下错误,请问是参数传的不对吗?var_names[0]是第一步取到的参数列表中的第一个字符串
program_state = fluid.global_scope().find_var(var_names[0]).get_tensor()
AttributeError: 'NoneType' object has no attribute 'get_tensor'

这里要load_inferenece_model之后在inference_program里面搜索

您好,这个“在inference_program里面搜索"是要怎么做呢?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants