discover dez
This commit is contained in:
BIN
app/src/main/assets/mobilefacenet.tflite
Normal file
BIN
app/src/main/assets/mobilefacenet.tflite
Normal file
Binary file not shown.
@@ -1,129 +1,91 @@
|
||||
package com.placeholder.sherpai2.data.local.dao
|
||||
|
||||
import androidx.room.*
|
||||
import androidx.room.Dao
|
||||
import androidx.room.Insert
|
||||
import androidx.room.OnConflictStrategy
|
||||
import androidx.room.Query
|
||||
import androidx.room.Update
|
||||
import com.placeholder.sherpai2.data.local.entity.FaceCacheEntity
|
||||
import kotlinx.coroutines.flow.Flow
|
||||
|
||||
/**
|
||||
* FaceCacheDao - Query face metadata for intelligent filtering
|
||||
*
|
||||
* ENABLES SMART CLUSTERING:
|
||||
* - Pre-filter to high-quality faces only
|
||||
* - Avoid processing blurry/distant faces
|
||||
* - Faster clustering with better results
|
||||
* FaceCacheDao - Face detection cache with NEW queries for two-stage clustering
|
||||
*/
|
||||
@Dao
|
||||
interface FaceCacheDao {
|
||||
|
||||
// ═══════════════════════════════════════
|
||||
// INSERT / UPDATE
|
||||
// ═══════════════════════════════════════
|
||||
|
||||
@Insert(onConflict = OnConflictStrategy.REPLACE)
|
||||
suspend fun insert(faceCache: FaceCacheEntity)
|
||||
|
||||
@Insert(onConflict = OnConflictStrategy.REPLACE)
|
||||
suspend fun insertAll(faceCaches: List<FaceCacheEntity>)
|
||||
|
||||
/**
|
||||
* Get ALL high-quality solo faces for clustering
|
||||
*
|
||||
* FILTERS:
|
||||
* - Solo photos only (joins with images.faceCount = 1)
|
||||
* - Large enough (isLargeEnough = true)
|
||||
* - Good quality score (>= 0.6)
|
||||
* - Frontal faces preferred (isFrontal = true)
|
||||
*/
|
||||
@Query("""
|
||||
SELECT fc.* FROM face_cache fc
|
||||
INNER JOIN images i ON fc.imageId = i.imageId
|
||||
WHERE i.faceCount = 1
|
||||
AND fc.isLargeEnough = 1
|
||||
AND fc.qualityScore >= 0.6
|
||||
AND fc.isFrontal = 1
|
||||
ORDER BY fc.qualityScore DESC
|
||||
""")
|
||||
suspend fun getHighQualitySoloFaces(): List<FaceCacheEntity>
|
||||
@Update
|
||||
suspend fun update(faceCache: FaceCacheEntity)
|
||||
|
||||
// ═══════════════════════════════════════
|
||||
// NEW CLUSTERING QUERIES ⭐
|
||||
// ═══════════════════════════════════════
|
||||
|
||||
/**
|
||||
* Get high-quality faces from ANY photo (including group photos)
|
||||
* Use when not enough solo photos available
|
||||
* Get high-quality solo faces for Stage 1 clustering
|
||||
*
|
||||
* Filters:
|
||||
* - Solo photos (faceCount = 1)
|
||||
* - Large faces (faceAreaRatio >= minFaceRatio)
|
||||
* - Has embedding
|
||||
*/
|
||||
@Query("""
|
||||
SELECT * FROM face_cache
|
||||
WHERE isLargeEnough = 1
|
||||
AND qualityScore >= 0.6
|
||||
AND isFrontal = 1
|
||||
ORDER BY qualityScore DESC
|
||||
SELECT fc.*
|
||||
FROM face_cache fc
|
||||
INNER JOIN images i ON fc.imageId = i.imageId
|
||||
WHERE i.faceCount = 1
|
||||
AND fc.faceAreaRatio >= :minFaceRatio
|
||||
AND fc.embedding IS NOT NULL
|
||||
ORDER BY fc.faceAreaRatio DESC
|
||||
LIMIT :limit
|
||||
""")
|
||||
suspend fun getHighQualityFaces(limit: Int = 1000): List<FaceCacheEntity>
|
||||
suspend fun getHighQualitySoloFaces(
|
||||
minFaceRatio: Float = 0.015f,
|
||||
limit: Int = 2000
|
||||
): List<FaceCacheEntity>
|
||||
|
||||
/**
|
||||
* Get faces for a specific image
|
||||
*/
|
||||
@Query("SELECT * FROM face_cache WHERE imageId = :imageId ORDER BY faceIndex ASC")
|
||||
suspend fun getFacesForImage(imageId: String): List<FaceCacheEntity>
|
||||
|
||||
/**
|
||||
* Count high-quality solo faces (for UI display)
|
||||
* FALLBACK: Get ANY solo faces with embeddings
|
||||
* Used if getHighQualitySoloFaces() returns empty
|
||||
*/
|
||||
@Query("""
|
||||
SELECT COUNT(*) FROM face_cache fc
|
||||
SELECT fc.*
|
||||
FROM face_cache fc
|
||||
INNER JOIN images i ON fc.imageId = i.imageId
|
||||
WHERE i.faceCount = 1
|
||||
AND fc.isLargeEnough = 1
|
||||
AND fc.qualityScore >= 0.6
|
||||
""")
|
||||
suspend fun getHighQualitySoloFaceCount(): Int
|
||||
|
||||
/**
|
||||
* Get quality distribution stats
|
||||
*/
|
||||
@Query("""
|
||||
SELECT
|
||||
SUM(CASE WHEN qualityScore >= 0.8 THEN 1 ELSE 0 END) as excellent,
|
||||
SUM(CASE WHEN qualityScore >= 0.6 AND qualityScore < 0.8 THEN 1 ELSE 0 END) as good,
|
||||
SUM(CASE WHEN qualityScore < 0.6 THEN 1 ELSE 0 END) as poor,
|
||||
COUNT(*) as total
|
||||
FROM face_cache
|
||||
""")
|
||||
suspend fun getQualityStats(): FaceQualityStats?
|
||||
|
||||
/**
|
||||
* Delete cache for specific image (when image is deleted)
|
||||
*/
|
||||
@Query("DELETE FROM face_cache WHERE imageId = :imageId")
|
||||
suspend fun deleteCacheForImage(imageId: String)
|
||||
|
||||
/**
|
||||
* Delete all cache (for full rebuild)
|
||||
*/
|
||||
@Query("DELETE FROM face_cache")
|
||||
suspend fun deleteAll()
|
||||
|
||||
/**
|
||||
* Get faces with embeddings already computed
|
||||
* (Ultra-fast clustering - no need to re-generate)
|
||||
*/
|
||||
@Query("""
|
||||
SELECT fc.* FROM face_cache fc
|
||||
INNER JOIN images i ON fc.imageId = i.imageId
|
||||
WHERE i.faceCount = 1
|
||||
AND fc.isLargeEnough = 1
|
||||
WHERE i.faceCount = 1
|
||||
AND fc.embedding IS NOT NULL
|
||||
ORDER BY fc.qualityScore DESC
|
||||
LIMIT :limit
|
||||
""")
|
||||
suspend fun getSoloFacesWithEmbeddings(limit: Int = 2000): List<FaceCacheEntity>
|
||||
}
|
||||
suspend fun getSoloFacesWithEmbeddings(
|
||||
limit: Int = 2000
|
||||
): List<FaceCacheEntity>
|
||||
|
||||
/**
|
||||
* Quality statistics result
|
||||
*/
|
||||
data class FaceQualityStats(
|
||||
val excellent: Int, // qualityScore >= 0.8
|
||||
val good: Int, // 0.6 <= qualityScore < 0.8
|
||||
val poor: Int, // qualityScore < 0.6
|
||||
val total: Int
|
||||
) {
|
||||
val excellentPercent: Float get() = if (total > 0) excellent.toFloat() / total else 0f
|
||||
val goodPercent: Float get() = if (total > 0) good.toFloat() / total else 0f
|
||||
val poorPercent: Float get() = if (total > 0) poor.toFloat() / total else 0f
|
||||
// ═══════════════════════════════════════
|
||||
// EXISTING QUERIES (keep as-is)
|
||||
// ═══════════════════════════════════════
|
||||
|
||||
@Query("SELECT * FROM face_cache WHERE id = :id")
|
||||
suspend fun getFaceCacheById(id: String): FaceCacheEntity?
|
||||
|
||||
@Query("SELECT * FROM face_cache WHERE imageId = :imageId ORDER BY faceIndex")
|
||||
suspend fun getFaceCacheForImage(imageId: String): List<FaceCacheEntity>
|
||||
|
||||
@Query("DELETE FROM face_cache WHERE imageId = :imageId")
|
||||
suspend fun deleteFaceCacheForImage(imageId: String)
|
||||
|
||||
@Query("DELETE FROM face_cache")
|
||||
suspend fun deleteAll()
|
||||
|
||||
@Query("DELETE FROM face_cache WHERE cacheVersion < :version")
|
||||
suspend fun deleteOldVersions(version: Int)
|
||||
}
|
||||
@@ -1,7 +1,7 @@
|
||||
package com.placeholder.sherpai2.domain.clustering
|
||||
|
||||
import android.graphics.Rect
|
||||
import com.placeholder.sherpai2.domain.clustering.DetectedFaceWithEmbedding
|
||||
import android.util.Log
|
||||
import javax.inject.Inject
|
||||
import javax.inject.Singleton
|
||||
import kotlin.math.sqrt
|
||||
@@ -9,56 +9,62 @@ import kotlin.math.sqrt
|
||||
/**
|
||||
* ClusterQualityAnalyzer - Validate cluster quality BEFORE training
|
||||
*
|
||||
* PURPOSE: Prevent training on "dirty" clusters (siblings merged, poor quality faces)
|
||||
*
|
||||
* CHECKS:
|
||||
* 1. Solo photo count (min 6 required)
|
||||
* 2. Face size (min 15% of image - clear, not distant)
|
||||
* 3. Internal consistency (all faces should match well)
|
||||
* 4. Outlier detection (find faces that don't belong)
|
||||
*
|
||||
* QUALITY TIERS:
|
||||
* - Excellent (95%+): Safe to train immediately
|
||||
* - Good (85-94%): Review outliers, then train
|
||||
* - Poor (<85%): Likely mixed people - DO NOT TRAIN!
|
||||
* RELAXED THRESHOLDS for real-world photos (social media, distant shots):
|
||||
* - Face size: 3% (down from 15%)
|
||||
* - Outlier threshold: 65% (down from 75%)
|
||||
* - GOOD tier: 75% (down from 85%)
|
||||
* - EXCELLENT tier: 85% (down from 95%)
|
||||
*/
|
||||
@Singleton
|
||||
class ClusterQualityAnalyzer @Inject constructor() {
|
||||
|
||||
companion object {
|
||||
private const val TAG = "ClusterQuality"
|
||||
private const val MIN_SOLO_PHOTOS = 6
|
||||
private const val MIN_FACE_SIZE_RATIO = 0.15f // 15% of image
|
||||
private const val MIN_INTERNAL_SIMILARITY = 0.80f
|
||||
private const val OUTLIER_THRESHOLD = 0.75f
|
||||
private const val EXCELLENT_THRESHOLD = 0.95f
|
||||
private const val GOOD_THRESHOLD = 0.85f
|
||||
private const val MIN_FACE_SIZE_RATIO = 0.03f // 3% of image (RELAXED)
|
||||
private const val MIN_FACE_DIMENSION_PIXELS = 50 // 50px minimum (RELAXED)
|
||||
private const val FALLBACK_MIN_DIMENSION = 80 // Fallback when no dimensions
|
||||
private const val MIN_INTERNAL_SIMILARITY = 0.75f
|
||||
private const val OUTLIER_THRESHOLD = 0.65f // RELAXED
|
||||
private const val EXCELLENT_THRESHOLD = 0.85f // RELAXED
|
||||
private const val GOOD_THRESHOLD = 0.75f // RELAXED
|
||||
}
|
||||
|
||||
/**
|
||||
* Analyze cluster quality before training
|
||||
*/
|
||||
fun analyzeCluster(cluster: FaceCluster): ClusterQualityResult {
|
||||
// Step 1: Filter to solo photos only
|
||||
val soloFaces = cluster.faces.filter { it.faceCount == 1 }
|
||||
Log.d(TAG, "========================================")
|
||||
Log.d(TAG, "Analyzing cluster ${cluster.clusterId}")
|
||||
Log.d(TAG, "Total faces: ${cluster.faces.size}")
|
||||
|
||||
// Step 2: Filter by face size (must be clear/close-up)
|
||||
// Step 1: Filter to solo photos
|
||||
val soloFaces = cluster.faces.filter { it.faceCount == 1 }
|
||||
Log.d(TAG, "Solo photos: ${soloFaces.size}")
|
||||
|
||||
// Step 2: Filter by face size
|
||||
val largeFaces = soloFaces.filter { face ->
|
||||
isFaceLargeEnough(face.boundingBox, face.imageUri)
|
||||
isFaceLargeEnough(face)
|
||||
}
|
||||
Log.d(TAG, "Large faces (>= 3%): ${largeFaces.size}")
|
||||
|
||||
if (largeFaces.size < soloFaces.size) {
|
||||
Log.d(TAG, "⚠️ Filtered out ${soloFaces.size - largeFaces.size} small faces")
|
||||
}
|
||||
|
||||
// Step 3: Calculate internal consistency
|
||||
val (avgSimilarity, outliers) = analyzeInternalConsistency(largeFaces)
|
||||
|
||||
// Step 4: Clean faces (large solo faces, no outliers)
|
||||
// Step 4: Clean faces
|
||||
val cleanFaces = largeFaces.filter { it !in outliers }
|
||||
Log.d(TAG, "Clean faces: ${cleanFaces.size}")
|
||||
|
||||
// Step 5: Calculate quality score
|
||||
val qualityScore = calculateQualityScore(
|
||||
soloPhotoCount = soloFaces.size,
|
||||
largeFaceCount = largeFaces.size,
|
||||
cleanFaceCount = cleanFaces.size,
|
||||
avgSimilarity = avgSimilarity
|
||||
avgSimilarity = avgSimilarity,
|
||||
totalFaces = cluster.faces.size
|
||||
)
|
||||
Log.d(TAG, "Quality score: ${(qualityScore * 100).toInt()}%")
|
||||
|
||||
// Step 6: Determine quality tier
|
||||
val qualityTier = when {
|
||||
@@ -66,6 +72,11 @@ class ClusterQualityAnalyzer @Inject constructor() {
|
||||
qualityScore >= GOOD_THRESHOLD -> ClusterQualityTier.GOOD
|
||||
else -> ClusterQualityTier.POOR
|
||||
}
|
||||
Log.d(TAG, "Quality tier: $qualityTier")
|
||||
|
||||
val canTrain = qualityTier != ClusterQualityTier.POOR && cleanFaces.size >= MIN_SOLO_PHOTOS
|
||||
Log.d(TAG, "Can train: $canTrain")
|
||||
Log.d(TAG, "========================================")
|
||||
|
||||
return ClusterQualityResult(
|
||||
originalFaceCount = cluster.faces.size,
|
||||
@@ -77,62 +88,65 @@ class ClusterQualityAnalyzer @Inject constructor() {
|
||||
cleanFaces = cleanFaces,
|
||||
qualityScore = qualityScore,
|
||||
qualityTier = qualityTier,
|
||||
canTrain = qualityTier != ClusterQualityTier.POOR && cleanFaces.size >= MIN_SOLO_PHOTOS,
|
||||
warnings = generateWarnings(soloFaces.size, largeFaces.size, cleanFaces.size, qualityTier)
|
||||
canTrain = canTrain,
|
||||
warnings = generateWarnings(soloFaces.size, largeFaces.size, cleanFaces.size, qualityTier, avgSimilarity)
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if face is large enough (not distant/blurry)
|
||||
*
|
||||
* A face should occupy at least 15% of the image area for good quality
|
||||
*/
|
||||
private fun isFaceLargeEnough(boundingBox: Rect, imageUri: String): Boolean {
|
||||
// Estimate image dimensions from common aspect ratios
|
||||
// For now, use bounding box size as proxy
|
||||
val faceArea = boundingBox.width() * boundingBox.height()
|
||||
private fun isFaceLargeEnough(face: DetectedFaceWithEmbedding): Boolean {
|
||||
val faceArea = face.boundingBox.width() * face.boundingBox.height()
|
||||
|
||||
// Assume typical photo is ~2000x1500 = 3,000,000 pixels
|
||||
// 15% = 450,000 pixels
|
||||
// For a square face: sqrt(450,000) = ~670 pixels per side
|
||||
// Check 1: Absolute minimum
|
||||
if (face.boundingBox.width() < MIN_FACE_DIMENSION_PIXELS ||
|
||||
face.boundingBox.height() < MIN_FACE_DIMENSION_PIXELS) {
|
||||
return false
|
||||
}
|
||||
|
||||
// More conservative: face should be at least 200x200 pixels
|
||||
return boundingBox.width() >= 200 && boundingBox.height() >= 200
|
||||
// Check 2: Relative size if we have dimensions
|
||||
if (face.imageWidth > 0 && face.imageHeight > 0) {
|
||||
val imageArea = face.imageWidth * face.imageHeight
|
||||
val faceRatio = faceArea.toFloat() / imageArea.toFloat()
|
||||
return faceRatio >= MIN_FACE_SIZE_RATIO
|
||||
}
|
||||
|
||||
// Fallback: Use absolute size
|
||||
return face.boundingBox.width() >= FALLBACK_MIN_DIMENSION &&
|
||||
face.boundingBox.height() >= FALLBACK_MIN_DIMENSION
|
||||
}
|
||||
|
||||
/**
|
||||
* Analyze how similar faces are to each other (internal consistency)
|
||||
*
|
||||
* Returns: (average similarity, list of outlier faces)
|
||||
*/
|
||||
private fun analyzeInternalConsistency(
|
||||
faces: List<DetectedFaceWithEmbedding>
|
||||
): Pair<Float, List<DetectedFaceWithEmbedding>> {
|
||||
if (faces.size < 2) {
|
||||
Log.d(TAG, "Less than 2 faces, skipping consistency check")
|
||||
return 0f to emptyList()
|
||||
}
|
||||
|
||||
// Calculate average embedding (centroid)
|
||||
Log.d(TAG, "Analyzing ${faces.size} faces for internal consistency")
|
||||
|
||||
val centroid = calculateCentroid(faces.map { it.embedding })
|
||||
|
||||
// Calculate similarity of each face to centroid
|
||||
val similarities = faces.map { face ->
|
||||
face to cosineSimilarity(face.embedding, centroid)
|
||||
val centroidSum = centroid.sum()
|
||||
Log.d(TAG, "Centroid sum: $centroidSum, first5=[${centroid.take(5).joinToString()}]")
|
||||
|
||||
val similarities = faces.mapIndexed { index, face ->
|
||||
val similarity = cosineSimilarity(face.embedding, centroid)
|
||||
Log.d(TAG, "Face $index similarity to centroid: $similarity")
|
||||
face to similarity
|
||||
}
|
||||
|
||||
val avgSimilarity = similarities.map { it.second }.average().toFloat()
|
||||
Log.d(TAG, "Average internal similarity: $avgSimilarity")
|
||||
|
||||
// Find outliers (faces significantly different from centroid)
|
||||
val outliers = similarities
|
||||
.filter { (_, similarity) -> similarity < OUTLIER_THRESHOLD }
|
||||
.map { (face, _) -> face }
|
||||
|
||||
Log.d(TAG, "Found ${outliers.size} outliers (threshold=$OUTLIER_THRESHOLD)")
|
||||
|
||||
return avgSimilarity to outliers
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculate centroid (average embedding)
|
||||
*/
|
||||
private fun calculateCentroid(embeddings: List<FloatArray>): FloatArray {
|
||||
val size = embeddings.first().size
|
||||
val centroid = FloatArray(size) { 0f }
|
||||
@@ -148,14 +162,14 @@ class ClusterQualityAnalyzer @Inject constructor() {
|
||||
centroid[i] /= count
|
||||
}
|
||||
|
||||
// Normalize
|
||||
val norm = sqrt(centroid.map { it * it }.sum())
|
||||
return centroid.map { it / norm }.toFloatArray()
|
||||
return if (norm > 0) {
|
||||
centroid.map { it / norm }.toFloatArray()
|
||||
} else {
|
||||
centroid
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Cosine similarity between two embeddings
|
||||
*/
|
||||
private fun cosineSimilarity(a: FloatArray, b: FloatArray): Float {
|
||||
var dotProduct = 0f
|
||||
var normA = 0f
|
||||
@@ -170,32 +184,31 @@ class ClusterQualityAnalyzer @Inject constructor() {
|
||||
return dotProduct / (sqrt(normA) * sqrt(normB))
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculate overall quality score (0.0 - 1.0)
|
||||
*/
|
||||
private fun calculateQualityScore(
|
||||
soloPhotoCount: Int,
|
||||
largeFaceCount: Int,
|
||||
cleanFaceCount: Int,
|
||||
avgSimilarity: Float
|
||||
avgSimilarity: Float,
|
||||
totalFaces: Int
|
||||
): Float {
|
||||
// Weight factors
|
||||
val soloPhotoScore = (soloPhotoCount.toFloat() / 20f).coerceIn(0f, 1f) * 0.3f
|
||||
val largeFaceScore = (largeFaceCount.toFloat() / 15f).coerceIn(0f, 1f) * 0.2f
|
||||
val cleanFaceScore = (cleanFaceCount.toFloat() / 10f).coerceIn(0f, 1f) * 0.2f
|
||||
val similarityScore = avgSimilarity * 0.3f
|
||||
val soloRatio = soloPhotoCount.toFloat() / totalFaces.toFloat().coerceAtLeast(1f)
|
||||
val soloPhotoScore = soloRatio.coerceIn(0f, 1f) * 0.25f
|
||||
|
||||
val largeFaceScore = (largeFaceCount.toFloat() / 15f).coerceIn(0f, 1f) * 0.25f
|
||||
|
||||
val cleanFaceScore = (cleanFaceCount.toFloat() / 10f).coerceIn(0f, 1f) * 0.20f
|
||||
|
||||
val similarityScore = avgSimilarity * 0.30f
|
||||
|
||||
return soloPhotoScore + largeFaceScore + cleanFaceScore + similarityScore
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate human-readable warnings
|
||||
*/
|
||||
private fun generateWarnings(
|
||||
soloPhotoCount: Int,
|
||||
largeFaceCount: Int,
|
||||
cleanFaceCount: Int,
|
||||
qualityTier: ClusterQualityTier
|
||||
qualityTier: ClusterQualityTier,
|
||||
avgSimilarity: Float
|
||||
): List<String> {
|
||||
val warnings = mutableListOf<String>()
|
||||
|
||||
@@ -203,12 +216,20 @@ class ClusterQualityAnalyzer @Inject constructor() {
|
||||
ClusterQualityTier.POOR -> {
|
||||
warnings.add("⚠️ POOR QUALITY - This cluster may contain multiple people!")
|
||||
warnings.add("Do NOT train on this cluster - it will create a bad model.")
|
||||
|
||||
if (avgSimilarity < 0.70f) {
|
||||
warnings.add("Low internal similarity (${(avgSimilarity * 100).toInt()}%) suggests mixed identities.")
|
||||
}
|
||||
}
|
||||
ClusterQualityTier.GOOD -> {
|
||||
warnings.add("⚠️ Review outlier faces before training")
|
||||
|
||||
if (cleanFaceCount < 10) {
|
||||
warnings.add("Consider adding more high-quality photos for better results.")
|
||||
}
|
||||
}
|
||||
ClusterQualityTier.EXCELLENT -> {
|
||||
// No warnings - ready to train!
|
||||
// No warnings
|
||||
}
|
||||
}
|
||||
|
||||
@@ -218,38 +239,47 @@ class ClusterQualityAnalyzer @Inject constructor() {
|
||||
|
||||
if (largeFaceCount < 6) {
|
||||
warnings.add("Only $largeFaceCount photos with large/clear faces (prefer 10+)")
|
||||
warnings.add("Tip: Use close-up photos where the face is clearly visible")
|
||||
}
|
||||
|
||||
if (cleanFaceCount < 6) {
|
||||
warnings.add("After removing outliers: only $cleanFaceCount clean faces (need 6+)")
|
||||
}
|
||||
|
||||
if (qualityTier == ClusterQualityTier.EXCELLENT) {
|
||||
warnings.add("✅ Excellent quality! This cluster is ready for training.")
|
||||
warnings.add("High-quality photos with consistent facial features detected.")
|
||||
}
|
||||
|
||||
return warnings
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Result of cluster quality analysis
|
||||
*/
|
||||
data class ClusterQualityResult(
|
||||
val originalFaceCount: Int, // Total faces in cluster
|
||||
val soloPhotoCount: Int, // Photos with faceCount = 1
|
||||
val largeFaceCount: Int, // Solo photos with large faces
|
||||
val cleanFaceCount: Int, // Large faces, no outliers
|
||||
val avgInternalSimilarity: Float, // How similar faces are (0.0-1.0)
|
||||
val outlierFaces: List<DetectedFaceWithEmbedding>, // Faces to exclude
|
||||
val cleanFaces: List<DetectedFaceWithEmbedding>, // Good faces for training
|
||||
val qualityScore: Float, // Overall score (0.0-1.0)
|
||||
val originalFaceCount: Int,
|
||||
val soloPhotoCount: Int,
|
||||
val largeFaceCount: Int,
|
||||
val cleanFaceCount: Int,
|
||||
val avgInternalSimilarity: Float,
|
||||
val outlierFaces: List<DetectedFaceWithEmbedding>,
|
||||
val cleanFaces: List<DetectedFaceWithEmbedding>,
|
||||
val qualityScore: Float,
|
||||
val qualityTier: ClusterQualityTier,
|
||||
val canTrain: Boolean, // Safe to proceed with training?
|
||||
val warnings: List<String> // Human-readable issues
|
||||
)
|
||||
val canTrain: Boolean,
|
||||
val warnings: List<String>
|
||||
) {
|
||||
fun getSummary(): String = when (qualityTier) {
|
||||
ClusterQualityTier.EXCELLENT ->
|
||||
"Excellent quality cluster with $cleanFaceCount high-quality photos ready for training."
|
||||
ClusterQualityTier.GOOD ->
|
||||
"Good quality cluster with $cleanFaceCount usable photos. Review outliers before training."
|
||||
ClusterQualityTier.POOR ->
|
||||
"Poor quality cluster. May contain multiple people or low-quality photos. Add more photos or split cluster."
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Quality tier classification
|
||||
*/
|
||||
enum class ClusterQualityTier {
|
||||
EXCELLENT, // 95%+ - Safe to train immediately
|
||||
GOOD, // 85-94% - Review outliers first
|
||||
POOR // <85% - DO NOT TRAIN (likely mixed people)
|
||||
EXCELLENT, // 85%+
|
||||
GOOD, // 75-84%
|
||||
POOR // <75%
|
||||
}
|
||||
@@ -4,11 +4,13 @@ import android.content.Context
|
||||
import android.graphics.Bitmap
|
||||
import android.graphics.BitmapFactory
|
||||
import android.net.Uri
|
||||
import android.util.Log
|
||||
import com.google.mlkit.vision.common.InputImage
|
||||
import com.google.mlkit.vision.face.FaceDetection
|
||||
import com.google.mlkit.vision.face.FaceDetectorOptions
|
||||
import com.placeholder.sherpai2.data.local.dao.FaceCacheDao
|
||||
import com.placeholder.sherpai2.data.local.dao.ImageDao
|
||||
import com.placeholder.sherpai2.data.local.entity.FaceCacheEntity
|
||||
import com.placeholder.sherpai2.data.local.entity.ImageEntity
|
||||
import com.placeholder.sherpai2.ml.FaceNetModel
|
||||
import dagger.hilt.android.qualifiers.ApplicationContext
|
||||
@@ -18,7 +20,6 @@ import kotlinx.coroutines.awaitAll
|
||||
import kotlinx.coroutines.coroutineScope
|
||||
import kotlinx.coroutines.sync.Semaphore
|
||||
import kotlinx.coroutines.withContext
|
||||
import java.util.concurrent.atomic.AtomicInteger
|
||||
import javax.inject.Inject
|
||||
import javax.inject.Singleton
|
||||
import kotlin.math.sqrt
|
||||
@@ -33,111 +34,152 @@ import kotlin.math.sqrt
|
||||
* 4. Detect faces and generate embeddings (parallel)
|
||||
* 5. Cluster using DBSCAN (epsilon=0.18, minPoints=3)
|
||||
* 6. Analyze clusters for age, siblings, representatives
|
||||
*
|
||||
* IMPROVEMENTS:
|
||||
* - ✅ Complete fast-path using FaceCacheDao.getSoloFacesWithEmbeddings()
|
||||
* - ✅ Works with existing FaceCacheEntity.getEmbedding() method
|
||||
* - ✅ Centroid-based representative face selection
|
||||
* - ✅ Batched processing to prevent OOM
|
||||
* - ✅ RGB_565 bitmap config for 50% memory savings
|
||||
*/
|
||||
@Singleton
|
||||
class FaceClusteringService @Inject constructor(
|
||||
@ApplicationContext private val context: Context,
|
||||
private val imageDao: ImageDao,
|
||||
private val faceCacheDao: FaceCacheDao // Optional - will work without it
|
||||
private val faceCacheDao: FaceCacheDao
|
||||
) {
|
||||
|
||||
private val semaphore = Semaphore(12)
|
||||
private val semaphore = Semaphore(8)
|
||||
|
||||
companion object {
|
||||
private const val TAG = "FaceClustering"
|
||||
private const val MAX_FACES_TO_CLUSTER = 2000
|
||||
private const val MIN_SOLO_PHOTOS = 50
|
||||
private const val BATCH_SIZE = 50
|
||||
private const val MIN_CACHED_FACES = 100
|
||||
}
|
||||
|
||||
/**
|
||||
* Main clustering entry point - HYBRID with automatic fallback
|
||||
*
|
||||
* @param maxFacesToCluster Limit for performance (default 2000)
|
||||
* @param onProgress Progress callback (current, total, message)
|
||||
*/
|
||||
suspend fun discoverPeople(
|
||||
maxFacesToCluster: Int = 2000,
|
||||
maxFacesToCluster: Int = MAX_FACES_TO_CLUSTER,
|
||||
onProgress: (Int, Int, String) -> Unit = { _, _, _ -> }
|
||||
): ClusteringResult = withContext(Dispatchers.Default) {
|
||||
|
||||
// TRY FAST PATH: Use face cache if available
|
||||
val highQualityFaces = try {
|
||||
withContext(Dispatchers.IO) {
|
||||
faceCacheDao.getHighQualitySoloFaces()
|
||||
}
|
||||
} catch (e: Exception) {
|
||||
emptyList()
|
||||
}
|
||||
|
||||
if (highQualityFaces.isNotEmpty()) {
|
||||
// FAST PATH: Use cached faces (future enhancement)
|
||||
onProgress(0, 100, "Using face cache (${highQualityFaces.size} faces)...")
|
||||
// TODO: Implement cache-based clustering
|
||||
// For now, fall through to classic method
|
||||
}
|
||||
|
||||
// CLASSIC METHOD: Load and process photos
|
||||
onProgress(0, 100, "Loading solo photos...")
|
||||
|
||||
// Step 1: Get SOLO PHOTOS ONLY (faceCount = 1) for cleaner clustering
|
||||
val soloPhotos = withContext(Dispatchers.IO) {
|
||||
imageDao.getImagesByFaceCount(count = 1)
|
||||
}
|
||||
|
||||
// Fallback: If not enough solo photos, use all images with faces
|
||||
val imagesWithFaces = if (soloPhotos.size < 50) {
|
||||
onProgress(0, 100, "Loading all photos with faces...")
|
||||
imageDao.getImagesWithFaces()
|
||||
} else {
|
||||
soloPhotos
|
||||
}
|
||||
|
||||
if (imagesWithFaces.isEmpty()) {
|
||||
return@withContext ClusteringResult(
|
||||
clusters = emptyList(),
|
||||
totalFacesAnalyzed = 0,
|
||||
processingTimeMs = 0,
|
||||
errorMessage = "No photos with faces found. Please ensure face detection cache is populated."
|
||||
)
|
||||
}
|
||||
|
||||
onProgress(10, 100, "Analyzing ${imagesWithFaces.size} photos (${if (soloPhotos.size >= 50) "solo only" else "all"})...")
|
||||
|
||||
val startTime = System.currentTimeMillis()
|
||||
|
||||
// Step 2: Detect faces and generate embeddings (parallel)
|
||||
val allFaces = detectFacesInImages(
|
||||
images = imagesWithFaces.take(1000), // Smart limit
|
||||
onProgress = { current, total ->
|
||||
onProgress(10 + (current * 40 / total), 100, "Detecting faces... $current/$total")
|
||||
// Try high-quality cached faces FIRST (NEW!)
|
||||
var cachedFaces = withContext(Dispatchers.IO) {
|
||||
try {
|
||||
faceCacheDao.getHighQualitySoloFaces(
|
||||
minFaceRatio = 0.015f, // 1.5%
|
||||
limit = maxFacesToCluster
|
||||
)
|
||||
} catch (e: Exception) {
|
||||
// Method doesn't exist yet - that's ok
|
||||
emptyList()
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
// Fallback to ANY solo faces if high-quality returned nothing
|
||||
if (cachedFaces.isEmpty()) {
|
||||
Log.w(TAG, "No high-quality faces (>= 1.5%), trying ANY solo faces...")
|
||||
cachedFaces = withContext(Dispatchers.IO) {
|
||||
try {
|
||||
faceCacheDao.getSoloFacesWithEmbeddings(limit = maxFacesToCluster)
|
||||
} catch (e: Exception) {
|
||||
emptyList()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Log.d(TAG, "Cache check: ${cachedFaces.size} faces available")
|
||||
|
||||
val allFaces = if (cachedFaces.size >= MIN_CACHED_FACES) {
|
||||
// FAST PATH ✅
|
||||
Log.d(TAG, "Using FAST PATH with ${cachedFaces.size} cached faces")
|
||||
onProgress(10, 100, "Using cached embeddings (${cachedFaces.size} faces)...")
|
||||
|
||||
cachedFaces.mapNotNull { cached ->
|
||||
val embedding = cached.getEmbedding() ?: return@mapNotNull null
|
||||
|
||||
DetectedFaceWithEmbedding(
|
||||
imageId = cached.imageId,
|
||||
imageUri = "",
|
||||
capturedAt = 0L,
|
||||
embedding = embedding,
|
||||
boundingBox = cached.getBoundingBox(),
|
||||
confidence = cached.confidence,
|
||||
faceCount = 1, // Solo faces only (filtered by query)
|
||||
imageWidth = cached.imageWidth,
|
||||
imageHeight = cached.imageHeight
|
||||
)
|
||||
}.also {
|
||||
onProgress(50, 100, "Processing ${it.size} cached faces...")
|
||||
}
|
||||
} else {
|
||||
// SLOW PATH
|
||||
Log.d(TAG, "Using SLOW PATH - cache has < $MIN_CACHED_FACES faces")
|
||||
onProgress(0, 100, "Loading photos...")
|
||||
|
||||
val soloPhotos = withContext(Dispatchers.IO) {
|
||||
imageDao.getImagesByFaceCount(count = 1)
|
||||
}
|
||||
|
||||
val imagesWithFaces = if (soloPhotos.size < MIN_SOLO_PHOTOS) {
|
||||
imageDao.getImagesWithFaces()
|
||||
} else {
|
||||
soloPhotos
|
||||
}
|
||||
|
||||
if (imagesWithFaces.isEmpty()) {
|
||||
return@withContext ClusteringResult(
|
||||
clusters = emptyList(),
|
||||
totalFacesAnalyzed = 0,
|
||||
processingTimeMs = 0,
|
||||
errorMessage = "No photos with faces found"
|
||||
)
|
||||
}
|
||||
|
||||
onProgress(10, 100, "Analyzing ${imagesWithFaces.size} photos...")
|
||||
|
||||
detectFacesInImagesBatched(
|
||||
images = imagesWithFaces.take(1000),
|
||||
onProgress = { current, total ->
|
||||
onProgress(10 + (current * 40 / total), 100, "Detecting faces... $current/$total")
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
if (allFaces.isEmpty()) {
|
||||
return@withContext ClusteringResult(
|
||||
clusters = emptyList(),
|
||||
totalFacesAnalyzed = 0,
|
||||
processingTimeMs = System.currentTimeMillis() - startTime,
|
||||
errorMessage = "No faces detected in images"
|
||||
errorMessage = "No faces detected"
|
||||
)
|
||||
}
|
||||
|
||||
onProgress(50, 100, "Clustering ${allFaces.size} faces...")
|
||||
|
||||
// Step 3: DBSCAN clustering
|
||||
val rawClusters = performDBSCAN(
|
||||
faces = allFaces.take(maxFacesToCluster),
|
||||
epsilon = 0.18f, // VERY STRICT for siblings
|
||||
epsilon = 0.26f,
|
||||
minPoints = 3
|
||||
)
|
||||
|
||||
onProgress(70, 100, "Analyzing relationships...")
|
||||
|
||||
// Step 4: Build co-occurrence graph
|
||||
val coOccurrenceGraph = buildCoOccurrenceGraph(rawClusters)
|
||||
|
||||
onProgress(80, 100, "Selecting representative faces...")
|
||||
|
||||
// Step 5: Create final clusters
|
||||
val clusters = rawClusters.map { cluster ->
|
||||
FaceCluster(
|
||||
clusterId = cluster.clusterId,
|
||||
faces = cluster.faces,
|
||||
representativeFaces = selectRepresentativeFaces(cluster.faces, count = 6),
|
||||
representativeFaces = selectRepresentativeFacesByCentroid(cluster.faces, count = 6),
|
||||
photoCount = cluster.faces.map { it.imageId }.distinct().size,
|
||||
averageConfidence = cluster.faces.map { it.confidence }.average().toFloat(),
|
||||
estimatedAge = estimateAge(cluster.faces),
|
||||
@@ -154,14 +196,31 @@ class FaceClusteringService @Inject constructor(
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* Detect faces in images and generate embeddings (parallel)
|
||||
*/
|
||||
private suspend fun detectFacesInImages(
|
||||
private suspend fun detectFacesInImagesBatched(
|
||||
images: List<ImageEntity>,
|
||||
onProgress: (Int, Int) -> Unit
|
||||
): List<DetectedFaceWithEmbedding> = coroutineScope {
|
||||
|
||||
val allFaces = mutableListOf<DetectedFaceWithEmbedding>()
|
||||
var processedCount = 0
|
||||
|
||||
images.chunked(BATCH_SIZE).forEach { batch ->
|
||||
val batchFaces = detectFacesInBatch(batch)
|
||||
allFaces.addAll(batchFaces)
|
||||
|
||||
processedCount += batch.size
|
||||
onProgress(processedCount, images.size)
|
||||
|
||||
System.gc()
|
||||
}
|
||||
|
||||
allFaces
|
||||
}
|
||||
|
||||
private suspend fun detectFacesInBatch(
|
||||
images: List<ImageEntity>
|
||||
): List<DetectedFaceWithEmbedding> = coroutineScope {
|
||||
|
||||
val detector = FaceDetection.getClient(
|
||||
FaceDetectorOptions.Builder()
|
||||
.setPerformanceMode(FaceDetectorOptions.PERFORMANCE_MODE_ACCURATE)
|
||||
@@ -170,20 +229,14 @@ class FaceClusteringService @Inject constructor(
|
||||
)
|
||||
|
||||
val faceNetModel = FaceNetModel(context)
|
||||
val allFaces = mutableListOf<DetectedFaceWithEmbedding>()
|
||||
val processedCount = AtomicInteger(0)
|
||||
val batchFaces = mutableListOf<DetectedFaceWithEmbedding>()
|
||||
|
||||
try {
|
||||
val jobs = images.map { image ->
|
||||
async {
|
||||
async(Dispatchers.IO) {
|
||||
semaphore.acquire()
|
||||
try {
|
||||
val faces = detectFacesInImage(image, detector, faceNetModel)
|
||||
val current = processedCount.incrementAndGet()
|
||||
if (current % 10 == 0) {
|
||||
onProgress(current, images.size)
|
||||
}
|
||||
faces
|
||||
detectFacesInImage(image, detector, faceNetModel)
|
||||
} finally {
|
||||
semaphore.release()
|
||||
}
|
||||
@@ -191,7 +244,7 @@ class FaceClusteringService @Inject constructor(
|
||||
}
|
||||
|
||||
jobs.awaitAll().flatten().also {
|
||||
allFaces.addAll(it)
|
||||
batchFaces.addAll(it)
|
||||
}
|
||||
|
||||
} finally {
|
||||
@@ -199,7 +252,7 @@ class FaceClusteringService @Inject constructor(
|
||||
faceNetModel.close()
|
||||
}
|
||||
|
||||
allFaces
|
||||
batchFaces
|
||||
}
|
||||
|
||||
private suspend fun detectFacesInImage(
|
||||
@@ -215,8 +268,6 @@ class FaceClusteringService @Inject constructor(
|
||||
val mlImage = InputImage.fromBitmap(bitmap, 0)
|
||||
val faces = com.google.android.gms.tasks.Tasks.await(detector.process(mlImage))
|
||||
|
||||
val totalFacesInImage = faces.size
|
||||
|
||||
val result = faces.mapNotNull { face ->
|
||||
try {
|
||||
val faceBitmap = Bitmap.createBitmap(
|
||||
@@ -237,7 +288,9 @@ class FaceClusteringService @Inject constructor(
|
||||
embedding = embedding,
|
||||
boundingBox = face.boundingBox,
|
||||
confidence = 0.95f,
|
||||
faceCount = totalFacesInImage
|
||||
faceCount = faces.size,
|
||||
imageWidth = bitmap.width,
|
||||
imageHeight = bitmap.height
|
||||
)
|
||||
} catch (e: Exception) {
|
||||
null
|
||||
@@ -252,9 +305,6 @@ class FaceClusteringService @Inject constructor(
|
||||
}
|
||||
}
|
||||
|
||||
// All other methods remain the same (DBSCAN, similarity, etc.)
|
||||
// ... [Rest of the implementation from original file]
|
||||
|
||||
private fun performDBSCAN(
|
||||
faces: List<DetectedFaceWithEmbedding>,
|
||||
epsilon: Float,
|
||||
@@ -368,18 +418,61 @@ class FaceClusteringService @Inject constructor(
|
||||
?: emptyList()
|
||||
}
|
||||
|
||||
private fun selectRepresentativeFaces(
|
||||
private fun selectRepresentativeFacesByCentroid(
|
||||
faces: List<DetectedFaceWithEmbedding>,
|
||||
count: Int
|
||||
): List<DetectedFaceWithEmbedding> {
|
||||
if (faces.size <= count) return faces
|
||||
|
||||
val sortedByTime = faces.sortedBy { it.capturedAt }
|
||||
val step = faces.size / count
|
||||
val centroid = calculateCentroid(faces.map { it.embedding })
|
||||
|
||||
return (0 until count).map { i ->
|
||||
sortedByTime[i * step]
|
||||
val facesWithDistance = faces.map { face ->
|
||||
val distance = 1 - cosineSimilarity(face.embedding, centroid)
|
||||
face to distance
|
||||
}
|
||||
|
||||
val sortedByProximity = facesWithDistance.sortedBy { it.second }
|
||||
|
||||
val representatives = mutableListOf<DetectedFaceWithEmbedding>()
|
||||
representatives.add(sortedByProximity.first().first)
|
||||
|
||||
val remainingFaces = sortedByProximity.drop(1).take(count * 3)
|
||||
val sortedByTime = remainingFaces.map { it.first }.sortedBy { it.capturedAt }
|
||||
|
||||
if (sortedByTime.isNotEmpty()) {
|
||||
val step = sortedByTime.size / (count - 1).coerceAtLeast(1)
|
||||
for (i in 0 until (count - 1)) {
|
||||
val index = (i * step).coerceAtMost(sortedByTime.size - 1)
|
||||
representatives.add(sortedByTime[index])
|
||||
}
|
||||
}
|
||||
|
||||
return representatives.take(count)
|
||||
}
|
||||
|
||||
private fun calculateCentroid(embeddings: List<FloatArray>): FloatArray {
|
||||
if (embeddings.isEmpty()) return FloatArray(0)
|
||||
|
||||
val size = embeddings.first().size
|
||||
val centroid = FloatArray(size) { 0f }
|
||||
|
||||
embeddings.forEach { embedding ->
|
||||
for (i in embedding.indices) {
|
||||
centroid[i] += embedding[i]
|
||||
}
|
||||
}
|
||||
|
||||
val count = embeddings.size.toFloat()
|
||||
for (i in centroid.indices) {
|
||||
centroid[i] /= count
|
||||
}
|
||||
|
||||
val norm = sqrt(centroid.map { it * it }.sum())
|
||||
if (norm > 0) {
|
||||
return centroid.map { it / norm }.toFloatArray()
|
||||
}
|
||||
|
||||
return centroid
|
||||
}
|
||||
|
||||
private fun estimateAge(faces: List<DetectedFaceWithEmbedding>): AgeEstimate {
|
||||
@@ -416,7 +509,6 @@ class FaceClusteringService @Inject constructor(
|
||||
}
|
||||
}
|
||||
|
||||
// Data classes
|
||||
data class DetectedFaceWithEmbedding(
|
||||
val imageId: String,
|
||||
val imageUri: String,
|
||||
@@ -424,7 +516,9 @@ data class DetectedFaceWithEmbedding(
|
||||
val embedding: FloatArray,
|
||||
val boundingBox: android.graphics.Rect,
|
||||
val confidence: Float,
|
||||
val faceCount: Int = 1
|
||||
val faceCount: Int = 1,
|
||||
val imageWidth: Int = 0,
|
||||
val imageHeight: Int = 0
|
||||
) {
|
||||
override fun equals(other: Any?): Boolean {
|
||||
if (this === other) return true
|
||||
|
||||
@@ -2,6 +2,7 @@ package com.placeholder.sherpai2.ml
|
||||
|
||||
import android.content.Context
|
||||
import android.graphics.Bitmap
|
||||
import android.util.Log
|
||||
import org.tensorflow.lite.Interpreter
|
||||
import java.io.FileInputStream
|
||||
import java.nio.ByteBuffer
|
||||
@@ -11,16 +12,21 @@ import java.nio.channels.FileChannel
|
||||
import kotlin.math.sqrt
|
||||
|
||||
/**
|
||||
* FaceNetModel - MobileFaceNet wrapper for face recognition
|
||||
* FaceNetModel - MobileFaceNet wrapper with debugging
|
||||
*
|
||||
* CLEAN IMPLEMENTATION:
|
||||
* - All IDs are Strings (matching your schema)
|
||||
* - Generates 192-dimensional embeddings
|
||||
* - Cosine similarity for matching
|
||||
* IMPROVEMENTS:
|
||||
* - ✅ Detailed error logging
|
||||
* - ✅ Model validation on init
|
||||
* - ✅ Embedding validation (detect all-zeros)
|
||||
* - ✅ Toggle-able debug mode
|
||||
*/
|
||||
class FaceNetModel(private val context: Context) {
|
||||
class FaceNetModel(
|
||||
private val context: Context,
|
||||
private val debugMode: Boolean = true // Enable for troubleshooting
|
||||
) {
|
||||
|
||||
companion object {
|
||||
private const val TAG = "FaceNetModel"
|
||||
private const val MODEL_FILE = "mobilefacenet.tflite"
|
||||
private const val INPUT_SIZE = 112
|
||||
private const val EMBEDDING_SIZE = 192
|
||||
@@ -31,13 +37,56 @@ class FaceNetModel(private val context: Context) {
|
||||
}
|
||||
|
||||
private var interpreter: Interpreter? = null
|
||||
private var modelLoadSuccess = false
|
||||
|
||||
init {
|
||||
try {
|
||||
if (debugMode) Log.d(TAG, "Loading FaceNet model: $MODEL_FILE")
|
||||
|
||||
val model = loadModelFile()
|
||||
interpreter = Interpreter(model)
|
||||
modelLoadSuccess = true
|
||||
|
||||
if (debugMode) {
|
||||
Log.d(TAG, "✅ FaceNet model loaded successfully")
|
||||
Log.d(TAG, "Model input size: ${INPUT_SIZE}x$INPUT_SIZE")
|
||||
Log.d(TAG, "Embedding size: $EMBEDDING_SIZE")
|
||||
}
|
||||
|
||||
// Test model with dummy input
|
||||
testModel()
|
||||
|
||||
} catch (e: Exception) {
|
||||
throw RuntimeException("Failed to load FaceNet model", e)
|
||||
Log.e(TAG, "❌ CRITICAL: Failed to load FaceNet model from assets/$MODEL_FILE", e)
|
||||
Log.e(TAG, "Make sure mobilefacenet.tflite exists in app/src/main/assets/")
|
||||
modelLoadSuccess = false
|
||||
throw RuntimeException("Failed to load FaceNet model: ${e.message}", e)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Test model with dummy input to verify it works
|
||||
*/
|
||||
private fun testModel() {
|
||||
try {
|
||||
val testBitmap = Bitmap.createBitmap(INPUT_SIZE, INPUT_SIZE, Bitmap.Config.ARGB_8888)
|
||||
val testEmbedding = generateEmbedding(testBitmap)
|
||||
testBitmap.recycle()
|
||||
|
||||
val sum = testEmbedding.sum()
|
||||
val norm = sqrt(testEmbedding.map { it * it }.sum())
|
||||
|
||||
if (debugMode) {
|
||||
Log.d(TAG, "Model test: embedding sum=$sum, norm=$norm")
|
||||
}
|
||||
|
||||
if (sum == 0f || norm == 0f) {
|
||||
Log.e(TAG, "⚠️ WARNING: Model test produced zero embedding!")
|
||||
} else {
|
||||
if (debugMode) Log.d(TAG, "✅ Model test passed")
|
||||
}
|
||||
} catch (e: Exception) {
|
||||
Log.e(TAG, "Model test failed", e)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -45,12 +94,22 @@ class FaceNetModel(private val context: Context) {
|
||||
* Load TFLite model from assets
|
||||
*/
|
||||
private fun loadModelFile(): MappedByteBuffer {
|
||||
val fileDescriptor = context.assets.openFd(MODEL_FILE)
|
||||
val inputStream = FileInputStream(fileDescriptor.fileDescriptor)
|
||||
val fileChannel = inputStream.channel
|
||||
val startOffset = fileDescriptor.startOffset
|
||||
val declaredLength = fileDescriptor.declaredLength
|
||||
return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength)
|
||||
try {
|
||||
val fileDescriptor = context.assets.openFd(MODEL_FILE)
|
||||
val inputStream = FileInputStream(fileDescriptor.fileDescriptor)
|
||||
val fileChannel = inputStream.channel
|
||||
val startOffset = fileDescriptor.startOffset
|
||||
val declaredLength = fileDescriptor.declaredLength
|
||||
|
||||
if (debugMode) {
|
||||
Log.d(TAG, "Model file size: ${declaredLength / 1024}KB")
|
||||
}
|
||||
|
||||
return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength)
|
||||
} catch (e: Exception) {
|
||||
Log.e(TAG, "Failed to open model file: $MODEL_FILE", e)
|
||||
throw e
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -60,13 +119,39 @@ class FaceNetModel(private val context: Context) {
|
||||
* @return 192-dimensional embedding
|
||||
*/
|
||||
fun generateEmbedding(faceBitmap: Bitmap): FloatArray {
|
||||
val resized = Bitmap.createScaledBitmap(faceBitmap, INPUT_SIZE, INPUT_SIZE, true)
|
||||
val inputBuffer = preprocessImage(resized)
|
||||
val output = Array(1) { FloatArray(EMBEDDING_SIZE) }
|
||||
if (!modelLoadSuccess || interpreter == null) {
|
||||
Log.e(TAG, "❌ Cannot generate embedding: model not loaded!")
|
||||
return FloatArray(EMBEDDING_SIZE) { 0f }
|
||||
}
|
||||
|
||||
interpreter?.run(inputBuffer, output)
|
||||
try {
|
||||
val resized = Bitmap.createScaledBitmap(faceBitmap, INPUT_SIZE, INPUT_SIZE, true)
|
||||
val inputBuffer = preprocessImage(resized)
|
||||
val output = Array(1) { FloatArray(EMBEDDING_SIZE) }
|
||||
|
||||
return normalizeEmbedding(output[0])
|
||||
interpreter?.run(inputBuffer, output)
|
||||
|
||||
val normalized = normalizeEmbedding(output[0])
|
||||
|
||||
// DIAGNOSTIC: Check embedding quality
|
||||
if (debugMode) {
|
||||
val sum = normalized.sum()
|
||||
val norm = sqrt(normalized.map { it * it }.sum())
|
||||
|
||||
if (sum == 0f && norm == 0f) {
|
||||
Log.e(TAG, "❌ CRITICAL: Generated all-zero embedding!")
|
||||
Log.e(TAG, "Input bitmap: ${faceBitmap.width}x${faceBitmap.height}")
|
||||
} else {
|
||||
Log.d(TAG, "✅ Embedding: sum=${"%.2f".format(sum)}, norm=${"%.2f".format(norm)}, first5=[${normalized.take(5).joinToString { "%.3f".format(it) }}]")
|
||||
}
|
||||
}
|
||||
|
||||
return normalized
|
||||
|
||||
} catch (e: Exception) {
|
||||
Log.e(TAG, "Failed to generate embedding", e)
|
||||
return FloatArray(EMBEDDING_SIZE) { 0f }
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -76,6 +161,10 @@ class FaceNetModel(private val context: Context) {
|
||||
faceBitmaps: List<Bitmap>,
|
||||
onProgress: (Int, Int) -> Unit = { _, _ -> }
|
||||
): List<FloatArray> {
|
||||
if (debugMode) {
|
||||
Log.d(TAG, "Generating embeddings for ${faceBitmaps.size} faces")
|
||||
}
|
||||
|
||||
return faceBitmaps.mapIndexed { index, bitmap ->
|
||||
onProgress(index + 1, faceBitmaps.size)
|
||||
generateEmbedding(bitmap)
|
||||
@@ -88,6 +177,10 @@ class FaceNetModel(private val context: Context) {
|
||||
fun createPersonModel(embeddings: List<FloatArray>): FloatArray {
|
||||
require(embeddings.isNotEmpty()) { "Need at least one embedding" }
|
||||
|
||||
if (debugMode) {
|
||||
Log.d(TAG, "Creating person model from ${embeddings.size} embeddings")
|
||||
}
|
||||
|
||||
val averaged = FloatArray(EMBEDDING_SIZE) { 0f }
|
||||
|
||||
embeddings.forEach { embedding ->
|
||||
@@ -101,7 +194,14 @@ class FaceNetModel(private val context: Context) {
|
||||
averaged[i] /= count
|
||||
}
|
||||
|
||||
return normalizeEmbedding(averaged)
|
||||
val normalized = normalizeEmbedding(averaged)
|
||||
|
||||
if (debugMode) {
|
||||
val sum = normalized.sum()
|
||||
Log.d(TAG, "Person model created: sum=${"%.2f".format(sum)}")
|
||||
}
|
||||
|
||||
return normalized
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -110,7 +210,7 @@ class FaceNetModel(private val context: Context) {
|
||||
*/
|
||||
fun calculateSimilarity(embedding1: FloatArray, embedding2: FloatArray): Float {
|
||||
require(embedding1.size == EMBEDDING_SIZE && embedding2.size == EMBEDDING_SIZE) {
|
||||
"Invalid embedding size"
|
||||
"Invalid embedding size: ${embedding1.size} vs ${embedding2.size}"
|
||||
}
|
||||
|
||||
var dotProduct = 0f
|
||||
@@ -123,7 +223,14 @@ class FaceNetModel(private val context: Context) {
|
||||
norm2 += embedding2[i] * embedding2[i]
|
||||
}
|
||||
|
||||
return dotProduct / (sqrt(norm1) * sqrt(norm2))
|
||||
val similarity = dotProduct / (sqrt(norm1) * sqrt(norm2))
|
||||
|
||||
if (debugMode && (similarity.isNaN() || similarity.isInfinite())) {
|
||||
Log.e(TAG, "❌ Invalid similarity: $similarity (norm1=$norm1, norm2=$norm2)")
|
||||
return 0f
|
||||
}
|
||||
|
||||
return similarity
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -151,6 +258,10 @@ class FaceNetModel(private val context: Context) {
|
||||
}
|
||||
}
|
||||
|
||||
if (debugMode && bestMatch != null) {
|
||||
Log.d(TAG, "Best match: ${bestMatch.first} with similarity ${bestMatch.second}")
|
||||
}
|
||||
|
||||
return bestMatch
|
||||
}
|
||||
|
||||
@@ -169,6 +280,7 @@ class FaceNetModel(private val context: Context) {
|
||||
val g = ((pixel shr 8) and 0xFF) / 255.0f
|
||||
val b = (pixel and 0xFF) / 255.0f
|
||||
|
||||
// Normalize to [-1, 1]
|
||||
buffer.putFloat((r - 0.5f) / 0.5f)
|
||||
buffer.putFloat((g - 0.5f) / 0.5f)
|
||||
buffer.putFloat((b - 0.5f) / 0.5f)
|
||||
@@ -190,14 +302,29 @@ class FaceNetModel(private val context: Context) {
|
||||
return if (norm > 0) {
|
||||
FloatArray(embedding.size) { i -> embedding[i] / norm }
|
||||
} else {
|
||||
Log.w(TAG, "⚠️ Cannot normalize zero embedding")
|
||||
embedding
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get model status for diagnostics
|
||||
*/
|
||||
fun getModelStatus(): String {
|
||||
return if (modelLoadSuccess) {
|
||||
"✅ Model loaded and operational"
|
||||
} else {
|
||||
"❌ Model failed to load - check assets/$MODEL_FILE"
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Clean up resources
|
||||
*/
|
||||
fun close() {
|
||||
if (debugMode) {
|
||||
Log.d(TAG, "Closing FaceNet model")
|
||||
}
|
||||
interpreter?.close()
|
||||
interpreter = null
|
||||
}
|
||||
|
||||
@@ -8,9 +8,13 @@ import androidx.compose.foundation.layout.*
|
||||
import androidx.compose.foundation.lazy.grid.GridCells
|
||||
import androidx.compose.foundation.lazy.grid.LazyVerticalGrid
|
||||
import androidx.compose.foundation.lazy.grid.items
|
||||
import androidx.compose.foundation.shape.CircleShape
|
||||
import androidx.compose.foundation.shape.RoundedCornerShape
|
||||
import androidx.compose.material.icons.Icons
|
||||
import androidx.compose.material.icons.filled.Check
|
||||
import androidx.compose.material.icons.filled.Warning
|
||||
import androidx.compose.material3.*
|
||||
import androidx.compose.runtime.Composable
|
||||
import androidx.compose.runtime.*
|
||||
import androidx.compose.ui.Alignment
|
||||
import androidx.compose.ui.Modifier
|
||||
import androidx.compose.ui.draw.clip
|
||||
@@ -19,6 +23,8 @@ import androidx.compose.ui.layout.ContentScale
|
||||
import androidx.compose.ui.text.font.FontWeight
|
||||
import androidx.compose.ui.unit.dp
|
||||
import coil.compose.AsyncImage
|
||||
import com.placeholder.sherpai2.domain.clustering.ClusterQualityAnalyzer
|
||||
import com.placeholder.sherpai2.domain.clustering.ClusterQualityTier
|
||||
import com.placeholder.sherpai2.domain.clustering.ClusteringResult
|
||||
import com.placeholder.sherpai2.domain.clustering.FaceCluster
|
||||
|
||||
@@ -28,13 +34,20 @@ import com.placeholder.sherpai2.domain.clustering.FaceCluster
|
||||
* Each cluster card shows:
|
||||
* - 2x2 grid of representative faces
|
||||
* - Photo count
|
||||
* - Quality badge (Excellent/Good/Poor)
|
||||
* - Tap to name
|
||||
*
|
||||
* IMPROVEMENTS:
|
||||
* - ✅ Quality badges for each cluster
|
||||
* - ✅ Visual indicators for trainable vs non-trainable clusters
|
||||
* - ✅ Better UX with disabled states for poor quality clusters
|
||||
*/
|
||||
@Composable
|
||||
fun ClusterGridScreen(
|
||||
result: ClusteringResult,
|
||||
onSelectCluster: (FaceCluster) -> Unit,
|
||||
modifier: Modifier = Modifier
|
||||
modifier: Modifier = Modifier,
|
||||
qualityAnalyzer: ClusterQualityAnalyzer = remember { ClusterQualityAnalyzer() }
|
||||
) {
|
||||
Column(
|
||||
modifier = modifier
|
||||
@@ -65,8 +78,15 @@ fun ClusterGridScreen(
|
||||
verticalArrangement = Arrangement.spacedBy(12.dp)
|
||||
) {
|
||||
items(result.clusters) { cluster ->
|
||||
// Analyze quality for each cluster
|
||||
val qualityResult = remember(cluster) {
|
||||
qualityAnalyzer.analyzeCluster(cluster)
|
||||
}
|
||||
|
||||
ClusterCard(
|
||||
cluster = cluster,
|
||||
qualityTier = qualityResult.qualityTier,
|
||||
canTrain = qualityResult.canTrain,
|
||||
onClick = { onSelectCluster(cluster) }
|
||||
)
|
||||
}
|
||||
@@ -75,77 +95,168 @@ fun ClusterGridScreen(
|
||||
}
|
||||
|
||||
/**
|
||||
* Single cluster card with 2x2 face grid
|
||||
* Single cluster card with 2x2 face grid and quality badge
|
||||
*/
|
||||
@Composable
|
||||
private fun ClusterCard(
|
||||
cluster: FaceCluster,
|
||||
qualityTier: ClusterQualityTier,
|
||||
canTrain: Boolean,
|
||||
onClick: () -> Unit
|
||||
) {
|
||||
Card(
|
||||
modifier = Modifier
|
||||
.fillMaxWidth()
|
||||
.aspectRatio(1f)
|
||||
.clickable(onClick = onClick),
|
||||
elevation = CardDefaults.cardElevation(defaultElevation = 2.dp)
|
||||
.clickable(onClick = onClick), // Always clickable - let dialog handle validation
|
||||
elevation = CardDefaults.cardElevation(defaultElevation = 2.dp),
|
||||
colors = CardDefaults.cardColors(
|
||||
containerColor = when {
|
||||
qualityTier == ClusterQualityTier.POOR ->
|
||||
MaterialTheme.colorScheme.errorContainer.copy(alpha = 0.3f)
|
||||
!canTrain ->
|
||||
MaterialTheme.colorScheme.surfaceVariant.copy(alpha = 0.5f)
|
||||
else ->
|
||||
MaterialTheme.colorScheme.surface
|
||||
}
|
||||
)
|
||||
) {
|
||||
Column(
|
||||
Box(
|
||||
modifier = Modifier.fillMaxSize()
|
||||
) {
|
||||
// 2x2 grid of faces
|
||||
val facesToShow = cluster.representativeFaces.take(4)
|
||||
|
||||
Column(
|
||||
modifier = Modifier.weight(1f)
|
||||
modifier = Modifier.fillMaxSize()
|
||||
) {
|
||||
// Top row (2 faces)
|
||||
Row(modifier = Modifier.weight(1f)) {
|
||||
facesToShow.getOrNull(0)?.let { face ->
|
||||
FaceThumbnail(
|
||||
imageUri = face.imageUri,
|
||||
modifier = Modifier.weight(1f)
|
||||
)
|
||||
} ?: EmptyFaceSlot(Modifier.weight(1f))
|
||||
// 2x2 grid of faces
|
||||
val facesToShow = cluster.representativeFaces.take(4)
|
||||
|
||||
facesToShow.getOrNull(1)?.let { face ->
|
||||
FaceThumbnail(
|
||||
imageUri = face.imageUri,
|
||||
modifier = Modifier.weight(1f)
|
||||
)
|
||||
} ?: EmptyFaceSlot(Modifier.weight(1f))
|
||||
Column(
|
||||
modifier = Modifier.weight(1f)
|
||||
) {
|
||||
// Top row (2 faces)
|
||||
Row(modifier = Modifier.weight(1f)) {
|
||||
facesToShow.getOrNull(0)?.let { face ->
|
||||
FaceThumbnail(
|
||||
imageUri = face.imageUri,
|
||||
enabled = canTrain,
|
||||
modifier = Modifier.weight(1f)
|
||||
)
|
||||
} ?: EmptyFaceSlot(Modifier.weight(1f))
|
||||
|
||||
facesToShow.getOrNull(1)?.let { face ->
|
||||
FaceThumbnail(
|
||||
imageUri = face.imageUri,
|
||||
enabled = canTrain,
|
||||
modifier = Modifier.weight(1f)
|
||||
)
|
||||
} ?: EmptyFaceSlot(Modifier.weight(1f))
|
||||
}
|
||||
|
||||
// Bottom row (2 faces)
|
||||
Row(modifier = Modifier.weight(1f)) {
|
||||
facesToShow.getOrNull(2)?.let { face ->
|
||||
FaceThumbnail(
|
||||
imageUri = face.imageUri,
|
||||
enabled = canTrain,
|
||||
modifier = Modifier.weight(1f)
|
||||
)
|
||||
} ?: EmptyFaceSlot(Modifier.weight(1f))
|
||||
|
||||
facesToShow.getOrNull(3)?.let { face ->
|
||||
FaceThumbnail(
|
||||
imageUri = face.imageUri,
|
||||
enabled = canTrain,
|
||||
modifier = Modifier.weight(1f)
|
||||
)
|
||||
} ?: EmptyFaceSlot(Modifier.weight(1f))
|
||||
}
|
||||
}
|
||||
|
||||
// Bottom row (2 faces)
|
||||
Row(modifier = Modifier.weight(1f)) {
|
||||
facesToShow.getOrNull(2)?.let { face ->
|
||||
FaceThumbnail(
|
||||
imageUri = face.imageUri,
|
||||
modifier = Modifier.weight(1f)
|
||||
// Footer with photo count
|
||||
Surface(
|
||||
modifier = Modifier.fillMaxWidth(),
|
||||
color = if (canTrain) {
|
||||
MaterialTheme.colorScheme.primaryContainer
|
||||
} else {
|
||||
MaterialTheme.colorScheme.surfaceVariant
|
||||
}
|
||||
) {
|
||||
Row(
|
||||
modifier = Modifier.padding(12.dp),
|
||||
verticalAlignment = Alignment.CenterVertically,
|
||||
horizontalArrangement = Arrangement.SpaceBetween
|
||||
) {
|
||||
Text(
|
||||
text = "${cluster.photoCount} photos",
|
||||
style = MaterialTheme.typography.bodyMedium,
|
||||
fontWeight = FontWeight.SemiBold,
|
||||
color = if (canTrain) {
|
||||
MaterialTheme.colorScheme.onPrimaryContainer
|
||||
} else {
|
||||
MaterialTheme.colorScheme.onSurfaceVariant
|
||||
}
|
||||
)
|
||||
} ?: EmptyFaceSlot(Modifier.weight(1f))
|
||||
|
||||
facesToShow.getOrNull(3)?.let { face ->
|
||||
FaceThumbnail(
|
||||
imageUri = face.imageUri,
|
||||
modifier = Modifier.weight(1f)
|
||||
)
|
||||
} ?: EmptyFaceSlot(Modifier.weight(1f))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Footer with photo count
|
||||
Surface(
|
||||
modifier = Modifier.fillMaxWidth(),
|
||||
color = MaterialTheme.colorScheme.primaryContainer
|
||||
) {
|
||||
Text(
|
||||
text = "${cluster.photoCount} photos",
|
||||
style = MaterialTheme.typography.bodyMedium,
|
||||
fontWeight = FontWeight.SemiBold,
|
||||
modifier = Modifier.padding(12.dp),
|
||||
color = MaterialTheme.colorScheme.onPrimaryContainer
|
||||
)
|
||||
}
|
||||
// Quality badge overlay
|
||||
QualityBadge(
|
||||
qualityTier = qualityTier,
|
||||
canTrain = canTrain,
|
||||
modifier = Modifier
|
||||
.align(Alignment.TopEnd)
|
||||
.padding(8.dp)
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Quality badge indicator
|
||||
*/
|
||||
@Composable
|
||||
private fun QualityBadge(
|
||||
qualityTier: ClusterQualityTier,
|
||||
canTrain: Boolean,
|
||||
modifier: Modifier = Modifier
|
||||
) {
|
||||
val (backgroundColor, iconColor, icon) = when (qualityTier) {
|
||||
ClusterQualityTier.EXCELLENT -> Triple(
|
||||
Color(0xFF1B5E20),
|
||||
Color.White,
|
||||
Icons.Default.Check
|
||||
)
|
||||
ClusterQualityTier.GOOD -> Triple(
|
||||
Color(0xFF2E7D32),
|
||||
Color.White,
|
||||
Icons.Default.Check
|
||||
)
|
||||
ClusterQualityTier.POOR -> Triple(
|
||||
Color(0xFFD32F2F),
|
||||
Color.White,
|
||||
Icons.Default.Warning
|
||||
)
|
||||
}
|
||||
|
||||
Surface(
|
||||
modifier = modifier,
|
||||
shape = CircleShape,
|
||||
color = backgroundColor,
|
||||
shadowElevation = 2.dp
|
||||
) {
|
||||
Box(
|
||||
modifier = Modifier
|
||||
.size(32.dp)
|
||||
.padding(6.dp),
|
||||
contentAlignment = Alignment.Center
|
||||
) {
|
||||
Icon(
|
||||
imageVector = icon,
|
||||
contentDescription = qualityTier.name,
|
||||
tint = iconColor,
|
||||
modifier = Modifier.size(20.dp)
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -153,19 +264,23 @@ private fun ClusterCard(
|
||||
@Composable
|
||||
private fun FaceThumbnail(
|
||||
imageUri: String,
|
||||
enabled: Boolean,
|
||||
modifier: Modifier = Modifier
|
||||
) {
|
||||
AsyncImage(
|
||||
model = Uri.parse(imageUri),
|
||||
contentDescription = "Face",
|
||||
modifier = modifier
|
||||
.fillMaxSize()
|
||||
.border(
|
||||
width = 0.5.dp,
|
||||
color = MaterialTheme.colorScheme.outline.copy(alpha = 0.3f)
|
||||
),
|
||||
contentScale = ContentScale.Crop
|
||||
)
|
||||
Box(modifier = modifier) {
|
||||
AsyncImage(
|
||||
model = Uri.parse(imageUri),
|
||||
contentDescription = "Face",
|
||||
modifier = Modifier
|
||||
.fillMaxSize()
|
||||
.border(
|
||||
width = 0.5.dp,
|
||||
color = MaterialTheme.colorScheme.outline.copy(alpha = 0.3f)
|
||||
),
|
||||
contentScale = ContentScale.Crop,
|
||||
alpha = if (enabled) 1f else 0.6f
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@Composable
|
||||
|
||||
@@ -11,11 +11,17 @@ import androidx.compose.ui.text.font.FontWeight
|
||||
import androidx.compose.ui.text.style.TextAlign
|
||||
import androidx.compose.ui.unit.dp
|
||||
import androidx.hilt.navigation.compose.hiltViewModel
|
||||
import com.placeholder.sherpai2.domain.clustering.ClusterQualityAnalyzer
|
||||
|
||||
/**
|
||||
* DiscoverPeopleScreen - COMPLETE WORKING VERSION
|
||||
* DiscoverPeopleScreen - COMPLETE WORKING VERSION WITH NAMING DIALOG
|
||||
*
|
||||
* This handles ALL states properly including Idle state
|
||||
* This handles ALL states properly including the NamingCluster dialog
|
||||
*
|
||||
* IMPROVEMENTS:
|
||||
* - ✅ Complete naming dialog integration
|
||||
* - ✅ Quality analysis in cluster grid
|
||||
* - ✅ Better error handling
|
||||
*/
|
||||
@OptIn(ExperimentalMaterial3Api::class)
|
||||
@Composable
|
||||
@@ -24,125 +30,130 @@ fun DiscoverPeopleScreen(
|
||||
onNavigateBack: () -> Unit = {}
|
||||
) {
|
||||
val uiState by viewModel.uiState.collectAsState()
|
||||
val qualityAnalyzer = remember { ClusterQualityAnalyzer() }
|
||||
|
||||
Scaffold(
|
||||
topBar = {
|
||||
TopAppBar(
|
||||
title = { Text("Discover People") },
|
||||
navigationIcon = {
|
||||
IconButton(onClick = onNavigateBack) {
|
||||
Icon(
|
||||
imageVector = Icons.Default.Person,
|
||||
contentDescription = "Back"
|
||||
// No Scaffold, no TopAppBar - MainScreen handles that
|
||||
Box(
|
||||
modifier = Modifier.fillMaxSize()
|
||||
) {
|
||||
when (val state = uiState) {
|
||||
// ===== IDLE STATE (START HERE) =====
|
||||
is DiscoverUiState.Idle -> {
|
||||
IdleStateContent(
|
||||
onStartDiscovery = { viewModel.startDiscovery() }
|
||||
)
|
||||
}
|
||||
|
||||
// ===== CLUSTERING IN PROGRESS =====
|
||||
is DiscoverUiState.Clustering -> {
|
||||
ClusteringProgressContent(
|
||||
progress = state.progress,
|
||||
total = state.total,
|
||||
message = state.message
|
||||
)
|
||||
}
|
||||
|
||||
// ===== CLUSTERS READY FOR NAMING =====
|
||||
is DiscoverUiState.NamingReady -> {
|
||||
ClusterGridScreen(
|
||||
result = state.result,
|
||||
onSelectCluster = { cluster ->
|
||||
viewModel.selectCluster(cluster)
|
||||
},
|
||||
qualityAnalyzer = qualityAnalyzer
|
||||
)
|
||||
}
|
||||
|
||||
// ===== ANALYZING CLUSTER QUALITY =====
|
||||
is DiscoverUiState.AnalyzingCluster -> {
|
||||
LoadingContent(message = "Analyzing cluster quality...")
|
||||
}
|
||||
|
||||
// ===== NAMING A CLUSTER (SHOW DIALOG) =====
|
||||
is DiscoverUiState.NamingCluster -> {
|
||||
// Show cluster grid in background
|
||||
ClusterGridScreen(
|
||||
result = state.result,
|
||||
onSelectCluster = { /* Disabled while dialog open */ },
|
||||
qualityAnalyzer = qualityAnalyzer
|
||||
)
|
||||
|
||||
// Show naming dialog overlay
|
||||
NamingDialog(
|
||||
cluster = state.selectedCluster,
|
||||
suggestedSiblings = state.suggestedSiblings,
|
||||
onConfirm = { name, dateOfBirth, isChild, selectedSiblings ->
|
||||
viewModel.confirmClusterName(
|
||||
cluster = state.selectedCluster,
|
||||
name = name,
|
||||
dateOfBirth = dateOfBirth,
|
||||
isChild = isChild,
|
||||
selectedSiblings = selectedSiblings
|
||||
)
|
||||
},
|
||||
onDismiss = {
|
||||
viewModel.cancelNaming()
|
||||
},
|
||||
qualityAnalyzer = qualityAnalyzer
|
||||
)
|
||||
}
|
||||
|
||||
// ===== TRAINING IN PROGRESS =====
|
||||
is DiscoverUiState.Training -> {
|
||||
TrainingProgressContent(
|
||||
stage = state.stage,
|
||||
progress = state.progress,
|
||||
total = state.total
|
||||
)
|
||||
}
|
||||
|
||||
// ===== VALIDATION PREVIEW =====
|
||||
is DiscoverUiState.ValidationPreview -> {
|
||||
ValidationPreviewScreen(
|
||||
personName = state.personName,
|
||||
validationResult = state.validationResult,
|
||||
onApprove = {
|
||||
viewModel.approveValidationAndScan(
|
||||
personId = state.personId,
|
||||
personName = state.personName
|
||||
)
|
||||
},
|
||||
onReject = {
|
||||
viewModel.rejectValidationAndImprove()
|
||||
}
|
||||
}
|
||||
)
|
||||
}
|
||||
) { paddingValues ->
|
||||
Box(
|
||||
modifier = Modifier
|
||||
.fillMaxSize()
|
||||
.padding(paddingValues)
|
||||
) {
|
||||
when (val state = uiState) {
|
||||
// ===== IDLE STATE (START HERE) =====
|
||||
is DiscoverUiState.Idle -> {
|
||||
IdleStateContent(
|
||||
onStartDiscovery = { viewModel.startDiscovery() }
|
||||
)
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
// ===== CLUSTERING IN PROGRESS =====
|
||||
is DiscoverUiState.Clustering -> {
|
||||
ClusteringProgressContent(
|
||||
progress = state.progress,
|
||||
total = state.total,
|
||||
message = state.message
|
||||
)
|
||||
}
|
||||
// ===== COMPLETE =====
|
||||
is DiscoverUiState.Complete -> {
|
||||
CompleteStateContent(
|
||||
message = state.message,
|
||||
onDone = onNavigateBack
|
||||
)
|
||||
}
|
||||
|
||||
// ===== CLUSTERS READY FOR NAMING =====
|
||||
is DiscoverUiState.NamingReady -> {
|
||||
ClusterGridScreen(
|
||||
result = state.result,
|
||||
onSelectCluster = { cluster ->
|
||||
viewModel.selectCluster(cluster)
|
||||
}
|
||||
)
|
||||
}
|
||||
// ===== NO PEOPLE FOUND =====
|
||||
is DiscoverUiState.NoPeopleFound -> {
|
||||
ErrorStateContent(
|
||||
title = "No People Found",
|
||||
message = state.message,
|
||||
onRetry = { viewModel.startDiscovery() },
|
||||
onBack = onNavigateBack
|
||||
)
|
||||
}
|
||||
|
||||
// ===== ANALYZING CLUSTER QUALITY =====
|
||||
is DiscoverUiState.AnalyzingCluster -> {
|
||||
LoadingContent(message = "Analyzing cluster quality...")
|
||||
}
|
||||
|
||||
// ===== NAMING A CLUSTER =====
|
||||
is DiscoverUiState.NamingCluster -> {
|
||||
Text(
|
||||
text = "Naming dialog for cluster ${state.selectedCluster.clusterId}\n\nDialog UI coming...",
|
||||
modifier = Modifier.align(Alignment.Center)
|
||||
)
|
||||
}
|
||||
|
||||
// ===== TRAINING IN PROGRESS =====
|
||||
is DiscoverUiState.Training -> {
|
||||
TrainingProgressContent(
|
||||
stage = state.stage,
|
||||
progress = state.progress,
|
||||
total = state.total
|
||||
)
|
||||
}
|
||||
|
||||
// ===== VALIDATION PREVIEW =====
|
||||
is DiscoverUiState.ValidationPreview -> {
|
||||
ValidationPreviewScreen(
|
||||
personName = state.personName,
|
||||
validationResult = state.validationResult,
|
||||
onApprove = {
|
||||
viewModel.approveValidationAndScan(
|
||||
personId = state.personId,
|
||||
personName = state.personName
|
||||
)
|
||||
},
|
||||
onReject = {
|
||||
viewModel.rejectValidationAndImprove()
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
// ===== COMPLETE =====
|
||||
is DiscoverUiState.Complete -> {
|
||||
CompleteStateContent(
|
||||
message = state.message,
|
||||
onDone = onNavigateBack
|
||||
)
|
||||
}
|
||||
|
||||
// ===== NO PEOPLE FOUND =====
|
||||
is DiscoverUiState.NoPeopleFound -> {
|
||||
ErrorStateContent(
|
||||
title = "No People Found",
|
||||
message = state.message,
|
||||
onRetry = { viewModel.startDiscovery() },
|
||||
onBack = onNavigateBack
|
||||
)
|
||||
}
|
||||
|
||||
// ===== ERROR =====
|
||||
is DiscoverUiState.Error -> {
|
||||
ErrorStateContent(
|
||||
title = "Error",
|
||||
message = state.message,
|
||||
onRetry = { viewModel.reset(); viewModel.startDiscovery() },
|
||||
onBack = onNavigateBack
|
||||
)
|
||||
}
|
||||
// ===== ERROR =====
|
||||
is DiscoverUiState.Error -> {
|
||||
ErrorStateContent(
|
||||
title = "Error",
|
||||
message = state.message,
|
||||
onRetry = { viewModel.reset(); viewModel.startDiscovery() },
|
||||
onBack = onNavigateBack
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ===== IDLE STATE CONTENT =====
|
||||
|
||||
@Composable
|
||||
@@ -165,19 +176,11 @@ private fun IdleStateContent(
|
||||
|
||||
Spacer(modifier = Modifier.height(32.dp))
|
||||
|
||||
Text(
|
||||
text = "Discover People",
|
||||
style = MaterialTheme.typography.headlineLarge,
|
||||
fontWeight = FontWeight.Bold
|
||||
)
|
||||
|
||||
Spacer(modifier = Modifier.height(16.dp))
|
||||
|
||||
Text(
|
||||
text = "Automatically find and organize people in your photo library",
|
||||
style = MaterialTheme.typography.bodyLarge,
|
||||
style = MaterialTheme.typography.headlineSmall,
|
||||
textAlign = TextAlign.Center,
|
||||
color = MaterialTheme.colorScheme.onSurfaceVariant
|
||||
color = MaterialTheme.colorScheme.onSurface
|
||||
)
|
||||
|
||||
Spacer(modifier = Modifier.height(48.dp))
|
||||
|
||||
@@ -0,0 +1,480 @@
|
||||
package com.placeholder.sherpai2.ui.discover
|
||||
|
||||
import androidx.compose.foundation.background
|
||||
import androidx.compose.foundation.border
|
||||
import androidx.compose.foundation.clickable
|
||||
import androidx.compose.foundation.layout.*
|
||||
import androidx.compose.foundation.lazy.LazyRow
|
||||
import androidx.compose.foundation.lazy.items
|
||||
import androidx.compose.foundation.rememberScrollState
|
||||
import androidx.compose.foundation.shape.CircleShape
|
||||
import androidx.compose.foundation.shape.RoundedCornerShape
|
||||
import androidx.compose.foundation.text.KeyboardActions
|
||||
import androidx.compose.foundation.text.KeyboardOptions
|
||||
import androidx.compose.foundation.verticalScroll
|
||||
import androidx.compose.material.icons.Icons
|
||||
import androidx.compose.material.icons.filled.*
|
||||
import androidx.compose.material3.*
|
||||
import androidx.compose.runtime.*
|
||||
import androidx.compose.ui.Alignment
|
||||
import androidx.compose.ui.Modifier
|
||||
import androidx.compose.ui.draw.clip
|
||||
import androidx.compose.ui.graphics.Color
|
||||
import androidx.compose.ui.layout.ContentScale
|
||||
import androidx.compose.ui.platform.LocalSoftwareKeyboardController
|
||||
import androidx.compose.ui.text.font.FontWeight
|
||||
import androidx.compose.ui.text.input.ImeAction
|
||||
import androidx.compose.ui.text.input.KeyboardCapitalization
|
||||
import androidx.compose.ui.text.input.KeyboardType
|
||||
import androidx.compose.ui.text.style.TextAlign
|
||||
import androidx.compose.ui.unit.dp
|
||||
import androidx.compose.ui.window.Dialog
|
||||
import coil.compose.AsyncImage
|
||||
import com.placeholder.sherpai2.domain.clustering.ClusterQualityAnalyzer
|
||||
import com.placeholder.sherpai2.domain.clustering.ClusterQualityTier
|
||||
import com.placeholder.sherpai2.domain.clustering.FaceCluster
|
||||
import java.text.SimpleDateFormat
|
||||
import java.util.*
|
||||
|
||||
/**
|
||||
* NamingDialog - Complete dialog for naming a cluster
|
||||
*
|
||||
* Features:
|
||||
* - Name input with validation
|
||||
* - Child toggle with date of birth picker
|
||||
* - Sibling cluster selection
|
||||
* - Quality warnings display
|
||||
* - Preview of representative faces
|
||||
*
|
||||
* IMPROVEMENTS:
|
||||
* - ✅ Complete UI implementation
|
||||
* - ✅ Quality analysis integration
|
||||
* - ✅ Sibling selection
|
||||
* - ✅ Form validation
|
||||
*/
|
||||
@OptIn(ExperimentalMaterial3Api::class)
|
||||
@Composable
|
||||
fun NamingDialog(
|
||||
cluster: FaceCluster,
|
||||
suggestedSiblings: List<FaceCluster>,
|
||||
onConfirm: (name: String, dateOfBirth: Long?, isChild: Boolean, selectedSiblings: List<Int>) -> Unit,
|
||||
onDismiss: () -> Unit,
|
||||
qualityAnalyzer: ClusterQualityAnalyzer = remember { ClusterQualityAnalyzer() }
|
||||
) {
|
||||
var name by remember { mutableStateOf("") }
|
||||
var isChild by remember { mutableStateOf(false) }
|
||||
var showDatePicker by remember { mutableStateOf(false) }
|
||||
var dateOfBirth by remember { mutableStateOf<Long?>(null) }
|
||||
var selectedSiblingIds by remember { mutableStateOf(setOf<Int>()) }
|
||||
|
||||
// Analyze cluster quality
|
||||
val qualityResult = remember(cluster) {
|
||||
qualityAnalyzer.analyzeCluster(cluster)
|
||||
}
|
||||
|
||||
val keyboardController = LocalSoftwareKeyboardController.current
|
||||
val dateFormatter = remember { SimpleDateFormat("MMM dd, yyyy", Locale.getDefault()) }
|
||||
|
||||
Dialog(onDismissRequest = onDismiss) {
|
||||
Card(
|
||||
modifier = Modifier
|
||||
.fillMaxWidth()
|
||||
.fillMaxHeight(0.9f),
|
||||
shape = RoundedCornerShape(16.dp),
|
||||
elevation = CardDefaults.cardElevation(defaultElevation = 8.dp)
|
||||
) {
|
||||
Column(
|
||||
modifier = Modifier
|
||||
.fillMaxSize()
|
||||
.verticalScroll(rememberScrollState())
|
||||
) {
|
||||
// Header
|
||||
Surface(
|
||||
color = MaterialTheme.colorScheme.primaryContainer
|
||||
) {
|
||||
Row(
|
||||
modifier = Modifier
|
||||
.fillMaxWidth()
|
||||
.padding(16.dp),
|
||||
horizontalArrangement = Arrangement.SpaceBetween,
|
||||
verticalAlignment = Alignment.CenterVertically
|
||||
) {
|
||||
Column(modifier = Modifier.weight(1f)) {
|
||||
Text(
|
||||
text = "Name This Person",
|
||||
style = MaterialTheme.typography.titleLarge,
|
||||
fontWeight = FontWeight.Bold,
|
||||
color = MaterialTheme.colorScheme.onPrimaryContainer
|
||||
)
|
||||
Text(
|
||||
text = "${cluster.photoCount} photos",
|
||||
style = MaterialTheme.typography.bodyMedium,
|
||||
color = MaterialTheme.colorScheme.onPrimaryContainer.copy(alpha = 0.7f)
|
||||
)
|
||||
}
|
||||
|
||||
IconButton(onClick = onDismiss) {
|
||||
Icon(
|
||||
imageVector = Icons.Default.Close,
|
||||
contentDescription = "Close",
|
||||
tint = MaterialTheme.colorScheme.onPrimaryContainer
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Column(
|
||||
modifier = Modifier.padding(16.dp)
|
||||
) {
|
||||
// Preview faces
|
||||
Text(
|
||||
text = "Preview",
|
||||
style = MaterialTheme.typography.titleMedium,
|
||||
fontWeight = FontWeight.SemiBold
|
||||
)
|
||||
|
||||
Spacer(modifier = Modifier.height(8.dp))
|
||||
|
||||
LazyRow(
|
||||
horizontalArrangement = Arrangement.spacedBy(8.dp)
|
||||
) {
|
||||
items(cluster.representativeFaces.take(6)) { face ->
|
||||
AsyncImage(
|
||||
model = android.net.Uri.parse(face.imageUri),
|
||||
contentDescription = "Preview",
|
||||
modifier = Modifier
|
||||
.size(80.dp)
|
||||
.clip(RoundedCornerShape(8.dp))
|
||||
.border(
|
||||
width = 1.dp,
|
||||
color = MaterialTheme.colorScheme.outline,
|
||||
shape = RoundedCornerShape(8.dp)
|
||||
),
|
||||
contentScale = ContentScale.Crop
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
Spacer(modifier = Modifier.height(20.dp))
|
||||
|
||||
// Quality warning (if applicable)
|
||||
if (qualityResult.qualityTier != ClusterQualityTier.EXCELLENT) {
|
||||
QualityWarningCard(qualityResult = qualityResult)
|
||||
Spacer(modifier = Modifier.height(16.dp))
|
||||
}
|
||||
|
||||
// Name input
|
||||
OutlinedTextField(
|
||||
value = name,
|
||||
onValueChange = { name = it },
|
||||
label = { Text("Name") },
|
||||
placeholder = { Text("Enter person's name") },
|
||||
modifier = Modifier.fillMaxWidth(),
|
||||
singleLine = true,
|
||||
leadingIcon = {
|
||||
Icon(
|
||||
imageVector = Icons.Default.Person,
|
||||
contentDescription = null
|
||||
)
|
||||
},
|
||||
keyboardOptions = KeyboardOptions(
|
||||
capitalization = KeyboardCapitalization.Words,
|
||||
imeAction = ImeAction.Done
|
||||
),
|
||||
keyboardActions = KeyboardActions(
|
||||
onDone = { keyboardController?.hide() }
|
||||
)
|
||||
)
|
||||
|
||||
Spacer(modifier = Modifier.height(16.dp))
|
||||
|
||||
// Child toggle
|
||||
Row(
|
||||
modifier = Modifier
|
||||
.fillMaxWidth()
|
||||
.clip(RoundedCornerShape(8.dp))
|
||||
.clickable { isChild = !isChild }
|
||||
.background(
|
||||
if (isChild) MaterialTheme.colorScheme.primaryContainer
|
||||
else MaterialTheme.colorScheme.surfaceVariant
|
||||
)
|
||||
.padding(16.dp),
|
||||
verticalAlignment = Alignment.CenterVertically,
|
||||
horizontalArrangement = Arrangement.SpaceBetween
|
||||
) {
|
||||
Row(
|
||||
verticalAlignment = Alignment.CenterVertically
|
||||
) {
|
||||
Icon(
|
||||
imageVector = Icons.Default.Face,
|
||||
contentDescription = null,
|
||||
tint = if (isChild) MaterialTheme.colorScheme.onPrimaryContainer
|
||||
else MaterialTheme.colorScheme.onSurfaceVariant
|
||||
)
|
||||
Spacer(modifier = Modifier.width(12.dp))
|
||||
Column {
|
||||
Text(
|
||||
text = "This is a child",
|
||||
style = MaterialTheme.typography.bodyLarge,
|
||||
fontWeight = FontWeight.Medium,
|
||||
color = if (isChild) MaterialTheme.colorScheme.onPrimaryContainer
|
||||
else MaterialTheme.colorScheme.onSurfaceVariant
|
||||
)
|
||||
Text(
|
||||
text = "For age-appropriate filtering",
|
||||
style = MaterialTheme.typography.bodySmall,
|
||||
color = if (isChild) MaterialTheme.colorScheme.onPrimaryContainer.copy(alpha = 0.7f)
|
||||
else MaterialTheme.colorScheme.onSurfaceVariant.copy(alpha = 0.7f)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
Switch(
|
||||
checked = isChild,
|
||||
onCheckedChange = { isChild = it }
|
||||
)
|
||||
}
|
||||
|
||||
// Date of birth (if child)
|
||||
if (isChild) {
|
||||
Spacer(modifier = Modifier.height(12.dp))
|
||||
|
||||
OutlinedButton(
|
||||
onClick = { showDatePicker = true },
|
||||
modifier = Modifier.fillMaxWidth()
|
||||
) {
|
||||
Icon(
|
||||
imageVector = Icons.Default.DateRange,
|
||||
contentDescription = null
|
||||
)
|
||||
Spacer(modifier = Modifier.width(8.dp))
|
||||
Text(
|
||||
text = dateOfBirth?.let { dateFormatter.format(Date(it)) }
|
||||
?: "Set date of birth (optional)"
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// Sibling selection
|
||||
if (suggestedSiblings.isNotEmpty()) {
|
||||
Spacer(modifier = Modifier.height(20.dp))
|
||||
|
||||
Text(
|
||||
text = "Appears with",
|
||||
style = MaterialTheme.typography.titleMedium,
|
||||
fontWeight = FontWeight.SemiBold
|
||||
)
|
||||
|
||||
Text(
|
||||
text = "Select siblings or family members",
|
||||
style = MaterialTheme.typography.bodySmall,
|
||||
color = MaterialTheme.colorScheme.onSurfaceVariant
|
||||
)
|
||||
|
||||
Spacer(modifier = Modifier.height(8.dp))
|
||||
|
||||
suggestedSiblings.forEach { sibling ->
|
||||
SiblingSelectionItem(
|
||||
cluster = sibling,
|
||||
selected = sibling.clusterId in selectedSiblingIds,
|
||||
onToggle = {
|
||||
selectedSiblingIds = if (sibling.clusterId in selectedSiblingIds) {
|
||||
selectedSiblingIds - sibling.clusterId
|
||||
} else {
|
||||
selectedSiblingIds + sibling.clusterId
|
||||
}
|
||||
}
|
||||
)
|
||||
Spacer(modifier = Modifier.height(8.dp))
|
||||
}
|
||||
}
|
||||
|
||||
Spacer(modifier = Modifier.height(24.dp))
|
||||
|
||||
// Action buttons
|
||||
Row(
|
||||
modifier = Modifier.fillMaxWidth(),
|
||||
horizontalArrangement = Arrangement.spacedBy(12.dp)
|
||||
) {
|
||||
OutlinedButton(
|
||||
onClick = onDismiss,
|
||||
modifier = Modifier.weight(1f)
|
||||
) {
|
||||
Text("Cancel")
|
||||
}
|
||||
|
||||
Button(
|
||||
onClick = {
|
||||
if (name.isNotBlank()) {
|
||||
onConfirm(
|
||||
name.trim(),
|
||||
dateOfBirth,
|
||||
isChild,
|
||||
selectedSiblingIds.toList()
|
||||
)
|
||||
}
|
||||
},
|
||||
modifier = Modifier.weight(1f),
|
||||
enabled = name.isNotBlank() && qualityResult.canTrain
|
||||
) {
|
||||
Icon(
|
||||
imageVector = Icons.Default.Check,
|
||||
contentDescription = null,
|
||||
modifier = Modifier.size(20.dp)
|
||||
)
|
||||
Spacer(modifier = Modifier.width(8.dp))
|
||||
Text("Create Model")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Date picker dialog
|
||||
if (showDatePicker) {
|
||||
val datePickerState = rememberDatePickerState()
|
||||
|
||||
DatePickerDialog(
|
||||
onDismissRequest = { showDatePicker = false },
|
||||
confirmButton = {
|
||||
TextButton(
|
||||
onClick = {
|
||||
dateOfBirth = datePickerState.selectedDateMillis
|
||||
showDatePicker = false
|
||||
}
|
||||
) {
|
||||
Text("OK")
|
||||
}
|
||||
},
|
||||
dismissButton = {
|
||||
TextButton(onClick = { showDatePicker = false }) {
|
||||
Text("Cancel")
|
||||
}
|
||||
}
|
||||
) {
|
||||
DatePicker(state = datePickerState)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Quality warning card
|
||||
*/
|
||||
@Composable
|
||||
private fun QualityWarningCard(qualityResult: com.placeholder.sherpai2.domain.clustering.ClusterQualityResult) {
|
||||
val (backgroundColor, iconColor) = when (qualityResult.qualityTier) {
|
||||
ClusterQualityTier.GOOD -> Pair(
|
||||
Color(0xFFFFF9C4),
|
||||
Color(0xFFF57F17)
|
||||
)
|
||||
ClusterQualityTier.POOR -> Pair(
|
||||
Color(0xFFFFEBEE),
|
||||
Color(0xFFD32F2F)
|
||||
)
|
||||
else -> Pair(
|
||||
MaterialTheme.colorScheme.surfaceVariant,
|
||||
MaterialTheme.colorScheme.onSurfaceVariant
|
||||
)
|
||||
}
|
||||
|
||||
Card(
|
||||
modifier = Modifier.fillMaxWidth(),
|
||||
colors = CardDefaults.cardColors(
|
||||
containerColor = backgroundColor
|
||||
)
|
||||
) {
|
||||
Column(
|
||||
modifier = Modifier.padding(12.dp)
|
||||
) {
|
||||
Row(
|
||||
verticalAlignment = Alignment.CenterVertically
|
||||
) {
|
||||
Icon(
|
||||
imageVector = Icons.Default.Warning,
|
||||
contentDescription = null,
|
||||
tint = iconColor,
|
||||
modifier = Modifier.size(20.dp)
|
||||
)
|
||||
Spacer(modifier = Modifier.width(8.dp))
|
||||
Text(
|
||||
text = when (qualityResult.qualityTier) {
|
||||
ClusterQualityTier.GOOD -> "Review Before Training"
|
||||
ClusterQualityTier.POOR -> "Quality Issues Detected"
|
||||
else -> ""
|
||||
},
|
||||
style = MaterialTheme.typography.titleSmall,
|
||||
fontWeight = FontWeight.Bold,
|
||||
color = iconColor
|
||||
)
|
||||
}
|
||||
|
||||
Spacer(modifier = Modifier.height(8.dp))
|
||||
|
||||
qualityResult.warnings.forEach { warning ->
|
||||
Text(
|
||||
text = "• $warning",
|
||||
style = MaterialTheme.typography.bodySmall,
|
||||
color = MaterialTheme.colorScheme.onSurfaceVariant
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Sibling selection item
|
||||
*/
|
||||
@Composable
|
||||
private fun SiblingSelectionItem(
|
||||
cluster: FaceCluster,
|
||||
selected: Boolean,
|
||||
onToggle: () -> Unit
|
||||
) {
|
||||
Row(
|
||||
modifier = Modifier
|
||||
.fillMaxWidth()
|
||||
.clip(RoundedCornerShape(8.dp))
|
||||
.clickable(onClick = onToggle)
|
||||
.background(
|
||||
if (selected) MaterialTheme.colorScheme.primaryContainer
|
||||
else MaterialTheme.colorScheme.surfaceVariant
|
||||
)
|
||||
.padding(12.dp),
|
||||
verticalAlignment = Alignment.CenterVertically,
|
||||
horizontalArrangement = Arrangement.SpaceBetween
|
||||
) {
|
||||
Row(
|
||||
verticalAlignment = Alignment.CenterVertically
|
||||
) {
|
||||
// Preview face
|
||||
AsyncImage(
|
||||
model = android.net.Uri.parse(cluster.representativeFaces.firstOrNull()?.imageUri ?: ""),
|
||||
contentDescription = "Preview",
|
||||
modifier = Modifier
|
||||
.size(40.dp)
|
||||
.clip(CircleShape)
|
||||
.border(
|
||||
width = 1.dp,
|
||||
color = MaterialTheme.colorScheme.outline,
|
||||
shape = CircleShape
|
||||
),
|
||||
contentScale = ContentScale.Crop
|
||||
)
|
||||
|
||||
Spacer(modifier = Modifier.width(12.dp))
|
||||
|
||||
Text(
|
||||
text = "${cluster.photoCount} photos together",
|
||||
style = MaterialTheme.typography.bodyMedium,
|
||||
color = if (selected) MaterialTheme.colorScheme.onPrimaryContainer
|
||||
else MaterialTheme.colorScheme.onSurfaceVariant
|
||||
)
|
||||
}
|
||||
|
||||
Checkbox(
|
||||
checked = selected,
|
||||
onCheckedChange = { onToggle() }
|
||||
)
|
||||
}
|
||||
}
|
||||
@@ -1,56 +1,48 @@
|
||||
package com.placeholder.sherpai2.ui.presentation
|
||||
|
||||
import androidx.compose.foundation.layout.Column
|
||||
import androidx.compose.foundation.layout.padding
|
||||
import androidx.compose.material.icons.Icons
|
||||
import androidx.compose.material.icons.filled.*
|
||||
import androidx.compose.material.icons.filled.Menu
|
||||
import androidx.compose.material3.*
|
||||
import androidx.compose.runtime.*
|
||||
import androidx.compose.ui.Modifier
|
||||
import androidx.compose.ui.text.font.FontWeight
|
||||
import androidx.hilt.navigation.compose.hiltViewModel
|
||||
import androidx.navigation.compose.currentBackStackEntryAsState
|
||||
import androidx.navigation.compose.rememberNavController
|
||||
import androidx.navigation.compose.currentBackStackEntryAsState
|
||||
import com.placeholder.sherpai2.ui.navigation.AppNavHost
|
||||
import com.placeholder.sherpai2.ui.navigation.AppRoutes
|
||||
import kotlinx.coroutines.launch
|
||||
|
||||
/**
|
||||
* MainScreen - UPDATED with auto face cache check
|
||||
* MainScreen - Complete app container with drawer navigation
|
||||
*
|
||||
* NEW: Prompts user to populate face cache on app launch if needed
|
||||
* FIXED: Prevents double headers for screens with their own TopAppBar
|
||||
* CRITICAL FIX APPLIED:
|
||||
* ✅ Removed AppRoutes.DISCOVER from screensWithOwnTopBar
|
||||
* ✅ DiscoverPeopleScreen now shows hamburger menu + "Discover People" title!
|
||||
*/
|
||||
@OptIn(ExperimentalMaterial3Api::class)
|
||||
@Composable
|
||||
fun MainScreen(
|
||||
mainViewModel: MainViewModel = hiltViewModel() // Same package - no import needed!
|
||||
viewModel: MainViewModel = hiltViewModel()
|
||||
) {
|
||||
val drawerState = rememberDrawerState(initialValue = DrawerValue.Closed)
|
||||
val scope = rememberCoroutineScope()
|
||||
val navController = rememberNavController()
|
||||
val drawerState = rememberDrawerState(DrawerValue.Closed)
|
||||
val scope = rememberCoroutineScope()
|
||||
|
||||
val navBackStackEntry by navController.currentBackStackEntryAsState()
|
||||
val currentRoute = navBackStackEntry?.destination?.route ?: AppRoutes.SEARCH
|
||||
val currentBackStackEntry by navController.currentBackStackEntryAsState()
|
||||
val currentRoute = currentBackStackEntry?.destination?.route
|
||||
|
||||
// Face cache status
|
||||
val needsFaceCache by mainViewModel.needsFaceCachePopulation.collectAsState()
|
||||
val unscannedCount by mainViewModel.unscannedPhotoCount.collectAsState()
|
||||
// Face cache prompt dialog state
|
||||
val needsFaceCachePopulation by viewModel.needsFaceCachePopulation.collectAsState()
|
||||
val unscannedPhotoCount by viewModel.unscannedPhotoCount.collectAsState()
|
||||
|
||||
// Show face cache prompt dialog if needed
|
||||
if (needsFaceCache && unscannedCount > 0) {
|
||||
FaceCachePromptDialog(
|
||||
unscannedPhotoCount = unscannedCount,
|
||||
onDismiss = { mainViewModel.dismissFaceCachePrompt() },
|
||||
onScanNow = {
|
||||
mainViewModel.dismissFaceCachePrompt()
|
||||
// Navigate to Photo Utilities
|
||||
navController.navigate(AppRoutes.UTILITIES) {
|
||||
launchSingleTop = true
|
||||
}
|
||||
}
|
||||
)
|
||||
}
|
||||
// ✅ CRITICAL FIX: DISCOVER is NOT in this list!
|
||||
// These screens handle their own TopAppBar/navigation
|
||||
val screensWithOwnTopBar = setOf(
|
||||
AppRoutes.IMAGE_DETAIL,
|
||||
AppRoutes.TRAINING_SCREEN,
|
||||
AppRoutes.CROP_SCREEN
|
||||
)
|
||||
|
||||
ModalNavigationDrawer(
|
||||
drawerState = drawerState,
|
||||
@@ -60,133 +52,86 @@ fun MainScreen(
|
||||
onDestinationClicked = { route ->
|
||||
scope.launch {
|
||||
drawerState.close()
|
||||
if (route != currentRoute) {
|
||||
navController.navigate(route) {
|
||||
launchSingleTop = true
|
||||
}
|
||||
}
|
||||
navController.navigate(route) {
|
||||
popUpTo(navController.graph.startDestinationId) {
|
||||
saveState = true
|
||||
}
|
||||
launchSingleTop = true
|
||||
restoreState = true
|
||||
}
|
||||
}
|
||||
)
|
||||
},
|
||||
}
|
||||
) {
|
||||
// CRITICAL: Some screens manage their own TopAppBar
|
||||
// Hide MainScreen's TopAppBar for these routes to prevent double headers
|
||||
val screensWithOwnTopBar = setOf(
|
||||
AppRoutes.TRAINING_PHOTO_SELECTOR, // Has its own TopAppBar with subtitle
|
||||
"album/", // Album views have their own TopAppBar (prefix match)
|
||||
AppRoutes.IMAGE_DETAIL // Image detail has its own TopAppBar
|
||||
)
|
||||
|
||||
// Check if current route starts with any excluded pattern
|
||||
val showTopBar = screensWithOwnTopBar.none { currentRoute.startsWith(it) }
|
||||
|
||||
Scaffold(
|
||||
topBar = {
|
||||
if (showTopBar) {
|
||||
// ✅ Show TopAppBar for ALL screens except those with their own
|
||||
if (currentRoute !in screensWithOwnTopBar) {
|
||||
TopAppBar(
|
||||
title = {
|
||||
Column {
|
||||
Text(
|
||||
text = getScreenTitle(currentRoute),
|
||||
style = MaterialTheme.typography.titleLarge,
|
||||
fontWeight = FontWeight.Bold
|
||||
)
|
||||
getScreenSubtitle(currentRoute)?.let { subtitle ->
|
||||
Text(
|
||||
text = subtitle,
|
||||
style = MaterialTheme.typography.bodySmall,
|
||||
color = MaterialTheme.colorScheme.onSurfaceVariant
|
||||
)
|
||||
Text(
|
||||
text = when (currentRoute) {
|
||||
AppRoutes.SEARCH -> "Search"
|
||||
AppRoutes.EXPLORE -> "Explore"
|
||||
AppRoutes.COLLECTIONS -> "Collections"
|
||||
AppRoutes.DISCOVER -> "Discover People" // ✅ SHOWS NOW!
|
||||
AppRoutes.INVENTORY -> "People"
|
||||
AppRoutes.TRAIN -> "Train Model"
|
||||
AppRoutes.TAGS -> "Tags"
|
||||
AppRoutes.UTILITIES -> "Utilities"
|
||||
AppRoutes.SETTINGS -> "Settings"
|
||||
AppRoutes.MODELS -> "AI Models"
|
||||
else -> {
|
||||
// Handle dynamic routes like album/{type}/{id}
|
||||
if (currentRoute?.startsWith("album/") == true) {
|
||||
"Album"
|
||||
} else {
|
||||
"SherpAI"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
)
|
||||
},
|
||||
navigationIcon = {
|
||||
IconButton(
|
||||
onClick = { scope.launch { drawerState.open() } }
|
||||
) {
|
||||
IconButton(onClick = {
|
||||
scope.launch {
|
||||
drawerState.open()
|
||||
}
|
||||
}) {
|
||||
Icon(
|
||||
Icons.Default.Menu,
|
||||
contentDescription = "Open Menu",
|
||||
tint = MaterialTheme.colorScheme.primary
|
||||
imageVector = Icons.Default.Menu,
|
||||
contentDescription = "Open menu"
|
||||
)
|
||||
}
|
||||
},
|
||||
actions = {
|
||||
// Dynamic actions based on current screen
|
||||
when (currentRoute) {
|
||||
AppRoutes.SEARCH -> {
|
||||
IconButton(onClick = { /* TODO: Open filter dialog */ }) {
|
||||
Icon(
|
||||
Icons.Default.FilterList,
|
||||
contentDescription = "Filter",
|
||||
tint = MaterialTheme.colorScheme.primary
|
||||
)
|
||||
}
|
||||
}
|
||||
AppRoutes.INVENTORY -> {
|
||||
IconButton(onClick = {
|
||||
navController.navigate(AppRoutes.TRAIN)
|
||||
}) {
|
||||
Icon(
|
||||
Icons.Default.PersonAdd,
|
||||
contentDescription = "Add Person",
|
||||
tint = MaterialTheme.colorScheme.primary
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
colors = TopAppBarDefaults.topAppBarColors(
|
||||
containerColor = MaterialTheme.colorScheme.surface,
|
||||
titleContentColor = MaterialTheme.colorScheme.onSurface,
|
||||
navigationIconContentColor = MaterialTheme.colorScheme.primary,
|
||||
actionIconContentColor = MaterialTheme.colorScheme.primary
|
||||
containerColor = MaterialTheme.colorScheme.primaryContainer,
|
||||
titleContentColor = MaterialTheme.colorScheme.onPrimaryContainer,
|
||||
navigationIconContentColor = MaterialTheme.colorScheme.onPrimaryContainer,
|
||||
actionIconContentColor = MaterialTheme.colorScheme.onPrimaryContainer
|
||||
)
|
||||
)
|
||||
}
|
||||
}
|
||||
) { paddingValues ->
|
||||
// ✅ Use YOUR existing AppNavHost - it already has all the screens defined!
|
||||
AppNavHost(
|
||||
navController = navController,
|
||||
modifier = Modifier.padding(paddingValues)
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get human-readable screen title
|
||||
*/
|
||||
private fun getScreenTitle(route: String): String {
|
||||
return when (route) {
|
||||
AppRoutes.SEARCH -> "Search"
|
||||
AppRoutes.EXPLORE -> "Explore"
|
||||
AppRoutes.COLLECTIONS -> "Collections"
|
||||
AppRoutes.DISCOVER -> "Discover People"
|
||||
AppRoutes.INVENTORY -> "People"
|
||||
AppRoutes.TRAIN -> "Train New Person"
|
||||
AppRoutes.MODELS -> "AI Models"
|
||||
AppRoutes.TAGS -> "Tag Management"
|
||||
AppRoutes.UTILITIES -> "Photo Util."
|
||||
AppRoutes.SETTINGS -> "Settings"
|
||||
else -> "SherpAI"
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get subtitle for screens that need context
|
||||
*/
|
||||
private fun getScreenSubtitle(route: String): String? {
|
||||
return when (route) {
|
||||
AppRoutes.SEARCH -> "Find photos by tags, people, or date"
|
||||
AppRoutes.EXPLORE -> "Browse your collection"
|
||||
AppRoutes.COLLECTIONS -> "Your photo collections"
|
||||
AppRoutes.DISCOVER -> "Auto-find faces in your library"
|
||||
AppRoutes.INVENTORY -> "Trained face models"
|
||||
AppRoutes.TRAIN -> "Add a new person to recognize"
|
||||
AppRoutes.TAGS -> "Organize your photo collection"
|
||||
AppRoutes.UTILITIES -> "Tools for managing collection"
|
||||
else -> null
|
||||
// ✅ Face cache prompt dialog (shows on app launch if needed)
|
||||
if (needsFaceCachePopulation) {
|
||||
FaceCachePromptDialog(
|
||||
unscannedPhotoCount = unscannedPhotoCount,
|
||||
onDismiss = { viewModel.dismissFaceCachePrompt() },
|
||||
onScanNow = {
|
||||
viewModel.dismissFaceCachePrompt()
|
||||
navController.navigate(AppRoutes.UTILITIES)
|
||||
}
|
||||
)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user