forked from broadinstitute/viral-ngs
-
Notifications
You must be signed in to change notification settings - Fork 0
/
taxon_filter.py
executable file
·857 lines (709 loc) · 34.4 KB
/
taxon_filter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
#!/usr/bin/env python
'''This script contains a number of utilities for filtering NGS reads based
on membership or non-membership in a species / genus / taxonomic grouping.
'''
from __future__ import print_function
__author__ = "[email protected], [email protected]," \
__commands__ = []
import argparse
import glob
import logging
import subprocess
import os
import math
import tempfile
import shutil
import concurrent.futures
import contextlib
from Bio import SeqIO
import pysam
import util.cmd
import util.file
import util.misc
import tools
import tools.blast
import tools.last
import tools.prinseq
import tools.bmtagger
import tools.picard
import tools.samtools
from util.file import mkstempfname
import read_utils
log = logging.getLogger(__name__)
# =======================
# *** deplete_human ***
# =======================
def parser_deplete(parser=argparse.ArgumentParser()):
parser.add_argument('inBam', help='Input BAM file.')
parser.add_argument('revertBam', nargs='?', help='Output BAM: read markup reverted with Picard.')
parser.add_argument('bwaBam', help='Output BAM: depleted of reads with BWA.')
parser.add_argument('bmtaggerBam', help='Output BAM: depleted of reads with BMTagger.')
parser.add_argument('rmdupBam', help='Output BAM: bmtaggerBam run through M-Vicuna duplicate removal.')
parser.add_argument(
'blastnBam', help='Output BAM: rmdupBam run through another depletion of reads with BLASTN.'
)
parser.add_argument(
'--bwaDbs',
nargs='*',
default=(),
help='Reference databases for blast to deplete from input.'
)
parser.add_argument(
'--bmtaggerDbs',
nargs='*',
default=(),
help='''Reference databases to deplete from input.
For each db, requires prior creation of db.bitmask by bmtool,
and db.srprism.idx, db.srprism.map, etc. by srprism mkindex.'''
)
parser.add_argument(
'--blastDbs',
nargs='*',
default=(),
help='Reference databases for blast to deplete from input.'
)
parser.add_argument('--srprismMemory', dest="srprism_memory", type=int, default=7168, help='Memory for srprism.')
parser.add_argument("--chunkSize", type=int, default=1000000, help='blastn chunk size (default: %(default)s)')
parser.add_argument(
'--JVMmemory',
default=tools.picard.FilterSamReadsTool.jvmMemDefault,
help='JVM virtual memory size for Picard FilterSamReads (default: %(default)s)'
)
parser = read_utils.parser_revert_sam_common(parser)
util.cmd.common_args(parser, (('threads', None), ('loglevel', None), ('version', None), ('tmp_dir', None)))
util.cmd.attach_main(parser, main_deplete)
return parser
def main_deplete(args):
''' Run the entire depletion pipeline: bwa, bmtagger, mvicuna, blastn.
'''
assert len(args.bmtaggerDbs) + len(args.blastDbs) + len(args.bwaDbs) > 0
# only RevertSam if inBam is already aligned
# Most of the time the input will be unaligned
# so we can save save time if we can skip RevertSam in the unaligned case
#
# via the SAM/BAM spec, if the file is aligned, an SQ line should be present
# in the header. Using pysam, we can check this if header['SQ'])>0
# https://samtools.github.io/hts-specs/SAMv1.pdf
# if the user has requested a revertBam
with read_utils.revert_bam_if_aligned( args.inBam,
revert_bam = args.revertBam,
clear_tags = args.clear_tags,
tags_to_clear = args.tags_to_clear,
picardOptions = ['MAX_DISCARD_FRACTION=0.5'],
JVMmemory = args.JVMmemory,
sanitize = not args.do_not_sanitize) as bamToDeplete:
multi_db_deplete_bam(
bamToDeplete,
args.bwaDbs,
deplete_bwa_bam,
args.bwaBam,
threads=args.threads
)
def bmtagger_wrapper(inBam, db, outBam, JVMmemory=None):
return deplete_bmtagger_bam(inBam, db, outBam, srprism_memory=args.srprism_memory, JVMmemory=JVMmemory)
multi_db_deplete_bam(
args.bwaBam,
args.bmtaggerDbs,
bmtagger_wrapper,
args.bmtaggerBam,
JVMmemory=args.JVMmemory
)
# if the user has not specified saving a revertBam, we used a temp file and can remove it
if not args.revertBam:
os.unlink(revertBamOut)
read_utils.rmdup_mvicuna_bam(args.bmtaggerBam, args.rmdupBam, JVMmemory=args.JVMmemory)
multi_db_deplete_bam(
args.rmdupBam,
args.blastDbs,
deplete_blastn_bam,
args.blastnBam,
chunkSize=args.chunkSize,
threads=args.threads,
JVMmemory=args.JVMmemory
)
return 0
__commands__.append(('deplete', parser_deplete))
def parser_deplete_human(parser=argparse.ArgumentParser()):
parser = parser_deplete(parser)
util.cmd.attach_main(parser, main_deplete_human)
return parser
def main_deplete_human(args):
''' A wrapper around 'deplete'; deprecated but preserved for legacy compatibility.
'''
main_deplete(args)
__commands__.append(('deplete_human', parser_deplete_human))
# =======================
# *** filter_lastal ***
# =======================
def filter_lastal_bam(
inBam,
db,
outBam,
max_gapless_alignments_per_position=1,
min_length_for_initial_matches=5,
max_length_for_initial_matches=50,
max_initial_matches_per_position=100,
JVMmemory=None, threads=None
):
''' Restrict input reads to those that align to the given
reference database using LASTAL.
'''
with util.file.tmp_dir('-lastdb') as tmp_db_dir:
# index db if necessary
lastdb = tools.last.Lastdb()
if not lastdb.is_indexed(db):
db = lastdb.build_database(db, os.path.join(tmp_db_dir, 'lastdb'))
with util.file.tempfname('.read_ids.txt') as hitList:
# look for lastal hits in BAM and write to temp file
with open(hitList, 'wt') as outf:
for read_id in tools.last.Lastal().get_hits(
inBam, db,
max_gapless_alignments_per_position,
min_length_for_initial_matches,
max_length_for_initial_matches,
max_initial_matches_per_position,
threads=threads
):
outf.write(read_id + '\n')
# filter original BAM file against keep list
tools.picard.FilterSamReadsTool().execute(inBam, False, hitList, outBam, JVMmemory=JVMmemory)
def parser_filter_lastal_bam(parser=argparse.ArgumentParser()):
parser.add_argument("inBam", help="Input reads")
parser.add_argument("db", help="Database of taxa we keep")
parser.add_argument("outBam", help="Output reads, filtered to refDb")
parser.add_argument(
'-n',
dest="max_gapless_alignments_per_position",
help='maximum gapless alignments per query position (default: %(default)s)',
type=int,
default=1
)
parser.add_argument(
'-l',
dest="min_length_for_initial_matches",
help='minimum length for initial matches (default: %(default)s)',
type=int,
default=5
)
parser.add_argument(
'-L',
dest="max_length_for_initial_matches",
help='maximum length for initial matches (default: %(default)s)',
type=int,
default=50
)
parser.add_argument(
'-m',
dest="max_initial_matches_per_position",
help='maximum initial matches per query position (default: %(default)s)',
type=int,
default=100
)
parser.add_argument(
'--JVMmemory',
default=tools.picard.FilterSamReadsTool.jvmMemDefault,
help='JVM virtual memory size (default: %(default)s)'
)
util.cmd.common_args(parser, (('threads', None), ('loglevel', None), ('version', None), ('tmp_dir', None)))
util.cmd.attach_main(parser, filter_lastal_bam, split_args=True)
return parser
__commands__.append(('filter_lastal_bam', parser_filter_lastal_bam))
# ==============================
# *** deplete_bmtagger_bam ***
# ==============================
def deplete_bmtagger_bam(inBam, db, outBam, srprism_memory=7168, JVMmemory=None):
"""
Use bmtagger to partition the input reads into ones that match at least one
of the databases and ones that don't match any of the databases.
inBam: paired-end input reads in BAM format.
db: bmtagger expects files
db.bitmask created by bmtool, and
db.srprism.idx, db.srprism.map, etc. created by srprism mkindex
outBam: the output BAM files to hold the unmatched reads.
srprism_memory: srprism memory in megabytes.
"""
bmtaggerPath = tools.bmtagger.BmtaggerShTool().install_and_get_path()
# bmtagger calls several executables in the same directory, and blastn;
# make sure they are accessible through $PATH
blastnPath = tools.blast.BlastnTool().install_and_get_path()
path = os.environ['PATH'].split(os.pathsep)
for t in (bmtaggerPath, blastnPath):
d = os.path.dirname(t)
if d not in path:
path = [d] + path
path = os.pathsep.join(path)
os.environ['PATH'] = path
with util.file.tempfname('.1.fastq') as inReads1:
tools.samtools.SamtoolsTool().bam2fq(inBam, inReads1)
with util.file.tempfname('.bmtagger.conf') as bmtaggerConf:
with open(bmtaggerConf, 'w') as f:
# Default srprismopts: "-b 100000000 -n 5 -R 0 -r 1 -M 7168"
print('srprismopts="-b 100000000 -n 5 -R 0 -r 1 -M {srprism_memory} --paired false"'.format(srprism_memory=srprism_memory), file=f)
with extract_build_or_use_database(db, bmtagger_build_db, 'bitmask', tmp_suffix="-bmtagger", db_prefix="bmtagger") as (db_prefix,tempDir):
matchesFile = mkstempfname('.txt')
cmdline = [
bmtaggerPath, '-b', db_prefix + '.bitmask', '-C', bmtaggerConf, '-x', db_prefix + '.srprism', '-T', tempDir, '-q1',
'-1', inReads1, '-o', matchesFile
]
log.debug(' '.join(cmdline))
util.misc.run_and_print(cmdline, check=True)
tools.picard.FilterSamReadsTool().execute(inBam, True, matchesFile, outBam, JVMmemory=JVMmemory)
def parser_deplete_bam_bmtagger(parser=argparse.ArgumentParser()):
parser.add_argument('inBam', help='Input BAM file.')
parser.add_argument(
'refDbs',
nargs='+',
help='''Reference databases (one or more) to deplete from input.
For each db, requires prior creation of db.bitmask by bmtool,
and db.srprism.idx, db.srprism.map, etc. by srprism mkindex.'''
)
parser.add_argument('outBam', help='Output BAM file.')
parser.add_argument('--srprismMemory', dest="srprism_memory", type=int, default=7168, help='Memory for srprism.')
parser.add_argument(
'--JVMmemory',
default=tools.picard.FilterSamReadsTool.jvmMemDefault,
help='JVM virtual memory size (default: %(default)s)'
)
parser = read_utils.parser_revert_sam_common(parser)
util.cmd.common_args(parser, (('loglevel', None), ('version', None), ('tmp_dir', None)))
util.cmd.attach_main(parser, main_deplete_bam_bmtagger)
return parser
def main_deplete_bam_bmtagger(args):
'''Use bmtagger to deplete input reads against several databases.'''
def bmtagger_wrapper(inBam, db, outBam, JVMmemory=None):
return deplete_bmtagger_bam(inBam, db, outBam, srprism_memory=args.srprism_memory, JVMmemory=JVMmemory)
with read_utils.revert_bam_if_aligned( args.inBam,
clear_tags = args.clear_tags,
tags_to_clear = args.tags_to_clear,
picardOptions = ['MAX_DISCARD_FRACTION=0.5'],
JVMmemory = args.JVMmemory,
sanitize = not args.do_not_sanitize) as bamToDeplete:
multi_db_deplete_bam(
args.inBam,
args.refDbs,
bmtagger_wrapper,
args.outBam,
JVMmemory=args.JVMmemory
)
__commands__.append(('deplete_bam_bmtagger', parser_deplete_bam_bmtagger))
def multi_db_deplete_bam(inBam, refDbs, deplete_method, outBam, **kwargs):
tmpDb = None
if len(refDbs)>1 and not any(
not os.path.exists(db) # indexed db prefix
or os.path.isdir(db) # indexed db in directory
or (os.path.isfile(db) and ('.tar' in db or '.tgz' in db or '.zip' in db)) # packaged indexed db
for db in refDbs):
# this is a scenario where all refDbs are unbuilt fasta
# files. we can simplify and speed up execution by
# concatenating them all and running deplete_method
# just once
tmpDb = mkstempfname('.fasta')
merge_compressed_files(refDbs, tmpDb, sep='\n')
refDbs = [tmpDb]
samtools = tools.samtools.SamtoolsTool()
tmpBamIn = inBam
for db in refDbs:
if not samtools.isEmpty(tmpBamIn):
tmpBamOut = mkstempfname('.bam')
deplete_method(tmpBamIn, db, tmpBamOut, **kwargs)
if tmpBamIn != inBam:
os.unlink(tmpBamIn)
tmpBamIn = tmpBamOut
shutil.copyfile(tmpBamIn, outBam)
if tmpDb:
os.unlink(tmpDb)
# ========================
# *** deplete_blastn ***
# ========================
def _run_blastn_chunk(db, input_fasta, out_hits, blast_threads):
""" run blastn on the input fasta file. this is intended to be run in parallel
by blastn_chunked_fasta
"""
with util.file.open_or_gzopen(out_hits, 'wt') as outf:
for read_id in tools.blast.BlastnTool().get_hits_fasta(input_fasta, db, threads=blast_threads):
outf.write(read_id + '\n')
def blastn_chunked_fasta(fasta, db, out_hits, chunkSize=1000000, threads=None):
"""
Helper function: blastn a fasta file, overcoming apparent memory leaks on
an input with many query sequences, by splitting it into multiple chunks
and running a new blastn process on each chunk. Return a list of output
filenames containing hits
"""
# the lower bound of how small a fasta chunk can be.
# too small and the overhead of spawning a new blast process
# will be detrimental relative to actual computation time
MIN_CHUNK_SIZE = 20000
# just in case blast is not installed, install it once, not many times in parallel!
tools.blast.BlastnTool().install()
# clamp threadcount to number of CPU cores
threads = util.misc.sanitize_thread_count(threads)
# determine size of input data; records in fasta file
number_of_reads = util.file.fasta_length(fasta)
log.debug("number of reads in fasta file %s" % number_of_reads)
if number_of_reads == 0:
util.file.make_empty(out_hits)
# divide (max, single-thread) chunksize by thread count
# to find the absolute max chunk size per thread
chunk_max_size_per_thread = chunkSize // threads
# find the chunk size if evenly divided among blast threads
reads_per_thread = number_of_reads // threads
# use the smaller of the two chunk sizes so we can run more copies of blast in parallel
chunkSize = min(reads_per_thread, chunk_max_size_per_thread)
# if the chunk size is too small, impose a sensible size
chunkSize = max(chunkSize, MIN_CHUNK_SIZE)
log.debug("chunk_max_size_per_thread %s" % chunk_max_size_per_thread)
# adjust chunk size so we don't have a small fraction
# of a chunk running in its own blast process
# if the size of the last chunk is <80% the size of the others,
# decrease the chunk size until the last chunk is 80%
# this is bounded by the MIN_CHUNK_SIZE
while (number_of_reads / chunkSize) % 1 < 0.8 and chunkSize > MIN_CHUNK_SIZE:
chunkSize = chunkSize - 1
log.debug("blastn chunk size %s" % chunkSize)
log.debug("number of chunks to create %s" % (number_of_reads / chunkSize))
log.debug("blastn parallel instances %s" % threads)
# chunk the input file. This is a sequential operation
input_fastas = []
with open(fasta, "rt") as fastaFile:
record_iter = SeqIO.parse(fastaFile, "fasta")
for batch in util.misc.batch_iterator(record_iter, chunkSize):
chunk_fasta = mkstempfname('.fasta')
with open(chunk_fasta, "wt") as handle:
SeqIO.write(batch, handle, "fasta")
batch = None
input_fastas.append(chunk_fasta)
num_chunks = len(input_fastas)
log.debug("number of chunk files to be processed by blastn %d" % num_chunks)
# run blastn on each of the fasta input chunks
hits_files = list(mkstempfname('.hits.txt') for f in input_fastas)
with concurrent.futures.ProcessPoolExecutor(max_workers=threads) as executor:
# If we have so few chunks that there are cpus left over,
# divide extra cpus evenly among chunks where possible
# rounding to 1 if there are more chunks than extra threads.
# Then double up this number to better maximize CPU usage.
cpus_leftover = threads - num_chunks
blast_threads = 2*max(1, int(cpus_leftover / num_chunks))
for i in range(num_chunks):
executor.submit(
_run_blastn_chunk, db, input_fastas[i], hits_files[i], blast_threads)
# merge results and clean up
util.file.cat(out_hits, hits_files)
for i in range(num_chunks):
os.unlink(input_fastas[i])
os.unlink(hits_files[i])
def deplete_blastn_bam(inBam, db, outBam, threads=None, chunkSize=1000000, JVMmemory=None):
#def deplete_blastn_bam(inBam, db, outBam, threads, chunkSize=0, JVMmemory=None):
'Use blastn to remove reads that match at least one of the databases.'
blast_hits = mkstempfname('.blast_hits.txt')
with extract_build_or_use_database(db, blastn_build_db, 'nin', tmp_suffix="-blastn_db_unpack", db_prefix="blastn") as (db_prefix,tempDir):
if chunkSize:
## chunk up input and perform blastn in several parallel threads
with util.file.tempfname('.fasta') as reads_fasta:
tools.samtools.SamtoolsTool().bam2fa(inBam, reads_fasta)
log.info("running blastn on %s against %s", inBam, db)
blastn_chunked_fasta(reads_fasta, db_prefix, blast_hits, chunkSize, threads)
else:
## pipe tools together and run blastn multithreaded
with open(blast_hits, 'wt') as outf:
for read_id in tools.blast.BlastnTool().get_hits_bam(inBam, db_prefix, threads=threads):
outf.write(read_id + '\n')
# Deplete BAM of hits
tools.picard.FilterSamReadsTool().execute(inBam, True, blast_hits, outBam, JVMmemory=JVMmemory)
os.unlink(blast_hits)
def parser_deplete_blastn_bam(parser=argparse.ArgumentParser()):
parser.add_argument('inBam', help='Input BAM file.')
parser.add_argument('refDbs', nargs='+', help='One or more reference databases for blast. '
'An ephemeral database will be created if a fasta file is provided.')
parser.add_argument('outBam', help='Output BAM file with matching reads removed.')
parser.add_argument("--chunkSize", type=int, default=1000000, help='FASTA chunk size (default: %(default)s)')
parser.add_argument(
'--JVMmemory',
default=tools.picard.FilterSamReadsTool.jvmMemDefault,
help='JVM virtual memory size (default: %(default)s)'
)
parser = read_utils.parser_revert_sam_common(parser)
util.cmd.common_args(parser, (('threads', None), ('loglevel', None), ('version', None), ('tmp_dir', None)))
util.cmd.attach_main(parser, main_deplete_blastn_bam)
return parser
def main_deplete_blastn_bam(args):
'''Use blastn to remove reads that match at least one of the specified databases.'''
def wrapper(inBam, db, outBam, threads, JVMmemory=None):
return deplete_blastn_bam(inBam, db, outBam, threads=threads, chunkSize=args.chunkSize, JVMmemory=JVMmemory)
with read_utils.revert_bam_if_aligned( args.inBam,
clear_tags = args.clear_tags,
tags_to_clear = args.tags_to_clear,
picardOptions = ['MAX_DISCARD_FRACTION=0.5'],
JVMmemory = args.JVMmemory,
sanitize = not args.do_not_sanitize) as bamToDeplete:
multi_db_deplete_bam(bamToDeplete, args.refDbs, wrapper, args.outBam, threads=args.threads, JVMmemory=args.JVMmemory)
return 0
__commands__.append(('deplete_blastn_bam', parser_deplete_blastn_bam))
@contextlib.contextmanager
def extract_build_or_use_database(db, db_build_command, db_extension_to_expect, tmp_suffix='db_unpack', db_prefix="db"):
'''
db_extension_to_expect = file extension, sans dot prefix
'''
with util.file.tmp_dir(tmp_suffix) as tempDbDir:
db_dir = ""
if os.path.exists(db):
if os.path.isfile(db):
# this is a single file
if db.endswith('.fasta') or db.endswith('.fasta.gz') or db.endswith('.fasta.lz4') or db.endswith('.fa') or db.endswith('.fa.gz') or db.endswith('.fa.lz4'):
# this is an unindexed fasta file, we will need to index it
# function should conform to the signature:
# db_build_command(inputFasta, outputDirectory, outputFilePrefix)
# the function will need to be able to handle lz4, etc.
db_build_command(db, tempDbDir, db_prefix)
db_dir = tempDbDir
else:
# this is a tarball with prebuilt indexes
db_dir = util.file.extract_tarball(db, tempDbDir)
else:
# this is a directory
db_dir = db
# this directory should have a .{ext} file, where {ext} is specific to the type of db
hits = list(glob.glob(os.path.join(db_dir, '*.{ext}'.format(ext=db_extension_to_expect))))
if len(hits) == 0:
raise Exception("The blast database does not appear to a *.{ext} file.".format(ext=db_extension_to_expect))
elif len(hits) == 1:
db_prefix = hits[0][:-(len('.{ext}'.format(ext=db_extension_to_expect)))] # remove the '.extension'
elif len(hits) >1:
db_prefix = os.path.commonprefix(hits).rsplit('.', 1)[0] # remove extension and split-db prefix
else:
# this is simply a prefix to a bunch of files, not an actual file
db_prefix = db.rsplit('.', 1)[0] if db.endswith('.') else db
yield (db_prefix,tempDbDir)
# ========================
# *** deplete_bwa ***
# ========================
def deplete_bwa_bam(inBam, db, outBam, threads=None, clear_tags=True, tags_to_clear=None, JVMmemory=None):
'Use bwa to remove reads from an unaligned bam that match at least one of the databases.'
tags_to_clear = tags_to_clear or []
threads = util.misc.sanitize_thread_count(threads)
with extract_build_or_use_database(db, bwa_build_db, 'bwt', tmp_suffix="-bwa_db_unpack", db_prefix="bwa") as (db_prefix,tempDbDir):
with util.file.tempfname('.aligned.sam') as aligned_sam:
tools.bwa.Bwa().align_mem_bam(inBam, db_prefix, aligned_sam, threads=threads, should_index=False, JVMmemory=JVMmemory)
#with util.file.fifo(name='filtered.sam') as filtered_sam:
with util.file.tempfname('.filtered.sam') as filtered_sam:
# filter proper pairs
tools.samtools.SamtoolsTool().view(['-h','-F0x2'], aligned_sam, filtered_sam)
picardOptions = []
if clear_tags:
for tag in tags_to_clear:
picardOptions.append("ATTRIBUTE_TO_CLEAR={}".format(tag))
tools.picard.RevertSamTool().execute(
filtered_sam,
outBam,
picardOptions=['SORT_ORDER=queryname'] + picardOptions,
JVMmemory=JVMmemory
)
# TODO: consider using Bwa().mem() so the input bam is not broken out by read group
# TODO: pipe bwa input directly to samtools process (need to use Bwa().mem() directly, )
# with Popen to background bwa process
def parser_deplete_bwa_bam(parser=argparse.ArgumentParser()):
parser.add_argument('inBam', help='Input BAM file.')
parser.add_argument('refDbs', nargs='+', help='One or more reference databases for bwa. '
'An ephemeral database will be created if a fasta file is provided.')
parser.add_argument('outBam', help='Ouput BAM file with matching reads removed.')
parser = read_utils.parser_revert_sam_common(parser)
util.cmd.common_args(parser, (('threads', None), ('loglevel', None), ('version', None), ('tmp_dir', None)))
util.cmd.attach_main(parser, main_deplete_bwa_bam)
return parser
def main_deplete_bwa_bam(args):
'''Use BWA to remove reads that match at least one of the specified databases.'''
with read_utils.revert_bam_if_aligned( args.inBam,
clear_tags = args.clear_tags,
tags_to_clear = args.tags_to_clear,
picardOptions = ['MAX_DISCARD_FRACTION=0.5'],
JVMmemory = args.JVMmemory,
sanitize = not args.do_not_sanitize) as bamToDeplete:
#def wrapper(inBam, db, outBam, threads, JVMmemory=None):
# return deplete_bwa_bam(inBam, db, outBam, threads=threads, )
multi_db_deplete_bam(bamToDeplete, args.refDbs, deplete_bwa_bam, args.outBam, threads=args.threads, clear_tags=args.clear_tags, tags_to_clear=args.tags_to_clear, JVMmemory=args.JVMmemory)
return 0
__commands__.append(('deplete_bwa_bam', parser_deplete_bwa_bam))
# ========================
# *** lastal_build_db ***
# ========================
def lastal_build_db(inputFasta, outputDirectory, outputFilePrefix):
''' build a database for use with last based on an input fasta file '''
if outputFilePrefix:
outPrefix = outputFilePrefix
else:
baseName = os.path.basename(inputFasta)
fileNameSansExtension = os.path.splitext(baseName)[0]
outPrefix = fileNameSansExtension
tools.last.Lastdb().build_database(inputFasta, os.path.join(outputDirectory, outPrefix))
def parser_lastal_build_db(parser=argparse.ArgumentParser()):
parser.add_argument('inputFasta', help='Location of the input FASTA file')
parser.add_argument('outputDirectory', help='Location for the output files (default is cwd: %(default)s)')
parser.add_argument(
'--outputFilePrefix',
help='Prefix for the output file name (default: inputFasta name, sans ".fasta" extension)'
)
util.cmd.common_args(parser, (('loglevel', None), ('version', None), ('tmp_dir', None)))
util.cmd.attach_main(parser, lastal_build_db, split_args=True)
return parser
__commands__.append(('lastal_build_db', parser_lastal_build_db))
# ================================
# *** merge_compressed_files ***
# ================================
def merge_compressed_files(inFiles, outFile, sep=''):
''' Take a collection of input text files, possibly compressed,
and concatenate into a single output text file.
TO DO: if we made util.file.open_or_gzopen more multilingual,
we wouldn't need this.
'''
with util.file.open_or_gzopen(outFile, 'wt') as outf:
first = True
for infname in inFiles:
if not first:
if sep:
outf.write(sep)
else:
first = False
if infname.endswith('.gz') or infname.endswith('.lz4') or infname.endswith('.bz2'):
if infname.endswith('.gz'):
decompressor = ['pigz', '-d']
elif infname.endswith('.lz4'):
decompressor = ['lz4', '-d']
else:
decompressor = ['lbzip2', '-d']
with open(infname, 'rb') as inf:
subprocess.check_call(decompressor, stdin=inf, stdout=outf)
else:
with open(infname, 'rt') as inf:
for line in inf:
outf.write(line)
# ========================
# *** bwa_build_db ***
# ========================
def bwa_build_db(inputFasta, outputDirectory, outputFilePrefix):
""" Create a database for use with bwa from an input reference FASTA file
"""
new_fasta = None
if inputFasta.endswith('.gz') or inputFasta.endswith('.lz4'):
if inputFasta.endswith('.gz'):
decompressor = ['pigz', '-dc']
else:
decompressor = ['lz4', '-d']
new_fasta = util.file.mkstempfname('.fasta')
with open(inputFasta, 'rb') as inf, open(new_fasta, 'wb') as outf:
subprocess.check_call(decompressor, stdin=inf, stdout=outf)
inputFasta = new_fasta
# make the output path if it does not exist
util.file.mkdir_p(outputDirectory)
if outputFilePrefix:
outPrefix = outputFilePrefix
else:
baseName = os.path.basename(inputFasta)
fileNameSansExtension = os.path.splitext(baseName)[0]
outPrefix = fileNameSansExtension
tools.bwa.Bwa().index(inputFasta, output=os.path.join(outputDirectory, outPrefix))
if new_fasta is not None:
os.unlink(new_fasta)
def parser_bwa_build_db(parser=argparse.ArgumentParser()):
parser.add_argument('inputFasta', help='Location of the input FASTA file')
parser.add_argument('outputDirectory', help='Location for the output files')
parser.add_argument(
'--outputFilePrefix',
help='Prefix for the output file name (default: inputFasta name, sans ".fasta" extension)'
)
util.cmd.common_args(parser, (('loglevel', None), ('version', None), ('tmp_dir', None)))
util.cmd.attach_main(parser, bwa_build_db, split_args=True)
return parser
__commands__.append(('bwa_build_db', parser_bwa_build_db))
# ========================
# *** blastn_build_db ***
# ========================
def blastn_build_db(inputFasta, outputDirectory, outputFilePrefix):
""" Create a database for use with blastn from an input reference FASTA file
"""
new_fasta = None
if inputFasta.endswith('.gz') or inputFasta.endswith('.lz4'):
if inputFasta.endswith('.gz'):
decompressor = ['pigz', '-dc']
else:
decompressor = ['lz4', '-d']
new_fasta = util.file.mkstempfname('.fasta')
with open(inputFasta, 'rb') as inf, open(new_fasta, 'wb') as outf:
subprocess.check_call(decompressor, stdin=inf, stdout=outf)
inputFasta = new_fasta
if outputFilePrefix:
outPrefix = outputFilePrefix
else:
baseName = os.path.basename(inputFasta)
fileNameSansExtension = os.path.splitext(baseName)[0]
outPrefix = fileNameSansExtension
blastdb_path = tools.blast.MakeblastdbTool().build_database(inputFasta, os.path.join(outputDirectory, outPrefix))
if new_fasta is not None:
os.unlink(new_fasta)
def parser_blastn_build_db(parser=argparse.ArgumentParser()):
parser.add_argument('inputFasta', help='Location of the input FASTA file')
parser.add_argument('outputDirectory', help='Location for the output files')
parser.add_argument(
'--outputFilePrefix',
help='Prefix for the output file name (default: inputFasta name, sans ".fasta" extension)'
)
util.cmd.common_args(parser, (('loglevel', None), ('version', None), ('tmp_dir', None)))
util.cmd.attach_main(parser, blastn_build_db, split_args=True)
return parser
__commands__.append(('blastn_build_db', parser_blastn_build_db))
# ========================
# *** bmtagger_build_db ***
# ========================
def bmtagger_build_db(inputFasta, outputDirectory, outputFilePrefix, word_size=18):
""" Create a database for use with Bmtagger from an input FASTA file.
"""
new_fasta = None
if inputFasta.endswith('.gz') or inputFasta.endswith('.lz4'):
if inputFasta.endswith('.gz'):
decompressor = ['pigz', '-dc']
else:
decompressor = ['lz4', '-d']
new_fasta = util.file.mkstempfname('.fasta')
log.debug("cat {} | {} > {}".format(inputFasta, ' '.join(decompressor), new_fasta))
with open(inputFasta, 'rb') as inf, open(new_fasta, 'wb') as outf:
subprocess.check_call(decompressor, stdin=inf, stdout=outf)
inputFasta = new_fasta
if outputFilePrefix:
outPrefix = outputFilePrefix
else:
baseName = os.path.basename(inputFasta)
fileNameSansExtension = os.path.splitext(baseName)[0]
outPrefix = fileNameSansExtension
log.debug("building bmtagger and srprism databases on {}".format(os.path.join(outputDirectory, outPrefix)))
bmtooldb_path = tools.bmtagger.BmtoolTool().build_database(
inputFasta, os.path.join(outputDirectory, outPrefix + ".bitmask"), word_size=word_size
)
srprismdb_path = tools.bmtagger.SrprismTool().build_database(
inputFasta, os.path.join(outputDirectory, outPrefix + ".srprism")
)
if new_fasta is not None:
os.unlink(new_fasta)
def parser_bmtagger_build_db(parser=argparse.ArgumentParser()):
parser.add_argument('inputFasta', help='Location of the input FASTA file')
parser.add_argument(
'outputDirectory',
help='Location for the output files (Where *.bitmask and *.srprism files will be stored)'
)
parser.add_argument(
'--outputFilePrefix',
help='Prefix for the output file name (default: inputFasta name, sans ".fasta" extension)'
)
parser.add_argument(
'--word_size',
type=int,
default=18,
help='Database word size (default: %(default)s)'
)
util.cmd.common_args(parser, (('loglevel', None), ('version', None), ('tmp_dir', None)))
util.cmd.attach_main(parser, bmtagger_build_db, split_args=True)
return parser
__commands__.append(('bmtagger_build_db', parser_bmtagger_build_db))
# ========================
def full_parser():
return util.cmd.make_parser(__commands__, __doc__)
if __name__ == '__main__':
util.cmd.main_argparse(__commands__, __doc__)