feat: add Drug Switching Sankey diagram (Task 9.6)
This commit is contained in:
@@ -402,14 +402,14 @@ Drawer selection → update_drug_selection → app-state store → load_pathway_
|
||||
- **Checkpoint**: Cost Waterfall tab renders real data, responds to filters ✓
|
||||
|
||||
### 9.6 Drug Switching Sankey chart (Tab 5)
|
||||
- [ ] Create `dash_app/callbacks/sankey.py`:
|
||||
- [x] Create `dash_app/callbacks/sankey.py`:
|
||||
- Build Plotly Sankey diagram from `get_drug_transitions()` data
|
||||
- Left nodes = 1st-line drugs, middle = 2nd-line, right = 3rd-line
|
||||
- Link width = patient count, colour by drug or directorate
|
||||
- Uses `parse_pathway_drugs()` to extract drug transitions from `ids` column
|
||||
- [ ] Create figure function in `src/visualization/`
|
||||
- [ ] Wire into tab switching
|
||||
- **Checkpoint**: Sankey tab renders real drug transition flows
|
||||
- [x] Create figure function in `src/visualization/`
|
||||
- [x] Wire into tab switching
|
||||
- **Checkpoint**: Sankey tab renders real drug transition flows ✓
|
||||
|
||||
### 9.7 Dosing Interval Comparison chart (Tab 6)
|
||||
- [ ] Create `dash_app/callbacks/dosing.py`:
|
||||
|
||||
@@ -158,6 +158,31 @@ def _render_cost_waterfall(app_state, title):
|
||||
return create_cost_waterfall_figure(data, title)
|
||||
|
||||
|
||||
def _render_sankey(app_state, title):
|
||||
"""Build the Sankey diagram from current filter state."""
|
||||
from dash_app.data.queries import get_drug_transitions
|
||||
from visualization.plotly_generator import create_sankey_figure
|
||||
|
||||
filter_id = (app_state or {}).get("date_filter_id", "all_6mo")
|
||||
chart_type = (app_state or {}).get("chart_type", "directory")
|
||||
|
||||
selected_dirs = (app_state or {}).get("selected_directorates") or []
|
||||
selected_trusts = (app_state or {}).get("selected_trusts") or []
|
||||
directory = selected_dirs[0] if len(selected_dirs) == 1 else None
|
||||
trust = selected_trusts[0] if len(selected_trusts) == 1 else None
|
||||
|
||||
try:
|
||||
data = get_drug_transitions(filter_id, chart_type, directory, trust)
|
||||
except Exception:
|
||||
log.exception("Failed to load drug transition data")
|
||||
return _empty_figure("Failed to load drug transition data.")
|
||||
|
||||
if not data.get("nodes") or not data.get("links"):
|
||||
return _empty_figure("No drug switching data available.\nTry adjusting your filters.")
|
||||
|
||||
return create_sankey_figure(data, title)
|
||||
|
||||
|
||||
def register_chart_callbacks(app):
|
||||
"""Register tab switching, pathway data loading, and chart rendering callbacks."""
|
||||
|
||||
@@ -280,6 +305,9 @@ def register_chart_callbacks(app):
|
||||
elif active_tab == "cost-waterfall":
|
||||
fig = _render_cost_waterfall(app_state, title)
|
||||
|
||||
elif active_tab == "sankey":
|
||||
fig = _render_sankey(app_state, title)
|
||||
|
||||
else:
|
||||
# Placeholder for charts not yet implemented
|
||||
tab_label = dict(TAB_DEFINITIONS).get(active_tab, active_tab)
|
||||
|
||||
@@ -703,6 +703,137 @@ def create_cost_waterfall_figure(
|
||||
return fig
|
||||
|
||||
|
||||
def create_sankey_figure(
|
||||
data: dict,
|
||||
title: str = "",
|
||||
) -> go.Figure:
|
||||
"""Create Sankey diagram showing drug switching flows between treatment lines.
|
||||
|
||||
Args:
|
||||
data: Dict from get_drug_transitions() with keys:
|
||||
nodes: [{name}] — drug names with ordinal suffixes (e.g., "ADALIMUMAB (1st)")
|
||||
links: [{source_idx, target_idx, patients}] — transitions between drugs
|
||||
title: Chart title suffix (filter description).
|
||||
|
||||
Returns:
|
||||
Plotly Figure with Sankey diagram.
|
||||
"""
|
||||
import re
|
||||
|
||||
nodes = data.get("nodes", [])
|
||||
links = data.get("links", [])
|
||||
|
||||
if not nodes or not links:
|
||||
return go.Figure()
|
||||
|
||||
# NHS colour palette — one colour per unique base drug name
|
||||
nhs_colours = [
|
||||
"#005EB8", "#003087", "#41B6E6", "#0066CC", "#1E88E5",
|
||||
"#4FC3F7", "#009639", "#ED8B00", "#768692", "#AE2573",
|
||||
"#8A1538", "#330072", "#DA291C", "#00A499", "#425563",
|
||||
]
|
||||
|
||||
# Extract base drug name (strip ordinal suffix) for colour consistency
|
||||
def base_drug(name: str) -> str:
|
||||
return re.sub(r"\s*\(\d+(?:st|nd|rd|th)\)\s*$", "", name)
|
||||
|
||||
unique_bases = []
|
||||
for n in nodes:
|
||||
b = base_drug(n["name"])
|
||||
if b not in unique_bases:
|
||||
unique_bases.append(b)
|
||||
base_colour_map = {b: nhs_colours[i % len(nhs_colours)] for i, b in enumerate(unique_bases)}
|
||||
|
||||
# Node colours — same drug gets same colour regardless of treatment line
|
||||
node_colours = [base_colour_map[base_drug(n["name"])] for n in nodes]
|
||||
|
||||
# Node labels — format nicely
|
||||
node_labels = [n["name"] for n in nodes]
|
||||
|
||||
# Link colours — use source node colour at 40% opacity for visual clarity
|
||||
def hex_to_rgba(hex_colour: str, alpha: float) -> str:
|
||||
h = hex_colour.lstrip("#")
|
||||
r, g, b = int(h[0:2], 16), int(h[2:4], 16), int(h[4:6], 16)
|
||||
return f"rgba({r},{g},{b},{alpha})"
|
||||
|
||||
link_colours = [
|
||||
hex_to_rgba(node_colours[link["source_idx"]], 0.35)
|
||||
for link in links
|
||||
]
|
||||
|
||||
# Build hover text for links
|
||||
link_hovers = [
|
||||
f"{node_labels[link['source_idx']]} → {node_labels[link['target_idx']]}"
|
||||
f"<br>Patients: {link['patients']:,}"
|
||||
for link in links
|
||||
]
|
||||
|
||||
# Compute total patients per node for node hover
|
||||
node_patients = [0] * len(nodes)
|
||||
for link in links:
|
||||
node_patients[link["source_idx"]] += link["patients"]
|
||||
# For terminal nodes (no outgoing), use incoming total
|
||||
node_incoming = [0] * len(nodes)
|
||||
for link in links:
|
||||
node_incoming[link["target_idx"]] += link["patients"]
|
||||
node_hover = []
|
||||
for i, n in enumerate(nodes):
|
||||
out_p = node_patients[i]
|
||||
in_p = node_incoming[i]
|
||||
total = max(out_p, in_p)
|
||||
node_hover.append(f"<b>{n['name']}</b><br>Patients: {total:,}")
|
||||
|
||||
fig = go.Figure(
|
||||
go.Sankey(
|
||||
arrangement="snap",
|
||||
node=dict(
|
||||
pad=20,
|
||||
thickness=25,
|
||||
line=dict(color="#FFFFFF", width=1),
|
||||
label=node_labels,
|
||||
color=node_colours,
|
||||
customdata=node_hover,
|
||||
hovertemplate="%{customdata}<extra></extra>",
|
||||
),
|
||||
link=dict(
|
||||
source=[link["source_idx"] for link in links],
|
||||
target=[link["target_idx"] for link in links],
|
||||
value=[link["patients"] for link in links],
|
||||
color=link_colours,
|
||||
customdata=link_hovers,
|
||||
hovertemplate="%{customdata}<extra></extra>",
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
chart_title = "Drug Switching Flows"
|
||||
if title:
|
||||
chart_title = f"{chart_title} — {title}"
|
||||
|
||||
fig.update_layout(
|
||||
title=dict(
|
||||
text=chart_title,
|
||||
font=dict(
|
||||
family="Source Sans 3, system-ui, sans-serif",
|
||||
size=18,
|
||||
color="#003087",
|
||||
),
|
||||
x=0.5,
|
||||
xanchor="center",
|
||||
),
|
||||
font=dict(
|
||||
family="Source Sans 3, system-ui, sans-serif",
|
||||
size=12,
|
||||
),
|
||||
paper_bgcolor="rgba(0,0,0,0)",
|
||||
plot_bgcolor="rgba(0,0,0,0)",
|
||||
margin=dict(t=60, l=30, r=30, b=30),
|
||||
height=max(500, len(unique_bases) * 35 + 200),
|
||||
)
|
||||
|
||||
return fig
|
||||
|
||||
|
||||
def save_figure_html(
|
||||
fig: go.Figure, save_dir: str, title: str, open_browser: bool = False
|
||||
) -> str:
|
||||
|
||||
Reference in New Issue
Block a user