Keras教程
使用matplotlib展示图片:
# -*- coding: UTF-8 -*- import matplotlib.pyplot as plt from keras.datasets import mnist # 加载数据 local_path = 'D:/data/keras/mnist.npz' (x_train, y_train), (x_test, y_test) = mnist.load_data(path=local_path); print(x_train.shape, y_train.shape, x_test.shape, y_test.shape) # (60000, 28, 28) (60000,) (10000, 28, 28) (10000,) # 显示训练集中的前16张图片 plt.figure(figsize=(6, 6)) for i in range(16): plt.subplot(4, 4, i + 1) plt.imshow(x_train[i]) plt.xlabel(y_train[i]) plt.subplots_adjust(hspace=1) plt.show()