feat: drug switching network graph tab (Task C.4)
This commit is contained in:
@@ -1285,6 +1285,94 @@ def get_duration_cost_scatter(
|
||||
conn.close()
|
||||
|
||||
|
||||
def get_drug_network(
|
||||
db_path: Path,
|
||||
date_filter_id: str,
|
||||
chart_type: str,
|
||||
directory: Optional[str] = None,
|
||||
trust: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""Build undirected drug co-occurrence network from pathway data.
|
||||
|
||||
Unlike get_drug_transitions() (directed, with ordinal suffixes for Sankey),
|
||||
this returns plain drug names with undirected edges representing co-occurrence
|
||||
in patient pathways.
|
||||
|
||||
Returns dict with:
|
||||
nodes: [{name, total_patients}] — unique drug names sorted by patient count
|
||||
edges: [{source, target, patients}] — undirected co-occurrence links
|
||||
"""
|
||||
conn = sqlite3.connect(str(db_path))
|
||||
conn.row_factory = sqlite3.Row
|
||||
try:
|
||||
where = ["date_filter_id = ?", "chart_type = ?", "level >= 4"]
|
||||
params: list = [date_filter_id, chart_type]
|
||||
|
||||
if directory:
|
||||
where.append("directory = ?")
|
||||
params.append(directory)
|
||||
if trust:
|
||||
where.append("trust_name = ?")
|
||||
params.append(trust)
|
||||
|
||||
query = f"""
|
||||
SELECT value AS patients, drug_sequence
|
||||
FROM pathway_nodes
|
||||
WHERE {' AND '.join(where)}
|
||||
"""
|
||||
rows = conn.execute(query, params).fetchall()
|
||||
|
||||
# Also get level 3 nodes for per-drug patient totals
|
||||
where_l3 = ["date_filter_id = ?", "chart_type = ?", "level = 3"]
|
||||
params_l3: list = [date_filter_id, chart_type]
|
||||
if directory:
|
||||
where_l3.append("directory = ?")
|
||||
params_l3.append(directory)
|
||||
if trust:
|
||||
where_l3.append("trust_name = ?")
|
||||
params_l3.append(trust)
|
||||
|
||||
query_l3 = f"""
|
||||
SELECT labels AS drug, SUM(value) AS total_patients
|
||||
FROM pathway_nodes
|
||||
WHERE {' AND '.join(where_l3)}
|
||||
GROUP BY labels
|
||||
"""
|
||||
l3_rows = conn.execute(query_l3, params_l3).fetchall()
|
||||
node_patients = {r["drug"]: r["total_patients"] or 0 for r in l3_rows}
|
||||
|
||||
# Build undirected edges from pathway sequences
|
||||
edge_agg = {}
|
||||
for r in rows:
|
||||
drugs = [d for d in (r["drug_sequence"] or "").split("|") if d]
|
||||
patients = r["patients"] or 0
|
||||
if len(drugs) < 2 or patients == 0:
|
||||
continue
|
||||
|
||||
# Adjacent drug pairs (undirected: sort to avoid A→B and B→A duplicates)
|
||||
for i in range(len(drugs) - 1):
|
||||
pair = tuple(sorted([drugs[i], drugs[i + 1]]))
|
||||
edge_agg[pair] = edge_agg.get(pair, 0) + patients
|
||||
|
||||
# Build result
|
||||
nodes = [
|
||||
{"name": name, "total_patients": pts}
|
||||
for name, pts in sorted(node_patients.items(), key=lambda x: -x[1])
|
||||
if pts > 0
|
||||
]
|
||||
|
||||
edges = [
|
||||
{"source": src, "target": tgt, "patients": pts}
|
||||
for (src, tgt), pts in sorted(edge_agg.items(), key=lambda x: -x[1])
|
||||
]
|
||||
|
||||
return {"nodes": nodes, "edges": edges}
|
||||
except sqlite3.Error:
|
||||
return {"nodes": [], "edges": []}
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
|
||||
def get_directorate_summary(
|
||||
db_path: Path,
|
||||
date_filter_id: str,
|
||||
|
||||
@@ -1997,3 +1997,91 @@ def create_duration_cost_scatter_figure(
|
||||
fig.update_layout(**layout)
|
||||
|
||||
return fig
|
||||
|
||||
|
||||
def create_drug_network_figure(data: dict, title: str = "") -> go.Figure:
|
||||
"""Create a drug co-occurrence network graph.
|
||||
|
||||
Nodes are drugs arranged in a circle, edges show co-occurrence in pathways.
|
||||
Node size = total patients, edge width = switching flow between drugs.
|
||||
"""
|
||||
import math
|
||||
|
||||
nodes = data.get("nodes", [])
|
||||
edges = data.get("edges", [])
|
||||
|
||||
if not nodes:
|
||||
return go.Figure()
|
||||
|
||||
display_title = f"Drug Network — {title}" if title else "Drug Network"
|
||||
|
||||
# Circular layout
|
||||
n = len(nodes)
|
||||
node_names = [nd["name"] for nd in nodes]
|
||||
node_patients = [nd["total_patients"] for nd in nodes]
|
||||
name_to_idx = {nd["name"]: i for i, nd in enumerate(nodes)}
|
||||
|
||||
angles = [2 * math.pi * i / n for i in range(n)]
|
||||
x_pos = [math.cos(a) for a in angles]
|
||||
y_pos = [math.sin(a) for a in angles]
|
||||
|
||||
fig = go.Figure()
|
||||
|
||||
# Draw edges as individual traces (each gets its own width)
|
||||
max_edge_patients = max((e["patients"] for e in edges), default=1) or 1
|
||||
for edge in edges:
|
||||
src_idx = name_to_idx.get(edge["source"])
|
||||
tgt_idx = name_to_idx.get(edge["target"])
|
||||
if src_idx is None or tgt_idx is None:
|
||||
continue
|
||||
|
||||
# Scale width: min 0.5, max 6
|
||||
width = max(0.5, min(6, 0.5 + 5.5 * (edge["patients"] / max_edge_patients)))
|
||||
# Opacity scales with relative strength
|
||||
opacity = max(0.15, min(0.7, 0.15 + 0.55 * (edge["patients"] / max_edge_patients)))
|
||||
|
||||
fig.add_trace(go.Scatter(
|
||||
x=[x_pos[src_idx], x_pos[tgt_idx]],
|
||||
y=[y_pos[src_idx], y_pos[tgt_idx]],
|
||||
mode="lines",
|
||||
line=dict(width=width, color=f"rgba(0,94,184,{opacity})"),
|
||||
hoverinfo="skip",
|
||||
showlegend=False,
|
||||
))
|
||||
|
||||
# Draw nodes
|
||||
max_patients = max(node_patients, default=1) or 1
|
||||
sizes = [max(12, min(50, 12 + 38 * (p / max_patients))) for p in node_patients]
|
||||
colors = [DRUG_PALETTE[i % len(DRUG_PALETTE)] for i in range(n)]
|
||||
|
||||
fig.add_trace(go.Scatter(
|
||||
x=x_pos,
|
||||
y=y_pos,
|
||||
mode="markers+text",
|
||||
marker=dict(
|
||||
size=sizes,
|
||||
color=colors,
|
||||
line=dict(width=1.5, color="white"),
|
||||
),
|
||||
text=node_names,
|
||||
textposition="top center",
|
||||
textfont=dict(size=9, family=CHART_FONT_FAMILY),
|
||||
customdata=[[p] for p in node_patients],
|
||||
hovertemplate=(
|
||||
"<b>%{text}</b><br>"
|
||||
"Patients: %{customdata[0]:,}<br>"
|
||||
"<extra></extra>"
|
||||
),
|
||||
showlegend=False,
|
||||
))
|
||||
|
||||
layout = _base_layout(display_title)
|
||||
layout.update(
|
||||
margin=dict(t=60, l=24, r=24, b=24),
|
||||
xaxis=dict(visible=False, scaleanchor="y", scaleratio=1),
|
||||
yaxis=dict(visible=False),
|
||||
height=600,
|
||||
)
|
||||
fig.update_layout(**layout)
|
||||
|
||||
return fig
|
||||
|
||||
Reference in New Issue
Block a user