Removed dead code/unneeded env setting

This commit is contained in:
2026-03-16 11:14:04 +00:00
parent 4ae150f420
commit 5dc546014a
2 changed files with 81 additions and 233 deletions
+5 -5
View File
@@ -7,7 +7,7 @@ Gives Claude (or any MCP client) the ability to explore your Snowflake account
## Features ## Features
- SSO authentication (opens browser on first connection, then reuses the session) - SSO authentication (opens browser on first connection, then reuses the session)
- Connection pooling with 30-minute timeout - Single persistent connection with 30-minute timeout
- Query caching (5 min TTL) - Query caching (5 min TTL)
- Read-only safety: only SELECT / SHOW / DESCRIBE queries are allowed - Read-only safety: only SELECT / SHOW / DESCRIBE queries are allowed
- Async query support for long-running queries - Async query support for long-running queries
@@ -37,14 +37,14 @@ Set environment variables (recommended) or edit the defaults at the top of `snow
export SNOWFLAKE_ACCOUNT="your-account-id" # e.g. "xy12345.eu-west-1" export SNOWFLAKE_ACCOUNT="your-account-id" # e.g. "xy12345.eu-west-1"
export SNOWFLAKE_USER="your.email@company.com" export SNOWFLAKE_USER="your.email@company.com"
# Optional — defaults shown # Optional
export SNOWFLAKE_AUTHENTICATOR="externalbrowser" # SSO via browser export SNOWFLAKE_AUTHENTICATOR="externalbrowser" # SSO via browser (default)
export SNOWFLAKE_WAREHOUSE="" # uses account default if empty export SNOWFLAKE_WAREHOUSE="" # uses account default if empty
export SNOWFLAKE_ROLE="" # uses account default if empty export SNOWFLAKE_ROLE="" # uses account default if empty
export SNOWFLAKE_DATABASE="" # uses account default if empty
export SNOWFLAKE_SCHEMA="" # uses account default if empty
``` ```
Once connected, use `list_databases` and `list_schemas` to discover what you have access to.
On Windows (PowerShell): On Windows (PowerShell):
```powershell ```powershell
$env:SNOWFLAKE_ACCOUNT = "your-account-id" $env:SNOWFLAKE_ACCOUNT = "your-account-id"
+76 -228
View File
@@ -37,101 +37,73 @@ SNOWFLAKE_CONFIG = {
"role": os.environ.get("SNOWFLAKE_ROLE"), # None = account default "role": os.environ.get("SNOWFLAKE_ROLE"), # None = account default
} }
# Available databases/schemas to expose
DATABASES = {
"default": {
"database": os.environ.get("SNOWFLAKE_DATABASE"), # None = account default
"schema": os.environ.get("SNOWFLAKE_SCHEMA"), # None = account default
"description": "Default database and schema"
},
}
class SnowflakeConnection:
class SnowflakeConnectionPool: """Manages a single Snowflake connection with SSO authentication"""
"""Manages Snowflake connections with SSO authentication"""
def __init__(self): def __init__(self):
self.connections: Dict[str, snowflake.connector.SnowflakeConnection] = {} self._conn: Optional[snowflake.connector.SnowflakeConnection] = None
self.last_used: Dict[str, datetime] = {} self._last_used: Optional[datetime] = None
self.connection_timeout = timedelta(minutes=30) self._timeout = timedelta(minutes=30)
def get_connection(self, database: str = "default") -> snowflake.connector.SnowflakeConnection: def get(self) -> snowflake.connector.SnowflakeConnection:
"""Get or create a Snowflake connection with SSO auth""" """Get or create the Snowflake connection"""
cache_key = database
# Check if we have a valid cached connection # Check if we have a valid cached connection
if cache_key in self.connections: if self._conn and self._last_used:
conn = self.connections[cache_key] if datetime.now() - self._last_used < self._timeout:
last = self.last_used.get(cache_key, datetime.now())
# Check if connection is still valid and not timed out
if datetime.now() - last < self.connection_timeout:
try: try:
# Test the connection cursor = self._conn.cursor()
cursor = conn.cursor()
cursor.execute("SELECT 1") cursor.execute("SELECT 1")
cursor.close() cursor.close()
self.last_used[cache_key] = datetime.now() self._last_used = datetime.now()
logger.info(f"Reusing cached Snowflake connection for {database}") logger.info("Reusing cached Snowflake connection")
return conn return self._conn
except Exception: except Exception:
logger.info(f"Cached connection for {database} is stale, reconnecting...") logger.info("Cached connection is stale, reconnecting...")
try: try:
conn.close() self._conn.close()
except Exception: except Exception:
pass pass
# Create new connection # Create new connection
logger.info(f"Creating new Snowflake connection for {database}") logger.info("Creating new Snowflake connection")
logger.info("Note: First connection will open browser for SSO login") logger.info("Note: First connection will open browser for SSO login")
db_config = DATABASES.get(database, DATABASES["default"])
connect_params = { connect_params = {
"account": SNOWFLAKE_CONFIG["account"], "account": SNOWFLAKE_CONFIG["account"],
"user": SNOWFLAKE_CONFIG["user"], "user": SNOWFLAKE_CONFIG["user"],
"authenticator": SNOWFLAKE_CONFIG["authenticator"], "authenticator": SNOWFLAKE_CONFIG["authenticator"],
} }
# Add optional parameters if specified
if db_config.get("database"):
connect_params["database"] = db_config["database"]
if db_config.get("schema"):
connect_params["schema"] = db_config["schema"]
if SNOWFLAKE_CONFIG.get("warehouse"): if SNOWFLAKE_CONFIG.get("warehouse"):
connect_params["warehouse"] = SNOWFLAKE_CONFIG["warehouse"] connect_params["warehouse"] = SNOWFLAKE_CONFIG["warehouse"]
if SNOWFLAKE_CONFIG.get("role"): if SNOWFLAKE_CONFIG.get("role"):
connect_params["role"] = SNOWFLAKE_CONFIG["role"] connect_params["role"] = SNOWFLAKE_CONFIG["role"]
try: try:
conn = snowflake.connector.connect(**connect_params) self._conn = snowflake.connector.connect(**connect_params)
self._last_used = datetime.now()
# Store the connection logger.info("Successfully connected to Snowflake")
self.connections[cache_key] = conn return self._conn
self.last_used[cache_key] = datetime.now()
logger.info(f"Successfully connected to Snowflake ({database})")
return conn
except Exception as e: except Exception as e:
logger.error(f"Failed to connect to Snowflake: {e}") logger.error(f"Failed to connect to Snowflake: {e}")
raise raise
def close_all(self): def close(self):
"""Close all connections""" """Close the connection"""
for database, conn in self.connections.items(): if self._conn:
try: try:
conn.close() self._conn.close()
logger.info(f"Closed Snowflake connection to {database}") logger.info("Closed Snowflake connection")
except Exception: except Exception:
pass pass
self.connections.clear() self._conn = None
self.last_used.clear() self._last_used = None
# Global connection pool # Global connection
connection_pool = SnowflakeConnectionPool() connection = SnowflakeConnection()
# Query cache # Query cache
query_cache: Dict[str, tuple] = {} query_cache: Dict[str, tuple] = {}
@@ -141,9 +113,9 @@ CACHE_TTL = timedelta(minutes=5)
async_queries: Dict[str, Dict[str, Any]] = {} async_queries: Dict[str, Dict[str, Any]] = {}
def get_cache_key(query: str, database: str) -> str: def get_cache_key(query: str) -> str:
"""Generate cache key for query""" """Generate cache key for query"""
return hashlib.md5(f"{database}:{query}".encode()).hexdigest() return hashlib.md5(query.encode()).hexdigest()
def validate_query(query: str) -> tuple[bool, str]: def validate_query(query: str) -> tuple[bool, str]:
@@ -161,20 +133,18 @@ def validate_query(query: str) -> tuple[bool, str]:
'CREATE', 'ALTER', 'GRANT', 'REVOKE', 'MERGE'] 'CREATE', 'ALTER', 'GRANT', 'REVOKE', 'MERGE']
for keyword in dangerous: for keyword in dangerous:
# Check for keyword as whole word (not part of column name)
if f' {keyword} ' in f' {query_upper} ': if f' {keyword} ' in f' {query_upper} ':
return False, f"Query contains forbidden keyword: {keyword}" return False, f"Query contains forbidden keyword: {keyword}"
return True, "Query validated successfully" return True, "Query validated successfully"
async def test_connection(database: str = "default") -> Dict[str, Any]: async def test_connection() -> Dict[str, Any]:
"""Test Snowflake connection and return system information""" """Test Snowflake connection and return system information"""
try: try:
conn = connection_pool.get_connection(database) conn = connection.get()
cursor = conn.cursor(DictCursor) cursor = conn.cursor(DictCursor)
# Get session information
cursor.execute(""" cursor.execute("""
SELECT SELECT
CURRENT_ACCOUNT() as account, CURRENT_ACCOUNT() as account,
@@ -211,14 +181,13 @@ async def test_connection(database: str = "default") -> Dict[str, Any]:
} }
async def list_tables(database: str = "default", schema: Optional[str] = None, async def list_tables(schema: Optional[str] = None,
include_row_counts: bool = False) -> Dict[str, Any]: include_row_counts: bool = False) -> Dict[str, Any]:
"""List all tables in the database/schema""" """List all tables in the database/schema"""
try: try:
conn = connection_pool.get_connection(database) conn = connection.get()
cursor = conn.cursor(DictCursor) cursor = conn.cursor(DictCursor)
# Get current database and schema if not specified
if schema: if schema:
query = f"SHOW TABLES IN SCHEMA {schema}" query = f"SHOW TABLES IN SCHEMA {schema}"
else: else:
@@ -242,7 +211,6 @@ async def list_tables(database: str = "default", schema: Optional[str] = None,
cursor.close() cursor.close()
# Return with helpful hint if no tables found
if not tables: if not tables:
return { return {
"tables": [], "tables": [],
@@ -257,10 +225,10 @@ async def list_tables(database: str = "default", schema: Optional[str] = None,
raise raise
async def list_views(schema_name: str, database: str = "default") -> List[Dict[str, Any]]: async def list_views(schema_name: str) -> List[Dict[str, Any]]:
"""List all views in a schema with their descriptions""" """List all views in a schema with their descriptions"""
try: try:
conn = connection_pool.get_connection(database) conn = connection.get()
cursor = conn.cursor(DictCursor) cursor = conn.cursor(DictCursor)
cursor.execute(f"SHOW VIEWS IN SCHEMA {schema_name}") cursor.execute(f"SHOW VIEWS IN SCHEMA {schema_name}")
@@ -287,13 +255,12 @@ async def list_views(schema_name: str, database: str = "default") -> List[Dict[s
raise raise
async def describe_table(table_name: str, database: str = "default") -> Dict[str, Any]: async def describe_table(table_name: str) -> Dict[str, Any]:
"""Get detailed schema information for a table or view""" """Get detailed schema information for a table or view"""
try: try:
conn = connection_pool.get_connection(database) conn = connection.get()
cursor = conn.cursor() cursor = conn.cursor()
# Parse table name to handle quoting
parts = table_name.split('.') parts = table_name.split('.')
if len(parts) == 3: if len(parts) == 3:
db, schema, table = parts db, schema, table = parts
@@ -305,7 +272,6 @@ async def describe_table(table_name: str, database: str = "default") -> Dict[str
schema = None schema = None
table = table_name table = table_name
# Build quoted version for case-sensitive names
if len(parts) == 3: if len(parts) == 3:
quoted_name = f'{parts[0]}.{parts[1]}."{parts[2]}"' quoted_name = f'{parts[0]}.{parts[1]}."{parts[2]}"'
elif len(parts) == 2: elif len(parts) == 2:
@@ -313,7 +279,6 @@ async def describe_table(table_name: str, database: str = "default") -> Dict[str
else: else:
quoted_name = f'"{table_name}"' quoted_name = f'"{table_name}"'
# Try unquoted first, then quoted if that fails
rows = None rows = None
used_name = table_name used_name = table_name
for try_name in [table_name, quoted_name]: for try_name in [table_name, quoted_name]:
@@ -324,13 +289,12 @@ async def describe_table(table_name: str, database: str = "default") -> Dict[str
break break
except Exception as inner_e: except Exception as inner_e:
if "does not exist" in str(inner_e).lower() and try_name == table_name: if "does not exist" in str(inner_e).lower() and try_name == table_name:
continue # Try quoted version continue
raise raise
if rows is None: if rows is None:
rows = [] rows = []
# Get column names from description using attribute access (safer for ResultMetadata)
desc_columns = [] desc_columns = []
if cursor.description: if cursor.description:
for desc in cursor.description: for desc in cursor.description:
@@ -385,11 +349,9 @@ async def describe_table(table_name: str, database: str = "default") -> Dict[str
} }
async def read_data(query: str, database: str = "default", async def read_data(query: str, max_rows: int = 50000) -> Dict[str, Any]:
max_rows: int = 50000) -> Dict[str, Any]:
"""Execute a SELECT query and return results""" """Execute a SELECT query and return results"""
# Validate query
is_valid, message = validate_query(query) is_valid, message = validate_query(query)
if not is_valid: if not is_valid:
return { return {
@@ -397,8 +359,7 @@ async def read_data(query: str, database: str = "default",
"suggestion": "Use SELECT statements only. For schema information, use describe_table." "suggestion": "Use SELECT statements only. For schema information, use describe_table."
} }
# Check cache cache_key = get_cache_key(query)
cache_key = get_cache_key(query, database)
if cache_key in query_cache: if cache_key in query_cache:
cached_result, cached_time = query_cache[cache_key] cached_result, cached_time = query_cache[cache_key]
if datetime.now() - cached_time < CACHE_TTL: if datetime.now() - cached_time < CACHE_TTL:
@@ -410,23 +371,19 @@ async def read_data(query: str, database: str = "default",
} }
try: try:
conn = connection_pool.get_connection(database) conn = connection.get()
cursor = conn.cursor(DictCursor) cursor = conn.cursor(DictCursor)
# Execute query
cursor.execute(query) cursor.execute(query)
# Get column names
columns = [desc[0] for desc in cursor.description] if cursor.description else [] columns = [desc[0] for desc in cursor.description] if cursor.description else []
# Fetch rows up to limit
rows = [] rows = []
for i, row in enumerate(cursor): for i, row in enumerate(cursor):
if i >= max_rows: if i >= max_rows:
logger.warning(f"Query returned more than {max_rows} rows, truncating") logger.warning(f"Query returned more than {max_rows} rows, truncating")
break break
# Convert row to dict with proper serialization
row_dict = {} row_dict = {}
for col in columns: for col in columns:
value = row[col] value = row[col]
@@ -446,10 +403,8 @@ async def read_data(query: str, database: str = "default",
"rows": rows, "rows": rows,
"row_count": len(rows), "row_count": len(rows),
"truncated": len(rows) == max_rows, "truncated": len(rows) == max_rows,
"database": database
} }
# Cache the result
query_cache[cache_key] = (result, datetime.now()) query_cache[cache_key] = (result, datetime.now())
return result return result
@@ -474,7 +429,7 @@ async def read_data(query: str, database: str = "default",
async def list_databases() -> List[Dict[str, Any]]: async def list_databases() -> List[Dict[str, Any]]:
"""List all accessible databases""" """List all accessible databases"""
try: try:
conn = connection_pool.get_connection("default") conn = connection.get()
cursor = conn.cursor(DictCursor) cursor = conn.cursor(DictCursor)
cursor.execute("SHOW DATABASES") cursor.execute("SHOW DATABASES")
@@ -501,7 +456,7 @@ async def list_databases() -> List[Dict[str, Any]]:
async def list_schemas(database_name: Optional[str] = None) -> List[Dict[str, Any]]: async def list_schemas(database_name: Optional[str] = None) -> List[Dict[str, Any]]:
"""List all schemas in a database""" """List all schemas in a database"""
try: try:
conn = connection_pool.get_connection("default") conn = connection.get()
cursor = conn.cursor(DictCursor) cursor = conn.cursor(DictCursor)
if database_name: if database_name:
@@ -533,14 +488,6 @@ async def get_system_health() -> Dict[str, Any]:
health_info = { health_info = {
"status": "healthy", "status": "healthy",
"timestamp": datetime.now().isoformat(), "timestamp": datetime.now().isoformat(),
"connection_pool": {
"active_connections": len(connection_pool.connections),
"databases": list(connection_pool.connections.keys()),
"last_used": {
db: last.isoformat()
for db, last in connection_pool.last_used.items()
}
},
"cache": { "cache": {
"query_cache_size": len(query_cache), "query_cache_size": len(query_cache),
"cache_ttl_minutes": CACHE_TTL.total_seconds() / 60 "cache_ttl_minutes": CACHE_TTL.total_seconds() / 60
@@ -552,9 +499,8 @@ async def get_system_health() -> Dict[str, Any]:
} }
} }
# Test connection
try: try:
test_result = await test_connection("default") test_result = await test_connection()
health_info["connection"] = test_result health_info["connection"] = test_result
except Exception: except Exception:
health_info["status"] = "degraded" health_info["status"] = "degraded"
@@ -563,7 +509,7 @@ async def get_system_health() -> Dict[str, Any]:
return health_info return health_info
async def describe_query(query: str, database: str = "default") -> Dict[str, Any]: async def describe_query(query: str) -> Dict[str, Any]:
"""Preview query output columns without executing the full query. """Preview query output columns without executing the full query.
Uses LIMIT 0 to get column metadata efficiently.""" Uses LIMIT 0 to get column metadata efficiently."""
@@ -572,21 +518,18 @@ async def describe_query(query: str, database: str = "default") -> Dict[str, Any
return {"error": message} return {"error": message}
try: try:
conn = connection_pool.get_connection(database) conn = connection.get()
cursor = conn.cursor(DictCursor) cursor = conn.cursor(DictCursor)
clean_query = query.rstrip(';').strip() clean_query = query.rstrip(';').strip()
# Remove existing LIMIT and OFFSET clauses (case insensitive)
clean_query = re.sub(r'\s+LIMIT\s+\d+(\s+OFFSET\s+\d+)?\s*$', '', clean_query, flags=re.IGNORECASE) clean_query = re.sub(r'\s+LIMIT\s+\d+(\s+OFFSET\s+\d+)?\s*$', '', clean_query, flags=re.IGNORECASE)
clean_query = re.sub(r'\s+OFFSET\s+\d+\s*$', '', clean_query, flags=re.IGNORECASE) clean_query = re.sub(r'\s+OFFSET\s+\d+\s*$', '', clean_query, flags=re.IGNORECASE)
# Add LIMIT 0 to get column metadata without fetching data
limited_query = f"{clean_query} LIMIT 0" limited_query = f"{clean_query} LIMIT 0"
try: try:
cursor.execute(limited_query) cursor.execute(limited_query)
except Exception: except Exception:
# Fallback: wrap in subquery if direct approach fails
limited_query = f"SELECT * FROM ({clean_query}) AS subq LIMIT 0" limited_query = f"SELECT * FROM ({clean_query}) AS subq LIMIT 0"
cursor.execute(limited_query) cursor.execute(limited_query)
@@ -609,7 +552,6 @@ async def describe_query(query: str, database: str = "default") -> Dict[str, Any
return { return {
"columns": columns, "columns": columns,
"column_count": len(columns), "column_count": len(columns),
"database": database,
"query_preview": query[:200] + "..." if len(query) > 200 else query "query_preview": query[:200] + "..." if len(query) > 200 else query
} }
@@ -621,7 +563,7 @@ async def describe_query(query: str, database: str = "default") -> Dict[str, Any
} }
async def execute_async(query: str, database: str = "default") -> Dict[str, Any]: async def execute_async(query: str) -> Dict[str, Any]:
"""Submit a query for asynchronous execution. Returns query ID for status tracking.""" """Submit a query for asynchronous execution. Returns query ID for status tracking."""
is_valid, message = validate_query(query) is_valid, message = validate_query(query)
@@ -629,17 +571,14 @@ async def execute_async(query: str, database: str = "default") -> Dict[str, Any]
return {"error": message} return {"error": message}
try: try:
conn = connection_pool.get_connection(database) conn = connection.get()
cursor = conn.cursor() cursor = conn.cursor()
# Execute asynchronously
cursor.execute_async(query) cursor.execute_async(query)
query_id = cursor.sfqid query_id = cursor.sfqid
# Store query info for tracking
async_queries[query_id] = { async_queries[query_id] = {
"query": query[:500], "query": query[:500],
"database": database,
"submitted_at": datetime.now().isoformat(), "submitted_at": datetime.now().isoformat(),
"status": "RUNNING" "status": "RUNNING"
} }
@@ -650,7 +589,6 @@ async def execute_async(query: str, database: str = "default") -> Dict[str, Any]
"query_id": query_id, "query_id": query_id,
"status": "SUBMITTED", "status": "SUBMITTED",
"message": "Query submitted for async execution. Use get_query_status to check progress.", "message": "Query submitted for async execution. Use get_query_status to check progress.",
"database": database
} }
except Exception as e: except Exception as e:
@@ -658,11 +596,11 @@ async def execute_async(query: str, database: str = "default") -> Dict[str, Any]
return {"error": str(e)} return {"error": str(e)}
async def get_query_status(query_id: str, database: str = "default") -> Dict[str, Any]: async def get_query_status(query_id: str) -> Dict[str, Any]:
"""Check the status of an async query by its query ID.""" """Check the status of an async query by its query ID."""
try: try:
conn = connection_pool.get_connection(database) conn = connection.get()
status = conn.get_query_status(query_id) status = conn.get_query_status(query_id)
status_name = status.name if hasattr(status, 'name') else str(status) status_name = status.name if hasattr(status, 'name') else str(status)
@@ -670,7 +608,6 @@ async def get_query_status(query_id: str, database: str = "default") -> Dict[str
is_running = conn.is_still_running(status) is_running = conn.is_still_running(status)
is_error = conn.is_an_error(status) is_error = conn.is_an_error(status)
# Update tracked query info
if query_id in async_queries: if query_id in async_queries:
async_queries[query_id]["status"] = status_name async_queries[query_id]["status"] = status_name
@@ -682,7 +619,6 @@ async def get_query_status(query_id: str, database: str = "default") -> Dict[str
"can_fetch_results": not is_running and not is_error "can_fetch_results": not is_running and not is_error
} }
# Add original query info if available
if query_id in async_queries: if query_id in async_queries:
result["query_info"] = async_queries[query_id] result["query_info"] = async_queries[query_id]
@@ -693,14 +629,12 @@ async def get_query_status(query_id: str, database: str = "default") -> Dict[str
return {"error": str(e), "query_id": query_id} return {"error": str(e), "query_id": query_id}
async def get_async_results(query_id: str, database: str = "default", async def get_async_results(query_id: str, max_rows: int = 50000) -> Dict[str, Any]:
max_rows: int = 50000) -> Dict[str, Any]:
"""Retrieve results from a completed async query.""" """Retrieve results from a completed async query."""
try: try:
conn = connection_pool.get_connection(database) conn = connection.get()
# Check status first
status = conn.get_query_status_throw_if_error(query_id) status = conn.get_query_status_throw_if_error(query_id)
if conn.is_still_running(status): if conn.is_still_running(status):
@@ -710,17 +644,13 @@ async def get_async_results(query_id: str, database: str = "default",
"status": status.name if hasattr(status, 'name') else str(status) "status": status.name if hasattr(status, 'name') else str(status)
} }
# Use standard cursor for get_results_from_sfqid
cursor = conn.cursor() cursor = conn.cursor()
cursor.get_results_from_sfqid(query_id) cursor.get_results_from_sfqid(query_id)
# Fetch results first (description may not populate until after fetch)
raw_rows = cursor.fetchmany(max_rows) raw_rows = cursor.fetchmany(max_rows)
# Now get column names from description
columns = [desc[0] for desc in cursor.description] if cursor.description else [] columns = [desc[0] for desc in cursor.description] if cursor.description else []
# Convert to list of dicts
rows = [] rows = []
for row in raw_rows: for row in raw_rows:
row_dict = {} row_dict = {}
@@ -734,12 +664,10 @@ async def get_async_results(query_id: str, database: str = "default",
row_dict[col] = value row_dict[col] = value
rows.append(row_dict) rows.append(row_dict)
# Check if there are more rows
truncated = len(raw_rows) == max_rows truncated = len(raw_rows) == max_rows
cursor.close() cursor.close()
# Clean up tracking
if query_id in async_queries: if query_id in async_queries:
del async_queries[query_id] del async_queries[query_id]
@@ -756,8 +684,8 @@ async def get_async_results(query_id: str, database: str = "default",
return {"error": str(e), "query_id": query_id} return {"error": str(e), "query_id": query_id}
async def read_data_paginated(query: str, database: str = "default", async def read_data_paginated(query: str, page_size: int = 1000,
page_size: int = 1000, page: int = 1) -> Dict[str, Any]: page: int = 1) -> Dict[str, Any]:
"""Execute a query with pagination support (offset/limit).""" """Execute a query with pagination support (offset/limit)."""
is_valid, message = validate_query(query) is_valid, message = validate_query(query)
@@ -771,12 +699,10 @@ async def read_data_paginated(query: str, database: str = "default",
offset = (page - 1) * page_size offset = (page - 1) * page_size
# Strip existing LIMIT/OFFSET from query for proper pagination
clean_query = query.rstrip(';').strip() clean_query = query.rstrip(';').strip()
clean_query = re.sub(r'\s+LIMIT\s+\d+(\s+OFFSET\s+\d+)?\s*$', '', clean_query, flags=re.IGNORECASE) clean_query = re.sub(r'\s+LIMIT\s+\d+(\s+OFFSET\s+\d+)?\s*$', '', clean_query, flags=re.IGNORECASE)
clean_query = re.sub(r'\s+OFFSET\s+\d+\s*$', '', clean_query, flags=re.IGNORECASE) clean_query = re.sub(r'\s+OFFSET\s+\d+\s*$', '', clean_query, flags=re.IGNORECASE)
# Wrap query with pagination
paginated_query = f""" paginated_query = f"""
SELECT * FROM ( SELECT * FROM (
{clean_query} {clean_query}
@@ -784,8 +710,7 @@ async def read_data_paginated(query: str, database: str = "default",
LIMIT {page_size} OFFSET {offset} LIMIT {page_size} OFFSET {offset}
""" """
# Also get total count (cached for efficiency) count_cache_key = get_cache_key(f"COUNT:{clean_query}")
count_cache_key = get_cache_key(f"COUNT:{clean_query}", database)
total_count = None total_count = None
if count_cache_key in query_cache: if count_cache_key in query_cache:
@@ -794,10 +719,9 @@ async def read_data_paginated(query: str, database: str = "default",
total_count = cached_count total_count = cached_count
try: try:
conn = connection_pool.get_connection(database) conn = connection.get()
cursor = conn.cursor(DictCursor) cursor = conn.cursor(DictCursor)
# Get total count if not cached
if total_count is None: if total_count is None:
count_query = f"SELECT COUNT(*) as total FROM ({clean_query}) AS count_subq" count_query = f"SELECT COUNT(*) as total FROM ({clean_query}) AS count_subq"
cursor.execute(count_query) cursor.execute(count_query)
@@ -805,7 +729,6 @@ async def read_data_paginated(query: str, database: str = "default",
total_count = count_row["TOTAL"] if count_row else 0 total_count = count_row["TOTAL"] if count_row else 0
query_cache[count_cache_key] = (total_count, datetime.now()) query_cache[count_cache_key] = (total_count, datetime.now())
# Execute paginated query
cursor.execute(paginated_query) cursor.execute(paginated_query)
columns = [desc[0] for desc in cursor.description] if cursor.description else [] columns = [desc[0] for desc in cursor.description] if cursor.description else []
@@ -838,7 +761,6 @@ async def read_data_paginated(query: str, database: str = "default",
"has_next": page < total_pages, "has_next": page < total_pages,
"has_previous": page > 1 "has_previous": page > 1
}, },
"database": database
} }
except Exception as e: except Exception as e:
@@ -846,8 +768,7 @@ async def read_data_paginated(query: str, database: str = "default",
return {"error": str(e)} return {"error": str(e)}
async def read_data_pandas(query: str, database: str = "default", async def read_data_pandas(query: str, max_rows: int = 50000,
max_rows: int = 50000,
output_format: str = "records") -> Dict[str, Any]: output_format: str = "records") -> Dict[str, Any]:
"""Execute a query and return results in a pandas-friendly format. """Execute a query and return results in a pandas-friendly format.
@@ -866,14 +787,13 @@ async def read_data_pandas(query: str, database: str = "default",
return {"error": f"Invalid output_format: {output_format}. Use: records, columns, split, csv"} return {"error": f"Invalid output_format: {output_format}. Use: records, columns, split, csv"}
try: try:
conn = connection_pool.get_connection(database) conn = connection.get()
cursor = conn.cursor(DictCursor) cursor = conn.cursor(DictCursor)
cursor.execute(query) cursor.execute(query)
columns = [desc[0] for desc in cursor.description] if cursor.description else [] columns = [desc[0] for desc in cursor.description] if cursor.description else []
# Collect all rows first
all_rows = [] all_rows = []
for i, row in enumerate(cursor): for i, row in enumerate(cursor):
if i >= max_rows: if i >= max_rows:
@@ -894,19 +814,15 @@ async def read_data_pandas(query: str, database: str = "default",
cursor.close() cursor.close()
truncated = len(all_rows) == max_rows truncated = len(all_rows) == max_rows
# Format output based on requested format
if output_format == 'records': if output_format == 'records':
data = all_rows data = all_rows
elif output_format == 'columns': elif output_format == 'columns':
data = {col: [row.get(col) for row in all_rows] for col in columns} data = {col: [row.get(col) for row in all_rows] for col in columns}
elif output_format == 'split': elif output_format == 'split':
data = { data = {
"columns": columns, "columns": columns,
"data": [[row.get(col) for col in columns] for row in all_rows] "data": [[row.get(col) for col in columns] for row in all_rows]
} }
elif output_format == 'csv': elif output_format == 'csv':
output = io.StringIO() output = io.StringIO()
writer = csv.DictWriter(output, fieldnames=columns) writer = csv.DictWriter(output, fieldnames=columns)
@@ -920,7 +836,6 @@ async def read_data_pandas(query: str, database: str = "default",
"data": data, "data": data,
"row_count": len(all_rows), "row_count": len(all_rows),
"truncated": truncated, "truncated": truncated,
"database": database,
"pandas_hint": "Use pd.DataFrame(result['data']) for 'records' format, pd.DataFrame(result['data']) for 'columns' format, or pd.read_csv(io.StringIO(result['data'])) for 'csv' format" "pandas_hint": "Use pd.DataFrame(result['data']) for 'records' format, pd.DataFrame(result['data']) for 'columns' format, or pd.read_csv(io.StringIO(result['data'])) for 'csv' format"
} }
@@ -942,7 +857,7 @@ server = Server("snowflake-mcp")
@server.list_tools() @server.list_tools()
async def list_tools() -> List[Tool]: async def handle_list_tools() -> List[Tool]:
"""List available MCP tools""" """List available MCP tools"""
return [ return [
Tool( Tool(
@@ -950,13 +865,7 @@ async def list_tools() -> List[Tool]:
description="Test Snowflake connection and get session information", description="Test Snowflake connection and get session information",
inputSchema={ inputSchema={
"type": "object", "type": "object",
"properties": { "properties": {}
"database": {
"type": "string",
"description": "Database context (default: 'default')",
"default": "default"
}
}
} }
), ),
Tool( Tool(
@@ -986,14 +895,9 @@ async def list_tools() -> List[Tool]:
inputSchema={ inputSchema={
"type": "object", "type": "object",
"properties": { "properties": {
"database": {
"type": "string",
"description": "Database context",
"default": "default"
},
"schema": { "schema": {
"type": "string", "type": "string",
"description": "Schema filter (optional)" "description": "Schema filter (optional, e.g. 'MY_DB.MY_SCHEMA')"
}, },
"include_row_counts": { "include_row_counts": {
"type": "boolean", "type": "boolean",
@@ -1012,11 +916,6 @@ async def list_tools() -> List[Tool]:
"schema_name": { "schema_name": {
"type": "string", "type": "string",
"description": "Schema name (e.g., 'MY_DB.MY_SCHEMA')" "description": "Schema name (e.g., 'MY_DB.MY_SCHEMA')"
},
"database": {
"type": "string",
"description": "Database context",
"default": "default"
} }
}, },
"required": ["schema_name"] "required": ["schema_name"]
@@ -1024,18 +923,13 @@ async def list_tools() -> List[Tool]:
), ),
Tool( Tool(
name="describe_table", name="describe_table",
description="Get detailed schema information for a table", description="Get detailed schema information for a table or view",
inputSchema={ inputSchema={
"type": "object", "type": "object",
"properties": { "properties": {
"table_name": { "table_name": {
"type": "string", "type": "string",
"description": "Table name (can include database.schema.table)" "description": "Table name (can include database.schema.table)"
},
"database": {
"type": "string",
"description": "Database context",
"default": "default"
} }
}, },
"required": ["table_name"] "required": ["table_name"]
@@ -1051,11 +945,6 @@ async def list_tools() -> List[Tool]:
"type": "string", "type": "string",
"description": "SQL SELECT query to execute" "description": "SQL SELECT query to execute"
}, },
"database": {
"type": "string",
"description": "Database context",
"default": "default"
},
"max_rows": { "max_rows": {
"type": "integer", "type": "integer",
"description": "Maximum rows to return", "description": "Maximum rows to return",
@@ -1082,11 +971,6 @@ async def list_tools() -> List[Tool]:
"query": { "query": {
"type": "string", "type": "string",
"description": "SQL SELECT query to analyze" "description": "SQL SELECT query to analyze"
},
"database": {
"type": "string",
"description": "Database context",
"default": "default"
} }
}, },
"required": ["query"] "required": ["query"]
@@ -1101,11 +985,6 @@ async def list_tools() -> List[Tool]:
"query": { "query": {
"type": "string", "type": "string",
"description": "SQL SELECT query to execute asynchronously" "description": "SQL SELECT query to execute asynchronously"
},
"database": {
"type": "string",
"description": "Database context",
"default": "default"
} }
}, },
"required": ["query"] "required": ["query"]
@@ -1120,11 +999,6 @@ async def list_tools() -> List[Tool]:
"query_id": { "query_id": {
"type": "string", "type": "string",
"description": "Snowflake query ID from execute_async" "description": "Snowflake query ID from execute_async"
},
"database": {
"type": "string",
"description": "Database context",
"default": "default"
} }
}, },
"required": ["query_id"] "required": ["query_id"]
@@ -1140,11 +1014,6 @@ async def list_tools() -> List[Tool]:
"type": "string", "type": "string",
"description": "Snowflake query ID from execute_async" "description": "Snowflake query ID from execute_async"
}, },
"database": {
"type": "string",
"description": "Database context",
"default": "default"
},
"max_rows": { "max_rows": {
"type": "integer", "type": "integer",
"description": "Maximum rows to return", "description": "Maximum rows to return",
@@ -1172,11 +1041,6 @@ async def list_tools() -> List[Tool]:
"type": "string", "type": "string",
"description": "SQL SELECT query to execute" "description": "SQL SELECT query to execute"
}, },
"database": {
"type": "string",
"description": "Database context",
"default": "default"
},
"page_size": { "page_size": {
"type": "integer", "type": "integer",
"description": "Number of rows per page (1-10000)", "description": "Number of rows per page (1-10000)",
@@ -1201,11 +1065,6 @@ async def list_tools() -> List[Tool]:
"type": "string", "type": "string",
"description": "SQL SELECT query to execute" "description": "SQL SELECT query to execute"
}, },
"database": {
"type": "string",
"description": "Database context",
"default": "default"
},
"max_rows": { "max_rows": {
"type": "integer", "type": "integer",
"description": "Maximum rows to return", "description": "Maximum rows to return",
@@ -1230,8 +1089,7 @@ async def call_tool(name: str, arguments: Any) -> List[TextContent]:
try: try:
if name == "test_connection": if name == "test_connection":
database = arguments.get("database", "default") result = await test_connection()
result = await test_connection(database)
elif name == "list_databases": elif name == "list_databases":
result = await list_databases() result = await list_databases()
@@ -1241,67 +1099,57 @@ async def call_tool(name: str, arguments: Any) -> List[TextContent]:
result = await list_schemas(database_name) result = await list_schemas(database_name)
elif name == "list_tables": elif name == "list_tables":
database = arguments.get("database", "default")
schema = arguments.get("schema") schema = arguments.get("schema")
include_counts = arguments.get("include_row_counts", False) include_counts = arguments.get("include_row_counts", False)
result = await list_tables(database, schema, include_counts) result = await list_tables(schema, include_counts)
elif name == "list_views": elif name == "list_views":
schema_name = arguments["schema_name"] schema_name = arguments["schema_name"]
database = arguments.get("database", "default") result = await list_views(schema_name)
result = await list_views(schema_name, database)
elif name == "describe_table": elif name == "describe_table":
table_name = arguments["table_name"] table_name = arguments["table_name"]
database = arguments.get("database", "default") result = await describe_table(table_name)
result = await describe_table(table_name, database)
elif name == "read_data": elif name == "read_data":
query = arguments["query"] query = arguments["query"]
database = arguments.get("database", "default")
max_rows = arguments.get("max_rows", 50000) max_rows = arguments.get("max_rows", 50000)
result = await read_data(query, database, max_rows) result = await read_data(query, max_rows)
elif name == "get_system_health": elif name == "get_system_health":
result = await get_system_health() result = await get_system_health()
elif name == "describe_query": elif name == "describe_query":
query = arguments["query"] query = arguments["query"]
database = arguments.get("database", "default") result = await describe_query(query)
result = await describe_query(query, database)
elif name == "execute_async": elif name == "execute_async":
query = arguments["query"] query = arguments["query"]
database = arguments.get("database", "default") result = await execute_async(query)
result = await execute_async(query, database)
elif name == "get_query_status": elif name == "get_query_status":
query_id = arguments["query_id"] query_id = arguments["query_id"]
database = arguments.get("database", "default") result = await get_query_status(query_id)
result = await get_query_status(query_id, database)
elif name == "get_async_results": elif name == "get_async_results":
query_id = arguments["query_id"] query_id = arguments["query_id"]
database = arguments.get("database", "default")
max_rows = arguments.get("max_rows", 50000) max_rows = arguments.get("max_rows", 50000)
result = await get_async_results(query_id, database, max_rows) result = await get_async_results(query_id, max_rows)
elif name == "list_async_queries": elif name == "list_async_queries":
result = await list_async_queries() result = await list_async_queries()
elif name == "read_data_paginated": elif name == "read_data_paginated":
query = arguments["query"] query = arguments["query"]
database = arguments.get("database", "default")
page_size = arguments.get("page_size", 1000) page_size = arguments.get("page_size", 1000)
page = arguments.get("page", 1) page = arguments.get("page", 1)
result = await read_data_paginated(query, database, page_size, page) result = await read_data_paginated(query, page_size, page)
elif name == "read_data_pandas": elif name == "read_data_pandas":
query = arguments["query"] query = arguments["query"]
database = arguments.get("database", "default")
max_rows = arguments.get("max_rows", 50000) max_rows = arguments.get("max_rows", 50000)
output_format = arguments.get("output_format", "records") output_format = arguments.get("output_format", "records")
result = await read_data_pandas(query, database, max_rows, output_format) result = await read_data_pandas(query, max_rows, output_format)
else: else:
raise ValueError(f"Unknown tool: {name}") raise ValueError(f"Unknown tool: {name}")
@@ -1339,4 +1187,4 @@ if __name__ == "__main__":
except KeyboardInterrupt: except KeyboardInterrupt:
logger.info("Server stopped by user") logger.info("Server stopped by user")
finally: finally:
connection_pool.close_all() connection.close()