import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
# Sample synthetic dataset (replace with real point cloud data)
X_train = torch.randn(1000, 1024, 3) # 1000 samples, 1024 points each, 3 coords
y_train = torch.randint(0, 10, (1000,)) # 10 classes
X_val = torch.randn(200, 1024, 3)
y_val = torch.randint(0, 10, (200,))
train_dataset = TensorDataset(X_train, y_train)
val_dataset = TensorDataset(X_val, y_val)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32)
class SimplePointNet(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv1d(3, 64, 1)
self.bn1 = nn.BatchNorm1d(64)
self.conv2 = nn.Conv1d(64, 128, 1)
self.bn2 = nn.BatchNorm1d(128)
self.conv3 = nn.Conv1d(128, 256, 1)
self.bn3 = nn.BatchNorm1d(256)
self.dropout = nn.Dropout(p=0.3)
self.fc1 = nn.Linear(256, 128)
self.bn4 = nn.BatchNorm1d(128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = x.transpose(1, 2) # (batch, 3, 1024)
x = nn.functional.relu(self.bn1(self.conv1(x)))
x = nn.functional.relu(self.bn2(self.conv2(x)))
x = nn.functional.relu(self.bn3(self.conv3(x)))
x = torch.max(x, 2)[0] # max pooling over points
x = self.dropout(x)
x = nn.functional.relu(self.bn4(self.fc1(x)))
x = self.fc2(x)
return x
model = SimplePointNet()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
# Training loop
for epoch in range(30):
model.train()
total_loss = 0
correct = 0
total = 0
for data, target in train_loader:
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
total_loss += loss.item() * data.size(0)
pred = output.argmax(dim=1)
correct += (pred == target).sum().item()
total += data.size(0)
train_loss = total_loss / total
train_acc = correct / total * 100
model.eval()
val_loss = 0
val_correct = 0
val_total = 0
with torch.no_grad():
for data, target in val_loader:
output = model(data)
loss = criterion(output, target)
val_loss += loss.item() * data.size(0)
pred = output.argmax(dim=1)
val_correct += (pred == target).sum().item()
val_total += data.size(0)
val_loss /= val_total
val_acc = val_correct / val_total * 100
print(f"Epoch {epoch+1}: Train loss {train_loss:.3f}, Train acc {train_acc:.1f}%, Val loss {val_loss:.3f}, Val acc {val_acc:.1f}%")