Custom Entity Metric#

Learn how to implement a custom entity-level distance metric by subclassing EntityMetric.

An entity metric computes a scalar distance between two individual entities (atomic observations in a sequence). It is the basic building block used by most sequence-level metrics.

Minimal contract

  1. Declare a SETTINGS_CLASS dataclass (use tanat_utils.settings_dataclass).

  2. Implement validate_entity(ent_a, ent_b): raise TypeError / KeyError when the entities are incompatible with the metric.

  3. Implement _compute(ent_a, ent_b): return a non-negative float.

The public __call__ in the base class invokes validate_entity then _compute automatically.

Setup#

import polars as pl
from tanat_utils import settings_dataclass as dataclass

from tanat import build_events
from tanat.dataset import simulate_events
from tanat.metric.entity.base import EntityMetric

Data#

A small event pool with a score numeric feature and a status categorical feature.

raw = simulate_events(n_ids=20, features=["score", "status"], seed=0)
pool = build_events(temporal_data=raw, id_column="id", time_column="time")
pool.cast_features({"score": pl.Float32})
┌─ Event SequenceStore
│
│ Step 1/4: Sorting & preparing data
│
│ Step 2/4: Building sequence index
│
│ Step 3/4: Writing entity & time index features
│
│ Step 4/4: Computing & writing metadata
│
└─ Done (20 sequences · 138 entities · 0.00s)
print(pool)
┌────────────────────────────────────────────────┐
│           EventSequencePool Summary            │
└────────────────────────────────────────────────┘

Overview
─────────────────────────
  Sequences          20
  Store              /home/runner/.tanat/_quick_event_bd06d362
  id_column          id

Time Index
─────────────────────────
  Type               Datetime(time_unit='us', time_zone=None) [2000-01-03 17:54:05.937674 → 2024-11-15 14:02:43.302235]
  Columns            ['time']
  t0                 position=0, anchor=None

Entity Features (2)
─────────────────────────
  • score               Numerical [1.0 → 100.0]
  • status              String [len 1 → 1]

Define a custom entity metric#

We implement a normalised absolute difference on a numeric feature. dist(a, b) = |a.score - b.score| / scale

@dataclass
class ScoreDiffSettings:
    """Settings for :class:`ScoreDiffEntityMetric`.

    Args:
        entity_feature: Numeric feature to compare.
        scale: Normalisation constant (``1.0`` means raw absolute diff).
    """

    entity_feature: str = "score"
    scale: float = 1.0


class ScoreDiffEntityMetric(EntityMetric, register_name="score_diff"):
    """Normalised absolute difference on a numeric entity feature."""

    SETTINGS_CLASS = ScoreDiffSettings

    def __init__(self, entity_feature: str = "score", scale: float = 1.0) -> None:
        super().__init__(
            settings=ScoreDiffSettings(entity_feature=entity_feature, scale=scale)
        )

    # ------------------------------------------------------------------
    # Required implementations
    # ------------------------------------------------------------------

    def validate_entity(self, ent_a, ent_b=None) -> None:
        """Check that both entities expose the expected numeric feature."""
        self._validate_entity_instance(ent_a, ent_b)
        feat = self.settings.entity_feature
        for ent in (e for e in (ent_a, ent_b) if e is not None):
            if feat not in ent.data():
                raise KeyError(
                    f"Feature {feat!r} not found in entity. "
                    f"Available: {list(ent.data().keys())}"
                )

    def _compute(self, ent_a, ent_b) -> float:
        """Return normalised absolute difference of the configured feature."""
        feat = self.settings.entity_feature
        val_a = float(ent_a[feat])
        val_b = float(ent_b[feat])
        return abs(val_a - val_b) / self.settings.scale

Instantiate and inspect#

metric = ScoreDiffEntityMetric(entity_feature="score", scale=100.0)
print(metric)
ScoreDiffEntityMetric(settings=ScoreDiffSettings(entity_feature='score', scale=100.0))

Compute a distance between two entities#

ids = pool.unique_ids
ent_a = pool[ids[0]][0]
ent_b = pool[ids[1]][0]

print(f"score A : {ent_a['score']}")
print(f"score B : {ent_b['score']}")
print(f"distance: {metric(ent_a, ent_b):.4f}")
score A : 28.0
score B : 73.0
distance: 0.4500

Compute distances over several pairs#

print("Sample pairwise distances")
print("-" * 40)
for i in range(5):
    ea = pool[ids[i]][0]
    eb = pool[ids[i + 1]][0]
    d = metric(ea, eb)
    print(f"  {ea['score']:6.1f}  vs  {eb['score']:6.1f}  ->  {d:.4f}")
Sample pairwise distances
----------------------------------------
    28.0  vs    73.0  ->  0.4500
    73.0  vs    30.0  ->  0.4300
    30.0  vs     1.0  ->  0.2900
     1.0  vs    62.0  ->  0.6100
    62.0  vs    81.0  ->  0.1900

Use the custom metric inside a sequence metric#

Any EntityMetric can be passed as the entity_metric argument to sequence-level metrics such as LinearPairwiseSequenceMetric.

from tanat.metric.sequence import LinearPairwiseSequenceMetric

lp = LinearPairwiseSequenceMetric(entity_metric=metric)
print(lp)
LinearPairwiseSequenceMetric(settings=LinearPairwiseSettings(entity_metric=ScoreDiffEntityMetric(settings=ScoreDiffSettings(entity_feature='score', scale=100.0)), agg_fun='mean', padding_penalty=None))
seq_a = pool[ids[0]]
seq_b = pool[ids[1]]
dist = lp(seq_a, seq_b)
print(f"LinearPairwise distance: {dist:.4f}")
LinearPairwise distance: 0.3025
dm = lp.compute_matrix(pool)
dm.to_frame().head()
┌─ LinearPairwiseSequenceMetric
│

│ Pairs:   0%|          | 0/400 [00:00<?, ?it/s]
│ Pairs:   1%|          | 3/400 [00:00<00:14, 27.03it/s]
│ Pairs:   2%|▏         | 9/400 [00:00<00:09, 42.83it/s]
│ Pairs:   4%|▎         | 14/400 [00:00<00:11, 33.65it/s]
│ Pairs:   4%|▍         | 18/400 [00:00<00:12, 30.94it/s]
│ Pairs:   6%|▌         | 22/400 [00:00<00:12, 29.59it/s]
│ Pairs:   7%|▋         | 27/400 [00:00<00:10, 34.16it/s]
│ Pairs:   8%|▊         | 31/400 [00:00<00:10, 35.04it/s]
│ Pairs:   9%|▉         | 35/400 [00:01<00:11, 32.97it/s]
│ Pairs:  10%|▉         | 39/400 [00:01<00:11, 31.69it/s]
│ Pairs:  11%|█         | 43/400 [00:01<00:11, 31.23it/s]
│ Pairs:  12%|█▏        | 49/400 [00:01<00:09, 37.51it/s]
│ Pairs:  13%|█▎        | 53/400 [00:01<00:09, 35.50it/s]
│ Pairs:  14%|█▍        | 57/400 [00:01<00:10, 34.14it/s]
│ Pairs:  15%|█▌        | 61/400 [00:01<00:10, 33.89it/s]
│ Pairs:  16%|█▋        | 66/400 [00:01<00:08, 37.28it/s]
│ Pairs:  18%|█▊        | 72/400 [00:02<00:07, 41.27it/s]
│ Pairs:  19%|█▉        | 77/400 [00:02<00:07, 41.72it/s]
│ Pairs:  20%|██        | 82/400 [00:02<00:07, 42.06it/s]
│ Pairs:  22%|██▏       | 88/400 [00:02<00:06, 45.12it/s]
│ Pairs:  23%|██▎       | 93/400 [00:02<00:06, 44.92it/s]
│ Pairs:  24%|██▍       | 98/400 [00:02<00:06, 44.32it/s]
│ Pairs:  26%|██▌       | 104/400 [00:02<00:06, 47.85it/s]
│ Pairs:  28%|██▊       | 111/400 [00:02<00:05, 53.65it/s]
│ Pairs:  30%|██▉       | 118/400 [00:02<00:04, 57.89it/s]
│ Pairs:  31%|███▏      | 125/400 [00:03<00:04, 60.97it/s]
│ Pairs:  33%|███▎      | 132/400 [00:03<00:04, 63.33it/s]
│ Pairs:  35%|███▍      | 139/400 [00:03<00:04, 64.78it/s]
│ Pairs:  36%|███▋      | 146/400 [00:03<00:03, 65.86it/s]
│ Pairs:  38%|███▊      | 153/400 [00:03<00:03, 66.71it/s]
│ Pairs:  40%|████      | 160/400 [00:03<00:03, 67.28it/s]
│ Pairs:  42%|████▏     | 167/400 [00:03<00:03, 63.28it/s]
│ Pairs:  44%|████▎     | 174/400 [00:03<00:03, 60.36it/s]
│ Pairs:  45%|████▌     | 181/400 [00:03<00:04, 53.94it/s]
│ Pairs:  47%|████▋     | 187/400 [00:04<00:04, 48.41it/s]
│ Pairs:  48%|████▊     | 193/400 [00:04<00:04, 42.61it/s]
│ Pairs:  50%|████▉     | 198/400 [00:04<00:05, 37.66it/s]
│ Pairs:  50%|█████     | 202/400 [00:04<00:05, 34.73it/s]
│ Pairs:  52%|█████▏    | 207/400 [00:04<00:05, 37.29it/s]
│ Pairs:  53%|█████▎    | 211/400 [00:04<00:05, 37.31it/s]
│ Pairs:  54%|█████▍    | 215/400 [00:04<00:05, 34.74it/s]
│ Pairs:  55%|█████▍    | 219/400 [00:05<00:05, 33.00it/s]
│ Pairs:  56%|█████▌    | 223/400 [00:05<00:05, 31.12it/s]
│ Pairs:  57%|█████▋    | 229/400 [00:05<00:04, 37.09it/s]
│ Pairs:  58%|█████▊    | 233/400 [00:05<00:05, 33.18it/s]
│ Pairs:  59%|█████▉    | 237/400 [00:05<00:05, 31.05it/s]
│ Pairs:  60%|██████    | 241/400 [00:05<00:05, 30.15it/s]
│ Pairs:  61%|██████▏   | 245/400 [00:05<00:04, 31.58it/s]
│ Pairs:  63%|██████▎   | 251/400 [00:06<00:04, 35.99it/s]
│ Pairs:  64%|██████▍   | 255/400 [00:06<00:04, 34.43it/s]
│ Pairs:  65%|██████▍   | 259/400 [00:06<00:04, 33.41it/s]
│ Pairs:  66%|██████▌   | 263/400 [00:06<00:04, 32.65it/s]
│ Pairs:  67%|██████▋   | 269/400 [00:06<00:03, 38.42it/s]
│ Pairs:  68%|██████▊   | 273/400 [00:06<00:03, 36.05it/s]
│ Pairs:  69%|██████▉   | 277/400 [00:06<00:03, 34.43it/s]
│ Pairs:  70%|███████   | 281/400 [00:06<00:03, 32.62it/s]
│ Pairs:  71%|███████▏  | 285/400 [00:07<00:03, 33.10it/s]
│ Pairs:  72%|███████▎  | 290/400 [00:07<00:02, 36.95it/s]
│ Pairs:  74%|███████▎  | 294/400 [00:07<00:03, 33.62it/s]
│ Pairs:  74%|███████▍  | 298/400 [00:07<00:03, 31.23it/s]
│ Pairs:  76%|███████▌  | 302/400 [00:07<00:03, 29.69it/s]
│ Pairs:  77%|███████▋  | 307/400 [00:07<00:02, 33.73it/s]
│ Pairs:  78%|███████▊  | 311/400 [00:07<00:02, 33.66it/s]
│ Pairs:  79%|███████▉  | 315/400 [00:07<00:02, 32.08it/s]
│ Pairs:  80%|███████▉  | 319/400 [00:08<00:02, 30.99it/s]
│ Pairs:  81%|████████  | 323/400 [00:08<00:02, 29.91it/s]
│ Pairs:  82%|████████▏ | 329/400 [00:08<00:01, 35.89it/s]
│ Pairs:  83%|████████▎ | 333/400 [00:08<00:02, 33.32it/s]
│ Pairs:  84%|████████▍ | 337/400 [00:08<00:01, 31.68it/s]
│ Pairs:  85%|████████▌ | 341/400 [00:08<00:01, 31.13it/s]
│ Pairs:  86%|████████▋ | 345/400 [00:08<00:01, 32.37it/s]
│ Pairs:  88%|████████▊ | 351/400 [00:09<00:01, 36.72it/s]
│ Pairs:  89%|████████▉ | 355/400 [00:09<00:01, 34.88it/s]
│ Pairs:  90%|████████▉ | 359/400 [00:09<00:01, 33.64it/s]
│ Pairs:  91%|█████████ | 363/400 [00:09<00:01, 32.73it/s]
│ Pairs:  92%|█████████▏| 369/400 [00:09<00:00, 38.37it/s]
│ Pairs:  93%|█████████▎| 373/400 [00:09<00:00, 35.72it/s]
│ Pairs:  94%|█████████▍| 377/400 [00:09<00:00, 34.26it/s]
│ Pairs:  95%|█████████▌| 381/400 [00:09<00:00, 32.57it/s]
│ Pairs:  96%|█████████▋| 385/400 [00:10<00:00, 33.10it/s]
│ Pairs:  98%|█████████▊| 390/400 [00:10<00:00, 36.99it/s]
│ Pairs:  98%|█████████▊| 394/400 [00:10<00:00, 33.75it/s]
│ Pairs: 100%|█████████▉| 398/400 [00:10<00:00, 31.35it/s]
│ Pairs: 100%|██████████| 400/400 [00:10<00:00, 37.90it/s]
│
└─ Done (20 sequences · 10.56s)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
1 0.000000 0.302500 0.378571 0.268 0.348 0.333333 0.190000 0.243333 0.320 0.264444 0.245000 0.397778 0.218571 0.571429 0.351111 0.432500 0.285000 0.375714 0.268571 0.395556
2 0.302500 0.000000 0.377143 0.482 0.182 0.140000 0.210000 0.270000 0.225 0.285000 0.330000 0.292500 0.257143 0.272857 0.315000 0.282500 0.305000 0.125714 0.361429 0.392500
3 0.378571 0.377143 0.000000 0.210 0.334 0.353333 0.363333 0.343333 0.235 0.284286 0.311429 0.287143 0.388571 0.318571 0.427143 0.432857 0.475714 0.348571 0.404286 0.324286
4 0.268000 0.482000 0.210000 0.000 0.352 0.420000 0.363333 0.343333 0.260 0.244000 0.210000 0.326000 0.384000 0.440000 0.432000 0.482000 0.416000 0.432000 0.308000 0.352000
5 0.348000 0.182000 0.334000 0.352 0.000 0.140000 0.176667 0.210000 0.175 0.220000 0.258000 0.150000 0.316000 0.244000 0.224000 0.222000 0.192000 0.224000 0.232000 0.224000


Total running time of the script: (0 minutes 10.663 seconds)

Gallery generated by Sphinx-Gallery