-
Notifications
You must be signed in to change notification settings - Fork 2
/
main.py
69 lines (54 loc) · 2.12 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
from DataFidelities.MRIClass import MRIClass
from Regularizers.robjects import *
from iterAlgs import *
from util import *
import settings
import matplotlib.pyplot as plt
import scipy.io as spio
import numpy as np
import os
####################################################
#### HYPER-PARAMETERS ###
####################################################
opt = settings.opt
os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpu_id
for arg in vars(opt):
print(arg, ':', getattr(opt, arg))
####################################################
#### DATA PREPARATION ###
####################################################
# prepare workspace
np.random.seed(0)
# function for evaluating SNR
evaluateSNR = lambda x, xhat: 20*np.log10(np.linalg.norm(x.flatten('F'))/np.linalg.norm(x.flatten('F')-xhat.flatten('F')))
# load image
data_mat = spio.loadmat('data/{}'.format(opt.data_name), squeeze_me=True)
x = data_mat['img']
# normalize x to 0-1.0
x = (x-x.min())/(x.max()-x.min())
imgSize = np.array(x.shape)
# measure
numLines = 36 # 60 # 48
mask = MRIClass.genMask(imgSize, numLines)
masksum=mask.sum()
y = MRIClass.fmult(x,mask)
# save the IFFT result
x_fft = cal_fft(x)
subsample_x_fft = x_fft * mask
ifft_subsampled_x = np.array(cal_ifft(subsample_x_fft))
plt.figure(2)
plt.imsave('IFFT_36lines_MRI_Knee_{}.jpg'.format(opt.data_name), ifft_subsampled_x, cmap='gray')
# prepare ground truth xtrue
xtrue = x
####################################################
#### NETWORK INITIALIZATION ###
####################################################
MultiHeadRNNClass = MultiHeadRNN(x.shape, sigma=opt.sigma)
####################################################
#### PnP ###
####################################################
mriObj = MRIClass(y, mask)
rObj = MultiHeadRNNClass
# optimize with APGM
recon, out = apgmEst(mriObj, rObj, numIter=opt.num_iter, step=1., accelerate=False, stochastic=False, mini_batch=20, verbose=True,
is_save=True, save_path='line36_apgm_result_{}_sigma{}'.format(opt.data_name, opt.sigma), xtrue=xtrue)