feat: add Cost Waterfall bar chart (Task 9.5)

This commit is contained in:
Andrew Charlwood
2026-02-06 19:44:37 +00:00
parent 47e4aa4df2
commit 73a8d1a49f
3 changed files with 174 additions and 4 deletions
+4 -4
View File
@@ -393,13 +393,13 @@ Drawer selection → update_drug_selection → app-state store → load_pathway_
- **Checkpoint**: Cost Effectiveness tab renders with lollipop dots and retention annotations ✓ - **Checkpoint**: Cost Effectiveness tab renders with lollipop dots and retention annotations ✓
### 9.5 Cost Waterfall chart (Tab 4) ### 9.5 Cost Waterfall chart (Tab 4)
- [ ] Create `dash_app/callbacks/cost_waterfall.py`: - [x] Create `dash_app/callbacks/cost_waterfall.py`:
- Build Plotly waterfall chart from `get_cost_waterfall()` data - Build Plotly waterfall chart from `get_cost_waterfall()` data
- Each bar = one directorate's average cost_pp_pa, sorted highest to lowest - Each bar = one directorate's average cost_pp_pa, sorted highest to lowest
- NHS colours, responds to chart_type toggle, date filter, trust filter - NHS colours, responds to chart_type toggle, date filter, trust filter
- [ ] Create figure function in `src/visualization/` - [x] Create figure function in `src/visualization/`
- [ ] Wire into tab switching - [x] Wire into tab switching
- **Checkpoint**: Cost Waterfall tab renders real data, responds to filters - **Checkpoint**: Cost Waterfall tab renders real data, responds to filters
### 9.6 Drug Switching Sankey chart (Tab 5) ### 9.6 Drug Switching Sankey chart (Tab 5)
- [ ] Create `dash_app/callbacks/sankey.py`: - [ ] Create `dash_app/callbacks/sankey.py`:
+26
View File
@@ -135,6 +135,29 @@ def _render_cost_effectiveness(app_state, chart_data, title):
return create_cost_effectiveness_figure(data, retention, title) return create_cost_effectiveness_figure(data, retention, title)
def _render_cost_waterfall(app_state, title):
"""Build the cost waterfall figure from current filter state."""
from dash_app.data.queries import get_cost_waterfall
from visualization.plotly_generator import create_cost_waterfall_figure
filter_id = (app_state or {}).get("date_filter_id", "all_6mo")
chart_type = (app_state or {}).get("chart_type", "directory")
selected_trusts = (app_state or {}).get("selected_trusts") or []
trust = selected_trusts[0] if len(selected_trusts) == 1 else None
try:
data = get_cost_waterfall(filter_id, chart_type, trust)
except Exception:
log.exception("Failed to load cost waterfall data")
return _empty_figure("Failed to load cost waterfall data.")
if not data:
return _empty_figure("No cost waterfall data available.\nTry adjusting your filters.")
return create_cost_waterfall_figure(data, title)
def register_chart_callbacks(app): def register_chart_callbacks(app):
"""Register tab switching, pathway data loading, and chart rendering callbacks.""" """Register tab switching, pathway data loading, and chart rendering callbacks."""
@@ -254,6 +277,9 @@ def register_chart_callbacks(app):
elif active_tab == "cost-effectiveness": elif active_tab == "cost-effectiveness":
fig = _render_cost_effectiveness(app_state, chart_data, title) fig = _render_cost_effectiveness(app_state, chart_data, title)
elif active_tab == "cost-waterfall":
fig = _render_cost_waterfall(app_state, title)
else: else:
# Placeholder for charts not yet implemented # Placeholder for charts not yet implemented
tab_label = dict(TAB_DEFINITIONS).get(active_tab, active_tab) tab_label = dict(TAB_DEFINITIONS).get(active_tab, active_tab)
+144
View File
@@ -559,6 +559,150 @@ def create_cost_effectiveness_figure(
return fig return fig
def create_cost_waterfall_figure(
data: list[dict],
title: str = "",
) -> go.Figure:
"""Create waterfall chart showing cost per patient by directorate/indication.
Args:
data: List of dicts from get_cost_waterfall() with keys:
directory, patients, total_cost, cost_pp.
Sorted by cost_pp desc.
title: Chart title suffix (filter description).
Returns:
Plotly Figure with waterfall bars and total.
"""
if not data:
return go.Figure()
labels = [d["directory"] for d in data]
cost_pp_values = [d["cost_pp"] for d in data]
patients_list = [d["patients"] for d in data]
total_costs = [d["total_cost"] for d in data]
# NHS colour palette for bars
nhs_colours = [
"#005EB8", "#003087", "#41B6E6", "#0066CC", "#1E88E5",
"#4FC3F7", "#009639", "#ED8B00", "#768692", "#425563",
"#DA291C", "#7C2855",
]
# Assign colours cycling through palette
bar_colours = [nhs_colours[i % len(nhs_colours)] for i in range(len(data))]
hover_texts = []
for d in data:
hover_texts.append(
f"<b>{d['directory']}</b><br>"
f"Cost per patient: £{d['cost_pp']:,.0f}<br>"
f"Patients: {d['patients']:,}<br>"
f"Total cost: £{d['total_cost']:,.0f}"
)
# Use a standard bar chart (not go.Waterfall) for cleaner control
# Each bar shows cost_pp for a directorate, sorted highest to lowest
fig = go.Figure()
fig.add_trace(
go.Bar(
x=labels,
y=cost_pp_values,
marker=dict(
color=bar_colours,
line=dict(color="#FFFFFF", width=1),
),
hovertemplate="%{customdata}<extra></extra>",
customdata=hover_texts,
text=[f"£{v:,.0f}" for v in cost_pp_values],
textposition="outside",
textfont=dict(size=11, color="#425563"),
)
)
# Add patient count annotations below each bar
for i, (label, pts) in enumerate(zip(labels, patients_list)):
fig.add_annotation(
x=label,
y=0,
text=f"n={pts:,}",
showarrow=False,
yshift=-18,
font=dict(size=10, color="#768692", family="Source Sans 3"),
)
# Grand total line
if cost_pp_values:
total_patients = sum(patients_list)
total_cost = sum(total_costs)
weighted_avg = total_cost / total_patients if total_patients else 0
fig.add_hline(
y=weighted_avg,
line_dash="dash",
line_color="#DA291C",
line_width=1.5,
annotation_text=f"Weighted avg: £{weighted_avg:,.0f}",
annotation_position="top right",
annotation_font=dict(
size=11, color="#DA291C", family="Source Sans 3"
),
)
display_title = (
f"Cost per Patient by Directorate — {title}" if title
else "Cost per Patient by Directorate"
)
fig.update_layout(
title=dict(
text=display_title,
font=dict(
family="Source Sans 3, system-ui, sans-serif",
size=18,
color="#1E293B",
),
x=0.5,
xanchor="center",
),
xaxis=dict(
title="",
tickangle=-45 if len(data) > 6 else 0,
tickfont=dict(size=11),
automargin=True,
),
yaxis=dict(
title="£ per patient",
tickprefix="£",
tickformat=",",
gridcolor="#E2E8F0",
zeroline=True,
zerolinecolor="#CBD5E1",
),
margin=dict(t=60, l=8, r=24, b=40),
paper_bgcolor="rgba(0,0,0,0)",
plot_bgcolor="rgba(0,0,0,0)",
autosize=True,
showlegend=False,
hoverlabel=dict(
bgcolor="#FFFFFF",
bordercolor="#CBD5E1",
font=dict(
family="Source Sans 3, system-ui, sans-serif",
size=13,
color="#1E293B",
),
),
font=dict(
family="Source Sans 3, system-ui, sans-serif",
),
height=max(450, 500),
bargap=0.25,
)
return fig
def save_figure_html( def save_figure_html(
fig: go.Figure, save_dir: str, title: str, open_browser: bool = False fig: go.Figure, save_dir: str, title: str, open_browser: bool = False
) -> str: ) -> str: