From 6f525474c5d64837c8192c1f7491a735212e42b4 Mon Sep 17 00:00:00 2001 From: Wenjun Si Date: Thu, 5 May 2022 20:22:21 +0800 Subject: [PATCH] Simplify agg with single return --- mars/dataframe/groupby/aggregation.py | 50 +++++++++++++++++++++++---- 1 file changed, 43 insertions(+), 7 deletions(-) diff --git a/mars/dataframe/groupby/aggregation.py b/mars/dataframe/groupby/aggregation.py index e173929ba2..6ae14d75f0 100644 --- a/mars/dataframe/groupby/aggregation.py +++ b/mars/dataframe/groupby/aggregation.py @@ -695,7 +695,28 @@ def _pack_inputs(agg_funcs: List[ReductionAggStep], in_data): return out_dict @staticmethod - def _do_custom_agg(op, custom_reduction, *input_objs): + def _do_custom_agg_single(op, custom_reduction, input_obj): + if op.stage == OperandStage.map: + if custom_reduction.pre_with_agg: + apply_fun = custom_reduction.pre + else: + + def apply_fun(obj): + return custom_reduction.agg(custom_reduction.pre(obj)) + + elif op.stage == OperandStage.agg: + + def apply_fun(obj): + return custom_reduction.post(custom_reduction.agg(obj)) + + else: + apply_fun = custom_reduction.agg + + res = input_obj.apply(apply_fun) + return (res,) + + @staticmethod + def _do_custom_agg_multiple(op, custom_reduction, *input_objs): xdf = cudf if op.gpu else pd results = [] out = op.outputs[0] @@ -763,6 +784,13 @@ def _do_custom_agg(op, custom_reduction, *input_objs): concat_result = tuple(xdf.concat(parts) for parts in zip(*results)) return concat_result + @classmethod + def _do_custom_agg(cls, op, custom_reduction, *input_objs, output_limit: int = 1): + if output_limit == 1: + return cls._do_custom_agg_single(op, custom_reduction, input_objs[0]) + else: + return cls._do_custom_agg_multiple(op, custom_reduction, *input_objs) + @staticmethod def _do_predefined_agg(input_obj, agg_func, single_func=False, **kwds): ndim = getattr(input_obj, "ndim", None) or input_obj.obj.ndim @@ -857,12 +885,16 @@ def _wrapped_func(col): _agg_func_name, custom_reduction, _output_key, - _output_limit, + output_limit, kwds, ) in op.agg_funcs: input_obj = ret_map_groupbys[input_key] if map_func_name == "custom_reduction": - agg_dfs.extend(cls._do_custom_agg(op, custom_reduction, input_obj)) + agg_dfs.extend( + cls._do_custom_agg( + op, custom_reduction, input_obj, output_limit=output_limit + ) + ) else: single_func = map_func_name == op.raw_func agg_dfs.append( @@ -903,12 +935,16 @@ def _execute_combine(cls, ctx, op: "DataFrameGroupByAgg"): agg_func_name, custom_reduction, output_key, - _output_limit, + output_limit, kwds, ) in op.agg_funcs: input_obj = in_data_dict[output_key] if agg_func_name == "custom_reduction": - combines.extend(cls._do_custom_agg(op, custom_reduction, *input_obj)) + combines.extend( + cls._do_custom_agg( + op, custom_reduction, *input_obj, output_limit=output_limit + ) + ) else: combines.append( cls._do_predefined_agg(input_obj, agg_func_name, **kwds) @@ -943,7 +979,7 @@ def _execute_agg(cls, ctx, op: "DataFrameGroupByAgg"): agg_func_name, custom_reduction, output_key, - _output_limit, + output_limit, kwds, ) in op.agg_funcs: if agg_func_name == "custom_reduction": @@ -951,7 +987,7 @@ def _execute_agg(cls, ctx, op: "DataFrameGroupByAgg"): cls._get_grouped(op, o, ctx) for o in in_data_dict[output_key] ) in_data_dict[output_key] = cls._do_custom_agg( - op, custom_reduction, *input_obj + op, custom_reduction, *input_obj, output_limit=output_limit )[0] else: input_obj = cls._get_grouped(op, in_data_dict[output_key], ctx)