From 538a86943c201bdbc158a0b1d884388f932030c3 Mon Sep 17 00:00:00 2001 From: Jaimie Murdock Date: Tue, 9 Feb 2016 19:26:19 -0500 Subject: [PATCH] adding --dry-run flag. closes #98 --- topicexplorer/__main__.py | 11 ++++++----- topicexplorer/train.py | 29 ++++++++++++++++------------- 2 files changed, 22 insertions(+), 18 deletions(-) diff --git a/topicexplorer/__main__.py b/topicexplorer/__main__.py index 999fcd7..9764514 100755 --- a/topicexplorer/__main__.py +++ b/topicexplorer/__main__.py @@ -85,11 +85,12 @@ def main(): elif args.func == 'train': train.main(args) - - print "\nTIP: launch the topic explorer with:" - print " vsm launch", args.config_file - print " or the notebook server with:" - print " vsm notebook", args.config_file + + if not args.dry_run: + print "\nTIP: launch the topic explorer with:" + print " vsm launch", args.config_file + print " or the notebook server with:" + print " vsm notebook", args.config_file elif args.func == 'launch': launch.main(args) diff --git a/topicexplorer/train.py b/topicexplorer/train.py index ced6328..45ca7fa 100644 --- a/topicexplorer/train.py +++ b/topicexplorer/train.py @@ -10,7 +10,7 @@ from topicexplorer.lib.util import bool_prompt, int_prompt, is_valid_filepath def build_models(corpus, corpus_filename, model_path, context_type, krange, - n_iterations=200, n_proc=1, seed=None): + n_iterations=200, n_proc=1, seed=None, dry_run=False): basefilename = os.path.basename(corpus_filename).replace('.npz','') basefilename += "-LDA-K%s-%s-%d.npz" % ('{0}', context_type, n_iterations) @@ -29,13 +29,14 @@ def build_models(corpus, corpus_filename, model_path, context_type, krange, else: seeds = None - for k in krange: - print "Training model for k={0} Topics with {1} Processes"\ - .format(k, n_proc) - m = LDA(corpus, context_type, K=k, multiprocessing=(n_proc > 1), - seed_or_seeds=seeds) - m.train(n_iterations=n_iterations) - m.save(basefilename.format(k)) + if not dry_run: + for k in krange: + print "Training model for k={0} Topics with {1} Processes"\ + .format(k, n_proc) + m = LDA(corpus, context_type, K=k, multiprocessing=(n_proc > 1), + seed_or_seeds=seeds) + m.train(n_iterations=n_iterations) + m.save(basefilename.format(k)) return basefilename @@ -159,12 +160,11 @@ def main(args): print " vsm train %s --iter %d --context-type %s -k %s\n" %\ (args.config_file, args.iter, args.context_type, ' '.join(map(str, args.k))) - model_pattern = build_models(corpus, corpus_filename, model_path, args.context_type, args.k, n_iterations=args.iter, - n_proc=args.processes, seed=args.seed) - + n_proc=args.processes, seed=args.seed, + dry_run=args.dry_run) config.set("main", "model_pattern", model_pattern) if args.context_type: # test for presence, since continuing doesn't require context_type @@ -172,8 +172,9 @@ def main(args): args.k.sort() config.set("main", "topics", str(args.k)) - with open(args.config_file, "wb") as configfh: - config.write(configfh) + if not args.dry_run: + with open(args.config_file, "wb") as configfh: + config.write(configfh) def populate_parser(parser): parser.add_argument("config_file", help="Path to Config", @@ -188,6 +189,8 @@ def populate_parser(parser): help="K values to train upon", type=int) parser.add_argument('--iter', type=int, help="Number of training iterations") + parser.add_argument('--dry-run', dest='dry_run', action='store_true', + help="Run code without training models") if __name__ == '__main__':