Skip to content

Commit

Permalink
feat(25): make models optionals
Browse files Browse the repository at this point in the history
  • Loading branch information
TheJoin95 committed Apr 24, 2024
1 parent 3dfa415 commit 8003ed4
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 19 deletions.
44 changes: 35 additions & 9 deletions ImageGoNord/GoNord.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,14 @@
import ffmpeg
import uuid
import shutil
import requests

import torch
import skimage.io as io
import skimage.color as convertor
import torchvision.transforms as transforms
try:
import torch
import skimage.color as convertor
import torchvision.transforms as transforms
except ImportError:
print("Please install the dependencies required for the AI feature")


try:
Expand All @@ -31,7 +34,11 @@
from ImageGoNord.utility.quantize import quantize_to_palette
import ImageGoNord.utility.palette_loader as pl
from ImageGoNord.utility.ConvertUtility import ConvertUtility
from ImageGoNord.utility.model import FeatureEncoder,RecoloringDecoder

try:
from ImageGoNord.utility.model import FeatureEncoder,RecoloringDecoder
except ImportError:
print("Please install the dependencies required for the AI feature")


class NordPaletteFile:
Expand Down Expand Up @@ -158,6 +165,8 @@ class GoNord(object):
TRANSPARENCY_TOLERANCE = 190
MAX_THREADS = 10

PALETTE_NET_REPO_FOLDER = 'https://github.com/Schrodinger-Hat/ImageGoNord-pip/raw/master/ImageGoNord/models/PaletteNet/'

AVAILABLE_PALETTE = []
PALETTE_DATA = {}

Expand Down Expand Up @@ -425,6 +434,16 @@ def converted_loop(self, is_rgba, pixels, original_pixels, maxRow, maxCol, minRo
pixels[row, col] = tuple(colors_list)
return pixels

def load_and_save_models(self):
rd_model = requests.get(self.PALETTE_NET_REPO_FOLDER + 'RD.state_dict.pt')
fe_model = requests.get(self.PALETTE_NET_REPO_FOLDER + 'FE.state_dict.pt')

with open(os.path.dirname(palette_net.__file__) + '/FE.state_dict.pt', "wb") as f:
f.write(fe_model.content)

with open(os.path.dirname(palette_net.__file__) + '/RD.state_dict.pt', "wb") as f:
f.write(rd_model.content)

def convert_image_by_model(self, image, use_model_cpu=False):
"""
Process a Pillow image by using a PyTorch model "PaletteNet" for recoloring the image
Expand All @@ -444,8 +463,14 @@ def convert_image_by_model(self, image, use_model_cpu=False):
FE = FeatureEncoder() # torch.Size([64, 3, 3, 3])
RD = RecoloringDecoder() # torch.Size([530, 256, 3, 3])

FE.load_state_dict(torch.load(pkg_resources.open_binary(palette_net, "FE.state_dict.pt")))
RD.load_state_dict(torch.load(pkg_resources.open_binary(palette_net, "RD.state_dict.pt")))
if (
os.path.exists(os.path.dirname(palette_net.__file__) + '/FE.state_dict.pt')
and os.path.exists(os.path.dirname(palette_net.__file__) + '/RD.state_dict.pt')
):
FE.load_state_dict(torch.load(pkg_resources.open_binary(palette_net, "FE.state_dict.pt")))
RD.load_state_dict(torch.load(pkg_resources.open_binary(palette_net, "RD.state_dict.pt")))
else:
self.load_and_save_models()

if use_model_cpu:
FE.to("cpu")
Expand All @@ -472,7 +497,8 @@ def convert_image_by_model(self, image, use_model_cpu=False):
try:
pal_np = np.array(palette).reshape(1,6,3)/255
except:
print("You have too many colors in your palette for the model, this feature is limited to 6 colours, now you have: ", len(palette), "! I'll take the first 6!")
# this feature is limited to 6 colours
# we're taking the first six
pal_np = np.array(palette[0:6]).reshape(1,6,3)/255

pal = torch.Tensor((convertor.rgb2lab(pal_np) - [50,0,0] ) / [50,128,128]).unsqueeze(0)
Expand Down Expand Up @@ -517,7 +543,7 @@ def convert_image(self, image, save_path='', use_model=False, use_model_cpu=Fals
pixels = self.load_pixel_image(image)
is_rgba = (image.mode == 'RGBA')

if use_model:
if use_model and torch != None:
image = self.convert_image_by_model(image, use_model_cpu)
else:
if not parallel_threading:
Expand Down
2 changes: 1 addition & 1 deletion ImageGoNord/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# gonord version
__version__ = "1.0.2"
__version__ = "1.1.0"

from ImageGoNord.GoNord import *
8 changes: 3 additions & 5 deletions index.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,10 @@
# go_nord.add_file_to_palette(NordPaletteFile.AURORA)
# go_nord.add_file_to_palette(NordPaletteFile.FROST)

# image = go_nord.open_image("images/valley.jpg")
# go_nord.convert_image(image, save_path="images/test-valley-ai.jpg", use_model=True)

output_path = go_nord.convert_video('videos/SampleVideo_720x480.mp4', 'custom_palette', save_path='videos/SampleVideo_converted.mp4')
print(output_path)
image = go_nord.open_image("images/valley.jpg")
go_nord.convert_image(image, save_path="images/test-valley-ai.jpg", use_model=True)
exit()
# output_path = go_nord.convert_video('videos/SampleVideo_720x480.mp4', 'custom_palette', save_path='videos/SampleVideo_converted.mp4')

image = go_nord.open_image("images/test.jpg")
resized_img = go_nord.resize_image(image)
Expand Down
11 changes: 7 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

setup(
name="image-go-nord",
version="1.0.2",
version="1.1.0",
description="A tool to convert any RGB image or video to any theme or color palette input by the user",
long_description=README,
long_description_content_type="text/markdown",
Expand All @@ -17,7 +17,7 @@
author_email="schrodinger.hat.show@gmail.com",
license="AGPL-3.0",
classifiers=[
'Development Status :: 5 - Production/Stable', # Chose either "3 - Alpha", "4 - Beta" or "5 - Production/Stable" as the current state of your package
'Development Status :: 5 - Production/Stable',
'Intended Audience :: Developers',
'Topic :: Software Development :: Build Tools',
"License :: OSI Approved :: GNU Affero General Public License v3 or later (AGPLv3+)",
Expand All @@ -30,8 +30,11 @@
"Bug Reports": "https://github.com/Schrodinger-Hat/ImageGoNord-pip/issues",
},
packages=find_packages(),
package_data={'': ['*.txt', 'palettes/*.txt', 'models/*.pt', '*.pt', '*.state_dict.*']},
package_data={'': ['*.txt', 'palettes/*.txt']},
include_package_data=True,
install_requires=["Pillow", "ffmpeg-python", "numpy", "torch", "scikit-image", "torchvision"],
install_requires=["Pillow", "ffmpeg-python", "numpy", "requests"],
extras_require = {
'AI': ["torch", "scikit-image", "torchvision"]
},
python_requires=">=3.5"
)

0 comments on commit 8003ed4

Please sign in to comment.