pymarl源码解读
本文最后更新于:2024年12月17日 晚上
pymarl源码解读
源代码仓
https://github.com/oxwhirl/pymarl
实现算法:
- QMIX: QMIX: Monotonic Value Function Factorisation for Deep Multi-Agent Reinforcement Learning
- COMA: Counterfactual Multi-Agent Policy Gradients
- VDN: Value-Decomposition Networks For Cooperative Multi-Agent Learning
- IQL: Independent Q-Learning
- QTRAN: QTRAN: Learning to Factorize with Transformation for Cooperative Multi-Agent Reinforcement Learning
文件结构
- 仅考虑src文件夹下的内容
- 关注qmix算法涉及文件
1 |
|
主要模块介绍
- 以qmix算法为例
- 不关注log模块
- 一些简单的函数,或者工具函数(跟算法思想没有太大关联,可直接复用的代码),也不关注
yaml配置文件
default.yaml
1
2use_tensorboard: True # 使用tensorboard记录实验数据,方便后续分析。
save_model: True # 保存模型,方便后续测试
程序入口——main.py
创建sacred实验
1
2
3
4
5
6
7
8SETTINGS['CAPTURE_MODE'] = "fd" # set to "no" if you want to see stdout/stderr in console
logger = get_logger() # 该语句会导致控制台输出一些类似 "[DEBUG xx:xx:xx] git.cmd Popen(...)" 格式的日志
ex = Experiment("pymarl")
ex.logger = logger
ex.captured_out_filter = apply_backspaces_and_linefeeds # 设置输出格式,避免有些实时输出(进度条等)不适合文件输出的形式
results_path = os.path.join(dirname(dirname(abspath(__file__))), "results")加载实验配置参数,运行sacred实验
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26if __name__ == '__main__':
params = deepcopy(sys.argv) # 接收命令行参数
# Get the defaults from default.yaml
with open(os.path.join(os.path.dirname(__file__), "config", "default.yaml"), "r") as f:
try:
config_dict = yaml.load(f, Loader=yaml.FullLoader)
except yaml.YAMLError as exc:
assert False, "default.yaml error: {}".format(exc)
# Load algorithm and env base configs
env_config = _get_config(params, "--env-config", "envs") # 获取实验环境(e.g.SC2)yaml配置
alg_config = _get_config(params, "--config", "algs") # 获取实验算法(e.g.QMIX)yaml配置
config_dict = {**config_dict, **env_config, **alg_config} # 字典合并
config_dict = recursive_dict_update(config_dict, env_config)
config_dict = recursive_dict_update(config_dict, alg_config)
# now add all the config to sacred
ex.add_config(config_dict)
# Save to disk by default for sacred
logger.info("Saving to FileStorageObserver in results/sacred.")
file_obs_path = os.path.join(results_path, "sacred")
ex.observers.append(FileStorageObserver.create(file_obs_path)) # 创建一个ex的观察者文件(写日志)
ex.run_commandline(params)初始化随机种子,启动实验框架
1
2
3
4
5
6
7
8
9
10
11# sacred实验的主函数
@ex.main
def my_main(_run, _config, _log):
# Setting the random seed throughout the modules
config = config_copy(_config)
np.random.seed(config["seed"])
th.manual_seed(config["seed"])
config['env_args']['seed'] = config["seed"]
# run the framework
run(_run, config, _log)
实验运行——run.py
前几行代码基本都是和log相关,可以暂时忽略,核心是run_sequential
函数,后面的代码是实验结束后的一些程序上的后处理,与算法无关。
1 |
|
run_sequential函数
实验运行的主要函数,构建如下自定义类的对象:
runner——环境运行器。负责执行游戏环境。
buffer——经验回放池。负责存放采样数据。
mac——智能体控制器。负责构建智能体,根据输入选择行为。
learner——智能体学习器。负责训练模型参数
最后进行实验,训练智能体,记录实验结果,定期测试并保存模型
构造实验需要的各种自定义类对象
定义环境运行器runner
1
2# Init runner so we can get env info
runner = r_REGISTRY[args.runner](args=args, logger=logger)定义采样数据格式,即存在buffer里的数据大概包含哪些信息
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15# Default/Base scheme
scheme = {
"state": {"vshape": env_info["state_shape"]},
"obs": {"vshape": env_info["obs_shape"], "group": "agents"},
"actions": {"vshape": (1,), "group": "agents", "dtype": th.long},
"avail_actions": {"vshape": (env_info["n_actions"],), "group": "agents", "dtype": th.int},
"reward": {"vshape": (1,)},
"terminated": {"vshape": (1,), "dtype": th.uint8},
}
groups = {
"agents": args.n_agents
}
preprocess = {
"actions": ("actions_onehot", [OneHot(out_dim=args.n_actions)])
}定义经验回放池buffer
1
2
3buffer = ReplayBuffer(scheme, groups, args.buffer_size, env_info["episode_limit"] + 1,
preprocess=preprocess,
device="cpu" if args.buffer_cpu_only else args.device)定义智能体控制器mac
1
2# Setup multiagent controller here
mac = mac_REGISTRY[args.mac](buffer.scheme, groups, args)将上面定义的scheme等信息,以及mac对象传给runner
self.new_batch是一个固定参数的EpisodeBatch类的构造函数,每次调用创建一个新的对象,用于存储采样数据
1
2
3
4
5# 以EpisodeRunner类的setup函数为例
def setup(self, scheme, groups, preprocess, mac):
self.new_batch = partial(EpisodeBatch, scheme, groups, self.batch_size, self.episode_limit + 1,
preprocess=preprocess, device=self.args.device)
self.mac = mac定义智能体学习器learner
1
2# Learner
learner = le_REGISTRY[args.learner](mac, buffer.scheme, logger, args)
如果有保存模型,读取模型,继续训练
1 |
|
开始训练
实验参数
1
2
3
4
5
6
7episode = 0 # 当前训练多少幕
last_test_T = -args.test_interval - 1 # 上次测试环境步,用于判断是否要进行测试
last_log_T = 0 # 上次输出日志环境步,用于判断是否要输出日志
model_save_time = 0 # 上次保存模型环境步,用于判断是否要保存模型
start_time = time.time() # 实验开始时间,用于日志信息
last_time = start_time # 用于计算剩余时间(控制台输出日志)while循环体(核心)
循环终止条件即训练环境步超出设定阈值
运行游戏环境并保存数据
1
2
3# Run for a whole episode at a time
episode_batch = runner.run(test_mode=False) # 运行一幕
buffer.insert_episode_batch(episode_batch)训练
1
2
3
4
5
6
7
8
9
10
11if buffer.can_sample(args.batch_size): # buffer中数据量超过batch_size
episode_sample = buffer.sample(args.batch_size) # buffer中存的样本数足够,才会进行训练
# Truncate batch to only filled timesteps
max_ep_t = episode_sample.max_t_filled()
episode_sample = episode_sample[:, :max_ep_t] # 使用从buffer中采样的训练样本集的最长时间序列,对所有样本的时间维度做截断
if episode_sample.device != args.device:
episode_sample.to(args.device)
learner.train(episode_sample, runner.t_env, episode) # 训练测试
1
2
3
4
5
6
7
8
9
10
11
12# Execute test runs once in a while
n_test_runs = max(1, args.test_nepisode // runner.batch_size) # 每次测试跑n_test_runs幕
if (runner.t_env - last_test_T) / args.test_interval >= 1.0: # 距离上次测试,已经过test_interval环境步
logger.console_logger.info("t_env: {} / {}".format(runner.t_env, args.t_max)) # 控制台打印训练进度
logger.console_logger.info("Estimated time left: {}. Time passed: {}".format(
time_left(last_time, last_test_T, runner.t_env, args.t_max), time_str(time.time() - start_time))) # 控制台打印估计剩余训练时间
last_time = time.time()
last_test_T = runner.t_env
for _ in range(n_test_runs):
runner.run(test_mode=True) # 测试保存模型
1
2
3
4
5
6
7
8
9
10if args.save_model and (runner.t_env - model_save_time >= args.save_model_interval or model_save_time == 0): # 超参数设置save_model并且距离上次保存模型,已经过save_model_interval环境步(或者是训练的起始阶段)
model_save_time = runner.t_env
save_path = os.path.join(args.local_results_path, "models", args.unique_token, str(runner.t_env))
#"results/models/{}".format(unique_token)
os.makedirs(save_path, exist_ok=True)
logger.console_logger.info("Saving models to {}".format(save_path))
# learner should handle saving/loading -- delegate actor save/load to mac,
# use appropriate filenames to do critics, optimizer states
learner.save_models(save_path) # 保存模型打印日志
1
2
3
4if (runner.t_env - last_log_T) >= args.log_interval: # 距离上次打印日志,已经过log_interval环境步
logger.log_stat("episode", episode, runner.t_env)
logger.print_recent_stats()
last_log_T = runner.t_env
关闭环境
1
2runner.close_env()
logger.console_logger.info("Finished Training")
游戏环境运行——episode_runner.py
pymarl框架总共有两种环境运行器,见
src/runners/__init__.py
文件
EpisodeRunner
一次运行一幕游戏(下面介绍的是这个类)
ParallelRunner
一次运行多幕游戏(由default.yaml
中的batch_size_run
参数控制)
__init__函数
1 |
|
run函数
初始化
1
2
3
4
5self.reset() # 重置环境
terminated = False # 标记游戏是否结束
episode_return = 0 # 记录当前幕累计return
self.mac.init_hidden(batch_size=self.batch_size) # 重置agent隐状态,避免将上局游戏"记忆"带入本局游戏进行一幕游戏,直到游戏结束
获取环境信息并存储——>选择action——>环境递进一步存储信息——>...
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37while not terminated:
pre_transition_data = {
"state": [self.env.get_state()],
"avail_actions": [self.env.get_avail_actions()],
"obs": [self.env.get_obs()]
}
self.batch.update(pre_transition_data, ts=self.t)
# Pass the entire batch of experiences up till now to the agents
# Receive the actions for each agent at this timestep in a batch of size 1
actions = self.mac.select_actions(self.batch, t_ep=self.t, t_env=self.t_env, test_mode=test_mode)
reward, terminated, env_info = self.env.step(actions[0]) # 根据agents行为,环境递进一步
episode_return += reward # 累积幕return
post_transition_data = {
"actions": actions,
"reward": [(reward,)],
"terminated": [(terminated != env_info.get("episode_limit", False),)],
}
self.batch.update(post_transition_data, ts=self.t)
self.t += 1 # 时间步+1
last_data = {
"state": [self.env.get_state()],
"avail_actions": [self.env.get_avail_actions()],
"obs": [self.env.get_obs()]
}
self.batch.update(last_data, ts=self.t)
# Select actions in the last stored state
actions = self.mac.select_actions(self.batch, t_ep=self.t, t_env=self.t_env, test_mode=test_mode)
self.batch.update({"actions": actions}, ts=self.t)日志相关操作(忽略)
采样样本——episode_buffer.py
样本数据结构——EpisodeBatch类
:sweat_smile:
经验回放池——ReplayBuffer类
继承
EpisodeBatch
类,在此基础上添加了经验回放池的功能
初始化
1
2
3
4
5def __init__(self, scheme, groups, buffer_size, max_seq_length, preprocess=None, device="cpu"):
super(ReplayBuffer, self).__init__(scheme, groups, buffer_size, max_seq_length, preprocess=preprocess, device=device)
self.buffer_size = buffer_size # same as self.batch_size but more explicit
self.buffer_index = 0 # 插入数据的起始index
self.episodes_in_buffer = 0 # 当前buffer中存有多少episode添加数据
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20# 以self.buffer_size为模数循环添加数据
def insert_episode_batch(self, ep_batch):
if self.buffer_index + ep_batch.batch_size <= self.buffer_size:
# 当待添加数据不超过buffer_size时,直接加进去
self.update(ep_batch.data.transition_data,
slice(self.buffer_index, self.buffer_index + ep_batch.batch_size),
slice(0, ep_batch.max_seq_length),
mark_filled=False)
self.update(ep_batch.data.episode_data,
slice(self.buffer_index, self.buffer_index + ep_batch.batch_size))
self.buffer_index = (self.buffer_index + ep_batch.batch_size) # 更新插入位置索引
self.episodes_in_buffer = max(self.episodes_in_buffer, self.buffer_index) # 当buffer满了以后,episode_in_buffer恒为5000,而buffer_index会循环计算
self.buffer_index = self.buffer_index % self.buffer_size # 当插入位置索引到5000后,归0
assert self.buffer_index < self.buffer_size
else:
# 当待添加数据超过buffer_size时,截断
# 一部分直接加进去,多余部分,从buffer的头部开始添加
buffer_left = self.buffer_size - self.buffer_index
self.insert_episode_batch(ep_batch[0:buffer_left, :])
self.insert_episode_batch(ep_batch[buffer_left:, :])采样数据
1
2
3
4
5
6
7
8def sample(self, batch_size):
assert self.can_sample(batch_size) # 判断经验回放池中存有足量数据
if self.episodes_in_buffer == batch_size:
return self[:batch_size]
else:
# Uniform sampling only atm
ep_ids = np.random.choice(self.episodes_in_buffer, batch_size, replace=False)
return self[ep_ids]
智能体控制器——basic_controller.py
This multi-agent controller shares parameters between agents
__init__函数
1 |
|
select_actions函数
1 |
|
forward函数
1 |
|
_build_inputs函数
构建智能体t时刻的观测
1 |
|
_get_input_shape函数
获取agent模型输入的维度
1 |
|
训练模型——q_learner.py
__init__函数
1 |
|
train函数(重点)
部分细节跳过
整理数据
1
2
3
4
5
6
7# Get the relevant quantities
rewards = batch["reward"][:, :-1]
actions = batch["actions"][:, :-1]
terminated = batch["terminated"][:, :-1].float()
mask = batch["filled"][:, :-1].float()
mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1])
avail_actions = batch["avail_actions"]计算训练样本轨迹对应的Q值
1
2
3
4
5
6
7# Calculate estimated Q-Values
mac_out = []
self.mac.init_hidden(batch.batch_size) # 训练前,清空mac记录的隐状态
for t in range(batch.max_seq_length):
agent_outs = self.mac.forward(batch, t=t)
mac_out.append(agent_outs)
mac_out = th.stack(mac_out, dim=1) # Concat over time提取agent实际选择动作的Q值
1
2# Pick the Q-Values for the actions taken by each agent
chosen_action_qvals = th.gather(mac_out[:, :-1], dim=3, index=actions).squeeze(3) # Remove the last dim使用target网络计算训练样本轨迹对应的Q值(训练目标)
1
2
3
4
5
6# Calculate the Q-Values necessary for the target
target_mac_out = []
self.target_mac.init_hidden(batch.batch_size)
for t in range(batch.max_seq_length):
target_agent_outs = self.target_mac.forward(batch, t=t)
target_mac_out.append(target_agent_outs)计算TD-error
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25# Max over target Q-Values
if self.args.double_q:
# Get actions that maximise live Q (for double q-learning)
mac_out_detach = mac_out.clone().detach()
mac_out_detach[avail_actions == 0] = -9999999
cur_max_actions = mac_out_detach[:, 1:].max(dim=3, keepdim=True)[1]
target_max_qvals = th.gather(target_mac_out, 3, cur_max_actions).squeeze(3)
else:
target_max_qvals = target_mac_out.max(dim=3)[0]
# Mix
if self.mixer is not None:
chosen_action_qvals = self.mixer(chosen_action_qvals, batch["state"][:, :-1])
target_max_qvals = self.target_mixer(target_max_qvals, batch["state"][:, 1:])
# Calculate 1-step Q-Learning targets
targets = rewards + self.args.gamma * (1 - terminated) * target_max_qvals
# Td-error
td_error = (chosen_action_qvals - targets.detach())
mask = mask.expand_as(td_error)
# 0-out the targets that came from padded data
masked_td_error = td_error * mask优化参数
1
2
3
4
5
6
7
8# Normal L2 loss, take mean over actual data
loss = (masked_td_error ** 2).sum() / mask.sum()
# Optimise
self.optimiser.zero_grad()
loss.backward()
grad_norm = th.nn.utils.clip_grad_norm_(self.params, self.args.grad_norm_clip)
self.optimiser.step()更新target网络
1
2
3if (episode_num - self.last_target_update_episode) / self.args.target_update_interval >= 1.0:
self._update_targets()
self.last_target_update_episode = episode_num
在本地测试效果
Note: Replays cannot be watched using the Linux version of StarCraft II. Please use either the Mac or Windows version of the StarCraft II client.
根据官方介绍,需要在windows下查看效果
下面是将训练model回传本地做测试,也可以在服务器端进行测试,直接将生成的
.SC2Replay
文件回传本地,回放文件在/StarCraft II/Replays/
路径下面
使用
scp
将训练好的model回传本地1
scp -r <user@remote_host:/path/to/remote/folder> </path/to/local/destination>
设置
default.yaml
参数1
2
3checkpoint_path: "results/models/qmix__2024-11-10_13-17-31"
evaluate: True
save_replay: True运行实验(仅测试,生成回放文件)
1
2# 参数跟训练时保持一致
python src/main.py --config=qmix --env-config=sc2 with env_args.map_name=2s3z运行
.SC2Replay
程序
本博客所有文章除特别声明外,均采用 CC BY-SA 4.0 协议 ,转载请注明出处!