@@ -36,7 +36,15 @@ def _get_target_class():
36
36
def _make_one (self , * args , ** kw ):
37
37
return self ._get_target_class ()(* args , ** kw )
38
38
39
- def _mock_client (self , rows = None , schema = None , num_dml_affected_rows = None ):
39
+ def _mock_client (
40
+ self ,
41
+ rows = None ,
42
+ schema = None ,
43
+ num_dml_affected_rows = None ,
44
+ default_query_job_config = None ,
45
+ dry_run_job = False ,
46
+ total_bytes_processed = 0 ,
47
+ ):
40
48
from google .cloud .bigquery import client
41
49
42
50
if rows is None :
@@ -49,8 +57,12 @@ def _mock_client(self, rows=None, schema=None, num_dml_affected_rows=None):
49
57
total_rows = total_rows ,
50
58
schema = schema ,
51
59
num_dml_affected_rows = num_dml_affected_rows ,
60
+ dry_run = dry_run_job ,
61
+ total_bytes_processed = total_bytes_processed ,
52
62
)
53
63
mock_client .list_rows .return_value = rows
64
+ mock_client ._default_query_job_config = default_query_job_config
65
+
54
66
return mock_client
55
67
56
68
def _mock_bqstorage_client (self , rows = None , stream_count = 0 ):
@@ -76,18 +88,31 @@ def _mock_bqstorage_client(self, rows=None, stream_count=0):
76
88
77
89
return mock_client
78
90
79
- def _mock_job (self , total_rows = 0 , schema = None , num_dml_affected_rows = None ):
91
+ def _mock_job (
92
+ self ,
93
+ total_rows = 0 ,
94
+ schema = None ,
95
+ num_dml_affected_rows = None ,
96
+ dry_run = False ,
97
+ total_bytes_processed = 0 ,
98
+ ):
80
99
from google .cloud .bigquery import job
81
100
82
101
mock_job = mock .create_autospec (job .QueryJob )
83
102
mock_job .error_result = None
84
103
mock_job .state = "DONE"
85
- mock_job .result .return_value = mock_job
86
- mock_job ._query_results = self ._mock_results (
87
- total_rows = total_rows ,
88
- schema = schema ,
89
- num_dml_affected_rows = num_dml_affected_rows ,
90
- )
104
+ mock_job .dry_run = dry_run
105
+
106
+ if dry_run :
107
+ mock_job .result .side_effect = exceptions .NotFound
108
+ mock_job .total_bytes_processed = total_bytes_processed
109
+ else :
110
+ mock_job .result .return_value = mock_job
111
+ mock_job ._query_results = self ._mock_results (
112
+ total_rows = total_rows ,
113
+ schema = schema ,
114
+ num_dml_affected_rows = num_dml_affected_rows ,
115
+ )
91
116
92
117
if num_dml_affected_rows is None :
93
118
mock_job .statement_type = None # API sends back None for SELECT
@@ -373,7 +398,27 @@ def test_execute_custom_job_id(self):
373
398
self .assertEqual (args [0 ], "SELECT 1;" )
374
399
self .assertEqual (kwargs ["job_id" ], "foo" )
375
400
376
- def test_execute_custom_job_config (self ):
401
+ def test_execute_w_default_config (self ):
402
+ from google .cloud .bigquery .dbapi import connect
403
+ from google .cloud .bigquery import job
404
+
405
+ default_config = job .QueryJobConfig (use_legacy_sql = False , flatten_results = True )
406
+ client = self ._mock_client (
407
+ rows = [], num_dml_affected_rows = 0 , default_query_job_config = default_config
408
+ )
409
+ connection = connect (client )
410
+ cursor = connection .cursor ()
411
+
412
+ cursor .execute ("SELECT 1;" , job_id = "foo" )
413
+
414
+ _ , kwargs = client .query .call_args
415
+ used_config = kwargs ["job_config" ]
416
+ expected_config = job .QueryJobConfig (
417
+ use_legacy_sql = False , flatten_results = True , query_parameters = []
418
+ )
419
+ self .assertEqual (used_config ._properties , expected_config ._properties )
420
+
421
+ def test_execute_custom_job_config_wo_default_config (self ):
377
422
from google .cloud .bigquery .dbapi import connect
378
423
from google .cloud .bigquery import job
379
424
@@ -387,6 +432,29 @@ def test_execute_custom_job_config(self):
387
432
self .assertEqual (kwargs ["job_id" ], "foo" )
388
433
self .assertEqual (kwargs ["job_config" ], config )
389
434
435
+ def test_execute_custom_job_config_w_default_config (self ):
436
+ from google .cloud .bigquery .dbapi import connect
437
+ from google .cloud .bigquery import job
438
+
439
+ default_config = job .QueryJobConfig (use_legacy_sql = False , flatten_results = True )
440
+ client = self ._mock_client (
441
+ rows = [], num_dml_affected_rows = 0 , default_query_job_config = default_config
442
+ )
443
+ connection = connect (client )
444
+ cursor = connection .cursor ()
445
+ config = job .QueryJobConfig (use_legacy_sql = True )
446
+
447
+ cursor .execute ("SELECT 1;" , job_id = "foo" , job_config = config )
448
+
449
+ _ , kwargs = client .query .call_args
450
+ used_config = kwargs ["job_config" ]
451
+ expected_config = job .QueryJobConfig (
452
+ use_legacy_sql = True , # the config passed to execute() prevails
453
+ flatten_results = True , # from the default
454
+ query_parameters = [],
455
+ )
456
+ self .assertEqual (used_config ._properties , expected_config ._properties )
457
+
390
458
def test_execute_w_dml (self ):
391
459
from google .cloud .bigquery .dbapi import connect
392
460
@@ -442,6 +510,40 @@ def test_execute_w_query(self):
442
510
row = cursor .fetchone ()
443
511
self .assertIsNone (row )
444
512
513
+ def test_execute_w_query_dry_run (self ):
514
+ from google .cloud .bigquery .job import QueryJobConfig
515
+ from google .cloud .bigquery .schema import SchemaField
516
+ from google .cloud .bigquery import dbapi
517
+
518
+ connection = dbapi .connect (
519
+ self ._mock_client (
520
+ rows = [("hello" , "world" , 1 ), ("howdy" , "y'all" , 2 )],
521
+ schema = [
522
+ SchemaField ("a" , "STRING" , mode = "NULLABLE" ),
523
+ SchemaField ("b" , "STRING" , mode = "REQUIRED" ),
524
+ SchemaField ("c" , "INTEGER" , mode = "NULLABLE" ),
525
+ ],
526
+ dry_run_job = True ,
527
+ total_bytes_processed = 12345 ,
528
+ )
529
+ )
530
+ cursor = connection .cursor ()
531
+ cursor .execute (
532
+ "SELECT a, b, c FROM hello_world WHERE d > 3;" ,
533
+ job_config = QueryJobConfig (dry_run = True ),
534
+ )
535
+
536
+ self .assertIsNone (cursor .description )
537
+ self .assertEqual (cursor .rowcount , 1 )
538
+
539
+ rows = cursor .fetchall ()
540
+
541
+ # We expect a single row with one column - the estimated numbe of bytes
542
+ # that will be processed by the query.
543
+ self .assertEqual (len (rows ), 1 )
544
+ self .assertEqual (len (rows [0 ]), 1 )
545
+ self .assertEqual (rows [0 ][0 ], 12345 )
546
+
445
547
def test_execute_raises_if_result_raises (self ):
446
548
import google .cloud .exceptions
447
549
@@ -451,8 +553,10 @@ def test_execute_raises_if_result_raises(self):
451
553
from google .cloud .bigquery .dbapi import exceptions
452
554
453
555
job = mock .create_autospec (job .QueryJob )
556
+ job .dry_run = None
454
557
job .result .side_effect = google .cloud .exceptions .GoogleCloudError ("" )
455
558
client = mock .create_autospec (client .Client )
559
+ client ._default_query_job_config = None
456
560
client .query .return_value = job
457
561
connection = connect (client )
458
562
cursor = connection .cursor ()
0 commit comments