Removed dead code/unneeded env setting
This commit is contained in:
@@ -7,7 +7,7 @@ Gives Claude (or any MCP client) the ability to explore your Snowflake account
|
||||
## Features
|
||||
|
||||
- SSO authentication (opens browser on first connection, then reuses the session)
|
||||
- Connection pooling with 30-minute timeout
|
||||
- Single persistent connection with 30-minute timeout
|
||||
- Query caching (5 min TTL)
|
||||
- Read-only safety: only SELECT / SHOW / DESCRIBE queries are allowed
|
||||
- Async query support for long-running queries
|
||||
@@ -37,14 +37,14 @@ Set environment variables (recommended) or edit the defaults at the top of `snow
|
||||
export SNOWFLAKE_ACCOUNT="your-account-id" # e.g. "xy12345.eu-west-1"
|
||||
export SNOWFLAKE_USER="your.email@company.com"
|
||||
|
||||
# Optional — defaults shown
|
||||
export SNOWFLAKE_AUTHENTICATOR="externalbrowser" # SSO via browser
|
||||
# Optional
|
||||
export SNOWFLAKE_AUTHENTICATOR="externalbrowser" # SSO via browser (default)
|
||||
export SNOWFLAKE_WAREHOUSE="" # uses account default if empty
|
||||
export SNOWFLAKE_ROLE="" # uses account default if empty
|
||||
export SNOWFLAKE_DATABASE="" # uses account default if empty
|
||||
export SNOWFLAKE_SCHEMA="" # uses account default if empty
|
||||
```
|
||||
|
||||
Once connected, use `list_databases` and `list_schemas` to discover what you have access to.
|
||||
|
||||
On Windows (PowerShell):
|
||||
```powershell
|
||||
$env:SNOWFLAKE_ACCOUNT = "your-account-id"
|
||||
|
||||
+76
-228
@@ -37,101 +37,73 @@ SNOWFLAKE_CONFIG = {
|
||||
"role": os.environ.get("SNOWFLAKE_ROLE"), # None = account default
|
||||
}
|
||||
|
||||
# Available databases/schemas to expose
|
||||
DATABASES = {
|
||||
"default": {
|
||||
"database": os.environ.get("SNOWFLAKE_DATABASE"), # None = account default
|
||||
"schema": os.environ.get("SNOWFLAKE_SCHEMA"), # None = account default
|
||||
"description": "Default database and schema"
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class SnowflakeConnectionPool:
|
||||
"""Manages Snowflake connections with SSO authentication"""
|
||||
class SnowflakeConnection:
|
||||
"""Manages a single Snowflake connection with SSO authentication"""
|
||||
|
||||
def __init__(self):
|
||||
self.connections: Dict[str, snowflake.connector.SnowflakeConnection] = {}
|
||||
self.last_used: Dict[str, datetime] = {}
|
||||
self.connection_timeout = timedelta(minutes=30)
|
||||
self._conn: Optional[snowflake.connector.SnowflakeConnection] = None
|
||||
self._last_used: Optional[datetime] = None
|
||||
self._timeout = timedelta(minutes=30)
|
||||
|
||||
def get_connection(self, database: str = "default") -> snowflake.connector.SnowflakeConnection:
|
||||
"""Get or create a Snowflake connection with SSO auth"""
|
||||
|
||||
cache_key = database
|
||||
def get(self) -> snowflake.connector.SnowflakeConnection:
|
||||
"""Get or create the Snowflake connection"""
|
||||
|
||||
# Check if we have a valid cached connection
|
||||
if cache_key in self.connections:
|
||||
conn = self.connections[cache_key]
|
||||
last = self.last_used.get(cache_key, datetime.now())
|
||||
|
||||
# Check if connection is still valid and not timed out
|
||||
if datetime.now() - last < self.connection_timeout:
|
||||
if self._conn and self._last_used:
|
||||
if datetime.now() - self._last_used < self._timeout:
|
||||
try:
|
||||
# Test the connection
|
||||
cursor = conn.cursor()
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute("SELECT 1")
|
||||
cursor.close()
|
||||
self.last_used[cache_key] = datetime.now()
|
||||
logger.info(f"Reusing cached Snowflake connection for {database}")
|
||||
return conn
|
||||
self._last_used = datetime.now()
|
||||
logger.info("Reusing cached Snowflake connection")
|
||||
return self._conn
|
||||
except Exception:
|
||||
logger.info(f"Cached connection for {database} is stale, reconnecting...")
|
||||
logger.info("Cached connection is stale, reconnecting...")
|
||||
try:
|
||||
conn.close()
|
||||
self._conn.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Create new connection
|
||||
logger.info(f"Creating new Snowflake connection for {database}")
|
||||
logger.info("Creating new Snowflake connection")
|
||||
logger.info("Note: First connection will open browser for SSO login")
|
||||
|
||||
db_config = DATABASES.get(database, DATABASES["default"])
|
||||
|
||||
connect_params = {
|
||||
"account": SNOWFLAKE_CONFIG["account"],
|
||||
"user": SNOWFLAKE_CONFIG["user"],
|
||||
"authenticator": SNOWFLAKE_CONFIG["authenticator"],
|
||||
}
|
||||
|
||||
# Add optional parameters if specified
|
||||
if db_config.get("database"):
|
||||
connect_params["database"] = db_config["database"]
|
||||
if db_config.get("schema"):
|
||||
connect_params["schema"] = db_config["schema"]
|
||||
if SNOWFLAKE_CONFIG.get("warehouse"):
|
||||
connect_params["warehouse"] = SNOWFLAKE_CONFIG["warehouse"]
|
||||
if SNOWFLAKE_CONFIG.get("role"):
|
||||
connect_params["role"] = SNOWFLAKE_CONFIG["role"]
|
||||
|
||||
try:
|
||||
conn = snowflake.connector.connect(**connect_params)
|
||||
|
||||
# Store the connection
|
||||
self.connections[cache_key] = conn
|
||||
self.last_used[cache_key] = datetime.now()
|
||||
|
||||
logger.info(f"Successfully connected to Snowflake ({database})")
|
||||
return conn
|
||||
|
||||
self._conn = snowflake.connector.connect(**connect_params)
|
||||
self._last_used = datetime.now()
|
||||
logger.info("Successfully connected to Snowflake")
|
||||
return self._conn
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect to Snowflake: {e}")
|
||||
raise
|
||||
|
||||
def close_all(self):
|
||||
"""Close all connections"""
|
||||
for database, conn in self.connections.items():
|
||||
def close(self):
|
||||
"""Close the connection"""
|
||||
if self._conn:
|
||||
try:
|
||||
conn.close()
|
||||
logger.info(f"Closed Snowflake connection to {database}")
|
||||
self._conn.close()
|
||||
logger.info("Closed Snowflake connection")
|
||||
except Exception:
|
||||
pass
|
||||
self.connections.clear()
|
||||
self.last_used.clear()
|
||||
self._conn = None
|
||||
self._last_used = None
|
||||
|
||||
|
||||
# Global connection pool
|
||||
connection_pool = SnowflakeConnectionPool()
|
||||
# Global connection
|
||||
connection = SnowflakeConnection()
|
||||
|
||||
# Query cache
|
||||
query_cache: Dict[str, tuple] = {}
|
||||
@@ -141,9 +113,9 @@ CACHE_TTL = timedelta(minutes=5)
|
||||
async_queries: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
|
||||
def get_cache_key(query: str, database: str) -> str:
|
||||
def get_cache_key(query: str) -> str:
|
||||
"""Generate cache key for query"""
|
||||
return hashlib.md5(f"{database}:{query}".encode()).hexdigest()
|
||||
return hashlib.md5(query.encode()).hexdigest()
|
||||
|
||||
|
||||
def validate_query(query: str) -> tuple[bool, str]:
|
||||
@@ -161,20 +133,18 @@ def validate_query(query: str) -> tuple[bool, str]:
|
||||
'CREATE', 'ALTER', 'GRANT', 'REVOKE', 'MERGE']
|
||||
|
||||
for keyword in dangerous:
|
||||
# Check for keyword as whole word (not part of column name)
|
||||
if f' {keyword} ' in f' {query_upper} ':
|
||||
return False, f"Query contains forbidden keyword: {keyword}"
|
||||
|
||||
return True, "Query validated successfully"
|
||||
|
||||
|
||||
async def test_connection(database: str = "default") -> Dict[str, Any]:
|
||||
async def test_connection() -> Dict[str, Any]:
|
||||
"""Test Snowflake connection and return system information"""
|
||||
try:
|
||||
conn = connection_pool.get_connection(database)
|
||||
conn = connection.get()
|
||||
cursor = conn.cursor(DictCursor)
|
||||
|
||||
# Get session information
|
||||
cursor.execute("""
|
||||
SELECT
|
||||
CURRENT_ACCOUNT() as account,
|
||||
@@ -211,14 +181,13 @@ async def test_connection(database: str = "default") -> Dict[str, Any]:
|
||||
}
|
||||
|
||||
|
||||
async def list_tables(database: str = "default", schema: Optional[str] = None,
|
||||
async def list_tables(schema: Optional[str] = None,
|
||||
include_row_counts: bool = False) -> Dict[str, Any]:
|
||||
"""List all tables in the database/schema"""
|
||||
try:
|
||||
conn = connection_pool.get_connection(database)
|
||||
conn = connection.get()
|
||||
cursor = conn.cursor(DictCursor)
|
||||
|
||||
# Get current database and schema if not specified
|
||||
if schema:
|
||||
query = f"SHOW TABLES IN SCHEMA {schema}"
|
||||
else:
|
||||
@@ -242,7 +211,6 @@ async def list_tables(database: str = "default", schema: Optional[str] = None,
|
||||
|
||||
cursor.close()
|
||||
|
||||
# Return with helpful hint if no tables found
|
||||
if not tables:
|
||||
return {
|
||||
"tables": [],
|
||||
@@ -257,10 +225,10 @@ async def list_tables(database: str = "default", schema: Optional[str] = None,
|
||||
raise
|
||||
|
||||
|
||||
async def list_views(schema_name: str, database: str = "default") -> List[Dict[str, Any]]:
|
||||
async def list_views(schema_name: str) -> List[Dict[str, Any]]:
|
||||
"""List all views in a schema with their descriptions"""
|
||||
try:
|
||||
conn = connection_pool.get_connection(database)
|
||||
conn = connection.get()
|
||||
cursor = conn.cursor(DictCursor)
|
||||
|
||||
cursor.execute(f"SHOW VIEWS IN SCHEMA {schema_name}")
|
||||
@@ -287,13 +255,12 @@ async def list_views(schema_name: str, database: str = "default") -> List[Dict[s
|
||||
raise
|
||||
|
||||
|
||||
async def describe_table(table_name: str, database: str = "default") -> Dict[str, Any]:
|
||||
async def describe_table(table_name: str) -> Dict[str, Any]:
|
||||
"""Get detailed schema information for a table or view"""
|
||||
try:
|
||||
conn = connection_pool.get_connection(database)
|
||||
conn = connection.get()
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Parse table name to handle quoting
|
||||
parts = table_name.split('.')
|
||||
if len(parts) == 3:
|
||||
db, schema, table = parts
|
||||
@@ -305,7 +272,6 @@ async def describe_table(table_name: str, database: str = "default") -> Dict[str
|
||||
schema = None
|
||||
table = table_name
|
||||
|
||||
# Build quoted version for case-sensitive names
|
||||
if len(parts) == 3:
|
||||
quoted_name = f'{parts[0]}.{parts[1]}."{parts[2]}"'
|
||||
elif len(parts) == 2:
|
||||
@@ -313,7 +279,6 @@ async def describe_table(table_name: str, database: str = "default") -> Dict[str
|
||||
else:
|
||||
quoted_name = f'"{table_name}"'
|
||||
|
||||
# Try unquoted first, then quoted if that fails
|
||||
rows = None
|
||||
used_name = table_name
|
||||
for try_name in [table_name, quoted_name]:
|
||||
@@ -324,13 +289,12 @@ async def describe_table(table_name: str, database: str = "default") -> Dict[str
|
||||
break
|
||||
except Exception as inner_e:
|
||||
if "does not exist" in str(inner_e).lower() and try_name == table_name:
|
||||
continue # Try quoted version
|
||||
continue
|
||||
raise
|
||||
|
||||
if rows is None:
|
||||
rows = []
|
||||
|
||||
# Get column names from description using attribute access (safer for ResultMetadata)
|
||||
desc_columns = []
|
||||
if cursor.description:
|
||||
for desc in cursor.description:
|
||||
@@ -385,11 +349,9 @@ async def describe_table(table_name: str, database: str = "default") -> Dict[str
|
||||
}
|
||||
|
||||
|
||||
async def read_data(query: str, database: str = "default",
|
||||
max_rows: int = 50000) -> Dict[str, Any]:
|
||||
async def read_data(query: str, max_rows: int = 50000) -> Dict[str, Any]:
|
||||
"""Execute a SELECT query and return results"""
|
||||
|
||||
# Validate query
|
||||
is_valid, message = validate_query(query)
|
||||
if not is_valid:
|
||||
return {
|
||||
@@ -397,8 +359,7 @@ async def read_data(query: str, database: str = "default",
|
||||
"suggestion": "Use SELECT statements only. For schema information, use describe_table."
|
||||
}
|
||||
|
||||
# Check cache
|
||||
cache_key = get_cache_key(query, database)
|
||||
cache_key = get_cache_key(query)
|
||||
if cache_key in query_cache:
|
||||
cached_result, cached_time = query_cache[cache_key]
|
||||
if datetime.now() - cached_time < CACHE_TTL:
|
||||
@@ -410,23 +371,19 @@ async def read_data(query: str, database: str = "default",
|
||||
}
|
||||
|
||||
try:
|
||||
conn = connection_pool.get_connection(database)
|
||||
conn = connection.get()
|
||||
cursor = conn.cursor(DictCursor)
|
||||
|
||||
# Execute query
|
||||
cursor.execute(query)
|
||||
|
||||
# Get column names
|
||||
columns = [desc[0] for desc in cursor.description] if cursor.description else []
|
||||
|
||||
# Fetch rows up to limit
|
||||
rows = []
|
||||
for i, row in enumerate(cursor):
|
||||
if i >= max_rows:
|
||||
logger.warning(f"Query returned more than {max_rows} rows, truncating")
|
||||
break
|
||||
|
||||
# Convert row to dict with proper serialization
|
||||
row_dict = {}
|
||||
for col in columns:
|
||||
value = row[col]
|
||||
@@ -446,10 +403,8 @@ async def read_data(query: str, database: str = "default",
|
||||
"rows": rows,
|
||||
"row_count": len(rows),
|
||||
"truncated": len(rows) == max_rows,
|
||||
"database": database
|
||||
}
|
||||
|
||||
# Cache the result
|
||||
query_cache[cache_key] = (result, datetime.now())
|
||||
|
||||
return result
|
||||
@@ -474,7 +429,7 @@ async def read_data(query: str, database: str = "default",
|
||||
async def list_databases() -> List[Dict[str, Any]]:
|
||||
"""List all accessible databases"""
|
||||
try:
|
||||
conn = connection_pool.get_connection("default")
|
||||
conn = connection.get()
|
||||
cursor = conn.cursor(DictCursor)
|
||||
|
||||
cursor.execute("SHOW DATABASES")
|
||||
@@ -501,7 +456,7 @@ async def list_databases() -> List[Dict[str, Any]]:
|
||||
async def list_schemas(database_name: Optional[str] = None) -> List[Dict[str, Any]]:
|
||||
"""List all schemas in a database"""
|
||||
try:
|
||||
conn = connection_pool.get_connection("default")
|
||||
conn = connection.get()
|
||||
cursor = conn.cursor(DictCursor)
|
||||
|
||||
if database_name:
|
||||
@@ -533,14 +488,6 @@ async def get_system_health() -> Dict[str, Any]:
|
||||
health_info = {
|
||||
"status": "healthy",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"connection_pool": {
|
||||
"active_connections": len(connection_pool.connections),
|
||||
"databases": list(connection_pool.connections.keys()),
|
||||
"last_used": {
|
||||
db: last.isoformat()
|
||||
for db, last in connection_pool.last_used.items()
|
||||
}
|
||||
},
|
||||
"cache": {
|
||||
"query_cache_size": len(query_cache),
|
||||
"cache_ttl_minutes": CACHE_TTL.total_seconds() / 60
|
||||
@@ -552,9 +499,8 @@ async def get_system_health() -> Dict[str, Any]:
|
||||
}
|
||||
}
|
||||
|
||||
# Test connection
|
||||
try:
|
||||
test_result = await test_connection("default")
|
||||
test_result = await test_connection()
|
||||
health_info["connection"] = test_result
|
||||
except Exception:
|
||||
health_info["status"] = "degraded"
|
||||
@@ -563,7 +509,7 @@ async def get_system_health() -> Dict[str, Any]:
|
||||
return health_info
|
||||
|
||||
|
||||
async def describe_query(query: str, database: str = "default") -> Dict[str, Any]:
|
||||
async def describe_query(query: str) -> Dict[str, Any]:
|
||||
"""Preview query output columns without executing the full query.
|
||||
Uses LIMIT 0 to get column metadata efficiently."""
|
||||
|
||||
@@ -572,21 +518,18 @@ async def describe_query(query: str, database: str = "default") -> Dict[str, Any
|
||||
return {"error": message}
|
||||
|
||||
try:
|
||||
conn = connection_pool.get_connection(database)
|
||||
conn = connection.get()
|
||||
cursor = conn.cursor(DictCursor)
|
||||
|
||||
clean_query = query.rstrip(';').strip()
|
||||
# Remove existing LIMIT and OFFSET clauses (case insensitive)
|
||||
clean_query = re.sub(r'\s+LIMIT\s+\d+(\s+OFFSET\s+\d+)?\s*$', '', clean_query, flags=re.IGNORECASE)
|
||||
clean_query = re.sub(r'\s+OFFSET\s+\d+\s*$', '', clean_query, flags=re.IGNORECASE)
|
||||
|
||||
# Add LIMIT 0 to get column metadata without fetching data
|
||||
limited_query = f"{clean_query} LIMIT 0"
|
||||
|
||||
try:
|
||||
cursor.execute(limited_query)
|
||||
except Exception:
|
||||
# Fallback: wrap in subquery if direct approach fails
|
||||
limited_query = f"SELECT * FROM ({clean_query}) AS subq LIMIT 0"
|
||||
cursor.execute(limited_query)
|
||||
|
||||
@@ -609,7 +552,6 @@ async def describe_query(query: str, database: str = "default") -> Dict[str, Any
|
||||
return {
|
||||
"columns": columns,
|
||||
"column_count": len(columns),
|
||||
"database": database,
|
||||
"query_preview": query[:200] + "..." if len(query) > 200 else query
|
||||
}
|
||||
|
||||
@@ -621,7 +563,7 @@ async def describe_query(query: str, database: str = "default") -> Dict[str, Any
|
||||
}
|
||||
|
||||
|
||||
async def execute_async(query: str, database: str = "default") -> Dict[str, Any]:
|
||||
async def execute_async(query: str) -> Dict[str, Any]:
|
||||
"""Submit a query for asynchronous execution. Returns query ID for status tracking."""
|
||||
|
||||
is_valid, message = validate_query(query)
|
||||
@@ -629,17 +571,14 @@ async def execute_async(query: str, database: str = "default") -> Dict[str, Any]
|
||||
return {"error": message}
|
||||
|
||||
try:
|
||||
conn = connection_pool.get_connection(database)
|
||||
conn = connection.get()
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Execute asynchronously
|
||||
cursor.execute_async(query)
|
||||
query_id = cursor.sfqid
|
||||
|
||||
# Store query info for tracking
|
||||
async_queries[query_id] = {
|
||||
"query": query[:500],
|
||||
"database": database,
|
||||
"submitted_at": datetime.now().isoformat(),
|
||||
"status": "RUNNING"
|
||||
}
|
||||
@@ -650,7 +589,6 @@ async def execute_async(query: str, database: str = "default") -> Dict[str, Any]
|
||||
"query_id": query_id,
|
||||
"status": "SUBMITTED",
|
||||
"message": "Query submitted for async execution. Use get_query_status to check progress.",
|
||||
"database": database
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
@@ -658,11 +596,11 @@ async def execute_async(query: str, database: str = "default") -> Dict[str, Any]
|
||||
return {"error": str(e)}
|
||||
|
||||
|
||||
async def get_query_status(query_id: str, database: str = "default") -> Dict[str, Any]:
|
||||
async def get_query_status(query_id: str) -> Dict[str, Any]:
|
||||
"""Check the status of an async query by its query ID."""
|
||||
|
||||
try:
|
||||
conn = connection_pool.get_connection(database)
|
||||
conn = connection.get()
|
||||
|
||||
status = conn.get_query_status(query_id)
|
||||
status_name = status.name if hasattr(status, 'name') else str(status)
|
||||
@@ -670,7 +608,6 @@ async def get_query_status(query_id: str, database: str = "default") -> Dict[str
|
||||
is_running = conn.is_still_running(status)
|
||||
is_error = conn.is_an_error(status)
|
||||
|
||||
# Update tracked query info
|
||||
if query_id in async_queries:
|
||||
async_queries[query_id]["status"] = status_name
|
||||
|
||||
@@ -682,7 +619,6 @@ async def get_query_status(query_id: str, database: str = "default") -> Dict[str
|
||||
"can_fetch_results": not is_running and not is_error
|
||||
}
|
||||
|
||||
# Add original query info if available
|
||||
if query_id in async_queries:
|
||||
result["query_info"] = async_queries[query_id]
|
||||
|
||||
@@ -693,14 +629,12 @@ async def get_query_status(query_id: str, database: str = "default") -> Dict[str
|
||||
return {"error": str(e), "query_id": query_id}
|
||||
|
||||
|
||||
async def get_async_results(query_id: str, database: str = "default",
|
||||
max_rows: int = 50000) -> Dict[str, Any]:
|
||||
async def get_async_results(query_id: str, max_rows: int = 50000) -> Dict[str, Any]:
|
||||
"""Retrieve results from a completed async query."""
|
||||
|
||||
try:
|
||||
conn = connection_pool.get_connection(database)
|
||||
conn = connection.get()
|
||||
|
||||
# Check status first
|
||||
status = conn.get_query_status_throw_if_error(query_id)
|
||||
|
||||
if conn.is_still_running(status):
|
||||
@@ -710,17 +644,13 @@ async def get_async_results(query_id: str, database: str = "default",
|
||||
"status": status.name if hasattr(status, 'name') else str(status)
|
||||
}
|
||||
|
||||
# Use standard cursor for get_results_from_sfqid
|
||||
cursor = conn.cursor()
|
||||
cursor.get_results_from_sfqid(query_id)
|
||||
|
||||
# Fetch results first (description may not populate until after fetch)
|
||||
raw_rows = cursor.fetchmany(max_rows)
|
||||
|
||||
# Now get column names from description
|
||||
columns = [desc[0] for desc in cursor.description] if cursor.description else []
|
||||
|
||||
# Convert to list of dicts
|
||||
rows = []
|
||||
for row in raw_rows:
|
||||
row_dict = {}
|
||||
@@ -734,12 +664,10 @@ async def get_async_results(query_id: str, database: str = "default",
|
||||
row_dict[col] = value
|
||||
rows.append(row_dict)
|
||||
|
||||
# Check if there are more rows
|
||||
truncated = len(raw_rows) == max_rows
|
||||
|
||||
cursor.close()
|
||||
|
||||
# Clean up tracking
|
||||
if query_id in async_queries:
|
||||
del async_queries[query_id]
|
||||
|
||||
@@ -756,8 +684,8 @@ async def get_async_results(query_id: str, database: str = "default",
|
||||
return {"error": str(e), "query_id": query_id}
|
||||
|
||||
|
||||
async def read_data_paginated(query: str, database: str = "default",
|
||||
page_size: int = 1000, page: int = 1) -> Dict[str, Any]:
|
||||
async def read_data_paginated(query: str, page_size: int = 1000,
|
||||
page: int = 1) -> Dict[str, Any]:
|
||||
"""Execute a query with pagination support (offset/limit)."""
|
||||
|
||||
is_valid, message = validate_query(query)
|
||||
@@ -771,12 +699,10 @@ async def read_data_paginated(query: str, database: str = "default",
|
||||
|
||||
offset = (page - 1) * page_size
|
||||
|
||||
# Strip existing LIMIT/OFFSET from query for proper pagination
|
||||
clean_query = query.rstrip(';').strip()
|
||||
clean_query = re.sub(r'\s+LIMIT\s+\d+(\s+OFFSET\s+\d+)?\s*$', '', clean_query, flags=re.IGNORECASE)
|
||||
clean_query = re.sub(r'\s+OFFSET\s+\d+\s*$', '', clean_query, flags=re.IGNORECASE)
|
||||
|
||||
# Wrap query with pagination
|
||||
paginated_query = f"""
|
||||
SELECT * FROM (
|
||||
{clean_query}
|
||||
@@ -784,8 +710,7 @@ async def read_data_paginated(query: str, database: str = "default",
|
||||
LIMIT {page_size} OFFSET {offset}
|
||||
"""
|
||||
|
||||
# Also get total count (cached for efficiency)
|
||||
count_cache_key = get_cache_key(f"COUNT:{clean_query}", database)
|
||||
count_cache_key = get_cache_key(f"COUNT:{clean_query}")
|
||||
total_count = None
|
||||
|
||||
if count_cache_key in query_cache:
|
||||
@@ -794,10 +719,9 @@ async def read_data_paginated(query: str, database: str = "default",
|
||||
total_count = cached_count
|
||||
|
||||
try:
|
||||
conn = connection_pool.get_connection(database)
|
||||
conn = connection.get()
|
||||
cursor = conn.cursor(DictCursor)
|
||||
|
||||
# Get total count if not cached
|
||||
if total_count is None:
|
||||
count_query = f"SELECT COUNT(*) as total FROM ({clean_query}) AS count_subq"
|
||||
cursor.execute(count_query)
|
||||
@@ -805,7 +729,6 @@ async def read_data_paginated(query: str, database: str = "default",
|
||||
total_count = count_row["TOTAL"] if count_row else 0
|
||||
query_cache[count_cache_key] = (total_count, datetime.now())
|
||||
|
||||
# Execute paginated query
|
||||
cursor.execute(paginated_query)
|
||||
|
||||
columns = [desc[0] for desc in cursor.description] if cursor.description else []
|
||||
@@ -838,7 +761,6 @@ async def read_data_paginated(query: str, database: str = "default",
|
||||
"has_next": page < total_pages,
|
||||
"has_previous": page > 1
|
||||
},
|
||||
"database": database
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
@@ -846,8 +768,7 @@ async def read_data_paginated(query: str, database: str = "default",
|
||||
return {"error": str(e)}
|
||||
|
||||
|
||||
async def read_data_pandas(query: str, database: str = "default",
|
||||
max_rows: int = 50000,
|
||||
async def read_data_pandas(query: str, max_rows: int = 50000,
|
||||
output_format: str = "records") -> Dict[str, Any]:
|
||||
"""Execute a query and return results in a pandas-friendly format.
|
||||
|
||||
@@ -866,14 +787,13 @@ async def read_data_pandas(query: str, database: str = "default",
|
||||
return {"error": f"Invalid output_format: {output_format}. Use: records, columns, split, csv"}
|
||||
|
||||
try:
|
||||
conn = connection_pool.get_connection(database)
|
||||
conn = connection.get()
|
||||
cursor = conn.cursor(DictCursor)
|
||||
|
||||
cursor.execute(query)
|
||||
|
||||
columns = [desc[0] for desc in cursor.description] if cursor.description else []
|
||||
|
||||
# Collect all rows first
|
||||
all_rows = []
|
||||
for i, row in enumerate(cursor):
|
||||
if i >= max_rows:
|
||||
@@ -894,19 +814,15 @@ async def read_data_pandas(query: str, database: str = "default",
|
||||
cursor.close()
|
||||
truncated = len(all_rows) == max_rows
|
||||
|
||||
# Format output based on requested format
|
||||
if output_format == 'records':
|
||||
data = all_rows
|
||||
|
||||
elif output_format == 'columns':
|
||||
data = {col: [row.get(col) for row in all_rows] for col in columns}
|
||||
|
||||
elif output_format == 'split':
|
||||
data = {
|
||||
"columns": columns,
|
||||
"data": [[row.get(col) for col in columns] for row in all_rows]
|
||||
}
|
||||
|
||||
elif output_format == 'csv':
|
||||
output = io.StringIO()
|
||||
writer = csv.DictWriter(output, fieldnames=columns)
|
||||
@@ -920,7 +836,6 @@ async def read_data_pandas(query: str, database: str = "default",
|
||||
"data": data,
|
||||
"row_count": len(all_rows),
|
||||
"truncated": truncated,
|
||||
"database": database,
|
||||
"pandas_hint": "Use pd.DataFrame(result['data']) for 'records' format, pd.DataFrame(result['data']) for 'columns' format, or pd.read_csv(io.StringIO(result['data'])) for 'csv' format"
|
||||
}
|
||||
|
||||
@@ -942,7 +857,7 @@ server = Server("snowflake-mcp")
|
||||
|
||||
|
||||
@server.list_tools()
|
||||
async def list_tools() -> List[Tool]:
|
||||
async def handle_list_tools() -> List[Tool]:
|
||||
"""List available MCP tools"""
|
||||
return [
|
||||
Tool(
|
||||
@@ -950,13 +865,7 @@ async def list_tools() -> List[Tool]:
|
||||
description="Test Snowflake connection and get session information",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"database": {
|
||||
"type": "string",
|
||||
"description": "Database context (default: 'default')",
|
||||
"default": "default"
|
||||
}
|
||||
}
|
||||
"properties": {}
|
||||
}
|
||||
),
|
||||
Tool(
|
||||
@@ -986,14 +895,9 @@ async def list_tools() -> List[Tool]:
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"database": {
|
||||
"type": "string",
|
||||
"description": "Database context",
|
||||
"default": "default"
|
||||
},
|
||||
"schema": {
|
||||
"type": "string",
|
||||
"description": "Schema filter (optional)"
|
||||
"description": "Schema filter (optional, e.g. 'MY_DB.MY_SCHEMA')"
|
||||
},
|
||||
"include_row_counts": {
|
||||
"type": "boolean",
|
||||
@@ -1012,11 +916,6 @@ async def list_tools() -> List[Tool]:
|
||||
"schema_name": {
|
||||
"type": "string",
|
||||
"description": "Schema name (e.g., 'MY_DB.MY_SCHEMA')"
|
||||
},
|
||||
"database": {
|
||||
"type": "string",
|
||||
"description": "Database context",
|
||||
"default": "default"
|
||||
}
|
||||
},
|
||||
"required": ["schema_name"]
|
||||
@@ -1024,18 +923,13 @@ async def list_tools() -> List[Tool]:
|
||||
),
|
||||
Tool(
|
||||
name="describe_table",
|
||||
description="Get detailed schema information for a table",
|
||||
description="Get detailed schema information for a table or view",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"table_name": {
|
||||
"type": "string",
|
||||
"description": "Table name (can include database.schema.table)"
|
||||
},
|
||||
"database": {
|
||||
"type": "string",
|
||||
"description": "Database context",
|
||||
"default": "default"
|
||||
}
|
||||
},
|
||||
"required": ["table_name"]
|
||||
@@ -1051,11 +945,6 @@ async def list_tools() -> List[Tool]:
|
||||
"type": "string",
|
||||
"description": "SQL SELECT query to execute"
|
||||
},
|
||||
"database": {
|
||||
"type": "string",
|
||||
"description": "Database context",
|
||||
"default": "default"
|
||||
},
|
||||
"max_rows": {
|
||||
"type": "integer",
|
||||
"description": "Maximum rows to return",
|
||||
@@ -1082,11 +971,6 @@ async def list_tools() -> List[Tool]:
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "SQL SELECT query to analyze"
|
||||
},
|
||||
"database": {
|
||||
"type": "string",
|
||||
"description": "Database context",
|
||||
"default": "default"
|
||||
}
|
||||
},
|
||||
"required": ["query"]
|
||||
@@ -1101,11 +985,6 @@ async def list_tools() -> List[Tool]:
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "SQL SELECT query to execute asynchronously"
|
||||
},
|
||||
"database": {
|
||||
"type": "string",
|
||||
"description": "Database context",
|
||||
"default": "default"
|
||||
}
|
||||
},
|
||||
"required": ["query"]
|
||||
@@ -1120,11 +999,6 @@ async def list_tools() -> List[Tool]:
|
||||
"query_id": {
|
||||
"type": "string",
|
||||
"description": "Snowflake query ID from execute_async"
|
||||
},
|
||||
"database": {
|
||||
"type": "string",
|
||||
"description": "Database context",
|
||||
"default": "default"
|
||||
}
|
||||
},
|
||||
"required": ["query_id"]
|
||||
@@ -1140,11 +1014,6 @@ async def list_tools() -> List[Tool]:
|
||||
"type": "string",
|
||||
"description": "Snowflake query ID from execute_async"
|
||||
},
|
||||
"database": {
|
||||
"type": "string",
|
||||
"description": "Database context",
|
||||
"default": "default"
|
||||
},
|
||||
"max_rows": {
|
||||
"type": "integer",
|
||||
"description": "Maximum rows to return",
|
||||
@@ -1172,11 +1041,6 @@ async def list_tools() -> List[Tool]:
|
||||
"type": "string",
|
||||
"description": "SQL SELECT query to execute"
|
||||
},
|
||||
"database": {
|
||||
"type": "string",
|
||||
"description": "Database context",
|
||||
"default": "default"
|
||||
},
|
||||
"page_size": {
|
||||
"type": "integer",
|
||||
"description": "Number of rows per page (1-10000)",
|
||||
@@ -1201,11 +1065,6 @@ async def list_tools() -> List[Tool]:
|
||||
"type": "string",
|
||||
"description": "SQL SELECT query to execute"
|
||||
},
|
||||
"database": {
|
||||
"type": "string",
|
||||
"description": "Database context",
|
||||
"default": "default"
|
||||
},
|
||||
"max_rows": {
|
||||
"type": "integer",
|
||||
"description": "Maximum rows to return",
|
||||
@@ -1230,8 +1089,7 @@ async def call_tool(name: str, arguments: Any) -> List[TextContent]:
|
||||
|
||||
try:
|
||||
if name == "test_connection":
|
||||
database = arguments.get("database", "default")
|
||||
result = await test_connection(database)
|
||||
result = await test_connection()
|
||||
|
||||
elif name == "list_databases":
|
||||
result = await list_databases()
|
||||
@@ -1241,67 +1099,57 @@ async def call_tool(name: str, arguments: Any) -> List[TextContent]:
|
||||
result = await list_schemas(database_name)
|
||||
|
||||
elif name == "list_tables":
|
||||
database = arguments.get("database", "default")
|
||||
schema = arguments.get("schema")
|
||||
include_counts = arguments.get("include_row_counts", False)
|
||||
result = await list_tables(database, schema, include_counts)
|
||||
result = await list_tables(schema, include_counts)
|
||||
|
||||
elif name == "list_views":
|
||||
schema_name = arguments["schema_name"]
|
||||
database = arguments.get("database", "default")
|
||||
result = await list_views(schema_name, database)
|
||||
result = await list_views(schema_name)
|
||||
|
||||
elif name == "describe_table":
|
||||
table_name = arguments["table_name"]
|
||||
database = arguments.get("database", "default")
|
||||
result = await describe_table(table_name, database)
|
||||
result = await describe_table(table_name)
|
||||
|
||||
elif name == "read_data":
|
||||
query = arguments["query"]
|
||||
database = arguments.get("database", "default")
|
||||
max_rows = arguments.get("max_rows", 50000)
|
||||
result = await read_data(query, database, max_rows)
|
||||
result = await read_data(query, max_rows)
|
||||
|
||||
elif name == "get_system_health":
|
||||
result = await get_system_health()
|
||||
|
||||
elif name == "describe_query":
|
||||
query = arguments["query"]
|
||||
database = arguments.get("database", "default")
|
||||
result = await describe_query(query, database)
|
||||
result = await describe_query(query)
|
||||
|
||||
elif name == "execute_async":
|
||||
query = arguments["query"]
|
||||
database = arguments.get("database", "default")
|
||||
result = await execute_async(query, database)
|
||||
result = await execute_async(query)
|
||||
|
||||
elif name == "get_query_status":
|
||||
query_id = arguments["query_id"]
|
||||
database = arguments.get("database", "default")
|
||||
result = await get_query_status(query_id, database)
|
||||
result = await get_query_status(query_id)
|
||||
|
||||
elif name == "get_async_results":
|
||||
query_id = arguments["query_id"]
|
||||
database = arguments.get("database", "default")
|
||||
max_rows = arguments.get("max_rows", 50000)
|
||||
result = await get_async_results(query_id, database, max_rows)
|
||||
result = await get_async_results(query_id, max_rows)
|
||||
|
||||
elif name == "list_async_queries":
|
||||
result = await list_async_queries()
|
||||
|
||||
elif name == "read_data_paginated":
|
||||
query = arguments["query"]
|
||||
database = arguments.get("database", "default")
|
||||
page_size = arguments.get("page_size", 1000)
|
||||
page = arguments.get("page", 1)
|
||||
result = await read_data_paginated(query, database, page_size, page)
|
||||
result = await read_data_paginated(query, page_size, page)
|
||||
|
||||
elif name == "read_data_pandas":
|
||||
query = arguments["query"]
|
||||
database = arguments.get("database", "default")
|
||||
max_rows = arguments.get("max_rows", 50000)
|
||||
output_format = arguments.get("output_format", "records")
|
||||
result = await read_data_pandas(query, database, max_rows, output_format)
|
||||
result = await read_data_pandas(query, max_rows, output_format)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown tool: {name}")
|
||||
@@ -1339,4 +1187,4 @@ if __name__ == "__main__":
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Server stopped by user")
|
||||
finally:
|
||||
connection_pool.close_all()
|
||||
connection.close()
|
||||
|
||||
Reference in New Issue
Block a user