Skip to content

Commit

Permalink
Simplify agg with single return
Browse files Browse the repository at this point in the history
  • Loading branch information
wjsi committed May 5, 2022
1 parent ee31659 commit 6f52547
Showing 1 changed file with 43 additions and 7 deletions.
50 changes: 43 additions & 7 deletions mars/dataframe/groupby/aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -695,7 +695,28 @@ def _pack_inputs(agg_funcs: List[ReductionAggStep], in_data):
return out_dict

@staticmethod
def _do_custom_agg(op, custom_reduction, *input_objs):
def _do_custom_agg_single(op, custom_reduction, input_obj):
if op.stage == OperandStage.map:
if custom_reduction.pre_with_agg:
apply_fun = custom_reduction.pre
else:

def apply_fun(obj):
return custom_reduction.agg(custom_reduction.pre(obj))

elif op.stage == OperandStage.agg:

def apply_fun(obj):
return custom_reduction.post(custom_reduction.agg(obj))

else:
apply_fun = custom_reduction.agg

res = input_obj.apply(apply_fun)
return (res,)

@staticmethod
def _do_custom_agg_multiple(op, custom_reduction, *input_objs):
xdf = cudf if op.gpu else pd
results = []
out = op.outputs[0]
Expand Down Expand Up @@ -763,6 +784,13 @@ def _do_custom_agg(op, custom_reduction, *input_objs):
concat_result = tuple(xdf.concat(parts) for parts in zip(*results))
return concat_result

@classmethod
def _do_custom_agg(cls, op, custom_reduction, *input_objs, output_limit: int = 1):
if output_limit == 1:
return cls._do_custom_agg_single(op, custom_reduction, input_objs[0])
else:
return cls._do_custom_agg_multiple(op, custom_reduction, *input_objs)

@staticmethod
def _do_predefined_agg(input_obj, agg_func, single_func=False, **kwds):
ndim = getattr(input_obj, "ndim", None) or input_obj.obj.ndim
Expand Down Expand Up @@ -857,12 +885,16 @@ def _wrapped_func(col):
_agg_func_name,
custom_reduction,
_output_key,
_output_limit,
output_limit,
kwds,
) in op.agg_funcs:
input_obj = ret_map_groupbys[input_key]
if map_func_name == "custom_reduction":
agg_dfs.extend(cls._do_custom_agg(op, custom_reduction, input_obj))
agg_dfs.extend(
cls._do_custom_agg(
op, custom_reduction, input_obj, output_limit=output_limit
)
)
else:
single_func = map_func_name == op.raw_func
agg_dfs.append(
Expand Down Expand Up @@ -903,12 +935,16 @@ def _execute_combine(cls, ctx, op: "DataFrameGroupByAgg"):
agg_func_name,
custom_reduction,
output_key,
_output_limit,
output_limit,
kwds,
) in op.agg_funcs:
input_obj = in_data_dict[output_key]
if agg_func_name == "custom_reduction":
combines.extend(cls._do_custom_agg(op, custom_reduction, *input_obj))
combines.extend(
cls._do_custom_agg(
op, custom_reduction, *input_obj, output_limit=output_limit
)
)
else:
combines.append(
cls._do_predefined_agg(input_obj, agg_func_name, **kwds)
Expand Down Expand Up @@ -943,15 +979,15 @@ def _execute_agg(cls, ctx, op: "DataFrameGroupByAgg"):
agg_func_name,
custom_reduction,
output_key,
_output_limit,
output_limit,
kwds,
) in op.agg_funcs:
if agg_func_name == "custom_reduction":
input_obj = tuple(
cls._get_grouped(op, o, ctx) for o in in_data_dict[output_key]
)
in_data_dict[output_key] = cls._do_custom_agg(
op, custom_reduction, *input_obj
op, custom_reduction, *input_obj, output_limit=output_limit
)[0]
else:
input_obj = cls._get_grouped(op, in_data_dict[output_key], ctx)
Expand Down

0 comments on commit 6f52547

Please sign in to comment.