-
Notifications
You must be signed in to change notification settings - Fork 12
/
plot_trained_results.py
74 lines (59 loc) · 2.12 KB
/
plot_trained_results.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Wed October 2020
@author: juanjosealcaraz
"""
import numpy as np
import matplotlib.pyplot as plt
from scipy import stats
import os
# test results
START = 40000
END = 49500
titles = ['Scenario 1', 'Scenario 2', 'Scenario 3']
scenarios = [0, 1, 2]
algo_names = ['A2C', 'PPO1', 'PPO2', 'TRPO', 'SAC', 'TD3', 'NAF', 'KBRL_97','KBRL_99']
labels = ['A2C', 'PPO1', 'PPO2', 'TRPO', 'SAC', 'TD3', 'NAF', 'KBRL 0.97', 'KBRL 0.99']
SPAN = END - START
prbs_values = [200, 150, 100]
scenarios = [0,1,2]
def mean_confidence_radius(data, confidence=0.95):
a = 1.0 * np.array(data)
n = len(a)
m, se = np.mean(a), stats.sem(a)
h = se * stats.t.ppf((1 + confidence) / 2., n-1)
return m, h
# subplot
fig, axs = plt.subplots(nrows=1, ncols=3, figsize=(12, 3.5), constrained_layout=True)
for i, (j, title) in enumerate(zip(scenarios,titles)):
axs[i].set_title(title)
PRBS = prbs_values[i]
# iterate over algorithms
for algo, label in zip(algo_names, labels):
data = False
path = './results/scenario_{}/{}/'.format(j,algo)
runs = 0
violations = []
resources = []
# iterate over files
for filename in os.listdir(path):
if filename.endswith(".npz"):
histories = np.load(path + filename)
_violations = histories['violation']
_resources = histories['resources']
if len(_violations) < END:
continue
violations.append(_violations[START:END].mean())
resources.append(_resources[START:END].mean()/PRBS)
v, v_h = mean_confidence_radius(violations)
r, r_h = mean_confidence_radius(resources)
axs[i].errorbar(r, v, xerr = r_h, yerr = v_h, fmt='o', label = label)
axs[i].set_xlim((0.4,1.))
axs[i].set_ylim((0.,1.))
axs[i].set_xlabel('Resource occupation') # Add an x-label to the axes.
axs[i].set_ylabel('SLA violations per stage')
axs[i].grid()
if i==0:
axs[i].legend(loc='upper left')
fig.savefig('./figures/trained_figure.png', format='png')