diff --git a/IMPLEMENTATION_PLAN.md b/IMPLEMENTATION_PLAN.md
index dda8a8a..20cd592 100644
--- a/IMPLEMENTATION_PLAN.md
+++ b/IMPLEMENTATION_PLAN.md
@@ -393,13 +393,13 @@ Drawer selection → update_drug_selection → app-state store → load_pathway_
- **Checkpoint**: Cost Effectiveness tab renders with lollipop dots and retention annotations ✓
### 9.5 Cost Waterfall chart (Tab 4)
-- [ ] Create `dash_app/callbacks/cost_waterfall.py`:
+- [x] Create `dash_app/callbacks/cost_waterfall.py`:
- Build Plotly waterfall chart from `get_cost_waterfall()` data
- Each bar = one directorate's average cost_pp_pa, sorted highest to lowest
- NHS colours, responds to chart_type toggle, date filter, trust filter
-- [ ] Create figure function in `src/visualization/`
-- [ ] Wire into tab switching
-- **Checkpoint**: Cost Waterfall tab renders real data, responds to filters
+- [x] Create figure function in `src/visualization/`
+- [x] Wire into tab switching
+- **Checkpoint**: Cost Waterfall tab renders real data, responds to filters ✓
### 9.6 Drug Switching Sankey chart (Tab 5)
- [ ] Create `dash_app/callbacks/sankey.py`:
diff --git a/dash_app/callbacks/chart.py b/dash_app/callbacks/chart.py
index dd8f754..5d6a088 100644
--- a/dash_app/callbacks/chart.py
+++ b/dash_app/callbacks/chart.py
@@ -135,6 +135,29 @@ def _render_cost_effectiveness(app_state, chart_data, title):
return create_cost_effectiveness_figure(data, retention, title)
+def _render_cost_waterfall(app_state, title):
+ """Build the cost waterfall figure from current filter state."""
+ from dash_app.data.queries import get_cost_waterfall
+ from visualization.plotly_generator import create_cost_waterfall_figure
+
+ filter_id = (app_state or {}).get("date_filter_id", "all_6mo")
+ chart_type = (app_state or {}).get("chart_type", "directory")
+
+ selected_trusts = (app_state or {}).get("selected_trusts") or []
+ trust = selected_trusts[0] if len(selected_trusts) == 1 else None
+
+ try:
+ data = get_cost_waterfall(filter_id, chart_type, trust)
+ except Exception:
+ log.exception("Failed to load cost waterfall data")
+ return _empty_figure("Failed to load cost waterfall data.")
+
+ if not data:
+ return _empty_figure("No cost waterfall data available.\nTry adjusting your filters.")
+
+ return create_cost_waterfall_figure(data, title)
+
+
def register_chart_callbacks(app):
"""Register tab switching, pathway data loading, and chart rendering callbacks."""
@@ -254,6 +277,9 @@ def register_chart_callbacks(app):
elif active_tab == "cost-effectiveness":
fig = _render_cost_effectiveness(app_state, chart_data, title)
+ elif active_tab == "cost-waterfall":
+ fig = _render_cost_waterfall(app_state, title)
+
else:
# Placeholder for charts not yet implemented
tab_label = dict(TAB_DEFINITIONS).get(active_tab, active_tab)
diff --git a/src/visualization/plotly_generator.py b/src/visualization/plotly_generator.py
index bbee0ce..8be52cd 100644
--- a/src/visualization/plotly_generator.py
+++ b/src/visualization/plotly_generator.py
@@ -559,6 +559,150 @@ def create_cost_effectiveness_figure(
return fig
+def create_cost_waterfall_figure(
+ data: list[dict],
+ title: str = "",
+) -> go.Figure:
+ """Create waterfall chart showing cost per patient by directorate/indication.
+
+ Args:
+ data: List of dicts from get_cost_waterfall() with keys:
+ directory, patients, total_cost, cost_pp.
+ Sorted by cost_pp desc.
+ title: Chart title suffix (filter description).
+
+ Returns:
+ Plotly Figure with waterfall bars and total.
+ """
+ if not data:
+ return go.Figure()
+
+ labels = [d["directory"] for d in data]
+ cost_pp_values = [d["cost_pp"] for d in data]
+ patients_list = [d["patients"] for d in data]
+ total_costs = [d["total_cost"] for d in data]
+
+ # NHS colour palette for bars
+ nhs_colours = [
+ "#005EB8", "#003087", "#41B6E6", "#0066CC", "#1E88E5",
+ "#4FC3F7", "#009639", "#ED8B00", "#768692", "#425563",
+ "#DA291C", "#7C2855",
+ ]
+
+ # Assign colours cycling through palette
+ bar_colours = [nhs_colours[i % len(nhs_colours)] for i in range(len(data))]
+
+ hover_texts = []
+ for d in data:
+ hover_texts.append(
+ f"{d['directory']}
"
+ f"Cost per patient: £{d['cost_pp']:,.0f}
"
+ f"Patients: {d['patients']:,}
"
+ f"Total cost: £{d['total_cost']:,.0f}"
+ )
+
+ # Use a standard bar chart (not go.Waterfall) for cleaner control
+ # Each bar shows cost_pp for a directorate, sorted highest to lowest
+ fig = go.Figure()
+
+ fig.add_trace(
+ go.Bar(
+ x=labels,
+ y=cost_pp_values,
+ marker=dict(
+ color=bar_colours,
+ line=dict(color="#FFFFFF", width=1),
+ ),
+ hovertemplate="%{customdata}",
+ customdata=hover_texts,
+ text=[f"£{v:,.0f}" for v in cost_pp_values],
+ textposition="outside",
+ textfont=dict(size=11, color="#425563"),
+ )
+ )
+
+ # Add patient count annotations below each bar
+ for i, (label, pts) in enumerate(zip(labels, patients_list)):
+ fig.add_annotation(
+ x=label,
+ y=0,
+ text=f"n={pts:,}",
+ showarrow=False,
+ yshift=-18,
+ font=dict(size=10, color="#768692", family="Source Sans 3"),
+ )
+
+ # Grand total line
+ if cost_pp_values:
+ total_patients = sum(patients_list)
+ total_cost = sum(total_costs)
+ weighted_avg = total_cost / total_patients if total_patients else 0
+ fig.add_hline(
+ y=weighted_avg,
+ line_dash="dash",
+ line_color="#DA291C",
+ line_width=1.5,
+ annotation_text=f"Weighted avg: £{weighted_avg:,.0f}",
+ annotation_position="top right",
+ annotation_font=dict(
+ size=11, color="#DA291C", family="Source Sans 3"
+ ),
+ )
+
+ display_title = (
+ f"Cost per Patient by Directorate — {title}" if title
+ else "Cost per Patient by Directorate"
+ )
+
+ fig.update_layout(
+ title=dict(
+ text=display_title,
+ font=dict(
+ family="Source Sans 3, system-ui, sans-serif",
+ size=18,
+ color="#1E293B",
+ ),
+ x=0.5,
+ xanchor="center",
+ ),
+ xaxis=dict(
+ title="",
+ tickangle=-45 if len(data) > 6 else 0,
+ tickfont=dict(size=11),
+ automargin=True,
+ ),
+ yaxis=dict(
+ title="£ per patient",
+ tickprefix="£",
+ tickformat=",",
+ gridcolor="#E2E8F0",
+ zeroline=True,
+ zerolinecolor="#CBD5E1",
+ ),
+ margin=dict(t=60, l=8, r=24, b=40),
+ paper_bgcolor="rgba(0,0,0,0)",
+ plot_bgcolor="rgba(0,0,0,0)",
+ autosize=True,
+ showlegend=False,
+ hoverlabel=dict(
+ bgcolor="#FFFFFF",
+ bordercolor="#CBD5E1",
+ font=dict(
+ family="Source Sans 3, system-ui, sans-serif",
+ size=13,
+ color="#1E293B",
+ ),
+ ),
+ font=dict(
+ family="Source Sans 3, system-ui, sans-serif",
+ ),
+ height=max(450, 500),
+ bargap=0.25,
+ )
+
+ return fig
+
+
def save_figure_html(
fig: go.Figure, save_dir: str, title: str, open_browser: bool = False
) -> str: