-
Notifications
You must be signed in to change notification settings - Fork 8
/
mmoe.py
82 lines (74 loc) · 2.74 KB
/
mmoe.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
from typing import Dict, Text
import tensorflow as tf
import tensorflow_recommenders as tfrs
from trainer.models.common.basic_layers import MLPLayer
from trainer.models.common.multi_task import MMOELayer, UncertaintyWeightingLayer
from trainer.util.tools import ObjectDict
class MMOE(tfrs.Model):
def __init__(
self,
hparams: ObjectDict,
ranking_emb: tf.keras.Model,
):
super().__init__()
self.ranking_emb = ranking_emb
self.hparams = hparams
self.pctr_task: tf.keras.layers.Layer = tfrs.tasks.Ranking(
loss=tf.keras.losses.BinaryCrossentropy(),
metrics=[tf.keras.metrics.BinaryCrossentropy(), tf.keras.metrics.AUC()],
)
self.pctcvr_task: tf.keras.layers.Layer = tfrs.tasks.Ranking(
loss=tf.keras.losses.BinaryCrossentropy(),
metrics=[tf.keras.metrics.BinaryCrossentropy(), tf.keras.metrics.AUC()],
)
self.pctr_weight = hparams.pctr_weight
self.pctcvr_weight = hparams.pctcvr_weight
self.mmoe = MMOELayer(hparams.expert_num, hparams.gate_num)
self.tower1 = tf.keras.Sequential(
[
MLPLayer(),
tf.keras.layers.Dense(1, "sigmoid"),
]
)
self.tower2 = tf.keras.Sequential(
[
MLPLayer(),
tf.keras.layers.Dense(1, "sigmoid"),
]
)
self.MTO = UncertaintyWeightingLayer(hparams.gate_num)
def call(self, features: Dict[Text, tf.Tensor], training=False) -> tf.Tensor:
shared_emb = self.ranking_emb(features, training=training)
# list of [batch_size, embedding_size]
gated_list = tf.split(
self.mmoe(shared_emb, training=training),
num_or_size_splits=self.hparams.gate_num,
axis=1,
)
pctr = self.tower1(gated_list[0], training=training)
pcvr = self.tower2(gated_list[1], training=training)
return pctr, pcvr
def compute_loss(
self, features: Dict[Text, tf.Tensor], training=False
) -> tf.Tensor:
ctr_label = tf.expand_dims(
tf.where(features[self.hparams.label] > 0, 1, 0), axis=-1
)
ctcvr_label = tf.expand_dims(
tf.where(features[self.hparams.label] > 3, 1, 0), axis=-1
)
pctr, pcvr = self(features, training=training)
pctcvr = pctr * pcvr
# pctr loss
pctr_loss = self.pctr_task(
labels=ctr_label,
predictions=pctr,
training=training,
)
# pctcvr loss
pctcvr_loss = self.pctcvr_task(
labels=ctcvr_label,
predictions=pctcvr,
training=training,
)
return self.MTO([pctr_loss, pctcvr_loss])