Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Release 0.2.0 #94

Merged
merged 2 commits into from
May 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 90 additions & 0 deletions .github/workflows/build_publish.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
name: Lint / Test / Publish

on:
push:
branches: ["main"]

# We only deploy on tags and main branch
tags:
# Only run on tags that match the following regex
# This will match tags like 1.0.0, 1.0.1, etc.
- "[0-9]+.[0-9]+.[0-9]+"

# Lint and test on pull requests
pull_request:

jobs:
lint_and_test:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.9", "3.10", "3.11", "3.12"]
steps:
# Checkout the repository
- name: Checkout
uses: actions/checkout@v4

# Set python version to 3.11
- name: set python version
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}

# Install Build stuff
- name: Install Dependencies
run: |
pip install poetry \
&& poetry config virtualenvs.create false \
&& poetry install

# Ruff
- name: Ruff check
run: |
poetry run ruff check .

- name: Ruff check
run: |
poetry run ruff format . --check

# Mypy
- name: Mypy Check
run: |
poetry run mypy .

# Tests
- name: Run Tests
run: |
poetry run pytest .

publish:
if: startsWith(github.ref, 'refs/tags')
runs-on: ubuntu-latest
needs: lint_and_test
steps:
# Checkout the repository
- name: Checkout
uses: actions/checkout@v4

# Set python version to 3.11
- name: set python version
uses: actions/setup-python@v4
with:
python-version: 3.11

# Install Build stuff
- name: Install Dependencies
run: |
pip install poetry \
&& poetry config virtualenvs.create false \
&& poetry install

# build package using poetry
- name: Build Package
run: |
poetry build

# Publish to PyPi
- name: Pypi publish
run: |
poetry config pypi-token.pypi ${{ secrets.PYPI_TOKEN }}
poetry publish
36 changes: 9 additions & 27 deletions examples/chatbot_with_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,7 @@ def completer(text, state):


class ChatBot:
def __init__(
self, api_key, model, system_message=None, temperature=DEFAULT_TEMPERATURE
):
def __init__(self, api_key, model, system_message=None, temperature=DEFAULT_TEMPERATURE):
if not api_key:
raise ValueError("An API key must be provided to use the Mistral API.")
self.client = MistralClient(api_key=api_key)
Expand All @@ -89,15 +87,11 @@ def opening_instructions(self):

def new_chat(self):
print("")
print(
f"Starting new chat with model: {self.model}, temperature: {self.temperature}"
)
print(f"Starting new chat with model: {self.model}, temperature: {self.temperature}")
print("")
self.messages = []
if self.system_message:
self.messages.append(
ChatMessage(role="system", content=self.system_message)
)
self.messages.append(ChatMessage(role="system", content=self.system_message))

def switch_model(self, input):
model = self.get_arguments(input)
Expand Down Expand Up @@ -146,13 +140,9 @@ def run_inference(self, content):
self.messages.append(ChatMessage(role="user", content=content))

assistant_response = ""
logger.debug(
f"Running inference with model: {self.model}, temperature: {self.temperature}"
)
logger.debug(f"Running inference with model: {self.model}, temperature: {self.temperature}")
logger.debug(f"Sending messages: {self.messages}")
for chunk in self.client.chat_stream(
model=self.model, temperature=self.temperature, messages=self.messages
):
for chunk in self.client.chat_stream(model=self.model, temperature=self.temperature, messages=self.messages):
response = chunk.choices[0].delta.content
if response is not None:
print(response, end="", flush=True)
Expand All @@ -161,9 +151,7 @@ def run_inference(self, content):
print("", flush=True)

if assistant_response:
self.messages.append(
ChatMessage(role="assistant", content=assistant_response)
)
self.messages.append(ChatMessage(role="assistant", content=assistant_response))
logger.debug(f"Current messages: {self.messages}")

def get_command(self, input):
Expand Down Expand Up @@ -215,9 +203,7 @@ def exit(self):


if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="A simple chatbot using the Mistral API"
)
parser = argparse.ArgumentParser(description="A simple chatbot using the Mistral API")
parser.add_argument(
"--api-key",
default=os.environ.get("MISTRAL_API_KEY"),
Expand All @@ -230,19 +216,15 @@ def exit(self):
default=DEFAULT_MODEL,
help="Model for chat inference. Choices are %(choices)s. Defaults to %(default)s",
)
parser.add_argument(
"-s", "--system-message", help="Optional system message to prepend."
)
parser.add_argument("-s", "--system-message", help="Optional system message to prepend.")
parser.add_argument(
"-t",
"--temperature",
type=float,
default=DEFAULT_TEMPERATURE,
help="Optional temperature for chat inference. Defaults to %(default)s",
)
parser.add_argument(
"-d", "--debug", action="store_true", help="Enable debug logging"
)
parser.add_argument("-d", "--debug", action="store_true", help="Enable debug logging")

args = parser.parse_args()

Expand Down
13 changes: 7 additions & 6 deletions examples/function_calling.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,26 @@
"payment_status": ["Paid", "Unpaid", "Paid", "Paid", "Pending"],
}

def retrieve_payment_status(data: Dict[str,List], transaction_id: str) -> str:

def retrieve_payment_status(data: Dict[str, List], transaction_id: str) -> str:
for i, r in enumerate(data["transaction_id"]):
if r == transaction_id:
return json.dumps({"status": data["payment_status"][i]})
else:
return json.dumps({"status": "Error - transaction id not found"})


def retrieve_payment_date(data: Dict[str, List], transaction_id: str) -> str:
for i, r in enumerate(data["transaction_id"]):
if r == transaction_id:
return json.dumps({"date": data["payment_date"][i]})
else:
return json.dumps({"status": "Error - transaction id not found"})


names_to_functions = {
"retrieve_payment_status": functools.partial(retrieve_payment_status, data=data),
"retrieve_payment_date": functools.partial(retrieve_payment_date, data=data)
"retrieve_payment_status": functools.partial(retrieve_payment_status, data=data),
"retrieve_payment_date": functools.partial(retrieve_payment_date, data=data),
}

tools = [
Expand Down Expand Up @@ -75,9 +78,7 @@ def retrieve_payment_date(data: Dict[str, List], transaction_id: str) -> str:
messages.append(ChatMessage(role="assistant", content=response.choices[0].message.content))
messages.append(ChatMessage(role="user", content="My transaction ID is T1001."))

response = client.chat(
model=model, messages=messages, tools=tools
)
response = client.chat(model=model, messages=messages, tools=tools)

tool_call = response.choices[0].message.tool_calls[0]
function_name = tool_call.function.name
Expand Down
1 change: 0 additions & 1 deletion examples/json_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ def main():
model=model,
response_format={"type": "json_object"},
messages=[ChatMessage(role="user", content="What is the best French cheese? Answer shortly in JSON.")],

)
print(chat_response.choices[0].message.content)

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "mistralai"
version = "0.0.1"
version = "0.2.0"
description = ""
authors = ["Bam4d <bam4d@mistral.ai>"]
readme = "README.md"
Expand Down
1 change: 0 additions & 1 deletion src/mistralai/async_client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import asyncio
import os
import posixpath
from json import JSONDecodeError
from typing import Any, AsyncGenerator, Dict, List, Optional, Union
Expand Down
1 change: 0 additions & 1 deletion src/mistralai/client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import os
import posixpath
import time
from json import JSONDecodeError
Expand Down
9 changes: 4 additions & 5 deletions src/mistralai/client_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
)
from mistralai.models.chat_completion import ChatMessage, Function, ResponseFormat, ToolChoice

CLIENT_VERSION = "0.2.0"


class ClientBase(ABC):
def __init__(
Expand All @@ -25,9 +27,7 @@ def __init__(
if api_key is None:
api_key = os.environ.get("MISTRAL_API_KEY")
if api_key is None:
raise MistralException(
message="API key not provided. Please set MISTRAL_API_KEY environment variable."
)
raise MistralException(message="API key not provided. Please set MISTRAL_API_KEY environment variable.")
self._api_key = api_key
self._endpoint = endpoint
self._logger = logging.getLogger(__name__)
Expand All @@ -36,8 +36,7 @@ def __init__(
if "inference.azure.com" in self._endpoint:
self._default_model = "mistral"

# This should be automatically updated by the deploy script
self._version = "0.0.1"
self._version = CLIENT_VERSION

def _parse_tools(self, tools: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
parsed_tools: List[Dict[str, Any]] = []
Expand Down
2 changes: 0 additions & 2 deletions src/mistralai/constants.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@


RETRY_STATUS_CODES = {429, 500, 502, 503, 504}

ENDPOINT = "https://api.mistral.ai"
6 changes: 3 additions & 3 deletions src/mistralai/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,7 @@ def __init__(
self.headers = headers or {}

@classmethod
def from_response(
cls, response: Response, message: Optional[str] = None
) -> MistralAPIException:
def from_response(cls, response: Response, message: Optional[str] = None) -> MistralAPIException:
return cls(
message=message or response.text,
http_status=response.status_code,
Expand All @@ -47,8 +45,10 @@ def from_response(
def __repr__(self) -> str:
return f"{self.__class__.__name__}(message={str(self)}, http_status={self.http_status})"


class MistralAPIStatusException(MistralAPIException):
"""Returned when we receive a non-200 response from the API that we should retry"""


class MistralConnectionException(MistralException):
"""Returned when the SDK can not reach the API server for any reason"""
1 change: 1 addition & 0 deletions src/mistralai/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class ModelPermission(BaseModel):
group: Optional[str] = None
is_blocking: bool = False


class ModelCard(BaseModel):
id: str
object: str
Expand Down
19 changes: 19 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from unittest import mock

import pytest
from mistralai.async_client import MistralAsyncClient
from mistralai.client import MistralClient


@pytest.fixture()
def client():
client = MistralClient(api_key="test_api_key")
client._client = mock.MagicMock()
return client


@pytest.fixture()
def async_client():
client = MistralAsyncClient(api_key="test_api_key")
client._client = mock.AsyncMock()
return client
Loading
Loading