Custom Sequence Metric#

Learn how to implement a custom sequence-level distance metric by subclassing SequenceMetric.

A sequence metric computes a scalar distance between two sequences and can produce a full pairwise DistanceMatrix over a pool.

Minimal contract

  1. Declare a SETTINGS_CLASS dataclass (use tanat_utils.settings_dataclass).

  2. Implement _compute(seq_a, seq_b): return a non-negative float.

  3. Implement validate_composition(seq_a, seq_b): raise early if the metric is incompatible with the given sequences (feature mismatch, wrong type, …).

The base class handles __call__, compute_matrix, and compute_cross_matrix automatically.

Note

For performance-critical cases you can additionally override _compute_matrix_impl with a vectorised or Numba-backed kernel. This is not required for a minimal working implementation.

Setup#

from tanat_utils import settings_dataclass as dataclass

from tanat import build_states
from tanat.dataset import simulate_states
from tanat.metric.sequence.base import SequenceMetric

Data#

raw = simulate_states(n_ids=30, features=["status"], seed=42)
pool = build_states(
    temporal_data=raw, id_column="id", start_column="start", end_column="end"
)
┌─ State SequenceStore
│
│ Step 1/4: Sorting & preparing data
│
│ Step 2/4: Building sequence index
│
│ Step 3/4: Writing entity & time index features
│
│ Step 4/4: Computing & writing metadata
│
└─ Done (30 sequences · 207 entities · 0.00s)
print(pool)
┌────────────────────────────────────────────────┐
│           StateSequencePool Summary            │
└────────────────────────────────────────────────┘

Overview
─────────────────────────
  Sequences          30
  Store              /home/runner/.tanat/_quick_state_d2506d21
  id_column          id

Time Index
─────────────────────────
  Type               Datetime(time_unit='us', time_zone=None) [2000-12-07 15:38:06.403275 → 2025-01-01 00:00:00]
  Columns            ['start', 'end']
  t0                 position=0, anchor=start

Entity Features (1)
─────────────────────────
  • status              Numerical [1 → 100]

Define a custom sequence metric#

We implement a length-difference metric: the distance between two sequences is the absolute difference of their lengths (number of entities). This intentionally simple metric requires no entity-level building block and illustrates the minimal implementation pattern.

@dataclass
class LengthDiffSettings:
    """Settings for :class:`LengthDiffSequenceMetric`.

    Args:
        normalize: If ``True``, divide by the length of the longer sequence
            so that the result is in ``[0, 1]``.
    """

    normalize: bool = False


class LengthDiffSequenceMetric(SequenceMetric, register_name="length_diff"):
    """Distance metric based on the difference in sequence lengths.

    ``dist(seq_a, seq_b) = |len(seq_a) - len(seq_b)|``

    With ``normalize=True``:
    ``dist = |len(seq_a) - len(seq_b)| / max(len(seq_a), len(seq_b))``
    """

    SETTINGS_CLASS = LengthDiffSettings

    def __init__(self, normalize: bool = False) -> None:
        super().__init__(settings=LengthDiffSettings(normalize=normalize))

    # ------------------------------------------------------------------
    # Required implementations
    # ------------------------------------------------------------------

    def validate_composition(self, seq_a, seq_b=None) -> None:
        """No entity-feature constraint: any sequence type is accepted."""
        # Length is always available, nothing to validate here.

    def _compute(self, seq_a, seq_b) -> float:
        """Absolute (or normalised) difference of sequence lengths."""
        la, lb = len(seq_a), len(seq_b)
        diff = abs(la - lb)
        if self.settings.normalize:
            denom = max(la, lb)
            return float(diff / denom) if denom > 0 else 0.0
        return float(diff)

Instantiate and inspect#

metric = LengthDiffSequenceMetric(normalize=True)
print(metric)
LengthDiffSequenceMetric(settings=LengthDiffSettings(normalize=True))

Compute a distance between two sequences#

ids = pool.unique_ids
seq_a = pool[ids[0]]
seq_b = pool[ids[1]]

print(f"len(seq_a) = {len(seq_a)},  len(seq_b) = {len(seq_b)}")
print(f"distance   = {metric(seq_a, seq_b):.4f}")
len(seq_a) = 3,  len(seq_b) = 9
distance   = 0.6667

Compute a full pairwise distance matrix#

compute_matrix is inherited from the base class and works out of the box once _compute is implemented.

dm = metric.compute_matrix(pool)
dm.to_frame().head()
┌─ LengthDiffSequenceMetric
│

│ Pairs:   0%|          | 0/900 [00:00<?, ?it/s]
│ Pairs:   8%|▊         | 72/900 [00:00<00:01, 712.47it/s]
│ Pairs:  16%|█▌        | 144/900 [00:00<00:01, 715.76it/s]
│ Pairs:  24%|██▍       | 217/900 [00:00<00:00, 717.76it/s]
│ Pairs:  32%|███▏      | 289/900 [00:00<00:00, 717.74it/s]
│ Pairs:  40%|████      | 362/900 [00:00<00:00, 721.45it/s]
│ Pairs:  48%|████▊     | 435/900 [00:00<00:00, 724.18it/s]
│ Pairs:  56%|█████▋    | 508/900 [00:00<00:00, 691.90it/s]
│ Pairs:  65%|██████▍   | 581/900 [00:00<00:00, 702.60it/s]
│ Pairs:  73%|███████▎  | 654/900 [00:00<00:00, 709.62it/s]
│ Pairs:  81%|████████  | 727/900 [00:01<00:00, 713.31it/s]
│ Pairs:  89%|████████▉ | 800/900 [00:01<00:00, 717.33it/s]
│ Pairs:  97%|█████████▋| 873/900 [00:01<00:00, 719.11it/s]
│ Pairs: 100%|██████████| 900/900 [00:01<00:00, 714.41it/s]
│
└─ Done (30 sequences · 1.26s)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30
1 0.000000 0.666667 0.625000 0.500000 0.500000 0.666667 0.000000 0.625000 0.250000 0.000000 0.571429 0.7 0.625000 0.666667 0.625000 0.666667 0.571429 0.250000 0.666667 0.500000 0.571429 0.400000 0.250000 0.7 0.666667 0.625000 0.500000 0.666667 0.571429 0.500000
2 0.666667 0.000000 0.111111 0.333333 0.333333 0.000000 0.666667 0.111111 0.555556 0.666667 0.222222 0.1 0.111111 0.000000 0.111111 0.000000 0.222222 0.555556 0.000000 0.333333 0.222222 0.444444 0.555556 0.1 0.000000 0.111111 0.333333 0.000000 0.222222 0.333333
3 0.625000 0.111111 0.000000 0.250000 0.250000 0.111111 0.625000 0.000000 0.500000 0.625000 0.125000 0.2 0.000000 0.111111 0.000000 0.111111 0.125000 0.500000 0.111111 0.250000 0.125000 0.375000 0.500000 0.2 0.111111 0.000000 0.250000 0.111111 0.125000 0.250000
4 0.500000 0.333333 0.250000 0.000000 0.000000 0.333333 0.500000 0.250000 0.333333 0.500000 0.142857 0.4 0.250000 0.333333 0.250000 0.333333 0.142857 0.333333 0.333333 0.000000 0.142857 0.166667 0.333333 0.4 0.333333 0.250000 0.000000 0.333333 0.142857 0.000000
5 0.500000 0.333333 0.250000 0.000000 0.000000 0.333333 0.500000 0.250000 0.333333 0.500000 0.142857 0.4 0.250000 0.333333 0.250000 0.333333 0.142857 0.333333 0.333333 0.000000 0.142857 0.166667 0.333333 0.4 0.333333 0.250000 0.000000 0.333333 0.142857 0.000000


Total running time of the script: (0 minutes 1.286 seconds)

Gallery generated by Sphinx-Gallery