Skip to content

Commit

Permalink
adding --dry-run flag. closes #98
Browse files Browse the repository at this point in the history
  • Loading branch information
JaimieMurdock committed Feb 10, 2016
1 parent 0227c41 commit 538a869
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 18 deletions.
11 changes: 6 additions & 5 deletions topicexplorer/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,12 @@ def main():

elif args.func == 'train':
train.main(args)

print "\nTIP: launch the topic explorer with:"
print " vsm launch", args.config_file
print " or the notebook server with:"
print " vsm notebook", args.config_file

if not args.dry_run:
print "\nTIP: launch the topic explorer with:"
print " vsm launch", args.config_file
print " or the notebook server with:"
print " vsm notebook", args.config_file

elif args.func == 'launch':
launch.main(args)
Expand Down
29 changes: 16 additions & 13 deletions topicexplorer/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from topicexplorer.lib.util import bool_prompt, int_prompt, is_valid_filepath

def build_models(corpus, corpus_filename, model_path, context_type, krange,
n_iterations=200, n_proc=1, seed=None):
n_iterations=200, n_proc=1, seed=None, dry_run=False):

basefilename = os.path.basename(corpus_filename).replace('.npz','')
basefilename += "-LDA-K%s-%s-%d.npz" % ('{0}', context_type, n_iterations)
Expand All @@ -29,13 +29,14 @@ def build_models(corpus, corpus_filename, model_path, context_type, krange,
else:
seeds = None

for k in krange:
print "Training model for k={0} Topics with {1} Processes"\
.format(k, n_proc)
m = LDA(corpus, context_type, K=k, multiprocessing=(n_proc > 1),
seed_or_seeds=seeds)
m.train(n_iterations=n_iterations)
m.save(basefilename.format(k))
if not dry_run:
for k in krange:
print "Training model for k={0} Topics with {1} Processes"\
.format(k, n_proc)
m = LDA(corpus, context_type, K=k, multiprocessing=(n_proc > 1),
seed_or_seeds=seeds)
m.train(n_iterations=n_iterations)
m.save(basefilename.format(k))

return basefilename

Expand Down Expand Up @@ -159,21 +160,21 @@ def main(args):
print " vsm train %s --iter %d --context-type %s -k %s\n" %\
(args.config_file, args.iter, args.context_type,
' '.join(map(str, args.k)))

model_pattern = build_models(corpus, corpus_filename, model_path,
args.context_type, args.k,
n_iterations=args.iter,
n_proc=args.processes, seed=args.seed)

n_proc=args.processes, seed=args.seed,
dry_run=args.dry_run)
config.set("main", "model_pattern", model_pattern)
if args.context_type:
# test for presence, since continuing doesn't require context_type
config.set("main", "context_type", args.context_type)
args.k.sort()
config.set("main", "topics", str(args.k))

with open(args.config_file, "wb") as configfh:
config.write(configfh)
if not args.dry_run:
with open(args.config_file, "wb") as configfh:
config.write(configfh)

def populate_parser(parser):
parser.add_argument("config_file", help="Path to Config",
Expand All @@ -188,6 +189,8 @@ def populate_parser(parser):
help="K values to train upon", type=int)
parser.add_argument('--iter', type=int,
help="Number of training iterations")
parser.add_argument('--dry-run', dest='dry_run', action='store_true',
help="Run code without training models")


if __name__ == '__main__':
Expand Down

0 comments on commit 538a869

Please sign in to comment.