Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
224 changes: 224 additions & 0 deletions BLOGPOST.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,224 @@
# DataMetaMap: Why Compare Datasets? A Method-Driven Blogpost

Understanding **dataset similarity** is a hidden key to transfer learning. If we can measure how "close" one dataset is to another, we can make smarter choices about which model to pre-train on—saving time and boosting performance.

But how do you embed entire datasets into a shared vector space? Our new library, **DataMetaMap**, implements four powerful, research-backed approaches. Below, we walk through the **key methodological insight** behind each one.

---

## 1. Maximum Mean Discrepancy (MMD) — A Classical Kernel View

**Based on:** *A Kernel Two-Sample Test* (Gretton et al., 2012, following the review in arXiv:2208.11726)

### Core Idea

MMD answers a fundamental question: *Are two datasets sampled from the same distribution?* Unlike deep learning approaches, MMD works without training—it directly computes a distance between distributions using kernel functions.

### Mathematical Formulation

Let $P$ and $Q$ be two probability distributions. Given samples $X = \{x_1, ..., x_m\} \sim P$ and $Y = \{y_1, ..., y_n\} \sim Q$, the squared MMD is:

$$\text{MMD}^2(P, Q) = \left\| \mathbb{E}_{x \sim P}[\phi(x)] - \mathbb{E}_{y \sim Q}[\phi(y)] \right\|^2_{\mathcal{H}}$$

where $\phi$ maps data into a Reproducing Kernel Hilbert Space (RKHS) $\mathcal{H}$. Using the kernel trick $k(x, x') = \langle \phi(x), \phi(x') \rangle_{\mathcal{H}}$, we get:

$$\text{MMD}^2 = \mathbb{E}_{x, x' \sim P}[k(x, x')] - 2\mathbb{E}_{x \sim P, y \sim Q}[k(x, y)] + \mathbb{E}_{y, y' \sim Q}[k(y, y')]$$

In practice, we use the unbiased empirical estimate:

$$\widehat{\text{MMD}}^2 = \frac{1}{m(m-1)}\sum_{i \neq j} k(x_i, x_j) - \frac{2}{mn}\sum_{i,j} k(x_i, y_j) + \frac{1}{n(n-1)}\sum_{i \neq j} k(y_i, y_j)$$

### Key Observations

- **No training required** — MMD works directly on raw features or neural network representations
- **Choice of kernel matters** — RBF (Gaussian) kernels with bandwidth selection are standard; DataMetaMap supports multiple kernels
- **Computational cost** — $O((m+n)^2)$ makes it suitable for moderate-sized datasets

### How to Use in DataMetaMap

Pass two datasets to the MMD embedder. The method returns a scalar distance. For embedding, we compute pairwise MMD distances to a set of reference datasets, creating a distance vector.

**Best for:** Quick baseline comparisons, detecting dataset shift, benchmarking other methods.

---

## 2. Task2Vec — Embedding Tasks via Fisher Information

**Based on:** *Task2Vec: Task Embedding for Meta-Learning* (Achille et al., arXiv:1905.11063)

### Core Idea

Every dataset defines a "task" for a neural network. The Fisher Information Matrix (FIM) tells us which parameters are most important for that task. By computing the diagonal of the FIM, Task2Vec creates a vector that captures the task's geometry.

### Mathematical Formulation

For a model with parameters $\theta$ and a dataset $\mathcal{D} = \{(x_i, y_i)\}$ with loss $\mathcal{L}(x, y; \theta)$, the Fisher Information Matrix is:

$$F(\theta) = \mathbb{E}_{x, y \sim \mathcal{D}}\left[ \nabla_\theta \log p(y|x; \theta) \nabla_\theta \log p(y|x; \theta)^\top \right]$$

Computing the full $F$ is prohibitive for modern networks. Task2Vec uses the **diagonal approximation**:

$$f_k = \mathbb{E}_{x, y \sim \mathcal{D}}\left[ \left( \frac{\partial \log p(y|x; \theta)}{\partial \theta_k} \right)^2 \right]$$

The task embedding is then:

$$z_{\text{task}} = \text{diag}(F) \quad \text{or} \quad z_{\text{task}} = \log \text{diag}(F)$$

After fine-tuning a reference network on $\mathcal{D}$ (or using a single gradient step), we compute these per-parameter importances.

### Key Observations

- **Reference network dependent** — Different architectures produce different similarity judgments
- **Fine-tuning is required** — Each dataset needs adaptation of the base model
- **Embedding dimensionality** — Equals number of network parameters (typically millions), often reduced via PCA
- **Log-transform** helps stabilize high-variance Fisher entries

### How to Use in DataMetaMap

1. Choose a reference network (e.g., ResNet-18 pretrained on ImageNet)
2. For each dataset, fine-tune for a few epochs
3. Compute diagonal Fisher Information matrix
4. Return the flattened vector (optionally log-transformed)

**Best for:** Comparing classification tasks when you have a good reference model.

---

## 3. Dataset2Vec — Learning Dataset Representations

**Based on:** *Dataset2Vec: Learning Dataset Meta-Features* (Jomaa et al., arXiv:1902.03545)

### Core Idea

Why compute Fisher matrices when we can learn to embed datasets directly? Dataset2Vec is a **meta-learning** approach: train a neural network that encodes any dataset (as a set of labeled examples) into a fixed-size vector, then optimize this encoder to predict something useful (like relative task similarity).

### Mathematical Formulation

Dataset2Vec processes a dataset as an unordered set:

$$z_{\mathcal{D}} = f_{\text{pool}} \left( \{ g(x_i, y_i) \mid (x_i, y_i) \in \mathcal{D} \} \right)$$

where:
- $g$ is a per-example encoder (typically a small MLP processing the concatenated input and one-hot label)
- $f_{\text{pool}}$ is a permutation-invariant pooling function (sum, mean, or max)
- The output $z_{\mathcal{D}}$ is a $d$-dimensional vector (e.g., $d=128$)

The training objective is meta-learning: given triplets of datasets $\mathcal{D}_a, \mathcal{D}_b, \mathcal{D}_c$ where $\mathcal{D}_a$ is more similar to $\mathcal{D}_b$ than to $\mathcal{D}_c$ (in terms of downstream transfer performance), we use a ranking loss:

$$\mathcal{L} = \max\left(0, \|z_a - z_b\|^2 - \|z_a - z_c\|^2 + \alpha\right)$$

### Key Observations

- **Once trained, inference is fast** — No per-dataset fine-tuning or Fisher computation
- **Meta-training requires many datasets** — Typically hundreds or thousands
- **Permutation invariance** ensures the embedding doesn't depend on data order
- **Generalization potential** — Can embed datasets not seen during meta-training

### How to Use in DataMetaMap

Our library includes:
- Pre-trained Dataset2Vec models on standard benchmarks (e.g., Meta-Dataset)
- Ability to train your own meta-encoder on custom dataset collections
- Support for various pooling strategies and per-example encoders

**Best for:** Large-scale dataset retrieval when you have many datasets and can afford meta-training.

---

## 4. Wasserstein Task Embedding — Optimal Transport Between Datasets

**Based on:** *Wasserstein Task Embedding for Meta-Learning* (Lee et al., arXiv:1605.09522)

### Core Idea

Instead of comparing datasets through a model, compare them directly using **optimal transport**. The Wasserstein distance measures how much "mass" you must move to transform one probability distribution into another. This geometric viewpoint respects the underlying feature space structure.

### Mathematical Formulation

For two probability distributions $\mu$ and $\nu$ on $\mathbb{R}^d$, the $p$-Wasserstein distance is:

$$W_p(\mu, \nu) = \left( \inf_{\gamma \in \Gamma(\mu, \nu)} \int_{\mathbb{R}^d \times \mathbb{R}^d} \|x - y\|^p d\gamma(x, y) \right)^{1/p}$$

where $\Gamma(\mu, \nu)$ is the set of all couplings (joint distributions) with marginals $\mu$ and $\nu$.

For empirical distributions (our datasets), we solve:
- **1D case** (after projecting features): $W_1(\hat{\mu}, \hat{\nu}) = \frac{1}{n} \sum_{i=1}^n |X_{(i)} - Y_{(i)}|$ (sorted samples)
- **High-dimensional case**: Use entropy-regularized Sinkhorn algorithm for $O(n^2)$ approximation

To create an **embedding**, Wasserstein Task Embedding computes distances to $K$ reference distributions:

$$z_{\mathcal{D}} = [W(\mathcal{D}, R_1), W(\mathcal{D}, R_2), ..., W(\mathcal{D}, R_K)]$$

Reference distributions can be:
- Randomly sampled subsets from a large meta-dataset
- Prototypical distributions (e.g., Gaussian with different covariances)
- Other datasets in your collection

### Key Observations

- **No training required** — Works directly on features (e.g., penultimate layer of a frozen network)
- **Handles different dataset sizes** — Unlike maximum mean discrepancy, optimal transport is robust to $n_1 \neq n_2$
- **Computational cost** — $O(n^2)$ for exact Wasserstein, $O(n^2 \log n)$ for Sinkhorn approximation
- **Choice of ground distance** — Euclidean is standard, but any metric works (e.g., cosine distance for embeddings)

### How to Use in DataMetaMap

1. Extract features for all examples using a frozen pre-trained network
2. Choose reference distributions (e.g., 50 random datasets from a meta-collection)
3. For each dataset, compute Wasserstein distance to each reference
4. Return the $K$-dimensional distance vector as the embedding
5. Optionally apply dimensionality reduction (PCA) if $K$ is large

**Best for:** Comparing datasets with imbalanced classes, different sizes, or when you want a geometry-aware metric.

---

## What DataMetaMap Does

Our library implements all four methods **in a unified PyTorch interface**:

- **Unified API** — Same `fit()` and `transform()` pattern across all embedders
- **Flexible feature extraction** — Raw data, pre-trained features, or learned representations
- **Reference management** — For MMD and Wasserstein methods, handle reference dataset selection
- **Visualization tools** — PCA, t-SNE, and UMAP projections of dataset embeddings
- **Similarity search** — Find nearest datasets to a target

**No code examples here—just the methods. But the repo contains ready-to-run demos.**

---

## Method Comparison at a Glance

| Method | Training Required | Inference Speed | Dimensionality | Handles Different Sizes | Geometric Interpretation |
|--------|:----------------:|:---------------:|:--------------:|:-----------------------:|:------------------------:|
| MMD | None | Medium (quadratic) | Variable (n_refs) | Yes | RKHS distance |
| Task2Vec | Per-dataset fine-tuning | Slow (per dataset) | # Parameters | N/A (fixed network) | Fisher information |
| Dataset2Vec | Meta-training (once) | Fast | Fixed (e.g., 128) | Yes | Learned similarity |
| Wasserstein | None | Slow (quadratic) | Fixed (n_refs) | Yes | Optimal transport |

---

## Key Insight Across All Methods

Despite their different mathematical origins (kernel methods, Fisher information, learned encoders, optimal transport), **all four approaches reduce to the same operation**: mapping a dataset to a vector where Euclidean distance correlates with transfer learning performance. DataMetaMap lets you compare which method works best for your domain.

---

## Practical Recommendations from Our Observations

- **For quick baselines** → Start with MMD on pre-trained features
- **When you have a strong reference model** → Try Task2Vec with few-shot fine-tuning
- **When you have many datasets for training** → Train a Dataset2Vec meta-encoder
- **When dataset sizes vary greatly** → Wasserstein embedding is your best bet
- **When computational budget is high** → Ensemble multiple methods

---

## References

- Gretton et al. (2012) – *A Kernel Two-Sample Test.* Review: arXiv:2208.11726
- Achille et al. (2019) – *Task2Vec: Task Embedding for Meta-Learning.* arXiv:1905.11063
- Jomaa et al. (2019) – *Dataset2Vec: Learning Dataset Meta-Features.* arXiv:1902.03545
- Lee et al. (2016) – *Wasserstein Task Embedding for Meta-Learning.* arXiv:1605.09522

**Our repo:** [DataMetaMap](https://github.com/intsystems/DataMetaMap)
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,9 @@ It includes multiple dataset embedding algorithms implemented on top of PyTorch:

## 📬 Assets

1. [Technical Meeting 1 - Presentation](assets/BMM_technical_1.pdf)
2. [Blog Post](https://data-meta-map.hashnode.dev/metamap-your-compass-for-navigating-the-universe-of-machine-learning-tasks)
3. [Technical Report](report/data_meta_map.pdf)
1. [Technical Meeting 1 - Presentation](https://github.com/intsystems/DataMetaMap/blob/master/assets/BMM_technical_1.pdf)
2. [Blog Post](https://github.com/intsystems/DataMetaMap/edit/meshkovvl/BLOGPOST.md)
3. [Technical Report](https://github.com/intsystems/DataMetaMap/blob/develop/report/data_meta_map.pdf)


## 💡 Motivation
Expand Down
60 changes: 32 additions & 28 deletions tests/test_task2vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
cdist,
)
from data_meta_map.task2vec.utils import AverageMeter, get_error, get_device
from data_meta_map.models import get_model
from data_meta_map import datasets


# ── helpers ────────────────────────────────────────────────────────────────────
Expand Down Expand Up @@ -92,15 +94,9 @@ def test_classifier_property(self):
net = _SimpleProbeNetwork()
assert net.classifier is net.fc

# def test_classifier_setter(self):
# net = _SimpleProbeNetwork()
# new_fc = nn.Linear(16, 5)
# net.classifier = new_fc
# assert net.fc is new_fc


# ── Task2Vec.__init__ ──────────────────────────────────────────────────────────


class TestTask2VecInit:
def test_default_attributes(self):
model = _SimpleProbeNetwork()
Expand Down Expand Up @@ -161,7 +157,7 @@ def test_inherits_base_embedder(self):

# ── Task2Vec.extract_embedding ─────────────────────────────────────────────────

class TestExtractEmbedding:
class TestExtractEmbeddingRealData:
def _make_model_with_grad2(self, n_filters=4):
model = _SimpleProbeNetwork(num_classes=2)
# Simulate what montecarlo_fisher stores on weight tensors
Expand All @@ -172,27 +168,35 @@ def _make_model_with_grad2(self, n_filters=4):
module.weight.grad2_acc = torch.ones_like(module.weight) * 0.5
return model

#def test_extract_returns_embedding(self):
# model = self._make_model_with_grad2()
# t2v = Task2Vec(model)
# emb = t2v.extract_embedding(model)
# assert isinstance(emb, Embedding)

# def test_hessian_non_empty(self):
# model = self._make_model_with_grad2()
# t2v = Task2Vec(model)
# emb = t2v.extract_embedding(model)
# assert emb.hessian.size > 0
# assert emb.scale.size > 0

# def test_hessian_shape_matches_scale(self):
# model = self._make_model_with_grad2()
# t2v = Task2Vec(model)
# emb = t2v.extract_embedding(model)
# assert emb.hessian.shape == emb.scale.shape

def test_mnist_resnet(self):
dataset = datasets.__dict__['mnist'](root='../../data')[0]
model = get_model('resnet18', pretrained=True,
num_classes=int(max(dataset.targets)+1)).cuda()
task2vec_embedder = Task2Vec(model, skip_layers=6, max_samples=200)
emb = task2vec_embedder.embed(dataset)
assert isinstance(emb, np.ndarray)
assert emb.shape == (7680, )

def test_mnist_resnet_less_skip(self):
dataset = datasets.__dict__['mnist'](root='../../data')[0]
model = get_model('resnet18', pretrained=True,
num_classes=int(max(dataset.targets)+1)).cuda()
task2vec_embedder = Task2Vec(model, skip_layers=2, max_samples=200)
emb = task2vec_embedder.embed(dataset)
assert isinstance(emb, np.ndarray)
assert emb.shape == (9472,)

def test_extract_hessian(self):
dataset = datasets.__dict__['mnist'](root='../../data')[0]
model = get_model('resnet18', pretrained=True,
num_classes=int(max(dataset.targets)+1)).cuda()
task2vec_embedder = Task2Vec(model, skip_layers=2, max_samples=200)
emb = task2vec_embedder.embed(dataset, create_final_embedding=False)
assert isinstance(emb.hessian, np.ndarray)
assert isinstance(emb.scale, np.ndarray)
assert emb.scale.shape == (9472,)
assert emb.hessian.shape == (9472,)

# ── task_similarity: scalar distance functions ─────────────────────────────────

class TestDistanceFunctions:
@pytest.fixture
Expand Down
Loading