-
Notifications
You must be signed in to change notification settings - Fork 21
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added asynchronous plotting to TensorFlow #126
base: integration
Are you sure you want to change the base?
Conversation
sandbox/rocky/tf/plotter/plotter.py
Outdated
|
||
def init_plot(env, policy, session): | ||
global process, queue, sess | ||
sess = session |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not use tf.get_default_session() instead?
sandbox/rocky/tf/plotter/plotter.py
Outdated
def init_worker(): | ||
global process, queue | ||
queue = Queue() | ||
process = Thread(target=_worker_start) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If it's an instance of Thread, then the object should be named thread.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I interrupt the training with the plotter on, will the program stop as well, or it waits for the plotter? If the program still waits for the plot thread, you may need set daemon=True
when creates the thread.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just tested. The program still waits when you don't set daemon=True
. Set daemon=True
solves this issue.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for letting me know!
sandbox/rocky/tf/plotter/plotter.py
Outdated
# Only fetch the last message of each type | ||
while True: | ||
try: | ||
msg = queue.get_nowait() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This generates busy waiting if there's nothing in the queue.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be fine for now because something will be placed into the queue for each iteration, and upon completion the thread will close.
sandbox/rocky/tf/plotter/plotter.py
Outdated
|
||
process = None | ||
queue = None | ||
sess = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you pass sess as a parameter to _worker_start using the Thread constructor, there's no need to have this as a global variable.
sandbox/rocky/tf/plotter/plotter.py
Outdated
|
||
__all__ = ['init_worker', 'init_plot', 'update_plot'] | ||
|
||
process = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Encapsulate the variables and the methods in this file within a class to avoid having global variables.
sandbox/rocky/tf/plotter/plotter.py
Outdated
import atexit | ||
from multiprocessing import Process | ||
import numpy as np | ||
import pickle |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pickle is not need here. Double check these imports.
@@ -119,7 +122,8 @@ def train(self, sess=None): | |||
logger.log("Optimizing policy...") | |||
self.optimize_policy(itr, samples_data) | |||
logger.log("Saving snapshot...") | |||
params = self.get_itr_snapshot(itr, samples_data) # , **kwargs) | |||
params = self.get_itr_snapshot(itr, | |||
samples_data) # , **kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please remove vestigial comment.
sandbox/rocky/tf/plotter/__init__.py
Outdated
@@ -0,0 +1 @@ | |||
from sandbox.rocky.tf.plotter import * |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please don't use any *
imports. Every imported symbol should be specified individually. Not all symbols are part of the public API (so those are not imported).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that PEP8 says that the public API for a module is the symbols listed in the all list defined in the module. We don't enforce this yet, but you can see that the plotter authors actually took advantage. So you should only import APIs from the all list.
sandbox/rocky/tf/plotter/plotter.py
Outdated
self.queue.join() | ||
self.thread.join() | ||
|
||
def shutdown(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Check that the worker thread dies correctly at the end of the simulation and when user interrupts the simulation with keyboard interruption.
sandbox/rocky/tf/plotter/plotter.py
Outdated
while True: | ||
msgs = {} | ||
# Only fetch the last message of each type | ||
while True: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems now GIL is putting both this worker thread and the main thread in the same processor, so this busy waiting will impact performance.
A suggestion on how to avoid busy waiting can be found in my attempt to solve this problem here.
Feel free to use the code there.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tested it and performance is much better now. Also, the process exits correctly after normal and interrupted execution.
@@ -158,3 +165,6 @@ def get_itr_snapshot(self, itr, samples_data): | |||
def optimize_policy(self, itr, samples_data): | |||
raise NotImplementedError | |||
|
|||
def update_plot(self): | |||
if self.plot: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This guard is redundant since it's already used in line 136.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should take the opportunity to clean this up a little bit. We can punt on improving the parallelism. Please leave a TODO/GitHub issue to figure out how to do this cross-platform and use multiprocessing.
sandbox/rocky/tf/plotter/plotter.py
Outdated
from multiprocessing import Process | ||
import numpy as np | ||
import platform | ||
from queue import Empty, Queue |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PEP8: import grouping
sandbox/rocky/tf/plotter/plotter.py
Outdated
msg = self.queue.get_nowait() | ||
msgs[msg[0]] = msg[1:] | ||
|
||
if 'stop' in msgs: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's replace the strings with enums.
sandbox/rocky/tf/plotter/plotter.py
Outdated
# Only fetch the last message of each type | ||
while not self.queue.empty(): | ||
msg = self.queue.get() | ||
msgs[msg[0]] = msg[1:] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
msg[0] and msg[1:]
is very difficult to read. What if we used a namedtuple instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
from collections import namedtuple
import enum
from enum import Enum
class Op(Enum):
STOP = enum.auto()
UPDATE = enum.auto()
DEMO = enum.auto()
Message = namedtuple("Message", ["op", "args", "kwargs"])
class Plotter:
def _start_worker(self):
while True:
if initial_rollout:
msg = self.queue.get()
msgs[msg.op] = msg
if Op.STOP in msgs:
break
elif Op.DEMO in msgs:
env, policy = msgs[Op.DEMO].args
def update_plot(self, policy, max_length=np.inf):
if self.worker_thread.is_alive():
self.queue.put(Message(op=Op.DEMO, args=(policy.get_param_values(), max_length))
self.queue.task_done()
sandbox/rocky/tf/plotter/plotter.py
Outdated
|
||
if 'stop' in msgs: | ||
break | ||
elif 'update' in msgs: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shouldn't this be if
and not elif
?
sandbox/rocky/tf/plotter/plotter.py
Outdated
break | ||
elif 'update' in msgs: | ||
env, policy = msgs['update'] | ||
elif 'demo' in msgs: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shouldn't this be if
and not elif
?
max_length = None | ||
initial_rollout = True | ||
try: | ||
with self.sess.as_default(), self.sess.graph.as_default(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add a comment here explaining that the worker processes all messages in the queue per loop, not one message per loop?
Please reopen this PR against https://github.com/rlworkgroup/garage |
Unfortunately, async plotting for TensorFlow requires the use of threading.Thread instead of multiprocessing.Process. Thus, on Linux machines, there will be a slight delay when you run the algorithm that seems like nothing is happening, but once the window loads everything works perfectly and asynchronously.