Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions src/mistralai/async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ async def chat(
top_p: Optional[float] = None,
random_seed: Optional[int] = None,
safe_mode: bool = False,
safe_prompt: bool = False,
) -> ChatCompletionResponse:
"""A asynchronous chat endpoint that returns a single response.

Expand All @@ -145,7 +146,8 @@ async def chat(
top_p (Optional[float], optional): the cumulative probability of tokens to generate, e.g. 0.9.
Defaults to None.
random_seed (Optional[int], optional): the random seed to use for sampling, e.g. 42. Defaults to None.
safe_mode (bool, optional): whether to use safe mode, e.g. true. Defaults to False.
safe_mode (bool, optional): deprecated, use safe_prompt instead. Defaults to False.
safe_prompt (bool, optional): whether to use safe prompt, e.g. true. Defaults to False.

Returns:
ChatCompletionResponse: a response object containing the generated text.
Expand All @@ -158,7 +160,7 @@ async def chat(
top_p=top_p,
random_seed=random_seed,
stream=False,
safe_mode=safe_mode,
safe_prompt=safe_mode or safe_prompt,
)

single_response = self._request("post", request, "v1/chat/completions")
Expand All @@ -177,6 +179,7 @@ async def chat_stream(
top_p: Optional[float] = None,
random_seed: Optional[int] = None,
safe_mode: bool = False,
safe_prompt: bool = False,
) -> AsyncGenerator[ChatCompletionStreamResponse, None]:
"""An Asynchronous chat endpoint that streams responses.

Expand All @@ -189,7 +192,8 @@ async def chat_stream(
top_p (Optional[float], optional): the cumulative probability of tokens to generate, e.g. 0.9.
Defaults to None.
random_seed (Optional[int], optional): the random seed to use for sampling, e.g. 42. Defaults to None.
safe_mode (bool, optional): whether to use safe mode, e.g. true. Defaults to False.
safe_mode (bool, optional): deprecated, use safe_prompt instead. Defaults to False.
safe_prompt (bool, optional): whether to use safe prompt, e.g. true. Defaults to False.

Returns:
AsyncGenerator[ChatCompletionStreamResponse, None]:
Expand All @@ -204,7 +208,7 @@ async def chat_stream(
top_p=top_p,
random_seed=random_seed,
stream=True,
safe_mode=safe_mode,
safe_prompt=safe_mode or safe_prompt,
)
async_response = self._request(
"post", request, "v1/chat/completions", stream=True
Expand Down
12 changes: 8 additions & 4 deletions src/mistralai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def chat(
top_p: Optional[float] = None,
random_seed: Optional[int] = None,
safe_mode: bool = False,
safe_prompt: bool = False,
) -> ChatCompletionResponse:
"""A chat endpoint that returns a single response.

Expand All @@ -137,7 +138,8 @@ def chat(
top_p (Optional[float], optional): the cumulative probability of tokens to generate, e.g. 0.9.
Defaults to None.
random_seed (Optional[int], optional): the random seed to use for sampling, e.g. 42. Defaults to None.
safe_mode (bool, optional): whether to use safe mode, e.g. true. Defaults to False.
safe_mode (bool, optional): deprecated, use safe_prompt instead. Defaults to False.
safe_prompt (bool, optional): whether to use safe prompt, e.g. true. Defaults to False.

Returns:
ChatCompletionResponse: a response object containing the generated text.
Expand All @@ -150,7 +152,7 @@ def chat(
top_p=top_p,
random_seed=random_seed,
stream=False,
safe_mode=safe_mode,
safe_prompt=safe_mode or safe_prompt,
)

single_response = self._request("post", request, "v1/chat/completions")
Expand All @@ -169,6 +171,7 @@ def chat_stream(
top_p: Optional[float] = None,
random_seed: Optional[int] = None,
safe_mode: bool = False,
safe_prompt: bool = False,
) -> Iterable[ChatCompletionStreamResponse]:
"""A chat endpoint that streams responses.

Expand All @@ -181,7 +184,8 @@ def chat_stream(
top_p (Optional[float], optional): the cumulative probability of tokens to generate, e.g. 0.9.
Defaults to None.
random_seed (Optional[int], optional): the random seed to use for sampling, e.g. 42. Defaults to None.
safe_mode (bool, optional): whether to use safe mode, e.g. true. Defaults to False.
safe_mode (bool, optional): deprecated, use safe_prompt instead. Defaults to False.
safe_prompt (bool, optional): whether to use safe prompt, e.g. true. Defaults to False.

Returns:
Iterable[ChatCompletionStreamResponse]:
Expand All @@ -195,7 +199,7 @@ def chat_stream(
top_p=top_p,
random_seed=random_seed,
stream=True,
safe_mode=safe_mode,
safe_prompt=safe_mode or safe_prompt,
)

response = self._request("post", request, "v1/chat/completions", stream=True)
Expand Down
4 changes: 2 additions & 2 deletions src/mistralai/client_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,12 @@ def _make_chat_request(
top_p: Optional[float] = None,
random_seed: Optional[int] = None,
stream: Optional[bool] = None,
safe_mode: Optional[bool] = False,
safe_prompt: Optional[bool] = False,
) -> Dict[str, Any]:
request_data: Dict[str, Any] = {
"model": model,
"messages": [msg.model_dump() for msg in messages],
"safe_prompt": safe_mode,
"safe_prompt": safe_prompt,
}
if temperature is not None:
request_data["temperature"] = temperature
Expand Down