Skip to content

Commit

Permalink
Fibertools ONT ml code from hackathon
Browse files Browse the repository at this point in the history
  • Loading branch information
anupamajha1 committed Oct 22, 2023
1 parent 760cffc commit 442832f
Show file tree
Hide file tree
Showing 8 changed files with 2,296 additions and 0 deletions.
142 changes: 142 additions & 0 deletions m6a_calling/autocorr_method.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Thu Sep 21 09:51:51 2023
@author: morgan hamm
"""

import numpy as np
import pandas as pd
import scipy
from scipy import signal
import argparse

# import seaborn as sns

args = argparse.Namespace(npz_file='/home/morgan/Documents/grad_school/misc_code/hackathon/merged_00_100p_20k_autocorr_input_5M_set2.npz',
invert_ml=True, ml_cutoff=0.938, dorado_cutoff=0.95,
n_sites=2000, output_file='/home/morgan/Documents/grad_school/misc_code/hackathon/test_out.npz')


# =============================================================================
# stack = []
# for i in range(2000):
# subset = preds[preds['id_hash'] == pos_ids[i]]
# if subset.iloc[-1]['pos'] > 820:
# autocorr = auto_corr(subset, score_col="dorado", cutoff=0.37)
# if (autocorr is not None):
# stack.append(autocorr)
#
# temp = np.stack(stack)
# temp2 = np.nansum(temp, axis=0)/temp.shape[0]
#
# sns.lineplot(lags, temp2)
# =============================================================================

def auto_corr(subset, w_len = 400, big_w_len = 800, score_col = 'ml', cutoff=0.938):
w_start = int((subset.iloc[-1]['pos'] / 2) - (big_w_len / 2))

filt = subset[(subset['pos'] >= w_start) & ((subset['pos'] < w_start + big_w_len) & (subset[score_col] >= cutoff) ) ]

if len(filt) == 0:
return(None)

big_window = np.zeros(big_w_len, dtype=float)
for i, row in filt.iterrows():
big_window[int(row['pos']) - w_start] = 1 /len(filt)
little_window = big_window[0:w_len]

# scale_factor = 1 / len(filt)
# scale_factor = 1/sum(little_window)
# scale_factor = 1/sum(big_window)
# little_window = little_window * scale_factor
# big_window = big_window * scale_factor
if sum(little_window) == 0:
return(None)

autocorr = signal.correlate(big_window, little_window, "valid")
# lags = sp.signal.correlation_lags(big_w_len, w_len, "valid")

# norm_fact = np.sum(little_window**2)
# sns.lineplot(lags, autocorr/norm_fact)
return(autocorr)


def auto_corr_n(preds, score_col, n_sites, cutoff):
stack = []
read_ids = np.unique(preds['id_hash'])
for i in range(n_sites):
subset = preds[preds['id_hash'] == read_ids[i]]
#print(i, len(subset))
if subset.iloc[-1]['pos'] > 820:
autocorr = auto_corr(subset, score_col=score_col, cutoff=cutoff)
#print(f"autocorr")
if (autocorr is not None):
#if autocorr.shape[0] == 401:
stack.append(autocorr)
print("stack: ", len(stack))
if len(stack) > 0:
all_out = np.stack(stack)
return(np.nansum(all_out, axis=0)/float(all_out.shape[0]))
else:
all_out = stack
return all_out



def main(args):
data = np.load(args.npz_file)
preds = data['preds']
preds = pd.DataFrame(preds, columns=['id_hash', 'pos', 'label', 'dorado', 'ml'])

if args.invert_ml == True:
preds['ml'] = 1 - preds['ml']

# read_ids = np.unique(preds['id_hash'])

preds_neg = preds[preds['label'] == 0]
preds_pos = preds[preds['label'] == 1]

# pos_ids = np.unique(preds_pos['id_hash'])
# neg_ids = np.unique(preds_neg['id_hash'])


print("lab1_ml_data")
lab1_ml_data = auto_corr_n(preds_pos, score_col='ml', n_sites=args.n_sites, cutoff=args.ml_cutoff)

print("lab0_ml_data")
lab0_ml_data = auto_corr_n(preds_neg, score_col='ml', n_sites=args.n_sites, cutoff=args.ml_cutoff)

print("lab1_dorado_data")
lab1_dorado_data = auto_corr_n(preds_pos, score_col='dorado', n_sites=args.n_sites, cutoff=args.dorado_cutoff)

print("lab0_dorado_data")
lab0_dorado_data = auto_corr_n(preds_neg, score_col='dorado', n_sites=args.n_sites, cutoff=args.dorado_cutoff)

npz_struct = {'lab1_ml_data':lab1_ml_data,
'lab0_ml_data':lab0_ml_data,
'lab1_dorado_data':lab1_dorado_data,
'lab0_dorado_data':lab0_dorado_data,}

np.savez(args.output_file, **npz_struct)

# ---------------------


if __name__ == '__main__':
parser = argparse.ArgumentParser(
description='parse an ONT BAM file into features surrounding called m6As')
parser.add_argument('npz_file', help='npz file with ML calls for all As in a set of fibers')
parser.add_argument('-i', '--invert_ml', type=bool, default=False,
help='set ml score to 1 - ML')
parser.add_argument('-m', '--ml_cutoff', type=float, default=0.938,
help='cutoff to use for ML results')
parser.add_argument('-d', '--dorado_cutoff', type=float, default=0.95,
help='cutoff to use for ML results')
parser.add_argument('-n', '--n_sites', type=int, default=5000,
help='number of sites or number of fibers to look at')
parser.add_argument('-o', '--output_file', type=str, default='output.npz',
help='output file name prefix')
args = parser.parse_args()
main(args)
145 changes: 145 additions & 0 deletions m6a_calling/m6a_autocorrelation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
import torch
import argparse
import numpy as np
import configparser
import pandas as pd
import _pickle as pickle
from m6a_cnn import M6ANet
from m6a_semi_supervised_cnn import tdc, count_pos_neg, make_one_hot_encoded


def find_window(score, precision_score_table):
for j in range(len(precision_score_table)-1, 1):
if score >= precision_score_table[j, 1]:
if score <= precision_score_table[j+1, 1]:
#print(j, precision_score_table[j, 1], precision_score_table[j+1, 1])
return j
return 0

def convert_cnn_score_to_int(precision_score_table, float_scores):
vfind_window = np.vectorize(find_window, excluded=['precision_score_table'])
unint_score = vfind_window(score=float_scores, precision_score_table=precision_score_table)
return unint_score

"""
def convert_cnn_score_to_int(precision_score_table, float_scores):
uint_score = np.zeros((float_scores.shape))
for i, score in enumerate(float_scores):
if i % 100000 == 0:
print(i)
for j in range(1, len(precision_score_table)-1, 1):
if score >= precision_score_table[j, 1]:
if score <= precision_score_table[j+1, 1]:
uint_score[i] = j
#print(uint_score[i], precision_score_table[j, 1], score, precision_score_table[j+1, 1])
break
return uint_score
"""

def make_ont_predictions_255(best_sup_save_model,
data_npz,
output_obj,
precision_score_tsv,
device="cuda"):

precision_score_table = np.loadtxt(precision_score_tsv,
delimiter="\t",
dtype=str)

precision_score_table = np.array(precision_score_table[1:, :], dtype=float)
print(f"precision_score_table: {precision_score_table[0]}")

# Load the supervised model for transfer learning
model = M6ANet()
with open(best_sup_save_model, "rb") as fp:
model.load_state_dict(pickle.load(fp))

model = model.to(device)

val_data = np.load(data_npz)
X_val = np.array(val_data['features'], dtype=float)
print(f"X_val: {X_val.shape}")

dorado_score = X_val[:, 5, 7]
X_val[:, 4, :] = X_val[:, 4, :]/255.0
X_val[:, 5, :] = X_val[:, 5, :]/255.0
y_val = np.array(val_data['labels'], dtype=int)
read_ids = val_data['read_ids']

#v_hash = np.vectorize(hash)
#read_id_hashes = v_hash(read_ids)

read_ids_unique = np.unique(read_ids)

read_idx_dict = dict()
for i, read in enumerate(read_ids_unique):
read_idx_dict[read] = i

read_id_hashes = np.zeros((read_ids.shape))
for i, read in enumerate(read_ids):
read_id_hashes[i] = read_idx_dict[read]


positions = val_data['positions']
# convert to one hot encoded
y_val_ohe = make_one_hot_encoded(y_val)

# convert data to tensors
X_val = torch.tensor(X_val).float()
y_val_ohe = torch.tensor(y_val_ohe).float()
#X_val = X_val.to(device)
#y_val_ohe = y_val_ohe.to(device)

preds_y = model.predict(X_val, device=device)
total_len = len(preds_y)

preds_y = preds_y[:, 0].numpy()
preds_y_uint = convert_cnn_score_to_int(precision_score_table, preds_y)

read_id_hashes = read_id_hashes[0:total_len][:, np.newaxis]
positions = positions[0:total_len][:, np.newaxis]
y_val = y_val[0:total_len][:, np.newaxis]
preds_y_uint = preds_y_uint[0:total_len][:, np.newaxis]
dorado_score = dorado_score[0:total_len][:, np.newaxis]

print(f"read_ids: {read_id_hashes.shape}")
print(f"positions: {positions.shape}")
print(f"y_val: {y_val.shape}")
print(f"preds_y_uint: {preds_y_uint.shape}")
print(f"dorado_score: {dorado_score.shape}")

output_arr = np.concatenate((read_id_hashes, positions, y_val, dorado_score, preds_y_uint), axis=1)
output_arr = np.array(output_arr, dtype=float)
np.savez(output_obj, preds=output_arr)

with open(f"{output_obj}_dict.pkl", 'wb') as f:
pickle.dump(read_idx_dict, f)





best_sup_save_model="../models/m6A_ONT_semi_supervised_cnn_5M_set2.best.torch.pickle"
data_npz="/net/gs/vol4/shared/public/hackathon_2023/Stergachis_lab/data/all_sites_npz/merged_00_100p_20k.npz"
output_obj="../results/merged_00_100p_20k_autocorr_input_5M_set2_0_255.npz"
precision_score_tsv="../results/semi_ONT_score_precision_5M_set2.tsv"
#make_ont_predictions_255(best_sup_save_model, data_npz, output_obj, precision_score_tsv)

best_sup_save_model="../models/m6A_ONT_semi_supervised_cnn_5M_set3.best.torch.pickle"
data_npz="/net/gs/vol4/shared/public/hackathon_2023/Stergachis_lab/data/all_sites_npz/merged_00_100p_20k.npz"
output_obj="../results/merged_00_100p_20k_autocorr_input_5M_set3_0_255.npz"
precision_score_tsv="../results/semi_ONT_score_precision_5M_set3.tsv"
#make_ont_predictions_255(best_sup_save_model, data_npz, output_obj, precision_score_tsv)

best_sup_save_model="../models/m6A_ONT_semi_supervised_cnn_5M_set3_run2.best.torch.pickle"
data_npz="/net/gs/vol4/shared/public/hackathon_2023/Stergachis_lab/data/all_sites_npz/merged_00_100p_20k.npz"
output_obj="../results/merged_00_100p_20k_autocorr_input_5M_set3_run2_0_255.npz"
precision_score_tsv="../results/semi_ONT_score_precision_5M_set3_run2.tsv"
#make_ont_predictions_255(best_sup_save_model, data_npz, output_obj, precision_score_tsv)

best_sup_save_model="../models/m6A_ONT_semi_supervised_cnn_5M_set3_run2.best.torch.pickle"
data_npz="/net/gs/vol4/shared/public/hackathon_2023/Stergachis_lab/data/NAPA_raw/HG002_2_NAPA_00.npz"
output_obj="../results/HG002_2_NAPA_00_autocorr_input_5M_set3_run2_0_255.npz"
precision_score_tsv="../results/semi_ONT_score_precision_5M_set3_run2.tsv"
make_ont_predictions_255(best_sup_save_model, data_npz, output_obj, precision_score_tsv)
Loading

0 comments on commit 442832f

Please sign in to comment.