ea6b9065bf
- Add 3 new visualization functions to plotly_generator.py: create_trust_market_share_figure, create_trust_heatmap_figure, create_trust_duration_figure - Replace 6 placeholder callbacks in trust_comparison.py with real implementations using trust-comparison queries + figure builders - Cost Waterfall reuses existing figure function via key mapping - Dosing reuses existing create_dosing_figure with group_by="trust" - Cost Effectiveness reuses existing function scoped to directorate - All 6 charts respond to date filter and chart type toggle - Validated with both directory (RHEUMATOLOGY) and indication (asthma)
1782 lines
58 KiB
Python
1782 lines
58 KiB
Python
"""
|
||
Plotly chart generation for patient pathway analysis.
|
||
|
||
This module contains functions for creating interactive icicle charts
|
||
that visualize patient treatment pathways. The charts display hierarchical
|
||
data: Trust → Directory → Drug → Pathway.
|
||
"""
|
||
|
||
import webbrowser
|
||
from typing import Optional
|
||
|
||
import numpy as np
|
||
import pandas as pd
|
||
import plotly.graph_objects as go
|
||
|
||
from core.logging_config import get_logger
|
||
|
||
logger = get_logger(__name__)
|
||
|
||
|
||
def create_icicle_figure(ice_df: pd.DataFrame, title: str) -> go.Figure:
|
||
"""
|
||
Create Plotly icicle figure from prepared DataFrame.
|
||
|
||
This function generates an interactive icicle chart showing patient pathway
|
||
hierarchies with custom data including costs, dates, and treatment durations.
|
||
|
||
Args:
|
||
ice_df: DataFrame with columns:
|
||
- parents: Parent node in hierarchy
|
||
- ids: Unique identifier for each node
|
||
- labels: Display label for each node
|
||
- value: Number of patients
|
||
- colour: Color value for visualization
|
||
- cost: Total cost
|
||
- costpp: Cost per patient
|
||
- cost_pp_pa: Cost per patient per annum
|
||
- First seen: First intervention date
|
||
- Last seen: Last intervention date
|
||
- First seen (Parent): Earliest date in parent group
|
||
- Last seen (Parent): Latest date in parent group
|
||
- average_spacing: Formatted string with dosing information
|
||
- avg_days: Average treatment duration
|
||
title: Chart title
|
||
|
||
Returns:
|
||
Plotly Figure object ready for display or export
|
||
"""
|
||
ice_df = ice_df.copy()
|
||
ice_df.sort_values(by=["labels"], ascending=True, inplace=True, ignore_index=True)
|
||
|
||
first_seen = ice_df["First seen"].astype(str).replace("NaT", "N/A").to_list()
|
||
last_seen = ice_df["Last seen"].astype(str).replace("NaT", "N/A").to_list()
|
||
first_seen_parent = ice_df["First seen (Parent)"].astype(str).to_list()
|
||
last_seen_parent = ice_df["Last seen (Parent)"].astype(str).to_list()
|
||
average_spacing = ice_df.average_spacing.astype(str).to_list()
|
||
|
||
fig = go.Figure(
|
||
go.Icicle(
|
||
labels=ice_df.labels,
|
||
ids=ice_df.ids,
|
||
parents=ice_df.parents,
|
||
customdata=np.stack(
|
||
(
|
||
ice_df.value,
|
||
ice_df.colour,
|
||
ice_df.cost,
|
||
ice_df.costpp,
|
||
first_seen,
|
||
last_seen,
|
||
first_seen_parent,
|
||
last_seen_parent,
|
||
average_spacing,
|
||
ice_df.cost_pp_pa,
|
||
),
|
||
axis=1,
|
||
),
|
||
values=ice_df.value,
|
||
branchvalues="total",
|
||
marker=dict(colors=ice_df.colour, colorscale="Viridis"),
|
||
maxdepth=3,
|
||
texttemplate="<b>%{label}</b> "
|
||
"<br><b>Total patients:</b> %{customdata[0]} (including children/further treatments)"
|
||
"<br><b>First seen:</b> %{customdata[4]}"
|
||
"<br><b>Last seen (including further treatments):</b> %{customdata[7]}"
|
||
"<br><b>Average treatment duration:</b> %{customdata[8]}"
|
||
"<br><b>Total cost:</b> £%{customdata[2]:.3~s}"
|
||
"<br><b>Average cost per patient:</b> £%{customdata[3]:.3~s}"
|
||
"<br><b>Average cost per patient per annum:</b> £%{customdata[9]:.3~s}",
|
||
hovertemplate="<b>%{label}</b>"
|
||
"<br><b>Total patients:</b> %{customdata[0]} - %{customdata[1]:.3p} of patients in level"
|
||
"<br><b>Total cost:</b> £%{customdata[2]:.3~s}"
|
||
"<br><b>Average cost per patient:</b> £%{customdata[3]:.3~s}"
|
||
"<br><b>Average cost per patient per annum:</b> £%{customdata[9]:.3~s}"
|
||
"<br><b>First seen:</b> %{customdata[4]}"
|
||
"<br><b>Last seen (including further treatments):</b> %{customdata[7]}"
|
||
"<br><b>Average treatment duration:</b>"
|
||
"%{customdata[8]}"
|
||
"<extra></extra>",
|
||
)
|
||
)
|
||
fig.update_traces(sort=False)
|
||
fig.update_layout(
|
||
margin=dict(t=60, l=1, r=1, b=60),
|
||
title=f"Norfolk & Waveney ICS high-cost drug patient pathways - {title}",
|
||
title_x=0.5,
|
||
hoverlabel=dict(font_size=16),
|
||
)
|
||
|
||
return fig
|
||
|
||
|
||
def create_icicle_from_nodes(nodes: list[dict], title: str = "") -> go.Figure:
|
||
"""
|
||
Create Plotly icicle figure from a list of pathway node dicts.
|
||
|
||
This is the dict-based entry point used by the Dash app. The nodes list
|
||
comes directly from the chart-data dcc.Store (JSON-serialized dicts with
|
||
underscore keys matching SQLite column names).
|
||
|
||
Args:
|
||
nodes: List of dicts with keys: parents, ids, labels, value, cost,
|
||
costpp, cost_pp_pa, colour, first_seen, last_seen,
|
||
first_seen_parent, last_seen_parent, average_spacing
|
||
title: Chart title (e.g. "By Directory | All years / Last 6 months")
|
||
|
||
Returns:
|
||
Plotly Figure object ready for dcc.Graph
|
||
"""
|
||
if not nodes:
|
||
return go.Figure()
|
||
|
||
parents = [d.get("parents", "") for d in nodes]
|
||
ids = [d.get("ids", "") for d in nodes]
|
||
labels = [d.get("labels", "") for d in nodes]
|
||
values = [d.get("value", 0) for d in nodes]
|
||
colours = [d.get("colour", 0.0) for d in nodes]
|
||
|
||
costs = [d.get("cost", 0.0) for d in nodes]
|
||
costpp = [d.get("costpp", 0.0) for d in nodes]
|
||
first_seen = [d.get("first_seen", "N/A") or "N/A" for d in nodes]
|
||
last_seen = [d.get("last_seen", "N/A") or "N/A" for d in nodes]
|
||
first_seen_parent = [d.get("first_seen_parent", "N/A") or "N/A" for d in nodes]
|
||
last_seen_parent = [d.get("last_seen_parent", "N/A") or "N/A" for d in nodes]
|
||
average_spacing = [d.get("average_spacing", "") or "" for d in nodes]
|
||
cost_pp_pa = [d.get("cost_pp_pa", 0.0) or 0.0 for d in nodes]
|
||
|
||
customdata = list(zip(
|
||
values, # [0]
|
||
colours, # [1]
|
||
costs, # [2]
|
||
costpp, # [3]
|
||
first_seen, # [4]
|
||
last_seen, # [5]
|
||
first_seen_parent, # [6]
|
||
last_seen_parent, # [7]
|
||
average_spacing, # [8]
|
||
cost_pp_pa, # [9]
|
||
))
|
||
|
||
# NHS blue gradient (Heritage Blue → Primary Blue → Vibrant Blue → Sky Blue → Pale Blue)
|
||
colorscale = [
|
||
[0.0, "#003087"],
|
||
[0.25, "#0066CC"],
|
||
[0.5, "#1E88E5"],
|
||
[0.75, "#4FC3F7"],
|
||
[1.0, "#E3F2FD"],
|
||
]
|
||
|
||
fig = go.Figure(
|
||
go.Icicle(
|
||
labels=labels,
|
||
ids=ids,
|
||
parents=parents,
|
||
values=values,
|
||
branchvalues="total",
|
||
marker=dict(
|
||
colors=colours,
|
||
colorscale=colorscale,
|
||
line=dict(width=1, color="#FFFFFF"),
|
||
),
|
||
maxdepth=3,
|
||
customdata=customdata,
|
||
texttemplate=(
|
||
"<b>%{label}</b> "
|
||
"<br><b>Total patients:</b> %{customdata[0]} (including children/further treatments)"
|
||
"<br><b>First seen:</b> %{customdata[4]}"
|
||
"<br><b>Last seen (including further treatments):</b> %{customdata[7]}"
|
||
"<br><b>Average treatment duration:</b> %{customdata[8]}"
|
||
"<br><b>Total cost:</b> \u00a3%{customdata[2]:.3~s}"
|
||
"<br><b>Average cost per patient:</b> \u00a3%{customdata[3]:.3~s}"
|
||
"<br><b>Average cost per patient per annum:</b> \u00a3%{customdata[9]:.3~s}"
|
||
),
|
||
hovertemplate=(
|
||
"<b>%{label}</b>"
|
||
"<br><b>Total patients:</b> %{customdata[0]} - %{customdata[1]:.3p} of patients in level"
|
||
"<br><b>Total cost:</b> \u00a3%{customdata[2]:.3~s}"
|
||
"<br><b>Average cost per patient:</b> \u00a3%{customdata[3]:.3~s}"
|
||
"<br><b>Average cost per patient per annum:</b> \u00a3%{customdata[9]:.3~s}"
|
||
"<br><b>First seen:</b> %{customdata[4]}"
|
||
"<br><b>Last seen (including further treatments):</b> %{customdata[7]}"
|
||
"<br><b>Average treatment duration:</b>"
|
||
"%{customdata[8]}"
|
||
"<extra></extra>"
|
||
),
|
||
textfont=dict(
|
||
family="Source Sans 3, system-ui, sans-serif",
|
||
size=12,
|
||
),
|
||
)
|
||
)
|
||
|
||
display_title = f"Patient Pathways \u2014 {title}" if title else "Patient Pathways"
|
||
|
||
fig.update_layout(
|
||
title=dict(
|
||
text=display_title,
|
||
font=dict(
|
||
family="Source Sans 3, system-ui, sans-serif",
|
||
size=18,
|
||
color="#1E293B",
|
||
),
|
||
x=0.5,
|
||
xanchor="center",
|
||
),
|
||
margin=dict(t=40, l=8, r=8, b=24),
|
||
hoverlabel=dict(
|
||
bgcolor="#FFFFFF",
|
||
bordercolor="#CBD5E1",
|
||
font=dict(
|
||
family="Source Sans 3, system-ui, sans-serif",
|
||
size=14,
|
||
color="#1E293B",
|
||
),
|
||
),
|
||
paper_bgcolor="rgba(0,0,0,0)",
|
||
plot_bgcolor="rgba(0,0,0,0)",
|
||
autosize=True,
|
||
clickmode="event+select",
|
||
)
|
||
|
||
fig.update_traces(sort=False)
|
||
|
||
return fig
|
||
|
||
|
||
def create_market_share_figure(data: list[dict], title: str = "") -> go.Figure:
|
||
"""
|
||
Create horizontal grouped bar chart showing first-line drug market share by directorate.
|
||
|
||
Args:
|
||
data: List of dicts from get_drug_market_share() with keys:
|
||
directory, drug, patients, proportion, cost, cost_pp_pa
|
||
Sorted by directory total patients desc, drugs desc within.
|
||
title: Chart title suffix (filter description)
|
||
|
||
Returns:
|
||
Plotly Figure with horizontal bars grouped by directorate.
|
||
"""
|
||
if not data:
|
||
return go.Figure()
|
||
|
||
# NHS blue palette for different drugs
|
||
nhs_colours = [
|
||
"#003087", "#005EB8", "#0072CE", "#1E88E5", "#41B6E6",
|
||
"#4FC3F7", "#768692", "#AE2573", "#006747", "#ED8B00",
|
||
"#8A1538", "#330072", "#009639", "#DA291C", "#00A499",
|
||
]
|
||
|
||
# Collect unique directorates (in order — already sorted by total patients desc)
|
||
seen_dirs = []
|
||
for d in data:
|
||
if d["directory"] not in seen_dirs:
|
||
seen_dirs.append(d["directory"])
|
||
|
||
# Collect unique drugs across all directorates (preserve first-encountered order)
|
||
seen_drugs = []
|
||
for d in data:
|
||
if d["drug"] not in seen_drugs:
|
||
seen_drugs.append(d["drug"])
|
||
|
||
# Build one trace per drug
|
||
drug_colour_map = {drug: nhs_colours[i % len(nhs_colours)] for i, drug in enumerate(seen_drugs)}
|
||
|
||
# Build a lookup: (directory, drug) -> row
|
||
lookup = {(d["directory"], d["drug"]): d for d in data}
|
||
|
||
# Reverse directory order so highest total is at the top of horizontal chart
|
||
display_dirs = list(reversed(seen_dirs))
|
||
|
||
traces = []
|
||
for drug in seen_drugs:
|
||
y_vals = []
|
||
x_vals = []
|
||
hover_texts = []
|
||
for directory in display_dirs:
|
||
row = lookup.get((directory, drug))
|
||
if row:
|
||
y_vals.append(directory)
|
||
x_vals.append(row["proportion"] * 100)
|
||
hover_texts.append(
|
||
f"<b>{drug}</b><br>"
|
||
f"{directory}<br>"
|
||
f"Patients: {row['patients']:,}<br>"
|
||
f"Share: {row['proportion']:.1%}<br>"
|
||
f"Cost: £{row['cost']:,.0f}<br>"
|
||
f"Cost p.p.p.a: £{row['cost_pp_pa']:,.0f}"
|
||
)
|
||
else:
|
||
y_vals.append(directory)
|
||
x_vals.append(0)
|
||
hover_texts.append("")
|
||
|
||
traces.append(go.Bar(
|
||
name=drug,
|
||
y=y_vals,
|
||
x=x_vals,
|
||
orientation="h",
|
||
marker_color=drug_colour_map[drug],
|
||
hovertemplate="%{customdata}<extra></extra>",
|
||
customdata=hover_texts,
|
||
))
|
||
|
||
display_title = f"First-Line Drug Market Share — {title}" if title else "First-Line Drug Market Share"
|
||
|
||
fig = go.Figure(data=traces)
|
||
fig.update_layout(
|
||
barmode="stack",
|
||
title=dict(
|
||
text=display_title,
|
||
font=dict(
|
||
family="Source Sans 3, system-ui, sans-serif",
|
||
size=18,
|
||
color="#1E293B",
|
||
),
|
||
x=0.5,
|
||
xanchor="center",
|
||
),
|
||
xaxis=dict(
|
||
title="% of patients",
|
||
ticksuffix="%",
|
||
range=[0, 105],
|
||
gridcolor="#E2E8F0",
|
||
zeroline=False,
|
||
),
|
||
yaxis=dict(
|
||
title="",
|
||
automargin=True,
|
||
),
|
||
legend=dict(
|
||
title="Drug",
|
||
orientation="h",
|
||
yanchor="top",
|
||
y=-0.15,
|
||
xanchor="center",
|
||
x=0.5,
|
||
font=dict(
|
||
family="Source Sans 3, system-ui, sans-serif",
|
||
size=11,
|
||
),
|
||
),
|
||
margin=dict(t=50, l=8, r=24, b=100),
|
||
paper_bgcolor="rgba(0,0,0,0)",
|
||
plot_bgcolor="rgba(0,0,0,0)",
|
||
autosize=True,
|
||
hoverlabel=dict(
|
||
bgcolor="#FFFFFF",
|
||
bordercolor="#CBD5E1",
|
||
font=dict(
|
||
family="Source Sans 3, system-ui, sans-serif",
|
||
size=13,
|
||
color="#1E293B",
|
||
),
|
||
),
|
||
font=dict(
|
||
family="Source Sans 3, system-ui, sans-serif",
|
||
),
|
||
height=max(400, len(seen_dirs) * 60 + 200),
|
||
)
|
||
|
||
return fig
|
||
|
||
|
||
def create_cost_effectiveness_figure(
|
||
data: list[dict],
|
||
retention: dict,
|
||
title: str = "",
|
||
) -> go.Figure:
|
||
"""
|
||
Create horizontal lollipop chart showing pathway cost per patient per annum.
|
||
|
||
Args:
|
||
data: List of dicts from get_pathway_costs() with keys:
|
||
ids, pathway_label, cost_pp_pa, patients, cost, avg_days,
|
||
directory, trust_name, drug_sequence, level.
|
||
Sorted by cost_pp_pa desc.
|
||
retention: Dict from calculate_retention_rate() mapping ids to retention
|
||
info: {retained_patients, total_patients, retention_rate, drug_sequence}.
|
||
title: Chart title suffix (filter description).
|
||
|
||
Returns:
|
||
Plotly Figure with horizontal lollipop dots and retention annotations.
|
||
"""
|
||
if not data:
|
||
return go.Figure()
|
||
|
||
# Filter to pathways with positive cost
|
||
filtered = [d for d in data if d["cost_pp_pa"] > 0]
|
||
if not filtered:
|
||
return go.Figure()
|
||
|
||
# Cap to top 40 pathways by cost to keep chart readable
|
||
filtered = filtered[:40]
|
||
|
||
# Reverse for horizontal chart (highest cost at top)
|
||
filtered = list(reversed(filtered))
|
||
|
||
pathway_labels = [d["pathway_label"] for d in filtered]
|
||
costs = [d["cost_pp_pa"] for d in filtered]
|
||
patients = [d["patients"] for d in filtered]
|
||
|
||
# Colour gradient: green (cheap) → amber → red (expensive)
|
||
max_cost = max(costs) if costs else 1
|
||
min_cost = min(costs) if costs else 0
|
||
cost_range = max_cost - min_cost if max_cost != min_cost else 1
|
||
|
||
colours = []
|
||
for c in costs:
|
||
ratio = (c - min_cost) / cost_range
|
||
if ratio < 0.33:
|
||
colours.append("#009639") # NHS green
|
||
elif ratio < 0.66:
|
||
colours.append("#ED8B00") # NHS warm yellow
|
||
else:
|
||
colours.append("#DA291C") # NHS red
|
||
|
||
# Dot size scaled by patient count (min 8, max 30)
|
||
max_pts = max(patients) if patients else 1
|
||
min_pts = min(patients) if patients else 1
|
||
pts_range = max_pts - min_pts if max_pts != min_pts else 1
|
||
sizes = [8 + (p - min_pts) / pts_range * 22 for p in patients]
|
||
|
||
# Build hover text with retention info
|
||
hover_texts = []
|
||
for d in filtered:
|
||
retention_info = retention.get(d["ids"], {})
|
||
retention_rate = retention_info.get("retention_rate")
|
||
drugs_in_seq = len(d["drug_sequence"])
|
||
|
||
hover = (
|
||
f"<b>{d['pathway_label']}</b><br>"
|
||
f"Cost p.p.p.a.: £{d['cost_pp_pa']:,.0f}<br>"
|
||
f"Patients: {d['patients']:,}<br>"
|
||
f"Total cost: £{d['cost']:,.0f}<br>"
|
||
f"Avg duration: {d['avg_days']:,.0f} days<br>"
|
||
f"Directorate: {d['directory']}<br>"
|
||
f"Treatment lines: {drugs_in_seq}"
|
||
)
|
||
if retention_rate is not None:
|
||
hover += f"<br>Retention: {retention_rate:.0f}% (no further switch)"
|
||
hover_texts.append(hover)
|
||
|
||
# Lollipop sticks (horizontal lines from 0 to cost)
|
||
stick_traces = []
|
||
for i, (label, cost) in enumerate(zip(pathway_labels, costs)):
|
||
stick_traces.append(
|
||
go.Scatter(
|
||
x=[0, cost],
|
||
y=[label, label],
|
||
mode="lines",
|
||
line=dict(color="#CBD5E1", width=1.5),
|
||
showlegend=False,
|
||
hoverinfo="skip",
|
||
)
|
||
)
|
||
|
||
# Lollipop dots
|
||
dot_trace = go.Scatter(
|
||
x=costs,
|
||
y=pathway_labels,
|
||
mode="markers",
|
||
marker=dict(
|
||
size=sizes,
|
||
color=colours,
|
||
line=dict(color="#FFFFFF", width=1),
|
||
),
|
||
hovertemplate="%{customdata}<extra></extra>",
|
||
customdata=hover_texts,
|
||
showlegend=False,
|
||
)
|
||
|
||
display_title = (
|
||
f"Pathway Cost Effectiveness — {title}" if title
|
||
else "Pathway Cost Effectiveness (£ per patient per annum)"
|
||
)
|
||
|
||
fig = go.Figure(data=stick_traces + [dot_trace])
|
||
|
||
# Add retention annotations for pathways with notable retention
|
||
annotation_count = 0
|
||
for d in filtered:
|
||
ret = retention.get(d["ids"], {})
|
||
rate = ret.get("retention_rate")
|
||
if rate is not None and rate < 90 and d["patients"] >= 10 and annotation_count < 8:
|
||
fig.add_annotation(
|
||
x=d["cost_pp_pa"],
|
||
y=d["pathway_label"],
|
||
text=f"{rate:.0f}% retain",
|
||
showarrow=False,
|
||
xanchor="left",
|
||
xshift=10,
|
||
font=dict(size=10, color="#768692", family="Source Sans 3"),
|
||
)
|
||
annotation_count += 1
|
||
|
||
fig.update_layout(
|
||
title=dict(
|
||
text=display_title,
|
||
font=dict(
|
||
family="Source Sans 3, system-ui, sans-serif",
|
||
size=18,
|
||
color="#1E293B",
|
||
),
|
||
x=0.5,
|
||
xanchor="center",
|
||
),
|
||
xaxis=dict(
|
||
title="£ per patient per annum",
|
||
tickprefix="£",
|
||
tickformat=",",
|
||
gridcolor="#E2E8F0",
|
||
zeroline=True,
|
||
zerolinecolor="#CBD5E1",
|
||
),
|
||
yaxis=dict(
|
||
title="",
|
||
automargin=True,
|
||
tickfont=dict(size=11),
|
||
),
|
||
margin=dict(t=50, l=8, r=24, b=40),
|
||
paper_bgcolor="rgba(0,0,0,0)",
|
||
plot_bgcolor="rgba(0,0,0,0)",
|
||
autosize=True,
|
||
hoverlabel=dict(
|
||
bgcolor="#FFFFFF",
|
||
bordercolor="#CBD5E1",
|
||
font=dict(
|
||
family="Source Sans 3, system-ui, sans-serif",
|
||
size=13,
|
||
color="#1E293B",
|
||
),
|
||
),
|
||
font=dict(
|
||
family="Source Sans 3, system-ui, sans-serif",
|
||
),
|
||
height=max(450, len(filtered) * 28 + 150),
|
||
)
|
||
|
||
return fig
|
||
|
||
|
||
def create_cost_waterfall_figure(
|
||
data: list[dict],
|
||
title: str = "",
|
||
) -> go.Figure:
|
||
"""Create waterfall chart showing cost per patient by directorate/indication.
|
||
|
||
Args:
|
||
data: List of dicts from get_cost_waterfall() with keys:
|
||
directory, patients, total_cost, cost_pp.
|
||
Sorted by cost_pp desc.
|
||
title: Chart title suffix (filter description).
|
||
|
||
Returns:
|
||
Plotly Figure with waterfall bars and total.
|
||
"""
|
||
if not data:
|
||
return go.Figure()
|
||
|
||
labels = [d["directory"] for d in data]
|
||
cost_pp_values = [d["cost_pp"] for d in data]
|
||
patients_list = [d["patients"] for d in data]
|
||
total_costs = [d["total_cost"] for d in data]
|
||
|
||
# NHS colour palette for bars
|
||
nhs_colours = [
|
||
"#005EB8", "#003087", "#41B6E6", "#0066CC", "#1E88E5",
|
||
"#4FC3F7", "#009639", "#ED8B00", "#768692", "#425563",
|
||
"#DA291C", "#7C2855",
|
||
]
|
||
|
||
# Assign colours cycling through palette
|
||
bar_colours = [nhs_colours[i % len(nhs_colours)] for i in range(len(data))]
|
||
|
||
hover_texts = []
|
||
for d in data:
|
||
hover_texts.append(
|
||
f"<b>{d['directory']}</b><br>"
|
||
f"Cost per patient: £{d['cost_pp']:,.0f}<br>"
|
||
f"Patients: {d['patients']:,}<br>"
|
||
f"Total cost: £{d['total_cost']:,.0f}"
|
||
)
|
||
|
||
# Use a standard bar chart (not go.Waterfall) for cleaner control
|
||
# Each bar shows cost_pp for a directorate, sorted highest to lowest
|
||
fig = go.Figure()
|
||
|
||
fig.add_trace(
|
||
go.Bar(
|
||
x=labels,
|
||
y=cost_pp_values,
|
||
marker=dict(
|
||
color=bar_colours,
|
||
line=dict(color="#FFFFFF", width=1),
|
||
),
|
||
hovertemplate="%{customdata}<extra></extra>",
|
||
customdata=hover_texts,
|
||
text=[f"£{v:,.0f}" for v in cost_pp_values],
|
||
textposition="outside",
|
||
textfont=dict(size=11, color="#425563"),
|
||
)
|
||
)
|
||
|
||
# Add patient count annotations below each bar
|
||
for i, (label, pts) in enumerate(zip(labels, patients_list)):
|
||
fig.add_annotation(
|
||
x=label,
|
||
y=0,
|
||
text=f"n={pts:,}",
|
||
showarrow=False,
|
||
yshift=-18,
|
||
font=dict(size=10, color="#768692", family="Source Sans 3"),
|
||
)
|
||
|
||
# Grand total line
|
||
if cost_pp_values:
|
||
total_patients = sum(patients_list)
|
||
total_cost = sum(total_costs)
|
||
weighted_avg = total_cost / total_patients if total_patients else 0
|
||
fig.add_hline(
|
||
y=weighted_avg,
|
||
line_dash="dash",
|
||
line_color="#DA291C",
|
||
line_width=1.5,
|
||
annotation_text=f"Weighted avg: £{weighted_avg:,.0f}",
|
||
annotation_position="top right",
|
||
annotation_font=dict(
|
||
size=11, color="#DA291C", family="Source Sans 3"
|
||
),
|
||
)
|
||
|
||
display_title = (
|
||
f"Cost per Patient by Directorate — {title}" if title
|
||
else "Cost per Patient by Directorate"
|
||
)
|
||
|
||
fig.update_layout(
|
||
title=dict(
|
||
text=display_title,
|
||
font=dict(
|
||
family="Source Sans 3, system-ui, sans-serif",
|
||
size=18,
|
||
color="#1E293B",
|
||
),
|
||
x=0.5,
|
||
xanchor="center",
|
||
),
|
||
xaxis=dict(
|
||
title="",
|
||
tickangle=-45 if len(data) > 6 else 0,
|
||
tickfont=dict(size=11),
|
||
automargin=True,
|
||
),
|
||
yaxis=dict(
|
||
title="£ per patient",
|
||
tickprefix="£",
|
||
tickformat=",",
|
||
gridcolor="#E2E8F0",
|
||
zeroline=True,
|
||
zerolinecolor="#CBD5E1",
|
||
),
|
||
margin=dict(t=60, l=8, r=24, b=40),
|
||
paper_bgcolor="rgba(0,0,0,0)",
|
||
plot_bgcolor="rgba(0,0,0,0)",
|
||
autosize=True,
|
||
showlegend=False,
|
||
hoverlabel=dict(
|
||
bgcolor="#FFFFFF",
|
||
bordercolor="#CBD5E1",
|
||
font=dict(
|
||
family="Source Sans 3, system-ui, sans-serif",
|
||
size=13,
|
||
color="#1E293B",
|
||
),
|
||
),
|
||
font=dict(
|
||
family="Source Sans 3, system-ui, sans-serif",
|
||
),
|
||
height=max(450, 500),
|
||
bargap=0.25,
|
||
)
|
||
|
||
return fig
|
||
|
||
|
||
def create_sankey_figure(
|
||
data: dict,
|
||
title: str = "",
|
||
) -> go.Figure:
|
||
"""Create Sankey diagram showing drug switching flows between treatment lines.
|
||
|
||
Args:
|
||
data: Dict from get_drug_transitions() with keys:
|
||
nodes: [{name}] — drug names with ordinal suffixes (e.g., "ADALIMUMAB (1st)")
|
||
links: [{source_idx, target_idx, patients}] — transitions between drugs
|
||
title: Chart title suffix (filter description).
|
||
|
||
Returns:
|
||
Plotly Figure with Sankey diagram.
|
||
"""
|
||
import re
|
||
|
||
nodes = data.get("nodes", [])
|
||
links = data.get("links", [])
|
||
|
||
if not nodes or not links:
|
||
return go.Figure()
|
||
|
||
# NHS colour palette — one colour per unique base drug name
|
||
nhs_colours = [
|
||
"#005EB8", "#003087", "#41B6E6", "#0066CC", "#1E88E5",
|
||
"#4FC3F7", "#009639", "#ED8B00", "#768692", "#AE2573",
|
||
"#8A1538", "#330072", "#DA291C", "#00A499", "#425563",
|
||
]
|
||
|
||
# Extract base drug name (strip ordinal suffix) for colour consistency
|
||
def base_drug(name: str) -> str:
|
||
return re.sub(r"\s*\(\d+(?:st|nd|rd|th)\)\s*$", "", name)
|
||
|
||
unique_bases = []
|
||
for n in nodes:
|
||
b = base_drug(n["name"])
|
||
if b not in unique_bases:
|
||
unique_bases.append(b)
|
||
base_colour_map = {b: nhs_colours[i % len(nhs_colours)] for i, b in enumerate(unique_bases)}
|
||
|
||
# Node colours — same drug gets same colour regardless of treatment line
|
||
node_colours = [base_colour_map[base_drug(n["name"])] for n in nodes]
|
||
|
||
# Node labels — format nicely
|
||
node_labels = [n["name"] for n in nodes]
|
||
|
||
# Link colours — use source node colour at 40% opacity for visual clarity
|
||
def hex_to_rgba(hex_colour: str, alpha: float) -> str:
|
||
h = hex_colour.lstrip("#")
|
||
r, g, b = int(h[0:2], 16), int(h[2:4], 16), int(h[4:6], 16)
|
||
return f"rgba({r},{g},{b},{alpha})"
|
||
|
||
link_colours = [
|
||
hex_to_rgba(node_colours[link["source_idx"]], 0.35)
|
||
for link in links
|
||
]
|
||
|
||
# Build hover text for links
|
||
link_hovers = [
|
||
f"{node_labels[link['source_idx']]} → {node_labels[link['target_idx']]}"
|
||
f"<br>Patients: {link['patients']:,}"
|
||
for link in links
|
||
]
|
||
|
||
# Compute total patients per node for node hover
|
||
node_patients = [0] * len(nodes)
|
||
for link in links:
|
||
node_patients[link["source_idx"]] += link["patients"]
|
||
# For terminal nodes (no outgoing), use incoming total
|
||
node_incoming = [0] * len(nodes)
|
||
for link in links:
|
||
node_incoming[link["target_idx"]] += link["patients"]
|
||
node_hover = []
|
||
for i, n in enumerate(nodes):
|
||
out_p = node_patients[i]
|
||
in_p = node_incoming[i]
|
||
total = max(out_p, in_p)
|
||
node_hover.append(f"<b>{n['name']}</b><br>Patients: {total:,}")
|
||
|
||
fig = go.Figure(
|
||
go.Sankey(
|
||
arrangement="snap",
|
||
node=dict(
|
||
pad=20,
|
||
thickness=25,
|
||
line=dict(color="#FFFFFF", width=1),
|
||
label=node_labels,
|
||
color=node_colours,
|
||
customdata=node_hover,
|
||
hovertemplate="%{customdata}<extra></extra>",
|
||
),
|
||
link=dict(
|
||
source=[link["source_idx"] for link in links],
|
||
target=[link["target_idx"] for link in links],
|
||
value=[link["patients"] for link in links],
|
||
color=link_colours,
|
||
customdata=link_hovers,
|
||
hovertemplate="%{customdata}<extra></extra>",
|
||
),
|
||
)
|
||
)
|
||
|
||
chart_title = "Drug Switching Flows"
|
||
if title:
|
||
chart_title = f"{chart_title} — {title}"
|
||
|
||
fig.update_layout(
|
||
title=dict(
|
||
text=chart_title,
|
||
font=dict(
|
||
family="Source Sans 3, system-ui, sans-serif",
|
||
size=18,
|
||
color="#003087",
|
||
),
|
||
x=0.5,
|
||
xanchor="center",
|
||
),
|
||
font=dict(
|
||
family="Source Sans 3, system-ui, sans-serif",
|
||
size=12,
|
||
),
|
||
paper_bgcolor="rgba(0,0,0,0)",
|
||
plot_bgcolor="rgba(0,0,0,0)",
|
||
margin=dict(t=60, l=30, r=30, b=30),
|
||
height=max(500, len(unique_bases) * 35 + 200),
|
||
)
|
||
|
||
return fig
|
||
|
||
|
||
def create_dosing_figure(
|
||
data: list[dict],
|
||
title: str = "",
|
||
group_by: str = "drug",
|
||
) -> go.Figure:
|
||
"""Create dosing interval comparison chart.
|
||
|
||
Shows weekly dosing intervals as horizontal bars, grouped either by drug
|
||
(overview mode) or by trust (single-drug comparison mode).
|
||
|
||
Args:
|
||
data: List of dicts from get_dosing_intervals() with keys:
|
||
drug, trust_name, directory, weekly_interval, dose_count,
|
||
total_weeks, patients.
|
||
title: Chart title suffix (filter description).
|
||
group_by: "drug" for drug-level overview (default),
|
||
"trust" for per-trust comparison of a single drug.
|
||
|
||
Returns:
|
||
Plotly Figure with horizontal grouped bar chart.
|
||
"""
|
||
if not data:
|
||
return go.Figure()
|
||
|
||
nhs_colours = [
|
||
"#005EB8", "#003087", "#41B6E6", "#0066CC", "#1E88E5",
|
||
"#4FC3F7", "#009639", "#ED8B00", "#768692", "#AE2573",
|
||
"#8A1538", "#330072", "#DA291C", "#00A499", "#425563",
|
||
]
|
||
|
||
if group_by == "trust":
|
||
# Single-drug mode: compare trusts, group bars by directory
|
||
fig = _dosing_by_trust(data, nhs_colours)
|
||
chart_title = f"Dosing Intervals by Trust"
|
||
else:
|
||
# Overview mode: weighted average per drug
|
||
fig = _dosing_by_drug(data, nhs_colours)
|
||
chart_title = "Dosing Interval Overview"
|
||
|
||
if title:
|
||
chart_title = f"{chart_title} — {title}"
|
||
|
||
n_rows = len(fig.data[0].y) if fig.data else 10
|
||
fig.update_layout(
|
||
title=dict(
|
||
text=chart_title,
|
||
font=dict(
|
||
family="Source Sans 3, system-ui, sans-serif",
|
||
size=18,
|
||
color="#003087",
|
||
),
|
||
x=0.5,
|
||
xanchor="center",
|
||
),
|
||
xaxis=dict(
|
||
title="Weekly Interval (weeks between doses)",
|
||
titlefont=dict(size=13, color="#425563"),
|
||
gridcolor="rgba(66,85,99,0.1)",
|
||
zeroline=True,
|
||
zerolinecolor="rgba(66,85,99,0.2)",
|
||
),
|
||
yaxis=dict(
|
||
automargin=True,
|
||
tickfont=dict(size=11),
|
||
),
|
||
font=dict(
|
||
family="Source Sans 3, system-ui, sans-serif",
|
||
size=12,
|
||
),
|
||
paper_bgcolor="rgba(0,0,0,0)",
|
||
plot_bgcolor="rgba(0,0,0,0)",
|
||
margin=dict(t=60, l=20, r=40, b=60),
|
||
height=max(450, n_rows * 40 + 150),
|
||
bargap=0.15,
|
||
bargroupgap=0.05,
|
||
showlegend=True,
|
||
legend=dict(
|
||
orientation="h",
|
||
yanchor="top",
|
||
y=-0.12,
|
||
xanchor="center",
|
||
x=0.5,
|
||
font=dict(size=11),
|
||
),
|
||
)
|
||
|
||
return fig
|
||
|
||
|
||
def _dosing_by_drug(data: list[dict], colours: list[str]) -> go.Figure:
|
||
"""Build dosing overview: one row per drug, bars per trust, showing weekly_interval."""
|
||
# Aggregate: weighted average interval per drug, summing patients
|
||
drug_agg = {}
|
||
for d in data:
|
||
drug = d["drug"]
|
||
pts = d["patients"] or 0
|
||
if drug not in drug_agg:
|
||
drug_agg[drug] = {"weighted_sum": 0.0, "total_patients": 0,
|
||
"dose_count_ws": 0.0, "total_weeks_ws": 0.0}
|
||
drug_agg[drug]["weighted_sum"] += d["weekly_interval"] * pts
|
||
drug_agg[drug]["total_patients"] += pts
|
||
drug_agg[drug]["dose_count_ws"] += d["dose_count"] * pts
|
||
drug_agg[drug]["total_weeks_ws"] += d["total_weeks"] * pts
|
||
|
||
# Build sorted list (by total patients desc)
|
||
drugs_sorted = sorted(
|
||
drug_agg.items(),
|
||
key=lambda x: x[1]["total_patients"],
|
||
)
|
||
|
||
drug_names = [d[0] for d in drugs_sorted]
|
||
intervals = []
|
||
patients_list = []
|
||
hover_texts = []
|
||
|
||
for drug, agg in drugs_sorted:
|
||
tp = agg["total_patients"]
|
||
avg_interval = agg["weighted_sum"] / tp if tp > 0 else 0
|
||
avg_doses = agg["dose_count_ws"] / tp if tp > 0 else 0
|
||
avg_weeks = agg["total_weeks_ws"] / tp if tp > 0 else 0
|
||
intervals.append(round(avg_interval, 1))
|
||
patients_list.append(tp)
|
||
hover_texts.append(
|
||
f"<b>{drug}</b><br>"
|
||
f"Avg interval: {avg_interval:.1f} weeks<br>"
|
||
f"Avg doses: {avg_doses:.1f}<br>"
|
||
f"Avg treatment: {avg_weeks:.0f} weeks<br>"
|
||
f"Patients: {tp:,}"
|
||
)
|
||
|
||
# Colour bars by interval: lower = more frequent dosing = NHS blue, higher = lighter
|
||
max_interval = max(intervals) if intervals else 1
|
||
bar_colours = []
|
||
for iv in intervals:
|
||
ratio = iv / max_interval if max_interval > 0 else 0
|
||
# Interpolate NHS blue (#005EB8) to light blue (#41B6E6)
|
||
r = int(0x00 + (0x41 - 0x00) * ratio)
|
||
g = int(0x5E + (0xB6 - 0x5E) * ratio)
|
||
b = int(0xB8 + (0xE6 - 0xB8) * ratio)
|
||
bar_colours.append(f"rgb({r},{g},{b})")
|
||
|
||
fig = go.Figure()
|
||
fig.add_trace(go.Bar(
|
||
y=drug_names,
|
||
x=intervals,
|
||
orientation="h",
|
||
marker=dict(color=bar_colours, line=dict(color="#FFFFFF", width=0.5)),
|
||
text=[f"{iv}w" for iv in intervals],
|
||
textposition="outside",
|
||
textfont=dict(size=10, color="#425563"),
|
||
customdata=list(zip(hover_texts, patients_list)),
|
||
hovertemplate="%{customdata[0]}<extra></extra>",
|
||
name="Weighted Avg Interval",
|
||
showlegend=False,
|
||
))
|
||
|
||
# Add patient count annotations on the right
|
||
for i, (drug, pts) in enumerate(zip(drug_names, patients_list)):
|
||
fig.add_annotation(
|
||
x=max(intervals) * 1.15 if intervals else 10,
|
||
y=drug,
|
||
text=f"n={pts:,}",
|
||
showarrow=False,
|
||
font=dict(size=9, color="#768692"),
|
||
xanchor="left",
|
||
)
|
||
|
||
return fig
|
||
|
||
|
||
def _dosing_by_trust(data: list[dict], colours: list[str]) -> go.Figure:
|
||
"""Build per-trust comparison: one row per trust, bars per directory, showing weekly_interval."""
|
||
from collections import defaultdict
|
||
|
||
# Group by trust × directory
|
||
trust_dir = defaultdict(list)
|
||
for d in data:
|
||
trust_dir[(d["trust_name"], d["directory"])].append(d)
|
||
|
||
# Get unique trusts and directories
|
||
trusts = sorted(set(d["trust_name"] for d in data))
|
||
directories = sorted(set(d["directory"] for d in data))
|
||
|
||
fig = go.Figure()
|
||
|
||
for i, directory in enumerate(directories):
|
||
y_labels = []
|
||
x_vals = []
|
||
hover_list = []
|
||
|
||
for trust in trusts:
|
||
entries = trust_dir.get((trust, directory))
|
||
if not entries:
|
||
continue
|
||
# Average if multiple entries per trust+directory (shouldn't happen at level 3)
|
||
avg_iv = sum(e["weekly_interval"] * (e["patients"] or 0) for e in entries)
|
||
total_pts = sum(e["patients"] or 0 for e in entries)
|
||
if total_pts == 0:
|
||
continue
|
||
avg_iv /= total_pts
|
||
avg_doses = sum(e["dose_count"] * (e["patients"] or 0) for e in entries) / total_pts
|
||
avg_weeks = sum(e["total_weeks"] * (e["patients"] or 0) for e in entries) / total_pts
|
||
|
||
# Shorten trust name for readability
|
||
short_trust = trust.replace(" NHS FOUNDATION TRUST", "").replace(" HOSPITALS", "")
|
||
y_labels.append(short_trust)
|
||
x_vals.append(round(avg_iv, 1))
|
||
hover_list.append(
|
||
f"<b>{short_trust}</b><br>"
|
||
f"Directorate: {directory}<br>"
|
||
f"Interval: {avg_iv:.1f} weeks<br>"
|
||
f"Avg doses: {avg_doses:.1f}<br>"
|
||
f"Treatment: {avg_weeks:.0f} weeks<br>"
|
||
f"Patients: {total_pts:,}"
|
||
)
|
||
|
||
if y_labels:
|
||
fig.add_trace(go.Bar(
|
||
y=y_labels,
|
||
x=x_vals,
|
||
orientation="h",
|
||
name=directory,
|
||
marker=dict(color=colours[i % len(colours)]),
|
||
customdata=hover_list,
|
||
hovertemplate="%{customdata}<extra></extra>",
|
||
))
|
||
|
||
fig.update_layout(barmode="group")
|
||
return fig
|
||
|
||
|
||
def save_figure_html(
|
||
fig: go.Figure, save_dir: str, title: str, open_browser: bool = False
|
||
) -> str:
|
||
"""
|
||
Save Plotly figure to HTML file.
|
||
|
||
Args:
|
||
fig: Plotly Figure object
|
||
save_dir: Directory to save the HTML file
|
||
title: Title used for filename
|
||
open_browser: If True, open the file in the default browser
|
||
|
||
Returns:
|
||
Path to the saved HTML file
|
||
"""
|
||
filepath = f"{save_dir}/{title}.html"
|
||
fig.write_html(filepath)
|
||
logger.info(f"Success! File saved to {filepath}")
|
||
|
||
if open_browser:
|
||
open_figure_in_browser(filepath)
|
||
|
||
return filepath
|
||
|
||
|
||
def open_figure_in_browser(filepath: str) -> None:
|
||
"""
|
||
Open an HTML file in the default browser.
|
||
|
||
Args:
|
||
filepath: Path to the HTML file
|
||
"""
|
||
webbrowser.open_new_tab("file:///" + filepath)
|
||
|
||
|
||
def figure_legacy(ice_df: pd.DataFrame, dir_string: str, save_dir: str) -> None:
|
||
"""
|
||
Create and display icicle figure (legacy interface).
|
||
|
||
This function maintains backward compatibility with the original figure()
|
||
function signature. It creates the figure, saves it to HTML, and opens
|
||
it in the browser.
|
||
|
||
Args:
|
||
ice_df: DataFrame with chart data
|
||
dir_string: Title string (used for filename and chart title)
|
||
save_dir: Directory to save the HTML file
|
||
|
||
Note:
|
||
This function is provided for backward compatibility.
|
||
New code should use create_icicle_figure() + save_figure_html() instead.
|
||
"""
|
||
# Handle avg_days column for display
|
||
ice_df = ice_df.copy()
|
||
ice_df.sort_values(by=["labels"], ascending=True, inplace=True, ignore_index=True)
|
||
|
||
first_seen = ice_df["First seen"].astype(str).replace("NaT", "N/A").to_list()
|
||
last_seen = ice_df["Last seen"].astype(str).replace("NaT", "N/A").to_list()
|
||
first_seen_parent = ice_df["First seen (Parent)"].astype(str).to_list()
|
||
last_seen_parent = ice_df["Last seen (Parent)"].astype(str).to_list()
|
||
average_spacing = ice_df.average_spacing.astype(str).to_list()
|
||
avg_seen = ice_df["avg_days"].dt.round("D").astype(str).replace("0 days", "N/A").to_list()
|
||
|
||
fig = go.Figure(
|
||
go.Icicle(
|
||
labels=ice_df.labels,
|
||
ids=ice_df.ids,
|
||
parents=ice_df.parents,
|
||
customdata=np.stack(
|
||
(
|
||
ice_df.value,
|
||
ice_df.colour,
|
||
ice_df.cost,
|
||
ice_df.costpp,
|
||
first_seen,
|
||
last_seen,
|
||
first_seen_parent,
|
||
last_seen_parent,
|
||
average_spacing,
|
||
ice_df.cost_pp_pa,
|
||
),
|
||
axis=1,
|
||
),
|
||
values=ice_df.value,
|
||
branchvalues="total",
|
||
marker=dict(colors=ice_df.colour, colorscale="Viridis"),
|
||
maxdepth=3,
|
||
texttemplate="<b>%{label}</b> "
|
||
"<br><b>Total patients:</b> %{customdata[0]} (including children/further treatments)"
|
||
"<br><b>First seen:</b> %{customdata[4]}"
|
||
"<br><b>Last seen (including further treatments):</b> %{customdata[7]}"
|
||
"<br><b>Average treatment duration:</b> %{customdata[8]}"
|
||
"<br><b>Total cost:</b> £%{customdata[2]:.3~s}"
|
||
"<br><b>Average cost per patient:</b> £%{customdata[3]:.3~s}"
|
||
"<br><b>Average cost per patient per annum:</b> £%{customdata[9]:.3~s}",
|
||
hovertemplate="<b>%{label}</b>"
|
||
"<br><b>Total patients:</b> %{customdata[0]} - %{customdata[1]:.3p} of patients in level"
|
||
"<br><b>Total cost:</b> £%{customdata[2]:.3~s}"
|
||
"<br><b>Average cost per patient:</b> £%{customdata[3]:.3~s}"
|
||
"<br><b>Average cost per patient per annum:</b> £%{customdata[9]:.3~s}"
|
||
"<br><b>First seen:</b> %{customdata[4]}"
|
||
"<br><b>Last seen (including further treatments):</b> %{customdata[7]}"
|
||
"<br><b>Average treatment duration:</b>"
|
||
"%{customdata[8]}"
|
||
"<extra></extra>",
|
||
)
|
||
)
|
||
fig.update_traces(sort=False)
|
||
fig.update_layout(
|
||
margin=dict(t=60, l=1, r=1, b=60),
|
||
title=f"Norfolk & Waveney ICS high-cost drug patient pathways - {dir_string}",
|
||
title_x=0.5,
|
||
hoverlabel=dict(font_size=16),
|
||
)
|
||
|
||
filepath = f"{save_dir}/{dir_string}.html"
|
||
fig.write_html(filepath)
|
||
logger.info(f"Success! File saved to {filepath}")
|
||
webbrowser.open_new_tab("file:///" + filepath)
|
||
|
||
|
||
def create_heatmap_figure(
|
||
data: dict,
|
||
title: str = "",
|
||
metric: str = "patients",
|
||
) -> go.Figure:
|
||
"""Create a directorate × drug heatmap chart.
|
||
|
||
Args:
|
||
data: Dict from get_drug_directory_matrix() with keys:
|
||
directories (list), drugs (list),
|
||
matrix ({dir: {drug: {patients, cost, cost_pp_pa}}}).
|
||
title: Chart title suffix (filter description).
|
||
metric: Colour metric — "patients", "cost", or "cost_pp_pa".
|
||
|
||
Returns:
|
||
Plotly Figure with annotated heatmap.
|
||
"""
|
||
directories = data.get("directories", [])
|
||
drugs = data.get("drugs", [])
|
||
matrix = data.get("matrix", {})
|
||
|
||
if not directories or not drugs:
|
||
return go.Figure()
|
||
|
||
# Cap columns to top 25 drugs for readability
|
||
max_drugs = 25
|
||
drugs = drugs[:max_drugs]
|
||
|
||
metric_labels = {
|
||
"patients": "Patients",
|
||
"cost": "Total Cost (£)",
|
||
"cost_pp_pa": "Cost per Patient p.a. (£)",
|
||
}
|
||
metric_label = metric_labels.get(metric, "Patients")
|
||
|
||
# Build 2D arrays for z-values and hover text
|
||
z_values = []
|
||
hover_texts = []
|
||
|
||
for d in directories:
|
||
row_z = []
|
||
row_hover = []
|
||
dir_data = matrix.get(d, {})
|
||
for drug in drugs:
|
||
cell = dir_data.get(drug)
|
||
if cell:
|
||
val = cell.get(metric, cell.get("patients", 0))
|
||
patients = cell.get("patients", 0)
|
||
cost = cell.get("cost", 0)
|
||
cpp = cell.get("cost_pp_pa", 0)
|
||
row_z.append(val if val else 0)
|
||
row_hover.append(
|
||
f"<b>{drug}</b><br>"
|
||
f"{d}<br>"
|
||
f"Patients: {patients:,}<br>"
|
||
f"Total cost: £{cost:,.0f}<br>"
|
||
f"Cost p.a.: £{cpp:,.0f}"
|
||
)
|
||
else:
|
||
row_z.append(0)
|
||
row_hover.append(
|
||
f"<b>{drug}</b><br>{d}<br>No patients"
|
||
)
|
||
z_values.append(row_z)
|
||
hover_texts.append(row_hover)
|
||
|
||
# NHS blue colorscale for the heatmap
|
||
colorscale = [
|
||
[0.0, "#F0F4F8"],
|
||
[0.01, "#E3F2FD"],
|
||
[0.1, "#90CAF9"],
|
||
[0.3, "#42A5F5"],
|
||
[0.5, "#1E88E5"],
|
||
[0.7, "#0066CC"],
|
||
[1.0, "#003087"],
|
||
]
|
||
|
||
fig = go.Figure(
|
||
data=go.Heatmap(
|
||
z=z_values,
|
||
x=drugs,
|
||
y=directories,
|
||
colorscale=colorscale,
|
||
hovertext=hover_texts,
|
||
hovertemplate="%{hovertext}<extra></extra>",
|
||
colorbar=dict(
|
||
title=dict(
|
||
text=metric_label,
|
||
font=dict(size=12, color="#425563"),
|
||
),
|
||
thickness=15,
|
||
len=0.8,
|
||
),
|
||
xgap=2,
|
||
ygap=2,
|
||
)
|
||
)
|
||
|
||
chart_title = f"Directorate × Drug — {metric_label}"
|
||
if title:
|
||
chart_title = f"{chart_title} — {title}"
|
||
|
||
n_drugs = len(drugs)
|
||
n_dirs = len(directories)
|
||
fig_width = max(700, 80 + n_drugs * 55)
|
||
fig_height = max(400, 80 + n_dirs * 40)
|
||
|
||
fig.update_layout(
|
||
title=dict(
|
||
text=chart_title,
|
||
font=dict(
|
||
family="Source Sans 3, system-ui, sans-serif",
|
||
size=18,
|
||
color="#003087",
|
||
),
|
||
x=0.5,
|
||
xanchor="center",
|
||
),
|
||
xaxis=dict(
|
||
title="",
|
||
tickfont=dict(size=11, color="#425563"),
|
||
tickangle=-45,
|
||
side="bottom",
|
||
),
|
||
yaxis=dict(
|
||
title="",
|
||
tickfont=dict(size=12, color="#425563"),
|
||
autorange="reversed",
|
||
),
|
||
plot_bgcolor="rgba(0,0,0,0)",
|
||
paper_bgcolor="rgba(0,0,0,0)",
|
||
font=dict(family="Source Sans 3, system-ui, sans-serif"),
|
||
margin=dict(t=60, l=200, r=80, b=120),
|
||
width=fig_width,
|
||
height=fig_height,
|
||
)
|
||
|
||
return fig
|
||
|
||
|
||
def create_duration_figure(
|
||
data: list[dict],
|
||
title: str = "",
|
||
show_directory: bool = False,
|
||
) -> go.Figure:
|
||
"""Create horizontal bar chart showing average treatment duration by drug.
|
||
|
||
Args:
|
||
data: List of dicts from get_treatment_durations() with keys:
|
||
drug, directory, avg_days, patients.
|
||
Sorted by avg_days desc.
|
||
title: Chart title suffix (filter description).
|
||
show_directory: If True, include directory in label (for overview mode).
|
||
|
||
Returns:
|
||
Plotly Figure with horizontal bars coloured by patient count.
|
||
"""
|
||
if not data:
|
||
return go.Figure()
|
||
|
||
# When not showing directory breakdown, aggregate same drug across directorates
|
||
if not show_directory:
|
||
agg = {}
|
||
for d in data:
|
||
drug = d["drug"]
|
||
pts = d["patients"]
|
||
days = d["avg_days"]
|
||
if drug not in agg:
|
||
agg[drug] = {"drug": drug, "total_weighted": 0.0, "total_pts": 0}
|
||
agg[drug]["total_weighted"] += days * pts
|
||
agg[drug]["total_pts"] += pts
|
||
data = []
|
||
for v in agg.values():
|
||
if v["total_pts"] > 0:
|
||
data.append({
|
||
"drug": v["drug"],
|
||
"avg_days": round(v["total_weighted"] / v["total_pts"], 1),
|
||
"patients": v["total_pts"],
|
||
})
|
||
data.sort(key=lambda x: -x["avg_days"])
|
||
|
||
# Cap at 40 entries for readability (keep top by patient count, then re-sort by days)
|
||
if len(data) > 40:
|
||
data.sort(key=lambda x: -x["patients"])
|
||
data = data[:40]
|
||
data.sort(key=lambda x: -x["avg_days"])
|
||
|
||
# Build labels
|
||
if show_directory:
|
||
labels = [f"{d['drug']} ({d['directory']})" for d in data]
|
||
else:
|
||
labels = [d["drug"] for d in data]
|
||
|
||
days_values = [d["avg_days"] for d in data]
|
||
patients_list = [d["patients"] for d in data]
|
||
|
||
# Colour gradient by patient count: light for few → dark NHS blue for many
|
||
max_pts = max(patients_list) if patients_list else 1
|
||
min_pts = min(patients_list) if patients_list else 0
|
||
pt_range = max_pts - min_pts if max_pts > min_pts else 1
|
||
|
||
bar_colours = []
|
||
for pts in patients_list:
|
||
t = (pts - min_pts) / pt_range
|
||
r = int(0x41 + (0x00 - 0x41) * t)
|
||
g = int(0xB6 + (0x30 - 0xB6) * t)
|
||
b = int(0xE6 + (0x87 - 0xE6) * t)
|
||
bar_colours.append(f"rgb({r},{g},{b})")
|
||
|
||
hover_texts = []
|
||
for d in data:
|
||
years = d["avg_days"] / 365.25
|
||
hover_texts.append(
|
||
f"<b>{d['drug']}</b><br>"
|
||
f"Avg duration: {d['avg_days']:,.0f} days ({years:.1f} years)<br>"
|
||
f"Patients: {d['patients']:,}"
|
||
)
|
||
|
||
fig = go.Figure()
|
||
|
||
fig.add_trace(
|
||
go.Bar(
|
||
y=labels,
|
||
x=days_values,
|
||
orientation="h",
|
||
marker=dict(
|
||
color=bar_colours,
|
||
line=dict(color="#FFFFFF", width=1),
|
||
),
|
||
hovertemplate="%{customdata}<extra></extra>",
|
||
customdata=hover_texts,
|
||
text=[f"{v:,.0f}d" for v in days_values],
|
||
textposition="outside",
|
||
textfont=dict(size=10, color="#425563"),
|
||
)
|
||
)
|
||
|
||
for i, pts in enumerate(patients_list):
|
||
fig.add_annotation(
|
||
x=days_values[i],
|
||
y=labels[i],
|
||
text=f"n={pts:,}",
|
||
showarrow=False,
|
||
xshift=45,
|
||
font=dict(size=9, color="#768692", family="Source Sans 3"),
|
||
)
|
||
|
||
chart_title = "Treatment Duration by Drug"
|
||
if title:
|
||
chart_title += f"<br><span style='font-size:13px;color:#768692'>{title}</span>"
|
||
|
||
n_bars = len(data)
|
||
fig_height = max(400, 40 + n_bars * 28)
|
||
|
||
fig.update_layout(
|
||
title=dict(
|
||
text=chart_title,
|
||
font=dict(
|
||
family="Source Sans 3, system-ui, sans-serif",
|
||
size=18,
|
||
color="#003087",
|
||
),
|
||
x=0.5,
|
||
xanchor="center",
|
||
),
|
||
xaxis=dict(
|
||
title="Average Duration (days)",
|
||
titlefont=dict(size=13, color="#425563"),
|
||
tickfont=dict(size=11, color="#425563"),
|
||
gridcolor="rgba(0,0,0,0.06)",
|
||
zeroline=True,
|
||
zerolinecolor="rgba(0,0,0,0.1)",
|
||
),
|
||
yaxis=dict(
|
||
title="",
|
||
tickfont=dict(size=11, color="#425563"),
|
||
autorange="reversed",
|
||
),
|
||
plot_bgcolor="rgba(0,0,0,0)",
|
||
paper_bgcolor="rgba(0,0,0,0)",
|
||
font=dict(family="Source Sans 3, system-ui, sans-serif"),
|
||
margin=dict(t=60, l=200, r=80, b=50),
|
||
height=fig_height,
|
||
showlegend=False,
|
||
)
|
||
|
||
return fig
|
||
|
||
|
||
# --- Trust Comparison chart functions ---
|
||
|
||
|
||
def create_trust_market_share_figure(
|
||
data: list[dict],
|
||
title: str = "",
|
||
) -> go.Figure:
|
||
"""Create horizontal stacked bar chart showing drug market share per trust.
|
||
|
||
Unlike create_market_share_figure (which groups by directorate), this groups
|
||
by trust within a single directorate — used by Trust Comparison dashboard.
|
||
|
||
Args:
|
||
data: List of dicts from get_trust_market_share() with keys:
|
||
trust_name, drug, patients, proportion, cost, cost_pp_pa.
|
||
title: Chart title suffix.
|
||
"""
|
||
if not data:
|
||
return go.Figure()
|
||
|
||
nhs_colours = [
|
||
"#003087", "#005EB8", "#0072CE", "#1E88E5", "#41B6E6",
|
||
"#4FC3F7", "#768692", "#AE2573", "#006747", "#ED8B00",
|
||
"#8A1538", "#330072", "#009639", "#DA291C", "#00A499",
|
||
]
|
||
|
||
seen_trusts = []
|
||
for d in data:
|
||
t = d["trust_name"]
|
||
if t not in seen_trusts:
|
||
seen_trusts.append(t)
|
||
|
||
seen_drugs = []
|
||
for d in data:
|
||
if d["drug"] not in seen_drugs:
|
||
seen_drugs.append(d["drug"])
|
||
|
||
drug_colour_map = {drug: nhs_colours[i % len(nhs_colours)] for i, drug in enumerate(seen_drugs)}
|
||
lookup = {(d["trust_name"], d["drug"]): d for d in data}
|
||
|
||
def short_trust(name):
|
||
return name.replace(" NHS FOUNDATION TRUST", "").replace(" HOSPITALS", "")
|
||
|
||
display_trusts = list(reversed(seen_trusts))
|
||
|
||
traces = []
|
||
for drug in seen_drugs:
|
||
y_vals = []
|
||
x_vals = []
|
||
hover_texts = []
|
||
for trust in display_trusts:
|
||
row = lookup.get((trust, drug))
|
||
y_vals.append(short_trust(trust))
|
||
if row:
|
||
x_vals.append(row["proportion"] * 100)
|
||
hover_texts.append(
|
||
f"<b>{drug}</b><br>"
|
||
f"{short_trust(trust)}<br>"
|
||
f"Patients: {row['patients']:,}<br>"
|
||
f"Share: {row['proportion']:.1%}<br>"
|
||
f"Cost: \u00a3{row['cost']:,.0f}<br>"
|
||
f"Cost p.p.p.a: \u00a3{row['cost_pp_pa']:,.0f}"
|
||
)
|
||
else:
|
||
x_vals.append(0)
|
||
hover_texts.append("")
|
||
|
||
traces.append(go.Bar(
|
||
name=drug, y=y_vals, x=x_vals, orientation="h",
|
||
marker_color=drug_colour_map[drug],
|
||
hovertemplate="%{customdata}<extra></extra>",
|
||
customdata=hover_texts,
|
||
))
|
||
|
||
display_title = f"Drug Market Share by Trust \u2014 {title}" if title else "Drug Market Share by Trust"
|
||
|
||
fig = go.Figure(data=traces)
|
||
fig.update_layout(
|
||
barmode="stack",
|
||
title=dict(
|
||
text=display_title,
|
||
font=dict(family="Source Sans 3, system-ui, sans-serif", size=16, color="#1E293B"),
|
||
x=0.5, xanchor="center",
|
||
),
|
||
xaxis=dict(title="% of patients", ticksuffix="%", range=[0, 105], gridcolor="#E2E8F0", zeroline=False),
|
||
yaxis=dict(title="", automargin=True),
|
||
legend=dict(
|
||
title="Drug", orientation="h", yanchor="top", y=-0.15,
|
||
xanchor="center", x=0.5, font=dict(family="Source Sans 3", size=11),
|
||
),
|
||
margin=dict(t=50, l=8, r=24, b=100),
|
||
paper_bgcolor="rgba(0,0,0,0)", plot_bgcolor="rgba(0,0,0,0)",
|
||
autosize=True,
|
||
hoverlabel=dict(
|
||
bgcolor="#FFFFFF", bordercolor="#CBD5E1",
|
||
font=dict(family="Source Sans 3, system-ui, sans-serif", size=13, color="#1E293B"),
|
||
),
|
||
font=dict(family="Source Sans 3, system-ui, sans-serif"),
|
||
height=max(300, len(seen_trusts) * 60 + 200),
|
||
)
|
||
|
||
return fig
|
||
|
||
|
||
def create_trust_heatmap_figure(
|
||
data: dict,
|
||
title: str = "",
|
||
metric: str = "patients",
|
||
) -> go.Figure:
|
||
"""Create a trust x drug heatmap for a single directorate.
|
||
|
||
Args:
|
||
data: Dict from get_trust_heatmap() with keys:
|
||
trusts (list), drugs (list),
|
||
matrix ({trust_name: {drug: {patients, cost, cost_pp_pa}}}).
|
||
title: Chart title suffix.
|
||
metric: Colour metric — "patients", "cost", or "cost_pp_pa".
|
||
"""
|
||
trusts = data.get("trusts", [])
|
||
drugs = data.get("drugs", [])
|
||
matrix = data.get("matrix", {})
|
||
|
||
if not trusts or not drugs:
|
||
return go.Figure()
|
||
|
||
drugs = drugs[:25]
|
||
|
||
metric_labels = {
|
||
"patients": "Patients",
|
||
"cost": "Total Cost (\u00a3)",
|
||
"cost_pp_pa": "Cost per Patient p.a. (\u00a3)",
|
||
}
|
||
metric_label = metric_labels.get(metric, "Patients")
|
||
|
||
def short_trust(name):
|
||
return name.replace(" NHS FOUNDATION TRUST", "").replace(" HOSPITALS", "")
|
||
|
||
z_values = []
|
||
hover_texts = []
|
||
|
||
for t in trusts:
|
||
row_z = []
|
||
row_hover = []
|
||
trust_data = matrix.get(t, {})
|
||
for drug in drugs:
|
||
cell = trust_data.get(drug)
|
||
if cell:
|
||
val = cell.get(metric, cell.get("patients", 0))
|
||
patients = cell.get("patients", 0)
|
||
cost = cell.get("cost", 0)
|
||
cpp = cell.get("cost_pp_pa", 0)
|
||
row_z.append(val if val else 0)
|
||
row_hover.append(
|
||
f"<b>{drug}</b><br>"
|
||
f"{short_trust(t)}<br>"
|
||
f"Patients: {patients:,}<br>"
|
||
f"Total cost: \u00a3{cost:,.0f}<br>"
|
||
f"Cost p.a.: \u00a3{cpp:,.0f}"
|
||
)
|
||
else:
|
||
row_z.append(0)
|
||
row_hover.append(f"<b>{drug}</b><br>{short_trust(t)}<br>No patients")
|
||
z_values.append(row_z)
|
||
hover_texts.append(row_hover)
|
||
|
||
colorscale = [
|
||
[0.0, "#F0F4F8"], [0.01, "#E3F2FD"], [0.1, "#90CAF9"],
|
||
[0.3, "#42A5F5"], [0.5, "#1E88E5"], [0.7, "#0066CC"], [1.0, "#003087"],
|
||
]
|
||
|
||
display_trusts = [short_trust(t) for t in trusts]
|
||
|
||
fig = go.Figure(
|
||
data=go.Heatmap(
|
||
z=z_values, x=drugs, y=display_trusts,
|
||
colorscale=colorscale,
|
||
hovertext=hover_texts,
|
||
hovertemplate="%{hovertext}<extra></extra>",
|
||
colorbar=dict(
|
||
title=dict(text=metric_label, font=dict(size=12, color="#425563")),
|
||
thickness=15, len=0.8,
|
||
),
|
||
xgap=2, ygap=2,
|
||
)
|
||
)
|
||
|
||
chart_title = f"Trust \u00d7 Drug \u2014 {metric_label}"
|
||
if title:
|
||
chart_title = f"{chart_title} \u2014 {title}"
|
||
|
||
n_drugs = len(drugs)
|
||
n_trusts = len(trusts)
|
||
|
||
fig.update_layout(
|
||
title=dict(
|
||
text=chart_title,
|
||
font=dict(family="Source Sans 3, system-ui, sans-serif", size=16, color="#003087"),
|
||
x=0.5, xanchor="center",
|
||
),
|
||
xaxis=dict(title="", tickfont=dict(size=11, color="#425563"), tickangle=-45, side="bottom"),
|
||
yaxis=dict(title="", tickfont=dict(size=12, color="#425563"), autorange="reversed"),
|
||
plot_bgcolor="rgba(0,0,0,0)", paper_bgcolor="rgba(0,0,0,0)",
|
||
font=dict(family="Source Sans 3, system-ui, sans-serif"),
|
||
margin=dict(t=60, l=200, r=80, b=120),
|
||
width=max(700, 80 + n_drugs * 55),
|
||
height=max(300, 80 + n_trusts * 50),
|
||
)
|
||
|
||
return fig
|
||
|
||
|
||
def create_trust_duration_figure(
|
||
data: list[dict],
|
||
title: str = "",
|
||
) -> go.Figure:
|
||
"""Create grouped horizontal bar chart showing drug durations by trust.
|
||
|
||
Args:
|
||
data: List of dicts from get_trust_durations() with keys:
|
||
drug, trust_name, avg_days, patients.
|
||
title: Chart title suffix.
|
||
"""
|
||
if not data:
|
||
return go.Figure()
|
||
|
||
nhs_colours = [
|
||
"#005EB8", "#003087", "#41B6E6", "#0066CC", "#1E88E5",
|
||
"#4FC3F7", "#009639", "#ED8B00", "#768692", "#AE2573",
|
||
]
|
||
|
||
seen_drugs = []
|
||
for d in data:
|
||
if d["drug"] not in seen_drugs:
|
||
seen_drugs.append(d["drug"])
|
||
|
||
seen_trusts = []
|
||
for d in data:
|
||
t = d["trust_name"]
|
||
if t not in seen_trusts:
|
||
seen_trusts.append(t)
|
||
|
||
def short_trust(name):
|
||
return name.replace(" NHS FOUNDATION TRUST", "").replace(" HOSPITALS", "")
|
||
|
||
trust_colour_map = {t: nhs_colours[i % len(nhs_colours)] for i, t in enumerate(seen_trusts)}
|
||
lookup = {(d["drug"], d["trust_name"]): d for d in data}
|
||
|
||
display_drugs = list(reversed(seen_drugs))
|
||
|
||
traces = []
|
||
for trust in seen_trusts:
|
||
y_vals = []
|
||
x_vals = []
|
||
hover_texts = []
|
||
for drug in display_drugs:
|
||
row = lookup.get((drug, trust))
|
||
y_vals.append(drug)
|
||
if row:
|
||
years = row["avg_days"] / 365.25
|
||
x_vals.append(row["avg_days"])
|
||
hover_texts.append(
|
||
f"<b>{drug}</b><br>"
|
||
f"{short_trust(trust)}<br>"
|
||
f"Avg duration: {row['avg_days']:,.0f} days ({years:.1f} yrs)<br>"
|
||
f"Patients: {row['patients']:,}"
|
||
)
|
||
else:
|
||
x_vals.append(0)
|
||
hover_texts.append("")
|
||
|
||
traces.append(go.Bar(
|
||
name=short_trust(trust), y=y_vals, x=x_vals, orientation="h",
|
||
marker_color=trust_colour_map[trust],
|
||
hovertemplate="%{customdata}<extra></extra>",
|
||
customdata=hover_texts,
|
||
))
|
||
|
||
display_title = f"Treatment Duration by Trust \u2014 {title}" if title else "Treatment Duration by Trust"
|
||
|
||
fig = go.Figure(data=traces)
|
||
fig.update_layout(
|
||
barmode="group",
|
||
title=dict(
|
||
text=display_title,
|
||
font=dict(family="Source Sans 3, system-ui, sans-serif", size=16, color="#003087"),
|
||
x=0.5, xanchor="center",
|
||
),
|
||
xaxis=dict(
|
||
title="Average Duration (days)", titlefont=dict(size=13, color="#425563"),
|
||
gridcolor="rgba(0,0,0,0.06)", zeroline=True, zerolinecolor="rgba(0,0,0,0.1)",
|
||
),
|
||
yaxis=dict(title="", automargin=True, tickfont=dict(size=11, color="#425563")),
|
||
legend=dict(
|
||
title="Trust", orientation="h", yanchor="top", y=-0.12,
|
||
xanchor="center", x=0.5, font=dict(size=11),
|
||
),
|
||
plot_bgcolor="rgba(0,0,0,0)", paper_bgcolor="rgba(0,0,0,0)",
|
||
font=dict(family="Source Sans 3, system-ui, sans-serif"),
|
||
margin=dict(t=60, l=200, r=40, b=100),
|
||
height=max(350, len(seen_drugs) * 35 + 200),
|
||
bargap=0.15, bargroupgap=0.05,
|
||
)
|
||
|
||
return fig
|