92
92
VariableResult ,
93
93
XComResult ,
94
94
)
95
- from airflow .sdk .execution_time .secrets_masker import SecretsMasker
96
95
from airflow .sdk .execution_time .supervisor import (
97
96
BUFFER_SIZE ,
98
97
ActivitySubprocess ,
@@ -977,9 +976,17 @@ def watched_subprocess(self, mocker):
977
976
978
977
return subprocess , read_end
979
978
980
- @patch ("airflow.sdk.execution_time.secrets_masker._secrets_masker " )
979
+ @patch ("airflow.sdk.execution_time.supervisor.mask_secret " )
981
980
@pytest .mark .parametrize (
982
- ["message" , "expected_buffer" , "client_attr_path" , "method_arg" , "method_kwarg" , "mock_response" ],
981
+ [
982
+ "message" ,
983
+ "expected_buffer" ,
984
+ "client_attr_path" ,
985
+ "method_arg" ,
986
+ "method_kwarg" ,
987
+ "mock_response" ,
988
+ "mask_secret_args" ,
989
+ ],
983
990
[
984
991
pytest .param (
985
992
GetConnection (conn_id = "test_conn" ),
@@ -988,15 +995,27 @@ def watched_subprocess(self, mocker):
988
995
("test_conn" ,),
989
996
{},
990
997
ConnectionResult (conn_id = "test_conn" , conn_type = "mysql" ),
998
+ None ,
991
999
id = "get_connection" ,
992
1000
),
1001
+ pytest .param (
1002
+ GetConnection (conn_id = "test_conn" ),
1003
+ b'{"conn_id":"test_conn","conn_type":"mysql","password":"password","type":"ConnectionResult"}\n ' ,
1004
+ "connections.get" ,
1005
+ ("test_conn" ,),
1006
+ {},
1007
+ ConnectionResult (conn_id = "test_conn" , conn_type = "mysql" , password = "password" ),
1008
+ ["password" ],
1009
+ id = "get_connection_with_password" ,
1010
+ ),
993
1011
pytest .param (
994
1012
GetConnection (conn_id = "test_conn" ),
995
1013
b'{"conn_id":"test_conn","conn_type":"mysql","schema":"mysql","type":"ConnectionResult"}\n ' ,
996
1014
"connections.get" ,
997
1015
("test_conn" ,),
998
1016
{},
999
1017
ConnectionResult (conn_id = "test_conn" , conn_type = "mysql" , schema = "mysql" ), # type: ignore[call-arg]
1018
+ None ,
1000
1019
id = "get_connection_with_alias" ,
1001
1020
),
1002
1021
pytest .param (
@@ -1006,6 +1025,7 @@ def watched_subprocess(self, mocker):
1006
1025
("test_key" ,),
1007
1026
{},
1008
1027
VariableResult (key = "test_key" , value = "test_value" ),
1028
+ ["test_value" , "test_key" ],
1009
1029
id = "get_variable" ,
1010
1030
),
1011
1031
pytest .param (
@@ -1015,6 +1035,7 @@ def watched_subprocess(self, mocker):
1015
1035
("test_key" , "test_value" , "test_description" ),
1016
1036
{},
1017
1037
OKResponse (ok = True ),
1038
+ None ,
1018
1039
id = "set_variable" ,
1019
1040
),
1020
1041
pytest .param (
@@ -1024,6 +1045,7 @@ def watched_subprocess(self, mocker):
1024
1045
("test_key" ,),
1025
1046
{},
1026
1047
OKResponse (ok = True ),
1048
+ None ,
1027
1049
id = "delete_variable" ,
1028
1050
),
1029
1051
pytest .param (
@@ -1033,6 +1055,7 @@ def watched_subprocess(self, mocker):
1033
1055
(TI_ID , DeferTask (next_method = "execute_callback" , classpath = "my-classpath" )),
1034
1056
{},
1035
1057
"" ,
1058
+ None ,
1036
1059
id = "patch_task_instance_to_deferred" ,
1037
1060
),
1038
1061
pytest .param (
@@ -1051,6 +1074,7 @@ def watched_subprocess(self, mocker):
1051
1074
),
1052
1075
{},
1053
1076
"" ,
1077
+ None ,
1054
1078
id = "patch_task_instance_to_up_for_reschedule" ,
1055
1079
),
1056
1080
pytest .param (
@@ -1060,6 +1084,7 @@ def watched_subprocess(self, mocker):
1060
1084
("test_dag" , "test_run" , "test_task" , "test_key" , None , False ),
1061
1085
{},
1062
1086
XComResult (key = "test_key" , value = "test_value" ),
1087
+ None ,
1063
1088
id = "get_xcom" ,
1064
1089
),
1065
1090
pytest .param (
@@ -1071,6 +1096,7 @@ def watched_subprocess(self, mocker):
1071
1096
("test_dag" , "test_run" , "test_task" , "test_key" , 2 , False ),
1072
1097
{},
1073
1098
XComResult (key = "test_key" , value = "test_value" ),
1099
+ None ,
1074
1100
id = "get_xcom_map_index" ,
1075
1101
),
1076
1102
pytest .param (
@@ -1080,6 +1106,7 @@ def watched_subprocess(self, mocker):
1080
1106
("test_dag" , "test_run" , "test_task" , "test_key" , None , False ),
1081
1107
{},
1082
1108
XComResult (key = "test_key" , value = None , type = "XComResult" ),
1109
+ None ,
1083
1110
id = "get_xcom_not_found" ,
1084
1111
),
1085
1112
pytest .param (
@@ -1095,6 +1122,7 @@ def watched_subprocess(self, mocker):
1095
1122
("test_dag" , "test_run" , "test_task" , "test_key" , None , True ),
1096
1123
{},
1097
1124
XComResult (key = "test_key" , value = None , type = "XComResult" ),
1125
+ None ,
1098
1126
id = "get_xcom_include_prior_dates" ,
1099
1127
),
1100
1128
pytest .param (
@@ -1118,6 +1146,7 @@ def watched_subprocess(self, mocker):
1118
1146
),
1119
1147
{},
1120
1148
OKResponse (ok = True ),
1149
+ None ,
1121
1150
id = "set_xcom" ,
1122
1151
),
1123
1152
pytest .param (
@@ -1142,6 +1171,7 @@ def watched_subprocess(self, mocker):
1142
1171
),
1143
1172
{},
1144
1173
OKResponse (ok = True ),
1174
+ None ,
1145
1175
id = "set_xcom_with_map_index" ,
1146
1176
),
1147
1177
pytest .param (
@@ -1167,6 +1197,7 @@ def watched_subprocess(self, mocker):
1167
1197
),
1168
1198
{},
1169
1199
OKResponse (ok = True ),
1200
+ None ,
1170
1201
id = "set_xcom_with_map_index_and_mapped_length" ,
1171
1202
),
1172
1203
pytest .param (
@@ -1188,6 +1219,7 @@ def watched_subprocess(self, mocker):
1188
1219
),
1189
1220
{},
1190
1221
OKResponse (ok = True ),
1222
+ None ,
1191
1223
id = "delete_xcom" ,
1192
1224
),
1193
1225
# we aren't adding all states under TaskInstanceState here, because this test's scope is only to check
@@ -1199,6 +1231,7 @@ def watched_subprocess(self, mocker):
1199
1231
(),
1200
1232
{},
1201
1233
"" ,
1234
+ None ,
1202
1235
id = "patch_task_instance_to_skipped" ,
1203
1236
),
1204
1237
pytest .param (
@@ -1214,6 +1247,7 @@ def watched_subprocess(self, mocker):
1214
1247
"rendered_map_index" : "test retry task" ,
1215
1248
},
1216
1249
"" ,
1250
+ None ,
1217
1251
id = "up_for_retry" ,
1218
1252
),
1219
1253
pytest .param (
@@ -1223,6 +1257,7 @@ def watched_subprocess(self, mocker):
1223
1257
(TI_ID , {"field1" : "rendered_value1" , "field2" : "rendered_value2" }),
1224
1258
{},
1225
1259
OKResponse (ok = True ),
1260
+ None ,
1226
1261
id = "set_rtif" ,
1227
1262
),
1228
1263
pytest .param (
@@ -1232,6 +1267,7 @@ def watched_subprocess(self, mocker):
1232
1267
[],
1233
1268
{"name" : "asset" },
1234
1269
AssetResult (name = "asset" , uri = "s3://bucket/obj" , group = "asset" ),
1270
+ None ,
1235
1271
id = "get_asset_by_name" ,
1236
1272
),
1237
1273
pytest .param (
@@ -1241,6 +1277,7 @@ def watched_subprocess(self, mocker):
1241
1277
[],
1242
1278
{"uri" : "s3://bucket/obj" },
1243
1279
AssetResult (name = "asset" , uri = "s3://bucket/obj" , group = "asset" ),
1280
+ None ,
1244
1281
id = "get_asset_by_uri" ,
1245
1282
),
1246
1283
pytest .param (
@@ -1263,6 +1300,7 @@ def watched_subprocess(self, mocker):
1263
1300
)
1264
1301
]
1265
1302
),
1303
+ None ,
1266
1304
id = "get_asset_events_by_uri_and_name" ,
1267
1305
),
1268
1306
pytest .param (
@@ -1285,6 +1323,7 @@ def watched_subprocess(self, mocker):
1285
1323
)
1286
1324
]
1287
1325
),
1326
+ None ,
1288
1327
id = "get_asset_events_by_uri" ,
1289
1328
),
1290
1329
pytest .param (
@@ -1307,6 +1346,7 @@ def watched_subprocess(self, mocker):
1307
1346
)
1308
1347
]
1309
1348
),
1349
+ None ,
1310
1350
id = "get_asset_events_by_name" ,
1311
1351
),
1312
1352
pytest .param (
@@ -1329,6 +1369,7 @@ def watched_subprocess(self, mocker):
1329
1369
)
1330
1370
]
1331
1371
),
1372
+ None ,
1332
1373
id = "get_asset_events_by_asset_alias" ,
1333
1374
),
1334
1375
pytest .param (
@@ -1346,6 +1387,7 @@ def watched_subprocess(self, mocker):
1346
1387
"rendered_map_index" : "test success task" ,
1347
1388
},
1348
1389
"" ,
1390
+ None ,
1349
1391
id = "succeed_task" ,
1350
1392
),
1351
1393
pytest .param (
@@ -1364,6 +1406,7 @@ def watched_subprocess(self, mocker):
1364
1406
data_interval_start = timezone .parse ("2025-01-10T12:00:00Z" ),
1365
1407
data_interval_end = timezone .parse ("2025-01-10T14:00:00Z" ),
1366
1408
),
1409
+ None ,
1367
1410
id = "get_prev_successful_dagrun" ,
1368
1411
),
1369
1412
pytest .param (
@@ -1379,6 +1422,7 @@ def watched_subprocess(self, mocker):
1379
1422
("test_dag" , "test_run" , {"key" : "value" }, timezone .datetime (2025 , 1 , 1 ), True ),
1380
1423
{},
1381
1424
OKResponse (ok = True ),
1425
+ None ,
1382
1426
id = "dag_run_trigger" ,
1383
1427
),
1384
1428
pytest .param (
@@ -1388,6 +1432,7 @@ def watched_subprocess(self, mocker):
1388
1432
("test_dag" , "test_run" , None , None , False ),
1389
1433
{},
1390
1434
ErrorResponse (error = ErrorType .DAGRUN_ALREADY_EXISTS ),
1435
+ None ,
1391
1436
id = "dag_run_trigger_already_exists" ,
1392
1437
),
1393
1438
pytest .param (
@@ -1397,6 +1442,7 @@ def watched_subprocess(self, mocker):
1397
1442
("test_dag" , "test_run" ),
1398
1443
{},
1399
1444
DagRunStateResult (state = DagRunState .RUNNING ),
1445
+ None ,
1400
1446
id = "get_dag_run_state" ,
1401
1447
),
1402
1448
pytest .param (
@@ -1406,6 +1452,7 @@ def watched_subprocess(self, mocker):
1406
1452
(TI_ID , 1 ),
1407
1453
{},
1408
1454
TaskRescheduleStartDate (start_date = timezone .parse ("2024-10-31T12:00:00Z" )),
1455
+ None ,
1409
1456
id = "get_task_reschedule_start_date" ,
1410
1457
),
1411
1458
pytest .param (
@@ -1423,6 +1470,7 @@ def watched_subprocess(self, mocker):
1423
1470
"task_ids" : ["task1" , "task2" ],
1424
1471
},
1425
1472
TICount (count = 2 ),
1473
+ None ,
1426
1474
id = "get_ti_count" ,
1427
1475
),
1428
1476
pytest .param (
@@ -1437,6 +1485,7 @@ def watched_subprocess(self, mocker):
1437
1485
"states" : ["success" , "failed" ],
1438
1486
},
1439
1487
DRCount (count = 2 ),
1488
+ None ,
1440
1489
id = "get_dr_count" ,
1441
1490
),
1442
1491
pytest .param (
@@ -1453,6 +1502,7 @@ def watched_subprocess(self, mocker):
1453
1502
"task_group_id" : "test_group" ,
1454
1503
},
1455
1504
TaskStatesResult (task_states = {"run_id" : {"task1" : "success" , "task2" : "failed" }}),
1505
+ None ,
1456
1506
id = "get_task_states" ,
1457
1507
),
1458
1508
pytest .param (
@@ -1468,6 +1518,7 @@ def watched_subprocess(self, mocker):
1468
1518
("test_dag" , "test_run" , "test_task" , "test_key" , 0 ),
1469
1519
{},
1470
1520
XComResult (key = "test_key" , value = "test_value" ),
1521
+ None ,
1471
1522
id = "get_xcom_seq_item" ,
1472
1523
),
1473
1524
pytest .param (
@@ -1483,13 +1534,14 @@ def watched_subprocess(self, mocker):
1483
1534
("test_dag" , "test_run" , "test_task" , "test_key" , 2 ),
1484
1535
{},
1485
1536
ErrorResponse (error = ErrorType .XCOM_NOT_FOUND ),
1537
+ None ,
1486
1538
id = "get_xcom_seq_item_not_found" ,
1487
1539
),
1488
1540
],
1489
1541
)
1490
1542
def test_handle_requests (
1491
1543
self ,
1492
- mock_secrets_masker ,
1544
+ mock_mask_secret ,
1493
1545
watched_subprocess ,
1494
1546
mocker ,
1495
1547
time_machine ,
@@ -1499,6 +1551,7 @@ def test_handle_requests(
1499
1551
method_arg ,
1500
1552
method_kwarg ,
1501
1553
mock_response ,
1554
+ mask_secret_args ,
1502
1555
):
1503
1556
"""
1504
1557
Test handling of different messages to the subprocess. For any new message type, add a
@@ -1511,7 +1564,6 @@ def test_handle_requests(
1511
1564
3. Checks that the buffer is updated with the expected response.
1512
1565
4. Verifies that the response is correctly decoded.
1513
1566
"""
1514
- mock_secrets_masker .return_value = SecretsMasker ()
1515
1567
watched_subprocess , read_socket = watched_subprocess
1516
1568
1517
1569
# Mock the client method. E.g. `client.variables.get` or `client.connections.get`
@@ -1524,6 +1576,10 @@ def test_handle_requests(
1524
1576
next (generator )
1525
1577
msg = message .model_dump_json ().encode () + b"\n "
1526
1578
generator .send (msg )
1579
+
1580
+ if mask_secret_args :
1581
+ mock_mask_secret .assert_called_with (* mask_secret_args )
1582
+
1527
1583
time_machine .move_to (timezone .datetime (2024 , 10 , 31 ), tick = False )
1528
1584
1529
1585
# Verify the correct client method was called
0 commit comments