博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
drop解决过拟合的情况
阅读量:5242 次
发布时间:2019-06-14

本文共 3075 字,大约阅读时间需要 10 分钟。

用到的训练数据集:sklearn数据集

可视化工具:tensorboard,这儿记录了loss值(预测值与真实值的差值),通过loss值可以判断训练的结果与真实数据是否吻合

 

过拟合:训练过程中为了追求完美而导致问题

过拟合的情况:蓝线为实际情况,在误差为10的区间,他能够表示每条数据。

       橙线为训练情况,为了追求0误差,他将每条数据都关联起来,但是如果新增一些点(+),他就不能去表示新增的点了

 

 

训练得到的值和实际测试得到的值相比,训练得到的loss更小,但它与实际不合,并不是loss值越小就越好

drop处理过拟合后:

代码:

import tensorflow as tffrom sklearn.datasets import load_digitsfrom sklearn.cross_validation import train_test_splitfrom sklearn.preprocessing import LabelBinarizer# load datadigits = load_digits()X = digits.datay = digits.targety = LabelBinarizer().fit_transform(y)   # 转换格式X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=.3)def add_layer(inputs, in_size, out_size, layer_name, active_function=None):    """    :param inputs:    :param in_size: 行    :param out_size: 列 , [行, 列] =矩阵    :param active_function:    :return:    """    with tf.name_scope('layer'):        with tf.name_scope('weights'):            W = tf.Variable(tf.random_normal([in_size, out_size]), name='W')  #        with tf.name_scope('bias'):            b = tf.Variable(tf.zeros([1, out_size]) + 0.1)  # b是一行数据,对应out_size列个数据        with tf.name_scope('Wx_plus_b'):            Wx_plus_b = tf.matmul(inputs, W) + b        Wx_plus_b = tf.nn.dropout(Wx_plus_b, keep_prob=keep_prob)        if active_function is None:            outputs = Wx_plus_b        else:            outputs = active_function(Wx_plus_b)        tf.summary.histogram(layer_name + '/outputs', outputs)  # 1.2.记录outputs值,数据直方图        return outputs# define placeholder for inputs to networkkeep_prob = tf.placeholder(tf.float32)  # 不被dropout的数量xs = tf.placeholder(tf.float32, [None, 64])  # 8*8ys = tf.placeholder(tf.float32, [None, 10])# add output layerl1 = add_layer(xs, 64, 50, 'l1', active_function=tf.nn.tanh)prediction = add_layer(l1, 50, 10, 'l2', active_function=tf.nn.softmax)# the loss between prediction and reallycross_entropy = tf.reduce_mean(-tf.reduce_sum(ys * tf.log(prediction), reduction_indices=[1]))tf.summary.scalar('loss', cross_entropy)  # 字符串类型的标量张量,包含一个Summaryprotobuf  1.1记录标量(展示到直方图中 1.2 )# trainingtrain_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)sess = tf.Session()merged = tf.summary.merge_all()  # 2.把所有summary节点整合在一起,只需run一次,这儿只有cross_entropysess.run(tf.initialize_all_variables())train_writer = tf.summary.FileWriter('log/train', sess.graph)  # 3.写入test_writer = tf.summary.FileWriter('log/test', sess.graph)  # cmd cd到log目录下,启动 tensorboard --logdir=log\# start trainingfor i in range(500):    sess.run(train_step, feed_dict={xs: X_train, ys: y_train, keep_prob: 0.5})  # keep_prob训练时保留50%, 当这儿为1时,代表不drop任何数据,(没处理过拟合问题)    if i % 50 == 0:        # 4. record loss        train_result = sess.run(merged, feed_dict={xs: X_train, ys: y_train, keep_prob: 1})  # tensorboard记录保留100%的数据        test_result = sess.run(merged, feed_dict={xs: X_test, ys: y_test, keep_prob: 1})        train_writer.add_summary(train_result, i)        test_writer.add_summary(test_result, i)print("Record Finished !!!")

 

posted on
2018-06-22 14:43 阅读(
...) 评论(
...)

转载于:https://www.cnblogs.com/tangpg/p/9213375.html

你可能感兴趣的文章
C#编程时应注意的性能处理
查看>>
Fragment
查看>>
比较安全的获取站点更目录
查看>>
苹果开发者账号那些事儿(二)
查看>>
使用C#交互快速生成代码!
查看>>
UVA11374 Airport Express
查看>>
P1373 小a和uim之大逃离 四维dp,维护差值
查看>>
NOIP2015 运输计划 树上差分+树剖
查看>>
P3950 部落冲突 树链剖分
查看>>
读书汇总贴
查看>>
微信小程序 movable-view组件应用:可拖动悬浮框_返回首页
查看>>
MPT树详解
查看>>
空间分析开源库GEOS
查看>>
RQNOJ八月赛
查看>>
前端各种mate积累
查看>>
jQuery 1.7 发布了
查看>>
Python(软件目录结构规范)
查看>>
Windows多线程入门のCreateThread与_beginthreadex本质区别(转)
查看>>
Nginx配置文件(nginx.conf)配置详解1
查看>>
linux php编译安装
查看>>