Files

1665 lines
55 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
Shared query functions for pathway node data.
These functions extract the data loading logic from the Reflex AppState
into standalone functions that accept db_path as a parameter and return
plain JSON-serializable dicts/lists. Both Reflex and Dash can call these.
All queries are read-only SELECTs against pathways.db.
"""
import sqlite3
from pathlib import Path
from typing import Optional
def load_initial_data(db_path: Path) -> dict:
"""
Load reference data from SQLite on app initialization.
Extracted from AppState.load_data() (pathways_app.py lines 407-488).
Returns dict with keys:
available_drugs: sorted list of unique drug labels (level 3 nodes)
available_directorates: sorted list of unique directorate labels (level 2, directory charts)
available_indications: sorted list of unique indications from ref_drug_indication_clusters
total_records: source row count from latest completed refresh
total_patients: patient count from default root node
total_cost: total cost from default root node
last_updated: ISO timestamp of latest completed refresh
"""
if not db_path.exists():
return {
"available_drugs": [],
"available_directorates": [],
"available_indications": [],
"available_trusts": [],
"total_records": 0,
"last_updated": "",
"error": "Database not found",
}
conn = sqlite3.connect(str(db_path))
try:
cursor = conn.cursor()
# Latest completed refresh metadata
cursor.execute("""
SELECT source_row_count, completed_at
FROM pathway_refresh_log
WHERE status = 'completed'
ORDER BY started_at DESC
LIMIT 1
""")
refresh_row = cursor.fetchone()
total_records = (refresh_row[0] or 0) if refresh_row else 0
last_updated = (refresh_row[1] or "") if refresh_row else ""
# Unique drugs from pathway_nodes level 3
cursor.execute("""
SELECT DISTINCT labels
FROM pathway_nodes
WHERE level = 3 AND labels IS NOT NULL AND labels != ''
ORDER BY labels
""")
available_drugs = [row[0] for row in cursor.fetchall()]
# Unique directorates from directory chart pathway_nodes level 2
cursor.execute("""
SELECT DISTINCT labels
FROM pathway_nodes
WHERE level = 2 AND chart_type = 'directory'
AND labels IS NOT NULL AND labels != ''
ORDER BY labels
""")
available_directorates = [row[0] for row in cursor.fetchall()]
# Unique indications from ref_drug_indication_clusters
cursor.execute("""
SELECT DISTINCT indication
FROM ref_drug_indication_clusters
WHERE indication IS NOT NULL AND indication != ''
ORDER BY indication
""")
available_indications = [row[0] for row in cursor.fetchall()]
if not available_indications:
available_indications = ["(No indications available)"]
# Unique trusts from pathway_nodes level 1
cursor.execute("""
SELECT DISTINCT trust_name
FROM pathway_nodes
WHERE level = 1 AND trust_name IS NOT NULL AND trust_name != ''
ORDER BY trust_name
""")
available_trusts = [row[0] for row in cursor.fetchall()]
# Total patients and cost from default root node
total_patients = 0
total_cost = 0.0
cursor.execute("""
SELECT value, cost FROM pathway_nodes
WHERE level = 0 AND date_filter_id = 'all_6mo' AND chart_type = 'directory'
LIMIT 1
""")
root_row = cursor.fetchone()
if root_row:
total_patients = root_row[0] or 0
total_cost = root_row[1] or 0.0
return {
"available_drugs": available_drugs,
"available_directorates": available_directorates,
"available_indications": available_indications,
"available_trusts": available_trusts,
"total_records": total_records,
"total_patients": total_patients,
"total_cost": total_cost,
"last_updated": last_updated,
}
except sqlite3.Error as e:
return {
"available_drugs": [],
"available_directorates": [],
"available_indications": [],
"available_trusts": [],
"total_records": 0,
"last_updated": "",
"error": f"Database error: {e}",
}
finally:
conn.close()
def load_pathway_nodes(
db_path: Path,
filter_id: str,
chart_type: str,
selected_drugs: Optional[list[str]] = None,
selected_directorates: Optional[list[str]] = None,
selected_trusts: Optional[list[str]] = None,
) -> dict:
"""
Load pre-computed pathway nodes from SQLite.
Extracted from AppState.load_pathway_data() (pathways_app.py lines 490-642).
Args:
db_path: Path to pathways.db
filter_id: e.g. "all_6mo", "2yr_12mo"
chart_type: "directory" or "indication"
selected_drugs: optional list of drug names to filter by
selected_directorates: optional list of directorate names to filter by
selected_trusts: optional list of trust names to filter by
Returns dict with keys:
nodes: list of dicts (JSON-serializable) with chart node data
unique_patients: int (from root node)
total_drugs: int (unique drugs across level 3+ nodes)
total_cost: float (from root node)
last_updated: ISO timestamp string
error: optional error string
"""
if not db_path.exists():
return _empty_result("Database not found")
conn = sqlite3.connect(str(db_path))
conn.row_factory = sqlite3.Row
try:
cursor = conn.cursor()
# Build WHERE clause with parameterized values
where_clauses = ["date_filter_id = ?", "chart_type = ?"]
params: list = [filter_id, chart_type]
if selected_directorates:
placeholders = ",".join("?" * len(selected_directorates))
where_clauses.append(
f"(level < 2 OR directory IN ({placeholders}) OR directory IS NULL OR directory = '')"
)
params.extend(selected_directorates)
if selected_drugs:
drug_conditions = []
for drug in selected_drugs:
drug_conditions.append("drug_sequence LIKE ?")
params.append(f"%{drug}%")
where_clauses.append(
f"(level < 3 OR {' OR '.join(drug_conditions)})"
)
if selected_trusts:
placeholders = ",".join("?" * len(selected_trusts))
where_clauses.append(
f"(trust_name IN ({placeholders}) OR trust_name IS NULL OR trust_name = '')"
)
params.extend(selected_trusts)
where_clause = " AND ".join(where_clauses)
query = f"""
SELECT
parents, ids, labels, level, value,
cost, costpp, cost_pp_pa, colour,
first_seen, last_seen, first_seen_parent, last_seen_parent,
average_spacing, average_administered, avg_days,
trust_name, directory, drug_sequence
FROM pathway_nodes
WHERE {where_clause}
ORDER BY level, parents, ids
"""
cursor.execute(query, params)
rows = cursor.fetchall()
if not rows:
return _empty_result(f"No pathway data for filter: {filter_id}")
# When drug or directorate filters are active, prune ancestor nodes
# that have no matching descendants. Without this, the icicle chart
# shows empty directorate/trust boxes with no children.
if selected_drugs or selected_directorates:
rows = _prune_empty_ancestors(rows)
nodes = []
root_patients = 0
root_cost = 0.0
has_entity_filter = bool(selected_drugs or selected_directorates or selected_trusts)
for row in rows:
node = {
"parents": row["parents"] or "",
"ids": row["ids"] or "",
"labels": row["labels"] or "",
"value": row["value"] or 0,
"cost": float(row["cost"]) if row["cost"] else 0.0,
"costpp": float(row["costpp"]) if row["costpp"] else 0.0,
"colour": float(row["colour"]) if row["colour"] else 0.0,
"first_seen": row["first_seen"] or "",
"last_seen": row["last_seen"] or "",
"first_seen_parent": row["first_seen_parent"] or "",
"last_seen_parent": row["last_seen_parent"] or "",
"average_spacing": row["average_spacing"] or "",
"cost_pp_pa": row["cost_pp_pa"] or "",
}
nodes.append(node)
if row["level"] == 0:
root_patients = row["value"] or 0
root_cost = float(row["cost"]) if row["cost"] else 0.0
# Count unique drugs from level 3+ nodes
unique_drugs = set()
for row in rows:
if row["level"] >= 3 and row["drug_sequence"]:
for drug in row["drug_sequence"].split("|"):
if drug:
unique_drugs.add(drug)
# When entity filters are active, sum level-3 drug nodes for KPIs
# instead of using the root node's pre-computed (unfiltered) totals
if has_entity_filter:
filtered_patients = sum(
row["value"] or 0 for row in rows if row["level"] == 3
)
filtered_cost = sum(
float(row["cost"]) if row["cost"] else 0.0
for row in rows if row["level"] == 3
)
root_patients = filtered_patients
root_cost = filtered_cost
# Data freshness
cursor.execute("""
SELECT completed_at
FROM pathway_refresh_log
WHERE status = 'completed'
ORDER BY completed_at DESC
LIMIT 1
""")
refresh_row = cursor.fetchone()
last_updated = (
refresh_row["completed_at"] if refresh_row and refresh_row["completed_at"] else ""
)
return {
"nodes": nodes,
"unique_patients": root_patients,
"total_drugs": len(unique_drugs),
"total_cost": root_cost,
"last_updated": last_updated,
}
except sqlite3.Error as e:
return _empty_result(f"Database error: {e}")
finally:
conn.close()
def _prune_empty_ancestors(rows):
"""Remove ancestor nodes that have no matching descendants.
When drug/directorate filters are active, levels 0-2 are included
unconditionally. This leaves directorate and trust nodes that have no
children in the filtered result. Plotly's icicle chart shows these as
empty boxes. Prune them by keeping only nodes whose ids appear as a
parent of another kept node, or that are leaf nodes (level 3+), or
are the root (level 0).
"""
# Collect all parent references from the result set
referenced_parents = {row["parents"] for row in rows if row["parents"]}
# Keep: root (level 0), any node referenced as a parent, any leaf (level 3+)
kept = [
row for row in rows
if row["level"] == 0
or row["level"] >= 3
or row["ids"] in referenced_parents
]
# Second pass: a trust (level 1) may reference root but itself have no
# kept level-2 children. Recheck that level-1 nodes are still parents
# of something in the kept set.
kept_parents = {row["parents"] for row in kept if row["parents"]}
return [
row for row in kept
if row["level"] == 0
or row["level"] >= 3
or row["ids"] in kept_parents
]
def _empty_result(error: str = "") -> dict:
return {
"nodes": [],
"unique_patients": 0,
"total_drugs": 0,
"total_cost": 0.0,
"last_updated": "",
"error": error,
}
# ---------------------------------------------------------------------------
# Analytics chart query functions (Phase 9)
# ---------------------------------------------------------------------------
def _safe_float(value, default=0.0):
"""Convert a value to float, returning default for None/N/A/empty."""
if value is None or value == "" or value == "N/A":
return default
try:
return float(value)
except (ValueError, TypeError):
return default
def get_drug_market_share(
db_path: Path,
date_filter_id: str,
chart_type: str,
directory: Optional[str] = None,
trust: Optional[str] = None,
) -> list[dict]:
"""Level 3 drug nodes grouped by directory with patient counts and proportions.
Returns list of dicts: [{directory, drug, patients, proportion, cost, cost_pp_pa}]
Sorted by directory total patients desc, then drug patients desc within each.
"""
conn = sqlite3.connect(str(db_path))
conn.row_factory = sqlite3.Row
try:
where = ["date_filter_id = ?", "chart_type = ?", "level = 3"]
params: list = [date_filter_id, chart_type]
if directory:
where.append("directory = ?")
params.append(directory)
if trust:
where.append("trust_name = ?")
params.append(trust)
query = f"""
SELECT labels AS drug, directory, value AS patients,
colour AS proportion, cost, cost_pp_pa, trust_name
FROM pathway_nodes
WHERE {' AND '.join(where)}
ORDER BY directory, value DESC
"""
rows = conn.execute(query, params).fetchall()
# Aggregate across trusts: sum patients/cost per directory+drug
agg = {}
for r in rows:
key = (r["directory"], r["drug"])
if key not in agg:
agg[key] = {"directory": r["directory"], "drug": r["drug"],
"patients": 0, "cost": 0.0}
agg[key]["patients"] += r["patients"] or 0
agg[key]["cost"] += float(r["cost"]) if r["cost"] else 0.0
# Compute proportions within each directory
dir_totals = {}
for v in agg.values():
dir_totals[v["directory"]] = dir_totals.get(v["directory"], 0) + v["patients"]
result = []
for v in agg.values():
total = dir_totals.get(v["directory"], 1)
result.append({
"directory": v["directory"],
"drug": v["drug"],
"patients": v["patients"],
"proportion": round(v["patients"] / total, 4) if total else 0,
"cost": round(v["cost"], 2),
"cost_pp_pa": round(v["cost"] / v["patients"], 2) if v["patients"] else 0,
})
# Sort: directory by total patients desc, drugs by patients desc within
result.sort(key=lambda x: (-dir_totals.get(x["directory"], 0), -x["patients"]))
return result
except sqlite3.Error:
return []
finally:
conn.close()
def get_pathway_costs(
db_path: Path,
date_filter_id: str,
chart_type: str,
directory: Optional[str] = None,
trust: Optional[str] = None,
) -> list[dict]:
"""Level 4+ pathway nodes with annualized cost and pathway labels.
Returns list of dicts: [{ids, pathway_label, cost_pp_pa, patients, directory, drug_sequence}]
Sorted by cost_pp_pa desc.
"""
conn = sqlite3.connect(str(db_path))
conn.row_factory = sqlite3.Row
try:
where = ["date_filter_id = ?", "chart_type = ?", "level >= 4"]
params: list = [date_filter_id, chart_type]
if directory:
where.append("directory = ?")
params.append(directory)
if trust:
where.append("trust_name = ?")
params.append(trust)
query = f"""
SELECT ids, labels, level, value AS patients, cost, costpp,
cost_pp_pa, avg_days, directory, drug_sequence, trust_name
FROM pathway_nodes
WHERE {' AND '.join(where)}
ORDER BY value DESC
"""
rows = conn.execute(query, params).fetchall()
result = []
for r in rows:
cpp = _safe_float(r["cost_pp_pa"])
drugs = (r["drug_sequence"] or "").split("|")
pathway_label = "".join(d for d in drugs if d)
result.append({
"ids": r["ids"],
"pathway_label": pathway_label,
"cost_pp_pa": cpp,
"patients": r["patients"] or 0,
"cost": float(r["cost"]) if r["cost"] else 0.0,
"avg_days": _safe_float(r["avg_days"]),
"directory": r["directory"] or "",
"trust_name": r["trust_name"] or "",
"drug_sequence": drugs,
"level": r["level"],
})
result.sort(key=lambda x: -x["cost_pp_pa"])
return result
except sqlite3.Error:
return []
finally:
conn.close()
def get_cost_waterfall(
db_path: Path,
date_filter_id: str,
chart_type: str,
trust: Optional[str] = None,
) -> list[dict]:
"""Level 2 directorate/indication nodes with cost metrics.
Since level 2 cost_pp_pa is 'N/A', we compute it from child (level 3) nodes:
sum(cost) / sum(patients) for each directory.
Returns list of dicts: [{directory, patients, total_cost, cost_pp}]
Sorted by cost_pp desc.
"""
conn = sqlite3.connect(str(db_path))
conn.row_factory = sqlite3.Row
try:
where = ["date_filter_id = ?", "chart_type = ?", "level = 3"]
params: list = [date_filter_id, chart_type]
if trust:
where.append("trust_name = ?")
params.append(trust)
query = f"""
SELECT directory, SUM(value) AS patients, SUM(cost) AS total_cost
FROM pathway_nodes
WHERE {' AND '.join(where)}
GROUP BY directory
HAVING patients > 0
ORDER BY total_cost DESC
"""
rows = conn.execute(query, params).fetchall()
result = []
for r in rows:
patients = r["patients"] or 0
total_cost = float(r["total_cost"]) if r["total_cost"] else 0.0
result.append({
"directory": r["directory"] or "",
"patients": patients,
"total_cost": round(total_cost, 2),
"cost_pp": round(total_cost / patients, 2) if patients else 0,
})
result.sort(key=lambda x: -x["cost_pp"])
return result
except sqlite3.Error:
return []
finally:
conn.close()
def get_drug_transitions(
db_path: Path,
date_filter_id: str,
chart_type: str,
directory: Optional[str] = None,
trust: Optional[str] = None,
) -> dict:
"""Parse level 3+ nodes into source→target drug transitions for Sankey.
Returns dict with:
nodes: [{name, total_patients}] — unique drug names
links: [{source_idx, target_idx, patients}] — transitions between drugs
"""
conn = sqlite3.connect(str(db_path))
conn.row_factory = sqlite3.Row
try:
where = ["date_filter_id = ?", "chart_type = ?", "level >= 4"]
params: list = [date_filter_id, chart_type]
if directory:
where.append("directory = ?")
params.append(directory)
if trust:
where.append("trust_name = ?")
params.append(trust)
query = f"""
SELECT ids, level, value AS patients, drug_sequence, directory
FROM pathway_nodes
WHERE {' AND '.join(where)}
ORDER BY level, ids
"""
rows = conn.execute(query, params).fetchall()
# Build node list and link aggregation
# Each drug at each treatment line is a separate Sankey node
# e.g., "ADALIMUMAB (1st)" and "ADALIMUMAB (2nd)" are different nodes
drug_line_set = set()
link_agg = {}
for r in rows:
drugs = [d for d in (r["drug_sequence"] or "").split("|") if d]
patients = r["patients"] or 0
if len(drugs) < 2 or patients == 0:
continue
# Only use adjacent transitions (drug[i] → drug[i+1])
for i in range(len(drugs) - 1):
src = f"{drugs[i]} ({_ordinal(i + 1)})"
tgt = f"{drugs[i + 1]} ({_ordinal(i + 2)})"
drug_line_set.add(src)
drug_line_set.add(tgt)
key = (src, tgt)
link_agg[key] = link_agg.get(key, 0) + patients
# Build indexed node list
node_list = sorted(drug_line_set)
node_idx = {name: i for i, name in enumerate(node_list)}
nodes = [{"name": name} for name in node_list]
links = [
{"source_idx": node_idx[src], "target_idx": node_idx[tgt], "patients": pts}
for (src, tgt), pts in sorted(link_agg.items(), key=lambda x: -x[1])
]
return {"nodes": nodes, "links": links}
except sqlite3.Error:
return {"nodes": [], "links": []}
finally:
conn.close()
def _ordinal(n: int) -> str:
"""Return '1st', '2nd', '3rd', '4th', etc."""
if 11 <= n % 100 <= 13:
return f"{n}th"
suffix = {1: "st", 2: "nd", 3: "rd"}.get(n % 10, "th")
return f"{n}{suffix}"
def get_dosing_intervals(
db_path: Path,
date_filter_id: str,
chart_type: str,
drug: Optional[str] = None,
trust: Optional[str] = None,
) -> list[dict]:
"""Level 3 drug nodes with parsed dosing interval data.
Returns list of dicts:
[{drug, trust_name, directory, weekly_interval, dose_count, total_weeks, patients}]
"""
from data_processing.parsing import parse_average_spacing
conn = sqlite3.connect(str(db_path))
conn.row_factory = sqlite3.Row
try:
where = ["date_filter_id = ?", "chart_type = ?", "level = 3",
"average_spacing IS NOT NULL", "average_spacing != ''"]
params: list = [date_filter_id, chart_type]
if drug:
where.append("labels = ?")
params.append(drug)
if trust:
where.append("trust_name = ?")
params.append(trust)
query = f"""
SELECT labels AS drug, trust_name, directory, value AS patients,
average_spacing
FROM pathway_nodes
WHERE {' AND '.join(where)}
ORDER BY labels, trust_name
"""
rows = conn.execute(query, params).fetchall()
result = []
for r in rows:
parsed = parse_average_spacing(r["average_spacing"])
for entry in parsed:
result.append({
"drug": entry["drug_name"],
"trust_name": r["trust_name"] or "",
"directory": r["directory"] or "",
"weekly_interval": entry["weekly_interval"],
"dose_count": entry["dose_count"],
"total_weeks": entry["total_weeks"],
"patients": r["patients"] or 0,
})
return result
except sqlite3.Error:
return []
finally:
conn.close()
def get_drug_directory_matrix(
db_path: Path,
date_filter_id: str,
chart_type: str,
trust: Optional[str] = None,
) -> dict:
"""Level 3 nodes pivoted as directory × drug matrix.
Returns dict with:
directories: sorted list of directory names (rows)
drugs: sorted list of drug names (columns)
matrix: {directory: {drug: {patients, cost, cost_pp_pa}}}
"""
conn = sqlite3.connect(str(db_path))
conn.row_factory = sqlite3.Row
try:
where = ["date_filter_id = ?", "chart_type = ?", "level = 3"]
params: list = [date_filter_id, chart_type]
if trust:
where.append("trust_name = ?")
params.append(trust)
query = f"""
SELECT labels AS drug, directory, value AS patients, cost, cost_pp_pa
FROM pathway_nodes
WHERE {' AND '.join(where)}
ORDER BY directory, labels
"""
rows = conn.execute(query, params).fetchall()
# Aggregate across trusts
matrix = {}
dir_totals = {}
drug_totals = {}
for r in rows:
d = r["directory"] or ""
drug = r["drug"] or ""
patients = r["patients"] or 0
cost = float(r["cost"]) if r["cost"] else 0.0
cpp = _safe_float(r["cost_pp_pa"])
if d not in matrix:
matrix[d] = {}
if drug not in matrix[d]:
matrix[d][drug] = {"patients": 0, "cost": 0.0}
matrix[d][drug]["patients"] += patients
matrix[d][drug]["cost"] += cost
dir_totals[d] = dir_totals.get(d, 0) + patients
drug_totals[drug] = drug_totals.get(drug, 0) + patients
# Add cost_pp_pa to each cell
for d in matrix:
for drug in matrix[d]:
cell = matrix[d][drug]
cell["cost"] = round(cell["cost"], 2)
cell["cost_pp_pa"] = (
round(cell["cost"] / cell["patients"], 2) if cell["patients"] else 0
)
# Sort directories by total patients desc, drugs by frequency desc
directories = sorted(dir_totals, key=lambda x: -dir_totals[x])
drugs = sorted(drug_totals, key=lambda x: -drug_totals[x])
return {"directories": directories, "drugs": drugs, "matrix": matrix}
except sqlite3.Error:
return {"directories": [], "drugs": [], "matrix": {}}
finally:
conn.close()
def get_treatment_durations(
db_path: Path,
date_filter_id: str,
chart_type: str,
directory: Optional[str] = None,
trust: Optional[str] = None,
) -> list[dict]:
"""Level 3 drug nodes with average treatment duration.
Returns list of dicts: [{drug, avg_days, patients, directory}]
Sorted by avg_days desc. Excludes nodes with no avg_days data.
"""
conn = sqlite3.connect(str(db_path))
conn.row_factory = sqlite3.Row
try:
where = ["date_filter_id = ?", "chart_type = ?", "level = 3",
"avg_days IS NOT NULL"]
params: list = [date_filter_id, chart_type]
if directory:
where.append("directory = ?")
params.append(directory)
if trust:
where.append("trust_name = ?")
params.append(trust)
query = f"""
SELECT labels AS drug, directory, trust_name,
value AS patients, avg_days
FROM pathway_nodes
WHERE {' AND '.join(where)}
ORDER BY avg_days DESC
"""
rows = conn.execute(query, params).fetchall()
# Aggregate across trusts: weighted average of avg_days by patients
agg = {}
for r in rows:
key = (r["directory"] or "", r["drug"])
patients = r["patients"] or 0
days = _safe_float(r["avg_days"])
if patients == 0 or days == 0:
continue
if key not in agg:
agg[key] = {"drug": r["drug"], "directory": r["directory"] or "",
"total_weighted_days": 0.0, "total_patients": 0}
agg[key]["total_weighted_days"] += days * patients
agg[key]["total_patients"] += patients
result = []
for v in agg.values():
if v["total_patients"] > 0:
result.append({
"drug": v["drug"],
"directory": v["directory"],
"avg_days": round(v["total_weighted_days"] / v["total_patients"], 1),
"patients": v["total_patients"],
})
result.sort(key=lambda x: -x["avg_days"])
return result
except sqlite3.Error:
return []
finally:
conn.close()
# ---------------------------------------------------------------------------
# Trust Comparison query functions (Phase 10)
#
# These functions provide per-trust breakdowns within a single directorate.
# Unlike Phase 9 functions which aggregate across trusts, these preserve
# per-trust detail for comparing drugs across trusts within a directorate.
# ---------------------------------------------------------------------------
def get_trust_market_share(
db_path: Path,
date_filter_id: str,
chart_type: str,
directory: str,
) -> list[dict]:
"""Drug market share broken down by trust within a single directorate.
Returns list of dicts: [{trust_name, drug, patients, proportion, cost, cost_pp_pa}]
Sorted by trust total patients desc, then drug patients desc within each trust.
"""
conn = sqlite3.connect(str(db_path))
conn.row_factory = sqlite3.Row
try:
query = """
SELECT labels AS drug, trust_name, value AS patients, cost, cost_pp_pa
FROM pathway_nodes
WHERE date_filter_id = ? AND chart_type = ? AND level = 3
AND directory = ?
ORDER BY trust_name, value DESC
"""
rows = conn.execute(query, (date_filter_id, chart_type, directory)).fetchall()
# Compute proportions within each trust
trust_totals = {}
for r in rows:
trust = r["trust_name"] or ""
trust_totals[trust] = trust_totals.get(trust, 0) + (r["patients"] or 0)
result = []
for r in rows:
trust = r["trust_name"] or ""
patients = r["patients"] or 0
cost = float(r["cost"]) if r["cost"] else 0.0
total = trust_totals.get(trust, 1)
result.append({
"trust_name": trust,
"drug": r["drug"],
"patients": patients,
"proportion": round(patients / total, 4) if total else 0,
"cost": round(cost, 2),
"cost_pp_pa": _safe_float(r["cost_pp_pa"]),
})
result.sort(key=lambda x: (-trust_totals.get(x["trust_name"], 0), -x["patients"]))
return result
except sqlite3.Error:
return []
finally:
conn.close()
def get_trust_cost_waterfall(
db_path: Path,
date_filter_id: str,
chart_type: str,
directory: str,
) -> list[dict]:
"""Cost per patient by trust within a single directorate.
Returns list of dicts: [{trust_name, patients, total_cost, cost_pp}]
Sorted by cost_pp desc.
"""
conn = sqlite3.connect(str(db_path))
conn.row_factory = sqlite3.Row
try:
query = """
SELECT trust_name, SUM(value) AS patients, SUM(cost) AS total_cost
FROM pathway_nodes
WHERE date_filter_id = ? AND chart_type = ? AND level = 3
AND directory = ?
GROUP BY trust_name
HAVING patients > 0
ORDER BY total_cost DESC
"""
rows = conn.execute(query, (date_filter_id, chart_type, directory)).fetchall()
result = []
for r in rows:
patients = r["patients"] or 0
total_cost = float(r["total_cost"]) if r["total_cost"] else 0.0
result.append({
"trust_name": r["trust_name"] or "",
"patients": patients,
"total_cost": round(total_cost, 2),
"cost_pp": round(total_cost / patients, 2) if patients else 0,
})
result.sort(key=lambda x: -x["cost_pp"])
return result
except sqlite3.Error:
return []
finally:
conn.close()
def get_trust_dosing(
db_path: Path,
date_filter_id: str,
chart_type: str,
directory: str,
) -> list[dict]:
"""Drug dosing intervals by trust within a single directorate.
Returns list of dicts:
[{drug, trust_name, weekly_interval, dose_count, total_weeks, patients}]
"""
from data_processing.parsing import parse_average_spacing
conn = sqlite3.connect(str(db_path))
conn.row_factory = sqlite3.Row
try:
query = """
SELECT labels AS drug, trust_name, value AS patients, average_spacing
FROM pathway_nodes
WHERE date_filter_id = ? AND chart_type = ? AND level = 3
AND directory = ?
AND average_spacing IS NOT NULL AND average_spacing != ''
ORDER BY labels, trust_name
"""
rows = conn.execute(query, (date_filter_id, chart_type, directory)).fetchall()
result = []
for r in rows:
parsed = parse_average_spacing(r["average_spacing"])
for entry in parsed:
result.append({
"drug": entry["drug_name"],
"trust_name": r["trust_name"] or "",
"weekly_interval": entry["weekly_interval"],
"dose_count": entry["dose_count"],
"total_weeks": entry["total_weeks"],
"patients": r["patients"] or 0,
})
return result
except sqlite3.Error:
return []
finally:
conn.close()
def get_trust_heatmap(
db_path: Path,
date_filter_id: str,
chart_type: str,
directory: str,
) -> dict:
"""Trust x drug matrix for a single directorate.
Returns dict with:
trusts: sorted list of trust names (rows)
drugs: sorted list of drug names (columns)
matrix: {trust_name: {drug: {patients, cost, cost_pp_pa}}}
"""
conn = sqlite3.connect(str(db_path))
conn.row_factory = sqlite3.Row
try:
query = """
SELECT labels AS drug, trust_name, value AS patients, cost, cost_pp_pa
FROM pathway_nodes
WHERE date_filter_id = ? AND chart_type = ? AND level = 3
AND directory = ?
ORDER BY trust_name, labels
"""
rows = conn.execute(query, (date_filter_id, chart_type, directory)).fetchall()
matrix = {}
trust_totals = {}
drug_totals = {}
for r in rows:
trust = r["trust_name"] or ""
drug = r["drug"] or ""
patients = r["patients"] or 0
cost = float(r["cost"]) if r["cost"] else 0.0
if trust not in matrix:
matrix[trust] = {}
matrix[trust][drug] = {
"patients": patients,
"cost": round(cost, 2),
"cost_pp_pa": _safe_float(r["cost_pp_pa"]),
}
trust_totals[trust] = trust_totals.get(trust, 0) + patients
drug_totals[drug] = drug_totals.get(drug, 0) + patients
trusts = sorted(trust_totals, key=lambda x: -trust_totals[x])
drugs = sorted(drug_totals, key=lambda x: -drug_totals[x])
return {"trusts": trusts, "drugs": drugs, "matrix": matrix}
except sqlite3.Error:
return {"trusts": [], "drugs": [], "matrix": {}}
finally:
conn.close()
def get_trust_durations(
db_path: Path,
date_filter_id: str,
chart_type: str,
directory: str,
) -> list[dict]:
"""Drug treatment durations by trust within a single directorate.
Returns list of dicts: [{drug, trust_name, avg_days, patients}]
Sorted by drug name, then trust. No aggregation — each row is one
trust+drug combination.
"""
conn = sqlite3.connect(str(db_path))
conn.row_factory = sqlite3.Row
try:
query = """
SELECT labels AS drug, trust_name, value AS patients, avg_days
FROM pathway_nodes
WHERE date_filter_id = ? AND chart_type = ? AND level = 3
AND directory = ?
AND avg_days IS NOT NULL
ORDER BY labels, trust_name
"""
rows = conn.execute(query, (date_filter_id, chart_type, directory)).fetchall()
result = []
for r in rows:
patients = r["patients"] or 0
days = _safe_float(r["avg_days"])
if patients == 0 or days == 0:
continue
result.append({
"drug": r["drug"],
"trust_name": r["trust_name"] or "",
"avg_days": round(days, 1),
"patients": patients,
})
return result
except sqlite3.Error:
return []
finally:
conn.close()
# --- Directorate/indication summary for Trust Comparison landing page ---
def get_retention_funnel(
db_path: Path,
date_filter_id: str,
chart_type: str,
directory: Optional[str] = None,
trust: Optional[str] = None,
) -> list[dict]:
"""Aggregate patient counts by treatment line depth for a retention funnel.
Level 3 = 1st drug, Level 4 = 2-drug pathway, Level 5 = 3-drug pathway, etc.
Returns list of dicts sorted by depth ascending:
[{depth: 1, label: "1 drug", patients: N, pct: 100.0}, ...]
"""
conn = sqlite3.connect(str(db_path))
conn.row_factory = sqlite3.Row
try:
where = ["date_filter_id = ?", "chart_type = ?", "level >= 3"]
params: list = [date_filter_id, chart_type]
if directory:
where.append("directory = ?")
params.append(directory)
if trust:
where.append("trust_name = ?")
params.append(trust)
query = f"""
SELECT level, SUM(value) AS patients
FROM pathway_nodes
WHERE {' AND '.join(where)}
GROUP BY level
ORDER BY level
"""
rows = conn.execute(query, params).fetchall()
if not rows:
return []
result = []
base_patients = 0
for r in rows:
level = r["level"]
patients = r["patients"] or 0
depth = level - 2 # level 3 → depth 1, level 4 → depth 2, etc.
if depth == 1:
base_patients = patients
ordinal_labels = {
1: "1st drug",
2: "2nd drug",
3: "3rd drug",
}
label = ordinal_labels.get(depth, f"{depth}th drug")
pct = round(patients / base_patients * 100, 1) if base_patients else 0
result.append({
"depth": depth,
"label": label,
"patients": patients,
"pct": pct,
})
return result
except sqlite3.Error:
return []
finally:
conn.close()
def get_pathway_depth_distribution(
db_path: Path,
date_filter_id: str,
chart_type: str,
directory: Optional[str] = None,
trust: Optional[str] = None,
) -> list[dict]:
"""Count patients who STOPPED at each treatment line depth.
Unlike the retention funnel (cumulative), this shows exclusive counts:
patients at depth N minus patients at depth N+1 = stopped at depth N.
Returns list of dicts sorted by depth ascending:
[{depth: 1, label: "1 drug only", patients: N, pct: 45.2}, ...]
"""
conn = sqlite3.connect(str(db_path))
conn.row_factory = sqlite3.Row
try:
where = ["date_filter_id = ?", "chart_type = ?", "level >= 3"]
params: list = [date_filter_id, chart_type]
if directory:
where.append("directory = ?")
params.append(directory)
if trust:
where.append("trust_name = ?")
params.append(trust)
query = f"""
SELECT level, SUM(value) AS patients
FROM pathway_nodes
WHERE {' AND '.join(where)}
GROUP BY level
ORDER BY level
"""
rows = conn.execute(query, params).fetchall()
if not rows:
return []
# Build list of (depth, cumulative_patients)
levels = []
for r in rows:
depth = r["level"] - 2 # level 3 → depth 1
patients = r["patients"] or 0
levels.append((depth, patients))
# Subtract next level to get "stopped at this depth"
total_patients = levels[0][1] if levels else 0
result = []
for i, (depth, patients) in enumerate(levels):
next_patients = levels[i + 1][1] if i + 1 < len(levels) else 0
stopped = patients - next_patients
label = f"{depth} drug{'s' if depth > 1 else ''} only"
pct = round(stopped / total_patients * 100, 1) if total_patients else 0
result.append({
"depth": depth,
"label": label,
"patients": stopped,
"pct": pct,
})
return result
except sqlite3.Error:
return []
finally:
conn.close()
def get_duration_cost_scatter(
db_path: Path,
date_filter_id: str,
chart_type: str,
directory: Optional[str] = None,
trust: Optional[str] = None,
) -> list[dict]:
"""Level 3 drug nodes with avg_days and cost_pp_pa for scatter plot.
Returns list of dicts: [{drug, directory, avg_days, cost_pp_pa, patients}]
Excludes nodes missing avg_days or cost_pp_pa. Aggregates across trusts
using weighted averages.
"""
conn = sqlite3.connect(str(db_path))
conn.row_factory = sqlite3.Row
try:
where = ["date_filter_id = ?", "chart_type = ?", "level = 3",
"avg_days IS NOT NULL", "cost_pp_pa IS NOT NULL"]
params: list = [date_filter_id, chart_type]
if directory:
where.append("directory = ?")
params.append(directory)
if trust:
where.append("trust_name = ?")
params.append(trust)
query = f"""
SELECT labels AS drug, directory,
value AS patients, avg_days, cost_pp_pa
FROM pathway_nodes
WHERE {' AND '.join(where)}
"""
rows = conn.execute(query, params).fetchall()
# Aggregate across trusts: weighted average of avg_days and cost_pp_pa
agg = {}
for r in rows:
key = (r["directory"] or "", r["drug"])
patients = r["patients"] or 0
days = _safe_float(r["avg_days"])
cost = _safe_float(r["cost_pp_pa"])
if patients == 0 or days == 0:
continue
if key not in agg:
agg[key] = {
"drug": r["drug"],
"directory": r["directory"] or "",
"weighted_days": 0.0,
"weighted_cost": 0.0,
"total_patients": 0,
}
agg[key]["weighted_days"] += days * patients
agg[key]["weighted_cost"] += cost * patients
agg[key]["total_patients"] += patients
result = []
for v in agg.values():
tp = v["total_patients"]
if tp > 0:
result.append({
"drug": v["drug"],
"directory": v["directory"],
"avg_days": round(v["weighted_days"] / tp, 1),
"cost_pp_pa": round(v["weighted_cost"] / tp, 0),
"patients": tp,
})
return result
except sqlite3.Error:
return []
finally:
conn.close()
def get_drug_network(
db_path: Path,
date_filter_id: str,
chart_type: str,
directory: Optional[str] = None,
trust: Optional[str] = None,
) -> dict:
"""Build undirected drug co-occurrence network from pathway data.
Unlike get_drug_transitions() (directed, with ordinal suffixes for Sankey),
this returns plain drug names with undirected edges representing co-occurrence
in patient pathways.
Returns dict with:
nodes: [{name, total_patients}] — unique drug names sorted by patient count
edges: [{source, target, patients}] — undirected co-occurrence links
"""
conn = sqlite3.connect(str(db_path))
conn.row_factory = sqlite3.Row
try:
where = ["date_filter_id = ?", "chart_type = ?", "level >= 4"]
params: list = [date_filter_id, chart_type]
if directory:
where.append("directory = ?")
params.append(directory)
if trust:
where.append("trust_name = ?")
params.append(trust)
query = f"""
SELECT value AS patients, drug_sequence
FROM pathway_nodes
WHERE {' AND '.join(where)}
"""
rows = conn.execute(query, params).fetchall()
# Also get level 3 nodes for per-drug patient totals
where_l3 = ["date_filter_id = ?", "chart_type = ?", "level = 3"]
params_l3: list = [date_filter_id, chart_type]
if directory:
where_l3.append("directory = ?")
params_l3.append(directory)
if trust:
where_l3.append("trust_name = ?")
params_l3.append(trust)
query_l3 = f"""
SELECT labels AS drug, SUM(value) AS total_patients
FROM pathway_nodes
WHERE {' AND '.join(where_l3)}
GROUP BY labels
"""
l3_rows = conn.execute(query_l3, params_l3).fetchall()
node_patients = {r["drug"]: r["total_patients"] or 0 for r in l3_rows}
# Build undirected edges from pathway sequences
edge_agg = {}
for r in rows:
drugs = [d for d in (r["drug_sequence"] or "").split("|") if d]
patients = r["patients"] or 0
if len(drugs) < 2 or patients == 0:
continue
# Adjacent drug pairs (undirected: sort to avoid A→B and B→A duplicates)
for i in range(len(drugs) - 1):
pair = tuple(sorted([drugs[i], drugs[i + 1]]))
edge_agg[pair] = edge_agg.get(pair, 0) + patients
# Build result
nodes = [
{"name": name, "total_patients": pts}
for name, pts in sorted(node_patients.items(), key=lambda x: -x[1])
if pts > 0
]
edges = [
{"source": src, "target": tgt, "patients": pts}
for (src, tgt), pts in sorted(edge_agg.items(), key=lambda x: -x[1])
]
return {"nodes": nodes, "edges": edges}
except sqlite3.Error:
return {"nodes": [], "edges": []}
finally:
conn.close()
def get_drug_timeline(
db_path: Path,
date_filter_id: str,
chart_type: str,
directory: Optional[str] = None,
trust: Optional[str] = None,
) -> list[dict]:
"""Get drug timeline data for Gantt-style chart.
Queries level 3 nodes and aggregates across trusts to get the earliest
first_seen and latest last_seen per drug × directory.
Returns list of dicts sorted by directory then first_seen:
[{"drug": "ADALIMUMAB", "directory": "RHEUMATOLOGY",
"first_seen": "2019-04-04T00:00:00", "last_seen": "2025-12-30T00:00:00",
"patients": 2053, "cost_pp_pa": 2170.5}, ...]
"""
conn = sqlite3.connect(str(db_path))
conn.row_factory = sqlite3.Row
try:
conditions = ["date_filter_id = ?", "chart_type = ?", "level = 3"]
params: list = [date_filter_id, chart_type]
if directory:
conditions.append("directory = ?")
params.append(directory)
if trust:
conditions.append("trust_name = ?")
params.append(trust)
where = " AND ".join(conditions)
query = f"""
SELECT
labels AS drug,
directory,
MIN(first_seen) AS first_seen,
MAX(last_seen) AS last_seen,
SUM(value) AS patients,
CASE
WHEN SUM(value) > 0
THEN ROUND(SUM(CAST(cost_pp_pa AS REAL) * value) / SUM(value), 2)
ELSE 0
END AS cost_pp_pa
FROM pathway_nodes
WHERE {where}
AND first_seen IS NOT NULL AND first_seen != ''
AND last_seen IS NOT NULL AND last_seen != ''
GROUP BY labels, directory
ORDER BY directory, first_seen
"""
rows = conn.execute(query, params).fetchall()
return [
{
"drug": r["drug"],
"directory": r["directory"],
"first_seen": r["first_seen"],
"last_seen": r["last_seen"],
"patients": r["patients"] or 0,
"cost_pp_pa": r["cost_pp_pa"] or 0,
}
for r in rows
if r["patients"] and r["patients"] > 0
]
except sqlite3.Error:
return []
finally:
conn.close()
def get_dosing_distribution(
db_path: Path,
date_filter_id: str,
chart_type: str,
directory: Optional[str] = None,
trust: Optional[str] = None,
) -> list[dict]:
"""Level 3 drug nodes with average administered dose counts.
Parses the average_administered JSON array (position 0 = avg doses for the drug).
Aggregates across trusts using weighted averages by patient count.
Returns list of dicts sorted by avg_doses desc:
[{drug, directory, avg_doses, patients}]
"""
import json
conn = sqlite3.connect(str(db_path))
conn.row_factory = sqlite3.Row
try:
where = ["date_filter_id = ?", "chart_type = ?", "level = 3",
"average_administered IS NOT NULL", "average_administered != ''"]
params: list = [date_filter_id, chart_type]
if directory:
where.append("directory = ?")
params.append(directory)
if trust:
where.append("trust_name = ?")
params.append(trust)
query = f"""
SELECT labels AS drug, directory, trust_name,
value AS patients, average_administered
FROM pathway_nodes
WHERE {' AND '.join(where)}
ORDER BY labels, directory
"""
rows = conn.execute(query, params).fetchall()
# Aggregate across trusts: weighted average of dose count
agg = {}
for r in rows:
patients = r["patients"] or 0
if patients == 0:
continue
try:
arr = json.loads(r["average_administered"].replace("NaN", "null"))
except (json.JSONDecodeError, AttributeError):
continue
# Position 0 is average doses for this drug
avg_doses = arr[0] if arr and arr[0] is not None else None
if avg_doses is None or avg_doses <= 0:
continue
key = (r["directory"] or "", r["drug"])
if key not in agg:
agg[key] = {
"drug": r["drug"],
"directory": r["directory"] or "",
"weighted_doses": 0.0,
"total_patients": 0,
}
agg[key]["weighted_doses"] += avg_doses * patients
agg[key]["total_patients"] += patients
result = []
for v in agg.values():
tp = v["total_patients"]
if tp > 0:
result.append({
"drug": v["drug"],
"directory": v["directory"],
"avg_doses": round(v["weighted_doses"] / tp, 1),
"patients": tp,
})
result.sort(key=lambda x: -x["avg_doses"])
return result
except sqlite3.Error:
return []
finally:
conn.close()
def get_directorate_summary(
db_path: Path,
date_filter_id: str,
chart_type: str,
) -> list[dict]:
"""Get per-directorate (or per-indication) summary stats for landing page cards.
Returns a list of dicts sorted by patient count descending:
[{"name": "RHEUMATOLOGY", "patients": 847, "drugs": 12}, ...]
Level 2 nodes provide patient counts; level 3 node count gives drug count.
"""
conn = sqlite3.connect(str(db_path))
conn.row_factory = sqlite3.Row
try:
# Get patient counts from level 2 nodes
level2_query = """
SELECT labels AS name, SUM(value) AS patients
FROM pathway_nodes
WHERE date_filter_id = ? AND chart_type = ? AND level = 2
AND labels IS NOT NULL AND labels != ''
GROUP BY labels
"""
level2_rows = conn.execute(level2_query, (date_filter_id, chart_type)).fetchall()
# Get drug counts from level 3 nodes grouped by directory
level3_query = """
SELECT directory, COUNT(DISTINCT labels) AS drug_count
FROM pathway_nodes
WHERE date_filter_id = ? AND chart_type = ? AND level = 3
AND directory IS NOT NULL AND directory != ''
GROUP BY directory
"""
level3_rows = conn.execute(level3_query, (date_filter_id, chart_type)).fetchall()
drug_counts = {r["directory"]: r["drug_count"] for r in level3_rows}
result = []
for r in level2_rows:
name = r["name"]
patients = r["patients"] or 0
if patients == 0:
continue
result.append({
"name": name,
"patients": patients,
"drugs": drug_counts.get(name, 0),
})
result.sort(key=lambda x: x["patients"], reverse=True)
return result
except sqlite3.Error:
return []
finally:
conn.close()
def get_trend_data(
db_path: Path,
metric: str = "patients",
directory: Optional[str] = None,
drug: Optional[str] = None,
group_by: str = "drug",
) -> list[dict]:
"""
Query pathway_trends table for time-series data.
Returns list of dicts with: period_end, name (drug or directory), value.
Args:
db_path: Path to pathways.db
metric: "patients", "total_cost", or "cost_pp_pa"
directory: Optional directory filter
drug: Optional drug filter
group_by: "drug" (one line per drug) or "directory" (one line per directory)
Returns:
List of dicts: [{period_end, name, value}, ...]
Empty list if pathway_trends table doesn't exist or has no data.
"""
conn = sqlite3.connect(str(db_path))
conn.row_factory = sqlite3.Row
try:
cursor = conn.cursor()
cursor.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name='pathway_trends'"
)
if not cursor.fetchone():
return []
valid_metrics = {"patients", "total_cost", "cost_pp_pa"}
if metric not in valid_metrics:
metric = "patients"
where_clauses = []
params = []
if directory:
where_clauses.append("directory = ?")
params.append(directory)
if drug:
where_clauses.append("drug = ?")
params.append(drug)
where_sql = " AND ".join(where_clauses) if where_clauses else "1=1"
# Determine grouping column
group_col = "directory" if group_by == "directory" else "drug"
if metric == "cost_pp_pa":
agg = "SUM(cost_pp_pa * patients) / NULLIF(SUM(patients), 0)"
elif metric == "total_cost":
agg = "SUM(total_cost)"
else:
agg = "SUM(patients)"
sql = f"""
SELECT period_end, {group_col} AS name, {agg} AS value
FROM pathway_trends
WHERE {where_sql}
GROUP BY period_end, {group_col}
ORDER BY period_end, {group_col}
"""
cursor.execute(sql, params)
rows = [dict(row) for row in cursor.fetchall()]
return rows
except sqlite3.Error:
return []
finally:
conn.close()