Skip to content

Commit

Permalink
feat(mm): support generic API tokens via regex/token pairs in config
Browse files Browse the repository at this point in the history
A list of regex and token pairs is accepted. As a file is downloaded by the model installer, the URL is tested against the provided regex/token pairs. The token for the first matching regex is used during download, added as a bearer token.
  • Loading branch information
psychedelicious committed Mar 8, 2024
1 parent 97d5afe commit f63bc6e
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 7 deletions.
20 changes: 18 additions & 2 deletions invokeai/app/services/config/config_default.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,11 +170,12 @@ class InvokeBatch(InvokeAISettings):
from __future__ import annotations

import os
import re
from pathlib import Path
from typing import Any, ClassVar, Dict, List, Literal, Optional

from omegaconf import DictConfig, OmegaConf
from pydantic import Field
from pydantic import BaseModel, Field, field_validator
from pydantic.config import JsonDict
from pydantic_settings import SettingsConfigDict

Expand Down Expand Up @@ -205,6 +206,21 @@ class Categories(object):
MemoryPerformance: JsonDict = {"category": "Memory/Performance"}


class URLRegexToken(BaseModel):
url_regex: str = Field(description="Regular expression to match against the URL")
token: str = Field(description="Token to use when the URL matches the regex")

@field_validator("url_regex")
@classmethod
def validate_url_regex(cls, v: str) -> str:
"""Validate that the value is a valid regex."""
try:
re.compile(v)
except re.error as e:
raise ValueError(f"Invalid regex: {e}")
return v


class InvokeAIAppConfig(InvokeAISettings):
"""Configuration object for InvokeAI App."""

Expand Down Expand Up @@ -288,7 +304,7 @@ class InvokeAIAppConfig(InvokeAISettings):
node_cache_size : int = Field(default=512, description="How many cached nodes to keep in memory", json_schema_extra=Categories.Nodes)

# MODEL IMPORT
remote_repo_api_key : Optional[str] = Field(default=os.environ.get("INVOKEAI_REMOTE_REPO_API_KEY"), description="API key used when downloading remote repositories", json_schema_extra=Categories.Other)
remote_api_tokens : Optional[list[URLRegexToken]] = Field(default=None, description="List of regular expression and token pairs used when downloading models from URLs. The download URL is tested against the regex, and if it matches, the token is provided in as a Bearer token.", json_schema_extra=Categories.Other)

# DEPRECATED FIELDS - STILL HERE IN ORDER TO OBTAN VALUES FROM PRE-3.1 CONFIG FILES
always_use_cpu : bool = Field(default=False, description="If true, use the CPU for rendering even if a GPU is available.", json_schema_extra=Categories.MemoryPerformance)
Expand Down
5 changes: 1 addition & 4 deletions invokeai/app/services/download/download_default.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,15 +241,12 @@ def _download_next_item(self) -> None:
def _do_download(self, job: DownloadJob) -> None:
"""Do the actual download."""
url = job.source
query_params = url.query_params()
if job.access_token:
query_params.append(("access_token", job.access_token))
header = {"Authorization": f"Bearer {job.access_token}"} if job.access_token else {}
open_mode = "wb"

# Make a streaming request. This will retrieve headers including
# content-length and content-disposition, but not fetch any content itself
resp = self._requests.get(str(url), params=query_params, headers=header, stream=True)
resp = self._requests.get(str(url), headers=header, stream=True)
if not resp.ok:
raise HTTPError(resp.reason)

Expand Down
9 changes: 8 additions & 1 deletion invokeai/app/services/model_install/model_install_default.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,9 +197,16 @@ def heuristic_import(
access_token=access_token,
)
elif re.match(r"^https?://[^/]+", source):
# Pull the token from config if it exists and matches the URL
_token = access_token
if _token is None:
for pair in self.app_config.remote_api_tokens or []:
if re.search(pair.url_regex, source):
_token = pair.token
break
source_obj = URLModelSource(
url=AnyHttpUrl(source),
access_token=self.app_config.remote_repo_api_key,
access_token=_token,
)
else:
raise ValueError(f"Unsupported model source: '{source}'")
Expand Down

0 comments on commit f63bc6e

Please sign in to comment.