diff --git a/README.md b/README.md index e9a6ff7..c54ae60 100644 --- a/README.md +++ b/README.md @@ -158,6 +158,16 @@ Then we marked out each utterance with our emotions classifier that predicts one To mark-up your own corpus with emotions you can use, for example, [DeepMoji tool](https://github.com/bfelbo/DeepMoji) or any other emotions classifier that you have. +#### Initalizing model weights from file +For some tools (for example [`tools/train.py`](tools/train.py)) you can specify the path to model's initialization weights via `--init_weights` argument. + +The weights may come from a trained CakeChat model or from a model with a different architecture. +In the latter case some parameters of Cakechat model may be left without initialization: +a parameter will be initialized with a saved value if the parameter's name and shape are +identical to the saved parameter, otherwise the parameter will keep its default initialization weights. + +See `load_weights` function for the details. + ### Training your own model 1. Put your training text corpus to [`data/corpora_processed/`](data/corpora_processed/). diff --git a/cakechat/config.py b/cakechat/config.py index 2010027..a0410b9 100644 --- a/cakechat/config.py +++ b/cakechat/config.py @@ -68,7 +68,7 @@ LEARNING_RATE = 1.0 # Learning rate for the chosen optimizer (currently using Adadelta, see model.py) # model params -NN_MODEL_PREFIX = 'cakechat' # Specify prefix to be prepended to model's name +NN_MODEL_PREFIX = 'cakechat_v1.3' # Specify prefix to be prepended to model's name # predictions params MAX_PREDICTIONS_LENGTH = 40 # Max. number of tokens which can be generated on the prediction step diff --git a/cakechat/dialog_model/model.py b/cakechat/dialog_model/model.py index 14a16cd..6f589ac 100644 --- a/cakechat/dialog_model/model.py +++ b/cakechat/dialog_model/model.py @@ -622,20 +622,54 @@ def is_reverse_model(self): return self._is_reverse_model def load_weights(self): - with open(self.model_load_path, 'rb') as f: - loaded_file = np.load(f) - # Just using .values() would't work here because we need to keep the order of elements - ordered_params = [loaded_file['arr_%d' % i] for i in xrange(len(loaded_file.files))] - set_all_param_values(self._net['dist'], ordered_params) + _logger.info('\nLoading saved weights from file:\n{}\n'.format(self.model_load_path)) + saved_var_name_to_var = OrderedDict(np.load(self.model_load_path)) - def save_model(self, save_model_path): - ensure_dir(os.path.dirname(save_model_path)) - ordered_params = get_all_param_values(self._net['dist']) + var_name_to_var = OrderedDict([(v.name, v) for v in get_all_params(self._net['dist'])]) + initialized_vars, missing_vars, mismatched_vars = [], [], [] - with open(save_model_path, 'wb') as f: - np.savez(f, *ordered_params) + for var_name, var in var_name_to_var.iteritems(): + if var_name not in saved_var_name_to_var: + missing_vars.append(var_name) + continue - _logger.info('\nSaved model:\n{}\n'.format(save_model_path)) + default_var_value = var.get_value() + saved_var_value = saved_var_name_to_var[var_name] + + if default_var_value.shape != saved_var_value.shape: + mismatched_vars.append((var_name, default_var_value.shape, saved_var_value.shape)) + continue + + # Checks passed, set parameter value + var.set_value(saved_var_value) + initialized_vars.append(var_name) + del saved_var_name_to_var[var_name] + + laconic_logger.info('\nRestored saved params:') + for var_name in initialized_vars: + laconic_logger.info('\t' + var_name) + + laconic_logger.warning('\nMissing saved params:') + for var_name in missing_vars: + laconic_logger.warning('\t' + var_name) + + laconic_logger.warning('\nShapes-mismatched params (saved -> current):') + for var_name, default_shape, saved_shape in mismatched_vars: + laconic_logger.warning('\t{0:<40} {1:<12} -> {2:<12}'.format(var_name, saved_shape, default_shape)) + + laconic_logger.warning('\nUnused saved params:') + for var_name in saved_var_name_to_var: + laconic_logger.warning('\t' + var_name) + + laconic_logger.info('') + + def save_model(self, save_path): + all_params = get_all_params(self._net['dist']) + with open(save_path, 'wb') as f: + params = {v.name: v.get_value() for v in all_params} + np.savez(f, **params) + + _logger.info('\nSaved model:\n{}\n'.format(save_path)) @staticmethod def delete_model(delete_path): @@ -684,7 +718,6 @@ def get_nn_model(index_to_token, index_to_condition, model_init_path=None, w2v_m model_exists = resolver.resolve() if model_exists: - _logger.info('\nLoading weights from file:\n{}\n'.format(model.model_load_path)) model.load_weights() elif model_init_path: raise FileNotFoundException('Can\'t initialize model from file:\n{}\n'.format(model_init_path))