Initial commit before Ralph loop
This commit is contained in:
@@ -0,0 +1,399 @@
|
||||
"""
|
||||
Data loader abstractions for NHS High-Cost Drug Patient Pathway Analysis Tool.
|
||||
|
||||
Provides a unified interface for loading patient intervention data from:
|
||||
- CSV/Parquet files (current behavior)
|
||||
- SQLite database (new, faster approach)
|
||||
- Snowflake (future, direct from warehouse)
|
||||
|
||||
The DataLoader ABC defines the contract for all loader implementations.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import date
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from core import PathConfig, default_paths
|
||||
from core.logging_config import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoadResult:
|
||||
"""Result of a data load operation.
|
||||
|
||||
Attributes:
|
||||
df: The loaded DataFrame with processed patient intervention data
|
||||
source: Description of the data source (e.g., "csv:/path/to/file.csv", "sqlite:fact_interventions")
|
||||
row_count: Number of rows loaded
|
||||
columns: List of column names in the DataFrame
|
||||
load_time_seconds: Time taken to load the data
|
||||
"""
|
||||
df: pd.DataFrame
|
||||
source: str
|
||||
row_count: int
|
||||
columns: list[str] = field(default_factory=list)
|
||||
load_time_seconds: float = 0.0
|
||||
|
||||
def __post_init__(self):
|
||||
if not self.columns:
|
||||
self.columns = list(self.df.columns)
|
||||
|
||||
|
||||
# Expected columns in a processed DataFrame
|
||||
# These are the columns that generate_graph() expects to receive
|
||||
REQUIRED_COLUMNS = [
|
||||
"UPID", # Unique Patient ID (Provider Code prefix + PersonKey)
|
||||
"Drug Name", # Standardized drug name
|
||||
"Intervention Date", # Date of intervention
|
||||
"Price Actual", # Cost of intervention
|
||||
"OrganisationName", # NHS Trust name
|
||||
"Directory", # Medical specialty/directory
|
||||
"Provider Code", # NHS provider code
|
||||
"PersonKey", # Patient identifier within provider
|
||||
]
|
||||
|
||||
# Additional columns that are useful but not strictly required
|
||||
OPTIONAL_COLUMNS = [
|
||||
"UPIDTreatment", # UPID + Drug Name combo (created by generate_graph)
|
||||
"Treatment Function Code", # NHS treatment function code
|
||||
"Additional Detail 1",
|
||||
"Additional Detail 2",
|
||||
"Additional Detail 3",
|
||||
"Additional Detail 4",
|
||||
"Additional Detail 5",
|
||||
]
|
||||
|
||||
|
||||
class DataLoader(ABC):
|
||||
"""Abstract base class for data loaders.
|
||||
|
||||
All data loaders must implement the load() method which returns
|
||||
a DataFrame ready for use by generate_graph().
|
||||
|
||||
The returned DataFrame must contain REQUIRED_COLUMNS at minimum.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def load(self) -> LoadResult:
|
||||
"""Load and process patient intervention data.
|
||||
|
||||
Returns:
|
||||
LoadResult containing the processed DataFrame and metadata.
|
||||
The DataFrame must contain all REQUIRED_COLUMNS.
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If the data source doesn't exist
|
||||
ValueError: If the data is malformed or missing required columns
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def validate_source(self) -> tuple[bool, str]:
|
||||
"""Check if the data source is valid and accessible.
|
||||
|
||||
Returns:
|
||||
Tuple of (is_valid, message).
|
||||
If is_valid is False, message explains the issue.
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def source_description(self) -> str:
|
||||
"""Human-readable description of the data source."""
|
||||
pass
|
||||
|
||||
def validate_dataframe(self, df: pd.DataFrame) -> tuple[bool, list[str]]:
|
||||
"""Validate that a DataFrame has all required columns.
|
||||
|
||||
Args:
|
||||
df: DataFrame to validate
|
||||
|
||||
Returns:
|
||||
Tuple of (is_valid, missing_columns).
|
||||
If is_valid is False, missing_columns lists what's missing.
|
||||
"""
|
||||
missing = [col for col in REQUIRED_COLUMNS if col not in df.columns]
|
||||
return len(missing) == 0, missing
|
||||
|
||||
|
||||
class FileDataLoader(DataLoader):
|
||||
"""Loads data from CSV or Parquet files.
|
||||
|
||||
This replicates the current behavior of dashboard_gui.main():
|
||||
1. Read CSV or Parquet file
|
||||
2. Apply patient_id() transformation
|
||||
3. Convert dates
|
||||
4. Apply drug_names() standardization
|
||||
5. Clean organization names
|
||||
6. Apply department_identification()
|
||||
|
||||
Args:
|
||||
file_path: Path to the CSV or Parquet file
|
||||
paths: PathConfig for reference data file locations (uses default_paths if None)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: Path | str,
|
||||
paths: Optional[PathConfig] = None,
|
||||
):
|
||||
self.file_path = Path(file_path)
|
||||
self.paths = paths or default_paths
|
||||
|
||||
def validate_source(self) -> tuple[bool, str]:
|
||||
"""Check if the file exists and has a supported extension."""
|
||||
if not self.file_path.exists():
|
||||
return False, f"File not found: {self.file_path}"
|
||||
|
||||
ext = self.file_path.suffix.lower()
|
||||
if ext not in ('.csv', '.parquet'):
|
||||
return False, f"Unsupported file type: {ext}. Must be .csv or .parquet"
|
||||
|
||||
return True, "OK"
|
||||
|
||||
@property
|
||||
def source_description(self) -> str:
|
||||
return f"file:{self.file_path}"
|
||||
|
||||
def load(self) -> LoadResult:
|
||||
"""Load and process data from CSV or Parquet file.
|
||||
|
||||
Applies the same transformation pipeline as the original
|
||||
dashboard_gui.main() function.
|
||||
"""
|
||||
import time
|
||||
from tools import data
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# Validate source before loading
|
||||
is_valid, msg = self.validate_source()
|
||||
if not is_valid:
|
||||
raise FileNotFoundError(msg)
|
||||
|
||||
# Read file based on extension
|
||||
ext = self.file_path.suffix.lower()
|
||||
logger.info(f"Reading {ext} file: {self.file_path}")
|
||||
|
||||
if ext == '.csv':
|
||||
df_raw = pd.read_csv(self.file_path, low_memory=False)
|
||||
else: # .parquet
|
||||
df_raw = pd.read_parquet(self.file_path)
|
||||
|
||||
logger.info(f"File read successfully. {len(df_raw)} rows.")
|
||||
|
||||
# Apply transformations (same as dashboard_gui.main())
|
||||
df = data.patient_id(df_raw)
|
||||
logger.info("Patient ID processing complete.")
|
||||
|
||||
df['Intervention Date'] = pd.to_datetime(df['Intervention Date'], format="%Y-%m-%d")
|
||||
logger.info("Date conversion complete.")
|
||||
|
||||
# Preserve original drug name before standardization (for SQLite storage)
|
||||
df['Drug Name Raw'] = df['Drug Name'].copy()
|
||||
|
||||
df = data.drug_names(df, self.paths)
|
||||
logger.info("Drug name processing complete.")
|
||||
|
||||
df['OrganisationName'] = df['OrganisationName'].str.replace(',', '')
|
||||
logger.info("Organisation name cleaning complete.")
|
||||
|
||||
df = data.department_identification(df, self.paths)
|
||||
logger.info("Department identification complete.")
|
||||
|
||||
# Validate result
|
||||
is_valid, missing = self.validate_dataframe(df)
|
||||
if not is_valid:
|
||||
raise ValueError(f"Processed DataFrame missing required columns: {missing}")
|
||||
|
||||
load_time = time.time() - start_time
|
||||
logger.info(f"Data loading complete. {len(df)} rows in {load_time:.2f}s")
|
||||
|
||||
return LoadResult(
|
||||
df=df,
|
||||
source=self.source_description,
|
||||
row_count=len(df),
|
||||
load_time_seconds=load_time,
|
||||
)
|
||||
|
||||
|
||||
class SQLiteDataLoader(DataLoader):
|
||||
"""Loads data from SQLite fact_interventions table.
|
||||
|
||||
This provides faster loading by reading pre-processed data from SQLite
|
||||
instead of re-processing CSV files each time.
|
||||
|
||||
The SQLite database must have been populated by the migration scripts.
|
||||
|
||||
Args:
|
||||
db_path: Path to the SQLite database (uses default if None)
|
||||
date_range: Optional tuple of (start_date, end_date) to filter data
|
||||
trusts: Optional list of trust names to filter
|
||||
drugs: Optional list of drug names to filter
|
||||
directories: Optional list of directories to filter
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
db_path: Optional[Path | str] = None,
|
||||
date_range: Optional[tuple[date, date]] = None,
|
||||
trusts: Optional[list[str]] = None,
|
||||
drugs: Optional[list[str]] = None,
|
||||
directories: Optional[list[str]] = None,
|
||||
):
|
||||
from data_processing.database import default_db_config
|
||||
|
||||
self.db_path = Path(db_path) if db_path else Path(default_db_config.db_path)
|
||||
self.date_range = date_range
|
||||
self.trusts = trusts
|
||||
self.drugs = drugs
|
||||
self.directories = directories
|
||||
|
||||
def validate_source(self) -> tuple[bool, str]:
|
||||
"""Check if the database exists and has the fact_interventions table."""
|
||||
if not self.db_path.exists():
|
||||
return False, f"Database not found: {self.db_path}"
|
||||
|
||||
# Check if fact_interventions table exists
|
||||
from data_processing.database import DatabaseManager, DatabaseConfig
|
||||
|
||||
config = DatabaseConfig(db_path=self.db_path)
|
||||
manager = DatabaseManager(config)
|
||||
|
||||
if not manager.table_exists("fact_interventions"):
|
||||
return False, "fact_interventions table not found in database"
|
||||
|
||||
count = manager.get_table_count("fact_interventions")
|
||||
if count == 0:
|
||||
return False, "fact_interventions table is empty"
|
||||
|
||||
return True, f"OK ({count:,} rows available)"
|
||||
|
||||
@property
|
||||
def source_description(self) -> str:
|
||||
return f"sqlite:{self.db_path}"
|
||||
|
||||
def load(self) -> LoadResult:
|
||||
"""Load data from SQLite fact_interventions table.
|
||||
|
||||
Maps SQLite column names to the expected DataFrame column names.
|
||||
Applies optional filters for date range, trusts, drugs, directories.
|
||||
"""
|
||||
import time
|
||||
from data_processing.database import DatabaseManager, DatabaseConfig
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# Validate source
|
||||
is_valid, msg = self.validate_source()
|
||||
if not is_valid:
|
||||
raise FileNotFoundError(msg)
|
||||
|
||||
logger.info(f"Loading data from SQLite: {self.db_path}")
|
||||
|
||||
# Build query with optional filters
|
||||
query = """
|
||||
SELECT
|
||||
upid AS "UPID",
|
||||
provider_code AS "Provider Code",
|
||||
person_key AS "PersonKey",
|
||||
drug_name_std AS "Drug Name",
|
||||
intervention_date AS "Intervention Date",
|
||||
price_actual AS "Price Actual",
|
||||
org_name AS "OrganisationName",
|
||||
directory AS "Directory",
|
||||
treatment_function_code AS "Treatment Function Code",
|
||||
additional_detail_1 AS "Additional Detail 1",
|
||||
additional_detail_2 AS "Additional Detail 2",
|
||||
additional_detail_3 AS "Additional Detail 3",
|
||||
additional_detail_4 AS "Additional Detail 4",
|
||||
additional_detail_5 AS "Additional Detail 5"
|
||||
FROM fact_interventions
|
||||
WHERE 1=1
|
||||
"""
|
||||
params = []
|
||||
|
||||
if self.date_range:
|
||||
start, end = self.date_range
|
||||
query += " AND intervention_date >= ? AND intervention_date < ?"
|
||||
params.extend([str(start), str(end)])
|
||||
|
||||
if self.trusts:
|
||||
placeholders = ','.join('?' * len(self.trusts))
|
||||
query += f" AND org_name IN ({placeholders})"
|
||||
params.extend(self.trusts)
|
||||
|
||||
if self.drugs:
|
||||
placeholders = ','.join('?' * len(self.drugs))
|
||||
query += f" AND drug_name_std IN ({placeholders})"
|
||||
params.extend(self.drugs)
|
||||
|
||||
if self.directories:
|
||||
placeholders = ','.join('?' * len(self.directories))
|
||||
query += f" AND directory IN ({placeholders})"
|
||||
params.extend(self.directories)
|
||||
|
||||
# Execute query
|
||||
config = DatabaseConfig(db_path=self.db_path)
|
||||
manager = DatabaseManager(config)
|
||||
|
||||
with manager.get_connection() as conn:
|
||||
df = pd.read_sql_query(query, conn, params=params)
|
||||
|
||||
# Convert intervention_date to datetime
|
||||
df['Intervention Date'] = pd.to_datetime(df['Intervention Date'])
|
||||
|
||||
logger.info(f"Loaded {len(df)} rows from SQLite")
|
||||
|
||||
# Validate result
|
||||
is_valid, missing = self.validate_dataframe(df)
|
||||
if not is_valid:
|
||||
raise ValueError(f"SQLite data missing required columns: {missing}")
|
||||
|
||||
load_time = time.time() - start_time
|
||||
logger.info(f"SQLite data loading complete. {len(df)} rows in {load_time:.2f}s")
|
||||
|
||||
return LoadResult(
|
||||
df=df,
|
||||
source=self.source_description,
|
||||
row_count=len(df),
|
||||
load_time_seconds=load_time,
|
||||
)
|
||||
|
||||
|
||||
def get_loader(
|
||||
source: str | Path,
|
||||
paths: Optional[PathConfig] = None,
|
||||
**kwargs
|
||||
) -> DataLoader:
|
||||
"""Factory function to create the appropriate DataLoader.
|
||||
|
||||
Args:
|
||||
source: Either a file path (CSV/Parquet) or "sqlite" for database
|
||||
paths: PathConfig for reference data (used by FileDataLoader)
|
||||
**kwargs: Additional arguments passed to the loader constructor
|
||||
|
||||
Returns:
|
||||
Appropriate DataLoader instance
|
||||
|
||||
Examples:
|
||||
>>> loader = get_loader("data/activity.csv")
|
||||
>>> loader = get_loader("data/activity.parquet")
|
||||
>>> loader = get_loader("sqlite")
|
||||
>>> loader = get_loader("sqlite", date_range=(date(2024, 1, 1), date(2024, 12, 31)))
|
||||
"""
|
||||
source_str = str(source).lower()
|
||||
|
||||
if source_str == "sqlite":
|
||||
return SQLiteDataLoader(**kwargs)
|
||||
|
||||
# Assume it's a file path
|
||||
path = Path(source)
|
||||
return FileDataLoader(file_path=path, paths=paths)
|
||||
Reference in New Issue
Block a user