From 18c18b70fa29e9e2b46f4d90a13a930fc84e2e68 Mon Sep 17 00:00:00 2001 From: aweally Date: Tue, 26 Jun 2018 10:13:41 +0800 Subject: [PATCH] update boltzmann_action definition To solve the ValueError: sum(pvals[:-1]) > 1.0 as described in https://github.com/endgameinc/gym-malware/issues/7 --- test_agent_kerasrl.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test_agent_kerasrl.py b/test_agent_kerasrl.py index 2072a61..822646c 100644 --- a/test_agent_kerasrl.py +++ b/test_agent_kerasrl.py @@ -39,7 +39,7 @@ def evaluate( action_function ): # option 1: Boltzmann sampling from Q-function network output softmax = lambda x : np.exp( x ) / np.sum( np.exp( x )) - boltzmann_action = lambda x : np.argmax( np.random.multinomial( 1, softmax(x).flatten())) + boltzmann_action = lambda x : np.random.choice( range(len(x)), p=softmax(x).flatten()) # option 2: maximize the Q value, ignoring stochastic action space best_action = lambda x : np.argmax( x ) @@ -70,4 +70,4 @@ def f(bytez): print("Success rate of random chance: {}\n".format( len(random_success) / total )) print("Success rate (dqn): {}\n".format( len(dqn_success) / total ) ) print("Success rate (dqn): {}\n".format( len(dqn_score_success) / total ) ) - \ No newline at end of file +