feat: temporal trends CLI script + Dash tab (Task D.1)

D.1a: cli/compute_trends.py — standalone CLI that imports existing pipeline
functions to replay pathway computation for ~10 historical 6-month endpoints.
Creates pathway_trends table via CREATE TABLE IF NOT EXISTS.

D.1b: Trends tab (10th PP tab) with metric toggle (patients/cost/cost_pp_pa).
Query gracefully returns empty when table doesn't exist, figure shows
instruction message to run compute_trends.
This commit is contained in:
Andrew Charlwood
2026-02-07 18:24:34 +00:00
parent d892c9fba8
commit d0404aa18a
7 changed files with 603 additions and 32 deletions
+345
View File
@@ -0,0 +1,345 @@
"""
CLI command for computing historical trend snapshots.
This command fetches all activity data from Snowflake once, then replays the
pathway computation for ~10 historical 6-month endpoints (2021-06-30 through
2025-12-31). For each period, level-3 node summaries (drug × directory) are
extracted and stored in a `pathway_trends` table in pathways.db.
The Dash "Trends" tab then queries this table to show how drug patient counts,
costs, and cost-per-patient have changed over time.
Usage:
python -m cli.compute_trends
python -m cli.compute_trends --start 2022-01-01 --end 2025-06-30
python -m cli.compute_trends --interval 12 # 12-month steps
python -m cli.compute_trends --dry-run -v
Run `python -m cli.compute_trends --help` for full options.
"""
import argparse
import sqlite3
import sys
import time
from datetime import date, timedelta
from pathlib import Path
from typing import Optional
# Ensure src/ is on sys.path when run as `python -m cli.compute_trends`
_src_dir = str(Path(__file__).resolve().parent.parent)
if _src_dir not in sys.path:
sys.path.insert(0, _src_dir)
from core import PathConfig, default_paths
from core.logging_config import get_logger, setup_logging
from data_processing.pathway_pipeline import (
DateFilterConfig,
fetch_and_transform_data,
process_pathway_for_date_filter,
extract_denormalized_fields,
)
logger = get_logger(__name__)
# Use the all_6mo config: all years initiated, last seen within 6 months
TREND_FILTER_CONFIG = DateFilterConfig(
id="all_6mo", initiated_years=None, last_seen_months=6
)
CREATE_TRENDS_TABLE = """
CREATE TABLE IF NOT EXISTS pathway_trends (
period_end TEXT NOT NULL,
drug TEXT NOT NULL,
directory TEXT NOT NULL,
patients INTEGER NOT NULL,
total_cost REAL NOT NULL,
cost_pp_pa REAL,
PRIMARY KEY (period_end, drug, directory)
)
"""
def generate_period_endpoints(
start: date,
end: date,
interval_months: int = 6,
) -> list[date]:
"""Generate period end-dates from start to end at interval_months steps."""
endpoints = []
current = start
while current <= end:
endpoints.append(current)
# Advance by interval_months
month = current.month + interval_months
year = current.year + (month - 1) // 12
month = ((month - 1) % 12) + 1
# Use last day of the target month or keep day if valid
import calendar
max_day = calendar.monthrange(year, month)[1]
day = min(current.day, max_day)
current = date(year, month, day)
return endpoints
def extract_level3_summaries(ice_df) -> list[dict]:
"""Extract level-3 (drug) node summaries from ice_df DataFrame.
Returns list of dicts with: drug, directory, patients, total_cost, cost_pp_pa
"""
import pandas as pd
level3 = ice_df[ice_df["level"] == 3].copy()
if level3.empty:
return []
# Extract denormalized fields to get drug and directory
level3 = extract_denormalized_fields(level3)
rows = []
for _, row in level3.iterrows():
drug_seq = row.get("drug_sequence", "")
directory = row.get("directory", "")
if not drug_seq or not directory:
continue
cost_pp_pa = row.get("cost_pp_pa")
try:
cost_pp_pa = float(cost_pp_pa) if pd.notna(cost_pp_pa) and cost_pp_pa != "" else None
except (ValueError, TypeError):
cost_pp_pa = None
rows.append({
"drug": drug_seq,
"directory": directory,
"patients": int(row.get("value", 0)),
"total_cost": float(row.get("cost", 0)),
"cost_pp_pa": cost_pp_pa,
})
return rows
def compute_trends(
start: date = date(2021, 6, 30),
end: date = date(2025, 12, 31),
interval_months: int = 6,
minimum_patients: int = 5,
db_path: Optional[Path] = None,
paths: Optional[PathConfig] = None,
dry_run: bool = False,
) -> tuple[bool, str]:
"""
Main function: fetch data, replay pathway computation for each period, store summaries.
Args:
start: First period endpoint
end: Last period endpoint
interval_months: Months between endpoints
minimum_patients: Min patients for pathway inclusion
db_path: Path to pathways.db (uses default if None)
paths: PathConfig for reference files
dry_run: If True, compute but don't write to DB
Returns:
(success, message) tuple
"""
if paths is None:
paths = default_paths
if db_path is None:
db_path = paths.data_dir / "pathways.db"
endpoints = generate_period_endpoints(start, end, interval_months)
logger.info(f"Will compute trends for {len(endpoints)} periods: "
f"{endpoints[0].isoformat()} to {endpoints[-1].isoformat()}")
# Load default filters (same as refresh_pathways)
from cli.refresh_pathways import get_default_filters
trust_filter, drug_filter, directory_filter = get_default_filters(paths)
if not drug_filter:
return False, "No drugs found in default filters"
logger.info(f"Filters: {len(trust_filter)} trusts, {len(drug_filter)} drugs, "
f"{len(directory_filter)} directories")
start_time = time.time()
# Step 1: Fetch all activity data from Snowflake (one-time)
logger.info("Step 1: Fetching all activity data from Snowflake...")
df = fetch_and_transform_data(paths=paths)
if df.empty:
return False, "No data returned from Snowflake"
logger.info(f"Fetched {len(df)} records")
# Step 2: Create trends table
if not dry_run:
conn = sqlite3.connect(str(db_path))
conn.execute(CREATE_TRENDS_TABLE)
conn.commit()
logger.info("Created pathway_trends table (if not exists)")
else:
conn = None
# Step 3: Process each historical endpoint
total_rows = 0
period_stats = []
for i, endpoint in enumerate(endpoints, 1):
logger.info(f"Period {i}/{len(endpoints)}: computing pathways as of {endpoint.isoformat()}...")
ice_df = process_pathway_for_date_filter(
df=df,
config=TREND_FILTER_CONFIG,
trust_filter=trust_filter,
drug_filter=drug_filter,
directory_filter=directory_filter,
minimum_patients=minimum_patients,
max_date=endpoint,
paths=paths,
)
if ice_df is None:
logger.warning(f" No data for period ending {endpoint.isoformat()}")
period_stats.append((endpoint, 0))
continue
summaries = extract_level3_summaries(ice_df)
period_stats.append((endpoint, len(summaries)))
total_rows += len(summaries)
logger.info(f" {len(summaries)} drug×directory rows for {endpoint.isoformat()}")
if not dry_run and conn and summaries:
# Insert/replace rows for this period
conn.executemany(
"INSERT OR REPLACE INTO pathway_trends "
"(period_end, drug, directory, patients, total_cost, cost_pp_pa) "
"VALUES (?, ?, ?, ?, ?, ?)",
[
(
endpoint.isoformat(),
s["drug"],
s["directory"],
s["patients"],
s["total_cost"],
s["cost_pp_pa"],
)
for s in summaries
],
)
conn.commit()
if conn:
conn.close()
elapsed = time.time() - start_time
# Summary
logger.info("")
logger.info("=" * 50)
logger.info(f"Trend computation complete in {elapsed:.1f}s")
logger.info(f"Periods processed: {len(endpoints)}")
logger.info(f"Total rows: {total_rows}")
for ep, count in period_stats:
logger.info(f" {ep.isoformat()}: {count} rows")
if dry_run:
logger.info("(DRY RUN — no data written)")
logger.info("=" * 50)
return True, f"Computed {total_rows} trend rows across {len(endpoints)} periods in {elapsed:.1f}s"
def main() -> int:
"""CLI entry point."""
parser = argparse.ArgumentParser(
description="Compute historical trend snapshots for pathway analysis",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
# Default: 6-month intervals from 2021-06-30 to 2025-12-31
python -m cli.compute_trends
# Custom date range
python -m cli.compute_trends --start 2022-01-01 --end 2025-06-30
# 12-month intervals
python -m cli.compute_trends --interval 12
# Dry run
python -m cli.compute_trends --dry-run -v
""",
)
parser.add_argument(
"--start",
type=str,
default="2021-06-30",
help="First period endpoint (ISO date, default: 2021-06-30)",
)
parser.add_argument(
"--end",
type=str,
default="2025-12-31",
help="Last period endpoint (ISO date, default: 2025-12-31)",
)
parser.add_argument(
"--interval",
type=int,
default=6,
help="Months between endpoints (default: 6)",
)
parser.add_argument(
"--minimum-patients",
type=int,
default=5,
help="Min patients per pathway (default: 5)",
)
parser.add_argument(
"--db-path",
type=str,
default=None,
help="Path to pathways.db (default: data/pathways.db)",
)
parser.add_argument(
"--dry-run",
action="store_true",
help="Compute but don't write to database",
)
parser.add_argument(
"--verbose", "-v",
action="store_true",
help="Enable verbose logging",
)
args = parser.parse_args()
import logging
setup_logging(level=logging.DEBUG if args.verbose else logging.INFO)
start_date = date.fromisoformat(args.start)
end_date = date.fromisoformat(args.end)
db_path_arg = Path(args.db_path) if args.db_path else None
success, message = compute_trends(
start=start_date,
end=end_date,
interval_months=args.interval,
minimum_patients=args.minimum_patients,
db_path=db_path_arg,
dry_run=args.dry_run,
)
if success:
print(f"\n[OK] {message}")
return 0
else:
print(f"\n[FAILED] {message}", file=sys.stderr)
return 1
if __name__ == "__main__":
sys.exit(main())
+85
View File
@@ -1584,3 +1584,88 @@ def get_directorate_summary(
return []
finally:
conn.close()
def get_trend_data(
db_path: Path,
metric: str = "patients",
directory: Optional[str] = None,
drug: Optional[str] = None,
) -> list[dict]:
"""
Query pathway_trends table for time-series data.
Returns list of dicts with: period_end, name (drug or directory), value.
Groups by drug (one line per drug) unless aggregating by directory.
Args:
db_path: Path to pathways.db
metric: "patients", "total_cost", or "cost_pp_pa"
directory: Optional directory filter
drug: Optional drug filter
Returns:
List of dicts: [{period_end, name, value}, ...]
Empty list if pathway_trends table doesn't exist or has no data.
"""
conn = sqlite3.connect(str(db_path))
conn.row_factory = sqlite3.Row
try:
# Check if the table exists
cursor = conn.cursor()
cursor.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name='pathway_trends'"
)
if not cursor.fetchone():
return []
valid_metrics = {"patients", "total_cost", "cost_pp_pa"}
if metric not in valid_metrics:
metric = "patients"
# Build query — group by drug (one line per drug over time)
where_clauses = []
params = []
if directory:
where_clauses.append("directory = ?")
params.append(directory)
if drug:
where_clauses.append("drug = ?")
params.append(drug)
where_sql = " AND ".join(where_clauses) if where_clauses else "1=1"
# Aggregate across directories per drug per period (or per directory if filtering by drug)
if drug:
# One line per directory for a specific drug
group_col = "directory"
else:
# One line per drug (aggregate across directories)
group_col = "drug"
if metric == "cost_pp_pa":
# Weighted average for cost_pp_pa
agg = "SUM(cost_pp_pa * patients) / NULLIF(SUM(patients), 0)"
elif metric == "total_cost":
agg = "SUM(total_cost)"
else:
agg = "SUM(patients)"
sql = f"""
SELECT period_end, {group_col} AS name, {agg} AS value
FROM pathway_trends
WHERE {where_sql}
GROUP BY period_end, {group_col}
ORDER BY period_end, {group_col}
"""
cursor.execute(sql, params)
rows = [dict(row) for row in cursor.fetchall()]
return rows
except sqlite3.Error:
return []
finally:
conn.close()
+86
View File
@@ -2297,3 +2297,89 @@ def create_dosing_distribution_figure(
fig.update_layout(**layout)
return fig
def create_trend_figure(
data: list[dict],
title: str = "",
metric: str = "patients",
) -> go.Figure:
"""Create a line chart showing trends over time from pathway_trends data.
Args:
data: List of dicts with keys: period_end, name, value
title: Chart title
metric: "patients", "total_cost", or "cost_pp_pa" (for y-axis label)
"""
if not data:
fig = go.Figure()
fig.add_annotation(
text="No trend data available.<br>Run <b>python -m cli.compute_trends</b> to generate.",
xref="paper", yref="paper", x=0.5, y=0.5,
showarrow=False,
font=dict(size=16, color=ANNOTATION_COLOR, family=CHART_FONT_FAMILY),
)
layout = _base_layout(title or "Temporal Trends")
fig.update_layout(**layout)
return fig
display_title = title or "Temporal Trends"
# Group data by name (drug or directory)
from collections import defaultdict
series = defaultdict(lambda: {"periods": [], "values": []})
for row in data:
name = row.get("name", "")
series[name]["periods"].append(row["period_end"])
series[name]["values"].append(row.get("value", 0))
n_series = len(series)
fig = go.Figure()
for i, (name, s) in enumerate(sorted(series.items())):
colour = DRUG_PALETTE[i % len(DRUG_PALETTE)]
fig.add_trace(go.Scatter(
x=s["periods"],
y=s["values"],
mode="lines+markers",
name=name,
line=dict(color=colour, width=2),
marker=dict(color=colour, size=6),
hovertemplate=(
f"<b>{name}</b><br>"
"Period: %{x}<br>"
"Value: %{y:,.0f}<extra></extra>"
),
))
metric_labels = {
"patients": "Patients",
"total_cost": "Total Cost (£)",
"cost_pp_pa": "Cost per Patient p.a. (£)",
}
y_label = metric_labels.get(metric, "Value")
legend = _smart_legend(n_series)
legend_margins = _smart_legend_margin(n_series)
layout = _base_layout(display_title)
layout.update(
xaxis=dict(
title="Period",
gridcolor=GRID_COLOR,
type="category",
),
yaxis=dict(
title=y_label,
gridcolor=GRID_COLOR,
zeroline=True,
zerolinecolor=GRID_COLOR,
),
height=500,
margin=dict(t=60, l=8, **legend_margins),
legend=legend,
hovermode="x unified",
)
fig.update_layout(**layout)
return fig