import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
# Load CIFAR-10 dataset
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
# Normalize pixel values
x_train = x_train.astype('float32') / 255.0
x_test = x_test.astype('float32') / 255.0
# Parameters
image_size = 32 # CIFAR-10 images are 32x32
patch_size = 4 # 4x4 patches
num_patches = (image_size // patch_size) ** 2
projection_dim = 64
num_heads = 4
transformer_layers = 4
mlp_units = [128, 64]
dropout_rate = 0.3
num_classes = 10
# Create patches
class Patches(layers.Layer):
def __init__(self, patch_size):
super().__init__()
self.patch_size = patch_size
def call(self, images):
batch_size = tf.shape(images)[0]
patches = tf.image.extract_patches(
images=images,
sizes=[1, self.patch_size, self.patch_size, 1],
strides=[1, self.patch_size, self.patch_size, 1],
rates=[1, 1, 1, 1],
padding='VALID',
)
patch_dims = patches.shape[-1]
patches = tf.reshape(patches, [batch_size, -1, patch_dims])
return patches
# Patch encoding
class PatchEncoder(layers.Layer):
def __init__(self, num_patches, projection_dim):
super().__init__()
self.num_patches = num_patches
self.projection = layers.Dense(projection_dim)
self.position_embedding = layers.Embedding(input_dim=num_patches, output_dim=projection_dim)
def call(self, patch):
positions = tf.range(start=0, limit=self.num_patches, delta=1)
encoded = self.projection(patch) + self.position_embedding(positions)
return encoded
# Build the Vision Transformer model
inputs = layers.Input(shape=(image_size, image_size, 3))
# Create patches
patches = Patches(patch_size)(inputs)
# Encode patches
encoded_patches = PatchEncoder(num_patches, projection_dim)(patches)
# Create multiple Transformer blocks
for _ in range(transformer_layers):
# Layer normalization 1
x1 = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
# Multi-head self-attention
attention_output = layers.MultiHeadAttention(num_heads=num_heads, key_dim=projection_dim)(x1, x1)
# Skip connection
x2 = layers.Add()([attention_output, encoded_patches])
# Layer normalization 2
x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
# MLP
mlp_output = layers.Dense(mlp_units[0], activation='relu')(x3)
mlp_output = layers.Dropout(dropout_rate)(mlp_output)
mlp_output = layers.Dense(mlp_units[1], activation='relu')(mlp_output)
mlp_output = layers.Dropout(dropout_rate)(mlp_output)
# Skip connection
encoded_patches = layers.Add()([mlp_output, x2])
# Classification head
representation = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
representation = layers.Flatten()(representation)
representation = layers.Dropout(dropout_rate)(representation)
outputs = layers.Dense(num_classes, activation='softmax')(representation)
model = keras.Model(inputs=inputs, outputs=outputs)
# Compile model
model.compile(
optimizer=keras.optimizers.Adam(learning_rate=1e-4),
loss='sparse_categorical_crossentropy',
metrics=['accuracy'],
)
# Train model with validation split
history = model.fit(
x_train, y_train,
epochs=30,
batch_size=64,
validation_split=0.2,
verbose=2
)
# Evaluate on test data
test_loss, test_accuracy = model.evaluate(x_test, y_test, verbose=0)
print(f'Test accuracy: {test_accuracy * 100:.2f}%', f'Test loss: {test_loss:.4f}')