Skip to content

Commit db0dc85

Browse files
authored
Vision: Add batch processing (#2978)
* Add Vision batch support to the surface.
1 parent 4d2a7d1 commit db0dc85

File tree

13 files changed

+289
-46
lines changed

13 files changed

+289
-46
lines changed

docs/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,7 @@
155155

156156
vision-usage
157157
vision-annotations
158+
vision-batch
158159
vision-client
159160
vision-color
160161
vision-entity

docs/vision-batch.rst

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
Vision Batch
2+
============
3+
4+
Batch
5+
~~~~~
6+
7+
.. automodule:: google.cloud.vision.batch
8+
:members:
9+
:undoc-members:
10+
:show-inheritance:

docs/vision-usage.rst

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,40 @@ image and determine the dominant colors in the image.
283283
0.758658
284284
285285
286+
*********************
287+
Batch image detection
288+
*********************
289+
290+
Multiple images can be processed with a single request by passing
291+
:class:`~google.cloud.vision.image.Image` to
292+
:meth:`~google.cloud.vision.client.Client.batch()`.
293+
294+
.. code-block:: python
295+
296+
>>> from google.cloud import vision
297+
>>> from google.cloud.vision.feature import Feature
298+
>>> from google.cloud.vision.feature import FeatureTypes
299+
>>>
300+
>>> client = vision.Client()
301+
>>> batch = client.batch()
302+
>>>
303+
>>> image_one = client.image(source_uri='gs://my-test-bucket/image1.jpg')
304+
>>> image_two = client.image(source_uri='gs://my-test-bucket/image2.jpg')
305+
>>> face_feature = Feature(FeatureTypes.FACE_DETECTION, 2)
306+
>>> logo_feature = Feature(FeatureTypes.LOGO_DETECTION, 2)
307+
>>> batch.add_image(image_one, [face_feature, logo_feature])
308+
>>> batch.add_image(image_two, [logo_feature])
309+
>>> results = batch.detect()
310+
>>> for image in results:
311+
... for face in image.faces:
312+
... print('=' * 40)
313+
... print(face.joy)
314+
========================================
315+
<Likelihood.VERY_LIKELY: 'VERY_LIKELY'>
316+
========================================
317+
<Likelihood.VERY_LIKELY: 'POSSIBLE'>
318+
319+
286320
****************
287321
No results found
288322
****************

system_tests/vision.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
from google.cloud import storage
2222
from google.cloud import vision
2323
from google.cloud.vision.entity import EntityAnnotation
24+
from google.cloud.vision.feature import Feature
25+
from google.cloud.vision.feature import FeatureTypes
2426

2527
from system_test_utils import unique_resource_id
2628
from retry import RetryErrors
@@ -507,3 +509,53 @@ def test_detect_properties_filename(self):
507509
image = client.image(filename=FACE_FILE)
508510
properties = image.detect_properties()
509511
self._assert_properties(properties)
512+
513+
514+
class TestVisionBatchProcessing(BaseVisionTestCase):
515+
def setUp(self):
516+
self.to_delete_by_case = []
517+
518+
def tearDown(self):
519+
for value in self.to_delete_by_case:
520+
value.delete()
521+
522+
def test_batch_detect_gcs(self):
523+
client = Config.CLIENT
524+
bucket_name = Config.TEST_BUCKET.name
525+
526+
# Logo GCS image.
527+
blob_name = 'logos.jpg'
528+
blob = Config.TEST_BUCKET.blob(blob_name)
529+
self.to_delete_by_case.append(blob) # Clean-up.
530+
with open(LOGO_FILE, 'rb') as file_obj:
531+
blob.upload_from_file(file_obj)
532+
533+
logo_source_uri = 'gs://%s/%s' % (bucket_name, blob_name)
534+
535+
image_one = client.image(source_uri=logo_source_uri)
536+
logo_feature = Feature(FeatureTypes.LOGO_DETECTION, 2)
537+
538+
# Faces GCS image.
539+
blob_name = 'faces.jpg'
540+
blob = Config.TEST_BUCKET.blob(blob_name)
541+
self.to_delete_by_case.append(blob) # Clean-up.
542+
with open(FACE_FILE, 'rb') as file_obj:
543+
blob.upload_from_file(file_obj)
544+
545+
face_source_uri = 'gs://%s/%s' % (bucket_name, blob_name)
546+
547+
image_two = client.image(source_uri=face_source_uri)
548+
face_feature = Feature(FeatureTypes.FACE_DETECTION, 2)
549+
550+
batch = client.batch()
551+
batch.add_image(image_one, [logo_feature])
552+
batch.add_image(image_two, [face_feature, logo_feature])
553+
results = batch.detect()
554+
self.assertEqual(len(results), 2)
555+
self.assertIsInstance(results[0], vision.annotations.Annotations)
556+
self.assertIsInstance(results[1], vision.annotations.Annotations)
557+
self.assertEqual(len(results[0].logos), 1)
558+
self.assertEqual(len(results[0].faces), 0)
559+
560+
self.assertEqual(len(results[1].logos), 0)
561+
self.assertEqual(len(results[1].faces), 2)

vision/google/cloud/vision/_gax.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -30,24 +30,28 @@ def __init__(self, client=None):
3030
self._client = client
3131
self._annotator_client = image_annotator_client.ImageAnnotatorClient()
3232

33-
def annotate(self, image, features):
33+
def annotate(self, images):
3434
"""Annotate images through GAX.
3535
36-
:type image: :class:`~google.cloud.vision.image.Image`
37-
:param image: Instance of ``Image``.
38-
39-
:type features: list
40-
:param features: List of :class:`~google.cloud.vision.feature.Feature`.
36+
:type images: list
37+
:param images: List containing pairs of
38+
:class:`~google.cloud.vision.image.Image` and
39+
:class:`~google.cloud.vision.feature.Feature`.
40+
e.g. [(image, [feature_one, feature_two]),]
4141
4242
:rtype: list
4343
:returns: List of
4444
:class:`~google.cloud.vision.annotations.Annotations`.
4545
"""
46-
gapic_features = [_to_gapic_feature(feature) for feature in features]
47-
gapic_image = _to_gapic_image(image)
48-
request = image_annotator_pb2.AnnotateImageRequest(
49-
image=gapic_image, features=gapic_features)
50-
requests = [request]
46+
requests = []
47+
for image, features in images:
48+
gapic_features = [_to_gapic_feature(feature)
49+
for feature in features]
50+
gapic_image = _to_gapic_image(image)
51+
request = image_annotator_pb2.AnnotateImageRequest(
52+
image=gapic_image, features=gapic_features)
53+
requests.append(request)
54+
5155
annotator_client = self._annotator_client
5256
responses = annotator_client.batch_annotate_images(requests).responses
5357
return [Annotations.from_pb(response) for response in responses]

vision/google/cloud/vision/_http.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -29,24 +29,19 @@ def __init__(self, client):
2929
self._client = client
3030
self._connection = client._connection
3131

32-
def annotate(self, image, features):
32+
def annotate(self, images):
3333
"""Annotate an image to discover it's attributes.
3434
35-
:type image: :class:`~google.cloud.vision.image.Image`
36-
:param image: A instance of ``Image``.
35+
:type images: list of :class:`~google.cloud.vision.image.Image`
36+
:param images: A list of ``Image``.
3737
38-
:type features: list of :class:`~google.cloud.vision.feature.Feature`
39-
:param features: The type of detection that the Vision API should
40-
use to determine image attributes. Pricing is
41-
based on the number of Feature Types.
42-
43-
See: https://cloud.google.com/vision/docs/pricing
4438
:rtype: list
4539
:returns: List of :class:`~googe.cloud.vision.annotations.Annotations`.
4640
"""
47-
request = _make_request(image, features)
48-
49-
data = {'requests': [request]}
41+
requests = []
42+
for image, features in images:
43+
requests.append(_make_request(image, features))
44+
data = {'requests': requests}
5045
api_response = self._connection.api_request(
5146
method='POST', path='/images:annotate', data=data)
5247
responses = api_response.get('responses')

vision/google/cloud/vision/batch.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# Copyright 2017 Google Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Batch multiple images into one request."""
16+
17+
18+
class Batch(object):
19+
"""Batch of images to process.
20+
21+
:type client: :class:`~google.cloud.vision.client.Client`
22+
:param client: Vision client.
23+
"""
24+
def __init__(self, client):
25+
self._client = client
26+
self._images = []
27+
28+
def add_image(self, image, features):
29+
"""Add image to batch request.
30+
31+
:type image: :class:`~google.cloud.vision.image.Image`
32+
:param image: Istance of ``Image``.
33+
34+
:type features: list
35+
:param features: List of :class:`~google.cloud.vision.feature.Feature`.
36+
"""
37+
self._images.append((image, features))
38+
39+
@property
40+
def images(self):
41+
"""List of images to process.
42+
43+
:rtype: list
44+
:returns: List of :class:`~google.cloud.vision.image.Image`.
45+
"""
46+
return self._images
47+
48+
def detect(self):
49+
"""Perform batch detection of images.
50+
51+
:rtype: list
52+
:returns: List of
53+
:class:`~google.cloud.vision.annotations.Annotations`.
54+
"""
55+
results = self._client._vision_api.annotate(self.images)
56+
self._images = []
57+
return results

vision/google/cloud/vision/client.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from google.cloud.environment_vars import DISABLE_GRPC
2121

2222
from google.cloud.vision._gax import _GAPICVisionAPI
23+
from google.cloud.vision.batch import Batch
2324
from google.cloud.vision.connection import Connection
2425
from google.cloud.vision.image import Image
2526
from google.cloud.vision._http import _HTTPVisionAPI
@@ -71,6 +72,14 @@ def __init__(self, project=None, credentials=None, http=None,
7172
else:
7273
self._use_gax = use_gax
7374

75+
def batch(self):
76+
"""Batch multiple images into a single API request.
77+
78+
:rtype: :class:`google.cloud.vision.batch.Batch`
79+
:returns: Instance of ``Batch``.
80+
"""
81+
return Batch(self)
82+
7483
def image(self, content=None, filename=None, source_uri=None):
7584
"""Get instance of Image using current client.
7685

vision/google/cloud/vision/image.py

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -94,21 +94,17 @@ def source(self):
9494
"""
9595
return self._source
9696

97-
def _detect_annotation(self, features):
97+
def _detect_annotation(self, images):
9898
"""Generic method for detecting annotations.
9999
100-
:type features: list
101-
:param features: List of :class:`~google.cloud.vision.feature.Feature`
102-
indicating the type of annotations to perform.
100+
:type images: list
101+
:param images: List of :class:`~google.cloud.vision.image.Image`.
103102
104103
:rtype: list
105104
:returns: List of
106-
:class:`~google.cloud.vision.entity.EntityAnnotation`,
107-
:class:`~google.cloud.vision.face.Face`,
108-
:class:`~google.cloud.vision.color.ImagePropertiesAnnotation`,
109-
:class:`~google.cloud.vision.sage.SafeSearchAnnotation`,
105+
:class:`~google.cloud.vision.annotations.Annotations`.
110106
"""
111-
return self.client._vision_api.annotate(self, features)
107+
return self.client._vision_api.annotate(images)
112108

113109
def detect(self, features):
114110
"""Detect multiple feature types.
@@ -121,7 +117,8 @@ def detect(self, features):
121117
:returns: List of
122118
:class:`~google.cloud.vision.entity.EntityAnnotation`.
123119
"""
124-
return self._detect_annotation(features)
120+
images = ((self, features),)
121+
return self._detect_annotation(images)
125122

126123
def detect_faces(self, limit=10):
127124
"""Detect faces in image.
@@ -133,7 +130,7 @@ def detect_faces(self, limit=10):
133130
:returns: List of :class:`~google.cloud.vision.face.Face`.
134131
"""
135132
features = [Feature(FeatureTypes.FACE_DETECTION, limit)]
136-
annotations = self._detect_annotation(features)
133+
annotations = self.detect(features)
137134
return annotations[0].faces
138135

139136
def detect_labels(self, limit=10):
@@ -146,7 +143,7 @@ def detect_labels(self, limit=10):
146143
:returns: List of :class:`~google.cloud.vision.entity.EntityAnnotation`
147144
"""
148145
features = [Feature(FeatureTypes.LABEL_DETECTION, limit)]
149-
annotations = self._detect_annotation(features)
146+
annotations = self.detect(features)
150147
return annotations[0].labels
151148

152149
def detect_landmarks(self, limit=10):
@@ -160,7 +157,7 @@ def detect_landmarks(self, limit=10):
160157
:class:`~google.cloud.vision.entity.EntityAnnotation`.
161158
"""
162159
features = [Feature(FeatureTypes.LANDMARK_DETECTION, limit)]
163-
annotations = self._detect_annotation(features)
160+
annotations = self.detect(features)
164161
return annotations[0].landmarks
165162

166163
def detect_logos(self, limit=10):
@@ -174,7 +171,7 @@ def detect_logos(self, limit=10):
174171
:class:`~google.cloud.vision.entity.EntityAnnotation`.
175172
"""
176173
features = [Feature(FeatureTypes.LOGO_DETECTION, limit)]
177-
annotations = self._detect_annotation(features)
174+
annotations = self.detect(features)
178175
return annotations[0].logos
179176

180177
def detect_properties(self, limit=10):
@@ -188,7 +185,7 @@ def detect_properties(self, limit=10):
188185
:class:`~google.cloud.vision.color.ImagePropertiesAnnotation`.
189186
"""
190187
features = [Feature(FeatureTypes.IMAGE_PROPERTIES, limit)]
191-
annotations = self._detect_annotation(features)
188+
annotations = self.detect(features)
192189
return annotations[0].properties
193190

194191
def detect_safe_search(self, limit=10):
@@ -202,7 +199,7 @@ def detect_safe_search(self, limit=10):
202199
:class:`~google.cloud.vision.sage.SafeSearchAnnotation`.
203200
"""
204201
features = [Feature(FeatureTypes.SAFE_SEARCH_DETECTION, limit)]
205-
annotations = self._detect_annotation(features)
202+
annotations = self.detect(features)
206203
return annotations[0].safe_searches
207204

208205
def detect_text(self, limit=10):
@@ -216,5 +213,5 @@ def detect_text(self, limit=10):
216213
:class:`~google.cloud.vision.entity.EntityAnnotation`.
217214
"""
218215
features = [Feature(FeatureTypes.TEXT_DETECTION, limit)]
219-
annotations = self._detect_annotation(features)
216+
annotations = self.detect(features)
220217
return annotations[0].texts

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy