Note
Go to the end to download the full example code.
Custom Sequence Metric#
Learn how to implement a custom sequence-level distance metric by subclassing
SequenceMetric.
A sequence metric computes a scalar distance between two sequences and can
produce a full pairwise DistanceMatrix over a pool.
Minimal contract
Declare a
SETTINGS_CLASSdataclass (usetanat_utils.settings_dataclass).Implement
_compute(seq_a, seq_b): return a non-negativefloat.Implement
validate_composition(seq_a, seq_b): raise early if the metric is incompatible with the given sequences (feature mismatch, wrong type, …).
The base class handles __call__, compute_matrix, and
compute_cross_matrix automatically.
Note
For performance-critical cases you can additionally override
_compute_matrix_impl with a vectorised or Numba-backed kernel.
This is not required for a minimal working implementation.
Setup#
from tanat_utils import settings_dataclass as dataclass
from tanat import build_states
from tanat.dataset import simulate_states
from tanat.metric.sequence.base import SequenceMetric
Data#
raw = simulate_states(n_ids=30, features=["status"], seed=42)
pool = build_states(
temporal_data=raw, id_column="id", start_column="start", end_column="end"
)
┌─ State SequenceStore
│
│ Step 1/4: Sorting & preparing data
│
│ Step 2/4: Building sequence index
│
│ Step 3/4: Writing entity & time index features
│
│ Step 4/4: Computing & writing metadata
│
└─ Done (30 sequences · 207 entities · 0.00s)
print(pool)
┌────────────────────────────────────────────────┐
│ StateSequencePool Summary │
└────────────────────────────────────────────────┘
Overview
─────────────────────────
Sequences 30
Store /home/runner/.tanat/_quick_state_d2506d21
id_column id
Time Index
─────────────────────────
Type Datetime(time_unit='us', time_zone=None) [2000-12-07 15:38:06.403275 → 2025-01-01 00:00:00]
Columns ['start', 'end']
t0 position=0, anchor=start
Entity Features (1)
─────────────────────────
• status Numerical [1 → 100]
Define a custom sequence metric#
We implement a length-difference metric: the distance between two sequences is the absolute difference of their lengths (number of entities). This intentionally simple metric requires no entity-level building block and illustrates the minimal implementation pattern.
@dataclass
class LengthDiffSettings:
"""Settings for :class:`LengthDiffSequenceMetric`.
Args:
normalize: If ``True``, divide by the length of the longer sequence
so that the result is in ``[0, 1]``.
"""
normalize: bool = False
class LengthDiffSequenceMetric(SequenceMetric, register_name="length_diff"):
"""Distance metric based on the difference in sequence lengths.
``dist(seq_a, seq_b) = |len(seq_a) - len(seq_b)|``
With ``normalize=True``:
``dist = |len(seq_a) - len(seq_b)| / max(len(seq_a), len(seq_b))``
"""
SETTINGS_CLASS = LengthDiffSettings
def __init__(self, normalize: bool = False) -> None:
super().__init__(settings=LengthDiffSettings(normalize=normalize))
# ------------------------------------------------------------------
# Required implementations
# ------------------------------------------------------------------
def validate_composition(self, seq_a, seq_b=None) -> None:
"""No entity-feature constraint: any sequence type is accepted."""
# Length is always available, nothing to validate here.
def _compute(self, seq_a, seq_b) -> float:
"""Absolute (or normalised) difference of sequence lengths."""
la, lb = len(seq_a), len(seq_b)
diff = abs(la - lb)
if self.settings.normalize:
denom = max(la, lb)
return float(diff / denom) if denom > 0 else 0.0
return float(diff)
Instantiate and inspect#
metric = LengthDiffSequenceMetric(normalize=True)
print(metric)
LengthDiffSequenceMetric(settings=LengthDiffSettings(normalize=True))
Compute a distance between two sequences#
ids = pool.unique_ids
seq_a = pool[ids[0]]
seq_b = pool[ids[1]]
print(f"len(seq_a) = {len(seq_a)}, len(seq_b) = {len(seq_b)}")
print(f"distance = {metric(seq_a, seq_b):.4f}")
len(seq_a) = 3, len(seq_b) = 9
distance = 0.6667
Compute a full pairwise distance matrix#
compute_matrix is inherited from the base class and works out of the
box once _compute is implemented.
dm = metric.compute_matrix(pool)
dm.to_frame().head()
┌─ LengthDiffSequenceMetric
│
│ Pairs: 0%| | 0/900 [00:00<?, ?it/s]
│ Pairs: 8%|▊ | 72/900 [00:00<00:01, 712.47it/s]
│ Pairs: 16%|█▌ | 144/900 [00:00<00:01, 715.76it/s]
│ Pairs: 24%|██▍ | 217/900 [00:00<00:00, 717.76it/s]
│ Pairs: 32%|███▏ | 289/900 [00:00<00:00, 717.74it/s]
│ Pairs: 40%|████ | 362/900 [00:00<00:00, 721.45it/s]
│ Pairs: 48%|████▊ | 435/900 [00:00<00:00, 724.18it/s]
│ Pairs: 56%|█████▋ | 508/900 [00:00<00:00, 691.90it/s]
│ Pairs: 65%|██████▍ | 581/900 [00:00<00:00, 702.60it/s]
│ Pairs: 73%|███████▎ | 654/900 [00:00<00:00, 709.62it/s]
│ Pairs: 81%|████████ | 727/900 [00:01<00:00, 713.31it/s]
│ Pairs: 89%|████████▉ | 800/900 [00:01<00:00, 717.33it/s]
│ Pairs: 97%|█████████▋| 873/900 [00:01<00:00, 719.11it/s]
│ Pairs: 100%|██████████| 900/900 [00:01<00:00, 714.41it/s]
│
└─ Done (30 sequences · 1.26s)
Total running time of the script: (0 minutes 1.286 seconds)