Files
HighCostDrugsDemo/data_processing/snowflake_connector.py
T

798 lines
27 KiB
Python

"""
Snowflake connector module for NHS Patient Pathway Analysis.
Provides connection handling with SSO browser authentication for NHS environments.
Uses the externalbrowser authenticator which opens a browser window for NHS identity
management authentication.
Usage:
from data_processing.snowflake_connector import SnowflakeConnector, get_connector
# Using context manager (recommended)
with get_connector() as conn:
cursor = conn.cursor()
cursor.execute("SELECT * FROM table LIMIT 10")
results = cursor.fetchall()
# Manual connection management
connector = SnowflakeConnector()
try:
conn = connector.connect()
cursor = conn.cursor()
# ... use cursor ...
finally:
connector.close()
"""
from contextlib import contextmanager
from dataclasses import dataclass
from datetime import date, datetime
from pathlib import Path
from typing import Any, Generator, Optional, TYPE_CHECKING
import time
# Snowflake connector is an optional dependency
SNOWFLAKE_AVAILABLE = False
try:
import snowflake.connector
from snowflake.connector import SnowflakeConnection
from snowflake.connector.cursor import SnowflakeCursor
SNOWFLAKE_AVAILABLE = True
except ImportError:
snowflake = None # type: ignore[assignment]
# Type hints for when snowflake is not available
if TYPE_CHECKING:
from snowflake.connector import SnowflakeConnection
from snowflake.connector.cursor import SnowflakeCursor
from config import get_snowflake_config, SnowflakeConfig
from core.logging_config import get_logger
logger = get_logger(__name__)
class SnowflakeConnectionError(Exception):
"""Raised when Snowflake connection fails."""
pass
class SnowflakeNotConfiguredError(Exception):
"""Raised when Snowflake is not configured (no account)."""
pass
class SnowflakeNotAvailableError(Exception):
"""Raised when snowflake-connector-python is not installed."""
pass
@dataclass
class ConnectionInfo:
"""Information about the current connection state."""
connected: bool = False
account: str = ""
warehouse: str = ""
database: str = ""
schema: str = ""
user: str = ""
role: str = ""
connected_at: Optional[datetime] = None
last_query_at: Optional[datetime] = None
query_count: int = 0
class SnowflakeConnector:
"""
Manages Snowflake connections with SSO browser authentication.
This class provides connection management for NHS Snowflake access using
the externalbrowser authenticator which triggers NHS SSO login via browser.
Attributes:
config: SnowflakeConfig with connection settings
connection_info: ConnectionInfo tracking current state
Example:
connector = SnowflakeConnector()
with connector.get_connection() as conn:
cursor = conn.cursor()
cursor.execute("SELECT CURRENT_USER()")
print(cursor.fetchone()[0])
"""
def __init__(self, config: Optional[SnowflakeConfig] = None):
"""
Initialize the connector with configuration.
Args:
config: Optional SnowflakeConfig. If not provided, loads from
config/snowflake.toml using get_snowflake_config().
"""
self._config = config or get_snowflake_config()
self._connection: Optional[SnowflakeConnection] = None
self._connection_info = ConnectionInfo()
@property
def config(self) -> SnowflakeConfig:
"""Return the Snowflake configuration."""
return self._config
@property
def connection_info(self) -> ConnectionInfo:
"""Return information about the current connection state."""
return self._connection_info
@property
def is_connected(self) -> bool:
"""Return True if currently connected to Snowflake."""
return self._connection is not None and not self._connection.is_closed()
def _check_availability(self) -> None:
"""Check that snowflake-connector-python is installed."""
if not SNOWFLAKE_AVAILABLE:
raise SnowflakeNotAvailableError(
"snowflake-connector-python is not installed. "
"Install it with: pip install snowflake-connector-python"
)
def _check_configured(self) -> None:
"""Check that Snowflake is configured."""
if not self._config.is_configured:
raise SnowflakeNotConfiguredError(
"Snowflake account is not configured. "
"Edit config/snowflake.toml and set connection.account"
)
def connect(self) -> SnowflakeConnection:
"""
Establish a connection to Snowflake.
Uses the externalbrowser authenticator which opens a browser window
for NHS SSO authentication. The browser popup is expected and normal.
Returns:
Active SnowflakeConnection
Raises:
SnowflakeNotAvailableError: If snowflake-connector-python not installed
SnowflakeNotConfiguredError: If account is not configured
SnowflakeConnectionError: If connection fails
"""
self._check_availability()
self._check_configured()
# Close existing connection if any
if self._connection is not None:
self.close()
conn_cfg = self._config.connection
timeout_cfg = self._config.timeouts
logger.info(f"Connecting to Snowflake account: {conn_cfg.account}")
logger.info(f"Using warehouse: {conn_cfg.warehouse}, database: {conn_cfg.database}")
logger.info(f"Authenticator: {conn_cfg.authenticator}")
if conn_cfg.authenticator == "externalbrowser":
logger.info("Browser window will open for NHS SSO authentication")
start_time = time.time()
try:
# Build connection parameters
connect_params = {
"account": conn_cfg.account,
"warehouse": conn_cfg.warehouse,
"database": conn_cfg.database,
"schema": conn_cfg.schema,
"authenticator": conn_cfg.authenticator,
"login_timeout": timeout_cfg.login_timeout,
"network_timeout": timeout_cfg.connection_timeout,
}
# Optional parameters (only add if set)
if conn_cfg.user:
connect_params["user"] = conn_cfg.user
if conn_cfg.role:
connect_params["role"] = conn_cfg.role
self._connection = snowflake.connector.connect(**connect_params)
elapsed = time.time() - start_time
logger.info(f"Connected to Snowflake successfully in {elapsed:.1f}s")
# Update connection info
self._connection_info = ConnectionInfo(
connected=True,
account=conn_cfg.account,
warehouse=conn_cfg.warehouse,
database=conn_cfg.database,
schema=conn_cfg.schema,
user=self._get_current_user(),
role=self._get_current_role(),
connected_at=datetime.now(),
query_count=0,
)
return self._connection
except Exception as e:
elapsed = time.time() - start_time
logger.error(f"Failed to connect to Snowflake after {elapsed:.1f}s: {e}")
self._connection_info = ConnectionInfo(connected=False)
raise SnowflakeConnectionError(f"Failed to connect to Snowflake: {e}") from e
def close(self) -> None:
"""Close the Snowflake connection if open."""
if self._connection is not None:
try:
self._connection.close()
logger.info("Snowflake connection closed")
except Exception as e:
logger.warning(f"Error closing Snowflake connection: {e}")
finally:
self._connection = None
self._connection_info = ConnectionInfo(connected=False)
def _get_current_user(self) -> str:
"""Get the current authenticated user."""
if self._connection is None:
return ""
try:
cursor = self._connection.cursor()
cursor.execute("SELECT CURRENT_USER()")
result = cursor.fetchone()
return result[0] if result else ""
except Exception:
return ""
def _get_current_role(self) -> str:
"""Get the current active role."""
if self._connection is None:
return ""
try:
cursor = self._connection.cursor()
cursor.execute("SELECT CURRENT_ROLE()")
result = cursor.fetchone()
return result[0] if result else ""
except Exception:
return ""
@contextmanager
def get_connection(self) -> Generator[SnowflakeConnection, None, None]:
"""
Context manager for connection handling.
Creates a new connection if not already connected, yields the connection,
and ensures proper cleanup on exit.
Yields:
Active SnowflakeConnection
Example:
connector = SnowflakeConnector()
with connector.get_connection() as conn:
cursor = conn.cursor()
cursor.execute("SELECT 1")
"""
if not self.is_connected:
self.connect()
assert self._connection is not None, "Connection should be established"
try:
yield self._connection
finally:
# Keep connection open for reuse
pass
@contextmanager
def get_cursor(
self,
dict_cursor: bool = False
) -> Generator[SnowflakeCursor, None, None]:
"""
Context manager that provides a cursor.
Args:
dict_cursor: If True, returns cursor that yields dict-like rows
Yields:
SnowflakeCursor for executing queries
Example:
connector = SnowflakeConnector()
with connector.get_cursor() as cursor:
cursor.execute("SELECT * FROM table LIMIT 10")
for row in cursor:
print(row)
"""
if not self.is_connected:
self.connect()
assert self._connection is not None, "Connection should be established"
cursor: Any = None
try:
if dict_cursor:
cursor = self._connection.cursor(snowflake.connector.DictCursor) # type: ignore[union-attr]
else:
cursor = self._connection.cursor()
yield cursor # type: ignore[misc]
self._connection_info.last_query_at = datetime.now()
self._connection_info.query_count += 1
finally:
if cursor is not None:
cursor.close()
def execute(
self,
query: str,
params: Optional[tuple] = None,
timeout: Optional[int] = None
) -> list[tuple]:
"""
Execute a query and return all results.
Args:
query: SQL query to execute
params: Optional query parameters for parameterized queries
timeout: Optional query timeout in seconds (overrides config)
Returns:
List of result rows as tuples
Raises:
SnowflakeConnectionError: If not connected
Various snowflake errors for query issues
"""
if not self.is_connected:
self.connect()
effective_timeout = timeout or self._config.timeouts.query_timeout
with self.get_cursor() as cursor:
logger.info(f"Executing query (timeout={effective_timeout}s)")
logger.debug(f"Query: {query[:200]}...")
if effective_timeout > 0:
cursor.execute(f"ALTER SESSION SET STATEMENT_TIMEOUT_IN_SECONDS = {effective_timeout}")
start_time = time.time()
cursor.execute(query, params)
results = cursor.fetchall()
elapsed = time.time() - start_time
logger.info(f"Query returned {len(results)} rows in {elapsed:.2f}s")
return results
def execute_dict(
self,
query: str,
params: Optional[tuple] = None,
timeout: Optional[int] = None
) -> list[dict]:
"""
Execute a query and return results as list of dictionaries.
Args:
query: SQL query to execute
params: Optional query parameters
timeout: Optional query timeout in seconds
Returns:
List of result rows as dictionaries
"""
if not self.is_connected:
self.connect()
effective_timeout = timeout or self._config.timeouts.query_timeout
with self.get_cursor(dict_cursor=True) as cursor:
logger.info(f"Executing query (timeout={effective_timeout}s)")
logger.debug(f"Query: {query[:200]}...")
if effective_timeout > 0:
cursor.execute(f"ALTER SESSION SET STATEMENT_TIMEOUT_IN_SECONDS = {effective_timeout}")
start_time = time.time()
cursor.execute(query, params)
results = cursor.fetchall()
elapsed = time.time() - start_time
logger.info(f"Query returned {len(results)} rows in {elapsed:.2f}s")
return results # type: ignore[return-value]
def execute_chunked(
self,
query: str,
params: Optional[tuple] = None,
chunk_size: Optional[int] = None,
timeout: Optional[int] = None,
max_rows: Optional[int] = None,
) -> Generator[list[tuple], None, None]:
"""
Execute a query and yield results in chunks for memory efficiency.
This method is useful for large result sets that would exceed memory
if loaded all at once. Results are yielded as chunks of rows.
Args:
query: SQL query to execute
params: Optional query parameters for parameterized queries
chunk_size: Number of rows per chunk (default from config)
timeout: Optional query timeout in seconds (overrides config)
max_rows: Maximum total rows to return (default from config, 0 for no limit)
Yields:
List of result rows as tuples for each chunk
Example:
for chunk in connector.execute_chunked("SELECT * FROM large_table"):
process_chunk(chunk)
"""
if not self.is_connected:
self.connect()
effective_timeout = timeout or self._config.timeouts.query_timeout
effective_chunk_size = chunk_size or self._config.query.chunk_size
effective_max_rows = max_rows if max_rows is not None else self._config.query.max_rows
with self.get_cursor() as cursor:
logger.info(f"Executing chunked query (chunk_size={effective_chunk_size}, timeout={effective_timeout}s)")
logger.debug(f"Query: {query[:200]}...")
if effective_timeout > 0:
cursor.execute(f"ALTER SESSION SET STATEMENT_TIMEOUT_IN_SECONDS = {effective_timeout}")
start_time = time.time()
cursor.execute(query, params)
total_rows = 0
chunk_num = 0
while True:
# Determine how many rows to fetch this chunk
if effective_max_rows > 0:
remaining = effective_max_rows - total_rows
if remaining <= 0:
break
fetch_size = min(effective_chunk_size, remaining)
else:
fetch_size = effective_chunk_size
chunk = cursor.fetchmany(fetch_size)
if not chunk:
break
chunk_num += 1
total_rows += len(chunk)
logger.debug(f"Chunk {chunk_num}: {len(chunk)} rows (total: {total_rows})")
yield chunk
elapsed = time.time() - start_time
logger.info(f"Chunked query returned {total_rows} rows in {chunk_num} chunks ({elapsed:.2f}s)")
def execute_chunked_dict(
self,
query: str,
params: Optional[tuple] = None,
chunk_size: Optional[int] = None,
timeout: Optional[int] = None,
max_rows: Optional[int] = None,
) -> Generator[list[dict], None, None]:
"""
Execute a query and yield dict results in chunks for memory efficiency.
Same as execute_chunked but returns rows as dictionaries.
Args:
query: SQL query to execute
params: Optional query parameters
chunk_size: Number of rows per chunk (default from config)
timeout: Optional query timeout in seconds
max_rows: Maximum total rows to return (default from config, 0 for no limit)
Yields:
List of result rows as dictionaries for each chunk
"""
if not self.is_connected:
self.connect()
effective_timeout = timeout or self._config.timeouts.query_timeout
effective_chunk_size = chunk_size or self._config.query.chunk_size
effective_max_rows = max_rows if max_rows is not None else self._config.query.max_rows
with self.get_cursor(dict_cursor=True) as cursor:
logger.info(f"Executing chunked dict query (chunk_size={effective_chunk_size}, timeout={effective_timeout}s)")
logger.debug(f"Query: {query[:200]}...")
if effective_timeout > 0:
cursor.execute(f"ALTER SESSION SET STATEMENT_TIMEOUT_IN_SECONDS = {effective_timeout}")
start_time = time.time()
cursor.execute(query, params)
total_rows = 0
chunk_num = 0
while True:
# Determine how many rows to fetch this chunk
if effective_max_rows > 0:
remaining = effective_max_rows - total_rows
if remaining <= 0:
break
fetch_size = min(effective_chunk_size, remaining)
else:
fetch_size = effective_chunk_size
chunk = cursor.fetchmany(fetch_size)
if not chunk:
break
chunk_num += 1
total_rows += len(chunk)
logger.debug(f"Chunk {chunk_num}: {len(chunk)} rows (total: {total_rows})")
yield chunk # type: ignore[misc]
elapsed = time.time() - start_time
logger.info(f"Chunked dict query returned {total_rows} rows in {chunk_num} chunks ({elapsed:.2f}s)")
def execute_with_row_limit(
self,
query: str,
params: Optional[tuple] = None,
max_rows: Optional[int] = None,
timeout: Optional[int] = None
) -> tuple[list[dict], bool]:
"""
Execute a query with a row limit and indicate if more rows were available.
This is useful for pagination or previewing large result sets.
Args:
query: SQL query to execute
params: Optional query parameters
max_rows: Maximum rows to return (default from config)
timeout: Optional query timeout in seconds
Returns:
Tuple of (results list, has_more bool)
- results: List of result rows as dictionaries (up to max_rows)
- has_more: True if there were more rows than max_rows
"""
if not self.is_connected:
self.connect()
effective_timeout = timeout or self._config.timeouts.query_timeout
effective_max_rows = max_rows if max_rows is not None else self._config.query.max_rows
with self.get_cursor(dict_cursor=True) as cursor:
logger.info(f"Executing query with limit (max_rows={effective_max_rows}, timeout={effective_timeout}s)")
logger.debug(f"Query: {query[:200]}...")
if effective_timeout > 0:
cursor.execute(f"ALTER SESSION SET STATEMENT_TIMEOUT_IN_SECONDS = {effective_timeout}")
start_time = time.time()
cursor.execute(query, params)
# Fetch one more than max to detect if there are more rows
results = cursor.fetchmany(effective_max_rows + 1)
elapsed = time.time() - start_time
has_more = len(results) > effective_max_rows
if has_more:
results = results[:effective_max_rows]
logger.info(f"Query returned {len(results)} rows (has_more={has_more}) in {elapsed:.2f}s")
return results, has_more # type: ignore[return-value]
def fetch_activity_data(
self,
start_date: Optional[date] = None,
end_date: Optional[date] = None,
provider_codes: Optional[list[str]] = None,
max_rows: Optional[int] = None,
timeout: Optional[int] = None,
) -> list[dict]:
"""
Fetch high-cost drug activity data from Snowflake.
Queries the CDM.Acute__Conmon__PatientLevelDrugs table and returns
data in a format compatible with the existing analysis pipeline.
Args:
start_date: Optional start date for filtering (inclusive)
end_date: Optional end date for filtering (inclusive)
provider_codes: Optional list of provider codes to filter by
max_rows: Maximum rows to return (default from config)
timeout: Query timeout in seconds (default from config)
Returns:
List of dictionaries with keys matching expected DataFrame columns:
- PseudoNHSNoLinked: Pseudonymised NHS number (for UPID creation)
- Provider Code: NHS provider code
- PersonKey: Local patient identifier
- Drug Name: Raw drug name
- Intervention Date: Date of intervention
- Price Actual: Cost of intervention
- OrganisationName: Provider organisation name
- Treatment Function Code: NHS treatment function code
- Additional Detail 1-5: Additional details for directory identification
Raises:
SnowflakeConnectionError: If not connected or query fails
"""
if not self.is_connected:
self.connect()
# Build the query
table_name = 'DATA_HUB.CDM."Acute__Conmon__PatientLevelDrugs"'
query = f'''
SELECT
"PseudoNHSNoLinked",
"ProviderCode" AS "Provider Code",
"LocalPatientID" AS "PersonKey",
"DrugName" AS "Drug Name",
"InterventionDate" AS "Intervention Date",
"PriceActual" AS "Price Actual",
"ProviderName" AS "OrganisationName",
"TreatmentFunctionCode" AS "Treatment Function Code",
"TreatmentFunctionDesc" AS "Treatment Function Desc",
"AdditionalDetail1" AS "Additional Detail 1",
"AdditionalDescription1" AS "Additional Description 1",
"AdditionalDetail2" AS "Additional Detail 2",
"AdditionalDescription2" AS "Additional Description 2",
"AdditionalDetail3" AS "Additional Detail 3",
"AdditionalDescription3" AS "Additional Description 3",
"AdditionalDetail4" AS "Additional Detail 4",
"AdditionalDescription4" AS "Additional Description 4",
"AdditionalDetail5" AS "Additional Detail 5",
"AdditionalDescription5" AS "Additional Description 5"
FROM {table_name}
WHERE 1=1
'''
params = []
# Add date filters
if start_date:
query += ' AND "InterventionDate" >= %s'
params.append(start_date.isoformat())
if end_date:
query += ' AND "InterventionDate" <= %s'
params.append(end_date.isoformat())
# Add provider filter
if provider_codes:
placeholders = ", ".join(["%s"] * len(provider_codes))
query += f' AND "ProviderCode" IN ({placeholders})'
params.extend(provider_codes)
# Add ordering for consistent results
query += ' ORDER BY "InterventionDate", "ProviderCode", "PseudoNHSNoLinked"'
logger.info(f"Fetching activity data from Snowflake")
if start_date:
logger.info(f" Date range: {start_date} to {end_date or 'now'}")
if provider_codes:
logger.info(f" Providers: {provider_codes}")
effective_max_rows = max_rows if max_rows is not None else self._config.query.max_rows
effective_timeout = timeout or self._config.timeouts.query_timeout
# Execute with chunked results for large datasets
all_results = []
total_rows = 0
for chunk in self.execute_chunked_dict(
query,
params=tuple(params) if params else None,
timeout=effective_timeout,
max_rows=effective_max_rows,
):
all_results.extend(chunk)
total_rows += len(chunk)
logger.debug(f"Fetched {total_rows} rows so far...")
logger.info(f"Fetched {len(all_results)} activity records from Snowflake")
return all_results
def test_connection(self) -> tuple[bool, str]:
"""
Test the Snowflake connection.
Returns:
Tuple of (success: bool, message: str)
"""
try:
self._check_availability()
except SnowflakeNotAvailableError as e:
return False, str(e)
try:
self._check_configured()
except SnowflakeNotConfiguredError as e:
return False, str(e)
try:
self.connect()
user = self._get_current_user()
role = self._get_current_role()
return True, f"Connected as {user} with role {role}"
except Exception as e:
return False, f"Connection failed: {e}"
def __enter__(self) -> "SnowflakeConnector":
"""Context manager entry."""
self.connect()
return self
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
"""Context manager exit."""
self.close()
# Module-level singleton for convenience
_default_connector: Optional[SnowflakeConnector] = None
def get_connector(config: Optional[SnowflakeConfig] = None) -> SnowflakeConnector:
"""
Get a Snowflake connector (creates singleton on first call).
Args:
config: Optional configuration. If provided, creates new connector
with this config. If None, uses/creates default connector.
Returns:
SnowflakeConnector instance
"""
global _default_connector
if config is not None:
# Custom config requested, create new connector
return SnowflakeConnector(config)
if _default_connector is None:
_default_connector = SnowflakeConnector()
return _default_connector
def reset_connector() -> None:
"""Reset the default connector (closes connection and clears singleton)."""
global _default_connector
if _default_connector is not None:
_default_connector.close()
_default_connector = None
def is_snowflake_available() -> bool:
"""Return True if snowflake-connector-python is installed."""
return SNOWFLAKE_AVAILABLE
def is_snowflake_configured() -> bool:
"""Return True if Snowflake account is configured."""
try:
config = get_snowflake_config()
return config.is_configured
except Exception:
return False
# Export public API
__all__ = [
"SnowflakeConnector",
"SnowflakeConnectionError",
"SnowflakeNotConfiguredError",
"SnowflakeNotAvailableError",
"ConnectionInfo",
"get_connector",
"reset_connector",
"is_snowflake_available",
"is_snowflake_configured",
"SNOWFLAKE_AVAILABLE",
]