import tensorflow as tf
from tensorflow.keras.layers import Input, Dense, Dropout, LayerNormalization, MultiHeadAttention
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
class TransformerBlock(tf.keras.layers.Layer):
def __init__(self, embed_dim, num_heads, ff_dim, rate=0.1):
super(TransformerBlock, self).__init__()
self.att = MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim)
self.ffn = tf.keras.Sequential([
Dense(ff_dim, activation='relu'),
Dense(embed_dim),
])
self.layernorm1 = LayerNormalization(epsilon=1e-6)
self.layernorm2 = LayerNormalization(epsilon=1e-6)
self.dropout1 = Dropout(rate)
self.dropout2 = Dropout(rate)
def call(self, inputs, training=None):
attn_output = self.att(inputs, inputs)
attn_output = self.dropout1(attn_output, training=training)
out1 = self.layernorm1(inputs + attn_output)
ffn_output = self.ffn(out1)
ffn_output = self.dropout2(ffn_output, training=training)
return self.layernorm2(out1 + ffn_output)
# Model parameters
embed_dim = 32 # Reduced from larger size
num_heads = 2 # Reduced number of heads
ff_dim = 64 # Feed-forward network size
sequence_length = 50 # Example input length
vocab_size = 10000 # Example vocabulary size
num_classes = 5 # Number of output classes
inputs = Input(shape=(sequence_length,))
embedding_layer = tf.keras.layers.Embedding(input_dim=vocab_size, output_dim=embed_dim)(inputs)
transformer_block = TransformerBlock(embed_dim, num_heads, ff_dim, rate=0.2)(embedding_layer)
pooling = tf.keras.layers.GlobalAveragePooling1D()(transformer_block)
dropout = Dropout(0.3)(pooling)
outputs = Dense(num_classes, activation='softmax')(dropout)
model = Model(inputs=inputs, outputs=outputs)
model.compile(optimizer=Adam(learning_rate=0.0005),
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
# Example training call (X_train, y_train, X_val, y_val must be defined)
# model.fit(X_train, y_train, batch_size=32, epochs=20, validation_data=(X_val, y_val),
# callbacks=[tf.keras.callbacks.EarlyStopping(patience=3, restore_best_weights=True)])