Skip to content

Commit 595981c

Browse files
author
Vincent Koc
authored
Light Refactor and Clean-up AWS Provider (#23907)
1 parent ab1f637 commit 595981c

25 files changed

+70
-95
lines changed

airflow/providers/amazon/aws/hooks/athena.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,7 @@ def run_query(
9191
if client_request_token:
9292
params['ClientRequestToken'] = client_request_token
9393
response = self.get_conn().start_query_execution(**params)
94-
query_execution_id = response['QueryExecutionId']
95-
return query_execution_id
94+
return response['QueryExecutionId']
9695

9796
def check_query_status(self, query_execution_id: str) -> Optional[str]:
9897
"""

airflow/providers/amazon/aws/hooks/glue.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -118,8 +118,8 @@ def initialize_job(
118118

119119
try:
120120
job_name = self.get_or_create_glue_job()
121-
job_run = glue_client.start_job_run(JobName=job_name, Arguments=script_arguments, **run_kwargs)
122-
return job_run
121+
return glue_client.start_job_run(JobName=job_name, Arguments=script_arguments, **run_kwargs)
122+
123123
except Exception as general_error:
124124
self.log.error("Failed to run aws glue job, error: %s", general_error)
125125
raise
@@ -134,8 +134,7 @@ def get_job_state(self, job_name: str, run_id: str) -> str:
134134
"""
135135
glue_client = self.get_conn()
136136
job_run = glue_client.get_job_run(JobName=job_name, RunId=run_id, PredecessorsIncluded=True)
137-
job_run_state = job_run['JobRun']['JobRunState']
138-
return job_run_state
137+
return job_run['JobRun']['JobRunState']
139138

140139
def job_completion(self, job_name: str, run_id: str) -> Dict[str, str]:
141140
"""
@@ -155,7 +154,7 @@ def job_completion(self, job_name: str, run_id: str) -> Dict[str, str]:
155154
self.log.info("Exiting Job %s Run State: %s", run_id, job_run_state)
156155
return {'JobRunState': job_run_state, 'JobRunId': run_id}
157156
if job_run_state in failed_states:
158-
job_error_message = "Exiting Job " + run_id + " Run State: " + job_run_state
157+
job_error_message = f"Exiting Job {run_id} Run State: {job_run_state}"
159158
self.log.info(job_error_message)
160159
raise AirflowException(job_error_message)
161160
else:

airflow/providers/amazon/aws/hooks/glue_crawler.py

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -102,8 +102,7 @@ def create_crawler(self, **crawler_kwargs) -> str:
102102
"""
103103
crawler_name = crawler_kwargs['Name']
104104
self.log.info("Creating crawler: %s", crawler_name)
105-
crawler = self.glue_client.create_crawler(**crawler_kwargs)
106-
return crawler
105+
return self.glue_client.create_crawler(**crawler_kwargs)
107106

108107
def start_crawler(self, crawler_name: str) -> dict:
109108
"""
@@ -113,8 +112,7 @@ def start_crawler(self, crawler_name: str) -> dict:
113112
:return: Empty dictionary
114113
"""
115114
self.log.info("Starting crawler %s", crawler_name)
116-
crawler = self.glue_client.start_crawler(Name=crawler_name)
117-
return crawler
115+
return self.glue_client.start_crawler(Name=crawler_name)
118116

119117
def wait_for_crawler_completion(self, crawler_name: str, poll_interval: int = 5) -> str:
120118
"""
@@ -137,18 +135,17 @@ def wait_for_crawler_completion(self, crawler_name: str, poll_interval: int = 5)
137135
crawler_status = crawler['LastCrawl']['Status']
138136
if crawler_status in failed_status:
139137
raise AirflowException(f"Status: {crawler_status}")
140-
else:
141-
metrics = self.glue_client.get_crawler_metrics(CrawlerNameList=[crawler_name])[
142-
'CrawlerMetricsList'
143-
][0]
144-
self.log.info("Status: %s", crawler_status)
145-
self.log.info("Last Runtime Duration (seconds): %s", metrics['LastRuntimeSeconds'])
146-
self.log.info("Median Runtime Duration (seconds): %s", metrics['MedianRuntimeSeconds'])
147-
self.log.info("Tables Created: %s", metrics['TablesCreated'])
148-
self.log.info("Tables Updated: %s", metrics['TablesUpdated'])
149-
self.log.info("Tables Deleted: %s", metrics['TablesDeleted'])
150-
151-
return crawler_status
138+
metrics = self.glue_client.get_crawler_metrics(CrawlerNameList=[crawler_name])[
139+
'CrawlerMetricsList'
140+
][0]
141+
self.log.info("Status: %s", crawler_status)
142+
self.log.info("Last Runtime Duration (seconds): %s", metrics['LastRuntimeSeconds'])
143+
self.log.info("Median Runtime Duration (seconds): %s", metrics['MedianRuntimeSeconds'])
144+
self.log.info("Tables Created: %s", metrics['TablesCreated'])
145+
self.log.info("Tables Updated: %s", metrics['TablesUpdated'])
146+
self.log.info("Tables Deleted: %s", metrics['TablesDeleted'])
147+
148+
return crawler_status
152149

153150
else:
154151
self.log.info("Polling for AWS Glue crawler: %s ", crawler_name)

airflow/providers/amazon/aws/hooks/kinesis.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,7 @@ def __init__(self, delivery_stream: str, *args, **kwargs) -> None:
4343

4444
def put_records(self, records: Iterable):
4545
"""Write batch records to Kinesis Firehose"""
46-
response = self.get_conn().put_record_batch(DeliveryStreamName=self.delivery_stream, Records=records)
47-
48-
return response
46+
return self.get_conn().put_record_batch(DeliveryStreamName=self.delivery_stream, Records=records)
4947

5048

5149
class AwsFirehoseHook(FirehoseHook):

airflow/providers/amazon/aws/hooks/lambda_function.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,7 @@ def invoke_lambda(
6666
"Payload": payload,
6767
"Qualifier": qualifier,
6868
}
69-
response = self.conn.invoke(**{k: v for k, v in invoke_args.items() if v is not None})
70-
return response
69+
return self.conn.invoke(**{k: v for k, v in invoke_args.items() if v is not None})
7170

7271
def create_lambda(
7372
self,
@@ -118,10 +117,9 @@ def create_lambda(
118117
"CodeSigningConfigArn": code_signing_config_arn,
119118
"Architectures": architectures,
120119
}
121-
response = self.conn.create_function(
120+
return self.conn.create_function(
122121
**{k: v for k, v in create_function_args.items() if v is not None},
123122
)
124-
return response
125123

126124

127125
class AwsLambdaHook(LambdaHook):

airflow/providers/amazon/aws/hooks/redshift_sql.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,4 @@ def get_conn(self) -> RedshiftConnection:
130130
conn_params = self._get_conn_params()
131131
conn_kwargs_dejson = self.conn.extra_dejson
132132
conn_kwargs: Dict = {**conn_params, **conn_kwargs_dejson}
133-
conn: RedshiftConnection = redshift_connector.connect(**conn_kwargs)
134-
135-
return conn
133+
return redshift_connector.connect(**conn_kwargs)

airflow/providers/amazon/aws/hooks/s3.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -277,11 +277,10 @@ def list_prefixes(
277277
Bucket=bucket_name, Prefix=prefix, Delimiter=delimiter, PaginationConfig=config
278278
)
279279

280-
prefixes = []
280+
prefixes = [] # type: List[str]
281281
for page in response:
282282
if 'CommonPrefixes' in page:
283-
for common_prefix in page['CommonPrefixes']:
284-
prefixes.append(common_prefix['Prefix'])
283+
prefixes.extend(common_prefix['Prefix'] for common_prefix in page['CommonPrefixes'])
285284

286285
return prefixes
287286

@@ -366,12 +365,10 @@ def _is_in_period(input_date: datetime) -> bool:
366365
StartAfter=start_after_key,
367366
)
368367

369-
keys = []
368+
keys = [] # type: List[str]
370369
for page in response:
371370
if 'Contents' in page:
372-
for k in page['Contents']:
373-
keys.append(k)
374-
371+
keys.extend(iter(page['Contents']))
375372
if self.object_filter_usr is not None:
376373
return self.object_filter_usr(keys, from_datetime, to_datetime)
377374

@@ -604,7 +601,7 @@ def load_file(
604601
extra_args['ServerSideEncryption'] = "AES256"
605602
if gzip:
606603
with open(filename, 'rb') as f_in:
607-
filename_gz = f_in.name + '.gz'
604+
filename_gz = f'{f_in.name}.gz'
608605
with gz.open(filename_gz, 'wb') as f_out:
609606
shutil.copyfileobj(f_in, f_out)
610607
filename = filename_gz

airflow/providers/amazon/aws/log/s3_task_handler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
# specific language governing permissions and limitations
1717
# under the License.
1818
import os
19+
import pathlib
1920
import sys
2021

2122
if sys.version_info >= (3, 8):
@@ -92,8 +93,7 @@ def close(self):
9293
remote_loc = os.path.join(self.remote_base, self.log_relative_path)
9394
if os.path.exists(local_loc):
9495
# read log and remove old logs to get just the latest additions
95-
with open(local_loc) as logfile:
96-
log = logfile.read()
96+
log = pathlib.Path(local_loc).read_text()
9797
self.s3_write(log, remote_loc)
9898

9999
# Mark closed so we don't double write if close is called twice

airflow/providers/amazon/aws/operators/glacier.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,5 +51,4 @@ def __init__(
5151

5252
def execute(self, context: 'Context'):
5353
hook = GlacierHook(aws_conn_id=self.aws_conn_id)
54-
response = hook.retrieve_inventory(vault_name=self.vault_name)
55-
return response
54+
return hook.retrieve_inventory(vault_name=self.vault_name)

airflow/providers/amazon/aws/secrets/secrets_manager.py

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ def __init__(
134134
self.profile_name = profile_name
135135
self.sep = sep
136136
self.full_url_mode = full_url_mode
137-
self.extra_conn_words = extra_conn_words if extra_conn_words else {}
137+
self.extra_conn_words = extra_conn_words or {}
138138
self.kwargs = kwargs
139139

140140
@cached_property
@@ -178,9 +178,7 @@ def get_uri_from_secret(self, secret):
178178

179179
conn_string = "{conn_type}://{user}:{password}@{host}:{port}/{schema}".format(**conn_d)
180180

181-
connection = self._format_uri_with_extra(secret, conn_string)
182-
183-
return connection
181+
return self._format_uri_with_extra(secret, conn_string)
184182

185183
def get_conn_value(self, conn_id: str):
186184
"""
@@ -193,20 +191,19 @@ def get_conn_value(self, conn_id: str):
193191

194192
if self.full_url_mode:
195193
return self._get_secret(self.connections_prefix, conn_id)
196-
else:
197-
try:
198-
secret_string = self._get_secret(self.connections_prefix, conn_id)
199-
# json.loads gives error
200-
secret = ast.literal_eval(secret_string) if secret_string else None
201-
except ValueError: # 'malformed node or string: ' error, for empty conns
202-
connection = None
203-
secret = None
204-
205-
# These lines will check if we have with some denomination stored an username, password and host
206-
if secret:
207-
connection = self.get_uri_from_secret(secret)
208-
209-
return connection
194+
try:
195+
secret_string = self._get_secret(self.connections_prefix, conn_id)
196+
# json.loads gives error
197+
secret = ast.literal_eval(secret_string) if secret_string else None
198+
except ValueError: # 'malformed node or string: ' error, for empty conns
199+
connection = None
200+
secret = None
201+
202+
# These lines will check if we have with some denomination stored an username, password and host
203+
if secret:
204+
connection = self.get_uri_from_secret(secret)
205+
206+
return connection
210207

211208
def get_conn_uri(self, conn_id: str) -> Optional[str]:
212209
"""

airflow/providers/amazon/aws/secrets/systems_manager.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -158,8 +158,7 @@ def _get_secret(self, path_prefix: str, secret_id: str) -> Optional[str]:
158158
ssm_path = self.build_path(path_prefix, secret_id)
159159
try:
160160
response = self.client.get_parameter(Name=ssm_path, WithDecryption=True)
161-
value = response["Parameter"]["Value"]
162-
return value
161+
return response["Parameter"]["Value"]
163162
except self.client.exceptions.ParameterNotFound:
164163
self.log.debug("Parameter %s not found.", ssm_path)
165164
return None

airflow/providers/amazon/aws/sensors/emr.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def get_hook(self) -> EmrHook:
6666
def poke(self, context: 'Context'):
6767
response = self.get_emr_response()
6868

69-
if not response['ResponseMetadata']['HTTPStatusCode'] == 200:
69+
if response['ResponseMetadata']['HTTPStatusCode'] != 200:
7070
self.log.info('Bad HTTP response: %s', response)
7171
return False
7272

airflow/providers/amazon/aws/sensors/glue.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def poke(self, context: 'Context'):
5757
self.log.info("Exiting Job %s Run State: %s", self.run_id, job_state)
5858
return True
5959
elif job_state in self.errored_states:
60-
job_error_message = "Exiting Job " + self.run_id + " Run State: " + job_state
60+
job_error_message = f"Exiting Job {self.run_id} Run State: {job_state}"
6161
raise AirflowException(job_error_message)
6262
else:
6363
return False

airflow/providers/amazon/aws/sensors/sagemaker.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def get_hook(self) -> SageMakerHook:
5050

5151
def poke(self, context: 'Context'):
5252
response = self.get_sagemaker_response()
53-
if not (response['ResponseMetadata']['HTTPStatusCode'] == 200):
53+
if response['ResponseMetadata']['HTTPStatusCode'] != 200:
5454
self.log.info('Bad HTTP response: %s', response)
5555
return False
5656
state = self.state_from_response(response)
@@ -225,7 +225,7 @@ def init_log_resource(self, hook: SageMakerHook) -> None:
225225
self.instance_count = description['ResourceConfig']['InstanceCount']
226226
status = description['TrainingJobStatus']
227227
job_already_completed = status not in self.non_terminal_states()
228-
self.state = LogState.TAILING if (not job_already_completed) else LogState.COMPLETE
228+
self.state = LogState.COMPLETE if job_already_completed else LogState.TAILING
229229
self.last_description = description
230230
self.last_describe_job_call = time.monotonic()
231231
self.log_resource_inited = True

airflow/providers/amazon/aws/transfers/google_api_to_s3.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -155,13 +155,12 @@ def _retrieve_data_from_google_api(self) -> dict:
155155
api_version=self.google_api_service_version,
156156
impersonation_chain=self.google_impersonation_chain,
157157
)
158-
google_api_response = google_discovery_api_hook.query(
158+
return google_discovery_api_hook.query(
159159
endpoint=self.google_api_endpoint_path,
160160
data=self.google_api_endpoint_params,
161161
paginate=self.google_api_pagination,
162162
num_retries=self.google_api_num_retries,
163163
)
164-
return google_api_response
165164

166165
def _load_data_to_s3(self, data: dict) -> None:
167166
s3_hook = S3Hook(aws_conn_id=self.aws_conn_id)

airflow/providers/amazon/aws/transfers/mysql_to_s3.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,7 @@ def __init__(
6666
if "header" not in pd_kwargs:
6767
pd_kwargs["header"] = header
6868
kwargs["pd_kwargs"] = {**kwargs.get('pd_kwargs', {}), **pd_kwargs}
69-
else:
70-
if pd_csv_kwargs is not None:
71-
raise TypeError("pd_csv_kwargs may not be specified when file_format='parquet'")
69+
elif pd_csv_kwargs is not None:
70+
raise TypeError("pd_csv_kwargs may not be specified when file_format='parquet'")
7271

7372
super().__init__(sql_conn_id=mysql_conn_id, **kwargs)

airflow/providers/amazon/aws/utils/eks_get_token.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
import argparse
1919
import json
20-
from datetime import datetime, timedelta
20+
from datetime import datetime, timedelta, timezone
2121

2222
from airflow.providers.amazon.aws.hooks.eks import EksHook
2323

@@ -27,7 +27,8 @@
2727

2828

2929
def get_expiration_time():
30-
token_expiration = datetime.utcnow() + timedelta(minutes=TOKEN_EXPIRATION_MINUTES)
30+
token_expiration = datetime.now(timezone.utc) + timedelta(minutes=TOKEN_EXPIRATION_MINUTES)
31+
3132
return token_expiration.strftime('%Y-%m-%dT%H:%M:%SZ')
3233

3334

tests/providers/amazon/aws/hooks/test_cloud_formation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,4 +102,4 @@ def test_delete_stack(self):
102102

103103
stacks = self.hook.get_conn().describe_stacks()['Stacks']
104104
matching_stacks = [x for x in stacks if x['StackName'] == stack_name]
105-
assert len(matching_stacks) == 0, f'stack with name {stack_name} should not exist'
105+
assert not matching_stacks, f'stack with name {stack_name} should not exist'

tests/providers/amazon/aws/log/test_s3_task_handler.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
# specific language governing permissions and limitations
1717
# under the License.
1818

19+
import contextlib
1920
import os
2021
import unittest
2122
from unittest import mock
@@ -75,10 +76,8 @@ def setUp(self):
7576

7677
def tearDown(self):
7778
if self.s3_task_handler.handler:
78-
try:
79+
with contextlib.suppress(Exception):
7980
os.remove(self.s3_task_handler.handler.baseFilename)
80-
except Exception:
81-
pass
8281

8382
def test_hook(self):
8483
assert isinstance(self.s3_task_handler.hook, S3Hook)

tests/providers/amazon/aws/operators/test_athena.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,8 @@ def setUp(self):
5050
'start_date': DEFAULT_DATE,
5151
}
5252

53-
self.dag = DAG(TEST_DAG_ID + 'test_schedule_dag_once', default_args=args, schedule_interval='@once')
53+
self.dag = DAG(f'{TEST_DAG_ID}test_schedule_dag_once', default_args=args, schedule_interval='@once')
54+
5455
self.athena = AthenaOperator(
5556
task_id='test_athena_operator',
5657
query='SELECT * FROM TEST_TABLE',

tests/providers/amazon/aws/operators/test_dms_describe_tasks.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -57,11 +57,7 @@ def setUp(self):
5757
"start_date": DEFAULT_DATE,
5858
}
5959

60-
self.dag = DAG(
61-
TEST_DAG_ID + "test_schedule_dag_once",
62-
default_args=args,
63-
schedule_interval="@once",
64-
)
60+
self.dag = DAG(f"{TEST_DAG_ID}test_schedule_dag_once", default_args=args, schedule_interval="@once")
6561

6662
def test_init(self):
6763
dms_operator = DmsDescribeTasksOperator(

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy