diff --git a/IMPLEMENTATION_PLAN.md b/IMPLEMENTATION_PLAN.md index 039ddae..b63a2e3 100644 --- a/IMPLEMENTATION_PLAN.md +++ b/IMPLEMENTATION_PLAN.md @@ -160,16 +160,18 @@ Comprehensive review and improvement of all Plotly charts in the Dash dashboard. - **Checkpoint**: Depth tab shows patient distribution by treatment line count ### C.3 Duration vs Cost scatter plot -- [ ] Create `get_duration_cost_scatter()` in `src/data_processing/pathway_queries.py`: - - Query level 3 nodes for drug-level data +- [x] Create `get_duration_cost_scatter()` in `src/data_processing/pathway_queries.py`: + - Query level 3 nodes for drug-level data with avg_days and cost_pp_pa + - Aggregates across trusts using weighted averages - Return: `[{drug, directory, avg_days, cost_pp_pa, patients}, ...]` -- [ ] Add thin wrapper in `dash_app/data/queries.py` -- [ ] Create `create_duration_cost_scatter_figure(data, title)` in `src/visualization/plotly_generator.py`: - - Scatter: x=avg_days, y=cost_pp_pa, size=patients, color=directory - - Add quadrant lines at median values (4 quadrants: cheap/short, cheap/long, expensive/short, expensive/long) +- [x] Add thin wrapper in `dash_app/data/queries.py` +- [x] Create `create_duration_cost_scatter_figure(data, title)` in `src/visualization/plotly_generator.py`: + - Scatter: x=avg_days, y=cost_pp_pa, size=patients (global max), color=directory + - One trace per directory for legend grouping using DRUG_PALETTE + - Quadrant lines at median values with annotations - Hover shows drug name, directory, all values -- [ ] Add "Scatter" tab to `TAB_DEFINITIONS` in `chart_card.py` -- [ ] Add `_render_scatter()` helper and tab dispatch in `dash_app/callbacks/chart.py` +- [x] Add "Scatter" tab to `TAB_DEFINITIONS` in `chart_card.py` (6 tabs: Icicle, Sankey, Heatmap, Funnel, Depth, Scatter) +- [x] Add `_render_scatter()` helper and tab dispatch in `dash_app/callbacks/chart.py` - **Checkpoint**: Scatter tab shows drugs by duration vs cost with directorate coloring ### C.4 Drug switching network graph diff --git a/dash_app/callbacks/chart.py b/dash_app/callbacks/chart.py index 4d41c85..3644494 100644 --- a/dash_app/callbacks/chart.py +++ b/dash_app/callbacks/chart.py @@ -317,6 +317,31 @@ def _render_depth(app_state, title): return create_pathway_depth_figure(data, title) +def _render_scatter(app_state, title): + """Build the duration vs cost scatter plot from current filter state.""" + from dash_app.data.queries import get_duration_cost_scatter + from visualization.plotly_generator import create_duration_cost_scatter_figure + + filter_id = (app_state or {}).get("date_filter_id", "all_6mo") + chart_type = (app_state or {}).get("chart_type", "directory") + + selected_dirs = (app_state or {}).get("selected_directorates") or [] + selected_trusts = (app_state or {}).get("selected_trusts") or [] + directory = selected_dirs[0] if len(selected_dirs) == 1 else None + trust = selected_trusts[0] if len(selected_trusts) == 1 else None + + try: + data = get_duration_cost_scatter(filter_id, chart_type, directory, trust) + except Exception: + log.exception("Failed to load duration/cost scatter data") + return _empty_figure("Failed to load scatter data.") + + if not data: + return _empty_figure("No duration/cost data available.\nTry adjusting your filters.") + + return create_duration_cost_scatter_figure(data, title) + + def register_chart_callbacks(app): """Register tab switching, pathway data loading, and chart rendering callbacks.""" @@ -463,6 +488,9 @@ def register_chart_callbacks(app): elif active_tab == "depth": fig = _render_depth(app_state, title) + elif active_tab == "scatter": + fig = _render_scatter(app_state, title) + else: # Placeholder for charts not yet implemented tab_label = dict(TAB_DEFINITIONS).get(active_tab, active_tab) diff --git a/dash_app/components/chart_card.py b/dash_app/components/chart_card.py index 0820df7..d348cc7 100644 --- a/dash_app/components/chart_card.py +++ b/dash_app/components/chart_card.py @@ -10,6 +10,7 @@ TAB_DEFINITIONS = [ ("heatmap", "Heatmap"), ("funnel", "Funnel"), ("depth", "Depth"), + ("scatter", "Scatter"), ] # Full set retained for Trust Comparison dashboard (Phase 10.8) diff --git a/dash_app/data/queries.py b/dash_app/data/queries.py index 17d1a2e..d3e9758 100644 --- a/dash_app/data/queries.py +++ b/dash_app/data/queries.py @@ -26,6 +26,7 @@ from data_processing.pathway_queries import ( get_directorate_summary as _get_directorate_summary, get_retention_funnel as _get_retention_funnel, get_pathway_depth_distribution as _get_pathway_depth_distribution, + get_duration_cost_scatter as _get_duration_cost_scatter, ) DB_PATH = Path(__file__).resolve().parents[2] / "data" / "pathways.db" @@ -205,3 +206,13 @@ def get_pathway_depth_distribution( ) -> list[dict]: """Patients who stopped at each treatment line depth (exclusive counts).""" return _get_pathway_depth_distribution(DB_PATH, date_filter_id, chart_type, directory, trust) + + +def get_duration_cost_scatter( + date_filter_id: str = "all_6mo", + chart_type: str = "directory", + directory: Optional[str] = None, + trust: Optional[str] = None, +) -> list[dict]: + """Drug-level avg_days and cost_pp_pa for scatter plot.""" + return _get_duration_cost_scatter(DB_PATH, date_filter_id, chart_type, directory, trust) diff --git a/src/data_processing/pathway_queries.py b/src/data_processing/pathway_queries.py index 071bfe1..40bec58 100644 --- a/src/data_processing/pathway_queries.py +++ b/src/data_processing/pathway_queries.py @@ -1209,6 +1209,82 @@ def get_pathway_depth_distribution( conn.close() +def get_duration_cost_scatter( + db_path: Path, + date_filter_id: str, + chart_type: str, + directory: Optional[str] = None, + trust: Optional[str] = None, +) -> list[dict]: + """Level 3 drug nodes with avg_days and cost_pp_pa for scatter plot. + + Returns list of dicts: [{drug, directory, avg_days, cost_pp_pa, patients}] + Excludes nodes missing avg_days or cost_pp_pa. Aggregates across trusts + using weighted averages. + """ + conn = sqlite3.connect(str(db_path)) + conn.row_factory = sqlite3.Row + try: + where = ["date_filter_id = ?", "chart_type = ?", "level = 3", + "avg_days IS NOT NULL", "cost_pp_pa IS NOT NULL"] + params: list = [date_filter_id, chart_type] + + if directory: + where.append("directory = ?") + params.append(directory) + if trust: + where.append("trust_name = ?") + params.append(trust) + + query = f""" + SELECT labels AS drug, directory, + value AS patients, avg_days, cost_pp_pa + FROM pathway_nodes + WHERE {' AND '.join(where)} + """ + rows = conn.execute(query, params).fetchall() + + # Aggregate across trusts: weighted average of avg_days and cost_pp_pa + agg = {} + for r in rows: + key = (r["directory"] or "", r["drug"]) + patients = r["patients"] or 0 + days = _safe_float(r["avg_days"]) + cost = _safe_float(r["cost_pp_pa"]) + if patients == 0 or days == 0: + continue + + if key not in agg: + agg[key] = { + "drug": r["drug"], + "directory": r["directory"] or "", + "weighted_days": 0.0, + "weighted_cost": 0.0, + "total_patients": 0, + } + agg[key]["weighted_days"] += days * patients + agg[key]["weighted_cost"] += cost * patients + agg[key]["total_patients"] += patients + + result = [] + for v in agg.values(): + tp = v["total_patients"] + if tp > 0: + result.append({ + "drug": v["drug"], + "directory": v["directory"], + "avg_days": round(v["weighted_days"] / tp, 1), + "cost_pp_pa": round(v["weighted_cost"] / tp, 0), + "patients": tp, + }) + + return result + except sqlite3.Error: + return [] + finally: + conn.close() + + def get_directorate_summary( db_path: Path, date_filter_id: str, diff --git a/src/visualization/plotly_generator.py b/src/visualization/plotly_generator.py index ff3cc71..0cb88a4 100644 --- a/src/visualization/plotly_generator.py +++ b/src/visualization/plotly_generator.py @@ -1891,3 +1891,109 @@ def create_pathway_depth_figure( fig.update_layout(**layout) return fig + + +def create_duration_cost_scatter_figure( + data: list[dict], + title: str = "", +) -> go.Figure: + """Create a Duration vs Cost scatter plot from drug-level data. + + Each point represents a drug (within a directory). x=avg treatment days, + y=annualised cost per patient, size=patient count, color=directory. + Quadrant lines at median values divide into 4 regions. + """ + if not data: + return go.Figure() + + import statistics + + display_title = f"Duration vs Cost — {title}" if title else "Duration vs Cost" + + # Assign colors by directory + directories = sorted(set(d["directory"] for d in data)) + dir_colors = { + d: DRUG_PALETTE[i % len(DRUG_PALETTE)] + for i, d in enumerate(directories) + } + + # Global max patients for consistent sizing across directories + global_max_p = max((d["patients"] for d in data), default=1) or 1 + + # Build one trace per directory for legend grouping + fig = go.Figure() + for directory in directories: + subset = [d for d in data if d["directory"] == directory] + patients = [d["patients"] for d in subset] + + # Scale marker size: min 8, max 40, relative to global max + sizes = [max(8, min(40, 8 + 32 * (p / global_max_p))) for p in patients] + + fig.add_trace(go.Scatter( + x=[d["avg_days"] for d in subset], + y=[d["cost_pp_pa"] for d in subset], + mode="markers", + name=directory, + marker=dict( + size=sizes, + color=dir_colors[directory], + opacity=0.75, + line=dict(width=1, color="white"), + ), + text=[d["drug"] for d in subset], + customdata=[[d["patients"], d["directory"], d["avg_days"], d["cost_pp_pa"]] for d in subset], + hovertemplate=( + "%{text}
" + "Directory: %{customdata[1]}
" + "Avg duration: %{customdata[2]} days
" + "Cost p.a.: £%{customdata[3]:,.0f}
" + "Patients: %{customdata[0]:,}
" + "" + ), + )) + + # Quadrant lines at median values + all_days = [d["avg_days"] for d in data] + all_costs = [d["cost_pp_pa"] for d in data] + med_days = statistics.median(all_days) + med_cost = statistics.median(all_costs) + + fig.add_hline( + y=med_cost, line_dash="dash", line_color=ANNOTATION_COLOR, + line_width=1, + annotation_text=f"Median £{med_cost:,.0f}", + annotation_position="top left", + annotation_font=dict(size=10, color=ANNOTATION_COLOR, family=CHART_FONT_FAMILY), + ) + fig.add_vline( + x=med_days, line_dash="dash", line_color=ANNOTATION_COLOR, + line_width=1, + annotation_text=f"Median {med_days:.0f} days", + annotation_position="top right", + annotation_font=dict(size=10, color=ANNOTATION_COLOR, family=CHART_FONT_FAMILY), + ) + + n_dirs = len(directories) + legend = _smart_legend(n_dirs, "Directory") + legend_margins = _smart_legend_margin(n_dirs) + + layout = _base_layout(display_title) + layout.update( + margin=dict(t=60, l=8, **legend_margins), + xaxis=dict( + title="Average Treatment Duration (days)", + gridcolor=GRID_COLOR, + zeroline=False, + ), + yaxis=dict( + title="Cost per Patient per Annum (£)", + gridcolor=GRID_COLOR, + automargin=True, + zeroline=False, + ), + legend=legend, + height=550, + ) + fig.update_layout(**layout) + + return fig