Files
HighCostDrugsDemo/cli/refresh_pathways.py
T
Andrew Charlwood adc1dbfc58 feat: complete Task 2.2 - test refresh pipeline with Snowflake data
Tested full refresh pipeline end-to-end with real Snowflake data:
- Fixed trust filter to read Name column from defaultTrusts.csv
- Fixed Decimal type handling in calculate_cost_per_patient_per_annum
- Fixed array handling in convert_to_records for average_administered
- Added required reference CSV files to data/ directory
- Configured Snowflake connection (account, warehouse, user)

Results:
- Snowflake fetch: 656,695 records in ~7s
- Transformations: 519,848 records after UPID/drug/directory
- Pathway nodes: 293 for all_6mo (8 trusts, 14 directories)
- Total processing time: ~6.2 minutes
2026-02-05 00:20:12 +00:00

487 lines
15 KiB
Python

"""
CLI command for refreshing pathway data from Snowflake.
This command fetches activity data from Snowflake, processes it through the
pathway pipeline for all 6 date filter combinations, and stores the results
in the SQLite pathway_nodes table.
Usage:
python -m cli.refresh_pathways
python -m cli.refresh_pathways --minimum-patients 10
python -m cli.refresh_pathways --provider-codes RGT,RM1
python -m cli.refresh_pathways --dry-run
Run `python -m cli.refresh_pathways --help` for full options.
"""
import argparse
import json
import sqlite3
import sys
import time
import uuid
from datetime import datetime
from pathlib import Path
from typing import Optional
from core import PathConfig, default_paths
from core.logging_config import get_logger, setup_logging
from data_processing.database import DatabaseManager, DatabaseConfig
from data_processing.schema import (
clear_pathway_nodes,
get_pathway_table_counts,
verify_pathway_tables_exist,
create_pathway_tables,
)
from data_processing.pathway_pipeline import (
DATE_FILTER_CONFIGS,
fetch_and_transform_data,
process_all_date_filters,
)
logger = get_logger(__name__)
def get_default_filters(paths: PathConfig) -> tuple[list[str], list[str], list[str]]:
"""
Load default filter values from reference files.
Returns:
Tuple of (trust_filter, drug_filter, directory_filter)
"""
import pandas as pd
# Load default trusts
trust_filter = []
if paths.default_trusts_csv.exists():
try:
trusts_df = pd.read_csv(paths.default_trusts_csv)
# Use the "Name" column which contains trust names
if 'Name' in trusts_df.columns:
trust_filter = trusts_df['Name'].dropna().tolist()
else:
# Fallback to first column if no Name column
trust_filter = trusts_df.iloc[:, 0].dropna().tolist()
logger.info(f"Loaded {len(trust_filter)} default trusts")
except Exception as e:
logger.warning(f"Could not load default trusts: {e}")
# Load default drugs (Include=1 in include.csv)
drug_filter = []
if paths.include_csv.exists():
try:
drugs_df = pd.read_csv(paths.include_csv)
if 'Include' in drugs_df.columns:
drug_filter = drugs_df[drugs_df['Include'] == 1].iloc[:, 0].dropna().tolist()
else:
# Assume first column contains drug names if no Include column
drug_filter = drugs_df.iloc[:, 0].dropna().tolist()
logger.info(f"Loaded {len(drug_filter)} default drugs")
except Exception as e:
logger.warning(f"Could not load default drugs: {e}")
# Load default directories
directory_filter = []
if paths.directory_list_csv.exists():
try:
dirs_df = pd.read_csv(paths.directory_list_csv)
# Assume first column contains directory names
directory_filter = dirs_df.iloc[:, 0].dropna().tolist()
logger.info(f"Loaded {len(directory_filter)} default directories")
except Exception as e:
logger.warning(f"Could not load default directories: {e}")
return trust_filter, drug_filter, directory_filter
def insert_pathway_records(
conn: sqlite3.Connection,
records: list[dict],
) -> int:
"""
Insert pathway records into pathway_nodes table.
Uses INSERT OR REPLACE to handle updates to existing records.
Args:
conn: SQLite connection
records: List of record dicts from convert_to_records()
Returns:
Number of records inserted
"""
if not records:
return 0
# Column order matching pathway_nodes schema
columns = [
'date_filter_id', 'parents', 'ids', 'labels', 'level',
'value', 'cost', 'costpp', 'cost_pp_pa', 'colour',
'first_seen', 'last_seen', 'first_seen_parent', 'last_seen_parent',
'average_spacing', 'average_administered', 'avg_days',
'trust_name', 'directory', 'drug_sequence', 'data_refresh_id'
]
placeholders = ', '.join(['?' for _ in columns])
column_names = ', '.join(columns)
insert_sql = f"""
INSERT OR REPLACE INTO pathway_nodes ({column_names})
VALUES ({placeholders})
"""
# Convert records to tuples in column order
rows = []
for record in records:
row = tuple(record.get(col) for col in columns)
rows.append(row)
cursor = conn.executemany(insert_sql, rows)
return cursor.rowcount
def log_refresh_start(
conn: sqlite3.Connection,
refresh_id: str,
date_from: Optional[str] = None,
date_to: Optional[str] = None,
) -> None:
"""Log the start of a refresh operation."""
conn.execute("""
INSERT INTO pathway_refresh_log
(refresh_id, started_at, status, snowflake_query_date_from, snowflake_query_date_to)
VALUES (?, ?, 'running', ?, ?)
""", (refresh_id, datetime.now().isoformat(), date_from, date_to))
conn.commit()
def log_refresh_complete(
conn: sqlite3.Connection,
refresh_id: str,
record_count: int,
date_filter_counts: dict[str, int],
duration_seconds: float,
) -> None:
"""Log the successful completion of a refresh operation."""
conn.execute("""
UPDATE pathway_refresh_log
SET completed_at = ?,
status = 'completed',
record_count = ?,
date_filter_counts = ?,
processing_duration_seconds = ?
WHERE refresh_id = ?
""", (
datetime.now().isoformat(),
record_count,
json.dumps(date_filter_counts),
duration_seconds,
refresh_id,
))
conn.commit()
def log_refresh_failed(
conn: sqlite3.Connection,
refresh_id: str,
error_message: str,
duration_seconds: float,
) -> None:
"""Log a failed refresh operation."""
conn.execute("""
UPDATE pathway_refresh_log
SET completed_at = ?,
status = 'failed',
error_message = ?,
processing_duration_seconds = ?
WHERE refresh_id = ?
""", (
datetime.now().isoformat(),
error_message,
duration_seconds,
refresh_id,
))
conn.commit()
def refresh_pathways(
minimum_patients: int = 5,
provider_codes: Optional[list[str]] = None,
trust_filter: Optional[list[str]] = None,
drug_filter: Optional[list[str]] = None,
directory_filter: Optional[list[str]] = None,
db_path: Optional[Path] = None,
paths: Optional[PathConfig] = None,
dry_run: bool = False,
) -> tuple[bool, str, dict]:
"""
Main refresh function that orchestrates the full pipeline.
Args:
minimum_patients: Minimum patients to include a pathway
provider_codes: List of provider codes to filter Snowflake query
trust_filter: List of trust names to include in pathways
drug_filter: List of drug names to include in pathways
directory_filter: List of directories to include in pathways
db_path: Path to SQLite database (uses default if None)
paths: PathConfig for file paths
dry_run: If True, don't actually insert records
Returns:
Tuple of (success: bool, message: str, stats: dict)
"""
if paths is None:
paths = default_paths
# Set up database connection
if db_path:
db_config = DatabaseConfig(db_path=db_path)
else:
db_config = DatabaseConfig(data_dir=paths.data_dir)
db_manager = DatabaseManager(db_config)
# Load default filters if not provided
default_trusts, default_drugs, default_dirs = get_default_filters(paths)
if trust_filter is None:
trust_filter = default_trusts
if drug_filter is None:
drug_filter = default_drugs
if directory_filter is None:
directory_filter = default_dirs
# Ensure we have some filters
if not drug_filter:
return False, "No drugs specified and could not load defaults", {}
logger.info("=" * 60)
logger.info("Pathway Data Refresh Starting")
logger.info("=" * 60)
logger.info(f"Minimum patients: {minimum_patients}")
logger.info(f"Trust filter: {len(trust_filter)} trusts")
logger.info(f"Drug filter: {len(drug_filter)} drugs")
logger.info(f"Directory filter: {len(directory_filter)} directories")
logger.info(f"Provider codes: {provider_codes or 'All'}")
logger.info(f"Database: {db_manager.db_path}")
logger.info(f"Dry run: {dry_run}")
logger.info("=" * 60)
start_time = time.time()
refresh_id = str(uuid.uuid4())[:8]
stats = {
"refresh_id": refresh_id,
"date_filter_counts": {},
"total_records": 0,
"snowflake_rows": 0,
}
try:
# Verify database and tables
with db_manager.get_connection() as conn:
missing_tables = verify_pathway_tables_exist(conn)
if missing_tables:
logger.info(f"Creating missing tables: {missing_tables}")
create_pathway_tables(conn)
# Log refresh start
if not dry_run:
log_refresh_start(conn, refresh_id)
# Step 1: Fetch data from Snowflake
logger.info("")
logger.info("Step 1/4: Fetching data from Snowflake...")
df = fetch_and_transform_data(
provider_codes=provider_codes,
paths=paths,
)
if df.empty:
msg = "No data returned from Snowflake"
logger.error(msg)
with db_manager.get_connection() as conn:
log_refresh_failed(conn, refresh_id, msg, time.time() - start_time)
return False, msg, stats
stats["snowflake_rows"] = len(df)
logger.info(f"Fetched {len(df)} records from Snowflake")
# Step 2: Process all date filters
logger.info("")
logger.info("Step 2/4: Processing pathway data for 6 date filter combinations...")
results = process_all_date_filters(
df=df,
trust_filter=trust_filter,
drug_filter=drug_filter,
directory_filter=directory_filter,
minimum_patients=minimum_patients,
refresh_id=refresh_id,
paths=paths,
)
# Count records per filter
for filter_id, records in results.items():
stats["date_filter_counts"][filter_id] = len(records)
stats["total_records"] += len(records)
logger.info(f"Processed {stats['total_records']} total pathway nodes")
for filter_id, count in stats["date_filter_counts"].items():
logger.info(f" {filter_id}: {count} nodes")
if dry_run:
logger.info("")
logger.info("DRY RUN - Skipping database insertion")
elapsed = time.time() - start_time
return True, f"Dry run complete: {stats['total_records']} records would be inserted", stats
# Step 3: Clear existing data and insert new records
logger.info("")
logger.info("Step 3/4: Clearing existing pathway data and inserting new records...")
with db_manager.get_transaction() as conn:
# Clear all existing pathway nodes
deleted = clear_pathway_nodes(conn)
logger.info(f"Cleared {deleted} existing pathway nodes")
# Insert new records for each date filter
total_inserted = 0
for filter_id, records in results.items():
if records:
inserted = insert_pathway_records(conn, records)
total_inserted += len(records)
logger.info(f" Inserted {len(records)} records for {filter_id}")
# Step 4: Log completion
logger.info("")
logger.info("Step 4/4: Logging refresh completion...")
elapsed = time.time() - start_time
with db_manager.get_connection() as conn:
log_refresh_complete(
conn=conn,
refresh_id=refresh_id,
record_count=stats["total_records"],
date_filter_counts=stats["date_filter_counts"],
duration_seconds=elapsed,
)
# Verify final counts
counts = get_pathway_table_counts(conn)
logger.info(f"Final table counts: {counts}")
logger.info("")
logger.info("=" * 60)
logger.info(f"Refresh completed successfully in {elapsed:.1f} seconds")
logger.info(f"Total records: {stats['total_records']}")
logger.info(f"Refresh ID: {refresh_id}")
logger.info("=" * 60)
return True, f"Refresh complete: {stats['total_records']} records in {elapsed:.1f}s", stats
except Exception as e:
elapsed = time.time() - start_time
error_msg = f"Refresh failed: {e}"
logger.error(error_msg, exc_info=True)
try:
with db_manager.get_connection() as conn:
log_refresh_failed(conn, refresh_id, str(e), elapsed)
except Exception:
pass # Don't fail the error handling
return False, error_msg, stats
def main() -> int:
"""CLI entry point."""
parser = argparse.ArgumentParser(
description="Refresh pathway data from Snowflake",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
# Basic refresh with defaults
python -m cli.refresh_pathways
# Refresh with custom minimum patients
python -m cli.refresh_pathways --minimum-patients 10
# Refresh specific providers only
python -m cli.refresh_pathways --provider-codes RGT,RM1
# Dry run to see what would be processed
python -m cli.refresh_pathways --dry-run
# Verbose output
python -m cli.refresh_pathways --verbose
"""
)
parser.add_argument(
"--minimum-patients",
type=int,
default=5,
help="Minimum patients to include a pathway (default: 5)"
)
parser.add_argument(
"--provider-codes",
type=str,
default=None,
help="Comma-separated list of provider codes to filter (default: all)"
)
parser.add_argument(
"--db-path",
type=str,
default=None,
help="Path to SQLite database (default: data/pathways.db)"
)
parser.add_argument(
"--dry-run",
action="store_true",
help="Process data but don't insert into database"
)
parser.add_argument(
"--verbose", "-v",
action="store_true",
help="Enable verbose logging"
)
args = parser.parse_args()
# Configure logging
import logging
log_level = logging.DEBUG if args.verbose else logging.INFO
setup_logging(level=log_level)
# Parse provider codes
provider_codes = None
if args.provider_codes:
provider_codes = [code.strip() for code in args.provider_codes.split(",")]
# Parse db path
db_path = Path(args.db_path) if args.db_path else None
# Run the refresh
success, message, stats = refresh_pathways(
minimum_patients=args.minimum_patients,
provider_codes=provider_codes,
db_path=db_path,
dry_run=args.dry_run,
)
if success:
print(f"\n[OK] {message}")
return 0
else:
print(f"\n[FAILED] {message}", file=sys.stderr)
return 1
if __name__ == "__main__":
sys.exit(main())