21- 神经网络模型_超参数搜索 (TensorFlow系列) (深度学习)
创始人
2025-05-31 04:08:15
0

知识要点

  • fetch_california_housing:加利福尼亚的房价数据,总计20640个样本,每个样本8个属性表示,以及房价作为target

  • 超参数搜索的方式: 网格搜索, 随机搜索, 遗传算法搜索, 启发式搜索

  • 超参数训练后用: gv.estimator调取最佳模型

  • 函数式添加神经网络:

    • model.add(keras.layers.Dense(layer_size, activation = 'relu'))

    • model.compile(loss = 'mse', optimizer = optimizer)    # optimizer = keras.optimizers.SGD (learning_rate)

    • sklearn_model = KerasRegressor(build_fn = build_model)

from tensorflow.keras.wrappers.scikit_learn import KerasRegressor   # 回归神经网络
# 搜索最佳学习率
def build_model(hidden_layers = 1, layer_size = 30, learning_rate = 3e-3):model = keras.models.Sequential()model.add(keras.layers.Dense(layer_size, activation = 'relu', input_shape = x_train.shape[1:]))for _ in range(hidden_layers - 1):model.add(keras.layers.Dense(layer_size, activation = 'relu'))model.add(keras.layers.Dense(1))optimizer = keras.optimizers.SGD(learning_rate)model.compile(loss = 'mse', optimizer = optimizer)# model.summary()return model
sklearn_model = KerasRegressor(build_fn = build_model)
  • callbacks = [keras.callbacks.EarlyStopping(patience = 5, min_delta = 1e-3)]  # 回调函数设置

  • gv = GridSearchCV(sklearn_model, param_grid = params, n_jobs = 1, cv= 5,verbose = 1)  # 找最佳参数

  • gv.fit(x_train_scaled, y_train)


1 导包

from tensorflow import keras
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
cpu=tf.config.list_physical_devices("CPU")
tf.config.set_visible_devices(cpu)
print(tf.config.list_logical_devices())

2 导入数据

from sklearn.model_selection import train_test_split
from sklearn.datasets import fetch_california_housinghousing = fetch_california_housing()
x_train_all, x_test, y_train_all, y_test = train_test_split(housing.data,housing.target,random_state= 7)
x_train, x_valid, y_train, y_valid = train_test_split(x_train_all, y_train_all,random_state = 11)

3 标准化处理数据

from sklearn.preprocessing import StandardScaler, MinMaxScalerscaler =StandardScaler()
x_train_scaled = scaler.fit_transform(x_train)
x_valid_scaled = scaler.transform(x_valid)
x_test_scaled = scaler.transform(x_test)

4 函数式定义模型

from tensorflow.keras.wrappers.scikit_learn import KerasRegressor   # 回归神经网络
# 搜索最佳学习率
def build_model(hidden_layers = 1, layer_size = 30, learning_rate = 3e-3):model = keras.models.Sequential()model.add(keras.layers.Dense(layer_size, activation = 'relu', input_shape = x_train.shape[1:]))for _ in range(hidden_layers - 1):model.add(keras.layers.Dense(layer_size, activation = 'relu'))model.add(keras.layers.Dense(1))optimizer = keras.optimizers.SGD(learning_rate)model.compile(loss = 'mse', optimizer = optimizer)# model.summary()return model
sklearn_model = KerasRegressor(build_fn = build_model)

 

5 模型训练

callbacks = [keras.callbacks.EarlyStopping(patience = 5, min_delta = 1e-3)]
history = sklearn_model.fit(x_train_scaled, y_train, epochs = 10,validation_data = (x_valid_scaled, y_valid), callbacks = callbacks)

 6 超参数搜索

超参数搜索的方式:

  • 网格搜索

    • 定义n维方格

    • 每个方格对应一组超参数

    • 一组一组参数尝试

  • 随机搜索

  • 遗传算法搜索

    • 对自然界的模拟

    • A: 初始化候选参数集合 --> 训练---> 得到模型指标作为生存概率

    • B: 选择 --> 交叉--> 变异 --> 产生下一代集合

    • C: 重新到A, 循环.

  • 启发式搜索

    • 研究热点-- AutoML的一部分

    • 使用循环神经网络来生成参数

    • 使用强化学习来进行反馈, 使用模型来训练生成参数.

# 使用sklearn 的网格搜索, 或者随机搜索
from sklearn.model_selection import GridSearchCV, RandomizedSearchCVparams = {'learning_rate' : [1e-4, 3e-4, 1e-3, 3e-3, 1e-2, 3e-2],'hidden_layers': [2, 3, 4, 5], 'layer_size': [20, 60, 100]}gv = GridSearchCV(sklearn_model, param_grid = params, n_jobs = 1, cv= 5,verbose = 1)
gv.fit(x_train_scaled, y_train)
  • 输出最佳参数
# 最佳得分
print(gv.best_score_)    # -0.47164334654808043
# 最佳参数
print(gv.best_params_)  # {'hidden_layers': 5,'layer_size': 100,'learning_rate':0.01}
# 最佳模型
print(gv.estimator)
''''''
gv.score

7 最佳参数建模

model = keras.models.Sequential()
model.add(keras.layers.Dense(100, activation = 'relu', input_shape = x_train.shape[1:]))
for _ in range(4):model.add(keras.layers.Dense(100, activation = 'relu'))
model.add(keras.layers.Dense(1))
optimizer = keras.optimizers.SGD(0.01)
model.compile(loss = 'mse', optimizer = optimizer)
model.summary()

callbacks = [keras.callbacks.EarlyStopping(patience = 5, min_delta = 1e-3)]
history = model.fit(x_train_scaled, y_train, epochs = 10,validation_data = (x_valid_scaled, y_valid), callbacks = callbacks)

 8 手动实现超参数搜索

  • 根据参数进行多次模型的训练, 然后记录 loss
# 搜索最佳学习率
learning_rates = [1e-4, 3e-4, 1e-3, 3e-3, 1e-2, 3e-2]
histories = []
for lr in learning_rates:model = keras.models.Sequential([keras.layers.Dense(30, activation = 'relu', input_shape = x_train.shape[1:]),keras.layers.Dense(1)])optimizer = keras.optimizers.SGD(lr)model.compile(loss = 'mse', optimizer = optimizer, metrics = ['mse'])callbacks = [keras.callbacks.EarlyStopping(patience = 5, min_delta = 1e-2)]history = model.fit(x_train_scaled, y_train, validation_data = (x_valid_scaled, y_valid), epochs = 100, callbacks = callbacks)histories.append(history)

 

# 画图
import pandas as pd
def plot_learning_curves(history):pd.DataFrame(history.history).plot(figsize = (8, 5))plt.grid(True)plt.gca().set_ylim(0, 1)plt.show()for lr, history in zip(learning_rates, histories): print(lr)plot_learning_curves(history)   

相关内容

热门资讯

SQL注入之HTTP请求头注入 Ps: 先做实验,在有操作的基础上理解原理会更清晰更深入。 一、实验 s...
最新或2023(历届)纪律教育... 【最新或2023(历届)纪律教育学习月活动总结1】  根据市委办《中共肇庆市委办公室转发〈市纪委关于...
开展最新或2023(历届)纪律... 【开展最新或2023(历届)纪律教育学习月活动总结1】  自我局“纪律教育学习月”活动启动以来,县局...
最新或2023(历届)扶贫助学... 扶贫助学活动总结【1】  “让我们荡起双桨,小船儿推开波浪…….”,英德市青塘镇中心小学的校园内传来...
开展最新或2023(历届)征信... 【开展最新或2023(历届)征信宣传活动总结1】  为进一步拓宽征信知识宣传的广度和深度,提升征信服...
Go项目(rocketmq) 文章目录简介场景技术选型rocketmq概念消息类型go-client集成CreateOrderin...
银行开展最新或2023(历届)... 【银行开展最新或2023(历届)征信宣传活动总结1】  6月14日是全国第十个“信用记录关爱日”,今...
【计算机视觉】经典的图卷积网络... 【计算机视觉】经典的图卷积网络框架(LeNet、AlexNet、VGGNet、Ince...
最新或2023(历届)全国“安... 【最新或2023(历届)全国“安全生产月”宣传咨询日活动简报1】  6月16日是全国第16个“安全生...
节能有我绿色共享系列活动总结 ... 【节能有我绿色共享系列活动总结1】  6月13日是全国低碳日,连城供电公司积极组织节能宣传周活动,在...
最新或2023(历届)白山中学... 6月29日放暑假,8月31日、9月1日学生到校报到,9月2日正式上课一、义务教育阶段中学、幼儿园6月...
最新或2023(历届)通化中学... 暑期是小学生安全事件频发的季节,分析近年来发生的暑期安全事故可以发现,安全意识淡漠、自救互救知识匮乏...
最新或2023(历届)全国土地... 最新或2023(历届)全国土地日宣传活动总结【1】  为做好第XX个全国“土地日”的宣传活动,使全社...
UWB芯片DW3000介绍二P... PHY报头:标准数据帧长度DW3000芯片的PHY头(PHR)使用前面文章调制方案定义的BPM/BP...
最新或2023(历届)齐齐哈尔... 据了解,今年我市义务教育学校暑假时间为7月20日至9月1日,普通高中暑假时间为7月20日至8月20日...
最新或2023(历届)吉林省中... 从吉林省教育厅获悉,近日省教育厅下发通知,最新或2023(历届)—最新或2023(历届)学年度吉林省...
最新或2023(历届)吉林中学... 从吉林省教育厅获悉,近日省教育厅下发通知,最新或2023(历届)—最新或2023(历届)学年度吉林省...
最新或2023(历届)宣城中学... 放假时间:  小学、初中、普通高中从最新或2023(历届)7月开始放暑假,8月31日结束,9月1日开...
最新或2023(历届)长春中学... 23日,记者从长春市教育局获悉,最新或2023(历届)—最新或2023(历届)学年度全市各学校暑假放...
第36篇:Java 泛型作用深... 泛型在日常开发当中,使用的场景非常多。尤其是在很多底层API、中间件等技术中有大量使用,如常见的RP...