Skip to content

Commit

Permalink
Merge pull request lukemelas#90 from lukemelas/swish
Browse files Browse the repository at this point in the history
Add memory-efficient and export-friendly swish activation functions
  • Loading branch information
lukemelas authored Oct 15, 2019
2 parents faa8430 + a8ce81a commit 8e8e137
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 15 deletions.
8 changes: 6 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
# EfficientNet PyTorch

### Update (October 12, 2019)
### Update (October 15, 2019)

This update makes the Swish activation function more memory-efficient. It also addresses pull requests #72, #73, #85, and #86. Thanks to the authors of all the pull requests!
This update allows you to choose whether to use a memory-efficient Swish activation. The memory-efficient version is chosen by default, but it cannot be used when exporting using PyTorch JIT. For this purpose, we have also included a standard (export-friendly) swish activation function. To switch to the export-friendly version, simply call `model.set_swish(memory_efficient=False)` after loading your desired model. This update addresses issues [#88](https://github.com/lukemelas/EfficientNet-PyTorch/pull/88) and [#89](https://github.com/lukemelas/EfficientNet-PyTorch/pull/89).

#### Update (October 12, 2019)

This update makes the Swish activation function more memory-efficient. It also addresses pull requests [#72](https://github.com/lukemelas/EfficientNet-PyTorch/pull/72), [#73](https://github.com/lukemelas/EfficientNet-PyTorch/pull/73), [#85](https://github.com/lukemelas/EfficientNet-PyTorch/pull/85), and [#86](https://github.com/lukemelas/EfficientNet-PyTorch/pull/86). Thanks to the authors of all the pull requests!

### Update (July 31, 2019)

Expand Down
2 changes: 1 addition & 1 deletion efficientnet_pytorch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.5.0"
__version__ = "0.5.1"
from .model import EfficientNet
from .utils import (
GlobalParams,
Expand Down
26 changes: 20 additions & 6 deletions efficientnet_pytorch/model.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
from torch.nn import functional as F

from .utils import (
relu_fn,
round_filters,
round_repeats,
drop_connect,
get_same_padding_conv2d,
get_model_params,
efficientnet_params,
load_pretrained_weights,
Swish,
MemoryEfficientSwish,
)

class MBConvBlock(nn.Module):
Expand Down Expand Up @@ -61,6 +62,7 @@ def __init__(self, block_args, global_params):
final_oup = self._block_args.output_filters
self._project_conv = Conv2d(in_channels=oup, out_channels=final_oup, kernel_size=1, bias=False)
self._bn2 = nn.BatchNorm2d(num_features=final_oup, momentum=self._bn_mom, eps=self._bn_eps)
self._swish = MemoryEfficientSwish()

def forward(self, inputs, drop_connect_rate=None):
"""
Expand All @@ -72,13 +74,13 @@ def forward(self, inputs, drop_connect_rate=None):
# Expansion and Depthwise Convolution
x = inputs
if self._block_args.expand_ratio != 1:
x = relu_fn(self._bn0(self._expand_conv(inputs)))
x = relu_fn(self._bn1(self._depthwise_conv(x)))
x = self._swish(self._bn0(self._expand_conv(inputs)))
x = self._swish(self._bn1(self._depthwise_conv(x)))

# Squeeze and Excitation
if self.has_se:
x_squeezed = F.adaptive_avg_pool2d(x, 1)
x_squeezed = self._se_expand(relu_fn(self._se_reduce(x_squeezed)))
x_squeezed = self._se_expand(self._swish(self._se_reduce(x_squeezed)))
x = torch.sigmoid(x_squeezed) * x

x = self._bn2(self._project_conv(x))
Expand All @@ -91,6 +93,10 @@ def forward(self, inputs, drop_connect_rate=None):
x = x + inputs # skip connection
return x

def set_swish(self, memory_efficient=True):
"""Sets swish function as memory efficient (for training) or standard (for export)"""
self._swish = MemoryEfficientSwish() if memory_efficient else Swish()


class EfficientNet(nn.Module):
"""
Expand Down Expand Up @@ -153,12 +159,20 @@ def __init__(self, blocks_args=None, global_params=None):
self._avg_pooling = nn.AdaptiveAvgPool2d(1)
self._dropout = nn.Dropout(self._global_params.dropout_rate)
self._fc = nn.Linear(out_channels, self._global_params.num_classes)
self._swish = MemoryEfficientSwish()

def set_swish(self, memory_efficient=True):
"""Sets swish function as memory efficient (for training) or standard (for export)"""
self._swish = MemoryEfficientSwish() if memory_efficient else Swish()
for block in self._blocks:
block.set_swish(memory_efficient)


def extract_features(self, inputs):
""" Returns output of the final convolution layer """

# Stem
x = relu_fn(self._bn0(self._conv_stem(inputs)))
x = self._swish(self._bn0(self._conv_stem(inputs)))

# Blocks
for idx, block in enumerate(self._blocks):
Expand All @@ -168,7 +182,7 @@ def extract_features(self, inputs):
x = block(x, drop_connect_rate=drop_connect_rate)

# Head
x = relu_fn(self._bn1(self._conv_head(x)))
x = self._swish(self._bn1(self._conv_head(x)))

return x

Expand Down
10 changes: 5 additions & 5 deletions efficientnet_pytorch/utils.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,13 @@ def backward(ctx, grad_output):
return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i)))


class Swish(nn.Module):
@staticmethod
def forward(x):
class MemoryEfficientSwish(nn.Module):
def forward(self, x):
return SwishImplementation.apply(x)


relu_fn = Swish()
class Swish(nn.Module):
def forward(self, x):
return x * torch.sigmoid(x)


def round_filters(filters, global_params):
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
EMAIL = 'lmelaskyriazi@college.harvard.edu'
AUTHOR = 'Luke'
REQUIRES_PYTHON = '>=3.5.0'
VERSION = '0.5.0'
VERSION = '0.5.1'

# What packages are required for this module to be executed?
REQUIRED = [
Expand Down

0 comments on commit 8e8e137

Please sign in to comment.