"""
Script to update existing activity_records in database with new categorization logic
Run this after updating activity_categorizer.py
"""

from sqlalchemy import create_engine, text
from sqlalchemy.orm import sessionmaker
import os
from datetime import datetime
from activity_categorizer import ActivityCategorizer
import sys

# Import your database configuration
from database import get_db, engine

def update_activity_categories():
    """
    Update all activity records in the database with new categorization
    """
    
    # Initialize categorizer
    categorizer = ActivityCategorizer()
    
    # Create session
    Session = sessionmaker(bind=engine)
    session = Session()
    
    try:
        print("Starting activity categorization update...")
        print("-" * 60)
        
        # Get all activity records
        query = text("""
            SELECT id, application_name, window_title, category
            FROM activity_records
            ORDER BY id
        """)
        
        records = session.execute(query).fetchall()
        total_records = len(records)
        
        print(f"Found {total_records} records to process")
        print("-" * 60)
        
        # Track changes
        changes = {
            'productive': {'to_browser': 0, 'to_server': 0, 'to_non_work': 0, 'unchanged': 0},
            'browser': {'to_productive': 0, 'to_server': 0, 'to_non_work': 0, 'unchanged': 0},
            'server': {'to_productive': 0, 'to_browser': 0, 'to_non_work': 0, 'unchanged': 0},
            'non-work': {'to_productive': 0, 'to_browser': 0, 'to_server': 0, 'unchanged': 0},
            'null_or_empty': 0
        }
        
        # Process in batches
        batch_size = 100
        updated_count = 0
        
        for i in range(0, total_records, batch_size):
            batch = records[i:i+batch_size]
            updates = []
            
            for record in batch:
                record_id = record.id
                app_name = record.application_name or ""
                window_title = record.window_title or ""
                old_category = record.category or ""
                
                # Get new category
                category_info = categorizer.get_detailed_category(window_title, app_name)
                new_category = category_info['category']
                
                # Track changes
                if not old_category:
                    changes['null_or_empty'] += 1
                elif old_category == new_category:
                    if old_category in changes:
                        changes[old_category]['unchanged'] += 1
                else:
                    if old_category in changes and f'to_{new_category.replace("-", "_")}' in changes[old_category]:
                        changes[old_category][f'to_{new_category.replace("-", "_")}'] += 1
                
                # Add to updates if category changed
                if old_category != new_category:
                    updates.append({
                        'id': record_id,
                        'new_category': new_category,
                        'old_category': old_category,
                        'window_title': window_title[:50]  # For logging
                    })
            
            # Execute batch update
            if updates:
                for update in updates:
                    update_query = text("""
                        UPDATE activity_records 
                        SET category = :new_category 
                        WHERE id = :id
                    """)
                    session.execute(update_query, {
                        'new_category': update['new_category'],
                        'id': update['id']
                    })
                
                session.commit()
                updated_count += len(updates)
                
                # Log some examples
                print(f"\nBatch {i//batch_size + 1}: Updated {len(updates)} records")
                for j, update in enumerate(updates[:3]):  # Show first 3 examples
                    print(f"  - '{update['window_title']}' : {update['old_category']} → {update['new_category']}")
                if len(updates) > 3:
                    print(f"  ... and {len(updates) - 3} more")
        
        print("\n" + "=" * 60)
        print("UPDATE COMPLETE!")
        print("=" * 60)
        print(f"\nTotal records processed: {total_records}")
        print(f"Total records updated: {updated_count}")
        print(f"Records with null/empty category: {changes['null_or_empty']}")
        
        print("\nCategory migration summary:")
        print("-" * 40)
        
        for old_cat in ['productive', 'browser', 'server', 'non-work']:
            if old_cat in changes:
                total_in_cat = sum(changes[old_cat].values())
                if total_in_cat > 0:
                    print(f"\n{old_cat.upper()}:")
                    print(f"  → browser: {changes[old_cat].get('to_browser', 0)}")
                    print(f"  → productive: {changes[old_cat].get('to_productive', 0)}")
                    print(f"  → server: {changes[old_cat].get('to_server', 0)}")
                    print(f"  → non-work: {changes[old_cat].get('to_non_work', 0)}")
                    print(f"  → unchanged: {changes[old_cat].get('unchanged', 0)}")
        
        # Show specific examples of incognito/new tab fixes
        print("\n" + "-" * 40)
        print("Checking for fixed 'New Incognito Tab' entries...")
        
        check_query = text("""
            SELECT window_title, category, application_name 
            FROM activity_records 
            WHERE LOWER(window_title) LIKE '%incognito%' 
               OR LOWER(window_title) LIKE '%new tab%'
            LIMIT 10
        """)
        
        fixed_records = session.execute(check_query).fetchall()
        
        if fixed_records:
            print(f"\nFound {len(fixed_records)} incognito/new tab entries:")
            for rec in fixed_records:
                print(f"  - '{rec.window_title}' → {rec.category}")
        else:
            print("No incognito/new tab entries found")
            
    except Exception as e:
        print(f"\nERROR: {str(e)}")
        session.rollback()
        raise
    finally:
        session.close()
        print("\nDatabase connection closed.")


def verify_update(sample_size=20):
    """
    Verify the update by checking a sample of records
    """
    Session = sessionmaker(bind=engine)
    session = Session()
    
    try:
        print("\n" + "=" * 60)
        print("VERIFICATION - Sampling updated records")
        print("=" * 60)
        
        # Check different categories
        for category in ['productive', 'browser', 'server', 'non-work']:
            print(f"\n{category.upper()} samples:")
            query = text("""
                SELECT window_title, application_name, category
                FROM activity_records
                WHERE category = :category
                ORDER BY RANDOM()
                LIMIT 5
            """)
            
            samples = session.execute(query, {'category': category}).fetchall()
            
            for sample in samples:
                title = sample.window_title[:60] if sample.window_title else "N/A"
                app = sample.application_name or "N/A"
                print(f"  - {title} ({app})")
                
    finally:
        session.close()


if __name__ == "__main__":
    print("Activity Category Update Script")
    print("==============================\n")
    
    # Confirm before proceeding
    confirm = input("This will update all activity records in the database. Continue? (yes/no): ")
    
    if confirm.lower() in ['yes', 'y']:
        # Run the update
        update_activity_categories()
        
        # Verify the update
        verify_check = input("\nDo you want to verify the update with sample records? (yes/no): ")
        if verify_check.lower() in ['yes', 'y']:
            verify_update()
            
        print("\nUpdate completed successfully!")
    else:
        print("Update cancelled.")