#!/usr/bin/env python3 """ Snowflake MCP Server with SSO Authentication Provides read-only Snowflake access via the Model Context Protocol. """ import asyncio import csv import io import json import logging import os import re from typing import Any, Optional, Dict, List from datetime import datetime, timedelta import hashlib import snowflake.connector from snowflake.connector import DictCursor from snowflake.connector.constants import QueryStatus from mcp.server import Server from mcp.types import TextContent, Tool, INTERNAL_ERROR # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # --------------------------------------------------------------------------- # Configuration — set via environment variables or edit the defaults below # --------------------------------------------------------------------------- SNOWFLAKE_CONFIG = { "account": os.environ.get("SNOWFLAKE_ACCOUNT", "YOUR_ACCOUNT"), "user": os.environ.get("SNOWFLAKE_USER", "your.email@example.com"), "authenticator": os.environ.get("SNOWFLAKE_AUTHENTICATOR", "externalbrowser"), "warehouse": os.environ.get("SNOWFLAKE_WAREHOUSE"), # None = account default "role": os.environ.get("SNOWFLAKE_ROLE"), # None = account default } class SnowflakeConnection: """Manages a single Snowflake connection with SSO authentication""" def __init__(self): self._conn: Optional[snowflake.connector.SnowflakeConnection] = None self._last_used: Optional[datetime] = None self._timeout = timedelta(minutes=30) def get(self) -> snowflake.connector.SnowflakeConnection: """Get or create the Snowflake connection""" # Check if we have a valid cached connection if self._conn and self._last_used: if datetime.now() - self._last_used < self._timeout: try: cursor = self._conn.cursor() cursor.execute("SELECT 1") cursor.close() self._last_used = datetime.now() logger.info("Reusing cached Snowflake connection") return self._conn except Exception: logger.info("Cached connection is stale, reconnecting...") try: self._conn.close() except Exception: pass # Create new connection logger.info("Creating new Snowflake connection") logger.info("Note: First connection will open browser for SSO login") connect_params = { "account": SNOWFLAKE_CONFIG["account"], "user": SNOWFLAKE_CONFIG["user"], "authenticator": SNOWFLAKE_CONFIG["authenticator"], } if SNOWFLAKE_CONFIG.get("warehouse"): connect_params["warehouse"] = SNOWFLAKE_CONFIG["warehouse"] if SNOWFLAKE_CONFIG.get("role"): connect_params["role"] = SNOWFLAKE_CONFIG["role"] try: self._conn = snowflake.connector.connect(**connect_params) self._last_used = datetime.now() logger.info("Successfully connected to Snowflake") return self._conn except Exception as e: logger.error(f"Failed to connect to Snowflake: {e}") raise def close(self): """Close the connection""" if self._conn: try: self._conn.close() logger.info("Closed Snowflake connection") except Exception: pass self._conn = None self._last_used = None # Global connection connection = SnowflakeConnection() # Query cache query_cache: Dict[str, tuple] = {} CACHE_TTL = timedelta(minutes=5) # Async query tracking async_queries: Dict[str, Dict[str, Any]] = {} def get_cache_key(query: str) -> str: """Generate cache key for query""" return hashlib.md5(query.encode()).hexdigest() def validate_query(query: str) -> tuple[bool, str]: """Validate query for safety""" query_upper = query.upper().strip() # Allow SELECT, WITH (CTEs), and SHOW/DESCRIBE commands allowed_starts = ('SELECT', 'WITH', 'SHOW', 'DESCRIBE', 'DESC', 'LIST') if not any(query_upper.startswith(start) for start in allowed_starts): if 'SELECT' not in query_upper: return False, "Only SELECT queries are allowed" # Block dangerous keywords dangerous = ['INSERT', 'UPDATE', 'DELETE', 'DROP', 'TRUNCATE', 'CREATE', 'ALTER', 'GRANT', 'REVOKE', 'MERGE'] for keyword in dangerous: if f' {keyword} ' in f' {query_upper} ': return False, f"Query contains forbidden keyword: {keyword}" return True, "Query validated successfully" async def test_connection() -> Dict[str, Any]: """Test Snowflake connection and return system information""" try: conn = connection.get() cursor = conn.cursor(DictCursor) cursor.execute(""" SELECT CURRENT_ACCOUNT() as account, CURRENT_USER() as user_name, CURRENT_ROLE() as current_role, CURRENT_WAREHOUSE() as warehouse, CURRENT_DATABASE() as database_name, CURRENT_SCHEMA() as schema_name, CURRENT_VERSION() as version """) row = cursor.fetchone() cursor.close() return { "status": "connected", "account": row["ACCOUNT"], "user": row["USER_NAME"], "role": row["CURRENT_ROLE"], "warehouse": row["WAREHOUSE"], "database": row["DATABASE_NAME"], "schema": row["SCHEMA_NAME"], "version": row["VERSION"], "auth_method": SNOWFLAKE_CONFIG["authenticator"], "message": "Connection successful" } except Exception as e: logger.error(f"Connection test failed: {e}") return { "status": "error", "error": str(e), "message": "If this is your first connection, a browser window should open for SSO login" } async def list_tables(schema: Optional[str] = None, include_row_counts: bool = False) -> Dict[str, Any]: """List all tables in the database/schema""" try: conn = connection.get() cursor = conn.cursor(DictCursor) if schema: query = f"SHOW TABLES IN SCHEMA {schema}" else: query = "SHOW TABLES" cursor.execute(query) tables = [] for row in cursor: table_info = { "database": row.get("database_name"), "schema": row.get("schema_name"), "name": row.get("name"), "full_name": f"{row.get('database_name')}.{row.get('schema_name')}.{row.get('name')}", "description": row.get("comment"), "kind": row.get("kind"), "created": str(row.get("created_on")) if row.get("created_on") else None, "rows": row.get("rows") if include_row_counts else None } tables.append(table_info) cursor.close() if not tables: return { "tables": [], "count": 0, "hint": f"No tables found in {schema or 'current schema'}. This schema may contain VIEWS instead of tables. Use list_views to check for views." } return {"tables": tables, "count": len(tables)} except Exception as e: logger.error(f"Failed to list tables: {e}") raise async def list_views(schema_name: str) -> List[Dict[str, Any]]: """List all views in a schema with their descriptions""" try: conn = connection.get() cursor = conn.cursor(DictCursor) cursor.execute(f"SHOW VIEWS IN SCHEMA {schema_name}") views = [] for row in cursor: view_info = { "name": row.get("name"), "database": row.get("database_name"), "schema": row.get("schema_name"), "full_name": f"{row.get('database_name')}.{row.get('schema_name')}.{row.get('name')}", "description": row.get("comment"), "created": str(row.get("created_on")) if row.get("created_on") else None, "is_secure": row.get("is_secure") == "true", "is_materialized": row.get("is_materialized") == "true" } views.append(view_info) cursor.close() return views except Exception as e: logger.error(f"Failed to list views in {schema_name}: {e}") raise async def describe_table(table_name: str) -> Dict[str, Any]: """Get detailed schema information for a table or view""" try: conn = connection.get() cursor = conn.cursor() parts = table_name.split('.') if len(parts) == 3: db, schema, table = parts elif len(parts) == 2: db = None schema, table = parts else: db = None schema = None table = table_name if len(parts) == 3: quoted_name = f'{parts[0]}.{parts[1]}."{parts[2]}"' elif len(parts) == 2: quoted_name = f'{parts[0]}."{parts[1]}"' else: quoted_name = f'"{table_name}"' rows = None used_name = table_name for try_name in [table_name, quoted_name]: try: cursor.execute(f"SHOW COLUMNS IN TABLE {try_name}") rows = cursor.fetchall() used_name = try_name break except Exception as inner_e: if "does not exist" in str(inner_e).lower() and try_name == table_name: continue raise if rows is None: rows = [] desc_columns = [] if cursor.description: for desc in cursor.description: col_name = getattr(desc, 'name', None) or desc[0] desc_columns.append(str(col_name).upper()) columns = [] for row in rows: row_dict = {} for i, val in enumerate(row): if i < len(desc_columns): row_dict[desc_columns[i]] = val col_info = { "name": row_dict.get("COLUMN_NAME"), "type": row_dict.get("DATA_TYPE"), "nullable": row_dict.get("IS_NULLABLE") == "YES", "default": row_dict.get("COLUMN_DEFAULT"), "comment": row_dict.get("COMMENT"), "kind": row_dict.get("KIND") } columns.append(col_info) cursor.close() return { "database": db, "schema": schema, "table": table, "full_name": table_name, "quoted_name": used_name if used_name != table_name else None, "columns": columns, "column_count": len(columns), "hint": "Use the quoted_name in queries if provided, as this object has case-sensitive naming." if used_name != table_name else None } except Exception as e: logger.error(f"Failed to describe table {table_name}: {e}") error_str = str(e) suggestions = [] if "does not exist" in error_str.lower(): suggestions.append("The object may be a VIEW. Use list_views to check available views in the schema.") suggestions.append("Object names are case-sensitive. The name may need to be quoted.") elif "not authorized" in error_str.lower(): suggestions.append("Check that your role has SELECT privileges on this object.") return { "error": error_str, "table_name": table_name, "suggestions": suggestions } async def read_data(query: str, max_rows: int = 50000) -> Dict[str, Any]: """Execute a SELECT query and return results""" is_valid, message = validate_query(query) if not is_valid: return { "error": message, "suggestion": "Use SELECT statements only. For schema information, use describe_table." } cache_key = get_cache_key(query) if cache_key in query_cache: cached_result, cached_time = query_cache[cache_key] if datetime.now() - cached_time < CACHE_TTL: logger.info("Returning cached result for query") return { **cached_result, "cached": True, "cache_age_seconds": (datetime.now() - cached_time).total_seconds() } try: conn = connection.get() cursor = conn.cursor(DictCursor) cursor.execute(query) columns = [desc[0] for desc in cursor.description] if cursor.description else [] rows = [] for i, row in enumerate(cursor): if i >= max_rows: logger.warning(f"Query returned more than {max_rows} rows, truncating") break row_dict = {} for col in columns: value = row[col] if value is None: row_dict[col] = None elif isinstance(value, (datetime, bytes)): row_dict[col] = str(value) else: row_dict[col] = value rows.append(row_dict) cursor.close() result = { "columns": columns, "rows": rows, "row_count": len(rows), "truncated": len(rows) == max_rows, } query_cache[cache_key] = (result, datetime.now()) return result except Exception as e: logger.error(f"Query execution failed: {e}") error_str = str(e) suggestions = [] if "does not exist" in error_str.lower() or "not authorized" in error_str.lower(): suggestions.append("Object names in Snowflake are case-sensitive when created with quotes. Try wrapping table/view names in double quotes.") suggestions.append("Use list_views to check if the object is a view rather than a table.") else: suggestions.append("Check your query syntax. Use describe_table to see available columns.") return { "error": error_str, "suggestions": suggestions } async def list_databases() -> List[Dict[str, Any]]: """List all accessible databases""" try: conn = connection.get() cursor = conn.cursor(DictCursor) cursor.execute("SHOW DATABASES") databases = [] for row in cursor: db_info = { "name": row.get("name"), "created": str(row.get("created_on")) if row.get("created_on") else None, "owner": row.get("owner"), "comment": row.get("comment"), "origin": row.get("origin") } databases.append(db_info) cursor.close() return databases except Exception as e: logger.error(f"Failed to list databases: {e}") raise async def list_schemas(database_name: Optional[str] = None) -> List[Dict[str, Any]]: """List all schemas in a database""" try: conn = connection.get() cursor = conn.cursor(DictCursor) if database_name: cursor.execute(f"SHOW SCHEMAS IN DATABASE {database_name}") else: cursor.execute("SHOW SCHEMAS") schemas = [] for row in cursor: schema_info = { "name": row.get("name"), "database": row.get("database_name"), "created": str(row.get("created_on")) if row.get("created_on") else None, "owner": row.get("owner") } schemas.append(schema_info) cursor.close() return schemas except Exception as e: logger.error(f"Failed to list schemas: {e}") raise async def get_system_health() -> Dict[str, Any]: """Get system health and connection pool statistics""" health_info = { "status": "healthy", "timestamp": datetime.now().isoformat(), "cache": { "query_cache_size": len(query_cache), "cache_ttl_minutes": CACHE_TTL.total_seconds() / 60 }, "authentication": { "method": SNOWFLAKE_CONFIG["authenticator"], "account": SNOWFLAKE_CONFIG["account"], "user": SNOWFLAKE_CONFIG["user"], } } try: test_result = await test_connection() health_info["connection"] = test_result except Exception: health_info["status"] = "degraded" health_info["connection"] = {"status": "error"} return health_info async def describe_query(query: str) -> Dict[str, Any]: """Preview query output columns without executing the full query. Uses LIMIT 0 to get column metadata efficiently.""" is_valid, message = validate_query(query) if not is_valid: return {"error": message} try: conn = connection.get() cursor = conn.cursor(DictCursor) clean_query = query.rstrip(';').strip() clean_query = re.sub(r'\s+LIMIT\s+\d+(\s+OFFSET\s+\d+)?\s*$', '', clean_query, flags=re.IGNORECASE) clean_query = re.sub(r'\s+OFFSET\s+\d+\s*$', '', clean_query, flags=re.IGNORECASE) limited_query = f"{clean_query} LIMIT 0" try: cursor.execute(limited_query) except Exception: limited_query = f"SELECT * FROM ({clean_query}) AS subq LIMIT 0" cursor.execute(limited_query) columns = [] if cursor.description: for desc in cursor.description: col_info = { "name": desc[0], "type_code": desc[1], "display_size": desc[2], "internal_size": desc[3], "precision": desc[4], "scale": desc[5], "nullable": desc[6] } columns.append(col_info) cursor.close() return { "columns": columns, "column_count": len(columns), "query_preview": query[:200] + "..." if len(query) > 200 else query } except Exception as e: logger.error(f"Failed to describe query: {e}") return { "error": str(e), "suggestion": "Check query syntax. The query must be a valid SELECT statement." } async def execute_async(query: str) -> Dict[str, Any]: """Submit a query for asynchronous execution. Returns query ID for status tracking.""" is_valid, message = validate_query(query) if not is_valid: return {"error": message} try: conn = connection.get() cursor = conn.cursor() cursor.execute_async(query) query_id = cursor.sfqid async_queries[query_id] = { "query": query[:500], "submitted_at": datetime.now().isoformat(), "status": "RUNNING" } cursor.close() return { "query_id": query_id, "status": "SUBMITTED", "message": "Query submitted for async execution. Use get_query_status to check progress.", } except Exception as e: logger.error(f"Failed to submit async query: {e}") return {"error": str(e)} async def get_query_status(query_id: str) -> Dict[str, Any]: """Check the status of an async query by its query ID.""" try: conn = connection.get() status = conn.get_query_status(query_id) status_name = status.name if hasattr(status, 'name') else str(status) is_running = conn.is_still_running(status) is_error = conn.is_an_error(status) if query_id in async_queries: async_queries[query_id]["status"] = status_name result = { "query_id": query_id, "status": status_name, "is_running": is_running, "is_error": is_error, "can_fetch_results": not is_running and not is_error } if query_id in async_queries: result["query_info"] = async_queries[query_id] return result except Exception as e: logger.error(f"Failed to get query status: {e}") return {"error": str(e), "query_id": query_id} async def get_async_results(query_id: str, max_rows: int = 50000) -> Dict[str, Any]: """Retrieve results from a completed async query.""" try: conn = connection.get() status = conn.get_query_status_throw_if_error(query_id) if conn.is_still_running(status): return { "error": "Query is still running", "query_id": query_id, "status": status.name if hasattr(status, 'name') else str(status) } cursor = conn.cursor() cursor.get_results_from_sfqid(query_id) raw_rows = cursor.fetchmany(max_rows) columns = [desc[0] for desc in cursor.description] if cursor.description else [] rows = [] for row in raw_rows: row_dict = {} for idx, col in enumerate(columns): value = row[idx] if idx < len(row) else None if value is None: row_dict[col] = None elif isinstance(value, (datetime, bytes)): row_dict[col] = str(value) else: row_dict[col] = value rows.append(row_dict) truncated = len(raw_rows) == max_rows cursor.close() if query_id in async_queries: del async_queries[query_id] return { "query_id": query_id, "columns": columns, "rows": rows, "row_count": len(rows), "truncated": truncated } except Exception as e: logger.error(f"Failed to get async results: {e}") return {"error": str(e), "query_id": query_id} async def read_data_paginated(query: str, page_size: int = 1000, page: int = 1) -> Dict[str, Any]: """Execute a query with pagination support (offset/limit).""" is_valid, message = validate_query(query) if not is_valid: return {"error": message} if page < 1: return {"error": "Page must be >= 1"} if page_size < 1 or page_size > 10000: return {"error": "Page size must be between 1 and 10000"} offset = (page - 1) * page_size clean_query = query.rstrip(';').strip() clean_query = re.sub(r'\s+LIMIT\s+\d+(\s+OFFSET\s+\d+)?\s*$', '', clean_query, flags=re.IGNORECASE) clean_query = re.sub(r'\s+OFFSET\s+\d+\s*$', '', clean_query, flags=re.IGNORECASE) paginated_query = f""" SELECT * FROM ( {clean_query} ) AS paginated_result LIMIT {page_size} OFFSET {offset} """ count_cache_key = get_cache_key(f"COUNT:{clean_query}") total_count = None if count_cache_key in query_cache: cached_count, cached_time = query_cache[count_cache_key] if datetime.now() - cached_time < CACHE_TTL: total_count = cached_count try: conn = connection.get() cursor = conn.cursor(DictCursor) if total_count is None: count_query = f"SELECT COUNT(*) as total FROM ({clean_query}) AS count_subq" cursor.execute(count_query) count_row = cursor.fetchone() total_count = count_row["TOTAL"] if count_row else 0 query_cache[count_cache_key] = (total_count, datetime.now()) cursor.execute(paginated_query) columns = [desc[0] for desc in cursor.description] if cursor.description else [] rows = [] for row in cursor: row_dict = {} for col in columns: value = row[col] if value is None: row_dict[col] = None elif isinstance(value, (datetime, bytes)): row_dict[col] = str(value) else: row_dict[col] = value rows.append(row_dict) cursor.close() total_pages = (total_count + page_size - 1) // page_size if total_count else 0 return { "columns": columns, "rows": rows, "pagination": { "page": page, "page_size": page_size, "total_rows": total_count, "total_pages": total_pages, "has_next": page < total_pages, "has_previous": page > 1 }, } except Exception as e: logger.error(f"Paginated query failed: {e}") return {"error": str(e)} async def read_data_pandas(query: str, max_rows: int = 50000, output_format: str = "records") -> Dict[str, Any]: """Execute a query and return results in a pandas-friendly format. output_format options: - 'records': List of dicts (default, same as read_data) - 'columns': Dict with column arrays {col1: [vals], col2: [vals]} - 'split': Dict with columns and data arrays {columns: [...], data: [[...]]} - 'csv': CSV-formatted string """ is_valid, message = validate_query(query) if not is_valid: return {"error": message} if output_format not in ('records', 'columns', 'split', 'csv'): return {"error": f"Invalid output_format: {output_format}. Use: records, columns, split, csv"} try: conn = connection.get() cursor = conn.cursor(DictCursor) cursor.execute(query) columns = [desc[0] for desc in cursor.description] if cursor.description else [] all_rows = [] for i, row in enumerate(cursor): if i >= max_rows: break row_data = {} for col in columns: value = row[col] if value is None: row_data[col] = None elif isinstance(value, datetime): row_data[col] = value.isoformat() elif isinstance(value, bytes): row_data[col] = value.decode('utf-8', errors='replace') else: row_data[col] = value all_rows.append(row_data) cursor.close() truncated = len(all_rows) == max_rows if output_format == 'records': data = all_rows elif output_format == 'columns': data = {col: [row.get(col) for row in all_rows] for col in columns} elif output_format == 'split': data = { "columns": columns, "data": [[row.get(col) for col in columns] for row in all_rows] } elif output_format == 'csv': output = io.StringIO() writer = csv.DictWriter(output, fieldnames=columns) writer.writeheader() writer.writerows(all_rows) data = output.getvalue() return { "format": output_format, "columns": columns, "data": data, "row_count": len(all_rows), "truncated": truncated, "pandas_hint": "Use pd.DataFrame(result['data']) for 'records' format, pd.DataFrame(result['data']) for 'columns' format, or pd.read_csv(io.StringIO(result['data'])) for 'csv' format" } except Exception as e: logger.error(f"Pandas query failed: {e}") return {"error": str(e)} async def list_async_queries() -> Dict[str, Any]: """List all tracked async queries and their statuses.""" return { "queries": async_queries, "count": len(async_queries) } # Create MCP server server = Server("snowflake-mcp") @server.list_tools() async def handle_list_tools() -> List[Tool]: """List available MCP tools""" return [ Tool( name="test_connection", description="Test Snowflake connection and get session information", inputSchema={ "type": "object", "properties": {} } ), Tool( name="list_databases", description="List all accessible databases in Snowflake", inputSchema={ "type": "object", "properties": {} } ), Tool( name="list_schemas", description="List all schemas in a database", inputSchema={ "type": "object", "properties": { "database_name": { "type": "string", "description": "Database name (optional, uses current if not specified)" } } } ), Tool( name="list_tables", description="List all tables in the current database/schema", inputSchema={ "type": "object", "properties": { "schema": { "type": "string", "description": "Schema filter (optional, e.g. 'MY_DB.MY_SCHEMA')" }, "include_row_counts": { "type": "boolean", "description": "Include row counts", "default": False } } } ), Tool( name="list_views", description="List all views in a schema with their descriptions", inputSchema={ "type": "object", "properties": { "schema_name": { "type": "string", "description": "Schema name (e.g., 'MY_DB.MY_SCHEMA')" } }, "required": ["schema_name"] } ), Tool( name="describe_table", description="Get detailed schema information for a table or view", inputSchema={ "type": "object", "properties": { "table_name": { "type": "string", "description": "Table name (can include database.schema.table)" } }, "required": ["table_name"] } ), Tool( name="read_data", description="Execute a SELECT query on Snowflake", inputSchema={ "type": "object", "properties": { "query": { "type": "string", "description": "SQL SELECT query to execute" }, "max_rows": { "type": "integer", "description": "Maximum rows to return", "default": 50000 } }, "required": ["query"] } ), Tool( name="get_system_health", description="Get system health and connection pool statistics", inputSchema={ "type": "object", "properties": {} } ), Tool( name="describe_query", description="Preview query output columns without executing the full query. Useful for validating complex queries before running them.", inputSchema={ "type": "object", "properties": { "query": { "type": "string", "description": "SQL SELECT query to analyze" } }, "required": ["query"] } ), Tool( name="execute_async", description="Submit a query for asynchronous execution. Returns a query ID for status tracking. Use for long-running queries.", inputSchema={ "type": "object", "properties": { "query": { "type": "string", "description": "SQL SELECT query to execute asynchronously" } }, "required": ["query"] } ), Tool( name="get_query_status", description="Check the status of an async query by its query ID", inputSchema={ "type": "object", "properties": { "query_id": { "type": "string", "description": "Snowflake query ID from execute_async" } }, "required": ["query_id"] } ), Tool( name="get_async_results", description="Retrieve results from a completed async query", inputSchema={ "type": "object", "properties": { "query_id": { "type": "string", "description": "Snowflake query ID from execute_async" }, "max_rows": { "type": "integer", "description": "Maximum rows to return", "default": 50000 } }, "required": ["query_id"] } ), Tool( name="list_async_queries", description="List all tracked async queries and their statuses", inputSchema={ "type": "object", "properties": {} } ), Tool( name="read_data_paginated", description="Execute a query with pagination support. Returns results page by page with total count.", inputSchema={ "type": "object", "properties": { "query": { "type": "string", "description": "SQL SELECT query to execute" }, "page_size": { "type": "integer", "description": "Number of rows per page (1-10000)", "default": 1000 }, "page": { "type": "integer", "description": "Page number (1-based)", "default": 1 } }, "required": ["query"] } ), Tool( name="read_data_pandas", description="Execute a query and return results in a pandas-friendly format. Supports multiple output formats: records, columns, split, csv.", inputSchema={ "type": "object", "properties": { "query": { "type": "string", "description": "SQL SELECT query to execute" }, "max_rows": { "type": "integer", "description": "Maximum rows to return", "default": 50000 }, "output_format": { "type": "string", "description": "Output format: 'records' (list of dicts), 'columns' (dict of arrays), 'split' (columns + data arrays), 'csv' (CSV string)", "default": "records", "enum": ["records", "columns", "split", "csv"] } }, "required": ["query"] } ) ] @server.call_tool() async def call_tool(name: str, arguments: Any) -> List[TextContent]: """Handle tool calls""" try: if name == "test_connection": result = await test_connection() elif name == "list_databases": result = await list_databases() elif name == "list_schemas": database_name = arguments.get("database_name") result = await list_schemas(database_name) elif name == "list_tables": schema = arguments.get("schema") include_counts = arguments.get("include_row_counts", False) result = await list_tables(schema, include_counts) elif name == "list_views": schema_name = arguments["schema_name"] result = await list_views(schema_name) elif name == "describe_table": table_name = arguments["table_name"] result = await describe_table(table_name) elif name == "read_data": query = arguments["query"] max_rows = arguments.get("max_rows", 50000) result = await read_data(query, max_rows) elif name == "get_system_health": result = await get_system_health() elif name == "describe_query": query = arguments["query"] result = await describe_query(query) elif name == "execute_async": query = arguments["query"] result = await execute_async(query) elif name == "get_query_status": query_id = arguments["query_id"] result = await get_query_status(query_id) elif name == "get_async_results": query_id = arguments["query_id"] max_rows = arguments.get("max_rows", 50000) result = await get_async_results(query_id, max_rows) elif name == "list_async_queries": result = await list_async_queries() elif name == "read_data_paginated": query = arguments["query"] page_size = arguments.get("page_size", 1000) page = arguments.get("page", 1) result = await read_data_paginated(query, page_size, page) elif name == "read_data_pandas": query = arguments["query"] max_rows = arguments.get("max_rows", 50000) output_format = arguments.get("output_format", "records") result = await read_data_pandas(query, max_rows, output_format) else: raise ValueError(f"Unknown tool: {name}") return [TextContent( type="text", text=json.dumps(result, indent=2, default=str) )] except Exception as e: logger.error(f"Tool execution failed: {e}") raise INTERNAL_ERROR(str(e)) async def main(): """Run the MCP server""" from mcp.server.stdio import stdio_server logger.info("Starting Snowflake MCP Server with SSO Authentication") logger.info(f"Account: {SNOWFLAKE_CONFIG['account']}") logger.info(f"User: {SNOWFLAKE_CONFIG['user']}") logger.info("First connection will open browser for SSO login") async with stdio_server() as (read_stream, write_stream): await server.run( read_stream, write_stream, server.create_initialization_options() ) if __name__ == "__main__": try: asyncio.run(main()) except KeyboardInterrupt: logger.info("Server stopped by user") finally: connection.close()