Initial commit before Ralph loop
This commit is contained in:
@@ -0,0 +1,581 @@
|
||||
"""
|
||||
Diagnosis lookup module for NHS Patient Pathway Analysis.
|
||||
|
||||
Provides functions to validate patient indications by checking GP diagnosis records
|
||||
against SNOMED cluster codes. Uses the drug-to-cluster mapping from
|
||||
drug_indication_clusters.csv and queries Snowflake for SNOMED codes and GP records.
|
||||
|
||||
Key workflow:
|
||||
1. Get drug's valid indication clusters from local mapping
|
||||
2. Get all SNOMED codes for those clusters from Snowflake
|
||||
3. Check if patient has any of those SNOMED codes in GP records
|
||||
4. Report indication validation status
|
||||
|
||||
IMPORTANT: HCD activity data indication codes are UNRELIABLE. This module uses
|
||||
GP/Primary Care data (PrimaryCareClinicalCoding) as the authoritative source.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import date, datetime
|
||||
from pathlib import Path
|
||||
from typing import Optional, Callable, Any, cast
|
||||
import csv
|
||||
|
||||
from core.logging_config import get_logger
|
||||
from data_processing.database import DatabaseManager, default_db_manager
|
||||
from data_processing.snowflake_connector import (
|
||||
SnowflakeConnector,
|
||||
get_connector,
|
||||
is_snowflake_available,
|
||||
is_snowflake_configured,
|
||||
SNOWFLAKE_AVAILABLE,
|
||||
)
|
||||
from data_processing.cache import get_cache, is_cache_enabled
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ClusterSnomedCodes:
|
||||
"""SNOMED codes for a clinical coding cluster."""
|
||||
cluster_id: str
|
||||
cluster_description: str
|
||||
snomed_codes: list[str] = field(default_factory=list)
|
||||
snomed_descriptions: dict[str, str] = field(default_factory=dict)
|
||||
|
||||
@property
|
||||
def code_count(self) -> int:
|
||||
return len(self.snomed_codes)
|
||||
|
||||
|
||||
@dataclass
|
||||
class IndicationValidationResult:
|
||||
"""Result of validating a patient's indication for a drug."""
|
||||
patient_pseudonym: str
|
||||
drug_name: str
|
||||
has_valid_indication: bool
|
||||
matched_cluster_id: Optional[str] = None
|
||||
matched_snomed_code: Optional[str] = None
|
||||
matched_snomed_description: Optional[str] = None
|
||||
checked_clusters: list[str] = field(default_factory=list)
|
||||
total_codes_checked: int = 0
|
||||
source: str = "GP_SNOMED" # GP_SNOMED | NONE
|
||||
error_message: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class DrugIndicationMatchRate:
|
||||
"""Match rate statistics for a drug's indication validation."""
|
||||
drug_name: str
|
||||
total_patients: int
|
||||
patients_with_indication: int
|
||||
patients_without_indication: int
|
||||
match_rate: float # 0.0 to 1.0
|
||||
clusters_checked: list[str] = field(default_factory=list)
|
||||
sample_unmatched: list[str] = field(default_factory=list) # Sample patient IDs
|
||||
|
||||
|
||||
def get_drug_clusters(
|
||||
drug_name: str,
|
||||
db_manager: Optional[DatabaseManager] = None
|
||||
) -> list[dict]:
|
||||
"""
|
||||
Get all SNOMED cluster mappings for a drug from local SQLite.
|
||||
|
||||
Args:
|
||||
drug_name: Drug name to look up (case-insensitive)
|
||||
db_manager: Optional DatabaseManager (defaults to default_db_manager)
|
||||
|
||||
Returns:
|
||||
List of dicts with keys: drug_name, indication, cluster_id,
|
||||
cluster_description, nice_ta_reference
|
||||
"""
|
||||
if db_manager is None:
|
||||
db_manager = default_db_manager
|
||||
|
||||
query = """
|
||||
SELECT drug_name, indication, cluster_id, cluster_description, nice_ta_reference
|
||||
FROM ref_drug_indication_clusters
|
||||
WHERE UPPER(drug_name) = UPPER(?)
|
||||
ORDER BY indication, cluster_id
|
||||
"""
|
||||
|
||||
try:
|
||||
with db_manager.get_connection() as conn:
|
||||
cursor = conn.execute(query, (drug_name,))
|
||||
rows = cursor.fetchall()
|
||||
|
||||
results = []
|
||||
for row in rows:
|
||||
results.append({
|
||||
"drug_name": row["drug_name"],
|
||||
"indication": row["indication"],
|
||||
"cluster_id": row["cluster_id"],
|
||||
"cluster_description": row["cluster_description"],
|
||||
"nice_ta_reference": row["nice_ta_reference"],
|
||||
})
|
||||
|
||||
logger.debug(f"Found {len(results)} cluster mappings for drug '{drug_name}'")
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting clusters for drug '{drug_name}': {e}")
|
||||
return []
|
||||
|
||||
|
||||
def get_drug_cluster_ids(
|
||||
drug_name: str,
|
||||
db_manager: Optional[DatabaseManager] = None
|
||||
) -> list[str]:
|
||||
"""
|
||||
Get unique cluster IDs for a drug.
|
||||
|
||||
Args:
|
||||
drug_name: Drug name to look up
|
||||
db_manager: Optional DatabaseManager
|
||||
|
||||
Returns:
|
||||
List of unique cluster IDs
|
||||
"""
|
||||
clusters = get_drug_clusters(drug_name, db_manager)
|
||||
return list(set(c["cluster_id"] for c in clusters))
|
||||
|
||||
|
||||
def get_cluster_snomed_codes(
|
||||
cluster_id: str,
|
||||
connector: Optional[SnowflakeConnector] = None,
|
||||
use_cache: bool = True,
|
||||
) -> ClusterSnomedCodes:
|
||||
"""
|
||||
Get all SNOMED codes for a cluster from Snowflake.
|
||||
|
||||
Queries the ClinicalCodingClusterSnomedCodes table to get all SNOMED codes
|
||||
that belong to the specified cluster.
|
||||
|
||||
Args:
|
||||
cluster_id: Cluster ID to look up (e.g., 'RARTH_COD', 'PSORIASIS_COD')
|
||||
connector: Optional SnowflakeConnector (defaults to singleton)
|
||||
use_cache: Whether to use cached results (default True)
|
||||
|
||||
Returns:
|
||||
ClusterSnomedCodes with list of SNOMED codes and descriptions
|
||||
"""
|
||||
if not SNOWFLAKE_AVAILABLE:
|
||||
logger.warning("Snowflake connector not available")
|
||||
return ClusterSnomedCodes(cluster_id=cluster_id, cluster_description="")
|
||||
|
||||
if not is_snowflake_configured():
|
||||
logger.warning("Snowflake not configured - cannot get cluster codes")
|
||||
return ClusterSnomedCodes(cluster_id=cluster_id, cluster_description="")
|
||||
|
||||
# Check cache first
|
||||
cache_key = f"cluster_snomed_{cluster_id}"
|
||||
if use_cache and is_cache_enabled():
|
||||
cache = get_cache()
|
||||
cached = cache.get(cache_key)
|
||||
if cached is not None and len(cached) > 0:
|
||||
logger.debug(f"Using cached SNOMED codes for cluster '{cluster_id}'")
|
||||
cached_dict = cached[0] # First element is our data dict
|
||||
return ClusterSnomedCodes(
|
||||
cluster_id=cluster_id,
|
||||
cluster_description=str(cached_dict.get("description", "")),
|
||||
snomed_codes=list(cached_dict.get("codes", [])),
|
||||
snomed_descriptions=dict(cached_dict.get("descriptions", {})),
|
||||
)
|
||||
|
||||
if connector is None:
|
||||
connector = get_connector()
|
||||
|
||||
query = '''
|
||||
SELECT DISTINCT
|
||||
"Cluster_ID",
|
||||
"Cluster_Description",
|
||||
"SNOMEDCode",
|
||||
"SNOMEDDescription"
|
||||
FROM DATA_HUB.PHM."ClinicalCodingClusterSnomedCodes"
|
||||
WHERE "Cluster_ID" = %s
|
||||
ORDER BY "SNOMEDCode"
|
||||
'''
|
||||
|
||||
try:
|
||||
results = connector.execute_dict(query, (cluster_id,))
|
||||
|
||||
if not results:
|
||||
logger.warning(f"No SNOMED codes found for cluster '{cluster_id}'")
|
||||
return ClusterSnomedCodes(cluster_id=cluster_id, cluster_description="")
|
||||
|
||||
codes = []
|
||||
descriptions = {}
|
||||
description = results[0].get("Cluster_Description", "") if results else ""
|
||||
|
||||
for row in results:
|
||||
code = row.get("SNOMEDCode")
|
||||
if code:
|
||||
codes.append(code)
|
||||
descriptions[code] = row.get("SNOMEDDescription", "")
|
||||
|
||||
logger.info(f"Found {len(codes)} SNOMED codes for cluster '{cluster_id}'")
|
||||
|
||||
# Cache the results (using query-based cache with fake params)
|
||||
if use_cache and is_cache_enabled():
|
||||
cache = get_cache()
|
||||
cache_data = [{
|
||||
"description": description,
|
||||
"codes": codes,
|
||||
"descriptions": descriptions,
|
||||
}]
|
||||
cache.set(cache_key, None, cache_data) # type: ignore[arg-type]
|
||||
|
||||
return ClusterSnomedCodes(
|
||||
cluster_id=cluster_id,
|
||||
cluster_description=description,
|
||||
snomed_codes=codes,
|
||||
snomed_descriptions=descriptions,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting SNOMED codes for cluster '{cluster_id}': {e}")
|
||||
return ClusterSnomedCodes(cluster_id=cluster_id, cluster_description="")
|
||||
|
||||
|
||||
def patient_has_indication(
|
||||
patient_pseudonym: str,
|
||||
cluster_ids: list[str],
|
||||
connector: Optional[SnowflakeConnector] = None,
|
||||
before_date: Optional[date] = None,
|
||||
) -> tuple[bool, Optional[str], Optional[str], Optional[str]]:
|
||||
"""
|
||||
Check if a patient has any SNOMED codes from the specified clusters in GP records.
|
||||
|
||||
Args:
|
||||
patient_pseudonym: Patient's pseudonymised NHS number
|
||||
cluster_ids: List of cluster IDs to check against
|
||||
connector: Optional SnowflakeConnector
|
||||
before_date: Optional date - only check diagnoses before this date
|
||||
|
||||
Returns:
|
||||
Tuple of (has_indication, matched_cluster_id, matched_snomed_code, matched_description)
|
||||
"""
|
||||
if not SNOWFLAKE_AVAILABLE or not is_snowflake_configured():
|
||||
return False, None, None, None
|
||||
|
||||
if not cluster_ids:
|
||||
return False, None, None, None
|
||||
|
||||
if connector is None:
|
||||
connector = get_connector()
|
||||
|
||||
# Build placeholders for cluster IDs
|
||||
placeholders = ", ".join(["%s"] * len(cluster_ids))
|
||||
|
||||
# Query to check if patient has any matching SNOMED code
|
||||
query = f'''
|
||||
SELECT
|
||||
pc."SNOMEDCode",
|
||||
cc."Cluster_ID",
|
||||
cc."SNOMEDDescription"
|
||||
FROM DATA_HUB.PHM."PrimaryCareClinicalCoding" pc
|
||||
INNER JOIN DATA_HUB.PHM."ClinicalCodingClusterSnomedCodes" cc
|
||||
ON pc."SNOMEDCode" = cc."SNOMEDCode"
|
||||
WHERE pc."PatientPseudonym" = %s
|
||||
AND cc."Cluster_ID" IN ({placeholders})
|
||||
'''
|
||||
|
||||
params = [patient_pseudonym] + cluster_ids
|
||||
|
||||
if before_date:
|
||||
query += ' AND pc."EventDateTime" < %s'
|
||||
params.append(before_date.isoformat())
|
||||
|
||||
query += ' LIMIT 1'
|
||||
|
||||
try:
|
||||
results = connector.execute_dict(query, tuple(params))
|
||||
|
||||
if results:
|
||||
row = results[0]
|
||||
return (
|
||||
True,
|
||||
row.get("Cluster_ID"),
|
||||
row.get("SNOMEDCode"),
|
||||
row.get("SNOMEDDescription"),
|
||||
)
|
||||
|
||||
return False, None, None, None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking indication for patient '{patient_pseudonym}': {e}")
|
||||
return False, None, None, None
|
||||
|
||||
|
||||
def validate_indication(
|
||||
patient_pseudonym: str,
|
||||
drug_name: str,
|
||||
connector: Optional[SnowflakeConnector] = None,
|
||||
db_manager: Optional[DatabaseManager] = None,
|
||||
before_date: Optional[date] = None,
|
||||
) -> IndicationValidationResult:
|
||||
"""
|
||||
Validate that a patient has an appropriate indication for a drug.
|
||||
|
||||
Full validation workflow:
|
||||
1. Get drug's valid indication clusters from local mapping
|
||||
2. Check if patient has any matching SNOMED codes in GP records
|
||||
3. Return detailed validation result
|
||||
|
||||
Args:
|
||||
patient_pseudonym: Patient's pseudonymised NHS number
|
||||
drug_name: Drug name to validate indication for
|
||||
connector: Optional SnowflakeConnector
|
||||
db_manager: Optional DatabaseManager
|
||||
before_date: Optional date - only check diagnoses before this date
|
||||
|
||||
Returns:
|
||||
IndicationValidationResult with validation details
|
||||
"""
|
||||
result = IndicationValidationResult(
|
||||
patient_pseudonym=patient_pseudonym,
|
||||
drug_name=drug_name,
|
||||
has_valid_indication=False,
|
||||
)
|
||||
|
||||
# Step 1: Get drug's cluster mappings
|
||||
cluster_ids = get_drug_cluster_ids(drug_name, db_manager)
|
||||
|
||||
if not cluster_ids:
|
||||
result.error_message = f"No cluster mappings found for drug '{drug_name}'"
|
||||
result.source = "NONE"
|
||||
return result
|
||||
|
||||
result.checked_clusters = cluster_ids
|
||||
|
||||
# Step 2: Check Snowflake availability
|
||||
if not SNOWFLAKE_AVAILABLE:
|
||||
result.error_message = "Snowflake connector not installed"
|
||||
result.source = "NONE"
|
||||
return result
|
||||
|
||||
if not is_snowflake_configured():
|
||||
result.error_message = "Snowflake not configured"
|
||||
result.source = "NONE"
|
||||
return result
|
||||
|
||||
# Step 3: Check patient GP records
|
||||
has_indication, matched_cluster, matched_code, matched_desc = patient_has_indication(
|
||||
patient_pseudonym=patient_pseudonym,
|
||||
cluster_ids=cluster_ids,
|
||||
connector=connector,
|
||||
before_date=before_date,
|
||||
)
|
||||
|
||||
result.has_valid_indication = has_indication
|
||||
result.matched_cluster_id = matched_cluster
|
||||
result.matched_snomed_code = matched_code
|
||||
result.matched_snomed_description = matched_desc
|
||||
result.source = "GP_SNOMED" if has_indication else "NONE"
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def get_indication_match_rate(
|
||||
drug_name: str,
|
||||
patient_pseudonyms: list[str],
|
||||
connector: Optional[SnowflakeConnector] = None,
|
||||
db_manager: Optional[DatabaseManager] = None,
|
||||
sample_unmatched_count: int = 10,
|
||||
) -> DrugIndicationMatchRate:
|
||||
"""
|
||||
Calculate indication match rate for a drug across a list of patients.
|
||||
|
||||
Args:
|
||||
drug_name: Drug name to check
|
||||
patient_pseudonyms: List of patient pseudonymised NHS numbers
|
||||
connector: Optional SnowflakeConnector
|
||||
db_manager: Optional DatabaseManager
|
||||
sample_unmatched_count: Number of unmatched patient IDs to include in sample
|
||||
|
||||
Returns:
|
||||
DrugIndicationMatchRate with match statistics
|
||||
"""
|
||||
if connector is None and SNOWFLAKE_AVAILABLE and is_snowflake_configured():
|
||||
connector = get_connector()
|
||||
|
||||
cluster_ids = get_drug_cluster_ids(drug_name, db_manager)
|
||||
|
||||
total = len(patient_pseudonyms)
|
||||
matched = 0
|
||||
unmatched = 0
|
||||
sample_unmatched: list[str] = []
|
||||
|
||||
if not cluster_ids:
|
||||
logger.warning(f"No cluster mappings for drug '{drug_name}' - all patients will be unmatched")
|
||||
return DrugIndicationMatchRate(
|
||||
drug_name=drug_name,
|
||||
total_patients=total,
|
||||
patients_with_indication=0,
|
||||
patients_without_indication=total,
|
||||
match_rate=0.0,
|
||||
clusters_checked=[],
|
||||
sample_unmatched=patient_pseudonyms[:sample_unmatched_count],
|
||||
)
|
||||
|
||||
for i, pseudonym in enumerate(patient_pseudonyms):
|
||||
if i > 0 and i % 100 == 0:
|
||||
logger.info(f"Validating indications: {i}/{total} ({100*i/total:.1f}%)")
|
||||
|
||||
has_indication, _, _, _ = patient_has_indication(
|
||||
patient_pseudonym=pseudonym,
|
||||
cluster_ids=cluster_ids,
|
||||
connector=connector,
|
||||
)
|
||||
|
||||
if has_indication:
|
||||
matched += 1
|
||||
else:
|
||||
unmatched += 1
|
||||
if len(sample_unmatched) < sample_unmatched_count:
|
||||
sample_unmatched.append(pseudonym)
|
||||
|
||||
match_rate = matched / total if total > 0 else 0.0
|
||||
|
||||
logger.info(f"Indication match rate for '{drug_name}': {100*match_rate:.1f}% ({matched}/{total})")
|
||||
|
||||
return DrugIndicationMatchRate(
|
||||
drug_name=drug_name,
|
||||
total_patients=total,
|
||||
patients_with_indication=matched,
|
||||
patients_without_indication=unmatched,
|
||||
match_rate=match_rate,
|
||||
clusters_checked=cluster_ids,
|
||||
sample_unmatched=sample_unmatched,
|
||||
)
|
||||
|
||||
|
||||
def batch_validate_indications(
|
||||
patient_drug_pairs: list[tuple[str, str]],
|
||||
connector: Optional[SnowflakeConnector] = None,
|
||||
db_manager: Optional[DatabaseManager] = None,
|
||||
progress_callback: Optional[Callable[[int, int], None]] = None,
|
||||
) -> list[IndicationValidationResult]:
|
||||
"""
|
||||
Validate indications for multiple patient-drug pairs efficiently.
|
||||
|
||||
Args:
|
||||
patient_drug_pairs: List of (patient_pseudonym, drug_name) tuples
|
||||
connector: Optional SnowflakeConnector
|
||||
db_manager: Optional DatabaseManager
|
||||
progress_callback: Optional callback(current, total) for progress updates
|
||||
|
||||
Returns:
|
||||
List of IndicationValidationResult for each pair
|
||||
"""
|
||||
results = []
|
||||
total = len(patient_drug_pairs)
|
||||
|
||||
# Cache cluster lookups by drug
|
||||
drug_clusters_cache = {}
|
||||
|
||||
for i, (pseudonym, drug_name) in enumerate(patient_drug_pairs):
|
||||
if progress_callback:
|
||||
progress_callback(i + 1, total)
|
||||
|
||||
# Get clusters from cache or lookup
|
||||
drug_upper = drug_name.upper()
|
||||
if drug_upper not in drug_clusters_cache:
|
||||
drug_clusters_cache[drug_upper] = get_drug_cluster_ids(drug_name, db_manager)
|
||||
|
||||
cluster_ids = drug_clusters_cache[drug_upper]
|
||||
|
||||
if not cluster_ids:
|
||||
results.append(IndicationValidationResult(
|
||||
patient_pseudonym=pseudonym,
|
||||
drug_name=drug_name,
|
||||
has_valid_indication=False,
|
||||
source="NONE",
|
||||
error_message=f"No cluster mappings for drug '{drug_name}'",
|
||||
))
|
||||
continue
|
||||
|
||||
# Check patient indication
|
||||
has_indication, matched_cluster, matched_code, matched_desc = patient_has_indication(
|
||||
patient_pseudonym=pseudonym,
|
||||
cluster_ids=cluster_ids,
|
||||
connector=connector,
|
||||
)
|
||||
|
||||
results.append(IndicationValidationResult(
|
||||
patient_pseudonym=pseudonym,
|
||||
drug_name=drug_name,
|
||||
has_valid_indication=has_indication,
|
||||
matched_cluster_id=matched_cluster,
|
||||
matched_snomed_code=matched_code,
|
||||
matched_snomed_description=matched_desc,
|
||||
checked_clusters=cluster_ids,
|
||||
source="GP_SNOMED" if has_indication else "NONE",
|
||||
))
|
||||
|
||||
matched_count = sum(1 for r in results if r.has_valid_indication)
|
||||
logger.info(f"Batch validation complete: {matched_count}/{total} ({100*matched_count/total:.1f}%) with valid indications")
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def get_available_clusters(
|
||||
connector: Optional[SnowflakeConnector] = None,
|
||||
) -> list[dict]:
|
||||
"""
|
||||
Get list of all available SNOMED clusters from Snowflake.
|
||||
|
||||
Returns:
|
||||
List of dicts with cluster_id, cluster_description, code_count
|
||||
"""
|
||||
if not SNOWFLAKE_AVAILABLE or not is_snowflake_configured():
|
||||
logger.warning("Snowflake not available - cannot list clusters")
|
||||
return []
|
||||
|
||||
if connector is None:
|
||||
connector = get_connector()
|
||||
|
||||
query = '''
|
||||
SELECT
|
||||
"Cluster_ID",
|
||||
"Cluster_Description",
|
||||
COUNT(DISTINCT "SNOMEDCode") as code_count
|
||||
FROM DATA_HUB.PHM."ClinicalCodingClusterSnomedCodes"
|
||||
GROUP BY "Cluster_ID", "Cluster_Description"
|
||||
ORDER BY "Cluster_ID"
|
||||
'''
|
||||
|
||||
try:
|
||||
results = connector.execute_dict(query)
|
||||
|
||||
clusters = []
|
||||
for row in results:
|
||||
clusters.append({
|
||||
"cluster_id": row.get("Cluster_ID"),
|
||||
"cluster_description": row.get("Cluster_Description"),
|
||||
"code_count": row.get("code_count", 0),
|
||||
})
|
||||
|
||||
logger.info(f"Found {len(clusters)} available SNOMED clusters")
|
||||
return clusters
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting available clusters: {e}")
|
||||
return []
|
||||
|
||||
|
||||
# Export public API
|
||||
__all__ = [
|
||||
"ClusterSnomedCodes",
|
||||
"IndicationValidationResult",
|
||||
"DrugIndicationMatchRate",
|
||||
"get_drug_clusters",
|
||||
"get_drug_cluster_ids",
|
||||
"get_cluster_snomed_codes",
|
||||
"patient_has_indication",
|
||||
"validate_indication",
|
||||
"get_indication_match_rate",
|
||||
"batch_validate_indications",
|
||||
"get_available_clusters",
|
||||
]
|
||||
Reference in New Issue
Block a user