-
Notifications
You must be signed in to change notification settings - Fork 11
/
db.py
274 lines (216 loc) · 9.14 KB
/
db.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
import sqlite3
import os
# create ./db directory if it doesn't exist
if not os.path.exists('./db'):
os.makedirs('./db')
conn = sqlite3.connect('./db/imagenet.db')
# Create a cursor
cursor = conn.cursor()
# Create the 'hashes' table if it doesn't exist
cursor.execute('''
CREATE TABLE IF NOT EXISTS hashes (
hash_value TEXT PRIMARY KEY
)
''')
# Create the 'prompts' table if it doesn't exist
cursor.execute('''
CREATE TABLE IF NOT EXISTS i2iprompts (
id INTEGER PRIMARY KEY,
batch_id INTEGER,
hash_value TEXT,
image_order_id INTEGER,
uid INTEGER,
prompt TEXT,
negative TEXT,
seed INTEGER,
height INTEGER,
width INTEGER,
timestamp INTEGER,
image_hash TEXT,
FOREIGN KEY (hash_value) REFERENCES hashes (hash_value),
FOREIGN KEY (image_hash) REFERENCES hashes (hash_value)
)
''')
# Create the 'prompts' table if it doesn't exist
cursor.execute('''
CREATE TABLE IF NOT EXISTS prompts (
id INTEGER PRIMARY KEY,
batch_id INTEGER,
hash_value TEXT,
image_order_id INTEGER,
uid INTEGER,
prompt TEXT,
negative TEXT,
seed INTEGER,
height INTEGER,
width INTEGER,
timestamp INTEGER,
FOREIGN KEY (hash_value) REFERENCES hashes (hash_value)
)
''')
# Create the 'batches' table if it doesn't exist
cursor.execute('''
CREATE TABLE IF NOT EXISTS batches (
id INTEGER PRIMARY KEY,
timestamp INTEGER
)
''')
cursor = None
def delete_prompt_by_id(conn, prompt_id):
cursor = conn.cursor()
cursor.execute('DELETE FROM prompts WHERE id = ?', (prompt_id,))
conn.commit()
def delete_i2iprompt_by_id(conn, prompt_id):
cursor = conn.cursor()
cursor.execute('DELETE FROM i2iprompts WHERE id = ?', (prompt_id,))
conn.commit()
def delete_hash_if_no_prompts(conn, hash_value):
cursor = conn.cursor()
# Check if there are any prompts with the same hash value
cursor.execute('SELECT COUNT(*) FROM prompts WHERE hash_value = ?', (hash_value,))
count = cursor.fetchone()[0]
# Check if there are any i2iprompts with the same hash value
cursor.execute('SELECT COUNT(*) FROM i2iprompts WHERE hash_value = ?', (hash_value,))
i2icount = cursor.fetchone()[0]
# If there are no remaining prompts with the same hash, delete the hash
if count == 0 and i2icount == 0:
cursor.execute('DELETE FROM hashes WHERE hash_value = ?', (hash_value,))
conn.commit()
def delete_batch_if_no_prompts(conn, batch_id):
cursor = conn.cursor()
# Check if there are any prompts with the same batch_id
cursor.execute('SELECT COUNT(*) FROM prompts WHERE batch_id = ?', (batch_id,))
count = cursor.fetchone()[0]
# Check if there are any i2iprompts with the same batch_id
cursor.execute('SELECT COUNT(*) FROM i2iprompts WHERE batch_id = ?', (batch_id,))
i2icount = cursor.fetchone()[0]
# If there are no remaining prompts with the same hash, delete the hash
if count == 0 and i2icount == 0:
cursor.execute('DELETE FROM batches WHERE id = ?', (batch_id,))
conn.commit()
def delete_prompts_by_timestamp(conn, timestamp):
cursor = conn.cursor()
# Get all prompts with timestamps beyond the specified value
cursor.execute('SELECT id, hash_value, batch_id FROM prompts WHERE timestamp > ?', (timestamp,))
prompts_to_delete = cursor.fetchall()
for prompt_id, hash_value, batch_id in prompts_to_delete:
# Delete the prompt
delete_prompt_by_id(conn, prompt_id)
# Check and delete the associated hash if necessary
delete_hash_if_no_prompts(conn, hash_value)
delete_batch_if_no_prompts(conn, batch_id)
# Get all prompts with timestamps beyond the specified value
cursor.execute('SELECT id, hash_value FROM i2iprompts WHERE timestamp > ?', (timestamp,))
prompts_to_delete = cursor.fetchall()
for prompt_id, hash_value in prompts_to_delete:
# Delete the prompt
delete_i2iprompt_by_id(conn, prompt_id)
# Check and delete the associated hash if necessary
delete_hash_if_no_prompts(conn, hash_value)
delete_batch_if_no_prompts(conn, batch_id)
def delete_prompts_by_uid(conn, uid):
cursor = conn.cursor()
# Get all prompts with timestamps beyond the specified value
cursor.execute('SELECT id, hash_value, batch_id FROM prompts WHERE uid = ?', (uid,))
prompts_to_delete = cursor.fetchall()
for prompt_id, hash_value, batch_id in prompts_to_delete:
# Delete the prompt
delete_prompt_by_id(conn, prompt_id)
# Check and delete the associated hash if necessary
delete_hash_if_no_prompts(conn, hash_value)
delete_batch_if_no_prompts(conn, batch_id)
# Get all prompts with timestamps beyond the specified value
cursor.execute('SELECT id, hash_value FROM i2iprompts WHERE uid = ?', (uid,))
prompts_to_delete = cursor.fetchall()
for prompt_id, hash_value in prompts_to_delete:
# Delete the prompt
delete_i2iprompt_by_id(conn, prompt_id)
# Check and delete the associated hash if necessary
delete_hash_if_no_prompts(conn, hash_value)
delete_batch_if_no_prompts(conn, batch_id)
def create_or_get_hash_id(conn, hash_value):
cursor = conn.cursor()
# Try to retrieve a row with the given hash_value
cursor.execute('SELECT hash_value FROM hashes WHERE hash_value = ?', (hash_value,))
existing_hash = cursor.fetchone()
if existing_hash:
# If the hash already exists, return the hash_value and a flag indicating it was fetched
return existing_hash[0], True
else:
# If the hash doesn't exist, create it, return the hash_value, and a flag indicating it was created
cursor.execute('INSERT INTO hashes (hash_value) VALUES (?)', (hash_value,))
conn.commit()
return hash_value, False
# returns a boolean value indicating whether the hash already existed (True) or was created (False)
def create_prompt(conn, batch_id, hash_value, image_order_id, uid, prompt, negative, seed, height, width, timestamp, input_image_hash=None):
cursor = conn.cursor()
# Create or get the hash_value and the creation status flag
hash_value, fetched = create_or_get_hash_id(conn, hash_value)
if input_image_hash is None:
# Insert the prompt with the associated hash_value
cursor.execute('''
INSERT INTO prompts (batch_id, hash_value, image_order_id, uid, prompt, negative, seed, height, width, timestamp)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
''', (batch_id, hash_value, image_order_id, uid, prompt, negative, seed, height, width, timestamp))
else:
# Insert the prompt with the associated hash_value and input_image_hash
cursor.execute('''
INSERT INTO i2iprompts (batch_id, hash_value, image_order_id, uid, prompt, negative, seed, height, width, image_hash, timestamp)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
''', (batch_id, hash_value, image_order_id, uid, prompt, negative, seed, height, width, input_image_hash, timestamp))
conn.commit()
return fetched
def create_batch(conn, timestamp):
cursor = conn.cursor()
cursor.execute('INSERT INTO batches (timestamp) VALUES (?)', (int(timestamp),))
conn.commit()
return cursor.lastrowid
def get_batch(conn, batch_id):
# return all the prompts that are tied to the batch_id
cursor = conn.cursor()
cursor.execute('SELECT * FROM prompts WHERE batch_id = ?', (batch_id,))
prompts = cursor.fetchall()
# create a list of Prompt classes
return [Prompt(prompt) for prompt in prompts]
# get random batch id within timestamp range
def get_random_batch_id(conn, start_timestamp = None, end_timestamp = None):
cursor = conn.cursor()
if start_timestamp is None:
start_timestamp = 0
if end_timestamp is None:
end_timestamp = 9999999999
cursor.execute('''
SELECT id FROM batches
WHERE timestamp >= ? AND timestamp <= ?
ORDER BY RANDOM() LIMIT 1
''', (start_timestamp, end_timestamp))
batch_id = cursor.fetchone()
if batch_id is None:
return None
else:
return batch_id[0]
# get random batch and all the associated prompts
def get_prompts_of_random_batch(conn, start_timestamp = None, end_timestamp = None):
batch_id = get_random_batch_id(conn, start_timestamp, end_timestamp)
if batch_id is None:
return None
else:
return get_batch(conn, batch_id)
# create a class object which takes in a prompt from sql and returns a prompt object
class Prompt:
def __init__(self, prompt):
self.id = prompt[0]
self.batch_id = prompt[1]
self.hash_value = prompt[2]
self.image_order_id = prompt[3]
self.uid = prompt[4]
self.prompt = prompt[5]
self.negative = prompt[6]
self.seed = prompt[7]
self.height = prompt[8]
self.width = prompt[9]
self.timestamp = prompt[10]
if len(prompt) > 11:
self.input_image_hash = prompt[11]
def __str__(self):
return str(self.__dict__)