From 092fdbba5a423de63e4c4458966b58395f952aee Mon Sep 17 00:00:00 2001 From: Andrew Charlwood Date: Wed, 4 Feb 2026 23:30:11 +0000 Subject: [PATCH] feat: add CLI refresh command for pathway data (Task 2.1) Add cli/refresh_pathways.py with: - refresh_pathways() main function for full pipeline orchestration - insert_pathway_records() for SQLite insertion - log_refresh_start/complete/failed() for refresh tracking - CLI with --minimum-patients, --provider-codes, --dry-run, --verbose Uses existing pipeline functions: - fetch_and_transform_data() from pathway_pipeline.py - process_all_date_filters() for 6 date filter combinations - Schema helpers from data_processing/schema.py --- IMPLEMENTATION_PLAN.md | 19 +- cli/__init__.py | 6 + cli/refresh_pathways.py | 482 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 498 insertions(+), 9 deletions(-) create mode 100644 cli/__init__.py create mode 100644 cli/refresh_pathways.py diff --git a/IMPLEMENTATION_PLAN.md b/IMPLEMENTATION_PLAN.md index a5bc894..d8f675e 100644 --- a/IMPLEMENTATION_PLAN.md +++ b/IMPLEMENTATION_PLAN.md @@ -69,18 +69,19 @@ cd pathways_app && timeout 60 python -m reflex run 2>&1 | head -30 ## Phase 2: CLI Refresh Command ### 2.1 Create Refresh Command -- [ ] Create `cli/refresh_pathways.py` with: - - DATE_FILTER_CONFIGS constant (6 combinations) - - `compute_date_ranges(config, max_date)` - Calculate actual dates from config +- [x] Create `cli/refresh_pathways.py` with: + - Uses DATE_FILTER_CONFIGS and compute_date_ranges from pathway_pipeline.py - `refresh_pathways(minimum_patients, provider_codes, ...)` main function -- [ ] Implement refresh flow: - 1. Fetch ALL data from Snowflake (full date range) - 2. Apply transformations (UPID, drug names, directory) - 3. Clear existing pathway_nodes + - `insert_pathway_records()` for SQLite insertion + - `log_refresh_start/complete/failed()` for refresh tracking +- [x] Implement refresh flow: + 1. Fetch ALL data from Snowflake (full date range) via fetch_and_transform_data() + 2. Apply transformations (UPID, drug names, directory) - handled by pipeline + 3. Clear existing pathway_nodes via clear_pathway_nodes() 4. For each of 6 date filter configs: filter → process → insert 5. Update pathway_refresh_log -- [ ] Add CLI argument parsing (--minimum-patients, --provider-codes, etc.) -- [ ] Verify: `python -m cli.refresh_pathways --help` +- [x] Add CLI argument parsing (--minimum-patients, --provider-codes, --dry-run, --verbose) +- [x] Verify: `python -m cli.refresh_pathways --help` ### 2.2 Test Refresh Pipeline - [ ] Run refresh with Snowflake data diff --git a/cli/__init__.py b/cli/__init__.py new file mode 100644 index 0000000..6158958 --- /dev/null +++ b/cli/__init__.py @@ -0,0 +1,6 @@ +""" +CLI commands for NHS High-Cost Drug Patient Pathway Analysis Tool. + +Available commands: + python -m cli.refresh_pathways - Refresh pathway data from Snowflake +""" diff --git a/cli/refresh_pathways.py b/cli/refresh_pathways.py new file mode 100644 index 0000000..14c93c5 --- /dev/null +++ b/cli/refresh_pathways.py @@ -0,0 +1,482 @@ +""" +CLI command for refreshing pathway data from Snowflake. + +This command fetches activity data from Snowflake, processes it through the +pathway pipeline for all 6 date filter combinations, and stores the results +in the SQLite pathway_nodes table. + +Usage: + python -m cli.refresh_pathways + python -m cli.refresh_pathways --minimum-patients 10 + python -m cli.refresh_pathways --provider-codes RGT,RM1 + python -m cli.refresh_pathways --dry-run + +Run `python -m cli.refresh_pathways --help` for full options. +""" + +import argparse +import json +import sqlite3 +import sys +import time +import uuid +from datetime import datetime +from pathlib import Path +from typing import Optional + +from core import PathConfig, default_paths +from core.logging_config import get_logger, setup_logging +from data_processing.database import DatabaseManager, DatabaseConfig +from data_processing.schema import ( + clear_pathway_nodes, + get_pathway_table_counts, + verify_pathway_tables_exist, + create_pathway_tables, +) +from data_processing.pathway_pipeline import ( + DATE_FILTER_CONFIGS, + fetch_and_transform_data, + process_all_date_filters, +) + +logger = get_logger(__name__) + + +def get_default_filters(paths: PathConfig) -> tuple[list[str], list[str], list[str]]: + """ + Load default filter values from reference files. + + Returns: + Tuple of (trust_filter, drug_filter, directory_filter) + """ + import pandas as pd + + # Load default trusts + trust_filter = [] + if paths.default_trusts_csv.exists(): + try: + trusts_df = pd.read_csv(paths.default_trusts_csv) + # Assume first column contains trust names + trust_filter = trusts_df.iloc[:, 0].dropna().tolist() + logger.info(f"Loaded {len(trust_filter)} default trusts") + except Exception as e: + logger.warning(f"Could not load default trusts: {e}") + + # Load default drugs (Include=1 in include.csv) + drug_filter = [] + if paths.include_csv.exists(): + try: + drugs_df = pd.read_csv(paths.include_csv) + if 'Include' in drugs_df.columns: + drug_filter = drugs_df[drugs_df['Include'] == 1].iloc[:, 0].dropna().tolist() + else: + # Assume first column contains drug names if no Include column + drug_filter = drugs_df.iloc[:, 0].dropna().tolist() + logger.info(f"Loaded {len(drug_filter)} default drugs") + except Exception as e: + logger.warning(f"Could not load default drugs: {e}") + + # Load default directories + directory_filter = [] + if paths.directory_list_csv.exists(): + try: + dirs_df = pd.read_csv(paths.directory_list_csv) + # Assume first column contains directory names + directory_filter = dirs_df.iloc[:, 0].dropna().tolist() + logger.info(f"Loaded {len(directory_filter)} default directories") + except Exception as e: + logger.warning(f"Could not load default directories: {e}") + + return trust_filter, drug_filter, directory_filter + + +def insert_pathway_records( + conn: sqlite3.Connection, + records: list[dict], +) -> int: + """ + Insert pathway records into pathway_nodes table. + + Uses INSERT OR REPLACE to handle updates to existing records. + + Args: + conn: SQLite connection + records: List of record dicts from convert_to_records() + + Returns: + Number of records inserted + """ + if not records: + return 0 + + # Column order matching pathway_nodes schema + columns = [ + 'date_filter_id', 'parents', 'ids', 'labels', 'level', + 'value', 'cost', 'costpp', 'cost_pp_pa', 'colour', + 'first_seen', 'last_seen', 'first_seen_parent', 'last_seen_parent', + 'average_spacing', 'average_administered', 'avg_days', + 'trust_name', 'directory', 'drug_sequence', 'data_refresh_id' + ] + + placeholders = ', '.join(['?' for _ in columns]) + column_names = ', '.join(columns) + + insert_sql = f""" + INSERT OR REPLACE INTO pathway_nodes ({column_names}) + VALUES ({placeholders}) + """ + + # Convert records to tuples in column order + rows = [] + for record in records: + row = tuple(record.get(col) for col in columns) + rows.append(row) + + cursor = conn.executemany(insert_sql, rows) + return cursor.rowcount + + +def log_refresh_start( + conn: sqlite3.Connection, + refresh_id: str, + date_from: Optional[str] = None, + date_to: Optional[str] = None, +) -> None: + """Log the start of a refresh operation.""" + conn.execute(""" + INSERT INTO pathway_refresh_log + (refresh_id, started_at, status, snowflake_query_date_from, snowflake_query_date_to) + VALUES (?, ?, 'running', ?, ?) + """, (refresh_id, datetime.now().isoformat(), date_from, date_to)) + conn.commit() + + +def log_refresh_complete( + conn: sqlite3.Connection, + refresh_id: str, + record_count: int, + date_filter_counts: dict[str, int], + duration_seconds: float, +) -> None: + """Log the successful completion of a refresh operation.""" + conn.execute(""" + UPDATE pathway_refresh_log + SET completed_at = ?, + status = 'completed', + record_count = ?, + date_filter_counts = ?, + processing_duration_seconds = ? + WHERE refresh_id = ? + """, ( + datetime.now().isoformat(), + record_count, + json.dumps(date_filter_counts), + duration_seconds, + refresh_id, + )) + conn.commit() + + +def log_refresh_failed( + conn: sqlite3.Connection, + refresh_id: str, + error_message: str, + duration_seconds: float, +) -> None: + """Log a failed refresh operation.""" + conn.execute(""" + UPDATE pathway_refresh_log + SET completed_at = ?, + status = 'failed', + error_message = ?, + processing_duration_seconds = ? + WHERE refresh_id = ? + """, ( + datetime.now().isoformat(), + error_message, + duration_seconds, + refresh_id, + )) + conn.commit() + + +def refresh_pathways( + minimum_patients: int = 5, + provider_codes: Optional[list[str]] = None, + trust_filter: Optional[list[str]] = None, + drug_filter: Optional[list[str]] = None, + directory_filter: Optional[list[str]] = None, + db_path: Optional[Path] = None, + paths: Optional[PathConfig] = None, + dry_run: bool = False, +) -> tuple[bool, str, dict]: + """ + Main refresh function that orchestrates the full pipeline. + + Args: + minimum_patients: Minimum patients to include a pathway + provider_codes: List of provider codes to filter Snowflake query + trust_filter: List of trust names to include in pathways + drug_filter: List of drug names to include in pathways + directory_filter: List of directories to include in pathways + db_path: Path to SQLite database (uses default if None) + paths: PathConfig for file paths + dry_run: If True, don't actually insert records + + Returns: + Tuple of (success: bool, message: str, stats: dict) + """ + if paths is None: + paths = default_paths + + # Set up database connection + if db_path: + db_config = DatabaseConfig(db_path=db_path) + else: + db_config = DatabaseConfig(data_dir=paths.data_dir) + + db_manager = DatabaseManager(db_config) + + # Load default filters if not provided + default_trusts, default_drugs, default_dirs = get_default_filters(paths) + + if trust_filter is None: + trust_filter = default_trusts + if drug_filter is None: + drug_filter = default_drugs + if directory_filter is None: + directory_filter = default_dirs + + # Ensure we have some filters + if not drug_filter: + return False, "No drugs specified and could not load defaults", {} + + logger.info("=" * 60) + logger.info("Pathway Data Refresh Starting") + logger.info("=" * 60) + logger.info(f"Minimum patients: {minimum_patients}") + logger.info(f"Trust filter: {len(trust_filter)} trusts") + logger.info(f"Drug filter: {len(drug_filter)} drugs") + logger.info(f"Directory filter: {len(directory_filter)} directories") + logger.info(f"Provider codes: {provider_codes or 'All'}") + logger.info(f"Database: {db_manager.db_path}") + logger.info(f"Dry run: {dry_run}") + logger.info("=" * 60) + + start_time = time.time() + refresh_id = str(uuid.uuid4())[:8] + stats = { + "refresh_id": refresh_id, + "date_filter_counts": {}, + "total_records": 0, + "snowflake_rows": 0, + } + + try: + # Verify database and tables + with db_manager.get_connection() as conn: + missing_tables = verify_pathway_tables_exist(conn) + if missing_tables: + logger.info(f"Creating missing tables: {missing_tables}") + create_pathway_tables(conn) + + # Log refresh start + if not dry_run: + log_refresh_start(conn, refresh_id) + + # Step 1: Fetch data from Snowflake + logger.info("") + logger.info("Step 1/4: Fetching data from Snowflake...") + df = fetch_and_transform_data( + provider_codes=provider_codes, + paths=paths, + ) + + if df.empty: + msg = "No data returned from Snowflake" + logger.error(msg) + with db_manager.get_connection() as conn: + log_refresh_failed(conn, refresh_id, msg, time.time() - start_time) + return False, msg, stats + + stats["snowflake_rows"] = len(df) + logger.info(f"Fetched {len(df)} records from Snowflake") + + # Step 2: Process all date filters + logger.info("") + logger.info("Step 2/4: Processing pathway data for 6 date filter combinations...") + + results = process_all_date_filters( + df=df, + trust_filter=trust_filter, + drug_filter=drug_filter, + directory_filter=directory_filter, + minimum_patients=minimum_patients, + refresh_id=refresh_id, + paths=paths, + ) + + # Count records per filter + for filter_id, records in results.items(): + stats["date_filter_counts"][filter_id] = len(records) + stats["total_records"] += len(records) + + logger.info(f"Processed {stats['total_records']} total pathway nodes") + for filter_id, count in stats["date_filter_counts"].items(): + logger.info(f" {filter_id}: {count} nodes") + + if dry_run: + logger.info("") + logger.info("DRY RUN - Skipping database insertion") + elapsed = time.time() - start_time + return True, f"Dry run complete: {stats['total_records']} records would be inserted", stats + + # Step 3: Clear existing data and insert new records + logger.info("") + logger.info("Step 3/4: Clearing existing pathway data and inserting new records...") + + with db_manager.get_transaction() as conn: + # Clear all existing pathway nodes + deleted = clear_pathway_nodes(conn) + logger.info(f"Cleared {deleted} existing pathway nodes") + + # Insert new records for each date filter + total_inserted = 0 + for filter_id, records in results.items(): + if records: + inserted = insert_pathway_records(conn, records) + total_inserted += len(records) + logger.info(f" Inserted {len(records)} records for {filter_id}") + + # Step 4: Log completion + logger.info("") + logger.info("Step 4/4: Logging refresh completion...") + + elapsed = time.time() - start_time + + with db_manager.get_connection() as conn: + log_refresh_complete( + conn=conn, + refresh_id=refresh_id, + record_count=stats["total_records"], + date_filter_counts=stats["date_filter_counts"], + duration_seconds=elapsed, + ) + + # Verify final counts + counts = get_pathway_table_counts(conn) + logger.info(f"Final table counts: {counts}") + + logger.info("") + logger.info("=" * 60) + logger.info(f"Refresh completed successfully in {elapsed:.1f} seconds") + logger.info(f"Total records: {stats['total_records']}") + logger.info(f"Refresh ID: {refresh_id}") + logger.info("=" * 60) + + return True, f"Refresh complete: {stats['total_records']} records in {elapsed:.1f}s", stats + + except Exception as e: + elapsed = time.time() - start_time + error_msg = f"Refresh failed: {e}" + logger.error(error_msg, exc_info=True) + + try: + with db_manager.get_connection() as conn: + log_refresh_failed(conn, refresh_id, str(e), elapsed) + except Exception: + pass # Don't fail the error handling + + return False, error_msg, stats + + +def main() -> int: + """CLI entry point.""" + parser = argparse.ArgumentParser( + description="Refresh pathway data from Snowflake", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + # Basic refresh with defaults + python -m cli.refresh_pathways + + # Refresh with custom minimum patients + python -m cli.refresh_pathways --minimum-patients 10 + + # Refresh specific providers only + python -m cli.refresh_pathways --provider-codes RGT,RM1 + + # Dry run to see what would be processed + python -m cli.refresh_pathways --dry-run + + # Verbose output + python -m cli.refresh_pathways --verbose + """ + ) + + parser.add_argument( + "--minimum-patients", + type=int, + default=5, + help="Minimum patients to include a pathway (default: 5)" + ) + + parser.add_argument( + "--provider-codes", + type=str, + default=None, + help="Comma-separated list of provider codes to filter (default: all)" + ) + + parser.add_argument( + "--db-path", + type=str, + default=None, + help="Path to SQLite database (default: data/pathways.db)" + ) + + parser.add_argument( + "--dry-run", + action="store_true", + help="Process data but don't insert into database" + ) + + parser.add_argument( + "--verbose", "-v", + action="store_true", + help="Enable verbose logging" + ) + + args = parser.parse_args() + + # Configure logging + import logging + log_level = logging.DEBUG if args.verbose else logging.INFO + setup_logging(level=log_level) + + # Parse provider codes + provider_codes = None + if args.provider_codes: + provider_codes = [code.strip() for code in args.provider_codes.split(",")] + + # Parse db path + db_path = Path(args.db_path) if args.db_path else None + + # Run the refresh + success, message, stats = refresh_pathways( + minimum_patients=args.minimum_patients, + provider_codes=provider_codes, + db_path=db_path, + dry_run=args.dry_run, + ) + + if success: + print(f"\nāœ“ {message}") + return 0 + else: + print(f"\nāœ— {message}", file=sys.stderr) + return 1 + + +if __name__ == "__main__": + sys.exit(main())