Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions answer_rocket/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -761,7 +761,7 @@ def get_domain_object(self, dataset_id: UUID, domain_object_id: str) -> DomainOb

return domain_object_result

def get_grounded_value(self, dataset_id: UUID, value: str, domain_entity: Optional[str] = None, copilot_id: Optional[UUID] = None) -> GroundedValueResponse:
def get_grounded_value(self, dataset_id: UUID, value: str, dimension_name: Optional[str] = None, copilot_id: Optional[UUID] = None) -> GroundedValueResponse:
"""
Get grounded values for fuzzy matching against domain values.

Expand All @@ -771,9 +771,9 @@ def get_grounded_value(self, dataset_id: UUID, value: str, domain_entity: Option
The UUID of the dataset.
value : str
The value to ground (single string).
domain_entity : str, optional
The domain entity to search within. Can be "metrics", "dimensions",
a specific domain attribute name, or None to search all. Defaults to None.
dimension_name : str, optional
The dimension name to search within. Can be
a specific dimension attribute name, or None to search all. Defaults to None.
copilot_id : UUID, optional
The UUID of the copilot. Defaults to the configured copilot_id.

Expand All @@ -787,7 +787,7 @@ def get_grounded_value(self, dataset_id: UUID, value: str, domain_entity: Option
query_args = {
'datasetId': str(dataset_id),
'value': value,
'domainEntity': domain_entity,
'dimensionName': dimension_name,
'copilotId': copilot_id or self.copilot_id,
}

Expand Down
6 changes: 3 additions & 3 deletions answer_rocket/graphql/operations/data.gql
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
query GetGroundedValue(
$datasetId: UUID!
$value: String!
$domainEntity: String
$dimensionName: String
$copilotId: UUID
) {
getGroundedValue(
datasetId: $datasetId
value: $value
domainEntity: $domainEntity
dimensionName: $dimensionName
copilotId: $copilotId
) {
matchedValue
Expand All @@ -16,7 +16,7 @@ query GetGroundedValue(
mappedIndicator
mappedValue
preferred
domainEntity
dimensionName
otherMatches
}
}
Expand Down
25 changes: 20 additions & 5 deletions answer_rocket/graphql/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,11 @@ class LlmResponse(sgqlc.types.Scalar):
__schema__ = schema


class MatchType(sgqlc.types.Enum):
__schema__ = schema
__choices__ = ('DEFAULT', 'EMBEDDED', 'EXACT', 'FUZZY')


class MetricType(sgqlc.types.Enum):
__schema__ = schema
__choices__ = ('BASIC', 'RATIO', 'SHARE')
Expand Down Expand Up @@ -705,17 +710,27 @@ class GenerateVisualizationResponse(sgqlc.types.Type):
visualization = sgqlc.types.Field(JSON, graphql_name='visualization')


class GroundedMatch(sgqlc.types.Type):
__schema__ = schema
__field_names__ = ('value', 'score', 'match_type', 'dimension_name', 'mapped_value')
value = sgqlc.types.Field(sgqlc.types.non_null(String), graphql_name='value')
score = sgqlc.types.Field(sgqlc.types.non_null(Float), graphql_name='score')
match_type = sgqlc.types.Field(sgqlc.types.non_null(MatchType), graphql_name='matchType')
dimension_name = sgqlc.types.Field(String, graphql_name='dimensionName')
mapped_value = sgqlc.types.Field(String, graphql_name='mappedValue')


class GroundedValueResponse(sgqlc.types.Type):
__schema__ = schema
__field_names__ = ('matched_value', 'match_quality', 'match_type', 'mapped_indicator', 'mapped_value', 'preferred', 'domain_entity', 'other_matches')
__field_names__ = ('matched_value', 'match_quality', 'match_type', 'mapped_indicator', 'mapped_value', 'preferred', 'dimension_name', 'other_matches')
matched_value = sgqlc.types.Field(String, graphql_name='matchedValue')
match_quality = sgqlc.types.Field(Float, graphql_name='matchQuality')
match_type = sgqlc.types.Field(String, graphql_name='matchType')
match_type = sgqlc.types.Field(MatchType, graphql_name='matchType')
mapped_indicator = sgqlc.types.Field(Boolean, graphql_name='mappedIndicator')
mapped_value = sgqlc.types.Field(String, graphql_name='mappedValue')
preferred = sgqlc.types.Field(Boolean, graphql_name='preferred')
domain_entity = sgqlc.types.Field(String, graphql_name='domainEntity')
other_matches = sgqlc.types.Field(sgqlc.types.list_of(sgqlc.types.non_null(JSON)), graphql_name='otherMatches')
dimension_name = sgqlc.types.Field(String, graphql_name='dimensionName')
other_matches = sgqlc.types.Field(sgqlc.types.list_of(sgqlc.types.non_null(GroundedMatch)), graphql_name='otherMatches')


class HydratedReport(sgqlc.types.Type):
Expand Down Expand Up @@ -1772,7 +1787,7 @@ class Query(sgqlc.types.Type):
)
get_grounded_value = sgqlc.types.Field(GroundedValueResponse, graphql_name='getGroundedValue', args=sgqlc.types.ArgDict((
('dataset_id', sgqlc.types.Arg(sgqlc.types.non_null(UUID), graphql_name='datasetId', default=None)),
('domain_entity', sgqlc.types.Arg(String, graphql_name='domainEntity', default=None)),
('dimension_name', sgqlc.types.Arg(String, graphql_name='dimensionName', default=None)),
('value', sgqlc.types.Arg(sgqlc.types.non_null(String), graphql_name='value', default=None)),
('copilot_id', sgqlc.types.Arg(UUID, graphql_name='copilotId', default=None)),
))
Expand Down
6 changes: 3 additions & 3 deletions answer_rocket/graphql/sdk_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -980,15 +980,15 @@ def query_get_copilot_question_folders():


def query_get_grounded_value():
_op = sgqlc.operation.Operation(_schema_root.query_type, name='GetGroundedValue', variables=dict(datasetId=sgqlc.types.Arg(sgqlc.types.non_null(_schema.UUID)), value=sgqlc.types.Arg(sgqlc.types.non_null(_schema.String)), domainEntity=sgqlc.types.Arg(_schema.String), copilotId=sgqlc.types.Arg(_schema.UUID)))
_op_get_grounded_value = _op.get_grounded_value(dataset_id=sgqlc.types.Variable('datasetId'), value=sgqlc.types.Variable('value'), domain_entity=sgqlc.types.Variable('domainEntity'), copilot_id=sgqlc.types.Variable('copilotId'))
_op = sgqlc.operation.Operation(_schema_root.query_type, name='GetGroundedValue', variables=dict(datasetId=sgqlc.types.Arg(sgqlc.types.non_null(_schema.UUID)), value=sgqlc.types.Arg(sgqlc.types.non_null(_schema.String)), dimensionName=sgqlc.types.Arg(_schema.String), copilotId=sgqlc.types.Arg(_schema.UUID)))
_op_get_grounded_value = _op.get_grounded_value(dataset_id=sgqlc.types.Variable('datasetId'), value=sgqlc.types.Variable('value'), dimension_name=sgqlc.types.Variable('dimensionName'), copilot_id=sgqlc.types.Variable('copilotId'))
_op_get_grounded_value.matched_value()
_op_get_grounded_value.match_quality()
_op_get_grounded_value.match_type()
_op_get_grounded_value.mapped_indicator()
_op_get_grounded_value.mapped_value()
_op_get_grounded_value.preferred()
_op_get_grounded_value.domain_entity()
_op_get_grounded_value.dimension_name()
_op_get_grounded_value.other_matches()
return _op

Expand Down