feat: add Drug Switching Sankey diagram (Task 9.6)
This commit is contained in:
@@ -402,14 +402,14 @@ Drawer selection → update_drug_selection → app-state store → load_pathway_
|
|||||||
- **Checkpoint**: Cost Waterfall tab renders real data, responds to filters ✓
|
- **Checkpoint**: Cost Waterfall tab renders real data, responds to filters ✓
|
||||||
|
|
||||||
### 9.6 Drug Switching Sankey chart (Tab 5)
|
### 9.6 Drug Switching Sankey chart (Tab 5)
|
||||||
- [ ] Create `dash_app/callbacks/sankey.py`:
|
- [x] Create `dash_app/callbacks/sankey.py`:
|
||||||
- Build Plotly Sankey diagram from `get_drug_transitions()` data
|
- Build Plotly Sankey diagram from `get_drug_transitions()` data
|
||||||
- Left nodes = 1st-line drugs, middle = 2nd-line, right = 3rd-line
|
- Left nodes = 1st-line drugs, middle = 2nd-line, right = 3rd-line
|
||||||
- Link width = patient count, colour by drug or directorate
|
- Link width = patient count, colour by drug or directorate
|
||||||
- Uses `parse_pathway_drugs()` to extract drug transitions from `ids` column
|
- Uses `parse_pathway_drugs()` to extract drug transitions from `ids` column
|
||||||
- [ ] Create figure function in `src/visualization/`
|
- [x] Create figure function in `src/visualization/`
|
||||||
- [ ] Wire into tab switching
|
- [x] Wire into tab switching
|
||||||
- **Checkpoint**: Sankey tab renders real drug transition flows
|
- **Checkpoint**: Sankey tab renders real drug transition flows ✓
|
||||||
|
|
||||||
### 9.7 Dosing Interval Comparison chart (Tab 6)
|
### 9.7 Dosing Interval Comparison chart (Tab 6)
|
||||||
- [ ] Create `dash_app/callbacks/dosing.py`:
|
- [ ] Create `dash_app/callbacks/dosing.py`:
|
||||||
|
|||||||
@@ -158,6 +158,31 @@ def _render_cost_waterfall(app_state, title):
|
|||||||
return create_cost_waterfall_figure(data, title)
|
return create_cost_waterfall_figure(data, title)
|
||||||
|
|
||||||
|
|
||||||
|
def _render_sankey(app_state, title):
|
||||||
|
"""Build the Sankey diagram from current filter state."""
|
||||||
|
from dash_app.data.queries import get_drug_transitions
|
||||||
|
from visualization.plotly_generator import create_sankey_figure
|
||||||
|
|
||||||
|
filter_id = (app_state or {}).get("date_filter_id", "all_6mo")
|
||||||
|
chart_type = (app_state or {}).get("chart_type", "directory")
|
||||||
|
|
||||||
|
selected_dirs = (app_state or {}).get("selected_directorates") or []
|
||||||
|
selected_trusts = (app_state or {}).get("selected_trusts") or []
|
||||||
|
directory = selected_dirs[0] if len(selected_dirs) == 1 else None
|
||||||
|
trust = selected_trusts[0] if len(selected_trusts) == 1 else None
|
||||||
|
|
||||||
|
try:
|
||||||
|
data = get_drug_transitions(filter_id, chart_type, directory, trust)
|
||||||
|
except Exception:
|
||||||
|
log.exception("Failed to load drug transition data")
|
||||||
|
return _empty_figure("Failed to load drug transition data.")
|
||||||
|
|
||||||
|
if not data.get("nodes") or not data.get("links"):
|
||||||
|
return _empty_figure("No drug switching data available.\nTry adjusting your filters.")
|
||||||
|
|
||||||
|
return create_sankey_figure(data, title)
|
||||||
|
|
||||||
|
|
||||||
def register_chart_callbacks(app):
|
def register_chart_callbacks(app):
|
||||||
"""Register tab switching, pathway data loading, and chart rendering callbacks."""
|
"""Register tab switching, pathway data loading, and chart rendering callbacks."""
|
||||||
|
|
||||||
@@ -280,6 +305,9 @@ def register_chart_callbacks(app):
|
|||||||
elif active_tab == "cost-waterfall":
|
elif active_tab == "cost-waterfall":
|
||||||
fig = _render_cost_waterfall(app_state, title)
|
fig = _render_cost_waterfall(app_state, title)
|
||||||
|
|
||||||
|
elif active_tab == "sankey":
|
||||||
|
fig = _render_sankey(app_state, title)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# Placeholder for charts not yet implemented
|
# Placeholder for charts not yet implemented
|
||||||
tab_label = dict(TAB_DEFINITIONS).get(active_tab, active_tab)
|
tab_label = dict(TAB_DEFINITIONS).get(active_tab, active_tab)
|
||||||
|
|||||||
@@ -703,6 +703,137 @@ def create_cost_waterfall_figure(
|
|||||||
return fig
|
return fig
|
||||||
|
|
||||||
|
|
||||||
|
def create_sankey_figure(
|
||||||
|
data: dict,
|
||||||
|
title: str = "",
|
||||||
|
) -> go.Figure:
|
||||||
|
"""Create Sankey diagram showing drug switching flows between treatment lines.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: Dict from get_drug_transitions() with keys:
|
||||||
|
nodes: [{name}] — drug names with ordinal suffixes (e.g., "ADALIMUMAB (1st)")
|
||||||
|
links: [{source_idx, target_idx, patients}] — transitions between drugs
|
||||||
|
title: Chart title suffix (filter description).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Plotly Figure with Sankey diagram.
|
||||||
|
"""
|
||||||
|
import re
|
||||||
|
|
||||||
|
nodes = data.get("nodes", [])
|
||||||
|
links = data.get("links", [])
|
||||||
|
|
||||||
|
if not nodes or not links:
|
||||||
|
return go.Figure()
|
||||||
|
|
||||||
|
# NHS colour palette — one colour per unique base drug name
|
||||||
|
nhs_colours = [
|
||||||
|
"#005EB8", "#003087", "#41B6E6", "#0066CC", "#1E88E5",
|
||||||
|
"#4FC3F7", "#009639", "#ED8B00", "#768692", "#AE2573",
|
||||||
|
"#8A1538", "#330072", "#DA291C", "#00A499", "#425563",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Extract base drug name (strip ordinal suffix) for colour consistency
|
||||||
|
def base_drug(name: str) -> str:
|
||||||
|
return re.sub(r"\s*\(\d+(?:st|nd|rd|th)\)\s*$", "", name)
|
||||||
|
|
||||||
|
unique_bases = []
|
||||||
|
for n in nodes:
|
||||||
|
b = base_drug(n["name"])
|
||||||
|
if b not in unique_bases:
|
||||||
|
unique_bases.append(b)
|
||||||
|
base_colour_map = {b: nhs_colours[i % len(nhs_colours)] for i, b in enumerate(unique_bases)}
|
||||||
|
|
||||||
|
# Node colours — same drug gets same colour regardless of treatment line
|
||||||
|
node_colours = [base_colour_map[base_drug(n["name"])] for n in nodes]
|
||||||
|
|
||||||
|
# Node labels — format nicely
|
||||||
|
node_labels = [n["name"] for n in nodes]
|
||||||
|
|
||||||
|
# Link colours — use source node colour at 40% opacity for visual clarity
|
||||||
|
def hex_to_rgba(hex_colour: str, alpha: float) -> str:
|
||||||
|
h = hex_colour.lstrip("#")
|
||||||
|
r, g, b = int(h[0:2], 16), int(h[2:4], 16), int(h[4:6], 16)
|
||||||
|
return f"rgba({r},{g},{b},{alpha})"
|
||||||
|
|
||||||
|
link_colours = [
|
||||||
|
hex_to_rgba(node_colours[link["source_idx"]], 0.35)
|
||||||
|
for link in links
|
||||||
|
]
|
||||||
|
|
||||||
|
# Build hover text for links
|
||||||
|
link_hovers = [
|
||||||
|
f"{node_labels[link['source_idx']]} → {node_labels[link['target_idx']]}"
|
||||||
|
f"<br>Patients: {link['patients']:,}"
|
||||||
|
for link in links
|
||||||
|
]
|
||||||
|
|
||||||
|
# Compute total patients per node for node hover
|
||||||
|
node_patients = [0] * len(nodes)
|
||||||
|
for link in links:
|
||||||
|
node_patients[link["source_idx"]] += link["patients"]
|
||||||
|
# For terminal nodes (no outgoing), use incoming total
|
||||||
|
node_incoming = [0] * len(nodes)
|
||||||
|
for link in links:
|
||||||
|
node_incoming[link["target_idx"]] += link["patients"]
|
||||||
|
node_hover = []
|
||||||
|
for i, n in enumerate(nodes):
|
||||||
|
out_p = node_patients[i]
|
||||||
|
in_p = node_incoming[i]
|
||||||
|
total = max(out_p, in_p)
|
||||||
|
node_hover.append(f"<b>{n['name']}</b><br>Patients: {total:,}")
|
||||||
|
|
||||||
|
fig = go.Figure(
|
||||||
|
go.Sankey(
|
||||||
|
arrangement="snap",
|
||||||
|
node=dict(
|
||||||
|
pad=20,
|
||||||
|
thickness=25,
|
||||||
|
line=dict(color="#FFFFFF", width=1),
|
||||||
|
label=node_labels,
|
||||||
|
color=node_colours,
|
||||||
|
customdata=node_hover,
|
||||||
|
hovertemplate="%{customdata}<extra></extra>",
|
||||||
|
),
|
||||||
|
link=dict(
|
||||||
|
source=[link["source_idx"] for link in links],
|
||||||
|
target=[link["target_idx"] for link in links],
|
||||||
|
value=[link["patients"] for link in links],
|
||||||
|
color=link_colours,
|
||||||
|
customdata=link_hovers,
|
||||||
|
hovertemplate="%{customdata}<extra></extra>",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
chart_title = "Drug Switching Flows"
|
||||||
|
if title:
|
||||||
|
chart_title = f"{chart_title} — {title}"
|
||||||
|
|
||||||
|
fig.update_layout(
|
||||||
|
title=dict(
|
||||||
|
text=chart_title,
|
||||||
|
font=dict(
|
||||||
|
family="Source Sans 3, system-ui, sans-serif",
|
||||||
|
size=18,
|
||||||
|
color="#003087",
|
||||||
|
),
|
||||||
|
x=0.5,
|
||||||
|
xanchor="center",
|
||||||
|
),
|
||||||
|
font=dict(
|
||||||
|
family="Source Sans 3, system-ui, sans-serif",
|
||||||
|
size=12,
|
||||||
|
),
|
||||||
|
paper_bgcolor="rgba(0,0,0,0)",
|
||||||
|
plot_bgcolor="rgba(0,0,0,0)",
|
||||||
|
margin=dict(t=60, l=30, r=30, b=30),
|
||||||
|
height=max(500, len(unique_bases) * 35 + 200),
|
||||||
|
)
|
||||||
|
|
||||||
|
return fig
|
||||||
|
|
||||||
|
|
||||||
def save_figure_html(
|
def save_figure_html(
|
||||||
fig: go.Figure, save_dir: str, title: str, open_browser: bool = False
|
fig: go.Figure, save_dir: str, title: str, open_browser: bool = False
|
||||||
) -> str:
|
) -> str:
|
||||||
|
|||||||
Reference in New Issue
Block a user