|
18 | 18 |
|
19 | 19 | import pkg_resources
|
20 | 20 | import logging
|
| 21 | +import six |
21 | 22 |
|
22 | 23 | from google.api_core.gapic_v1 import client_info
|
23 | 24 | from google.api_core import exceptions
|
24 | 25 | from google.cloud.automl_v1beta1 import gapic
|
25 |
| -from google.cloud.automl_v1beta1.proto import data_types_pb2, data_items_pb2 |
| 26 | +from google.cloud.automl_v1beta1.proto import data_items_pb2 |
26 | 27 | from google.cloud.automl_v1beta1.tables import gcs_client
|
27 | 28 | from google.protobuf import struct_pb2
|
28 | 29 |
|
|
31 | 32 | _LOGGER = logging.getLogger(__name__)
|
32 | 33 |
|
33 | 34 |
|
| 35 | +def to_proto_value(value): |
| 36 | + """translates a Python value to a google.protobuf.Value. |
| 37 | +
|
| 38 | + Args: |
| 39 | + value: The Python value to be translated. |
| 40 | +
|
| 41 | + Returns: |
| 42 | + Tuple of the translated google.protobuf.Value and error if any. |
| 43 | + """ |
| 44 | + # possible Python types (this is a Python3 module): |
| 45 | + # https://simplejson.readthedocs.io/en/latest/#encoders-and-decoders |
| 46 | + # JSON Python 2 Python 3 |
| 47 | + # object dict dict |
| 48 | + # array list list |
| 49 | + # string unicode str |
| 50 | + # number (int) int, long int |
| 51 | + # number (real) float float |
| 52 | + # true True True |
| 53 | + # false False False |
| 54 | + # null None None |
| 55 | + if value is None: |
| 56 | + # translate null to an empty value. |
| 57 | + return struct_pb2.Value(null_value=struct_pb2.NullValue.NULL_VALUE), None |
| 58 | + elif isinstance(value, bool): |
| 59 | + # This check needs to happen before isinstance(value, int), |
| 60 | + # isinstance(value, int) returns True when value is bool. |
| 61 | + return struct_pb2.Value(bool_value=value), None |
| 62 | + elif isinstance(value, six.integer_types) or isinstance(value, float): |
| 63 | + return struct_pb2.Value(number_value=value), None |
| 64 | + elif isinstance(value, six.string_types) or isinstance(value, six.text_type): |
| 65 | + return struct_pb2.Value(string_value=value), None |
| 66 | + elif isinstance(value, dict): |
| 67 | + struct_value = struct_pb2.Struct() |
| 68 | + for key, v in value.items(): |
| 69 | + field_value, err = to_proto_value(v) |
| 70 | + if err is not None: |
| 71 | + return None, err |
| 72 | + |
| 73 | + struct_value.fields[key].CopyFrom(field_value) |
| 74 | + return struct_pb2.Value(struct_value=struct_value), None |
| 75 | + elif isinstance(value, list): |
| 76 | + list_value = [] |
| 77 | + for v in value: |
| 78 | + proto_value, err = to_proto_value(v) |
| 79 | + if err is not None: |
| 80 | + return None, err |
| 81 | + list_value.append(proto_value) |
| 82 | + return ( |
| 83 | + struct_pb2.Value(list_value=struct_pb2.ListValue(values=list_value)), |
| 84 | + None, |
| 85 | + ) |
| 86 | + else: |
| 87 | + return None, "unsupport data type: {}".format(type(value)) |
| 88 | + |
| 89 | + |
34 | 90 | class TablesClient(object):
|
35 | 91 | """
|
36 | 92 | AutoML Tables API helper.
|
@@ -404,42 +460,6 @@ def __column_spec_name_from_args(
|
404 | 460 |
|
405 | 461 | return column_spec_name
|
406 | 462 |
|
407 |
| - def __data_type_to_proto_value(self, data_type, value): |
408 |
| - type_code = data_type.type_code |
409 |
| - if value is None: |
410 |
| - return struct_pb2.Value(null_value=struct_pb2.NullValue.NULL_VALUE) |
411 |
| - elif type_code == data_types_pb2.FLOAT64: |
412 |
| - return struct_pb2.Value(number_value=value) |
413 |
| - elif ( |
414 |
| - type_code == data_types_pb2.TIMESTAMP |
415 |
| - or type_code == data_types_pb2.STRING |
416 |
| - or type_code == data_types_pb2.CATEGORY |
417 |
| - ): |
418 |
| - return struct_pb2.Value(string_value=value) |
419 |
| - elif type_code == data_types_pb2.ARRAY: |
420 |
| - if isinstance(value, struct_pb2.ListValue): |
421 |
| - # in case the user passed in a ListValue. |
422 |
| - return struct_pb2.Value(list_value=value) |
423 |
| - array = [] |
424 |
| - for item in value: |
425 |
| - array.append( |
426 |
| - self.__data_type_to_proto_value(data_type.list_element_type, item) |
427 |
| - ) |
428 |
| - return struct_pb2.Value(list_value=struct_pb2.ListValue(values=array)) |
429 |
| - elif type_code == data_types_pb2.STRUCT: |
430 |
| - if isinstance(value, struct_pb2.Struct): |
431 |
| - # in case the user passed in a Struct. |
432 |
| - return struct_pb2.Value(struct_value=value) |
433 |
| - struct_value = struct_pb2.Struct() |
434 |
| - for k, v in value.items(): |
435 |
| - field_value = self.__data_type_to_proto_value( |
436 |
| - data_type.struct_type.fields[k], v |
437 |
| - ) |
438 |
| - struct_value.fields[k].CopyFrom(field_value) |
439 |
| - return struct_pb2.Value(struct_value=struct_value) |
440 |
| - else: |
441 |
| - raise ValueError("Unknown type_code: {}".format(type_code)) |
442 |
| - |
443 | 463 | def __ensure_gcs_client_is_initialized(self, credentials, project):
|
444 | 464 | """Checks if GCS client is initialized. Initializes it if not.
|
445 | 465 |
|
@@ -2714,7 +2734,9 @@ def predict(
|
2714 | 2734 |
|
2715 | 2735 | values = []
|
2716 | 2736 | for i, c in zip(inputs, column_specs):
|
2717 |
| - value_type = self.__data_type_to_proto_value(c.data_type, i) |
| 2737 | + value_type, err = to_proto_value(i) |
| 2738 | + if err is not None: |
| 2739 | + raise ValueError(err) |
2718 | 2740 | values.append(value_type)
|
2719 | 2741 |
|
2720 | 2742 | row = data_items_pb2.Row(values=values)
|
|
0 commit comments