170 lines
6.9 KiB
Python
170 lines
6.9 KiB
Python
# /opt/docker/dev/service_finder/backend/app/scripts/sync_engine.py
|
||
#!/usr/bin/env python3
|
||
"""
|
||
Universal Schema Synchronizer
|
||
|
||
Dynamically imports all SQLAlchemy models from app.models, compares them with the live database,
|
||
and creates missing tables/columns without dropping anything.
|
||
|
||
Safety First:
|
||
- NEVER drops tables or columns.
|
||
- Prints planned SQL before execution.
|
||
- Requires confirmation for destructive operations (none in this script).
|
||
"""
|
||
|
||
import asyncio
|
||
import importlib
|
||
import os
|
||
import sys
|
||
from pathlib import Path
|
||
from sqlalchemy.ext.asyncio import create_async_engine
|
||
from sqlalchemy import inspect, text
|
||
from sqlalchemy.schema import CreateTable, AddConstraint
|
||
from sqlalchemy.sql.ddl import CreateColumn
|
||
|
||
# Add backend to path
|
||
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
||
|
||
from app.database import Base
|
||
from app.core.config import settings
|
||
|
||
def dynamic_import_models():
|
||
"""
|
||
Dynamically import all .py files in app.models directory to ensure Base.metadata is populated.
|
||
"""
|
||
models_dir = Path(__file__).parent.parent / "models"
|
||
imported = []
|
||
|
||
for py_file in models_dir.glob("*.py"):
|
||
if py_file.name == "__init__.py":
|
||
continue
|
||
module_name = f"app.models.{py_file.stem}"
|
||
try:
|
||
module = importlib.import_module(module_name)
|
||
imported.append(module_name)
|
||
print(f"✅ Imported {module_name}")
|
||
except Exception as e:
|
||
print(f"⚠️ Could not import {module_name}: {e}")
|
||
|
||
# Also ensure the __init__ is loaded (it imports many models manually)
|
||
import app.models
|
||
print(f"📦 Total tables in Base.metadata: {len(Base.metadata.tables)}")
|
||
return imported
|
||
|
||
async def compare_and_repair():
|
||
"""
|
||
Compare SQLAlchemy metadata with live database and create missing tables/columns.
|
||
"""
|
||
print("🔗 Connecting to database...")
|
||
engine = create_async_engine(str(settings.SQLALCHEMY_DATABASE_URI))
|
||
|
||
def get_diff_and_repair(connection):
|
||
inspector = inspect(connection)
|
||
|
||
# Get all schemas from models
|
||
expected_schemas = sorted({t.schema for t in Base.metadata.sorted_tables if t.schema})
|
||
print(f"📋 Expected schemas: {expected_schemas}")
|
||
|
||
# Ensure enum types exist in marketplace schema
|
||
if 'marketplace' in expected_schemas:
|
||
print("\n🔧 Ensuring enum types in marketplace schema...")
|
||
# moderation_status enum
|
||
connection.execute(text("""
|
||
DO $$
|
||
BEGIN
|
||
IF NOT EXISTS (SELECT 1 FROM pg_type WHERE typname = 'moderation_status' AND typnamespace = (SELECT oid FROM pg_namespace WHERE nspname = 'marketplace')) THEN
|
||
CREATE TYPE marketplace.moderation_status AS ENUM ('pending', 'approved', 'rejected');
|
||
END IF;
|
||
END $$;
|
||
"""))
|
||
# source_type enum
|
||
connection.execute(text("""
|
||
DO $$
|
||
BEGIN
|
||
IF NOT EXISTS (SELECT 1 FROM pg_type WHERE typname = 'source_type' AND typnamespace = (SELECT oid FROM pg_namespace WHERE nspname = 'marketplace')) THEN
|
||
CREATE TYPE marketplace.source_type AS ENUM ('manual', 'ocr', 'import');
|
||
END IF;
|
||
END $$;
|
||
"""))
|
||
print("✅ Enum types ensured.")
|
||
|
||
for schema in expected_schemas:
|
||
print(f"\n--- 🔍 Checking schema '{schema}' ---")
|
||
|
||
# Check if schema exists
|
||
db_schemas = inspector.get_schema_names()
|
||
if schema not in db_schemas:
|
||
print(f"❌ Schema '{schema}' missing. Creating...")
|
||
connection.execute(text(f'CREATE SCHEMA IF NOT EXISTS "{schema}"'))
|
||
print(f"✅ Schema '{schema}' created.")
|
||
|
||
# Get tables in this schema from models
|
||
model_tables = [t for t in Base.metadata.sorted_tables if t.schema == schema]
|
||
db_tables = inspector.get_table_names(schema=schema)
|
||
|
||
for table in model_tables:
|
||
if table.name not in db_tables:
|
||
print(f"❌ Missing table: {schema}.{table.name}")
|
||
# Generate CREATE TABLE statement
|
||
create_stmt = CreateTable(table)
|
||
# Print SQL for debugging
|
||
sql_str = str(create_stmt.compile(bind=engine))
|
||
print(f" SQL: {sql_str}")
|
||
connection.execute(create_stmt)
|
||
print(f"✅ Table {schema}.{table.name} created.")
|
||
else:
|
||
# Check columns
|
||
db_columns = {c['name']: c for c in inspector.get_columns(table.name, schema=schema)}
|
||
model_columns = table.columns
|
||
|
||
missing_cols = []
|
||
for col in model_columns:
|
||
if col.name not in db_columns:
|
||
missing_cols.append(col)
|
||
|
||
if missing_cols:
|
||
print(f"⚠️ Table {schema}.{table.name} missing columns: {[c.name for c in missing_cols]}")
|
||
for col in missing_cols:
|
||
# Generate ADD COLUMN statement
|
||
col_type = col.type.compile(dialect=engine.dialect)
|
||
sql = f'ALTER TABLE "{schema}"."{table.name}" ADD COLUMN "{col.name}" {col_type}'
|
||
if col.nullable is False:
|
||
sql += " NOT NULL"
|
||
if col.default is not None:
|
||
# Handle default values (simplistic)
|
||
sql += f" DEFAULT {col.default.arg}"
|
||
print(f" SQL: {sql}")
|
||
connection.execute(text(sql))
|
||
print(f"✅ Column {col.name} added.")
|
||
else:
|
||
print(f"✅ Table {schema}.{table.name} is up‑to‑date.")
|
||
|
||
print("\n--- ✅ Schema synchronization complete. ---")
|
||
|
||
async with engine.begin() as conn:
|
||
await conn.run_sync(get_diff_and_repair)
|
||
|
||
await engine.dispose()
|
||
|
||
async def main():
|
||
print("🚀 Universal Schema Synchronizer")
|
||
print("=" * 50)
|
||
|
||
# Step 1: Dynamic import
|
||
print("\n📥 Step 1: Dynamically importing all models...")
|
||
dynamic_import_models()
|
||
|
||
# Step 2: Compare and repair
|
||
print("\n🔧 Step 2: Comparing with database and repairing...")
|
||
await compare_and_repair()
|
||
|
||
# Step 3: Final verification
|
||
print("\n📊 Step 3: Final verification...")
|
||
# Run compare_schema.py logic to confirm everything is green
|
||
from app.tests_internal.diagnostics.compare_schema import compare
|
||
await compare()
|
||
|
||
print("\n✨ Synchronization finished successfully!")
|
||
|
||
if __name__ == "__main__":
|
||
asyncio.run(main()) |