Note
Go to the end to download the full example code.
Custom Entity Metric#
Learn how to implement a custom entity-level distance metric by subclassing
EntityMetric.
An entity metric computes a scalar distance between two individual entities (atomic observations in a sequence). It is the basic building block used by most sequence-level metrics.
Minimal contract
Declare a
SETTINGS_CLASSdataclass (usetanat_utils.settings_dataclass).Implement
validate_entity(ent_a, ent_b): raiseTypeError/KeyErrorwhen the entities are incompatible with the metric.Implement
_compute(ent_a, ent_b): return a non-negativefloat.
The public __call__ in the base class invokes validate_entity then
_compute automatically.
Setup#
import polars as pl
from tanat_utils import settings_dataclass as dataclass
from tanat import build_events
from tanat.dataset import simulate_events
from tanat.metric.entity.base import EntityMetric
Data#
A small event pool with a score numeric feature and a status
categorical feature.
raw = simulate_events(n_ids=20, features=["score", "status"], seed=0)
pool = build_events(temporal_data=raw, id_column="id", time_column="time")
pool.cast_features({"score": pl.Float32})
┌─ Event SequenceStore
│
│ Step 1/4: Sorting & preparing data
│
│ Step 2/4: Building sequence index
│
│ Step 3/4: Writing entity & time index features
│
│ Step 4/4: Computing & writing metadata
│
└─ Done (20 sequences · 138 entities · 0.00s)
print(pool)
┌────────────────────────────────────────────────┐
│ EventSequencePool Summary │
└────────────────────────────────────────────────┘
Overview
─────────────────────────
Sequences 20
Store /home/runner/.tanat/_quick_event_bd06d362
id_column id
Time Index
─────────────────────────
Type Datetime(time_unit='us', time_zone=None) [2000-01-03 17:54:05.937674 → 2024-11-15 14:02:43.302235]
Columns ['time']
t0 position=0, anchor=None
Entity Features (2)
─────────────────────────
• score Numerical [1.0 → 100.0]
• status String [len 1 → 1]
Define a custom entity metric#
We implement a normalised absolute difference on a numeric feature.
dist(a, b) = |a.score - b.score| / scale
@dataclass
class ScoreDiffSettings:
"""Settings for :class:`ScoreDiffEntityMetric`.
Args:
entity_feature: Numeric feature to compare.
scale: Normalisation constant (``1.0`` means raw absolute diff).
"""
entity_feature: str = "score"
scale: float = 1.0
class ScoreDiffEntityMetric(EntityMetric, register_name="score_diff"):
"""Normalised absolute difference on a numeric entity feature."""
SETTINGS_CLASS = ScoreDiffSettings
def __init__(self, entity_feature: str = "score", scale: float = 1.0) -> None:
super().__init__(
settings=ScoreDiffSettings(entity_feature=entity_feature, scale=scale)
)
# ------------------------------------------------------------------
# Required implementations
# ------------------------------------------------------------------
def validate_entity(self, ent_a, ent_b=None) -> None:
"""Check that both entities expose the expected numeric feature."""
self._validate_entity_instance(ent_a, ent_b)
feat = self.settings.entity_feature
for ent in (e for e in (ent_a, ent_b) if e is not None):
if feat not in ent.data():
raise KeyError(
f"Feature {feat!r} not found in entity. "
f"Available: {list(ent.data().keys())}"
)
def _compute(self, ent_a, ent_b) -> float:
"""Return normalised absolute difference of the configured feature."""
feat = self.settings.entity_feature
val_a = float(ent_a[feat])
val_b = float(ent_b[feat])
return abs(val_a - val_b) / self.settings.scale
Instantiate and inspect#
metric = ScoreDiffEntityMetric(entity_feature="score", scale=100.0)
print(metric)
ScoreDiffEntityMetric(settings=ScoreDiffSettings(entity_feature='score', scale=100.0))
Compute a distance between two entities#
ids = pool.unique_ids
ent_a = pool[ids[0]][0]
ent_b = pool[ids[1]][0]
print(f"score A : {ent_a['score']}")
print(f"score B : {ent_b['score']}")
print(f"distance: {metric(ent_a, ent_b):.4f}")
score A : 28.0
score B : 73.0
distance: 0.4500
Compute distances over several pairs#
print("Sample pairwise distances")
print("-" * 40)
for i in range(5):
ea = pool[ids[i]][0]
eb = pool[ids[i + 1]][0]
d = metric(ea, eb)
print(f" {ea['score']:6.1f} vs {eb['score']:6.1f} -> {d:.4f}")
Sample pairwise distances
----------------------------------------
28.0 vs 73.0 -> 0.4500
73.0 vs 30.0 -> 0.4300
30.0 vs 1.0 -> 0.2900
1.0 vs 62.0 -> 0.6100
62.0 vs 81.0 -> 0.1900
Use the custom metric inside a sequence metric#
Any EntityMetric can be passed as the
entity_metric argument to sequence-level metrics such as
LinearPairwiseSequenceMetric.
from tanat.metric.sequence import LinearPairwiseSequenceMetric
lp = LinearPairwiseSequenceMetric(entity_metric=metric)
print(lp)
LinearPairwiseSequenceMetric(settings=LinearPairwiseSettings(entity_metric=ScoreDiffEntityMetric(settings=ScoreDiffSettings(entity_feature='score', scale=100.0)), agg_fun='mean', padding_penalty=None))
seq_a = pool[ids[0]]
seq_b = pool[ids[1]]
dist = lp(seq_a, seq_b)
print(f"LinearPairwise distance: {dist:.4f}")
LinearPairwise distance: 0.3025
dm = lp.compute_matrix(pool)
dm.to_frame().head()
┌─ LinearPairwiseSequenceMetric
│
│ Pairs: 0%| | 0/400 [00:00<?, ?it/s]
│ Pairs: 1%| | 3/400 [00:00<00:14, 27.03it/s]
│ Pairs: 2%|▏ | 9/400 [00:00<00:09, 42.83it/s]
│ Pairs: 4%|▎ | 14/400 [00:00<00:11, 33.65it/s]
│ Pairs: 4%|▍ | 18/400 [00:00<00:12, 30.94it/s]
│ Pairs: 6%|▌ | 22/400 [00:00<00:12, 29.59it/s]
│ Pairs: 7%|▋ | 27/400 [00:00<00:10, 34.16it/s]
│ Pairs: 8%|▊ | 31/400 [00:00<00:10, 35.04it/s]
│ Pairs: 9%|▉ | 35/400 [00:01<00:11, 32.97it/s]
│ Pairs: 10%|▉ | 39/400 [00:01<00:11, 31.69it/s]
│ Pairs: 11%|█ | 43/400 [00:01<00:11, 31.23it/s]
│ Pairs: 12%|█▏ | 49/400 [00:01<00:09, 37.51it/s]
│ Pairs: 13%|█▎ | 53/400 [00:01<00:09, 35.50it/s]
│ Pairs: 14%|█▍ | 57/400 [00:01<00:10, 34.14it/s]
│ Pairs: 15%|█▌ | 61/400 [00:01<00:10, 33.89it/s]
│ Pairs: 16%|█▋ | 66/400 [00:01<00:08, 37.28it/s]
│ Pairs: 18%|█▊ | 72/400 [00:02<00:07, 41.27it/s]
│ Pairs: 19%|█▉ | 77/400 [00:02<00:07, 41.72it/s]
│ Pairs: 20%|██ | 82/400 [00:02<00:07, 42.06it/s]
│ Pairs: 22%|██▏ | 88/400 [00:02<00:06, 45.12it/s]
│ Pairs: 23%|██▎ | 93/400 [00:02<00:06, 44.92it/s]
│ Pairs: 24%|██▍ | 98/400 [00:02<00:06, 44.32it/s]
│ Pairs: 26%|██▌ | 104/400 [00:02<00:06, 47.85it/s]
│ Pairs: 28%|██▊ | 111/400 [00:02<00:05, 53.65it/s]
│ Pairs: 30%|██▉ | 118/400 [00:02<00:04, 57.89it/s]
│ Pairs: 31%|███▏ | 125/400 [00:03<00:04, 60.97it/s]
│ Pairs: 33%|███▎ | 132/400 [00:03<00:04, 63.33it/s]
│ Pairs: 35%|███▍ | 139/400 [00:03<00:04, 64.78it/s]
│ Pairs: 36%|███▋ | 146/400 [00:03<00:03, 65.86it/s]
│ Pairs: 38%|███▊ | 153/400 [00:03<00:03, 66.71it/s]
│ Pairs: 40%|████ | 160/400 [00:03<00:03, 67.28it/s]
│ Pairs: 42%|████▏ | 167/400 [00:03<00:03, 63.28it/s]
│ Pairs: 44%|████▎ | 174/400 [00:03<00:03, 60.36it/s]
│ Pairs: 45%|████▌ | 181/400 [00:03<00:04, 53.94it/s]
│ Pairs: 47%|████▋ | 187/400 [00:04<00:04, 48.41it/s]
│ Pairs: 48%|████▊ | 193/400 [00:04<00:04, 42.61it/s]
│ Pairs: 50%|████▉ | 198/400 [00:04<00:05, 37.66it/s]
│ Pairs: 50%|█████ | 202/400 [00:04<00:05, 34.73it/s]
│ Pairs: 52%|█████▏ | 207/400 [00:04<00:05, 37.29it/s]
│ Pairs: 53%|█████▎ | 211/400 [00:04<00:05, 37.31it/s]
│ Pairs: 54%|█████▍ | 215/400 [00:04<00:05, 34.74it/s]
│ Pairs: 55%|█████▍ | 219/400 [00:05<00:05, 33.00it/s]
│ Pairs: 56%|█████▌ | 223/400 [00:05<00:05, 31.12it/s]
│ Pairs: 57%|█████▋ | 229/400 [00:05<00:04, 37.09it/s]
│ Pairs: 58%|█████▊ | 233/400 [00:05<00:05, 33.18it/s]
│ Pairs: 59%|█████▉ | 237/400 [00:05<00:05, 31.05it/s]
│ Pairs: 60%|██████ | 241/400 [00:05<00:05, 30.15it/s]
│ Pairs: 61%|██████▏ | 245/400 [00:05<00:04, 31.58it/s]
│ Pairs: 63%|██████▎ | 251/400 [00:06<00:04, 35.99it/s]
│ Pairs: 64%|██████▍ | 255/400 [00:06<00:04, 34.43it/s]
│ Pairs: 65%|██████▍ | 259/400 [00:06<00:04, 33.41it/s]
│ Pairs: 66%|██████▌ | 263/400 [00:06<00:04, 32.65it/s]
│ Pairs: 67%|██████▋ | 269/400 [00:06<00:03, 38.42it/s]
│ Pairs: 68%|██████▊ | 273/400 [00:06<00:03, 36.05it/s]
│ Pairs: 69%|██████▉ | 277/400 [00:06<00:03, 34.43it/s]
│ Pairs: 70%|███████ | 281/400 [00:06<00:03, 32.62it/s]
│ Pairs: 71%|███████▏ | 285/400 [00:07<00:03, 33.10it/s]
│ Pairs: 72%|███████▎ | 290/400 [00:07<00:02, 36.95it/s]
│ Pairs: 74%|███████▎ | 294/400 [00:07<00:03, 33.62it/s]
│ Pairs: 74%|███████▍ | 298/400 [00:07<00:03, 31.23it/s]
│ Pairs: 76%|███████▌ | 302/400 [00:07<00:03, 29.69it/s]
│ Pairs: 77%|███████▋ | 307/400 [00:07<00:02, 33.73it/s]
│ Pairs: 78%|███████▊ | 311/400 [00:07<00:02, 33.66it/s]
│ Pairs: 79%|███████▉ | 315/400 [00:07<00:02, 32.08it/s]
│ Pairs: 80%|███████▉ | 319/400 [00:08<00:02, 30.99it/s]
│ Pairs: 81%|████████ | 323/400 [00:08<00:02, 29.91it/s]
│ Pairs: 82%|████████▏ | 329/400 [00:08<00:01, 35.89it/s]
│ Pairs: 83%|████████▎ | 333/400 [00:08<00:02, 33.32it/s]
│ Pairs: 84%|████████▍ | 337/400 [00:08<00:01, 31.68it/s]
│ Pairs: 85%|████████▌ | 341/400 [00:08<00:01, 31.13it/s]
│ Pairs: 86%|████████▋ | 345/400 [00:08<00:01, 32.37it/s]
│ Pairs: 88%|████████▊ | 351/400 [00:09<00:01, 36.72it/s]
│ Pairs: 89%|████████▉ | 355/400 [00:09<00:01, 34.88it/s]
│ Pairs: 90%|████████▉ | 359/400 [00:09<00:01, 33.64it/s]
│ Pairs: 91%|█████████ | 363/400 [00:09<00:01, 32.73it/s]
│ Pairs: 92%|█████████▏| 369/400 [00:09<00:00, 38.37it/s]
│ Pairs: 93%|█████████▎| 373/400 [00:09<00:00, 35.72it/s]
│ Pairs: 94%|█████████▍| 377/400 [00:09<00:00, 34.26it/s]
│ Pairs: 95%|█████████▌| 381/400 [00:09<00:00, 32.57it/s]
│ Pairs: 96%|█████████▋| 385/400 [00:10<00:00, 33.10it/s]
│ Pairs: 98%|█████████▊| 390/400 [00:10<00:00, 36.99it/s]
│ Pairs: 98%|█████████▊| 394/400 [00:10<00:00, 33.75it/s]
│ Pairs: 100%|█████████▉| 398/400 [00:10<00:00, 31.35it/s]
│ Pairs: 100%|██████████| 400/400 [00:10<00:00, 37.90it/s]
│
└─ Done (20 sequences · 10.56s)
Total running time of the script: (0 minutes 10.663 seconds)