Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Build a new model for the RNAsamba tool to use for predicting coding vs. noncoding RNAs #1

Merged
merged 35 commits into from
Feb 14, 2024
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
4379976
add ensembl genome links for rnasamba model building
taylorreiter Feb 2, 2024
5c0cde3
check in cluster all and dumb split code
taylorreiter Feb 5, 2024
28477ff
prelim processing into sets finished (still need to refactor and add …
taylorreiter Feb 6, 2024
bf20a00
add summary stats
taylorreiter Feb 6, 2024
d20938a
swap out wildcard names to be more descriptive
taylorreiter Feb 6, 2024
6a759e8
simplify output of set creation
taylorreiter Feb 6, 2024
b0aba84
add rule to build rnasamba model
taylorreiter Feb 6, 2024
d2ef62a
bump snakemake version and add pandas to dev
taylorreiter Feb 7, 2024
7eca248
add early stopping epochs, typos
taylorreiter Feb 7, 2024
e56261b
add rules and code to assess accuracy of new RNAsamba model
taylorreiter Feb 8, 2024
198fff9
missing eof new line
taylorreiter Feb 8, 2024
f50f19c
rm comments
taylorreiter Feb 8, 2024
6989b64
add in comparison to existing human model to show improvement
taylorreiter Feb 8, 2024
3dcdf2b
diversify snakefmt file endings for CI
taylorreiter Feb 9, 2024
aa15c46
update snakefmt file endings and run linting locally
taylorreiter Feb 9, 2024
12985a8
Apply suggestions from code review
taylorreiter Feb 12, 2024
b31317f
indentation
taylorreiter Feb 12, 2024
93f1666
Merge branch 'ter/build-rnasamba-model' of github.com:Arcadia-Science…
taylorreiter Feb 12, 2024
eaee034
sample with replacement to augment noncoding to coding numbers
taylorreiter Feb 13, 2024
0a3c232
add new test data set links
taylorreiter Feb 13, 2024
8972760
update train and test sets to be different species
taylorreiter Feb 13, 2024
55cc055
update pthas
taylorreiter Feb 13, 2024
aace13f
linting
taylorreiter Feb 13, 2024
b54f0cb
try update rnasamba env for gpu
taylorreiter Feb 13, 2024
6879d04
deal with rnasamba install for gpu
taylorreiter Feb 13, 2024
a073c13
update file pointers for data processing
taylorreiter Feb 13, 2024
98c1a98
linting
taylorreiter Feb 13, 2024
43e2eae
fix typos
taylorreiter Feb 13, 2024
f614cbd
add a benchmark for model building
taylorreiter Feb 13, 2024
1a7d863
missing new line eof
taylorreiter Feb 13, 2024
3e24ad5
woops filepath typo
taylorreiter Feb 14, 2024
f97ade0
clean up versions around rnasamba env
taylorreiter Feb 14, 2024
ee3c279
clean up class weights comment
taylorreiter Feb 14, 2024
dcb8b2a
add note about order of arguments
taylorreiter Feb 14, 2024
1287e5c
change snakefile name since we were able to keep rnasamba build commands
taylorreiter Feb 14, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
284 changes: 284 additions & 0 deletions curate_datasets.snakefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,284 @@
import pandas as pd
import os
import re

metadata = pd.read_csv("inputs/models/datasets/train_data_links.tsv", sep="\t")
GENOMES = metadata["genome"].unique().tolist()
RNA_TYPES = ["cdna", "ncrna"] # inherits names from ensembl
VALIDATION_TYPES = [
"mRNAs",
"ncRNAs",
] # inherits names from https://github.com/cbl-nabi/RNAChallenge
CODING_TYPES = ["coding", "noncoding"]
DATASET_TYPES = ["train", "test", "validation"]
MODEL_TYPES = ["eukaryote", "human"]


rule all:
input:
"outputs/models/datasets/3_stats/set_summary.tsv",
expand(
"outputs/models/build/rnasamba/1_evaluation/{model_type}/accuracy_metrics_{dataset_type}.tsv",
model_type=MODEL_TYPES,
dataset_type=DATASET_TYPES,
),


rule download_ensembl_data:
"""
Download ensembl cDNA and ncRNA files.
Ensembl annotates protein coding and non-coding RNA transcripts in their files.
This information will be used to separate protein coding from non-coding RNAs to build an RNAsamba model.
Note this download renames genome files from their names on ensembl to make them simpler to point to.
See example transformations below:
- cdna/Homo_sapiens.GRCh38.cdna.all.fa.gz -> cdna/Homo_sapiens.GRCh38.cdna.fa.gz (dropped "all.")
- ncrna/Homo_sapiens.GRCh38.ncrna.fa.gz -> ncrna/Homo_sapiens.GRCh38.ncrna.fa.gz (no change)
"""
output:
"inputs/models/datasets/ensembl/{rna_type}/{genome}.{rna_type}.fa.gz",
run:
genome_df = metadata.loc[(metadata["genome"] == wildcards.genome)]
root_url = genome_df["root_url"].values[0]
if wildcards.rna_type == "cdna":
suffix = genome_df["cdna_suffix"].values[0]
else:
suffix = genome_df["ncrna_suffix"].values[0]

url = root_url + suffix
shell("curl -JLo {output} {url}")


rule extract_protein_coding_orfs_from_cdna:
"""
Ensembl cDNA files consist of transcript sequences for actual and possible genes, including pseudogenes, NMD and the like.
Transcripts in the cDNA files have headers like: >TRANSCRIPT_ID SEQTYPE LOCATION GENE_ID GENE_BIOTYPE TRANSCRIPT_BIOTYPE, where the gene_biotype and transcript_biotype both contain information about whether the gene is coding or not.
"""
input:
"inputs/models/datasets/ensembl/cdna/{genome}.cdna.fa.gz",
output:
"outputs/models/datasets/0_coding/{genome}.fa.gz",
conda:
"envs/seqkit.yml"
shell:
"""
seqkit grep --use-regexp --by-name --pattern "transcript_biotype:protein_coding" -o {output} {input}
"""


rule download_validation_data:
output:
"inputs/models/datasets/validation/rnachallenge/{validation_type}.fa.gz",
shell:
"""
curl -JL https://raw.githubusercontent.com/cbl-nabi/RNAChallenge/main/RNAchallenge/{wildcards.validation_type}.fa | gzip > {output}
"""


rule combine_sequences:
input:
coding=expand("outputs/models/datasets/0_coding/{genome}.fa.gz", genome=GENOMES),
noncoding=expand(
"inputs/models/datasets/ensembl/ncrna/{genome}.ncrna.fa.gz", genome=GENOMES
),
validation=expand(
"inputs/models/datasets/validation/rnachallenge/{validation_type}.fa.gz",
validation_type=VALIDATION_TYPES,
),
output:
"outputs/models/datasets/1_homology_reduction/all_sequences.fa",
shell:
"""
cat {input} | gunzip > {output}
"""


rule grab_all_sequence_names_and_lengths:
input:
"outputs/models/datasets/1_homology_reduction/all_sequences.fa",
output:
"outputs/models/datasets/1_homology_reduction/all_sequences.fa.seqkit.fai",
conda:
"envs/seqkit.yml"
shell:
"""
seqkit faidx -f {input}
"""


rule reduce_sequence_homology:
"""
To reduce pollution between training and testing set, cluster sequences at 80% sequence identity.
"""
input:
"outputs/models/datasets/1_homology_reduction/all_sequences.fa",
output:
"outputs/models/datasets/1_homology_reduction/clustered_sequences_rep_seq.fasta",
"outputs/models/datasets/1_homology_reduction/clustered_sequences_cluster.tsv",
params:
prefix="outputs/models/datasets/1_homology_reduction/clustered_sequences",
conda:
"envs/mmseqs2.yml"
shell:
"""
mmseqs easy-cluster {input} {params.prefix} tmp_mmseqs2 --min-seq-id 0.8 --cov-mode 1 --cluster-mode 2
"""


rule grab_validation_set_names_and_lengths:
"""
The train/test data set sequences are identifiable by the genome information in the header, which is consistently formatted by Ensembl.
The same is not true for the validation data.
This rule grabs the validation sequence header names so they can be separated from the train/test sets.
"""
input:
"inputs/models/datasets/validation/rnachallenge/{validation_type}.fa.gz",
output:
validation="inputs/models/datasets/validation/rnachallenge/{validation_type}.fa",
validation_fai="inputs/models/datasets/validation/rnachallenge/{validation_type}.fa.seqkit.fai",
conda:
"envs/seqkit.yml"
shell:
"""
cat {input} | gunzip > {output.validation}
seqkit faidx -f {output.validation}
"""


rule process_sequences_into_nonoverlapping_sets:
input:
all_fai="outputs/models/datasets/1_homology_reduction/all_sequences.fa.seqkit.fai",
validation_fai=expand(
"inputs/models/datasets/validation/rnachallenge/{validation_type}.fa.seqkit.fai",
taylorreiter marked this conversation as resolved.
Show resolved Hide resolved
validation_type=VALIDATION_TYPES,
),
clusters="outputs/models/datasets/1_homology_reduction/clustered_sequences_cluster.tsv",
output:
expand(
"outputs/models/datasets/2_sequence_sets/{coding_type}_{dataset_type}.txt",
coding_type=CODING_TYPES,
dataset_type=DATASET_TYPES,
),
taylorreiter marked this conversation as resolved.
Show resolved Hide resolved
conda:
"envs/tidyverse.yml"
script:
"scripts/process_sequences_into_nonoverlapping_sets.R"


rule filter_sequence_sets:
input:
fa="outputs/models/datasets/1_homology_reduction/clustered_sequences_rep_seq.fasta",
names="outputs/models/datasets/2_sequence_sets/{coding_type}_{dataset_type}.txt",
output:
"outputs/models/datasets/2_sequence_sets/{coding_type}_{dataset_type}.fa",
conda:
"envs/seqtk.yml"
shell:
"""
seqtk subseq {input.fa} {input.names} > {output}
"""


##################################################################
## Build RNAsamba model
##################################################################


rule build_rnasamba_model:
"""
Build a new rnasamba model from the training data curated above.
The --early_stopping parameter reduces training time and can help avoid overfitting.
It is the number of epochs after lowest validation loss before stopping training.
"""
input:
expand(
"outputs/models/datasets/2_sequence_sets/{coding_type}_train.fa",
coding_type=CODING_TYPES,
),
output:
"outputs/models/build/rnasamba/0_model/eukaryote_rnasamba.hdf5",
conda:
"envs/rnasamba.yml"
shell:
"""
rnasamba train --early_stopping 5 --verbose 2 {output} {input[0]} {input[1]}
"""


rule assess_rnasamba_model:
input:
model="outputs/models/build/rnasamba/0_model/{model_type}_rnasamba.hdf5",
faa="outputs/models/datasets/2_sequence_sets/{coding_type}_{dataset_type}.fa",
output:
faa="outputs/models/build/rnasamba/1_evaluation/{model_type}/{coding_type}_{dataset_type}.fa",
predictions="outputs/models/build/rnasamba/1_evaluation/{model_type}/{coding_type}_{dataset_type}.tsv",
benchmark:
"benchmarks/models/build/rnasamba/1_evaluation/{model_type}/{coding_type}_{dataset_type}.tsv"
conda:
"envs/rnasamba.yml"
shell:
"""
rnasamba classify --protein_fasta {output.faa} {output.predictions} {input.faa} {input.model}
"""


rule calculate_rnasamba_model_accuracy:
input:
expand(
"outputs/models/build/rnasamba/1_evaluation/{{model_type}}/{coding_type}_{{dataset_type}}.tsv",
coding_type=CODING_TYPES,
),
output:
freq="outputs/models/build/build/1_evaluation/{model_type}/confusionmatrix_{dataset_type}.tsv",
metrics="outputs/models/build/rnasamba/1_evaluation/{model_type}/accuracy_metrics_{dataset_type}.tsv",
conda:
"envs/caret.yml"
script:
"scripts/calculate_rnasamba_model_accuracy.R"


rule download_rnasamba_human_model:
"""
Use this model to compare whether the new model performs better or worse.
It's saved under a new name so we can use a wildcard to run rnasamba classify and to calculate model accuracy.
"""
output:
"outputs/models/build/rnasamba/0_model/human_rnasamba.hdf5",
shell:
"""
curl -JLo {output} https://github.com/apcamargo/RNAsamba/raw/master/data/full_length_weights.hdf5
"""


##################################################################
## Get sequence statistics
##################################################################


rule get_sequence_descriptors:
input:
"outputs/models/datasets/2_sequence_sets/{coding_type}_{dataset_type}.fa",
output:
"outputs/models/datasets/2_sequence_sets/{coding_type}_{dataset_type}.fa.seqkit.fai",
conda:
"envs/seqkit.yml"
shell:
"""
seqkit faidx -f {input}
"""


rule calculate_sequence_statistics:
input:
expand(
"outputs/models/datasets/2_sequence_sets/{coding_type}_{dataset_type}.fa.seqkit.fai",
coding_type=CODING_TYPES,
dataset_type=DATASET_TYPES,
),
output:
set_summary="outputs/models/datasets/3_stats/set_summary.tsv",
set_length_summary="outputs/models/datasets/3_stats/set_length_summary.tsv",
set_length_genome_summary="outputs/models/datasets/3_stats/set_length_genome_summary.tsv",
conda:
"envs/tidyverse.yml"
script:
"scripts/calculate_sequence_statistics.R"
7 changes: 7 additions & 0 deletions envs/caret.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
channels:
- conda-forge
- bioconda
- defaults
dependencies:
- r-tidyverse=2.0.0
- r-caret=6.0_94
3 changes: 2 additions & 1 deletion envs/dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@ dependencies:
- python=3.12.0
- ruff=0.1.6
- snakefmt=0.8.5
- snakemake-minimal=8.0.1
- snakemake-minimal=8.4.2
- pandas=2.2.0
6 changes: 6 additions & 0 deletions envs/mmseqs2.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
channels:
- conda-forge
- bioconda
- defaults
dependencies:
- mmseqs2=15.6f452
6 changes: 6 additions & 0 deletions envs/rnasamba.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
channels:
- conda-forge
- bioconda
- defaults
dependencies:
- rnasamba=0.2.5
6 changes: 6 additions & 0 deletions envs/seqkit.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
channels:
- conda-forge
- bioconda
- defaults
dependencies:
- seqkit=2.6.1
6 changes: 6 additions & 0 deletions envs/seqtk.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
channels:
- conda-forge
- bioconda
- defaults
dependencies:
- seqtk=1.4
6 changes: 6 additions & 0 deletions envs/tidyverse.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
channels:
- conda-forge
- bioconda
- defaults
dependencies:
- r-tidyverse=2.0.0
17 changes: 17 additions & 0 deletions inputs/models/datasets/train_data_links.tsv
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
organism root_url genome cdna_suffix ncrna_suffix genome_abbreviation set_name
human https://ftp.ensembl.org/pub/release-111/fasta/homo_sapiens/ Homo_sapiens.GRCh38 cdna/Homo_sapiens.GRCh38.cdna.all.fa.gz ncrna/Homo_sapiens.GRCh38.ncrna.fa.gz GRCh38 train
yeast https://ftp.ensemblgenomes.ebi.ac.uk/pub/fungi/release-58/fasta/saccharomyces_cerevisiae/ Saccharomyces_cerevisiae.R64-1-1 cdna/Saccharomyces_cerevisiae.R64-1-1.cdna.all.fa.gz ncrna/Saccharomyces_cerevisiae.R64-1-1.ncrna.fa.gz R64-1-1 train
worm https://ftp.ensemblgenomes.ebi.ac.uk/pub/metazoa/release-58/fasta/caenorhabditis_elegans/ Caenorhabditis_elegans.WBcel235 cdna/Caenorhabditis_elegans.WBcel235.cdna.all.fa.gz ncrna/Caenorhabditis_elegans.WBcel235.ncrna.fa.gz WBcel235 train
arabadopsis https://ftp.ensemblgenomes.ebi.ac.uk/pub/plants/release-58/fasta/arabidopsis_thaliana/ Arabidopsis_thaliana.TAIR10 cdna/Arabidopsis_thaliana.TAIR10.cdna.all.fa.gz ncrna/Arabidopsis_thaliana.TAIR10.ncrna.fa.gz TAIR10 train
drosophila https://ftp.ensemblgenomes.ebi.ac.uk/pub/metazoa/release-58/fasta/drosophila_melanogaster/ Drosophila_melanogaster.BDGP6.46 cdna/Drosophila_melanogaster.BDGP6.46.cdna.all.fa.gz ncrna/Drosophila_melanogaster.BDGP6.46.ncrna.fa.gz BDGP6.46 train
dictyostelium_discoideum https://ftp.ensemblgenomes.ebi.ac.uk/pub/protists/release-58/fasta/dictyostelium_discoideum/ Dictyostelium_discoideum.dicty_2.7 cdna/Dictyostelium_discoideum.dicty_2.7.cdna.all.fa.gz ncrna/Dictyostelium_discoideum.dicty_2.7.ncrna.fa.gz dicty_2.7 train
mouse https://ftp.ensembl.org/pub/release-111/fasta/mus_musculus/ Mus_musculus.GRCm39 cdna/Mus_musculus.GRCm39.cdna.all.fa.gz ncrna/Mus_musculus.GRCm39.ncrna.fa.gz GRCm39 train
zebrafish https://ftp.ensembl.org/pub/release-111/fasta/danio_rerio/ Danio_rerio.GRCz11 cdna/Danio_rerio.GRCz11.cdna.all.fa.gz ncrna/Danio_rerio.GRCz11.ncrna.fa.gz GRCz11 train
chicken https://ftp.ensembl.org/pub/release-111/fasta/gallus_gallus/ Gallus_gallus.bGalGal1.mat.broiler.GRCg7b cdna/Gallus_gallus.bGalGal1.mat.broiler.GRCg7b.cdna.all.fa.gz ncrna/Gallus_gallus.bGalGal1.mat.broiler.GRCg7b.ncrna.fa.gz bGalGal1.mat.broiler.GRCg7b test
rice https://ftp.ensemblgenomes.ebi.ac.uk/pub/plants/release-58/fasta/oryza_indica/ Oryza_indica.ASM465v1 cdna/Oryza_indica.ASM465v1.cdna.all.fa.gz ncrna/Oryza_indica.ASM465v1.ncrna.fa.gz ASM465v1 test
maize https://ftp.ensemblgenomes.ebi.ac.uk/pub/plants/release-58/fasta/zea_mays/ Zea_mays.Zm-B73-REFERENCE-NAM-5.0 cdna/Zea_mays.Zm-B73-REFERENCE-NAM-5.0.cdna.all.fa.gz ncrna/Zea_mays.Zm-B73-REFERENCE-NAM-5.0.ncrna.fa.gz Zm-B73-REFERENCE-NAM-5.0 test
frog https://ftp.ensembl.org/pub/release-111/fasta/xenopus_tropicalis/ Xenopus_tropicalis.UCB_Xtro_10.0 cdna/Xenopus_tropicalis.UCB_Xtro_10.0.cdna.all.fa.gz ncrna/Xenopus_tropicalis.UCB_Xtro_10.0.ncrna.fa.gz UCB_Xtro_10.0 test
rat https://ftp.ensembl.org/pub/release-111/fasta/rattus_norvegicus/ Rattus_norvegicus.mRatBN7.2 cdna/Rattus_norvegicus.mRatBN7.2.cdna.all.fa.gz ncrna/Rattus_norvegicus.mRatBN7.2.ncrna.fa.gz mRatBN7 test
honeybee https://ftp.ensemblgenomes.ebi.ac.uk/pub/metazoa/release-58/fasta/apis_mellifera/ Apis_mellifera.Amel_HAv3.1 cdna/Apis_mellifera.Amel_HAv3.1.cdna.all.fa.gz ncrna/Apis_mellifera.Amel_HAv3.1.ncrna.fa.gz Amel_HAv3.1 test
fission_yeast https://ftp.ensemblgenomes.ebi.ac.uk/pub/fungi/release-58/fasta/schizosaccharomyces_pombe/ Schizosaccharomyces_pombe.ASM294v2 cdna/Schizosaccharomyces_pombe.ASM294v2.cdna.all.fa.gz ncrna/Schizosaccharomyces_pombe.ASM294v2.ncrna.fa.gz ASM294v2 test
tetrahymena https://ftp.ensemblgenomes.ebi.ac.uk/pub/protists/release-58/fasta/tetrahymena_thermophila/ Tetrahymena_thermophila.JCVI-TTA1-2.2 cdna/Tetrahymena_thermophila.JCVI-TTA1-2.2.cdna.all.fa.gz ncrna/Tetrahymena_thermophila.JCVI-TTA1-2.2.ncrna.fa.gz JCVI-TTA1-2.2 test
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -59,5 +59,5 @@ no-lines-before = ["future", "standard-library"]
[tool.snakefmt]
# the line length here should match the one used by ruff
line_length = 100
include = 'Snakefile|Snakefile_*'
include = 'Snakefile|Snakefile_*|\.snakefile$|\.smk$'
exclude = 'dev'
29 changes: 29 additions & 0 deletions scripts/calculate_rnasamba_model_accuracy.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
library(tidyverse)
library(caret)

files <- unlist(snakemake@input)

# read in and format model results for each dataset type
model_predictions <- files %>%
set_names() %>%
map_dfr(read_tsv, .id = "tmp") %>%
mutate(tmp = gsub(".tsv", "", basename(tmp))) %>%
separate(tmp, into = c("coding_type", "dataset_type"), sep = "_", remove = T) %>%
mutate(coding_type = as.factor(coding_type),
classification = as.factor(classification))

# determine model performance
model_confusion_matrix <- confusionMatrix(data = model_predictions$classification,
reference = model_predictions$coding_type)

# convert the model performance metrics into a data frame
model_confusion_matrix_df <- bind_rows(data.frame(value = model_confusion_matrix$overall),
data.frame(value = model_confusion_matrix$byClass)) %>%
rownames_to_column("metric") %>%
mutate(dataset_type = model_predictions$dataset_type[1]) # set dataset type as column even tho it will be in the file name too

# convert the confusion matrix into a data frame
model_confusion_matrix_freq_df <- model_confusion_matrix %>% as.table() %>% data.frame()

write_tsv(model_confusion_matrix_df, snakemake@output[['metrics']])
write_tsv(model_confusion_matrix_freq_df, snakemake@output[['freq']])
Loading