import tensorflow as tf
import numpy as np
# Load example dataset (cats vs dogs) - using CIFAR-10 for demonstration
(x_train, y_train), (x_val, y_val) = tf.keras.datasets.cifar10.load_data()
# Filter dataset for classes 3 (cat) and 5 (dog)
train_filter = np.where((y_train == 3) | (y_train == 5))[0]
val_filter = np.where((y_val == 3) | (y_val == 5))[0]
x_train, y_train = x_train[train_filter], y_train[train_filter]
x_val, y_val = x_val[val_filter], y_val[val_filter]
# Convert labels to 0 (cat) and 1 (dog)
y_train = (y_train == 5).astype(np.float32)
y_val = (y_val == 5).astype(np.float32)
# Normalize images
x_train = x_train.astype('float32') / 255.0
x_val = x_val.astype('float32') / 255.0
# Define MixUp function
def mixup(batch_x, batch_y, alpha=0.2):
batch_size = batch_x.shape[0]
lam = np.random.beta(alpha, alpha)
index = np.random.permutation(batch_size)
mixed_x = lam * batch_x + (1 - lam) * batch_x[index]
mixed_y = lam * batch_y + (1 - lam) * batch_y[index]
return mixed_x, mixed_y
# Create TensorFlow dataset
batch_size = 64
train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_ds = train_ds.shuffle(1000).batch(batch_size)
val_ds = tf.data.Dataset.from_tensor_slices((x_val, y_val)).batch(batch_size)
# Define model (simple CNN)
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(32, 32, 3)),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Conv2D(64, 3, activation='relu'),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.Dense(1, activation='sigmoid')
])
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
# Custom training loop with MixUp
epochs = 10
for epoch in range(epochs):
print(f'Epoch {epoch+1}/{epochs}')
# Training
for step, (batch_x, batch_y) in enumerate(train_ds):
batch_x, batch_y = mixup(batch_x.numpy(), batch_y.numpy(), alpha=0.2)
loss, acc = model.train_on_batch(batch_x, batch_y)
if step % 50 == 0:
print(f'Step {step}, Loss: {loss:.4f}, Accuracy: {acc:.4f}')
# Validation
val_loss, val_acc = model.evaluate(val_ds, verbose=0)
print(f'Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_acc:.4f}')