Skip to content

Commit 7fea754

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: LLM - Support tuning of new text embedding models by migrating to the new v1.1.3 pipeline.
PiperOrigin-RevId: 631887159
1 parent 3938107 commit 7fea754

File tree

3 files changed

+115
-47
lines changed

3 files changed

+115
-47
lines changed

tests/unit/aiplatform/test_language_models.py

Lines changed: 90 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -563,7 +563,7 @@ def reverse_string_2(s):""",
563563
"parameterType": "STRING",
564564
},
565565
"base_model_version_id": {
566-
"defaultValue": "textembedding-gecko@001",
566+
"defaultValue": "text-embedding-004",
567567
"description": "which base model to tune. This may be any stable\nnumbered version, for example `textembedding-gecko@001`.",
568568
"isOptional": True,
569569
"parameterType": "STRING",
@@ -578,17 +578,15 @@ def reverse_string_2(s):""",
578578
"description": "the GCS path to the corpus data location.",
579579
"parameterType": "STRING",
580580
},
581-
"iterations": {
582-
"defaultValue": 1000,
583-
"description": "the number of steps to perform fine-tuning.",
581+
"encryption_spec_key_name": {
582+
"defaultValue": "",
584583
"isOptional": True,
585-
"parameterType": "NUMBER_INTEGER",
584+
"parameterType": "STRING",
586585
},
587-
"location": {
588-
"defaultValue": "us-central1",
589-
"description": "GCP region to run the pipeline.",
586+
"learning_rate_multiplier": {
587+
"defaultValue": 1.0,
590588
"isOptional": True,
591-
"parameterType": "STRING",
589+
"parameterType": "NUMBER_DOUBLE",
592590
},
593591
"machine_type": {
594592
"defaultValue": "n1-standard-16",
@@ -602,9 +600,10 @@ def reverse_string_2(s):""",
602600
"isOptional": True,
603601
"parameterType": "STRING",
604602
},
605-
"project": {
606-
"description": "user's project id.",
607-
"parameterType": "STRING",
603+
"output_dimensionality": {
604+
"defaultValue": -1,
605+
"isOptional": True,
606+
"parameterType": "NUMBER_INTEGER",
608607
},
609608
"queries_path": {
610609
"description": "the GCS path to the queries location.",
@@ -626,6 +625,12 @@ def reverse_string_2(s):""",
626625
"description": "the GCS path to the train label data location.",
627626
"parameterType": "STRING",
628627
},
628+
"train_steps": {
629+
"defaultValue": 1000,
630+
"description": "the number of steps to perform fine-tuning.",
631+
"isOptional": True,
632+
"parameterType": "NUMBER_INTEGER",
633+
},
629634
"validation_label_path": {
630635
"defaultValue": "",
631636
"description": "The GCS path to the validation label data location.",
@@ -2283,6 +2288,61 @@ def test_text_generation_response_repr(self):
22832288
["https://us-central1-kfp.pkg.dev/proj/repo/pack/latest"],
22842289
indirect=True,
22852290
)
2291+
@pytest.mark.parametrize(
2292+
"base_model_version_id,tune_args,expected_pipeline_args",
2293+
[ # Do not pass any optional parameters.
2294+
(
2295+
"textembedding-gecko@003",
2296+
dict(
2297+
training_data="gs://bucket/training.tsv",
2298+
corpus_data="gs://bucket/corpus.jsonl",
2299+
queries_data="gs://bucket/queries.jsonl",
2300+
),
2301+
dict(
2302+
base_model_version_id="textembedding-gecko@003",
2303+
train_label_path="gs://bucket/training.tsv",
2304+
corpus_path="gs://bucket/corpus.jsonl",
2305+
queries_path="gs://bucket/queries.jsonl",
2306+
encryption_spec_key_name=_TEST_ENCRYPTION_KEY_NAME,
2307+
),
2308+
),
2309+
# Pass all optional parameters.
2310+
(
2311+
"text-multilingual-embedding-002",
2312+
dict(
2313+
training_data="gs://bucket/training.tsv",
2314+
corpus_data="gs://bucket/corpus.jsonl",
2315+
queries_data="gs://bucket/queries.jsonl",
2316+
test_data="gs://bucket/test.tsv",
2317+
validation_data="gs://bucket/validation.tsv",
2318+
tuned_model_location="us-central1",
2319+
model_display_name="my-tuned-model",
2320+
train_steps=30,
2321+
batch_size=256,
2322+
accelerator="NVIDIA_TESLA_V100",
2323+
accelerator_count=1,
2324+
machine_type="n1-highmem-16",
2325+
task_type="DEFAULT",
2326+
),
2327+
dict(
2328+
train_steps=30,
2329+
accelerator_type="NVIDIA_TESLA_V100",
2330+
accelerator_count=1,
2331+
machine_type="n1-highmem-16",
2332+
base_model_version_id="text-multilingual-embedding-002",
2333+
train_label_path="gs://bucket/training.tsv",
2334+
corpus_path="gs://bucket/corpus.jsonl",
2335+
queries_path="gs://bucket/queries.jsonl",
2336+
test_label_path="gs://bucket/test.tsv",
2337+
batch_size=256,
2338+
model_display_name="my-tuned-model",
2339+
validation_label_path="gs://bucket/validation.tsv",
2340+
encryption_spec_key_name=_TEST_ENCRYPTION_KEY_NAME,
2341+
task_type="DEFAULT",
2342+
),
2343+
),
2344+
],
2345+
)
22862346
def test_tune_text_embedding_model(
22872347
self,
22882348
mock_pipeline_service_create,
@@ -2294,6 +2354,9 @@ def test_tune_text_embedding_model(
22942354
mock_gcs_upload,
22952355
mock_request_urlopen_gecko,
22962356
mock_deploy_tuned_embedding_model,
2357+
tune_args,
2358+
expected_pipeline_args,
2359+
base_model_version_id,
22972360
):
22982361
"""Tests tuning the text embedding model."""
22992362
aiplatform.init(
@@ -2309,23 +2372,23 @@ def test_tune_text_embedding_model(
23092372
),
23102373
):
23112374
model = language_models.TextEmbeddingModel.from_pretrained(
2312-
"textembedding-gecko@003"
2313-
)
2314-
tuning_job = model.tune_model(
2315-
training_data="gs://bucket/training.tsv",
2316-
corpus_data="gs://bucket/corpus.jsonl",
2317-
queries_data="gs://bucket/queries.jsonl",
2318-
test_data="gs://bucket/test.tsv",
2319-
tuned_model_location="us-central1",
2320-
train_steps=10,
2321-
accelerator="NVIDIA_TESLA_A100",
2375+
base_model_version_id
23222376
)
2377+
tuning_job = model.tune_model(**tune_args)
23232378
call_kwargs = mock_pipeline_service_create.call_args[1]
2324-
pipeline_arguments = call_kwargs[
2325-
"pipeline_job"
2326-
].runtime_config.parameter_values
2327-
assert pipeline_arguments["iterations"] == 10
2328-
assert pipeline_arguments["accelerator_type"] == "NVIDIA_TESLA_A100"
2379+
pipeline_arguments = dict(
2380+
call_kwargs["pipeline_job"].runtime_config.parameter_values
2381+
)
2382+
2383+
if (
2384+
"model_display_name" not in tune_args
2385+
and "model_display_name" in pipeline_arguments
2386+
):
2387+
# This is automatically generated from some params, so don't
2388+
# check it.
2389+
del pipeline_arguments["model_display_name"]
2390+
2391+
assert pipeline_arguments == expected_pipeline_args
23292392

23302393
# Testing the tuned model
23312394
tuned_model = tuning_job.deploy_tuned_model()

vertexai/_model_garden/_model_garden_models.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,10 @@
3939
"chat-bison-32k": "https://us-kfp.pkg.dev/ml-pipeline/large-language-model-pipelines/tune-large-chat-model/v3.0.0",
4040
"codechat-bison": "https://us-kfp.pkg.dev/ml-pipeline/large-language-model-pipelines/tune-large-chat-model/v3.0.0",
4141
"codechat-bison-32k": "https://us-kfp.pkg.dev/ml-pipeline/large-language-model-pipelines/tune-large-chat-model/v3.0.0",
42-
"textembedding-gecko": "https://us-kfp.pkg.dev/ml-pipeline/llm-text-embedding/tune-text-embedding-model/v1.1.2",
43-
"textembedding-gecko-multilingual": "https://us-kfp.pkg.dev/ml-pipeline/llm-text-embedding/tune-text-embedding-model/v1.1.2",
42+
"textembedding-gecko": "https://us-kfp.pkg.dev/ml-pipeline/llm-text-embedding/tune-text-embedding-model/v1.1.3",
43+
"textembedding-gecko-multilingual": "https://us-kfp.pkg.dev/ml-pipeline/llm-text-embedding/tune-text-embedding-model/v1.1.3",
44+
"text-embedding-004": "https://us-kfp.pkg.dev/ml-pipeline/llm-text-embedding/tune-text-embedding-model/v1.1.3",
45+
"text-multilingual-embedding-002": "https://us-kfp.pkg.dev/ml-pipeline/llm-text-embedding/tune-text-embedding-model/v1.1.3",
4446
}
4547

4648
_LOGGER = base.Logger(__name__)

vertexai/language_models/_language_models.py

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -414,20 +414,24 @@ def _tune_model(
414414
model_id=self._model_id,
415415
schema_to_class_map={self._INSTANCE_SCHEMA_URI: type(self)},
416416
)
417-
if model_info.tuning_pipeline_uri.startswith(
418-
"https://us-kfp.pkg.dev/ml-pipeline/llm-text-embedding/tune-text-embedding-model"
419-
):
420-
train_steps = tuning_parameters.pop("train_steps", None)
421-
if train_steps:
422-
tuning_parameters["iterations"] = train_steps
417+
if _is_text_embedding_tuning_pipeline(model_info.tuning_pipeline_uri):
423418
tunable_base_model_id = self._model_id.rpartition("/")[-1]
424419
tuning_parameters["base_model_version_id"] = tunable_base_model_id
425420
else:
426421
tuning_parameters["large_model_reference"] = model_info.tuning_model_id
427-
if aiplatform_initializer.global_config.encryption_spec_key_name:
428-
tuning_parameters[
429-
"encryption_spec_key_name"
430-
] = aiplatform_initializer.global_config.encryption_spec_key_name
422+
tuning_parameters.update(
423+
{
424+
"project": aiplatform_initializer.global_config.project,
425+
# TODO(b/275444096): Remove the explicit location once tuning
426+
# can happen in all regions.
427+
# "location": aiplatform_initializer.global_config.location,
428+
"location": tuned_model_location,
429+
}
430+
)
431+
if aiplatform_initializer.global_config.encryption_spec_key_name:
432+
tuning_parameters[
433+
"encryption_spec_key_name"
434+
] = aiplatform_initializer.global_config.encryption_spec_key_name
431435

432436
if not model_info.tuning_pipeline_uri:
433437
raise RuntimeError(f"The {self._model_id} model does not support tuning")
@@ -3890,6 +3894,12 @@ def _maybe_upload_training_data(
38903894
)
38913895

38923896

3897+
def _is_text_embedding_tuning_pipeline(pipeline_uri: str) -> bool:
3898+
return pipeline_uri.startswith(
3899+
"https://us-kfp.pkg.dev/ml-pipeline/llm-text-embedding/tune-text-embedding-model"
3900+
)
3901+
3902+
38933903
def _launch_tuning_job(
38943904
training_data: Union[str, "pandas.core.frame.DataFrame"],
38953905
model_id: str,
@@ -3931,16 +3941,9 @@ def _launch_tuning_job(
39313941
model_display_name = name[:max_display_name_length]
39323942

39333943
pipeline_arguments = {
3934-
"project": aiplatform_initializer.global_config.project,
3935-
# TODO(b/275444096): Remove the explicit location once tuning can happen in all regions
3936-
# "location": aiplatform_initializer.global_config.location,
3937-
"location": tuned_model_location,
39383944
"model_display_name": model_display_name,
39393945
}
3940-
3941-
if tuning_pipeline_uri.startswith(
3942-
"https://us-kfp.pkg.dev/ml-pipeline/llm-text-embedding/tune-text-embedding-model"
3943-
):
3946+
if _is_text_embedding_tuning_pipeline(tuning_pipeline_uri):
39443947
pipeline_arguments["train_label_path"] = training_data_path
39453948
elif training_data_path.startswith("gs://"):
39463949
pipeline_arguments["dataset_uri"] = training_data_path

0 commit comments

Comments
 (0)