diff --git a/IMPLEMENTATION_PLAN.md b/IMPLEMENTATION_PLAN.md
index b63a2e3..fcb3c64 100644
--- a/IMPLEMENTATION_PLAN.md
+++ b/IMPLEMENTATION_PLAN.md
@@ -175,14 +175,14 @@ Comprehensive review and improvement of all Plotly charts in the Dash dashboard.
- **Checkpoint**: Scatter tab shows drugs by duration vs cost with directorate coloring
### C.4 Drug switching network graph
-- [ ] Create modified variant of `get_drug_transitions()` in pathway_queries.py that returns undirected edges without ordinal suffixes:
- - `get_drug_network(db_path, filter_id, chart_type, directory, trust)` → `{nodes: [{name, total_patients}], edges: [{source, target, patients}]}`
-- [ ] Add thin wrapper in `dash_app/data/queries.py`
-- [ ] Create `create_drug_network_figure(data, title)` in `src/visualization/plotly_generator.py`:
- - Use `go.Scatter` for nodes (circular layout) + edges (lines)
- - Node size = total patients, edge width = switching flow
- - `DRUG_PALETTE` for node colors
-- [ ] Add as sub-toggle within Sankey tab (e.g., "Flow" vs "Network" toggle) or as separate "Network" tab
+- [x] Create `get_drug_network()` in pathway_queries.py — undirected edges without ordinal suffixes, node patients from level 3, edge co-occurrence from level 4+
+- [x] Add thin wrapper in `dash_app/data/queries.py`
+- [x] Create `create_drug_network_figure(data, title)` in `src/visualization/plotly_generator.py`:
+ - Circular layout using `go.Scatter` for nodes + individual edge traces as lines
+ - Node size = total patients (12–50px), edge width = switching flow (0.5–6px), edge opacity scales with strength
+ - `DRUG_PALETTE` for node colors, NHS Blue (`rgba(0,94,184,...)`) for edges
+- [x] Added as separate "Network" tab (7th tab: Icicle, Sankey, Heatmap, Funnel, Depth, Scatter, Network)
+- [x] Added `_render_network()` helper and dispatch case in `chart.py`
- **Checkpoint**: Network view shows drug switching as a graph alternative to Sankey
---
diff --git a/dash_app/callbacks/chart.py b/dash_app/callbacks/chart.py
index 3644494..eb6b2bf 100644
--- a/dash_app/callbacks/chart.py
+++ b/dash_app/callbacks/chart.py
@@ -342,6 +342,31 @@ def _render_scatter(app_state, title):
return create_duration_cost_scatter_figure(data, title)
+def _render_network(app_state, title):
+ """Build the drug co-occurrence network graph from current filter state."""
+ from dash_app.data.queries import get_drug_network
+ from visualization.plotly_generator import create_drug_network_figure
+
+ filter_id = (app_state or {}).get("date_filter_id", "all_6mo")
+ chart_type = (app_state or {}).get("chart_type", "directory")
+
+ selected_dirs = (app_state or {}).get("selected_directorates") or []
+ selected_trusts = (app_state or {}).get("selected_trusts") or []
+ directory = selected_dirs[0] if len(selected_dirs) == 1 else None
+ trust = selected_trusts[0] if len(selected_trusts) == 1 else None
+
+ try:
+ data = get_drug_network(filter_id, chart_type, directory, trust)
+ except Exception:
+ log.exception("Failed to load drug network data")
+ return _empty_figure("Failed to load drug network data.")
+
+ if not data.get("nodes"):
+ return _empty_figure("No drug network data available.\nTry adjusting your filters.")
+
+ return create_drug_network_figure(data, title)
+
+
def register_chart_callbacks(app):
"""Register tab switching, pathway data loading, and chart rendering callbacks."""
@@ -491,6 +516,9 @@ def register_chart_callbacks(app):
elif active_tab == "scatter":
fig = _render_scatter(app_state, title)
+ elif active_tab == "network":
+ fig = _render_network(app_state, title)
+
else:
# Placeholder for charts not yet implemented
tab_label = dict(TAB_DEFINITIONS).get(active_tab, active_tab)
diff --git a/dash_app/components/chart_card.py b/dash_app/components/chart_card.py
index d348cc7..8e8c28a 100644
--- a/dash_app/components/chart_card.py
+++ b/dash_app/components/chart_card.py
@@ -11,6 +11,7 @@ TAB_DEFINITIONS = [
("funnel", "Funnel"),
("depth", "Depth"),
("scatter", "Scatter"),
+ ("network", "Network"),
]
# Full set retained for Trust Comparison dashboard (Phase 10.8)
diff --git a/dash_app/data/queries.py b/dash_app/data/queries.py
index d3e9758..fc968d7 100644
--- a/dash_app/data/queries.py
+++ b/dash_app/data/queries.py
@@ -27,6 +27,7 @@ from data_processing.pathway_queries import (
get_retention_funnel as _get_retention_funnel,
get_pathway_depth_distribution as _get_pathway_depth_distribution,
get_duration_cost_scatter as _get_duration_cost_scatter,
+ get_drug_network as _get_drug_network,
)
DB_PATH = Path(__file__).resolve().parents[2] / "data" / "pathways.db"
@@ -216,3 +217,13 @@ def get_duration_cost_scatter(
) -> list[dict]:
"""Drug-level avg_days and cost_pp_pa for scatter plot."""
return _get_duration_cost_scatter(DB_PATH, date_filter_id, chart_type, directory, trust)
+
+
+def get_drug_network(
+ date_filter_id: str = "all_6mo",
+ chart_type: str = "directory",
+ directory: Optional[str] = None,
+ trust: Optional[str] = None,
+) -> dict:
+ """Undirected drug co-occurrence network for network graph."""
+ return _get_drug_network(DB_PATH, date_filter_id, chart_type, directory, trust)
diff --git a/src/data_processing/pathway_queries.py b/src/data_processing/pathway_queries.py
index 40bec58..6b16ae4 100644
--- a/src/data_processing/pathway_queries.py
+++ b/src/data_processing/pathway_queries.py
@@ -1285,6 +1285,94 @@ def get_duration_cost_scatter(
conn.close()
+def get_drug_network(
+ db_path: Path,
+ date_filter_id: str,
+ chart_type: str,
+ directory: Optional[str] = None,
+ trust: Optional[str] = None,
+) -> dict:
+ """Build undirected drug co-occurrence network from pathway data.
+
+ Unlike get_drug_transitions() (directed, with ordinal suffixes for Sankey),
+ this returns plain drug names with undirected edges representing co-occurrence
+ in patient pathways.
+
+ Returns dict with:
+ nodes: [{name, total_patients}] — unique drug names sorted by patient count
+ edges: [{source, target, patients}] — undirected co-occurrence links
+ """
+ conn = sqlite3.connect(str(db_path))
+ conn.row_factory = sqlite3.Row
+ try:
+ where = ["date_filter_id = ?", "chart_type = ?", "level >= 4"]
+ params: list = [date_filter_id, chart_type]
+
+ if directory:
+ where.append("directory = ?")
+ params.append(directory)
+ if trust:
+ where.append("trust_name = ?")
+ params.append(trust)
+
+ query = f"""
+ SELECT value AS patients, drug_sequence
+ FROM pathway_nodes
+ WHERE {' AND '.join(where)}
+ """
+ rows = conn.execute(query, params).fetchall()
+
+ # Also get level 3 nodes for per-drug patient totals
+ where_l3 = ["date_filter_id = ?", "chart_type = ?", "level = 3"]
+ params_l3: list = [date_filter_id, chart_type]
+ if directory:
+ where_l3.append("directory = ?")
+ params_l3.append(directory)
+ if trust:
+ where_l3.append("trust_name = ?")
+ params_l3.append(trust)
+
+ query_l3 = f"""
+ SELECT labels AS drug, SUM(value) AS total_patients
+ FROM pathway_nodes
+ WHERE {' AND '.join(where_l3)}
+ GROUP BY labels
+ """
+ l3_rows = conn.execute(query_l3, params_l3).fetchall()
+ node_patients = {r["drug"]: r["total_patients"] or 0 for r in l3_rows}
+
+ # Build undirected edges from pathway sequences
+ edge_agg = {}
+ for r in rows:
+ drugs = [d for d in (r["drug_sequence"] or "").split("|") if d]
+ patients = r["patients"] or 0
+ if len(drugs) < 2 or patients == 0:
+ continue
+
+ # Adjacent drug pairs (undirected: sort to avoid A→B and B→A duplicates)
+ for i in range(len(drugs) - 1):
+ pair = tuple(sorted([drugs[i], drugs[i + 1]]))
+ edge_agg[pair] = edge_agg.get(pair, 0) + patients
+
+ # Build result
+ nodes = [
+ {"name": name, "total_patients": pts}
+ for name, pts in sorted(node_patients.items(), key=lambda x: -x[1])
+ if pts > 0
+ ]
+
+ edges = [
+ {"source": src, "target": tgt, "patients": pts}
+ for (src, tgt), pts in sorted(edge_agg.items(), key=lambda x: -x[1])
+ ]
+
+ return {"nodes": nodes, "edges": edges}
+ except sqlite3.Error:
+ return {"nodes": [], "edges": []}
+ finally:
+ conn.close()
+
+
def get_directorate_summary(
db_path: Path,
date_filter_id: str,
diff --git a/src/visualization/plotly_generator.py b/src/visualization/plotly_generator.py
index 0cb88a4..96d58ad 100644
--- a/src/visualization/plotly_generator.py
+++ b/src/visualization/plotly_generator.py
@@ -1997,3 +1997,91 @@ def create_duration_cost_scatter_figure(
fig.update_layout(**layout)
return fig
+
+
+def create_drug_network_figure(data: dict, title: str = "") -> go.Figure:
+ """Create a drug co-occurrence network graph.
+
+ Nodes are drugs arranged in a circle, edges show co-occurrence in pathways.
+ Node size = total patients, edge width = switching flow between drugs.
+ """
+ import math
+
+ nodes = data.get("nodes", [])
+ edges = data.get("edges", [])
+
+ if not nodes:
+ return go.Figure()
+
+ display_title = f"Drug Network — {title}" if title else "Drug Network"
+
+ # Circular layout
+ n = len(nodes)
+ node_names = [nd["name"] for nd in nodes]
+ node_patients = [nd["total_patients"] for nd in nodes]
+ name_to_idx = {nd["name"]: i for i, nd in enumerate(nodes)}
+
+ angles = [2 * math.pi * i / n for i in range(n)]
+ x_pos = [math.cos(a) for a in angles]
+ y_pos = [math.sin(a) for a in angles]
+
+ fig = go.Figure()
+
+ # Draw edges as individual traces (each gets its own width)
+ max_edge_patients = max((e["patients"] for e in edges), default=1) or 1
+ for edge in edges:
+ src_idx = name_to_idx.get(edge["source"])
+ tgt_idx = name_to_idx.get(edge["target"])
+ if src_idx is None or tgt_idx is None:
+ continue
+
+ # Scale width: min 0.5, max 6
+ width = max(0.5, min(6, 0.5 + 5.5 * (edge["patients"] / max_edge_patients)))
+ # Opacity scales with relative strength
+ opacity = max(0.15, min(0.7, 0.15 + 0.55 * (edge["patients"] / max_edge_patients)))
+
+ fig.add_trace(go.Scatter(
+ x=[x_pos[src_idx], x_pos[tgt_idx]],
+ y=[y_pos[src_idx], y_pos[tgt_idx]],
+ mode="lines",
+ line=dict(width=width, color=f"rgba(0,94,184,{opacity})"),
+ hoverinfo="skip",
+ showlegend=False,
+ ))
+
+ # Draw nodes
+ max_patients = max(node_patients, default=1) or 1
+ sizes = [max(12, min(50, 12 + 38 * (p / max_patients))) for p in node_patients]
+ colors = [DRUG_PALETTE[i % len(DRUG_PALETTE)] for i in range(n)]
+
+ fig.add_trace(go.Scatter(
+ x=x_pos,
+ y=y_pos,
+ mode="markers+text",
+ marker=dict(
+ size=sizes,
+ color=colors,
+ line=dict(width=1.5, color="white"),
+ ),
+ text=node_names,
+ textposition="top center",
+ textfont=dict(size=9, family=CHART_FONT_FAMILY),
+ customdata=[[p] for p in node_patients],
+ hovertemplate=(
+ "%{text}
"
+ "Patients: %{customdata[0]:,}
"
+ ""
+ ),
+ showlegend=False,
+ ))
+
+ layout = _base_layout(display_title)
+ layout.update(
+ margin=dict(t=60, l=24, r=24, b=24),
+ xaxis=dict(visible=False, scaleanchor="y", scaleratio=1),
+ yaxis=dict(visible=False),
+ height=600,
+ )
+ fig.update_layout(**layout)
+
+ return fig