Custom Keras Data Generators: Moving Beyond ImageDataGenerator

If you are training a standard classifier, ImageDataGenerator is fine. If you are building a Variational Autoencoder (VAE), a Siamese Network, or working with 3D Medical Imaging, the built-in tools are useless.

When your data is too large for RAM, you have two choices: tf.data.Dataset or tf.keras.utils.Sequence. While tf.data is more powerful, Sequence is often easier to debug and ensures your model sees every sample exactly once per epoch—something a standard Python generator cannot guarantee when running in parallel.


1. The Blueprint

To create a generator, you must inherit from tf.keras.utils.Sequence. This isn't optional; it’s what allows Keras to handle multiprocessing safely without double-counting data.

Python
import tensorflow as tf
import numpy as np
import pandas as pd

class CustomDataGenerator(tf.keras.utils.Sequence):
    def __init__(self, data: pd.DataFrame, batch_size: int=64, shuffle: bool=True):
        self.batch_size = batch_size
        self.shuffle = shuffle
        
        # Internal state
        self.labels = tf.keras.utils.to_categorical(data['label'].values)
        self.images = data.drop(['label'], axis=1).values.reshape(-1, 28, 28)
        self.indices = np.arange(len(self.images))
        self.on_epoch_end()

    def __len__(self):
        """Returns the number of batches per epoch."""
        return int(np.ceil(len(self.images) / self.batch_size))

    def __getitem__(self, index):
        """Generates one batch of data."""
        # 1. Identify the indices for this batch
        start = index * self.batch_size
        end = (index + 1) * self.batch_size
        batch_indices = self.indices[start:end]

        # 2. Slice the data
        X = self.images[batch_indices]
        y = self.labels[batch_indices]

        # 3. Apply normalization or augmentation here
        return X / 255.0, y

    def on_epoch_end(self):
        """Triggered after every epoch."""
        if self.shuffle:
            np.random.shuffle(self.indices)

2. Critical Logic Breakdown

Why __getitem__ is the Driver

Keras calls __getitem__ with an index argument. If your batch size is 64 and you have 640 samples, Keras will call this 10 times per epoch (indices 0 through 9).

  • The Trap: Do not do heavy file I/O inside __init__. If you load 50GB of images into self.images during initialization, you’ve defeated the purpose of a generator.

  • The Fix: Store only file paths in __init__, and load the actual image bytes from disk inside __getitem__.

The on_epoch_end Advantage

Unlike simple Python generators (yield), the Sequence class has a hook for epoch-level logic. This is where you shuffle your data indices. If you don't shuffle, your model may learn the order of the data rather than its features, leading to poor generalization.


3. Deployment: Multiprocessing

The primary reason to use Sequence is the ability to leverage multiple CPU cores to prepare the next batch while the GPU is busy training the current one.

Python
train_gen = CustomDataGenerator(train_df, batch_size=32)

model.fit(
    train_gen,
    epochs=10,
    use_multiprocessing=True, # Runs __getitem__ on separate CPU cores
    workers=4                 # Number of parallel threads
)

Performance Warning

If use_multiprocessing=True causes a deadlock or high memory usage, ensure your __getitem__ isn't creating massive new objects or holding onto file handles. Everything inside your class must be thread-safe.


Conclusion

You now have a scalable pipeline. You aren't limited by RAM; you're limited only by your disk speed and CPU's ability to process batches.

Comments

Popular posts from this blog

Beyond CRUD: Building a Scalable Data Quality Monitoring Engine with React, FastAPI, and Strategy Patterns

Architecting MarketPulse: A Deep Dive into a Enterprise-Grade Financial Sentiment Pipeline

Architecting GitQuery AI: A Deep Dive into Building a Production-Ready RAG System for GitHub Repositories