Skip to content

Commit

Permalink
feat: mask api key for cerebriumai llm (#14272)
Browse files Browse the repository at this point in the history
- **Description:** Masking API key for CerebriumAI LLM to protect user
secrets.
 - **Issue:** #12165 
 - **Dependencies:** None
 - **Tag maintainer:** @eyurtsev

---------

Signed-off-by: Yuchen Liang <yuchenl3@andrew.cmu.edu>
Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
  • Loading branch information
yliang412 and hwchase17 authored Dec 6, 2023
1 parent d4d64da commit ad6dfb6
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 10 deletions.
23 changes: 13 additions & 10 deletions libs/langchain/langchain/llms/cerebriumai.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,23 @@
import logging
from typing import Any, Dict, List, Mapping, Optional
from typing import Any, Dict, List, Mapping, Optional, cast

import requests
from langchain_core.pydantic_v1 import Extra, Field, root_validator
from langchain_core.pydantic_v1 import Extra, Field, SecretStr, root_validator

from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
from langchain.llms.utils import enforce_stop_tokens
from langchain.utils import get_from_dict_or_env
from langchain.utils import convert_to_secret_str, get_from_dict_or_env

logger = logging.getLogger(__name__)


class CerebriumAI(LLM):
"""CerebriumAI large language models.
To use, you should have the ``cerebrium`` python package installed, and the
environment variable ``CEREBRIUMAI_API_KEY`` set with your API key.
To use, you should have the ``cerebrium`` python package installed.
You should also have the environment variable ``CEREBRIUMAI_API_KEY``
set with your API key or pass it as a named argument in the constructor.
Any parameters that are valid to be passed to the call can be passed
in, even if not explicitly saved on this class.
Expand All @@ -25,7 +26,7 @@ class CerebriumAI(LLM):
.. code-block:: python
from langchain.llms import CerebriumAI
cerebrium = CerebriumAI(endpoint_url="")
cerebrium = CerebriumAI(endpoint_url="", cerebriumai_api_key="my-api-key")
"""

Expand All @@ -36,7 +37,7 @@ class CerebriumAI(LLM):
"""Holds any model parameters valid for `create` call not
explicitly specified."""

cerebriumai_api_key: Optional[str] = None
cerebriumai_api_key: Optional[SecretStr] = None

class Config:
"""Configuration for this pydantic config."""
Expand Down Expand Up @@ -64,8 +65,8 @@ def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
cerebriumai_api_key = get_from_dict_or_env(
values, "cerebriumai_api_key", "CEREBRIUMAI_API_KEY"
cerebriumai_api_key = convert_to_secret_str(
get_from_dict_or_env(values, "cerebriumai_api_key", "CEREBRIUMAI_API_KEY")
)
values["cerebriumai_api_key"] = cerebriumai_api_key
return values
Expand All @@ -91,7 +92,9 @@ def _call(
**kwargs: Any,
) -> str:
headers: Dict = {
"Authorization": self.cerebriumai_api_key,
"Authorization": cast(
SecretStr, self.cerebriumai_api_key
).get_secret_value(),
"Content-Type": "application/json",
}
params = self.model_kwargs or {}
Expand Down
33 changes: 33 additions & 0 deletions libs/langchain/tests/unit_tests/llms/test_cerebriumai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
"""Test CerebriumAI llm"""


from langchain_core.pydantic_v1 import SecretStr
from pytest import CaptureFixture, MonkeyPatch

from langchain.llms.cerebriumai import CerebriumAI


def test_api_key_is_secret_string() -> None:
llm = CerebriumAI(cerebriumai_api_key="test-cerebriumai-api-key")
assert isinstance(llm.cerebriumai_api_key, SecretStr)


def test_api_key_masked_when_passed_via_constructor(capsys: CaptureFixture) -> None:
llm = CerebriumAI(cerebriumai_api_key="secret-api-key")
print(llm.cerebriumai_api_key, end="")
captured = capsys.readouterr()

assert captured.out == "**********"
assert repr(llm.cerebriumai_api_key) == "SecretStr('**********')"


def test_api_key_masked_when_passed_from_env(
monkeypatch: MonkeyPatch, capsys: CaptureFixture
) -> None:
monkeypatch.setenv("CEREBRIUMAI_API_KEY", "secret-api-key")
llm = CerebriumAI()
print(llm.cerebriumai_api_key, end="")
captured = capsys.readouterr()

assert captured.out == "**********"
assert repr(llm.cerebriumai_api_key) == "SecretStr('**********')"

0 comments on commit ad6dfb6

Please sign in to comment.