代码如下:
import datetime import matplotlib.pyplot as plt import tensorflow as tf (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data() x_train, x_test = x_train / 255.0, x_test / 255.0 # 创建模型 model = tf.keras.models.Sequential([ tf.keras.layers.Flatten(input_shape=(28, 28)), tf.keras.layers.Dense(512, activation='relu'), tf.keras.layers.Dropout(0.2), tf.keras.layers.Dense(10, activation='softmax') ]) # 打印网络结构 print(model.summary()) # Model: "sequential" # _________________________________________________________________ # Layer (type) Output Shape Param # # ================================================================= # flatten (Flatten) (None, 784) 0 # _________________________________________________________________ # dense (Dense) (None, 512) 401920 # _________________________________________________________________ # dropout (Dropout) (None, 512) 0 # _________________________________________________________________ # dense_1 (Dense) (None, 10) 5130 # ================================================================= # Total params: 407,050 # Trainable params: 407,050 # Non-trainable params: 0 # _________________________________________________________________ # None # 模型变异 model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy']) # 只保留网络结构 model.save("data/model2.h5") model = tf.keras.models.load_model("data/model2.h5") log_dir = "logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S") tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1) earlyStopping_callback = tf.keras.callbacks.EarlyStopping(patience=2, monitor='val_loss'), # 模型训练 history = model.fit(x=x_train, y=y_train, epochs=5, validation_data=(x_test, y_test), callbacks=[tensorboard_callback, earlyStopping_callback]) # 保存网路结构和参数 model.save("data/model2.h5") model = tf.keras.models.load_model("data/model2.h5") # 训练、测试准确率时序图 history_dict = history.history print(history_dict.keys()) # output: dict_keys(['loss', 'accuracy', 'val_loss', 'val_accuracy']) acc = history_dict['accuracy'] val_acc = history_dict['val_accuracy'] loss = history_dict['loss'] val_loss = history_dict['val_loss'] epochs = range(1, len(acc) + 1) plt.plot(epochs, loss, 'bo', label='train loss') plt.plot(epochs, val_loss, 'b', label='val loss') plt.title('Train and val loss') plt.xlabel('Epochs') plt.ylabel('loss') plt.legend() plt.show()
持续补充