import tensorflow as tf
from tensorflow.keras.layers import Input, Dense, Dropout, Layer
from tensorflow.keras.models import Model
import numpy as np
# Simple attention layer implementation
class SimpleAttention(Layer):
def __init__(self, **kwargs):
super(SimpleAttention, self).__init__(**kwargs)
def build(self, input_shape):
self.W = self.add_weight(shape=(input_shape[-1], 1), initializer='random_normal', trainable=True)
super(SimpleAttention, self).build(input_shape)
def call(self, inputs):
scores = tf.matmul(inputs, self.W) # shape: (batch_size, seq_len, 1)
weights = tf.nn.softmax(scores, axis=1) # attention weights
weighted_sum = tf.reduce_sum(inputs * weights, axis=1) # shape: (batch_size, features)
return weighted_sum
# Model parameters
sequence_length = 10
feature_dim = 16
num_classes = 2
inputs = Input(shape=(sequence_length, feature_dim))
# Attention mechanism
attention_output = SimpleAttention()(inputs)
# Add dropout to reduce overfitting
dropout = Dropout(0.3)(attention_output)
# Dense layer with fewer units to reduce complexity
dense = Dense(32, activation='relu')(dropout)
outputs = Dense(num_classes, activation='softmax')(dense)
model = Model(inputs=inputs, outputs=outputs)
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
# Generate dummy data
X_train = np.random.rand(1000, sequence_length, feature_dim).astype(np.float32)
y_train = np.random.randint(0, num_classes, 1000)
X_val = np.random.rand(200, sequence_length, feature_dim).astype(np.float32)
y_val = np.random.randint(0, num_classes, 200)
# Train with early stopping
early_stop = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=3, restore_best_weights=True)
history = model.fit(X_train, y_train, epochs=30, batch_size=32, validation_data=(X_val, y_val), callbacks=[early_stop])
# Evaluate final metrics
train_loss, train_acc = model.evaluate(X_train, y_train, verbose=0)
val_loss, val_acc = model.evaluate(X_val, y_val, verbose=0)
print(f'Training accuracy: {train_acc*100:.2f}%, Validation accuracy: {val_acc*100:.2f}%')