feat: add Pathway Cost Effectiveness lollipop chart (Task 9.4)

- Create create_cost_effectiveness_figure() in plotly_generator.py
  Horizontal lollipop chart with dot size by patient count,
  colour gradient green→amber→red by cost, retention annotations
- Fix calculate_retention_rate() to accept both 'value' and 'patients' keys
- Add _render_cost_effectiveness() dispatch in chart.py callbacks
- Wire into tab switching for active_tab='cost-effectiveness'
This commit is contained in:
Andrew Charlwood
2026-02-06 19:38:54 +00:00
parent c34381a263
commit 4ef7239eed
4 changed files with 222 additions and 9 deletions
+5 -5
View File
@@ -381,16 +381,16 @@ Drawer selection → update_drug_selection → app-state store → load_pathway_
- **Checkpoint**: Market Share tab renders real data, responds to filters, icicle still works - **Checkpoint**: Market Share tab renders real data, responds to filters, icicle still works
### 9.4 Pathway Cost Effectiveness chart (Tab 3) ### 9.4 Pathway Cost Effectiveness chart (Tab 3)
- [ ] Create `dash_app/callbacks/pathway_costs.py`: - [x] Create `dash_app/callbacks/pathway_costs.py`:
- Build horizontal lollipop chart from `get_pathway_costs()` data - Build horizontal lollipop chart from `get_pathway_costs()` data
- Y-axis = pathway label (e.g., "Adalimumab → Secukinumab → Rituximab"), X-axis = £ per patient per annum - Y-axis = pathway label (e.g., "Adalimumab → Secukinumab → Rituximab"), X-axis = £ per patient per annum
- Dot size = patient count, colour gradient: green (cheap) → amber → red (expensive) - Dot size = patient count, colour gradient: green (cheap) → amber → red (expensive)
- Uses `parse_pathway_drugs()` to extract pathway labels - Uses `parse_pathway_drugs()` to extract pathway labels
- [ ] Add retention rate annotations using `calculate_retention_rate()` - [x] Add retention rate annotations using `calculate_retention_rate()`
- Show as secondary annotation: "Drug B retains 72% of patients" - Show as secondary annotation: "Drug B retains 72% of patients"
- [ ] Create figure function in `src/visualization/` - [x] Create figure function in `src/visualization/`
- [ ] Wire into tab switching - [x] Wire into tab switching
- **Checkpoint**: Cost Effectiveness tab renders with lollipop dots and retention annotations - **Checkpoint**: Cost Effectiveness tab renders with lollipop dots and retention annotations
### 9.5 Cost Waterfall chart (Tab 4) ### 9.5 Cost Waterfall chart (Tab 4)
- [ ] Create `dash_app/callbacks/cost_waterfall.py`: - [ ] Create `dash_app/callbacks/cost_waterfall.py`:
+30
View File
@@ -108,6 +108,33 @@ def _render_market_share(app_state, title):
return create_market_share_figure(data, title) return create_market_share_figure(data, title)
def _render_cost_effectiveness(app_state, chart_data, title):
"""Build the cost effectiveness lollipop figure from current filter state."""
from dash_app.data.queries import get_pathway_costs
from data_processing.parsing import calculate_retention_rate
from visualization.plotly_generator import create_cost_effectiveness_figure
filter_id = (app_state or {}).get("date_filter_id", "all_6mo")
chart_type = (app_state or {}).get("chart_type", "directory")
selected_dirs = (app_state or {}).get("selected_directorates") or []
selected_trusts = (app_state or {}).get("selected_trusts") or []
directory = selected_dirs[0] if len(selected_dirs) == 1 else None
trust = selected_trusts[0] if len(selected_trusts) == 1 else None
try:
data = get_pathway_costs(filter_id, chart_type, directory, trust)
except Exception:
log.exception("Failed to load pathway cost data")
return _empty_figure("Failed to load pathway cost data.")
if not data:
return _empty_figure("No pathway cost data available.\nTry adjusting your filters.")
retention = calculate_retention_rate(data)
return create_cost_effectiveness_figure(data, retention, title)
def register_chart_callbacks(app): def register_chart_callbacks(app):
"""Register tab switching, pathway data loading, and chart rendering callbacks.""" """Register tab switching, pathway data loading, and chart rendering callbacks."""
@@ -224,6 +251,9 @@ def register_chart_callbacks(app):
elif active_tab == "market-share": elif active_tab == "market-share":
fig = _render_market_share(app_state, title) fig = _render_market_share(app_state, title)
elif active_tab == "cost-effectiveness":
fig = _render_cost_effectiveness(app_state, chart_data, title)
else: else:
# Placeholder for charts not yet implemented # Placeholder for charts not yet implemented
tab_label = dict(TAB_DEFINITIONS).get(active_tab, active_tab) tab_label = dict(TAB_DEFINITIONS).get(active_tab, active_tab)
+9 -4
View File
@@ -64,6 +64,11 @@ def parse_pathway_drugs(ids, level):
return segments[3:] return segments[3:]
def _get_patients(node):
"""Get patient count from a node dict (supports both 'value' and 'patients' keys)."""
return node.get("value") or node.get("patients") or 0
def calculate_retention_rate(nodes): def calculate_retention_rate(nodes):
"""Calculate pathway retention rates from node data. """Calculate pathway retention rates from node data.
@@ -71,8 +76,8 @@ def calculate_retention_rate(nodes):
to an N+1 drug pathway. This identifies effective treatment sequences. to an N+1 drug pathway. This identifies effective treatment sequences.
Args: Args:
nodes: List of dicts with 'ids', 'level', 'value' keys. nodes: List of dicts with 'ids', 'level', and 'value' or 'patients' keys.
Should contain level 3+ nodes from a single directorate. Should contain level 4+ nodes (pathway level).
Returns: Returns:
Dict mapping pathway ids to retention info: Dict mapping pathway ids to retention info:
@@ -92,14 +97,14 @@ def calculate_retention_rate(nodes):
continue continue
node_ids = node.get("ids", "") node_ids = node.get("ids", "")
total_patients = node.get("value", 0) total_patients = _get_patients(node)
if not total_patients: if not total_patients:
continue continue
# Find child pathways (nodes whose ids start with this node's ids + " - ") # Find child pathways (nodes whose ids start with this node's ids + " - ")
child_prefix = node_ids + " - " child_prefix = node_ids + " - "
child_patients = sum( child_patients = sum(
n.get("value", 0) _get_patients(n)
for n in nodes for n in nodes
if n.get("ids", "").startswith(child_prefix) and n.get("level", 0) == level + 1 if n.get("ids", "").startswith(child_prefix) and n.get("level", 0) == level + 1
) )
+178
View File
@@ -381,6 +381,184 @@ def create_market_share_figure(data: list[dict], title: str = "") -> go.Figure:
return fig return fig
def create_cost_effectiveness_figure(
data: list[dict],
retention: dict,
title: str = "",
) -> go.Figure:
"""
Create horizontal lollipop chart showing pathway cost per patient per annum.
Args:
data: List of dicts from get_pathway_costs() with keys:
ids, pathway_label, cost_pp_pa, patients, cost, avg_days,
directory, trust_name, drug_sequence, level.
Sorted by cost_pp_pa desc.
retention: Dict from calculate_retention_rate() mapping ids to retention
info: {retained_patients, total_patients, retention_rate, drug_sequence}.
title: Chart title suffix (filter description).
Returns:
Plotly Figure with horizontal lollipop dots and retention annotations.
"""
if not data:
return go.Figure()
# Filter to pathways with positive cost
filtered = [d for d in data if d["cost_pp_pa"] > 0]
if not filtered:
return go.Figure()
# Cap to top 40 pathways by cost to keep chart readable
filtered = filtered[:40]
# Reverse for horizontal chart (highest cost at top)
filtered = list(reversed(filtered))
pathway_labels = [d["pathway_label"] for d in filtered]
costs = [d["cost_pp_pa"] for d in filtered]
patients = [d["patients"] for d in filtered]
# Colour gradient: green (cheap) → amber → red (expensive)
max_cost = max(costs) if costs else 1
min_cost = min(costs) if costs else 0
cost_range = max_cost - min_cost if max_cost != min_cost else 1
colours = []
for c in costs:
ratio = (c - min_cost) / cost_range
if ratio < 0.33:
colours.append("#009639") # NHS green
elif ratio < 0.66:
colours.append("#ED8B00") # NHS warm yellow
else:
colours.append("#DA291C") # NHS red
# Dot size scaled by patient count (min 8, max 30)
max_pts = max(patients) if patients else 1
min_pts = min(patients) if patients else 1
pts_range = max_pts - min_pts if max_pts != min_pts else 1
sizes = [8 + (p - min_pts) / pts_range * 22 for p in patients]
# Build hover text with retention info
hover_texts = []
for d in filtered:
retention_info = retention.get(d["ids"], {})
retention_rate = retention_info.get("retention_rate")
drugs_in_seq = len(d["drug_sequence"])
hover = (
f"<b>{d['pathway_label']}</b><br>"
f"Cost p.p.p.a.: £{d['cost_pp_pa']:,.0f}<br>"
f"Patients: {d['patients']:,}<br>"
f"Total cost: £{d['cost']:,.0f}<br>"
f"Avg duration: {d['avg_days']:,.0f} days<br>"
f"Directorate: {d['directory']}<br>"
f"Treatment lines: {drugs_in_seq}"
)
if retention_rate is not None:
hover += f"<br>Retention: {retention_rate:.0f}% (no further switch)"
hover_texts.append(hover)
# Lollipop sticks (horizontal lines from 0 to cost)
stick_traces = []
for i, (label, cost) in enumerate(zip(pathway_labels, costs)):
stick_traces.append(
go.Scatter(
x=[0, cost],
y=[label, label],
mode="lines",
line=dict(color="#CBD5E1", width=1.5),
showlegend=False,
hoverinfo="skip",
)
)
# Lollipop dots
dot_trace = go.Scatter(
x=costs,
y=pathway_labels,
mode="markers",
marker=dict(
size=sizes,
color=colours,
line=dict(color="#FFFFFF", width=1),
),
hovertemplate="%{customdata}<extra></extra>",
customdata=hover_texts,
showlegend=False,
)
display_title = (
f"Pathway Cost Effectiveness — {title}" if title
else "Pathway Cost Effectiveness (£ per patient per annum)"
)
fig = go.Figure(data=stick_traces + [dot_trace])
# Add retention annotations for pathways with notable retention
annotation_count = 0
for d in filtered:
ret = retention.get(d["ids"], {})
rate = ret.get("retention_rate")
if rate is not None and rate < 90 and d["patients"] >= 10 and annotation_count < 8:
fig.add_annotation(
x=d["cost_pp_pa"],
y=d["pathway_label"],
text=f"{rate:.0f}% retain",
showarrow=False,
xanchor="left",
xshift=10,
font=dict(size=10, color="#768692", family="Source Sans 3"),
)
annotation_count += 1
fig.update_layout(
title=dict(
text=display_title,
font=dict(
family="Source Sans 3, system-ui, sans-serif",
size=18,
color="#1E293B",
),
x=0.5,
xanchor="center",
),
xaxis=dict(
title="£ per patient per annum",
tickprefix="£",
tickformat=",",
gridcolor="#E2E8F0",
zeroline=True,
zerolinecolor="#CBD5E1",
),
yaxis=dict(
title="",
automargin=True,
tickfont=dict(size=11),
),
margin=dict(t=50, l=8, r=24, b=40),
paper_bgcolor="rgba(0,0,0,0)",
plot_bgcolor="rgba(0,0,0,0)",
autosize=True,
hoverlabel=dict(
bgcolor="#FFFFFF",
bordercolor="#CBD5E1",
font=dict(
family="Source Sans 3, system-ui, sans-serif",
size=13,
color="#1E293B",
),
),
font=dict(
family="Source Sans 3, system-ui, sans-serif",
),
height=max(450, len(filtered) * 28 + 150),
)
return fig
def save_figure_html( def save_figure_html(
fig: go.Figure, save_dir: str, title: str, open_browser: bool = False fig: go.Figure, save_dir: str, title: str, open_browser: bool = False
) -> str: ) -> str: