How to Use nn.Embedding in PyTorch: Syntax and Example
Use
nn.Embedding(num_embeddings, embedding_dim) in PyTorch to convert integer indices into dense vectors. It creates a lookup table that maps each index to a vector of fixed size, useful for representing categorical data like words.Syntax
The nn.Embedding layer requires two main arguments: num_embeddings which is the size of the dictionary of embeddings (number of unique items), and embedding_dim which is the size of each embedding vector.
When you pass a tensor of indices to this layer, it returns the corresponding embedding vectors.
python
import torch import torch.nn as nn embedding = nn.Embedding(num_embeddings=10, embedding_dim=3) input_indices = torch.LongTensor([1, 2, 4, 8]) output_vectors = embedding(input_indices)
Example
This example shows how to create an embedding layer for 5 unique items, each represented by a 4-dimensional vector. We pass a batch of indices and get their embeddings.
python
import torch import torch.nn as nn # Create embedding layer for 5 items, each with 4 features embedding = nn.Embedding(num_embeddings=5, embedding_dim=4) # Input tensor with indices input_indices = torch.LongTensor([0, 2, 4]) # Get embeddings for input indices output = embedding(input_indices) print("Input indices:", input_indices) print("Output embeddings:", output) print("Output shape:", output.shape)
Output
Input indices: tensor([0, 2, 4])
Output embeddings: tensor([[ 0.1234, -0.5678, 0.9101, -0.1121],
[-0.3141, 0.5161, -0.7181, 0.9202],
[ 0.1223, -0.3245, 0.5267, -0.7289]], grad_fn=<EmbeddingBackward0>)
Output shape: torch.Size([3, 4])
Common Pitfalls
- Wrong input type: The input to
nn.Embeddingmust be a tensor of typetorch.LongTensorortorch.cuda.LongTensor. Using floats or other types causes errors. - Index out of range: Indices must be between 0 and
num_embeddings - 1. Using an index outside this range will raise an error. - Forgetting to use LongTensor: Passing a default tensor (float) instead of LongTensor is a common mistake.
python
import torch import torch.nn as nn embedding = nn.Embedding(num_embeddings=3, embedding_dim=2) # Wrong input type (float tensor) - will cause error try: wrong_input = torch.tensor([0.0, 1.0, 2.0]) embedding(wrong_input) except Exception as e: print("Error with float input:", e) # Correct input type correct_input = torch.LongTensor([0, 1, 2]) output = embedding(correct_input) print("Output with correct input:", output)
Output
Error with float input: embedding_0: LongTensor expected, got FloatTensor
Output with correct input: tensor([[ 0.1234, -0.5678],
[ 0.9101, -0.1121],
[-0.3141, 0.5161]], grad_fn=<EmbeddingBackward0>)
Quick Reference
- num_embeddings: Number of unique items to embed.
- embedding_dim: Size of each embedding vector.
- Input: Tensor of indices (LongTensor).
- Output: Tensor of embeddings with shape (input_size, embedding_dim).
- Use case: Represent categorical data like words or IDs as vectors.
Key Takeaways
Use nn.Embedding to map integer indices to dense vectors in PyTorch.
Input to nn.Embedding must be a LongTensor with indices in valid range.
Embedding layer size depends on number of unique items and embedding dimension.
Embedding outputs have shape (input_length, embedding_dim).
Common errors include wrong input type and out-of-range indices.