tensorflow2.0 手写数字识别训练、预测
代码如下:
import numpy as np import tensorflow as tf # 加载数据 (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data() print(x_train.shape) print(y_train.shape) print(x_test.shape) print(y_test.shape) # 训练 def train(): model = tf.keras.models.Sequential([ tf.keras.layers.Flatten(input_shape=(28, 28)), tf.keras.layers.Dense(512, activation='relu'), tf.keras.layers.Dropout(0.2), tf.keras.layers.Dense(10, activation='softmax') ]) model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy']) model.fit(x_train, y_train, batch_size=64, epochs=10, validation_data=(x_test, y_test)) evaluate_result = model.evaluate(x_test, y_test) print(evaluate_result) # 只保留网络结构 model.save("data/model2.h5") # 预测 def predict(): model = tf.keras.models.load_model("data/model2.h5") for i in range(10): x = x_test[i].reshape((-1, 28, 28)) y_prob = model.predict(x) y_pred = np.argmax(y_prob) print("actual: {},predict:{},result:{}".format(y_test[i], y_pred, y_test[i] == y_pred)) # train() predict()