import tensorflow as tf
from tensorflow.keras import layers, models, datasets, utils
# Load CIFAR-10 data
(x_train, y_train), (x_test, y_test) = datasets.cifar10.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
y_train, y_test = utils.to_categorical(y_train, 10), utils.to_categorical(y_test, 10)
# Define a basic residual block
class ResidualBlock(layers.Layer):
def __init__(self, filters, stride=1):
super().__init__()
self.conv1 = layers.Conv2D(filters, 3, strides=stride, padding='same', use_bias=False)
self.bn1 = layers.BatchNormalization()
self.relu = layers.ReLU()
self.conv2 = layers.Conv2D(filters, 3, strides=1, padding='same', use_bias=False)
self.bn2 = layers.BatchNormalization()
if stride != 1:
self.shortcut = models.Sequential([
layers.Conv2D(filters, 1, strides=stride, padding='same', use_bias=False),
layers.BatchNormalization()
])
else:
self.shortcut = layers.Layer() # Changed from layers.Identity() to layers.Layer() as Identity is not available in Keras
def call(self, inputs, training=False):
x = self.conv1(inputs)
x = self.bn1(x, training=training)
x = self.relu(x)
x = self.conv2(x)
x = self.bn2(x, training=training)
shortcut = self.shortcut(inputs, training=training) if hasattr(self.shortcut, 'call') else inputs
x += shortcut
return self.relu(x)
# Build a small ResNet model
inputs = layers.Input(shape=(32, 32, 3))
x = layers.Conv2D(64, 3, strides=1, padding='same', use_bias=False)(inputs)
x = layers.BatchNormalization()(x)
x = layers.ReLU()(x)
x = ResidualBlock(64)(x)
x = ResidualBlock(64)(x)
x = ResidualBlock(128, stride=2)(x)
x = ResidualBlock(128)(x)
x = layers.GlobalAveragePooling2D()(x)
x = layers.Dense(10, activation='softmax')(x)
model = models.Model(inputs, x)
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
loss='categorical_crossentropy',
metrics=['accuracy'])
history = model.fit(x_train, y_train, epochs=30, batch_size=64, validation_data=(x_test, y_test))