diff --git a/.idea/deviceManager.xml b/.idea/deviceManager.xml
index fb403f9..106e5ce 100644
--- a/.idea/deviceManager.xml
+++ b/.idea/deviceManager.xml
@@ -8,7 +8,15 @@
-
+
+
+
+
+
+
+
+
+
@@ -17,6 +25,10 @@
diff --git a/app/src/main/assets/mobilefacenet.tflite b/app/src/main/assets/mobilefacenet.tflite
new file mode 100644
index 0000000..057b985
Binary files /dev/null and b/app/src/main/assets/mobilefacenet.tflite differ
diff --git a/app/src/main/java/com/placeholder/sherpai2/data/local/dao/Facecachedao.kt b/app/src/main/java/com/placeholder/sherpai2/data/local/dao/Facecachedao.kt
index 8ead9bc..ea6cfc2 100644
--- a/app/src/main/java/com/placeholder/sherpai2/data/local/dao/Facecachedao.kt
+++ b/app/src/main/java/com/placeholder/sherpai2/data/local/dao/Facecachedao.kt
@@ -1,129 +1,91 @@
package com.placeholder.sherpai2.data.local.dao
-import androidx.room.*
+import androidx.room.Dao
+import androidx.room.Insert
+import androidx.room.OnConflictStrategy
+import androidx.room.Query
+import androidx.room.Update
import com.placeholder.sherpai2.data.local.entity.FaceCacheEntity
-import kotlinx.coroutines.flow.Flow
/**
- * FaceCacheDao - Query face metadata for intelligent filtering
- *
- * ENABLES SMART CLUSTERING:
- * - Pre-filter to high-quality faces only
- * - Avoid processing blurry/distant faces
- * - Faster clustering with better results
+ * FaceCacheDao - Face detection cache with NEW queries for two-stage clustering
*/
@Dao
interface FaceCacheDao {
+ // ═══════════════════════════════════════
+ // INSERT / UPDATE
+ // ═══════════════════════════════════════
+
@Insert(onConflict = OnConflictStrategy.REPLACE)
suspend fun insert(faceCache: FaceCacheEntity)
@Insert(onConflict = OnConflictStrategy.REPLACE)
suspend fun insertAll(faceCaches: List)
- /**
- * Get ALL high-quality solo faces for clustering
- *
- * FILTERS:
- * - Solo photos only (joins with images.faceCount = 1)
- * - Large enough (isLargeEnough = true)
- * - Good quality score (>= 0.6)
- * - Frontal faces preferred (isFrontal = true)
- */
- @Query("""
- SELECT fc.* FROM face_cache fc
- INNER JOIN images i ON fc.imageId = i.imageId
- WHERE i.faceCount = 1
- AND fc.isLargeEnough = 1
- AND fc.qualityScore >= 0.6
- AND fc.isFrontal = 1
- ORDER BY fc.qualityScore DESC
- """)
- suspend fun getHighQualitySoloFaces(): List
+ @Update
+ suspend fun update(faceCache: FaceCacheEntity)
+
+ // ═══════════════════════════════════════
+ // NEW CLUSTERING QUERIES ⭐
+ // ═══════════════════════════════════════
/**
- * Get high-quality faces from ANY photo (including group photos)
- * Use when not enough solo photos available
+ * Get high-quality solo faces for Stage 1 clustering
+ *
+ * Filters:
+ * - Solo photos (faceCount = 1)
+ * - Large faces (faceAreaRatio >= minFaceRatio)
+ * - Has embedding
*/
@Query("""
- SELECT * FROM face_cache
- WHERE isLargeEnough = 1
- AND qualityScore >= 0.6
- AND isFrontal = 1
- ORDER BY qualityScore DESC
+ SELECT fc.*
+ FROM face_cache fc
+ INNER JOIN images i ON fc.imageId = i.imageId
+ WHERE i.faceCount = 1
+ AND fc.faceAreaRatio >= :minFaceRatio
+ AND fc.embedding IS NOT NULL
+ ORDER BY fc.faceAreaRatio DESC
LIMIT :limit
""")
- suspend fun getHighQualityFaces(limit: Int = 1000): List
+ suspend fun getHighQualitySoloFaces(
+ minFaceRatio: Float = 0.015f,
+ limit: Int = 2000
+ ): List
/**
- * Get faces for a specific image
- */
- @Query("SELECT * FROM face_cache WHERE imageId = :imageId ORDER BY faceIndex ASC")
- suspend fun getFacesForImage(imageId: String): List
-
- /**
- * Count high-quality solo faces (for UI display)
+ * FALLBACK: Get ANY solo faces with embeddings
+ * Used if getHighQualitySoloFaces() returns empty
*/
@Query("""
- SELECT COUNT(*) FROM face_cache fc
+ SELECT fc.*
+ FROM face_cache fc
INNER JOIN images i ON fc.imageId = i.imageId
- WHERE i.faceCount = 1
- AND fc.isLargeEnough = 1
- AND fc.qualityScore >= 0.6
- """)
- suspend fun getHighQualitySoloFaceCount(): Int
-
- /**
- * Get quality distribution stats
- */
- @Query("""
- SELECT
- SUM(CASE WHEN qualityScore >= 0.8 THEN 1 ELSE 0 END) as excellent,
- SUM(CASE WHEN qualityScore >= 0.6 AND qualityScore < 0.8 THEN 1 ELSE 0 END) as good,
- SUM(CASE WHEN qualityScore < 0.6 THEN 1 ELSE 0 END) as poor,
- COUNT(*) as total
- FROM face_cache
- """)
- suspend fun getQualityStats(): FaceQualityStats?
-
- /**
- * Delete cache for specific image (when image is deleted)
- */
- @Query("DELETE FROM face_cache WHERE imageId = :imageId")
- suspend fun deleteCacheForImage(imageId: String)
-
- /**
- * Delete all cache (for full rebuild)
- */
- @Query("DELETE FROM face_cache")
- suspend fun deleteAll()
-
- /**
- * Get faces with embeddings already computed
- * (Ultra-fast clustering - no need to re-generate)
- */
- @Query("""
- SELECT fc.* FROM face_cache fc
- INNER JOIN images i ON fc.imageId = i.imageId
- WHERE i.faceCount = 1
- AND fc.isLargeEnough = 1
+ WHERE i.faceCount = 1
AND fc.embedding IS NOT NULL
ORDER BY fc.qualityScore DESC
LIMIT :limit
""")
- suspend fun getSoloFacesWithEmbeddings(limit: Int = 2000): List
-}
+ suspend fun getSoloFacesWithEmbeddings(
+ limit: Int = 2000
+ ): List
-/**
- * Quality statistics result
- */
-data class FaceQualityStats(
- val excellent: Int, // qualityScore >= 0.8
- val good: Int, // 0.6 <= qualityScore < 0.8
- val poor: Int, // qualityScore < 0.6
- val total: Int
-) {
- val excellentPercent: Float get() = if (total > 0) excellent.toFloat() / total else 0f
- val goodPercent: Float get() = if (total > 0) good.toFloat() / total else 0f
- val poorPercent: Float get() = if (total > 0) poor.toFloat() / total else 0f
+ // ═══════════════════════════════════════
+ // EXISTING QUERIES (keep as-is)
+ // ═══════════════════════════════════════
+
+ @Query("SELECT * FROM face_cache WHERE id = :id")
+ suspend fun getFaceCacheById(id: String): FaceCacheEntity?
+
+ @Query("SELECT * FROM face_cache WHERE imageId = :imageId ORDER BY faceIndex")
+ suspend fun getFaceCacheForImage(imageId: String): List
+
+ @Query("DELETE FROM face_cache WHERE imageId = :imageId")
+ suspend fun deleteFaceCacheForImage(imageId: String)
+
+ @Query("DELETE FROM face_cache")
+ suspend fun deleteAll()
+
+ @Query("DELETE FROM face_cache WHERE cacheVersion < :version")
+ suspend fun deleteOldVersions(version: Int)
}
\ No newline at end of file
diff --git a/app/src/main/java/com/placeholder/sherpai2/domain/clustering/Clusterqualityanalyzer.kt b/app/src/main/java/com/placeholder/sherpai2/domain/clustering/Clusterqualityanalyzer.kt
index 5660066..4241332 100644
--- a/app/src/main/java/com/placeholder/sherpai2/domain/clustering/Clusterqualityanalyzer.kt
+++ b/app/src/main/java/com/placeholder/sherpai2/domain/clustering/Clusterqualityanalyzer.kt
@@ -1,7 +1,7 @@
package com.placeholder.sherpai2.domain.clustering
import android.graphics.Rect
-import com.placeholder.sherpai2.domain.clustering.DetectedFaceWithEmbedding
+import android.util.Log
import javax.inject.Inject
import javax.inject.Singleton
import kotlin.math.sqrt
@@ -9,56 +9,62 @@ import kotlin.math.sqrt
/**
* ClusterQualityAnalyzer - Validate cluster quality BEFORE training
*
- * PURPOSE: Prevent training on "dirty" clusters (siblings merged, poor quality faces)
- *
- * CHECKS:
- * 1. Solo photo count (min 6 required)
- * 2. Face size (min 15% of image - clear, not distant)
- * 3. Internal consistency (all faces should match well)
- * 4. Outlier detection (find faces that don't belong)
- *
- * QUALITY TIERS:
- * - Excellent (95%+): Safe to train immediately
- * - Good (85-94%): Review outliers, then train
- * - Poor (<85%): Likely mixed people - DO NOT TRAIN!
+ * RELAXED THRESHOLDS for real-world photos (social media, distant shots):
+ * - Face size: 3% (down from 15%)
+ * - Outlier threshold: 65% (down from 75%)
+ * - GOOD tier: 75% (down from 85%)
+ * - EXCELLENT tier: 85% (down from 95%)
*/
@Singleton
class ClusterQualityAnalyzer @Inject constructor() {
companion object {
+ private const val TAG = "ClusterQuality"
private const val MIN_SOLO_PHOTOS = 6
- private const val MIN_FACE_SIZE_RATIO = 0.15f // 15% of image
- private const val MIN_INTERNAL_SIMILARITY = 0.80f
- private const val OUTLIER_THRESHOLD = 0.75f
- private const val EXCELLENT_THRESHOLD = 0.95f
- private const val GOOD_THRESHOLD = 0.85f
+ private const val MIN_FACE_SIZE_RATIO = 0.03f // 3% of image (RELAXED)
+ private const val MIN_FACE_DIMENSION_PIXELS = 50 // 50px minimum (RELAXED)
+ private const val FALLBACK_MIN_DIMENSION = 80 // Fallback when no dimensions
+ private const val MIN_INTERNAL_SIMILARITY = 0.75f
+ private const val OUTLIER_THRESHOLD = 0.65f // RELAXED
+ private const val EXCELLENT_THRESHOLD = 0.85f // RELAXED
+ private const val GOOD_THRESHOLD = 0.75f // RELAXED
}
- /**
- * Analyze cluster quality before training
- */
fun analyzeCluster(cluster: FaceCluster): ClusterQualityResult {
- // Step 1: Filter to solo photos only
- val soloFaces = cluster.faces.filter { it.faceCount == 1 }
+ Log.d(TAG, "========================================")
+ Log.d(TAG, "Analyzing cluster ${cluster.clusterId}")
+ Log.d(TAG, "Total faces: ${cluster.faces.size}")
- // Step 2: Filter by face size (must be clear/close-up)
+ // Step 1: Filter to solo photos
+ val soloFaces = cluster.faces.filter { it.faceCount == 1 }
+ Log.d(TAG, "Solo photos: ${soloFaces.size}")
+
+ // Step 2: Filter by face size
val largeFaces = soloFaces.filter { face ->
- isFaceLargeEnough(face.boundingBox, face.imageUri)
+ isFaceLargeEnough(face)
+ }
+ Log.d(TAG, "Large faces (>= 3%): ${largeFaces.size}")
+
+ if (largeFaces.size < soloFaces.size) {
+ Log.d(TAG, "⚠️ Filtered out ${soloFaces.size - largeFaces.size} small faces")
}
// Step 3: Calculate internal consistency
val (avgSimilarity, outliers) = analyzeInternalConsistency(largeFaces)
- // Step 4: Clean faces (large solo faces, no outliers)
+ // Step 4: Clean faces
val cleanFaces = largeFaces.filter { it !in outliers }
+ Log.d(TAG, "Clean faces: ${cleanFaces.size}")
// Step 5: Calculate quality score
val qualityScore = calculateQualityScore(
soloPhotoCount = soloFaces.size,
largeFaceCount = largeFaces.size,
cleanFaceCount = cleanFaces.size,
- avgSimilarity = avgSimilarity
+ avgSimilarity = avgSimilarity,
+ totalFaces = cluster.faces.size
)
+ Log.d(TAG, "Quality score: ${(qualityScore * 100).toInt()}%")
// Step 6: Determine quality tier
val qualityTier = when {
@@ -66,6 +72,11 @@ class ClusterQualityAnalyzer @Inject constructor() {
qualityScore >= GOOD_THRESHOLD -> ClusterQualityTier.GOOD
else -> ClusterQualityTier.POOR
}
+ Log.d(TAG, "Quality tier: $qualityTier")
+
+ val canTrain = qualityTier != ClusterQualityTier.POOR && cleanFaces.size >= MIN_SOLO_PHOTOS
+ Log.d(TAG, "Can train: $canTrain")
+ Log.d(TAG, "========================================")
return ClusterQualityResult(
originalFaceCount = cluster.faces.size,
@@ -77,62 +88,65 @@ class ClusterQualityAnalyzer @Inject constructor() {
cleanFaces = cleanFaces,
qualityScore = qualityScore,
qualityTier = qualityTier,
- canTrain = qualityTier != ClusterQualityTier.POOR && cleanFaces.size >= MIN_SOLO_PHOTOS,
- warnings = generateWarnings(soloFaces.size, largeFaces.size, cleanFaces.size, qualityTier)
+ canTrain = canTrain,
+ warnings = generateWarnings(soloFaces.size, largeFaces.size, cleanFaces.size, qualityTier, avgSimilarity)
)
}
- /**
- * Check if face is large enough (not distant/blurry)
- *
- * A face should occupy at least 15% of the image area for good quality
- */
- private fun isFaceLargeEnough(boundingBox: Rect, imageUri: String): Boolean {
- // Estimate image dimensions from common aspect ratios
- // For now, use bounding box size as proxy
- val faceArea = boundingBox.width() * boundingBox.height()
+ private fun isFaceLargeEnough(face: DetectedFaceWithEmbedding): Boolean {
+ val faceArea = face.boundingBox.width() * face.boundingBox.height()
- // Assume typical photo is ~2000x1500 = 3,000,000 pixels
- // 15% = 450,000 pixels
- // For a square face: sqrt(450,000) = ~670 pixels per side
+ // Check 1: Absolute minimum
+ if (face.boundingBox.width() < MIN_FACE_DIMENSION_PIXELS ||
+ face.boundingBox.height() < MIN_FACE_DIMENSION_PIXELS) {
+ return false
+ }
- // More conservative: face should be at least 200x200 pixels
- return boundingBox.width() >= 200 && boundingBox.height() >= 200
+ // Check 2: Relative size if we have dimensions
+ if (face.imageWidth > 0 && face.imageHeight > 0) {
+ val imageArea = face.imageWidth * face.imageHeight
+ val faceRatio = faceArea.toFloat() / imageArea.toFloat()
+ return faceRatio >= MIN_FACE_SIZE_RATIO
+ }
+
+ // Fallback: Use absolute size
+ return face.boundingBox.width() >= FALLBACK_MIN_DIMENSION &&
+ face.boundingBox.height() >= FALLBACK_MIN_DIMENSION
}
- /**
- * Analyze how similar faces are to each other (internal consistency)
- *
- * Returns: (average similarity, list of outlier faces)
- */
private fun analyzeInternalConsistency(
faces: List
): Pair> {
if (faces.size < 2) {
+ Log.d(TAG, "Less than 2 faces, skipping consistency check")
return 0f to emptyList()
}
- // Calculate average embedding (centroid)
+ Log.d(TAG, "Analyzing ${faces.size} faces for internal consistency")
+
val centroid = calculateCentroid(faces.map { it.embedding })
- // Calculate similarity of each face to centroid
- val similarities = faces.map { face ->
- face to cosineSimilarity(face.embedding, centroid)
+ val centroidSum = centroid.sum()
+ Log.d(TAG, "Centroid sum: $centroidSum, first5=[${centroid.take(5).joinToString()}]")
+
+ val similarities = faces.mapIndexed { index, face ->
+ val similarity = cosineSimilarity(face.embedding, centroid)
+ Log.d(TAG, "Face $index similarity to centroid: $similarity")
+ face to similarity
}
val avgSimilarity = similarities.map { it.second }.average().toFloat()
+ Log.d(TAG, "Average internal similarity: $avgSimilarity")
- // Find outliers (faces significantly different from centroid)
val outliers = similarities
.filter { (_, similarity) -> similarity < OUTLIER_THRESHOLD }
.map { (face, _) -> face }
+ Log.d(TAG, "Found ${outliers.size} outliers (threshold=$OUTLIER_THRESHOLD)")
+
return avgSimilarity to outliers
}
- /**
- * Calculate centroid (average embedding)
- */
private fun calculateCentroid(embeddings: List): FloatArray {
val size = embeddings.first().size
val centroid = FloatArray(size) { 0f }
@@ -148,14 +162,14 @@ class ClusterQualityAnalyzer @Inject constructor() {
centroid[i] /= count
}
- // Normalize
val norm = sqrt(centroid.map { it * it }.sum())
- return centroid.map { it / norm }.toFloatArray()
+ return if (norm > 0) {
+ centroid.map { it / norm }.toFloatArray()
+ } else {
+ centroid
+ }
}
- /**
- * Cosine similarity between two embeddings
- */
private fun cosineSimilarity(a: FloatArray, b: FloatArray): Float {
var dotProduct = 0f
var normA = 0f
@@ -170,32 +184,31 @@ class ClusterQualityAnalyzer @Inject constructor() {
return dotProduct / (sqrt(normA) * sqrt(normB))
}
- /**
- * Calculate overall quality score (0.0 - 1.0)
- */
private fun calculateQualityScore(
soloPhotoCount: Int,
largeFaceCount: Int,
cleanFaceCount: Int,
- avgSimilarity: Float
+ avgSimilarity: Float,
+ totalFaces: Int
): Float {
- // Weight factors
- val soloPhotoScore = (soloPhotoCount.toFloat() / 20f).coerceIn(0f, 1f) * 0.3f
- val largeFaceScore = (largeFaceCount.toFloat() / 15f).coerceIn(0f, 1f) * 0.2f
- val cleanFaceScore = (cleanFaceCount.toFloat() / 10f).coerceIn(0f, 1f) * 0.2f
- val similarityScore = avgSimilarity * 0.3f
+ val soloRatio = soloPhotoCount.toFloat() / totalFaces.toFloat().coerceAtLeast(1f)
+ val soloPhotoScore = soloRatio.coerceIn(0f, 1f) * 0.25f
+
+ val largeFaceScore = (largeFaceCount.toFloat() / 15f).coerceIn(0f, 1f) * 0.25f
+
+ val cleanFaceScore = (cleanFaceCount.toFloat() / 10f).coerceIn(0f, 1f) * 0.20f
+
+ val similarityScore = avgSimilarity * 0.30f
return soloPhotoScore + largeFaceScore + cleanFaceScore + similarityScore
}
- /**
- * Generate human-readable warnings
- */
private fun generateWarnings(
soloPhotoCount: Int,
largeFaceCount: Int,
cleanFaceCount: Int,
- qualityTier: ClusterQualityTier
+ qualityTier: ClusterQualityTier,
+ avgSimilarity: Float
): List {
val warnings = mutableListOf()
@@ -203,12 +216,20 @@ class ClusterQualityAnalyzer @Inject constructor() {
ClusterQualityTier.POOR -> {
warnings.add("⚠️ POOR QUALITY - This cluster may contain multiple people!")
warnings.add("Do NOT train on this cluster - it will create a bad model.")
+
+ if (avgSimilarity < 0.70f) {
+ warnings.add("Low internal similarity (${(avgSimilarity * 100).toInt()}%) suggests mixed identities.")
+ }
}
ClusterQualityTier.GOOD -> {
warnings.add("⚠️ Review outlier faces before training")
+
+ if (cleanFaceCount < 10) {
+ warnings.add("Consider adding more high-quality photos for better results.")
+ }
}
ClusterQualityTier.EXCELLENT -> {
- // No warnings - ready to train!
+ // No warnings
}
}
@@ -218,38 +239,47 @@ class ClusterQualityAnalyzer @Inject constructor() {
if (largeFaceCount < 6) {
warnings.add("Only $largeFaceCount photos with large/clear faces (prefer 10+)")
+ warnings.add("Tip: Use close-up photos where the face is clearly visible")
}
if (cleanFaceCount < 6) {
warnings.add("After removing outliers: only $cleanFaceCount clean faces (need 6+)")
}
+ if (qualityTier == ClusterQualityTier.EXCELLENT) {
+ warnings.add("✅ Excellent quality! This cluster is ready for training.")
+ warnings.add("High-quality photos with consistent facial features detected.")
+ }
+
return warnings
}
}
-/**
- * Result of cluster quality analysis
- */
data class ClusterQualityResult(
- val originalFaceCount: Int, // Total faces in cluster
- val soloPhotoCount: Int, // Photos with faceCount = 1
- val largeFaceCount: Int, // Solo photos with large faces
- val cleanFaceCount: Int, // Large faces, no outliers
- val avgInternalSimilarity: Float, // How similar faces are (0.0-1.0)
- val outlierFaces: List, // Faces to exclude
- val cleanFaces: List, // Good faces for training
- val qualityScore: Float, // Overall score (0.0-1.0)
+ val originalFaceCount: Int,
+ val soloPhotoCount: Int,
+ val largeFaceCount: Int,
+ val cleanFaceCount: Int,
+ val avgInternalSimilarity: Float,
+ val outlierFaces: List,
+ val cleanFaces: List,
+ val qualityScore: Float,
val qualityTier: ClusterQualityTier,
- val canTrain: Boolean, // Safe to proceed with training?
- val warnings: List // Human-readable issues
-)
+ val canTrain: Boolean,
+ val warnings: List
+) {
+ fun getSummary(): String = when (qualityTier) {
+ ClusterQualityTier.EXCELLENT ->
+ "Excellent quality cluster with $cleanFaceCount high-quality photos ready for training."
+ ClusterQualityTier.GOOD ->
+ "Good quality cluster with $cleanFaceCount usable photos. Review outliers before training."
+ ClusterQualityTier.POOR ->
+ "Poor quality cluster. May contain multiple people or low-quality photos. Add more photos or split cluster."
+ }
+}
-/**
- * Quality tier classification
- */
enum class ClusterQualityTier {
- EXCELLENT, // 95%+ - Safe to train immediately
- GOOD, // 85-94% - Review outliers first
- POOR // <85% - DO NOT TRAIN (likely mixed people)
+ EXCELLENT, // 85%+
+ GOOD, // 75-84%
+ POOR // <75%
}
\ No newline at end of file
diff --git a/app/src/main/java/com/placeholder/sherpai2/domain/clustering/Faceclusteringservice.kt b/app/src/main/java/com/placeholder/sherpai2/domain/clustering/Faceclusteringservice.kt
index 36a4aa2..1f1cafb 100644
--- a/app/src/main/java/com/placeholder/sherpai2/domain/clustering/Faceclusteringservice.kt
+++ b/app/src/main/java/com/placeholder/sherpai2/domain/clustering/Faceclusteringservice.kt
@@ -4,11 +4,13 @@ import android.content.Context
import android.graphics.Bitmap
import android.graphics.BitmapFactory
import android.net.Uri
+import android.util.Log
import com.google.mlkit.vision.common.InputImage
import com.google.mlkit.vision.face.FaceDetection
import com.google.mlkit.vision.face.FaceDetectorOptions
import com.placeholder.sherpai2.data.local.dao.FaceCacheDao
import com.placeholder.sherpai2.data.local.dao.ImageDao
+import com.placeholder.sherpai2.data.local.entity.FaceCacheEntity
import com.placeholder.sherpai2.data.local.entity.ImageEntity
import com.placeholder.sherpai2.ml.FaceNetModel
import dagger.hilt.android.qualifiers.ApplicationContext
@@ -18,7 +20,6 @@ import kotlinx.coroutines.awaitAll
import kotlinx.coroutines.coroutineScope
import kotlinx.coroutines.sync.Semaphore
import kotlinx.coroutines.withContext
-import java.util.concurrent.atomic.AtomicInteger
import javax.inject.Inject
import javax.inject.Singleton
import kotlin.math.sqrt
@@ -33,111 +34,152 @@ import kotlin.math.sqrt
* 4. Detect faces and generate embeddings (parallel)
* 5. Cluster using DBSCAN (epsilon=0.18, minPoints=3)
* 6. Analyze clusters for age, siblings, representatives
+ *
+ * IMPROVEMENTS:
+ * - ✅ Complete fast-path using FaceCacheDao.getSoloFacesWithEmbeddings()
+ * - ✅ Works with existing FaceCacheEntity.getEmbedding() method
+ * - ✅ Centroid-based representative face selection
+ * - ✅ Batched processing to prevent OOM
+ * - ✅ RGB_565 bitmap config for 50% memory savings
*/
@Singleton
class FaceClusteringService @Inject constructor(
@ApplicationContext private val context: Context,
private val imageDao: ImageDao,
- private val faceCacheDao: FaceCacheDao // Optional - will work without it
+ private val faceCacheDao: FaceCacheDao
) {
- private val semaphore = Semaphore(12)
+ private val semaphore = Semaphore(8)
+
+ companion object {
+ private const val TAG = "FaceClustering"
+ private const val MAX_FACES_TO_CLUSTER = 2000
+ private const val MIN_SOLO_PHOTOS = 50
+ private const val BATCH_SIZE = 50
+ private const val MIN_CACHED_FACES = 100
+ }
/**
* Main clustering entry point - HYBRID with automatic fallback
- *
- * @param maxFacesToCluster Limit for performance (default 2000)
- * @param onProgress Progress callback (current, total, message)
*/
suspend fun discoverPeople(
- maxFacesToCluster: Int = 2000,
+ maxFacesToCluster: Int = MAX_FACES_TO_CLUSTER,
onProgress: (Int, Int, String) -> Unit = { _, _, _ -> }
): ClusteringResult = withContext(Dispatchers.Default) {
- // TRY FAST PATH: Use face cache if available
- val highQualityFaces = try {
- withContext(Dispatchers.IO) {
- faceCacheDao.getHighQualitySoloFaces()
- }
- } catch (e: Exception) {
- emptyList()
- }
-
- if (highQualityFaces.isNotEmpty()) {
- // FAST PATH: Use cached faces (future enhancement)
- onProgress(0, 100, "Using face cache (${highQualityFaces.size} faces)...")
- // TODO: Implement cache-based clustering
- // For now, fall through to classic method
- }
-
- // CLASSIC METHOD: Load and process photos
- onProgress(0, 100, "Loading solo photos...")
-
- // Step 1: Get SOLO PHOTOS ONLY (faceCount = 1) for cleaner clustering
- val soloPhotos = withContext(Dispatchers.IO) {
- imageDao.getImagesByFaceCount(count = 1)
- }
-
- // Fallback: If not enough solo photos, use all images with faces
- val imagesWithFaces = if (soloPhotos.size < 50) {
- onProgress(0, 100, "Loading all photos with faces...")
- imageDao.getImagesWithFaces()
- } else {
- soloPhotos
- }
-
- if (imagesWithFaces.isEmpty()) {
- return@withContext ClusteringResult(
- clusters = emptyList(),
- totalFacesAnalyzed = 0,
- processingTimeMs = 0,
- errorMessage = "No photos with faces found. Please ensure face detection cache is populated."
- )
- }
-
- onProgress(10, 100, "Analyzing ${imagesWithFaces.size} photos (${if (soloPhotos.size >= 50) "solo only" else "all"})...")
-
val startTime = System.currentTimeMillis()
- // Step 2: Detect faces and generate embeddings (parallel)
- val allFaces = detectFacesInImages(
- images = imagesWithFaces.take(1000), // Smart limit
- onProgress = { current, total ->
- onProgress(10 + (current * 40 / total), 100, "Detecting faces... $current/$total")
+ // Try high-quality cached faces FIRST (NEW!)
+ var cachedFaces = withContext(Dispatchers.IO) {
+ try {
+ faceCacheDao.getHighQualitySoloFaces(
+ minFaceRatio = 0.015f, // 1.5%
+ limit = maxFacesToCluster
+ )
+ } catch (e: Exception) {
+ // Method doesn't exist yet - that's ok
+ emptyList()
}
- )
+ }
+
+ // Fallback to ANY solo faces if high-quality returned nothing
+ if (cachedFaces.isEmpty()) {
+ Log.w(TAG, "No high-quality faces (>= 1.5%), trying ANY solo faces...")
+ cachedFaces = withContext(Dispatchers.IO) {
+ try {
+ faceCacheDao.getSoloFacesWithEmbeddings(limit = maxFacesToCluster)
+ } catch (e: Exception) {
+ emptyList()
+ }
+ }
+ }
+
+ Log.d(TAG, "Cache check: ${cachedFaces.size} faces available")
+
+ val allFaces = if (cachedFaces.size >= MIN_CACHED_FACES) {
+ // FAST PATH ✅
+ Log.d(TAG, "Using FAST PATH with ${cachedFaces.size} cached faces")
+ onProgress(10, 100, "Using cached embeddings (${cachedFaces.size} faces)...")
+
+ cachedFaces.mapNotNull { cached ->
+ val embedding = cached.getEmbedding() ?: return@mapNotNull null
+
+ DetectedFaceWithEmbedding(
+ imageId = cached.imageId,
+ imageUri = "",
+ capturedAt = 0L,
+ embedding = embedding,
+ boundingBox = cached.getBoundingBox(),
+ confidence = cached.confidence,
+ faceCount = 1, // Solo faces only (filtered by query)
+ imageWidth = cached.imageWidth,
+ imageHeight = cached.imageHeight
+ )
+ }.also {
+ onProgress(50, 100, "Processing ${it.size} cached faces...")
+ }
+ } else {
+ // SLOW PATH
+ Log.d(TAG, "Using SLOW PATH - cache has < $MIN_CACHED_FACES faces")
+ onProgress(0, 100, "Loading photos...")
+
+ val soloPhotos = withContext(Dispatchers.IO) {
+ imageDao.getImagesByFaceCount(count = 1)
+ }
+
+ val imagesWithFaces = if (soloPhotos.size < MIN_SOLO_PHOTOS) {
+ imageDao.getImagesWithFaces()
+ } else {
+ soloPhotos
+ }
+
+ if (imagesWithFaces.isEmpty()) {
+ return@withContext ClusteringResult(
+ clusters = emptyList(),
+ totalFacesAnalyzed = 0,
+ processingTimeMs = 0,
+ errorMessage = "No photos with faces found"
+ )
+ }
+
+ onProgress(10, 100, "Analyzing ${imagesWithFaces.size} photos...")
+
+ detectFacesInImagesBatched(
+ images = imagesWithFaces.take(1000),
+ onProgress = { current, total ->
+ onProgress(10 + (current * 40 / total), 100, "Detecting faces... $current/$total")
+ }
+ )
+ }
if (allFaces.isEmpty()) {
return@withContext ClusteringResult(
clusters = emptyList(),
totalFacesAnalyzed = 0,
processingTimeMs = System.currentTimeMillis() - startTime,
- errorMessage = "No faces detected in images"
+ errorMessage = "No faces detected"
)
}
onProgress(50, 100, "Clustering ${allFaces.size} faces...")
- // Step 3: DBSCAN clustering
val rawClusters = performDBSCAN(
faces = allFaces.take(maxFacesToCluster),
- epsilon = 0.18f, // VERY STRICT for siblings
+ epsilon = 0.26f,
minPoints = 3
)
onProgress(70, 100, "Analyzing relationships...")
- // Step 4: Build co-occurrence graph
val coOccurrenceGraph = buildCoOccurrenceGraph(rawClusters)
onProgress(80, 100, "Selecting representative faces...")
- // Step 5: Create final clusters
val clusters = rawClusters.map { cluster ->
FaceCluster(
clusterId = cluster.clusterId,
faces = cluster.faces,
- representativeFaces = selectRepresentativeFaces(cluster.faces, count = 6),
+ representativeFaces = selectRepresentativeFacesByCentroid(cluster.faces, count = 6),
photoCount = cluster.faces.map { it.imageId }.distinct().size,
averageConfidence = cluster.faces.map { it.confidence }.average().toFloat(),
estimatedAge = estimateAge(cluster.faces),
@@ -154,14 +196,31 @@ class FaceClusteringService @Inject constructor(
)
}
- /**
- * Detect faces in images and generate embeddings (parallel)
- */
- private suspend fun detectFacesInImages(
+ private suspend fun detectFacesInImagesBatched(
images: List,
onProgress: (Int, Int) -> Unit
): List = coroutineScope {
+ val allFaces = mutableListOf()
+ var processedCount = 0
+
+ images.chunked(BATCH_SIZE).forEach { batch ->
+ val batchFaces = detectFacesInBatch(batch)
+ allFaces.addAll(batchFaces)
+
+ processedCount += batch.size
+ onProgress(processedCount, images.size)
+
+ System.gc()
+ }
+
+ allFaces
+ }
+
+ private suspend fun detectFacesInBatch(
+ images: List
+ ): List = coroutineScope {
+
val detector = FaceDetection.getClient(
FaceDetectorOptions.Builder()
.setPerformanceMode(FaceDetectorOptions.PERFORMANCE_MODE_ACCURATE)
@@ -170,20 +229,14 @@ class FaceClusteringService @Inject constructor(
)
val faceNetModel = FaceNetModel(context)
- val allFaces = mutableListOf()
- val processedCount = AtomicInteger(0)
+ val batchFaces = mutableListOf()
try {
val jobs = images.map { image ->
- async {
+ async(Dispatchers.IO) {
semaphore.acquire()
try {
- val faces = detectFacesInImage(image, detector, faceNetModel)
- val current = processedCount.incrementAndGet()
- if (current % 10 == 0) {
- onProgress(current, images.size)
- }
- faces
+ detectFacesInImage(image, detector, faceNetModel)
} finally {
semaphore.release()
}
@@ -191,7 +244,7 @@ class FaceClusteringService @Inject constructor(
}
jobs.awaitAll().flatten().also {
- allFaces.addAll(it)
+ batchFaces.addAll(it)
}
} finally {
@@ -199,7 +252,7 @@ class FaceClusteringService @Inject constructor(
faceNetModel.close()
}
- allFaces
+ batchFaces
}
private suspend fun detectFacesInImage(
@@ -215,8 +268,6 @@ class FaceClusteringService @Inject constructor(
val mlImage = InputImage.fromBitmap(bitmap, 0)
val faces = com.google.android.gms.tasks.Tasks.await(detector.process(mlImage))
- val totalFacesInImage = faces.size
-
val result = faces.mapNotNull { face ->
try {
val faceBitmap = Bitmap.createBitmap(
@@ -237,7 +288,9 @@ class FaceClusteringService @Inject constructor(
embedding = embedding,
boundingBox = face.boundingBox,
confidence = 0.95f,
- faceCount = totalFacesInImage
+ faceCount = faces.size,
+ imageWidth = bitmap.width,
+ imageHeight = bitmap.height
)
} catch (e: Exception) {
null
@@ -252,9 +305,6 @@ class FaceClusteringService @Inject constructor(
}
}
- // All other methods remain the same (DBSCAN, similarity, etc.)
- // ... [Rest of the implementation from original file]
-
private fun performDBSCAN(
faces: List,
epsilon: Float,
@@ -368,18 +418,61 @@ class FaceClusteringService @Inject constructor(
?: emptyList()
}
- private fun selectRepresentativeFaces(
+ private fun selectRepresentativeFacesByCentroid(
faces: List,
count: Int
): List {
if (faces.size <= count) return faces
- val sortedByTime = faces.sortedBy { it.capturedAt }
- val step = faces.size / count
+ val centroid = calculateCentroid(faces.map { it.embedding })
- return (0 until count).map { i ->
- sortedByTime[i * step]
+ val facesWithDistance = faces.map { face ->
+ val distance = 1 - cosineSimilarity(face.embedding, centroid)
+ face to distance
}
+
+ val sortedByProximity = facesWithDistance.sortedBy { it.second }
+
+ val representatives = mutableListOf()
+ representatives.add(sortedByProximity.first().first)
+
+ val remainingFaces = sortedByProximity.drop(1).take(count * 3)
+ val sortedByTime = remainingFaces.map { it.first }.sortedBy { it.capturedAt }
+
+ if (sortedByTime.isNotEmpty()) {
+ val step = sortedByTime.size / (count - 1).coerceAtLeast(1)
+ for (i in 0 until (count - 1)) {
+ val index = (i * step).coerceAtMost(sortedByTime.size - 1)
+ representatives.add(sortedByTime[index])
+ }
+ }
+
+ return representatives.take(count)
+ }
+
+ private fun calculateCentroid(embeddings: List): FloatArray {
+ if (embeddings.isEmpty()) return FloatArray(0)
+
+ val size = embeddings.first().size
+ val centroid = FloatArray(size) { 0f }
+
+ embeddings.forEach { embedding ->
+ for (i in embedding.indices) {
+ centroid[i] += embedding[i]
+ }
+ }
+
+ val count = embeddings.size.toFloat()
+ for (i in centroid.indices) {
+ centroid[i] /= count
+ }
+
+ val norm = sqrt(centroid.map { it * it }.sum())
+ if (norm > 0) {
+ return centroid.map { it / norm }.toFloatArray()
+ }
+
+ return centroid
}
private fun estimateAge(faces: List): AgeEstimate {
@@ -416,7 +509,6 @@ class FaceClusteringService @Inject constructor(
}
}
-// Data classes
data class DetectedFaceWithEmbedding(
val imageId: String,
val imageUri: String,
@@ -424,7 +516,9 @@ data class DetectedFaceWithEmbedding(
val embedding: FloatArray,
val boundingBox: android.graphics.Rect,
val confidence: Float,
- val faceCount: Int = 1
+ val faceCount: Int = 1,
+ val imageWidth: Int = 0,
+ val imageHeight: Int = 0
) {
override fun equals(other: Any?): Boolean {
if (this === other) return true
diff --git a/app/src/main/java/com/placeholder/sherpai2/ml/FaceNetModel.kt b/app/src/main/java/com/placeholder/sherpai2/ml/FaceNetModel.kt
index 22daad1..a15197d 100644
--- a/app/src/main/java/com/placeholder/sherpai2/ml/FaceNetModel.kt
+++ b/app/src/main/java/com/placeholder/sherpai2/ml/FaceNetModel.kt
@@ -2,6 +2,7 @@ package com.placeholder.sherpai2.ml
import android.content.Context
import android.graphics.Bitmap
+import android.util.Log
import org.tensorflow.lite.Interpreter
import java.io.FileInputStream
import java.nio.ByteBuffer
@@ -11,16 +12,21 @@ import java.nio.channels.FileChannel
import kotlin.math.sqrt
/**
- * FaceNetModel - MobileFaceNet wrapper for face recognition
+ * FaceNetModel - MobileFaceNet wrapper with debugging
*
- * CLEAN IMPLEMENTATION:
- * - All IDs are Strings (matching your schema)
- * - Generates 192-dimensional embeddings
- * - Cosine similarity for matching
+ * IMPROVEMENTS:
+ * - ✅ Detailed error logging
+ * - ✅ Model validation on init
+ * - ✅ Embedding validation (detect all-zeros)
+ * - ✅ Toggle-able debug mode
*/
-class FaceNetModel(private val context: Context) {
+class FaceNetModel(
+ private val context: Context,
+ private val debugMode: Boolean = true // Enable for troubleshooting
+) {
companion object {
+ private const val TAG = "FaceNetModel"
private const val MODEL_FILE = "mobilefacenet.tflite"
private const val INPUT_SIZE = 112
private const val EMBEDDING_SIZE = 192
@@ -31,13 +37,56 @@ class FaceNetModel(private val context: Context) {
}
private var interpreter: Interpreter? = null
+ private var modelLoadSuccess = false
init {
try {
+ if (debugMode) Log.d(TAG, "Loading FaceNet model: $MODEL_FILE")
+
val model = loadModelFile()
interpreter = Interpreter(model)
+ modelLoadSuccess = true
+
+ if (debugMode) {
+ Log.d(TAG, "✅ FaceNet model loaded successfully")
+ Log.d(TAG, "Model input size: ${INPUT_SIZE}x$INPUT_SIZE")
+ Log.d(TAG, "Embedding size: $EMBEDDING_SIZE")
+ }
+
+ // Test model with dummy input
+ testModel()
+
} catch (e: Exception) {
- throw RuntimeException("Failed to load FaceNet model", e)
+ Log.e(TAG, "❌ CRITICAL: Failed to load FaceNet model from assets/$MODEL_FILE", e)
+ Log.e(TAG, "Make sure mobilefacenet.tflite exists in app/src/main/assets/")
+ modelLoadSuccess = false
+ throw RuntimeException("Failed to load FaceNet model: ${e.message}", e)
+ }
+ }
+
+ /**
+ * Test model with dummy input to verify it works
+ */
+ private fun testModel() {
+ try {
+ val testBitmap = Bitmap.createBitmap(INPUT_SIZE, INPUT_SIZE, Bitmap.Config.ARGB_8888)
+ val testEmbedding = generateEmbedding(testBitmap)
+ testBitmap.recycle()
+
+ val sum = testEmbedding.sum()
+ val norm = sqrt(testEmbedding.map { it * it }.sum())
+
+ if (debugMode) {
+ Log.d(TAG, "Model test: embedding sum=$sum, norm=$norm")
+ }
+
+ if (sum == 0f || norm == 0f) {
+ Log.e(TAG, "⚠️ WARNING: Model test produced zero embedding!")
+ } else {
+ if (debugMode) Log.d(TAG, "✅ Model test passed")
+ }
+ } catch (e: Exception) {
+ Log.e(TAG, "Model test failed", e)
}
}
@@ -45,12 +94,22 @@ class FaceNetModel(private val context: Context) {
* Load TFLite model from assets
*/
private fun loadModelFile(): MappedByteBuffer {
- val fileDescriptor = context.assets.openFd(MODEL_FILE)
- val inputStream = FileInputStream(fileDescriptor.fileDescriptor)
- val fileChannel = inputStream.channel
- val startOffset = fileDescriptor.startOffset
- val declaredLength = fileDescriptor.declaredLength
- return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength)
+ try {
+ val fileDescriptor = context.assets.openFd(MODEL_FILE)
+ val inputStream = FileInputStream(fileDescriptor.fileDescriptor)
+ val fileChannel = inputStream.channel
+ val startOffset = fileDescriptor.startOffset
+ val declaredLength = fileDescriptor.declaredLength
+
+ if (debugMode) {
+ Log.d(TAG, "Model file size: ${declaredLength / 1024}KB")
+ }
+
+ return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength)
+ } catch (e: Exception) {
+ Log.e(TAG, "Failed to open model file: $MODEL_FILE", e)
+ throw e
+ }
}
/**
@@ -60,13 +119,39 @@ class FaceNetModel(private val context: Context) {
* @return 192-dimensional embedding
*/
fun generateEmbedding(faceBitmap: Bitmap): FloatArray {
- val resized = Bitmap.createScaledBitmap(faceBitmap, INPUT_SIZE, INPUT_SIZE, true)
- val inputBuffer = preprocessImage(resized)
- val output = Array(1) { FloatArray(EMBEDDING_SIZE) }
+ if (!modelLoadSuccess || interpreter == null) {
+ Log.e(TAG, "❌ Cannot generate embedding: model not loaded!")
+ return FloatArray(EMBEDDING_SIZE) { 0f }
+ }
- interpreter?.run(inputBuffer, output)
+ try {
+ val resized = Bitmap.createScaledBitmap(faceBitmap, INPUT_SIZE, INPUT_SIZE, true)
+ val inputBuffer = preprocessImage(resized)
+ val output = Array(1) { FloatArray(EMBEDDING_SIZE) }
- return normalizeEmbedding(output[0])
+ interpreter?.run(inputBuffer, output)
+
+ val normalized = normalizeEmbedding(output[0])
+
+ // DIAGNOSTIC: Check embedding quality
+ if (debugMode) {
+ val sum = normalized.sum()
+ val norm = sqrt(normalized.map { it * it }.sum())
+
+ if (sum == 0f && norm == 0f) {
+ Log.e(TAG, "❌ CRITICAL: Generated all-zero embedding!")
+ Log.e(TAG, "Input bitmap: ${faceBitmap.width}x${faceBitmap.height}")
+ } else {
+ Log.d(TAG, "✅ Embedding: sum=${"%.2f".format(sum)}, norm=${"%.2f".format(norm)}, first5=[${normalized.take(5).joinToString { "%.3f".format(it) }}]")
+ }
+ }
+
+ return normalized
+
+ } catch (e: Exception) {
+ Log.e(TAG, "Failed to generate embedding", e)
+ return FloatArray(EMBEDDING_SIZE) { 0f }
+ }
}
/**
@@ -76,6 +161,10 @@ class FaceNetModel(private val context: Context) {
faceBitmaps: List,
onProgress: (Int, Int) -> Unit = { _, _ -> }
): List {
+ if (debugMode) {
+ Log.d(TAG, "Generating embeddings for ${faceBitmaps.size} faces")
+ }
+
return faceBitmaps.mapIndexed { index, bitmap ->
onProgress(index + 1, faceBitmaps.size)
generateEmbedding(bitmap)
@@ -88,6 +177,10 @@ class FaceNetModel(private val context: Context) {
fun createPersonModel(embeddings: List): FloatArray {
require(embeddings.isNotEmpty()) { "Need at least one embedding" }
+ if (debugMode) {
+ Log.d(TAG, "Creating person model from ${embeddings.size} embeddings")
+ }
+
val averaged = FloatArray(EMBEDDING_SIZE) { 0f }
embeddings.forEach { embedding ->
@@ -101,7 +194,14 @@ class FaceNetModel(private val context: Context) {
averaged[i] /= count
}
- return normalizeEmbedding(averaged)
+ val normalized = normalizeEmbedding(averaged)
+
+ if (debugMode) {
+ val sum = normalized.sum()
+ Log.d(TAG, "Person model created: sum=${"%.2f".format(sum)}")
+ }
+
+ return normalized
}
/**
@@ -110,7 +210,7 @@ class FaceNetModel(private val context: Context) {
*/
fun calculateSimilarity(embedding1: FloatArray, embedding2: FloatArray): Float {
require(embedding1.size == EMBEDDING_SIZE && embedding2.size == EMBEDDING_SIZE) {
- "Invalid embedding size"
+ "Invalid embedding size: ${embedding1.size} vs ${embedding2.size}"
}
var dotProduct = 0f
@@ -123,7 +223,14 @@ class FaceNetModel(private val context: Context) {
norm2 += embedding2[i] * embedding2[i]
}
- return dotProduct / (sqrt(norm1) * sqrt(norm2))
+ val similarity = dotProduct / (sqrt(norm1) * sqrt(norm2))
+
+ if (debugMode && (similarity.isNaN() || similarity.isInfinite())) {
+ Log.e(TAG, "❌ Invalid similarity: $similarity (norm1=$norm1, norm2=$norm2)")
+ return 0f
+ }
+
+ return similarity
}
/**
@@ -151,6 +258,10 @@ class FaceNetModel(private val context: Context) {
}
}
+ if (debugMode && bestMatch != null) {
+ Log.d(TAG, "Best match: ${bestMatch.first} with similarity ${bestMatch.second}")
+ }
+
return bestMatch
}
@@ -169,6 +280,7 @@ class FaceNetModel(private val context: Context) {
val g = ((pixel shr 8) and 0xFF) / 255.0f
val b = (pixel and 0xFF) / 255.0f
+ // Normalize to [-1, 1]
buffer.putFloat((r - 0.5f) / 0.5f)
buffer.putFloat((g - 0.5f) / 0.5f)
buffer.putFloat((b - 0.5f) / 0.5f)
@@ -190,14 +302,29 @@ class FaceNetModel(private val context: Context) {
return if (norm > 0) {
FloatArray(embedding.size) { i -> embedding[i] / norm }
} else {
+ Log.w(TAG, "⚠️ Cannot normalize zero embedding")
embedding
}
}
+ /**
+ * Get model status for diagnostics
+ */
+ fun getModelStatus(): String {
+ return if (modelLoadSuccess) {
+ "✅ Model loaded and operational"
+ } else {
+ "❌ Model failed to load - check assets/$MODEL_FILE"
+ }
+ }
+
/**
* Clean up resources
*/
fun close() {
+ if (debugMode) {
+ Log.d(TAG, "Closing FaceNet model")
+ }
interpreter?.close()
interpreter = null
}
diff --git a/app/src/main/java/com/placeholder/sherpai2/ui/discover/Clustergridscreen.kt b/app/src/main/java/com/placeholder/sherpai2/ui/discover/Clustergridscreen.kt
index c502417..965621d 100644
--- a/app/src/main/java/com/placeholder/sherpai2/ui/discover/Clustergridscreen.kt
+++ b/app/src/main/java/com/placeholder/sherpai2/ui/discover/Clustergridscreen.kt
@@ -8,9 +8,13 @@ import androidx.compose.foundation.layout.*
import androidx.compose.foundation.lazy.grid.GridCells
import androidx.compose.foundation.lazy.grid.LazyVerticalGrid
import androidx.compose.foundation.lazy.grid.items
+import androidx.compose.foundation.shape.CircleShape
import androidx.compose.foundation.shape.RoundedCornerShape
+import androidx.compose.material.icons.Icons
+import androidx.compose.material.icons.filled.Check
+import androidx.compose.material.icons.filled.Warning
import androidx.compose.material3.*
-import androidx.compose.runtime.Composable
+import androidx.compose.runtime.*
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.draw.clip
@@ -19,6 +23,8 @@ import androidx.compose.ui.layout.ContentScale
import androidx.compose.ui.text.font.FontWeight
import androidx.compose.ui.unit.dp
import coil.compose.AsyncImage
+import com.placeholder.sherpai2.domain.clustering.ClusterQualityAnalyzer
+import com.placeholder.sherpai2.domain.clustering.ClusterQualityTier
import com.placeholder.sherpai2.domain.clustering.ClusteringResult
import com.placeholder.sherpai2.domain.clustering.FaceCluster
@@ -28,13 +34,20 @@ import com.placeholder.sherpai2.domain.clustering.FaceCluster
* Each cluster card shows:
* - 2x2 grid of representative faces
* - Photo count
+ * - Quality badge (Excellent/Good/Poor)
* - Tap to name
+ *
+ * IMPROVEMENTS:
+ * - ✅ Quality badges for each cluster
+ * - ✅ Visual indicators for trainable vs non-trainable clusters
+ * - ✅ Better UX with disabled states for poor quality clusters
*/
@Composable
fun ClusterGridScreen(
result: ClusteringResult,
onSelectCluster: (FaceCluster) -> Unit,
- modifier: Modifier = Modifier
+ modifier: Modifier = Modifier,
+ qualityAnalyzer: ClusterQualityAnalyzer = remember { ClusterQualityAnalyzer() }
) {
Column(
modifier = modifier
@@ -65,8 +78,15 @@ fun ClusterGridScreen(
verticalArrangement = Arrangement.spacedBy(12.dp)
) {
items(result.clusters) { cluster ->
+ // Analyze quality for each cluster
+ val qualityResult = remember(cluster) {
+ qualityAnalyzer.analyzeCluster(cluster)
+ }
+
ClusterCard(
cluster = cluster,
+ qualityTier = qualityResult.qualityTier,
+ canTrain = qualityResult.canTrain,
onClick = { onSelectCluster(cluster) }
)
}
@@ -75,77 +95,168 @@ fun ClusterGridScreen(
}
/**
- * Single cluster card with 2x2 face grid
+ * Single cluster card with 2x2 face grid and quality badge
*/
@Composable
private fun ClusterCard(
cluster: FaceCluster,
+ qualityTier: ClusterQualityTier,
+ canTrain: Boolean,
onClick: () -> Unit
) {
Card(
modifier = Modifier
.fillMaxWidth()
.aspectRatio(1f)
- .clickable(onClick = onClick),
- elevation = CardDefaults.cardElevation(defaultElevation = 2.dp)
+ .clickable(onClick = onClick), // Always clickable - let dialog handle validation
+ elevation = CardDefaults.cardElevation(defaultElevation = 2.dp),
+ colors = CardDefaults.cardColors(
+ containerColor = when {
+ qualityTier == ClusterQualityTier.POOR ->
+ MaterialTheme.colorScheme.errorContainer.copy(alpha = 0.3f)
+ !canTrain ->
+ MaterialTheme.colorScheme.surfaceVariant.copy(alpha = 0.5f)
+ else ->
+ MaterialTheme.colorScheme.surface
+ }
+ )
) {
- Column(
+ Box(
modifier = Modifier.fillMaxSize()
) {
- // 2x2 grid of faces
- val facesToShow = cluster.representativeFaces.take(4)
-
Column(
- modifier = Modifier.weight(1f)
+ modifier = Modifier.fillMaxSize()
) {
- // Top row (2 faces)
- Row(modifier = Modifier.weight(1f)) {
- facesToShow.getOrNull(0)?.let { face ->
- FaceThumbnail(
- imageUri = face.imageUri,
- modifier = Modifier.weight(1f)
- )
- } ?: EmptyFaceSlot(Modifier.weight(1f))
+ // 2x2 grid of faces
+ val facesToShow = cluster.representativeFaces.take(4)
- facesToShow.getOrNull(1)?.let { face ->
- FaceThumbnail(
- imageUri = face.imageUri,
- modifier = Modifier.weight(1f)
- )
- } ?: EmptyFaceSlot(Modifier.weight(1f))
+ Column(
+ modifier = Modifier.weight(1f)
+ ) {
+ // Top row (2 faces)
+ Row(modifier = Modifier.weight(1f)) {
+ facesToShow.getOrNull(0)?.let { face ->
+ FaceThumbnail(
+ imageUri = face.imageUri,
+ enabled = canTrain,
+ modifier = Modifier.weight(1f)
+ )
+ } ?: EmptyFaceSlot(Modifier.weight(1f))
+
+ facesToShow.getOrNull(1)?.let { face ->
+ FaceThumbnail(
+ imageUri = face.imageUri,
+ enabled = canTrain,
+ modifier = Modifier.weight(1f)
+ )
+ } ?: EmptyFaceSlot(Modifier.weight(1f))
+ }
+
+ // Bottom row (2 faces)
+ Row(modifier = Modifier.weight(1f)) {
+ facesToShow.getOrNull(2)?.let { face ->
+ FaceThumbnail(
+ imageUri = face.imageUri,
+ enabled = canTrain,
+ modifier = Modifier.weight(1f)
+ )
+ } ?: EmptyFaceSlot(Modifier.weight(1f))
+
+ facesToShow.getOrNull(3)?.let { face ->
+ FaceThumbnail(
+ imageUri = face.imageUri,
+ enabled = canTrain,
+ modifier = Modifier.weight(1f)
+ )
+ } ?: EmptyFaceSlot(Modifier.weight(1f))
+ }
}
- // Bottom row (2 faces)
- Row(modifier = Modifier.weight(1f)) {
- facesToShow.getOrNull(2)?.let { face ->
- FaceThumbnail(
- imageUri = face.imageUri,
- modifier = Modifier.weight(1f)
+ // Footer with photo count
+ Surface(
+ modifier = Modifier.fillMaxWidth(),
+ color = if (canTrain) {
+ MaterialTheme.colorScheme.primaryContainer
+ } else {
+ MaterialTheme.colorScheme.surfaceVariant
+ }
+ ) {
+ Row(
+ modifier = Modifier.padding(12.dp),
+ verticalAlignment = Alignment.CenterVertically,
+ horizontalArrangement = Arrangement.SpaceBetween
+ ) {
+ Text(
+ text = "${cluster.photoCount} photos",
+ style = MaterialTheme.typography.bodyMedium,
+ fontWeight = FontWeight.SemiBold,
+ color = if (canTrain) {
+ MaterialTheme.colorScheme.onPrimaryContainer
+ } else {
+ MaterialTheme.colorScheme.onSurfaceVariant
+ }
)
- } ?: EmptyFaceSlot(Modifier.weight(1f))
-
- facesToShow.getOrNull(3)?.let { face ->
- FaceThumbnail(
- imageUri = face.imageUri,
- modifier = Modifier.weight(1f)
- )
- } ?: EmptyFaceSlot(Modifier.weight(1f))
+ }
}
}
- // Footer with photo count
- Surface(
- modifier = Modifier.fillMaxWidth(),
- color = MaterialTheme.colorScheme.primaryContainer
- ) {
- Text(
- text = "${cluster.photoCount} photos",
- style = MaterialTheme.typography.bodyMedium,
- fontWeight = FontWeight.SemiBold,
- modifier = Modifier.padding(12.dp),
- color = MaterialTheme.colorScheme.onPrimaryContainer
- )
- }
+ // Quality badge overlay
+ QualityBadge(
+ qualityTier = qualityTier,
+ canTrain = canTrain,
+ modifier = Modifier
+ .align(Alignment.TopEnd)
+ .padding(8.dp)
+ )
+ }
+ }
+}
+
+/**
+ * Quality badge indicator
+ */
+@Composable
+private fun QualityBadge(
+ qualityTier: ClusterQualityTier,
+ canTrain: Boolean,
+ modifier: Modifier = Modifier
+) {
+ val (backgroundColor, iconColor, icon) = when (qualityTier) {
+ ClusterQualityTier.EXCELLENT -> Triple(
+ Color(0xFF1B5E20),
+ Color.White,
+ Icons.Default.Check
+ )
+ ClusterQualityTier.GOOD -> Triple(
+ Color(0xFF2E7D32),
+ Color.White,
+ Icons.Default.Check
+ )
+ ClusterQualityTier.POOR -> Triple(
+ Color(0xFFD32F2F),
+ Color.White,
+ Icons.Default.Warning
+ )
+ }
+
+ Surface(
+ modifier = modifier,
+ shape = CircleShape,
+ color = backgroundColor,
+ shadowElevation = 2.dp
+ ) {
+ Box(
+ modifier = Modifier
+ .size(32.dp)
+ .padding(6.dp),
+ contentAlignment = Alignment.Center
+ ) {
+ Icon(
+ imageVector = icon,
+ contentDescription = qualityTier.name,
+ tint = iconColor,
+ modifier = Modifier.size(20.dp)
+ )
}
}
}
@@ -153,19 +264,23 @@ private fun ClusterCard(
@Composable
private fun FaceThumbnail(
imageUri: String,
+ enabled: Boolean,
modifier: Modifier = Modifier
) {
- AsyncImage(
- model = Uri.parse(imageUri),
- contentDescription = "Face",
- modifier = modifier
- .fillMaxSize()
- .border(
- width = 0.5.dp,
- color = MaterialTheme.colorScheme.outline.copy(alpha = 0.3f)
- ),
- contentScale = ContentScale.Crop
- )
+ Box(modifier = modifier) {
+ AsyncImage(
+ model = Uri.parse(imageUri),
+ contentDescription = "Face",
+ modifier = Modifier
+ .fillMaxSize()
+ .border(
+ width = 0.5.dp,
+ color = MaterialTheme.colorScheme.outline.copy(alpha = 0.3f)
+ ),
+ contentScale = ContentScale.Crop,
+ alpha = if (enabled) 1f else 0.6f
+ )
+ }
}
@Composable
diff --git a/app/src/main/java/com/placeholder/sherpai2/ui/discover/Discoverpeoplescreen.kt b/app/src/main/java/com/placeholder/sherpai2/ui/discover/Discoverpeoplescreen.kt
index 4c722da..64a2630 100644
--- a/app/src/main/java/com/placeholder/sherpai2/ui/discover/Discoverpeoplescreen.kt
+++ b/app/src/main/java/com/placeholder/sherpai2/ui/discover/Discoverpeoplescreen.kt
@@ -11,11 +11,17 @@ import androidx.compose.ui.text.font.FontWeight
import androidx.compose.ui.text.style.TextAlign
import androidx.compose.ui.unit.dp
import androidx.hilt.navigation.compose.hiltViewModel
+import com.placeholder.sherpai2.domain.clustering.ClusterQualityAnalyzer
/**
- * DiscoverPeopleScreen - COMPLETE WORKING VERSION
+ * DiscoverPeopleScreen - COMPLETE WORKING VERSION WITH NAMING DIALOG
*
- * This handles ALL states properly including Idle state
+ * This handles ALL states properly including the NamingCluster dialog
+ *
+ * IMPROVEMENTS:
+ * - ✅ Complete naming dialog integration
+ * - ✅ Quality analysis in cluster grid
+ * - ✅ Better error handling
*/
@OptIn(ExperimentalMaterial3Api::class)
@Composable
@@ -24,125 +30,130 @@ fun DiscoverPeopleScreen(
onNavigateBack: () -> Unit = {}
) {
val uiState by viewModel.uiState.collectAsState()
+ val qualityAnalyzer = remember { ClusterQualityAnalyzer() }
- Scaffold(
- topBar = {
- TopAppBar(
- title = { Text("Discover People") },
- navigationIcon = {
- IconButton(onClick = onNavigateBack) {
- Icon(
- imageVector = Icons.Default.Person,
- contentDescription = "Back"
+ // No Scaffold, no TopAppBar - MainScreen handles that
+ Box(
+ modifier = Modifier.fillMaxSize()
+ ) {
+ when (val state = uiState) {
+ // ===== IDLE STATE (START HERE) =====
+ is DiscoverUiState.Idle -> {
+ IdleStateContent(
+ onStartDiscovery = { viewModel.startDiscovery() }
+ )
+ }
+
+ // ===== CLUSTERING IN PROGRESS =====
+ is DiscoverUiState.Clustering -> {
+ ClusteringProgressContent(
+ progress = state.progress,
+ total = state.total,
+ message = state.message
+ )
+ }
+
+ // ===== CLUSTERS READY FOR NAMING =====
+ is DiscoverUiState.NamingReady -> {
+ ClusterGridScreen(
+ result = state.result,
+ onSelectCluster = { cluster ->
+ viewModel.selectCluster(cluster)
+ },
+ qualityAnalyzer = qualityAnalyzer
+ )
+ }
+
+ // ===== ANALYZING CLUSTER QUALITY =====
+ is DiscoverUiState.AnalyzingCluster -> {
+ LoadingContent(message = "Analyzing cluster quality...")
+ }
+
+ // ===== NAMING A CLUSTER (SHOW DIALOG) =====
+ is DiscoverUiState.NamingCluster -> {
+ // Show cluster grid in background
+ ClusterGridScreen(
+ result = state.result,
+ onSelectCluster = { /* Disabled while dialog open */ },
+ qualityAnalyzer = qualityAnalyzer
+ )
+
+ // Show naming dialog overlay
+ NamingDialog(
+ cluster = state.selectedCluster,
+ suggestedSiblings = state.suggestedSiblings,
+ onConfirm = { name, dateOfBirth, isChild, selectedSiblings ->
+ viewModel.confirmClusterName(
+ cluster = state.selectedCluster,
+ name = name,
+ dateOfBirth = dateOfBirth,
+ isChild = isChild,
+ selectedSiblings = selectedSiblings
)
+ },
+ onDismiss = {
+ viewModel.cancelNaming()
+ },
+ qualityAnalyzer = qualityAnalyzer
+ )
+ }
+
+ // ===== TRAINING IN PROGRESS =====
+ is DiscoverUiState.Training -> {
+ TrainingProgressContent(
+ stage = state.stage,
+ progress = state.progress,
+ total = state.total
+ )
+ }
+
+ // ===== VALIDATION PREVIEW =====
+ is DiscoverUiState.ValidationPreview -> {
+ ValidationPreviewScreen(
+ personName = state.personName,
+ validationResult = state.validationResult,
+ onApprove = {
+ viewModel.approveValidationAndScan(
+ personId = state.personId,
+ personName = state.personName
+ )
+ },
+ onReject = {
+ viewModel.rejectValidationAndImprove()
}
- }
- )
- }
- ) { paddingValues ->
- Box(
- modifier = Modifier
- .fillMaxSize()
- .padding(paddingValues)
- ) {
- when (val state = uiState) {
- // ===== IDLE STATE (START HERE) =====
- is DiscoverUiState.Idle -> {
- IdleStateContent(
- onStartDiscovery = { viewModel.startDiscovery() }
- )
- }
+ )
+ }
- // ===== CLUSTERING IN PROGRESS =====
- is DiscoverUiState.Clustering -> {
- ClusteringProgressContent(
- progress = state.progress,
- total = state.total,
- message = state.message
- )
- }
+ // ===== COMPLETE =====
+ is DiscoverUiState.Complete -> {
+ CompleteStateContent(
+ message = state.message,
+ onDone = onNavigateBack
+ )
+ }
- // ===== CLUSTERS READY FOR NAMING =====
- is DiscoverUiState.NamingReady -> {
- ClusterGridScreen(
- result = state.result,
- onSelectCluster = { cluster ->
- viewModel.selectCluster(cluster)
- }
- )
- }
+ // ===== NO PEOPLE FOUND =====
+ is DiscoverUiState.NoPeopleFound -> {
+ ErrorStateContent(
+ title = "No People Found",
+ message = state.message,
+ onRetry = { viewModel.startDiscovery() },
+ onBack = onNavigateBack
+ )
+ }
- // ===== ANALYZING CLUSTER QUALITY =====
- is DiscoverUiState.AnalyzingCluster -> {
- LoadingContent(message = "Analyzing cluster quality...")
- }
-
- // ===== NAMING A CLUSTER =====
- is DiscoverUiState.NamingCluster -> {
- Text(
- text = "Naming dialog for cluster ${state.selectedCluster.clusterId}\n\nDialog UI coming...",
- modifier = Modifier.align(Alignment.Center)
- )
- }
-
- // ===== TRAINING IN PROGRESS =====
- is DiscoverUiState.Training -> {
- TrainingProgressContent(
- stage = state.stage,
- progress = state.progress,
- total = state.total
- )
- }
-
- // ===== VALIDATION PREVIEW =====
- is DiscoverUiState.ValidationPreview -> {
- ValidationPreviewScreen(
- personName = state.personName,
- validationResult = state.validationResult,
- onApprove = {
- viewModel.approveValidationAndScan(
- personId = state.personId,
- personName = state.personName
- )
- },
- onReject = {
- viewModel.rejectValidationAndImprove()
- }
- )
- }
-
- // ===== COMPLETE =====
- is DiscoverUiState.Complete -> {
- CompleteStateContent(
- message = state.message,
- onDone = onNavigateBack
- )
- }
-
- // ===== NO PEOPLE FOUND =====
- is DiscoverUiState.NoPeopleFound -> {
- ErrorStateContent(
- title = "No People Found",
- message = state.message,
- onRetry = { viewModel.startDiscovery() },
- onBack = onNavigateBack
- )
- }
-
- // ===== ERROR =====
- is DiscoverUiState.Error -> {
- ErrorStateContent(
- title = "Error",
- message = state.message,
- onRetry = { viewModel.reset(); viewModel.startDiscovery() },
- onBack = onNavigateBack
- )
- }
+ // ===== ERROR =====
+ is DiscoverUiState.Error -> {
+ ErrorStateContent(
+ title = "Error",
+ message = state.message,
+ onRetry = { viewModel.reset(); viewModel.startDiscovery() },
+ onBack = onNavigateBack
+ )
}
}
}
}
-
// ===== IDLE STATE CONTENT =====
@Composable
@@ -165,19 +176,11 @@ private fun IdleStateContent(
Spacer(modifier = Modifier.height(32.dp))
- Text(
- text = "Discover People",
- style = MaterialTheme.typography.headlineLarge,
- fontWeight = FontWeight.Bold
- )
-
- Spacer(modifier = Modifier.height(16.dp))
-
Text(
text = "Automatically find and organize people in your photo library",
- style = MaterialTheme.typography.bodyLarge,
+ style = MaterialTheme.typography.headlineSmall,
textAlign = TextAlign.Center,
- color = MaterialTheme.colorScheme.onSurfaceVariant
+ color = MaterialTheme.colorScheme.onSurface
)
Spacer(modifier = Modifier.height(48.dp))
diff --git a/app/src/main/java/com/placeholder/sherpai2/ui/discover/Namingdialog.kt b/app/src/main/java/com/placeholder/sherpai2/ui/discover/Namingdialog.kt
new file mode 100644
index 0000000..e024715
--- /dev/null
+++ b/app/src/main/java/com/placeholder/sherpai2/ui/discover/Namingdialog.kt
@@ -0,0 +1,480 @@
+package com.placeholder.sherpai2.ui.discover
+
+import androidx.compose.foundation.background
+import androidx.compose.foundation.border
+import androidx.compose.foundation.clickable
+import androidx.compose.foundation.layout.*
+import androidx.compose.foundation.lazy.LazyRow
+import androidx.compose.foundation.lazy.items
+import androidx.compose.foundation.rememberScrollState
+import androidx.compose.foundation.shape.CircleShape
+import androidx.compose.foundation.shape.RoundedCornerShape
+import androidx.compose.foundation.text.KeyboardActions
+import androidx.compose.foundation.text.KeyboardOptions
+import androidx.compose.foundation.verticalScroll
+import androidx.compose.material.icons.Icons
+import androidx.compose.material.icons.filled.*
+import androidx.compose.material3.*
+import androidx.compose.runtime.*
+import androidx.compose.ui.Alignment
+import androidx.compose.ui.Modifier
+import androidx.compose.ui.draw.clip
+import androidx.compose.ui.graphics.Color
+import androidx.compose.ui.layout.ContentScale
+import androidx.compose.ui.platform.LocalSoftwareKeyboardController
+import androidx.compose.ui.text.font.FontWeight
+import androidx.compose.ui.text.input.ImeAction
+import androidx.compose.ui.text.input.KeyboardCapitalization
+import androidx.compose.ui.text.input.KeyboardType
+import androidx.compose.ui.text.style.TextAlign
+import androidx.compose.ui.unit.dp
+import androidx.compose.ui.window.Dialog
+import coil.compose.AsyncImage
+import com.placeholder.sherpai2.domain.clustering.ClusterQualityAnalyzer
+import com.placeholder.sherpai2.domain.clustering.ClusterQualityTier
+import com.placeholder.sherpai2.domain.clustering.FaceCluster
+import java.text.SimpleDateFormat
+import java.util.*
+
+/**
+ * NamingDialog - Complete dialog for naming a cluster
+ *
+ * Features:
+ * - Name input with validation
+ * - Child toggle with date of birth picker
+ * - Sibling cluster selection
+ * - Quality warnings display
+ * - Preview of representative faces
+ *
+ * IMPROVEMENTS:
+ * - ✅ Complete UI implementation
+ * - ✅ Quality analysis integration
+ * - ✅ Sibling selection
+ * - ✅ Form validation
+ */
+@OptIn(ExperimentalMaterial3Api::class)
+@Composable
+fun NamingDialog(
+ cluster: FaceCluster,
+ suggestedSiblings: List,
+ onConfirm: (name: String, dateOfBirth: Long?, isChild: Boolean, selectedSiblings: List) -> Unit,
+ onDismiss: () -> Unit,
+ qualityAnalyzer: ClusterQualityAnalyzer = remember { ClusterQualityAnalyzer() }
+) {
+ var name by remember { mutableStateOf("") }
+ var isChild by remember { mutableStateOf(false) }
+ var showDatePicker by remember { mutableStateOf(false) }
+ var dateOfBirth by remember { mutableStateOf(null) }
+ var selectedSiblingIds by remember { mutableStateOf(setOf()) }
+
+ // Analyze cluster quality
+ val qualityResult = remember(cluster) {
+ qualityAnalyzer.analyzeCluster(cluster)
+ }
+
+ val keyboardController = LocalSoftwareKeyboardController.current
+ val dateFormatter = remember { SimpleDateFormat("MMM dd, yyyy", Locale.getDefault()) }
+
+ Dialog(onDismissRequest = onDismiss) {
+ Card(
+ modifier = Modifier
+ .fillMaxWidth()
+ .fillMaxHeight(0.9f),
+ shape = RoundedCornerShape(16.dp),
+ elevation = CardDefaults.cardElevation(defaultElevation = 8.dp)
+ ) {
+ Column(
+ modifier = Modifier
+ .fillMaxSize()
+ .verticalScroll(rememberScrollState())
+ ) {
+ // Header
+ Surface(
+ color = MaterialTheme.colorScheme.primaryContainer
+ ) {
+ Row(
+ modifier = Modifier
+ .fillMaxWidth()
+ .padding(16.dp),
+ horizontalArrangement = Arrangement.SpaceBetween,
+ verticalAlignment = Alignment.CenterVertically
+ ) {
+ Column(modifier = Modifier.weight(1f)) {
+ Text(
+ text = "Name This Person",
+ style = MaterialTheme.typography.titleLarge,
+ fontWeight = FontWeight.Bold,
+ color = MaterialTheme.colorScheme.onPrimaryContainer
+ )
+ Text(
+ text = "${cluster.photoCount} photos",
+ style = MaterialTheme.typography.bodyMedium,
+ color = MaterialTheme.colorScheme.onPrimaryContainer.copy(alpha = 0.7f)
+ )
+ }
+
+ IconButton(onClick = onDismiss) {
+ Icon(
+ imageVector = Icons.Default.Close,
+ contentDescription = "Close",
+ tint = MaterialTheme.colorScheme.onPrimaryContainer
+ )
+ }
+ }
+ }
+
+ Column(
+ modifier = Modifier.padding(16.dp)
+ ) {
+ // Preview faces
+ Text(
+ text = "Preview",
+ style = MaterialTheme.typography.titleMedium,
+ fontWeight = FontWeight.SemiBold
+ )
+
+ Spacer(modifier = Modifier.height(8.dp))
+
+ LazyRow(
+ horizontalArrangement = Arrangement.spacedBy(8.dp)
+ ) {
+ items(cluster.representativeFaces.take(6)) { face ->
+ AsyncImage(
+ model = android.net.Uri.parse(face.imageUri),
+ contentDescription = "Preview",
+ modifier = Modifier
+ .size(80.dp)
+ .clip(RoundedCornerShape(8.dp))
+ .border(
+ width = 1.dp,
+ color = MaterialTheme.colorScheme.outline,
+ shape = RoundedCornerShape(8.dp)
+ ),
+ contentScale = ContentScale.Crop
+ )
+ }
+ }
+
+ Spacer(modifier = Modifier.height(20.dp))
+
+ // Quality warning (if applicable)
+ if (qualityResult.qualityTier != ClusterQualityTier.EXCELLENT) {
+ QualityWarningCard(qualityResult = qualityResult)
+ Spacer(modifier = Modifier.height(16.dp))
+ }
+
+ // Name input
+ OutlinedTextField(
+ value = name,
+ onValueChange = { name = it },
+ label = { Text("Name") },
+ placeholder = { Text("Enter person's name") },
+ modifier = Modifier.fillMaxWidth(),
+ singleLine = true,
+ leadingIcon = {
+ Icon(
+ imageVector = Icons.Default.Person,
+ contentDescription = null
+ )
+ },
+ keyboardOptions = KeyboardOptions(
+ capitalization = KeyboardCapitalization.Words,
+ imeAction = ImeAction.Done
+ ),
+ keyboardActions = KeyboardActions(
+ onDone = { keyboardController?.hide() }
+ )
+ )
+
+ Spacer(modifier = Modifier.height(16.dp))
+
+ // Child toggle
+ Row(
+ modifier = Modifier
+ .fillMaxWidth()
+ .clip(RoundedCornerShape(8.dp))
+ .clickable { isChild = !isChild }
+ .background(
+ if (isChild) MaterialTheme.colorScheme.primaryContainer
+ else MaterialTheme.colorScheme.surfaceVariant
+ )
+ .padding(16.dp),
+ verticalAlignment = Alignment.CenterVertically,
+ horizontalArrangement = Arrangement.SpaceBetween
+ ) {
+ Row(
+ verticalAlignment = Alignment.CenterVertically
+ ) {
+ Icon(
+ imageVector = Icons.Default.Face,
+ contentDescription = null,
+ tint = if (isChild) MaterialTheme.colorScheme.onPrimaryContainer
+ else MaterialTheme.colorScheme.onSurfaceVariant
+ )
+ Spacer(modifier = Modifier.width(12.dp))
+ Column {
+ Text(
+ text = "This is a child",
+ style = MaterialTheme.typography.bodyLarge,
+ fontWeight = FontWeight.Medium,
+ color = if (isChild) MaterialTheme.colorScheme.onPrimaryContainer
+ else MaterialTheme.colorScheme.onSurfaceVariant
+ )
+ Text(
+ text = "For age-appropriate filtering",
+ style = MaterialTheme.typography.bodySmall,
+ color = if (isChild) MaterialTheme.colorScheme.onPrimaryContainer.copy(alpha = 0.7f)
+ else MaterialTheme.colorScheme.onSurfaceVariant.copy(alpha = 0.7f)
+ )
+ }
+ }
+
+ Switch(
+ checked = isChild,
+ onCheckedChange = { isChild = it }
+ )
+ }
+
+ // Date of birth (if child)
+ if (isChild) {
+ Spacer(modifier = Modifier.height(12.dp))
+
+ OutlinedButton(
+ onClick = { showDatePicker = true },
+ modifier = Modifier.fillMaxWidth()
+ ) {
+ Icon(
+ imageVector = Icons.Default.DateRange,
+ contentDescription = null
+ )
+ Spacer(modifier = Modifier.width(8.dp))
+ Text(
+ text = dateOfBirth?.let { dateFormatter.format(Date(it)) }
+ ?: "Set date of birth (optional)"
+ )
+ }
+ }
+
+ // Sibling selection
+ if (suggestedSiblings.isNotEmpty()) {
+ Spacer(modifier = Modifier.height(20.dp))
+
+ Text(
+ text = "Appears with",
+ style = MaterialTheme.typography.titleMedium,
+ fontWeight = FontWeight.SemiBold
+ )
+
+ Text(
+ text = "Select siblings or family members",
+ style = MaterialTheme.typography.bodySmall,
+ color = MaterialTheme.colorScheme.onSurfaceVariant
+ )
+
+ Spacer(modifier = Modifier.height(8.dp))
+
+ suggestedSiblings.forEach { sibling ->
+ SiblingSelectionItem(
+ cluster = sibling,
+ selected = sibling.clusterId in selectedSiblingIds,
+ onToggle = {
+ selectedSiblingIds = if (sibling.clusterId in selectedSiblingIds) {
+ selectedSiblingIds - sibling.clusterId
+ } else {
+ selectedSiblingIds + sibling.clusterId
+ }
+ }
+ )
+ Spacer(modifier = Modifier.height(8.dp))
+ }
+ }
+
+ Spacer(modifier = Modifier.height(24.dp))
+
+ // Action buttons
+ Row(
+ modifier = Modifier.fillMaxWidth(),
+ horizontalArrangement = Arrangement.spacedBy(12.dp)
+ ) {
+ OutlinedButton(
+ onClick = onDismiss,
+ modifier = Modifier.weight(1f)
+ ) {
+ Text("Cancel")
+ }
+
+ Button(
+ onClick = {
+ if (name.isNotBlank()) {
+ onConfirm(
+ name.trim(),
+ dateOfBirth,
+ isChild,
+ selectedSiblingIds.toList()
+ )
+ }
+ },
+ modifier = Modifier.weight(1f),
+ enabled = name.isNotBlank() && qualityResult.canTrain
+ ) {
+ Icon(
+ imageVector = Icons.Default.Check,
+ contentDescription = null,
+ modifier = Modifier.size(20.dp)
+ )
+ Spacer(modifier = Modifier.width(8.dp))
+ Text("Create Model")
+ }
+ }
+ }
+ }
+ }
+ }
+
+ // Date picker dialog
+ if (showDatePicker) {
+ val datePickerState = rememberDatePickerState()
+
+ DatePickerDialog(
+ onDismissRequest = { showDatePicker = false },
+ confirmButton = {
+ TextButton(
+ onClick = {
+ dateOfBirth = datePickerState.selectedDateMillis
+ showDatePicker = false
+ }
+ ) {
+ Text("OK")
+ }
+ },
+ dismissButton = {
+ TextButton(onClick = { showDatePicker = false }) {
+ Text("Cancel")
+ }
+ }
+ ) {
+ DatePicker(state = datePickerState)
+ }
+ }
+}
+
+/**
+ * Quality warning card
+ */
+@Composable
+private fun QualityWarningCard(qualityResult: com.placeholder.sherpai2.domain.clustering.ClusterQualityResult) {
+ val (backgroundColor, iconColor) = when (qualityResult.qualityTier) {
+ ClusterQualityTier.GOOD -> Pair(
+ Color(0xFFFFF9C4),
+ Color(0xFFF57F17)
+ )
+ ClusterQualityTier.POOR -> Pair(
+ Color(0xFFFFEBEE),
+ Color(0xFFD32F2F)
+ )
+ else -> Pair(
+ MaterialTheme.colorScheme.surfaceVariant,
+ MaterialTheme.colorScheme.onSurfaceVariant
+ )
+ }
+
+ Card(
+ modifier = Modifier.fillMaxWidth(),
+ colors = CardDefaults.cardColors(
+ containerColor = backgroundColor
+ )
+ ) {
+ Column(
+ modifier = Modifier.padding(12.dp)
+ ) {
+ Row(
+ verticalAlignment = Alignment.CenterVertically
+ ) {
+ Icon(
+ imageVector = Icons.Default.Warning,
+ contentDescription = null,
+ tint = iconColor,
+ modifier = Modifier.size(20.dp)
+ )
+ Spacer(modifier = Modifier.width(8.dp))
+ Text(
+ text = when (qualityResult.qualityTier) {
+ ClusterQualityTier.GOOD -> "Review Before Training"
+ ClusterQualityTier.POOR -> "Quality Issues Detected"
+ else -> ""
+ },
+ style = MaterialTheme.typography.titleSmall,
+ fontWeight = FontWeight.Bold,
+ color = iconColor
+ )
+ }
+
+ Spacer(modifier = Modifier.height(8.dp))
+
+ qualityResult.warnings.forEach { warning ->
+ Text(
+ text = "• $warning",
+ style = MaterialTheme.typography.bodySmall,
+ color = MaterialTheme.colorScheme.onSurfaceVariant
+ )
+ }
+ }
+ }
+}
+
+/**
+ * Sibling selection item
+ */
+@Composable
+private fun SiblingSelectionItem(
+ cluster: FaceCluster,
+ selected: Boolean,
+ onToggle: () -> Unit
+) {
+ Row(
+ modifier = Modifier
+ .fillMaxWidth()
+ .clip(RoundedCornerShape(8.dp))
+ .clickable(onClick = onToggle)
+ .background(
+ if (selected) MaterialTheme.colorScheme.primaryContainer
+ else MaterialTheme.colorScheme.surfaceVariant
+ )
+ .padding(12.dp),
+ verticalAlignment = Alignment.CenterVertically,
+ horizontalArrangement = Arrangement.SpaceBetween
+ ) {
+ Row(
+ verticalAlignment = Alignment.CenterVertically
+ ) {
+ // Preview face
+ AsyncImage(
+ model = android.net.Uri.parse(cluster.representativeFaces.firstOrNull()?.imageUri ?: ""),
+ contentDescription = "Preview",
+ modifier = Modifier
+ .size(40.dp)
+ .clip(CircleShape)
+ .border(
+ width = 1.dp,
+ color = MaterialTheme.colorScheme.outline,
+ shape = CircleShape
+ ),
+ contentScale = ContentScale.Crop
+ )
+
+ Spacer(modifier = Modifier.width(12.dp))
+
+ Text(
+ text = "${cluster.photoCount} photos together",
+ style = MaterialTheme.typography.bodyMedium,
+ color = if (selected) MaterialTheme.colorScheme.onPrimaryContainer
+ else MaterialTheme.colorScheme.onSurfaceVariant
+ )
+ }
+
+ Checkbox(
+ checked = selected,
+ onCheckedChange = { onToggle() }
+ )
+ }
+}
\ No newline at end of file
diff --git a/app/src/main/java/com/placeholder/sherpai2/ui/presentation/MainScreen.kt b/app/src/main/java/com/placeholder/sherpai2/ui/presentation/MainScreen.kt
index 600ebf8..fc6fa36 100644
--- a/app/src/main/java/com/placeholder/sherpai2/ui/presentation/MainScreen.kt
+++ b/app/src/main/java/com/placeholder/sherpai2/ui/presentation/MainScreen.kt
@@ -1,56 +1,48 @@
package com.placeholder.sherpai2.ui.presentation
-import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.padding
import androidx.compose.material.icons.Icons
-import androidx.compose.material.icons.filled.*
+import androidx.compose.material.icons.filled.Menu
import androidx.compose.material3.*
import androidx.compose.runtime.*
import androidx.compose.ui.Modifier
-import androidx.compose.ui.text.font.FontWeight
import androidx.hilt.navigation.compose.hiltViewModel
-import androidx.navigation.compose.currentBackStackEntryAsState
import androidx.navigation.compose.rememberNavController
+import androidx.navigation.compose.currentBackStackEntryAsState
import com.placeholder.sherpai2.ui.navigation.AppNavHost
import com.placeholder.sherpai2.ui.navigation.AppRoutes
import kotlinx.coroutines.launch
/**
- * MainScreen - UPDATED with auto face cache check
+ * MainScreen - Complete app container with drawer navigation
*
- * NEW: Prompts user to populate face cache on app launch if needed
- * FIXED: Prevents double headers for screens with their own TopAppBar
+ * CRITICAL FIX APPLIED:
+ * ✅ Removed AppRoutes.DISCOVER from screensWithOwnTopBar
+ * ✅ DiscoverPeopleScreen now shows hamburger menu + "Discover People" title!
*/
@OptIn(ExperimentalMaterial3Api::class)
@Composable
fun MainScreen(
- mainViewModel: MainViewModel = hiltViewModel() // Same package - no import needed!
+ viewModel: MainViewModel = hiltViewModel()
) {
- val drawerState = rememberDrawerState(initialValue = DrawerValue.Closed)
- val scope = rememberCoroutineScope()
val navController = rememberNavController()
+ val drawerState = rememberDrawerState(DrawerValue.Closed)
+ val scope = rememberCoroutineScope()
- val navBackStackEntry by navController.currentBackStackEntryAsState()
- val currentRoute = navBackStackEntry?.destination?.route ?: AppRoutes.SEARCH
+ val currentBackStackEntry by navController.currentBackStackEntryAsState()
+ val currentRoute = currentBackStackEntry?.destination?.route
- // Face cache status
- val needsFaceCache by mainViewModel.needsFaceCachePopulation.collectAsState()
- val unscannedCount by mainViewModel.unscannedPhotoCount.collectAsState()
+ // Face cache prompt dialog state
+ val needsFaceCachePopulation by viewModel.needsFaceCachePopulation.collectAsState()
+ val unscannedPhotoCount by viewModel.unscannedPhotoCount.collectAsState()
- // Show face cache prompt dialog if needed
- if (needsFaceCache && unscannedCount > 0) {
- FaceCachePromptDialog(
- unscannedPhotoCount = unscannedCount,
- onDismiss = { mainViewModel.dismissFaceCachePrompt() },
- onScanNow = {
- mainViewModel.dismissFaceCachePrompt()
- // Navigate to Photo Utilities
- navController.navigate(AppRoutes.UTILITIES) {
- launchSingleTop = true
- }
- }
- )
- }
+ // ✅ CRITICAL FIX: DISCOVER is NOT in this list!
+ // These screens handle their own TopAppBar/navigation
+ val screensWithOwnTopBar = setOf(
+ AppRoutes.IMAGE_DETAIL,
+ AppRoutes.TRAINING_SCREEN,
+ AppRoutes.CROP_SCREEN
+ )
ModalNavigationDrawer(
drawerState = drawerState,
@@ -60,133 +52,86 @@ fun MainScreen(
onDestinationClicked = { route ->
scope.launch {
drawerState.close()
- if (route != currentRoute) {
- navController.navigate(route) {
- launchSingleTop = true
- }
+ }
+ navController.navigate(route) {
+ popUpTo(navController.graph.startDestinationId) {
+ saveState = true
}
+ launchSingleTop = true
+ restoreState = true
}
}
)
- },
+ }
) {
- // CRITICAL: Some screens manage their own TopAppBar
- // Hide MainScreen's TopAppBar for these routes to prevent double headers
- val screensWithOwnTopBar = setOf(
- AppRoutes.TRAINING_PHOTO_SELECTOR, // Has its own TopAppBar with subtitle
- "album/", // Album views have their own TopAppBar (prefix match)
- AppRoutes.IMAGE_DETAIL // Image detail has its own TopAppBar
- )
-
- // Check if current route starts with any excluded pattern
- val showTopBar = screensWithOwnTopBar.none { currentRoute.startsWith(it) }
-
Scaffold(
topBar = {
- if (showTopBar) {
+ // ✅ Show TopAppBar for ALL screens except those with their own
+ if (currentRoute !in screensWithOwnTopBar) {
TopAppBar(
title = {
- Column {
- Text(
- text = getScreenTitle(currentRoute),
- style = MaterialTheme.typography.titleLarge,
- fontWeight = FontWeight.Bold
- )
- getScreenSubtitle(currentRoute)?.let { subtitle ->
- Text(
- text = subtitle,
- style = MaterialTheme.typography.bodySmall,
- color = MaterialTheme.colorScheme.onSurfaceVariant
- )
+ Text(
+ text = when (currentRoute) {
+ AppRoutes.SEARCH -> "Search"
+ AppRoutes.EXPLORE -> "Explore"
+ AppRoutes.COLLECTIONS -> "Collections"
+ AppRoutes.DISCOVER -> "Discover People" // ✅ SHOWS NOW!
+ AppRoutes.INVENTORY -> "People"
+ AppRoutes.TRAIN -> "Train Model"
+ AppRoutes.TAGS -> "Tags"
+ AppRoutes.UTILITIES -> "Utilities"
+ AppRoutes.SETTINGS -> "Settings"
+ AppRoutes.MODELS -> "AI Models"
+ else -> {
+ // Handle dynamic routes like album/{type}/{id}
+ if (currentRoute?.startsWith("album/") == true) {
+ "Album"
+ } else {
+ "SherpAI"
+ }
+ }
}
- }
+ )
},
navigationIcon = {
- IconButton(
- onClick = { scope.launch { drawerState.open() } }
- ) {
+ IconButton(onClick = {
+ scope.launch {
+ drawerState.open()
+ }
+ }) {
Icon(
- Icons.Default.Menu,
- contentDescription = "Open Menu",
- tint = MaterialTheme.colorScheme.primary
+ imageVector = Icons.Default.Menu,
+ contentDescription = "Open menu"
)
}
},
- actions = {
- // Dynamic actions based on current screen
- when (currentRoute) {
- AppRoutes.SEARCH -> {
- IconButton(onClick = { /* TODO: Open filter dialog */ }) {
- Icon(
- Icons.Default.FilterList,
- contentDescription = "Filter",
- tint = MaterialTheme.colorScheme.primary
- )
- }
- }
- AppRoutes.INVENTORY -> {
- IconButton(onClick = {
- navController.navigate(AppRoutes.TRAIN)
- }) {
- Icon(
- Icons.Default.PersonAdd,
- contentDescription = "Add Person",
- tint = MaterialTheme.colorScheme.primary
- )
- }
- }
- }
- },
colors = TopAppBarDefaults.topAppBarColors(
- containerColor = MaterialTheme.colorScheme.surface,
- titleContentColor = MaterialTheme.colorScheme.onSurface,
- navigationIconContentColor = MaterialTheme.colorScheme.primary,
- actionIconContentColor = MaterialTheme.colorScheme.primary
+ containerColor = MaterialTheme.colorScheme.primaryContainer,
+ titleContentColor = MaterialTheme.colorScheme.onPrimaryContainer,
+ navigationIconContentColor = MaterialTheme.colorScheme.onPrimaryContainer,
+ actionIconContentColor = MaterialTheme.colorScheme.onPrimaryContainer
)
)
}
}
) { paddingValues ->
+ // ✅ Use YOUR existing AppNavHost - it already has all the screens defined!
AppNavHost(
navController = navController,
modifier = Modifier.padding(paddingValues)
)
}
}
-}
-/**
- * Get human-readable screen title
- */
-private fun getScreenTitle(route: String): String {
- return when (route) {
- AppRoutes.SEARCH -> "Search"
- AppRoutes.EXPLORE -> "Explore"
- AppRoutes.COLLECTIONS -> "Collections"
- AppRoutes.DISCOVER -> "Discover People"
- AppRoutes.INVENTORY -> "People"
- AppRoutes.TRAIN -> "Train New Person"
- AppRoutes.MODELS -> "AI Models"
- AppRoutes.TAGS -> "Tag Management"
- AppRoutes.UTILITIES -> "Photo Util."
- AppRoutes.SETTINGS -> "Settings"
- else -> "SherpAI"
- }
-}
-
-/**
- * Get subtitle for screens that need context
- */
-private fun getScreenSubtitle(route: String): String? {
- return when (route) {
- AppRoutes.SEARCH -> "Find photos by tags, people, or date"
- AppRoutes.EXPLORE -> "Browse your collection"
- AppRoutes.COLLECTIONS -> "Your photo collections"
- AppRoutes.DISCOVER -> "Auto-find faces in your library"
- AppRoutes.INVENTORY -> "Trained face models"
- AppRoutes.TRAIN -> "Add a new person to recognize"
- AppRoutes.TAGS -> "Organize your photo collection"
- AppRoutes.UTILITIES -> "Tools for managing collection"
- else -> null
+ // ✅ Face cache prompt dialog (shows on app launch if needed)
+ if (needsFaceCachePopulation) {
+ FaceCachePromptDialog(
+ unscannedPhotoCount = unscannedPhotoCount,
+ onDismiss = { viewModel.dismissFaceCachePrompt() },
+ onScanNow = {
+ viewModel.dismissFaceCachePrompt()
+ navController.navigate(AppRoutes.UTILITIES)
+ }
+ )
}
}
\ No newline at end of file