Skip to content

Commit

Permalink
Add download link to README.md and refactor code for progress window
Browse files Browse the repository at this point in the history
  • Loading branch information
ProfFan committed Feb 15, 2024
1 parent cc8fcd9 commit 2571e82
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 3 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ Snap2LaTeX is a tool that converts a picture of a mathematical equation into a L

# Usage

Download from [releases](https://github.com/ProfFan/Snap2LaTeX/releases).

Run the application.

There will be an icon in the system tray.
Expand Down
72 changes: 69 additions & 3 deletions standalone_app/Snap2LaTeX.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,86 @@
from transformers import NougatImageProcessor
import accelerate

import re
import logging
from logging import info
import io
import sys
from multiprocessing.queues import Queue

from PyQt6.QtGui import *
from PyQt6.QtWidgets import *
from PyQt6.QtCore import Qt

import multiprocessing as mp

class StdoutQueue(Queue):
def __init__(self, maxsize=-1, block=True, timeout=None):
self.block = block
self.timeout = timeout
super().__init__(maxsize, ctx=mp.get_context())

def write(self, msg):
self.put(msg)

def flush(self):
sys.__stdout__.flush()


def load_model_proc(model_name, q: StdoutQueue):
sys.stdout = q
sys.stderr = q
device = "mps" if torch.backends.mps.is_available() else "cpu"
# init model
model = VisionEncoderDecoderModel.from_pretrained(model_name, device_map=device)

q.close()


def app_show_progress(model_name):
app = QApplication([])

q = StdoutQueue()
load_process = mp.Process(target=load_model_proc, args=(model_name, q))
load_process.start()

# Progress window
progress = QProgressDialog()
progress.setLabelText("Loading model...")
progress.setWindowModality(Qt.WindowModality.WindowModal)
# disable the cancel button
progress.setCancelButton(None)

while load_process.is_alive():
while not q.empty():
# match the ": 5%|" pattern with the regex
# and extract the percentage
match = re.search(r":\s+(\d+)%\|", q.get())
if match:
progress.setValue(int(match.group(1)))

# append the message to the progress window
progress.setLabelText(q.get())
app.processEvents()

print("Model loaded.")

load_process.join()
progress.close()
app.quit()
print("Model check complete.")

if __name__ == "__main__":
mp.freeze_support()

model_name = "Norm/nougat-latex-base"
device = "mps" if torch.backends.mps.is_available() else "cpu"

app_show_progress(model_name)

app = QApplication([])
app.setQuitOnLastWindowClosed(False)

# init model
model = VisionEncoderDecoderModel.from_pretrained(model_name, device_map=device)

Expand All @@ -26,9 +95,6 @@

info("Loaded model.")

app = QApplication([])
app.setQuitOnLastWindowClosed(False)

# Create the icon
from os import path

Expand Down

0 comments on commit 2571e82

Please sign in to comment.