Skip to content

Commit 5e7d5ed

Browse files
authored
test: add unit test covering the case where worker streams are stopped early (googleapis#2127)
* test: add unit test covering the case where worker streams are stopped early * use older pyarrow.record_batch constructor * remove flakey log-based tests from snippets * add a gc.collect() call to make sure threads are supposed to be cleaned up
1 parent 54c8d07 commit 5e7d5ed

File tree

4 files changed

+146
-48
lines changed

4 files changed

+146
-48
lines changed

google/cloud/bigquery/_pandas_helpers.py

Lines changed: 50 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from itertools import islice
2121
import logging
2222
import queue
23+
import threading
2324
import warnings
2425
from typing import Any, Union, Optional, Callable, Generator, List
2526

@@ -119,6 +120,21 @@ def __init__(self):
119120
# be an atomic operation in the Python language definition (enforced by
120121
# the global interpreter lock).
121122
self.done = False
123+
# To assist with testing and understanding the behavior of the
124+
# download, use this object as shared state to track how many worker
125+
# threads have started and have gracefully shutdown.
126+
self._started_workers_lock = threading.Lock()
127+
self.started_workers = 0
128+
self._finished_workers_lock = threading.Lock()
129+
self.finished_workers = 0
130+
131+
def start(self):
132+
with self._started_workers_lock:
133+
self.started_workers += 1
134+
135+
def finish(self):
136+
with self._finished_workers_lock:
137+
self.finished_workers += 1
122138

123139

124140
BQ_FIELD_TYPE_TO_ARROW_FIELD_METADATA = {
@@ -786,25 +802,35 @@ def _bqstorage_page_to_dataframe(column_names, dtypes, page):
786802
def _download_table_bqstorage_stream(
787803
download_state, bqstorage_client, session, stream, worker_queue, page_to_item
788804
):
789-
reader = bqstorage_client.read_rows(stream.name)
805+
download_state.start()
806+
try:
807+
reader = bqstorage_client.read_rows(stream.name)
790808

791-
# Avoid deprecation warnings for passing in unnecessary read session.
792-
# https://github.com/googleapis/python-bigquery-storage/issues/229
793-
if _versions_helpers.BQ_STORAGE_VERSIONS.is_read_session_optional:
794-
rowstream = reader.rows()
795-
else:
796-
rowstream = reader.rows(session)
797-
798-
for page in rowstream.pages:
799-
item = page_to_item(page)
800-
while True:
801-
if download_state.done:
802-
return
803-
try:
804-
worker_queue.put(item, timeout=_PROGRESS_INTERVAL)
805-
break
806-
except queue.Full: # pragma: NO COVER
807-
continue
809+
# Avoid deprecation warnings for passing in unnecessary read session.
810+
# https://github.com/googleapis/python-bigquery-storage/issues/229
811+
if _versions_helpers.BQ_STORAGE_VERSIONS.is_read_session_optional:
812+
rowstream = reader.rows()
813+
else:
814+
rowstream = reader.rows(session)
815+
816+
for page in rowstream.pages:
817+
item = page_to_item(page)
818+
819+
# Make sure we set a timeout on put() so that we give the worker
820+
# thread opportunities to shutdown gracefully, for example if the
821+
# parent thread shuts down or the parent generator object which
822+
# collects rows from all workers goes out of scope. See:
823+
# https://github.com/googleapis/python-bigquery/issues/2032
824+
while True:
825+
if download_state.done:
826+
return
827+
try:
828+
worker_queue.put(item, timeout=_PROGRESS_INTERVAL)
829+
break
830+
except queue.Full:
831+
continue
832+
finally:
833+
download_state.finish()
808834

809835

810836
def _nowait(futures):
@@ -830,6 +856,7 @@ def _download_table_bqstorage(
830856
page_to_item: Optional[Callable] = None,
831857
max_queue_size: Any = _MAX_QUEUE_SIZE_DEFAULT,
832858
max_stream_count: Optional[int] = None,
859+
download_state: Optional[_DownloadState] = None,
833860
) -> Generator[Any, None, None]:
834861
"""Downloads a BigQuery table using the BigQuery Storage API.
835862
@@ -857,6 +884,9 @@ def _download_table_bqstorage(
857884
is True, the requested streams are limited to 1 regardless of the
858885
`max_stream_count` value. If 0 or None, then the number of
859886
requested streams will be unbounded. Defaults to None.
887+
download_state (Optional[_DownloadState]):
888+
A threadsafe state object which can be used to observe the
889+
behavior of the worker threads created by this method.
860890
861891
Yields:
862892
pandas.DataFrame: Pandas DataFrames, one for each chunk of data
@@ -915,7 +945,8 @@ def _download_table_bqstorage(
915945

916946
# Use _DownloadState to notify worker threads when to quit.
917947
# See: https://stackoverflow.com/a/29237343/101923
918-
download_state = _DownloadState()
948+
if download_state is None:
949+
download_state = _DownloadState()
919950

920951
# Create a queue to collect frames as they are created in each thread.
921952
#

samples/tests/test_download_public_data.py

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -12,29 +12,16 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import logging
16-
1715
import pytest
1816

1917
from .. import download_public_data
2018

2119
pytest.importorskip("google.cloud.bigquery_storage_v1")
2220

2321

24-
def test_download_public_data(
25-
caplog: pytest.LogCaptureFixture, capsys: pytest.CaptureFixture[str]
26-
) -> None:
27-
# Enable debug-level logging to verify the BigQuery Storage API is used.
28-
caplog.set_level(logging.DEBUG)
29-
22+
def test_download_public_data(capsys: pytest.CaptureFixture[str]) -> None:
3023
download_public_data.download_public_data()
3124
out, _ = capsys.readouterr()
3225
assert "year" in out
3326
assert "gender" in out
3427
assert "name" in out
35-
36-
assert any(
37-
"Started reading table 'bigquery-public-data.usa_names.usa_1910_current' with BQ Storage API session"
38-
in message
39-
for message in caplog.messages
40-
)

samples/tests/test_download_public_data_sandbox.py

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -12,29 +12,16 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import logging
16-
1715
import pytest
1816

1917
from .. import download_public_data_sandbox
2018

2119
pytest.importorskip("google.cloud.bigquery_storage_v1")
2220

2321

24-
def test_download_public_data_sandbox(
25-
caplog: pytest.LogCaptureFixture, capsys: pytest.CaptureFixture[str]
26-
) -> None:
27-
# Enable debug-level logging to verify the BigQuery Storage API is used.
28-
caplog.set_level(logging.DEBUG)
29-
22+
def test_download_public_data_sandbox(capsys: pytest.CaptureFixture[str]) -> None:
3023
download_public_data_sandbox.download_public_data_sandbox()
31-
out, err = capsys.readouterr()
24+
out, _ = capsys.readouterr()
3225
assert "year" in out
3326
assert "gender" in out
3427
assert "name" in out
35-
36-
assert any(
37-
# An anonymous table is used because this sample reads from query results.
38-
("Started reading table" in message and "BQ Storage API session" in message)
39-
for message in caplog.messages
40-
)

tests/unit/test__pandas_helpers.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import datetime
1717
import decimal
1818
import functools
19+
import gc
1920
import operator
2021
import queue
2122
from typing import Union
@@ -1846,6 +1847,98 @@ def fake_download_stream(
18461847
assert queue_used.maxsize == expected_maxsize
18471848

18481849

1850+
@pytest.mark.skipif(isinstance(pyarrow, mock.Mock), reason="Requires `pyarrow`")
1851+
def test__download_table_bqstorage_shuts_down_workers(
1852+
monkeypatch,
1853+
module_under_test,
1854+
):
1855+
"""Regression test for https://github.com/googleapis/python-bigquery/issues/2032
1856+
1857+
Make sure that when the top-level iterator goes out of scope (is deleted),
1858+
the child threads are also stopped.
1859+
"""
1860+
from google.cloud.bigquery import dataset
1861+
from google.cloud.bigquery import table
1862+
import google.cloud.bigquery_storage_v1.reader
1863+
import google.cloud.bigquery_storage_v1.types
1864+
1865+
monkeypatch.setattr(
1866+
_versions_helpers.BQ_STORAGE_VERSIONS, "_installed_version", None
1867+
)
1868+
monkeypatch.setattr(bigquery_storage, "__version__", "2.5.0")
1869+
1870+
# Create a fake stream with a decent number of rows.
1871+
arrow_schema = pyarrow.schema(
1872+
[
1873+
("int_col", pyarrow.int64()),
1874+
("str_col", pyarrow.string()),
1875+
]
1876+
)
1877+
arrow_rows = pyarrow.record_batch(
1878+
[
1879+
pyarrow.array([0, 1, 2], type=pyarrow.int64()),
1880+
pyarrow.array(["a", "b", "c"], type=pyarrow.string()),
1881+
],
1882+
schema=arrow_schema,
1883+
)
1884+
session = google.cloud.bigquery_storage_v1.types.ReadSession()
1885+
session.data_format = "ARROW"
1886+
session.arrow_schema = {"serialized_schema": arrow_schema.serialize().to_pybytes()}
1887+
session.streams = [
1888+
google.cloud.bigquery_storage_v1.types.ReadStream(name=name)
1889+
for name in ("stream/s0", "stream/s1", "stream/s2")
1890+
]
1891+
bqstorage_client = mock.create_autospec(
1892+
bigquery_storage.BigQueryReadClient, instance=True
1893+
)
1894+
reader = mock.create_autospec(
1895+
google.cloud.bigquery_storage_v1.reader.ReadRowsStream, instance=True
1896+
)
1897+
reader.__iter__.return_value = [
1898+
google.cloud.bigquery_storage_v1.types.ReadRowsResponse(
1899+
arrow_schema={"serialized_schema": arrow_schema.serialize().to_pybytes()},
1900+
arrow_record_batch={
1901+
"serialized_record_batch": arrow_rows.serialize().to_pybytes()
1902+
},
1903+
)
1904+
for _ in range(100)
1905+
]
1906+
reader.rows.return_value = google.cloud.bigquery_storage_v1.reader.ReadRowsIterable(
1907+
reader, read_session=session
1908+
)
1909+
bqstorage_client.read_rows.return_value = reader
1910+
bqstorage_client.create_read_session.return_value = session
1911+
table_ref = table.TableReference(
1912+
dataset.DatasetReference("project-x", "dataset-y"),
1913+
"table-z",
1914+
)
1915+
download_state = module_under_test._DownloadState()
1916+
assert download_state.started_workers == 0
1917+
assert download_state.finished_workers == 0
1918+
1919+
result_gen = module_under_test._download_table_bqstorage(
1920+
"some-project",
1921+
table_ref,
1922+
bqstorage_client,
1923+
max_queue_size=1,
1924+
page_to_item=module_under_test._bqstorage_page_to_arrow,
1925+
download_state=download_state,
1926+
)
1927+
1928+
result_gen_iter = iter(result_gen)
1929+
next(result_gen_iter)
1930+
assert download_state.started_workers == 3
1931+
assert download_state.finished_workers == 0
1932+
1933+
# Stop iteration early and simulate the variables going out of scope
1934+
# to be doubly sure that the worker threads are supposed to be cleaned up.
1935+
del result_gen, result_gen_iter
1936+
gc.collect()
1937+
1938+
assert download_state.started_workers == 3
1939+
assert download_state.finished_workers == 3
1940+
1941+
18491942
@pytest.mark.skipif(isinstance(pyarrow, mock.Mock), reason="Requires `pyarrow`")
18501943
def test_download_arrow_row_iterator_unknown_field_type(module_under_test):
18511944
fake_page = api_core.page_iterator.Page(

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy