Embedding Analysis

This guide covers analyzing and interpreting embeddings generated by PhenoCLR models. Learn how to extract, visualize, and evaluate the quality of learned representations.

Overview

Embedding analysis helps understand what features the model has learned and how well they capture biological patterns. This includes dimensionality reduction, clustering analysis, and correlation with biological metadata.

Extracting Embeddings

Basic Embedding Extraction

import torch
import numpy as np
from pathlib import Path

def generate_embeddings(model, dataloader):
    """Generates representations for all images in the dataloader with
    the given model
    """
    from tqdm import tqdm

    embeddings = []
    filenames = []
    with torch.no_grad():
        for img,_, fnames in tqdm(dataloader, desc="Generating embeddings", unit="batch"):
            img = img.to(model.device)
            emb = model.backbone(img).flatten(start_dim=1)
            # Move to CPU immediately to free GPU memory
            embeddings.append(emb.cpu())
            filenames.extend(fnames)
            # Clear GPU cache after each batch
            torch.cuda.empty_cache()

    embeddings = torch.cat(embeddings, 0)
    embeddings = normalize(embeddings)
    return embeddings, filenames

The main function to generate embedding is taken from the training script. If a model is trained but the embeddings.csv file is not generated, you can use the generate_embeddings*.py scripts to generate it.

Example:

python embeddings/generate_embeddings.py --model_path /path/to/model.ckpt --config /path/to/config.yaml --data_dir /path/to/data --output_dir /path/to/output

Next Steps