Content-Length: 621178 | pFad | https://github.com/googleapis/python-aiplatform/commit/90d95d778f94e598a78a6f1c8a38e1911bffd8e2

C3 feat: LVM - Added support for Images from GCS uri for multimodal embe… · googleapis/python-aiplatform@90d95d7 · GitHub
Skip to content

Commit 90d95d7

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: LVM - Added support for Images from GCS uri for multimodal embeddings
PiperOrigin-RevId: 605748060
1 parent 716f3e1 commit 90d95d7

File tree

3 files changed

+106
-7
lines changed

3 files changed

+106
-7
lines changed

tests/system/aiplatform/test_vision_models.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,12 @@ def _create_blank_image(
3838
return vision_models.Image.load_from_file(image_path)
3939

4040

41+
def _load_image_from_gcs(
42+
gcs_uri: str = "gs://cloud-samples-data/vertex-ai/llm/prompts/landmark1.png",
43+
) -> vision_models.Image:
44+
return vision_models.Image.load_from_file(gcs_uri)
45+
46+
4147
class VisionModelTestSuite(e2e_base.TestEndToEnd):
4248
"""System tests for vision models."""
4349

@@ -85,6 +91,22 @@ def test_multi_modal_embedding_model(self):
8591
assert len(embeddings.image_embedding) == 1408
8692
assert len(embeddings.text_embedding) == 1408
8793

94+
def test_multi_modal_embedding_model_with_gcs_uri(self):
95+
aiplatform.init(project=e2e_base._PROJECT, location=e2e_base._LOCATION)
96+
97+
model = ga_vision_models.MultiModalEmbeddingModel.from_pretrained(
98+
"multimodalembedding@001"
99+
)
100+
image = _load_image_from_gcs()
101+
embeddings = model.get_embeddings(
102+
image=image,
103+
# Optional:
104+
contextual_text="this is a car",
105+
)
106+
# The service is expected to return the embeddings of size 1408
107+
assert len(embeddings.image_embedding) == 1408
108+
assert len(embeddings.text_embedding) == 1408
109+
88110
def test_image_generation_model_generate_images(self):
89111
"""Tests the image generation model generating images."""
90112
model = vision_models.ImageGenerationModel.from_pretrained(

tests/unit/aiplatform/test_vision_models.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,12 @@ def generate_image_from_file(
132132
return ga_vision_models.Image.load_from_file(image_path)
133133

134134

135+
def generate_image_from_gcs_uri(
136+
gcs_uri: str = "gs://cloud-samples-data/vertex-ai/llm/prompts/landmark1.png",
137+
) -> ga_vision_models.Image:
138+
return ga_vision_models.Image.load_from_file(gcs_uri)
139+
140+
135141
@pytest.mark.usefixtures("google_auth_mock")
136142
class TestImageGenerationModels:
137143
"""Unit tests for the image generation models."""
@@ -721,6 +727,42 @@ def test_image_embedding_model_with_lower_dimensions(self):
721727
assert embedding_response.image_embedding == test_embeddings
722728
assert embedding_response.text_embedding == test_embeddings
723729

730+
def test_image_embedding_model_with_gcs_uri(self):
731+
aiplatform.init(
732+
project=_TEST_PROJECT,
733+
location=_TEST_LOCATION,
734+
)
735+
with mock.patch.object(
736+
target=model_garden_service_client.ModelGardenServiceClient,
737+
attribute="get_publisher_model",
738+
return_value=gca_publisher_model.PublisherModel(
739+
_IMAGE_EMBEDDING_PUBLISHER_MODEL_DICT
740+
),
741+
):
742+
model = preview_vision_models.MultiModalEmbeddingModel.from_pretrained(
743+
"multimodalembedding@001"
744+
)
745+
746+
test_embeddings = [0, 0]
747+
gca_predict_response = gca_prediction_service.PredictResponse()
748+
gca_predict_response.predictions.append(
749+
{"imageEmbedding": test_embeddings, "textEmbedding": test_embeddings}
750+
)
751+
752+
image = generate_image_from_gcs_uri()
753+
754+
with mock.patch.object(
755+
target=prediction_service_client.PredictionServiceClient,
756+
attribute="predict",
757+
return_value=gca_predict_response,
758+
):
759+
embedding_response = model.get_embeddings(
760+
image=image, contextual_text="hello world"
761+
)
762+
763+
assert embedding_response.image_embedding == test_embeddings
764+
assert embedding_response.text_embedding == test_embeddings
765+
724766

725767
@pytest.mark.usefixtures("google_auth_mock")
726768
class ImageTextModelTests:

vertexai/vision_models/_vision_models.py

Lines changed: 42 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@
2323
import typing
2424
from typing import Any, Dict, List, Optional, Union
2525

26+
from google.cloud import storage
27+
28+
from google.cloud.aiplatform import initializer as aiplatform_initializer
2629
from vertexai._model_garden import _model_garden_models
2730

2831
# pylint: disable=g-import-not-at-top
@@ -45,31 +48,60 @@ class Image:
4548

4649
__module__ = "vertexai.vision_models"
4750

48-
_image_bytes: bytes
51+
_loaded_bytes: Optional[bytes] = None
4952
_loaded_image: Optional["PIL_Image.Image"] = None
53+
_gcs_uri: Optional[str] = None
5054

51-
def __init__(self, image_bytes: bytes):
55+
def __init__(
56+
self,
57+
image_bytes: Optional[bytes] = None,
58+
gcs_uri: Optional[str] = None,
59+
):
5260
"""Creates an `Image` object.
5361
5462
Args:
5563
image_bytes: Image file bytes. Image can be in PNG or JPEG format.
64+
gcs_uri: Image URI in Google Cloud Storage.
5665
"""
66+
if bool(image_bytes) == bool(gcs_uri):
67+
raise ValueError("Either image_bytes or gcs_uri must be provided.")
68+
5769
self._image_bytes = image_bytes
70+
self._gcs_uri = gcs_uri
5871

5972
@staticmethod
6073
def load_from_file(location: str) -> "Image":
61-
"""Loads image from file.
74+
"""Loads image from local file or Google Cloud Storage.
6275
6376
Args:
64-
location: Local path from where to load the image.
77+
location: Local path or Google Cloud Storage uri from where to load
78+
the image.
6579
6680
Returns:
6781
Loaded image as an `Image` object.
6882
"""
83+
if location.startswith("gs://"):
84+
return Image(gcs_uri=location)
85+
6986
image_bytes = pathlib.Path(location).read_bytes()
7087
image = Image(image_bytes=image_bytes)
7188
return image
7289

90+
@property
91+
def _image_bytes(self) -> bytes:
92+
if self._loaded_bytes is None:
93+
storage_client = storage.Client(
94+
credentials=aiplatform_initializer.global_config.credentials
95+
)
96+
self._loaded_bytes = storage.Blob.from_string(
97+
uri=self._gcs_uri, client=storage_client
98+
).download_as_bytes()
99+
return self._loaded_bytes
100+
101+
@_image_bytes.setter
102+
def _image_bytes(self, value: bytes):
103+
self._loaded_bytes = value
104+
73105
@property
74106
def _pil_image(self) -> "PIL_Image.Image":
75107
if self._loaded_image is None:
@@ -664,7 +696,7 @@ def get_embeddings(
664696
values: `128`, `256`, `512`, and `1408` (default).
665697
666698
Returns:
667-
ImageEmbeddingResponse:
699+
MultiModalEmbeddingResponse:
668700
The image and text embedding vectors.
669701
"""
670702

@@ -674,7 +706,10 @@ def get_embeddings(
674706
instance = {}
675707

676708
if image:
677-
instance["image"] = {"bytesBase64Encoded": image._as_base64_string()}
709+
if image._gcs_uri:
710+
instance["image"] = {"gcsUri": image._gcs_uri}
711+
else:
712+
instance["image"] = {"bytesBase64Encoded": image._as_base64_string()}
678713

679714
if contextual_text:
680715
instance["text"] = contextual_text
@@ -702,7 +737,7 @@ def get_embeddings(
702737

703738
@dataclasses.dataclass
704739
class MultiModalEmbeddingResponse:
705-
"""The image embedding response.
740+
"""The multimodal embedding response.
706741
707742
Attributes:
708743
image_embedding (List[float]):

0 commit comments

Comments
 (0)








ApplySandwichStrip

pFad - (p)hone/(F)rame/(a)nonymizer/(d)eclutterfier!      Saves Data!


--- a PPN by Garber Painting Akron. With Image Size Reduction included!

Fetched URL: https://github.com/googleapis/python-aiplatform/commit/90d95d778f94e598a78a6f1c8a38e1911bffd8e2

Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy