Note
Go to the end to download the full example code.
Custom Trajectory Metric#
Learn how to implement a custom trajectory-level distance metric by subclassing
TrajectoryMetric.
A trajectory metric computes a scalar distance between two
Trajectory objects and can produce a
full pairwise DistanceMatrix over a
TrajectoryPool.
Minimal contract
Declare a
SETTINGS_CLASSdataclass (usetanat_utils.settings_dataclass).Implement
_compute(traj_a, traj_b): return a non-negativefloat.
The base class handles __call__, compute_matrix, and
compute_cross_matrix automatically.
Setup#
from tanat_utils import settings_dataclass as dataclass
from tanat import build_events, build_states
from tanat.dataset import simulate_events, simulate_states
from tanat.trajectory.shortcuts import build_trajectories
from tanat.metric.trajectory.base import TrajectoryMetric
Data#
A trajectory pool combining an event sequence ("visits") and a state
sequence ("status") for the same set of individuals.
N_IDS = 20
SEED = 42
events_df = simulate_events(n_ids=N_IDS, features=["score"], seed=SEED)
states_df = simulate_states(n_ids=N_IDS, features=["phase"], seed=SEED)
event_pool = build_events(temporal_data=events_df, id_column="id", time_column="time")
state_pool = build_states(
temporal_data=states_df, id_column="id", start_column="start", end_column="end"
)
tpool = build_trajectories({"visits": event_pool, "status": state_pool})
┌─ Event SequenceStore
│
│ Step 1/4: Sorting & preparing data
│
│ Step 2/4: Building sequence index
│
│ Step 3/4: Writing entity & time index features
│
│ Step 4/4: Computing & writing metadata
│
└─ Done (20 sequences · 136 entities · 0.00s)
┌─ State SequenceStore
│
│ Step 1/4: Sorting & preparing data
│
│ Step 2/4: Building sequence index
│
│ Step 3/4: Writing entity & time index features
│
│ Step 4/4: Computing & writing metadata
│
└─ Done (20 sequences · 136 entities · 0.00s)
┌─ TrajectoryStore
│
│ Step 1/2: Linking pools: visits, status
│
│ Step 2/2: Building trajectory index & metadata
│
└─ Done (20 trajectories · 2 pool(s) · 0.00s)
print(tpool)
┌────────────────────────────────────────────────┐
│ TrajectoryPool Summary │
└────────────────────────────────────────────────┘
Overview
─────────────────────────
Trajectories 20
Store /home/runner/.tanat/_quick_trajectory_682e0903
id_column id
Time Index
─────────────────────────
Type Datetime(time_unit='us', time_zone=None) [2000-05-07 06:23:17.895952 → 2025-01-01 00:00:00]
t0 position=0, anchor=start
Sequences (2)
─────────────────────────
• visits EventSequencePool(n=20, entity_features=1, static_features=0, store='/home/runner/.tanat/_quick_event_298ccdb1')
• status StateSequencePool(n=20, entity_features=1, static_features=0, store='/home/runner/.tanat/_quick_state_12c9fb86')
Define a custom trajectory metric#
We implement a total-length-diff metric: the distance between two trajectories is the sum of absolute length differences across all shared sequence aliases. Simple by design, its purpose is to show the minimal pattern.
@dataclass
class TotalLengthDiffSettings:
"""Settings for :class:`TotalLengthDiffTrajectoryMetric`.
Args:
normalize: If ``True``, divide each per-alias difference by the length
of the longer sub-sequence before summing.
"""
normalize: bool = False
class TotalLengthDiffTrajectoryMetric(
TrajectoryMetric, register_name="total_length_diff"
):
"""Trajectory distance as the sum of per-alias sequence-length differences.
For each alias present in both trajectories:
``diff_alias = |len(alias_a) - len(alias_b)|``
The final distance is the sum (or normalised sum) of those values.
"""
SETTINGS_CLASS = TotalLengthDiffSettings
def __init__(self, normalize: bool = False) -> None:
super().__init__(settings=TotalLengthDiffSettings(normalize=normalize))
# ------------------------------------------------------------------
# Required implementation
# ------------------------------------------------------------------
def _compute(self, traj_a, traj_b) -> float:
"""Sum of per-alias absolute length differences."""
total = 0.0
for alias in traj_a:
if alias not in traj_b:
continue
la = len(traj_a[alias])
lb = len(traj_b[alias])
diff = abs(la - lb)
if self.settings.normalize:
denom = max(la, lb)
diff = (diff / denom) if denom > 0 else 0.0
total += diff
return float(total)
Instantiate and inspect#
metric = TotalLengthDiffTrajectoryMetric(normalize=True)
print(metric)
TotalLengthDiffTrajectoryMetric(settings=TotalLengthDiffSettings(normalize=True))
Compute a distance between two trajectories#
ids = tpool.unique_ids
traj_a = tpool[ids[0]]
traj_b = tpool[ids[1]]
print(
f"Trajectory A | visits: {len(traj_a['visits'])}, status: {len(traj_a['status'])}"
)
print(
f"Trajectory B | visits: {len(traj_b['visits'])}, status: {len(traj_b['status'])}"
)
print(f"Distance: {metric(traj_a, traj_b):.4f}")
Trajectory A | visits: 3, status: 3
Trajectory B | visits: 9, status: 9
Distance: 1.3333
Compute a full pairwise distance matrix#
compute_matrix is inherited from the base class and works out of the
box once _compute is implemented.
dm = metric.compute_matrix(tpool)
dm.to_frame().head()
┌─ TotalLengthDiffTrajectoryMetric
│
│ Pairs: 0%| | 0/400 [00:00<?, ?it/s]
│ Pairs: 7%|▋ | 29/400 [00:00<00:01, 287.04it/s]
│ Pairs: 16%|█▌ | 63/400 [00:00<00:01, 316.07it/s]
│ Pairs: 24%|██▍ | 97/400 [00:00<00:00, 326.05it/s]
│ Pairs: 33%|███▎ | 131/400 [00:00<00:00, 329.64it/s]
│ Pairs: 41%|████▏ | 165/400 [00:00<00:00, 331.10it/s]
│ Pairs: 50%|████▉ | 199/400 [00:00<00:00, 332.24it/s]
│ Pairs: 58%|█████▊ | 233/400 [00:00<00:00, 333.12it/s]
│ Pairs: 67%|██████▋ | 267/400 [00:00<00:00, 333.15it/s]
│ Pairs: 75%|███████▌ | 301/400 [00:00<00:00, 332.67it/s]
│ Pairs: 84%|████████▍ | 335/400 [00:01<00:00, 332.69it/s]
│ Pairs: 92%|█████████▏| 369/400 [00:01<00:00, 333.49it/s]
│ Pairs: 100%|██████████| 400/400 [00:01<00:00, 330.61it/s]
│
└─ Done (20 trajectories · 1.21s)
Total running time of the script: (0 minutes 1.255 seconds)