diff --git a/.idea/deviceManager.xml b/.idea/deviceManager.xml index fb403f9..106e5ce 100644 --- a/.idea/deviceManager.xml +++ b/.idea/deviceManager.xml @@ -8,7 +8,15 @@ + + + + @@ -17,6 +25,10 @@ diff --git a/app/src/main/assets/mobilefacenet.tflite b/app/src/main/assets/mobilefacenet.tflite new file mode 100644 index 0000000..057b985 Binary files /dev/null and b/app/src/main/assets/mobilefacenet.tflite differ diff --git a/app/src/main/java/com/placeholder/sherpai2/data/local/dao/Facecachedao.kt b/app/src/main/java/com/placeholder/sherpai2/data/local/dao/Facecachedao.kt index 8ead9bc..ea6cfc2 100644 --- a/app/src/main/java/com/placeholder/sherpai2/data/local/dao/Facecachedao.kt +++ b/app/src/main/java/com/placeholder/sherpai2/data/local/dao/Facecachedao.kt @@ -1,129 +1,91 @@ package com.placeholder.sherpai2.data.local.dao -import androidx.room.* +import androidx.room.Dao +import androidx.room.Insert +import androidx.room.OnConflictStrategy +import androidx.room.Query +import androidx.room.Update import com.placeholder.sherpai2.data.local.entity.FaceCacheEntity -import kotlinx.coroutines.flow.Flow /** - * FaceCacheDao - Query face metadata for intelligent filtering - * - * ENABLES SMART CLUSTERING: - * - Pre-filter to high-quality faces only - * - Avoid processing blurry/distant faces - * - Faster clustering with better results + * FaceCacheDao - Face detection cache with NEW queries for two-stage clustering */ @Dao interface FaceCacheDao { + // ═══════════════════════════════════════ + // INSERT / UPDATE + // ═══════════════════════════════════════ + @Insert(onConflict = OnConflictStrategy.REPLACE) suspend fun insert(faceCache: FaceCacheEntity) @Insert(onConflict = OnConflictStrategy.REPLACE) suspend fun insertAll(faceCaches: List) - /** - * Get ALL high-quality solo faces for clustering - * - * FILTERS: - * - Solo photos only (joins with images.faceCount = 1) - * - Large enough (isLargeEnough = true) - * - Good quality score (>= 0.6) - * - Frontal faces preferred (isFrontal = true) - */ - @Query(""" - SELECT fc.* FROM face_cache fc - INNER JOIN images i ON fc.imageId = i.imageId - WHERE i.faceCount = 1 - AND fc.isLargeEnough = 1 - AND fc.qualityScore >= 0.6 - AND fc.isFrontal = 1 - ORDER BY fc.qualityScore DESC - """) - suspend fun getHighQualitySoloFaces(): List + @Update + suspend fun update(faceCache: FaceCacheEntity) + + // ═══════════════════════════════════════ + // NEW CLUSTERING QUERIES ⭐ + // ═══════════════════════════════════════ /** - * Get high-quality faces from ANY photo (including group photos) - * Use when not enough solo photos available + * Get high-quality solo faces for Stage 1 clustering + * + * Filters: + * - Solo photos (faceCount = 1) + * - Large faces (faceAreaRatio >= minFaceRatio) + * - Has embedding */ @Query(""" - SELECT * FROM face_cache - WHERE isLargeEnough = 1 - AND qualityScore >= 0.6 - AND isFrontal = 1 - ORDER BY qualityScore DESC + SELECT fc.* + FROM face_cache fc + INNER JOIN images i ON fc.imageId = i.imageId + WHERE i.faceCount = 1 + AND fc.faceAreaRatio >= :minFaceRatio + AND fc.embedding IS NOT NULL + ORDER BY fc.faceAreaRatio DESC LIMIT :limit """) - suspend fun getHighQualityFaces(limit: Int = 1000): List + suspend fun getHighQualitySoloFaces( + minFaceRatio: Float = 0.015f, + limit: Int = 2000 + ): List /** - * Get faces for a specific image - */ - @Query("SELECT * FROM face_cache WHERE imageId = :imageId ORDER BY faceIndex ASC") - suspend fun getFacesForImage(imageId: String): List - - /** - * Count high-quality solo faces (for UI display) + * FALLBACK: Get ANY solo faces with embeddings + * Used if getHighQualitySoloFaces() returns empty */ @Query(""" - SELECT COUNT(*) FROM face_cache fc + SELECT fc.* + FROM face_cache fc INNER JOIN images i ON fc.imageId = i.imageId - WHERE i.faceCount = 1 - AND fc.isLargeEnough = 1 - AND fc.qualityScore >= 0.6 - """) - suspend fun getHighQualitySoloFaceCount(): Int - - /** - * Get quality distribution stats - */ - @Query(""" - SELECT - SUM(CASE WHEN qualityScore >= 0.8 THEN 1 ELSE 0 END) as excellent, - SUM(CASE WHEN qualityScore >= 0.6 AND qualityScore < 0.8 THEN 1 ELSE 0 END) as good, - SUM(CASE WHEN qualityScore < 0.6 THEN 1 ELSE 0 END) as poor, - COUNT(*) as total - FROM face_cache - """) - suspend fun getQualityStats(): FaceQualityStats? - - /** - * Delete cache for specific image (when image is deleted) - */ - @Query("DELETE FROM face_cache WHERE imageId = :imageId") - suspend fun deleteCacheForImage(imageId: String) - - /** - * Delete all cache (for full rebuild) - */ - @Query("DELETE FROM face_cache") - suspend fun deleteAll() - - /** - * Get faces with embeddings already computed - * (Ultra-fast clustering - no need to re-generate) - */ - @Query(""" - SELECT fc.* FROM face_cache fc - INNER JOIN images i ON fc.imageId = i.imageId - WHERE i.faceCount = 1 - AND fc.isLargeEnough = 1 + WHERE i.faceCount = 1 AND fc.embedding IS NOT NULL ORDER BY fc.qualityScore DESC LIMIT :limit """) - suspend fun getSoloFacesWithEmbeddings(limit: Int = 2000): List -} + suspend fun getSoloFacesWithEmbeddings( + limit: Int = 2000 + ): List -/** - * Quality statistics result - */ -data class FaceQualityStats( - val excellent: Int, // qualityScore >= 0.8 - val good: Int, // 0.6 <= qualityScore < 0.8 - val poor: Int, // qualityScore < 0.6 - val total: Int -) { - val excellentPercent: Float get() = if (total > 0) excellent.toFloat() / total else 0f - val goodPercent: Float get() = if (total > 0) good.toFloat() / total else 0f - val poorPercent: Float get() = if (total > 0) poor.toFloat() / total else 0f + // ═══════════════════════════════════════ + // EXISTING QUERIES (keep as-is) + // ═══════════════════════════════════════ + + @Query("SELECT * FROM face_cache WHERE id = :id") + suspend fun getFaceCacheById(id: String): FaceCacheEntity? + + @Query("SELECT * FROM face_cache WHERE imageId = :imageId ORDER BY faceIndex") + suspend fun getFaceCacheForImage(imageId: String): List + + @Query("DELETE FROM face_cache WHERE imageId = :imageId") + suspend fun deleteFaceCacheForImage(imageId: String) + + @Query("DELETE FROM face_cache") + suspend fun deleteAll() + + @Query("DELETE FROM face_cache WHERE cacheVersion < :version") + suspend fun deleteOldVersions(version: Int) } \ No newline at end of file diff --git a/app/src/main/java/com/placeholder/sherpai2/domain/clustering/Clusterqualityanalyzer.kt b/app/src/main/java/com/placeholder/sherpai2/domain/clustering/Clusterqualityanalyzer.kt index 5660066..4241332 100644 --- a/app/src/main/java/com/placeholder/sherpai2/domain/clustering/Clusterqualityanalyzer.kt +++ b/app/src/main/java/com/placeholder/sherpai2/domain/clustering/Clusterqualityanalyzer.kt @@ -1,7 +1,7 @@ package com.placeholder.sherpai2.domain.clustering import android.graphics.Rect -import com.placeholder.sherpai2.domain.clustering.DetectedFaceWithEmbedding +import android.util.Log import javax.inject.Inject import javax.inject.Singleton import kotlin.math.sqrt @@ -9,56 +9,62 @@ import kotlin.math.sqrt /** * ClusterQualityAnalyzer - Validate cluster quality BEFORE training * - * PURPOSE: Prevent training on "dirty" clusters (siblings merged, poor quality faces) - * - * CHECKS: - * 1. Solo photo count (min 6 required) - * 2. Face size (min 15% of image - clear, not distant) - * 3. Internal consistency (all faces should match well) - * 4. Outlier detection (find faces that don't belong) - * - * QUALITY TIERS: - * - Excellent (95%+): Safe to train immediately - * - Good (85-94%): Review outliers, then train - * - Poor (<85%): Likely mixed people - DO NOT TRAIN! + * RELAXED THRESHOLDS for real-world photos (social media, distant shots): + * - Face size: 3% (down from 15%) + * - Outlier threshold: 65% (down from 75%) + * - GOOD tier: 75% (down from 85%) + * - EXCELLENT tier: 85% (down from 95%) */ @Singleton class ClusterQualityAnalyzer @Inject constructor() { companion object { + private const val TAG = "ClusterQuality" private const val MIN_SOLO_PHOTOS = 6 - private const val MIN_FACE_SIZE_RATIO = 0.15f // 15% of image - private const val MIN_INTERNAL_SIMILARITY = 0.80f - private const val OUTLIER_THRESHOLD = 0.75f - private const val EXCELLENT_THRESHOLD = 0.95f - private const val GOOD_THRESHOLD = 0.85f + private const val MIN_FACE_SIZE_RATIO = 0.03f // 3% of image (RELAXED) + private const val MIN_FACE_DIMENSION_PIXELS = 50 // 50px minimum (RELAXED) + private const val FALLBACK_MIN_DIMENSION = 80 // Fallback when no dimensions + private const val MIN_INTERNAL_SIMILARITY = 0.75f + private const val OUTLIER_THRESHOLD = 0.65f // RELAXED + private const val EXCELLENT_THRESHOLD = 0.85f // RELAXED + private const val GOOD_THRESHOLD = 0.75f // RELAXED } - /** - * Analyze cluster quality before training - */ fun analyzeCluster(cluster: FaceCluster): ClusterQualityResult { - // Step 1: Filter to solo photos only - val soloFaces = cluster.faces.filter { it.faceCount == 1 } + Log.d(TAG, "========================================") + Log.d(TAG, "Analyzing cluster ${cluster.clusterId}") + Log.d(TAG, "Total faces: ${cluster.faces.size}") - // Step 2: Filter by face size (must be clear/close-up) + // Step 1: Filter to solo photos + val soloFaces = cluster.faces.filter { it.faceCount == 1 } + Log.d(TAG, "Solo photos: ${soloFaces.size}") + + // Step 2: Filter by face size val largeFaces = soloFaces.filter { face -> - isFaceLargeEnough(face.boundingBox, face.imageUri) + isFaceLargeEnough(face) + } + Log.d(TAG, "Large faces (>= 3%): ${largeFaces.size}") + + if (largeFaces.size < soloFaces.size) { + Log.d(TAG, "⚠️ Filtered out ${soloFaces.size - largeFaces.size} small faces") } // Step 3: Calculate internal consistency val (avgSimilarity, outliers) = analyzeInternalConsistency(largeFaces) - // Step 4: Clean faces (large solo faces, no outliers) + // Step 4: Clean faces val cleanFaces = largeFaces.filter { it !in outliers } + Log.d(TAG, "Clean faces: ${cleanFaces.size}") // Step 5: Calculate quality score val qualityScore = calculateQualityScore( soloPhotoCount = soloFaces.size, largeFaceCount = largeFaces.size, cleanFaceCount = cleanFaces.size, - avgSimilarity = avgSimilarity + avgSimilarity = avgSimilarity, + totalFaces = cluster.faces.size ) + Log.d(TAG, "Quality score: ${(qualityScore * 100).toInt()}%") // Step 6: Determine quality tier val qualityTier = when { @@ -66,6 +72,11 @@ class ClusterQualityAnalyzer @Inject constructor() { qualityScore >= GOOD_THRESHOLD -> ClusterQualityTier.GOOD else -> ClusterQualityTier.POOR } + Log.d(TAG, "Quality tier: $qualityTier") + + val canTrain = qualityTier != ClusterQualityTier.POOR && cleanFaces.size >= MIN_SOLO_PHOTOS + Log.d(TAG, "Can train: $canTrain") + Log.d(TAG, "========================================") return ClusterQualityResult( originalFaceCount = cluster.faces.size, @@ -77,62 +88,65 @@ class ClusterQualityAnalyzer @Inject constructor() { cleanFaces = cleanFaces, qualityScore = qualityScore, qualityTier = qualityTier, - canTrain = qualityTier != ClusterQualityTier.POOR && cleanFaces.size >= MIN_SOLO_PHOTOS, - warnings = generateWarnings(soloFaces.size, largeFaces.size, cleanFaces.size, qualityTier) + canTrain = canTrain, + warnings = generateWarnings(soloFaces.size, largeFaces.size, cleanFaces.size, qualityTier, avgSimilarity) ) } - /** - * Check if face is large enough (not distant/blurry) - * - * A face should occupy at least 15% of the image area for good quality - */ - private fun isFaceLargeEnough(boundingBox: Rect, imageUri: String): Boolean { - // Estimate image dimensions from common aspect ratios - // For now, use bounding box size as proxy - val faceArea = boundingBox.width() * boundingBox.height() + private fun isFaceLargeEnough(face: DetectedFaceWithEmbedding): Boolean { + val faceArea = face.boundingBox.width() * face.boundingBox.height() - // Assume typical photo is ~2000x1500 = 3,000,000 pixels - // 15% = 450,000 pixels - // For a square face: sqrt(450,000) = ~670 pixels per side + // Check 1: Absolute minimum + if (face.boundingBox.width() < MIN_FACE_DIMENSION_PIXELS || + face.boundingBox.height() < MIN_FACE_DIMENSION_PIXELS) { + return false + } - // More conservative: face should be at least 200x200 pixels - return boundingBox.width() >= 200 && boundingBox.height() >= 200 + // Check 2: Relative size if we have dimensions + if (face.imageWidth > 0 && face.imageHeight > 0) { + val imageArea = face.imageWidth * face.imageHeight + val faceRatio = faceArea.toFloat() / imageArea.toFloat() + return faceRatio >= MIN_FACE_SIZE_RATIO + } + + // Fallback: Use absolute size + return face.boundingBox.width() >= FALLBACK_MIN_DIMENSION && + face.boundingBox.height() >= FALLBACK_MIN_DIMENSION } - /** - * Analyze how similar faces are to each other (internal consistency) - * - * Returns: (average similarity, list of outlier faces) - */ private fun analyzeInternalConsistency( faces: List ): Pair> { if (faces.size < 2) { + Log.d(TAG, "Less than 2 faces, skipping consistency check") return 0f to emptyList() } - // Calculate average embedding (centroid) + Log.d(TAG, "Analyzing ${faces.size} faces for internal consistency") + val centroid = calculateCentroid(faces.map { it.embedding }) - // Calculate similarity of each face to centroid - val similarities = faces.map { face -> - face to cosineSimilarity(face.embedding, centroid) + val centroidSum = centroid.sum() + Log.d(TAG, "Centroid sum: $centroidSum, first5=[${centroid.take(5).joinToString()}]") + + val similarities = faces.mapIndexed { index, face -> + val similarity = cosineSimilarity(face.embedding, centroid) + Log.d(TAG, "Face $index similarity to centroid: $similarity") + face to similarity } val avgSimilarity = similarities.map { it.second }.average().toFloat() + Log.d(TAG, "Average internal similarity: $avgSimilarity") - // Find outliers (faces significantly different from centroid) val outliers = similarities .filter { (_, similarity) -> similarity < OUTLIER_THRESHOLD } .map { (face, _) -> face } + Log.d(TAG, "Found ${outliers.size} outliers (threshold=$OUTLIER_THRESHOLD)") + return avgSimilarity to outliers } - /** - * Calculate centroid (average embedding) - */ private fun calculateCentroid(embeddings: List): FloatArray { val size = embeddings.first().size val centroid = FloatArray(size) { 0f } @@ -148,14 +162,14 @@ class ClusterQualityAnalyzer @Inject constructor() { centroid[i] /= count } - // Normalize val norm = sqrt(centroid.map { it * it }.sum()) - return centroid.map { it / norm }.toFloatArray() + return if (norm > 0) { + centroid.map { it / norm }.toFloatArray() + } else { + centroid + } } - /** - * Cosine similarity between two embeddings - */ private fun cosineSimilarity(a: FloatArray, b: FloatArray): Float { var dotProduct = 0f var normA = 0f @@ -170,32 +184,31 @@ class ClusterQualityAnalyzer @Inject constructor() { return dotProduct / (sqrt(normA) * sqrt(normB)) } - /** - * Calculate overall quality score (0.0 - 1.0) - */ private fun calculateQualityScore( soloPhotoCount: Int, largeFaceCount: Int, cleanFaceCount: Int, - avgSimilarity: Float + avgSimilarity: Float, + totalFaces: Int ): Float { - // Weight factors - val soloPhotoScore = (soloPhotoCount.toFloat() / 20f).coerceIn(0f, 1f) * 0.3f - val largeFaceScore = (largeFaceCount.toFloat() / 15f).coerceIn(0f, 1f) * 0.2f - val cleanFaceScore = (cleanFaceCount.toFloat() / 10f).coerceIn(0f, 1f) * 0.2f - val similarityScore = avgSimilarity * 0.3f + val soloRatio = soloPhotoCount.toFloat() / totalFaces.toFloat().coerceAtLeast(1f) + val soloPhotoScore = soloRatio.coerceIn(0f, 1f) * 0.25f + + val largeFaceScore = (largeFaceCount.toFloat() / 15f).coerceIn(0f, 1f) * 0.25f + + val cleanFaceScore = (cleanFaceCount.toFloat() / 10f).coerceIn(0f, 1f) * 0.20f + + val similarityScore = avgSimilarity * 0.30f return soloPhotoScore + largeFaceScore + cleanFaceScore + similarityScore } - /** - * Generate human-readable warnings - */ private fun generateWarnings( soloPhotoCount: Int, largeFaceCount: Int, cleanFaceCount: Int, - qualityTier: ClusterQualityTier + qualityTier: ClusterQualityTier, + avgSimilarity: Float ): List { val warnings = mutableListOf() @@ -203,12 +216,20 @@ class ClusterQualityAnalyzer @Inject constructor() { ClusterQualityTier.POOR -> { warnings.add("⚠️ POOR QUALITY - This cluster may contain multiple people!") warnings.add("Do NOT train on this cluster - it will create a bad model.") + + if (avgSimilarity < 0.70f) { + warnings.add("Low internal similarity (${(avgSimilarity * 100).toInt()}%) suggests mixed identities.") + } } ClusterQualityTier.GOOD -> { warnings.add("⚠️ Review outlier faces before training") + + if (cleanFaceCount < 10) { + warnings.add("Consider adding more high-quality photos for better results.") + } } ClusterQualityTier.EXCELLENT -> { - // No warnings - ready to train! + // No warnings } } @@ -218,38 +239,47 @@ class ClusterQualityAnalyzer @Inject constructor() { if (largeFaceCount < 6) { warnings.add("Only $largeFaceCount photos with large/clear faces (prefer 10+)") + warnings.add("Tip: Use close-up photos where the face is clearly visible") } if (cleanFaceCount < 6) { warnings.add("After removing outliers: only $cleanFaceCount clean faces (need 6+)") } + if (qualityTier == ClusterQualityTier.EXCELLENT) { + warnings.add("✅ Excellent quality! This cluster is ready for training.") + warnings.add("High-quality photos with consistent facial features detected.") + } + return warnings } } -/** - * Result of cluster quality analysis - */ data class ClusterQualityResult( - val originalFaceCount: Int, // Total faces in cluster - val soloPhotoCount: Int, // Photos with faceCount = 1 - val largeFaceCount: Int, // Solo photos with large faces - val cleanFaceCount: Int, // Large faces, no outliers - val avgInternalSimilarity: Float, // How similar faces are (0.0-1.0) - val outlierFaces: List, // Faces to exclude - val cleanFaces: List, // Good faces for training - val qualityScore: Float, // Overall score (0.0-1.0) + val originalFaceCount: Int, + val soloPhotoCount: Int, + val largeFaceCount: Int, + val cleanFaceCount: Int, + val avgInternalSimilarity: Float, + val outlierFaces: List, + val cleanFaces: List, + val qualityScore: Float, val qualityTier: ClusterQualityTier, - val canTrain: Boolean, // Safe to proceed with training? - val warnings: List // Human-readable issues -) + val canTrain: Boolean, + val warnings: List +) { + fun getSummary(): String = when (qualityTier) { + ClusterQualityTier.EXCELLENT -> + "Excellent quality cluster with $cleanFaceCount high-quality photos ready for training." + ClusterQualityTier.GOOD -> + "Good quality cluster with $cleanFaceCount usable photos. Review outliers before training." + ClusterQualityTier.POOR -> + "Poor quality cluster. May contain multiple people or low-quality photos. Add more photos or split cluster." + } +} -/** - * Quality tier classification - */ enum class ClusterQualityTier { - EXCELLENT, // 95%+ - Safe to train immediately - GOOD, // 85-94% - Review outliers first - POOR // <85% - DO NOT TRAIN (likely mixed people) + EXCELLENT, // 85%+ + GOOD, // 75-84% + POOR // <75% } \ No newline at end of file diff --git a/app/src/main/java/com/placeholder/sherpai2/domain/clustering/Faceclusteringservice.kt b/app/src/main/java/com/placeholder/sherpai2/domain/clustering/Faceclusteringservice.kt index 36a4aa2..1f1cafb 100644 --- a/app/src/main/java/com/placeholder/sherpai2/domain/clustering/Faceclusteringservice.kt +++ b/app/src/main/java/com/placeholder/sherpai2/domain/clustering/Faceclusteringservice.kt @@ -4,11 +4,13 @@ import android.content.Context import android.graphics.Bitmap import android.graphics.BitmapFactory import android.net.Uri +import android.util.Log import com.google.mlkit.vision.common.InputImage import com.google.mlkit.vision.face.FaceDetection import com.google.mlkit.vision.face.FaceDetectorOptions import com.placeholder.sherpai2.data.local.dao.FaceCacheDao import com.placeholder.sherpai2.data.local.dao.ImageDao +import com.placeholder.sherpai2.data.local.entity.FaceCacheEntity import com.placeholder.sherpai2.data.local.entity.ImageEntity import com.placeholder.sherpai2.ml.FaceNetModel import dagger.hilt.android.qualifiers.ApplicationContext @@ -18,7 +20,6 @@ import kotlinx.coroutines.awaitAll import kotlinx.coroutines.coroutineScope import kotlinx.coroutines.sync.Semaphore import kotlinx.coroutines.withContext -import java.util.concurrent.atomic.AtomicInteger import javax.inject.Inject import javax.inject.Singleton import kotlin.math.sqrt @@ -33,111 +34,152 @@ import kotlin.math.sqrt * 4. Detect faces and generate embeddings (parallel) * 5. Cluster using DBSCAN (epsilon=0.18, minPoints=3) * 6. Analyze clusters for age, siblings, representatives + * + * IMPROVEMENTS: + * - ✅ Complete fast-path using FaceCacheDao.getSoloFacesWithEmbeddings() + * - ✅ Works with existing FaceCacheEntity.getEmbedding() method + * - ✅ Centroid-based representative face selection + * - ✅ Batched processing to prevent OOM + * - ✅ RGB_565 bitmap config for 50% memory savings */ @Singleton class FaceClusteringService @Inject constructor( @ApplicationContext private val context: Context, private val imageDao: ImageDao, - private val faceCacheDao: FaceCacheDao // Optional - will work without it + private val faceCacheDao: FaceCacheDao ) { - private val semaphore = Semaphore(12) + private val semaphore = Semaphore(8) + + companion object { + private const val TAG = "FaceClustering" + private const val MAX_FACES_TO_CLUSTER = 2000 + private const val MIN_SOLO_PHOTOS = 50 + private const val BATCH_SIZE = 50 + private const val MIN_CACHED_FACES = 100 + } /** * Main clustering entry point - HYBRID with automatic fallback - * - * @param maxFacesToCluster Limit for performance (default 2000) - * @param onProgress Progress callback (current, total, message) */ suspend fun discoverPeople( - maxFacesToCluster: Int = 2000, + maxFacesToCluster: Int = MAX_FACES_TO_CLUSTER, onProgress: (Int, Int, String) -> Unit = { _, _, _ -> } ): ClusteringResult = withContext(Dispatchers.Default) { - // TRY FAST PATH: Use face cache if available - val highQualityFaces = try { - withContext(Dispatchers.IO) { - faceCacheDao.getHighQualitySoloFaces() - } - } catch (e: Exception) { - emptyList() - } - - if (highQualityFaces.isNotEmpty()) { - // FAST PATH: Use cached faces (future enhancement) - onProgress(0, 100, "Using face cache (${highQualityFaces.size} faces)...") - // TODO: Implement cache-based clustering - // For now, fall through to classic method - } - - // CLASSIC METHOD: Load and process photos - onProgress(0, 100, "Loading solo photos...") - - // Step 1: Get SOLO PHOTOS ONLY (faceCount = 1) for cleaner clustering - val soloPhotos = withContext(Dispatchers.IO) { - imageDao.getImagesByFaceCount(count = 1) - } - - // Fallback: If not enough solo photos, use all images with faces - val imagesWithFaces = if (soloPhotos.size < 50) { - onProgress(0, 100, "Loading all photos with faces...") - imageDao.getImagesWithFaces() - } else { - soloPhotos - } - - if (imagesWithFaces.isEmpty()) { - return@withContext ClusteringResult( - clusters = emptyList(), - totalFacesAnalyzed = 0, - processingTimeMs = 0, - errorMessage = "No photos with faces found. Please ensure face detection cache is populated." - ) - } - - onProgress(10, 100, "Analyzing ${imagesWithFaces.size} photos (${if (soloPhotos.size >= 50) "solo only" else "all"})...") - val startTime = System.currentTimeMillis() - // Step 2: Detect faces and generate embeddings (parallel) - val allFaces = detectFacesInImages( - images = imagesWithFaces.take(1000), // Smart limit - onProgress = { current, total -> - onProgress(10 + (current * 40 / total), 100, "Detecting faces... $current/$total") + // Try high-quality cached faces FIRST (NEW!) + var cachedFaces = withContext(Dispatchers.IO) { + try { + faceCacheDao.getHighQualitySoloFaces( + minFaceRatio = 0.015f, // 1.5% + limit = maxFacesToCluster + ) + } catch (e: Exception) { + // Method doesn't exist yet - that's ok + emptyList() } - ) + } + + // Fallback to ANY solo faces if high-quality returned nothing + if (cachedFaces.isEmpty()) { + Log.w(TAG, "No high-quality faces (>= 1.5%), trying ANY solo faces...") + cachedFaces = withContext(Dispatchers.IO) { + try { + faceCacheDao.getSoloFacesWithEmbeddings(limit = maxFacesToCluster) + } catch (e: Exception) { + emptyList() + } + } + } + + Log.d(TAG, "Cache check: ${cachedFaces.size} faces available") + + val allFaces = if (cachedFaces.size >= MIN_CACHED_FACES) { + // FAST PATH ✅ + Log.d(TAG, "Using FAST PATH with ${cachedFaces.size} cached faces") + onProgress(10, 100, "Using cached embeddings (${cachedFaces.size} faces)...") + + cachedFaces.mapNotNull { cached -> + val embedding = cached.getEmbedding() ?: return@mapNotNull null + + DetectedFaceWithEmbedding( + imageId = cached.imageId, + imageUri = "", + capturedAt = 0L, + embedding = embedding, + boundingBox = cached.getBoundingBox(), + confidence = cached.confidence, + faceCount = 1, // Solo faces only (filtered by query) + imageWidth = cached.imageWidth, + imageHeight = cached.imageHeight + ) + }.also { + onProgress(50, 100, "Processing ${it.size} cached faces...") + } + } else { + // SLOW PATH + Log.d(TAG, "Using SLOW PATH - cache has < $MIN_CACHED_FACES faces") + onProgress(0, 100, "Loading photos...") + + val soloPhotos = withContext(Dispatchers.IO) { + imageDao.getImagesByFaceCount(count = 1) + } + + val imagesWithFaces = if (soloPhotos.size < MIN_SOLO_PHOTOS) { + imageDao.getImagesWithFaces() + } else { + soloPhotos + } + + if (imagesWithFaces.isEmpty()) { + return@withContext ClusteringResult( + clusters = emptyList(), + totalFacesAnalyzed = 0, + processingTimeMs = 0, + errorMessage = "No photos with faces found" + ) + } + + onProgress(10, 100, "Analyzing ${imagesWithFaces.size} photos...") + + detectFacesInImagesBatched( + images = imagesWithFaces.take(1000), + onProgress = { current, total -> + onProgress(10 + (current * 40 / total), 100, "Detecting faces... $current/$total") + } + ) + } if (allFaces.isEmpty()) { return@withContext ClusteringResult( clusters = emptyList(), totalFacesAnalyzed = 0, processingTimeMs = System.currentTimeMillis() - startTime, - errorMessage = "No faces detected in images" + errorMessage = "No faces detected" ) } onProgress(50, 100, "Clustering ${allFaces.size} faces...") - // Step 3: DBSCAN clustering val rawClusters = performDBSCAN( faces = allFaces.take(maxFacesToCluster), - epsilon = 0.18f, // VERY STRICT for siblings + epsilon = 0.26f, minPoints = 3 ) onProgress(70, 100, "Analyzing relationships...") - // Step 4: Build co-occurrence graph val coOccurrenceGraph = buildCoOccurrenceGraph(rawClusters) onProgress(80, 100, "Selecting representative faces...") - // Step 5: Create final clusters val clusters = rawClusters.map { cluster -> FaceCluster( clusterId = cluster.clusterId, faces = cluster.faces, - representativeFaces = selectRepresentativeFaces(cluster.faces, count = 6), + representativeFaces = selectRepresentativeFacesByCentroid(cluster.faces, count = 6), photoCount = cluster.faces.map { it.imageId }.distinct().size, averageConfidence = cluster.faces.map { it.confidence }.average().toFloat(), estimatedAge = estimateAge(cluster.faces), @@ -154,14 +196,31 @@ class FaceClusteringService @Inject constructor( ) } - /** - * Detect faces in images and generate embeddings (parallel) - */ - private suspend fun detectFacesInImages( + private suspend fun detectFacesInImagesBatched( images: List, onProgress: (Int, Int) -> Unit ): List = coroutineScope { + val allFaces = mutableListOf() + var processedCount = 0 + + images.chunked(BATCH_SIZE).forEach { batch -> + val batchFaces = detectFacesInBatch(batch) + allFaces.addAll(batchFaces) + + processedCount += batch.size + onProgress(processedCount, images.size) + + System.gc() + } + + allFaces + } + + private suspend fun detectFacesInBatch( + images: List + ): List = coroutineScope { + val detector = FaceDetection.getClient( FaceDetectorOptions.Builder() .setPerformanceMode(FaceDetectorOptions.PERFORMANCE_MODE_ACCURATE) @@ -170,20 +229,14 @@ class FaceClusteringService @Inject constructor( ) val faceNetModel = FaceNetModel(context) - val allFaces = mutableListOf() - val processedCount = AtomicInteger(0) + val batchFaces = mutableListOf() try { val jobs = images.map { image -> - async { + async(Dispatchers.IO) { semaphore.acquire() try { - val faces = detectFacesInImage(image, detector, faceNetModel) - val current = processedCount.incrementAndGet() - if (current % 10 == 0) { - onProgress(current, images.size) - } - faces + detectFacesInImage(image, detector, faceNetModel) } finally { semaphore.release() } @@ -191,7 +244,7 @@ class FaceClusteringService @Inject constructor( } jobs.awaitAll().flatten().also { - allFaces.addAll(it) + batchFaces.addAll(it) } } finally { @@ -199,7 +252,7 @@ class FaceClusteringService @Inject constructor( faceNetModel.close() } - allFaces + batchFaces } private suspend fun detectFacesInImage( @@ -215,8 +268,6 @@ class FaceClusteringService @Inject constructor( val mlImage = InputImage.fromBitmap(bitmap, 0) val faces = com.google.android.gms.tasks.Tasks.await(detector.process(mlImage)) - val totalFacesInImage = faces.size - val result = faces.mapNotNull { face -> try { val faceBitmap = Bitmap.createBitmap( @@ -237,7 +288,9 @@ class FaceClusteringService @Inject constructor( embedding = embedding, boundingBox = face.boundingBox, confidence = 0.95f, - faceCount = totalFacesInImage + faceCount = faces.size, + imageWidth = bitmap.width, + imageHeight = bitmap.height ) } catch (e: Exception) { null @@ -252,9 +305,6 @@ class FaceClusteringService @Inject constructor( } } - // All other methods remain the same (DBSCAN, similarity, etc.) - // ... [Rest of the implementation from original file] - private fun performDBSCAN( faces: List, epsilon: Float, @@ -368,18 +418,61 @@ class FaceClusteringService @Inject constructor( ?: emptyList() } - private fun selectRepresentativeFaces( + private fun selectRepresentativeFacesByCentroid( faces: List, count: Int ): List { if (faces.size <= count) return faces - val sortedByTime = faces.sortedBy { it.capturedAt } - val step = faces.size / count + val centroid = calculateCentroid(faces.map { it.embedding }) - return (0 until count).map { i -> - sortedByTime[i * step] + val facesWithDistance = faces.map { face -> + val distance = 1 - cosineSimilarity(face.embedding, centroid) + face to distance } + + val sortedByProximity = facesWithDistance.sortedBy { it.second } + + val representatives = mutableListOf() + representatives.add(sortedByProximity.first().first) + + val remainingFaces = sortedByProximity.drop(1).take(count * 3) + val sortedByTime = remainingFaces.map { it.first }.sortedBy { it.capturedAt } + + if (sortedByTime.isNotEmpty()) { + val step = sortedByTime.size / (count - 1).coerceAtLeast(1) + for (i in 0 until (count - 1)) { + val index = (i * step).coerceAtMost(sortedByTime.size - 1) + representatives.add(sortedByTime[index]) + } + } + + return representatives.take(count) + } + + private fun calculateCentroid(embeddings: List): FloatArray { + if (embeddings.isEmpty()) return FloatArray(0) + + val size = embeddings.first().size + val centroid = FloatArray(size) { 0f } + + embeddings.forEach { embedding -> + for (i in embedding.indices) { + centroid[i] += embedding[i] + } + } + + val count = embeddings.size.toFloat() + for (i in centroid.indices) { + centroid[i] /= count + } + + val norm = sqrt(centroid.map { it * it }.sum()) + if (norm > 0) { + return centroid.map { it / norm }.toFloatArray() + } + + return centroid } private fun estimateAge(faces: List): AgeEstimate { @@ -416,7 +509,6 @@ class FaceClusteringService @Inject constructor( } } -// Data classes data class DetectedFaceWithEmbedding( val imageId: String, val imageUri: String, @@ -424,7 +516,9 @@ data class DetectedFaceWithEmbedding( val embedding: FloatArray, val boundingBox: android.graphics.Rect, val confidence: Float, - val faceCount: Int = 1 + val faceCount: Int = 1, + val imageWidth: Int = 0, + val imageHeight: Int = 0 ) { override fun equals(other: Any?): Boolean { if (this === other) return true diff --git a/app/src/main/java/com/placeholder/sherpai2/ml/FaceNetModel.kt b/app/src/main/java/com/placeholder/sherpai2/ml/FaceNetModel.kt index 22daad1..a15197d 100644 --- a/app/src/main/java/com/placeholder/sherpai2/ml/FaceNetModel.kt +++ b/app/src/main/java/com/placeholder/sherpai2/ml/FaceNetModel.kt @@ -2,6 +2,7 @@ package com.placeholder.sherpai2.ml import android.content.Context import android.graphics.Bitmap +import android.util.Log import org.tensorflow.lite.Interpreter import java.io.FileInputStream import java.nio.ByteBuffer @@ -11,16 +12,21 @@ import java.nio.channels.FileChannel import kotlin.math.sqrt /** - * FaceNetModel - MobileFaceNet wrapper for face recognition + * FaceNetModel - MobileFaceNet wrapper with debugging * - * CLEAN IMPLEMENTATION: - * - All IDs are Strings (matching your schema) - * - Generates 192-dimensional embeddings - * - Cosine similarity for matching + * IMPROVEMENTS: + * - ✅ Detailed error logging + * - ✅ Model validation on init + * - ✅ Embedding validation (detect all-zeros) + * - ✅ Toggle-able debug mode */ -class FaceNetModel(private val context: Context) { +class FaceNetModel( + private val context: Context, + private val debugMode: Boolean = true // Enable for troubleshooting +) { companion object { + private const val TAG = "FaceNetModel" private const val MODEL_FILE = "mobilefacenet.tflite" private const val INPUT_SIZE = 112 private const val EMBEDDING_SIZE = 192 @@ -31,13 +37,56 @@ class FaceNetModel(private val context: Context) { } private var interpreter: Interpreter? = null + private var modelLoadSuccess = false init { try { + if (debugMode) Log.d(TAG, "Loading FaceNet model: $MODEL_FILE") + val model = loadModelFile() interpreter = Interpreter(model) + modelLoadSuccess = true + + if (debugMode) { + Log.d(TAG, "✅ FaceNet model loaded successfully") + Log.d(TAG, "Model input size: ${INPUT_SIZE}x$INPUT_SIZE") + Log.d(TAG, "Embedding size: $EMBEDDING_SIZE") + } + + // Test model with dummy input + testModel() + } catch (e: Exception) { - throw RuntimeException("Failed to load FaceNet model", e) + Log.e(TAG, "❌ CRITICAL: Failed to load FaceNet model from assets/$MODEL_FILE", e) + Log.e(TAG, "Make sure mobilefacenet.tflite exists in app/src/main/assets/") + modelLoadSuccess = false + throw RuntimeException("Failed to load FaceNet model: ${e.message}", e) + } + } + + /** + * Test model with dummy input to verify it works + */ + private fun testModel() { + try { + val testBitmap = Bitmap.createBitmap(INPUT_SIZE, INPUT_SIZE, Bitmap.Config.ARGB_8888) + val testEmbedding = generateEmbedding(testBitmap) + testBitmap.recycle() + + val sum = testEmbedding.sum() + val norm = sqrt(testEmbedding.map { it * it }.sum()) + + if (debugMode) { + Log.d(TAG, "Model test: embedding sum=$sum, norm=$norm") + } + + if (sum == 0f || norm == 0f) { + Log.e(TAG, "⚠️ WARNING: Model test produced zero embedding!") + } else { + if (debugMode) Log.d(TAG, "✅ Model test passed") + } + } catch (e: Exception) { + Log.e(TAG, "Model test failed", e) } } @@ -45,12 +94,22 @@ class FaceNetModel(private val context: Context) { * Load TFLite model from assets */ private fun loadModelFile(): MappedByteBuffer { - val fileDescriptor = context.assets.openFd(MODEL_FILE) - val inputStream = FileInputStream(fileDescriptor.fileDescriptor) - val fileChannel = inputStream.channel - val startOffset = fileDescriptor.startOffset - val declaredLength = fileDescriptor.declaredLength - return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength) + try { + val fileDescriptor = context.assets.openFd(MODEL_FILE) + val inputStream = FileInputStream(fileDescriptor.fileDescriptor) + val fileChannel = inputStream.channel + val startOffset = fileDescriptor.startOffset + val declaredLength = fileDescriptor.declaredLength + + if (debugMode) { + Log.d(TAG, "Model file size: ${declaredLength / 1024}KB") + } + + return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength) + } catch (e: Exception) { + Log.e(TAG, "Failed to open model file: $MODEL_FILE", e) + throw e + } } /** @@ -60,13 +119,39 @@ class FaceNetModel(private val context: Context) { * @return 192-dimensional embedding */ fun generateEmbedding(faceBitmap: Bitmap): FloatArray { - val resized = Bitmap.createScaledBitmap(faceBitmap, INPUT_SIZE, INPUT_SIZE, true) - val inputBuffer = preprocessImage(resized) - val output = Array(1) { FloatArray(EMBEDDING_SIZE) } + if (!modelLoadSuccess || interpreter == null) { + Log.e(TAG, "❌ Cannot generate embedding: model not loaded!") + return FloatArray(EMBEDDING_SIZE) { 0f } + } - interpreter?.run(inputBuffer, output) + try { + val resized = Bitmap.createScaledBitmap(faceBitmap, INPUT_SIZE, INPUT_SIZE, true) + val inputBuffer = preprocessImage(resized) + val output = Array(1) { FloatArray(EMBEDDING_SIZE) } - return normalizeEmbedding(output[0]) + interpreter?.run(inputBuffer, output) + + val normalized = normalizeEmbedding(output[0]) + + // DIAGNOSTIC: Check embedding quality + if (debugMode) { + val sum = normalized.sum() + val norm = sqrt(normalized.map { it * it }.sum()) + + if (sum == 0f && norm == 0f) { + Log.e(TAG, "❌ CRITICAL: Generated all-zero embedding!") + Log.e(TAG, "Input bitmap: ${faceBitmap.width}x${faceBitmap.height}") + } else { + Log.d(TAG, "✅ Embedding: sum=${"%.2f".format(sum)}, norm=${"%.2f".format(norm)}, first5=[${normalized.take(5).joinToString { "%.3f".format(it) }}]") + } + } + + return normalized + + } catch (e: Exception) { + Log.e(TAG, "Failed to generate embedding", e) + return FloatArray(EMBEDDING_SIZE) { 0f } + } } /** @@ -76,6 +161,10 @@ class FaceNetModel(private val context: Context) { faceBitmaps: List, onProgress: (Int, Int) -> Unit = { _, _ -> } ): List { + if (debugMode) { + Log.d(TAG, "Generating embeddings for ${faceBitmaps.size} faces") + } + return faceBitmaps.mapIndexed { index, bitmap -> onProgress(index + 1, faceBitmaps.size) generateEmbedding(bitmap) @@ -88,6 +177,10 @@ class FaceNetModel(private val context: Context) { fun createPersonModel(embeddings: List): FloatArray { require(embeddings.isNotEmpty()) { "Need at least one embedding" } + if (debugMode) { + Log.d(TAG, "Creating person model from ${embeddings.size} embeddings") + } + val averaged = FloatArray(EMBEDDING_SIZE) { 0f } embeddings.forEach { embedding -> @@ -101,7 +194,14 @@ class FaceNetModel(private val context: Context) { averaged[i] /= count } - return normalizeEmbedding(averaged) + val normalized = normalizeEmbedding(averaged) + + if (debugMode) { + val sum = normalized.sum() + Log.d(TAG, "Person model created: sum=${"%.2f".format(sum)}") + } + + return normalized } /** @@ -110,7 +210,7 @@ class FaceNetModel(private val context: Context) { */ fun calculateSimilarity(embedding1: FloatArray, embedding2: FloatArray): Float { require(embedding1.size == EMBEDDING_SIZE && embedding2.size == EMBEDDING_SIZE) { - "Invalid embedding size" + "Invalid embedding size: ${embedding1.size} vs ${embedding2.size}" } var dotProduct = 0f @@ -123,7 +223,14 @@ class FaceNetModel(private val context: Context) { norm2 += embedding2[i] * embedding2[i] } - return dotProduct / (sqrt(norm1) * sqrt(norm2)) + val similarity = dotProduct / (sqrt(norm1) * sqrt(norm2)) + + if (debugMode && (similarity.isNaN() || similarity.isInfinite())) { + Log.e(TAG, "❌ Invalid similarity: $similarity (norm1=$norm1, norm2=$norm2)") + return 0f + } + + return similarity } /** @@ -151,6 +258,10 @@ class FaceNetModel(private val context: Context) { } } + if (debugMode && bestMatch != null) { + Log.d(TAG, "Best match: ${bestMatch.first} with similarity ${bestMatch.second}") + } + return bestMatch } @@ -169,6 +280,7 @@ class FaceNetModel(private val context: Context) { val g = ((pixel shr 8) and 0xFF) / 255.0f val b = (pixel and 0xFF) / 255.0f + // Normalize to [-1, 1] buffer.putFloat((r - 0.5f) / 0.5f) buffer.putFloat((g - 0.5f) / 0.5f) buffer.putFloat((b - 0.5f) / 0.5f) @@ -190,14 +302,29 @@ class FaceNetModel(private val context: Context) { return if (norm > 0) { FloatArray(embedding.size) { i -> embedding[i] / norm } } else { + Log.w(TAG, "⚠️ Cannot normalize zero embedding") embedding } } + /** + * Get model status for diagnostics + */ + fun getModelStatus(): String { + return if (modelLoadSuccess) { + "✅ Model loaded and operational" + } else { + "❌ Model failed to load - check assets/$MODEL_FILE" + } + } + /** * Clean up resources */ fun close() { + if (debugMode) { + Log.d(TAG, "Closing FaceNet model") + } interpreter?.close() interpreter = null } diff --git a/app/src/main/java/com/placeholder/sherpai2/ui/discover/Clustergridscreen.kt b/app/src/main/java/com/placeholder/sherpai2/ui/discover/Clustergridscreen.kt index c502417..965621d 100644 --- a/app/src/main/java/com/placeholder/sherpai2/ui/discover/Clustergridscreen.kt +++ b/app/src/main/java/com/placeholder/sherpai2/ui/discover/Clustergridscreen.kt @@ -8,9 +8,13 @@ import androidx.compose.foundation.layout.* import androidx.compose.foundation.lazy.grid.GridCells import androidx.compose.foundation.lazy.grid.LazyVerticalGrid import androidx.compose.foundation.lazy.grid.items +import androidx.compose.foundation.shape.CircleShape import androidx.compose.foundation.shape.RoundedCornerShape +import androidx.compose.material.icons.Icons +import androidx.compose.material.icons.filled.Check +import androidx.compose.material.icons.filled.Warning import androidx.compose.material3.* -import androidx.compose.runtime.Composable +import androidx.compose.runtime.* import androidx.compose.ui.Alignment import androidx.compose.ui.Modifier import androidx.compose.ui.draw.clip @@ -19,6 +23,8 @@ import androidx.compose.ui.layout.ContentScale import androidx.compose.ui.text.font.FontWeight import androidx.compose.ui.unit.dp import coil.compose.AsyncImage +import com.placeholder.sherpai2.domain.clustering.ClusterQualityAnalyzer +import com.placeholder.sherpai2.domain.clustering.ClusterQualityTier import com.placeholder.sherpai2.domain.clustering.ClusteringResult import com.placeholder.sherpai2.domain.clustering.FaceCluster @@ -28,13 +34,20 @@ import com.placeholder.sherpai2.domain.clustering.FaceCluster * Each cluster card shows: * - 2x2 grid of representative faces * - Photo count + * - Quality badge (Excellent/Good/Poor) * - Tap to name + * + * IMPROVEMENTS: + * - ✅ Quality badges for each cluster + * - ✅ Visual indicators for trainable vs non-trainable clusters + * - ✅ Better UX with disabled states for poor quality clusters */ @Composable fun ClusterGridScreen( result: ClusteringResult, onSelectCluster: (FaceCluster) -> Unit, - modifier: Modifier = Modifier + modifier: Modifier = Modifier, + qualityAnalyzer: ClusterQualityAnalyzer = remember { ClusterQualityAnalyzer() } ) { Column( modifier = modifier @@ -65,8 +78,15 @@ fun ClusterGridScreen( verticalArrangement = Arrangement.spacedBy(12.dp) ) { items(result.clusters) { cluster -> + // Analyze quality for each cluster + val qualityResult = remember(cluster) { + qualityAnalyzer.analyzeCluster(cluster) + } + ClusterCard( cluster = cluster, + qualityTier = qualityResult.qualityTier, + canTrain = qualityResult.canTrain, onClick = { onSelectCluster(cluster) } ) } @@ -75,77 +95,168 @@ fun ClusterGridScreen( } /** - * Single cluster card with 2x2 face grid + * Single cluster card with 2x2 face grid and quality badge */ @Composable private fun ClusterCard( cluster: FaceCluster, + qualityTier: ClusterQualityTier, + canTrain: Boolean, onClick: () -> Unit ) { Card( modifier = Modifier .fillMaxWidth() .aspectRatio(1f) - .clickable(onClick = onClick), - elevation = CardDefaults.cardElevation(defaultElevation = 2.dp) + .clickable(onClick = onClick), // Always clickable - let dialog handle validation + elevation = CardDefaults.cardElevation(defaultElevation = 2.dp), + colors = CardDefaults.cardColors( + containerColor = when { + qualityTier == ClusterQualityTier.POOR -> + MaterialTheme.colorScheme.errorContainer.copy(alpha = 0.3f) + !canTrain -> + MaterialTheme.colorScheme.surfaceVariant.copy(alpha = 0.5f) + else -> + MaterialTheme.colorScheme.surface + } + ) ) { - Column( + Box( modifier = Modifier.fillMaxSize() ) { - // 2x2 grid of faces - val facesToShow = cluster.representativeFaces.take(4) - Column( - modifier = Modifier.weight(1f) + modifier = Modifier.fillMaxSize() ) { - // Top row (2 faces) - Row(modifier = Modifier.weight(1f)) { - facesToShow.getOrNull(0)?.let { face -> - FaceThumbnail( - imageUri = face.imageUri, - modifier = Modifier.weight(1f) - ) - } ?: EmptyFaceSlot(Modifier.weight(1f)) + // 2x2 grid of faces + val facesToShow = cluster.representativeFaces.take(4) - facesToShow.getOrNull(1)?.let { face -> - FaceThumbnail( - imageUri = face.imageUri, - modifier = Modifier.weight(1f) - ) - } ?: EmptyFaceSlot(Modifier.weight(1f)) + Column( + modifier = Modifier.weight(1f) + ) { + // Top row (2 faces) + Row(modifier = Modifier.weight(1f)) { + facesToShow.getOrNull(0)?.let { face -> + FaceThumbnail( + imageUri = face.imageUri, + enabled = canTrain, + modifier = Modifier.weight(1f) + ) + } ?: EmptyFaceSlot(Modifier.weight(1f)) + + facesToShow.getOrNull(1)?.let { face -> + FaceThumbnail( + imageUri = face.imageUri, + enabled = canTrain, + modifier = Modifier.weight(1f) + ) + } ?: EmptyFaceSlot(Modifier.weight(1f)) + } + + // Bottom row (2 faces) + Row(modifier = Modifier.weight(1f)) { + facesToShow.getOrNull(2)?.let { face -> + FaceThumbnail( + imageUri = face.imageUri, + enabled = canTrain, + modifier = Modifier.weight(1f) + ) + } ?: EmptyFaceSlot(Modifier.weight(1f)) + + facesToShow.getOrNull(3)?.let { face -> + FaceThumbnail( + imageUri = face.imageUri, + enabled = canTrain, + modifier = Modifier.weight(1f) + ) + } ?: EmptyFaceSlot(Modifier.weight(1f)) + } } - // Bottom row (2 faces) - Row(modifier = Modifier.weight(1f)) { - facesToShow.getOrNull(2)?.let { face -> - FaceThumbnail( - imageUri = face.imageUri, - modifier = Modifier.weight(1f) + // Footer with photo count + Surface( + modifier = Modifier.fillMaxWidth(), + color = if (canTrain) { + MaterialTheme.colorScheme.primaryContainer + } else { + MaterialTheme.colorScheme.surfaceVariant + } + ) { + Row( + modifier = Modifier.padding(12.dp), + verticalAlignment = Alignment.CenterVertically, + horizontalArrangement = Arrangement.SpaceBetween + ) { + Text( + text = "${cluster.photoCount} photos", + style = MaterialTheme.typography.bodyMedium, + fontWeight = FontWeight.SemiBold, + color = if (canTrain) { + MaterialTheme.colorScheme.onPrimaryContainer + } else { + MaterialTheme.colorScheme.onSurfaceVariant + } ) - } ?: EmptyFaceSlot(Modifier.weight(1f)) - - facesToShow.getOrNull(3)?.let { face -> - FaceThumbnail( - imageUri = face.imageUri, - modifier = Modifier.weight(1f) - ) - } ?: EmptyFaceSlot(Modifier.weight(1f)) + } } } - // Footer with photo count - Surface( - modifier = Modifier.fillMaxWidth(), - color = MaterialTheme.colorScheme.primaryContainer - ) { - Text( - text = "${cluster.photoCount} photos", - style = MaterialTheme.typography.bodyMedium, - fontWeight = FontWeight.SemiBold, - modifier = Modifier.padding(12.dp), - color = MaterialTheme.colorScheme.onPrimaryContainer - ) - } + // Quality badge overlay + QualityBadge( + qualityTier = qualityTier, + canTrain = canTrain, + modifier = Modifier + .align(Alignment.TopEnd) + .padding(8.dp) + ) + } + } +} + +/** + * Quality badge indicator + */ +@Composable +private fun QualityBadge( + qualityTier: ClusterQualityTier, + canTrain: Boolean, + modifier: Modifier = Modifier +) { + val (backgroundColor, iconColor, icon) = when (qualityTier) { + ClusterQualityTier.EXCELLENT -> Triple( + Color(0xFF1B5E20), + Color.White, + Icons.Default.Check + ) + ClusterQualityTier.GOOD -> Triple( + Color(0xFF2E7D32), + Color.White, + Icons.Default.Check + ) + ClusterQualityTier.POOR -> Triple( + Color(0xFFD32F2F), + Color.White, + Icons.Default.Warning + ) + } + + Surface( + modifier = modifier, + shape = CircleShape, + color = backgroundColor, + shadowElevation = 2.dp + ) { + Box( + modifier = Modifier + .size(32.dp) + .padding(6.dp), + contentAlignment = Alignment.Center + ) { + Icon( + imageVector = icon, + contentDescription = qualityTier.name, + tint = iconColor, + modifier = Modifier.size(20.dp) + ) } } } @@ -153,19 +264,23 @@ private fun ClusterCard( @Composable private fun FaceThumbnail( imageUri: String, + enabled: Boolean, modifier: Modifier = Modifier ) { - AsyncImage( - model = Uri.parse(imageUri), - contentDescription = "Face", - modifier = modifier - .fillMaxSize() - .border( - width = 0.5.dp, - color = MaterialTheme.colorScheme.outline.copy(alpha = 0.3f) - ), - contentScale = ContentScale.Crop - ) + Box(modifier = modifier) { + AsyncImage( + model = Uri.parse(imageUri), + contentDescription = "Face", + modifier = Modifier + .fillMaxSize() + .border( + width = 0.5.dp, + color = MaterialTheme.colorScheme.outline.copy(alpha = 0.3f) + ), + contentScale = ContentScale.Crop, + alpha = if (enabled) 1f else 0.6f + ) + } } @Composable diff --git a/app/src/main/java/com/placeholder/sherpai2/ui/discover/Discoverpeoplescreen.kt b/app/src/main/java/com/placeholder/sherpai2/ui/discover/Discoverpeoplescreen.kt index 4c722da..64a2630 100644 --- a/app/src/main/java/com/placeholder/sherpai2/ui/discover/Discoverpeoplescreen.kt +++ b/app/src/main/java/com/placeholder/sherpai2/ui/discover/Discoverpeoplescreen.kt @@ -11,11 +11,17 @@ import androidx.compose.ui.text.font.FontWeight import androidx.compose.ui.text.style.TextAlign import androidx.compose.ui.unit.dp import androidx.hilt.navigation.compose.hiltViewModel +import com.placeholder.sherpai2.domain.clustering.ClusterQualityAnalyzer /** - * DiscoverPeopleScreen - COMPLETE WORKING VERSION + * DiscoverPeopleScreen - COMPLETE WORKING VERSION WITH NAMING DIALOG * - * This handles ALL states properly including Idle state + * This handles ALL states properly including the NamingCluster dialog + * + * IMPROVEMENTS: + * - ✅ Complete naming dialog integration + * - ✅ Quality analysis in cluster grid + * - ✅ Better error handling */ @OptIn(ExperimentalMaterial3Api::class) @Composable @@ -24,125 +30,130 @@ fun DiscoverPeopleScreen( onNavigateBack: () -> Unit = {} ) { val uiState by viewModel.uiState.collectAsState() + val qualityAnalyzer = remember { ClusterQualityAnalyzer() } - Scaffold( - topBar = { - TopAppBar( - title = { Text("Discover People") }, - navigationIcon = { - IconButton(onClick = onNavigateBack) { - Icon( - imageVector = Icons.Default.Person, - contentDescription = "Back" + // No Scaffold, no TopAppBar - MainScreen handles that + Box( + modifier = Modifier.fillMaxSize() + ) { + when (val state = uiState) { + // ===== IDLE STATE (START HERE) ===== + is DiscoverUiState.Idle -> { + IdleStateContent( + onStartDiscovery = { viewModel.startDiscovery() } + ) + } + + // ===== CLUSTERING IN PROGRESS ===== + is DiscoverUiState.Clustering -> { + ClusteringProgressContent( + progress = state.progress, + total = state.total, + message = state.message + ) + } + + // ===== CLUSTERS READY FOR NAMING ===== + is DiscoverUiState.NamingReady -> { + ClusterGridScreen( + result = state.result, + onSelectCluster = { cluster -> + viewModel.selectCluster(cluster) + }, + qualityAnalyzer = qualityAnalyzer + ) + } + + // ===== ANALYZING CLUSTER QUALITY ===== + is DiscoverUiState.AnalyzingCluster -> { + LoadingContent(message = "Analyzing cluster quality...") + } + + // ===== NAMING A CLUSTER (SHOW DIALOG) ===== + is DiscoverUiState.NamingCluster -> { + // Show cluster grid in background + ClusterGridScreen( + result = state.result, + onSelectCluster = { /* Disabled while dialog open */ }, + qualityAnalyzer = qualityAnalyzer + ) + + // Show naming dialog overlay + NamingDialog( + cluster = state.selectedCluster, + suggestedSiblings = state.suggestedSiblings, + onConfirm = { name, dateOfBirth, isChild, selectedSiblings -> + viewModel.confirmClusterName( + cluster = state.selectedCluster, + name = name, + dateOfBirth = dateOfBirth, + isChild = isChild, + selectedSiblings = selectedSiblings ) + }, + onDismiss = { + viewModel.cancelNaming() + }, + qualityAnalyzer = qualityAnalyzer + ) + } + + // ===== TRAINING IN PROGRESS ===== + is DiscoverUiState.Training -> { + TrainingProgressContent( + stage = state.stage, + progress = state.progress, + total = state.total + ) + } + + // ===== VALIDATION PREVIEW ===== + is DiscoverUiState.ValidationPreview -> { + ValidationPreviewScreen( + personName = state.personName, + validationResult = state.validationResult, + onApprove = { + viewModel.approveValidationAndScan( + personId = state.personId, + personName = state.personName + ) + }, + onReject = { + viewModel.rejectValidationAndImprove() } - } - ) - } - ) { paddingValues -> - Box( - modifier = Modifier - .fillMaxSize() - .padding(paddingValues) - ) { - when (val state = uiState) { - // ===== IDLE STATE (START HERE) ===== - is DiscoverUiState.Idle -> { - IdleStateContent( - onStartDiscovery = { viewModel.startDiscovery() } - ) - } + ) + } - // ===== CLUSTERING IN PROGRESS ===== - is DiscoverUiState.Clustering -> { - ClusteringProgressContent( - progress = state.progress, - total = state.total, - message = state.message - ) - } + // ===== COMPLETE ===== + is DiscoverUiState.Complete -> { + CompleteStateContent( + message = state.message, + onDone = onNavigateBack + ) + } - // ===== CLUSTERS READY FOR NAMING ===== - is DiscoverUiState.NamingReady -> { - ClusterGridScreen( - result = state.result, - onSelectCluster = { cluster -> - viewModel.selectCluster(cluster) - } - ) - } + // ===== NO PEOPLE FOUND ===== + is DiscoverUiState.NoPeopleFound -> { + ErrorStateContent( + title = "No People Found", + message = state.message, + onRetry = { viewModel.startDiscovery() }, + onBack = onNavigateBack + ) + } - // ===== ANALYZING CLUSTER QUALITY ===== - is DiscoverUiState.AnalyzingCluster -> { - LoadingContent(message = "Analyzing cluster quality...") - } - - // ===== NAMING A CLUSTER ===== - is DiscoverUiState.NamingCluster -> { - Text( - text = "Naming dialog for cluster ${state.selectedCluster.clusterId}\n\nDialog UI coming...", - modifier = Modifier.align(Alignment.Center) - ) - } - - // ===== TRAINING IN PROGRESS ===== - is DiscoverUiState.Training -> { - TrainingProgressContent( - stage = state.stage, - progress = state.progress, - total = state.total - ) - } - - // ===== VALIDATION PREVIEW ===== - is DiscoverUiState.ValidationPreview -> { - ValidationPreviewScreen( - personName = state.personName, - validationResult = state.validationResult, - onApprove = { - viewModel.approveValidationAndScan( - personId = state.personId, - personName = state.personName - ) - }, - onReject = { - viewModel.rejectValidationAndImprove() - } - ) - } - - // ===== COMPLETE ===== - is DiscoverUiState.Complete -> { - CompleteStateContent( - message = state.message, - onDone = onNavigateBack - ) - } - - // ===== NO PEOPLE FOUND ===== - is DiscoverUiState.NoPeopleFound -> { - ErrorStateContent( - title = "No People Found", - message = state.message, - onRetry = { viewModel.startDiscovery() }, - onBack = onNavigateBack - ) - } - - // ===== ERROR ===== - is DiscoverUiState.Error -> { - ErrorStateContent( - title = "Error", - message = state.message, - onRetry = { viewModel.reset(); viewModel.startDiscovery() }, - onBack = onNavigateBack - ) - } + // ===== ERROR ===== + is DiscoverUiState.Error -> { + ErrorStateContent( + title = "Error", + message = state.message, + onRetry = { viewModel.reset(); viewModel.startDiscovery() }, + onBack = onNavigateBack + ) } } } } - // ===== IDLE STATE CONTENT ===== @Composable @@ -165,19 +176,11 @@ private fun IdleStateContent( Spacer(modifier = Modifier.height(32.dp)) - Text( - text = "Discover People", - style = MaterialTheme.typography.headlineLarge, - fontWeight = FontWeight.Bold - ) - - Spacer(modifier = Modifier.height(16.dp)) - Text( text = "Automatically find and organize people in your photo library", - style = MaterialTheme.typography.bodyLarge, + style = MaterialTheme.typography.headlineSmall, textAlign = TextAlign.Center, - color = MaterialTheme.colorScheme.onSurfaceVariant + color = MaterialTheme.colorScheme.onSurface ) Spacer(modifier = Modifier.height(48.dp)) diff --git a/app/src/main/java/com/placeholder/sherpai2/ui/discover/Namingdialog.kt b/app/src/main/java/com/placeholder/sherpai2/ui/discover/Namingdialog.kt new file mode 100644 index 0000000..e024715 --- /dev/null +++ b/app/src/main/java/com/placeholder/sherpai2/ui/discover/Namingdialog.kt @@ -0,0 +1,480 @@ +package com.placeholder.sherpai2.ui.discover + +import androidx.compose.foundation.background +import androidx.compose.foundation.border +import androidx.compose.foundation.clickable +import androidx.compose.foundation.layout.* +import androidx.compose.foundation.lazy.LazyRow +import androidx.compose.foundation.lazy.items +import androidx.compose.foundation.rememberScrollState +import androidx.compose.foundation.shape.CircleShape +import androidx.compose.foundation.shape.RoundedCornerShape +import androidx.compose.foundation.text.KeyboardActions +import androidx.compose.foundation.text.KeyboardOptions +import androidx.compose.foundation.verticalScroll +import androidx.compose.material.icons.Icons +import androidx.compose.material.icons.filled.* +import androidx.compose.material3.* +import androidx.compose.runtime.* +import androidx.compose.ui.Alignment +import androidx.compose.ui.Modifier +import androidx.compose.ui.draw.clip +import androidx.compose.ui.graphics.Color +import androidx.compose.ui.layout.ContentScale +import androidx.compose.ui.platform.LocalSoftwareKeyboardController +import androidx.compose.ui.text.font.FontWeight +import androidx.compose.ui.text.input.ImeAction +import androidx.compose.ui.text.input.KeyboardCapitalization +import androidx.compose.ui.text.input.KeyboardType +import androidx.compose.ui.text.style.TextAlign +import androidx.compose.ui.unit.dp +import androidx.compose.ui.window.Dialog +import coil.compose.AsyncImage +import com.placeholder.sherpai2.domain.clustering.ClusterQualityAnalyzer +import com.placeholder.sherpai2.domain.clustering.ClusterQualityTier +import com.placeholder.sherpai2.domain.clustering.FaceCluster +import java.text.SimpleDateFormat +import java.util.* + +/** + * NamingDialog - Complete dialog for naming a cluster + * + * Features: + * - Name input with validation + * - Child toggle with date of birth picker + * - Sibling cluster selection + * - Quality warnings display + * - Preview of representative faces + * + * IMPROVEMENTS: + * - ✅ Complete UI implementation + * - ✅ Quality analysis integration + * - ✅ Sibling selection + * - ✅ Form validation + */ +@OptIn(ExperimentalMaterial3Api::class) +@Composable +fun NamingDialog( + cluster: FaceCluster, + suggestedSiblings: List, + onConfirm: (name: String, dateOfBirth: Long?, isChild: Boolean, selectedSiblings: List) -> Unit, + onDismiss: () -> Unit, + qualityAnalyzer: ClusterQualityAnalyzer = remember { ClusterQualityAnalyzer() } +) { + var name by remember { mutableStateOf("") } + var isChild by remember { mutableStateOf(false) } + var showDatePicker by remember { mutableStateOf(false) } + var dateOfBirth by remember { mutableStateOf(null) } + var selectedSiblingIds by remember { mutableStateOf(setOf()) } + + // Analyze cluster quality + val qualityResult = remember(cluster) { + qualityAnalyzer.analyzeCluster(cluster) + } + + val keyboardController = LocalSoftwareKeyboardController.current + val dateFormatter = remember { SimpleDateFormat("MMM dd, yyyy", Locale.getDefault()) } + + Dialog(onDismissRequest = onDismiss) { + Card( + modifier = Modifier + .fillMaxWidth() + .fillMaxHeight(0.9f), + shape = RoundedCornerShape(16.dp), + elevation = CardDefaults.cardElevation(defaultElevation = 8.dp) + ) { + Column( + modifier = Modifier + .fillMaxSize() + .verticalScroll(rememberScrollState()) + ) { + // Header + Surface( + color = MaterialTheme.colorScheme.primaryContainer + ) { + Row( + modifier = Modifier + .fillMaxWidth() + .padding(16.dp), + horizontalArrangement = Arrangement.SpaceBetween, + verticalAlignment = Alignment.CenterVertically + ) { + Column(modifier = Modifier.weight(1f)) { + Text( + text = "Name This Person", + style = MaterialTheme.typography.titleLarge, + fontWeight = FontWeight.Bold, + color = MaterialTheme.colorScheme.onPrimaryContainer + ) + Text( + text = "${cluster.photoCount} photos", + style = MaterialTheme.typography.bodyMedium, + color = MaterialTheme.colorScheme.onPrimaryContainer.copy(alpha = 0.7f) + ) + } + + IconButton(onClick = onDismiss) { + Icon( + imageVector = Icons.Default.Close, + contentDescription = "Close", + tint = MaterialTheme.colorScheme.onPrimaryContainer + ) + } + } + } + + Column( + modifier = Modifier.padding(16.dp) + ) { + // Preview faces + Text( + text = "Preview", + style = MaterialTheme.typography.titleMedium, + fontWeight = FontWeight.SemiBold + ) + + Spacer(modifier = Modifier.height(8.dp)) + + LazyRow( + horizontalArrangement = Arrangement.spacedBy(8.dp) + ) { + items(cluster.representativeFaces.take(6)) { face -> + AsyncImage( + model = android.net.Uri.parse(face.imageUri), + contentDescription = "Preview", + modifier = Modifier + .size(80.dp) + .clip(RoundedCornerShape(8.dp)) + .border( + width = 1.dp, + color = MaterialTheme.colorScheme.outline, + shape = RoundedCornerShape(8.dp) + ), + contentScale = ContentScale.Crop + ) + } + } + + Spacer(modifier = Modifier.height(20.dp)) + + // Quality warning (if applicable) + if (qualityResult.qualityTier != ClusterQualityTier.EXCELLENT) { + QualityWarningCard(qualityResult = qualityResult) + Spacer(modifier = Modifier.height(16.dp)) + } + + // Name input + OutlinedTextField( + value = name, + onValueChange = { name = it }, + label = { Text("Name") }, + placeholder = { Text("Enter person's name") }, + modifier = Modifier.fillMaxWidth(), + singleLine = true, + leadingIcon = { + Icon( + imageVector = Icons.Default.Person, + contentDescription = null + ) + }, + keyboardOptions = KeyboardOptions( + capitalization = KeyboardCapitalization.Words, + imeAction = ImeAction.Done + ), + keyboardActions = KeyboardActions( + onDone = { keyboardController?.hide() } + ) + ) + + Spacer(modifier = Modifier.height(16.dp)) + + // Child toggle + Row( + modifier = Modifier + .fillMaxWidth() + .clip(RoundedCornerShape(8.dp)) + .clickable { isChild = !isChild } + .background( + if (isChild) MaterialTheme.colorScheme.primaryContainer + else MaterialTheme.colorScheme.surfaceVariant + ) + .padding(16.dp), + verticalAlignment = Alignment.CenterVertically, + horizontalArrangement = Arrangement.SpaceBetween + ) { + Row( + verticalAlignment = Alignment.CenterVertically + ) { + Icon( + imageVector = Icons.Default.Face, + contentDescription = null, + tint = if (isChild) MaterialTheme.colorScheme.onPrimaryContainer + else MaterialTheme.colorScheme.onSurfaceVariant + ) + Spacer(modifier = Modifier.width(12.dp)) + Column { + Text( + text = "This is a child", + style = MaterialTheme.typography.bodyLarge, + fontWeight = FontWeight.Medium, + color = if (isChild) MaterialTheme.colorScheme.onPrimaryContainer + else MaterialTheme.colorScheme.onSurfaceVariant + ) + Text( + text = "For age-appropriate filtering", + style = MaterialTheme.typography.bodySmall, + color = if (isChild) MaterialTheme.colorScheme.onPrimaryContainer.copy(alpha = 0.7f) + else MaterialTheme.colorScheme.onSurfaceVariant.copy(alpha = 0.7f) + ) + } + } + + Switch( + checked = isChild, + onCheckedChange = { isChild = it } + ) + } + + // Date of birth (if child) + if (isChild) { + Spacer(modifier = Modifier.height(12.dp)) + + OutlinedButton( + onClick = { showDatePicker = true }, + modifier = Modifier.fillMaxWidth() + ) { + Icon( + imageVector = Icons.Default.DateRange, + contentDescription = null + ) + Spacer(modifier = Modifier.width(8.dp)) + Text( + text = dateOfBirth?.let { dateFormatter.format(Date(it)) } + ?: "Set date of birth (optional)" + ) + } + } + + // Sibling selection + if (suggestedSiblings.isNotEmpty()) { + Spacer(modifier = Modifier.height(20.dp)) + + Text( + text = "Appears with", + style = MaterialTheme.typography.titleMedium, + fontWeight = FontWeight.SemiBold + ) + + Text( + text = "Select siblings or family members", + style = MaterialTheme.typography.bodySmall, + color = MaterialTheme.colorScheme.onSurfaceVariant + ) + + Spacer(modifier = Modifier.height(8.dp)) + + suggestedSiblings.forEach { sibling -> + SiblingSelectionItem( + cluster = sibling, + selected = sibling.clusterId in selectedSiblingIds, + onToggle = { + selectedSiblingIds = if (sibling.clusterId in selectedSiblingIds) { + selectedSiblingIds - sibling.clusterId + } else { + selectedSiblingIds + sibling.clusterId + } + } + ) + Spacer(modifier = Modifier.height(8.dp)) + } + } + + Spacer(modifier = Modifier.height(24.dp)) + + // Action buttons + Row( + modifier = Modifier.fillMaxWidth(), + horizontalArrangement = Arrangement.spacedBy(12.dp) + ) { + OutlinedButton( + onClick = onDismiss, + modifier = Modifier.weight(1f) + ) { + Text("Cancel") + } + + Button( + onClick = { + if (name.isNotBlank()) { + onConfirm( + name.trim(), + dateOfBirth, + isChild, + selectedSiblingIds.toList() + ) + } + }, + modifier = Modifier.weight(1f), + enabled = name.isNotBlank() && qualityResult.canTrain + ) { + Icon( + imageVector = Icons.Default.Check, + contentDescription = null, + modifier = Modifier.size(20.dp) + ) + Spacer(modifier = Modifier.width(8.dp)) + Text("Create Model") + } + } + } + } + } + } + + // Date picker dialog + if (showDatePicker) { + val datePickerState = rememberDatePickerState() + + DatePickerDialog( + onDismissRequest = { showDatePicker = false }, + confirmButton = { + TextButton( + onClick = { + dateOfBirth = datePickerState.selectedDateMillis + showDatePicker = false + } + ) { + Text("OK") + } + }, + dismissButton = { + TextButton(onClick = { showDatePicker = false }) { + Text("Cancel") + } + } + ) { + DatePicker(state = datePickerState) + } + } +} + +/** + * Quality warning card + */ +@Composable +private fun QualityWarningCard(qualityResult: com.placeholder.sherpai2.domain.clustering.ClusterQualityResult) { + val (backgroundColor, iconColor) = when (qualityResult.qualityTier) { + ClusterQualityTier.GOOD -> Pair( + Color(0xFFFFF9C4), + Color(0xFFF57F17) + ) + ClusterQualityTier.POOR -> Pair( + Color(0xFFFFEBEE), + Color(0xFFD32F2F) + ) + else -> Pair( + MaterialTheme.colorScheme.surfaceVariant, + MaterialTheme.colorScheme.onSurfaceVariant + ) + } + + Card( + modifier = Modifier.fillMaxWidth(), + colors = CardDefaults.cardColors( + containerColor = backgroundColor + ) + ) { + Column( + modifier = Modifier.padding(12.dp) + ) { + Row( + verticalAlignment = Alignment.CenterVertically + ) { + Icon( + imageVector = Icons.Default.Warning, + contentDescription = null, + tint = iconColor, + modifier = Modifier.size(20.dp) + ) + Spacer(modifier = Modifier.width(8.dp)) + Text( + text = when (qualityResult.qualityTier) { + ClusterQualityTier.GOOD -> "Review Before Training" + ClusterQualityTier.POOR -> "Quality Issues Detected" + else -> "" + }, + style = MaterialTheme.typography.titleSmall, + fontWeight = FontWeight.Bold, + color = iconColor + ) + } + + Spacer(modifier = Modifier.height(8.dp)) + + qualityResult.warnings.forEach { warning -> + Text( + text = "• $warning", + style = MaterialTheme.typography.bodySmall, + color = MaterialTheme.colorScheme.onSurfaceVariant + ) + } + } + } +} + +/** + * Sibling selection item + */ +@Composable +private fun SiblingSelectionItem( + cluster: FaceCluster, + selected: Boolean, + onToggle: () -> Unit +) { + Row( + modifier = Modifier + .fillMaxWidth() + .clip(RoundedCornerShape(8.dp)) + .clickable(onClick = onToggle) + .background( + if (selected) MaterialTheme.colorScheme.primaryContainer + else MaterialTheme.colorScheme.surfaceVariant + ) + .padding(12.dp), + verticalAlignment = Alignment.CenterVertically, + horizontalArrangement = Arrangement.SpaceBetween + ) { + Row( + verticalAlignment = Alignment.CenterVertically + ) { + // Preview face + AsyncImage( + model = android.net.Uri.parse(cluster.representativeFaces.firstOrNull()?.imageUri ?: ""), + contentDescription = "Preview", + modifier = Modifier + .size(40.dp) + .clip(CircleShape) + .border( + width = 1.dp, + color = MaterialTheme.colorScheme.outline, + shape = CircleShape + ), + contentScale = ContentScale.Crop + ) + + Spacer(modifier = Modifier.width(12.dp)) + + Text( + text = "${cluster.photoCount} photos together", + style = MaterialTheme.typography.bodyMedium, + color = if (selected) MaterialTheme.colorScheme.onPrimaryContainer + else MaterialTheme.colorScheme.onSurfaceVariant + ) + } + + Checkbox( + checked = selected, + onCheckedChange = { onToggle() } + ) + } +} \ No newline at end of file diff --git a/app/src/main/java/com/placeholder/sherpai2/ui/presentation/MainScreen.kt b/app/src/main/java/com/placeholder/sherpai2/ui/presentation/MainScreen.kt index 600ebf8..fc6fa36 100644 --- a/app/src/main/java/com/placeholder/sherpai2/ui/presentation/MainScreen.kt +++ b/app/src/main/java/com/placeholder/sherpai2/ui/presentation/MainScreen.kt @@ -1,56 +1,48 @@ package com.placeholder.sherpai2.ui.presentation -import androidx.compose.foundation.layout.Column import androidx.compose.foundation.layout.padding import androidx.compose.material.icons.Icons -import androidx.compose.material.icons.filled.* +import androidx.compose.material.icons.filled.Menu import androidx.compose.material3.* import androidx.compose.runtime.* import androidx.compose.ui.Modifier -import androidx.compose.ui.text.font.FontWeight import androidx.hilt.navigation.compose.hiltViewModel -import androidx.navigation.compose.currentBackStackEntryAsState import androidx.navigation.compose.rememberNavController +import androidx.navigation.compose.currentBackStackEntryAsState import com.placeholder.sherpai2.ui.navigation.AppNavHost import com.placeholder.sherpai2.ui.navigation.AppRoutes import kotlinx.coroutines.launch /** - * MainScreen - UPDATED with auto face cache check + * MainScreen - Complete app container with drawer navigation * - * NEW: Prompts user to populate face cache on app launch if needed - * FIXED: Prevents double headers for screens with their own TopAppBar + * CRITICAL FIX APPLIED: + * ✅ Removed AppRoutes.DISCOVER from screensWithOwnTopBar + * ✅ DiscoverPeopleScreen now shows hamburger menu + "Discover People" title! */ @OptIn(ExperimentalMaterial3Api::class) @Composable fun MainScreen( - mainViewModel: MainViewModel = hiltViewModel() // Same package - no import needed! + viewModel: MainViewModel = hiltViewModel() ) { - val drawerState = rememberDrawerState(initialValue = DrawerValue.Closed) - val scope = rememberCoroutineScope() val navController = rememberNavController() + val drawerState = rememberDrawerState(DrawerValue.Closed) + val scope = rememberCoroutineScope() - val navBackStackEntry by navController.currentBackStackEntryAsState() - val currentRoute = navBackStackEntry?.destination?.route ?: AppRoutes.SEARCH + val currentBackStackEntry by navController.currentBackStackEntryAsState() + val currentRoute = currentBackStackEntry?.destination?.route - // Face cache status - val needsFaceCache by mainViewModel.needsFaceCachePopulation.collectAsState() - val unscannedCount by mainViewModel.unscannedPhotoCount.collectAsState() + // Face cache prompt dialog state + val needsFaceCachePopulation by viewModel.needsFaceCachePopulation.collectAsState() + val unscannedPhotoCount by viewModel.unscannedPhotoCount.collectAsState() - // Show face cache prompt dialog if needed - if (needsFaceCache && unscannedCount > 0) { - FaceCachePromptDialog( - unscannedPhotoCount = unscannedCount, - onDismiss = { mainViewModel.dismissFaceCachePrompt() }, - onScanNow = { - mainViewModel.dismissFaceCachePrompt() - // Navigate to Photo Utilities - navController.navigate(AppRoutes.UTILITIES) { - launchSingleTop = true - } - } - ) - } + // ✅ CRITICAL FIX: DISCOVER is NOT in this list! + // These screens handle their own TopAppBar/navigation + val screensWithOwnTopBar = setOf( + AppRoutes.IMAGE_DETAIL, + AppRoutes.TRAINING_SCREEN, + AppRoutes.CROP_SCREEN + ) ModalNavigationDrawer( drawerState = drawerState, @@ -60,133 +52,86 @@ fun MainScreen( onDestinationClicked = { route -> scope.launch { drawerState.close() - if (route != currentRoute) { - navController.navigate(route) { - launchSingleTop = true - } + } + navController.navigate(route) { + popUpTo(navController.graph.startDestinationId) { + saveState = true } + launchSingleTop = true + restoreState = true } } ) - }, + } ) { - // CRITICAL: Some screens manage their own TopAppBar - // Hide MainScreen's TopAppBar for these routes to prevent double headers - val screensWithOwnTopBar = setOf( - AppRoutes.TRAINING_PHOTO_SELECTOR, // Has its own TopAppBar with subtitle - "album/", // Album views have their own TopAppBar (prefix match) - AppRoutes.IMAGE_DETAIL // Image detail has its own TopAppBar - ) - - // Check if current route starts with any excluded pattern - val showTopBar = screensWithOwnTopBar.none { currentRoute.startsWith(it) } - Scaffold( topBar = { - if (showTopBar) { + // ✅ Show TopAppBar for ALL screens except those with their own + if (currentRoute !in screensWithOwnTopBar) { TopAppBar( title = { - Column { - Text( - text = getScreenTitle(currentRoute), - style = MaterialTheme.typography.titleLarge, - fontWeight = FontWeight.Bold - ) - getScreenSubtitle(currentRoute)?.let { subtitle -> - Text( - text = subtitle, - style = MaterialTheme.typography.bodySmall, - color = MaterialTheme.colorScheme.onSurfaceVariant - ) + Text( + text = when (currentRoute) { + AppRoutes.SEARCH -> "Search" + AppRoutes.EXPLORE -> "Explore" + AppRoutes.COLLECTIONS -> "Collections" + AppRoutes.DISCOVER -> "Discover People" // ✅ SHOWS NOW! + AppRoutes.INVENTORY -> "People" + AppRoutes.TRAIN -> "Train Model" + AppRoutes.TAGS -> "Tags" + AppRoutes.UTILITIES -> "Utilities" + AppRoutes.SETTINGS -> "Settings" + AppRoutes.MODELS -> "AI Models" + else -> { + // Handle dynamic routes like album/{type}/{id} + if (currentRoute?.startsWith("album/") == true) { + "Album" + } else { + "SherpAI" + } + } } - } + ) }, navigationIcon = { - IconButton( - onClick = { scope.launch { drawerState.open() } } - ) { + IconButton(onClick = { + scope.launch { + drawerState.open() + } + }) { Icon( - Icons.Default.Menu, - contentDescription = "Open Menu", - tint = MaterialTheme.colorScheme.primary + imageVector = Icons.Default.Menu, + contentDescription = "Open menu" ) } }, - actions = { - // Dynamic actions based on current screen - when (currentRoute) { - AppRoutes.SEARCH -> { - IconButton(onClick = { /* TODO: Open filter dialog */ }) { - Icon( - Icons.Default.FilterList, - contentDescription = "Filter", - tint = MaterialTheme.colorScheme.primary - ) - } - } - AppRoutes.INVENTORY -> { - IconButton(onClick = { - navController.navigate(AppRoutes.TRAIN) - }) { - Icon( - Icons.Default.PersonAdd, - contentDescription = "Add Person", - tint = MaterialTheme.colorScheme.primary - ) - } - } - } - }, colors = TopAppBarDefaults.topAppBarColors( - containerColor = MaterialTheme.colorScheme.surface, - titleContentColor = MaterialTheme.colorScheme.onSurface, - navigationIconContentColor = MaterialTheme.colorScheme.primary, - actionIconContentColor = MaterialTheme.colorScheme.primary + containerColor = MaterialTheme.colorScheme.primaryContainer, + titleContentColor = MaterialTheme.colorScheme.onPrimaryContainer, + navigationIconContentColor = MaterialTheme.colorScheme.onPrimaryContainer, + actionIconContentColor = MaterialTheme.colorScheme.onPrimaryContainer ) ) } } ) { paddingValues -> + // ✅ Use YOUR existing AppNavHost - it already has all the screens defined! AppNavHost( navController = navController, modifier = Modifier.padding(paddingValues) ) } } -} -/** - * Get human-readable screen title - */ -private fun getScreenTitle(route: String): String { - return when (route) { - AppRoutes.SEARCH -> "Search" - AppRoutes.EXPLORE -> "Explore" - AppRoutes.COLLECTIONS -> "Collections" - AppRoutes.DISCOVER -> "Discover People" - AppRoutes.INVENTORY -> "People" - AppRoutes.TRAIN -> "Train New Person" - AppRoutes.MODELS -> "AI Models" - AppRoutes.TAGS -> "Tag Management" - AppRoutes.UTILITIES -> "Photo Util." - AppRoutes.SETTINGS -> "Settings" - else -> "SherpAI" - } -} - -/** - * Get subtitle for screens that need context - */ -private fun getScreenSubtitle(route: String): String? { - return when (route) { - AppRoutes.SEARCH -> "Find photos by tags, people, or date" - AppRoutes.EXPLORE -> "Browse your collection" - AppRoutes.COLLECTIONS -> "Your photo collections" - AppRoutes.DISCOVER -> "Auto-find faces in your library" - AppRoutes.INVENTORY -> "Trained face models" - AppRoutes.TRAIN -> "Add a new person to recognize" - AppRoutes.TAGS -> "Organize your photo collection" - AppRoutes.UTILITIES -> "Tools for managing collection" - else -> null + // ✅ Face cache prompt dialog (shows on app launch if needed) + if (needsFaceCachePopulation) { + FaceCachePromptDialog( + unscannedPhotoCount = unscannedPhotoCount, + onDismiss = { viewModel.dismissFaceCachePrompt() }, + onScanNow = { + viewModel.dismissFaceCachePrompt() + navController.navigate(AppRoutes.UTILITIES) + } + ) } } \ No newline at end of file