Trajectory-Level Zeroing#

Align a TrajectoryPool to a common reference date (T0) using any of the four built-in strategies.

At the trajectory level, T0 is computed once (from a reference sub-pool when on= is provided) then shared across every child object: sub-pools, trajectories, and individual sequences all return the same t0 value. Each object still computes its own t0_nearest_rank from its own temporal grid.

Strategy

on= required?

Description

position

yes

T0 = temporal value at row index N in the reference sub-pool

direct

no

T0 = scalar or per-id dict (no sub-pool lookup needed)

feature

no

T0 = trajectory-level static feature column

query

yes

T0 = first/last row matching a Polars expression in the reference sub-pool

After set_t0, inspect the results with tpool.t0_data(), traj.t0, and traj.t0_nearest_rank.

See Zeroing & Alignment for the complete reference and Sequence Level Zeroing for the sequence-level equivalent.

Imports#

import polars as pl
import warnings
from datetime import datetime, timedelta

from tanat import build_events, build_intervals, build_states, build_trajectories
from tanat.dataset import simulate_trajectories, simulate_static

Simulate data#

Generate three temporal pools (interval, event, state) sharing the same IDs, plus a trajectory-level static DataFrame.

N_IDS = 40
SEED = 42

data = simulate_trajectories(
    sequences={
        "admissions": {
            "type": "interval",
            "n_ids": N_IDS,
            "features": ["value", "status"],
        },
        "labs": {"type": "event", "n_ids": N_IDS, "features": ["value", "status"]},
        "phases": {
            "type": "state",
            "n_ids": N_IDS,
            "features": ["phase_name", "score"],
        },
    },
    seed=SEED,
)
static = simulate_static(n_ids=N_IDS, features=["age", "group"], seed=SEED)

for alias, df in data.items():
    print(f"{alias:12s}: {df.shape[0]:4d} rows, columns={df.columns}")
print(f"{'static':12s}: {static.shape[0]:4d} rows, columns={static.columns}")
admissions  :  264 rows, columns=Index(['id', 'start', 'end', 'value', 'status'], dtype='object')
labs        :  259 rows, columns=Index(['id', 'time', 'value', 'status'], dtype='object')
phases      :  253 rows, columns=Index(['id', 'start', 'end', 'phase_name', 'score'], dtype='object')
static      :   40 rows, columns=Index(['id', 'age', 'group'], dtype='object')

Build the TrajectoryPool#

admissions_pool = build_intervals(
    temporal_data=data["admissions"],
    id_column="id",
    start_column="start",
    end_column="end",
)
labs_pool = build_events(
    temporal_data=data["labs"],
    id_column="id",
    time_column="time",
)
phases_pool = build_states(
    temporal_data=data["phases"],
    id_column="id",
    start_column="start",
    end_column="end",
)

tpool = build_trajectories(
    pools={
        "admissions": admissions_pool,
        "labs": labs_pool,
        "phases": phases_pool,
    },
    static_data=static,
    id_column="id",
)
print(tpool)
┌─ Interval SequenceStore
│
│ Step 1/4: Sorting & preparing data
│
│ Step 2/4: Building sequence index
│
│ Step 3/4: Writing entity & time index features
│
│ Step 4/4: Computing & writing metadata
│
└─ Done (40 sequences · 264 entities · 0.00s)
┌─ Event SequenceStore
│
│ Step 1/4: Sorting & preparing data
│
│ Step 2/4: Building sequence index
│
│ Step 3/4: Writing entity & time index features
│
│ Step 4/4: Computing & writing metadata
│
└─ Done (40 sequences · 259 entities · 0.00s)
┌─ State SequenceStore
│
│ Step 1/4: Sorting & preparing data
│
│ Step 2/4: Building sequence index
│
│ Step 3/4: Writing entity & time index features
│
│ Step 4/4: Computing & writing metadata
│
└─ Done (40 sequences · 253 entities · 0.00s)
┌─ TrajectoryStore
│
│ Step 1/2: Linking pools: admissions, labs, phases
│
│ Step 2/2: Building trajectory index & metadata
│
└─ Done (40 trajectories · 3 pool(s) · 0.01s)
┌────────────────────────────────────────────────┐
│             TrajectoryPool Summary             │
└────────────────────────────────────────────────┘

Overview
─────────────────────────
  Trajectories       40
  Store              /home/runner/.tanat/_quick_trajectory_b82febf4
  id_column          id

Time Index
─────────────────────────
  Type               Datetime(time_unit='us', time_zone=None) [2000-01-05 21:36:54.558222 → 2025-01-01 00:00:00]
  t0                 position=0, anchor=start

Sequences (3)
─────────────────────────
  • admissions          IntervalSequencePool(n=40, entity_features=2, static_features=0, store='/home/runner/.tanat/_quick_interval_83bf9e6b')
  • labs                EventSequencePool(n=40, entity_features=2, static_features=0, store='/home/runner/.tanat/_quick_event_52795831')
  • phases              StateSequencePool(n=40, entity_features=2, static_features=0, store='/home/runner/.tanat/_quick_state_5e2cf141')

Static Features (2)
─────────────────────────
  • age                 Numerical [7 → 98]
  • group               String [len 1 → 1]

Strategy 1 - position#

set_t0(position=N, on="alias") selects row N in the reference sub-pool.

  • on= is required: the row index depends on the target sub-pool.

  • anchor= controls which end of the interval is used (interval/state pools only).

# First row of admissions, start of interval
tpool.set_t0(position=0, anchor="start", on="admissions")
print("position=0, anchor='start', on='admissions'")
tpool.t0_data().head()
position=0, anchor='start', on='admissions'
id _T0_ admissions_T0_NEAREST_RANK_ labs_T0_NEAREST_RANK_ phases_T0_NEAREST_RANK_
0 1 2001-08-22 12:56:21.376599 0 0 0
1 2 2000-08-01 14:08:17.503482 0 0 0
2 3 2000-01-07 14:22:15.347103 0 0 0
3 4 2000-02-23 04:33:57.263646 0 0 0
4 5 2001-12-31 04:05:28.022580 0 0 0


Last lab event (event pool: anchor is ignored)

tpool.set_t0(position=-1, on="labs")
print("position=-1, on='labs'")
tpool.t0_data().head()
position=-1, on='labs'
id _T0_ admissions_T0_NEAREST_RANK_ labs_T0_NEAREST_RANK_ phases_T0_NEAREST_RANK_
0 1 2021-03-11 13:36:27.157962 7 2 2
1 2 2022-01-29 12:43:42.168713 9 4 8
2 3 2021-01-23 14:06:23.295715 2 7 2
3 4 2022-10-26 13:22:59.548499 7 6 7
4 5 2013-05-16 22:13:10.705679 1 2 8


Strategy 2 - direct#

set_t0(direct=value) assigns the same T0 to every trajectory. set_t0(direct={id: value, ...}) assigns a per-ID T0; absent IDs receive _T0_ = null.

on= is not required: the value is provided directly.

Scalar: same T0 for everyone

tpool.set_t0(direct=datetime(2010, 6, 1))
print("direct scalar")
tpool.t0_data().head()
direct scalar
id _T0_ admissions_T0_NEAREST_RANK_ labs_T0_NEAREST_RANK_ phases_T0_NEAREST_RANK_
0 1 2010-06-01 00:00:00 4 0 2
1 2 2010-06-01 00:00:00 5 1 8
2 3 2010-06-01 00:00:00 2 3 2
3 4 2010-06-01 00:00:00 2 3 7
4 5 2010-06-01 00:00:00 1 0 8


Per-ID dict (only 3 IDs mapped, rest get null)

first_ids = tpool.unique_ids[:3]
per_id_map = {
    first_ids[0]: datetime(2010, 1, 15),
    first_ids[1]: datetime(2011, 3, 20),
    first_ids[2]: datetime(2012, 7, 10),
}
tpool.set_t0(direct=per_id_map)
print(f"direct per-id (3 IDs mapped, {len(tpool) - 3} get null)")
tpool.t0_data().head(6)
/home/runner/work/TanaT/TanaT/src/tanat/trajectory/pool.py:851: UserWarning: 37 sequence(s) received _t0 = null (no valid row found): [4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40]
  setter.compute_from_trajectory(self, on=on)
direct per-id (3 IDs mapped, 37 get null)
id _T0_ admissions_T0_NEAREST_RANK_ labs_T0_NEAREST_RANK_ phases_T0_NEAREST_RANK_
0 1 2010-01-15 00:00:00 4 0 2
1 2 2011-03-20 00:00:00 5 1 8
2 3 2012-07-10 00:00:00 2 3 2
3 4 <NA> <NA> <NA> <NA>
4 5 <NA> <NA> <NA> <NA>
5 6 <NA> <NA> <NA> <NA>


Strategy 3 - feature#

set_t0(feature="col") reads T0 from a trajectory-level static feature. The dtype must match the pool’s temporal dtype.

on= is not required.

Add a custom datetime static feature

ids = tpool.unique_ids
admission_dates = [
    datetime(2008, 1, 1) + timedelta(days=i * 45) for i in range(len(ids))
]
date_df = pl.DataFrame(
    {
        "id": ids,
        "admission_date": pl.Series(admission_dates).cast(pl.Datetime("us")),
    }
)
tpool.add_static_features(date_df)

tpool.set_t0(feature="admission_date")
print("feature='admission_date'")
tpool.t0_data().head()
feature='admission_date'
id _T0_ admissions_T0_NEAREST_RANK_ labs_T0_NEAREST_RANK_ phases_T0_NEAREST_RANK_
0 1 2008-01-01 00:00:00 2 0 2
1 2 2008-02-15 00:00:00 5 1 8
2 3 2008-03-31 00:00:00 1 2 2
3 4 2008-05-15 00:00:00 1 3 7
4 5 2008-06-29 00:00:00 1 0 8


Strategy 4 - query#

set_t0(query=expr, on="alias") picks the first (or last) row matching the expression in the reference sub-pool.

on= is required: the expression refers to columns in the target sub-pool.

# First lab with status matching a known category
status_values = labs_pool.temporal_data(fmt="polars")["status"].unique().to_list()
ref_status = status_values[0]

tpool.set_t0(
    query=pl.col("status") == ref_status,
    use_first=True,
    on="labs",
)
print(f"query: first lab with status=={ref_status!r}, on='labs'")
tpool.t0_data().head()
/home/runner/work/TanaT/TanaT/src/tanat/trajectory/pool.py:851: UserWarning: 20 sequence(s) received _t0 = null (no valid row found): [1, 4, 7, 8, 13, 15, 17, 18, 19, 20, 21, 27, 29, 32, 33, 34, 35, 36, 37, 39]
  setter.compute_from_trajectory(self, on=on)
query: first lab with status=='C', on='labs'
id _T0_ admissions_T0_NEAREST_RANK_ labs_T0_NEAREST_RANK_ phases_T0_NEAREST_RANK_
0 1 <NA> <NA> <NA> <NA>
1 2 2011-03-22 18:02:46.084166 5 2 8
2 3 2021-01-23 14:06:23.295715 2 7 2
3 4 <NA> <NA> <NA> <NA>
4 5 2011-09-27 10:30:06.092538 1 1 8


# Last admission matching a status (interval pool, anchor='end')
adm_status_values = (
    admissions_pool.temporal_data(fmt="polars")["status"].unique().to_list()
)
ref_adm_status = adm_status_values[0]

tpool.set_t0(
    query=pl.col("status") == ref_adm_status,
    anchor="end",
    use_first=False,
    on="admissions",
)
print(f"query: last admission with status=={ref_adm_status!r}, anchor='end'")
tpool.t0_data().head()
/home/runner/work/TanaT/TanaT/src/tanat/trajectory/pool.py:851: UserWarning: 9 sequence(s) received _t0 = null (no valid row found): [10, 11, 18, 21, 24, 34, 36, 39, 40]
  setter.compute_from_trajectory(self, on=on)
query: last admission with status=='D', anchor='end'
id _T0_ admissions_T0_NEAREST_RANK_ labs_T0_NEAREST_RANK_ phases_T0_NEAREST_RANK_
0 1 2024-01-19 09:16:32.239342 9 2 2
1 2 2020-11-06 13:09:56.043064 8 3 8
2 3 2009-10-21 10:08:27.706922 2 3 2
3 4 2009-07-13 10:55:48.123113 2 3 7
4 5 2004-01-21 02:25:46.545494 1 0 8


Trajectory-level properties#

After set_t0, every Trajectory exposes two read-only properties:

Property

Type

Description

traj.t0

scalar | None

T0 for this trajectory

traj.t0_nearest_rank

dict[str, int | None]

Per-alias floor index: {"admissions": 0, "labs": 2, ...}

tpool.set_t0(position=0, anchor="start", on="admissions")

traj = tpool[tpool.unique_ids[0]]
print(f"id               : {traj.id_value}")
print(f"t0               : {traj.t0}")
print(f"t0_nearest_rank  : {traj.t0_nearest_rank}")
id               : 1
t0               : 2001-08-22 12:56:21.376599
t0_nearest_rank  : {'admissions': 0, 'labs': 0, 'phases': 0}
# Iterate over the first few trajectories
for traj in list(tpool)[:6]:
    ranks = ", ".join(f"{alias}={rank}" for alias, rank in traj.t0_nearest_rank.items())
    print(f"  id={traj.id_value!r:<6}  t0={str(traj.t0):<30}  ranks=[{ranks}]")
id=1       t0=2001-08-22 12:56:21.376599      ranks=[admissions=0, labs=0, phases=0]
id=2       t0=2000-08-01 14:08:17.503482      ranks=[admissions=0, labs=0, phases=0]
id=3       t0=2000-01-07 14:22:15.347103      ranks=[admissions=0, labs=0, phases=0]
id=4       t0=2000-02-23 04:33:57.263646      ranks=[admissions=0, labs=0, phases=0]
id=5       t0=2001-12-31 04:05:28.022580      ranks=[admissions=0, labs=0, phases=0]
id=6       t0=2002-02-07 05:50:58.019626      ranks=[admissions=0, labs=0, phases=0]

t0_data() column structure#

tpool.t0_data() returns one row per trajectory with:

  • _T0_: the T0 value (identical across all sub-pools)

  • <alias>_T0_NEAREST_RANK_: per-alias floor index (each sub-pool gets its own column because the temporal grids differ)

df = tpool.t0_data(fmt="polars")
print("Columns:", df.columns)
print(f"Rows   : {len(df)} (one per trajectory)")
df.head()
Columns: ['id', '_T0_', 'admissions_T0_NEAREST_RANK_', 'labs_T0_NEAREST_RANK_', 'phases_T0_NEAREST_RANK_']
Rows   : 40 (one per trajectory)
shape: (5, 5)
id_T0_admissions_T0_NEAREST_RANK_labs_T0_NEAREST_RANK_phases_T0_NEAREST_RANK_
i64datetime[μs]u32u32u32
12001-08-22 12:56:21.376599000
22000-08-01 14:08:17.503482000
32000-01-07 14:22:15.347103000
42000-02-23 04:33:57.263646000
52001-12-31 04:05:28.022580000


Null handling#

Trajectories with _T0_ = null arise in the same situations as for sequence pools (out-of-range position, missing dict key, null feature, no query match). The trajectory is not dropped; traj.t0 returns None and all nearest-rank values are None.

with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    tpool.set_t0(query=pl.col("value") > 1e9, on="labs")  # no match

null_trajs = [t.id_value for t in tpool if t.t0 is None]
print(f"{len(null_trajs)}/{len(tpool)} trajectory(ies) with t0 = None")
40/40 trajectory(ies) with t0 = None

T0 is shared across all children#

A single tpool.set_t0(...) call is enough. Every object you retrieve from the pool - a sub-pool, a trajectory, or an individual sequence - automatically returns the same t0 value. t0_nearest_rank still varies: each pool computes its floor index on its own temporal grid.

tpool.set_t0(position=0, anchor="start", on="admissions")

uid = tpool.unique_ids[0]
traj = tpool[uid]
seq = traj["labs"]

print(f"traj.t0                        : {traj.t0}")
print(f"traj['labs'].t0                : {seq.t0}")
print(f"same value                     : {seq.t0 == traj.t0}")
print(f"seq.t0_nearest_rank (labs grid): {seq.t0_nearest_rank}")
traj.t0                        : 2001-08-22 12:56:21.376599
traj['labs'].t0                : 2001-08-22 12:56:21.376599
same value                     : True
seq.t0_nearest_rank (labs grid): 0

Sub-pools expose the same T0 via t0_data()

labs_t0 = tpool.sequence_pools["labs"].t0_data(fmt="polars")
print("Columns:", labs_t0.columns)
labs_t0.head()
Columns: ['id', '_T0_', '_T0_NEAREST_RANK_']
shape: (5, 3)
id_T0__T0_NEAREST_RANK_
i64datetime[μs]u32
12001-08-22 12:56:21.3765990
22000-08-01 14:08:17.5034820
32000-01-07 14:22:15.3471030
42000-02-23 04:33:57.2636460
52001-12-31 04:05:28.0225800


Total running time of the script: (0 minutes 0.151 seconds)

Gallery generated by Sphinx-Gallery