Files
HighCostDrugsDemo/src/data_processing/pathway_queries.py
T
Andrew Charlwood d98cd4fd69 feat: add 7 analytics chart query functions (Task 9.2)
New query functions in src/data_processing/pathway_queries.py:
- get_drug_market_share: Level 3 drug nodes grouped by directory
- get_pathway_costs: Level 4+ pathway nodes with cost_pp_pa
- get_cost_waterfall: Directorate cost per patient from level 3 aggregation
- get_drug_transitions: Sankey source/target drug transitions with ordinal line labels
- get_dosing_intervals: Parsed average_spacing by trust/directory
- get_drug_directory_matrix: Directory x drug pivot with patient/cost metrics
- get_treatment_durations: Weighted avg_days by drug within directorates

Thin wrappers added in dash_app/data/queries.py for all 7 functions.
2026-02-06 19:21:10 +00:00

810 lines
28 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
Shared query functions for pathway node data.
These functions extract the data loading logic from the Reflex AppState
into standalone functions that accept db_path as a parameter and return
plain JSON-serializable dicts/lists. Both Reflex and Dash can call these.
All queries are read-only SELECTs against pathways.db.
"""
import sqlite3
from pathlib import Path
from typing import Optional
def load_initial_data(db_path: Path) -> dict:
"""
Load reference data from SQLite on app initialization.
Extracted from AppState.load_data() (pathways_app.py lines 407-488).
Returns dict with keys:
available_drugs: sorted list of unique drug labels (level 3 nodes)
available_directorates: sorted list of unique directorate labels (level 2, directory charts)
available_indications: sorted list of unique indications from ref_drug_indication_clusters
total_records: source row count from latest completed refresh
last_updated: ISO timestamp of latest completed refresh
"""
if not db_path.exists():
return {
"available_drugs": [],
"available_directorates": [],
"available_indications": [],
"available_trusts": [],
"total_records": 0,
"last_updated": "",
"error": "Database not found",
}
conn = sqlite3.connect(str(db_path))
try:
cursor = conn.cursor()
# Latest completed refresh metadata
cursor.execute("""
SELECT source_row_count, completed_at
FROM pathway_refresh_log
WHERE status = 'completed'
ORDER BY started_at DESC
LIMIT 1
""")
refresh_row = cursor.fetchone()
total_records = (refresh_row[0] or 0) if refresh_row else 0
last_updated = (refresh_row[1] or "") if refresh_row else ""
# Unique drugs from pathway_nodes level 3
cursor.execute("""
SELECT DISTINCT labels
FROM pathway_nodes
WHERE level = 3 AND labels IS NOT NULL AND labels != ''
ORDER BY labels
""")
available_drugs = [row[0] for row in cursor.fetchall()]
# Unique directorates from directory chart pathway_nodes level 2
cursor.execute("""
SELECT DISTINCT labels
FROM pathway_nodes
WHERE level = 2 AND chart_type = 'directory'
AND labels IS NOT NULL AND labels != ''
ORDER BY labels
""")
available_directorates = [row[0] for row in cursor.fetchall()]
# Unique indications from ref_drug_indication_clusters
cursor.execute("""
SELECT DISTINCT indication
FROM ref_drug_indication_clusters
WHERE indication IS NOT NULL AND indication != ''
ORDER BY indication
""")
available_indications = [row[0] for row in cursor.fetchall()]
if not available_indications:
available_indications = ["(No indications available)"]
# Unique trusts from pathway_nodes level 1
cursor.execute("""
SELECT DISTINCT trust_name
FROM pathway_nodes
WHERE level = 1 AND trust_name IS NOT NULL AND trust_name != ''
ORDER BY trust_name
""")
available_trusts = [row[0] for row in cursor.fetchall()]
# Total patients from default root node (fallback when source_row_count is empty)
total_patients = 0
if not total_records:
cursor.execute("""
SELECT value FROM pathway_nodes
WHERE level = 0 AND date_filter_id = 'all_6mo' AND chart_type = 'directory'
LIMIT 1
""")
root_row = cursor.fetchone()
total_patients = (root_row[0] or 0) if root_row else 0
return {
"available_drugs": available_drugs,
"available_directorates": available_directorates,
"available_indications": available_indications,
"available_trusts": available_trusts,
"total_records": total_records,
"total_patients": total_patients,
"last_updated": last_updated,
}
except sqlite3.Error as e:
return {
"available_drugs": [],
"available_directorates": [],
"available_indications": [],
"available_trusts": [],
"total_records": 0,
"last_updated": "",
"error": f"Database error: {e}",
}
finally:
conn.close()
def load_pathway_nodes(
db_path: Path,
filter_id: str,
chart_type: str,
selected_drugs: Optional[list[str]] = None,
selected_directorates: Optional[list[str]] = None,
selected_trusts: Optional[list[str]] = None,
) -> dict:
"""
Load pre-computed pathway nodes from SQLite.
Extracted from AppState.load_pathway_data() (pathways_app.py lines 490-642).
Args:
db_path: Path to pathways.db
filter_id: e.g. "all_6mo", "2yr_12mo"
chart_type: "directory" or "indication"
selected_drugs: optional list of drug names to filter by
selected_directorates: optional list of directorate names to filter by
selected_trusts: optional list of trust names to filter by
Returns dict with keys:
nodes: list of dicts (JSON-serializable) with chart node data
unique_patients: int (from root node)
total_drugs: int (unique drugs across level 3+ nodes)
total_cost: float (from root node)
last_updated: ISO timestamp string
error: optional error string
"""
if not db_path.exists():
return _empty_result("Database not found")
conn = sqlite3.connect(str(db_path))
conn.row_factory = sqlite3.Row
try:
cursor = conn.cursor()
# Build WHERE clause with parameterized values
where_clauses = ["date_filter_id = ?", "chart_type = ?"]
params: list = [filter_id, chart_type]
if selected_directorates:
placeholders = ",".join("?" * len(selected_directorates))
where_clauses.append(
f"(level < 2 OR directory IN ({placeholders}) OR directory IS NULL OR directory = '')"
)
params.extend(selected_directorates)
if selected_drugs:
drug_conditions = []
for drug in selected_drugs:
drug_conditions.append("drug_sequence LIKE ?")
params.append(f"%{drug}%")
where_clauses.append(
f"(level < 3 OR {' OR '.join(drug_conditions)})"
)
if selected_trusts:
placeholders = ",".join("?" * len(selected_trusts))
where_clauses.append(
f"(trust_name IN ({placeholders}) OR trust_name IS NULL OR trust_name = '')"
)
params.extend(selected_trusts)
where_clause = " AND ".join(where_clauses)
query = f"""
SELECT
parents, ids, labels, level, value,
cost, costpp, cost_pp_pa, colour,
first_seen, last_seen, first_seen_parent, last_seen_parent,
average_spacing, average_administered, avg_days,
trust_name, directory, drug_sequence
FROM pathway_nodes
WHERE {where_clause}
ORDER BY level, parents, ids
"""
cursor.execute(query, params)
rows = cursor.fetchall()
if not rows:
return _empty_result(f"No pathway data for filter: {filter_id}")
# When drug or directorate filters are active, prune ancestor nodes
# that have no matching descendants. Without this, the icicle chart
# shows empty directorate/trust boxes with no children.
if selected_drugs or selected_directorates:
rows = _prune_empty_ancestors(rows)
nodes = []
root_patients = 0
root_cost = 0.0
has_entity_filter = bool(selected_drugs or selected_directorates or selected_trusts)
for row in rows:
node = {
"parents": row["parents"] or "",
"ids": row["ids"] or "",
"labels": row["labels"] or "",
"value": row["value"] or 0,
"cost": float(row["cost"]) if row["cost"] else 0.0,
"costpp": float(row["costpp"]) if row["costpp"] else 0.0,
"colour": float(row["colour"]) if row["colour"] else 0.0,
"first_seen": row["first_seen"] or "",
"last_seen": row["last_seen"] or "",
"first_seen_parent": row["first_seen_parent"] or "",
"last_seen_parent": row["last_seen_parent"] or "",
"average_spacing": row["average_spacing"] or "",
"cost_pp_pa": row["cost_pp_pa"] or "",
}
nodes.append(node)
if row["level"] == 0:
root_patients = row["value"] or 0
root_cost = float(row["cost"]) if row["cost"] else 0.0
# Count unique drugs from level 3+ nodes
unique_drugs = set()
for row in rows:
if row["level"] >= 3 and row["drug_sequence"]:
for drug in row["drug_sequence"].split("|"):
if drug:
unique_drugs.add(drug)
# When entity filters are active, sum level-3 drug nodes for KPIs
# instead of using the root node's pre-computed (unfiltered) totals
if has_entity_filter:
filtered_patients = sum(
row["value"] or 0 for row in rows if row["level"] == 3
)
filtered_cost = sum(
float(row["cost"]) if row["cost"] else 0.0
for row in rows if row["level"] == 3
)
root_patients = filtered_patients
root_cost = filtered_cost
# Data freshness
cursor.execute("""
SELECT completed_at
FROM pathway_refresh_log
WHERE status = 'completed'
ORDER BY completed_at DESC
LIMIT 1
""")
refresh_row = cursor.fetchone()
last_updated = (
refresh_row["completed_at"] if refresh_row and refresh_row["completed_at"] else ""
)
return {
"nodes": nodes,
"unique_patients": root_patients,
"total_drugs": len(unique_drugs),
"total_cost": root_cost,
"last_updated": last_updated,
}
except sqlite3.Error as e:
return _empty_result(f"Database error: {e}")
finally:
conn.close()
def _prune_empty_ancestors(rows):
"""Remove ancestor nodes that have no matching descendants.
When drug/directorate filters are active, levels 0-2 are included
unconditionally. This leaves directorate and trust nodes that have no
children in the filtered result. Plotly's icicle chart shows these as
empty boxes. Prune them by keeping only nodes whose ids appear as a
parent of another kept node, or that are leaf nodes (level 3+), or
are the root (level 0).
"""
# Collect all parent references from the result set
referenced_parents = {row["parents"] for row in rows if row["parents"]}
# Keep: root (level 0), any node referenced as a parent, any leaf (level 3+)
kept = [
row for row in rows
if row["level"] == 0
or row["level"] >= 3
or row["ids"] in referenced_parents
]
# Second pass: a trust (level 1) may reference root but itself have no
# kept level-2 children. Recheck that level-1 nodes are still parents
# of something in the kept set.
kept_parents = {row["parents"] for row in kept if row["parents"]}
return [
row for row in kept
if row["level"] == 0
or row["level"] >= 3
or row["ids"] in kept_parents
]
def _empty_result(error: str = "") -> dict:
return {
"nodes": [],
"unique_patients": 0,
"total_drugs": 0,
"total_cost": 0.0,
"last_updated": "",
"error": error,
}
# ---------------------------------------------------------------------------
# Analytics chart query functions (Phase 9)
# ---------------------------------------------------------------------------
def _safe_float(value, default=0.0):
"""Convert a value to float, returning default for None/N/A/empty."""
if value is None or value == "" or value == "N/A":
return default
try:
return float(value)
except (ValueError, TypeError):
return default
def get_drug_market_share(
db_path: Path,
date_filter_id: str,
chart_type: str,
directory: Optional[str] = None,
trust: Optional[str] = None,
) -> list[dict]:
"""Level 3 drug nodes grouped by directory with patient counts and proportions.
Returns list of dicts: [{directory, drug, patients, proportion, cost, cost_pp_pa}]
Sorted by directory total patients desc, then drug patients desc within each.
"""
conn = sqlite3.connect(str(db_path))
conn.row_factory = sqlite3.Row
try:
where = ["date_filter_id = ?", "chart_type = ?", "level = 3"]
params: list = [date_filter_id, chart_type]
if directory:
where.append("directory = ?")
params.append(directory)
if trust:
where.append("trust_name = ?")
params.append(trust)
query = f"""
SELECT labels AS drug, directory, value AS patients,
colour AS proportion, cost, cost_pp_pa, trust_name
FROM pathway_nodes
WHERE {' AND '.join(where)}
ORDER BY directory, value DESC
"""
rows = conn.execute(query, params).fetchall()
# Aggregate across trusts: sum patients/cost per directory+drug
agg = {}
for r in rows:
key = (r["directory"], r["drug"])
if key not in agg:
agg[key] = {"directory": r["directory"], "drug": r["drug"],
"patients": 0, "cost": 0.0}
agg[key]["patients"] += r["patients"] or 0
agg[key]["cost"] += float(r["cost"]) if r["cost"] else 0.0
# Compute proportions within each directory
dir_totals = {}
for v in agg.values():
dir_totals[v["directory"]] = dir_totals.get(v["directory"], 0) + v["patients"]
result = []
for v in agg.values():
total = dir_totals.get(v["directory"], 1)
result.append({
"directory": v["directory"],
"drug": v["drug"],
"patients": v["patients"],
"proportion": round(v["patients"] / total, 4) if total else 0,
"cost": round(v["cost"], 2),
"cost_pp_pa": round(v["cost"] / v["patients"], 2) if v["patients"] else 0,
})
# Sort: directory by total patients desc, drugs by patients desc within
result.sort(key=lambda x: (-dir_totals.get(x["directory"], 0), -x["patients"]))
return result
except sqlite3.Error:
return []
finally:
conn.close()
def get_pathway_costs(
db_path: Path,
date_filter_id: str,
chart_type: str,
directory: Optional[str] = None,
trust: Optional[str] = None,
) -> list[dict]:
"""Level 4+ pathway nodes with annualized cost and pathway labels.
Returns list of dicts: [{ids, pathway_label, cost_pp_pa, patients, directory, drug_sequence}]
Sorted by cost_pp_pa desc.
"""
conn = sqlite3.connect(str(db_path))
conn.row_factory = sqlite3.Row
try:
where = ["date_filter_id = ?", "chart_type = ?", "level >= 4"]
params: list = [date_filter_id, chart_type]
if directory:
where.append("directory = ?")
params.append(directory)
if trust:
where.append("trust_name = ?")
params.append(trust)
query = f"""
SELECT ids, labels, level, value AS patients, cost, costpp,
cost_pp_pa, avg_days, directory, drug_sequence, trust_name
FROM pathway_nodes
WHERE {' AND '.join(where)}
ORDER BY value DESC
"""
rows = conn.execute(query, params).fetchall()
result = []
for r in rows:
cpp = _safe_float(r["cost_pp_pa"])
drugs = (r["drug_sequence"] or "").split("|")
pathway_label = "".join(d for d in drugs if d)
result.append({
"ids": r["ids"],
"pathway_label": pathway_label,
"cost_pp_pa": cpp,
"patients": r["patients"] or 0,
"cost": float(r["cost"]) if r["cost"] else 0.0,
"avg_days": _safe_float(r["avg_days"]),
"directory": r["directory"] or "",
"trust_name": r["trust_name"] or "",
"drug_sequence": drugs,
"level": r["level"],
})
result.sort(key=lambda x: -x["cost_pp_pa"])
return result
except sqlite3.Error:
return []
finally:
conn.close()
def get_cost_waterfall(
db_path: Path,
date_filter_id: str,
chart_type: str,
trust: Optional[str] = None,
) -> list[dict]:
"""Level 2 directorate/indication nodes with cost metrics.
Since level 2 cost_pp_pa is 'N/A', we compute it from child (level 3) nodes:
sum(cost) / sum(patients) for each directory.
Returns list of dicts: [{directory, patients, total_cost, cost_pp}]
Sorted by cost_pp desc.
"""
conn = sqlite3.connect(str(db_path))
conn.row_factory = sqlite3.Row
try:
where = ["date_filter_id = ?", "chart_type = ?", "level = 3"]
params: list = [date_filter_id, chart_type]
if trust:
where.append("trust_name = ?")
params.append(trust)
query = f"""
SELECT directory, SUM(value) AS patients, SUM(cost) AS total_cost
FROM pathway_nodes
WHERE {' AND '.join(where)}
GROUP BY directory
HAVING patients > 0
ORDER BY total_cost DESC
"""
rows = conn.execute(query, params).fetchall()
result = []
for r in rows:
patients = r["patients"] or 0
total_cost = float(r["total_cost"]) if r["total_cost"] else 0.0
result.append({
"directory": r["directory"] or "",
"patients": patients,
"total_cost": round(total_cost, 2),
"cost_pp": round(total_cost / patients, 2) if patients else 0,
})
result.sort(key=lambda x: -x["cost_pp"])
return result
except sqlite3.Error:
return []
finally:
conn.close()
def get_drug_transitions(
db_path: Path,
date_filter_id: str,
chart_type: str,
directory: Optional[str] = None,
trust: Optional[str] = None,
) -> dict:
"""Parse level 3+ nodes into source→target drug transitions for Sankey.
Returns dict with:
nodes: [{name, total_patients}] — unique drug names
links: [{source_idx, target_idx, patients}] — transitions between drugs
"""
conn = sqlite3.connect(str(db_path))
conn.row_factory = sqlite3.Row
try:
where = ["date_filter_id = ?", "chart_type = ?", "level >= 4"]
params: list = [date_filter_id, chart_type]
if directory:
where.append("directory = ?")
params.append(directory)
if trust:
where.append("trust_name = ?")
params.append(trust)
query = f"""
SELECT ids, level, value AS patients, drug_sequence, directory
FROM pathway_nodes
WHERE {' AND '.join(where)}
ORDER BY level, ids
"""
rows = conn.execute(query, params).fetchall()
# Build node list and link aggregation
# Each drug at each treatment line is a separate Sankey node
# e.g., "ADALIMUMAB (1st)" and "ADALIMUMAB (2nd)" are different nodes
drug_line_set = set()
link_agg = {}
for r in rows:
drugs = [d for d in (r["drug_sequence"] or "").split("|") if d]
patients = r["patients"] or 0
if len(drugs) < 2 or patients == 0:
continue
# Only use adjacent transitions (drug[i] → drug[i+1])
for i in range(len(drugs) - 1):
src = f"{drugs[i]} ({_ordinal(i + 1)})"
tgt = f"{drugs[i + 1]} ({_ordinal(i + 2)})"
drug_line_set.add(src)
drug_line_set.add(tgt)
key = (src, tgt)
link_agg[key] = link_agg.get(key, 0) + patients
# Build indexed node list
node_list = sorted(drug_line_set)
node_idx = {name: i for i, name in enumerate(node_list)}
nodes = [{"name": name} for name in node_list]
links = [
{"source_idx": node_idx[src], "target_idx": node_idx[tgt], "patients": pts}
for (src, tgt), pts in sorted(link_agg.items(), key=lambda x: -x[1])
]
return {"nodes": nodes, "links": links}
except sqlite3.Error:
return {"nodes": [], "links": []}
finally:
conn.close()
def _ordinal(n: int) -> str:
"""Return '1st', '2nd', '3rd', '4th', etc."""
if 11 <= n % 100 <= 13:
return f"{n}th"
suffix = {1: "st", 2: "nd", 3: "rd"}.get(n % 10, "th")
return f"{n}{suffix}"
def get_dosing_intervals(
db_path: Path,
date_filter_id: str,
chart_type: str,
drug: Optional[str] = None,
trust: Optional[str] = None,
) -> list[dict]:
"""Level 3 drug nodes with parsed dosing interval data.
Returns list of dicts:
[{drug, trust_name, directory, weekly_interval, dose_count, total_weeks, patients}]
"""
from data_processing.parsing import parse_average_spacing
conn = sqlite3.connect(str(db_path))
conn.row_factory = sqlite3.Row
try:
where = ["date_filter_id = ?", "chart_type = ?", "level = 3",
"average_spacing IS NOT NULL", "average_spacing != ''"]
params: list = [date_filter_id, chart_type]
if drug:
where.append("labels = ?")
params.append(drug)
if trust:
where.append("trust_name = ?")
params.append(trust)
query = f"""
SELECT labels AS drug, trust_name, directory, value AS patients,
average_spacing
FROM pathway_nodes
WHERE {' AND '.join(where)}
ORDER BY labels, trust_name
"""
rows = conn.execute(query, params).fetchall()
result = []
for r in rows:
parsed = parse_average_spacing(r["average_spacing"])
for entry in parsed:
result.append({
"drug": entry["drug_name"],
"trust_name": r["trust_name"] or "",
"directory": r["directory"] or "",
"weekly_interval": entry["weekly_interval"],
"dose_count": entry["dose_count"],
"total_weeks": entry["total_weeks"],
"patients": r["patients"] or 0,
})
return result
except sqlite3.Error:
return []
finally:
conn.close()
def get_drug_directory_matrix(
db_path: Path,
date_filter_id: str,
chart_type: str,
trust: Optional[str] = None,
) -> dict:
"""Level 3 nodes pivoted as directory × drug matrix.
Returns dict with:
directories: sorted list of directory names (rows)
drugs: sorted list of drug names (columns)
matrix: {directory: {drug: {patients, cost, cost_pp_pa}}}
"""
conn = sqlite3.connect(str(db_path))
conn.row_factory = sqlite3.Row
try:
where = ["date_filter_id = ?", "chart_type = ?", "level = 3"]
params: list = [date_filter_id, chart_type]
if trust:
where.append("trust_name = ?")
params.append(trust)
query = f"""
SELECT labels AS drug, directory, value AS patients, cost, cost_pp_pa
FROM pathway_nodes
WHERE {' AND '.join(where)}
ORDER BY directory, labels
"""
rows = conn.execute(query, params).fetchall()
# Aggregate across trusts
matrix = {}
dir_totals = {}
drug_totals = {}
for r in rows:
d = r["directory"] or ""
drug = r["drug"] or ""
patients = r["patients"] or 0
cost = float(r["cost"]) if r["cost"] else 0.0
cpp = _safe_float(r["cost_pp_pa"])
if d not in matrix:
matrix[d] = {}
if drug not in matrix[d]:
matrix[d][drug] = {"patients": 0, "cost": 0.0}
matrix[d][drug]["patients"] += patients
matrix[d][drug]["cost"] += cost
dir_totals[d] = dir_totals.get(d, 0) + patients
drug_totals[drug] = drug_totals.get(drug, 0) + patients
# Add cost_pp_pa to each cell
for d in matrix:
for drug in matrix[d]:
cell = matrix[d][drug]
cell["cost"] = round(cell["cost"], 2)
cell["cost_pp_pa"] = (
round(cell["cost"] / cell["patients"], 2) if cell["patients"] else 0
)
# Sort directories by total patients desc, drugs by frequency desc
directories = sorted(dir_totals, key=lambda x: -dir_totals[x])
drugs = sorted(drug_totals, key=lambda x: -drug_totals[x])
return {"directories": directories, "drugs": drugs, "matrix": matrix}
except sqlite3.Error:
return {"directories": [], "drugs": [], "matrix": {}}
finally:
conn.close()
def get_treatment_durations(
db_path: Path,
date_filter_id: str,
chart_type: str,
directory: Optional[str] = None,
trust: Optional[str] = None,
) -> list[dict]:
"""Level 3 drug nodes with average treatment duration.
Returns list of dicts: [{drug, avg_days, patients, directory}]
Sorted by avg_days desc. Excludes nodes with no avg_days data.
"""
conn = sqlite3.connect(str(db_path))
conn.row_factory = sqlite3.Row
try:
where = ["date_filter_id = ?", "chart_type = ?", "level = 3",
"avg_days IS NOT NULL"]
params: list = [date_filter_id, chart_type]
if directory:
where.append("directory = ?")
params.append(directory)
if trust:
where.append("trust_name = ?")
params.append(trust)
query = f"""
SELECT labels AS drug, directory, trust_name,
value AS patients, avg_days
FROM pathway_nodes
WHERE {' AND '.join(where)}
ORDER BY avg_days DESC
"""
rows = conn.execute(query, params).fetchall()
# Aggregate across trusts: weighted average of avg_days by patients
agg = {}
for r in rows:
key = (r["directory"] or "", r["drug"])
patients = r["patients"] or 0
days = _safe_float(r["avg_days"])
if patients == 0 or days == 0:
continue
if key not in agg:
agg[key] = {"drug": r["drug"], "directory": r["directory"] or "",
"total_weighted_days": 0.0, "total_patients": 0}
agg[key]["total_weighted_days"] += days * patients
agg[key]["total_patients"] += patients
result = []
for v in agg.values():
if v["total_patients"] > 0:
result.append({
"drug": v["drug"],
"directory": v["directory"],
"avg_days": round(v["total_weighted_days"] / v["total_patients"], 1),
"patients": v["total_patients"],
})
result.sort(key=lambda x: -x["avg_days"])
return result
except sqlite3.Error:
return []
finally:
conn.close()