From 5dc546014adae22412a75e10122ee2dfdc2424ef Mon Sep 17 00:00:00 2001 From: A Charlwood Date: Mon, 16 Mar 2026 11:14:04 +0000 Subject: [PATCH] Removed dead code/unneeded env setting --- README.md | 10 +- snowflake_mcp_server.py | 304 ++++++++++------------------------------ 2 files changed, 81 insertions(+), 233 deletions(-) diff --git a/README.md b/README.md index dca2edd..002f791 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,7 @@ Gives Claude (or any MCP client) the ability to explore your Snowflake account ## Features - SSO authentication (opens browser on first connection, then reuses the session) -- Connection pooling with 30-minute timeout +- Single persistent connection with 30-minute timeout - Query caching (5 min TTL) - Read-only safety: only SELECT / SHOW / DESCRIBE queries are allowed - Async query support for long-running queries @@ -37,14 +37,14 @@ Set environment variables (recommended) or edit the defaults at the top of `snow export SNOWFLAKE_ACCOUNT="your-account-id" # e.g. "xy12345.eu-west-1" export SNOWFLAKE_USER="your.email@company.com" -# Optional — defaults shown -export SNOWFLAKE_AUTHENTICATOR="externalbrowser" # SSO via browser +# Optional +export SNOWFLAKE_AUTHENTICATOR="externalbrowser" # SSO via browser (default) export SNOWFLAKE_WAREHOUSE="" # uses account default if empty export SNOWFLAKE_ROLE="" # uses account default if empty -export SNOWFLAKE_DATABASE="" # uses account default if empty -export SNOWFLAKE_SCHEMA="" # uses account default if empty ``` +Once connected, use `list_databases` and `list_schemas` to discover what you have access to. + On Windows (PowerShell): ```powershell $env:SNOWFLAKE_ACCOUNT = "your-account-id" diff --git a/snowflake_mcp_server.py b/snowflake_mcp_server.py index 2980f49..0eb006d 100644 --- a/snowflake_mcp_server.py +++ b/snowflake_mcp_server.py @@ -37,101 +37,73 @@ SNOWFLAKE_CONFIG = { "role": os.environ.get("SNOWFLAKE_ROLE"), # None = account default } -# Available databases/schemas to expose -DATABASES = { - "default": { - "database": os.environ.get("SNOWFLAKE_DATABASE"), # None = account default - "schema": os.environ.get("SNOWFLAKE_SCHEMA"), # None = account default - "description": "Default database and schema" - }, -} - -class SnowflakeConnectionPool: - """Manages Snowflake connections with SSO authentication""" +class SnowflakeConnection: + """Manages a single Snowflake connection with SSO authentication""" def __init__(self): - self.connections: Dict[str, snowflake.connector.SnowflakeConnection] = {} - self.last_used: Dict[str, datetime] = {} - self.connection_timeout = timedelta(minutes=30) + self._conn: Optional[snowflake.connector.SnowflakeConnection] = None + self._last_used: Optional[datetime] = None + self._timeout = timedelta(minutes=30) - def get_connection(self, database: str = "default") -> snowflake.connector.SnowflakeConnection: - """Get or create a Snowflake connection with SSO auth""" - - cache_key = database + def get(self) -> snowflake.connector.SnowflakeConnection: + """Get or create the Snowflake connection""" # Check if we have a valid cached connection - if cache_key in self.connections: - conn = self.connections[cache_key] - last = self.last_used.get(cache_key, datetime.now()) - - # Check if connection is still valid and not timed out - if datetime.now() - last < self.connection_timeout: + if self._conn and self._last_used: + if datetime.now() - self._last_used < self._timeout: try: - # Test the connection - cursor = conn.cursor() + cursor = self._conn.cursor() cursor.execute("SELECT 1") cursor.close() - self.last_used[cache_key] = datetime.now() - logger.info(f"Reusing cached Snowflake connection for {database}") - return conn + self._last_used = datetime.now() + logger.info("Reusing cached Snowflake connection") + return self._conn except Exception: - logger.info(f"Cached connection for {database} is stale, reconnecting...") + logger.info("Cached connection is stale, reconnecting...") try: - conn.close() + self._conn.close() except Exception: pass # Create new connection - logger.info(f"Creating new Snowflake connection for {database}") + logger.info("Creating new Snowflake connection") logger.info("Note: First connection will open browser for SSO login") - db_config = DATABASES.get(database, DATABASES["default"]) - connect_params = { "account": SNOWFLAKE_CONFIG["account"], "user": SNOWFLAKE_CONFIG["user"], "authenticator": SNOWFLAKE_CONFIG["authenticator"], } - # Add optional parameters if specified - if db_config.get("database"): - connect_params["database"] = db_config["database"] - if db_config.get("schema"): - connect_params["schema"] = db_config["schema"] if SNOWFLAKE_CONFIG.get("warehouse"): connect_params["warehouse"] = SNOWFLAKE_CONFIG["warehouse"] if SNOWFLAKE_CONFIG.get("role"): connect_params["role"] = SNOWFLAKE_CONFIG["role"] try: - conn = snowflake.connector.connect(**connect_params) - - # Store the connection - self.connections[cache_key] = conn - self.last_used[cache_key] = datetime.now() - - logger.info(f"Successfully connected to Snowflake ({database})") - return conn - + self._conn = snowflake.connector.connect(**connect_params) + self._last_used = datetime.now() + logger.info("Successfully connected to Snowflake") + return self._conn except Exception as e: logger.error(f"Failed to connect to Snowflake: {e}") raise - def close_all(self): - """Close all connections""" - for database, conn in self.connections.items(): + def close(self): + """Close the connection""" + if self._conn: try: - conn.close() - logger.info(f"Closed Snowflake connection to {database}") + self._conn.close() + logger.info("Closed Snowflake connection") except Exception: pass - self.connections.clear() - self.last_used.clear() + self._conn = None + self._last_used = None -# Global connection pool -connection_pool = SnowflakeConnectionPool() +# Global connection +connection = SnowflakeConnection() # Query cache query_cache: Dict[str, tuple] = {} @@ -141,9 +113,9 @@ CACHE_TTL = timedelta(minutes=5) async_queries: Dict[str, Dict[str, Any]] = {} -def get_cache_key(query: str, database: str) -> str: +def get_cache_key(query: str) -> str: """Generate cache key for query""" - return hashlib.md5(f"{database}:{query}".encode()).hexdigest() + return hashlib.md5(query.encode()).hexdigest() def validate_query(query: str) -> tuple[bool, str]: @@ -161,20 +133,18 @@ def validate_query(query: str) -> tuple[bool, str]: 'CREATE', 'ALTER', 'GRANT', 'REVOKE', 'MERGE'] for keyword in dangerous: - # Check for keyword as whole word (not part of column name) if f' {keyword} ' in f' {query_upper} ': return False, f"Query contains forbidden keyword: {keyword}" return True, "Query validated successfully" -async def test_connection(database: str = "default") -> Dict[str, Any]: +async def test_connection() -> Dict[str, Any]: """Test Snowflake connection and return system information""" try: - conn = connection_pool.get_connection(database) + conn = connection.get() cursor = conn.cursor(DictCursor) - # Get session information cursor.execute(""" SELECT CURRENT_ACCOUNT() as account, @@ -211,14 +181,13 @@ async def test_connection(database: str = "default") -> Dict[str, Any]: } -async def list_tables(database: str = "default", schema: Optional[str] = None, +async def list_tables(schema: Optional[str] = None, include_row_counts: bool = False) -> Dict[str, Any]: """List all tables in the database/schema""" try: - conn = connection_pool.get_connection(database) + conn = connection.get() cursor = conn.cursor(DictCursor) - # Get current database and schema if not specified if schema: query = f"SHOW TABLES IN SCHEMA {schema}" else: @@ -242,7 +211,6 @@ async def list_tables(database: str = "default", schema: Optional[str] = None, cursor.close() - # Return with helpful hint if no tables found if not tables: return { "tables": [], @@ -257,10 +225,10 @@ async def list_tables(database: str = "default", schema: Optional[str] = None, raise -async def list_views(schema_name: str, database: str = "default") -> List[Dict[str, Any]]: +async def list_views(schema_name: str) -> List[Dict[str, Any]]: """List all views in a schema with their descriptions""" try: - conn = connection_pool.get_connection(database) + conn = connection.get() cursor = conn.cursor(DictCursor) cursor.execute(f"SHOW VIEWS IN SCHEMA {schema_name}") @@ -287,13 +255,12 @@ async def list_views(schema_name: str, database: str = "default") -> List[Dict[s raise -async def describe_table(table_name: str, database: str = "default") -> Dict[str, Any]: +async def describe_table(table_name: str) -> Dict[str, Any]: """Get detailed schema information for a table or view""" try: - conn = connection_pool.get_connection(database) + conn = connection.get() cursor = conn.cursor() - # Parse table name to handle quoting parts = table_name.split('.') if len(parts) == 3: db, schema, table = parts @@ -305,7 +272,6 @@ async def describe_table(table_name: str, database: str = "default") -> Dict[str schema = None table = table_name - # Build quoted version for case-sensitive names if len(parts) == 3: quoted_name = f'{parts[0]}.{parts[1]}."{parts[2]}"' elif len(parts) == 2: @@ -313,7 +279,6 @@ async def describe_table(table_name: str, database: str = "default") -> Dict[str else: quoted_name = f'"{table_name}"' - # Try unquoted first, then quoted if that fails rows = None used_name = table_name for try_name in [table_name, quoted_name]: @@ -324,13 +289,12 @@ async def describe_table(table_name: str, database: str = "default") -> Dict[str break except Exception as inner_e: if "does not exist" in str(inner_e).lower() and try_name == table_name: - continue # Try quoted version + continue raise if rows is None: rows = [] - # Get column names from description using attribute access (safer for ResultMetadata) desc_columns = [] if cursor.description: for desc in cursor.description: @@ -385,11 +349,9 @@ async def describe_table(table_name: str, database: str = "default") -> Dict[str } -async def read_data(query: str, database: str = "default", - max_rows: int = 50000) -> Dict[str, Any]: +async def read_data(query: str, max_rows: int = 50000) -> Dict[str, Any]: """Execute a SELECT query and return results""" - # Validate query is_valid, message = validate_query(query) if not is_valid: return { @@ -397,8 +359,7 @@ async def read_data(query: str, database: str = "default", "suggestion": "Use SELECT statements only. For schema information, use describe_table." } - # Check cache - cache_key = get_cache_key(query, database) + cache_key = get_cache_key(query) if cache_key in query_cache: cached_result, cached_time = query_cache[cache_key] if datetime.now() - cached_time < CACHE_TTL: @@ -410,23 +371,19 @@ async def read_data(query: str, database: str = "default", } try: - conn = connection_pool.get_connection(database) + conn = connection.get() cursor = conn.cursor(DictCursor) - # Execute query cursor.execute(query) - # Get column names columns = [desc[0] for desc in cursor.description] if cursor.description else [] - # Fetch rows up to limit rows = [] for i, row in enumerate(cursor): if i >= max_rows: logger.warning(f"Query returned more than {max_rows} rows, truncating") break - # Convert row to dict with proper serialization row_dict = {} for col in columns: value = row[col] @@ -446,10 +403,8 @@ async def read_data(query: str, database: str = "default", "rows": rows, "row_count": len(rows), "truncated": len(rows) == max_rows, - "database": database } - # Cache the result query_cache[cache_key] = (result, datetime.now()) return result @@ -474,7 +429,7 @@ async def read_data(query: str, database: str = "default", async def list_databases() -> List[Dict[str, Any]]: """List all accessible databases""" try: - conn = connection_pool.get_connection("default") + conn = connection.get() cursor = conn.cursor(DictCursor) cursor.execute("SHOW DATABASES") @@ -501,7 +456,7 @@ async def list_databases() -> List[Dict[str, Any]]: async def list_schemas(database_name: Optional[str] = None) -> List[Dict[str, Any]]: """List all schemas in a database""" try: - conn = connection_pool.get_connection("default") + conn = connection.get() cursor = conn.cursor(DictCursor) if database_name: @@ -533,14 +488,6 @@ async def get_system_health() -> Dict[str, Any]: health_info = { "status": "healthy", "timestamp": datetime.now().isoformat(), - "connection_pool": { - "active_connections": len(connection_pool.connections), - "databases": list(connection_pool.connections.keys()), - "last_used": { - db: last.isoformat() - for db, last in connection_pool.last_used.items() - } - }, "cache": { "query_cache_size": len(query_cache), "cache_ttl_minutes": CACHE_TTL.total_seconds() / 60 @@ -552,9 +499,8 @@ async def get_system_health() -> Dict[str, Any]: } } - # Test connection try: - test_result = await test_connection("default") + test_result = await test_connection() health_info["connection"] = test_result except Exception: health_info["status"] = "degraded" @@ -563,7 +509,7 @@ async def get_system_health() -> Dict[str, Any]: return health_info -async def describe_query(query: str, database: str = "default") -> Dict[str, Any]: +async def describe_query(query: str) -> Dict[str, Any]: """Preview query output columns without executing the full query. Uses LIMIT 0 to get column metadata efficiently.""" @@ -572,21 +518,18 @@ async def describe_query(query: str, database: str = "default") -> Dict[str, Any return {"error": message} try: - conn = connection_pool.get_connection(database) + conn = connection.get() cursor = conn.cursor(DictCursor) clean_query = query.rstrip(';').strip() - # Remove existing LIMIT and OFFSET clauses (case insensitive) clean_query = re.sub(r'\s+LIMIT\s+\d+(\s+OFFSET\s+\d+)?\s*$', '', clean_query, flags=re.IGNORECASE) clean_query = re.sub(r'\s+OFFSET\s+\d+\s*$', '', clean_query, flags=re.IGNORECASE) - # Add LIMIT 0 to get column metadata without fetching data limited_query = f"{clean_query} LIMIT 0" try: cursor.execute(limited_query) except Exception: - # Fallback: wrap in subquery if direct approach fails limited_query = f"SELECT * FROM ({clean_query}) AS subq LIMIT 0" cursor.execute(limited_query) @@ -609,7 +552,6 @@ async def describe_query(query: str, database: str = "default") -> Dict[str, Any return { "columns": columns, "column_count": len(columns), - "database": database, "query_preview": query[:200] + "..." if len(query) > 200 else query } @@ -621,7 +563,7 @@ async def describe_query(query: str, database: str = "default") -> Dict[str, Any } -async def execute_async(query: str, database: str = "default") -> Dict[str, Any]: +async def execute_async(query: str) -> Dict[str, Any]: """Submit a query for asynchronous execution. Returns query ID for status tracking.""" is_valid, message = validate_query(query) @@ -629,17 +571,14 @@ async def execute_async(query: str, database: str = "default") -> Dict[str, Any] return {"error": message} try: - conn = connection_pool.get_connection(database) + conn = connection.get() cursor = conn.cursor() - # Execute asynchronously cursor.execute_async(query) query_id = cursor.sfqid - # Store query info for tracking async_queries[query_id] = { "query": query[:500], - "database": database, "submitted_at": datetime.now().isoformat(), "status": "RUNNING" } @@ -650,7 +589,6 @@ async def execute_async(query: str, database: str = "default") -> Dict[str, Any] "query_id": query_id, "status": "SUBMITTED", "message": "Query submitted for async execution. Use get_query_status to check progress.", - "database": database } except Exception as e: @@ -658,11 +596,11 @@ async def execute_async(query: str, database: str = "default") -> Dict[str, Any] return {"error": str(e)} -async def get_query_status(query_id: str, database: str = "default") -> Dict[str, Any]: +async def get_query_status(query_id: str) -> Dict[str, Any]: """Check the status of an async query by its query ID.""" try: - conn = connection_pool.get_connection(database) + conn = connection.get() status = conn.get_query_status(query_id) status_name = status.name if hasattr(status, 'name') else str(status) @@ -670,7 +608,6 @@ async def get_query_status(query_id: str, database: str = "default") -> Dict[str is_running = conn.is_still_running(status) is_error = conn.is_an_error(status) - # Update tracked query info if query_id in async_queries: async_queries[query_id]["status"] = status_name @@ -682,7 +619,6 @@ async def get_query_status(query_id: str, database: str = "default") -> Dict[str "can_fetch_results": not is_running and not is_error } - # Add original query info if available if query_id in async_queries: result["query_info"] = async_queries[query_id] @@ -693,14 +629,12 @@ async def get_query_status(query_id: str, database: str = "default") -> Dict[str return {"error": str(e), "query_id": query_id} -async def get_async_results(query_id: str, database: str = "default", - max_rows: int = 50000) -> Dict[str, Any]: +async def get_async_results(query_id: str, max_rows: int = 50000) -> Dict[str, Any]: """Retrieve results from a completed async query.""" try: - conn = connection_pool.get_connection(database) + conn = connection.get() - # Check status first status = conn.get_query_status_throw_if_error(query_id) if conn.is_still_running(status): @@ -710,17 +644,13 @@ async def get_async_results(query_id: str, database: str = "default", "status": status.name if hasattr(status, 'name') else str(status) } - # Use standard cursor for get_results_from_sfqid cursor = conn.cursor() cursor.get_results_from_sfqid(query_id) - # Fetch results first (description may not populate until after fetch) raw_rows = cursor.fetchmany(max_rows) - # Now get column names from description columns = [desc[0] for desc in cursor.description] if cursor.description else [] - # Convert to list of dicts rows = [] for row in raw_rows: row_dict = {} @@ -734,12 +664,10 @@ async def get_async_results(query_id: str, database: str = "default", row_dict[col] = value rows.append(row_dict) - # Check if there are more rows truncated = len(raw_rows) == max_rows cursor.close() - # Clean up tracking if query_id in async_queries: del async_queries[query_id] @@ -756,8 +684,8 @@ async def get_async_results(query_id: str, database: str = "default", return {"error": str(e), "query_id": query_id} -async def read_data_paginated(query: str, database: str = "default", - page_size: int = 1000, page: int = 1) -> Dict[str, Any]: +async def read_data_paginated(query: str, page_size: int = 1000, + page: int = 1) -> Dict[str, Any]: """Execute a query with pagination support (offset/limit).""" is_valid, message = validate_query(query) @@ -771,12 +699,10 @@ async def read_data_paginated(query: str, database: str = "default", offset = (page - 1) * page_size - # Strip existing LIMIT/OFFSET from query for proper pagination clean_query = query.rstrip(';').strip() clean_query = re.sub(r'\s+LIMIT\s+\d+(\s+OFFSET\s+\d+)?\s*$', '', clean_query, flags=re.IGNORECASE) clean_query = re.sub(r'\s+OFFSET\s+\d+\s*$', '', clean_query, flags=re.IGNORECASE) - # Wrap query with pagination paginated_query = f""" SELECT * FROM ( {clean_query} @@ -784,8 +710,7 @@ async def read_data_paginated(query: str, database: str = "default", LIMIT {page_size} OFFSET {offset} """ - # Also get total count (cached for efficiency) - count_cache_key = get_cache_key(f"COUNT:{clean_query}", database) + count_cache_key = get_cache_key(f"COUNT:{clean_query}") total_count = None if count_cache_key in query_cache: @@ -794,10 +719,9 @@ async def read_data_paginated(query: str, database: str = "default", total_count = cached_count try: - conn = connection_pool.get_connection(database) + conn = connection.get() cursor = conn.cursor(DictCursor) - # Get total count if not cached if total_count is None: count_query = f"SELECT COUNT(*) as total FROM ({clean_query}) AS count_subq" cursor.execute(count_query) @@ -805,7 +729,6 @@ async def read_data_paginated(query: str, database: str = "default", total_count = count_row["TOTAL"] if count_row else 0 query_cache[count_cache_key] = (total_count, datetime.now()) - # Execute paginated query cursor.execute(paginated_query) columns = [desc[0] for desc in cursor.description] if cursor.description else [] @@ -838,7 +761,6 @@ async def read_data_paginated(query: str, database: str = "default", "has_next": page < total_pages, "has_previous": page > 1 }, - "database": database } except Exception as e: @@ -846,8 +768,7 @@ async def read_data_paginated(query: str, database: str = "default", return {"error": str(e)} -async def read_data_pandas(query: str, database: str = "default", - max_rows: int = 50000, +async def read_data_pandas(query: str, max_rows: int = 50000, output_format: str = "records") -> Dict[str, Any]: """Execute a query and return results in a pandas-friendly format. @@ -866,14 +787,13 @@ async def read_data_pandas(query: str, database: str = "default", return {"error": f"Invalid output_format: {output_format}. Use: records, columns, split, csv"} try: - conn = connection_pool.get_connection(database) + conn = connection.get() cursor = conn.cursor(DictCursor) cursor.execute(query) columns = [desc[0] for desc in cursor.description] if cursor.description else [] - # Collect all rows first all_rows = [] for i, row in enumerate(cursor): if i >= max_rows: @@ -894,19 +814,15 @@ async def read_data_pandas(query: str, database: str = "default", cursor.close() truncated = len(all_rows) == max_rows - # Format output based on requested format if output_format == 'records': data = all_rows - elif output_format == 'columns': data = {col: [row.get(col) for row in all_rows] for col in columns} - elif output_format == 'split': data = { "columns": columns, "data": [[row.get(col) for col in columns] for row in all_rows] } - elif output_format == 'csv': output = io.StringIO() writer = csv.DictWriter(output, fieldnames=columns) @@ -920,7 +836,6 @@ async def read_data_pandas(query: str, database: str = "default", "data": data, "row_count": len(all_rows), "truncated": truncated, - "database": database, "pandas_hint": "Use pd.DataFrame(result['data']) for 'records' format, pd.DataFrame(result['data']) for 'columns' format, or pd.read_csv(io.StringIO(result['data'])) for 'csv' format" } @@ -942,7 +857,7 @@ server = Server("snowflake-mcp") @server.list_tools() -async def list_tools() -> List[Tool]: +async def handle_list_tools() -> List[Tool]: """List available MCP tools""" return [ Tool( @@ -950,13 +865,7 @@ async def list_tools() -> List[Tool]: description="Test Snowflake connection and get session information", inputSchema={ "type": "object", - "properties": { - "database": { - "type": "string", - "description": "Database context (default: 'default')", - "default": "default" - } - } + "properties": {} } ), Tool( @@ -986,14 +895,9 @@ async def list_tools() -> List[Tool]: inputSchema={ "type": "object", "properties": { - "database": { - "type": "string", - "description": "Database context", - "default": "default" - }, "schema": { "type": "string", - "description": "Schema filter (optional)" + "description": "Schema filter (optional, e.g. 'MY_DB.MY_SCHEMA')" }, "include_row_counts": { "type": "boolean", @@ -1012,11 +916,6 @@ async def list_tools() -> List[Tool]: "schema_name": { "type": "string", "description": "Schema name (e.g., 'MY_DB.MY_SCHEMA')" - }, - "database": { - "type": "string", - "description": "Database context", - "default": "default" } }, "required": ["schema_name"] @@ -1024,18 +923,13 @@ async def list_tools() -> List[Tool]: ), Tool( name="describe_table", - description="Get detailed schema information for a table", + description="Get detailed schema information for a table or view", inputSchema={ "type": "object", "properties": { "table_name": { "type": "string", "description": "Table name (can include database.schema.table)" - }, - "database": { - "type": "string", - "description": "Database context", - "default": "default" } }, "required": ["table_name"] @@ -1051,11 +945,6 @@ async def list_tools() -> List[Tool]: "type": "string", "description": "SQL SELECT query to execute" }, - "database": { - "type": "string", - "description": "Database context", - "default": "default" - }, "max_rows": { "type": "integer", "description": "Maximum rows to return", @@ -1082,11 +971,6 @@ async def list_tools() -> List[Tool]: "query": { "type": "string", "description": "SQL SELECT query to analyze" - }, - "database": { - "type": "string", - "description": "Database context", - "default": "default" } }, "required": ["query"] @@ -1101,11 +985,6 @@ async def list_tools() -> List[Tool]: "query": { "type": "string", "description": "SQL SELECT query to execute asynchronously" - }, - "database": { - "type": "string", - "description": "Database context", - "default": "default" } }, "required": ["query"] @@ -1120,11 +999,6 @@ async def list_tools() -> List[Tool]: "query_id": { "type": "string", "description": "Snowflake query ID from execute_async" - }, - "database": { - "type": "string", - "description": "Database context", - "default": "default" } }, "required": ["query_id"] @@ -1140,11 +1014,6 @@ async def list_tools() -> List[Tool]: "type": "string", "description": "Snowflake query ID from execute_async" }, - "database": { - "type": "string", - "description": "Database context", - "default": "default" - }, "max_rows": { "type": "integer", "description": "Maximum rows to return", @@ -1172,11 +1041,6 @@ async def list_tools() -> List[Tool]: "type": "string", "description": "SQL SELECT query to execute" }, - "database": { - "type": "string", - "description": "Database context", - "default": "default" - }, "page_size": { "type": "integer", "description": "Number of rows per page (1-10000)", @@ -1201,11 +1065,6 @@ async def list_tools() -> List[Tool]: "type": "string", "description": "SQL SELECT query to execute" }, - "database": { - "type": "string", - "description": "Database context", - "default": "default" - }, "max_rows": { "type": "integer", "description": "Maximum rows to return", @@ -1230,8 +1089,7 @@ async def call_tool(name: str, arguments: Any) -> List[TextContent]: try: if name == "test_connection": - database = arguments.get("database", "default") - result = await test_connection(database) + result = await test_connection() elif name == "list_databases": result = await list_databases() @@ -1241,67 +1099,57 @@ async def call_tool(name: str, arguments: Any) -> List[TextContent]: result = await list_schemas(database_name) elif name == "list_tables": - database = arguments.get("database", "default") schema = arguments.get("schema") include_counts = arguments.get("include_row_counts", False) - result = await list_tables(database, schema, include_counts) + result = await list_tables(schema, include_counts) elif name == "list_views": schema_name = arguments["schema_name"] - database = arguments.get("database", "default") - result = await list_views(schema_name, database) + result = await list_views(schema_name) elif name == "describe_table": table_name = arguments["table_name"] - database = arguments.get("database", "default") - result = await describe_table(table_name, database) + result = await describe_table(table_name) elif name == "read_data": query = arguments["query"] - database = arguments.get("database", "default") max_rows = arguments.get("max_rows", 50000) - result = await read_data(query, database, max_rows) + result = await read_data(query, max_rows) elif name == "get_system_health": result = await get_system_health() elif name == "describe_query": query = arguments["query"] - database = arguments.get("database", "default") - result = await describe_query(query, database) + result = await describe_query(query) elif name == "execute_async": query = arguments["query"] - database = arguments.get("database", "default") - result = await execute_async(query, database) + result = await execute_async(query) elif name == "get_query_status": query_id = arguments["query_id"] - database = arguments.get("database", "default") - result = await get_query_status(query_id, database) + result = await get_query_status(query_id) elif name == "get_async_results": query_id = arguments["query_id"] - database = arguments.get("database", "default") max_rows = arguments.get("max_rows", 50000) - result = await get_async_results(query_id, database, max_rows) + result = await get_async_results(query_id, max_rows) elif name == "list_async_queries": result = await list_async_queries() elif name == "read_data_paginated": query = arguments["query"] - database = arguments.get("database", "default") page_size = arguments.get("page_size", 1000) page = arguments.get("page", 1) - result = await read_data_paginated(query, database, page_size, page) + result = await read_data_paginated(query, page_size, page) elif name == "read_data_pandas": query = arguments["query"] - database = arguments.get("database", "default") max_rows = arguments.get("max_rows", 50000) output_format = arguments.get("output_format", "records") - result = await read_data_pandas(query, database, max_rows, output_format) + result = await read_data_pandas(query, max_rows, output_format) else: raise ValueError(f"Unknown tool: {name}") @@ -1339,4 +1187,4 @@ if __name__ == "__main__": except KeyboardInterrupt: logger.info("Server stopped by user") finally: - connection_pool.close_all() + connection.close()