tensorflow2.0入门

代码如下:

import datetime

import matplotlib.pyplot as plt
import tensorflow as tf

(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

# 创建模型
model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(input_shape=(28, 28)),
    tf.keras.layers.Dense(512, activation='relu'),
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Dense(10, activation='softmax')
])

# 打印网络结构
print(model.summary())
# Model: "sequential"
# _________________________________________________________________
# Layer (type)                 Output Shape              Param #
# =================================================================
# flatten (Flatten)            (None, 784)               0
# _________________________________________________________________
# dense (Dense)                (None, 512)               401920
# _________________________________________________________________
# dropout (Dropout)            (None, 512)               0
# _________________________________________________________________
# dense_1 (Dense)              (None, 10)                5130
# =================================================================
# Total params: 407,050
# Trainable params: 407,050
# Non-trainable params: 0
# _________________________________________________________________
# None


# 模型变异
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

# 只保留网络结构
model.save("data/model2.h5")
model = tf.keras.models.load_model("data/model2.h5")

log_dir = "logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)
earlyStopping_callback = tf.keras.callbacks.EarlyStopping(patience=2, monitor='val_loss'),
# 模型训练
history = model.fit(x=x_train,
                    y=y_train,
                    epochs=5,
                    validation_data=(x_test, y_test),
                    callbacks=[tensorboard_callback, earlyStopping_callback])

# 保存网路结构和参数
model.save("data/model2.h5")
model = tf.keras.models.load_model("data/model2.h5")

# 训练、测试准确率时序图
history_dict = history.history
print(history_dict.keys())
# output: dict_keys(['loss', 'accuracy', 'val_loss', 'val_accuracy'])

acc = history_dict['accuracy']
val_acc = history_dict['val_accuracy']
loss = history_dict['loss']
val_loss = history_dict['val_loss']
epochs = range(1, len(acc) + 1)

plt.plot(epochs, loss, 'bo', label='train loss')
plt.plot(epochs, val_loss, 'b', label='val loss')
plt.title('Train and val loss')
plt.xlabel('Epochs')
plt.ylabel('loss')
plt.legend()
plt.show()


持续补充

个人资料
hadoop迷
等级:6
文章:30篇
访问:2.2w
排名: 13
上一篇: ROC和AUC介绍以及如何计算AUC
下一篇:tensorflow2.0模型保存为json格式
标签: tensorflow、面试题
隐藏