-
Notifications
You must be signed in to change notification settings - Fork 5
/
render_model.py
69 lines (55 loc) · 2.14 KB
/
render_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import os
import yaml
import argparse
import gymnasium as gym
from tqdm.notebook import trange
from stable_baselines3 import TD3
from model import init_gym, load_model
def render_prediction(
model,
prediction_steps=1000,
):
"""
Generate prediction frames using a trained RL model.
This function uses the given RL model to predict actions for a given number of steps and
generates frames for each step by rendering the environment. It returns a list of frames.
Parameters:
model (stable_baselines3.TD3): The trained RL model.
prediction_steps (int, optional): The number of steps to generate predictions and frames.
Defaults to 1000.
Returns:
list: A list of rendered frames from the environment during prediction.
"""
vec_env = model.get_env()
obs = vec_env.reset()
frames = []
for i in trange(prediction_steps):
# Predict the action based on the observation; Perform the action in the environment
action, _states = model.predict(obs)
obs, rewards, dones, info = vec_env.step(action)
# Render the environment and add the frame to the frames list
frames.append(vec_env.render())
return frames
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Train the RL model.")
parser.add_argument(
"--config_path",
type=str,
default="config.yaml",
help="config file for your model",
)
args = parser.parse_args()
with open(args.config_path, "r") as file:
config = yaml.safe_load(file)
output_path = os.path.join(config["setup"]["path"], config["setup"]["alias"])
# Load the model and generate prediction frames
env = init_gym(
gym_name=config["env"]["gym_name"],
render_mode=config["env"]["render_mode"],
video_path=os.path.join(output_path, "video"),
logs_path=None,
walls=config["env"]["walls"],
goal_size=config["env"]["goal_size"],
)
model = load_model(env, output_path, replay_buffer=None, logger=None)
frames = render_prediction(model, config["test"]["prediction_steps"])