feat: add Treatment Duration bar chart (Task 9.9)

This commit is contained in:
Andrew Charlwood
2026-02-06 20:12:01 +00:00
parent 1b134c46a2
commit 965fc8c3d2
3 changed files with 183 additions and 3 deletions
+149
View File
@@ -1324,3 +1324,152 @@ def create_heatmap_figure(
)
return fig
def create_duration_figure(
data: list[dict],
title: str = "",
show_directory: bool = False,
) -> go.Figure:
"""Create horizontal bar chart showing average treatment duration by drug.
Args:
data: List of dicts from get_treatment_durations() with keys:
drug, directory, avg_days, patients.
Sorted by avg_days desc.
title: Chart title suffix (filter description).
show_directory: If True, include directory in label (for overview mode).
Returns:
Plotly Figure with horizontal bars coloured by patient count.
"""
if not data:
return go.Figure()
# When not showing directory breakdown, aggregate same drug across directorates
if not show_directory:
agg = {}
for d in data:
drug = d["drug"]
pts = d["patients"]
days = d["avg_days"]
if drug not in agg:
agg[drug] = {"drug": drug, "total_weighted": 0.0, "total_pts": 0}
agg[drug]["total_weighted"] += days * pts
agg[drug]["total_pts"] += pts
data = []
for v in agg.values():
if v["total_pts"] > 0:
data.append({
"drug": v["drug"],
"avg_days": round(v["total_weighted"] / v["total_pts"], 1),
"patients": v["total_pts"],
})
data.sort(key=lambda x: -x["avg_days"])
# Cap at 40 entries for readability (keep top by patient count, then re-sort by days)
if len(data) > 40:
data.sort(key=lambda x: -x["patients"])
data = data[:40]
data.sort(key=lambda x: -x["avg_days"])
# Build labels
if show_directory:
labels = [f"{d['drug']} ({d['directory']})" for d in data]
else:
labels = [d["drug"] for d in data]
days_values = [d["avg_days"] for d in data]
patients_list = [d["patients"] for d in data]
# Colour gradient by patient count: light for few → dark NHS blue for many
max_pts = max(patients_list) if patients_list else 1
min_pts = min(patients_list) if patients_list else 0
pt_range = max_pts - min_pts if max_pts > min_pts else 1
bar_colours = []
for pts in patients_list:
t = (pts - min_pts) / pt_range
r = int(0x41 + (0x00 - 0x41) * t)
g = int(0xB6 + (0x30 - 0xB6) * t)
b = int(0xE6 + (0x87 - 0xE6) * t)
bar_colours.append(f"rgb({r},{g},{b})")
hover_texts = []
for d in data:
years = d["avg_days"] / 365.25
hover_texts.append(
f"<b>{d['drug']}</b><br>"
f"Avg duration: {d['avg_days']:,.0f} days ({years:.1f} years)<br>"
f"Patients: {d['patients']:,}"
)
fig = go.Figure()
fig.add_trace(
go.Bar(
y=labels,
x=days_values,
orientation="h",
marker=dict(
color=bar_colours,
line=dict(color="#FFFFFF", width=1),
),
hovertemplate="%{customdata}<extra></extra>",
customdata=hover_texts,
text=[f"{v:,.0f}d" for v in days_values],
textposition="outside",
textfont=dict(size=10, color="#425563"),
)
)
for i, pts in enumerate(patients_list):
fig.add_annotation(
x=days_values[i],
y=labels[i],
text=f"n={pts:,}",
showarrow=False,
xshift=45,
font=dict(size=9, color="#768692", family="Source Sans 3"),
)
chart_title = "Treatment Duration by Drug"
if title:
chart_title += f"<br><span style='font-size:13px;color:#768692'>{title}</span>"
n_bars = len(data)
fig_height = max(400, 40 + n_bars * 28)
fig.update_layout(
title=dict(
text=chart_title,
font=dict(
family="Source Sans 3, system-ui, sans-serif",
size=18,
color="#003087",
),
x=0.5,
xanchor="center",
),
xaxis=dict(
title="Average Duration (days)",
titlefont=dict(size=13, color="#425563"),
tickfont=dict(size=11, color="#425563"),
gridcolor="rgba(0,0,0,0.06)",
zeroline=True,
zerolinecolor="rgba(0,0,0,0.1)",
),
yaxis=dict(
title="",
tickfont=dict(size=11, color="#425563"),
autorange="reversed",
),
plot_bgcolor="rgba(0,0,0,0)",
paper_bgcolor="rgba(0,0,0,0)",
font=dict(family="Source Sans 3, system-ui, sans-serif"),
margin=dict(t=60, l=200, r=80, b=50),
height=fig_height,
showlegend=False,
)
return fig