diff --git a/src/mistralai/async_client.py b/src/mistralai/async_client.py index 710c9b49..ce9ce986 100644 --- a/src/mistralai/async_client.py +++ b/src/mistralai/async_client.py @@ -134,6 +134,7 @@ async def chat( random_seed: Optional[int] = None, safe_mode: bool = False, safe_prompt: bool = False, + stop: Optional[List[str]] = None ) -> ChatCompletionResponse: """A asynchronous chat endpoint that returns a single response. @@ -148,6 +149,8 @@ async def chat( random_seed (Optional[int], optional): the random seed to use for sampling, e.g. 42. Defaults to None. safe_mode (bool, optional): deprecated, use safe_prompt instead. Defaults to False. safe_prompt (bool, optional): whether to use safe prompt, e.g. true. Defaults to False. + stop (Optional[List[str]], optional): string upon which the API will stop generating further tokens. The + returned text will not contain the stop sequence. e.g. ['Observations'] Defaults to None. Returns: ChatCompletionResponse: a response object containing the generated text. @@ -180,6 +183,7 @@ async def chat_stream( random_seed: Optional[int] = None, safe_mode: bool = False, safe_prompt: bool = False, + stop: Optional[List[str]] = None ) -> AsyncGenerator[ChatCompletionStreamResponse, None]: """An Asynchronous chat endpoint that streams responses. @@ -194,6 +198,8 @@ async def chat_stream( random_seed (Optional[int], optional): the random seed to use for sampling, e.g. 42. Defaults to None. safe_mode (bool, optional): deprecated, use safe_prompt instead. Defaults to False. safe_prompt (bool, optional): whether to use safe prompt, e.g. true. Defaults to False. + stop (Optional[List[str]], optional): string upon which the API will stop generating further tokens. The + returned text will not contain the stop sequence. e.g. ['Observations'] Defaults to None. Returns: AsyncGenerator[ChatCompletionStreamResponse, None]: @@ -209,6 +215,7 @@ async def chat_stream( random_seed=random_seed, stream=True, safe_prompt=safe_mode or safe_prompt, + stop=stop ) async_response = self._request( "post", request, "v1/chat/completions", stream=True diff --git a/src/mistralai/client.py b/src/mistralai/client.py index 365079d9..54dd83ce 100644 --- a/src/mistralai/client.py +++ b/src/mistralai/client.py @@ -126,6 +126,7 @@ def chat( random_seed: Optional[int] = None, safe_mode: bool = False, safe_prompt: bool = False, + stop: Optional[List[str]] = None ) -> ChatCompletionResponse: """A chat endpoint that returns a single response. @@ -140,6 +141,8 @@ def chat( random_seed (Optional[int], optional): the random seed to use for sampling, e.g. 42. Defaults to None. safe_mode (bool, optional): deprecated, use safe_prompt instead. Defaults to False. safe_prompt (bool, optional): whether to use safe prompt, e.g. true. Defaults to False. + stop (Optional[List[str]], optional): string upon which the API will stop generating further tokens. The + returned text will not contain the stop sequence. e.g. ['Observations'] Defaults to None. Returns: ChatCompletionResponse: a response object containing the generated text. @@ -153,6 +156,7 @@ def chat( random_seed=random_seed, stream=False, safe_prompt=safe_mode or safe_prompt, + stop=stop ) single_response = self._request("post", request, "v1/chat/completions") @@ -172,6 +176,7 @@ def chat_stream( random_seed: Optional[int] = None, safe_mode: bool = False, safe_prompt: bool = False, + stop: Optional[List[str]] = None ) -> Iterable[ChatCompletionStreamResponse]: """A chat endpoint that streams responses. @@ -186,6 +191,8 @@ def chat_stream( random_seed (Optional[int], optional): the random seed to use for sampling, e.g. 42. Defaults to None. safe_mode (bool, optional): deprecated, use safe_prompt instead. Defaults to False. safe_prompt (bool, optional): whether to use safe prompt, e.g. true. Defaults to False. + stop (Optional[List[str]], optional): string upon which the API will stop generating further tokens. The + returned text will not contain the stop sequence. e.g. ['Observations'] Defaults to None. Returns: Iterable[ChatCompletionStreamResponse]: @@ -200,6 +207,7 @@ def chat_stream( random_seed=random_seed, stream=True, safe_prompt=safe_mode or safe_prompt, + stop=stop ) response = self._request("post", request, "v1/chat/completions", stream=True) diff --git a/src/mistralai/client_base.py b/src/mistralai/client_base.py index f6ab4354..90b9fee7 100644 --- a/src/mistralai/client_base.py +++ b/src/mistralai/client_base.py @@ -48,6 +48,7 @@ def _make_chat_request( random_seed: Optional[int] = None, stream: Optional[bool] = None, safe_prompt: Optional[bool] = False, + stop: Optional[List[str]] = None ) -> Dict[str, Any]: request_data: Dict[str, Any] = { "model": model, @@ -64,12 +65,14 @@ def _make_chat_request( request_data["random_seed"] = random_seed if stream is not None: request_data["stream"] = stream - + if stop is not None: + request_data["stop"] = stop self._logger.debug(f"Chat request: {request_data}") return request_data - def _check_response_status_codes(self, response: Response) -> None: + @staticmethod + def _check_response_status_codes(response: Response) -> None: if response.status_code in RETRY_STATUS_CODES: raise MistralAPIStatusException.from_response( response, diff --git a/src/mistralai/models/chat_completion.py b/src/mistralai/models/chat_completion.py index ef1b09e6..9af250c9 100644 --- a/src/mistralai/models/chat_completion.py +++ b/src/mistralai/models/chat_completion.py @@ -19,6 +19,7 @@ class DeltaMessage(BaseModel): class FinishReason(Enum): stop = "stop" length = "length" + stop_param = "stop_param" class ChatCompletionResponseStreamChoice(BaseModel):