Content-Length: 456318 | pFad | https://github.com/apache/airflow/commit/1982c3fdca1f04cfc41fc5b5e285d8f01c6b76ab

2F Run Dataflow for ML Engine summary in venv (#7809) · apache/airflow@1982c3f · GitHub
Skip to content

Commit 1982c3f

Browse files
authored
Run Dataflow for ML Engine summary in venv (#7809)
1 parent 0c6af43 commit 1982c3f

File tree

3 files changed

+28
-22
lines changed

3 files changed

+28
-22
lines changed

airflow/providers/google/cloud/utils/mlengine_operator_utils.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -225,20 +225,22 @@ def validate_err_and_count(summary):
225225
metric_fn_encoded = base64.b64encode(dill.dumps(metric_fn, recurse=True)).decode()
226226
evaluate_summary = DataflowCreatePythonJobOperator(
227227
task_id=(task_prefix + "-summary"),
228-
py_options=["-m"],
229-
py_file="airflow.providers.google.cloud.utils.mlengine_prediction_summary",
228+
py_file=os.path.join(os.path.dirname(__file__), 'mlengine_prediction_summary.py'),
230229
dataflow_default_options=dataflow_options,
231230
options={
232231
"prediction_path": prediction_path,
233232
"metric_fn_encoded": metric_fn_encoded,
234233
"metric_keys": ','.join(metric_keys)
235234
},
236235
py_interpreter=py_interpreter,
236+
py_requirements=[
237+
'apache-beam[gcp]>=2.14.0'
238+
],
237239
dag=dag)
238240
evaluate_summary.set_upstream(evaluate_prediction)
239241

240-
def apply_validate_fn(*args, **kwargs):
241-
prediction_path = kwargs["templates_dict"]["prediction_path"]
242+
def apply_validate_fn(*args, templates_dict, **kwargs):
243+
prediction_path = templates_dict["prediction_path"]
242244
scheme, bucket, obj, _, _ = urlsplit(prediction_path)
243245
if scheme != "gs" or not bucket or not obj:
244246
raise ValueError("Wrong format prediction_path: {}".format(prediction_path))

airflow/providers/google/cloud/utils/mlengine_prediction_summary.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ def metric_fn(inst):
8585
import argparse
8686
import base64
8787
import json
88+
import logging
8889
import os
8990

9091
import apache_beam as beam
@@ -156,23 +157,24 @@ def run(argv=None):
156157
raise ValueError("--metric_fn_encoded must be an encoded callable.")
157158
metric_keys = known_args.metric_keys.split(",")
158159

159-
with beam.Pipeline(
160-
options=beam.pipeline.PipelineOptions(pipeline_args)) as pipe:
161-
# This is apache-beam ptransform's convention
160+
with beam.Pipeline(options=beam.pipeline.PipelineOptions(pipeline_args)) as pipe:
162161
# pylint: disable=no-value-for-parameter
163-
_ = (pipe
164-
| "ReadPredictionResult" >> beam.io.ReadFromText(
165-
os.path.join(known_args.prediction_path,
166-
"prediction.results-*-of-*"),
167-
coder=JsonCoder())
168-
| "Summary" >> MakeSummary(metric_fn, metric_keys)
169-
| "Write" >> beam.io.WriteToText(
170-
os.path.join(known_args.prediction_path,
171-
"prediction.summary.json"),
172-
shard_name_template='', # without trailing -NNNNN-of-NNNNN.
173-
coder=JsonCoder()))
174-
# pylint: enable=no-value-for-parameter
162+
prediction_result_pattern = os.path.join(known_args.prediction_path, "prediction.results-*-of-*")
163+
prediction_summary_path = os.path.join(known_args.prediction_path, "prediction.summary.json")
164+
# This is apache-beam ptransform's convention
165+
_ = (
166+
pipe | "ReadPredictionResult" >> beam.io.ReadFromText(
167+
prediction_result_pattern, coder=JsonCoder())
168+
| "Summary" >> MakeSummary(metric_fn, metric_keys)
169+
| "Write" >> beam.io.WriteToText(
170+
prediction_summary_path,
171+
shard_name_template='', # without trailing -NNNNN-of-NNNNN.
172+
coder=JsonCoder())
173+
)
175174

176175

177176
if __name__ == "__main__":
177+
# Dataflow does not print anything on the screen by default. Good practice says to configure the logger
178+
# to be able to track the progress. This code is run in a separate process, so it's safe.
179+
logging.getLogger().setLevel(logging.INFO)
178180
run()

tests/providers/google/cloud/operators/test_mlengine_utils.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
import unittest
2020
from unittest.mock import ANY, patch
2121

22+
import mock
23+
2224
from airflow.exceptions import AirflowException
2325
from airflow.models.dag import DAG
2426
from airflow.providers.google.cloud.utils import mlengine_operator_utils
@@ -110,10 +112,10 @@ def test_successful_run(self):
110112
'metric_keys': 'err',
111113
'metric_fn_encoded': self.metric_fn_encoded,
112114
},
113-
dataflow='airflow.providers.google.cloud.utils.mlengine_prediction_summary',
114-
py_options=['-m'],
115+
dataflow=mock.ANY,
116+
py_options=[],
117+
py_requirements=['apache-beam[gcp]>=2.14.0'],
115118
py_interpreter='python3',
116-
py_requirements=[],
117119
py_system_site_packages=False,
118120
on_new_job_id_callback=ANY
119121
)

0 commit comments

Comments
 (0)








ApplySandwichStrip

pFad - (p)hone/(F)rame/(a)nonymizer/(d)eclutterfier!      Saves Data!


--- a PPN by Garber Painting Akron. With Image Size Reduction included!

Fetched URL: https://github.com/apache/airflow/commit/1982c3fdca1f04cfc41fc5b5e285d8f01c6b76ab

Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy