From 73a8d1a49f2d4c578dc49c403feec6d068d808a9 Mon Sep 17 00:00:00 2001 From: Andrew Charlwood Date: Fri, 6 Feb 2026 19:44:37 +0000 Subject: [PATCH] feat: add Cost Waterfall bar chart (Task 9.5) --- IMPLEMENTATION_PLAN.md | 8 +- dash_app/callbacks/chart.py | 26 +++++ src/visualization/plotly_generator.py | 144 ++++++++++++++++++++++++++ 3 files changed, 174 insertions(+), 4 deletions(-) diff --git a/IMPLEMENTATION_PLAN.md b/IMPLEMENTATION_PLAN.md index dda8a8a..20cd592 100644 --- a/IMPLEMENTATION_PLAN.md +++ b/IMPLEMENTATION_PLAN.md @@ -393,13 +393,13 @@ Drawer selection → update_drug_selection → app-state store → load_pathway_ - **Checkpoint**: Cost Effectiveness tab renders with lollipop dots and retention annotations ✓ ### 9.5 Cost Waterfall chart (Tab 4) -- [ ] Create `dash_app/callbacks/cost_waterfall.py`: +- [x] Create `dash_app/callbacks/cost_waterfall.py`: - Build Plotly waterfall chart from `get_cost_waterfall()` data - Each bar = one directorate's average cost_pp_pa, sorted highest to lowest - NHS colours, responds to chart_type toggle, date filter, trust filter -- [ ] Create figure function in `src/visualization/` -- [ ] Wire into tab switching -- **Checkpoint**: Cost Waterfall tab renders real data, responds to filters +- [x] Create figure function in `src/visualization/` +- [x] Wire into tab switching +- **Checkpoint**: Cost Waterfall tab renders real data, responds to filters ✓ ### 9.6 Drug Switching Sankey chart (Tab 5) - [ ] Create `dash_app/callbacks/sankey.py`: diff --git a/dash_app/callbacks/chart.py b/dash_app/callbacks/chart.py index dd8f754..5d6a088 100644 --- a/dash_app/callbacks/chart.py +++ b/dash_app/callbacks/chart.py @@ -135,6 +135,29 @@ def _render_cost_effectiveness(app_state, chart_data, title): return create_cost_effectiveness_figure(data, retention, title) +def _render_cost_waterfall(app_state, title): + """Build the cost waterfall figure from current filter state.""" + from dash_app.data.queries import get_cost_waterfall + from visualization.plotly_generator import create_cost_waterfall_figure + + filter_id = (app_state or {}).get("date_filter_id", "all_6mo") + chart_type = (app_state or {}).get("chart_type", "directory") + + selected_trusts = (app_state or {}).get("selected_trusts") or [] + trust = selected_trusts[0] if len(selected_trusts) == 1 else None + + try: + data = get_cost_waterfall(filter_id, chart_type, trust) + except Exception: + log.exception("Failed to load cost waterfall data") + return _empty_figure("Failed to load cost waterfall data.") + + if not data: + return _empty_figure("No cost waterfall data available.\nTry adjusting your filters.") + + return create_cost_waterfall_figure(data, title) + + def register_chart_callbacks(app): """Register tab switching, pathway data loading, and chart rendering callbacks.""" @@ -254,6 +277,9 @@ def register_chart_callbacks(app): elif active_tab == "cost-effectiveness": fig = _render_cost_effectiveness(app_state, chart_data, title) + elif active_tab == "cost-waterfall": + fig = _render_cost_waterfall(app_state, title) + else: # Placeholder for charts not yet implemented tab_label = dict(TAB_DEFINITIONS).get(active_tab, active_tab) diff --git a/src/visualization/plotly_generator.py b/src/visualization/plotly_generator.py index bbee0ce..8be52cd 100644 --- a/src/visualization/plotly_generator.py +++ b/src/visualization/plotly_generator.py @@ -559,6 +559,150 @@ def create_cost_effectiveness_figure( return fig +def create_cost_waterfall_figure( + data: list[dict], + title: str = "", +) -> go.Figure: + """Create waterfall chart showing cost per patient by directorate/indication. + + Args: + data: List of dicts from get_cost_waterfall() with keys: + directory, patients, total_cost, cost_pp. + Sorted by cost_pp desc. + title: Chart title suffix (filter description). + + Returns: + Plotly Figure with waterfall bars and total. + """ + if not data: + return go.Figure() + + labels = [d["directory"] for d in data] + cost_pp_values = [d["cost_pp"] for d in data] + patients_list = [d["patients"] for d in data] + total_costs = [d["total_cost"] for d in data] + + # NHS colour palette for bars + nhs_colours = [ + "#005EB8", "#003087", "#41B6E6", "#0066CC", "#1E88E5", + "#4FC3F7", "#009639", "#ED8B00", "#768692", "#425563", + "#DA291C", "#7C2855", + ] + + # Assign colours cycling through palette + bar_colours = [nhs_colours[i % len(nhs_colours)] for i in range(len(data))] + + hover_texts = [] + for d in data: + hover_texts.append( + f"{d['directory']}
" + f"Cost per patient: £{d['cost_pp']:,.0f}
" + f"Patients: {d['patients']:,}
" + f"Total cost: £{d['total_cost']:,.0f}" + ) + + # Use a standard bar chart (not go.Waterfall) for cleaner control + # Each bar shows cost_pp for a directorate, sorted highest to lowest + fig = go.Figure() + + fig.add_trace( + go.Bar( + x=labels, + y=cost_pp_values, + marker=dict( + color=bar_colours, + line=dict(color="#FFFFFF", width=1), + ), + hovertemplate="%{customdata}", + customdata=hover_texts, + text=[f"£{v:,.0f}" for v in cost_pp_values], + textposition="outside", + textfont=dict(size=11, color="#425563"), + ) + ) + + # Add patient count annotations below each bar + for i, (label, pts) in enumerate(zip(labels, patients_list)): + fig.add_annotation( + x=label, + y=0, + text=f"n={pts:,}", + showarrow=False, + yshift=-18, + font=dict(size=10, color="#768692", family="Source Sans 3"), + ) + + # Grand total line + if cost_pp_values: + total_patients = sum(patients_list) + total_cost = sum(total_costs) + weighted_avg = total_cost / total_patients if total_patients else 0 + fig.add_hline( + y=weighted_avg, + line_dash="dash", + line_color="#DA291C", + line_width=1.5, + annotation_text=f"Weighted avg: £{weighted_avg:,.0f}", + annotation_position="top right", + annotation_font=dict( + size=11, color="#DA291C", family="Source Sans 3" + ), + ) + + display_title = ( + f"Cost per Patient by Directorate — {title}" if title + else "Cost per Patient by Directorate" + ) + + fig.update_layout( + title=dict( + text=display_title, + font=dict( + family="Source Sans 3, system-ui, sans-serif", + size=18, + color="#1E293B", + ), + x=0.5, + xanchor="center", + ), + xaxis=dict( + title="", + tickangle=-45 if len(data) > 6 else 0, + tickfont=dict(size=11), + automargin=True, + ), + yaxis=dict( + title="£ per patient", + tickprefix="£", + tickformat=",", + gridcolor="#E2E8F0", + zeroline=True, + zerolinecolor="#CBD5E1", + ), + margin=dict(t=60, l=8, r=24, b=40), + paper_bgcolor="rgba(0,0,0,0)", + plot_bgcolor="rgba(0,0,0,0)", + autosize=True, + showlegend=False, + hoverlabel=dict( + bgcolor="#FFFFFF", + bordercolor="#CBD5E1", + font=dict( + family="Source Sans 3, system-ui, sans-serif", + size=13, + color="#1E293B", + ), + ), + font=dict( + family="Source Sans 3, system-ui, sans-serif", + ), + height=max(450, 500), + bargap=0.25, + ) + + return fig + + def save_figure_html( fig: go.Figure, save_dir: str, title: str, open_browser: bool = False ) -> str: