Note
Go to the end to download the full example code.
Federated learning with Fed-BioMed#
Scenario : Several hospitals hold patient observation records. They want to train a shared model without ever sharing their raw data.
This tutorial walks through the full pipeline in two stages:
Stage A · Local prototype
Generate a synthetic multi-centre dataset with
FedPatientJourneyBuild a TanaT event pool
Convert patient trajectories into temporal tensors
Train a tiny PyTorch model locally (just to validate the pipeline)
Stage B · Federated deployment
Move the preprocessing into a Fed-BioMed
TrainingPlanRegister each centre’s data on its own node
Run a federated
ExperimentwithFedAverage
Stage A · Local prototype#
1 · Imports#
import polars as pl
import torch
import torch.nn as nn
from tanat_synthea import FedPatientJourney
from tanat.sequence.shortcuts import build_events
2 · Generate synthetic patient journeys#
FedPatientJourney.demo() simulates up 3 centres (US / FR / GB,
60 patients each) and runs Synthea to produce realistic patient records.
# Generates 3 centres with 60 patients each
fpj = FedPatientJourney.demo()
fpj.generate()
┌─ FedPatientJourney -- Synthea generation
│ Centers : [0, 1, 2]
│ Output : data/demo
│
│ ┌─ Massachusetts General Hospital
│ │ Country : US
│ │ Region : Massachusetts
│ │ City : (entire region)
│ │ Population : 50 patients
│ │ Seed : 42
│ │ Ages : 20-80
│ │
│ │ Step 1/2: Running Synthea
│ │
│ │ Step 2/2: Processing CSV files
│ │
│ └─ Done (24.24s)
│
│ ┌─ CHU de Lyon
│ │ Country : FR
│ │ Region : Auvergne-Rhone-Alpes
│ │ City : (entire region)
│ │ Population : 50 patients
│ │ Seed : 123
│ │ Ages : 40-85
│ │
│ │ Step 1/2: Running Synthea
│ │
│ │ Step 2/2: Processing CSV files
│ │
│ └─ Done (25.88s)
│
│ ┌─ University College London Hospital
│ │ Country : GB
│ │ Region : Greater London
│ │ City : (entire region)
│ │ Population : 50 patients
│ │ Seed : 777
│ │ Ages : 20-60
│ │
│ │ Step 1/2: Running Synthea
│ │
│ │ Step 2/2: Processing CSV files
│ │
│ └─ Done (18.18s)
│
└─ Done (datasets saved to data/demo · 68.30s)
print(fpj)
┌────────────────────────────────────────────────┐
│ FedPatientJourney Summary │
└────────────────────────────────────────────────┘
Overview
─────────────────────────
centers_total 3
centers_available 3 / 3
Centers (3)
─────────────────────────
• Massachusetts General Hospitalus/Massachusetts · pop=50 · ready
• CHU de Lyon fr/Auvergne-Rhone-Alpes · pop=50 · ready
• University College London Hospitalgb/Greater London · pop=50 · ready
3 · Inspect one centre#
In a real federated setup each centre lives on a separate node. Here we work with centre 0 locally to validate the full pipeline before federating.
center = fpj.center(0)
static_data = center.static_data()
observation_data = center.temporal_data(record_types=["observation"])
print("Static data :", static_data.shape)
print("Observations :", observation_data.shape)
Static data : (60, 30)
Observations : (57708, 12)
4 · Build a TanaT event pool#
build_events converts the raw observation table into a structured
EventSequencePool indexed by patient and time.
pool = build_events(
observation_data,
id_column="PATIENT",
time_column="DATE",
static_data=static_data,
store_name="observations",
)
┌─ Event SequenceStore
│
│ Step 1/4: Sorting & preparing data
│
│ Step 2/4: Building sequence index
│
│ Step 3/4: Writing entity, time index & static features
│
│ Step 4/4: Computing & writing metadata
│
└─ Done (60 sequences · 57,708 entities · 0.03s)
print(pool)
┌────────────────────────────────────────────────┐
│ EventSequencePool Summary │
└────────────────────────────────────────────────┘
Overview
─────────────────────────
Sequences 60
Store /home/runner/.tanat_workspace/building_pools_tutorial/observations
id_column PATIENT
Time Index
─────────────────────────
Type Datetime(time_unit='ns', time_zone='UTC') [1953-07-29 22:28:12+00:00 → 2026-06-11 07:23:51+00:00]
Columns ['DATE']
t0 position=0, anchor=None
Entity Features (10)
─────────────────────────
• CATEGORY String [len 4 → 14]
• CENTER_ID Numerical [0 → 0]
• CENTER_NAME String [len 30 → 30]
• CODE String [len 4 → 9]
• DESCRIPTION String [len 4 → 138]
• ENCOUNTER String [len 36 → 36]
• RECORD_TYPE String [len 11 → 11]
• TYPE String [len 4 → 7]
• UNITS String [len 1 → 16]
• VALUE String [len 1 → 101]
Static Features (29)
─────────────────────────
• ADDRESS String [len 13 → 29]
• BIRTHDATE Datetime(time_unit='ns', time_zone='UTC') [1947-09-07 00:00:00+00:00 → 2003-09-28 00:00:00+00:00]
• BIRTHPLACE String [len 18 → 36]
• CENTER_ID Numerical [0 → 0]
• CENTER_NAME String [len 30 → 30]
• CITY String [len 5 → 12]
• COUNTY String [len 12 → 17]
• DEATHDATE Datetime(time_unit='ns', time_zone='UTC') [1962-05-20 00:00:00+00:00 → 2025-11-05 00:00:00+00:00]
• DRIVERS String [len 9 → 9]
• ETHNICITY String [len 8 → 11]
... and 19 more
Temporal data (one row per observation event)
print(pool.temporal_data()[["PATIENT", "DATE", "CODE"]].head())
PATIENT DATE CODE
0 080b069b-5108-46b6-ecef-6aacd3b9ef3f 2016-01-23 12:33:17+00:00 QALY
1 080b069b-5108-46b6-ecef-6aacd3b9ef3f 2016-01-23 12:33:17+00:00 DALY
2 080b069b-5108-46b6-ecef-6aacd3b9ef3f 2016-01-23 12:33:17+00:00 QOLS
3 080b069b-5108-46b6-ecef-6aacd3b9ef3f 2017-01-23 12:33:17+00:00 QALY
4 080b069b-5108-46b6-ecef-6aacd3b9ef3f 2017-01-23 12:33:17+00:00 QOLS
Static data (one row per patient)
print(pool.static_data()[["PATIENT", "GENDER", "BIRTHDATE"]].head())
PATIENT GENDER BIRTHDATE
0 080b069b-5108-46b6-ecef-6aacd3b9ef3f M 1966-01-23 00:00:00+00:00
1 08917fa9-c6ca-7569-27d5-b655cb688998 F 1997-01-18 00:00:00+00:00
2 12834141-3f83-d76b-7d7c-151862dd1150 M 1953-07-29 00:00:00+00:00
3 26aa011e-0d75-d1a3-3c5a-560a9450cc9d M 1981-02-26 00:00:00+00:00
4 2c167999-289d-95fa-5c27-7d8b8d35348c M 1983-04-21 00:00:00+00:00
5 · Convert events into a temporal tensor#
to_tensor() projects each
patient’s history onto a shared daily time axis and returns a 3-tuple
(tensor, patient_ids, feature_names).
Tensor dimensions:
N : patients
M : time bins (one per calendar day, capped at
max_bins)K : one-hot encoded observation codes
We keep only the top 30 most frequent codes before converting to a tensor. This keeps the tutorial manageable and avoids an overly sparse OHE vocabulary while preserving the most common observation patterns.
TOP_K = 30
top_codes = (
pool.temporal_data()
.groupby("CODE")
.size()
.sort_values(ascending=False)
.head(TOP_K)
.index.tolist()
)
print(f"Retaining {TOP_K} codes out of {pool.temporal_data()['CODE'].nunique()}")
print("Top codes:", top_codes[:10], "...")
Retaining 30 codes out of 236
Top codes: ['72514-3', '29463-7', '8480-6', '8462-4', '9279-1', '8867-4', '72166-2', '8302-2', '39156-5', '71802-3'] ...
from tanat.criterion import EntityCriterion
pool.filter_entities(
EntityCriterion(query=pl.col("CODE").is_in(top_codes)),
inplace=True,
)
[filter_entities] EntityCriterion → 25,267 / 57,708 entities (43.8%) · 0 IDs affected
EventSequencePool(n=60, entity_features=10, static_features=29, store='/home/runner/.tanat_workspace/building_pools_tutorial/observations')
BIN_SIZE = "360D" # one bin = 360 days
pool.cast_features({"CODE": pl.Categorical})
tensor, patient_ids, feature_names = pool.to_tensor(
features="CODE",
bin_size=BIN_SIZE,
fill_value=0,
ohe=True,
)
N, M, K = tensor.shape
print(f"Tensor shape : {tensor.shape}")
print(f"Patients : {N}")
print(f"Time bins : {M}")
print(f"Features : {K}")
print(f"Sparsity : {(tensor == 0).mean():.1%} empty bins")
print(f"Non-zero cells : {(tensor != 0).sum()}")
print(f"Patients with events : {(tensor.sum(axis=(1, 2)) != 0).sum()} / {N}")
Tensor shape : (60, 74, 30)
Patients : 60
Time bins : 74
Features : 30
Sparsity : 99.6% empty bins
Non-zero cells : 501
Patients with events : 60 / 60
6 · Build a learning-ready dataset#
We collapse the time axis by averaging across bins. A quick way to turn the 3D tensor into a flat feature matrix.
A real model could consume the full (N, M, K) tensor.
This step is intentionally simplified.
X = tensor.mean(axis=1) # (N, M, K) → (N, K)
# Binary target: male = 1, female = 0
y_df = pool.apply(
exprs=(pl.col("GENDER") == "M").cast(pl.Int8).alias("target"),
is_static=True,
)
y = y_df["target"].to_numpy()
print("X shape :", X.shape)
print("y shape :", y.shape)
print(f"Class balance : {y.mean():.0%} positive")
X shape : (60, 30)
y shape : (60,)
Class balance : 57% positive
7 · Train a local model#
A three-layer MLP trained for 10 epochs to verify that the tensor feeds cleanly into a standard PyTorch loop.
X_t = torch.tensor(X, dtype=torch.float32)
y_t = torch.tensor(y, dtype=torch.float32)
model = nn.Sequential(
nn.Linear(K, 32),
nn.ReLU(),
nn.Linear(32, 1),
nn.Sigmoid(),
)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
loss_fn = nn.BCELoss()
for epoch in range(10):
optimizer.zero_grad()
pred = model(X_t).squeeze()
loss = loss_fn(pred, y_t)
loss.backward()
optimizer.step()
print(f"Epoch {epoch + 1:2d} loss={loss.item():.4f}")
Epoch 1 loss=0.6881
Epoch 2 loss=0.6880
Epoch 3 loss=0.6878
Epoch 4 loss=0.6877
Epoch 5 loss=0.6875
Epoch 6 loss=0.6874
Epoch 7 loss=0.6873
Epoch 8 loss=0.6871
Epoch 9 loss=0.6870
Epoch 10 loss=0.6869
Stage B · Federated deployment#
This stage shows how the same preprocessing pipeline is executed on each node locally, while only model weights are shared.
Tip
Key concept: what travels on the wire? In Fed-BioMed only the model weights move between nodes and the researcher. Raw data never leaves the hospital.
┌───────────────────────────────────────┐
│ RESEARCHER │
│ 1) defines the model │
│ 2) aggregates weights (FedAvg) │
└───────────────────────────────────────┘
▲ ▲ ▲
│ │ │
weights only: raw data stays local
│ │ │
┌─────────┘ ┌───┘ └──┐
▼ ▼ ▼
┌───────────────┐ ┌───────────────┐ ┌───────────────┐
│ Node 0 · US │ │ Node 1 · FR │ │ Node 2 · GB │
│ local data │ │ local data │ │ local data │
│ + │ │ + │ │ + │
│local training │ │local training │ │local training │
└───────────────┘ └───────────────┘ └───────────────┘
Note
Prerequisites:
pip install fedbiomed tanat tanat-synthea
# one terminal per node
fedbiomed node -p fbm-node-0 start
fedbiomed node -p fbm-node-1 start
fedbiomed node -p fbm-node-2 start
# researcher (this notebook)
fedbiomed researcher start
8 · Register data on each node#
On each node terminal, declare the local data with a shared tag. The tag is how the Researcher discovers which nodes hold compatible data.
# Node 0 terminal
fedbiomed node -p fbm-node-0 dataset add \\
--name "tanat_observations" \\
--tags tanat-obs \\
--data-type csv \\
--path /path/to/center_0/
# Node 1 terminal (same tag, different path)
fedbiomed node -p fbm-node-1 dataset add \\
--name "tanat_observations" \\
--tags tanat-obs \\
--data-type csv \\
--path /path/to/center_1/
# Node 2 terminal
fedbiomed node -p fbm-node-2 dataset add \\
--name "tanat_observations" \\
--tags tanat-obs \\
--data-type csv \\
--path /path/to/center_2/
9 · Define the TrainingPlan#
A TorchTrainingPlan wraps:
init_model: the PyTorch model (identical to Stage A)init_optimizer: the optimiserinit_dependencies: imports that must be available on each nodetraining_data: the TanaT preprocessing pipeline from Stage Atraining_step: the loss computation
Here the new part is simple: each node reads its own local data, applies the same preprocessing, and returns a training-ready dataset. Raw data never leaves the node; only model weights are aggregated.
from fedbiomed.common.training_plans import TorchTrainingPlan
from fedbiomed.common.data import DataManager
import polars as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
from tanat.sequence.shortcuts import build_events
class TanaTObsTrainingPlan(TorchTrainingPlan):
# ------------------------------------------------------------------
# Model (same MLP as Stage A)
# ------------------------------------------------------------------
def init_model(self, model_args: dict):
in_dim = model_args["in_features"]
class MLP(nn.Module):
def __init__(self, in_dim):
super().__init__()
self.net = nn.Sequential(
nn.Linear(in_dim, 32),
nn.ReLU(),
nn.Linear(32, 1),
nn.Sigmoid(),
)
def forward(self, x):
return self.net(x).squeeze(-1)
return MLP(in_dim)
# ------------------------------------------------------------------
# Optimiser
# ------------------------------------------------------------------
def init_optimizer(self, optimizer_args: dict):
return torch.optim.Adam(
self.parameters(),
lr=optimizer_args.get("lr", 1e-3),
)
# ------------------------------------------------------------------
# Node-side dependencies
# ------------------------------------------------------------------
def init_dependencies(self):
return [
"import polars as pl",
"from tanat.sequence.shortcuts import build_events",
]
# ------------------------------------------------------------------
# Data loading (runs locally on each node/TanaT pipeline from Stage A)
# ------------------------------------------------------------------
def training_data(self):
center = self.dataset_path # path injected by Fed-BioMed
static_data = center.static_data()
observation_data = center.temporal_data(record_types=["observation"])
pool = build_events(
observation_data,
id_column="PATIENT",
time_column="DATE",
static_data=static_data,
)
pool.cast_features({"CODE": pl.Categorical})
tensor, _, _ = pool.to_tensor(
features="CODE",
bin_size="1D",
max_bins=90,
fill_value=0,
ohe=True,
)
X = tensor.mean(axis=1) # (N, M, K) → (N, K)
y_df = pool.apply(
exprs=(pl.col("GENDER") == "M").cast(pl.Int8).alias("target"),
is_static=True,
)
y = y_df["target"].to_numpy()
return DataManager(
dataset=torch.tensor(X, dtype=torch.float32),
target=torch.tensor(y, dtype=torch.float32),
)
# ------------------------------------------------------------------
# Loss
# ------------------------------------------------------------------
def training_step(self, data, target):
return F.binary_cross_entropy(self.forward(data), target)
10 · Run the federated experiment#
from fedbiomed.researcher.federated_workflows import Experiment
from fedbiomed.researcher.aggregators.fedavg import FedAverage
exp = Experiment(
tags=["tanat-obs"],
model_args={"in_features": K}, # K from Stage A
training_plan_class=TanaTObsTrainingPlan,
training_args={
"loader_args": {"batch_size": 16},
"optimizer_args": {"lr": 1e-3},
"epochs": 5,
},
round_limit=10,
aggregator=FedAverage(),
)
exp.run()
# Inspect results
final_model = exp.training_plan().model()
final_model.eval()
# Loss per round (averaged over nodes)
for round_id, replies in exp.training_replies().items():
losses = [r["loss"] for r in replies.values() if "loss" in r]
if losses:
print(f"Round {round_id:2d} avg loss = {sum(losses) / len(losses):.4f}")
Total running time of the script: (1 minutes 8.712 seconds)