feat: drug switching network graph tab (Task C.4)

This commit is contained in:
Andrew Charlwood
2026-02-07 03:32:14 +00:00
parent ac688c9ac0
commit 1405476818
6 changed files with 224 additions and 8 deletions
+88
View File
@@ -1285,6 +1285,94 @@ def get_duration_cost_scatter(
conn.close()
def get_drug_network(
db_path: Path,
date_filter_id: str,
chart_type: str,
directory: Optional[str] = None,
trust: Optional[str] = None,
) -> dict:
"""Build undirected drug co-occurrence network from pathway data.
Unlike get_drug_transitions() (directed, with ordinal suffixes for Sankey),
this returns plain drug names with undirected edges representing co-occurrence
in patient pathways.
Returns dict with:
nodes: [{name, total_patients}] — unique drug names sorted by patient count
edges: [{source, target, patients}] — undirected co-occurrence links
"""
conn = sqlite3.connect(str(db_path))
conn.row_factory = sqlite3.Row
try:
where = ["date_filter_id = ?", "chart_type = ?", "level >= 4"]
params: list = [date_filter_id, chart_type]
if directory:
where.append("directory = ?")
params.append(directory)
if trust:
where.append("trust_name = ?")
params.append(trust)
query = f"""
SELECT value AS patients, drug_sequence
FROM pathway_nodes
WHERE {' AND '.join(where)}
"""
rows = conn.execute(query, params).fetchall()
# Also get level 3 nodes for per-drug patient totals
where_l3 = ["date_filter_id = ?", "chart_type = ?", "level = 3"]
params_l3: list = [date_filter_id, chart_type]
if directory:
where_l3.append("directory = ?")
params_l3.append(directory)
if trust:
where_l3.append("trust_name = ?")
params_l3.append(trust)
query_l3 = f"""
SELECT labels AS drug, SUM(value) AS total_patients
FROM pathway_nodes
WHERE {' AND '.join(where_l3)}
GROUP BY labels
"""
l3_rows = conn.execute(query_l3, params_l3).fetchall()
node_patients = {r["drug"]: r["total_patients"] or 0 for r in l3_rows}
# Build undirected edges from pathway sequences
edge_agg = {}
for r in rows:
drugs = [d for d in (r["drug_sequence"] or "").split("|") if d]
patients = r["patients"] or 0
if len(drugs) < 2 or patients == 0:
continue
# Adjacent drug pairs (undirected: sort to avoid A→B and B→A duplicates)
for i in range(len(drugs) - 1):
pair = tuple(sorted([drugs[i], drugs[i + 1]]))
edge_agg[pair] = edge_agg.get(pair, 0) + patients
# Build result
nodes = [
{"name": name, "total_patients": pts}
for name, pts in sorted(node_patients.items(), key=lambda x: -x[1])
if pts > 0
]
edges = [
{"source": src, "target": tgt, "patients": pts}
for (src, tgt), pts in sorted(edge_agg.items(), key=lambda x: -x[1])
]
return {"nodes": nodes, "edges": edges}
except sqlite3.Error:
return {"nodes": [], "edges": []}
finally:
conn.close()
def get_directorate_summary(
db_path: Path,
date_filter_id: str,
+88
View File
@@ -1997,3 +1997,91 @@ def create_duration_cost_scatter_figure(
fig.update_layout(**layout)
return fig
def create_drug_network_figure(data: dict, title: str = "") -> go.Figure:
"""Create a drug co-occurrence network graph.
Nodes are drugs arranged in a circle, edges show co-occurrence in pathways.
Node size = total patients, edge width = switching flow between drugs.
"""
import math
nodes = data.get("nodes", [])
edges = data.get("edges", [])
if not nodes:
return go.Figure()
display_title = f"Drug Network — {title}" if title else "Drug Network"
# Circular layout
n = len(nodes)
node_names = [nd["name"] for nd in nodes]
node_patients = [nd["total_patients"] for nd in nodes]
name_to_idx = {nd["name"]: i for i, nd in enumerate(nodes)}
angles = [2 * math.pi * i / n for i in range(n)]
x_pos = [math.cos(a) for a in angles]
y_pos = [math.sin(a) for a in angles]
fig = go.Figure()
# Draw edges as individual traces (each gets its own width)
max_edge_patients = max((e["patients"] for e in edges), default=1) or 1
for edge in edges:
src_idx = name_to_idx.get(edge["source"])
tgt_idx = name_to_idx.get(edge["target"])
if src_idx is None or tgt_idx is None:
continue
# Scale width: min 0.5, max 6
width = max(0.5, min(6, 0.5 + 5.5 * (edge["patients"] / max_edge_patients)))
# Opacity scales with relative strength
opacity = max(0.15, min(0.7, 0.15 + 0.55 * (edge["patients"] / max_edge_patients)))
fig.add_trace(go.Scatter(
x=[x_pos[src_idx], x_pos[tgt_idx]],
y=[y_pos[src_idx], y_pos[tgt_idx]],
mode="lines",
line=dict(width=width, color=f"rgba(0,94,184,{opacity})"),
hoverinfo="skip",
showlegend=False,
))
# Draw nodes
max_patients = max(node_patients, default=1) or 1
sizes = [max(12, min(50, 12 + 38 * (p / max_patients))) for p in node_patients]
colors = [DRUG_PALETTE[i % len(DRUG_PALETTE)] for i in range(n)]
fig.add_trace(go.Scatter(
x=x_pos,
y=y_pos,
mode="markers+text",
marker=dict(
size=sizes,
color=colors,
line=dict(width=1.5, color="white"),
),
text=node_names,
textposition="top center",
textfont=dict(size=9, family=CHART_FONT_FAMILY),
customdata=[[p] for p in node_patients],
hovertemplate=(
"<b>%{text}</b><br>"
"Patients: %{customdata[0]:,}<br>"
"<extra></extra>"
),
showlegend=False,
))
layout = _base_layout(display_title)
layout.update(
margin=dict(t=60, l=24, r=24, b=24),
xaxis=dict(visible=False, scaleanchor="y", scaleratio=1),
yaxis=dict(visible=False),
height=600,
)
fig.update_layout(**layout)
return fig