diff --git a/IMPLEMENTATION_PLAN.md b/IMPLEMENTATION_PLAN.md
index 039ddae..b63a2e3 100644
--- a/IMPLEMENTATION_PLAN.md
+++ b/IMPLEMENTATION_PLAN.md
@@ -160,16 +160,18 @@ Comprehensive review and improvement of all Plotly charts in the Dash dashboard.
- **Checkpoint**: Depth tab shows patient distribution by treatment line count
### C.3 Duration vs Cost scatter plot
-- [ ] Create `get_duration_cost_scatter()` in `src/data_processing/pathway_queries.py`:
- - Query level 3 nodes for drug-level data
+- [x] Create `get_duration_cost_scatter()` in `src/data_processing/pathway_queries.py`:
+ - Query level 3 nodes for drug-level data with avg_days and cost_pp_pa
+ - Aggregates across trusts using weighted averages
- Return: `[{drug, directory, avg_days, cost_pp_pa, patients}, ...]`
-- [ ] Add thin wrapper in `dash_app/data/queries.py`
-- [ ] Create `create_duration_cost_scatter_figure(data, title)` in `src/visualization/plotly_generator.py`:
- - Scatter: x=avg_days, y=cost_pp_pa, size=patients, color=directory
- - Add quadrant lines at median values (4 quadrants: cheap/short, cheap/long, expensive/short, expensive/long)
+- [x] Add thin wrapper in `dash_app/data/queries.py`
+- [x] Create `create_duration_cost_scatter_figure(data, title)` in `src/visualization/plotly_generator.py`:
+ - Scatter: x=avg_days, y=cost_pp_pa, size=patients (global max), color=directory
+ - One trace per directory for legend grouping using DRUG_PALETTE
+ - Quadrant lines at median values with annotations
- Hover shows drug name, directory, all values
-- [ ] Add "Scatter" tab to `TAB_DEFINITIONS` in `chart_card.py`
-- [ ] Add `_render_scatter()` helper and tab dispatch in `dash_app/callbacks/chart.py`
+- [x] Add "Scatter" tab to `TAB_DEFINITIONS` in `chart_card.py` (6 tabs: Icicle, Sankey, Heatmap, Funnel, Depth, Scatter)
+- [x] Add `_render_scatter()` helper and tab dispatch in `dash_app/callbacks/chart.py`
- **Checkpoint**: Scatter tab shows drugs by duration vs cost with directorate coloring
### C.4 Drug switching network graph
diff --git a/dash_app/callbacks/chart.py b/dash_app/callbacks/chart.py
index 4d41c85..3644494 100644
--- a/dash_app/callbacks/chart.py
+++ b/dash_app/callbacks/chart.py
@@ -317,6 +317,31 @@ def _render_depth(app_state, title):
return create_pathway_depth_figure(data, title)
+def _render_scatter(app_state, title):
+ """Build the duration vs cost scatter plot from current filter state."""
+ from dash_app.data.queries import get_duration_cost_scatter
+ from visualization.plotly_generator import create_duration_cost_scatter_figure
+
+ filter_id = (app_state or {}).get("date_filter_id", "all_6mo")
+ chart_type = (app_state or {}).get("chart_type", "directory")
+
+ selected_dirs = (app_state or {}).get("selected_directorates") or []
+ selected_trusts = (app_state or {}).get("selected_trusts") or []
+ directory = selected_dirs[0] if len(selected_dirs) == 1 else None
+ trust = selected_trusts[0] if len(selected_trusts) == 1 else None
+
+ try:
+ data = get_duration_cost_scatter(filter_id, chart_type, directory, trust)
+ except Exception:
+ log.exception("Failed to load duration/cost scatter data")
+ return _empty_figure("Failed to load scatter data.")
+
+ if not data:
+ return _empty_figure("No duration/cost data available.\nTry adjusting your filters.")
+
+ return create_duration_cost_scatter_figure(data, title)
+
+
def register_chart_callbacks(app):
"""Register tab switching, pathway data loading, and chart rendering callbacks."""
@@ -463,6 +488,9 @@ def register_chart_callbacks(app):
elif active_tab == "depth":
fig = _render_depth(app_state, title)
+ elif active_tab == "scatter":
+ fig = _render_scatter(app_state, title)
+
else:
# Placeholder for charts not yet implemented
tab_label = dict(TAB_DEFINITIONS).get(active_tab, active_tab)
diff --git a/dash_app/components/chart_card.py b/dash_app/components/chart_card.py
index 0820df7..d348cc7 100644
--- a/dash_app/components/chart_card.py
+++ b/dash_app/components/chart_card.py
@@ -10,6 +10,7 @@ TAB_DEFINITIONS = [
("heatmap", "Heatmap"),
("funnel", "Funnel"),
("depth", "Depth"),
+ ("scatter", "Scatter"),
]
# Full set retained for Trust Comparison dashboard (Phase 10.8)
diff --git a/dash_app/data/queries.py b/dash_app/data/queries.py
index 17d1a2e..d3e9758 100644
--- a/dash_app/data/queries.py
+++ b/dash_app/data/queries.py
@@ -26,6 +26,7 @@ from data_processing.pathway_queries import (
get_directorate_summary as _get_directorate_summary,
get_retention_funnel as _get_retention_funnel,
get_pathway_depth_distribution as _get_pathway_depth_distribution,
+ get_duration_cost_scatter as _get_duration_cost_scatter,
)
DB_PATH = Path(__file__).resolve().parents[2] / "data" / "pathways.db"
@@ -205,3 +206,13 @@ def get_pathway_depth_distribution(
) -> list[dict]:
"""Patients who stopped at each treatment line depth (exclusive counts)."""
return _get_pathway_depth_distribution(DB_PATH, date_filter_id, chart_type, directory, trust)
+
+
+def get_duration_cost_scatter(
+ date_filter_id: str = "all_6mo",
+ chart_type: str = "directory",
+ directory: Optional[str] = None,
+ trust: Optional[str] = None,
+) -> list[dict]:
+ """Drug-level avg_days and cost_pp_pa for scatter plot."""
+ return _get_duration_cost_scatter(DB_PATH, date_filter_id, chart_type, directory, trust)
diff --git a/src/data_processing/pathway_queries.py b/src/data_processing/pathway_queries.py
index 071bfe1..40bec58 100644
--- a/src/data_processing/pathway_queries.py
+++ b/src/data_processing/pathway_queries.py
@@ -1209,6 +1209,82 @@ def get_pathway_depth_distribution(
conn.close()
+def get_duration_cost_scatter(
+ db_path: Path,
+ date_filter_id: str,
+ chart_type: str,
+ directory: Optional[str] = None,
+ trust: Optional[str] = None,
+) -> list[dict]:
+ """Level 3 drug nodes with avg_days and cost_pp_pa for scatter plot.
+
+ Returns list of dicts: [{drug, directory, avg_days, cost_pp_pa, patients}]
+ Excludes nodes missing avg_days or cost_pp_pa. Aggregates across trusts
+ using weighted averages.
+ """
+ conn = sqlite3.connect(str(db_path))
+ conn.row_factory = sqlite3.Row
+ try:
+ where = ["date_filter_id = ?", "chart_type = ?", "level = 3",
+ "avg_days IS NOT NULL", "cost_pp_pa IS NOT NULL"]
+ params: list = [date_filter_id, chart_type]
+
+ if directory:
+ where.append("directory = ?")
+ params.append(directory)
+ if trust:
+ where.append("trust_name = ?")
+ params.append(trust)
+
+ query = f"""
+ SELECT labels AS drug, directory,
+ value AS patients, avg_days, cost_pp_pa
+ FROM pathway_nodes
+ WHERE {' AND '.join(where)}
+ """
+ rows = conn.execute(query, params).fetchall()
+
+ # Aggregate across trusts: weighted average of avg_days and cost_pp_pa
+ agg = {}
+ for r in rows:
+ key = (r["directory"] or "", r["drug"])
+ patients = r["patients"] or 0
+ days = _safe_float(r["avg_days"])
+ cost = _safe_float(r["cost_pp_pa"])
+ if patients == 0 or days == 0:
+ continue
+
+ if key not in agg:
+ agg[key] = {
+ "drug": r["drug"],
+ "directory": r["directory"] or "",
+ "weighted_days": 0.0,
+ "weighted_cost": 0.0,
+ "total_patients": 0,
+ }
+ agg[key]["weighted_days"] += days * patients
+ agg[key]["weighted_cost"] += cost * patients
+ agg[key]["total_patients"] += patients
+
+ result = []
+ for v in agg.values():
+ tp = v["total_patients"]
+ if tp > 0:
+ result.append({
+ "drug": v["drug"],
+ "directory": v["directory"],
+ "avg_days": round(v["weighted_days"] / tp, 1),
+ "cost_pp_pa": round(v["weighted_cost"] / tp, 0),
+ "patients": tp,
+ })
+
+ return result
+ except sqlite3.Error:
+ return []
+ finally:
+ conn.close()
+
+
def get_directorate_summary(
db_path: Path,
date_filter_id: str,
diff --git a/src/visualization/plotly_generator.py b/src/visualization/plotly_generator.py
index ff3cc71..0cb88a4 100644
--- a/src/visualization/plotly_generator.py
+++ b/src/visualization/plotly_generator.py
@@ -1891,3 +1891,109 @@ def create_pathway_depth_figure(
fig.update_layout(**layout)
return fig
+
+
+def create_duration_cost_scatter_figure(
+ data: list[dict],
+ title: str = "",
+) -> go.Figure:
+ """Create a Duration vs Cost scatter plot from drug-level data.
+
+ Each point represents a drug (within a directory). x=avg treatment days,
+ y=annualised cost per patient, size=patient count, color=directory.
+ Quadrant lines at median values divide into 4 regions.
+ """
+ if not data:
+ return go.Figure()
+
+ import statistics
+
+ display_title = f"Duration vs Cost — {title}" if title else "Duration vs Cost"
+
+ # Assign colors by directory
+ directories = sorted(set(d["directory"] for d in data))
+ dir_colors = {
+ d: DRUG_PALETTE[i % len(DRUG_PALETTE)]
+ for i, d in enumerate(directories)
+ }
+
+ # Global max patients for consistent sizing across directories
+ global_max_p = max((d["patients"] for d in data), default=1) or 1
+
+ # Build one trace per directory for legend grouping
+ fig = go.Figure()
+ for directory in directories:
+ subset = [d for d in data if d["directory"] == directory]
+ patients = [d["patients"] for d in subset]
+
+ # Scale marker size: min 8, max 40, relative to global max
+ sizes = [max(8, min(40, 8 + 32 * (p / global_max_p))) for p in patients]
+
+ fig.add_trace(go.Scatter(
+ x=[d["avg_days"] for d in subset],
+ y=[d["cost_pp_pa"] for d in subset],
+ mode="markers",
+ name=directory,
+ marker=dict(
+ size=sizes,
+ color=dir_colors[directory],
+ opacity=0.75,
+ line=dict(width=1, color="white"),
+ ),
+ text=[d["drug"] for d in subset],
+ customdata=[[d["patients"], d["directory"], d["avg_days"], d["cost_pp_pa"]] for d in subset],
+ hovertemplate=(
+ "%{text}
"
+ "Directory: %{customdata[1]}
"
+ "Avg duration: %{customdata[2]} days
"
+ "Cost p.a.: £%{customdata[3]:,.0f}
"
+ "Patients: %{customdata[0]:,}
"
+ ""
+ ),
+ ))
+
+ # Quadrant lines at median values
+ all_days = [d["avg_days"] for d in data]
+ all_costs = [d["cost_pp_pa"] for d in data]
+ med_days = statistics.median(all_days)
+ med_cost = statistics.median(all_costs)
+
+ fig.add_hline(
+ y=med_cost, line_dash="dash", line_color=ANNOTATION_COLOR,
+ line_width=1,
+ annotation_text=f"Median £{med_cost:,.0f}",
+ annotation_position="top left",
+ annotation_font=dict(size=10, color=ANNOTATION_COLOR, family=CHART_FONT_FAMILY),
+ )
+ fig.add_vline(
+ x=med_days, line_dash="dash", line_color=ANNOTATION_COLOR,
+ line_width=1,
+ annotation_text=f"Median {med_days:.0f} days",
+ annotation_position="top right",
+ annotation_font=dict(size=10, color=ANNOTATION_COLOR, family=CHART_FONT_FAMILY),
+ )
+
+ n_dirs = len(directories)
+ legend = _smart_legend(n_dirs, "Directory")
+ legend_margins = _smart_legend_margin(n_dirs)
+
+ layout = _base_layout(display_title)
+ layout.update(
+ margin=dict(t=60, l=8, **legend_margins),
+ xaxis=dict(
+ title="Average Treatment Duration (days)",
+ gridcolor=GRID_COLOR,
+ zeroline=False,
+ ),
+ yaxis=dict(
+ title="Cost per Patient per Annum (£)",
+ gridcolor=GRID_COLOR,
+ automargin=True,
+ zeroline=False,
+ ),
+ legend=legend,
+ height=550,
+ )
+ fig.update_layout(**layout)
+
+ return fig