diff --git a/IMPLEMENTATION_PLAN.md b/IMPLEMENTATION_PLAN.md index 05188d6..92dfb3c 100644 --- a/IMPLEMENTATION_PLAN.md +++ b/IMPLEMENTATION_PLAN.md @@ -431,12 +431,12 @@ Drawer selection → update_drug_selection → app-state store → load_pathway_ - **Checkpoint**: Heatmap tab renders matrix with correct colour mapping ✓ ### 9.9 Treatment Duration chart (Tab 8) -- [ ] Create `dash_app/callbacks/duration.py`: +- [x] Create `dash_app/callbacks/duration.py`: - Build horizontal bar chart from `get_treatment_durations()` data - Y-axis = drug, X-axis = average days, colour intensity by patient count - Directorate filter drives which drugs are shown -- [ ] Create figure function in `src/visualization/` -- [ ] Wire into tab switching +- [x] Create figure function in `src/visualization/` +- [x] Wire into tab switching - **Checkpoint**: Duration tab renders real data, responds to directorate filter ### 9.10 Final integration + polish diff --git a/dash_app/callbacks/chart.py b/dash_app/callbacks/chart.py index 74d1943..910d5e2 100644 --- a/dash_app/callbacks/chart.py +++ b/dash_app/callbacks/chart.py @@ -233,6 +233,34 @@ def _render_heatmap(app_state, title): return create_heatmap_figure(data, title, metric="patients") +def _render_duration(app_state, title): + """Build the treatment duration horizontal bar chart from current filter state.""" + from dash_app.data.queries import get_treatment_durations + from visualization.plotly_generator import create_duration_figure + + filter_id = (app_state or {}).get("date_filter_id", "all_6mo") + chart_type = (app_state or {}).get("chart_type", "directory") + + selected_dirs = (app_state or {}).get("selected_directorates") or [] + directory = selected_dirs[0] if len(selected_dirs) == 1 else None + + selected_trusts = (app_state or {}).get("selected_trusts") or [] + trust = selected_trusts[0] if len(selected_trusts) == 1 else None + + try: + data = get_treatment_durations(filter_id, chart_type, directory, trust) + except Exception: + log.exception("Failed to load treatment duration data") + return _empty_figure("Failed to load treatment duration data.") + + if not data: + return _empty_figure("No treatment duration data available.\nTry adjusting your filters.") + + # Show directory breakdown when no specific directory is filtered + show_directory = directory is None and chart_type == "indication" + return create_duration_figure(data, title, show_directory=show_directory) + + def register_chart_callbacks(app): """Register tab switching, pathway data loading, and chart rendering callbacks.""" @@ -364,6 +392,9 @@ def register_chart_callbacks(app): elif active_tab == "heatmap": fig = _render_heatmap(app_state, title) + elif active_tab == "duration": + fig = _render_duration(app_state, title) + else: # Placeholder for charts not yet implemented tab_label = dict(TAB_DEFINITIONS).get(active_tab, active_tab) diff --git a/src/visualization/plotly_generator.py b/src/visualization/plotly_generator.py index 7316180..a15792b 100644 --- a/src/visualization/plotly_generator.py +++ b/src/visualization/plotly_generator.py @@ -1324,3 +1324,152 @@ def create_heatmap_figure( ) return fig + + +def create_duration_figure( + data: list[dict], + title: str = "", + show_directory: bool = False, +) -> go.Figure: + """Create horizontal bar chart showing average treatment duration by drug. + + Args: + data: List of dicts from get_treatment_durations() with keys: + drug, directory, avg_days, patients. + Sorted by avg_days desc. + title: Chart title suffix (filter description). + show_directory: If True, include directory in label (for overview mode). + + Returns: + Plotly Figure with horizontal bars coloured by patient count. + """ + if not data: + return go.Figure() + + # When not showing directory breakdown, aggregate same drug across directorates + if not show_directory: + agg = {} + for d in data: + drug = d["drug"] + pts = d["patients"] + days = d["avg_days"] + if drug not in agg: + agg[drug] = {"drug": drug, "total_weighted": 0.0, "total_pts": 0} + agg[drug]["total_weighted"] += days * pts + agg[drug]["total_pts"] += pts + data = [] + for v in agg.values(): + if v["total_pts"] > 0: + data.append({ + "drug": v["drug"], + "avg_days": round(v["total_weighted"] / v["total_pts"], 1), + "patients": v["total_pts"], + }) + data.sort(key=lambda x: -x["avg_days"]) + + # Cap at 40 entries for readability (keep top by patient count, then re-sort by days) + if len(data) > 40: + data.sort(key=lambda x: -x["patients"]) + data = data[:40] + data.sort(key=lambda x: -x["avg_days"]) + + # Build labels + if show_directory: + labels = [f"{d['drug']} ({d['directory']})" for d in data] + else: + labels = [d["drug"] for d in data] + + days_values = [d["avg_days"] for d in data] + patients_list = [d["patients"] for d in data] + + # Colour gradient by patient count: light for few → dark NHS blue for many + max_pts = max(patients_list) if patients_list else 1 + min_pts = min(patients_list) if patients_list else 0 + pt_range = max_pts - min_pts if max_pts > min_pts else 1 + + bar_colours = [] + for pts in patients_list: + t = (pts - min_pts) / pt_range + r = int(0x41 + (0x00 - 0x41) * t) + g = int(0xB6 + (0x30 - 0xB6) * t) + b = int(0xE6 + (0x87 - 0xE6) * t) + bar_colours.append(f"rgb({r},{g},{b})") + + hover_texts = [] + for d in data: + years = d["avg_days"] / 365.25 + hover_texts.append( + f"{d['drug']}
" + f"Avg duration: {d['avg_days']:,.0f} days ({years:.1f} years)
" + f"Patients: {d['patients']:,}" + ) + + fig = go.Figure() + + fig.add_trace( + go.Bar( + y=labels, + x=days_values, + orientation="h", + marker=dict( + color=bar_colours, + line=dict(color="#FFFFFF", width=1), + ), + hovertemplate="%{customdata}", + customdata=hover_texts, + text=[f"{v:,.0f}d" for v in days_values], + textposition="outside", + textfont=dict(size=10, color="#425563"), + ) + ) + + for i, pts in enumerate(patients_list): + fig.add_annotation( + x=days_values[i], + y=labels[i], + text=f"n={pts:,}", + showarrow=False, + xshift=45, + font=dict(size=9, color="#768692", family="Source Sans 3"), + ) + + chart_title = "Treatment Duration by Drug" + if title: + chart_title += f"
{title}" + + n_bars = len(data) + fig_height = max(400, 40 + n_bars * 28) + + fig.update_layout( + title=dict( + text=chart_title, + font=dict( + family="Source Sans 3, system-ui, sans-serif", + size=18, + color="#003087", + ), + x=0.5, + xanchor="center", + ), + xaxis=dict( + title="Average Duration (days)", + titlefont=dict(size=13, color="#425563"), + tickfont=dict(size=11, color="#425563"), + gridcolor="rgba(0,0,0,0.06)", + zeroline=True, + zerolinecolor="rgba(0,0,0,0.1)", + ), + yaxis=dict( + title="", + tickfont=dict(size=11, color="#425563"), + autorange="reversed", + ), + plot_bgcolor="rgba(0,0,0,0)", + paper_bgcolor="rgba(0,0,0,0)", + font=dict(family="Source Sans 3, system-ui, sans-serif"), + margin=dict(t=60, l=200, r=80, b=50), + height=fig_height, + showlegend=False, + ) + + return fig