feat: drug switching network graph tab (Task C.4)

This commit is contained in:
Andrew Charlwood
2026-02-07 03:32:14 +00:00
parent ac688c9ac0
commit 1405476818
6 changed files with 224 additions and 8 deletions
+8 -8
View File
@@ -175,14 +175,14 @@ Comprehensive review and improvement of all Plotly charts in the Dash dashboard.
- **Checkpoint**: Scatter tab shows drugs by duration vs cost with directorate coloring - **Checkpoint**: Scatter tab shows drugs by duration vs cost with directorate coloring
### C.4 Drug switching network graph ### C.4 Drug switching network graph
- [ ] Create modified variant of `get_drug_transitions()` in pathway_queries.py that returns undirected edges without ordinal suffixes: - [x] Create `get_drug_network()` in pathway_queries.py undirected edges without ordinal suffixes, node patients from level 3, edge co-occurrence from level 4+
- `get_drug_network(db_path, filter_id, chart_type, directory, trust)` → `{nodes: [{name, total_patients}], edges: [{source, target, patients}]}` - [x] Add thin wrapper in `dash_app/data/queries.py`
- [ ] Add thin wrapper in `dash_app/data/queries.py` - [x] Create `create_drug_network_figure(data, title)` in `src/visualization/plotly_generator.py`:
- [ ] Create `create_drug_network_figure(data, title)` in `src/visualization/plotly_generator.py`: - Circular layout using `go.Scatter` for nodes + individual edge traces as lines
- Use `go.Scatter` for nodes (circular layout) + edges (lines) - Node size = total patients (1250px), edge width = switching flow (0.56px), edge opacity scales with strength
- Node size = total patients, edge width = switching flow - `DRUG_PALETTE` for node colors, NHS Blue (`rgba(0,94,184,...)`) for edges
- `DRUG_PALETTE` for node colors - [x] Added as separate "Network" tab (7th tab: Icicle, Sankey, Heatmap, Funnel, Depth, Scatter, Network)
- [ ] Add as sub-toggle within Sankey tab (e.g., "Flow" vs "Network" toggle) or as separate "Network" tab - [x] Added `_render_network()` helper and dispatch case in `chart.py`
- **Checkpoint**: Network view shows drug switching as a graph alternative to Sankey - **Checkpoint**: Network view shows drug switching as a graph alternative to Sankey
--- ---
+28
View File
@@ -342,6 +342,31 @@ def _render_scatter(app_state, title):
return create_duration_cost_scatter_figure(data, title) return create_duration_cost_scatter_figure(data, title)
def _render_network(app_state, title):
"""Build the drug co-occurrence network graph from current filter state."""
from dash_app.data.queries import get_drug_network
from visualization.plotly_generator import create_drug_network_figure
filter_id = (app_state or {}).get("date_filter_id", "all_6mo")
chart_type = (app_state or {}).get("chart_type", "directory")
selected_dirs = (app_state or {}).get("selected_directorates") or []
selected_trusts = (app_state or {}).get("selected_trusts") or []
directory = selected_dirs[0] if len(selected_dirs) == 1 else None
trust = selected_trusts[0] if len(selected_trusts) == 1 else None
try:
data = get_drug_network(filter_id, chart_type, directory, trust)
except Exception:
log.exception("Failed to load drug network data")
return _empty_figure("Failed to load drug network data.")
if not data.get("nodes"):
return _empty_figure("No drug network data available.\nTry adjusting your filters.")
return create_drug_network_figure(data, title)
def register_chart_callbacks(app): def register_chart_callbacks(app):
"""Register tab switching, pathway data loading, and chart rendering callbacks.""" """Register tab switching, pathway data loading, and chart rendering callbacks."""
@@ -491,6 +516,9 @@ def register_chart_callbacks(app):
elif active_tab == "scatter": elif active_tab == "scatter":
fig = _render_scatter(app_state, title) fig = _render_scatter(app_state, title)
elif active_tab == "network":
fig = _render_network(app_state, title)
else: else:
# Placeholder for charts not yet implemented # Placeholder for charts not yet implemented
tab_label = dict(TAB_DEFINITIONS).get(active_tab, active_tab) tab_label = dict(TAB_DEFINITIONS).get(active_tab, active_tab)
+1
View File
@@ -11,6 +11,7 @@ TAB_DEFINITIONS = [
("funnel", "Funnel"), ("funnel", "Funnel"),
("depth", "Depth"), ("depth", "Depth"),
("scatter", "Scatter"), ("scatter", "Scatter"),
("network", "Network"),
] ]
# Full set retained for Trust Comparison dashboard (Phase 10.8) # Full set retained for Trust Comparison dashboard (Phase 10.8)
+11
View File
@@ -27,6 +27,7 @@ from data_processing.pathway_queries import (
get_retention_funnel as _get_retention_funnel, get_retention_funnel as _get_retention_funnel,
get_pathway_depth_distribution as _get_pathway_depth_distribution, get_pathway_depth_distribution as _get_pathway_depth_distribution,
get_duration_cost_scatter as _get_duration_cost_scatter, get_duration_cost_scatter as _get_duration_cost_scatter,
get_drug_network as _get_drug_network,
) )
DB_PATH = Path(__file__).resolve().parents[2] / "data" / "pathways.db" DB_PATH = Path(__file__).resolve().parents[2] / "data" / "pathways.db"
@@ -216,3 +217,13 @@ def get_duration_cost_scatter(
) -> list[dict]: ) -> list[dict]:
"""Drug-level avg_days and cost_pp_pa for scatter plot.""" """Drug-level avg_days and cost_pp_pa for scatter plot."""
return _get_duration_cost_scatter(DB_PATH, date_filter_id, chart_type, directory, trust) return _get_duration_cost_scatter(DB_PATH, date_filter_id, chart_type, directory, trust)
def get_drug_network(
date_filter_id: str = "all_6mo",
chart_type: str = "directory",
directory: Optional[str] = None,
trust: Optional[str] = None,
) -> dict:
"""Undirected drug co-occurrence network for network graph."""
return _get_drug_network(DB_PATH, date_filter_id, chart_type, directory, trust)
+88
View File
@@ -1285,6 +1285,94 @@ def get_duration_cost_scatter(
conn.close() conn.close()
def get_drug_network(
db_path: Path,
date_filter_id: str,
chart_type: str,
directory: Optional[str] = None,
trust: Optional[str] = None,
) -> dict:
"""Build undirected drug co-occurrence network from pathway data.
Unlike get_drug_transitions() (directed, with ordinal suffixes for Sankey),
this returns plain drug names with undirected edges representing co-occurrence
in patient pathways.
Returns dict with:
nodes: [{name, total_patients}] — unique drug names sorted by patient count
edges: [{source, target, patients}] — undirected co-occurrence links
"""
conn = sqlite3.connect(str(db_path))
conn.row_factory = sqlite3.Row
try:
where = ["date_filter_id = ?", "chart_type = ?", "level >= 4"]
params: list = [date_filter_id, chart_type]
if directory:
where.append("directory = ?")
params.append(directory)
if trust:
where.append("trust_name = ?")
params.append(trust)
query = f"""
SELECT value AS patients, drug_sequence
FROM pathway_nodes
WHERE {' AND '.join(where)}
"""
rows = conn.execute(query, params).fetchall()
# Also get level 3 nodes for per-drug patient totals
where_l3 = ["date_filter_id = ?", "chart_type = ?", "level = 3"]
params_l3: list = [date_filter_id, chart_type]
if directory:
where_l3.append("directory = ?")
params_l3.append(directory)
if trust:
where_l3.append("trust_name = ?")
params_l3.append(trust)
query_l3 = f"""
SELECT labels AS drug, SUM(value) AS total_patients
FROM pathway_nodes
WHERE {' AND '.join(where_l3)}
GROUP BY labels
"""
l3_rows = conn.execute(query_l3, params_l3).fetchall()
node_patients = {r["drug"]: r["total_patients"] or 0 for r in l3_rows}
# Build undirected edges from pathway sequences
edge_agg = {}
for r in rows:
drugs = [d for d in (r["drug_sequence"] or "").split("|") if d]
patients = r["patients"] or 0
if len(drugs) < 2 or patients == 0:
continue
# Adjacent drug pairs (undirected: sort to avoid A→B and B→A duplicates)
for i in range(len(drugs) - 1):
pair = tuple(sorted([drugs[i], drugs[i + 1]]))
edge_agg[pair] = edge_agg.get(pair, 0) + patients
# Build result
nodes = [
{"name": name, "total_patients": pts}
for name, pts in sorted(node_patients.items(), key=lambda x: -x[1])
if pts > 0
]
edges = [
{"source": src, "target": tgt, "patients": pts}
for (src, tgt), pts in sorted(edge_agg.items(), key=lambda x: -x[1])
]
return {"nodes": nodes, "edges": edges}
except sqlite3.Error:
return {"nodes": [], "edges": []}
finally:
conn.close()
def get_directorate_summary( def get_directorate_summary(
db_path: Path, db_path: Path,
date_filter_id: str, date_filter_id: str,
+88
View File
@@ -1997,3 +1997,91 @@ def create_duration_cost_scatter_figure(
fig.update_layout(**layout) fig.update_layout(**layout)
return fig return fig
def create_drug_network_figure(data: dict, title: str = "") -> go.Figure:
"""Create a drug co-occurrence network graph.
Nodes are drugs arranged in a circle, edges show co-occurrence in pathways.
Node size = total patients, edge width = switching flow between drugs.
"""
import math
nodes = data.get("nodes", [])
edges = data.get("edges", [])
if not nodes:
return go.Figure()
display_title = f"Drug Network — {title}" if title else "Drug Network"
# Circular layout
n = len(nodes)
node_names = [nd["name"] for nd in nodes]
node_patients = [nd["total_patients"] for nd in nodes]
name_to_idx = {nd["name"]: i for i, nd in enumerate(nodes)}
angles = [2 * math.pi * i / n for i in range(n)]
x_pos = [math.cos(a) for a in angles]
y_pos = [math.sin(a) for a in angles]
fig = go.Figure()
# Draw edges as individual traces (each gets its own width)
max_edge_patients = max((e["patients"] for e in edges), default=1) or 1
for edge in edges:
src_idx = name_to_idx.get(edge["source"])
tgt_idx = name_to_idx.get(edge["target"])
if src_idx is None or tgt_idx is None:
continue
# Scale width: min 0.5, max 6
width = max(0.5, min(6, 0.5 + 5.5 * (edge["patients"] / max_edge_patients)))
# Opacity scales with relative strength
opacity = max(0.15, min(0.7, 0.15 + 0.55 * (edge["patients"] / max_edge_patients)))
fig.add_trace(go.Scatter(
x=[x_pos[src_idx], x_pos[tgt_idx]],
y=[y_pos[src_idx], y_pos[tgt_idx]],
mode="lines",
line=dict(width=width, color=f"rgba(0,94,184,{opacity})"),
hoverinfo="skip",
showlegend=False,
))
# Draw nodes
max_patients = max(node_patients, default=1) or 1
sizes = [max(12, min(50, 12 + 38 * (p / max_patients))) for p in node_patients]
colors = [DRUG_PALETTE[i % len(DRUG_PALETTE)] for i in range(n)]
fig.add_trace(go.Scatter(
x=x_pos,
y=y_pos,
mode="markers+text",
marker=dict(
size=sizes,
color=colors,
line=dict(width=1.5, color="white"),
),
text=node_names,
textposition="top center",
textfont=dict(size=9, family=CHART_FONT_FAMILY),
customdata=[[p] for p in node_patients],
hovertemplate=(
"<b>%{text}</b><br>"
"Patients: %{customdata[0]:,}<br>"
"<extra></extra>"
),
showlegend=False,
))
layout = _base_layout(display_title)
layout.update(
margin=dict(t=60, l=24, r=24, b=24),
xaxis=dict(visible=False, scaleanchor="y", scaleratio=1),
yaxis=dict(visible=False),
height=600,
)
fig.update_layout(**layout)
return fig