Skip to content

Commit

Permalink
Merge pull request apache#50 from kasiabozek/acc_callback
Browse files Browse the repository at this point in the history
accuracy in callback apache#49
  • Loading branch information
pluskid committed Dec 9, 2015
2 parents 8d3a9a0 + 49a92d1 commit 9286bcc
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 10 deletions.
8 changes: 4 additions & 4 deletions src/callback.jl
Original file line number Diff line number Diff line change
Expand Up @@ -110,13 +110,13 @@ end
function every_n_epoch(callback :: Function, n :: Int; call_on_0 :: Bool = false)
EpochCallback(n, call_on_0, callback)
end
function Base.call(cb :: EpochCallback, model :: Any, state :: OptimizationState)
function Base.call{T<:Real}(cb :: EpochCallback, model :: Any, state :: OptimizationState, metric :: Vector{Tuple{Base.Symbol, T}})
if state.curr_epoch == 0
if cb.call_on_0
cb.callback(model, state)
cb.callback(model, state, metric)
end
elseif state.curr_epoch % cb.frequency == 0
cb.callback(model, state)
cb.callback(model, state, metric)
end
end

Expand All @@ -136,7 +136,7 @@ end
=#
function do_checkpoint(prefix::AbstractString; frequency::Int=1, save_epoch_0=false)
mkpath(dirname(prefix))
every_n_epoch(frequency, call_on_0=save_epoch_0) do model, state
every_n_epoch(frequency, call_on_0=save_epoch_0) do model, state, metric
save_checkpoint(model, prefix, state)
end
end
13 changes: 7 additions & 6 deletions src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -270,13 +270,14 @@ end
callbacks :: Vector{AbstractCallback} = AbstractCallback[],
)

function _invoke_callbacks(self::FeedForward, callbacks::Vector{AbstractCallback},
state::OptimizationState, type_filter::Type)
function _invoke_callbacks{T<:Real}(self::FeedForward, callbacks::Vector{AbstractCallback},
state::OptimizationState, type_filter::Type;
metric::Vector{Tuple{Base.Symbol, T}} = Vector{Tuple{Base.Symbol, Real}}())
map(callbacks) do cb
if isa(cb, type_filter)
if type_filter == AbstractEpochCallback
# epoch callback have extra access to the model object
cb(self, state)
cb(self, state, metric)
else
cb(state)
end
Expand Down Expand Up @@ -465,9 +466,10 @@ function fit(self :: FeedForward, optimizer :: AbstractOptimizer, data :: Abstra
end # end of one epoch

time_stop = time()
metric = get(opts.eval_metric)
info(format("== Epoch {1:0>3d} ==========", i_epoch))
info("## Training summary")
for (name, value) in get(opts.eval_metric)
for (name, value) in metric
info(format("{1:>18s} = {2:.4f}", string(name), value))
end
info(format("{1:>18s} = {2:.4f} seconds", "time", time_stop-time_start))
Expand Down Expand Up @@ -514,7 +516,7 @@ function fit(self :: FeedForward, optimizer :: AbstractOptimizer, data :: Abstra
copy!(self.aux_params[name], aux_avg)
end
end
_invoke_callbacks(self, opts.callbacks, op_state, AbstractEpochCallback)
_invoke_callbacks(self, opts.callbacks, op_state, AbstractEpochCallback; metric=metric)
end # end of all epochs
end

Expand Down Expand Up @@ -573,4 +575,3 @@ function load_checkpoint(self :: FeedForward, prefix :: AbstractString, epoch ::
self.aux_params = aux_params
return self
end

0 comments on commit 9286bcc

Please sign in to comment.