Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/mistralai/async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ async def chat(
random_seed: Optional[int] = None,
safe_mode: bool = False,
safe_prompt: bool = False,
stop: Optional[List[str]] = None
) -> ChatCompletionResponse:
"""A asynchronous chat endpoint that returns a single response.

Expand All @@ -148,6 +149,8 @@ async def chat(
random_seed (Optional[int], optional): the random seed to use for sampling, e.g. 42. Defaults to None.
safe_mode (bool, optional): deprecated, use safe_prompt instead. Defaults to False.
safe_prompt (bool, optional): whether to use safe prompt, e.g. true. Defaults to False.
stop (Optional[List[str]], optional): string upon which the API will stop generating further tokens. The
returned text will not contain the stop sequence. e.g. ['Observations'] Defaults to None.

Returns:
ChatCompletionResponse: a response object containing the generated text.
Expand Down Expand Up @@ -180,6 +183,7 @@ async def chat_stream(
random_seed: Optional[int] = None,
safe_mode: bool = False,
safe_prompt: bool = False,
stop: Optional[List[str]] = None
) -> AsyncGenerator[ChatCompletionStreamResponse, None]:
"""An Asynchronous chat endpoint that streams responses.

Expand All @@ -194,6 +198,8 @@ async def chat_stream(
random_seed (Optional[int], optional): the random seed to use for sampling, e.g. 42. Defaults to None.
safe_mode (bool, optional): deprecated, use safe_prompt instead. Defaults to False.
safe_prompt (bool, optional): whether to use safe prompt, e.g. true. Defaults to False.
stop (Optional[List[str]], optional): string upon which the API will stop generating further tokens. The
returned text will not contain the stop sequence. e.g. ['Observations'] Defaults to None.

Returns:
AsyncGenerator[ChatCompletionStreamResponse, None]:
Expand All @@ -209,6 +215,7 @@ async def chat_stream(
random_seed=random_seed,
stream=True,
safe_prompt=safe_mode or safe_prompt,
stop=stop
)
async_response = self._request(
"post", request, "v1/chat/completions", stream=True
Expand Down
8 changes: 8 additions & 0 deletions src/mistralai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ def chat(
random_seed: Optional[int] = None,
safe_mode: bool = False,
safe_prompt: bool = False,
stop: Optional[List[str]] = None
) -> ChatCompletionResponse:
"""A chat endpoint that returns a single response.

Expand All @@ -140,6 +141,8 @@ def chat(
random_seed (Optional[int], optional): the random seed to use for sampling, e.g. 42. Defaults to None.
safe_mode (bool, optional): deprecated, use safe_prompt instead. Defaults to False.
safe_prompt (bool, optional): whether to use safe prompt, e.g. true. Defaults to False.
stop (Optional[List[str]], optional): string upon which the API will stop generating further tokens. The
returned text will not contain the stop sequence. e.g. ['Observations'] Defaults to None.

Returns:
ChatCompletionResponse: a response object containing the generated text.
Expand All @@ -153,6 +156,7 @@ def chat(
random_seed=random_seed,
stream=False,
safe_prompt=safe_mode or safe_prompt,
stop=stop
)

single_response = self._request("post", request, "v1/chat/completions")
Expand All @@ -172,6 +176,7 @@ def chat_stream(
random_seed: Optional[int] = None,
safe_mode: bool = False,
safe_prompt: bool = False,
stop: Optional[List[str]] = None
) -> Iterable[ChatCompletionStreamResponse]:
"""A chat endpoint that streams responses.

Expand All @@ -186,6 +191,8 @@ def chat_stream(
random_seed (Optional[int], optional): the random seed to use for sampling, e.g. 42. Defaults to None.
safe_mode (bool, optional): deprecated, use safe_prompt instead. Defaults to False.
safe_prompt (bool, optional): whether to use safe prompt, e.g. true. Defaults to False.
stop (Optional[List[str]], optional): string upon which the API will stop generating further tokens. The
returned text will not contain the stop sequence. e.g. ['Observations'] Defaults to None.

Returns:
Iterable[ChatCompletionStreamResponse]:
Expand All @@ -200,6 +207,7 @@ def chat_stream(
random_seed=random_seed,
stream=True,
safe_prompt=safe_mode or safe_prompt,
stop=stop
)

response = self._request("post", request, "v1/chat/completions", stream=True)
Expand Down
7 changes: 5 additions & 2 deletions src/mistralai/client_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def _make_chat_request(
random_seed: Optional[int] = None,
stream: Optional[bool] = None,
safe_prompt: Optional[bool] = False,
stop: Optional[List[str]] = None
) -> Dict[str, Any]:
request_data: Dict[str, Any] = {
"model": model,
Expand All @@ -64,12 +65,14 @@ def _make_chat_request(
request_data["random_seed"] = random_seed
if stream is not None:
request_data["stream"] = stream

if stop is not None:
request_data["stop"] = stop
self._logger.debug(f"Chat request: {request_data}")

return request_data

def _check_response_status_codes(self, response: Response) -> None:
@staticmethod
def _check_response_status_codes(response: Response) -> None:
if response.status_code in RETRY_STATUS_CODES:
raise MistralAPIStatusException.from_response(
response,
Expand Down
1 change: 1 addition & 0 deletions src/mistralai/models/chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class DeltaMessage(BaseModel):
class FinishReason(Enum):
stop = "stop"
length = "length"
stop_param = "stop_param"


class ChatCompletionResponseStreamChoice(BaseModel):
Expand Down