diff --git a/python/paddle/trainer_config_helpers/optimizers.py b/python/paddle/trainer_config_helpers/optimizers.py index a53ebe160be3b..3656995256e74 100644 --- a/python/paddle/trainer_config_helpers/optimizers.py +++ b/python/paddle/trainer_config_helpers/optimizers.py @@ -16,12 +16,15 @@ default_gradient_clipping_threshold, default_momentum from .default_decorators import wrap_param_default +import collections +import cStringIO __all__ = [ 'Optimizer', 'BaseSGDOptimizer', 'MomentumOptimizer', 'AdamaxOptimizer', 'AdamOptimizer', 'AdaGradOptimizer', 'RMSPropOptimizer', 'DecayedAdaGradOptimizer', 'AdaDeltaOptimizer', 'BaseRegularization', - 'L2Regularization', 'settings', 'ModelAverage' + 'L2Regularization', 'settings', 'ModelAverage', 'PolyLRS', 'ConstantLRS', + 'ExpLRS', 'DiscreteExpLRS', 'LinearLRS', 'ManualLRS', 'PassManualLRS' ] @@ -351,15 +354,141 @@ def __extends__(dict1, dict2): return dict1 +class BaseLRS(Optimizer): + def __init__(self, a, b, scheduler_name): + self.__a__ = float(a) + self.__b__ = float(b) + self.__scheduler_name__ = scheduler_name + + def to_setting_kwargs(self): + return { + 'learning_rate_schedule': self.__scheduler_name__, + 'learning_rate_decay_a': self.__a__, + 'learning_rate_decay_b': self.__b__ + } + + +class PolyLRS(BaseLRS): + """ + Poly Learning Rate Scheduler. + + lr = learning_rate * pow(1 + a * num_samples_processed, -b) + """ + + def __init__(self, a, b): + super(PolyLRS, self).__init__(a=a, b=b, scheduler_name='poly') + + +class ConstantLRS(Optimizer): + """ + Constant Learning Rate Scheduler. Learning rate will not be changed. + """ + + def to_setting_kwargs(self): + return {'learning_rate_schedule': 'constant'} + + +class ExpLRS(BaseLRS): + """ + Exp Learning Rate Scheduler. + + lr = learning_rate * pow(a, num_samples_processed/b) + """ + + def __init__(self, a, b): + super(ExpLRS, self).__init__(a=a, b=b, scheduler_name='exp') + + +class DiscreteExpLRS(BaseLRS): + """ + Discrete Exp Learning Rate Scheduler. + + lr = learning_rate * pow(a, floor(num_samples_processed / b)) + """ + + def __init__(self, a, b): + super(DiscreteExpLRS, self).__init__(a=a, b=b, scheduler_name='discexp') + + +class LinearLRS(BaseLRS): + """ + Linear Learning Rate Scheduler. + + lr = max(learning_rate - a, b) + """ + + def __init__(self, a, b): + super(LinearLRS, self).__init__(a=a, b=b, scheduler_name='linear') + + +class ManualLRS(Optimizer): + """ + specify learning rate through explicit pass all learning_rates. + + :param learning_rates: list of learning rates. Each item contains two field. + First is a int value, as segmentation. Second is the + learning rate. + + The real learning rate is: + + if seg_{i-1} <= numSamples <= seg_i, + return lr_{i} + + :type learning_rates: list of list. Each element should be (int, float) + """ + + def __init__(self, learning_rates): + assert isinstance(learning_rates, collections.Sequence) + with cStringIO.StringIO() as buf: + for i, each in enumerate(learning_rates): + assert isinstance(each, collections.Sequence) + assert len(each) == 2 + buf.write("{0}:{1:.5f}".format(int(each[0]), float(each[1]))) + if i + 1 != len(learning_rates): # not at end + buf.write(",") + self.__args__ = buf.getvalue() + + def to_setting_kwargs(self): + return { + 'learning_rate_schedule': 'manual', + 'learning_rate_args': self.__args__ + } + + +class PassManualLRS(ManualLRS): + """ + Pass Manual Learning Rate Scheduler. + + Basically same as manual learning rate scheduler, except pass manual LRS use + pass number as segment number. + + The real learning rate is: + + if seg_{i-1} <= pass_id <= seg_i: + return lr_{i} + """ + + def __init__(self, learning_rates): + super(PassManualLRS, self).__init__(learning_rates=learning_rates) + + def to_setting_kwargs(self): + return { + 'learning_rate_schedule': 'pass_manual', + 'learning_rate_args': self.__args__ + } + + @wrap_param_default( ['learning_method'], default_factory=lambda _: MomentumOptimizer()) @wrap_param_default( ['regularization'], default_factory=lambda _: BaseRegularization()) +@wrap_param_default( + ['learning_rate_args'], default_factory=lambda _: ConstantLRS()) def settings(batch_size, learning_rate=1e-3, learning_rate_decay_a=0., learning_rate_decay_b=0., - learning_rate_schedule='poly', + learning_rate_schedule=None, learning_rate_args='', learning_method=None, regularization=None, @@ -396,6 +525,19 @@ def settings(batch_size, value larger than some value, will be clipped. :type gradient_clipping_threshold: float + + :param learning_rate_schedule: A Learning Rate Scheduler object or basestr. + It is recommend to pass a LRS object. + If you set learning_rate_schedule as basestr, + you should manually set learning_rate_decay_a + learning_rate_decay_b and learning_rate_args. + + Check LRS.to_setting_kwargs to figure out + how to set these arguments. + :type learning_rate_schedule: basestring|Optimizer + :param learning_rate_decay_a: See learning_rate_schedule. + :param learning_rate_decay_b: See learning_rate_schedule. + :param learning_rate_args: See learning_rate_schedule. """ if isinstance(regularization, BaseRegularization): regularization = [regularization] @@ -406,15 +548,24 @@ def settings(batch_size, else: algorithm = 'owlqn' - args = [ - 'batch_size', 'learning_rate', 'learning_rate_decay_a', - 'learning_rate_decay_b', 'learning_rate_schedule', 'learning_rate_args' - ] + args = ['batch_size', 'learning_rate'] kwargs = dict() kwargs['algorithm'] = algorithm + for arg in args: kwargs[arg] = locals()[arg] + if isinstance(learning_rate_schedule, Optimizer): + kwargs = __extends__(kwargs, learning_rate_schedule.to_setting_kwargs()) + elif isinstance(learning_rate_schedule, basestring): + for arg in [ + 'learning_rate_decay_a', 'learning_rate_decay_b', + 'learning_rate_schedule', 'learning_rate_args' + ]: + kwargs[arg] = locals()[arg] + else: + raise RuntimeWarning("Unexcepted branch") + kwargs = __extends__(kwargs, learning_method.to_setting_kwargs()) learning_method.extra_settings()