diff --git a/answer_rocket/graphql/operations/llm.gql b/answer_rocket/graphql/operations/llm.gql new file mode 100644 index 0000000..7c3ff7a --- /dev/null +++ b/answer_rocket/graphql/operations/llm.gql @@ -0,0 +1,14 @@ +query GenerateEmbeddings( + $texts: [String!]!, + $modelOverride: String +) { + generateEmbeddings(texts: $texts, modelOverride: $modelOverride) { + success + code + error + embeddings { + text + vector + } + } +} diff --git a/answer_rocket/graphql/schema.py b/answer_rocket/graphql/schema.py index 601cd6f..33e1cad 100644 --- a/answer_rocket/graphql/schema.py +++ b/answer_rocket/graphql/schema.py @@ -569,6 +569,13 @@ class DomainObjectResponse(sgqlc.types.Type): domain_object = sgqlc.types.Field(MaxDomainObject, graphql_name='domainObject') +class EmbeddingResult(sgqlc.types.Type): + __schema__ = schema + __field_names__ = ('text', 'vector') + text = sgqlc.types.Field(sgqlc.types.non_null(String), graphql_name='text') + vector = sgqlc.types.Field(sgqlc.types.non_null(sgqlc.types.list_of(sgqlc.types.non_null(Float))), graphql_name='vector') + + class EvaluateChatQuestionResponse(sgqlc.types.Type): __schema__ = schema __field_names__ = ('success', 'eval_results') @@ -607,6 +614,15 @@ class ExecuteSqlQueryResponse(sgqlc.types.Type): data = sgqlc.types.Field(JSON, graphql_name='data') +class GenerateEmbeddingsResponse(sgqlc.types.Type): + __schema__ = schema + __field_names__ = ('success', 'code', 'error', 'embeddings') + success = sgqlc.types.Field(sgqlc.types.non_null(Boolean), graphql_name='success') + code = sgqlc.types.Field(Int, graphql_name='code') + error = sgqlc.types.Field(String, graphql_name='error') + embeddings = sgqlc.types.Field(sgqlc.types.list_of(sgqlc.types.non_null(EmbeddingResult)), graphql_name='embeddings') + + class GenerateVisualizationResponse(sgqlc.types.Type): __schema__ = schema __field_names__ = ('success', 'code', 'error', 'visualization') @@ -1002,6 +1018,20 @@ class MaxReportResult(sgqlc.types.Type): final_message = sgqlc.types.Field(String, graphql_name='finalMessage') preview = sgqlc.types.Field(String, graphql_name='preview') + +class MaxSkillComponent(sgqlc.types.Type): + __schema__ = schema + __field_names__ = ('skill_component_id', 'node_type', 'organization', 'name', 'description', 'input_properties', 'output_properties', 'component_data') + skill_component_id = sgqlc.types.Field(sgqlc.types.non_null(UUID), graphql_name='skillComponentId') + node_type = sgqlc.types.Field(sgqlc.types.non_null(String), graphql_name='nodeType') + organization = sgqlc.types.Field(sgqlc.types.non_null(String), graphql_name='organization') + name = sgqlc.types.Field(sgqlc.types.non_null(String), graphql_name='name') + description = sgqlc.types.Field(String, graphql_name='description') + input_properties = sgqlc.types.Field(sgqlc.types.non_null(sgqlc.types.list_of(sgqlc.types.non_null('MaxSkillComponentInputProperty'))), graphql_name='inputProperties') + output_properties = sgqlc.types.Field(sgqlc.types.non_null(sgqlc.types.list_of(sgqlc.types.non_null('MaxSkillComponentOutputProperty'))), graphql_name='outputProperties') + component_data = sgqlc.types.Field(sgqlc.types.non_null(JSON), graphql_name='componentData') + + class MaxSkillComponentInputProperty(sgqlc.types.Type): __schema__ = schema __field_names__ = ('name', 'label', 'description', 'type', 'is_required', 'is_list', 'can_wire_from_output') @@ -1474,8 +1504,7 @@ class ParameterDefinition(sgqlc.types.Type): class Query(sgqlc.types.Type): __schema__ = schema - __field_names__ = ('ping', 'current_user', 'get_copilot_skill_artifact_by_path', 'get_copilots', 'get_copilot_info', 'get_copilot_skill', 'run_copilot_skill', 'get_copilot_hydrated_reports', 'get_async_skill_run_status', 'get_max_agent_workflow', 'execute_sql_query', 'execute_rql_query', 'get_databases', 'get_database', 'get_database_tables', 'get_dataset_id', 'get_dataset', 'get_dataset2', 'get_datasets', 'get_domain_object', 'get_domain_object_by_name', 'get_grounded_value', 'get_database_kshots', 'get_database_kshot_by_id', 'get_dataset_kshots', 'get_dataset_kshot_by_id', 'run_max_sql_gen', 'run_sql_ai', 'generate_visualization', 'llmapi_config_for_sdk', 'get_max_llm_prompt', 'user_chat_threads', 'user_chat_entries', 'chat_thread', 'chat_entry', 'user', 'all_chat_entries', 'skill_memory', 'chat_completion', 'narrative_completion', 'narrative_completion_with_prompt', 'sql_completion', 'research_completion', 'chat_completion_with_prompt', 'research_completion_with_prompt', 'get_chat_artifact', 'get_chat_artifacts') - + __field_names__ = ('ping', 'current_user', 'get_copilot_skill_artifact_by_path', 'get_copilots', 'get_copilot_info', 'get_copilot_skill', 'run_copilot_skill', 'get_skill_components', 'get_copilot_hydrated_reports', 'get_async_skill_run_status', 'get_max_agent_workflow', 'execute_sql_query', 'execute_rql_query', 'get_databases', 'get_database', 'get_database_tables', 'get_dataset_id', 'get_dataset', 'get_dataset2', 'get_datasets', 'get_domain_object', 'get_domain_object_by_name', 'get_grounded_value', 'get_database_kshots', 'get_database_kshot_by_id', 'get_dataset_kshots', 'get_dataset_kshot_by_id', 'run_max_sql_gen', 'run_sql_ai', 'generate_visualization', 'llmapi_config_for_sdk', 'generate_embeddings', 'get_max_llm_prompt', 'user_chat_threads', 'user_chat_entries', 'chat_thread', 'chat_entry', 'user', 'all_chat_entries', 'skill_memory', 'chat_completion', 'narrative_completion', 'narrative_completion_with_prompt', 'sql_completion', 'research_completion', 'chat_completion_with_prompt', 'research_completion_with_prompt', 'get_chat_artifact', 'get_chat_artifacts') ping = sgqlc.types.Field(String, graphql_name='ping') current_user = sgqlc.types.Field(MaxUser, graphql_name='currentUser') get_copilot_skill_artifact_by_path = sgqlc.types.Field(CopilotSkillArtifact, graphql_name='getCopilotSkillArtifactByPath', args=sgqlc.types.ArgDict(( @@ -1504,7 +1533,7 @@ class Query(sgqlc.types.Type): ('validate_parameters', sgqlc.types.Arg(Boolean, graphql_name='validateParameters', default=None)), )) ) - + get_skill_components = sgqlc.types.Field(sgqlc.types.non_null(sgqlc.types.list_of(sgqlc.types.non_null(MaxSkillComponent))), graphql_name='getSkillComponents') get_copilot_hydrated_reports = sgqlc.types.Field(sgqlc.types.non_null(sgqlc.types.list_of(sgqlc.types.non_null(HydratedReport))), graphql_name='getCopilotHydratedReports', args=sgqlc.types.ArgDict(( ('copilot_id', sgqlc.types.Arg(sgqlc.types.non_null(UUID), graphql_name='copilotId', default=None)), ('override_dataset_id', sgqlc.types.Arg(UUID, graphql_name='overrideDatasetId', default=None)), @@ -1629,6 +1658,11 @@ class Query(sgqlc.types.Type): ) llmapi_config_for_sdk = sgqlc.types.Field(LLMApiConfig, graphql_name='LLMApiConfigForSdk', args=sgqlc.types.ArgDict(( ('model_type', sgqlc.types.Arg(sgqlc.types.non_null(String), graphql_name='modelType', default=None)), +)) + ) + generate_embeddings = sgqlc.types.Field(sgqlc.types.non_null(GenerateEmbeddingsResponse), graphql_name='generateEmbeddings', args=sgqlc.types.ArgDict(( + ('texts', sgqlc.types.Arg(sgqlc.types.non_null(sgqlc.types.list_of(sgqlc.types.non_null(String))), graphql_name='texts', default=None)), + ('model_override', sgqlc.types.Arg(String, graphql_name='modelOverride', default=None)), )) ) get_max_llm_prompt = sgqlc.types.Field(MaxLLmPrompt, graphql_name='getMaxLlmPrompt', args=sgqlc.types.ArgDict(( diff --git a/answer_rocket/graphql/sdk_operations.py b/answer_rocket/graphql/sdk_operations.py index b392cd0..8f7a9af 100644 --- a/answer_rocket/graphql/sdk_operations.py +++ b/answer_rocket/graphql/sdk_operations.py @@ -1007,6 +1007,18 @@ def query_get_dataset_kshot_by_id(): return _op +def query_generate_embeddings(): + _op = sgqlc.operation.Operation(_schema_root.query_type, name='GenerateEmbeddings', variables=dict(texts=sgqlc.types.Arg(sgqlc.types.non_null(sgqlc.types.list_of(sgqlc.types.non_null(_schema.String)))), modelOverride=sgqlc.types.Arg(_schema.String))) + _op_generate_embeddings = _op.generate_embeddings(texts=sgqlc.types.Variable('texts'), model_override=sgqlc.types.Variable('modelOverride')) + _op_generate_embeddings.success() + _op_generate_embeddings.code() + _op_generate_embeddings.error() + _op_generate_embeddings_embeddings = _op_generate_embeddings.embeddings() + _op_generate_embeddings_embeddings.text() + _op_generate_embeddings_embeddings.vector() + return _op + + def query_chat_completion(): _op = sgqlc.operation.Operation(_schema_root.query_type, name='ChatCompletion', variables=dict(messages=sgqlc.types.Arg(sgqlc.types.non_null(sgqlc.types.list_of(sgqlc.types.non_null(_schema.LlmChatMessage)))), modelSelection=sgqlc.types.Arg(_schema.LlmModelSelection), llmMeta=sgqlc.types.Arg(_schema.LlmMeta), functions=sgqlc.types.Arg(sgqlc.types.list_of(sgqlc.types.non_null(_schema.LlmFunction))))) _op.chat_completion(messages=sgqlc.types.Variable('messages'), model_selection=sgqlc.types.Variable('modelSelection'), llm_meta=sgqlc.types.Variable('llmMeta'), functions=sgqlc.types.Variable('functions')) @@ -1078,6 +1090,7 @@ class Query: chat_thread = query_chat_thread() current_user = query_current_user() dataframes_for_entry = query_dataframes_for_entry() + generate_embeddings = query_generate_embeddings() get_async_skill_run_status = query_get_async_skill_run_status() get_chat_artifact = query_get_chat_artifact() get_chat_artifacts = query_get_chat_artifacts() diff --git a/answer_rocket/llm.py b/answer_rocket/llm.py index d01cbb8..f0fa0ed 100644 --- a/answer_rocket/llm.py +++ b/answer_rocket/llm.py @@ -355,3 +355,31 @@ def research_completion_with_prompt(self, prompt_name: str, prompt_variables: di } gql_response = self.gql_client.submit(op, args) return gql_response.research_completion_with_prompt + + def generate_embeddings(self, texts: list[str], model_override: str | None = None): + """ + Generate embeddings for the provided texts. + + Parameters + ---------- + texts : list[str] + List of text strings to generate embeddings for. + model_override : str | None, optional + Model name or ID to use instead of configured default. + + Returns + ------- + dict + The response containing success status, error (if any), and embeddings. + Each embedding includes the original text and its vector representation. + """ + query_args = { + 'texts': texts, + 'modelOverride': model_override, + } + + op = Operations.query.generate_embeddings + + result = self.gql_client.submit(op, query_args) + + return result.generate_embeddings