-
Notifications
You must be signed in to change notification settings - Fork 0
/
GUI.py
113 lines (94 loc) · 4.3 KB
/
GUI.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
# -*- coding: utf-8 -*-
# Form implementation generated from reading ui file 'pyui.ui'
#
# Created by: PyQt5 UI code generator 5.15.6
#
# WARNING: Any manual changes made to this file will be lost when pyuic5 is
# run again. Do not edit this file unless you know what you are doing.
import sys
import matplotlib.ticker as ticker
import torch
from PyQt5 import QtCore, QtGui, QtWidgets
from PyQt5.QtWidgets import QApplication
from d2l import torch as d2l
from nltk import word_tokenize
from test import init, eva_one_sentence
from setting import TRAIN_FILE, DEV_FILE
from train import model
from data_pre import PrepareData
import matplotlib.pyplot as plt
model.load_state_dict(torch.load('save/model_150ep.pt', map_location=torch.device('cpu')))
model.eval()
data = PrepareData(TRAIN_FILE, DEV_FILE)
def show_heatmaps(matrices, xlabel, ylabel, titles=None, figsize=(10, 10),
cmap='Reds', sentence=None ):
"""Show heatmaps of matrices.
Defined in :numref:`sec_attention-cues`"""
d2l.use_svg_display()
num_rows, num_cols = matrices.shape[0], matrices.shape[1]
fig, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize)
for i, (row_axes, row_matrices) in enumerate(zip(axes, matrices)):
for j, (ax, matrix) in enumerate(zip(row_axes, row_matrices)):
pcm = ax.imshow(d2l.numpy(matrix), cmap=cmap)
ax.set_xlabel(xlabel)
ax.set_ylabel(ylabel)
if titles:
ax.set_title(titles[j])
ll = len(sentence)
ax.set_xticks(range(ll)) #设置哪些成为坐标
ax.set_yticks(range(ll))
ax.set_xticklabels(sentence,rotation = 30) # 设置坐标的label
ax.set_yticklabels(sentence)
ax.grid(True)
fig.colorbar(pcm, ax=axes, shrink=0.6);
class Ui_MiniTranslater(object):
def setupUi(self, MiniTranslater):
MiniTranslater.setObjectName("MiniTranslater")
MiniTranslater.resize(800, 439)
MiniTranslater.setAcceptDrops(False)
self.textEdit = QtWidgets.QTextEdit(MiniTranslater)
self.textEdit.setGeometry(QtCore.QRect(0, 50, 400, 300))
self.textEdit.setMinimumSize(QtCore.QSize(341, 271))
self.textEdit.setText("nice to meet you !")
self.textEdit.setStyleSheet("")
self.textEdit.setObjectName("textEdit")
self.textEdit_2 = QtWidgets.QTextEdit(MiniTranslater)
self.textEdit_2.setGeometry(QtCore.QRect(400, 50, 400, 300))
self.textEdit_2.setObjectName("textEdit_2")
self.pushButton = QtWidgets.QPushButton(MiniTranslater)
self.pushButton.setGeometry(QtCore.QRect(10, 360, 113, 32))
self.pushButton.setObjectName("pushButton")
self.label = QtWidgets.QLabel(MiniTranslater)
self.label.setGeometry(QtCore.QRect(380, 20, 60, 16))
self.label.setObjectName("label")
self.retranslateUi(MiniTranslater)
QtCore.QMetaObject.connectSlotsByName(MiniTranslater)
def f(self):
# print('tick !!')
# print(self.textEdit_2.toPlainText())
# self.textEdit.setText('tick !!!!!111')
str=self.textEdit.toPlainText()
str = ["BOS"] + word_tokenize(str.lower()) + ["EOS"]
self.textEdit_2.setText(eva_one_sentence(data,model,str=self.textEdit.toPlainText()))
img = model.encoder.layers[0].self_attn.attn.cpu()
img = img.reshape(2,4,img.shape[2],img.shape[3])
show_heatmaps(img ,xlabel='Keys', ylabel='Queries',sentence=str)
plt.show()
print( model.encoder.layers[-1].self_attn.attn)
def retranslateUi(self, MiniTranslater):
_translate = QtCore.QCoreApplication.translate
MiniTranslater.setWindowTitle(_translate("MiniTranslater", "MiniTranslater"))
self.pushButton.setText(_translate("MiniTranslater", "翻译一下"))
self.pushButton.clicked.connect(self.f)
self.label.setText(_translate("MiniTranslater", "英 -> 中"))
from PyQt5 import QtWidgets,QtCore
import sys
if __name__=="__main__":
QtCore.QCoreApplication.setAttribute(QtCore.Qt.AA_EnableHighDpiScaling)
app=QtWidgets.QApplication(sys.argv)
widget=QtWidgets.QMainWindow()
ui=Ui_MiniTranslater()
#我这边是默认的Ui_MainWindow,要是你们自己有修改,这边要相应修改
ui.setupUi(widget)
widget.show()
sys.exit(app.exec_())