Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions answer_rocket/graphql/operations/llm.gql
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
query GenerateEmbeddings(
$texts: [String!]!,
$modelOverride: String
) {
generateEmbeddings(texts: $texts, modelOverride: $modelOverride) {
success
code
error
embeddings {
text
vector
}
}
}
40 changes: 37 additions & 3 deletions answer_rocket/graphql/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,6 +569,13 @@ class DomainObjectResponse(sgqlc.types.Type):
domain_object = sgqlc.types.Field(MaxDomainObject, graphql_name='domainObject')


class EmbeddingResult(sgqlc.types.Type):
__schema__ = schema
__field_names__ = ('text', 'vector')
text = sgqlc.types.Field(sgqlc.types.non_null(String), graphql_name='text')
vector = sgqlc.types.Field(sgqlc.types.non_null(sgqlc.types.list_of(sgqlc.types.non_null(Float))), graphql_name='vector')


class EvaluateChatQuestionResponse(sgqlc.types.Type):
__schema__ = schema
__field_names__ = ('success', 'eval_results')
Expand Down Expand Up @@ -607,6 +614,15 @@ class ExecuteSqlQueryResponse(sgqlc.types.Type):
data = sgqlc.types.Field(JSON, graphql_name='data')


class GenerateEmbeddingsResponse(sgqlc.types.Type):
__schema__ = schema
__field_names__ = ('success', 'code', 'error', 'embeddings')
success = sgqlc.types.Field(sgqlc.types.non_null(Boolean), graphql_name='success')
code = sgqlc.types.Field(Int, graphql_name='code')
error = sgqlc.types.Field(String, graphql_name='error')
embeddings = sgqlc.types.Field(sgqlc.types.list_of(sgqlc.types.non_null(EmbeddingResult)), graphql_name='embeddings')


class GenerateVisualizationResponse(sgqlc.types.Type):
__schema__ = schema
__field_names__ = ('success', 'code', 'error', 'visualization')
Expand Down Expand Up @@ -1002,6 +1018,20 @@ class MaxReportResult(sgqlc.types.Type):
final_message = sgqlc.types.Field(String, graphql_name='finalMessage')
preview = sgqlc.types.Field(String, graphql_name='preview')


class MaxSkillComponent(sgqlc.types.Type):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This wasn't mine, but got picked up in the schema generation running against my Max branch (was created from develop this morning).

__schema__ = schema
__field_names__ = ('skill_component_id', 'node_type', 'organization', 'name', 'description', 'input_properties', 'output_properties', 'component_data')
skill_component_id = sgqlc.types.Field(sgqlc.types.non_null(UUID), graphql_name='skillComponentId')
node_type = sgqlc.types.Field(sgqlc.types.non_null(String), graphql_name='nodeType')
organization = sgqlc.types.Field(sgqlc.types.non_null(String), graphql_name='organization')
name = sgqlc.types.Field(sgqlc.types.non_null(String), graphql_name='name')
description = sgqlc.types.Field(String, graphql_name='description')
input_properties = sgqlc.types.Field(sgqlc.types.non_null(sgqlc.types.list_of(sgqlc.types.non_null('MaxSkillComponentInputProperty'))), graphql_name='inputProperties')
output_properties = sgqlc.types.Field(sgqlc.types.non_null(sgqlc.types.list_of(sgqlc.types.non_null('MaxSkillComponentOutputProperty'))), graphql_name='outputProperties')
component_data = sgqlc.types.Field(sgqlc.types.non_null(JSON), graphql_name='componentData')


class MaxSkillComponentInputProperty(sgqlc.types.Type):
__schema__ = schema
__field_names__ = ('name', 'label', 'description', 'type', 'is_required', 'is_list', 'can_wire_from_output')
Expand Down Expand Up @@ -1474,8 +1504,7 @@ class ParameterDefinition(sgqlc.types.Type):

class Query(sgqlc.types.Type):
__schema__ = schema
__field_names__ = ('ping', 'current_user', 'get_copilot_skill_artifact_by_path', 'get_copilots', 'get_copilot_info', 'get_copilot_skill', 'run_copilot_skill', 'get_copilot_hydrated_reports', 'get_async_skill_run_status', 'get_max_agent_workflow', 'execute_sql_query', 'execute_rql_query', 'get_databases', 'get_database', 'get_database_tables', 'get_dataset_id', 'get_dataset', 'get_dataset2', 'get_datasets', 'get_domain_object', 'get_domain_object_by_name', 'get_grounded_value', 'get_database_kshots', 'get_database_kshot_by_id', 'get_dataset_kshots', 'get_dataset_kshot_by_id', 'run_max_sql_gen', 'run_sql_ai', 'generate_visualization', 'llmapi_config_for_sdk', 'get_max_llm_prompt', 'user_chat_threads', 'user_chat_entries', 'chat_thread', 'chat_entry', 'user', 'all_chat_entries', 'skill_memory', 'chat_completion', 'narrative_completion', 'narrative_completion_with_prompt', 'sql_completion', 'research_completion', 'chat_completion_with_prompt', 'research_completion_with_prompt', 'get_chat_artifact', 'get_chat_artifacts')

__field_names__ = ('ping', 'current_user', 'get_copilot_skill_artifact_by_path', 'get_copilots', 'get_copilot_info', 'get_copilot_skill', 'run_copilot_skill', 'get_skill_components', 'get_copilot_hydrated_reports', 'get_async_skill_run_status', 'get_max_agent_workflow', 'execute_sql_query', 'execute_rql_query', 'get_databases', 'get_database', 'get_database_tables', 'get_dataset_id', 'get_dataset', 'get_dataset2', 'get_datasets', 'get_domain_object', 'get_domain_object_by_name', 'get_grounded_value', 'get_database_kshots', 'get_database_kshot_by_id', 'get_dataset_kshots', 'get_dataset_kshot_by_id', 'run_max_sql_gen', 'run_sql_ai', 'generate_visualization', 'llmapi_config_for_sdk', 'generate_embeddings', 'get_max_llm_prompt', 'user_chat_threads', 'user_chat_entries', 'chat_thread', 'chat_entry', 'user', 'all_chat_entries', 'skill_memory', 'chat_completion', 'narrative_completion', 'narrative_completion_with_prompt', 'sql_completion', 'research_completion', 'chat_completion_with_prompt', 'research_completion_with_prompt', 'get_chat_artifact', 'get_chat_artifacts')
ping = sgqlc.types.Field(String, graphql_name='ping')
current_user = sgqlc.types.Field(MaxUser, graphql_name='currentUser')
get_copilot_skill_artifact_by_path = sgqlc.types.Field(CopilotSkillArtifact, graphql_name='getCopilotSkillArtifactByPath', args=sgqlc.types.ArgDict((
Expand Down Expand Up @@ -1504,7 +1533,7 @@ class Query(sgqlc.types.Type):
('validate_parameters', sgqlc.types.Arg(Boolean, graphql_name='validateParameters', default=None)),
))
)

get_skill_components = sgqlc.types.Field(sgqlc.types.non_null(sgqlc.types.list_of(sgqlc.types.non_null(MaxSkillComponent))), graphql_name='getSkillComponents')
get_copilot_hydrated_reports = sgqlc.types.Field(sgqlc.types.non_null(sgqlc.types.list_of(sgqlc.types.non_null(HydratedReport))), graphql_name='getCopilotHydratedReports', args=sgqlc.types.ArgDict((
('copilot_id', sgqlc.types.Arg(sgqlc.types.non_null(UUID), graphql_name='copilotId', default=None)),
('override_dataset_id', sgqlc.types.Arg(UUID, graphql_name='overrideDatasetId', default=None)),
Expand Down Expand Up @@ -1629,6 +1658,11 @@ class Query(sgqlc.types.Type):
)
llmapi_config_for_sdk = sgqlc.types.Field(LLMApiConfig, graphql_name='LLMApiConfigForSdk', args=sgqlc.types.ArgDict((
('model_type', sgqlc.types.Arg(sgqlc.types.non_null(String), graphql_name='modelType', default=None)),
))
)
generate_embeddings = sgqlc.types.Field(sgqlc.types.non_null(GenerateEmbeddingsResponse), graphql_name='generateEmbeddings', args=sgqlc.types.ArgDict((
('texts', sgqlc.types.Arg(sgqlc.types.non_null(sgqlc.types.list_of(sgqlc.types.non_null(String))), graphql_name='texts', default=None)),
('model_override', sgqlc.types.Arg(String, graphql_name='modelOverride', default=None)),
))
)
get_max_llm_prompt = sgqlc.types.Field(MaxLLmPrompt, graphql_name='getMaxLlmPrompt', args=sgqlc.types.ArgDict((
Expand Down
13 changes: 13 additions & 0 deletions answer_rocket/graphql/sdk_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1007,6 +1007,18 @@ def query_get_dataset_kshot_by_id():
return _op


def query_generate_embeddings():
_op = sgqlc.operation.Operation(_schema_root.query_type, name='GenerateEmbeddings', variables=dict(texts=sgqlc.types.Arg(sgqlc.types.non_null(sgqlc.types.list_of(sgqlc.types.non_null(_schema.String)))), modelOverride=sgqlc.types.Arg(_schema.String)))
_op_generate_embeddings = _op.generate_embeddings(texts=sgqlc.types.Variable('texts'), model_override=sgqlc.types.Variable('modelOverride'))
_op_generate_embeddings.success()
_op_generate_embeddings.code()
_op_generate_embeddings.error()
_op_generate_embeddings_embeddings = _op_generate_embeddings.embeddings()
_op_generate_embeddings_embeddings.text()
_op_generate_embeddings_embeddings.vector()
return _op


def query_chat_completion():
_op = sgqlc.operation.Operation(_schema_root.query_type, name='ChatCompletion', variables=dict(messages=sgqlc.types.Arg(sgqlc.types.non_null(sgqlc.types.list_of(sgqlc.types.non_null(_schema.LlmChatMessage)))), modelSelection=sgqlc.types.Arg(_schema.LlmModelSelection), llmMeta=sgqlc.types.Arg(_schema.LlmMeta), functions=sgqlc.types.Arg(sgqlc.types.list_of(sgqlc.types.non_null(_schema.LlmFunction)))))
_op.chat_completion(messages=sgqlc.types.Variable('messages'), model_selection=sgqlc.types.Variable('modelSelection'), llm_meta=sgqlc.types.Variable('llmMeta'), functions=sgqlc.types.Variable('functions'))
Expand Down Expand Up @@ -1078,6 +1090,7 @@ class Query:
chat_thread = query_chat_thread()
current_user = query_current_user()
dataframes_for_entry = query_dataframes_for_entry()
generate_embeddings = query_generate_embeddings()
get_async_skill_run_status = query_get_async_skill_run_status()
get_chat_artifact = query_get_chat_artifact()
get_chat_artifacts = query_get_chat_artifacts()
Expand Down
28 changes: 28 additions & 0 deletions answer_rocket/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,3 +355,31 @@ def research_completion_with_prompt(self, prompt_name: str, prompt_variables: di
}
gql_response = self.gql_client.submit(op, args)
return gql_response.research_completion_with_prompt

def generate_embeddings(self, texts: list[str], model_override: str | None = None):
"""
Generate embeddings for the provided texts.

Parameters
----------
texts : list[str]
List of text strings to generate embeddings for.
model_override : str | None, optional
Model name or ID to use instead of configured default.

Returns
-------
dict
The response containing success status, error (if any), and embeddings.
Each embedding includes the original text and its vector representation.
"""
query_args = {
'texts': texts,
'modelOverride': model_override,
}

op = Operations.query.generate_embeddings

result = self.gql_client.submit(op, query_args)

return result.generate_embeddings