Custom Keras Data Generators: Moving Beyond ImageDataGenerator
If you are training a standard classifier, ImageDataGenerator is fine. If you are building a Variational Autoencoder (VAE), a Siamese Network, or working with 3D Medical Imaging, the built-in tools are useless.
When your data is too large for RAM, you have two choices: tf.data.Dataset or tf.keras.utils.Sequence. While tf.data is more powerful, Sequence is often easier to debug and ensures your model sees every sample exactly once per epoch—something a standard Python generator cannot guarantee when running in parallel.
1. The Blueprint
To create a generator, you must inherit from tf.keras.utils.Sequence. This isn't optional; it’s what allows Keras to handle multiprocessing safely without double-counting data.
import tensorflow as tf
import numpy as np
import pandas as pd
class CustomDataGenerator(tf.keras.utils.Sequence):
def __init__(self, data: pd.DataFrame, batch_size: int=64, shuffle: bool=True):
self.batch_size = batch_size
self.shuffle = shuffle
# Internal state
self.labels = tf.keras.utils.to_categorical(data['label'].values)
self.images = data.drop(['label'], axis=1).values.reshape(-1, 28, 28)
self.indices = np.arange(len(self.images))
self.on_epoch_end()
def __len__(self):
"""Returns the number of batches per epoch."""
return int(np.ceil(len(self.images) / self.batch_size))
def __getitem__(self, index):
"""Generates one batch of data."""
# 1. Identify the indices for this batch
start = index * self.batch_size
end = (index + 1) * self.batch_size
batch_indices = self.indices[start:end]
# 2. Slice the data
X = self.images[batch_indices]
y = self.labels[batch_indices]
# 3. Apply normalization or augmentation here
return X / 255.0, y
def on_epoch_end(self):
"""Triggered after every epoch."""
if self.shuffle:
np.random.shuffle(self.indices)
2. Critical Logic Breakdown
Why __getitem__ is the Driver
Keras calls __getitem__ with an index argument. If your batch size is 64 and you have 640 samples, Keras will call this 10 times per epoch (indices 0 through 9).
The Trap: Do not do heavy file I/O inside
__init__. If you load 50GB of images intoself.imagesduring initialization, you’ve defeated the purpose of a generator.The Fix: Store only file paths in
__init__, and load the actual image bytes from disk inside__getitem__.
The on_epoch_end Advantage
Unlike simple Python generators (yield), the Sequence class has a hook for epoch-level logic. This is where you shuffle your data indices. If you don't shuffle, your model may learn the order of the data rather than its features, leading to poor generalization.
3. Deployment: Multiprocessing
The primary reason to use Sequence is the ability to leverage multiple CPU cores to prepare the next batch while the GPU is busy training the current one.
train_gen = CustomDataGenerator(train_df, batch_size=32)
model.fit(
train_gen,
epochs=10,
use_multiprocessing=True, # Runs __getitem__ on separate CPU cores
workers=4 # Number of parallel threads
)
Performance Warning
If use_multiprocessing=True causes a deadlock or high memory usage, ensure your __getitem__ isn't creating massive new objects or holding onto file handles. Everything inside your class must be thread-safe.
Conclusion
You now have a scalable pipeline. You aren't limited by RAM; you're limited only by your disk speed and CPU's ability to process batches.
Comments
Post a Comment