-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_gmm.py
27 lines (23 loc) · 862 Bytes
/
train_gmm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
from sklearn.mixture import GaussianMixture
import lmdb
import numpy as np
import lmdb
from tqdm import tqdm
DB_features = lmdb.open('./clean/features.lmdb/',map_size=1200*1_000_000) #5000mb
def get_all_data(db, size=20000):
with db.begin(buffers=True) as txn:
with txn.cursor() as curs:
features = [] # np.zeros((batch_size,dim),np.float32)
i=0
for data in tqdm(curs.iternext(keys=True, values=True)):
if i>=size:
break
features.append(np.frombuffer(data[1],dtype=np.float32))
i+=1
return features
all_features = np.array(get_all_data(DB_features, 300000))
gmm = GaussianMixture(n_components = 16, covariance_type = 'full')
gmm.fit(all_features)
import pickle
with open("./gmm.model","wb") as file:
pickle.dump(gmm,file)