1191 lines
39 KiB
Python
1191 lines
39 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Snowflake MCP Server with SSO Authentication
|
|
Provides read-only Snowflake access via the Model Context Protocol.
|
|
"""
|
|
|
|
import asyncio
|
|
import csv
|
|
import io
|
|
import json
|
|
import logging
|
|
import os
|
|
import re
|
|
from typing import Any, Optional, Dict, List
|
|
from datetime import datetime, timedelta
|
|
import hashlib
|
|
|
|
import snowflake.connector
|
|
from snowflake.connector import DictCursor
|
|
from snowflake.connector.constants import QueryStatus
|
|
|
|
from mcp.server import Server
|
|
from mcp.types import TextContent, Tool, INTERNAL_ERROR
|
|
|
|
# Configure logging
|
|
logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Configuration — set via environment variables or edit the defaults below
|
|
# ---------------------------------------------------------------------------
|
|
SNOWFLAKE_CONFIG = {
|
|
"account": os.environ.get("SNOWFLAKE_ACCOUNT", "YOUR_ACCOUNT"),
|
|
"user": os.environ.get("SNOWFLAKE_USER", "your.email@example.com"),
|
|
"authenticator": os.environ.get("SNOWFLAKE_AUTHENTICATOR", "externalbrowser"),
|
|
"warehouse": os.environ.get("SNOWFLAKE_WAREHOUSE"), # None = account default
|
|
"role": os.environ.get("SNOWFLAKE_ROLE"), # None = account default
|
|
}
|
|
|
|
|
|
class SnowflakeConnection:
|
|
"""Manages a single Snowflake connection with SSO authentication"""
|
|
|
|
def __init__(self):
|
|
self._conn: Optional[snowflake.connector.SnowflakeConnection] = None
|
|
self._last_used: Optional[datetime] = None
|
|
self._timeout = timedelta(minutes=30)
|
|
|
|
def get(self) -> snowflake.connector.SnowflakeConnection:
|
|
"""Get or create the Snowflake connection"""
|
|
|
|
# Check if we have a valid cached connection
|
|
if self._conn and self._last_used:
|
|
if datetime.now() - self._last_used < self._timeout:
|
|
try:
|
|
cursor = self._conn.cursor()
|
|
cursor.execute("SELECT 1")
|
|
cursor.close()
|
|
self._last_used = datetime.now()
|
|
logger.info("Reusing cached Snowflake connection")
|
|
return self._conn
|
|
except Exception:
|
|
logger.info("Cached connection is stale, reconnecting...")
|
|
try:
|
|
self._conn.close()
|
|
except Exception:
|
|
pass
|
|
|
|
# Create new connection
|
|
logger.info("Creating new Snowflake connection")
|
|
logger.info("Note: First connection will open browser for SSO login")
|
|
|
|
connect_params = {
|
|
"account": SNOWFLAKE_CONFIG["account"],
|
|
"user": SNOWFLAKE_CONFIG["user"],
|
|
"authenticator": SNOWFLAKE_CONFIG["authenticator"],
|
|
}
|
|
|
|
if SNOWFLAKE_CONFIG.get("warehouse"):
|
|
connect_params["warehouse"] = SNOWFLAKE_CONFIG["warehouse"]
|
|
if SNOWFLAKE_CONFIG.get("role"):
|
|
connect_params["role"] = SNOWFLAKE_CONFIG["role"]
|
|
|
|
try:
|
|
self._conn = snowflake.connector.connect(**connect_params)
|
|
self._last_used = datetime.now()
|
|
logger.info("Successfully connected to Snowflake")
|
|
return self._conn
|
|
except Exception as e:
|
|
logger.error(f"Failed to connect to Snowflake: {e}")
|
|
raise
|
|
|
|
def close(self):
|
|
"""Close the connection"""
|
|
if self._conn:
|
|
try:
|
|
self._conn.close()
|
|
logger.info("Closed Snowflake connection")
|
|
except Exception:
|
|
pass
|
|
self._conn = None
|
|
self._last_used = None
|
|
|
|
|
|
# Global connection
|
|
connection = SnowflakeConnection()
|
|
|
|
# Query cache
|
|
query_cache: Dict[str, tuple] = {}
|
|
CACHE_TTL = timedelta(minutes=5)
|
|
|
|
# Async query tracking
|
|
async_queries: Dict[str, Dict[str, Any]] = {}
|
|
|
|
|
|
def get_cache_key(query: str) -> str:
|
|
"""Generate cache key for query"""
|
|
return hashlib.md5(query.encode()).hexdigest()
|
|
|
|
|
|
def validate_query(query: str) -> tuple[bool, str]:
|
|
"""Validate query for safety"""
|
|
query_upper = query.upper().strip()
|
|
|
|
# Allow SELECT, WITH (CTEs), and SHOW/DESCRIBE commands
|
|
allowed_starts = ('SELECT', 'WITH', 'SHOW', 'DESCRIBE', 'DESC', 'LIST')
|
|
if not any(query_upper.startswith(start) for start in allowed_starts):
|
|
if 'SELECT' not in query_upper:
|
|
return False, "Only SELECT queries are allowed"
|
|
|
|
# Block dangerous keywords
|
|
dangerous = ['INSERT', 'UPDATE', 'DELETE', 'DROP', 'TRUNCATE',
|
|
'CREATE', 'ALTER', 'GRANT', 'REVOKE', 'MERGE']
|
|
|
|
for keyword in dangerous:
|
|
if f' {keyword} ' in f' {query_upper} ':
|
|
return False, f"Query contains forbidden keyword: {keyword}"
|
|
|
|
return True, "Query validated successfully"
|
|
|
|
|
|
async def test_connection() -> Dict[str, Any]:
|
|
"""Test Snowflake connection and return system information"""
|
|
try:
|
|
conn = connection.get()
|
|
cursor = conn.cursor(DictCursor)
|
|
|
|
cursor.execute("""
|
|
SELECT
|
|
CURRENT_ACCOUNT() as account,
|
|
CURRENT_USER() as user_name,
|
|
CURRENT_ROLE() as current_role,
|
|
CURRENT_WAREHOUSE() as warehouse,
|
|
CURRENT_DATABASE() as database_name,
|
|
CURRENT_SCHEMA() as schema_name,
|
|
CURRENT_VERSION() as version
|
|
""")
|
|
|
|
row = cursor.fetchone()
|
|
cursor.close()
|
|
|
|
return {
|
|
"status": "connected",
|
|
"account": row["ACCOUNT"],
|
|
"user": row["USER_NAME"],
|
|
"role": row["CURRENT_ROLE"],
|
|
"warehouse": row["WAREHOUSE"],
|
|
"database": row["DATABASE_NAME"],
|
|
"schema": row["SCHEMA_NAME"],
|
|
"version": row["VERSION"],
|
|
"auth_method": SNOWFLAKE_CONFIG["authenticator"],
|
|
"message": "Connection successful"
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Connection test failed: {e}")
|
|
return {
|
|
"status": "error",
|
|
"error": str(e),
|
|
"message": "If this is your first connection, a browser window should open for SSO login"
|
|
}
|
|
|
|
|
|
async def list_tables(schema: Optional[str] = None,
|
|
include_row_counts: bool = False) -> Dict[str, Any]:
|
|
"""List all tables in the database/schema"""
|
|
try:
|
|
conn = connection.get()
|
|
cursor = conn.cursor(DictCursor)
|
|
|
|
if schema:
|
|
query = f"SHOW TABLES IN SCHEMA {schema}"
|
|
else:
|
|
query = "SHOW TABLES"
|
|
|
|
cursor.execute(query)
|
|
tables = []
|
|
|
|
for row in cursor:
|
|
table_info = {
|
|
"database": row.get("database_name"),
|
|
"schema": row.get("schema_name"),
|
|
"name": row.get("name"),
|
|
"full_name": f"{row.get('database_name')}.{row.get('schema_name')}.{row.get('name')}",
|
|
"description": row.get("comment"),
|
|
"kind": row.get("kind"),
|
|
"created": str(row.get("created_on")) if row.get("created_on") else None,
|
|
"rows": row.get("rows") if include_row_counts else None
|
|
}
|
|
tables.append(table_info)
|
|
|
|
cursor.close()
|
|
|
|
if not tables:
|
|
return {
|
|
"tables": [],
|
|
"count": 0,
|
|
"hint": f"No tables found in {schema or 'current schema'}. This schema may contain VIEWS instead of tables. Use list_views to check for views."
|
|
}
|
|
|
|
return {"tables": tables, "count": len(tables)}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to list tables: {e}")
|
|
raise
|
|
|
|
|
|
async def list_views(schema_name: str) -> List[Dict[str, Any]]:
|
|
"""List all views in a schema with their descriptions"""
|
|
try:
|
|
conn = connection.get()
|
|
cursor = conn.cursor(DictCursor)
|
|
|
|
cursor.execute(f"SHOW VIEWS IN SCHEMA {schema_name}")
|
|
|
|
views = []
|
|
for row in cursor:
|
|
view_info = {
|
|
"name": row.get("name"),
|
|
"database": row.get("database_name"),
|
|
"schema": row.get("schema_name"),
|
|
"full_name": f"{row.get('database_name')}.{row.get('schema_name')}.{row.get('name')}",
|
|
"description": row.get("comment"),
|
|
"created": str(row.get("created_on")) if row.get("created_on") else None,
|
|
"is_secure": row.get("is_secure") == "true",
|
|
"is_materialized": row.get("is_materialized") == "true"
|
|
}
|
|
views.append(view_info)
|
|
|
|
cursor.close()
|
|
return views
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to list views in {schema_name}: {e}")
|
|
raise
|
|
|
|
|
|
async def describe_table(table_name: str) -> Dict[str, Any]:
|
|
"""Get detailed schema information for a table or view"""
|
|
try:
|
|
conn = connection.get()
|
|
cursor = conn.cursor()
|
|
|
|
parts = table_name.split('.')
|
|
if len(parts) == 3:
|
|
db, schema, table = parts
|
|
elif len(parts) == 2:
|
|
db = None
|
|
schema, table = parts
|
|
else:
|
|
db = None
|
|
schema = None
|
|
table = table_name
|
|
|
|
if len(parts) == 3:
|
|
quoted_name = f'{parts[0]}.{parts[1]}."{parts[2]}"'
|
|
elif len(parts) == 2:
|
|
quoted_name = f'{parts[0]}."{parts[1]}"'
|
|
else:
|
|
quoted_name = f'"{table_name}"'
|
|
|
|
rows = None
|
|
used_name = table_name
|
|
for try_name in [table_name, quoted_name]:
|
|
try:
|
|
cursor.execute(f"SHOW COLUMNS IN TABLE {try_name}")
|
|
rows = cursor.fetchall()
|
|
used_name = try_name
|
|
break
|
|
except Exception as inner_e:
|
|
if "does not exist" in str(inner_e).lower() and try_name == table_name:
|
|
continue
|
|
raise
|
|
|
|
if rows is None:
|
|
rows = []
|
|
|
|
desc_columns = []
|
|
if cursor.description:
|
|
for desc in cursor.description:
|
|
col_name = getattr(desc, 'name', None) or desc[0]
|
|
desc_columns.append(str(col_name).upper())
|
|
|
|
columns = []
|
|
for row in rows:
|
|
row_dict = {}
|
|
for i, val in enumerate(row):
|
|
if i < len(desc_columns):
|
|
row_dict[desc_columns[i]] = val
|
|
|
|
col_info = {
|
|
"name": row_dict.get("COLUMN_NAME"),
|
|
"type": row_dict.get("DATA_TYPE"),
|
|
"nullable": row_dict.get("IS_NULLABLE") == "YES",
|
|
"default": row_dict.get("COLUMN_DEFAULT"),
|
|
"comment": row_dict.get("COMMENT"),
|
|
"kind": row_dict.get("KIND")
|
|
}
|
|
columns.append(col_info)
|
|
|
|
cursor.close()
|
|
|
|
return {
|
|
"database": db,
|
|
"schema": schema,
|
|
"table": table,
|
|
"full_name": table_name,
|
|
"quoted_name": used_name if used_name != table_name else None,
|
|
"columns": columns,
|
|
"column_count": len(columns),
|
|
"hint": "Use the quoted_name in queries if provided, as this object has case-sensitive naming." if used_name != table_name else None
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to describe table {table_name}: {e}")
|
|
error_str = str(e)
|
|
suggestions = []
|
|
|
|
if "does not exist" in error_str.lower():
|
|
suggestions.append("The object may be a VIEW. Use list_views to check available views in the schema.")
|
|
suggestions.append("Object names are case-sensitive. The name may need to be quoted.")
|
|
elif "not authorized" in error_str.lower():
|
|
suggestions.append("Check that your role has SELECT privileges on this object.")
|
|
|
|
return {
|
|
"error": error_str,
|
|
"table_name": table_name,
|
|
"suggestions": suggestions
|
|
}
|
|
|
|
|
|
async def read_data(query: str, max_rows: int = 50000) -> Dict[str, Any]:
|
|
"""Execute a SELECT query and return results"""
|
|
|
|
is_valid, message = validate_query(query)
|
|
if not is_valid:
|
|
return {
|
|
"error": message,
|
|
"suggestion": "Use SELECT statements only. For schema information, use describe_table."
|
|
}
|
|
|
|
cache_key = get_cache_key(query)
|
|
if cache_key in query_cache:
|
|
cached_result, cached_time = query_cache[cache_key]
|
|
if datetime.now() - cached_time < CACHE_TTL:
|
|
logger.info("Returning cached result for query")
|
|
return {
|
|
**cached_result,
|
|
"cached": True,
|
|
"cache_age_seconds": (datetime.now() - cached_time).total_seconds()
|
|
}
|
|
|
|
try:
|
|
conn = connection.get()
|
|
cursor = conn.cursor(DictCursor)
|
|
|
|
cursor.execute(query)
|
|
|
|
columns = [desc[0] for desc in cursor.description] if cursor.description else []
|
|
|
|
rows = []
|
|
for i, row in enumerate(cursor):
|
|
if i >= max_rows:
|
|
logger.warning(f"Query returned more than {max_rows} rows, truncating")
|
|
break
|
|
|
|
row_dict = {}
|
|
for col in columns:
|
|
value = row[col]
|
|
if value is None:
|
|
row_dict[col] = None
|
|
elif isinstance(value, (datetime, bytes)):
|
|
row_dict[col] = str(value)
|
|
else:
|
|
row_dict[col] = value
|
|
|
|
rows.append(row_dict)
|
|
|
|
cursor.close()
|
|
|
|
result = {
|
|
"columns": columns,
|
|
"rows": rows,
|
|
"row_count": len(rows),
|
|
"truncated": len(rows) == max_rows,
|
|
}
|
|
|
|
query_cache[cache_key] = (result, datetime.now())
|
|
|
|
return result
|
|
|
|
except Exception as e:
|
|
logger.error(f"Query execution failed: {e}")
|
|
error_str = str(e)
|
|
suggestions = []
|
|
|
|
if "does not exist" in error_str.lower() or "not authorized" in error_str.lower():
|
|
suggestions.append("Object names in Snowflake are case-sensitive when created with quotes. Try wrapping table/view names in double quotes.")
|
|
suggestions.append("Use list_views to check if the object is a view rather than a table.")
|
|
else:
|
|
suggestions.append("Check your query syntax. Use describe_table to see available columns.")
|
|
|
|
return {
|
|
"error": error_str,
|
|
"suggestions": suggestions
|
|
}
|
|
|
|
|
|
async def list_databases() -> List[Dict[str, Any]]:
|
|
"""List all accessible databases"""
|
|
try:
|
|
conn = connection.get()
|
|
cursor = conn.cursor(DictCursor)
|
|
|
|
cursor.execute("SHOW DATABASES")
|
|
|
|
databases = []
|
|
for row in cursor:
|
|
db_info = {
|
|
"name": row.get("name"),
|
|
"created": str(row.get("created_on")) if row.get("created_on") else None,
|
|
"owner": row.get("owner"),
|
|
"comment": row.get("comment"),
|
|
"origin": row.get("origin")
|
|
}
|
|
databases.append(db_info)
|
|
|
|
cursor.close()
|
|
return databases
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to list databases: {e}")
|
|
raise
|
|
|
|
|
|
async def list_schemas(database_name: Optional[str] = None) -> List[Dict[str, Any]]:
|
|
"""List all schemas in a database"""
|
|
try:
|
|
conn = connection.get()
|
|
cursor = conn.cursor(DictCursor)
|
|
|
|
if database_name:
|
|
cursor.execute(f"SHOW SCHEMAS IN DATABASE {database_name}")
|
|
else:
|
|
cursor.execute("SHOW SCHEMAS")
|
|
|
|
schemas = []
|
|
for row in cursor:
|
|
schema_info = {
|
|
"name": row.get("name"),
|
|
"database": row.get("database_name"),
|
|
"created": str(row.get("created_on")) if row.get("created_on") else None,
|
|
"owner": row.get("owner")
|
|
}
|
|
schemas.append(schema_info)
|
|
|
|
cursor.close()
|
|
return schemas
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to list schemas: {e}")
|
|
raise
|
|
|
|
|
|
async def get_system_health() -> Dict[str, Any]:
|
|
"""Get system health and connection pool statistics"""
|
|
|
|
health_info = {
|
|
"status": "healthy",
|
|
"timestamp": datetime.now().isoformat(),
|
|
"cache": {
|
|
"query_cache_size": len(query_cache),
|
|
"cache_ttl_minutes": CACHE_TTL.total_seconds() / 60
|
|
},
|
|
"authentication": {
|
|
"method": SNOWFLAKE_CONFIG["authenticator"],
|
|
"account": SNOWFLAKE_CONFIG["account"],
|
|
"user": SNOWFLAKE_CONFIG["user"],
|
|
}
|
|
}
|
|
|
|
try:
|
|
test_result = await test_connection()
|
|
health_info["connection"] = test_result
|
|
except Exception:
|
|
health_info["status"] = "degraded"
|
|
health_info["connection"] = {"status": "error"}
|
|
|
|
return health_info
|
|
|
|
|
|
async def describe_query(query: str) -> Dict[str, Any]:
|
|
"""Preview query output columns without executing the full query.
|
|
Uses LIMIT 0 to get column metadata efficiently."""
|
|
|
|
is_valid, message = validate_query(query)
|
|
if not is_valid:
|
|
return {"error": message}
|
|
|
|
try:
|
|
conn = connection.get()
|
|
cursor = conn.cursor(DictCursor)
|
|
|
|
clean_query = query.rstrip(';').strip()
|
|
clean_query = re.sub(r'\s+LIMIT\s+\d+(\s+OFFSET\s+\d+)?\s*$', '', clean_query, flags=re.IGNORECASE)
|
|
clean_query = re.sub(r'\s+OFFSET\s+\d+\s*$', '', clean_query, flags=re.IGNORECASE)
|
|
|
|
limited_query = f"{clean_query} LIMIT 0"
|
|
|
|
try:
|
|
cursor.execute(limited_query)
|
|
except Exception:
|
|
limited_query = f"SELECT * FROM ({clean_query}) AS subq LIMIT 0"
|
|
cursor.execute(limited_query)
|
|
|
|
columns = []
|
|
if cursor.description:
|
|
for desc in cursor.description:
|
|
col_info = {
|
|
"name": desc[0],
|
|
"type_code": desc[1],
|
|
"display_size": desc[2],
|
|
"internal_size": desc[3],
|
|
"precision": desc[4],
|
|
"scale": desc[5],
|
|
"nullable": desc[6]
|
|
}
|
|
columns.append(col_info)
|
|
|
|
cursor.close()
|
|
|
|
return {
|
|
"columns": columns,
|
|
"column_count": len(columns),
|
|
"query_preview": query[:200] + "..." if len(query) > 200 else query
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to describe query: {e}")
|
|
return {
|
|
"error": str(e),
|
|
"suggestion": "Check query syntax. The query must be a valid SELECT statement."
|
|
}
|
|
|
|
|
|
async def execute_async(query: str) -> Dict[str, Any]:
|
|
"""Submit a query for asynchronous execution. Returns query ID for status tracking."""
|
|
|
|
is_valid, message = validate_query(query)
|
|
if not is_valid:
|
|
return {"error": message}
|
|
|
|
try:
|
|
conn = connection.get()
|
|
cursor = conn.cursor()
|
|
|
|
cursor.execute_async(query)
|
|
query_id = cursor.sfqid
|
|
|
|
async_queries[query_id] = {
|
|
"query": query[:500],
|
|
"submitted_at": datetime.now().isoformat(),
|
|
"status": "RUNNING"
|
|
}
|
|
|
|
cursor.close()
|
|
|
|
return {
|
|
"query_id": query_id,
|
|
"status": "SUBMITTED",
|
|
"message": "Query submitted for async execution. Use get_query_status to check progress.",
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to submit async query: {e}")
|
|
return {"error": str(e)}
|
|
|
|
|
|
async def get_query_status(query_id: str) -> Dict[str, Any]:
|
|
"""Check the status of an async query by its query ID."""
|
|
|
|
try:
|
|
conn = connection.get()
|
|
|
|
status = conn.get_query_status(query_id)
|
|
status_name = status.name if hasattr(status, 'name') else str(status)
|
|
|
|
is_running = conn.is_still_running(status)
|
|
is_error = conn.is_an_error(status)
|
|
|
|
if query_id in async_queries:
|
|
async_queries[query_id]["status"] = status_name
|
|
|
|
result = {
|
|
"query_id": query_id,
|
|
"status": status_name,
|
|
"is_running": is_running,
|
|
"is_error": is_error,
|
|
"can_fetch_results": not is_running and not is_error
|
|
}
|
|
|
|
if query_id in async_queries:
|
|
result["query_info"] = async_queries[query_id]
|
|
|
|
return result
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to get query status: {e}")
|
|
return {"error": str(e), "query_id": query_id}
|
|
|
|
|
|
async def get_async_results(query_id: str, max_rows: int = 50000) -> Dict[str, Any]:
|
|
"""Retrieve results from a completed async query."""
|
|
|
|
try:
|
|
conn = connection.get()
|
|
|
|
status = conn.get_query_status_throw_if_error(query_id)
|
|
|
|
if conn.is_still_running(status):
|
|
return {
|
|
"error": "Query is still running",
|
|
"query_id": query_id,
|
|
"status": status.name if hasattr(status, 'name') else str(status)
|
|
}
|
|
|
|
cursor = conn.cursor()
|
|
cursor.get_results_from_sfqid(query_id)
|
|
|
|
raw_rows = cursor.fetchmany(max_rows)
|
|
|
|
columns = [desc[0] for desc in cursor.description] if cursor.description else []
|
|
|
|
rows = []
|
|
for row in raw_rows:
|
|
row_dict = {}
|
|
for idx, col in enumerate(columns):
|
|
value = row[idx] if idx < len(row) else None
|
|
if value is None:
|
|
row_dict[col] = None
|
|
elif isinstance(value, (datetime, bytes)):
|
|
row_dict[col] = str(value)
|
|
else:
|
|
row_dict[col] = value
|
|
rows.append(row_dict)
|
|
|
|
truncated = len(raw_rows) == max_rows
|
|
|
|
cursor.close()
|
|
|
|
if query_id in async_queries:
|
|
del async_queries[query_id]
|
|
|
|
return {
|
|
"query_id": query_id,
|
|
"columns": columns,
|
|
"rows": rows,
|
|
"row_count": len(rows),
|
|
"truncated": truncated
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to get async results: {e}")
|
|
return {"error": str(e), "query_id": query_id}
|
|
|
|
|
|
async def read_data_paginated(query: str, page_size: int = 1000,
|
|
page: int = 1) -> Dict[str, Any]:
|
|
"""Execute a query with pagination support (offset/limit)."""
|
|
|
|
is_valid, message = validate_query(query)
|
|
if not is_valid:
|
|
return {"error": message}
|
|
|
|
if page < 1:
|
|
return {"error": "Page must be >= 1"}
|
|
if page_size < 1 or page_size > 10000:
|
|
return {"error": "Page size must be between 1 and 10000"}
|
|
|
|
offset = (page - 1) * page_size
|
|
|
|
clean_query = query.rstrip(';').strip()
|
|
clean_query = re.sub(r'\s+LIMIT\s+\d+(\s+OFFSET\s+\d+)?\s*$', '', clean_query, flags=re.IGNORECASE)
|
|
clean_query = re.sub(r'\s+OFFSET\s+\d+\s*$', '', clean_query, flags=re.IGNORECASE)
|
|
|
|
paginated_query = f"""
|
|
SELECT * FROM (
|
|
{clean_query}
|
|
) AS paginated_result
|
|
LIMIT {page_size} OFFSET {offset}
|
|
"""
|
|
|
|
count_cache_key = get_cache_key(f"COUNT:{clean_query}")
|
|
total_count = None
|
|
|
|
if count_cache_key in query_cache:
|
|
cached_count, cached_time = query_cache[count_cache_key]
|
|
if datetime.now() - cached_time < CACHE_TTL:
|
|
total_count = cached_count
|
|
|
|
try:
|
|
conn = connection.get()
|
|
cursor = conn.cursor(DictCursor)
|
|
|
|
if total_count is None:
|
|
count_query = f"SELECT COUNT(*) as total FROM ({clean_query}) AS count_subq"
|
|
cursor.execute(count_query)
|
|
count_row = cursor.fetchone()
|
|
total_count = count_row["TOTAL"] if count_row else 0
|
|
query_cache[count_cache_key] = (total_count, datetime.now())
|
|
|
|
cursor.execute(paginated_query)
|
|
|
|
columns = [desc[0] for desc in cursor.description] if cursor.description else []
|
|
|
|
rows = []
|
|
for row in cursor:
|
|
row_dict = {}
|
|
for col in columns:
|
|
value = row[col]
|
|
if value is None:
|
|
row_dict[col] = None
|
|
elif isinstance(value, (datetime, bytes)):
|
|
row_dict[col] = str(value)
|
|
else:
|
|
row_dict[col] = value
|
|
rows.append(row_dict)
|
|
|
|
cursor.close()
|
|
|
|
total_pages = (total_count + page_size - 1) // page_size if total_count else 0
|
|
|
|
return {
|
|
"columns": columns,
|
|
"rows": rows,
|
|
"pagination": {
|
|
"page": page,
|
|
"page_size": page_size,
|
|
"total_rows": total_count,
|
|
"total_pages": total_pages,
|
|
"has_next": page < total_pages,
|
|
"has_previous": page > 1
|
|
},
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Paginated query failed: {e}")
|
|
return {"error": str(e)}
|
|
|
|
|
|
async def read_data_pandas(query: str, max_rows: int = 50000,
|
|
output_format: str = "records") -> Dict[str, Any]:
|
|
"""Execute a query and return results in a pandas-friendly format.
|
|
|
|
output_format options:
|
|
- 'records': List of dicts (default, same as read_data)
|
|
- 'columns': Dict with column arrays {col1: [vals], col2: [vals]}
|
|
- 'split': Dict with columns and data arrays {columns: [...], data: [[...]]}
|
|
- 'csv': CSV-formatted string
|
|
"""
|
|
|
|
is_valid, message = validate_query(query)
|
|
if not is_valid:
|
|
return {"error": message}
|
|
|
|
if output_format not in ('records', 'columns', 'split', 'csv'):
|
|
return {"error": f"Invalid output_format: {output_format}. Use: records, columns, split, csv"}
|
|
|
|
try:
|
|
conn = connection.get()
|
|
cursor = conn.cursor(DictCursor)
|
|
|
|
cursor.execute(query)
|
|
|
|
columns = [desc[0] for desc in cursor.description] if cursor.description else []
|
|
|
|
all_rows = []
|
|
for i, row in enumerate(cursor):
|
|
if i >= max_rows:
|
|
break
|
|
row_data = {}
|
|
for col in columns:
|
|
value = row[col]
|
|
if value is None:
|
|
row_data[col] = None
|
|
elif isinstance(value, datetime):
|
|
row_data[col] = value.isoformat()
|
|
elif isinstance(value, bytes):
|
|
row_data[col] = value.decode('utf-8', errors='replace')
|
|
else:
|
|
row_data[col] = value
|
|
all_rows.append(row_data)
|
|
|
|
cursor.close()
|
|
truncated = len(all_rows) == max_rows
|
|
|
|
if output_format == 'records':
|
|
data = all_rows
|
|
elif output_format == 'columns':
|
|
data = {col: [row.get(col) for row in all_rows] for col in columns}
|
|
elif output_format == 'split':
|
|
data = {
|
|
"columns": columns,
|
|
"data": [[row.get(col) for col in columns] for row in all_rows]
|
|
}
|
|
elif output_format == 'csv':
|
|
output = io.StringIO()
|
|
writer = csv.DictWriter(output, fieldnames=columns)
|
|
writer.writeheader()
|
|
writer.writerows(all_rows)
|
|
data = output.getvalue()
|
|
|
|
return {
|
|
"format": output_format,
|
|
"columns": columns,
|
|
"data": data,
|
|
"row_count": len(all_rows),
|
|
"truncated": truncated,
|
|
"pandas_hint": "Use pd.DataFrame(result['data']) for 'records' format, pd.DataFrame(result['data']) for 'columns' format, or pd.read_csv(io.StringIO(result['data'])) for 'csv' format"
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Pandas query failed: {e}")
|
|
return {"error": str(e)}
|
|
|
|
|
|
async def list_async_queries() -> Dict[str, Any]:
|
|
"""List all tracked async queries and their statuses."""
|
|
return {
|
|
"queries": async_queries,
|
|
"count": len(async_queries)
|
|
}
|
|
|
|
|
|
# Create MCP server
|
|
server = Server("snowflake-mcp")
|
|
|
|
|
|
@server.list_tools()
|
|
async def handle_list_tools() -> List[Tool]:
|
|
"""List available MCP tools"""
|
|
return [
|
|
Tool(
|
|
name="test_connection",
|
|
description="Test Snowflake connection and get session information",
|
|
inputSchema={
|
|
"type": "object",
|
|
"properties": {}
|
|
}
|
|
),
|
|
Tool(
|
|
name="list_databases",
|
|
description="List all accessible databases in Snowflake",
|
|
inputSchema={
|
|
"type": "object",
|
|
"properties": {}
|
|
}
|
|
),
|
|
Tool(
|
|
name="list_schemas",
|
|
description="List all schemas in a database",
|
|
inputSchema={
|
|
"type": "object",
|
|
"properties": {
|
|
"database_name": {
|
|
"type": "string",
|
|
"description": "Database name (optional, uses current if not specified)"
|
|
}
|
|
}
|
|
}
|
|
),
|
|
Tool(
|
|
name="list_tables",
|
|
description="List all tables in the current database/schema",
|
|
inputSchema={
|
|
"type": "object",
|
|
"properties": {
|
|
"schema": {
|
|
"type": "string",
|
|
"description": "Schema filter (optional, e.g. 'MY_DB.MY_SCHEMA')"
|
|
},
|
|
"include_row_counts": {
|
|
"type": "boolean",
|
|
"description": "Include row counts",
|
|
"default": False
|
|
}
|
|
}
|
|
}
|
|
),
|
|
Tool(
|
|
name="list_views",
|
|
description="List all views in a schema with their descriptions",
|
|
inputSchema={
|
|
"type": "object",
|
|
"properties": {
|
|
"schema_name": {
|
|
"type": "string",
|
|
"description": "Schema name (e.g., 'MY_DB.MY_SCHEMA')"
|
|
}
|
|
},
|
|
"required": ["schema_name"]
|
|
}
|
|
),
|
|
Tool(
|
|
name="describe_table",
|
|
description="Get detailed schema information for a table or view",
|
|
inputSchema={
|
|
"type": "object",
|
|
"properties": {
|
|
"table_name": {
|
|
"type": "string",
|
|
"description": "Table name (can include database.schema.table)"
|
|
}
|
|
},
|
|
"required": ["table_name"]
|
|
}
|
|
),
|
|
Tool(
|
|
name="read_data",
|
|
description="Execute a SELECT query on Snowflake",
|
|
inputSchema={
|
|
"type": "object",
|
|
"properties": {
|
|
"query": {
|
|
"type": "string",
|
|
"description": "SQL SELECT query to execute"
|
|
},
|
|
"max_rows": {
|
|
"type": "integer",
|
|
"description": "Maximum rows to return",
|
|
"default": 50000
|
|
}
|
|
},
|
|
"required": ["query"]
|
|
}
|
|
),
|
|
Tool(
|
|
name="get_system_health",
|
|
description="Get system health and connection pool statistics",
|
|
inputSchema={
|
|
"type": "object",
|
|
"properties": {}
|
|
}
|
|
),
|
|
Tool(
|
|
name="describe_query",
|
|
description="Preview query output columns without executing the full query. Useful for validating complex queries before running them.",
|
|
inputSchema={
|
|
"type": "object",
|
|
"properties": {
|
|
"query": {
|
|
"type": "string",
|
|
"description": "SQL SELECT query to analyze"
|
|
}
|
|
},
|
|
"required": ["query"]
|
|
}
|
|
),
|
|
Tool(
|
|
name="execute_async",
|
|
description="Submit a query for asynchronous execution. Returns a query ID for status tracking. Use for long-running queries.",
|
|
inputSchema={
|
|
"type": "object",
|
|
"properties": {
|
|
"query": {
|
|
"type": "string",
|
|
"description": "SQL SELECT query to execute asynchronously"
|
|
}
|
|
},
|
|
"required": ["query"]
|
|
}
|
|
),
|
|
Tool(
|
|
name="get_query_status",
|
|
description="Check the status of an async query by its query ID",
|
|
inputSchema={
|
|
"type": "object",
|
|
"properties": {
|
|
"query_id": {
|
|
"type": "string",
|
|
"description": "Snowflake query ID from execute_async"
|
|
}
|
|
},
|
|
"required": ["query_id"]
|
|
}
|
|
),
|
|
Tool(
|
|
name="get_async_results",
|
|
description="Retrieve results from a completed async query",
|
|
inputSchema={
|
|
"type": "object",
|
|
"properties": {
|
|
"query_id": {
|
|
"type": "string",
|
|
"description": "Snowflake query ID from execute_async"
|
|
},
|
|
"max_rows": {
|
|
"type": "integer",
|
|
"description": "Maximum rows to return",
|
|
"default": 50000
|
|
}
|
|
},
|
|
"required": ["query_id"]
|
|
}
|
|
),
|
|
Tool(
|
|
name="list_async_queries",
|
|
description="List all tracked async queries and their statuses",
|
|
inputSchema={
|
|
"type": "object",
|
|
"properties": {}
|
|
}
|
|
),
|
|
Tool(
|
|
name="read_data_paginated",
|
|
description="Execute a query with pagination support. Returns results page by page with total count.",
|
|
inputSchema={
|
|
"type": "object",
|
|
"properties": {
|
|
"query": {
|
|
"type": "string",
|
|
"description": "SQL SELECT query to execute"
|
|
},
|
|
"page_size": {
|
|
"type": "integer",
|
|
"description": "Number of rows per page (1-10000)",
|
|
"default": 1000
|
|
},
|
|
"page": {
|
|
"type": "integer",
|
|
"description": "Page number (1-based)",
|
|
"default": 1
|
|
}
|
|
},
|
|
"required": ["query"]
|
|
}
|
|
),
|
|
Tool(
|
|
name="read_data_pandas",
|
|
description="Execute a query and return results in a pandas-friendly format. Supports multiple output formats: records, columns, split, csv.",
|
|
inputSchema={
|
|
"type": "object",
|
|
"properties": {
|
|
"query": {
|
|
"type": "string",
|
|
"description": "SQL SELECT query to execute"
|
|
},
|
|
"max_rows": {
|
|
"type": "integer",
|
|
"description": "Maximum rows to return",
|
|
"default": 50000
|
|
},
|
|
"output_format": {
|
|
"type": "string",
|
|
"description": "Output format: 'records' (list of dicts), 'columns' (dict of arrays), 'split' (columns + data arrays), 'csv' (CSV string)",
|
|
"default": "records",
|
|
"enum": ["records", "columns", "split", "csv"]
|
|
}
|
|
},
|
|
"required": ["query"]
|
|
}
|
|
)
|
|
]
|
|
|
|
|
|
@server.call_tool()
|
|
async def call_tool(name: str, arguments: Any) -> List[TextContent]:
|
|
"""Handle tool calls"""
|
|
|
|
try:
|
|
if name == "test_connection":
|
|
result = await test_connection()
|
|
|
|
elif name == "list_databases":
|
|
result = await list_databases()
|
|
|
|
elif name == "list_schemas":
|
|
database_name = arguments.get("database_name")
|
|
result = await list_schemas(database_name)
|
|
|
|
elif name == "list_tables":
|
|
schema = arguments.get("schema")
|
|
include_counts = arguments.get("include_row_counts", False)
|
|
result = await list_tables(schema, include_counts)
|
|
|
|
elif name == "list_views":
|
|
schema_name = arguments["schema_name"]
|
|
result = await list_views(schema_name)
|
|
|
|
elif name == "describe_table":
|
|
table_name = arguments["table_name"]
|
|
result = await describe_table(table_name)
|
|
|
|
elif name == "read_data":
|
|
query = arguments["query"]
|
|
max_rows = arguments.get("max_rows", 50000)
|
|
result = await read_data(query, max_rows)
|
|
|
|
elif name == "get_system_health":
|
|
result = await get_system_health()
|
|
|
|
elif name == "describe_query":
|
|
query = arguments["query"]
|
|
result = await describe_query(query)
|
|
|
|
elif name == "execute_async":
|
|
query = arguments["query"]
|
|
result = await execute_async(query)
|
|
|
|
elif name == "get_query_status":
|
|
query_id = arguments["query_id"]
|
|
result = await get_query_status(query_id)
|
|
|
|
elif name == "get_async_results":
|
|
query_id = arguments["query_id"]
|
|
max_rows = arguments.get("max_rows", 50000)
|
|
result = await get_async_results(query_id, max_rows)
|
|
|
|
elif name == "list_async_queries":
|
|
result = await list_async_queries()
|
|
|
|
elif name == "read_data_paginated":
|
|
query = arguments["query"]
|
|
page_size = arguments.get("page_size", 1000)
|
|
page = arguments.get("page", 1)
|
|
result = await read_data_paginated(query, page_size, page)
|
|
|
|
elif name == "read_data_pandas":
|
|
query = arguments["query"]
|
|
max_rows = arguments.get("max_rows", 50000)
|
|
output_format = arguments.get("output_format", "records")
|
|
result = await read_data_pandas(query, max_rows, output_format)
|
|
|
|
else:
|
|
raise ValueError(f"Unknown tool: {name}")
|
|
|
|
return [TextContent(
|
|
type="text",
|
|
text=json.dumps(result, indent=2, default=str)
|
|
)]
|
|
|
|
except Exception as e:
|
|
logger.error(f"Tool execution failed: {e}")
|
|
raise INTERNAL_ERROR(str(e))
|
|
|
|
|
|
async def main():
|
|
"""Run the MCP server"""
|
|
from mcp.server.stdio import stdio_server
|
|
|
|
logger.info("Starting Snowflake MCP Server with SSO Authentication")
|
|
logger.info(f"Account: {SNOWFLAKE_CONFIG['account']}")
|
|
logger.info(f"User: {SNOWFLAKE_CONFIG['user']}")
|
|
logger.info("First connection will open browser for SSO login")
|
|
|
|
async with stdio_server() as (read_stream, write_stream):
|
|
await server.run(
|
|
read_stream,
|
|
write_stream,
|
|
server.create_initialization_options()
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
try:
|
|
asyncio.run(main())
|
|
except KeyboardInterrupt:
|
|
logger.info("Server stopped by user")
|
|
finally:
|
|
connection.close()
|