diff --git a/IMPLEMENTATION_PLAN.md b/IMPLEMENTATION_PLAN.md index b63a2e3..fcb3c64 100644 --- a/IMPLEMENTATION_PLAN.md +++ b/IMPLEMENTATION_PLAN.md @@ -175,14 +175,14 @@ Comprehensive review and improvement of all Plotly charts in the Dash dashboard. - **Checkpoint**: Scatter tab shows drugs by duration vs cost with directorate coloring ### C.4 Drug switching network graph -- [ ] Create modified variant of `get_drug_transitions()` in pathway_queries.py that returns undirected edges without ordinal suffixes: - - `get_drug_network(db_path, filter_id, chart_type, directory, trust)` → `{nodes: [{name, total_patients}], edges: [{source, target, patients}]}` -- [ ] Add thin wrapper in `dash_app/data/queries.py` -- [ ] Create `create_drug_network_figure(data, title)` in `src/visualization/plotly_generator.py`: - - Use `go.Scatter` for nodes (circular layout) + edges (lines) - - Node size = total patients, edge width = switching flow - - `DRUG_PALETTE` for node colors -- [ ] Add as sub-toggle within Sankey tab (e.g., "Flow" vs "Network" toggle) or as separate "Network" tab +- [x] Create `get_drug_network()` in pathway_queries.py — undirected edges without ordinal suffixes, node patients from level 3, edge co-occurrence from level 4+ +- [x] Add thin wrapper in `dash_app/data/queries.py` +- [x] Create `create_drug_network_figure(data, title)` in `src/visualization/plotly_generator.py`: + - Circular layout using `go.Scatter` for nodes + individual edge traces as lines + - Node size = total patients (12–50px), edge width = switching flow (0.5–6px), edge opacity scales with strength + - `DRUG_PALETTE` for node colors, NHS Blue (`rgba(0,94,184,...)`) for edges +- [x] Added as separate "Network" tab (7th tab: Icicle, Sankey, Heatmap, Funnel, Depth, Scatter, Network) +- [x] Added `_render_network()` helper and dispatch case in `chart.py` - **Checkpoint**: Network view shows drug switching as a graph alternative to Sankey --- diff --git a/dash_app/callbacks/chart.py b/dash_app/callbacks/chart.py index 3644494..eb6b2bf 100644 --- a/dash_app/callbacks/chart.py +++ b/dash_app/callbacks/chart.py @@ -342,6 +342,31 @@ def _render_scatter(app_state, title): return create_duration_cost_scatter_figure(data, title) +def _render_network(app_state, title): + """Build the drug co-occurrence network graph from current filter state.""" + from dash_app.data.queries import get_drug_network + from visualization.plotly_generator import create_drug_network_figure + + filter_id = (app_state or {}).get("date_filter_id", "all_6mo") + chart_type = (app_state or {}).get("chart_type", "directory") + + selected_dirs = (app_state or {}).get("selected_directorates") or [] + selected_trusts = (app_state or {}).get("selected_trusts") or [] + directory = selected_dirs[0] if len(selected_dirs) == 1 else None + trust = selected_trusts[0] if len(selected_trusts) == 1 else None + + try: + data = get_drug_network(filter_id, chart_type, directory, trust) + except Exception: + log.exception("Failed to load drug network data") + return _empty_figure("Failed to load drug network data.") + + if not data.get("nodes"): + return _empty_figure("No drug network data available.\nTry adjusting your filters.") + + return create_drug_network_figure(data, title) + + def register_chart_callbacks(app): """Register tab switching, pathway data loading, and chart rendering callbacks.""" @@ -491,6 +516,9 @@ def register_chart_callbacks(app): elif active_tab == "scatter": fig = _render_scatter(app_state, title) + elif active_tab == "network": + fig = _render_network(app_state, title) + else: # Placeholder for charts not yet implemented tab_label = dict(TAB_DEFINITIONS).get(active_tab, active_tab) diff --git a/dash_app/components/chart_card.py b/dash_app/components/chart_card.py index d348cc7..8e8c28a 100644 --- a/dash_app/components/chart_card.py +++ b/dash_app/components/chart_card.py @@ -11,6 +11,7 @@ TAB_DEFINITIONS = [ ("funnel", "Funnel"), ("depth", "Depth"), ("scatter", "Scatter"), + ("network", "Network"), ] # Full set retained for Trust Comparison dashboard (Phase 10.8) diff --git a/dash_app/data/queries.py b/dash_app/data/queries.py index d3e9758..fc968d7 100644 --- a/dash_app/data/queries.py +++ b/dash_app/data/queries.py @@ -27,6 +27,7 @@ from data_processing.pathway_queries import ( get_retention_funnel as _get_retention_funnel, get_pathway_depth_distribution as _get_pathway_depth_distribution, get_duration_cost_scatter as _get_duration_cost_scatter, + get_drug_network as _get_drug_network, ) DB_PATH = Path(__file__).resolve().parents[2] / "data" / "pathways.db" @@ -216,3 +217,13 @@ def get_duration_cost_scatter( ) -> list[dict]: """Drug-level avg_days and cost_pp_pa for scatter plot.""" return _get_duration_cost_scatter(DB_PATH, date_filter_id, chart_type, directory, trust) + + +def get_drug_network( + date_filter_id: str = "all_6mo", + chart_type: str = "directory", + directory: Optional[str] = None, + trust: Optional[str] = None, +) -> dict: + """Undirected drug co-occurrence network for network graph.""" + return _get_drug_network(DB_PATH, date_filter_id, chart_type, directory, trust) diff --git a/src/data_processing/pathway_queries.py b/src/data_processing/pathway_queries.py index 40bec58..6b16ae4 100644 --- a/src/data_processing/pathway_queries.py +++ b/src/data_processing/pathway_queries.py @@ -1285,6 +1285,94 @@ def get_duration_cost_scatter( conn.close() +def get_drug_network( + db_path: Path, + date_filter_id: str, + chart_type: str, + directory: Optional[str] = None, + trust: Optional[str] = None, +) -> dict: + """Build undirected drug co-occurrence network from pathway data. + + Unlike get_drug_transitions() (directed, with ordinal suffixes for Sankey), + this returns plain drug names with undirected edges representing co-occurrence + in patient pathways. + + Returns dict with: + nodes: [{name, total_patients}] — unique drug names sorted by patient count + edges: [{source, target, patients}] — undirected co-occurrence links + """ + conn = sqlite3.connect(str(db_path)) + conn.row_factory = sqlite3.Row + try: + where = ["date_filter_id = ?", "chart_type = ?", "level >= 4"] + params: list = [date_filter_id, chart_type] + + if directory: + where.append("directory = ?") + params.append(directory) + if trust: + where.append("trust_name = ?") + params.append(trust) + + query = f""" + SELECT value AS patients, drug_sequence + FROM pathway_nodes + WHERE {' AND '.join(where)} + """ + rows = conn.execute(query, params).fetchall() + + # Also get level 3 nodes for per-drug patient totals + where_l3 = ["date_filter_id = ?", "chart_type = ?", "level = 3"] + params_l3: list = [date_filter_id, chart_type] + if directory: + where_l3.append("directory = ?") + params_l3.append(directory) + if trust: + where_l3.append("trust_name = ?") + params_l3.append(trust) + + query_l3 = f""" + SELECT labels AS drug, SUM(value) AS total_patients + FROM pathway_nodes + WHERE {' AND '.join(where_l3)} + GROUP BY labels + """ + l3_rows = conn.execute(query_l3, params_l3).fetchall() + node_patients = {r["drug"]: r["total_patients"] or 0 for r in l3_rows} + + # Build undirected edges from pathway sequences + edge_agg = {} + for r in rows: + drugs = [d for d in (r["drug_sequence"] or "").split("|") if d] + patients = r["patients"] or 0 + if len(drugs) < 2 or patients == 0: + continue + + # Adjacent drug pairs (undirected: sort to avoid A→B and B→A duplicates) + for i in range(len(drugs) - 1): + pair = tuple(sorted([drugs[i], drugs[i + 1]])) + edge_agg[pair] = edge_agg.get(pair, 0) + patients + + # Build result + nodes = [ + {"name": name, "total_patients": pts} + for name, pts in sorted(node_patients.items(), key=lambda x: -x[1]) + if pts > 0 + ] + + edges = [ + {"source": src, "target": tgt, "patients": pts} + for (src, tgt), pts in sorted(edge_agg.items(), key=lambda x: -x[1]) + ] + + return {"nodes": nodes, "edges": edges} + except sqlite3.Error: + return {"nodes": [], "edges": []} + finally: + conn.close() + + def get_directorate_summary( db_path: Path, date_filter_id: str, diff --git a/src/visualization/plotly_generator.py b/src/visualization/plotly_generator.py index 0cb88a4..96d58ad 100644 --- a/src/visualization/plotly_generator.py +++ b/src/visualization/plotly_generator.py @@ -1997,3 +1997,91 @@ def create_duration_cost_scatter_figure( fig.update_layout(**layout) return fig + + +def create_drug_network_figure(data: dict, title: str = "") -> go.Figure: + """Create a drug co-occurrence network graph. + + Nodes are drugs arranged in a circle, edges show co-occurrence in pathways. + Node size = total patients, edge width = switching flow between drugs. + """ + import math + + nodes = data.get("nodes", []) + edges = data.get("edges", []) + + if not nodes: + return go.Figure() + + display_title = f"Drug Network — {title}" if title else "Drug Network" + + # Circular layout + n = len(nodes) + node_names = [nd["name"] for nd in nodes] + node_patients = [nd["total_patients"] for nd in nodes] + name_to_idx = {nd["name"]: i for i, nd in enumerate(nodes)} + + angles = [2 * math.pi * i / n for i in range(n)] + x_pos = [math.cos(a) for a in angles] + y_pos = [math.sin(a) for a in angles] + + fig = go.Figure() + + # Draw edges as individual traces (each gets its own width) + max_edge_patients = max((e["patients"] for e in edges), default=1) or 1 + for edge in edges: + src_idx = name_to_idx.get(edge["source"]) + tgt_idx = name_to_idx.get(edge["target"]) + if src_idx is None or tgt_idx is None: + continue + + # Scale width: min 0.5, max 6 + width = max(0.5, min(6, 0.5 + 5.5 * (edge["patients"] / max_edge_patients))) + # Opacity scales with relative strength + opacity = max(0.15, min(0.7, 0.15 + 0.55 * (edge["patients"] / max_edge_patients))) + + fig.add_trace(go.Scatter( + x=[x_pos[src_idx], x_pos[tgt_idx]], + y=[y_pos[src_idx], y_pos[tgt_idx]], + mode="lines", + line=dict(width=width, color=f"rgba(0,94,184,{opacity})"), + hoverinfo="skip", + showlegend=False, + )) + + # Draw nodes + max_patients = max(node_patients, default=1) or 1 + sizes = [max(12, min(50, 12 + 38 * (p / max_patients))) for p in node_patients] + colors = [DRUG_PALETTE[i % len(DRUG_PALETTE)] for i in range(n)] + + fig.add_trace(go.Scatter( + x=x_pos, + y=y_pos, + mode="markers+text", + marker=dict( + size=sizes, + color=colors, + line=dict(width=1.5, color="white"), + ), + text=node_names, + textposition="top center", + textfont=dict(size=9, family=CHART_FONT_FAMILY), + customdata=[[p] for p in node_patients], + hovertemplate=( + "%{text}
" + "Patients: %{customdata[0]:,}
" + "" + ), + showlegend=False, + )) + + layout = _base_layout(display_title) + layout.update( + margin=dict(t=60, l=24, r=24, b=24), + xaxis=dict(visible=False, scaleanchor="y", scaleratio=1), + yaxis=dict(visible=False), + height=600, + ) + fig.update_layout(**layout) + + return fig