diff --git a/tfutils/db_interface.py b/tfutils/db_interface.py index 07e0bdb..07b4fbd 100644 --- a/tfutils/db_interface.py +++ b/tfutils/db_interface.py @@ -454,16 +454,20 @@ def initialize(self): ckpt_filename = None if ckpt_filename is not None: + # initialize local and global variables + init_op_global = tf.global_variables_initializer() + self.sess.run(init_op_global) + init_op_local = tf.local_variables_initializer() + self.sess.run(init_op_local) # Determine which vars should be restored from the specified checkpoint. restore_vars = self.get_restore_vars(ckpt_filename) restore_names = [name for name, var in restore_vars.items()] - # remap the actually restored names + # remap the actually restored names to the new ones if self.load_param_dict: - new_restore_names = [] - for each_old_name in restore_names: - new_restore_names.append( - self.load_param_dict[each_old_name]) - restore_names = new_restore_names + for each_old_name in self.load_param_dict.keys(): + if each_old_name in restore_names: + restore_names.remove(each_old_name) + restore_names.append(self.load_param_dict[each_old_name]) # Actually load the vars. log.info('Restored Vars (in ckpt, in graph):\n' @@ -528,17 +532,14 @@ def get_restore_vars(self, save_file): # Specify which vars are to be restored vs. reinitialized. all_vars = self.var_list - if not self.load_param_dict: - restore_vars = { - name: var for name, var in all_vars.items() \ - if name in var_shapes} - else: + restore_vars = { + name: var for name, var in all_vars.items() \ + if name in var_shapes} + if self.load_param_dict: # associate checkpoint names with actual variables - load_var_dict = {} for ckpt_var_name, curr_var_name in self.load_param_dict.items(): if curr_var_name in all_vars: - load_var_dict[ckpt_var_name] = all_vars[curr_var_name] - restore_vars = load_var_dict + restore_vars[ckpt_var_name] = all_vars[curr_var_name] restore_vars = self.filter_var_list(restore_vars) diff --git a/tfutils/tests/test_dbinterface.py b/tfutils/tests/test_dbinterface.py index d0740df..511104d 100644 --- a/tfutils/tests/test_dbinterface.py +++ b/tfutils/tests/test_dbinterface.py @@ -131,8 +131,9 @@ def test_get_restore_vars(self): self.log.info('restore_vars:') for name, var in restore_vars.items(): - self.log.info('(name, var.name): ({}, {})'.format(name, var.name)) - self.assertEqual(var.op.name, mapping[name]) + if name in mapping.keys(): + self.log.info('(name, var.name): ({}, {})'.format(name, var.name)) + self.assertEqual(var.op.name, mapping[name]) def test_filter_var_list(self):