Skip to content

Commit

Permalink
Merge branch 'Lucky4Luuk-master'
Browse files Browse the repository at this point in the history
  • Loading branch information
milanboers committed Jan 23, 2024
2 parents 5bbaeb6 + 69f23f1 commit ee3b604
Showing 1 changed file with 19 additions and 6 deletions.
25 changes: 19 additions & 6 deletions src/dqn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,11 +131,12 @@ where
Some(target.into())
}

#[allow(clippy::boxed_local)]
pub fn train_dqn(
&mut self,
states: [[f32; STATE_SIZE]; BATCH],
states: Box<[[f32; STATE_SIZE]; BATCH]>,
actions: [[f32; ACTION_SIZE]; BATCH],
next_states: [[f32; STATE_SIZE]; BATCH],
next_states: Box<[[f32; STATE_SIZE]; BATCH]>,
rewards: [f32; BATCH],
dones: [bool; BATCH],
) {
Expand All @@ -148,7 +149,7 @@ where

// Convert to tensors and normalize the states for better training
let states: Tensor<Rank2<BATCH, STATE_SIZE>, f32, _> =
self.dev.tensor(states).normalize::<Axis<1>>(0.001);
self.dev.tensor(*states).normalize::<Axis<1>>(0.001);

// Convert actions to tensors and get the max action for each batch
let actions: Tensor<Rank1<BATCH>, usize, _> = self.dev.tensor(actions.map(|a| {
Expand All @@ -165,7 +166,7 @@ where

// Convert to tensors and normalize the states for better training
let next_states: Tensor<Rank2<BATCH, STATE_SIZE>, f32, _> =
self.dev.tensor(next_states).normalize::<Axis<1>>(0.001);
self.dev.tensor(*next_states).normalize::<Axis<1>>(0.001);

// Compute the estimated Q-value for the action
for _step in 0..20 {
Expand Down Expand Up @@ -203,9 +204,21 @@ where
) {
loop {
// Initialize batch
let mut states: [[f32; STATE_SIZE]; BATCH] = [[0.0; STATE_SIZE]; BATCH];
let mut states: Box<[[f32; STATE_SIZE]; BATCH]> = {
let b = vec![0.0; STATE_SIZE].into_boxed_slice();
let big = unsafe { Box::from_raw(Box::into_raw(b) as *mut [f32; STATE_SIZE]) };

let b = vec![*big; BATCH].into_boxed_slice();
unsafe { Box::from_raw(Box::into_raw(b) as *mut [[f32; STATE_SIZE]; BATCH]) }
};
let mut actions: [[f32; ACTION_SIZE]; BATCH] = [[0.0; ACTION_SIZE]; BATCH];
let mut next_states: [[f32; STATE_SIZE]; BATCH] = [[0.0; STATE_SIZE]; BATCH];
let mut next_states: Box<[[f32; STATE_SIZE]; BATCH]> = {
let b = vec![0.0; STATE_SIZE].into_boxed_slice();
let big = unsafe { Box::from_raw(Box::into_raw(b) as *mut [f32; STATE_SIZE]) };

let b = vec![*big; BATCH].into_boxed_slice();
unsafe { Box::from_raw(Box::into_raw(b) as *mut [[f32; STATE_SIZE]; BATCH]) }
};
let mut rewards: [f32; BATCH] = [0.0; BATCH];
let mut dones = [false; BATCH];

Expand Down

0 comments on commit ee3b604

Please sign in to comment.