Note
Go to the end to download the full example code.
Trajectory-Level Zeroing#
Align a TrajectoryPool to a common reference
date (T0) using any of the four built-in strategies.
At the trajectory level, T0 is computed once (from a reference sub-pool
when on= is provided) then shared across every child object: sub-pools,
trajectories, and individual sequences all return the same t0 value.
Each object still computes its own t0_nearest_rank from its own temporal grid.
Strategy |
|
Description |
|---|---|---|
|
yes |
T0 = temporal value at row index |
|
no |
T0 = scalar or per-id |
|
no |
T0 = trajectory-level static feature column |
|
yes |
T0 = first/last row matching a Polars expression in the reference sub-pool |
After set_t0, inspect the results with tpool.t0_data(), traj.t0,
and traj.t0_nearest_rank.
See Zeroing & Alignment for the complete reference and Sequence Level Zeroing for the sequence-level equivalent.
Imports#
import polars as pl
import warnings
from datetime import datetime, timedelta
from tanat import build_events, build_intervals, build_states, build_trajectories
from tanat.dataset import simulate_trajectories, simulate_static
Simulate data#
Generate three temporal pools (interval, event, state) sharing the same IDs, plus a trajectory-level static DataFrame.
N_IDS = 40
SEED = 42
data = simulate_trajectories(
sequences={
"admissions": {
"type": "interval",
"n_ids": N_IDS,
"features": ["value", "status"],
},
"labs": {"type": "event", "n_ids": N_IDS, "features": ["value", "status"]},
"phases": {
"type": "state",
"n_ids": N_IDS,
"features": ["phase_name", "score"],
},
},
seed=SEED,
)
static = simulate_static(n_ids=N_IDS, features=["age", "group"], seed=SEED)
for alias, df in data.items():
print(f"{alias:12s}: {df.shape[0]:4d} rows, columns={df.columns}")
print(f"{'static':12s}: {static.shape[0]:4d} rows, columns={static.columns}")
admissions : 264 rows, columns=Index(['id', 'start', 'end', 'value', 'status'], dtype='object')
labs : 259 rows, columns=Index(['id', 'time', 'value', 'status'], dtype='object')
phases : 253 rows, columns=Index(['id', 'start', 'end', 'phase_name', 'score'], dtype='object')
static : 40 rows, columns=Index(['id', 'age', 'group'], dtype='object')
Build the TrajectoryPool#
admissions_pool = build_intervals(
temporal_data=data["admissions"],
id_column="id",
start_column="start",
end_column="end",
)
labs_pool = build_events(
temporal_data=data["labs"],
id_column="id",
time_column="time",
)
phases_pool = build_states(
temporal_data=data["phases"],
id_column="id",
start_column="start",
end_column="end",
)
tpool = build_trajectories(
pools={
"admissions": admissions_pool,
"labs": labs_pool,
"phases": phases_pool,
},
static_data=static,
id_column="id",
)
print(tpool)
┌─ Interval SequenceStore
│
│ Step 1/4: Sorting & preparing data
│
│ Step 2/4: Building sequence index
│
│ Step 3/4: Writing entity & time index features
│
│ Step 4/4: Computing & writing metadata
│
└─ Done (40 sequences · 264 entities · 0.00s)
┌─ Event SequenceStore
│
│ Step 1/4: Sorting & preparing data
│
│ Step 2/4: Building sequence index
│
│ Step 3/4: Writing entity & time index features
│
│ Step 4/4: Computing & writing metadata
│
└─ Done (40 sequences · 259 entities · 0.00s)
┌─ State SequenceStore
│
│ Step 1/4: Sorting & preparing data
│
│ Step 2/4: Building sequence index
│
│ Step 3/4: Writing entity & time index features
│
│ Step 4/4: Computing & writing metadata
│
└─ Done (40 sequences · 253 entities · 0.00s)
┌─ TrajectoryStore
│
│ Step 1/2: Linking pools: admissions, labs, phases
│
│ Step 2/2: Building trajectory index & metadata
│
└─ Done (40 trajectories · 3 pool(s) · 0.01s)
┌────────────────────────────────────────────────┐
│ TrajectoryPool Summary │
└────────────────────────────────────────────────┘
Overview
─────────────────────────
Trajectories 40
Store /home/runner/.tanat/_quick_trajectory_b82febf4
id_column id
Time Index
─────────────────────────
Type Datetime(time_unit='us', time_zone=None) [2000-01-05 21:36:54.558222 → 2025-01-01 00:00:00]
t0 position=0, anchor=start
Sequences (3)
─────────────────────────
• admissions IntervalSequencePool(n=40, entity_features=2, static_features=0, store='/home/runner/.tanat/_quick_interval_83bf9e6b')
• labs EventSequencePool(n=40, entity_features=2, static_features=0, store='/home/runner/.tanat/_quick_event_52795831')
• phases StateSequencePool(n=40, entity_features=2, static_features=0, store='/home/runner/.tanat/_quick_state_5e2cf141')
Static Features (2)
─────────────────────────
• age Numerical [7 → 98]
• group String [len 1 → 1]
Strategy 1 - position#
set_t0(position=N, on="alias") selects row N in the reference
sub-pool.
on=is required: the row index depends on the target sub-pool.anchor=controls which end of the interval is used (interval/state pools only).
# First row of admissions, start of interval
tpool.set_t0(position=0, anchor="start", on="admissions")
print("position=0, anchor='start', on='admissions'")
tpool.t0_data().head()
position=0, anchor='start', on='admissions'
Last lab event (event pool: anchor is ignored)
tpool.set_t0(position=-1, on="labs")
print("position=-1, on='labs'")
tpool.t0_data().head()
position=-1, on='labs'
Strategy 2 - direct#
set_t0(direct=value) assigns the same T0 to every trajectory.
set_t0(direct={id: value, ...}) assigns a per-ID T0; absent IDs
receive _T0_ = null.
on= is not required: the value is provided directly.
Scalar: same T0 for everyone
tpool.set_t0(direct=datetime(2010, 6, 1))
print("direct scalar")
tpool.t0_data().head()
direct scalar
Per-ID dict (only 3 IDs mapped, rest get null)
first_ids = tpool.unique_ids[:3]
per_id_map = {
first_ids[0]: datetime(2010, 1, 15),
first_ids[1]: datetime(2011, 3, 20),
first_ids[2]: datetime(2012, 7, 10),
}
tpool.set_t0(direct=per_id_map)
print(f"direct per-id (3 IDs mapped, {len(tpool) - 3} get null)")
tpool.t0_data().head(6)
/home/runner/work/TanaT/TanaT/src/tanat/trajectory/pool.py:851: UserWarning: 37 sequence(s) received _t0 = null (no valid row found): [4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40]
setter.compute_from_trajectory(self, on=on)
direct per-id (3 IDs mapped, 37 get null)
Strategy 3 - feature#
set_t0(feature="col") reads T0 from a trajectory-level static
feature. The dtype must match the pool’s temporal dtype.
on= is not required.
Add a custom datetime static feature
ids = tpool.unique_ids
admission_dates = [
datetime(2008, 1, 1) + timedelta(days=i * 45) for i in range(len(ids))
]
date_df = pl.DataFrame(
{
"id": ids,
"admission_date": pl.Series(admission_dates).cast(pl.Datetime("us")),
}
)
tpool.add_static_features(date_df)
tpool.set_t0(feature="admission_date")
print("feature='admission_date'")
tpool.t0_data().head()
feature='admission_date'
Strategy 4 - query#
set_t0(query=expr, on="alias") picks the first (or last) row matching
the expression in the reference sub-pool.
on= is required: the expression refers to columns in the target
sub-pool.
# First lab with status matching a known category
status_values = labs_pool.temporal_data(fmt="polars")["status"].unique().to_list()
ref_status = status_values[0]
tpool.set_t0(
query=pl.col("status") == ref_status,
use_first=True,
on="labs",
)
print(f"query: first lab with status=={ref_status!r}, on='labs'")
tpool.t0_data().head()
/home/runner/work/TanaT/TanaT/src/tanat/trajectory/pool.py:851: UserWarning: 20 sequence(s) received _t0 = null (no valid row found): [1, 4, 7, 8, 13, 15, 17, 18, 19, 20, 21, 27, 29, 32, 33, 34, 35, 36, 37, 39]
setter.compute_from_trajectory(self, on=on)
query: first lab with status=='C', on='labs'
# Last admission matching a status (interval pool, anchor='end')
adm_status_values = (
admissions_pool.temporal_data(fmt="polars")["status"].unique().to_list()
)
ref_adm_status = adm_status_values[0]
tpool.set_t0(
query=pl.col("status") == ref_adm_status,
anchor="end",
use_first=False,
on="admissions",
)
print(f"query: last admission with status=={ref_adm_status!r}, anchor='end'")
tpool.t0_data().head()
/home/runner/work/TanaT/TanaT/src/tanat/trajectory/pool.py:851: UserWarning: 9 sequence(s) received _t0 = null (no valid row found): [10, 11, 18, 21, 24, 34, 36, 39, 40]
setter.compute_from_trajectory(self, on=on)
query: last admission with status=='D', anchor='end'
Trajectory-level properties#
After set_t0, every Trajectory
exposes two read-only properties:
Property |
Type |
Description |
|---|---|---|
|
scalar | |
T0 for this trajectory |
|
|
Per-alias floor index: |
tpool.set_t0(position=0, anchor="start", on="admissions")
traj = tpool[tpool.unique_ids[0]]
print(f"id : {traj.id_value}")
print(f"t0 : {traj.t0}")
print(f"t0_nearest_rank : {traj.t0_nearest_rank}")
id : 1
t0 : 2001-08-22 12:56:21.376599
t0_nearest_rank : {'admissions': 0, 'labs': 0, 'phases': 0}
# Iterate over the first few trajectories
for traj in list(tpool)[:6]:
ranks = ", ".join(f"{alias}={rank}" for alias, rank in traj.t0_nearest_rank.items())
print(f" id={traj.id_value!r:<6} t0={str(traj.t0):<30} ranks=[{ranks}]")
id=1 t0=2001-08-22 12:56:21.376599 ranks=[admissions=0, labs=0, phases=0]
id=2 t0=2000-08-01 14:08:17.503482 ranks=[admissions=0, labs=0, phases=0]
id=3 t0=2000-01-07 14:22:15.347103 ranks=[admissions=0, labs=0, phases=0]
id=4 t0=2000-02-23 04:33:57.263646 ranks=[admissions=0, labs=0, phases=0]
id=5 t0=2001-12-31 04:05:28.022580 ranks=[admissions=0, labs=0, phases=0]
id=6 t0=2002-02-07 05:50:58.019626 ranks=[admissions=0, labs=0, phases=0]
t0_data() column structure#
tpool.t0_data() returns one row per trajectory with:
_T0_: the T0 value (identical across all sub-pools)<alias>_T0_NEAREST_RANK_: per-alias floor index (each sub-pool gets its own column because the temporal grids differ)
df = tpool.t0_data(fmt="polars")
print("Columns:", df.columns)
print(f"Rows : {len(df)} (one per trajectory)")
df.head()
Columns: ['id', '_T0_', 'admissions_T0_NEAREST_RANK_', 'labs_T0_NEAREST_RANK_', 'phases_T0_NEAREST_RANK_']
Rows : 40 (one per trajectory)
Null handling#
Trajectories with _T0_ = null arise in the same situations as for
sequence pools (out-of-range position, missing dict key, null feature,
no query match). The trajectory is not dropped; traj.t0 returns
None and all nearest-rank values are None.
with warnings.catch_warnings():
warnings.simplefilter("ignore")
tpool.set_t0(query=pl.col("value") > 1e9, on="labs") # no match
null_trajs = [t.id_value for t in tpool if t.t0 is None]
print(f"{len(null_trajs)}/{len(tpool)} trajectory(ies) with t0 = None")
40/40 trajectory(ies) with t0 = None