Note
Go to the end to download the full example code.
Trajectory Metrics: Aggregation Trajectory Metric#
This example demonstrates trajectory-level metrics using
AggregationTrajectoryMetric, which computes
distances between trajectories that contain multiple sequence types
(e.g., states, events, or intervals).
Setup#
import polars as pl
import matplotlib.pyplot as plt
from tanat import build_states, build_events, build_trajectories
from tanat.dataset import simulate_trajectories
from tanat.metric.entity import HammingEntityMetric
from tanat.metric.sequence import (
EditSequenceMetric,
LCSSequenceMetric,
)
from tanat.metric import AggregationTrajectoryMetric
Generate synthetic trajectory data#
N_TRAJ = 100
SEED = 42
raw = simulate_trajectories(
sequences={
"states": {
"type": "state",
"n_ids": N_TRAJ,
"seq_length_range": (3, 8),
"features": ["score", "status"],
},
"events": {
"type": "event",
"n_ids": N_TRAJ,
"seq_length_range": (2, 6),
"features": ["score", "status"],
},
},
shared_ids=True,
seed=SEED,
)
# Build pools for each sequence type
states_pool = build_states(
temporal_data=raw["states"],
id_column="id",
start_column="start",
end_column="end",
)
events_pool = build_events(
temporal_data=raw["events"],
id_column="id",
time_column="time",
)
# Build trajectory pool
traj_pool = build_trajectories(pools={"states": states_pool, "events": events_pool})
┌─ State SequenceStore
│
│ Step 1/4: Sorting & preparing data
│
│ Step 2/4: Building sequence index
│
│ Step 3/4: Writing entity & time index features
│
│ Step 4/4: Computing & writing metadata
│
└─ Done (100 sequences · 531 entities · 0.01s)
┌─ Event SequenceStore
│
│ Step 1/4: Sorting & preparing data
│
│ Step 2/4: Building sequence index
│
│ Step 3/4: Writing entity & time index features
│
│ Step 4/4: Computing & writing metadata
│
└─ Done (100 sequences · 383 entities · 0.00s)
┌─ TrajectoryStore
│
│ Step 1/2: Linking pools: states, events
│
│ Step 2/2: Building trajectory index & metadata
│
└─ Done (100 trajectories · 2 pool(s) · 0.00s)
# Cast features to categorical
for sp in traj_pool.sequence_pools.values():
sp.cast_features({"status": pl.Categorical})
print(traj_pool)
┌────────────────────────────────────────────────┐
│ TrajectoryPool Summary │
└────────────────────────────────────────────────┘
Overview
─────────────────────────
Trajectories 100
Store /home/runner/.tanat/_quick_trajectory_4056a21b
id_column id
Time Index
─────────────────────────
Type Datetime(time_unit='us', time_zone=None) [2000-01-04 17:05:08.109495 → 2025-01-01 00:00:00]
t0 position=0, anchor=start
Sequences (2)
─────────────────────────
• states StateSequencePool(n=100, entity_features=2, static_features=0, store='/home/runner/.tanat/_quick_state_98c2ffcd')
• events EventSequencePool(n=100, entity_features=2, static_features=0, store='/home/runner/.tanat/_quick_event_ad5c8635')
Define trajectory metric#
hamming = HammingEntityMetric(entity_feature="status")
# Use different metrics per alias (sequence type)
agg = AggregationTrajectoryMetric(
default_metric=EditSequenceMetric(entity_metric=hamming, normalize=True),
sequence_metrics={
"events": LCSSequenceMetric(entity_metric=hamming, mode="normalized"),
},
agg_fun="mean",
)
print(agg)
AggregationTrajectoryMetric(settings=AggregationSettings(default_metric=EditSequenceMetric(settings=EditSettings(entity_metric=HammingEntityMetric(settings=HammingSettings(entity_feature='status', cost=None, mismatch_cost=1.0)), indel_cost=1.0, normalize=True)), sequence_metrics={'events': LCSSequenceMetric(settings=LCSSettings(entity_metric=HammingEntityMetric(settings=HammingSettings(entity_feature='status', cost=None, mismatch_cost=1.0)), equality_threshold=0.0, mode='normalized'))}, agg_fun='mean', weights=None))
Compute distance between a single pair#
traj_ids = traj_pool.unique_ids
traj_a = traj_pool[traj_ids[0]]
traj_b = traj_pool[traj_ids[1]]
dist = agg(traj_a, traj_b)
print(f"\nDistance between {traj_ids[0]} and {traj_ids[1]}: {dist:.4f}")
Distance between 1 and 2: 0.8750
Compute full pairwise distance matrix#
matrix = agg.compute_matrix(traj_pool)
print(f"\nDistance matrix shape: {matrix.shape}")
print(f"Mean distance: {matrix.to_numpy()[matrix.to_numpy() > 0].mean():.4f}")
┌─ AggregationTrajectoryMetric
│
│ ┌─ EditSequenceMetric
│ │
│ │ Chunks: 0%| | 0/1 [00:00<?, ?it/s]
│ │ Chunks: 100%|██████████| 1/1 [00:00<00:00, 1110.19it/s]
│ │
│ └─ Done (100 sequences · 0.00s)
│
│ ┌─ LCSSequenceMetric
│ │
│ │ Chunks: 0%| | 0/1 [00:00<?, ?it/s]
│ │ Chunks: 100%|██████████| 1/1 [00:00<00:00, 1203.88it/s]
│ │
│ └─ Done (100 sequences · 0.00s)
│
└─ Done (100 trajectories · 0.02s)
Distance matrix shape: (100, 100)
Mean distance: 0.6860
Visualize trajectory distances#
fig, ax = plt.subplots(figsize=(8, 6))
arr = matrix.to_numpy()
im = ax.imshow(arr, cmap="viridis", aspect="auto")
ax.set_title(
"Trajectory distances\n(Edit distance for states, LCS for events)",
fontsize=12,
fontweight="bold",
)
ax.set_xlabel("Trajectory index")
ax.set_ylabel("Trajectory index")
cbar = plt.colorbar(im, ax=ax)
cbar.set_label("Distance")
plt.tight_layout()
plt.show()

Total running time of the script: (0 minutes 0.413 seconds)