# activity_categorization_api.py - Fixed version
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlalchemy.orm import Session
from sqlalchemy import text
from datetime import datetime, timezone
from typing import Optional, List, Dict
from database import get_db
from activity_categorizer import ActivityCategorizer
import json

router = APIRouter()

@router.get("/api/activity-categories/{developer_id}")
async def get_categorized_activities(
    developer_id: str,
    start_date: Optional[str] = Query(None),
    end_date: Optional[str] = Query(None),
    db: Session = Depends(get_db)
):
    """Get activities categorized into Productive, Browser, and Server categories"""
    try:
        # Initialize categorizer
        categorizer = ActivityCategorizer()
        
        # Parse dates or use defaults
        if start_date:
            start = datetime.fromisoformat(start_date.replace('Z', '+00:00'))
        else:
            start = datetime.now(timezone.utc).replace(hour=0, minute=0, second=0, microsecond=0)
            
        if end_date:
            end = datetime.fromisoformat(end_date.replace('Z', '+00:00'))
        else:
            end = datetime.now(timezone.utc)
        
        # Fetch activities from database
        query = text("""
            SELECT 
                id,
                developer_id,
                application_name,
                window_title,
                duration,
                timestamp,
                url,
                file_path,
                project_name,
                project_type,
                category
            FROM activity_records
            WHERE developer_id = :dev_id
            AND timestamp >= :start_date
            AND timestamp <= :end_date
            ORDER BY timestamp DESC
        """)
        
        result = db.execute(query, {
            "dev_id": developer_id,
            "start_date": start,
            "end_date": end
        }).fetchall()
        
        # Categorize activities
        activities_by_category = {
            "productive": [],
            "browser": [],
            "server": [],
            "non-work": []
        }
        
        category_stats = {
            "productive": {"count": 0, "duration": 0},
            "browser": {"count": 0, "duration": 0},
            "server": {"count": 0, "duration": 0},
            "non-work": {"count": 0, "duration": 0}
        }
        
        total_duration = 0
        
        for row in result:
            activity = {
                "id": row[0],
                "developer_id": row[1],
                "application_name": row[2] or "",
                "window_title": row[3] or "",
                "duration": row[4] or 0,
                "timestamp": row[5].isoformat() if row[5] else None,
                "url": row[6],
                "file_path": row[7],
                "project_name": row[8],
                "project_type": row[9],
                "existing_category": row[10]
            }
            
            # Get categorization
            category_info = categorizer.get_detailed_category(
                activity["window_title"],
                activity["application_name"]
            )
            
            activity["new_category"] = category_info["category"]
            activity["subcategory"] = category_info["subcategory"]
            activity["confidence"] = category_info["confidence"]
            
            # Add to appropriate category list
            category = category_info["category"]
            activities_by_category[category].append(activity)
            
            # Update statistics
            category_stats[category]["count"] += 1
            category_stats[category]["duration"] += activity["duration"]
            total_duration += activity["duration"]
        
        # Calculate percentages
        for category in category_stats:
            if total_duration > 0:
                category_stats[category]["percentage"] = (
                    category_stats[category]["duration"] / total_duration * 100
                )
            else:
                category_stats[category]["percentage"] = 0
            
            # Format duration
            duration_seconds = category_stats[category]["duration"]
            hours = duration_seconds / 3600
            category_stats[category]["duration_hours"] = round(hours, 2)
        
        # Get top activities by category with proper grouping
        top_activities_by_category = {}
        for category, activities in activities_by_category.items():
            if category == "productive":
                # Group productive activities by actual projects
                grouped_activities = group_productive_activities_by_project(activities)
                top_activities_by_category[category] = grouped_activities[:10]
            else:
                # For other categories, show individual activities
                sorted_activities = sorted(
                    activities, 
                    key=lambda x: x["duration"], 
                    reverse=True
                )[:10]
                
                top_activities_by_category[category] = [
                    {
                        "window_title": act["window_title"],
                        "application_name": act["application_name"],
                        "duration": act["duration"],
                        "duration_hours": round(act["duration"] / 3600, 2),
                        "subcategory": act["subcategory"],
                        "confidence": act["confidence"]
                    }
                    for act in sorted_activities
                ]
        
        return {
            "developer_id": developer_id,
            "date_range": {
                "start": start.isoformat(),
                "end": end.isoformat()
            },
            "statistics": category_stats,
            "total_duration_hours": round(total_duration / 3600, 2),
            "top_activities_by_category": top_activities_by_category,
            "productivity_score": calculate_productivity_score(category_stats)
        }
        
    except Exception as e:
        raise HTTPException(
            status_code=500,
            detail=f"Error categorizing activities: {str(e)}"
        )

def group_productive_activities_by_project(activities: List[Dict]) -> List[Dict]:
    """Group productive activities by actual project name, filtering out non-project items"""
    project_groups = {}
    
    # Define patterns to exclude (emails, browser tabs, etc.)
    exclude_patterns = [
        'inbox', 'email', 'gmail', 'outlook', 'mail',
        'fw:', 're:', 'fwd:',  # Email subjects
        'http://', 'https://', 'www.',  # URLs
        'generate captcha', 'file manager',  # Generic browser titles
        'websites & domains', 'repository system proposal',
        'contact-form', 'form', 'survey'
    ]
    
    # Common code editor applications
    code_editors = [
        'visual studio code', 'vs code', 'code.exe', 'code',
        'cursor', 'cursor.exe',
        'sublime text', 'sublime_text.exe',
        'notepad++', 'notepad++.exe',
        'atom', 'atom.exe',
        'phpstorm', 'webstorm', 'intellij', 'pycharm',
        'vim', 'nvim', 'emacs',
        'brackets', 'vscode'
    ]
    
    for activity in activities:
        window_title = activity.get("window_title", "").lower()
        app_name = activity.get("application_name", "").lower()
        
        # Skip if this is an email, browser tab, or other non-project item
        skip = False
        for pattern in exclude_patterns:
            if pattern in window_title.lower():
                skip = True
                break
        
        if skip:
            continue
            
        # Only process if it's from a code editor
        is_code_editor = any(editor in app_name.lower() for editor in code_editors)
        if not is_code_editor and 'filezilla' not in app_name.lower():
            continue
        
        # Extract project name
        project_name = extract_project_name(activity)
        
        # Skip if no valid project name found
        if not project_name or project_name.lower() in ['general', 'unknown', 'no project', '']:
            continue
        
        # Initialize project group if not exists
        if project_name not in project_groups:
            project_groups[project_name] = {
                "project_name": project_name,
                "window_title": project_name,  # Just show project name
                "application_name": "Development",
                "duration": 0,
                "duration_hours": 0,
                "file_count": 0,
                "files": set(),
                "applications": set(),
                "subcategory": "development",
                "confidence": 0.9
            }
        
        # Accumulate data
        project_groups[project_name]["duration"] += activity["duration"]
        
        # Track unique files
        file_name = extract_file_name(activity.get("window_title", ""))
        if file_name:
            project_groups[project_name]["files"].add(file_name)
        
        # Track applications used
        if activity.get("application_name"):
            project_groups[project_name]["applications"].add(activity["application_name"])
    
    # Convert to list and finalize calculations
    result = []
    for project_name, group in project_groups.items():
        group["file_count"] = len(group["files"])
        group["duration_hours"] = round(group["duration"] / 3600, 2)
        
        # Format the display
        hours = int(group["duration"] // 3600)
        minutes = int((group["duration"] % 3600) // 60)
        seconds = int(group["duration"] % 60)
        
        time_str = ""
        if hours > 0:
            time_str = f"{hours}h {minutes}m"
        elif minutes > 0:
            time_str = f"{minutes}m {seconds}s"
        else:
            time_str = f"{seconds}s"
        
        # Update window title to show summary
        group["window_title"] = f"{project_name} - {time_str} ({group['file_count']} files)"
        
        # Convert sets to lists for JSON serialization
        group["files"] = list(group["files"])
        group["applications"] = list(group["applications"])
        
        # Set application name based on most used editor
        if group["applications"]:
            group["application_name"] = group["applications"][0]
        
        result.append(group)
    
    # Sort by total duration (descending)
    return sorted(result, key=lambda x: x["duration"], reverse=True)

def extract_project_name(activity: Dict) -> str:
    """Extract actual project name from activity data"""
    window_title = activity.get("window_title", "")
    app_name = activity.get("application_name", "")
    existing_project = activity.get("project_name", "")
    
    # If we have a valid existing project name, use it
    if existing_project and existing_project.lower() not in ['general', 'unknown', '']:
        return existing_project
    
    # Common patterns in window titles:
    # "filename.ext - projectname - VS Code"
    # "projectname - filename.ext - VS Code"
    # "filename.ext • projectname — VS Code"
    # "projectname/filename.ext - Cursor"
    
    # Try to extract from window title
    if " - " in window_title:
        parts = window_title.split(" - ")
        
        # Remove the application name (usually last part)
        if len(parts) >= 3:
            # Pattern: file - project - app
            potential_project = parts[1].strip()
            
            # Check if it's not a filename
            if not any(ext in potential_project.lower() for ext in ['.js', '.py', '.php', '.html', '.css', '.jsx', '.ts', '.tsx', '.vue', '.java', '.cpp', '.c', '.h', '.cs', '.rb', '.go', '.rs', '.dart', '.kt', '.swift', '.m', '.mm']):
                return potential_project
        
        elif len(parts) == 2:
            # Could be "project - app" or "file - project"
            part1 = parts[0].strip()
            part2 = parts[1].strip()
            
            # If part2 is an app name, part1 might be project
            if any(app in part2.lower() for app in ['code', 'cursor', 'studio', 'storm', 'text', 'vim', 'atom']):
                # Check if part1 is not a filename
                if '.' not in part1.split('/')[-1].split('\\')[-1]:
                    return part1
            # Otherwise part2 might be the project
            elif not any(ext in part2.lower() for ext in ['.js', '.py', '.php', '.html', '.css']):
                return part2
    
    # Try path-based extraction for patterns like "project/file.ext"
    if "/" in window_title or "\\" in window_title:
        # Split by both types of slashes
        path_parts = window_title.replace("\\", "/").split("/")
        if len(path_parts) >= 2:
            # Get the parent directory name (likely the project)
            for i in range(len(path_parts) - 1):
                part = path_parts[i].strip()
                # Skip common directory names
                if part.lower() not in ['src', 'app', 'components', 'pages', 'views', 'public', 'assets', 'dist', 'build', 'node_modules', 'vendor', 'lib', 'bin', 'test', 'tests', 'spec', 'docs', 'config']:
                    if not part.startswith('.') and len(part) > 2:
                        return part
    
    # Special handling for FileZilla
    if 'filezilla' in app_name.lower():
        # Pattern: "projectname - FileZilla"
        if " - " in window_title:
            return window_title.split(" - ")[0].strip()
    
    return None

def extract_file_name(window_title: str) -> str:
    """Extract filename from window title"""
    if not window_title:
        return None
    
    # Common patterns
    if " - " in window_title:
        parts = window_title.split(" - ")
        # First part is usually the filename
        potential_file = parts[0].strip()
        
        # Check if it looks like a filename (has extension)
        if '.' in potential_file and len(potential_file.split('.')[-1]) <= 4:
            return potential_file.split('/')[-1].split('\\')[-1]
    
    # Try to find filename in the full title
    words = window_title.split()
    for word in words:
        if '.' in word and len(word.split('.')[-1]) <= 4:
            return word.split('/')[-1].split('\\')[-1]
    
    return None

def calculate_productivity_score(category_stats: Dict) -> float:
    """Calculate productivity score based on category distribution"""
    productive_time = category_stats["productive"]["duration"]
    server_time = category_stats["server"]["duration"]
    browser_time = category_stats["browser"]["duration"]
    non_work_time = category_stats["non-work"]["duration"]
    
    # Productive and server time count fully
    productive_total = productive_time + server_time
    
    # Browser time counts at 50% (some is work-related)
    productive_total += browser_time * 0.5
    
    # Total time excluding non-work
    total_time = sum(stat["duration"] for stat in category_stats.values())
    work_time = total_time - non_work_time
    
    if work_time > 0:
        score = (productive_total / work_time) * 100
    else:
        score = 0
    
    return round(min(score, 100), 2)

# Keep your other endpoints as they are...
@router.post("/api/update-activity-categories/{developer_id}")
async def update_activity_categories(
    developer_id: str,
    start_date: Optional[str] = Query(None),
    end_date: Optional[str] = Query(None),
    db: Session = Depends(get_db)
):
    """Update the category field in the database for all activities"""
    try:
        categorizer = ActivityCategorizer()
        
        # Parse dates
        if start_date:
            start = datetime.fromisoformat(start_date.replace('Z', '+00:00'))
        else:
            start = datetime.now(timezone.utc).replace(hour=0, minute=0, second=0, microsecond=0)
            
        if end_date:
            end = datetime.fromisoformat(end_date.replace('Z', '+00:00'))
        else:
            end = datetime.now(timezone.utc)
        
        # Fetch activities
        query = text("""
            SELECT id, window_title, application_name
            FROM activity_records
            WHERE developer_id = :dev_id
            AND timestamp >= :start_date
            AND timestamp <= :end_date
        """)
        
        result = db.execute(query, {
            "dev_id": developer_id,
            "start_date": start,
            "end_date": end
        }).fetchall()
        
        updated_count = 0
        
        for row in result:
            activity_id = row[0]
            window_title = row[1] or ""
            app_name = row[2] or ""
            
            # Get category
            category_info = categorizer.get_detailed_category(window_title, app_name)
            
            # Update database
            update_query = text("""
                UPDATE activity_records
                SET 
                    category = :category,
                    subcategory = :subcategory,
                    category_confidence = :confidence
                WHERE id = :activity_id
            """)
            
            db.execute(update_query, {
                "category": category_info["category"],
                "subcategory": category_info["subcategory"],
                "confidence": category_info["confidence"],
                "activity_id": activity_id
            })
            
            updated_count += 1
        
        db.commit()
        
        return {
            "success": True,
            "updated_count": updated_count,
            "message": f"Successfully updated {updated_count} activities"
        }
        
    except Exception as e:
        db.rollback()
        raise HTTPException(
            status_code=500,
            detail=f"Error updating categories: {str(e)}"
        )

@router.get("/api/category-summary")
async def get_category_summary(
    start_date: Optional[str] = Query(None),
    end_date: Optional[str] = Query(None),
    db: Session = Depends(get_db)
):
    """Get summary of all activities by category across all developers"""
    try:
        # Parse dates
        if start_date:
            start = datetime.fromisoformat(start_date.replace('Z', '+00:00'))
        else:
            start = datetime.now(timezone.utc).replace(hour=0, minute=0, second=0, microsecond=0)
            
        if end_date:
            end = datetime.fromisoformat(end_date.replace('Z', '+00:00'))
        else:
            end = datetime.now(timezone.utc)
        
        # Get summary by category
        query = text("""
            SELECT 
                COALESCE(category, 'uncategorized') as category,
                COUNT(*) as activity_count,
                SUM(duration) as total_duration,
                COUNT(DISTINCT developer_id) as developer_count
            FROM activity_records
            WHERE timestamp >= :start_date
            AND timestamp <= :end_date
            GROUP BY category
            ORDER BY total_duration DESC
        """)
        
        result = db.execute(query, {
            "start_date": start,
            "end_date": end
        }).fetchall()
        
        summary = []
        total_duration = 0
        
        for row in result:
            category_data = {
                "category": row[0],
                "activity_count": row[1],
                "total_duration_seconds": row[2] or 0,
                "total_duration_hours": round((row[2] or 0) / 3600, 2),
                "developer_count": row[3]
            }
            summary.append(category_data)
            total_duration += category_data["total_duration_seconds"]
        
        # Add percentages
        for item in summary:
            if total_duration > 0:
                item["percentage"] = round(
                    item["total_duration_seconds"] / total_duration * 100, 2
                )
            else:
                item["percentage"] = 0
        
        return {
            "date_range": {
                "start": start.isoformat(),
                "end": end.isoformat()
            },
            "summary": summary,
            "total_duration_hours": round(total_duration / 3600, 2)
        }
        
    except Exception as e:
        raise HTTPException(
            status_code=500,
            detail=f"Error getting category summary: {str(e)}"
        )
