Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added option to restore local vars during validation #141

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 15 additions & 14 deletions tfutils/db_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,16 +454,20 @@ def initialize(self):
ckpt_filename = None

if ckpt_filename is not None:
# initialize local and global variables
init_op_global = tf.global_variables_initializer()
self.sess.run(init_op_global)
init_op_local = tf.local_variables_initializer()
self.sess.run(init_op_local)
# Determine which vars should be restored from the specified checkpoint.
restore_vars = self.get_restore_vars(ckpt_filename)
restore_names = [name for name, var in restore_vars.items()]
# remap the actually restored names
# remap the actually restored names to the new ones
if self.load_param_dict:
new_restore_names = []
for each_old_name in restore_names:
new_restore_names.append(
self.load_param_dict[each_old_name])
restore_names = new_restore_names
for each_old_name in self.load_param_dict.keys():
if each_old_name in restore_names:
restore_names.remove(each_old_name)
restore_names.append(self.load_param_dict[each_old_name])

# Actually load the vars.
log.info('Restored Vars (in ckpt, in graph):\n'
Expand Down Expand Up @@ -528,17 +532,14 @@ def get_restore_vars(self, save_file):

# Specify which vars are to be restored vs. reinitialized.
all_vars = self.var_list
if not self.load_param_dict:
restore_vars = {
name: var for name, var in all_vars.items() \
if name in var_shapes}
else:
restore_vars = {
name: var for name, var in all_vars.items() \
if name in var_shapes}
if self.load_param_dict:
# associate checkpoint names with actual variables
load_var_dict = {}
for ckpt_var_name, curr_var_name in self.load_param_dict.items():
if curr_var_name in all_vars:
load_var_dict[ckpt_var_name] = all_vars[curr_var_name]
restore_vars = load_var_dict
restore_vars[ckpt_var_name] = all_vars[curr_var_name]

restore_vars = self.filter_var_list(restore_vars)

Expand Down
5 changes: 3 additions & 2 deletions tfutils/tests/test_dbinterface.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,9 @@ def test_get_restore_vars(self):

self.log.info('restore_vars:')
for name, var in restore_vars.items():
self.log.info('(name, var.name): ({}, {})'.format(name, var.name))
self.assertEqual(var.op.name, mapping[name])
if name in mapping.keys():
self.log.info('(name, var.name): ({}, {})'.format(name, var.name))
self.assertEqual(var.op.name, mapping[name])

def test_filter_var_list(self):

Expand Down