From d0404aa18aaab18f496aea9034b999494ff09c80 Mon Sep 17 00:00:00 2001 From: Andrew Charlwood Date: Sat, 7 Feb 2026 18:24:34 +0000 Subject: [PATCH] feat: temporal trends CLI script + Dash tab (Task D.1) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit D.1a: cli/compute_trends.py — standalone CLI that imports existing pipeline functions to replay pathway computation for ~10 historical 6-month endpoints. Creates pathway_trends table via CREATE TABLE IF NOT EXISTS. D.1b: Trends tab (10th PP tab) with metric toggle (patients/cost/cost_pp_pa). Query gracefully returns empty when table doesn't exist, figure shows instruction message to run compute_trends. --- IMPLEMENTATION_PLAN.md | 48 ++-- dash_app/callbacks/chart.py | 43 ++- dash_app/components/chart_card.py | 18 ++ dash_app/data/queries.py | 10 + src/cli/compute_trends.py | 345 +++++++++++++++++++++++++ src/data_processing/pathway_queries.py | 85 ++++++ src/visualization/plotly_generator.py | 86 ++++++ 7 files changed, 603 insertions(+), 32 deletions(-) create mode 100644 src/cli/compute_trends.py diff --git a/IMPLEMENTATION_PLAN.md b/IMPLEMENTATION_PLAN.md index 5e76fe6..1b97728 100644 --- a/IMPLEMENTATION_PLAN.md +++ b/IMPLEMENTATION_PLAN.md @@ -189,21 +189,28 @@ Comprehensive review and improvement of all Plotly charts in the Dash dashboard. ## Phase D: New Analytics (Backend Work) -### D.1 Temporal trend analysis -- [B] **BLOCKED**: Requires modifying guardrail-protected files (`schema.py`, `reference_data.py`, `refresh_pathways.py`) + needs ≥2 refresh cycles for meaningful data -- [ ] Design `pathway_trends` table schema in `src/data_processing/schema.py`: - - Columns: `snapshot_date, chart_type, directory, drug, patients, cost, cost_pp_pa` - - Stores quarterly aggregates from each refresh -- [ ] Add migration for `pathway_trends` table in `data_processing/reference_data.py` -- [ ] Extend `cli/refresh_pathways.py` to compute and insert trend data after main refresh -- [ ] Create `get_trend_data()` query in `pathway_queries.py` -- [ ] Add thin wrapper in `dash_app/data/queries.py` -- [ ] Create `create_trend_figure(data, title, metric)` in plotly_generator.py: - - Line chart: x=date, y=metric, one line per drug (or directory) - - Metric selector: patients / cost / cost_pp_pa -- [ ] Add "Trends" tab to `TAB_DEFINITIONS` in `chart_card.py` -- [ ] Add callback wiring -- **Checkpoint**: Trends tab shows drug usage over time (requires at least 2 refresh cycles for meaningful data) +### D.1 Temporal trend analysis (historical snapshots approach) +- [x] **D.1a — Create `cli/compute_trends.py` CLI script**: + - Creates `pathway_trends` table via `CREATE TABLE IF NOT EXISTS` (no schema.py changes): + ``` + pathway_trends(period_end TEXT, drug TEXT, directory TEXT, patients INTEGER, + total_cost REAL, cost_pp_pa REAL, PRIMARY KEY(period_end, drug, directory)) + ``` + - Imports existing `fetch_and_transform_data()` and `process_pathway_for_date_filter()` from `pathway_pipeline.py` — does NOT modify them + - Fetches all activity data once from Snowflake + - Loops over 6-month historical endpoints (2021-06-30 through 2025-12-31, ~10 periods) + - For each endpoint: calls `process_pathway_for_date_filter()` with `max_date=endpoint` using `all_6mo` config + - Extracts level 3 summary stats (drug, directory, patients, cost, cost_pp_pa) from resulting DataFrame + - Inserts aggregated rows into `pathway_trends` table + - Run separately: `python -m cli.compute_trends` (not part of main refresh) +- [x] **D.1b — Add Trends tab to Dash** (standard 6-step pattern): + 1. Create `get_trend_data(db_path, metric, directory, drug)` in `pathway_queries.py` — query `pathway_trends` table, return time-series data + 2. Add thin wrapper in `dash_app/data/queries.py` + 3. Create `create_trend_figure(data, title, metric)` in `plotly_generator.py` — line chart: x=period_end, y=metric, one line per drug (or per directory). Uses `_base_layout()` + `_smart_legend()`. Add `dmc.SegmentedControl` for metric toggle (patients / cost / cost_pp_pa) + 4. Add "Trends" tab to `TAB_DEFINITIONS` in `chart_card.py` + 5. Add `_render_trends()` helper + dispatch case in `chart.py` + 6. Handle empty state: if `pathway_trends` table doesn't exist or is empty, show "Run `python -m cli.compute_trends` to generate trend data" message +- **Checkpoint**: Trends tab shows drug/directory trends over 10 historical periods, responds to filters. Empty state handled gracefully if trends not yet computed. ### D.2 Average administered doses analysis - [x] Create `get_dosing_distribution()` query in `pathway_queries.py`: @@ -233,14 +240,6 @@ Comprehensive review and improvement of all Plotly charts in the Dash dashboard. - [x] Add `_render_timeline()` helper + dispatch case in `chart.py` - **Checkpoint**: Timeline tab shows when each drug cohort was active -### D.4 NICE TA compliance dashboard -- [B] **BLOCKED**: `data/ta-recommendations.xlsx` does not exist (source data missing). Also requires schema + migration (guardrail-protected files) -- [ ] Parse `data/ta-recommendations.xlsx` into a reference table -- [ ] Create schema and migration for TA compliance reference data -- [ ] Create compliance scoring: cross-reference pathway data with TA recommendations -- [ ] Create `create_ta_compliance_figure(data, title)` — traffic-light matrix -- [ ] Add "Compliance" tab or separate Trust Comparison sub-view -- **Checkpoint**: Compliance matrix shows alignment with NICE guidance --- @@ -269,10 +268,9 @@ Comprehensive review and improvement of all Plotly charts in the Dash dashboard. - [x] `python run_dash.py` starts cleanly ### Phase D -- [B] Temporal trends — BLOCKED (requires guardrail-protected file changes + ≥2 refresh cycles) +- [x] Temporal trends computed via historical snapshots (CLI script + Dash tab) - [x] Dose distribution shows average administered doses per drug - [x] Drug timeline shows Gantt-style cohort activity -- [B] NICE TA compliance — BLOCKED (source data file missing + requires guardrail-protected file changes) - [x] `python run_dash.py` starts cleanly --- diff --git a/dash_app/callbacks/chart.py b/dash_app/callbacks/chart.py index 8f52dda..8242b6f 100644 --- a/dash_app/callbacks/chart.py +++ b/dash_app/callbacks/chart.py @@ -417,6 +417,25 @@ def _render_doses(app_state, title): return create_dosing_distribution_figure(data, title) +def _render_trends(app_state, title, metric="patients"): + """Build the temporal trends line chart.""" + from dash_app.data.queries import get_trend_data + from visualization.plotly_generator import create_trend_figure + + selected_dirs = (app_state or {}).get("selected_directorates") or [] + selected_drugs = (app_state or {}).get("selected_drugs") or [] + directory = selected_dirs[0] if len(selected_dirs) == 1 else None + drug = selected_drugs[0] if len(selected_drugs) == 1 else None + + try: + data = get_trend_data(metric=metric, directory=directory, drug=drug) + except Exception: + log.exception("Failed to load trend data") + return _empty_figure("Failed to load trend data.") + + return create_trend_figure(data, title, metric=metric) + + def register_chart_callbacks(app): """Register tab switching, pathway data loading, and chart rendering callbacks.""" @@ -496,18 +515,21 @@ def register_chart_callbacks(app): Output("pathway-chart", "figure"), Output("chart-subtitle", "children"), Output("heatmap-metric-wrapper", "style"), + Output("trends-metric-wrapper", "style"), Input("chart-data", "data"), Input("active-tab", "data"), Input("app-state", "data"), Input("heatmap-metric-toggle", "value"), + Input("trends-metric-toggle", "value"), ) - def update_chart(chart_data, active_tab, app_state, heatmap_metric): + def update_chart(chart_data, active_tab, app_state, heatmap_metric, trends_metric): """Render the active tab's chart from chart-data nodes.""" active_tab = active_tab or "icicle" chart_type = (app_state or {}).get("chart_type", "directory") - # Show/hide heatmap metric toggle based on active tab - toggle_style = {} if active_tab == "heatmap" else {"display": "none"} + # Show/hide metric toggles based on active tab + heatmap_toggle_style = {} if active_tab == "heatmap" else {"display": "none"} + trends_toggle_style = {} if active_tab == "trends" else {"display": "none"} if chart_type == "indication": subtitle = "Trust \u2192 Indication \u2192 Drug \u2192 Patient Pathway" @@ -515,17 +537,24 @@ def register_chart_callbacks(app): subtitle = "Trust \u2192 Directorate \u2192 Drug \u2192 Patient Pathway" if not chart_data: - return no_update, no_update, toggle_style + return no_update, no_update, heatmap_toggle_style, trends_toggle_style error_msg = chart_data.get("error") if error_msg: - return _empty_figure(error_msg), subtitle, toggle_style + return _empty_figure(error_msg), subtitle, heatmap_toggle_style, trends_toggle_style + + # Trends tab doesn't depend on chart-data nodes + if active_tab == "trends": + title = _generate_chart_title(app_state) if app_state else "" + metric = trends_metric or "patients" + fig = _render_trends(app_state, title, metric=metric) + return fig, subtitle, heatmap_toggle_style, trends_toggle_style if not chart_data.get("nodes"): return _empty_figure( "No matching pathways found.\n" "Try adjusting your filters." - ), subtitle, toggle_style + ), subtitle, heatmap_toggle_style, trends_toggle_style # Lazy rendering — only compute the active tab's chart title = _generate_chart_title(app_state) if app_state else "" @@ -580,4 +609,4 @@ def register_chart_callbacks(app): tab_label = dict(TAB_DEFINITIONS).get(active_tab, active_tab) fig = _empty_figure(f"{tab_label} chart — coming soon") - return fig, subtitle, toggle_style + return fig, subtitle, heatmap_toggle_style, trends_toggle_style diff --git a/dash_app/components/chart_card.py b/dash_app/components/chart_card.py index 09eafaa..33fc9d0 100644 --- a/dash_app/components/chart_card.py +++ b/dash_app/components/chart_card.py @@ -14,6 +14,7 @@ TAB_DEFINITIONS = [ ("network", "Network"), ("timeline", "Timeline"), ("doses", "Doses"), + ("trends", "Trends"), ] # Full set retained for Trust Comparison dashboard (Phase 10.8) @@ -94,6 +95,23 @@ def make_chart_card(): ), ], ), + # Trends metric toggle — visible only when trends tab active + html.Div( + id="trends-metric-wrapper", + style={"display": "none"}, + children=[ + dmc.SegmentedControl( + id="trends-metric-toggle", + data=[ + {"value": "patients", "label": "Patients"}, + {"value": "total_cost", "label": "Cost"}, + {"value": "cost_pp_pa", "label": "Cost p.a."}, + ], + value="patients", + size="xs", + ), + ], + ), ], ), # Chart area with loading spinner diff --git a/dash_app/data/queries.py b/dash_app/data/queries.py index cbdf315..ce7a549 100644 --- a/dash_app/data/queries.py +++ b/dash_app/data/queries.py @@ -30,6 +30,7 @@ from data_processing.pathway_queries import ( get_drug_network as _get_drug_network, get_drug_timeline as _get_drug_timeline, get_dosing_distribution as _get_dosing_distribution, + get_trend_data as _get_trend_data, ) DB_PATH = Path(__file__).resolve().parents[2] / "data" / "pathways.db" @@ -249,3 +250,12 @@ def get_dosing_distribution( ) -> list[dict]: """Average administered dose counts per drug.""" return _get_dosing_distribution(DB_PATH, date_filter_id, chart_type, directory, trust) + + +def get_trend_data( + metric: str = "patients", + directory: Optional[str] = None, + drug: Optional[str] = None, +) -> list[dict]: + """Time-series trend data from pathway_trends table.""" + return _get_trend_data(DB_PATH, metric, directory, drug) diff --git a/src/cli/compute_trends.py b/src/cli/compute_trends.py new file mode 100644 index 0000000..91b1ddf --- /dev/null +++ b/src/cli/compute_trends.py @@ -0,0 +1,345 @@ +""" +CLI command for computing historical trend snapshots. + +This command fetches all activity data from Snowflake once, then replays the +pathway computation for ~10 historical 6-month endpoints (2021-06-30 through +2025-12-31). For each period, level-3 node summaries (drug × directory) are +extracted and stored in a `pathway_trends` table in pathways.db. + +The Dash "Trends" tab then queries this table to show how drug patient counts, +costs, and cost-per-patient have changed over time. + +Usage: + python -m cli.compute_trends + python -m cli.compute_trends --start 2022-01-01 --end 2025-06-30 + python -m cli.compute_trends --interval 12 # 12-month steps + python -m cli.compute_trends --dry-run -v + +Run `python -m cli.compute_trends --help` for full options. +""" + +import argparse +import sqlite3 +import sys +import time +from datetime import date, timedelta +from pathlib import Path +from typing import Optional + +# Ensure src/ is on sys.path when run as `python -m cli.compute_trends` +_src_dir = str(Path(__file__).resolve().parent.parent) +if _src_dir not in sys.path: + sys.path.insert(0, _src_dir) + +from core import PathConfig, default_paths +from core.logging_config import get_logger, setup_logging +from data_processing.pathway_pipeline import ( + DateFilterConfig, + fetch_and_transform_data, + process_pathway_for_date_filter, + extract_denormalized_fields, +) + +logger = get_logger(__name__) + +# Use the all_6mo config: all years initiated, last seen within 6 months +TREND_FILTER_CONFIG = DateFilterConfig( + id="all_6mo", initiated_years=None, last_seen_months=6 +) + +CREATE_TRENDS_TABLE = """ +CREATE TABLE IF NOT EXISTS pathway_trends ( + period_end TEXT NOT NULL, + drug TEXT NOT NULL, + directory TEXT NOT NULL, + patients INTEGER NOT NULL, + total_cost REAL NOT NULL, + cost_pp_pa REAL, + PRIMARY KEY (period_end, drug, directory) +) +""" + + +def generate_period_endpoints( + start: date, + end: date, + interval_months: int = 6, +) -> list[date]: + """Generate period end-dates from start to end at interval_months steps.""" + endpoints = [] + current = start + while current <= end: + endpoints.append(current) + # Advance by interval_months + month = current.month + interval_months + year = current.year + (month - 1) // 12 + month = ((month - 1) % 12) + 1 + # Use last day of the target month or keep day if valid + import calendar + max_day = calendar.monthrange(year, month)[1] + day = min(current.day, max_day) + current = date(year, month, day) + return endpoints + + +def extract_level3_summaries(ice_df) -> list[dict]: + """Extract level-3 (drug) node summaries from ice_df DataFrame. + + Returns list of dicts with: drug, directory, patients, total_cost, cost_pp_pa + """ + import pandas as pd + + level3 = ice_df[ice_df["level"] == 3].copy() + if level3.empty: + return [] + + # Extract denormalized fields to get drug and directory + level3 = extract_denormalized_fields(level3) + + rows = [] + for _, row in level3.iterrows(): + drug_seq = row.get("drug_sequence", "") + directory = row.get("directory", "") + if not drug_seq or not directory: + continue + + cost_pp_pa = row.get("cost_pp_pa") + try: + cost_pp_pa = float(cost_pp_pa) if pd.notna(cost_pp_pa) and cost_pp_pa != "" else None + except (ValueError, TypeError): + cost_pp_pa = None + + rows.append({ + "drug": drug_seq, + "directory": directory, + "patients": int(row.get("value", 0)), + "total_cost": float(row.get("cost", 0)), + "cost_pp_pa": cost_pp_pa, + }) + + return rows + + +def compute_trends( + start: date = date(2021, 6, 30), + end: date = date(2025, 12, 31), + interval_months: int = 6, + minimum_patients: int = 5, + db_path: Optional[Path] = None, + paths: Optional[PathConfig] = None, + dry_run: bool = False, +) -> tuple[bool, str]: + """ + Main function: fetch data, replay pathway computation for each period, store summaries. + + Args: + start: First period endpoint + end: Last period endpoint + interval_months: Months between endpoints + minimum_patients: Min patients for pathway inclusion + db_path: Path to pathways.db (uses default if None) + paths: PathConfig for reference files + dry_run: If True, compute but don't write to DB + + Returns: + (success, message) tuple + """ + if paths is None: + paths = default_paths + + if db_path is None: + db_path = paths.data_dir / "pathways.db" + + endpoints = generate_period_endpoints(start, end, interval_months) + logger.info(f"Will compute trends for {len(endpoints)} periods: " + f"{endpoints[0].isoformat()} to {endpoints[-1].isoformat()}") + + # Load default filters (same as refresh_pathways) + from cli.refresh_pathways import get_default_filters + trust_filter, drug_filter, directory_filter = get_default_filters(paths) + + if not drug_filter: + return False, "No drugs found in default filters" + + logger.info(f"Filters: {len(trust_filter)} trusts, {len(drug_filter)} drugs, " + f"{len(directory_filter)} directories") + + start_time = time.time() + + # Step 1: Fetch all activity data from Snowflake (one-time) + logger.info("Step 1: Fetching all activity data from Snowflake...") + df = fetch_and_transform_data(paths=paths) + + if df.empty: + return False, "No data returned from Snowflake" + + logger.info(f"Fetched {len(df)} records") + + # Step 2: Create trends table + if not dry_run: + conn = sqlite3.connect(str(db_path)) + conn.execute(CREATE_TRENDS_TABLE) + conn.commit() + logger.info("Created pathway_trends table (if not exists)") + else: + conn = None + + # Step 3: Process each historical endpoint + total_rows = 0 + period_stats = [] + + for i, endpoint in enumerate(endpoints, 1): + logger.info(f"Period {i}/{len(endpoints)}: computing pathways as of {endpoint.isoformat()}...") + + ice_df = process_pathway_for_date_filter( + df=df, + config=TREND_FILTER_CONFIG, + trust_filter=trust_filter, + drug_filter=drug_filter, + directory_filter=directory_filter, + minimum_patients=minimum_patients, + max_date=endpoint, + paths=paths, + ) + + if ice_df is None: + logger.warning(f" No data for period ending {endpoint.isoformat()}") + period_stats.append((endpoint, 0)) + continue + + summaries = extract_level3_summaries(ice_df) + period_stats.append((endpoint, len(summaries))) + total_rows += len(summaries) + + logger.info(f" {len(summaries)} drug×directory rows for {endpoint.isoformat()}") + + if not dry_run and conn and summaries: + # Insert/replace rows for this period + conn.executemany( + "INSERT OR REPLACE INTO pathway_trends " + "(period_end, drug, directory, patients, total_cost, cost_pp_pa) " + "VALUES (?, ?, ?, ?, ?, ?)", + [ + ( + endpoint.isoformat(), + s["drug"], + s["directory"], + s["patients"], + s["total_cost"], + s["cost_pp_pa"], + ) + for s in summaries + ], + ) + conn.commit() + + if conn: + conn.close() + + elapsed = time.time() - start_time + + # Summary + logger.info("") + logger.info("=" * 50) + logger.info(f"Trend computation complete in {elapsed:.1f}s") + logger.info(f"Periods processed: {len(endpoints)}") + logger.info(f"Total rows: {total_rows}") + for ep, count in period_stats: + logger.info(f" {ep.isoformat()}: {count} rows") + if dry_run: + logger.info("(DRY RUN — no data written)") + logger.info("=" * 50) + + return True, f"Computed {total_rows} trend rows across {len(endpoints)} periods in {elapsed:.1f}s" + + +def main() -> int: + """CLI entry point.""" + parser = argparse.ArgumentParser( + description="Compute historical trend snapshots for pathway analysis", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + # Default: 6-month intervals from 2021-06-30 to 2025-12-31 + python -m cli.compute_trends + + # Custom date range + python -m cli.compute_trends --start 2022-01-01 --end 2025-06-30 + + # 12-month intervals + python -m cli.compute_trends --interval 12 + + # Dry run + python -m cli.compute_trends --dry-run -v + """, + ) + + parser.add_argument( + "--start", + type=str, + default="2021-06-30", + help="First period endpoint (ISO date, default: 2021-06-30)", + ) + parser.add_argument( + "--end", + type=str, + default="2025-12-31", + help="Last period endpoint (ISO date, default: 2025-12-31)", + ) + parser.add_argument( + "--interval", + type=int, + default=6, + help="Months between endpoints (default: 6)", + ) + parser.add_argument( + "--minimum-patients", + type=int, + default=5, + help="Min patients per pathway (default: 5)", + ) + parser.add_argument( + "--db-path", + type=str, + default=None, + help="Path to pathways.db (default: data/pathways.db)", + ) + parser.add_argument( + "--dry-run", + action="store_true", + help="Compute but don't write to database", + ) + parser.add_argument( + "--verbose", "-v", + action="store_true", + help="Enable verbose logging", + ) + + args = parser.parse_args() + + import logging + setup_logging(level=logging.DEBUG if args.verbose else logging.INFO) + + start_date = date.fromisoformat(args.start) + end_date = date.fromisoformat(args.end) + db_path_arg = Path(args.db_path) if args.db_path else None + + success, message = compute_trends( + start=start_date, + end=end_date, + interval_months=args.interval, + minimum_patients=args.minimum_patients, + db_path=db_path_arg, + dry_run=args.dry_run, + ) + + if success: + print(f"\n[OK] {message}") + return 0 + else: + print(f"\n[FAILED] {message}", file=sys.stderr) + return 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/src/data_processing/pathway_queries.py b/src/data_processing/pathway_queries.py index d72db83..baa01ba 100644 --- a/src/data_processing/pathway_queries.py +++ b/src/data_processing/pathway_queries.py @@ -1584,3 +1584,88 @@ def get_directorate_summary( return [] finally: conn.close() + + +def get_trend_data( + db_path: Path, + metric: str = "patients", + directory: Optional[str] = None, + drug: Optional[str] = None, +) -> list[dict]: + """ + Query pathway_trends table for time-series data. + + Returns list of dicts with: period_end, name (drug or directory), value. + Groups by drug (one line per drug) unless aggregating by directory. + + Args: + db_path: Path to pathways.db + metric: "patients", "total_cost", or "cost_pp_pa" + directory: Optional directory filter + drug: Optional drug filter + + Returns: + List of dicts: [{period_end, name, value}, ...] + Empty list if pathway_trends table doesn't exist or has no data. + """ + conn = sqlite3.connect(str(db_path)) + conn.row_factory = sqlite3.Row + + try: + # Check if the table exists + cursor = conn.cursor() + cursor.execute( + "SELECT name FROM sqlite_master WHERE type='table' AND name='pathway_trends'" + ) + if not cursor.fetchone(): + return [] + + valid_metrics = {"patients", "total_cost", "cost_pp_pa"} + if metric not in valid_metrics: + metric = "patients" + + # Build query — group by drug (one line per drug over time) + where_clauses = [] + params = [] + + if directory: + where_clauses.append("directory = ?") + params.append(directory) + if drug: + where_clauses.append("drug = ?") + params.append(drug) + + where_sql = " AND ".join(where_clauses) if where_clauses else "1=1" + + # Aggregate across directories per drug per period (or per directory if filtering by drug) + if drug: + # One line per directory for a specific drug + group_col = "directory" + else: + # One line per drug (aggregate across directories) + group_col = "drug" + + if metric == "cost_pp_pa": + # Weighted average for cost_pp_pa + agg = "SUM(cost_pp_pa * patients) / NULLIF(SUM(patients), 0)" + elif metric == "total_cost": + agg = "SUM(total_cost)" + else: + agg = "SUM(patients)" + + sql = f""" + SELECT period_end, {group_col} AS name, {agg} AS value + FROM pathway_trends + WHERE {where_sql} + GROUP BY period_end, {group_col} + ORDER BY period_end, {group_col} + """ + + cursor.execute(sql, params) + rows = [dict(row) for row in cursor.fetchall()] + return rows + + except sqlite3.Error: + return [] + finally: + conn.close() diff --git a/src/visualization/plotly_generator.py b/src/visualization/plotly_generator.py index ad34f09..6ad0bed 100644 --- a/src/visualization/plotly_generator.py +++ b/src/visualization/plotly_generator.py @@ -2297,3 +2297,89 @@ def create_dosing_distribution_figure( fig.update_layout(**layout) return fig + + +def create_trend_figure( + data: list[dict], + title: str = "", + metric: str = "patients", +) -> go.Figure: + """Create a line chart showing trends over time from pathway_trends data. + + Args: + data: List of dicts with keys: period_end, name, value + title: Chart title + metric: "patients", "total_cost", or "cost_pp_pa" (for y-axis label) + """ + if not data: + fig = go.Figure() + fig.add_annotation( + text="No trend data available.
Run python -m cli.compute_trends to generate.", + xref="paper", yref="paper", x=0.5, y=0.5, + showarrow=False, + font=dict(size=16, color=ANNOTATION_COLOR, family=CHART_FONT_FAMILY), + ) + layout = _base_layout(title or "Temporal Trends") + fig.update_layout(**layout) + return fig + + display_title = title or "Temporal Trends" + + # Group data by name (drug or directory) + from collections import defaultdict + series = defaultdict(lambda: {"periods": [], "values": []}) + for row in data: + name = row.get("name", "") + series[name]["periods"].append(row["period_end"]) + series[name]["values"].append(row.get("value", 0)) + + n_series = len(series) + fig = go.Figure() + + for i, (name, s) in enumerate(sorted(series.items())): + colour = DRUG_PALETTE[i % len(DRUG_PALETTE)] + fig.add_trace(go.Scatter( + x=s["periods"], + y=s["values"], + mode="lines+markers", + name=name, + line=dict(color=colour, width=2), + marker=dict(color=colour, size=6), + hovertemplate=( + f"{name}
" + "Period: %{x}
" + "Value: %{y:,.0f}" + ), + )) + + metric_labels = { + "patients": "Patients", + "total_cost": "Total Cost (£)", + "cost_pp_pa": "Cost per Patient p.a. (£)", + } + y_label = metric_labels.get(metric, "Value") + + legend = _smart_legend(n_series) + legend_margins = _smart_legend_margin(n_series) + + layout = _base_layout(display_title) + layout.update( + xaxis=dict( + title="Period", + gridcolor=GRID_COLOR, + type="category", + ), + yaxis=dict( + title=y_label, + gridcolor=GRID_COLOR, + zeroline=True, + zerolinecolor=GRID_COLOR, + ), + height=500, + margin=dict(t=60, l=8, **legend_margins), + legend=legend, + hovermode="x unified", + ) + fig.update_layout(**layout) + + return fig