-
Notifications
You must be signed in to change notification settings - Fork 5
/
mpi_evo.py
229 lines (198 loc) · 8.71 KB
/
mpi_evo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
# from mpi4py import MPI
from timeit import default_timer as timer
from circuit_dynamics_init import *
from pybind_circuit import unitary_cxx_parallel
import sys
start = timer()
# Global error handler to avoid MPI deadlock (a known bug in mpi4py)
# source https://github.com/chainer/chainermn/issues/236
def global_except_hook(exctype, value, traceback):
import sys
try:
import mpi4py.MPI
sys.stderr.write("\n*****************************************************\n")
sys.stderr.write("Uncaught exception was detected on rank {}. \n".format(
mpi4py.MPI.COMM_WORLD.Get_rank()))
from traceback import print_exception
print_exception(exctype, value, traceback)
sys.stderr.write("*****************************************************\n\n\n")
sys.stderr.write("\n")
sys.stderr.write("Calling MPI_Abort() to shut down MPI processes...\n")
sys.stderr.flush()
finally:
try:
import mpi4py.MPI
mpi4py.MPI.COMM_WORLD.Abort(1)
except Exception as e:
sys.stderr.write("*****************************************************\n")
sys.stderr.write("Sorry, we failed to stop MPI, this process will hang.\n")
sys.stderr.write("*****************************************************\n")
sys.stderr.flush()
raise e
sys.excepthook = global_except_hook
# reading parameters from file
para = open('para_haar.txt', 'r')
para = para.readlines()
# the paramters are system size, measurement probability and discrete time steps
L, pro, time = int(para[0]), float(para[1]), int(para[2])
# system partition
# with PBC, we partition system into 4 parts where a and b separated by c1 and c2
# c1 and c2 are effectively connected, so the system is composed of A, B and C
lc1, la, lb = int(np.floor(L/8)), int(np.floor(L/4)), int(np.floor(L/4))
lc2 = L-lc1-la-lb
# pack the partition into array
part= np.array([L, la, lb, lc1, lc2], dtype="int64")
# initializing wavefunctions
p1 = np.ones(1)
p2 = np.zeros(2**L-1,dtype='c16')
# a product state with all spins align up
psi = np.concatenate((p1,p2),axis=0).T
# MPI session
import mpi4py.MPI
comm = mpi4py.MPI.COMM_WORLD
rank = comm.Get_rank()
size = comm.Get_size()
def unitary_mpi(wave, i, l):
shape_b = 2**(l-2*i-2)
# factor 16 is due to 4*4 random unitary is dense
len_u = 16*shape_b # length of data array to be scattered to assemble unitary matrix
for j in range(l):
'''
calculating the kronecker product between the random unitary matrix
and the identity matrix then broadcasting it to all nodes
'''
if rank == 0:
#un = sparse.kron(u, sparse.identity(2**(l-2*i-2)), format='coo')
u = coo_matrix(unitary_group.rvs(4), dtype = 'c16')
d_a = u.data
r_a = u.row
c_a = u.col
d_b = np.ones(shape_b)
r_b = np.arange(shape_b)
c_b = r_b
un_coo = kron_raw(d_a, r_a, c_a, d_b, r_b, c_b, shape_b)
# Broadcasting the unitary matrix to all nodes
un_pack = np.array(un_coo, dtype='c16')
else:
un_pack = np.empty((3, len_u), dtype='c16')
comm.Bcast(un_pack, root=0)
# assert un_pack.dtype == 'c16'
# unpack the data and make the sparse matrix block
un_sub = csr_matrix((un_pack[0],(un_pack[1].real,un_pack[2].real)), shape=(2**(l-2*i), 2**(l-2*i)))
# Scatter wavefunction across the nodes
sendbuf = None
if rank == 0:
sendbuf = np.array(np.split(wave, size), dtype='c16')
# receiving buffer for the incoming chunked wavefunction
recvbuf = np.empty(2**l//size, dtype='c16')
# scatter chunked wavefunction from root node
comm.Scatter(sendbuf, recvbuf, root=0)
# assert recvbuf.shape[0] == 2**l//size
sub_blocks = 2**(2*i)//size # number of subblocks of the sparse matrix
# each sub-divided wavefunction are further splitted locally
wave_split = np.split(recvbuf, sub_blocks)
# assert len(wave_split) == sub_blocks
# assert wave_split[0].dtype == 'c16'
temp = np.empty_like(wave_split)
# assert temp.dtype == 'c16'
# apply dot product for each block matrix
for k in range(sub_blocks):
temp[k] = un_sub.dot(wave_split[k])
# stack wavefunction locally
wave_split = np.concatenate(temp)
# assert wave_split.shape[0] == 2**l//size
# set receiving buffer for root node
recvbuf = None
if rank == 0:
recvbuf = np.empty([size, 2**l//size], dtype='c16')
# gathering resulting wavefuntion
comm.Gather(wave_split, recvbuf, root=0)
if rank == 0:
wave = recvbuf.ravel(order='F')
# assert wave.shape[0] == 2**l
wave = np.reshape(recvbuf,(2, 2**(l-2), 2))
# shift the axis to next position and flatten array
wave = np.moveaxis(wave, -1, 0).ravel(order='F')
return wave
import cppimport
# import c++ module to perfrom dot product
cxx = cppimport.imp("eigen_dot")
# mpi+openmp version
def unitary_hybrid(wave, i, l):
shape_b = 2**(l-2*i-2)
# factor 16 is due to 4*4 random unitary is dense
len_u = 16*shape_b # length of data array to be scattered to assemble unitary matrix
for j in range(l):
'''
calculating the kronecker product between the random unitary matrix
and the identity matrix then broadcasting it to all nodes
'''
if rank == 0:
u = coo_matrix(unitary_group.rvs(4), dtype = 'c16')
d_a = u.data
r_a = u.row
c_a = u.col
d_b = np.ones(shape_b)
r_b = np.arange(shape_b)
c_b = r_b
un_coo = kron_raw(d_a, r_a, c_a, d_b, r_b, c_b, shape_b)
# Broadcasting the unitary matrix to all nodes
un_pack = np.array(un_coo, dtype='c16')
else:
un_pack = np.empty((3, len_u), dtype='c16')
comm.Bcast(un_pack, root=0)
# unpack the data and make the sparse matrix block
un_sub = csr_matrix((un_pack[0],(un_pack[1].real,un_pack[2].real)), shape=(2**(l-2*i), 2**(l-2*i)))
# Scatter wavefunction across the nodes
sendbuf = None
if rank == 0:
sendbuf = np.array(np.split(wave, size), dtype='c16')
# receiving buffer for the incoming chunked wavefunction
recvbuf = np.empty(2**l//size, dtype='c16')
# scatter chunked wavefunction from root node
comm.Scatter(sendbuf, recvbuf, root=0)
# dot product between splitted wave function and sparse matrix un_sub using c++
wave_split = cxx.dot(i, l, un_sub, recvbuf, size)
# set receiving buffer for root node
recvbuf = None
if rank == 0:
recvbuf = np.empty([size, 2**l//size], dtype='c16')
# gathering resulting wavefuntion
comm.Gather(wave_split, recvbuf, root=0)
if rank == 0:
wave = recvbuf.ravel(order='C')
# assert wave.shape[0] == 2**l
wave = np.reshape(recvbuf,(2, 2**(l-2), 2))
# shift the axis to next position and flatten array
wave = np.moveaxis(wave, -1, 0).ravel(order='C')
return wave
def evo_parallel(steps, wave, prob, l = L, n = 2, partition = part):
von = np.zeros(steps, dtype='float64') # von-Neumann entropy
renyi = np.zeros(steps, dtype='float64') # Renyi entropy
neg = np.zeros(steps, dtype='float64') # logarithmic negativity
mut = np.zeros(steps, dtype='float64') # mutual information using von-Neumann entropy
mutr = np.zeros(steps, dtype='float64') # mutual information in terms of Renyi entropy
for t in range(steps):
# evolve over ALL links
wave = unitary_hybrid(wave, 4, l)
# measurement layer
'''
with this protocol, we need to double the measurement rate
'''
if rank == 0:
for i in range(l):
wave = measure(wave, prob, i, l)
result = ent(wave, n, l//2, l) # half-chain entanglement entropy
# print(result[0])
von[t] = result[0]
renyi[t] = result[1]
result = logneg(wave, n, partition) # logarithmic negativity according to preset partition
neg[t] = result[0]
mut[t] = result[1]
mutr[t] = result[2]
return np.array([von, renyi, neg, mut, mutr])
result = evo_parallel(time, psi, pro)
if rank == 0:
np.savez('dynamics_L=%s_p=%s_t=%s'%(L, pro, time), ent=result[0], renyi=result[1], neg=result[2], mut=result[3], mutr=result[4])
end = timer()
print("Elapsed = %s" % (end - start))