Skip to content

Commit

Permalink
Use Beam YAML main.py (#1709)
Browse files Browse the repository at this point in the history
Signed-off-by: Jeffrey Kinard <[email protected]>
  • Loading branch information
Polber authored Jul 16, 2024
1 parent 6deb448 commit 580805c
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 80 deletions.
3 changes: 3 additions & 0 deletions .github/workflows/java-pr.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ on:
paths:
- '**.java'
- '**.xml'
# Include python files and Dockerfiles used for YAML and xlang templates.
- '**.py'
- 'plugins/core-plugin/src/main/resources/**'
# Include relevant GitHub Action files for running these checks.
# This will make it easier to verify action changes don't break anything.
- '.github/actions/setup-env/*'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ RUN python -m venv /venv \
&& rm -rf /usr/local/lib/python$PY_VERSION/site-packages \
&& mv /venv/lib/python$PY_VERSION/site-packages /usr/local/lib/python$PY_VERSION/

# Cache provider environments for faster startup and expansion time
RUN python -m apache_beam.yaml.cache_provider_artifacts


#============================================================#
# Create Distroless xlang image compatible with YamlTemplate #
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ public interface YAMLTemplate {
name = "yaml_pipeline",
optional = true,
description = "Input YAML pipeline spec.",
hiddenUi = true,
helpText = "A yaml description of the pipeline to run.")
String getYamlPipeline();

Expand Down
83 changes: 3 additions & 80 deletions python/src/main/python/yaml-template/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,10 @@
import argparse
import json

import jinja2
import yaml

import apache_beam as beam
from apache_beam.io.filesystems import FileSystems
from apache_beam.typehints.schemas import LogicalType
from apache_beam.typehints.schemas import MillisInstant
from apache_beam.yaml import cache_provider_artifacts
from apache_beam.yaml import yaml_transform

# Workaround for https://github.com/apache/beam/issues/28151.
LogicalType.register_logical_type(MillisInstant)
from apache_beam.yaml import main


# TODO(polber) - remove _preparse_jinja_flags with Beam 2.58.0
def _preparse_jinja_flags(argv):
"""Promotes any flags to --jinja_variables based on --jinja_variable_flags.
This is to facilitate tools (such as dataflow templates) that must pass
Expand Down Expand Up @@ -69,78 +59,11 @@ def _preparse_jinja_flags(argv):
return pipeline_args


def _configure_parser(argv):
parser = argparse.ArgumentParser()
parser.add_argument(
'--yaml_pipeline',
'--pipeline_spec',
help='A yaml description of the pipeline to run.')
parser.add_argument(
'--yaml_pipeline_file',
'--pipeline_spec_file',
help='A file containing a yaml description of the pipeline to run.')
parser.add_argument(
'--json_schema_validation',
default='generic',
help='none: do no pipeline validation against the schema; '
'generic: validate the pipeline shape, but not individual transforms; '
'per_transform: also validate the config of known transforms')
parser.add_argument(
'--jinja_variables',
default=None,
type=json.loads,
help='A json dict of variables used when invoking the jinja preprocessor '
'on the provided yaml pipeline.')
return parser.parse_known_args(argv)


def _pipeline_spec_from_args(known_args):
if known_args.yaml_pipeline_file and known_args.yaml_pipeline:
raise ValueError(
"Exactly one of yaml_pipeline or yaml_pipeline_file must be set.")
elif known_args.yaml_pipeline_file:
with FileSystems.open(known_args.yaml_pipeline_file) as fin:
pipeline_yaml = fin.read().decode()
elif known_args.yaml_pipeline:
pipeline_yaml = known_args.yaml_pipeline
else:
raise ValueError(
"Exactly one of yaml_pipeline or yaml_pipeline_file must be set.")

return pipeline_yaml


class _BeamFileIOLoader(jinja2.BaseLoader):
def get_source(self, environment, path):
source = FileSystems.open(path).read().decode()
return source, path, lambda: True


def run(argv=None):
argv = _preparse_jinja_flags(argv)
known_args, pipeline_args = _configure_parser(argv)
pipeline_yaml = ( # keep formatting
jinja2.Environment(
undefined=jinja2.StrictUndefined, loader=_BeamFileIOLoader())
.from_string(_pipeline_spec_from_args(known_args))
.render(**known_args.jinja_variables or {}))
pipeline_spec = yaml.load(pipeline_yaml, Loader=yaml_transform.SafeLineLoader)

with beam.Pipeline( # linebreak for better yapf formatting
options=beam.options.pipeline_options.PipelineOptions(
pipeline_args,
pickle_library='cloudpickle',
**yaml_transform.SafeLineLoader.strip_metadata(pipeline_spec.get(
'options', {}))),
display_data={'yaml': pipeline_yaml}) as p:
print("Building pipeline...")
yaml_transform.expand_pipeline(
p, pipeline_spec, validate_schema=known_args.json_schema_validation)
print("Running pipeline...")
main.run(argv=_preparse_jinja_flags(argv))


if __name__ == '__main__':
import logging
logging.getLogger().setLevel(logging.INFO)
cache_provider_artifacts.cache_provider_artifacts()
run()

0 comments on commit 580805c

Please sign in to comment.