Custom Trajectory Metric#

Learn how to implement a custom trajectory-level distance metric by subclassing TrajectoryMetric.

A trajectory metric computes a scalar distance between two Trajectory objects and can produce a full pairwise DistanceMatrix over a TrajectoryPool.

Minimal contract

  1. Declare a SETTINGS_CLASS dataclass (use tanat_utils.settings_dataclass).

  2. Implement _compute(traj_a, traj_b): return a non-negative float.

The base class handles __call__, compute_matrix, and compute_cross_matrix automatically.

Setup#

from tanat_utils import settings_dataclass as dataclass

from tanat import build_events, build_states
from tanat.dataset import simulate_events, simulate_states
from tanat.trajectory.shortcuts import build_trajectories
from tanat.metric.trajectory.base import TrajectoryMetric

Data#

A trajectory pool combining an event sequence ("visits") and a state sequence ("status") for the same set of individuals.

N_IDS = 20
SEED = 42

events_df = simulate_events(n_ids=N_IDS, features=["score"], seed=SEED)
states_df = simulate_states(n_ids=N_IDS, features=["phase"], seed=SEED)

event_pool = build_events(temporal_data=events_df, id_column="id", time_column="time")
state_pool = build_states(
    temporal_data=states_df, id_column="id", start_column="start", end_column="end"
)

tpool = build_trajectories({"visits": event_pool, "status": state_pool})
┌─ Event SequenceStore
│
│ Step 1/4: Sorting & preparing data
│
│ Step 2/4: Building sequence index
│
│ Step 3/4: Writing entity & time index features
│
│ Step 4/4: Computing & writing metadata
│
└─ Done (20 sequences · 136 entities · 0.00s)
┌─ State SequenceStore
│
│ Step 1/4: Sorting & preparing data
│
│ Step 2/4: Building sequence index
│
│ Step 3/4: Writing entity & time index features
│
│ Step 4/4: Computing & writing metadata
│
└─ Done (20 sequences · 136 entities · 0.00s)
┌─ TrajectoryStore
│
│ Step 1/2: Linking pools: visits, status
│
│ Step 2/2: Building trajectory index & metadata
│
└─ Done (20 trajectories · 2 pool(s) · 0.00s)
print(tpool)
┌────────────────────────────────────────────────┐
│             TrajectoryPool Summary             │
└────────────────────────────────────────────────┘

Overview
─────────────────────────
  Trajectories       20
  Store              /home/runner/.tanat/_quick_trajectory_682e0903
  id_column          id

Time Index
─────────────────────────
  Type               Datetime(time_unit='us', time_zone=None) [2000-05-07 06:23:17.895952 → 2025-01-01 00:00:00]
  t0                 position=0, anchor=start

Sequences (2)
─────────────────────────
  • visits              EventSequencePool(n=20, entity_features=1, static_features=0, store='/home/runner/.tanat/_quick_event_298ccdb1')
  • status              StateSequencePool(n=20, entity_features=1, static_features=0, store='/home/runner/.tanat/_quick_state_12c9fb86')

Define a custom trajectory metric#

We implement a total-length-diff metric: the distance between two trajectories is the sum of absolute length differences across all shared sequence aliases. Simple by design, its purpose is to show the minimal pattern.

@dataclass
class TotalLengthDiffSettings:
    """Settings for :class:`TotalLengthDiffTrajectoryMetric`.

    Args:
        normalize: If ``True``, divide each per-alias difference by the length
            of the longer sub-sequence before summing.
    """

    normalize: bool = False


class TotalLengthDiffTrajectoryMetric(
    TrajectoryMetric, register_name="total_length_diff"
):
    """Trajectory distance as the sum of per-alias sequence-length differences.

    For each alias present in both trajectories:
    ``diff_alias = |len(alias_a) - len(alias_b)|``

    The final distance is the sum (or normalised sum) of those values.
    """

    SETTINGS_CLASS = TotalLengthDiffSettings

    def __init__(self, normalize: bool = False) -> None:
        super().__init__(settings=TotalLengthDiffSettings(normalize=normalize))

    # ------------------------------------------------------------------
    # Required implementation
    # ------------------------------------------------------------------

    def _compute(self, traj_a, traj_b) -> float:
        """Sum of per-alias absolute length differences."""
        total = 0.0
        for alias in traj_a:
            if alias not in traj_b:
                continue
            la = len(traj_a[alias])
            lb = len(traj_b[alias])
            diff = abs(la - lb)
            if self.settings.normalize:
                denom = max(la, lb)
                diff = (diff / denom) if denom > 0 else 0.0
            total += diff
        return float(total)

Instantiate and inspect#

metric = TotalLengthDiffTrajectoryMetric(normalize=True)
print(metric)
TotalLengthDiffTrajectoryMetric(settings=TotalLengthDiffSettings(normalize=True))

Compute a distance between two trajectories#

ids = tpool.unique_ids
traj_a = tpool[ids[0]]
traj_b = tpool[ids[1]]

print(
    f"Trajectory A | visits: {len(traj_a['visits'])}, status: {len(traj_a['status'])}"
)
print(
    f"Trajectory B | visits: {len(traj_b['visits'])}, status: {len(traj_b['status'])}"
)
print(f"Distance: {metric(traj_a, traj_b):.4f}")
Trajectory A | visits: 3, status: 3
Trajectory B | visits: 9, status: 9
Distance: 1.3333

Compute a full pairwise distance matrix#

compute_matrix is inherited from the base class and works out of the box once _compute is implemented.

dm = metric.compute_matrix(tpool)
dm.to_frame().head()
┌─ TotalLengthDiffTrajectoryMetric
│

│ Pairs:   0%|          | 0/400 [00:00<?, ?it/s]
│ Pairs:   7%|▋         | 29/400 [00:00<00:01, 287.04it/s]
│ Pairs:  16%|█▌        | 63/400 [00:00<00:01, 316.07it/s]
│ Pairs:  24%|██▍       | 97/400 [00:00<00:00, 326.05it/s]
│ Pairs:  33%|███▎      | 131/400 [00:00<00:00, 329.64it/s]
│ Pairs:  41%|████▏     | 165/400 [00:00<00:00, 331.10it/s]
│ Pairs:  50%|████▉     | 199/400 [00:00<00:00, 332.24it/s]
│ Pairs:  58%|█████▊    | 233/400 [00:00<00:00, 333.12it/s]
│ Pairs:  67%|██████▋   | 267/400 [00:00<00:00, 333.15it/s]
│ Pairs:  75%|███████▌  | 301/400 [00:00<00:00, 332.67it/s]
│ Pairs:  84%|████████▍ | 335/400 [00:01<00:00, 332.69it/s]
│ Pairs:  92%|█████████▏| 369/400 [00:01<00:00, 333.49it/s]
│ Pairs: 100%|██████████| 400/400 [00:01<00:00, 330.61it/s]
│
└─ Done (20 trajectories · 1.21s)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
1 0.000000 1.333333 1.250000 1.000000 1.000000 1.333333 0.000000 1.250000 0.500000 0.000000 1.142857 1.4 1.250000 1.333333 1.250000 1.333333 1.142857 0.500000 1.333333 1.000000
2 1.333333 0.000000 0.222222 0.666667 0.666667 0.000000 1.333333 0.222222 1.111111 1.333333 0.444444 0.2 0.222222 0.000000 0.222222 0.000000 0.444444 1.111111 0.000000 0.666667
3 1.250000 0.222222 0.000000 0.500000 0.500000 0.222222 1.250000 0.000000 1.000000 1.250000 0.250000 0.4 0.000000 0.222222 0.000000 0.222222 0.250000 1.000000 0.222222 0.500000
4 1.000000 0.666667 0.500000 0.000000 0.000000 0.666667 1.000000 0.500000 0.666667 1.000000 0.285714 0.8 0.500000 0.666667 0.500000 0.666667 0.285714 0.666667 0.666667 0.000000
5 1.000000 0.666667 0.500000 0.000000 0.000000 0.666667 1.000000 0.500000 0.666667 1.000000 0.285714 0.8 0.500000 0.666667 0.500000 0.666667 0.285714 0.666667 0.666667 0.000000


Total running time of the script: (0 minutes 1.255 seconds)

Gallery generated by Sphinx-Gallery