From 0a10947a48b0e8f36ef933ef191dfef0f30be284 Mon Sep 17 00:00:00 2001 From: Aran Nayebi Date: Fri, 9 Aug 2019 15:23:54 -0700 Subject: [PATCH 1/3] Added option to restore local vars --- tfutils/db_interface.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tfutils/db_interface.py b/tfutils/db_interface.py index 07e0bdb..8bf1c91 100644 --- a/tfutils/db_interface.py +++ b/tfutils/db_interface.py @@ -454,6 +454,11 @@ def initialize(self): ckpt_filename = None if ckpt_filename is not None: + # initialize local and global variables + init_op_global = tf.global_variables_initializer() + self.sess.run(init_op_global) + init_op_local = tf.local_variables_initializer() + self.sess.run(init_op_local) # Determine which vars should be restored from the specified checkpoint. restore_vars = self.get_restore_vars(ckpt_filename) restore_names = [name for name, var in restore_vars.items()] From 4ded8a76a4a1cf482ce216a4d256bc3643efba04 Mon Sep 17 00:00:00 2001 From: Aran Nayebi Date: Tue, 13 Aug 2019 18:05:03 -0700 Subject: [PATCH 2/3] Having load params dict act as a modifier rather than complete replacement --- tfutils/db_interface.py | 24 ++++++++++-------------- 1 file changed, 10 insertions(+), 14 deletions(-) diff --git a/tfutils/db_interface.py b/tfutils/db_interface.py index 07e0bdb..fe15bc9 100644 --- a/tfutils/db_interface.py +++ b/tfutils/db_interface.py @@ -457,13 +457,12 @@ def initialize(self): # Determine which vars should be restored from the specified checkpoint. restore_vars = self.get_restore_vars(ckpt_filename) restore_names = [name for name, var in restore_vars.items()] - # remap the actually restored names + # remap the actually restored names to the new ones if self.load_param_dict: - new_restore_names = [] - for each_old_name in restore_names: - new_restore_names.append( - self.load_param_dict[each_old_name]) - restore_names = new_restore_names + for each_old_name in self.load_param_dict.keys(): + if each_old_name in restore_names: + restore_names.remove(each_old_name) + restore_names.append(self.load_param_dict[each_old_name]) # Actually load the vars. log.info('Restored Vars (in ckpt, in graph):\n' @@ -528,17 +527,14 @@ def get_restore_vars(self, save_file): # Specify which vars are to be restored vs. reinitialized. all_vars = self.var_list - if not self.load_param_dict: - restore_vars = { - name: var for name, var in all_vars.items() \ - if name in var_shapes} - else: + restore_vars = { + name: var for name, var in all_vars.items() \ + if name in var_shapes} + if self.load_param_dict: # associate checkpoint names with actual variables - load_var_dict = {} for ckpt_var_name, curr_var_name in self.load_param_dict.items(): if curr_var_name in all_vars: - load_var_dict[ckpt_var_name] = all_vars[curr_var_name] - restore_vars = load_var_dict + restore_vars[ckpt_var_name] = all_vars[curr_var_name] restore_vars = self.filter_var_list(restore_vars) From 407557c39b68e049d186ca0dfa2c3f5dabd88dd9 Mon Sep 17 00:00:00 2001 From: Aran Nayebi Date: Thu, 15 Aug 2019 11:31:50 -0700 Subject: [PATCH 3/3] Updated tests to reflect that load_param_dict keeps all variables but remaps the specified subset --- tfutils/tests/test_dbinterface.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tfutils/tests/test_dbinterface.py b/tfutils/tests/test_dbinterface.py index d0740df..511104d 100644 --- a/tfutils/tests/test_dbinterface.py +++ b/tfutils/tests/test_dbinterface.py @@ -131,8 +131,9 @@ def test_get_restore_vars(self): self.log.info('restore_vars:') for name, var in restore_vars.items(): - self.log.info('(name, var.name): ({}, {})'.format(name, var.name)) - self.assertEqual(var.op.name, mapping[name]) + if name in mapping.keys(): + self.log.info('(name, var.name): ({}, {})'.format(name, var.name)) + self.assertEqual(var.op.name, mapping[name]) def test_filter_var_list(self):