Skip to content

Commit 1982c3f

Browse files
authored
Run Dataflow for ML Engine summary in venv (#7809)
1 parent 0c6af43 commit 1982c3f

File tree

3 files changed

+28
-22
lines changed

3 files changed

+28
-22
lines changed

airflow/providers/google/cloud/utils/mlengine_operator_utils.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -225,20 +225,22 @@ def validate_err_and_count(summary):
225225
metric_fn_encoded = base64.b64encode(dill.dumps(metric_fn, recurse=True)).decode()
226226
evaluate_summary = DataflowCreatePythonJobOperator(
227227
task_id=(task_prefix + "-summary"),
228-
py_options=["-m"],
229-
py_file="airflow.providers.google.cloud.utils.mlengine_prediction_summary",
228+
py_file=os.path.join(os.path.dirname(__file__), 'mlengine_prediction_summary.py'),
230229
dataflow_default_options=dataflow_options,
231230
options={
232231
"prediction_path": prediction_path,
233232
"metric_fn_encoded": metric_fn_encoded,
234233
"metric_keys": ','.join(metric_keys)
235234
},
236235
py_interpreter=py_interpreter,
236+
py_requirements=[
237+
'apache-beam[gcp]>=2.14.0'
238+
],
237239
dag=dag)
238240
evaluate_summary.set_upstream(evaluate_prediction)
239241

240-
def apply_validate_fn(*args, **kwargs):
241-
prediction_path = kwargs["templates_dict"]["prediction_path"]
242+
def apply_validate_fn(*args, templates_dict, **kwargs):
243+
prediction_path = templates_dict["prediction_path"]
242244
scheme, bucket, obj, _, _ = urlsplit(prediction_path)
243245
if scheme != "gs" or not bucket or not obj:
244246
raise ValueError("Wrong format prediction_path: {}".format(prediction_path))

airflow/providers/google/cloud/utils/mlengine_prediction_summary.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ def metric_fn(inst):
8585
import argparse
8686
import base64
8787
import json
88+
import logging
8889
import os
8990

9091
import apache_beam as beam
@@ -156,23 +157,24 @@ def run(argv=None):
156157
raise ValueError("--metric_fn_encoded must be an encoded callable.")
157158
metric_keys = known_args.metric_keys.split(",")
158159

159-
with beam.Pipeline(
160-
options=beam.pipeline.PipelineOptions(pipeline_args)) as pipe:
161-
# This is apache-beam ptransform's convention
160+
with beam.Pipeline(options=beam.pipeline.PipelineOptions(pipeline_args)) as pipe:
162161
# pylint: disable=no-value-for-parameter
163-
_ = (pipe
164-
| "ReadPredictionResult" >> beam.io.ReadFromText(
165-
os.path.join(known_args.prediction_path,
166-
"prediction.results-*-of-*"),
167-
coder=JsonCoder())
168-
| "Summary" >> MakeSummary(metric_fn, metric_keys)
169-
| "Write" >> beam.io.WriteToText(
170-
os.path.join(known_args.prediction_path,
171-
"prediction.summary.json"),
172-
shard_name_template='', # without trailing -NNNNN-of-NNNNN.
173-
coder=JsonCoder()))
174-
# pylint: enable=no-value-for-parameter
162+
prediction_result_pattern = os.path.join(known_args.prediction_path, "prediction.results-*-of-*")
163+
prediction_summary_path = os.path.join(known_args.prediction_path, "prediction.summary.json")
164+
# This is apache-beam ptransform's convention
165+
_ = (
166+
pipe | "ReadPredictionResult" >> beam.io.ReadFromText(
167+
prediction_result_pattern, coder=JsonCoder())
168+
| "Summary" >> MakeSummary(metric_fn, metric_keys)
169+
| "Write" >> beam.io.WriteToText(
170+
prediction_summary_path,
171+
shard_name_template='', # without trailing -NNNNN-of-NNNNN.
172+
coder=JsonCoder())
173+
)
175174

176175

177176
if __name__ == "__main__":
177+
# Dataflow does not print anything on the screen by default. Good practice says to configure the logger
178+
# to be able to track the progress. This code is run in a separate process, so it's safe.
179+
logging.getLogger().setLevel(logging.INFO)
178180
run()

tests/providers/google/cloud/operators/test_mlengine_utils.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
import unittest
2020
from unittest.mock import ANY, patch
2121

22+
import mock
23+
2224
from airflow.exceptions import AirflowException
2325
from airflow.models.dag import DAG
2426
from airflow.providers.google.cloud.utils import mlengine_operator_utils
@@ -110,10 +112,10 @@ def test_successful_run(self):
110112
'metric_keys': 'err',
111113
'metric_fn_encoded': self.metric_fn_encoded,
112114
},
113-
dataflow='airflow.providers.google.cloud.utils.mlengine_prediction_summary',
114-
py_options=['-m'],
115+
dataflow=mock.ANY,
116+
py_options=[],
117+
py_requirements=['apache-beam[gcp]>=2.14.0'],
115118
py_interpreter='python3',
116-
py_requirements=[],
117119
py_system_site_packages=False,
118120
on_new_job_id_callback=ANY
119121
)

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy