Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

create application to predict, and fix translator #30

Open
wants to merge 21 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,23 @@ CNN + MLP occured overfitting to the training data.

Relational networks shows far better results in relational questions and non-relation questions.

## Application Demo
retrain experiments:

![](./readme_img/binary_relational_acc.png)
![](./readme_img/non_relational_acc.png)
![](./readme_img/ternary_relational_acc.png)


You can randomly generate and move 2D shaped objects and edit text to ask questions.

$ python application.py

<img src="./readme_img/relational-network-application.gif" width="1600">


## Contributions

[@gngdb](https://github.com/gngdb) speeds up the model by 10 times.

[@neural022](https://github.com/neural022) and [@hhhlll21qq](https://github.com/hhhlll21qq) build application.
21 changes: 21 additions & 0 deletions RN_1_log.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
epoch,train_acc_ternary,train_acc_rel,train_acc_norel,train_acc_ternary,test_acc_rel,test_acc_norel
1,51.555355976485956,50.55825440888308,52.975996080992815,54.13306451612903,61.29032258064516,55.69556451612903
2,53.33830013063357,66.2342831482691,55.337606139777925,52.368951612903224,72.88306451612904,56.703629032258064
3,53.53833278902678,69.31539843239713,56.32654310907903,51.965725806451616,72.22782258064517,58.71975806451613
4,53.70672762900065,69.78996570868713,58.261552906597,52.167338709677416,73.38709677419355,61.44153225806452
5,53.97615937295885,70.14104343566297,60.37924559111692,52.318548387096776,72.53024193548387,63.457661290322584
6,54.60687459177009,70.68807152188113,61.68456074461137,53.024193548387096,71.52217741935483,64.21370967741936
7,55.12634715871979,70.62377531025473,62.19382756368387,52.671370967741936,72.58064516129032,63.608870967741936
8,55.39782005225343,71.11058948399739,62.6500244937949,53.931451612903224,74.24395161290323,63.76008064516129
9,55.49477465708687,71.87704114957545,63.30013063357283,54.48588709677419,74.04233870967742,66.22983870967742
10,55.601935009797515,72.61491672109732,66.2832707380797,53.42741935483871,74.49596774193549,71.27016129032258
11,55.741753755715216,73.1078543435663,77.11258981058133,53.78024193548387,74.44556451612904,86.18951612903226
12,55.741753755715216,73.47015839320706,93.81633736120183,54.435483870967744,76.00806451612904,98.03427419354838
13,55.940765839320704,76.68394839973874,98.29564010450686,54.939516129032256,80.54435483870968,98.5383064516129
14,56.174477465708684,82.49101894186806,99.06107119529719,53.881048387096776,84.77822580645162,97.88306451612904
15,56.41125081645983,84.75771554539517,99.32744121489223,55.39314516129032,85.38306451612904,98.7399193548387
16,56.56841933376878,86.46513716525146,99.5009389288047,55.292338709677416,85.6350806451613,98.94153225806451
17,56.68068256041803,87.47448563030699,99.63463422599608,56.30040322580645,87.3991935483871,98.99193548387096
18,57.09911822338341,88.2388961463096,99.67851894186806,55.645161290322584,87.3991935483871,99.14314516129032
19,57.25730731548008,88.99208033964729,99.7305682560418,56.149193548387096,88.20564516129032,99.19354838709677
20,57.507348138471585,89.61258981058133,99.7601649248857,56.09879032258065,88.10483870967742,98.89112903225806
Expand Down
54 changes: 54 additions & 0 deletions application.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# -*- coding: utf-8 -*-
"""
Created on Mon Jan 18 10:29:41 2022

@author: helen, george chen(neural022)
"""

import sys
import cv2
from PyQt5 import QtWidgets
from utils import Ui_MainWindow, RNPredictor
from PIL import ImageQt,Image
from numpy import asarray
import time


class MainWindow(QtWidgets.QMainWindow, Ui_MainWindow):
def __init__(self, parent=None):
super(MainWindow, self).__init__(parent=parent)
self.setupUi(self)
self.rn_predictor = RNPredictor()

self.pushButton.clicked.connect(self.pushButton_clicked)
self.pushButton_2.clicked.connect(self.label.pushButton2_clicked)

def pushButton_clicked(self):#OK
label_image = ImageQt.fromqpixmap(self.label.grab())
image_RGB = Image.new("RGB", label_image.size, (255, 255, 255))
image_RGB.paste(label_image, mask=label_image.split()[3])
image_RGB = image_RGB.resize((75, 75), Image.ANTIALIAS)
image_RGB.save('image_RGB.jpg', 'JPEG', quality=100)
img_rgb_array = asarray(image_RGB)
img_bgr_array = cv2.cvtColor(img_rgb_array, cv2.COLOR_RGB2BGR)
# print(img_bgr_array.shape)#(75,75,3)
time.sleep(0.5)
# Question ComboBox
# Answer label4
if self.comboBox.currentText() != '':
self.question = self.comboBox.currentText()
print('Question:', self.question)
question = self.rn_predictor.tokenize(self.question)
self.answer = self.rn_predictor.predict((img_bgr_array/255, question))
print('Answer:', self.answer)
self.label4.setText(self.answer)


if __name__ == "__main__":
app = QtWidgets.QApplication([])
app.setStyle('Fusion')
w = MainWindow()
w.show()

sys.exit(app.exec())

Binary file added best_model/epoch_RN_20.pth
Binary file not shown.
113 changes: 94 additions & 19 deletions environment.yml
Original file line number Diff line number Diff line change
@@ -1,28 +1,103 @@
name: RN3
channels:
- pytorch
- defaults
dependencies:
- ca-certificates
- certifi
- libcxx
- libcxxabi
- libedit
- libffi
- ncurses
- openssl
- pip=20.0.2=py38_1
- python=3.8.1
- readline
- setuptools
- sqlite=3.31.1
- tk
- wheel
- xz
- zlib
- blas=1.0=mkl
- ca-certificates=2021.10.26=haa95532_2
- certifi=2021.10.8=py38haa95532_0
- cudatoolkit=11.3.1=h59b6b97_2
- freetype=2.10.4=hd328e21_0
- intel-openmp=2021.4.0=haa95532_3556
- jpeg=9d=h2bbff1b_0
- libpng=1.6.37=h2a8f88b_0
- libtiff=4.2.0=hd0e1b90_0
- libuv=1.40.0=he774522_0
- libwebp=1.2.0=h2bbff1b_0
- lz4-c=1.9.3=h2bbff1b_1
- mkl=2021.4.0=haa95532_640
- mkl-service=2.4.0=py38h2bbff1b_0
- mkl_fft=1.3.1=py38h277e83a_0
- mkl_random=1.2.2=py38hf11a4ad_0
- numpy-base=1.21.2=py38h0829f74_0
- olefile=0.46=pyhd3eb1b0_0
- openssl=1.1.1l=h2bbff1b_0
- pillow=8.4.0=py38hd45dc43_0
- pip=21.2.2=py38haa95532_0
- python=3.8.1=h5fd99cc_8_cpython
- pytorch=1.10.1=py3.8_cuda11.3_cudnn8_0
- pytorch-mutex=1.0=cuda
- setuptools=58.0.4=py38haa95532_0
- six=1.16.0=pyhd3eb1b0_0
- sqlite=3.31.1=h2a8f88b_1
- tk=8.6.11=h2bbff1b_0
- torchaudio=0.10.1=py38_cu113
- torchvision=0.11.2=py38_cu113
- typing_extensions=3.10.0.2=pyh06a4308_0
- vc=14.2=h21ff451_1
- vs2015_runtime=14.27.29016=h5e58377_2
- wheel=0.37.0=pyhd3eb1b0_1
- wincertstore=0.2=py38haa95532_2
- xz=5.2.5=h62dcd97_0
- zlib=1.2.11=h8cc25b3_4
- zstd=1.4.9=h19a0ad4_0
- pip:
- absl-py==1.0.0
- altgraph==0.17.2
- cachetools==4.2.4
- charset-normalizer==2.0.9
- click==7.1.2
- colorama==0.4.4
- filelock==3.4.2
- future==0.18.2
- google-auth==1.35.0
- google-auth-oauthlib==0.4.6
- grpcio==1.43.0
- huggingface-hub==0.4.0
- idna==3.3
- importlib-metadata==4.10.0
- joblib==1.1.0
- markdown==3.3.6
- numpy==1.18.1
- oauthlib==3.1.1
- opencv-python==4.2.0.32
- torch==1.4.0
- packaging==21.3
- pandas==1.3.5
- pefile==2021.9.3
- protobuf==3.19.1
- pyasn1==0.4.8
- pyasn1-modules==0.2.8
- pyinstaller==4.8
- pyinstaller-hooks-contrib==2021.5
- pyparsing==3.0.6
- pyqt5==5.15.6
- pyqt5-plugins==5.15.4.2.2
- pyqt5-qt5==5.15.2
- pyqt5-sip==12.9.0
- python-dateutil==2.8.2
- python-dotenv==0.19.2
- pytz==2021.3
- pywin32-ctypes==0.2.0
- pyyaml==6.0
- qt5-applications==5.15.2.2.2
- qt5-tools==5.15.2.1.2
- regex==2021.11.10
- requests==2.27.0
- requests-oauthlib==1.3.0
- rsa==4.8
- sacremoses==0.0.47
- scikit-learn==1.0.2
- scipy==1.7.3
- sklearn==0.0
- tensorboard==2.2.0
- tensorboard-plugin-wit==1.6.0.post2
prefix: ~/miniconda3/envs/RN3
- threadpoolctl==3.0.0
- tokenizers==0.10.3
- torch-tb-profiler==0.3.1
- tqdm==4.62.3
- transformers==4.15.0
- typing-extensions==4.0.1
- urllib3==1.26.7
- werkzeug==2.0.2
- zipp==3.7.0
prefix: D:\Coding\Anaconda3\envs\RN3
Binary file added image_RGB.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
51 changes: 22 additions & 29 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,32 +20,24 @@

# Training settings
parser = argparse.ArgumentParser(description='PyTorch Relational-Network sort-of-CLVR Example')
parser.add_argument('--model', type=str, choices=['RN', 'CNN_MLP'], default='RN',
help='resume from model stored')
parser.add_argument('--batch-size', type=int, default=64, metavar='N',
help='input batch size for training (default: 64)')
parser.add_argument('--epochs', type=int, default=20, metavar='N',
help='number of epochs to train (default: 20)')
parser.add_argument('--lr', type=float, default=0.0001, metavar='LR',
help='learning rate (default: 0.0001)')
parser.add_argument('--no-cuda', action='store_true', default=False,
help='disables CUDA training')
parser.add_argument('--seed', type=int, default=1, metavar='S',
help='random seed (default: 1)')
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
help='how many batches to wait before logging training status')
parser.add_argument('--resume', type=str,
help='resume from model stored')
parser.add_argument('--relation-type', type=str, default='binary',
help='what kind of relations to learn. options: binary, ternary (default: binary)')
parser.add_argument('--model', type=str, choices=['RN', 'CNN_MLP'], default='RN', help='resume from model stored')
parser.add_argument('--batch-size', type=int, default=64, metavar='N', help='input batch size for training (default: 64)')
parser.add_argument('--epochs', type=int, default=20, metavar='N', help='number of epochs to train (default: 20)')
parser.add_argument('--lr', type=float, default=0.0001, metavar='LR', help='learning rate (default: 0.0001)')
parser.add_argument('--no-cuda', action='store_true', default=False, help='disables CUDA training')
parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)')
parser.add_argument('--log-interval', type=int, default=10, metavar='N', help='how many batches to wait before logging training status')
parser.add_argument('--resume', type=str, help='resume from model stored')
parser.add_argument('--relation-type', type=str, default='binary', help='what kind of relations to learn. options: binary, ternary (default: binary)')

args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()
torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
args.cuda = True

torch.manual_seed(args.seed)
if args.cuda:
torch.cuda.manual_seed(args.seed)

''''''
summary_writer = SummaryWriter()

if args.model=='CNN_MLP':
Expand All @@ -59,6 +51,7 @@
input_qst = torch.FloatTensor(bs, 18)
label = torch.LongTensor(bs)

# print("gpu:", args.cuda)
if args.cuda:
model.cuda()
input_img = input_img.cuda()
Expand All @@ -70,6 +63,8 @@
label = Variable(label)

def tensor_data(data, i):

# [ batch_size*i : batch_size*i+1 ]
img = torch.from_numpy(np.asarray(data[0][bs*i:bs*(i+1)]))
qst = torch.from_numpy(np.asarray(data[1][bs*i:bs*(i+1)]))
ans = torch.from_numpy(np.asarray(data[2][bs*i:bs*(i+1)]))
Expand All @@ -78,7 +73,6 @@ def tensor_data(data, i):
input_qst.data.resize_(qst.size()).copy_(qst)
label.data.resize_(ans.size()).copy_(ans)


def cvt_data_axis(data):
img = [e[0] for e in data]
qst = [e[1] for e in data]
Expand All @@ -97,6 +91,7 @@ def train(epoch, ternary, rel, norel):
random.shuffle(rel)
random.shuffle(norel)


ternary = cvt_data_axis(ternary)
rel = cvt_data_axis(rel)
norel = cvt_data_axis(norel)
Expand Down Expand Up @@ -224,6 +219,7 @@ def load_data():
filename = os.path.join(dirs,'sort-of-clevr.pickle')
with open(filename, 'rb') as f:
train_datasets, test_datasets = pickle.load(f)

ternary_train = []
ternary_test = []
rel_train = []
Expand Down Expand Up @@ -255,6 +251,7 @@ def load_data():

ternary_train, ternary_test, rel_train, rel_test, norel_train, norel_test = load_data()


try:
os.makedirs(model_dirs)
except:
Expand All @@ -276,11 +273,7 @@ def load_data():
print(f"Training {args.model} {f'({args.relation_type})' if args.model == 'RN' else ''} model...")

for epoch in range(1, args.epochs + 1):
train_acc_ternary, train_acc_binary, train_acc_unary = train(
epoch, ternary_train, rel_train, norel_train)
test_acc_ternary, test_acc_binary, test_acc_unary = test(
epoch, ternary_test, rel_test, norel_test)

csv_writer.writerow([epoch, train_acc_ternary, train_acc_binary,
train_acc_unary, test_acc_ternary, test_acc_binary, test_acc_unary])
model.save_model(epoch)
train_acc_ternary, train_acc_binary, train_acc_unary = train(epoch, ternary_train, rel_train, norel_train)
test_acc_ternary, test_acc_binary, test_acc_unary = test(epoch, ternary_test, rel_test, norel_test)
csv_writer.writerow([epoch, train_acc_ternary, train_acc_binary, train_acc_unary, test_acc_ternary, test_acc_binary, test_acc_unary])
model.save_model(epoch)
Loading