Skip to content

Commit

Permalink
replace eval dataloader with train dataloader if eval_dataloader is N…
Browse files Browse the repository at this point in the history
…one (PaddlePaddle#1163)

* replace eval dataloader with train dataloader if eval_dataloader is None

* update
  • Loading branch information
ceci3 authored Jun 10, 2022
1 parent 6a16182 commit 769c28f
Showing 1 changed file with 18 additions and 4 deletions.
22 changes: 18 additions & 4 deletions paddleslim/auto_compression/compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,9 @@ def __init__(self,
If set to None, will choose a strategy automatically. Default: None.
target_speedup(float, optional): target speedup ratio by the way of auto compress. Default: None.
eval_callback(function, optional): eval function, define by yourself to return the metric of the inference program, can be used to judge the metric of compressed model. The documents of how to write eval function is `https://github.com/PaddlePaddle/PaddleSlim/blob/develop/docs/zh_cn/api_cn/static/auto-compression/custom_function.rst`_ . ``eval_callback`` and ``eval_dataloader`` cannot be None at the same time. Dafault: None.
eval_dataloader(paddle.io.Dataloader, optional): The
Generator or Dataloader provides eval data, and it could
return a batch every time. ``eval_callback`` and ``eval_dataloader`` cannot be None at the same time. Dafault: None.
eval_dataloader(paddle.io.Dataloader, optional): The Generator or Dataloader provides eval data, and it could
return a batch every time. If eval_dataloader is None, will take first 5000 sample from train_dataloader
as eval_dataloader, and the metric of eval_dataloader for reference only. Dafault: None.
deploy_hardware(str, optional): The hardware you want to deploy. Default: 'gpu'.
"""
self.model_dir = model_dir
Expand All @@ -116,7 +116,10 @@ def __init__(self,
self.train_dataloader = train_dataloader
self.target_speedup = target_speedup
self.eval_function = eval_callback
self.eval_dataloader = eval_dataloader if eval_dataloader is not None else train_dataloader

if eval_dataloader is None:
eval_dataloader = self._get_eval_dataloader(train_dataloader)
self.eval_dataloader = eval_dataloader

paddle.enable_static()

Expand Down Expand Up @@ -152,6 +155,17 @@ def __init__(self,
self.train_config = create_train_config(self.strategy_config,
self.model_type)

def _get_eval_dataloader(self, train_dataloader):
def _gen():
len_loader = len(list(train_dataloader()))
### max eval_dataloader is 5000 if use train_dataloader as eval_dataloader
slice_len = min(5000, len_loader)
ret = list(itertools.islice(train_dataloader(), slice_len))
for i in ret:
yield i

return _gen

def _prepare_envs(self):
devices = paddle.device.get_device().split(':')[0]
places = paddle.device._convert_to_place(devices)
Expand Down

0 comments on commit 769c28f

Please sign in to comment.