pymarl源码解读

本文最后更新于:2024年12月17日 晚上

pymarl源码解读

源代码仓

https://github.com/oxwhirl/pymarl

实现算法:

文件结构

  1. 仅考虑src文件夹下的内容
  2. 关注qmix算法涉及文件
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
src
├── components
│   ├── __init__.py
│   ├── action_selectors.py // 选择action
│   ├── episode_buffer.py // "采样样本"数据结构
│   ├── epsilon_schedules.py // epsilon衰减
│   └── transforms.py
├── config // 实验配置参数
│   ├── algs // 算法配置参数
│   │   ├── coma.yaml
│   │   ├── iql_beta.yaml
│   │   ├── iql.yaml
│   │   ├── qmix_beta.yaml
│   │   ├── qmix.yaml
│   │   ├── qtran.yaml
│   │   ├── vdn_beta.yaml
│   │   └── vdn.yaml
│   ├── envs // 环境配置参数
│   │   ├── sc2_beta.yaml
│   │   └── sc2.yaml
│   └── default.yaml // 基础配置参数
├── controllers
│   ├── __init__.py
│   └── basic_controller.py // agent控制器,从构建agent到选择action
├── envs
│   ├── __init__.py
│   └── multiagentenv.py
├── learners // 训练模型
│   ├── __init__.py
│   ├── coma_learner.py
│   ├── q_learner.py // 基于q函数训练,包括vdn和qmix
│   └── qtran_learner.py
├── modules
│   ├── agents
│   │   ├── __init__.py
│   │   └── rnn_agent.py // agent网络模型,输入观测等信息,输出q值
│   ├── critics
│   │   ├── __init__.py
│   │   └── coma.py
│   ├── mixers // mixing网络
│   │   ├── __init__.py
│   │   ├── qmix.py // qmix的mixing网络,论文Figure 2a
│   │   ├── qtran.py
│   │   └── vdn.py
│   └── __init__.py
├── runners // 游戏环境运行
│   ├── __init__.py
│   ├── episode_runner.py // 单幕运行,run函数完整运行一次游戏
│   └── parallel_runner.py // 多幕并行
└── utils // 工具函数
│ ├── dict2namedtuple.py
│ ├── logging.py
│ ├── rl_utils.py
│ └── timehelper.py
├── __init__.py
├── main.py // 程序入口,设置sacred实验
└── run.py // 实验运行,涉及启动实验到关闭环境全流程

主要模块介绍

  1. 以qmix算法为例
  2. 不关注log模块
  3. 一些简单的函数,或者工具函数(跟算法思想没有太大关联,可直接复用的代码),也不关注

yaml配置文件

  1. default.yaml

    1
    2
    use_tensorboard: True # 使用tensorboard记录实验数据,方便后续分析。
    save_model: True # 保存模型,方便后续测试

程序入口——main.py

  1. 创建sacred实验

    1
    2
    3
    4
    5
    6
    7
    8
    SETTINGS['CAPTURE_MODE'] = "fd" # set to "no" if you want to see stdout/stderr in console
    logger = get_logger() # 该语句会导致控制台输出一些类似 "[DEBUG xx:xx:xx] git.cmd Popen(...)" 格式的日志

    ex = Experiment("pymarl")
    ex.logger = logger
    ex.captured_out_filter = apply_backspaces_and_linefeeds # 设置输出格式,避免有些实时输出(进度条等)不适合文件输出的形式

    results_path = os.path.join(dirname(dirname(abspath(__file__))), "results")
  2. 加载实验配置参数,运行sacred实验

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    if __name__ == '__main__':
    params = deepcopy(sys.argv) # 接收命令行参数

    # Get the defaults from default.yaml
    with open(os.path.join(os.path.dirname(__file__), "config", "default.yaml"), "r") as f:
    try:
    config_dict = yaml.load(f, Loader=yaml.FullLoader)
    except yaml.YAMLError as exc:
    assert False, "default.yaml error: {}".format(exc)

    # Load algorithm and env base configs
    env_config = _get_config(params, "--env-config", "envs") # 获取实验环境(e.g.SC2)yaml配置
    alg_config = _get_config(params, "--config", "algs") # 获取实验算法(e.g.QMIX)yaml配置
    config_dict = {**config_dict, **env_config, **alg_config} # 字典合并
    config_dict = recursive_dict_update(config_dict, env_config)
    config_dict = recursive_dict_update(config_dict, alg_config)

    # now add all the config to sacred
    ex.add_config(config_dict)

    # Save to disk by default for sacred
    logger.info("Saving to FileStorageObserver in results/sacred.")
    file_obs_path = os.path.join(results_path, "sacred")
    ex.observers.append(FileStorageObserver.create(file_obs_path)) # 创建一个ex的观察者文件(写日志)

    ex.run_commandline(params)
  3. 初始化随机种子,启动实验框架

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    # sacred实验的主函数
    @ex.main
    def my_main(_run, _config, _log):
    # Setting the random seed throughout the modules
    config = config_copy(_config)
    np.random.seed(config["seed"])
    th.manual_seed(config["seed"])
    config['env_args']['seed'] = config["seed"]

    # run the framework
    run(_run, config, _log)

实验运行——run.py

前几行代码基本都是和log相关,可以暂时忽略,核心是run_sequential函数,后面的代码是实验结束后的一些程序上的后处理,与算法无关。

1
2
# Run and train
run_sequential(args=args, logger=logger)

run_sequential函数

实验运行的主要函数,构建如下自定义类的对象:

runner——环境运行器。负责执行游戏环境。

buffer——经验回放池。负责存放采样数据。

mac——智能体控制器。负责构建智能体,根据输入选择行为。

learner——智能体学习器。负责训练模型参数

最后进行实验,训练智能体,记录实验结果,定期测试并保存模型

构造实验需要的各种自定义类对象
  1. 定义环境运行器runner

    1
    2
    # Init runner so we can get env info
    runner = r_REGISTRY[args.runner](args=args, logger=logger)
  2. 定义采样数据格式,即存在buffer里的数据大概包含哪些信息

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    # Default/Base scheme
    scheme = {
    "state": {"vshape": env_info["state_shape"]},
    "obs": {"vshape": env_info["obs_shape"], "group": "agents"},
    "actions": {"vshape": (1,), "group": "agents", "dtype": th.long},
    "avail_actions": {"vshape": (env_info["n_actions"],), "group": "agents", "dtype": th.int},
    "reward": {"vshape": (1,)},
    "terminated": {"vshape": (1,), "dtype": th.uint8},
    }
    groups = {
    "agents": args.n_agents
    }
    preprocess = {
    "actions": ("actions_onehot", [OneHot(out_dim=args.n_actions)])
    }
  3. 定义经验回放池buffer

    1
    2
    3
    buffer = ReplayBuffer(scheme, groups, args.buffer_size, env_info["episode_limit"] + 1,
    preprocess=preprocess,
    device="cpu" if args.buffer_cpu_only else args.device)
  4. 定义智能体控制器mac

    1
    2
    # Setup multiagent controller here
    mac = mac_REGISTRY[args.mac](buffer.scheme, groups, args)
  5. 将上面定义的scheme等信息,以及mac对象传给runner

    self.new_batch是一个固定参数的EpisodeBatch类的构造函数,每次调用创建一个新的对象,用于存储采样数据

    1
    2
    3
    4
    5
    # 以EpisodeRunner类的setup函数为例
    def setup(self, scheme, groups, preprocess, mac):
    self.new_batch = partial(EpisodeBatch, scheme, groups, self.batch_size, self.episode_limit + 1,
    preprocess=preprocess, device=self.args.device)
    self.mac = mac
  6. 定义智能体学习器learner

    1
    2
    # Learner
    learner = le_REGISTRY[args.learner](mac, buffer.scheme, logger, args)
如果有保存模型,读取模型,继续训练
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
if args.checkpoint_path != "":

timesteps = []
timestep_to_load = 0

if not os.path.isdir(args.checkpoint_path):
logger.console_logger.info("Checkpoint directiory {} doesn't exist".format(args.checkpoint_path))
return

# Go through all files in args.checkpoint_path
for name in os.listdir(args.checkpoint_path):
full_name = os.path.join(args.checkpoint_path, name)
# Check if they are dirs the names of which are numbers
if os.path.isdir(full_name) and name.isdigit():
timesteps.append(int(name)) # 记录保存的每个模型对应的环境步

# 确定加载哪个环境步的模型
# 如果load_step参数设置为0,则加载最大的环境步
# 否则,加载距离load_step最近的环境步保存的模型
if args.load_step == 0:
# choose the max timestep
timestep_to_load = max(timesteps)
else:
# choose the timestep closest to load_step
timestep_to_load = min(timesteps, key=lambda x: abs(x - args.load_step))

model_path = os.path.join(args.checkpoint_path, str(timestep_to_load))

logger.console_logger.info("Loading model from {}".format(model_path))
learner.load_models(model_path)
runner.t_env = timestep_to_load # 从保存模型的环境步继续训练

# 仅测试,不训练
if args.evaluate or args.save_replay:
evaluate_sequential(args, runner)
return
开始训练
  1. 实验参数

    1
    2
    3
    4
    5
    6
    7
    episode = 0	# 当前训练多少幕
    last_test_T = -args.test_interval - 1 # 上次测试环境步,用于判断是否要进行测试
    last_log_T = 0 # 上次输出日志环境步,用于判断是否要输出日志
    model_save_time = 0 # 上次保存模型环境步,用于判断是否要保存模型

    start_time = time.time() # 实验开始时间,用于日志信息
    last_time = start_time # 用于计算剩余时间(控制台输出日志)
  2. while循环体(核心)

    循环终止条件即训练环境步超出设定阈值

    1. 运行游戏环境并保存数据

      1
      2
      3
      # Run for a whole episode at a time
      episode_batch = runner.run(test_mode=False) # 运行一幕
      buffer.insert_episode_batch(episode_batch)
    2. 训练

      1
      2
      3
      4
      5
      6
      7
      8
      9
      10
      11
      if buffer.can_sample(args.batch_size):	# buffer中数据量超过batch_size
      episode_sample = buffer.sample(args.batch_size) # buffer中存的样本数足够,才会进行训练

      # Truncate batch to only filled timesteps
      max_ep_t = episode_sample.max_t_filled()
      episode_sample = episode_sample[:, :max_ep_t] # 使用从buffer中采样的训练样本集的最长时间序列,对所有样本的时间维度做截断

      if episode_sample.device != args.device:
      episode_sample.to(args.device)

      learner.train(episode_sample, runner.t_env, episode) # 训练
    3. 测试

      1
      2
      3
      4
      5
      6
      7
      8
      9
      10
      11
      12
      # Execute test runs once in a while
      n_test_runs = max(1, args.test_nepisode // runner.batch_size) # 每次测试跑n_test_runs幕
      if (runner.t_env - last_test_T) / args.test_interval >= 1.0: # 距离上次测试,已经过test_interval环境步

      logger.console_logger.info("t_env: {} / {}".format(runner.t_env, args.t_max)) # 控制台打印训练进度
      logger.console_logger.info("Estimated time left: {}. Time passed: {}".format(
      time_left(last_time, last_test_T, runner.t_env, args.t_max), time_str(time.time() - start_time))) # 控制台打印估计剩余训练时间
      last_time = time.time()

      last_test_T = runner.t_env
      for _ in range(n_test_runs):
      runner.run(test_mode=True) # 测试
    4. 保存模型

      1
      2
      3
      4
      5
      6
      7
      8
      9
      10
      if args.save_model and (runner.t_env - model_save_time >= args.save_model_interval or model_save_time == 0):	# 超参数设置save_model并且距离上次保存模型,已经过save_model_interval环境步(或者是训练的起始阶段)
      model_save_time = runner.t_env
      save_path = os.path.join(args.local_results_path, "models", args.unique_token, str(runner.t_env))
      #"results/models/{}".format(unique_token)
      os.makedirs(save_path, exist_ok=True)
      logger.console_logger.info("Saving models to {}".format(save_path))

      # learner should handle saving/loading -- delegate actor save/load to mac,
      # use appropriate filenames to do critics, optimizer states
      learner.save_models(save_path) # 保存模型
    5. 打印日志

      1
      2
      3
      4
      if (runner.t_env - last_log_T) >= args.log_interval:	# 距离上次打印日志,已经过log_interval环境步
      logger.log_stat("episode", episode, runner.t_env)
      logger.print_recent_stats()
      last_log_T = runner.t_env
  3. 关闭环境

    1
    2
    runner.close_env()
    logger.console_logger.info("Finished Training")

游戏环境运行——episode_runner.py

pymarl框架总共有两种环境运行器,见src/runners/__init__.py文件

EpisodeRunner一次运行一幕游戏(下面介绍的是这个类)

ParallelRunner一次运行多幕游戏(由default.yaml中的batch_size_run参数控制)

__init__函数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
self.args = args
self.logger = logger
self.batch_size = self.args.batch_size_run
assert self.batch_size == 1

self.env = env_REGISTRY[self.args.env](**self.args.env_args)
self.episode_limit = self.env.episode_limit
self.t = 0 # 时间步,记录当前游戏执行了多少步

self.t_env = 0 # 环境步,记录整个实验总共执行了多少时间步

self.train_returns = []
self.test_returns = []
self.train_stats = {}
self.test_stats = {}

# Log the first run
self.log_train_stats_t = -1000000

run函数

  1. 初始化

    1
    2
    3
    4
    5
    self.reset()	# 重置环境

    terminated = False # 标记游戏是否结束
    episode_return = 0 # 记录当前幕累计return
    self.mac.init_hidden(batch_size=self.batch_size) # 重置agent隐状态,避免将上局游戏"记忆"带入本局游戏
  2. 进行一幕游戏,直到游戏结束

    获取环境信息并存储——>选择action——>环境递进一步存储信息——>...

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    while not terminated:

    pre_transition_data = {
    "state": [self.env.get_state()],
    "avail_actions": [self.env.get_avail_actions()],
    "obs": [self.env.get_obs()]
    }

    self.batch.update(pre_transition_data, ts=self.t)

    # Pass the entire batch of experiences up till now to the agents
    # Receive the actions for each agent at this timestep in a batch of size 1
    actions = self.mac.select_actions(self.batch, t_ep=self.t, t_env=self.t_env, test_mode=test_mode)

    reward, terminated, env_info = self.env.step(actions[0]) # 根据agents行为,环境递进一步
    episode_return += reward # 累积幕return

    post_transition_data = {
    "actions": actions,
    "reward": [(reward,)],
    "terminated": [(terminated != env_info.get("episode_limit", False),)],
    }

    self.batch.update(post_transition_data, ts=self.t)

    self.t += 1 # 时间步+1

    last_data = {
    "state": [self.env.get_state()],
    "avail_actions": [self.env.get_avail_actions()],
    "obs": [self.env.get_obs()]
    }
    self.batch.update(last_data, ts=self.t)

    # Select actions in the last stored state
    actions = self.mac.select_actions(self.batch, t_ep=self.t, t_env=self.t_env, test_mode=test_mode)
    self.batch.update({"actions": actions}, ts=self.t)
  3. 日志相关操作(忽略)

采样样本——episode_buffer.py

样本数据结构——EpisodeBatch类

:sweat_smile:

经验回放池——ReplayBuffer类

继承EpisodeBatch类,在此基础上添加了经验回放池的功能

  1. 初始化

    1
    2
    3
    4
    5
    def __init__(self, scheme, groups, buffer_size, max_seq_length, preprocess=None, device="cpu"):
    super(ReplayBuffer, self).__init__(scheme, groups, buffer_size, max_seq_length, preprocess=preprocess, device=device)
    self.buffer_size = buffer_size # same as self.batch_size but more explicit
    self.buffer_index = 0 # 插入数据的起始index
    self.episodes_in_buffer = 0 # 当前buffer中存有多少episode
  2. 添加数据

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    # 以self.buffer_size为模数循环添加数据
    def insert_episode_batch(self, ep_batch):
    if self.buffer_index + ep_batch.batch_size <= self.buffer_size:
    # 当待添加数据不超过buffer_size时,直接加进去
    self.update(ep_batch.data.transition_data,
    slice(self.buffer_index, self.buffer_index + ep_batch.batch_size),
    slice(0, ep_batch.max_seq_length),
    mark_filled=False)
    self.update(ep_batch.data.episode_data,
    slice(self.buffer_index, self.buffer_index + ep_batch.batch_size))
    self.buffer_index = (self.buffer_index + ep_batch.batch_size) # 更新插入位置索引
    self.episodes_in_buffer = max(self.episodes_in_buffer, self.buffer_index) # 当buffer满了以后,episode_in_buffer恒为5000,而buffer_index会循环计算
    self.buffer_index = self.buffer_index % self.buffer_size # 当插入位置索引到5000后,归0
    assert self.buffer_index < self.buffer_size
    else:
    # 当待添加数据超过buffer_size时,截断
    # 一部分直接加进去,多余部分,从buffer的头部开始添加
    buffer_left = self.buffer_size - self.buffer_index
    self.insert_episode_batch(ep_batch[0:buffer_left, :])
    self.insert_episode_batch(ep_batch[buffer_left:, :])
  3. 采样数据

    1
    2
    3
    4
    5
    6
    7
    8
    def sample(self, batch_size):
    assert self.can_sample(batch_size) # 判断经验回放池中存有足量数据
    if self.episodes_in_buffer == batch_size:
    return self[:batch_size]
    else:
    # Uniform sampling only atm
    ep_ids = np.random.choice(self.episodes_in_buffer, batch_size, replace=False)
    return self[ep_ids]

智能体控制器——basic_controller.py

This multi-agent controller shares parameters between agents

__init__函数

1
2
3
4
5
6
7
8
9
self.n_agents = args.n_agents	# 智能体数量
self.args = args # 实验参数
input_shape = self._get_input_shape(scheme) # 获取输入维度,3.6.5介绍
self._build_agents(input_shape) # 创建智能体网络(RNN模型,价值网络)
self.agent_output_type = args.agent_output_type

self.action_selector = action_REGISTRY[args.action_selector](args) # action选择器,e.g. epsilon-greedy

self.hidden_states = None # RNN模型的隐状态

select_actions函数

1
2
3
4
5
# Only select actions for the selected batch elements in bs
avail_actions = ep_batch["avail_actions"][:, t_ep] # 当前时刻可执行动作
agent_outputs = self.forward(ep_batch, t_ep, test_mode=test_mode) # agent模型前向传播,3.6.3介绍
chosen_actions = self.action_selector.select_action(agent_outputs[bs], avail_actions[bs], t_env, test_mode=test_mode) # 选择action
return chosen_actions

forward函数

1
2
3
4
5
6
7
agent_inputs = self._build_inputs(ep_batch, t)	# 构建智能体t时刻观测,3.6.4介绍
avail_actions = ep_batch["avail_actions"][:, t] # 当前时刻可执行动作
agent_outs, self.hidden_states = self.agent(agent_inputs, self.hidden_states)

# 中间代码针对coma算法,省略

return agent_outs.view(ep_batch.batch_size, self.n_agents, -1)

_build_inputs函数

构建智能体t时刻的观测

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# Assumes homogenous agents with flat observations.
# Other MACs might want to e.g. delegate building inputs to each agent
bs = batch.batch_size
inputs = []
inputs.append(batch["obs"][:, t]) # b1av,分别对应batch_size, time_step, agent, vshape
if self.args.obs_last_action: # 观测包含agent的上一个动作
if t == 0:
inputs.append(th.zeros_like(batch["actions_onehot"][:, t]))
else:
inputs.append(batch["actions_onehot"][:, t-1])
if self.args.obs_agent_id: # 观测包含agent的id
inputs.append(th.eye(self.n_agents, device=batch.device).unsqueeze(0).expand(bs, -1, -1))

inputs = th.cat([x.reshape(bs*self.n_agents, -1) for x in inputs], dim=1)
return inputs

_get_input_shape函数

获取agent模型输入的维度

1
2
3
4
5
6
7
input_shape = scheme["obs"]["vshape"]
if self.args.obs_last_action: # 观测包含agent的上一个动作
input_shape += scheme["actions_onehot"]["vshape"][0]
if self.args.obs_agent_id: # 观测包含agent的id
input_shape += self.n_agents

return input_shape

训练模型——q_learner.py

__init__函数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
self.args = args	# 实验参数
self.mac = mac # 智能体控制器
self.logger = logger # 日志

self.params = list(mac.parameters())

self.last_target_update_episode = 0

self.mixer = None
if args.mixer is not None: # Mixing网络
if args.mixer == "vdn":
self.mixer = VDNMixer()
elif args.mixer == "qmix":
self.mixer = QMixer(args)
else:
raise ValueError("Mixer {} not recognised.".format(args.mixer))
self.params += list(self.mixer.parameters())
self.target_mixer = copy.deepcopy(self.mixer)

self.optimiser = RMSprop(params=self.params, lr=args.lr, alpha=args.optim_alpha, eps=args.optim_eps) # 优化器

# a little wasteful to deepcopy (e.g. duplicates action selector), but should work for any MAC
self.target_mac = copy.deepcopy(mac)

self.log_stats_t = -self.args.learner_log_interval - 1

train函数(重点)

部分细节跳过

  1. 整理数据

    1
    2
    3
    4
    5
    6
    7
    # Get the relevant quantities
    rewards = batch["reward"][:, :-1]
    actions = batch["actions"][:, :-1]
    terminated = batch["terminated"][:, :-1].float()
    mask = batch["filled"][:, :-1].float()
    mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1])
    avail_actions = batch["avail_actions"]
  2. 计算训练样本轨迹对应的Q值

    1
    2
    3
    4
    5
    6
    7
    # Calculate estimated Q-Values
    mac_out = []
    self.mac.init_hidden(batch.batch_size) # 训练前,清空mac记录的隐状态
    for t in range(batch.max_seq_length):
    agent_outs = self.mac.forward(batch, t=t)
    mac_out.append(agent_outs)
    mac_out = th.stack(mac_out, dim=1) # Concat over time
  3. 提取agent实际选择动作的Q值

    1
    2
    # Pick the Q-Values for the actions taken by each agent
    chosen_action_qvals = th.gather(mac_out[:, :-1], dim=3, index=actions).squeeze(3) # Remove the last dim
  4. 使用target网络计算训练样本轨迹对应的Q值(训练目标)

    1
    2
    3
    4
    5
    6
    # Calculate the Q-Values necessary for the target
    target_mac_out = []
    self.target_mac.init_hidden(batch.batch_size)
    for t in range(batch.max_seq_length):
    target_agent_outs = self.target_mac.forward(batch, t=t)
    target_mac_out.append(target_agent_outs)
  5. 计算TD-error

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    # Max over target Q-Values
    if self.args.double_q:
    # Get actions that maximise live Q (for double q-learning)
    mac_out_detach = mac_out.clone().detach()
    mac_out_detach[avail_actions == 0] = -9999999
    cur_max_actions = mac_out_detach[:, 1:].max(dim=3, keepdim=True)[1]
    target_max_qvals = th.gather(target_mac_out, 3, cur_max_actions).squeeze(3)
    else:
    target_max_qvals = target_mac_out.max(dim=3)[0]

    # Mix
    if self.mixer is not None:
    chosen_action_qvals = self.mixer(chosen_action_qvals, batch["state"][:, :-1])
    target_max_qvals = self.target_mixer(target_max_qvals, batch["state"][:, 1:])

    # Calculate 1-step Q-Learning targets
    targets = rewards + self.args.gamma * (1 - terminated) * target_max_qvals

    # Td-error
    td_error = (chosen_action_qvals - targets.detach())

    mask = mask.expand_as(td_error)

    # 0-out the targets that came from padded data
    masked_td_error = td_error * mask
  6. 优化参数

    1
    2
    3
    4
    5
    6
    7
    8
    # Normal L2 loss, take mean over actual data
    loss = (masked_td_error ** 2).sum() / mask.sum()

    # Optimise
    self.optimiser.zero_grad()
    loss.backward()
    grad_norm = th.nn.utils.clip_grad_norm_(self.params, self.args.grad_norm_clip)
    self.optimiser.step()
  7. 更新target网络

    1
    2
    3
    if (episode_num - self.last_target_update_episode) / self.args.target_update_interval >= 1.0:
    self._update_targets()
    self.last_target_update_episode = episode_num

在本地测试效果

Note: Replays cannot be watched using the Linux version of StarCraft II. Please use either the Mac or Windows version of the StarCraft II client.

根据官方介绍,需要在windows下查看效果

下面是将训练model回传本地做测试,也可以在服务器端进行测试,直接将生成的.SC2Replay文件回传本地,回放文件在/StarCraft II/Replays/路径下面

  1. 使用scp将训练好的model回传本地

    1
    scp -r <user@remote_host:/path/to/remote/folder> </path/to/local/destination>
  2. 设置default.yaml参数

    1
    2
    3
    checkpoint_path: "results/models/qmix__2024-11-10_13-17-31"
    evaluate: True
    save_replay: True
  3. 运行实验(仅测试,生成回放文件)

    1
    2
    # 参数跟训练时保持一致
    python src/main.py --config=qmix --env-config=sc2 with env_args.map_name=2s3z
  4. 运行.SC2Replay程序


本博客所有文章除特别声明外,均采用 CC BY-SA 4.0 协议 ,转载请注明出处!