Skip to content

Commit

Permalink
[quantizer] add QHardsigmoid (#379)
Browse files Browse the repository at this point in the history
* [quantizer] add QHarrdsigmoid

* update docs
  • Loading branch information
peterjc123 authored Dec 10, 2024
1 parent f79b0cc commit ccf5d73
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 1 deletion.
1 change: 1 addition & 0 deletions docs/quantization_support.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ Quantized OPs that are natively not supported by PyTorch (and possibly TFLite).
| `softmax` | For QATQuantizer/PostQuantizer, set `config={"set_quantizable_op_stats": True}`<br>For TFLiteConverter, set `rewrite_quantizable=True` |
| `sum` | For TFLiteConverter, set `rewrite_quantizable=True` |
| `torch.nn.GLU` | No action needed |
| `torch.nn.Hardsigmoid` | No action needed |
| `torch.nn.LogSoftmax` | For QATQuantizer/PostQuantizer, set `config={"set_quantizable_op_stats": True}`<br>For TFLiteConverter, set `rewrite_quantizable=True` |
| `torch.nn.PReLU` | No action needed |
| `torch.nn.SiLU` | No action needed |
Expand Down
19 changes: 19 additions & 0 deletions tinynn/graph/quantization/modules.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch
import torch.nn as nn
import torch.nn.quantized as nnq
import torch.quantization as torch_q


class QPReLU(nn.Module):
Expand Down Expand Up @@ -70,3 +71,21 @@ def __init__(self, glu: nn.GLU) -> None:
def forward(self, input: torch.Tensor) -> torch.Tensor:
slices = torch.chunk(input, 2, self.dim)
return self.f_mul.mul(slices[0], self.sigmoid(slices[1]))


class QHardsigmoid(nn.Module):
def __init__(self, hardsigmoid: nn.Hardsigmoid) -> None:
super().__init__()

self.f_mul = nnq.FloatFunctional()
self.f_add = nnq.FloatFunctional()
self.q = torch_q.QuantStub()
self.dq = torch_q.DeQuantStub()
self.act_hs = nn.Hardsigmoid()
self.act_r = nn.ReLU6()

def forward(self, input: torch.Tensor) -> torch.Tensor:
x1 = self.f_add.add_scalar(input, 3.0)
x2 = self.act_r(x1)
x3 = self.q(self.dq(x2))
return self.f_mul.mul_scalar(x3, 1 / 6)
3 changes: 2 additions & 1 deletion tinynn/graph/quantization/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
FakeQuantizeBFloat16,
FakeQuantizeTFLite,
)
from tinynn.graph.quantization.modules import QGLU, QPReLU, QSiLU
from tinynn.graph.quantization.modules import QGLU, QHardsigmoid, QPReLU, QSiLU
from tinynn.graph.quantization.observer import (
HistogramObserverKL,
MinMaxObserver,
Expand Down Expand Up @@ -223,6 +223,7 @@
Q_MODULES_MAPPING = {
nn.PReLU: QPReLU,
nn.GLU: QGLU,
nn.Hardsigmoid: QHardsigmoid,
}

FUNCTIONAL_MODULE_MAPPING = {
Expand Down

0 comments on commit ccf5d73

Please sign in to comment.