Removed dead code/unneeded env setting

This commit is contained in:
2026-03-16 11:14:04 +00:00
parent 4ae150f420
commit 5dc546014a
2 changed files with 81 additions and 233 deletions
+76 -228
View File
@@ -37,101 +37,73 @@ SNOWFLAKE_CONFIG = {
"role": os.environ.get("SNOWFLAKE_ROLE"), # None = account default
}
# Available databases/schemas to expose
DATABASES = {
"default": {
"database": os.environ.get("SNOWFLAKE_DATABASE"), # None = account default
"schema": os.environ.get("SNOWFLAKE_SCHEMA"), # None = account default
"description": "Default database and schema"
},
}
class SnowflakeConnectionPool:
"""Manages Snowflake connections with SSO authentication"""
class SnowflakeConnection:
"""Manages a single Snowflake connection with SSO authentication"""
def __init__(self):
self.connections: Dict[str, snowflake.connector.SnowflakeConnection] = {}
self.last_used: Dict[str, datetime] = {}
self.connection_timeout = timedelta(minutes=30)
self._conn: Optional[snowflake.connector.SnowflakeConnection] = None
self._last_used: Optional[datetime] = None
self._timeout = timedelta(minutes=30)
def get_connection(self, database: str = "default") -> snowflake.connector.SnowflakeConnection:
"""Get or create a Snowflake connection with SSO auth"""
cache_key = database
def get(self) -> snowflake.connector.SnowflakeConnection:
"""Get or create the Snowflake connection"""
# Check if we have a valid cached connection
if cache_key in self.connections:
conn = self.connections[cache_key]
last = self.last_used.get(cache_key, datetime.now())
# Check if connection is still valid and not timed out
if datetime.now() - last < self.connection_timeout:
if self._conn and self._last_used:
if datetime.now() - self._last_used < self._timeout:
try:
# Test the connection
cursor = conn.cursor()
cursor = self._conn.cursor()
cursor.execute("SELECT 1")
cursor.close()
self.last_used[cache_key] = datetime.now()
logger.info(f"Reusing cached Snowflake connection for {database}")
return conn
self._last_used = datetime.now()
logger.info("Reusing cached Snowflake connection")
return self._conn
except Exception:
logger.info(f"Cached connection for {database} is stale, reconnecting...")
logger.info("Cached connection is stale, reconnecting...")
try:
conn.close()
self._conn.close()
except Exception:
pass
# Create new connection
logger.info(f"Creating new Snowflake connection for {database}")
logger.info("Creating new Snowflake connection")
logger.info("Note: First connection will open browser for SSO login")
db_config = DATABASES.get(database, DATABASES["default"])
connect_params = {
"account": SNOWFLAKE_CONFIG["account"],
"user": SNOWFLAKE_CONFIG["user"],
"authenticator": SNOWFLAKE_CONFIG["authenticator"],
}
# Add optional parameters if specified
if db_config.get("database"):
connect_params["database"] = db_config["database"]
if db_config.get("schema"):
connect_params["schema"] = db_config["schema"]
if SNOWFLAKE_CONFIG.get("warehouse"):
connect_params["warehouse"] = SNOWFLAKE_CONFIG["warehouse"]
if SNOWFLAKE_CONFIG.get("role"):
connect_params["role"] = SNOWFLAKE_CONFIG["role"]
try:
conn = snowflake.connector.connect(**connect_params)
# Store the connection
self.connections[cache_key] = conn
self.last_used[cache_key] = datetime.now()
logger.info(f"Successfully connected to Snowflake ({database})")
return conn
self._conn = snowflake.connector.connect(**connect_params)
self._last_used = datetime.now()
logger.info("Successfully connected to Snowflake")
return self._conn
except Exception as e:
logger.error(f"Failed to connect to Snowflake: {e}")
raise
def close_all(self):
"""Close all connections"""
for database, conn in self.connections.items():
def close(self):
"""Close the connection"""
if self._conn:
try:
conn.close()
logger.info(f"Closed Snowflake connection to {database}")
self._conn.close()
logger.info("Closed Snowflake connection")
except Exception:
pass
self.connections.clear()
self.last_used.clear()
self._conn = None
self._last_used = None
# Global connection pool
connection_pool = SnowflakeConnectionPool()
# Global connection
connection = SnowflakeConnection()
# Query cache
query_cache: Dict[str, tuple] = {}
@@ -141,9 +113,9 @@ CACHE_TTL = timedelta(minutes=5)
async_queries: Dict[str, Dict[str, Any]] = {}
def get_cache_key(query: str, database: str) -> str:
def get_cache_key(query: str) -> str:
"""Generate cache key for query"""
return hashlib.md5(f"{database}:{query}".encode()).hexdigest()
return hashlib.md5(query.encode()).hexdigest()
def validate_query(query: str) -> tuple[bool, str]:
@@ -161,20 +133,18 @@ def validate_query(query: str) -> tuple[bool, str]:
'CREATE', 'ALTER', 'GRANT', 'REVOKE', 'MERGE']
for keyword in dangerous:
# Check for keyword as whole word (not part of column name)
if f' {keyword} ' in f' {query_upper} ':
return False, f"Query contains forbidden keyword: {keyword}"
return True, "Query validated successfully"
async def test_connection(database: str = "default") -> Dict[str, Any]:
async def test_connection() -> Dict[str, Any]:
"""Test Snowflake connection and return system information"""
try:
conn = connection_pool.get_connection(database)
conn = connection.get()
cursor = conn.cursor(DictCursor)
# Get session information
cursor.execute("""
SELECT
CURRENT_ACCOUNT() as account,
@@ -211,14 +181,13 @@ async def test_connection(database: str = "default") -> Dict[str, Any]:
}
async def list_tables(database: str = "default", schema: Optional[str] = None,
async def list_tables(schema: Optional[str] = None,
include_row_counts: bool = False) -> Dict[str, Any]:
"""List all tables in the database/schema"""
try:
conn = connection_pool.get_connection(database)
conn = connection.get()
cursor = conn.cursor(DictCursor)
# Get current database and schema if not specified
if schema:
query = f"SHOW TABLES IN SCHEMA {schema}"
else:
@@ -242,7 +211,6 @@ async def list_tables(database: str = "default", schema: Optional[str] = None,
cursor.close()
# Return with helpful hint if no tables found
if not tables:
return {
"tables": [],
@@ -257,10 +225,10 @@ async def list_tables(database: str = "default", schema: Optional[str] = None,
raise
async def list_views(schema_name: str, database: str = "default") -> List[Dict[str, Any]]:
async def list_views(schema_name: str) -> List[Dict[str, Any]]:
"""List all views in a schema with their descriptions"""
try:
conn = connection_pool.get_connection(database)
conn = connection.get()
cursor = conn.cursor(DictCursor)
cursor.execute(f"SHOW VIEWS IN SCHEMA {schema_name}")
@@ -287,13 +255,12 @@ async def list_views(schema_name: str, database: str = "default") -> List[Dict[s
raise
async def describe_table(table_name: str, database: str = "default") -> Dict[str, Any]:
async def describe_table(table_name: str) -> Dict[str, Any]:
"""Get detailed schema information for a table or view"""
try:
conn = connection_pool.get_connection(database)
conn = connection.get()
cursor = conn.cursor()
# Parse table name to handle quoting
parts = table_name.split('.')
if len(parts) == 3:
db, schema, table = parts
@@ -305,7 +272,6 @@ async def describe_table(table_name: str, database: str = "default") -> Dict[str
schema = None
table = table_name
# Build quoted version for case-sensitive names
if len(parts) == 3:
quoted_name = f'{parts[0]}.{parts[1]}."{parts[2]}"'
elif len(parts) == 2:
@@ -313,7 +279,6 @@ async def describe_table(table_name: str, database: str = "default") -> Dict[str
else:
quoted_name = f'"{table_name}"'
# Try unquoted first, then quoted if that fails
rows = None
used_name = table_name
for try_name in [table_name, quoted_name]:
@@ -324,13 +289,12 @@ async def describe_table(table_name: str, database: str = "default") -> Dict[str
break
except Exception as inner_e:
if "does not exist" in str(inner_e).lower() and try_name == table_name:
continue # Try quoted version
continue
raise
if rows is None:
rows = []
# Get column names from description using attribute access (safer for ResultMetadata)
desc_columns = []
if cursor.description:
for desc in cursor.description:
@@ -385,11 +349,9 @@ async def describe_table(table_name: str, database: str = "default") -> Dict[str
}
async def read_data(query: str, database: str = "default",
max_rows: int = 50000) -> Dict[str, Any]:
async def read_data(query: str, max_rows: int = 50000) -> Dict[str, Any]:
"""Execute a SELECT query and return results"""
# Validate query
is_valid, message = validate_query(query)
if not is_valid:
return {
@@ -397,8 +359,7 @@ async def read_data(query: str, database: str = "default",
"suggestion": "Use SELECT statements only. For schema information, use describe_table."
}
# Check cache
cache_key = get_cache_key(query, database)
cache_key = get_cache_key(query)
if cache_key in query_cache:
cached_result, cached_time = query_cache[cache_key]
if datetime.now() - cached_time < CACHE_TTL:
@@ -410,23 +371,19 @@ async def read_data(query: str, database: str = "default",
}
try:
conn = connection_pool.get_connection(database)
conn = connection.get()
cursor = conn.cursor(DictCursor)
# Execute query
cursor.execute(query)
# Get column names
columns = [desc[0] for desc in cursor.description] if cursor.description else []
# Fetch rows up to limit
rows = []
for i, row in enumerate(cursor):
if i >= max_rows:
logger.warning(f"Query returned more than {max_rows} rows, truncating")
break
# Convert row to dict with proper serialization
row_dict = {}
for col in columns:
value = row[col]
@@ -446,10 +403,8 @@ async def read_data(query: str, database: str = "default",
"rows": rows,
"row_count": len(rows),
"truncated": len(rows) == max_rows,
"database": database
}
# Cache the result
query_cache[cache_key] = (result, datetime.now())
return result
@@ -474,7 +429,7 @@ async def read_data(query: str, database: str = "default",
async def list_databases() -> List[Dict[str, Any]]:
"""List all accessible databases"""
try:
conn = connection_pool.get_connection("default")
conn = connection.get()
cursor = conn.cursor(DictCursor)
cursor.execute("SHOW DATABASES")
@@ -501,7 +456,7 @@ async def list_databases() -> List[Dict[str, Any]]:
async def list_schemas(database_name: Optional[str] = None) -> List[Dict[str, Any]]:
"""List all schemas in a database"""
try:
conn = connection_pool.get_connection("default")
conn = connection.get()
cursor = conn.cursor(DictCursor)
if database_name:
@@ -533,14 +488,6 @@ async def get_system_health() -> Dict[str, Any]:
health_info = {
"status": "healthy",
"timestamp": datetime.now().isoformat(),
"connection_pool": {
"active_connections": len(connection_pool.connections),
"databases": list(connection_pool.connections.keys()),
"last_used": {
db: last.isoformat()
for db, last in connection_pool.last_used.items()
}
},
"cache": {
"query_cache_size": len(query_cache),
"cache_ttl_minutes": CACHE_TTL.total_seconds() / 60
@@ -552,9 +499,8 @@ async def get_system_health() -> Dict[str, Any]:
}
}
# Test connection
try:
test_result = await test_connection("default")
test_result = await test_connection()
health_info["connection"] = test_result
except Exception:
health_info["status"] = "degraded"
@@ -563,7 +509,7 @@ async def get_system_health() -> Dict[str, Any]:
return health_info
async def describe_query(query: str, database: str = "default") -> Dict[str, Any]:
async def describe_query(query: str) -> Dict[str, Any]:
"""Preview query output columns without executing the full query.
Uses LIMIT 0 to get column metadata efficiently."""
@@ -572,21 +518,18 @@ async def describe_query(query: str, database: str = "default") -> Dict[str, Any
return {"error": message}
try:
conn = connection_pool.get_connection(database)
conn = connection.get()
cursor = conn.cursor(DictCursor)
clean_query = query.rstrip(';').strip()
# Remove existing LIMIT and OFFSET clauses (case insensitive)
clean_query = re.sub(r'\s+LIMIT\s+\d+(\s+OFFSET\s+\d+)?\s*$', '', clean_query, flags=re.IGNORECASE)
clean_query = re.sub(r'\s+OFFSET\s+\d+\s*$', '', clean_query, flags=re.IGNORECASE)
# Add LIMIT 0 to get column metadata without fetching data
limited_query = f"{clean_query} LIMIT 0"
try:
cursor.execute(limited_query)
except Exception:
# Fallback: wrap in subquery if direct approach fails
limited_query = f"SELECT * FROM ({clean_query}) AS subq LIMIT 0"
cursor.execute(limited_query)
@@ -609,7 +552,6 @@ async def describe_query(query: str, database: str = "default") -> Dict[str, Any
return {
"columns": columns,
"column_count": len(columns),
"database": database,
"query_preview": query[:200] + "..." if len(query) > 200 else query
}
@@ -621,7 +563,7 @@ async def describe_query(query: str, database: str = "default") -> Dict[str, Any
}
async def execute_async(query: str, database: str = "default") -> Dict[str, Any]:
async def execute_async(query: str) -> Dict[str, Any]:
"""Submit a query for asynchronous execution. Returns query ID for status tracking."""
is_valid, message = validate_query(query)
@@ -629,17 +571,14 @@ async def execute_async(query: str, database: str = "default") -> Dict[str, Any]
return {"error": message}
try:
conn = connection_pool.get_connection(database)
conn = connection.get()
cursor = conn.cursor()
# Execute asynchronously
cursor.execute_async(query)
query_id = cursor.sfqid
# Store query info for tracking
async_queries[query_id] = {
"query": query[:500],
"database": database,
"submitted_at": datetime.now().isoformat(),
"status": "RUNNING"
}
@@ -650,7 +589,6 @@ async def execute_async(query: str, database: str = "default") -> Dict[str, Any]
"query_id": query_id,
"status": "SUBMITTED",
"message": "Query submitted for async execution. Use get_query_status to check progress.",
"database": database
}
except Exception as e:
@@ -658,11 +596,11 @@ async def execute_async(query: str, database: str = "default") -> Dict[str, Any]
return {"error": str(e)}
async def get_query_status(query_id: str, database: str = "default") -> Dict[str, Any]:
async def get_query_status(query_id: str) -> Dict[str, Any]:
"""Check the status of an async query by its query ID."""
try:
conn = connection_pool.get_connection(database)
conn = connection.get()
status = conn.get_query_status(query_id)
status_name = status.name if hasattr(status, 'name') else str(status)
@@ -670,7 +608,6 @@ async def get_query_status(query_id: str, database: str = "default") -> Dict[str
is_running = conn.is_still_running(status)
is_error = conn.is_an_error(status)
# Update tracked query info
if query_id in async_queries:
async_queries[query_id]["status"] = status_name
@@ -682,7 +619,6 @@ async def get_query_status(query_id: str, database: str = "default") -> Dict[str
"can_fetch_results": not is_running and not is_error
}
# Add original query info if available
if query_id in async_queries:
result["query_info"] = async_queries[query_id]
@@ -693,14 +629,12 @@ async def get_query_status(query_id: str, database: str = "default") -> Dict[str
return {"error": str(e), "query_id": query_id}
async def get_async_results(query_id: str, database: str = "default",
max_rows: int = 50000) -> Dict[str, Any]:
async def get_async_results(query_id: str, max_rows: int = 50000) -> Dict[str, Any]:
"""Retrieve results from a completed async query."""
try:
conn = connection_pool.get_connection(database)
conn = connection.get()
# Check status first
status = conn.get_query_status_throw_if_error(query_id)
if conn.is_still_running(status):
@@ -710,17 +644,13 @@ async def get_async_results(query_id: str, database: str = "default",
"status": status.name if hasattr(status, 'name') else str(status)
}
# Use standard cursor for get_results_from_sfqid
cursor = conn.cursor()
cursor.get_results_from_sfqid(query_id)
# Fetch results first (description may not populate until after fetch)
raw_rows = cursor.fetchmany(max_rows)
# Now get column names from description
columns = [desc[0] for desc in cursor.description] if cursor.description else []
# Convert to list of dicts
rows = []
for row in raw_rows:
row_dict = {}
@@ -734,12 +664,10 @@ async def get_async_results(query_id: str, database: str = "default",
row_dict[col] = value
rows.append(row_dict)
# Check if there are more rows
truncated = len(raw_rows) == max_rows
cursor.close()
# Clean up tracking
if query_id in async_queries:
del async_queries[query_id]
@@ -756,8 +684,8 @@ async def get_async_results(query_id: str, database: str = "default",
return {"error": str(e), "query_id": query_id}
async def read_data_paginated(query: str, database: str = "default",
page_size: int = 1000, page: int = 1) -> Dict[str, Any]:
async def read_data_paginated(query: str, page_size: int = 1000,
page: int = 1) -> Dict[str, Any]:
"""Execute a query with pagination support (offset/limit)."""
is_valid, message = validate_query(query)
@@ -771,12 +699,10 @@ async def read_data_paginated(query: str, database: str = "default",
offset = (page - 1) * page_size
# Strip existing LIMIT/OFFSET from query for proper pagination
clean_query = query.rstrip(';').strip()
clean_query = re.sub(r'\s+LIMIT\s+\d+(\s+OFFSET\s+\d+)?\s*$', '', clean_query, flags=re.IGNORECASE)
clean_query = re.sub(r'\s+OFFSET\s+\d+\s*$', '', clean_query, flags=re.IGNORECASE)
# Wrap query with pagination
paginated_query = f"""
SELECT * FROM (
{clean_query}
@@ -784,8 +710,7 @@ async def read_data_paginated(query: str, database: str = "default",
LIMIT {page_size} OFFSET {offset}
"""
# Also get total count (cached for efficiency)
count_cache_key = get_cache_key(f"COUNT:{clean_query}", database)
count_cache_key = get_cache_key(f"COUNT:{clean_query}")
total_count = None
if count_cache_key in query_cache:
@@ -794,10 +719,9 @@ async def read_data_paginated(query: str, database: str = "default",
total_count = cached_count
try:
conn = connection_pool.get_connection(database)
conn = connection.get()
cursor = conn.cursor(DictCursor)
# Get total count if not cached
if total_count is None:
count_query = f"SELECT COUNT(*) as total FROM ({clean_query}) AS count_subq"
cursor.execute(count_query)
@@ -805,7 +729,6 @@ async def read_data_paginated(query: str, database: str = "default",
total_count = count_row["TOTAL"] if count_row else 0
query_cache[count_cache_key] = (total_count, datetime.now())
# Execute paginated query
cursor.execute(paginated_query)
columns = [desc[0] for desc in cursor.description] if cursor.description else []
@@ -838,7 +761,6 @@ async def read_data_paginated(query: str, database: str = "default",
"has_next": page < total_pages,
"has_previous": page > 1
},
"database": database
}
except Exception as e:
@@ -846,8 +768,7 @@ async def read_data_paginated(query: str, database: str = "default",
return {"error": str(e)}
async def read_data_pandas(query: str, database: str = "default",
max_rows: int = 50000,
async def read_data_pandas(query: str, max_rows: int = 50000,
output_format: str = "records") -> Dict[str, Any]:
"""Execute a query and return results in a pandas-friendly format.
@@ -866,14 +787,13 @@ async def read_data_pandas(query: str, database: str = "default",
return {"error": f"Invalid output_format: {output_format}. Use: records, columns, split, csv"}
try:
conn = connection_pool.get_connection(database)
conn = connection.get()
cursor = conn.cursor(DictCursor)
cursor.execute(query)
columns = [desc[0] for desc in cursor.description] if cursor.description else []
# Collect all rows first
all_rows = []
for i, row in enumerate(cursor):
if i >= max_rows:
@@ -894,19 +814,15 @@ async def read_data_pandas(query: str, database: str = "default",
cursor.close()
truncated = len(all_rows) == max_rows
# Format output based on requested format
if output_format == 'records':
data = all_rows
elif output_format == 'columns':
data = {col: [row.get(col) for row in all_rows] for col in columns}
elif output_format == 'split':
data = {
"columns": columns,
"data": [[row.get(col) for col in columns] for row in all_rows]
}
elif output_format == 'csv':
output = io.StringIO()
writer = csv.DictWriter(output, fieldnames=columns)
@@ -920,7 +836,6 @@ async def read_data_pandas(query: str, database: str = "default",
"data": data,
"row_count": len(all_rows),
"truncated": truncated,
"database": database,
"pandas_hint": "Use pd.DataFrame(result['data']) for 'records' format, pd.DataFrame(result['data']) for 'columns' format, or pd.read_csv(io.StringIO(result['data'])) for 'csv' format"
}
@@ -942,7 +857,7 @@ server = Server("snowflake-mcp")
@server.list_tools()
async def list_tools() -> List[Tool]:
async def handle_list_tools() -> List[Tool]:
"""List available MCP tools"""
return [
Tool(
@@ -950,13 +865,7 @@ async def list_tools() -> List[Tool]:
description="Test Snowflake connection and get session information",
inputSchema={
"type": "object",
"properties": {
"database": {
"type": "string",
"description": "Database context (default: 'default')",
"default": "default"
}
}
"properties": {}
}
),
Tool(
@@ -986,14 +895,9 @@ async def list_tools() -> List[Tool]:
inputSchema={
"type": "object",
"properties": {
"database": {
"type": "string",
"description": "Database context",
"default": "default"
},
"schema": {
"type": "string",
"description": "Schema filter (optional)"
"description": "Schema filter (optional, e.g. 'MY_DB.MY_SCHEMA')"
},
"include_row_counts": {
"type": "boolean",
@@ -1012,11 +916,6 @@ async def list_tools() -> List[Tool]:
"schema_name": {
"type": "string",
"description": "Schema name (e.g., 'MY_DB.MY_SCHEMA')"
},
"database": {
"type": "string",
"description": "Database context",
"default": "default"
}
},
"required": ["schema_name"]
@@ -1024,18 +923,13 @@ async def list_tools() -> List[Tool]:
),
Tool(
name="describe_table",
description="Get detailed schema information for a table",
description="Get detailed schema information for a table or view",
inputSchema={
"type": "object",
"properties": {
"table_name": {
"type": "string",
"description": "Table name (can include database.schema.table)"
},
"database": {
"type": "string",
"description": "Database context",
"default": "default"
}
},
"required": ["table_name"]
@@ -1051,11 +945,6 @@ async def list_tools() -> List[Tool]:
"type": "string",
"description": "SQL SELECT query to execute"
},
"database": {
"type": "string",
"description": "Database context",
"default": "default"
},
"max_rows": {
"type": "integer",
"description": "Maximum rows to return",
@@ -1082,11 +971,6 @@ async def list_tools() -> List[Tool]:
"query": {
"type": "string",
"description": "SQL SELECT query to analyze"
},
"database": {
"type": "string",
"description": "Database context",
"default": "default"
}
},
"required": ["query"]
@@ -1101,11 +985,6 @@ async def list_tools() -> List[Tool]:
"query": {
"type": "string",
"description": "SQL SELECT query to execute asynchronously"
},
"database": {
"type": "string",
"description": "Database context",
"default": "default"
}
},
"required": ["query"]
@@ -1120,11 +999,6 @@ async def list_tools() -> List[Tool]:
"query_id": {
"type": "string",
"description": "Snowflake query ID from execute_async"
},
"database": {
"type": "string",
"description": "Database context",
"default": "default"
}
},
"required": ["query_id"]
@@ -1140,11 +1014,6 @@ async def list_tools() -> List[Tool]:
"type": "string",
"description": "Snowflake query ID from execute_async"
},
"database": {
"type": "string",
"description": "Database context",
"default": "default"
},
"max_rows": {
"type": "integer",
"description": "Maximum rows to return",
@@ -1172,11 +1041,6 @@ async def list_tools() -> List[Tool]:
"type": "string",
"description": "SQL SELECT query to execute"
},
"database": {
"type": "string",
"description": "Database context",
"default": "default"
},
"page_size": {
"type": "integer",
"description": "Number of rows per page (1-10000)",
@@ -1201,11 +1065,6 @@ async def list_tools() -> List[Tool]:
"type": "string",
"description": "SQL SELECT query to execute"
},
"database": {
"type": "string",
"description": "Database context",
"default": "default"
},
"max_rows": {
"type": "integer",
"description": "Maximum rows to return",
@@ -1230,8 +1089,7 @@ async def call_tool(name: str, arguments: Any) -> List[TextContent]:
try:
if name == "test_connection":
database = arguments.get("database", "default")
result = await test_connection(database)
result = await test_connection()
elif name == "list_databases":
result = await list_databases()
@@ -1241,67 +1099,57 @@ async def call_tool(name: str, arguments: Any) -> List[TextContent]:
result = await list_schemas(database_name)
elif name == "list_tables":
database = arguments.get("database", "default")
schema = arguments.get("schema")
include_counts = arguments.get("include_row_counts", False)
result = await list_tables(database, schema, include_counts)
result = await list_tables(schema, include_counts)
elif name == "list_views":
schema_name = arguments["schema_name"]
database = arguments.get("database", "default")
result = await list_views(schema_name, database)
result = await list_views(schema_name)
elif name == "describe_table":
table_name = arguments["table_name"]
database = arguments.get("database", "default")
result = await describe_table(table_name, database)
result = await describe_table(table_name)
elif name == "read_data":
query = arguments["query"]
database = arguments.get("database", "default")
max_rows = arguments.get("max_rows", 50000)
result = await read_data(query, database, max_rows)
result = await read_data(query, max_rows)
elif name == "get_system_health":
result = await get_system_health()
elif name == "describe_query":
query = arguments["query"]
database = arguments.get("database", "default")
result = await describe_query(query, database)
result = await describe_query(query)
elif name == "execute_async":
query = arguments["query"]
database = arguments.get("database", "default")
result = await execute_async(query, database)
result = await execute_async(query)
elif name == "get_query_status":
query_id = arguments["query_id"]
database = arguments.get("database", "default")
result = await get_query_status(query_id, database)
result = await get_query_status(query_id)
elif name == "get_async_results":
query_id = arguments["query_id"]
database = arguments.get("database", "default")
max_rows = arguments.get("max_rows", 50000)
result = await get_async_results(query_id, database, max_rows)
result = await get_async_results(query_id, max_rows)
elif name == "list_async_queries":
result = await list_async_queries()
elif name == "read_data_paginated":
query = arguments["query"]
database = arguments.get("database", "default")
page_size = arguments.get("page_size", 1000)
page = arguments.get("page", 1)
result = await read_data_paginated(query, database, page_size, page)
result = await read_data_paginated(query, page_size, page)
elif name == "read_data_pandas":
query = arguments["query"]
database = arguments.get("database", "default")
max_rows = arguments.get("max_rows", 50000)
output_format = arguments.get("output_format", "records")
result = await read_data_pandas(query, database, max_rows, output_format)
result = await read_data_pandas(query, max_rows, output_format)
else:
raise ValueError(f"Unknown tool: {name}")
@@ -1339,4 +1187,4 @@ if __name__ == "__main__":
except KeyboardInterrupt:
logger.info("Server stopped by user")
finally:
connection_pool.close_all()
connection.close()