Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

One confusion about DQN #404

Open
guest-oo opened this issue Sep 17, 2024 · 1 comment
Open

One confusion about DQN #404

guest-oo opened this issue Sep 17, 2024 · 1 comment

Comments

@guest-oo
Copy link

guest-oo commented Sep 17, 2024

def valid_agent(env_class, env_args: dict, net_dims: List[int], agent_class, actor_path: str, render_times: int = 8):
env = build_env(env_class, env_args)

state_dim = env_args['state_dim']
action_dim = env_args['action_dim']
agent = agent_class(net_dims, state_dim, action_dim, gpu_id=-1)
actor = agent.act

print(f"| render and load actor from: {actor_path}")
actor.load_state_dict(th.load(actor_path, map_location=lambda storage, loc: storage))
for i in range(render_times):
    cumulative_reward, episode_step = get_rewards_and_steps(env, actor, if_render=True)
    print(f"|{i:4}  cumulative_reward {cumulative_reward:9.3f}  episode_step {episode_step:5.0f}")

Does the above code read the trained pth(dict) files, namely neural network parameters, and directly test the new environment with the neural network (no longer use the explore rate)? Why do you choose th.save(actor.state_dict(), save_path) instead of direct actor for saving pth files?

@Yonv1943
Copy link
Collaborator

Question:Does the above code read the trained pth(dict) files, namely neural network parameters, and directly test the new environment with the neural network (no longer use the explore rate)?

Answer:Yes。


Question:Why do you choose th.save(actor.state_dict(), save_path) instead of direct actor for saving pth files?

Answer:只保存模型参数权重,而不是保存整个模型(包括网络结构),使得兼容性更好,不同版本的 pytorch 都能打开。
Saving only the model parameter weights, rather than the entire model (including the network structure), enhances compatibility, allowing models to be opened across different versions of PyTorch.

最近(2024-10-12)PyTorch 为读取网络的函数添加了 weight_only 超参数,带来了一些兼容性影响,但不会影响到“仅保存模型参数权重”的保存方法。
Recently (2024-10-12), PyTorch introduced the weight_only hyperparameter for network loading functions, which has caused some compatibility issues. However, this does not affect the method of saving only the model parameter weights.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants