Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions answer_rocket/chat.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,17 @@
import io
import logging
import pandas as pd
import uuid
from datetime import datetime
from typing import Literal, Optional

import pandas as pd
from sgqlc.types import Variable, non_null, String, Arg, list_of
from typing import Literal, Optional

from answer_rocket.client_config import ClientConfig
from answer_rocket.graphql.client import GraphQlClient
from answer_rocket.graphql.schema import (UUID, Int, DateTime, ChatDryRunType, MaxChatEntry, MaxChatThread,
SharedThread, MaxChatUser, ChatArtifact, MaxMutationResponse,
ChatArtifactSearchInput, PagingInput, PagedChatArtifacts)
from answer_rocket.graphql.sdk_operations import Operations
from answer_rocket.client_config import ClientConfig

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -243,18 +242,20 @@ def get_chat_thread(self, thread_id: str) -> MaxChatThread:
result = self.gql_client.submit(op, get_chat_thread_args)
return result.chat_thread

def create_new_thread(self, copilot_id: str) -> MaxChatThread:
def create_new_thread(self, copilot_id: str, thread_type: ThreadType = "CHAT") -> MaxChatThread:
"""
Create a new chat thread for the specified agent.

Args:
copilot_id: The ID of the agent to create the thread for
thread_type: The type of thread to create, defaults to CHAT. For most purposes CHAT is the only type needed.

Returns:
MaxChatThread: The newly created chat thread object
"""
create_chat_thread_args = {
'copilotId': copilot_id,
'threadType': thread_type
}

op = Operations.mutation.create_chat_thread
Expand Down
10 changes: 5 additions & 5 deletions answer_rocket/graphql/operations/chat.gql
Original file line number Diff line number Diff line change
Expand Up @@ -149,11 +149,11 @@ query AllChatEntries($offset: Int, $limit: Int, $filters: JSON) {
}
}

mutation CreateChatThread($copilotId: UUID!) {
createChatThread(copilotId: $copilotId) {
id
copilotId
}
mutation CreateChatThread($copilotId: UUID!, $threadType: ThreadType) {
createChatThread(copilotId: $copilotId, threadType: $threadType) {
id
copilotId
}
}

mutation AddFeedback(
Expand Down
33 changes: 32 additions & 1 deletion answer_rocket/graphql/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,15 @@ class CreateDatasetFromTableResponse(sgqlc.types.Type):
error = sgqlc.types.Field(String, graphql_name='error')


class CreateMaxCopilotSkillChatQuestionResponse(sgqlc.types.Type):
__schema__ = schema
__field_names__ = ('copilot_skill_chat_question_id', 'success', 'code', 'error')
copilot_skill_chat_question_id = sgqlc.types.Field(UUID, graphql_name='copilotSkillChatQuestionId')
success = sgqlc.types.Field(sgqlc.types.non_null(Boolean), graphql_name='success')
code = sgqlc.types.Field(String, graphql_name='code')
error = sgqlc.types.Field(String, graphql_name='error')


class Database(sgqlc.types.Type):
__schema__ = schema
__field_names__ = ('database_id', 'name', 'dbms', 'description', 'llm_description', 'mermaid_er_diagram', 'k_shot_limit')
Expand Down Expand Up @@ -983,7 +992,28 @@ class MaxUser(sgqlc.types.Type):

class Mutation(sgqlc.types.Type):
__schema__ = schema
__field_names__ = ('create_max_copilot_question', 'update_max_copilot_question', 'delete_max_copilot_question', 'set_max_agent_workflow', 'import_copilot_skill_from_zip', 'sync_max_skill_repository', 'import_skill_from_repo', 'test_run_copilot_skill', 'get_test_run_output', 'reload_dataset', 'update_database_name', 'update_database_description', 'update_database_llm_description', 'update_database_mermaid_er_diagram', 'update_database_kshot_limit', 'update_dataset_name', 'update_dataset_description', 'update_dataset_date_range', 'update_dataset_data_interval', 'update_dataset_misc_info', 'update_dataset_source', 'update_dataset_query_row_limit', 'update_dataset_use_database_casing', 'update_dataset_kshot_limit', 'create_dataset', 'create_dataset_from_table', 'create_dimension', 'update_dimension', 'delete_dimension', 'create_metric', 'update_metric', 'delete_metric', 'create_database_kshot', 'update_database_kshot_question', 'update_database_kshot_rendered_prompt', 'update_database_kshot_explanation', 'update_database_kshot_sql', 'update_database_kshot_title', 'update_database_kshot_visualization', 'delete_database_kshot', 'update_chat_answer_payload', 'ask_chat_question', 'evaluate_chat_question', 'queue_chat_question', 'cancel_chat_question', 'create_chat_thread', 'add_feedback', 'set_skill_memory', 'share_thread', 'update_loading_message', 'create_chat_artifact', 'delete_chat_artifact')
__field_names__ = ('create_max_copilot_skill_chat_question', 'update_max_copilot_skill_chat_question', 'delete_max_copilot_skill_chat_question', 'create_max_copilot_question', 'update_max_copilot_question', 'delete_max_copilot_question', 'set_max_agent_workflow', 'import_copilot_skill_from_zip', 'sync_max_skill_repository', 'import_skill_from_repo', 'test_run_copilot_skill', 'get_test_run_output', 'reload_dataset', 'update_database_name', 'update_database_description', 'update_database_llm_description', 'update_database_mermaid_er_diagram', 'update_database_kshot_limit', 'update_dataset_name', 'update_dataset_description', 'update_dataset_date_range', 'update_dataset_data_interval', 'update_dataset_misc_info', 'update_dataset_source', 'update_dataset_query_row_limit', 'update_dataset_use_database_casing', 'update_dataset_kshot_limit', 'create_dataset', 'create_dataset_from_table', 'create_dimension', 'update_dimension', 'delete_dimension', 'create_metric', 'update_metric', 'delete_metric', 'create_database_kshot', 'update_database_kshot_question', 'update_database_kshot_rendered_prompt', 'update_database_kshot_explanation', 'update_database_kshot_sql', 'update_database_kshot_title', 'update_database_kshot_visualization', 'delete_database_kshot', 'update_chat_answer_payload', 'ask_chat_question', 'evaluate_chat_question', 'queue_chat_question', 'cancel_chat_question', 'create_chat_thread', 'add_feedback', 'set_skill_memory', 'share_thread', 'update_loading_message', 'create_chat_artifact', 'delete_chat_artifact')
create_max_copilot_skill_chat_question = sgqlc.types.Field(sgqlc.types.non_null(CreateMaxCopilotSkillChatQuestionResponse), graphql_name='createMaxCopilotSkillChatQuestion', args=sgqlc.types.ArgDict((
('copilot_id', sgqlc.types.Arg(sgqlc.types.non_null(UUID), graphql_name='copilotId', default=None)),
('copilot_skill_id', sgqlc.types.Arg(sgqlc.types.non_null(UUID), graphql_name='copilotSkillId', default=None)),
('question', sgqlc.types.Arg(sgqlc.types.non_null(String), graphql_name='question', default=None)),
('expected_completion_response', sgqlc.types.Arg(sgqlc.types.non_null(String), graphql_name='expectedCompletionResponse', default=None)),
))
)
update_max_copilot_skill_chat_question = sgqlc.types.Field(sgqlc.types.non_null(MaxMutationResponse), graphql_name='updateMaxCopilotSkillChatQuestion', args=sgqlc.types.ArgDict((
('copilot_id', sgqlc.types.Arg(sgqlc.types.non_null(UUID), graphql_name='copilotId', default=None)),
('copilot_skill_id', sgqlc.types.Arg(sgqlc.types.non_null(UUID), graphql_name='copilotSkillId', default=None)),
('copilot_skill_chat_question_id', sgqlc.types.Arg(sgqlc.types.non_null(UUID), graphql_name='copilotSkillChatQuestionId', default=None)),
('question', sgqlc.types.Arg(sgqlc.types.non_null(String), graphql_name='question', default=None)),
('expected_completion_response', sgqlc.types.Arg(sgqlc.types.non_null(String), graphql_name='expectedCompletionResponse', default=None)),
))
)
delete_max_copilot_skill_chat_question = sgqlc.types.Field(sgqlc.types.non_null(MaxMutationResponse), graphql_name='deleteMaxCopilotSkillChatQuestion', args=sgqlc.types.ArgDict((
('copilot_id', sgqlc.types.Arg(sgqlc.types.non_null(UUID), graphql_name='copilotId', default=None)),
('copilot_skill_id', sgqlc.types.Arg(sgqlc.types.non_null(UUID), graphql_name='copilotSkillId', default=None)),
('copilot_skill_chat_question_id', sgqlc.types.Arg(sgqlc.types.non_null(UUID), graphql_name='copilotSkillChatQuestionId', default=None)),
))
)
create_max_copilot_question = sgqlc.types.Field(sgqlc.types.non_null(MaxCreateCopilotQuestionResponse), graphql_name='createMaxCopilotQuestion', args=sgqlc.types.ArgDict((
('copilot_id', sgqlc.types.Arg(sgqlc.types.non_null(UUID), graphql_name='copilotId', default=None)),
('copilot_question', sgqlc.types.Arg(sgqlc.types.non_null(MaxCopilotQuestionInput), graphql_name='copilotQuestion', default=None)),
Expand Down Expand Up @@ -1229,6 +1259,7 @@ class Mutation(sgqlc.types.Type):
)
create_chat_thread = sgqlc.types.Field(MaxChatThread, graphql_name='createChatThread', args=sgqlc.types.ArgDict((
('copilot_id', sgqlc.types.Arg(sgqlc.types.non_null(UUID), graphql_name='copilotId', default=None)),
('thread_type', sgqlc.types.Arg(ThreadType, graphql_name='threadType', default=None)),
))
)
add_feedback = sgqlc.types.Field(Boolean, graphql_name='addFeedback', args=sgqlc.types.ArgDict((
Expand Down
4 changes: 2 additions & 2 deletions answer_rocket/graphql/sdk_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ def mutation_cancel_chat_question():


def mutation_create_chat_thread():
_op = sgqlc.operation.Operation(_schema_root.mutation_type, name='CreateChatThread', variables=dict(copilotId=sgqlc.types.Arg(sgqlc.types.non_null(_schema.UUID))))
_op_create_chat_thread = _op.create_chat_thread(copilot_id=sgqlc.types.Variable('copilotId'))
_op = sgqlc.operation.Operation(_schema_root.mutation_type, name='CreateChatThread', variables=dict(copilotId=sgqlc.types.Arg(sgqlc.types.non_null(_schema.UUID)), threadType=sgqlc.types.Arg(_schema.ThreadType)))
_op_create_chat_thread = _op.create_chat_thread(copilot_id=sgqlc.types.Variable('copilotId'), thread_type=sgqlc.types.Variable('threadType'))
_op_create_chat_thread.id()
_op_create_chat_thread.copilot_id()
return _op
Expand Down