Skip to content

Commit

Permalink
Merge pull request #1 from mistralai/bam4d/json_update
Browse files Browse the repository at this point in the history
added a few tricks to function calling
  • Loading branch information
Bam4d authored Feb 26, 2024
2 parents ba1e050 + fab57aa commit 8c83c25
Show file tree
Hide file tree
Showing 6 changed files with 113 additions and 17 deletions.
6 changes: 4 additions & 2 deletions examples/function_calling.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def retrieve_payment_date(df: pd.DataFrame, transaction_id: str) -> str:


api_key = os.environ["MISTRAL_API_KEY"]
model = "mistral-large"
model = "mistral-large-latest"

client = MistralClient(api_key=api_key)

Expand All @@ -81,7 +81,9 @@ def retrieve_payment_date(df: pd.DataFrame, transaction_id: str) -> str:
messages.append(ChatMessage(role="assistant", content=response.choices[0].message.content))
messages.append(ChatMessage(role="user", content="My transaction ID is T1001."))

response = client.chat(model=model, messages=messages, tools=tools)
response = client.chat(
model=model, messages=messages, tools=tools
)

tool_call = response.choices[0].message.tool_calls[0]
function_name = tool_call.function.name
Expand Down
25 changes: 25 additions & 0 deletions examples/json_format.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#!/usr/bin/env python

import os

from mistralai.client import MistralClient
from mistralai.models.chat_completion import ChatMessage


def main():
api_key = os.environ["MISTRAL_API_KEY"]
model = "mistral-large-latest"

client = MistralClient(api_key=api_key)

chat_response = client.chat(
model=model,
response_format={"type": "json_object"},
messages=[ChatMessage(role="user", content="What is the best French cheese?")],

)
print(chat_response.choices[0].message.content)


if __name__ == "__main__":
main()
23 changes: 18 additions & 5 deletions src/mistralai/async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,12 @@
MistralConnectionException,
MistralException,
)
from mistralai.models.chat_completion import ChatCompletionResponse, ChatCompletionStreamResponse
from mistralai.models.chat_completion import (
ChatCompletionResponse,
ChatCompletionStreamResponse,
ResponseFormat,
ToolChoice,
)
from mistralai.models.embeddings import EmbeddingResponse
from mistralai.models.models import ModelList

Expand Down Expand Up @@ -116,15 +121,17 @@ async def _request(

async def chat(
self,
model: str,
messages: List[Any],
model: Optional[str] = None,
tools: Optional[List[Dict[str, Any]]] = None,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
top_p: Optional[float] = None,
random_seed: Optional[int] = None,
safe_mode: bool = False,
safe_prompt: bool = False,
tool_choice: Optional[Union[str, ToolChoice]] = None,
response_format: Optional[Union[Dict[str, str], ResponseFormat]] = None,
) -> ChatCompletionResponse:
"""A asynchronous chat endpoint that returns a single response.
Expand All @@ -144,15 +151,17 @@ async def chat(
ChatCompletionResponse: a response object containing the generated text.
"""
request = self._make_chat_request(
model,
messages,
model,
tools=tools,
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
random_seed=random_seed,
stream=False,
safe_prompt=safe_mode or safe_prompt,
tool_choice=tool_choice,
response_format=response_format,
)

single_response = self._request("post", request, "v1/chat/completions")
Expand All @@ -164,15 +173,17 @@ async def chat(

async def chat_stream(
self,
model: str,
messages: List[Any],
model: Optional[str] = None,
tools: Optional[List[Dict[str, Any]]] = None,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
top_p: Optional[float] = None,
random_seed: Optional[int] = None,
safe_mode: bool = False,
safe_prompt: bool = False,
tool_choice: Optional[Union[str, ToolChoice]] = None,
response_format: Optional[Union[Dict[str, str], ResponseFormat]] = None,
) -> AsyncGenerator[ChatCompletionStreamResponse, None]:
"""An Asynchronous chat endpoint that streams responses.
Expand All @@ -195,15 +206,17 @@ async def chat_stream(
"""

request = self._make_chat_request(
model,
messages,
model,
tools=tools,
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
random_seed=random_seed,
stream=True,
safe_prompt=safe_mode or safe_prompt,
tool_choice=tool_choice,
response_format=response_format,
)
async_response = self._request("post", request, "v1/chat/completions", stream=True)

Expand Down
23 changes: 18 additions & 5 deletions src/mistralai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,12 @@
MistralConnectionException,
MistralException,
)
from mistralai.models.chat_completion import ChatCompletionResponse, ChatCompletionStreamResponse
from mistralai.models.chat_completion import (
ChatCompletionResponse,
ChatCompletionStreamResponse,
ResponseFormat,
ToolChoice,
)
from mistralai.models.embeddings import EmbeddingResponse
from mistralai.models.models import ModelList

Expand Down Expand Up @@ -109,15 +114,17 @@ def _request(

def chat(
self,
model: str,
messages: List[Any],
model: Optional[str] = None,
tools: Optional[List[Dict[str, Any]]] = None,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
top_p: Optional[float] = None,
random_seed: Optional[int] = None,
safe_mode: bool = False,
safe_prompt: bool = False,
tool_choice: Optional[Union[str, ToolChoice]] = None,
response_format: Optional[Union[Dict[str, str], ResponseFormat]] = None,
) -> ChatCompletionResponse:
"""A chat endpoint that returns a single response.
Expand All @@ -138,15 +145,17 @@ def chat(
ChatCompletionResponse: a response object containing the generated text.
"""
request = self._make_chat_request(
model,
messages,
model,
tools=tools,
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
random_seed=random_seed,
stream=False,
safe_prompt=safe_mode or safe_prompt,
tool_choice=tool_choice,
response_format=response_format,
)

single_response = self._request("post", request, "v1/chat/completions")
Expand All @@ -158,15 +167,17 @@ def chat(

def chat_stream(
self,
model: str,
messages: List[Any],
model: Optional[str] = None,
tools: Optional[List[Dict[str, Any]]] = None,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
top_p: Optional[float] = None,
random_seed: Optional[int] = None,
safe_mode: bool = False,
safe_prompt: bool = False,
tool_choice: Optional[Union[str, ToolChoice]] = None,
response_format: Optional[Union[Dict[str, str], ResponseFormat]] = None,
) -> Iterable[ChatCompletionStreamResponse]:
"""A chat endpoint that streams responses.
Expand All @@ -188,15 +199,17 @@ def chat_stream(
A generator that yields ChatCompletionStreamResponse objects.
"""
request = self._make_chat_request(
model,
messages,
model,
tools=tools,
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
random_seed=random_seed,
stream=True,
safe_prompt=safe_mode or safe_prompt,
tool_choice=tool_choice,
response_format=response_format,
)

response = self._request("post", request, "v1/chat/completions", stream=True)
Expand Down
36 changes: 32 additions & 4 deletions src/mistralai/client_base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
import os
from abc import ABC
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Union

import orjson
from httpx import Response
Expand All @@ -12,7 +12,7 @@
MistralAPIStatusException,
MistralException,
)
from mistralai.models.chat_completion import ChatMessage, Function
from mistralai.models.chat_completion import ChatMessage, Function, ResponseFormat, ToolChoice

logging.basicConfig(
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
Expand All @@ -35,6 +35,10 @@ def __init__(
self._api_key = api_key
self._logger = logging.getLogger(__name__)

# For azure endpoints, we default to the mistral model
if "inference.azure.com" in self._endpoint:
self._default_model = "mistral"

# This should be automatically updated by the deploy script
self._version = "0.0.1"

Expand All @@ -53,6 +57,16 @@ def _parse_tools(self, tools: List[Dict[str, Any]]) -> List[Dict[str, Any]]:

return parsed_tools

def _parse_tool_choice(self, tool_choice: Union[str, ToolChoice]) -> str:
if isinstance(tool_choice, ToolChoice):
return tool_choice.value
return tool_choice

def _parse_response_format(self, response_format: Union[Dict[str, Any], ResponseFormat]) -> Dict[str, Any]:
if isinstance(response_format, ResponseFormat):
return response_format.model_dump(exclude_none=True)
return response_format

def _parse_messages(self, messages: List[Any]) -> List[Dict[str, Any]]:
parsed_messages: List[Dict[str, Any]] = []
for message in messages:
Expand All @@ -65,21 +79,30 @@ def _parse_messages(self, messages: List[Any]) -> List[Dict[str, Any]]:

def _make_chat_request(
self,
model: str,
messages: List[Any],
model: Optional[str] = None,
tools: Optional[List[Dict[str, Any]]] = None,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
top_p: Optional[float] = None,
random_seed: Optional[int] = None,
stream: Optional[bool] = None,
safe_prompt: Optional[bool] = False,
tool_choice: Optional[Union[str, ToolChoice]] = None,
response_format: Optional[Union[Dict[str, str], ResponseFormat]] = None,
) -> Dict[str, Any]:
request_data: Dict[str, Any] = {
"model": model,
"messages": self._parse_messages(messages),
"safe_prompt": safe_prompt,
}

if model is not None:
request_data["model"] = model
else:
if self._default_model is None:
raise MistralException(message="model must be provided")
request_data["model"] = self._default_model

if tools is not None:
request_data["tools"] = self._parse_tools(tools)
if temperature is not None:
Expand All @@ -93,6 +116,11 @@ def _make_chat_request(
if stream is not None:
request_data["stream"] = stream

if tool_choice is not None:
request_data["tool_choice"] = self._parse_tool_choice(tool_choice)
if response_format is not None:
request_data["response_format"] = self._parse_response_format(response_format)

self._logger.debug(f"Chat request: {request_data}")

return request_data
Expand Down
17 changes: 16 additions & 1 deletion src/mistralai/models/chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,21 @@ class ToolCall(BaseModel):
function: FunctionCall


class ResponseFormats(str, Enum):
text: str = "text"
json_object: str = "json_object"


class ToolChoice(str, Enum):
auto: str = "auto"
any: str = "any"
none: str = "none"


class ResponseFormat(BaseModel):
type: ResponseFormats = ResponseFormats.text


class ChatMessage(BaseModel):
role: str
content: Union[str, List[str]]
Expand All @@ -40,7 +55,7 @@ class DeltaMessage(BaseModel):
tool_calls: Optional[List[ToolCall]] = None


class FinishReason(Enum):
class FinishReason(str, Enum):
stop = "stop"
length = "length"
error = "error"
Expand Down

0 comments on commit 8c83c25

Please sign in to comment.