discover dez

This commit is contained in:
genki
2026-01-21 10:11:20 -05:00
parent 7f122a4e17
commit 4474365cd6
10 changed files with 1446 additions and 615 deletions

View File

@@ -8,7 +8,15 @@
<list>
<CategoryState>
<option name="attribute" value="Type" />
<option name="value" value="Physical" />
<option name="value" value="Virtual" />
</CategoryState>
<CategoryState>
<option name="attribute" value="Type" />
<option name="value" value="Virtual" />
</CategoryState>
<CategoryState>
<option name="attribute" value="Type" />
<option name="value" value="Virtual" />
</CategoryState>
</list>
</option>
@@ -17,6 +25,10 @@
</option>
<option name="columnSorters">
<list>
<ColumnSorterState>
<option name="column" value="Status" />
<option name="order" value="ASCENDING" />
</ColumnSorterState>
<ColumnSorterState>
<option name="column" value="Name" />
<option name="order" value="DESCENDING" />
@@ -37,6 +49,69 @@
<option value="Type" />
<option value="Type" />
<option value="Type" />
<option value="Type" />
<option value="Type" />
<option value="Type" />
<option value="Type" />
<option value="Type" />
<option value="Type" />
<option value="Type" />
<option value="Type" />
<option value="Type" />
<option value="Type" />
<option value="Type" />
<option value="Type" />
<option value="Type" />
<option value="Type" />
<option value="Type" />
<option value="Type" />
<option value="Type" />
<option value="Type" />
<option value="Type" />
<option value="Type" />
<option value="Type" />
<option value="Type" />
<option value="Type" />
<option value="Type" />
<option value="Type" />
<option value="Type" />
<option value="Type" />
<option value="Type" />
<option value="Type" />
<option value="Type" />
<option value="Type" />
<option value="Type" />
<option value="Type" />
<option value="Type" />
<option value="Type" />
<option value="Type" />
<option value="Type" />
<option value="Type" />
<option value="Type" />
<option value="Type" />
<option value="Type" />
<option value="Type" />
<option value="Type" />
<option value="Type" />
<option value="Type" />
<option value="Type" />
<option value="Type" />
<option value="Type" />
<option value="Type" />
<option value="Type" />
<option value="Type" />
<option value="Type" />
<option value="Type" />
<option value="Type" />
<option value="Type" />
<option value="Type" />
<option value="Type" />
<option value="Type" />
<option value="Type" />
<option value="Type" />
<option value="Type" />
<option value="Type" />
<option value="Type" />
</list>
</option>
</component>

Binary file not shown.

View File

@@ -1,129 +1,91 @@
package com.placeholder.sherpai2.data.local.dao
import androidx.room.*
import androidx.room.Dao
import androidx.room.Insert
import androidx.room.OnConflictStrategy
import androidx.room.Query
import androidx.room.Update
import com.placeholder.sherpai2.data.local.entity.FaceCacheEntity
import kotlinx.coroutines.flow.Flow
/**
* FaceCacheDao - Query face metadata for intelligent filtering
*
* ENABLES SMART CLUSTERING:
* - Pre-filter to high-quality faces only
* - Avoid processing blurry/distant faces
* - Faster clustering with better results
* FaceCacheDao - Face detection cache with NEW queries for two-stage clustering
*/
@Dao
interface FaceCacheDao {
// ═══════════════════════════════════════
// INSERT / UPDATE
// ═══════════════════════════════════════
@Insert(onConflict = OnConflictStrategy.REPLACE)
suspend fun insert(faceCache: FaceCacheEntity)
@Insert(onConflict = OnConflictStrategy.REPLACE)
suspend fun insertAll(faceCaches: List<FaceCacheEntity>)
/**
* Get ALL high-quality solo faces for clustering
*
* FILTERS:
* - Solo photos only (joins with images.faceCount = 1)
* - Large enough (isLargeEnough = true)
* - Good quality score (>= 0.6)
* - Frontal faces preferred (isFrontal = true)
*/
@Query("""
SELECT fc.* FROM face_cache fc
INNER JOIN images i ON fc.imageId = i.imageId
WHERE i.faceCount = 1
AND fc.isLargeEnough = 1
AND fc.qualityScore >= 0.6
AND fc.isFrontal = 1
ORDER BY fc.qualityScore DESC
""")
suspend fun getHighQualitySoloFaces(): List<FaceCacheEntity>
@Update
suspend fun update(faceCache: FaceCacheEntity)
// ═══════════════════════════════════════
// NEW CLUSTERING QUERIES ⭐
// ═══════════════════════════════════════
/**
* Get high-quality faces from ANY photo (including group photos)
* Use when not enough solo photos available
* Get high-quality solo faces for Stage 1 clustering
*
* Filters:
* - Solo photos (faceCount = 1)
* - Large faces (faceAreaRatio >= minFaceRatio)
* - Has embedding
*/
@Query("""
SELECT * FROM face_cache
WHERE isLargeEnough = 1
AND qualityScore >= 0.6
AND isFrontal = 1
ORDER BY qualityScore DESC
SELECT fc.*
FROM face_cache fc
INNER JOIN images i ON fc.imageId = i.imageId
WHERE i.faceCount = 1
AND fc.faceAreaRatio >= :minFaceRatio
AND fc.embedding IS NOT NULL
ORDER BY fc.faceAreaRatio DESC
LIMIT :limit
""")
suspend fun getHighQualityFaces(limit: Int = 1000): List<FaceCacheEntity>
suspend fun getHighQualitySoloFaces(
minFaceRatio: Float = 0.015f,
limit: Int = 2000
): List<FaceCacheEntity>
/**
* Get faces for a specific image
*/
@Query("SELECT * FROM face_cache WHERE imageId = :imageId ORDER BY faceIndex ASC")
suspend fun getFacesForImage(imageId: String): List<FaceCacheEntity>
/**
* Count high-quality solo faces (for UI display)
* FALLBACK: Get ANY solo faces with embeddings
* Used if getHighQualitySoloFaces() returns empty
*/
@Query("""
SELECT COUNT(*) FROM face_cache fc
SELECT fc.*
FROM face_cache fc
INNER JOIN images i ON fc.imageId = i.imageId
WHERE i.faceCount = 1
AND fc.isLargeEnough = 1
AND fc.qualityScore >= 0.6
""")
suspend fun getHighQualitySoloFaceCount(): Int
/**
* Get quality distribution stats
*/
@Query("""
SELECT
SUM(CASE WHEN qualityScore >= 0.8 THEN 1 ELSE 0 END) as excellent,
SUM(CASE WHEN qualityScore >= 0.6 AND qualityScore < 0.8 THEN 1 ELSE 0 END) as good,
SUM(CASE WHEN qualityScore < 0.6 THEN 1 ELSE 0 END) as poor,
COUNT(*) as total
FROM face_cache
""")
suspend fun getQualityStats(): FaceQualityStats?
/**
* Delete cache for specific image (when image is deleted)
*/
@Query("DELETE FROM face_cache WHERE imageId = :imageId")
suspend fun deleteCacheForImage(imageId: String)
/**
* Delete all cache (for full rebuild)
*/
@Query("DELETE FROM face_cache")
suspend fun deleteAll()
/**
* Get faces with embeddings already computed
* (Ultra-fast clustering - no need to re-generate)
*/
@Query("""
SELECT fc.* FROM face_cache fc
INNER JOIN images i ON fc.imageId = i.imageId
WHERE i.faceCount = 1
AND fc.isLargeEnough = 1
AND fc.embedding IS NOT NULL
ORDER BY fc.qualityScore DESC
LIMIT :limit
""")
suspend fun getSoloFacesWithEmbeddings(limit: Int = 2000): List<FaceCacheEntity>
}
suspend fun getSoloFacesWithEmbeddings(
limit: Int = 2000
): List<FaceCacheEntity>
/**
* Quality statistics result
*/
data class FaceQualityStats(
val excellent: Int, // qualityScore >= 0.8
val good: Int, // 0.6 <= qualityScore < 0.8
val poor: Int, // qualityScore < 0.6
val total: Int
) {
val excellentPercent: Float get() = if (total > 0) excellent.toFloat() / total else 0f
val goodPercent: Float get() = if (total > 0) good.toFloat() / total else 0f
val poorPercent: Float get() = if (total > 0) poor.toFloat() / total else 0f
// ═══════════════════════════════════════
// EXISTING QUERIES (keep as-is)
// ═══════════════════════════════════════
@Query("SELECT * FROM face_cache WHERE id = :id")
suspend fun getFaceCacheById(id: String): FaceCacheEntity?
@Query("SELECT * FROM face_cache WHERE imageId = :imageId ORDER BY faceIndex")
suspend fun getFaceCacheForImage(imageId: String): List<FaceCacheEntity>
@Query("DELETE FROM face_cache WHERE imageId = :imageId")
suspend fun deleteFaceCacheForImage(imageId: String)
@Query("DELETE FROM face_cache")
suspend fun deleteAll()
@Query("DELETE FROM face_cache WHERE cacheVersion < :version")
suspend fun deleteOldVersions(version: Int)
}

View File

@@ -1,7 +1,7 @@
package com.placeholder.sherpai2.domain.clustering
import android.graphics.Rect
import com.placeholder.sherpai2.domain.clustering.DetectedFaceWithEmbedding
import android.util.Log
import javax.inject.Inject
import javax.inject.Singleton
import kotlin.math.sqrt
@@ -9,56 +9,62 @@ import kotlin.math.sqrt
/**
* ClusterQualityAnalyzer - Validate cluster quality BEFORE training
*
* PURPOSE: Prevent training on "dirty" clusters (siblings merged, poor quality faces)
*
* CHECKS:
* 1. Solo photo count (min 6 required)
* 2. Face size (min 15% of image - clear, not distant)
* 3. Internal consistency (all faces should match well)
* 4. Outlier detection (find faces that don't belong)
*
* QUALITY TIERS:
* - Excellent (95%+): Safe to train immediately
* - Good (85-94%): Review outliers, then train
* - Poor (<85%): Likely mixed people - DO NOT TRAIN!
* RELAXED THRESHOLDS for real-world photos (social media, distant shots):
* - Face size: 3% (down from 15%)
* - Outlier threshold: 65% (down from 75%)
* - GOOD tier: 75% (down from 85%)
* - EXCELLENT tier: 85% (down from 95%)
*/
@Singleton
class ClusterQualityAnalyzer @Inject constructor() {
companion object {
private const val TAG = "ClusterQuality"
private const val MIN_SOLO_PHOTOS = 6
private const val MIN_FACE_SIZE_RATIO = 0.15f // 15% of image
private const val MIN_INTERNAL_SIMILARITY = 0.80f
private const val OUTLIER_THRESHOLD = 0.75f
private const val EXCELLENT_THRESHOLD = 0.95f
private const val GOOD_THRESHOLD = 0.85f
private const val MIN_FACE_SIZE_RATIO = 0.03f // 3% of image (RELAXED)
private const val MIN_FACE_DIMENSION_PIXELS = 50 // 50px minimum (RELAXED)
private const val FALLBACK_MIN_DIMENSION = 80 // Fallback when no dimensions
private const val MIN_INTERNAL_SIMILARITY = 0.75f
private const val OUTLIER_THRESHOLD = 0.65f // RELAXED
private const val EXCELLENT_THRESHOLD = 0.85f // RELAXED
private const val GOOD_THRESHOLD = 0.75f // RELAXED
}
/**
* Analyze cluster quality before training
*/
fun analyzeCluster(cluster: FaceCluster): ClusterQualityResult {
// Step 1: Filter to solo photos only
val soloFaces = cluster.faces.filter { it.faceCount == 1 }
Log.d(TAG, "========================================")
Log.d(TAG, "Analyzing cluster ${cluster.clusterId}")
Log.d(TAG, "Total faces: ${cluster.faces.size}")
// Step 2: Filter by face size (must be clear/close-up)
// Step 1: Filter to solo photos
val soloFaces = cluster.faces.filter { it.faceCount == 1 }
Log.d(TAG, "Solo photos: ${soloFaces.size}")
// Step 2: Filter by face size
val largeFaces = soloFaces.filter { face ->
isFaceLargeEnough(face.boundingBox, face.imageUri)
isFaceLargeEnough(face)
}
Log.d(TAG, "Large faces (>= 3%): ${largeFaces.size}")
if (largeFaces.size < soloFaces.size) {
Log.d(TAG, "⚠️ Filtered out ${soloFaces.size - largeFaces.size} small faces")
}
// Step 3: Calculate internal consistency
val (avgSimilarity, outliers) = analyzeInternalConsistency(largeFaces)
// Step 4: Clean faces (large solo faces, no outliers)
// Step 4: Clean faces
val cleanFaces = largeFaces.filter { it !in outliers }
Log.d(TAG, "Clean faces: ${cleanFaces.size}")
// Step 5: Calculate quality score
val qualityScore = calculateQualityScore(
soloPhotoCount = soloFaces.size,
largeFaceCount = largeFaces.size,
cleanFaceCount = cleanFaces.size,
avgSimilarity = avgSimilarity
avgSimilarity = avgSimilarity,
totalFaces = cluster.faces.size
)
Log.d(TAG, "Quality score: ${(qualityScore * 100).toInt()}%")
// Step 6: Determine quality tier
val qualityTier = when {
@@ -66,6 +72,11 @@ class ClusterQualityAnalyzer @Inject constructor() {
qualityScore >= GOOD_THRESHOLD -> ClusterQualityTier.GOOD
else -> ClusterQualityTier.POOR
}
Log.d(TAG, "Quality tier: $qualityTier")
val canTrain = qualityTier != ClusterQualityTier.POOR && cleanFaces.size >= MIN_SOLO_PHOTOS
Log.d(TAG, "Can train: $canTrain")
Log.d(TAG, "========================================")
return ClusterQualityResult(
originalFaceCount = cluster.faces.size,
@@ -77,62 +88,65 @@ class ClusterQualityAnalyzer @Inject constructor() {
cleanFaces = cleanFaces,
qualityScore = qualityScore,
qualityTier = qualityTier,
canTrain = qualityTier != ClusterQualityTier.POOR && cleanFaces.size >= MIN_SOLO_PHOTOS,
warnings = generateWarnings(soloFaces.size, largeFaces.size, cleanFaces.size, qualityTier)
canTrain = canTrain,
warnings = generateWarnings(soloFaces.size, largeFaces.size, cleanFaces.size, qualityTier, avgSimilarity)
)
}
/**
* Check if face is large enough (not distant/blurry)
*
* A face should occupy at least 15% of the image area for good quality
*/
private fun isFaceLargeEnough(boundingBox: Rect, imageUri: String): Boolean {
// Estimate image dimensions from common aspect ratios
// For now, use bounding box size as proxy
val faceArea = boundingBox.width() * boundingBox.height()
private fun isFaceLargeEnough(face: DetectedFaceWithEmbedding): Boolean {
val faceArea = face.boundingBox.width() * face.boundingBox.height()
// Assume typical photo is ~2000x1500 = 3,000,000 pixels
// 15% = 450,000 pixels
// For a square face: sqrt(450,000) = ~670 pixels per side
// Check 1: Absolute minimum
if (face.boundingBox.width() < MIN_FACE_DIMENSION_PIXELS ||
face.boundingBox.height() < MIN_FACE_DIMENSION_PIXELS) {
return false
}
// More conservative: face should be at least 200x200 pixels
return boundingBox.width() >= 200 && boundingBox.height() >= 200
// Check 2: Relative size if we have dimensions
if (face.imageWidth > 0 && face.imageHeight > 0) {
val imageArea = face.imageWidth * face.imageHeight
val faceRatio = faceArea.toFloat() / imageArea.toFloat()
return faceRatio >= MIN_FACE_SIZE_RATIO
}
// Fallback: Use absolute size
return face.boundingBox.width() >= FALLBACK_MIN_DIMENSION &&
face.boundingBox.height() >= FALLBACK_MIN_DIMENSION
}
/**
* Analyze how similar faces are to each other (internal consistency)
*
* Returns: (average similarity, list of outlier faces)
*/
private fun analyzeInternalConsistency(
faces: List<DetectedFaceWithEmbedding>
): Pair<Float, List<DetectedFaceWithEmbedding>> {
if (faces.size < 2) {
Log.d(TAG, "Less than 2 faces, skipping consistency check")
return 0f to emptyList()
}
// Calculate average embedding (centroid)
Log.d(TAG, "Analyzing ${faces.size} faces for internal consistency")
val centroid = calculateCentroid(faces.map { it.embedding })
// Calculate similarity of each face to centroid
val similarities = faces.map { face ->
face to cosineSimilarity(face.embedding, centroid)
val centroidSum = centroid.sum()
Log.d(TAG, "Centroid sum: $centroidSum, first5=[${centroid.take(5).joinToString()}]")
val similarities = faces.mapIndexed { index, face ->
val similarity = cosineSimilarity(face.embedding, centroid)
Log.d(TAG, "Face $index similarity to centroid: $similarity")
face to similarity
}
val avgSimilarity = similarities.map { it.second }.average().toFloat()
Log.d(TAG, "Average internal similarity: $avgSimilarity")
// Find outliers (faces significantly different from centroid)
val outliers = similarities
.filter { (_, similarity) -> similarity < OUTLIER_THRESHOLD }
.map { (face, _) -> face }
Log.d(TAG, "Found ${outliers.size} outliers (threshold=$OUTLIER_THRESHOLD)")
return avgSimilarity to outliers
}
/**
* Calculate centroid (average embedding)
*/
private fun calculateCentroid(embeddings: List<FloatArray>): FloatArray {
val size = embeddings.first().size
val centroid = FloatArray(size) { 0f }
@@ -148,14 +162,14 @@ class ClusterQualityAnalyzer @Inject constructor() {
centroid[i] /= count
}
// Normalize
val norm = sqrt(centroid.map { it * it }.sum())
return centroid.map { it / norm }.toFloatArray()
return if (norm > 0) {
centroid.map { it / norm }.toFloatArray()
} else {
centroid
}
}
/**
* Cosine similarity between two embeddings
*/
private fun cosineSimilarity(a: FloatArray, b: FloatArray): Float {
var dotProduct = 0f
var normA = 0f
@@ -170,32 +184,31 @@ class ClusterQualityAnalyzer @Inject constructor() {
return dotProduct / (sqrt(normA) * sqrt(normB))
}
/**
* Calculate overall quality score (0.0 - 1.0)
*/
private fun calculateQualityScore(
soloPhotoCount: Int,
largeFaceCount: Int,
cleanFaceCount: Int,
avgSimilarity: Float
avgSimilarity: Float,
totalFaces: Int
): Float {
// Weight factors
val soloPhotoScore = (soloPhotoCount.toFloat() / 20f).coerceIn(0f, 1f) * 0.3f
val largeFaceScore = (largeFaceCount.toFloat() / 15f).coerceIn(0f, 1f) * 0.2f
val cleanFaceScore = (cleanFaceCount.toFloat() / 10f).coerceIn(0f, 1f) * 0.2f
val similarityScore = avgSimilarity * 0.3f
val soloRatio = soloPhotoCount.toFloat() / totalFaces.toFloat().coerceAtLeast(1f)
val soloPhotoScore = soloRatio.coerceIn(0f, 1f) * 0.25f
val largeFaceScore = (largeFaceCount.toFloat() / 15f).coerceIn(0f, 1f) * 0.25f
val cleanFaceScore = (cleanFaceCount.toFloat() / 10f).coerceIn(0f, 1f) * 0.20f
val similarityScore = avgSimilarity * 0.30f
return soloPhotoScore + largeFaceScore + cleanFaceScore + similarityScore
}
/**
* Generate human-readable warnings
*/
private fun generateWarnings(
soloPhotoCount: Int,
largeFaceCount: Int,
cleanFaceCount: Int,
qualityTier: ClusterQualityTier
qualityTier: ClusterQualityTier,
avgSimilarity: Float
): List<String> {
val warnings = mutableListOf<String>()
@@ -203,12 +216,20 @@ class ClusterQualityAnalyzer @Inject constructor() {
ClusterQualityTier.POOR -> {
warnings.add("⚠️ POOR QUALITY - This cluster may contain multiple people!")
warnings.add("Do NOT train on this cluster - it will create a bad model.")
if (avgSimilarity < 0.70f) {
warnings.add("Low internal similarity (${(avgSimilarity * 100).toInt()}%) suggests mixed identities.")
}
}
ClusterQualityTier.GOOD -> {
warnings.add("⚠️ Review outlier faces before training")
if (cleanFaceCount < 10) {
warnings.add("Consider adding more high-quality photos for better results.")
}
}
ClusterQualityTier.EXCELLENT -> {
// No warnings - ready to train!
// No warnings
}
}
@@ -218,38 +239,47 @@ class ClusterQualityAnalyzer @Inject constructor() {
if (largeFaceCount < 6) {
warnings.add("Only $largeFaceCount photos with large/clear faces (prefer 10+)")
warnings.add("Tip: Use close-up photos where the face is clearly visible")
}
if (cleanFaceCount < 6) {
warnings.add("After removing outliers: only $cleanFaceCount clean faces (need 6+)")
}
if (qualityTier == ClusterQualityTier.EXCELLENT) {
warnings.add("✅ Excellent quality! This cluster is ready for training.")
warnings.add("High-quality photos with consistent facial features detected.")
}
return warnings
}
}
/**
* Result of cluster quality analysis
*/
data class ClusterQualityResult(
val originalFaceCount: Int, // Total faces in cluster
val soloPhotoCount: Int, // Photos with faceCount = 1
val largeFaceCount: Int, // Solo photos with large faces
val cleanFaceCount: Int, // Large faces, no outliers
val avgInternalSimilarity: Float, // How similar faces are (0.0-1.0)
val outlierFaces: List<DetectedFaceWithEmbedding>, // Faces to exclude
val cleanFaces: List<DetectedFaceWithEmbedding>, // Good faces for training
val qualityScore: Float, // Overall score (0.0-1.0)
val originalFaceCount: Int,
val soloPhotoCount: Int,
val largeFaceCount: Int,
val cleanFaceCount: Int,
val avgInternalSimilarity: Float,
val outlierFaces: List<DetectedFaceWithEmbedding>,
val cleanFaces: List<DetectedFaceWithEmbedding>,
val qualityScore: Float,
val qualityTier: ClusterQualityTier,
val canTrain: Boolean, // Safe to proceed with training?
val warnings: List<String> // Human-readable issues
)
/**
* Quality tier classification
*/
enum class ClusterQualityTier {
EXCELLENT, // 95%+ - Safe to train immediately
GOOD, // 85-94% - Review outliers first
POOR // <85% - DO NOT TRAIN (likely mixed people)
val canTrain: Boolean,
val warnings: List<String>
) {
fun getSummary(): String = when (qualityTier) {
ClusterQualityTier.EXCELLENT ->
"Excellent quality cluster with $cleanFaceCount high-quality photos ready for training."
ClusterQualityTier.GOOD ->
"Good quality cluster with $cleanFaceCount usable photos. Review outliers before training."
ClusterQualityTier.POOR ->
"Poor quality cluster. May contain multiple people or low-quality photos. Add more photos or split cluster."
}
}
enum class ClusterQualityTier {
EXCELLENT, // 85%+
GOOD, // 75-84%
POOR // <75%
}

View File

@@ -4,11 +4,13 @@ import android.content.Context
import android.graphics.Bitmap
import android.graphics.BitmapFactory
import android.net.Uri
import android.util.Log
import com.google.mlkit.vision.common.InputImage
import com.google.mlkit.vision.face.FaceDetection
import com.google.mlkit.vision.face.FaceDetectorOptions
import com.placeholder.sherpai2.data.local.dao.FaceCacheDao
import com.placeholder.sherpai2.data.local.dao.ImageDao
import com.placeholder.sherpai2.data.local.entity.FaceCacheEntity
import com.placeholder.sherpai2.data.local.entity.ImageEntity
import com.placeholder.sherpai2.ml.FaceNetModel
import dagger.hilt.android.qualifiers.ApplicationContext
@@ -18,7 +20,6 @@ import kotlinx.coroutines.awaitAll
import kotlinx.coroutines.coroutineScope
import kotlinx.coroutines.sync.Semaphore
import kotlinx.coroutines.withContext
import java.util.concurrent.atomic.AtomicInteger
import javax.inject.Inject
import javax.inject.Singleton
import kotlin.math.sqrt
@@ -33,111 +34,152 @@ import kotlin.math.sqrt
* 4. Detect faces and generate embeddings (parallel)
* 5. Cluster using DBSCAN (epsilon=0.18, minPoints=3)
* 6. Analyze clusters for age, siblings, representatives
*
* IMPROVEMENTS:
* - ✅ Complete fast-path using FaceCacheDao.getSoloFacesWithEmbeddings()
* - ✅ Works with existing FaceCacheEntity.getEmbedding() method
* - ✅ Centroid-based representative face selection
* - ✅ Batched processing to prevent OOM
* - ✅ RGB_565 bitmap config for 50% memory savings
*/
@Singleton
class FaceClusteringService @Inject constructor(
@ApplicationContext private val context: Context,
private val imageDao: ImageDao,
private val faceCacheDao: FaceCacheDao // Optional - will work without it
private val faceCacheDao: FaceCacheDao
) {
private val semaphore = Semaphore(12)
private val semaphore = Semaphore(8)
companion object {
private const val TAG = "FaceClustering"
private const val MAX_FACES_TO_CLUSTER = 2000
private const val MIN_SOLO_PHOTOS = 50
private const val BATCH_SIZE = 50
private const val MIN_CACHED_FACES = 100
}
/**
* Main clustering entry point - HYBRID with automatic fallback
*
* @param maxFacesToCluster Limit for performance (default 2000)
* @param onProgress Progress callback (current, total, message)
*/
suspend fun discoverPeople(
maxFacesToCluster: Int = 2000,
maxFacesToCluster: Int = MAX_FACES_TO_CLUSTER,
onProgress: (Int, Int, String) -> Unit = { _, _, _ -> }
): ClusteringResult = withContext(Dispatchers.Default) {
// TRY FAST PATH: Use face cache if available
val highQualityFaces = try {
withContext(Dispatchers.IO) {
faceCacheDao.getHighQualitySoloFaces()
}
} catch (e: Exception) {
emptyList()
}
if (highQualityFaces.isNotEmpty()) {
// FAST PATH: Use cached faces (future enhancement)
onProgress(0, 100, "Using face cache (${highQualityFaces.size} faces)...")
// TODO: Implement cache-based clustering
// For now, fall through to classic method
}
// CLASSIC METHOD: Load and process photos
onProgress(0, 100, "Loading solo photos...")
// Step 1: Get SOLO PHOTOS ONLY (faceCount = 1) for cleaner clustering
val soloPhotos = withContext(Dispatchers.IO) {
imageDao.getImagesByFaceCount(count = 1)
}
// Fallback: If not enough solo photos, use all images with faces
val imagesWithFaces = if (soloPhotos.size < 50) {
onProgress(0, 100, "Loading all photos with faces...")
imageDao.getImagesWithFaces()
} else {
soloPhotos
}
if (imagesWithFaces.isEmpty()) {
return@withContext ClusteringResult(
clusters = emptyList(),
totalFacesAnalyzed = 0,
processingTimeMs = 0,
errorMessage = "No photos with faces found. Please ensure face detection cache is populated."
)
}
onProgress(10, 100, "Analyzing ${imagesWithFaces.size} photos (${if (soloPhotos.size >= 50) "solo only" else "all"})...")
val startTime = System.currentTimeMillis()
// Step 2: Detect faces and generate embeddings (parallel)
val allFaces = detectFacesInImages(
images = imagesWithFaces.take(1000), // Smart limit
onProgress = { current, total ->
onProgress(10 + (current * 40 / total), 100, "Detecting faces... $current/$total")
// Try high-quality cached faces FIRST (NEW!)
var cachedFaces = withContext(Dispatchers.IO) {
try {
faceCacheDao.getHighQualitySoloFaces(
minFaceRatio = 0.015f, // 1.5%
limit = maxFacesToCluster
)
} catch (e: Exception) {
// Method doesn't exist yet - that's ok
emptyList()
}
)
}
// Fallback to ANY solo faces if high-quality returned nothing
if (cachedFaces.isEmpty()) {
Log.w(TAG, "No high-quality faces (>= 1.5%), trying ANY solo faces...")
cachedFaces = withContext(Dispatchers.IO) {
try {
faceCacheDao.getSoloFacesWithEmbeddings(limit = maxFacesToCluster)
} catch (e: Exception) {
emptyList()
}
}
}
Log.d(TAG, "Cache check: ${cachedFaces.size} faces available")
val allFaces = if (cachedFaces.size >= MIN_CACHED_FACES) {
// FAST PATH ✅
Log.d(TAG, "Using FAST PATH with ${cachedFaces.size} cached faces")
onProgress(10, 100, "Using cached embeddings (${cachedFaces.size} faces)...")
cachedFaces.mapNotNull { cached ->
val embedding = cached.getEmbedding() ?: return@mapNotNull null
DetectedFaceWithEmbedding(
imageId = cached.imageId,
imageUri = "",
capturedAt = 0L,
embedding = embedding,
boundingBox = cached.getBoundingBox(),
confidence = cached.confidence,
faceCount = 1, // Solo faces only (filtered by query)
imageWidth = cached.imageWidth,
imageHeight = cached.imageHeight
)
}.also {
onProgress(50, 100, "Processing ${it.size} cached faces...")
}
} else {
// SLOW PATH
Log.d(TAG, "Using SLOW PATH - cache has < $MIN_CACHED_FACES faces")
onProgress(0, 100, "Loading photos...")
val soloPhotos = withContext(Dispatchers.IO) {
imageDao.getImagesByFaceCount(count = 1)
}
val imagesWithFaces = if (soloPhotos.size < MIN_SOLO_PHOTOS) {
imageDao.getImagesWithFaces()
} else {
soloPhotos
}
if (imagesWithFaces.isEmpty()) {
return@withContext ClusteringResult(
clusters = emptyList(),
totalFacesAnalyzed = 0,
processingTimeMs = 0,
errorMessage = "No photos with faces found"
)
}
onProgress(10, 100, "Analyzing ${imagesWithFaces.size} photos...")
detectFacesInImagesBatched(
images = imagesWithFaces.take(1000),
onProgress = { current, total ->
onProgress(10 + (current * 40 / total), 100, "Detecting faces... $current/$total")
}
)
}
if (allFaces.isEmpty()) {
return@withContext ClusteringResult(
clusters = emptyList(),
totalFacesAnalyzed = 0,
processingTimeMs = System.currentTimeMillis() - startTime,
errorMessage = "No faces detected in images"
errorMessage = "No faces detected"
)
}
onProgress(50, 100, "Clustering ${allFaces.size} faces...")
// Step 3: DBSCAN clustering
val rawClusters = performDBSCAN(
faces = allFaces.take(maxFacesToCluster),
epsilon = 0.18f, // VERY STRICT for siblings
epsilon = 0.26f,
minPoints = 3
)
onProgress(70, 100, "Analyzing relationships...")
// Step 4: Build co-occurrence graph
val coOccurrenceGraph = buildCoOccurrenceGraph(rawClusters)
onProgress(80, 100, "Selecting representative faces...")
// Step 5: Create final clusters
val clusters = rawClusters.map { cluster ->
FaceCluster(
clusterId = cluster.clusterId,
faces = cluster.faces,
representativeFaces = selectRepresentativeFaces(cluster.faces, count = 6),
representativeFaces = selectRepresentativeFacesByCentroid(cluster.faces, count = 6),
photoCount = cluster.faces.map { it.imageId }.distinct().size,
averageConfidence = cluster.faces.map { it.confidence }.average().toFloat(),
estimatedAge = estimateAge(cluster.faces),
@@ -154,14 +196,31 @@ class FaceClusteringService @Inject constructor(
)
}
/**
* Detect faces in images and generate embeddings (parallel)
*/
private suspend fun detectFacesInImages(
private suspend fun detectFacesInImagesBatched(
images: List<ImageEntity>,
onProgress: (Int, Int) -> Unit
): List<DetectedFaceWithEmbedding> = coroutineScope {
val allFaces = mutableListOf<DetectedFaceWithEmbedding>()
var processedCount = 0
images.chunked(BATCH_SIZE).forEach { batch ->
val batchFaces = detectFacesInBatch(batch)
allFaces.addAll(batchFaces)
processedCount += batch.size
onProgress(processedCount, images.size)
System.gc()
}
allFaces
}
private suspend fun detectFacesInBatch(
images: List<ImageEntity>
): List<DetectedFaceWithEmbedding> = coroutineScope {
val detector = FaceDetection.getClient(
FaceDetectorOptions.Builder()
.setPerformanceMode(FaceDetectorOptions.PERFORMANCE_MODE_ACCURATE)
@@ -170,20 +229,14 @@ class FaceClusteringService @Inject constructor(
)
val faceNetModel = FaceNetModel(context)
val allFaces = mutableListOf<DetectedFaceWithEmbedding>()
val processedCount = AtomicInteger(0)
val batchFaces = mutableListOf<DetectedFaceWithEmbedding>()
try {
val jobs = images.map { image ->
async {
async(Dispatchers.IO) {
semaphore.acquire()
try {
val faces = detectFacesInImage(image, detector, faceNetModel)
val current = processedCount.incrementAndGet()
if (current % 10 == 0) {
onProgress(current, images.size)
}
faces
detectFacesInImage(image, detector, faceNetModel)
} finally {
semaphore.release()
}
@@ -191,7 +244,7 @@ class FaceClusteringService @Inject constructor(
}
jobs.awaitAll().flatten().also {
allFaces.addAll(it)
batchFaces.addAll(it)
}
} finally {
@@ -199,7 +252,7 @@ class FaceClusteringService @Inject constructor(
faceNetModel.close()
}
allFaces
batchFaces
}
private suspend fun detectFacesInImage(
@@ -215,8 +268,6 @@ class FaceClusteringService @Inject constructor(
val mlImage = InputImage.fromBitmap(bitmap, 0)
val faces = com.google.android.gms.tasks.Tasks.await(detector.process(mlImage))
val totalFacesInImage = faces.size
val result = faces.mapNotNull { face ->
try {
val faceBitmap = Bitmap.createBitmap(
@@ -237,7 +288,9 @@ class FaceClusteringService @Inject constructor(
embedding = embedding,
boundingBox = face.boundingBox,
confidence = 0.95f,
faceCount = totalFacesInImage
faceCount = faces.size,
imageWidth = bitmap.width,
imageHeight = bitmap.height
)
} catch (e: Exception) {
null
@@ -252,9 +305,6 @@ class FaceClusteringService @Inject constructor(
}
}
// All other methods remain the same (DBSCAN, similarity, etc.)
// ... [Rest of the implementation from original file]
private fun performDBSCAN(
faces: List<DetectedFaceWithEmbedding>,
epsilon: Float,
@@ -368,18 +418,61 @@ class FaceClusteringService @Inject constructor(
?: emptyList()
}
private fun selectRepresentativeFaces(
private fun selectRepresentativeFacesByCentroid(
faces: List<DetectedFaceWithEmbedding>,
count: Int
): List<DetectedFaceWithEmbedding> {
if (faces.size <= count) return faces
val sortedByTime = faces.sortedBy { it.capturedAt }
val step = faces.size / count
val centroid = calculateCentroid(faces.map { it.embedding })
return (0 until count).map { i ->
sortedByTime[i * step]
val facesWithDistance = faces.map { face ->
val distance = 1 - cosineSimilarity(face.embedding, centroid)
face to distance
}
val sortedByProximity = facesWithDistance.sortedBy { it.second }
val representatives = mutableListOf<DetectedFaceWithEmbedding>()
representatives.add(sortedByProximity.first().first)
val remainingFaces = sortedByProximity.drop(1).take(count * 3)
val sortedByTime = remainingFaces.map { it.first }.sortedBy { it.capturedAt }
if (sortedByTime.isNotEmpty()) {
val step = sortedByTime.size / (count - 1).coerceAtLeast(1)
for (i in 0 until (count - 1)) {
val index = (i * step).coerceAtMost(sortedByTime.size - 1)
representatives.add(sortedByTime[index])
}
}
return representatives.take(count)
}
private fun calculateCentroid(embeddings: List<FloatArray>): FloatArray {
if (embeddings.isEmpty()) return FloatArray(0)
val size = embeddings.first().size
val centroid = FloatArray(size) { 0f }
embeddings.forEach { embedding ->
for (i in embedding.indices) {
centroid[i] += embedding[i]
}
}
val count = embeddings.size.toFloat()
for (i in centroid.indices) {
centroid[i] /= count
}
val norm = sqrt(centroid.map { it * it }.sum())
if (norm > 0) {
return centroid.map { it / norm }.toFloatArray()
}
return centroid
}
private fun estimateAge(faces: List<DetectedFaceWithEmbedding>): AgeEstimate {
@@ -416,7 +509,6 @@ class FaceClusteringService @Inject constructor(
}
}
// Data classes
data class DetectedFaceWithEmbedding(
val imageId: String,
val imageUri: String,
@@ -424,7 +516,9 @@ data class DetectedFaceWithEmbedding(
val embedding: FloatArray,
val boundingBox: android.graphics.Rect,
val confidence: Float,
val faceCount: Int = 1
val faceCount: Int = 1,
val imageWidth: Int = 0,
val imageHeight: Int = 0
) {
override fun equals(other: Any?): Boolean {
if (this === other) return true

View File

@@ -2,6 +2,7 @@ package com.placeholder.sherpai2.ml
import android.content.Context
import android.graphics.Bitmap
import android.util.Log
import org.tensorflow.lite.Interpreter
import java.io.FileInputStream
import java.nio.ByteBuffer
@@ -11,16 +12,21 @@ import java.nio.channels.FileChannel
import kotlin.math.sqrt
/**
* FaceNetModel - MobileFaceNet wrapper for face recognition
* FaceNetModel - MobileFaceNet wrapper with debugging
*
* CLEAN IMPLEMENTATION:
* - All IDs are Strings (matching your schema)
* - Generates 192-dimensional embeddings
* - Cosine similarity for matching
* IMPROVEMENTS:
* - ✅ Detailed error logging
* - ✅ Model validation on init
* - ✅ Embedding validation (detect all-zeros)
* - ✅ Toggle-able debug mode
*/
class FaceNetModel(private val context: Context) {
class FaceNetModel(
private val context: Context,
private val debugMode: Boolean = true // Enable for troubleshooting
) {
companion object {
private const val TAG = "FaceNetModel"
private const val MODEL_FILE = "mobilefacenet.tflite"
private const val INPUT_SIZE = 112
private const val EMBEDDING_SIZE = 192
@@ -31,13 +37,56 @@ class FaceNetModel(private val context: Context) {
}
private var interpreter: Interpreter? = null
private var modelLoadSuccess = false
init {
try {
if (debugMode) Log.d(TAG, "Loading FaceNet model: $MODEL_FILE")
val model = loadModelFile()
interpreter = Interpreter(model)
modelLoadSuccess = true
if (debugMode) {
Log.d(TAG, "✅ FaceNet model loaded successfully")
Log.d(TAG, "Model input size: ${INPUT_SIZE}x$INPUT_SIZE")
Log.d(TAG, "Embedding size: $EMBEDDING_SIZE")
}
// Test model with dummy input
testModel()
} catch (e: Exception) {
throw RuntimeException("Failed to load FaceNet model", e)
Log.e(TAG, "❌ CRITICAL: Failed to load FaceNet model from assets/$MODEL_FILE", e)
Log.e(TAG, "Make sure mobilefacenet.tflite exists in app/src/main/assets/")
modelLoadSuccess = false
throw RuntimeException("Failed to load FaceNet model: ${e.message}", e)
}
}
/**
* Test model with dummy input to verify it works
*/
private fun testModel() {
try {
val testBitmap = Bitmap.createBitmap(INPUT_SIZE, INPUT_SIZE, Bitmap.Config.ARGB_8888)
val testEmbedding = generateEmbedding(testBitmap)
testBitmap.recycle()
val sum = testEmbedding.sum()
val norm = sqrt(testEmbedding.map { it * it }.sum())
if (debugMode) {
Log.d(TAG, "Model test: embedding sum=$sum, norm=$norm")
}
if (sum == 0f || norm == 0f) {
Log.e(TAG, "⚠️ WARNING: Model test produced zero embedding!")
} else {
if (debugMode) Log.d(TAG, "✅ Model test passed")
}
} catch (e: Exception) {
Log.e(TAG, "Model test failed", e)
}
}
@@ -45,12 +94,22 @@ class FaceNetModel(private val context: Context) {
* Load TFLite model from assets
*/
private fun loadModelFile(): MappedByteBuffer {
val fileDescriptor = context.assets.openFd(MODEL_FILE)
val inputStream = FileInputStream(fileDescriptor.fileDescriptor)
val fileChannel = inputStream.channel
val startOffset = fileDescriptor.startOffset
val declaredLength = fileDescriptor.declaredLength
return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength)
try {
val fileDescriptor = context.assets.openFd(MODEL_FILE)
val inputStream = FileInputStream(fileDescriptor.fileDescriptor)
val fileChannel = inputStream.channel
val startOffset = fileDescriptor.startOffset
val declaredLength = fileDescriptor.declaredLength
if (debugMode) {
Log.d(TAG, "Model file size: ${declaredLength / 1024}KB")
}
return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength)
} catch (e: Exception) {
Log.e(TAG, "Failed to open model file: $MODEL_FILE", e)
throw e
}
}
/**
@@ -60,13 +119,39 @@ class FaceNetModel(private val context: Context) {
* @return 192-dimensional embedding
*/
fun generateEmbedding(faceBitmap: Bitmap): FloatArray {
val resized = Bitmap.createScaledBitmap(faceBitmap, INPUT_SIZE, INPUT_SIZE, true)
val inputBuffer = preprocessImage(resized)
val output = Array(1) { FloatArray(EMBEDDING_SIZE) }
if (!modelLoadSuccess || interpreter == null) {
Log.e(TAG, "❌ Cannot generate embedding: model not loaded!")
return FloatArray(EMBEDDING_SIZE) { 0f }
}
interpreter?.run(inputBuffer, output)
try {
val resized = Bitmap.createScaledBitmap(faceBitmap, INPUT_SIZE, INPUT_SIZE, true)
val inputBuffer = preprocessImage(resized)
val output = Array(1) { FloatArray(EMBEDDING_SIZE) }
return normalizeEmbedding(output[0])
interpreter?.run(inputBuffer, output)
val normalized = normalizeEmbedding(output[0])
// DIAGNOSTIC: Check embedding quality
if (debugMode) {
val sum = normalized.sum()
val norm = sqrt(normalized.map { it * it }.sum())
if (sum == 0f && norm == 0f) {
Log.e(TAG, "❌ CRITICAL: Generated all-zero embedding!")
Log.e(TAG, "Input bitmap: ${faceBitmap.width}x${faceBitmap.height}")
} else {
Log.d(TAG, "✅ Embedding: sum=${"%.2f".format(sum)}, norm=${"%.2f".format(norm)}, first5=[${normalized.take(5).joinToString { "%.3f".format(it) }}]")
}
}
return normalized
} catch (e: Exception) {
Log.e(TAG, "Failed to generate embedding", e)
return FloatArray(EMBEDDING_SIZE) { 0f }
}
}
/**
@@ -76,6 +161,10 @@ class FaceNetModel(private val context: Context) {
faceBitmaps: List<Bitmap>,
onProgress: (Int, Int) -> Unit = { _, _ -> }
): List<FloatArray> {
if (debugMode) {
Log.d(TAG, "Generating embeddings for ${faceBitmaps.size} faces")
}
return faceBitmaps.mapIndexed { index, bitmap ->
onProgress(index + 1, faceBitmaps.size)
generateEmbedding(bitmap)
@@ -88,6 +177,10 @@ class FaceNetModel(private val context: Context) {
fun createPersonModel(embeddings: List<FloatArray>): FloatArray {
require(embeddings.isNotEmpty()) { "Need at least one embedding" }
if (debugMode) {
Log.d(TAG, "Creating person model from ${embeddings.size} embeddings")
}
val averaged = FloatArray(EMBEDDING_SIZE) { 0f }
embeddings.forEach { embedding ->
@@ -101,7 +194,14 @@ class FaceNetModel(private val context: Context) {
averaged[i] /= count
}
return normalizeEmbedding(averaged)
val normalized = normalizeEmbedding(averaged)
if (debugMode) {
val sum = normalized.sum()
Log.d(TAG, "Person model created: sum=${"%.2f".format(sum)}")
}
return normalized
}
/**
@@ -110,7 +210,7 @@ class FaceNetModel(private val context: Context) {
*/
fun calculateSimilarity(embedding1: FloatArray, embedding2: FloatArray): Float {
require(embedding1.size == EMBEDDING_SIZE && embedding2.size == EMBEDDING_SIZE) {
"Invalid embedding size"
"Invalid embedding size: ${embedding1.size} vs ${embedding2.size}"
}
var dotProduct = 0f
@@ -123,7 +223,14 @@ class FaceNetModel(private val context: Context) {
norm2 += embedding2[i] * embedding2[i]
}
return dotProduct / (sqrt(norm1) * sqrt(norm2))
val similarity = dotProduct / (sqrt(norm1) * sqrt(norm2))
if (debugMode && (similarity.isNaN() || similarity.isInfinite())) {
Log.e(TAG, "❌ Invalid similarity: $similarity (norm1=$norm1, norm2=$norm2)")
return 0f
}
return similarity
}
/**
@@ -151,6 +258,10 @@ class FaceNetModel(private val context: Context) {
}
}
if (debugMode && bestMatch != null) {
Log.d(TAG, "Best match: ${bestMatch.first} with similarity ${bestMatch.second}")
}
return bestMatch
}
@@ -169,6 +280,7 @@ class FaceNetModel(private val context: Context) {
val g = ((pixel shr 8) and 0xFF) / 255.0f
val b = (pixel and 0xFF) / 255.0f
// Normalize to [-1, 1]
buffer.putFloat((r - 0.5f) / 0.5f)
buffer.putFloat((g - 0.5f) / 0.5f)
buffer.putFloat((b - 0.5f) / 0.5f)
@@ -190,14 +302,29 @@ class FaceNetModel(private val context: Context) {
return if (norm > 0) {
FloatArray(embedding.size) { i -> embedding[i] / norm }
} else {
Log.w(TAG, "⚠️ Cannot normalize zero embedding")
embedding
}
}
/**
* Get model status for diagnostics
*/
fun getModelStatus(): String {
return if (modelLoadSuccess) {
"✅ Model loaded and operational"
} else {
"❌ Model failed to load - check assets/$MODEL_FILE"
}
}
/**
* Clean up resources
*/
fun close() {
if (debugMode) {
Log.d(TAG, "Closing FaceNet model")
}
interpreter?.close()
interpreter = null
}

View File

@@ -8,9 +8,13 @@ import androidx.compose.foundation.layout.*
import androidx.compose.foundation.lazy.grid.GridCells
import androidx.compose.foundation.lazy.grid.LazyVerticalGrid
import androidx.compose.foundation.lazy.grid.items
import androidx.compose.foundation.shape.CircleShape
import androidx.compose.foundation.shape.RoundedCornerShape
import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.filled.Check
import androidx.compose.material.icons.filled.Warning
import androidx.compose.material3.*
import androidx.compose.runtime.Composable
import androidx.compose.runtime.*
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.draw.clip
@@ -19,6 +23,8 @@ import androidx.compose.ui.layout.ContentScale
import androidx.compose.ui.text.font.FontWeight
import androidx.compose.ui.unit.dp
import coil.compose.AsyncImage
import com.placeholder.sherpai2.domain.clustering.ClusterQualityAnalyzer
import com.placeholder.sherpai2.domain.clustering.ClusterQualityTier
import com.placeholder.sherpai2.domain.clustering.ClusteringResult
import com.placeholder.sherpai2.domain.clustering.FaceCluster
@@ -28,13 +34,20 @@ import com.placeholder.sherpai2.domain.clustering.FaceCluster
* Each cluster card shows:
* - 2x2 grid of representative faces
* - Photo count
* - Quality badge (Excellent/Good/Poor)
* - Tap to name
*
* IMPROVEMENTS:
* - ✅ Quality badges for each cluster
* - ✅ Visual indicators for trainable vs non-trainable clusters
* - ✅ Better UX with disabled states for poor quality clusters
*/
@Composable
fun ClusterGridScreen(
result: ClusteringResult,
onSelectCluster: (FaceCluster) -> Unit,
modifier: Modifier = Modifier
modifier: Modifier = Modifier,
qualityAnalyzer: ClusterQualityAnalyzer = remember { ClusterQualityAnalyzer() }
) {
Column(
modifier = modifier
@@ -65,8 +78,15 @@ fun ClusterGridScreen(
verticalArrangement = Arrangement.spacedBy(12.dp)
) {
items(result.clusters) { cluster ->
// Analyze quality for each cluster
val qualityResult = remember(cluster) {
qualityAnalyzer.analyzeCluster(cluster)
}
ClusterCard(
cluster = cluster,
qualityTier = qualityResult.qualityTier,
canTrain = qualityResult.canTrain,
onClick = { onSelectCluster(cluster) }
)
}
@@ -75,77 +95,168 @@ fun ClusterGridScreen(
}
/**
* Single cluster card with 2x2 face grid
* Single cluster card with 2x2 face grid and quality badge
*/
@Composable
private fun ClusterCard(
cluster: FaceCluster,
qualityTier: ClusterQualityTier,
canTrain: Boolean,
onClick: () -> Unit
) {
Card(
modifier = Modifier
.fillMaxWidth()
.aspectRatio(1f)
.clickable(onClick = onClick),
elevation = CardDefaults.cardElevation(defaultElevation = 2.dp)
.clickable(onClick = onClick), // Always clickable - let dialog handle validation
elevation = CardDefaults.cardElevation(defaultElevation = 2.dp),
colors = CardDefaults.cardColors(
containerColor = when {
qualityTier == ClusterQualityTier.POOR ->
MaterialTheme.colorScheme.errorContainer.copy(alpha = 0.3f)
!canTrain ->
MaterialTheme.colorScheme.surfaceVariant.copy(alpha = 0.5f)
else ->
MaterialTheme.colorScheme.surface
}
)
) {
Column(
Box(
modifier = Modifier.fillMaxSize()
) {
// 2x2 grid of faces
val facesToShow = cluster.representativeFaces.take(4)
Column(
modifier = Modifier.weight(1f)
modifier = Modifier.fillMaxSize()
) {
// Top row (2 faces)
Row(modifier = Modifier.weight(1f)) {
facesToShow.getOrNull(0)?.let { face ->
FaceThumbnail(
imageUri = face.imageUri,
modifier = Modifier.weight(1f)
)
} ?: EmptyFaceSlot(Modifier.weight(1f))
// 2x2 grid of faces
val facesToShow = cluster.representativeFaces.take(4)
facesToShow.getOrNull(1)?.let { face ->
FaceThumbnail(
imageUri = face.imageUri,
modifier = Modifier.weight(1f)
)
} ?: EmptyFaceSlot(Modifier.weight(1f))
Column(
modifier = Modifier.weight(1f)
) {
// Top row (2 faces)
Row(modifier = Modifier.weight(1f)) {
facesToShow.getOrNull(0)?.let { face ->
FaceThumbnail(
imageUri = face.imageUri,
enabled = canTrain,
modifier = Modifier.weight(1f)
)
} ?: EmptyFaceSlot(Modifier.weight(1f))
facesToShow.getOrNull(1)?.let { face ->
FaceThumbnail(
imageUri = face.imageUri,
enabled = canTrain,
modifier = Modifier.weight(1f)
)
} ?: EmptyFaceSlot(Modifier.weight(1f))
}
// Bottom row (2 faces)
Row(modifier = Modifier.weight(1f)) {
facesToShow.getOrNull(2)?.let { face ->
FaceThumbnail(
imageUri = face.imageUri,
enabled = canTrain,
modifier = Modifier.weight(1f)
)
} ?: EmptyFaceSlot(Modifier.weight(1f))
facesToShow.getOrNull(3)?.let { face ->
FaceThumbnail(
imageUri = face.imageUri,
enabled = canTrain,
modifier = Modifier.weight(1f)
)
} ?: EmptyFaceSlot(Modifier.weight(1f))
}
}
// Bottom row (2 faces)
Row(modifier = Modifier.weight(1f)) {
facesToShow.getOrNull(2)?.let { face ->
FaceThumbnail(
imageUri = face.imageUri,
modifier = Modifier.weight(1f)
// Footer with photo count
Surface(
modifier = Modifier.fillMaxWidth(),
color = if (canTrain) {
MaterialTheme.colorScheme.primaryContainer
} else {
MaterialTheme.colorScheme.surfaceVariant
}
) {
Row(
modifier = Modifier.padding(12.dp),
verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.SpaceBetween
) {
Text(
text = "${cluster.photoCount} photos",
style = MaterialTheme.typography.bodyMedium,
fontWeight = FontWeight.SemiBold,
color = if (canTrain) {
MaterialTheme.colorScheme.onPrimaryContainer
} else {
MaterialTheme.colorScheme.onSurfaceVariant
}
)
} ?: EmptyFaceSlot(Modifier.weight(1f))
facesToShow.getOrNull(3)?.let { face ->
FaceThumbnail(
imageUri = face.imageUri,
modifier = Modifier.weight(1f)
)
} ?: EmptyFaceSlot(Modifier.weight(1f))
}
}
}
// Footer with photo count
Surface(
modifier = Modifier.fillMaxWidth(),
color = MaterialTheme.colorScheme.primaryContainer
) {
Text(
text = "${cluster.photoCount} photos",
style = MaterialTheme.typography.bodyMedium,
fontWeight = FontWeight.SemiBold,
modifier = Modifier.padding(12.dp),
color = MaterialTheme.colorScheme.onPrimaryContainer
)
}
// Quality badge overlay
QualityBadge(
qualityTier = qualityTier,
canTrain = canTrain,
modifier = Modifier
.align(Alignment.TopEnd)
.padding(8.dp)
)
}
}
}
/**
* Quality badge indicator
*/
@Composable
private fun QualityBadge(
qualityTier: ClusterQualityTier,
canTrain: Boolean,
modifier: Modifier = Modifier
) {
val (backgroundColor, iconColor, icon) = when (qualityTier) {
ClusterQualityTier.EXCELLENT -> Triple(
Color(0xFF1B5E20),
Color.White,
Icons.Default.Check
)
ClusterQualityTier.GOOD -> Triple(
Color(0xFF2E7D32),
Color.White,
Icons.Default.Check
)
ClusterQualityTier.POOR -> Triple(
Color(0xFFD32F2F),
Color.White,
Icons.Default.Warning
)
}
Surface(
modifier = modifier,
shape = CircleShape,
color = backgroundColor,
shadowElevation = 2.dp
) {
Box(
modifier = Modifier
.size(32.dp)
.padding(6.dp),
contentAlignment = Alignment.Center
) {
Icon(
imageVector = icon,
contentDescription = qualityTier.name,
tint = iconColor,
modifier = Modifier.size(20.dp)
)
}
}
}
@@ -153,19 +264,23 @@ private fun ClusterCard(
@Composable
private fun FaceThumbnail(
imageUri: String,
enabled: Boolean,
modifier: Modifier = Modifier
) {
AsyncImage(
model = Uri.parse(imageUri),
contentDescription = "Face",
modifier = modifier
.fillMaxSize()
.border(
width = 0.5.dp,
color = MaterialTheme.colorScheme.outline.copy(alpha = 0.3f)
),
contentScale = ContentScale.Crop
)
Box(modifier = modifier) {
AsyncImage(
model = Uri.parse(imageUri),
contentDescription = "Face",
modifier = Modifier
.fillMaxSize()
.border(
width = 0.5.dp,
color = MaterialTheme.colorScheme.outline.copy(alpha = 0.3f)
),
contentScale = ContentScale.Crop,
alpha = if (enabled) 1f else 0.6f
)
}
}
@Composable

View File

@@ -11,11 +11,17 @@ import androidx.compose.ui.text.font.FontWeight
import androidx.compose.ui.text.style.TextAlign
import androidx.compose.ui.unit.dp
import androidx.hilt.navigation.compose.hiltViewModel
import com.placeholder.sherpai2.domain.clustering.ClusterQualityAnalyzer
/**
* DiscoverPeopleScreen - COMPLETE WORKING VERSION
* DiscoverPeopleScreen - COMPLETE WORKING VERSION WITH NAMING DIALOG
*
* This handles ALL states properly including Idle state
* This handles ALL states properly including the NamingCluster dialog
*
* IMPROVEMENTS:
* - ✅ Complete naming dialog integration
* - ✅ Quality analysis in cluster grid
* - ✅ Better error handling
*/
@OptIn(ExperimentalMaterial3Api::class)
@Composable
@@ -24,125 +30,130 @@ fun DiscoverPeopleScreen(
onNavigateBack: () -> Unit = {}
) {
val uiState by viewModel.uiState.collectAsState()
val qualityAnalyzer = remember { ClusterQualityAnalyzer() }
Scaffold(
topBar = {
TopAppBar(
title = { Text("Discover People") },
navigationIcon = {
IconButton(onClick = onNavigateBack) {
Icon(
imageVector = Icons.Default.Person,
contentDescription = "Back"
// No Scaffold, no TopAppBar - MainScreen handles that
Box(
modifier = Modifier.fillMaxSize()
) {
when (val state = uiState) {
// ===== IDLE STATE (START HERE) =====
is DiscoverUiState.Idle -> {
IdleStateContent(
onStartDiscovery = { viewModel.startDiscovery() }
)
}
// ===== CLUSTERING IN PROGRESS =====
is DiscoverUiState.Clustering -> {
ClusteringProgressContent(
progress = state.progress,
total = state.total,
message = state.message
)
}
// ===== CLUSTERS READY FOR NAMING =====
is DiscoverUiState.NamingReady -> {
ClusterGridScreen(
result = state.result,
onSelectCluster = { cluster ->
viewModel.selectCluster(cluster)
},
qualityAnalyzer = qualityAnalyzer
)
}
// ===== ANALYZING CLUSTER QUALITY =====
is DiscoverUiState.AnalyzingCluster -> {
LoadingContent(message = "Analyzing cluster quality...")
}
// ===== NAMING A CLUSTER (SHOW DIALOG) =====
is DiscoverUiState.NamingCluster -> {
// Show cluster grid in background
ClusterGridScreen(
result = state.result,
onSelectCluster = { /* Disabled while dialog open */ },
qualityAnalyzer = qualityAnalyzer
)
// Show naming dialog overlay
NamingDialog(
cluster = state.selectedCluster,
suggestedSiblings = state.suggestedSiblings,
onConfirm = { name, dateOfBirth, isChild, selectedSiblings ->
viewModel.confirmClusterName(
cluster = state.selectedCluster,
name = name,
dateOfBirth = dateOfBirth,
isChild = isChild,
selectedSiblings = selectedSiblings
)
},
onDismiss = {
viewModel.cancelNaming()
},
qualityAnalyzer = qualityAnalyzer
)
}
// ===== TRAINING IN PROGRESS =====
is DiscoverUiState.Training -> {
TrainingProgressContent(
stage = state.stage,
progress = state.progress,
total = state.total
)
}
// ===== VALIDATION PREVIEW =====
is DiscoverUiState.ValidationPreview -> {
ValidationPreviewScreen(
personName = state.personName,
validationResult = state.validationResult,
onApprove = {
viewModel.approveValidationAndScan(
personId = state.personId,
personName = state.personName
)
},
onReject = {
viewModel.rejectValidationAndImprove()
}
}
)
}
) { paddingValues ->
Box(
modifier = Modifier
.fillMaxSize()
.padding(paddingValues)
) {
when (val state = uiState) {
// ===== IDLE STATE (START HERE) =====
is DiscoverUiState.Idle -> {
IdleStateContent(
onStartDiscovery = { viewModel.startDiscovery() }
)
}
)
}
// ===== CLUSTERING IN PROGRESS =====
is DiscoverUiState.Clustering -> {
ClusteringProgressContent(
progress = state.progress,
total = state.total,
message = state.message
)
}
// ===== COMPLETE =====
is DiscoverUiState.Complete -> {
CompleteStateContent(
message = state.message,
onDone = onNavigateBack
)
}
// ===== CLUSTERS READY FOR NAMING =====
is DiscoverUiState.NamingReady -> {
ClusterGridScreen(
result = state.result,
onSelectCluster = { cluster ->
viewModel.selectCluster(cluster)
}
)
}
// ===== NO PEOPLE FOUND =====
is DiscoverUiState.NoPeopleFound -> {
ErrorStateContent(
title = "No People Found",
message = state.message,
onRetry = { viewModel.startDiscovery() },
onBack = onNavigateBack
)
}
// ===== ANALYZING CLUSTER QUALITY =====
is DiscoverUiState.AnalyzingCluster -> {
LoadingContent(message = "Analyzing cluster quality...")
}
// ===== NAMING A CLUSTER =====
is DiscoverUiState.NamingCluster -> {
Text(
text = "Naming dialog for cluster ${state.selectedCluster.clusterId}\n\nDialog UI coming...",
modifier = Modifier.align(Alignment.Center)
)
}
// ===== TRAINING IN PROGRESS =====
is DiscoverUiState.Training -> {
TrainingProgressContent(
stage = state.stage,
progress = state.progress,
total = state.total
)
}
// ===== VALIDATION PREVIEW =====
is DiscoverUiState.ValidationPreview -> {
ValidationPreviewScreen(
personName = state.personName,
validationResult = state.validationResult,
onApprove = {
viewModel.approveValidationAndScan(
personId = state.personId,
personName = state.personName
)
},
onReject = {
viewModel.rejectValidationAndImprove()
}
)
}
// ===== COMPLETE =====
is DiscoverUiState.Complete -> {
CompleteStateContent(
message = state.message,
onDone = onNavigateBack
)
}
// ===== NO PEOPLE FOUND =====
is DiscoverUiState.NoPeopleFound -> {
ErrorStateContent(
title = "No People Found",
message = state.message,
onRetry = { viewModel.startDiscovery() },
onBack = onNavigateBack
)
}
// ===== ERROR =====
is DiscoverUiState.Error -> {
ErrorStateContent(
title = "Error",
message = state.message,
onRetry = { viewModel.reset(); viewModel.startDiscovery() },
onBack = onNavigateBack
)
}
// ===== ERROR =====
is DiscoverUiState.Error -> {
ErrorStateContent(
title = "Error",
message = state.message,
onRetry = { viewModel.reset(); viewModel.startDiscovery() },
onBack = onNavigateBack
)
}
}
}
}
// ===== IDLE STATE CONTENT =====
@Composable
@@ -165,19 +176,11 @@ private fun IdleStateContent(
Spacer(modifier = Modifier.height(32.dp))
Text(
text = "Discover People",
style = MaterialTheme.typography.headlineLarge,
fontWeight = FontWeight.Bold
)
Spacer(modifier = Modifier.height(16.dp))
Text(
text = "Automatically find and organize people in your photo library",
style = MaterialTheme.typography.bodyLarge,
style = MaterialTheme.typography.headlineSmall,
textAlign = TextAlign.Center,
color = MaterialTheme.colorScheme.onSurfaceVariant
color = MaterialTheme.colorScheme.onSurface
)
Spacer(modifier = Modifier.height(48.dp))

View File

@@ -0,0 +1,480 @@
package com.placeholder.sherpai2.ui.discover
import androidx.compose.foundation.background
import androidx.compose.foundation.border
import androidx.compose.foundation.clickable
import androidx.compose.foundation.layout.*
import androidx.compose.foundation.lazy.LazyRow
import androidx.compose.foundation.lazy.items
import androidx.compose.foundation.rememberScrollState
import androidx.compose.foundation.shape.CircleShape
import androidx.compose.foundation.shape.RoundedCornerShape
import androidx.compose.foundation.text.KeyboardActions
import androidx.compose.foundation.text.KeyboardOptions
import androidx.compose.foundation.verticalScroll
import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.filled.*
import androidx.compose.material3.*
import androidx.compose.runtime.*
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.draw.clip
import androidx.compose.ui.graphics.Color
import androidx.compose.ui.layout.ContentScale
import androidx.compose.ui.platform.LocalSoftwareKeyboardController
import androidx.compose.ui.text.font.FontWeight
import androidx.compose.ui.text.input.ImeAction
import androidx.compose.ui.text.input.KeyboardCapitalization
import androidx.compose.ui.text.input.KeyboardType
import androidx.compose.ui.text.style.TextAlign
import androidx.compose.ui.unit.dp
import androidx.compose.ui.window.Dialog
import coil.compose.AsyncImage
import com.placeholder.sherpai2.domain.clustering.ClusterQualityAnalyzer
import com.placeholder.sherpai2.domain.clustering.ClusterQualityTier
import com.placeholder.sherpai2.domain.clustering.FaceCluster
import java.text.SimpleDateFormat
import java.util.*
/**
* NamingDialog - Complete dialog for naming a cluster
*
* Features:
* - Name input with validation
* - Child toggle with date of birth picker
* - Sibling cluster selection
* - Quality warnings display
* - Preview of representative faces
*
* IMPROVEMENTS:
* - ✅ Complete UI implementation
* - ✅ Quality analysis integration
* - ✅ Sibling selection
* - ✅ Form validation
*/
@OptIn(ExperimentalMaterial3Api::class)
@Composable
fun NamingDialog(
cluster: FaceCluster,
suggestedSiblings: List<FaceCluster>,
onConfirm: (name: String, dateOfBirth: Long?, isChild: Boolean, selectedSiblings: List<Int>) -> Unit,
onDismiss: () -> Unit,
qualityAnalyzer: ClusterQualityAnalyzer = remember { ClusterQualityAnalyzer() }
) {
var name by remember { mutableStateOf("") }
var isChild by remember { mutableStateOf(false) }
var showDatePicker by remember { mutableStateOf(false) }
var dateOfBirth by remember { mutableStateOf<Long?>(null) }
var selectedSiblingIds by remember { mutableStateOf(setOf<Int>()) }
// Analyze cluster quality
val qualityResult = remember(cluster) {
qualityAnalyzer.analyzeCluster(cluster)
}
val keyboardController = LocalSoftwareKeyboardController.current
val dateFormatter = remember { SimpleDateFormat("MMM dd, yyyy", Locale.getDefault()) }
Dialog(onDismissRequest = onDismiss) {
Card(
modifier = Modifier
.fillMaxWidth()
.fillMaxHeight(0.9f),
shape = RoundedCornerShape(16.dp),
elevation = CardDefaults.cardElevation(defaultElevation = 8.dp)
) {
Column(
modifier = Modifier
.fillMaxSize()
.verticalScroll(rememberScrollState())
) {
// Header
Surface(
color = MaterialTheme.colorScheme.primaryContainer
) {
Row(
modifier = Modifier
.fillMaxWidth()
.padding(16.dp),
horizontalArrangement = Arrangement.SpaceBetween,
verticalAlignment = Alignment.CenterVertically
) {
Column(modifier = Modifier.weight(1f)) {
Text(
text = "Name This Person",
style = MaterialTheme.typography.titleLarge,
fontWeight = FontWeight.Bold,
color = MaterialTheme.colorScheme.onPrimaryContainer
)
Text(
text = "${cluster.photoCount} photos",
style = MaterialTheme.typography.bodyMedium,
color = MaterialTheme.colorScheme.onPrimaryContainer.copy(alpha = 0.7f)
)
}
IconButton(onClick = onDismiss) {
Icon(
imageVector = Icons.Default.Close,
contentDescription = "Close",
tint = MaterialTheme.colorScheme.onPrimaryContainer
)
}
}
}
Column(
modifier = Modifier.padding(16.dp)
) {
// Preview faces
Text(
text = "Preview",
style = MaterialTheme.typography.titleMedium,
fontWeight = FontWeight.SemiBold
)
Spacer(modifier = Modifier.height(8.dp))
LazyRow(
horizontalArrangement = Arrangement.spacedBy(8.dp)
) {
items(cluster.representativeFaces.take(6)) { face ->
AsyncImage(
model = android.net.Uri.parse(face.imageUri),
contentDescription = "Preview",
modifier = Modifier
.size(80.dp)
.clip(RoundedCornerShape(8.dp))
.border(
width = 1.dp,
color = MaterialTheme.colorScheme.outline,
shape = RoundedCornerShape(8.dp)
),
contentScale = ContentScale.Crop
)
}
}
Spacer(modifier = Modifier.height(20.dp))
// Quality warning (if applicable)
if (qualityResult.qualityTier != ClusterQualityTier.EXCELLENT) {
QualityWarningCard(qualityResult = qualityResult)
Spacer(modifier = Modifier.height(16.dp))
}
// Name input
OutlinedTextField(
value = name,
onValueChange = { name = it },
label = { Text("Name") },
placeholder = { Text("Enter person's name") },
modifier = Modifier.fillMaxWidth(),
singleLine = true,
leadingIcon = {
Icon(
imageVector = Icons.Default.Person,
contentDescription = null
)
},
keyboardOptions = KeyboardOptions(
capitalization = KeyboardCapitalization.Words,
imeAction = ImeAction.Done
),
keyboardActions = KeyboardActions(
onDone = { keyboardController?.hide() }
)
)
Spacer(modifier = Modifier.height(16.dp))
// Child toggle
Row(
modifier = Modifier
.fillMaxWidth()
.clip(RoundedCornerShape(8.dp))
.clickable { isChild = !isChild }
.background(
if (isChild) MaterialTheme.colorScheme.primaryContainer
else MaterialTheme.colorScheme.surfaceVariant
)
.padding(16.dp),
verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.SpaceBetween
) {
Row(
verticalAlignment = Alignment.CenterVertically
) {
Icon(
imageVector = Icons.Default.Face,
contentDescription = null,
tint = if (isChild) MaterialTheme.colorScheme.onPrimaryContainer
else MaterialTheme.colorScheme.onSurfaceVariant
)
Spacer(modifier = Modifier.width(12.dp))
Column {
Text(
text = "This is a child",
style = MaterialTheme.typography.bodyLarge,
fontWeight = FontWeight.Medium,
color = if (isChild) MaterialTheme.colorScheme.onPrimaryContainer
else MaterialTheme.colorScheme.onSurfaceVariant
)
Text(
text = "For age-appropriate filtering",
style = MaterialTheme.typography.bodySmall,
color = if (isChild) MaterialTheme.colorScheme.onPrimaryContainer.copy(alpha = 0.7f)
else MaterialTheme.colorScheme.onSurfaceVariant.copy(alpha = 0.7f)
)
}
}
Switch(
checked = isChild,
onCheckedChange = { isChild = it }
)
}
// Date of birth (if child)
if (isChild) {
Spacer(modifier = Modifier.height(12.dp))
OutlinedButton(
onClick = { showDatePicker = true },
modifier = Modifier.fillMaxWidth()
) {
Icon(
imageVector = Icons.Default.DateRange,
contentDescription = null
)
Spacer(modifier = Modifier.width(8.dp))
Text(
text = dateOfBirth?.let { dateFormatter.format(Date(it)) }
?: "Set date of birth (optional)"
)
}
}
// Sibling selection
if (suggestedSiblings.isNotEmpty()) {
Spacer(modifier = Modifier.height(20.dp))
Text(
text = "Appears with",
style = MaterialTheme.typography.titleMedium,
fontWeight = FontWeight.SemiBold
)
Text(
text = "Select siblings or family members",
style = MaterialTheme.typography.bodySmall,
color = MaterialTheme.colorScheme.onSurfaceVariant
)
Spacer(modifier = Modifier.height(8.dp))
suggestedSiblings.forEach { sibling ->
SiblingSelectionItem(
cluster = sibling,
selected = sibling.clusterId in selectedSiblingIds,
onToggle = {
selectedSiblingIds = if (sibling.clusterId in selectedSiblingIds) {
selectedSiblingIds - sibling.clusterId
} else {
selectedSiblingIds + sibling.clusterId
}
}
)
Spacer(modifier = Modifier.height(8.dp))
}
}
Spacer(modifier = Modifier.height(24.dp))
// Action buttons
Row(
modifier = Modifier.fillMaxWidth(),
horizontalArrangement = Arrangement.spacedBy(12.dp)
) {
OutlinedButton(
onClick = onDismiss,
modifier = Modifier.weight(1f)
) {
Text("Cancel")
}
Button(
onClick = {
if (name.isNotBlank()) {
onConfirm(
name.trim(),
dateOfBirth,
isChild,
selectedSiblingIds.toList()
)
}
},
modifier = Modifier.weight(1f),
enabled = name.isNotBlank() && qualityResult.canTrain
) {
Icon(
imageVector = Icons.Default.Check,
contentDescription = null,
modifier = Modifier.size(20.dp)
)
Spacer(modifier = Modifier.width(8.dp))
Text("Create Model")
}
}
}
}
}
}
// Date picker dialog
if (showDatePicker) {
val datePickerState = rememberDatePickerState()
DatePickerDialog(
onDismissRequest = { showDatePicker = false },
confirmButton = {
TextButton(
onClick = {
dateOfBirth = datePickerState.selectedDateMillis
showDatePicker = false
}
) {
Text("OK")
}
},
dismissButton = {
TextButton(onClick = { showDatePicker = false }) {
Text("Cancel")
}
}
) {
DatePicker(state = datePickerState)
}
}
}
/**
* Quality warning card
*/
@Composable
private fun QualityWarningCard(qualityResult: com.placeholder.sherpai2.domain.clustering.ClusterQualityResult) {
val (backgroundColor, iconColor) = when (qualityResult.qualityTier) {
ClusterQualityTier.GOOD -> Pair(
Color(0xFFFFF9C4),
Color(0xFFF57F17)
)
ClusterQualityTier.POOR -> Pair(
Color(0xFFFFEBEE),
Color(0xFFD32F2F)
)
else -> Pair(
MaterialTheme.colorScheme.surfaceVariant,
MaterialTheme.colorScheme.onSurfaceVariant
)
}
Card(
modifier = Modifier.fillMaxWidth(),
colors = CardDefaults.cardColors(
containerColor = backgroundColor
)
) {
Column(
modifier = Modifier.padding(12.dp)
) {
Row(
verticalAlignment = Alignment.CenterVertically
) {
Icon(
imageVector = Icons.Default.Warning,
contentDescription = null,
tint = iconColor,
modifier = Modifier.size(20.dp)
)
Spacer(modifier = Modifier.width(8.dp))
Text(
text = when (qualityResult.qualityTier) {
ClusterQualityTier.GOOD -> "Review Before Training"
ClusterQualityTier.POOR -> "Quality Issues Detected"
else -> ""
},
style = MaterialTheme.typography.titleSmall,
fontWeight = FontWeight.Bold,
color = iconColor
)
}
Spacer(modifier = Modifier.height(8.dp))
qualityResult.warnings.forEach { warning ->
Text(
text = "$warning",
style = MaterialTheme.typography.bodySmall,
color = MaterialTheme.colorScheme.onSurfaceVariant
)
}
}
}
}
/**
* Sibling selection item
*/
@Composable
private fun SiblingSelectionItem(
cluster: FaceCluster,
selected: Boolean,
onToggle: () -> Unit
) {
Row(
modifier = Modifier
.fillMaxWidth()
.clip(RoundedCornerShape(8.dp))
.clickable(onClick = onToggle)
.background(
if (selected) MaterialTheme.colorScheme.primaryContainer
else MaterialTheme.colorScheme.surfaceVariant
)
.padding(12.dp),
verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.SpaceBetween
) {
Row(
verticalAlignment = Alignment.CenterVertically
) {
// Preview face
AsyncImage(
model = android.net.Uri.parse(cluster.representativeFaces.firstOrNull()?.imageUri ?: ""),
contentDescription = "Preview",
modifier = Modifier
.size(40.dp)
.clip(CircleShape)
.border(
width = 1.dp,
color = MaterialTheme.colorScheme.outline,
shape = CircleShape
),
contentScale = ContentScale.Crop
)
Spacer(modifier = Modifier.width(12.dp))
Text(
text = "${cluster.photoCount} photos together",
style = MaterialTheme.typography.bodyMedium,
color = if (selected) MaterialTheme.colorScheme.onPrimaryContainer
else MaterialTheme.colorScheme.onSurfaceVariant
)
}
Checkbox(
checked = selected,
onCheckedChange = { onToggle() }
)
}
}

View File

@@ -1,56 +1,48 @@
package com.placeholder.sherpai2.ui.presentation
import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.padding
import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.filled.*
import androidx.compose.material.icons.filled.Menu
import androidx.compose.material3.*
import androidx.compose.runtime.*
import androidx.compose.ui.Modifier
import androidx.compose.ui.text.font.FontWeight
import androidx.hilt.navigation.compose.hiltViewModel
import androidx.navigation.compose.currentBackStackEntryAsState
import androidx.navigation.compose.rememberNavController
import androidx.navigation.compose.currentBackStackEntryAsState
import com.placeholder.sherpai2.ui.navigation.AppNavHost
import com.placeholder.sherpai2.ui.navigation.AppRoutes
import kotlinx.coroutines.launch
/**
* MainScreen - UPDATED with auto face cache check
* MainScreen - Complete app container with drawer navigation
*
* NEW: Prompts user to populate face cache on app launch if needed
* FIXED: Prevents double headers for screens with their own TopAppBar
* CRITICAL FIX APPLIED:
* ✅ Removed AppRoutes.DISCOVER from screensWithOwnTopBar
* ✅ DiscoverPeopleScreen now shows hamburger menu + "Discover People" title!
*/
@OptIn(ExperimentalMaterial3Api::class)
@Composable
fun MainScreen(
mainViewModel: MainViewModel = hiltViewModel() // Same package - no import needed!
viewModel: MainViewModel = hiltViewModel()
) {
val drawerState = rememberDrawerState(initialValue = DrawerValue.Closed)
val scope = rememberCoroutineScope()
val navController = rememberNavController()
val drawerState = rememberDrawerState(DrawerValue.Closed)
val scope = rememberCoroutineScope()
val navBackStackEntry by navController.currentBackStackEntryAsState()
val currentRoute = navBackStackEntry?.destination?.route ?: AppRoutes.SEARCH
val currentBackStackEntry by navController.currentBackStackEntryAsState()
val currentRoute = currentBackStackEntry?.destination?.route
// Face cache status
val needsFaceCache by mainViewModel.needsFaceCachePopulation.collectAsState()
val unscannedCount by mainViewModel.unscannedPhotoCount.collectAsState()
// Face cache prompt dialog state
val needsFaceCachePopulation by viewModel.needsFaceCachePopulation.collectAsState()
val unscannedPhotoCount by viewModel.unscannedPhotoCount.collectAsState()
// Show face cache prompt dialog if needed
if (needsFaceCache && unscannedCount > 0) {
FaceCachePromptDialog(
unscannedPhotoCount = unscannedCount,
onDismiss = { mainViewModel.dismissFaceCachePrompt() },
onScanNow = {
mainViewModel.dismissFaceCachePrompt()
// Navigate to Photo Utilities
navController.navigate(AppRoutes.UTILITIES) {
launchSingleTop = true
}
}
)
}
// ✅ CRITICAL FIX: DISCOVER is NOT in this list!
// These screens handle their own TopAppBar/navigation
val screensWithOwnTopBar = setOf(
AppRoutes.IMAGE_DETAIL,
AppRoutes.TRAINING_SCREEN,
AppRoutes.CROP_SCREEN
)
ModalNavigationDrawer(
drawerState = drawerState,
@@ -60,133 +52,86 @@ fun MainScreen(
onDestinationClicked = { route ->
scope.launch {
drawerState.close()
if (route != currentRoute) {
navController.navigate(route) {
launchSingleTop = true
}
}
navController.navigate(route) {
popUpTo(navController.graph.startDestinationId) {
saveState = true
}
launchSingleTop = true
restoreState = true
}
}
)
},
}
) {
// CRITICAL: Some screens manage their own TopAppBar
// Hide MainScreen's TopAppBar for these routes to prevent double headers
val screensWithOwnTopBar = setOf(
AppRoutes.TRAINING_PHOTO_SELECTOR, // Has its own TopAppBar with subtitle
"album/", // Album views have their own TopAppBar (prefix match)
AppRoutes.IMAGE_DETAIL // Image detail has its own TopAppBar
)
// Check if current route starts with any excluded pattern
val showTopBar = screensWithOwnTopBar.none { currentRoute.startsWith(it) }
Scaffold(
topBar = {
if (showTopBar) {
// ✅ Show TopAppBar for ALL screens except those with their own
if (currentRoute !in screensWithOwnTopBar) {
TopAppBar(
title = {
Column {
Text(
text = getScreenTitle(currentRoute),
style = MaterialTheme.typography.titleLarge,
fontWeight = FontWeight.Bold
)
getScreenSubtitle(currentRoute)?.let { subtitle ->
Text(
text = subtitle,
style = MaterialTheme.typography.bodySmall,
color = MaterialTheme.colorScheme.onSurfaceVariant
)
Text(
text = when (currentRoute) {
AppRoutes.SEARCH -> "Search"
AppRoutes.EXPLORE -> "Explore"
AppRoutes.COLLECTIONS -> "Collections"
AppRoutes.DISCOVER -> "Discover People" // ✅ SHOWS NOW!
AppRoutes.INVENTORY -> "People"
AppRoutes.TRAIN -> "Train Model"
AppRoutes.TAGS -> "Tags"
AppRoutes.UTILITIES -> "Utilities"
AppRoutes.SETTINGS -> "Settings"
AppRoutes.MODELS -> "AI Models"
else -> {
// Handle dynamic routes like album/{type}/{id}
if (currentRoute?.startsWith("album/") == true) {
"Album"
} else {
"SherpAI"
}
}
}
}
)
},
navigationIcon = {
IconButton(
onClick = { scope.launch { drawerState.open() } }
) {
IconButton(onClick = {
scope.launch {
drawerState.open()
}
}) {
Icon(
Icons.Default.Menu,
contentDescription = "Open Menu",
tint = MaterialTheme.colorScheme.primary
imageVector = Icons.Default.Menu,
contentDescription = "Open menu"
)
}
},
actions = {
// Dynamic actions based on current screen
when (currentRoute) {
AppRoutes.SEARCH -> {
IconButton(onClick = { /* TODO: Open filter dialog */ }) {
Icon(
Icons.Default.FilterList,
contentDescription = "Filter",
tint = MaterialTheme.colorScheme.primary
)
}
}
AppRoutes.INVENTORY -> {
IconButton(onClick = {
navController.navigate(AppRoutes.TRAIN)
}) {
Icon(
Icons.Default.PersonAdd,
contentDescription = "Add Person",
tint = MaterialTheme.colorScheme.primary
)
}
}
}
},
colors = TopAppBarDefaults.topAppBarColors(
containerColor = MaterialTheme.colorScheme.surface,
titleContentColor = MaterialTheme.colorScheme.onSurface,
navigationIconContentColor = MaterialTheme.colorScheme.primary,
actionIconContentColor = MaterialTheme.colorScheme.primary
containerColor = MaterialTheme.colorScheme.primaryContainer,
titleContentColor = MaterialTheme.colorScheme.onPrimaryContainer,
navigationIconContentColor = MaterialTheme.colorScheme.onPrimaryContainer,
actionIconContentColor = MaterialTheme.colorScheme.onPrimaryContainer
)
)
}
}
) { paddingValues ->
// ✅ Use YOUR existing AppNavHost - it already has all the screens defined!
AppNavHost(
navController = navController,
modifier = Modifier.padding(paddingValues)
)
}
}
}
/**
* Get human-readable screen title
*/
private fun getScreenTitle(route: String): String {
return when (route) {
AppRoutes.SEARCH -> "Search"
AppRoutes.EXPLORE -> "Explore"
AppRoutes.COLLECTIONS -> "Collections"
AppRoutes.DISCOVER -> "Discover People"
AppRoutes.INVENTORY -> "People"
AppRoutes.TRAIN -> "Train New Person"
AppRoutes.MODELS -> "AI Models"
AppRoutes.TAGS -> "Tag Management"
AppRoutes.UTILITIES -> "Photo Util."
AppRoutes.SETTINGS -> "Settings"
else -> "SherpAI"
}
}
/**
* Get subtitle for screens that need context
*/
private fun getScreenSubtitle(route: String): String? {
return when (route) {
AppRoutes.SEARCH -> "Find photos by tags, people, or date"
AppRoutes.EXPLORE -> "Browse your collection"
AppRoutes.COLLECTIONS -> "Your photo collections"
AppRoutes.DISCOVER -> "Auto-find faces in your library"
AppRoutes.INVENTORY -> "Trained face models"
AppRoutes.TRAIN -> "Add a new person to recognize"
AppRoutes.TAGS -> "Organize your photo collection"
AppRoutes.UTILITIES -> "Tools for managing collection"
else -> null
// ✅ Face cache prompt dialog (shows on app launch if needed)
if (needsFaceCachePopulation) {
FaceCachePromptDialog(
unscannedPhotoCount = unscannedPhotoCount,
onDismiss = { viewModel.dismissFaceCachePrompt() },
onScanNow = {
viewModel.dismissFaceCachePrompt()
navController.navigate(AppRoutes.UTILITIES)
}
)
}
}