Files
snowflake-mcp/snowflake_mcp_server.py
T
2026-03-16 10:59:27 +00:00

1343 lines
46 KiB
Python

#!/usr/bin/env python3
"""
Snowflake MCP Server with SSO Authentication
Provides read-only Snowflake access via the Model Context Protocol.
"""
import asyncio
import csv
import io
import json
import logging
import os
import re
from typing import Any, Optional, Dict, List
from datetime import datetime, timedelta
import hashlib
import snowflake.connector
from snowflake.connector import DictCursor
from snowflake.connector.constants import QueryStatus
from mcp.server import Server
from mcp.types import TextContent, Tool, INTERNAL_ERROR
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Configuration — set via environment variables or edit the defaults below
# ---------------------------------------------------------------------------
SNOWFLAKE_CONFIG = {
"account": os.environ.get("SNOWFLAKE_ACCOUNT", "YOUR_ACCOUNT"),
"user": os.environ.get("SNOWFLAKE_USER", "your.email@example.com"),
"authenticator": os.environ.get("SNOWFLAKE_AUTHENTICATOR", "externalbrowser"),
"warehouse": os.environ.get("SNOWFLAKE_WAREHOUSE"), # None = account default
"role": os.environ.get("SNOWFLAKE_ROLE"), # None = account default
}
# Available databases/schemas to expose
DATABASES = {
"default": {
"database": os.environ.get("SNOWFLAKE_DATABASE"), # None = account default
"schema": os.environ.get("SNOWFLAKE_SCHEMA"), # None = account default
"description": "Default database and schema"
},
}
class SnowflakeConnectionPool:
"""Manages Snowflake connections with SSO authentication"""
def __init__(self):
self.connections: Dict[str, snowflake.connector.SnowflakeConnection] = {}
self.last_used: Dict[str, datetime] = {}
self.connection_timeout = timedelta(minutes=30)
def get_connection(self, database: str = "default") -> snowflake.connector.SnowflakeConnection:
"""Get or create a Snowflake connection with SSO auth"""
cache_key = database
# Check if we have a valid cached connection
if cache_key in self.connections:
conn = self.connections[cache_key]
last = self.last_used.get(cache_key, datetime.now())
# Check if connection is still valid and not timed out
if datetime.now() - last < self.connection_timeout:
try:
# Test the connection
cursor = conn.cursor()
cursor.execute("SELECT 1")
cursor.close()
self.last_used[cache_key] = datetime.now()
logger.info(f"Reusing cached Snowflake connection for {database}")
return conn
except Exception:
logger.info(f"Cached connection for {database} is stale, reconnecting...")
try:
conn.close()
except Exception:
pass
# Create new connection
logger.info(f"Creating new Snowflake connection for {database}")
logger.info("Note: First connection will open browser for SSO login")
db_config = DATABASES.get(database, DATABASES["default"])
connect_params = {
"account": SNOWFLAKE_CONFIG["account"],
"user": SNOWFLAKE_CONFIG["user"],
"authenticator": SNOWFLAKE_CONFIG["authenticator"],
}
# Add optional parameters if specified
if db_config.get("database"):
connect_params["database"] = db_config["database"]
if db_config.get("schema"):
connect_params["schema"] = db_config["schema"]
if SNOWFLAKE_CONFIG.get("warehouse"):
connect_params["warehouse"] = SNOWFLAKE_CONFIG["warehouse"]
if SNOWFLAKE_CONFIG.get("role"):
connect_params["role"] = SNOWFLAKE_CONFIG["role"]
try:
conn = snowflake.connector.connect(**connect_params)
# Store the connection
self.connections[cache_key] = conn
self.last_used[cache_key] = datetime.now()
logger.info(f"Successfully connected to Snowflake ({database})")
return conn
except Exception as e:
logger.error(f"Failed to connect to Snowflake: {e}")
raise
def close_all(self):
"""Close all connections"""
for database, conn in self.connections.items():
try:
conn.close()
logger.info(f"Closed Snowflake connection to {database}")
except Exception:
pass
self.connections.clear()
self.last_used.clear()
# Global connection pool
connection_pool = SnowflakeConnectionPool()
# Query cache
query_cache: Dict[str, tuple] = {}
CACHE_TTL = timedelta(minutes=5)
# Async query tracking
async_queries: Dict[str, Dict[str, Any]] = {}
def get_cache_key(query: str, database: str) -> str:
"""Generate cache key for query"""
return hashlib.md5(f"{database}:{query}".encode()).hexdigest()
def validate_query(query: str) -> tuple[bool, str]:
"""Validate query for safety"""
query_upper = query.upper().strip()
# Allow SELECT, WITH (CTEs), and SHOW/DESCRIBE commands
allowed_starts = ('SELECT', 'WITH', 'SHOW', 'DESCRIBE', 'DESC', 'LIST')
if not any(query_upper.startswith(start) for start in allowed_starts):
if 'SELECT' not in query_upper:
return False, "Only SELECT queries are allowed"
# Block dangerous keywords
dangerous = ['INSERT', 'UPDATE', 'DELETE', 'DROP', 'TRUNCATE',
'CREATE', 'ALTER', 'GRANT', 'REVOKE', 'MERGE']
for keyword in dangerous:
# Check for keyword as whole word (not part of column name)
if f' {keyword} ' in f' {query_upper} ':
return False, f"Query contains forbidden keyword: {keyword}"
return True, "Query validated successfully"
async def test_connection(database: str = "default") -> Dict[str, Any]:
"""Test Snowflake connection and return system information"""
try:
conn = connection_pool.get_connection(database)
cursor = conn.cursor(DictCursor)
# Get session information
cursor.execute("""
SELECT
CURRENT_ACCOUNT() as account,
CURRENT_USER() as user_name,
CURRENT_ROLE() as current_role,
CURRENT_WAREHOUSE() as warehouse,
CURRENT_DATABASE() as database_name,
CURRENT_SCHEMA() as schema_name,
CURRENT_VERSION() as version
""")
row = cursor.fetchone()
cursor.close()
return {
"status": "connected",
"account": row["ACCOUNT"],
"user": row["USER_NAME"],
"role": row["CURRENT_ROLE"],
"warehouse": row["WAREHOUSE"],
"database": row["DATABASE_NAME"],
"schema": row["SCHEMA_NAME"],
"version": row["VERSION"],
"auth_method": SNOWFLAKE_CONFIG["authenticator"],
"message": "Connection successful"
}
except Exception as e:
logger.error(f"Connection test failed: {e}")
return {
"status": "error",
"error": str(e),
"message": "If this is your first connection, a browser window should open for SSO login"
}
async def list_tables(database: str = "default", schema: Optional[str] = None,
include_row_counts: bool = False) -> Dict[str, Any]:
"""List all tables in the database/schema"""
try:
conn = connection_pool.get_connection(database)
cursor = conn.cursor(DictCursor)
# Get current database and schema if not specified
if schema:
query = f"SHOW TABLES IN SCHEMA {schema}"
else:
query = "SHOW TABLES"
cursor.execute(query)
tables = []
for row in cursor:
table_info = {
"database": row.get("database_name"),
"schema": row.get("schema_name"),
"name": row.get("name"),
"full_name": f"{row.get('database_name')}.{row.get('schema_name')}.{row.get('name')}",
"description": row.get("comment"),
"kind": row.get("kind"),
"created": str(row.get("created_on")) if row.get("created_on") else None,
"rows": row.get("rows") if include_row_counts else None
}
tables.append(table_info)
cursor.close()
# Return with helpful hint if no tables found
if not tables:
return {
"tables": [],
"count": 0,
"hint": f"No tables found in {schema or 'current schema'}. This schema may contain VIEWS instead of tables. Use list_views to check for views."
}
return {"tables": tables, "count": len(tables)}
except Exception as e:
logger.error(f"Failed to list tables: {e}")
raise
async def list_views(schema_name: str, database: str = "default") -> List[Dict[str, Any]]:
"""List all views in a schema with their descriptions"""
try:
conn = connection_pool.get_connection(database)
cursor = conn.cursor(DictCursor)
cursor.execute(f"SHOW VIEWS IN SCHEMA {schema_name}")
views = []
for row in cursor:
view_info = {
"name": row.get("name"),
"database": row.get("database_name"),
"schema": row.get("schema_name"),
"full_name": f"{row.get('database_name')}.{row.get('schema_name')}.{row.get('name')}",
"description": row.get("comment"),
"created": str(row.get("created_on")) if row.get("created_on") else None,
"is_secure": row.get("is_secure") == "true",
"is_materialized": row.get("is_materialized") == "true"
}
views.append(view_info)
cursor.close()
return views
except Exception as e:
logger.error(f"Failed to list views in {schema_name}: {e}")
raise
async def describe_table(table_name: str, database: str = "default") -> Dict[str, Any]:
"""Get detailed schema information for a table or view"""
try:
conn = connection_pool.get_connection(database)
cursor = conn.cursor()
# Parse table name to handle quoting
parts = table_name.split('.')
if len(parts) == 3:
db, schema, table = parts
elif len(parts) == 2:
db = None
schema, table = parts
else:
db = None
schema = None
table = table_name
# Build quoted version for case-sensitive names
if len(parts) == 3:
quoted_name = f'{parts[0]}.{parts[1]}."{parts[2]}"'
elif len(parts) == 2:
quoted_name = f'{parts[0]}."{parts[1]}"'
else:
quoted_name = f'"{table_name}"'
# Try unquoted first, then quoted if that fails
rows = None
used_name = table_name
for try_name in [table_name, quoted_name]:
try:
cursor.execute(f"SHOW COLUMNS IN TABLE {try_name}")
rows = cursor.fetchall()
used_name = try_name
break
except Exception as inner_e:
if "does not exist" in str(inner_e).lower() and try_name == table_name:
continue # Try quoted version
raise
if rows is None:
rows = []
# Get column names from description using attribute access (safer for ResultMetadata)
desc_columns = []
if cursor.description:
for desc in cursor.description:
col_name = getattr(desc, 'name', None) or desc[0]
desc_columns.append(str(col_name).upper())
columns = []
for row in rows:
row_dict = {}
for i, val in enumerate(row):
if i < len(desc_columns):
row_dict[desc_columns[i]] = val
col_info = {
"name": row_dict.get("COLUMN_NAME"),
"type": row_dict.get("DATA_TYPE"),
"nullable": row_dict.get("IS_NULLABLE") == "YES",
"default": row_dict.get("COLUMN_DEFAULT"),
"comment": row_dict.get("COMMENT"),
"kind": row_dict.get("KIND")
}
columns.append(col_info)
cursor.close()
return {
"database": db,
"schema": schema,
"table": table,
"full_name": table_name,
"quoted_name": used_name if used_name != table_name else None,
"columns": columns,
"column_count": len(columns),
"hint": "Use the quoted_name in queries if provided, as this object has case-sensitive naming." if used_name != table_name else None
}
except Exception as e:
logger.error(f"Failed to describe table {table_name}: {e}")
error_str = str(e)
suggestions = []
if "does not exist" in error_str.lower():
suggestions.append("The object may be a VIEW. Use list_views to check available views in the schema.")
suggestions.append("Object names are case-sensitive. The name may need to be quoted.")
elif "not authorized" in error_str.lower():
suggestions.append("Check that your role has SELECT privileges on this object.")
return {
"error": error_str,
"table_name": table_name,
"suggestions": suggestions
}
async def read_data(query: str, database: str = "default",
max_rows: int = 50000) -> Dict[str, Any]:
"""Execute a SELECT query and return results"""
# Validate query
is_valid, message = validate_query(query)
if not is_valid:
return {
"error": message,
"suggestion": "Use SELECT statements only. For schema information, use describe_table."
}
# Check cache
cache_key = get_cache_key(query, database)
if cache_key in query_cache:
cached_result, cached_time = query_cache[cache_key]
if datetime.now() - cached_time < CACHE_TTL:
logger.info("Returning cached result for query")
return {
**cached_result,
"cached": True,
"cache_age_seconds": (datetime.now() - cached_time).total_seconds()
}
try:
conn = connection_pool.get_connection(database)
cursor = conn.cursor(DictCursor)
# Execute query
cursor.execute(query)
# Get column names
columns = [desc[0] for desc in cursor.description] if cursor.description else []
# Fetch rows up to limit
rows = []
for i, row in enumerate(cursor):
if i >= max_rows:
logger.warning(f"Query returned more than {max_rows} rows, truncating")
break
# Convert row to dict with proper serialization
row_dict = {}
for col in columns:
value = row[col]
if value is None:
row_dict[col] = None
elif isinstance(value, (datetime, bytes)):
row_dict[col] = str(value)
else:
row_dict[col] = value
rows.append(row_dict)
cursor.close()
result = {
"columns": columns,
"rows": rows,
"row_count": len(rows),
"truncated": len(rows) == max_rows,
"database": database
}
# Cache the result
query_cache[cache_key] = (result, datetime.now())
return result
except Exception as e:
logger.error(f"Query execution failed: {e}")
error_str = str(e)
suggestions = []
if "does not exist" in error_str.lower() or "not authorized" in error_str.lower():
suggestions.append("Object names in Snowflake are case-sensitive when created with quotes. Try wrapping table/view names in double quotes.")
suggestions.append("Use list_views to check if the object is a view rather than a table.")
else:
suggestions.append("Check your query syntax. Use describe_table to see available columns.")
return {
"error": error_str,
"suggestions": suggestions
}
async def list_databases() -> List[Dict[str, Any]]:
"""List all accessible databases"""
try:
conn = connection_pool.get_connection("default")
cursor = conn.cursor(DictCursor)
cursor.execute("SHOW DATABASES")
databases = []
for row in cursor:
db_info = {
"name": row.get("name"),
"created": str(row.get("created_on")) if row.get("created_on") else None,
"owner": row.get("owner"),
"comment": row.get("comment"),
"origin": row.get("origin")
}
databases.append(db_info)
cursor.close()
return databases
except Exception as e:
logger.error(f"Failed to list databases: {e}")
raise
async def list_schemas(database_name: Optional[str] = None) -> List[Dict[str, Any]]:
"""List all schemas in a database"""
try:
conn = connection_pool.get_connection("default")
cursor = conn.cursor(DictCursor)
if database_name:
cursor.execute(f"SHOW SCHEMAS IN DATABASE {database_name}")
else:
cursor.execute("SHOW SCHEMAS")
schemas = []
for row in cursor:
schema_info = {
"name": row.get("name"),
"database": row.get("database_name"),
"created": str(row.get("created_on")) if row.get("created_on") else None,
"owner": row.get("owner")
}
schemas.append(schema_info)
cursor.close()
return schemas
except Exception as e:
logger.error(f"Failed to list schemas: {e}")
raise
async def get_system_health() -> Dict[str, Any]:
"""Get system health and connection pool statistics"""
health_info = {
"status": "healthy",
"timestamp": datetime.now().isoformat(),
"connection_pool": {
"active_connections": len(connection_pool.connections),
"databases": list(connection_pool.connections.keys()),
"last_used": {
db: last.isoformat()
for db, last in connection_pool.last_used.items()
}
},
"cache": {
"query_cache_size": len(query_cache),
"cache_ttl_minutes": CACHE_TTL.total_seconds() / 60
},
"authentication": {
"method": SNOWFLAKE_CONFIG["authenticator"],
"account": SNOWFLAKE_CONFIG["account"],
"user": SNOWFLAKE_CONFIG["user"],
}
}
# Test connection
try:
test_result = await test_connection("default")
health_info["connection"] = test_result
except Exception:
health_info["status"] = "degraded"
health_info["connection"] = {"status": "error"}
return health_info
async def describe_query(query: str, database: str = "default") -> Dict[str, Any]:
"""Preview query output columns without executing the full query.
Uses LIMIT 0 to get column metadata efficiently."""
is_valid, message = validate_query(query)
if not is_valid:
return {"error": message}
try:
conn = connection_pool.get_connection(database)
cursor = conn.cursor(DictCursor)
clean_query = query.rstrip(';').strip()
# Remove existing LIMIT and OFFSET clauses (case insensitive)
clean_query = re.sub(r'\s+LIMIT\s+\d+(\s+OFFSET\s+\d+)?\s*$', '', clean_query, flags=re.IGNORECASE)
clean_query = re.sub(r'\s+OFFSET\s+\d+\s*$', '', clean_query, flags=re.IGNORECASE)
# Add LIMIT 0 to get column metadata without fetching data
limited_query = f"{clean_query} LIMIT 0"
try:
cursor.execute(limited_query)
except Exception:
# Fallback: wrap in subquery if direct approach fails
limited_query = f"SELECT * FROM ({clean_query}) AS subq LIMIT 0"
cursor.execute(limited_query)
columns = []
if cursor.description:
for desc in cursor.description:
col_info = {
"name": desc[0],
"type_code": desc[1],
"display_size": desc[2],
"internal_size": desc[3],
"precision": desc[4],
"scale": desc[5],
"nullable": desc[6]
}
columns.append(col_info)
cursor.close()
return {
"columns": columns,
"column_count": len(columns),
"database": database,
"query_preview": query[:200] + "..." if len(query) > 200 else query
}
except Exception as e:
logger.error(f"Failed to describe query: {e}")
return {
"error": str(e),
"suggestion": "Check query syntax. The query must be a valid SELECT statement."
}
async def execute_async(query: str, database: str = "default") -> Dict[str, Any]:
"""Submit a query for asynchronous execution. Returns query ID for status tracking."""
is_valid, message = validate_query(query)
if not is_valid:
return {"error": message}
try:
conn = connection_pool.get_connection(database)
cursor = conn.cursor()
# Execute asynchronously
cursor.execute_async(query)
query_id = cursor.sfqid
# Store query info for tracking
async_queries[query_id] = {
"query": query[:500],
"database": database,
"submitted_at": datetime.now().isoformat(),
"status": "RUNNING"
}
cursor.close()
return {
"query_id": query_id,
"status": "SUBMITTED",
"message": "Query submitted for async execution. Use get_query_status to check progress.",
"database": database
}
except Exception as e:
logger.error(f"Failed to submit async query: {e}")
return {"error": str(e)}
async def get_query_status(query_id: str, database: str = "default") -> Dict[str, Any]:
"""Check the status of an async query by its query ID."""
try:
conn = connection_pool.get_connection(database)
status = conn.get_query_status(query_id)
status_name = status.name if hasattr(status, 'name') else str(status)
is_running = conn.is_still_running(status)
is_error = conn.is_an_error(status)
# Update tracked query info
if query_id in async_queries:
async_queries[query_id]["status"] = status_name
result = {
"query_id": query_id,
"status": status_name,
"is_running": is_running,
"is_error": is_error,
"can_fetch_results": not is_running and not is_error
}
# Add original query info if available
if query_id in async_queries:
result["query_info"] = async_queries[query_id]
return result
except Exception as e:
logger.error(f"Failed to get query status: {e}")
return {"error": str(e), "query_id": query_id}
async def get_async_results(query_id: str, database: str = "default",
max_rows: int = 50000) -> Dict[str, Any]:
"""Retrieve results from a completed async query."""
try:
conn = connection_pool.get_connection(database)
# Check status first
status = conn.get_query_status_throw_if_error(query_id)
if conn.is_still_running(status):
return {
"error": "Query is still running",
"query_id": query_id,
"status": status.name if hasattr(status, 'name') else str(status)
}
# Use standard cursor for get_results_from_sfqid
cursor = conn.cursor()
cursor.get_results_from_sfqid(query_id)
# Fetch results first (description may not populate until after fetch)
raw_rows = cursor.fetchmany(max_rows)
# Now get column names from description
columns = [desc[0] for desc in cursor.description] if cursor.description else []
# Convert to list of dicts
rows = []
for row in raw_rows:
row_dict = {}
for idx, col in enumerate(columns):
value = row[idx] if idx < len(row) else None
if value is None:
row_dict[col] = None
elif isinstance(value, (datetime, bytes)):
row_dict[col] = str(value)
else:
row_dict[col] = value
rows.append(row_dict)
# Check if there are more rows
truncated = len(raw_rows) == max_rows
cursor.close()
# Clean up tracking
if query_id in async_queries:
del async_queries[query_id]
return {
"query_id": query_id,
"columns": columns,
"rows": rows,
"row_count": len(rows),
"truncated": truncated
}
except Exception as e:
logger.error(f"Failed to get async results: {e}")
return {"error": str(e), "query_id": query_id}
async def read_data_paginated(query: str, database: str = "default",
page_size: int = 1000, page: int = 1) -> Dict[str, Any]:
"""Execute a query with pagination support (offset/limit)."""
is_valid, message = validate_query(query)
if not is_valid:
return {"error": message}
if page < 1:
return {"error": "Page must be >= 1"}
if page_size < 1 or page_size > 10000:
return {"error": "Page size must be between 1 and 10000"}
offset = (page - 1) * page_size
# Strip existing LIMIT/OFFSET from query for proper pagination
clean_query = query.rstrip(';').strip()
clean_query = re.sub(r'\s+LIMIT\s+\d+(\s+OFFSET\s+\d+)?\s*$', '', clean_query, flags=re.IGNORECASE)
clean_query = re.sub(r'\s+OFFSET\s+\d+\s*$', '', clean_query, flags=re.IGNORECASE)
# Wrap query with pagination
paginated_query = f"""
SELECT * FROM (
{clean_query}
) AS paginated_result
LIMIT {page_size} OFFSET {offset}
"""
# Also get total count (cached for efficiency)
count_cache_key = get_cache_key(f"COUNT:{clean_query}", database)
total_count = None
if count_cache_key in query_cache:
cached_count, cached_time = query_cache[count_cache_key]
if datetime.now() - cached_time < CACHE_TTL:
total_count = cached_count
try:
conn = connection_pool.get_connection(database)
cursor = conn.cursor(DictCursor)
# Get total count if not cached
if total_count is None:
count_query = f"SELECT COUNT(*) as total FROM ({clean_query}) AS count_subq"
cursor.execute(count_query)
count_row = cursor.fetchone()
total_count = count_row["TOTAL"] if count_row else 0
query_cache[count_cache_key] = (total_count, datetime.now())
# Execute paginated query
cursor.execute(paginated_query)
columns = [desc[0] for desc in cursor.description] if cursor.description else []
rows = []
for row in cursor:
row_dict = {}
for col in columns:
value = row[col]
if value is None:
row_dict[col] = None
elif isinstance(value, (datetime, bytes)):
row_dict[col] = str(value)
else:
row_dict[col] = value
rows.append(row_dict)
cursor.close()
total_pages = (total_count + page_size - 1) // page_size if total_count else 0
return {
"columns": columns,
"rows": rows,
"pagination": {
"page": page,
"page_size": page_size,
"total_rows": total_count,
"total_pages": total_pages,
"has_next": page < total_pages,
"has_previous": page > 1
},
"database": database
}
except Exception as e:
logger.error(f"Paginated query failed: {e}")
return {"error": str(e)}
async def read_data_pandas(query: str, database: str = "default",
max_rows: int = 50000,
output_format: str = "records") -> Dict[str, Any]:
"""Execute a query and return results in a pandas-friendly format.
output_format options:
- 'records': List of dicts (default, same as read_data)
- 'columns': Dict with column arrays {col1: [vals], col2: [vals]}
- 'split': Dict with columns and data arrays {columns: [...], data: [[...]]}
- 'csv': CSV-formatted string
"""
is_valid, message = validate_query(query)
if not is_valid:
return {"error": message}
if output_format not in ('records', 'columns', 'split', 'csv'):
return {"error": f"Invalid output_format: {output_format}. Use: records, columns, split, csv"}
try:
conn = connection_pool.get_connection(database)
cursor = conn.cursor(DictCursor)
cursor.execute(query)
columns = [desc[0] for desc in cursor.description] if cursor.description else []
# Collect all rows first
all_rows = []
for i, row in enumerate(cursor):
if i >= max_rows:
break
row_data = {}
for col in columns:
value = row[col]
if value is None:
row_data[col] = None
elif isinstance(value, datetime):
row_data[col] = value.isoformat()
elif isinstance(value, bytes):
row_data[col] = value.decode('utf-8', errors='replace')
else:
row_data[col] = value
all_rows.append(row_data)
cursor.close()
truncated = len(all_rows) == max_rows
# Format output based on requested format
if output_format == 'records':
data = all_rows
elif output_format == 'columns':
data = {col: [row.get(col) for row in all_rows] for col in columns}
elif output_format == 'split':
data = {
"columns": columns,
"data": [[row.get(col) for col in columns] for row in all_rows]
}
elif output_format == 'csv':
output = io.StringIO()
writer = csv.DictWriter(output, fieldnames=columns)
writer.writeheader()
writer.writerows(all_rows)
data = output.getvalue()
return {
"format": output_format,
"columns": columns,
"data": data,
"row_count": len(all_rows),
"truncated": truncated,
"database": database,
"pandas_hint": "Use pd.DataFrame(result['data']) for 'records' format, pd.DataFrame(result['data']) for 'columns' format, or pd.read_csv(io.StringIO(result['data'])) for 'csv' format"
}
except Exception as e:
logger.error(f"Pandas query failed: {e}")
return {"error": str(e)}
async def list_async_queries() -> Dict[str, Any]:
"""List all tracked async queries and their statuses."""
return {
"queries": async_queries,
"count": len(async_queries)
}
# Create MCP server
server = Server("snowflake-mcp")
@server.list_tools()
async def list_tools() -> List[Tool]:
"""List available MCP tools"""
return [
Tool(
name="test_connection",
description="Test Snowflake connection and get session information",
inputSchema={
"type": "object",
"properties": {
"database": {
"type": "string",
"description": "Database context (default: 'default')",
"default": "default"
}
}
}
),
Tool(
name="list_databases",
description="List all accessible databases in Snowflake",
inputSchema={
"type": "object",
"properties": {}
}
),
Tool(
name="list_schemas",
description="List all schemas in a database",
inputSchema={
"type": "object",
"properties": {
"database_name": {
"type": "string",
"description": "Database name (optional, uses current if not specified)"
}
}
}
),
Tool(
name="list_tables",
description="List all tables in the current database/schema",
inputSchema={
"type": "object",
"properties": {
"database": {
"type": "string",
"description": "Database context",
"default": "default"
},
"schema": {
"type": "string",
"description": "Schema filter (optional)"
},
"include_row_counts": {
"type": "boolean",
"description": "Include row counts",
"default": False
}
}
}
),
Tool(
name="list_views",
description="List all views in a schema with their descriptions",
inputSchema={
"type": "object",
"properties": {
"schema_name": {
"type": "string",
"description": "Schema name (e.g., 'MY_DB.MY_SCHEMA')"
},
"database": {
"type": "string",
"description": "Database context",
"default": "default"
}
},
"required": ["schema_name"]
}
),
Tool(
name="describe_table",
description="Get detailed schema information for a table",
inputSchema={
"type": "object",
"properties": {
"table_name": {
"type": "string",
"description": "Table name (can include database.schema.table)"
},
"database": {
"type": "string",
"description": "Database context",
"default": "default"
}
},
"required": ["table_name"]
}
),
Tool(
name="read_data",
description="Execute a SELECT query on Snowflake",
inputSchema={
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "SQL SELECT query to execute"
},
"database": {
"type": "string",
"description": "Database context",
"default": "default"
},
"max_rows": {
"type": "integer",
"description": "Maximum rows to return",
"default": 50000
}
},
"required": ["query"]
}
),
Tool(
name="get_system_health",
description="Get system health and connection pool statistics",
inputSchema={
"type": "object",
"properties": {}
}
),
Tool(
name="describe_query",
description="Preview query output columns without executing the full query. Useful for validating complex queries before running them.",
inputSchema={
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "SQL SELECT query to analyze"
},
"database": {
"type": "string",
"description": "Database context",
"default": "default"
}
},
"required": ["query"]
}
),
Tool(
name="execute_async",
description="Submit a query for asynchronous execution. Returns a query ID for status tracking. Use for long-running queries.",
inputSchema={
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "SQL SELECT query to execute asynchronously"
},
"database": {
"type": "string",
"description": "Database context",
"default": "default"
}
},
"required": ["query"]
}
),
Tool(
name="get_query_status",
description="Check the status of an async query by its query ID",
inputSchema={
"type": "object",
"properties": {
"query_id": {
"type": "string",
"description": "Snowflake query ID from execute_async"
},
"database": {
"type": "string",
"description": "Database context",
"default": "default"
}
},
"required": ["query_id"]
}
),
Tool(
name="get_async_results",
description="Retrieve results from a completed async query",
inputSchema={
"type": "object",
"properties": {
"query_id": {
"type": "string",
"description": "Snowflake query ID from execute_async"
},
"database": {
"type": "string",
"description": "Database context",
"default": "default"
},
"max_rows": {
"type": "integer",
"description": "Maximum rows to return",
"default": 50000
}
},
"required": ["query_id"]
}
),
Tool(
name="list_async_queries",
description="List all tracked async queries and their statuses",
inputSchema={
"type": "object",
"properties": {}
}
),
Tool(
name="read_data_paginated",
description="Execute a query with pagination support. Returns results page by page with total count.",
inputSchema={
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "SQL SELECT query to execute"
},
"database": {
"type": "string",
"description": "Database context",
"default": "default"
},
"page_size": {
"type": "integer",
"description": "Number of rows per page (1-10000)",
"default": 1000
},
"page": {
"type": "integer",
"description": "Page number (1-based)",
"default": 1
}
},
"required": ["query"]
}
),
Tool(
name="read_data_pandas",
description="Execute a query and return results in a pandas-friendly format. Supports multiple output formats: records, columns, split, csv.",
inputSchema={
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "SQL SELECT query to execute"
},
"database": {
"type": "string",
"description": "Database context",
"default": "default"
},
"max_rows": {
"type": "integer",
"description": "Maximum rows to return",
"default": 50000
},
"output_format": {
"type": "string",
"description": "Output format: 'records' (list of dicts), 'columns' (dict of arrays), 'split' (columns + data arrays), 'csv' (CSV string)",
"default": "records",
"enum": ["records", "columns", "split", "csv"]
}
},
"required": ["query"]
}
)
]
@server.call_tool()
async def call_tool(name: str, arguments: Any) -> List[TextContent]:
"""Handle tool calls"""
try:
if name == "test_connection":
database = arguments.get("database", "default")
result = await test_connection(database)
elif name == "list_databases":
result = await list_databases()
elif name == "list_schemas":
database_name = arguments.get("database_name")
result = await list_schemas(database_name)
elif name == "list_tables":
database = arguments.get("database", "default")
schema = arguments.get("schema")
include_counts = arguments.get("include_row_counts", False)
result = await list_tables(database, schema, include_counts)
elif name == "list_views":
schema_name = arguments["schema_name"]
database = arguments.get("database", "default")
result = await list_views(schema_name, database)
elif name == "describe_table":
table_name = arguments["table_name"]
database = arguments.get("database", "default")
result = await describe_table(table_name, database)
elif name == "read_data":
query = arguments["query"]
database = arguments.get("database", "default")
max_rows = arguments.get("max_rows", 50000)
result = await read_data(query, database, max_rows)
elif name == "get_system_health":
result = await get_system_health()
elif name == "describe_query":
query = arguments["query"]
database = arguments.get("database", "default")
result = await describe_query(query, database)
elif name == "execute_async":
query = arguments["query"]
database = arguments.get("database", "default")
result = await execute_async(query, database)
elif name == "get_query_status":
query_id = arguments["query_id"]
database = arguments.get("database", "default")
result = await get_query_status(query_id, database)
elif name == "get_async_results":
query_id = arguments["query_id"]
database = arguments.get("database", "default")
max_rows = arguments.get("max_rows", 50000)
result = await get_async_results(query_id, database, max_rows)
elif name == "list_async_queries":
result = await list_async_queries()
elif name == "read_data_paginated":
query = arguments["query"]
database = arguments.get("database", "default")
page_size = arguments.get("page_size", 1000)
page = arguments.get("page", 1)
result = await read_data_paginated(query, database, page_size, page)
elif name == "read_data_pandas":
query = arguments["query"]
database = arguments.get("database", "default")
max_rows = arguments.get("max_rows", 50000)
output_format = arguments.get("output_format", "records")
result = await read_data_pandas(query, database, max_rows, output_format)
else:
raise ValueError(f"Unknown tool: {name}")
return [TextContent(
type="text",
text=json.dumps(result, indent=2, default=str)
)]
except Exception as e:
logger.error(f"Tool execution failed: {e}")
raise INTERNAL_ERROR(str(e))
async def main():
"""Run the MCP server"""
from mcp.server.stdio import stdio_server
logger.info("Starting Snowflake MCP Server with SSO Authentication")
logger.info(f"Account: {SNOWFLAKE_CONFIG['account']}")
logger.info(f"User: {SNOWFLAKE_CONFIG['user']}")
logger.info("First connection will open browser for SSO login")
async with stdio_server() as (read_stream, write_stream):
await server.run(
read_stream,
write_stream,
server.create_initialization_options()
)
if __name__ == "__main__":
try:
asyncio.run(main())
except KeyboardInterrupt:
logger.info("Server stopped by user")
finally:
connection_pool.close_all()