TensorFlow Datasets 數據集載入

TensorFlow Datasets 是一個開箱即用的數據集集合,包含數十種常用的機器學習數據集。通過簡單的幾行代碼即可將數據以 tf.data.Dataset 的格式載入。關於 tf.data.Dataset 的使用可參考 tf.data

該工具是一個獨立的Python包,可以通過:

pip install tensorflow-datasets

安裝。

在使用時,首先使用import導入該包

import tensorflow as tf
import tensorflow_datasets as tfds

然後,最基礎的用法是使用 tfds.load 方法,載入所需的數據集。例如,以下三行代碼分別載入了MNIST、貓狗分類和 tf_flowers 三個圖像分類數據集:

dataset = tfds.load("mnist", split=tfds.Split.TRAIN, as_supervised=True)
dataset = tfds.load("cats_vs_dogs", split=tfds.Split.TRAIN, as_supervised=True)
dataset = tfds.load("tf_flowers", split=tfds.Split.TRAIN, as_supervised=True)

當第一次載入特定數據集時,TensorFlow Datasets 會自動從雲端下載數據集到本地,並顯示下載進度。例如,載入MNIST數據集時,終端輸出提示如下:

Downloading and preparing dataset mnist (11.06 MiB) to C:\Users\snowkylin\tensorflow_datasets\mnist\3.0.0...
WARNING:absl:Dataset mnist is hosted on GCS. It will automatically be downloaded to your
local data directory. If you'd instead prefer to read directly from our public
GCS bucket (recommended if you're running on GCP), you can instead set
data_dir=gs://tfds-data/datasets.

Dl Completed...: 100%|██████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:10<00:00,  2.93s/ file]
Dl Completed...: 100%|██████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:10<00:00,  2.73s/ file]
Dataset mnist downloaded and prepared to C:\Users\snowkylin\tensorflow_datasets\mnist\3.0.0. Subsequent calls will reuse this data.

提示

在使用 TensorFlow Datasets 時,可能需要設置代理。較爲簡易的方式是設置 HTTPS_PROXY 環境變量( 參考這裡 ),即

export HTTPS_PROXY=http://代理伺服器IP:

tfds.load 方法返回一個 tf.data.Dataset 對象。部分重要的參數如下:

  • as_supervised :若爲True,則根據數據集的特性,將數據集中的每行元素整理爲有監督的二元組 (input, label) (即「數據+標籤」)形式,否則數據集中的每行元素爲包含所有特徵的字典。

  • split:指定返回數據集的特定部分。若不指定,則返回整個數據集。一般有 tfds.Split.TRAIN (訓練集)和 tfds.Split.TEST (測試集)選項。

TensorFlow Datasets 當前支持的數據集可在 官方文檔 查看,或者也可以使用 tfds.list_builders() 查看。

當得到了 tf.data.Dataset 類型的數據集後,我們即可使用 tf.data 對數據集進行各種預處理以及讀取數據。例如:

# 使用 TessorFlow Datasets 載入「tf_flowers」數據集
dataset = tfds.load("tf_flowers", split=tfds.Split.TRAIN, as_supervised=True)
# 對 dataset 進行大小調整、打散和分批次操作
dataset = dataset.map(lambda img, label: (tf.image.resize(img, [224, 224]) / 255.0, label)) \
    .shuffle(1024) \
    .batch(32)
# 疊代數據
for images, labels in dataset:
    # 對images和labels進行操作

詳細操作說明可見 本手冊的 tf.data 一節 。同時,本手冊的 分布式訓練 一章也使用了 TensorFlow Datasets 載入數據集。可以參考這些章節的示例代碼以進一步了解 TensorFlow Datasets 的使用方法。