import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.utils.class_weight import compute_class_weight
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping
# Sample data (replace with real dataset)
texts = ["Free money now", "Hi, how are you?", "Win a prize", "Let's meet tomorrow", "Cheap meds available", "Are you coming?", "Congratulations, you won!", "Call me later"] * 100
labels = [1 if i % 10 == 0 else 0 for i in range(len(texts))] # 1=spam, 0=not spam (~10% spam)
# Split data
X_train, X_val, y_train, y_val = train_test_split(texts, labels, test_size=0.2, stratify=labels, random_state=42)
# Vectorize text
vectorizer = TfidfVectorizer(max_features=1000)
X_train_vec = vectorizer.fit_transform(X_train).toarray()
X_val_vec = vectorizer.transform(X_val).toarray()
# Compute class weights
class_weights = compute_class_weight('balanced', classes=np.unique(y_train), y=y_train)
class_weight_dict = {i: w for i, w in enumerate(class_weights)}
# Build model
model = Sequential([
Dense(64, activation='relu', input_shape=(X_train_vec.shape[1],)),
Dropout(0.5),
Dense(32, activation='relu'),
Dense(1, activation='sigmoid')
])
model.compile(optimizer=Adam(learning_rate=0.001), loss='binary_crossentropy', metrics=['accuracy'])
# Early stopping
early_stop = EarlyStopping(monitor='val_loss', patience=3, restore_best_weights=True)
# Train with class weights
history = model.fit(X_train_vec, np.array(y_train), epochs=20, batch_size=32, validation_data=(X_val_vec, np.array(y_val)), class_weight=class_weight_dict, callbacks=[early_stop], verbose=0)
# Evaluate
val_preds = (model.predict(X_val_vec) > 0.5).astype(int).flatten()
from sklearn.metrics import accuracy_score, recall_score
val_accuracy = accuracy_score(y_val, val_preds) * 100
val_recall_spam = recall_score(y_val, val_preds, pos_label=1) * 100
print(f"Validation accuracy: {val_accuracy:.2f}%")
print(f"Validation recall for spam: {val_recall_spam:.2f}%")