Skip to content

Commit

Permalink
Add new download model logic
Browse files Browse the repository at this point in the history
  • Loading branch information
derneuere committed Nov 2, 2023
1 parent 3ad9c7e commit 7aef292
Show file tree
Hide file tree
Showing 11 changed files with 429 additions and 67 deletions.
9 changes: 7 additions & 2 deletions api/batch_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
from api.models.photo import Photo
from api.semantic_search.semantic_search import semantic_search_instance

from api.ml_models import download_models


def create_batch_job(job_type, user):
job_id = uuid.uuid4()
Expand All @@ -25,6 +27,8 @@ def create_batch_job(job_type, user):

if job_type == LongRunningJob.JOB_CALCULATE_CLIP_EMBEDDINGS:
AsyncTask(batch_calculate_clip_embedding, job_id, user).run()
if job_type == LongRunningJob.JOB_DOWNLOAD_MODELS:
AsyncTask(download_models, job_id).run()

lrj.save()

Expand All @@ -40,11 +44,12 @@ def batch_calculate_clip_embedding(job_id, user):
).count()
lrj.result = {"progress": {"current": 0, "target": count}}
lrj.save()

if not torch.cuda.is_available():
num_threads = max(1, site_config.HEAVYWEIGHT_PROCESS)
torch.set_num_threads(num_threads)
os.environ["OMP_NUM_THREADS"] = str(num_threads)
os.environ["OMP_NUM_THREADS"] = str(num_threads)
else:
torch.multiprocessing.set_start_method("spawn", force=True)

BATCH_SIZE = 64
util.logger.info("Using threads: {}".format(torch.get_num_threads()))
Expand Down
2 changes: 1 addition & 1 deletion api/im2txt/data_loader.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import os
from PIL import Image

import torch
import torch.utils.data as data
from PIL import Image


class CocoDataset(data.Dataset):
Expand Down
12 changes: 2 additions & 10 deletions api/im2txt/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,18 +34,10 @@ def __init__(
self.linear = nn.Linear(hidden_size, vocab_size)
self.max_seg_length = max_seq_length

def forward(self, features, captions, lengths):
"""Decode image feature vectors and generates captions."""
embeddings = self.embed(captions)
embeddings = torch.cat((features.unsqueeze(1), embeddings), 1)
packed = pack_padded_sequence(embeddings, lengths, batch_first=True)
hiddens, _ = self.lstm(packed)
outputs = self.linear(hiddens[0])
return outputs

def sample(self, features, states=None):
def forward(self, features):
"""Generate captions for given image features using greedy search."""
sampled_ids = []
states = None
inputs = features.unsqueeze(1)
for i in range(self.max_seg_length):
hiddens, states = self.lstm(
Expand Down
Loading

0 comments on commit 7aef292

Please sign in to comment.