Trajectory Metrics: Aggregation Trajectory Metric#

This example demonstrates trajectory-level metrics using AggregationTrajectoryMetric, which computes distances between trajectories that contain multiple sequence types (e.g., states, events, or intervals).

Setup#

import polars as pl
import matplotlib.pyplot as plt

from tanat import build_states, build_events, build_trajectories
from tanat.dataset import simulate_trajectories
from tanat.metric.entity import HammingEntityMetric
from tanat.metric.sequence import (
    EditSequenceMetric,
    LCSSequenceMetric,
)
from tanat.metric import AggregationTrajectoryMetric

Generate synthetic trajectory data#

N_TRAJ = 100
SEED = 42

raw = simulate_trajectories(
    sequences={
        "states": {
            "type": "state",
            "n_ids": N_TRAJ,
            "seq_length_range": (3, 8),
            "features": ["score", "status"],
        },
        "events": {
            "type": "event",
            "n_ids": N_TRAJ,
            "seq_length_range": (2, 6),
            "features": ["score", "status"],
        },
    },
    shared_ids=True,
    seed=SEED,
)

# Build pools for each sequence type
states_pool = build_states(
    temporal_data=raw["states"],
    id_column="id",
    start_column="start",
    end_column="end",
)
events_pool = build_events(
    temporal_data=raw["events"],
    id_column="id",
    time_column="time",
)

# Build trajectory pool
traj_pool = build_trajectories(pools={"states": states_pool, "events": events_pool})
┌─ State SequenceStore
│
│ Step 1/4: Sorting & preparing data
│
│ Step 2/4: Building sequence index
│
│ Step 3/4: Writing entity & time index features
│
│ Step 4/4: Computing & writing metadata
│
└─ Done (100 sequences · 531 entities · 0.01s)
┌─ Event SequenceStore
│
│ Step 1/4: Sorting & preparing data
│
│ Step 2/4: Building sequence index
│
│ Step 3/4: Writing entity & time index features
│
│ Step 4/4: Computing & writing metadata
│
└─ Done (100 sequences · 383 entities · 0.00s)
┌─ TrajectoryStore
│
│ Step 1/2: Linking pools: states, events
│
│ Step 2/2: Building trajectory index & metadata
│
└─ Done (100 trajectories · 2 pool(s) · 0.00s)
# Cast features to categorical
for sp in traj_pool.sequence_pools.values():
    sp.cast_features({"status": pl.Categorical})

print(traj_pool)
┌────────────────────────────────────────────────┐
│             TrajectoryPool Summary             │
└────────────────────────────────────────────────┘

Overview
─────────────────────────
  Trajectories       100
  Store              /home/runner/.tanat/_quick_trajectory_4056a21b
  id_column          id

Time Index
─────────────────────────
  Type               Datetime(time_unit='us', time_zone=None) [2000-01-04 17:05:08.109495 → 2025-01-01 00:00:00]
  t0                 position=0, anchor=start

Sequences (2)
─────────────────────────
  • states              StateSequencePool(n=100, entity_features=2, static_features=0, store='/home/runner/.tanat/_quick_state_98c2ffcd')
  • events              EventSequencePool(n=100, entity_features=2, static_features=0, store='/home/runner/.tanat/_quick_event_ad5c8635')

Define trajectory metric#

hamming = HammingEntityMetric(entity_feature="status")

# Use different metrics per alias (sequence type)
agg = AggregationTrajectoryMetric(
    default_metric=EditSequenceMetric(entity_metric=hamming, normalize=True),
    sequence_metrics={
        "events": LCSSequenceMetric(entity_metric=hamming, mode="normalized"),
    },
    agg_fun="mean",
)
print(agg)
AggregationTrajectoryMetric(settings=AggregationSettings(default_metric=EditSequenceMetric(settings=EditSettings(entity_metric=HammingEntityMetric(settings=HammingSettings(entity_feature='status', cost=None, mismatch_cost=1.0)), indel_cost=1.0, normalize=True)), sequence_metrics={'events': LCSSequenceMetric(settings=LCSSettings(entity_metric=HammingEntityMetric(settings=HammingSettings(entity_feature='status', cost=None, mismatch_cost=1.0)), equality_threshold=0.0, mode='normalized'))}, agg_fun='mean', weights=None))

Compute distance between a single pair#

traj_ids = traj_pool.unique_ids
traj_a = traj_pool[traj_ids[0]]
traj_b = traj_pool[traj_ids[1]]

dist = agg(traj_a, traj_b)
print(f"\nDistance between {traj_ids[0]} and {traj_ids[1]}: {dist:.4f}")
Distance between 1 and 2: 0.8750

Compute full pairwise distance matrix#

matrix = agg.compute_matrix(traj_pool)
print(f"\nDistance matrix shape: {matrix.shape}")
print(f"Mean distance: {matrix.to_numpy()[matrix.to_numpy() > 0].mean():.4f}")
┌─ AggregationTrajectoryMetric
│
│   ┌─ EditSequenceMetric
│   │

│   │ Chunks:   0%|          | 0/1 [00:00<?, ?it/s]
│   │ Chunks: 100%|██████████| 1/1 [00:00<00:00, 1110.19it/s]
│   │
│   └─ Done (100 sequences · 0.00s)
│
│   ┌─ LCSSequenceMetric
│   │

│   │ Chunks:   0%|          | 0/1 [00:00<?, ?it/s]
│   │ Chunks: 100%|██████████| 1/1 [00:00<00:00, 1203.88it/s]
│   │
│   └─ Done (100 sequences · 0.00s)
│
└─ Done (100 trajectories · 0.02s)

Distance matrix shape: (100, 100)
Mean distance: 0.6860

Visualize trajectory distances#

fig, ax = plt.subplots(figsize=(8, 6))
arr = matrix.to_numpy()
im = ax.imshow(arr, cmap="viridis", aspect="auto")
ax.set_title(
    "Trajectory distances\n(Edit distance for states, LCS for events)",
    fontsize=12,
    fontweight="bold",
)
ax.set_xlabel("Trajectory index")
ax.set_ylabel("Trajectory index")
cbar = plt.colorbar(im, ax=ax)
cbar.set_label("Distance")
plt.tight_layout()
plt.show()
Trajectory distances (Edit distance for states, LCS for events)

Total running time of the script: (0 minutes 0.413 seconds)

Gallery generated by Sphinx-Gallery