Files
snowflake-mcp/snowflake_mcp_server.py

1191 lines
39 KiB
Python

#!/usr/bin/env python3
"""
Snowflake MCP Server with SSO Authentication
Provides read-only Snowflake access via the Model Context Protocol.
"""
import asyncio
import csv
import io
import json
import logging
import os
import re
from typing import Any, Optional, Dict, List
from datetime import datetime, timedelta
import hashlib
import snowflake.connector
from snowflake.connector import DictCursor
from snowflake.connector.constants import QueryStatus
from mcp.server import Server
from mcp.types import TextContent, Tool, INTERNAL_ERROR
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Configuration — set via environment variables or edit the defaults below
# ---------------------------------------------------------------------------
SNOWFLAKE_CONFIG = {
"account": os.environ.get("SNOWFLAKE_ACCOUNT", "YOUR_ACCOUNT"),
"user": os.environ.get("SNOWFLAKE_USER", "your.email@example.com"),
"authenticator": os.environ.get("SNOWFLAKE_AUTHENTICATOR", "externalbrowser"),
"warehouse": os.environ.get("SNOWFLAKE_WAREHOUSE"), # None = account default
"role": os.environ.get("SNOWFLAKE_ROLE"), # None = account default
}
class SnowflakeConnection:
"""Manages a single Snowflake connection with SSO authentication"""
def __init__(self):
self._conn: Optional[snowflake.connector.SnowflakeConnection] = None
self._last_used: Optional[datetime] = None
self._timeout = timedelta(minutes=30)
def get(self) -> snowflake.connector.SnowflakeConnection:
"""Get or create the Snowflake connection"""
# Check if we have a valid cached connection
if self._conn and self._last_used:
if datetime.now() - self._last_used < self._timeout:
try:
cursor = self._conn.cursor()
cursor.execute("SELECT 1")
cursor.close()
self._last_used = datetime.now()
logger.info("Reusing cached Snowflake connection")
return self._conn
except Exception:
logger.info("Cached connection is stale, reconnecting...")
try:
self._conn.close()
except Exception:
pass
# Create new connection
logger.info("Creating new Snowflake connection")
logger.info("Note: First connection will open browser for SSO login")
connect_params = {
"account": SNOWFLAKE_CONFIG["account"],
"user": SNOWFLAKE_CONFIG["user"],
"authenticator": SNOWFLAKE_CONFIG["authenticator"],
}
if SNOWFLAKE_CONFIG.get("warehouse"):
connect_params["warehouse"] = SNOWFLAKE_CONFIG["warehouse"]
if SNOWFLAKE_CONFIG.get("role"):
connect_params["role"] = SNOWFLAKE_CONFIG["role"]
try:
self._conn = snowflake.connector.connect(**connect_params)
self._last_used = datetime.now()
logger.info("Successfully connected to Snowflake")
return self._conn
except Exception as e:
logger.error(f"Failed to connect to Snowflake: {e}")
raise
def close(self):
"""Close the connection"""
if self._conn:
try:
self._conn.close()
logger.info("Closed Snowflake connection")
except Exception:
pass
self._conn = None
self._last_used = None
# Global connection
connection = SnowflakeConnection()
# Query cache
query_cache: Dict[str, tuple] = {}
CACHE_TTL = timedelta(minutes=5)
# Async query tracking
async_queries: Dict[str, Dict[str, Any]] = {}
def get_cache_key(query: str) -> str:
"""Generate cache key for query"""
return hashlib.md5(query.encode()).hexdigest()
def validate_query(query: str) -> tuple[bool, str]:
"""Validate query for safety"""
query_upper = query.upper().strip()
# Allow SELECT, WITH (CTEs), and SHOW/DESCRIBE commands
allowed_starts = ('SELECT', 'WITH', 'SHOW', 'DESCRIBE', 'DESC', 'LIST')
if not any(query_upper.startswith(start) for start in allowed_starts):
if 'SELECT' not in query_upper:
return False, "Only SELECT queries are allowed"
# Block dangerous keywords
dangerous = ['INSERT', 'UPDATE', 'DELETE', 'DROP', 'TRUNCATE',
'CREATE', 'ALTER', 'GRANT', 'REVOKE', 'MERGE']
for keyword in dangerous:
if f' {keyword} ' in f' {query_upper} ':
return False, f"Query contains forbidden keyword: {keyword}"
return True, "Query validated successfully"
async def test_connection() -> Dict[str, Any]:
"""Test Snowflake connection and return system information"""
try:
conn = connection.get()
cursor = conn.cursor(DictCursor)
cursor.execute("""
SELECT
CURRENT_ACCOUNT() as account,
CURRENT_USER() as user_name,
CURRENT_ROLE() as current_role,
CURRENT_WAREHOUSE() as warehouse,
CURRENT_DATABASE() as database_name,
CURRENT_SCHEMA() as schema_name,
CURRENT_VERSION() as version
""")
row = cursor.fetchone()
cursor.close()
return {
"status": "connected",
"account": row["ACCOUNT"],
"user": row["USER_NAME"],
"role": row["CURRENT_ROLE"],
"warehouse": row["WAREHOUSE"],
"database": row["DATABASE_NAME"],
"schema": row["SCHEMA_NAME"],
"version": row["VERSION"],
"auth_method": SNOWFLAKE_CONFIG["authenticator"],
"message": "Connection successful"
}
except Exception as e:
logger.error(f"Connection test failed: {e}")
return {
"status": "error",
"error": str(e),
"message": "If this is your first connection, a browser window should open for SSO login"
}
async def list_tables(schema: Optional[str] = None,
include_row_counts: bool = False) -> Dict[str, Any]:
"""List all tables in the database/schema"""
try:
conn = connection.get()
cursor = conn.cursor(DictCursor)
if schema:
query = f"SHOW TABLES IN SCHEMA {schema}"
else:
query = "SHOW TABLES"
cursor.execute(query)
tables = []
for row in cursor:
table_info = {
"database": row.get("database_name"),
"schema": row.get("schema_name"),
"name": row.get("name"),
"full_name": f"{row.get('database_name')}.{row.get('schema_name')}.{row.get('name')}",
"description": row.get("comment"),
"kind": row.get("kind"),
"created": str(row.get("created_on")) if row.get("created_on") else None,
"rows": row.get("rows") if include_row_counts else None
}
tables.append(table_info)
cursor.close()
if not tables:
return {
"tables": [],
"count": 0,
"hint": f"No tables found in {schema or 'current schema'}. This schema may contain VIEWS instead of tables. Use list_views to check for views."
}
return {"tables": tables, "count": len(tables)}
except Exception as e:
logger.error(f"Failed to list tables: {e}")
raise
async def list_views(schema_name: str) -> List[Dict[str, Any]]:
"""List all views in a schema with their descriptions"""
try:
conn = connection.get()
cursor = conn.cursor(DictCursor)
cursor.execute(f"SHOW VIEWS IN SCHEMA {schema_name}")
views = []
for row in cursor:
view_info = {
"name": row.get("name"),
"database": row.get("database_name"),
"schema": row.get("schema_name"),
"full_name": f"{row.get('database_name')}.{row.get('schema_name')}.{row.get('name')}",
"description": row.get("comment"),
"created": str(row.get("created_on")) if row.get("created_on") else None,
"is_secure": row.get("is_secure") == "true",
"is_materialized": row.get("is_materialized") == "true"
}
views.append(view_info)
cursor.close()
return views
except Exception as e:
logger.error(f"Failed to list views in {schema_name}: {e}")
raise
async def describe_table(table_name: str) -> Dict[str, Any]:
"""Get detailed schema information for a table or view"""
try:
conn = connection.get()
cursor = conn.cursor()
parts = table_name.split('.')
if len(parts) == 3:
db, schema, table = parts
elif len(parts) == 2:
db = None
schema, table = parts
else:
db = None
schema = None
table = table_name
if len(parts) == 3:
quoted_name = f'{parts[0]}.{parts[1]}."{parts[2]}"'
elif len(parts) == 2:
quoted_name = f'{parts[0]}."{parts[1]}"'
else:
quoted_name = f'"{table_name}"'
rows = None
used_name = table_name
for try_name in [table_name, quoted_name]:
try:
cursor.execute(f"SHOW COLUMNS IN TABLE {try_name}")
rows = cursor.fetchall()
used_name = try_name
break
except Exception as inner_e:
if "does not exist" in str(inner_e).lower() and try_name == table_name:
continue
raise
if rows is None:
rows = []
desc_columns = []
if cursor.description:
for desc in cursor.description:
col_name = getattr(desc, 'name', None) or desc[0]
desc_columns.append(str(col_name).upper())
columns = []
for row in rows:
row_dict = {}
for i, val in enumerate(row):
if i < len(desc_columns):
row_dict[desc_columns[i]] = val
col_info = {
"name": row_dict.get("COLUMN_NAME"),
"type": row_dict.get("DATA_TYPE"),
"nullable": row_dict.get("IS_NULLABLE") == "YES",
"default": row_dict.get("COLUMN_DEFAULT"),
"comment": row_dict.get("COMMENT"),
"kind": row_dict.get("KIND")
}
columns.append(col_info)
cursor.close()
return {
"database": db,
"schema": schema,
"table": table,
"full_name": table_name,
"quoted_name": used_name if used_name != table_name else None,
"columns": columns,
"column_count": len(columns),
"hint": "Use the quoted_name in queries if provided, as this object has case-sensitive naming." if used_name != table_name else None
}
except Exception as e:
logger.error(f"Failed to describe table {table_name}: {e}")
error_str = str(e)
suggestions = []
if "does not exist" in error_str.lower():
suggestions.append("The object may be a VIEW. Use list_views to check available views in the schema.")
suggestions.append("Object names are case-sensitive. The name may need to be quoted.")
elif "not authorized" in error_str.lower():
suggestions.append("Check that your role has SELECT privileges on this object.")
return {
"error": error_str,
"table_name": table_name,
"suggestions": suggestions
}
async def read_data(query: str, max_rows: int = 50000) -> Dict[str, Any]:
"""Execute a SELECT query and return results"""
is_valid, message = validate_query(query)
if not is_valid:
return {
"error": message,
"suggestion": "Use SELECT statements only. For schema information, use describe_table."
}
cache_key = get_cache_key(query)
if cache_key in query_cache:
cached_result, cached_time = query_cache[cache_key]
if datetime.now() - cached_time < CACHE_TTL:
logger.info("Returning cached result for query")
return {
**cached_result,
"cached": True,
"cache_age_seconds": (datetime.now() - cached_time).total_seconds()
}
try:
conn = connection.get()
cursor = conn.cursor(DictCursor)
cursor.execute(query)
columns = [desc[0] for desc in cursor.description] if cursor.description else []
rows = []
for i, row in enumerate(cursor):
if i >= max_rows:
logger.warning(f"Query returned more than {max_rows} rows, truncating")
break
row_dict = {}
for col in columns:
value = row[col]
if value is None:
row_dict[col] = None
elif isinstance(value, (datetime, bytes)):
row_dict[col] = str(value)
else:
row_dict[col] = value
rows.append(row_dict)
cursor.close()
result = {
"columns": columns,
"rows": rows,
"row_count": len(rows),
"truncated": len(rows) == max_rows,
}
query_cache[cache_key] = (result, datetime.now())
return result
except Exception as e:
logger.error(f"Query execution failed: {e}")
error_str = str(e)
suggestions = []
if "does not exist" in error_str.lower() or "not authorized" in error_str.lower():
suggestions.append("Object names in Snowflake are case-sensitive when created with quotes. Try wrapping table/view names in double quotes.")
suggestions.append("Use list_views to check if the object is a view rather than a table.")
else:
suggestions.append("Check your query syntax. Use describe_table to see available columns.")
return {
"error": error_str,
"suggestions": suggestions
}
async def list_databases() -> List[Dict[str, Any]]:
"""List all accessible databases"""
try:
conn = connection.get()
cursor = conn.cursor(DictCursor)
cursor.execute("SHOW DATABASES")
databases = []
for row in cursor:
db_info = {
"name": row.get("name"),
"created": str(row.get("created_on")) if row.get("created_on") else None,
"owner": row.get("owner"),
"comment": row.get("comment"),
"origin": row.get("origin")
}
databases.append(db_info)
cursor.close()
return databases
except Exception as e:
logger.error(f"Failed to list databases: {e}")
raise
async def list_schemas(database_name: Optional[str] = None) -> List[Dict[str, Any]]:
"""List all schemas in a database"""
try:
conn = connection.get()
cursor = conn.cursor(DictCursor)
if database_name:
cursor.execute(f"SHOW SCHEMAS IN DATABASE {database_name}")
else:
cursor.execute("SHOW SCHEMAS")
schemas = []
for row in cursor:
schema_info = {
"name": row.get("name"),
"database": row.get("database_name"),
"created": str(row.get("created_on")) if row.get("created_on") else None,
"owner": row.get("owner")
}
schemas.append(schema_info)
cursor.close()
return schemas
except Exception as e:
logger.error(f"Failed to list schemas: {e}")
raise
async def get_system_health() -> Dict[str, Any]:
"""Get system health and connection pool statistics"""
health_info = {
"status": "healthy",
"timestamp": datetime.now().isoformat(),
"cache": {
"query_cache_size": len(query_cache),
"cache_ttl_minutes": CACHE_TTL.total_seconds() / 60
},
"authentication": {
"method": SNOWFLAKE_CONFIG["authenticator"],
"account": SNOWFLAKE_CONFIG["account"],
"user": SNOWFLAKE_CONFIG["user"],
}
}
try:
test_result = await test_connection()
health_info["connection"] = test_result
except Exception:
health_info["status"] = "degraded"
health_info["connection"] = {"status": "error"}
return health_info
async def describe_query(query: str) -> Dict[str, Any]:
"""Preview query output columns without executing the full query.
Uses LIMIT 0 to get column metadata efficiently."""
is_valid, message = validate_query(query)
if not is_valid:
return {"error": message}
try:
conn = connection.get()
cursor = conn.cursor(DictCursor)
clean_query = query.rstrip(';').strip()
clean_query = re.sub(r'\s+LIMIT\s+\d+(\s+OFFSET\s+\d+)?\s*$', '', clean_query, flags=re.IGNORECASE)
clean_query = re.sub(r'\s+OFFSET\s+\d+\s*$', '', clean_query, flags=re.IGNORECASE)
limited_query = f"{clean_query} LIMIT 0"
try:
cursor.execute(limited_query)
except Exception:
limited_query = f"SELECT * FROM ({clean_query}) AS subq LIMIT 0"
cursor.execute(limited_query)
columns = []
if cursor.description:
for desc in cursor.description:
col_info = {
"name": desc[0],
"type_code": desc[1],
"display_size": desc[2],
"internal_size": desc[3],
"precision": desc[4],
"scale": desc[5],
"nullable": desc[6]
}
columns.append(col_info)
cursor.close()
return {
"columns": columns,
"column_count": len(columns),
"query_preview": query[:200] + "..." if len(query) > 200 else query
}
except Exception as e:
logger.error(f"Failed to describe query: {e}")
return {
"error": str(e),
"suggestion": "Check query syntax. The query must be a valid SELECT statement."
}
async def execute_async(query: str) -> Dict[str, Any]:
"""Submit a query for asynchronous execution. Returns query ID for status tracking."""
is_valid, message = validate_query(query)
if not is_valid:
return {"error": message}
try:
conn = connection.get()
cursor = conn.cursor()
cursor.execute_async(query)
query_id = cursor.sfqid
async_queries[query_id] = {
"query": query[:500],
"submitted_at": datetime.now().isoformat(),
"status": "RUNNING"
}
cursor.close()
return {
"query_id": query_id,
"status": "SUBMITTED",
"message": "Query submitted for async execution. Use get_query_status to check progress.",
}
except Exception as e:
logger.error(f"Failed to submit async query: {e}")
return {"error": str(e)}
async def get_query_status(query_id: str) -> Dict[str, Any]:
"""Check the status of an async query by its query ID."""
try:
conn = connection.get()
status = conn.get_query_status(query_id)
status_name = status.name if hasattr(status, 'name') else str(status)
is_running = conn.is_still_running(status)
is_error = conn.is_an_error(status)
if query_id in async_queries:
async_queries[query_id]["status"] = status_name
result = {
"query_id": query_id,
"status": status_name,
"is_running": is_running,
"is_error": is_error,
"can_fetch_results": not is_running and not is_error
}
if query_id in async_queries:
result["query_info"] = async_queries[query_id]
return result
except Exception as e:
logger.error(f"Failed to get query status: {e}")
return {"error": str(e), "query_id": query_id}
async def get_async_results(query_id: str, max_rows: int = 50000) -> Dict[str, Any]:
"""Retrieve results from a completed async query."""
try:
conn = connection.get()
status = conn.get_query_status_throw_if_error(query_id)
if conn.is_still_running(status):
return {
"error": "Query is still running",
"query_id": query_id,
"status": status.name if hasattr(status, 'name') else str(status)
}
cursor = conn.cursor()
cursor.get_results_from_sfqid(query_id)
raw_rows = cursor.fetchmany(max_rows)
columns = [desc[0] for desc in cursor.description] if cursor.description else []
rows = []
for row in raw_rows:
row_dict = {}
for idx, col in enumerate(columns):
value = row[idx] if idx < len(row) else None
if value is None:
row_dict[col] = None
elif isinstance(value, (datetime, bytes)):
row_dict[col] = str(value)
else:
row_dict[col] = value
rows.append(row_dict)
truncated = len(raw_rows) == max_rows
cursor.close()
if query_id in async_queries:
del async_queries[query_id]
return {
"query_id": query_id,
"columns": columns,
"rows": rows,
"row_count": len(rows),
"truncated": truncated
}
except Exception as e:
logger.error(f"Failed to get async results: {e}")
return {"error": str(e), "query_id": query_id}
async def read_data_paginated(query: str, page_size: int = 1000,
page: int = 1) -> Dict[str, Any]:
"""Execute a query with pagination support (offset/limit)."""
is_valid, message = validate_query(query)
if not is_valid:
return {"error": message}
if page < 1:
return {"error": "Page must be >= 1"}
if page_size < 1 or page_size > 10000:
return {"error": "Page size must be between 1 and 10000"}
offset = (page - 1) * page_size
clean_query = query.rstrip(';').strip()
clean_query = re.sub(r'\s+LIMIT\s+\d+(\s+OFFSET\s+\d+)?\s*$', '', clean_query, flags=re.IGNORECASE)
clean_query = re.sub(r'\s+OFFSET\s+\d+\s*$', '', clean_query, flags=re.IGNORECASE)
paginated_query = f"""
SELECT * FROM (
{clean_query}
) AS paginated_result
LIMIT {page_size} OFFSET {offset}
"""
count_cache_key = get_cache_key(f"COUNT:{clean_query}")
total_count = None
if count_cache_key in query_cache:
cached_count, cached_time = query_cache[count_cache_key]
if datetime.now() - cached_time < CACHE_TTL:
total_count = cached_count
try:
conn = connection.get()
cursor = conn.cursor(DictCursor)
if total_count is None:
count_query = f"SELECT COUNT(*) as total FROM ({clean_query}) AS count_subq"
cursor.execute(count_query)
count_row = cursor.fetchone()
total_count = count_row["TOTAL"] if count_row else 0
query_cache[count_cache_key] = (total_count, datetime.now())
cursor.execute(paginated_query)
columns = [desc[0] for desc in cursor.description] if cursor.description else []
rows = []
for row in cursor:
row_dict = {}
for col in columns:
value = row[col]
if value is None:
row_dict[col] = None
elif isinstance(value, (datetime, bytes)):
row_dict[col] = str(value)
else:
row_dict[col] = value
rows.append(row_dict)
cursor.close()
total_pages = (total_count + page_size - 1) // page_size if total_count else 0
return {
"columns": columns,
"rows": rows,
"pagination": {
"page": page,
"page_size": page_size,
"total_rows": total_count,
"total_pages": total_pages,
"has_next": page < total_pages,
"has_previous": page > 1
},
}
except Exception as e:
logger.error(f"Paginated query failed: {e}")
return {"error": str(e)}
async def read_data_pandas(query: str, max_rows: int = 50000,
output_format: str = "records") -> Dict[str, Any]:
"""Execute a query and return results in a pandas-friendly format.
output_format options:
- 'records': List of dicts (default, same as read_data)
- 'columns': Dict with column arrays {col1: [vals], col2: [vals]}
- 'split': Dict with columns and data arrays {columns: [...], data: [[...]]}
- 'csv': CSV-formatted string
"""
is_valid, message = validate_query(query)
if not is_valid:
return {"error": message}
if output_format not in ('records', 'columns', 'split', 'csv'):
return {"error": f"Invalid output_format: {output_format}. Use: records, columns, split, csv"}
try:
conn = connection.get()
cursor = conn.cursor(DictCursor)
cursor.execute(query)
columns = [desc[0] for desc in cursor.description] if cursor.description else []
all_rows = []
for i, row in enumerate(cursor):
if i >= max_rows:
break
row_data = {}
for col in columns:
value = row[col]
if value is None:
row_data[col] = None
elif isinstance(value, datetime):
row_data[col] = value.isoformat()
elif isinstance(value, bytes):
row_data[col] = value.decode('utf-8', errors='replace')
else:
row_data[col] = value
all_rows.append(row_data)
cursor.close()
truncated = len(all_rows) == max_rows
if output_format == 'records':
data = all_rows
elif output_format == 'columns':
data = {col: [row.get(col) for row in all_rows] for col in columns}
elif output_format == 'split':
data = {
"columns": columns,
"data": [[row.get(col) for col in columns] for row in all_rows]
}
elif output_format == 'csv':
output = io.StringIO()
writer = csv.DictWriter(output, fieldnames=columns)
writer.writeheader()
writer.writerows(all_rows)
data = output.getvalue()
return {
"format": output_format,
"columns": columns,
"data": data,
"row_count": len(all_rows),
"truncated": truncated,
"pandas_hint": "Use pd.DataFrame(result['data']) for 'records' format, pd.DataFrame(result['data']) for 'columns' format, or pd.read_csv(io.StringIO(result['data'])) for 'csv' format"
}
except Exception as e:
logger.error(f"Pandas query failed: {e}")
return {"error": str(e)}
async def list_async_queries() -> Dict[str, Any]:
"""List all tracked async queries and their statuses."""
return {
"queries": async_queries,
"count": len(async_queries)
}
# Create MCP server
server = Server("snowflake-mcp")
@server.list_tools()
async def handle_list_tools() -> List[Tool]:
"""List available MCP tools"""
return [
Tool(
name="test_connection",
description="Test Snowflake connection and get session information",
inputSchema={
"type": "object",
"properties": {}
}
),
Tool(
name="list_databases",
description="List all accessible databases in Snowflake",
inputSchema={
"type": "object",
"properties": {}
}
),
Tool(
name="list_schemas",
description="List all schemas in a database",
inputSchema={
"type": "object",
"properties": {
"database_name": {
"type": "string",
"description": "Database name (optional, uses current if not specified)"
}
}
}
),
Tool(
name="list_tables",
description="List all tables in the current database/schema",
inputSchema={
"type": "object",
"properties": {
"schema": {
"type": "string",
"description": "Schema filter (optional, e.g. 'MY_DB.MY_SCHEMA')"
},
"include_row_counts": {
"type": "boolean",
"description": "Include row counts",
"default": False
}
}
}
),
Tool(
name="list_views",
description="List all views in a schema with their descriptions",
inputSchema={
"type": "object",
"properties": {
"schema_name": {
"type": "string",
"description": "Schema name (e.g., 'MY_DB.MY_SCHEMA')"
}
},
"required": ["schema_name"]
}
),
Tool(
name="describe_table",
description="Get detailed schema information for a table or view",
inputSchema={
"type": "object",
"properties": {
"table_name": {
"type": "string",
"description": "Table name (can include database.schema.table)"
}
},
"required": ["table_name"]
}
),
Tool(
name="read_data",
description="Execute a SELECT query on Snowflake",
inputSchema={
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "SQL SELECT query to execute"
},
"max_rows": {
"type": "integer",
"description": "Maximum rows to return",
"default": 50000
}
},
"required": ["query"]
}
),
Tool(
name="get_system_health",
description="Get system health and connection pool statistics",
inputSchema={
"type": "object",
"properties": {}
}
),
Tool(
name="describe_query",
description="Preview query output columns without executing the full query. Useful for validating complex queries before running them.",
inputSchema={
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "SQL SELECT query to analyze"
}
},
"required": ["query"]
}
),
Tool(
name="execute_async",
description="Submit a query for asynchronous execution. Returns a query ID for status tracking. Use for long-running queries.",
inputSchema={
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "SQL SELECT query to execute asynchronously"
}
},
"required": ["query"]
}
),
Tool(
name="get_query_status",
description="Check the status of an async query by its query ID",
inputSchema={
"type": "object",
"properties": {
"query_id": {
"type": "string",
"description": "Snowflake query ID from execute_async"
}
},
"required": ["query_id"]
}
),
Tool(
name="get_async_results",
description="Retrieve results from a completed async query",
inputSchema={
"type": "object",
"properties": {
"query_id": {
"type": "string",
"description": "Snowflake query ID from execute_async"
},
"max_rows": {
"type": "integer",
"description": "Maximum rows to return",
"default": 50000
}
},
"required": ["query_id"]
}
),
Tool(
name="list_async_queries",
description="List all tracked async queries and their statuses",
inputSchema={
"type": "object",
"properties": {}
}
),
Tool(
name="read_data_paginated",
description="Execute a query with pagination support. Returns results page by page with total count.",
inputSchema={
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "SQL SELECT query to execute"
},
"page_size": {
"type": "integer",
"description": "Number of rows per page (1-10000)",
"default": 1000
},
"page": {
"type": "integer",
"description": "Page number (1-based)",
"default": 1
}
},
"required": ["query"]
}
),
Tool(
name="read_data_pandas",
description="Execute a query and return results in a pandas-friendly format. Supports multiple output formats: records, columns, split, csv.",
inputSchema={
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "SQL SELECT query to execute"
},
"max_rows": {
"type": "integer",
"description": "Maximum rows to return",
"default": 50000
},
"output_format": {
"type": "string",
"description": "Output format: 'records' (list of dicts), 'columns' (dict of arrays), 'split' (columns + data arrays), 'csv' (CSV string)",
"default": "records",
"enum": ["records", "columns", "split", "csv"]
}
},
"required": ["query"]
}
)
]
@server.call_tool()
async def call_tool(name: str, arguments: Any) -> List[TextContent]:
"""Handle tool calls"""
try:
if name == "test_connection":
result = await test_connection()
elif name == "list_databases":
result = await list_databases()
elif name == "list_schemas":
database_name = arguments.get("database_name")
result = await list_schemas(database_name)
elif name == "list_tables":
schema = arguments.get("schema")
include_counts = arguments.get("include_row_counts", False)
result = await list_tables(schema, include_counts)
elif name == "list_views":
schema_name = arguments["schema_name"]
result = await list_views(schema_name)
elif name == "describe_table":
table_name = arguments["table_name"]
result = await describe_table(table_name)
elif name == "read_data":
query = arguments["query"]
max_rows = arguments.get("max_rows", 50000)
result = await read_data(query, max_rows)
elif name == "get_system_health":
result = await get_system_health()
elif name == "describe_query":
query = arguments["query"]
result = await describe_query(query)
elif name == "execute_async":
query = arguments["query"]
result = await execute_async(query)
elif name == "get_query_status":
query_id = arguments["query_id"]
result = await get_query_status(query_id)
elif name == "get_async_results":
query_id = arguments["query_id"]
max_rows = arguments.get("max_rows", 50000)
result = await get_async_results(query_id, max_rows)
elif name == "list_async_queries":
result = await list_async_queries()
elif name == "read_data_paginated":
query = arguments["query"]
page_size = arguments.get("page_size", 1000)
page = arguments.get("page", 1)
result = await read_data_paginated(query, page_size, page)
elif name == "read_data_pandas":
query = arguments["query"]
max_rows = arguments.get("max_rows", 50000)
output_format = arguments.get("output_format", "records")
result = await read_data_pandas(query, max_rows, output_format)
else:
raise ValueError(f"Unknown tool: {name}")
return [TextContent(
type="text",
text=json.dumps(result, indent=2, default=str)
)]
except Exception as e:
logger.error(f"Tool execution failed: {e}")
raise INTERNAL_ERROR(str(e))
async def main():
"""Run the MCP server"""
from mcp.server.stdio import stdio_server
logger.info("Starting Snowflake MCP Server with SSO Authentication")
logger.info(f"Account: {SNOWFLAKE_CONFIG['account']}")
logger.info(f"User: {SNOWFLAKE_CONFIG['user']}")
logger.info("First connection will open browser for SSO login")
async with stdio_server() as (read_stream, write_stream):
await server.run(
read_stream,
write_stream,
server.create_initialization_options()
)
if __name__ == "__main__":
try:
asyncio.run(main())
except KeyboardInterrupt:
logger.info("Server stopped by user")
finally:
connection.close()