feat: add Drug Switching Sankey diagram (Task 9.6)

This commit is contained in:
Andrew Charlwood
2026-02-06 19:50:43 +00:00
parent ba39f94c32
commit 4ffcdf4268
3 changed files with 163 additions and 4 deletions
+4 -4
View File
@@ -402,14 +402,14 @@ Drawer selection → update_drug_selection → app-state store → load_pathway_
- **Checkpoint**: Cost Waterfall tab renders real data, responds to filters ✓ - **Checkpoint**: Cost Waterfall tab renders real data, responds to filters ✓
### 9.6 Drug Switching Sankey chart (Tab 5) ### 9.6 Drug Switching Sankey chart (Tab 5)
- [ ] Create `dash_app/callbacks/sankey.py`: - [x] Create `dash_app/callbacks/sankey.py`:
- Build Plotly Sankey diagram from `get_drug_transitions()` data - Build Plotly Sankey diagram from `get_drug_transitions()` data
- Left nodes = 1st-line drugs, middle = 2nd-line, right = 3rd-line - Left nodes = 1st-line drugs, middle = 2nd-line, right = 3rd-line
- Link width = patient count, colour by drug or directorate - Link width = patient count, colour by drug or directorate
- Uses `parse_pathway_drugs()` to extract drug transitions from `ids` column - Uses `parse_pathway_drugs()` to extract drug transitions from `ids` column
- [ ] Create figure function in `src/visualization/` - [x] Create figure function in `src/visualization/`
- [ ] Wire into tab switching - [x] Wire into tab switching
- **Checkpoint**: Sankey tab renders real drug transition flows - **Checkpoint**: Sankey tab renders real drug transition flows
### 9.7 Dosing Interval Comparison chart (Tab 6) ### 9.7 Dosing Interval Comparison chart (Tab 6)
- [ ] Create `dash_app/callbacks/dosing.py`: - [ ] Create `dash_app/callbacks/dosing.py`:
+28
View File
@@ -158,6 +158,31 @@ def _render_cost_waterfall(app_state, title):
return create_cost_waterfall_figure(data, title) return create_cost_waterfall_figure(data, title)
def _render_sankey(app_state, title):
"""Build the Sankey diagram from current filter state."""
from dash_app.data.queries import get_drug_transitions
from visualization.plotly_generator import create_sankey_figure
filter_id = (app_state or {}).get("date_filter_id", "all_6mo")
chart_type = (app_state or {}).get("chart_type", "directory")
selected_dirs = (app_state or {}).get("selected_directorates") or []
selected_trusts = (app_state or {}).get("selected_trusts") or []
directory = selected_dirs[0] if len(selected_dirs) == 1 else None
trust = selected_trusts[0] if len(selected_trusts) == 1 else None
try:
data = get_drug_transitions(filter_id, chart_type, directory, trust)
except Exception:
log.exception("Failed to load drug transition data")
return _empty_figure("Failed to load drug transition data.")
if not data.get("nodes") or not data.get("links"):
return _empty_figure("No drug switching data available.\nTry adjusting your filters.")
return create_sankey_figure(data, title)
def register_chart_callbacks(app): def register_chart_callbacks(app):
"""Register tab switching, pathway data loading, and chart rendering callbacks.""" """Register tab switching, pathway data loading, and chart rendering callbacks."""
@@ -280,6 +305,9 @@ def register_chart_callbacks(app):
elif active_tab == "cost-waterfall": elif active_tab == "cost-waterfall":
fig = _render_cost_waterfall(app_state, title) fig = _render_cost_waterfall(app_state, title)
elif active_tab == "sankey":
fig = _render_sankey(app_state, title)
else: else:
# Placeholder for charts not yet implemented # Placeholder for charts not yet implemented
tab_label = dict(TAB_DEFINITIONS).get(active_tab, active_tab) tab_label = dict(TAB_DEFINITIONS).get(active_tab, active_tab)
+131
View File
@@ -703,6 +703,137 @@ def create_cost_waterfall_figure(
return fig return fig
def create_sankey_figure(
data: dict,
title: str = "",
) -> go.Figure:
"""Create Sankey diagram showing drug switching flows between treatment lines.
Args:
data: Dict from get_drug_transitions() with keys:
nodes: [{name}] — drug names with ordinal suffixes (e.g., "ADALIMUMAB (1st)")
links: [{source_idx, target_idx, patients}] — transitions between drugs
title: Chart title suffix (filter description).
Returns:
Plotly Figure with Sankey diagram.
"""
import re
nodes = data.get("nodes", [])
links = data.get("links", [])
if not nodes or not links:
return go.Figure()
# NHS colour palette — one colour per unique base drug name
nhs_colours = [
"#005EB8", "#003087", "#41B6E6", "#0066CC", "#1E88E5",
"#4FC3F7", "#009639", "#ED8B00", "#768692", "#AE2573",
"#8A1538", "#330072", "#DA291C", "#00A499", "#425563",
]
# Extract base drug name (strip ordinal suffix) for colour consistency
def base_drug(name: str) -> str:
return re.sub(r"\s*\(\d+(?:st|nd|rd|th)\)\s*$", "", name)
unique_bases = []
for n in nodes:
b = base_drug(n["name"])
if b not in unique_bases:
unique_bases.append(b)
base_colour_map = {b: nhs_colours[i % len(nhs_colours)] for i, b in enumerate(unique_bases)}
# Node colours — same drug gets same colour regardless of treatment line
node_colours = [base_colour_map[base_drug(n["name"])] for n in nodes]
# Node labels — format nicely
node_labels = [n["name"] for n in nodes]
# Link colours — use source node colour at 40% opacity for visual clarity
def hex_to_rgba(hex_colour: str, alpha: float) -> str:
h = hex_colour.lstrip("#")
r, g, b = int(h[0:2], 16), int(h[2:4], 16), int(h[4:6], 16)
return f"rgba({r},{g},{b},{alpha})"
link_colours = [
hex_to_rgba(node_colours[link["source_idx"]], 0.35)
for link in links
]
# Build hover text for links
link_hovers = [
f"{node_labels[link['source_idx']]}{node_labels[link['target_idx']]}"
f"<br>Patients: {link['patients']:,}"
for link in links
]
# Compute total patients per node for node hover
node_patients = [0] * len(nodes)
for link in links:
node_patients[link["source_idx"]] += link["patients"]
# For terminal nodes (no outgoing), use incoming total
node_incoming = [0] * len(nodes)
for link in links:
node_incoming[link["target_idx"]] += link["patients"]
node_hover = []
for i, n in enumerate(nodes):
out_p = node_patients[i]
in_p = node_incoming[i]
total = max(out_p, in_p)
node_hover.append(f"<b>{n['name']}</b><br>Patients: {total:,}")
fig = go.Figure(
go.Sankey(
arrangement="snap",
node=dict(
pad=20,
thickness=25,
line=dict(color="#FFFFFF", width=1),
label=node_labels,
color=node_colours,
customdata=node_hover,
hovertemplate="%{customdata}<extra></extra>",
),
link=dict(
source=[link["source_idx"] for link in links],
target=[link["target_idx"] for link in links],
value=[link["patients"] for link in links],
color=link_colours,
customdata=link_hovers,
hovertemplate="%{customdata}<extra></extra>",
),
)
)
chart_title = "Drug Switching Flows"
if title:
chart_title = f"{chart_title}{title}"
fig.update_layout(
title=dict(
text=chart_title,
font=dict(
family="Source Sans 3, system-ui, sans-serif",
size=18,
color="#003087",
),
x=0.5,
xanchor="center",
),
font=dict(
family="Source Sans 3, system-ui, sans-serif",
size=12,
),
paper_bgcolor="rgba(0,0,0,0)",
plot_bgcolor="rgba(0,0,0,0)",
margin=dict(t=60, l=30, r=30, b=30),
height=max(500, len(unique_bases) * 35 + 200),
)
return fig
def save_figure_html( def save_figure_html(
fig: go.Figure, save_dir: str, title: str, open_browser: bool = False fig: go.Figure, save_dir: str, title: str, open_browser: bool = False
) -> str: ) -> str: