import numpy as np
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.neural_network import MLPClassifier
from sklearn.metrics import accuracy_score
# Create synthetic data
X, y = make_classification(n_samples=1000, n_features=20, n_informative=15, n_classes=2, random_state=42)
# Split into labeled, unlabeled, and test sets
X_temp, X_test, y_temp, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
X_labeled, X_unlabeled, y_labeled, _ = train_test_split(X_temp, y_temp, test_size=0.8, random_state=42)
# Initial model trained only on labeled data
model = MLPClassifier(hidden_layer_sizes=(50,), max_iter=200, random_state=42)
model.fit(X_labeled, y_labeled)
# Evaluate initial model
train_acc = accuracy_score(y_labeled, model.predict(X_labeled))
val_acc = accuracy_score(y_test, model.predict(X_test))
# Semi-supervised learning with pseudo-labeling
threshold = 0.9
for iteration in range(5):
# Predict probabilities on unlabeled data
probs = model.predict_proba(X_unlabeled)
max_probs = np.max(probs, axis=1)
pseudo_labels = model.predict(X_unlabeled)
# Select confident predictions
confident_mask = max_probs >= threshold
if not np.any(confident_mask):
break
X_pseudo = X_unlabeled[confident_mask]
y_pseudo = pseudo_labels[confident_mask]
# Combine labeled and pseudo-labeled data
X_combined = np.vstack((X_labeled, X_pseudo))
y_combined = np.hstack((y_labeled, y_pseudo))
# Remove pseudo-labeled from unlabeled
X_unlabeled = X_unlabeled[~confident_mask]
# Retrain model
model = MLPClassifier(hidden_layer_sizes=(50,), max_iter=200, random_state=42)
model.fit(X_combined, y_combined)
# Update labeled data
X_labeled, y_labeled = X_combined, y_combined
# Final evaluation
final_train_acc = accuracy_score(y_labeled, model.predict(X_labeled))
final_val_acc = accuracy_score(y_test, model.predict(X_test))
print(f"Initial training accuracy: {train_acc:.2f}")
print(f"Initial validation accuracy: {val_acc:.2f}")
print(f"Final training accuracy: {final_train_acc:.2f}")
print(f"Final validation accuracy: {final_val_acc:.2f}")