Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added asynchronous plotting to TensorFlow #126

Open
wants to merge 7 commits into
base: integration
Choose a base branch
from

Conversation

jonashen
Copy link
Collaborator

@jonashen jonashen commented Jun 7, 2018

Unfortunately, async plotting for TensorFlow requires the use of threading.Thread instead of multiprocessing.Process. Thus, on Linux machines, there will be a slight delay when you run the algorithm that seems like nothing is happening, but once the window loads everything works perfectly and asynchronously.


def init_plot(env, policy, session):
global process, queue, sess
sess = session
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not use tf.get_default_session() instead?

def init_worker():
global process, queue
queue = Queue()
process = Thread(target=_worker_start)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it's an instance of Thread, then the object should be named thread.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I interrupt the training with the plotter on, will the program stop as well, or it waits for the plotter? If the program still waits for the plot thread, you may need set daemon=True when creates the thread.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just tested. The program still waits when you don't set daemon=True. Set daemon=True solves this issue.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for letting me know!

# Only fetch the last message of each type
while True:
try:
msg = queue.get_nowait()
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This generates busy waiting if there's nothing in the queue.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be fine for now because something will be placed into the queue for each iteration, and upon completion the thread will close.


process = None
queue = None
sess = None
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you pass sess as a parameter to _worker_start using the Thread constructor, there's no need to have this as a global variable.


__all__ = ['init_worker', 'init_plot', 'update_plot']

process = None
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Encapsulate the variables and the methods in this file within a class to avoid having global variables.

import atexit
from multiprocessing import Process
import numpy as np
import pickle
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pickle is not need here. Double check these imports.

@@ -119,7 +122,8 @@ def train(self, sess=None):
logger.log("Optimizing policy...")
self.optimize_policy(itr, samples_data)
logger.log("Saving snapshot...")
params = self.get_itr_snapshot(itr, samples_data) # , **kwargs)
params = self.get_itr_snapshot(itr,
samples_data) # , **kwargs)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove vestigial comment.

@@ -0,0 +1 @@
from sandbox.rocky.tf.plotter import *
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please don't use any * imports. Every imported symbol should be specified individually. Not all symbols are part of the public API (so those are not imported).

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that PEP8 says that the public API for a module is the symbols listed in the all list defined in the module. We don't enforce this yet, but you can see that the plotter authors actually took advantage. So you should only import APIs from the all list.

@ryanjulian ryanjulian added this to the Week of June 4th milestone Jun 7, 2018
self.queue.join()
self.thread.join()

def shutdown(self):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check that the worker thread dies correctly at the end of the simulation and when user interrupts the simulation with keyboard interruption.

while True:
msgs = {}
# Only fetch the last message of each type
while True:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems now GIL is putting both this worker thread and the main thread in the same processor, so this busy waiting will impact performance.
A suggestion on how to avoid busy waiting can be found in my attempt to solve this problem here.
Feel free to use the code there.

Copy link

@ghost ghost left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tested it and performance is much better now. Also, the process exits correctly after normal and interrupted execution.

@@ -158,3 +165,6 @@ def get_itr_snapshot(self, itr, samples_data):
def optimize_policy(self, itr, samples_data):
raise NotImplementedError

def update_plot(self):
if self.plot:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This guard is redundant since it's already used in line 136.

@jonashen jonashen requested review from ryanjulian, CatherineSue and zhanpenghe and removed request for eric-heiden and CatherineSue June 7, 2018 22:04
Copy link
Owner

@ryanjulian ryanjulian left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should take the opportunity to clean this up a little bit. We can punt on improving the parallelism. Please leave a TODO/GitHub issue to figure out how to do this cross-platform and use multiprocessing.

from multiprocessing import Process
import numpy as np
import platform
from queue import Empty, Queue
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PEP8: import grouping

msg = self.queue.get_nowait()
msgs[msg[0]] = msg[1:]

if 'stop' in msgs:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's replace the strings with enums.

# Only fetch the last message of each type
while not self.queue.empty():
msg = self.queue.get()
msgs[msg[0]] = msg[1:]
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

msg[0] and msg[1:] is very difficult to read. What if we used a namedtuple instead?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from collections import namedtuple
import enum
from enum import Enum

class Op(Enum):
    STOP = enum.auto()
    UPDATE = enum.auto()
    DEMO = enum.auto()

Message = namedtuple("Message", ["op", "args", "kwargs"])

class Plotter:

    def _start_worker(self):
        while True:

            if initial_rollout:
                msg = self.queue.get()
                msgs[msg.op] = msg

            if Op.STOP in msgs:
                break
            elif Op.DEMO in msgs:
                env, policy = msgs[Op.DEMO].args

    def update_plot(self, policy, max_length=np.inf):
        if self.worker_thread.is_alive():
            self.queue.put(Message(op=Op.DEMO, args=(policy.get_param_values(), max_length))
            self.queue.task_done()


if 'stop' in msgs:
break
elif 'update' in msgs:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't this be if and not elif?

break
elif 'update' in msgs:
env, policy = msgs['update']
elif 'demo' in msgs:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't this be if and not elif?

max_length = None
initial_rollout = True
try:
with self.sess.as_default(), self.sess.graph.as_default():
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a comment here explaining that the worker processes all messages in the queue per loop, not one message per loop?

@ryanjulian
Copy link
Owner

Please reopen this PR against https://github.com/rlworkgroup/garage

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants