-
Notifications
You must be signed in to change notification settings - Fork 1
/
normalization.py
38 lines (29 loc) · 1.29 KB
/
normalization.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
# -*- coding:utf-8 -*-
"""
Author:
Weichen Shen,wcshen1994@163.com
"""
from tensorflow.python.keras import backend as K
from tensorflow.python.keras.initializers import Ones, Zeros
from tensorflow.python.keras.layers import Layer
class LayerNormalization(Layer):
def __init__(self, axis=-1, eps=1e-9, **kwargs):
self.axis = axis
self.eps = eps
super(LayerNormalization, self).__init__(**kwargs)
def build(self, input_shape):
self.gamma = self.add_weight(name='gamma', shape=input_shape[-1:],
initializer=Ones(), trainable=True)
self.beta = self.add_weight(name='beta', shape=input_shape[-1:],
initializer=Zeros(), trainable=True)
super(LayerNormalization, self).build(input_shape)
def call(self, x):
mean = K.mean(x, axis=self.axis, keepdims=True)
std = K.std(x, axis=self.axis, keepdims=True)
return self.gamma * (x - mean) / (std + self.eps) + self.beta
def compute_output_shape(self, input_shape):
return input_shape
def get_config(self, ):
config = {'axis': self.axis, 'eps': self.eps}
base_config = super(LayerNormalization, self).get_config()
return dict(list(base_config.items()) + list(config.items()))