Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Type error in PolicySaver.save() #929

Open
anmol438 opened this issue Jun 15, 2024 · 1 comment
Open

Type error in PolicySaver.save() #929

anmol438 opened this issue Jun 15, 2024 · 1 comment

Comments

@anmol438
Copy link

I have a DQN agent with policy of type <tf_agents.policies.greedy_policy.GreedyPolicy> to train a gym environment (CartPole-v1). I am using tf_agents 0.16.0 and gym 0.23.0
During saving the policy tf_agents.policies.policy_saver.PolicySaver I am having the following error:
policy_saver = PolicySaver(agent.policy)
policy_saver.save('./policy')

TypeError Traceback (most recent call last)
Cell In[8], line 1
----> 1 policy_saver = PolicySaver(agent.policy)
2 policy_saver.save('./policy')

File c:\Users\iitka.conda\envs\temp\lib\site-packages\tf_agents\policies\policy_saver.py:333, in PolicySaver.init(self, policy, batch_size, use_nest_path_signatures, seed, train_step, input_fn_and_spec, metadata)
326 get_initial_state_fn.get_concrete_function(*get_initial_state_input_specs)
328 train_step_fn = common.function(
329 lambda: saved_policy.train_step
330 ).get_concrete_function()
331 get_metadata_fn = common.function(
332 lambda: saved_policy.metadata
--> 333 ).get_concrete_function()
335 batched_time_step_spec = tf.nest.map_structure(
336 lambda spec: add_batch_dim(spec, [batch_size]), policy.time_step_spec
337 )
338 batched_time_step_spec = cast(ts.TimeStep, batched_time_step_spec)

File c:\Users\iitka.conda\envs\temp\lib\site-packages\tensorflow\python\eager\polymorphic_function\polymorphic_function.py:1227, in Function.get_concrete_function(self, *args, **kwargs)
1225 def get_concrete_function(self, *args, **kwargs):
1226 # Implements PolymorphicFunction.get_concrete_function.
-> 1227 concrete = self._get_concrete_function_garbage_collected(*args, **kwargs)
1228 concrete._garbage_collector.release() # pylint: disable=protected-access
1229 return concrete

File c:\Users\iitka.conda\envs\temp\lib\site-packages\tensorflow\python\eager\polymorphic_function\polymorphic_function.py:1197, in Function._get_concrete_function_garbage_collected(self, *args, **kwargs)
1195 if self._variable_creation_config is None:
1196 initializers = []
-> 1197 self._initialize(args, kwargs, add_initializers_to=initializers)
1198 self._initialize_uninitialized_variables(initializers)
1200 if self._created_variables:
1201 # In this case we have created variables on the first call, so we run the
1202 # version which is guaranteed to never create variables.

File c:\Users\iitka.conda\envs\temp\lib\site-packages\tensorflow\python\eager\polymorphic_function\polymorphic_function.py:695, in Function._initialize(self, args, kwds, add_initializers_to)
690 self._variable_creation_config = self._generate_scoped_tracing_options(
691 variable_capturing_scope,
692 tracing_compilation.ScopeType.VARIABLE_CREATION,
693 )
694 # Force the definition of the function for these arguments
--> 695 self._concrete_variable_creation_fn = tracing_compilation.trace_function(
696 args, kwds, self._variable_creation_config
697 )
699 def invalid_creator_scope(*unused_args, **unused_kwds):
700 """Disables variable creation."""

File c:\Users\iitka.conda\envs\temp\lib\site-packages\tensorflow\python\eager\polymorphic_function\tracing_compilation.py:178, in trace_function(args, kwargs, tracing_options)
175 args = tracing_options.input_signature
176 kwargs = {}
--> 178 concrete_function = _maybe_define_function(
179 args, kwargs, tracing_options
180 )
182 if not tracing_options.bind_graph_to_function:
183 concrete_function._garbage_collector.release() # pylint: disable=protected-access

File c:\Users\iitka.conda\envs\temp\lib\site-packages\tensorflow\python\eager\polymorphic_function\tracing_compilation.py:283, in _maybe_define_function(args, kwargs, tracing_options)
281 else:
282 target_func_type = lookup_func_type
--> 283 concrete_function = _create_concrete_function(
284 target_func_type, lookup_func_context, func_graph, tracing_options
285 )
287 if tracing_options.function_cache is not None:
288 tracing_options.function_cache.add(
289 concrete_function, current_func_context
290 )

File c:\Users\iitka.conda\envs\temp\lib\site-packages\tensorflow\python\eager\polymorphic_function\tracing_compilation.py:331, in _create_concrete_function(function_type, type_context, func_graph, tracing_options)
328 tracing_options.function_captures.merge_by_ref_with(graph_capture_container)
330 # Create a new FunctionType including captures and outputs.
--> 331 output_type = trace_type.from_value(
332 traced_func_graph.structured_outputs, type_context
333 )
334 traced_func_type = function_type_lib.FunctionType(
335 function_type.parameters.values(),
336 traced_func_graph.function_captures.capture_types,
337 return_annotation=output_type,
338 )
340 concrete_function = concrete_function_lib.ConcreteFunction.from_func_graph(
341 traced_func_graph,
342 traced_func_type,
(...)
348 shared_func_graph=False,
349 )

File c:\Users\iitka.conda\envs\temp\lib\site-packages\tensorflow\core\function\trace_type\trace_type_builder.py:144, in from_value(value, context)
142 if context.is_legacy_signature and isinstance(value, trace.TraceType):
143 return value
--> 144 elif isinstance(value, trace.SupportsTracingProtocol):
145 generated_type = value.tf_tracing_type(context)
146 if not isinstance(generated_type, trace.TraceType):

File c:\Users\iitka.conda\envs\temp\lib\site-packages\typing_extensions.py:647, in _ProtocolMeta.instancecheck(cls, instance)
645 for attr in cls.protocol_attrs:
646 try:
--> 647 val = inspect.getattr_static(instance, attr)
648 except AttributeError:
649 break

File c:\Users\iitka.conda\envs\temp\lib\inspect.py:1743, in getattr_static(obj, attr, default)
1740 dict_attr = _shadowed_dict(klass)
1741 if (dict_attr is _sentinel or
1742 type(dict_attr) is types.MemberDescriptorType):
-> 1743 instance_result = _check_instance(obj, attr)
1744 else:
1745 klass = obj

File c:\Users\iitka.conda\envs\temp\lib\inspect.py:1690, in _check_instance(obj, attr)
1688 instance_dict = {}
1689 try:
-> 1690 instance_dict = object.getattribute(obj, "dict")
1691 except AttributeError:
1692 pass

TypeError: this dict descriptor does not support '_DictWrapper' objects

What are the possible reasons and how can I resolve it?

@BaLinuss
Copy link

BaLinuss commented Jul 8, 2024

Hi,
I encountered the same problem with older code that worked previously.
Rolling back to Version 0.16.0 made it work again.
Note that you also need tf@2.12.0 to make it work.

Hope this works for you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants