Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for HF token #193

Merged
merged 1 commit into from
Oct 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 84 additions & 10 deletions comfy_cli/command/models/models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import os
import pathlib
import sys
from typing import List, Optional, Tuple
from urllib.parse import unquote, urlparse

import requests
import typer
Expand All @@ -10,7 +12,7 @@
from comfy_cli import constants, tracking, ui
from comfy_cli.config_manager import ConfigManager
from comfy_cli.constants import DEFAULT_COMFY_MODEL_PATH
from comfy_cli.file_utils import DownloadException, download_file
from comfy_cli.file_utils import DownloadException, check_unauthorized, download_file
from comfy_cli.workspace_manager import WorkspaceManager

app = typer.Typer()
Expand All @@ -37,10 +39,41 @@ def potentially_strip_param_url(path_name: str) -> str:
return path_name


# Convert relative path to absolute path based on the current working
# directory
def check_huggingface_url(url: str) -> bool:
return "huggingface.co" in url
def check_huggingface_url(url: str) -> Tuple[bool, Optional[str], Optional[str], Optional[str], Optional[str]]:
"""
Check if the given URL is a Hugging Face URL and extract relevant information.

Args:
url (str): The URL to check.

Returns:
Tuple[bool, Optional[str], Optional[str], Optional[str], Optional[str]]:
- is_huggingface_url (bool): True if it's a Hugging Face URL, False otherwise.
- repo_id (Optional[str]): The repository ID if it's a Hugging Face URL, None otherwise.
- filename (Optional[str]): The filename if present, None otherwise.
- folder_name (Optional[str]): The folder name if present, None otherwise.
- branch_name (Optional[str]): The git branch name if present, None otherwise.
"""
parsed_url = urlparse(url)

if parsed_url.netloc != "huggingface.co" and parsed_url.netloc != "huggingface.com":
return False, None, None, None, None

path_parts = [p for p in parsed_url.path.split("/") if p]

if len(path_parts) < 5 or (path_parts[2] != "resolve" and path_parts[2] != "blob"):
return False, None, None, None, None
repo_id = f"{path_parts[0]}/{path_parts[1]}"
branch_name = path_parts[3]

remaining_path = "/".join(path_parts[4:])
folder_name = os.path.dirname(remaining_path) if "/" in remaining_path else None
filename = os.path.basename(remaining_path)

# URL decode the filename
filename = unquote(filename)

return True, repo_id, filename, folder_name, branch_name


def check_civitai_url(url: str) -> Tuple[bool, bool, int, int]:
Expand Down Expand Up @@ -154,6 +187,14 @@ def download(
show_default=False,
),
] = None,
set_hf_api_token: Annotated[
Optional[str],
typer.Option(
"--set-hf-api-token",
help="Set the HuggingFace API token to use for model listing.",
show_default=False,
),
] = None,
):
if relative_path is not None:
relative_path = os.path.expanduser(relative_path)
Expand All @@ -166,8 +207,12 @@ def download(
config_manager.set(constants.CIVITAI_API_TOKEN_KEY, set_civitai_api_token)
civitai_api_token = set_civitai_api_token

if set_hf_api_token is not None:
config_manager.set(constants.HF_API_TOKEN_KEY, set_hf_api_token)
hf_api_token = set_hf_api_token
else:
civitai_api_token = config_manager.get(constants.CIVITAI_API_TOKEN_KEY)
hf_api_token = config_manager.get(constants.HF_API_TOKEN_KEY)

if civitai_api_token is not None:
headers = {
Expand All @@ -176,6 +221,7 @@ def download(
}

is_civitai_model_url, is_civitai_api_url, model_id, version_id = check_civitai_url(url)
is_huggingface_url, repo_id, hf_filename, hf_folder_name, hf_branch_name = check_huggingface_url(url)

if is_civitai_model_url:
local_filename, url, model_type, basemodel = request_civitai_model_api(model_id, version_id, headers)
Expand All @@ -197,7 +243,9 @@ def download(
model_path = ui.prompt_input("Enter model type path (e.g. loras, checkpoints, ...)", default="")

relative_path = os.path.join(DEFAULT_COMFY_MODEL_PATH, model_path, basemodel)
elif check_huggingface_url(url):
elif is_huggingface_url:
model_id = "/".join(url.split("/")[-2:])

local_filename = potentially_strip_param_url(url.split("/")[-1])

if relative_path is None:
Expand Down Expand Up @@ -225,14 +273,40 @@ def download(

local_filepath = get_workspace() / relative_path / local_filename

# Check if the file already exists
if local_filepath.exists():
print(f"[bold red]File already exists: {local_filepath}[/bold red]")
return

# File does not exist, proceed with download
print(f"Start downloading URL: {url} into {local_filepath}")
download_file(url, local_filepath, headers)
if is_huggingface_url and check_unauthorized(url, headers):
if hf_api_token is None:
print(
"Unauthorized access to Hugging Face model. Please set the HuggingFace API token using --set-hf-api-token"
)
return
else:
try:
import huggingface_hub
except ImportError:
print("huggingface_hub not found. Installing...")
import subprocess

subprocess.check_call([sys.executable, "-m", "pip", "install", "huggingface_hub"])
import huggingface_hub

print(f"Downloading model {model_id} from Hugging Face...")
output_path = huggingface_hub.hf_hub_download(
repo_id=repo_id,
filename=hf_filename,
subfolder=hf_folder_name,
revision=hf_branch_name,
token=hf_api_token,
local_dir=get_workspace() / relative_path,
cache_dir=get_workspace() / relative_path,
)
print(f"Model downloaded successfully to: {output_path}")
else:
print(f"Start downloading URL: {url} into {local_filepath}")
download_file(url, local_filepath, headers)


@app.command()
Expand Down
1 change: 1 addition & 0 deletions comfy_cli/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ class PROC(str, Enum):
CONFIG_KEY_BACKGROUND = "background"

CIVITAI_API_TOKEN_KEY = "civitai_api_token"
HF_API_TOKEN_KEY = "hf_api_token"

DEFAULT_TRACKING_VALUE = True

Expand Down
19 changes: 19 additions & 0 deletions comfy_cli/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,25 @@ def parse_json(input_data):
return f"Unknown error occurred (status code: {status_code})"


def check_unauthorized(url: str, headers: Optional[dict] = None) -> bool:
"""
Perform a GET request to the given URL and check if the response status code is 401 (Unauthorized).

Args:
url (str): The URL to send the GET request to.
headers (Optional[dict]): Optional headers to include in the request.

Returns:
bool: True if the response status code is 401, False otherwise.
"""
try:
response = requests.get(url, headers=headers, allow_redirects=True)
return response.status_code == 401
except requests.RequestException:
# If there's an error making the request, we can't determine if it's unauthorized
return False


def download_file(url: str, local_filepath: pathlib.Path, headers: Optional[dict] = None):
"""Helper function to download a file."""
local_filepath.parent.mkdir(parents=True, exist_ok=True) # Ensure the directory exists
Expand Down
65 changes: 64 additions & 1 deletion tests/comfy_cli/command/models/test_models.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from comfy_cli.command.models.models import check_civitai_url
from comfy_cli.command.models.models import check_civitai_url, check_huggingface_url


def test_valid_model_url():
Expand Down Expand Up @@ -34,3 +34,66 @@ def test_malformed_url():
def test_malformed_query_url():
url = "https://civitai.com/models/43331?version="
assert check_civitai_url(url) == (False, False, None, None)


def test_valid_huggingface_url():
url = "https://huggingface.co/CompVis/stable-diffusion-v1-4/resolve/main/sd-v1-4.ckpt"
assert check_huggingface_url(url) == (True, "CompVis/stable-diffusion-v1-4", "sd-v1-4.ckpt", None, "main")


def test_valid_huggingface_url_sd_audio():
url = "https://huggingface.co/stabilityai/stable-audio-open-1.0/blob/main/model.safetensors"
assert check_huggingface_url(url) == (True, "stabilityai/stable-audio-open-1.0", "model.safetensors", None, "main")


def test_valid_huggingface_url_with_folder():
url = "https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.ckpt"
assert check_huggingface_url(url) == (
True,
"runwayml/stable-diffusion-v1-5",
"v1-5-pruned-emaonly.ckpt",
None,
"main",
)


def test_valid_huggingface_url_with_subfolder():
url = "https://huggingface.co/stabilityai/stable-diffusion-2-1/resolve/main/v2-1_768-ema-pruned.ckpt"
assert check_huggingface_url(url) == (
True,
"stabilityai/stable-diffusion-2-1",
"v2-1_768-ema-pruned.ckpt",
None,
"main",
)


def test_valid_huggingface_url_with_encoded_filename():
url = "https://huggingface.co/CompVis/stable-diffusion-v1-4/resolve/main/sd-v1-4%20(1).ckpt"
assert check_huggingface_url(url) == (True, "CompVis/stable-diffusion-v1-4", "sd-v1-4 (1).ckpt", None, "main")


def test_invalid_huggingface_url():
url = "https://example.com/CompVis/stable-diffusion-v1-4/resolve/main/sd-v1-4.ckpt"
assert check_huggingface_url(url) == (False, None, None, None, None)


def test_invalid_huggingface_url_structure():
url = "https://huggingface.co/CompVis/stable-diffusion-v1-4/main/sd-v1-4.ckpt"
assert check_huggingface_url(url) == (False, None, None, None, None)


def test_huggingface_url_with_com_domain():
url = "https://huggingface.com/CompVis/stable-diffusion-v1-4/resolve/main/sd-v1-4.ckpt"
assert check_huggingface_url(url) == (True, "CompVis/stable-diffusion-v1-4", "sd-v1-4.ckpt", None, "main")


def test_huggingface_url_with_folder_structure():
url = "https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/resolve/main/sd_xl_base_1.0.safetensors"
assert check_huggingface_url(url) == (
True,
"stabilityai/stable-diffusion-xl-base-1.0",
"sd_xl_base_1.0.safetensors",
None,
"main",
)
Loading