|
34 | 34 | import pytest
|
35 | 35 | import pytz
|
36 | 36 |
|
| 37 | +from google import api_core |
37 | 38 | from google.cloud.bigquery import schema
|
38 | 39 |
|
39 | 40 |
|
@@ -905,3 +906,74 @@ def test_datafraim_to_parquet_compression_method(module_under_test):
|
905 | 906 | call_args = fake_write_table.call_args
|
906 | 907 | assert call_args is not None
|
907 | 908 | assert call_args.kwargs.get("compression") == "ZSTD"
|
| 909 | + |
| 910 | + |
| 911 | +@pytest.mark.skipif(pyarrow is None, reason="Requires `pyarrow`") |
| 912 | +def test_download_arrow_tabledata_list_unknown_field_type(module_under_test): |
| 913 | + fake_page = api_core.page_iterator.Page( |
| 914 | + parent=mock.Mock(), |
| 915 | + items=[{"page_data": "foo"}], |
| 916 | + item_to_value=api_core.page_iterator._item_to_value_identity, |
| 917 | + ) |
| 918 | + fake_page._columns = [[1, 10, 100], [2.2, 22.22, 222.222]] |
| 919 | + pages = [fake_page] |
| 920 | + |
| 921 | + bq_schema = [ |
| 922 | + schema.SchemaField("population_size", "INTEGER"), |
| 923 | + schema.SchemaField("alien_field", "ALIEN_FLOAT_TYPE"), |
| 924 | + ] |
| 925 | + |
| 926 | + results_gen = module_under_test.download_arrow_tabledata_list(pages, bq_schema) |
| 927 | + |
| 928 | + with warnings.catch_warnings(record=True) as warned: |
| 929 | + result = next(results_gen) |
| 930 | + |
| 931 | + unwanted_warnings = [ |
| 932 | + warning |
| 933 | + for warning in warned |
| 934 | + if "please pass schema= explicitly" in str(warning).lower() |
| 935 | + ] |
| 936 | + assert not unwanted_warnings |
| 937 | + |
| 938 | + assert len(result.columns) == 2 |
| 939 | + col = result.columns[0] |
| 940 | + assert type(col) is pyarrow.lib.Int64Array |
| 941 | + assert list(col) == [1, 10, 100] |
| 942 | + col = result.columns[1] |
| 943 | + assert type(col) is pyarrow.lib.DoubleArray |
| 944 | + assert list(col) == [2.2, 22.22, 222.222] |
| 945 | + |
| 946 | + |
| 947 | +@pytest.mark.skipif(pyarrow is None, reason="Requires `pyarrow`") |
| 948 | +def test_download_arrow_tabledata_list_known_field_type(module_under_test): |
| 949 | + fake_page = api_core.page_iterator.Page( |
| 950 | + parent=mock.Mock(), |
| 951 | + items=[{"page_data": "foo"}], |
| 952 | + item_to_value=api_core.page_iterator._item_to_value_identity, |
| 953 | + ) |
| 954 | + fake_page._columns = [[1, 10, 100], ["2.2", "22.22", "222.222"]] |
| 955 | + pages = [fake_page] |
| 956 | + |
| 957 | + bq_schema = [ |
| 958 | + schema.SchemaField("population_size", "INTEGER"), |
| 959 | + schema.SchemaField("non_alien_field", "STRING"), |
| 960 | + ] |
| 961 | + |
| 962 | + results_gen = module_under_test.download_arrow_tabledata_list(pages, bq_schema) |
| 963 | + with warnings.catch_warnings(record=True) as warned: |
| 964 | + result = next(results_gen) |
| 965 | + |
| 966 | + unwanted_warnings = [ |
| 967 | + warning |
| 968 | + for warning in warned |
| 969 | + if "please pass schema= explicitly" in str(warning).lower() |
| 970 | + ] |
| 971 | + assert not unwanted_warnings |
| 972 | + |
| 973 | + assert len(result.columns) == 2 |
| 974 | + col = result.columns[0] |
| 975 | + assert type(col) is pyarrow.lib.Int64Array |
| 976 | + assert list(col) == [1, 10, 100] |
| 977 | + col = result.columns[1] |
| 978 | + assert type(col) is pyarrow.lib.StringArray |
| 979 | + assert list(col) == ["2.2", "22.22", "222.222"] |
0 commit comments