diff --git a/keras_cv/visualization/plot_image_gallery.py b/keras_cv/visualization/plot_image_gallery.py index 1d98c20f53..3ab9f6ae55 100644 --- a/keras_cv/visualization/plot_image_gallery.py +++ b/keras_cv/visualization/plot_image_gallery.py @@ -133,7 +133,7 @@ def plot_image_gallery( ) # batch_size from within passed `tf.data.Dataset` else: batch_size = ( - images.shape[0] if len(images.shape) == 4 else 1 + np.asarray(images).shape[0] if len(images.shape) == 4 else 1 ) # batch_size from np.array or single image rows = rows or int(math.ceil(math.sqrt(batch_size)))