350 lines
12 KiB
Python
350 lines
12 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Collect table and column lineage from Snowflake — collection only.
|
|
|
|
Queries ACCOUNT_USAGE for DML/DDL statements in the last 24 hours, parses each
|
|
QUERY_TEXT with regex to extract source and destination tables, then writes the
|
|
resulting lineage edges to a JSON manifest file.
|
|
|
|
Can be run standalone via CLI or imported (use the ``collect()`` function).
|
|
|
|
Note: ACCOUNT_USAGE views have an approximate latency of 45 minutes, so very
|
|
recent queries may not yet appear.
|
|
|
|
Substitution points
|
|
-------------------
|
|
- SNOWFLAKE_ACCOUNT (env) / --account (CLI) : Snowflake account identifier
|
|
- SNOWFLAKE_USER (env) / --user (CLI) : Snowflake username
|
|
- SNOWFLAKE_PASSWORD (env) / --password (CLI) : Snowflake password
|
|
- SNOWFLAKE_WAREHOUSE (env) / --warehouse (CLI) : Snowflake virtual warehouse
|
|
|
|
Prerequisites
|
|
-------------
|
|
pip install snowflake-connector-python
|
|
|
|
Usage (table-level):
|
|
python collect_lineage.py \\
|
|
--account <SNOWFLAKE_ACCOUNT> \\
|
|
--user <SNOWFLAKE_USER> \\
|
|
--password <SNOWFLAKE_PASSWORD> \\
|
|
--warehouse <SNOWFLAKE_WAREHOUSE>
|
|
|
|
Usage (column-level):
|
|
python collect_lineage.py ... --column-lineage
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import json
|
|
import os
|
|
import re
|
|
from dataclasses import dataclass, field
|
|
from datetime import datetime, timezone
|
|
|
|
import snowflake.connector
|
|
|
|
# ← SUBSTITUTE: set RESOURCE_TYPE to match your Monte Carlo connection type
|
|
RESOURCE_TYPE = "snowflake"
|
|
|
|
|
|
def _check_available_memory(min_gb: float = 2.0) -> None:
|
|
"""Warn if available memory is below the threshold."""
|
|
try:
|
|
if hasattr(os, "sysconf"): # Linux / macOS
|
|
page_size = os.sysconf("SC_PAGE_SIZE")
|
|
avail_pages = os.sysconf("SC_AVPHYS_PAGES")
|
|
avail_gb = (page_size * avail_pages) / (1024 ** 3)
|
|
else:
|
|
return # Windows — skip check
|
|
except (ValueError, OSError):
|
|
return
|
|
if avail_gb < min_gb:
|
|
print(
|
|
f"WARNING: Only {avail_gb:.1f} GB of memory available "
|
|
f"(minimum recommended: {min_gb:.1f} GB). "
|
|
f"Consider reducing the lookback window or increasing available memory."
|
|
)
|
|
|
|
# Hours to look back in ACCOUNT_USAGE.QUERY_HISTORY
|
|
# ← SUBSTITUTE: adjust the lookback window to match your collection cadence
|
|
_LOOKBACK_HOURS = 24
|
|
|
|
# Regex for CTAS: CREATE [OR REPLACE] [TRANSIENT] TABLE [IF NOT EXISTS] [db.][schema.]table AS SELECT
|
|
_CTAS_RE = re.compile(
|
|
r"CREATE\s+(?:OR\s+REPLACE\s+)?(?:TRANSIENT\s+)?TABLE\s+(?:IF\s+NOT\s+EXISTS\s+)?"
|
|
r"(?:(?P<dest_db>\w+)\.)?(?:(?P<dest_schema>\w+)\.)?(?P<dest_table>\w+)"
|
|
r".*?AS\s+SELECT\s+(?P<select_cols>.+?)\s+FROM\s+"
|
|
r"(?:(?P<src_db>\w+)\.)?(?:(?P<src_schema>\w+)\.)?(?P<src_table>\w+)",
|
|
re.IGNORECASE | re.DOTALL,
|
|
)
|
|
|
|
# Regex for INSERT INTO [db.][schema.]table SELECT ... FROM [db.][schema.]table
|
|
_INSERT_RE = re.compile(
|
|
r"INSERT\s+(?:INTO|OVERWRITE)\s+"
|
|
r"(?:(?P<dest_db>\w+)\.)?(?:(?P<dest_schema>\w+)\.)?(?P<dest_table>\w+)"
|
|
r".*?SELECT\s+(?P<select_cols>.+?)\s+FROM\s+"
|
|
r"(?:(?P<src_db>\w+)\.)?(?:(?P<src_schema>\w+)\.)?(?P<src_table>\w+)",
|
|
re.IGNORECASE | re.DOTALL,
|
|
)
|
|
|
|
# Regex for CREATE [OR REPLACE] VIEW [db.][schema.]view AS SELECT ... FROM ...
|
|
_CREATE_VIEW_RE = re.compile(
|
|
r"CREATE\s+(?:OR\s+REPLACE\s+)?(?:SECURE\s+)?VIEW\s+"
|
|
r"(?:(?P<dest_db>\w+)\.)?(?:(?P<dest_schema>\w+)\.)?(?P<dest_table>\w+)"
|
|
r".*?AS\s+SELECT\s+(?P<select_cols>.+?)\s+FROM\s+"
|
|
r"(?:(?P<src_db>\w+)\.)?(?:(?P<src_schema>\w+)\.)?(?P<src_table>\w+)",
|
|
re.IGNORECASE | re.DOTALL,
|
|
)
|
|
|
|
# Additional JOIN sources
|
|
_JOIN_RE = re.compile(
|
|
r"JOIN\s+(?:(?P<src_db>\w+)\.)?(?:(?P<src_schema>\w+)\.)?(?P<src_table>\w+)",
|
|
re.IGNORECASE,
|
|
)
|
|
|
|
# Simple column alias extraction from SELECT clause
|
|
_COL_RE = re.compile(r"(?:(\w+)\.)?(\w+)(?:\s+AS\s+(\w+))?", re.IGNORECASE)
|
|
_SQL_KEYWORDS = {
|
|
"FROM", "SELECT", "WHERE", "JOIN", "ON", "AS", "*", "AND", "OR",
|
|
"GROUP", "ORDER", "BY", "HAVING", "LIMIT", "DISTINCT", "CASE", "WHEN",
|
|
"THEN", "ELSE", "END", "NULL", "NOT", "IN", "IS", "BETWEEN",
|
|
}
|
|
|
|
|
|
@dataclass
|
|
class _LineageEdge:
|
|
dest_db: str
|
|
dest_schema: str
|
|
dest_table: str
|
|
sources: list[tuple[str, str, str]] = field(default_factory=list)
|
|
# col_mappings: (dest_col, src_table, src_col)
|
|
col_mappings: list[tuple[str, str, str]] = field(default_factory=list)
|
|
|
|
|
|
def _parse_select_cols(select_clause: str, src_table: str) -> list[tuple[str, str, str]]:
|
|
mappings = []
|
|
for m in _COL_RE.finditer(select_clause):
|
|
src_col = m.group(2)
|
|
dest_col = m.group(3) or src_col
|
|
if src_col.upper() in _SQL_KEYWORDS:
|
|
continue
|
|
mappings.append((dest_col, src_table, src_col))
|
|
return mappings
|
|
|
|
|
|
def _parse_edges(rows: list[dict]) -> list[_LineageEdge]:
|
|
"""Parse QUERY_HISTORY rows into _LineageEdge objects."""
|
|
edges: dict[str, _LineageEdge] = {}
|
|
|
|
for row in rows:
|
|
query_text = row.get("QUERY_TEXT") or ""
|
|
default_db = (row.get("DATABASE_NAME") or "").lower()
|
|
sql_clean = re.sub(r"\s+", " ", query_text).strip()
|
|
|
|
for pattern in (_CTAS_RE, _INSERT_RE, _CREATE_VIEW_RE):
|
|
m = pattern.search(sql_clean)
|
|
if not m:
|
|
continue
|
|
|
|
dest_db = (m.group("dest_db") or default_db).lower()
|
|
dest_schema = (m.group("dest_schema") or "public").lower()
|
|
dest_table = m.group("dest_table").lower()
|
|
src_db = (m.group("src_db") or default_db).lower()
|
|
src_schema = (m.group("src_schema") or "public").lower()
|
|
src_table = m.group("src_table").lower()
|
|
select_cols = m.group("select_cols")
|
|
|
|
key = f"{dest_db}.{dest_schema}.{dest_table}"
|
|
if key not in edges:
|
|
edges[key] = _LineageEdge(
|
|
dest_db=dest_db, dest_schema=dest_schema, dest_table=dest_table
|
|
)
|
|
|
|
edge = edges[key]
|
|
src_triple = (src_db, src_schema, src_table)
|
|
if src_triple not in edge.sources:
|
|
edge.sources.append(src_triple)
|
|
|
|
for jm in _JOIN_RE.finditer(sql_clean):
|
|
jt = jm.group("src_table").lower()
|
|
jschema = (jm.group("src_schema") or src_schema).lower()
|
|
jdb = (jm.group("src_db") or src_db).lower()
|
|
jp = (jdb, jschema, jt)
|
|
if jp not in edge.sources:
|
|
edge.sources.append(jp)
|
|
|
|
edge.col_mappings.extend(_parse_select_cols(select_cols, src_table))
|
|
break
|
|
|
|
return list(edges.values())
|
|
|
|
|
|
def _fetch_query_history(conn, lookback_hours: int) -> list[dict]:
|
|
cursor = conn.cursor()
|
|
cursor.execute(
|
|
f"""
|
|
SELECT QUERY_ID, QUERY_TEXT, START_TIME, END_TIME, USER_NAME, DATABASE_NAME, EXECUTION_STATUS
|
|
FROM SNOWFLAKE.ACCOUNT_USAGE.QUERY_HISTORY
|
|
WHERE START_TIME >= DATEADD(hour, -{lookback_hours}, CURRENT_TIMESTAMP())
|
|
AND EXECUTION_STATUS = 'SUCCESS'
|
|
AND QUERY_TYPE IN ('CREATE_TABLE_AS_SELECT', 'INSERT', 'MERGE', 'CREATE_VIEW')
|
|
ORDER BY START_TIME
|
|
LIMIT 50000
|
|
"""
|
|
# ← SUBSTITUTE: adjust QUERY_TYPE list, LIMIT, or add a WHERE clause to scope to specific databases
|
|
)
|
|
columns = [col[0] for col in cursor.description]
|
|
rows = []
|
|
while True:
|
|
batch = cursor.fetchmany(1000)
|
|
if not batch:
|
|
break
|
|
rows.extend(dict(zip(columns, row)) for row in batch)
|
|
cursor.close()
|
|
return rows
|
|
|
|
|
|
def collect(
|
|
account: str,
|
|
user: str,
|
|
password: str,
|
|
warehouse: str,
|
|
lookback_hours: int = _LOOKBACK_HOURS,
|
|
column_lineage: bool = False,
|
|
output_file: str = "lineage_output.json",
|
|
) -> dict:
|
|
"""
|
|
Connect to Snowflake, collect lineage edges, and write a JSON manifest.
|
|
|
|
Returns the manifest dict.
|
|
"""
|
|
_check_available_memory()
|
|
print(f"Connecting to Snowflake account: {account} ...")
|
|
conn = snowflake.connector.connect(
|
|
account=account,
|
|
user=user,
|
|
password=password,
|
|
warehouse=warehouse,
|
|
)
|
|
|
|
print(f"Fetching QUERY_HISTORY for the last {lookback_hours} hour(s) ...")
|
|
rows = _fetch_query_history(conn, lookback_hours)
|
|
conn.close()
|
|
print(f" Retrieved {len(rows)} qualifying query/queries.")
|
|
|
|
if not rows:
|
|
print("No lineage queries found in the specified window.")
|
|
manifest = {
|
|
"resource_type": RESOURCE_TYPE,
|
|
"collected_at": datetime.now(tz=timezone.utc).isoformat(),
|
|
"column_lineage": column_lineage,
|
|
"edges": [],
|
|
}
|
|
with open(output_file, "w") as fh:
|
|
json.dump(manifest, fh, indent=2)
|
|
return manifest
|
|
|
|
edges = _parse_edges(rows)
|
|
print(f" Parsed {len(edges)} lineage edge(s).")
|
|
|
|
manifest = {
|
|
"resource_type": RESOURCE_TYPE,
|
|
"collected_at": datetime.now(tz=timezone.utc).isoformat(),
|
|
"column_lineage": column_lineage,
|
|
"edges": [
|
|
{
|
|
"destination": {
|
|
"database": e.dest_db,
|
|
"schema": e.dest_schema,
|
|
"table": e.dest_table,
|
|
},
|
|
"sources": [
|
|
{"database": sdb, "schema": sschema, "table": stbl}
|
|
for sdb, sschema, stbl in e.sources
|
|
],
|
|
"col_mappings": [
|
|
{"dest_col": dc, "src_table": st, "src_col": sc}
|
|
for dc, st, sc in e.col_mappings
|
|
],
|
|
}
|
|
for e in edges
|
|
],
|
|
}
|
|
with open(output_file, "w") as fh:
|
|
json.dump(manifest, fh, indent=2)
|
|
print(f"Lineage manifest written to {output_file}")
|
|
|
|
return manifest
|
|
|
|
|
|
def main() -> None:
|
|
parser = argparse.ArgumentParser(
|
|
description="Collect Snowflake lineage from ACCOUNT_USAGE and write to a manifest file",
|
|
)
|
|
parser.add_argument(
|
|
"--account",
|
|
default=os.environ.get("SNOWFLAKE_ACCOUNT"),
|
|
help="Snowflake account identifier (env: SNOWFLAKE_ACCOUNT)",
|
|
)
|
|
parser.add_argument(
|
|
"--user",
|
|
default=os.environ.get("SNOWFLAKE_USER"),
|
|
help="Snowflake username (env: SNOWFLAKE_USER)",
|
|
)
|
|
parser.add_argument(
|
|
"--password",
|
|
default=os.environ.get("SNOWFLAKE_PASSWORD"),
|
|
help="Snowflake password (env: SNOWFLAKE_PASSWORD)",
|
|
)
|
|
parser.add_argument(
|
|
"--warehouse",
|
|
default=os.environ.get("SNOWFLAKE_WAREHOUSE"),
|
|
help="Snowflake virtual warehouse (env: SNOWFLAKE_WAREHOUSE)",
|
|
)
|
|
parser.add_argument(
|
|
"--lookback-hours",
|
|
type=int,
|
|
default=_LOOKBACK_HOURS,
|
|
help=f"Hours of QUERY_HISTORY to scan (default: {_LOOKBACK_HOURS})",
|
|
)
|
|
parser.add_argument(
|
|
"--column-lineage",
|
|
action="store_true",
|
|
help="Include column-level lineage mappings in the manifest",
|
|
)
|
|
parser.add_argument(
|
|
"--output-file",
|
|
default="lineage_output.json",
|
|
help="Path to write the lineage manifest (default: lineage_output.json)",
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
missing = [
|
|
name
|
|
for name, val in [
|
|
("--account", args.account),
|
|
("--user", args.user),
|
|
("--password", args.password),
|
|
("--warehouse", args.warehouse),
|
|
]
|
|
if not val
|
|
]
|
|
if missing:
|
|
parser.error(f"Missing required arguments: {', '.join(missing)}")
|
|
|
|
collect(
|
|
account=args.account,
|
|
user=args.user,
|
|
password=args.password,
|
|
warehouse=args.warehouse,
|
|
lookback_hours=args.lookback_hours,
|
|
column_lineage=args.column_lineage,
|
|
output_file=args.output_file,
|
|
)
|
|
print("Done.")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|