@@ -563,7 +563,7 @@ def reverse_string_2(s):""",
563563 "parameterType" : "STRING" ,
564564 },
565565 "base_model_version_id" : {
566- "defaultValue" : "textembedding-gecko@001 " ,
566+ "defaultValue" : "text-embedding-004 " ,
567567 "description" : "which base model to tune. This may be any stable\n numbered version, for example `textembedding-gecko@001`." ,
568568 "isOptional" : True ,
569569 "parameterType" : "STRING" ,
@@ -578,17 +578,15 @@ def reverse_string_2(s):""",
578578 "description" : "the GCS path to the corpus data location." ,
579579 "parameterType" : "STRING" ,
580580 },
581- "iterations" : {
582- "defaultValue" : 1000 ,
583- "description" : "the number of steps to perform fine-tuning." ,
581+ "encryption_spec_key_name" : {
582+ "defaultValue" : "" ,
584583 "isOptional" : True ,
585- "parameterType" : "NUMBER_INTEGER " ,
584+ "parameterType" : "STRING " ,
586585 },
587- "location" : {
588- "defaultValue" : "us-central1" ,
589- "description" : "GCP region to run the pipeline." ,
586+ "learning_rate_multiplier" : {
587+ "defaultValue" : 1.0 ,
590588 "isOptional" : True ,
591- "parameterType" : "STRING " ,
589+ "parameterType" : "NUMBER_DOUBLE " ,
592590 },
593591 "machine_type" : {
594592 "defaultValue" : "n1-standard-16" ,
@@ -602,9 +600,10 @@ def reverse_string_2(s):""",
602600 "isOptional" : True ,
603601 "parameterType" : "STRING" ,
604602 },
605- "project" : {
606- "description" : "user's project id." ,
607- "parameterType" : "STRING" ,
603+ "output_dimensionality" : {
604+ "defaultValue" : - 1 ,
605+ "isOptional" : True ,
606+ "parameterType" : "NUMBER_INTEGER" ,
608607 },
609608 "queries_path" : {
610609 "description" : "the GCS path to the queries location." ,
@@ -626,6 +625,12 @@ def reverse_string_2(s):""",
626625 "description" : "the GCS path to the train label data location." ,
627626 "parameterType" : "STRING" ,
628627 },
628+ "train_steps" : {
629+ "defaultValue" : 1000 ,
630+ "description" : "the number of steps to perform fine-tuning." ,
631+ "isOptional" : True ,
632+ "parameterType" : "NUMBER_INTEGER" ,
633+ },
629634 "validation_label_path" : {
630635 "defaultValue" : "" ,
631636 "description" : "The GCS path to the validation label data location." ,
@@ -2283,6 +2288,61 @@ def test_text_generation_response_repr(self):
22832288 ["https://us-central1-kfp.pkg.dev/proj/repo/pack/latest" ],
22842289 indirect = True ,
22852290 )
2291+ @pytest .mark .parametrize (
2292+ "base_model_version_id,tune_args,expected_pipeline_args" ,
2293+ [ # Do not pass any optional parameters.
2294+ (
2295+ "textembedding-gecko@003" ,
2296+ dict (
2297+ training_data = "gs://bucket/training.tsv" ,
2298+ corpus_data = "gs://bucket/corpus.jsonl" ,
2299+ queries_data = "gs://bucket/queries.jsonl" ,
2300+ ),
2301+ dict (
2302+ base_model_version_id = "textembedding-gecko@003" ,
2303+ train_label_path = "gs://bucket/training.tsv" ,
2304+ corpus_path = "gs://bucket/corpus.jsonl" ,
2305+ queries_path = "gs://bucket/queries.jsonl" ,
2306+ encryption_spec_key_name = _TEST_ENCRYPTION_KEY_NAME ,
2307+ ),
2308+ ),
2309+ # Pass all optional parameters.
2310+ (
2311+ "text-multilingual-embedding-002" ,
2312+ dict (
2313+ training_data = "gs://bucket/training.tsv" ,
2314+ corpus_data = "gs://bucket/corpus.jsonl" ,
2315+ queries_data = "gs://bucket/queries.jsonl" ,
2316+ test_data = "gs://bucket/test.tsv" ,
2317+ validation_data = "gs://bucket/validation.tsv" ,
2318+ tuned_model_location = "us-central1" ,
2319+ model_display_name = "my-tuned-model" ,
2320+ train_steps = 30 ,
2321+ batch_size = 256 ,
2322+ accelerator = "NVIDIA_TESLA_V100" ,
2323+ accelerator_count = 1 ,
2324+ machine_type = "n1-highmem-16" ,
2325+ task_type = "DEFAULT" ,
2326+ ),
2327+ dict (
2328+ train_steps = 30 ,
2329+ accelerator_type = "NVIDIA_TESLA_V100" ,
2330+ accelerator_count = 1 ,
2331+ machine_type = "n1-highmem-16" ,
2332+ base_model_version_id = "text-multilingual-embedding-002" ,
2333+ train_label_path = "gs://bucket/training.tsv" ,
2334+ corpus_path = "gs://bucket/corpus.jsonl" ,
2335+ queries_path = "gs://bucket/queries.jsonl" ,
2336+ test_label_path = "gs://bucket/test.tsv" ,
2337+ batch_size = 256 ,
2338+ model_display_name = "my-tuned-model" ,
2339+ validation_label_path = "gs://bucket/validation.tsv" ,
2340+ encryption_spec_key_name = _TEST_ENCRYPTION_KEY_NAME ,
2341+ task_type = "DEFAULT" ,
2342+ ),
2343+ ),
2344+ ],
2345+ )
22862346 def test_tune_text_embedding_model (
22872347 self ,
22882348 mock_pipeline_service_create ,
@@ -2294,6 +2354,9 @@ def test_tune_text_embedding_model(
22942354 mock_gcs_upload ,
22952355 mock_request_urlopen_gecko ,
22962356 mock_deploy_tuned_embedding_model ,
2357+ tune_args ,
2358+ expected_pipeline_args ,
2359+ base_model_version_id ,
22972360 ):
22982361 """Tests tuning the text embedding model."""
22992362 aiplatform .init (
@@ -2309,23 +2372,23 @@ def test_tune_text_embedding_model(
23092372 ),
23102373 ):
23112374 model = language_models .TextEmbeddingModel .from_pretrained (
2312- "textembedding-gecko@003"
2313- )
2314- tuning_job = model .tune_model (
2315- training_data = "gs://bucket/training.tsv" ,
2316- corpus_data = "gs://bucket/corpus.jsonl" ,
2317- queries_data = "gs://bucket/queries.jsonl" ,
2318- test_data = "gs://bucket/test.tsv" ,
2319- tuned_model_location = "us-central1" ,
2320- train_steps = 10 ,
2321- accelerator = "NVIDIA_TESLA_A100" ,
2375+ base_model_version_id
23222376 )
2377+ tuning_job = model .tune_model (** tune_args )
23232378 call_kwargs = mock_pipeline_service_create .call_args [1 ]
2324- pipeline_arguments = call_kwargs [
2325- "pipeline_job"
2326- ].runtime_config .parameter_values
2327- assert pipeline_arguments ["iterations" ] == 10
2328- assert pipeline_arguments ["accelerator_type" ] == "NVIDIA_TESLA_A100"
2379+ pipeline_arguments = dict (
2380+ call_kwargs ["pipeline_job" ].runtime_config .parameter_values
2381+ )
2382+
2383+ if (
2384+ "model_display_name" not in tune_args
2385+ and "model_display_name" in pipeline_arguments
2386+ ):
2387+ # This is automatically generated from some params, so don't
2388+ # check it.
2389+ del pipeline_arguments ["model_display_name" ]
2390+
2391+ assert pipeline_arguments == expected_pipeline_args
23292392
23302393 # Testing the tuned model
23312394 tuned_model = tuning_job .deploy_tuned_model ()
0 commit comments