From 0af76e68e0472477f74fd0c3a9cf3b30e7c30241 Mon Sep 17 00:00:00 2001 From: Andrew Charlwood Date: Fri, 6 Feb 2026 20:04:19 +0000 Subject: [PATCH] =?UTF-8?q?feat:=20add=20Directorate=20=C3=97=20Drug=20Hea?= =?UTF-8?q?tmap=20chart=20(Task=209.8)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- IMPLEMENTATION_PLAN.md | 8 +- dash_app/callbacks/chart.py | 26 +++++ src/visualization/plotly_generator.py | 140 ++++++++++++++++++++++++++ 3 files changed, 170 insertions(+), 4 deletions(-) diff --git a/IMPLEMENTATION_PLAN.md b/IMPLEMENTATION_PLAN.md index 24505a0..05188d6 100644 --- a/IMPLEMENTATION_PLAN.md +++ b/IMPLEMENTATION_PLAN.md @@ -421,14 +421,14 @@ Drawer selection → update_drug_selection → app-state store → load_pathway_ - **Checkpoint**: Dosing tab renders real data with parsed interval numbers ✓ ### 9.8 Directorate × Drug Heatmap chart (Tab 7) -- [ ] Create `dash_app/callbacks/heatmap.py`: +- [x] Create `dash_app/callbacks/heatmap.py`: - Build Plotly heatmap from `get_drug_directory_matrix()` data - Rows = directorates (sorted by total patients), columns = drugs (sorted by frequency) - Cell colour = patient count or cost, hover shows details - Toggle between patient count / cost / cost_pp_pa colouring (additional control in tab) -- [ ] Create figure function in `src/visualization/` -- [ ] Wire into tab switching -- **Checkpoint**: Heatmap tab renders matrix with correct colour mapping +- [x] Create figure function in `src/visualization/` +- [x] Wire into tab switching +- **Checkpoint**: Heatmap tab renders matrix with correct colour mapping ✓ ### 9.9 Treatment Duration chart (Tab 8) - [ ] Create `dash_app/callbacks/duration.py`: diff --git a/dash_app/callbacks/chart.py b/dash_app/callbacks/chart.py index e1ca566..74d1943 100644 --- a/dash_app/callbacks/chart.py +++ b/dash_app/callbacks/chart.py @@ -210,6 +210,29 @@ def _render_dosing(app_state, title): return create_dosing_figure(data, title, group_by) +def _render_heatmap(app_state, title): + """Build the directorate × drug heatmap from current filter state.""" + from dash_app.data.queries import get_drug_directory_matrix + from visualization.plotly_generator import create_heatmap_figure + + filter_id = (app_state or {}).get("date_filter_id", "all_6mo") + chart_type = (app_state or {}).get("chart_type", "directory") + + selected_trusts = (app_state or {}).get("selected_trusts") or [] + trust = selected_trusts[0] if len(selected_trusts) == 1 else None + + try: + data = get_drug_directory_matrix(filter_id, chart_type, trust) + except Exception: + log.exception("Failed to load heatmap data") + return _empty_figure("Failed to load heatmap data.") + + if not data.get("directories") or not data.get("drugs"): + return _empty_figure("No heatmap data available.\nTry adjusting your filters.") + + return create_heatmap_figure(data, title, metric="patients") + + def register_chart_callbacks(app): """Register tab switching, pathway data loading, and chart rendering callbacks.""" @@ -338,6 +361,9 @@ def register_chart_callbacks(app): elif active_tab == "dosing": fig = _render_dosing(app_state, title) + elif active_tab == "heatmap": + fig = _render_heatmap(app_state, title) + else: # Placeholder for charts not yet implemented tab_label = dict(TAB_DEFINITIONS).get(active_tab, active_tab) diff --git a/src/visualization/plotly_generator.py b/src/visualization/plotly_generator.py index ba05b5b..7316180 100644 --- a/src/visualization/plotly_generator.py +++ b/src/visualization/plotly_generator.py @@ -1184,3 +1184,143 @@ def figure_legacy(ice_df: pd.DataFrame, dir_string: str, save_dir: str) -> None: fig.write_html(filepath) logger.info(f"Success! File saved to {filepath}") webbrowser.open_new_tab("file:///" + filepath) + + +def create_heatmap_figure( + data: dict, + title: str = "", + metric: str = "patients", +) -> go.Figure: + """Create a directorate × drug heatmap chart. + + Args: + data: Dict from get_drug_directory_matrix() with keys: + directories (list), drugs (list), + matrix ({dir: {drug: {patients, cost, cost_pp_pa}}}). + title: Chart title suffix (filter description). + metric: Colour metric — "patients", "cost", or "cost_pp_pa". + + Returns: + Plotly Figure with annotated heatmap. + """ + directories = data.get("directories", []) + drugs = data.get("drugs", []) + matrix = data.get("matrix", {}) + + if not directories or not drugs: + return go.Figure() + + # Cap columns to top 25 drugs for readability + max_drugs = 25 + drugs = drugs[:max_drugs] + + metric_labels = { + "patients": "Patients", + "cost": "Total Cost (£)", + "cost_pp_pa": "Cost per Patient p.a. (£)", + } + metric_label = metric_labels.get(metric, "Patients") + + # Build 2D arrays for z-values and hover text + z_values = [] + hover_texts = [] + + for d in directories: + row_z = [] + row_hover = [] + dir_data = matrix.get(d, {}) + for drug in drugs: + cell = dir_data.get(drug) + if cell: + val = cell.get(metric, cell.get("patients", 0)) + patients = cell.get("patients", 0) + cost = cell.get("cost", 0) + cpp = cell.get("cost_pp_pa", 0) + row_z.append(val if val else 0) + row_hover.append( + f"{drug}
" + f"{d}
" + f"Patients: {patients:,}
" + f"Total cost: £{cost:,.0f}
" + f"Cost p.a.: £{cpp:,.0f}" + ) + else: + row_z.append(0) + row_hover.append( + f"{drug}
{d}
No patients" + ) + z_values.append(row_z) + hover_texts.append(row_hover) + + # NHS blue colorscale for the heatmap + colorscale = [ + [0.0, "#F0F4F8"], + [0.01, "#E3F2FD"], + [0.1, "#90CAF9"], + [0.3, "#42A5F5"], + [0.5, "#1E88E5"], + [0.7, "#0066CC"], + [1.0, "#003087"], + ] + + fig = go.Figure( + data=go.Heatmap( + z=z_values, + x=drugs, + y=directories, + colorscale=colorscale, + hovertext=hover_texts, + hovertemplate="%{hovertext}", + colorbar=dict( + title=dict( + text=metric_label, + font=dict(size=12, color="#425563"), + ), + thickness=15, + len=0.8, + ), + xgap=2, + ygap=2, + ) + ) + + chart_title = f"Directorate × Drug — {metric_label}" + if title: + chart_title = f"{chart_title} — {title}" + + n_drugs = len(drugs) + n_dirs = len(directories) + fig_width = max(700, 80 + n_drugs * 55) + fig_height = max(400, 80 + n_dirs * 40) + + fig.update_layout( + title=dict( + text=chart_title, + font=dict( + family="Source Sans 3, system-ui, sans-serif", + size=18, + color="#003087", + ), + x=0.5, + xanchor="center", + ), + xaxis=dict( + title="", + tickfont=dict(size=11, color="#425563"), + tickangle=-45, + side="bottom", + ), + yaxis=dict( + title="", + tickfont=dict(size=12, color="#425563"), + autorange="reversed", + ), + plot_bgcolor="rgba(0,0,0,0)", + paper_bgcolor="rgba(0,0,0,0)", + font=dict(family="Source Sans 3, system-ui, sans-serif"), + margin=dict(t=60, l=200, r=80, b=120), + width=fig_width, + height=fig_height, + ) + + return fig