@@ -114,6 +114,18 @@ def make_image_generation_response(
114114 return {"predictions" : predictions }
115115
116116
117+ def make_image_generation_response_gcs (count : int = 1 ) -> Dict [str , Any ]:
118+ predictions = []
119+ for _ in range (count ):
120+ predictions .append (
121+ {
122+ "gcsUri" : "gs://cloud-samples-data/vertex-ai/llm/prompts/landmark1.png" ,
123+ "mimeType" : "image/png" ,
124+ }
125+ )
126+ return {"predictions" : predictions }
127+
128+
117129def make_image_upscale_response (upscale_size : int ) -> Dict [str , Any ]:
118130 predictions = {
119131 "bytesBase64Encoded" : make_image_base64 (upscale_size , upscale_size ),
@@ -122,6 +134,14 @@ def make_image_upscale_response(upscale_size: int) -> Dict[str, Any]:
122134 return {"predictions" : [predictions ]}
123135
124136
137+ def make_image_upscale_response_gcs () -> Dict [str , Any ]:
138+ predictions = {
139+ "gcsUri" : "gs://cloud-samples-data/vertex-ai/llm/prompts/landmark1.png" ,
140+ "mimeType" : "image/png" ,
141+ }
142+ return {"predictions" : [predictions ]}
143+
144+
125145def generate_image_from_file (
126146 width : int = 100 , height : int = 100
127147) -> ga_vision_models .Image :
@@ -332,6 +352,111 @@ def test_generate_images(self):
332352 assert image .generation_parameters ["mask_hash" ]
333353 assert image .generation_parameters ["language" ] == language
334354
355+ def test_generate_images_gcs (self ):
356+ """Tests the image generation model."""
357+ model = self ._get_image_generation_model ()
358+
359+ # TODO(b/295946075) The service stopped supporting image sizes.
360+ # height = 768
361+ number_of_images = 4
362+ seed = 1
363+ guidance_scale = 15
364+ language = "en"
365+ output_gcs_uri = "gs://test-bucket/"
366+
367+ image_generation_response = make_image_generation_response_gcs (
368+ count = number_of_images
369+ )
370+ gca_predict_response = gca_prediction_service .PredictResponse ()
371+ gca_predict_response .predictions .extend (
372+ image_generation_response ["predictions" ]
373+ )
374+
375+ with mock .patch .object (
376+ target = prediction_service_client .PredictionServiceClient ,
377+ attribute = "predict" ,
378+ return_value = gca_predict_response ,
379+ ) as mock_predict :
380+ prompt1 = "Astronaut riding a horse"
381+ negative_prompt1 = "bad quality"
382+ image_response = model .generate_images (
383+ prompt = prompt1 ,
384+ # Optional:
385+ negative_prompt = negative_prompt1 ,
386+ number_of_images = number_of_images ,
387+ # TODO(b/295946075) The service stopped supporting image sizes.
388+ # width=width,
389+ # height=height,
390+ seed = seed ,
391+ guidance_scale = guidance_scale ,
392+ language = language ,
393+ output_gcs_uri = output_gcs_uri ,
394+ )
395+ predict_kwargs = mock_predict .call_args [1 ]
396+ actual_parameters = predict_kwargs ["parameters" ]
397+ actual_instance = predict_kwargs ["instances" ][0 ]
398+ assert actual_instance ["prompt" ] == prompt1
399+ assert actual_parameters ["negativePrompt" ] == negative_prompt1
400+ # TODO(b/295946075) The service stopped supporting image sizes.
401+ # assert actual_parameters["sampleImageSize"] == str(max(width, height))
402+ # assert actual_parameters["aspectRatio"] == f"{width}:{height}"
403+ assert actual_parameters ["seed" ] == seed
404+ assert actual_parameters ["guidanceScale" ] == guidance_scale
405+ assert actual_parameters ["language" ] == language
406+ assert actual_parameters ["storageUri" ] == output_gcs_uri
407+
408+ assert len (image_response .images ) == number_of_images
409+ for idx , image in enumerate (image_response ):
410+ assert image .generation_parameters
411+ assert image .generation_parameters ["prompt" ] == prompt1
412+ assert image .generation_parameters ["negative_prompt" ] == negative_prompt1
413+ # TODO(b/295946075) The service stopped supporting image sizes.
414+ # assert image.generation_parameters["width"] == width
415+ # assert image.generation_parameters["height"] == height
416+ assert image .generation_parameters ["seed" ] == seed
417+ assert image .generation_parameters ["guidance_scale" ] == guidance_scale
418+ assert image .generation_parameters ["language" ] == language
419+ assert image .generation_parameters ["index_of_image_in_batch" ] == idx
420+ assert image .generation_parameters ["storage_uri" ] == output_gcs_uri
421+
422+ image1 = generate_image_from_gcs_uri ()
423+ mask_image = generate_image_from_gcs_uri ()
424+
425+ # Test generating image from base image
426+ with mock .patch .object (
427+ target = prediction_service_client .PredictionServiceClient ,
428+ attribute = "predict" ,
429+ return_value = gca_predict_response ,
430+ ) as mock_predict :
431+ prompt2 = "Ancient book style"
432+ image_response2 = model .edit_image (
433+ prompt = prompt2 ,
434+ # Optional:
435+ number_of_images = number_of_images ,
436+ seed = seed ,
437+ guidance_scale = guidance_scale ,
438+ base_image = image1 ,
439+ mask = mask_image ,
440+ language = language ,
441+ output_gcs_uri = output_gcs_uri ,
442+ )
443+ predict_kwargs = mock_predict .call_args [1 ]
444+ actual_parameters = predict_kwargs ["parameters" ]
445+ actual_instance = predict_kwargs ["instances" ][0 ]
446+ assert actual_instance ["prompt" ] == prompt2
447+ assert actual_instance ["image" ]["gcsUri" ]
448+ assert actual_instance ["mask" ]["image" ]["gcsUri" ]
449+ assert actual_parameters ["language" ] == language
450+
451+ assert len (image_response2 .images ) == number_of_images
452+ for image in image_response2 :
453+ assert image .generation_parameters
454+ assert image .generation_parameters ["prompt" ] == prompt2
455+ assert image .generation_parameters ["base_image_uri" ]
456+ assert image .generation_parameters ["mask_uri" ]
457+ assert image .generation_parameters ["language" ] == language
458+ assert image .generation_parameters ["storage_uri" ] == output_gcs_uri
459+
335460 @unittest .skip (reason = "b/295946075 The service stopped supporting image sizes." )
336461 def test_generate_images_requests_square_images_by_default (self ):
337462 """Tests that the model class generates square image by default."""
0 commit comments