diff --git a/IMPLEMENTATION_PLAN.md b/IMPLEMENTATION_PLAN.md
index eededa6..039ddae 100644
--- a/IMPLEMENTATION_PLAN.md
+++ b/IMPLEMENTATION_PLAN.md
@@ -146,15 +146,17 @@ Comprehensive review and improvement of all Plotly charts in the Dash dashboard.
- **Checkpoint**: Funnel tab shows retention by treatment line depth, responds to filters
### C.2 Pathway depth distribution chart
-- [ ] Create `get_pathway_depth_distribution()` in `src/data_processing/pathway_queries.py`:
+- [x] Create `get_pathway_depth_distribution()` in `src/data_processing/pathway_queries.py`:
- Aggregate patient counts at level 3 (1-drug), level 4 (2-drug), etc.
- Subtract child counts to get patients who STOPPED at each depth
- - Return: `[{depth: 1, label: "1 drug only", patients: N}, ...]`
-- [ ] Add thin wrapper in `dash_app/data/queries.py`
-- [ ] Create `create_pathway_depth_figure(data, title)` in `src/visualization/plotly_generator.py`:
+ - Return: `[{depth: 1, label: "1 drug only", patients: N, pct: 80.2}, ...]`
+- [x] Add thin wrapper in `dash_app/data/queries.py`
+- [x] Create `create_pathway_depth_figure(data, title)` in `src/visualization/plotly_generator.py`:
- Horizontal bar chart with NHS blue gradient by depth
-- [ ] Add "Depth" tab to `TAB_DEFINITIONS` in `chart_card.py`
-- [ ] Add `_render_depth()` helper and tab dispatch in `dash_app/callbacks/chart.py`
+ - Text shows "N (pct%)" inside bars
+ - Uses `_base_layout()` for consistent styling
+- [x] Add "Depth" tab to `TAB_DEFINITIONS` in `chart_card.py` (5 tabs: Icicle, Sankey, Heatmap, Funnel, Depth)
+- [x] Add `_render_depth()` helper and tab dispatch in `dash_app/callbacks/chart.py`
- **Checkpoint**: Depth tab shows patient distribution by treatment line count
### C.3 Duration vs Cost scatter plot
diff --git a/dash_app/callbacks/chart.py b/dash_app/callbacks/chart.py
index f7a087f..4d41c85 100644
--- a/dash_app/callbacks/chart.py
+++ b/dash_app/callbacks/chart.py
@@ -292,6 +292,31 @@ def _render_funnel(app_state, title):
return create_retention_funnel_figure(data, title)
+def _render_depth(app_state, title):
+ """Build the pathway depth distribution figure from current filter state."""
+ from dash_app.data.queries import get_pathway_depth_distribution
+ from visualization.plotly_generator import create_pathway_depth_figure
+
+ filter_id = (app_state or {}).get("date_filter_id", "all_6mo")
+ chart_type = (app_state or {}).get("chart_type", "directory")
+
+ selected_dirs = (app_state or {}).get("selected_directorates") or []
+ selected_trusts = (app_state or {}).get("selected_trusts") or []
+ directory = selected_dirs[0] if len(selected_dirs) == 1 else None
+ trust = selected_trusts[0] if len(selected_trusts) == 1 else None
+
+ try:
+ data = get_pathway_depth_distribution(filter_id, chart_type, directory, trust)
+ except Exception:
+ log.exception("Failed to load pathway depth data")
+ return _empty_figure("Failed to load pathway depth data.")
+
+ if not data:
+ return _empty_figure("No pathway depth data available.\nTry adjusting your filters.")
+
+ return create_pathway_depth_figure(data, title)
+
+
def register_chart_callbacks(app):
"""Register tab switching, pathway data loading, and chart rendering callbacks."""
@@ -435,6 +460,9 @@ def register_chart_callbacks(app):
elif active_tab == "funnel":
fig = _render_funnel(app_state, title)
+ elif active_tab == "depth":
+ fig = _render_depth(app_state, title)
+
else:
# Placeholder for charts not yet implemented
tab_label = dict(TAB_DEFINITIONS).get(active_tab, active_tab)
diff --git a/dash_app/components/chart_card.py b/dash_app/components/chart_card.py
index 835c4c3..0820df7 100644
--- a/dash_app/components/chart_card.py
+++ b/dash_app/components/chart_card.py
@@ -9,6 +9,7 @@ TAB_DEFINITIONS = [
("sankey", "Sankey"),
("heatmap", "Heatmap"),
("funnel", "Funnel"),
+ ("depth", "Depth"),
]
# Full set retained for Trust Comparison dashboard (Phase 10.8)
diff --git a/dash_app/data/queries.py b/dash_app/data/queries.py
index 89cff41..17d1a2e 100644
--- a/dash_app/data/queries.py
+++ b/dash_app/data/queries.py
@@ -25,6 +25,7 @@ from data_processing.pathway_queries import (
get_trust_durations as _get_trust_durations,
get_directorate_summary as _get_directorate_summary,
get_retention_funnel as _get_retention_funnel,
+ get_pathway_depth_distribution as _get_pathway_depth_distribution,
)
DB_PATH = Path(__file__).resolve().parents[2] / "data" / "pathways.db"
@@ -194,3 +195,13 @@ def get_retention_funnel(
) -> list[dict]:
"""Patient retention by treatment line depth."""
return _get_retention_funnel(DB_PATH, date_filter_id, chart_type, directory, trust)
+
+
+def get_pathway_depth_distribution(
+ date_filter_id: str = "all_6mo",
+ chart_type: str = "directory",
+ directory: Optional[str] = None,
+ trust: Optional[str] = None,
+) -> list[dict]:
+ """Patients who stopped at each treatment line depth (exclusive counts)."""
+ return _get_pathway_depth_distribution(DB_PATH, date_filter_id, chart_type, directory, trust)
diff --git a/src/data_processing/pathway_queries.py b/src/data_processing/pathway_queries.py
index 1bee9ea..071bfe1 100644
--- a/src/data_processing/pathway_queries.py
+++ b/src/data_processing/pathway_queries.py
@@ -1139,6 +1139,76 @@ def get_retention_funnel(
conn.close()
+def get_pathway_depth_distribution(
+ db_path: Path,
+ date_filter_id: str,
+ chart_type: str,
+ directory: Optional[str] = None,
+ trust: Optional[str] = None,
+) -> list[dict]:
+ """Count patients who STOPPED at each treatment line depth.
+
+ Unlike the retention funnel (cumulative), this shows exclusive counts:
+ patients at depth N minus patients at depth N+1 = stopped at depth N.
+
+ Returns list of dicts sorted by depth ascending:
+ [{depth: 1, label: "1 drug only", patients: N, pct: 45.2}, ...]
+ """
+ conn = sqlite3.connect(str(db_path))
+ conn.row_factory = sqlite3.Row
+ try:
+ where = ["date_filter_id = ?", "chart_type = ?", "level >= 3"]
+ params: list = [date_filter_id, chart_type]
+
+ if directory:
+ where.append("directory = ?")
+ params.append(directory)
+ if trust:
+ where.append("trust_name = ?")
+ params.append(trust)
+
+ query = f"""
+ SELECT level, SUM(value) AS patients
+ FROM pathway_nodes
+ WHERE {' AND '.join(where)}
+ GROUP BY level
+ ORDER BY level
+ """
+ rows = conn.execute(query, params).fetchall()
+
+ if not rows:
+ return []
+
+ # Build list of (depth, cumulative_patients)
+ levels = []
+ for r in rows:
+ depth = r["level"] - 2 # level 3 → depth 1
+ patients = r["patients"] or 0
+ levels.append((depth, patients))
+
+ # Subtract next level to get "stopped at this depth"
+ total_patients = levels[0][1] if levels else 0
+ result = []
+ for i, (depth, patients) in enumerate(levels):
+ next_patients = levels[i + 1][1] if i + 1 < len(levels) else 0
+ stopped = patients - next_patients
+
+ label = f"{depth} drug{'s' if depth > 1 else ''} only"
+ pct = round(stopped / total_patients * 100, 1) if total_patients else 0
+ result.append({
+ "depth": depth,
+ "label": label,
+ "patients": stopped,
+ "pct": pct,
+ })
+
+ return result
+ except sqlite3.Error:
+ return []
+ finally:
+ conn.close()
+
+
def get_directorate_summary(
db_path: Path,
date_filter_id: str,
diff --git a/src/visualization/plotly_generator.py b/src/visualization/plotly_generator.py
index 9436e74..ff3cc71 100644
--- a/src/visualization/plotly_generator.py
+++ b/src/visualization/plotly_generator.py
@@ -1820,3 +1820,74 @@ def create_retention_funnel_figure(
fig.update_layout(**layout)
return fig
+
+
+def create_pathway_depth_figure(
+ data: list[dict],
+ title: str = "",
+) -> go.Figure:
+ """Create a horizontal bar chart showing patients who stopped at each treatment depth.
+
+ Args:
+ data: List of dicts with keys: depth, label, patients, pct
+ title: Chart title from filter state.
+
+ Returns:
+ Plotly Figure with horizontal bar trace.
+ """
+ if not data:
+ return go.Figure()
+
+ display_title = f"Pathway Depth Distribution — {title}" if title else "Pathway Depth Distribution"
+
+ labels = [d["label"] for d in data]
+ patients = [d["patients"] for d in data]
+ pcts = [d["pct"] for d in data]
+
+ # NHS blue gradient: darkest for depth 1 (most patients) → lightest
+ bar_colors = [
+ "#003087",
+ "#005EB8",
+ "#1E88E5",
+ "#42A5F5",
+ "#90CAF9",
+ ]
+ colors = bar_colors[: len(data)]
+ if len(colors) < len(data):
+ colors.extend(["#E3F2FD"] * (len(data) - len(colors)))
+
+ fig = go.Figure(
+ go.Bar(
+ y=labels,
+ x=patients,
+ orientation="h",
+ text=[f"{p:,} ({pct}%)" for p, pct in zip(patients, pcts)],
+ textposition="auto",
+ textfont=dict(family=CHART_FONT_FAMILY, size=13),
+ marker=dict(color=colors),
+ hovertemplate=(
+ "%{y}
"
+ "Patients: %{x:,}
"
+ ""
+ ),
+ )
+ )
+
+ layout = _base_layout(display_title)
+ layout.update(
+ margin=dict(t=60, l=8, r=24, b=40),
+ yaxis=dict(
+ automargin=True,
+ autorange="reversed",
+ title="",
+ ),
+ xaxis=dict(
+ title="Patients",
+ gridcolor=GRID_COLOR,
+ ),
+ height=max(300, len(data) * 70 + 120),
+ bargap=0.3,
+ )
+ fig.update_layout(**layout)
+
+ return fig