import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, models, losses, optimizers
# Simulated data: queries, documents, and relevance labels
# For simplicity, embeddings are random but in practice use real embeddings
np.random.seed(42)
num_samples = 1000
embedding_dim = 50
# Generate random embeddings for queries and documents
query_embeddings = np.random.rand(num_samples, embedding_dim).astype(np.float32)
doc_embeddings = np.random.rand(num_samples, embedding_dim).astype(np.float32)
# Generate binary relevance labels (1=relevant, 0=not relevant)
labels = np.random.randint(0, 2, size=(num_samples, 1)).astype(np.float32)
# Combine query and doc embeddings as input features
inputs = np.concatenate([query_embeddings, doc_embeddings], axis=1)
# Split into train and validation sets
split = int(0.8 * num_samples)
X_train, X_val = inputs[:split], inputs[split:]
y_train, y_val = labels[:split], labels[split:]
# Define a simple neural network for re-ranking
model = models.Sequential([
layers.Input(shape=(embedding_dim * 2,)),
layers.Dense(64, activation='relu'),
layers.Dropout(0.3),
layers.Dense(32, activation='relu'),
layers.Dense(1, activation='sigmoid')
])
model.compile(optimizer=optimizers.Adam(learning_rate=0.001),
loss=losses.BinaryCrossentropy(),
metrics=['accuracy'])
# Train the model
history = model.fit(X_train, y_train, epochs=10, batch_size=32, validation_data=(X_val, y_val), verbose=0)
# Evaluate precision@5 and recall@5 on validation set
# For simplicity, simulate retrieval of 5 docs per query (here each sample is a pair)
# Sort validation samples by predicted score descending
val_preds = model.predict(X_val).flatten()
# Simulate grouping by queries: assume each 5 samples correspond to one query
num_val_queries = len(X_val) // 5
precision_at_5 = []
recall_at_5 = []
for i in range(num_val_queries):
start = i * 5
end = start + 5
scores = val_preds[start:end]
true_labels = y_val[start:end].flatten()
# Sort by scores
sorted_indices = np.argsort(scores)[::-1]
sorted_labels = true_labels[sorted_indices]
# Calculate precision@5 and recall@5
relevant_retrieved = np.sum(sorted_labels)
total_relevant = np.sum(true_labels)
precision = relevant_retrieved / 5
recall = relevant_retrieved / total_relevant if total_relevant > 0 else 0
precision_at_5.append(precision)
recall_at_5.append(recall)
avg_precision_at_5 = np.mean(precision_at_5) * 100
avg_recall_at_5 = np.mean(recall_at_5) * 100
print(f"Precision@5 after re-ranking: {avg_precision_at_5:.2f}%")
print(f"Recall@5 after re-ranking: {avg_recall_at_5:.2f}%")