-
Notifications
You must be signed in to change notification settings - Fork 1
/
base.py
156 lines (127 loc) · 4.91 KB
/
base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
from collections import deque, OrderedDict
class Connection(object):
def __init__(self, obj_from, obj_to, key_from=None, key_to=None):
self.obj_from = obj_from
self.obj_to = obj_to
self.key_from = key_from
self.key_to = key_to
self.data = 0
def send(self):
self.obj_to.receive(self.data, self.key_to)
def receive(self, data, key):
assert key == self.key_from
self.data = data
def reset(self):
self.data = 0
class AbstractNode(object):
KEYS = []
def __init__(self):
self.data = OrderedDict.fromkeys(self.KEYS, 0)
self.connections = {}
def add_connection(self, conn, key):
assert key in self.KEYS
self.connections[key] = conn
def send(self):
for key, conn in self.connections.items():
conn.receive(self.data.get(key, 0), key)
def receive(self, data, key):
assert key in self.KEYS
self.data[key] = data
def process(self):
raise NotImplementedError
def reset(self):
self.data = OrderedDict.fromkeys(self.KEYS, 0)
class DataStream(object):
def __init__(self, data, key):
# todo: add possibility to specify fp/path to file with data
self.initial_data = data
self.queue = deque(data)
self.key = key
self.connection = None
def add_connection(self, conn, key=None):
self.connection = conn
def send(self):
self.connection.receive(self.queue.popleft() if len(self.queue) else 0, self.key)
def reset(self):
self.queue = deque(self.initial_data)
class SystolicArray(object):
def __init__(self, size, nodes_by_class, input_streams, connections):
self.current_step = 0
self.n, self.m = size
self.array = [[None] * self.m for _ in range(self.n)]
for node_cls, positions in nodes_by_class.items():
for rows, cols in positions:
for r in self.fix_sequence_of_indexes(rows):
for c in self.fix_sequence_of_indexes(cols):
self.array[r][c] = node_cls()
for i in range(self.n):
assert None not in self.array[i]
self.input_streams = input_streams
self.connections = []
for objs_from, objs_to, key_from, key_to in connections:
objs_from = self.get_elements_sequence(objs_from)
objs_to = self.get_elements_sequence(objs_to)
assert len(objs_from) == len(objs_to)
for obj_from, objs_to in zip(objs_from, objs_to):
conn = Connection(obj_from, objs_to, key_from, key_to)
obj_from.add_connection(conn, key_from)
self.connections.append(conn)
@staticmethod
def fix_sequence_of_indexes(seq):
if isinstance(seq, int):
seq = [seq]
assert isinstance(seq, (range, list))
return seq
def get_nodes_by_indexes(self, rows, cols):
rows = self.fix_sequence_of_indexes(rows)
cols = self.fix_sequence_of_indexes(cols)
objs = []
for r in rows:
for c in cols:
objs.append(self.array[r][c])
return objs
def get_elements_sequence(self, seq):
# elements == nodes + data streams
if isinstance(seq, tuple):
assert len(seq) == 2
return self.get_nodes_by_indexes(*seq)
if isinstance(seq, DataStream):
return [seq]
assert isinstance(seq, list)
return seq
def iterate(self, n_times=1):
for _ in range(n_times):
for key, streams in self.input_streams.items():
if isinstance(streams, DataStream):
streams = [streams]
assert isinstance(streams, list), \
"Expected `list`, but got `{obj}` of `{type}` type".format(
obj=streams, type=type(streams)
)
for stream in streams:
stream.send()
for r in range(self.n):
for c in range(self.m):
self.array[r][c].send()
for conn in self.connections:
conn.send()
for r in range(self.n):
for c in range(self.m):
self.array[r][c].process()
self.current_step += n_times
def reset(self):
for key, streams in self.input_streams.items():
if isinstance(streams, DataStream):
streams = [streams]
assert isinstance(streams, list), \
"Expected `list`, but got `{obj}` of `{type}` type".format(
obj=streams, type=type(streams)
)
for stream in streams:
stream.reset()
for r in range(self.n):
for c in range(self.m):
self.array[r][c].reset()
for conn in self.connections:
conn.reset()
self.current_step = 0