Removed dead code/unneeded env setting
This commit is contained in:
@@ -7,7 +7,7 @@ Gives Claude (or any MCP client) the ability to explore your Snowflake account
|
|||||||
## Features
|
## Features
|
||||||
|
|
||||||
- SSO authentication (opens browser on first connection, then reuses the session)
|
- SSO authentication (opens browser on first connection, then reuses the session)
|
||||||
- Connection pooling with 30-minute timeout
|
- Single persistent connection with 30-minute timeout
|
||||||
- Query caching (5 min TTL)
|
- Query caching (5 min TTL)
|
||||||
- Read-only safety: only SELECT / SHOW / DESCRIBE queries are allowed
|
- Read-only safety: only SELECT / SHOW / DESCRIBE queries are allowed
|
||||||
- Async query support for long-running queries
|
- Async query support for long-running queries
|
||||||
@@ -37,14 +37,14 @@ Set environment variables (recommended) or edit the defaults at the top of `snow
|
|||||||
export SNOWFLAKE_ACCOUNT="your-account-id" # e.g. "xy12345.eu-west-1"
|
export SNOWFLAKE_ACCOUNT="your-account-id" # e.g. "xy12345.eu-west-1"
|
||||||
export SNOWFLAKE_USER="your.email@company.com"
|
export SNOWFLAKE_USER="your.email@company.com"
|
||||||
|
|
||||||
# Optional — defaults shown
|
# Optional
|
||||||
export SNOWFLAKE_AUTHENTICATOR="externalbrowser" # SSO via browser
|
export SNOWFLAKE_AUTHENTICATOR="externalbrowser" # SSO via browser (default)
|
||||||
export SNOWFLAKE_WAREHOUSE="" # uses account default if empty
|
export SNOWFLAKE_WAREHOUSE="" # uses account default if empty
|
||||||
export SNOWFLAKE_ROLE="" # uses account default if empty
|
export SNOWFLAKE_ROLE="" # uses account default if empty
|
||||||
export SNOWFLAKE_DATABASE="" # uses account default if empty
|
|
||||||
export SNOWFLAKE_SCHEMA="" # uses account default if empty
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Once connected, use `list_databases` and `list_schemas` to discover what you have access to.
|
||||||
|
|
||||||
On Windows (PowerShell):
|
On Windows (PowerShell):
|
||||||
```powershell
|
```powershell
|
||||||
$env:SNOWFLAKE_ACCOUNT = "your-account-id"
|
$env:SNOWFLAKE_ACCOUNT = "your-account-id"
|
||||||
|
|||||||
+76
-228
@@ -37,101 +37,73 @@ SNOWFLAKE_CONFIG = {
|
|||||||
"role": os.environ.get("SNOWFLAKE_ROLE"), # None = account default
|
"role": os.environ.get("SNOWFLAKE_ROLE"), # None = account default
|
||||||
}
|
}
|
||||||
|
|
||||||
# Available databases/schemas to expose
|
|
||||||
DATABASES = {
|
|
||||||
"default": {
|
|
||||||
"database": os.environ.get("SNOWFLAKE_DATABASE"), # None = account default
|
|
||||||
"schema": os.environ.get("SNOWFLAKE_SCHEMA"), # None = account default
|
|
||||||
"description": "Default database and schema"
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
class SnowflakeConnection:
|
||||||
class SnowflakeConnectionPool:
|
"""Manages a single Snowflake connection with SSO authentication"""
|
||||||
"""Manages Snowflake connections with SSO authentication"""
|
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.connections: Dict[str, snowflake.connector.SnowflakeConnection] = {}
|
self._conn: Optional[snowflake.connector.SnowflakeConnection] = None
|
||||||
self.last_used: Dict[str, datetime] = {}
|
self._last_used: Optional[datetime] = None
|
||||||
self.connection_timeout = timedelta(minutes=30)
|
self._timeout = timedelta(minutes=30)
|
||||||
|
|
||||||
def get_connection(self, database: str = "default") -> snowflake.connector.SnowflakeConnection:
|
def get(self) -> snowflake.connector.SnowflakeConnection:
|
||||||
"""Get or create a Snowflake connection with SSO auth"""
|
"""Get or create the Snowflake connection"""
|
||||||
|
|
||||||
cache_key = database
|
|
||||||
|
|
||||||
# Check if we have a valid cached connection
|
# Check if we have a valid cached connection
|
||||||
if cache_key in self.connections:
|
if self._conn and self._last_used:
|
||||||
conn = self.connections[cache_key]
|
if datetime.now() - self._last_used < self._timeout:
|
||||||
last = self.last_used.get(cache_key, datetime.now())
|
|
||||||
|
|
||||||
# Check if connection is still valid and not timed out
|
|
||||||
if datetime.now() - last < self.connection_timeout:
|
|
||||||
try:
|
try:
|
||||||
# Test the connection
|
cursor = self._conn.cursor()
|
||||||
cursor = conn.cursor()
|
|
||||||
cursor.execute("SELECT 1")
|
cursor.execute("SELECT 1")
|
||||||
cursor.close()
|
cursor.close()
|
||||||
self.last_used[cache_key] = datetime.now()
|
self._last_used = datetime.now()
|
||||||
logger.info(f"Reusing cached Snowflake connection for {database}")
|
logger.info("Reusing cached Snowflake connection")
|
||||||
return conn
|
return self._conn
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.info(f"Cached connection for {database} is stale, reconnecting...")
|
logger.info("Cached connection is stale, reconnecting...")
|
||||||
try:
|
try:
|
||||||
conn.close()
|
self._conn.close()
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# Create new connection
|
# Create new connection
|
||||||
logger.info(f"Creating new Snowflake connection for {database}")
|
logger.info("Creating new Snowflake connection")
|
||||||
logger.info("Note: First connection will open browser for SSO login")
|
logger.info("Note: First connection will open browser for SSO login")
|
||||||
|
|
||||||
db_config = DATABASES.get(database, DATABASES["default"])
|
|
||||||
|
|
||||||
connect_params = {
|
connect_params = {
|
||||||
"account": SNOWFLAKE_CONFIG["account"],
|
"account": SNOWFLAKE_CONFIG["account"],
|
||||||
"user": SNOWFLAKE_CONFIG["user"],
|
"user": SNOWFLAKE_CONFIG["user"],
|
||||||
"authenticator": SNOWFLAKE_CONFIG["authenticator"],
|
"authenticator": SNOWFLAKE_CONFIG["authenticator"],
|
||||||
}
|
}
|
||||||
|
|
||||||
# Add optional parameters if specified
|
|
||||||
if db_config.get("database"):
|
|
||||||
connect_params["database"] = db_config["database"]
|
|
||||||
if db_config.get("schema"):
|
|
||||||
connect_params["schema"] = db_config["schema"]
|
|
||||||
if SNOWFLAKE_CONFIG.get("warehouse"):
|
if SNOWFLAKE_CONFIG.get("warehouse"):
|
||||||
connect_params["warehouse"] = SNOWFLAKE_CONFIG["warehouse"]
|
connect_params["warehouse"] = SNOWFLAKE_CONFIG["warehouse"]
|
||||||
if SNOWFLAKE_CONFIG.get("role"):
|
if SNOWFLAKE_CONFIG.get("role"):
|
||||||
connect_params["role"] = SNOWFLAKE_CONFIG["role"]
|
connect_params["role"] = SNOWFLAKE_CONFIG["role"]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
conn = snowflake.connector.connect(**connect_params)
|
self._conn = snowflake.connector.connect(**connect_params)
|
||||||
|
self._last_used = datetime.now()
|
||||||
# Store the connection
|
logger.info("Successfully connected to Snowflake")
|
||||||
self.connections[cache_key] = conn
|
return self._conn
|
||||||
self.last_used[cache_key] = datetime.now()
|
|
||||||
|
|
||||||
logger.info(f"Successfully connected to Snowflake ({database})")
|
|
||||||
return conn
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to connect to Snowflake: {e}")
|
logger.error(f"Failed to connect to Snowflake: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def close_all(self):
|
def close(self):
|
||||||
"""Close all connections"""
|
"""Close the connection"""
|
||||||
for database, conn in self.connections.items():
|
if self._conn:
|
||||||
try:
|
try:
|
||||||
conn.close()
|
self._conn.close()
|
||||||
logger.info(f"Closed Snowflake connection to {database}")
|
logger.info("Closed Snowflake connection")
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
self.connections.clear()
|
self._conn = None
|
||||||
self.last_used.clear()
|
self._last_used = None
|
||||||
|
|
||||||
|
|
||||||
# Global connection pool
|
# Global connection
|
||||||
connection_pool = SnowflakeConnectionPool()
|
connection = SnowflakeConnection()
|
||||||
|
|
||||||
# Query cache
|
# Query cache
|
||||||
query_cache: Dict[str, tuple] = {}
|
query_cache: Dict[str, tuple] = {}
|
||||||
@@ -141,9 +113,9 @@ CACHE_TTL = timedelta(minutes=5)
|
|||||||
async_queries: Dict[str, Dict[str, Any]] = {}
|
async_queries: Dict[str, Dict[str, Any]] = {}
|
||||||
|
|
||||||
|
|
||||||
def get_cache_key(query: str, database: str) -> str:
|
def get_cache_key(query: str) -> str:
|
||||||
"""Generate cache key for query"""
|
"""Generate cache key for query"""
|
||||||
return hashlib.md5(f"{database}:{query}".encode()).hexdigest()
|
return hashlib.md5(query.encode()).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
def validate_query(query: str) -> tuple[bool, str]:
|
def validate_query(query: str) -> tuple[bool, str]:
|
||||||
@@ -161,20 +133,18 @@ def validate_query(query: str) -> tuple[bool, str]:
|
|||||||
'CREATE', 'ALTER', 'GRANT', 'REVOKE', 'MERGE']
|
'CREATE', 'ALTER', 'GRANT', 'REVOKE', 'MERGE']
|
||||||
|
|
||||||
for keyword in dangerous:
|
for keyword in dangerous:
|
||||||
# Check for keyword as whole word (not part of column name)
|
|
||||||
if f' {keyword} ' in f' {query_upper} ':
|
if f' {keyword} ' in f' {query_upper} ':
|
||||||
return False, f"Query contains forbidden keyword: {keyword}"
|
return False, f"Query contains forbidden keyword: {keyword}"
|
||||||
|
|
||||||
return True, "Query validated successfully"
|
return True, "Query validated successfully"
|
||||||
|
|
||||||
|
|
||||||
async def test_connection(database: str = "default") -> Dict[str, Any]:
|
async def test_connection() -> Dict[str, Any]:
|
||||||
"""Test Snowflake connection and return system information"""
|
"""Test Snowflake connection and return system information"""
|
||||||
try:
|
try:
|
||||||
conn = connection_pool.get_connection(database)
|
conn = connection.get()
|
||||||
cursor = conn.cursor(DictCursor)
|
cursor = conn.cursor(DictCursor)
|
||||||
|
|
||||||
# Get session information
|
|
||||||
cursor.execute("""
|
cursor.execute("""
|
||||||
SELECT
|
SELECT
|
||||||
CURRENT_ACCOUNT() as account,
|
CURRENT_ACCOUNT() as account,
|
||||||
@@ -211,14 +181,13 @@ async def test_connection(database: str = "default") -> Dict[str, Any]:
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
async def list_tables(database: str = "default", schema: Optional[str] = None,
|
async def list_tables(schema: Optional[str] = None,
|
||||||
include_row_counts: bool = False) -> Dict[str, Any]:
|
include_row_counts: bool = False) -> Dict[str, Any]:
|
||||||
"""List all tables in the database/schema"""
|
"""List all tables in the database/schema"""
|
||||||
try:
|
try:
|
||||||
conn = connection_pool.get_connection(database)
|
conn = connection.get()
|
||||||
cursor = conn.cursor(DictCursor)
|
cursor = conn.cursor(DictCursor)
|
||||||
|
|
||||||
# Get current database and schema if not specified
|
|
||||||
if schema:
|
if schema:
|
||||||
query = f"SHOW TABLES IN SCHEMA {schema}"
|
query = f"SHOW TABLES IN SCHEMA {schema}"
|
||||||
else:
|
else:
|
||||||
@@ -242,7 +211,6 @@ async def list_tables(database: str = "default", schema: Optional[str] = None,
|
|||||||
|
|
||||||
cursor.close()
|
cursor.close()
|
||||||
|
|
||||||
# Return with helpful hint if no tables found
|
|
||||||
if not tables:
|
if not tables:
|
||||||
return {
|
return {
|
||||||
"tables": [],
|
"tables": [],
|
||||||
@@ -257,10 +225,10 @@ async def list_tables(database: str = "default", schema: Optional[str] = None,
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
async def list_views(schema_name: str, database: str = "default") -> List[Dict[str, Any]]:
|
async def list_views(schema_name: str) -> List[Dict[str, Any]]:
|
||||||
"""List all views in a schema with their descriptions"""
|
"""List all views in a schema with their descriptions"""
|
||||||
try:
|
try:
|
||||||
conn = connection_pool.get_connection(database)
|
conn = connection.get()
|
||||||
cursor = conn.cursor(DictCursor)
|
cursor = conn.cursor(DictCursor)
|
||||||
|
|
||||||
cursor.execute(f"SHOW VIEWS IN SCHEMA {schema_name}")
|
cursor.execute(f"SHOW VIEWS IN SCHEMA {schema_name}")
|
||||||
@@ -287,13 +255,12 @@ async def list_views(schema_name: str, database: str = "default") -> List[Dict[s
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
async def describe_table(table_name: str, database: str = "default") -> Dict[str, Any]:
|
async def describe_table(table_name: str) -> Dict[str, Any]:
|
||||||
"""Get detailed schema information for a table or view"""
|
"""Get detailed schema information for a table or view"""
|
||||||
try:
|
try:
|
||||||
conn = connection_pool.get_connection(database)
|
conn = connection.get()
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
|
|
||||||
# Parse table name to handle quoting
|
|
||||||
parts = table_name.split('.')
|
parts = table_name.split('.')
|
||||||
if len(parts) == 3:
|
if len(parts) == 3:
|
||||||
db, schema, table = parts
|
db, schema, table = parts
|
||||||
@@ -305,7 +272,6 @@ async def describe_table(table_name: str, database: str = "default") -> Dict[str
|
|||||||
schema = None
|
schema = None
|
||||||
table = table_name
|
table = table_name
|
||||||
|
|
||||||
# Build quoted version for case-sensitive names
|
|
||||||
if len(parts) == 3:
|
if len(parts) == 3:
|
||||||
quoted_name = f'{parts[0]}.{parts[1]}."{parts[2]}"'
|
quoted_name = f'{parts[0]}.{parts[1]}."{parts[2]}"'
|
||||||
elif len(parts) == 2:
|
elif len(parts) == 2:
|
||||||
@@ -313,7 +279,6 @@ async def describe_table(table_name: str, database: str = "default") -> Dict[str
|
|||||||
else:
|
else:
|
||||||
quoted_name = f'"{table_name}"'
|
quoted_name = f'"{table_name}"'
|
||||||
|
|
||||||
# Try unquoted first, then quoted if that fails
|
|
||||||
rows = None
|
rows = None
|
||||||
used_name = table_name
|
used_name = table_name
|
||||||
for try_name in [table_name, quoted_name]:
|
for try_name in [table_name, quoted_name]:
|
||||||
@@ -324,13 +289,12 @@ async def describe_table(table_name: str, database: str = "default") -> Dict[str
|
|||||||
break
|
break
|
||||||
except Exception as inner_e:
|
except Exception as inner_e:
|
||||||
if "does not exist" in str(inner_e).lower() and try_name == table_name:
|
if "does not exist" in str(inner_e).lower() and try_name == table_name:
|
||||||
continue # Try quoted version
|
continue
|
||||||
raise
|
raise
|
||||||
|
|
||||||
if rows is None:
|
if rows is None:
|
||||||
rows = []
|
rows = []
|
||||||
|
|
||||||
# Get column names from description using attribute access (safer for ResultMetadata)
|
|
||||||
desc_columns = []
|
desc_columns = []
|
||||||
if cursor.description:
|
if cursor.description:
|
||||||
for desc in cursor.description:
|
for desc in cursor.description:
|
||||||
@@ -385,11 +349,9 @@ async def describe_table(table_name: str, database: str = "default") -> Dict[str
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
async def read_data(query: str, database: str = "default",
|
async def read_data(query: str, max_rows: int = 50000) -> Dict[str, Any]:
|
||||||
max_rows: int = 50000) -> Dict[str, Any]:
|
|
||||||
"""Execute a SELECT query and return results"""
|
"""Execute a SELECT query and return results"""
|
||||||
|
|
||||||
# Validate query
|
|
||||||
is_valid, message = validate_query(query)
|
is_valid, message = validate_query(query)
|
||||||
if not is_valid:
|
if not is_valid:
|
||||||
return {
|
return {
|
||||||
@@ -397,8 +359,7 @@ async def read_data(query: str, database: str = "default",
|
|||||||
"suggestion": "Use SELECT statements only. For schema information, use describe_table."
|
"suggestion": "Use SELECT statements only. For schema information, use describe_table."
|
||||||
}
|
}
|
||||||
|
|
||||||
# Check cache
|
cache_key = get_cache_key(query)
|
||||||
cache_key = get_cache_key(query, database)
|
|
||||||
if cache_key in query_cache:
|
if cache_key in query_cache:
|
||||||
cached_result, cached_time = query_cache[cache_key]
|
cached_result, cached_time = query_cache[cache_key]
|
||||||
if datetime.now() - cached_time < CACHE_TTL:
|
if datetime.now() - cached_time < CACHE_TTL:
|
||||||
@@ -410,23 +371,19 @@ async def read_data(query: str, database: str = "default",
|
|||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
conn = connection_pool.get_connection(database)
|
conn = connection.get()
|
||||||
cursor = conn.cursor(DictCursor)
|
cursor = conn.cursor(DictCursor)
|
||||||
|
|
||||||
# Execute query
|
|
||||||
cursor.execute(query)
|
cursor.execute(query)
|
||||||
|
|
||||||
# Get column names
|
|
||||||
columns = [desc[0] for desc in cursor.description] if cursor.description else []
|
columns = [desc[0] for desc in cursor.description] if cursor.description else []
|
||||||
|
|
||||||
# Fetch rows up to limit
|
|
||||||
rows = []
|
rows = []
|
||||||
for i, row in enumerate(cursor):
|
for i, row in enumerate(cursor):
|
||||||
if i >= max_rows:
|
if i >= max_rows:
|
||||||
logger.warning(f"Query returned more than {max_rows} rows, truncating")
|
logger.warning(f"Query returned more than {max_rows} rows, truncating")
|
||||||
break
|
break
|
||||||
|
|
||||||
# Convert row to dict with proper serialization
|
|
||||||
row_dict = {}
|
row_dict = {}
|
||||||
for col in columns:
|
for col in columns:
|
||||||
value = row[col]
|
value = row[col]
|
||||||
@@ -446,10 +403,8 @@ async def read_data(query: str, database: str = "default",
|
|||||||
"rows": rows,
|
"rows": rows,
|
||||||
"row_count": len(rows),
|
"row_count": len(rows),
|
||||||
"truncated": len(rows) == max_rows,
|
"truncated": len(rows) == max_rows,
|
||||||
"database": database
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# Cache the result
|
|
||||||
query_cache[cache_key] = (result, datetime.now())
|
query_cache[cache_key] = (result, datetime.now())
|
||||||
|
|
||||||
return result
|
return result
|
||||||
@@ -474,7 +429,7 @@ async def read_data(query: str, database: str = "default",
|
|||||||
async def list_databases() -> List[Dict[str, Any]]:
|
async def list_databases() -> List[Dict[str, Any]]:
|
||||||
"""List all accessible databases"""
|
"""List all accessible databases"""
|
||||||
try:
|
try:
|
||||||
conn = connection_pool.get_connection("default")
|
conn = connection.get()
|
||||||
cursor = conn.cursor(DictCursor)
|
cursor = conn.cursor(DictCursor)
|
||||||
|
|
||||||
cursor.execute("SHOW DATABASES")
|
cursor.execute("SHOW DATABASES")
|
||||||
@@ -501,7 +456,7 @@ async def list_databases() -> List[Dict[str, Any]]:
|
|||||||
async def list_schemas(database_name: Optional[str] = None) -> List[Dict[str, Any]]:
|
async def list_schemas(database_name: Optional[str] = None) -> List[Dict[str, Any]]:
|
||||||
"""List all schemas in a database"""
|
"""List all schemas in a database"""
|
||||||
try:
|
try:
|
||||||
conn = connection_pool.get_connection("default")
|
conn = connection.get()
|
||||||
cursor = conn.cursor(DictCursor)
|
cursor = conn.cursor(DictCursor)
|
||||||
|
|
||||||
if database_name:
|
if database_name:
|
||||||
@@ -533,14 +488,6 @@ async def get_system_health() -> Dict[str, Any]:
|
|||||||
health_info = {
|
health_info = {
|
||||||
"status": "healthy",
|
"status": "healthy",
|
||||||
"timestamp": datetime.now().isoformat(),
|
"timestamp": datetime.now().isoformat(),
|
||||||
"connection_pool": {
|
|
||||||
"active_connections": len(connection_pool.connections),
|
|
||||||
"databases": list(connection_pool.connections.keys()),
|
|
||||||
"last_used": {
|
|
||||||
db: last.isoformat()
|
|
||||||
for db, last in connection_pool.last_used.items()
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"cache": {
|
"cache": {
|
||||||
"query_cache_size": len(query_cache),
|
"query_cache_size": len(query_cache),
|
||||||
"cache_ttl_minutes": CACHE_TTL.total_seconds() / 60
|
"cache_ttl_minutes": CACHE_TTL.total_seconds() / 60
|
||||||
@@ -552,9 +499,8 @@ async def get_system_health() -> Dict[str, Any]:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
# Test connection
|
|
||||||
try:
|
try:
|
||||||
test_result = await test_connection("default")
|
test_result = await test_connection()
|
||||||
health_info["connection"] = test_result
|
health_info["connection"] = test_result
|
||||||
except Exception:
|
except Exception:
|
||||||
health_info["status"] = "degraded"
|
health_info["status"] = "degraded"
|
||||||
@@ -563,7 +509,7 @@ async def get_system_health() -> Dict[str, Any]:
|
|||||||
return health_info
|
return health_info
|
||||||
|
|
||||||
|
|
||||||
async def describe_query(query: str, database: str = "default") -> Dict[str, Any]:
|
async def describe_query(query: str) -> Dict[str, Any]:
|
||||||
"""Preview query output columns without executing the full query.
|
"""Preview query output columns without executing the full query.
|
||||||
Uses LIMIT 0 to get column metadata efficiently."""
|
Uses LIMIT 0 to get column metadata efficiently."""
|
||||||
|
|
||||||
@@ -572,21 +518,18 @@ async def describe_query(query: str, database: str = "default") -> Dict[str, Any
|
|||||||
return {"error": message}
|
return {"error": message}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
conn = connection_pool.get_connection(database)
|
conn = connection.get()
|
||||||
cursor = conn.cursor(DictCursor)
|
cursor = conn.cursor(DictCursor)
|
||||||
|
|
||||||
clean_query = query.rstrip(';').strip()
|
clean_query = query.rstrip(';').strip()
|
||||||
# Remove existing LIMIT and OFFSET clauses (case insensitive)
|
|
||||||
clean_query = re.sub(r'\s+LIMIT\s+\d+(\s+OFFSET\s+\d+)?\s*$', '', clean_query, flags=re.IGNORECASE)
|
clean_query = re.sub(r'\s+LIMIT\s+\d+(\s+OFFSET\s+\d+)?\s*$', '', clean_query, flags=re.IGNORECASE)
|
||||||
clean_query = re.sub(r'\s+OFFSET\s+\d+\s*$', '', clean_query, flags=re.IGNORECASE)
|
clean_query = re.sub(r'\s+OFFSET\s+\d+\s*$', '', clean_query, flags=re.IGNORECASE)
|
||||||
|
|
||||||
# Add LIMIT 0 to get column metadata without fetching data
|
|
||||||
limited_query = f"{clean_query} LIMIT 0"
|
limited_query = f"{clean_query} LIMIT 0"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
cursor.execute(limited_query)
|
cursor.execute(limited_query)
|
||||||
except Exception:
|
except Exception:
|
||||||
# Fallback: wrap in subquery if direct approach fails
|
|
||||||
limited_query = f"SELECT * FROM ({clean_query}) AS subq LIMIT 0"
|
limited_query = f"SELECT * FROM ({clean_query}) AS subq LIMIT 0"
|
||||||
cursor.execute(limited_query)
|
cursor.execute(limited_query)
|
||||||
|
|
||||||
@@ -609,7 +552,6 @@ async def describe_query(query: str, database: str = "default") -> Dict[str, Any
|
|||||||
return {
|
return {
|
||||||
"columns": columns,
|
"columns": columns,
|
||||||
"column_count": len(columns),
|
"column_count": len(columns),
|
||||||
"database": database,
|
|
||||||
"query_preview": query[:200] + "..." if len(query) > 200 else query
|
"query_preview": query[:200] + "..." if len(query) > 200 else query
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -621,7 +563,7 @@ async def describe_query(query: str, database: str = "default") -> Dict[str, Any
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
async def execute_async(query: str, database: str = "default") -> Dict[str, Any]:
|
async def execute_async(query: str) -> Dict[str, Any]:
|
||||||
"""Submit a query for asynchronous execution. Returns query ID for status tracking."""
|
"""Submit a query for asynchronous execution. Returns query ID for status tracking."""
|
||||||
|
|
||||||
is_valid, message = validate_query(query)
|
is_valid, message = validate_query(query)
|
||||||
@@ -629,17 +571,14 @@ async def execute_async(query: str, database: str = "default") -> Dict[str, Any]
|
|||||||
return {"error": message}
|
return {"error": message}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
conn = connection_pool.get_connection(database)
|
conn = connection.get()
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
|
|
||||||
# Execute asynchronously
|
|
||||||
cursor.execute_async(query)
|
cursor.execute_async(query)
|
||||||
query_id = cursor.sfqid
|
query_id = cursor.sfqid
|
||||||
|
|
||||||
# Store query info for tracking
|
|
||||||
async_queries[query_id] = {
|
async_queries[query_id] = {
|
||||||
"query": query[:500],
|
"query": query[:500],
|
||||||
"database": database,
|
|
||||||
"submitted_at": datetime.now().isoformat(),
|
"submitted_at": datetime.now().isoformat(),
|
||||||
"status": "RUNNING"
|
"status": "RUNNING"
|
||||||
}
|
}
|
||||||
@@ -650,7 +589,6 @@ async def execute_async(query: str, database: str = "default") -> Dict[str, Any]
|
|||||||
"query_id": query_id,
|
"query_id": query_id,
|
||||||
"status": "SUBMITTED",
|
"status": "SUBMITTED",
|
||||||
"message": "Query submitted for async execution. Use get_query_status to check progress.",
|
"message": "Query submitted for async execution. Use get_query_status to check progress.",
|
||||||
"database": database
|
|
||||||
}
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -658,11 +596,11 @@ async def execute_async(query: str, database: str = "default") -> Dict[str, Any]
|
|||||||
return {"error": str(e)}
|
return {"error": str(e)}
|
||||||
|
|
||||||
|
|
||||||
async def get_query_status(query_id: str, database: str = "default") -> Dict[str, Any]:
|
async def get_query_status(query_id: str) -> Dict[str, Any]:
|
||||||
"""Check the status of an async query by its query ID."""
|
"""Check the status of an async query by its query ID."""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
conn = connection_pool.get_connection(database)
|
conn = connection.get()
|
||||||
|
|
||||||
status = conn.get_query_status(query_id)
|
status = conn.get_query_status(query_id)
|
||||||
status_name = status.name if hasattr(status, 'name') else str(status)
|
status_name = status.name if hasattr(status, 'name') else str(status)
|
||||||
@@ -670,7 +608,6 @@ async def get_query_status(query_id: str, database: str = "default") -> Dict[str
|
|||||||
is_running = conn.is_still_running(status)
|
is_running = conn.is_still_running(status)
|
||||||
is_error = conn.is_an_error(status)
|
is_error = conn.is_an_error(status)
|
||||||
|
|
||||||
# Update tracked query info
|
|
||||||
if query_id in async_queries:
|
if query_id in async_queries:
|
||||||
async_queries[query_id]["status"] = status_name
|
async_queries[query_id]["status"] = status_name
|
||||||
|
|
||||||
@@ -682,7 +619,6 @@ async def get_query_status(query_id: str, database: str = "default") -> Dict[str
|
|||||||
"can_fetch_results": not is_running and not is_error
|
"can_fetch_results": not is_running and not is_error
|
||||||
}
|
}
|
||||||
|
|
||||||
# Add original query info if available
|
|
||||||
if query_id in async_queries:
|
if query_id in async_queries:
|
||||||
result["query_info"] = async_queries[query_id]
|
result["query_info"] = async_queries[query_id]
|
||||||
|
|
||||||
@@ -693,14 +629,12 @@ async def get_query_status(query_id: str, database: str = "default") -> Dict[str
|
|||||||
return {"error": str(e), "query_id": query_id}
|
return {"error": str(e), "query_id": query_id}
|
||||||
|
|
||||||
|
|
||||||
async def get_async_results(query_id: str, database: str = "default",
|
async def get_async_results(query_id: str, max_rows: int = 50000) -> Dict[str, Any]:
|
||||||
max_rows: int = 50000) -> Dict[str, Any]:
|
|
||||||
"""Retrieve results from a completed async query."""
|
"""Retrieve results from a completed async query."""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
conn = connection_pool.get_connection(database)
|
conn = connection.get()
|
||||||
|
|
||||||
# Check status first
|
|
||||||
status = conn.get_query_status_throw_if_error(query_id)
|
status = conn.get_query_status_throw_if_error(query_id)
|
||||||
|
|
||||||
if conn.is_still_running(status):
|
if conn.is_still_running(status):
|
||||||
@@ -710,17 +644,13 @@ async def get_async_results(query_id: str, database: str = "default",
|
|||||||
"status": status.name if hasattr(status, 'name') else str(status)
|
"status": status.name if hasattr(status, 'name') else str(status)
|
||||||
}
|
}
|
||||||
|
|
||||||
# Use standard cursor for get_results_from_sfqid
|
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
cursor.get_results_from_sfqid(query_id)
|
cursor.get_results_from_sfqid(query_id)
|
||||||
|
|
||||||
# Fetch results first (description may not populate until after fetch)
|
|
||||||
raw_rows = cursor.fetchmany(max_rows)
|
raw_rows = cursor.fetchmany(max_rows)
|
||||||
|
|
||||||
# Now get column names from description
|
|
||||||
columns = [desc[0] for desc in cursor.description] if cursor.description else []
|
columns = [desc[0] for desc in cursor.description] if cursor.description else []
|
||||||
|
|
||||||
# Convert to list of dicts
|
|
||||||
rows = []
|
rows = []
|
||||||
for row in raw_rows:
|
for row in raw_rows:
|
||||||
row_dict = {}
|
row_dict = {}
|
||||||
@@ -734,12 +664,10 @@ async def get_async_results(query_id: str, database: str = "default",
|
|||||||
row_dict[col] = value
|
row_dict[col] = value
|
||||||
rows.append(row_dict)
|
rows.append(row_dict)
|
||||||
|
|
||||||
# Check if there are more rows
|
|
||||||
truncated = len(raw_rows) == max_rows
|
truncated = len(raw_rows) == max_rows
|
||||||
|
|
||||||
cursor.close()
|
cursor.close()
|
||||||
|
|
||||||
# Clean up tracking
|
|
||||||
if query_id in async_queries:
|
if query_id in async_queries:
|
||||||
del async_queries[query_id]
|
del async_queries[query_id]
|
||||||
|
|
||||||
@@ -756,8 +684,8 @@ async def get_async_results(query_id: str, database: str = "default",
|
|||||||
return {"error": str(e), "query_id": query_id}
|
return {"error": str(e), "query_id": query_id}
|
||||||
|
|
||||||
|
|
||||||
async def read_data_paginated(query: str, database: str = "default",
|
async def read_data_paginated(query: str, page_size: int = 1000,
|
||||||
page_size: int = 1000, page: int = 1) -> Dict[str, Any]:
|
page: int = 1) -> Dict[str, Any]:
|
||||||
"""Execute a query with pagination support (offset/limit)."""
|
"""Execute a query with pagination support (offset/limit)."""
|
||||||
|
|
||||||
is_valid, message = validate_query(query)
|
is_valid, message = validate_query(query)
|
||||||
@@ -771,12 +699,10 @@ async def read_data_paginated(query: str, database: str = "default",
|
|||||||
|
|
||||||
offset = (page - 1) * page_size
|
offset = (page - 1) * page_size
|
||||||
|
|
||||||
# Strip existing LIMIT/OFFSET from query for proper pagination
|
|
||||||
clean_query = query.rstrip(';').strip()
|
clean_query = query.rstrip(';').strip()
|
||||||
clean_query = re.sub(r'\s+LIMIT\s+\d+(\s+OFFSET\s+\d+)?\s*$', '', clean_query, flags=re.IGNORECASE)
|
clean_query = re.sub(r'\s+LIMIT\s+\d+(\s+OFFSET\s+\d+)?\s*$', '', clean_query, flags=re.IGNORECASE)
|
||||||
clean_query = re.sub(r'\s+OFFSET\s+\d+\s*$', '', clean_query, flags=re.IGNORECASE)
|
clean_query = re.sub(r'\s+OFFSET\s+\d+\s*$', '', clean_query, flags=re.IGNORECASE)
|
||||||
|
|
||||||
# Wrap query with pagination
|
|
||||||
paginated_query = f"""
|
paginated_query = f"""
|
||||||
SELECT * FROM (
|
SELECT * FROM (
|
||||||
{clean_query}
|
{clean_query}
|
||||||
@@ -784,8 +710,7 @@ async def read_data_paginated(query: str, database: str = "default",
|
|||||||
LIMIT {page_size} OFFSET {offset}
|
LIMIT {page_size} OFFSET {offset}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Also get total count (cached for efficiency)
|
count_cache_key = get_cache_key(f"COUNT:{clean_query}")
|
||||||
count_cache_key = get_cache_key(f"COUNT:{clean_query}", database)
|
|
||||||
total_count = None
|
total_count = None
|
||||||
|
|
||||||
if count_cache_key in query_cache:
|
if count_cache_key in query_cache:
|
||||||
@@ -794,10 +719,9 @@ async def read_data_paginated(query: str, database: str = "default",
|
|||||||
total_count = cached_count
|
total_count = cached_count
|
||||||
|
|
||||||
try:
|
try:
|
||||||
conn = connection_pool.get_connection(database)
|
conn = connection.get()
|
||||||
cursor = conn.cursor(DictCursor)
|
cursor = conn.cursor(DictCursor)
|
||||||
|
|
||||||
# Get total count if not cached
|
|
||||||
if total_count is None:
|
if total_count is None:
|
||||||
count_query = f"SELECT COUNT(*) as total FROM ({clean_query}) AS count_subq"
|
count_query = f"SELECT COUNT(*) as total FROM ({clean_query}) AS count_subq"
|
||||||
cursor.execute(count_query)
|
cursor.execute(count_query)
|
||||||
@@ -805,7 +729,6 @@ async def read_data_paginated(query: str, database: str = "default",
|
|||||||
total_count = count_row["TOTAL"] if count_row else 0
|
total_count = count_row["TOTAL"] if count_row else 0
|
||||||
query_cache[count_cache_key] = (total_count, datetime.now())
|
query_cache[count_cache_key] = (total_count, datetime.now())
|
||||||
|
|
||||||
# Execute paginated query
|
|
||||||
cursor.execute(paginated_query)
|
cursor.execute(paginated_query)
|
||||||
|
|
||||||
columns = [desc[0] for desc in cursor.description] if cursor.description else []
|
columns = [desc[0] for desc in cursor.description] if cursor.description else []
|
||||||
@@ -838,7 +761,6 @@ async def read_data_paginated(query: str, database: str = "default",
|
|||||||
"has_next": page < total_pages,
|
"has_next": page < total_pages,
|
||||||
"has_previous": page > 1
|
"has_previous": page > 1
|
||||||
},
|
},
|
||||||
"database": database
|
|
||||||
}
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -846,8 +768,7 @@ async def read_data_paginated(query: str, database: str = "default",
|
|||||||
return {"error": str(e)}
|
return {"error": str(e)}
|
||||||
|
|
||||||
|
|
||||||
async def read_data_pandas(query: str, database: str = "default",
|
async def read_data_pandas(query: str, max_rows: int = 50000,
|
||||||
max_rows: int = 50000,
|
|
||||||
output_format: str = "records") -> Dict[str, Any]:
|
output_format: str = "records") -> Dict[str, Any]:
|
||||||
"""Execute a query and return results in a pandas-friendly format.
|
"""Execute a query and return results in a pandas-friendly format.
|
||||||
|
|
||||||
@@ -866,14 +787,13 @@ async def read_data_pandas(query: str, database: str = "default",
|
|||||||
return {"error": f"Invalid output_format: {output_format}. Use: records, columns, split, csv"}
|
return {"error": f"Invalid output_format: {output_format}. Use: records, columns, split, csv"}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
conn = connection_pool.get_connection(database)
|
conn = connection.get()
|
||||||
cursor = conn.cursor(DictCursor)
|
cursor = conn.cursor(DictCursor)
|
||||||
|
|
||||||
cursor.execute(query)
|
cursor.execute(query)
|
||||||
|
|
||||||
columns = [desc[0] for desc in cursor.description] if cursor.description else []
|
columns = [desc[0] for desc in cursor.description] if cursor.description else []
|
||||||
|
|
||||||
# Collect all rows first
|
|
||||||
all_rows = []
|
all_rows = []
|
||||||
for i, row in enumerate(cursor):
|
for i, row in enumerate(cursor):
|
||||||
if i >= max_rows:
|
if i >= max_rows:
|
||||||
@@ -894,19 +814,15 @@ async def read_data_pandas(query: str, database: str = "default",
|
|||||||
cursor.close()
|
cursor.close()
|
||||||
truncated = len(all_rows) == max_rows
|
truncated = len(all_rows) == max_rows
|
||||||
|
|
||||||
# Format output based on requested format
|
|
||||||
if output_format == 'records':
|
if output_format == 'records':
|
||||||
data = all_rows
|
data = all_rows
|
||||||
|
|
||||||
elif output_format == 'columns':
|
elif output_format == 'columns':
|
||||||
data = {col: [row.get(col) for row in all_rows] for col in columns}
|
data = {col: [row.get(col) for row in all_rows] for col in columns}
|
||||||
|
|
||||||
elif output_format == 'split':
|
elif output_format == 'split':
|
||||||
data = {
|
data = {
|
||||||
"columns": columns,
|
"columns": columns,
|
||||||
"data": [[row.get(col) for col in columns] for row in all_rows]
|
"data": [[row.get(col) for col in columns] for row in all_rows]
|
||||||
}
|
}
|
||||||
|
|
||||||
elif output_format == 'csv':
|
elif output_format == 'csv':
|
||||||
output = io.StringIO()
|
output = io.StringIO()
|
||||||
writer = csv.DictWriter(output, fieldnames=columns)
|
writer = csv.DictWriter(output, fieldnames=columns)
|
||||||
@@ -920,7 +836,6 @@ async def read_data_pandas(query: str, database: str = "default",
|
|||||||
"data": data,
|
"data": data,
|
||||||
"row_count": len(all_rows),
|
"row_count": len(all_rows),
|
||||||
"truncated": truncated,
|
"truncated": truncated,
|
||||||
"database": database,
|
|
||||||
"pandas_hint": "Use pd.DataFrame(result['data']) for 'records' format, pd.DataFrame(result['data']) for 'columns' format, or pd.read_csv(io.StringIO(result['data'])) for 'csv' format"
|
"pandas_hint": "Use pd.DataFrame(result['data']) for 'records' format, pd.DataFrame(result['data']) for 'columns' format, or pd.read_csv(io.StringIO(result['data'])) for 'csv' format"
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -942,7 +857,7 @@ server = Server("snowflake-mcp")
|
|||||||
|
|
||||||
|
|
||||||
@server.list_tools()
|
@server.list_tools()
|
||||||
async def list_tools() -> List[Tool]:
|
async def handle_list_tools() -> List[Tool]:
|
||||||
"""List available MCP tools"""
|
"""List available MCP tools"""
|
||||||
return [
|
return [
|
||||||
Tool(
|
Tool(
|
||||||
@@ -950,13 +865,7 @@ async def list_tools() -> List[Tool]:
|
|||||||
description="Test Snowflake connection and get session information",
|
description="Test Snowflake connection and get session information",
|
||||||
inputSchema={
|
inputSchema={
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {}
|
||||||
"database": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "Database context (default: 'default')",
|
|
||||||
"default": "default"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
),
|
),
|
||||||
Tool(
|
Tool(
|
||||||
@@ -986,14 +895,9 @@ async def list_tools() -> List[Tool]:
|
|||||||
inputSchema={
|
inputSchema={
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
"database": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "Database context",
|
|
||||||
"default": "default"
|
|
||||||
},
|
|
||||||
"schema": {
|
"schema": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "Schema filter (optional)"
|
"description": "Schema filter (optional, e.g. 'MY_DB.MY_SCHEMA')"
|
||||||
},
|
},
|
||||||
"include_row_counts": {
|
"include_row_counts": {
|
||||||
"type": "boolean",
|
"type": "boolean",
|
||||||
@@ -1012,11 +916,6 @@ async def list_tools() -> List[Tool]:
|
|||||||
"schema_name": {
|
"schema_name": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "Schema name (e.g., 'MY_DB.MY_SCHEMA')"
|
"description": "Schema name (e.g., 'MY_DB.MY_SCHEMA')"
|
||||||
},
|
|
||||||
"database": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "Database context",
|
|
||||||
"default": "default"
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"required": ["schema_name"]
|
"required": ["schema_name"]
|
||||||
@@ -1024,18 +923,13 @@ async def list_tools() -> List[Tool]:
|
|||||||
),
|
),
|
||||||
Tool(
|
Tool(
|
||||||
name="describe_table",
|
name="describe_table",
|
||||||
description="Get detailed schema information for a table",
|
description="Get detailed schema information for a table or view",
|
||||||
inputSchema={
|
inputSchema={
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
"table_name": {
|
"table_name": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "Table name (can include database.schema.table)"
|
"description": "Table name (can include database.schema.table)"
|
||||||
},
|
|
||||||
"database": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "Database context",
|
|
||||||
"default": "default"
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"required": ["table_name"]
|
"required": ["table_name"]
|
||||||
@@ -1051,11 +945,6 @@ async def list_tools() -> List[Tool]:
|
|||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "SQL SELECT query to execute"
|
"description": "SQL SELECT query to execute"
|
||||||
},
|
},
|
||||||
"database": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "Database context",
|
|
||||||
"default": "default"
|
|
||||||
},
|
|
||||||
"max_rows": {
|
"max_rows": {
|
||||||
"type": "integer",
|
"type": "integer",
|
||||||
"description": "Maximum rows to return",
|
"description": "Maximum rows to return",
|
||||||
@@ -1082,11 +971,6 @@ async def list_tools() -> List[Tool]:
|
|||||||
"query": {
|
"query": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "SQL SELECT query to analyze"
|
"description": "SQL SELECT query to analyze"
|
||||||
},
|
|
||||||
"database": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "Database context",
|
|
||||||
"default": "default"
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"required": ["query"]
|
"required": ["query"]
|
||||||
@@ -1101,11 +985,6 @@ async def list_tools() -> List[Tool]:
|
|||||||
"query": {
|
"query": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "SQL SELECT query to execute asynchronously"
|
"description": "SQL SELECT query to execute asynchronously"
|
||||||
},
|
|
||||||
"database": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "Database context",
|
|
||||||
"default": "default"
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"required": ["query"]
|
"required": ["query"]
|
||||||
@@ -1120,11 +999,6 @@ async def list_tools() -> List[Tool]:
|
|||||||
"query_id": {
|
"query_id": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "Snowflake query ID from execute_async"
|
"description": "Snowflake query ID from execute_async"
|
||||||
},
|
|
||||||
"database": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "Database context",
|
|
||||||
"default": "default"
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"required": ["query_id"]
|
"required": ["query_id"]
|
||||||
@@ -1140,11 +1014,6 @@ async def list_tools() -> List[Tool]:
|
|||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "Snowflake query ID from execute_async"
|
"description": "Snowflake query ID from execute_async"
|
||||||
},
|
},
|
||||||
"database": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "Database context",
|
|
||||||
"default": "default"
|
|
||||||
},
|
|
||||||
"max_rows": {
|
"max_rows": {
|
||||||
"type": "integer",
|
"type": "integer",
|
||||||
"description": "Maximum rows to return",
|
"description": "Maximum rows to return",
|
||||||
@@ -1172,11 +1041,6 @@ async def list_tools() -> List[Tool]:
|
|||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "SQL SELECT query to execute"
|
"description": "SQL SELECT query to execute"
|
||||||
},
|
},
|
||||||
"database": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "Database context",
|
|
||||||
"default": "default"
|
|
||||||
},
|
|
||||||
"page_size": {
|
"page_size": {
|
||||||
"type": "integer",
|
"type": "integer",
|
||||||
"description": "Number of rows per page (1-10000)",
|
"description": "Number of rows per page (1-10000)",
|
||||||
@@ -1201,11 +1065,6 @@ async def list_tools() -> List[Tool]:
|
|||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "SQL SELECT query to execute"
|
"description": "SQL SELECT query to execute"
|
||||||
},
|
},
|
||||||
"database": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "Database context",
|
|
||||||
"default": "default"
|
|
||||||
},
|
|
||||||
"max_rows": {
|
"max_rows": {
|
||||||
"type": "integer",
|
"type": "integer",
|
||||||
"description": "Maximum rows to return",
|
"description": "Maximum rows to return",
|
||||||
@@ -1230,8 +1089,7 @@ async def call_tool(name: str, arguments: Any) -> List[TextContent]:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
if name == "test_connection":
|
if name == "test_connection":
|
||||||
database = arguments.get("database", "default")
|
result = await test_connection()
|
||||||
result = await test_connection(database)
|
|
||||||
|
|
||||||
elif name == "list_databases":
|
elif name == "list_databases":
|
||||||
result = await list_databases()
|
result = await list_databases()
|
||||||
@@ -1241,67 +1099,57 @@ async def call_tool(name: str, arguments: Any) -> List[TextContent]:
|
|||||||
result = await list_schemas(database_name)
|
result = await list_schemas(database_name)
|
||||||
|
|
||||||
elif name == "list_tables":
|
elif name == "list_tables":
|
||||||
database = arguments.get("database", "default")
|
|
||||||
schema = arguments.get("schema")
|
schema = arguments.get("schema")
|
||||||
include_counts = arguments.get("include_row_counts", False)
|
include_counts = arguments.get("include_row_counts", False)
|
||||||
result = await list_tables(database, schema, include_counts)
|
result = await list_tables(schema, include_counts)
|
||||||
|
|
||||||
elif name == "list_views":
|
elif name == "list_views":
|
||||||
schema_name = arguments["schema_name"]
|
schema_name = arguments["schema_name"]
|
||||||
database = arguments.get("database", "default")
|
result = await list_views(schema_name)
|
||||||
result = await list_views(schema_name, database)
|
|
||||||
|
|
||||||
elif name == "describe_table":
|
elif name == "describe_table":
|
||||||
table_name = arguments["table_name"]
|
table_name = arguments["table_name"]
|
||||||
database = arguments.get("database", "default")
|
result = await describe_table(table_name)
|
||||||
result = await describe_table(table_name, database)
|
|
||||||
|
|
||||||
elif name == "read_data":
|
elif name == "read_data":
|
||||||
query = arguments["query"]
|
query = arguments["query"]
|
||||||
database = arguments.get("database", "default")
|
|
||||||
max_rows = arguments.get("max_rows", 50000)
|
max_rows = arguments.get("max_rows", 50000)
|
||||||
result = await read_data(query, database, max_rows)
|
result = await read_data(query, max_rows)
|
||||||
|
|
||||||
elif name == "get_system_health":
|
elif name == "get_system_health":
|
||||||
result = await get_system_health()
|
result = await get_system_health()
|
||||||
|
|
||||||
elif name == "describe_query":
|
elif name == "describe_query":
|
||||||
query = arguments["query"]
|
query = arguments["query"]
|
||||||
database = arguments.get("database", "default")
|
result = await describe_query(query)
|
||||||
result = await describe_query(query, database)
|
|
||||||
|
|
||||||
elif name == "execute_async":
|
elif name == "execute_async":
|
||||||
query = arguments["query"]
|
query = arguments["query"]
|
||||||
database = arguments.get("database", "default")
|
result = await execute_async(query)
|
||||||
result = await execute_async(query, database)
|
|
||||||
|
|
||||||
elif name == "get_query_status":
|
elif name == "get_query_status":
|
||||||
query_id = arguments["query_id"]
|
query_id = arguments["query_id"]
|
||||||
database = arguments.get("database", "default")
|
result = await get_query_status(query_id)
|
||||||
result = await get_query_status(query_id, database)
|
|
||||||
|
|
||||||
elif name == "get_async_results":
|
elif name == "get_async_results":
|
||||||
query_id = arguments["query_id"]
|
query_id = arguments["query_id"]
|
||||||
database = arguments.get("database", "default")
|
|
||||||
max_rows = arguments.get("max_rows", 50000)
|
max_rows = arguments.get("max_rows", 50000)
|
||||||
result = await get_async_results(query_id, database, max_rows)
|
result = await get_async_results(query_id, max_rows)
|
||||||
|
|
||||||
elif name == "list_async_queries":
|
elif name == "list_async_queries":
|
||||||
result = await list_async_queries()
|
result = await list_async_queries()
|
||||||
|
|
||||||
elif name == "read_data_paginated":
|
elif name == "read_data_paginated":
|
||||||
query = arguments["query"]
|
query = arguments["query"]
|
||||||
database = arguments.get("database", "default")
|
|
||||||
page_size = arguments.get("page_size", 1000)
|
page_size = arguments.get("page_size", 1000)
|
||||||
page = arguments.get("page", 1)
|
page = arguments.get("page", 1)
|
||||||
result = await read_data_paginated(query, database, page_size, page)
|
result = await read_data_paginated(query, page_size, page)
|
||||||
|
|
||||||
elif name == "read_data_pandas":
|
elif name == "read_data_pandas":
|
||||||
query = arguments["query"]
|
query = arguments["query"]
|
||||||
database = arguments.get("database", "default")
|
|
||||||
max_rows = arguments.get("max_rows", 50000)
|
max_rows = arguments.get("max_rows", 50000)
|
||||||
output_format = arguments.get("output_format", "records")
|
output_format = arguments.get("output_format", "records")
|
||||||
result = await read_data_pandas(query, database, max_rows, output_format)
|
result = await read_data_pandas(query, max_rows, output_format)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown tool: {name}")
|
raise ValueError(f"Unknown tool: {name}")
|
||||||
@@ -1339,4 +1187,4 @@ if __name__ == "__main__":
|
|||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
logger.info("Server stopped by user")
|
logger.info("Server stopped by user")
|
||||||
finally:
|
finally:
|
||||||
connection_pool.close_all()
|
connection.close()
|
||||||
|
|||||||
Reference in New Issue
Block a user