diff --git a/IMPLEMENTATION_PLAN.md b/IMPLEMENTATION_PLAN.md
index 24505a0..05188d6 100644
--- a/IMPLEMENTATION_PLAN.md
+++ b/IMPLEMENTATION_PLAN.md
@@ -421,14 +421,14 @@ Drawer selection → update_drug_selection → app-state store → load_pathway_
- **Checkpoint**: Dosing tab renders real data with parsed interval numbers ✓
### 9.8 Directorate × Drug Heatmap chart (Tab 7)
-- [ ] Create `dash_app/callbacks/heatmap.py`:
+- [x] Create `dash_app/callbacks/heatmap.py`:
- Build Plotly heatmap from `get_drug_directory_matrix()` data
- Rows = directorates (sorted by total patients), columns = drugs (sorted by frequency)
- Cell colour = patient count or cost, hover shows details
- Toggle between patient count / cost / cost_pp_pa colouring (additional control in tab)
-- [ ] Create figure function in `src/visualization/`
-- [ ] Wire into tab switching
-- **Checkpoint**: Heatmap tab renders matrix with correct colour mapping
+- [x] Create figure function in `src/visualization/`
+- [x] Wire into tab switching
+- **Checkpoint**: Heatmap tab renders matrix with correct colour mapping ✓
### 9.9 Treatment Duration chart (Tab 8)
- [ ] Create `dash_app/callbacks/duration.py`:
diff --git a/dash_app/callbacks/chart.py b/dash_app/callbacks/chart.py
index e1ca566..74d1943 100644
--- a/dash_app/callbacks/chart.py
+++ b/dash_app/callbacks/chart.py
@@ -210,6 +210,29 @@ def _render_dosing(app_state, title):
return create_dosing_figure(data, title, group_by)
+def _render_heatmap(app_state, title):
+ """Build the directorate × drug heatmap from current filter state."""
+ from dash_app.data.queries import get_drug_directory_matrix
+ from visualization.plotly_generator import create_heatmap_figure
+
+ filter_id = (app_state or {}).get("date_filter_id", "all_6mo")
+ chart_type = (app_state or {}).get("chart_type", "directory")
+
+ selected_trusts = (app_state or {}).get("selected_trusts") or []
+ trust = selected_trusts[0] if len(selected_trusts) == 1 else None
+
+ try:
+ data = get_drug_directory_matrix(filter_id, chart_type, trust)
+ except Exception:
+ log.exception("Failed to load heatmap data")
+ return _empty_figure("Failed to load heatmap data.")
+
+ if not data.get("directories") or not data.get("drugs"):
+ return _empty_figure("No heatmap data available.\nTry adjusting your filters.")
+
+ return create_heatmap_figure(data, title, metric="patients")
+
+
def register_chart_callbacks(app):
"""Register tab switching, pathway data loading, and chart rendering callbacks."""
@@ -338,6 +361,9 @@ def register_chart_callbacks(app):
elif active_tab == "dosing":
fig = _render_dosing(app_state, title)
+ elif active_tab == "heatmap":
+ fig = _render_heatmap(app_state, title)
+
else:
# Placeholder for charts not yet implemented
tab_label = dict(TAB_DEFINITIONS).get(active_tab, active_tab)
diff --git a/src/visualization/plotly_generator.py b/src/visualization/plotly_generator.py
index ba05b5b..7316180 100644
--- a/src/visualization/plotly_generator.py
+++ b/src/visualization/plotly_generator.py
@@ -1184,3 +1184,143 @@ def figure_legacy(ice_df: pd.DataFrame, dir_string: str, save_dir: str) -> None:
fig.write_html(filepath)
logger.info(f"Success! File saved to {filepath}")
webbrowser.open_new_tab("file:///" + filepath)
+
+
+def create_heatmap_figure(
+ data: dict,
+ title: str = "",
+ metric: str = "patients",
+) -> go.Figure:
+ """Create a directorate × drug heatmap chart.
+
+ Args:
+ data: Dict from get_drug_directory_matrix() with keys:
+ directories (list), drugs (list),
+ matrix ({dir: {drug: {patients, cost, cost_pp_pa}}}).
+ title: Chart title suffix (filter description).
+ metric: Colour metric — "patients", "cost", or "cost_pp_pa".
+
+ Returns:
+ Plotly Figure with annotated heatmap.
+ """
+ directories = data.get("directories", [])
+ drugs = data.get("drugs", [])
+ matrix = data.get("matrix", {})
+
+ if not directories or not drugs:
+ return go.Figure()
+
+ # Cap columns to top 25 drugs for readability
+ max_drugs = 25
+ drugs = drugs[:max_drugs]
+
+ metric_labels = {
+ "patients": "Patients",
+ "cost": "Total Cost (£)",
+ "cost_pp_pa": "Cost per Patient p.a. (£)",
+ }
+ metric_label = metric_labels.get(metric, "Patients")
+
+ # Build 2D arrays for z-values and hover text
+ z_values = []
+ hover_texts = []
+
+ for d in directories:
+ row_z = []
+ row_hover = []
+ dir_data = matrix.get(d, {})
+ for drug in drugs:
+ cell = dir_data.get(drug)
+ if cell:
+ val = cell.get(metric, cell.get("patients", 0))
+ patients = cell.get("patients", 0)
+ cost = cell.get("cost", 0)
+ cpp = cell.get("cost_pp_pa", 0)
+ row_z.append(val if val else 0)
+ row_hover.append(
+ f"{drug}
"
+ f"{d}
"
+ f"Patients: {patients:,}
"
+ f"Total cost: £{cost:,.0f}
"
+ f"Cost p.a.: £{cpp:,.0f}"
+ )
+ else:
+ row_z.append(0)
+ row_hover.append(
+ f"{drug}
{d}
No patients"
+ )
+ z_values.append(row_z)
+ hover_texts.append(row_hover)
+
+ # NHS blue colorscale for the heatmap
+ colorscale = [
+ [0.0, "#F0F4F8"],
+ [0.01, "#E3F2FD"],
+ [0.1, "#90CAF9"],
+ [0.3, "#42A5F5"],
+ [0.5, "#1E88E5"],
+ [0.7, "#0066CC"],
+ [1.0, "#003087"],
+ ]
+
+ fig = go.Figure(
+ data=go.Heatmap(
+ z=z_values,
+ x=drugs,
+ y=directories,
+ colorscale=colorscale,
+ hovertext=hover_texts,
+ hovertemplate="%{hovertext}",
+ colorbar=dict(
+ title=dict(
+ text=metric_label,
+ font=dict(size=12, color="#425563"),
+ ),
+ thickness=15,
+ len=0.8,
+ ),
+ xgap=2,
+ ygap=2,
+ )
+ )
+
+ chart_title = f"Directorate × Drug — {metric_label}"
+ if title:
+ chart_title = f"{chart_title} — {title}"
+
+ n_drugs = len(drugs)
+ n_dirs = len(directories)
+ fig_width = max(700, 80 + n_drugs * 55)
+ fig_height = max(400, 80 + n_dirs * 40)
+
+ fig.update_layout(
+ title=dict(
+ text=chart_title,
+ font=dict(
+ family="Source Sans 3, system-ui, sans-serif",
+ size=18,
+ color="#003087",
+ ),
+ x=0.5,
+ xanchor="center",
+ ),
+ xaxis=dict(
+ title="",
+ tickfont=dict(size=11, color="#425563"),
+ tickangle=-45,
+ side="bottom",
+ ),
+ yaxis=dict(
+ title="",
+ tickfont=dict(size=12, color="#425563"),
+ autorange="reversed",
+ ),
+ plot_bgcolor="rgba(0,0,0,0)",
+ paper_bgcolor="rgba(0,0,0,0)",
+ font=dict(family="Source Sans 3, system-ui, sans-serif"),
+ margin=dict(t=60, l=200, r=80, b=120),
+ width=fig_width,
+ height=fig_height,
+ )
+
+ return fig