diff --git a/answer_rocket/chat.py b/answer_rocket/chat.py index b863f63..bb56af6 100644 --- a/answer_rocket/chat.py +++ b/answer_rocket/chat.py @@ -1,18 +1,17 @@ import io import logging +import pandas as pd import uuid from datetime import datetime -from typing import Literal, Optional - -import pandas as pd from sgqlc.types import Variable, non_null, String, Arg, list_of +from typing import Literal, Optional +from answer_rocket.client_config import ClientConfig from answer_rocket.graphql.client import GraphQlClient from answer_rocket.graphql.schema import (UUID, Int, DateTime, ChatDryRunType, MaxChatEntry, MaxChatThread, SharedThread, MaxChatUser, ChatArtifact, MaxMutationResponse, ChatArtifactSearchInput, PagingInput, PagedChatArtifacts) from answer_rocket.graphql.sdk_operations import Operations -from answer_rocket.client_config import ClientConfig logger = logging.getLogger(__name__) @@ -243,18 +242,20 @@ def get_chat_thread(self, thread_id: str) -> MaxChatThread: result = self.gql_client.submit(op, get_chat_thread_args) return result.chat_thread - def create_new_thread(self, copilot_id: str) -> MaxChatThread: + def create_new_thread(self, copilot_id: str, thread_type: ThreadType = "CHAT") -> MaxChatThread: """ Create a new chat thread for the specified agent. Args: copilot_id: The ID of the agent to create the thread for + thread_type: The type of thread to create, defaults to CHAT. For most purposes CHAT is the only type needed. Returns: MaxChatThread: The newly created chat thread object """ create_chat_thread_args = { 'copilotId': copilot_id, + 'threadType': thread_type } op = Operations.mutation.create_chat_thread diff --git a/answer_rocket/graphql/operations/chat.gql b/answer_rocket/graphql/operations/chat.gql index 4dba627..2cdb931 100644 --- a/answer_rocket/graphql/operations/chat.gql +++ b/answer_rocket/graphql/operations/chat.gql @@ -149,11 +149,11 @@ query AllChatEntries($offset: Int, $limit: Int, $filters: JSON) { } } -mutation CreateChatThread($copilotId: UUID!) { - createChatThread(copilotId: $copilotId) { - id - copilotId - } +mutation CreateChatThread($copilotId: UUID!, $threadType: ThreadType) { + createChatThread(copilotId: $copilotId, threadType: $threadType) { + id + copilotId + } } mutation AddFeedback( diff --git a/answer_rocket/graphql/schema.py b/answer_rocket/graphql/schema.py index ad3d060..f2d8dcf 100644 --- a/answer_rocket/graphql/schema.py +++ b/answer_rocket/graphql/schema.py @@ -404,6 +404,15 @@ class CreateDatasetFromTableResponse(sgqlc.types.Type): error = sgqlc.types.Field(String, graphql_name='error') +class CreateMaxCopilotSkillChatQuestionResponse(sgqlc.types.Type): + __schema__ = schema + __field_names__ = ('copilot_skill_chat_question_id', 'success', 'code', 'error') + copilot_skill_chat_question_id = sgqlc.types.Field(UUID, graphql_name='copilotSkillChatQuestionId') + success = sgqlc.types.Field(sgqlc.types.non_null(Boolean), graphql_name='success') + code = sgqlc.types.Field(String, graphql_name='code') + error = sgqlc.types.Field(String, graphql_name='error') + + class Database(sgqlc.types.Type): __schema__ = schema __field_names__ = ('database_id', 'name', 'dbms', 'description', 'llm_description', 'mermaid_er_diagram', 'k_shot_limit') @@ -983,7 +992,28 @@ class MaxUser(sgqlc.types.Type): class Mutation(sgqlc.types.Type): __schema__ = schema - __field_names__ = ('create_max_copilot_question', 'update_max_copilot_question', 'delete_max_copilot_question', 'set_max_agent_workflow', 'import_copilot_skill_from_zip', 'sync_max_skill_repository', 'import_skill_from_repo', 'test_run_copilot_skill', 'get_test_run_output', 'reload_dataset', 'update_database_name', 'update_database_description', 'update_database_llm_description', 'update_database_mermaid_er_diagram', 'update_database_kshot_limit', 'update_dataset_name', 'update_dataset_description', 'update_dataset_date_range', 'update_dataset_data_interval', 'update_dataset_misc_info', 'update_dataset_source', 'update_dataset_query_row_limit', 'update_dataset_use_database_casing', 'update_dataset_kshot_limit', 'create_dataset', 'create_dataset_from_table', 'create_dimension', 'update_dimension', 'delete_dimension', 'create_metric', 'update_metric', 'delete_metric', 'create_database_kshot', 'update_database_kshot_question', 'update_database_kshot_rendered_prompt', 'update_database_kshot_explanation', 'update_database_kshot_sql', 'update_database_kshot_title', 'update_database_kshot_visualization', 'delete_database_kshot', 'update_chat_answer_payload', 'ask_chat_question', 'evaluate_chat_question', 'queue_chat_question', 'cancel_chat_question', 'create_chat_thread', 'add_feedback', 'set_skill_memory', 'share_thread', 'update_loading_message', 'create_chat_artifact', 'delete_chat_artifact') + __field_names__ = ('create_max_copilot_skill_chat_question', 'update_max_copilot_skill_chat_question', 'delete_max_copilot_skill_chat_question', 'create_max_copilot_question', 'update_max_copilot_question', 'delete_max_copilot_question', 'set_max_agent_workflow', 'import_copilot_skill_from_zip', 'sync_max_skill_repository', 'import_skill_from_repo', 'test_run_copilot_skill', 'get_test_run_output', 'reload_dataset', 'update_database_name', 'update_database_description', 'update_database_llm_description', 'update_database_mermaid_er_diagram', 'update_database_kshot_limit', 'update_dataset_name', 'update_dataset_description', 'update_dataset_date_range', 'update_dataset_data_interval', 'update_dataset_misc_info', 'update_dataset_source', 'update_dataset_query_row_limit', 'update_dataset_use_database_casing', 'update_dataset_kshot_limit', 'create_dataset', 'create_dataset_from_table', 'create_dimension', 'update_dimension', 'delete_dimension', 'create_metric', 'update_metric', 'delete_metric', 'create_database_kshot', 'update_database_kshot_question', 'update_database_kshot_rendered_prompt', 'update_database_kshot_explanation', 'update_database_kshot_sql', 'update_database_kshot_title', 'update_database_kshot_visualization', 'delete_database_kshot', 'update_chat_answer_payload', 'ask_chat_question', 'evaluate_chat_question', 'queue_chat_question', 'cancel_chat_question', 'create_chat_thread', 'add_feedback', 'set_skill_memory', 'share_thread', 'update_loading_message', 'create_chat_artifact', 'delete_chat_artifact') + create_max_copilot_skill_chat_question = sgqlc.types.Field(sgqlc.types.non_null(CreateMaxCopilotSkillChatQuestionResponse), graphql_name='createMaxCopilotSkillChatQuestion', args=sgqlc.types.ArgDict(( + ('copilot_id', sgqlc.types.Arg(sgqlc.types.non_null(UUID), graphql_name='copilotId', default=None)), + ('copilot_skill_id', sgqlc.types.Arg(sgqlc.types.non_null(UUID), graphql_name='copilotSkillId', default=None)), + ('question', sgqlc.types.Arg(sgqlc.types.non_null(String), graphql_name='question', default=None)), + ('expected_completion_response', sgqlc.types.Arg(sgqlc.types.non_null(String), graphql_name='expectedCompletionResponse', default=None)), +)) + ) + update_max_copilot_skill_chat_question = sgqlc.types.Field(sgqlc.types.non_null(MaxMutationResponse), graphql_name='updateMaxCopilotSkillChatQuestion', args=sgqlc.types.ArgDict(( + ('copilot_id', sgqlc.types.Arg(sgqlc.types.non_null(UUID), graphql_name='copilotId', default=None)), + ('copilot_skill_id', sgqlc.types.Arg(sgqlc.types.non_null(UUID), graphql_name='copilotSkillId', default=None)), + ('copilot_skill_chat_question_id', sgqlc.types.Arg(sgqlc.types.non_null(UUID), graphql_name='copilotSkillChatQuestionId', default=None)), + ('question', sgqlc.types.Arg(sgqlc.types.non_null(String), graphql_name='question', default=None)), + ('expected_completion_response', sgqlc.types.Arg(sgqlc.types.non_null(String), graphql_name='expectedCompletionResponse', default=None)), +)) + ) + delete_max_copilot_skill_chat_question = sgqlc.types.Field(sgqlc.types.non_null(MaxMutationResponse), graphql_name='deleteMaxCopilotSkillChatQuestion', args=sgqlc.types.ArgDict(( + ('copilot_id', sgqlc.types.Arg(sgqlc.types.non_null(UUID), graphql_name='copilotId', default=None)), + ('copilot_skill_id', sgqlc.types.Arg(sgqlc.types.non_null(UUID), graphql_name='copilotSkillId', default=None)), + ('copilot_skill_chat_question_id', sgqlc.types.Arg(sgqlc.types.non_null(UUID), graphql_name='copilotSkillChatQuestionId', default=None)), +)) + ) create_max_copilot_question = sgqlc.types.Field(sgqlc.types.non_null(MaxCreateCopilotQuestionResponse), graphql_name='createMaxCopilotQuestion', args=sgqlc.types.ArgDict(( ('copilot_id', sgqlc.types.Arg(sgqlc.types.non_null(UUID), graphql_name='copilotId', default=None)), ('copilot_question', sgqlc.types.Arg(sgqlc.types.non_null(MaxCopilotQuestionInput), graphql_name='copilotQuestion', default=None)), @@ -1229,6 +1259,7 @@ class Mutation(sgqlc.types.Type): ) create_chat_thread = sgqlc.types.Field(MaxChatThread, graphql_name='createChatThread', args=sgqlc.types.ArgDict(( ('copilot_id', sgqlc.types.Arg(sgqlc.types.non_null(UUID), graphql_name='copilotId', default=None)), + ('thread_type', sgqlc.types.Arg(ThreadType, graphql_name='threadType', default=None)), )) ) add_feedback = sgqlc.types.Field(Boolean, graphql_name='addFeedback', args=sgqlc.types.ArgDict(( diff --git a/answer_rocket/graphql/sdk_operations.py b/answer_rocket/graphql/sdk_operations.py index 428d618..7c812bb 100644 --- a/answer_rocket/graphql/sdk_operations.py +++ b/answer_rocket/graphql/sdk_operations.py @@ -68,8 +68,8 @@ def mutation_cancel_chat_question(): def mutation_create_chat_thread(): - _op = sgqlc.operation.Operation(_schema_root.mutation_type, name='CreateChatThread', variables=dict(copilotId=sgqlc.types.Arg(sgqlc.types.non_null(_schema.UUID)))) - _op_create_chat_thread = _op.create_chat_thread(copilot_id=sgqlc.types.Variable('copilotId')) + _op = sgqlc.operation.Operation(_schema_root.mutation_type, name='CreateChatThread', variables=dict(copilotId=sgqlc.types.Arg(sgqlc.types.non_null(_schema.UUID)), threadType=sgqlc.types.Arg(_schema.ThreadType))) + _op_create_chat_thread = _op.create_chat_thread(copilot_id=sgqlc.types.Variable('copilotId'), thread_type=sgqlc.types.Variable('threadType')) _op_create_chat_thread.id() _op_create_chat_thread.copilot_id() return _op