因为自己的工作需要,需要跑一下元强化学习的 PEARL (Efficient Off-policy Meta-learning via Probabilistic Context Variables) 代码做一些对比实验。
最原始的代码(https://github.com/katerakelly/oyster)。但是这个代码下载下来,环境配置比较困难。
搜索 Github 上看到这个项目(链接:https://github.com/dongminlee94/meta-learning-for-everyone),里面介绍了 MAML、RL2 和 PEARL。阅读项目介绍,感觉比较科普,因此就以这个项目做解读。理解里面的代码细节,方便做复现~
进入到 PEARL 文件夹下,输入指令 tree
即可大致了解里面的文件包含情况~
.
├── algorithm
│ ├── buffers.py
│ ├── meta_learner.py
│ ├── networks.py
│ ├── sac.py
│ └── sampler.py
├── configs
│ ├── dir_target_config.yaml
│ ├── experiment_config.yaml
│ └── vel_target_config.yaml
└── pearl_trainer.py
algorithm
文件夹下面是基本的组建:经验池 buffers.py
,元学习外环框架 meta_learner.py
,模型网络 networks.py
,内环框架 sac.py
,以及最后的批采样器 sampler.py
。
configs
下面是一些实验的配置文件:experiment_config.yaml
总的配置文件,dir_target_config.yaml
控制方向实验配置文件,vel_target_config.yaml
控制速度配置文件。
pearl_trainer.py
是整个训练脚本,运行时候直接运行这个文件即可~
导入 experiment_config.yaml
文件。
with open(os.path.join("configs", "experiment_config.yaml"), "r", encoding="utf-8") as file:experiment_config: Dict[str, Any] = yaml.load(file, Loader=yaml.FullLoader)
导入具体某个实验的配置数据,experiment_config.yaml
文件中:env_name: "dir"
,因此导入的是方向实验文件,也就是 dir_target_config.yaml
。
# 목표 보상 설정에 대한 하이퍼파라미터들 불러오기with open(os.path.join("configs", experiment_config["env_name"] + "_target_config.yaml"),"r",encoding="utf-8",) as file:env_target_config: Dict[str, Any] = yaml.load(file, Loader=yaml.FullLoader)
下面是 dir_target_config.yaml
的具体细节。具体变量含义稍后在提及~
# PEARL 방향 목표 보상 환경 설정# 일반 환경 설정
# ----------
# 메타-트레이닝에 대한 태스크들의 수
train_tasks: 2# 메타-태스팅에 대한 태스크들이 수
test_tasks: 2# Latent context의 차원의 수
latent_dim: 5# 은닉 유닛의 차원의 수
hidden_dim: 300# PEARL 환경 설정
# -------------
pearl_params:# 트레이닝에 대한 반복 수num_iterations: 1000# 매 반복에서 수집할 샘플된 태스크의 수num_sample_tasks: 5# 트레이닝 이전에 태스크마다 수집된 샘플 수num_init_samples: 2000# z ~ prior를 이용할 때 태스크마다 수집할 샘플 수num_prior_samples: 1000# 정책 네트워크의 학습에만 사용되고 인코더에는 사용되지 않는# z ~ posterior를 이용할 때 태스크마다 수집할 샘플 수num_posterior_samples: 1000# 하나의 반복마다 취할 메타-그레디언트의 수num_meta_grads: 1500# 메타-배치의 샘플 수meta_batch_size: 4# Context 배치의 샘플 수batch_size: 256# 환경에 대한 최대 스텝 수max_step: 200# 최대 버퍼 사이즈max_buffer_size: 1000000# 조기 중단 조건의 수num_stop_conditions: 3# 조기 중단 조건에서 사용되는 목표 값stop_goal: 1900# SAC 환경 설정
# -----------
sac_params:# 할인률gamma: 0.99# 인코더 손실 함수에서 사용되는 KL divergence에 대한 가중치 값kl_lambda: 0.1# 정책 네트워크와 가치 네트워크의 배치에서 사용되는 샘플 수batch_size: 256# 행동 가치 함수 네트워크의 학습률qf_lr: 0.0003# 인코더 네트워크의 학습률encoder_lr: 0.0003# 정책 네트워크의 학습률policy_lr: 0.0003
env
实例化环境,里面初始化了元训练任务和元测试任务的目标。
tasks
返回的是一个列表,输出的是一组目标不同的任务的编号。
# 멀티-태스크 환경과 샘플 태스크들 생성
env: HalfCheetahEnv = ENVS["cheetah-" + experiment_config["env_name"]](num_tasks=env_target_config["train_tasks"] + env_target_config["test_tasks"],
)
tasks: List[int] = env.get_all_task_idx()
读取配置文件的随机数信息,给必要的库设置随机数。
# 랜덤 시드 값 설정
env.reset(seed=experiment_config["seed"])
np.random.seed(experiment_config["seed"])
torch.manual_seed(experiment_config["seed"])
根据实例化的环境 env
,指出环境中的状态空间、动作空间和隐藏层维度。
隐藏层维度就是 dir_target_config.yaml
文件的 hidden_dim
变量,数值是300~
observ_dim: int = env.observation_space.shape[0]
action_dim: int = env.action_space.shape[0]
hidden_dim: int = env_target_config["hidden_dim"]
根据配置文件指定显卡。
device: torch.device = (torch.device("cuda", index=experiment_config["gpu_index"])if torch.cuda.is_available()else torch.device("cpu"))
实例化内环智能体 SAC
。实例化参数里面,根据配置信息读取数据。上下文变量的维度:latent_dim
是5,隐藏层的维度:hidden_dim
是300。编码器的输入是状态转移信息,观测维度 + 动作维度 + 奖励值(维度是1)。输出是两个上下文变量的维度,实际是一个均值变量,一个是方差变量,生成一个正态分布。最后再从配置文件中传入SAC的其他参数~
agent = SAC(observ_dim=observ_dim,action_dim=action_dim,latent_dim=env_target_config["latent_dim"],hidden_dim=hidden_dim,encoder_input_dim=observ_dim + action_dim + 1,encoder_output_dim=env_target_config["latent_dim"] * 2,device=device,**env_target_config["sac_params"],)
实例化外环训练器 MetaLearner
。输入的是元训练集 train_tasks
的任务和元测试集 test_tasks
的任务。还有就是保存和载入的模型文件的断电这些~
meta_learner = MetaLearner(env=env,env_name=experiment_config["env_name"],agent=agent,observ_dim=observ_dim,action_dim=action_dim,train_tasks=tasks[: env_target_config["train_tasks"]],test_tasks=tasks[-env_target_config["test_tasks"] :],save_exp_name=experiment_config["save_exp_name"],save_file_name=experiment_config["save_file_name"],load_exp_name=experiment_config["load_exp_name"],load_file_name=experiment_config["load_file_name"],load_ckpt_num=experiment_config["load_ckpt_num"],device=device,**env_target_config["pearl_params"],)
最后开始进行元训练~
meta_learner.meta_train()
元训练的主要代码就在这里了。
total_start_time
和 start_time
记录的是总的训练时间和每次迭代开始的时间。
self.num_iterations
就是配置文件的 num_iterations
,数值是1000。
条件块if iteration == 0
:代码在第0次迭代搜集用于训练和验证的状态转移数据。用循环变量 index
代表训练集任务 self.train_tasks
的下标,用循环变量 index
指引每个任务,对每个任务做初始化操作,然后开始搜集状态转移数据。
self.collect_train_data()
的作用是为每个任务收集若干完整的轨迹作为样本,加入到强化学习经验池子和编码器经验池子当中,这里不进行后验推断。搜集的方法 self.collect_train_data
在下一节~
代码首先遍历 self.num_sample_tasks
次,从训练任务中抽取一个任务环境,重置并清空里面的编码器经验池子。如果 self.num_prior_samples
大于0,用先验分布抽样得到的隐藏层变量然后获得一批数据,放入经验池子,此时不进行后验推断。随后,如果 self.num_prior_samples
大于0,用后验分布抽样得到的隐藏层变量然后获得一批数据,不放入经验池子,此时进行后验推断。
之后开始进行元梯度更新。在元梯度迭代次数 self.num_meta_grads
内,在 self.train_tasks
内采样self.meta_batch_size
的下标,用于指带这么多的任务。清楚这些任务的隐藏层变量z,赋值标准正态分布的采样,随后采样上下文 context_batch
和状态转移 transition_batch
数据。
随后执行 self.agent.train_model()
方法进行元训练,然后再调用 self.meta_test()
做元测试。对于早结束的情况做了一些异常的提示。
元训练过程 self.agent.train_model()
在 7. 代码流程 def train_model()
~
元测试过程 self.meta_test()
在 8. 代码流程 def meta_test()
~
def meta_train(self) -> None:# 메타-트레이닝total_start_time: float = time.time()for iteration in range(self.num_iterations):start_time: float = time.time()# 첫번째 반복단계에 한해 모든 메타-트레이닝 태스크에 대한 경로를 수집하여 리플레이 버퍼에 저장if iteration == 0:print("Collecting initial pool of data for train and eval")for index in tqdm(self.train_tasks):self.env.reset_task(index)self.collect_train_data(task_index=index,max_samples=self.num_init_samples,update_posterior=False,add_to_enc_buffer=True,)print(f"\n=============== Iteration {iteration} ===============")# 임의의 메타 트레이닝 태스크에 대한 새로운 경로를 버퍼에 저장for i in range(self.num_sample_tasks):index = np.random.randint(len(self.train_tasks))self.env.reset_task(index)self.encoder_replay_buffer.task_buffers[index].clear()# 샘플된 z ~ prior r(z)에 대한 경로 수집if self.num_prior_samples > 0:print(f"[{i + 1}/{self.num_sample_tasks}] collecting samples with prior")self.collect_train_data(task_index=index,max_samples=self.num_prior_samples,update_posterior=False,add_to_enc_buffer=True,)# 인코더는 prior r(z)로 생성된 경로 데이터만을 사용하여 학습되나,# RL 정책의 학습에는 z ~ posterior q(z|c)로 생성된 경로도 사용if self.num_posterior_samples > 0:print(f"[{i + 1}/{self.num_sample_tasks}] collecting samples with posterior")self.collect_train_data(task_index=index,max_samples=self.num_posterior_samples,update_posterior=True,add_to_enc_buffer=False,)# 샘플된 메타-배치 태스크들의 경로 데이터로 네트워크 업데이트print(f"Start meta-gradient updates of iteration {iteration}")for i in range(self.num_meta_grads):indices: np.ndarray = np.random.choice(self.train_tasks, self.meta_batch_size)# 인코더의 context와 은닉 상태 초기화self.agent.encoder.clear_z(num_tasks=len(indices))# Context 배치 샘플context_batch: torch.Tensor = self.sample_context(indices)# 경로 배치 샘플transition_batch: List[torch.Tensor] = self.sample_transition(indices)# 정책, Q-함수, 인코더 네트워크를 SAC 알고리즘에서 학습log_values: Dict[str, float] = self.agent.train_model(meta_batch_size=self.meta_batch_size,batch_size=self.batch_size,context_batch=context_batch,transition_batch=transition_batch,)# 인코더의 태스크변수 z의 Backpropagation 차단self.agent.encoder.task_z.detach()# 메타-테스트 태스크에서 학습성능 평가self.meta_test(iteration, total_start_time, start_time, log_values)if self.is_early_stopping:print(f"\n==================================================\n"f"The last {self.num_stop_conditions} meta-testing results are {self.dq}.\n"f"And early stopping condition is {self.is_early_stopping}.\n"f"Therefore, meta-training is terminated.",)break
这个函数是用来搜集状态转移数据,正如上一节提到的。
输入变量:task_index
任务索引下标,指带具体任务的编号;max_samples
:;update_posterior
:;add_to_enc_buffer
:;
代码里 self.agent.encoder.clear_z()
先把上下文z变量赋值标准正态分布的抽样信息。
编码器类在下一节 5. 编码器类 class MLPEncoder(FlattenMLP)
介绍~
self.agent.policy.is_deterministic = False
表示这部分代码的决策是通过分布采样得到而不是直接输出固定确切值。
采样器用于采样 max_samples
数量的若干条完整的轨迹。
采样器类在 6. 采样器类 class Sampler
介绍~
获得完整的轨迹后,将数据存储在第 task_index
任务的经验池子当中。
如果 add_to_enc_buffer
标记为真,那么就将轨迹放置到编码器经验池子 encoder_replay_buffer
当中;如果 update_posterior
标记为真,那么就从编码器经验池子 encoder_replay_buffer
当中采样数据进行后验推断。
def collect_train_data(self,task_index: int,max_samples: int,update_posterior: bool,add_to_enc_buffer: bool,
) -> None:# 주어진 인덱스 태스크에 대한 경로 데이터 수집self.agent.encoder.clear_z()self.agent.policy.is_deterministic = Falsecur_samples = 0while cur_samples < max_samples:trajs, num_samples = self.sampler.obtain_samples(max_samples=max_samples - cur_samples,update_posterior=update_posterior,accum_context=False,)cur_samples += num_samples# RL 리플레이 버퍼에 수집한 데이터 저장self.rl_replay_buffer.add_trajs(task_index, trajs)if add_to_enc_buffer:# 인코더 리플레이 버퍼에 수집한 데이터 저장self.encoder_replay_buffer.add_trajs(task_index, trajs)if update_posterior:# 샘플한 context에 따른 posterior 업데이트context_batch = self.sample_context(np.array([task_index]))self.agent.encoder.infer_posterior(context_batch)
这个是编码器类,在内环训练过程中,需要将状态转移信息用编码器得到上下文变量,用的就是这个类的实例~
初始化输入信息维度 input_dim
、输出信息的维度 output_dim
、上下文变量维度 latent_dim
和中间层维度 hidden_dim
,最后再声明一下显卡配置 device
。
这个类继承了 FlattenMLP
这个类,在计算时候会调用 FlattenMLP
这个类的 forward
方法,实际上就是加了一层 torch.cat()
操作,把几个张量合并在一起。
阅读源码时候, FlattenMLP
这个类继承了 MLP
这个类。 MLP
这个类定义了神经网路,通过python的继承使用语法,网络的输入层维度就是初始化输入信息维度 input_dim
、输出层维度就是输出信息的维度 output_dim
,中间层的神经元数就是中间层维度 hidden_dim
,默认3个中间层。使用 torch.functional.F.relu
激活函数。 MLP
这个类还定义了基本的前向计算 forward()
,与我们一般的神经网络差不多。
回到 MLPEncoder
这个类。初始化了隐藏层变量的均值 self.z_mean
、方差 self.z_var
和 self.task_z
均为 None
。执行了一下 self.clear_z()
操作。
接下来细说 self.clear_z()
操作:先初始化先验分布为标准正态分布,标准正态分布的维度 = 上下文变量的维度 × 任务数(默认是1),为每个任务抽样上下文变量做准备。
接下来执行 self.sample_z()
操作:首先采用 torch.unbind()
方法对均值张量和方差张量对第0维度做了切片,然后用 zip
捆绑起来,并给 mean
和 var
赋值,也就是分别为每个任务附加先验的标准正态分布。依次构建标准正态分布并填充于一个列表 dists
。再对列表里面的每个任务的标准正态分布重采样上下文变量,并堆叠起来成为一个张量。
class MLPEncoder(FlattenMLP):def __init__(self,input_dim: int,output_dim: int,latent_dim: int,hidden_dim: int,device: torch.device,) -> None:super().__init__(input_dim=input_dim, output_dim=output_dim, hidden_dim=hidden_dim)self.output_dim = output_dimself.latent_dim = latent_dimself.device = deviceself.z_mean = Noneself.z_var = Noneself.task_z = Noneself.clear_z()def clear_z(self, num_tasks: int = 1) -> None:# q(z|c)를 prior r(z)로 초기화self.z_mean = torch.zeros(num_tasks, self.latent_dim).to(self.device)self.z_var = torch.ones(num_tasks, self.latent_dim).to(self.device)# Prior r(z)에서 새로운 z를 생성self.sample_z()# 지금까지 모은 context 초기화self.context = Nonedef sample_z(self) -> None:# z ~ r(z) 또는 z ~ q(z|c) 생성dists = []for mean, var in zip(torch.unbind(self.z_mean), torch.unbind(self.z_var)):dist = torch.distributions.Normal(mean, torch.sqrt(var))dists.append(dist)sampled_z = [dist.rsample() for dist in dists]self.task_z = torch.stack(sampled_z).to(self.device)
这个类是用来采样轨迹数据的。
首先来看 def rollout
方法,这个方法的作用是获得一条智能体在环境中运动的数据(only one)。初始化了一系列列表,用于记录智能体与环境交互的各个数据。obs
获得环境的状态信息,done
是是否完成的标志,默认是 False
。智能体根据状态信息,获得动作信息;然后作用在环境中获得下一个状态、奖励和是否完成的标志。如果 accum_context
为 True
,那么就实时更新上下文信息,否则就没有。然后将状态转移数据记录到一个个列表中,更新 obs
进入下一状态。最后返回的是一个字典,记录了这些信息。
接下来看 def update_context
方法。先把状态信息 obs
、动作信息 action
和奖励 reward
改成高维度的浮点型GPU张量, 然后对这些信息做连接得到 transition
上下文变量。如果编码器的上下文信息 self.agent.encoder.context
是空 None
的话,那么得到 transition
上下文变量赋值给self.agent.encoder.context
;否则就是在原来 self.agent.encoder.context
基础上再增加当前状态转移的上下文。
最后看 def obtain_samples
方法。输入的 max_samples
表示最大采样数量、update_posterior
表示更新后验标记和 accum_context
表示是否累积上下文。这个方法的作用是:获得大小为 max_samples
的状态转移数据,这些数据由若干完整的轨迹变量 trajs
组成。然后还采样了隐藏层变量z,最后进行输出。
class Sampler:def __init__(self,env: HalfCheetahEnv,agent: SAC,max_step: int,device: torch.device,) -> None:self.env = envself.agent = agentself.max_step = max_stepself.device = devicedef obtain_samples(self,max_samples: int,update_posterior: bool,accum_context: bool = True,) -> Tuple[List[Dict[str, np.ndarray]], int]:# 최대 샘플량의 수까지 샘플들 얻기trajs = []cur_samples = 0while cur_samples < max_samples:traj = self.rollout(accum_context=accum_context)trajs.append(traj)cur_samples += len(traj["cur_obs"])self.agent.encoder.sample_z()if update_posterior:breakreturn trajs, cur_samplesdef rollout(self, accum_context: bool = True) -> Dict[str, np.ndarray]:# 최대 경로 길이까지 경로 생성_cur_obs = []_actions = []_rewards = []_next_obs = []_dones = []_infos = []obs = self.env.reset()done = Falsecur_step = 0while not (done or cur_step == self.max_step):action = self.agent.get_action(obs)next_obs, reward, done, info = self.env.step(action)# 에이전트의 현재 context 업데이트if accum_context:self.update_context(obs=obs, action=action, reward=np.array([reward]))_cur_obs.append(obs)_actions.append(action)_rewards.append(reward)_next_obs.append(next_obs)_dones.append(done)_infos.append(info["run_cost"])cur_step += 1obs = next_obsreturn dict(cur_obs=np.array(_cur_obs),actions=np.array(_actions),rewards=np.array(_rewards).reshape(-1, 1),next_obs=np.array(_next_obs),dones=np.array(_dones).reshape(-1, 1),infos=np.array(_infos),)def update_context(self, obs: np.ndarray, action: np.ndarray, reward: np.ndarray) -> None:# 현재 context에 하나의 transition 추가obs = obs.reshape((1, 1, *obs.shape))action = action.reshape((1, 1, *action.shape))reward = reward.reshape((1, 1, *reward.shape))obs = torch.from_numpy(obs).float().to(self.device)action = torch.from_numpy(action).float().to(self.device)reward = torch.from_numpy(reward).float().to(self.device)transition = torch.cat([obs, action, reward], dim=-1).to(self.device)if self.agent.encoder.context is None:self.agent.encoder.context = transitionelse:self.agent.encoder.context = torch.cat([self.agent.encoder.context, transition], dim=1).to(self.device,)
2023-03-22-23-00