# activity_categorization_api.py
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlalchemy.orm import Session
from sqlalchemy import text
from datetime import datetime, timezone
from typing import Optional, List, Dict
from database import get_db
from activity_categorizer import ActivityCategorizer
import json

router = APIRouter()

@router.get("/api/activity-categories/{developer_id}")
async def get_categorized_activities(
    developer_id: str,
    start_date: Optional[str] = Query(None),
    end_date: Optional[str] = Query(None),
    db: Session = Depends(get_db)
):
    """Get activities categorized into Productive, Browser, and Server categories"""
    try:
        # Initialize categorizer
        categorizer = ActivityCategorizer()
        
        # Parse dates or use defaults
        if start_date:
            start = datetime.fromisoformat(start_date.replace('Z', '+00:00'))
        else:
            start = datetime.now(timezone.utc).replace(hour=0, minute=0, second=0, microsecond=0)
            
        if end_date:
            end = datetime.fromisoformat(end_date.replace('Z', '+00:00'))
        else:
            end = datetime.now(timezone.utc)
        
        # Debug: Check total duration in database
        debug_query = text("""
            SELECT 
                COUNT(*) as total_records,
                SUM(duration) as total_duration,
                AVG(duration) as avg_duration,
                MIN(duration) as min_duration,
                MAX(duration) as max_duration
            FROM activity_records
            WHERE developer_id = :dev_id
            AND timestamp >= :start_date
            AND timestamp <= :end_date
        """)
        
        debug_result = db.execute(debug_query, {
            "dev_id": developer_id,
            "start_date": start,
            "end_date": end
        }).fetchone()
        
        print(f"Debug - Total records: {debug_result[0]}, Total duration: {debug_result[1]}, "
              f"Avg: {debug_result[2]}, Min: {debug_result[3]}, Max: {debug_result[4]}")
        
        # Fetch activities from database
        query = text("""
            SELECT 
                id,
                developer_id,
                application_name,
                window_title,
                duration,
                timestamp,
                url,
                file_path,
                project_name,
                project_type,
                category
            FROM activity_records
            WHERE developer_id = :dev_id
            AND timestamp >= :start_date
            AND timestamp <= :end_date
            ORDER BY timestamp DESC
        """)
        
        result = db.execute(query, {
            "dev_id": developer_id,
            "start_date": start,
            "end_date": end
        }).fetchall()
        
        # Group activities by category AND window title
        activities_grouped = {
            "productive": {},
            "browser": {},
            "server": {},
            "non-work": {}
        }
        
        category_stats = {
            "productive": {"count": 0, "duration": 0},
            "browser": {"count": 0, "duration": 0},
            "server": {"count": 0, "duration": 0},
            "non-work": {"count": 0, "duration": 0}
        }
        
        total_duration = 0
        
        for row in result:
            activity = {
                "id": row[0],
                "developer_id": row[1],
                "application_name": row[2] or "",
                "window_title": row[3] or "",
                "duration": row[4] or 0,
                "timestamp": row[5].isoformat() if row[5] else None,
                "url": row[6],
                "file_path": row[7],
                "project_name": row[8],
                "project_type": row[9],
                "existing_category": row[10]
            }
            
            # Get categorization
            category_info = categorizer.get_detailed_category(
                activity["window_title"],
                activity["application_name"]
            )
            
            # Normalize the window title for grouping (remove extra spaces, convert to lower for comparison)
            window_title_key = activity["window_title"].strip()
            if not window_title_key:
                window_title_key = "Untitled"
            
            category = category_info["category"]
            # Ensure we only use valid categories
            if category not in activities_grouped:
                category = "browser"
            
            # Group activities by window title within each category
            if window_title_key not in activities_grouped[category]:
                # First occurrence - create the grouped entry
                activities_grouped[category][window_title_key] = {
                    "window_title": window_title_key,
                    "application_name": activity["application_name"],
                    "duration": 0,
                    "activity_count": 0,
                    "first_timestamp": activity["timestamp"],
                    "last_timestamp": activity["timestamp"],
                    "category": category,
                    "subcategory": category_info["subcategory"],
                    "confidence": category_info["confidence"],
                    "url": activity["url"],
                    "file_path": activity["file_path"],
                    "project_name": activity["project_name"],
                    "project_type": activity["project_type"],
                    "individual_activities": []  # Store individual activities if needed
                }
            
            # Add to existing entry
            grouped_activity = activities_grouped[category][window_title_key]
            grouped_activity["duration"] += activity["duration"]
            grouped_activity["activity_count"] += 1
            grouped_activity["last_timestamp"] = activity["timestamp"]
            
            # Keep track of individual activities if needed for debugging
            grouped_activity["individual_activities"].append({
                "id": activity["id"],
                "duration": activity["duration"],
                "timestamp": activity["timestamp"]
            })
            
            # Update statistics
            category_stats[category]["count"] += 1
            category_stats[category]["duration"] += activity["duration"]
            total_duration += activity["duration"]
        
        # Convert grouped dictionaries to lists
        activities_by_category = {}
        for cat, grouped_dict in activities_grouped.items():
            activities_list = []
            for window_title, grouped_activity in grouped_dict.items():
                # Calculate duration in hours/minutes for display
                # FIXED: Duration is already in seconds, no need to divide by 1000
                duration_seconds = grouped_activity["duration"]
                if duration_seconds >= 3600:
                    grouped_activity["duration_hours"] = round(duration_seconds / 3600, 2)
                    grouped_activity["duration_display"] = f"{grouped_activity['duration_hours']}h"
                elif duration_seconds >= 60:
                    duration_minutes = round(duration_seconds / 60, 1)
                    grouped_activity["duration_hours"] = round(duration_seconds / 3600, 4)  # Keep precise value
                    grouped_activity["duration_display"] = f"{duration_minutes}m"
                else:
                    grouped_activity["duration_hours"] = round(duration_seconds / 3600, 4)  # Keep precise value
                    grouped_activity["duration_display"] = f"{round(duration_seconds, 1)}s"
                # Remove individual_activities from response to reduce payload
                grouped_activity.pop("individual_activities", None)
                activities_list.append(grouped_activity)
            
            # Sort by duration descending
            activities_list.sort(key=lambda x: x["duration"], reverse=True)
            activities_by_category[cat] = activities_list
        
        # Calculate percentages
        for category in category_stats:
            if total_duration > 0:
                category_stats[category]["percentage"] = (
                    category_stats[category]["duration"] / total_duration * 100
                )
            else:
                category_stats[category]["percentage"] = 0
            
            # Format duration
            # FIXED: Duration is already in seconds
            duration_seconds = category_stats[category]["duration"]
            hours = duration_seconds / 3600
            category_stats[category]["duration_hours"] = round(hours, 4)  # Keep precise value
            
            # Add human-readable duration
            if duration_seconds >= 3600:
                category_stats[category]["duration_display"] = f"{round(hours, 1)}h"
            elif duration_seconds >= 60:
                category_stats[category]["duration_display"] = f"{round(duration_seconds / 60, 1)}m"
            else:
                category_stats[category]["duration_display"] = f"{round(duration_seconds, 1)}s"
        
        # Get top 10 activities for each category (already sorted by duration)
        top_activities_by_category = {}
        for cat, activities in activities_by_category.items():
            top_activities_by_category[cat] = activities[:10]
        
        return {
            "developer_id": developer_id,
            "date_range": {
                "start": start.isoformat(),
                "end": end.isoformat()
            },
            "statistics": category_stats,
            "total_duration_hours": round(total_duration / 3600, 2),  # FIXED: Divide by 3600, not 3600000
            "activities_by_category": activities_by_category,  # All grouped activities
            "top_activities_by_category": top_activities_by_category,  # Top 10 per category
            "productivity_score": calculate_productivity_score(category_stats)
        }
        
    except Exception as e:
        raise HTTPException(
            status_code=500,
            detail=f"Error categorizing activities: {str(e)}"
        )

@router.post("/api/update-activity-categories/{developer_id}")
async def update_activity_categories(
    developer_id: str,
    start_date: Optional[str] = Query(None),
    end_date: Optional[str] = Query(None),
    db: Session = Depends(get_db)
):
    """Update the category field in the database for all activities"""
    try:
        categorizer = ActivityCategorizer()
        
        # Parse dates
        if start_date:
            start = datetime.fromisoformat(start_date.replace('Z', '+00:00'))
        else:
            start = datetime.now(timezone.utc).replace(hour=0, minute=0, second=0, microsecond=0)
            
        if end_date:
            end = datetime.fromisoformat(end_date.replace('Z', '+00:00'))
        else:
            end = datetime.now(timezone.utc)
        
        # Fetch activities
        query = text("""
            SELECT id, window_title, application_name
            FROM activity_records
            WHERE developer_id = :dev_id
            AND timestamp >= :start_date
            AND timestamp <= :end_date
        """)
        
        result = db.execute(query, {
            "dev_id": developer_id,
            "start_date": start,
            "end_date": end
        }).fetchall()
        
        updated_count = 0
        
        for row in result:
            activity_id = row[0]
            window_title = row[1] or ""
            app_name = row[2] or ""
            
            # Get category
            category_info = categorizer.get_detailed_category(window_title, app_name)
            
            # Update database
            update_query = text("""
                UPDATE activity_records
                SET 
                    category = :category,
                    subcategory = :subcategory,
                    category_confidence = :confidence
                WHERE id = :activity_id
            """)
            
            db.execute(update_query, {
                "category": category_info["category"],
                "subcategory": category_info["subcategory"],
                "confidence": category_info["confidence"],
                "activity_id": activity_id
            })
            
            updated_count += 1
        
        db.commit()
        
        return {
            "success": True,
            "updated_count": updated_count,
            "message": f"Successfully updated {updated_count} activities"
        }
        
    except Exception as e:
        db.rollback()
        raise HTTPException(
            status_code=500,
            detail=f"Error updating categories: {str(e)}"
        )

@router.get("/api/category-summary")
async def get_category_summary(
    start_date: Optional[str] = Query(None),
    end_date: Optional[str] = Query(None),
    db: Session = Depends(get_db)
):
    """Get summary of all activities by category across all developers"""
    try:
        # Parse dates
        if start_date:
            start = datetime.fromisoformat(start_date.replace('Z', '+00:00'))
        else:
            start = datetime.now(timezone.utc).replace(hour=0, minute=0, second=0, microsecond=0)
            
        if end_date:
            end = datetime.fromisoformat(end_date.replace('Z', '+00:00'))
        else:
            end = datetime.now(timezone.utc)
        
        # Get summary by category
        query = text("""
            SELECT 
                CASE 
                    WHEN category IN ('productive', 'browser', 'server', 'non-work') THEN category
                    ELSE 'browser'
                END as category,
                COUNT(*) as activity_count,
                SUM(duration) as total_duration,
                COUNT(DISTINCT developer_id) as developer_count
            FROM activity_records
            WHERE timestamp >= :start_date
            AND timestamp <= :end_date
            GROUP BY 
                CASE 
                    WHEN category IN ('productive', 'browser', 'server', 'non-work') THEN category
                    ELSE 'browser'
                END
            ORDER BY total_duration DESC
        """)
        
        result = db.execute(query, {
            "start_date": start,
            "end_date": end
        }).fetchall()
        
        summary = []
        total_duration = 0
        
        for row in result:
            category_data = {
                "category": row[0],
                "activity_count": row[1],
                "total_duration_seconds": row[2] or 0,
                "total_duration_hours": round((row[2] or 0) / 3600, 2),  # FIXED: Divide by 3600, not 3600000
                "developer_count": row[3]
            }
            summary.append(category_data)
            total_duration += category_data["total_duration_seconds"]
        
        # Add percentages
        for item in summary:
            if total_duration > 0:
                item["percentage"] = round(
                    item["total_duration_seconds"] / total_duration * 100, 2
                )
            else:
                item["percentage"] = 0
        
        return {
            "date_range": {
                "start": start.isoformat(),
                "end": end.isoformat()
            },
            "summary": summary,
            "total_duration_hours": round(total_duration / 3600, 2)  # FIXED: Divide by 3600, not 3600000
        }
        
    except Exception as e:
        raise HTTPException(
            status_code=500,
            detail=f"Error getting category summary: {str(e)}"
        )

def calculate_productivity_score(category_stats: Dict) -> float:
    """Calculate productivity score based on category distribution"""
    productive_time = category_stats.get("productive", {}).get("duration", 0)
    server_time = category_stats.get("server", {}).get("duration", 0)
    browser_time = category_stats.get("browser", {}).get("duration", 0)
    non_work_time = category_stats.get("non-work", {}).get("duration", 0)
    
    # Productive and server time count fully
    productive_total = productive_time + server_time
    
    # Browser time counts at 50% (some is work-related)
    productive_total += browser_time * 0.5
    
    # Total time
    total_time = productive_time + server_time + browser_time + non_work_time
    
    # Calculate work time (excluding non-work)
    work_time = total_time - non_work_time
    
    if work_time > 0:
        score = (productive_total / work_time) * 100
    else:
        score = 0
    
    return round(min(score, 100), 2)

@router.get("/api/debug-durations/{developer_id}")
async def debug_activity_durations(
    developer_id: str,
    db: Session = Depends(get_db)
):
    """Debug endpoint to check activity durations and grouping"""
    try:
        # Get sample of raw activities
        query = text("""
            SELECT 
                window_title,
                COUNT(*) as count,
                SUM(duration) as total_duration,
                AVG(duration) as avg_duration,
                MIN(timestamp) as first_seen,
                MAX(timestamp) as last_seen
            FROM activity_records
            WHERE developer_id = :dev_id
            AND timestamp >= CURRENT_DATE
            GROUP BY window_title
            ORDER BY total_duration DESC
            LIMIT 20
        """)
        
        result = db.execute(query, {"dev_id": developer_id}).fetchall()
        
        activities = []
        for row in result:
            # FIXED: Duration is already in seconds
            total_seconds = row[2] or 0
            activities.append({
                "window_title": row[0],
                "activity_count": row[1],
                "total_duration_seconds": round(total_seconds, 1),
                "total_duration_display": f"{round(total_seconds/60, 1)}m" if total_seconds >= 60 else f"{round(total_seconds, 1)}s",
                "avg_duration_seconds": round(row[3] or 0, 1),
                "first_seen": row[4].isoformat() if row[4] else None,
                "last_seen": row[5].isoformat() if row[5] else None
            })
        
        # Get overall stats
        stats_query = text("""
            SELECT 
                COUNT(*) as total_records,
                COUNT(DISTINCT window_title) as unique_windows,
                SUM(duration) as total_duration_seconds,
                AVG(duration) as avg_duration_seconds,
                MIN(duration) as min_duration_seconds,
                MAX(duration) as max_duration_seconds
            FROM activity_records
            WHERE developer_id = :dev_id
            AND timestamp >= CURRENT_DATE
        """)
        
        stats = db.execute(stats_query, {"dev_id": developer_id}).fetchone()
        
        total_hours = (stats[2] or 0) / 3600  # FIXED: Divide by 3600, not 3600000
        
        return {
            "summary": {
                "total_records": stats[0],
                "unique_windows": stats[1],
                "total_duration_seconds": stats[2] or 0,
                "total_duration_hours": round(total_hours, 2),
                "avg_duration_per_record_seconds": round(stats[3] or 0, 1),
                "min_duration_seconds": stats[4],
                "max_duration_seconds": stats[5],
                "recording_interval": "ActivityWatch records activities every 6 seconds"
            },
            "top_activities_grouped": activities,
            "note": "Durations are in seconds. ActivityWatch polls every few seconds."
        }
        
    except Exception as e:
        raise HTTPException(
            status_code=500,
            detail=f"Error debugging durations: {str(e)}"
        )