fix: adaptive legends + _base_layout for 4 chart functions (Task A.3)

- Add _smart_legend(n_items) and _smart_legend_margin(n_items) helpers
  - >15 items: vertical right legend with extra right margin
  - <=15 items: horizontal legend with dynamic bottom margin
- Apply _base_layout() + _smart_legend() to:
  - create_market_share_figure() — DRUG_PALETTE, adaptive legend
  - create_trust_market_share_figure() — DRUG_PALETTE, adaptive legend
  - create_dosing_figure() — DRUG_PALETTE, legend adapts to trace count
  - create_trust_duration_figure() — TRUST_PALETTE, fixed l=200 margin
- Replace all local nhs_colours lists with module-level palettes
- Net reduction of 48 lines via DRY layout code
This commit is contained in:
Andrew Charlwood
2026-02-07 02:43:35 +00:00
parent ffa29bd10c
commit 90de24c72d
2 changed files with 104 additions and 151 deletions
+94 -142
View File
@@ -46,6 +46,56 @@ DRUG_PALETTE = [
]
def _smart_legend(n_items: int, legend_title: str = "") -> dict:
"""Return a legend dict that adapts to the number of items.
- >15 items: vertical legend to the right of the chart
- ≤15 items: horizontal legend below the chart with dynamic bottom margin
Returns a dict suitable for ``legend=...`` inside ``fig.update_layout()``.
The caller should also set bottom margin accordingly — use
``_smart_legend_margin_b(n_items)`` for that.
"""
base = dict(
font=dict(family=CHART_FONT_FAMILY, size=11),
)
if legend_title:
base["title"] = legend_title
if n_items > 15:
base.update(
orientation="v",
x=1.02,
y=1,
xanchor="left",
yanchor="top",
)
else:
base.update(
orientation="h",
yanchor="top",
y=-0.12,
xanchor="center",
x=0.5,
)
return base
def _smart_legend_margin(n_items: int) -> dict:
"""Return margin dict with bottom margin adapted to legend size.
- >15 items: vertical right legend needs extra right margin (r=140)
but minimal bottom margin (b=40).
- ≤15 items: horizontal legend needs bottom margin scaled to
estimated row count (~6 items per row at font size 11).
"""
if n_items > 15:
return dict(r=140, b=40)
else:
rows = max(1, (n_items + 5) // 6) # ~6 items per row
return dict(b=max(60, rows * 28 + 30), r=24)
def _base_layout(title: str, **overrides) -> dict:
"""Return a dict of shared Plotly layout properties.
@@ -321,13 +371,6 @@ def create_market_share_figure(data: list[dict], title: str = "") -> go.Figure:
if not data:
return go.Figure()
# NHS blue palette for different drugs
nhs_colours = [
"#003087", "#005EB8", "#0072CE", "#1E88E5", "#41B6E6",
"#4FC3F7", "#768692", "#AE2573", "#006747", "#ED8B00",
"#8A1538", "#330072", "#009639", "#DA291C", "#00A499",
]
# Collect unique directorates (in order — already sorted by total patients desc)
seen_dirs = []
for d in data:
@@ -341,7 +384,7 @@ def create_market_share_figure(data: list[dict], title: str = "") -> go.Figure:
seen_drugs.append(d["drug"])
# Build one trace per drug
drug_colour_map = {drug: nhs_colours[i % len(nhs_colours)] for i, drug in enumerate(seen_drugs)}
drug_colour_map = {drug: DRUG_PALETTE[i % len(DRUG_PALETTE)] for i, drug in enumerate(seen_drugs)}
# Build a lookup: (directory, drug) -> row
lookup = {(d["directory"], d["drug"]): d for d in data}
@@ -384,60 +427,26 @@ def create_market_share_figure(data: list[dict], title: str = "") -> go.Figure:
display_title = f"First-Line Drug Market Share — {title}" if title else "First-Line Drug Market Share"
n_drugs = len(seen_drugs)
legend_margins = _smart_legend_margin(n_drugs)
fig = go.Figure(data=traces)
fig.update_layout(
layout = _base_layout(display_title)
layout.update(
barmode="stack",
title=dict(
text=display_title,
font=dict(
family="Source Sans 3, system-ui, sans-serif",
size=18,
color="#1E293B",
),
x=0.5,
xanchor="center",
),
xaxis=dict(
title="% of patients",
ticksuffix="%",
range=[0, 105],
gridcolor="#E2E8F0",
gridcolor=GRID_COLOR,
zeroline=False,
),
yaxis=dict(
title="",
automargin=True,
),
legend=dict(
title="Drug",
orientation="h",
yanchor="top",
y=-0.15,
xanchor="center",
x=0.5,
font=dict(
family="Source Sans 3, system-ui, sans-serif",
size=11,
),
),
margin=dict(t=50, l=8, r=24, b=100),
paper_bgcolor="rgba(0,0,0,0)",
plot_bgcolor="rgba(0,0,0,0)",
autosize=True,
hoverlabel=dict(
bgcolor="#FFFFFF",
bordercolor="#CBD5E1",
font=dict(
family="Source Sans 3, system-ui, sans-serif",
size=13,
color="#1E293B",
),
),
font=dict(
family="Source Sans 3, system-ui, sans-serif",
),
yaxis=dict(title="", automargin=True),
legend=_smart_legend(n_drugs, legend_title="Drug"),
margin=dict(t=50, l=8, **legend_margins),
height=max(400, len(seen_dirs) * 60 + 200),
)
fig.update_layout(**layout)
return fig
@@ -919,36 +928,22 @@ def create_dosing_figure(
if not data:
return go.Figure()
nhs_colours = [
"#005EB8", "#003087", "#41B6E6", "#0066CC", "#1E88E5",
"#4FC3F7", "#009639", "#ED8B00", "#768692", "#AE2573",
"#8A1538", "#330072", "#DA291C", "#00A499", "#425563",
]
if group_by == "trust":
# Single-drug mode: compare trusts, group bars by directory
fig = _dosing_by_trust(data, nhs_colours)
chart_title = f"Dosing Intervals by Trust"
fig = _dosing_by_trust(data, DRUG_PALETTE)
chart_title = "Dosing Intervals by Trust"
else:
# Overview mode: weighted average per drug
fig = _dosing_by_drug(data, nhs_colours)
fig = _dosing_by_drug(data, DRUG_PALETTE)
chart_title = "Dosing Interval Overview"
if title:
chart_title = f"{chart_title}{title}"
n_rows = len(fig.data[0].y) if fig.data else 10
fig.update_layout(
title=dict(
text=chart_title,
font=dict(
family="Source Sans 3, system-ui, sans-serif",
size=18,
color="#003087",
),
x=0.5,
xanchor="center",
),
n_legend = sum(1 for t in fig.data if t.showlegend is not False)
legend_margins = _smart_legend_margin(n_legend)
layout = _base_layout(chart_title)
layout.update(
xaxis=dict(
title="Weekly Interval (weeks between doses)",
titlefont=dict(size=13, color="#425563"),
@@ -956,30 +951,15 @@ def create_dosing_figure(
zeroline=True,
zerolinecolor="rgba(66,85,99,0.2)",
),
yaxis=dict(
automargin=True,
tickfont=dict(size=11),
),
font=dict(
family="Source Sans 3, system-ui, sans-serif",
size=12,
),
paper_bgcolor="rgba(0,0,0,0)",
plot_bgcolor="rgba(0,0,0,0)",
margin=dict(t=60, l=20, r=40, b=60),
yaxis=dict(automargin=True, tickfont=dict(size=11)),
margin=dict(t=60, l=20, **legend_margins),
height=max(450, n_rows * 40 + 150),
bargap=0.15,
bargroupgap=0.05,
showlegend=True,
legend=dict(
orientation="h",
yanchor="top",
y=-0.12,
xanchor="center",
x=0.5,
font=dict(size=11),
),
legend=_smart_legend(n_legend),
)
fig.update_layout(**layout)
return fig
@@ -1571,12 +1551,6 @@ def create_trust_market_share_figure(
if not data:
return go.Figure()
nhs_colours = [
"#003087", "#005EB8", "#0072CE", "#1E88E5", "#41B6E6",
"#4FC3F7", "#768692", "#AE2573", "#006747", "#ED8B00",
"#8A1538", "#330072", "#009639", "#DA291C", "#00A499",
]
seen_trusts = []
for d in data:
t = d["trust_name"]
@@ -1588,7 +1562,7 @@ def create_trust_market_share_figure(
if d["drug"] not in seen_drugs:
seen_drugs.append(d["drug"])
drug_colour_map = {drug: nhs_colours[i % len(nhs_colours)] for i, drug in enumerate(seen_drugs)}
drug_colour_map = {drug: DRUG_PALETTE[i % len(DRUG_PALETTE)] for i, drug in enumerate(seen_drugs)}
lookup = {(d["trust_name"], d["drug"]): d for d in data}
def short_trust(name):
@@ -1611,8 +1585,8 @@ def create_trust_market_share_figure(
f"{short_trust(trust)}<br>"
f"Patients: {row['patients']:,}<br>"
f"Share: {row['proportion']:.1%}<br>"
f"Cost: \u00a3{row['cost']:,.0f}<br>"
f"Cost p.p.p.a: \u00a3{row['cost_pp_pa']:,.0f}"
f"Cost: £{row['cost']:,.0f}<br>"
f"Cost p.p.p.a: £{row['cost_pp_pa']:,.0f}"
)
else:
x_vals.append(0)
@@ -1625,32 +1599,21 @@ def create_trust_market_share_figure(
customdata=hover_texts,
))
display_title = f"Drug Market Share by Trust \u2014 {title}" if title else "Drug Market Share by Trust"
display_title = f"Drug Market Share by Trust {title}" if title else "Drug Market Share by Trust"
n_drugs = len(seen_drugs)
legend_margins = _smart_legend_margin(n_drugs)
fig = go.Figure(data=traces)
fig.update_layout(
layout = _base_layout(display_title)
layout.update(
barmode="stack",
title=dict(
text=display_title,
font=dict(family="Source Sans 3, system-ui, sans-serif", size=16, color="#1E293B"),
x=0.5, xanchor="center",
),
xaxis=dict(title="% of patients", ticksuffix="%", range=[0, 105], gridcolor="#E2E8F0", zeroline=False),
xaxis=dict(title="% of patients", ticksuffix="%", range=[0, 105], gridcolor=GRID_COLOR, zeroline=False),
yaxis=dict(title="", automargin=True),
legend=dict(
title="Drug", orientation="h", yanchor="top", y=-0.15,
xanchor="center", x=0.5, font=dict(family="Source Sans 3", size=11),
),
margin=dict(t=50, l=8, r=24, b=100),
paper_bgcolor="rgba(0,0,0,0)", plot_bgcolor="rgba(0,0,0,0)",
autosize=True,
hoverlabel=dict(
bgcolor="#FFFFFF", bordercolor="#CBD5E1",
font=dict(family="Source Sans 3, system-ui, sans-serif", size=13, color="#1E293B"),
),
font=dict(family="Source Sans 3, system-ui, sans-serif"),
legend=_smart_legend(n_drugs, legend_title="Drug"),
margin=dict(t=50, l=8, **legend_margins),
height=max(300, len(seen_trusts) * 60 + 200),
)
fig.update_layout(**layout)
return fig
@@ -1800,11 +1763,6 @@ def create_trust_duration_figure(
if not data:
return go.Figure()
nhs_colours = [
"#005EB8", "#003087", "#41B6E6", "#0066CC", "#1E88E5",
"#4FC3F7", "#009639", "#ED8B00", "#768692", "#AE2573",
]
seen_drugs = []
for d in data:
if d["drug"] not in seen_drugs:
@@ -1819,7 +1777,7 @@ def create_trust_duration_figure(
def short_trust(name):
return name.replace(" NHS FOUNDATION TRUST", "").replace(" HOSPITALS", "")
trust_colour_map = {t: nhs_colours[i % len(nhs_colours)] for i, t in enumerate(seen_trusts)}
trust_colour_map = {t: TRUST_PALETTE[i % len(TRUST_PALETTE)] for i, t in enumerate(seen_trusts)}
lookup = {(d["drug"], d["trust_name"]): d for d in data}
display_drugs = list(reversed(seen_drugs))
@@ -1852,30 +1810,24 @@ def create_trust_duration_figure(
customdata=hover_texts,
))
display_title = f"Treatment Duration by Trust \u2014 {title}" if title else "Treatment Duration by Trust"
display_title = f"Treatment Duration by Trust {title}" if title else "Treatment Duration by Trust"
n_trusts = len(seen_trusts)
legend_margins = _smart_legend_margin(n_trusts)
fig = go.Figure(data=traces)
fig.update_layout(
layout = _base_layout(display_title)
layout.update(
barmode="group",
title=dict(
text=display_title,
font=dict(family="Source Sans 3, system-ui, sans-serif", size=16, color="#003087"),
x=0.5, xanchor="center",
),
xaxis=dict(
title="Average Duration (days)", titlefont=dict(size=13, color="#425563"),
gridcolor="rgba(0,0,0,0.06)", zeroline=True, zerolinecolor="rgba(0,0,0,0.1)",
),
yaxis=dict(title="", automargin=True, tickfont=dict(size=11, color="#425563")),
legend=dict(
title="Trust", orientation="h", yanchor="top", y=-0.12,
xanchor="center", x=0.5, font=dict(size=11),
),
plot_bgcolor="rgba(0,0,0,0)", paper_bgcolor="rgba(0,0,0,0)",
font=dict(family="Source Sans 3, system-ui, sans-serif"),
margin=dict(t=60, l=200, r=40, b=100),
legend=_smart_legend(n_trusts, legend_title="Trust"),
margin=dict(t=60, l=8, **legend_margins),
height=max(350, len(seen_drugs) * 35 + 200),
bargap=0.15, bargroupgap=0.05,
)
fig.update_layout(**layout)
return fig