diff --git a/test_agent_kerasrl.py b/test_agent_kerasrl.py index 2072a61..822646c 100644 --- a/test_agent_kerasrl.py +++ b/test_agent_kerasrl.py @@ -39,7 +39,7 @@ def evaluate( action_function ): # option 1: Boltzmann sampling from Q-function network output softmax = lambda x : np.exp( x ) / np.sum( np.exp( x )) - boltzmann_action = lambda x : np.argmax( np.random.multinomial( 1, softmax(x).flatten())) + boltzmann_action = lambda x : np.random.choice( range(len(x)), p=softmax(x).flatten()) # option 2: maximize the Q value, ignoring stochastic action space best_action = lambda x : np.argmax( x ) @@ -70,4 +70,4 @@ def f(bytez): print("Success rate of random chance: {}\n".format( len(random_success) / total )) print("Success rate (dqn): {}\n".format( len(dqn_success) / total ) ) print("Success rate (dqn): {}\n".format( len(dqn_score_success) / total ) ) - \ No newline at end of file +