元强化学习 PEARL 项目解读
创始人
2025-06-01 18:58:57
0

文章目录

  • PEARL 项目解读
    • 1. 文件总概览
    • 2. 代码流程 pearl_trainer.py
    • 3. 代码流程 meta_learner.meta_train()
    • 4. 代码流程 def collect_train_data()
    • 5. 编码器类 class MLPEncoder(FlattenMLP)
    • 6. 采样器类 class Sampler
    • 7. 代码流程 def train_model()
    • 8. 代码流程 def meta_test()

PEARL 项目解读

因为自己的工作需要,需要跑一下元强化学习的 PEARL (Efficient Off-policy Meta-learning via Probabilistic Context Variables) 代码做一些对比实验。

最原始的代码(https://github.com/katerakelly/oyster)。但是这个代码下载下来,环境配置比较困难。

搜索 Github 上看到这个项目(链接:https://github.com/dongminlee94/meta-learning-for-everyone),里面介绍了 MAML、RL2 和 PEARL。阅读项目介绍,感觉比较科普,因此就以这个项目做解读。理解里面的代码细节,方便做复现~

1. 文件总概览

进入到 PEARL 文件夹下,输入指令 tree 即可大致了解里面的文件包含情况~

.
├── algorithm
│   ├── buffers.py
│   ├── meta_learner.py
│   ├── networks.py
│   ├── sac.py
│   └── sampler.py
├── configs
│   ├── dir_target_config.yaml
│   ├── experiment_config.yaml
│   └── vel_target_config.yaml
└── pearl_trainer.py

algorithm 文件夹下面是基本的组建:经验池 buffers.py,元学习外环框架 meta_learner.py,模型网络 networks.py,内环框架 sac.py,以及最后的批采样器 sampler.py

configs 下面是一些实验的配置文件:experiment_config.yaml 总的配置文件,dir_target_config.yaml 控制方向实验配置文件,vel_target_config.yaml 控制速度配置文件。

pearl_trainer.py 是整个训练脚本,运行时候直接运行这个文件即可~

2. 代码流程 pearl_trainer.py

导入 experiment_config.yaml 文件。

    with open(os.path.join("configs", "experiment_config.yaml"), "r", encoding="utf-8") as file:experiment_config: Dict[str, Any] = yaml.load(file, Loader=yaml.FullLoader)

导入具体某个实验的配置数据,experiment_config.yaml 文件中:env_name: "dir",因此导入的是方向实验文件,也就是 dir_target_config.yaml

# 목표 보상 설정에 대한 하이퍼파라미터들 불러오기with open(os.path.join("configs", experiment_config["env_name"] + "_target_config.yaml"),"r",encoding="utf-8",) as file:env_target_config: Dict[str, Any] = yaml.load(file, Loader=yaml.FullLoader)

下面是 dir_target_config.yaml 的具体细节。具体变量含义稍后在提及~

# PEARL 방향 목표 보상 환경 설정# 일반 환경 설정
# ----------
# 메타-트레이닝에 대한 태스크들의 수
train_tasks: 2# 메타-태스팅에 대한 태스크들이 수
test_tasks: 2# Latent context의 차원의 수
latent_dim: 5# 은닉 유닛의 차원의 수
hidden_dim: 300# PEARL 환경 설정
# -------------
pearl_params:# 트레이닝에 대한 반복 수num_iterations: 1000# 매 반복에서 수집할 샘플된 태스크의 수num_sample_tasks: 5# 트레이닝 이전에 태스크마다 수집된 샘플 수num_init_samples: 2000# z ~ prior를 이용할 때 태스크마다 수집할 샘플 수num_prior_samples: 1000# 정책 네트워크의 학습에만 사용되고 인코더에는 사용되지 않는# z ~ posterior를 이용할 때 태스크마다 수집할 샘플 수num_posterior_samples: 1000# 하나의 반복마다 취할 메타-그레디언트의 수num_meta_grads: 1500# 메타-배치의 샘플 수meta_batch_size: 4# Context 배치의 샘플 수batch_size: 256# 환경에 대한 최대 스텝 수max_step: 200# 최대 버퍼 사이즈max_buffer_size: 1000000# 조기 중단 조건의 수num_stop_conditions: 3# 조기 중단 조건에서 사용되는 목표 값stop_goal: 1900# SAC 환경 설정
# -----------
sac_params:# 할인률gamma: 0.99# 인코더 손실 함수에서 사용되는 KL divergence에 대한 가중치 값kl_lambda: 0.1# 정책 네트워크와 가치 네트워크의 배치에서 사용되는 샘플 수batch_size: 256# 행동 가치 함수 네트워크의 학습률qf_lr: 0.0003# 인코더 네트워크의 학습률encoder_lr: 0.0003# 정책 네트워크의 학습률policy_lr: 0.0003

env 实例化环境,里面初始化了元训练任务和元测试任务的目标。

tasks 返回的是一个列表,输出的是一组目标不同的任务的编号。

# 멀티-태스크 환경과 샘플 태스크들 생성
env: HalfCheetahEnv = ENVS["cheetah-" + experiment_config["env_name"]](num_tasks=env_target_config["train_tasks"] + env_target_config["test_tasks"],
)
tasks: List[int] = env.get_all_task_idx()

读取配置文件的随机数信息,给必要的库设置随机数。

# 랜덤 시드 값 설정
env.reset(seed=experiment_config["seed"])
np.random.seed(experiment_config["seed"])
torch.manual_seed(experiment_config["seed"])

根据实例化的环境 env,指出环境中的状态空间、动作空间和隐藏层维度。

隐藏层维度就是 dir_target_config.yaml 文件的 hidden_dim 变量,数值是300~

observ_dim: int = env.observation_space.shape[0]
action_dim: int = env.action_space.shape[0]
hidden_dim: int = env_target_config["hidden_dim"]

根据配置文件指定显卡。

device: torch.device = (torch.device("cuda", index=experiment_config["gpu_index"])if torch.cuda.is_available()else torch.device("cpu"))

实例化内环智能体 SAC。实例化参数里面,根据配置信息读取数据。上下文变量的维度:latent_dim 是5,隐藏层的维度:hidden_dim 是300。编码器的输入是状态转移信息,观测维度 + 动作维度 + 奖励值(维度是1)。输出是两个上下文变量的维度,实际是一个均值变量,一个是方差变量,生成一个正态分布。最后再从配置文件中传入SAC的其他参数~

agent = SAC(observ_dim=observ_dim,action_dim=action_dim,latent_dim=env_target_config["latent_dim"],hidden_dim=hidden_dim,encoder_input_dim=observ_dim + action_dim + 1,encoder_output_dim=env_target_config["latent_dim"] * 2,device=device,**env_target_config["sac_params"],)

实例化外环训练器 MetaLearner。输入的是元训练集 train_tasks 的任务和元测试集 test_tasks 的任务。还有就是保存和载入的模型文件的断电这些~

meta_learner = MetaLearner(env=env,env_name=experiment_config["env_name"],agent=agent,observ_dim=observ_dim,action_dim=action_dim,train_tasks=tasks[: env_target_config["train_tasks"]],test_tasks=tasks[-env_target_config["test_tasks"] :],save_exp_name=experiment_config["save_exp_name"],save_file_name=experiment_config["save_file_name"],load_exp_name=experiment_config["load_exp_name"],load_file_name=experiment_config["load_file_name"],load_ckpt_num=experiment_config["load_ckpt_num"],device=device,**env_target_config["pearl_params"],)

最后开始进行元训练~

meta_learner.meta_train()

3. 代码流程 meta_learner.meta_train()

元训练的主要代码就在这里了。

total_start_timestart_time 记录的是总的训练时间和每次迭代开始的时间。

self.num_iterations 就是配置文件的 num_iterations ,数值是1000。

条件块if iteration == 0 :代码在第0次迭代搜集用于训练和验证的状态转移数据。用循环变量 index 代表训练集任务 self.train_tasks 的下标,用循环变量 index 指引每个任务,对每个任务做初始化操作,然后开始搜集状态转移数据。

self.collect_train_data() 的作用是为每个任务收集若干完整的轨迹作为样本,加入到强化学习经验池子和编码器经验池子当中,这里不进行后验推断。搜集的方法 self.collect_train_data 在下一节~

代码首先遍历 self.num_sample_tasks 次,从训练任务中抽取一个任务环境,重置并清空里面的编码器经验池子。如果 self.num_prior_samples 大于0,用先验分布抽样得到的隐藏层变量然后获得一批数据,放入经验池子,此时不进行后验推断。随后,如果 self.num_prior_samples 大于0,用后验分布抽样得到的隐藏层变量然后获得一批数据,不放入经验池子,此时进行后验推断。

之后开始进行元梯度更新。在元梯度迭代次数 self.num_meta_grads 内,在 self.train_tasks 内采样self.meta_batch_size 的下标,用于指带这么多的任务。清楚这些任务的隐藏层变量z,赋值标准正态分布的采样,随后采样上下文 context_batch 和状态转移 transition_batch 数据。

随后执行 self.agent.train_model() 方法进行元训练,然后再调用 self.meta_test() 做元测试。对于早结束的情况做了一些异常的提示。

元训练过程 self.agent.train_model()7. 代码流程 def train_model()

元测试过程 self.meta_test()8. 代码流程 def meta_test()

def meta_train(self) -> None:# 메타-트레이닝total_start_time: float = time.time()for iteration in range(self.num_iterations):start_time: float = time.time()# 첫번째 반복단계에 한해 모든 메타-트레이닝 태스크에 대한 경로를 수집하여 리플레이 버퍼에 저장if iteration == 0:print("Collecting initial pool of data for train and eval")for index in tqdm(self.train_tasks):self.env.reset_task(index)self.collect_train_data(task_index=index,max_samples=self.num_init_samples,update_posterior=False,add_to_enc_buffer=True,)print(f"\n=============== Iteration {iteration} ===============")# 임의의 메타 트레이닝 태스크에 대한 새로운 경로를 버퍼에 저장for i in range(self.num_sample_tasks):index = np.random.randint(len(self.train_tasks))self.env.reset_task(index)self.encoder_replay_buffer.task_buffers[index].clear()# 샘플된 z ~ prior r(z)에 대한 경로 수집if self.num_prior_samples > 0:print(f"[{i + 1}/{self.num_sample_tasks}] collecting samples with prior")self.collect_train_data(task_index=index,max_samples=self.num_prior_samples,update_posterior=False,add_to_enc_buffer=True,)# 인코더는 prior r(z)로 생성된 경로 데이터만을 사용하여 학습되나,# RL 정책의 학습에는 z ~ posterior q(z|c)로 생성된 경로도 사용if self.num_posterior_samples > 0:print(f"[{i + 1}/{self.num_sample_tasks}] collecting samples with posterior")self.collect_train_data(task_index=index,max_samples=self.num_posterior_samples,update_posterior=True,add_to_enc_buffer=False,)# 샘플된 메타-배치 태스크들의 경로 데이터로 네트워크 업데이트print(f"Start meta-gradient updates of iteration {iteration}")for i in range(self.num_meta_grads):indices: np.ndarray = np.random.choice(self.train_tasks, self.meta_batch_size)# 인코더의 context와 은닉 상태 초기화self.agent.encoder.clear_z(num_tasks=len(indices))# Context 배치 샘플context_batch: torch.Tensor = self.sample_context(indices)# 경로 배치 샘플transition_batch: List[torch.Tensor] = self.sample_transition(indices)# 정책, Q-함수, 인코더 네트워크를 SAC 알고리즘에서 학습log_values: Dict[str, float] = self.agent.train_model(meta_batch_size=self.meta_batch_size,batch_size=self.batch_size,context_batch=context_batch,transition_batch=transition_batch,)# 인코더의 태스크변수 z의 Backpropagation 차단self.agent.encoder.task_z.detach()# 메타-테스트 태스크에서 학습성능 평가self.meta_test(iteration, total_start_time, start_time, log_values)if self.is_early_stopping:print(f"\n==================================================\n"f"The last {self.num_stop_conditions} meta-testing results are {self.dq}.\n"f"And early stopping condition is {self.is_early_stopping}.\n"f"Therefore, meta-training is terminated.",)break

4. 代码流程 def collect_train_data()

这个函数是用来搜集状态转移数据,正如上一节提到的。

输入变量:task_index 任务索引下标,指带具体任务的编号;max_samples :;update_posterior:;add_to_enc_buffer:;

代码里 self.agent.encoder.clear_z() 先把上下文z变量赋值标准正态分布的抽样信息。

编码器类在下一节 5. 编码器类 class MLPEncoder(FlattenMLP) 介绍~

self.agent.policy.is_deterministic = False 表示这部分代码的决策是通过分布采样得到而不是直接输出固定确切值。

采样器用于采样 max_samples 数量的若干条完整的轨迹。

采样器类在 6. 采样器类 class Sampler 介绍~

获得完整的轨迹后,将数据存储在第 task_index 任务的经验池子当中。

如果 add_to_enc_buffer 标记为真,那么就将轨迹放置到编码器经验池子 encoder_replay_buffer 当中;如果 update_posterior 标记为真,那么就从编码器经验池子 encoder_replay_buffer 当中采样数据进行后验推断。

def collect_train_data(self,task_index: int,max_samples: int,update_posterior: bool,add_to_enc_buffer: bool,
) -> None:# 주어진 인덱스 태스크에 대한 경로 데이터 수집self.agent.encoder.clear_z()self.agent.policy.is_deterministic = Falsecur_samples = 0while cur_samples < max_samples:trajs, num_samples = self.sampler.obtain_samples(max_samples=max_samples - cur_samples,update_posterior=update_posterior,accum_context=False,)cur_samples += num_samples# RL 리플레이 버퍼에 수집한 데이터 저장self.rl_replay_buffer.add_trajs(task_index, trajs)if add_to_enc_buffer:# 인코더 리플레이 버퍼에 수집한 데이터 저장self.encoder_replay_buffer.add_trajs(task_index, trajs)if update_posterior:# 샘플한 context에 따른 posterior 업데이트context_batch = self.sample_context(np.array([task_index]))self.agent.encoder.infer_posterior(context_batch)

5. 编码器类 class MLPEncoder(FlattenMLP)

这个是编码器类,在内环训练过程中,需要将状态转移信息用编码器得到上下文变量,用的就是这个类的实例~

初始化输入信息维度 input_dim、输出信息的维度 output_dim、上下文变量维度 latent_dim 和中间层维度 hidden_dim,最后再声明一下显卡配置 device

这个类继承了 FlattenMLP 这个类,在计算时候会调用 FlattenMLP 这个类的 forward 方法,实际上就是加了一层 torch.cat() 操作,把几个张量合并在一起。

阅读源码时候, FlattenMLP 这个类继承了 MLP 这个类。 MLP 这个类定义了神经网路,通过python的继承使用语法,网络的输入层维度就是初始化输入信息维度 input_dim、输出层维度就是输出信息的维度 output_dim,中间层的神经元数就是中间层维度 hidden_dim,默认3个中间层。使用 torch.functional.F.relu 激活函数。 MLP 这个类还定义了基本的前向计算 forward(),与我们一般的神经网络差不多。

回到 MLPEncoder 这个类。初始化了隐藏层变量的均值 self.z_mean 、方差 self.z_varself.task_z 均为 None。执行了一下 self.clear_z() 操作。

接下来细说 self.clear_z() 操作:先初始化先验分布为标准正态分布,标准正态分布的维度 = 上下文变量的维度 × 任务数(默认是1),为每个任务抽样上下文变量做准备。

接下来执行 self.sample_z() 操作:首先采用 torch.unbind() 方法对均值张量和方差张量对第0维度做了切片,然后用 zip 捆绑起来,并给 meanvar 赋值,也就是分别为每个任务附加先验的标准正态分布。依次构建标准正态分布并填充于一个列表 dists 。再对列表里面的每个任务的标准正态分布重采样上下文变量,并堆叠起来成为一个张量。

class MLPEncoder(FlattenMLP):def __init__(self,input_dim: int,output_dim: int,latent_dim: int,hidden_dim: int,device: torch.device,) -> None:super().__init__(input_dim=input_dim, output_dim=output_dim, hidden_dim=hidden_dim)self.output_dim = output_dimself.latent_dim = latent_dimself.device = deviceself.z_mean = Noneself.z_var = Noneself.task_z = Noneself.clear_z()def clear_z(self, num_tasks: int = 1) -> None:# q(z|c)를 prior r(z)로 초기화self.z_mean = torch.zeros(num_tasks, self.latent_dim).to(self.device)self.z_var = torch.ones(num_tasks, self.latent_dim).to(self.device)# Prior r(z)에서 새로운 z를 생성self.sample_z()# 지금까지 모은 context 초기화self.context = Nonedef sample_z(self) -> None:# z ~ r(z) 또는 z ~ q(z|c) 생성dists = []for mean, var in zip(torch.unbind(self.z_mean), torch.unbind(self.z_var)):dist = torch.distributions.Normal(mean, torch.sqrt(var))dists.append(dist)sampled_z = [dist.rsample() for dist in dists]self.task_z = torch.stack(sampled_z).to(self.device)

6. 采样器类 class Sampler

这个类是用来采样轨迹数据的。

首先来看 def rollout 方法,这个方法的作用是获得一条智能体在环境中运动的数据(only one)。初始化了一系列列表,用于记录智能体与环境交互的各个数据。obs 获得环境的状态信息,done 是是否完成的标志,默认是 False。智能体根据状态信息,获得动作信息;然后作用在环境中获得下一个状态、奖励和是否完成的标志。如果 accum_contextTrue,那么就实时更新上下文信息,否则就没有。然后将状态转移数据记录到一个个列表中,更新 obs 进入下一状态。最后返回的是一个字典,记录了这些信息。

接下来看 def update_context 方法。先把状态信息 obs 、动作信息 action 和奖励 reward 改成高维度的浮点型GPU张量, 然后对这些信息做连接得到 transition 上下文变量。如果编码器的上下文信息 self.agent.encoder.context 是空 None 的话,那么得到 transition 上下文变量赋值给self.agent.encoder.context ;否则就是在原来 self.agent.encoder.context 基础上再增加当前状态转移的上下文。

最后看 def obtain_samples 方法。输入的 max_samples 表示最大采样数量、update_posterior 表示更新后验标记和 accum_context 表示是否累积上下文。这个方法的作用是:获得大小为 max_samples 的状态转移数据,这些数据由若干完整的轨迹变量 trajs 组成。然后还采样了隐藏层变量z,最后进行输出。

class Sampler:def __init__(self,env: HalfCheetahEnv,agent: SAC,max_step: int,device: torch.device,) -> None:self.env = envself.agent = agentself.max_step = max_stepself.device = devicedef obtain_samples(self,max_samples: int,update_posterior: bool,accum_context: bool = True,) -> Tuple[List[Dict[str, np.ndarray]], int]:# 최대 샘플량의 수까지 샘플들 얻기trajs = []cur_samples = 0while cur_samples < max_samples:traj = self.rollout(accum_context=accum_context)trajs.append(traj)cur_samples += len(traj["cur_obs"])self.agent.encoder.sample_z()if update_posterior:breakreturn trajs, cur_samplesdef rollout(self, accum_context: bool = True) -> Dict[str, np.ndarray]:# 최대 경로 길이까지 경로 생성_cur_obs = []_actions = []_rewards = []_next_obs = []_dones = []_infos = []obs = self.env.reset()done = Falsecur_step = 0while not (done or cur_step == self.max_step):action = self.agent.get_action(obs)next_obs, reward, done, info = self.env.step(action)# 에이전트의 현재 context 업데이트if accum_context:self.update_context(obs=obs, action=action, reward=np.array([reward]))_cur_obs.append(obs)_actions.append(action)_rewards.append(reward)_next_obs.append(next_obs)_dones.append(done)_infos.append(info["run_cost"])cur_step += 1obs = next_obsreturn dict(cur_obs=np.array(_cur_obs),actions=np.array(_actions),rewards=np.array(_rewards).reshape(-1, 1),next_obs=np.array(_next_obs),dones=np.array(_dones).reshape(-1, 1),infos=np.array(_infos),)def update_context(self, obs: np.ndarray, action: np.ndarray, reward: np.ndarray) -> None:# 현재 context에 하나의 transition 추가obs = obs.reshape((1, 1, *obs.shape))action = action.reshape((1, 1, *action.shape))reward = reward.reshape((1, 1, *reward.shape))obs = torch.from_numpy(obs).float().to(self.device)action = torch.from_numpy(action).float().to(self.device)reward = torch.from_numpy(reward).float().to(self.device)transition = torch.cat([obs, action, reward], dim=-1).to(self.device)if self.agent.encoder.context is None:self.agent.encoder.context = transitionelse:self.agent.encoder.context = torch.cat([self.agent.encoder.context, transition], dim=1).to(self.device,)

7. 代码流程 def train_model()

8. 代码流程 def meta_test()


2023-03-22-23-00

相关内容

热门资讯

最新或2023(历届)学校世界... 世界艾滋病日是为了提高公众对艾滋病的认识,共同对抗艾滋病在全球传播,世界卫生组织1988年组织召开“...
最新或2023(历届)世界艾滋...   最新或2023(历届)12月1日是第28个“世界艾滋病日”。今年活动主题仍是“行动起来,向‘零’...
最新或2023(历届)世界艾滋...   最新或2023(历届)世界艾滋病日是在12月1日,星期二。下面是小编给大家整理的关于世界艾滋病日...
最新或2023(历届)关于世界...   在第28个世界艾滋病日来临之际,小编精心给大家整理了关于世界艾滋病日的黑板报资料,希望对大家有所...
最新或2023(历届)世界艾滋...   世界卫生组织将历年年12月1日定为世界艾滋病日,是因为第一个艾滋病病例是在1981年此日诊断出来...
最新或2023(历届)世界艾滋...  世界艾滋病日是全世界同艾滋病作斗争日。1988年作为全世界防治艾滋病年;每年12月1日作为世界艾滋...
最新或2023(历届)世界艾滋...  12月1号是世界艾滋病日,小编在此给大家整理了关于世界艾滋病日的黑板报资料:世界艾滋病日的历史起源...
最新或2023(历届)世界艾滋...  为提高人们对艾滋病的认识,世界卫生组织于1988年1月将每年的12月1日定为世界艾滋病日,号召世界...
最新或2023(历届)世界艾滋...   为提高人们对艾滋病的认识,世界卫生组织于1988年1月将每年的12月1日定为世界艾滋病日,号召世...
最新或2023(历届)世界艾滋...   12月1日是世界艾滋病日,这天旨在提高公众对HIV病毒引起的艾滋病在全球传播的意识。小编在此给大...
最新或2023(历届)关于一二...   最新或2023(历届)12月9日是纪念一二九爱国运动的八十周年,下面是小编给大家整理的关于一二九...
最新或2023(历届)感恩节黑...  每年11月的第四个星期四是美国传统的感恩节。在美国人的心目中,感恩节的重要性仅次于圣诞节。这是一个...
最新或2023(历届)感恩节黑...   下个星期四就是感恩节了,大家都准备好感恩节的黑板报资料了吗?一起来参考下小编给大家整理的具体内容...
最新或2023(历届)中小学纪...  一二九运动是指1935年12月9日发生在北平的一次伟大的抗日救亡运动。最新或2023(历届)的一二...
最新或2023(历届)班级纪念...  中国共产党领导的一次学生爱国运动。1935年12月9日,北平(今北京)学生数千人在中国共产党领导下...
最新或2023(历届)感恩节黑...   感恩节是美国国定假日中最地道、最美国式的节日,而且它和早期美国历史最为密切相关。感恩节就在每年1...
最新或2023(历届)感恩节黑...   每年11月的第四个星期四为感恩节(英语:Thanksgiving Day)。这是美国Day)。这...
最新或2023(历届)漂亮感恩...   感恩节(Thanksgiving Day)是美国人民独创的一个古老节日,也是美国人合家欢聚的节日...
最新或2023(历届)感恩节中...   感恩节(Thanksgiving Day)是美国人民独创的一个古老节日,也是美国人合家欢聚的节日...
最新或2023(历届)最新感恩... 11月的第四个星期四是感恩节。感恩节是美国人民独创的一个古老节日,也是美国人合家欢聚的节日,因此美国...