-
Notifications
You must be signed in to change notification settings - Fork 37
/
global_vars.py
executable file
·215 lines (179 loc) · 7.78 KB
/
global_vars.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
#########################
# Purpose: Sets up global variables to be used throughout
########################
import argparse
import warnings
warnings.filterwarnings("ignore")
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
import logging
tf.get_logger().setLevel(logging.ERROR)
def dir_name_fn(args):
# Setting directory name to store computed weights
dir_name = 'weights/%s/model_%s/%s/k%s_E%s_B%s_C%1.0e_lr%.1e' % (
args.dataset, args.model_num, args.optimizer, args.k, args.E, args.B, args.C, args.eta)
# dir_name = 'weights/k{}_E{}_B{}_C{%e}_lr{}'
output_file_name = 'output'
output_dir_name = 'output_files/%s/model_%s/%s/k%s_E%s_B%s_C%1.0e_lr%.1e' % (
args.dataset, args.model_num, args.optimizer, args.k, args.E, args.B, args.C, args.eta)
figures_dir_name = 'figures/%s/model_%s/%s/k%s_E%s_B%s_C%1.0e_lr%.1e' % (
args.dataset, args.model_num, args.optimizer, args.k, args.E, args.B, args.C, args.eta)
interpret_figs_dir_name = 'interpret_figs/%s/model_%s/%s/k%s_E%s_B%s_C%1.0e_lr%.1e' % (
args.dataset, args.model_num, args.optimizer, args.k, args.E, args.B, args.C, args.eta)
if args.gar != 'avg':
dir_name = dir_name + '_' + args.gar
output_file_name = output_file_name + '_' + args.gar
output_dir_name = output_dir_name + '_' + args.gar
figures_dir_name = figures_dir_name + '_' + args.gar
interpret_figs_dir_name = interpret_figs_dir_name + '_' + args.gar
if args.lr_reduce:
dir_name += '_lrr'
output_dir_name += '_lrr'
figures_dir_name += '_lrr'
if args.steps is not None:
dir_name += '_steps' + str(args.steps)
output_dir_name += '_steps' + str(args.steps)
figures_dir_name += '_steps' + str(args.steps)
if args.mal:
if 'multiple' in args.mal_obj:
args.mal_obj = args.mal_obj + str(args.mal_num)
if 'dist' in args.mal_strat:
args.mal_strat += '_rho' + '{:.2E}'.format(args.rho)
if args.E != args.mal_E:
args.mal_strat += '_ext' + str(args.mal_E)
if args.mal_delay > 0:
args.mal_strat += '_del' + str(args.mal_delay)
if args.ls != 1:
args.mal_strat += '_ls' + str(args.ls)
if 'data_poison' in args.mal_strat:
args.mal_strat += '_reps' + str(args.data_rep)
if 'no_boost' in args.mal_strat or 'data_poison' in args.mal_strat:
args.mal_strat = args.mal_strat
else:
# if 'auto' not in args.mal_strat:
args.mal_strat += '_boost' + str(args.mal_boost)
output_file_name += '_mal_' + args.mal_obj + '_' + args.mal_strat
dir_name += '_mal_' + args.mal_obj + '_' + args.mal_strat
if not os.path.exists(dir_name):
os.makedirs(dir_name)
if not os.path.exists(output_dir_name):
os.makedirs(output_dir_name)
if not os.path.exists(figures_dir_name):
os.makedirs(figures_dir_name)
if not os.path.exists(interpret_figs_dir_name):
os.makedirs(interpret_figs_dir_name)
dir_name += '/'
output_dir_name += '/'
figures_dir_name += '/'
interpret_figs_dir_name += '/'
# print(dir_name)
# print(output_file_name)
return dir_name, output_dir_name, output_file_name, figures_dir_name, interpret_figs_dir_name
def init():
# Reading in arguments for the run
parser = argparse.ArgumentParser()
parser.add_argument("--dataset", default='MNIST',
help="dataset to be used")
parser.add_argument("--model_num", type=int,
default=0, help="model to be used")
parser.add_argument("--optimizer", default='adam',
help="optimizer to be used")
parser.add_argument("--eta", type=float, default=1e-3,
help="learning rate")
parser.add_argument("--k", type=int, default=10, help="number of agents")
parser.add_argument("--C", type=float, default=1.0,
help="fraction of agents per time step")
parser.add_argument("--E", type=int, default=5,
help="epochs for each agent")
parser.add_argument("--steps", type=int, default=None,
help="GD steps per agent")
parser.add_argument("--T", type=int, default=80, help="max time_steps")
parser.add_argument("--B", type=int, default=100, help="agent batch size")
parser.add_argument("--train", default=True, action='store_true')
parser.add_argument("--lr_reduce", action='store_true')
parser.add_argument("--mal", default=True, action='store_true')
parser.add_argument("--mal_obj", default='single',
help='Objective for malicious agent')
parser.add_argument("--mal_strat", default='asyncFL',
help='Strategy for malicious agent')
parser.add_argument("--mal_num", type=int, default=1,
help='Objective for simultaneous targeting')
parser.add_argument("--mal_delay", type=int, default=0,
help='Delay for wait till converge')
parser.add_argument("--mal_boost", type=float, default=10,
help='Boosting factor for alternating minimization attack')
parser.add_argument("--mal_E", type=float, default=5,
help='Benign training epochs for malicious agent')
parser.add_argument("--ls", type=int, default=1,
help='Training steps for each malicious step')
parser.add_argument("--gar", type=str, default='avg',
help='Gradient Aggregation Rule')
parser.add_argument("--rho", type=float, default=1e-4,
help='Weighting factor for distance constraints')
parser.add_argument("--data_rep", type=float, default=10,
help='Data repetitions for data poisoning')
parser.add_argument('--gpu_ids', nargs='+', type=int, default=None,
help='GPUs to run on')
global args
args = parser.parse_args()
# print(args)
# making sure single agent run is only for the benign case
if args.k==1:
assert args.mal==False
if args.mal:
global mal_agent_index
mal_agent_index = args.k - 1
# Moving rate of 1.0 leads to full overwrite
global moving_rate
global gpu_ids
if args.gpu_ids is not None:
gpu_ids = args.gpu_ids
else:
gpu_ids = [0]
global num_gpus
num_gpus = len(gpu_ids)
global max_agents_per_gpu
global IMAGE_ROWS, IMAGE_COLS, NUM_CHANNELS, NUM_CLASSES, BATCH_SIZE
global max_acc
if 'MNIST' in args.dataset:
IMAGE_ROWS = 28
IMAGE_COLS = 28
NUM_CHANNELS = 1
NUM_CLASSES = 10
BATCH_SIZE = 100
if args.dataset == 'MNIST':
max_acc = 100.0
elif args.dataset == 'fMNIST':
max_acc = 90.0
max_agents_per_gpu = 6
mem_frac = 0.05
moving_rate = 1.0
elif args.dataset == 'census':
global DATA_DIM
DATA_DIM = 104
BATCH_SIZE = 50
NUM_CLASSES = 2
max_acc = 85.0
max_agents_per_gpu = 6
mem_frac = 0.05
moving_rate = 1.0
elif args.dataset == 'CIFAR-10':
IMAGE_COLS = 32
IMAGE_ROWS = 32
NUM_CHANNELS = 3
NUM_CLASSES = 10
BATCH_SIZE = 100
max_acc = 90.0
max_agents_per_gpu = 6
mem_frac = 0.05
moving_rate = 1.0
if max_agents_per_gpu < 1:
max_agents_per_gpu = 1
global gpu_options
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=mem_frac)
global dir_name, output_dir_name, output_file_name, figures_dir_name, interpret_figs_dir_name
dir_name, output_dir_name, output_file_name, figures_dir_name, interpret_figs_dir_name = dir_name_fn(
args)
return args