Skip to content

Commit

Permalink
Merge pull request apache#74 from oist/vc/temp_space
Browse files Browse the repository at this point in the history
Output temp space
  • Loading branch information
pluskid committed Apr 10, 2016
2 parents f7562e9 + acd2a74 commit 34efa24
Showing 1 changed file with 4 additions and 0 deletions.
4 changes: 4 additions & 0 deletions src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,8 @@ function _setup_predictor(self :: FeedForward, overwrite :: Bool=false; data_sha

# the predictor use only the first device
self.pred_exec = simple_bind(self.arch, self.ctx[1]; grad_req=GRAD_NOP, data_shapes...)
dbg_str = mx.debug_str(self.pred_exec)
info(string("TempSpace: ", split(dbg_str, ['\n'])[end-2]..., " on ", self.ctx[1]))
copy_params_from(self.pred_exec, self.arg_params, self.aux_params)
else
# make sure the new setup is compatible with the existing one
Expand Down Expand Up @@ -345,6 +347,8 @@ function fit(self :: FeedForward, optimizer :: AbstractOptimizer, data :: Abstra
data_shapes = [k => tuple(v[1:end-1]...,length(slices[i])) for (k,v) in provide_data(data)]
label_shapes = [k => tuple(v[1:end-1]...,length(slices[i])) for (k,v) in provide_label(data)]
train_execs[i] = simple_bind(self.arch, self.ctx[i]; grad_req=GRAD_WRITE, data_shapes..., label_shapes...)
dbg_str = mx.debug_str(train_execs[i])
info(string("TempSpace: ", split(dbg_str, ['\n'])[end-2]..., " on ", self.ctx[i]))

copy_params_from(train_execs[i], self.arg_params, self.aux_params)
end
Expand Down

0 comments on commit 34efa24

Please sign in to comment.