import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Embedding, LSTM, Dense, Dropout
from tensorflow.keras.callbacks import EarlyStopping
# Sample dataset (toy example)
sentences = [
'I love machine learning',
'Machine learning is fun',
'I enjoy learning new things',
'Deep learning is a branch of machine learning',
'Natural language processing is interesting'
]
# Simple tokenizer and data preparation
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences
tokenizer = Tokenizer()
tokenizer.fit_on_texts(sentences)
sequences = []
for sentence in sentences:
token_list = tokenizer.texts_to_sequences([sentence])[0]
for i in range(1, len(token_list)):
n_gram_sequence = token_list[:i+1]
sequences.append(n_gram_sequence)
max_seq_len = max(len(seq) for seq in sequences)
sequences = pad_sequences(sequences, maxlen=max_seq_len, padding='pre')
import numpy as np
sequences = np.array(sequences)
X = sequences[:, :-1]
y = sequences[:, -1]
vocab_size = len(tokenizer.word_index) + 1
# Build model with dropout and reduced units
model = Sequential([
Embedding(vocab_size, 10, input_length=max_seq_len-1),
LSTM(32, return_sequences=False),
Dropout(0.3),
Dense(vocab_size, activation='softmax')
])
model.compile(loss='sparse_categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
# Early stopping callback
early_stop = EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)
history = model.fit(X, y, epochs=50, batch_size=4, validation_split=0.2, callbacks=[early_stop], verbose=0)
# Extract final metrics
final_train_acc = history.history['accuracy'][-1] * 100
final_val_acc = history.history['val_accuracy'][-1] * 100
final_train_loss = history.history['loss'][-1]
final_val_loss = history.history['val_loss'][-1]
print(f'Training accuracy: {final_train_acc:.2f}%')
print(f'Validation accuracy: {final_val_acc:.2f}%')
print(f'Training loss: {final_train_loss:.3f}')
print(f'Validation loss: {final_val_loss:.3f}')