diff --git a/thicket/__init__.py b/thicket/__init__.py index d5fd06e1..f0f076c5 100644 --- a/thicket/__init__.py +++ b/thicket/__init__.py @@ -6,6 +6,7 @@ # make flake8 unused names in this file. # flake8: noqa: F401 +from .ensemble import Ensemble from .thicket import Thicket from .thicket import InvalidFilter from .thicket import EmptyMetadataTable diff --git a/thicket/ensemble.py b/thicket/ensemble.py new file mode 100644 index 00000000..62f38e5a --- /dev/null +++ b/thicket/ensemble.py @@ -0,0 +1,400 @@ +# Copyright 2022 Lawrence Livermore National Security, LLC and other +# Thicket Project Developers. See the top-level LICENSE file for details. +# +# SPDX-License-Identifier: MIT + +from collections import OrderedDict + +from hatchet import GraphFrame +import numpy as np +import pandas as pd + +import thicket.helpers as helpers +from .utils import verify_sorted_profile, verify_thicket_structures + + +class Ensemble: + """Operations pertaining to ensembling.""" + + @staticmethod + def _unify(thickets, inplace=False): + """Create union graph from list of thickets and sync their DataFrames. + + Arguments: + thickets (list): list of Thicket objects + inplace (bool): whether to modify the original thicket objects or return new + + Returns: + (tuple): tuple containing: + (hatchet.Graph): unified graph + (list): list of Thicket objects + """ + _thickets = thickets + if not inplace: + _thickets = [th.deepcopy() for th in thickets] + # Unify graphs if "self" and "other" do not have the same graph + union_graph = _thickets[0].graph + for i in range(len(_thickets) - 1): + # Check for same graph id (fast) if not -> check for equality (slow) + if ( + _thickets[i].graph is _thickets[i + 1].graph + or _thickets[i].graph == _thickets[i + 1].graph + ): + continue + else: + union_graph = union_graph.union(_thickets[i + 1].graph) + for i in range(len(_thickets)): + # Set all graphs to the union graph + _thickets[i].graph = union_graph + # Necessary to change dataframe hatchet id's to match the nodes in the graph + helpers._sync_nodes_frame(union_graph, _thickets[i].dataframe) + # For tree diff. dataframes need to be sorted. + _thickets[i].dataframe.sort_index(inplace=True) + return union_graph, _thickets + + @staticmethod + def _columns( + thickets, + headers=None, + metadata_key=None, + ): + """Concatenate Thicket attributes horizontally. For DataFrames, this implies expanding + in the column direction. New column multi-index will be created with columns + under separate indexer headers. + + Arguments: + headers (list): List of headers to use for the new columnar multi-index + metadata_key (str): Name of the column from the metadata tables to replace the 'profile' + index. If no argument is provided, it is assumed that there is no profile-wise + relationship between the thickets. + + Returns: + (Thicket): New ensembled Thicket object + """ + + def _check_structures(): + """Check that the structures of the thicket objects are valid for the incoming operations.""" + # Required/expected format of the data + for th in thickets: + verify_thicket_structures(th.dataframe, index=["node", "profile"]) + verify_thicket_structures(th.statsframe.dataframe, index=["node"]) + verify_thicket_structures(th.metadata, index=["profile"]) + # Check for metadata_key in metadata + if metadata_key: + for th in thickets: + verify_thicket_structures(th.metadata, columns=[metadata_key]) + # Check length of profiles match + for i in range(len(thickets) - 1): + if len(thickets[i].profile) != len(thickets[i + 1].profile): + raise ValueError( + "Length of all thicket profiles must match. {} != {}".format( + len(thickets[i].profile), len(thickets[i + 1].profile) + ) + ) + # Ensure all thickets profiles are sorted. Must be true when metadata_key=None to + # guarantee performance data table and metadata table match up. + if metadata_key is None: + for th in thickets: + verify_sorted_profile(th.dataframe) + verify_sorted_profile(th.metadata) + + def _create_multiindex_columns(df, upper_idx_name): + """Helper function to create multi-index column names from a dataframe's + current columns. + + Arguments: + df (DataFrame): source dataframe + upper_idx_name (String): name of the newly added index in the multi-index. + Prepended before each column as a tuple. + + Returns: + (list): list of new indicies generated from the source dataframe + """ + new_idx = [] + for column in df.columns: + new_tuple = (upper_idx_name, column) + new_idx.append(new_tuple) + return new_idx + + def _handle_metadata(): + """Handle operations to create new concatenated columnar axis metadata table.""" + # Update index to reflect performance data table index + for i in range(len(thickets_cp)): + thickets_cp[i].metadata.reset_index(drop=True, inplace=True) + if metadata_key is None: + for i in range(len(thickets_cp)): + thickets_cp[i].metadata.index.set_names("profile", inplace=True) + else: + for i in range(len(thickets_cp)): + thickets_cp[i].metadata.set_index(metadata_key, inplace=True) + thickets_cp[i].metadata.sort_index(inplace=True) + + # Create multi-index columns + for i in range(len(thickets_cp)): + thickets_cp[i].metadata.columns = pd.MultiIndex.from_tuples( + _create_multiindex_columns(thickets_cp[i].metadata, headers[i]) + ) + + # Concat metadata together + combined_th.metadata = pd.concat( + [thickets_cp[i].metadata for i in range(len(thickets_cp))], + axis="columns", + ) + + def _handle_misc(): + """Misceallaneous Thicket object operations.""" + for i in range(1, len(thickets_cp)): + combined_th.profile += thickets_cp[i].profile # Update "profile" object + combined_th.profile_mapping.update( + thickets_cp[i].profile_mapping + ) # Update "profile_mapping" object + combined_th.profile = [new_mappings[prf] for prf in combined_th.profile] + profile_mapping_cp = combined_th.profile_mapping.copy() + for k, v in profile_mapping_cp.items(): + combined_th.profile_mapping[ + new_mappings[k] + ] = combined_th.profile_mapping.pop(k) + combined_th.performance_cols = helpers._get_perf_columns( + combined_th.dataframe + ) + + def _handle_perfdata(): + """Handle operations to create new concatenated columnar axis performance data table. + + Returns: + (dict): dictionary mapping old profiles to new profiles + """ + # Create header list if not provided + nonlocal headers + if headers is None: + headers = [i for i in range(len(thickets))] + + # Update index to reflect performance data table index + new_mappings = {} # Dictionary mapping old profiles to new profiles + if metadata_key is None: # Create index from scratch + new_profiles = [i for i in range(len(thickets_cp[0].profile))] + for i in range(len(thickets_cp)): + thickets_cp[i].metadata["new_profiles"] = new_profiles + thickets_cp[i].add_column_from_metadata_to_ensemble( + "new_profiles", drop=True + ) + thickets_cp[i].dataframe.reset_index(level="profile", inplace=True) + new_mappings.update( + pd.Series( + thickets_cp[i] + .dataframe["new_profiles"] + .map(lambda x: (x, headers[i])) + .values, + index=thickets_cp[i].dataframe["profile"], + ).to_dict() + ) + thickets_cp[i].dataframe.drop("profile", axis=1, inplace=True) + thickets_cp[i].dataframe.set_index( + "new_profiles", append=True, inplace=True + ) + thickets_cp[i].dataframe.index.rename( + "profile", level="new_profiles", inplace=True + ) + else: # Change second-level index to be from metadata's "metadata_key" column + for i in range(len(thickets_cp)): + thickets_cp[i].add_column_from_metadata_to_ensemble(metadata_key) + thickets_cp[i].dataframe.reset_index(level="profile", inplace=True) + new_mappings.update( + pd.Series( + thickets_cp[i] + .dataframe[metadata_key] + .map(lambda x: (x, headers[i])) + .values, + index=thickets_cp[i].dataframe["profile"], + ).to_dict() + ) + thickets_cp[i].dataframe.drop("profile", axis=1, inplace=True) + thickets_cp[i].dataframe.set_index( + metadata_key, append=True, inplace=True + ) + thickets_cp[i].dataframe.sort_index(inplace=True) + + # Create tuple columns + new_columns = [ + _create_multiindex_columns(th.dataframe, headers[i]) + for i, th in enumerate(thickets_cp) + ] + # Clear old metrics (non-tuple) + combined_th.exc_metrics.clear() + combined_th.inc_metrics.clear() + # Update inc/exc metrics + for i in range(len(new_columns)): + for col_tuple in new_columns[i]: + if col_tuple[1] in thickets_cp[i].exc_metrics: + combined_th.exc_metrics.append(col_tuple) + if col_tuple[1] in thickets_cp[i].inc_metrics: + combined_th.inc_metrics.append(col_tuple) + # Update columns + for i in range(len(thickets_cp)): + thickets_cp[i].dataframe.columns = pd.MultiIndex.from_tuples( + new_columns[i] + ) + + # Concat performance data table together + combined_th.dataframe = pd.concat( + [thickets_cp[i].dataframe for i in range(len(thickets_cp))], + axis="columns", + ) + + # Extract "name" columns to upper level + nodes = list(set(combined_th.dataframe.reset_index()["node"])) + for node in nodes: + combined_th.dataframe.loc[node, "name"] = node.frame["name"] + combined_th.dataframe.drop( + columns=[(headers[i], "name") for i in range(len(headers))], + inplace=True, + ) + + # Sort DataFrame + combined_th.dataframe.sort_index(inplace=True) + + return new_mappings + + def _handle_statsframe(): + """Handle operations to create new concatenated columnar axis aggregated statistics table.""" + # Clear aggregated statistics table + combined_th.statsframe = GraphFrame( + graph=combined_th.graph, + dataframe=helpers._new_statsframe_df( + combined_th.dataframe, multiindex=True + ), + ) + + # Step 0A: Pre-check of data structures + _check_structures() + # Step 0B: Variable Initialization + combined_th = thickets[0].deepcopy() + thickets_cp = [th.deepcopy() for th in thickets] + + # Step 1: Unify the thickets + union_graph, _thickets = Ensemble._unify(thickets_cp) + combined_th.graph = union_graph + thickets_cp = _thickets + + # Step 2A: Handle performance data tables + new_mappings = _handle_perfdata() + # Step 2B: Handle metadata tables + _handle_metadata() + # Step 2C: Handle statistics table + _handle_statsframe() + # Step 2D: Handle other Thicket objects. + _handle_misc() + + return combined_th + + @staticmethod + def _index(thickets, superthicket=False): + """Unify a list of thickets into a single thicket + + Arguments: + superthicket (bool): whether the result is a superthicket + + Returns: + unify_graph (hatchet.Graph): unified graph, + unify_df (DataFrame): unified dataframe, + unify_exc_metrics (list): exclusive metrics, + unify_inc_metrics (list): inclusive metrics, + unify_metadata (DataFrame): unified metadata, + unify_profile (list): profiles, + unify_profile_mapping (dict): profile mapping + """ + + def _fill_perfdata(perfdata, fill_value=np.nan): + # Fill missing rows in dataframe with NaN's + perfdata = perfdata.reindex( + pd.MultiIndex.from_product(perfdata.index.levels), fill_value=fill_value + ) + # Replace "NaN" with "None" in columns of string type + for col in perfdata.columns: + if pd.api.types.is_string_dtype(perfdata[col].dtype): + perfdata[col].replace({fill_value: None}, inplace=True) + + return perfdata + + def _superthicket_metadata(metadata): + """Aggregate data in Metadata""" + + def _agg_to_set(obj): + """Aggregate values in 'obj' into a set to remove duplicates.""" + if len(obj) <= 1: + return obj + else: + _set = set(obj) + # If len == 1 just use the value, otherwise return the set + if len(_set) == 1: + return _set.pop() + else: + return _set + + # Rename index to "thicket" + metadata.index.rename("thicket", inplace=True) + # Execute aggregation + metadata = metadata.groupby("thicket").agg(_agg_to_set) + + # Add missing indicies to thickets + helpers._resolve_missing_indicies(thickets) + + # Initialize attributes + unify_graph = None + unify_df = pd.DataFrame() + unify_inc_metrics = [] + unify_exc_metrics = [] + unify_metadata = pd.DataFrame() + unify_profile = [] + unify_profile_mapping = OrderedDict() + + # Unification + unify_graph, thickets = Ensemble._unify(thickets) + for th in thickets: + # Extend metrics + unify_inc_metrics.extend(th.inc_metrics) + unify_exc_metrics.extend(th.exc_metrics) + # Extend metadata + if len(th.metadata) > 0: + curr_meta = th.metadata.copy() + unify_metadata = pd.concat([curr_meta, unify_metadata]) + # Extend profile + if th.profile is not None: + unify_profile.extend(th.profile) + # Extend profile mapping + if th.profile_mapping is not None: + unify_profile_mapping.update(th.profile_mapping) + # Extend dataframe + unify_df = pd.concat([th.dataframe, unify_df]) + # Sort by keys + unify_profile_mapping = OrderedDict(sorted(unify_profile_mapping.items())) + + # Insert missing rows in dataframe + unify_df = _fill_perfdata(unify_df) + + # Metadata-specific operations + if superthicket: + _superthicket_metadata(unify_metadata) + + # Sort PerfData + unify_df.sort_index(inplace=True) + # Sort Metadata + unify_metadata.sort_index(inplace=True) + + # Remove duplicates in metrics + unify_inc_metrics = list(set(unify_inc_metrics)) + unify_exc_metrics = list(set(unify_exc_metrics)) + + # Workaround for graph/df node id mismatch. + helpers._sync_nodes(unify_graph, unify_df) + + unify_parts = ( + unify_graph, + unify_df, + unify_exc_metrics, + unify_inc_metrics, + unify_metadata, + unify_profile, + unify_profile_mapping, + ) + return unify_parts diff --git a/thicket/helpers.py b/thicket/helpers.py index 07fee911..290d8753 100644 --- a/thicket/helpers.py +++ b/thicket/helpers.py @@ -126,6 +126,8 @@ def _resolve_missing_indicies(th_list): def _sync_nodes(gh, df): """Set the node objects to be equal in both the graph and the dataframe. + Operations: (n tree nodes) X (m df nodes) X (m) + id(graph_node) == id(df_node) after this function for nodes with equivalent hatchet nid's. """ diff --git a/thicket/tests/conftest.py b/thicket/tests/conftest.py index d7379096..d09f3333 100644 --- a/thicket/tests/conftest.py +++ b/thicket/tests/conftest.py @@ -12,8 +12,8 @@ @pytest.fixture -def columnar_join_thicket(mpi_scaling_cali, rajaperf_basecuda_xl_cali): - """Generator for 'columnar_join' thicket. +def thicket_axis_columns(mpi_scaling_cali, rajaperf_basecuda_xl_cali): + """Generator for 'concat_thickets(axis="columns")' thicket. Arguments: mpi_scaling_cali (list): List of Caliper files for MPI scaling study. @@ -21,7 +21,7 @@ def columnar_join_thicket(mpi_scaling_cali, rajaperf_basecuda_xl_cali): Returns: list: List of original thickets, list of deepcopies of original thickets, and - columnar-joined thicket. + column-joined thicket. """ th_mpi_1 = Thicket.from_caliperreader(mpi_scaling_cali[0:2]) th_mpi_2 = Thicket.from_caliperreader(mpi_scaling_cali[2:4]) @@ -39,29 +39,29 @@ def columnar_join_thicket(mpi_scaling_cali, rajaperf_basecuda_xl_cali): th_mpi_2_deep = th_mpi_2.deepcopy() th_cuda128_deep = th_cuda128.deepcopy() - thicket_list = [th_mpi_1, th_mpi_2, th_cuda128] - thicket_list_cp = [th_mpi_1_deep, th_mpi_2_deep, th_cuda128_deep] + thickets = [th_mpi_1, th_mpi_2, th_cuda128] + thickets_cp = [th_mpi_1_deep, th_mpi_2_deep, th_cuda128_deep] - combined_th = Thicket.columnar_join( - thicket_list=thicket_list, - header_list=["MPI1", "MPI2", "Cuda128"], - column_name="ProblemSize", + combined_th = Thicket.concat_thickets( + thickets=thickets, + axis="columns", + headers=["MPI1", "MPI2", "Cuda128"], + metadata_key="ProblemSize", ) - return thicket_list, thicket_list_cp, combined_th + return thickets, thickets_cp, combined_th @pytest.fixture -def stats_columnar_join_thicket(rajaperf_basecuda_xl_cali): - """Generator for 'columnar_join' thicket for test_stats.py. +def stats_thicket_axis_columns(rajaperf_basecuda_xl_cali): + """Generator for 'concat_thickets(axis="columns")' thicket for test_stats.py. Arguments: - mpi_scaling_cali (list): List of Caliper files for MPI scaling study. rajaperf_basecuda_xl_cali (list): List of Caliper files for base cuda variant. Returns: list: List of original thickets, list of deepcopies of original thickets, and - columnar-joined thicket. + column-joined thicket. """ th_cuda128_1 = Thicket.from_caliperreader(rajaperf_basecuda_xl_cali[0:4]) th_cuda128_2 = Thicket.from_caliperreader(rajaperf_basecuda_xl_cali[5:9]) @@ -69,15 +69,16 @@ def stats_columnar_join_thicket(rajaperf_basecuda_xl_cali): # To check later if modifications were unexpectedly made th_cuda128_1_deep = th_cuda128_1.deepcopy() th_cuda128_2_deep = th_cuda128_2.deepcopy() - thicket_list = [th_cuda128_1, th_cuda128_2] - thicket_list_cp = [th_cuda128_1_deep, th_cuda128_2_deep] + thickets = [th_cuda128_1, th_cuda128_2] + thickets_cp = [th_cuda128_1_deep, th_cuda128_2_deep] - combined_th = Thicket.columnar_join( - thicket_list=thicket_list, - header_list=["Cuda 1", "Cuda 2"], + combined_th = Thicket.concat_thickets( + thickets=thickets, + axis="columns", + headers=["Cuda 1", "Cuda 2"], ) - return thicket_list, thicket_list_cp, combined_th + return thickets, thickets_cp, combined_th @pytest.fixture diff --git a/thicket/tests/test_columnar_join.py b/thicket/tests/test_concat_thickets.py similarity index 64% rename from thicket/tests/test_columnar_join.py rename to thicket/tests/test_concat_thickets.py index 82da873c..aad0a8ae 100644 --- a/thicket/tests/test_columnar_join.py +++ b/thicket/tests/test_concat_thickets.py @@ -12,40 +12,58 @@ from test_filter_metadata import filter_multiple_and from test_filter_stats import check_filter_stats from test_query import check_query +from thicket import Thicket -def test_columnar_join(columnar_join_thicket): - thicket_list, thicket_list_cp, combined_th = columnar_join_thicket +def test_concat_thickets_index(mpi_scaling_cali): + th_27 = Thicket.from_caliperreader(mpi_scaling_cali[0]) + th_64 = Thicket.from_caliperreader(mpi_scaling_cali[1]) + + tk = Thicket.concat_thickets([th_27, th_64]) + + # Check dataframe shape + tk.dataframe.shape == (90, 7) + + # Check that the two Thickets are equivalent + assert tk + + # Check specific values. Row order can vary so use "sum" to check + node = tk.dataframe.index.get_level_values("node")[8] + assert sum(tk.dataframe.loc[node, "Min time/rank"]) == 0.000453 + + +def test_concat_thickets_columns(thicket_axis_columns): + thickets, thickets_cp, combined_th = thicket_axis_columns # Check no original objects modified - for i in range(len(thicket_list)): - assert thicket_list[i].dataframe.equals(thicket_list_cp[i].dataframe) - assert thicket_list[i].metadata.equals(thicket_list_cp[i].metadata) + for i in range(len(thickets)): + assert thickets[i].dataframe.equals(thickets_cp[i].dataframe) + assert thickets[i].metadata.equals(thickets_cp[i].metadata) # Check dataframe shape. Should be columnar-joined assert combined_th.dataframe.shape[0] <= sum( - [th.dataframe.shape[0] for th in thicket_list] + [th.dataframe.shape[0] for th in thickets] ) # Rows. Should be <= because some rows will exist across multiple thickets. assert ( combined_th.dataframe.shape[1] - == sum([th.dataframe.shape[1] for th in thicket_list]) - len(thicket_list) + 1 + == sum([th.dataframe.shape[1] for th in thickets]) - len(thickets) + 1 ) # Columns. (-1) for each name column removed, (+1) singular name column created. # Check metadata shape. Should be columnar-joined assert combined_th.metadata.shape[0] == max( - [th.metadata.shape[0] for th in thicket_list] + [th.metadata.shape[0] for th in thickets] ) # Rows. Should be max because all rows should exist in all thickets. assert combined_th.metadata.shape[1] == sum( - [th.metadata.shape[1] for th in thicket_list] + [th.metadata.shape[1] for th in thickets] ) - len( - thicket_list + thickets ) # Columns. (-1) Since we added an additional column "ProblemSize". # Check profiles - assert len(combined_th.profile) == sum([len(th.profile) for th in thicket_list]) + assert len(combined_th.profile) == sum([len(th.profile) for th in thickets]) # Check profile_mapping assert len(combined_th.profile_mapping) == sum( - [len(th.profile_mapping) for th in thicket_list] + [len(th.profile_mapping) for th in thickets] ) # PerfData and StatsFrame nodes should be in the same order. @@ -55,8 +73,8 @@ def test_columnar_join(columnar_join_thicket): ).all() -def test_filter_columnar_join(columnar_join_thicket): - thicket_list, thicket_list_cp, combined_th = columnar_join_thicket +def test_filter_concat_thickets_columns(thicket_axis_columns): + thickets, thickets_cp, combined_th = thicket_axis_columns # columns and corresponding values to filter by columns_values = { ("MPI1", "mpi.world.size"): [27], @@ -67,8 +85,8 @@ def test_filter_columnar_join(columnar_join_thicket): filter_multiple_and(combined_th, columns_values) -def test_filter_stats_columnar_join(columnar_join_thicket): - thicket_list, thicket_list_cp, combined_th = columnar_join_thicket +def test_filter_stats_concat_thickets_columns(thicket_axis_columns): + thickets, thickets_cp, combined_th = thicket_axis_columns # columns and corresponding values to filter by columns_values = { ("test", "test_string_column"): ["less than 20"], @@ -86,8 +104,8 @@ def test_filter_stats_columnar_join(columnar_join_thicket): check_filter_stats(combined_th, columns_values) -def test_query_columnar_join(columnar_join_thicket): - thicket_list, thicket_list_cp, combined_th = columnar_join_thicket +def test_query_concat_thickets_columns(thicket_axis_columns): + thickets, thickets_cp, combined_th = thicket_axis_columns # test arguments hnids = [0, 1, 2, 3, 5, 6, 8, 9] query = ( diff --git a/thicket/tests/test_groupby.py b/thicket/tests/test_groupby.py index 1163ade7..850efce0 100644 --- a/thicket/tests/test_groupby.py +++ b/thicket/tests/test_groupby.py @@ -6,7 +6,7 @@ import pytest from thicket import Thicket, EmptyMetadataTable -from test_columnar_join import test_columnar_join +from test_concat_thickets import test_concat_thickets_columns from utils import check_identity @@ -89,7 +89,7 @@ def test_groupby(example_cali): check_groupby(th, columns_values) -def test_groupby_columnar_join(example_cali): +def test_groupby_concat_thickets_columns(example_cali): """Tests case where the Sub-Thickets of a groupby are used in a columnar join""" # example thicket th = Thicket.from_caliperreader(example_cali) @@ -106,23 +106,24 @@ def test_groupby_columnar_join(example_cali): th_list[2].metadata[selected_column] = problem_size th_list[3].metadata[selected_column] = problem_size - thicket_list = [th_list[0], th_list[1], th_list[2], th_list[3]] - thicket_list_cp = [ + thickets = [th_list[0], th_list[1], th_list[2], th_list[3]] + thickets_cp = [ th_list[0].deepcopy(), th_list[1].deepcopy(), th_list[2].deepcopy(), th_list[3].deepcopy(), ] - combined_th = Thicket.columnar_join( - thicket_list=thicket_list, - column_name=selected_column, + combined_th = Thicket.concat_thickets( + thickets=thickets, + axis="columns", + metadata_key=selected_column, ) - test_columnar_join((thicket_list, thicket_list_cp, combined_th)) + test_concat_thickets_columns((thickets, thickets_cp, combined_th)) -def test_groupby_columnar_join_subthickets(example_cali): +def test_groupby_concat_thickets_columns_subthickets(example_cali): """Tests case where some specific Sub-Thickets of a groupby are used in a columnar join""" # example thicket th = Thicket.from_caliperreader(example_cali) @@ -137,15 +138,16 @@ def test_groupby_columnar_join_subthickets(example_cali): th_list[0].metadata[selected_column] = problem_size th_list[1].metadata[selected_column] = problem_size - thicket_list = [th_list[0], th_list[1]] - thicket_list_cp = [ + thickets = [th_list[0], th_list[1]] + thickets_cp = [ th_list[0].deepcopy(), th_list[1].deepcopy(), ] - combined_th = Thicket.columnar_join( - thicket_list=thicket_list, - column_name=selected_column, + combined_th = Thicket.concat_thickets( + thickets=thickets, + axis="columns", + metadata_key=selected_column, ) - test_columnar_join((thicket_list, thicket_list_cp, combined_th)) + test_concat_thickets_columns((thickets, thickets_cp, combined_th)) diff --git a/thicket/tests/test_intersection.py b/thicket/tests/test_intersection.py index af8a6007..98c34961 100644 --- a/thicket/tests/test_intersection.py +++ b/thicket/tests/test_intersection.py @@ -12,6 +12,11 @@ def test_intersection(example_cali): intersected_th = th_ens.intersection() + intersected_th_other = th.from_caliperreader(example_cali, intersection=True) + + # Check other methodology + assert len(intersected_th.graph) == len(intersected_th_other.graph) + # Check original and intersected thickets assert len(th_ens.dataframe) == 344 assert len(intersected_th.dataframe) == 4 diff --git a/thicket/tests/test_stats.py b/thicket/tests/test_stats.py index 360eb89a..5d68f17e 100644 --- a/thicket/tests/test_stats.py +++ b/thicket/tests/test_stats.py @@ -25,8 +25,8 @@ def test_mean(example_cali): assert "Min time/rank_mean" in th_ens.statsframe.show_metric_columns() -def test_mean_columnar_join(columnar_join_thicket): - thicket_list, thicket_list_cp, combined_th = columnar_join_thicket +def test_mean_columnar_join(thicket_axis_columns): + thicket_list, thicket_list_cp, combined_th = thicket_axis_columns idx = combined_th.dataframe.columns.levels[0][0] assert sorted(combined_th.dataframe.index.get_level_values(0).unique()) == sorted( combined_th.statsframe.dataframe.index.values @@ -63,8 +63,8 @@ def test_median(example_cali): assert "Min time/rank_median" in th_ens.statsframe.show_metric_columns() -def test_median_columnar_join(columnar_join_thicket): - thicket_list, thicket_list_cp, combined_th = columnar_join_thicket +def test_median_columnar_join(thicket_axis_columns): + thicket_list, thicket_list_cp, combined_th = thicket_axis_columns idx = combined_th.dataframe.columns.levels[0][0] assert sorted(combined_th.dataframe.index.get_level_values(0).unique()) == sorted( combined_th.statsframe.dataframe.index.values @@ -101,8 +101,8 @@ def test_minimum(example_cali): assert "Min time/rank_min" in th_ens.statsframe.show_metric_columns() -def test_minimum_columnar_join(columnar_join_thicket): - thicket_list, thicket_list_cp, combined_th = columnar_join_thicket +def test_minimum_columnar_join(thicket_axis_columns): + thicket_list, thicket_list_cp, combined_th = thicket_axis_columns idx = combined_th.dataframe.columns.levels[0][0] assert sorted(combined_th.dataframe.index.get_level_values(0).unique()) == sorted( combined_th.statsframe.dataframe.index.values @@ -139,8 +139,8 @@ def test_maximum(example_cali): assert "Min time/rank_max" in th_ens.statsframe.show_metric_columns() -def test_maximum_columnar_join(columnar_join_thicket): - thicket_list, thicket_list_cp, combined_th = columnar_join_thicket +def test_maximum_columnar_join(thicket_axis_columns): + thicket_list, thicket_list_cp, combined_th = thicket_axis_columns idx = combined_th.dataframe.columns.levels[0][0] assert sorted(combined_th.dataframe.index.get_level_values(0).unique()) == sorted( combined_th.statsframe.dataframe.index.values @@ -177,8 +177,8 @@ def test_std(example_cali): assert "Min time/rank_std" in th_ens.statsframe.show_metric_columns() -def test_std_columnar_join(columnar_join_thicket): - thicket_list, thicket_list_cp, combined_th = columnar_join_thicket +def test_std_columnar_join(thicket_axis_columns): + thicket_list, thicket_list_cp, combined_th = thicket_axis_columns idx = combined_th.dataframe.columns.levels[0][0] assert sorted(combined_th.dataframe.index.get_level_values(0).unique()) == sorted( combined_th.statsframe.dataframe.index.values @@ -216,8 +216,8 @@ def test_percentiles(example_cali): assert "Min time/rank_percentiles" in th_ens.statsframe.show_metric_columns() -def test_percentiles_columnar_join(columnar_join_thicket): - thicket_list, thicket_list_cp, combined_th = columnar_join_thicket +def test_percentiles_columnar_join(thicket_axis_columns): + thicket_list, thicket_list_cp, combined_th = thicket_axis_columns idx = combined_th.dataframe.columns.levels[0][0] assert sorted(combined_th.dataframe.index.get_level_values(0).unique()) == sorted( combined_th.statsframe.dataframe.index.values @@ -260,8 +260,8 @@ def test_variance(example_cali): assert "Min time/rank_var" in th_ens.statsframe.show_metric_columns() -def test_variance_columnar_join(columnar_join_thicket): - thicket_list, thicket_list_cp, combined_th = columnar_join_thicket +def test_variance_columnar_join(thicket_axis_columns): + thicket_list, thicket_list_cp, combined_th = thicket_axis_columns idx = combined_th.dataframe.columns.levels[0][0] assert sorted(combined_th.dataframe.index.get_level_values(0).unique()) == sorted( combined_th.statsframe.dataframe.index.values @@ -303,9 +303,9 @@ def test_normality(rajaperf_basecuda_xl_cali): assert "Min time/rank_normality" in th_ens.statsframe.show_metric_columns() -def test_normality_columnar_join(columnar_join_thicket, stats_columnar_join_thicket): - thicket_list, thicket_list_cp, combined_th = columnar_join_thicket - sthicket_list, sthicket_list_cp, scombined_th = stats_columnar_join_thicket +def test_normality_columnar_join(thicket_axis_columns, stats_thicket_axis_columns): + thicket_list, thicket_list_cp, combined_th = thicket_axis_columns + sthicket_list, sthicket_list_cp, scombined_th = stats_thicket_axis_columns # new data must be added before uncommenting, need 3 or more datapoints # idx = combined_th.dataframe.columns.levels[0][0] assert sorted(combined_th.dataframe.index.get_level_values(0).unique()) == sorted( @@ -352,8 +352,8 @@ def test_correlation(rajaperf_basecuda_xl_cali): ) -def test_correlation_columnar_join(columnar_join_thicket): - thicket_list, thicket_list_cp, combined_th = columnar_join_thicket +def test_correlation_columnar_join(thicket_axis_columns): + thicket_list, thicket_list_cp, combined_th = thicket_axis_columns idx = combined_th.dataframe.columns.levels[0] assert sorted(combined_th.dataframe.index.get_level_values(0).unique()) == sorted( combined_th.statsframe.dataframe.index.values @@ -458,8 +458,8 @@ def test_boxplot(example_cali): ) -def test_boxplot_columnar_join(columnar_join_thicket): - thicket_list, thicket_list_cp, combined_th = columnar_join_thicket +def test_boxplot_columnar_join(thicket_axis_columns): + thicket_list, thicket_list_cp, combined_th = thicket_axis_columns idx = combined_th.dataframe.columns.levels[0][0] assert sorted(combined_th.dataframe.index.get_level_values(0).unique()) == sorted( combined_th.statsframe.dataframe.index.values diff --git a/thicket/tests/test_thicket.py b/thicket/tests/test_thicket.py index 5f6ccaa5..98e56b86 100644 --- a/thicket/tests/test_thicket.py +++ b/thicket/tests/test_thicket.py @@ -76,7 +76,7 @@ def _test_multiindex(): """Test statsframe when headers are multiindexed.""" th1 = Thicket.from_caliperreader(example_cali[0]) th2 = Thicket.from_caliperreader(example_cali[1]) - th_cj = Thicket.columnar_join([th1, th2]) + th_cj = Thicket.concat_thickets([th1, th2], axis="columns") # Check column format assert ("name", "") in th_cj.statsframe.dataframe.columns @@ -145,24 +145,6 @@ def test_thicketize_graphframe(example_cali): assert ht1.dataframe.equals(th1.dataframe) -def test_unify_ensemble(mpi_scaling_cali): - th_27 = Thicket.from_caliperreader(mpi_scaling_cali[0]) - th_64 = Thicket.from_caliperreader(mpi_scaling_cali[1]) - - th_listwise = Thicket.unify_ensemble([th_27, th_64]) - th_pairwise = Thicket.unify_ensemble([th_27, th_64], pairwise=True) - - # Check dataframe shape - th_listwise.dataframe.shape == (90, 7) - - # Check that the two Thickets are equivalent - assert th_listwise == th_pairwise - - # Check specific values. Row order can vary so use "sum" to check - node = th_listwise.dataframe.index.get_level_values("node")[8] - assert sum(th_listwise.dataframe.loc[node, "Min time/rank"]) == 0.000453 - - def test_unique_metadata_base_cuda(rajaperf_basecuda_xl_cali): t_ens = Thicket.from_caliperreader(rajaperf_basecuda_xl_cali) diff --git a/thicket/thicket.py b/thicket/thicket.py index 643650aa..f10d989e 100644 --- a/thicket/thicket.py +++ b/thicket/thicket.py @@ -16,8 +16,8 @@ from hatchet import GraphFrame from hatchet.query import AbstractQuery, QueryMatcher +from thicket.ensemble import Ensemble import thicket.helpers as helpers -from .utils import verify_sorted_profile from .utils import verify_thicket_structures from .external.console import ThicketRenderer @@ -249,12 +249,90 @@ def reader_dispatch(func, intersection=False, *args, **kwargs): + "' is not a valid type to be read from." ) - # Perform unify ensemble - thicket_object = Thicket.unify_ensemble(ens_list) + # Perform ensembling operation + calltree = "union" if intersection: - thicket_object = thicket_object.intersection() + calltree = "intersection" + thicket_object = Thicket.concat_thickets( + thickets=ens_list, + axis="index", + calltree=calltree, + ) + return thicket_object + @staticmethod + def concat_thickets(thickets, axis="index", calltree="union", **kwargs): + """Concatenate thickets together on index or columns. The calltree can either be unioned or + intersected which will affect the other structures. + + Arguments: + thickets (list): list of thicket objects + axis (str): axis to concatenate on -> "index" or "column" + calltree (str): calltree to use -> "union" or "intersection" + + valid kwargs: + if axis="index": + superthicket (bool): Whether the result is a superthicket + if axis="columns": + headers (list): List of headers to use for the new columnar multi-index. + metadata_key (str): Name of the column from the metadata tables to replace the 'profile' + index. If no argument is provided, it is assumed that there is no profile-wise + relationship between the thickets. + + Returns: + (thicket): concatenated thicket + """ + + def _index(thickets, superthicket=False): + thicket_parts = Ensemble._index( + thickets=thickets, superthicket=superthicket + ) + + return Thicket( + graph=thicket_parts[0], + dataframe=thicket_parts[1], + exc_metrics=thicket_parts[2], + inc_metrics=thicket_parts[3], + metadata=thicket_parts[4], + profile=thicket_parts[5], + profile_mapping=thicket_parts[6], + ) + + def _columns(thickets, headers=None, metadata_key=None): + combined_thicket = Ensemble._columns( + thickets=thickets, headers=headers, metadata_key=metadata_key + ) + + return combined_thicket + + if calltree not in ["union", "intersection"]: + raise ValueError("calltree must be 'union' or 'intersection'") + + if axis == "index": + ct = _index(thickets, **kwargs) + elif axis == "columns": + ct = _columns(thickets, **kwargs) + else: + raise ValueError("axis must be 'index' or 'columns'") + + if calltree == "intersection": + ct = ct.intersection() + + return ct + + @staticmethod + def columnar_join(thicket_list, header_list=None, metadata_key=None): + raise ValueError( + "columnar_join is deprecated. Use 'concat_thickets(axis='columns'...)' instead." + ) + + @staticmethod + def unify_ensemble(th_list, superthicket=False): + raise ValueError( + "unify_ensemble is deprecated. Use 'concat_thickets(axis='index'...)' instead." + ) + @staticmethod def from_json(json_thicket): # deserialize the json @@ -300,272 +378,37 @@ def from_json(json_thicket): # make and return thicket? return th - @staticmethod - def columnar_join( - thicket_list, - header_list=None, - column_name=None, - ): - """Join Thickets column-wise. New column multi-index will be created with - columns under separate indexer headers. - - Arguments: - thicket_list (list): List of Thickets to join - header_list (list): List of headers to use for the new columnar multi-index - column_name (str): Name of the column from the metadata table to join on. If - no argument is provided, it is assumed that there is no profile-wise - relationship between self and other. - - Returns: - (Thicket): New Thicket object with joined columns - """ - - def _create_multiindex_columns(df, upper_idx_name): - """Helper function to create multi-index column names from a dataframe's - current columns. - - Arguments: - df (DataFrame): source dataframe - upper_idx_name (String): name of the newly added index in the multi-index. - Prepended before each column as a tuple. - - Returns: - (list): list of new indicies generated from the source dataframe - """ - new_idx = [] - for column in df.columns: - new_tuple = (upper_idx_name, column) - new_idx.append(new_tuple) - return new_idx - - ### - # Step 0A: Pre-check of data structures - ### - # Required/expected format of the data - for th in thicket_list: - verify_thicket_structures(th.dataframe, index=["node", "profile"]) - verify_thicket_structures(th.statsframe.dataframe, index=["node"]) - verify_thicket_structures(th.metadata, index=["profile"]) - # Check for column_name in metadata - if column_name: - for th in thicket_list: - verify_thicket_structures(th.metadata, columns=[column_name]) - # Check length of profiles match - for i in range(len(thicket_list) - 1): - if len(thicket_list[i].profile) != len(thicket_list[i + 1].profile): - raise ValueError( - "Length of all thicket profiles must match. {} != {}".format( - len(thicket_list[i].profile), len(thicket_list[i + 1].profile) - ) - ) - # Ensure all thickets profiles are sorted. Must be true when column_name=None to - # guarantee performance data table and metadata table match up. - if column_name is None: - for th in thicket_list: - verify_sorted_profile(th.dataframe) - verify_sorted_profile(th.metadata) - - ### - # Step 0B: Variable Initialization - ### - # Initialize combined thicket - combined_th = thicket_list[0].deepcopy() - # Use copies to be non-destructive - thicket_list_cp = [th.deepcopy() for th in thicket_list] - - ### - # Step 1: Unify the thickets - ### - # Unify graphs if "self" and "other" do not have the same graph - union_graph = thicket_list_cp[0].graph - for i in range(len(thicket_list_cp) - 1): - if thicket_list_cp[i].graph != thicket_list_cp[i + 1].graph: - union_graph = union_graph.union(thicket_list_cp[i + 1].graph) - combined_th.graph = union_graph - for i in range(len(thicket_list_cp)): - # Set all graphs to the union graph - thicket_list_cp[i].graph = union_graph - # Necessary to change dataframe hatchet id's to match the nodes in the graph - helpers._sync_nodes_frame(union_graph, thicket_list_cp[i].dataframe) - # For tree diff. dataframes need to be sorted. - thicket_list_cp[i].dataframe.sort_index(inplace=True) - - ### - # Step 2: Join "self" & "other" performance data table - ### - # Create header list if not provided - if header_list is None: - header_list = [i for i in range(len(thicket_list))] - - # Update index to reflect performance data table index - new_mappings = {} # Dictionary mapping old profiles to new profiles - if column_name is None: # Create index from scratch - new_profiles = [i for i in range(len(thicket_list_cp[0].profile))] - for i in range(len(thicket_list_cp)): - thicket_list_cp[i].metadata["new_profiles"] = new_profiles - thicket_list_cp[i].add_column_from_metadata_to_ensemble( - "new_profiles", drop=True - ) - thicket_list_cp[i].dataframe.reset_index(level="profile", inplace=True) - new_mappings.update( - pd.Series( - thicket_list_cp[i] - .dataframe["new_profiles"] - .map(lambda x: (x, header_list[i])) - .values, - index=thicket_list_cp[i].dataframe["profile"], - ).to_dict() - ) - thicket_list_cp[i].dataframe.drop("profile", axis=1, inplace=True) - thicket_list_cp[i].dataframe.set_index( - "new_profiles", append=True, inplace=True - ) - thicket_list_cp[i].dataframe.index.rename( - "profile", level="new_profiles", inplace=True - ) - else: # Change second-level index to be from metadata's "column_name" column - for i in range(len(thicket_list_cp)): - thicket_list_cp[i].add_column_from_metadata_to_ensemble(column_name) - thicket_list_cp[i].dataframe.reset_index(level="profile", inplace=True) - new_mappings.update( - pd.Series( - thicket_list_cp[i] - .dataframe[column_name] - .map(lambda x: (x, header_list[i])) - .values, - index=thicket_list_cp[i].dataframe["profile"], - ).to_dict() - ) - thicket_list_cp[i].dataframe.drop("profile", axis=1, inplace=True) - thicket_list_cp[i].dataframe.set_index( - column_name, append=True, inplace=True - ) - thicket_list_cp[i].dataframe.sort_index(inplace=True) - - # Create tuple columns - new_columns = [ - _create_multiindex_columns(th.dataframe, header_list[i]) - for i, th in enumerate(thicket_list_cp) - ] - # Clear old metrics (non-tuple) - combined_th.exc_metrics.clear() - combined_th.inc_metrics.clear() - # Update inc/exc metrics - for i in range(len(new_columns)): - for col_tuple in new_columns[i]: - if col_tuple[1] in thicket_list_cp[i].exc_metrics: - combined_th.exc_metrics.append(col_tuple) - if col_tuple[1] in thicket_list_cp[i].inc_metrics: - combined_th.inc_metrics.append(col_tuple) - # Update columns - for i in range(len(thicket_list_cp)): - thicket_list_cp[i].dataframe.columns = pd.MultiIndex.from_tuples( - new_columns[i] - ) - - # Concat performance data table together - combined_th.dataframe = pd.concat( - [thicket_list_cp[i].dataframe for i in range(len(thicket_list_cp))], - axis="columns", - join="outer", - ) - - # Extract "name" columns to upper level - nodes = list(set(combined_th.dataframe.reset_index()["node"])) - for node in nodes: - combined_th.dataframe.loc[node, "name"] = node.frame["name"] - combined_th.dataframe.drop( - columns=[(header_list[i], "name") for i in range(len(header_list))], - inplace=True, - ) - - # Sort DataFrame - combined_th.dataframe.sort_index(inplace=True) - - ### - # Step 3: Join "self" & "other" metadata table - ### - # Update index to reflect performance data table index - for i in range(len(thicket_list_cp)): - thicket_list_cp[i].metadata.reset_index(drop=True, inplace=True) - if column_name is None: - for i in range(len(thicket_list_cp)): - thicket_list_cp[i].metadata.index.set_names("profile", inplace=True) - else: - for i in range(len(thicket_list_cp)): - thicket_list_cp[i].metadata.set_index(column_name, inplace=True) - thicket_list_cp[i].metadata.sort_index(inplace=True) - - # Create multi-index columns - for i in range(len(thicket_list_cp)): - thicket_list_cp[i].metadata.columns = pd.MultiIndex.from_tuples( - _create_multiindex_columns(thicket_list_cp[i].metadata, header_list[i]) - ) - - # Concat metadata together - combined_th.metadata = pd.concat( - [thicket_list_cp[i].metadata for i in range(len(thicket_list_cp))], - axis="columns", - join="outer", - ) - - ### - # Step 4: Update other Thicket objects. - ### - for i in range(1, len(thicket_list_cp)): - combined_th.profile += thicket_list_cp[i].profile # Update "profile" object - combined_th.profile_mapping.update( - thicket_list_cp[i].profile_mapping - ) # Update "profile_mapping" object - combined_th.profile = [new_mappings[prf] for prf in combined_th.profile] - profile_mapping_cp = combined_th.profile_mapping.copy() - for k, v in profile_mapping_cp.items(): - combined_th.profile_mapping[ - new_mappings[k] - ] = combined_th.profile_mapping.pop(k) - - # Clear aggregated statistics table - combined_th.statsframe = GraphFrame( - graph=combined_th.graph, - dataframe=helpers._new_statsframe_df( - combined_th.dataframe, multiindex=True - ), - ) - combined_th.performance_cols = helpers._get_perf_columns(combined_th.dataframe) - - return combined_th - def add_column_from_metadata_to_ensemble( - self, column_name, overwrite=False, drop=False + self, metadata_key, overwrite=False, drop=False ): """Add a column from the metadata table to the performance data table. Arguments: - column_name (str): Name of the column from the metadata table + metadata_key (str): Name of the column from the metadata table overwrite (bool): Determines overriding behavior in performance data table drop (bool): Whether to drop the column from the metadata table afterwards """ # Add warning if column already exists in performance data table - if column_name in self.dataframe.columns: + if metadata_key in self.dataframe.columns: # Drop column to overwrite, otherwise warn and return if overwrite: - self.dataframe.drop(column_name, axis=1, inplace=True) + self.dataframe.drop(metadata_key, axis=1, inplace=True) else: warnings.warn( "Column " - + column_name + + metadata_key + " already exists. Set 'overwrite=True' to force update the column." ) return # Add the column to the performance data table self.dataframe = self.dataframe.join( - self.metadata[column_name], on=self.dataframe.index.names[1] + self.metadata[metadata_key], on=self.dataframe.index.names[1] ) # Drop column if drop: - self.metadata.drop(column_name, axis=1, inplace=True) + self.metadata.drop(metadata_key, axis=1, inplace=True) def squash(self, update_inc_cols=True): """Rewrite the Graph to include only nodes present in the performance @@ -753,218 +596,6 @@ def tree( max_value=max_value, ) - def unify_pair(self, other): - """Unify two Thicket's graphs and dataframes""" - # Check for the same object. Cheap operation since no graph walkthrough. - if self.graph is other.graph: - print("same graph (object)") - return self.graph - - # Check for the same graph structure. Need to walk through graphs *but should - # still be less expensive then performing the rest of this function.* - if self.graph == other.graph: - print("same graph (structure)") - return self.graph - - print("different graph") - - node_map = {} - union_graph = self.graph.union(other.graph, node_map) - - self_index_names = self.dataframe.index.names - other_index_names = other.dataframe.index.names - - self.dataframe.reset_index(inplace=True) - other.dataframe.reset_index(inplace=True) - - self.dataframe["node"] = self.dataframe["node"].apply(lambda x: node_map[id(x)]) - other.dataframe["node"] = other.dataframe["node"].apply( - lambda x: node_map[id(x)] - ) - - self.dataframe.set_index(self_index_names, inplace=True) - other.dataframe.set_index(other_index_names, inplace=True) - - self.graph = union_graph - other.graph = union_graph - - return union_graph - - @staticmethod - def unify_pairwise(th_list, debug=False): - """Unifies two thickets graphs and dataframes. - - Ensure self and other have the same graph and same node IDs. This may change the - node IDs in the dataframe. - - Update the graphs in the graphframe if they differ. - - Arguments: - th_list (list): list of Thicket objects - debug (bool): print debug statements - - Returns: - union_graph (Graph): unified graph - """ - union_graph = th_list[0].graph - for i in range(len(th_list)): - for j in range(i + 1, len(th_list)): - if debug: - print("Unifying (" + str(i) + ", " + str(j) + "...") - union_graph = th_list[i].unify_pair(th_list[j]) - return union_graph - - @staticmethod - def unify_listwise(th_list, debug=False): - """Unify a list of Thicket's graphs and dataframes - - Arguments: - th_list (list): list of Thicket objects - debug (bool): print debug statements - - Returns: - union_graph (Graph): unified graph - """ - # variable to keep track of case where all graphs are the same - same_graphs = True - - # GRAPH UNIFICATION - union_graph = th_list[0].graph - for i in range(1, len(th_list)): # n-1 unions - # Check to skip unnecessary computation. apply short circuiting with 'or'. - if union_graph is th_list[i].graph or union_graph == th_list[i].graph: - if debug: - print("Union Graph == thicket[" + str(i) + "].graph") - else: - if debug: - print("Unifying (Union Graph, " + str(i) + ")") - same_graphs = False - # Unify graph with current thickets graph - union_graph = union_graph.union(th_list[i].graph) - - # If the graphs were all the same in the first place then there is no need to - # apply any node mappings. - if same_graphs: - return union_graph - - # DATAFRAME MAPPING UPDATE - for i in range(len(th_list)): # n ops - node_map = {} - # Create a node map from current thickets graph to the union graph. This is - # only valid once the union graph is complete. - union_graph.union(th_list[i].graph, node_map) - names = th_list[i].dataframe.index.names - th_list[i].dataframe.reset_index(inplace=True) - - # Apply node_map mapping - th_list[i].dataframe["node"] = ( - th_list[i].dataframe["node"].apply(lambda node: node_map[id(node)]) - ) - th_list[i].dataframe.set_index(names, inplace=True, drop=True) - - # After this point the graph and dataframe in each thicket is out of sync. - # We could update the graph element in thicket to be the union graph but if the - # user prints out the graph how do we annotate nodes only contained in one - # thicket. - return union_graph - - @staticmethod - def unify_ensemble(th_list, pairwise=False, superthicket=False): - """Unify a list of thickets into a single thicket - - Arguments: - th_list (list): list of thickets - pairwise (bool): use the pairwise implementation of unify (use if having - issues) - superthicket (bool): whether the result is a superthicket - - Returns: - (thicket): unified thicket - """ - unify_graph = None - if pairwise: - unify_graph = Thicket.unify_pairwise(th_list) - else: - unify_graph = Thicket.unify_listwise(th_list) - - helpers._resolve_missing_indicies(th_list) - - # Unify dataframe - unify_df = pd.DataFrame() - unify_inc_metrics = [] - unify_exc_metrics = [] - unify_metadata = pd.DataFrame() - unify_profile = [] - unify_profile_mapping = {} - - # Unification loop - for i, th in enumerate(th_list): - unify_inc_metrics.extend(th.inc_metrics) - unify_exc_metrics.extend(th.exc_metrics) - if len(th.metadata) > 0: - curr_meta = th.metadata.copy() - unify_metadata = pd.concat([curr_meta, unify_metadata]) - if th.profile is not None: - unify_profile.extend(th.profile) - if th.profile_mapping is not None: - unify_profile_mapping.update(th.profile_mapping) - unify_df = pd.concat([th.dataframe, unify_df]) - - # Fill missing rows in dataframe with NaN's - fill_value = np.nan - unify_df = unify_df.reindex( - pd.MultiIndex.from_product(unify_df.index.levels), fill_value=fill_value - ) - # Replace NaN with None in string columns - for col in unify_df.columns: - if pd.api.types.is_string_dtype(unify_df[col].dtype): - unify_df[col].replace({fill_value: None}, inplace=True) - - # Operations specific to a superthicket - if superthicket: - unify_metadata.index.rename("thicket", inplace=True) - - # Process to aggregate rows of thickets with the same name. - def _agg_function(obj): - """Aggregate values in 'obj' into a set to remove duplicates.""" - if len(obj) <= 1: - return obj - else: - _set = set(obj) - if len(_set) == 1: - return _set.pop() - else: - return _set - - unify_metadata = unify_metadata.groupby("thicket").agg(_agg_function) - - # Have metadata index match performance data table index - unify_metadata.sort_index(inplace=True) - - # Sort by hatchet node id - unify_df.sort_index(inplace=True) - - unify_inc_metrics = list(set(unify_inc_metrics)) - unify_exc_metrics = list(set(unify_exc_metrics)) - - # Workaround for graph/df node id mismatch. - # (n tree nodes) X (m df nodes) X (m) - helpers._sync_nodes(unify_graph, unify_df) - - # Mutate into OrderedDict to sort profile hashes - unify_profile_mapping = OrderedDict(sorted(unify_profile_mapping.items())) - - unify_th = Thicket( - graph=unify_graph, - dataframe=unify_df, - exc_metrics=unify_exc_metrics, - inc_metrics=unify_inc_metrics, - metadata=unify_metadata, - profile=unify_profile, - profile_mapping=unify_profile_mapping, - ) - return unify_th - @staticmethod def make_superthicket(th_list, profiles_from_meta=None): """Convert a list of thickets into a 'superthicket'. @@ -1036,7 +667,7 @@ def make_superthicket(th_list, profiles_from_meta=None): # Append copy to list th_copy_list.append(th_copy) - return Thicket.unify_ensemble(th_copy_list, superthicket=True) + return Thicket.concat_thickets(th_copy_list, superthicket=True) def to_json(self, ensemble=True, metadata=True, stats=True): jsonified_thicket = {}