Skip to content

Commit

Permalink
Merge pull request #744 from lukeyeager/iter_size
Browse files Browse the repository at this point in the history
Expose iter_size solver option
  • Loading branch information
lukeyeager committed May 18, 2016
2 parents 0ce826d + 44be59a commit 03e7f11
Show file tree
Hide file tree
Showing 9 changed files with 96 additions and 39 deletions.
5 changes: 5 additions & 0 deletions digits/frameworks/caffe_framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,3 +121,8 @@ def get_network_visualization(self, desc):
net.name = 'Network'
return '<image src="data:image/png;base64,' + caffe.draw.draw_net(net, 'UD').encode('base64') + '" style="max-width:100%" />'

@override
def can_accumulate_gradients(self):
return (config_value('caffe_root')['version']
> parse_version('0.14.0-alpha'))

5 changes: 2 additions & 3 deletions digits/frameworks/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@ def get_network_visualization(self, desc):
"""
raise NotImplementedError('Please implement me')




def can_accumulate_gradients(self):
return False

9 changes: 9 additions & 0 deletions digits/model/forms.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,15 @@ def validate_py_ext(form, field):
tooltip = "How many images to process at once. If blank, values are used from the network definition."
)

batch_accumulation = utils.forms.IntegerField('Batch Accumulation',
default=1,
validators = [
validators.NumberRange(min=1),
validators.Optional(),
],
tooltip = "Accumulate gradients over multiple batches (useful when you need a bigger batch size for training but it doesn't fit in memory)."
)

### Solver types

solver_type = utils.forms.SelectField(
Expand Down
35 changes: 18 additions & 17 deletions digits/model/images/classification/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,23 +253,24 @@ def create():
else ''), form.python_layer_server_file.data)

job.tasks.append(fw.create_train_task(
job = job,
dataset = datasetJob,
train_epochs = form.train_epochs.data,
snapshot_interval = form.snapshot_interval.data,
learning_rate = form.learning_rate.data[0],
lr_policy = policy,
gpu_count = gpu_count,
selected_gpus = selected_gpus,
batch_size = form.batch_size.data[0],
val_interval = form.val_interval.data,
pretrained_model= pretrained_model,
crop_size = form.crop_size.data,
use_mean = form.use_mean.data,
network = network,
random_seed = form.random_seed.data,
solver_type = form.solver_type.data,
shuffle = form.shuffle.data,
job = job,
dataset = datasetJob,
train_epochs = form.train_epochs.data,
snapshot_interval = form.snapshot_interval.data,
learning_rate = form.learning_rate.data[0],
lr_policy = policy,
gpu_count = gpu_count,
selected_gpus = selected_gpus,
batch_size = form.batch_size.data[0],
batch_accumulation = form.batch_accumulation.data,
val_interval = form.val_interval.data,
pretrained_model = pretrained_model,
crop_size = form.crop_size.data,
use_mean = form.use_mean.data,
network = network,
random_seed = form.random_seed.data,
solver_type = form.solver_type.data,
shuffle = form.shuffle.data,
)
)

Expand Down
35 changes: 18 additions & 17 deletions digits/model/images/generic/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,23 +205,24 @@ def create():
else ''), form.python_layer_server_file.data)

job.tasks.append(fw.create_train_task(
job = job,
dataset = datasetJob,
train_epochs = form.train_epochs.data,
snapshot_interval = form.snapshot_interval.data,
learning_rate = form.learning_rate.data[0],
lr_policy = policy,
gpu_count = gpu_count,
selected_gpus = selected_gpus,
batch_size = form.batch_size.data[0],
val_interval = form.val_interval.data,
pretrained_model= pretrained_model,
crop_size = form.crop_size.data,
use_mean = form.use_mean.data,
network = network,
random_seed = form.random_seed.data,
solver_type = form.solver_type.data,
shuffle = form.shuffle.data,
job = job,
dataset = datasetJob,
train_epochs = form.train_epochs.data,
snapshot_interval = form.snapshot_interval.data,
learning_rate = form.learning_rate.data[0],
lr_policy = policy,
gpu_count = gpu_count,
selected_gpus = selected_gpus,
batch_size = form.batch_size.data[0],
batch_accumulation = form.batch_accumulation.data,
val_interval = form.val_interval.data,
pretrained_model = pretrained_model,
crop_size = form.crop_size.data,
use_mean = form.use_mean.data,
network = network,
random_seed = form.random_seed.data,
solver_type = form.solver_type.data,
shuffle = form.shuffle.data,
)
)

Expand Down
10 changes: 10 additions & 0 deletions digits/model/tasks/caffe_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,11 @@ def save_files_classification(self):

solver.snapshot_prefix = self.snapshot_prefix

# Batch accumulation
from digits.frameworks import CaffeFramework
if CaffeFramework().can_accumulate_gradients():
solver.iter_size = self.batch_accumulation

# Epochs -> Iterations
train_iter = int(math.ceil(float(self.dataset.get_entry_count(constants.TRAIN_DB)) / train_data_layer.data_param.batch_size))
solver.max_iter = train_iter * self.train_epochs
Expand Down Expand Up @@ -623,6 +628,11 @@ def save_files_generic(self):

solver.snapshot_prefix = self.snapshot_prefix

# Batch accumulation
from digits.frameworks import CaffeFramework
if CaffeFramework().can_accumulate_gradients():
solver.iter_size = self.batch_accumulation

# Epochs -> Iterations
train_iter = int(math.ceil(float(self.dataset.get_entry_count(constants.TRAIN_DB)) / train_image_data_layer.data_param.batch_size))
solver.max_iter = train_iter * self.train_epochs
Expand Down
2 changes: 2 additions & 0 deletions digits/model/tasks/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def __init__(self, job, dataset, train_epochs, snapshot_interval, learning_rate,
gpu_count -- how many GPUs to use for training (integer)
selected_gpus -- a list of GPU indexes to be used for training
batch_size -- if set, override any network specific batch_size with this value
batch_accumulation -- accumulate gradients over multiple batches
val_interval -- how many epochs between validating the model with an epoch of validation data
pretrained_model -- filename for a model to use for fine-tuning
crop_size -- crop each image down to a square of this size
Expand All @@ -47,6 +48,7 @@ def __init__(self, job, dataset, train_epochs, snapshot_interval, learning_rate,
self.gpu_count = kwargs.pop('gpu_count', None)
self.selected_gpus = kwargs.pop('selected_gpus', None)
self.batch_size = kwargs.pop('batch_size', None)
self.batch_accumulation = kwargs.pop('batch_accumulation', None)
self.val_interval = kwargs.pop('val_interval', None)
self.pretrained_model = kwargs.pop('pretrained_model', None)
self.crop_size = kwargs.pop('crop_size', None)
Expand Down
17 changes: 16 additions & 1 deletion digits/templates/models/images/classification/new.html
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,12 @@ <h4>Solver Options</h4>
</small>
{{form.batch_size(class='form-control', placeholder='[network defaults]')}}
</div>
<div class="form-group{{mark_errors([form.batch_accumulation])}}"
id="batch-accumulation-option" style="display:none;">
{{form.batch_accumulation.label}}
{{form.batch_accumulation.tooltip}}
{{form.batch_accumulation(class='form-control')}}
</div>
<div class="form-group{{mark_errors([form.solver_type])}}">
{{form.solver_type.label}}
{{form.solver_type.tooltip}}
Expand Down Expand Up @@ -407,7 +413,8 @@ <h4>Solver Options</h4>
{% for fw in frameworks %}
framework = {
name : '{{ fw.get_name() }}',
can_shuffle : '{{ fw.can_shuffle_data() }}'=='True'
can_shuffle : '{{ fw.can_shuffle_data() }}'=='True',
can_accumulate_gradients : '{{ fw.can_accumulate_gradients() }}'=='True',
};
frameworks['{{ fw.get_id() }}'] = framework;
{% endfor %}
Expand All @@ -432,6 +439,14 @@ <h4>Solver Options</h4>
$("#torch-warning").hide();
$('#stdnetRole a[href="'+"#"+fwid+"_standard"+'"]').tab('show');
$('#customFramework a[href="'+"#"+fwid+"_custom"+'"]').tab('show');

if (frameworks[fwid].can_accumulate_gradients) {
$('#batch_accumulation').prop('disabled', false);
$('#batch-accumulation-option').show();
} else {
$('#batch-accumulation-option').hide();
$('#batch_accumulation').prop('disabled', true);
}
}
</script>

Expand Down
17 changes: 16 additions & 1 deletion digits/templates/models/images/generic/new.html
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,12 @@ <h4>Solver Options</h4>
</small>
{{form.batch_size(class='form-control', placeholder='[network defaults]')}}
</div>
<div class="form-group{{mark_errors([form.batch_accumulation])}}"
id="batch-accumulation-option" style="display:none;">
{{form.batch_accumulation.label}}
{{form.batch_accumulation.tooltip}}
{{form.batch_accumulation(class='form-control')}}
</div>
<div class="form-group{{mark_errors([form.solver_type])}}">
{{form.solver_type.label}}
{{form.solver_type.tooltip}}
Expand Down Expand Up @@ -404,7 +410,8 @@ <h4>Solver Options</h4>
{% for fw in frameworks %}
framework = {
name : '{{ fw.get_name() }}',
can_shuffle : '{{ fw.can_shuffle_data() }}'=='True'
can_shuffle : '{{ fw.can_shuffle_data() }}'=='True',
can_accumulate_gradients : '{{ fw.can_accumulate_gradients() }}'=='True',
};
frameworks['{{ fw.get_id() }}'] = framework;
{% endfor %}
Expand All @@ -422,6 +429,14 @@ <h4>Solver Options</h4>
$("select[name=solver_type] > option:selected").prop("selected", false);
}
$('#customFramework a[href="'+"#"+fwid+"_custom"+'"]').tab('show');

if (frameworks[fwid].can_accumulate_gradients) {
$('#batch_accumulation').prop('disabled', false);
$('#batch-accumulation-option').show();
} else {
$('#batch-accumulation-option').hide();
$('#batch_accumulation').prop('disabled', true);
}
}
</script>

Expand Down

0 comments on commit 03e7f11

Please sign in to comment.