diff --git a/IMPLEMENTATION_PLAN.md b/IMPLEMENTATION_PLAN.md index 20cd592..a07ea4f 100644 --- a/IMPLEMENTATION_PLAN.md +++ b/IMPLEMENTATION_PLAN.md @@ -402,14 +402,14 @@ Drawer selection → update_drug_selection → app-state store → load_pathway_ - **Checkpoint**: Cost Waterfall tab renders real data, responds to filters ✓ ### 9.6 Drug Switching Sankey chart (Tab 5) -- [ ] Create `dash_app/callbacks/sankey.py`: +- [x] Create `dash_app/callbacks/sankey.py`: - Build Plotly Sankey diagram from `get_drug_transitions()` data - Left nodes = 1st-line drugs, middle = 2nd-line, right = 3rd-line - Link width = patient count, colour by drug or directorate - Uses `parse_pathway_drugs()` to extract drug transitions from `ids` column -- [ ] Create figure function in `src/visualization/` -- [ ] Wire into tab switching -- **Checkpoint**: Sankey tab renders real drug transition flows +- [x] Create figure function in `src/visualization/` +- [x] Wire into tab switching +- **Checkpoint**: Sankey tab renders real drug transition flows ✓ ### 9.7 Dosing Interval Comparison chart (Tab 6) - [ ] Create `dash_app/callbacks/dosing.py`: diff --git a/dash_app/callbacks/chart.py b/dash_app/callbacks/chart.py index 5d6a088..1b56d59 100644 --- a/dash_app/callbacks/chart.py +++ b/dash_app/callbacks/chart.py @@ -158,6 +158,31 @@ def _render_cost_waterfall(app_state, title): return create_cost_waterfall_figure(data, title) +def _render_sankey(app_state, title): + """Build the Sankey diagram from current filter state.""" + from dash_app.data.queries import get_drug_transitions + from visualization.plotly_generator import create_sankey_figure + + filter_id = (app_state or {}).get("date_filter_id", "all_6mo") + chart_type = (app_state or {}).get("chart_type", "directory") + + selected_dirs = (app_state or {}).get("selected_directorates") or [] + selected_trusts = (app_state or {}).get("selected_trusts") or [] + directory = selected_dirs[0] if len(selected_dirs) == 1 else None + trust = selected_trusts[0] if len(selected_trusts) == 1 else None + + try: + data = get_drug_transitions(filter_id, chart_type, directory, trust) + except Exception: + log.exception("Failed to load drug transition data") + return _empty_figure("Failed to load drug transition data.") + + if not data.get("nodes") or not data.get("links"): + return _empty_figure("No drug switching data available.\nTry adjusting your filters.") + + return create_sankey_figure(data, title) + + def register_chart_callbacks(app): """Register tab switching, pathway data loading, and chart rendering callbacks.""" @@ -280,6 +305,9 @@ def register_chart_callbacks(app): elif active_tab == "cost-waterfall": fig = _render_cost_waterfall(app_state, title) + elif active_tab == "sankey": + fig = _render_sankey(app_state, title) + else: # Placeholder for charts not yet implemented tab_label = dict(TAB_DEFINITIONS).get(active_tab, active_tab) diff --git a/src/visualization/plotly_generator.py b/src/visualization/plotly_generator.py index 8be52cd..dee7d8d 100644 --- a/src/visualization/plotly_generator.py +++ b/src/visualization/plotly_generator.py @@ -703,6 +703,137 @@ def create_cost_waterfall_figure( return fig +def create_sankey_figure( + data: dict, + title: str = "", +) -> go.Figure: + """Create Sankey diagram showing drug switching flows between treatment lines. + + Args: + data: Dict from get_drug_transitions() with keys: + nodes: [{name}] — drug names with ordinal suffixes (e.g., "ADALIMUMAB (1st)") + links: [{source_idx, target_idx, patients}] — transitions between drugs + title: Chart title suffix (filter description). + + Returns: + Plotly Figure with Sankey diagram. + """ + import re + + nodes = data.get("nodes", []) + links = data.get("links", []) + + if not nodes or not links: + return go.Figure() + + # NHS colour palette — one colour per unique base drug name + nhs_colours = [ + "#005EB8", "#003087", "#41B6E6", "#0066CC", "#1E88E5", + "#4FC3F7", "#009639", "#ED8B00", "#768692", "#AE2573", + "#8A1538", "#330072", "#DA291C", "#00A499", "#425563", + ] + + # Extract base drug name (strip ordinal suffix) for colour consistency + def base_drug(name: str) -> str: + return re.sub(r"\s*\(\d+(?:st|nd|rd|th)\)\s*$", "", name) + + unique_bases = [] + for n in nodes: + b = base_drug(n["name"]) + if b not in unique_bases: + unique_bases.append(b) + base_colour_map = {b: nhs_colours[i % len(nhs_colours)] for i, b in enumerate(unique_bases)} + + # Node colours — same drug gets same colour regardless of treatment line + node_colours = [base_colour_map[base_drug(n["name"])] for n in nodes] + + # Node labels — format nicely + node_labels = [n["name"] for n in nodes] + + # Link colours — use source node colour at 40% opacity for visual clarity + def hex_to_rgba(hex_colour: str, alpha: float) -> str: + h = hex_colour.lstrip("#") + r, g, b = int(h[0:2], 16), int(h[2:4], 16), int(h[4:6], 16) + return f"rgba({r},{g},{b},{alpha})" + + link_colours = [ + hex_to_rgba(node_colours[link["source_idx"]], 0.35) + for link in links + ] + + # Build hover text for links + link_hovers = [ + f"{node_labels[link['source_idx']]} → {node_labels[link['target_idx']]}" + f"
Patients: {link['patients']:,}" + for link in links + ] + + # Compute total patients per node for node hover + node_patients = [0] * len(nodes) + for link in links: + node_patients[link["source_idx"]] += link["patients"] + # For terminal nodes (no outgoing), use incoming total + node_incoming = [0] * len(nodes) + for link in links: + node_incoming[link["target_idx"]] += link["patients"] + node_hover = [] + for i, n in enumerate(nodes): + out_p = node_patients[i] + in_p = node_incoming[i] + total = max(out_p, in_p) + node_hover.append(f"{n['name']}
Patients: {total:,}") + + fig = go.Figure( + go.Sankey( + arrangement="snap", + node=dict( + pad=20, + thickness=25, + line=dict(color="#FFFFFF", width=1), + label=node_labels, + color=node_colours, + customdata=node_hover, + hovertemplate="%{customdata}", + ), + link=dict( + source=[link["source_idx"] for link in links], + target=[link["target_idx"] for link in links], + value=[link["patients"] for link in links], + color=link_colours, + customdata=link_hovers, + hovertemplate="%{customdata}", + ), + ) + ) + + chart_title = "Drug Switching Flows" + if title: + chart_title = f"{chart_title} — {title}" + + fig.update_layout( + title=dict( + text=chart_title, + font=dict( + family="Source Sans 3, system-ui, sans-serif", + size=18, + color="#003087", + ), + x=0.5, + xanchor="center", + ), + font=dict( + family="Source Sans 3, system-ui, sans-serif", + size=12, + ), + paper_bgcolor="rgba(0,0,0,0)", + plot_bgcolor="rgba(0,0,0,0)", + margin=dict(t=60, l=30, r=30, b=30), + height=max(500, len(unique_bases) * 35 + 200), + ) + + return fig + + def save_figure_html( fig: go.Figure, save_dir: str, title: str, open_browser: bool = False ) -> str: