playbook/antigravity-awesome-skills/skills/monte-carlo-push-ingestion/scripts/templates/redshift/collect_query_logs.py

309 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Redshift — Query Log Collection (collect-only)
================================================
Collects completed query execution records from Redshift using sys_query_history
and sys_querytext (modern RA3/serverless), assembles full SQL text from
multi-row text chunks, and writes a JSON manifest file that can be consumed
by push_query_logs.py.
Substitution points (search for "← SUBSTITUTE"):
- REDSHIFT_HOST / REDSHIFT_DB / REDSHIFT_USER / REDSHIFT_PASSWORD : connection
- LOOKBACK_HOURS : hours back from [now - LAG_HOURS] to collect (default 25)
- LOOKBACK_LAG_HOURS: lag behind now to avoid in-flight queries (default 1)
- BATCH_SIZE : number of query_ids to fetch texts for in one SQL call
- MAX_QUERIES : maximum query rows to process per run
Prerequisites:
pip install psycopg2-binary
"""
from __future__ import annotations
import argparse
import ipaddress
import json
import logging
import os
import re
from datetime import datetime, timezone
from typing import Any
import psycopg2
from _safe_paths import safe_existing_directory, safe_input_json_path, safe_output_json_path, write_json_file
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
log = logging.getLogger(__name__)
LOG_TYPE = "redshift"
LOOKBACK_HOURS: int = int(os.getenv("LOOKBACK_HOURS", "25")) # ← SUBSTITUTE
LOOKBACK_LAG_HOURS: int = int(os.getenv("LOOKBACK_LAG_HOURS", "1")) # ← SUBSTITUTE
BATCH_SIZE: int = int(os.getenv("BATCH_SIZE", "200")) # ← SUBSTITUTE
MAX_QUERIES: int = int(os.getenv("MAX_QUERIES", "10000")) # ← SUBSTITUTE
_ALLOWED_REDSHIFT_HOST_RE = re.compile(
r"^[a-z0-9][a-z0-9.-]*\.(?:redshift|redshift-serverless)\.[a-z0-9-]+\.amazonaws\.com(?:\.cn)?$",
re.IGNORECASE,
)
def _explicitly_allowed_redshift_hosts() -> set[str]:
raw_hosts = os.getenv("REDSHIFT_ALLOWED_HOSTS", "")
return {host.strip().lower().rstrip(".") for host in raw_hosts.split(",") if host.strip()}
def validate_redshift_host(host: str, *, allow_private: bool = False) -> str:
value = str(host).strip()
if not value or any(part in value for part in ("/", "\\", "@", ":")):
raise ValueError(f"Invalid Redshift host: {host!r}")
hostname = value.lower().rstrip(".")
allowed_hosts = _explicitly_allowed_redshift_hosts()
try:
address = ipaddress.ip_address(value)
except ValueError:
if hostname in allowed_hosts:
return hostname
match = _ALLOWED_REDSHIFT_HOST_RE.fullmatch(hostname)
if match:
return match.group(0)
raise ValueError(
"Redshift host must be an AWS Redshift endpoint or be listed in REDSHIFT_ALLOWED_HOSTS"
)
if hostname not in allowed_hosts:
raise ValueError("Redshift IP hosts must be listed in REDSHIFT_ALLOWED_HOSTS")
blocked = (
address.is_loopback
or address.is_link_local
or address.is_multicast
or address.is_unspecified
or address.is_reserved
or (address.is_private and not allow_private)
)
if blocked:
raise ValueError(f"Redshift host address is not allowed: {host!r}")
return str(address)
def _bounded_int(value: int, field: str, *, minimum: int, maximum: int) -> int:
value = int(value)
if value < minimum or value > maximum:
raise ValueError(f"{field} must be between {minimum} and {maximum}")
return value
def _check_available_memory(min_gb: float = 2.0) -> None:
"""Warn if available memory is below the threshold."""
try:
if hasattr(os, "sysconf"): # Linux / macOS
page_size = os.sysconf("SC_PAGE_SIZE")
avail_pages = os.sysconf("SC_AVPHYS_PAGES")
avail_gb = (page_size * avail_pages) / (1024 ** 3)
else:
return # Windows — skip check
except (ValueError, OSError):
return
if avail_gb < min_gb:
log.warning(
"Only %.1f GB of memory available (minimum recommended: %.1f GB). "
"Consider reducing the collection scope or increasing available memory.",
avail_gb,
min_gb,
)
def _dictfetch(cursor: Any, sql: str, params: tuple | None = None) -> list[dict[str, Any]]:
cursor.execute(sql, params)
cols = [d.name for d in cursor.description]
rows = []
while True:
chunk = cursor.fetchmany(1000)
if not chunk:
break
rows.extend(dict(zip(cols, row)) for row in chunk)
return rows
def _safe_isoformat(dt: Any) -> str | None:
if dt is None:
return None
if hasattr(dt, "isoformat"):
if dt.tzinfo is None:
dt = dt.replace(tzinfo=timezone.utc)
return dt.isoformat()
return str(dt)
def fetch_query_metadata(
cursor: Any,
lookback_hours: int,
lag_hours: int,
max_queries: int,
) -> list[dict[str, Any]]:
"""Fetch query execution metadata from sys_query_history."""
lookback_hours = _bounded_int(lookback_hours, "lookback_hours", minimum=1, maximum=24 * 31)
lag_hours = _bounded_int(lag_hours, "lag_hours", minimum=0, maximum=24 * 7)
max_queries = _bounded_int(max_queries, "max_queries", minimum=1, maximum=100000)
return _dictfetch(
cursor,
"""
SELECT
query_id,
start_time,
end_time,
status,
user_id,
database_name,
elapsed_time
FROM sys_query_history
WHERE start_time >= DATEADD(hour, -%s, GETDATE())
AND start_time < DATEADD(hour, -%s, GETDATE())
AND status = 'success'
ORDER BY start_time
LIMIT %s
""", # ← SUBSTITUTE: add AND database_name = 'mydb' to narrow scope
(lookback_hours, lag_hours, max_queries),
)
def fetch_query_texts_batch(cursor: Any, query_ids: list[int]) -> dict[int, str]:
"""Batch-fetch and assemble multi-row query texts for a list of query_ids."""
if not query_ids:
return {}
query_ids = [_bounded_int(qid, "query_id", minimum=1, maximum=2**63 - 1) for qid in query_ids]
rows = _dictfetch(
cursor,
"""
SELECT
query_id,
LISTAGG(
CASE WHEN LEN(text) <= 200 THEN text ELSE LEFT(text, 200) END,
''
) WITHIN GROUP (ORDER BY sequence) AS query_text
FROM sys_querytext
WHERE query_id = ANY(%s)
GROUP BY query_id
""",
(query_ids,),
)
return {r["query_id"]: r["query_text"] for r in rows if r.get("query_text")}
def collect(
host: str,
db: str,
user: str,
password: str,
manifest_path: str = "manifest_query_logs.json",
port: int = 5439,
lookback_hours: int = LOOKBACK_HOURS,
lookback_lag_hours: int = LOOKBACK_LAG_HOURS,
batch_size: int = BATCH_SIZE,
max_queries: int = MAX_QUERIES,
) -> list[dict[str, Any]]:
"""Connect to Redshift, collect query logs, write a JSON manifest, and return entries."""
_check_available_memory()
allow_private_host = os.getenv("REDSHIFT_ALLOW_PRIVATE_HOST", "").lower() in {"1", "true", "yes"}
host = validate_redshift_host(host, allow_private=allow_private_host)
port = _bounded_int(port, "port", minimum=1, maximum=65535)
lookback_hours = _bounded_int(lookback_hours, "lookback_hours", minimum=1, maximum=24 * 31)
lookback_lag_hours = _bounded_int(lookback_lag_hours, "lookback_lag_hours", minimum=0, maximum=24 * 7)
batch_size = _bounded_int(batch_size, "batch_size", minimum=1, maximum=10000)
max_queries = _bounded_int(max_queries, "max_queries", minimum=1, maximum=100000)
collected_at = datetime.now(timezone.utc).isoformat()
conn = psycopg2.connect(
host=host, port=port, dbname=db, user=user, password=password, connect_timeout=30,
)
try:
with conn.cursor() as cursor:
query_meta = fetch_query_metadata(cursor, lookback_hours, lookback_lag_hours, max_queries)
log.info("Retrieved %d query metadata rows", len(query_meta))
# Batch-fetch texts to avoid enormous single queries
query_ids = [r["query_id"] for r in query_meta]
text_map: dict[int, str] = {}
for i in range(0, len(query_ids), batch_size):
batch = query_ids[i : i + batch_size]
text_map.update(fetch_query_texts_batch(cursor, batch))
log.debug("Fetched texts for batch %d%d", i, i + len(batch))
finally:
conn.close()
entries: list[dict[str, Any]] = []
for row in query_meta:
qid = row["query_id"]
query_text = text_map.get(qid, "")
if not query_text.strip():
continue # ← SUBSTITUTE: decide whether to push rows with missing text
entry = {
"query_id": str(qid),
"query_text": query_text,
"start_time": _safe_isoformat(row.get("start_time")),
"end_time": _safe_isoformat(row.get("end_time")),
"user": str(row.get("user_id")) if row.get("user_id") is not None else None,
"database_name": row.get("database_name"),
"elapsed_time_us": row.get("elapsed_time"),
}
entries.append(entry)
log.info("Collected %d query log entries", len(entries))
manifest = {
"log_type": LOG_TYPE,
"collected_at": collected_at,
"lookback_hours": lookback_hours,
"lookback_lag_hours": lookback_lag_hours,
"query_log_count": len(entries),
"entries": entries,
}
write_json_file(manifest_path, manifest)
log.info("Manifest written to %s (%d entries)", manifest_path, len(entries))
return entries
def main() -> None:
parser = argparse.ArgumentParser(description="Collect Redshift query logs to a manifest file")
parser.add_argument("--db", default=os.getenv("REDSHIFT_DB")) # ← SUBSTITUTE
parser.add_argument("--user", default=os.getenv("REDSHIFT_USER")) # ← SUBSTITUTE
parser.add_argument("--password", default=os.getenv("REDSHIFT_PASSWORD")) # ← SUBSTITUTE
parser.add_argument("--port", type=int, default=int(os.getenv("REDSHIFT_PORT", "5439")))
parser.add_argument("--lookback-hours", type=int, default=LOOKBACK_HOURS)
parser.add_argument("--lookback-lag-hours", type=int, default=LOOKBACK_LAG_HOURS)
parser.add_argument("--batch-size", type=int, default=BATCH_SIZE)
parser.add_argument("--max-queries", type=int, default=MAX_QUERIES)
parser.add_argument("--manifest", default="manifest_query_logs.json")
args = parser.parse_args()
required = ["db", "user", "password"]
missing = [k for k in required if getattr(args, k) is None]
if missing:
parser.error(f"Missing required arguments/env vars: {missing}")
redshift_host = os.getenv("REDSHIFT_HOST")
if not redshift_host:
parser.error("Missing required env var: REDSHIFT_HOST")
redshift_host = validate_redshift_host(
redshift_host,
allow_private=os.getenv("REDSHIFT_ALLOW_PRIVATE_HOST", "").lower() in {"1", "true", "yes"},
)
collect(
host=redshift_host,
db=args.db,
user=args.user,
password=args.password,
manifest_path=args.manifest,
port=args.port,
lookback_hours=args.lookback_hours,
lookback_lag_hours=args.lookback_lag_hours,
batch_size=args.batch_size,
max_queries=args.max_queries,
)
if __name__ == "__main__":
main()