feat: add Directorate × Drug Heatmap chart (Task 9.8)

This commit is contained in:
Andrew Charlwood
2026-02-06 20:04:19 +00:00
parent c1e11a6dd7
commit 0af76e68e0
3 changed files with 170 additions and 4 deletions
+4 -4
View File
@@ -421,14 +421,14 @@ Drawer selection → update_drug_selection → app-state store → load_pathway_
- **Checkpoint**: Dosing tab renders real data with parsed interval numbers ✓ - **Checkpoint**: Dosing tab renders real data with parsed interval numbers ✓
### 9.8 Directorate × Drug Heatmap chart (Tab 7) ### 9.8 Directorate × Drug Heatmap chart (Tab 7)
- [ ] Create `dash_app/callbacks/heatmap.py`: - [x] Create `dash_app/callbacks/heatmap.py`:
- Build Plotly heatmap from `get_drug_directory_matrix()` data - Build Plotly heatmap from `get_drug_directory_matrix()` data
- Rows = directorates (sorted by total patients), columns = drugs (sorted by frequency) - Rows = directorates (sorted by total patients), columns = drugs (sorted by frequency)
- Cell colour = patient count or cost, hover shows details - Cell colour = patient count or cost, hover shows details
- Toggle between patient count / cost / cost_pp_pa colouring (additional control in tab) - Toggle between patient count / cost / cost_pp_pa colouring (additional control in tab)
- [ ] Create figure function in `src/visualization/` - [x] Create figure function in `src/visualization/`
- [ ] Wire into tab switching - [x] Wire into tab switching
- **Checkpoint**: Heatmap tab renders matrix with correct colour mapping - **Checkpoint**: Heatmap tab renders matrix with correct colour mapping
### 9.9 Treatment Duration chart (Tab 8) ### 9.9 Treatment Duration chart (Tab 8)
- [ ] Create `dash_app/callbacks/duration.py`: - [ ] Create `dash_app/callbacks/duration.py`:
+26
View File
@@ -210,6 +210,29 @@ def _render_dosing(app_state, title):
return create_dosing_figure(data, title, group_by) return create_dosing_figure(data, title, group_by)
def _render_heatmap(app_state, title):
"""Build the directorate × drug heatmap from current filter state."""
from dash_app.data.queries import get_drug_directory_matrix
from visualization.plotly_generator import create_heatmap_figure
filter_id = (app_state or {}).get("date_filter_id", "all_6mo")
chart_type = (app_state or {}).get("chart_type", "directory")
selected_trusts = (app_state or {}).get("selected_trusts") or []
trust = selected_trusts[0] if len(selected_trusts) == 1 else None
try:
data = get_drug_directory_matrix(filter_id, chart_type, trust)
except Exception:
log.exception("Failed to load heatmap data")
return _empty_figure("Failed to load heatmap data.")
if not data.get("directories") or not data.get("drugs"):
return _empty_figure("No heatmap data available.\nTry adjusting your filters.")
return create_heatmap_figure(data, title, metric="patients")
def register_chart_callbacks(app): def register_chart_callbacks(app):
"""Register tab switching, pathway data loading, and chart rendering callbacks.""" """Register tab switching, pathway data loading, and chart rendering callbacks."""
@@ -338,6 +361,9 @@ def register_chart_callbacks(app):
elif active_tab == "dosing": elif active_tab == "dosing":
fig = _render_dosing(app_state, title) fig = _render_dosing(app_state, title)
elif active_tab == "heatmap":
fig = _render_heatmap(app_state, title)
else: else:
# Placeholder for charts not yet implemented # Placeholder for charts not yet implemented
tab_label = dict(TAB_DEFINITIONS).get(active_tab, active_tab) tab_label = dict(TAB_DEFINITIONS).get(active_tab, active_tab)
+140
View File
@@ -1184,3 +1184,143 @@ def figure_legacy(ice_df: pd.DataFrame, dir_string: str, save_dir: str) -> None:
fig.write_html(filepath) fig.write_html(filepath)
logger.info(f"Success! File saved to {filepath}") logger.info(f"Success! File saved to {filepath}")
webbrowser.open_new_tab("file:///" + filepath) webbrowser.open_new_tab("file:///" + filepath)
def create_heatmap_figure(
data: dict,
title: str = "",
metric: str = "patients",
) -> go.Figure:
"""Create a directorate × drug heatmap chart.
Args:
data: Dict from get_drug_directory_matrix() with keys:
directories (list), drugs (list),
matrix ({dir: {drug: {patients, cost, cost_pp_pa}}}).
title: Chart title suffix (filter description).
metric: Colour metric — "patients", "cost", or "cost_pp_pa".
Returns:
Plotly Figure with annotated heatmap.
"""
directories = data.get("directories", [])
drugs = data.get("drugs", [])
matrix = data.get("matrix", {})
if not directories or not drugs:
return go.Figure()
# Cap columns to top 25 drugs for readability
max_drugs = 25
drugs = drugs[:max_drugs]
metric_labels = {
"patients": "Patients",
"cost": "Total Cost (£)",
"cost_pp_pa": "Cost per Patient p.a. (£)",
}
metric_label = metric_labels.get(metric, "Patients")
# Build 2D arrays for z-values and hover text
z_values = []
hover_texts = []
for d in directories:
row_z = []
row_hover = []
dir_data = matrix.get(d, {})
for drug in drugs:
cell = dir_data.get(drug)
if cell:
val = cell.get(metric, cell.get("patients", 0))
patients = cell.get("patients", 0)
cost = cell.get("cost", 0)
cpp = cell.get("cost_pp_pa", 0)
row_z.append(val if val else 0)
row_hover.append(
f"<b>{drug}</b><br>"
f"{d}<br>"
f"Patients: {patients:,}<br>"
f"Total cost: £{cost:,.0f}<br>"
f"Cost p.a.: £{cpp:,.0f}"
)
else:
row_z.append(0)
row_hover.append(
f"<b>{drug}</b><br>{d}<br>No patients"
)
z_values.append(row_z)
hover_texts.append(row_hover)
# NHS blue colorscale for the heatmap
colorscale = [
[0.0, "#F0F4F8"],
[0.01, "#E3F2FD"],
[0.1, "#90CAF9"],
[0.3, "#42A5F5"],
[0.5, "#1E88E5"],
[0.7, "#0066CC"],
[1.0, "#003087"],
]
fig = go.Figure(
data=go.Heatmap(
z=z_values,
x=drugs,
y=directories,
colorscale=colorscale,
hovertext=hover_texts,
hovertemplate="%{hovertext}<extra></extra>",
colorbar=dict(
title=dict(
text=metric_label,
font=dict(size=12, color="#425563"),
),
thickness=15,
len=0.8,
),
xgap=2,
ygap=2,
)
)
chart_title = f"Directorate × Drug — {metric_label}"
if title:
chart_title = f"{chart_title}{title}"
n_drugs = len(drugs)
n_dirs = len(directories)
fig_width = max(700, 80 + n_drugs * 55)
fig_height = max(400, 80 + n_dirs * 40)
fig.update_layout(
title=dict(
text=chart_title,
font=dict(
family="Source Sans 3, system-ui, sans-serif",
size=18,
color="#003087",
),
x=0.5,
xanchor="center",
),
xaxis=dict(
title="",
tickfont=dict(size=11, color="#425563"),
tickangle=-45,
side="bottom",
),
yaxis=dict(
title="",
tickfont=dict(size=12, color="#425563"),
autorange="reversed",
),
plot_bgcolor="rgba(0,0,0,0)",
paper_bgcolor="rgba(0,0,0,0)",
font=dict(family="Source Sans 3, system-ui, sans-serif"),
margin=dict(t=60, l=200, r=80, b=120),
width=fig_width,
height=fig_height,
)
return fig