Skip to content

Commit

Permalink
Merge pull request #32141 Add basic testing for yaml join docs.
Browse files Browse the repository at this point in the history
  • Loading branch information
robertwb authored Aug 19, 2024
2 parents 6b4a7a5 + 7f1c7f4 commit 38dfbd4
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 8 deletions.
39 changes: 33 additions & 6 deletions sdks/python/apache_beam/yaml/readme_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ def expand(self, inputs):

def guess_name_and_type(expr):
expr = expr.strip().replace('`', '')
if expr.endswith('*'):
return 'unknown', str
parts = expr.split()
if len(parts) >= 2 and parts[-2].lower() == 'as':
name = parts[-1]
Expand Down Expand Up @@ -87,7 +89,7 @@ def guess_name_and_type(expr):
return name, typ

if m.group(1) == '*':
return inputs['PCOLLECTION'] | beam.Filter(lambda _: True)
return next(iter(inputs.values())) | beam.Filter(lambda _: True)
else:
output_schema = [
guess_name_and_type(expr) for expr in m.group(1).split(',')
Expand Down Expand Up @@ -269,6 +271,22 @@ def test(self):

def parse_test_methods(markdown_lines):
# pylint: disable=too-many-nested-blocks

def extract_inputs(input_spec):
if not input_spec:
return set()
elif isinstance(input_spec, str):
return set([input_spec.split('.')[0]])
elif isinstance(input_spec, list):
return set.union(*[extract_inputs(v) for v in input_spec])
elif isinstance(input_spec, dict):
return set.union(*[extract_inputs(v) for v in input_spec.values()])
else:
raise ValueError("Misformed inputs: " + input_spec)

def extract_name(input_spec):
return input_spec.get('name', input_spec.get('type'))

code_lines = None
for ix, line in enumerate(markdown_lines):
line = line.rstrip()
Expand All @@ -280,17 +298,23 @@ def parse_test_methods(markdown_lines):
else:
if code_lines:
if code_lines[0].startswith('- type:'):
is_chain = not any('input:' in line for line in code_lines)
specs = yaml.load('\n'.join(code_lines), Loader=SafeLoader)
is_chain = not any('input' in spec for spec in specs)
if is_chain:
undefined_inputs = set(['input'])
else:
undefined_inputs = set.union(
*[extract_inputs(spec.get('input')) for spec in specs]) - set(
extract_name(spec) for spec in specs)
# Treat this as a fragment of a larger pipeline.
# pylint: disable=not-an-iterable
code_lines = [
'pipeline:',
' type: chain' if is_chain else '',
' transforms:',
' - type: ReadFromCsv',
' name: input',
' config:',
' path: whatever',
] + [
' - {type: ReadFromCsv, name: "%s", config: {path: x}}' %
undefined_input for undefined_input in undefined_inputs
] + [' ' + line for line in code_lines]
if code_lines[0] == 'pipeline:':
yaml_pipeline = '\n'.join(code_lines)
Expand Down Expand Up @@ -329,6 +353,9 @@ def createTestSuite(name, path):
InlinePythonTest = createTestSuite(
'InlinePythonTest', os.path.join(YAML_DOCS_DIR, 'yaml-inline-python.md'))

JoinTest = createTestSuite(
'JoinTest', os.path.join(YAML_DOCS_DIR, 'yaml-join.md'))

if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--render_dir', default=None)
Expand Down
3 changes: 2 additions & 1 deletion sdks/python/apache_beam/yaml/yaml_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,9 @@ def _is_connected(edge_list, expected_node_count):
def _SqlJoinTransform(
pcolls,
sql_transform_constructor,
type: Union[str, Dict[str, List]],
*,
equalities: Union[str, List[Dict[str, str]]],
type: Union[str, Dict[str, List]] = 'inner',
fields: Optional[Dict[str, Any]] = None):
"""Joins two or more inputs using a specified condition.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ inputs, one can use the following shorthand syntax:
input2: Second Input
input3: Third Input
config:
equalities: col
equalities: col1
```

## Join Types
Expand Down

0 comments on commit 38dfbd4

Please sign in to comment.