diff --git a/IMPLEMENTATION_PLAN.md b/IMPLEMENTATION_PLAN.md index eededa6..039ddae 100644 --- a/IMPLEMENTATION_PLAN.md +++ b/IMPLEMENTATION_PLAN.md @@ -146,15 +146,17 @@ Comprehensive review and improvement of all Plotly charts in the Dash dashboard. - **Checkpoint**: Funnel tab shows retention by treatment line depth, responds to filters ### C.2 Pathway depth distribution chart -- [ ] Create `get_pathway_depth_distribution()` in `src/data_processing/pathway_queries.py`: +- [x] Create `get_pathway_depth_distribution()` in `src/data_processing/pathway_queries.py`: - Aggregate patient counts at level 3 (1-drug), level 4 (2-drug), etc. - Subtract child counts to get patients who STOPPED at each depth - - Return: `[{depth: 1, label: "1 drug only", patients: N}, ...]` -- [ ] Add thin wrapper in `dash_app/data/queries.py` -- [ ] Create `create_pathway_depth_figure(data, title)` in `src/visualization/plotly_generator.py`: + - Return: `[{depth: 1, label: "1 drug only", patients: N, pct: 80.2}, ...]` +- [x] Add thin wrapper in `dash_app/data/queries.py` +- [x] Create `create_pathway_depth_figure(data, title)` in `src/visualization/plotly_generator.py`: - Horizontal bar chart with NHS blue gradient by depth -- [ ] Add "Depth" tab to `TAB_DEFINITIONS` in `chart_card.py` -- [ ] Add `_render_depth()` helper and tab dispatch in `dash_app/callbacks/chart.py` + - Text shows "N (pct%)" inside bars + - Uses `_base_layout()` for consistent styling +- [x] Add "Depth" tab to `TAB_DEFINITIONS` in `chart_card.py` (5 tabs: Icicle, Sankey, Heatmap, Funnel, Depth) +- [x] Add `_render_depth()` helper and tab dispatch in `dash_app/callbacks/chart.py` - **Checkpoint**: Depth tab shows patient distribution by treatment line count ### C.3 Duration vs Cost scatter plot diff --git a/dash_app/callbacks/chart.py b/dash_app/callbacks/chart.py index f7a087f..4d41c85 100644 --- a/dash_app/callbacks/chart.py +++ b/dash_app/callbacks/chart.py @@ -292,6 +292,31 @@ def _render_funnel(app_state, title): return create_retention_funnel_figure(data, title) +def _render_depth(app_state, title): + """Build the pathway depth distribution figure from current filter state.""" + from dash_app.data.queries import get_pathway_depth_distribution + from visualization.plotly_generator import create_pathway_depth_figure + + filter_id = (app_state or {}).get("date_filter_id", "all_6mo") + chart_type = (app_state or {}).get("chart_type", "directory") + + selected_dirs = (app_state or {}).get("selected_directorates") or [] + selected_trusts = (app_state or {}).get("selected_trusts") or [] + directory = selected_dirs[0] if len(selected_dirs) == 1 else None + trust = selected_trusts[0] if len(selected_trusts) == 1 else None + + try: + data = get_pathway_depth_distribution(filter_id, chart_type, directory, trust) + except Exception: + log.exception("Failed to load pathway depth data") + return _empty_figure("Failed to load pathway depth data.") + + if not data: + return _empty_figure("No pathway depth data available.\nTry adjusting your filters.") + + return create_pathway_depth_figure(data, title) + + def register_chart_callbacks(app): """Register tab switching, pathway data loading, and chart rendering callbacks.""" @@ -435,6 +460,9 @@ def register_chart_callbacks(app): elif active_tab == "funnel": fig = _render_funnel(app_state, title) + elif active_tab == "depth": + fig = _render_depth(app_state, title) + else: # Placeholder for charts not yet implemented tab_label = dict(TAB_DEFINITIONS).get(active_tab, active_tab) diff --git a/dash_app/components/chart_card.py b/dash_app/components/chart_card.py index 835c4c3..0820df7 100644 --- a/dash_app/components/chart_card.py +++ b/dash_app/components/chart_card.py @@ -9,6 +9,7 @@ TAB_DEFINITIONS = [ ("sankey", "Sankey"), ("heatmap", "Heatmap"), ("funnel", "Funnel"), + ("depth", "Depth"), ] # Full set retained for Trust Comparison dashboard (Phase 10.8) diff --git a/dash_app/data/queries.py b/dash_app/data/queries.py index 89cff41..17d1a2e 100644 --- a/dash_app/data/queries.py +++ b/dash_app/data/queries.py @@ -25,6 +25,7 @@ from data_processing.pathway_queries import ( get_trust_durations as _get_trust_durations, get_directorate_summary as _get_directorate_summary, get_retention_funnel as _get_retention_funnel, + get_pathway_depth_distribution as _get_pathway_depth_distribution, ) DB_PATH = Path(__file__).resolve().parents[2] / "data" / "pathways.db" @@ -194,3 +195,13 @@ def get_retention_funnel( ) -> list[dict]: """Patient retention by treatment line depth.""" return _get_retention_funnel(DB_PATH, date_filter_id, chart_type, directory, trust) + + +def get_pathway_depth_distribution( + date_filter_id: str = "all_6mo", + chart_type: str = "directory", + directory: Optional[str] = None, + trust: Optional[str] = None, +) -> list[dict]: + """Patients who stopped at each treatment line depth (exclusive counts).""" + return _get_pathway_depth_distribution(DB_PATH, date_filter_id, chart_type, directory, trust) diff --git a/src/data_processing/pathway_queries.py b/src/data_processing/pathway_queries.py index 1bee9ea..071bfe1 100644 --- a/src/data_processing/pathway_queries.py +++ b/src/data_processing/pathway_queries.py @@ -1139,6 +1139,76 @@ def get_retention_funnel( conn.close() +def get_pathway_depth_distribution( + db_path: Path, + date_filter_id: str, + chart_type: str, + directory: Optional[str] = None, + trust: Optional[str] = None, +) -> list[dict]: + """Count patients who STOPPED at each treatment line depth. + + Unlike the retention funnel (cumulative), this shows exclusive counts: + patients at depth N minus patients at depth N+1 = stopped at depth N. + + Returns list of dicts sorted by depth ascending: + [{depth: 1, label: "1 drug only", patients: N, pct: 45.2}, ...] + """ + conn = sqlite3.connect(str(db_path)) + conn.row_factory = sqlite3.Row + try: + where = ["date_filter_id = ?", "chart_type = ?", "level >= 3"] + params: list = [date_filter_id, chart_type] + + if directory: + where.append("directory = ?") + params.append(directory) + if trust: + where.append("trust_name = ?") + params.append(trust) + + query = f""" + SELECT level, SUM(value) AS patients + FROM pathway_nodes + WHERE {' AND '.join(where)} + GROUP BY level + ORDER BY level + """ + rows = conn.execute(query, params).fetchall() + + if not rows: + return [] + + # Build list of (depth, cumulative_patients) + levels = [] + for r in rows: + depth = r["level"] - 2 # level 3 → depth 1 + patients = r["patients"] or 0 + levels.append((depth, patients)) + + # Subtract next level to get "stopped at this depth" + total_patients = levels[0][1] if levels else 0 + result = [] + for i, (depth, patients) in enumerate(levels): + next_patients = levels[i + 1][1] if i + 1 < len(levels) else 0 + stopped = patients - next_patients + + label = f"{depth} drug{'s' if depth > 1 else ''} only" + pct = round(stopped / total_patients * 100, 1) if total_patients else 0 + result.append({ + "depth": depth, + "label": label, + "patients": stopped, + "pct": pct, + }) + + return result + except sqlite3.Error: + return [] + finally: + conn.close() + + def get_directorate_summary( db_path: Path, date_filter_id: str, diff --git a/src/visualization/plotly_generator.py b/src/visualization/plotly_generator.py index 9436e74..ff3cc71 100644 --- a/src/visualization/plotly_generator.py +++ b/src/visualization/plotly_generator.py @@ -1820,3 +1820,74 @@ def create_retention_funnel_figure( fig.update_layout(**layout) return fig + + +def create_pathway_depth_figure( + data: list[dict], + title: str = "", +) -> go.Figure: + """Create a horizontal bar chart showing patients who stopped at each treatment depth. + + Args: + data: List of dicts with keys: depth, label, patients, pct + title: Chart title from filter state. + + Returns: + Plotly Figure with horizontal bar trace. + """ + if not data: + return go.Figure() + + display_title = f"Pathway Depth Distribution — {title}" if title else "Pathway Depth Distribution" + + labels = [d["label"] for d in data] + patients = [d["patients"] for d in data] + pcts = [d["pct"] for d in data] + + # NHS blue gradient: darkest for depth 1 (most patients) → lightest + bar_colors = [ + "#003087", + "#005EB8", + "#1E88E5", + "#42A5F5", + "#90CAF9", + ] + colors = bar_colors[: len(data)] + if len(colors) < len(data): + colors.extend(["#E3F2FD"] * (len(data) - len(colors))) + + fig = go.Figure( + go.Bar( + y=labels, + x=patients, + orientation="h", + text=[f"{p:,} ({pct}%)" for p, pct in zip(patients, pcts)], + textposition="auto", + textfont=dict(family=CHART_FONT_FAMILY, size=13), + marker=dict(color=colors), + hovertemplate=( + "%{y}
" + "Patients: %{x:,}
" + "" + ), + ) + ) + + layout = _base_layout(display_title) + layout.update( + margin=dict(t=60, l=8, r=24, b=40), + yaxis=dict( + automargin=True, + autorange="reversed", + title="", + ), + xaxis=dict( + title="Patients", + gridcolor=GRID_COLOR, + ), + height=max(300, len(data) * 70 + 120), + bargap=0.3, + ) + fig.update_layout(**layout) + + return fig