""" Redshift — Query Log Collection (collect-only) ================================================ Collects completed query execution records from Redshift using sys_query_history and sys_querytext (modern RA3/serverless), assembles full SQL text from multi-row text chunks, and writes a JSON manifest file that can be consumed by push_query_logs.py. Substitution points (search for "← SUBSTITUTE"): - REDSHIFT_HOST / REDSHIFT_DB / REDSHIFT_USER / REDSHIFT_PASSWORD : connection - LOOKBACK_HOURS : hours back from [now - LAG_HOURS] to collect (default 25) - LOOKBACK_LAG_HOURS: lag behind now to avoid in-flight queries (default 1) - BATCH_SIZE : number of query_ids to fetch texts for in one SQL call - MAX_QUERIES : maximum query rows to process per run Prerequisites: pip install psycopg2-binary """ from __future__ import annotations import argparse import ipaddress import json import logging import os import re from datetime import datetime, timezone from typing import Any import psycopg2 from _safe_paths import safe_existing_directory, safe_input_json_path, safe_output_json_path, write_json_file logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") log = logging.getLogger(__name__) LOG_TYPE = "redshift" LOOKBACK_HOURS: int = int(os.getenv("LOOKBACK_HOURS", "25")) # ← SUBSTITUTE LOOKBACK_LAG_HOURS: int = int(os.getenv("LOOKBACK_LAG_HOURS", "1")) # ← SUBSTITUTE BATCH_SIZE: int = int(os.getenv("BATCH_SIZE", "200")) # ← SUBSTITUTE MAX_QUERIES: int = int(os.getenv("MAX_QUERIES", "10000")) # ← SUBSTITUTE _ALLOWED_REDSHIFT_HOST_RE = re.compile( r"^[a-z0-9][a-z0-9.-]*\.(?:redshift|redshift-serverless)\.[a-z0-9-]+\.amazonaws\.com(?:\.cn)?$", re.IGNORECASE, ) def _explicitly_allowed_redshift_hosts() -> set[str]: raw_hosts = os.getenv("REDSHIFT_ALLOWED_HOSTS", "") return {host.strip().lower().rstrip(".") for host in raw_hosts.split(",") if host.strip()} def validate_redshift_host(host: str, *, allow_private: bool = False) -> str: value = str(host).strip() if not value or any(part in value for part in ("/", "\\", "@", ":")): raise ValueError(f"Invalid Redshift host: {host!r}") hostname = value.lower().rstrip(".") allowed_hosts = _explicitly_allowed_redshift_hosts() try: address = ipaddress.ip_address(value) except ValueError: if hostname in allowed_hosts: return hostname match = _ALLOWED_REDSHIFT_HOST_RE.fullmatch(hostname) if match: return match.group(0) raise ValueError( "Redshift host must be an AWS Redshift endpoint or be listed in REDSHIFT_ALLOWED_HOSTS" ) if hostname not in allowed_hosts: raise ValueError("Redshift IP hosts must be listed in REDSHIFT_ALLOWED_HOSTS") blocked = ( address.is_loopback or address.is_link_local or address.is_multicast or address.is_unspecified or address.is_reserved or (address.is_private and not allow_private) ) if blocked: raise ValueError(f"Redshift host address is not allowed: {host!r}") return str(address) def _bounded_int(value: int, field: str, *, minimum: int, maximum: int) -> int: value = int(value) if value < minimum or value > maximum: raise ValueError(f"{field} must be between {minimum} and {maximum}") return value def _check_available_memory(min_gb: float = 2.0) -> None: """Warn if available memory is below the threshold.""" try: if hasattr(os, "sysconf"): # Linux / macOS page_size = os.sysconf("SC_PAGE_SIZE") avail_pages = os.sysconf("SC_AVPHYS_PAGES") avail_gb = (page_size * avail_pages) / (1024 ** 3) else: return # Windows — skip check except (ValueError, OSError): return if avail_gb < min_gb: log.warning( "Only %.1f GB of memory available (minimum recommended: %.1f GB). " "Consider reducing the collection scope or increasing available memory.", avail_gb, min_gb, ) def _dictfetch(cursor: Any, sql: str, params: tuple | None = None) -> list[dict[str, Any]]: cursor.execute(sql, params) cols = [d.name for d in cursor.description] rows = [] while True: chunk = cursor.fetchmany(1000) if not chunk: break rows.extend(dict(zip(cols, row)) for row in chunk) return rows def _safe_isoformat(dt: Any) -> str | None: if dt is None: return None if hasattr(dt, "isoformat"): if dt.tzinfo is None: dt = dt.replace(tzinfo=timezone.utc) return dt.isoformat() return str(dt) def fetch_query_metadata( cursor: Any, lookback_hours: int, lag_hours: int, max_queries: int, ) -> list[dict[str, Any]]: """Fetch query execution metadata from sys_query_history.""" lookback_hours = _bounded_int(lookback_hours, "lookback_hours", minimum=1, maximum=24 * 31) lag_hours = _bounded_int(lag_hours, "lag_hours", minimum=0, maximum=24 * 7) max_queries = _bounded_int(max_queries, "max_queries", minimum=1, maximum=100000) return _dictfetch( cursor, """ SELECT query_id, start_time, end_time, status, user_id, database_name, elapsed_time FROM sys_query_history WHERE start_time >= DATEADD(hour, -%s, GETDATE()) AND start_time < DATEADD(hour, -%s, GETDATE()) AND status = 'success' ORDER BY start_time LIMIT %s """, # ← SUBSTITUTE: add AND database_name = 'mydb' to narrow scope (lookback_hours, lag_hours, max_queries), ) def fetch_query_texts_batch(cursor: Any, query_ids: list[int]) -> dict[int, str]: """Batch-fetch and assemble multi-row query texts for a list of query_ids.""" if not query_ids: return {} query_ids = [_bounded_int(qid, "query_id", minimum=1, maximum=2**63 - 1) for qid in query_ids] rows = _dictfetch( cursor, """ SELECT query_id, LISTAGG( CASE WHEN LEN(text) <= 200 THEN text ELSE LEFT(text, 200) END, '' ) WITHIN GROUP (ORDER BY sequence) AS query_text FROM sys_querytext WHERE query_id = ANY(%s) GROUP BY query_id """, (query_ids,), ) return {r["query_id"]: r["query_text"] for r in rows if r.get("query_text")} def collect( host: str, db: str, user: str, password: str, manifest_path: str = "manifest_query_logs.json", port: int = 5439, lookback_hours: int = LOOKBACK_HOURS, lookback_lag_hours: int = LOOKBACK_LAG_HOURS, batch_size: int = BATCH_SIZE, max_queries: int = MAX_QUERIES, ) -> list[dict[str, Any]]: """Connect to Redshift, collect query logs, write a JSON manifest, and return entries.""" _check_available_memory() allow_private_host = os.getenv("REDSHIFT_ALLOW_PRIVATE_HOST", "").lower() in {"1", "true", "yes"} host = validate_redshift_host(host, allow_private=allow_private_host) port = _bounded_int(port, "port", minimum=1, maximum=65535) lookback_hours = _bounded_int(lookback_hours, "lookback_hours", minimum=1, maximum=24 * 31) lookback_lag_hours = _bounded_int(lookback_lag_hours, "lookback_lag_hours", minimum=0, maximum=24 * 7) batch_size = _bounded_int(batch_size, "batch_size", minimum=1, maximum=10000) max_queries = _bounded_int(max_queries, "max_queries", minimum=1, maximum=100000) collected_at = datetime.now(timezone.utc).isoformat() conn = psycopg2.connect( host=host, port=port, dbname=db, user=user, password=password, connect_timeout=30, ) try: with conn.cursor() as cursor: query_meta = fetch_query_metadata(cursor, lookback_hours, lookback_lag_hours, max_queries) log.info("Retrieved %d query metadata rows", len(query_meta)) # Batch-fetch texts to avoid enormous single queries query_ids = [r["query_id"] for r in query_meta] text_map: dict[int, str] = {} for i in range(0, len(query_ids), batch_size): batch = query_ids[i : i + batch_size] text_map.update(fetch_query_texts_batch(cursor, batch)) log.debug("Fetched texts for batch %d–%d", i, i + len(batch)) finally: conn.close() entries: list[dict[str, Any]] = [] for row in query_meta: qid = row["query_id"] query_text = text_map.get(qid, "") if not query_text.strip(): continue # ← SUBSTITUTE: decide whether to push rows with missing text entry = { "query_id": str(qid), "query_text": query_text, "start_time": _safe_isoformat(row.get("start_time")), "end_time": _safe_isoformat(row.get("end_time")), "user": str(row.get("user_id")) if row.get("user_id") is not None else None, "database_name": row.get("database_name"), "elapsed_time_us": row.get("elapsed_time"), } entries.append(entry) log.info("Collected %d query log entries", len(entries)) manifest = { "log_type": LOG_TYPE, "collected_at": collected_at, "lookback_hours": lookback_hours, "lookback_lag_hours": lookback_lag_hours, "query_log_count": len(entries), "entries": entries, } write_json_file(manifest_path, manifest) log.info("Manifest written to %s (%d entries)", manifest_path, len(entries)) return entries def main() -> None: parser = argparse.ArgumentParser(description="Collect Redshift query logs to a manifest file") parser.add_argument("--db", default=os.getenv("REDSHIFT_DB")) # ← SUBSTITUTE parser.add_argument("--user", default=os.getenv("REDSHIFT_USER")) # ← SUBSTITUTE parser.add_argument("--password", default=os.getenv("REDSHIFT_PASSWORD")) # ← SUBSTITUTE parser.add_argument("--port", type=int, default=int(os.getenv("REDSHIFT_PORT", "5439"))) parser.add_argument("--lookback-hours", type=int, default=LOOKBACK_HOURS) parser.add_argument("--lookback-lag-hours", type=int, default=LOOKBACK_LAG_HOURS) parser.add_argument("--batch-size", type=int, default=BATCH_SIZE) parser.add_argument("--max-queries", type=int, default=MAX_QUERIES) parser.add_argument("--manifest", default="manifest_query_logs.json") args = parser.parse_args() required = ["db", "user", "password"] missing = [k for k in required if getattr(args, k) is None] if missing: parser.error(f"Missing required arguments/env vars: {missing}") redshift_host = os.getenv("REDSHIFT_HOST") if not redshift_host: parser.error("Missing required env var: REDSHIFT_HOST") redshift_host = validate_redshift_host( redshift_host, allow_private=os.getenv("REDSHIFT_ALLOW_PRIVATE_HOST", "").lower() in {"1", "true", "yes"}, ) collect( host=redshift_host, db=args.db, user=args.user, password=args.password, manifest_path=args.manifest, port=args.port, lookback_hours=args.lookback_hours, lookback_lag_hours=args.lookback_lag_hours, batch_size=args.batch_size, max_queries=args.max_queries, ) if __name__ == "__main__": main()