-
Notifications
You must be signed in to change notification settings - Fork 8
/
only_downstream_snakefile
159 lines (144 loc) · 7.82 KB
/
only_downstream_snakefile
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
#Snakefile for downstream analysis
import os
import re
import ast
import yaml
#### FUNCTIONS ####
def absoluteFilePaths(directory):
files = os.listdir(directory)
files = [file for file in files if file.endswith('.pth')]
files = [directory+"/"+file for file in files]
return files
def extract_model_name(model_path):
file_names = [file_name.split('/')[-1] for file_name in model_path]
file_names = [file_name.split('.')[0] for file_name in file_names]
return file_names
def get_channel_name_combi(channel_combi_num, channel_dict):
name_of_channel_combi = ""
for channel_number in iter(str(channel_combi_num)):
name_of_channel_combi = "_".join([name_of_channel_combi, channel_dict[int(channel_number)]])
return name_of_channel_combi
def get_channel_number_combi(channel_names, channel_dict):
channel_combi = ""
for channel_name in channel_names.split('_'):
for key, value in channel_dict.items():
if value == channel_name:
channel_combi = "".join([channel_combi, str(key)])
return channel_combi
def get_channel_name_combi_list(selected_channels, channel_dict):
channel_names = []
for channel_combi in selected_channels:
channel_names.append(get_channel_name_combi(channel_combi,channel_dict))
return channel_names
def load_norm_per_channel(config):
if config['compute_cls_features']['normalize']:
if not config['compute_cls_features']['parse_mean_std_from_file']:
return config['compute_cls_features']['norm_per_channel']
else:
with open(config['compute_cls_features']['mean_std_file_location']) as f:
norm_per_channel_json = json.load(f)
norm_per_channel = str([norm_per_channel_json['mean'], norm_per_channel_json['std']])
return norm_per_channel
else:
return None
def save_config_file(config, save_dir):
os.makedirs(save_dir, exist_ok=True)
with open(f"{save_dir}/run_config_dump.json", "w") as f:
json.dump(config, f)
with open(f"{save_dir}/run_config_dump.yaml", "w") as f:
yaml.dump(config, f)
#### CONFIG FILE ####
configfile: "run_config.yaml"
#### PARSING ####
name_of_run = config['meta']['name_of_run']
sk_save_dir = config['meta']['output_dir']
save_dir_downstream_run = sk_save_dir+"/"+name_of_run
selected_channels = config['meta']['selected_channel_combination_per_run']
channel_dict = config['meta']['channel_dict']
models_repo = absoluteFilePaths(config['compute_cls_features']['pretrained_weights'])
model_name = extract_model_name(models_repo)
print('model list:')
for model in model_name:
print(model)
print('channel list:', selected_channels)
save_config_file(config, save_dir_downstream_run)
##### RULES #####
rule all:
input:
expand("{save_dir_downstream_run}/kNN/global_kNN.txt", save_dir_downstream_run=save_dir_downstream_run),
expand("{save_dir_downstream_run}/embedding_plots/channel{channel_names}_model_{model_name}_umap.png", save_dir_downstream_run=save_dir_downstream_run, channel_names=get_channel_name_combi_list(selected_channels, channel_dict), model_name=model_name),
expand("{save_dir_downstream_run}/attention_images/channel{channel_names}_model_{model_name}/run_log.txt", save_dir_downstream_run=save_dir_downstream_run, channel_names=get_channel_name_combi_list(selected_channels, channel_dict), model_name=model_name),
rule visualise_attention:
output:
expand("{save_dir_downstream_run}/attention_images/channel{channel_names}_model_{model_name}/run_log.txt", save_dir_downstream_run="{save_dir_downstream_run}", channel_names="{channel_names}", model_name="{model_name}")
input:
path_to_model = config['compute_cls_features']['pretrained_weights']+"/{model_name}.pth",
params:
script_params = {**config['meta'], **config['compute_cls_features'], **config['attention_visualisation']},
selected_channel_indices= lambda wildcards: get_channel_number_combi(wildcards.channel_names, channel_dict),
resources:
mem_mb = 4000,
cores = 4,
shell:
'python pyscripts/visualise_attention.py --selected_channels {params.selected_channel_indices} --pretrained_weights {input.path_to_model} --parse_params """{params.script_params}"""'
rule plot_in_2D:
input:
features = expand("{save_dir_downstream_run}/CLS_features/channel{channel_names}_model_{model_name}_features.csv",save_dir_downstream_run="{save_dir_downstream_run}", channel_names="{channel_names}", model_name="{model_name}"),
class_labels = "{save_dir_downstream_run}/CLS_features/class_labels.csv"
output:
"{save_dir_downstream_run}/embedding_plots/channel{channel_names}_model_{model_name}_umap.png"
resources:
mem_mb = 4000,
cores = 4,
params:
scDINO_full_pipeline = False,
topometry_plots = config['umap_eval']['topometry_plots']
script:
'pyscripts/plot_in_2D.py'
rule calc_global_kNN:
input:
features = expand("{save_dir_downstream_run}/CLS_features/channel{channel_names}_model_{model_name}_features.csv", save_dir_downstream_run=save_dir_downstream_run, channel_names=get_channel_name_combi_list(selected_channels, channel_dict), model_name=model_name),
class_labels = "{save_dir_downstream_run}/CLS_features/class_labels.csv"
output:
"{save_dir_downstream_run}/kNN/global_kNN.txt"
resources:
mem_mb = 4000,
cores = 4,
params:
scDINO_full_pipeline = False,
run_names = expand("channel{channel_names}_model_{model_name}", channel_names=get_channel_name_combi_list(selected_channels, channel_dict), model_name=model_name),
save_dir= save_dir_downstream_run,
seed = config['meta']['seed']
script:
'pyscripts/global_kNN.py'
rule extract_labels:
input:
path_to_model = models_repo[0] #to extract the labels arbitrarily from the first model
output:
labels = expand("{save_dir_downstream_run}/CLS_features/{file_name}.csv", file_name=['class_labels','image_paths'], save_dir_downstream_run="{save_dir_downstream_run}")
params:
script_params = {**config['meta'], **config['compute_cls_features']},
num_gpus = config['compute_cls_features']['num_gpus'],
selected_channel_indices= selected_channels[0],
resources:
gpus = config['compute_cls_features']['num_gpus'],
mem_mb = 4000,
cores = 4,
shell:
'python -m torch.distributed.launch --nproc_per_node {params.num_gpus} pyscripts/extract_image_labels.py --selected_channels {params.selected_channel_indices} --pretrained_weights {input.path_to_model} --parse_params """{params.script_params}"""'
rule compute_CLS_features:
input:
path_to_model = config['compute_cls_features']['pretrained_weights']+"/{model_name}.pth",
output:
expand("{save_dir_downstream_run}/CLS_features/channel{channel_names}_model_{model_name}_features.csv", save_dir_downstream_run="{save_dir_downstream_run}", channel_names="{channel_names}", model_name="{model_name}")
params:
script_params = {**config['meta'], **config['compute_cls_features']},
num_gpus = config['compute_cls_features']['num_gpus'],
selected_channel_indices= lambda wildcards: get_channel_number_combi(wildcards.channel_names, channel_dict),
norm_per_channel = load_norm_per_channel(config)
resources:
gpus = config['compute_cls_features']['num_gpus'],
mem_mb = config['compute_cls_features']['num_workers']*4000,
cores = config['compute_cls_features']['num_workers']
shell:
'python -m torch.distributed.launch --nproc_per_node {params.num_gpus} pyscripts/compute_CLS_features.py --selected_channels {params.selected_channel_indices} --pretrained_weights {input.path_to_model} --norm_per_channel """{params.norm_per_channel}""" --parse_params """{params.script_params}"""'