diff --git a/IMPLEMENTATION_PLAN.md b/IMPLEMENTATION_PLAN.md
index 05188d6..92dfb3c 100644
--- a/IMPLEMENTATION_PLAN.md
+++ b/IMPLEMENTATION_PLAN.md
@@ -431,12 +431,12 @@ Drawer selection → update_drug_selection → app-state store → load_pathway_
- **Checkpoint**: Heatmap tab renders matrix with correct colour mapping ✓
### 9.9 Treatment Duration chart (Tab 8)
-- [ ] Create `dash_app/callbacks/duration.py`:
+- [x] Create `dash_app/callbacks/duration.py`:
- Build horizontal bar chart from `get_treatment_durations()` data
- Y-axis = drug, X-axis = average days, colour intensity by patient count
- Directorate filter drives which drugs are shown
-- [ ] Create figure function in `src/visualization/`
-- [ ] Wire into tab switching
+- [x] Create figure function in `src/visualization/`
+- [x] Wire into tab switching
- **Checkpoint**: Duration tab renders real data, responds to directorate filter
### 9.10 Final integration + polish
diff --git a/dash_app/callbacks/chart.py b/dash_app/callbacks/chart.py
index 74d1943..910d5e2 100644
--- a/dash_app/callbacks/chart.py
+++ b/dash_app/callbacks/chart.py
@@ -233,6 +233,34 @@ def _render_heatmap(app_state, title):
return create_heatmap_figure(data, title, metric="patients")
+def _render_duration(app_state, title):
+ """Build the treatment duration horizontal bar chart from current filter state."""
+ from dash_app.data.queries import get_treatment_durations
+ from visualization.plotly_generator import create_duration_figure
+
+ filter_id = (app_state or {}).get("date_filter_id", "all_6mo")
+ chart_type = (app_state or {}).get("chart_type", "directory")
+
+ selected_dirs = (app_state or {}).get("selected_directorates") or []
+ directory = selected_dirs[0] if len(selected_dirs) == 1 else None
+
+ selected_trusts = (app_state or {}).get("selected_trusts") or []
+ trust = selected_trusts[0] if len(selected_trusts) == 1 else None
+
+ try:
+ data = get_treatment_durations(filter_id, chart_type, directory, trust)
+ except Exception:
+ log.exception("Failed to load treatment duration data")
+ return _empty_figure("Failed to load treatment duration data.")
+
+ if not data:
+ return _empty_figure("No treatment duration data available.\nTry adjusting your filters.")
+
+ # Show directory breakdown when no specific directory is filtered
+ show_directory = directory is None and chart_type == "indication"
+ return create_duration_figure(data, title, show_directory=show_directory)
+
+
def register_chart_callbacks(app):
"""Register tab switching, pathway data loading, and chart rendering callbacks."""
@@ -364,6 +392,9 @@ def register_chart_callbacks(app):
elif active_tab == "heatmap":
fig = _render_heatmap(app_state, title)
+ elif active_tab == "duration":
+ fig = _render_duration(app_state, title)
+
else:
# Placeholder for charts not yet implemented
tab_label = dict(TAB_DEFINITIONS).get(active_tab, active_tab)
diff --git a/src/visualization/plotly_generator.py b/src/visualization/plotly_generator.py
index 7316180..a15792b 100644
--- a/src/visualization/plotly_generator.py
+++ b/src/visualization/plotly_generator.py
@@ -1324,3 +1324,152 @@ def create_heatmap_figure(
)
return fig
+
+
+def create_duration_figure(
+ data: list[dict],
+ title: str = "",
+ show_directory: bool = False,
+) -> go.Figure:
+ """Create horizontal bar chart showing average treatment duration by drug.
+
+ Args:
+ data: List of dicts from get_treatment_durations() with keys:
+ drug, directory, avg_days, patients.
+ Sorted by avg_days desc.
+ title: Chart title suffix (filter description).
+ show_directory: If True, include directory in label (for overview mode).
+
+ Returns:
+ Plotly Figure with horizontal bars coloured by patient count.
+ """
+ if not data:
+ return go.Figure()
+
+ # When not showing directory breakdown, aggregate same drug across directorates
+ if not show_directory:
+ agg = {}
+ for d in data:
+ drug = d["drug"]
+ pts = d["patients"]
+ days = d["avg_days"]
+ if drug not in agg:
+ agg[drug] = {"drug": drug, "total_weighted": 0.0, "total_pts": 0}
+ agg[drug]["total_weighted"] += days * pts
+ agg[drug]["total_pts"] += pts
+ data = []
+ for v in agg.values():
+ if v["total_pts"] > 0:
+ data.append({
+ "drug": v["drug"],
+ "avg_days": round(v["total_weighted"] / v["total_pts"], 1),
+ "patients": v["total_pts"],
+ })
+ data.sort(key=lambda x: -x["avg_days"])
+
+ # Cap at 40 entries for readability (keep top by patient count, then re-sort by days)
+ if len(data) > 40:
+ data.sort(key=lambda x: -x["patients"])
+ data = data[:40]
+ data.sort(key=lambda x: -x["avg_days"])
+
+ # Build labels
+ if show_directory:
+ labels = [f"{d['drug']} ({d['directory']})" for d in data]
+ else:
+ labels = [d["drug"] for d in data]
+
+ days_values = [d["avg_days"] for d in data]
+ patients_list = [d["patients"] for d in data]
+
+ # Colour gradient by patient count: light for few → dark NHS blue for many
+ max_pts = max(patients_list) if patients_list else 1
+ min_pts = min(patients_list) if patients_list else 0
+ pt_range = max_pts - min_pts if max_pts > min_pts else 1
+
+ bar_colours = []
+ for pts in patients_list:
+ t = (pts - min_pts) / pt_range
+ r = int(0x41 + (0x00 - 0x41) * t)
+ g = int(0xB6 + (0x30 - 0xB6) * t)
+ b = int(0xE6 + (0x87 - 0xE6) * t)
+ bar_colours.append(f"rgb({r},{g},{b})")
+
+ hover_texts = []
+ for d in data:
+ years = d["avg_days"] / 365.25
+ hover_texts.append(
+ f"{d['drug']}
"
+ f"Avg duration: {d['avg_days']:,.0f} days ({years:.1f} years)
"
+ f"Patients: {d['patients']:,}"
+ )
+
+ fig = go.Figure()
+
+ fig.add_trace(
+ go.Bar(
+ y=labels,
+ x=days_values,
+ orientation="h",
+ marker=dict(
+ color=bar_colours,
+ line=dict(color="#FFFFFF", width=1),
+ ),
+ hovertemplate="%{customdata}",
+ customdata=hover_texts,
+ text=[f"{v:,.0f}d" for v in days_values],
+ textposition="outside",
+ textfont=dict(size=10, color="#425563"),
+ )
+ )
+
+ for i, pts in enumerate(patients_list):
+ fig.add_annotation(
+ x=days_values[i],
+ y=labels[i],
+ text=f"n={pts:,}",
+ showarrow=False,
+ xshift=45,
+ font=dict(size=9, color="#768692", family="Source Sans 3"),
+ )
+
+ chart_title = "Treatment Duration by Drug"
+ if title:
+ chart_title += f"
{title}"
+
+ n_bars = len(data)
+ fig_height = max(400, 40 + n_bars * 28)
+
+ fig.update_layout(
+ title=dict(
+ text=chart_title,
+ font=dict(
+ family="Source Sans 3, system-ui, sans-serif",
+ size=18,
+ color="#003087",
+ ),
+ x=0.5,
+ xanchor="center",
+ ),
+ xaxis=dict(
+ title="Average Duration (days)",
+ titlefont=dict(size=13, color="#425563"),
+ tickfont=dict(size=11, color="#425563"),
+ gridcolor="rgba(0,0,0,0.06)",
+ zeroline=True,
+ zerolinecolor="rgba(0,0,0,0.1)",
+ ),
+ yaxis=dict(
+ title="",
+ tickfont=dict(size=11, color="#425563"),
+ autorange="reversed",
+ ),
+ plot_bgcolor="rgba(0,0,0,0)",
+ paper_bgcolor="rgba(0,0,0,0)",
+ font=dict(family="Source Sans 3, system-ui, sans-serif"),
+ margin=dict(t=60, l=200, r=80, b=50),
+ height=fig_height,
+ showlegend=False,
+ )
+
+ return fig