diff --git a/IMPLEMENTATION_PLAN.md b/IMPLEMENTATION_PLAN.md
index 20cd592..a07ea4f 100644
--- a/IMPLEMENTATION_PLAN.md
+++ b/IMPLEMENTATION_PLAN.md
@@ -402,14 +402,14 @@ Drawer selection → update_drug_selection → app-state store → load_pathway_
- **Checkpoint**: Cost Waterfall tab renders real data, responds to filters ✓
### 9.6 Drug Switching Sankey chart (Tab 5)
-- [ ] Create `dash_app/callbacks/sankey.py`:
+- [x] Create `dash_app/callbacks/sankey.py`:
- Build Plotly Sankey diagram from `get_drug_transitions()` data
- Left nodes = 1st-line drugs, middle = 2nd-line, right = 3rd-line
- Link width = patient count, colour by drug or directorate
- Uses `parse_pathway_drugs()` to extract drug transitions from `ids` column
-- [ ] Create figure function in `src/visualization/`
-- [ ] Wire into tab switching
-- **Checkpoint**: Sankey tab renders real drug transition flows
+- [x] Create figure function in `src/visualization/`
+- [x] Wire into tab switching
+- **Checkpoint**: Sankey tab renders real drug transition flows ✓
### 9.7 Dosing Interval Comparison chart (Tab 6)
- [ ] Create `dash_app/callbacks/dosing.py`:
diff --git a/dash_app/callbacks/chart.py b/dash_app/callbacks/chart.py
index 5d6a088..1b56d59 100644
--- a/dash_app/callbacks/chart.py
+++ b/dash_app/callbacks/chart.py
@@ -158,6 +158,31 @@ def _render_cost_waterfall(app_state, title):
return create_cost_waterfall_figure(data, title)
+def _render_sankey(app_state, title):
+ """Build the Sankey diagram from current filter state."""
+ from dash_app.data.queries import get_drug_transitions
+ from visualization.plotly_generator import create_sankey_figure
+
+ filter_id = (app_state or {}).get("date_filter_id", "all_6mo")
+ chart_type = (app_state or {}).get("chart_type", "directory")
+
+ selected_dirs = (app_state or {}).get("selected_directorates") or []
+ selected_trusts = (app_state or {}).get("selected_trusts") or []
+ directory = selected_dirs[0] if len(selected_dirs) == 1 else None
+ trust = selected_trusts[0] if len(selected_trusts) == 1 else None
+
+ try:
+ data = get_drug_transitions(filter_id, chart_type, directory, trust)
+ except Exception:
+ log.exception("Failed to load drug transition data")
+ return _empty_figure("Failed to load drug transition data.")
+
+ if not data.get("nodes") or not data.get("links"):
+ return _empty_figure("No drug switching data available.\nTry adjusting your filters.")
+
+ return create_sankey_figure(data, title)
+
+
def register_chart_callbacks(app):
"""Register tab switching, pathway data loading, and chart rendering callbacks."""
@@ -280,6 +305,9 @@ def register_chart_callbacks(app):
elif active_tab == "cost-waterfall":
fig = _render_cost_waterfall(app_state, title)
+ elif active_tab == "sankey":
+ fig = _render_sankey(app_state, title)
+
else:
# Placeholder for charts not yet implemented
tab_label = dict(TAB_DEFINITIONS).get(active_tab, active_tab)
diff --git a/src/visualization/plotly_generator.py b/src/visualization/plotly_generator.py
index 8be52cd..dee7d8d 100644
--- a/src/visualization/plotly_generator.py
+++ b/src/visualization/plotly_generator.py
@@ -703,6 +703,137 @@ def create_cost_waterfall_figure(
return fig
+def create_sankey_figure(
+ data: dict,
+ title: str = "",
+) -> go.Figure:
+ """Create Sankey diagram showing drug switching flows between treatment lines.
+
+ Args:
+ data: Dict from get_drug_transitions() with keys:
+ nodes: [{name}] — drug names with ordinal suffixes (e.g., "ADALIMUMAB (1st)")
+ links: [{source_idx, target_idx, patients}] — transitions between drugs
+ title: Chart title suffix (filter description).
+
+ Returns:
+ Plotly Figure with Sankey diagram.
+ """
+ import re
+
+ nodes = data.get("nodes", [])
+ links = data.get("links", [])
+
+ if not nodes or not links:
+ return go.Figure()
+
+ # NHS colour palette — one colour per unique base drug name
+ nhs_colours = [
+ "#005EB8", "#003087", "#41B6E6", "#0066CC", "#1E88E5",
+ "#4FC3F7", "#009639", "#ED8B00", "#768692", "#AE2573",
+ "#8A1538", "#330072", "#DA291C", "#00A499", "#425563",
+ ]
+
+ # Extract base drug name (strip ordinal suffix) for colour consistency
+ def base_drug(name: str) -> str:
+ return re.sub(r"\s*\(\d+(?:st|nd|rd|th)\)\s*$", "", name)
+
+ unique_bases = []
+ for n in nodes:
+ b = base_drug(n["name"])
+ if b not in unique_bases:
+ unique_bases.append(b)
+ base_colour_map = {b: nhs_colours[i % len(nhs_colours)] for i, b in enumerate(unique_bases)}
+
+ # Node colours — same drug gets same colour regardless of treatment line
+ node_colours = [base_colour_map[base_drug(n["name"])] for n in nodes]
+
+ # Node labels — format nicely
+ node_labels = [n["name"] for n in nodes]
+
+ # Link colours — use source node colour at 40% opacity for visual clarity
+ def hex_to_rgba(hex_colour: str, alpha: float) -> str:
+ h = hex_colour.lstrip("#")
+ r, g, b = int(h[0:2], 16), int(h[2:4], 16), int(h[4:6], 16)
+ return f"rgba({r},{g},{b},{alpha})"
+
+ link_colours = [
+ hex_to_rgba(node_colours[link["source_idx"]], 0.35)
+ for link in links
+ ]
+
+ # Build hover text for links
+ link_hovers = [
+ f"{node_labels[link['source_idx']]} → {node_labels[link['target_idx']]}"
+ f"
Patients: {link['patients']:,}"
+ for link in links
+ ]
+
+ # Compute total patients per node for node hover
+ node_patients = [0] * len(nodes)
+ for link in links:
+ node_patients[link["source_idx"]] += link["patients"]
+ # For terminal nodes (no outgoing), use incoming total
+ node_incoming = [0] * len(nodes)
+ for link in links:
+ node_incoming[link["target_idx"]] += link["patients"]
+ node_hover = []
+ for i, n in enumerate(nodes):
+ out_p = node_patients[i]
+ in_p = node_incoming[i]
+ total = max(out_p, in_p)
+ node_hover.append(f"{n['name']}
Patients: {total:,}")
+
+ fig = go.Figure(
+ go.Sankey(
+ arrangement="snap",
+ node=dict(
+ pad=20,
+ thickness=25,
+ line=dict(color="#FFFFFF", width=1),
+ label=node_labels,
+ color=node_colours,
+ customdata=node_hover,
+ hovertemplate="%{customdata}",
+ ),
+ link=dict(
+ source=[link["source_idx"] for link in links],
+ target=[link["target_idx"] for link in links],
+ value=[link["patients"] for link in links],
+ color=link_colours,
+ customdata=link_hovers,
+ hovertemplate="%{customdata}",
+ ),
+ )
+ )
+
+ chart_title = "Drug Switching Flows"
+ if title:
+ chart_title = f"{chart_title} — {title}"
+
+ fig.update_layout(
+ title=dict(
+ text=chart_title,
+ font=dict(
+ family="Source Sans 3, system-ui, sans-serif",
+ size=18,
+ color="#003087",
+ ),
+ x=0.5,
+ xanchor="center",
+ ),
+ font=dict(
+ family="Source Sans 3, system-ui, sans-serif",
+ size=12,
+ ),
+ paper_bgcolor="rgba(0,0,0,0)",
+ plot_bgcolor="rgba(0,0,0,0)",
+ margin=dict(t=60, l=30, r=30, b=30),
+ height=max(500, len(unique_bases) * 35 + 200),
+ )
+
+ return fig
+
+
def save_figure_html(
fig: go.Figure, save_dir: str, title: str, open_browser: bool = False
) -> str: