Skip to content

Commit b3729c1

Browse files
Ark-kuncopybara-github
authored andcommitted
feat: LVM - Added support for Image Generation models
Features: * Generate images from text prompt (prompt, negative prompt, width, height, seed, guidance scale) * Edit an existing images using a text prompt (and optional mask) * Upscale image * Show image (works in notebook environments) * Save image (saved image also includes the image generation parameters) Example usage: ```python model = ImageGenerationModel.from_pretrained("imagegeneration@002") images = model.generate_images( prompt="Astronaut riding a horse", # Optional: number_of_images=1, width=1024, height=768, seed=1, guidance_scale=15, ) images[0].show() images[0].save("image1.png") ``` PiperOrigin-RevId: 557736987
1 parent caee592 commit b3729c1

File tree

4 files changed

+793
-1
lines changed

4 files changed

+793
-1
lines changed

tests/system/aiplatform/test_vision_models.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,3 +84,77 @@ def test_multi_modal_embedding_model(self):
8484
# The service is expected to return the embeddings of size 1408
8585
assert len(embeddings.image_embedding) == 1408
8686
assert len(embeddings.text_embedding) == 1408
87+
88+
def test_image_generation_model_generate_images(self):
89+
"""Tests the image generation model generating images."""
90+
model = vision_models.ImageGenerationModel.from_pretrained(
91+
"imagegeneration@001"
92+
)
93+
94+
width = 1024
95+
height = 768
96+
number_of_images = 4
97+
seed = 1
98+
guidance_scale = 15
99+
100+
prompt1 = "Astronaut riding a horse"
101+
negative_prompt1 = "bad quality"
102+
image_response = model.generate_images(
103+
prompt=prompt1,
104+
# Optional:
105+
negative_prompt=negative_prompt1,
106+
number_of_images=number_of_images,
107+
width=width,
108+
height=height,
109+
seed=seed,
110+
guidance_scale=guidance_scale,
111+
)
112+
113+
assert len(image_response.images) == number_of_images
114+
for idx, image in enumerate(image_response):
115+
assert image._pil_image.size == (width, height)
116+
assert image.generation_parameters
117+
assert image.generation_parameters["prompt"] == prompt1
118+
assert image.generation_parameters["negative_prompt"] == negative_prompt1
119+
assert image.generation_parameters["width"] == width
120+
assert image.generation_parameters["height"] == height
121+
assert image.generation_parameters["seed"] == seed
122+
assert image.generation_parameters["guidance_scale"] == guidance_scale
123+
assert image.generation_parameters["index_of_image_in_batch"] == idx
124+
125+
# Test saving and loading images
126+
with tempfile.TemporaryDirectory() as temp_dir:
127+
image_path = os.path.join(temp_dir, "image.png")
128+
image_response[0].save(location=image_path)
129+
image1 = vision_models.GeneratedImage.load_from_file(image_path)
130+
assert image1._pil_image.size == (width, height)
131+
assert image1.generation_parameters
132+
assert image1.generation_parameters["prompt"] == prompt1
133+
134+
# Preparing mask
135+
mask_path = os.path.join(temp_dir, "mask.png")
136+
mask_pil_image = PIL_Image.new(mode="RGB", size=(width, height))
137+
mask_pil_image.save(mask_path, format="PNG")
138+
mask_image = vision_models.Image.load_from_file(mask_path)
139+
140+
# Test generating image from base image
141+
prompt2 = "Ancient book style"
142+
image_response2 = model.edit_image(
143+
prompt=prompt2,
144+
# Optional:
145+
number_of_images=number_of_images,
146+
seed=seed,
147+
guidance_scale=guidance_scale,
148+
base_image=image1,
149+
mask=mask_image,
150+
)
151+
assert len(image_response2.images) == number_of_images
152+
for idx, image in enumerate(image_response2):
153+
assert image._pil_image.size == (width, height)
154+
assert image.generation_parameters
155+
assert image.generation_parameters["prompt"] == prompt2
156+
assert image.generation_parameters["seed"] == seed
157+
assert image.generation_parameters["guidance_scale"] == guidance_scale
158+
assert image.generation_parameters["index_of_image_in_batch"] == idx
159+
assert "base_image_hash" in image.generation_parameters
160+
assert "mask_hash" in image.generation_parameters

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy