Restructured src to more logical heirachy
This commit is contained in:
@@ -0,0 +1,345 @@
|
||||
"""
|
||||
CLI command for computing historical trend snapshots.
|
||||
|
||||
This command fetches all activity data from Snowflake once, then replays the
|
||||
pathway computation for ~10 historical 6-month endpoints (2021-06-30 through
|
||||
2025-12-31). For each period, level-3 node summaries (drug × directory) are
|
||||
extracted and stored in a `pathway_trends` table in pathways.db.
|
||||
|
||||
The Dash "Trends" tab then queries this table to show how drug patient counts,
|
||||
costs, and cost-per-patient have changed over time.
|
||||
|
||||
Usage:
|
||||
python -m cli.compute_trends
|
||||
python -m cli.compute_trends --start 2022-01-01 --end 2025-06-30
|
||||
python -m cli.compute_trends --interval 12 # 12-month steps
|
||||
python -m cli.compute_trends --dry-run -v
|
||||
|
||||
Run `python -m cli.compute_trends --help` for full options.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import sqlite3
|
||||
import sys
|
||||
import time
|
||||
from datetime import date, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
# Ensure src/ is on sys.path when run as `python -m cli.compute_trends`
|
||||
_src_dir = str(Path(__file__).resolve().parent.parent)
|
||||
if _src_dir not in sys.path:
|
||||
sys.path.insert(0, _src_dir)
|
||||
|
||||
from core import PathConfig, default_paths
|
||||
from core.logging_config import get_logger, setup_logging
|
||||
from data_processing.pathway_pipeline import (
|
||||
DateFilterConfig,
|
||||
fetch_and_transform_data,
|
||||
process_pathway_for_date_filter,
|
||||
extract_denormalized_fields,
|
||||
)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# Use the all_6mo config: all years initiated, last seen within 6 months
|
||||
TREND_FILTER_CONFIG = DateFilterConfig(
|
||||
id="all_6mo", initiated_years=None, last_seen_months=6
|
||||
)
|
||||
|
||||
CREATE_TRENDS_TABLE = """
|
||||
CREATE TABLE IF NOT EXISTS pathway_trends (
|
||||
period_end TEXT NOT NULL,
|
||||
drug TEXT NOT NULL,
|
||||
directory TEXT NOT NULL,
|
||||
patients INTEGER NOT NULL,
|
||||
total_cost REAL NOT NULL,
|
||||
cost_pp_pa REAL,
|
||||
PRIMARY KEY (period_end, drug, directory)
|
||||
)
|
||||
"""
|
||||
|
||||
|
||||
def generate_period_endpoints(
|
||||
start: date,
|
||||
end: date,
|
||||
interval_months: int = 6,
|
||||
) -> list[date]:
|
||||
"""Generate period end-dates from start to end at interval_months steps."""
|
||||
endpoints = []
|
||||
current = start
|
||||
while current <= end:
|
||||
endpoints.append(current)
|
||||
# Advance by interval_months
|
||||
month = current.month + interval_months
|
||||
year = current.year + (month - 1) // 12
|
||||
month = ((month - 1) % 12) + 1
|
||||
# Use last day of the target month or keep day if valid
|
||||
import calendar
|
||||
max_day = calendar.monthrange(year, month)[1]
|
||||
day = min(current.day, max_day)
|
||||
current = date(year, month, day)
|
||||
return endpoints
|
||||
|
||||
|
||||
def extract_level3_summaries(ice_df) -> list[dict]:
|
||||
"""Extract level-3 (drug) node summaries from ice_df DataFrame.
|
||||
|
||||
Returns list of dicts with: drug, directory, patients, total_cost, cost_pp_pa
|
||||
"""
|
||||
import pandas as pd
|
||||
|
||||
level3 = ice_df[ice_df["level"] == 3].copy()
|
||||
if level3.empty:
|
||||
return []
|
||||
|
||||
# Extract denormalized fields to get drug and directory
|
||||
level3 = extract_denormalized_fields(level3)
|
||||
|
||||
rows = []
|
||||
for _, row in level3.iterrows():
|
||||
drug_seq = row.get("drug_sequence", "")
|
||||
directory = row.get("directory", "")
|
||||
if not drug_seq or not directory:
|
||||
continue
|
||||
|
||||
cost_pp_pa = row.get("cost_pp_pa")
|
||||
try:
|
||||
cost_pp_pa = float(cost_pp_pa) if pd.notna(cost_pp_pa) and cost_pp_pa != "" else None
|
||||
except (ValueError, TypeError):
|
||||
cost_pp_pa = None
|
||||
|
||||
rows.append({
|
||||
"drug": drug_seq,
|
||||
"directory": directory,
|
||||
"patients": int(row.get("value", 0)),
|
||||
"total_cost": float(row.get("cost", 0)),
|
||||
"cost_pp_pa": cost_pp_pa,
|
||||
})
|
||||
|
||||
return rows
|
||||
|
||||
|
||||
def compute_trends(
|
||||
start: date = date(2021, 6, 30),
|
||||
end: date = date(2025, 12, 31),
|
||||
interval_months: int = 6,
|
||||
minimum_patients: int = 5,
|
||||
db_path: Optional[Path] = None,
|
||||
paths: Optional[PathConfig] = None,
|
||||
dry_run: bool = False,
|
||||
) -> tuple[bool, str]:
|
||||
"""
|
||||
Main function: fetch data, replay pathway computation for each period, store summaries.
|
||||
|
||||
Args:
|
||||
start: First period endpoint
|
||||
end: Last period endpoint
|
||||
interval_months: Months between endpoints
|
||||
minimum_patients: Min patients for pathway inclusion
|
||||
db_path: Path to pathways.db (uses default if None)
|
||||
paths: PathConfig for reference files
|
||||
dry_run: If True, compute but don't write to DB
|
||||
|
||||
Returns:
|
||||
(success, message) tuple
|
||||
"""
|
||||
if paths is None:
|
||||
paths = default_paths
|
||||
|
||||
if db_path is None:
|
||||
db_path = paths.data_dir / "pathways.db"
|
||||
|
||||
endpoints = generate_period_endpoints(start, end, interval_months)
|
||||
logger.info(f"Will compute trends for {len(endpoints)} periods: "
|
||||
f"{endpoints[0].isoformat()} to {endpoints[-1].isoformat()}")
|
||||
|
||||
# Load default filters (same as refresh_pathways)
|
||||
from cli.refresh_pathways import get_default_filters
|
||||
trust_filter, drug_filter, directory_filter = get_default_filters(paths)
|
||||
|
||||
if not drug_filter:
|
||||
return False, "No drugs found in default filters"
|
||||
|
||||
logger.info(f"Filters: {len(trust_filter)} trusts, {len(drug_filter)} drugs, "
|
||||
f"{len(directory_filter)} directories")
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# Step 1: Fetch all activity data from Snowflake (one-time)
|
||||
logger.info("Step 1: Fetching all activity data from Snowflake...")
|
||||
df = fetch_and_transform_data(paths=paths)
|
||||
|
||||
if df.empty:
|
||||
return False, "No data returned from Snowflake"
|
||||
|
||||
logger.info(f"Fetched {len(df)} records")
|
||||
|
||||
# Step 2: Create trends table
|
||||
if not dry_run:
|
||||
conn = sqlite3.connect(str(db_path))
|
||||
conn.execute(CREATE_TRENDS_TABLE)
|
||||
conn.commit()
|
||||
logger.info("Created pathway_trends table (if not exists)")
|
||||
else:
|
||||
conn = None
|
||||
|
||||
# Step 3: Process each historical endpoint
|
||||
total_rows = 0
|
||||
period_stats = []
|
||||
|
||||
for i, endpoint in enumerate(endpoints, 1):
|
||||
logger.info(f"Period {i}/{len(endpoints)}: computing pathways as of {endpoint.isoformat()}...")
|
||||
|
||||
ice_df = process_pathway_for_date_filter(
|
||||
df=df,
|
||||
config=TREND_FILTER_CONFIG,
|
||||
trust_filter=trust_filter,
|
||||
drug_filter=drug_filter,
|
||||
directory_filter=directory_filter,
|
||||
minimum_patients=minimum_patients,
|
||||
max_date=endpoint,
|
||||
paths=paths,
|
||||
)
|
||||
|
||||
if ice_df is None:
|
||||
logger.warning(f" No data for period ending {endpoint.isoformat()}")
|
||||
period_stats.append((endpoint, 0))
|
||||
continue
|
||||
|
||||
summaries = extract_level3_summaries(ice_df)
|
||||
period_stats.append((endpoint, len(summaries)))
|
||||
total_rows += len(summaries)
|
||||
|
||||
logger.info(f" {len(summaries)} drug×directory rows for {endpoint.isoformat()}")
|
||||
|
||||
if not dry_run and conn and summaries:
|
||||
# Insert/replace rows for this period
|
||||
conn.executemany(
|
||||
"INSERT OR REPLACE INTO pathway_trends "
|
||||
"(period_end, drug, directory, patients, total_cost, cost_pp_pa) "
|
||||
"VALUES (?, ?, ?, ?, ?, ?)",
|
||||
[
|
||||
(
|
||||
endpoint.isoformat(),
|
||||
s["drug"],
|
||||
s["directory"],
|
||||
s["patients"],
|
||||
s["total_cost"],
|
||||
s["cost_pp_pa"],
|
||||
)
|
||||
for s in summaries
|
||||
],
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
if conn:
|
||||
conn.close()
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
# Summary
|
||||
logger.info("")
|
||||
logger.info("=" * 50)
|
||||
logger.info(f"Trend computation complete in {elapsed:.1f}s")
|
||||
logger.info(f"Periods processed: {len(endpoints)}")
|
||||
logger.info(f"Total rows: {total_rows}")
|
||||
for ep, count in period_stats:
|
||||
logger.info(f" {ep.isoformat()}: {count} rows")
|
||||
if dry_run:
|
||||
logger.info("(DRY RUN — no data written)")
|
||||
logger.info("=" * 50)
|
||||
|
||||
return True, f"Computed {total_rows} trend rows across {len(endpoints)} periods in {elapsed:.1f}s"
|
||||
|
||||
|
||||
def main() -> int:
|
||||
"""CLI entry point."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Compute historical trend snapshots for pathway analysis",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
# Default: 6-month intervals from 2021-06-30 to 2025-12-31
|
||||
python -m cli.compute_trends
|
||||
|
||||
# Custom date range
|
||||
python -m cli.compute_trends --start 2022-01-01 --end 2025-06-30
|
||||
|
||||
# 12-month intervals
|
||||
python -m cli.compute_trends --interval 12
|
||||
|
||||
# Dry run
|
||||
python -m cli.compute_trends --dry-run -v
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--start",
|
||||
type=str,
|
||||
default="2021-06-30",
|
||||
help="First period endpoint (ISO date, default: 2021-06-30)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--end",
|
||||
type=str,
|
||||
default="2025-12-31",
|
||||
help="Last period endpoint (ISO date, default: 2025-12-31)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--interval",
|
||||
type=int,
|
||||
default=6,
|
||||
help="Months between endpoints (default: 6)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--minimum-patients",
|
||||
type=int,
|
||||
default=5,
|
||||
help="Min patients per pathway (default: 5)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--db-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to pathways.db (default: data/pathways.db)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dry-run",
|
||||
action="store_true",
|
||||
help="Compute but don't write to database",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--verbose", "-v",
|
||||
action="store_true",
|
||||
help="Enable verbose logging",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
import logging
|
||||
setup_logging(level=logging.DEBUG if args.verbose else logging.INFO)
|
||||
|
||||
start_date = date.fromisoformat(args.start)
|
||||
end_date = date.fromisoformat(args.end)
|
||||
db_path_arg = Path(args.db_path) if args.db_path else None
|
||||
|
||||
success, message = compute_trends(
|
||||
start=start_date,
|
||||
end=end_date,
|
||||
interval_months=args.interval,
|
||||
minimum_patients=args.minimum_patients,
|
||||
db_path=db_path_arg,
|
||||
dry_run=args.dry_run,
|
||||
)
|
||||
|
||||
if success:
|
||||
print(f"\n[OK] {message}")
|
||||
return 0
|
||||
else:
|
||||
print(f"\n[FAILED] {message}", file=sys.stderr)
|
||||
return 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
Reference in New Issue
Block a user