import os
import json
import joblib
from datetime import datetime
class SimpleModelRegistry:
def __init__(self, registry_dir='model_registry'):
self.registry_dir = registry_dir
os.makedirs(self.registry_dir, exist_ok=True)
self.metadata_file = os.path.join(self.registry_dir, 'metadata.json')
if not os.path.exists(self.metadata_file):
with open(self.metadata_file, 'w') as f:
json.dump([], f)
def register_model(self, model, version, metrics):
model_path = os.path.join(self.registry_dir, f'model_v{version}.joblib')
joblib.dump(model, model_path)
with open(self.metadata_file, 'r') as f:
metadata = json.load(f)
metadata.append({
'version': version,
'model_path': model_path,
'metrics': metrics,
'registered_at': datetime.now().isoformat()
})
with open(self.metadata_file, 'w') as f:
json.dump(metadata, f, indent=2)
print(f'Model version {version} registered successfully.')
def load_model(self, version=None):
with open(self.metadata_file, 'r') as f:
metadata = json.load(f)
if not metadata:
print('No models registered yet.')
return None
if version is None:
# Load latest version
latest = max(metadata, key=lambda x: x['version'])
version = latest['version']
else:
latest = next((m for m in metadata if m['version'] == version), None)
if latest is None:
print(f'Model version {version} not found.')
return None
model = joblib.load(latest['model_path'])
print(f'Model version {version} loaded.')
return model
# Example usage:
if __name__ == '__main__':
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
data = load_iris()
X_train, X_test, y_train, y_test = train_test_split(data.data, data.target, test_size=0.3, random_state=42)
model = DecisionTreeClassifier(random_state=42)
model.fit(X_train, y_train)
preds = model.predict(X_test)
acc = accuracy_score(y_test, preds)
registry = SimpleModelRegistry()
registry.register_model(model, version=1, metrics={'accuracy': acc})
loaded_model = registry.load_model() # Loads latest
loaded_preds = loaded_model.predict(X_test)
loaded_acc = accuracy_score(y_test, loaded_preds)
print(f'Loaded model accuracy: {loaded_acc:.2f}')