Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(25): make models optionals #27

Merged
merged 2 commits into from
Apr 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 40 additions & 9 deletions ImageGoNord/GoNord.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,15 @@
import ffmpeg
import uuid
import shutil
import requests

import torch
import skimage.io as io
import skimage.color as convertor
import torchvision.transforms as transforms
try:
import torch
import skimage.color as convertor
import torchvision.transforms as transforms
except ImportError:
# AI feature disabled
pass


try:
Expand All @@ -31,7 +35,12 @@
from ImageGoNord.utility.quantize import quantize_to_palette
import ImageGoNord.utility.palette_loader as pl
from ImageGoNord.utility.ConvertUtility import ConvertUtility
from ImageGoNord.utility.model import FeatureEncoder,RecoloringDecoder

try:
from ImageGoNord.utility.model import FeatureEncoder,RecoloringDecoder
except ImportError:
# AI feature disabled
pass


class NordPaletteFile:
Expand Down Expand Up @@ -158,6 +167,8 @@ class GoNord(object):
TRANSPARENCY_TOLERANCE = 190
MAX_THREADS = 10

PALETTE_NET_REPO_FOLDER = 'https://github.com/Schrodinger-Hat/ImageGoNord-pip/raw/master/ImageGoNord/models/PaletteNet/'

AVAILABLE_PALETTE = []
PALETTE_DATA = {}

Expand Down Expand Up @@ -425,6 +436,16 @@ def converted_loop(self, is_rgba, pixels, original_pixels, maxRow, maxCol, minRo
pixels[row, col] = tuple(colors_list)
return pixels

def load_and_save_models(self):
rd_model = requests.get(self.PALETTE_NET_REPO_FOLDER + 'RD.state_dict.pt')
fe_model = requests.get(self.PALETTE_NET_REPO_FOLDER + 'FE.state_dict.pt')

with open(os.path.dirname(palette_net.__file__) + '/FE.state_dict.pt', "wb") as f:
f.write(fe_model.content)

with open(os.path.dirname(palette_net.__file__) + '/RD.state_dict.pt', "wb") as f:
f.write(rd_model.content)

def convert_image_by_model(self, image, use_model_cpu=False):
"""
Process a Pillow image by using a PyTorch model "PaletteNet" for recoloring the image
Expand All @@ -444,8 +465,14 @@ def convert_image_by_model(self, image, use_model_cpu=False):
FE = FeatureEncoder() # torch.Size([64, 3, 3, 3])
RD = RecoloringDecoder() # torch.Size([530, 256, 3, 3])

FE.load_state_dict(torch.load(pkg_resources.open_binary(palette_net, "FE.state_dict.pt")))
RD.load_state_dict(torch.load(pkg_resources.open_binary(palette_net, "RD.state_dict.pt")))
if (
os.path.exists(os.path.dirname(palette_net.__file__) + '/FE.state_dict.pt')
and os.path.exists(os.path.dirname(palette_net.__file__) + '/RD.state_dict.pt')
):
FE.load_state_dict(torch.load(pkg_resources.open_binary(palette_net, "FE.state_dict.pt")))
RD.load_state_dict(torch.load(pkg_resources.open_binary(palette_net, "RD.state_dict.pt")))
else:
self.load_and_save_models()

if use_model_cpu:
FE.to("cpu")
Expand All @@ -472,7 +499,8 @@ def convert_image_by_model(self, image, use_model_cpu=False):
try:
pal_np = np.array(palette).reshape(1,6,3)/255
except:
print("You have too many colors in your palette for the model, this feature is limited to 6 colours, now you have: ", len(palette), "! I'll take the first 6!")
# this feature is limited to 6 colours
# we're taking the first six
pal_np = np.array(palette[0:6]).reshape(1,6,3)/255

pal = torch.Tensor((convertor.rgb2lab(pal_np) - [50,0,0] ) / [50,128,128]).unsqueeze(0)
Expand Down Expand Up @@ -518,7 +546,10 @@ def convert_image(self, image, save_path='', use_model=False, use_model_cpu=Fals
is_rgba = (image.mode == 'RGBA')

if use_model:
image = self.convert_image_by_model(image, use_model_cpu)
if torch != None:
image = self.convert_image_by_model(image, use_model_cpu)
else:
print('Please install the dependencies required for the AI feature: pip install image-go-nord[AI]')
else:
if not parallel_threading:
self.converted_loop(is_rgba, pixels, original_pixels, image.size[0], image.size[1])
Expand Down
2 changes: 1 addition & 1 deletion ImageGoNord/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# gonord version
__version__ = "1.0.2"
__version__ = "1.1.0"

from ImageGoNord.GoNord import *
8 changes: 3 additions & 5 deletions index.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,10 @@
# go_nord.add_file_to_palette(NordPaletteFile.AURORA)
# go_nord.add_file_to_palette(NordPaletteFile.FROST)

# image = go_nord.open_image("images/valley.jpg")
# go_nord.convert_image(image, save_path="images/test-valley-ai.jpg", use_model=True)

output_path = go_nord.convert_video('videos/SampleVideo_720x480.mp4', 'custom_palette', save_path='videos/SampleVideo_converted.mp4')
print(output_path)
image = go_nord.open_image("images/valley.jpg")
go_nord.convert_image(image, save_path="images/test-valley-ai.jpg", use_model=True)
exit()
# output_path = go_nord.convert_video('videos/SampleVideo_720x480.mp4', 'custom_palette', save_path='videos/SampleVideo_converted.mp4')

image = go_nord.open_image("images/test.jpg")
resized_img = go_nord.resize_image(image)
Expand Down
11 changes: 7 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

setup(
name="image-go-nord",
version="1.0.2",
version="1.1.0",
description="A tool to convert any RGB image or video to any theme or color palette input by the user",
long_description=README,
long_description_content_type="text/markdown",
Expand All @@ -17,7 +17,7 @@
author_email="schrodinger.hat.show@gmail.com",
license="AGPL-3.0",
classifiers=[
'Development Status :: 5 - Production/Stable', # Chose either "3 - Alpha", "4 - Beta" or "5 - Production/Stable" as the current state of your package
'Development Status :: 5 - Production/Stable',
'Intended Audience :: Developers',
'Topic :: Software Development :: Build Tools',
"License :: OSI Approved :: GNU Affero General Public License v3 or later (AGPLv3+)",
Expand All @@ -30,8 +30,11 @@
"Bug Reports": "https://github.com/Schrodinger-Hat/ImageGoNord-pip/issues",
},
packages=find_packages(),
package_data={'': ['*.txt', 'palettes/*.txt', 'models/*.pt', '*.pt', '*.state_dict.*']},
package_data={'': ['*.txt', 'palettes/*.txt']},
include_package_data=True,
install_requires=["Pillow", "ffmpeg-python", "numpy", "torch", "scikit-image", "torchvision"],
install_requires=["Pillow", "ffmpeg-python", "numpy", "requests"],
extras_require = {
'AI': ["torch", "scikit-image", "torchvision"]
},
python_requires=">=3.5"
)