You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I have a DQN agent with policy of type <tf_agents.policies.greedy_policy.GreedyPolicy> to train a gym environment (CartPole-v1). I am using tf_agents 0.16.0 and gym 0.23.0
During saving the policy tf_agents.policies.policy_saver.PolicySaver I am having the following error: policy_saver = PolicySaver(agent.policy)
policy_saver.save('./policy')
File c:\Users\iitka.conda\envs\temp\lib\site-packages\tensorflow\python\eager\polymorphic_function\polymorphic_function.py:1197, in Function._get_concrete_function_garbage_collected(self, *args, **kwargs)
1195 if self._variable_creation_config is None:
1196 initializers = []
-> 1197 self._initialize(args, kwargs, add_initializers_to=initializers)
1198 self._initialize_uninitialized_variables(initializers)
1200 if self._created_variables:
1201 # In this case we have created variables on the first call, so we run the
1202 # version which is guaranteed to never create variables.
File c:\Users\iitka.conda\envs\temp\lib\site-packages\tensorflow\python\eager\polymorphic_function\polymorphic_function.py:695, in Function._initialize(self, args, kwds, add_initializers_to)
690 self._variable_creation_config = self._generate_scoped_tracing_options(
691 variable_capturing_scope,
692 tracing_compilation.ScopeType.VARIABLE_CREATION,
693 )
694 # Force the definition of the function for these arguments
--> 695 self._concrete_variable_creation_fn = tracing_compilation.trace_function(
696 args, kwds, self._variable_creation_config
697 )
699 def invalid_creator_scope(*unused_args, **unused_kwds):
700 """Disables variable creation."""
File c:\Users\iitka.conda\envs\temp\lib\site-packages\tensorflow\core\function\trace_type\trace_type_builder.py:144, in from_value(value, context)
142 if context.is_legacy_signature and isinstance(value, trace.TraceType):
143 return value
--> 144 elif isinstance(value, trace.SupportsTracingProtocol):
145 generated_type = value.tf_tracing_type(context)
146 if not isinstance(generated_type, trace.TraceType):
File c:\Users\iitka.conda\envs\temp\lib\site-packages\typing_extensions.py:647, in _ProtocolMeta.instancecheck(cls, instance)
645 for attr in cls.protocol_attrs:
646 try:
--> 647 val = inspect.getattr_static(instance, attr)
648 except AttributeError:
649 break
File c:\Users\iitka.conda\envs\temp\lib\inspect.py:1743, in getattr_static(obj, attr, default)
1740 dict_attr = _shadowed_dict(klass)
1741 if (dict_attr is _sentinel or
1742 type(dict_attr) is types.MemberDescriptorType):
-> 1743 instance_result = _check_instance(obj, attr)
1744 else:
1745 klass = obj
Hi,
I encountered the same problem with older code that worked previously.
Rolling back to Version 0.16.0 made it work again.
Note that you also need tf@2.12.0 to make it work.
I have a DQN agent with policy of type <tf_agents.policies.greedy_policy.GreedyPolicy> to train a gym environment (CartPole-v1). I am using tf_agents 0.16.0 and gym 0.23.0
During saving the policy tf_agents.policies.policy_saver.PolicySaver I am having the following error:
policy_saver = PolicySaver(agent.policy)
policy_saver.save('./policy')
TypeError Traceback (most recent call last)
Cell In[8], line 1
----> 1 policy_saver = PolicySaver(agent.policy)
2 policy_saver.save('./policy')
File c:\Users\iitka.conda\envs\temp\lib\site-packages\tf_agents\policies\policy_saver.py:333, in PolicySaver.init(self, policy, batch_size, use_nest_path_signatures, seed, train_step, input_fn_and_spec, metadata)
326 get_initial_state_fn.get_concrete_function(*get_initial_state_input_specs)
328 train_step_fn = common.function(
329 lambda: saved_policy.train_step
330 ).get_concrete_function()
331 get_metadata_fn = common.function(
332 lambda: saved_policy.metadata
--> 333 ).get_concrete_function()
335 batched_time_step_spec = tf.nest.map_structure(
336 lambda spec: add_batch_dim(spec, [batch_size]), policy.time_step_spec
337 )
338 batched_time_step_spec = cast(ts.TimeStep, batched_time_step_spec)
File c:\Users\iitka.conda\envs\temp\lib\site-packages\tensorflow\python\eager\polymorphic_function\polymorphic_function.py:1227, in Function.get_concrete_function(self, *args, **kwargs)
1225 def get_concrete_function(self, *args, **kwargs):
1226 # Implements PolymorphicFunction.get_concrete_function.
-> 1227 concrete = self._get_concrete_function_garbage_collected(*args, **kwargs)
1228 concrete._garbage_collector.release() # pylint: disable=protected-access
1229 return concrete
File c:\Users\iitka.conda\envs\temp\lib\site-packages\tensorflow\python\eager\polymorphic_function\polymorphic_function.py:1197, in Function._get_concrete_function_garbage_collected(self, *args, **kwargs)
1195 if self._variable_creation_config is None:
1196 initializers = []
-> 1197 self._initialize(args, kwargs, add_initializers_to=initializers)
1198 self._initialize_uninitialized_variables(initializers)
1200 if self._created_variables:
1201 # In this case we have created variables on the first call, so we run the
1202 # version which is guaranteed to never create variables.
File c:\Users\iitka.conda\envs\temp\lib\site-packages\tensorflow\python\eager\polymorphic_function\polymorphic_function.py:695, in Function._initialize(self, args, kwds, add_initializers_to)
690 self._variable_creation_config = self._generate_scoped_tracing_options(
691 variable_capturing_scope,
692 tracing_compilation.ScopeType.VARIABLE_CREATION,
693 )
694 # Force the definition of the function for these arguments
--> 695 self._concrete_variable_creation_fn = tracing_compilation.trace_function(
696 args, kwds, self._variable_creation_config
697 )
699 def invalid_creator_scope(*unused_args, **unused_kwds):
700 """Disables variable creation."""
File c:\Users\iitka.conda\envs\temp\lib\site-packages\tensorflow\python\eager\polymorphic_function\tracing_compilation.py:178, in trace_function(args, kwargs, tracing_options)
175 args = tracing_options.input_signature
176 kwargs = {}
--> 178 concrete_function = _maybe_define_function(
179 args, kwargs, tracing_options
180 )
182 if not tracing_options.bind_graph_to_function:
183 concrete_function._garbage_collector.release() # pylint: disable=protected-access
File c:\Users\iitka.conda\envs\temp\lib\site-packages\tensorflow\python\eager\polymorphic_function\tracing_compilation.py:283, in _maybe_define_function(args, kwargs, tracing_options)
281 else:
282 target_func_type = lookup_func_type
--> 283 concrete_function = _create_concrete_function(
284 target_func_type, lookup_func_context, func_graph, tracing_options
285 )
287 if tracing_options.function_cache is not None:
288 tracing_options.function_cache.add(
289 concrete_function, current_func_context
290 )
File c:\Users\iitka.conda\envs\temp\lib\site-packages\tensorflow\python\eager\polymorphic_function\tracing_compilation.py:331, in _create_concrete_function(function_type, type_context, func_graph, tracing_options)
328 tracing_options.function_captures.merge_by_ref_with(graph_capture_container)
330 # Create a new FunctionType including captures and outputs.
--> 331 output_type = trace_type.from_value(
332 traced_func_graph.structured_outputs, type_context
333 )
334 traced_func_type = function_type_lib.FunctionType(
335 function_type.parameters.values(),
336 traced_func_graph.function_captures.capture_types,
337 return_annotation=output_type,
338 )
340 concrete_function = concrete_function_lib.ConcreteFunction.from_func_graph(
341 traced_func_graph,
342 traced_func_type,
(...)
348 shared_func_graph=False,
349 )
File c:\Users\iitka.conda\envs\temp\lib\site-packages\tensorflow\core\function\trace_type\trace_type_builder.py:144, in from_value(value, context)
142 if context.is_legacy_signature and isinstance(value, trace.TraceType):
143 return value
--> 144 elif isinstance(value, trace.SupportsTracingProtocol):
145 generated_type = value.tf_tracing_type(context)
146 if not isinstance(generated_type, trace.TraceType):
File c:\Users\iitka.conda\envs\temp\lib\site-packages\typing_extensions.py:647, in _ProtocolMeta.instancecheck(cls, instance)
645 for attr in cls.protocol_attrs:
646 try:
--> 647 val = inspect.getattr_static(instance, attr)
648 except AttributeError:
649 break
File c:\Users\iitka.conda\envs\temp\lib\inspect.py:1743, in getattr_static(obj, attr, default)
1740 dict_attr = _shadowed_dict(klass)
1741 if (dict_attr is _sentinel or
1742 type(dict_attr) is types.MemberDescriptorType):
-> 1743 instance_result = _check_instance(obj, attr)
1744 else:
1745 klass = obj
File c:\Users\iitka.conda\envs\temp\lib\inspect.py:1690, in _check_instance(obj, attr)
1688 instance_dict = {}
1689 try:
-> 1690 instance_dict = object.getattribute(obj, "dict")
1691 except AttributeError:
1692 pass
TypeError: this dict descriptor does not support '_DictWrapper' objects
What are the possible reasons and how can I resolve it?
The text was updated successfully, but these errors were encountered: