refactor: reorganize repository to src/ layout
Move 6 packages (core, config, data_processing, analysis, visualization, cli) into src/ to reduce root clutter. Merge tools/data.py into data_processing/transforms.py. Move docs to docs/. Path resolution via .pth file (setup_dev.py), pytest pythonpath config, and sys.path bootstrap in rxconfig.py and CLI entry points. Clean up pyproject.toml deps (remove stale pins, add snowflake-connector-python). Fix tomllib import for Python 3.10 compatibility. All 113 tests pass.
This commit is contained in:
@@ -0,0 +1,24 @@
|
||||
# Analysis Package
|
||||
|
||||
Four-step pathway analysis pipeline refactored from original 267-line `generate_graph()` function.
|
||||
|
||||
## Module: pathway_analyzer.py
|
||||
|
||||
**Main entry points:**
|
||||
- `generate_icicle_chart(df, filters)` — Directory charts (Trust → Directory → Drug → Pathway)
|
||||
- `generate_icicle_chart_indication(df, indication_df, filters)` — Indication charts using Search_Term hierarchy
|
||||
|
||||
**Pipeline steps:**
|
||||
1. `prepare_data()` — Filter by date/trusts/drugs/directories. **MUST use `df.copy()`** to prevent mutation.
|
||||
2. `calculate_statistics()` — Compute frequency, cost, duration stats
|
||||
3. `build_hierarchy()` — Create Trust → Directory/Indication → Drug → Pathway structure
|
||||
4. `prepare_chart_data()` — Format data for Plotly icicle chart
|
||||
|
||||
**Note on modified UPIDs:**
|
||||
For drug-aware indication matching, UPIDs are formatted as `{original}|{search_term}`. The hierarchy-building functions treat UPID as opaque — pipe delimiters work transparently without code changes.
|
||||
|
||||
## Module: statistics.py
|
||||
|
||||
Statistical calculation helper functions (frequency, cost, duration, per-patient metrics).
|
||||
|
||||
Called by `calculate_statistics()` during pipeline execution.
|
||||
@@ -0,0 +1,50 @@
|
||||
"""
|
||||
Analysis package for patient pathway processing.
|
||||
|
||||
This package contains refactored functions from the original generate_graph() pipeline:
|
||||
- pathway_analyzer: Main analysis pipeline with prepare_data, calculate_statistics, build_hierarchy
|
||||
- statistics: Statistical calculation functions (costs, frequencies, durations)
|
||||
"""
|
||||
|
||||
from analysis.pathway_analyzer import (
|
||||
prepare_data,
|
||||
calculate_statistics,
|
||||
build_hierarchy,
|
||||
prepare_chart_data,
|
||||
generate_icicle_chart,
|
||||
)
|
||||
|
||||
from analysis.statistics import (
|
||||
count_consecutive_values,
|
||||
calculate_drug_costs,
|
||||
calculate_dosing_frequency,
|
||||
calculate_drug_frequency_row,
|
||||
calculate_cost_per_patient_per_annum,
|
||||
calculate_treatment_duration,
|
||||
calculate_pathway_proportion,
|
||||
aggregate_patient_costs,
|
||||
aggregate_drug_frequencies,
|
||||
format_treatment_statistics,
|
||||
remove_nan_values,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Pathway analysis pipeline
|
||||
"prepare_data",
|
||||
"calculate_statistics",
|
||||
"build_hierarchy",
|
||||
"prepare_chart_data",
|
||||
"generate_icicle_chart",
|
||||
# Statistical calculations
|
||||
"count_consecutive_values",
|
||||
"calculate_drug_costs",
|
||||
"calculate_dosing_frequency",
|
||||
"calculate_drug_frequency_row",
|
||||
"calculate_cost_per_patient_per_annum",
|
||||
"calculate_treatment_duration",
|
||||
"calculate_pathway_proportion",
|
||||
"aggregate_patient_costs",
|
||||
"aggregate_drug_frequencies",
|
||||
"format_treatment_statistics",
|
||||
"remove_nan_values",
|
||||
]
|
||||
@@ -0,0 +1,943 @@
|
||||
"""
|
||||
Patient pathway analysis pipeline.
|
||||
|
||||
This module contains functions extracted from the original generate_graph() function
|
||||
to improve maintainability and testability. The functions follow this pipeline:
|
||||
|
||||
1. prepare_data() - Apply filters, create composite keys, load reference data
|
||||
2. calculate_statistics() - Calculate patient costs, drug frequencies, treatment durations
|
||||
3. build_hierarchy() - Build the Trust → Directory → Drug → Pathway hierarchy
|
||||
4. prepare_chart_data() - Finalize data for Plotly icicle chart
|
||||
|
||||
The generate_icicle_chart() function orchestrates the full pipeline.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from core import PathConfig, default_paths
|
||||
from core.logging_config import get_logger
|
||||
from analysis.statistics import (
|
||||
count_consecutive_values,
|
||||
calculate_drug_costs,
|
||||
calculate_dosing_frequency,
|
||||
calculate_cost_per_patient_per_annum,
|
||||
remove_nan_values,
|
||||
)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def prepare_data(
|
||||
df: pd.DataFrame,
|
||||
trust_filter: list[str],
|
||||
drug_filter: list[str],
|
||||
directory_filter: list[str],
|
||||
paths: Optional[PathConfig] = None,
|
||||
) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
|
||||
"""
|
||||
Prepare data for analysis by applying filters and loading reference data.
|
||||
|
||||
Args:
|
||||
df: DataFrame with processed patient intervention data
|
||||
trust_filter: List of trust names to include
|
||||
drug_filter: List of drug names to include
|
||||
directory_filter: List of directories to include
|
||||
paths: PathConfig for file paths (uses default if None)
|
||||
|
||||
Returns:
|
||||
Tuple of (filtered_df, org_codes_df, directory_df) or (None, None, None) if no data
|
||||
"""
|
||||
if paths is None:
|
||||
paths = default_paths
|
||||
|
||||
# Work on a copy to avoid mutating the caller's DataFrame
|
||||
# (Provider Code mapping is destructive — second call would map names to NaN)
|
||||
df = df.copy()
|
||||
|
||||
df["UPIDTreatment"] = df["UPID"] + df["Drug Name"]
|
||||
|
||||
org_codes = pd.read_csv(paths.org_codes_csv, index_col=1)
|
||||
df["Provider Code"] = df["Provider Code"].map(org_codes["Name"])
|
||||
|
||||
df = df[
|
||||
(df["Provider Code"].isin(trust_filter))
|
||||
& (df["Drug Name"].isin(drug_filter))
|
||||
& (df["Directory"].isin(directory_filter))
|
||||
]
|
||||
|
||||
if len(df) == 0:
|
||||
logger.warning("No data found for selected filters.")
|
||||
return None, None, None
|
||||
|
||||
directory_df = df[["UPID", "Directory"]].drop_duplicates("UPID").set_index("UPID")
|
||||
|
||||
logger.info("Filtering unrelated interventions")
|
||||
return df, org_codes, directory_df
|
||||
|
||||
|
||||
def _count_list_values(x):
|
||||
"""Count consecutive occurrences of each value in a sorted list."""
|
||||
return count_consecutive_values(x)
|
||||
|
||||
|
||||
def _sum_list_values(x):
|
||||
"""Calculate sum of price_actual for each drug's portion of the list."""
|
||||
return calculate_drug_costs(x["Drug Name"], x["Price Actual"])
|
||||
|
||||
|
||||
def _start_date_drug(start_dates_df: pd.DataFrame, x: pd.Series) -> list:
|
||||
"""Get start dates for each drug in a patient's treatment."""
|
||||
drug_count = x.notnull().sum()
|
||||
date_string = []
|
||||
for d in range(drug_count):
|
||||
UPID_date_var = str(x.name) + str(x[d])
|
||||
date = start_dates_df.loc[UPID_date_var, "Intervention Date"]
|
||||
date_string.append(date)
|
||||
return date_string
|
||||
|
||||
|
||||
def _end_date_drug(end_dates_df: pd.DataFrame, x: pd.Series) -> list:
|
||||
"""Get end dates for each drug in a patient's treatment."""
|
||||
drug_count = x.notnull().sum()
|
||||
date_string = []
|
||||
for d in range(drug_count - 1):
|
||||
UPID_date_var = str(x.name) + str(x[d])
|
||||
date = end_dates_df.loc[UPID_date_var, "Intervention Date"]
|
||||
date_string.append(date)
|
||||
return date_string
|
||||
|
||||
|
||||
def _drug_frequency_average(x: pd.Series) -> list[float]:
|
||||
"""Calculate average dosing frequency for each drug."""
|
||||
drug_count = x.index.str.contains("drug_").sum()
|
||||
freq = []
|
||||
for d in range(drug_count):
|
||||
freq_val = x.get(f"freq_{d}", 0)
|
||||
if pd.isna(freq_val):
|
||||
freq_val = 0
|
||||
else:
|
||||
freq_val = int(freq_val)
|
||||
|
||||
if freq_val > 1:
|
||||
start_date = x.get(f"start_date_{d}")
|
||||
end_date = x.get(f"end_date_{d}")
|
||||
if pd.notna(start_date) and pd.notna(end_date):
|
||||
freq_calc = calculate_dosing_frequency(freq_val, start_date, end_date)
|
||||
else:
|
||||
freq_calc = 0.0
|
||||
else:
|
||||
freq_calc = 0.0
|
||||
freq.append(freq_calc)
|
||||
return freq
|
||||
|
||||
|
||||
def _drop_duplicate_treatments(df: pd.DataFrame, ascending: bool) -> pd.DataFrame:
|
||||
"""Drop duplicate treatments keeping first/last based on date sort order."""
|
||||
df_sorted = df.sort_values(by=["Intervention Date"], ascending=ascending)
|
||||
df_treatment_steps = df_sorted.drop_duplicates(subset="UPIDTreatment", keep="first")
|
||||
if not ascending:
|
||||
df_treatment_steps = df_treatment_steps.sort_values(by=["Intervention Date"], ascending=True)
|
||||
return df_treatment_steps
|
||||
|
||||
|
||||
def calculate_statistics(
|
||||
df: pd.DataFrame,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
last_seen_date: str,
|
||||
title: str,
|
||||
) -> tuple[pd.DataFrame, pd.DataFrame, str]:
|
||||
"""
|
||||
Calculate patient statistics: costs, drug frequencies, treatment durations.
|
||||
|
||||
Args:
|
||||
df: Filtered DataFrame from prepare_data()
|
||||
start_date: Start date for patient initiation filter
|
||||
end_date: End date for patient initiation filter
|
||||
last_seen_date: Filter for patients last seen after this date
|
||||
title: Chart title (auto-generated if empty)
|
||||
|
||||
Returns:
|
||||
Tuple of (patient_info_df, date_df, final_title) or (None, None, "") if no valid data
|
||||
"""
|
||||
cost_df = df[["UPID", "Price Actual"]]
|
||||
total_costs = pd.DataFrame(cost_df.groupby("UPID").sum())
|
||||
total_costs.rename(columns={"Price Actual": "Total cost"}, inplace=True)
|
||||
|
||||
df_end_dates = _drop_duplicate_treatments(df, False)
|
||||
df1_unique = _drop_duplicate_treatments(df, True)
|
||||
logger.info("Identifying unique patients and interventions used")
|
||||
|
||||
df_drug_freq = (
|
||||
df.groupby("UPID")
|
||||
.agg({"Drug Name": lambda x: list(x)})
|
||||
.reset_index()
|
||||
.set_index("UPID")
|
||||
)
|
||||
df_drug_cost = (
|
||||
df.groupby("UPID")
|
||||
.agg({"Price Actual": lambda x: list(x)})
|
||||
.reset_index()
|
||||
.set_index("UPID")
|
||||
)
|
||||
df_drug_freq["Price Actual"] = df_drug_freq.index.map(df_drug_cost["Price Actual"])
|
||||
df_drug_freq["Drug Name"] = df_drug_freq["Drug Name"].apply(_count_list_values)
|
||||
df_drug_freq["Drug cost total"] = df_drug_freq.apply(lambda x: _sum_list_values(x), axis=1)
|
||||
|
||||
df_drugs = (
|
||||
df1_unique.groupby("UPID")
|
||||
.agg({"Drug Name": lambda x: list(x)})
|
||||
.reset_index()
|
||||
.set_index("UPID")
|
||||
)
|
||||
df_dates = (
|
||||
df1_unique.groupby("UPID")
|
||||
.agg({"Intervention Date": lambda x: list(x)})
|
||||
.reset_index()
|
||||
.set_index("UPID")
|
||||
)
|
||||
df_end_dates_grouped = (
|
||||
df_end_dates.groupby("UPID")
|
||||
.agg({"Intervention Date": lambda x: list(x)})
|
||||
.reset_index()
|
||||
.set_index("UPID")
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Calculating each unique patient's intervention average frequency, cost and duration of each intervention"
|
||||
)
|
||||
|
||||
df_dates_unwrapped = pd.DataFrame(
|
||||
df_dates["Intervention Date"].values.tolist(), index=df_dates.index
|
||||
).add_prefix("date_")
|
||||
df_end_dates_unwrapped = pd.DataFrame(
|
||||
df_end_dates_grouped["Intervention Date"].values.tolist(),
|
||||
index=df_end_dates_grouped.index,
|
||||
).add_prefix("date_end_")
|
||||
df_drugs_unwrapped = pd.DataFrame(
|
||||
df_drugs["Drug Name"].values.tolist(), index=df_drugs.index
|
||||
).add_prefix("drug_")
|
||||
|
||||
df_freq_unwrapped = pd.DataFrame(
|
||||
df_drug_freq["Drug Name"].values.tolist(), index=df_drug_freq.index
|
||||
).add_prefix("freq_")
|
||||
|
||||
start_dates = (
|
||||
df[["UPIDTreatment", "Intervention Date"]]
|
||||
.sort_values(by=["Intervention Date"], ascending=True)
|
||||
.drop_duplicates(subset="UPIDTreatment")
|
||||
.set_index("UPIDTreatment")
|
||||
)
|
||||
end_dates = (
|
||||
df[["UPIDTreatment", "Intervention Date"]]
|
||||
.sort_values(by=["Intervention Date"], ascending=False)
|
||||
.drop_duplicates(subset="UPIDTreatment")
|
||||
.set_index("UPIDTreatment")
|
||||
)
|
||||
|
||||
df_drugs_unwrapped["start_dates"] = df_drugs_unwrapped.apply(
|
||||
lambda x: _start_date_drug(start_dates, x), axis=1
|
||||
)
|
||||
df_start_dates_unwrapped = pd.DataFrame(
|
||||
df_drugs_unwrapped["start_dates"].values.tolist(), index=df_drugs_unwrapped.index
|
||||
).add_prefix("start_date_")
|
||||
df_drugs_unwrapped.drop(["start_dates"], inplace=True, axis=1)
|
||||
|
||||
df_drugs_unwrapped["end_dates"] = df_drugs_unwrapped.apply(
|
||||
lambda x: _start_date_drug(end_dates, x), axis=1
|
||||
)
|
||||
df_end_dates_unwrapped_2 = pd.DataFrame(
|
||||
df_drugs_unwrapped["end_dates"].values.tolist(), index=df_drugs_unwrapped.index
|
||||
).add_prefix("end_date_")
|
||||
df_drugs_unwrapped.drop(["end_dates"], inplace=True, axis=1)
|
||||
|
||||
df_drugs_unwrapped = pd.merge(
|
||||
df_drugs_unwrapped, df_start_dates_unwrapped, left_index=True, right_index=True
|
||||
)
|
||||
df_drugs_unwrapped = pd.merge(
|
||||
df_drugs_unwrapped, df_end_dates_unwrapped_2, left_index=True, right_index=True
|
||||
)
|
||||
|
||||
df_freq_for_merge = pd.DataFrame(
|
||||
df_drug_freq["Drug Name"].values.tolist(), index=df_drugs_unwrapped.index
|
||||
).add_prefix("freq_")
|
||||
df_drugs_unwrapped = pd.merge(
|
||||
df_drugs_unwrapped, df_freq_for_merge, left_index=True, right_index=True
|
||||
)
|
||||
df_drugs_unwrapped["frequency"] = df_drugs_unwrapped.apply(
|
||||
lambda x: _drug_frequency_average(x), axis=1
|
||||
)
|
||||
|
||||
df_spacing_unwrapped = pd.DataFrame(
|
||||
df_drugs_unwrapped["frequency"].values.tolist(), index=df_drugs_unwrapped.index
|
||||
).add_prefix("spacing_")
|
||||
df_drugs_unwrapped = pd.merge(
|
||||
df_drugs_unwrapped, df_spacing_unwrapped, left_index=True, right_index=True
|
||||
)
|
||||
|
||||
df_cost_unwrapped = pd.DataFrame(
|
||||
df_drug_freq["Drug cost total"].values.tolist(), index=df_drugs_unwrapped.index
|
||||
).add_prefix("total_cost_drug_")
|
||||
df_drugs_unwrapped = pd.merge(
|
||||
df_drugs_unwrapped, df_cost_unwrapped, left_index=True, right_index=True
|
||||
)
|
||||
df_drugs_unwrapped.drop(["frequency"], inplace=True, axis=1)
|
||||
|
||||
df_drugs_unwrapped.insert(0, "First seen", df_dates_unwrapped.min(axis=1))
|
||||
df_drugs_unwrapped.insert(1, "Last seen", df_end_dates_unwrapped.max(axis=1))
|
||||
|
||||
patient_info = df.drop_duplicates(subset="UPID", keep="first").set_index("UPID")
|
||||
patient_info = pd.merge(patient_info, df_drugs_unwrapped, left_index=True, right_index=True)
|
||||
patient_info = pd.merge(patient_info, df_freq_unwrapped, left_index=True, right_index=True)
|
||||
patient_info = pd.merge(patient_info, total_costs, left_index=True, right_index=True)
|
||||
|
||||
patient_info = patient_info[
|
||||
(patient_info["First seen"] >= str(start_date))
|
||||
& (patient_info["First seen"] < str(end_date))
|
||||
]
|
||||
|
||||
if title == "":
|
||||
title = f"Patients initiated from {start_date} to {end_date}"
|
||||
|
||||
patient_info = patient_info[patient_info["Last seen"] > str(last_seen_date)]
|
||||
|
||||
patient_info["drug_0"] = patient_info["drug_0"].replace("N/A", np.nan)
|
||||
patient_info.dropna(subset=["drug_0"], inplace=True)
|
||||
|
||||
if len(patient_info) == 0:
|
||||
logger.warning("No patients remaining after date filters.")
|
||||
return None, None, ""
|
||||
|
||||
patient_info["Days treated"] = patient_info["Last seen"] - patient_info["First seen"]
|
||||
date_df = patient_info[["First seen", "Last seen", "Days treated"]]
|
||||
|
||||
return patient_info, date_df, title
|
||||
|
||||
|
||||
def _row_function(row: pd.Series) -> str:
|
||||
"""Build composite parent-label-id string for hierarchy."""
|
||||
ids = ""
|
||||
parents = "N&WICS"
|
||||
count = row.count()
|
||||
for c in range(count):
|
||||
v = row[c]
|
||||
if type(v) != str:
|
||||
v = row[c + 1]
|
||||
if c == count - 1:
|
||||
ids = parents + " - " + v
|
||||
continue
|
||||
parents += " - " + v
|
||||
label = row[count - 1]
|
||||
value = parents + "," + label + "," + ids
|
||||
return value
|
||||
|
||||
|
||||
def _remove_nan_string(y) -> list:
|
||||
"""Remove 'nan' strings from list."""
|
||||
return remove_nan_values(y)
|
||||
|
||||
|
||||
def _list_to_string(x: pd.Series) -> str:
|
||||
"""Format drug statistics into readable string."""
|
||||
list_parts = x.ids.split(" - ")
|
||||
drug_list = list_parts[len(list_parts) - len(x.average_cost) :]
|
||||
ret_string = ""
|
||||
for y in range(len(x.average_cost)):
|
||||
if (
|
||||
(round(x.average_spacing[y], 0) > 1)
|
||||
and (round(x.average_administered[y], 0) > 2.5)
|
||||
and (int(x.value) > 0)
|
||||
):
|
||||
string = (
|
||||
f"<br><b>{drug_list[y]}</b><br>On average given "
|
||||
f"{round(x.average_administered[y], 1)} times with a "
|
||||
f"{round(int(x.average_spacing[y]) / 7, 1)} weekly interval ("
|
||||
f"{round((int(x.average_spacing[y]) / 7) * round(x.average_administered[y], 1), 0)} weeks total treatment length)"
|
||||
)
|
||||
else:
|
||||
string = (
|
||||
f"<br><b>{drug_list[y]}</b><br>On average given "
|
||||
f"{round(x.average_administered[y], 1)} times with a "
|
||||
f"{round(int(x.average_spacing[y]) / 7, 1)} weekly interval ("
|
||||
f"{round((int(x.average_spacing[y]) / 7) * round(x.average_administered[y], 1), 0)} weeks total treatment length)"
|
||||
)
|
||||
ret_string += string
|
||||
return ret_string
|
||||
|
||||
|
||||
def _min_max_treatment_dates(ice_df: pd.DataFrame, row: pd.Series) -> str:
|
||||
"""Get min/max dates for a pathway."""
|
||||
ids = row["ids"]
|
||||
min_max = ice_df[ice_df["ids"].str.contains(ids, regex=False)]
|
||||
if len(min_max) == 0:
|
||||
return "N/A,N/A"
|
||||
|
||||
# Handle NaT (Not a Time) values
|
||||
first_seen_min = min_max["First seen"].min()
|
||||
last_seen_max = min_max["Last seen"].max()
|
||||
|
||||
if pd.isna(first_seen_min):
|
||||
min_date = "N/A"
|
||||
else:
|
||||
min_date = str(first_seen_min.strftime("%Y-%m-%d"))
|
||||
|
||||
if pd.isna(last_seen_max):
|
||||
max_date = "N/A"
|
||||
else:
|
||||
max_date = str(last_seen_max.strftime("%Y-%m-%d"))
|
||||
|
||||
return f"{min_date},{max_date}"
|
||||
|
||||
|
||||
def _cost_pp_pa(x: pd.Series) -> str:
|
||||
"""Calculate cost per patient per annum."""
|
||||
result = calculate_cost_per_patient_per_annum(x["costpp"], x["avg_days"])
|
||||
if result is not None:
|
||||
return str(round(result, 2))
|
||||
else:
|
||||
return "N/A"
|
||||
|
||||
|
||||
def build_hierarchy(
|
||||
patient_info: pd.DataFrame,
|
||||
date_df: pd.DataFrame,
|
||||
df: pd.DataFrame,
|
||||
org_codes: pd.DataFrame,
|
||||
directory_df: pd.DataFrame,
|
||||
total_costs: pd.DataFrame,
|
||||
df_drugs_unwrapped: pd.DataFrame,
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
Build the hierarchical structure for the icicle chart.
|
||||
|
||||
Args:
|
||||
patient_info: DataFrame with calculated patient statistics
|
||||
date_df: DataFrame with first/last seen dates
|
||||
df: Original filtered DataFrame
|
||||
org_codes: Organization codes lookup
|
||||
directory_df: Directory assignments by UPID
|
||||
total_costs: Total costs by UPID
|
||||
df_drugs_unwrapped: Drug data with dates and frequencies unwrapped
|
||||
|
||||
Returns:
|
||||
DataFrame with parents, ids, labels, value, colour for icicle chart
|
||||
"""
|
||||
number_of_drugs = np.count_nonzero(patient_info.columns.str.startswith("drug_"))
|
||||
final_drug_index = patient_info.columns.to_list().index("drug_" + str(number_of_drugs - 1))
|
||||
|
||||
upid_drugs_df = patient_info.iloc[
|
||||
:, (final_drug_index - number_of_drugs + 1) : final_drug_index + 1
|
||||
]
|
||||
upid_drugs_df = upid_drugs_df.copy()
|
||||
|
||||
upid_drugs_df.insert(0, "Trust", upid_drugs_df.index.str[:3])
|
||||
upid_drugs_df.insert(1, "Directory", upid_drugs_df.index)
|
||||
|
||||
upid_drugs_df["Trust"] = upid_drugs_df["Trust"].map(org_codes["Name"])
|
||||
upid_drugs_df["Directory"] = upid_drugs_df["Directory"].map(directory_df["Directory"])
|
||||
|
||||
upid_drugs_df["value"] = upid_drugs_df.apply(lambda x: _row_function(x), axis=1)
|
||||
upid_drugs_df = pd.merge(upid_drugs_df, date_df, left_index=True, right_index=True)
|
||||
|
||||
upid_drugs_df["ids"] = upid_drugs_df["value"].str.split(",").str[2]
|
||||
|
||||
avg_treatment_dfs = pd.DataFrame(
|
||||
upid_drugs_df.groupby("ids", as_index=False)["Days treated"].mean()
|
||||
).set_index("ids")
|
||||
value_dfs = pd.DataFrame(
|
||||
upid_drugs_df.groupby("value", as_index=False).size()
|
||||
).reset_index()
|
||||
first_seen_treatment_dfs = pd.DataFrame(
|
||||
upid_drugs_df.groupby("ids", as_index=False)["First seen"].min()
|
||||
).set_index("ids")
|
||||
last_seen_treatment_dfs = pd.DataFrame(
|
||||
upid_drugs_df.groupby("ids", as_index=False)["Last seen"].max()
|
||||
).set_index("ids")
|
||||
|
||||
upid_drugs_df["Cost"] = upid_drugs_df.index.map(total_costs["Total cost"])
|
||||
cost_dfs = pd.DataFrame(
|
||||
upid_drugs_df.groupby("value", as_index=False)["Cost"].sum()
|
||||
).set_index("value", drop=True)
|
||||
|
||||
upid_drugs_df = pd.merge(upid_drugs_df, df_drugs_unwrapped, left_index=True, right_index=True)
|
||||
|
||||
spacing_average = pd.DataFrame(
|
||||
upid_drugs_df.groupby("value", as_index=False)[
|
||||
[col for col in upid_drugs_df.columns if "spacing_" in col]
|
||||
].mean()
|
||||
).set_index("value", drop=True)
|
||||
spacing_average = spacing_average.round()
|
||||
spacing_average["combined"] = spacing_average.values.tolist()
|
||||
spacing_average["ids"] = spacing_average.index
|
||||
spacing_average["ids"] = spacing_average["ids"].str.split(",").str[2]
|
||||
spacing_average.set_index("ids", inplace=True)
|
||||
|
||||
cost_average = pd.DataFrame(
|
||||
upid_drugs_df.groupby("value", as_index=False)[
|
||||
[col for col in upid_drugs_df.columns if "total_cost_drug_" in col]
|
||||
].mean()
|
||||
).set_index("value", drop=True)
|
||||
cost_average = cost_average.round(2)
|
||||
cost_average["combined"] = cost_average.values.tolist()
|
||||
cost_average["ids"] = cost_average.index
|
||||
cost_average["ids"] = cost_average["ids"].str.split(",").str[2]
|
||||
cost_average.set_index("ids", inplace=True)
|
||||
|
||||
freq_average = pd.DataFrame(
|
||||
upid_drugs_df.groupby("ids", as_index=False)[
|
||||
[col for col in upid_drugs_df.columns if "freq_" in col]
|
||||
].mean()
|
||||
).set_index("ids", drop=True)
|
||||
freq_average["combined"] = freq_average.values.tolist()
|
||||
|
||||
num = cost_dfs._get_numeric_data()
|
||||
num[num < 0] = 0
|
||||
|
||||
value_dfs["Cost"] = value_dfs["value"].map(cost_dfs["Cost"])
|
||||
|
||||
ice_df = pd.DataFrame()
|
||||
ice_df[["parents", "labels", "ids"]] = value_dfs["value"].str.split(",", expand=True)
|
||||
|
||||
ice_df["average_administered"] = ice_df["ids"].map(freq_average["combined"])
|
||||
ice_df["cost"] = value_dfs["Cost"]
|
||||
ice_df["value"] = value_dfs["size"]
|
||||
|
||||
ice_df["average_cost"] = ice_df["ids"].map(cost_average["combined"])
|
||||
ice_df["average_cost"] = ice_df["average_cost"].apply(_remove_nan_string)
|
||||
|
||||
ice_df["average_spacing"] = ice_df["ids"].map(spacing_average["combined"])
|
||||
ice_df["average_spacing"] = ice_df["average_spacing"].apply(_remove_nan_string)
|
||||
ice_df["average_spacing"] = ice_df.apply(lambda x: _list_to_string(x), axis=1)
|
||||
ice_df["average_spacing"] = ice_df["average_spacing"].str.replace("nan", "N/A")
|
||||
|
||||
logger.info("Building graph dataframe structure.")
|
||||
|
||||
new_row = pd.DataFrame(
|
||||
{"parents": "", "ids": "N&WICS", "labels": "N&WICS", "value": 0, "cost": 0}, index=[0]
|
||||
)
|
||||
ice_df = pd.concat(objs=[ice_df, new_row], ignore_index=True, axis=0)
|
||||
|
||||
l_df = pd.DataFrame()
|
||||
ice_df2 = pd.DataFrame()
|
||||
l3 = [x for x in ice_df.parents.unique() if x not in ice_df.ids]
|
||||
while len(l3) > 1:
|
||||
for l in l3:
|
||||
z = l.rfind("-")
|
||||
if z > 0:
|
||||
l_dict = {
|
||||
"parents": l[: z - 1],
|
||||
"ids": l,
|
||||
"value": 0,
|
||||
"labels": l[z + 2 :],
|
||||
"cost": 0,
|
||||
}
|
||||
l_df = pd.concat([l_df, pd.DataFrame(l_dict, index=[0])], ignore_index=True)
|
||||
ice_df2 = pd.concat([ice_df, l_df], ignore_index=True)
|
||||
l3 = [x for x in ice_df2.parents.unique() if x not in ice_df2.ids.unique()]
|
||||
if len(ice_df2) > 0:
|
||||
ice_df = ice_df2.drop_duplicates("ids")
|
||||
|
||||
ice_df["level"] = ice_df["ids"].str.count("-")
|
||||
ice_df = ice_df[~ice_df["labels"].isin(["COST", "CHARGE", "N/A"])]
|
||||
ice_df.sort_values(by=["level"], ascending=False, inplace=True, ignore_index=True)
|
||||
|
||||
for index, row in ice_df.iterrows():
|
||||
lookup_index = ice_df.index[ice_df["ids"] == row["parents"]]
|
||||
ice_df.loc[lookup_index, "value"] = (
|
||||
ice_df.loc[lookup_index, "value"] + ice_df.loc[index, "value"]
|
||||
)
|
||||
ice_df.loc[lookup_index, "cost"] = (
|
||||
ice_df.loc[lookup_index, "cost"] + ice_df.loc[index, "cost"]
|
||||
)
|
||||
|
||||
colour_df = pd.DataFrame(ice_df.groupby(["parents"])["value"].sum())
|
||||
ice_df["colour"] = ice_df["parents"].map(colour_df["value"])
|
||||
ice_df["colour"] = ice_df["value"] / ice_df["colour"]
|
||||
|
||||
ice_df["costpp"] = ice_df["cost"] / ice_df["value"]
|
||||
ice_df["avg_days"] = ice_df["ids"].map(avg_treatment_dfs["Days treated"])
|
||||
ice_df["First seen"] = ice_df["ids"].map(first_seen_treatment_dfs["First seen"])
|
||||
ice_df["Last seen"] = ice_df["ids"].map(last_seen_treatment_dfs["Last seen"])
|
||||
|
||||
ice_df["dates"] = ice_df.apply(lambda x: _min_max_treatment_dates(ice_df, x), axis=1)
|
||||
ice_df[["First seen (Parent)", "Last seen (Parent)"]] = ice_df["dates"].str.split(
|
||||
",", expand=True
|
||||
)
|
||||
|
||||
ice_df["First seen"] = pd.to_datetime(ice_df["First seen"])
|
||||
ice_df["Last seen"] = pd.to_datetime(ice_df["Last seen"])
|
||||
ice_df["cost_pp_pa"] = ice_df.apply(lambda x: _cost_pp_pa(x), axis=1)
|
||||
|
||||
return ice_df
|
||||
|
||||
|
||||
def prepare_chart_data(
|
||||
ice_df: pd.DataFrame,
|
||||
minimum_num_patients: int,
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
Prepare final chart data by applying patient threshold filter.
|
||||
|
||||
Args:
|
||||
ice_df: DataFrame from build_hierarchy()
|
||||
minimum_num_patients: Minimum number of patients to include a pathway
|
||||
|
||||
Returns:
|
||||
Filtered DataFrame ready for chart generation
|
||||
"""
|
||||
ice_df = ice_df[ice_df["value"] >= minimum_num_patients]
|
||||
logger.info("Generating graph.")
|
||||
return ice_df
|
||||
|
||||
|
||||
def generate_icicle_chart(
|
||||
df: pd.DataFrame,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
last_seen_date: str,
|
||||
trust_filter: list[str],
|
||||
drug_filter: list[str],
|
||||
directory_filter: list[str],
|
||||
minimum_num_patients: int,
|
||||
title: str = "",
|
||||
paths: Optional[PathConfig] = None,
|
||||
) -> tuple[pd.DataFrame, str]:
|
||||
"""
|
||||
Generate icicle chart data using the refactored pipeline.
|
||||
|
||||
This is the main entry point that orchestrates the full analysis pipeline.
|
||||
|
||||
Args:
|
||||
df: DataFrame with processed patient intervention data
|
||||
start_date: Start date for patient initiation filter
|
||||
end_date: End date for patient initiation filter
|
||||
last_seen_date: Filter for patients last seen after this date
|
||||
trust_filter: List of trust names to include
|
||||
drug_filter: List of drug names to include
|
||||
directory_filter: List of directories to include
|
||||
minimum_num_patients: Minimum number of patients to include a pathway
|
||||
title: Chart title (auto-generated if empty)
|
||||
paths: PathConfig for file paths (uses default if None)
|
||||
|
||||
Returns:
|
||||
Tuple of (ice_df for chart, final_title) or (None, "") if no data
|
||||
"""
|
||||
if paths is None:
|
||||
paths = default_paths
|
||||
|
||||
result = prepare_data(df, trust_filter, drug_filter, directory_filter, paths)
|
||||
if result[0] is None:
|
||||
return None, ""
|
||||
filtered_df, org_codes, directory_df = result
|
||||
|
||||
cost_df = filtered_df[["UPID", "Price Actual"]]
|
||||
total_costs = pd.DataFrame(cost_df.groupby("UPID").sum())
|
||||
total_costs.rename(columns={"Price Actual": "Total cost"}, inplace=True)
|
||||
|
||||
result = calculate_statistics(filtered_df, start_date, end_date, last_seen_date, title)
|
||||
if result[0] is None:
|
||||
return None, ""
|
||||
patient_info, date_df, final_title = result
|
||||
|
||||
df_drug_freq = (
|
||||
filtered_df.groupby("UPID")
|
||||
.agg({"Drug Name": lambda x: list(x)})
|
||||
.reset_index()
|
||||
.set_index("UPID")
|
||||
)
|
||||
df_drug_cost = (
|
||||
filtered_df.groupby("UPID")
|
||||
.agg({"Price Actual": lambda x: list(x)})
|
||||
.reset_index()
|
||||
.set_index("UPID")
|
||||
)
|
||||
df_drug_freq["Price Actual"] = df_drug_freq.index.map(df_drug_cost["Price Actual"])
|
||||
df_drug_freq["Drug Name"] = df_drug_freq["Drug Name"].apply(_count_list_values)
|
||||
df_drug_freq["Drug cost total"] = df_drug_freq.apply(lambda x: _sum_list_values(x), axis=1)
|
||||
|
||||
df1_unique = _drop_duplicate_treatments(filtered_df, True)
|
||||
df_drugs = (
|
||||
df1_unique.groupby("UPID")
|
||||
.agg({"Drug Name": lambda x: list(x)})
|
||||
.reset_index()
|
||||
.set_index("UPID")
|
||||
)
|
||||
df_dates = (
|
||||
df1_unique.groupby("UPID")
|
||||
.agg({"Intervention Date": lambda x: list(x)})
|
||||
.reset_index()
|
||||
.set_index("UPID")
|
||||
)
|
||||
|
||||
df_dates_unwrapped = pd.DataFrame(
|
||||
df_dates["Intervention Date"].values.tolist(), index=df_dates.index
|
||||
).add_prefix("date_")
|
||||
df_drugs_unwrapped = pd.DataFrame(
|
||||
df_drugs["Drug Name"].values.tolist(), index=df_drugs.index
|
||||
).add_prefix("drug_")
|
||||
|
||||
start_dates = (
|
||||
filtered_df[["UPIDTreatment", "Intervention Date"]]
|
||||
.sort_values(by=["Intervention Date"], ascending=True)
|
||||
.drop_duplicates(subset="UPIDTreatment")
|
||||
.set_index("UPIDTreatment")
|
||||
)
|
||||
end_dates = (
|
||||
filtered_df[["UPIDTreatment", "Intervention Date"]]
|
||||
.sort_values(by=["Intervention Date"], ascending=False)
|
||||
.drop_duplicates(subset="UPIDTreatment")
|
||||
.set_index("UPIDTreatment")
|
||||
)
|
||||
|
||||
df_drugs_unwrapped["start_dates"] = df_drugs_unwrapped.apply(
|
||||
lambda x: _start_date_drug(start_dates, x), axis=1
|
||||
)
|
||||
df_start_dates_unwrapped = pd.DataFrame(
|
||||
df_drugs_unwrapped["start_dates"].values.tolist(), index=df_drugs_unwrapped.index
|
||||
).add_prefix("start_date_")
|
||||
df_drugs_unwrapped.drop(["start_dates"], inplace=True, axis=1)
|
||||
|
||||
df_drugs_unwrapped["end_dates"] = df_drugs_unwrapped.apply(
|
||||
lambda x: _start_date_drug(end_dates, x), axis=1
|
||||
)
|
||||
df_end_dates_unwrapped_2 = pd.DataFrame(
|
||||
df_drugs_unwrapped["end_dates"].values.tolist(), index=df_drugs_unwrapped.index
|
||||
).add_prefix("end_date_")
|
||||
df_drugs_unwrapped.drop(["end_dates"], inplace=True, axis=1)
|
||||
|
||||
df_drugs_unwrapped = pd.merge(
|
||||
df_drugs_unwrapped, df_start_dates_unwrapped, left_index=True, right_index=True
|
||||
)
|
||||
df_drugs_unwrapped = pd.merge(
|
||||
df_drugs_unwrapped, df_end_dates_unwrapped_2, left_index=True, right_index=True
|
||||
)
|
||||
|
||||
df_freq_for_merge = pd.DataFrame(
|
||||
df_drug_freq["Drug Name"].values.tolist(), index=df_drugs_unwrapped.index
|
||||
).add_prefix("freq_")
|
||||
df_drugs_unwrapped = pd.merge(
|
||||
df_drugs_unwrapped, df_freq_for_merge, left_index=True, right_index=True
|
||||
)
|
||||
df_drugs_unwrapped["frequency"] = df_drugs_unwrapped.apply(
|
||||
lambda x: _drug_frequency_average(x), axis=1
|
||||
)
|
||||
|
||||
df_spacing_unwrapped = pd.DataFrame(
|
||||
df_drugs_unwrapped["frequency"].values.tolist(), index=df_drugs_unwrapped.index
|
||||
).add_prefix("spacing_")
|
||||
df_drugs_unwrapped = pd.merge(
|
||||
df_drugs_unwrapped, df_spacing_unwrapped, left_index=True, right_index=True
|
||||
)
|
||||
|
||||
df_cost_unwrapped = pd.DataFrame(
|
||||
df_drug_freq["Drug cost total"].values.tolist(), index=df_drugs_unwrapped.index
|
||||
).add_prefix("total_cost_drug_")
|
||||
df_drugs_unwrapped = pd.merge(
|
||||
df_drugs_unwrapped, df_cost_unwrapped, left_index=True, right_index=True
|
||||
)
|
||||
df_drugs_unwrapped.drop(["frequency"], inplace=True, axis=1)
|
||||
|
||||
ice_df = build_hierarchy(
|
||||
patient_info,
|
||||
date_df,
|
||||
filtered_df,
|
||||
org_codes,
|
||||
directory_df,
|
||||
total_costs,
|
||||
df_drugs_unwrapped,
|
||||
)
|
||||
|
||||
ice_df = prepare_chart_data(ice_df, minimum_num_patients)
|
||||
|
||||
return ice_df, final_title
|
||||
|
||||
|
||||
def generate_icicle_chart_indication(
|
||||
df: pd.DataFrame,
|
||||
indication_df: pd.DataFrame,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
last_seen_date: str,
|
||||
trust_filter: list[str],
|
||||
drug_filter: list[str],
|
||||
directory_filter: list[str],
|
||||
minimum_num_patients: int,
|
||||
title: str = "",
|
||||
paths: Optional[PathConfig] = None,
|
||||
) -> tuple[pd.DataFrame, str]:
|
||||
"""
|
||||
Generate icicle chart data with indication-based grouping.
|
||||
|
||||
This is a variant of generate_icicle_chart() that groups by Search_Term
|
||||
(from GP diagnosis match) instead of Directory. For patients without
|
||||
a GP diagnosis match, the fallback directorate is used with a "(no GP dx)"
|
||||
suffix to distinguish them.
|
||||
|
||||
Hierarchy: Trust → Indication_Group → Drug → Pathway
|
||||
|
||||
Args:
|
||||
df: DataFrame with processed patient intervention data
|
||||
indication_df: DataFrame mapping UPID → Indication_Group
|
||||
Must have 'UPID' as index and 'Indication_Group' column
|
||||
Values are either Search_Term or "Directory (no GP dx)"
|
||||
start_date: Start date for patient initiation filter
|
||||
end_date: End date for patient initiation filter
|
||||
last_seen_date: Filter for patients last seen after this date
|
||||
trust_filter: List of trust names to include
|
||||
drug_filter: List of drug names to include
|
||||
directory_filter: List of directories to include
|
||||
minimum_num_patients: Minimum number of patients to include a pathway
|
||||
title: Chart title (auto-generated if empty)
|
||||
paths: PathConfig for file paths (uses default if None)
|
||||
|
||||
Returns:
|
||||
Tuple of (ice_df for chart, final_title) or (None, "") if no data
|
||||
"""
|
||||
if paths is None:
|
||||
paths = default_paths
|
||||
|
||||
# Prepare data - use standard prepare_data function
|
||||
result = prepare_data(df, trust_filter, drug_filter, directory_filter, paths)
|
||||
if result[0] is None:
|
||||
return None, ""
|
||||
filtered_df, org_codes, directory_df = result
|
||||
|
||||
# For indication charts, we replace directory_df with indication_df
|
||||
# First, ensure indication_df has the correct format (UPID as index)
|
||||
if indication_df is not None and not indication_df.empty:
|
||||
if 'UPID' in indication_df.columns:
|
||||
indication_df = indication_df.set_index('UPID')
|
||||
# Rename column for compatibility with build_hierarchy()
|
||||
if 'Indication_Group' in indication_df.columns:
|
||||
indication_df = indication_df.rename(columns={'Indication_Group': 'Directory'})
|
||||
elif 'indication_group' in indication_df.columns:
|
||||
indication_df = indication_df.rename(columns={'indication_group': 'Directory'})
|
||||
else:
|
||||
# Fall back to directory if no indication data provided
|
||||
logger.warning("No indication data provided, falling back to directory grouping")
|
||||
indication_df = directory_df
|
||||
|
||||
cost_df = filtered_df[["UPID", "Price Actual"]]
|
||||
total_costs = pd.DataFrame(cost_df.groupby("UPID").sum())
|
||||
total_costs.rename(columns={"Price Actual": "Total cost"}, inplace=True)
|
||||
|
||||
result = calculate_statistics(filtered_df, start_date, end_date, last_seen_date, title)
|
||||
if result[0] is None:
|
||||
return None, ""
|
||||
patient_info, date_df, final_title = result
|
||||
|
||||
df_drug_freq = (
|
||||
filtered_df.groupby("UPID")
|
||||
.agg({"Drug Name": lambda x: list(x)})
|
||||
.reset_index()
|
||||
.set_index("UPID")
|
||||
)
|
||||
df_drug_cost = (
|
||||
filtered_df.groupby("UPID")
|
||||
.agg({"Price Actual": lambda x: list(x)})
|
||||
.reset_index()
|
||||
.set_index("UPID")
|
||||
)
|
||||
df_drug_freq["Price Actual"] = df_drug_freq.index.map(df_drug_cost["Price Actual"])
|
||||
df_drug_freq["Drug Name"] = df_drug_freq["Drug Name"].apply(_count_list_values)
|
||||
df_drug_freq["Drug cost total"] = df_drug_freq.apply(lambda x: _sum_list_values(x), axis=1)
|
||||
|
||||
df1_unique = _drop_duplicate_treatments(filtered_df, True)
|
||||
df_drugs = (
|
||||
df1_unique.groupby("UPID")
|
||||
.agg({"Drug Name": lambda x: list(x)})
|
||||
.reset_index()
|
||||
.set_index("UPID")
|
||||
)
|
||||
df_dates = (
|
||||
df1_unique.groupby("UPID")
|
||||
.agg({"Intervention Date": lambda x: list(x)})
|
||||
.reset_index()
|
||||
.set_index("UPID")
|
||||
)
|
||||
|
||||
df_dates_unwrapped = pd.DataFrame(
|
||||
df_dates["Intervention Date"].values.tolist(), index=df_dates.index
|
||||
).add_prefix("date_")
|
||||
df_drugs_unwrapped = pd.DataFrame(
|
||||
df_drugs["Drug Name"].values.tolist(), index=df_drugs.index
|
||||
).add_prefix("drug_")
|
||||
|
||||
start_dates = (
|
||||
filtered_df[["UPIDTreatment", "Intervention Date"]]
|
||||
.sort_values(by=["Intervention Date"], ascending=True)
|
||||
.drop_duplicates(subset="UPIDTreatment")
|
||||
.set_index("UPIDTreatment")
|
||||
)
|
||||
end_dates = (
|
||||
filtered_df[["UPIDTreatment", "Intervention Date"]]
|
||||
.sort_values(by=["Intervention Date"], ascending=False)
|
||||
.drop_duplicates(subset="UPIDTreatment")
|
||||
.set_index("UPIDTreatment")
|
||||
)
|
||||
|
||||
df_drugs_unwrapped["start_dates"] = df_drugs_unwrapped.apply(
|
||||
lambda x: _start_date_drug(start_dates, x), axis=1
|
||||
)
|
||||
df_start_dates_unwrapped = pd.DataFrame(
|
||||
df_drugs_unwrapped["start_dates"].values.tolist(), index=df_drugs_unwrapped.index
|
||||
).add_prefix("start_date_")
|
||||
df_drugs_unwrapped.drop(["start_dates"], inplace=True, axis=1)
|
||||
|
||||
df_drugs_unwrapped["end_dates"] = df_drugs_unwrapped.apply(
|
||||
lambda x: _start_date_drug(end_dates, x), axis=1
|
||||
)
|
||||
df_end_dates_unwrapped_2 = pd.DataFrame(
|
||||
df_drugs_unwrapped["end_dates"].values.tolist(), index=df_drugs_unwrapped.index
|
||||
).add_prefix("end_date_")
|
||||
df_drugs_unwrapped.drop(["end_dates"], inplace=True, axis=1)
|
||||
|
||||
df_drugs_unwrapped = pd.merge(
|
||||
df_drugs_unwrapped, df_start_dates_unwrapped, left_index=True, right_index=True
|
||||
)
|
||||
df_drugs_unwrapped = pd.merge(
|
||||
df_drugs_unwrapped, df_end_dates_unwrapped_2, left_index=True, right_index=True
|
||||
)
|
||||
|
||||
df_freq_for_merge = pd.DataFrame(
|
||||
df_drug_freq["Drug Name"].values.tolist(), index=df_drugs_unwrapped.index
|
||||
).add_prefix("freq_")
|
||||
df_drugs_unwrapped = pd.merge(
|
||||
df_drugs_unwrapped, df_freq_for_merge, left_index=True, right_index=True
|
||||
)
|
||||
df_drugs_unwrapped["frequency"] = df_drugs_unwrapped.apply(
|
||||
lambda x: _drug_frequency_average(x), axis=1
|
||||
)
|
||||
|
||||
df_spacing_unwrapped = pd.DataFrame(
|
||||
df_drugs_unwrapped["frequency"].values.tolist(), index=df_drugs_unwrapped.index
|
||||
).add_prefix("spacing_")
|
||||
df_drugs_unwrapped = pd.merge(
|
||||
df_drugs_unwrapped, df_spacing_unwrapped, left_index=True, right_index=True
|
||||
)
|
||||
|
||||
df_cost_unwrapped = pd.DataFrame(
|
||||
df_drug_freq["Drug cost total"].values.tolist(), index=df_drugs_unwrapped.index
|
||||
).add_prefix("total_cost_drug_")
|
||||
df_drugs_unwrapped = pd.merge(
|
||||
df_drugs_unwrapped, df_cost_unwrapped, left_index=True, right_index=True
|
||||
)
|
||||
df_drugs_unwrapped.drop(["frequency"], inplace=True, axis=1)
|
||||
|
||||
# Build hierarchy with indication_df instead of directory_df
|
||||
ice_df = build_hierarchy(
|
||||
patient_info,
|
||||
date_df,
|
||||
filtered_df,
|
||||
org_codes,
|
||||
indication_df, # Use indication mapping instead of directory
|
||||
total_costs,
|
||||
df_drugs_unwrapped,
|
||||
)
|
||||
|
||||
ice_df = prepare_chart_data(ice_df, minimum_num_patients)
|
||||
|
||||
return ice_df, final_title
|
||||
@@ -0,0 +1,331 @@
|
||||
"""
|
||||
Statistical calculation functions for patient pathway analysis.
|
||||
|
||||
This module contains functions for calculating:
|
||||
- Drug frequency counts and averages
|
||||
- Cost aggregations (total, per patient, per annum)
|
||||
- Treatment duration calculations
|
||||
- Dosing interval calculations
|
||||
|
||||
These functions are extracted from the analysis pipeline to enable:
|
||||
- Independent testing
|
||||
- Reuse across different analysis contexts
|
||||
- Clearer separation of concerns
|
||||
"""
|
||||
|
||||
from itertools import groupby
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
|
||||
def count_consecutive_values(values: list) -> list[int]:
|
||||
"""
|
||||
Count consecutive occurrences of each value in a sorted list.
|
||||
|
||||
Used to count how many times each drug was administered.
|
||||
|
||||
Args:
|
||||
values: List of values (typically drug names)
|
||||
|
||||
Returns:
|
||||
List of counts for each unique value in sorted order
|
||||
|
||||
Example:
|
||||
>>> count_consecutive_values(['A', 'A', 'B', 'A'])
|
||||
[3, 1] # 'A' appears 3 times, 'B' appears 1 time (sorted)
|
||||
"""
|
||||
return [len(list(group)) for key, group in groupby(sorted(values))]
|
||||
|
||||
|
||||
def calculate_drug_costs(drug_counts: list[int], prices: list[float]) -> list[float]:
|
||||
"""
|
||||
Calculate total cost for each drug based on counts and prices.
|
||||
|
||||
Splits the price list based on drug administration counts and sums
|
||||
each drug's portion.
|
||||
|
||||
Args:
|
||||
drug_counts: List of administration counts per drug (from count_consecutive_values)
|
||||
prices: List of individual administration prices (Price Actual values)
|
||||
|
||||
Returns:
|
||||
List of total costs per drug
|
||||
|
||||
Example:
|
||||
>>> calculate_drug_costs([3, 2], [100, 100, 100, 200, 200])
|
||||
[300.0, 400.0] # Drug 1: 3x$100 = $300, Drug 2: 2x$200 = $400
|
||||
"""
|
||||
sum_list = []
|
||||
cumulative = 0
|
||||
for count in drug_counts:
|
||||
drug_cost = sum(prices[cumulative:cumulative + count])
|
||||
sum_list.append(float(drug_cost))
|
||||
cumulative += count
|
||||
return sum_list
|
||||
|
||||
|
||||
def calculate_dosing_frequency(
|
||||
freq: int,
|
||||
start_date: pd.Timestamp,
|
||||
end_date: pd.Timestamp,
|
||||
) -> float:
|
||||
"""
|
||||
Calculate average dosing interval in days.
|
||||
|
||||
Computes the average number of days between administrations.
|
||||
|
||||
Args:
|
||||
freq: Number of administrations
|
||||
start_date: First administration date
|
||||
end_date: Last administration date
|
||||
|
||||
Returns:
|
||||
Average days between administrations, or 0 if only one dose
|
||||
|
||||
Example:
|
||||
>>> start = pd.Timestamp('2024-01-01')
|
||||
>>> end = pd.Timestamp('2024-01-22')
|
||||
>>> calculate_dosing_frequency(4, start, end)
|
||||
7.0 # 21 days / (4-1) = 7 days between doses
|
||||
"""
|
||||
if freq <= 1:
|
||||
return 0.0
|
||||
|
||||
duration_days = (end_date - start_date) / np.timedelta64(1, "D")
|
||||
if duration_days <= 0:
|
||||
return 0.0
|
||||
|
||||
return duration_days / (freq - 1)
|
||||
|
||||
|
||||
def calculate_drug_frequency_row(row: pd.Series) -> list[float]:
|
||||
"""
|
||||
Calculate average dosing frequency for each drug in a patient's treatment.
|
||||
|
||||
Used with DataFrame.apply() on rows containing drug_*, freq_*, start_date_*, end_date_* columns.
|
||||
|
||||
Args:
|
||||
row: Series with drug names, frequencies, start dates, and end dates
|
||||
|
||||
Returns:
|
||||
List of average dosing intervals (days) for each drug
|
||||
"""
|
||||
drug_count = row.index.str.contains("drug_").sum()
|
||||
frequencies = []
|
||||
|
||||
for d in range(drug_count):
|
||||
freq_col = f"freq_{d}"
|
||||
start_col = f"start_date_{d}"
|
||||
end_col = f"end_date_{d}"
|
||||
|
||||
freq = row.get(freq_col, 0)
|
||||
if freq is None or pd.isna(freq):
|
||||
freq = 0
|
||||
else:
|
||||
freq = int(freq)
|
||||
|
||||
if freq > 1:
|
||||
start_date = row.get(start_col)
|
||||
end_date = row.get(end_col)
|
||||
|
||||
if pd.notna(start_date) and pd.notna(end_date):
|
||||
interval = calculate_dosing_frequency(freq, start_date, end_date)
|
||||
else:
|
||||
interval = 0.0
|
||||
else:
|
||||
interval = 0.0
|
||||
|
||||
frequencies.append(interval)
|
||||
|
||||
return frequencies
|
||||
|
||||
|
||||
def calculate_cost_per_patient_per_annum(
|
||||
total_cost: float,
|
||||
days_treated: Optional[pd.Timedelta],
|
||||
) -> Optional[float]:
|
||||
"""
|
||||
Calculate annualized cost per patient.
|
||||
|
||||
Normalizes costs to a per-year basis to enable comparison across
|
||||
patients with different treatment durations.
|
||||
|
||||
Args:
|
||||
total_cost: Total cost for the patient (can be Decimal or float)
|
||||
days_treated: Treatment duration as timedelta
|
||||
|
||||
Returns:
|
||||
Annualized cost, or None if days_treated is 0 or None
|
||||
|
||||
Example:
|
||||
>>> calculate_cost_per_patient_per_annum(5000, pd.Timedelta(days=182.5))
|
||||
10000.0 # Half year treatment, so annual cost is 2x
|
||||
"""
|
||||
if days_treated is None or pd.isna(days_treated):
|
||||
return None
|
||||
|
||||
days = days_treated / np.timedelta64(1, "D") if hasattr(days_treated, '__truediv__') else float(days_treated)
|
||||
|
||||
if days <= 0:
|
||||
return None
|
||||
|
||||
# Convert total_cost to float to handle Decimal from Snowflake
|
||||
return float(total_cost) / (days / 365)
|
||||
|
||||
|
||||
def calculate_treatment_duration(
|
||||
first_seen: pd.Timestamp,
|
||||
last_seen: pd.Timestamp,
|
||||
) -> pd.Timedelta:
|
||||
"""
|
||||
Calculate treatment duration from first to last seen dates.
|
||||
|
||||
Args:
|
||||
first_seen: Date of first treatment
|
||||
last_seen: Date of last treatment
|
||||
|
||||
Returns:
|
||||
Duration as timedelta
|
||||
"""
|
||||
return last_seen - first_seen
|
||||
|
||||
|
||||
def calculate_pathway_proportion(value: int, parent_value: int) -> float:
|
||||
"""
|
||||
Calculate proportion of parent value for color scaling.
|
||||
|
||||
Used to determine color intensity in the icicle chart based on
|
||||
what proportion of the parent category this pathway represents.
|
||||
|
||||
Args:
|
||||
value: Patient count for this pathway
|
||||
parent_value: Total patient count for the parent category
|
||||
|
||||
Returns:
|
||||
Proportion (0.0 to 1.0)
|
||||
"""
|
||||
if parent_value <= 0:
|
||||
return 0.0
|
||||
return value / parent_value
|
||||
|
||||
|
||||
def aggregate_patient_costs(df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""
|
||||
Calculate total cost per patient (UPID).
|
||||
|
||||
Args:
|
||||
df: DataFrame with UPID and Price Actual columns
|
||||
|
||||
Returns:
|
||||
DataFrame indexed by UPID with Total cost column
|
||||
"""
|
||||
cost_df = df[["UPID", "Price Actual"]]
|
||||
total_costs = cost_df.groupby("UPID").sum()
|
||||
total_costs.rename(columns={"Price Actual": "Total cost"}, inplace=True)
|
||||
return total_costs
|
||||
|
||||
|
||||
def aggregate_drug_frequencies(df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""
|
||||
Calculate drug administration frequency per patient.
|
||||
|
||||
Groups by UPID and returns counts of each drug's administrations.
|
||||
|
||||
Args:
|
||||
df: DataFrame with UPID and Drug Name columns
|
||||
|
||||
Returns:
|
||||
DataFrame indexed by UPID with Drug Name as list of counts
|
||||
"""
|
||||
return (
|
||||
df.groupby("UPID")
|
||||
.agg({"Drug Name": lambda x: count_consecutive_values(list(x))})
|
||||
.reset_index()
|
||||
.set_index("UPID")
|
||||
)
|
||||
|
||||
|
||||
def calculate_average_spacing_for_pathway(
|
||||
upid_drugs_df: pd.DataFrame,
|
||||
pathway_value: str,
|
||||
) -> list[float]:
|
||||
"""
|
||||
Calculate average dosing spacing for a treatment pathway.
|
||||
|
||||
Groups patients by pathway and calculates mean spacing for each drug position.
|
||||
|
||||
Args:
|
||||
upid_drugs_df: DataFrame with patient pathway data and spacing columns
|
||||
pathway_value: Pathway identifier string
|
||||
|
||||
Returns:
|
||||
List of average spacing values (days) for each drug in pathway
|
||||
"""
|
||||
spacing_cols = [col for col in upid_drugs_df.columns if col.startswith("spacing_")]
|
||||
|
||||
pathway_data = upid_drugs_df[upid_drugs_df["value"] == pathway_value]
|
||||
|
||||
if len(pathway_data) == 0:
|
||||
return []
|
||||
|
||||
averages = pathway_data[spacing_cols].mean()
|
||||
return [round(v, 0) if pd.notna(v) else 0.0 for v in averages.tolist()]
|
||||
|
||||
|
||||
def format_treatment_statistics(
|
||||
drug_names: list[str],
|
||||
average_administered: list[float],
|
||||
average_spacing: list[float],
|
||||
average_cost: list[float],
|
||||
) -> str:
|
||||
"""
|
||||
Format drug treatment statistics into a readable string for chart display.
|
||||
|
||||
Creates an HTML-formatted string with drug name, average administrations,
|
||||
dosing interval, and total treatment length.
|
||||
|
||||
Args:
|
||||
drug_names: List of drug names in treatment sequence
|
||||
average_administered: Average number of administrations per drug
|
||||
average_spacing: Average days between doses per drug
|
||||
average_cost: Average cost per drug
|
||||
|
||||
Returns:
|
||||
HTML-formatted string for chart hover text
|
||||
"""
|
||||
ret_string = ""
|
||||
|
||||
for i, drug_name in enumerate(drug_names):
|
||||
admin_count = average_administered[i] if i < len(average_administered) else 0
|
||||
spacing_days = average_spacing[i] if i < len(average_spacing) else 0
|
||||
|
||||
# Convert to weeks
|
||||
spacing_weeks = spacing_days / 7 if spacing_days > 0 else 0
|
||||
total_weeks = spacing_weeks * admin_count if admin_count > 0 else 0
|
||||
|
||||
string = (
|
||||
f"<br><b>{drug_name}</b><br>On average given "
|
||||
f"{round(admin_count, 1)} times with a "
|
||||
f"{round(spacing_weeks, 1)} weekly interval ("
|
||||
f"{round(total_weeks, 0)} weeks total treatment length)"
|
||||
)
|
||||
ret_string += string
|
||||
|
||||
return ret_string
|
||||
|
||||
|
||||
def remove_nan_values(values: list) -> list:
|
||||
"""
|
||||
Remove NaN string values from a list.
|
||||
|
||||
Used to clean up aggregated statistics that may contain 'nan' strings.
|
||||
|
||||
Args:
|
||||
values: List potentially containing 'nan' strings
|
||||
|
||||
Returns:
|
||||
Filtered list without 'nan' strings
|
||||
"""
|
||||
return [x for x in values if str(x).lower() != "nan"]
|
||||
@@ -0,0 +1,27 @@
|
||||
# CLI Package
|
||||
|
||||
Command-line interface for pathway data refresh operations.
|
||||
|
||||
## refresh_pathways.py
|
||||
|
||||
Main CLI module for refreshing pre-computed pathway data from Snowflake to SQLite.
|
||||
|
||||
**Key Functions:**
|
||||
- `refresh_pathways()` — Orchestrates full pipeline: fetch from Snowflake, transform via tools/data.py, generate pathway charts, insert to SQLite
|
||||
- `insert_pathway_records()` — Bulk inserts using parameterized queries with `INSERT OR REPLACE` (handles overwrites via UNIQUE constraint)
|
||||
- `log_refresh_start()`, `log_refresh_complete()`, `log_refresh_failed()` — Tracks refresh status in pathway_refresh_log table
|
||||
- `get_default_filters()` — Loads available trusts, drugs, directories from CSV files
|
||||
|
||||
**CLI Arguments:**
|
||||
- `--chart-type [all|directory|indication]` — Which pathway types to refresh (default: all)
|
||||
- `--dry-run` — Test without database changes
|
||||
- `--minimum-patients N` — Pathway nodes with <N patients filtered out (default: 5)
|
||||
- `-v, --verbose` — Enable debug logging
|
||||
|
||||
**Usage:**
|
||||
```bash
|
||||
python -m cli.refresh_pathways --chart-type all
|
||||
python -m cli.refresh_pathways --chart-type indication --dry-run -v
|
||||
```
|
||||
|
||||
**Note:** Module uses sys.path bootstrap at top to enable `python -m cli.refresh_pathways` from project root.
|
||||
@@ -0,0 +1,6 @@
|
||||
"""
|
||||
CLI commands for NHS High-Cost Drug Patient Pathway Analysis Tool.
|
||||
|
||||
Available commands:
|
||||
python -m cli.refresh_pathways - Refresh pathway data from Snowflake
|
||||
"""
|
||||
@@ -0,0 +1,662 @@
|
||||
"""
|
||||
CLI command for refreshing pathway data from Snowflake.
|
||||
|
||||
This command fetches activity data from Snowflake, processes it through the
|
||||
pathway pipeline for all 6 date filter combinations, and stores the results
|
||||
in the SQLite pathway_nodes table. Supports two chart types:
|
||||
- "directory": Trust → Directory → Drug → Pathway (default)
|
||||
- "indication": Trust → Search_Term → Drug → Pathway (requires GP diagnosis lookup)
|
||||
|
||||
Usage:
|
||||
python -m cli.refresh_pathways
|
||||
python -m cli.refresh_pathways --minimum-patients 10
|
||||
python -m cli.refresh_pathways --provider-codes RGT,RM1
|
||||
python -m cli.refresh_pathways --chart-type all
|
||||
python -m cli.refresh_pathways --chart-type directory
|
||||
python -m cli.refresh_pathways --dry-run
|
||||
|
||||
Run `python -m cli.refresh_pathways --help` for full options.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sqlite3
|
||||
import sys
|
||||
import time
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
# Ensure src/ is on sys.path when run as `python -m cli.refresh_pathways`
|
||||
_src_dir = str(Path(__file__).resolve().parent.parent)
|
||||
if _src_dir not in sys.path:
|
||||
sys.path.insert(0, _src_dir)
|
||||
|
||||
from core import PathConfig, default_paths
|
||||
from core.logging_config import get_logger, setup_logging
|
||||
from data_processing.database import DatabaseManager, DatabaseConfig
|
||||
from data_processing.schema import (
|
||||
clear_pathway_nodes,
|
||||
get_pathway_table_counts,
|
||||
verify_pathway_tables_exist,
|
||||
create_pathway_tables,
|
||||
)
|
||||
from data_processing.pathway_pipeline import (
|
||||
ChartType,
|
||||
DATE_FILTER_CONFIGS,
|
||||
fetch_and_transform_data,
|
||||
process_all_date_filters,
|
||||
process_pathway_for_date_filter,
|
||||
process_indication_pathway_for_date_filter,
|
||||
extract_denormalized_fields,
|
||||
extract_indication_fields,
|
||||
convert_to_records,
|
||||
)
|
||||
from data_processing.diagnosis_lookup import (
|
||||
assign_drug_indications,
|
||||
get_patient_indication_groups,
|
||||
load_drug_indication_mapping,
|
||||
)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def get_default_filters(paths: PathConfig) -> tuple[list[str], list[str], list[str]]:
|
||||
"""
|
||||
Load default filter values from reference files.
|
||||
|
||||
Returns:
|
||||
Tuple of (trust_filter, drug_filter, directory_filter)
|
||||
"""
|
||||
import pandas as pd
|
||||
|
||||
# Load default trusts
|
||||
trust_filter = []
|
||||
if paths.default_trusts_csv.exists():
|
||||
try:
|
||||
trusts_df = pd.read_csv(paths.default_trusts_csv)
|
||||
# Use the "Name" column which contains trust names
|
||||
if 'Name' in trusts_df.columns:
|
||||
trust_filter = trusts_df['Name'].dropna().tolist()
|
||||
else:
|
||||
# Fallback to first column if no Name column
|
||||
trust_filter = trusts_df.iloc[:, 0].dropna().tolist()
|
||||
logger.info(f"Loaded {len(trust_filter)} default trusts")
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not load default trusts: {e}")
|
||||
|
||||
# Load default drugs (Include=1 in include.csv)
|
||||
drug_filter = []
|
||||
if paths.include_csv.exists():
|
||||
try:
|
||||
drugs_df = pd.read_csv(paths.include_csv)
|
||||
if 'Include' in drugs_df.columns:
|
||||
drug_filter = drugs_df[drugs_df['Include'] == 1].iloc[:, 0].dropna().tolist()
|
||||
else:
|
||||
# Assume first column contains drug names if no Include column
|
||||
drug_filter = drugs_df.iloc[:, 0].dropna().tolist()
|
||||
logger.info(f"Loaded {len(drug_filter)} default drugs")
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not load default drugs: {e}")
|
||||
|
||||
# Load default directories
|
||||
directory_filter = []
|
||||
if paths.directory_list_csv.exists():
|
||||
try:
|
||||
dirs_df = pd.read_csv(paths.directory_list_csv)
|
||||
# Assume first column contains directory names
|
||||
directory_filter = dirs_df.iloc[:, 0].dropna().tolist()
|
||||
logger.info(f"Loaded {len(directory_filter)} default directories")
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not load default directories: {e}")
|
||||
|
||||
return trust_filter, drug_filter, directory_filter
|
||||
|
||||
|
||||
def insert_pathway_records(
|
||||
conn: sqlite3.Connection,
|
||||
records: list[dict],
|
||||
) -> int:
|
||||
"""
|
||||
Insert pathway records into pathway_nodes table.
|
||||
|
||||
Uses INSERT OR REPLACE to handle updates to existing records.
|
||||
|
||||
Args:
|
||||
conn: SQLite connection
|
||||
records: List of record dicts from convert_to_records()
|
||||
|
||||
Returns:
|
||||
Number of records inserted
|
||||
"""
|
||||
if not records:
|
||||
return 0
|
||||
|
||||
# Column order matching pathway_nodes schema (includes chart_type)
|
||||
columns = [
|
||||
'date_filter_id', 'chart_type', 'parents', 'ids', 'labels', 'level',
|
||||
'value', 'cost', 'costpp', 'cost_pp_pa', 'colour',
|
||||
'first_seen', 'last_seen', 'first_seen_parent', 'last_seen_parent',
|
||||
'average_spacing', 'average_administered', 'avg_days',
|
||||
'trust_name', 'directory', 'drug_sequence', 'data_refresh_id'
|
||||
]
|
||||
|
||||
placeholders = ', '.join(['?' for _ in columns])
|
||||
column_names = ', '.join(columns)
|
||||
|
||||
insert_sql = f"""
|
||||
INSERT OR REPLACE INTO pathway_nodes ({column_names})
|
||||
VALUES ({placeholders})
|
||||
"""
|
||||
|
||||
# Convert records to tuples in column order
|
||||
rows = []
|
||||
for record in records:
|
||||
row = tuple(record.get(col) for col in columns)
|
||||
rows.append(row)
|
||||
|
||||
cursor = conn.executemany(insert_sql, rows)
|
||||
return cursor.rowcount
|
||||
|
||||
|
||||
def log_refresh_start(
|
||||
conn: sqlite3.Connection,
|
||||
refresh_id: str,
|
||||
date_from: Optional[str] = None,
|
||||
date_to: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Log the start of a refresh operation."""
|
||||
conn.execute("""
|
||||
INSERT INTO pathway_refresh_log
|
||||
(refresh_id, started_at, status, snowflake_query_date_from, snowflake_query_date_to)
|
||||
VALUES (?, ?, 'running', ?, ?)
|
||||
""", (refresh_id, datetime.now().isoformat(), date_from, date_to))
|
||||
conn.commit()
|
||||
|
||||
|
||||
def log_refresh_complete(
|
||||
conn: sqlite3.Connection,
|
||||
refresh_id: str,
|
||||
record_count: int,
|
||||
date_filter_counts: dict[str, int],
|
||||
duration_seconds: float,
|
||||
source_row_count: Optional[int] = None,
|
||||
) -> None:
|
||||
"""Log the successful completion of a refresh operation."""
|
||||
conn.execute("""
|
||||
UPDATE pathway_refresh_log
|
||||
SET completed_at = ?,
|
||||
status = 'completed',
|
||||
record_count = ?,
|
||||
date_filter_counts = ?,
|
||||
processing_duration_seconds = ?,
|
||||
source_row_count = ?
|
||||
WHERE refresh_id = ?
|
||||
""", (
|
||||
datetime.now().isoformat(),
|
||||
record_count,
|
||||
json.dumps(date_filter_counts),
|
||||
duration_seconds,
|
||||
source_row_count,
|
||||
refresh_id,
|
||||
))
|
||||
conn.commit()
|
||||
|
||||
|
||||
def log_refresh_failed(
|
||||
conn: sqlite3.Connection,
|
||||
refresh_id: str,
|
||||
error_message: str,
|
||||
duration_seconds: float,
|
||||
) -> None:
|
||||
"""Log a failed refresh operation."""
|
||||
conn.execute("""
|
||||
UPDATE pathway_refresh_log
|
||||
SET completed_at = ?,
|
||||
status = 'failed',
|
||||
error_message = ?,
|
||||
processing_duration_seconds = ?
|
||||
WHERE refresh_id = ?
|
||||
""", (
|
||||
datetime.now().isoformat(),
|
||||
error_message,
|
||||
duration_seconds,
|
||||
refresh_id,
|
||||
))
|
||||
conn.commit()
|
||||
|
||||
|
||||
def refresh_pathways(
|
||||
minimum_patients: int = 5,
|
||||
provider_codes: Optional[list[str]] = None,
|
||||
trust_filter: Optional[list[str]] = None,
|
||||
drug_filter: Optional[list[str]] = None,
|
||||
directory_filter: Optional[list[str]] = None,
|
||||
db_path: Optional[Path] = None,
|
||||
paths: Optional[PathConfig] = None,
|
||||
dry_run: bool = False,
|
||||
chart_type: str = "directory",
|
||||
) -> tuple[bool, str, dict]:
|
||||
"""
|
||||
Main refresh function that orchestrates the full pipeline.
|
||||
|
||||
Args:
|
||||
minimum_patients: Minimum patients to include a pathway
|
||||
provider_codes: List of provider codes to filter Snowflake query
|
||||
trust_filter: List of trust names to include in pathways
|
||||
drug_filter: List of drug names to include in pathways
|
||||
directory_filter: List of directories to include in pathways
|
||||
db_path: Path to SQLite database (uses default if None)
|
||||
paths: PathConfig for file paths
|
||||
dry_run: If True, don't actually insert records
|
||||
chart_type: Which chart type to process: "directory", "indication", or "all"
|
||||
|
||||
Returns:
|
||||
Tuple of (success: bool, message: str, stats: dict)
|
||||
"""
|
||||
if paths is None:
|
||||
paths = default_paths
|
||||
|
||||
# Set up database connection
|
||||
if db_path:
|
||||
db_config = DatabaseConfig(db_path=db_path)
|
||||
else:
|
||||
db_config = DatabaseConfig(data_dir=paths.data_dir)
|
||||
|
||||
db_manager = DatabaseManager(db_config)
|
||||
|
||||
# Load default filters if not provided
|
||||
default_trusts, default_drugs, default_dirs = get_default_filters(paths)
|
||||
|
||||
if trust_filter is None:
|
||||
trust_filter = default_trusts
|
||||
if drug_filter is None:
|
||||
drug_filter = default_drugs
|
||||
if directory_filter is None:
|
||||
directory_filter = default_dirs
|
||||
|
||||
# Ensure we have some filters
|
||||
if not drug_filter:
|
||||
return False, "No drugs specified and could not load defaults", {}
|
||||
|
||||
# Determine which chart types to process
|
||||
if chart_type == "all":
|
||||
chart_types_to_process: list[ChartType] = ["directory", "indication"]
|
||||
else:
|
||||
chart_types_to_process = [chart_type] # type: ignore
|
||||
|
||||
logger.info("=" * 60)
|
||||
logger.info("Pathway Data Refresh Starting")
|
||||
logger.info("=" * 60)
|
||||
logger.info(f"Minimum patients: {minimum_patients}")
|
||||
logger.info(f"Trust filter: {len(trust_filter)} trusts")
|
||||
logger.info(f"Drug filter: {len(drug_filter)} drugs")
|
||||
logger.info(f"Directory filter: {len(directory_filter)} directories")
|
||||
logger.info(f"Provider codes: {provider_codes or 'All'}")
|
||||
logger.info(f"Chart type(s): {', '.join(chart_types_to_process)}")
|
||||
logger.info(f"Database: {db_manager.db_path}")
|
||||
logger.info(f"Dry run: {dry_run}")
|
||||
logger.info("=" * 60)
|
||||
|
||||
start_time = time.time()
|
||||
refresh_id = str(uuid.uuid4())[:8]
|
||||
stats = {
|
||||
"refresh_id": refresh_id,
|
||||
"date_filter_counts": {},
|
||||
"total_records": 0,
|
||||
"snowflake_rows": 0,
|
||||
}
|
||||
|
||||
try:
|
||||
# Verify database and tables
|
||||
with db_manager.get_connection() as conn:
|
||||
missing_tables = verify_pathway_tables_exist(conn)
|
||||
if missing_tables:
|
||||
logger.info(f"Creating missing tables: {missing_tables}")
|
||||
create_pathway_tables(conn)
|
||||
|
||||
# Log refresh start
|
||||
if not dry_run:
|
||||
log_refresh_start(conn, refresh_id)
|
||||
|
||||
# Step 1: Fetch data from Snowflake
|
||||
logger.info("")
|
||||
logger.info("Step 1/4: Fetching data from Snowflake...")
|
||||
df = fetch_and_transform_data(
|
||||
provider_codes=provider_codes,
|
||||
paths=paths,
|
||||
)
|
||||
|
||||
if df.empty:
|
||||
msg = "No data returned from Snowflake"
|
||||
logger.error(msg)
|
||||
with db_manager.get_connection() as conn:
|
||||
log_refresh_failed(conn, refresh_id, msg, time.time() - start_time)
|
||||
return False, msg, stats
|
||||
|
||||
stats["snowflake_rows"] = len(df)
|
||||
logger.info(f"Fetched {len(df)} records from Snowflake")
|
||||
|
||||
# Step 2: Process all date filters for each chart type
|
||||
num_date_filters = len(DATE_FILTER_CONFIGS)
|
||||
num_chart_types = len(chart_types_to_process)
|
||||
total_datasets = num_date_filters * num_chart_types
|
||||
|
||||
logger.info("")
|
||||
logger.info(f"Step 2/4: Processing pathway data for {total_datasets} datasets "
|
||||
f"({num_date_filters} date filters x {num_chart_types} chart types)...")
|
||||
|
||||
# Store results keyed by "date_filter_id:chart_type"
|
||||
results: dict[str, list[dict]] = {}
|
||||
|
||||
for current_chart_type in chart_types_to_process:
|
||||
logger.info("")
|
||||
logger.info(f"Processing chart type: {current_chart_type}")
|
||||
|
||||
if current_chart_type == "directory":
|
||||
# Use existing process_all_date_filters for directory charts
|
||||
dir_results = process_all_date_filters(
|
||||
df=df,
|
||||
trust_filter=trust_filter,
|
||||
drug_filter=drug_filter,
|
||||
directory_filter=directory_filter,
|
||||
minimum_patients=minimum_patients,
|
||||
refresh_id=refresh_id,
|
||||
paths=paths,
|
||||
)
|
||||
# Add results with chart_type suffix
|
||||
for filter_id, records in dir_results.items():
|
||||
# Records already have chart_type set by convert_to_records
|
||||
results[f"{filter_id}:directory"] = records
|
||||
|
||||
elif current_chart_type == "indication":
|
||||
# For indication charts, use drug-aware matching:
|
||||
# 1. Get ALL GP diagnosis matches per patient (with code_frequency)
|
||||
# 2. Cross-reference with drug-to-Search_Term mapping from DimSearchTerm.csv
|
||||
# 3. Assign each drug to its matched indication via modified UPIDs
|
||||
logger.info("Building drug-aware indication groups...")
|
||||
|
||||
# Check Snowflake availability
|
||||
from data_processing.snowflake_connector import get_connector, is_snowflake_available
|
||||
|
||||
if not is_snowflake_available():
|
||||
logger.warning("Snowflake not available - cannot process indication charts")
|
||||
for config in DATE_FILTER_CONFIGS:
|
||||
results[f"{config.id}:indication"] = []
|
||||
continue
|
||||
|
||||
try:
|
||||
import pandas as pd
|
||||
connector = get_connector()
|
||||
|
||||
if 'PseudoNHSNoLinked' not in df.columns:
|
||||
logger.error("DataFrame missing 'PseudoNHSNoLinked' column - cannot lookup GP records")
|
||||
for config in DATE_FILTER_CONFIGS:
|
||||
results[f"{config.id}:indication"] = []
|
||||
continue
|
||||
|
||||
# Step 1: Load drug-to-Search_Term mapping from DimSearchTerm.csv
|
||||
_, search_term_to_fragments = load_drug_indication_mapping()
|
||||
logger.info(f"Loaded drug mapping: {len(search_term_to_fragments)} Search_Terms")
|
||||
|
||||
# Step 2: Get ALL GP diagnosis matches per patient (with code_frequency)
|
||||
patient_pseudonyms = df['PseudoNHSNoLinked'].dropna().unique().tolist()
|
||||
logger.info(f"Looking up GP diagnoses for {len(patient_pseudonyms)} unique patients...")
|
||||
|
||||
# Restrict GP codes to HCD data window (reduces noise from old diagnoses)
|
||||
earliest_hcd_date = df['Intervention Date'].min()
|
||||
if pd.notna(earliest_hcd_date):
|
||||
earliest_hcd_date_str = pd.Timestamp(earliest_hcd_date).strftime('%Y-%m-%d')
|
||||
logger.info(f"Restricting GP codes to HCD window: >= {earliest_hcd_date_str}")
|
||||
else:
|
||||
earliest_hcd_date_str = None
|
||||
|
||||
gp_matches_df = get_patient_indication_groups(
|
||||
patient_pseudonyms=patient_pseudonyms,
|
||||
connector=connector,
|
||||
batch_size=5000,
|
||||
earliest_hcd_date=earliest_hcd_date_str,
|
||||
)
|
||||
|
||||
# Step 3: Assign drug-aware indications using cross-referencing
|
||||
# This replaces the old per-patient approach with per-drug matching
|
||||
modified_df, indication_df = assign_drug_indications(
|
||||
df=df,
|
||||
gp_matches_df=gp_matches_df,
|
||||
search_term_to_fragments=search_term_to_fragments,
|
||||
)
|
||||
|
||||
logger.info(f"Drug-aware indication matching complete. "
|
||||
f"Modified UPIDs: {modified_df['UPID'].nunique()}, "
|
||||
f"Indication groups: {len(indication_df)}")
|
||||
|
||||
if indication_df.empty:
|
||||
logger.warning("Empty indication_df - skipping indication charts")
|
||||
for config in DATE_FILTER_CONFIGS:
|
||||
results[f"{config.id}:indication"] = []
|
||||
else:
|
||||
# Process each date filter with drug-aware indication grouping
|
||||
# Use modified_df (with indication-aware UPIDs) instead of original df
|
||||
for config in DATE_FILTER_CONFIGS:
|
||||
logger.info(f"Processing indication pathway for {config.id}")
|
||||
|
||||
ice_df = process_indication_pathway_for_date_filter(
|
||||
df=modified_df,
|
||||
indication_df=indication_df,
|
||||
config=config,
|
||||
trust_filter=trust_filter,
|
||||
drug_filter=drug_filter,
|
||||
directory_filter=directory_filter,
|
||||
minimum_patients=minimum_patients,
|
||||
paths=paths,
|
||||
)
|
||||
|
||||
if ice_df is None:
|
||||
logger.warning(f"No indication pathway data for {config.id}")
|
||||
results[f"{config.id}:indication"] = []
|
||||
continue
|
||||
|
||||
# Extract denormalized fields (using indication variant)
|
||||
ice_df = extract_indication_fields(ice_df)
|
||||
|
||||
# Convert to records with chart_type="indication"
|
||||
records = convert_to_records(ice_df, config.id, refresh_id, chart_type="indication")
|
||||
results[f"{config.id}:indication"] = records
|
||||
|
||||
logger.info(f"Completed {config.id}:indication: {len(records)} nodes")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing indication charts: {e}")
|
||||
logger.exception(e)
|
||||
for config in DATE_FILTER_CONFIGS:
|
||||
results[f"{config.id}:indication"] = []
|
||||
|
||||
# Count records per filter and chart type
|
||||
stats["chart_type_counts"] = {}
|
||||
for key, records in results.items():
|
||||
stats["date_filter_counts"][key] = len(records)
|
||||
stats["total_records"] += len(records)
|
||||
# Also track by chart type
|
||||
_, ct = key.split(":")
|
||||
stats["chart_type_counts"][ct] = stats["chart_type_counts"].get(ct, 0) + len(records)
|
||||
|
||||
logger.info("")
|
||||
logger.info(f"Processed {stats['total_records']} total pathway nodes")
|
||||
for chart_type_name, count in stats.get("chart_type_counts", {}).items():
|
||||
logger.info(f" {chart_type_name}: {count} nodes total")
|
||||
for key, count in sorted(stats["date_filter_counts"].items()):
|
||||
if count > 0:
|
||||
logger.info(f" {key}: {count} nodes")
|
||||
|
||||
if dry_run:
|
||||
logger.info("")
|
||||
logger.info("DRY RUN - Skipping database insertion")
|
||||
elapsed = time.time() - start_time
|
||||
return True, f"Dry run complete: {stats['total_records']} records would be inserted", stats
|
||||
|
||||
# Step 3: Clear existing data and insert new records
|
||||
logger.info("")
|
||||
logger.info("Step 3/4: Clearing existing pathway data and inserting new records...")
|
||||
|
||||
with db_manager.get_transaction() as conn:
|
||||
# Clear all existing pathway nodes
|
||||
deleted = clear_pathway_nodes(conn)
|
||||
logger.info(f"Cleared {deleted} existing pathway nodes")
|
||||
|
||||
# Insert new records for each date filter + chart type combination
|
||||
total_inserted = 0
|
||||
for key, records in results.items():
|
||||
if records:
|
||||
inserted = insert_pathway_records(conn, records)
|
||||
total_inserted += len(records)
|
||||
logger.info(f" Inserted {len(records)} records for {key}")
|
||||
|
||||
# Step 4: Log completion
|
||||
logger.info("")
|
||||
logger.info("Step 4/4: Logging refresh completion...")
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
with db_manager.get_connection() as conn:
|
||||
log_refresh_complete(
|
||||
conn=conn,
|
||||
refresh_id=refresh_id,
|
||||
record_count=stats["total_records"],
|
||||
date_filter_counts=stats["date_filter_counts"],
|
||||
duration_seconds=elapsed,
|
||||
source_row_count=stats.get("snowflake_rows"),
|
||||
)
|
||||
|
||||
# Verify final counts
|
||||
counts = get_pathway_table_counts(conn)
|
||||
logger.info(f"Final table counts: {counts}")
|
||||
|
||||
logger.info("")
|
||||
logger.info("=" * 60)
|
||||
logger.info(f"Refresh completed successfully in {elapsed:.1f} seconds")
|
||||
logger.info(f"Total records: {stats['total_records']}")
|
||||
logger.info(f"Refresh ID: {refresh_id}")
|
||||
logger.info("=" * 60)
|
||||
|
||||
return True, f"Refresh complete: {stats['total_records']} records in {elapsed:.1f}s", stats
|
||||
|
||||
except Exception as e:
|
||||
elapsed = time.time() - start_time
|
||||
error_msg = f"Refresh failed: {e}"
|
||||
logger.error(error_msg, exc_info=True)
|
||||
|
||||
try:
|
||||
with db_manager.get_connection() as conn:
|
||||
log_refresh_failed(conn, refresh_id, str(e), elapsed)
|
||||
except Exception:
|
||||
pass # Don't fail the error handling
|
||||
|
||||
return False, error_msg, stats
|
||||
|
||||
|
||||
def main() -> int:
|
||||
"""CLI entry point."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Refresh pathway data from Snowflake",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
# Basic refresh with defaults (directory chart only)
|
||||
python -m cli.refresh_pathways
|
||||
|
||||
# Refresh both chart types (directory and indication)
|
||||
python -m cli.refresh_pathways --chart-type all
|
||||
|
||||
# Refresh only indication-based charts
|
||||
python -m cli.refresh_pathways --chart-type indication
|
||||
|
||||
# Refresh with custom minimum patients
|
||||
python -m cli.refresh_pathways --minimum-patients 10
|
||||
|
||||
# Refresh specific providers only
|
||||
python -m cli.refresh_pathways --provider-codes RGT,RM1
|
||||
|
||||
# Dry run to see what would be processed
|
||||
python -m cli.refresh_pathways --dry-run
|
||||
|
||||
# Verbose output
|
||||
python -m cli.refresh_pathways --verbose
|
||||
"""
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--minimum-patients",
|
||||
type=int,
|
||||
default=5,
|
||||
help="Minimum patients to include a pathway (default: 5)"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--provider-codes",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Comma-separated list of provider codes to filter (default: all)"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--db-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to SQLite database (default: data/pathways.db)"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--dry-run",
|
||||
action="store_true",
|
||||
help="Process data but don't insert into database"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--chart-type",
|
||||
type=str,
|
||||
choices=["directory", "indication", "all"],
|
||||
default="directory",
|
||||
help="Chart type to process: 'directory' (default), 'indication', or 'all'"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--verbose", "-v",
|
||||
action="store_true",
|
||||
help="Enable verbose logging"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Configure logging
|
||||
import logging
|
||||
log_level = logging.DEBUG if args.verbose else logging.INFO
|
||||
setup_logging(level=log_level)
|
||||
|
||||
# Parse provider codes
|
||||
provider_codes = None
|
||||
if args.provider_codes:
|
||||
provider_codes = [code.strip() for code in args.provider_codes.split(",")]
|
||||
|
||||
# Parse db path
|
||||
db_path = Path(args.db_path) if args.db_path else None
|
||||
|
||||
# Run the refresh
|
||||
success, message, stats = refresh_pathways(
|
||||
minimum_patients=args.minimum_patients,
|
||||
provider_codes=provider_codes,
|
||||
db_path=db_path,
|
||||
dry_run=args.dry_run,
|
||||
chart_type=args.chart_type,
|
||||
)
|
||||
|
||||
if success:
|
||||
print(f"\n[OK] {message}")
|
||||
return 0
|
||||
else:
|
||||
print(f"\n[FAILED] {message}", file=sys.stderr)
|
||||
return 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
@@ -0,0 +1,23 @@
|
||||
# Config Package
|
||||
|
||||
Snowflake configuration management with dataclass hierarchy and TOML loading.
|
||||
|
||||
## Modules
|
||||
|
||||
**__init__.py** - Configuration dataclass hierarchy:
|
||||
- `ConnectionConfig`, `TimeoutConfig`, `CacheConfig`, `QueryConfig` — Settings containers
|
||||
- `TableReference` — Snowflake object reference with `fully_qualified_name` property
|
||||
- `TablesConfig` — Common table references (activity, patient, medication, organization)
|
||||
- `SnowflakeConfig` — Root config aggregating all above + `validate()` and `is_configured` property
|
||||
- `load_snowflake_config(path=None)` — Load from TOML, default `config/snowflake.toml`
|
||||
- `get_snowflake_config()` — Cached singleton access
|
||||
- `reload_snowflake_config()` — Force reload from disk
|
||||
|
||||
**snowflake.toml** — Snowflake connection settings (co-located with loader)
|
||||
|
||||
## Key Details
|
||||
|
||||
- Uses `tomllib` (Python 3.11+) with `tomli` fallback for 3.10
|
||||
- Missing config file returns default SnowflakeConfig (no error)
|
||||
- All dataclasses have sensible defaults (DATA_HUB.DWH, 24h cache TTL, etc.)
|
||||
- Config is stateless but cached; call `reload_snowflake_config()` to refresh
|
||||
@@ -0,0 +1,271 @@
|
||||
"""
|
||||
Configuration module for Patient Pathway Analysis.
|
||||
|
||||
This module provides access to configuration settings loaded from TOML files.
|
||||
Primary configuration file: config/snowflake.toml
|
||||
|
||||
Usage:
|
||||
from config import load_snowflake_config, SnowflakeConfig
|
||||
|
||||
config = load_snowflake_config()
|
||||
print(config.connection.account)
|
||||
print(config.cache.ttl_seconds)
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
try:
|
||||
import tomllib
|
||||
except ModuleNotFoundError:
|
||||
import tomli as tomllib
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConnectionConfig:
|
||||
"""Snowflake connection settings."""
|
||||
account: str = ""
|
||||
warehouse: str = "ANALYST_WH"
|
||||
database: str = "DATA_HUB"
|
||||
schema: str = "DWH"
|
||||
authenticator: str = "externalbrowser"
|
||||
user: str = ""
|
||||
role: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class TimeoutConfig:
|
||||
"""Timeout settings for Snowflake operations."""
|
||||
connection_timeout: int = 30
|
||||
query_timeout: int = 300
|
||||
login_timeout: int = 120
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheConfig:
|
||||
"""Cache settings for Snowflake query results."""
|
||||
enabled: bool = True
|
||||
directory: str = "data/cache"
|
||||
ttl_seconds: int = 86400 # 24 hours
|
||||
ttl_current_data_seconds: int = 3600 # 1 hour
|
||||
max_size_mb: int = 500
|
||||
|
||||
|
||||
@dataclass
|
||||
class TableReference:
|
||||
"""Reference to a Snowflake table or view."""
|
||||
database: str = ""
|
||||
schema: str = ""
|
||||
view: str = ""
|
||||
table: str = ""
|
||||
key_columns: list = field(default_factory=list)
|
||||
|
||||
@property
|
||||
def fully_qualified_name(self) -> str:
|
||||
"""Return the fully qualified table/view name."""
|
||||
obj_name = self.table or self.view
|
||||
if not obj_name:
|
||||
return ""
|
||||
if self.database and self.schema:
|
||||
return f'"{self.database}"."{self.schema}"."{obj_name}"'
|
||||
elif self.schema:
|
||||
return f'"{self.schema}"."{obj_name}"'
|
||||
else:
|
||||
return f'"{obj_name}"'
|
||||
|
||||
|
||||
@dataclass
|
||||
class TablesConfig:
|
||||
"""Configuration for commonly used tables."""
|
||||
activity: TableReference = field(default_factory=TableReference)
|
||||
patient: TableReference = field(default_factory=TableReference)
|
||||
medication: TableReference = field(default_factory=TableReference)
|
||||
organization: TableReference = field(default_factory=TableReference)
|
||||
|
||||
|
||||
@dataclass
|
||||
class QueryConfig:
|
||||
"""Query execution settings."""
|
||||
quote_identifiers: bool = True
|
||||
test_limit: int = 20
|
||||
max_rows: int = 100000
|
||||
chunk_size: int = 10000
|
||||
|
||||
|
||||
@dataclass
|
||||
class SnowflakeConfig:
|
||||
"""Complete Snowflake configuration."""
|
||||
connection: ConnectionConfig = field(default_factory=ConnectionConfig)
|
||||
timeouts: TimeoutConfig = field(default_factory=TimeoutConfig)
|
||||
cache: CacheConfig = field(default_factory=CacheConfig)
|
||||
tables: TablesConfig = field(default_factory=TablesConfig)
|
||||
query: QueryConfig = field(default_factory=QueryConfig)
|
||||
|
||||
def validate(self) -> list[str]:
|
||||
"""
|
||||
Validate the configuration.
|
||||
|
||||
Returns:
|
||||
List of error messages (empty if valid).
|
||||
"""
|
||||
errors = []
|
||||
|
||||
if not self.connection.account:
|
||||
errors.append("Snowflake account is not configured (connection.account)")
|
||||
|
||||
if not self.connection.warehouse:
|
||||
errors.append("Snowflake warehouse is not configured (connection.warehouse)")
|
||||
|
||||
if self.connection.authenticator not in ("externalbrowser", "snowflake", "oauth", "okta"):
|
||||
errors.append(f"Invalid authenticator: {self.connection.authenticator}")
|
||||
|
||||
if self.cache.ttl_seconds < 0:
|
||||
errors.append("Cache TTL must be non-negative")
|
||||
|
||||
if self.query.max_rows < 1:
|
||||
errors.append("max_rows must be at least 1")
|
||||
|
||||
return errors
|
||||
|
||||
@property
|
||||
def is_configured(self) -> bool:
|
||||
"""Return True if minimum required settings are present."""
|
||||
return bool(self.connection.account)
|
||||
|
||||
|
||||
def _parse_table_reference(data: dict) -> TableReference:
|
||||
"""Parse a table reference from TOML data."""
|
||||
return TableReference(
|
||||
database=data.get("database", ""),
|
||||
schema=data.get("schema", ""),
|
||||
view=data.get("view", ""),
|
||||
table=data.get("table", ""),
|
||||
key_columns=data.get("key_columns", []),
|
||||
)
|
||||
|
||||
|
||||
def load_snowflake_config(config_path: Optional[Path] = None) -> SnowflakeConfig:
|
||||
"""
|
||||
Load Snowflake configuration from TOML file.
|
||||
|
||||
Args:
|
||||
config_path: Path to the TOML config file. Defaults to config/snowflake.toml
|
||||
relative to the project root.
|
||||
|
||||
Returns:
|
||||
SnowflakeConfig dataclass with all settings.
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If the config file doesn't exist.
|
||||
tomllib.TOMLDecodeError: If the TOML is invalid.
|
||||
"""
|
||||
if config_path is None:
|
||||
# Default to config/snowflake.toml relative to this file's directory
|
||||
config_path = Path(__file__).parent / "snowflake.toml"
|
||||
|
||||
if not config_path.exists():
|
||||
# Return default config if file doesn't exist
|
||||
return SnowflakeConfig()
|
||||
|
||||
with open(config_path, "rb") as f:
|
||||
data = tomllib.load(f)
|
||||
|
||||
# Parse connection settings
|
||||
conn_data = data.get("connection", {})
|
||||
connection = ConnectionConfig(
|
||||
account=conn_data.get("account", ""),
|
||||
warehouse=conn_data.get("warehouse", "ANALYST_WH"),
|
||||
database=conn_data.get("database", "DATA_HUB"),
|
||||
schema=conn_data.get("schema", "DWH"),
|
||||
authenticator=conn_data.get("authenticator", "externalbrowser"),
|
||||
user=conn_data.get("user", ""),
|
||||
role=conn_data.get("role", ""),
|
||||
)
|
||||
|
||||
# Parse timeout settings
|
||||
timeout_data = data.get("timeouts", {})
|
||||
timeouts = TimeoutConfig(
|
||||
connection_timeout=timeout_data.get("connection_timeout", 600),
|
||||
query_timeout=timeout_data.get("query_timeout", 300),
|
||||
login_timeout=timeout_data.get("login_timeout", 120),
|
||||
)
|
||||
|
||||
# Parse cache settings
|
||||
cache_data = data.get("cache", {})
|
||||
cache = CacheConfig(
|
||||
enabled=cache_data.get("enabled", True),
|
||||
directory=cache_data.get("directory", "data/cache"),
|
||||
ttl_seconds=cache_data.get("ttl_seconds", 86400),
|
||||
ttl_current_data_seconds=cache_data.get("ttl_current_data_seconds", 3600),
|
||||
max_size_mb=cache_data.get("max_size_mb", 500),
|
||||
)
|
||||
|
||||
# Parse table references
|
||||
tables_data = data.get("tables", {})
|
||||
tables = TablesConfig(
|
||||
activity=_parse_table_reference(tables_data.get("activity", {})),
|
||||
patient=_parse_table_reference(tables_data.get("patient", {})),
|
||||
medication=_parse_table_reference(tables_data.get("medication", {})),
|
||||
organization=_parse_table_reference(tables_data.get("organization", {})),
|
||||
)
|
||||
|
||||
# Parse query settings
|
||||
query_data = data.get("query", {})
|
||||
query = QueryConfig(
|
||||
quote_identifiers=query_data.get("quote_identifiers", True),
|
||||
test_limit=query_data.get("test_limit", 20),
|
||||
max_rows=query_data.get("max_rows", 100000),
|
||||
chunk_size=query_data.get("chunk_size", 10000),
|
||||
)
|
||||
|
||||
return SnowflakeConfig(
|
||||
connection=connection,
|
||||
timeouts=timeouts,
|
||||
cache=cache,
|
||||
tables=tables,
|
||||
query=query,
|
||||
)
|
||||
|
||||
|
||||
# Module-level cached config (loaded on first access)
|
||||
_cached_config: Optional[SnowflakeConfig] = None
|
||||
|
||||
|
||||
def get_snowflake_config() -> SnowflakeConfig:
|
||||
"""
|
||||
Get the Snowflake configuration (cached after first load).
|
||||
|
||||
Returns:
|
||||
SnowflakeConfig dataclass with all settings.
|
||||
"""
|
||||
global _cached_config
|
||||
if _cached_config is None:
|
||||
_cached_config = load_snowflake_config()
|
||||
return _cached_config
|
||||
|
||||
|
||||
def reload_snowflake_config() -> SnowflakeConfig:
|
||||
"""
|
||||
Reload the Snowflake configuration from disk.
|
||||
|
||||
Returns:
|
||||
SnowflakeConfig dataclass with all settings.
|
||||
"""
|
||||
global _cached_config
|
||||
_cached_config = load_snowflake_config()
|
||||
return _cached_config
|
||||
|
||||
|
||||
# Export public API
|
||||
__all__ = [
|
||||
"SnowflakeConfig",
|
||||
"ConnectionConfig",
|
||||
"TimeoutConfig",
|
||||
"CacheConfig",
|
||||
"TableReference",
|
||||
"TablesConfig",
|
||||
"QueryConfig",
|
||||
"load_snowflake_config",
|
||||
"get_snowflake_config",
|
||||
"reload_snowflake_config",
|
||||
]
|
||||
@@ -0,0 +1,129 @@
|
||||
# Snowflake Configuration for NHS Patient Pathway Analysis
|
||||
#
|
||||
# This file contains connection settings for the Snowflake data warehouse.
|
||||
# IMPORTANT: This file should NOT be committed to version control if it contains
|
||||
# sensitive information. However, with externalbrowser auth, no passwords are stored.
|
||||
#
|
||||
# For NHS SSO authentication, the 'externalbrowser' authenticator opens a browser
|
||||
# window for authentication via NHS identity management.
|
||||
|
||||
[connection]
|
||||
# Snowflake account identifier (e.g., "xy12345.uk-south.azure")
|
||||
# Ask your Snowflake administrator for the correct account name
|
||||
account = "ZK91403.uk-south.azure"
|
||||
|
||||
# Default warehouse to use for queries
|
||||
# Common options: ANALYST_WH, COMPUTE_WH
|
||||
warehouse = "WH__XSMALL"
|
||||
|
||||
# Default database for queries
|
||||
# DATA_HUB is the primary analyst-curated data warehouse
|
||||
database = "DATA_HUB"
|
||||
|
||||
# Default schema (optional, can be overridden per query)
|
||||
schema = "DWH"
|
||||
|
||||
# Authentication method
|
||||
# "externalbrowser" opens browser for NHS SSO (required for NHS environments)
|
||||
# Other options: "snowflake" (username/password), "oauth", "okta"
|
||||
authenticator = "externalbrowser"
|
||||
|
||||
# User principal (email address for externalbrowser auth)
|
||||
# Leave empty to use current Windows user or prompt
|
||||
user = "ANDREW.CHARLWOOD@NHS.NET"
|
||||
|
||||
# Role to use (optional, uses default role if empty)
|
||||
role = ""
|
||||
|
||||
[timeouts]
|
||||
# Network timeout in seconds (how long client waits for Snowflake response)
|
||||
# Must be high enough for GP record lookups which can take 30-60s per batch
|
||||
connection_timeout = 600
|
||||
|
||||
# Query execution timeout in seconds (for long-running queries)
|
||||
# Set to 0 for no timeout
|
||||
query_timeout = 300
|
||||
|
||||
# Login timeout in seconds (for SSO browser auth)
|
||||
login_timeout = 120
|
||||
|
||||
[cache]
|
||||
# Enable result caching
|
||||
enabled = true
|
||||
|
||||
# Cache directory (relative to project root or absolute path)
|
||||
# Defaults to data/cache/ if not specified
|
||||
directory = "data/cache"
|
||||
|
||||
# Time-to-live for cached results in seconds
|
||||
# 24 hours for historical data (86400 seconds)
|
||||
ttl_seconds = 86400
|
||||
|
||||
# TTL for data that includes today's date (shorter)
|
||||
ttl_current_data_seconds = 3600
|
||||
|
||||
# Maximum cache size in MB (oldest entries removed when exceeded)
|
||||
max_size_mb = 500
|
||||
|
||||
[databases]
|
||||
# Quick reference for database purposes (read-only documentation)
|
||||
# DATA_HUB = "Analyst-curated data warehouse - primary source for most queries"
|
||||
# PRIMARY_CARE = "Raw extracts from EMIS and TPP clinical systems"
|
||||
# NATIONAL = "NHS England national datasets (SUS, ECDS, MHSDS, etc.)"
|
||||
# FACTS_AND_DIMENSIONS_ALL_DATA = "External reference data (BNF, SNOMED, QOF clusters)"
|
||||
# REPORTING_DATASETS_ICB = "Reporting outputs and analyst workspaces"
|
||||
|
||||
# Tables commonly used for high-cost drug analysis
|
||||
[tables.activity]
|
||||
# Main activity data source (high-cost drug interventions)
|
||||
# Acute__Conmon__PatientLevelDrugs contains patient-level high-cost drug data
|
||||
database = "DATA_HUB"
|
||||
schema = "CDM"
|
||||
table = "Acute__Conmon__PatientLevelDrugs"
|
||||
key_columns = [
|
||||
"PseudoNHSNoLinked", # Pseudonymised NHS number for patient linking
|
||||
"ProviderCode", # NHS provider code (e.g., RM1, RGP)
|
||||
"LocalPatientID", # Local patient identifier within provider
|
||||
"InterventionDate", # Date of drug intervention
|
||||
"DrugName", # Drug name (raw, needs standardization)
|
||||
"DrugSNOMEDCode", # SNOMED code for drug
|
||||
"PriceActual", # Actual cost of intervention
|
||||
"TreatmentFunctionCode", # NHS treatment function code
|
||||
"TreatmentFunctionDesc", # Treatment function description
|
||||
"AdditionalDetail1", # Additional details (used for directory identification)
|
||||
]
|
||||
|
||||
[tables.patient]
|
||||
# Patient demographics
|
||||
database = "DATA_HUB"
|
||||
schema = "DWH"
|
||||
view = "DimPerson"
|
||||
key_columns = ["PatientPseudonym", "PersonKey", "CurrentGeneralPractice"]
|
||||
|
||||
[tables.medication]
|
||||
# Medication reference data
|
||||
database = "DATA_HUB"
|
||||
schema = "DWH"
|
||||
view = "DimMedicineAndDevice"
|
||||
key_columns = ["ProductSnomedCode", "TherapeuticMoietySnomedCode", "ProductDescription"]
|
||||
|
||||
[tables.organization]
|
||||
# NHS organizations and GP practices
|
||||
database = "DATA_HUB"
|
||||
schema = "DWH"
|
||||
view = "DimOrganisationAndSite"
|
||||
key_columns = ["SiteCode", "OrganisationName"]
|
||||
|
||||
[query]
|
||||
# Default query behaviors
|
||||
# Always double-quote identifiers for case-sensitivity
|
||||
quote_identifiers = true
|
||||
|
||||
# Default row limit for test queries
|
||||
test_limit = 20
|
||||
|
||||
# Maximum rows to fetch in a single query (prevents runaway queries)
|
||||
max_rows = 100000
|
||||
|
||||
# Chunk size for large result sets
|
||||
chunk_size = 10000
|
||||
@@ -0,0 +1,25 @@
|
||||
# core/ — Foundation Layer
|
||||
|
||||
Configuration, state models, and logging setup.
|
||||
|
||||
## Modules
|
||||
|
||||
**config.py** — `PathConfig` dataclass encapsulating all file paths (data dir, images, CSVs, fonts).
|
||||
- `validate()` method checks existence of required directories and files
|
||||
- `default_paths` module instance resolves from `Path.cwd()` (not package location)
|
||||
- Critical: CWD must be project root for relative paths to work
|
||||
|
||||
**models.py** — `AnalysisFilters` dataclass for UI filter state (dates, drugs, trusts, directories).
|
||||
|
||||
**logging_config.py** — Structured logging with file + console output.
|
||||
- `setup_logging()` initializes handlers
|
||||
- `get_logger(name)` returns configured logger
|
||||
|
||||
**__init__.py** — Re-exports `PathConfig`, `default_paths`, `AnalysisFilters` for easy importing.
|
||||
|
||||
## Usage
|
||||
|
||||
```python
|
||||
from core import PathConfig, default_paths, AnalysisFilters
|
||||
default_paths.validate() # Verify config on startup
|
||||
```
|
||||
@@ -0,0 +1,17 @@
|
||||
"""
|
||||
Core module for NHS High-Cost Drug Patient Pathway Analysis Tool.
|
||||
|
||||
Contains configuration, models, and shared utilities used across the application.
|
||||
"""
|
||||
|
||||
from core.config import PathConfig, default_paths
|
||||
from core.models import AnalysisFilters
|
||||
from core.logging_config import setup_logging, get_logger
|
||||
|
||||
__all__ = [
|
||||
"PathConfig",
|
||||
"default_paths",
|
||||
"AnalysisFilters",
|
||||
"setup_logging",
|
||||
"get_logger",
|
||||
]
|
||||
@@ -0,0 +1,197 @@
|
||||
"""
|
||||
Configuration module for NHS High-Cost Drug Patient Pathway Analysis Tool.
|
||||
|
||||
Contains PathConfig dataclass for centralizing all file path references.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class PathConfig:
|
||||
"""
|
||||
Centralizes all file paths used across the application.
|
||||
|
||||
Provides a single source of truth for file locations, making it easier to:
|
||||
- Change the data directory location
|
||||
- Support different environments (development, production)
|
||||
- Validate that required files exist
|
||||
|
||||
Attributes:
|
||||
base_dir: Root directory of the application (defaults to current working directory)
|
||||
data_dir: Directory containing reference data files
|
||||
images_dir: Directory containing UI assets and fonts
|
||||
"""
|
||||
|
||||
base_dir: Path = field(default_factory=Path.cwd)
|
||||
_data_dir: Optional[Path] = field(default=None, repr=False)
|
||||
_images_dir: Optional[Path] = field(default=None, repr=False)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Set default subdirectories relative to base_dir if not provided."""
|
||||
if self._data_dir is None:
|
||||
self._data_dir = self.base_dir / "data"
|
||||
if self._images_dir is None:
|
||||
self._images_dir = self.base_dir / "images"
|
||||
|
||||
@property
|
||||
def data_dir(self) -> Path:
|
||||
"""Directory containing reference data files."""
|
||||
# _data_dir is always set after __post_init__
|
||||
assert self._data_dir is not None
|
||||
return self._data_dir
|
||||
|
||||
@property
|
||||
def images_dir(self) -> Path:
|
||||
"""Directory containing UI assets and fonts."""
|
||||
# _images_dir is always set after __post_init__
|
||||
assert self._images_dir is not None
|
||||
return self._images_dir
|
||||
|
||||
# Reference data files (read-only lookups)
|
||||
@property
|
||||
def drugnames_csv(self) -> Path:
|
||||
"""Drug name standardization mapping."""
|
||||
return self.data_dir / "drugnames.csv"
|
||||
|
||||
@property
|
||||
def directory_list_csv(self) -> Path:
|
||||
"""Medical specialties/directories list."""
|
||||
return self.data_dir / "directory_list.csv"
|
||||
|
||||
@property
|
||||
def treatment_function_codes_csv(self) -> Path:
|
||||
"""NHS treatment function code mappings."""
|
||||
return self.data_dir / "treatment_function_codes.csv"
|
||||
|
||||
@property
|
||||
def drug_directory_list_csv(self) -> Path:
|
||||
"""Valid drug-to-directory mappings (pipe-separated)."""
|
||||
return self.data_dir / "drug_directory_list.csv"
|
||||
|
||||
@property
|
||||
def org_codes_csv(self) -> Path:
|
||||
"""Provider code to organization name mapping."""
|
||||
return self.data_dir / "org_codes.csv"
|
||||
|
||||
@property
|
||||
def include_csv(self) -> Path:
|
||||
"""Drug filter list with default selections."""
|
||||
return self.data_dir / "include.csv"
|
||||
|
||||
@property
|
||||
def default_trusts_csv(self) -> Path:
|
||||
"""NHS Trust list for filter."""
|
||||
return self.data_dir / "defaultTrusts.csv"
|
||||
|
||||
# Output/diagnostic files
|
||||
@property
|
||||
def na_directory_rows_csv(self) -> Path:
|
||||
"""Exported rows with unresolved Directory for diagnostics."""
|
||||
return self.data_dir / "na_directory_rows.csv"
|
||||
|
||||
@property
|
||||
def ta_recommendations_xlsx(self) -> Path:
|
||||
"""NICE TA recommendations (downloaded from web)."""
|
||||
return self.data_dir / "ta-recommendations.xlsx"
|
||||
|
||||
# UI assets
|
||||
@property
|
||||
def font_medium(self) -> Path:
|
||||
"""AvenirLTStd-Medium font file."""
|
||||
return self.images_dir / "AvenirLTStd-Medium.ttf"
|
||||
|
||||
@property
|
||||
def font_roman(self) -> Path:
|
||||
"""AvenirLTStd-Roman font file."""
|
||||
return self.images_dir / "AvenirLTStd-Roman.ttf"
|
||||
|
||||
@property
|
||||
def logo_ico(self) -> Path:
|
||||
"""Application icon."""
|
||||
return self.images_dir / "logo.ico"
|
||||
|
||||
@property
|
||||
def logo_png(self) -> Path:
|
||||
"""Application logo."""
|
||||
return self.images_dir / "logo.png"
|
||||
|
||||
def validate(self) -> list[str]:
|
||||
"""
|
||||
Validate that required files and directories exist.
|
||||
|
||||
Returns:
|
||||
List of error messages. Empty list means all validations passed.
|
||||
"""
|
||||
errors = []
|
||||
|
||||
# Check directories exist
|
||||
if not self.data_dir.exists():
|
||||
errors.append(f"Data directory not found: {self.data_dir}")
|
||||
if not self.images_dir.exists():
|
||||
errors.append(f"Images directory not found: {self.images_dir}")
|
||||
|
||||
# Check required reference files
|
||||
required_files = [
|
||||
(self.drugnames_csv, "Drug names mapping"),
|
||||
(self.directory_list_csv, "Directory list"),
|
||||
(self.treatment_function_codes_csv, "Treatment function codes"),
|
||||
(self.drug_directory_list_csv, "Drug-directory mapping"),
|
||||
(self.org_codes_csv, "Organization codes"),
|
||||
(self.include_csv, "Drug include list"),
|
||||
(self.default_trusts_csv, "Default trusts"),
|
||||
]
|
||||
|
||||
for file_path, description in required_files:
|
||||
if not file_path.exists():
|
||||
errors.append(f"{description} not found: {file_path}")
|
||||
|
||||
return errors
|
||||
|
||||
def validate_fonts(self) -> list[str]:
|
||||
"""
|
||||
Validate that font files exist (for GUI mode).
|
||||
|
||||
Returns:
|
||||
List of error messages. Empty list means all validations passed.
|
||||
"""
|
||||
errors = []
|
||||
|
||||
font_files = [
|
||||
(self.font_medium, "Medium font"),
|
||||
(self.font_roman, "Roman font"),
|
||||
]
|
||||
|
||||
for file_path, description in font_files:
|
||||
if not file_path.exists():
|
||||
errors.append(f"{description} not found: {file_path}")
|
||||
|
||||
return errors
|
||||
|
||||
def as_legacy_paths(self) -> dict[str, str]:
|
||||
"""
|
||||
Return paths as strings with './' prefix for backwards compatibility.
|
||||
|
||||
This method eases migration by providing paths in the format
|
||||
currently used throughout the codebase.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping path names to legacy-format string paths.
|
||||
"""
|
||||
return {
|
||||
"drugnames_csv": f"./{self.drugnames_csv.relative_to(self.base_dir)}",
|
||||
"directory_list_csv": f"./{self.directory_list_csv.relative_to(self.base_dir)}",
|
||||
"treatment_function_codes_csv": f"./{self.treatment_function_codes_csv.relative_to(self.base_dir)}",
|
||||
"drug_directory_list_csv": f"./{self.drug_directory_list_csv.relative_to(self.base_dir)}",
|
||||
"org_codes_csv": f"./{self.org_codes_csv.relative_to(self.base_dir)}",
|
||||
"include_csv": f"./{self.include_csv.relative_to(self.base_dir)}",
|
||||
"default_trusts_csv": f"./{self.default_trusts_csv.relative_to(self.base_dir)}",
|
||||
"na_directory_rows_csv": f"./{self.na_directory_rows_csv.relative_to(self.base_dir)}",
|
||||
"ta_recommendations_xlsx": f"./{self.ta_recommendations_xlsx.relative_to(self.base_dir)}",
|
||||
}
|
||||
|
||||
|
||||
# Default instance for application-wide use
|
||||
default_paths = PathConfig()
|
||||
@@ -0,0 +1,121 @@
|
||||
"""
|
||||
Logging configuration for NHS High-Cost Drug Patient Pathway Analysis Tool.
|
||||
|
||||
Provides structured logging setup with console and optional file handlers.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
|
||||
# Default log format: timestamp, level, module name, message
|
||||
DEFAULT_FORMAT = "%(asctime)s [%(levelname)s] %(name)s: %(message)s"
|
||||
DEFAULT_DATE_FORMAT = "%Y-%m-%d %H:%M:%S"
|
||||
|
||||
# Simplified format for console output (used when redirecting to GUI)
|
||||
SIMPLE_FORMAT = "%(message)s"
|
||||
|
||||
|
||||
def setup_logging(
|
||||
level: int = logging.INFO,
|
||||
log_dir: Optional[Path] = None,
|
||||
console: bool = True,
|
||||
file_logging: bool = False,
|
||||
simple_console: bool = False,
|
||||
) -> logging.Logger:
|
||||
"""
|
||||
Configure application-wide logging.
|
||||
|
||||
Args:
|
||||
level: Logging level (default: INFO)
|
||||
log_dir: Directory for log files (default: ./logs/)
|
||||
console: Whether to log to console/stdout (default: True)
|
||||
file_logging: Whether to log to file (default: False)
|
||||
simple_console: Use simplified format for console (just message, no timestamp)
|
||||
|
||||
Returns:
|
||||
Root logger configured for the application
|
||||
|
||||
Usage:
|
||||
# Basic setup - console only
|
||||
logger = setup_logging()
|
||||
|
||||
# With file logging
|
||||
logger = setup_logging(file_logging=True)
|
||||
|
||||
# Debug mode
|
||||
logger = setup_logging(level=logging.DEBUG)
|
||||
|
||||
# GUI mode - simple format for stdout capture
|
||||
logger = setup_logging(simple_console=True)
|
||||
"""
|
||||
# Get root logger for the application
|
||||
root_logger = logging.getLogger("pathways")
|
||||
|
||||
# Clear any existing handlers to avoid duplicates on re-initialization
|
||||
root_logger.handlers.clear()
|
||||
|
||||
root_logger.setLevel(level)
|
||||
|
||||
# Console handler
|
||||
if console:
|
||||
console_handler = logging.StreamHandler(sys.stdout)
|
||||
console_handler.setLevel(level)
|
||||
|
||||
if simple_console:
|
||||
console_format = logging.Formatter(SIMPLE_FORMAT)
|
||||
else:
|
||||
console_format = logging.Formatter(DEFAULT_FORMAT, datefmt=DEFAULT_DATE_FORMAT)
|
||||
|
||||
console_handler.setFormatter(console_format)
|
||||
root_logger.addHandler(console_handler)
|
||||
|
||||
# File handler
|
||||
if file_logging:
|
||||
if log_dir is None:
|
||||
log_dir = Path("./logs")
|
||||
|
||||
log_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
log_filename = f"pathways_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log"
|
||||
log_path = log_dir / log_filename
|
||||
|
||||
file_handler = logging.FileHandler(log_path, encoding="utf-8")
|
||||
file_handler.setLevel(level)
|
||||
file_handler.setFormatter(
|
||||
logging.Formatter(DEFAULT_FORMAT, datefmt=DEFAULT_DATE_FORMAT)
|
||||
)
|
||||
root_logger.addHandler(file_handler)
|
||||
|
||||
return root_logger
|
||||
|
||||
|
||||
def get_logger(name: str) -> logging.Logger:
|
||||
"""
|
||||
Get a logger for a specific module.
|
||||
|
||||
Args:
|
||||
name: Module name (typically __name__)
|
||||
|
||||
Returns:
|
||||
Logger instance configured as child of root pathways logger
|
||||
|
||||
Usage:
|
||||
from core.logging_config import get_logger
|
||||
logger = get_logger(__name__)
|
||||
logger.info("Processing started")
|
||||
logger.error("Something went wrong")
|
||||
"""
|
||||
# Create child logger under the pathways namespace
|
||||
if name.startswith("pathways."):
|
||||
return logging.getLogger(name)
|
||||
return logging.getLogger(f"pathways.{name}")
|
||||
|
||||
|
||||
# Module-level loggers for common components
|
||||
data_logger = get_logger("data")
|
||||
dashboard_logger = get_logger("dashboard")
|
||||
gui_logger = get_logger("gui")
|
||||
@@ -0,0 +1,140 @@
|
||||
"""
|
||||
Data models for NHS High-Cost Drug Patient Pathway Analysis Tool.
|
||||
|
||||
Contains dataclasses for encapsulating application state and filter parameters.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import date
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class AnalysisFilters:
|
||||
"""
|
||||
Encapsulates all filter state for the analysis pipeline.
|
||||
|
||||
Replaces the individual parameters currently passed to generate_graph()
|
||||
and the global state managed in the GUI. This provides:
|
||||
- Type safety for filter values
|
||||
- Validation of filter combinations
|
||||
- Easy serialization for caching/persistence
|
||||
- Clear interface between GUI and analysis engine
|
||||
|
||||
Attributes:
|
||||
start_date: Patient initiated start date (treatment pathway start)
|
||||
end_date: Patient initiated end date (treatment pathway start cutoff)
|
||||
last_seen_date: Minimum last seen date (filters out patients not seen recently)
|
||||
trusts: List of NHS Trust names to include (empty = all)
|
||||
drugs: List of drug names to include (empty = all)
|
||||
directories: List of medical directories/specialties to include (empty = all)
|
||||
custom_title: Optional custom title for the graph (blank = auto-generated)
|
||||
minimum_patients: Minimum number of patients for a pathway to be included
|
||||
output_dir: Directory where output files should be saved
|
||||
"""
|
||||
|
||||
start_date: date
|
||||
end_date: date
|
||||
last_seen_date: date
|
||||
trusts: list[str] = field(default_factory=list)
|
||||
drugs: list[str] = field(default_factory=list)
|
||||
directories: list[str] = field(default_factory=list)
|
||||
custom_title: str = ""
|
||||
minimum_patients: int = 0
|
||||
output_dir: Optional[Path] = None
|
||||
|
||||
def validate(self) -> list[str]:
|
||||
"""
|
||||
Validate filter configuration for logical consistency.
|
||||
|
||||
Returns:
|
||||
List of error messages. Empty list means all validations passed.
|
||||
"""
|
||||
errors = []
|
||||
|
||||
# Date range validation
|
||||
if self.end_date < self.start_date:
|
||||
errors.append(
|
||||
f"End date ({self.end_date}) cannot be before start date ({self.start_date})"
|
||||
)
|
||||
|
||||
if self.last_seen_date > self.end_date:
|
||||
errors.append(
|
||||
f"Last seen date ({self.last_seen_date}) is after end date ({self.end_date}), "
|
||||
"which would exclude all patients"
|
||||
)
|
||||
|
||||
# Minimum patients validation
|
||||
if self.minimum_patients < 0:
|
||||
errors.append(
|
||||
f"Minimum patients ({self.minimum_patients}) cannot be negative"
|
||||
)
|
||||
|
||||
# Output directory validation
|
||||
if self.output_dir is not None and not self.output_dir.exists():
|
||||
errors.append(f"Output directory does not exist: {self.output_dir}")
|
||||
|
||||
# Filter list validation (warn if empty but don't error)
|
||||
# Empty lists are valid and mean "include all"
|
||||
|
||||
return errors
|
||||
|
||||
@property
|
||||
def has_trust_filter(self) -> bool:
|
||||
"""Check if any trust filter is applied."""
|
||||
return len(self.trusts) > 0
|
||||
|
||||
@property
|
||||
def has_drug_filter(self) -> bool:
|
||||
"""Check if any drug filter is applied."""
|
||||
return len(self.drugs) > 0
|
||||
|
||||
@property
|
||||
def has_directory_filter(self) -> bool:
|
||||
"""Check if any directory filter is applied."""
|
||||
return len(self.directories) > 0
|
||||
|
||||
@property
|
||||
def title(self) -> str:
|
||||
"""
|
||||
Return the display title for the graph.
|
||||
|
||||
If custom_title is set, use it. Otherwise, generate a default title
|
||||
based on the date range.
|
||||
"""
|
||||
if self.custom_title:
|
||||
return self.custom_title
|
||||
return f"Patients initiated from {self.start_date} to {self.end_date}"
|
||||
|
||||
def summary(self) -> str:
|
||||
"""
|
||||
Return a human-readable summary of the filter configuration.
|
||||
|
||||
Useful for logging and display in the GUI.
|
||||
"""
|
||||
lines = [
|
||||
f"Date range: {self.start_date} to {self.end_date}",
|
||||
f"Last seen after: {self.last_seen_date}",
|
||||
f"Minimum patients: {self.minimum_patients}",
|
||||
]
|
||||
|
||||
if self.trusts:
|
||||
lines.append(f"Trusts: {len(self.trusts)} selected")
|
||||
else:
|
||||
lines.append("Trusts: All")
|
||||
|
||||
if self.drugs:
|
||||
lines.append(f"Drugs: {len(self.drugs)} selected")
|
||||
else:
|
||||
lines.append("Drugs: All")
|
||||
|
||||
if self.directories:
|
||||
lines.append(f"Directories: {len(self.directories)} selected")
|
||||
else:
|
||||
lines.append("Directories: All")
|
||||
|
||||
if self.custom_title:
|
||||
lines.append(f"Custom title: {self.custom_title}")
|
||||
|
||||
return "\n".join(lines)
|
||||
@@ -0,0 +1,42 @@
|
||||
# data_processing Package
|
||||
|
||||
Data layer for NHS High-Cost Drug Patient Pathway Analysis Tool.
|
||||
|
||||
## Core Responsibilities
|
||||
|
||||
**Data Pipeline:** `Snowflake → Transforms → Pathway Generation → SQLite`
|
||||
|
||||
## Key Modules
|
||||
|
||||
**transforms.py** — Core data transformations (moved from tools/data.py):
|
||||
- `patient_id()` — Creates UPID = Provider Code (first 3 chars) + PersonKey
|
||||
- `drug_names()` — Standardizes drug names via drugnames.csv lookup
|
||||
- `department_identification()` — 5-level fallback chain for directory assignment
|
||||
|
||||
**pathway_pipeline.py** — Pipeline orchestration:
|
||||
- Processes 6 date filter combinations × 2 chart types (directory + indication)
|
||||
- `fetch_and_transform_data()` — Snowflake fetch + UPID/drug/directory transforms
|
||||
- `process_pathway_for_date_filter()` — Directory charts using `generate_icicle_chart()`
|
||||
- `process_indication_pathway_for_date_filter()` — Indication charts using `generate_icicle_chart_indication()`
|
||||
- `insert_pathway_records()` — SQLite insertion with parameterized queries
|
||||
|
||||
**diagnosis_lookup.py** — GP diagnosis matching:
|
||||
- `get_patient_indication_groups()` — Batch queries Snowflake (500 patients at a time)
|
||||
- Embeds ~148 Search_Term → Cluster_ID mappings as SQL CTE
|
||||
- Returns most recent match per patient via `QUALIFY ROW_NUMBER()`
|
||||
|
||||
**database.py** — SQLite connection pooling and transaction management
|
||||
|
||||
**schema.py** — SQL schema definitions (reference tables + pathway_nodes)
|
||||
|
||||
**snowflake_connector.py** — Snowflake SSO integration via externalbrowser authenticator
|
||||
|
||||
**cache.py** — Query result caching with TTL-based invalidation
|
||||
|
||||
## Import Pattern
|
||||
|
||||
All imports use package names directly:
|
||||
```python
|
||||
from data_processing.transforms import patient_id, drug_names, department_identification
|
||||
from data_processing.pathway_pipeline import process_all_date_filters
|
||||
```
|
||||
@@ -0,0 +1,208 @@
|
||||
"""
|
||||
Data processing module for NHS High-Cost Drug Patient Pathway Analysis Tool.
|
||||
|
||||
Contains SQLite database management, data loaders, and Snowflake integration.
|
||||
Handles the migration from CSV-based storage to SQLite for improved performance.
|
||||
|
||||
Submodules:
|
||||
database: SQLite connection management and schema definitions
|
||||
loader: Data loading abstractions (CSV, SQLite, Snowflake)
|
||||
snowflake_connector: Snowflake integration with SSO authentication
|
||||
"""
|
||||
|
||||
from data_processing.database import (
|
||||
DatabaseConfig,
|
||||
DatabaseManager,
|
||||
default_db_config,
|
||||
default_db_manager,
|
||||
)
|
||||
from data_processing.schema import (
|
||||
# Reference table schemas
|
||||
REF_DRUG_NAMES_SCHEMA,
|
||||
REF_ORGANIZATIONS_SCHEMA,
|
||||
REF_DIRECTORIES_SCHEMA,
|
||||
REF_DRUG_DIRECTORY_MAP_SCHEMA,
|
||||
REF_DRUG_INDICATION_CLUSTERS_SCHEMA,
|
||||
REFERENCE_TABLES_SCHEMA,
|
||||
# Combined schema
|
||||
ALL_TABLES_SCHEMA,
|
||||
# Reference table functions
|
||||
create_reference_tables,
|
||||
drop_reference_tables,
|
||||
get_reference_table_counts,
|
||||
verify_reference_tables_exist,
|
||||
# Combined functions
|
||||
create_all_tables,
|
||||
drop_all_tables,
|
||||
get_all_table_counts,
|
||||
verify_all_tables_exist,
|
||||
)
|
||||
|
||||
# Reference data migration functions
|
||||
from data_processing.reference_data import (
|
||||
MigrationResult,
|
||||
migrate_drug_names,
|
||||
get_drug_name_counts,
|
||||
verify_drug_names_migration,
|
||||
migrate_organizations,
|
||||
get_organization_counts,
|
||||
verify_organizations_migration,
|
||||
migrate_directories,
|
||||
get_directory_counts,
|
||||
verify_directories_migration,
|
||||
migrate_drug_directory_map,
|
||||
get_drug_directory_map_counts,
|
||||
verify_drug_directory_map_migration,
|
||||
migrate_drug_indication_clusters,
|
||||
get_drug_indication_cluster_counts,
|
||||
verify_drug_indication_clusters_migration,
|
||||
)
|
||||
|
||||
# Data loader abstractions
|
||||
from data_processing.loader import (
|
||||
DataLoader,
|
||||
FileDataLoader,
|
||||
LoadResult,
|
||||
get_loader,
|
||||
REQUIRED_COLUMNS,
|
||||
OPTIONAL_COLUMNS,
|
||||
)
|
||||
|
||||
# Snowflake connector
|
||||
from data_processing.snowflake_connector import (
|
||||
SnowflakeConnector,
|
||||
SnowflakeConnectionError,
|
||||
SnowflakeNotConfiguredError,
|
||||
SnowflakeNotAvailableError,
|
||||
ConnectionInfo,
|
||||
get_connector,
|
||||
reset_connector,
|
||||
is_snowflake_available,
|
||||
is_snowflake_configured,
|
||||
SNOWFLAKE_AVAILABLE,
|
||||
)
|
||||
|
||||
# Query result caching
|
||||
from data_processing.cache import (
|
||||
QueryCache,
|
||||
CacheEntry,
|
||||
CacheStats,
|
||||
get_cache,
|
||||
reset_cache,
|
||||
is_cache_enabled,
|
||||
)
|
||||
|
||||
# Data source management with fallback chain
|
||||
from data_processing.data_source import (
|
||||
DataSourceType,
|
||||
DataSourceResult,
|
||||
SourceStatus,
|
||||
DataSourceManager,
|
||||
get_data_source_manager,
|
||||
get_data,
|
||||
reset_data_source_manager,
|
||||
)
|
||||
|
||||
# Diagnosis lookup (GP diagnosis validation)
|
||||
from data_processing.diagnosis_lookup import (
|
||||
ClusterSnomedCodes,
|
||||
IndicationValidationResult,
|
||||
DrugIndicationMatchRate,
|
||||
get_drug_clusters,
|
||||
get_drug_cluster_ids,
|
||||
get_cluster_snomed_codes,
|
||||
patient_has_indication,
|
||||
validate_indication,
|
||||
get_indication_match_rate,
|
||||
batch_validate_indications,
|
||||
get_available_clusters,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Database management
|
||||
"DatabaseConfig",
|
||||
"DatabaseManager",
|
||||
"default_db_config",
|
||||
"default_db_manager",
|
||||
# Reference table schemas
|
||||
"REF_DRUG_NAMES_SCHEMA",
|
||||
"REF_ORGANIZATIONS_SCHEMA",
|
||||
"REF_DIRECTORIES_SCHEMA",
|
||||
"REF_DRUG_DIRECTORY_MAP_SCHEMA",
|
||||
"REF_DRUG_INDICATION_CLUSTERS_SCHEMA",
|
||||
"REFERENCE_TABLES_SCHEMA",
|
||||
# Combined schema
|
||||
"ALL_TABLES_SCHEMA",
|
||||
# Reference table functions
|
||||
"create_reference_tables",
|
||||
"drop_reference_tables",
|
||||
"get_reference_table_counts",
|
||||
"verify_reference_tables_exist",
|
||||
# Combined functions
|
||||
"create_all_tables",
|
||||
"drop_all_tables",
|
||||
"get_all_table_counts",
|
||||
"verify_all_tables_exist",
|
||||
# Reference data migration
|
||||
"MigrationResult",
|
||||
"migrate_drug_names",
|
||||
"get_drug_name_counts",
|
||||
"verify_drug_names_migration",
|
||||
"migrate_organizations",
|
||||
"get_organization_counts",
|
||||
"verify_organizations_migration",
|
||||
"migrate_directories",
|
||||
"get_directory_counts",
|
||||
"verify_directories_migration",
|
||||
"migrate_drug_directory_map",
|
||||
"get_drug_directory_map_counts",
|
||||
"verify_drug_directory_map_migration",
|
||||
"migrate_drug_indication_clusters",
|
||||
"get_drug_indication_cluster_counts",
|
||||
"verify_drug_indication_clusters_migration",
|
||||
# Data loader abstractions
|
||||
"DataLoader",
|
||||
"FileDataLoader",
|
||||
"LoadResult",
|
||||
"get_loader",
|
||||
"REQUIRED_COLUMNS",
|
||||
"OPTIONAL_COLUMNS",
|
||||
# Snowflake connector
|
||||
"SnowflakeConnector",
|
||||
"SnowflakeConnectionError",
|
||||
"SnowflakeNotConfiguredError",
|
||||
"SnowflakeNotAvailableError",
|
||||
"ConnectionInfo",
|
||||
"get_connector",
|
||||
"reset_connector",
|
||||
"is_snowflake_available",
|
||||
"is_snowflake_configured",
|
||||
"SNOWFLAKE_AVAILABLE",
|
||||
# Query result caching
|
||||
"QueryCache",
|
||||
"CacheEntry",
|
||||
"CacheStats",
|
||||
"get_cache",
|
||||
"reset_cache",
|
||||
"is_cache_enabled",
|
||||
# Data source management with fallback chain
|
||||
"DataSourceType",
|
||||
"DataSourceResult",
|
||||
"SourceStatus",
|
||||
"DataSourceManager",
|
||||
"get_data_source_manager",
|
||||
"get_data",
|
||||
"reset_data_source_manager",
|
||||
# Diagnosis lookup
|
||||
"ClusterSnomedCodes",
|
||||
"IndicationValidationResult",
|
||||
"DrugIndicationMatchRate",
|
||||
"get_drug_clusters",
|
||||
"get_drug_cluster_ids",
|
||||
"get_cluster_snomed_codes",
|
||||
"patient_has_indication",
|
||||
"validate_indication",
|
||||
"get_indication_match_rate",
|
||||
"batch_validate_indications",
|
||||
"get_available_clusters",
|
||||
]
|
||||
@@ -0,0 +1,553 @@
|
||||
"""
|
||||
Query result caching module for NHS Patient Pathway Analysis.
|
||||
|
||||
Provides file-based caching for Snowflake query results with TTL-based invalidation.
|
||||
Supports different TTLs for historical data vs data including the current date.
|
||||
|
||||
Cache keys are generated from query hashes. Results are stored as compressed JSON.
|
||||
|
||||
Usage:
|
||||
from data_processing.cache import QueryCache, get_cache
|
||||
|
||||
cache = get_cache()
|
||||
|
||||
# Check for cached result
|
||||
result = cache.get(query, params)
|
||||
if result is None:
|
||||
# Execute query and cache result
|
||||
result = execute_query(query, params)
|
||||
cache.set(query, params, result, includes_current_data=False)
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, date
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
import gzip
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
|
||||
from config import get_snowflake_config, CacheConfig
|
||||
from core.logging_config import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheEntry:
|
||||
"""Metadata for a cached query result."""
|
||||
cache_key: str
|
||||
query_hash: str
|
||||
created_at: datetime
|
||||
expires_at: datetime
|
||||
includes_current_data: bool
|
||||
row_count: int
|
||||
file_size_bytes: int
|
||||
file_path: Path
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheStats:
|
||||
"""Statistics about the cache."""
|
||||
enabled: bool
|
||||
cache_dir: Path
|
||||
total_entries: int
|
||||
total_size_mb: float
|
||||
max_size_mb: int
|
||||
oldest_entry: Optional[datetime]
|
||||
newest_entry: Optional[datetime]
|
||||
hit_count: int
|
||||
miss_count: int
|
||||
|
||||
|
||||
class QueryCache:
|
||||
"""
|
||||
File-based cache for Snowflake query results.
|
||||
|
||||
Results are stored as gzipped JSON files with TTL-based expiration.
|
||||
Supports different TTLs for historical vs current data.
|
||||
|
||||
Attributes:
|
||||
config: CacheConfig with cache settings
|
||||
cache_dir: Path to cache directory
|
||||
"""
|
||||
|
||||
def __init__(self, config: Optional[CacheConfig] = None, base_path: Optional[Path] = None):
|
||||
"""
|
||||
Initialize the query cache.
|
||||
|
||||
Args:
|
||||
config: Optional CacheConfig. If not provided, loads from snowflake.toml
|
||||
base_path: Base path for relative cache directory. Defaults to cwd.
|
||||
"""
|
||||
if config is None:
|
||||
sf_config = get_snowflake_config()
|
||||
config = sf_config.cache
|
||||
|
||||
self._config = config
|
||||
self._base_path = base_path or Path.cwd()
|
||||
|
||||
# Resolve cache directory
|
||||
cache_dir = Path(config.directory)
|
||||
if not cache_dir.is_absolute():
|
||||
cache_dir = self._base_path / cache_dir
|
||||
self._cache_dir = cache_dir
|
||||
|
||||
# Stats tracking (in-memory only, reset on restart)
|
||||
self._hit_count = 0
|
||||
self._miss_count = 0
|
||||
|
||||
# Ensure cache directory exists if enabled
|
||||
if self._config.enabled:
|
||||
self._cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
@property
|
||||
def config(self) -> CacheConfig:
|
||||
"""Return the cache configuration."""
|
||||
return self._config
|
||||
|
||||
@property
|
||||
def cache_dir(self) -> Path:
|
||||
"""Return the cache directory path."""
|
||||
return self._cache_dir
|
||||
|
||||
@property
|
||||
def is_enabled(self) -> bool:
|
||||
"""Return True if caching is enabled."""
|
||||
return self._config.enabled
|
||||
|
||||
def _generate_cache_key(self, query: str, params: Optional[tuple] = None) -> str:
|
||||
"""
|
||||
Generate a cache key from query and parameters.
|
||||
|
||||
Uses SHA256 hash of query + params to create unique key.
|
||||
"""
|
||||
# Normalize query (strip whitespace, lowercase)
|
||||
normalized_query = " ".join(query.lower().split())
|
||||
|
||||
# Combine query and params
|
||||
key_content = normalized_query
|
||||
if params:
|
||||
key_content += "|" + "|".join(str(p) for p in params)
|
||||
|
||||
# Hash to create key
|
||||
hash_obj = hashlib.sha256(key_content.encode("utf-8"))
|
||||
return hash_obj.hexdigest()[:32] # Use first 32 chars for readability
|
||||
|
||||
def _get_cache_file_path(self, cache_key: str) -> Path:
|
||||
"""Get the file path for a cache entry."""
|
||||
return self._cache_dir / f"{cache_key}.json.gz"
|
||||
|
||||
def _get_meta_file_path(self, cache_key: str) -> Path:
|
||||
"""Get the metadata file path for a cache entry."""
|
||||
return self._cache_dir / f"{cache_key}.meta.json"
|
||||
|
||||
def _is_expired(self, meta: dict) -> bool:
|
||||
"""Check if a cache entry is expired based on its metadata."""
|
||||
expires_at = datetime.fromisoformat(meta["expires_at"])
|
||||
return datetime.now() > expires_at
|
||||
|
||||
def get(
|
||||
self,
|
||||
query: str,
|
||||
params: Optional[tuple] = None,
|
||||
check_expiry: bool = True
|
||||
) -> Optional[list[dict]]:
|
||||
"""
|
||||
Get a cached query result.
|
||||
|
||||
Args:
|
||||
query: SQL query string
|
||||
params: Optional query parameters
|
||||
check_expiry: If True, returns None for expired entries
|
||||
|
||||
Returns:
|
||||
Cached result as list of dicts, or None if not cached/expired
|
||||
"""
|
||||
if not self.is_enabled:
|
||||
self._miss_count += 1
|
||||
return None
|
||||
|
||||
cache_key = self._generate_cache_key(query, params)
|
||||
cache_file = self._get_cache_file_path(cache_key)
|
||||
meta_file = self._get_meta_file_path(cache_key)
|
||||
|
||||
# Check if files exist
|
||||
if not cache_file.exists() or not meta_file.exists():
|
||||
self._miss_count += 1
|
||||
logger.debug(f"Cache miss (not found): {cache_key}")
|
||||
return None
|
||||
|
||||
# Load and check metadata
|
||||
try:
|
||||
with open(meta_file, "r", encoding="utf-8") as f:
|
||||
meta = json.load(f)
|
||||
|
||||
if check_expiry and self._is_expired(meta):
|
||||
self._miss_count += 1
|
||||
logger.debug(f"Cache miss (expired): {cache_key}")
|
||||
return None
|
||||
|
||||
# Load cached data
|
||||
with gzip.open(cache_file, "rt", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
self._hit_count += 1
|
||||
logger.info(f"Cache hit: {cache_key} ({meta['row_count']} rows)")
|
||||
return data
|
||||
|
||||
except (json.JSONDecodeError, KeyError, OSError) as e:
|
||||
logger.warning(f"Cache read error for {cache_key}: {e}")
|
||||
self._miss_count += 1
|
||||
# Clean up corrupted entry
|
||||
self._delete_entry(cache_key)
|
||||
return None
|
||||
|
||||
def set(
|
||||
self,
|
||||
query: str,
|
||||
params: Optional[tuple],
|
||||
data: list[dict],
|
||||
includes_current_data: bool = False,
|
||||
custom_ttl_seconds: Optional[int] = None
|
||||
) -> Optional[CacheEntry]:
|
||||
"""
|
||||
Cache a query result.
|
||||
|
||||
Args:
|
||||
query: SQL query string
|
||||
params: Optional query parameters
|
||||
data: Query result as list of dicts
|
||||
includes_current_data: If True, uses shorter TTL for current data
|
||||
custom_ttl_seconds: Optional custom TTL (overrides config)
|
||||
|
||||
Returns:
|
||||
CacheEntry with metadata, or None if caching disabled/failed
|
||||
"""
|
||||
if not self.is_enabled:
|
||||
return None
|
||||
|
||||
cache_key = self._generate_cache_key(query, params)
|
||||
cache_file = self._get_cache_file_path(cache_key)
|
||||
meta_file = self._get_meta_file_path(cache_key)
|
||||
|
||||
# Determine TTL
|
||||
if custom_ttl_seconds is not None:
|
||||
ttl = custom_ttl_seconds
|
||||
elif includes_current_data:
|
||||
ttl = self._config.ttl_current_data_seconds
|
||||
else:
|
||||
ttl = self._config.ttl_seconds
|
||||
|
||||
now = datetime.now()
|
||||
expires_at = datetime.fromtimestamp(now.timestamp() + ttl)
|
||||
|
||||
try:
|
||||
# Write compressed data
|
||||
with gzip.open(cache_file, "wt", encoding="utf-8", compresslevel=6) as f:
|
||||
json.dump(data, f, default=str)
|
||||
|
||||
file_size = cache_file.stat().st_size
|
||||
|
||||
# Write metadata
|
||||
meta = {
|
||||
"cache_key": cache_key,
|
||||
"query_hash": hashlib.sha256(query.encode()).hexdigest()[:16],
|
||||
"created_at": now.isoformat(),
|
||||
"expires_at": expires_at.isoformat(),
|
||||
"includes_current_data": includes_current_data,
|
||||
"row_count": len(data),
|
||||
"file_size_bytes": file_size,
|
||||
"ttl_seconds": ttl,
|
||||
}
|
||||
|
||||
with open(meta_file, "w", encoding="utf-8") as f:
|
||||
json.dump(meta, f, indent=2)
|
||||
|
||||
logger.info(f"Cached {len(data)} rows as {cache_key} (expires in {ttl}s)")
|
||||
|
||||
# Check if we need to enforce size limit
|
||||
self._enforce_size_limit()
|
||||
|
||||
return CacheEntry(
|
||||
cache_key=cache_key,
|
||||
query_hash=str(meta["query_hash"]),
|
||||
created_at=now,
|
||||
expires_at=expires_at,
|
||||
includes_current_data=includes_current_data,
|
||||
row_count=len(data),
|
||||
file_size_bytes=file_size,
|
||||
file_path=cache_file,
|
||||
)
|
||||
|
||||
except (OSError, TypeError) as e:
|
||||
logger.error(f"Failed to cache result: {e}")
|
||||
return None
|
||||
|
||||
def invalidate(self, query: str, params: Optional[tuple] = None) -> bool:
|
||||
"""
|
||||
Invalidate a specific cache entry.
|
||||
|
||||
Args:
|
||||
query: SQL query string
|
||||
params: Optional query parameters
|
||||
|
||||
Returns:
|
||||
True if entry was deleted, False if not found
|
||||
"""
|
||||
cache_key = self._generate_cache_key(query, params)
|
||||
return self._delete_entry(cache_key)
|
||||
|
||||
def _delete_entry(self, cache_key: str) -> bool:
|
||||
"""Delete a cache entry by key."""
|
||||
cache_file = self._get_cache_file_path(cache_key)
|
||||
meta_file = self._get_meta_file_path(cache_key)
|
||||
|
||||
deleted = False
|
||||
|
||||
if cache_file.exists():
|
||||
cache_file.unlink()
|
||||
deleted = True
|
||||
|
||||
if meta_file.exists():
|
||||
meta_file.unlink()
|
||||
deleted = True
|
||||
|
||||
if deleted:
|
||||
logger.debug(f"Deleted cache entry: {cache_key}")
|
||||
|
||||
return deleted
|
||||
|
||||
def clear(self) -> int:
|
||||
"""
|
||||
Clear all cache entries.
|
||||
|
||||
Returns:
|
||||
Number of entries deleted
|
||||
"""
|
||||
if not self._cache_dir.exists():
|
||||
return 0
|
||||
|
||||
count = 0
|
||||
for file in self._cache_dir.glob("*.json*"):
|
||||
try:
|
||||
file.unlink()
|
||||
count += 1
|
||||
except OSError as e:
|
||||
logger.warning(f"Failed to delete {file}: {e}")
|
||||
|
||||
# Reset stats
|
||||
self._hit_count = 0
|
||||
self._miss_count = 0
|
||||
|
||||
logger.info(f"Cleared {count} cache files")
|
||||
return count // 2 # Divide by 2 since we have .json.gz and .meta.json
|
||||
|
||||
def clear_expired(self) -> int:
|
||||
"""
|
||||
Remove expired cache entries.
|
||||
|
||||
Returns:
|
||||
Number of expired entries deleted
|
||||
"""
|
||||
if not self._cache_dir.exists():
|
||||
return 0
|
||||
|
||||
count = 0
|
||||
for meta_file in self._cache_dir.glob("*.meta.json"):
|
||||
try:
|
||||
with open(meta_file, "r", encoding="utf-8") as f:
|
||||
meta = json.load(f)
|
||||
|
||||
if self._is_expired(meta):
|
||||
cache_key = meta_file.stem.replace(".meta", "")
|
||||
self._delete_entry(cache_key)
|
||||
count += 1
|
||||
except (OSError, json.JSONDecodeError):
|
||||
# Delete corrupted metadata files
|
||||
cache_key = meta_file.stem.replace(".meta", "")
|
||||
self._delete_entry(cache_key)
|
||||
count += 1
|
||||
|
||||
logger.info(f"Cleared {count} expired cache entries")
|
||||
return count
|
||||
|
||||
def _get_total_size_mb(self) -> float:
|
||||
"""Calculate total cache size in MB."""
|
||||
if not self._cache_dir.exists():
|
||||
return 0.0
|
||||
|
||||
total_bytes = sum(
|
||||
f.stat().st_size
|
||||
for f in self._cache_dir.glob("*")
|
||||
if f.is_file()
|
||||
)
|
||||
return total_bytes / (1024 * 1024)
|
||||
|
||||
def _enforce_size_limit(self) -> int:
|
||||
"""
|
||||
Enforce cache size limit by removing oldest entries.
|
||||
|
||||
Returns:
|
||||
Number of entries removed
|
||||
"""
|
||||
max_size_mb = self._config.max_size_mb
|
||||
current_size_mb = self._get_total_size_mb()
|
||||
|
||||
if current_size_mb <= max_size_mb:
|
||||
return 0
|
||||
|
||||
# Get all entries sorted by creation time
|
||||
entries = []
|
||||
for meta_file in self._cache_dir.glob("*.meta.json"):
|
||||
try:
|
||||
with open(meta_file, "r", encoding="utf-8") as f:
|
||||
meta = json.load(f)
|
||||
entries.append((
|
||||
meta_file.stem.replace(".meta", ""),
|
||||
datetime.fromisoformat(meta["created_at"]),
|
||||
meta.get("file_size_bytes", 0)
|
||||
))
|
||||
except (OSError, json.JSONDecodeError, KeyError):
|
||||
# Clean up corrupted entry
|
||||
cache_key = meta_file.stem.replace(".meta", "")
|
||||
self._delete_entry(cache_key)
|
||||
|
||||
# Sort by creation time (oldest first)
|
||||
entries.sort(key=lambda x: x[1])
|
||||
|
||||
# Remove oldest entries until under limit
|
||||
removed = 0
|
||||
size_to_remove_bytes = (current_size_mb - max_size_mb * 0.9) * 1024 * 1024 # Target 90% of limit
|
||||
removed_bytes = 0
|
||||
|
||||
for cache_key, created_at, file_size in entries:
|
||||
if removed_bytes >= size_to_remove_bytes:
|
||||
break
|
||||
|
||||
self._delete_entry(cache_key)
|
||||
removed_bytes += file_size
|
||||
removed += 1
|
||||
|
||||
logger.info(f"Removed {removed} cache entries to enforce size limit")
|
||||
return removed
|
||||
|
||||
def get_stats(self) -> CacheStats:
|
||||
"""Get cache statistics."""
|
||||
if not self._cache_dir.exists():
|
||||
return CacheStats(
|
||||
enabled=self.is_enabled,
|
||||
cache_dir=self._cache_dir,
|
||||
total_entries=0,
|
||||
total_size_mb=0.0,
|
||||
max_size_mb=self._config.max_size_mb,
|
||||
oldest_entry=None,
|
||||
newest_entry=None,
|
||||
hit_count=self._hit_count,
|
||||
miss_count=self._miss_count,
|
||||
)
|
||||
|
||||
entries = []
|
||||
for meta_file in self._cache_dir.glob("*.meta.json"):
|
||||
try:
|
||||
with open(meta_file, "r", encoding="utf-8") as f:
|
||||
meta = json.load(f)
|
||||
entries.append(datetime.fromisoformat(meta["created_at"]))
|
||||
except (OSError, json.JSONDecodeError, KeyError):
|
||||
pass
|
||||
|
||||
oldest = min(entries) if entries else None
|
||||
newest = max(entries) if entries else None
|
||||
|
||||
return CacheStats(
|
||||
enabled=self.is_enabled,
|
||||
cache_dir=self._cache_dir,
|
||||
total_entries=len(entries),
|
||||
total_size_mb=self._get_total_size_mb(),
|
||||
max_size_mb=self._config.max_size_mb,
|
||||
oldest_entry=oldest,
|
||||
newest_entry=newest,
|
||||
hit_count=self._hit_count,
|
||||
miss_count=self._miss_count,
|
||||
)
|
||||
|
||||
def list_entries(self) -> list[CacheEntry]:
|
||||
"""List all cache entries with metadata."""
|
||||
if not self._cache_dir.exists():
|
||||
return []
|
||||
|
||||
entries = []
|
||||
for meta_file in self._cache_dir.glob("*.meta.json"):
|
||||
try:
|
||||
with open(meta_file, "r", encoding="utf-8") as f:
|
||||
meta = json.load(f)
|
||||
|
||||
cache_key = meta["cache_key"]
|
||||
entries.append(CacheEntry(
|
||||
cache_key=cache_key,
|
||||
query_hash=meta.get("query_hash", ""),
|
||||
created_at=datetime.fromisoformat(meta["created_at"]),
|
||||
expires_at=datetime.fromisoformat(meta["expires_at"]),
|
||||
includes_current_data=meta.get("includes_current_data", False),
|
||||
row_count=meta.get("row_count", 0),
|
||||
file_size_bytes=meta.get("file_size_bytes", 0),
|
||||
file_path=self._get_cache_file_path(cache_key),
|
||||
))
|
||||
except (OSError, json.JSONDecodeError, KeyError):
|
||||
pass
|
||||
|
||||
# Sort by creation time (newest first)
|
||||
entries.sort(key=lambda x: x.created_at, reverse=True)
|
||||
return entries
|
||||
|
||||
|
||||
# Module-level singleton
|
||||
_default_cache: Optional[QueryCache] = None
|
||||
|
||||
|
||||
def get_cache(config: Optional[CacheConfig] = None) -> QueryCache:
|
||||
"""
|
||||
Get a QueryCache instance (creates singleton on first call).
|
||||
|
||||
Args:
|
||||
config: Optional CacheConfig. If provided, creates new cache with
|
||||
this config. If None, uses/creates default cache.
|
||||
|
||||
Returns:
|
||||
QueryCache instance
|
||||
"""
|
||||
global _default_cache
|
||||
|
||||
if config is not None:
|
||||
# Custom config requested, create new cache
|
||||
return QueryCache(config)
|
||||
|
||||
if _default_cache is None:
|
||||
_default_cache = QueryCache()
|
||||
|
||||
return _default_cache
|
||||
|
||||
|
||||
def reset_cache() -> None:
|
||||
"""Reset the default cache singleton."""
|
||||
global _default_cache
|
||||
_default_cache = None
|
||||
|
||||
|
||||
def is_cache_enabled() -> bool:
|
||||
"""Return True if caching is enabled in configuration."""
|
||||
config = get_snowflake_config()
|
||||
return config.cache.enabled
|
||||
|
||||
|
||||
# Export public API
|
||||
__all__ = [
|
||||
"QueryCache",
|
||||
"CacheEntry",
|
||||
"CacheStats",
|
||||
"get_cache",
|
||||
"reset_cache",
|
||||
"is_cache_enabled",
|
||||
]
|
||||
@@ -0,0 +1,932 @@
|
||||
"""
|
||||
Unified data access layer with fallback chain for NHS Patient Pathway Analysis.
|
||||
|
||||
Provides a high-level interface that automatically selects the best available data source:
|
||||
1. Cache - Returns cached results if valid and not expired
|
||||
2. Snowflake - Queries Snowflake warehouse if configured and connected
|
||||
3. Local - Falls back to SQLite database or CSV/Parquet files
|
||||
|
||||
The fallback chain handles connection errors, missing configurations, and
|
||||
unavailable services gracefully, always attempting to provide data from
|
||||
some source.
|
||||
|
||||
Usage:
|
||||
from data_processing.data_source import DataSourceManager, get_data
|
||||
|
||||
# Simple usage with automatic source selection
|
||||
result = get_data(
|
||||
start_date=date(2024, 1, 1),
|
||||
end_date=date(2024, 12, 31),
|
||||
trusts=["TRUST A", "TRUST B"],
|
||||
)
|
||||
|
||||
# Or with explicit source preference
|
||||
manager = DataSourceManager()
|
||||
result = manager.get_data(
|
||||
start_date=date(2024, 1, 1),
|
||||
end_date=date(2024, 12, 31),
|
||||
preferred_source="snowflake",
|
||||
)
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import date, datetime
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Optional, Callable
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from core.logging_config import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class DataSourceType(Enum):
|
||||
"""Enumeration of available data sources."""
|
||||
CACHE = "cache"
|
||||
SNOWFLAKE = "snowflake"
|
||||
SQLITE = "sqlite"
|
||||
FILE = "file"
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataSourceResult:
|
||||
"""Result from data source query.
|
||||
|
||||
Attributes:
|
||||
df: The loaded DataFrame with patient intervention data
|
||||
source_type: Which data source was used
|
||||
source_detail: Additional details about the source (e.g., file path, query hash)
|
||||
row_count: Number of rows returned
|
||||
cached: Whether the result came from cache
|
||||
from_fallback: Whether a fallback source was used
|
||||
load_time_seconds: Time taken to load data
|
||||
warnings: Any warnings generated during loading
|
||||
"""
|
||||
df: pd.DataFrame
|
||||
source_type: DataSourceType
|
||||
source_detail: str = ""
|
||||
row_count: int = 0
|
||||
cached: bool = False
|
||||
from_fallback: bool = False
|
||||
load_time_seconds: float = 0.0
|
||||
warnings: list[str] = field(default_factory=list)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.row_count == 0 and self.df is not None:
|
||||
self.row_count = len(self.df)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SourceStatus:
|
||||
"""Status of a data source.
|
||||
|
||||
Attributes:
|
||||
source_type: The type of data source
|
||||
available: Whether the source is available
|
||||
configured: Whether the source is properly configured
|
||||
message: Status message explaining the state
|
||||
last_checked: When the status was last checked
|
||||
"""
|
||||
source_type: DataSourceType
|
||||
available: bool = False
|
||||
configured: bool = False
|
||||
message: str = ""
|
||||
last_checked: Optional[datetime] = None
|
||||
|
||||
|
||||
class DataSourceManager:
|
||||
"""
|
||||
Manages data access with automatic fallback between sources.
|
||||
|
||||
The manager attempts to retrieve data from sources in order of preference:
|
||||
1. Cache (if enabled and has valid cached data)
|
||||
2. Snowflake (if configured and connected)
|
||||
3. SQLite (if database exists with data)
|
||||
4. Local files (CSV/Parquet)
|
||||
|
||||
Attributes:
|
||||
cache_enabled: Whether to use caching
|
||||
local_file_path: Path to local CSV/Parquet file (optional fallback)
|
||||
sqlite_db_path: Path to SQLite database (optional)
|
||||
|
||||
Example:
|
||||
manager = DataSourceManager()
|
||||
|
||||
# Check what sources are available
|
||||
status = manager.check_all_sources()
|
||||
for s in status:
|
||||
print(f"{s.source_type.value}: {s.message}")
|
||||
|
||||
# Get data with automatic fallback
|
||||
result = manager.get_data(
|
||||
start_date=date(2024, 1, 1),
|
||||
end_date=date(2024, 6, 30),
|
||||
)
|
||||
print(f"Got {result.row_count} rows from {result.source_type.value}")
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cache_enabled: bool = True,
|
||||
local_file_path: Optional[Path | str] = None,
|
||||
sqlite_db_path: Optional[Path | str] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the data source manager.
|
||||
|
||||
Args:
|
||||
cache_enabled: Whether to check cache before querying (default True)
|
||||
local_file_path: Path to local CSV/Parquet file for file fallback
|
||||
sqlite_db_path: Path to SQLite database (uses default if None)
|
||||
"""
|
||||
self._cache_enabled = cache_enabled
|
||||
self._local_file_path = Path(local_file_path) if local_file_path else None
|
||||
self._sqlite_db_path = Path(sqlite_db_path) if sqlite_db_path else None
|
||||
self._source_status: dict[DataSourceType, SourceStatus] = {}
|
||||
|
||||
@property
|
||||
def cache_enabled(self) -> bool:
|
||||
"""Return whether caching is enabled."""
|
||||
return self._cache_enabled
|
||||
|
||||
@cache_enabled.setter
|
||||
def cache_enabled(self, value: bool):
|
||||
"""Set whether caching is enabled."""
|
||||
self._cache_enabled = value
|
||||
|
||||
def _check_cache_status(self) -> SourceStatus:
|
||||
"""Check if cache is available."""
|
||||
try:
|
||||
from data_processing.cache import is_cache_enabled, get_cache
|
||||
|
||||
if not is_cache_enabled():
|
||||
return SourceStatus(
|
||||
source_type=DataSourceType.CACHE,
|
||||
available=False,
|
||||
configured=False,
|
||||
message="Cache disabled in configuration",
|
||||
last_checked=datetime.now(),
|
||||
)
|
||||
|
||||
cache = get_cache()
|
||||
stats = cache.get_stats()
|
||||
|
||||
return SourceStatus(
|
||||
source_type=DataSourceType.CACHE,
|
||||
available=True,
|
||||
configured=True,
|
||||
message=f"Cache enabled ({stats.total_entries} entries, {stats.total_size_mb:.1f}MB)",
|
||||
last_checked=datetime.now(),
|
||||
)
|
||||
except Exception as e:
|
||||
return SourceStatus(
|
||||
source_type=DataSourceType.CACHE,
|
||||
available=False,
|
||||
configured=False,
|
||||
message=f"Cache error: {e}",
|
||||
last_checked=datetime.now(),
|
||||
)
|
||||
|
||||
def _check_snowflake_status(self) -> SourceStatus:
|
||||
"""Check if Snowflake is available and configured."""
|
||||
try:
|
||||
from data_processing.snowflake_connector import (
|
||||
is_snowflake_available,
|
||||
is_snowflake_configured,
|
||||
)
|
||||
|
||||
if not is_snowflake_available():
|
||||
return SourceStatus(
|
||||
source_type=DataSourceType.SNOWFLAKE,
|
||||
available=False,
|
||||
configured=False,
|
||||
message="snowflake-connector-python not installed",
|
||||
last_checked=datetime.now(),
|
||||
)
|
||||
|
||||
if not is_snowflake_configured():
|
||||
return SourceStatus(
|
||||
source_type=DataSourceType.SNOWFLAKE,
|
||||
available=True,
|
||||
configured=False,
|
||||
message="Snowflake account not configured in config/snowflake.toml",
|
||||
last_checked=datetime.now(),
|
||||
)
|
||||
|
||||
return SourceStatus(
|
||||
source_type=DataSourceType.SNOWFLAKE,
|
||||
available=True,
|
||||
configured=True,
|
||||
message="Snowflake configured and ready",
|
||||
last_checked=datetime.now(),
|
||||
)
|
||||
except Exception as e:
|
||||
return SourceStatus(
|
||||
source_type=DataSourceType.SNOWFLAKE,
|
||||
available=False,
|
||||
configured=False,
|
||||
message=f"Snowflake error: {e}",
|
||||
last_checked=datetime.now(),
|
||||
)
|
||||
|
||||
def _check_sqlite_status(self) -> SourceStatus:
|
||||
"""Check if SQLite database is available with pathway data."""
|
||||
try:
|
||||
from data_processing.database import default_db_config
|
||||
|
||||
db_path = self._sqlite_db_path or Path(default_db_config.db_path)
|
||||
|
||||
if not db_path.exists():
|
||||
return SourceStatus(
|
||||
source_type=DataSourceType.SQLITE,
|
||||
available=False,
|
||||
configured=True,
|
||||
message=f"Database not found: {db_path}",
|
||||
last_checked=datetime.now(),
|
||||
)
|
||||
|
||||
from data_processing.database import DatabaseManager, DatabaseConfig
|
||||
|
||||
config = DatabaseConfig(db_path=db_path)
|
||||
manager = DatabaseManager(config)
|
||||
|
||||
if not manager.table_exists("pathway_nodes"):
|
||||
return SourceStatus(
|
||||
source_type=DataSourceType.SQLITE,
|
||||
available=False,
|
||||
configured=True,
|
||||
message="pathway_nodes table not found",
|
||||
last_checked=datetime.now(),
|
||||
)
|
||||
|
||||
count = manager.get_table_count("pathway_nodes")
|
||||
if count == 0:
|
||||
return SourceStatus(
|
||||
source_type=DataSourceType.SQLITE,
|
||||
available=False,
|
||||
configured=True,
|
||||
message="pathway_nodes table is empty",
|
||||
last_checked=datetime.now(),
|
||||
)
|
||||
|
||||
return SourceStatus(
|
||||
source_type=DataSourceType.SQLITE,
|
||||
available=True,
|
||||
configured=True,
|
||||
message=f"SQLite database ready ({count:,} pathway nodes)",
|
||||
last_checked=datetime.now(),
|
||||
)
|
||||
except Exception as e:
|
||||
return SourceStatus(
|
||||
source_type=DataSourceType.SQLITE,
|
||||
available=False,
|
||||
configured=False,
|
||||
message=f"SQLite error: {e}",
|
||||
last_checked=datetime.now(),
|
||||
)
|
||||
|
||||
def _check_file_status(self) -> SourceStatus:
|
||||
"""Check if local file is available."""
|
||||
if self._local_file_path is None:
|
||||
return SourceStatus(
|
||||
source_type=DataSourceType.FILE,
|
||||
available=False,
|
||||
configured=False,
|
||||
message="No local file path configured",
|
||||
last_checked=datetime.now(),
|
||||
)
|
||||
|
||||
if not self._local_file_path.exists():
|
||||
return SourceStatus(
|
||||
source_type=DataSourceType.FILE,
|
||||
available=False,
|
||||
configured=True,
|
||||
message=f"File not found: {self._local_file_path}",
|
||||
last_checked=datetime.now(),
|
||||
)
|
||||
|
||||
size_mb = self._local_file_path.stat().st_size / (1024 * 1024)
|
||||
return SourceStatus(
|
||||
source_type=DataSourceType.FILE,
|
||||
available=True,
|
||||
configured=True,
|
||||
message=f"Local file ready: {self._local_file_path.name} ({size_mb:.1f}MB)",
|
||||
last_checked=datetime.now(),
|
||||
)
|
||||
|
||||
def check_source_status(self, source_type: DataSourceType) -> SourceStatus:
|
||||
"""
|
||||
Check the status of a specific data source.
|
||||
|
||||
Args:
|
||||
source_type: The type of source to check
|
||||
|
||||
Returns:
|
||||
SourceStatus with current availability information
|
||||
"""
|
||||
if source_type == DataSourceType.CACHE:
|
||||
return self._check_cache_status()
|
||||
elif source_type == DataSourceType.SNOWFLAKE:
|
||||
return self._check_snowflake_status()
|
||||
elif source_type == DataSourceType.SQLITE:
|
||||
return self._check_sqlite_status()
|
||||
elif source_type == DataSourceType.FILE:
|
||||
return self._check_file_status()
|
||||
else:
|
||||
return SourceStatus(
|
||||
source_type=source_type,
|
||||
available=False,
|
||||
configured=False,
|
||||
message=f"Unknown source type: {source_type}",
|
||||
last_checked=datetime.now(),
|
||||
)
|
||||
|
||||
def check_all_sources(self) -> list[SourceStatus]:
|
||||
"""
|
||||
Check the status of all data sources.
|
||||
|
||||
Returns:
|
||||
List of SourceStatus for each source type
|
||||
"""
|
||||
statuses = []
|
||||
for source_type in DataSourceType:
|
||||
status = self.check_source_status(source_type)
|
||||
self._source_status[source_type] = status
|
||||
statuses.append(status)
|
||||
return statuses
|
||||
|
||||
def _build_cache_key_params(
|
||||
self,
|
||||
start_date: Optional[date],
|
||||
end_date: Optional[date],
|
||||
trusts: Optional[list[str]],
|
||||
drugs: Optional[list[str]],
|
||||
directories: Optional[list[str]],
|
||||
) -> tuple[str, tuple]:
|
||||
"""Build a cache-compatible query string and params for the filter criteria."""
|
||||
# Create a canonical representation for caching
|
||||
query_parts = ["SELECT * FROM activity_data"]
|
||||
params = []
|
||||
|
||||
conditions = []
|
||||
if start_date:
|
||||
conditions.append("start_date >= ?")
|
||||
params.append(str(start_date))
|
||||
if end_date:
|
||||
conditions.append("end_date <= ?")
|
||||
params.append(str(end_date))
|
||||
if trusts:
|
||||
placeholders = ",".join(["?"] * len(trusts))
|
||||
conditions.append(f"trust IN ({placeholders})")
|
||||
params.extend(sorted(trusts))
|
||||
if drugs:
|
||||
placeholders = ",".join(["?"] * len(drugs))
|
||||
conditions.append(f"drug IN ({placeholders})")
|
||||
params.extend(sorted(drugs))
|
||||
if directories:
|
||||
placeholders = ",".join(["?"] * len(directories))
|
||||
conditions.append(f"directory IN ({placeholders})")
|
||||
params.extend(sorted(directories))
|
||||
|
||||
if conditions:
|
||||
query_parts.append("WHERE " + " AND ".join(conditions))
|
||||
|
||||
query = " ".join(query_parts)
|
||||
return query, tuple(params)
|
||||
|
||||
def _try_cache(
|
||||
self,
|
||||
start_date: Optional[date],
|
||||
end_date: Optional[date],
|
||||
trusts: Optional[list[str]],
|
||||
drugs: Optional[list[str]],
|
||||
directories: Optional[list[str]],
|
||||
) -> Optional[DataSourceResult]:
|
||||
"""Try to get data from cache."""
|
||||
if not self._cache_enabled:
|
||||
return None
|
||||
|
||||
try:
|
||||
from data_processing.cache import get_cache
|
||||
|
||||
cache = get_cache()
|
||||
if not cache.is_enabled:
|
||||
return None
|
||||
|
||||
query, params = self._build_cache_key_params(
|
||||
start_date, end_date, trusts, drugs, directories
|
||||
)
|
||||
|
||||
cached_data = cache.get(query, params)
|
||||
if cached_data is None:
|
||||
logger.debug("Cache miss")
|
||||
return None
|
||||
|
||||
# Convert cached data back to DataFrame
|
||||
df = pd.DataFrame(cached_data)
|
||||
|
||||
# Convert date columns
|
||||
if 'Intervention Date' in df.columns:
|
||||
df['Intervention Date'] = pd.to_datetime(df['Intervention Date'])
|
||||
|
||||
logger.info(f"Cache hit: {len(df)} rows")
|
||||
|
||||
return DataSourceResult(
|
||||
df=df,
|
||||
source_type=DataSourceType.CACHE,
|
||||
source_detail=f"cache_key={query[:50]}...",
|
||||
row_count=len(df),
|
||||
cached=True,
|
||||
from_fallback=False,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Cache lookup failed: {e}")
|
||||
return None
|
||||
|
||||
def _try_snowflake(
|
||||
self,
|
||||
start_date: Optional[date],
|
||||
end_date: Optional[date],
|
||||
trusts: Optional[list[str]],
|
||||
drugs: Optional[list[str]],
|
||||
directories: Optional[list[str]],
|
||||
progress_callback: Optional[Callable[[int, int], None]] = None,
|
||||
) -> Optional[DataSourceResult]:
|
||||
"""Try to get data from Snowflake."""
|
||||
import time
|
||||
|
||||
try:
|
||||
from data_processing.snowflake_connector import (
|
||||
is_snowflake_available,
|
||||
is_snowflake_configured,
|
||||
get_connector,
|
||||
SnowflakeConnectionError,
|
||||
)
|
||||
|
||||
if not is_snowflake_available():
|
||||
logger.debug("Snowflake connector not installed")
|
||||
return None
|
||||
|
||||
if not is_snowflake_configured():
|
||||
logger.debug("Snowflake not configured")
|
||||
return None
|
||||
|
||||
# Get connector and fetch data
|
||||
connector = get_connector()
|
||||
logger.info("Fetching data from Snowflake...")
|
||||
start_time = time.time()
|
||||
|
||||
# Fetch activity data from Snowflake
|
||||
# Note: provider_codes filter not directly supported yet - would need trust name to code mapping
|
||||
rows = connector.fetch_activity_data(
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
provider_codes=None, # TODO: map trust names to provider codes if needed
|
||||
)
|
||||
|
||||
if not rows:
|
||||
logger.warning("Snowflake returned no data")
|
||||
return None
|
||||
|
||||
# Convert to DataFrame
|
||||
df = pd.DataFrame(rows)
|
||||
load_time = time.time() - start_time
|
||||
|
||||
logger.info(f"Snowflake loaded {len(df)} rows in {load_time:.2f}s")
|
||||
|
||||
# Apply local transformations to match expected format
|
||||
# (patient_id, drug_names, department_identification)
|
||||
from data_processing.transforms import patient_id, drug_names, department_identification
|
||||
from core import default_paths
|
||||
|
||||
df = patient_id(df)
|
||||
df = drug_names(df, paths=default_paths)
|
||||
df = department_identification(df, paths=default_paths)
|
||||
|
||||
# Apply additional filters if provided
|
||||
if trusts and 'OrganisationName' in df.columns:
|
||||
df = df[df['OrganisationName'].isin(trusts)]
|
||||
if drugs and 'Drug Name' in df.columns:
|
||||
df = df[df['Drug Name'].isin(drugs)]
|
||||
if directories and 'Directory' in df.columns:
|
||||
df = df[df['Directory'].isin(directories)]
|
||||
|
||||
return DataSourceResult(
|
||||
df=df,
|
||||
source_type=DataSourceType.SNOWFLAKE,
|
||||
source_detail="DATA_HUB.CDM.Acute__Conmon__PatientLevelDrugs",
|
||||
row_count=len(df),
|
||||
cached=False,
|
||||
from_fallback=False,
|
||||
load_time_seconds=load_time,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Snowflake query failed: {e}")
|
||||
return None
|
||||
|
||||
def _try_sqlite(
|
||||
self,
|
||||
start_date: Optional[date],
|
||||
end_date: Optional[date],
|
||||
trusts: Optional[list[str]],
|
||||
drugs: Optional[list[str]],
|
||||
directories: Optional[list[str]],
|
||||
) -> Optional[DataSourceResult]:
|
||||
"""Try to get data from SQLite.
|
||||
|
||||
Note: Raw intervention data is no longer stored in SQLite.
|
||||
The app now uses pre-computed pathway_nodes via load_pathway_data().
|
||||
This fallback is retained for interface compatibility but always returns None.
|
||||
"""
|
||||
logger.debug("SQLite raw data fallback skipped (fact_interventions removed)")
|
||||
return None
|
||||
|
||||
def _try_file(
|
||||
self,
|
||||
start_date: Optional[date],
|
||||
end_date: Optional[date],
|
||||
trusts: Optional[list[str]],
|
||||
drugs: Optional[list[str]],
|
||||
directories: Optional[list[str]],
|
||||
) -> Optional[DataSourceResult]:
|
||||
"""Try to get data from local file."""
|
||||
import time
|
||||
|
||||
if self._local_file_path is None:
|
||||
logger.debug("No local file configured")
|
||||
return None
|
||||
|
||||
try:
|
||||
from data_processing.loader import FileDataLoader
|
||||
|
||||
loader = FileDataLoader(file_path=self._local_file_path)
|
||||
|
||||
is_valid, msg = loader.validate_source()
|
||||
if not is_valid:
|
||||
logger.debug(f"Local file not available: {msg}")
|
||||
return None
|
||||
|
||||
start_time = time.time()
|
||||
result = loader.load()
|
||||
df = result.df
|
||||
|
||||
# Apply filters (file loader loads all data, then we filter)
|
||||
if start_date and 'Intervention Date' in df.columns:
|
||||
df = df[df['Intervention Date'] >= pd.Timestamp(start_date)]
|
||||
if end_date and 'Intervention Date' in df.columns:
|
||||
df = df[df['Intervention Date'] < pd.Timestamp(end_date)]
|
||||
if trusts and 'OrganisationName' in df.columns:
|
||||
df = df[df['OrganisationName'].isin(trusts)]
|
||||
if drugs and 'Drug Name' in df.columns:
|
||||
df = df[df['Drug Name'].isin(drugs)]
|
||||
if directories and 'Directory' in df.columns:
|
||||
df = df[df['Directory'].isin(directories)]
|
||||
|
||||
load_time = time.time() - start_time
|
||||
|
||||
logger.info(f"File loaded and filtered: {len(df)} rows in {load_time:.2f}s")
|
||||
|
||||
return DataSourceResult(
|
||||
df=df,
|
||||
source_type=DataSourceType.FILE,
|
||||
source_detail=str(self._local_file_path),
|
||||
row_count=len(df),
|
||||
cached=False,
|
||||
from_fallback=True,
|
||||
load_time_seconds=load_time,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"File load failed: {e}")
|
||||
return None
|
||||
|
||||
def get_data(
|
||||
self,
|
||||
start_date: Optional[date] = None,
|
||||
end_date: Optional[date] = None,
|
||||
trusts: Optional[list[str]] = None,
|
||||
drugs: Optional[list[str]] = None,
|
||||
directories: Optional[list[str]] = None,
|
||||
preferred_source: Optional[str] = None,
|
||||
skip_cache: bool = False,
|
||||
progress_callback: Optional[Callable[[int, int], None]] = None,
|
||||
) -> DataSourceResult:
|
||||
"""
|
||||
Get patient intervention data from the best available source.
|
||||
|
||||
The fallback chain is: Cache → Snowflake → SQLite → File
|
||||
|
||||
Args:
|
||||
start_date: Optional start date for filtering (inclusive)
|
||||
end_date: Optional end date for filtering (exclusive)
|
||||
trusts: Optional list of trust names to filter
|
||||
drugs: Optional list of drug names to filter
|
||||
directories: Optional list of directories to filter
|
||||
preferred_source: Optional preferred source ("snowflake", "sqlite", "file")
|
||||
skip_cache: If True, bypass cache and query source directly
|
||||
progress_callback: Optional callback(current, total) for progress updates
|
||||
|
||||
Returns:
|
||||
DataSourceResult with the loaded data and metadata
|
||||
|
||||
Raises:
|
||||
ValueError: If no data source is available or all sources fail
|
||||
"""
|
||||
import time
|
||||
start_time = time.time()
|
||||
warnings = []
|
||||
|
||||
# If preferred source specified, try that first
|
||||
if preferred_source:
|
||||
preferred = preferred_source.lower()
|
||||
if preferred == "snowflake":
|
||||
result = self._try_snowflake(
|
||||
start_date, end_date, trusts, drugs, directories, progress_callback
|
||||
)
|
||||
if result:
|
||||
result.load_time_seconds = time.time() - start_time
|
||||
return result
|
||||
warnings.append("Preferred source 'snowflake' unavailable")
|
||||
|
||||
elif preferred == "sqlite":
|
||||
result = self._try_sqlite(
|
||||
start_date, end_date, trusts, drugs, directories
|
||||
)
|
||||
if result:
|
||||
result.load_time_seconds = time.time() - start_time
|
||||
return result
|
||||
warnings.append("Preferred source 'sqlite' unavailable")
|
||||
|
||||
elif preferred == "file":
|
||||
result = self._try_file(
|
||||
start_date, end_date, trusts, drugs, directories
|
||||
)
|
||||
if result:
|
||||
result.load_time_seconds = time.time() - start_time
|
||||
return result
|
||||
warnings.append("Preferred source 'file' unavailable")
|
||||
|
||||
# Standard fallback chain: cache → snowflake → sqlite → file
|
||||
|
||||
# 1. Try cache first (unless skipped)
|
||||
if not skip_cache:
|
||||
result = self._try_cache(
|
||||
start_date, end_date, trusts, drugs, directories
|
||||
)
|
||||
if result:
|
||||
result.load_time_seconds = time.time() - start_time
|
||||
return result
|
||||
|
||||
# 2. Try Snowflake
|
||||
result = self._try_snowflake(
|
||||
start_date, end_date, trusts, drugs, directories, progress_callback
|
||||
)
|
||||
if result:
|
||||
# Cache the result for future queries
|
||||
if self._cache_enabled:
|
||||
self._cache_result(
|
||||
result.df,
|
||||
start_date, end_date, trusts, drugs, directories,
|
||||
includes_current_data=end_date is None or end_date >= date.today()
|
||||
)
|
||||
result.load_time_seconds = time.time() - start_time
|
||||
return result
|
||||
|
||||
# 3. Try SQLite
|
||||
result = self._try_sqlite(
|
||||
start_date, end_date, trusts, drugs, directories
|
||||
)
|
||||
if result:
|
||||
result.from_fallback = True # Mark as fallback since Snowflake wasn't used
|
||||
result.load_time_seconds = time.time() - start_time
|
||||
if warnings:
|
||||
result.warnings.extend(warnings)
|
||||
return result
|
||||
|
||||
# 4. Try local file
|
||||
result = self._try_file(
|
||||
start_date, end_date, trusts, drugs, directories
|
||||
)
|
||||
if result:
|
||||
result.from_fallback = True
|
||||
result.load_time_seconds = time.time() - start_time
|
||||
if warnings:
|
||||
result.warnings.extend(warnings)
|
||||
return result
|
||||
|
||||
# All sources failed
|
||||
source_status = self.check_all_sources()
|
||||
status_msg = "; ".join(
|
||||
f"{s.source_type.value}: {s.message}" for s in source_status
|
||||
)
|
||||
raise ValueError(f"No data source available. Status: {status_msg}")
|
||||
|
||||
def _cache_result(
|
||||
self,
|
||||
df: pd.DataFrame,
|
||||
start_date: Optional[date],
|
||||
end_date: Optional[date],
|
||||
trusts: Optional[list[str]],
|
||||
drugs: Optional[list[str]],
|
||||
directories: Optional[list[str]],
|
||||
includes_current_data: bool = False,
|
||||
) -> bool:
|
||||
"""Cache a query result for future use."""
|
||||
try:
|
||||
from data_processing.cache import get_cache
|
||||
|
||||
cache = get_cache()
|
||||
if not cache.is_enabled:
|
||||
return False
|
||||
|
||||
query, params = self._build_cache_key_params(
|
||||
start_date, end_date, trusts, drugs, directories
|
||||
)
|
||||
|
||||
# Convert DataFrame to list of dicts for caching
|
||||
# Convert datetime columns to strings for JSON serialization
|
||||
df_copy = df.copy()
|
||||
for col in df_copy.columns:
|
||||
if pd.api.types.is_datetime64_any_dtype(df_copy[col]):
|
||||
df_copy[col] = df_copy[col].astype(str)
|
||||
|
||||
data = df_copy.to_dict(orient='records')
|
||||
|
||||
entry = cache.set(
|
||||
query, params, data,
|
||||
includes_current_data=includes_current_data
|
||||
)
|
||||
|
||||
if entry:
|
||||
logger.info(f"Cached {len(data)} rows (key={entry.cache_key[:16]}...)")
|
||||
return True
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to cache result: {e}")
|
||||
return False
|
||||
|
||||
def clear_cache(self) -> int:
|
||||
"""
|
||||
Clear all cached data.
|
||||
|
||||
Returns:
|
||||
Number of cache entries cleared
|
||||
"""
|
||||
try:
|
||||
from data_processing.cache import get_cache
|
||||
cache = get_cache()
|
||||
return cache.clear()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to clear cache: {e}")
|
||||
return 0
|
||||
|
||||
def refresh_from_snowflake(
|
||||
self,
|
||||
start_date: Optional[date] = None,
|
||||
end_date: Optional[date] = None,
|
||||
trusts: Optional[list[str]] = None,
|
||||
drugs: Optional[list[str]] = None,
|
||||
directories: Optional[list[str]] = None,
|
||||
progress_callback: Optional[Callable[[int, int], None]] = None,
|
||||
) -> DataSourceResult:
|
||||
"""
|
||||
Force a refresh from Snowflake, bypassing cache and other sources.
|
||||
|
||||
This method specifically queries Snowflake and will fail if Snowflake
|
||||
is not available or not configured.
|
||||
|
||||
Args:
|
||||
start_date: Optional start date for filtering
|
||||
end_date: Optional end date for filtering
|
||||
trusts: Optional list of trust names
|
||||
drugs: Optional list of drug names
|
||||
directories: Optional list of directories
|
||||
progress_callback: Optional progress callback
|
||||
|
||||
Returns:
|
||||
DataSourceResult from Snowflake
|
||||
|
||||
Raises:
|
||||
ValueError: If Snowflake is not available or query fails
|
||||
"""
|
||||
from data_processing.snowflake_connector import (
|
||||
is_snowflake_available,
|
||||
is_snowflake_configured,
|
||||
)
|
||||
|
||||
if not is_snowflake_available():
|
||||
raise ValueError("Snowflake connector not installed")
|
||||
|
||||
if not is_snowflake_configured():
|
||||
raise ValueError("Snowflake not configured - edit config/snowflake.toml")
|
||||
|
||||
result = self._try_snowflake(
|
||||
start_date, end_date, trusts, drugs, directories, progress_callback
|
||||
)
|
||||
|
||||
if result is None:
|
||||
raise ValueError("Snowflake query failed - check logs for details")
|
||||
|
||||
# Cache the fresh result
|
||||
if self._cache_enabled:
|
||||
self._cache_result(
|
||||
result.df,
|
||||
start_date, end_date, trusts, drugs, directories,
|
||||
includes_current_data=end_date is None or end_date >= date.today()
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# Module-level singleton and convenience functions
|
||||
_default_manager: Optional[DataSourceManager] = None
|
||||
|
||||
|
||||
def get_data_source_manager(
|
||||
cache_enabled: bool = True,
|
||||
local_file_path: Optional[Path | str] = None,
|
||||
sqlite_db_path: Optional[Path | str] = None,
|
||||
) -> DataSourceManager:
|
||||
"""
|
||||
Get a DataSourceManager instance.
|
||||
|
||||
Args:
|
||||
cache_enabled: Whether to enable caching
|
||||
local_file_path: Optional path to local CSV/Parquet file
|
||||
sqlite_db_path: Optional path to SQLite database
|
||||
|
||||
Returns:
|
||||
DataSourceManager instance
|
||||
"""
|
||||
global _default_manager
|
||||
|
||||
# If custom paths provided, create a new manager
|
||||
if local_file_path or sqlite_db_path:
|
||||
return DataSourceManager(
|
||||
cache_enabled=cache_enabled,
|
||||
local_file_path=local_file_path,
|
||||
sqlite_db_path=sqlite_db_path,
|
||||
)
|
||||
|
||||
# Otherwise use/create singleton
|
||||
if _default_manager is None:
|
||||
_default_manager = DataSourceManager(cache_enabled=cache_enabled)
|
||||
|
||||
return _default_manager
|
||||
|
||||
|
||||
def get_data(
|
||||
start_date: Optional[date] = None,
|
||||
end_date: Optional[date] = None,
|
||||
trusts: Optional[list[str]] = None,
|
||||
drugs: Optional[list[str]] = None,
|
||||
directories: Optional[list[str]] = None,
|
||||
preferred_source: Optional[str] = None,
|
||||
skip_cache: bool = False,
|
||||
) -> DataSourceResult:
|
||||
"""
|
||||
Convenience function to get data using the default manager.
|
||||
|
||||
Args:
|
||||
start_date: Optional start date for filtering
|
||||
end_date: Optional end date for filtering
|
||||
trusts: Optional list of trust names
|
||||
drugs: Optional list of drug names
|
||||
directories: Optional list of directories
|
||||
preferred_source: Optional preferred source
|
||||
skip_cache: If True, bypass cache
|
||||
|
||||
Returns:
|
||||
DataSourceResult with loaded data
|
||||
"""
|
||||
manager = get_data_source_manager()
|
||||
return manager.get_data(
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
trusts=trusts,
|
||||
drugs=drugs,
|
||||
directories=directories,
|
||||
preferred_source=preferred_source,
|
||||
skip_cache=skip_cache,
|
||||
)
|
||||
|
||||
|
||||
def reset_data_source_manager() -> None:
|
||||
"""Reset the default data source manager singleton."""
|
||||
global _default_manager
|
||||
_default_manager = None
|
||||
|
||||
|
||||
# Export public API
|
||||
__all__ = [
|
||||
"DataSourceType",
|
||||
"DataSourceResult",
|
||||
"SourceStatus",
|
||||
"DataSourceManager",
|
||||
"get_data_source_manager",
|
||||
"get_data",
|
||||
"reset_data_source_manager",
|
||||
]
|
||||
@@ -0,0 +1,239 @@
|
||||
"""
|
||||
SQLite database connection management for NHS High-Cost Drug Patient Pathway Analysis Tool.
|
||||
|
||||
Provides connection management, schema initialization, and common database operations.
|
||||
Uses context manager pattern for safe resource handling.
|
||||
"""
|
||||
|
||||
import sqlite3
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Optional, Generator, Literal
|
||||
|
||||
from core.logging_config import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class DatabaseConfig:
|
||||
"""
|
||||
Configuration for SQLite database location and connection parameters.
|
||||
|
||||
Attributes:
|
||||
db_path: Path to the SQLite database file
|
||||
timeout: Connection timeout in seconds (default: 30)
|
||||
isolation_level: Transaction isolation level (default: None for autocommit)
|
||||
"""
|
||||
|
||||
DEFAULT_DB_NAME = "pathways.db"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
db_path: Optional[Path] = None,
|
||||
data_dir: Optional[Path] = None,
|
||||
timeout: float = 30.0,
|
||||
isolation_level: Optional[Literal['DEFERRED', 'EXCLUSIVE', 'IMMEDIATE']] = None
|
||||
):
|
||||
"""
|
||||
Initialize database configuration.
|
||||
|
||||
Args:
|
||||
db_path: Full path to database file. If None, uses data_dir/DEFAULT_DB_NAME.
|
||||
data_dir: Directory to place database in. Defaults to ./data/
|
||||
timeout: Connection timeout in seconds.
|
||||
isolation_level: Transaction isolation level. None = autocommit.
|
||||
"""
|
||||
if db_path is not None:
|
||||
self.db_path = Path(db_path)
|
||||
elif data_dir is not None:
|
||||
self.db_path = Path(data_dir) / self.DEFAULT_DB_NAME
|
||||
else:
|
||||
self.db_path = Path("./data") / self.DEFAULT_DB_NAME
|
||||
|
||||
self.timeout = timeout
|
||||
self.isolation_level = isolation_level
|
||||
|
||||
def validate(self) -> list[str]:
|
||||
"""
|
||||
Validate database configuration.
|
||||
|
||||
Returns:
|
||||
List of error messages. Empty list means configuration is valid.
|
||||
"""
|
||||
errors = []
|
||||
|
||||
# Check parent directory exists
|
||||
parent_dir = self.db_path.parent
|
||||
if not parent_dir.exists():
|
||||
errors.append(f"Database directory does not exist: {parent_dir}")
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
class DatabaseManager:
|
||||
"""
|
||||
Manages SQLite database connections and operations.
|
||||
|
||||
Provides context manager for safe connection handling and methods
|
||||
for common database operations.
|
||||
|
||||
Usage:
|
||||
db_manager = DatabaseManager()
|
||||
|
||||
# Using context manager (recommended)
|
||||
with db_manager.get_connection() as conn:
|
||||
cursor = conn.execute("SELECT * FROM ref_drug_names")
|
||||
results = cursor.fetchall()
|
||||
|
||||
# Or get a managed connection for longer operations
|
||||
conn = db_manager.connect()
|
||||
try:
|
||||
# ... do work ...
|
||||
finally:
|
||||
conn.close()
|
||||
"""
|
||||
|
||||
def __init__(self, config: Optional[DatabaseConfig] = None):
|
||||
"""
|
||||
Initialize the database manager.
|
||||
|
||||
Args:
|
||||
config: Database configuration. If None, uses default configuration.
|
||||
"""
|
||||
self.config = config or DatabaseConfig()
|
||||
self._connection: Optional[sqlite3.Connection] = None
|
||||
|
||||
@property
|
||||
def db_path(self) -> Path:
|
||||
"""Path to the SQLite database file."""
|
||||
return self.config.db_path
|
||||
|
||||
@property
|
||||
def exists(self) -> bool:
|
||||
"""Check if the database file exists."""
|
||||
return self.db_path.exists()
|
||||
|
||||
def connect(self) -> sqlite3.Connection:
|
||||
"""
|
||||
Create a new database connection.
|
||||
|
||||
Returns:
|
||||
sqlite3.Connection: New database connection.
|
||||
|
||||
Note:
|
||||
The caller is responsible for closing the connection.
|
||||
Consider using get_connection() context manager instead.
|
||||
"""
|
||||
conn = sqlite3.connect(
|
||||
str(self.db_path),
|
||||
timeout=self.config.timeout,
|
||||
isolation_level=self.config.isolation_level
|
||||
)
|
||||
# Enable foreign key support
|
||||
conn.execute("PRAGMA foreign_keys = ON")
|
||||
# Return rows as sqlite3.Row for dict-like access
|
||||
conn.row_factory = sqlite3.Row
|
||||
return conn
|
||||
|
||||
@contextmanager
|
||||
def get_connection(self) -> Generator[sqlite3.Connection, None, None]:
|
||||
"""
|
||||
Context manager for database connections.
|
||||
|
||||
Yields:
|
||||
sqlite3.Connection: Database connection.
|
||||
|
||||
Example:
|
||||
with db_manager.get_connection() as conn:
|
||||
conn.execute("INSERT INTO table VALUES (?)", (value,))
|
||||
conn.commit()
|
||||
"""
|
||||
conn = self.connect()
|
||||
try:
|
||||
yield conn
|
||||
except Exception:
|
||||
conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
@contextmanager
|
||||
def get_transaction(self) -> Generator[sqlite3.Connection, None, None]:
|
||||
"""
|
||||
Context manager for transactional operations.
|
||||
|
||||
Automatically commits on success, rolls back on exception.
|
||||
|
||||
Yields:
|
||||
sqlite3.Connection: Database connection in transaction mode.
|
||||
|
||||
Example:
|
||||
with db_manager.get_transaction() as conn:
|
||||
conn.execute("INSERT INTO table VALUES (?)", (value1,))
|
||||
conn.execute("INSERT INTO other_table VALUES (?)", (value2,))
|
||||
# Auto-commits if no exception
|
||||
"""
|
||||
conn = sqlite3.connect(
|
||||
str(self.db_path),
|
||||
timeout=self.config.timeout,
|
||||
isolation_level="DEFERRED" # Explicit transaction mode
|
||||
)
|
||||
conn.execute("PRAGMA foreign_keys = ON")
|
||||
conn.row_factory = sqlite3.Row
|
||||
try:
|
||||
yield conn
|
||||
conn.commit()
|
||||
except Exception:
|
||||
conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def execute_script(self, sql_script: str) -> None:
|
||||
"""
|
||||
Execute a SQL script (multiple statements).
|
||||
|
||||
Args:
|
||||
sql_script: SQL script containing one or more statements.
|
||||
"""
|
||||
with self.get_connection() as conn:
|
||||
conn.executescript(sql_script)
|
||||
logger.info("Executed SQL script successfully")
|
||||
|
||||
def table_exists(self, table_name: str) -> bool:
|
||||
"""
|
||||
Check if a table exists in the database.
|
||||
|
||||
Args:
|
||||
table_name: Name of the table to check.
|
||||
|
||||
Returns:
|
||||
True if the table exists, False otherwise.
|
||||
"""
|
||||
with self.get_connection() as conn:
|
||||
cursor = conn.execute(
|
||||
"SELECT name FROM sqlite_master WHERE type='table' AND name=?",
|
||||
(table_name,)
|
||||
)
|
||||
return cursor.fetchone() is not None
|
||||
|
||||
def get_table_count(self, table_name: str) -> int:
|
||||
"""
|
||||
Get the row count for a table.
|
||||
|
||||
Args:
|
||||
table_name: Name of the table.
|
||||
|
||||
Returns:
|
||||
Number of rows in the table.
|
||||
"""
|
||||
with self.get_connection() as conn:
|
||||
# Use parameterized table name via string formatting (safe since we control table_name)
|
||||
cursor = conn.execute(f"SELECT COUNT(*) FROM {table_name}")
|
||||
result = cursor.fetchone()
|
||||
return result[0] if result else 0
|
||||
|
||||
|
||||
# Default instance for application-wide use
|
||||
default_db_config = DatabaseConfig()
|
||||
default_db_manager = DatabaseManager(default_db_config)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,246 @@
|
||||
"""
|
||||
Data loader abstractions for NHS High-Cost Drug Patient Pathway Analysis Tool.
|
||||
|
||||
Provides a unified interface for loading patient intervention data from:
|
||||
- CSV/Parquet files (current behavior)
|
||||
- SQLite database (new, faster approach)
|
||||
- Snowflake (future, direct from warehouse)
|
||||
|
||||
The DataLoader ABC defines the contract for all loader implementations.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from core import PathConfig, default_paths
|
||||
from core.logging_config import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoadResult:
|
||||
"""Result of a data load operation.
|
||||
|
||||
Attributes:
|
||||
df: The loaded DataFrame with processed patient intervention data
|
||||
source: Description of the data source (e.g., "file:/path/to/file.csv")
|
||||
row_count: Number of rows loaded
|
||||
columns: List of column names in the DataFrame
|
||||
load_time_seconds: Time taken to load the data
|
||||
"""
|
||||
df: pd.DataFrame
|
||||
source: str
|
||||
row_count: int
|
||||
columns: list[str] = field(default_factory=list)
|
||||
load_time_seconds: float = 0.0
|
||||
|
||||
def __post_init__(self):
|
||||
if not self.columns:
|
||||
self.columns = list(self.df.columns)
|
||||
|
||||
|
||||
# Expected columns in a processed DataFrame
|
||||
# These are the columns that generate_graph() expects to receive
|
||||
REQUIRED_COLUMNS = [
|
||||
"UPID", # Unique Patient ID (Provider Code prefix + PersonKey)
|
||||
"Drug Name", # Standardized drug name
|
||||
"Intervention Date", # Date of intervention
|
||||
"Price Actual", # Cost of intervention
|
||||
"OrganisationName", # NHS Trust name
|
||||
"Directory", # Medical specialty/directory
|
||||
"Provider Code", # NHS provider code
|
||||
"PersonKey", # Patient identifier within provider
|
||||
]
|
||||
|
||||
# Additional columns that are useful but not strictly required
|
||||
OPTIONAL_COLUMNS = [
|
||||
"UPIDTreatment", # UPID + Drug Name combo (created by generate_graph)
|
||||
"Treatment Function Code", # NHS treatment function code
|
||||
"Additional Detail 1",
|
||||
"Additional Detail 2",
|
||||
"Additional Detail 3",
|
||||
"Additional Detail 4",
|
||||
"Additional Detail 5",
|
||||
]
|
||||
|
||||
|
||||
class DataLoader(ABC):
|
||||
"""Abstract base class for data loaders.
|
||||
|
||||
All data loaders must implement the load() method which returns
|
||||
a DataFrame ready for use by generate_graph().
|
||||
|
||||
The returned DataFrame must contain REQUIRED_COLUMNS at minimum.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def load(self) -> LoadResult:
|
||||
"""Load and process patient intervention data.
|
||||
|
||||
Returns:
|
||||
LoadResult containing the processed DataFrame and metadata.
|
||||
The DataFrame must contain all REQUIRED_COLUMNS.
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If the data source doesn't exist
|
||||
ValueError: If the data is malformed or missing required columns
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def validate_source(self) -> tuple[bool, str]:
|
||||
"""Check if the data source is valid and accessible.
|
||||
|
||||
Returns:
|
||||
Tuple of (is_valid, message).
|
||||
If is_valid is False, message explains the issue.
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def source_description(self) -> str:
|
||||
"""Human-readable description of the data source."""
|
||||
pass
|
||||
|
||||
def validate_dataframe(self, df: pd.DataFrame) -> tuple[bool, list[str]]:
|
||||
"""Validate that a DataFrame has all required columns.
|
||||
|
||||
Args:
|
||||
df: DataFrame to validate
|
||||
|
||||
Returns:
|
||||
Tuple of (is_valid, missing_columns).
|
||||
If is_valid is False, missing_columns lists what's missing.
|
||||
"""
|
||||
missing = [col for col in REQUIRED_COLUMNS if col not in df.columns]
|
||||
return len(missing) == 0, missing
|
||||
|
||||
|
||||
class FileDataLoader(DataLoader):
|
||||
"""Loads data from CSV or Parquet files.
|
||||
|
||||
This replicates the current behavior of dashboard_gui.main():
|
||||
1. Read CSV or Parquet file
|
||||
2. Apply patient_id() transformation
|
||||
3. Convert dates
|
||||
4. Apply drug_names() standardization
|
||||
5. Clean organization names
|
||||
6. Apply department_identification()
|
||||
|
||||
Args:
|
||||
file_path: Path to the CSV or Parquet file
|
||||
paths: PathConfig for reference data file locations (uses default_paths if None)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: Path | str,
|
||||
paths: Optional[PathConfig] = None,
|
||||
):
|
||||
self.file_path = Path(file_path)
|
||||
self.paths = paths or default_paths
|
||||
|
||||
def validate_source(self) -> tuple[bool, str]:
|
||||
"""Check if the file exists and has a supported extension."""
|
||||
if not self.file_path.exists():
|
||||
return False, f"File not found: {self.file_path}"
|
||||
|
||||
ext = self.file_path.suffix.lower()
|
||||
if ext not in ('.csv', '.parquet'):
|
||||
return False, f"Unsupported file type: {ext}. Must be .csv or .parquet"
|
||||
|
||||
return True, "OK"
|
||||
|
||||
@property
|
||||
def source_description(self) -> str:
|
||||
return f"file:{self.file_path}"
|
||||
|
||||
def load(self) -> LoadResult:
|
||||
"""Load and process data from CSV or Parquet file.
|
||||
|
||||
Applies the same transformation pipeline as the original
|
||||
dashboard_gui.main() function.
|
||||
"""
|
||||
import time
|
||||
from data_processing import transforms as data
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# Validate source before loading
|
||||
is_valid, msg = self.validate_source()
|
||||
if not is_valid:
|
||||
raise FileNotFoundError(msg)
|
||||
|
||||
# Read file based on extension
|
||||
ext = self.file_path.suffix.lower()
|
||||
logger.info(f"Reading {ext} file: {self.file_path}")
|
||||
|
||||
if ext == '.csv':
|
||||
df_raw = pd.read_csv(self.file_path, low_memory=False)
|
||||
else: # .parquet
|
||||
df_raw = pd.read_parquet(self.file_path)
|
||||
|
||||
logger.info(f"File read successfully. {len(df_raw)} rows.")
|
||||
|
||||
# Apply transformations (same as dashboard_gui.main())
|
||||
df = data.patient_id(df_raw)
|
||||
logger.info("Patient ID processing complete.")
|
||||
|
||||
df['Intervention Date'] = pd.to_datetime(df['Intervention Date'], format="%Y-%m-%d")
|
||||
logger.info("Date conversion complete.")
|
||||
|
||||
# Preserve original drug name before standardization (for SQLite storage)
|
||||
df['Drug Name Raw'] = df['Drug Name'].copy()
|
||||
|
||||
df = data.drug_names(df, self.paths)
|
||||
logger.info("Drug name processing complete.")
|
||||
|
||||
df['OrganisationName'] = df['OrganisationName'].str.replace(',', '')
|
||||
logger.info("Organisation name cleaning complete.")
|
||||
|
||||
df = data.department_identification(df, self.paths)
|
||||
logger.info("Department identification complete.")
|
||||
|
||||
# Validate result
|
||||
is_valid, missing = self.validate_dataframe(df)
|
||||
if not is_valid:
|
||||
raise ValueError(f"Processed DataFrame missing required columns: {missing}")
|
||||
|
||||
load_time = time.time() - start_time
|
||||
logger.info(f"Data loading complete. {len(df)} rows in {load_time:.2f}s")
|
||||
|
||||
return LoadResult(
|
||||
df=df,
|
||||
source=self.source_description,
|
||||
row_count=len(df),
|
||||
load_time_seconds=load_time,
|
||||
)
|
||||
|
||||
|
||||
def get_loader(
|
||||
source: str | Path,
|
||||
paths: Optional[PathConfig] = None,
|
||||
**kwargs
|
||||
) -> DataLoader:
|
||||
"""Factory function to create the appropriate DataLoader.
|
||||
|
||||
Args:
|
||||
source: File path (CSV/Parquet)
|
||||
paths: PathConfig for reference data (used by FileDataLoader)
|
||||
**kwargs: Additional arguments passed to the loader constructor
|
||||
|
||||
Returns:
|
||||
Appropriate DataLoader instance
|
||||
|
||||
Examples:
|
||||
>>> loader = get_loader("data/activity.csv")
|
||||
>>> loader = get_loader("data/activity.parquet")
|
||||
"""
|
||||
path = Path(source)
|
||||
return FileDataLoader(file_path=path, paths=paths)
|
||||
@@ -0,0 +1,469 @@
|
||||
"""
|
||||
Database migration script for NHS High-Cost Drug Patient Pathway Analysis Tool.
|
||||
|
||||
Provides functions to initialize the SQLite database schema and CLI interface
|
||||
for running migrations from the command line.
|
||||
|
||||
Usage:
|
||||
# Initialize database (creates all tables)
|
||||
python -m data_processing.migrate
|
||||
|
||||
# Drop existing tables and reinitialize
|
||||
python -m data_processing.migrate --drop-existing
|
||||
|
||||
# Show current database status
|
||||
python -m data_processing.migrate --status
|
||||
|
||||
# Migrate all reference data from CSV files
|
||||
python -m data_processing.migrate --reference-data
|
||||
|
||||
# Migrate reference data with verification
|
||||
python -m data_processing.migrate --reference-data --verify
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
# Ensure src/ is on sys.path when run as `python -m data_processing.migrate`
|
||||
_src_dir = str(Path(__file__).resolve().parent.parent)
|
||||
if _src_dir not in sys.path:
|
||||
sys.path.insert(0, _src_dir)
|
||||
|
||||
from core.logging_config import setup_logging, get_logger
|
||||
from data_processing.database import DatabaseManager, DatabaseConfig
|
||||
from core import PathConfig, default_paths
|
||||
from data_processing.schema import (
|
||||
create_all_tables,
|
||||
drop_all_tables,
|
||||
verify_all_tables_exist,
|
||||
get_all_table_counts,
|
||||
migrate_pathway_nodes_chart_type,
|
||||
migrate_refresh_log_source_row_count,
|
||||
)
|
||||
from data_processing.reference_data import (
|
||||
MigrationResult,
|
||||
migrate_drug_names,
|
||||
migrate_organizations,
|
||||
migrate_directories,
|
||||
migrate_drug_directory_map,
|
||||
migrate_drug_indication_clusters,
|
||||
verify_drug_names_migration,
|
||||
verify_organizations_migration,
|
||||
verify_directories_migration,
|
||||
verify_drug_directory_map_migration,
|
||||
verify_drug_indication_clusters_migration,
|
||||
)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def initialize_database(
|
||||
db_manager: Optional[DatabaseManager] = None,
|
||||
drop_existing: bool = False,
|
||||
confirm_drop: bool = True
|
||||
) -> bool:
|
||||
"""
|
||||
Initialize the database with all required tables.
|
||||
|
||||
Creates all tables defined in the schema (reference tables and pathway
|
||||
tables). Uses IF NOT EXISTS so safe to run multiple times.
|
||||
|
||||
Args:
|
||||
db_manager: DatabaseManager instance. Uses default if not provided.
|
||||
drop_existing: If True, drops all existing tables before creating.
|
||||
confirm_drop: If True and drop_existing=True, prompts for confirmation.
|
||||
Set to False for non-interactive use.
|
||||
|
||||
Returns:
|
||||
True if initialization succeeded, False otherwise.
|
||||
"""
|
||||
if db_manager is None:
|
||||
db_manager = DatabaseManager()
|
||||
|
||||
logger.info(f"Initializing database at: {db_manager.db_path}")
|
||||
|
||||
# Handle drop existing with confirmation
|
||||
if drop_existing:
|
||||
if confirm_drop:
|
||||
print(f"\nWARNING: This will delete ALL data from the database:")
|
||||
print(f" {db_manager.db_path}\n")
|
||||
response = input("Are you sure you want to continue? (yes/no): ")
|
||||
if response.lower() not in ("yes", "y"):
|
||||
print("Operation cancelled.")
|
||||
return False
|
||||
|
||||
if db_manager.exists:
|
||||
logger.warning("Dropping existing tables...")
|
||||
with db_manager.get_connection() as conn:
|
||||
drop_all_tables(conn)
|
||||
conn.commit()
|
||||
logger.info("Existing tables dropped")
|
||||
else:
|
||||
logger.info("Database does not exist yet, nothing to drop")
|
||||
|
||||
# Create all tables
|
||||
try:
|
||||
with db_manager.get_transaction() as conn:
|
||||
create_all_tables(conn)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create tables: {e}")
|
||||
return False
|
||||
|
||||
# Run migrations for schema changes
|
||||
try:
|
||||
with db_manager.get_connection() as conn:
|
||||
# Add chart_type column to pathway_nodes if it doesn't exist
|
||||
success, msg = migrate_pathway_nodes_chart_type(conn)
|
||||
if success:
|
||||
logger.info(f"pathway_nodes migration: {msg}")
|
||||
else:
|
||||
logger.error(f"pathway_nodes migration failed: {msg}")
|
||||
return False
|
||||
|
||||
# Add source_row_count column to pathway_refresh_log if it doesn't exist
|
||||
success, msg = migrate_refresh_log_source_row_count(conn)
|
||||
if success:
|
||||
logger.info(f"pathway_refresh_log migration: {msg}")
|
||||
else:
|
||||
logger.error(f"pathway_refresh_log migration failed: {msg}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Migration failed: {e}")
|
||||
return False
|
||||
|
||||
# Verify all tables were created
|
||||
with db_manager.get_connection() as conn:
|
||||
missing = verify_all_tables_exist(conn)
|
||||
|
||||
if missing:
|
||||
logger.error(f"Table creation failed. Missing tables: {missing}")
|
||||
return False
|
||||
|
||||
logger.info("All tables created successfully")
|
||||
return True
|
||||
|
||||
|
||||
def migrate_all_reference_data(
|
||||
db_manager: Optional[DatabaseManager] = None,
|
||||
paths: Optional[PathConfig] = None,
|
||||
verify: bool = False
|
||||
) -> tuple[bool, list[MigrationResult]]:
|
||||
"""
|
||||
Run all reference data migrations from CSV files to SQLite tables.
|
||||
|
||||
Migrations are run in order:
|
||||
1. Drug names (drugnames.csv → ref_drug_names)
|
||||
2. Organizations (org_codes.csv → ref_organizations)
|
||||
3. Directories (directory_list.csv → ref_directories)
|
||||
4. Drug-directory mappings (drug_directory_list.csv → ref_drug_directory_map)
|
||||
|
||||
Args:
|
||||
db_manager: DatabaseManager instance. Uses default if not provided.
|
||||
paths: PathConfig instance for locating CSV files. Uses default if not provided.
|
||||
verify: If True, runs verification after each migration.
|
||||
|
||||
Returns:
|
||||
Tuple of (all_success: bool, results: list of MigrationResult)
|
||||
"""
|
||||
if db_manager is None:
|
||||
db_manager = DatabaseManager()
|
||||
if paths is None:
|
||||
paths = default_paths
|
||||
|
||||
results: list[MigrationResult] = []
|
||||
all_success = True
|
||||
|
||||
# Define migrations in order
|
||||
# Note: drug_indication_clusters uses a different signature (csv_path instead of paths)
|
||||
migrations = [
|
||||
("Drug names", migrate_drug_names, verify_drug_names_migration if verify else None, True),
|
||||
("Organizations", migrate_organizations, verify_organizations_migration if verify else None, True),
|
||||
("Directories", migrate_directories, verify_directories_migration if verify else None, True),
|
||||
("Drug-directory map", migrate_drug_directory_map, verify_drug_directory_map_migration if verify else None, True),
|
||||
("Drug indication clusters", migrate_drug_indication_clusters, verify_drug_indication_clusters_migration if verify else None, False),
|
||||
]
|
||||
|
||||
logger.info(f"Starting reference data migrations ({len(migrations)} tables)")
|
||||
|
||||
for name, migrate_fn, verify_fn, uses_paths in migrations:
|
||||
logger.info(f"Migrating: {name}...")
|
||||
|
||||
# Run migration (some use paths parameter, some use csv_path)
|
||||
if uses_paths:
|
||||
result = migrate_fn(db_manager=db_manager, paths=paths) # type: ignore[operator]
|
||||
else:
|
||||
# Drug indication clusters uses csv_path instead of paths
|
||||
result = migrate_fn(db_manager=db_manager) # type: ignore[operator]
|
||||
results.append(result)
|
||||
|
||||
if not result.success:
|
||||
logger.error(f"Migration failed: {name} - {result.error_message}")
|
||||
all_success = False
|
||||
continue
|
||||
|
||||
logger.info(f" {result}")
|
||||
|
||||
# Run verification if requested
|
||||
if verify_fn is not None:
|
||||
logger.info(f" Verifying {name}...")
|
||||
if uses_paths:
|
||||
verified, verify_msg = verify_fn(db_manager=db_manager, paths=paths) # type: ignore[call-arg]
|
||||
else:
|
||||
verified, verify_msg = verify_fn(db_manager=db_manager) # type: ignore[call-arg]
|
||||
if verified:
|
||||
logger.info(f" OK: {verify_msg}")
|
||||
else:
|
||||
logger.error(f" FAILED: Verification failed: {verify_msg}")
|
||||
all_success = False
|
||||
|
||||
# Summary
|
||||
successful = sum(1 for r in results if r.success)
|
||||
logger.info(f"Reference data migrations complete: {successful}/{len(results)} succeeded")
|
||||
|
||||
return all_success, results
|
||||
|
||||
|
||||
def print_migration_summary(results: list[MigrationResult]) -> None:
|
||||
"""Print a summary of migration results to stdout."""
|
||||
print("\n=== Reference Data Migration Summary ===\n")
|
||||
|
||||
for result in results:
|
||||
status = "[OK]" if result.success else "[FAILED]"
|
||||
print(f"{status} {result.table_name}")
|
||||
if result.success:
|
||||
print(f" Read: {result.rows_read}, Inserted: {result.rows_inserted}, Skipped: {result.rows_skipped}")
|
||||
else:
|
||||
print(f" Error: {result.error_message}")
|
||||
|
||||
successful = sum(1 for r in results if r.success)
|
||||
print(f"\nTotal: {successful}/{len(results)} migrations succeeded")
|
||||
print()
|
||||
|
||||
|
||||
def create_progress_reporter(description: str = "Loading", width: int = 40):
|
||||
"""
|
||||
Create a progress callback that prints a progress bar to stdout.
|
||||
|
||||
Args:
|
||||
description: Label to show before the progress bar.
|
||||
width: Width of the progress bar in characters.
|
||||
|
||||
Returns:
|
||||
Callback function(current, total) that prints progress.
|
||||
"""
|
||||
last_percent = [-1] # Use list to allow mutation in closure
|
||||
|
||||
def report_progress(current: int, total: int) -> None:
|
||||
"""Print a progress bar showing current/total progress."""
|
||||
if total == 0:
|
||||
percent = 100
|
||||
else:
|
||||
percent = int(100 * current / total)
|
||||
|
||||
# Only update display when percentage changes (avoid excessive output)
|
||||
if percent == last_percent[0]:
|
||||
return
|
||||
last_percent[0] = percent
|
||||
|
||||
filled = int(width * current / total) if total > 0 else width
|
||||
bar = "=" * filled + "-" * (width - filled)
|
||||
|
||||
# Use carriage return to overwrite the line
|
||||
sys.stdout.write(f"\r{description}: [{bar}] {percent:3d}% ({current:,}/{total:,})")
|
||||
sys.stdout.flush()
|
||||
|
||||
# Print newline when complete
|
||||
if current >= total:
|
||||
print()
|
||||
|
||||
return report_progress
|
||||
|
||||
|
||||
def get_database_status(db_manager: Optional[DatabaseManager] = None) -> dict:
|
||||
"""
|
||||
Get the current status of the database.
|
||||
|
||||
Returns:
|
||||
Dictionary with database status information:
|
||||
- exists: Whether the database file exists
|
||||
- path: Path to the database file
|
||||
- size_bytes: Size of database file (if exists)
|
||||
- tables: Dictionary of table names to row counts
|
||||
- missing_tables: List of expected tables that don't exist
|
||||
"""
|
||||
if db_manager is None:
|
||||
db_manager = DatabaseManager()
|
||||
|
||||
status = {
|
||||
"exists": db_manager.exists,
|
||||
"path": str(db_manager.db_path),
|
||||
"size_bytes": None,
|
||||
"tables": {},
|
||||
"missing_tables": [],
|
||||
}
|
||||
|
||||
if db_manager.exists:
|
||||
status["size_bytes"] = db_manager.db_path.stat().st_size
|
||||
|
||||
with db_manager.get_connection() as conn:
|
||||
status["missing_tables"] = verify_all_tables_exist(conn)
|
||||
|
||||
# Get counts for existing tables
|
||||
try:
|
||||
status["tables"] = get_all_table_counts(conn)
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not get table counts: {e}")
|
||||
|
||||
return status
|
||||
|
||||
|
||||
def print_database_status(db_manager: Optional[DatabaseManager] = None) -> None:
|
||||
"""Print database status to stdout in a human-readable format."""
|
||||
status = get_database_status(db_manager)
|
||||
|
||||
print("\n=== Database Status ===\n")
|
||||
print(f"Path: {status['path']}")
|
||||
print(f"Exists: {status['exists']}")
|
||||
|
||||
if status["exists"]:
|
||||
size_kb = (status["size_bytes"] or 0) / 1024
|
||||
print(f"Size: {size_kb:.1f} KB")
|
||||
|
||||
if status["missing_tables"]:
|
||||
print(f"\nMissing tables: {', '.join(status['missing_tables'])}")
|
||||
else:
|
||||
print("\nAll expected tables exist.")
|
||||
|
||||
if status["tables"]:
|
||||
print("\nTable row counts:")
|
||||
for table, count in sorted(status["tables"].items()):
|
||||
print(f" {table}: {count:,} rows")
|
||||
else:
|
||||
print("\nDatabase does not exist. Run migration to create it.")
|
||||
|
||||
print()
|
||||
|
||||
|
||||
def main():
|
||||
"""CLI entry point for database migration."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Initialize NHS Pathways Analysis SQLite database schema",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
python -m data_processing.migrate # Initialize database
|
||||
python -m data_processing.migrate --status # Show database status
|
||||
python -m data_processing.migrate --drop-existing # Reset database
|
||||
python -m data_processing.migrate --reference-data # Migrate reference data
|
||||
python -m data_processing.migrate --reference-data --verify # With verification
|
||||
python -m data_processing.migrate --db-path ./data/test.db # Custom path
|
||||
"""
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--status",
|
||||
action="store_true",
|
||||
help="Show current database status and exit"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--drop-existing",
|
||||
action="store_true",
|
||||
help="Drop all existing tables before creating (WARNING: deletes data)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--reference-data",
|
||||
action="store_true",
|
||||
help="Migrate all reference data from CSV files to SQLite tables"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--verify",
|
||||
action="store_true",
|
||||
help="Verify migrated data matches CSV sources (use with --reference-data)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--db-path",
|
||||
type=Path,
|
||||
help="Path to database file (default: ./data/pathways.db)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--yes", "-y",
|
||||
action="store_true",
|
||||
help="Skip confirmation prompts (for non-interactive use)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--verbose", "-v",
|
||||
action="store_true",
|
||||
help="Enable verbose logging"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Set up logging
|
||||
log_level = "DEBUG" if args.verbose else "INFO"
|
||||
setup_logging(level=log_level, simple_console=True)
|
||||
|
||||
# Create database manager with optional custom path
|
||||
if args.db_path:
|
||||
config = DatabaseConfig(db_path=args.db_path)
|
||||
db_manager = DatabaseManager(config)
|
||||
else:
|
||||
db_manager = DatabaseManager()
|
||||
|
||||
# Handle --status
|
||||
if args.status:
|
||||
print_database_status(db_manager)
|
||||
return 0
|
||||
|
||||
# Validate configuration
|
||||
config_errors = db_manager.config.validate()
|
||||
if config_errors:
|
||||
for error in config_errors:
|
||||
logger.error(error)
|
||||
return 1
|
||||
|
||||
# Handle --reference-data (migrate reference data from CSV to SQLite)
|
||||
if args.reference_data:
|
||||
# Ensure database exists with tables first
|
||||
if not db_manager.exists:
|
||||
print("Database does not exist. Initializing schema first...")
|
||||
success = initialize_database(db_manager=db_manager)
|
||||
if not success:
|
||||
print("\nDatabase initialization failed. Check logs for details.")
|
||||
return 1
|
||||
|
||||
# Run reference data migrations
|
||||
success, results = migrate_all_reference_data(
|
||||
db_manager=db_manager,
|
||||
paths=default_paths,
|
||||
verify=args.verify
|
||||
)
|
||||
|
||||
print_migration_summary(results)
|
||||
print_database_status(db_manager)
|
||||
|
||||
if success:
|
||||
print("Reference data migration completed successfully.")
|
||||
return 0
|
||||
else:
|
||||
print("Reference data migration completed with errors. Check logs for details.")
|
||||
return 1
|
||||
|
||||
# Run schema migration (default behavior)
|
||||
success = initialize_database(
|
||||
db_manager=db_manager,
|
||||
drop_existing=args.drop_existing,
|
||||
confirm_drop=not args.yes
|
||||
)
|
||||
|
||||
if success:
|
||||
print("\nDatabase initialized successfully.")
|
||||
print_database_status(db_manager)
|
||||
return 0
|
||||
else:
|
||||
print("\nDatabase initialization failed. Check logs for details.")
|
||||
return 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
@@ -0,0 +1,642 @@
|
||||
"""
|
||||
Pathway data processing pipeline.
|
||||
|
||||
This module provides functions to:
|
||||
1. Fetch and transform raw intervention data from Snowflake
|
||||
2. Process data for each of the 6 date filter combinations
|
||||
3. Extract denormalized fields from hierarchical path strings
|
||||
4. Convert processed data to records for SQLite storage
|
||||
|
||||
The pipeline integrates with:
|
||||
- analysis/pathway_analyzer.py: generate_icicle_chart() for pathway processing
|
||||
- data_processing/snowflake_connector.py: fetch_activity_data() for data retrieval
|
||||
- tools/data.py: patient_id(), drug_names(), department_identification()
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import date, timedelta
|
||||
from typing import Optional, Literal
|
||||
import json
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from core import PathConfig, default_paths
|
||||
from core.logging_config import get_logger
|
||||
from analysis.pathway_analyzer import generate_icicle_chart, generate_icicle_chart_indication
|
||||
from data_processing.transforms import patient_id, drug_names, department_identification
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# Type alias for chart types
|
||||
ChartType = Literal["directory", "indication"]
|
||||
|
||||
|
||||
@dataclass
|
||||
class DateFilterConfig:
|
||||
"""Configuration for a date filter combination."""
|
||||
|
||||
id: str # e.g., 'all_6mo', '1yr_12mo'
|
||||
initiated_years: Optional[int] # None for 'All', 1, or 2
|
||||
last_seen_months: int # 6 or 12
|
||||
|
||||
|
||||
# Pre-defined date filter configurations matching pathway_date_filters table
|
||||
DATE_FILTER_CONFIGS = [
|
||||
DateFilterConfig(id="all_6mo", initiated_years=None, last_seen_months=6),
|
||||
DateFilterConfig(id="all_12mo", initiated_years=None, last_seen_months=12),
|
||||
DateFilterConfig(id="1yr_6mo", initiated_years=1, last_seen_months=6),
|
||||
DateFilterConfig(id="1yr_12mo", initiated_years=1, last_seen_months=12),
|
||||
DateFilterConfig(id="2yr_6mo", initiated_years=2, last_seen_months=6),
|
||||
DateFilterConfig(id="2yr_12mo", initiated_years=2, last_seen_months=12),
|
||||
]
|
||||
|
||||
|
||||
def compute_date_ranges(
|
||||
config: DateFilterConfig,
|
||||
max_date: Optional[date] = None,
|
||||
) -> tuple[str, str, str]:
|
||||
"""
|
||||
Compute actual date strings from a date filter configuration.
|
||||
|
||||
Args:
|
||||
config: DateFilterConfig with initiated_years and last_seen_months
|
||||
max_date: Reference date (defaults to today)
|
||||
|
||||
Returns:
|
||||
Tuple of (start_date, end_date, last_seen_date) as ISO format strings
|
||||
- start_date: Start of initiated filter period
|
||||
- end_date: End of initiated filter period (usually max_date)
|
||||
- last_seen_date: Date threshold for last_seen filter
|
||||
"""
|
||||
if max_date is None:
|
||||
max_date = date.today()
|
||||
|
||||
# Calculate end_date (always max_date)
|
||||
end_date = max_date
|
||||
|
||||
# Calculate start_date based on initiated_years
|
||||
if config.initiated_years is None:
|
||||
# "All years" - use a very old date
|
||||
start_date = date(2000, 1, 1)
|
||||
else:
|
||||
# Last N years from max_date
|
||||
start_date = max_date.replace(year=max_date.year - config.initiated_years)
|
||||
|
||||
# Calculate last_seen_date based on last_seen_months
|
||||
# Patients must have been seen within the last N months
|
||||
last_seen_date = max_date - timedelta(days=config.last_seen_months * 30)
|
||||
|
||||
return (
|
||||
start_date.isoformat(),
|
||||
end_date.isoformat(),
|
||||
last_seen_date.isoformat(),
|
||||
)
|
||||
|
||||
|
||||
def fetch_and_transform_data(
|
||||
start_date: Optional[date] = None,
|
||||
end_date: Optional[date] = None,
|
||||
provider_codes: Optional[list[str]] = None,
|
||||
paths: Optional[PathConfig] = None,
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
Fetch data from Snowflake and apply standard transformations.
|
||||
|
||||
This function:
|
||||
1. Fetches raw intervention data from Snowflake
|
||||
2. Applies UPID generation (Provider Code[:3] + PersonKey)
|
||||
3. Standardizes drug names via drugnames.csv mapping
|
||||
4. Assigns directories using the 5-level fallback logic
|
||||
|
||||
Args:
|
||||
start_date: Optional start date filter for Snowflake query
|
||||
end_date: Optional end date filter for Snowflake query
|
||||
provider_codes: Optional list of provider codes to filter
|
||||
paths: PathConfig for file paths (uses default if None)
|
||||
|
||||
Returns:
|
||||
DataFrame with columns: UPID, Drug Name, Directory, Intervention Date,
|
||||
Price Actual, Provider Code, PersonKey, OrganisationName, etc.
|
||||
|
||||
Raises:
|
||||
ImportError: If snowflake-connector-python is not installed
|
||||
SnowflakeConnectionError: If connection fails
|
||||
"""
|
||||
if paths is None:
|
||||
paths = default_paths
|
||||
|
||||
# Import here to avoid circular imports and handle optional dependency
|
||||
from data_processing.snowflake_connector import get_connector, is_snowflake_available
|
||||
|
||||
if not is_snowflake_available():
|
||||
raise ImportError(
|
||||
"snowflake-connector-python is not installed. "
|
||||
"Install it with: pip install snowflake-connector-python"
|
||||
)
|
||||
|
||||
logger.info("Fetching activity data from Snowflake...")
|
||||
|
||||
connector = get_connector()
|
||||
raw_data = connector.fetch_activity_data(
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
provider_codes=provider_codes,
|
||||
max_rows=0, # No limit
|
||||
)
|
||||
|
||||
if not raw_data:
|
||||
logger.warning("No data returned from Snowflake")
|
||||
return pd.DataFrame()
|
||||
|
||||
logger.info(f"Fetched {len(raw_data)} records from Snowflake")
|
||||
|
||||
# Convert to DataFrame
|
||||
df = pd.DataFrame(raw_data)
|
||||
|
||||
# Apply transformations in the standard order
|
||||
logger.info("Applying data transformations...")
|
||||
|
||||
# 1. Generate UPID
|
||||
df = patient_id(df)
|
||||
logger.info(f"Generated UPID for {df['UPID'].nunique()} unique patients")
|
||||
|
||||
# 2. Standardize drug names
|
||||
df = drug_names(df, paths)
|
||||
# Remove rows where drug name mapping failed (NaN)
|
||||
before_count = len(df)
|
||||
df = df.dropna(subset=['Drug Name'])
|
||||
after_count = len(df)
|
||||
if before_count != after_count:
|
||||
logger.info(f"Removed {before_count - after_count} rows with unmapped drug names")
|
||||
|
||||
# 3. Assign directories
|
||||
df = department_identification(df, paths)
|
||||
logger.info(f"Assigned directories to {len(df)} records")
|
||||
|
||||
# Ensure Intervention Date is datetime
|
||||
df['Intervention Date'] = pd.to_datetime(df['Intervention Date'])
|
||||
|
||||
logger.info(f"Data transformation complete. Final record count: {len(df)}")
|
||||
return df
|
||||
|
||||
|
||||
def process_pathway_for_date_filter(
|
||||
df: pd.DataFrame,
|
||||
config: DateFilterConfig,
|
||||
trust_filter: list[str],
|
||||
drug_filter: list[str],
|
||||
directory_filter: list[str],
|
||||
minimum_patients: int = 5,
|
||||
max_date: Optional[date] = None,
|
||||
paths: Optional[PathConfig] = None,
|
||||
) -> Optional[pd.DataFrame]:
|
||||
"""
|
||||
Process pathway data for a single date filter configuration.
|
||||
|
||||
Uses the existing generate_icicle_chart() function from pathway_analyzer.py
|
||||
to build the pathway hierarchy with treatment statistics.
|
||||
|
||||
Args:
|
||||
df: Transformed DataFrame from fetch_and_transform_data()
|
||||
config: DateFilterConfig specifying the date filter combination
|
||||
trust_filter: List of trust names to include
|
||||
drug_filter: List of drug names to include
|
||||
directory_filter: List of directories to include
|
||||
minimum_patients: Minimum patients to include a pathway
|
||||
max_date: Reference date for computing date ranges
|
||||
paths: PathConfig for file paths
|
||||
|
||||
Returns:
|
||||
DataFrame with pathway hierarchy (ice_df) or None if no data
|
||||
"""
|
||||
if paths is None:
|
||||
paths = default_paths
|
||||
|
||||
# Compute actual date ranges for this filter config
|
||||
start_date, end_date, last_seen_date = compute_date_ranges(config, max_date)
|
||||
|
||||
logger.info(f"Processing pathway for {config.id}")
|
||||
logger.info(f" Date range: {start_date} to {end_date}")
|
||||
logger.info(f" Last seen after: {last_seen_date}")
|
||||
|
||||
# Use the existing pathway analyzer
|
||||
ice_df, title = generate_icicle_chart(
|
||||
df=df,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
last_seen_date=last_seen_date,
|
||||
trust_filter=trust_filter,
|
||||
drug_filter=drug_filter,
|
||||
directory_filter=directory_filter,
|
||||
minimum_num_patients=minimum_patients,
|
||||
title="",
|
||||
paths=paths,
|
||||
)
|
||||
|
||||
if ice_df is None or len(ice_df) == 0:
|
||||
logger.warning(f"No pathway data for filter {config.id}")
|
||||
return None
|
||||
|
||||
logger.info(f"Generated {len(ice_df)} pathway nodes for {config.id}")
|
||||
return ice_df
|
||||
|
||||
|
||||
def process_indication_pathway_for_date_filter(
|
||||
df: pd.DataFrame,
|
||||
indication_df: pd.DataFrame,
|
||||
config: DateFilterConfig,
|
||||
trust_filter: list[str],
|
||||
drug_filter: list[str],
|
||||
directory_filter: list[str],
|
||||
minimum_patients: int = 5,
|
||||
max_date: Optional[date] = None,
|
||||
paths: Optional[PathConfig] = None,
|
||||
) -> Optional[pd.DataFrame]:
|
||||
"""
|
||||
Process indication-based pathway data for a single date filter configuration.
|
||||
|
||||
This is similar to process_pathway_for_date_filter() but uses indication-based
|
||||
grouping (Search_Term from GP diagnosis) instead of directory grouping.
|
||||
|
||||
Hierarchy: Trust → Indication_Group → Drug → Pathway
|
||||
|
||||
Args:
|
||||
df: Transformed DataFrame from fetch_and_transform_data()
|
||||
indication_df: DataFrame with UPID → Indication_Group mapping
|
||||
Must have columns: UPID, Indication_Group
|
||||
Indication_Group is either Search_Term or "Directory (no GP dx)"
|
||||
config: DateFilterConfig specifying the date filter combination
|
||||
trust_filter: List of trust names to include
|
||||
drug_filter: List of drug names to include
|
||||
directory_filter: List of directories to include
|
||||
minimum_patients: Minimum patients to include a pathway
|
||||
max_date: Reference date for computing date ranges
|
||||
paths: PathConfig for file paths
|
||||
|
||||
Returns:
|
||||
DataFrame with pathway hierarchy (ice_df) or None if no data
|
||||
"""
|
||||
if paths is None:
|
||||
paths = default_paths
|
||||
|
||||
# Compute actual date ranges for this filter config
|
||||
start_date, end_date, last_seen_date = compute_date_ranges(config, max_date)
|
||||
|
||||
logger.info(f"Processing indication pathway for {config.id}")
|
||||
logger.info(f" Date range: {start_date} to {end_date}")
|
||||
logger.info(f" Last seen after: {last_seen_date}")
|
||||
|
||||
# Use the indication-aware pathway analyzer
|
||||
ice_df, title = generate_icicle_chart_indication(
|
||||
df=df,
|
||||
indication_df=indication_df,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
last_seen_date=last_seen_date,
|
||||
trust_filter=trust_filter,
|
||||
drug_filter=drug_filter,
|
||||
directory_filter=directory_filter,
|
||||
minimum_num_patients=minimum_patients,
|
||||
title="",
|
||||
paths=paths,
|
||||
)
|
||||
|
||||
if ice_df is None or len(ice_df) == 0:
|
||||
logger.warning(f"No indication pathway data for filter {config.id}")
|
||||
return None
|
||||
|
||||
logger.info(f"Generated {len(ice_df)} indication pathway nodes for {config.id}")
|
||||
return ice_df
|
||||
|
||||
|
||||
def extract_denormalized_fields(ice_df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""
|
||||
Extract denormalized filter columns from the ids column.
|
||||
|
||||
The ids column contains hierarchical paths like:
|
||||
- "N&WICS" (root)
|
||||
- "N&WICS - NNUH" (trust level)
|
||||
- "N&WICS - NNUH - OPHTHALMOLOGY" (directory level)
|
||||
- "N&WICS - NNUH - OPHTHALMOLOGY - RANIBIZUMAB" (first drug)
|
||||
- "N&WICS - NNUH - OPHTHALMOLOGY - RANIBIZUMAB - AFLIBERCEPT" (pathway)
|
||||
|
||||
This function extracts:
|
||||
- trust_name: The trust component (level 1)
|
||||
- directory: The directory component (level 2)
|
||||
- drug_sequence: Pipe-separated drugs (level 3+)
|
||||
|
||||
Args:
|
||||
ice_df: DataFrame from generate_icicle_chart()
|
||||
|
||||
Returns:
|
||||
DataFrame with added columns: trust_name, directory, drug_sequence
|
||||
"""
|
||||
df = ice_df.copy()
|
||||
|
||||
# Split ids by " - " delimiter
|
||||
def extract_components(ids_str: str) -> tuple[str, str, str]:
|
||||
"""Extract trust, directory, and drug sequence from ids string."""
|
||||
if not ids_str or pd.isna(ids_str):
|
||||
return ("", "", "")
|
||||
|
||||
parts = ids_str.split(" - ")
|
||||
|
||||
# Level 0: Root (e.g., "N&WICS")
|
||||
if len(parts) <= 1:
|
||||
return ("", "", "")
|
||||
|
||||
# Level 1+: Trust is always parts[1]
|
||||
trust_name = parts[1] if len(parts) > 1 else ""
|
||||
|
||||
# Level 2+: Directory is parts[2]
|
||||
directory = parts[2] if len(parts) > 2 else ""
|
||||
|
||||
# Level 3+: Drugs are parts[3:]
|
||||
drugs = parts[3:] if len(parts) > 3 else []
|
||||
drug_sequence = "|".join(drugs) if drugs else ""
|
||||
|
||||
return (trust_name, directory, drug_sequence)
|
||||
|
||||
# Apply extraction to all rows
|
||||
extracted = df['ids'].apply(extract_components)
|
||||
df['trust_name'] = extracted.apply(lambda x: x[0])
|
||||
df['directory'] = extracted.apply(lambda x: x[1])
|
||||
df['drug_sequence'] = extracted.apply(lambda x: x[2])
|
||||
|
||||
logger.info(f"Extracted denormalized fields for {len(df)} nodes")
|
||||
logger.info(f" Unique trusts: {df['trust_name'].nunique()}")
|
||||
logger.info(f" Unique directories: {df['directory'].nunique()}")
|
||||
|
||||
return df
|
||||
|
||||
|
||||
def extract_indication_fields(ice_df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""
|
||||
Extract denormalized filter columns from the ids column for indication charts.
|
||||
|
||||
Similar to extract_denormalized_fields() but for indication-based charts where
|
||||
the level-2 grouping is Search_Term (or fallback directorate) instead of Directory.
|
||||
|
||||
The ids column contains hierarchical paths like:
|
||||
- "N&WICS" (root)
|
||||
- "N&WICS - NNUH" (trust level)
|
||||
- "N&WICS - NNUH - rheumatoid arthritis" (search_term level - matched patient)
|
||||
- "N&WICS - NNUH - RHEUMATOLOGY (no GP dx)" (fallback level - unmatched patient)
|
||||
- "N&WICS - NNUH - rheumatoid arthritis - ADALIMUMAB" (first drug)
|
||||
- "N&WICS - NNUH - rheumatoid arthritis - ADALIMUMAB - ETANERCEPT" (pathway)
|
||||
|
||||
This function extracts:
|
||||
- trust_name: The trust component (level 1)
|
||||
- search_term: The Search_Term or fallback directorate (level 2)
|
||||
- drug_sequence: Pipe-separated drugs (level 3+)
|
||||
|
||||
Note: For indication charts, 'directory' column contains the search_term
|
||||
to maintain schema compatibility with the pathway_nodes table.
|
||||
|
||||
Args:
|
||||
ice_df: DataFrame from generate_icicle_chart() with indication grouping
|
||||
|
||||
Returns:
|
||||
DataFrame with added columns: trust_name, directory (=search_term), drug_sequence
|
||||
"""
|
||||
df = ice_df.copy()
|
||||
|
||||
def extract_components(ids_str: str) -> tuple[str, str, str]:
|
||||
"""Extract trust, search_term, and drug sequence from ids string."""
|
||||
if not ids_str or pd.isna(ids_str):
|
||||
return ("", "", "")
|
||||
|
||||
parts = ids_str.split(" - ")
|
||||
|
||||
# Level 0: Root (e.g., "N&WICS")
|
||||
if len(parts) <= 1:
|
||||
return ("", "", "")
|
||||
|
||||
# Level 1+: Trust is always parts[1]
|
||||
trust_name = parts[1] if len(parts) > 1 else ""
|
||||
|
||||
# Level 2+: Search_term (or fallback) is parts[2]
|
||||
search_term = parts[2] if len(parts) > 2 else ""
|
||||
|
||||
# Level 3+: Drugs are parts[3:]
|
||||
drugs = parts[3:] if len(parts) > 3 else []
|
||||
drug_sequence = "|".join(drugs) if drugs else ""
|
||||
|
||||
return (trust_name, search_term, drug_sequence)
|
||||
|
||||
# Apply extraction to all rows
|
||||
extracted = df['ids'].apply(extract_components)
|
||||
df['trust_name'] = extracted.apply(lambda x: x[0])
|
||||
# Use 'directory' column to store search_term for schema compatibility
|
||||
df['directory'] = extracted.apply(lambda x: x[1])
|
||||
df['drug_sequence'] = extracted.apply(lambda x: x[2])
|
||||
|
||||
logger.info(f"Extracted indication fields for {len(df)} nodes")
|
||||
logger.info(f" Unique trusts: {df['trust_name'].nunique()}")
|
||||
logger.info(f" Unique search_terms: {df['directory'].nunique()}")
|
||||
|
||||
return df
|
||||
|
||||
|
||||
def convert_to_records(
|
||||
ice_df: pd.DataFrame,
|
||||
date_filter_id: str,
|
||||
refresh_id: Optional[str] = None,
|
||||
chart_type: ChartType = "directory",
|
||||
) -> list[dict]:
|
||||
"""
|
||||
Convert ice_df to a list of dictionaries for SQLite insertion.
|
||||
|
||||
Maps ice_df columns to pathway_nodes table schema:
|
||||
- parents, ids, labels: Direct mapping
|
||||
- level: From ice_df['level']
|
||||
- value, cost, costpp, colour: Direct mapping
|
||||
- cost_pp_pa: From ice_df['cost_pp_pa']
|
||||
- first_seen, last_seen, first_seen_parent, last_seen_parent: Date columns
|
||||
- average_spacing: From ice_df['average_spacing']
|
||||
- average_administered: JSON serialization of list
|
||||
- avg_days: From ice_df['avg_days']
|
||||
- trust_name, directory, drug_sequence: Denormalized fields
|
||||
- date_filter_id: The filter combination ID
|
||||
- chart_type: "directory" or "indication"
|
||||
- data_refresh_id: Optional refresh tracking ID
|
||||
|
||||
Args:
|
||||
ice_df: DataFrame from generate_icicle_chart() with denormalized fields
|
||||
date_filter_id: The date filter combination ID (e.g., 'all_6mo')
|
||||
refresh_id: Optional refresh tracking ID
|
||||
chart_type: Chart type - "directory" (default) or "indication"
|
||||
|
||||
Returns:
|
||||
List of dictionaries ready for SQLite insertion
|
||||
"""
|
||||
records = []
|
||||
|
||||
for _, row in ice_df.iterrows():
|
||||
# Handle date formatting
|
||||
first_seen = None
|
||||
last_seen = None
|
||||
first_seen_parent = None
|
||||
last_seen_parent = None
|
||||
|
||||
if pd.notna(row.get('First seen')):
|
||||
if hasattr(row['First seen'], 'isoformat'):
|
||||
first_seen = row['First seen'].isoformat()
|
||||
else:
|
||||
first_seen = str(row['First seen'])
|
||||
|
||||
if pd.notna(row.get('Last seen')):
|
||||
if hasattr(row['Last seen'], 'isoformat'):
|
||||
last_seen = row['Last seen'].isoformat()
|
||||
else:
|
||||
last_seen = str(row['Last seen'])
|
||||
|
||||
if pd.notna(row.get('First seen (Parent)')):
|
||||
first_seen_parent = str(row['First seen (Parent)'])
|
||||
|
||||
if pd.notna(row.get('Last seen (Parent)')):
|
||||
last_seen_parent = str(row['Last seen (Parent)'])
|
||||
|
||||
# Handle average_administered (could be list, ndarray, or None)
|
||||
average_administered = None
|
||||
val = row.get('average_administered')
|
||||
if val is not None:
|
||||
# Check for scalar None-like values
|
||||
try:
|
||||
if pd.isna(val):
|
||||
average_administered = None
|
||||
elif isinstance(val, (list, np.ndarray)):
|
||||
average_administered = json.dumps(list(val) if hasattr(val, 'tolist') else val)
|
||||
else:
|
||||
average_administered = str(val)
|
||||
except (ValueError, TypeError):
|
||||
# pd.isna raises ValueError for arrays with >1 element
|
||||
# In that case, val is an array/list, so convert to JSON
|
||||
if hasattr(val, 'tolist'):
|
||||
average_administered = json.dumps(val.tolist())
|
||||
elif isinstance(val, list):
|
||||
average_administered = json.dumps(val)
|
||||
else:
|
||||
average_administered = str(val)
|
||||
|
||||
record = {
|
||||
'date_filter_id': date_filter_id,
|
||||
'chart_type': chart_type,
|
||||
'parents': str(row.get('parents', '')) if pd.notna(row.get('parents')) else '',
|
||||
'ids': str(row.get('ids', '')) if pd.notna(row.get('ids')) else '',
|
||||
'labels': str(row.get('labels', '')) if pd.notna(row.get('labels')) else '',
|
||||
'level': int(row.get('level', 0)) if pd.notna(row.get('level')) else 0,
|
||||
'value': int(row.get('value', 0)) if pd.notna(row.get('value')) else 0,
|
||||
'cost': float(row.get('cost', 0)) if pd.notna(row.get('cost')) else 0.0,
|
||||
'costpp': float(row.get('costpp')) if pd.notna(row.get('costpp')) else None,
|
||||
'cost_pp_pa': str(row.get('cost_pp_pa', '')) if pd.notna(row.get('cost_pp_pa')) else None,
|
||||
'colour': float(row.get('colour', 0)) if pd.notna(row.get('colour')) else 0.0,
|
||||
'first_seen': first_seen,
|
||||
'last_seen': last_seen,
|
||||
'first_seen_parent': first_seen_parent,
|
||||
'last_seen_parent': last_seen_parent,
|
||||
'average_spacing': str(row.get('average_spacing', '')) if pd.notna(row.get('average_spacing')) else None,
|
||||
'average_administered': average_administered,
|
||||
'avg_days': float(row['avg_days'].total_seconds() / 86400) if pd.notna(row.get('avg_days')) and hasattr(row.get('avg_days'), 'total_seconds') else (float(row.get('avg_days')) if pd.notna(row.get('avg_days')) else None),
|
||||
'trust_name': row.get('trust_name', '') if pd.notna(row.get('trust_name')) else None,
|
||||
'directory': row.get('directory', '') if pd.notna(row.get('directory')) else None,
|
||||
'drug_sequence': row.get('drug_sequence', '') if pd.notna(row.get('drug_sequence')) else None,
|
||||
'data_refresh_id': refresh_id,
|
||||
}
|
||||
records.append(record)
|
||||
|
||||
logger.info(f"Converted {len(records)} pathway nodes to records for {date_filter_id} ({chart_type})")
|
||||
return records
|
||||
|
||||
|
||||
def process_all_date_filters(
|
||||
df: pd.DataFrame,
|
||||
trust_filter: list[str],
|
||||
drug_filter: list[str],
|
||||
directory_filter: list[str],
|
||||
minimum_patients: int = 5,
|
||||
max_date: Optional[date] = None,
|
||||
refresh_id: Optional[str] = None,
|
||||
paths: Optional[PathConfig] = None,
|
||||
) -> dict[str, list[dict]]:
|
||||
"""
|
||||
Process pathway data for all 6 date filter combinations.
|
||||
|
||||
This is a convenience function that processes all DATE_FILTER_CONFIGS
|
||||
and returns a dictionary of records ready for SQLite insertion.
|
||||
|
||||
Args:
|
||||
df: Transformed DataFrame from fetch_and_transform_data()
|
||||
trust_filter: List of trust names to include
|
||||
drug_filter: List of drug names to include
|
||||
directory_filter: List of directories to include
|
||||
minimum_patients: Minimum patients to include a pathway
|
||||
max_date: Reference date for computing date ranges
|
||||
refresh_id: Optional refresh tracking ID
|
||||
paths: PathConfig for file paths
|
||||
|
||||
Returns:
|
||||
Dictionary mapping date_filter_id to list of record dicts
|
||||
e.g., {"all_6mo": [...], "all_12mo": [...], ...}
|
||||
"""
|
||||
if paths is None:
|
||||
paths = default_paths
|
||||
|
||||
results = {}
|
||||
|
||||
for config in DATE_FILTER_CONFIGS:
|
||||
logger.info(f"Processing date filter: {config.id}")
|
||||
|
||||
# Process pathway for this date filter
|
||||
ice_df = process_pathway_for_date_filter(
|
||||
df=df,
|
||||
config=config,
|
||||
trust_filter=trust_filter,
|
||||
drug_filter=drug_filter,
|
||||
directory_filter=directory_filter,
|
||||
minimum_patients=minimum_patients,
|
||||
max_date=max_date,
|
||||
paths=paths,
|
||||
)
|
||||
|
||||
if ice_df is None:
|
||||
logger.warning(f"Skipping {config.id} - no data")
|
||||
results[config.id] = []
|
||||
continue
|
||||
|
||||
# Extract denormalized fields
|
||||
ice_df = extract_denormalized_fields(ice_df)
|
||||
|
||||
# Convert to records
|
||||
records = convert_to_records(ice_df, config.id, refresh_id)
|
||||
results[config.id] = records
|
||||
|
||||
logger.info(f"Completed {config.id}: {len(records)} nodes")
|
||||
|
||||
total_records = sum(len(r) for r in results.values())
|
||||
logger.info(f"Total pathway nodes across all filters: {total_records}")
|
||||
|
||||
return results
|
||||
|
||||
|
||||
# Export public API
|
||||
__all__ = [
|
||||
# Types
|
||||
"ChartType",
|
||||
# Data classes
|
||||
"DateFilterConfig",
|
||||
"DATE_FILTER_CONFIGS",
|
||||
# Core functions
|
||||
"compute_date_ranges",
|
||||
"fetch_and_transform_data",
|
||||
# Directory chart processing
|
||||
"process_pathway_for_date_filter",
|
||||
"extract_denormalized_fields",
|
||||
# Indication chart processing
|
||||
"process_indication_pathway_for_date_filter",
|
||||
"extract_indication_fields",
|
||||
# Common utilities
|
||||
"convert_to_records",
|
||||
"process_all_date_filters",
|
||||
]
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,709 @@
|
||||
"""
|
||||
SQLite schema definitions for NHS High-Cost Drug Patient Pathway Analysis Tool.
|
||||
|
||||
Contains SQL strings for creating reference tables, fact tables, and indexes.
|
||||
Schema design supports:
|
||||
- Reference data from CSV files (drug names, organizations, directories)
|
||||
- Drug-directory mappings with single-valid-directory flag
|
||||
- Patient intervention facts with proper indexing
|
||||
- Cached aggregations for performance
|
||||
- File tracking for incremental updates
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
import sqlite3
|
||||
|
||||
from core.logging_config import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Reference Table Schemas
|
||||
# =============================================================================
|
||||
|
||||
REF_DRUG_NAMES_SCHEMA = """
|
||||
-- Mapping from raw drug names (as they appear in source data) to standardized names
|
||||
-- Source: data/drugnames.csv
|
||||
CREATE TABLE IF NOT EXISTS ref_drug_names (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
raw_name TEXT NOT NULL UNIQUE,
|
||||
standard_name TEXT NOT NULL,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
|
||||
-- Index for fast lookups during data transformation
|
||||
CREATE INDEX IF NOT EXISTS idx_ref_drug_names_raw ON ref_drug_names(raw_name);
|
||||
CREATE INDEX IF NOT EXISTS idx_ref_drug_names_standard ON ref_drug_names(standard_name);
|
||||
"""
|
||||
|
||||
REF_ORGANIZATIONS_SCHEMA = """
|
||||
-- NHS organization codes and names
|
||||
-- Source: data/org_codes.csv
|
||||
CREATE TABLE IF NOT EXISTS ref_organizations (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
org_code TEXT NOT NULL UNIQUE,
|
||||
org_name TEXT NOT NULL,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
|
||||
-- Index for fast lookups by organization code
|
||||
CREATE INDEX IF NOT EXISTS idx_ref_organizations_code ON ref_organizations(org_code);
|
||||
"""
|
||||
|
||||
REF_DIRECTORIES_SCHEMA = """
|
||||
-- Medical directories/specialties
|
||||
-- Source: data/directory_list.csv
|
||||
CREATE TABLE IF NOT EXISTS ref_directories (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
directory_name TEXT NOT NULL UNIQUE,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
|
||||
-- Index for fast lookups by directory name
|
||||
CREATE INDEX IF NOT EXISTS idx_ref_directories_name ON ref_directories(directory_name);
|
||||
"""
|
||||
|
||||
REF_DRUG_DIRECTORY_MAP_SCHEMA = """
|
||||
-- Mapping from drug names to valid directories
|
||||
-- Source: data/drug_directory_list.csv
|
||||
-- A drug may map to multiple directories (one row per drug-directory pair)
|
||||
-- The is_single_valid flag indicates drugs with exactly ONE valid directory,
|
||||
-- which enables automatic directory assignment in department_identification()
|
||||
CREATE TABLE IF NOT EXISTS ref_drug_directory_map (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
drug_name TEXT NOT NULL,
|
||||
directory_name TEXT NOT NULL,
|
||||
is_single_valid BOOLEAN NOT NULL DEFAULT 0,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
UNIQUE(drug_name, directory_name)
|
||||
);
|
||||
|
||||
-- Index for looking up directories by drug name (most common access pattern)
|
||||
CREATE INDEX IF NOT EXISTS idx_ref_drug_directory_map_drug ON ref_drug_directory_map(drug_name);
|
||||
|
||||
-- Index for reverse lookup (find drugs by directory)
|
||||
CREATE INDEX IF NOT EXISTS idx_ref_drug_directory_map_directory ON ref_drug_directory_map(directory_name);
|
||||
|
||||
-- Index for quick filtering of single-valid drugs
|
||||
CREATE INDEX IF NOT EXISTS idx_ref_drug_directory_map_single ON ref_drug_directory_map(is_single_valid);
|
||||
"""
|
||||
|
||||
REF_DRUG_INDICATION_CLUSTERS_SCHEMA = """
|
||||
-- Mapping from drugs to SNOMED clusters for indication validation
|
||||
-- Source: data/drug_indication_clusters.csv
|
||||
-- Used to validate that patients have appropriate GP diagnoses for their prescribed drugs
|
||||
-- A drug may map to multiple clusters (one row per drug-indication-cluster combination)
|
||||
CREATE TABLE IF NOT EXISTS ref_drug_indication_clusters (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
drug_name TEXT NOT NULL,
|
||||
indication TEXT NOT NULL,
|
||||
cluster_id TEXT NOT NULL,
|
||||
cluster_description TEXT,
|
||||
nice_ta_reference TEXT,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
UNIQUE(drug_name, indication, cluster_id)
|
||||
);
|
||||
|
||||
-- Index for looking up clusters by drug name (most common access pattern)
|
||||
CREATE INDEX IF NOT EXISTS idx_ref_drug_indication_clusters_drug ON ref_drug_indication_clusters(drug_name);
|
||||
|
||||
-- Index for looking up drugs by cluster (for finding all drugs treating a condition)
|
||||
CREATE INDEX IF NOT EXISTS idx_ref_drug_indication_clusters_cluster ON ref_drug_indication_clusters(cluster_id);
|
||||
|
||||
-- Index for looking up by indication text
|
||||
CREATE INDEX IF NOT EXISTS idx_ref_drug_indication_clusters_indication ON ref_drug_indication_clusters(indication);
|
||||
"""
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Pathway Data Architecture Schemas
|
||||
# =============================================================================
|
||||
|
||||
PATHWAY_DATE_FILTERS_SCHEMA = """
|
||||
-- Stores the 6 pre-computed date filter combinations
|
||||
-- Each combination represents a different initiated/last_seen date range
|
||||
-- Used to efficiently query pre-computed pathway data
|
||||
CREATE TABLE IF NOT EXISTS pathway_date_filters (
|
||||
id TEXT PRIMARY KEY, -- e.g., 'all_6mo', '1yr_12mo'
|
||||
initiated_label TEXT NOT NULL, -- e.g., 'All years', 'Last 1 year', 'Last 2 years'
|
||||
last_seen_label TEXT NOT NULL, -- e.g., 'Last 6 months', 'Last 12 months'
|
||||
initiated_years INTEGER, -- NULL for 'All', 1, or 2
|
||||
last_seen_months INTEGER NOT NULL, -- 6 or 12
|
||||
is_default INTEGER DEFAULT 0, -- 1 for 'all_6mo' (default selection)
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
|
||||
-- Pre-populate the 6 combinations
|
||||
INSERT OR REPLACE INTO pathway_date_filters (id, initiated_label, last_seen_label, initiated_years, last_seen_months, is_default) VALUES
|
||||
('all_6mo', 'All years', 'Last 6 months', NULL, 6, 1),
|
||||
('all_12mo', 'All years', 'Last 12 months', NULL, 12, 0),
|
||||
('1yr_6mo', 'Last 1 year', 'Last 6 months', 1, 6, 0),
|
||||
('1yr_12mo', 'Last 1 year', 'Last 12 months', 1, 12, 0),
|
||||
('2yr_6mo', 'Last 2 years', 'Last 6 months', 2, 6, 0),
|
||||
('2yr_12mo', 'Last 2 years', 'Last 12 months', 2, 12, 0);
|
||||
"""
|
||||
|
||||
PATHWAY_NODES_SCHEMA = """
|
||||
-- Main pathway nodes table (one set per date filter + chart type combination)
|
||||
-- Stores pre-computed pathway hierarchy with all visualization data
|
||||
-- Designed for fast filtering by date_filter_id + chart_type + trust/directory/drug
|
||||
CREATE TABLE IF NOT EXISTS pathway_nodes (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
|
||||
-- Date filter combination this belongs to
|
||||
date_filter_id TEXT NOT NULL,
|
||||
|
||||
-- Chart type: "directory" (Trust→Directory→Drug) or "indication" (Trust→SearchTerm→Drug)
|
||||
chart_type TEXT NOT NULL DEFAULT 'directory',
|
||||
|
||||
-- Hierarchy structure (for icicle chart)
|
||||
parents TEXT NOT NULL, -- Parent node identifier
|
||||
ids TEXT NOT NULL, -- Unique node identifier (hierarchical path)
|
||||
labels TEXT NOT NULL, -- Display label
|
||||
level INTEGER NOT NULL, -- Hierarchy depth (0=root, 1=trust, 2=directory/search_term, 3+=drugs)
|
||||
|
||||
-- Patient counts (accurate for this date filter combination)
|
||||
value INTEGER NOT NULL DEFAULT 0, -- Patient count
|
||||
|
||||
-- Cost metrics
|
||||
cost REAL NOT NULL DEFAULT 0.0, -- Total cost
|
||||
costpp REAL, -- Cost per patient
|
||||
cost_pp_pa TEXT, -- Cost per patient per annum (formatted string)
|
||||
|
||||
-- Visualization
|
||||
colour REAL NOT NULL DEFAULT 0.0, -- Color value (proportion of parent)
|
||||
|
||||
-- Date ranges (for this node)
|
||||
first_seen TEXT, -- First intervention date (ISO format)
|
||||
last_seen TEXT, -- Last intervention date (ISO format)
|
||||
first_seen_parent TEXT, -- Earliest date in parent group
|
||||
last_seen_parent TEXT, -- Latest date in parent group
|
||||
|
||||
-- Treatment statistics
|
||||
average_spacing TEXT, -- Formatted treatment duration string
|
||||
average_administered TEXT, -- JSON array of average doses per drug
|
||||
avg_days REAL, -- Average treatment duration in days
|
||||
|
||||
-- Denormalized filter columns (for efficient WHERE clause filtering)
|
||||
trust_name TEXT, -- Extracted trust name from ids
|
||||
directory TEXT, -- Extracted directory from ids
|
||||
drug_sequence TEXT, -- Pipe-separated drug sequence from pathway
|
||||
|
||||
-- Metadata
|
||||
created_at TEXT DEFAULT CURRENT_TIMESTAMP,
|
||||
data_refresh_id TEXT, -- Links to pathway_refresh_log
|
||||
|
||||
-- Unique per date filter + chart type + pathway
|
||||
UNIQUE(date_filter_id, chart_type, ids),
|
||||
FOREIGN KEY (date_filter_id) REFERENCES pathway_date_filters(id)
|
||||
);
|
||||
|
||||
-- Indexes for efficient filtering
|
||||
-- Primary filter: select by date_filter_id
|
||||
CREATE INDEX IF NOT EXISTS idx_pathway_nodes_date_filter ON pathway_nodes(date_filter_id);
|
||||
|
||||
-- Chart type filter: for switching between directory and indication views
|
||||
CREATE INDEX IF NOT EXISTS idx_pathway_nodes_chart_type ON pathway_nodes(date_filter_id, chart_type);
|
||||
|
||||
-- Level filter: often used with date_filter_id
|
||||
CREATE INDEX IF NOT EXISTS idx_pathway_nodes_level ON pathway_nodes(date_filter_id, level);
|
||||
|
||||
-- Trust filter: for Trust dropdown filtering
|
||||
CREATE INDEX IF NOT EXISTS idx_pathway_nodes_trust ON pathway_nodes(date_filter_id, trust_name);
|
||||
|
||||
-- Directory filter: for Directory dropdown filtering
|
||||
CREATE INDEX IF NOT EXISTS idx_pathway_nodes_directory ON pathway_nodes(date_filter_id, directory);
|
||||
|
||||
-- Drug sequence filter: for drug filtering (uses LIKE '%DRUG%')
|
||||
CREATE INDEX IF NOT EXISTS idx_pathway_nodes_drug_seq ON pathway_nodes(drug_sequence);
|
||||
|
||||
-- Parents filter: for finding children of a node
|
||||
CREATE INDEX IF NOT EXISTS idx_pathway_nodes_parents ON pathway_nodes(date_filter_id, parents);
|
||||
|
||||
-- Composite index for common filter combination
|
||||
CREATE INDEX IF NOT EXISTS idx_pathway_nodes_filter_composite
|
||||
ON pathway_nodes(date_filter_id, chart_type, trust_name, directory);
|
||||
"""
|
||||
|
||||
PATHWAY_REFRESH_LOG_SCHEMA = """
|
||||
-- Metadata table for tracking refresh status
|
||||
-- Tracks when pathway data was last refreshed from Snowflake
|
||||
CREATE TABLE IF NOT EXISTS pathway_refresh_log (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
refresh_id TEXT NOT NULL, -- Unique identifier for this refresh run
|
||||
started_at TEXT NOT NULL, -- ISO timestamp when refresh started
|
||||
completed_at TEXT, -- ISO timestamp when refresh completed (NULL if still running)
|
||||
status TEXT DEFAULT 'running', -- 'running', 'completed', 'failed'
|
||||
record_count INTEGER, -- Total pathway_nodes records created
|
||||
date_filter_counts TEXT, -- JSON: {"all_6mo": 1234, "all_12mo": 1567, ...}
|
||||
error_message TEXT, -- Error details if status='failed'
|
||||
snowflake_query_date_from TEXT, -- Start date of Snowflake query
|
||||
snowflake_query_date_to TEXT, -- End date of Snowflake query
|
||||
processing_duration_seconds REAL, -- How long the refresh took
|
||||
source_row_count INTEGER, -- Number of Snowflake rows fetched
|
||||
created_at TEXT DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
|
||||
-- Index for finding latest refresh
|
||||
CREATE INDEX IF NOT EXISTS idx_pathway_refresh_log_started ON pathway_refresh_log(started_at DESC);
|
||||
|
||||
-- Index for finding by status
|
||||
CREATE INDEX IF NOT EXISTS idx_pathway_refresh_log_status ON pathway_refresh_log(status);
|
||||
"""
|
||||
|
||||
# Combined pathway schema
|
||||
PATHWAY_TABLES_SCHEMA = f"""
|
||||
-- Pathway Data Architecture Tables
|
||||
-- Pre-computed pathway data for fast Reflex filtering
|
||||
|
||||
{PATHWAY_DATE_FILTERS_SCHEMA}
|
||||
|
||||
{PATHWAY_NODES_SCHEMA}
|
||||
|
||||
{PATHWAY_REFRESH_LOG_SCHEMA}
|
||||
"""
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Combined Schemas
|
||||
# =============================================================================
|
||||
|
||||
REFERENCE_TABLES_SCHEMA = f"""
|
||||
-- Reference Tables Schema
|
||||
-- Contains lookup data migrated from CSV files
|
||||
|
||||
{REF_DRUG_NAMES_SCHEMA}
|
||||
|
||||
{REF_ORGANIZATIONS_SCHEMA}
|
||||
|
||||
{REF_DIRECTORIES_SCHEMA}
|
||||
|
||||
{REF_DRUG_DIRECTORY_MAP_SCHEMA}
|
||||
|
||||
{REF_DRUG_INDICATION_CLUSTERS_SCHEMA}
|
||||
"""
|
||||
|
||||
ALL_TABLES_SCHEMA = f"""
|
||||
-- Complete Database Schema
|
||||
-- Reference tables + Pathway tables
|
||||
|
||||
{REFERENCE_TABLES_SCHEMA}
|
||||
|
||||
{PATHWAY_TABLES_SCHEMA}
|
||||
"""
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Schema Helper Functions
|
||||
# =============================================================================
|
||||
|
||||
def create_reference_tables(conn: sqlite3.Connection) -> None:
|
||||
"""
|
||||
Create all reference tables in the database.
|
||||
|
||||
Args:
|
||||
conn: SQLite database connection.
|
||||
"""
|
||||
logger.info("Creating reference tables...")
|
||||
conn.executescript(REFERENCE_TABLES_SCHEMA)
|
||||
logger.info("Reference tables created successfully")
|
||||
|
||||
|
||||
def drop_reference_tables(conn: sqlite3.Connection) -> None:
|
||||
"""
|
||||
Drop all reference tables from the database.
|
||||
|
||||
Args:
|
||||
conn: SQLite database connection.
|
||||
|
||||
Warning:
|
||||
This will delete all reference data. Use with caution.
|
||||
"""
|
||||
logger.warning("Dropping reference tables...")
|
||||
conn.executescript("""
|
||||
DROP TABLE IF EXISTS ref_drug_names;
|
||||
DROP TABLE IF EXISTS ref_organizations;
|
||||
DROP TABLE IF EXISTS ref_directories;
|
||||
DROP TABLE IF EXISTS ref_drug_directory_map;
|
||||
DROP TABLE IF EXISTS ref_drug_indication_clusters;
|
||||
""")
|
||||
logger.info("Reference tables dropped")
|
||||
|
||||
|
||||
def get_reference_table_counts(conn: sqlite3.Connection) -> dict[str, int]:
|
||||
"""
|
||||
Get row counts for all reference tables.
|
||||
|
||||
Args:
|
||||
conn: SQLite database connection.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping table name to row count.
|
||||
"""
|
||||
tables = [
|
||||
"ref_drug_names",
|
||||
"ref_organizations",
|
||||
"ref_directories",
|
||||
"ref_drug_directory_map",
|
||||
"ref_drug_indication_clusters",
|
||||
]
|
||||
counts = {}
|
||||
|
||||
for table in tables:
|
||||
try:
|
||||
cursor = conn.execute(f"SELECT COUNT(*) FROM {table}")
|
||||
result = cursor.fetchone()
|
||||
counts[table] = result[0] if result else 0
|
||||
except sqlite3.OperationalError:
|
||||
counts[table] = 0
|
||||
|
||||
return counts
|
||||
|
||||
|
||||
def verify_reference_tables_exist(conn: sqlite3.Connection) -> list[str]:
|
||||
"""
|
||||
Verify that all reference tables exist.
|
||||
|
||||
Args:
|
||||
conn: SQLite database connection.
|
||||
|
||||
Returns:
|
||||
List of missing table names. Empty list means all tables exist.
|
||||
"""
|
||||
required_tables = [
|
||||
"ref_drug_names",
|
||||
"ref_organizations",
|
||||
"ref_directories",
|
||||
"ref_drug_directory_map",
|
||||
"ref_drug_indication_clusters",
|
||||
]
|
||||
missing = []
|
||||
|
||||
for table in required_tables:
|
||||
cursor = conn.execute(
|
||||
"SELECT name FROM sqlite_master WHERE type='table' AND name=?",
|
||||
(table,)
|
||||
)
|
||||
if cursor.fetchone() is None:
|
||||
missing.append(table)
|
||||
|
||||
return missing
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Pathway Table Helper Functions
|
||||
# =============================================================================
|
||||
|
||||
def create_pathway_tables(conn: sqlite3.Connection) -> None:
|
||||
"""
|
||||
Create pathway data architecture tables in the database.
|
||||
|
||||
Creates:
|
||||
- pathway_date_filters: 6 pre-defined date filter combinations
|
||||
- pathway_nodes: Pre-computed pathway hierarchy data
|
||||
- pathway_refresh_log: Refresh tracking metadata
|
||||
|
||||
Args:
|
||||
conn: SQLite database connection.
|
||||
"""
|
||||
logger.info("Creating pathway tables...")
|
||||
conn.executescript(PATHWAY_TABLES_SCHEMA)
|
||||
logger.info("Pathway tables created successfully")
|
||||
|
||||
|
||||
def drop_pathway_tables(conn: sqlite3.Connection) -> None:
|
||||
"""
|
||||
Drop pathway data architecture tables from the database.
|
||||
|
||||
Args:
|
||||
conn: SQLite database connection.
|
||||
|
||||
Warning:
|
||||
This will delete all pre-computed pathway data.
|
||||
"""
|
||||
logger.warning("Dropping pathway tables...")
|
||||
conn.executescript("""
|
||||
DROP TABLE IF EXISTS pathway_nodes;
|
||||
DROP TABLE IF EXISTS pathway_refresh_log;
|
||||
DROP TABLE IF EXISTS pathway_date_filters;
|
||||
""")
|
||||
logger.info("Pathway tables dropped")
|
||||
|
||||
|
||||
def get_pathway_table_counts(conn: sqlite3.Connection) -> dict[str, int]:
|
||||
"""
|
||||
Get row counts for pathway tables.
|
||||
|
||||
Args:
|
||||
conn: SQLite database connection.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping table name to row count.
|
||||
"""
|
||||
tables = ["pathway_date_filters", "pathway_nodes", "pathway_refresh_log"]
|
||||
counts = {}
|
||||
|
||||
for table in tables:
|
||||
try:
|
||||
cursor = conn.execute(f"SELECT COUNT(*) FROM {table}")
|
||||
result = cursor.fetchone()
|
||||
counts[table] = result[0] if result else 0
|
||||
except sqlite3.OperationalError:
|
||||
# Table doesn't exist yet
|
||||
counts[table] = 0
|
||||
|
||||
return counts
|
||||
|
||||
|
||||
def verify_pathway_tables_exist(conn: sqlite3.Connection) -> list[str]:
|
||||
"""
|
||||
Verify that pathway tables exist.
|
||||
|
||||
Args:
|
||||
conn: SQLite database connection.
|
||||
|
||||
Returns:
|
||||
List of missing table names. Empty list means all tables exist.
|
||||
"""
|
||||
required_tables = ["pathway_date_filters", "pathway_nodes", "pathway_refresh_log"]
|
||||
missing = []
|
||||
|
||||
for table in required_tables:
|
||||
cursor = conn.execute(
|
||||
"SELECT name FROM sqlite_master WHERE type='table' AND name=?",
|
||||
(table,)
|
||||
)
|
||||
if cursor.fetchone() is None:
|
||||
missing.append(table)
|
||||
|
||||
return missing
|
||||
|
||||
|
||||
def clear_pathway_nodes(conn: sqlite3.Connection, date_filter_id: str | None = None) -> int:
|
||||
"""
|
||||
Clear pathway nodes, optionally for a specific date filter.
|
||||
|
||||
Args:
|
||||
conn: SQLite database connection.
|
||||
date_filter_id: If provided, only clear nodes for this date filter.
|
||||
If None, clear all pathway nodes.
|
||||
|
||||
Returns:
|
||||
Number of rows deleted.
|
||||
"""
|
||||
if date_filter_id:
|
||||
cursor = conn.execute(
|
||||
"DELETE FROM pathway_nodes WHERE date_filter_id = ?",
|
||||
(date_filter_id,)
|
||||
)
|
||||
else:
|
||||
cursor = conn.execute("DELETE FROM pathway_nodes")
|
||||
|
||||
deleted_count = cursor.rowcount
|
||||
conn.commit()
|
||||
logger.info(f"Cleared {deleted_count} pathway nodes")
|
||||
return deleted_count
|
||||
|
||||
|
||||
def get_pathway_refresh_status(conn: sqlite3.Connection) -> dict | None:
|
||||
"""
|
||||
Get the status of the most recent pathway refresh.
|
||||
|
||||
Args:
|
||||
conn: SQLite database connection.
|
||||
|
||||
Returns:
|
||||
Dictionary with refresh status, or None if no refresh has been done.
|
||||
"""
|
||||
try:
|
||||
cursor = conn.execute("""
|
||||
SELECT refresh_id, started_at, completed_at, status, record_count,
|
||||
date_filter_counts, error_message, processing_duration_seconds
|
||||
FROM pathway_refresh_log
|
||||
ORDER BY started_at DESC
|
||||
LIMIT 1
|
||||
""")
|
||||
row = cursor.fetchone()
|
||||
if row:
|
||||
return {
|
||||
"refresh_id": row[0],
|
||||
"started_at": row[1],
|
||||
"completed_at": row[2],
|
||||
"status": row[3],
|
||||
"record_count": row[4],
|
||||
"date_filter_counts": row[5],
|
||||
"error_message": row[6],
|
||||
"processing_duration_seconds": row[7],
|
||||
}
|
||||
return None
|
||||
except sqlite3.OperationalError:
|
||||
# Table doesn't exist yet
|
||||
return None
|
||||
|
||||
|
||||
def migrate_pathway_nodes_chart_type(conn: sqlite3.Connection) -> tuple[bool, str]:
|
||||
"""
|
||||
Migrate pathway_nodes table to add chart_type column.
|
||||
|
||||
This migration:
|
||||
1. Checks if chart_type column already exists
|
||||
2. If not, adds it with DEFAULT 'directory'
|
||||
3. Updates existing rows to have 'directory' chart_type
|
||||
4. Adds index for efficient filtering
|
||||
|
||||
Args:
|
||||
conn: SQLite database connection.
|
||||
|
||||
Returns:
|
||||
Tuple of (success: bool, message: str)
|
||||
"""
|
||||
# Check if table exists
|
||||
cursor = conn.execute(
|
||||
"SELECT name FROM sqlite_master WHERE type='table' AND name='pathway_nodes'"
|
||||
)
|
||||
if cursor.fetchone() is None:
|
||||
return True, "pathway_nodes table does not exist yet (will be created with chart_type column)"
|
||||
|
||||
# Check if chart_type column already exists
|
||||
cursor = conn.execute("PRAGMA table_info(pathway_nodes)")
|
||||
columns = [row[1] for row in cursor.fetchall()]
|
||||
|
||||
if "chart_type" in columns:
|
||||
return True, "chart_type column already exists in pathway_nodes"
|
||||
|
||||
# Add chart_type column
|
||||
logger.info("Adding chart_type column to pathway_nodes table...")
|
||||
try:
|
||||
# Add column with default value
|
||||
conn.execute("""
|
||||
ALTER TABLE pathway_nodes
|
||||
ADD COLUMN chart_type TEXT NOT NULL DEFAULT 'directory'
|
||||
""")
|
||||
|
||||
# Create index for efficient filtering by chart type
|
||||
conn.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_pathway_nodes_chart_type
|
||||
ON pathway_nodes(date_filter_id, chart_type)
|
||||
""")
|
||||
|
||||
# Update existing composite index (need to drop and recreate)
|
||||
# Note: SQLite doesn't support DROP INDEX IF EXISTS in older versions,
|
||||
# so we use a try/except
|
||||
try:
|
||||
conn.execute("DROP INDEX idx_pathway_nodes_filter_composite")
|
||||
except sqlite3.OperationalError:
|
||||
pass # Index didn't exist
|
||||
|
||||
conn.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_pathway_nodes_filter_composite
|
||||
ON pathway_nodes(date_filter_id, chart_type, trust_name, directory)
|
||||
""")
|
||||
|
||||
# Need to recreate unique constraint since it changed
|
||||
# SQLite doesn't support ALTER TABLE to change constraints, but
|
||||
# since we're adding a column with a default value and the old
|
||||
# constraint was (date_filter_id, ids), the new constraint
|
||||
# (date_filter_id, chart_type, ids) will be satisfied by all existing
|
||||
# rows since they all have chart_type='directory'
|
||||
|
||||
conn.commit()
|
||||
logger.info("chart_type column added successfully")
|
||||
|
||||
# Count updated rows
|
||||
cursor = conn.execute("SELECT COUNT(*) FROM pathway_nodes")
|
||||
row_count = cursor.fetchone()[0]
|
||||
|
||||
return True, f"Added chart_type column, {row_count} existing rows set to 'directory'"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to add chart_type column: {e}")
|
||||
return False, f"Migration failed: {e}"
|
||||
|
||||
|
||||
def migrate_refresh_log_source_row_count(conn: sqlite3.Connection) -> tuple[bool, str]:
|
||||
"""Add source_row_count column to pathway_refresh_log if it doesn't exist.
|
||||
|
||||
This column stores the Snowflake row count for display in the UI footer.
|
||||
"""
|
||||
cursor = conn.execute("PRAGMA table_info(pathway_refresh_log)")
|
||||
columns = [row[1] for row in cursor.fetchall()]
|
||||
|
||||
if "source_row_count" in columns:
|
||||
return True, "source_row_count column already exists"
|
||||
|
||||
logger.info("Adding source_row_count column to pathway_refresh_log...")
|
||||
try:
|
||||
conn.execute("""
|
||||
ALTER TABLE pathway_refresh_log
|
||||
ADD COLUMN source_row_count INTEGER
|
||||
""")
|
||||
conn.commit()
|
||||
return True, "Added source_row_count column"
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to add source_row_count column: {e}")
|
||||
return False, f"Migration failed: {e}"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Combined Helper Functions
|
||||
# =============================================================================
|
||||
|
||||
def create_all_tables(conn: sqlite3.Connection) -> None:
|
||||
"""
|
||||
Create all tables (reference + pathway) in the database.
|
||||
|
||||
Args:
|
||||
conn: SQLite database connection.
|
||||
"""
|
||||
logger.info("Creating all database tables...")
|
||||
conn.executescript(ALL_TABLES_SCHEMA)
|
||||
logger.info("All tables created successfully")
|
||||
|
||||
|
||||
def drop_all_tables(conn: sqlite3.Connection) -> None:
|
||||
"""
|
||||
Drop all tables from the database.
|
||||
|
||||
Args:
|
||||
conn: SQLite database connection.
|
||||
|
||||
Warning:
|
||||
This will delete all data. Use with extreme caution.
|
||||
"""
|
||||
logger.warning("Dropping all tables...")
|
||||
drop_pathway_tables(conn)
|
||||
drop_reference_tables(conn)
|
||||
logger.info("All tables dropped")
|
||||
|
||||
|
||||
def get_all_table_counts(conn: sqlite3.Connection) -> dict[str, int]:
|
||||
"""
|
||||
Get row counts for all tables.
|
||||
|
||||
Args:
|
||||
conn: SQLite database connection.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping table name to row count.
|
||||
"""
|
||||
counts = {}
|
||||
counts.update(get_reference_table_counts(conn))
|
||||
counts.update(get_pathway_table_counts(conn))
|
||||
return counts
|
||||
|
||||
|
||||
def verify_all_tables_exist(conn: sqlite3.Connection) -> list[str]:
|
||||
"""
|
||||
Verify that all tables exist.
|
||||
|
||||
Args:
|
||||
conn: SQLite database connection.
|
||||
|
||||
Returns:
|
||||
List of missing table names. Empty list means all tables exist.
|
||||
"""
|
||||
missing = []
|
||||
missing.extend(verify_reference_tables_exist(conn))
|
||||
missing.extend(verify_pathway_tables_exist(conn))
|
||||
return missing
|
||||
@@ -0,0 +1,797 @@
|
||||
"""
|
||||
Snowflake connector module for NHS Patient Pathway Analysis.
|
||||
|
||||
Provides connection handling with SSO browser authentication for NHS environments.
|
||||
Uses the externalbrowser authenticator which opens a browser window for NHS identity
|
||||
management authentication.
|
||||
|
||||
Usage:
|
||||
from data_processing.snowflake_connector import SnowflakeConnector, get_connector
|
||||
|
||||
# Using context manager (recommended)
|
||||
with get_connector() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT * FROM table LIMIT 10")
|
||||
results = cursor.fetchall()
|
||||
|
||||
# Manual connection management
|
||||
connector = SnowflakeConnector()
|
||||
try:
|
||||
conn = connector.connect()
|
||||
cursor = conn.cursor()
|
||||
# ... use cursor ...
|
||||
finally:
|
||||
connector.close()
|
||||
"""
|
||||
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from datetime import date, datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Generator, Optional, TYPE_CHECKING
|
||||
import time
|
||||
|
||||
# Snowflake connector is an optional dependency
|
||||
SNOWFLAKE_AVAILABLE = False
|
||||
try:
|
||||
import snowflake.connector
|
||||
from snowflake.connector import SnowflakeConnection
|
||||
from snowflake.connector.cursor import SnowflakeCursor
|
||||
SNOWFLAKE_AVAILABLE = True
|
||||
except ImportError:
|
||||
snowflake = None # type: ignore[assignment]
|
||||
|
||||
# Type hints for when snowflake is not available
|
||||
if TYPE_CHECKING:
|
||||
from snowflake.connector import SnowflakeConnection
|
||||
from snowflake.connector.cursor import SnowflakeCursor
|
||||
|
||||
from config import get_snowflake_config, SnowflakeConfig
|
||||
from core.logging_config import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class SnowflakeConnectionError(Exception):
|
||||
"""Raised when Snowflake connection fails."""
|
||||
pass
|
||||
|
||||
|
||||
class SnowflakeNotConfiguredError(Exception):
|
||||
"""Raised when Snowflake is not configured (no account)."""
|
||||
pass
|
||||
|
||||
|
||||
class SnowflakeNotAvailableError(Exception):
|
||||
"""Raised when snowflake-connector-python is not installed."""
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConnectionInfo:
|
||||
"""Information about the current connection state."""
|
||||
connected: bool = False
|
||||
account: str = ""
|
||||
warehouse: str = ""
|
||||
database: str = ""
|
||||
schema: str = ""
|
||||
user: str = ""
|
||||
role: str = ""
|
||||
connected_at: Optional[datetime] = None
|
||||
last_query_at: Optional[datetime] = None
|
||||
query_count: int = 0
|
||||
|
||||
|
||||
class SnowflakeConnector:
|
||||
"""
|
||||
Manages Snowflake connections with SSO browser authentication.
|
||||
|
||||
This class provides connection management for NHS Snowflake access using
|
||||
the externalbrowser authenticator which triggers NHS SSO login via browser.
|
||||
|
||||
Attributes:
|
||||
config: SnowflakeConfig with connection settings
|
||||
connection_info: ConnectionInfo tracking current state
|
||||
|
||||
Example:
|
||||
connector = SnowflakeConnector()
|
||||
with connector.get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT CURRENT_USER()")
|
||||
print(cursor.fetchone()[0])
|
||||
"""
|
||||
|
||||
def __init__(self, config: Optional[SnowflakeConfig] = None):
|
||||
"""
|
||||
Initialize the connector with configuration.
|
||||
|
||||
Args:
|
||||
config: Optional SnowflakeConfig. If not provided, loads from
|
||||
config/snowflake.toml using get_snowflake_config().
|
||||
"""
|
||||
self._config = config or get_snowflake_config()
|
||||
self._connection: Optional[SnowflakeConnection] = None
|
||||
self._connection_info = ConnectionInfo()
|
||||
|
||||
@property
|
||||
def config(self) -> SnowflakeConfig:
|
||||
"""Return the Snowflake configuration."""
|
||||
return self._config
|
||||
|
||||
@property
|
||||
def connection_info(self) -> ConnectionInfo:
|
||||
"""Return information about the current connection state."""
|
||||
return self._connection_info
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
"""Return True if currently connected to Snowflake."""
|
||||
return self._connection is not None and not self._connection.is_closed()
|
||||
|
||||
def _check_availability(self) -> None:
|
||||
"""Check that snowflake-connector-python is installed."""
|
||||
if not SNOWFLAKE_AVAILABLE:
|
||||
raise SnowflakeNotAvailableError(
|
||||
"snowflake-connector-python is not installed. "
|
||||
"Install it with: pip install snowflake-connector-python"
|
||||
)
|
||||
|
||||
def _check_configured(self) -> None:
|
||||
"""Check that Snowflake is configured."""
|
||||
if not self._config.is_configured:
|
||||
raise SnowflakeNotConfiguredError(
|
||||
"Snowflake account is not configured. "
|
||||
"Edit config/snowflake.toml and set connection.account"
|
||||
)
|
||||
|
||||
def connect(self) -> SnowflakeConnection:
|
||||
"""
|
||||
Establish a connection to Snowflake.
|
||||
|
||||
Uses the externalbrowser authenticator which opens a browser window
|
||||
for NHS SSO authentication. The browser popup is expected and normal.
|
||||
|
||||
Returns:
|
||||
Active SnowflakeConnection
|
||||
|
||||
Raises:
|
||||
SnowflakeNotAvailableError: If snowflake-connector-python not installed
|
||||
SnowflakeNotConfiguredError: If account is not configured
|
||||
SnowflakeConnectionError: If connection fails
|
||||
"""
|
||||
self._check_availability()
|
||||
self._check_configured()
|
||||
|
||||
# Close existing connection if any
|
||||
if self._connection is not None:
|
||||
self.close()
|
||||
|
||||
conn_cfg = self._config.connection
|
||||
timeout_cfg = self._config.timeouts
|
||||
|
||||
logger.info(f"Connecting to Snowflake account: {conn_cfg.account}")
|
||||
logger.info(f"Using warehouse: {conn_cfg.warehouse}, database: {conn_cfg.database}")
|
||||
logger.info(f"Authenticator: {conn_cfg.authenticator}")
|
||||
if conn_cfg.authenticator == "externalbrowser":
|
||||
logger.info("Browser window will open for NHS SSO authentication")
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# Build connection parameters
|
||||
connect_params = {
|
||||
"account": conn_cfg.account,
|
||||
"warehouse": conn_cfg.warehouse,
|
||||
"database": conn_cfg.database,
|
||||
"schema": conn_cfg.schema,
|
||||
"authenticator": conn_cfg.authenticator,
|
||||
"login_timeout": timeout_cfg.login_timeout,
|
||||
"network_timeout": timeout_cfg.connection_timeout,
|
||||
}
|
||||
|
||||
# Optional parameters (only add if set)
|
||||
if conn_cfg.user:
|
||||
connect_params["user"] = conn_cfg.user
|
||||
if conn_cfg.role:
|
||||
connect_params["role"] = conn_cfg.role
|
||||
|
||||
self._connection = snowflake.connector.connect(**connect_params)
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
logger.info(f"Connected to Snowflake successfully in {elapsed:.1f}s")
|
||||
|
||||
# Update connection info
|
||||
self._connection_info = ConnectionInfo(
|
||||
connected=True,
|
||||
account=conn_cfg.account,
|
||||
warehouse=conn_cfg.warehouse,
|
||||
database=conn_cfg.database,
|
||||
schema=conn_cfg.schema,
|
||||
user=self._get_current_user(),
|
||||
role=self._get_current_role(),
|
||||
connected_at=datetime.now(),
|
||||
query_count=0,
|
||||
)
|
||||
|
||||
return self._connection
|
||||
|
||||
except Exception as e:
|
||||
elapsed = time.time() - start_time
|
||||
logger.error(f"Failed to connect to Snowflake after {elapsed:.1f}s: {e}")
|
||||
self._connection_info = ConnectionInfo(connected=False)
|
||||
raise SnowflakeConnectionError(f"Failed to connect to Snowflake: {e}") from e
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the Snowflake connection if open."""
|
||||
if self._connection is not None:
|
||||
try:
|
||||
self._connection.close()
|
||||
logger.info("Snowflake connection closed")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error closing Snowflake connection: {e}")
|
||||
finally:
|
||||
self._connection = None
|
||||
self._connection_info = ConnectionInfo(connected=False)
|
||||
|
||||
def _get_current_user(self) -> str:
|
||||
"""Get the current authenticated user."""
|
||||
if self._connection is None:
|
||||
return ""
|
||||
try:
|
||||
cursor = self._connection.cursor()
|
||||
cursor.execute("SELECT CURRENT_USER()")
|
||||
result = cursor.fetchone()
|
||||
return result[0] if result else ""
|
||||
except Exception:
|
||||
return ""
|
||||
|
||||
def _get_current_role(self) -> str:
|
||||
"""Get the current active role."""
|
||||
if self._connection is None:
|
||||
return ""
|
||||
try:
|
||||
cursor = self._connection.cursor()
|
||||
cursor.execute("SELECT CURRENT_ROLE()")
|
||||
result = cursor.fetchone()
|
||||
return result[0] if result else ""
|
||||
except Exception:
|
||||
return ""
|
||||
|
||||
@contextmanager
|
||||
def get_connection(self) -> Generator[SnowflakeConnection, None, None]:
|
||||
"""
|
||||
Context manager for connection handling.
|
||||
|
||||
Creates a new connection if not already connected, yields the connection,
|
||||
and ensures proper cleanup on exit.
|
||||
|
||||
Yields:
|
||||
Active SnowflakeConnection
|
||||
|
||||
Example:
|
||||
connector = SnowflakeConnector()
|
||||
with connector.get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT 1")
|
||||
"""
|
||||
if not self.is_connected:
|
||||
self.connect()
|
||||
|
||||
assert self._connection is not None, "Connection should be established"
|
||||
try:
|
||||
yield self._connection
|
||||
finally:
|
||||
# Keep connection open for reuse
|
||||
pass
|
||||
|
||||
@contextmanager
|
||||
def get_cursor(
|
||||
self,
|
||||
dict_cursor: bool = False
|
||||
) -> Generator[SnowflakeCursor, None, None]:
|
||||
"""
|
||||
Context manager that provides a cursor.
|
||||
|
||||
Args:
|
||||
dict_cursor: If True, returns cursor that yields dict-like rows
|
||||
|
||||
Yields:
|
||||
SnowflakeCursor for executing queries
|
||||
|
||||
Example:
|
||||
connector = SnowflakeConnector()
|
||||
with connector.get_cursor() as cursor:
|
||||
cursor.execute("SELECT * FROM table LIMIT 10")
|
||||
for row in cursor:
|
||||
print(row)
|
||||
"""
|
||||
if not self.is_connected:
|
||||
self.connect()
|
||||
|
||||
assert self._connection is not None, "Connection should be established"
|
||||
cursor: Any = None
|
||||
try:
|
||||
if dict_cursor:
|
||||
cursor = self._connection.cursor(snowflake.connector.DictCursor) # type: ignore[union-attr]
|
||||
else:
|
||||
cursor = self._connection.cursor()
|
||||
yield cursor # type: ignore[misc]
|
||||
self._connection_info.last_query_at = datetime.now()
|
||||
self._connection_info.query_count += 1
|
||||
finally:
|
||||
if cursor is not None:
|
||||
cursor.close()
|
||||
|
||||
def execute(
|
||||
self,
|
||||
query: str,
|
||||
params: Optional[tuple] = None,
|
||||
timeout: Optional[int] = None
|
||||
) -> list[tuple]:
|
||||
"""
|
||||
Execute a query and return all results.
|
||||
|
||||
Args:
|
||||
query: SQL query to execute
|
||||
params: Optional query parameters for parameterized queries
|
||||
timeout: Optional query timeout in seconds (overrides config)
|
||||
|
||||
Returns:
|
||||
List of result rows as tuples
|
||||
|
||||
Raises:
|
||||
SnowflakeConnectionError: If not connected
|
||||
Various snowflake errors for query issues
|
||||
"""
|
||||
if not self.is_connected:
|
||||
self.connect()
|
||||
|
||||
effective_timeout = timeout or self._config.timeouts.query_timeout
|
||||
|
||||
with self.get_cursor() as cursor:
|
||||
logger.info(f"Executing query (timeout={effective_timeout}s)")
|
||||
logger.debug(f"Query: {query[:200]}...")
|
||||
|
||||
if effective_timeout > 0:
|
||||
cursor.execute(f"ALTER SESSION SET STATEMENT_TIMEOUT_IN_SECONDS = {effective_timeout}")
|
||||
|
||||
start_time = time.time()
|
||||
cursor.execute(query, params)
|
||||
results = cursor.fetchall()
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
logger.info(f"Query returned {len(results)} rows in {elapsed:.2f}s")
|
||||
return results
|
||||
|
||||
def execute_dict(
|
||||
self,
|
||||
query: str,
|
||||
params: Optional[tuple] = None,
|
||||
timeout: Optional[int] = None
|
||||
) -> list[dict]:
|
||||
"""
|
||||
Execute a query and return results as list of dictionaries.
|
||||
|
||||
Args:
|
||||
query: SQL query to execute
|
||||
params: Optional query parameters
|
||||
timeout: Optional query timeout in seconds
|
||||
|
||||
Returns:
|
||||
List of result rows as dictionaries
|
||||
"""
|
||||
if not self.is_connected:
|
||||
self.connect()
|
||||
|
||||
effective_timeout = timeout or self._config.timeouts.query_timeout
|
||||
|
||||
with self.get_cursor(dict_cursor=True) as cursor:
|
||||
logger.info(f"Executing query (timeout={effective_timeout}s)")
|
||||
logger.debug(f"Query: {query[:200]}...")
|
||||
|
||||
if effective_timeout > 0:
|
||||
cursor.execute(f"ALTER SESSION SET STATEMENT_TIMEOUT_IN_SECONDS = {effective_timeout}")
|
||||
|
||||
start_time = time.time()
|
||||
cursor.execute(query, params)
|
||||
results = cursor.fetchall()
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
logger.info(f"Query returned {len(results)} rows in {elapsed:.2f}s")
|
||||
return results # type: ignore[return-value]
|
||||
|
||||
def execute_chunked(
|
||||
self,
|
||||
query: str,
|
||||
params: Optional[tuple] = None,
|
||||
chunk_size: Optional[int] = None,
|
||||
timeout: Optional[int] = None,
|
||||
max_rows: Optional[int] = None,
|
||||
) -> Generator[list[tuple], None, None]:
|
||||
"""
|
||||
Execute a query and yield results in chunks for memory efficiency.
|
||||
|
||||
This method is useful for large result sets that would exceed memory
|
||||
if loaded all at once. Results are yielded as chunks of rows.
|
||||
|
||||
Args:
|
||||
query: SQL query to execute
|
||||
params: Optional query parameters for parameterized queries
|
||||
chunk_size: Number of rows per chunk (default from config)
|
||||
timeout: Optional query timeout in seconds (overrides config)
|
||||
max_rows: Maximum total rows to return (default from config, 0 for no limit)
|
||||
|
||||
Yields:
|
||||
List of result rows as tuples for each chunk
|
||||
|
||||
Example:
|
||||
for chunk in connector.execute_chunked("SELECT * FROM large_table"):
|
||||
process_chunk(chunk)
|
||||
"""
|
||||
if not self.is_connected:
|
||||
self.connect()
|
||||
|
||||
effective_timeout = timeout or self._config.timeouts.query_timeout
|
||||
effective_chunk_size = chunk_size or self._config.query.chunk_size
|
||||
effective_max_rows = max_rows if max_rows is not None else self._config.query.max_rows
|
||||
|
||||
with self.get_cursor() as cursor:
|
||||
logger.info(f"Executing chunked query (chunk_size={effective_chunk_size}, timeout={effective_timeout}s)")
|
||||
logger.debug(f"Query: {query[:200]}...")
|
||||
|
||||
if effective_timeout > 0:
|
||||
cursor.execute(f"ALTER SESSION SET STATEMENT_TIMEOUT_IN_SECONDS = {effective_timeout}")
|
||||
|
||||
start_time = time.time()
|
||||
cursor.execute(query, params)
|
||||
|
||||
total_rows = 0
|
||||
chunk_num = 0
|
||||
|
||||
while True:
|
||||
# Determine how many rows to fetch this chunk
|
||||
if effective_max_rows > 0:
|
||||
remaining = effective_max_rows - total_rows
|
||||
if remaining <= 0:
|
||||
break
|
||||
fetch_size = min(effective_chunk_size, remaining)
|
||||
else:
|
||||
fetch_size = effective_chunk_size
|
||||
|
||||
chunk = cursor.fetchmany(fetch_size)
|
||||
if not chunk:
|
||||
break
|
||||
|
||||
chunk_num += 1
|
||||
total_rows += len(chunk)
|
||||
logger.debug(f"Chunk {chunk_num}: {len(chunk)} rows (total: {total_rows})")
|
||||
yield chunk
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
logger.info(f"Chunked query returned {total_rows} rows in {chunk_num} chunks ({elapsed:.2f}s)")
|
||||
|
||||
def execute_chunked_dict(
|
||||
self,
|
||||
query: str,
|
||||
params: Optional[tuple] = None,
|
||||
chunk_size: Optional[int] = None,
|
||||
timeout: Optional[int] = None,
|
||||
max_rows: Optional[int] = None,
|
||||
) -> Generator[list[dict], None, None]:
|
||||
"""
|
||||
Execute a query and yield dict results in chunks for memory efficiency.
|
||||
|
||||
Same as execute_chunked but returns rows as dictionaries.
|
||||
|
||||
Args:
|
||||
query: SQL query to execute
|
||||
params: Optional query parameters
|
||||
chunk_size: Number of rows per chunk (default from config)
|
||||
timeout: Optional query timeout in seconds
|
||||
max_rows: Maximum total rows to return (default from config, 0 for no limit)
|
||||
|
||||
Yields:
|
||||
List of result rows as dictionaries for each chunk
|
||||
"""
|
||||
if not self.is_connected:
|
||||
self.connect()
|
||||
|
||||
effective_timeout = timeout or self._config.timeouts.query_timeout
|
||||
effective_chunk_size = chunk_size or self._config.query.chunk_size
|
||||
effective_max_rows = max_rows if max_rows is not None else self._config.query.max_rows
|
||||
|
||||
with self.get_cursor(dict_cursor=True) as cursor:
|
||||
logger.info(f"Executing chunked dict query (chunk_size={effective_chunk_size}, timeout={effective_timeout}s)")
|
||||
logger.debug(f"Query: {query[:200]}...")
|
||||
|
||||
if effective_timeout > 0:
|
||||
cursor.execute(f"ALTER SESSION SET STATEMENT_TIMEOUT_IN_SECONDS = {effective_timeout}")
|
||||
|
||||
start_time = time.time()
|
||||
cursor.execute(query, params)
|
||||
|
||||
total_rows = 0
|
||||
chunk_num = 0
|
||||
|
||||
while True:
|
||||
# Determine how many rows to fetch this chunk
|
||||
if effective_max_rows > 0:
|
||||
remaining = effective_max_rows - total_rows
|
||||
if remaining <= 0:
|
||||
break
|
||||
fetch_size = min(effective_chunk_size, remaining)
|
||||
else:
|
||||
fetch_size = effective_chunk_size
|
||||
|
||||
chunk = cursor.fetchmany(fetch_size)
|
||||
if not chunk:
|
||||
break
|
||||
|
||||
chunk_num += 1
|
||||
total_rows += len(chunk)
|
||||
logger.debug(f"Chunk {chunk_num}: {len(chunk)} rows (total: {total_rows})")
|
||||
yield chunk # type: ignore[misc]
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
logger.info(f"Chunked dict query returned {total_rows} rows in {chunk_num} chunks ({elapsed:.2f}s)")
|
||||
|
||||
def execute_with_row_limit(
|
||||
self,
|
||||
query: str,
|
||||
params: Optional[tuple] = None,
|
||||
max_rows: Optional[int] = None,
|
||||
timeout: Optional[int] = None
|
||||
) -> tuple[list[dict], bool]:
|
||||
"""
|
||||
Execute a query with a row limit and indicate if more rows were available.
|
||||
|
||||
This is useful for pagination or previewing large result sets.
|
||||
|
||||
Args:
|
||||
query: SQL query to execute
|
||||
params: Optional query parameters
|
||||
max_rows: Maximum rows to return (default from config)
|
||||
timeout: Optional query timeout in seconds
|
||||
|
||||
Returns:
|
||||
Tuple of (results list, has_more bool)
|
||||
- results: List of result rows as dictionaries (up to max_rows)
|
||||
- has_more: True if there were more rows than max_rows
|
||||
"""
|
||||
if not self.is_connected:
|
||||
self.connect()
|
||||
|
||||
effective_timeout = timeout or self._config.timeouts.query_timeout
|
||||
effective_max_rows = max_rows if max_rows is not None else self._config.query.max_rows
|
||||
|
||||
with self.get_cursor(dict_cursor=True) as cursor:
|
||||
logger.info(f"Executing query with limit (max_rows={effective_max_rows}, timeout={effective_timeout}s)")
|
||||
logger.debug(f"Query: {query[:200]}...")
|
||||
|
||||
if effective_timeout > 0:
|
||||
cursor.execute(f"ALTER SESSION SET STATEMENT_TIMEOUT_IN_SECONDS = {effective_timeout}")
|
||||
|
||||
start_time = time.time()
|
||||
cursor.execute(query, params)
|
||||
|
||||
# Fetch one more than max to detect if there are more rows
|
||||
results = cursor.fetchmany(effective_max_rows + 1)
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
has_more = len(results) > effective_max_rows
|
||||
if has_more:
|
||||
results = results[:effective_max_rows]
|
||||
|
||||
logger.info(f"Query returned {len(results)} rows (has_more={has_more}) in {elapsed:.2f}s")
|
||||
return results, has_more # type: ignore[return-value]
|
||||
|
||||
def fetch_activity_data(
|
||||
self,
|
||||
start_date: Optional[date] = None,
|
||||
end_date: Optional[date] = None,
|
||||
provider_codes: Optional[list[str]] = None,
|
||||
max_rows: Optional[int] = None,
|
||||
timeout: Optional[int] = None,
|
||||
) -> list[dict]:
|
||||
"""
|
||||
Fetch high-cost drug activity data from Snowflake.
|
||||
|
||||
Queries the CDM.Acute__Conmon__PatientLevelDrugs table and returns
|
||||
data in a format compatible with the existing analysis pipeline.
|
||||
|
||||
Args:
|
||||
start_date: Optional start date for filtering (inclusive)
|
||||
end_date: Optional end date for filtering (inclusive)
|
||||
provider_codes: Optional list of provider codes to filter by
|
||||
max_rows: Maximum rows to return (default from config)
|
||||
timeout: Query timeout in seconds (default from config)
|
||||
|
||||
Returns:
|
||||
List of dictionaries with keys matching expected DataFrame columns:
|
||||
- PseudoNHSNoLinked: Pseudonymised NHS number (for UPID creation)
|
||||
- Provider Code: NHS provider code
|
||||
- PersonKey: Local patient identifier
|
||||
- Drug Name: Raw drug name
|
||||
- Intervention Date: Date of intervention
|
||||
- Price Actual: Cost of intervention
|
||||
- OrganisationName: Provider organisation name
|
||||
- Treatment Function Code: NHS treatment function code
|
||||
- Additional Detail 1-5: Additional details for directory identification
|
||||
|
||||
Raises:
|
||||
SnowflakeConnectionError: If not connected or query fails
|
||||
"""
|
||||
if not self.is_connected:
|
||||
self.connect()
|
||||
|
||||
# Build the query
|
||||
table_name = 'DATA_HUB.CDM."Acute__Conmon__PatientLevelDrugs"'
|
||||
|
||||
query = f'''
|
||||
SELECT
|
||||
"PseudoNHSNoLinked",
|
||||
"ProviderCode" AS "Provider Code",
|
||||
"LocalPatientID" AS "PersonKey",
|
||||
"DrugName" AS "Drug Name",
|
||||
"InterventionDate" AS "Intervention Date",
|
||||
"PriceActual" AS "Price Actual",
|
||||
"ProviderName" AS "OrganisationName",
|
||||
"TreatmentFunctionCode" AS "Treatment Function Code",
|
||||
"TreatmentFunctionDesc" AS "Treatment Function Desc",
|
||||
"AdditionalDetail1" AS "Additional Detail 1",
|
||||
"AdditionalDescription1" AS "Additional Description 1",
|
||||
"AdditionalDetail2" AS "Additional Detail 2",
|
||||
"AdditionalDescription2" AS "Additional Description 2",
|
||||
"AdditionalDetail3" AS "Additional Detail 3",
|
||||
"AdditionalDescription3" AS "Additional Description 3",
|
||||
"AdditionalDetail4" AS "Additional Detail 4",
|
||||
"AdditionalDescription4" AS "Additional Description 4",
|
||||
"AdditionalDetail5" AS "Additional Detail 5",
|
||||
"AdditionalDescription5" AS "Additional Description 5"
|
||||
FROM {table_name}
|
||||
WHERE 1=1
|
||||
'''
|
||||
|
||||
params = []
|
||||
|
||||
# Add date filters
|
||||
if start_date:
|
||||
query += ' AND "InterventionDate" >= %s'
|
||||
params.append(start_date.isoformat())
|
||||
if end_date:
|
||||
query += ' AND "InterventionDate" <= %s'
|
||||
params.append(end_date.isoformat())
|
||||
|
||||
# Add provider filter
|
||||
if provider_codes:
|
||||
placeholders = ", ".join(["%s"] * len(provider_codes))
|
||||
query += f' AND "ProviderCode" IN ({placeholders})'
|
||||
params.extend(provider_codes)
|
||||
|
||||
# Add ordering for consistent results
|
||||
query += ' ORDER BY "InterventionDate", "ProviderCode", "PseudoNHSNoLinked"'
|
||||
|
||||
logger.info(f"Fetching activity data from Snowflake")
|
||||
if start_date:
|
||||
logger.info(f" Date range: {start_date} to {end_date or 'now'}")
|
||||
if provider_codes:
|
||||
logger.info(f" Providers: {provider_codes}")
|
||||
|
||||
effective_max_rows = max_rows if max_rows is not None else self._config.query.max_rows
|
||||
effective_timeout = timeout or self._config.timeouts.query_timeout
|
||||
|
||||
# Execute with chunked results for large datasets
|
||||
all_results = []
|
||||
total_rows = 0
|
||||
|
||||
for chunk in self.execute_chunked_dict(
|
||||
query,
|
||||
params=tuple(params) if params else None,
|
||||
timeout=effective_timeout,
|
||||
max_rows=effective_max_rows,
|
||||
):
|
||||
all_results.extend(chunk)
|
||||
total_rows += len(chunk)
|
||||
logger.debug(f"Fetched {total_rows} rows so far...")
|
||||
|
||||
logger.info(f"Fetched {len(all_results)} activity records from Snowflake")
|
||||
return all_results
|
||||
|
||||
def test_connection(self) -> tuple[bool, str]:
|
||||
"""
|
||||
Test the Snowflake connection.
|
||||
|
||||
Returns:
|
||||
Tuple of (success: bool, message: str)
|
||||
"""
|
||||
try:
|
||||
self._check_availability()
|
||||
except SnowflakeNotAvailableError as e:
|
||||
return False, str(e)
|
||||
|
||||
try:
|
||||
self._check_configured()
|
||||
except SnowflakeNotConfiguredError as e:
|
||||
return False, str(e)
|
||||
|
||||
try:
|
||||
self.connect()
|
||||
user = self._get_current_user()
|
||||
role = self._get_current_role()
|
||||
return True, f"Connected as {user} with role {role}"
|
||||
except Exception as e:
|
||||
return False, f"Connection failed: {e}"
|
||||
|
||||
def __enter__(self) -> "SnowflakeConnector":
|
||||
"""Context manager entry."""
|
||||
self.connect()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
|
||||
"""Context manager exit."""
|
||||
self.close()
|
||||
|
||||
|
||||
# Module-level singleton for convenience
|
||||
_default_connector: Optional[SnowflakeConnector] = None
|
||||
|
||||
|
||||
def get_connector(config: Optional[SnowflakeConfig] = None) -> SnowflakeConnector:
|
||||
"""
|
||||
Get a Snowflake connector (creates singleton on first call).
|
||||
|
||||
Args:
|
||||
config: Optional configuration. If provided, creates new connector
|
||||
with this config. If None, uses/creates default connector.
|
||||
|
||||
Returns:
|
||||
SnowflakeConnector instance
|
||||
"""
|
||||
global _default_connector
|
||||
|
||||
if config is not None:
|
||||
# Custom config requested, create new connector
|
||||
return SnowflakeConnector(config)
|
||||
|
||||
if _default_connector is None:
|
||||
_default_connector = SnowflakeConnector()
|
||||
|
||||
return _default_connector
|
||||
|
||||
|
||||
def reset_connector() -> None:
|
||||
"""Reset the default connector (closes connection and clears singleton)."""
|
||||
global _default_connector
|
||||
|
||||
if _default_connector is not None:
|
||||
_default_connector.close()
|
||||
_default_connector = None
|
||||
|
||||
|
||||
def is_snowflake_available() -> bool:
|
||||
"""Return True if snowflake-connector-python is installed."""
|
||||
return SNOWFLAKE_AVAILABLE
|
||||
|
||||
|
||||
def is_snowflake_configured() -> bool:
|
||||
"""Return True if Snowflake account is configured."""
|
||||
try:
|
||||
config = get_snowflake_config()
|
||||
return config.is_configured
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
# Export public API
|
||||
__all__ = [
|
||||
"SnowflakeConnector",
|
||||
"SnowflakeConnectionError",
|
||||
"SnowflakeNotConfiguredError",
|
||||
"SnowflakeNotAvailableError",
|
||||
"ConnectionInfo",
|
||||
"get_connector",
|
||||
"reset_connector",
|
||||
"is_snowflake_available",
|
||||
"is_snowflake_configured",
|
||||
"SNOWFLAKE_AVAILABLE",
|
||||
]
|
||||
@@ -0,0 +1,331 @@
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import csv
|
||||
import urllib.request
|
||||
import io # Added for StringIO
|
||||
import re # Added for regex escape and word boundaries
|
||||
from typing import Optional
|
||||
|
||||
from core import PathConfig, default_paths
|
||||
from core.logging_config import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
def drug_names(df, paths: Optional[PathConfig] = None):
|
||||
# Generate dictionary to convert drug names from activity data to generic standardisation
|
||||
if paths is None:
|
||||
paths = default_paths
|
||||
|
||||
d = {}
|
||||
with open(paths.drugnames_csv, 'r', newline='') as f:
|
||||
reader = csv.reader(f, delimiter=',')
|
||||
for drug_name, generic in reader:
|
||||
d[drug_name.upper()] = generic.upper()
|
||||
|
||||
# Map drug names with dictionary generated earlier
|
||||
df["Drug Name"] = df["Drug Name"].str.upper().map(d)
|
||||
|
||||
# Remove (Left eye) or (Right eye) from Drug Name, including whitespace
|
||||
df["Drug Name"] = df["Drug Name"].str.replace(r'\(LEFT EYE\)', '', regex=True) # Escaped parentheses
|
||||
df["Drug Name"] = df["Drug Name"].str.replace(r'\(RIGHT EYE\)', '', regex=True) # Escaped parentheses
|
||||
df["Drug Name"] = df["Drug Name"].str.strip()
|
||||
return df
|
||||
|
||||
|
||||
def patient_id(df):
|
||||
# Generate unique patient ID
|
||||
df["UPID"] = df["Provider Code"].str[:3] + df["PersonKey"].astype(str)
|
||||
return df
|
||||
|
||||
|
||||
def compress_csv(filepath):
|
||||
df = pd.read_csv(filepath)
|
||||
compressed_path = filepath.replace(".csv", "_bz2.csv")
|
||||
df.to_csv(compressed_path, compression="bz2", index=False)
|
||||
return compressed_path
|
||||
|
||||
|
||||
def department_identification(df, paths: Optional[PathConfig] = None):
|
||||
# --- Setup ---
|
||||
if paths is None:
|
||||
paths = default_paths
|
||||
|
||||
# 1. Load directory_list.csv and prepare uppercase versions/pattern
|
||||
try:
|
||||
directory_df = pd.read_csv(paths.directory_list_csv)
|
||||
directory_list = directory_df["directory"].dropna().astype(str).tolist()
|
||||
if not directory_list:
|
||||
raise ValueError("directory_list.csv is empty or contains only NA values.")
|
||||
directory_list_upper = [d.upper() for d in directory_list]
|
||||
# Use word boundaries (\b) to avoid partial matches within words, escape special regex chars
|
||||
dir_pattern_upper = r'\b({})'.format('|'.join(map(re.escape, directory_list_upper)))
|
||||
except FileNotFoundError:
|
||||
logger.error(f"File not found: {paths.directory_list_csv}. Cannot extract directories.")
|
||||
return df
|
||||
except ValueError as e:
|
||||
logger.error(f"Error loading directory list: {e}")
|
||||
return df
|
||||
|
||||
# Simpler pattern for Primary_Source (no word boundaries)
|
||||
dir_pattern_primary_simple = r'({})'.format('|'.join(map(re.escape, directory_list_upper)))
|
||||
|
||||
# 2. Load treatment_function_codes.csv and prepare uppercase mapping
|
||||
treatment_codes = pd.read_csv(paths.treatment_function_codes_csv)
|
||||
mapping_treatment_codes = dict(treatment_codes[['Code', 'Service']].values)
|
||||
mapping_treatment_codes_upper = {k: str(v).upper() for k, v in mapping_treatment_codes.items()}
|
||||
|
||||
# 3. Load drug_directory_list.csv and parse into drug_to_valid_dirs
|
||||
drug_to_valid_dirs: dict[str, set[str]] = {}
|
||||
# Try pandas direct read - much simpler approach
|
||||
drug_dir_df = pd.read_csv(paths.drug_directory_list_csv, skipinitialspace=True)
|
||||
|
||||
# Identify the drug name column (first column) and directory column (second column)
|
||||
drug_col = drug_dir_df.columns[0]
|
||||
dir_col = drug_dir_df.columns[1]
|
||||
|
||||
# Process dataframe directly
|
||||
drug_to_valid_dirs = {}
|
||||
for _, row in drug_dir_df.iterrows():
|
||||
drug_name = str(row[drug_col]).strip().upper()
|
||||
try:
|
||||
# Directories are pipe-separated in the second column
|
||||
dirs_str = str(row[dir_col]) if not pd.isna(row[dir_col]) else ""
|
||||
dirs = {d.strip().upper() for d in dirs_str.split('|') if d.strip()}
|
||||
if drug_name and dirs and drug_name.lower() != 'nan':
|
||||
drug_to_valid_dirs[drug_name] = dirs
|
||||
except Exception:
|
||||
# Silently continue on row errors
|
||||
continue
|
||||
# 4. Create drug_to_single_dir map
|
||||
drug_to_single_dir = {
|
||||
drug: list(dirs)[0]
|
||||
for drug, dirs in drug_to_valid_dirs.items()
|
||||
if len(dirs) == 1
|
||||
}
|
||||
|
||||
# --- Data Preprocessing ---
|
||||
# Keep original extraction columns list
|
||||
additional_detail_columns = ["Additional Detail 1", "Additional Description 1", "Additional Detail 2", "Additional Description 2",
|
||||
"Additional Detail 3", "Additional Description 3", "Additional Detail 4", "Additional Description 4",
|
||||
"Additional Detail 5", "Additional Description 5", "NCDR Treatment Function Name", "Treatment Function Desc"]
|
||||
|
||||
# 6. Convert detail columns to uppercase BEFORE extraction
|
||||
for ad in additional_detail_columns:
|
||||
# Check if column exists and is object/string type before applying .str
|
||||
if ad in df.columns and pd.api.types.is_object_dtype(df[ad]):
|
||||
df[ad] = df[ad].str.upper()
|
||||
|
||||
# Original extraction loop (using original case list for extraction)
|
||||
# Extract directory from specified columns
|
||||
directory_df = pd.read_csv(paths.directory_list_csv)
|
||||
directory_list = directory_df["directory"].tolist() # Reload original case list
|
||||
|
||||
for ad in additional_detail_columns:
|
||||
try:
|
||||
# Ensure column is string type before cleaning
|
||||
if pd.api.types.is_string_dtype(df[ad]):
|
||||
# Extract directly from the uppercased string column
|
||||
extracted = df[ad].str.extract(dir_pattern_upper, expand=False)
|
||||
df.loc[extracted.index, ad] = extracted
|
||||
else:
|
||||
df[ad] = np.nan # Set non-string columns to NaN
|
||||
except AttributeError: # Skip columns that might not exist or are not string type
|
||||
df[ad] = np.nan # Ensure column exists but set to NaN if error
|
||||
except Exception as e: # Catch other potential errors during extract
|
||||
logger.error(f"Error processing column {ad}: {e}")
|
||||
df[ad] = np.nan
|
||||
|
||||
# 7. Process Treatment Function Code
|
||||
df["Treatment Function Code"].replace(np.nan, 0, inplace=True)
|
||||
# Ensure it's int type before mapping, handle potential errors
|
||||
try:
|
||||
df["Treatment Function Code"] = df["Treatment Function Code"].astype(int)
|
||||
except ValueError:
|
||||
# Handle cases where conversion to int fails (e.g., non-numeric values)
|
||||
# Try coercing errors to NaN, then fillna with 0
|
||||
df["Treatment Function Code"] = pd.to_numeric(df["Treatment Function Code"], errors='coerce').fillna(0).astype(int)
|
||||
|
||||
df["Treatment Function Code"] = df["Treatment Function Code"].map(mapping_treatment_codes_upper)
|
||||
df.rename(columns={'Treatment Function Code': 'Fallback_Source'}, inplace=True)
|
||||
|
||||
# Apply replacements before combining
|
||||
df.replace('MEDICAL OPHTHALMOLOGY', 'OPHTHALMOLOGY', inplace=True)
|
||||
|
||||
# --- Single Directory Assignment ---
|
||||
# 8. Apply single directory override
|
||||
# Ensure Drug Name is suitable for mapping (already done in drug_names func)
|
||||
df['Directory'] = df['Drug Name'].map(drug_to_single_dir)
|
||||
|
||||
# Initialize Directory_Source column - track which fallback level was used
|
||||
df['Directory_Source'] = pd.NA
|
||||
# Mark rows where single valid directory was assigned
|
||||
df.loc[df['Directory'].notna(), 'Directory_Source'] = 'SINGLE_VALID_DIR'
|
||||
|
||||
# --- Prepare Fallback Logic ---
|
||||
# 9. Create Primary source from Additional Detail 1
|
||||
if 'Additional Detail 1' in df.columns:
|
||||
df['Primary_Source'] = df['Additional Detail 1'].astype(pd.StringDtype())
|
||||
df['Primary_Source'] = df['Primary_Source'].str.upper() # Apply upper to strings
|
||||
else:
|
||||
df['Primary_Source'] = pd.NA # Use pd.NA for StringDtype
|
||||
|
||||
# Extract actual directory name using the pattern
|
||||
try:
|
||||
# Use simpler pattern for primary source
|
||||
df['Extracted_Primary_Dir'] = df['Primary_Source'].str.extract(dir_pattern_primary_simple, expand=False, flags=re.IGNORECASE)
|
||||
df['Extracted_Fallback_Dir'] = df['Fallback_Source'].str.extract(dir_pattern_upper, expand=False, flags=re.IGNORECASE)
|
||||
except Exception as e:
|
||||
logger.error(f"Error during directory extraction: {e}")
|
||||
# Assign NA columns if extraction fails
|
||||
df['Extracted_Primary_Dir'] = pd.NA
|
||||
df['Extracted_Fallback_Dir'] = pd.NA
|
||||
|
||||
# Strip potential whitespace from extracted directories
|
||||
if 'Extracted_Primary_Dir' in df.columns:
|
||||
df['Extracted_Primary_Dir'] = df['Extracted_Primary_Dir'].str.strip()
|
||||
if 'Extracted_Fallback_Dir' in df.columns:
|
||||
df['Extracted_Fallback_Dir'] = df['Extracted_Fallback_Dir'].str.strip()
|
||||
|
||||
# 10. Combine sources, prioritizing Primary_Source
|
||||
# Combine EXTRACTED directories
|
||||
df['Primary_Directory'] = df['Extracted_Primary_Dir'].fillna(df['Extracted_Fallback_Dir'])
|
||||
|
||||
# Track extraction source for Directory_Source column
|
||||
# Rows where we have Extracted_Primary_Dir will use EXTRACTED_PRIMARY
|
||||
# Rows where we only have Extracted_Fallback_Dir will use EXTRACTED_FALLBACK
|
||||
df['_extracted_source'] = pd.NA
|
||||
df.loc[df['Extracted_Primary_Dir'].notna(), '_extracted_source'] = 'EXTRACTED_PRIMARY'
|
||||
df.loc[(df['Extracted_Primary_Dir'].isna()) & (df['Extracted_Fallback_Dir'].notna()), '_extracted_source'] = 'EXTRACTED_FALLBACK'
|
||||
|
||||
# 11. Clean up intermediate columns
|
||||
df.drop(columns=['Primary_Source', 'Fallback_Source', 'Extracted_Primary_Dir', 'Extracted_Fallback_Dir'], inplace=True, errors='ignore')
|
||||
|
||||
# --- Identify Rows Needing Calculation ---
|
||||
# 12. Filter rows where Directory is not yet assigned
|
||||
df_to_process = df[df['Directory'].isnull()].copy()
|
||||
|
||||
# --- Calculate Most Frequent Valid Directory ---
|
||||
# 13. Drop rows without a potential primary directory
|
||||
df_to_process.dropna(subset=['Primary_Directory'], inplace=True)
|
||||
|
||||
# 14. Group and count potential directories
|
||||
if not df_to_process.empty:
|
||||
df_counts = df_to_process.groupby(['UPID', 'Drug Name', 'Primary_Directory'], observed=True)['Primary_Directory'].count().reset_index(name='count')
|
||||
|
||||
# 15. Sort by count descending
|
||||
df_counts.sort_values(['UPID', 'Drug Name', 'count'], ascending=[True, True, False], inplace=True)
|
||||
|
||||
# 16. Define helper function
|
||||
def find_first_valid_dir(group, drug_map):
|
||||
drug_name = group['Drug Name'].iloc[0]
|
||||
valid_dirs = drug_map.get(drug_name, set())
|
||||
|
||||
if not valid_dirs:
|
||||
return np.nan
|
||||
|
||||
for dir_candidate in group['Primary_Directory']:
|
||||
# Skip NA values
|
||||
if pd.isna(dir_candidate):
|
||||
continue
|
||||
|
||||
# Check if valid directory for this drug
|
||||
if isinstance(dir_candidate, str) and dir_candidate in valid_dirs:
|
||||
return dir_candidate
|
||||
|
||||
return np.nan # No valid directory found in the group
|
||||
|
||||
# 17. Group by UPID and Drug Name
|
||||
valid_groups = df_counts.groupby(['UPID', 'Drug Name'], observed=True, group_keys=False)
|
||||
|
||||
# 18. Apply helper function to find the best valid directory
|
||||
calculated_dirs = valid_groups.apply(lambda grp: find_first_valid_dir(grp, drug_to_valid_dirs))
|
||||
|
||||
# 19. Reset index to get UPID, Drug Name columns
|
||||
final_mapping = calculated_dirs.reset_index()
|
||||
|
||||
# 20. Rename the resulting column
|
||||
final_mapping.columns = ['UPID', 'Drug Name', 'Calculated_Directory']
|
||||
|
||||
# --- Merge Results and Finalize ---
|
||||
# 21. Merge calculated directories back to the main DataFrame
|
||||
df = pd.merge(df, final_mapping, on=['UPID', 'Drug Name'], how='left')
|
||||
|
||||
# 22. Fill NaN Directories with the calculated ones and track source
|
||||
# Find rows that will be filled from Calculated_Directory
|
||||
rows_to_fill = df['Directory'].isna() & df['Calculated_Directory'].notna()
|
||||
# For these rows, set Directory_Source based on _extracted_source (where the calculated dir came from)
|
||||
# The "calculated" directory is still derived from extraction, just via frequency analysis
|
||||
df.loc[rows_to_fill, 'Directory_Source'] = df.loc[rows_to_fill, '_extracted_source'].fillna('CALCULATED_MOST_FREQ')
|
||||
# Replace with the actual value of _extracted_source or fall back to CALCULATED_MOST_FREQ
|
||||
# Actually, let's simplify: if we're using the calculated most frequent directory, that's CALCULATED_MOST_FREQ
|
||||
df.loc[rows_to_fill, 'Directory_Source'] = 'CALCULATED_MOST_FREQ'
|
||||
|
||||
df['Directory'].fillna(df['Calculated_Directory'], inplace=True)
|
||||
|
||||
# 23. Drop temporary columns
|
||||
df.drop(columns=['Calculated_Directory', 'Primary_Directory', '_extracted_source'], inplace=True, errors='ignore')
|
||||
|
||||
else:
|
||||
# If df_to_process was empty, still need to drop temporary columns
|
||||
df.drop(columns=['Primary_Directory', '_extracted_source'], inplace=True, errors='ignore')
|
||||
|
||||
# 24. Drop rows with missing UPID (original logic)
|
||||
df['UPID'].replace('', np.nan, inplace=True) # Ensure empty strings are NaN
|
||||
df_orig = df.copy() # Save before dropna for future reference if needed
|
||||
df.dropna(subset=['UPID'], inplace=True)
|
||||
|
||||
# 25. Export rows with NA Directory to CSV for analysis (keep this for diagnostics)
|
||||
na_directory_rows = df[df['Directory'].isna()].copy()
|
||||
|
||||
# Export to CSV if there are any NA Directory rows
|
||||
if len(na_directory_rows) > 0:
|
||||
na_directory_rows.to_csv(paths.na_directory_rows_csv, index=False)
|
||||
|
||||
# 26. FALLBACK MECHANISM 1: Infer directory based on same UPID
|
||||
# Create a mapping of most frequent directory per UPID (only for UPIDs with a directory)
|
||||
if len(df[df['Directory'].isna()]) > 0:
|
||||
# First get valid directories per UPID
|
||||
valid_upid_dirs = df[df['Directory'].notna()].groupby('UPID')['Directory'].agg(
|
||||
lambda x: x.value_counts().index[0] if len(x.value_counts()) > 0 else None
|
||||
).to_dict()
|
||||
|
||||
# Apply UPID-based inference and track source
|
||||
for idx in df[df['Directory'].isna()].index:
|
||||
upid = df.loc[idx, 'UPID']
|
||||
if upid in valid_upid_dirs and valid_upid_dirs[upid] is not None:
|
||||
df.loc[idx, 'Directory'] = valid_upid_dirs[upid]
|
||||
df.loc[idx, 'Directory_Source'] = 'UPID_INFERENCE'
|
||||
|
||||
# 27. FALLBACK MECHANISM 2: Label remaining NA as "Undefined"
|
||||
# Track rows that will be marked as Undefined
|
||||
rows_undefined = df['Directory'].isna()
|
||||
df.loc[rows_undefined, 'Directory_Source'] = 'UNDEFINED'
|
||||
# Fill remaining NA directories with "Undefined"
|
||||
df['Directory'].fillna("Undefined", inplace=True)
|
||||
|
||||
# 28. Return the processed DataFrame
|
||||
return df
|
||||
|
||||
|
||||
|
||||
def ta_list_get(paths: Optional[PathConfig] = None):
|
||||
if paths is None:
|
||||
paths = default_paths
|
||||
|
||||
link = "https://www.nice.org.uk/Media/Default/About/what-we-do/NICE-guidance/NICE-technology-appraisals/TA%20recommendations.xlsx"
|
||||
urllib.request.urlretrieve(link, paths.ta_recommendations_xlsx)
|
||||
ta_db = pd.read_excel(paths.ta_recommendations_xlsx, index_col=0)
|
||||
|
||||
# Filter out TA's which are not Recommended or not Pharmaceutical
|
||||
ta_db = ta_db[ta_db["Categorisation (for specific recommendation)"].isin(["Recommended", "Optimised"])]
|
||||
ta_db = ta_db[ta_db["Technology type"] == "Pharmaceutical"]
|
||||
|
||||
# Amend TA001 strings to only the integer
|
||||
ta_db["TA ID"] = ta_db["TA ID"].str.replace(r'\D+', '', regex=True).astype(int)
|
||||
ta_db["TA ID"] = "NICE TA" + ta_db["TA ID"].astype(str)
|
||||
ta_series = ta_db[["TA ID", "Indication"]].drop_duplicates()
|
||||
return ta_series
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,27 @@
|
||||
# visualization/ - Plotly Chart Generation
|
||||
|
||||
## Module: plotly_generator.py
|
||||
|
||||
Generates interactive Plotly icicle charts for patient pathway hierarchies.
|
||||
|
||||
### Key Functions
|
||||
|
||||
**create_icicle_figure(ice_df, title)**
|
||||
- Builds Plotly icicle figure from pre-processed DataFrame
|
||||
- Uses 10-field customdata: value, costpp, cost_pp_pa, first/last seen (node + parent), average_spacing, avg_days
|
||||
- Color gradient based on patient volume (darkblue=high, lightblue=low)
|
||||
- Hover template shows full treatment pathway and statistics
|
||||
|
||||
**save_figure_html(figure, filepath)**
|
||||
- Exports interactive HTML file with embedded Plotly.js
|
||||
|
||||
**open_figure_in_browser(filepath)**
|
||||
- Opens HTML file in default browser
|
||||
|
||||
### Data Requirements
|
||||
|
||||
Input DataFrame must have columns: parents, ids, labels, value, colour, cost, costpp, cost_pp_pa, first_seen, last_seen, first_seen_parent, last_seen_parent, average_spacing, avg_days
|
||||
|
||||
### Output
|
||||
|
||||
Interactive icicle chart showing Trust → Directory/Indication → Drug → Pathway hierarchy with rich tooltips.
|
||||
@@ -0,0 +1,18 @@
|
||||
"""
|
||||
Visualization package for patient pathway charts.
|
||||
|
||||
This package contains functions for generating interactive Plotly visualizations:
|
||||
- plotly_generator: Create icicle charts for patient pathway analysis
|
||||
"""
|
||||
|
||||
from visualization.plotly_generator import (
|
||||
create_icicle_figure,
|
||||
save_figure_html,
|
||||
open_figure_in_browser,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"create_icicle_figure",
|
||||
"save_figure_html",
|
||||
"open_figure_in_browser",
|
||||
]
|
||||
@@ -0,0 +1,231 @@
|
||||
"""
|
||||
Plotly chart generation for patient pathway analysis.
|
||||
|
||||
This module contains functions for creating interactive icicle charts
|
||||
that visualize patient treatment pathways. The charts display hierarchical
|
||||
data: Trust → Directory → Drug → Pathway.
|
||||
"""
|
||||
|
||||
import webbrowser
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import plotly.graph_objects as go
|
||||
|
||||
from core.logging_config import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def create_icicle_figure(ice_df: pd.DataFrame, title: str) -> go.Figure:
|
||||
"""
|
||||
Create Plotly icicle figure from prepared DataFrame.
|
||||
|
||||
This function generates an interactive icicle chart showing patient pathway
|
||||
hierarchies with custom data including costs, dates, and treatment durations.
|
||||
|
||||
Args:
|
||||
ice_df: DataFrame with columns:
|
||||
- parents: Parent node in hierarchy
|
||||
- ids: Unique identifier for each node
|
||||
- labels: Display label for each node
|
||||
- value: Number of patients
|
||||
- colour: Color value for visualization
|
||||
- cost: Total cost
|
||||
- costpp: Cost per patient
|
||||
- cost_pp_pa: Cost per patient per annum
|
||||
- First seen: First intervention date
|
||||
- Last seen: Last intervention date
|
||||
- First seen (Parent): Earliest date in parent group
|
||||
- Last seen (Parent): Latest date in parent group
|
||||
- average_spacing: Formatted string with dosing information
|
||||
- avg_days: Average treatment duration
|
||||
title: Chart title
|
||||
|
||||
Returns:
|
||||
Plotly Figure object ready for display or export
|
||||
"""
|
||||
ice_df = ice_df.copy()
|
||||
ice_df.sort_values(by=["labels"], ascending=True, inplace=True, ignore_index=True)
|
||||
|
||||
first_seen = ice_df["First seen"].astype(str).replace("NaT", "N/A").to_list()
|
||||
last_seen = ice_df["Last seen"].astype(str).replace("NaT", "N/A").to_list()
|
||||
first_seen_parent = ice_df["First seen (Parent)"].astype(str).to_list()
|
||||
last_seen_parent = ice_df["Last seen (Parent)"].astype(str).to_list()
|
||||
average_spacing = ice_df.average_spacing.astype(str).to_list()
|
||||
|
||||
fig = go.Figure(
|
||||
go.Icicle(
|
||||
labels=ice_df.labels,
|
||||
ids=ice_df.ids,
|
||||
parents=ice_df.parents,
|
||||
customdata=np.stack(
|
||||
(
|
||||
ice_df.value,
|
||||
ice_df.colour,
|
||||
ice_df.cost,
|
||||
ice_df.costpp,
|
||||
first_seen,
|
||||
last_seen,
|
||||
first_seen_parent,
|
||||
last_seen_parent,
|
||||
average_spacing,
|
||||
ice_df.cost_pp_pa,
|
||||
),
|
||||
axis=1,
|
||||
),
|
||||
values=ice_df.value,
|
||||
branchvalues="total",
|
||||
marker=dict(colors=ice_df.colour, colorscale="Viridis"),
|
||||
maxdepth=3,
|
||||
texttemplate="<b>%{label}</b> "
|
||||
"<br><b>Total patients:</b> %{customdata[0]} (including children/further treatments)"
|
||||
"<br><b>First seen:</b> %{customdata[4]}"
|
||||
"<br><b>Last seen (including further treatments):</b> %{customdata[7]}"
|
||||
"<br><b>Average treatment duration:</b> %{customdata[8]}"
|
||||
"<br><b>Total cost:</b> £%{customdata[2]:.3~s}"
|
||||
"<br><b>Average cost per patient:</b> £%{customdata[3]:.3~s}"
|
||||
"<br><b>Average cost per patient per annum:</b> £%{customdata[9]:.3~s}",
|
||||
hovertemplate="<b>%{label}</b>"
|
||||
"<br><b>Total patients:</b> %{customdata[0]} - %{customdata[1]:.3p} of patients in level"
|
||||
"<br><b>Total cost:</b> £%{customdata[2]:.3~s}"
|
||||
"<br><b>Average cost per patient:</b> £%{customdata[3]:.3~s}"
|
||||
"<br><b>Average cost per patient per annum:</b> £%{customdata[9]:.3~s}"
|
||||
"<br><b>First seen:</b> %{customdata[4]}"
|
||||
"<br><b>Last seen (including further treatments):</b> %{customdata[7]}"
|
||||
"<br><b>Average treatment duration:</b>"
|
||||
"%{customdata[8]}"
|
||||
"<extra></extra>",
|
||||
)
|
||||
)
|
||||
fig.update_traces(sort=False)
|
||||
fig.update_layout(
|
||||
margin=dict(t=60, l=1, r=1, b=60),
|
||||
title=f"Norfolk & Waveney ICS high-cost drug patient pathways - {title}",
|
||||
title_x=0.5,
|
||||
hoverlabel=dict(font_size=16),
|
||||
)
|
||||
|
||||
return fig
|
||||
|
||||
|
||||
def save_figure_html(
|
||||
fig: go.Figure, save_dir: str, title: str, open_browser: bool = False
|
||||
) -> str:
|
||||
"""
|
||||
Save Plotly figure to HTML file.
|
||||
|
||||
Args:
|
||||
fig: Plotly Figure object
|
||||
save_dir: Directory to save the HTML file
|
||||
title: Title used for filename
|
||||
open_browser: If True, open the file in the default browser
|
||||
|
||||
Returns:
|
||||
Path to the saved HTML file
|
||||
"""
|
||||
filepath = f"{save_dir}/{title}.html"
|
||||
fig.write_html(filepath)
|
||||
logger.info(f"Success! File saved to {filepath}")
|
||||
|
||||
if open_browser:
|
||||
open_figure_in_browser(filepath)
|
||||
|
||||
return filepath
|
||||
|
||||
|
||||
def open_figure_in_browser(filepath: str) -> None:
|
||||
"""
|
||||
Open an HTML file in the default browser.
|
||||
|
||||
Args:
|
||||
filepath: Path to the HTML file
|
||||
"""
|
||||
webbrowser.open_new_tab("file:///" + filepath)
|
||||
|
||||
|
||||
def figure_legacy(ice_df: pd.DataFrame, dir_string: str, save_dir: str) -> None:
|
||||
"""
|
||||
Create and display icicle figure (legacy interface).
|
||||
|
||||
This function maintains backward compatibility with the original figure()
|
||||
function signature. It creates the figure, saves it to HTML, and opens
|
||||
it in the browser.
|
||||
|
||||
Args:
|
||||
ice_df: DataFrame with chart data
|
||||
dir_string: Title string (used for filename and chart title)
|
||||
save_dir: Directory to save the HTML file
|
||||
|
||||
Note:
|
||||
This function is provided for backward compatibility.
|
||||
New code should use create_icicle_figure() + save_figure_html() instead.
|
||||
"""
|
||||
# Handle avg_days column for display
|
||||
ice_df = ice_df.copy()
|
||||
ice_df.sort_values(by=["labels"], ascending=True, inplace=True, ignore_index=True)
|
||||
|
||||
first_seen = ice_df["First seen"].astype(str).replace("NaT", "N/A").to_list()
|
||||
last_seen = ice_df["Last seen"].astype(str).replace("NaT", "N/A").to_list()
|
||||
first_seen_parent = ice_df["First seen (Parent)"].astype(str).to_list()
|
||||
last_seen_parent = ice_df["Last seen (Parent)"].astype(str).to_list()
|
||||
average_spacing = ice_df.average_spacing.astype(str).to_list()
|
||||
avg_seen = ice_df["avg_days"].dt.round("D").astype(str).replace("0 days", "N/A").to_list()
|
||||
|
||||
fig = go.Figure(
|
||||
go.Icicle(
|
||||
labels=ice_df.labels,
|
||||
ids=ice_df.ids,
|
||||
parents=ice_df.parents,
|
||||
customdata=np.stack(
|
||||
(
|
||||
ice_df.value,
|
||||
ice_df.colour,
|
||||
ice_df.cost,
|
||||
ice_df.costpp,
|
||||
first_seen,
|
||||
last_seen,
|
||||
first_seen_parent,
|
||||
last_seen_parent,
|
||||
average_spacing,
|
||||
ice_df.cost_pp_pa,
|
||||
),
|
||||
axis=1,
|
||||
),
|
||||
values=ice_df.value,
|
||||
branchvalues="total",
|
||||
marker=dict(colors=ice_df.colour, colorscale="Viridis"),
|
||||
maxdepth=3,
|
||||
texttemplate="<b>%{label}</b> "
|
||||
"<br><b>Total patients:</b> %{customdata[0]} (including children/further treatments)"
|
||||
"<br><b>First seen:</b> %{customdata[4]}"
|
||||
"<br><b>Last seen (including further treatments):</b> %{customdata[7]}"
|
||||
"<br><b>Average treatment duration:</b> %{customdata[8]}"
|
||||
"<br><b>Total cost:</b> £%{customdata[2]:.3~s}"
|
||||
"<br><b>Average cost per patient:</b> £%{customdata[3]:.3~s}"
|
||||
"<br><b>Average cost per patient per annum:</b> £%{customdata[9]:.3~s}",
|
||||
hovertemplate="<b>%{label}</b>"
|
||||
"<br><b>Total patients:</b> %{customdata[0]} - %{customdata[1]:.3p} of patients in level"
|
||||
"<br><b>Total cost:</b> £%{customdata[2]:.3~s}"
|
||||
"<br><b>Average cost per patient:</b> £%{customdata[3]:.3~s}"
|
||||
"<br><b>Average cost per patient per annum:</b> £%{customdata[9]:.3~s}"
|
||||
"<br><b>First seen:</b> %{customdata[4]}"
|
||||
"<br><b>Last seen (including further treatments):</b> %{customdata[7]}"
|
||||
"<br><b>Average treatment duration:</b>"
|
||||
"%{customdata[8]}"
|
||||
"<extra></extra>",
|
||||
)
|
||||
)
|
||||
fig.update_traces(sort=False)
|
||||
fig.update_layout(
|
||||
margin=dict(t=60, l=1, r=1, b=60),
|
||||
title=f"Norfolk & Waveney ICS high-cost drug patient pathways - {dir_string}",
|
||||
title_x=0.5,
|
||||
hoverlabel=dict(font_size=16),
|
||||
)
|
||||
|
||||
filepath = f"{save_dir}/{dir_string}.html"
|
||||
fig.write_html(filepath)
|
||||
logger.info(f"Success! File saved to {filepath}")
|
||||
webbrowser.open_new_tab("file:///" + filepath)
|
||||
Reference in New Issue
Block a user