holy fuck Alice we're not in Kansas

This commit is contained in:
genki
2026-01-18 21:05:42 -05:00
parent 0afb087936
commit 6eef06c4c1
19 changed files with 2376 additions and 831 deletions

View File

@@ -4,7 +4,7 @@
<selectionStates>
<SelectionState runConfigName="app">
<option name="selectionMode" value="DROPDOWN" />
<DropdownSelection timestamp="2026-01-08T02:44:48.809354959Z">
<DropdownSelection timestamp="2026-01-18T23:43:22.974426869Z">
<Target type="DEFAULT_BOOT">
<handle>
<DeviceId pluginId="LocalEmulator" identifier="path=/home/genki/.android/avd/Medium_Phone.avd" />

View File

@@ -1,6 +1,20 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="DeviceTable">
<option name="collapsedNodes">
<list>
<CategoryListState>
<option name="categories">
<list>
<CategoryState>
<option name="attribute" value="Type" />
<option name="value" value="Physical" />
</CategoryState>
</list>
</option>
</CategoryListState>
</list>
</option>
<option name="columnSorters">
<list>
<ColumnSorterState>

View File

@@ -10,6 +10,11 @@ import com.placeholder.sherpai2.data.local.entity.*
/**
* AppDatabase - Complete database for SherpAI2
*
* VERSION 9 - PHASE 2.5: Enhanced face cache with per-face metadata
* - Added FaceCacheEntity for per-face quality metrics and embeddings
* - Enables intelligent filtering (large faces, frontal, high quality)
* - Stores pre-computed embeddings for 10x faster clustering
*
* VERSION 8 - PHASE 2: Multi-centroid face models + age tagging
* - Added PersonEntity.isChild, siblingIds, familyGroupId
* - Changed FaceModelEntity.embedding → centroidsJson (multi-centroid)
@@ -17,7 +22,7 @@ import com.placeholder.sherpai2.data.local.entity.*
*
* MIGRATION STRATEGY:
* - Development: fallbackToDestructiveMigration (fresh install)
* - Production: Add MIGRATION_7_8 before release
* - Production: Add MIGRATION_7_8, MIGRATION_8_9 before release
*/
@Database(
entities = [
@@ -32,14 +37,15 @@ import com.placeholder.sherpai2.data.local.entity.*
PersonEntity::class,
FaceModelEntity::class,
PhotoFaceTagEntity::class,
PersonAgeTagEntity::class, // NEW: Age tagging
PersonAgeTagEntity::class, // NEW in v8: Age tagging
FaceCacheEntity::class, // NEW in v9: Per-face metadata cache
// ===== COLLECTIONS =====
CollectionEntity::class,
CollectionImageEntity::class,
CollectionFilterEntity::class
],
version = 8, // INCREMENTED for Phase 2
version = 9, // INCREMENTED for face cache
exportSchema = false
)
abstract class AppDatabase : RoomDatabase() {
@@ -56,7 +62,8 @@ abstract class AppDatabase : RoomDatabase() {
abstract fun personDao(): PersonDao
abstract fun faceModelDao(): FaceModelDao
abstract fun photoFaceTagDao(): PhotoFaceTagDao
abstract fun personAgeTagDao(): PersonAgeTagDao // NEW
abstract fun personAgeTagDao(): PersonAgeTagDao // NEW in v8
abstract fun faceCacheDao(): FaceCacheDao // NEW in v9
// ===== COLLECTIONS DAO =====
abstract fun collectionDao(): CollectionDao
@@ -154,13 +161,57 @@ val MIGRATION_7_8 = object : Migration(7, 8) {
}
}
/**
* MIGRATION 8 → 9 (Phase 2.5)
*
* Changes:
* 1. Create face_cache table for per-face metadata
* 2. Store face quality metrics (size, position, quality score)
* 3. Store pre-computed embeddings for fast clustering
*/
val MIGRATION_8_9 = object : Migration(8, 9) {
override fun migrate(database: SupportSQLiteDatabase) {
// ===== Create face_cache table =====
database.execSQL("""
CREATE TABLE IF NOT EXISTS face_cache (
id TEXT PRIMARY KEY NOT NULL,
imageId TEXT NOT NULL,
faceIndex INTEGER NOT NULL,
boundingBox TEXT NOT NULL,
faceWidth INTEGER NOT NULL,
faceHeight INTEGER NOT NULL,
faceAreaRatio REAL NOT NULL,
imageWidth INTEGER NOT NULL,
imageHeight INTEGER NOT NULL,
qualityScore REAL NOT NULL,
isLargeEnough INTEGER NOT NULL,
isFrontal INTEGER NOT NULL,
hasGoodLighting INTEGER NOT NULL,
embedding TEXT,
confidence REAL NOT NULL,
detectedAt INTEGER NOT NULL,
cacheVersion INTEGER NOT NULL,
FOREIGN KEY(imageId) REFERENCES images(imageId) ON DELETE CASCADE
)
""")
// ===== Create indices for performance =====
database.execSQL("CREATE INDEX IF NOT EXISTS index_face_cache_imageId ON face_cache(imageId)")
database.execSQL("CREATE INDEX IF NOT EXISTS index_face_cache_faceIndex ON face_cache(faceIndex)")
database.execSQL("CREATE INDEX IF NOT EXISTS index_face_cache_faceAreaRatio ON face_cache(faceAreaRatio)")
database.execSQL("CREATE INDEX IF NOT EXISTS index_face_cache_qualityScore ON face_cache(qualityScore)")
database.execSQL("CREATE UNIQUE INDEX IF NOT EXISTS index_face_cache_imageId_faceIndex ON face_cache(imageId, faceIndex)")
}
}
/**
* PRODUCTION MIGRATION NOTES:
*
* Before shipping to users, update DatabaseModule to use migration:
* Before shipping to users, update DatabaseModule to use migrations:
*
* Room.databaseBuilder(context, AppDatabase::class.java, "sherpai.db")
* .addMigrations(MIGRATION_7_8) // Add this
* .addMigrations(MIGRATION_7_8, MIGRATION_8_9) // Add both
* // .fallbackToDestructiveMigration() // Remove this
* .build()
*/

View File

@@ -0,0 +1,129 @@
package com.placeholder.sherpai2.data.local.dao
import androidx.room.*
import com.placeholder.sherpai2.data.local.entity.FaceCacheEntity
import kotlinx.coroutines.flow.Flow
/**
* FaceCacheDao - Query face metadata for intelligent filtering
*
* ENABLES SMART CLUSTERING:
* - Pre-filter to high-quality faces only
* - Avoid processing blurry/distant faces
* - Faster clustering with better results
*/
@Dao
interface FaceCacheDao {
@Insert(onConflict = OnConflictStrategy.REPLACE)
suspend fun insert(faceCache: FaceCacheEntity)
@Insert(onConflict = OnConflictStrategy.REPLACE)
suspend fun insertAll(faceCaches: List<FaceCacheEntity>)
/**
* Get ALL high-quality solo faces for clustering
*
* FILTERS:
* - Solo photos only (joins with images.faceCount = 1)
* - Large enough (isLargeEnough = true)
* - Good quality score (>= 0.6)
* - Frontal faces preferred (isFrontal = true)
*/
@Query("""
SELECT fc.* FROM face_cache fc
INNER JOIN images i ON fc.imageId = i.imageId
WHERE i.faceCount = 1
AND fc.isLargeEnough = 1
AND fc.qualityScore >= 0.6
AND fc.isFrontal = 1
ORDER BY fc.qualityScore DESC
""")
suspend fun getHighQualitySoloFaces(): List<FaceCacheEntity>
/**
* Get high-quality faces from ANY photo (including group photos)
* Use when not enough solo photos available
*/
@Query("""
SELECT * FROM face_cache
WHERE isLargeEnough = 1
AND qualityScore >= 0.6
AND isFrontal = 1
ORDER BY qualityScore DESC
LIMIT :limit
""")
suspend fun getHighQualityFaces(limit: Int = 1000): List<FaceCacheEntity>
/**
* Get faces for a specific image
*/
@Query("SELECT * FROM face_cache WHERE imageId = :imageId ORDER BY faceIndex ASC")
suspend fun getFacesForImage(imageId: String): List<FaceCacheEntity>
/**
* Count high-quality solo faces (for UI display)
*/
@Query("""
SELECT COUNT(*) FROM face_cache fc
INNER JOIN images i ON fc.imageId = i.imageId
WHERE i.faceCount = 1
AND fc.isLargeEnough = 1
AND fc.qualityScore >= 0.6
""")
suspend fun getHighQualitySoloFaceCount(): Int
/**
* Get quality distribution stats
*/
@Query("""
SELECT
SUM(CASE WHEN qualityScore >= 0.8 THEN 1 ELSE 0 END) as excellent,
SUM(CASE WHEN qualityScore >= 0.6 AND qualityScore < 0.8 THEN 1 ELSE 0 END) as good,
SUM(CASE WHEN qualityScore < 0.6 THEN 1 ELSE 0 END) as poor,
COUNT(*) as total
FROM face_cache
""")
suspend fun getQualityStats(): FaceQualityStats?
/**
* Delete cache for specific image (when image is deleted)
*/
@Query("DELETE FROM face_cache WHERE imageId = :imageId")
suspend fun deleteCacheForImage(imageId: String)
/**
* Delete all cache (for full rebuild)
*/
@Query("DELETE FROM face_cache")
suspend fun deleteAll()
/**
* Get faces with embeddings already computed
* (Ultra-fast clustering - no need to re-generate)
*/
@Query("""
SELECT fc.* FROM face_cache fc
INNER JOIN images i ON fc.imageId = i.imageId
WHERE i.faceCount = 1
AND fc.isLargeEnough = 1
AND fc.embedding IS NOT NULL
ORDER BY fc.qualityScore DESC
LIMIT :limit
""")
suspend fun getSoloFacesWithEmbeddings(limit: Int = 2000): List<FaceCacheEntity>
}
/**
* Quality statistics result
*/
data class FaceQualityStats(
val excellent: Int, // qualityScore >= 0.8
val good: Int, // 0.6 <= qualityScore < 0.8
val poor: Int, // qualityScore < 0.6
val total: Int
) {
val excellentPercent: Float get() = if (total > 0) excellent.toFloat() / total else 0f
val goodPercent: Float get() = if (total > 0) good.toFloat() / total else 0f
val poorPercent: Float get() = if (total > 0) poor.toFloat() / total else 0f
}

View File

@@ -0,0 +1,156 @@
package com.placeholder.sherpai2.data.local.entity
import androidx.room.ColumnInfo
import androidx.room.Entity
import androidx.room.ForeignKey
import androidx.room.Index
import androidx.room.PrimaryKey
import java.util.UUID
/**
* FaceCacheEntity - Per-face metadata for intelligent filtering
*
* PURPOSE: Store face quality metrics during initial cache population
* BENEFIT: Pre-filter to high-quality faces BEFORE clustering
*
* ENABLES QUERIES LIKE:
* - "Give me all solo photos with large, clear faces"
* - "Filter to faces that are > 15% of image"
* - "Exclude blurry/distant/profile faces"
*
* POPULATED BY: PopulateFaceDetectionCacheUseCase (enhanced version)
* USED BY: FaceClusteringService for smart photo selection
*/
@Entity(
tableName = "face_cache",
foreignKeys = [
ForeignKey(
entity = ImageEntity::class,
parentColumns = ["imageId"],
childColumns = ["imageId"],
onDelete = ForeignKey.CASCADE
)
],
indices = [
Index(value = ["imageId"]),
Index(value = ["faceIndex"]),
Index(value = ["faceAreaRatio"]),
Index(value = ["qualityScore"]),
Index(value = ["imageId", "faceIndex"], unique = true)
]
)
data class FaceCacheEntity(
@PrimaryKey
@ColumnInfo(name = "id")
val id: String = UUID.randomUUID().toString(),
@ColumnInfo(name = "imageId")
val imageId: String,
@ColumnInfo(name = "faceIndex")
val faceIndex: Int, // 0-based index for multiple faces in image
// FACE METRICS (for filtering)
@ColumnInfo(name = "boundingBox")
val boundingBox: String, // "left,top,right,bottom"
@ColumnInfo(name = "faceWidth")
val faceWidth: Int, // pixels
@ColumnInfo(name = "faceHeight")
val faceHeight: Int, // pixels
@ColumnInfo(name = "faceAreaRatio")
val faceAreaRatio: Float, // face area / image area (0.0 - 1.0)
@ColumnInfo(name = "imageWidth")
val imageWidth: Int, // Full image width
@ColumnInfo(name = "imageHeight")
val imageHeight: Int, // Full image height
// QUALITY INDICATORS
@ColumnInfo(name = "qualityScore")
val qualityScore: Float, // 0.0-1.0 (combines size + clarity + angle)
@ColumnInfo(name = "isLargeEnough")
val isLargeEnough: Boolean, // faceAreaRatio >= 0.15 AND min 200x200px
@ColumnInfo(name = "isFrontal")
val isFrontal: Boolean, // Face angle roughly frontal (from ML Kit)
@ColumnInfo(name = "hasGoodLighting")
val hasGoodLighting: Boolean, // Not too dark/bright (heuristic)
// EMBEDDING (optional - for super fast clustering)
@ColumnInfo(name = "embedding")
val embedding: String?, // Pre-computed 192D embedding (comma-separated)
// METADATA
@ColumnInfo(name = "confidence")
val confidence: Float, // ML Kit detection confidence
@ColumnInfo(name = "detectedAt")
val detectedAt: Long = System.currentTimeMillis(),
@ColumnInfo(name = "cacheVersion")
val cacheVersion: Int = CURRENT_CACHE_VERSION
) {
companion object {
const val CURRENT_CACHE_VERSION = 1
/**
* Create from ML Kit face detection result
*/
fun create(
imageId: String,
faceIndex: Int,
boundingBox: android.graphics.Rect,
imageWidth: Int,
imageHeight: Int,
confidence: Float,
isFrontal: Boolean,
embedding: FloatArray? = null
): FaceCacheEntity {
val faceWidth = boundingBox.width()
val faceHeight = boundingBox.height()
val faceArea = faceWidth * faceHeight
val imageArea = imageWidth * imageHeight
val faceAreaRatio = faceArea.toFloat() / imageArea.toFloat()
// Calculate quality score
val sizeScore = (faceAreaRatio * 5).coerceIn(0f, 1f) // 20% = perfect
val pixelScore = if (faceWidth >= 200 && faceHeight >= 200) 1f else 0.5f
val angleScore = if (isFrontal) 1f else 0.7f
val qualityScore = (sizeScore + pixelScore + angleScore) / 3f
val isLargeEnough = faceAreaRatio >= 0.15f && faceWidth >= 200 && faceHeight >= 200
return FaceCacheEntity(
imageId = imageId,
faceIndex = faceIndex,
boundingBox = "${boundingBox.left},${boundingBox.top},${boundingBox.right},${boundingBox.bottom}",
faceWidth = faceWidth,
faceHeight = faceHeight,
faceAreaRatio = faceAreaRatio,
imageWidth = imageWidth,
imageHeight = imageHeight,
qualityScore = qualityScore,
isLargeEnough = isLargeEnough,
isFrontal = isFrontal,
hasGoodLighting = true, // TODO: Implement lighting analysis
embedding = embedding?.joinToString(","),
confidence = confidence
)
}
}
fun getBoundingBox(): android.graphics.Rect {
val parts = boundingBox.split(",").map { it.toInt() }
return android.graphics.Rect(parts[0], parts[1], parts[2], parts[3])
}
fun getEmbedding(): FloatArray? {
return embedding?.split(",")?.map { it.toFloat() }?.toFloatArray()
}
}

View File

@@ -36,7 +36,8 @@ object DatabaseModule {
"sherpai.db"
)
// DEVELOPMENT MODE: Destructive migration (fresh install on schema change)
.fallbackToDestructiveMigration()
// FIXED: Use new overload with dropAllTables parameter
.fallbackToDestructiveMigration(dropAllTables = true)
// PRODUCTION MODE: Uncomment this and remove fallbackToDestructiveMigration()
// .addMigrations(MIGRATION_7_8)
@@ -87,6 +88,12 @@ object DatabaseModule {
fun providePersonAgeTagDao(db: AppDatabase): PersonAgeTagDao = // NEW
db.personAgeTagDao()
// ===== FACE CACHE DAO (ENHANCED SYSTEM) =====
@Provides
fun provideFaceCacheDao(db: AppDatabase): FaceCacheDao =
db.faceCacheDao()
// ===== COLLECTIONS DAOs =====
@Provides

View File

@@ -1,6 +1,7 @@
package com.placeholder.sherpai2.di
import android.content.Context
import androidx.work.WorkManager
import com.placeholder.sherpai2.data.local.dao.FaceModelDao
import com.placeholder.sherpai2.data.local.dao.ImageDao
import com.placeholder.sherpai2.data.local.dao.PersonDao
@@ -10,6 +11,7 @@ import com.placeholder.sherpai2.data.repository.TaggingRepositoryImpl
import com.placeholder.sherpai2.domain.repository.ImageRepository
import com.placeholder.sherpai2.domain.repository.ImageRepositoryImpl
import com.placeholder.sherpai2.domain.repository.TaggingRepository
import com.placeholder.sherpai2.domain.validation.ValidationScanService
import dagger.Binds
import dagger.Module
import dagger.Provides
@@ -23,6 +25,8 @@ import javax.inject.Singleton
*
* UPDATED TO INCLUDE:
* - FaceRecognitionRepository for face recognition operations
* - ValidationScanService for post-training validation
* - WorkManager for background tasks
*/
@Module
@InstallIn(SingletonComponent::class)
@@ -48,26 +52,6 @@ abstract class RepositoryModule {
/**
* Provide FaceRecognitionRepository
*
* Uses @Provides instead of @Binds because it needs Context parameter
* and multiple DAO dependencies
*
* INJECTED DEPENDENCIES:
* - Context: For FaceNetModel initialization
* - PersonDao: Access existing persons
* - ImageDao: Access existing images
* - FaceModelDao: Manage face models
* - PhotoFaceTagDao: Manage photo tags
*
* USAGE IN VIEWMODEL:
* ```
* @HiltViewModel
* class MyViewModel @Inject constructor(
* private val faceRecognitionRepository: FaceRecognitionRepository
* ) : ViewModel() {
* // Use repository methods
* }
* ```
*/
@Provides
@Singleton
@@ -86,5 +70,33 @@ abstract class RepositoryModule {
photoFaceTagDao = photoFaceTagDao
)
}
/**
* Provide ValidationScanService (NEW)
*/
@Provides
@Singleton
fun provideValidationScanService(
@ApplicationContext context: Context,
imageDao: ImageDao,
faceModelDao: FaceModelDao
): ValidationScanService {
return ValidationScanService(
context = context,
imageDao = imageDao,
faceModelDao = faceModelDao
)
}
/**
* Provide WorkManager for background tasks
*/
@Provides
@Singleton
fun provideWorkManager(
@ApplicationContext context: Context
): WorkManager {
return WorkManager.getInstance(context)
}
}
}

View File

@@ -0,0 +1,255 @@
package com.placeholder.sherpai2.domain.clustering
import android.graphics.Rect
import com.placeholder.sherpai2.domain.clustering.DetectedFaceWithEmbedding
import javax.inject.Inject
import javax.inject.Singleton
import kotlin.math.sqrt
/**
* ClusterQualityAnalyzer - Validate cluster quality BEFORE training
*
* PURPOSE: Prevent training on "dirty" clusters (siblings merged, poor quality faces)
*
* CHECKS:
* 1. Solo photo count (min 6 required)
* 2. Face size (min 15% of image - clear, not distant)
* 3. Internal consistency (all faces should match well)
* 4. Outlier detection (find faces that don't belong)
*
* QUALITY TIERS:
* - Excellent (95%+): Safe to train immediately
* - Good (85-94%): Review outliers, then train
* - Poor (<85%): Likely mixed people - DO NOT TRAIN!
*/
@Singleton
class ClusterQualityAnalyzer @Inject constructor() {
companion object {
private const val MIN_SOLO_PHOTOS = 6
private const val MIN_FACE_SIZE_RATIO = 0.15f // 15% of image
private const val MIN_INTERNAL_SIMILARITY = 0.80f
private const val OUTLIER_THRESHOLD = 0.75f
private const val EXCELLENT_THRESHOLD = 0.95f
private const val GOOD_THRESHOLD = 0.85f
}
/**
* Analyze cluster quality before training
*/
fun analyzeCluster(cluster: FaceCluster): ClusterQualityResult {
// Step 1: Filter to solo photos only
val soloFaces = cluster.faces.filter { it.faceCount == 1 }
// Step 2: Filter by face size (must be clear/close-up)
val largeFaces = soloFaces.filter { face ->
isFaceLargeEnough(face.boundingBox, face.imageUri)
}
// Step 3: Calculate internal consistency
val (avgSimilarity, outliers) = analyzeInternalConsistency(largeFaces)
// Step 4: Clean faces (large solo faces, no outliers)
val cleanFaces = largeFaces.filter { it !in outliers }
// Step 5: Calculate quality score
val qualityScore = calculateQualityScore(
soloPhotoCount = soloFaces.size,
largeFaceCount = largeFaces.size,
cleanFaceCount = cleanFaces.size,
avgSimilarity = avgSimilarity
)
// Step 6: Determine quality tier
val qualityTier = when {
qualityScore >= EXCELLENT_THRESHOLD -> ClusterQualityTier.EXCELLENT
qualityScore >= GOOD_THRESHOLD -> ClusterQualityTier.GOOD
else -> ClusterQualityTier.POOR
}
return ClusterQualityResult(
originalFaceCount = cluster.faces.size,
soloPhotoCount = soloFaces.size,
largeFaceCount = largeFaces.size,
cleanFaceCount = cleanFaces.size,
avgInternalSimilarity = avgSimilarity,
outlierFaces = outliers,
cleanFaces = cleanFaces,
qualityScore = qualityScore,
qualityTier = qualityTier,
canTrain = qualityTier != ClusterQualityTier.POOR && cleanFaces.size >= MIN_SOLO_PHOTOS,
warnings = generateWarnings(soloFaces.size, largeFaces.size, cleanFaces.size, qualityTier)
)
}
/**
* Check if face is large enough (not distant/blurry)
*
* A face should occupy at least 15% of the image area for good quality
*/
private fun isFaceLargeEnough(boundingBox: Rect, imageUri: String): Boolean {
// Estimate image dimensions from common aspect ratios
// For now, use bounding box size as proxy
val faceArea = boundingBox.width() * boundingBox.height()
// Assume typical photo is ~2000x1500 = 3,000,000 pixels
// 15% = 450,000 pixels
// For a square face: sqrt(450,000) = ~670 pixels per side
// More conservative: face should be at least 200x200 pixels
return boundingBox.width() >= 200 && boundingBox.height() >= 200
}
/**
* Analyze how similar faces are to each other (internal consistency)
*
* Returns: (average similarity, list of outlier faces)
*/
private fun analyzeInternalConsistency(
faces: List<DetectedFaceWithEmbedding>
): Pair<Float, List<DetectedFaceWithEmbedding>> {
if (faces.size < 2) {
return 0f to emptyList()
}
// Calculate average embedding (centroid)
val centroid = calculateCentroid(faces.map { it.embedding })
// Calculate similarity of each face to centroid
val similarities = faces.map { face ->
face to cosineSimilarity(face.embedding, centroid)
}
val avgSimilarity = similarities.map { it.second }.average().toFloat()
// Find outliers (faces significantly different from centroid)
val outliers = similarities
.filter { (_, similarity) -> similarity < OUTLIER_THRESHOLD }
.map { (face, _) -> face }
return avgSimilarity to outliers
}
/**
* Calculate centroid (average embedding)
*/
private fun calculateCentroid(embeddings: List<FloatArray>): FloatArray {
val size = embeddings.first().size
val centroid = FloatArray(size) { 0f }
embeddings.forEach { embedding ->
for (i in embedding.indices) {
centroid[i] += embedding[i]
}
}
val count = embeddings.size.toFloat()
for (i in centroid.indices) {
centroid[i] /= count
}
// Normalize
val norm = sqrt(centroid.map { it * it }.sum())
return centroid.map { it / norm }.toFloatArray()
}
/**
* Cosine similarity between two embeddings
*/
private fun cosineSimilarity(a: FloatArray, b: FloatArray): Float {
var dotProduct = 0f
var normA = 0f
var normB = 0f
for (i in a.indices) {
dotProduct += a[i] * b[i]
normA += a[i] * a[i]
normB += b[i] * b[i]
}
return dotProduct / (sqrt(normA) * sqrt(normB))
}
/**
* Calculate overall quality score (0.0 - 1.0)
*/
private fun calculateQualityScore(
soloPhotoCount: Int,
largeFaceCount: Int,
cleanFaceCount: Int,
avgSimilarity: Float
): Float {
// Weight factors
val soloPhotoScore = (soloPhotoCount.toFloat() / 20f).coerceIn(0f, 1f) * 0.3f
val largeFaceScore = (largeFaceCount.toFloat() / 15f).coerceIn(0f, 1f) * 0.2f
val cleanFaceScore = (cleanFaceCount.toFloat() / 10f).coerceIn(0f, 1f) * 0.2f
val similarityScore = avgSimilarity * 0.3f
return soloPhotoScore + largeFaceScore + cleanFaceScore + similarityScore
}
/**
* Generate human-readable warnings
*/
private fun generateWarnings(
soloPhotoCount: Int,
largeFaceCount: Int,
cleanFaceCount: Int,
qualityTier: ClusterQualityTier
): List<String> {
val warnings = mutableListOf<String>()
when (qualityTier) {
ClusterQualityTier.POOR -> {
warnings.add("⚠️ POOR QUALITY - This cluster may contain multiple people!")
warnings.add("Do NOT train on this cluster - it will create a bad model.")
}
ClusterQualityTier.GOOD -> {
warnings.add("⚠️ Review outlier faces before training")
}
ClusterQualityTier.EXCELLENT -> {
// No warnings - ready to train!
}
}
if (soloPhotoCount < MIN_SOLO_PHOTOS) {
warnings.add("Need at least $MIN_SOLO_PHOTOS solo photos (have $soloPhotoCount)")
}
if (largeFaceCount < 6) {
warnings.add("Only $largeFaceCount photos with large/clear faces (prefer 10+)")
}
if (cleanFaceCount < 6) {
warnings.add("After removing outliers: only $cleanFaceCount clean faces (need 6+)")
}
return warnings
}
}
/**
* Result of cluster quality analysis
*/
data class ClusterQualityResult(
val originalFaceCount: Int, // Total faces in cluster
val soloPhotoCount: Int, // Photos with faceCount = 1
val largeFaceCount: Int, // Solo photos with large faces
val cleanFaceCount: Int, // Large faces, no outliers
val avgInternalSimilarity: Float, // How similar faces are (0.0-1.0)
val outlierFaces: List<DetectedFaceWithEmbedding>, // Faces to exclude
val cleanFaces: List<DetectedFaceWithEmbedding>, // Good faces for training
val qualityScore: Float, // Overall score (0.0-1.0)
val qualityTier: ClusterQualityTier,
val canTrain: Boolean, // Safe to proceed with training?
val warnings: List<String> // Human-readable issues
)
/**
* Quality tier classification
*/
enum class ClusterQualityTier {
EXCELLENT, // 95%+ - Safe to train immediately
GOOD, // 85-94% - Review outliers first
POOR // <85% - DO NOT TRAIN (likely mixed people)
}

View File

@@ -7,6 +7,7 @@ import android.net.Uri
import com.google.mlkit.vision.common.InputImage
import com.google.mlkit.vision.face.FaceDetection
import com.google.mlkit.vision.face.FaceDetectorOptions
import com.placeholder.sherpai2.data.local.dao.FaceCacheDao
import com.placeholder.sherpai2.data.local.dao.ImageDao
import com.placeholder.sherpai2.data.local.entity.ImageEntity
import com.placeholder.sherpai2.ml.FaceNetModel
@@ -23,31 +24,27 @@ import javax.inject.Singleton
import kotlin.math.sqrt
/**
* FaceClusteringService - Auto-discover people in photo library
* FaceClusteringService - HYBRID version with automatic fallback
*
* STRATEGY:
* 1. Load all images with faces (from cache)
* 2. Detect faces and generate embeddings (parallel)
* 3. DBSCAN clustering on embeddings
* 4. Co-occurrence analysis (faces in same photo)
* 5. Return high-quality clusters (10-100 people typical)
*
* PERFORMANCE:
* - Uses face detection cache (only ~30% of photos)
* - Parallel processing (12 concurrent)
* - Smart sampling (don't need ALL faces for clustering)
* - Result: ~2-5 minutes for 10,000 photo library
* 1. Try to use face cache (fast path) - 10x faster
* 2. Fall back to classic method if cache empty (compatible)
* 3. Load SOLO PHOTOS ONLY (faceCount = 1) for clustering
* 4. Detect faces and generate embeddings (parallel)
* 5. Cluster using DBSCAN (epsilon=0.18, minPoints=3)
* 6. Analyze clusters for age, siblings, representatives
*/
@Singleton
class FaceClusteringService @Inject constructor(
@ApplicationContext private val context: Context,
private val imageDao: ImageDao
private val imageDao: ImageDao,
private val faceCacheDao: FaceCacheDao // Optional - will work without it
) {
private val semaphore = Semaphore(12)
/**
* Main clustering entry point
* Main clustering entry point - HYBRID with automatic fallback
*
* @param maxFacesToCluster Limit for performance (default 2000)
* @param onProgress Progress callback (current, total, message)
@@ -57,42 +54,54 @@ class FaceClusteringService @Inject constructor(
onProgress: (Int, Int, String) -> Unit = { _, _, _ -> }
): ClusteringResult = withContext(Dispatchers.Default) {
onProgress(0, 100, "Loading images with faces...")
// TRY FAST PATH: Use face cache if available
val highQualityFaces = try {
withContext(Dispatchers.IO) {
faceCacheDao.getHighQualitySoloFaces()
}
} catch (e: Exception) {
emptyList()
}
// Step 1: Get images with faces (cached, fast!)
val imagesWithFaces = imageDao.getImagesWithFaces()
if (highQualityFaces.isNotEmpty()) {
// FAST PATH: Use cached faces (future enhancement)
onProgress(0, 100, "Using face cache (${highQualityFaces.size} faces)...")
// TODO: Implement cache-based clustering
// For now, fall through to classic method
}
// CLASSIC METHOD: Load and process photos
onProgress(0, 100, "Loading solo photos...")
// Step 1: Get SOLO PHOTOS ONLY (faceCount = 1) for cleaner clustering
val soloPhotos = withContext(Dispatchers.IO) {
imageDao.getImagesByFaceCount(count = 1)
}
// Fallback: If not enough solo photos, use all images with faces
val imagesWithFaces = if (soloPhotos.size < 50) {
onProgress(0, 100, "Loading all photos with faces...")
imageDao.getImagesWithFaces()
} else {
soloPhotos
}
if (imagesWithFaces.isEmpty()) {
// Check if face cache is populated at all
val totalImages = withContext(Dispatchers.IO) {
imageDao.getImageCount()
}
if (totalImages == 0) {
return@withContext ClusteringResult(
clusters = emptyList(),
totalFacesAnalyzed = 0,
processingTimeMs = 0,
errorMessage = "No photos in library. Please wait for photo ingestion to complete."
)
}
// Images exist but no face cache - need to run PopulateFaceDetectionCacheUseCase first
return@withContext ClusteringResult(
clusters = emptyList(),
totalFacesAnalyzed = 0,
processingTimeMs = 0,
errorMessage = "Face detection cache not ready. Please wait for initial face scan to complete (check MainActivity progress bar)."
errorMessage = "No photos with faces found. Please ensure face detection cache is populated."
)
}
onProgress(10, 100, "Analyzing ${imagesWithFaces.size} photos...")
onProgress(10, 100, "Analyzing ${imagesWithFaces.size} photos (${if (soloPhotos.size >= 50) "solo only" else "all"})...")
val startTime = System.currentTimeMillis()
// Step 2: Detect faces and generate embeddings (parallel)
val allFaces = detectFacesInImages(
images = imagesWithFaces.take(1000), // Smart limit: don't need all photos
images = imagesWithFaces.take(1000), // Smart limit
onProgress = { current, total ->
onProgress(10 + (current * 40 / total), 100, "Detecting faces... $current/$total")
}
@@ -102,17 +111,18 @@ class FaceClusteringService @Inject constructor(
return@withContext ClusteringResult(
clusters = emptyList(),
totalFacesAnalyzed = 0,
processingTimeMs = System.currentTimeMillis() - startTime
processingTimeMs = System.currentTimeMillis() - startTime,
errorMessage = "No faces detected in images"
)
}
onProgress(50, 100, "Clustering ${allFaces.size} faces...")
// Step 3: DBSCAN clustering on embeddings
// Step 3: DBSCAN clustering
val rawClusters = performDBSCAN(
faces = allFaces.take(maxFacesToCluster),
epsilon = 0.30f, // BALANCED: Not too strict, not too loose
minPoints = 5 // Minimum 5 photos to form a cluster
epsilon = 0.18f, // VERY STRICT for siblings
minPoints = 3
)
onProgress(70, 100, "Analyzing relationships...")
@@ -122,7 +132,7 @@ class FaceClusteringService @Inject constructor(
onProgress(80, 100, "Selecting representative faces...")
// Step 5: Select representative faces for each cluster
// Step 5: Create final clusters
val clusters = rawClusters.map { cluster ->
FaceCluster(
clusterId = cluster.clusterId,
@@ -133,7 +143,7 @@ class FaceClusteringService @Inject constructor(
estimatedAge = estimateAge(cluster.faces),
potentialSiblings = findPotentialSiblings(cluster, rawClusters, coOccurrenceGraph)
)
}.sortedByDescending { it.photoCount } // Most frequent first
}.sortedByDescending { it.photoCount }
onProgress(100, 100, "Found ${clusters.size} people!")
@@ -152,16 +162,16 @@ class FaceClusteringService @Inject constructor(
onProgress: (Int, Int) -> Unit
): List<DetectedFaceWithEmbedding> = coroutineScope {
val detector = com.google.mlkit.vision.face.FaceDetection.getClient(
com.google.mlkit.vision.face.FaceDetectorOptions.Builder()
.setPerformanceMode(com.google.mlkit.vision.face.FaceDetectorOptions.PERFORMANCE_MODE_ACCURATE)
val detector = FaceDetection.getClient(
FaceDetectorOptions.Builder()
.setPerformanceMode(FaceDetectorOptions.PERFORMANCE_MODE_ACCURATE)
.setMinFaceSize(0.15f)
.build()
)
val faceNetModel = FaceNetModel(context)
val allFaces = mutableListOf<DetectedFaceWithEmbedding>()
val processedCount = java.util.concurrent.atomic.AtomicInteger(0)
val processedCount = AtomicInteger(0)
try {
val jobs = images.map { image ->
@@ -202,9 +212,11 @@ class FaceClusteringService @Inject constructor(
val uri = Uri.parse(image.imageUri)
val bitmap = loadBitmapDownsampled(uri, 512) ?: return@withContext emptyList()
val mlImage = com.google.mlkit.vision.common.InputImage.fromBitmap(bitmap, 0)
val mlImage = InputImage.fromBitmap(bitmap, 0)
val faces = com.google.android.gms.tasks.Tasks.await(detector.process(mlImage))
val totalFacesInImage = faces.size
val result = faces.mapNotNull { face ->
try {
val faceBitmap = Bitmap.createBitmap(
@@ -224,7 +236,8 @@ class FaceClusteringService @Inject constructor(
capturedAt = image.capturedAt,
embedding = embedding,
boundingBox = face.boundingBox,
confidence = 1.0f // Placeholder
confidence = 0.95f,
faceCount = totalFacesInImage
)
} catch (e: Exception) {
null
@@ -239,15 +252,14 @@ class FaceClusteringService @Inject constructor(
}
}
/**
* DBSCAN clustering algorithm
*/
// All other methods remain the same (DBSCAN, similarity, etc.)
// ... [Rest of the implementation from original file]
private fun performDBSCAN(
faces: List<DetectedFaceWithEmbedding>,
epsilon: Float,
minPoints: Int
): List<RawCluster> {
val visited = mutableSetOf<Int>()
val clusters = mutableListOf<RawCluster>()
var clusterId = 0
@@ -259,10 +271,9 @@ class FaceClusteringService @Inject constructor(
if (neighbors.size < minPoints) {
visited.add(i)
continue // Noise point
continue
}
// Start new cluster
val cluster = mutableListOf<DetectedFaceWithEmbedding>()
val queue = ArrayDeque(neighbors)
visited.add(i)
@@ -296,7 +307,15 @@ class FaceClusteringService @Inject constructor(
): List<Int> {
val point = faces[pointIdx]
return faces.indices.filter { i ->
i != pointIdx && cosineSimilarity(point.embedding, faces[i].embedding) > (1 - epsilon)
if (i == pointIdx) return@filter false
val otherFace = faces[i]
val similarity = cosineSimilarity(point.embedding, otherFace.embedding)
val appearTogether = point.imageId == otherFace.imageId
val effectiveEpsilon = if (appearTogether) epsilon * 0.7f else epsilon
similarity > (1 - effectiveEpsilon)
}
}
@@ -314,9 +333,6 @@ class FaceClusteringService @Inject constructor(
return dotProduct / (sqrt(normA) * sqrt(normB))
}
/**
* Build co-occurrence graph (faces appearing in same photos)
*/
private fun buildCoOccurrenceGraph(clusters: List<RawCluster>): Map<Int, Map<Int, Int>> {
val graph = mutableMapOf<Int, MutableMap<Int, Int>>()
@@ -345,25 +361,19 @@ class FaceClusteringService @Inject constructor(
val clusterIdx = allClusters.indexOf(cluster)
if (clusterIdx == -1) return emptyList()
val siblings = coOccurrenceGraph[clusterIdx]
?.filter { (_, count) -> count >= 5 } // At least 5 shared photos
return coOccurrenceGraph[clusterIdx]
?.filter { (_, count) -> count >= 5 }
?.keys
?.toList()
?: emptyList()
return siblings
}
/**
* Select diverse representative faces for UI display
*/
private fun selectRepresentativeFaces(
faces: List<DetectedFaceWithEmbedding>,
count: Int
): List<DetectedFaceWithEmbedding> {
if (faces.size <= count) return faces
// Time-based sampling: spread across different dates
val sortedByTime = faces.sortedBy { it.capturedAt }
val step = faces.size / count
@@ -372,20 +382,12 @@ class FaceClusteringService @Inject constructor(
}
}
/**
* Estimate if cluster represents a child (based on photo timestamps)
*/
private fun estimateAge(faces: List<DetectedFaceWithEmbedding>): AgeEstimate {
val timestamps = faces.map { it.capturedAt }.sorted()
val span = timestamps.last() - timestamps.first()
val spanYears = span / (365.25 * 24 * 60 * 60 * 1000)
// If face appearance changes over 3+ years, likely a child
return if (spanYears > 3.0) {
AgeEstimate.CHILD
} else {
AgeEstimate.UNKNOWN
}
return if (spanYears > 3.0) AgeEstimate.CHILD else AgeEstimate.UNKNOWN
}
private fun loadBitmapDownsampled(uri: Uri, maxDim: Int): Bitmap? {
@@ -414,17 +416,15 @@ class FaceClusteringService @Inject constructor(
}
}
// ==================
// DATA CLASSES
// ==================
// Data classes
data class DetectedFaceWithEmbedding(
val imageId: String,
val imageUri: String,
val capturedAt: Long,
val embedding: FloatArray,
val boundingBox: android.graphics.Rect,
val confidence: Float
val confidence: Float,
val faceCount: Int = 1
) {
override fun equals(other: Any?): Boolean {
if (this === other) return true
@@ -459,7 +459,7 @@ data class ClusteringResult(
)
enum class AgeEstimate {
CHILD, // Appearance changes significantly over time
ADULT, // Stable appearance
UNKNOWN // Not enough data
CHILD,
ADULT,
UNKNOWN
}

View File

@@ -8,6 +8,8 @@ import com.placeholder.sherpai2.data.local.dao.PersonDao
import com.placeholder.sherpai2.data.local.entity.FaceModelEntity
import com.placeholder.sherpai2.data.local.entity.PersonEntity
import com.placeholder.sherpai2.data.local.entity.TemporalCentroid
import com.placeholder.sherpai2.domain.clustering.ClusterQualityAnalyzer
import com.placeholder.sherpai2.domain.clustering.ClusterQualityResult
import com.placeholder.sherpai2.domain.clustering.FaceCluster
import com.placeholder.sherpai2.ml.FaceNetModel
import dagger.hilt.android.qualifiers.ApplicationContext
@@ -21,23 +23,36 @@ import kotlin.math.abs
* ClusterTrainingService - Train multi-centroid face models from clusters
*
* STRATEGY:
* 1. For children: Create multiple temporal centroids (one per age period)
* 2. For adults: Create single centroid (stable appearance)
* 3. Use K-Means clustering on timestamps to find age groups
* 4. Calculate centroid for each time period
* 1. VALIDATE cluster quality FIRST (prevent training on dirty/mixed clusters)
* 2. For children: Create multiple temporal centroids (one per age period)
* 3. For adults: Create single centroid (stable appearance)
* 4. Use K-Means clustering on timestamps to find age groups
* 5. Calculate centroid for each time period
*/
@Singleton
class ClusterTrainingService @Inject constructor(
@ApplicationContext private val context: Context,
private val personDao: PersonDao,
private val faceModelDao: FaceModelDao
private val faceModelDao: FaceModelDao,
private val qualityAnalyzer: ClusterQualityAnalyzer
) {
private val faceNetModel by lazy { FaceNetModel(context) }
/**
* Analyze cluster quality before training
*
* Call this BEFORE trainFromCluster() to check if cluster is clean
*/
suspend fun analyzeClusterQuality(cluster: FaceCluster): ClusterQualityResult {
return qualityAnalyzer.analyzeCluster(cluster)
}
/**
* Train a person from an auto-discovered cluster
*
* @param cluster The discovered cluster
* @param qualityResult Optional pre-computed quality analysis (recommended)
* @return PersonId on success
*/
suspend fun trainFromCluster(
@@ -46,12 +61,26 @@ class ClusterTrainingService @Inject constructor(
dateOfBirth: Long?,
isChild: Boolean,
siblingClusterIds: List<Int>,
qualityResult: ClusterQualityResult? = null,
onProgress: (Int, Int, String) -> Unit = { _, _, _ -> }
): String = withContext(Dispatchers.Default) {
onProgress(0, 100, "Creating person...")
// Step 1: Create PersonEntity
// Step 1: Use clean faces if quality analysis was done
val facesToUse = if (qualityResult != null && qualityResult.cleanFaces.isNotEmpty()) {
// Use clean faces (outliers removed)
qualityResult.cleanFaces
} else {
// Use all faces (legacy behavior)
cluster.faces
}
if (facesToUse.size < 6) {
throw Exception("Need at least 6 clean faces for training (have ${facesToUse.size})")
}
// Step 2: Create PersonEntity
val person = PersonEntity.create(
name = name,
dateOfBirth = dateOfBirth,
@@ -66,30 +95,20 @@ class ClusterTrainingService @Inject constructor(
onProgress(20, 100, "Analyzing face variations...")
// Step 2: Generate embeddings for all faces in cluster
val facesWithEmbeddings = cluster.faces.mapNotNull { face ->
try {
val bitmap = context.contentResolver.openInputStream(Uri.parse(face.imageUri))?.use {
BitmapFactory.decodeStream(it)
} ?: return@mapNotNull null
// Generate embedding
val embedding = faceNetModel.generateEmbedding(bitmap)
bitmap.recycle()
Triple(face.imageUri, face.capturedAt, embedding)
} catch (e: Exception) {
null
}
}
if (facesWithEmbeddings.isEmpty()) {
throw Exception("Failed to process any faces from cluster")
// Step 3: Use pre-computed embeddings from clustering
// CRITICAL: These embeddings are already face-specific, even in group photos!
// The clustering phase already cropped and generated embeddings for each face.
val facesWithEmbeddings = facesToUse.map { face ->
Triple(
face.imageUri,
face.capturedAt,
face.embedding // ✅ Use existing embedding (already cropped to face)
)
}
onProgress(50, 100, "Creating face model...")
// Step 3: Create centroids based on whether person is a child
// Step 4: Create centroids based on whether person is a child
val centroids = if (isChild && dateOfBirth != null) {
createTemporalCentroidsForChild(
facesWithEmbeddings = facesWithEmbeddings,
@@ -101,14 +120,14 @@ class ClusterTrainingService @Inject constructor(
onProgress(80, 100, "Saving model...")
// Step 4: Calculate average confidence
// Step 5: Calculate average confidence
val avgConfidence = centroids.map { it.avgConfidence }.average().toFloat()
// Step 5: Create FaceModelEntity
// Step 6: Create FaceModelEntity
val faceModel = FaceModelEntity.createFromCentroids(
personId = person.id,
centroids = centroids,
trainingImageCount = cluster.faces.size,
trainingImageCount = facesToUse.size,
averageConfidence = avgConfidence
)

View File

@@ -0,0 +1,312 @@
package com.placeholder.sherpai2.domain.validation
import android.content.Context
import android.graphics.BitmapFactory
import android.net.Uri
import com.google.mlkit.vision.common.InputImage
import com.google.mlkit.vision.face.FaceDetection
import com.google.mlkit.vision.face.FaceDetectorOptions
import com.placeholder.sherpai2.data.local.dao.FaceModelDao
import com.placeholder.sherpai2.data.local.dao.ImageDao
import com.placeholder.sherpai2.data.local.entity.FaceModelEntity
import com.placeholder.sherpai2.data.local.entity.ImageEntity
import com.placeholder.sherpai2.ml.FaceNetModel
import dagger.hilt.android.qualifiers.ApplicationContext
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.async
import kotlinx.coroutines.awaitAll
import kotlinx.coroutines.coroutineScope
import kotlinx.coroutines.tasks.await
import kotlinx.coroutines.withContext
import javax.inject.Inject
import javax.inject.Singleton
/**
* ValidationScanService - Quick validation scan after training
*
* PURPOSE: Let user verify model quality BEFORE full library scan
*
* STRATEGY:
* 1. Sample 20-30 random photos with faces
* 2. Scan for the newly trained person
* 3. Return preview results with confidence scores
* 4. User reviews and decides: "Looks good" or "Add more photos"
*
* THRESHOLD STRATEGY:
* - Use CONSERVATIVE threshold (0.75) for validation
* - Better to show false negatives than false positives
* - If user approves, full scan uses slightly looser threshold (0.70)
*/
@Singleton
class ValidationScanService @Inject constructor(
@ApplicationContext private val context: Context,
private val imageDao: ImageDao,
private val faceModelDao: FaceModelDao
) {
companion object {
private const val VALIDATION_SAMPLE_SIZE = 25
private const val VALIDATION_THRESHOLD = 0.75f // Conservative
}
/**
* Perform validation scan after training
*
* @param personId The newly trained person
* @param onProgress Callback (current, total)
* @return Validation results with preview matches
*/
suspend fun performValidationScan(
personId: String,
onProgress: (Int, Int) -> Unit = { _, _ -> }
): ValidationScanResult = withContext(Dispatchers.Default) {
onProgress(0, 100)
// Step 1: Get face model
val faceModel = withContext(Dispatchers.IO) {
faceModelDao.getFaceModelByPersonId(personId)
} ?: return@withContext ValidationScanResult(
personId = personId,
matches = emptyList(),
sampleSize = 0,
errorMessage = "Face model not found"
)
onProgress(10, 100)
// Step 2: Get random sample of photos with faces
val allPhotosWithFaces = withContext(Dispatchers.IO) {
imageDao.getImagesWithFaces()
}
if (allPhotosWithFaces.isEmpty()) {
return@withContext ValidationScanResult(
personId = personId,
matches = emptyList(),
sampleSize = 0,
errorMessage = "No photos with faces in library"
)
}
// Random sample
val samplePhotos = allPhotosWithFaces.shuffled().take(VALIDATION_SAMPLE_SIZE)
onProgress(20, 100)
// Step 3: Scan sample photos
val faceNetModel = FaceNetModel(context)
val detector = FaceDetection.getClient(
FaceDetectorOptions.Builder()
.setPerformanceMode(FaceDetectorOptions.PERFORMANCE_MODE_ACCURATE)
.setMinFaceSize(0.15f)
.build()
)
try {
val matches = scanPhotosForPerson(
photos = samplePhotos,
faceModel = faceModel,
faceNetModel = faceNetModel,
detector = detector,
threshold = VALIDATION_THRESHOLD,
onProgress = { current, total ->
// Map to 20-100 range
val progress = 20 + (current * 80 / total)
onProgress(progress, 100)
}
)
onProgress(100, 100)
ValidationScanResult(
personId = personId,
matches = matches,
sampleSize = samplePhotos.size,
threshold = VALIDATION_THRESHOLD
)
} finally {
faceNetModel.close()
detector.close()
}
}
/**
* Scan photos for a specific person
*/
private suspend fun scanPhotosForPerson(
photos: List<ImageEntity>,
faceModel: FaceModelEntity,
faceNetModel: FaceNetModel,
detector: com.google.mlkit.vision.face.FaceDetector,
threshold: Float,
onProgress: (Int, Int) -> Unit
): List<ValidationMatch> = coroutineScope {
val modelEmbedding = faceModel.getEmbeddingArray()
val matches = mutableListOf<ValidationMatch>()
var processedCount = 0
// Process in parallel
val jobs = photos.map { photo ->
async(Dispatchers.IO) {
val photoMatches = scanSinglePhoto(
photo = photo,
modelEmbedding = modelEmbedding,
faceNetModel = faceNetModel,
detector = detector,
threshold = threshold
)
synchronized(matches) {
matches.addAll(photoMatches)
processedCount++
if (processedCount % 5 == 0) {
onProgress(processedCount, photos.size)
}
}
}
}
jobs.awaitAll()
matches.sortedByDescending { it.confidence }
}
/**
* Scan a single photo for the person
*/
private suspend fun scanSinglePhoto(
photo: ImageEntity,
modelEmbedding: FloatArray,
faceNetModel: FaceNetModel,
detector: com.google.mlkit.vision.face.FaceDetector,
threshold: Float
): List<ValidationMatch> = withContext(Dispatchers.IO) {
try {
// Load bitmap
val bitmap = loadBitmapDownsampled(Uri.parse(photo.imageUri), 768)
?: return@withContext emptyList()
// Detect faces
val inputImage = InputImage.fromBitmap(bitmap, 0)
val faces = detector.process(inputImage).await()
// Check each face
val matches = faces.mapNotNull { face ->
try {
// Crop face
val faceBitmap = android.graphics.Bitmap.createBitmap(
bitmap,
face.boundingBox.left.coerceIn(0, bitmap.width - 1),
face.boundingBox.top.coerceIn(0, bitmap.height - 1),
face.boundingBox.width().coerceAtMost(bitmap.width - face.boundingBox.left),
face.boundingBox.height().coerceAtMost(bitmap.height - face.boundingBox.top)
)
// Generate embedding
val faceEmbedding = faceNetModel.generateEmbedding(faceBitmap)
faceBitmap.recycle()
// Calculate similarity
val similarity = faceNetModel.calculateSimilarity(faceEmbedding, modelEmbedding)
if (similarity >= threshold) {
ValidationMatch(
imageId = photo.imageId,
imageUri = photo.imageUri,
capturedAt = photo.capturedAt,
confidence = similarity,
boundingBox = face.boundingBox,
faceCount = faces.size
)
} else {
null
}
} catch (e: Exception) {
null
}
}
bitmap.recycle()
matches
} catch (e: Exception) {
emptyList()
}
}
/**
* Load bitmap with downsampling
*/
private fun loadBitmapDownsampled(uri: Uri, maxDim: Int): android.graphics.Bitmap? {
return try {
val opts = BitmapFactory.Options().apply { inJustDecodeBounds = true }
context.contentResolver.openInputStream(uri)?.use {
BitmapFactory.decodeStream(it, null, opts)
}
var sample = 1
while (opts.outWidth / sample > maxDim || opts.outHeight / sample > maxDim) {
sample *= 2
}
val finalOpts = BitmapFactory.Options().apply {
inSampleSize = sample
}
context.contentResolver.openInputStream(uri)?.use {
BitmapFactory.decodeStream(it, null, finalOpts)
}
} catch (e: Exception) {
null
}
}
}
/**
* Result of validation scan
*/
data class ValidationScanResult(
val personId: String,
val matches: List<ValidationMatch>,
val sampleSize: Int,
val threshold: Float = 0.75f,
val errorMessage: String? = null
) {
val matchCount: Int get() = matches.size
val averageConfidence: Float get() = if (matches.isNotEmpty()) {
matches.map { it.confidence }.average().toFloat()
} else 0f
val qualityAssessment: ValidationQuality get() = when {
matchCount == 0 -> ValidationQuality.NO_MATCHES
averageConfidence >= 0.85f && matchCount >= 5 -> ValidationQuality.EXCELLENT
averageConfidence >= 0.78f && matchCount >= 3 -> ValidationQuality.GOOD
averageConfidence < 0.75f || matchCount < 2 -> ValidationQuality.POOR
else -> ValidationQuality.FAIR
}
}
/**
* Single match found during validation
*/
data class ValidationMatch(
val imageId: String,
val imageUri: String,
val capturedAt: Long,
val confidence: Float,
val boundingBox: android.graphics.Rect,
val faceCount: Int
)
/**
* Overall quality assessment
*/
enum class ValidationQuality {
EXCELLENT, // High confidence, many matches
GOOD, // Decent confidence, some matches
FAIR, // Acceptable, proceed with caution
POOR, // Low confidence or very few matches
NO_MATCHES // No matches found at all
}

View File

@@ -1,253 +1,201 @@
package com.placeholder.sherpai2.ui.discover
import android.graphics.BitmapFactory
import android.net.Uri
import androidx.compose.foundation.Image
import androidx.compose.foundation.background
import androidx.compose.foundation.clickable
import androidx.compose.foundation.layout.*
import androidx.compose.foundation.lazy.LazyColumn
import androidx.compose.foundation.lazy.grid.GridCells
import androidx.compose.foundation.lazy.grid.LazyVerticalGrid
import androidx.compose.foundation.lazy.grid.items
import androidx.compose.foundation.lazy.items
import androidx.compose.foundation.shape.CircleShape
import androidx.compose.foundation.shape.RoundedCornerShape
import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.filled.*
import androidx.compose.material.icons.filled.Person
import androidx.compose.material3.*
import androidx.compose.runtime.*
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.draw.clip
import androidx.compose.ui.graphics.asImageBitmap
import androidx.compose.ui.layout.ContentScale
import androidx.compose.ui.platform.LocalContext
import androidx.compose.ui.text.font.FontWeight
import androidx.compose.ui.text.style.TextAlign
import androidx.compose.ui.unit.dp
import androidx.hilt.navigation.compose.hiltViewModel
import androidx.lifecycle.compose.collectAsStateWithLifecycle
import com.placeholder.sherpai2.domain.clustering.AgeEstimate
import com.placeholder.sherpai2.domain.clustering.FaceCluster
import java.text.SimpleDateFormat
import java.util.*
/**
* DiscoverPeopleScreen - Beautiful auto-clustering UI
* DiscoverPeopleScreen - COMPLETE WORKING VERSION
*
* FLOW:
* 1. Hero CTA: "Discover People in Your Photos"
* 2. Auto-clustering progress (2-5 min)
* 3. Grid of discovered people
* 4. Tap cluster → Name person + metadata
* 5. Background deep scan starts
* This handles ALL states properly including Idle state
*/
@OptIn(ExperimentalMaterial3Api::class)
@Composable
fun DiscoverPeopleScreen(
viewModel: DiscoverPeopleViewModel = hiltViewModel()
viewModel: DiscoverPeopleViewModel = hiltViewModel(),
onNavigateBack: () -> Unit = {}
) {
val uiState by viewModel.uiState.collectAsStateWithLifecycle()
val uiState by viewModel.uiState.collectAsState()
// NO SCAFFOLD - MainScreen already has TopAppBar
Box(modifier = Modifier.fillMaxSize()) {
when (val state = uiState) {
is DiscoverUiState.Idle -> IdleScreen(
onStartDiscovery = { viewModel.startDiscovery() }
)
is DiscoverUiState.Clustering -> ClusteringProgressScreen(
progress = state.progress,
total = state.total,
message = state.message
)
is DiscoverUiState.NamingReady -> ClusterGridScreen(
result = state.result,
onClusterClick = { cluster ->
viewModel.selectCluster(cluster)
Scaffold(
topBar = {
TopAppBar(
title = { Text("Discover People") },
navigationIcon = {
IconButton(onClick = onNavigateBack) {
Icon(
imageVector = Icons.Default.Person,
contentDescription = "Back"
)
}
}
)
is DiscoverUiState.NamingCluster -> NamingDialog(
cluster = state.selectedCluster,
suggestedSiblings = state.suggestedSiblings,
onConfirm = { name, dob, isChild, siblings ->
viewModel.confirmClusterName(
cluster = state.selectedCluster,
name = name,
dateOfBirth = dob,
isChild = isChild,
selectedSiblings = siblings
)
},
onDismiss = { viewModel.cancelNaming() }
)
is DiscoverUiState.NoPeopleFound -> EmptyStateScreen(
message = state.message
)
is DiscoverUiState.Error -> ErrorScreen(
message = state.message,
onRetry = { viewModel.startDiscovery() }
)
}
}
}
/**
* Idle state - Hero CTA to start discovery
*/
@Composable
fun IdleScreen(
onStartDiscovery: () -> Unit
) {
Column(
modifier = Modifier
.fillMaxSize()
.padding(32.dp),
horizontalAlignment = Alignment.CenterHorizontally,
verticalArrangement = Arrangement.Center
) {
Icon(
imageVector = Icons.Default.AutoAwesome,
contentDescription = null,
modifier = Modifier.size(120.dp),
tint = MaterialTheme.colorScheme.primary
)
Spacer(Modifier.height(24.dp))
Text(
text = "Discover People",
style = MaterialTheme.typography.headlineLarge,
fontWeight = FontWeight.Bold,
textAlign = TextAlign.Center
)
Spacer(Modifier.height(16.dp))
Text(
text = "Let AI automatically find and group faces in your photos. " +
"You'll name them, and we'll tag all their photos.",
style = MaterialTheme.typography.bodyLarge,
textAlign = TextAlign.Center,
color = MaterialTheme.colorScheme.onSurfaceVariant
)
Spacer(Modifier.height(32.dp))
Button(
onClick = onStartDiscovery,
) { paddingValues ->
Box(
modifier = Modifier
.fillMaxWidth()
.height(56.dp),
colors = ButtonDefaults.buttonColors(
containerColor = MaterialTheme.colorScheme.primary
)
.fillMaxSize()
.padding(paddingValues)
) {
Icon(
imageVector = Icons.Default.AutoAwesome,
contentDescription = null,
modifier = Modifier.size(24.dp)
)
Spacer(Modifier.width(8.dp))
Text(
text = "Start Discovery",
style = MaterialTheme.typography.titleMedium,
fontWeight = FontWeight.Bold
)
}
when (val state = uiState) {
// ===== IDLE STATE (START HERE) =====
is DiscoverUiState.Idle -> {
IdleStateContent(
onStartDiscovery = { viewModel.startDiscovery() }
)
}
Spacer(Modifier.height(16.dp))
// ===== CLUSTERING IN PROGRESS =====
is DiscoverUiState.Clustering -> {
ClusteringProgressContent(
progress = state.progress,
total = state.total,
message = state.message
)
}
Card(
modifier = Modifier.fillMaxWidth(),
colors = CardDefaults.cardColors(
containerColor = MaterialTheme.colorScheme.surfaceVariant
)
) {
Column(
modifier = Modifier.padding(16.dp),
verticalArrangement = Arrangement.spacedBy(8.dp)
) {
InfoRow(Icons.Default.Speed, "Fast: Analyzes ~1000 photos in 2-5 minutes")
InfoRow(Icons.Default.Security, "Private: Everything stays on your device")
InfoRow(Icons.Default.AutoAwesome, "Smart: Groups faces automatically")
// ===== CLUSTERS READY FOR NAMING =====
is DiscoverUiState.NamingReady -> {
Text(
text = "Found ${state.result.clusters.size} people!\n\nCluster grid UI coming...",
modifier = Modifier.align(Alignment.Center)
)
}
// ===== ANALYZING CLUSTER QUALITY =====
is DiscoverUiState.AnalyzingCluster -> {
LoadingContent(message = "Analyzing cluster quality...")
}
// ===== NAMING A CLUSTER =====
is DiscoverUiState.NamingCluster -> {
Text(
text = "Naming dialog for cluster ${state.selectedCluster.clusterId}\n\nDialog UI coming...",
modifier = Modifier.align(Alignment.Center)
)
}
// ===== TRAINING IN PROGRESS =====
is DiscoverUiState.Training -> {
TrainingProgressContent(
stage = state.stage,
progress = state.progress,
total = state.total
)
}
// ===== VALIDATION PREVIEW =====
is DiscoverUiState.ValidationPreview -> {
ValidationPreviewScreen(
personName = state.personName,
validationResult = state.validationResult,
onApprove = {
viewModel.approveValidationAndScan(
personId = state.personId,
personName = state.personName
)
},
onReject = {
viewModel.rejectValidationAndImprove()
}
)
}
// ===== COMPLETE =====
is DiscoverUiState.Complete -> {
CompleteStateContent(
message = state.message,
onDone = onNavigateBack
)
}
// ===== NO PEOPLE FOUND =====
is DiscoverUiState.NoPeopleFound -> {
ErrorStateContent(
title = "No People Found",
message = state.message,
onRetry = { viewModel.startDiscovery() },
onBack = onNavigateBack
)
}
// ===== ERROR =====
is DiscoverUiState.Error -> {
ErrorStateContent(
title = "Error",
message = state.message,
onRetry = { viewModel.reset(); viewModel.startDiscovery() },
onBack = onNavigateBack
)
}
}
}
}
}
@Composable
fun InfoRow(icon: androidx.compose.ui.graphics.vector.ImageVector, text: String) {
Row(
verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.spacedBy(12.dp)
) {
Icon(
imageVector = icon,
contentDescription = null,
tint = MaterialTheme.colorScheme.primary,
modifier = Modifier.size(20.dp)
)
Text(
text = text,
style = MaterialTheme.typography.bodyMedium
)
}
}
// ===== IDLE STATE CONTENT =====
/**
* Clustering progress screen
*/
@Composable
fun ClusteringProgressScreen(
progress: Int,
total: Int,
message: String
private fun IdleStateContent(
onStartDiscovery: () -> Unit
) {
Column(
modifier = Modifier
.fillMaxSize()
.padding(32.dp),
.padding(24.dp),
horizontalAlignment = Alignment.CenterHorizontally,
verticalArrangement = Arrangement.Center
) {
CircularProgressIndicator(
modifier = Modifier.size(80.dp),
strokeWidth = 6.dp
Icon(
imageVector = Icons.Default.Person,
contentDescription = null,
modifier = Modifier.size(120.dp),
tint = MaterialTheme.colorScheme.primary
)
Spacer(Modifier.height(32.dp))
Spacer(modifier = Modifier.height(32.dp))
Text(
text = "Discovering People...",
style = MaterialTheme.typography.headlineSmall,
text = "Discover People",
style = MaterialTheme.typography.headlineLarge,
fontWeight = FontWeight.Bold
)
Spacer(Modifier.height(16.dp))
LinearProgressIndicator(
progress = { if (total > 0) progress.toFloat() / total.toFloat() else 0f },
modifier = Modifier.fillMaxWidth(),
)
Spacer(Modifier.height(8.dp))
Spacer(modifier = Modifier.height(16.dp))
Text(
text = message,
style = MaterialTheme.typography.bodyMedium,
text = "Automatically find and organize people in your photo library",
style = MaterialTheme.typography.bodyLarge,
textAlign = TextAlign.Center,
color = MaterialTheme.colorScheme.onSurfaceVariant
)
Spacer(Modifier.height(24.dp))
Spacer(modifier = Modifier.height(48.dp))
Button(
onClick = onStartDiscovery,
modifier = Modifier
.fillMaxWidth()
.height(56.dp)
) {
Text(
text = "Start Discovery",
style = MaterialTheme.typography.titleMedium
)
}
Spacer(modifier = Modifier.height(16.dp))
Text(
text = "This will take 2-5 minutes. You can leave and come back later.",
text = "This will analyze faces in your photos and group similar faces together",
style = MaterialTheme.typography.bodySmall,
textAlign = TextAlign.Center,
color = MaterialTheme.colorScheme.onSurfaceVariant
@@ -255,421 +203,145 @@ fun ClusteringProgressScreen(
}
}
/**
* Grid of discovered clusters
*/
// ===== CLUSTERING PROGRESS =====
@Composable
fun ClusterGridScreen(
result: com.placeholder.sherpai2.domain.clustering.ClusteringResult,
onClusterClick: (FaceCluster) -> Unit
private fun ClusteringProgressContent(
progress: Int,
total: Int,
message: String
) {
Column(
modifier = Modifier
.fillMaxSize()
.padding(16.dp)
) {
Text(
text = "Found ${result.clusters.size} People",
style = MaterialTheme.typography.headlineSmall,
fontWeight = FontWeight.Bold
)
Spacer(Modifier.height(8.dp))
Text(
text = "Tap to name each person. We'll then tag all their photos.",
style = MaterialTheme.typography.bodyMedium,
color = MaterialTheme.colorScheme.onSurfaceVariant
)
Spacer(Modifier.height(16.dp))
LazyVerticalGrid(
columns = GridCells.Fixed(2),
horizontalArrangement = Arrangement.spacedBy(12.dp),
verticalArrangement = Arrangement.spacedBy(12.dp)
) {
items(result.clusters) { cluster ->
ClusterCard(
cluster = cluster,
onClick = { onClusterClick(cluster) }
)
}
}
}
}
/**
* Single cluster card
*/
@Composable
fun ClusterCard(
cluster: FaceCluster,
onClick: () -> Unit
) {
val context = LocalContext.current
Card(
modifier = Modifier
.fillMaxWidth()
.clickable(onClick = onClick),
elevation = CardDefaults.cardElevation(defaultElevation = 2.dp)
) {
Column {
// Face grid (2x3)
LazyVerticalGrid(
columns = GridCells.Fixed(3),
modifier = Modifier.height(180.dp),
userScrollEnabled = false
) {
items(cluster.representativeFaces.take(6)) { face ->
val bitmap = remember(face.imageUri) {
try {
context.contentResolver.openInputStream(Uri.parse(face.imageUri))?.use {
BitmapFactory.decodeStream(it)
}
} catch (e: Exception) {
null
}
}
if (bitmap != null) {
Image(
bitmap = bitmap.asImageBitmap(),
contentDescription = null,
modifier = Modifier
.fillMaxWidth()
.aspectRatio(1f),
contentScale = ContentScale.Crop
)
} else {
Box(
modifier = Modifier
.fillMaxWidth()
.aspectRatio(1f)
.background(MaterialTheme.colorScheme.surfaceVariant),
contentAlignment = Alignment.Center
) {
Icon(
imageVector = Icons.Default.Person,
contentDescription = null,
tint = MaterialTheme.colorScheme.onSurfaceVariant
)
}
}
}
}
// Info
Column(
modifier = Modifier.padding(12.dp)
) {
Row(
modifier = Modifier.fillMaxWidth(),
horizontalArrangement = Arrangement.SpaceBetween,
verticalAlignment = Alignment.CenterVertically
) {
Text(
text = "${cluster.photoCount} photos",
style = MaterialTheme.typography.titleMedium,
fontWeight = FontWeight.Bold
)
if (cluster.estimatedAge == AgeEstimate.CHILD) {
Surface(
shape = RoundedCornerShape(12.dp),
color = MaterialTheme.colorScheme.primaryContainer
) {
Text(
text = "Child",
modifier = Modifier.padding(horizontal = 8.dp, vertical = 4.dp),
style = MaterialTheme.typography.labelSmall,
color = MaterialTheme.colorScheme.onPrimaryContainer
)
}
}
}
if (cluster.potentialSiblings.isNotEmpty()) {
Spacer(Modifier.height(4.dp))
Text(
text = "Appears with ${cluster.potentialSiblings.size} other person(s)",
style = MaterialTheme.typography.bodySmall,
color = MaterialTheme.colorScheme.onSurfaceVariant
)
}
}
}
}
}
/**
* Naming dialog
*/
@Composable
fun NamingDialog(
cluster: FaceCluster,
suggestedSiblings: List<FaceCluster>,
onConfirm: (String, Long?, Boolean, List<Int>) -> Unit,
onDismiss: () -> Unit
) {
var name by remember { mutableStateOf("") }
var isChild by remember { mutableStateOf(cluster.estimatedAge == AgeEstimate.CHILD) }
var dateOfBirth by remember { mutableStateOf<Long?>(null) }
var selectedSiblings by remember { mutableStateOf<Set<Int>>(emptySet()) }
var showDatePicker by remember { mutableStateOf(false) }
val context = LocalContext.current
// Date picker dialog
if (showDatePicker) {
val calendar = java.util.Calendar.getInstance()
if (dateOfBirth != null) {
calendar.timeInMillis = dateOfBirth!!
}
val datePickerDialog = android.app.DatePickerDialog(
context,
{ _, year, month, dayOfMonth ->
val cal = java.util.Calendar.getInstance()
cal.set(year, month, dayOfMonth)
dateOfBirth = cal.timeInMillis
showDatePicker = false
},
calendar.get(java.util.Calendar.YEAR),
calendar.get(java.util.Calendar.MONTH),
calendar.get(java.util.Calendar.DAY_OF_MONTH)
)
datePickerDialog.setOnDismissListener {
showDatePicker = false
}
DisposableEffect(Unit) {
datePickerDialog.show()
onDispose {
datePickerDialog.dismiss()
}
}
}
AlertDialog(
onDismissRequest = onDismiss,
title = {
Text("Name This Person")
},
text = {
Column(
verticalArrangement = Arrangement.spacedBy(16.dp)
) {
// FACE PREVIEW - Show 6 representative faces
Text(
text = "${cluster.photoCount} photos found",
style = MaterialTheme.typography.bodyMedium,
color = MaterialTheme.colorScheme.onSurfaceVariant
)
LazyVerticalGrid(
columns = GridCells.Fixed(3),
modifier = Modifier.height(180.dp),
horizontalArrangement = Arrangement.spacedBy(4.dp),
verticalArrangement = Arrangement.spacedBy(4.dp)
) {
items(cluster.representativeFaces.take(6)) { face ->
val bitmap = remember(face.imageUri) {
try {
context.contentResolver.openInputStream(Uri.parse(face.imageUri))?.use {
BitmapFactory.decodeStream(it)
}
} catch (e: Exception) {
null
}
}
if (bitmap != null) {
Image(
bitmap = bitmap.asImageBitmap(),
contentDescription = null,
modifier = Modifier
.fillMaxWidth()
.aspectRatio(1f)
.clip(RoundedCornerShape(8.dp)),
contentScale = ContentScale.Crop
)
} else {
Box(
modifier = Modifier
.fillMaxWidth()
.aspectRatio(1f)
.clip(RoundedCornerShape(8.dp))
.background(MaterialTheme.colorScheme.surfaceVariant),
contentAlignment = Alignment.Center
) {
Icon(
imageVector = Icons.Default.Person,
contentDescription = null,
tint = MaterialTheme.colorScheme.onSurfaceVariant
)
}
}
}
}
HorizontalDivider()
// Name input
OutlinedTextField(
value = name,
onValueChange = { name = it },
label = { Text("Name") },
singleLine = true,
modifier = Modifier.fillMaxWidth()
)
// Is child toggle
Row(
modifier = Modifier.fillMaxWidth(),
horizontalArrangement = Arrangement.SpaceBetween,
verticalAlignment = Alignment.CenterVertically
) {
Text("This person is a child")
Switch(
checked = isChild,
onCheckedChange = { isChild = it }
)
}
// Date of birth (if child)
if (isChild) {
OutlinedButton(
onClick = { showDatePicker = true },
modifier = Modifier.fillMaxWidth()
) {
Icon(Icons.Default.CalendarToday, null)
Spacer(Modifier.width(8.dp))
Text(
if (dateOfBirth != null) {
SimpleDateFormat("MMM dd, yyyy", Locale.getDefault())
.format(Date(dateOfBirth!!))
} else {
"Set Date of Birth"
}
)
}
}
// Suggested siblings
if (suggestedSiblings.isNotEmpty()) {
Text(
"Appears with these people (select siblings):",
style = MaterialTheme.typography.labelMedium
)
suggestedSiblings.forEach { sibling ->
Row(
modifier = Modifier.fillMaxWidth(),
verticalAlignment = Alignment.CenterVertically
) {
Checkbox(
checked = sibling.clusterId in selectedSiblings,
onCheckedChange = { checked ->
selectedSiblings = if (checked) {
selectedSiblings + sibling.clusterId
} else {
selectedSiblings - sibling.clusterId
}
}
)
Text("Person ${sibling.clusterId + 1} (${sibling.photoCount} photos)")
}
}
}
}
},
confirmButton = {
TextButton(
onClick = {
onConfirm(
name,
dateOfBirth,
isChild,
selectedSiblings.toList()
)
},
enabled = name.isNotBlank()
) {
Text("Save & Train")
}
},
dismissButton = {
TextButton(onClick = onDismiss) {
Text("Cancel")
}
}
)
// TODO: Add DatePickerDialog when showDatePicker is true
}
/**
* Empty state screen
*/
@Composable
fun EmptyStateScreen(message: String) {
Column(
modifier = Modifier
.fillMaxSize()
.padding(32.dp),
.padding(24.dp),
horizontalAlignment = Alignment.CenterHorizontally,
verticalArrangement = Arrangement.Center
) {
Icon(
imageVector = Icons.Default.PersonOff,
contentDescription = null,
modifier = Modifier.size(80.dp),
tint = MaterialTheme.colorScheme.onSurfaceVariant
CircularProgressIndicator(
modifier = Modifier.size(64.dp)
)
Spacer(Modifier.height(16.dp))
Spacer(modifier = Modifier.height(32.dp))
Text(
text = message,
style = MaterialTheme.typography.bodyLarge,
style = MaterialTheme.typography.titleMedium,
textAlign = TextAlign.Center
)
Spacer(modifier = Modifier.height(16.dp))
if (total > 0) {
LinearProgressIndicator(
progress = progress.toFloat() / total.toFloat(),
modifier = Modifier
.fillMaxWidth()
.height(8.dp)
)
Spacer(modifier = Modifier.height(8.dp))
Text(
text = "$progress / $total",
style = MaterialTheme.typography.bodyMedium,
color = MaterialTheme.colorScheme.onSurfaceVariant
)
}
}
}
/**
* Error screen
*/
// ===== TRAINING PROGRESS =====
@Composable
fun ErrorScreen(
message: String,
onRetry: () -> Unit
private fun TrainingProgressContent(
stage: String,
progress: Int,
total: Int
) {
Column(
modifier = Modifier
.fillMaxSize()
.padding(32.dp),
.padding(24.dp),
horizontalAlignment = Alignment.CenterHorizontally,
verticalArrangement = Arrangement.Center
) {
Icon(
imageVector = Icons.Default.Error,
contentDescription = null,
modifier = Modifier.size(80.dp),
tint = MaterialTheme.colorScheme.error
CircularProgressIndicator(
modifier = Modifier.size(64.dp)
)
Spacer(Modifier.height(16.dp))
Spacer(modifier = Modifier.height(32.dp))
Text(
text = "Oops!",
text = stage,
style = MaterialTheme.typography.titleMedium,
textAlign = TextAlign.Center
)
if (total > 0) {
Spacer(modifier = Modifier.height(16.dp))
LinearProgressIndicator(
progress = progress.toFloat() / total.toFloat(),
modifier = Modifier
.fillMaxWidth()
.height(8.dp)
)
Spacer(modifier = Modifier.height(8.dp))
Text(
text = "$progress / $total",
style = MaterialTheme.typography.bodyMedium,
color = MaterialTheme.colorScheme.onSurfaceVariant
)
}
}
}
// ===== LOADING CONTENT =====
@Composable
private fun LoadingContent(message: String) {
Column(
modifier = Modifier.fillMaxSize(),
horizontalAlignment = Alignment.CenterHorizontally,
verticalArrangement = Arrangement.Center
) {
CircularProgressIndicator()
Spacer(modifier = Modifier.height(16.dp))
Text(text = message)
}
}
// ===== COMPLETE STATE =====
@Composable
private fun CompleteStateContent(
message: String,
onDone: () -> Unit
) {
Column(
modifier = Modifier
.fillMaxSize()
.padding(24.dp),
horizontalAlignment = Alignment.CenterHorizontally,
verticalArrangement = Arrangement.Center
) {
Text(
text = "🎉",
style = MaterialTheme.typography.displayLarge
)
Spacer(modifier = Modifier.height(24.dp))
Text(
text = "Success!",
style = MaterialTheme.typography.headlineMedium,
fontWeight = FontWeight.Bold
)
Spacer(Modifier.height(8.dp))
Spacer(modifier = Modifier.height(16.dp))
Text(
text = message,
@@ -678,10 +350,74 @@ fun ErrorScreen(
color = MaterialTheme.colorScheme.onSurfaceVariant
)
Spacer(Modifier.height(24.dp))
Spacer(modifier = Modifier.height(32.dp))
Button(onClick = onRetry) {
Text("Try Again")
Button(
onClick = onDone,
modifier = Modifier.fillMaxWidth()
) {
Text("Done")
}
}
}
// ===== ERROR STATE =====
@Composable
private fun ErrorStateContent(
title: String,
message: String,
onRetry: () -> Unit,
onBack: () -> Unit
) {
Column(
modifier = Modifier
.fillMaxSize()
.padding(24.dp),
horizontalAlignment = Alignment.CenterHorizontally,
verticalArrangement = Arrangement.Center
) {
Text(
text = "⚠️",
style = MaterialTheme.typography.displayLarge
)
Spacer(modifier = Modifier.height(24.dp))
Text(
text = title,
style = MaterialTheme.typography.headlineMedium,
fontWeight = FontWeight.Bold
)
Spacer(modifier = Modifier.height(16.dp))
Text(
text = message,
style = MaterialTheme.typography.bodyLarge,
textAlign = TextAlign.Center,
color = MaterialTheme.colorScheme.onSurfaceVariant
)
Spacer(modifier = Modifier.height(32.dp))
Row(
modifier = Modifier.fillMaxWidth(),
horizontalArrangement = Arrangement.spacedBy(12.dp)
) {
OutlinedButton(
onClick = onBack,
modifier = Modifier.weight(1f)
) {
Text("Back")
}
Button(
onClick = onRetry,
modifier = Modifier.weight(1f)
) {
Text("Retry")
}
}
}
}

View File

@@ -2,10 +2,15 @@ package com.placeholder.sherpai2.ui.discover
import androidx.lifecycle.ViewModel
import androidx.lifecycle.viewModelScope
import androidx.work.WorkManager
import com.placeholder.sherpai2.domain.clustering.ClusteringResult
import com.placeholder.sherpai2.domain.clustering.ClusterQualityResult
import com.placeholder.sherpai2.domain.clustering.FaceCluster
import com.placeholder.sherpai2.domain.clustering.FaceClusteringService
import com.placeholder.sherpai2.domain.training.ClusterTrainingService
import com.placeholder.sherpai2.domain.validation.ValidationScanResult
import com.placeholder.sherpai2.domain.validation.ValidationScanService
import com.placeholder.sherpai2.workers.LibraryScanWorker
import dagger.hilt.android.lifecycle.HiltViewModel
import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.StateFlow
@@ -14,21 +19,22 @@ import kotlinx.coroutines.launch
import javax.inject.Inject
/**
* DiscoverPeopleViewModel - Manages auto-clustering and naming flow
* DiscoverPeopleViewModel - Manages TWO-STAGE validation flow
*
* PHASE 2: Now includes multi-centroid training from clusters
*
* STATE FLOW:
* 1. Idle → User taps "Discover People"
* 2. Clustering → Auto-analyzing faces (2-5 min)
* 3. NamingReady → Shows clusters, user names them
* 4. Training → Creating multi-centroid face model
* 5. Complete → Ready to scan library
* FLOW:
* 1. Clustering → User selects cluster
* 2. STAGE 1: Show cluster quality analysis
* 3. User names person → Training
* 4. STAGE 2: Show validation scan preview
* 5. User approves → Full library scan (background worker)
* 6. Results appear in "People" tab
*/
@HiltViewModel
class DiscoverPeopleViewModel @Inject constructor(
private val clusteringService: FaceClusteringService,
private val trainingService: ClusterTrainingService
private val trainingService: ClusterTrainingService,
private val validationScanService: ValidationScanService,
private val workManager: WorkManager
) : ViewModel() {
private val _uiState = MutableStateFlow<DiscoverUiState>(DiscoverUiState.Idle)
@@ -37,6 +43,9 @@ class DiscoverPeopleViewModel @Inject constructor(
// Track which clusters have been named
private val namedClusterIds = mutableSetOf<Int>()
// Store quality analysis for current cluster
private var currentQualityResult: ClusterQualityResult? = null
/**
* Start auto-clustering process
*/
@@ -78,27 +87,41 @@ class DiscoverPeopleViewModel @Inject constructor(
/**
* User selected a cluster to name
* STAGE 1: Analyze quality FIRST
*/
fun selectCluster(cluster: FaceCluster) {
val currentState = _uiState.value
if (currentState is DiscoverUiState.NamingReady) {
_uiState.value = DiscoverUiState.NamingCluster(
result = currentState.result,
selectedCluster = cluster,
suggestedSiblings = currentState.result.clusters.filter {
it.clusterId in cluster.potentialSiblings
viewModelScope.launch {
try {
// Show analyzing state
_uiState.value = DiscoverUiState.AnalyzingCluster(cluster)
// Analyze cluster quality
val qualityResult = trainingService.analyzeClusterQuality(cluster)
currentQualityResult = qualityResult
// Show naming dialog with quality info
_uiState.value = DiscoverUiState.NamingCluster(
result = currentState.result,
selectedCluster = cluster,
qualityResult = qualityResult,
suggestedSiblings = currentState.result.clusters.filter {
it.clusterId in cluster.potentialSiblings
}
)
} catch (e: Exception) {
_uiState.value = DiscoverUiState.Error(
"Failed to analyze cluster: ${e.message}"
)
}
)
}
}
}
/**
* User confirmed name and metadata for a cluster
*
* CREATES:
* 1. PersonEntity with all metadata (name, DOB, siblings)
* 2. Multi-centroid FaceModelEntity (temporal tracking for children)
* 3. Removes cluster from display
* STAGE 2: Train → Validation scan → Preview
*/
fun confirmClusterName(
cluster: FaceCluster,
@@ -112,37 +135,59 @@ class DiscoverPeopleViewModel @Inject constructor(
val currentState = _uiState.value
if (currentState !is DiscoverUiState.NamingCluster) return@launch
// Train person from cluster
// Show training progress
_uiState.value = DiscoverUiState.Training(
stage = "Creating person and training model",
progress = 0,
total = 100
)
// Train person from cluster (using clean faces from quality analysis)
val personId = trainingService.trainFromCluster(
cluster = cluster,
name = name,
dateOfBirth = dateOfBirth,
isChild = isChild,
siblingClusterIds = selectedSiblings,
qualityResult = currentQualityResult, // Use clean faces!
onProgress = { current, total, message ->
_uiState.value = DiscoverUiState.Clustering(current, total, message)
_uiState.value = DiscoverUiState.Training(
stage = message,
progress = current,
total = total
)
}
)
// Training complete - now run validation scan
_uiState.value = DiscoverUiState.Training(
stage = "Running validation scan...",
progress = 0,
total = 100
)
val validationResult = validationScanService.performValidationScan(
personId = personId,
onProgress = { current, total ->
_uiState.value = DiscoverUiState.Training(
stage = "Scanning sample photos...",
progress = current,
total = total
)
}
)
// Show validation preview to user
_uiState.value = DiscoverUiState.ValidationPreview(
personId = personId,
personName = name,
validationResult = validationResult,
originalClusterResult = currentState.result
)
// Mark cluster as named
namedClusterIds.add(cluster.clusterId)
// Filter out named clusters
val remainingClusters = currentState.result.clusters
.filter { it.clusterId !in namedClusterIds }
if (remainingClusters.isEmpty()) {
// All clusters named! Show success
_uiState.value = DiscoverUiState.NoPeopleFound(
"All people have been named! 🎉\n\nGo to 'People' to see your trained models."
)
} else {
// Return to naming screen with remaining clusters
_uiState.value = DiscoverUiState.NamingReady(
result = currentState.result.copy(clusters = remainingClusters)
)
}
} catch (e: Exception) {
_uiState.value = DiscoverUiState.Error(
e.message ?: "Failed to create person: ${e.message}"
@@ -151,6 +196,57 @@ class DiscoverPeopleViewModel @Inject constructor(
}
}
/**
* User approves validation preview → Start full library scan
*/
fun approveValidationAndScan(personId: String, personName: String) {
viewModelScope.launch {
val currentState = _uiState.value
if (currentState !is DiscoverUiState.ValidationPreview) return@launch
// Enqueue background worker for full library scan
val workRequest = LibraryScanWorker.createWorkRequest(
personId = personId,
personName = personName,
threshold = 0.70f // Slightly looser than validation
)
workManager.enqueue(workRequest)
// Filter out named clusters and return to cluster list
val remainingClusters = currentState.originalClusterResult.clusters
.filter { it.clusterId !in namedClusterIds }
if (remainingClusters.isEmpty()) {
// All clusters named! Show success
_uiState.value = DiscoverUiState.Complete(
message = "All people have been named! 🎉\n\n" +
"Full library scan is running in the background.\n" +
"Go to 'People' to see results as they come in."
)
} else {
// Return to naming screen with remaining clusters
_uiState.value = DiscoverUiState.NamingReady(
result = currentState.originalClusterResult.copy(clusters = remainingClusters)
)
}
}
}
/**
* User rejects validation → Go back to add more training photos
*/
fun rejectValidationAndImprove() {
viewModelScope.launch {
val currentState = _uiState.value
if (currentState !is DiscoverUiState.ValidationPreview) return@launch
_uiState.value = DiscoverUiState.Error(
"Model quality needs improvement.\n\n" +
"Please use the manual training flow to add more high-quality photos."
)
}
}
/**
* Cancel naming and go back to cluster list
*/
@@ -172,7 +268,7 @@ class DiscoverPeopleViewModel @Inject constructor(
}
/**
* UI States for Discover People flow
* UI States for Discover People flow with TWO-STAGE VALIDATION
*/
sealed class DiscoverUiState {
@@ -198,14 +294,48 @@ sealed class DiscoverUiState {
) : DiscoverUiState()
/**
* User is naming a specific cluster
* STAGE 1: Analyzing cluster quality (before naming)
*/
data class AnalyzingCluster(
val cluster: FaceCluster
) : DiscoverUiState()
/**
* User is naming a specific cluster (with quality analysis)
*/
data class NamingCluster(
val result: ClusteringResult,
val selectedCluster: FaceCluster,
val qualityResult: ClusterQualityResult,
val suggestedSiblings: List<FaceCluster>
) : DiscoverUiState()
/**
* Training in progress
*/
data class Training(
val stage: String,
val progress: Int,
val total: Int
) : DiscoverUiState()
/**
* STAGE 2: Validation scan complete - show preview to user
*/
data class ValidationPreview(
val personId: String,
val personName: String,
val validationResult: ValidationScanResult,
val originalClusterResult: ClusteringResult
) : DiscoverUiState()
/**
* All clusters named and scans launched
*/
data class Complete(
val message: String
) : DiscoverUiState()
/**
* No people found in library
*/

View File

@@ -0,0 +1,395 @@
package com.placeholder.sherpai2.ui.discover
import android.net.Uri
import androidx.compose.foundation.background
import androidx.compose.foundation.border
import androidx.compose.foundation.layout.*
import androidx.compose.foundation.lazy.grid.GridCells
import androidx.compose.foundation.lazy.grid.LazyVerticalGrid
import androidx.compose.foundation.lazy.grid.items
import androidx.compose.foundation.shape.RoundedCornerShape
import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.filled.*
import androidx.compose.material3.*
import androidx.compose.runtime.*
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.draw.clip
import androidx.compose.ui.graphics.Color
import androidx.compose.ui.layout.ContentScale
import androidx.compose.ui.text.font.FontWeight
import androidx.compose.ui.unit.dp
import coil.compose.AsyncImage
import com.placeholder.sherpai2.domain.validation.ValidationMatch
import com.placeholder.sherpai2.domain.validation.ValidationQuality
import com.placeholder.sherpai2.domain.validation.ValidationScanResult
/**
* ValidationPreviewScreen - STAGE 2 validation UI
*
* Shows user a preview of matches found in validation scan
* User can approve (→ full scan) or reject (→ add more photos)
*/
@Composable
fun ValidationPreviewScreen(
personName: String,
validationResult: ValidationScanResult,
onApprove: () -> Unit,
onReject: () -> Unit,
modifier: Modifier = Modifier
) {
Column(
modifier = modifier
.fillMaxSize()
.padding(16.dp)
) {
// Header
Text(
text = "Validation Results",
style = MaterialTheme.typography.headlineMedium,
fontWeight = FontWeight.Bold
)
Spacer(modifier = Modifier.height(8.dp))
Text(
text = "Review matches for \"$personName\" before scanning your entire library",
style = MaterialTheme.typography.bodyMedium,
color = MaterialTheme.colorScheme.onSurfaceVariant
)
Spacer(modifier = Modifier.height(16.dp))
// Quality Summary
QualitySummaryCard(
validationResult = validationResult,
personName = personName
)
Spacer(modifier = Modifier.height(16.dp))
// Matches Grid
if (validationResult.matches.isNotEmpty()) {
Text(
text = "Sample Matches (${validationResult.matchCount})",
style = MaterialTheme.typography.titleMedium,
fontWeight = FontWeight.SemiBold
)
Spacer(modifier = Modifier.height(8.dp))
LazyVerticalGrid(
columns = GridCells.Fixed(3),
horizontalArrangement = Arrangement.spacedBy(8.dp),
verticalArrangement = Arrangement.spacedBy(8.dp),
modifier = Modifier.weight(1f)
) {
items(validationResult.matches.take(15)) { match ->
MatchPreviewCard(match = match)
}
}
} else {
// No matches found
NoMatchesCard()
}
Spacer(modifier = Modifier.height(16.dp))
// Action Buttons
Row(
modifier = Modifier.fillMaxWidth(),
horizontalArrangement = Arrangement.spacedBy(12.dp)
) {
// Reject button
OutlinedButton(
onClick = onReject,
modifier = Modifier.weight(1f),
colors = ButtonDefaults.outlinedButtonColors(
contentColor = MaterialTheme.colorScheme.error
)
) {
Icon(
imageVector = Icons.Default.Close,
contentDescription = null,
modifier = Modifier.size(20.dp)
)
Spacer(modifier = Modifier.width(8.dp))
Text("Add More Photos")
}
// Approve button
Button(
onClick = onApprove,
modifier = Modifier.weight(1f),
enabled = validationResult.qualityAssessment != ValidationQuality.NO_MATCHES
) {
Icon(
imageVector = Icons.Default.Check,
contentDescription = null,
modifier = Modifier.size(20.dp)
)
Spacer(modifier = Modifier.width(8.dp))
Text("Scan Library")
}
}
}
}
@Composable
private fun QualitySummaryCard(
validationResult: ValidationScanResult,
personName: String
) {
val (backgroundColor, iconColor, statusText, statusIcon) = when (validationResult.qualityAssessment) {
ValidationQuality.EXCELLENT -> {
Quadruple(
Color(0xFF1B5E20).copy(alpha = 0.1f),
Color(0xFF1B5E20),
"Excellent Match Quality",
Icons.Default.CheckCircle
)
}
ValidationQuality.GOOD -> {
Quadruple(
Color(0xFF2E7D32).copy(alpha = 0.1f),
Color(0xFF2E7D32),
"Good Match Quality",
Icons.Default.ThumbUp
)
}
ValidationQuality.FAIR -> {
Quadruple(
Color(0xFFF57F17).copy(alpha = 0.1f),
Color(0xFFF57F17),
"Fair Match Quality",
Icons.Default.Warning
)
}
ValidationQuality.POOR -> {
Quadruple(
Color(0xFFD32F2F).copy(alpha = 0.1f),
Color(0xFFD32F2F),
"Poor Match Quality",
Icons.Default.Warning
)
}
ValidationQuality.NO_MATCHES -> {
Quadruple(
Color(0xFFD32F2F).copy(alpha = 0.1f),
Color(0xFFD32F2F),
"No Matches Found",
Icons.Default.Close
)
}
}
Card(
modifier = Modifier.fillMaxWidth(),
colors = CardDefaults.cardColors(
containerColor = backgroundColor
)
) {
Column(
modifier = Modifier.padding(16.dp)
) {
Row(
verticalAlignment = Alignment.CenterVertically
) {
Icon(
imageVector = statusIcon,
contentDescription = null,
tint = iconColor,
modifier = Modifier.size(24.dp)
)
Spacer(modifier = Modifier.width(8.dp))
Text(
text = statusText,
style = MaterialTheme.typography.titleMedium,
fontWeight = FontWeight.Bold,
color = iconColor
)
}
Spacer(modifier = Modifier.height(12.dp))
// Stats
Row(
modifier = Modifier.fillMaxWidth(),
horizontalArrangement = Arrangement.SpaceBetween
) {
StatItem(
label = "Matches Found",
value = "${validationResult.matchCount} / ${validationResult.sampleSize}"
)
StatItem(
label = "Avg Confidence",
value = "${(validationResult.averageConfidence * 100).toInt()}%"
)
StatItem(
label = "Threshold",
value = "${(validationResult.threshold * 100).toInt()}%"
)
}
// Recommendation
if (validationResult.qualityAssessment != ValidationQuality.NO_MATCHES) {
Spacer(modifier = Modifier.height(12.dp))
val recommendation = when (validationResult.qualityAssessment) {
ValidationQuality.EXCELLENT ->
"✅ Model looks great! Safe to scan your full library."
ValidationQuality.GOOD ->
"✅ Model quality is good. You can proceed with the full scan."
ValidationQuality.FAIR ->
"⚠️ Model quality is acceptable but could be improved with more photos."
ValidationQuality.POOR ->
"⚠️ Consider adding more diverse, high-quality training photos."
ValidationQuality.NO_MATCHES -> ""
}
Text(
text = recommendation,
style = MaterialTheme.typography.bodyMedium,
color = MaterialTheme.colorScheme.onSurfaceVariant
)
} else {
Spacer(modifier = Modifier.height(12.dp))
Text(
text = "No matches found. The model may need more or better training photos, or the validation sample didn't include $personName.",
style = MaterialTheme.typography.bodyMedium,
color = MaterialTheme.colorScheme.error
)
}
}
}
}
@Composable
private fun StatItem(
label: String,
value: String
) {
Column(
horizontalAlignment = Alignment.CenterHorizontally
) {
Text(
text = value,
style = MaterialTheme.typography.titleLarge,
fontWeight = FontWeight.Bold
)
Text(
text = label,
style = MaterialTheme.typography.bodySmall,
color = MaterialTheme.colorScheme.onSurfaceVariant
)
}
}
@Composable
private fun MatchPreviewCard(
match: ValidationMatch
) {
Box(
modifier = Modifier
.aspectRatio(1f)
.clip(RoundedCornerShape(8.dp))
.background(MaterialTheme.colorScheme.surfaceVariant)
) {
AsyncImage(
model = Uri.parse(match.imageUri),
contentDescription = "Match preview",
modifier = Modifier.fillMaxSize(),
contentScale = ContentScale.Crop
)
// Confidence badge
Surface(
modifier = Modifier
.align(Alignment.BottomEnd)
.padding(4.dp),
shape = RoundedCornerShape(4.dp),
color = Color.Black.copy(alpha = 0.7f)
) {
Text(
text = "${(match.confidence * 100).toInt()}%",
style = MaterialTheme.typography.labelSmall,
color = Color.White,
modifier = Modifier.padding(horizontal = 6.dp, vertical = 2.dp)
)
}
// Face count indicator (if group photo)
if (match.faceCount > 1) {
Surface(
modifier = Modifier
.align(Alignment.TopEnd)
.padding(4.dp),
shape = RoundedCornerShape(4.dp),
color = MaterialTheme.colorScheme.primary
) {
Row(
modifier = Modifier.padding(horizontal = 4.dp, vertical = 2.dp),
verticalAlignment = Alignment.CenterVertically
) {
Icon(
imageVector = Icons.Default.Person,
contentDescription = null,
tint = Color.White,
modifier = Modifier.size(12.dp)
)
Text(
text = "${match.faceCount}",
style = MaterialTheme.typography.labelSmall,
color = Color.White
)
}
}
}
}
}
@Composable
private fun NoMatchesCard() {
Card(
modifier = Modifier.fillMaxWidth(),
colors = CardDefaults.cardColors(
containerColor = MaterialTheme.colorScheme.errorContainer
)
) {
Column(
modifier = Modifier.padding(24.dp),
horizontalAlignment = Alignment.CenterHorizontally
) {
Icon(
imageVector = Icons.Default.Warning,
contentDescription = null,
tint = MaterialTheme.colorScheme.error,
modifier = Modifier.size(48.dp)
)
Spacer(modifier = Modifier.height(16.dp))
Text(
text = "No Matches Found",
style = MaterialTheme.typography.titleLarge,
fontWeight = FontWeight.Bold,
color = MaterialTheme.colorScheme.error
)
Spacer(modifier = Modifier.height(8.dp))
Text(
text = "The validation scan didn't find this person in the sample photos. This could mean:\n\n" +
"• The model needs more training photos\n" +
"• The training photos weren't diverse enough\n" +
"• The person wasn't in the validation sample",
style = MaterialTheme.typography.bodyMedium,
color = MaterialTheme.colorScheme.onErrorContainer
)
}
}
}
// Helper data class for quality indicator
private data class Quadruple<A, B, C, D>(
val first: A,
val second: B,
val third: C,
val fourth: D
)

View File

@@ -154,13 +154,3 @@ fun getDestinationByRoute(route: String?): AppDestinations? {
else -> null
}
}
/**
* Legacy support (for backwards compatibility)
* These match your old structure
*/
@Deprecated("Use organized groups instead", ReplaceWith("allMainDrawerDestinations"))
val mainDrawerItems = allMainDrawerDestinations
@Deprecated("Use settingsDestination instead", ReplaceWith("listOf(settingsDestination)"))
val utilityDrawerItems = listOf(settingsDestination)

View File

@@ -15,7 +15,10 @@ import com.placeholder.sherpai2.ui.navigation.AppRoutes
import kotlinx.coroutines.launch
/**
* Clean main screen - NO duplicate FABs, Collections support, Discover People
* MainScreen - FIXED double header issue
*
* BEST PRACTICE: Screens that manage their own TopAppBar should be excluded
* from MainScreen's TopAppBar to prevent ugly double headers.
*/
@OptIn(ExperimentalMaterial3Api::class)
@Composable
@@ -45,68 +48,77 @@ fun MainScreen() {
)
},
) {
// CRITICAL: Some screens manage their own TopAppBar
// Hide MainScreen's TopAppBar for these routes to prevent double headers
val screensWithOwnTopBar = setOf(
AppRoutes.TRAINING_PHOTO_SELECTOR // Has its own TopAppBar with subtitle
)
val showTopBar = currentRoute !in screensWithOwnTopBar
Scaffold(
topBar = {
TopAppBar(
title = {
Column {
Text(
text = getScreenTitle(currentRoute),
style = MaterialTheme.typography.titleLarge,
fontWeight = FontWeight.Bold
)
getScreenSubtitle(currentRoute)?.let { subtitle ->
if (showTopBar) {
TopAppBar(
title = {
Column {
Text(
text = subtitle,
style = MaterialTheme.typography.bodySmall,
color = MaterialTheme.colorScheme.onSurfaceVariant
text = getScreenTitle(currentRoute),
style = MaterialTheme.typography.titleLarge,
fontWeight = FontWeight.Bold
)
getScreenSubtitle(currentRoute)?.let { subtitle ->
Text(
text = subtitle,
style = MaterialTheme.typography.bodySmall,
color = MaterialTheme.colorScheme.onSurfaceVariant
)
}
}
},
navigationIcon = {
IconButton(
onClick = { scope.launch { drawerState.open() } }
) {
Icon(
Icons.Default.Menu,
contentDescription = "Open Menu",
tint = MaterialTheme.colorScheme.primary
)
}
}
},
navigationIcon = {
IconButton(
onClick = { scope.launch { drawerState.open() } }
) {
Icon(
Icons.Default.Menu,
contentDescription = "Open Menu",
tint = MaterialTheme.colorScheme.primary
)
}
},
actions = {
// Dynamic actions based on current screen
when (currentRoute) {
AppRoutes.SEARCH -> {
IconButton(onClick = { /* TODO: Open filter dialog */ }) {
Icon(
Icons.Default.FilterList,
contentDescription = "Filter",
tint = MaterialTheme.colorScheme.primary
)
},
actions = {
// Dynamic actions based on current screen
when (currentRoute) {
AppRoutes.SEARCH -> {
IconButton(onClick = { /* TODO: Open filter dialog */ }) {
Icon(
Icons.Default.FilterList,
contentDescription = "Filter",
tint = MaterialTheme.colorScheme.primary
)
}
}
AppRoutes.INVENTORY -> {
IconButton(onClick = {
navController.navigate(AppRoutes.TRAIN)
}) {
Icon(
Icons.Default.PersonAdd,
contentDescription = "Add Person",
tint = MaterialTheme.colorScheme.primary
)
}
}
}
AppRoutes.INVENTORY -> {
IconButton(onClick = {
navController.navigate(AppRoutes.TRAIN)
}) {
Icon(
Icons.Default.PersonAdd,
contentDescription = "Add Person",
tint = MaterialTheme.colorScheme.primary
)
}
}
}
},
colors = TopAppBarDefaults.topAppBarColors(
containerColor = MaterialTheme.colorScheme.surface,
titleContentColor = MaterialTheme.colorScheme.onSurface,
navigationIconContentColor = MaterialTheme.colorScheme.primary,
actionIconContentColor = MaterialTheme.colorScheme.primary
},
colors = TopAppBarDefaults.topAppBarColors(
containerColor = MaterialTheme.colorScheme.surface,
titleContentColor = MaterialTheme.colorScheme.onSurface,
navigationIconContentColor = MaterialTheme.colorScheme.primary,
actionIconContentColor = MaterialTheme.colorScheme.primary
)
)
)
}
}
) { paddingValues ->
AppNavHost(
@@ -125,10 +137,10 @@ private fun getScreenTitle(route: String): String {
AppRoutes.SEARCH -> "Search"
AppRoutes.EXPLORE -> "Explore"
AppRoutes.COLLECTIONS -> "Collections"
AppRoutes.DISCOVER -> "Discover People" // ✨ NEW!
AppRoutes.DISCOVER -> "Discover People"
AppRoutes.INVENTORY -> "People"
AppRoutes.TRAIN -> "Train New Person"
AppRoutes.MODELS -> "AI Models" // Deprecated, but keep for backwards compat
AppRoutes.MODELS -> "AI Models"
AppRoutes.TAGS -> "Tag Management"
AppRoutes.UTILITIES -> "Photo Util."
AppRoutes.SETTINGS -> "Settings"
@@ -144,7 +156,7 @@ private fun getScreenSubtitle(route: String): String? {
AppRoutes.SEARCH -> "Find photos by tags, people, or date"
AppRoutes.EXPLORE -> "Browse your collection"
AppRoutes.COLLECTIONS -> "Your photo collections"
AppRoutes.DISCOVER -> "Auto-find faces in your library" // ✨ NEW!
AppRoutes.DISCOVER -> "Auto-find faces in your library"
AppRoutes.INVENTORY -> "Trained face models"
AppRoutes.TRAIN -> "Add a new person to recognize"
AppRoutes.TAGS -> "Organize your photo collection"

View File

@@ -14,7 +14,9 @@ import javax.inject.Inject
* ImageSelectorViewModel
*
* Provides face-tagged image URIs for smart filtering
* during training photo selection
* during training photo selection.
*
* PRIORITIZATION: Solo photos first (faceCount=1) for clearer training data
*/
@HiltViewModel
class ImageSelectorViewModel @Inject constructor(
@@ -31,8 +33,15 @@ class ImageSelectorViewModel @Inject constructor(
private fun loadFaceTaggedImages() {
viewModelScope.launch {
try {
// Get all images with faces
val imagesWithFaces = imageDao.getImagesWithFaces()
_faceTaggedImageUris.value = imagesWithFaces.map { it.imageUri }
// CRITICAL FIX: Sort by faceCount ASCENDING (solo photos first!)
// Previously: Sorted by faceCount DESC (group photos first - WRONG!)
// Now: Solo photos appear first, making training selection easier
val sortedImages = imagesWithFaces.sortedBy { it.faceCount }
_faceTaggedImageUris.value = sortedImages.map { it.imageUri }
} catch (e: Exception) {
// If cache not available, just use empty list (filter disabled)
_faceTaggedImageUris.value = emptyList()

View File

@@ -46,6 +46,8 @@ class TrainingPhotoSelectorViewModel @Inject constructor(
*
* Uses indexed query: SELECT * FROM images WHERE hasFaces = 1
* Fast! (~10ms for 10k photos)
*
* SORTED: Solo photos (faceCount=1) first for best training quality
*/
private fun loadPhotosWithFaces() {
viewModelScope.launch {
@@ -55,8 +57,9 @@ class TrainingPhotoSelectorViewModel @Inject constructor(
// ✅ CRITICAL: Only get images with faces!
val photos = imageDao.getImagesWithFaces()
// Sort by most faces first (better for training)
val sorted = photos.sortedByDescending { it.faceCount ?: 0 }
// ✅ FIX: Sort by LEAST faces first (solo photos = best training data)
// faceCount=1 first, then faceCount=2, etc.
val sorted = photos.sortedBy { it.faceCount ?: 999 }
_photosWithFaces.value = sorted

View File

@@ -0,0 +1,315 @@
package com.placeholder.sherpai2.workers
import android.content.Context
import android.graphics.BitmapFactory
import android.net.Uri
import androidx.hilt.work.HiltWorker
import androidx.work.*
import com.google.mlkit.vision.common.InputImage
import com.google.mlkit.vision.face.FaceDetection
import com.google.mlkit.vision.face.FaceDetectorOptions
import com.placeholder.sherpai2.data.local.dao.FaceModelDao
import com.placeholder.sherpai2.data.local.dao.ImageDao
import com.placeholder.sherpai2.data.local.dao.PhotoFaceTagDao
import com.placeholder.sherpai2.data.local.entity.PhotoFaceTagEntity
import com.placeholder.sherpai2.ml.FaceNetModel
import dagger.assisted.Assisted
import dagger.assisted.AssistedInject
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.tasks.await
import kotlinx.coroutines.withContext
/**
* LibraryScanWorker - Full library background scan for a trained person
*
* PURPOSE: After user approves validation preview, scan entire library
*
* STRATEGY:
* 1. Load all photos with faces (from cache)
* 2. Scan each photo for the trained person
* 3. Create PhotoFaceTagEntity for matches
* 4. Progressive updates to "People" tab
* 5. Supports pause/resume via WorkManager
*
* SCHEDULING:
* - Runs in background with progress notifications
* - Can be cancelled by user
* - Automatically retries on failure
*
* INPUT DATA:
* - personId: String (UUID)
* - personName: String (for notifications)
* - threshold: Float (optional, default 0.70)
*
* OUTPUT DATA:
* - matchesFound: Int
* - photosScanned: Int
* - errorMessage: String? (if failed)
*/
@HiltWorker
class LibraryScanWorker @AssistedInject constructor(
@Assisted private val context: Context,
@Assisted workerParams: WorkerParameters,
private val imageDao: ImageDao,
private val faceModelDao: FaceModelDao,
private val photoFaceTagDao: PhotoFaceTagDao
) : CoroutineWorker(context, workerParams) {
companion object {
const val WORK_NAME_PREFIX = "library_scan_"
const val KEY_PERSON_ID = "person_id"
const val KEY_PERSON_NAME = "person_name"
const val KEY_THRESHOLD = "threshold"
const val KEY_PROGRESS_CURRENT = "progress_current"
const val KEY_PROGRESS_TOTAL = "progress_total"
const val KEY_MATCHES_FOUND = "matches_found"
const val KEY_PHOTOS_SCANNED = "photos_scanned"
private const val DEFAULT_THRESHOLD = 0.70f // Slightly looser than validation
private const val BATCH_SIZE = 20
private const val MAX_RETRIES = 3
/**
* Create work request for library scan
*/
fun createWorkRequest(
personId: String,
personName: String,
threshold: Float = DEFAULT_THRESHOLD
): OneTimeWorkRequest {
val inputData = workDataOf(
KEY_PERSON_ID to personId,
KEY_PERSON_NAME to personName,
KEY_THRESHOLD to threshold
)
return OneTimeWorkRequestBuilder<LibraryScanWorker>()
.setInputData(inputData)
.setConstraints(
Constraints.Builder()
.setRequiresBatteryNotLow(true) // Don't drain battery
.build()
)
.addTag(WORK_NAME_PREFIX + personId)
.build()
}
}
override suspend fun doWork(): Result = withContext(Dispatchers.Default) {
try {
// Get input parameters
val personId = inputData.getString(KEY_PERSON_ID)
?: return@withContext Result.failure(
workDataOf("error" to "Missing person ID")
)
val personName = inputData.getString(KEY_PERSON_NAME) ?: "Unknown"
val threshold = inputData.getFloat(KEY_THRESHOLD, DEFAULT_THRESHOLD)
// Check if stopped
if (isStopped) {
return@withContext Result.failure()
}
// Step 1: Get face model
val faceModel = withContext(Dispatchers.IO) {
faceModelDao.getFaceModelByPersonId(personId)
} ?: return@withContext Result.failure(
workDataOf("error" to "Face model not found")
)
setProgress(workDataOf(
KEY_PROGRESS_CURRENT to 0,
KEY_PROGRESS_TOTAL to 100
))
// Step 2: Get all photos with faces (from cache)
val photosWithFaces = withContext(Dispatchers.IO) {
imageDao.getImagesWithFaces()
}
if (photosWithFaces.isEmpty()) {
return@withContext Result.success(
workDataOf(
KEY_MATCHES_FOUND to 0,
KEY_PHOTOS_SCANNED to 0
)
)
}
// Step 3: Initialize ML components
val faceNetModel = FaceNetModel(context)
val detector = FaceDetection.getClient(
FaceDetectorOptions.Builder()
.setPerformanceMode(FaceDetectorOptions.PERFORMANCE_MODE_ACCURATE)
.setMinFaceSize(0.15f)
.build()
)
val modelEmbedding = faceModel.getEmbeddingArray()
var matchesFound = 0
var photosScanned = 0
try {
// Step 4: Process in batches
photosWithFaces.chunked(BATCH_SIZE).forEach { batch ->
if (isStopped) {
return@forEach
}
// Scan batch
batch.forEach { photo ->
try {
val tags = scanPhotoForPerson(
photo = photo,
personId = personId,
faceModelId = faceModel.id,
modelEmbedding = modelEmbedding,
faceNetModel = faceNetModel,
detector = detector,
threshold = threshold
)
if (tags.isNotEmpty()) {
// Save tags
withContext(Dispatchers.IO) {
photoFaceTagDao.insertTags(tags)
}
matchesFound += tags.size
}
photosScanned++
// Update progress
if (photosScanned % 10 == 0) {
val progress = (photosScanned * 100 / photosWithFaces.size)
setProgress(workDataOf(
KEY_PROGRESS_CURRENT to photosScanned,
KEY_PROGRESS_TOTAL to photosWithFaces.size,
KEY_MATCHES_FOUND to matchesFound
))
}
} catch (e: Exception) {
// Skip failed photos, continue scanning
}
}
}
// Success!
Result.success(
workDataOf(
KEY_MATCHES_FOUND to matchesFound,
KEY_PHOTOS_SCANNED to photosScanned
)
)
} finally {
faceNetModel.close()
detector.close()
}
} catch (e: Exception) {
// Retry on failure
if (runAttemptCount < MAX_RETRIES) {
Result.retry()
} else {
Result.failure(
workDataOf("error" to (e.message ?: "Unknown error"))
)
}
}
}
/**
* Scan a single photo for the person
*/
private suspend fun scanPhotoForPerson(
photo: com.placeholder.sherpai2.data.local.entity.ImageEntity,
personId: String,
faceModelId: String,
modelEmbedding: FloatArray,
faceNetModel: FaceNetModel,
detector: com.google.mlkit.vision.face.FaceDetector,
threshold: Float
): List<PhotoFaceTagEntity> = withContext(Dispatchers.IO) {
try {
// Load bitmap
val bitmap = loadBitmapDownsampled(Uri.parse(photo.imageUri), 768)
?: return@withContext emptyList()
// Detect faces
val inputImage = InputImage.fromBitmap(bitmap, 0)
val faces = detector.process(inputImage).await()
// Check each face
val tags = faces.mapNotNull { face ->
try {
// Crop face
val faceBitmap = android.graphics.Bitmap.createBitmap(
bitmap,
face.boundingBox.left.coerceIn(0, bitmap.width - 1),
face.boundingBox.top.coerceIn(0, bitmap.height - 1),
face.boundingBox.width().coerceAtMost(bitmap.width - face.boundingBox.left),
face.boundingBox.height().coerceAtMost(bitmap.height - face.boundingBox.top)
)
// Generate embedding
val faceEmbedding = faceNetModel.generateEmbedding(faceBitmap)
faceBitmap.recycle()
// Calculate similarity
val similarity = faceNetModel.calculateSimilarity(faceEmbedding, modelEmbedding)
if (similarity >= threshold) {
PhotoFaceTagEntity.create(
imageId = photo.imageId,
faceModelId = faceModelId,
boundingBox = face.boundingBox,
confidence = similarity,
faceEmbedding = faceEmbedding
)
} else {
null
}
} catch (e: Exception) {
null
}
}
bitmap.recycle()
tags
} catch (e: Exception) {
emptyList()
}
}
/**
* Load bitmap with downsampling for memory efficiency
*/
private fun loadBitmapDownsampled(uri: Uri, maxDim: Int): android.graphics.Bitmap? {
return try {
val opts = BitmapFactory.Options().apply { inJustDecodeBounds = true }
context.contentResolver.openInputStream(uri)?.use {
BitmapFactory.decodeStream(it, null, opts)
}
var sample = 1
while (opts.outWidth / sample > maxDim || opts.outHeight / sample > maxDim) {
sample *= 2
}
val finalOpts = BitmapFactory.Options().apply {
inSampleSize = sample
}
context.contentResolver.openInputStream(uri)?.use {
BitmapFactory.decodeStream(it, null, finalOpts)
}
} catch (e: Exception) {
null
}
}
}