import tensorflow as tf
from tensorflow.keras.layers import Input, Dense
from tensorflow.keras.models import Model
import numpy as np
# Dummy data generation for MLM and NSP
vocab_size = 30522
seq_length = 128
batch_size = 32
num_batches = 100
# Generate random token ids for MLM input
X_mlm = np.random.randint(0, vocab_size, size=(num_batches * batch_size, seq_length))
# MLM labels: same shape, with some tokens masked (id=103)
Y_mlm = np.copy(X_mlm)
mask_positions = np.random.rand(*X_mlm.shape) < 0.15
Y_mlm[~mask_positions] = -100 # Ignore tokens not masked
Y_mlm[Y_mlm == -100] = 0 # Set ignored labels to valid dummy class to avoid TF loss error
X_mlm[mask_positions] = 103 # Mask token id
# NSP labels: 0 or 1 for sentence pairs
Y_nsp = np.random.randint(0, 2, size=(num_batches * batch_size, 1))
# Simple BERT-like model with MLM and NSP heads
input_ids = Input(shape=(seq_length,), dtype=tf.int32, name='input_ids')
embedding = Dense(128, activation='relu')(tf.one_hot(input_ids, depth=vocab_size))
sequence_output = Dense(128, activation='relu')(embedding)
# MLM head: predict token ids at each position
mlm_logits = Dense(vocab_size)(sequence_output)
# NSP head: predict if next sentence is consecutive
pooled_output = tf.reduce_mean(sequence_output, axis=1)
nsp_logits = Dense(2)(pooled_output)
model = Model(inputs=input_ids, outputs=[mlm_logits, nsp_logits])
loss_fn_mlm = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction='none')
loss_fn_nsp = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-4)
@tf.function
def train_step(x, y_mlm, y_nsp):
with tf.GradientTape() as tape:
mlm_pred, nsp_pred = model(x, training=True)
# Mask loss for MLM tokens only
mask = tf.not_equal(y_mlm, 0)
mlm_loss_all = loss_fn_mlm(y_mlm, mlm_pred)
mlm_loss = tf.reduce_sum(tf.boolean_mask(mlm_loss_all, mask)) / tf.reduce_sum(tf.cast(mask, tf.float32))
nsp_loss = loss_fn_nsp(y_nsp, nsp_pred)
total_loss = mlm_loss + nsp_loss
grads = tape.gradient(total_loss, model.trainable_variables)
optimizer.apply_gradients(zip(grads, model.trainable_variables))
masked_argmax = tf.boolean_mask(tf.argmax(mlm_pred, axis=-1), mask)
masked_y_mlm = tf.boolean_mask(y_mlm, mask)
mlm_acc = tf.reduce_mean(tf.cast(tf.equal(masked_argmax, masked_y_mlm), tf.float32))
nsp_acc = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(nsp_pred, axis=-1), tf.squeeze(y_nsp)), tf.float32))
return total_loss, mlm_loss, nsp_loss, mlm_acc, nsp_acc
# Training loop
for batch in range(num_batches):
x_batch = X_mlm[batch*batch_size:(batch+1)*batch_size]
y_mlm_batch = Y_mlm[batch*batch_size:(batch+1)*batch_size]
y_nsp_batch = Y_nsp[batch*batch_size:(batch+1)*batch_size]
total_loss, mlm_loss, nsp_loss, mlm_acc, nsp_acc = train_step(x_batch, y_mlm_batch, y_nsp_batch)
if batch % 10 == 0:
print(f"Batch {batch}: Total Loss={total_loss:.3f}, MLM Acc={mlm_acc:.3f}, NSP Acc={nsp_acc:.3f}")