2398 lines
77 KiB
Python
2398 lines
77 KiB
Python
"""
|
||
Plotly chart generation for patient pathway analysis.
|
||
|
||
This module contains functions for creating interactive icicle charts
|
||
that visualize patient treatment pathways. The charts display hierarchical
|
||
data: Trust → Directory → Drug → Pathway.
|
||
"""
|
||
|
||
import webbrowser
|
||
from typing import Optional
|
||
|
||
import numpy as np
|
||
import pandas as pd
|
||
import plotly.graph_objects as go
|
||
|
||
from core.logging_config import get_logger
|
||
|
||
logger = get_logger(__name__)
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Shared styling constants
|
||
# ---------------------------------------------------------------------------
|
||
|
||
CHART_FONT_FAMILY = "Source Sans 3, system-ui, sans-serif"
|
||
CHART_TITLE_SIZE = 18
|
||
CHART_TITLE_COLOR = "#1E293B"
|
||
GRID_COLOR = "#E2E8F0"
|
||
ANNOTATION_COLOR = "#768692"
|
||
|
||
# 7 maximally-distinct colours for trust-comparison charts
|
||
TRUST_PALETTE = [
|
||
"#005EB8", # NHS Blue
|
||
"#DA291C", # Red
|
||
"#009639", # Green
|
||
"#ED8B00", # Orange
|
||
"#7C2855", # Plum
|
||
"#00A499", # Teal
|
||
"#330072", # Purple
|
||
]
|
||
|
||
# 15 distinct colours for drug-level charts
|
||
DRUG_PALETTE = [
|
||
"#005EB8", "#DA291C", "#009639", "#ED8B00", "#7C2855",
|
||
"#00A499", "#330072", "#E06666", "#6FA8DC", "#93C47D",
|
||
"#F6B26B", "#8E7CC3", "#C27BA0", "#76A5AF", "#FFD966",
|
||
]
|
||
|
||
|
||
def _smart_legend(n_items: int, legend_title: str = "") -> dict:
|
||
"""Return a legend dict that adapts to the number of items.
|
||
|
||
- >15 items: vertical legend to the right of the chart
|
||
- ≤15 items: horizontal legend below the chart with dynamic bottom margin
|
||
|
||
Returns a dict suitable for ``legend=...`` inside ``fig.update_layout()``.
|
||
The caller should also set bottom margin accordingly — use
|
||
``_smart_legend_margin_b(n_items)`` for that.
|
||
"""
|
||
base = dict(
|
||
font=dict(family=CHART_FONT_FAMILY, size=11),
|
||
)
|
||
if legend_title:
|
||
base["title"] = legend_title
|
||
|
||
if n_items > 15:
|
||
base.update(
|
||
orientation="v",
|
||
x=1.02,
|
||
y=1,
|
||
xanchor="left",
|
||
yanchor="top",
|
||
)
|
||
else:
|
||
base.update(
|
||
orientation="h",
|
||
yanchor="top",
|
||
y=-0.12,
|
||
xanchor="center",
|
||
x=0.5,
|
||
)
|
||
return base
|
||
|
||
|
||
def _smart_legend_margin(n_items: int) -> dict:
|
||
"""Return margin dict with bottom margin adapted to legend size.
|
||
|
||
- >15 items: vertical right legend needs extra right margin (r=140)
|
||
but minimal bottom margin (b=40).
|
||
- ≤15 items: horizontal legend needs bottom margin scaled to
|
||
estimated row count (~6 items per row at font size 11).
|
||
"""
|
||
if n_items > 15:
|
||
return dict(r=140, b=40)
|
||
else:
|
||
rows = max(1, (n_items + 5) // 6) # ~6 items per row
|
||
return dict(b=max(60, rows * 28 + 30), r=24)
|
||
|
||
|
||
def _base_layout(title: str, **overrides) -> dict:
|
||
"""Return a dict of shared Plotly layout properties.
|
||
|
||
All chart functions should call this to get consistent styling, then
|
||
update the result with chart-specific overrides.
|
||
|
||
Args:
|
||
title: Display title for the chart.
|
||
**overrides: Any key accepted by ``fig.update_layout()``; these are
|
||
merged on top of the base dict so callers can override margins,
|
||
height, etc.
|
||
|
||
Returns:
|
||
Dict ready to be unpacked into ``fig.update_layout(**layout)``.
|
||
"""
|
||
layout = dict(
|
||
title=dict(
|
||
text=title,
|
||
font=dict(
|
||
family=CHART_FONT_FAMILY,
|
||
size=CHART_TITLE_SIZE,
|
||
color=CHART_TITLE_COLOR,
|
||
),
|
||
x=0.5,
|
||
xanchor="center",
|
||
),
|
||
hoverlabel=dict(
|
||
bgcolor="#FFFFFF",
|
||
bordercolor="#CBD5E1",
|
||
font=dict(
|
||
family=CHART_FONT_FAMILY,
|
||
size=13,
|
||
color=CHART_TITLE_COLOR,
|
||
),
|
||
),
|
||
paper_bgcolor="rgba(0,0,0,0)",
|
||
plot_bgcolor="rgba(0,0,0,0)",
|
||
autosize=True,
|
||
font=dict(family=CHART_FONT_FAMILY),
|
||
xaxis=dict(automargin=True),
|
||
yaxis=dict(automargin=True),
|
||
)
|
||
layout.update(overrides)
|
||
return layout
|
||
|
||
|
||
def create_icicle_figure(ice_df: pd.DataFrame, title: str) -> go.Figure:
|
||
"""
|
||
Create Plotly icicle figure from prepared DataFrame.
|
||
|
||
This function generates an interactive icicle chart showing patient pathway
|
||
hierarchies with custom data including costs, dates, and treatment durations.
|
||
|
||
Args:
|
||
ice_df: DataFrame with columns:
|
||
- parents: Parent node in hierarchy
|
||
- ids: Unique identifier for each node
|
||
- labels: Display label for each node
|
||
- value: Number of patients
|
||
- colour: Color value for visualization
|
||
- cost: Total cost
|
||
- costpp: Cost per patient
|
||
- cost_pp_pa: Cost per patient per annum
|
||
- First seen: First intervention date
|
||
- Last seen: Last intervention date
|
||
- First seen (Parent): Earliest date in parent group
|
||
- Last seen (Parent): Latest date in parent group
|
||
- average_spacing: Formatted string with dosing information
|
||
- avg_days: Average treatment duration
|
||
title: Chart title
|
||
|
||
Returns:
|
||
Plotly Figure object ready for display or export
|
||
"""
|
||
ice_df = ice_df.copy()
|
||
ice_df.sort_values(by=["labels"], ascending=True, inplace=True, ignore_index=True)
|
||
|
||
first_seen = ice_df["First seen"].astype(str).replace("NaT", "N/A").to_list()
|
||
last_seen = ice_df["Last seen"].astype(str).replace("NaT", "N/A").to_list()
|
||
first_seen_parent = ice_df["First seen (Parent)"].astype(str).to_list()
|
||
last_seen_parent = ice_df["Last seen (Parent)"].astype(str).to_list()
|
||
average_spacing = ice_df.average_spacing.astype(str).to_list()
|
||
|
||
fig = go.Figure(
|
||
go.Icicle(
|
||
labels=ice_df.labels,
|
||
ids=ice_df.ids,
|
||
parents=ice_df.parents,
|
||
customdata=np.stack(
|
||
(
|
||
ice_df.value,
|
||
ice_df.colour,
|
||
ice_df.cost,
|
||
ice_df.costpp,
|
||
first_seen,
|
||
last_seen,
|
||
first_seen_parent,
|
||
last_seen_parent,
|
||
average_spacing,
|
||
ice_df.cost_pp_pa,
|
||
),
|
||
axis=1,
|
||
),
|
||
values=ice_df.value,
|
||
branchvalues="total",
|
||
marker=dict(colors=ice_df.colour, colorscale="Viridis"),
|
||
maxdepth=3,
|
||
texttemplate="<b>%{label}</b> "
|
||
"<br><b>Total patients:</b> %{customdata[0]} (including children/further treatments)"
|
||
"<br><b>First seen:</b> %{customdata[4]}"
|
||
"<br><b>Last seen (including further treatments):</b> %{customdata[7]}"
|
||
"<br><b>Average treatment duration:</b> %{customdata[8]}"
|
||
"<br><b>Total cost:</b> £%{customdata[2]:.3~s}"
|
||
"<br><b>Average cost per patient:</b> £%{customdata[3]:.3~s}"
|
||
"<br><b>Average cost per patient per annum:</b> £%{customdata[9]:.3~s}",
|
||
hovertemplate="<b>%{label}</b>"
|
||
"<br><b>Total patients:</b> %{customdata[0]} - %{customdata[1]:.3p} of patients in level"
|
||
"<br><b>Total cost:</b> £%{customdata[2]:.3~s}"
|
||
"<br><b>Average cost per patient:</b> £%{customdata[3]:.3~s}"
|
||
"<br><b>Average cost per patient per annum:</b> £%{customdata[9]:.3~s}"
|
||
"<br><b>First seen:</b> %{customdata[4]}"
|
||
"<br><b>Last seen (including further treatments):</b> %{customdata[7]}"
|
||
"<br><b>Average treatment duration:</b>"
|
||
"%{customdata[8]}"
|
||
"<extra></extra>",
|
||
)
|
||
)
|
||
fig.update_traces(sort=False)
|
||
fig.update_layout(
|
||
margin=dict(t=60, l=1, r=1, b=60),
|
||
title=f"Norfolk & Waveney ICS high-cost drug patient pathways - {title}",
|
||
title_x=0.5,
|
||
hoverlabel=dict(font_size=16),
|
||
)
|
||
|
||
return fig
|
||
|
||
|
||
def create_icicle_from_nodes(nodes: list[dict], title: str = "") -> go.Figure:
|
||
"""
|
||
Create Plotly icicle figure from a list of pathway node dicts.
|
||
|
||
This is the dict-based entry point used by the Dash app. The nodes list
|
||
comes directly from the chart-data dcc.Store (JSON-serialized dicts with
|
||
underscore keys matching SQLite column names).
|
||
|
||
Args:
|
||
nodes: List of dicts with keys: parents, ids, labels, value, cost,
|
||
costpp, cost_pp_pa, colour, first_seen, last_seen,
|
||
first_seen_parent, last_seen_parent, average_spacing
|
||
title: Chart title (e.g. "By Directory | All years / Last 6 months")
|
||
|
||
Returns:
|
||
Plotly Figure object ready for dcc.Graph
|
||
"""
|
||
if not nodes:
|
||
return go.Figure()
|
||
|
||
parents = [d.get("parents", "") for d in nodes]
|
||
ids = [d.get("ids", "") for d in nodes]
|
||
labels = [d.get("labels", "") for d in nodes]
|
||
values = [d.get("value", 0) for d in nodes]
|
||
colours = [d.get("colour", 0.0) for d in nodes]
|
||
|
||
costs = [d.get("cost", 0.0) for d in nodes]
|
||
costpp = [d.get("costpp", 0.0) for d in nodes]
|
||
first_seen = [d.get("first_seen", "N/A") or "N/A" for d in nodes]
|
||
last_seen = [d.get("last_seen", "N/A") or "N/A" for d in nodes]
|
||
first_seen_parent = [d.get("first_seen_parent", "N/A") or "N/A" for d in nodes]
|
||
last_seen_parent = [d.get("last_seen_parent", "N/A") or "N/A" for d in nodes]
|
||
average_spacing = [d.get("average_spacing", "") or "" for d in nodes]
|
||
cost_pp_pa = [d.get("cost_pp_pa", 0.0) or 0.0 for d in nodes]
|
||
|
||
customdata = list(zip(
|
||
values, # [0]
|
||
colours, # [1]
|
||
costs, # [2]
|
||
costpp, # [3]
|
||
first_seen, # [4]
|
||
last_seen, # [5]
|
||
first_seen_parent, # [6]
|
||
last_seen_parent, # [7]
|
||
average_spacing, # [8]
|
||
cost_pp_pa, # [9]
|
||
))
|
||
|
||
# NHS blue gradient (Heritage Blue → Primary Blue → Vibrant Blue → Sky Blue → Pale Blue)
|
||
colorscale = [
|
||
[0.0, "#003087"],
|
||
[0.25, "#0066CC"],
|
||
[0.5, "#1E88E5"],
|
||
[0.75, "#4FC3F7"],
|
||
[1.0, "#E3F2FD"],
|
||
]
|
||
|
||
fig = go.Figure(
|
||
go.Icicle(
|
||
labels=labels,
|
||
ids=ids,
|
||
parents=parents,
|
||
values=values,
|
||
branchvalues="total",
|
||
marker=dict(
|
||
colors=colours,
|
||
colorscale=colorscale,
|
||
line=dict(width=1, color="#FFFFFF"),
|
||
),
|
||
maxdepth=3,
|
||
customdata=customdata,
|
||
texttemplate=(
|
||
"<b>%{label}</b> "
|
||
"<br><b>Total patients:</b> %{customdata[0]} (including children/further treatments)"
|
||
"<br><b>First seen:</b> %{customdata[4]}"
|
||
"<br><b>Last seen (including further treatments):</b> %{customdata[7]}"
|
||
"<br><b>Average treatment duration:</b> %{customdata[8]}"
|
||
"<br><b>Total cost:</b> \u00a3%{customdata[2]:.3~s}"
|
||
"<br><b>Average cost per patient:</b> \u00a3%{customdata[3]:.3~s}"
|
||
"<br><b>Average cost per patient per annum:</b> \u00a3%{customdata[9]:.3~s}"
|
||
),
|
||
hovertemplate=(
|
||
"<b>%{label}</b>"
|
||
"<br><b>Total patients:</b> %{customdata[0]} - %{customdata[1]:.3p} of patients in level"
|
||
"<br><b>Total cost:</b> \u00a3%{customdata[2]:.3~s}"
|
||
"<br><b>Average cost per patient:</b> \u00a3%{customdata[3]:.3~s}"
|
||
"<br><b>Average cost per patient per annum:</b> \u00a3%{customdata[9]:.3~s}"
|
||
"<br><b>First seen:</b> %{customdata[4]}"
|
||
"<br><b>Last seen (including further treatments):</b> %{customdata[7]}"
|
||
"<br><b>Average treatment duration:</b>"
|
||
"%{customdata[8]}"
|
||
"<extra></extra>"
|
||
),
|
||
textfont=dict(
|
||
family=CHART_FONT_FAMILY,
|
||
size=12,
|
||
),
|
||
)
|
||
)
|
||
|
||
display_title = f"Patient Pathways \u2014 {title}" if title else "Patient Pathways"
|
||
|
||
layout = _base_layout(
|
||
display_title,
|
||
margin=dict(t=40, l=8, r=8, b=24),
|
||
height=700,
|
||
hoverlabel=dict(
|
||
bgcolor="#FFFFFF",
|
||
bordercolor="#CBD5E1",
|
||
font=dict(
|
||
family=CHART_FONT_FAMILY,
|
||
size=14,
|
||
color=CHART_TITLE_COLOR,
|
||
),
|
||
),
|
||
clickmode="event+select",
|
||
)
|
||
fig.update_layout(**layout)
|
||
|
||
fig.update_traces(sort=False)
|
||
|
||
return fig
|
||
|
||
|
||
def create_market_share_figure(data: list[dict], title: str = "") -> go.Figure:
|
||
"""
|
||
Create horizontal grouped bar chart showing first-line drug market share by directorate.
|
||
|
||
Args:
|
||
data: List of dicts from get_drug_market_share() with keys:
|
||
directory, drug, patients, proportion, cost, cost_pp_pa
|
||
Sorted by directory total patients desc, drugs desc within.
|
||
title: Chart title suffix (filter description)
|
||
|
||
Returns:
|
||
Plotly Figure with horizontal bars grouped by directorate.
|
||
"""
|
||
if not data:
|
||
return go.Figure()
|
||
|
||
# Collect unique directorates (in order — already sorted by total patients desc)
|
||
seen_dirs = []
|
||
for d in data:
|
||
if d["directory"] not in seen_dirs:
|
||
seen_dirs.append(d["directory"])
|
||
|
||
# Collect unique drugs across all directorates (preserve first-encountered order)
|
||
seen_drugs = []
|
||
for d in data:
|
||
if d["drug"] not in seen_drugs:
|
||
seen_drugs.append(d["drug"])
|
||
|
||
# Build one trace per drug
|
||
drug_colour_map = {drug: DRUG_PALETTE[i % len(DRUG_PALETTE)] for i, drug in enumerate(seen_drugs)}
|
||
|
||
# Build a lookup: (directory, drug) -> row
|
||
lookup = {(d["directory"], d["drug"]): d for d in data}
|
||
|
||
# Reverse directory order so highest total is at the top of horizontal chart
|
||
display_dirs = list(reversed(seen_dirs))
|
||
|
||
traces = []
|
||
for drug in seen_drugs:
|
||
y_vals = []
|
||
x_vals = []
|
||
hover_texts = []
|
||
for directory in display_dirs:
|
||
row = lookup.get((directory, drug))
|
||
if row:
|
||
y_vals.append(directory)
|
||
x_vals.append(row["proportion"] * 100)
|
||
hover_texts.append(
|
||
f"<b>{drug}</b><br>"
|
||
f"{directory}<br>"
|
||
f"Patients: {row['patients']:,}<br>"
|
||
f"Share: {row['proportion']:.1%}<br>"
|
||
f"Cost: £{row['cost']:,.0f}<br>"
|
||
f"Cost p.p.p.a: £{row['cost_pp_pa']:,.0f}"
|
||
)
|
||
else:
|
||
y_vals.append(directory)
|
||
x_vals.append(0)
|
||
hover_texts.append("")
|
||
|
||
traces.append(go.Bar(
|
||
name=drug,
|
||
y=y_vals,
|
||
x=x_vals,
|
||
orientation="h",
|
||
marker_color=drug_colour_map[drug],
|
||
hovertemplate="%{customdata}<extra></extra>",
|
||
customdata=hover_texts,
|
||
))
|
||
|
||
display_title = f"First-Line Drug Market Share — {title}" if title else "First-Line Drug Market Share"
|
||
|
||
n_drugs = len(seen_drugs)
|
||
legend_margins = _smart_legend_margin(n_drugs)
|
||
|
||
fig = go.Figure(data=traces)
|
||
layout = _base_layout(display_title)
|
||
layout.update(
|
||
barmode="stack",
|
||
xaxis=dict(
|
||
title="% of patients",
|
||
ticksuffix="%",
|
||
range=[0, 105],
|
||
gridcolor=GRID_COLOR,
|
||
zeroline=False,
|
||
),
|
||
yaxis=dict(title="", automargin=True),
|
||
legend=_smart_legend(n_drugs, legend_title="Drug"),
|
||
margin=dict(t=50, l=8, **legend_margins),
|
||
height=max(600, len(seen_dirs) * 60 + 200),
|
||
)
|
||
fig.update_layout(**layout)
|
||
|
||
return fig
|
||
|
||
|
||
def create_cost_effectiveness_figure(
|
||
data: list[dict],
|
||
retention: dict,
|
||
title: str = "",
|
||
) -> go.Figure:
|
||
"""
|
||
Create horizontal lollipop chart showing pathway cost per patient per annum.
|
||
|
||
Args:
|
||
data: List of dicts from get_pathway_costs() with keys:
|
||
ids, pathway_label, cost_pp_pa, patients, cost, avg_days,
|
||
directory, trust_name, drug_sequence, level.
|
||
Sorted by cost_pp_pa desc.
|
||
retention: Dict from calculate_retention_rate() mapping ids to retention
|
||
info: {retained_patients, total_patients, retention_rate, drug_sequence}.
|
||
title: Chart title suffix (filter description).
|
||
|
||
Returns:
|
||
Plotly Figure with horizontal lollipop dots and retention annotations.
|
||
"""
|
||
if not data:
|
||
return go.Figure()
|
||
|
||
# Filter to pathways with positive cost
|
||
filtered = [d for d in data if d["cost_pp_pa"] > 0]
|
||
if not filtered:
|
||
return go.Figure()
|
||
|
||
# Cap to top 40 pathways by cost to keep chart readable
|
||
filtered = filtered[:40]
|
||
|
||
# Reverse for horizontal chart (highest cost at top)
|
||
filtered = list(reversed(filtered))
|
||
|
||
pathway_labels = [d["pathway_label"] for d in filtered]
|
||
costs = [d["cost_pp_pa"] for d in filtered]
|
||
patients = [d["patients"] for d in filtered]
|
||
|
||
# Colour gradient: green (cheap) → amber → red (expensive)
|
||
max_cost = max(costs) if costs else 1
|
||
min_cost = min(costs) if costs else 0
|
||
cost_range = max_cost - min_cost if max_cost != min_cost else 1
|
||
|
||
def _lerp_color(ratio: float) -> str:
|
||
"""Smooth green→amber→red gradient via linear RGB interpolation."""
|
||
green = (0x00, 0x96, 0x39)
|
||
amber = (0xED, 0x8B, 0x00)
|
||
red = (0xDA, 0x29, 0x1C)
|
||
ratio = max(0.0, min(1.0, ratio))
|
||
if ratio <= 0.5:
|
||
t = ratio / 0.5
|
||
c1, c2 = green, amber
|
||
else:
|
||
t = (ratio - 0.5) / 0.5
|
||
c1, c2 = amber, red
|
||
r = int(c1[0] + (c2[0] - c1[0]) * t)
|
||
g = int(c1[1] + (c2[1] - c1[1]) * t)
|
||
b = int(c1[2] + (c2[2] - c1[2]) * t)
|
||
return f"rgb({r},{g},{b})"
|
||
|
||
colours = [_lerp_color((c - min_cost) / cost_range) for c in costs]
|
||
|
||
# Dot size scaled by patient count (min 8, max 30)
|
||
max_pts = max(patients) if patients else 1
|
||
min_pts = min(patients) if patients else 1
|
||
pts_range = max_pts - min_pts if max_pts != min_pts else 1
|
||
sizes = [8 + (p - min_pts) / pts_range * 22 for p in patients]
|
||
|
||
# Build hover text with retention info
|
||
hover_texts = []
|
||
for d in filtered:
|
||
retention_info = retention.get(d["ids"], {})
|
||
retention_rate = retention_info.get("retention_rate")
|
||
drugs_in_seq = len(d["drug_sequence"])
|
||
|
||
hover = (
|
||
f"<b>{d['pathway_label']}</b><br>"
|
||
f"Cost p.p.p.a.: £{d['cost_pp_pa']:,.0f}<br>"
|
||
f"Patients: {d['patients']:,}<br>"
|
||
f"Total cost: £{d['cost']:,.0f}<br>"
|
||
f"Avg duration: {d['avg_days']:,.0f} days<br>"
|
||
f"Directorate: {d['directory']}<br>"
|
||
f"Treatment lines: {drugs_in_seq}"
|
||
)
|
||
if retention_rate is not None:
|
||
hover += f"<br>Retention: {retention_rate:.0f}% (no further switch)"
|
||
hover_texts.append(hover)
|
||
|
||
# Lollipop sticks (horizontal lines from 0 to cost)
|
||
stick_traces = []
|
||
for i, (label, cost) in enumerate(zip(pathway_labels, costs)):
|
||
stick_traces.append(
|
||
go.Scatter(
|
||
x=[0, cost],
|
||
y=[label, label],
|
||
mode="lines",
|
||
line=dict(color="#CBD5E1", width=1.5),
|
||
showlegend=False,
|
||
hoverinfo="skip",
|
||
)
|
||
)
|
||
|
||
# Lollipop dots
|
||
dot_trace = go.Scatter(
|
||
x=costs,
|
||
y=pathway_labels,
|
||
mode="markers",
|
||
marker=dict(
|
||
size=sizes,
|
||
color=colours,
|
||
line=dict(color="#FFFFFF", width=1),
|
||
),
|
||
hovertemplate="%{customdata}<extra></extra>",
|
||
customdata=hover_texts,
|
||
showlegend=False,
|
||
)
|
||
|
||
display_title = (
|
||
f"Pathway Cost Effectiveness — {title}" if title
|
||
else "Pathway Cost Effectiveness (£ per patient per annum)"
|
||
)
|
||
|
||
fig = go.Figure(data=stick_traces + [dot_trace])
|
||
|
||
# Add retention annotations for pathways with notable retention
|
||
annotation_count = 0
|
||
for d in filtered:
|
||
ret = retention.get(d["ids"], {})
|
||
rate = ret.get("retention_rate")
|
||
if rate is not None and rate < 90 and d["patients"] >= 10 and annotation_count < 8:
|
||
fig.add_annotation(
|
||
x=d["cost_pp_pa"],
|
||
y=d["pathway_label"],
|
||
text=f"{rate:.0f}% retain",
|
||
showarrow=False,
|
||
xanchor="left",
|
||
xshift=10,
|
||
font=dict(size=10, color=ANNOTATION_COLOR, family=CHART_FONT_FAMILY),
|
||
)
|
||
annotation_count += 1
|
||
|
||
layout = _base_layout(display_title)
|
||
layout.update(
|
||
xaxis=dict(
|
||
title="£ per patient per annum",
|
||
tickprefix="£",
|
||
tickformat=",",
|
||
gridcolor=GRID_COLOR,
|
||
zeroline=True,
|
||
zerolinecolor="#CBD5E1",
|
||
),
|
||
yaxis=dict(
|
||
title="",
|
||
automargin=True,
|
||
tickfont=dict(size=11),
|
||
),
|
||
margin=dict(t=50, l=8, r=80, b=40),
|
||
height=max(600, len(filtered) * 28 + 150),
|
||
)
|
||
fig.update_layout(**layout)
|
||
|
||
return fig
|
||
|
||
|
||
def create_cost_waterfall_figure(
|
||
data: list[dict],
|
||
title: str = "",
|
||
is_trust_comparison: bool = False,
|
||
) -> go.Figure:
|
||
"""Create waterfall chart showing cost per patient by directorate/indication.
|
||
|
||
Args:
|
||
data: List of dicts from get_cost_waterfall() with keys:
|
||
directory, patients, total_cost, cost_pp.
|
||
Sorted by cost_pp desc.
|
||
title: Chart title suffix (filter description).
|
||
|
||
Returns:
|
||
Plotly Figure with waterfall bars and total.
|
||
"""
|
||
if not data:
|
||
return go.Figure()
|
||
|
||
labels = [d["directory"] for d in data]
|
||
cost_pp_values = [d["cost_pp"] for d in data]
|
||
patients_list = [d["patients"] for d in data]
|
||
total_costs = [d["total_cost"] for d in data]
|
||
|
||
palette = TRUST_PALETTE if is_trust_comparison else DRUG_PALETTE
|
||
bar_colours = [palette[i % len(palette)] for i in range(len(data))]
|
||
|
||
hover_texts = []
|
||
for d in data:
|
||
hover_texts.append(
|
||
f"<b>{d['directory']}</b><br>"
|
||
f"Cost per patient: £{d['cost_pp']:,.0f}<br>"
|
||
f"Patients: {d['patients']:,}<br>"
|
||
f"Total cost: £{d['total_cost']:,.0f}"
|
||
)
|
||
|
||
# Use a standard bar chart (not go.Waterfall) for cleaner control
|
||
# Each bar shows cost_pp for a directorate, sorted highest to lowest
|
||
fig = go.Figure()
|
||
|
||
fig.add_trace(
|
||
go.Bar(
|
||
x=labels,
|
||
y=cost_pp_values,
|
||
marker=dict(
|
||
color=bar_colours,
|
||
line=dict(color="#FFFFFF", width=1),
|
||
),
|
||
hovertemplate="%{customdata}<extra></extra>",
|
||
customdata=hover_texts,
|
||
text=[f"£{v:,.0f}" for v in cost_pp_values],
|
||
textposition="outside",
|
||
textfont=dict(size=11, color="#425563"),
|
||
)
|
||
)
|
||
|
||
# Add patient count annotations below each bar
|
||
for i, (label, pts) in enumerate(zip(labels, patients_list)):
|
||
fig.add_annotation(
|
||
x=label,
|
||
y=0,
|
||
text=f"n={pts:,}",
|
||
showarrow=False,
|
||
yshift=-18,
|
||
font=dict(size=10, color=ANNOTATION_COLOR, family=CHART_FONT_FAMILY),
|
||
)
|
||
|
||
# Grand total line
|
||
if cost_pp_values:
|
||
total_patients = sum(patients_list)
|
||
total_cost = sum(total_costs)
|
||
weighted_avg = total_cost / total_patients if total_patients else 0
|
||
fig.add_hline(
|
||
y=weighted_avg,
|
||
line_dash="dash",
|
||
line_color="#DA291C",
|
||
line_width=1.5,
|
||
annotation_text=f"Weighted avg: £{weighted_avg:,.0f}",
|
||
annotation_position="top right",
|
||
annotation_font=dict(
|
||
size=11, color="#DA291C", family=CHART_FONT_FAMILY
|
||
),
|
||
)
|
||
|
||
display_title = (
|
||
f"Cost per Patient by Directorate — {title}" if title
|
||
else "Cost per Patient by Directorate"
|
||
)
|
||
|
||
layout = _base_layout(display_title)
|
||
layout.update(
|
||
xaxis=dict(
|
||
title="",
|
||
tickangle=-45 if len(data) > 6 else 0,
|
||
tickfont=dict(size=11),
|
||
automargin=True,
|
||
),
|
||
yaxis=dict(
|
||
title="£ per patient",
|
||
tickprefix="£",
|
||
tickformat=",",
|
||
gridcolor=GRID_COLOR,
|
||
zeroline=True,
|
||
zerolinecolor="#CBD5E1",
|
||
automargin=True,
|
||
),
|
||
margin=dict(t=60, l=8, r=24, b=80),
|
||
height=max(600, len(data) * 50 + 200),
|
||
showlegend=False,
|
||
bargap=0.25,
|
||
)
|
||
fig.update_layout(**layout)
|
||
|
||
return fig
|
||
|
||
|
||
def create_sankey_figure(
|
||
data: dict,
|
||
title: str = "",
|
||
) -> go.Figure:
|
||
"""Create Sankey diagram showing drug switching flows between treatment lines.
|
||
|
||
Args:
|
||
data: Dict from get_drug_transitions() with keys:
|
||
nodes: [{name}] — drug names with ordinal suffixes (e.g., "ADALIMUMAB (1st)")
|
||
links: [{source_idx, target_idx, patients}] — transitions between drugs
|
||
title: Chart title suffix (filter description).
|
||
|
||
Returns:
|
||
Plotly Figure with Sankey diagram.
|
||
"""
|
||
import re
|
||
|
||
nodes = data.get("nodes", [])
|
||
links = data.get("links", [])
|
||
|
||
if not nodes or not links:
|
||
return go.Figure()
|
||
|
||
# Extract base drug name (strip ordinal suffix) for colour consistency
|
||
def base_drug(name: str) -> str:
|
||
return re.sub(r"\s*\(\d+(?:st|nd|rd|th)\)\s*$", "", name)
|
||
|
||
unique_bases = []
|
||
for n in nodes:
|
||
b = base_drug(n["name"])
|
||
if b not in unique_bases:
|
||
unique_bases.append(b)
|
||
base_colour_map = {b: DRUG_PALETTE[i % len(DRUG_PALETTE)] for i, b in enumerate(unique_bases)}
|
||
|
||
# Node colours — same drug gets same colour regardless of treatment line
|
||
node_colours = [base_colour_map[base_drug(n["name"])] for n in nodes]
|
||
|
||
# Node labels — format nicely
|
||
node_labels = [n["name"] for n in nodes]
|
||
|
||
# Link colours — use source node colour at 40% opacity for visual clarity
|
||
def hex_to_rgba(hex_colour: str, alpha: float) -> str:
|
||
h = hex_colour.lstrip("#")
|
||
r, g, b = int(h[0:2], 16), int(h[2:4], 16), int(h[4:6], 16)
|
||
return f"rgba({r},{g},{b},{alpha})"
|
||
|
||
link_colours = [
|
||
hex_to_rgba(node_colours[link["source_idx"]], 0.35)
|
||
for link in links
|
||
]
|
||
|
||
# Build hover text for links
|
||
link_hovers = [
|
||
f"{node_labels[link['source_idx']]} → {node_labels[link['target_idx']]}"
|
||
f"<br>Patients: {link['patients']:,}"
|
||
for link in links
|
||
]
|
||
|
||
# Compute total patients per node for node hover
|
||
node_patients = [0] * len(nodes)
|
||
for link in links:
|
||
node_patients[link["source_idx"]] += link["patients"]
|
||
# For terminal nodes (no outgoing), use incoming total
|
||
node_incoming = [0] * len(nodes)
|
||
for link in links:
|
||
node_incoming[link["target_idx"]] += link["patients"]
|
||
node_hover = []
|
||
for i, n in enumerate(nodes):
|
||
out_p = node_patients[i]
|
||
in_p = node_incoming[i]
|
||
total = max(out_p, in_p)
|
||
node_hover.append(f"<b>{n['name']}</b><br>Patients: {total:,}")
|
||
|
||
fig = go.Figure(
|
||
go.Sankey(
|
||
arrangement="freeform",
|
||
node=dict(
|
||
pad=25,
|
||
thickness=25,
|
||
line=dict(color="#FFFFFF", width=1),
|
||
label=node_labels,
|
||
color=node_colours,
|
||
customdata=node_hover,
|
||
hovertemplate="%{customdata}<extra></extra>",
|
||
),
|
||
link=dict(
|
||
source=[link["source_idx"] for link in links],
|
||
target=[link["target_idx"] for link in links],
|
||
value=[link["patients"] for link in links],
|
||
color=link_colours,
|
||
customdata=link_hovers,
|
||
hovertemplate="%{customdata}<extra></extra>",
|
||
),
|
||
)
|
||
)
|
||
|
||
chart_title = "Drug Switching Flows"
|
||
if title:
|
||
chart_title = f"{chart_title} — {title}"
|
||
|
||
layout = _base_layout(chart_title)
|
||
layout.update(
|
||
font=dict(family=CHART_FONT_FAMILY, size=12),
|
||
margin=dict(t=60, l=30, r=30, b=30),
|
||
height=max(600, len(unique_bases) * 35 + 200),
|
||
)
|
||
fig.update_layout(**layout)
|
||
|
||
return fig
|
||
|
||
|
||
def create_dosing_figure(
|
||
data: list[dict],
|
||
title: str = "",
|
||
group_by: str = "drug",
|
||
) -> go.Figure:
|
||
"""Create dosing interval comparison chart.
|
||
|
||
Shows weekly dosing intervals as horizontal bars, grouped either by drug
|
||
(overview mode) or by trust (single-drug comparison mode).
|
||
|
||
Args:
|
||
data: List of dicts from get_dosing_intervals() with keys:
|
||
drug, trust_name, directory, weekly_interval, dose_count,
|
||
total_weeks, patients.
|
||
title: Chart title suffix (filter description).
|
||
group_by: "drug" for drug-level overview (default),
|
||
"trust" for per-trust comparison of a single drug.
|
||
|
||
Returns:
|
||
Plotly Figure with horizontal grouped bar chart.
|
||
"""
|
||
if not data:
|
||
return go.Figure()
|
||
|
||
if group_by == "trust":
|
||
fig = _dosing_by_trust(data, DRUG_PALETTE)
|
||
chart_title = "Dosing Intervals by Trust"
|
||
else:
|
||
fig = _dosing_by_drug(data, DRUG_PALETTE)
|
||
chart_title = "Dosing Interval Overview"
|
||
|
||
if title:
|
||
chart_title = f"{chart_title} — {title}"
|
||
|
||
n_rows = len(fig.data[0].y) if fig.data else 10
|
||
n_legend = sum(1 for t in fig.data if t.showlegend is not False)
|
||
legend_margins = _smart_legend_margin(n_legend)
|
||
|
||
layout = _base_layout(chart_title)
|
||
layout.update(
|
||
xaxis=dict(
|
||
title="Weekly Interval (weeks between doses)",
|
||
titlefont=dict(size=13, color="#425563"),
|
||
gridcolor="rgba(66,85,99,0.1)",
|
||
zeroline=True,
|
||
zerolinecolor="rgba(66,85,99,0.2)",
|
||
),
|
||
yaxis=dict(automargin=True, tickfont=dict(size=11)),
|
||
margin=dict(t=60, l=20, **legend_margins),
|
||
height=max(600, n_rows * 40 + 150),
|
||
bargap=0.15,
|
||
bargroupgap=0.05,
|
||
showlegend=True,
|
||
legend=_smart_legend(n_legend),
|
||
)
|
||
fig.update_layout(**layout)
|
||
|
||
return fig
|
||
|
||
|
||
def _dosing_by_drug(data: list[dict], colours: list[str]) -> go.Figure:
|
||
"""Build dosing overview: one row per drug, bars per trust, showing weekly_interval."""
|
||
# Aggregate: weighted average interval per drug, summing patients
|
||
drug_agg = {}
|
||
for d in data:
|
||
drug = d["drug"]
|
||
pts = d["patients"] or 0
|
||
if drug not in drug_agg:
|
||
drug_agg[drug] = {"weighted_sum": 0.0, "total_patients": 0,
|
||
"dose_count_ws": 0.0, "total_weeks_ws": 0.0}
|
||
drug_agg[drug]["weighted_sum"] += d["weekly_interval"] * pts
|
||
drug_agg[drug]["total_patients"] += pts
|
||
drug_agg[drug]["dose_count_ws"] += d["dose_count"] * pts
|
||
drug_agg[drug]["total_weeks_ws"] += d["total_weeks"] * pts
|
||
|
||
# Build sorted list (by total patients desc)
|
||
drugs_sorted = sorted(
|
||
drug_agg.items(),
|
||
key=lambda x: x[1]["total_patients"],
|
||
)
|
||
|
||
drug_names = [d[0] for d in drugs_sorted]
|
||
intervals = []
|
||
patients_list = []
|
||
hover_texts = []
|
||
|
||
for drug, agg in drugs_sorted:
|
||
tp = agg["total_patients"]
|
||
avg_interval = agg["weighted_sum"] / tp if tp > 0 else 0
|
||
avg_doses = agg["dose_count_ws"] / tp if tp > 0 else 0
|
||
avg_weeks = agg["total_weeks_ws"] / tp if tp > 0 else 0
|
||
intervals.append(round(avg_interval, 1))
|
||
patients_list.append(tp)
|
||
hover_texts.append(
|
||
f"<b>{drug}</b><br>"
|
||
f"Avg interval: {avg_interval:.1f} weeks<br>"
|
||
f"Avg doses: {avg_doses:.1f}<br>"
|
||
f"Avg treatment: {avg_weeks:.0f} weeks<br>"
|
||
f"Patients: {tp:,}"
|
||
)
|
||
|
||
# Colour bars by interval: lower = more frequent dosing, higher = less frequent
|
||
# Use Viridis colorscale for meaningful gradient (replaces blue→blue interpolation)
|
||
import plotly.colors as pc
|
||
max_interval = max(intervals) if intervals else 1
|
||
ratios = [iv / max_interval if max_interval > 0 else 0 for iv in intervals]
|
||
bar_colours = pc.sample_colorscale("Viridis", ratios)
|
||
|
||
fig = go.Figure()
|
||
fig.add_trace(go.Bar(
|
||
y=drug_names,
|
||
x=intervals,
|
||
orientation="h",
|
||
marker=dict(color=bar_colours, line=dict(color="#FFFFFF", width=0.5)),
|
||
text=[f"{iv}w" for iv in intervals],
|
||
textposition="outside",
|
||
textfont=dict(size=10, color="#425563"),
|
||
customdata=list(zip(hover_texts, patients_list)),
|
||
hovertemplate="%{customdata[0]}<extra></extra>",
|
||
name="Weighted Avg Interval",
|
||
showlegend=False,
|
||
))
|
||
|
||
# Add patient count annotations on the right
|
||
for i, (drug, pts) in enumerate(zip(drug_names, patients_list)):
|
||
fig.add_annotation(
|
||
x=max(intervals) * 1.15 if intervals else 10,
|
||
y=drug,
|
||
text=f"n={pts:,}",
|
||
showarrow=False,
|
||
font=dict(size=9, color="#768692"),
|
||
xanchor="left",
|
||
)
|
||
|
||
return fig
|
||
|
||
|
||
def _dosing_by_trust(data: list[dict], colours: list[str]) -> go.Figure:
|
||
"""Build per-trust comparison: one row per trust, bars per directory, showing weekly_interval."""
|
||
from collections import defaultdict
|
||
|
||
# Group by trust × directory
|
||
trust_dir = defaultdict(list)
|
||
for d in data:
|
||
trust_dir[(d["trust_name"], d["directory"])].append(d)
|
||
|
||
# Get unique trusts and directories
|
||
trusts = sorted(set(d["trust_name"] for d in data))
|
||
directories = sorted(set(d["directory"] for d in data))
|
||
|
||
fig = go.Figure()
|
||
|
||
for i, directory in enumerate(directories):
|
||
y_labels = []
|
||
x_vals = []
|
||
hover_list = []
|
||
|
||
for trust in trusts:
|
||
entries = trust_dir.get((trust, directory))
|
||
if not entries:
|
||
continue
|
||
# Average if multiple entries per trust+directory (shouldn't happen at level 3)
|
||
avg_iv = sum(e["weekly_interval"] * (e["patients"] or 0) for e in entries)
|
||
total_pts = sum(e["patients"] or 0 for e in entries)
|
||
if total_pts == 0:
|
||
continue
|
||
avg_iv /= total_pts
|
||
avg_doses = sum(e["dose_count"] * (e["patients"] or 0) for e in entries) / total_pts
|
||
avg_weeks = sum(e["total_weeks"] * (e["patients"] or 0) for e in entries) / total_pts
|
||
|
||
# Shorten trust name for readability
|
||
short_trust = trust.replace(" NHS FOUNDATION TRUST", "").replace(" HOSPITALS", "")
|
||
y_labels.append(short_trust)
|
||
x_vals.append(round(avg_iv, 1))
|
||
hover_list.append(
|
||
f"<b>{short_trust}</b><br>"
|
||
f"Directorate: {directory}<br>"
|
||
f"Interval: {avg_iv:.1f} weeks<br>"
|
||
f"Avg doses: {avg_doses:.1f}<br>"
|
||
f"Treatment: {avg_weeks:.0f} weeks<br>"
|
||
f"Patients: {total_pts:,}"
|
||
)
|
||
|
||
if y_labels:
|
||
fig.add_trace(go.Bar(
|
||
y=y_labels,
|
||
x=x_vals,
|
||
orientation="h",
|
||
name=directory,
|
||
marker=dict(color=colours[i % len(colours)]),
|
||
customdata=hover_list,
|
||
hovertemplate="%{customdata}<extra></extra>",
|
||
))
|
||
|
||
fig.update_layout(barmode="group")
|
||
return fig
|
||
|
||
|
||
def save_figure_html(
|
||
fig: go.Figure, save_dir: str, title: str, open_browser: bool = False
|
||
) -> str:
|
||
"""
|
||
Save Plotly figure to HTML file.
|
||
|
||
Args:
|
||
fig: Plotly Figure object
|
||
save_dir: Directory to save the HTML file
|
||
title: Title used for filename
|
||
open_browser: If True, open the file in the default browser
|
||
|
||
Returns:
|
||
Path to the saved HTML file
|
||
"""
|
||
filepath = f"{save_dir}/{title}.html"
|
||
fig.write_html(filepath)
|
||
logger.info(f"Success! File saved to {filepath}")
|
||
|
||
if open_browser:
|
||
open_figure_in_browser(filepath)
|
||
|
||
return filepath
|
||
|
||
|
||
def open_figure_in_browser(filepath: str) -> None:
|
||
"""
|
||
Open an HTML file in the default browser.
|
||
|
||
Args:
|
||
filepath: Path to the HTML file
|
||
"""
|
||
webbrowser.open_new_tab("file:///" + filepath)
|
||
|
||
|
||
def figure_legacy(ice_df: pd.DataFrame, dir_string: str, save_dir: str) -> None:
|
||
"""
|
||
Create and display icicle figure (legacy interface).
|
||
|
||
This function maintains backward compatibility with the original figure()
|
||
function signature. It creates the figure, saves it to HTML, and opens
|
||
it in the browser.
|
||
|
||
Args:
|
||
ice_df: DataFrame with chart data
|
||
dir_string: Title string (used for filename and chart title)
|
||
save_dir: Directory to save the HTML file
|
||
|
||
Note:
|
||
This function is provided for backward compatibility.
|
||
New code should use create_icicle_figure() + save_figure_html() instead.
|
||
"""
|
||
# Handle avg_days column for display
|
||
ice_df = ice_df.copy()
|
||
ice_df.sort_values(by=["labels"], ascending=True, inplace=True, ignore_index=True)
|
||
|
||
first_seen = ice_df["First seen"].astype(str).replace("NaT", "N/A").to_list()
|
||
last_seen = ice_df["Last seen"].astype(str).replace("NaT", "N/A").to_list()
|
||
first_seen_parent = ice_df["First seen (Parent)"].astype(str).to_list()
|
||
last_seen_parent = ice_df["Last seen (Parent)"].astype(str).to_list()
|
||
average_spacing = ice_df.average_spacing.astype(str).to_list()
|
||
avg_seen = ice_df["avg_days"].dt.round("D").astype(str).replace("0 days", "N/A").to_list()
|
||
|
||
fig = go.Figure(
|
||
go.Icicle(
|
||
labels=ice_df.labels,
|
||
ids=ice_df.ids,
|
||
parents=ice_df.parents,
|
||
customdata=np.stack(
|
||
(
|
||
ice_df.value,
|
||
ice_df.colour,
|
||
ice_df.cost,
|
||
ice_df.costpp,
|
||
first_seen,
|
||
last_seen,
|
||
first_seen_parent,
|
||
last_seen_parent,
|
||
average_spacing,
|
||
ice_df.cost_pp_pa,
|
||
),
|
||
axis=1,
|
||
),
|
||
values=ice_df.value,
|
||
branchvalues="total",
|
||
marker=dict(colors=ice_df.colour, colorscale="Viridis"),
|
||
maxdepth=3,
|
||
texttemplate="<b>%{label}</b> "
|
||
"<br><b>Total patients:</b> %{customdata[0]} (including children/further treatments)"
|
||
"<br><b>First seen:</b> %{customdata[4]}"
|
||
"<br><b>Last seen (including further treatments):</b> %{customdata[7]}"
|
||
"<br><b>Average treatment duration:</b> %{customdata[8]}"
|
||
"<br><b>Total cost:</b> £%{customdata[2]:.3~s}"
|
||
"<br><b>Average cost per patient:</b> £%{customdata[3]:.3~s}"
|
||
"<br><b>Average cost per patient per annum:</b> £%{customdata[9]:.3~s}",
|
||
hovertemplate="<b>%{label}</b>"
|
||
"<br><b>Total patients:</b> %{customdata[0]} - %{customdata[1]:.3p} of patients in level"
|
||
"<br><b>Total cost:</b> £%{customdata[2]:.3~s}"
|
||
"<br><b>Average cost per patient:</b> £%{customdata[3]:.3~s}"
|
||
"<br><b>Average cost per patient per annum:</b> £%{customdata[9]:.3~s}"
|
||
"<br><b>First seen:</b> %{customdata[4]}"
|
||
"<br><b>Last seen (including further treatments):</b> %{customdata[7]}"
|
||
"<br><b>Average treatment duration:</b>"
|
||
"%{customdata[8]}"
|
||
"<extra></extra>",
|
||
)
|
||
)
|
||
fig.update_traces(sort=False)
|
||
fig.update_layout(
|
||
margin=dict(t=60, l=1, r=1, b=60),
|
||
title=f"Norfolk & Waveney ICS high-cost drug patient pathways - {dir_string}",
|
||
title_x=0.5,
|
||
hoverlabel=dict(font_size=16),
|
||
)
|
||
|
||
filepath = f"{save_dir}/{dir_string}.html"
|
||
fig.write_html(filepath)
|
||
logger.info(f"Success! File saved to {filepath}")
|
||
webbrowser.open_new_tab("file:///" + filepath)
|
||
|
||
|
||
def create_heatmap_figure(
|
||
data: dict,
|
||
title: str = "",
|
||
metric: str = "patients",
|
||
) -> go.Figure:
|
||
"""Create a directorate × drug heatmap chart.
|
||
|
||
Args:
|
||
data: Dict from get_drug_directory_matrix() with keys:
|
||
directories (list), drugs (list),
|
||
matrix ({dir: {drug: {patients, cost, cost_pp_pa}}}).
|
||
title: Chart title suffix (filter description).
|
||
metric: Colour metric — "patients", "cost", or "cost_pp_pa".
|
||
|
||
Returns:
|
||
Plotly Figure with annotated heatmap.
|
||
"""
|
||
directories = data.get("directories", [])
|
||
drugs = data.get("drugs", [])
|
||
matrix = data.get("matrix", {})
|
||
|
||
if not directories or not drugs:
|
||
return go.Figure()
|
||
|
||
# Cap columns to top 25 drugs for readability
|
||
max_drugs = 25
|
||
total_drug_count = len(drugs)
|
||
drugs = drugs[:max_drugs]
|
||
capped = total_drug_count > max_drugs
|
||
|
||
metric_labels = {
|
||
"patients": "Patients",
|
||
"cost": "Total Cost (£)",
|
||
"cost_pp_pa": "Cost per Patient p.a. (£)",
|
||
}
|
||
metric_label = metric_labels.get(metric, "Patients")
|
||
|
||
# Build 2D arrays for z-values, hover text, and cell annotations
|
||
z_values = []
|
||
hover_texts = []
|
||
text_values = []
|
||
|
||
for d in directories:
|
||
row_z = []
|
||
row_hover = []
|
||
row_text = []
|
||
dir_data = matrix.get(d, {})
|
||
for drug in drugs:
|
||
cell = dir_data.get(drug)
|
||
if cell:
|
||
val = cell.get(metric, cell.get("patients", 0))
|
||
patients = cell.get("patients", 0)
|
||
cost = cell.get("cost", 0)
|
||
cpp = cell.get("cost_pp_pa", 0)
|
||
row_z.append(val if val else 0)
|
||
row_hover.append(
|
||
f"<b>{drug}</b><br>"
|
||
f"{d}<br>"
|
||
f"Patients: {patients:,}<br>"
|
||
f"Total cost: £{cost:,.0f}<br>"
|
||
f"Cost p.a.: £{cpp:,.0f}"
|
||
)
|
||
# Cell annotation text, formatted per metric
|
||
if metric == "cost":
|
||
row_text.append(f"£{cost / 1000:.0f}k" if cost >= 1000 else f"£{cost:.0f}")
|
||
elif metric == "cost_pp_pa":
|
||
row_text.append(f"£{cpp:,.0f}")
|
||
else:
|
||
row_text.append(f"{patients:,}")
|
||
else:
|
||
row_z.append(0)
|
||
row_hover.append(
|
||
f"<b>{drug}</b><br>{d}<br>No patients"
|
||
)
|
||
row_text.append("")
|
||
z_values.append(row_z)
|
||
hover_texts.append(row_hover)
|
||
text_values.append(row_text)
|
||
|
||
# Linear 5-stop NHS blue colorscale
|
||
colorscale = [
|
||
[0.0, "#E3F2FD"],
|
||
[0.25, "#90CAF9"],
|
||
[0.5, "#42A5F5"],
|
||
[0.75, "#1E88E5"],
|
||
[1.0, "#003087"],
|
||
]
|
||
|
||
n_drugs = len(drugs)
|
||
gap = 1 if n_drugs > 15 else 2
|
||
|
||
fig = go.Figure(
|
||
data=go.Heatmap(
|
||
z=z_values,
|
||
x=drugs,
|
||
y=directories,
|
||
colorscale=colorscale,
|
||
zmin=0,
|
||
text=text_values,
|
||
texttemplate="%{text}",
|
||
textfont=dict(size=10),
|
||
hovertext=hover_texts,
|
||
hovertemplate="%{hovertext}<extra></extra>",
|
||
colorbar=dict(
|
||
title=dict(
|
||
text=metric_label,
|
||
font=dict(size=12, color="#425563"),
|
||
),
|
||
thickness=15,
|
||
len=0.8,
|
||
),
|
||
xgap=gap,
|
||
ygap=gap,
|
||
)
|
||
)
|
||
|
||
chart_title = f"Directorate × Drug — {metric_label}"
|
||
if title:
|
||
chart_title = f"{chart_title} — {title}"
|
||
|
||
n_dirs = len(directories)
|
||
fig_height = max(600, 80 + n_dirs * 40)
|
||
|
||
layout = _base_layout(chart_title)
|
||
layout.update(
|
||
xaxis=dict(
|
||
title="",
|
||
tickfont=dict(size=11, color="#425563"),
|
||
tickangle=-45,
|
||
side="bottom",
|
||
automargin=True,
|
||
),
|
||
yaxis=dict(
|
||
title="",
|
||
tickfont=dict(size=12, color="#425563"),
|
||
autorange="reversed",
|
||
automargin=True,
|
||
),
|
||
margin=dict(t=60, l=8, r=80, b=120),
|
||
height=fig_height,
|
||
)
|
||
fig.update_layout(**layout)
|
||
|
||
# Add subtitle when drug cap is reached
|
||
if capped:
|
||
fig.add_annotation(
|
||
text=f"Showing top {max_drugs} of {total_drug_count} drugs",
|
||
xref="paper", yref="paper",
|
||
x=0.5, y=1.02,
|
||
showarrow=False,
|
||
font=dict(size=12, color=ANNOTATION_COLOR),
|
||
)
|
||
|
||
return fig
|
||
|
||
|
||
def create_duration_figure(
|
||
data: list[dict],
|
||
title: str = "",
|
||
show_directory: bool = False,
|
||
) -> go.Figure:
|
||
"""Create horizontal bar chart showing average treatment duration by drug.
|
||
|
||
Args:
|
||
data: List of dicts from get_treatment_durations() with keys:
|
||
drug, directory, avg_days, patients.
|
||
Sorted by avg_days desc.
|
||
title: Chart title suffix (filter description).
|
||
show_directory: If True, include directory in label (for overview mode).
|
||
|
||
Returns:
|
||
Plotly Figure with horizontal bars coloured by patient count.
|
||
"""
|
||
if not data:
|
||
return go.Figure()
|
||
|
||
# When not showing directory breakdown, aggregate same drug across directorates
|
||
if not show_directory:
|
||
agg = {}
|
||
for d in data:
|
||
drug = d["drug"]
|
||
pts = d["patients"]
|
||
days = d["avg_days"]
|
||
if drug not in agg:
|
||
agg[drug] = {"drug": drug, "total_weighted": 0.0, "total_pts": 0}
|
||
agg[drug]["total_weighted"] += days * pts
|
||
agg[drug]["total_pts"] += pts
|
||
data = []
|
||
for v in agg.values():
|
||
if v["total_pts"] > 0:
|
||
data.append({
|
||
"drug": v["drug"],
|
||
"avg_days": round(v["total_weighted"] / v["total_pts"], 1),
|
||
"patients": v["total_pts"],
|
||
})
|
||
data.sort(key=lambda x: -x["avg_days"])
|
||
|
||
# Cap at 40 entries for readability (keep top by patient count, then re-sort by days)
|
||
if len(data) > 40:
|
||
data.sort(key=lambda x: -x["patients"])
|
||
data = data[:40]
|
||
data.sort(key=lambda x: -x["avg_days"])
|
||
|
||
# Build labels
|
||
if show_directory:
|
||
labels = [f"{d['drug']} ({d['directory']})" for d in data]
|
||
else:
|
||
labels = [d["drug"] for d in data]
|
||
|
||
days_values = [d["avg_days"] for d in data]
|
||
patients_list = [d["patients"] for d in data]
|
||
|
||
# Colour gradient by patient count: light for few → dark NHS blue for many
|
||
max_pts = max(patients_list) if patients_list else 1
|
||
min_pts = min(patients_list) if patients_list else 0
|
||
pt_range = max_pts - min_pts if max_pts > min_pts else 1
|
||
|
||
bar_colours = []
|
||
for pts in patients_list:
|
||
t = (pts - min_pts) / pt_range
|
||
r = int(0x41 + (0x00 - 0x41) * t)
|
||
g = int(0xB6 + (0x30 - 0xB6) * t)
|
||
b = int(0xE6 + (0x87 - 0xE6) * t)
|
||
bar_colours.append(f"rgb({r},{g},{b})")
|
||
|
||
hover_texts = []
|
||
for d in data:
|
||
years = d["avg_days"] / 365.25
|
||
hover_texts.append(
|
||
f"<b>{d['drug']}</b><br>"
|
||
f"Avg duration: {d['avg_days']:,.0f} days ({years:.1f} years)<br>"
|
||
f"Patients: {d['patients']:,}"
|
||
)
|
||
|
||
fig = go.Figure()
|
||
|
||
fig.add_trace(
|
||
go.Bar(
|
||
y=labels,
|
||
x=days_values,
|
||
orientation="h",
|
||
marker=dict(
|
||
color=bar_colours,
|
||
line=dict(color="#FFFFFF", width=1),
|
||
),
|
||
hovertemplate="%{customdata}<extra></extra>",
|
||
customdata=hover_texts,
|
||
text=[f"{v:,.0f}d" for v in days_values],
|
||
textposition="outside",
|
||
textfont=dict(size=10, color="#425563"),
|
||
)
|
||
)
|
||
|
||
for i, pts in enumerate(patients_list):
|
||
fig.add_annotation(
|
||
x=days_values[i],
|
||
y=labels[i],
|
||
text=f"n={pts:,}",
|
||
showarrow=False,
|
||
xshift=45,
|
||
font=dict(size=9, color=ANNOTATION_COLOR, family=CHART_FONT_FAMILY),
|
||
)
|
||
|
||
chart_title = "Treatment Duration by Drug"
|
||
if title:
|
||
chart_title += f"<br><span style='font-size:13px;color:{ANNOTATION_COLOR}'>{title}</span>"
|
||
|
||
n_bars = len(data)
|
||
fig_height = max(600, 40 + n_bars * 28)
|
||
|
||
layout = _base_layout(chart_title)
|
||
layout.update(
|
||
xaxis=dict(
|
||
title="Average Duration (days)",
|
||
titlefont=dict(size=13, color="#425563"),
|
||
tickfont=dict(size=11, color="#425563"),
|
||
gridcolor="rgba(0,0,0,0.06)",
|
||
zeroline=True,
|
||
zerolinecolor="rgba(0,0,0,0.1)",
|
||
),
|
||
yaxis=dict(
|
||
title="",
|
||
tickfont=dict(size=11, color="#425563"),
|
||
automargin=True,
|
||
autorange="reversed",
|
||
),
|
||
margin=dict(t=60, l=8, r=100, b=50),
|
||
height=fig_height,
|
||
showlegend=False,
|
||
)
|
||
fig.update_layout(**layout)
|
||
|
||
return fig
|
||
|
||
|
||
# --- Trust Comparison chart functions ---
|
||
|
||
|
||
def create_trust_market_share_figure(
|
||
data: list[dict],
|
||
title: str = "",
|
||
) -> go.Figure:
|
||
"""Create horizontal stacked bar chart showing drug market share per trust.
|
||
|
||
Unlike create_market_share_figure (which groups by directorate), this groups
|
||
by trust within a single directorate — used by Trust Comparison dashboard.
|
||
|
||
Args:
|
||
data: List of dicts from get_trust_market_share() with keys:
|
||
trust_name, drug, patients, proportion, cost, cost_pp_pa.
|
||
title: Chart title suffix.
|
||
"""
|
||
if not data:
|
||
return go.Figure()
|
||
|
||
seen_trusts = []
|
||
for d in data:
|
||
t = d["trust_name"]
|
||
if t not in seen_trusts:
|
||
seen_trusts.append(t)
|
||
|
||
seen_drugs = []
|
||
for d in data:
|
||
if d["drug"] not in seen_drugs:
|
||
seen_drugs.append(d["drug"])
|
||
|
||
drug_colour_map = {drug: DRUG_PALETTE[i % len(DRUG_PALETTE)] for i, drug in enumerate(seen_drugs)}
|
||
lookup = {(d["trust_name"], d["drug"]): d for d in data}
|
||
|
||
def short_trust(name):
|
||
return name.replace(" NHS FOUNDATION TRUST", "").replace(" HOSPITALS", "")
|
||
|
||
display_trusts = list(reversed(seen_trusts))
|
||
|
||
traces = []
|
||
for drug in seen_drugs:
|
||
y_vals = []
|
||
x_vals = []
|
||
hover_texts = []
|
||
for trust in display_trusts:
|
||
row = lookup.get((trust, drug))
|
||
y_vals.append(short_trust(trust))
|
||
if row:
|
||
x_vals.append(row["proportion"] * 100)
|
||
hover_texts.append(
|
||
f"<b>{drug}</b><br>"
|
||
f"{short_trust(trust)}<br>"
|
||
f"Patients: {row['patients']:,}<br>"
|
||
f"Share: {row['proportion']:.1%}<br>"
|
||
f"Cost: £{row['cost']:,.0f}<br>"
|
||
f"Cost p.p.p.a: £{row['cost_pp_pa']:,.0f}"
|
||
)
|
||
else:
|
||
x_vals.append(0)
|
||
hover_texts.append("")
|
||
|
||
traces.append(go.Bar(
|
||
name=drug, y=y_vals, x=x_vals, orientation="h",
|
||
marker_color=drug_colour_map[drug],
|
||
hovertemplate="%{customdata}<extra></extra>",
|
||
customdata=hover_texts,
|
||
))
|
||
|
||
display_title = f"Drug Market Share by Trust — {title}" if title else "Drug Market Share by Trust"
|
||
n_drugs = len(seen_drugs)
|
||
legend_margins = _smart_legend_margin(n_drugs)
|
||
|
||
fig = go.Figure(data=traces)
|
||
layout = _base_layout(display_title)
|
||
layout.update(
|
||
barmode="stack",
|
||
xaxis=dict(title="% of patients", ticksuffix="%", range=[0, 105], gridcolor=GRID_COLOR, zeroline=False),
|
||
yaxis=dict(title="", automargin=True),
|
||
legend=_smart_legend(n_drugs, legend_title="Drug"),
|
||
margin=dict(t=50, l=8, **legend_margins),
|
||
height=max(300, len(seen_trusts) * 60 + 200),
|
||
)
|
||
fig.update_layout(**layout)
|
||
|
||
return fig
|
||
|
||
|
||
def create_trust_heatmap_figure(
|
||
data: dict,
|
||
title: str = "",
|
||
metric: str = "patients",
|
||
) -> go.Figure:
|
||
"""Create a trust x drug heatmap for a single directorate.
|
||
|
||
Args:
|
||
data: Dict from get_trust_heatmap() with keys:
|
||
trusts (list), drugs (list),
|
||
matrix ({trust_name: {drug: {patients, cost, cost_pp_pa}}}).
|
||
title: Chart title suffix.
|
||
metric: Colour metric — "patients", "cost", or "cost_pp_pa".
|
||
"""
|
||
trusts = data.get("trusts", [])
|
||
drugs = data.get("drugs", [])
|
||
matrix = data.get("matrix", {})
|
||
|
||
if not trusts or not drugs:
|
||
return go.Figure()
|
||
|
||
total_drug_count = len(drugs)
|
||
drugs = drugs[:25]
|
||
capped = total_drug_count > 25
|
||
|
||
metric_labels = {
|
||
"patients": "Patients",
|
||
"cost": "Total Cost (£)",
|
||
"cost_pp_pa": "Cost per Patient p.a. (£)",
|
||
}
|
||
metric_label = metric_labels.get(metric, "Patients")
|
||
|
||
def short_trust(name):
|
||
return name.replace(" NHS FOUNDATION TRUST", "").replace(" HOSPITALS", "")
|
||
|
||
z_values = []
|
||
hover_texts = []
|
||
text_values = []
|
||
|
||
for t in trusts:
|
||
row_z = []
|
||
row_hover = []
|
||
row_text = []
|
||
trust_data = matrix.get(t, {})
|
||
for drug in drugs:
|
||
cell = trust_data.get(drug)
|
||
if cell:
|
||
val = cell.get(metric, cell.get("patients", 0))
|
||
patients = cell.get("patients", 0)
|
||
cost = cell.get("cost", 0)
|
||
cpp = cell.get("cost_pp_pa", 0)
|
||
row_z.append(val if val else 0)
|
||
row_hover.append(
|
||
f"<b>{drug}</b><br>"
|
||
f"{short_trust(t)}<br>"
|
||
f"Patients: {patients:,}<br>"
|
||
f"Total cost: £{cost:,.0f}<br>"
|
||
f"Cost p.a.: £{cpp:,.0f}"
|
||
)
|
||
if metric == "cost":
|
||
row_text.append(f"£{cost / 1000:.0f}k" if cost >= 1000 else f"£{cost:.0f}")
|
||
elif metric == "cost_pp_pa":
|
||
row_text.append(f"£{cpp:,.0f}")
|
||
else:
|
||
row_text.append(f"{patients:,}")
|
||
else:
|
||
row_z.append(0)
|
||
row_hover.append(f"<b>{drug}</b><br>{short_trust(t)}<br>No patients")
|
||
row_text.append("")
|
||
z_values.append(row_z)
|
||
hover_texts.append(row_hover)
|
||
text_values.append(row_text)
|
||
|
||
# Linear 5-stop NHS blue colorscale
|
||
colorscale = [
|
||
[0.0, "#E3F2FD"],
|
||
[0.25, "#90CAF9"],
|
||
[0.5, "#42A5F5"],
|
||
[0.75, "#1E88E5"],
|
||
[1.0, "#003087"],
|
||
]
|
||
|
||
display_trusts = [short_trust(t) for t in trusts]
|
||
n_drugs = len(drugs)
|
||
gap = 1 if n_drugs > 15 else 2
|
||
|
||
fig = go.Figure(
|
||
data=go.Heatmap(
|
||
z=z_values, x=drugs, y=display_trusts,
|
||
colorscale=colorscale,
|
||
zmin=0,
|
||
text=text_values,
|
||
texttemplate="%{text}",
|
||
textfont=dict(size=10),
|
||
hovertext=hover_texts,
|
||
hovertemplate="%{hovertext}<extra></extra>",
|
||
colorbar=dict(
|
||
title=dict(text=metric_label, font=dict(size=12, color="#425563")),
|
||
thickness=15, len=0.8,
|
||
),
|
||
xgap=gap, ygap=gap,
|
||
)
|
||
)
|
||
|
||
chart_title = f"Trust × Drug — {metric_label}"
|
||
if title:
|
||
chart_title = f"{chart_title} — {title}"
|
||
|
||
n_trusts = len(trusts)
|
||
|
||
layout = _base_layout(chart_title)
|
||
layout.update(
|
||
xaxis=dict(title="", tickfont=dict(size=11, color="#425563"), tickangle=-45, side="bottom", automargin=True),
|
||
yaxis=dict(title="", tickfont=dict(size=12, color="#425563"), autorange="reversed", automargin=True),
|
||
margin=dict(t=60, l=8, r=80, b=120),
|
||
height=max(400, 80 + n_trusts * 50),
|
||
)
|
||
fig.update_layout(**layout)
|
||
|
||
if capped:
|
||
fig.add_annotation(
|
||
text=f"Showing top 25 of {total_drug_count} drugs",
|
||
xref="paper", yref="paper",
|
||
x=0.5, y=1.02,
|
||
showarrow=False,
|
||
font=dict(size=12, color=ANNOTATION_COLOR),
|
||
)
|
||
|
||
return fig
|
||
|
||
|
||
def create_trust_duration_figure(
|
||
data: list[dict],
|
||
title: str = "",
|
||
) -> go.Figure:
|
||
"""Create grouped horizontal bar chart showing drug durations by trust.
|
||
|
||
Args:
|
||
data: List of dicts from get_trust_durations() with keys:
|
||
drug, trust_name, avg_days, patients.
|
||
title: Chart title suffix.
|
||
"""
|
||
if not data:
|
||
return go.Figure()
|
||
|
||
seen_drugs = []
|
||
for d in data:
|
||
if d["drug"] not in seen_drugs:
|
||
seen_drugs.append(d["drug"])
|
||
|
||
seen_trusts = []
|
||
for d in data:
|
||
t = d["trust_name"]
|
||
if t not in seen_trusts:
|
||
seen_trusts.append(t)
|
||
|
||
def short_trust(name):
|
||
return name.replace(" NHS FOUNDATION TRUST", "").replace(" HOSPITALS", "")
|
||
|
||
trust_colour_map = {t: TRUST_PALETTE[i % len(TRUST_PALETTE)] for i, t in enumerate(seen_trusts)}
|
||
lookup = {(d["drug"], d["trust_name"]): d for d in data}
|
||
|
||
display_drugs = list(reversed(seen_drugs))
|
||
|
||
traces = []
|
||
for trust in seen_trusts:
|
||
y_vals = []
|
||
x_vals = []
|
||
hover_texts = []
|
||
for drug in display_drugs:
|
||
row = lookup.get((drug, trust))
|
||
y_vals.append(drug)
|
||
if row:
|
||
years = row["avg_days"] / 365.25
|
||
x_vals.append(row["avg_days"])
|
||
hover_texts.append(
|
||
f"<b>{drug}</b><br>"
|
||
f"{short_trust(trust)}<br>"
|
||
f"Avg duration: {row['avg_days']:,.0f} days ({years:.1f} yrs)<br>"
|
||
f"Patients: {row['patients']:,}"
|
||
)
|
||
else:
|
||
x_vals.append(0)
|
||
hover_texts.append("")
|
||
|
||
traces.append(go.Bar(
|
||
name=short_trust(trust), y=y_vals, x=x_vals, orientation="h",
|
||
marker_color=trust_colour_map[trust],
|
||
hovertemplate="%{customdata}<extra></extra>",
|
||
customdata=hover_texts,
|
||
))
|
||
|
||
display_title = f"Treatment Duration by Trust — {title}" if title else "Treatment Duration by Trust"
|
||
n_trusts = len(seen_trusts)
|
||
legend_margins = _smart_legend_margin(n_trusts)
|
||
|
||
fig = go.Figure(data=traces)
|
||
layout = _base_layout(display_title)
|
||
layout.update(
|
||
barmode="group",
|
||
xaxis=dict(
|
||
title="Average Duration (days)", titlefont=dict(size=13, color="#425563"),
|
||
gridcolor="rgba(0,0,0,0.06)", zeroline=True, zerolinecolor="rgba(0,0,0,0.1)",
|
||
),
|
||
yaxis=dict(title="", automargin=True, tickfont=dict(size=11, color="#425563")),
|
||
legend=_smart_legend(n_trusts, legend_title="Trust"),
|
||
margin=dict(t=60, l=8, **legend_margins),
|
||
height=max(350, len(seen_drugs) * 35 + 200),
|
||
bargap=0.15, bargroupgap=0.05,
|
||
)
|
||
fig.update_layout(**layout)
|
||
|
||
return fig
|
||
|
||
|
||
def create_retention_funnel_figure(
|
||
data: list[dict],
|
||
title: str = "",
|
||
) -> go.Figure:
|
||
"""Create a retention funnel showing patient drop-off by treatment line depth.
|
||
|
||
Args:
|
||
data: List of dicts with keys: depth, label, patients, pct
|
||
title: Chart title from filter state.
|
||
|
||
Returns:
|
||
Plotly Figure with go.Funnel trace.
|
||
"""
|
||
if not data:
|
||
return go.Figure()
|
||
|
||
display_title = f"Treatment Retention — {title}" if title else "Treatment Retention"
|
||
|
||
labels = [d["label"] for d in data]
|
||
patients = [d["patients"] for d in data]
|
||
pcts = [d["pct"] for d in data]
|
||
|
||
# NHS blue gradient: darkest at top (most patients) → lightest at bottom
|
||
funnel_colors = [
|
||
"#003087", # NHS Heritage Blue (1st drug)
|
||
"#005EB8", # NHS Blue
|
||
"#1E88E5", # Mid blue
|
||
"#42A5F5", # Light blue
|
||
"#90CAF9", # Pale blue
|
||
]
|
||
colors = funnel_colors[: len(data)]
|
||
if len(colors) < len(data):
|
||
colors.extend(["#E3F2FD"] * (len(data) - len(colors)))
|
||
|
||
text_values = [
|
||
f"{p:,} patients ({pct}%)" for p, pct in zip(patients, pcts)
|
||
]
|
||
|
||
fig = go.Figure(
|
||
go.Funnel(
|
||
y=labels,
|
||
x=patients,
|
||
text=text_values,
|
||
textposition="inside",
|
||
textfont=dict(family=CHART_FONT_FAMILY, size=14, color="white"),
|
||
marker=dict(color=colors),
|
||
connector=dict(line=dict(color=GRID_COLOR, width=1)),
|
||
hovertemplate=(
|
||
"<b>%{y}</b><br>"
|
||
"Patients: %{x:,}<br>"
|
||
"%{text}<extra></extra>"
|
||
),
|
||
)
|
||
)
|
||
|
||
layout = _base_layout(display_title)
|
||
layout.update(
|
||
margin=dict(t=60, l=8, r=8, b=40),
|
||
yaxis=dict(automargin=True),
|
||
height=max(600, len(data) * 80 + 120),
|
||
)
|
||
fig.update_layout(**layout)
|
||
|
||
return fig
|
||
|
||
|
||
def create_pathway_depth_figure(
|
||
data: list[dict],
|
||
title: str = "",
|
||
) -> go.Figure:
|
||
"""Create a horizontal bar chart showing patients who stopped at each treatment depth.
|
||
|
||
Args:
|
||
data: List of dicts with keys: depth, label, patients, pct
|
||
title: Chart title from filter state.
|
||
|
||
Returns:
|
||
Plotly Figure with horizontal bar trace.
|
||
"""
|
||
if not data:
|
||
return go.Figure()
|
||
|
||
display_title = f"Pathway Depth Distribution — {title}" if title else "Pathway Depth Distribution"
|
||
|
||
labels = [d["label"] for d in data]
|
||
patients = [d["patients"] for d in data]
|
||
pcts = [d["pct"] for d in data]
|
||
|
||
# NHS blue gradient: darkest for depth 1 (most patients) → lightest
|
||
bar_colors = [
|
||
"#003087",
|
||
"#005EB8",
|
||
"#1E88E5",
|
||
"#42A5F5",
|
||
"#90CAF9",
|
||
]
|
||
colors = bar_colors[: len(data)]
|
||
if len(colors) < len(data):
|
||
colors.extend(["#E3F2FD"] * (len(data) - len(colors)))
|
||
|
||
fig = go.Figure(
|
||
go.Bar(
|
||
y=labels,
|
||
x=patients,
|
||
orientation="h",
|
||
text=[f"{p:,} ({pct}%)" for p, pct in zip(patients, pcts)],
|
||
textposition="auto",
|
||
textfont=dict(family=CHART_FONT_FAMILY, size=13),
|
||
marker=dict(color=colors),
|
||
hovertemplate=(
|
||
"<b>%{y}</b><br>"
|
||
"Patients: %{x:,}<br>"
|
||
"<extra></extra>"
|
||
),
|
||
)
|
||
)
|
||
|
||
layout = _base_layout(display_title)
|
||
layout.update(
|
||
margin=dict(t=60, l=8, r=24, b=40),
|
||
yaxis=dict(
|
||
automargin=True,
|
||
autorange="reversed",
|
||
title="",
|
||
),
|
||
xaxis=dict(
|
||
title="Patients",
|
||
gridcolor=GRID_COLOR,
|
||
),
|
||
height=max(600, len(data) * 70 + 120),
|
||
bargap=0.3,
|
||
)
|
||
fig.update_layout(**layout)
|
||
|
||
return fig
|
||
|
||
|
||
def create_duration_cost_scatter_figure(
|
||
data: list[dict],
|
||
title: str = "",
|
||
) -> go.Figure:
|
||
"""Create a Duration vs Cost scatter plot from drug-level data.
|
||
|
||
Each point represents a drug (within a directory). x=avg treatment days,
|
||
y=annualised cost per patient, size=patient count, color=directory.
|
||
Quadrant lines at median values divide into 4 regions.
|
||
"""
|
||
if not data:
|
||
return go.Figure()
|
||
|
||
import statistics
|
||
|
||
display_title = f"Duration vs Cost — {title}" if title else "Duration vs Cost"
|
||
|
||
# Assign colors by directory
|
||
directories = sorted(set(d["directory"] for d in data))
|
||
dir_colors = {
|
||
d: DRUG_PALETTE[i % len(DRUG_PALETTE)]
|
||
for i, d in enumerate(directories)
|
||
}
|
||
|
||
# Global max patients for consistent sizing across directories
|
||
global_max_p = max((d["patients"] for d in data), default=1) or 1
|
||
|
||
# Build one trace per directory for legend grouping
|
||
fig = go.Figure()
|
||
for directory in directories:
|
||
subset = [d for d in data if d["directory"] == directory]
|
||
patients = [d["patients"] for d in subset]
|
||
|
||
# Scale marker size: min 8, max 40, relative to global max
|
||
sizes = [max(8, min(40, 8 + 32 * (p / global_max_p))) for p in patients]
|
||
|
||
fig.add_trace(go.Scatter(
|
||
x=[d["avg_days"] for d in subset],
|
||
y=[d["cost_pp_pa"] for d in subset],
|
||
mode="markers",
|
||
name=directory,
|
||
marker=dict(
|
||
size=sizes,
|
||
color=dir_colors[directory],
|
||
opacity=0.75,
|
||
line=dict(width=1, color="white"),
|
||
),
|
||
text=[d["drug"] for d in subset],
|
||
customdata=[[d["patients"], d["directory"], d["avg_days"], d["cost_pp_pa"]] for d in subset],
|
||
hovertemplate=(
|
||
"<b>%{text}</b><br>"
|
||
"Directory: %{customdata[1]}<br>"
|
||
"Avg duration: %{customdata[2]} days<br>"
|
||
"Cost p.a.: £%{customdata[3]:,.0f}<br>"
|
||
"Patients: %{customdata[0]:,}<br>"
|
||
"<extra></extra>"
|
||
),
|
||
))
|
||
|
||
# Quadrant lines at median values
|
||
all_days = [d["avg_days"] for d in data]
|
||
all_costs = [d["cost_pp_pa"] for d in data]
|
||
med_days = statistics.median(all_days)
|
||
med_cost = statistics.median(all_costs)
|
||
|
||
fig.add_hline(
|
||
y=med_cost, line_dash="dash", line_color=ANNOTATION_COLOR,
|
||
line_width=1,
|
||
annotation_text=f"Median £{med_cost:,.0f}",
|
||
annotation_position="top left",
|
||
annotation_font=dict(size=10, color=ANNOTATION_COLOR, family=CHART_FONT_FAMILY),
|
||
)
|
||
fig.add_vline(
|
||
x=med_days, line_dash="dash", line_color=ANNOTATION_COLOR,
|
||
line_width=1,
|
||
annotation_text=f"Median {med_days:.0f} days",
|
||
annotation_position="top right",
|
||
annotation_font=dict(size=10, color=ANNOTATION_COLOR, family=CHART_FONT_FAMILY),
|
||
)
|
||
|
||
n_dirs = len(directories)
|
||
legend = _smart_legend(n_dirs, "Directory")
|
||
legend_margins = _smart_legend_margin(n_dirs)
|
||
|
||
layout = _base_layout(display_title)
|
||
layout.update(
|
||
margin=dict(t=60, l=8, **legend_margins),
|
||
height=600,
|
||
xaxis=dict(
|
||
title="Average Treatment Duration (days)",
|
||
gridcolor=GRID_COLOR,
|
||
zeroline=False,
|
||
),
|
||
yaxis=dict(
|
||
title="Cost per Patient per Annum (£)",
|
||
gridcolor=GRID_COLOR,
|
||
automargin=True,
|
||
zeroline=False,
|
||
),
|
||
legend=legend,
|
||
)
|
||
fig.update_layout(**layout)
|
||
|
||
return fig
|
||
|
||
|
||
def create_drug_network_figure(data: dict, title: str = "") -> go.Figure:
|
||
"""Create a drug co-occurrence network graph.
|
||
|
||
Nodes are drugs arranged in a circle, edges show co-occurrence in pathways.
|
||
Node size = total patients, edge width = switching flow between drugs.
|
||
"""
|
||
import math
|
||
|
||
nodes = data.get("nodes", [])
|
||
edges = data.get("edges", [])
|
||
|
||
if not nodes:
|
||
return go.Figure()
|
||
|
||
display_title = f"Drug Network — {title}" if title else "Drug Network"
|
||
|
||
# Circular layout
|
||
n = len(nodes)
|
||
node_names = [nd["name"] for nd in nodes]
|
||
node_patients = [nd["total_patients"] for nd in nodes]
|
||
name_to_idx = {nd["name"]: i for i, nd in enumerate(nodes)}
|
||
|
||
angles = [2 * math.pi * i / n for i in range(n)]
|
||
x_pos = [math.cos(a) for a in angles]
|
||
y_pos = [math.sin(a) for a in angles]
|
||
|
||
fig = go.Figure()
|
||
|
||
# Draw edges as individual traces (each gets its own width)
|
||
max_edge_patients = max((e["patients"] for e in edges), default=1) or 1
|
||
for edge in edges:
|
||
src_idx = name_to_idx.get(edge["source"])
|
||
tgt_idx = name_to_idx.get(edge["target"])
|
||
if src_idx is None or tgt_idx is None:
|
||
continue
|
||
|
||
# Scale width: min 0.5, max 6
|
||
width = max(0.5, min(6, 0.5 + 5.5 * (edge["patients"] / max_edge_patients)))
|
||
# Opacity scales with relative strength
|
||
opacity = max(0.15, min(0.7, 0.15 + 0.55 * (edge["patients"] / max_edge_patients)))
|
||
|
||
fig.add_trace(go.Scatter(
|
||
x=[x_pos[src_idx], x_pos[tgt_idx]],
|
||
y=[y_pos[src_idx], y_pos[tgt_idx]],
|
||
mode="lines",
|
||
line=dict(width=width, color=f"rgba(0,94,184,{opacity})"),
|
||
hoverinfo="skip",
|
||
showlegend=False,
|
||
))
|
||
|
||
# Draw nodes
|
||
max_patients = max(node_patients, default=1) or 1
|
||
sizes = [max(12, min(50, 12 + 38 * (p / max_patients))) for p in node_patients]
|
||
colors = [DRUG_PALETTE[i % len(DRUG_PALETTE)] for i in range(n)]
|
||
|
||
fig.add_trace(go.Scatter(
|
||
x=x_pos,
|
||
y=y_pos,
|
||
mode="markers+text",
|
||
marker=dict(
|
||
size=sizes,
|
||
color=colors,
|
||
line=dict(width=1.5, color="white"),
|
||
),
|
||
text=node_names,
|
||
textposition="top center",
|
||
textfont=dict(size=9, family=CHART_FONT_FAMILY),
|
||
customdata=[[p] for p in node_patients],
|
||
hovertemplate=(
|
||
"<b>%{text}</b><br>"
|
||
"Patients: %{customdata[0]:,}<br>"
|
||
"<extra></extra>"
|
||
),
|
||
showlegend=False,
|
||
))
|
||
|
||
layout = _base_layout(display_title)
|
||
layout.update(
|
||
margin=dict(t=60, l=24, r=24, b=24),
|
||
height=600,
|
||
xaxis=dict(visible=False, scaleanchor="y", scaleratio=1),
|
||
yaxis=dict(visible=False),
|
||
)
|
||
fig.update_layout(**layout)
|
||
|
||
return fig
|
||
|
||
|
||
def create_drug_timeline_figure(data: list[dict], title: str = "") -> go.Figure:
|
||
"""Create a Gantt-style timeline showing when each drug cohort was active.
|
||
|
||
Each horizontal bar spans from first_seen to last_seen for a drug,
|
||
grouped by directory, with color indicating directory and text showing
|
||
patient count.
|
||
|
||
Args:
|
||
data: List of dicts with keys: drug, directory, first_seen, last_seen,
|
||
patients, cost_pp_pa.
|
||
title: Chart title.
|
||
|
||
Returns:
|
||
Plotly Figure with horizontal bars.
|
||
"""
|
||
if not data:
|
||
return go.Figure()
|
||
|
||
from datetime import datetime
|
||
|
||
display_title = title or "Drug Timeline"
|
||
|
||
# Parse dates and sort by directory then first_seen
|
||
for d in data:
|
||
d["_fs"] = datetime.fromisoformat(d["first_seen"])
|
||
d["_ls"] = datetime.fromisoformat(d["last_seen"])
|
||
d["_duration_days"] = max(1, (d["_ls"] - d["_fs"]).days)
|
||
|
||
# Sort: by directory alphabetically, then by first_seen ascending
|
||
data.sort(key=lambda d: (d["directory"], d["_fs"]))
|
||
|
||
# Assign colors by directory
|
||
directories = list(dict.fromkeys(d["directory"] for d in data))
|
||
dir_colors = {
|
||
d: DRUG_PALETTE[i % len(DRUG_PALETTE)]
|
||
for i, d in enumerate(directories)
|
||
}
|
||
|
||
# Build y-axis labels: "Drug (Directory)" for multi-directory views, just "Drug" for single
|
||
single_directory = len(directories) == 1
|
||
y_labels = []
|
||
for d in data:
|
||
if single_directory:
|
||
y_labels.append(d["drug"])
|
||
else:
|
||
y_labels.append(f"{d['drug']} ({d['directory']})")
|
||
|
||
# Build one trace per directory for legend grouping
|
||
fig = go.Figure()
|
||
dir_legend_shown = set()
|
||
|
||
for i, d in enumerate(data):
|
||
show_legend = d["directory"] not in dir_legend_shown
|
||
dir_legend_shown.add(d["directory"])
|
||
|
||
duration_ms = d["_duration_days"] * 86_400_000 # days → milliseconds
|
||
patients = d["patients"]
|
||
cost = d["cost_pp_pa"]
|
||
|
||
fig.add_trace(
|
||
go.Bar(
|
||
y=[y_labels[i]],
|
||
x=[duration_ms],
|
||
base=[d["_fs"]],
|
||
orientation="h",
|
||
marker=dict(
|
||
color=dir_colors[d["directory"]],
|
||
line=dict(width=0),
|
||
),
|
||
name=d["directory"],
|
||
legendgroup=d["directory"],
|
||
showlegend=show_legend,
|
||
text=f"{patients:,}",
|
||
textposition="inside",
|
||
textfont=dict(color="white", size=10),
|
||
hovertemplate=(
|
||
f"<b>{d['drug']}</b><br>"
|
||
f"Directory: {d['directory']}<br>"
|
||
f"First seen: {d['_fs'].strftime('%b %Y')}<br>"
|
||
f"Last seen: {d['_ls'].strftime('%b %Y')}<br>"
|
||
f"Duration: {d['_duration_days']:,} days<br>"
|
||
f"Patients: {patients:,}<br>"
|
||
f"Cost p.a.: £{cost:,.0f}"
|
||
"<extra></extra>"
|
||
),
|
||
)
|
||
)
|
||
|
||
# Layout
|
||
n_bars = len(data)
|
||
bar_height = 28
|
||
dynamic_height = max(600, n_bars * bar_height + 120)
|
||
|
||
n_dirs = len(directories)
|
||
legend_margins = _smart_legend_margin(n_dirs)
|
||
legend = _smart_legend(n_dirs, legend_title="Directory")
|
||
|
||
layout = _base_layout(display_title)
|
||
layout.update(
|
||
xaxis=dict(
|
||
type="date",
|
||
gridcolor=GRID_COLOR,
|
||
dtick="M6",
|
||
tickformat="%b\n%Y",
|
||
),
|
||
yaxis=dict(
|
||
automargin=True,
|
||
autorange="reversed",
|
||
tickfont=dict(size=11),
|
||
),
|
||
barmode="overlay",
|
||
height=dynamic_height,
|
||
margin=dict(t=60, l=8, **legend_margins),
|
||
legend=legend,
|
||
bargap=0.3,
|
||
)
|
||
fig.update_layout(**layout)
|
||
|
||
return fig
|
||
|
||
|
||
def create_dosing_distribution_figure(
|
||
data: list[dict], title: str = ""
|
||
) -> go.Figure:
|
||
"""Create horizontal bar chart of average administered doses per drug.
|
||
|
||
Args:
|
||
data: list of dicts with keys: drug, directory, avg_doses, patients
|
||
title: chart title suffix
|
||
"""
|
||
if not data:
|
||
return go.Figure()
|
||
|
||
display_title = f"Average Administered Doses — {title}" if title else "Average Administered Doses"
|
||
|
||
# Group by directory for coloring
|
||
directories = sorted(set(d["directory"] for d in data))
|
||
dir_colors = {
|
||
d: DRUG_PALETTE[i % len(DRUG_PALETTE)]
|
||
for i, d in enumerate(directories)
|
||
}
|
||
|
||
single_directory = len(directories) == 1
|
||
|
||
# Sort by avg_doses descending
|
||
sorted_data = sorted(data, key=lambda x: x["avg_doses"])
|
||
|
||
# Build y-labels
|
||
if single_directory:
|
||
y_labels = [d["drug"] for d in sorted_data]
|
||
else:
|
||
y_labels = [f"{d['drug']} ({d['directory']})" for d in sorted_data]
|
||
|
||
fig = go.Figure()
|
||
|
||
# One trace per directory for legend grouping
|
||
shown_dirs = set()
|
||
for i, row in enumerate(sorted_data):
|
||
d = row["directory"]
|
||
show_legend = d not in shown_dirs
|
||
shown_dirs.add(d)
|
||
|
||
fig.add_trace(go.Bar(
|
||
y=[y_labels[i]],
|
||
x=[row["avg_doses"]],
|
||
orientation="h",
|
||
marker_color=dir_colors[d],
|
||
name=d,
|
||
showlegend=show_legend,
|
||
legendgroup=d,
|
||
text=[f"{row['avg_doses']:.0f}"],
|
||
textposition="inside",
|
||
textfont=dict(color="white", size=11),
|
||
hovertemplate=(
|
||
f"<b>{row['drug']}</b><br>"
|
||
f"Directory: {d}<br>"
|
||
f"Avg doses: {row['avg_doses']:.1f}<br>"
|
||
f"Patients: {row['patients']:,}"
|
||
"<extra></extra>"
|
||
),
|
||
))
|
||
|
||
n_bars = len(sorted_data)
|
||
bar_height = 24
|
||
dynamic_height = max(600, n_bars * bar_height + 120)
|
||
|
||
n_dirs = len(directories)
|
||
legend_margins = _smart_legend_margin(n_dirs)
|
||
legend = _smart_legend(n_dirs, legend_title="Directory")
|
||
|
||
layout = _base_layout(display_title)
|
||
layout.update(
|
||
xaxis=dict(
|
||
title="Average Doses Administered",
|
||
gridcolor=GRID_COLOR,
|
||
zeroline=False,
|
||
),
|
||
yaxis=dict(
|
||
automargin=True,
|
||
tickfont=dict(size=11),
|
||
),
|
||
barmode="overlay",
|
||
height=dynamic_height,
|
||
margin=dict(t=60, l=8, **legend_margins),
|
||
legend=legend,
|
||
bargap=0.3,
|
||
)
|
||
fig.update_layout(**layout)
|
||
|
||
return fig
|
||
|
||
|
||
def create_trend_figure(
|
||
data: list[dict],
|
||
title: str = "",
|
||
metric: str = "patients",
|
||
) -> go.Figure:
|
||
"""Create a line chart showing trends over time from pathway_trends data.
|
||
|
||
Args:
|
||
data: List of dicts with keys: period_end, name, value
|
||
title: Chart title
|
||
metric: "patients", "total_cost", or "cost_pp_pa" (for y-axis label)
|
||
"""
|
||
if not data:
|
||
fig = go.Figure()
|
||
fig.add_annotation(
|
||
text="No trend data available.<br>Run <b>python -m cli.compute_trends</b> to generate.",
|
||
xref="paper", yref="paper", x=0.5, y=0.5,
|
||
showarrow=False,
|
||
font=dict(size=16, color=ANNOTATION_COLOR, family=CHART_FONT_FAMILY),
|
||
)
|
||
layout = _base_layout(title or "Temporal Trends")
|
||
fig.update_layout(**layout)
|
||
return fig
|
||
|
||
display_title = title or "Temporal Trends"
|
||
|
||
# Group data by name (drug or directory), sorting periods chronologically
|
||
from collections import defaultdict
|
||
series = defaultdict(list)
|
||
for row in data:
|
||
name = row.get("name", "")
|
||
series[name].append((row["period_end"], row.get("value", 0)))
|
||
|
||
# Sort each series by period
|
||
for name in series:
|
||
series[name].sort(key=lambda x: x[0])
|
||
|
||
n_series = len(series)
|
||
fig = go.Figure()
|
||
|
||
for i, (name, points) in enumerate(sorted(series.items())):
|
||
periods = [p[0] for p in points]
|
||
values = [p[1] for p in points]
|
||
colour = DRUG_PALETTE[i % len(DRUG_PALETTE)]
|
||
fig.add_trace(go.Scatter(
|
||
x=periods,
|
||
y=values,
|
||
mode="lines+markers",
|
||
name=name,
|
||
customdata=[name] * len(periods),
|
||
line=dict(color=colour, width=2),
|
||
marker=dict(color=colour, size=6),
|
||
hovertemplate=(
|
||
f"<b>{name}</b><br>"
|
||
"Period: %{x}<br>"
|
||
"Value: %{y:,.0f}<extra></extra>"
|
||
),
|
||
))
|
||
|
||
metric_labels = {
|
||
"patients": "Patients",
|
||
"total_cost": "Total Cost (£)",
|
||
"cost_pp_pa": "Cost per Patient p.a. (£)",
|
||
}
|
||
y_label = metric_labels.get(metric, "Value")
|
||
|
||
legend = _smart_legend(n_series)
|
||
legend_margins = _smart_legend_margin(n_series)
|
||
|
||
layout = _base_layout(display_title)
|
||
layout.update(
|
||
xaxis=dict(
|
||
title="Period",
|
||
gridcolor=GRID_COLOR,
|
||
type="date",
|
||
dtick="M6",
|
||
tickformat="%b %Y",
|
||
),
|
||
yaxis=dict(
|
||
title=y_label,
|
||
gridcolor=GRID_COLOR,
|
||
zeroline=True,
|
||
zerolinecolor=GRID_COLOR,
|
||
),
|
||
margin=dict(t=60, l=8, **legend_margins),
|
||
legend=legend,
|
||
hovermode="x unified",
|
||
)
|
||
fig.update_layout(**layout)
|
||
|
||
return fig
|