feat: drug switching network graph tab (Task C.4)
This commit is contained in:
@@ -175,14 +175,14 @@ Comprehensive review and improvement of all Plotly charts in the Dash dashboard.
|
|||||||
- **Checkpoint**: Scatter tab shows drugs by duration vs cost with directorate coloring
|
- **Checkpoint**: Scatter tab shows drugs by duration vs cost with directorate coloring
|
||||||
|
|
||||||
### C.4 Drug switching network graph
|
### C.4 Drug switching network graph
|
||||||
- [ ] Create modified variant of `get_drug_transitions()` in pathway_queries.py that returns undirected edges without ordinal suffixes:
|
- [x] Create `get_drug_network()` in pathway_queries.py — undirected edges without ordinal suffixes, node patients from level 3, edge co-occurrence from level 4+
|
||||||
- `get_drug_network(db_path, filter_id, chart_type, directory, trust)` → `{nodes: [{name, total_patients}], edges: [{source, target, patients}]}`
|
- [x] Add thin wrapper in `dash_app/data/queries.py`
|
||||||
- [ ] Add thin wrapper in `dash_app/data/queries.py`
|
- [x] Create `create_drug_network_figure(data, title)` in `src/visualization/plotly_generator.py`:
|
||||||
- [ ] Create `create_drug_network_figure(data, title)` in `src/visualization/plotly_generator.py`:
|
- Circular layout using `go.Scatter` for nodes + individual edge traces as lines
|
||||||
- Use `go.Scatter` for nodes (circular layout) + edges (lines)
|
- Node size = total patients (12–50px), edge width = switching flow (0.5–6px), edge opacity scales with strength
|
||||||
- Node size = total patients, edge width = switching flow
|
- `DRUG_PALETTE` for node colors, NHS Blue (`rgba(0,94,184,...)`) for edges
|
||||||
- `DRUG_PALETTE` for node colors
|
- [x] Added as separate "Network" tab (7th tab: Icicle, Sankey, Heatmap, Funnel, Depth, Scatter, Network)
|
||||||
- [ ] Add as sub-toggle within Sankey tab (e.g., "Flow" vs "Network" toggle) or as separate "Network" tab
|
- [x] Added `_render_network()` helper and dispatch case in `chart.py`
|
||||||
- **Checkpoint**: Network view shows drug switching as a graph alternative to Sankey
|
- **Checkpoint**: Network view shows drug switching as a graph alternative to Sankey
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|||||||
@@ -342,6 +342,31 @@ def _render_scatter(app_state, title):
|
|||||||
return create_duration_cost_scatter_figure(data, title)
|
return create_duration_cost_scatter_figure(data, title)
|
||||||
|
|
||||||
|
|
||||||
|
def _render_network(app_state, title):
|
||||||
|
"""Build the drug co-occurrence network graph from current filter state."""
|
||||||
|
from dash_app.data.queries import get_drug_network
|
||||||
|
from visualization.plotly_generator import create_drug_network_figure
|
||||||
|
|
||||||
|
filter_id = (app_state or {}).get("date_filter_id", "all_6mo")
|
||||||
|
chart_type = (app_state or {}).get("chart_type", "directory")
|
||||||
|
|
||||||
|
selected_dirs = (app_state or {}).get("selected_directorates") or []
|
||||||
|
selected_trusts = (app_state or {}).get("selected_trusts") or []
|
||||||
|
directory = selected_dirs[0] if len(selected_dirs) == 1 else None
|
||||||
|
trust = selected_trusts[0] if len(selected_trusts) == 1 else None
|
||||||
|
|
||||||
|
try:
|
||||||
|
data = get_drug_network(filter_id, chart_type, directory, trust)
|
||||||
|
except Exception:
|
||||||
|
log.exception("Failed to load drug network data")
|
||||||
|
return _empty_figure("Failed to load drug network data.")
|
||||||
|
|
||||||
|
if not data.get("nodes"):
|
||||||
|
return _empty_figure("No drug network data available.\nTry adjusting your filters.")
|
||||||
|
|
||||||
|
return create_drug_network_figure(data, title)
|
||||||
|
|
||||||
|
|
||||||
def register_chart_callbacks(app):
|
def register_chart_callbacks(app):
|
||||||
"""Register tab switching, pathway data loading, and chart rendering callbacks."""
|
"""Register tab switching, pathway data loading, and chart rendering callbacks."""
|
||||||
|
|
||||||
@@ -491,6 +516,9 @@ def register_chart_callbacks(app):
|
|||||||
elif active_tab == "scatter":
|
elif active_tab == "scatter":
|
||||||
fig = _render_scatter(app_state, title)
|
fig = _render_scatter(app_state, title)
|
||||||
|
|
||||||
|
elif active_tab == "network":
|
||||||
|
fig = _render_network(app_state, title)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# Placeholder for charts not yet implemented
|
# Placeholder for charts not yet implemented
|
||||||
tab_label = dict(TAB_DEFINITIONS).get(active_tab, active_tab)
|
tab_label = dict(TAB_DEFINITIONS).get(active_tab, active_tab)
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ TAB_DEFINITIONS = [
|
|||||||
("funnel", "Funnel"),
|
("funnel", "Funnel"),
|
||||||
("depth", "Depth"),
|
("depth", "Depth"),
|
||||||
("scatter", "Scatter"),
|
("scatter", "Scatter"),
|
||||||
|
("network", "Network"),
|
||||||
]
|
]
|
||||||
|
|
||||||
# Full set retained for Trust Comparison dashboard (Phase 10.8)
|
# Full set retained for Trust Comparison dashboard (Phase 10.8)
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ from data_processing.pathway_queries import (
|
|||||||
get_retention_funnel as _get_retention_funnel,
|
get_retention_funnel as _get_retention_funnel,
|
||||||
get_pathway_depth_distribution as _get_pathway_depth_distribution,
|
get_pathway_depth_distribution as _get_pathway_depth_distribution,
|
||||||
get_duration_cost_scatter as _get_duration_cost_scatter,
|
get_duration_cost_scatter as _get_duration_cost_scatter,
|
||||||
|
get_drug_network as _get_drug_network,
|
||||||
)
|
)
|
||||||
|
|
||||||
DB_PATH = Path(__file__).resolve().parents[2] / "data" / "pathways.db"
|
DB_PATH = Path(__file__).resolve().parents[2] / "data" / "pathways.db"
|
||||||
@@ -216,3 +217,13 @@ def get_duration_cost_scatter(
|
|||||||
) -> list[dict]:
|
) -> list[dict]:
|
||||||
"""Drug-level avg_days and cost_pp_pa for scatter plot."""
|
"""Drug-level avg_days and cost_pp_pa for scatter plot."""
|
||||||
return _get_duration_cost_scatter(DB_PATH, date_filter_id, chart_type, directory, trust)
|
return _get_duration_cost_scatter(DB_PATH, date_filter_id, chart_type, directory, trust)
|
||||||
|
|
||||||
|
|
||||||
|
def get_drug_network(
|
||||||
|
date_filter_id: str = "all_6mo",
|
||||||
|
chart_type: str = "directory",
|
||||||
|
directory: Optional[str] = None,
|
||||||
|
trust: Optional[str] = None,
|
||||||
|
) -> dict:
|
||||||
|
"""Undirected drug co-occurrence network for network graph."""
|
||||||
|
return _get_drug_network(DB_PATH, date_filter_id, chart_type, directory, trust)
|
||||||
|
|||||||
@@ -1285,6 +1285,94 @@ def get_duration_cost_scatter(
|
|||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
|
|
||||||
|
def get_drug_network(
|
||||||
|
db_path: Path,
|
||||||
|
date_filter_id: str,
|
||||||
|
chart_type: str,
|
||||||
|
directory: Optional[str] = None,
|
||||||
|
trust: Optional[str] = None,
|
||||||
|
) -> dict:
|
||||||
|
"""Build undirected drug co-occurrence network from pathway data.
|
||||||
|
|
||||||
|
Unlike get_drug_transitions() (directed, with ordinal suffixes for Sankey),
|
||||||
|
this returns plain drug names with undirected edges representing co-occurrence
|
||||||
|
in patient pathways.
|
||||||
|
|
||||||
|
Returns dict with:
|
||||||
|
nodes: [{name, total_patients}] — unique drug names sorted by patient count
|
||||||
|
edges: [{source, target, patients}] — undirected co-occurrence links
|
||||||
|
"""
|
||||||
|
conn = sqlite3.connect(str(db_path))
|
||||||
|
conn.row_factory = sqlite3.Row
|
||||||
|
try:
|
||||||
|
where = ["date_filter_id = ?", "chart_type = ?", "level >= 4"]
|
||||||
|
params: list = [date_filter_id, chart_type]
|
||||||
|
|
||||||
|
if directory:
|
||||||
|
where.append("directory = ?")
|
||||||
|
params.append(directory)
|
||||||
|
if trust:
|
||||||
|
where.append("trust_name = ?")
|
||||||
|
params.append(trust)
|
||||||
|
|
||||||
|
query = f"""
|
||||||
|
SELECT value AS patients, drug_sequence
|
||||||
|
FROM pathway_nodes
|
||||||
|
WHERE {' AND '.join(where)}
|
||||||
|
"""
|
||||||
|
rows = conn.execute(query, params).fetchall()
|
||||||
|
|
||||||
|
# Also get level 3 nodes for per-drug patient totals
|
||||||
|
where_l3 = ["date_filter_id = ?", "chart_type = ?", "level = 3"]
|
||||||
|
params_l3: list = [date_filter_id, chart_type]
|
||||||
|
if directory:
|
||||||
|
where_l3.append("directory = ?")
|
||||||
|
params_l3.append(directory)
|
||||||
|
if trust:
|
||||||
|
where_l3.append("trust_name = ?")
|
||||||
|
params_l3.append(trust)
|
||||||
|
|
||||||
|
query_l3 = f"""
|
||||||
|
SELECT labels AS drug, SUM(value) AS total_patients
|
||||||
|
FROM pathway_nodes
|
||||||
|
WHERE {' AND '.join(where_l3)}
|
||||||
|
GROUP BY labels
|
||||||
|
"""
|
||||||
|
l3_rows = conn.execute(query_l3, params_l3).fetchall()
|
||||||
|
node_patients = {r["drug"]: r["total_patients"] or 0 for r in l3_rows}
|
||||||
|
|
||||||
|
# Build undirected edges from pathway sequences
|
||||||
|
edge_agg = {}
|
||||||
|
for r in rows:
|
||||||
|
drugs = [d for d in (r["drug_sequence"] or "").split("|") if d]
|
||||||
|
patients = r["patients"] or 0
|
||||||
|
if len(drugs) < 2 or patients == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Adjacent drug pairs (undirected: sort to avoid A→B and B→A duplicates)
|
||||||
|
for i in range(len(drugs) - 1):
|
||||||
|
pair = tuple(sorted([drugs[i], drugs[i + 1]]))
|
||||||
|
edge_agg[pair] = edge_agg.get(pair, 0) + patients
|
||||||
|
|
||||||
|
# Build result
|
||||||
|
nodes = [
|
||||||
|
{"name": name, "total_patients": pts}
|
||||||
|
for name, pts in sorted(node_patients.items(), key=lambda x: -x[1])
|
||||||
|
if pts > 0
|
||||||
|
]
|
||||||
|
|
||||||
|
edges = [
|
||||||
|
{"source": src, "target": tgt, "patients": pts}
|
||||||
|
for (src, tgt), pts in sorted(edge_agg.items(), key=lambda x: -x[1])
|
||||||
|
]
|
||||||
|
|
||||||
|
return {"nodes": nodes, "edges": edges}
|
||||||
|
except sqlite3.Error:
|
||||||
|
return {"nodes": [], "edges": []}
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
|
||||||
def get_directorate_summary(
|
def get_directorate_summary(
|
||||||
db_path: Path,
|
db_path: Path,
|
||||||
date_filter_id: str,
|
date_filter_id: str,
|
||||||
|
|||||||
@@ -1997,3 +1997,91 @@ def create_duration_cost_scatter_figure(
|
|||||||
fig.update_layout(**layout)
|
fig.update_layout(**layout)
|
||||||
|
|
||||||
return fig
|
return fig
|
||||||
|
|
||||||
|
|
||||||
|
def create_drug_network_figure(data: dict, title: str = "") -> go.Figure:
|
||||||
|
"""Create a drug co-occurrence network graph.
|
||||||
|
|
||||||
|
Nodes are drugs arranged in a circle, edges show co-occurrence in pathways.
|
||||||
|
Node size = total patients, edge width = switching flow between drugs.
|
||||||
|
"""
|
||||||
|
import math
|
||||||
|
|
||||||
|
nodes = data.get("nodes", [])
|
||||||
|
edges = data.get("edges", [])
|
||||||
|
|
||||||
|
if not nodes:
|
||||||
|
return go.Figure()
|
||||||
|
|
||||||
|
display_title = f"Drug Network — {title}" if title else "Drug Network"
|
||||||
|
|
||||||
|
# Circular layout
|
||||||
|
n = len(nodes)
|
||||||
|
node_names = [nd["name"] for nd in nodes]
|
||||||
|
node_patients = [nd["total_patients"] for nd in nodes]
|
||||||
|
name_to_idx = {nd["name"]: i for i, nd in enumerate(nodes)}
|
||||||
|
|
||||||
|
angles = [2 * math.pi * i / n for i in range(n)]
|
||||||
|
x_pos = [math.cos(a) for a in angles]
|
||||||
|
y_pos = [math.sin(a) for a in angles]
|
||||||
|
|
||||||
|
fig = go.Figure()
|
||||||
|
|
||||||
|
# Draw edges as individual traces (each gets its own width)
|
||||||
|
max_edge_patients = max((e["patients"] for e in edges), default=1) or 1
|
||||||
|
for edge in edges:
|
||||||
|
src_idx = name_to_idx.get(edge["source"])
|
||||||
|
tgt_idx = name_to_idx.get(edge["target"])
|
||||||
|
if src_idx is None or tgt_idx is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Scale width: min 0.5, max 6
|
||||||
|
width = max(0.5, min(6, 0.5 + 5.5 * (edge["patients"] / max_edge_patients)))
|
||||||
|
# Opacity scales with relative strength
|
||||||
|
opacity = max(0.15, min(0.7, 0.15 + 0.55 * (edge["patients"] / max_edge_patients)))
|
||||||
|
|
||||||
|
fig.add_trace(go.Scatter(
|
||||||
|
x=[x_pos[src_idx], x_pos[tgt_idx]],
|
||||||
|
y=[y_pos[src_idx], y_pos[tgt_idx]],
|
||||||
|
mode="lines",
|
||||||
|
line=dict(width=width, color=f"rgba(0,94,184,{opacity})"),
|
||||||
|
hoverinfo="skip",
|
||||||
|
showlegend=False,
|
||||||
|
))
|
||||||
|
|
||||||
|
# Draw nodes
|
||||||
|
max_patients = max(node_patients, default=1) or 1
|
||||||
|
sizes = [max(12, min(50, 12 + 38 * (p / max_patients))) for p in node_patients]
|
||||||
|
colors = [DRUG_PALETTE[i % len(DRUG_PALETTE)] for i in range(n)]
|
||||||
|
|
||||||
|
fig.add_trace(go.Scatter(
|
||||||
|
x=x_pos,
|
||||||
|
y=y_pos,
|
||||||
|
mode="markers+text",
|
||||||
|
marker=dict(
|
||||||
|
size=sizes,
|
||||||
|
color=colors,
|
||||||
|
line=dict(width=1.5, color="white"),
|
||||||
|
),
|
||||||
|
text=node_names,
|
||||||
|
textposition="top center",
|
||||||
|
textfont=dict(size=9, family=CHART_FONT_FAMILY),
|
||||||
|
customdata=[[p] for p in node_patients],
|
||||||
|
hovertemplate=(
|
||||||
|
"<b>%{text}</b><br>"
|
||||||
|
"Patients: %{customdata[0]:,}<br>"
|
||||||
|
"<extra></extra>"
|
||||||
|
),
|
||||||
|
showlegend=False,
|
||||||
|
))
|
||||||
|
|
||||||
|
layout = _base_layout(display_title)
|
||||||
|
layout.update(
|
||||||
|
margin=dict(t=60, l=24, r=24, b=24),
|
||||||
|
xaxis=dict(visible=False, scaleanchor="y", scaleratio=1),
|
||||||
|
yaxis=dict(visible=False),
|
||||||
|
height=600,
|
||||||
|
)
|
||||||
|
fig.update_layout(**layout)
|
||||||
|
|
||||||
|
return fig
|
||||||
|
|||||||
Reference in New Issue
Block a user