-
Notifications
You must be signed in to change notification settings - Fork 12
/
redmule_x_buffer.sv
194 lines (180 loc) · 5.58 KB
/
redmule_x_buffer.sv
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
// Copyright 2023 ETH Zurich and University of Bologna.
// Solderpad Hardware License, Version 0.51, see LICENSE for details.
// SPDX-License-Identifier: SHL-0.51
//
// Yvan Tortorella <yvan.tortorella@unibo.it>
//
module redmule_x_buffer
import fpnew_pkg::*;
import redmule_pkg::*;
#(
parameter int unsigned DW = 288,
parameter fpnew_pkg::fp_format_e FpFormat = fpnew_pkg::FP16,
parameter int unsigned Height = ARRAY_HEIGHT, // Number of PEs per row
parameter int unsigned Width = ARRAY_WIDTH, // Number of parallel index
localparam int unsigned BITW = fpnew_pkg::fp_width(FpFormat), // Number of bits for the given format
localparam int unsigned H = Height,
localparam int unsigned W = Width,
localparam int unsigned D = DW/(H*BITW),
localparam int unsigned HALF_D = D/2,
localparam int unsigned TOT_DEPTH = H*D
)(
input logic clk_i ,
input logic rst_ni ,
input logic clear_i ,
input x_buffer_ctrl_t ctrl_i ,
output x_buffer_flgs_t flags_o ,
output logic [W-1:0][H-1:0][BITW-1:0] x_buffer_o,
input logic [DW-1:0] x_buffer_i
);
logic rst_w_load, rst_d_shift, rst_h_shift, empty_rst;
logic [$clog2(W):0] w_index, w_limit;
logic [$clog2(H)-1:0] h_index;
logic [$clog2(D):0] d_shift, empty_count, empty_count_q;
logic [$clog2(TOT_DEPTH):0] depth;
logic [D-1:0][W-1:0][H-1:0][BITW-1:0] x_pad_q;
logic [(D/2)-1:0][W-1:0][H-1:0][BITW-1:0] x_buffer_q;
always_ff @(posedge clk_i or negedge rst_ni) begin : bump_register
if(~rst_ni) begin
x_pad_q <= '0;
x_buffer_q <= '0;
end else begin
if (clear_i) begin
x_pad_q <= '0;
x_buffer_q <= '0;
end else
if (ctrl_i.load) begin
for (int d = 0; d < D; d++) begin
for (int h = 0; h < H; h++) begin
x_pad_q[d][w_index][h] <= ( (H*d + h) < depth ) ? x_buffer_i[(H*d + h)*BITW+:BITW] : '0;
end
end
end
if (ctrl_i.d_shift) begin
for (int w = 0; w < W; w++) begin
for (int h = 0; h < H; h++) begin
for (int d = 0; d < D; d++) begin
x_pad_q[d][w][h] <= (d < D - 1) ? x_pad_q[d+1][w][h] : '0;
x_buffer_q[HALF_D-1][w][h] <= x_pad_q[0][w][h];
end
end
end
end
if (ctrl_i.blck_shift) begin
for (int w = 0; w < W; w++) begin
for (int h = 0; h < H; h++) begin
for (int d = 0; d < D; d++)
x_pad_q[d][w][h] <= (d < HALF_D) ? x_pad_q[d+2][w][h] : '0;
for (int dd = 0; dd < HALF_D; dd++)
x_buffer_q[dd][w][h] <= x_pad_q[dd][w][h];
end
end
end
if (ctrl_i.h_shift) begin
for (int w = 0; w < W; w++) begin
for (int h = 0; h < H; h++) begin
for (int d = 0; d < D; d++)
x_buffer_q[0][w][h_index] <= x_buffer_q[1][w][h_index];
end
end
end
end
end
assign depth = (ctrl_i.cols_lftovr == '0) ? TOT_DEPTH : ctrl_i.cols_lftovr;
// Counter to track the rows that have to be loaded
always_ff @(posedge clk_i or negedge rst_ni) begin : row_loaded_counter
if(~rst_ni) begin
w_index <= '0;
end else begin
if (rst_w_load || clear_i)
w_index <= '0;
else if (ctrl_i.load)
w_index <= w_index + 1;
else
w_index <= w_index;
end
end
assign w_limit = (ctrl_i.rows_lftovr != '0) ? ctrl_i.rows_lftovr : W;
always_comb begin : load_count_rst
rst_w_load = 1'b0;
flags_o.full = 1'b0;
if (w_index == w_limit || w_index == W) begin
rst_w_load = 1'b1;
flags_o.full = 1'b1;
end else begin
rst_w_load = 1'b0;
flags_o.full = 1'b0;
end
end
// Depth shift counter
always_ff @(posedge clk_i or negedge rst_ni) begin : d_shift_counter
if(~rst_ni) begin
d_shift <= '0;
end else begin
if (rst_d_shift || clear_i)
d_shift <= '0;
else if (ctrl_i.blck_shift)
d_shift <= d_shift + 2;
else if (ctrl_i.d_shift)
d_shift <= d_shift + 1;
else
d_shift <= d_shift;
end
end
always_comb begin
if (ctrl_i.cols_lftovr != '0)
empty_count = ctrl_i.slots;
else
empty_count = D;
end
always_ff @(posedge clk_i or negedge rst_ni) begin : empty_count_reg
if(~rst_ni) begin
empty_count_q <= '0;
end else begin
if (clear_i || empty_rst)
empty_count_q <= D;
else begin
if (ctrl_i.cols_lftovr != '0)
empty_count_q <= ctrl_i.slots;
else
empty_count_q <= empty_count_q;
end
end
end
always_comb begin : empty_gen_and_shift_count_rst
flags_o.empty = 1'b0;
rst_d_shift = 1'b0;
empty_rst = 1'b0;
if (d_shift == empty_count_q) begin
flags_o.empty = 1'b1;
rst_d_shift = 1'b1;
if (empty_count_q != depth)
empty_rst = 1'b1;
end else begin
flags_o.empty = 1'b0;
rst_d_shift = 1'b0;
empty_rst = 1'b0;
end
end
// H shift counter
always_ff @(posedge clk_i or negedge rst_ni) begin : h_shift_counter
if(~rst_ni) begin
h_index <= '0;
end else begin
if (rst_h_shift || clear_i)
h_index <= '0;
else if(ctrl_i.h_shift)
h_index <= h_index + 1;
else
h_index <= h_index;
end
end
// Output assignment
generate
for (genvar w = 0; w < W; w++) begin
for (genvar h = 0; h < H; h++) begin
assign x_buffer_o[w][h] = x_buffer_q[0][w][h];
end
end
endgenerate
endmodule : redmule_x_buffer