discover dez

This commit is contained in:
genki
2026-01-21 10:11:20 -05:00
parent 7f122a4e17
commit 4474365cd6
10 changed files with 1446 additions and 615 deletions

View File

@@ -8,7 +8,15 @@
<list> <list>
<CategoryState> <CategoryState>
<option name="attribute" value="Type" /> <option name="attribute" value="Type" />
<option name="value" value="Physical" /> <option name="value" value="Virtual" />
</CategoryState>
<CategoryState>
<option name="attribute" value="Type" />
<option name="value" value="Virtual" />
</CategoryState>
<CategoryState>
<option name="attribute" value="Type" />
<option name="value" value="Virtual" />
</CategoryState> </CategoryState>
</list> </list>
</option> </option>
@@ -17,6 +25,10 @@
</option> </option>
<option name="columnSorters"> <option name="columnSorters">
<list> <list>
<ColumnSorterState>
<option name="column" value="Status" />
<option name="order" value="ASCENDING" />
</ColumnSorterState>
<ColumnSorterState> <ColumnSorterState>
<option name="column" value="Name" /> <option name="column" value="Name" />
<option name="order" value="DESCENDING" /> <option name="order" value="DESCENDING" />
@@ -37,6 +49,69 @@
<option value="Type" /> <option value="Type" />
<option value="Type" /> <option value="Type" />
<option value="Type" /> <option value="Type" />
<option value="Type" />
<option value="Type" />
<option value="Type" />
<option value="Type" />
<option value="Type" />
<option value="Type" />
<option value="Type" />
<option value="Type" />
<option value="Type" />
<option value="Type" />
<option value="Type" />
<option value="Type" />
<option value="Type" />
<option value="Type" />
<option value="Type" />
<option value="Type" />
<option value="Type" />
<option value="Type" />
<option value="Type" />
<option value="Type" />
<option value="Type" />
<option value="Type" />
<option value="Type" />
<option value="Type" />
<option value="Type" />
<option value="Type" />
<option value="Type" />
<option value="Type" />
<option value="Type" />
<option value="Type" />
<option value="Type" />
<option value="Type" />
<option value="Type" />
<option value="Type" />
<option value="Type" />
<option value="Type" />
<option value="Type" />
<option value="Type" />
<option value="Type" />
<option value="Type" />
<option value="Type" />
<option value="Type" />
<option value="Type" />
<option value="Type" />
<option value="Type" />
<option value="Type" />
<option value="Type" />
<option value="Type" />
<option value="Type" />
<option value="Type" />
<option value="Type" />
<option value="Type" />
<option value="Type" />
<option value="Type" />
<option value="Type" />
<option value="Type" />
<option value="Type" />
<option value="Type" />
<option value="Type" />
<option value="Type" />
<option value="Type" />
<option value="Type" />
<option value="Type" />
</list> </list>
</option> </option>
</component> </component>

Binary file not shown.

View File

@@ -1,129 +1,91 @@
package com.placeholder.sherpai2.data.local.dao package com.placeholder.sherpai2.data.local.dao
import androidx.room.* import androidx.room.Dao
import androidx.room.Insert
import androidx.room.OnConflictStrategy
import androidx.room.Query
import androidx.room.Update
import com.placeholder.sherpai2.data.local.entity.FaceCacheEntity import com.placeholder.sherpai2.data.local.entity.FaceCacheEntity
import kotlinx.coroutines.flow.Flow
/** /**
* FaceCacheDao - Query face metadata for intelligent filtering * FaceCacheDao - Face detection cache with NEW queries for two-stage clustering
*
* ENABLES SMART CLUSTERING:
* - Pre-filter to high-quality faces only
* - Avoid processing blurry/distant faces
* - Faster clustering with better results
*/ */
@Dao @Dao
interface FaceCacheDao { interface FaceCacheDao {
// ═══════════════════════════════════════
// INSERT / UPDATE
// ═══════════════════════════════════════
@Insert(onConflict = OnConflictStrategy.REPLACE) @Insert(onConflict = OnConflictStrategy.REPLACE)
suspend fun insert(faceCache: FaceCacheEntity) suspend fun insert(faceCache: FaceCacheEntity)
@Insert(onConflict = OnConflictStrategy.REPLACE) @Insert(onConflict = OnConflictStrategy.REPLACE)
suspend fun insertAll(faceCaches: List<FaceCacheEntity>) suspend fun insertAll(faceCaches: List<FaceCacheEntity>)
/** @Update
* Get ALL high-quality solo faces for clustering suspend fun update(faceCache: FaceCacheEntity)
*
* FILTERS: // ═══════════════════════════════════════
* - Solo photos only (joins with images.faceCount = 1) // NEW CLUSTERING QUERIES ⭐
* - Large enough (isLargeEnough = true) // ═══════════════════════════════════════
* - Good quality score (>= 0.6)
* - Frontal faces preferred (isFrontal = true)
*/
@Query("""
SELECT fc.* FROM face_cache fc
INNER JOIN images i ON fc.imageId = i.imageId
WHERE i.faceCount = 1
AND fc.isLargeEnough = 1
AND fc.qualityScore >= 0.6
AND fc.isFrontal = 1
ORDER BY fc.qualityScore DESC
""")
suspend fun getHighQualitySoloFaces(): List<FaceCacheEntity>
/** /**
* Get high-quality faces from ANY photo (including group photos) * Get high-quality solo faces for Stage 1 clustering
* Use when not enough solo photos available *
* Filters:
* - Solo photos (faceCount = 1)
* - Large faces (faceAreaRatio >= minFaceRatio)
* - Has embedding
*/ */
@Query(""" @Query("""
SELECT * FROM face_cache SELECT fc.*
WHERE isLargeEnough = 1 FROM face_cache fc
AND qualityScore >= 0.6 INNER JOIN images i ON fc.imageId = i.imageId
AND isFrontal = 1 WHERE i.faceCount = 1
ORDER BY qualityScore DESC AND fc.faceAreaRatio >= :minFaceRatio
AND fc.embedding IS NOT NULL
ORDER BY fc.faceAreaRatio DESC
LIMIT :limit LIMIT :limit
""") """)
suspend fun getHighQualityFaces(limit: Int = 1000): List<FaceCacheEntity> suspend fun getHighQualitySoloFaces(
minFaceRatio: Float = 0.015f,
limit: Int = 2000
): List<FaceCacheEntity>
/** /**
* Get faces for a specific image * FALLBACK: Get ANY solo faces with embeddings
*/ * Used if getHighQualitySoloFaces() returns empty
@Query("SELECT * FROM face_cache WHERE imageId = :imageId ORDER BY faceIndex ASC")
suspend fun getFacesForImage(imageId: String): List<FaceCacheEntity>
/**
* Count high-quality solo faces (for UI display)
*/ */
@Query(""" @Query("""
SELECT COUNT(*) FROM face_cache fc SELECT fc.*
FROM face_cache fc
INNER JOIN images i ON fc.imageId = i.imageId INNER JOIN images i ON fc.imageId = i.imageId
WHERE i.faceCount = 1 WHERE i.faceCount = 1
AND fc.isLargeEnough = 1
AND fc.qualityScore >= 0.6
""")
suspend fun getHighQualitySoloFaceCount(): Int
/**
* Get quality distribution stats
*/
@Query("""
SELECT
SUM(CASE WHEN qualityScore >= 0.8 THEN 1 ELSE 0 END) as excellent,
SUM(CASE WHEN qualityScore >= 0.6 AND qualityScore < 0.8 THEN 1 ELSE 0 END) as good,
SUM(CASE WHEN qualityScore < 0.6 THEN 1 ELSE 0 END) as poor,
COUNT(*) as total
FROM face_cache
""")
suspend fun getQualityStats(): FaceQualityStats?
/**
* Delete cache for specific image (when image is deleted)
*/
@Query("DELETE FROM face_cache WHERE imageId = :imageId")
suspend fun deleteCacheForImage(imageId: String)
/**
* Delete all cache (for full rebuild)
*/
@Query("DELETE FROM face_cache")
suspend fun deleteAll()
/**
* Get faces with embeddings already computed
* (Ultra-fast clustering - no need to re-generate)
*/
@Query("""
SELECT fc.* FROM face_cache fc
INNER JOIN images i ON fc.imageId = i.imageId
WHERE i.faceCount = 1
AND fc.isLargeEnough = 1
AND fc.embedding IS NOT NULL AND fc.embedding IS NOT NULL
ORDER BY fc.qualityScore DESC ORDER BY fc.qualityScore DESC
LIMIT :limit LIMIT :limit
""") """)
suspend fun getSoloFacesWithEmbeddings(limit: Int = 2000): List<FaceCacheEntity> suspend fun getSoloFacesWithEmbeddings(
} limit: Int = 2000
): List<FaceCacheEntity>
/** // ═══════════════════════════════════════
* Quality statistics result // EXISTING QUERIES (keep as-is)
*/ // ═══════════════════════════════════════
data class FaceQualityStats(
val excellent: Int, // qualityScore >= 0.8 @Query("SELECT * FROM face_cache WHERE id = :id")
val good: Int, // 0.6 <= qualityScore < 0.8 suspend fun getFaceCacheById(id: String): FaceCacheEntity?
val poor: Int, // qualityScore < 0.6
val total: Int @Query("SELECT * FROM face_cache WHERE imageId = :imageId ORDER BY faceIndex")
) { suspend fun getFaceCacheForImage(imageId: String): List<FaceCacheEntity>
val excellentPercent: Float get() = if (total > 0) excellent.toFloat() / total else 0f
val goodPercent: Float get() = if (total > 0) good.toFloat() / total else 0f @Query("DELETE FROM face_cache WHERE imageId = :imageId")
val poorPercent: Float get() = if (total > 0) poor.toFloat() / total else 0f suspend fun deleteFaceCacheForImage(imageId: String)
@Query("DELETE FROM face_cache")
suspend fun deleteAll()
@Query("DELETE FROM face_cache WHERE cacheVersion < :version")
suspend fun deleteOldVersions(version: Int)
} }

View File

@@ -1,7 +1,7 @@
package com.placeholder.sherpai2.domain.clustering package com.placeholder.sherpai2.domain.clustering
import android.graphics.Rect import android.graphics.Rect
import com.placeholder.sherpai2.domain.clustering.DetectedFaceWithEmbedding import android.util.Log
import javax.inject.Inject import javax.inject.Inject
import javax.inject.Singleton import javax.inject.Singleton
import kotlin.math.sqrt import kotlin.math.sqrt
@@ -9,56 +9,62 @@ import kotlin.math.sqrt
/** /**
* ClusterQualityAnalyzer - Validate cluster quality BEFORE training * ClusterQualityAnalyzer - Validate cluster quality BEFORE training
* *
* PURPOSE: Prevent training on "dirty" clusters (siblings merged, poor quality faces) * RELAXED THRESHOLDS for real-world photos (social media, distant shots):
* * - Face size: 3% (down from 15%)
* CHECKS: * - Outlier threshold: 65% (down from 75%)
* 1. Solo photo count (min 6 required) * - GOOD tier: 75% (down from 85%)
* 2. Face size (min 15% of image - clear, not distant) * - EXCELLENT tier: 85% (down from 95%)
* 3. Internal consistency (all faces should match well)
* 4. Outlier detection (find faces that don't belong)
*
* QUALITY TIERS:
* - Excellent (95%+): Safe to train immediately
* - Good (85-94%): Review outliers, then train
* - Poor (<85%): Likely mixed people - DO NOT TRAIN!
*/ */
@Singleton @Singleton
class ClusterQualityAnalyzer @Inject constructor() { class ClusterQualityAnalyzer @Inject constructor() {
companion object { companion object {
private const val TAG = "ClusterQuality"
private const val MIN_SOLO_PHOTOS = 6 private const val MIN_SOLO_PHOTOS = 6
private const val MIN_FACE_SIZE_RATIO = 0.15f // 15% of image private const val MIN_FACE_SIZE_RATIO = 0.03f // 3% of image (RELAXED)
private const val MIN_INTERNAL_SIMILARITY = 0.80f private const val MIN_FACE_DIMENSION_PIXELS = 50 // 50px minimum (RELAXED)
private const val OUTLIER_THRESHOLD = 0.75f private const val FALLBACK_MIN_DIMENSION = 80 // Fallback when no dimensions
private const val EXCELLENT_THRESHOLD = 0.95f private const val MIN_INTERNAL_SIMILARITY = 0.75f
private const val GOOD_THRESHOLD = 0.85f private const val OUTLIER_THRESHOLD = 0.65f // RELAXED
private const val EXCELLENT_THRESHOLD = 0.85f // RELAXED
private const val GOOD_THRESHOLD = 0.75f // RELAXED
} }
/**
* Analyze cluster quality before training
*/
fun analyzeCluster(cluster: FaceCluster): ClusterQualityResult { fun analyzeCluster(cluster: FaceCluster): ClusterQualityResult {
// Step 1: Filter to solo photos only Log.d(TAG, "========================================")
val soloFaces = cluster.faces.filter { it.faceCount == 1 } Log.d(TAG, "Analyzing cluster ${cluster.clusterId}")
Log.d(TAG, "Total faces: ${cluster.faces.size}")
// Step 2: Filter by face size (must be clear/close-up) // Step 1: Filter to solo photos
val soloFaces = cluster.faces.filter { it.faceCount == 1 }
Log.d(TAG, "Solo photos: ${soloFaces.size}")
// Step 2: Filter by face size
val largeFaces = soloFaces.filter { face -> val largeFaces = soloFaces.filter { face ->
isFaceLargeEnough(face.boundingBox, face.imageUri) isFaceLargeEnough(face)
}
Log.d(TAG, "Large faces (>= 3%): ${largeFaces.size}")
if (largeFaces.size < soloFaces.size) {
Log.d(TAG, "⚠️ Filtered out ${soloFaces.size - largeFaces.size} small faces")
} }
// Step 3: Calculate internal consistency // Step 3: Calculate internal consistency
val (avgSimilarity, outliers) = analyzeInternalConsistency(largeFaces) val (avgSimilarity, outliers) = analyzeInternalConsistency(largeFaces)
// Step 4: Clean faces (large solo faces, no outliers) // Step 4: Clean faces
val cleanFaces = largeFaces.filter { it !in outliers } val cleanFaces = largeFaces.filter { it !in outliers }
Log.d(TAG, "Clean faces: ${cleanFaces.size}")
// Step 5: Calculate quality score // Step 5: Calculate quality score
val qualityScore = calculateQualityScore( val qualityScore = calculateQualityScore(
soloPhotoCount = soloFaces.size, soloPhotoCount = soloFaces.size,
largeFaceCount = largeFaces.size, largeFaceCount = largeFaces.size,
cleanFaceCount = cleanFaces.size, cleanFaceCount = cleanFaces.size,
avgSimilarity = avgSimilarity avgSimilarity = avgSimilarity,
totalFaces = cluster.faces.size
) )
Log.d(TAG, "Quality score: ${(qualityScore * 100).toInt()}%")
// Step 6: Determine quality tier // Step 6: Determine quality tier
val qualityTier = when { val qualityTier = when {
@@ -66,6 +72,11 @@ class ClusterQualityAnalyzer @Inject constructor() {
qualityScore >= GOOD_THRESHOLD -> ClusterQualityTier.GOOD qualityScore >= GOOD_THRESHOLD -> ClusterQualityTier.GOOD
else -> ClusterQualityTier.POOR else -> ClusterQualityTier.POOR
} }
Log.d(TAG, "Quality tier: $qualityTier")
val canTrain = qualityTier != ClusterQualityTier.POOR && cleanFaces.size >= MIN_SOLO_PHOTOS
Log.d(TAG, "Can train: $canTrain")
Log.d(TAG, "========================================")
return ClusterQualityResult( return ClusterQualityResult(
originalFaceCount = cluster.faces.size, originalFaceCount = cluster.faces.size,
@@ -77,62 +88,65 @@ class ClusterQualityAnalyzer @Inject constructor() {
cleanFaces = cleanFaces, cleanFaces = cleanFaces,
qualityScore = qualityScore, qualityScore = qualityScore,
qualityTier = qualityTier, qualityTier = qualityTier,
canTrain = qualityTier != ClusterQualityTier.POOR && cleanFaces.size >= MIN_SOLO_PHOTOS, canTrain = canTrain,
warnings = generateWarnings(soloFaces.size, largeFaces.size, cleanFaces.size, qualityTier) warnings = generateWarnings(soloFaces.size, largeFaces.size, cleanFaces.size, qualityTier, avgSimilarity)
) )
} }
/** private fun isFaceLargeEnough(face: DetectedFaceWithEmbedding): Boolean {
* Check if face is large enough (not distant/blurry) val faceArea = face.boundingBox.width() * face.boundingBox.height()
*
* A face should occupy at least 15% of the image area for good quality
*/
private fun isFaceLargeEnough(boundingBox: Rect, imageUri: String): Boolean {
// Estimate image dimensions from common aspect ratios
// For now, use bounding box size as proxy
val faceArea = boundingBox.width() * boundingBox.height()
// Assume typical photo is ~2000x1500 = 3,000,000 pixels // Check 1: Absolute minimum
// 15% = 450,000 pixels if (face.boundingBox.width() < MIN_FACE_DIMENSION_PIXELS ||
// For a square face: sqrt(450,000) = ~670 pixels per side face.boundingBox.height() < MIN_FACE_DIMENSION_PIXELS) {
return false
// More conservative: face should be at least 200x200 pixels }
return boundingBox.width() >= 200 && boundingBox.height() >= 200
// Check 2: Relative size if we have dimensions
if (face.imageWidth > 0 && face.imageHeight > 0) {
val imageArea = face.imageWidth * face.imageHeight
val faceRatio = faceArea.toFloat() / imageArea.toFloat()
return faceRatio >= MIN_FACE_SIZE_RATIO
}
// Fallback: Use absolute size
return face.boundingBox.width() >= FALLBACK_MIN_DIMENSION &&
face.boundingBox.height() >= FALLBACK_MIN_DIMENSION
} }
/**
* Analyze how similar faces are to each other (internal consistency)
*
* Returns: (average similarity, list of outlier faces)
*/
private fun analyzeInternalConsistency( private fun analyzeInternalConsistency(
faces: List<DetectedFaceWithEmbedding> faces: List<DetectedFaceWithEmbedding>
): Pair<Float, List<DetectedFaceWithEmbedding>> { ): Pair<Float, List<DetectedFaceWithEmbedding>> {
if (faces.size < 2) { if (faces.size < 2) {
Log.d(TAG, "Less than 2 faces, skipping consistency check")
return 0f to emptyList() return 0f to emptyList()
} }
// Calculate average embedding (centroid) Log.d(TAG, "Analyzing ${faces.size} faces for internal consistency")
val centroid = calculateCentroid(faces.map { it.embedding }) val centroid = calculateCentroid(faces.map { it.embedding })
// Calculate similarity of each face to centroid val centroidSum = centroid.sum()
val similarities = faces.map { face -> Log.d(TAG, "Centroid sum: $centroidSum, first5=[${centroid.take(5).joinToString()}]")
face to cosineSimilarity(face.embedding, centroid)
val similarities = faces.mapIndexed { index, face ->
val similarity = cosineSimilarity(face.embedding, centroid)
Log.d(TAG, "Face $index similarity to centroid: $similarity")
face to similarity
} }
val avgSimilarity = similarities.map { it.second }.average().toFloat() val avgSimilarity = similarities.map { it.second }.average().toFloat()
Log.d(TAG, "Average internal similarity: $avgSimilarity")
// Find outliers (faces significantly different from centroid)
val outliers = similarities val outliers = similarities
.filter { (_, similarity) -> similarity < OUTLIER_THRESHOLD } .filter { (_, similarity) -> similarity < OUTLIER_THRESHOLD }
.map { (face, _) -> face } .map { (face, _) -> face }
Log.d(TAG, "Found ${outliers.size} outliers (threshold=$OUTLIER_THRESHOLD)")
return avgSimilarity to outliers return avgSimilarity to outliers
} }
/**
* Calculate centroid (average embedding)
*/
private fun calculateCentroid(embeddings: List<FloatArray>): FloatArray { private fun calculateCentroid(embeddings: List<FloatArray>): FloatArray {
val size = embeddings.first().size val size = embeddings.first().size
val centroid = FloatArray(size) { 0f } val centroid = FloatArray(size) { 0f }
@@ -148,14 +162,14 @@ class ClusterQualityAnalyzer @Inject constructor() {
centroid[i] /= count centroid[i] /= count
} }
// Normalize
val norm = sqrt(centroid.map { it * it }.sum()) val norm = sqrt(centroid.map { it * it }.sum())
return centroid.map { it / norm }.toFloatArray() return if (norm > 0) {
centroid.map { it / norm }.toFloatArray()
} else {
centroid
}
} }
/**
* Cosine similarity between two embeddings
*/
private fun cosineSimilarity(a: FloatArray, b: FloatArray): Float { private fun cosineSimilarity(a: FloatArray, b: FloatArray): Float {
var dotProduct = 0f var dotProduct = 0f
var normA = 0f var normA = 0f
@@ -170,32 +184,31 @@ class ClusterQualityAnalyzer @Inject constructor() {
return dotProduct / (sqrt(normA) * sqrt(normB)) return dotProduct / (sqrt(normA) * sqrt(normB))
} }
/**
* Calculate overall quality score (0.0 - 1.0)
*/
private fun calculateQualityScore( private fun calculateQualityScore(
soloPhotoCount: Int, soloPhotoCount: Int,
largeFaceCount: Int, largeFaceCount: Int,
cleanFaceCount: Int, cleanFaceCount: Int,
avgSimilarity: Float avgSimilarity: Float,
totalFaces: Int
): Float { ): Float {
// Weight factors val soloRatio = soloPhotoCount.toFloat() / totalFaces.toFloat().coerceAtLeast(1f)
val soloPhotoScore = (soloPhotoCount.toFloat() / 20f).coerceIn(0f, 1f) * 0.3f val soloPhotoScore = soloRatio.coerceIn(0f, 1f) * 0.25f
val largeFaceScore = (largeFaceCount.toFloat() / 15f).coerceIn(0f, 1f) * 0.2f
val cleanFaceScore = (cleanFaceCount.toFloat() / 10f).coerceIn(0f, 1f) * 0.2f val largeFaceScore = (largeFaceCount.toFloat() / 15f).coerceIn(0f, 1f) * 0.25f
val similarityScore = avgSimilarity * 0.3f
val cleanFaceScore = (cleanFaceCount.toFloat() / 10f).coerceIn(0f, 1f) * 0.20f
val similarityScore = avgSimilarity * 0.30f
return soloPhotoScore + largeFaceScore + cleanFaceScore + similarityScore return soloPhotoScore + largeFaceScore + cleanFaceScore + similarityScore
} }
/**
* Generate human-readable warnings
*/
private fun generateWarnings( private fun generateWarnings(
soloPhotoCount: Int, soloPhotoCount: Int,
largeFaceCount: Int, largeFaceCount: Int,
cleanFaceCount: Int, cleanFaceCount: Int,
qualityTier: ClusterQualityTier qualityTier: ClusterQualityTier,
avgSimilarity: Float
): List<String> { ): List<String> {
val warnings = mutableListOf<String>() val warnings = mutableListOf<String>()
@@ -203,12 +216,20 @@ class ClusterQualityAnalyzer @Inject constructor() {
ClusterQualityTier.POOR -> { ClusterQualityTier.POOR -> {
warnings.add("⚠️ POOR QUALITY - This cluster may contain multiple people!") warnings.add("⚠️ POOR QUALITY - This cluster may contain multiple people!")
warnings.add("Do NOT train on this cluster - it will create a bad model.") warnings.add("Do NOT train on this cluster - it will create a bad model.")
if (avgSimilarity < 0.70f) {
warnings.add("Low internal similarity (${(avgSimilarity * 100).toInt()}%) suggests mixed identities.")
}
} }
ClusterQualityTier.GOOD -> { ClusterQualityTier.GOOD -> {
warnings.add("⚠️ Review outlier faces before training") warnings.add("⚠️ Review outlier faces before training")
if (cleanFaceCount < 10) {
warnings.add("Consider adding more high-quality photos for better results.")
}
} }
ClusterQualityTier.EXCELLENT -> { ClusterQualityTier.EXCELLENT -> {
// No warnings - ready to train! // No warnings
} }
} }
@@ -218,38 +239,47 @@ class ClusterQualityAnalyzer @Inject constructor() {
if (largeFaceCount < 6) { if (largeFaceCount < 6) {
warnings.add("Only $largeFaceCount photos with large/clear faces (prefer 10+)") warnings.add("Only $largeFaceCount photos with large/clear faces (prefer 10+)")
warnings.add("Tip: Use close-up photos where the face is clearly visible")
} }
if (cleanFaceCount < 6) { if (cleanFaceCount < 6) {
warnings.add("After removing outliers: only $cleanFaceCount clean faces (need 6+)") warnings.add("After removing outliers: only $cleanFaceCount clean faces (need 6+)")
} }
if (qualityTier == ClusterQualityTier.EXCELLENT) {
warnings.add("✅ Excellent quality! This cluster is ready for training.")
warnings.add("High-quality photos with consistent facial features detected.")
}
return warnings return warnings
} }
} }
/**
* Result of cluster quality analysis
*/
data class ClusterQualityResult( data class ClusterQualityResult(
val originalFaceCount: Int, // Total faces in cluster val originalFaceCount: Int,
val soloPhotoCount: Int, // Photos with faceCount = 1 val soloPhotoCount: Int,
val largeFaceCount: Int, // Solo photos with large faces val largeFaceCount: Int,
val cleanFaceCount: Int, // Large faces, no outliers val cleanFaceCount: Int,
val avgInternalSimilarity: Float, // How similar faces are (0.0-1.0) val avgInternalSimilarity: Float,
val outlierFaces: List<DetectedFaceWithEmbedding>, // Faces to exclude val outlierFaces: List<DetectedFaceWithEmbedding>,
val cleanFaces: List<DetectedFaceWithEmbedding>, // Good faces for training val cleanFaces: List<DetectedFaceWithEmbedding>,
val qualityScore: Float, // Overall score (0.0-1.0) val qualityScore: Float,
val qualityTier: ClusterQualityTier, val qualityTier: ClusterQualityTier,
val canTrain: Boolean, // Safe to proceed with training? val canTrain: Boolean,
val warnings: List<String> // Human-readable issues val warnings: List<String>
) ) {
fun getSummary(): String = when (qualityTier) {
/** ClusterQualityTier.EXCELLENT ->
* Quality tier classification "Excellent quality cluster with $cleanFaceCount high-quality photos ready for training."
*/ ClusterQualityTier.GOOD ->
enum class ClusterQualityTier { "Good quality cluster with $cleanFaceCount usable photos. Review outliers before training."
EXCELLENT, // 95%+ - Safe to train immediately ClusterQualityTier.POOR ->
GOOD, // 85-94% - Review outliers first "Poor quality cluster. May contain multiple people or low-quality photos. Add more photos or split cluster."
POOR // <85% - DO NOT TRAIN (likely mixed people) }
}
enum class ClusterQualityTier {
EXCELLENT, // 85%+
GOOD, // 75-84%
POOR // <75%
} }

View File

@@ -4,11 +4,13 @@ import android.content.Context
import android.graphics.Bitmap import android.graphics.Bitmap
import android.graphics.BitmapFactory import android.graphics.BitmapFactory
import android.net.Uri import android.net.Uri
import android.util.Log
import com.google.mlkit.vision.common.InputImage import com.google.mlkit.vision.common.InputImage
import com.google.mlkit.vision.face.FaceDetection import com.google.mlkit.vision.face.FaceDetection
import com.google.mlkit.vision.face.FaceDetectorOptions import com.google.mlkit.vision.face.FaceDetectorOptions
import com.placeholder.sherpai2.data.local.dao.FaceCacheDao import com.placeholder.sherpai2.data.local.dao.FaceCacheDao
import com.placeholder.sherpai2.data.local.dao.ImageDao import com.placeholder.sherpai2.data.local.dao.ImageDao
import com.placeholder.sherpai2.data.local.entity.FaceCacheEntity
import com.placeholder.sherpai2.data.local.entity.ImageEntity import com.placeholder.sherpai2.data.local.entity.ImageEntity
import com.placeholder.sherpai2.ml.FaceNetModel import com.placeholder.sherpai2.ml.FaceNetModel
import dagger.hilt.android.qualifiers.ApplicationContext import dagger.hilt.android.qualifiers.ApplicationContext
@@ -18,7 +20,6 @@ import kotlinx.coroutines.awaitAll
import kotlinx.coroutines.coroutineScope import kotlinx.coroutines.coroutineScope
import kotlinx.coroutines.sync.Semaphore import kotlinx.coroutines.sync.Semaphore
import kotlinx.coroutines.withContext import kotlinx.coroutines.withContext
import java.util.concurrent.atomic.AtomicInteger
import javax.inject.Inject import javax.inject.Inject
import javax.inject.Singleton import javax.inject.Singleton
import kotlin.math.sqrt import kotlin.math.sqrt
@@ -33,54 +34,100 @@ import kotlin.math.sqrt
* 4. Detect faces and generate embeddings (parallel) * 4. Detect faces and generate embeddings (parallel)
* 5. Cluster using DBSCAN (epsilon=0.18, minPoints=3) * 5. Cluster using DBSCAN (epsilon=0.18, minPoints=3)
* 6. Analyze clusters for age, siblings, representatives * 6. Analyze clusters for age, siblings, representatives
*
* IMPROVEMENTS:
* - ✅ Complete fast-path using FaceCacheDao.getSoloFacesWithEmbeddings()
* - ✅ Works with existing FaceCacheEntity.getEmbedding() method
* - ✅ Centroid-based representative face selection
* - ✅ Batched processing to prevent OOM
* - ✅ RGB_565 bitmap config for 50% memory savings
*/ */
@Singleton @Singleton
class FaceClusteringService @Inject constructor( class FaceClusteringService @Inject constructor(
@ApplicationContext private val context: Context, @ApplicationContext private val context: Context,
private val imageDao: ImageDao, private val imageDao: ImageDao,
private val faceCacheDao: FaceCacheDao // Optional - will work without it private val faceCacheDao: FaceCacheDao
) { ) {
private val semaphore = Semaphore(12) private val semaphore = Semaphore(8)
companion object {
private const val TAG = "FaceClustering"
private const val MAX_FACES_TO_CLUSTER = 2000
private const val MIN_SOLO_PHOTOS = 50
private const val BATCH_SIZE = 50
private const val MIN_CACHED_FACES = 100
}
/** /**
* Main clustering entry point - HYBRID with automatic fallback * Main clustering entry point - HYBRID with automatic fallback
*
* @param maxFacesToCluster Limit for performance (default 2000)
* @param onProgress Progress callback (current, total, message)
*/ */
suspend fun discoverPeople( suspend fun discoverPeople(
maxFacesToCluster: Int = 2000, maxFacesToCluster: Int = MAX_FACES_TO_CLUSTER,
onProgress: (Int, Int, String) -> Unit = { _, _, _ -> } onProgress: (Int, Int, String) -> Unit = { _, _, _ -> }
): ClusteringResult = withContext(Dispatchers.Default) { ): ClusteringResult = withContext(Dispatchers.Default) {
// TRY FAST PATH: Use face cache if available val startTime = System.currentTimeMillis()
val highQualityFaces = try {
withContext(Dispatchers.IO) { // Try high-quality cached faces FIRST (NEW!)
faceCacheDao.getHighQualitySoloFaces() var cachedFaces = withContext(Dispatchers.IO) {
try {
faceCacheDao.getHighQualitySoloFaces(
minFaceRatio = 0.015f, // 1.5%
limit = maxFacesToCluster
)
} catch (e: Exception) {
// Method doesn't exist yet - that's ok
emptyList()
} }
}
// Fallback to ANY solo faces if high-quality returned nothing
if (cachedFaces.isEmpty()) {
Log.w(TAG, "No high-quality faces (>= 1.5%), trying ANY solo faces...")
cachedFaces = withContext(Dispatchers.IO) {
try {
faceCacheDao.getSoloFacesWithEmbeddings(limit = maxFacesToCluster)
} catch (e: Exception) { } catch (e: Exception) {
emptyList() emptyList()
} }
}
if (highQualityFaces.isNotEmpty()) {
// FAST PATH: Use cached faces (future enhancement)
onProgress(0, 100, "Using face cache (${highQualityFaces.size} faces)...")
// TODO: Implement cache-based clustering
// For now, fall through to classic method
} }
// CLASSIC METHOD: Load and process photos Log.d(TAG, "Cache check: ${cachedFaces.size} faces available")
onProgress(0, 100, "Loading solo photos...")
val allFaces = if (cachedFaces.size >= MIN_CACHED_FACES) {
// FAST PATH ✅
Log.d(TAG, "Using FAST PATH with ${cachedFaces.size} cached faces")
onProgress(10, 100, "Using cached embeddings (${cachedFaces.size} faces)...")
cachedFaces.mapNotNull { cached ->
val embedding = cached.getEmbedding() ?: return@mapNotNull null
DetectedFaceWithEmbedding(
imageId = cached.imageId,
imageUri = "",
capturedAt = 0L,
embedding = embedding,
boundingBox = cached.getBoundingBox(),
confidence = cached.confidence,
faceCount = 1, // Solo faces only (filtered by query)
imageWidth = cached.imageWidth,
imageHeight = cached.imageHeight
)
}.also {
onProgress(50, 100, "Processing ${it.size} cached faces...")
}
} else {
// SLOW PATH
Log.d(TAG, "Using SLOW PATH - cache has < $MIN_CACHED_FACES faces")
onProgress(0, 100, "Loading photos...")
// Step 1: Get SOLO PHOTOS ONLY (faceCount = 1) for cleaner clustering
val soloPhotos = withContext(Dispatchers.IO) { val soloPhotos = withContext(Dispatchers.IO) {
imageDao.getImagesByFaceCount(count = 1) imageDao.getImagesByFaceCount(count = 1)
} }
// Fallback: If not enough solo photos, use all images with faces val imagesWithFaces = if (soloPhotos.size < MIN_SOLO_PHOTOS) {
val imagesWithFaces = if (soloPhotos.size < 50) {
onProgress(0, 100, "Loading all photos with faces...")
imageDao.getImagesWithFaces() imageDao.getImagesWithFaces()
} else { } else {
soloPhotos soloPhotos
@@ -91,53 +138,48 @@ class FaceClusteringService @Inject constructor(
clusters = emptyList(), clusters = emptyList(),
totalFacesAnalyzed = 0, totalFacesAnalyzed = 0,
processingTimeMs = 0, processingTimeMs = 0,
errorMessage = "No photos with faces found. Please ensure face detection cache is populated." errorMessage = "No photos with faces found"
) )
} }
onProgress(10, 100, "Analyzing ${imagesWithFaces.size} photos (${if (soloPhotos.size >= 50) "solo only" else "all"})...") onProgress(10, 100, "Analyzing ${imagesWithFaces.size} photos...")
val startTime = System.currentTimeMillis() detectFacesInImagesBatched(
images = imagesWithFaces.take(1000),
// Step 2: Detect faces and generate embeddings (parallel)
val allFaces = detectFacesInImages(
images = imagesWithFaces.take(1000), // Smart limit
onProgress = { current, total -> onProgress = { current, total ->
onProgress(10 + (current * 40 / total), 100, "Detecting faces... $current/$total") onProgress(10 + (current * 40 / total), 100, "Detecting faces... $current/$total")
} }
) )
}
if (allFaces.isEmpty()) { if (allFaces.isEmpty()) {
return@withContext ClusteringResult( return@withContext ClusteringResult(
clusters = emptyList(), clusters = emptyList(),
totalFacesAnalyzed = 0, totalFacesAnalyzed = 0,
processingTimeMs = System.currentTimeMillis() - startTime, processingTimeMs = System.currentTimeMillis() - startTime,
errorMessage = "No faces detected in images" errorMessage = "No faces detected"
) )
} }
onProgress(50, 100, "Clustering ${allFaces.size} faces...") onProgress(50, 100, "Clustering ${allFaces.size} faces...")
// Step 3: DBSCAN clustering
val rawClusters = performDBSCAN( val rawClusters = performDBSCAN(
faces = allFaces.take(maxFacesToCluster), faces = allFaces.take(maxFacesToCluster),
epsilon = 0.18f, // VERY STRICT for siblings epsilon = 0.26f,
minPoints = 3 minPoints = 3
) )
onProgress(70, 100, "Analyzing relationships...") onProgress(70, 100, "Analyzing relationships...")
// Step 4: Build co-occurrence graph
val coOccurrenceGraph = buildCoOccurrenceGraph(rawClusters) val coOccurrenceGraph = buildCoOccurrenceGraph(rawClusters)
onProgress(80, 100, "Selecting representative faces...") onProgress(80, 100, "Selecting representative faces...")
// Step 5: Create final clusters
val clusters = rawClusters.map { cluster -> val clusters = rawClusters.map { cluster ->
FaceCluster( FaceCluster(
clusterId = cluster.clusterId, clusterId = cluster.clusterId,
faces = cluster.faces, faces = cluster.faces,
representativeFaces = selectRepresentativeFaces(cluster.faces, count = 6), representativeFaces = selectRepresentativeFacesByCentroid(cluster.faces, count = 6),
photoCount = cluster.faces.map { it.imageId }.distinct().size, photoCount = cluster.faces.map { it.imageId }.distinct().size,
averageConfidence = cluster.faces.map { it.confidence }.average().toFloat(), averageConfidence = cluster.faces.map { it.confidence }.average().toFloat(),
estimatedAge = estimateAge(cluster.faces), estimatedAge = estimateAge(cluster.faces),
@@ -154,14 +196,31 @@ class FaceClusteringService @Inject constructor(
) )
} }
/** private suspend fun detectFacesInImagesBatched(
* Detect faces in images and generate embeddings (parallel)
*/
private suspend fun detectFacesInImages(
images: List<ImageEntity>, images: List<ImageEntity>,
onProgress: (Int, Int) -> Unit onProgress: (Int, Int) -> Unit
): List<DetectedFaceWithEmbedding> = coroutineScope { ): List<DetectedFaceWithEmbedding> = coroutineScope {
val allFaces = mutableListOf<DetectedFaceWithEmbedding>()
var processedCount = 0
images.chunked(BATCH_SIZE).forEach { batch ->
val batchFaces = detectFacesInBatch(batch)
allFaces.addAll(batchFaces)
processedCount += batch.size
onProgress(processedCount, images.size)
System.gc()
}
allFaces
}
private suspend fun detectFacesInBatch(
images: List<ImageEntity>
): List<DetectedFaceWithEmbedding> = coroutineScope {
val detector = FaceDetection.getClient( val detector = FaceDetection.getClient(
FaceDetectorOptions.Builder() FaceDetectorOptions.Builder()
.setPerformanceMode(FaceDetectorOptions.PERFORMANCE_MODE_ACCURATE) .setPerformanceMode(FaceDetectorOptions.PERFORMANCE_MODE_ACCURATE)
@@ -170,20 +229,14 @@ class FaceClusteringService @Inject constructor(
) )
val faceNetModel = FaceNetModel(context) val faceNetModel = FaceNetModel(context)
val allFaces = mutableListOf<DetectedFaceWithEmbedding>() val batchFaces = mutableListOf<DetectedFaceWithEmbedding>()
val processedCount = AtomicInteger(0)
try { try {
val jobs = images.map { image -> val jobs = images.map { image ->
async { async(Dispatchers.IO) {
semaphore.acquire() semaphore.acquire()
try { try {
val faces = detectFacesInImage(image, detector, faceNetModel) detectFacesInImage(image, detector, faceNetModel)
val current = processedCount.incrementAndGet()
if (current % 10 == 0) {
onProgress(current, images.size)
}
faces
} finally { } finally {
semaphore.release() semaphore.release()
} }
@@ -191,7 +244,7 @@ class FaceClusteringService @Inject constructor(
} }
jobs.awaitAll().flatten().also { jobs.awaitAll().flatten().also {
allFaces.addAll(it) batchFaces.addAll(it)
} }
} finally { } finally {
@@ -199,7 +252,7 @@ class FaceClusteringService @Inject constructor(
faceNetModel.close() faceNetModel.close()
} }
allFaces batchFaces
} }
private suspend fun detectFacesInImage( private suspend fun detectFacesInImage(
@@ -215,8 +268,6 @@ class FaceClusteringService @Inject constructor(
val mlImage = InputImage.fromBitmap(bitmap, 0) val mlImage = InputImage.fromBitmap(bitmap, 0)
val faces = com.google.android.gms.tasks.Tasks.await(detector.process(mlImage)) val faces = com.google.android.gms.tasks.Tasks.await(detector.process(mlImage))
val totalFacesInImage = faces.size
val result = faces.mapNotNull { face -> val result = faces.mapNotNull { face ->
try { try {
val faceBitmap = Bitmap.createBitmap( val faceBitmap = Bitmap.createBitmap(
@@ -237,7 +288,9 @@ class FaceClusteringService @Inject constructor(
embedding = embedding, embedding = embedding,
boundingBox = face.boundingBox, boundingBox = face.boundingBox,
confidence = 0.95f, confidence = 0.95f,
faceCount = totalFacesInImage faceCount = faces.size,
imageWidth = bitmap.width,
imageHeight = bitmap.height
) )
} catch (e: Exception) { } catch (e: Exception) {
null null
@@ -252,9 +305,6 @@ class FaceClusteringService @Inject constructor(
} }
} }
// All other methods remain the same (DBSCAN, similarity, etc.)
// ... [Rest of the implementation from original file]
private fun performDBSCAN( private fun performDBSCAN(
faces: List<DetectedFaceWithEmbedding>, faces: List<DetectedFaceWithEmbedding>,
epsilon: Float, epsilon: Float,
@@ -368,18 +418,61 @@ class FaceClusteringService @Inject constructor(
?: emptyList() ?: emptyList()
} }
private fun selectRepresentativeFaces( private fun selectRepresentativeFacesByCentroid(
faces: List<DetectedFaceWithEmbedding>, faces: List<DetectedFaceWithEmbedding>,
count: Int count: Int
): List<DetectedFaceWithEmbedding> { ): List<DetectedFaceWithEmbedding> {
if (faces.size <= count) return faces if (faces.size <= count) return faces
val sortedByTime = faces.sortedBy { it.capturedAt } val centroid = calculateCentroid(faces.map { it.embedding })
val step = faces.size / count
return (0 until count).map { i -> val facesWithDistance = faces.map { face ->
sortedByTime[i * step] val distance = 1 - cosineSimilarity(face.embedding, centroid)
face to distance
} }
val sortedByProximity = facesWithDistance.sortedBy { it.second }
val representatives = mutableListOf<DetectedFaceWithEmbedding>()
representatives.add(sortedByProximity.first().first)
val remainingFaces = sortedByProximity.drop(1).take(count * 3)
val sortedByTime = remainingFaces.map { it.first }.sortedBy { it.capturedAt }
if (sortedByTime.isNotEmpty()) {
val step = sortedByTime.size / (count - 1).coerceAtLeast(1)
for (i in 0 until (count - 1)) {
val index = (i * step).coerceAtMost(sortedByTime.size - 1)
representatives.add(sortedByTime[index])
}
}
return representatives.take(count)
}
private fun calculateCentroid(embeddings: List<FloatArray>): FloatArray {
if (embeddings.isEmpty()) return FloatArray(0)
val size = embeddings.first().size
val centroid = FloatArray(size) { 0f }
embeddings.forEach { embedding ->
for (i in embedding.indices) {
centroid[i] += embedding[i]
}
}
val count = embeddings.size.toFloat()
for (i in centroid.indices) {
centroid[i] /= count
}
val norm = sqrt(centroid.map { it * it }.sum())
if (norm > 0) {
return centroid.map { it / norm }.toFloatArray()
}
return centroid
} }
private fun estimateAge(faces: List<DetectedFaceWithEmbedding>): AgeEstimate { private fun estimateAge(faces: List<DetectedFaceWithEmbedding>): AgeEstimate {
@@ -416,7 +509,6 @@ class FaceClusteringService @Inject constructor(
} }
} }
// Data classes
data class DetectedFaceWithEmbedding( data class DetectedFaceWithEmbedding(
val imageId: String, val imageId: String,
val imageUri: String, val imageUri: String,
@@ -424,7 +516,9 @@ data class DetectedFaceWithEmbedding(
val embedding: FloatArray, val embedding: FloatArray,
val boundingBox: android.graphics.Rect, val boundingBox: android.graphics.Rect,
val confidence: Float, val confidence: Float,
val faceCount: Int = 1 val faceCount: Int = 1,
val imageWidth: Int = 0,
val imageHeight: Int = 0
) { ) {
override fun equals(other: Any?): Boolean { override fun equals(other: Any?): Boolean {
if (this === other) return true if (this === other) return true

View File

@@ -2,6 +2,7 @@ package com.placeholder.sherpai2.ml
import android.content.Context import android.content.Context
import android.graphics.Bitmap import android.graphics.Bitmap
import android.util.Log
import org.tensorflow.lite.Interpreter import org.tensorflow.lite.Interpreter
import java.io.FileInputStream import java.io.FileInputStream
import java.nio.ByteBuffer import java.nio.ByteBuffer
@@ -11,16 +12,21 @@ import java.nio.channels.FileChannel
import kotlin.math.sqrt import kotlin.math.sqrt
/** /**
* FaceNetModel - MobileFaceNet wrapper for face recognition * FaceNetModel - MobileFaceNet wrapper with debugging
* *
* CLEAN IMPLEMENTATION: * IMPROVEMENTS:
* - All IDs are Strings (matching your schema) * - ✅ Detailed error logging
* - Generates 192-dimensional embeddings * - ✅ Model validation on init
* - Cosine similarity for matching * - ✅ Embedding validation (detect all-zeros)
* - ✅ Toggle-able debug mode
*/ */
class FaceNetModel(private val context: Context) { class FaceNetModel(
private val context: Context,
private val debugMode: Boolean = true // Enable for troubleshooting
) {
companion object { companion object {
private const val TAG = "FaceNetModel"
private const val MODEL_FILE = "mobilefacenet.tflite" private const val MODEL_FILE = "mobilefacenet.tflite"
private const val INPUT_SIZE = 112 private const val INPUT_SIZE = 112
private const val EMBEDDING_SIZE = 192 private const val EMBEDDING_SIZE = 192
@@ -31,13 +37,56 @@ class FaceNetModel(private val context: Context) {
} }
private var interpreter: Interpreter? = null private var interpreter: Interpreter? = null
private var modelLoadSuccess = false
init { init {
try { try {
if (debugMode) Log.d(TAG, "Loading FaceNet model: $MODEL_FILE")
val model = loadModelFile() val model = loadModelFile()
interpreter = Interpreter(model) interpreter = Interpreter(model)
modelLoadSuccess = true
if (debugMode) {
Log.d(TAG, "✅ FaceNet model loaded successfully")
Log.d(TAG, "Model input size: ${INPUT_SIZE}x$INPUT_SIZE")
Log.d(TAG, "Embedding size: $EMBEDDING_SIZE")
}
// Test model with dummy input
testModel()
} catch (e: Exception) { } catch (e: Exception) {
throw RuntimeException("Failed to load FaceNet model", e) Log.e(TAG, "❌ CRITICAL: Failed to load FaceNet model from assets/$MODEL_FILE", e)
Log.e(TAG, "Make sure mobilefacenet.tflite exists in app/src/main/assets/")
modelLoadSuccess = false
throw RuntimeException("Failed to load FaceNet model: ${e.message}", e)
}
}
/**
* Test model with dummy input to verify it works
*/
private fun testModel() {
try {
val testBitmap = Bitmap.createBitmap(INPUT_SIZE, INPUT_SIZE, Bitmap.Config.ARGB_8888)
val testEmbedding = generateEmbedding(testBitmap)
testBitmap.recycle()
val sum = testEmbedding.sum()
val norm = sqrt(testEmbedding.map { it * it }.sum())
if (debugMode) {
Log.d(TAG, "Model test: embedding sum=$sum, norm=$norm")
}
if (sum == 0f || norm == 0f) {
Log.e(TAG, "⚠️ WARNING: Model test produced zero embedding!")
} else {
if (debugMode) Log.d(TAG, "✅ Model test passed")
}
} catch (e: Exception) {
Log.e(TAG, "Model test failed", e)
} }
} }
@@ -45,12 +94,22 @@ class FaceNetModel(private val context: Context) {
* Load TFLite model from assets * Load TFLite model from assets
*/ */
private fun loadModelFile(): MappedByteBuffer { private fun loadModelFile(): MappedByteBuffer {
try {
val fileDescriptor = context.assets.openFd(MODEL_FILE) val fileDescriptor = context.assets.openFd(MODEL_FILE)
val inputStream = FileInputStream(fileDescriptor.fileDescriptor) val inputStream = FileInputStream(fileDescriptor.fileDescriptor)
val fileChannel = inputStream.channel val fileChannel = inputStream.channel
val startOffset = fileDescriptor.startOffset val startOffset = fileDescriptor.startOffset
val declaredLength = fileDescriptor.declaredLength val declaredLength = fileDescriptor.declaredLength
if (debugMode) {
Log.d(TAG, "Model file size: ${declaredLength / 1024}KB")
}
return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength) return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength)
} catch (e: Exception) {
Log.e(TAG, "Failed to open model file: $MODEL_FILE", e)
throw e
}
} }
/** /**
@@ -60,13 +119,39 @@ class FaceNetModel(private val context: Context) {
* @return 192-dimensional embedding * @return 192-dimensional embedding
*/ */
fun generateEmbedding(faceBitmap: Bitmap): FloatArray { fun generateEmbedding(faceBitmap: Bitmap): FloatArray {
if (!modelLoadSuccess || interpreter == null) {
Log.e(TAG, "❌ Cannot generate embedding: model not loaded!")
return FloatArray(EMBEDDING_SIZE) { 0f }
}
try {
val resized = Bitmap.createScaledBitmap(faceBitmap, INPUT_SIZE, INPUT_SIZE, true) val resized = Bitmap.createScaledBitmap(faceBitmap, INPUT_SIZE, INPUT_SIZE, true)
val inputBuffer = preprocessImage(resized) val inputBuffer = preprocessImage(resized)
val output = Array(1) { FloatArray(EMBEDDING_SIZE) } val output = Array(1) { FloatArray(EMBEDDING_SIZE) }
interpreter?.run(inputBuffer, output) interpreter?.run(inputBuffer, output)
return normalizeEmbedding(output[0]) val normalized = normalizeEmbedding(output[0])
// DIAGNOSTIC: Check embedding quality
if (debugMode) {
val sum = normalized.sum()
val norm = sqrt(normalized.map { it * it }.sum())
if (sum == 0f && norm == 0f) {
Log.e(TAG, "❌ CRITICAL: Generated all-zero embedding!")
Log.e(TAG, "Input bitmap: ${faceBitmap.width}x${faceBitmap.height}")
} else {
Log.d(TAG, "✅ Embedding: sum=${"%.2f".format(sum)}, norm=${"%.2f".format(norm)}, first5=[${normalized.take(5).joinToString { "%.3f".format(it) }}]")
}
}
return normalized
} catch (e: Exception) {
Log.e(TAG, "Failed to generate embedding", e)
return FloatArray(EMBEDDING_SIZE) { 0f }
}
} }
/** /**
@@ -76,6 +161,10 @@ class FaceNetModel(private val context: Context) {
faceBitmaps: List<Bitmap>, faceBitmaps: List<Bitmap>,
onProgress: (Int, Int) -> Unit = { _, _ -> } onProgress: (Int, Int) -> Unit = { _, _ -> }
): List<FloatArray> { ): List<FloatArray> {
if (debugMode) {
Log.d(TAG, "Generating embeddings for ${faceBitmaps.size} faces")
}
return faceBitmaps.mapIndexed { index, bitmap -> return faceBitmaps.mapIndexed { index, bitmap ->
onProgress(index + 1, faceBitmaps.size) onProgress(index + 1, faceBitmaps.size)
generateEmbedding(bitmap) generateEmbedding(bitmap)
@@ -88,6 +177,10 @@ class FaceNetModel(private val context: Context) {
fun createPersonModel(embeddings: List<FloatArray>): FloatArray { fun createPersonModel(embeddings: List<FloatArray>): FloatArray {
require(embeddings.isNotEmpty()) { "Need at least one embedding" } require(embeddings.isNotEmpty()) { "Need at least one embedding" }
if (debugMode) {
Log.d(TAG, "Creating person model from ${embeddings.size} embeddings")
}
val averaged = FloatArray(EMBEDDING_SIZE) { 0f } val averaged = FloatArray(EMBEDDING_SIZE) { 0f }
embeddings.forEach { embedding -> embeddings.forEach { embedding ->
@@ -101,7 +194,14 @@ class FaceNetModel(private val context: Context) {
averaged[i] /= count averaged[i] /= count
} }
return normalizeEmbedding(averaged) val normalized = normalizeEmbedding(averaged)
if (debugMode) {
val sum = normalized.sum()
Log.d(TAG, "Person model created: sum=${"%.2f".format(sum)}")
}
return normalized
} }
/** /**
@@ -110,7 +210,7 @@ class FaceNetModel(private val context: Context) {
*/ */
fun calculateSimilarity(embedding1: FloatArray, embedding2: FloatArray): Float { fun calculateSimilarity(embedding1: FloatArray, embedding2: FloatArray): Float {
require(embedding1.size == EMBEDDING_SIZE && embedding2.size == EMBEDDING_SIZE) { require(embedding1.size == EMBEDDING_SIZE && embedding2.size == EMBEDDING_SIZE) {
"Invalid embedding size" "Invalid embedding size: ${embedding1.size} vs ${embedding2.size}"
} }
var dotProduct = 0f var dotProduct = 0f
@@ -123,7 +223,14 @@ class FaceNetModel(private val context: Context) {
norm2 += embedding2[i] * embedding2[i] norm2 += embedding2[i] * embedding2[i]
} }
return dotProduct / (sqrt(norm1) * sqrt(norm2)) val similarity = dotProduct / (sqrt(norm1) * sqrt(norm2))
if (debugMode && (similarity.isNaN() || similarity.isInfinite())) {
Log.e(TAG, "❌ Invalid similarity: $similarity (norm1=$norm1, norm2=$norm2)")
return 0f
}
return similarity
} }
/** /**
@@ -151,6 +258,10 @@ class FaceNetModel(private val context: Context) {
} }
} }
if (debugMode && bestMatch != null) {
Log.d(TAG, "Best match: ${bestMatch.first} with similarity ${bestMatch.second}")
}
return bestMatch return bestMatch
} }
@@ -169,6 +280,7 @@ class FaceNetModel(private val context: Context) {
val g = ((pixel shr 8) and 0xFF) / 255.0f val g = ((pixel shr 8) and 0xFF) / 255.0f
val b = (pixel and 0xFF) / 255.0f val b = (pixel and 0xFF) / 255.0f
// Normalize to [-1, 1]
buffer.putFloat((r - 0.5f) / 0.5f) buffer.putFloat((r - 0.5f) / 0.5f)
buffer.putFloat((g - 0.5f) / 0.5f) buffer.putFloat((g - 0.5f) / 0.5f)
buffer.putFloat((b - 0.5f) / 0.5f) buffer.putFloat((b - 0.5f) / 0.5f)
@@ -190,14 +302,29 @@ class FaceNetModel(private val context: Context) {
return if (norm > 0) { return if (norm > 0) {
FloatArray(embedding.size) { i -> embedding[i] / norm } FloatArray(embedding.size) { i -> embedding[i] / norm }
} else { } else {
Log.w(TAG, "⚠️ Cannot normalize zero embedding")
embedding embedding
} }
} }
/**
* Get model status for diagnostics
*/
fun getModelStatus(): String {
return if (modelLoadSuccess) {
"✅ Model loaded and operational"
} else {
"❌ Model failed to load - check assets/$MODEL_FILE"
}
}
/** /**
* Clean up resources * Clean up resources
*/ */
fun close() { fun close() {
if (debugMode) {
Log.d(TAG, "Closing FaceNet model")
}
interpreter?.close() interpreter?.close()
interpreter = null interpreter = null
} }

View File

@@ -8,9 +8,13 @@ import androidx.compose.foundation.layout.*
import androidx.compose.foundation.lazy.grid.GridCells import androidx.compose.foundation.lazy.grid.GridCells
import androidx.compose.foundation.lazy.grid.LazyVerticalGrid import androidx.compose.foundation.lazy.grid.LazyVerticalGrid
import androidx.compose.foundation.lazy.grid.items import androidx.compose.foundation.lazy.grid.items
import androidx.compose.foundation.shape.CircleShape
import androidx.compose.foundation.shape.RoundedCornerShape import androidx.compose.foundation.shape.RoundedCornerShape
import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.filled.Check
import androidx.compose.material.icons.filled.Warning
import androidx.compose.material3.* import androidx.compose.material3.*
import androidx.compose.runtime.Composable import androidx.compose.runtime.*
import androidx.compose.ui.Alignment import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier import androidx.compose.ui.Modifier
import androidx.compose.ui.draw.clip import androidx.compose.ui.draw.clip
@@ -19,6 +23,8 @@ import androidx.compose.ui.layout.ContentScale
import androidx.compose.ui.text.font.FontWeight import androidx.compose.ui.text.font.FontWeight
import androidx.compose.ui.unit.dp import androidx.compose.ui.unit.dp
import coil.compose.AsyncImage import coil.compose.AsyncImage
import com.placeholder.sherpai2.domain.clustering.ClusterQualityAnalyzer
import com.placeholder.sherpai2.domain.clustering.ClusterQualityTier
import com.placeholder.sherpai2.domain.clustering.ClusteringResult import com.placeholder.sherpai2.domain.clustering.ClusteringResult
import com.placeholder.sherpai2.domain.clustering.FaceCluster import com.placeholder.sherpai2.domain.clustering.FaceCluster
@@ -28,13 +34,20 @@ import com.placeholder.sherpai2.domain.clustering.FaceCluster
* Each cluster card shows: * Each cluster card shows:
* - 2x2 grid of representative faces * - 2x2 grid of representative faces
* - Photo count * - Photo count
* - Quality badge (Excellent/Good/Poor)
* - Tap to name * - Tap to name
*
* IMPROVEMENTS:
* - ✅ Quality badges for each cluster
* - ✅ Visual indicators for trainable vs non-trainable clusters
* - ✅ Better UX with disabled states for poor quality clusters
*/ */
@Composable @Composable
fun ClusterGridScreen( fun ClusterGridScreen(
result: ClusteringResult, result: ClusteringResult,
onSelectCluster: (FaceCluster) -> Unit, onSelectCluster: (FaceCluster) -> Unit,
modifier: Modifier = Modifier modifier: Modifier = Modifier,
qualityAnalyzer: ClusterQualityAnalyzer = remember { ClusterQualityAnalyzer() }
) { ) {
Column( Column(
modifier = modifier modifier = modifier
@@ -65,8 +78,15 @@ fun ClusterGridScreen(
verticalArrangement = Arrangement.spacedBy(12.dp) verticalArrangement = Arrangement.spacedBy(12.dp)
) { ) {
items(result.clusters) { cluster -> items(result.clusters) { cluster ->
// Analyze quality for each cluster
val qualityResult = remember(cluster) {
qualityAnalyzer.analyzeCluster(cluster)
}
ClusterCard( ClusterCard(
cluster = cluster, cluster = cluster,
qualityTier = qualityResult.qualityTier,
canTrain = qualityResult.canTrain,
onClick = { onSelectCluster(cluster) } onClick = { onSelectCluster(cluster) }
) )
} }
@@ -75,19 +95,34 @@ fun ClusterGridScreen(
} }
/** /**
* Single cluster card with 2x2 face grid * Single cluster card with 2x2 face grid and quality badge
*/ */
@Composable @Composable
private fun ClusterCard( private fun ClusterCard(
cluster: FaceCluster, cluster: FaceCluster,
qualityTier: ClusterQualityTier,
canTrain: Boolean,
onClick: () -> Unit onClick: () -> Unit
) { ) {
Card( Card(
modifier = Modifier modifier = Modifier
.fillMaxWidth() .fillMaxWidth()
.aspectRatio(1f) .aspectRatio(1f)
.clickable(onClick = onClick), .clickable(onClick = onClick), // Always clickable - let dialog handle validation
elevation = CardDefaults.cardElevation(defaultElevation = 2.dp) elevation = CardDefaults.cardElevation(defaultElevation = 2.dp),
colors = CardDefaults.cardColors(
containerColor = when {
qualityTier == ClusterQualityTier.POOR ->
MaterialTheme.colorScheme.errorContainer.copy(alpha = 0.3f)
!canTrain ->
MaterialTheme.colorScheme.surfaceVariant.copy(alpha = 0.5f)
else ->
MaterialTheme.colorScheme.surface
}
)
) {
Box(
modifier = Modifier.fillMaxSize()
) { ) {
Column( Column(
modifier = Modifier.fillMaxSize() modifier = Modifier.fillMaxSize()
@@ -103,6 +138,7 @@ private fun ClusterCard(
facesToShow.getOrNull(0)?.let { face -> facesToShow.getOrNull(0)?.let { face ->
FaceThumbnail( FaceThumbnail(
imageUri = face.imageUri, imageUri = face.imageUri,
enabled = canTrain,
modifier = Modifier.weight(1f) modifier = Modifier.weight(1f)
) )
} ?: EmptyFaceSlot(Modifier.weight(1f)) } ?: EmptyFaceSlot(Modifier.weight(1f))
@@ -110,6 +146,7 @@ private fun ClusterCard(
facesToShow.getOrNull(1)?.let { face -> facesToShow.getOrNull(1)?.let { face ->
FaceThumbnail( FaceThumbnail(
imageUri = face.imageUri, imageUri = face.imageUri,
enabled = canTrain,
modifier = Modifier.weight(1f) modifier = Modifier.weight(1f)
) )
} ?: EmptyFaceSlot(Modifier.weight(1f)) } ?: EmptyFaceSlot(Modifier.weight(1f))
@@ -120,6 +157,7 @@ private fun ClusterCard(
facesToShow.getOrNull(2)?.let { face -> facesToShow.getOrNull(2)?.let { face ->
FaceThumbnail( FaceThumbnail(
imageUri = face.imageUri, imageUri = face.imageUri,
enabled = canTrain,
modifier = Modifier.weight(1f) modifier = Modifier.weight(1f)
) )
} ?: EmptyFaceSlot(Modifier.weight(1f)) } ?: EmptyFaceSlot(Modifier.weight(1f))
@@ -127,6 +165,7 @@ private fun ClusterCard(
facesToShow.getOrNull(3)?.let { face -> facesToShow.getOrNull(3)?.let { face ->
FaceThumbnail( FaceThumbnail(
imageUri = face.imageUri, imageUri = face.imageUri,
enabled = canTrain,
modifier = Modifier.weight(1f) modifier = Modifier.weight(1f)
) )
} ?: EmptyFaceSlot(Modifier.weight(1f)) } ?: EmptyFaceSlot(Modifier.weight(1f))
@@ -136,37 +175,113 @@ private fun ClusterCard(
// Footer with photo count // Footer with photo count
Surface( Surface(
modifier = Modifier.fillMaxWidth(), modifier = Modifier.fillMaxWidth(),
color = MaterialTheme.colorScheme.primaryContainer color = if (canTrain) {
MaterialTheme.colorScheme.primaryContainer
} else {
MaterialTheme.colorScheme.surfaceVariant
}
) {
Row(
modifier = Modifier.padding(12.dp),
verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.SpaceBetween
) { ) {
Text( Text(
text = "${cluster.photoCount} photos", text = "${cluster.photoCount} photos",
style = MaterialTheme.typography.bodyMedium, style = MaterialTheme.typography.bodyMedium,
fontWeight = FontWeight.SemiBold, fontWeight = FontWeight.SemiBold,
modifier = Modifier.padding(12.dp), color = if (canTrain) {
color = MaterialTheme.colorScheme.onPrimaryContainer MaterialTheme.colorScheme.onPrimaryContainer
} else {
MaterialTheme.colorScheme.onSurfaceVariant
}
) )
} }
} }
} }
// Quality badge overlay
QualityBadge(
qualityTier = qualityTier,
canTrain = canTrain,
modifier = Modifier
.align(Alignment.TopEnd)
.padding(8.dp)
)
}
}
}
/**
* Quality badge indicator
*/
@Composable
private fun QualityBadge(
qualityTier: ClusterQualityTier,
canTrain: Boolean,
modifier: Modifier = Modifier
) {
val (backgroundColor, iconColor, icon) = when (qualityTier) {
ClusterQualityTier.EXCELLENT -> Triple(
Color(0xFF1B5E20),
Color.White,
Icons.Default.Check
)
ClusterQualityTier.GOOD -> Triple(
Color(0xFF2E7D32),
Color.White,
Icons.Default.Check
)
ClusterQualityTier.POOR -> Triple(
Color(0xFFD32F2F),
Color.White,
Icons.Default.Warning
)
}
Surface(
modifier = modifier,
shape = CircleShape,
color = backgroundColor,
shadowElevation = 2.dp
) {
Box(
modifier = Modifier
.size(32.dp)
.padding(6.dp),
contentAlignment = Alignment.Center
) {
Icon(
imageVector = icon,
contentDescription = qualityTier.name,
tint = iconColor,
modifier = Modifier.size(20.dp)
)
}
}
} }
@Composable @Composable
private fun FaceThumbnail( private fun FaceThumbnail(
imageUri: String, imageUri: String,
enabled: Boolean,
modifier: Modifier = Modifier modifier: Modifier = Modifier
) { ) {
Box(modifier = modifier) {
AsyncImage( AsyncImage(
model = Uri.parse(imageUri), model = Uri.parse(imageUri),
contentDescription = "Face", contentDescription = "Face",
modifier = modifier modifier = Modifier
.fillMaxSize() .fillMaxSize()
.border( .border(
width = 0.5.dp, width = 0.5.dp,
color = MaterialTheme.colorScheme.outline.copy(alpha = 0.3f) color = MaterialTheme.colorScheme.outline.copy(alpha = 0.3f)
), ),
contentScale = ContentScale.Crop contentScale = ContentScale.Crop,
alpha = if (enabled) 1f else 0.6f
) )
} }
}
@Composable @Composable
private fun EmptyFaceSlot(modifier: Modifier = Modifier) { private fun EmptyFaceSlot(modifier: Modifier = Modifier) {

View File

@@ -11,11 +11,17 @@ import androidx.compose.ui.text.font.FontWeight
import androidx.compose.ui.text.style.TextAlign import androidx.compose.ui.text.style.TextAlign
import androidx.compose.ui.unit.dp import androidx.compose.ui.unit.dp
import androidx.hilt.navigation.compose.hiltViewModel import androidx.hilt.navigation.compose.hiltViewModel
import com.placeholder.sherpai2.domain.clustering.ClusterQualityAnalyzer
/** /**
* DiscoverPeopleScreen - COMPLETE WORKING VERSION * DiscoverPeopleScreen - COMPLETE WORKING VERSION WITH NAMING DIALOG
* *
* This handles ALL states properly including Idle state * This handles ALL states properly including the NamingCluster dialog
*
* IMPROVEMENTS:
* - ✅ Complete naming dialog integration
* - ✅ Quality analysis in cluster grid
* - ✅ Better error handling
*/ */
@OptIn(ExperimentalMaterial3Api::class) @OptIn(ExperimentalMaterial3Api::class)
@Composable @Composable
@@ -24,26 +30,11 @@ fun DiscoverPeopleScreen(
onNavigateBack: () -> Unit = {} onNavigateBack: () -> Unit = {}
) { ) {
val uiState by viewModel.uiState.collectAsState() val uiState by viewModel.uiState.collectAsState()
val qualityAnalyzer = remember { ClusterQualityAnalyzer() }
Scaffold( // No Scaffold, no TopAppBar - MainScreen handles that
topBar = {
TopAppBar(
title = { Text("Discover People") },
navigationIcon = {
IconButton(onClick = onNavigateBack) {
Icon(
imageVector = Icons.Default.Person,
contentDescription = "Back"
)
}
}
)
}
) { paddingValues ->
Box( Box(
modifier = Modifier modifier = Modifier.fillMaxSize()
.fillMaxSize()
.padding(paddingValues)
) { ) {
when (val state = uiState) { when (val state = uiState) {
// ===== IDLE STATE (START HERE) ===== // ===== IDLE STATE (START HERE) =====
@@ -68,7 +59,8 @@ fun DiscoverPeopleScreen(
result = state.result, result = state.result,
onSelectCluster = { cluster -> onSelectCluster = { cluster ->
viewModel.selectCluster(cluster) viewModel.selectCluster(cluster)
} },
qualityAnalyzer = qualityAnalyzer
) )
} }
@@ -77,11 +69,32 @@ fun DiscoverPeopleScreen(
LoadingContent(message = "Analyzing cluster quality...") LoadingContent(message = "Analyzing cluster quality...")
} }
// ===== NAMING A CLUSTER ===== // ===== NAMING A CLUSTER (SHOW DIALOG) =====
is DiscoverUiState.NamingCluster -> { is DiscoverUiState.NamingCluster -> {
Text( // Show cluster grid in background
text = "Naming dialog for cluster ${state.selectedCluster.clusterId}\n\nDialog UI coming...", ClusterGridScreen(
modifier = Modifier.align(Alignment.Center) result = state.result,
onSelectCluster = { /* Disabled while dialog open */ },
qualityAnalyzer = qualityAnalyzer
)
// Show naming dialog overlay
NamingDialog(
cluster = state.selectedCluster,
suggestedSiblings = state.suggestedSiblings,
onConfirm = { name, dateOfBirth, isChild, selectedSiblings ->
viewModel.confirmClusterName(
cluster = state.selectedCluster,
name = name,
dateOfBirth = dateOfBirth,
isChild = isChild,
selectedSiblings = selectedSiblings
)
},
onDismiss = {
viewModel.cancelNaming()
},
qualityAnalyzer = qualityAnalyzer
) )
} }
@@ -141,8 +154,6 @@ fun DiscoverPeopleScreen(
} }
} }
} }
}
// ===== IDLE STATE CONTENT ===== // ===== IDLE STATE CONTENT =====
@Composable @Composable
@@ -165,19 +176,11 @@ private fun IdleStateContent(
Spacer(modifier = Modifier.height(32.dp)) Spacer(modifier = Modifier.height(32.dp))
Text(
text = "Discover People",
style = MaterialTheme.typography.headlineLarge,
fontWeight = FontWeight.Bold
)
Spacer(modifier = Modifier.height(16.dp))
Text( Text(
text = "Automatically find and organize people in your photo library", text = "Automatically find and organize people in your photo library",
style = MaterialTheme.typography.bodyLarge, style = MaterialTheme.typography.headlineSmall,
textAlign = TextAlign.Center, textAlign = TextAlign.Center,
color = MaterialTheme.colorScheme.onSurfaceVariant color = MaterialTheme.colorScheme.onSurface
) )
Spacer(modifier = Modifier.height(48.dp)) Spacer(modifier = Modifier.height(48.dp))

View File

@@ -0,0 +1,480 @@
package com.placeholder.sherpai2.ui.discover
import androidx.compose.foundation.background
import androidx.compose.foundation.border
import androidx.compose.foundation.clickable
import androidx.compose.foundation.layout.*
import androidx.compose.foundation.lazy.LazyRow
import androidx.compose.foundation.lazy.items
import androidx.compose.foundation.rememberScrollState
import androidx.compose.foundation.shape.CircleShape
import androidx.compose.foundation.shape.RoundedCornerShape
import androidx.compose.foundation.text.KeyboardActions
import androidx.compose.foundation.text.KeyboardOptions
import androidx.compose.foundation.verticalScroll
import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.filled.*
import androidx.compose.material3.*
import androidx.compose.runtime.*
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.draw.clip
import androidx.compose.ui.graphics.Color
import androidx.compose.ui.layout.ContentScale
import androidx.compose.ui.platform.LocalSoftwareKeyboardController
import androidx.compose.ui.text.font.FontWeight
import androidx.compose.ui.text.input.ImeAction
import androidx.compose.ui.text.input.KeyboardCapitalization
import androidx.compose.ui.text.input.KeyboardType
import androidx.compose.ui.text.style.TextAlign
import androidx.compose.ui.unit.dp
import androidx.compose.ui.window.Dialog
import coil.compose.AsyncImage
import com.placeholder.sherpai2.domain.clustering.ClusterQualityAnalyzer
import com.placeholder.sherpai2.domain.clustering.ClusterQualityTier
import com.placeholder.sherpai2.domain.clustering.FaceCluster
import java.text.SimpleDateFormat
import java.util.*
/**
* NamingDialog - Complete dialog for naming a cluster
*
* Features:
* - Name input with validation
* - Child toggle with date of birth picker
* - Sibling cluster selection
* - Quality warnings display
* - Preview of representative faces
*
* IMPROVEMENTS:
* - ✅ Complete UI implementation
* - ✅ Quality analysis integration
* - ✅ Sibling selection
* - ✅ Form validation
*/
@OptIn(ExperimentalMaterial3Api::class)
@Composable
fun NamingDialog(
cluster: FaceCluster,
suggestedSiblings: List<FaceCluster>,
onConfirm: (name: String, dateOfBirth: Long?, isChild: Boolean, selectedSiblings: List<Int>) -> Unit,
onDismiss: () -> Unit,
qualityAnalyzer: ClusterQualityAnalyzer = remember { ClusterQualityAnalyzer() }
) {
var name by remember { mutableStateOf("") }
var isChild by remember { mutableStateOf(false) }
var showDatePicker by remember { mutableStateOf(false) }
var dateOfBirth by remember { mutableStateOf<Long?>(null) }
var selectedSiblingIds by remember { mutableStateOf(setOf<Int>()) }
// Analyze cluster quality
val qualityResult = remember(cluster) {
qualityAnalyzer.analyzeCluster(cluster)
}
val keyboardController = LocalSoftwareKeyboardController.current
val dateFormatter = remember { SimpleDateFormat("MMM dd, yyyy", Locale.getDefault()) }
Dialog(onDismissRequest = onDismiss) {
Card(
modifier = Modifier
.fillMaxWidth()
.fillMaxHeight(0.9f),
shape = RoundedCornerShape(16.dp),
elevation = CardDefaults.cardElevation(defaultElevation = 8.dp)
) {
Column(
modifier = Modifier
.fillMaxSize()
.verticalScroll(rememberScrollState())
) {
// Header
Surface(
color = MaterialTheme.colorScheme.primaryContainer
) {
Row(
modifier = Modifier
.fillMaxWidth()
.padding(16.dp),
horizontalArrangement = Arrangement.SpaceBetween,
verticalAlignment = Alignment.CenterVertically
) {
Column(modifier = Modifier.weight(1f)) {
Text(
text = "Name This Person",
style = MaterialTheme.typography.titleLarge,
fontWeight = FontWeight.Bold,
color = MaterialTheme.colorScheme.onPrimaryContainer
)
Text(
text = "${cluster.photoCount} photos",
style = MaterialTheme.typography.bodyMedium,
color = MaterialTheme.colorScheme.onPrimaryContainer.copy(alpha = 0.7f)
)
}
IconButton(onClick = onDismiss) {
Icon(
imageVector = Icons.Default.Close,
contentDescription = "Close",
tint = MaterialTheme.colorScheme.onPrimaryContainer
)
}
}
}
Column(
modifier = Modifier.padding(16.dp)
) {
// Preview faces
Text(
text = "Preview",
style = MaterialTheme.typography.titleMedium,
fontWeight = FontWeight.SemiBold
)
Spacer(modifier = Modifier.height(8.dp))
LazyRow(
horizontalArrangement = Arrangement.spacedBy(8.dp)
) {
items(cluster.representativeFaces.take(6)) { face ->
AsyncImage(
model = android.net.Uri.parse(face.imageUri),
contentDescription = "Preview",
modifier = Modifier
.size(80.dp)
.clip(RoundedCornerShape(8.dp))
.border(
width = 1.dp,
color = MaterialTheme.colorScheme.outline,
shape = RoundedCornerShape(8.dp)
),
contentScale = ContentScale.Crop
)
}
}
Spacer(modifier = Modifier.height(20.dp))
// Quality warning (if applicable)
if (qualityResult.qualityTier != ClusterQualityTier.EXCELLENT) {
QualityWarningCard(qualityResult = qualityResult)
Spacer(modifier = Modifier.height(16.dp))
}
// Name input
OutlinedTextField(
value = name,
onValueChange = { name = it },
label = { Text("Name") },
placeholder = { Text("Enter person's name") },
modifier = Modifier.fillMaxWidth(),
singleLine = true,
leadingIcon = {
Icon(
imageVector = Icons.Default.Person,
contentDescription = null
)
},
keyboardOptions = KeyboardOptions(
capitalization = KeyboardCapitalization.Words,
imeAction = ImeAction.Done
),
keyboardActions = KeyboardActions(
onDone = { keyboardController?.hide() }
)
)
Spacer(modifier = Modifier.height(16.dp))
// Child toggle
Row(
modifier = Modifier
.fillMaxWidth()
.clip(RoundedCornerShape(8.dp))
.clickable { isChild = !isChild }
.background(
if (isChild) MaterialTheme.colorScheme.primaryContainer
else MaterialTheme.colorScheme.surfaceVariant
)
.padding(16.dp),
verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.SpaceBetween
) {
Row(
verticalAlignment = Alignment.CenterVertically
) {
Icon(
imageVector = Icons.Default.Face,
contentDescription = null,
tint = if (isChild) MaterialTheme.colorScheme.onPrimaryContainer
else MaterialTheme.colorScheme.onSurfaceVariant
)
Spacer(modifier = Modifier.width(12.dp))
Column {
Text(
text = "This is a child",
style = MaterialTheme.typography.bodyLarge,
fontWeight = FontWeight.Medium,
color = if (isChild) MaterialTheme.colorScheme.onPrimaryContainer
else MaterialTheme.colorScheme.onSurfaceVariant
)
Text(
text = "For age-appropriate filtering",
style = MaterialTheme.typography.bodySmall,
color = if (isChild) MaterialTheme.colorScheme.onPrimaryContainer.copy(alpha = 0.7f)
else MaterialTheme.colorScheme.onSurfaceVariant.copy(alpha = 0.7f)
)
}
}
Switch(
checked = isChild,
onCheckedChange = { isChild = it }
)
}
// Date of birth (if child)
if (isChild) {
Spacer(modifier = Modifier.height(12.dp))
OutlinedButton(
onClick = { showDatePicker = true },
modifier = Modifier.fillMaxWidth()
) {
Icon(
imageVector = Icons.Default.DateRange,
contentDescription = null
)
Spacer(modifier = Modifier.width(8.dp))
Text(
text = dateOfBirth?.let { dateFormatter.format(Date(it)) }
?: "Set date of birth (optional)"
)
}
}
// Sibling selection
if (suggestedSiblings.isNotEmpty()) {
Spacer(modifier = Modifier.height(20.dp))
Text(
text = "Appears with",
style = MaterialTheme.typography.titleMedium,
fontWeight = FontWeight.SemiBold
)
Text(
text = "Select siblings or family members",
style = MaterialTheme.typography.bodySmall,
color = MaterialTheme.colorScheme.onSurfaceVariant
)
Spacer(modifier = Modifier.height(8.dp))
suggestedSiblings.forEach { sibling ->
SiblingSelectionItem(
cluster = sibling,
selected = sibling.clusterId in selectedSiblingIds,
onToggle = {
selectedSiblingIds = if (sibling.clusterId in selectedSiblingIds) {
selectedSiblingIds - sibling.clusterId
} else {
selectedSiblingIds + sibling.clusterId
}
}
)
Spacer(modifier = Modifier.height(8.dp))
}
}
Spacer(modifier = Modifier.height(24.dp))
// Action buttons
Row(
modifier = Modifier.fillMaxWidth(),
horizontalArrangement = Arrangement.spacedBy(12.dp)
) {
OutlinedButton(
onClick = onDismiss,
modifier = Modifier.weight(1f)
) {
Text("Cancel")
}
Button(
onClick = {
if (name.isNotBlank()) {
onConfirm(
name.trim(),
dateOfBirth,
isChild,
selectedSiblingIds.toList()
)
}
},
modifier = Modifier.weight(1f),
enabled = name.isNotBlank() && qualityResult.canTrain
) {
Icon(
imageVector = Icons.Default.Check,
contentDescription = null,
modifier = Modifier.size(20.dp)
)
Spacer(modifier = Modifier.width(8.dp))
Text("Create Model")
}
}
}
}
}
}
// Date picker dialog
if (showDatePicker) {
val datePickerState = rememberDatePickerState()
DatePickerDialog(
onDismissRequest = { showDatePicker = false },
confirmButton = {
TextButton(
onClick = {
dateOfBirth = datePickerState.selectedDateMillis
showDatePicker = false
}
) {
Text("OK")
}
},
dismissButton = {
TextButton(onClick = { showDatePicker = false }) {
Text("Cancel")
}
}
) {
DatePicker(state = datePickerState)
}
}
}
/**
* Quality warning card
*/
@Composable
private fun QualityWarningCard(qualityResult: com.placeholder.sherpai2.domain.clustering.ClusterQualityResult) {
val (backgroundColor, iconColor) = when (qualityResult.qualityTier) {
ClusterQualityTier.GOOD -> Pair(
Color(0xFFFFF9C4),
Color(0xFFF57F17)
)
ClusterQualityTier.POOR -> Pair(
Color(0xFFFFEBEE),
Color(0xFFD32F2F)
)
else -> Pair(
MaterialTheme.colorScheme.surfaceVariant,
MaterialTheme.colorScheme.onSurfaceVariant
)
}
Card(
modifier = Modifier.fillMaxWidth(),
colors = CardDefaults.cardColors(
containerColor = backgroundColor
)
) {
Column(
modifier = Modifier.padding(12.dp)
) {
Row(
verticalAlignment = Alignment.CenterVertically
) {
Icon(
imageVector = Icons.Default.Warning,
contentDescription = null,
tint = iconColor,
modifier = Modifier.size(20.dp)
)
Spacer(modifier = Modifier.width(8.dp))
Text(
text = when (qualityResult.qualityTier) {
ClusterQualityTier.GOOD -> "Review Before Training"
ClusterQualityTier.POOR -> "Quality Issues Detected"
else -> ""
},
style = MaterialTheme.typography.titleSmall,
fontWeight = FontWeight.Bold,
color = iconColor
)
}
Spacer(modifier = Modifier.height(8.dp))
qualityResult.warnings.forEach { warning ->
Text(
text = "$warning",
style = MaterialTheme.typography.bodySmall,
color = MaterialTheme.colorScheme.onSurfaceVariant
)
}
}
}
}
/**
* Sibling selection item
*/
@Composable
private fun SiblingSelectionItem(
cluster: FaceCluster,
selected: Boolean,
onToggle: () -> Unit
) {
Row(
modifier = Modifier
.fillMaxWidth()
.clip(RoundedCornerShape(8.dp))
.clickable(onClick = onToggle)
.background(
if (selected) MaterialTheme.colorScheme.primaryContainer
else MaterialTheme.colorScheme.surfaceVariant
)
.padding(12.dp),
verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.SpaceBetween
) {
Row(
verticalAlignment = Alignment.CenterVertically
) {
// Preview face
AsyncImage(
model = android.net.Uri.parse(cluster.representativeFaces.firstOrNull()?.imageUri ?: ""),
contentDescription = "Preview",
modifier = Modifier
.size(40.dp)
.clip(CircleShape)
.border(
width = 1.dp,
color = MaterialTheme.colorScheme.outline,
shape = CircleShape
),
contentScale = ContentScale.Crop
)
Spacer(modifier = Modifier.width(12.dp))
Text(
text = "${cluster.photoCount} photos together",
style = MaterialTheme.typography.bodyMedium,
color = if (selected) MaterialTheme.colorScheme.onPrimaryContainer
else MaterialTheme.colorScheme.onSurfaceVariant
)
}
Checkbox(
checked = selected,
onCheckedChange = { onToggle() }
)
}
}

View File

@@ -1,56 +1,48 @@
package com.placeholder.sherpai2.ui.presentation package com.placeholder.sherpai2.ui.presentation
import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.padding import androidx.compose.foundation.layout.padding
import androidx.compose.material.icons.Icons import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.filled.* import androidx.compose.material.icons.filled.Menu
import androidx.compose.material3.* import androidx.compose.material3.*
import androidx.compose.runtime.* import androidx.compose.runtime.*
import androidx.compose.ui.Modifier import androidx.compose.ui.Modifier
import androidx.compose.ui.text.font.FontWeight
import androidx.hilt.navigation.compose.hiltViewModel import androidx.hilt.navigation.compose.hiltViewModel
import androidx.navigation.compose.currentBackStackEntryAsState
import androidx.navigation.compose.rememberNavController import androidx.navigation.compose.rememberNavController
import androidx.navigation.compose.currentBackStackEntryAsState
import com.placeholder.sherpai2.ui.navigation.AppNavHost import com.placeholder.sherpai2.ui.navigation.AppNavHost
import com.placeholder.sherpai2.ui.navigation.AppRoutes import com.placeholder.sherpai2.ui.navigation.AppRoutes
import kotlinx.coroutines.launch import kotlinx.coroutines.launch
/** /**
* MainScreen - UPDATED with auto face cache check * MainScreen - Complete app container with drawer navigation
* *
* NEW: Prompts user to populate face cache on app launch if needed * CRITICAL FIX APPLIED:
* FIXED: Prevents double headers for screens with their own TopAppBar * ✅ Removed AppRoutes.DISCOVER from screensWithOwnTopBar
* ✅ DiscoverPeopleScreen now shows hamburger menu + "Discover People" title!
*/ */
@OptIn(ExperimentalMaterial3Api::class) @OptIn(ExperimentalMaterial3Api::class)
@Composable @Composable
fun MainScreen( fun MainScreen(
mainViewModel: MainViewModel = hiltViewModel() // Same package - no import needed! viewModel: MainViewModel = hiltViewModel()
) { ) {
val drawerState = rememberDrawerState(initialValue = DrawerValue.Closed)
val scope = rememberCoroutineScope()
val navController = rememberNavController() val navController = rememberNavController()
val drawerState = rememberDrawerState(DrawerValue.Closed)
val scope = rememberCoroutineScope()
val navBackStackEntry by navController.currentBackStackEntryAsState() val currentBackStackEntry by navController.currentBackStackEntryAsState()
val currentRoute = navBackStackEntry?.destination?.route ?: AppRoutes.SEARCH val currentRoute = currentBackStackEntry?.destination?.route
// Face cache status // Face cache prompt dialog state
val needsFaceCache by mainViewModel.needsFaceCachePopulation.collectAsState() val needsFaceCachePopulation by viewModel.needsFaceCachePopulation.collectAsState()
val unscannedCount by mainViewModel.unscannedPhotoCount.collectAsState() val unscannedPhotoCount by viewModel.unscannedPhotoCount.collectAsState()
// Show face cache prompt dialog if needed // ✅ CRITICAL FIX: DISCOVER is NOT in this list!
if (needsFaceCache && unscannedCount > 0) { // These screens handle their own TopAppBar/navigation
FaceCachePromptDialog( val screensWithOwnTopBar = setOf(
unscannedPhotoCount = unscannedCount, AppRoutes.IMAGE_DETAIL,
onDismiss = { mainViewModel.dismissFaceCachePrompt() }, AppRoutes.TRAINING_SCREEN,
onScanNow = { AppRoutes.CROP_SCREEN
mainViewModel.dismissFaceCachePrompt()
// Navigate to Photo Utilities
navController.navigate(AppRoutes.UTILITIES) {
launchSingleTop = true
}
}
) )
}
ModalNavigationDrawer( ModalNavigationDrawer(
drawerState = drawerState, drawerState = drawerState,
@@ -60,133 +52,86 @@ fun MainScreen(
onDestinationClicked = { route -> onDestinationClicked = { route ->
scope.launch { scope.launch {
drawerState.close() drawerState.close()
if (route != currentRoute) { }
navController.navigate(route) { navController.navigate(route) {
popUpTo(navController.graph.startDestinationId) {
saveState = true
}
launchSingleTop = true launchSingleTop = true
} restoreState = true
}
} }
} }
) )
}, }
) { ) {
// CRITICAL: Some screens manage their own TopAppBar
// Hide MainScreen's TopAppBar for these routes to prevent double headers
val screensWithOwnTopBar = setOf(
AppRoutes.TRAINING_PHOTO_SELECTOR, // Has its own TopAppBar with subtitle
"album/", // Album views have their own TopAppBar (prefix match)
AppRoutes.IMAGE_DETAIL // Image detail has its own TopAppBar
)
// Check if current route starts with any excluded pattern
val showTopBar = screensWithOwnTopBar.none { currentRoute.startsWith(it) }
Scaffold( Scaffold(
topBar = { topBar = {
if (showTopBar) { // ✅ Show TopAppBar for ALL screens except those with their own
if (currentRoute !in screensWithOwnTopBar) {
TopAppBar( TopAppBar(
title = { title = {
Column {
Text( Text(
text = getScreenTitle(currentRoute), text = when (currentRoute) {
style = MaterialTheme.typography.titleLarge, AppRoutes.SEARCH -> "Search"
fontWeight = FontWeight.Bold AppRoutes.EXPLORE -> "Explore"
) AppRoutes.COLLECTIONS -> "Collections"
getScreenSubtitle(currentRoute)?.let { subtitle -> AppRoutes.DISCOVER -> "Discover People" // ✅ SHOWS NOW!
Text( AppRoutes.INVENTORY -> "People"
text = subtitle, AppRoutes.TRAIN -> "Train Model"
style = MaterialTheme.typography.bodySmall, AppRoutes.TAGS -> "Tags"
color = MaterialTheme.colorScheme.onSurfaceVariant AppRoutes.UTILITIES -> "Utilities"
) AppRoutes.SETTINGS -> "Settings"
AppRoutes.MODELS -> "AI Models"
else -> {
// Handle dynamic routes like album/{type}/{id}
if (currentRoute?.startsWith("album/") == true) {
"Album"
} else {
"SherpAI"
} }
} }
}
)
}, },
navigationIcon = { navigationIcon = {
IconButton(
onClick = { scope.launch { drawerState.open() } }
) {
Icon(
Icons.Default.Menu,
contentDescription = "Open Menu",
tint = MaterialTheme.colorScheme.primary
)
}
},
actions = {
// Dynamic actions based on current screen
when (currentRoute) {
AppRoutes.SEARCH -> {
IconButton(onClick = { /* TODO: Open filter dialog */ }) {
Icon(
Icons.Default.FilterList,
contentDescription = "Filter",
tint = MaterialTheme.colorScheme.primary
)
}
}
AppRoutes.INVENTORY -> {
IconButton(onClick = { IconButton(onClick = {
navController.navigate(AppRoutes.TRAIN) scope.launch {
drawerState.open()
}
}) { }) {
Icon( Icon(
Icons.Default.PersonAdd, imageVector = Icons.Default.Menu,
contentDescription = "Add Person", contentDescription = "Open menu"
tint = MaterialTheme.colorScheme.primary
) )
} }
}
}
}, },
colors = TopAppBarDefaults.topAppBarColors( colors = TopAppBarDefaults.topAppBarColors(
containerColor = MaterialTheme.colorScheme.surface, containerColor = MaterialTheme.colorScheme.primaryContainer,
titleContentColor = MaterialTheme.colorScheme.onSurface, titleContentColor = MaterialTheme.colorScheme.onPrimaryContainer,
navigationIconContentColor = MaterialTheme.colorScheme.primary, navigationIconContentColor = MaterialTheme.colorScheme.onPrimaryContainer,
actionIconContentColor = MaterialTheme.colorScheme.primary actionIconContentColor = MaterialTheme.colorScheme.onPrimaryContainer
) )
) )
} }
} }
) { paddingValues -> ) { paddingValues ->
// ✅ Use YOUR existing AppNavHost - it already has all the screens defined!
AppNavHost( AppNavHost(
navController = navController, navController = navController,
modifier = Modifier.padding(paddingValues) modifier = Modifier.padding(paddingValues)
) )
} }
} }
}
/** // ✅ Face cache prompt dialog (shows on app launch if needed)
* Get human-readable screen title if (needsFaceCachePopulation) {
*/ FaceCachePromptDialog(
private fun getScreenTitle(route: String): String { unscannedPhotoCount = unscannedPhotoCount,
return when (route) { onDismiss = { viewModel.dismissFaceCachePrompt() },
AppRoutes.SEARCH -> "Search" onScanNow = {
AppRoutes.EXPLORE -> "Explore" viewModel.dismissFaceCachePrompt()
AppRoutes.COLLECTIONS -> "Collections" navController.navigate(AppRoutes.UTILITIES)
AppRoutes.DISCOVER -> "Discover People" }
AppRoutes.INVENTORY -> "People" )
AppRoutes.TRAIN -> "Train New Person"
AppRoutes.MODELS -> "AI Models"
AppRoutes.TAGS -> "Tag Management"
AppRoutes.UTILITIES -> "Photo Util."
AppRoutes.SETTINGS -> "Settings"
else -> "SherpAI"
}
}
/**
* Get subtitle for screens that need context
*/
private fun getScreenSubtitle(route: String): String? {
return when (route) {
AppRoutes.SEARCH -> "Find photos by tags, people, or date"
AppRoutes.EXPLORE -> "Browse your collection"
AppRoutes.COLLECTIONS -> "Your photo collections"
AppRoutes.DISCOVER -> "Auto-find faces in your library"
AppRoutes.INVENTORY -> "Trained face models"
AppRoutes.TRAIN -> "Add a new person to recognize"
AppRoutes.TAGS -> "Organize your photo collection"
AppRoutes.UTILITIES -> "Tools for managing collection"
else -> null
} }
} }