-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_ir.py
34 lines (33 loc) · 1.04 KB
/
test_ir.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import numpy as np
from mindspore.communication import init
from mindspore import nn, Model
from mindspore import context, Tensor
from mindspore import Parameter
class DenseLayer(nn.Cell):
"""A base layer with two dense layer"""
def __init__(self):
super().__init__()
self.input_mapping = nn.Dense(10, 10)
self.output_mapping = nn.Dense(10, 10)
def construct(self, x):
x = self.input_mapping(x)
return self.output_mapping(x)
class Net(nn.Cell):
"""An network with many dense layers"""
def __init__(self):
super().__init__()
self.layer1 = DenseLayer()
self.layer2 = DenseLayer()
self.layer3 = DenseLayer()
self.layer1.set_comm_fusion(0)
self.layer2.set_comm_fusion(1)
self.layer3.set_comm_fusion(2)
def construct(self, x):
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
return x
context.set_context(save_graphs=True)
net = Net()
model = Model(net)
model.predict(Tensor(np.ones((10, 10), dtype=np.float32)))