Skip to content

Commit

Permalink
Allow ib.collect(...) to take multiple PCollections.
Browse files Browse the repository at this point in the history
It is often more efficient to do a single run computing multiple
collections at a time than to do multiple runs.
  • Loading branch information
robertwb committed Sep 5, 2024
1 parent 50a6cd2 commit d5dc200
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 38 deletions.
96 changes: 58 additions & 38 deletions sdks/python/apache_beam/runners/interactive/interactive_beam.py
Original file line number Diff line number Diff line change
Expand Up @@ -876,11 +876,12 @@ def show(

@progress_indicated
def collect(
pcoll,
*pcolls,
n='inf',
duration='inf',
include_window_info=False,
force_compute=False):
force_compute=False,
force_tuple=False):
"""Materializes the elements from a PCollection into a Dataframe.
This reads each element from file and reads only the amount that it needs
Expand All @@ -889,13 +890,16 @@ def collect(
it is assumed to be infinite.
Args:
pcolls: PCollections to compute.
n: (optional) max number of elements to visualize. Default 'inf'.
duration: (optional) max duration of elements to read in integer seconds or
a string duration. Default 'inf'.
include_window_info: (optional) if True, appends the windowing information
to each row. Default False.
force_compute: (optional) if True, forces recomputation rather than using
cached PCollections
force_tuple: (optional) if True, return a 1-tuple or results rather than
the bare results if only one PCollection is computed
For example::
Expand All @@ -906,17 +910,27 @@ def collect(
# Run the pipeline and bring the PCollection into memory as a Dataframe.
in_memory_square = head(square, n=5)
"""
# Remember the element type so we can make an informed decision on how to
# collect the result in elements_to_df.
if isinstance(pcoll, DeferredBase):
# Get the proxy so we can get the output shape of the DataFrame.
pcoll, element_type = deferred_df_to_pcollection(pcoll)
watch({'anonymous_pcollection_{}'.format(id(pcoll)): pcoll})
else:
element_type = pcoll.element_type
if len(pcolls) == 0:
return ()

def as_pcollection(pcoll_or_df):
if isinstance(pcoll_or_df, DeferredBase):
# Get the proxy so we can get the output shape of the DataFrame.
pcoll, element_type = deferred_df_to_pcollection(pcoll_or_df)
watch({'anonymous_pcollection_{}'.format(id(pcoll)): pcoll})
return pcoll, element_type
elif isinstance(pcoll_or_df, beam.pvalue.PCollection):
return pcoll_or_df, pcoll_or_df.element_type
else:
raise TypeError(f'{pcoll} is not an apache_beam.pvalue.PCollection.')

assert isinstance(pcoll, beam.pvalue.PCollection), (
'{} is not an apache_beam.pvalue.PCollection.'.format(pcoll))
pcolls_with_element_types = [as_pcollection(p) for p in pcolls]
pcolls_to_element_types = dict(pcolls_with_element_types)
pcolls = [pcoll for pcoll, _ in pcolls_with_element_types]
pipelines = set(pcoll.pipeline for pcoll in pcolls)
if len(pipelines) != 1:
raise ValueError('All PCollections must belong to the same pipeline.')
pipeline, = pipelines

if isinstance(n, str):
assert n == 'inf', (
Expand All @@ -935,45 +949,51 @@ def collect(
if duration == 'inf':
duration = float('inf')

user_pipeline = ie.current_env().user_pipeline(pcoll.pipeline)
user_pipeline = ie.current_env().user_pipeline(pipeline)
# Possibly collecting a PCollection defined in a local scope that is not
# explicitly watched. Ad hoc watch it though it's a little late.
if not user_pipeline:
watch({'anonymous_pipeline_{}'.format(id(pcoll.pipeline)): pcoll.pipeline})
user_pipeline = pcoll.pipeline
watch({'anonymous_pipeline_{}'.format(id(pipeline)): pipeline})
user_pipeline = pipeline
recording_manager = ie.current_env().get_recording_manager(
user_pipeline, create_if_absent=True)

# If already computed, directly read the stream and return.
if pcoll in ie.current_env().computed_pcollections and not force_compute:
pcoll_name = find_pcoll_name(pcoll)
elements = list(
recording_manager.read(pcoll_name, pcoll, n, duration).read())
return elements_to_df(
elements,
include_window_info=include_window_info,
element_type=element_type)

recording = recording_manager.record([pcoll],
max_n=n,
max_duration=duration,
force_compute=force_compute)

try:
elements = list(recording.stream(pcoll).read())
except KeyboardInterrupt:
recording.cancel()
return pd.DataFrame()
computed = {}
for pcoll in pcolls_to_element_types.keys():
if pcoll in ie.current_env().computed_pcollections and not force_compute:
pcoll_name = find_pcoll_name(pcoll)
computed[pcoll] = list(
recording_manager.read(pcoll_name, pcoll, n, duration).read())

uncomputed = set(pcolls) - set(computed.keys())
if uncomputed:
recording = recording_manager.record(
uncomputed, max_n=n, max_duration=duration, force_compute=force_compute)

try:
for pcoll in uncomputed:
computed[pcoll] = list(recording.stream(pcoll).read())
except KeyboardInterrupt:
recording.cancel()

if n == float('inf'):
n = None

# Collecting DataFrames may have a length > n, so slice again to be sure. Note
# that array[:None] returns everything.
return elements_to_df(
elements,
include_window_info=include_window_info,
element_type=element_type)[:n]
empty = pd.DataFrame()
result_tuple = tuple(
elements_to_df(
computed[pcoll],
include_window_info=include_window_info,
element_type=pcolls_to_element_types[pcoll])[:n] if pcoll in
computed else empty for pcoll in pcolls)

if len(result_tuple) == 1 and not force_tuple:
return result_tuple[0]
else:
return result_tuple


@progress_indicated
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,50 @@ def test_basic(self):
self.assertEqual(set(collected2[0]), set(['A', 'B', 'C']))
self.assertEqual(count_side_effects('a'), 2)

@unittest.skipIf(sys.platform == "win32", "[BEAM-10627]")
def test_multiple_collect(self):
clear_side_effect()
p = beam.Pipeline(direct_runner.DirectRunner())

# Initial collection runs the pipeline.
pcollA = p | 'A' >> beam.Create(['a']) | 'As' >> beam.Map(cause_side_effect)
pcollB = p | 'B' >> beam.Create(['b']) | 'Bs' >> beam.Map(cause_side_effect)
collectedA, collectedB = ib.collect(pcollA, pcollB)
self.assertEqual(set(collectedA[0]), set(['a']))
self.assertEqual(set(collectedB[0]), set(['b']))
self.assertEqual(count_side_effects('a'), 1)
self.assertEqual(count_side_effects('b'), 1)

# Collecting the PCollection again uses the cache.
collectedA, collectedB = ib.collect(pcollA, pcollB)
self.assertEqual(set(collectedA[0]), set(['a']))
self.assertEqual(set(collectedB[0]), set(['b']))
self.assertEqual(count_side_effects('a'), 1)
self.assertEqual(count_side_effects('b'), 1)

# Using the PCollection uses the cache.
pcollAA = pcollA | beam.Map(
lambda x: 2 * x) | 'AAs' >> beam.Map(cause_side_effect)
collectedA, collectedB, collectedAA = ib.collect(pcollA, pcollB, pcollAA)
self.assertEqual(set(collectedA[0]), set(['a']))
self.assertEqual(set(collectedB[0]), set(['b']))
self.assertEqual(set(collectedAA[0]), set(['aa']))
self.assertEqual(count_side_effects('a'), 1)
self.assertEqual(count_side_effects('b'), 1)
self.assertEqual(count_side_effects('aa'), 1)

# Duplicates are only computed once.
pcollBB = pcollB | beam.Map(
lambda x: 2 * x) | 'BBs' >> beam.Map(cause_side_effect)
collectedAA, collectedAAagain, collectedBB, collectedBBagain = ib.collect(
pcollAA, pcollAA, pcollBB, pcollBB)
self.assertEqual(set(collectedAA[0]), set(['aa']))
self.assertEqual(set(collectedAAagain[0]), set(['aa']))
self.assertEqual(set(collectedBB[0]), set(['bb']))
self.assertEqual(set(collectedBBagain[0]), set(['bb']))
self.assertEqual(count_side_effects('aa'), 1)
self.assertEqual(count_side_effects('bb'), 1)

@unittest.skipIf(sys.platform == "win32", "[BEAM-10627]")
def test_wordcount(self):
class WordExtractingDoFn(beam.DoFn):
Expand Down

0 comments on commit d5dc200

Please sign in to comment.