Một tùy chọn khác để thực hiện việc này với contrib.learn
thư viện như sau, dựa trên hướng dẫn Deep MNIST trên trang web Tensorflow. Đầu tiên, giả sử bạn đã nhập các thư viện có liên quan (chẳng hạn như import tensorflow.contrib.layers as layers
), bạn có thể xác định mạng theo một phương pháp riêng:
def easier_network(x, reg):
""" A network based on tf.contrib.learn, with input `x`. """
with tf.variable_scope('EasyNet'):
out = layers.flatten(x)
out = layers.fully_connected(out,
num_outputs=200,
weights_initializer = layers.xavier_initializer(uniform=True),
weights_regularizer = layers.l2_regularizer(scale=reg),
activation_fn = tf.nn.tanh)
out = layers.fully_connected(out,
num_outputs=200,
weights_initializer = layers.xavier_initializer(uniform=True),
weights_regularizer = layers.l2_regularizer(scale=reg),
activation_fn = tf.nn.tanh)
out = layers.fully_connected(out,
num_outputs=10, # Because there are ten digits!
weights_initializer = layers.xavier_initializer(uniform=True),
weights_regularizer = layers.l2_regularizer(scale=reg),
activation_fn = None)
return out
Sau đó, trong một phương pháp chính, bạn có thể sử dụng đoạn mã sau:
def main(_):
mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True)
x = tf.placeholder(tf.float32, [None, 784])
y_ = tf.placeholder(tf.float32, [None, 10])
# Make a network with regularization
y_conv = easier_network(x, FLAGS.regu)
weights = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'EasyNet')
print("")
for w in weights:
shp = w.get_shape().as_list()
print("- {} shape:{} size:{}".format(w.name, shp, np.prod(shp)))
print("")
reg_ws = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES, 'EasyNet')
for w in reg_ws:
shp = w.get_shape().as_list()
print("- {} shape:{} size:{}".format(w.name, shp, np.prod(shp)))
print("")
# Make the loss function `loss_fn` with regularization.
cross_entropy = tf.reduce_mean(
tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y_conv))
loss_fn = cross_entropy + tf.reduce_sum(reg_ws)
train_step = tf.train.AdamOptimizer(1e-4).minimize(loss_fn)
Để làm được điều này, bạn cần làm theo hướng dẫn MNIST mà tôi đã liên kết trước đó và nhập các thư viện liên quan, nhưng đó là một bài tập hay để học TensorFlow và thật dễ dàng để xem việc chính quy ảnh hưởng đến đầu ra như thế nào. Nếu bạn áp dụng quy định làm đối số, bạn có thể thấy những điều sau:
- EasyNet/fully_connected/weights:0 shape:[784, 200] size:156800
- EasyNet/fully_connected/biases:0 shape:[200] size:200
- EasyNet/fully_connected_1/weights:0 shape:[200, 200] size:40000
- EasyNet/fully_connected_1/biases:0 shape:[200] size:200
- EasyNet/fully_connected_2/weights:0 shape:[200, 10] size:2000
- EasyNet/fully_connected_2/biases:0 shape:[10] size:10
- EasyNet/fully_connected/kernel/Regularizer/l2_regularizer:0 shape:[] size:1.0
- EasyNet/fully_connected_1/kernel/Regularizer/l2_regularizer:0 shape:[] size:1.0
- EasyNet/fully_connected_2/kernel/Regularizer/l2_regularizer:0 shape:[] size:1.0
Lưu ý rằng phần quy định cung cấp cho bạn ba mục, dựa trên các mục có sẵn.
Với các mức quy định 0, 0,0001, 0,01 và 1,0, tôi nhận được các giá trị độ chính xác thử nghiệm tương ứng là 0,9468, 0,9476, 0,9183 và 0,1135, cho thấy sự nguy hiểm của các điều khoản chính quy cao.
S = tf.get_variable(name='S', regularizer=tf.contrib.layers.l2_regularizer )
?