Embedding Analysis¶
This guide covers analyzing and interpreting embeddings generated by PhenoCLR models. Learn how to extract, visualize, and evaluate the quality of learned representations.
Overview¶
Embedding analysis helps understand what features the model has learned and how well they capture biological patterns. This includes dimensionality reduction, clustering analysis, and correlation with biological metadata.
Extracting Embeddings¶
Basic Embedding Extraction¶
import torch
import numpy as np
from pathlib import Path
def generate_embeddings(model, dataloader):
"""Generates representations for all images in the dataloader with
the given model
"""
from tqdm import tqdm
embeddings = []
filenames = []
with torch.no_grad():
for img,_, fnames in tqdm(dataloader, desc="Generating embeddings", unit="batch"):
img = img.to(model.device)
emb = model.backbone(img).flatten(start_dim=1)
# Move to CPU immediately to free GPU memory
embeddings.append(emb.cpu())
filenames.extend(fnames)
# Clear GPU cache after each batch
torch.cuda.empty_cache()
embeddings = torch.cat(embeddings, 0)
embeddings = normalize(embeddings)
return embeddings, filenames
The main function to generate embedding is taken from the training script. If a model is trained but the embeddings.csv file is not generated, you can use the generate_embeddings*.py scripts to generate it.
Example:
python embeddings/generate_embeddings.py --model_path /path/to/model.ckpt --config /path/to/config.yaml --data_dir /path/to/data --output_dir /path/to/output
Next Steps¶
- Visualization Tools: Advanced visualization techniques
- Statistical Metrics: Detailed statistical analysis
- Examples: Practical analysis examples