From 4ae150f4205f44757f91331245789e41ce5b1a0a Mon Sep 17 00:00:00 2001 From: A Charlwood Date: Mon, 16 Mar 2026 10:59:27 +0000 Subject: [PATCH] Initial commit --- README.md | 131 ++++ requirements.txt | 2 + snowflake_mcp_server.py | 1342 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 1475 insertions(+) create mode 100644 README.md create mode 100644 requirements.txt create mode 100644 snowflake_mcp_server.py diff --git a/README.md b/README.md new file mode 100644 index 0000000..dca2edd --- /dev/null +++ b/README.md @@ -0,0 +1,131 @@ +# Snowflake MCP Server + +A read-only [Model Context Protocol](https://modelcontextprotocol.io/) server for Snowflake, using SSO (browser-based) authentication. + +Gives Claude (or any MCP client) the ability to explore your Snowflake account — list databases, schemas, tables, views, describe columns, and run SELECT queries. + +## Features + +- SSO authentication (opens browser on first connection, then reuses the session) +- Connection pooling with 30-minute timeout +- Query caching (5 min TTL) +- Read-only safety: only SELECT / SHOW / DESCRIBE queries are allowed +- Async query support for long-running queries +- Paginated results +- Multiple output formats (records, columns, split, CSV) + +## Quick start with uv + +[uv](https://docs.astral.sh/uv/) is the easiest way to run this without managing a virtualenv yourself. + +### 1. Install uv (if you don't have it) + +```bash +# macOS / Linux +curl -LsSf https://astral.sh/uv/install.sh | sh + +# Windows (PowerShell) +powershell -ExecutionPolicy ByPass -c "irm https://astral.sh/uv/install.ps1 | iex" +``` + +### 2. Configure your Snowflake details + +Set environment variables (recommended) or edit the defaults at the top of `snowflake_mcp_server.py`. + +```bash +# Required +export SNOWFLAKE_ACCOUNT="your-account-id" # e.g. "xy12345.eu-west-1" +export SNOWFLAKE_USER="your.email@company.com" + +# Optional — defaults shown +export SNOWFLAKE_AUTHENTICATOR="externalbrowser" # SSO via browser +export SNOWFLAKE_WAREHOUSE="" # uses account default if empty +export SNOWFLAKE_ROLE="" # uses account default if empty +export SNOWFLAKE_DATABASE="" # uses account default if empty +export SNOWFLAKE_SCHEMA="" # uses account default if empty +``` + +On Windows (PowerShell): +```powershell +$env:SNOWFLAKE_ACCOUNT = "your-account-id" +$env:SNOWFLAKE_USER = "your.email@company.com" +``` + +### 3. Run the server + +```bash +uv run --with snowflake-connector-python --with mcp snowflake_mcp_server.py +``` + +Or, if you prefer to create a project first: + +```bash +uv init --no-package +uv add snowflake-connector-python mcp +uv run snowflake_mcp_server.py +``` + +### 4. Add to Claude Code + +**Option A — CLI (easiest):** + +```bash +claude mcp add --transport stdio snowflake \ + --env SNOWFLAKE_ACCOUNT=your-account-id \ + --env SNOWFLAKE_USER=your.email@company.com \ + --env SNOWFLAKE_ROLE=YOUR_ROLE \ + -- uv run --with snowflake-connector-python --with mcp /absolute/path/to/snowflake_mcp_server.py +``` + +**Option B — manual config:** + +Add to `.mcp.json` in your project root (shared with team) or `~/.claude.json` (personal): + +```json +{ + "mcpServers": { + "snowflake": { + "type": "stdio", + "command": "uv", + "args": [ + "run", + "--with", "snowflake-connector-python", + "--with", "mcp", + "/absolute/path/to/snowflake_mcp_server.py" + ], + "env": { + "SNOWFLAKE_ACCOUNT": "your-account-id", + "SNOWFLAKE_USER": "your.email@company.com", + "SNOWFLAKE_ROLE": "YOUR_ROLE" + } + } + } +} +``` + +Verify it's registered with `claude mcp list`. + +## Available tools + +| Tool | Description | +|------|-------------| +| `test_connection` | Verify connectivity and get session info | +| `list_databases` | List all accessible databases | +| `list_schemas` | List schemas in a database | +| `list_tables` | List tables (optionally with row counts) | +| `list_views` | List views in a schema | +| `describe_table` | Get column-level schema for a table or view | +| `describe_query` | Preview output columns without running the full query | +| `read_data` | Run a SELECT query (up to 50k rows) | +| `read_data_paginated` | Paginated query results with total count | +| `read_data_pandas` | Results in records / columns / split / CSV format | +| `execute_async` | Submit a long-running query asynchronously | +| `get_query_status` | Check async query progress | +| `get_async_results` | Fetch results from a completed async query | +| `list_async_queries` | List all tracked async queries | + +## Notes + +- The first connection opens your default browser for SSO login. Subsequent calls reuse the session. +- All queries are validated to be read-only — DML and DDL are blocked. +- Snowflake object names are case-sensitive when created with quotes. The server will automatically retry with quoted names if the unquoted version fails. diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..525ea6c --- /dev/null +++ b/requirements.txt @@ -0,0 +1,2 @@ +snowflake-connector-python>=3.6.0 +mcp>=1.0.0 diff --git a/snowflake_mcp_server.py b/snowflake_mcp_server.py new file mode 100644 index 0000000..2980f49 --- /dev/null +++ b/snowflake_mcp_server.py @@ -0,0 +1,1342 @@ +#!/usr/bin/env python3 +""" +Snowflake MCP Server with SSO Authentication +Provides read-only Snowflake access via the Model Context Protocol. +""" + +import asyncio +import csv +import io +import json +import logging +import os +import re +from typing import Any, Optional, Dict, List +from datetime import datetime, timedelta +import hashlib + +import snowflake.connector +from snowflake.connector import DictCursor +from snowflake.connector.constants import QueryStatus + +from mcp.server import Server +from mcp.types import TextContent, Tool, INTERNAL_ERROR + +# Configure logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Configuration — set via environment variables or edit the defaults below +# --------------------------------------------------------------------------- +SNOWFLAKE_CONFIG = { + "account": os.environ.get("SNOWFLAKE_ACCOUNT", "YOUR_ACCOUNT"), + "user": os.environ.get("SNOWFLAKE_USER", "your.email@example.com"), + "authenticator": os.environ.get("SNOWFLAKE_AUTHENTICATOR", "externalbrowser"), + "warehouse": os.environ.get("SNOWFLAKE_WAREHOUSE"), # None = account default + "role": os.environ.get("SNOWFLAKE_ROLE"), # None = account default +} + +# Available databases/schemas to expose +DATABASES = { + "default": { + "database": os.environ.get("SNOWFLAKE_DATABASE"), # None = account default + "schema": os.environ.get("SNOWFLAKE_SCHEMA"), # None = account default + "description": "Default database and schema" + }, +} + + +class SnowflakeConnectionPool: + """Manages Snowflake connections with SSO authentication""" + + def __init__(self): + self.connections: Dict[str, snowflake.connector.SnowflakeConnection] = {} + self.last_used: Dict[str, datetime] = {} + self.connection_timeout = timedelta(minutes=30) + + def get_connection(self, database: str = "default") -> snowflake.connector.SnowflakeConnection: + """Get or create a Snowflake connection with SSO auth""" + + cache_key = database + + # Check if we have a valid cached connection + if cache_key in self.connections: + conn = self.connections[cache_key] + last = self.last_used.get(cache_key, datetime.now()) + + # Check if connection is still valid and not timed out + if datetime.now() - last < self.connection_timeout: + try: + # Test the connection + cursor = conn.cursor() + cursor.execute("SELECT 1") + cursor.close() + self.last_used[cache_key] = datetime.now() + logger.info(f"Reusing cached Snowflake connection for {database}") + return conn + except Exception: + logger.info(f"Cached connection for {database} is stale, reconnecting...") + try: + conn.close() + except Exception: + pass + + # Create new connection + logger.info(f"Creating new Snowflake connection for {database}") + logger.info("Note: First connection will open browser for SSO login") + + db_config = DATABASES.get(database, DATABASES["default"]) + + connect_params = { + "account": SNOWFLAKE_CONFIG["account"], + "user": SNOWFLAKE_CONFIG["user"], + "authenticator": SNOWFLAKE_CONFIG["authenticator"], + } + + # Add optional parameters if specified + if db_config.get("database"): + connect_params["database"] = db_config["database"] + if db_config.get("schema"): + connect_params["schema"] = db_config["schema"] + if SNOWFLAKE_CONFIG.get("warehouse"): + connect_params["warehouse"] = SNOWFLAKE_CONFIG["warehouse"] + if SNOWFLAKE_CONFIG.get("role"): + connect_params["role"] = SNOWFLAKE_CONFIG["role"] + + try: + conn = snowflake.connector.connect(**connect_params) + + # Store the connection + self.connections[cache_key] = conn + self.last_used[cache_key] = datetime.now() + + logger.info(f"Successfully connected to Snowflake ({database})") + return conn + + except Exception as e: + logger.error(f"Failed to connect to Snowflake: {e}") + raise + + def close_all(self): + """Close all connections""" + for database, conn in self.connections.items(): + try: + conn.close() + logger.info(f"Closed Snowflake connection to {database}") + except Exception: + pass + self.connections.clear() + self.last_used.clear() + + +# Global connection pool +connection_pool = SnowflakeConnectionPool() + +# Query cache +query_cache: Dict[str, tuple] = {} +CACHE_TTL = timedelta(minutes=5) + +# Async query tracking +async_queries: Dict[str, Dict[str, Any]] = {} + + +def get_cache_key(query: str, database: str) -> str: + """Generate cache key for query""" + return hashlib.md5(f"{database}:{query}".encode()).hexdigest() + + +def validate_query(query: str) -> tuple[bool, str]: + """Validate query for safety""" + query_upper = query.upper().strip() + + # Allow SELECT, WITH (CTEs), and SHOW/DESCRIBE commands + allowed_starts = ('SELECT', 'WITH', 'SHOW', 'DESCRIBE', 'DESC', 'LIST') + if not any(query_upper.startswith(start) for start in allowed_starts): + if 'SELECT' not in query_upper: + return False, "Only SELECT queries are allowed" + + # Block dangerous keywords + dangerous = ['INSERT', 'UPDATE', 'DELETE', 'DROP', 'TRUNCATE', + 'CREATE', 'ALTER', 'GRANT', 'REVOKE', 'MERGE'] + + for keyword in dangerous: + # Check for keyword as whole word (not part of column name) + if f' {keyword} ' in f' {query_upper} ': + return False, f"Query contains forbidden keyword: {keyword}" + + return True, "Query validated successfully" + + +async def test_connection(database: str = "default") -> Dict[str, Any]: + """Test Snowflake connection and return system information""" + try: + conn = connection_pool.get_connection(database) + cursor = conn.cursor(DictCursor) + + # Get session information + cursor.execute(""" + SELECT + CURRENT_ACCOUNT() as account, + CURRENT_USER() as user_name, + CURRENT_ROLE() as current_role, + CURRENT_WAREHOUSE() as warehouse, + CURRENT_DATABASE() as database_name, + CURRENT_SCHEMA() as schema_name, + CURRENT_VERSION() as version + """) + + row = cursor.fetchone() + cursor.close() + + return { + "status": "connected", + "account": row["ACCOUNT"], + "user": row["USER_NAME"], + "role": row["CURRENT_ROLE"], + "warehouse": row["WAREHOUSE"], + "database": row["DATABASE_NAME"], + "schema": row["SCHEMA_NAME"], + "version": row["VERSION"], + "auth_method": SNOWFLAKE_CONFIG["authenticator"], + "message": "Connection successful" + } + + except Exception as e: + logger.error(f"Connection test failed: {e}") + return { + "status": "error", + "error": str(e), + "message": "If this is your first connection, a browser window should open for SSO login" + } + + +async def list_tables(database: str = "default", schema: Optional[str] = None, + include_row_counts: bool = False) -> Dict[str, Any]: + """List all tables in the database/schema""" + try: + conn = connection_pool.get_connection(database) + cursor = conn.cursor(DictCursor) + + # Get current database and schema if not specified + if schema: + query = f"SHOW TABLES IN SCHEMA {schema}" + else: + query = "SHOW TABLES" + + cursor.execute(query) + tables = [] + + for row in cursor: + table_info = { + "database": row.get("database_name"), + "schema": row.get("schema_name"), + "name": row.get("name"), + "full_name": f"{row.get('database_name')}.{row.get('schema_name')}.{row.get('name')}", + "description": row.get("comment"), + "kind": row.get("kind"), + "created": str(row.get("created_on")) if row.get("created_on") else None, + "rows": row.get("rows") if include_row_counts else None + } + tables.append(table_info) + + cursor.close() + + # Return with helpful hint if no tables found + if not tables: + return { + "tables": [], + "count": 0, + "hint": f"No tables found in {schema or 'current schema'}. This schema may contain VIEWS instead of tables. Use list_views to check for views." + } + + return {"tables": tables, "count": len(tables)} + + except Exception as e: + logger.error(f"Failed to list tables: {e}") + raise + + +async def list_views(schema_name: str, database: str = "default") -> List[Dict[str, Any]]: + """List all views in a schema with their descriptions""" + try: + conn = connection_pool.get_connection(database) + cursor = conn.cursor(DictCursor) + + cursor.execute(f"SHOW VIEWS IN SCHEMA {schema_name}") + + views = [] + for row in cursor: + view_info = { + "name": row.get("name"), + "database": row.get("database_name"), + "schema": row.get("schema_name"), + "full_name": f"{row.get('database_name')}.{row.get('schema_name')}.{row.get('name')}", + "description": row.get("comment"), + "created": str(row.get("created_on")) if row.get("created_on") else None, + "is_secure": row.get("is_secure") == "true", + "is_materialized": row.get("is_materialized") == "true" + } + views.append(view_info) + + cursor.close() + return views + + except Exception as e: + logger.error(f"Failed to list views in {schema_name}: {e}") + raise + + +async def describe_table(table_name: str, database: str = "default") -> Dict[str, Any]: + """Get detailed schema information for a table or view""" + try: + conn = connection_pool.get_connection(database) + cursor = conn.cursor() + + # Parse table name to handle quoting + parts = table_name.split('.') + if len(parts) == 3: + db, schema, table = parts + elif len(parts) == 2: + db = None + schema, table = parts + else: + db = None + schema = None + table = table_name + + # Build quoted version for case-sensitive names + if len(parts) == 3: + quoted_name = f'{parts[0]}.{parts[1]}."{parts[2]}"' + elif len(parts) == 2: + quoted_name = f'{parts[0]}."{parts[1]}"' + else: + quoted_name = f'"{table_name}"' + + # Try unquoted first, then quoted if that fails + rows = None + used_name = table_name + for try_name in [table_name, quoted_name]: + try: + cursor.execute(f"SHOW COLUMNS IN TABLE {try_name}") + rows = cursor.fetchall() + used_name = try_name + break + except Exception as inner_e: + if "does not exist" in str(inner_e).lower() and try_name == table_name: + continue # Try quoted version + raise + + if rows is None: + rows = [] + + # Get column names from description using attribute access (safer for ResultMetadata) + desc_columns = [] + if cursor.description: + for desc in cursor.description: + col_name = getattr(desc, 'name', None) or desc[0] + desc_columns.append(str(col_name).upper()) + + columns = [] + for row in rows: + row_dict = {} + for i, val in enumerate(row): + if i < len(desc_columns): + row_dict[desc_columns[i]] = val + + col_info = { + "name": row_dict.get("COLUMN_NAME"), + "type": row_dict.get("DATA_TYPE"), + "nullable": row_dict.get("IS_NULLABLE") == "YES", + "default": row_dict.get("COLUMN_DEFAULT"), + "comment": row_dict.get("COMMENT"), + "kind": row_dict.get("KIND") + } + columns.append(col_info) + + cursor.close() + + return { + "database": db, + "schema": schema, + "table": table, + "full_name": table_name, + "quoted_name": used_name if used_name != table_name else None, + "columns": columns, + "column_count": len(columns), + "hint": "Use the quoted_name in queries if provided, as this object has case-sensitive naming." if used_name != table_name else None + } + + except Exception as e: + logger.error(f"Failed to describe table {table_name}: {e}") + error_str = str(e) + suggestions = [] + + if "does not exist" in error_str.lower(): + suggestions.append("The object may be a VIEW. Use list_views to check available views in the schema.") + suggestions.append("Object names are case-sensitive. The name may need to be quoted.") + elif "not authorized" in error_str.lower(): + suggestions.append("Check that your role has SELECT privileges on this object.") + + return { + "error": error_str, + "table_name": table_name, + "suggestions": suggestions + } + + +async def read_data(query: str, database: str = "default", + max_rows: int = 50000) -> Dict[str, Any]: + """Execute a SELECT query and return results""" + + # Validate query + is_valid, message = validate_query(query) + if not is_valid: + return { + "error": message, + "suggestion": "Use SELECT statements only. For schema information, use describe_table." + } + + # Check cache + cache_key = get_cache_key(query, database) + if cache_key in query_cache: + cached_result, cached_time = query_cache[cache_key] + if datetime.now() - cached_time < CACHE_TTL: + logger.info("Returning cached result for query") + return { + **cached_result, + "cached": True, + "cache_age_seconds": (datetime.now() - cached_time).total_seconds() + } + + try: + conn = connection_pool.get_connection(database) + cursor = conn.cursor(DictCursor) + + # Execute query + cursor.execute(query) + + # Get column names + columns = [desc[0] for desc in cursor.description] if cursor.description else [] + + # Fetch rows up to limit + rows = [] + for i, row in enumerate(cursor): + if i >= max_rows: + logger.warning(f"Query returned more than {max_rows} rows, truncating") + break + + # Convert row to dict with proper serialization + row_dict = {} + for col in columns: + value = row[col] + if value is None: + row_dict[col] = None + elif isinstance(value, (datetime, bytes)): + row_dict[col] = str(value) + else: + row_dict[col] = value + + rows.append(row_dict) + + cursor.close() + + result = { + "columns": columns, + "rows": rows, + "row_count": len(rows), + "truncated": len(rows) == max_rows, + "database": database + } + + # Cache the result + query_cache[cache_key] = (result, datetime.now()) + + return result + + except Exception as e: + logger.error(f"Query execution failed: {e}") + error_str = str(e) + suggestions = [] + + if "does not exist" in error_str.lower() or "not authorized" in error_str.lower(): + suggestions.append("Object names in Snowflake are case-sensitive when created with quotes. Try wrapping table/view names in double quotes.") + suggestions.append("Use list_views to check if the object is a view rather than a table.") + else: + suggestions.append("Check your query syntax. Use describe_table to see available columns.") + + return { + "error": error_str, + "suggestions": suggestions + } + + +async def list_databases() -> List[Dict[str, Any]]: + """List all accessible databases""" + try: + conn = connection_pool.get_connection("default") + cursor = conn.cursor(DictCursor) + + cursor.execute("SHOW DATABASES") + + databases = [] + for row in cursor: + db_info = { + "name": row.get("name"), + "created": str(row.get("created_on")) if row.get("created_on") else None, + "owner": row.get("owner"), + "comment": row.get("comment"), + "origin": row.get("origin") + } + databases.append(db_info) + + cursor.close() + return databases + + except Exception as e: + logger.error(f"Failed to list databases: {e}") + raise + + +async def list_schemas(database_name: Optional[str] = None) -> List[Dict[str, Any]]: + """List all schemas in a database""" + try: + conn = connection_pool.get_connection("default") + cursor = conn.cursor(DictCursor) + + if database_name: + cursor.execute(f"SHOW SCHEMAS IN DATABASE {database_name}") + else: + cursor.execute("SHOW SCHEMAS") + + schemas = [] + for row in cursor: + schema_info = { + "name": row.get("name"), + "database": row.get("database_name"), + "created": str(row.get("created_on")) if row.get("created_on") else None, + "owner": row.get("owner") + } + schemas.append(schema_info) + + cursor.close() + return schemas + + except Exception as e: + logger.error(f"Failed to list schemas: {e}") + raise + + +async def get_system_health() -> Dict[str, Any]: + """Get system health and connection pool statistics""" + + health_info = { + "status": "healthy", + "timestamp": datetime.now().isoformat(), + "connection_pool": { + "active_connections": len(connection_pool.connections), + "databases": list(connection_pool.connections.keys()), + "last_used": { + db: last.isoformat() + for db, last in connection_pool.last_used.items() + } + }, + "cache": { + "query_cache_size": len(query_cache), + "cache_ttl_minutes": CACHE_TTL.total_seconds() / 60 + }, + "authentication": { + "method": SNOWFLAKE_CONFIG["authenticator"], + "account": SNOWFLAKE_CONFIG["account"], + "user": SNOWFLAKE_CONFIG["user"], + } + } + + # Test connection + try: + test_result = await test_connection("default") + health_info["connection"] = test_result + except Exception: + health_info["status"] = "degraded" + health_info["connection"] = {"status": "error"} + + return health_info + + +async def describe_query(query: str, database: str = "default") -> Dict[str, Any]: + """Preview query output columns without executing the full query. + Uses LIMIT 0 to get column metadata efficiently.""" + + is_valid, message = validate_query(query) + if not is_valid: + return {"error": message} + + try: + conn = connection_pool.get_connection(database) + cursor = conn.cursor(DictCursor) + + clean_query = query.rstrip(';').strip() + # Remove existing LIMIT and OFFSET clauses (case insensitive) + clean_query = re.sub(r'\s+LIMIT\s+\d+(\s+OFFSET\s+\d+)?\s*$', '', clean_query, flags=re.IGNORECASE) + clean_query = re.sub(r'\s+OFFSET\s+\d+\s*$', '', clean_query, flags=re.IGNORECASE) + + # Add LIMIT 0 to get column metadata without fetching data + limited_query = f"{clean_query} LIMIT 0" + + try: + cursor.execute(limited_query) + except Exception: + # Fallback: wrap in subquery if direct approach fails + limited_query = f"SELECT * FROM ({clean_query}) AS subq LIMIT 0" + cursor.execute(limited_query) + + columns = [] + if cursor.description: + for desc in cursor.description: + col_info = { + "name": desc[0], + "type_code": desc[1], + "display_size": desc[2], + "internal_size": desc[3], + "precision": desc[4], + "scale": desc[5], + "nullable": desc[6] + } + columns.append(col_info) + + cursor.close() + + return { + "columns": columns, + "column_count": len(columns), + "database": database, + "query_preview": query[:200] + "..." if len(query) > 200 else query + } + + except Exception as e: + logger.error(f"Failed to describe query: {e}") + return { + "error": str(e), + "suggestion": "Check query syntax. The query must be a valid SELECT statement." + } + + +async def execute_async(query: str, database: str = "default") -> Dict[str, Any]: + """Submit a query for asynchronous execution. Returns query ID for status tracking.""" + + is_valid, message = validate_query(query) + if not is_valid: + return {"error": message} + + try: + conn = connection_pool.get_connection(database) + cursor = conn.cursor() + + # Execute asynchronously + cursor.execute_async(query) + query_id = cursor.sfqid + + # Store query info for tracking + async_queries[query_id] = { + "query": query[:500], + "database": database, + "submitted_at": datetime.now().isoformat(), + "status": "RUNNING" + } + + cursor.close() + + return { + "query_id": query_id, + "status": "SUBMITTED", + "message": "Query submitted for async execution. Use get_query_status to check progress.", + "database": database + } + + except Exception as e: + logger.error(f"Failed to submit async query: {e}") + return {"error": str(e)} + + +async def get_query_status(query_id: str, database: str = "default") -> Dict[str, Any]: + """Check the status of an async query by its query ID.""" + + try: + conn = connection_pool.get_connection(database) + + status = conn.get_query_status(query_id) + status_name = status.name if hasattr(status, 'name') else str(status) + + is_running = conn.is_still_running(status) + is_error = conn.is_an_error(status) + + # Update tracked query info + if query_id in async_queries: + async_queries[query_id]["status"] = status_name + + result = { + "query_id": query_id, + "status": status_name, + "is_running": is_running, + "is_error": is_error, + "can_fetch_results": not is_running and not is_error + } + + # Add original query info if available + if query_id in async_queries: + result["query_info"] = async_queries[query_id] + + return result + + except Exception as e: + logger.error(f"Failed to get query status: {e}") + return {"error": str(e), "query_id": query_id} + + +async def get_async_results(query_id: str, database: str = "default", + max_rows: int = 50000) -> Dict[str, Any]: + """Retrieve results from a completed async query.""" + + try: + conn = connection_pool.get_connection(database) + + # Check status first + status = conn.get_query_status_throw_if_error(query_id) + + if conn.is_still_running(status): + return { + "error": "Query is still running", + "query_id": query_id, + "status": status.name if hasattr(status, 'name') else str(status) + } + + # Use standard cursor for get_results_from_sfqid + cursor = conn.cursor() + cursor.get_results_from_sfqid(query_id) + + # Fetch results first (description may not populate until after fetch) + raw_rows = cursor.fetchmany(max_rows) + + # Now get column names from description + columns = [desc[0] for desc in cursor.description] if cursor.description else [] + + # Convert to list of dicts + rows = [] + for row in raw_rows: + row_dict = {} + for idx, col in enumerate(columns): + value = row[idx] if idx < len(row) else None + if value is None: + row_dict[col] = None + elif isinstance(value, (datetime, bytes)): + row_dict[col] = str(value) + else: + row_dict[col] = value + rows.append(row_dict) + + # Check if there are more rows + truncated = len(raw_rows) == max_rows + + cursor.close() + + # Clean up tracking + if query_id in async_queries: + del async_queries[query_id] + + return { + "query_id": query_id, + "columns": columns, + "rows": rows, + "row_count": len(rows), + "truncated": truncated + } + + except Exception as e: + logger.error(f"Failed to get async results: {e}") + return {"error": str(e), "query_id": query_id} + + +async def read_data_paginated(query: str, database: str = "default", + page_size: int = 1000, page: int = 1) -> Dict[str, Any]: + """Execute a query with pagination support (offset/limit).""" + + is_valid, message = validate_query(query) + if not is_valid: + return {"error": message} + + if page < 1: + return {"error": "Page must be >= 1"} + if page_size < 1 or page_size > 10000: + return {"error": "Page size must be between 1 and 10000"} + + offset = (page - 1) * page_size + + # Strip existing LIMIT/OFFSET from query for proper pagination + clean_query = query.rstrip(';').strip() + clean_query = re.sub(r'\s+LIMIT\s+\d+(\s+OFFSET\s+\d+)?\s*$', '', clean_query, flags=re.IGNORECASE) + clean_query = re.sub(r'\s+OFFSET\s+\d+\s*$', '', clean_query, flags=re.IGNORECASE) + + # Wrap query with pagination + paginated_query = f""" + SELECT * FROM ( + {clean_query} + ) AS paginated_result + LIMIT {page_size} OFFSET {offset} + """ + + # Also get total count (cached for efficiency) + count_cache_key = get_cache_key(f"COUNT:{clean_query}", database) + total_count = None + + if count_cache_key in query_cache: + cached_count, cached_time = query_cache[count_cache_key] + if datetime.now() - cached_time < CACHE_TTL: + total_count = cached_count + + try: + conn = connection_pool.get_connection(database) + cursor = conn.cursor(DictCursor) + + # Get total count if not cached + if total_count is None: + count_query = f"SELECT COUNT(*) as total FROM ({clean_query}) AS count_subq" + cursor.execute(count_query) + count_row = cursor.fetchone() + total_count = count_row["TOTAL"] if count_row else 0 + query_cache[count_cache_key] = (total_count, datetime.now()) + + # Execute paginated query + cursor.execute(paginated_query) + + columns = [desc[0] for desc in cursor.description] if cursor.description else [] + + rows = [] + for row in cursor: + row_dict = {} + for col in columns: + value = row[col] + if value is None: + row_dict[col] = None + elif isinstance(value, (datetime, bytes)): + row_dict[col] = str(value) + else: + row_dict[col] = value + rows.append(row_dict) + + cursor.close() + + total_pages = (total_count + page_size - 1) // page_size if total_count else 0 + + return { + "columns": columns, + "rows": rows, + "pagination": { + "page": page, + "page_size": page_size, + "total_rows": total_count, + "total_pages": total_pages, + "has_next": page < total_pages, + "has_previous": page > 1 + }, + "database": database + } + + except Exception as e: + logger.error(f"Paginated query failed: {e}") + return {"error": str(e)} + + +async def read_data_pandas(query: str, database: str = "default", + max_rows: int = 50000, + output_format: str = "records") -> Dict[str, Any]: + """Execute a query and return results in a pandas-friendly format. + + output_format options: + - 'records': List of dicts (default, same as read_data) + - 'columns': Dict with column arrays {col1: [vals], col2: [vals]} + - 'split': Dict with columns and data arrays {columns: [...], data: [[...]]} + - 'csv': CSV-formatted string + """ + + is_valid, message = validate_query(query) + if not is_valid: + return {"error": message} + + if output_format not in ('records', 'columns', 'split', 'csv'): + return {"error": f"Invalid output_format: {output_format}. Use: records, columns, split, csv"} + + try: + conn = connection_pool.get_connection(database) + cursor = conn.cursor(DictCursor) + + cursor.execute(query) + + columns = [desc[0] for desc in cursor.description] if cursor.description else [] + + # Collect all rows first + all_rows = [] + for i, row in enumerate(cursor): + if i >= max_rows: + break + row_data = {} + for col in columns: + value = row[col] + if value is None: + row_data[col] = None + elif isinstance(value, datetime): + row_data[col] = value.isoformat() + elif isinstance(value, bytes): + row_data[col] = value.decode('utf-8', errors='replace') + else: + row_data[col] = value + all_rows.append(row_data) + + cursor.close() + truncated = len(all_rows) == max_rows + + # Format output based on requested format + if output_format == 'records': + data = all_rows + + elif output_format == 'columns': + data = {col: [row.get(col) for row in all_rows] for col in columns} + + elif output_format == 'split': + data = { + "columns": columns, + "data": [[row.get(col) for col in columns] for row in all_rows] + } + + elif output_format == 'csv': + output = io.StringIO() + writer = csv.DictWriter(output, fieldnames=columns) + writer.writeheader() + writer.writerows(all_rows) + data = output.getvalue() + + return { + "format": output_format, + "columns": columns, + "data": data, + "row_count": len(all_rows), + "truncated": truncated, + "database": database, + "pandas_hint": "Use pd.DataFrame(result['data']) for 'records' format, pd.DataFrame(result['data']) for 'columns' format, or pd.read_csv(io.StringIO(result['data'])) for 'csv' format" + } + + except Exception as e: + logger.error(f"Pandas query failed: {e}") + return {"error": str(e)} + + +async def list_async_queries() -> Dict[str, Any]: + """List all tracked async queries and their statuses.""" + return { + "queries": async_queries, + "count": len(async_queries) + } + + +# Create MCP server +server = Server("snowflake-mcp") + + +@server.list_tools() +async def list_tools() -> List[Tool]: + """List available MCP tools""" + return [ + Tool( + name="test_connection", + description="Test Snowflake connection and get session information", + inputSchema={ + "type": "object", + "properties": { + "database": { + "type": "string", + "description": "Database context (default: 'default')", + "default": "default" + } + } + } + ), + Tool( + name="list_databases", + description="List all accessible databases in Snowflake", + inputSchema={ + "type": "object", + "properties": {} + } + ), + Tool( + name="list_schemas", + description="List all schemas in a database", + inputSchema={ + "type": "object", + "properties": { + "database_name": { + "type": "string", + "description": "Database name (optional, uses current if not specified)" + } + } + } + ), + Tool( + name="list_tables", + description="List all tables in the current database/schema", + inputSchema={ + "type": "object", + "properties": { + "database": { + "type": "string", + "description": "Database context", + "default": "default" + }, + "schema": { + "type": "string", + "description": "Schema filter (optional)" + }, + "include_row_counts": { + "type": "boolean", + "description": "Include row counts", + "default": False + } + } + } + ), + Tool( + name="list_views", + description="List all views in a schema with their descriptions", + inputSchema={ + "type": "object", + "properties": { + "schema_name": { + "type": "string", + "description": "Schema name (e.g., 'MY_DB.MY_SCHEMA')" + }, + "database": { + "type": "string", + "description": "Database context", + "default": "default" + } + }, + "required": ["schema_name"] + } + ), + Tool( + name="describe_table", + description="Get detailed schema information for a table", + inputSchema={ + "type": "object", + "properties": { + "table_name": { + "type": "string", + "description": "Table name (can include database.schema.table)" + }, + "database": { + "type": "string", + "description": "Database context", + "default": "default" + } + }, + "required": ["table_name"] + } + ), + Tool( + name="read_data", + description="Execute a SELECT query on Snowflake", + inputSchema={ + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "SQL SELECT query to execute" + }, + "database": { + "type": "string", + "description": "Database context", + "default": "default" + }, + "max_rows": { + "type": "integer", + "description": "Maximum rows to return", + "default": 50000 + } + }, + "required": ["query"] + } + ), + Tool( + name="get_system_health", + description="Get system health and connection pool statistics", + inputSchema={ + "type": "object", + "properties": {} + } + ), + Tool( + name="describe_query", + description="Preview query output columns without executing the full query. Useful for validating complex queries before running them.", + inputSchema={ + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "SQL SELECT query to analyze" + }, + "database": { + "type": "string", + "description": "Database context", + "default": "default" + } + }, + "required": ["query"] + } + ), + Tool( + name="execute_async", + description="Submit a query for asynchronous execution. Returns a query ID for status tracking. Use for long-running queries.", + inputSchema={ + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "SQL SELECT query to execute asynchronously" + }, + "database": { + "type": "string", + "description": "Database context", + "default": "default" + } + }, + "required": ["query"] + } + ), + Tool( + name="get_query_status", + description="Check the status of an async query by its query ID", + inputSchema={ + "type": "object", + "properties": { + "query_id": { + "type": "string", + "description": "Snowflake query ID from execute_async" + }, + "database": { + "type": "string", + "description": "Database context", + "default": "default" + } + }, + "required": ["query_id"] + } + ), + Tool( + name="get_async_results", + description="Retrieve results from a completed async query", + inputSchema={ + "type": "object", + "properties": { + "query_id": { + "type": "string", + "description": "Snowflake query ID from execute_async" + }, + "database": { + "type": "string", + "description": "Database context", + "default": "default" + }, + "max_rows": { + "type": "integer", + "description": "Maximum rows to return", + "default": 50000 + } + }, + "required": ["query_id"] + } + ), + Tool( + name="list_async_queries", + description="List all tracked async queries and their statuses", + inputSchema={ + "type": "object", + "properties": {} + } + ), + Tool( + name="read_data_paginated", + description="Execute a query with pagination support. Returns results page by page with total count.", + inputSchema={ + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "SQL SELECT query to execute" + }, + "database": { + "type": "string", + "description": "Database context", + "default": "default" + }, + "page_size": { + "type": "integer", + "description": "Number of rows per page (1-10000)", + "default": 1000 + }, + "page": { + "type": "integer", + "description": "Page number (1-based)", + "default": 1 + } + }, + "required": ["query"] + } + ), + Tool( + name="read_data_pandas", + description="Execute a query and return results in a pandas-friendly format. Supports multiple output formats: records, columns, split, csv.", + inputSchema={ + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "SQL SELECT query to execute" + }, + "database": { + "type": "string", + "description": "Database context", + "default": "default" + }, + "max_rows": { + "type": "integer", + "description": "Maximum rows to return", + "default": 50000 + }, + "output_format": { + "type": "string", + "description": "Output format: 'records' (list of dicts), 'columns' (dict of arrays), 'split' (columns + data arrays), 'csv' (CSV string)", + "default": "records", + "enum": ["records", "columns", "split", "csv"] + } + }, + "required": ["query"] + } + ) + ] + + +@server.call_tool() +async def call_tool(name: str, arguments: Any) -> List[TextContent]: + """Handle tool calls""" + + try: + if name == "test_connection": + database = arguments.get("database", "default") + result = await test_connection(database) + + elif name == "list_databases": + result = await list_databases() + + elif name == "list_schemas": + database_name = arguments.get("database_name") + result = await list_schemas(database_name) + + elif name == "list_tables": + database = arguments.get("database", "default") + schema = arguments.get("schema") + include_counts = arguments.get("include_row_counts", False) + result = await list_tables(database, schema, include_counts) + + elif name == "list_views": + schema_name = arguments["schema_name"] + database = arguments.get("database", "default") + result = await list_views(schema_name, database) + + elif name == "describe_table": + table_name = arguments["table_name"] + database = arguments.get("database", "default") + result = await describe_table(table_name, database) + + elif name == "read_data": + query = arguments["query"] + database = arguments.get("database", "default") + max_rows = arguments.get("max_rows", 50000) + result = await read_data(query, database, max_rows) + + elif name == "get_system_health": + result = await get_system_health() + + elif name == "describe_query": + query = arguments["query"] + database = arguments.get("database", "default") + result = await describe_query(query, database) + + elif name == "execute_async": + query = arguments["query"] + database = arguments.get("database", "default") + result = await execute_async(query, database) + + elif name == "get_query_status": + query_id = arguments["query_id"] + database = arguments.get("database", "default") + result = await get_query_status(query_id, database) + + elif name == "get_async_results": + query_id = arguments["query_id"] + database = arguments.get("database", "default") + max_rows = arguments.get("max_rows", 50000) + result = await get_async_results(query_id, database, max_rows) + + elif name == "list_async_queries": + result = await list_async_queries() + + elif name == "read_data_paginated": + query = arguments["query"] + database = arguments.get("database", "default") + page_size = arguments.get("page_size", 1000) + page = arguments.get("page", 1) + result = await read_data_paginated(query, database, page_size, page) + + elif name == "read_data_pandas": + query = arguments["query"] + database = arguments.get("database", "default") + max_rows = arguments.get("max_rows", 50000) + output_format = arguments.get("output_format", "records") + result = await read_data_pandas(query, database, max_rows, output_format) + + else: + raise ValueError(f"Unknown tool: {name}") + + return [TextContent( + type="text", + text=json.dumps(result, indent=2, default=str) + )] + + except Exception as e: + logger.error(f"Tool execution failed: {e}") + raise INTERNAL_ERROR(str(e)) + + +async def main(): + """Run the MCP server""" + from mcp.server.stdio import stdio_server + + logger.info("Starting Snowflake MCP Server with SSO Authentication") + logger.info(f"Account: {SNOWFLAKE_CONFIG['account']}") + logger.info(f"User: {SNOWFLAKE_CONFIG['user']}") + logger.info("First connection will open browser for SSO login") + + async with stdio_server() as (read_stream, write_stream): + await server.run( + read_stream, + write_stream, + server.create_initialization_options() + ) + + +if __name__ == "__main__": + try: + asyncio.run(main()) + except KeyboardInterrupt: + logger.info("Server stopped by user") + finally: + connection_pool.close_all()