-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
是否有API可以获取到所有参数权重的name #26519
Comments
没有直接获取名称的api,但有获取参数的,可以先获取参数再处理下就能得到name列表。 请问您使用的是动态图还是静态图 |
@chenwhql |
方法1不可行,因为存的格式numpy不能识别 方法2可行,但得看看您的情况,看您的描述,是所有参数都保存的到一个文件里了吗?你要的是值还是变量名呢? |
谢谢你的回复: |
我们有个接口叫load_program_state,在这种情况下如果给定所有的参数名,可以拿到参数的dict,但参数名信息在model里 目前确实还不太方便,没有一个接口就能搞定的,近期会发布的2.0-beta版本里会实现用一个接口做这件事 |
谢谢。 |
目前的话分两步吧:
第1步:目前应该没有接口,只能自己写段code了,比如下面这样,当然也可以用别的方法,这段code我试了下,应该是ok的,这段code拿到了所有存储参数的name,就是你标题里的问题
第2步:用下面这个接口 这种情况下,model_path应该制定你参数名文件的路径,比如,your_save_path/params 不过这都是临时方案,后续我们会把这里完善下 |
非常感谢!我尝试一下。 |
同学,这个问题后来解决了吗 |
您好,第一步可以获取到 var_names。第二步用fluid.load_program_state这个接口,第一个参数传的是/inference/ch_det_mv3_db/params,第二个参数传的是 params。会报错。 |
噢噢,第二个接口load_program_state要求传入的var_list都是Variable类型,比较麻烦了,第二步不太可行。 目前暂时还是用你说的这个方法吧,fluid.global_scope().find_var("weight_name").get_tensor() 这个我们记个需求尽快完善下,让一个接口可以把这些事做完 |
您好,我刚才试了下 fluid.global_scope().find_var("weight_name").get_tensor() 这个接口 |
@chenwhql |
这里要load_inferenece_model之后在inference_program里面搜索 |
这里返回的是一个dict吧,前面已经拿到了name,根据名称找到参数? |
嗯嗯,是的,根据名称从这个dict里找参数 |
您好,您的问题应该解决了吧,这个可以关闭了吗 |
嗯嗯,解决了,谢谢您! |
您好,这个“在inference_program里面搜索"是要怎么做呢? |
是否有api可以获取到所有参数权重的名称,例如返回所有的 ParamAttr(name=name + '_weights') 中的name的name列表。
The text was updated successfully, but these errors were encountered: