feat: add Drug Switching Sankey diagram (Task 9.6)
This commit is contained in:
@@ -703,6 +703,137 @@ def create_cost_waterfall_figure(
|
||||
return fig
|
||||
|
||||
|
||||
def create_sankey_figure(
|
||||
data: dict,
|
||||
title: str = "",
|
||||
) -> go.Figure:
|
||||
"""Create Sankey diagram showing drug switching flows between treatment lines.
|
||||
|
||||
Args:
|
||||
data: Dict from get_drug_transitions() with keys:
|
||||
nodes: [{name}] — drug names with ordinal suffixes (e.g., "ADALIMUMAB (1st)")
|
||||
links: [{source_idx, target_idx, patients}] — transitions between drugs
|
||||
title: Chart title suffix (filter description).
|
||||
|
||||
Returns:
|
||||
Plotly Figure with Sankey diagram.
|
||||
"""
|
||||
import re
|
||||
|
||||
nodes = data.get("nodes", [])
|
||||
links = data.get("links", [])
|
||||
|
||||
if not nodes or not links:
|
||||
return go.Figure()
|
||||
|
||||
# NHS colour palette — one colour per unique base drug name
|
||||
nhs_colours = [
|
||||
"#005EB8", "#003087", "#41B6E6", "#0066CC", "#1E88E5",
|
||||
"#4FC3F7", "#009639", "#ED8B00", "#768692", "#AE2573",
|
||||
"#8A1538", "#330072", "#DA291C", "#00A499", "#425563",
|
||||
]
|
||||
|
||||
# Extract base drug name (strip ordinal suffix) for colour consistency
|
||||
def base_drug(name: str) -> str:
|
||||
return re.sub(r"\s*\(\d+(?:st|nd|rd|th)\)\s*$", "", name)
|
||||
|
||||
unique_bases = []
|
||||
for n in nodes:
|
||||
b = base_drug(n["name"])
|
||||
if b not in unique_bases:
|
||||
unique_bases.append(b)
|
||||
base_colour_map = {b: nhs_colours[i % len(nhs_colours)] for i, b in enumerate(unique_bases)}
|
||||
|
||||
# Node colours — same drug gets same colour regardless of treatment line
|
||||
node_colours = [base_colour_map[base_drug(n["name"])] for n in nodes]
|
||||
|
||||
# Node labels — format nicely
|
||||
node_labels = [n["name"] for n in nodes]
|
||||
|
||||
# Link colours — use source node colour at 40% opacity for visual clarity
|
||||
def hex_to_rgba(hex_colour: str, alpha: float) -> str:
|
||||
h = hex_colour.lstrip("#")
|
||||
r, g, b = int(h[0:2], 16), int(h[2:4], 16), int(h[4:6], 16)
|
||||
return f"rgba({r},{g},{b},{alpha})"
|
||||
|
||||
link_colours = [
|
||||
hex_to_rgba(node_colours[link["source_idx"]], 0.35)
|
||||
for link in links
|
||||
]
|
||||
|
||||
# Build hover text for links
|
||||
link_hovers = [
|
||||
f"{node_labels[link['source_idx']]} → {node_labels[link['target_idx']]}"
|
||||
f"<br>Patients: {link['patients']:,}"
|
||||
for link in links
|
||||
]
|
||||
|
||||
# Compute total patients per node for node hover
|
||||
node_patients = [0] * len(nodes)
|
||||
for link in links:
|
||||
node_patients[link["source_idx"]] += link["patients"]
|
||||
# For terminal nodes (no outgoing), use incoming total
|
||||
node_incoming = [0] * len(nodes)
|
||||
for link in links:
|
||||
node_incoming[link["target_idx"]] += link["patients"]
|
||||
node_hover = []
|
||||
for i, n in enumerate(nodes):
|
||||
out_p = node_patients[i]
|
||||
in_p = node_incoming[i]
|
||||
total = max(out_p, in_p)
|
||||
node_hover.append(f"<b>{n['name']}</b><br>Patients: {total:,}")
|
||||
|
||||
fig = go.Figure(
|
||||
go.Sankey(
|
||||
arrangement="snap",
|
||||
node=dict(
|
||||
pad=20,
|
||||
thickness=25,
|
||||
line=dict(color="#FFFFFF", width=1),
|
||||
label=node_labels,
|
||||
color=node_colours,
|
||||
customdata=node_hover,
|
||||
hovertemplate="%{customdata}<extra></extra>",
|
||||
),
|
||||
link=dict(
|
||||
source=[link["source_idx"] for link in links],
|
||||
target=[link["target_idx"] for link in links],
|
||||
value=[link["patients"] for link in links],
|
||||
color=link_colours,
|
||||
customdata=link_hovers,
|
||||
hovertemplate="%{customdata}<extra></extra>",
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
chart_title = "Drug Switching Flows"
|
||||
if title:
|
||||
chart_title = f"{chart_title} — {title}"
|
||||
|
||||
fig.update_layout(
|
||||
title=dict(
|
||||
text=chart_title,
|
||||
font=dict(
|
||||
family="Source Sans 3, system-ui, sans-serif",
|
||||
size=18,
|
||||
color="#003087",
|
||||
),
|
||||
x=0.5,
|
||||
xanchor="center",
|
||||
),
|
||||
font=dict(
|
||||
family="Source Sans 3, system-ui, sans-serif",
|
||||
size=12,
|
||||
),
|
||||
paper_bgcolor="rgba(0,0,0,0)",
|
||||
plot_bgcolor="rgba(0,0,0,0)",
|
||||
margin=dict(t=60, l=30, r=30, b=30),
|
||||
height=max(500, len(unique_bases) * 35 + 200),
|
||||
)
|
||||
|
||||
return fig
|
||||
|
||||
|
||||
def save_figure_html(
|
||||
fig: go.Figure, save_dir: str, title: str, open_browser: bool = False
|
||||
) -> str:
|
||||
|
||||
Reference in New Issue
Block a user