diff --git a/agate/aggregations/mean.py b/agate/aggregations/mean.py index 689b674f..ac0de983 100644 --- a/agate/aggregations/mean.py +++ b/agate/aggregations/mean.py @@ -3,7 +3,7 @@ from agate.aggregations.base import Aggregation from agate.aggregations.has_nulls import HasNulls from agate.aggregations.sum import Sum -from agate.data_types import Number +from agate.data_types import Number, TimeDelta from agate.exceptions import DataTypeError from agate.warns import warn_null_calculation @@ -20,13 +20,16 @@ def __init__(self, column_name): self._sum = Sum(column_name) def get_aggregate_data_type(self, table): - return Number() + column = table.columns[self._column_name] + + if isinstance(column.data_type, (Number, TimeDelta)): + return column.data_type def validate(self, table): column = table.columns[self._column_name] - if not isinstance(column.data_type, Number): - raise DataTypeError('Mean can only be applied to columns containing Number data.') + if not isinstance(column.data_type, (Number, TimeDelta)): + raise DataTypeError('Sum can only be applied to columns containing Number or TimeDelta data.') has_nulls = HasNulls(self._column_name).run(table)