diff --git a/IMPLEMENTATION_PLAN.md b/IMPLEMENTATION_PLAN.md
index ef8839a..eededa6 100644
--- a/IMPLEMENTATION_PLAN.md
+++ b/IMPLEMENTATION_PLAN.md
@@ -132,15 +132,17 @@ Comprehensive review and improvement of all Plotly charts in the Dash dashboard.
## Phase C: New Analytics (Existing Data)
### C.1 Retention funnel chart
-- [ ] Create `get_retention_funnel()` in `src/data_processing/pathway_queries.py`:
- - Query level 4+ nodes, aggregate patient counts by treatment line depth
- - Return: `[{depth: 1, label: "1 drug", patients: N, pct: 100}, {depth: 2, ...}, ...]`
-- [ ] Add thin wrapper in `dash_app/data/queries.py`
-- [ ] Create `create_retention_funnel_figure(data, title)` in `src/visualization/plotly_generator.py`:
- - Use `go.Funnel` with NHS blue gradient
- - Show absolute patient count + percentage retained
-- [ ] Add "Funnel" tab to `TAB_DEFINITIONS` in `chart_card.py`
-- [ ] Add `_render_funnel()` helper and tab dispatch in `dash_app/callbacks/chart.py`
+- [x] Create `get_retention_funnel()` in `src/data_processing/pathway_queries.py`:
+ - Query level 3+ nodes, aggregate patient counts by treatment line depth (level 3=1st drug, 4=2nd, 5=3rd)
+ - Return: `[{depth: 1, label: "1st drug", patients: N, pct: 100.0}, ...]`
+ - Supports directory/trust filters
+- [x] Add thin wrapper in `dash_app/data/queries.py`
+- [x] Create `create_retention_funnel_figure(data, title)` in `src/visualization/plotly_generator.py`:
+ - Uses `go.Funnel` with NHS blue gradient (#003087 → #1E88E5)
+ - Shows absolute patient count + percentage retained as text inside bars
+ - Uses `_base_layout()` for consistent styling
+- [x] Add "Funnel" tab to `TAB_DEFINITIONS` in `chart_card.py` (4 tabs: Icicle, Sankey, Heatmap, Funnel)
+- [x] Add `_render_funnel()` helper and tab dispatch in `dash_app/callbacks/chart.py`
- **Checkpoint**: Funnel tab shows retention by treatment line depth, responds to filters
### C.2 Pathway depth distribution chart
diff --git a/dash_app/callbacks/chart.py b/dash_app/callbacks/chart.py
index e3f795f..f7a087f 100644
--- a/dash_app/callbacks/chart.py
+++ b/dash_app/callbacks/chart.py
@@ -267,6 +267,31 @@ def _render_duration(app_state, title):
return create_duration_figure(data, title, show_directory=show_directory)
+def _render_funnel(app_state, title):
+ """Build the retention funnel figure from current filter state."""
+ from dash_app.data.queries import get_retention_funnel
+ from visualization.plotly_generator import create_retention_funnel_figure
+
+ filter_id = (app_state or {}).get("date_filter_id", "all_6mo")
+ chart_type = (app_state or {}).get("chart_type", "directory")
+
+ selected_dirs = (app_state or {}).get("selected_directorates") or []
+ selected_trusts = (app_state or {}).get("selected_trusts") or []
+ directory = selected_dirs[0] if len(selected_dirs) == 1 else None
+ trust = selected_trusts[0] if len(selected_trusts) == 1 else None
+
+ try:
+ data = get_retention_funnel(filter_id, chart_type, directory, trust)
+ except Exception:
+ log.exception("Failed to load retention funnel data")
+ return _empty_figure("Failed to load retention funnel data.")
+
+ if not data:
+ return _empty_figure("No retention data available.\nTry adjusting your filters.")
+
+ return create_retention_funnel_figure(data, title)
+
+
def register_chart_callbacks(app):
"""Register tab switching, pathway data loading, and chart rendering callbacks."""
@@ -407,6 +432,9 @@ def register_chart_callbacks(app):
elif active_tab == "duration":
fig = _render_duration(app_state, title)
+ elif active_tab == "funnel":
+ fig = _render_funnel(app_state, title)
+
else:
# Placeholder for charts not yet implemented
tab_label = dict(TAB_DEFINITIONS).get(active_tab, active_tab)
diff --git a/dash_app/components/chart_card.py b/dash_app/components/chart_card.py
index 51d65b3..835c4c3 100644
--- a/dash_app/components/chart_card.py
+++ b/dash_app/components/chart_card.py
@@ -8,6 +8,7 @@ TAB_DEFINITIONS = [
("icicle", "Icicle"),
("sankey", "Sankey"),
("heatmap", "Heatmap"),
+ ("funnel", "Funnel"),
]
# Full set retained for Trust Comparison dashboard (Phase 10.8)
diff --git a/dash_app/data/queries.py b/dash_app/data/queries.py
index ec91508..89cff41 100644
--- a/dash_app/data/queries.py
+++ b/dash_app/data/queries.py
@@ -24,6 +24,7 @@ from data_processing.pathway_queries import (
get_trust_heatmap as _get_trust_heatmap,
get_trust_durations as _get_trust_durations,
get_directorate_summary as _get_directorate_summary,
+ get_retention_funnel as _get_retention_funnel,
)
DB_PATH = Path(__file__).resolve().parents[2] / "data" / "pathways.db"
@@ -180,3 +181,16 @@ def get_directorate_summary(
) -> list[dict]:
"""Per-directorate summary (name, patient count, drug count) for landing cards."""
return _get_directorate_summary(DB_PATH, date_filter_id, chart_type)
+
+
+# --- Retention funnel (Phase C) ---
+
+
+def get_retention_funnel(
+ date_filter_id: str = "all_6mo",
+ chart_type: str = "directory",
+ directory: Optional[str] = None,
+ trust: Optional[str] = None,
+) -> list[dict]:
+ """Patient retention by treatment line depth."""
+ return _get_retention_funnel(DB_PATH, date_filter_id, chart_type, directory, trust)
diff --git a/src/data_processing/pathway_queries.py b/src/data_processing/pathway_queries.py
index 8de4541..1bee9ea 100644
--- a/src/data_processing/pathway_queries.py
+++ b/src/data_processing/pathway_queries.py
@@ -1069,6 +1069,76 @@ def get_trust_durations(
# --- Directorate/indication summary for Trust Comparison landing page ---
+def get_retention_funnel(
+ db_path: Path,
+ date_filter_id: str,
+ chart_type: str,
+ directory: Optional[str] = None,
+ trust: Optional[str] = None,
+) -> list[dict]:
+ """Aggregate patient counts by treatment line depth for a retention funnel.
+
+ Level 3 = 1st drug, Level 4 = 2-drug pathway, Level 5 = 3-drug pathway, etc.
+ Returns list of dicts sorted by depth ascending:
+ [{depth: 1, label: "1 drug", patients: N, pct: 100.0}, ...]
+ """
+ conn = sqlite3.connect(str(db_path))
+ conn.row_factory = sqlite3.Row
+ try:
+ where = ["date_filter_id = ?", "chart_type = ?", "level >= 3"]
+ params: list = [date_filter_id, chart_type]
+
+ if directory:
+ where.append("directory = ?")
+ params.append(directory)
+ if trust:
+ where.append("trust_name = ?")
+ params.append(trust)
+
+ query = f"""
+ SELECT level, SUM(value) AS patients
+ FROM pathway_nodes
+ WHERE {' AND '.join(where)}
+ GROUP BY level
+ ORDER BY level
+ """
+ rows = conn.execute(query, params).fetchall()
+
+ if not rows:
+ return []
+
+ result = []
+ base_patients = 0
+ for r in rows:
+ level = r["level"]
+ patients = r["patients"] or 0
+ depth = level - 2 # level 3 → depth 1, level 4 → depth 2, etc.
+
+ if depth == 1:
+ base_patients = patients
+
+ ordinal_labels = {
+ 1: "1st drug",
+ 2: "2nd drug",
+ 3: "3rd drug",
+ }
+ label = ordinal_labels.get(depth, f"{depth}th drug")
+
+ pct = round(patients / base_patients * 100, 1) if base_patients else 0
+ result.append({
+ "depth": depth,
+ "label": label,
+ "patients": patients,
+ "pct": pct,
+ })
+
+ return result
+ except sqlite3.Error:
+ return []
+ finally:
+ conn.close()
+
+
def get_directorate_summary(
db_path: Path,
date_filter_id: str,
diff --git a/src/visualization/plotly_generator.py b/src/visualization/plotly_generator.py
index b903c9c..9436e74 100644
--- a/src/visualization/plotly_generator.py
+++ b/src/visualization/plotly_generator.py
@@ -1754,3 +1754,69 @@ def create_trust_duration_figure(
fig.update_layout(**layout)
return fig
+
+
+def create_retention_funnel_figure(
+ data: list[dict],
+ title: str = "",
+) -> go.Figure:
+ """Create a retention funnel showing patient drop-off by treatment line depth.
+
+ Args:
+ data: List of dicts with keys: depth, label, patients, pct
+ title: Chart title from filter state.
+
+ Returns:
+ Plotly Figure with go.Funnel trace.
+ """
+ if not data:
+ return go.Figure()
+
+ display_title = f"Treatment Retention — {title}" if title else "Treatment Retention"
+
+ labels = [d["label"] for d in data]
+ patients = [d["patients"] for d in data]
+ pcts = [d["pct"] for d in data]
+
+ # NHS blue gradient: darkest at top (most patients) → lightest at bottom
+ funnel_colors = [
+ "#003087", # NHS Heritage Blue (1st drug)
+ "#005EB8", # NHS Blue
+ "#1E88E5", # Mid blue
+ "#42A5F5", # Light blue
+ "#90CAF9", # Pale blue
+ ]
+ colors = funnel_colors[: len(data)]
+ if len(colors) < len(data):
+ colors.extend(["#E3F2FD"] * (len(data) - len(colors)))
+
+ text_values = [
+ f"{p:,} patients ({pct}%)" for p, pct in zip(patients, pcts)
+ ]
+
+ fig = go.Figure(
+ go.Funnel(
+ y=labels,
+ x=patients,
+ text=text_values,
+ textposition="inside",
+ textfont=dict(family=CHART_FONT_FAMILY, size=14, color="white"),
+ marker=dict(color=colors),
+ connector=dict(line=dict(color=GRID_COLOR, width=1)),
+ hovertemplate=(
+ "%{y}
"
+ "Patients: %{x:,}
"
+ "%{text}"
+ ),
+ )
+ )
+
+ layout = _base_layout(display_title)
+ layout.update(
+ margin=dict(t=60, l=8, r=8, b=40),
+ yaxis=dict(automargin=True),
+ height=max(300, len(data) * 80 + 120),
+ )
+ fig.update_layout(**layout)
+
+ return fig