|
25 | 25 | from datetime import datetime, timedelta
|
26 | 26 | from io import BytesIO
|
27 | 27 | from unittest import mock
|
| 28 | +from unittest.mock import MagicMock |
28 | 29 |
|
29 | 30 | import dateutil
|
30 | 31 | import pytest
|
@@ -1279,6 +1280,47 @@ def test_should_overwrite_files(self, mock_get_conn, mock_delete, mock_rewrite,
|
1279 | 1280 | )
|
1280 | 1281 | mock_copy.assert_not_called()
|
1281 | 1282 |
|
| 1283 | + @mock.patch(GCS_STRING.format("GCSHook.copy")) |
| 1284 | + @mock.patch(GCS_STRING.format("GCSHook.rewrite")) |
| 1285 | + @mock.patch(GCS_STRING.format("GCSHook.delete")) |
| 1286 | + @mock.patch(GCS_STRING.format("GCSHook.get_conn")) |
| 1287 | + def test_should_overwrite_cmek_files(self, mock_get_conn, mock_delete, mock_rewrite, mock_copy): |
| 1288 | + source_bucket = self._create_bucket(name="SOURCE_BUCKET") |
| 1289 | + source_bucket.list_blobs.return_value = [ |
| 1290 | + self._create_blob("FILE_A", "C1", kms_key_name="KMS_KEY_1", generation=1), |
| 1291 | + self._create_blob("FILE_B", "C1"), |
| 1292 | + ] |
| 1293 | + destination_bucket = self._create_bucket(name="DEST_BUCKET") |
| 1294 | + destination_bucket.list_blobs.return_value = [ |
| 1295 | + self._create_blob("FILE_A", "C2", kms_key_name="KMS_KEY_2", generation=2), |
| 1296 | + self._create_blob("FILE_B", "C2"), |
| 1297 | + ] |
| 1298 | + mock_get_conn.return_value.bucket.side_effect = [source_bucket, destination_bucket] |
| 1299 | + self.gcs_hook.sync( |
| 1300 | + source_bucket="SOURCE_BUCKET", destination_bucket="DEST_BUCKET", allow_overwrite=True |
| 1301 | + ) |
| 1302 | + mock_delete.assert_not_called() |
| 1303 | + source_bucket.get_blob.assert_called_once_with("FILE_A", generation=1) |
| 1304 | + destination_bucket.get_blob.assert_called_once_with("FILE_A", generation=2) |
| 1305 | + mock_rewrite.assert_has_calls( |
| 1306 | + [ |
| 1307 | + mock.call( |
| 1308 | + source_bucket="SOURCE_BUCKET", |
| 1309 | + source_object="FILE_B", |
| 1310 | + destination_bucket="DEST_BUCKET", |
| 1311 | + destination_object="FILE_B", |
| 1312 | + ), |
| 1313 | + mock.call( |
| 1314 | + source_bucket="SOURCE_BUCKET", |
| 1315 | + source_object=source_bucket.get_blob.return_value.name, |
| 1316 | + destination_bucket="DEST_BUCKET", |
| 1317 | + destination_object=source_bucket.get_blob.return_value.name.__getitem__.return_value, |
| 1318 | + ), |
| 1319 | + ], |
| 1320 | + any_order=True, |
| 1321 | + ) |
| 1322 | + mock_copy.assert_not_called() |
| 1323 | + |
1282 | 1324 | @mock.patch(GCS_STRING.format("GCSHook.copy"))
|
1283 | 1325 | @mock.patch(GCS_STRING.format("GCSHook.rewrite"))
|
1284 | 1326 | @mock.patch(GCS_STRING.format("GCSHook.delete"))
|
@@ -1440,11 +1482,20 @@ def test_should_not_overwrite_when_overwrite_is_disabled(
|
1440 | 1482 | mock_rewrite.assert_not_called()
|
1441 | 1483 | mock_copy.assert_not_called()
|
1442 | 1484 |
|
1443 |
| - def _create_blob(self, name: str, crc32: str, bucket=None): |
| 1485 | + def _create_blob( |
| 1486 | + self, |
| 1487 | + name: str, |
| 1488 | + crc32: str, |
| 1489 | + bucket: MagicMock | None = None, |
| 1490 | + kms_key_name: str | None = None, |
| 1491 | + generation: int = 0, |
| 1492 | + ): |
1444 | 1493 | blob = mock.MagicMock(name=f"BLOB:{name}")
|
1445 | 1494 | blob.name = name
|
1446 | 1495 | blob.crc32 = crc32
|
1447 | 1496 | blob.bucket = bucket
|
| 1497 | + blob.kms_key_name = kms_key_name |
| 1498 | + blob.generation = generation |
1448 | 1499 | return blob
|
1449 | 1500 |
|
1450 | 1501 | def _create_bucket(self, name: str):
|
|
0 commit comments