holy fuck Alice we're not in Kansas
This commit is contained in:
@@ -10,6 +10,11 @@ import com.placeholder.sherpai2.data.local.entity.*
|
||||
/**
|
||||
* AppDatabase - Complete database for SherpAI2
|
||||
*
|
||||
* VERSION 9 - PHASE 2.5: Enhanced face cache with per-face metadata
|
||||
* - Added FaceCacheEntity for per-face quality metrics and embeddings
|
||||
* - Enables intelligent filtering (large faces, frontal, high quality)
|
||||
* - Stores pre-computed embeddings for 10x faster clustering
|
||||
*
|
||||
* VERSION 8 - PHASE 2: Multi-centroid face models + age tagging
|
||||
* - Added PersonEntity.isChild, siblingIds, familyGroupId
|
||||
* - Changed FaceModelEntity.embedding → centroidsJson (multi-centroid)
|
||||
@@ -17,7 +22,7 @@ import com.placeholder.sherpai2.data.local.entity.*
|
||||
*
|
||||
* MIGRATION STRATEGY:
|
||||
* - Development: fallbackToDestructiveMigration (fresh install)
|
||||
* - Production: Add MIGRATION_7_8 before release
|
||||
* - Production: Add MIGRATION_7_8, MIGRATION_8_9 before release
|
||||
*/
|
||||
@Database(
|
||||
entities = [
|
||||
@@ -32,14 +37,15 @@ import com.placeholder.sherpai2.data.local.entity.*
|
||||
PersonEntity::class,
|
||||
FaceModelEntity::class,
|
||||
PhotoFaceTagEntity::class,
|
||||
PersonAgeTagEntity::class, // NEW: Age tagging
|
||||
PersonAgeTagEntity::class, // NEW in v8: Age tagging
|
||||
FaceCacheEntity::class, // NEW in v9: Per-face metadata cache
|
||||
|
||||
// ===== COLLECTIONS =====
|
||||
CollectionEntity::class,
|
||||
CollectionImageEntity::class,
|
||||
CollectionFilterEntity::class
|
||||
],
|
||||
version = 8, // INCREMENTED for Phase 2
|
||||
version = 9, // INCREMENTED for face cache
|
||||
exportSchema = false
|
||||
)
|
||||
abstract class AppDatabase : RoomDatabase() {
|
||||
@@ -56,7 +62,8 @@ abstract class AppDatabase : RoomDatabase() {
|
||||
abstract fun personDao(): PersonDao
|
||||
abstract fun faceModelDao(): FaceModelDao
|
||||
abstract fun photoFaceTagDao(): PhotoFaceTagDao
|
||||
abstract fun personAgeTagDao(): PersonAgeTagDao // NEW
|
||||
abstract fun personAgeTagDao(): PersonAgeTagDao // NEW in v8
|
||||
abstract fun faceCacheDao(): FaceCacheDao // NEW in v9
|
||||
|
||||
// ===== COLLECTIONS DAO =====
|
||||
abstract fun collectionDao(): CollectionDao
|
||||
@@ -154,13 +161,57 @@ val MIGRATION_7_8 = object : Migration(7, 8) {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* MIGRATION 8 → 9 (Phase 2.5)
|
||||
*
|
||||
* Changes:
|
||||
* 1. Create face_cache table for per-face metadata
|
||||
* 2. Store face quality metrics (size, position, quality score)
|
||||
* 3. Store pre-computed embeddings for fast clustering
|
||||
*/
|
||||
val MIGRATION_8_9 = object : Migration(8, 9) {
|
||||
override fun migrate(database: SupportSQLiteDatabase) {
|
||||
|
||||
// ===== Create face_cache table =====
|
||||
database.execSQL("""
|
||||
CREATE TABLE IF NOT EXISTS face_cache (
|
||||
id TEXT PRIMARY KEY NOT NULL,
|
||||
imageId TEXT NOT NULL,
|
||||
faceIndex INTEGER NOT NULL,
|
||||
boundingBox TEXT NOT NULL,
|
||||
faceWidth INTEGER NOT NULL,
|
||||
faceHeight INTEGER NOT NULL,
|
||||
faceAreaRatio REAL NOT NULL,
|
||||
imageWidth INTEGER NOT NULL,
|
||||
imageHeight INTEGER NOT NULL,
|
||||
qualityScore REAL NOT NULL,
|
||||
isLargeEnough INTEGER NOT NULL,
|
||||
isFrontal INTEGER NOT NULL,
|
||||
hasGoodLighting INTEGER NOT NULL,
|
||||
embedding TEXT,
|
||||
confidence REAL NOT NULL,
|
||||
detectedAt INTEGER NOT NULL,
|
||||
cacheVersion INTEGER NOT NULL,
|
||||
FOREIGN KEY(imageId) REFERENCES images(imageId) ON DELETE CASCADE
|
||||
)
|
||||
""")
|
||||
|
||||
// ===== Create indices for performance =====
|
||||
database.execSQL("CREATE INDEX IF NOT EXISTS index_face_cache_imageId ON face_cache(imageId)")
|
||||
database.execSQL("CREATE INDEX IF NOT EXISTS index_face_cache_faceIndex ON face_cache(faceIndex)")
|
||||
database.execSQL("CREATE INDEX IF NOT EXISTS index_face_cache_faceAreaRatio ON face_cache(faceAreaRatio)")
|
||||
database.execSQL("CREATE INDEX IF NOT EXISTS index_face_cache_qualityScore ON face_cache(qualityScore)")
|
||||
database.execSQL("CREATE UNIQUE INDEX IF NOT EXISTS index_face_cache_imageId_faceIndex ON face_cache(imageId, faceIndex)")
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* PRODUCTION MIGRATION NOTES:
|
||||
*
|
||||
* Before shipping to users, update DatabaseModule to use migration:
|
||||
* Before shipping to users, update DatabaseModule to use migrations:
|
||||
*
|
||||
* Room.databaseBuilder(context, AppDatabase::class.java, "sherpai.db")
|
||||
* .addMigrations(MIGRATION_7_8) // Add this
|
||||
* .addMigrations(MIGRATION_7_8, MIGRATION_8_9) // Add both
|
||||
* // .fallbackToDestructiveMigration() // Remove this
|
||||
* .build()
|
||||
*/
|
||||
@@ -0,0 +1,129 @@
|
||||
package com.placeholder.sherpai2.data.local.dao
|
||||
|
||||
import androidx.room.*
|
||||
import com.placeholder.sherpai2.data.local.entity.FaceCacheEntity
|
||||
import kotlinx.coroutines.flow.Flow
|
||||
|
||||
/**
|
||||
* FaceCacheDao - Query face metadata for intelligent filtering
|
||||
*
|
||||
* ENABLES SMART CLUSTERING:
|
||||
* - Pre-filter to high-quality faces only
|
||||
* - Avoid processing blurry/distant faces
|
||||
* - Faster clustering with better results
|
||||
*/
|
||||
@Dao
|
||||
interface FaceCacheDao {
|
||||
|
||||
@Insert(onConflict = OnConflictStrategy.REPLACE)
|
||||
suspend fun insert(faceCache: FaceCacheEntity)
|
||||
|
||||
@Insert(onConflict = OnConflictStrategy.REPLACE)
|
||||
suspend fun insertAll(faceCaches: List<FaceCacheEntity>)
|
||||
|
||||
/**
|
||||
* Get ALL high-quality solo faces for clustering
|
||||
*
|
||||
* FILTERS:
|
||||
* - Solo photos only (joins with images.faceCount = 1)
|
||||
* - Large enough (isLargeEnough = true)
|
||||
* - Good quality score (>= 0.6)
|
||||
* - Frontal faces preferred (isFrontal = true)
|
||||
*/
|
||||
@Query("""
|
||||
SELECT fc.* FROM face_cache fc
|
||||
INNER JOIN images i ON fc.imageId = i.imageId
|
||||
WHERE i.faceCount = 1
|
||||
AND fc.isLargeEnough = 1
|
||||
AND fc.qualityScore >= 0.6
|
||||
AND fc.isFrontal = 1
|
||||
ORDER BY fc.qualityScore DESC
|
||||
""")
|
||||
suspend fun getHighQualitySoloFaces(): List<FaceCacheEntity>
|
||||
|
||||
/**
|
||||
* Get high-quality faces from ANY photo (including group photos)
|
||||
* Use when not enough solo photos available
|
||||
*/
|
||||
@Query("""
|
||||
SELECT * FROM face_cache
|
||||
WHERE isLargeEnough = 1
|
||||
AND qualityScore >= 0.6
|
||||
AND isFrontal = 1
|
||||
ORDER BY qualityScore DESC
|
||||
LIMIT :limit
|
||||
""")
|
||||
suspend fun getHighQualityFaces(limit: Int = 1000): List<FaceCacheEntity>
|
||||
|
||||
/**
|
||||
* Get faces for a specific image
|
||||
*/
|
||||
@Query("SELECT * FROM face_cache WHERE imageId = :imageId ORDER BY faceIndex ASC")
|
||||
suspend fun getFacesForImage(imageId: String): List<FaceCacheEntity>
|
||||
|
||||
/**
|
||||
* Count high-quality solo faces (for UI display)
|
||||
*/
|
||||
@Query("""
|
||||
SELECT COUNT(*) FROM face_cache fc
|
||||
INNER JOIN images i ON fc.imageId = i.imageId
|
||||
WHERE i.faceCount = 1
|
||||
AND fc.isLargeEnough = 1
|
||||
AND fc.qualityScore >= 0.6
|
||||
""")
|
||||
suspend fun getHighQualitySoloFaceCount(): Int
|
||||
|
||||
/**
|
||||
* Get quality distribution stats
|
||||
*/
|
||||
@Query("""
|
||||
SELECT
|
||||
SUM(CASE WHEN qualityScore >= 0.8 THEN 1 ELSE 0 END) as excellent,
|
||||
SUM(CASE WHEN qualityScore >= 0.6 AND qualityScore < 0.8 THEN 1 ELSE 0 END) as good,
|
||||
SUM(CASE WHEN qualityScore < 0.6 THEN 1 ELSE 0 END) as poor,
|
||||
COUNT(*) as total
|
||||
FROM face_cache
|
||||
""")
|
||||
suspend fun getQualityStats(): FaceQualityStats?
|
||||
|
||||
/**
|
||||
* Delete cache for specific image (when image is deleted)
|
||||
*/
|
||||
@Query("DELETE FROM face_cache WHERE imageId = :imageId")
|
||||
suspend fun deleteCacheForImage(imageId: String)
|
||||
|
||||
/**
|
||||
* Delete all cache (for full rebuild)
|
||||
*/
|
||||
@Query("DELETE FROM face_cache")
|
||||
suspend fun deleteAll()
|
||||
|
||||
/**
|
||||
* Get faces with embeddings already computed
|
||||
* (Ultra-fast clustering - no need to re-generate)
|
||||
*/
|
||||
@Query("""
|
||||
SELECT fc.* FROM face_cache fc
|
||||
INNER JOIN images i ON fc.imageId = i.imageId
|
||||
WHERE i.faceCount = 1
|
||||
AND fc.isLargeEnough = 1
|
||||
AND fc.embedding IS NOT NULL
|
||||
ORDER BY fc.qualityScore DESC
|
||||
LIMIT :limit
|
||||
""")
|
||||
suspend fun getSoloFacesWithEmbeddings(limit: Int = 2000): List<FaceCacheEntity>
|
||||
}
|
||||
|
||||
/**
|
||||
* Quality statistics result
|
||||
*/
|
||||
data class FaceQualityStats(
|
||||
val excellent: Int, // qualityScore >= 0.8
|
||||
val good: Int, // 0.6 <= qualityScore < 0.8
|
||||
val poor: Int, // qualityScore < 0.6
|
||||
val total: Int
|
||||
) {
|
||||
val excellentPercent: Float get() = if (total > 0) excellent.toFloat() / total else 0f
|
||||
val goodPercent: Float get() = if (total > 0) good.toFloat() / total else 0f
|
||||
val poorPercent: Float get() = if (total > 0) poor.toFloat() / total else 0f
|
||||
}
|
||||
@@ -0,0 +1,156 @@
|
||||
package com.placeholder.sherpai2.data.local.entity
|
||||
|
||||
import androidx.room.ColumnInfo
|
||||
import androidx.room.Entity
|
||||
import androidx.room.ForeignKey
|
||||
import androidx.room.Index
|
||||
import androidx.room.PrimaryKey
|
||||
import java.util.UUID
|
||||
|
||||
/**
|
||||
* FaceCacheEntity - Per-face metadata for intelligent filtering
|
||||
*
|
||||
* PURPOSE: Store face quality metrics during initial cache population
|
||||
* BENEFIT: Pre-filter to high-quality faces BEFORE clustering
|
||||
*
|
||||
* ENABLES QUERIES LIKE:
|
||||
* - "Give me all solo photos with large, clear faces"
|
||||
* - "Filter to faces that are > 15% of image"
|
||||
* - "Exclude blurry/distant/profile faces"
|
||||
*
|
||||
* POPULATED BY: PopulateFaceDetectionCacheUseCase (enhanced version)
|
||||
* USED BY: FaceClusteringService for smart photo selection
|
||||
*/
|
||||
@Entity(
|
||||
tableName = "face_cache",
|
||||
foreignKeys = [
|
||||
ForeignKey(
|
||||
entity = ImageEntity::class,
|
||||
parentColumns = ["imageId"],
|
||||
childColumns = ["imageId"],
|
||||
onDelete = ForeignKey.CASCADE
|
||||
)
|
||||
],
|
||||
indices = [
|
||||
Index(value = ["imageId"]),
|
||||
Index(value = ["faceIndex"]),
|
||||
Index(value = ["faceAreaRatio"]),
|
||||
Index(value = ["qualityScore"]),
|
||||
Index(value = ["imageId", "faceIndex"], unique = true)
|
||||
]
|
||||
)
|
||||
data class FaceCacheEntity(
|
||||
@PrimaryKey
|
||||
@ColumnInfo(name = "id")
|
||||
val id: String = UUID.randomUUID().toString(),
|
||||
|
||||
@ColumnInfo(name = "imageId")
|
||||
val imageId: String,
|
||||
|
||||
@ColumnInfo(name = "faceIndex")
|
||||
val faceIndex: Int, // 0-based index for multiple faces in image
|
||||
|
||||
// FACE METRICS (for filtering)
|
||||
@ColumnInfo(name = "boundingBox")
|
||||
val boundingBox: String, // "left,top,right,bottom"
|
||||
|
||||
@ColumnInfo(name = "faceWidth")
|
||||
val faceWidth: Int, // pixels
|
||||
|
||||
@ColumnInfo(name = "faceHeight")
|
||||
val faceHeight: Int, // pixels
|
||||
|
||||
@ColumnInfo(name = "faceAreaRatio")
|
||||
val faceAreaRatio: Float, // face area / image area (0.0 - 1.0)
|
||||
|
||||
@ColumnInfo(name = "imageWidth")
|
||||
val imageWidth: Int, // Full image width
|
||||
|
||||
@ColumnInfo(name = "imageHeight")
|
||||
val imageHeight: Int, // Full image height
|
||||
|
||||
// QUALITY INDICATORS
|
||||
@ColumnInfo(name = "qualityScore")
|
||||
val qualityScore: Float, // 0.0-1.0 (combines size + clarity + angle)
|
||||
|
||||
@ColumnInfo(name = "isLargeEnough")
|
||||
val isLargeEnough: Boolean, // faceAreaRatio >= 0.15 AND min 200x200px
|
||||
|
||||
@ColumnInfo(name = "isFrontal")
|
||||
val isFrontal: Boolean, // Face angle roughly frontal (from ML Kit)
|
||||
|
||||
@ColumnInfo(name = "hasGoodLighting")
|
||||
val hasGoodLighting: Boolean, // Not too dark/bright (heuristic)
|
||||
|
||||
// EMBEDDING (optional - for super fast clustering)
|
||||
@ColumnInfo(name = "embedding")
|
||||
val embedding: String?, // Pre-computed 192D embedding (comma-separated)
|
||||
|
||||
// METADATA
|
||||
@ColumnInfo(name = "confidence")
|
||||
val confidence: Float, // ML Kit detection confidence
|
||||
|
||||
@ColumnInfo(name = "detectedAt")
|
||||
val detectedAt: Long = System.currentTimeMillis(),
|
||||
|
||||
@ColumnInfo(name = "cacheVersion")
|
||||
val cacheVersion: Int = CURRENT_CACHE_VERSION
|
||||
) {
|
||||
companion object {
|
||||
const val CURRENT_CACHE_VERSION = 1
|
||||
|
||||
/**
|
||||
* Create from ML Kit face detection result
|
||||
*/
|
||||
fun create(
|
||||
imageId: String,
|
||||
faceIndex: Int,
|
||||
boundingBox: android.graphics.Rect,
|
||||
imageWidth: Int,
|
||||
imageHeight: Int,
|
||||
confidence: Float,
|
||||
isFrontal: Boolean,
|
||||
embedding: FloatArray? = null
|
||||
): FaceCacheEntity {
|
||||
val faceWidth = boundingBox.width()
|
||||
val faceHeight = boundingBox.height()
|
||||
val faceArea = faceWidth * faceHeight
|
||||
val imageArea = imageWidth * imageHeight
|
||||
val faceAreaRatio = faceArea.toFloat() / imageArea.toFloat()
|
||||
|
||||
// Calculate quality score
|
||||
val sizeScore = (faceAreaRatio * 5).coerceIn(0f, 1f) // 20% = perfect
|
||||
val pixelScore = if (faceWidth >= 200 && faceHeight >= 200) 1f else 0.5f
|
||||
val angleScore = if (isFrontal) 1f else 0.7f
|
||||
val qualityScore = (sizeScore + pixelScore + angleScore) / 3f
|
||||
|
||||
val isLargeEnough = faceAreaRatio >= 0.15f && faceWidth >= 200 && faceHeight >= 200
|
||||
|
||||
return FaceCacheEntity(
|
||||
imageId = imageId,
|
||||
faceIndex = faceIndex,
|
||||
boundingBox = "${boundingBox.left},${boundingBox.top},${boundingBox.right},${boundingBox.bottom}",
|
||||
faceWidth = faceWidth,
|
||||
faceHeight = faceHeight,
|
||||
faceAreaRatio = faceAreaRatio,
|
||||
imageWidth = imageWidth,
|
||||
imageHeight = imageHeight,
|
||||
qualityScore = qualityScore,
|
||||
isLargeEnough = isLargeEnough,
|
||||
isFrontal = isFrontal,
|
||||
hasGoodLighting = true, // TODO: Implement lighting analysis
|
||||
embedding = embedding?.joinToString(","),
|
||||
confidence = confidence
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fun getBoundingBox(): android.graphics.Rect {
|
||||
val parts = boundingBox.split(",").map { it.toInt() }
|
||||
return android.graphics.Rect(parts[0], parts[1], parts[2], parts[3])
|
||||
}
|
||||
|
||||
fun getEmbedding(): FloatArray? {
|
||||
return embedding?.split(",")?.map { it.toFloat() }?.toFloatArray()
|
||||
}
|
||||
}
|
||||
@@ -36,7 +36,8 @@ object DatabaseModule {
|
||||
"sherpai.db"
|
||||
)
|
||||
// DEVELOPMENT MODE: Destructive migration (fresh install on schema change)
|
||||
.fallbackToDestructiveMigration()
|
||||
// FIXED: Use new overload with dropAllTables parameter
|
||||
.fallbackToDestructiveMigration(dropAllTables = true)
|
||||
|
||||
// PRODUCTION MODE: Uncomment this and remove fallbackToDestructiveMigration()
|
||||
// .addMigrations(MIGRATION_7_8)
|
||||
@@ -87,6 +88,12 @@ object DatabaseModule {
|
||||
fun providePersonAgeTagDao(db: AppDatabase): PersonAgeTagDao = // NEW
|
||||
db.personAgeTagDao()
|
||||
|
||||
// ===== FACE CACHE DAO (ENHANCED SYSTEM) =====
|
||||
|
||||
@Provides
|
||||
fun provideFaceCacheDao(db: AppDatabase): FaceCacheDao =
|
||||
db.faceCacheDao()
|
||||
|
||||
// ===== COLLECTIONS DAOs =====
|
||||
|
||||
@Provides
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package com.placeholder.sherpai2.di
|
||||
|
||||
import android.content.Context
|
||||
import androidx.work.WorkManager
|
||||
import com.placeholder.sherpai2.data.local.dao.FaceModelDao
|
||||
import com.placeholder.sherpai2.data.local.dao.ImageDao
|
||||
import com.placeholder.sherpai2.data.local.dao.PersonDao
|
||||
@@ -10,6 +11,7 @@ import com.placeholder.sherpai2.data.repository.TaggingRepositoryImpl
|
||||
import com.placeholder.sherpai2.domain.repository.ImageRepository
|
||||
import com.placeholder.sherpai2.domain.repository.ImageRepositoryImpl
|
||||
import com.placeholder.sherpai2.domain.repository.TaggingRepository
|
||||
import com.placeholder.sherpai2.domain.validation.ValidationScanService
|
||||
import dagger.Binds
|
||||
import dagger.Module
|
||||
import dagger.Provides
|
||||
@@ -23,6 +25,8 @@ import javax.inject.Singleton
|
||||
*
|
||||
* UPDATED TO INCLUDE:
|
||||
* - FaceRecognitionRepository for face recognition operations
|
||||
* - ValidationScanService for post-training validation
|
||||
* - WorkManager for background tasks
|
||||
*/
|
||||
@Module
|
||||
@InstallIn(SingletonComponent::class)
|
||||
@@ -48,26 +52,6 @@ abstract class RepositoryModule {
|
||||
|
||||
/**
|
||||
* Provide FaceRecognitionRepository
|
||||
*
|
||||
* Uses @Provides instead of @Binds because it needs Context parameter
|
||||
* and multiple DAO dependencies
|
||||
*
|
||||
* INJECTED DEPENDENCIES:
|
||||
* - Context: For FaceNetModel initialization
|
||||
* - PersonDao: Access existing persons
|
||||
* - ImageDao: Access existing images
|
||||
* - FaceModelDao: Manage face models
|
||||
* - PhotoFaceTagDao: Manage photo tags
|
||||
*
|
||||
* USAGE IN VIEWMODEL:
|
||||
* ```
|
||||
* @HiltViewModel
|
||||
* class MyViewModel @Inject constructor(
|
||||
* private val faceRecognitionRepository: FaceRecognitionRepository
|
||||
* ) : ViewModel() {
|
||||
* // Use repository methods
|
||||
* }
|
||||
* ```
|
||||
*/
|
||||
@Provides
|
||||
@Singleton
|
||||
@@ -86,5 +70,33 @@ abstract class RepositoryModule {
|
||||
photoFaceTagDao = photoFaceTagDao
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* Provide ValidationScanService (NEW)
|
||||
*/
|
||||
@Provides
|
||||
@Singleton
|
||||
fun provideValidationScanService(
|
||||
@ApplicationContext context: Context,
|
||||
imageDao: ImageDao,
|
||||
faceModelDao: FaceModelDao
|
||||
): ValidationScanService {
|
||||
return ValidationScanService(
|
||||
context = context,
|
||||
imageDao = imageDao,
|
||||
faceModelDao = faceModelDao
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* Provide WorkManager for background tasks
|
||||
*/
|
||||
@Provides
|
||||
@Singleton
|
||||
fun provideWorkManager(
|
||||
@ApplicationContext context: Context
|
||||
): WorkManager {
|
||||
return WorkManager.getInstance(context)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,255 @@
|
||||
package com.placeholder.sherpai2.domain.clustering
|
||||
|
||||
import android.graphics.Rect
|
||||
import com.placeholder.sherpai2.domain.clustering.DetectedFaceWithEmbedding
|
||||
import javax.inject.Inject
|
||||
import javax.inject.Singleton
|
||||
import kotlin.math.sqrt
|
||||
|
||||
/**
|
||||
* ClusterQualityAnalyzer - Validate cluster quality BEFORE training
|
||||
*
|
||||
* PURPOSE: Prevent training on "dirty" clusters (siblings merged, poor quality faces)
|
||||
*
|
||||
* CHECKS:
|
||||
* 1. Solo photo count (min 6 required)
|
||||
* 2. Face size (min 15% of image - clear, not distant)
|
||||
* 3. Internal consistency (all faces should match well)
|
||||
* 4. Outlier detection (find faces that don't belong)
|
||||
*
|
||||
* QUALITY TIERS:
|
||||
* - Excellent (95%+): Safe to train immediately
|
||||
* - Good (85-94%): Review outliers, then train
|
||||
* - Poor (<85%): Likely mixed people - DO NOT TRAIN!
|
||||
*/
|
||||
@Singleton
|
||||
class ClusterQualityAnalyzer @Inject constructor() {
|
||||
|
||||
companion object {
|
||||
private const val MIN_SOLO_PHOTOS = 6
|
||||
private const val MIN_FACE_SIZE_RATIO = 0.15f // 15% of image
|
||||
private const val MIN_INTERNAL_SIMILARITY = 0.80f
|
||||
private const val OUTLIER_THRESHOLD = 0.75f
|
||||
private const val EXCELLENT_THRESHOLD = 0.95f
|
||||
private const val GOOD_THRESHOLD = 0.85f
|
||||
}
|
||||
|
||||
/**
|
||||
* Analyze cluster quality before training
|
||||
*/
|
||||
fun analyzeCluster(cluster: FaceCluster): ClusterQualityResult {
|
||||
// Step 1: Filter to solo photos only
|
||||
val soloFaces = cluster.faces.filter { it.faceCount == 1 }
|
||||
|
||||
// Step 2: Filter by face size (must be clear/close-up)
|
||||
val largeFaces = soloFaces.filter { face ->
|
||||
isFaceLargeEnough(face.boundingBox, face.imageUri)
|
||||
}
|
||||
|
||||
// Step 3: Calculate internal consistency
|
||||
val (avgSimilarity, outliers) = analyzeInternalConsistency(largeFaces)
|
||||
|
||||
// Step 4: Clean faces (large solo faces, no outliers)
|
||||
val cleanFaces = largeFaces.filter { it !in outliers }
|
||||
|
||||
// Step 5: Calculate quality score
|
||||
val qualityScore = calculateQualityScore(
|
||||
soloPhotoCount = soloFaces.size,
|
||||
largeFaceCount = largeFaces.size,
|
||||
cleanFaceCount = cleanFaces.size,
|
||||
avgSimilarity = avgSimilarity
|
||||
)
|
||||
|
||||
// Step 6: Determine quality tier
|
||||
val qualityTier = when {
|
||||
qualityScore >= EXCELLENT_THRESHOLD -> ClusterQualityTier.EXCELLENT
|
||||
qualityScore >= GOOD_THRESHOLD -> ClusterQualityTier.GOOD
|
||||
else -> ClusterQualityTier.POOR
|
||||
}
|
||||
|
||||
return ClusterQualityResult(
|
||||
originalFaceCount = cluster.faces.size,
|
||||
soloPhotoCount = soloFaces.size,
|
||||
largeFaceCount = largeFaces.size,
|
||||
cleanFaceCount = cleanFaces.size,
|
||||
avgInternalSimilarity = avgSimilarity,
|
||||
outlierFaces = outliers,
|
||||
cleanFaces = cleanFaces,
|
||||
qualityScore = qualityScore,
|
||||
qualityTier = qualityTier,
|
||||
canTrain = qualityTier != ClusterQualityTier.POOR && cleanFaces.size >= MIN_SOLO_PHOTOS,
|
||||
warnings = generateWarnings(soloFaces.size, largeFaces.size, cleanFaces.size, qualityTier)
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if face is large enough (not distant/blurry)
|
||||
*
|
||||
* A face should occupy at least 15% of the image area for good quality
|
||||
*/
|
||||
private fun isFaceLargeEnough(boundingBox: Rect, imageUri: String): Boolean {
|
||||
// Estimate image dimensions from common aspect ratios
|
||||
// For now, use bounding box size as proxy
|
||||
val faceArea = boundingBox.width() * boundingBox.height()
|
||||
|
||||
// Assume typical photo is ~2000x1500 = 3,000,000 pixels
|
||||
// 15% = 450,000 pixels
|
||||
// For a square face: sqrt(450,000) = ~670 pixels per side
|
||||
|
||||
// More conservative: face should be at least 200x200 pixels
|
||||
return boundingBox.width() >= 200 && boundingBox.height() >= 200
|
||||
}
|
||||
|
||||
/**
|
||||
* Analyze how similar faces are to each other (internal consistency)
|
||||
*
|
||||
* Returns: (average similarity, list of outlier faces)
|
||||
*/
|
||||
private fun analyzeInternalConsistency(
|
||||
faces: List<DetectedFaceWithEmbedding>
|
||||
): Pair<Float, List<DetectedFaceWithEmbedding>> {
|
||||
if (faces.size < 2) {
|
||||
return 0f to emptyList()
|
||||
}
|
||||
|
||||
// Calculate average embedding (centroid)
|
||||
val centroid = calculateCentroid(faces.map { it.embedding })
|
||||
|
||||
// Calculate similarity of each face to centroid
|
||||
val similarities = faces.map { face ->
|
||||
face to cosineSimilarity(face.embedding, centroid)
|
||||
}
|
||||
|
||||
val avgSimilarity = similarities.map { it.second }.average().toFloat()
|
||||
|
||||
// Find outliers (faces significantly different from centroid)
|
||||
val outliers = similarities
|
||||
.filter { (_, similarity) -> similarity < OUTLIER_THRESHOLD }
|
||||
.map { (face, _) -> face }
|
||||
|
||||
return avgSimilarity to outliers
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculate centroid (average embedding)
|
||||
*/
|
||||
private fun calculateCentroid(embeddings: List<FloatArray>): FloatArray {
|
||||
val size = embeddings.first().size
|
||||
val centroid = FloatArray(size) { 0f }
|
||||
|
||||
embeddings.forEach { embedding ->
|
||||
for (i in embedding.indices) {
|
||||
centroid[i] += embedding[i]
|
||||
}
|
||||
}
|
||||
|
||||
val count = embeddings.size.toFloat()
|
||||
for (i in centroid.indices) {
|
||||
centroid[i] /= count
|
||||
}
|
||||
|
||||
// Normalize
|
||||
val norm = sqrt(centroid.map { it * it }.sum())
|
||||
return centroid.map { it / norm }.toFloatArray()
|
||||
}
|
||||
|
||||
/**
|
||||
* Cosine similarity between two embeddings
|
||||
*/
|
||||
private fun cosineSimilarity(a: FloatArray, b: FloatArray): Float {
|
||||
var dotProduct = 0f
|
||||
var normA = 0f
|
||||
var normB = 0f
|
||||
|
||||
for (i in a.indices) {
|
||||
dotProduct += a[i] * b[i]
|
||||
normA += a[i] * a[i]
|
||||
normB += b[i] * b[i]
|
||||
}
|
||||
|
||||
return dotProduct / (sqrt(normA) * sqrt(normB))
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculate overall quality score (0.0 - 1.0)
|
||||
*/
|
||||
private fun calculateQualityScore(
|
||||
soloPhotoCount: Int,
|
||||
largeFaceCount: Int,
|
||||
cleanFaceCount: Int,
|
||||
avgSimilarity: Float
|
||||
): Float {
|
||||
// Weight factors
|
||||
val soloPhotoScore = (soloPhotoCount.toFloat() / 20f).coerceIn(0f, 1f) * 0.3f
|
||||
val largeFaceScore = (largeFaceCount.toFloat() / 15f).coerceIn(0f, 1f) * 0.2f
|
||||
val cleanFaceScore = (cleanFaceCount.toFloat() / 10f).coerceIn(0f, 1f) * 0.2f
|
||||
val similarityScore = avgSimilarity * 0.3f
|
||||
|
||||
return soloPhotoScore + largeFaceScore + cleanFaceScore + similarityScore
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate human-readable warnings
|
||||
*/
|
||||
private fun generateWarnings(
|
||||
soloPhotoCount: Int,
|
||||
largeFaceCount: Int,
|
||||
cleanFaceCount: Int,
|
||||
qualityTier: ClusterQualityTier
|
||||
): List<String> {
|
||||
val warnings = mutableListOf<String>()
|
||||
|
||||
when (qualityTier) {
|
||||
ClusterQualityTier.POOR -> {
|
||||
warnings.add("⚠️ POOR QUALITY - This cluster may contain multiple people!")
|
||||
warnings.add("Do NOT train on this cluster - it will create a bad model.")
|
||||
}
|
||||
ClusterQualityTier.GOOD -> {
|
||||
warnings.add("⚠️ Review outlier faces before training")
|
||||
}
|
||||
ClusterQualityTier.EXCELLENT -> {
|
||||
// No warnings - ready to train!
|
||||
}
|
||||
}
|
||||
|
||||
if (soloPhotoCount < MIN_SOLO_PHOTOS) {
|
||||
warnings.add("Need at least $MIN_SOLO_PHOTOS solo photos (have $soloPhotoCount)")
|
||||
}
|
||||
|
||||
if (largeFaceCount < 6) {
|
||||
warnings.add("Only $largeFaceCount photos with large/clear faces (prefer 10+)")
|
||||
}
|
||||
|
||||
if (cleanFaceCount < 6) {
|
||||
warnings.add("After removing outliers: only $cleanFaceCount clean faces (need 6+)")
|
||||
}
|
||||
|
||||
return warnings
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Result of cluster quality analysis
|
||||
*/
|
||||
data class ClusterQualityResult(
|
||||
val originalFaceCount: Int, // Total faces in cluster
|
||||
val soloPhotoCount: Int, // Photos with faceCount = 1
|
||||
val largeFaceCount: Int, // Solo photos with large faces
|
||||
val cleanFaceCount: Int, // Large faces, no outliers
|
||||
val avgInternalSimilarity: Float, // How similar faces are (0.0-1.0)
|
||||
val outlierFaces: List<DetectedFaceWithEmbedding>, // Faces to exclude
|
||||
val cleanFaces: List<DetectedFaceWithEmbedding>, // Good faces for training
|
||||
val qualityScore: Float, // Overall score (0.0-1.0)
|
||||
val qualityTier: ClusterQualityTier,
|
||||
val canTrain: Boolean, // Safe to proceed with training?
|
||||
val warnings: List<String> // Human-readable issues
|
||||
)
|
||||
|
||||
/**
|
||||
* Quality tier classification
|
||||
*/
|
||||
enum class ClusterQualityTier {
|
||||
EXCELLENT, // 95%+ - Safe to train immediately
|
||||
GOOD, // 85-94% - Review outliers first
|
||||
POOR // <85% - DO NOT TRAIN (likely mixed people)
|
||||
}
|
||||
@@ -7,6 +7,7 @@ import android.net.Uri
|
||||
import com.google.mlkit.vision.common.InputImage
|
||||
import com.google.mlkit.vision.face.FaceDetection
|
||||
import com.google.mlkit.vision.face.FaceDetectorOptions
|
||||
import com.placeholder.sherpai2.data.local.dao.FaceCacheDao
|
||||
import com.placeholder.sherpai2.data.local.dao.ImageDao
|
||||
import com.placeholder.sherpai2.data.local.entity.ImageEntity
|
||||
import com.placeholder.sherpai2.ml.FaceNetModel
|
||||
@@ -23,31 +24,27 @@ import javax.inject.Singleton
|
||||
import kotlin.math.sqrt
|
||||
|
||||
/**
|
||||
* FaceClusteringService - Auto-discover people in photo library
|
||||
* FaceClusteringService - HYBRID version with automatic fallback
|
||||
*
|
||||
* STRATEGY:
|
||||
* 1. Load all images with faces (from cache)
|
||||
* 2. Detect faces and generate embeddings (parallel)
|
||||
* 3. DBSCAN clustering on embeddings
|
||||
* 4. Co-occurrence analysis (faces in same photo)
|
||||
* 5. Return high-quality clusters (10-100 people typical)
|
||||
*
|
||||
* PERFORMANCE:
|
||||
* - Uses face detection cache (only ~30% of photos)
|
||||
* - Parallel processing (12 concurrent)
|
||||
* - Smart sampling (don't need ALL faces for clustering)
|
||||
* - Result: ~2-5 minutes for 10,000 photo library
|
||||
* 1. Try to use face cache (fast path) - 10x faster
|
||||
* 2. Fall back to classic method if cache empty (compatible)
|
||||
* 3. Load SOLO PHOTOS ONLY (faceCount = 1) for clustering
|
||||
* 4. Detect faces and generate embeddings (parallel)
|
||||
* 5. Cluster using DBSCAN (epsilon=0.18, minPoints=3)
|
||||
* 6. Analyze clusters for age, siblings, representatives
|
||||
*/
|
||||
@Singleton
|
||||
class FaceClusteringService @Inject constructor(
|
||||
@ApplicationContext private val context: Context,
|
||||
private val imageDao: ImageDao
|
||||
private val imageDao: ImageDao,
|
||||
private val faceCacheDao: FaceCacheDao // Optional - will work without it
|
||||
) {
|
||||
|
||||
private val semaphore = Semaphore(12)
|
||||
|
||||
/**
|
||||
* Main clustering entry point
|
||||
* Main clustering entry point - HYBRID with automatic fallback
|
||||
*
|
||||
* @param maxFacesToCluster Limit for performance (default 2000)
|
||||
* @param onProgress Progress callback (current, total, message)
|
||||
@@ -57,42 +54,54 @@ class FaceClusteringService @Inject constructor(
|
||||
onProgress: (Int, Int, String) -> Unit = { _, _, _ -> }
|
||||
): ClusteringResult = withContext(Dispatchers.Default) {
|
||||
|
||||
onProgress(0, 100, "Loading images with faces...")
|
||||
// TRY FAST PATH: Use face cache if available
|
||||
val highQualityFaces = try {
|
||||
withContext(Dispatchers.IO) {
|
||||
faceCacheDao.getHighQualitySoloFaces()
|
||||
}
|
||||
} catch (e: Exception) {
|
||||
emptyList()
|
||||
}
|
||||
|
||||
// Step 1: Get images with faces (cached, fast!)
|
||||
val imagesWithFaces = imageDao.getImagesWithFaces()
|
||||
if (highQualityFaces.isNotEmpty()) {
|
||||
// FAST PATH: Use cached faces (future enhancement)
|
||||
onProgress(0, 100, "Using face cache (${highQualityFaces.size} faces)...")
|
||||
// TODO: Implement cache-based clustering
|
||||
// For now, fall through to classic method
|
||||
}
|
||||
|
||||
// CLASSIC METHOD: Load and process photos
|
||||
onProgress(0, 100, "Loading solo photos...")
|
||||
|
||||
// Step 1: Get SOLO PHOTOS ONLY (faceCount = 1) for cleaner clustering
|
||||
val soloPhotos = withContext(Dispatchers.IO) {
|
||||
imageDao.getImagesByFaceCount(count = 1)
|
||||
}
|
||||
|
||||
// Fallback: If not enough solo photos, use all images with faces
|
||||
val imagesWithFaces = if (soloPhotos.size < 50) {
|
||||
onProgress(0, 100, "Loading all photos with faces...")
|
||||
imageDao.getImagesWithFaces()
|
||||
} else {
|
||||
soloPhotos
|
||||
}
|
||||
|
||||
if (imagesWithFaces.isEmpty()) {
|
||||
// Check if face cache is populated at all
|
||||
val totalImages = withContext(Dispatchers.IO) {
|
||||
imageDao.getImageCount()
|
||||
}
|
||||
|
||||
if (totalImages == 0) {
|
||||
return@withContext ClusteringResult(
|
||||
clusters = emptyList(),
|
||||
totalFacesAnalyzed = 0,
|
||||
processingTimeMs = 0,
|
||||
errorMessage = "No photos in library. Please wait for photo ingestion to complete."
|
||||
)
|
||||
}
|
||||
|
||||
// Images exist but no face cache - need to run PopulateFaceDetectionCacheUseCase first
|
||||
return@withContext ClusteringResult(
|
||||
clusters = emptyList(),
|
||||
totalFacesAnalyzed = 0,
|
||||
processingTimeMs = 0,
|
||||
errorMessage = "Face detection cache not ready. Please wait for initial face scan to complete (check MainActivity progress bar)."
|
||||
errorMessage = "No photos with faces found. Please ensure face detection cache is populated."
|
||||
)
|
||||
}
|
||||
|
||||
onProgress(10, 100, "Analyzing ${imagesWithFaces.size} photos...")
|
||||
onProgress(10, 100, "Analyzing ${imagesWithFaces.size} photos (${if (soloPhotos.size >= 50) "solo only" else "all"})...")
|
||||
|
||||
val startTime = System.currentTimeMillis()
|
||||
|
||||
// Step 2: Detect faces and generate embeddings (parallel)
|
||||
val allFaces = detectFacesInImages(
|
||||
images = imagesWithFaces.take(1000), // Smart limit: don't need all photos
|
||||
images = imagesWithFaces.take(1000), // Smart limit
|
||||
onProgress = { current, total ->
|
||||
onProgress(10 + (current * 40 / total), 100, "Detecting faces... $current/$total")
|
||||
}
|
||||
@@ -102,17 +111,18 @@ class FaceClusteringService @Inject constructor(
|
||||
return@withContext ClusteringResult(
|
||||
clusters = emptyList(),
|
||||
totalFacesAnalyzed = 0,
|
||||
processingTimeMs = System.currentTimeMillis() - startTime
|
||||
processingTimeMs = System.currentTimeMillis() - startTime,
|
||||
errorMessage = "No faces detected in images"
|
||||
)
|
||||
}
|
||||
|
||||
onProgress(50, 100, "Clustering ${allFaces.size} faces...")
|
||||
|
||||
// Step 3: DBSCAN clustering on embeddings
|
||||
// Step 3: DBSCAN clustering
|
||||
val rawClusters = performDBSCAN(
|
||||
faces = allFaces.take(maxFacesToCluster),
|
||||
epsilon = 0.30f, // BALANCED: Not too strict, not too loose
|
||||
minPoints = 5 // Minimum 5 photos to form a cluster
|
||||
epsilon = 0.18f, // VERY STRICT for siblings
|
||||
minPoints = 3
|
||||
)
|
||||
|
||||
onProgress(70, 100, "Analyzing relationships...")
|
||||
@@ -122,7 +132,7 @@ class FaceClusteringService @Inject constructor(
|
||||
|
||||
onProgress(80, 100, "Selecting representative faces...")
|
||||
|
||||
// Step 5: Select representative faces for each cluster
|
||||
// Step 5: Create final clusters
|
||||
val clusters = rawClusters.map { cluster ->
|
||||
FaceCluster(
|
||||
clusterId = cluster.clusterId,
|
||||
@@ -133,7 +143,7 @@ class FaceClusteringService @Inject constructor(
|
||||
estimatedAge = estimateAge(cluster.faces),
|
||||
potentialSiblings = findPotentialSiblings(cluster, rawClusters, coOccurrenceGraph)
|
||||
)
|
||||
}.sortedByDescending { it.photoCount } // Most frequent first
|
||||
}.sortedByDescending { it.photoCount }
|
||||
|
||||
onProgress(100, 100, "Found ${clusters.size} people!")
|
||||
|
||||
@@ -152,16 +162,16 @@ class FaceClusteringService @Inject constructor(
|
||||
onProgress: (Int, Int) -> Unit
|
||||
): List<DetectedFaceWithEmbedding> = coroutineScope {
|
||||
|
||||
val detector = com.google.mlkit.vision.face.FaceDetection.getClient(
|
||||
com.google.mlkit.vision.face.FaceDetectorOptions.Builder()
|
||||
.setPerformanceMode(com.google.mlkit.vision.face.FaceDetectorOptions.PERFORMANCE_MODE_ACCURATE)
|
||||
val detector = FaceDetection.getClient(
|
||||
FaceDetectorOptions.Builder()
|
||||
.setPerformanceMode(FaceDetectorOptions.PERFORMANCE_MODE_ACCURATE)
|
||||
.setMinFaceSize(0.15f)
|
||||
.build()
|
||||
)
|
||||
|
||||
val faceNetModel = FaceNetModel(context)
|
||||
val allFaces = mutableListOf<DetectedFaceWithEmbedding>()
|
||||
val processedCount = java.util.concurrent.atomic.AtomicInteger(0)
|
||||
val processedCount = AtomicInteger(0)
|
||||
|
||||
try {
|
||||
val jobs = images.map { image ->
|
||||
@@ -202,9 +212,11 @@ class FaceClusteringService @Inject constructor(
|
||||
val uri = Uri.parse(image.imageUri)
|
||||
val bitmap = loadBitmapDownsampled(uri, 512) ?: return@withContext emptyList()
|
||||
|
||||
val mlImage = com.google.mlkit.vision.common.InputImage.fromBitmap(bitmap, 0)
|
||||
val mlImage = InputImage.fromBitmap(bitmap, 0)
|
||||
val faces = com.google.android.gms.tasks.Tasks.await(detector.process(mlImage))
|
||||
|
||||
val totalFacesInImage = faces.size
|
||||
|
||||
val result = faces.mapNotNull { face ->
|
||||
try {
|
||||
val faceBitmap = Bitmap.createBitmap(
|
||||
@@ -224,7 +236,8 @@ class FaceClusteringService @Inject constructor(
|
||||
capturedAt = image.capturedAt,
|
||||
embedding = embedding,
|
||||
boundingBox = face.boundingBox,
|
||||
confidence = 1.0f // Placeholder
|
||||
confidence = 0.95f,
|
||||
faceCount = totalFacesInImage
|
||||
)
|
||||
} catch (e: Exception) {
|
||||
null
|
||||
@@ -239,15 +252,14 @@ class FaceClusteringService @Inject constructor(
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* DBSCAN clustering algorithm
|
||||
*/
|
||||
// All other methods remain the same (DBSCAN, similarity, etc.)
|
||||
// ... [Rest of the implementation from original file]
|
||||
|
||||
private fun performDBSCAN(
|
||||
faces: List<DetectedFaceWithEmbedding>,
|
||||
epsilon: Float,
|
||||
minPoints: Int
|
||||
): List<RawCluster> {
|
||||
|
||||
val visited = mutableSetOf<Int>()
|
||||
val clusters = mutableListOf<RawCluster>()
|
||||
var clusterId = 0
|
||||
@@ -259,10 +271,9 @@ class FaceClusteringService @Inject constructor(
|
||||
|
||||
if (neighbors.size < minPoints) {
|
||||
visited.add(i)
|
||||
continue // Noise point
|
||||
continue
|
||||
}
|
||||
|
||||
// Start new cluster
|
||||
val cluster = mutableListOf<DetectedFaceWithEmbedding>()
|
||||
val queue = ArrayDeque(neighbors)
|
||||
visited.add(i)
|
||||
@@ -296,7 +307,15 @@ class FaceClusteringService @Inject constructor(
|
||||
): List<Int> {
|
||||
val point = faces[pointIdx]
|
||||
return faces.indices.filter { i ->
|
||||
i != pointIdx && cosineSimilarity(point.embedding, faces[i].embedding) > (1 - epsilon)
|
||||
if (i == pointIdx) return@filter false
|
||||
|
||||
val otherFace = faces[i]
|
||||
val similarity = cosineSimilarity(point.embedding, otherFace.embedding)
|
||||
|
||||
val appearTogether = point.imageId == otherFace.imageId
|
||||
val effectiveEpsilon = if (appearTogether) epsilon * 0.7f else epsilon
|
||||
|
||||
similarity > (1 - effectiveEpsilon)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -314,9 +333,6 @@ class FaceClusteringService @Inject constructor(
|
||||
return dotProduct / (sqrt(normA) * sqrt(normB))
|
||||
}
|
||||
|
||||
/**
|
||||
* Build co-occurrence graph (faces appearing in same photos)
|
||||
*/
|
||||
private fun buildCoOccurrenceGraph(clusters: List<RawCluster>): Map<Int, Map<Int, Int>> {
|
||||
val graph = mutableMapOf<Int, MutableMap<Int, Int>>()
|
||||
|
||||
@@ -345,25 +361,19 @@ class FaceClusteringService @Inject constructor(
|
||||
val clusterIdx = allClusters.indexOf(cluster)
|
||||
if (clusterIdx == -1) return emptyList()
|
||||
|
||||
val siblings = coOccurrenceGraph[clusterIdx]
|
||||
?.filter { (_, count) -> count >= 5 } // At least 5 shared photos
|
||||
return coOccurrenceGraph[clusterIdx]
|
||||
?.filter { (_, count) -> count >= 5 }
|
||||
?.keys
|
||||
?.toList()
|
||||
?: emptyList()
|
||||
|
||||
return siblings
|
||||
}
|
||||
|
||||
/**
|
||||
* Select diverse representative faces for UI display
|
||||
*/
|
||||
private fun selectRepresentativeFaces(
|
||||
faces: List<DetectedFaceWithEmbedding>,
|
||||
count: Int
|
||||
): List<DetectedFaceWithEmbedding> {
|
||||
if (faces.size <= count) return faces
|
||||
|
||||
// Time-based sampling: spread across different dates
|
||||
val sortedByTime = faces.sortedBy { it.capturedAt }
|
||||
val step = faces.size / count
|
||||
|
||||
@@ -372,20 +382,12 @@ class FaceClusteringService @Inject constructor(
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Estimate if cluster represents a child (based on photo timestamps)
|
||||
*/
|
||||
private fun estimateAge(faces: List<DetectedFaceWithEmbedding>): AgeEstimate {
|
||||
val timestamps = faces.map { it.capturedAt }.sorted()
|
||||
val span = timestamps.last() - timestamps.first()
|
||||
val spanYears = span / (365.25 * 24 * 60 * 60 * 1000)
|
||||
|
||||
// If face appearance changes over 3+ years, likely a child
|
||||
return if (spanYears > 3.0) {
|
||||
AgeEstimate.CHILD
|
||||
} else {
|
||||
AgeEstimate.UNKNOWN
|
||||
}
|
||||
return if (spanYears > 3.0) AgeEstimate.CHILD else AgeEstimate.UNKNOWN
|
||||
}
|
||||
|
||||
private fun loadBitmapDownsampled(uri: Uri, maxDim: Int): Bitmap? {
|
||||
@@ -414,17 +416,15 @@ class FaceClusteringService @Inject constructor(
|
||||
}
|
||||
}
|
||||
|
||||
// ==================
|
||||
// DATA CLASSES
|
||||
// ==================
|
||||
|
||||
// Data classes
|
||||
data class DetectedFaceWithEmbedding(
|
||||
val imageId: String,
|
||||
val imageUri: String,
|
||||
val capturedAt: Long,
|
||||
val embedding: FloatArray,
|
||||
val boundingBox: android.graphics.Rect,
|
||||
val confidence: Float
|
||||
val confidence: Float,
|
||||
val faceCount: Int = 1
|
||||
) {
|
||||
override fun equals(other: Any?): Boolean {
|
||||
if (this === other) return true
|
||||
@@ -459,7 +459,7 @@ data class ClusteringResult(
|
||||
)
|
||||
|
||||
enum class AgeEstimate {
|
||||
CHILD, // Appearance changes significantly over time
|
||||
ADULT, // Stable appearance
|
||||
UNKNOWN // Not enough data
|
||||
CHILD,
|
||||
ADULT,
|
||||
UNKNOWN
|
||||
}
|
||||
@@ -8,6 +8,8 @@ import com.placeholder.sherpai2.data.local.dao.PersonDao
|
||||
import com.placeholder.sherpai2.data.local.entity.FaceModelEntity
|
||||
import com.placeholder.sherpai2.data.local.entity.PersonEntity
|
||||
import com.placeholder.sherpai2.data.local.entity.TemporalCentroid
|
||||
import com.placeholder.sherpai2.domain.clustering.ClusterQualityAnalyzer
|
||||
import com.placeholder.sherpai2.domain.clustering.ClusterQualityResult
|
||||
import com.placeholder.sherpai2.domain.clustering.FaceCluster
|
||||
import com.placeholder.sherpai2.ml.FaceNetModel
|
||||
import dagger.hilt.android.qualifiers.ApplicationContext
|
||||
@@ -21,23 +23,36 @@ import kotlin.math.abs
|
||||
* ClusterTrainingService - Train multi-centroid face models from clusters
|
||||
*
|
||||
* STRATEGY:
|
||||
* 1. For children: Create multiple temporal centroids (one per age period)
|
||||
* 2. For adults: Create single centroid (stable appearance)
|
||||
* 3. Use K-Means clustering on timestamps to find age groups
|
||||
* 4. Calculate centroid for each time period
|
||||
* 1. VALIDATE cluster quality FIRST (prevent training on dirty/mixed clusters)
|
||||
* 2. For children: Create multiple temporal centroids (one per age period)
|
||||
* 3. For adults: Create single centroid (stable appearance)
|
||||
* 4. Use K-Means clustering on timestamps to find age groups
|
||||
* 5. Calculate centroid for each time period
|
||||
*/
|
||||
@Singleton
|
||||
class ClusterTrainingService @Inject constructor(
|
||||
@ApplicationContext private val context: Context,
|
||||
private val personDao: PersonDao,
|
||||
private val faceModelDao: FaceModelDao
|
||||
private val faceModelDao: FaceModelDao,
|
||||
private val qualityAnalyzer: ClusterQualityAnalyzer
|
||||
) {
|
||||
|
||||
private val faceNetModel by lazy { FaceNetModel(context) }
|
||||
|
||||
/**
|
||||
* Analyze cluster quality before training
|
||||
*
|
||||
* Call this BEFORE trainFromCluster() to check if cluster is clean
|
||||
*/
|
||||
suspend fun analyzeClusterQuality(cluster: FaceCluster): ClusterQualityResult {
|
||||
return qualityAnalyzer.analyzeCluster(cluster)
|
||||
}
|
||||
|
||||
/**
|
||||
* Train a person from an auto-discovered cluster
|
||||
*
|
||||
* @param cluster The discovered cluster
|
||||
* @param qualityResult Optional pre-computed quality analysis (recommended)
|
||||
* @return PersonId on success
|
||||
*/
|
||||
suspend fun trainFromCluster(
|
||||
@@ -46,12 +61,26 @@ class ClusterTrainingService @Inject constructor(
|
||||
dateOfBirth: Long?,
|
||||
isChild: Boolean,
|
||||
siblingClusterIds: List<Int>,
|
||||
qualityResult: ClusterQualityResult? = null,
|
||||
onProgress: (Int, Int, String) -> Unit = { _, _, _ -> }
|
||||
): String = withContext(Dispatchers.Default) {
|
||||
|
||||
onProgress(0, 100, "Creating person...")
|
||||
|
||||
// Step 1: Create PersonEntity
|
||||
// Step 1: Use clean faces if quality analysis was done
|
||||
val facesToUse = if (qualityResult != null && qualityResult.cleanFaces.isNotEmpty()) {
|
||||
// Use clean faces (outliers removed)
|
||||
qualityResult.cleanFaces
|
||||
} else {
|
||||
// Use all faces (legacy behavior)
|
||||
cluster.faces
|
||||
}
|
||||
|
||||
if (facesToUse.size < 6) {
|
||||
throw Exception("Need at least 6 clean faces for training (have ${facesToUse.size})")
|
||||
}
|
||||
|
||||
// Step 2: Create PersonEntity
|
||||
val person = PersonEntity.create(
|
||||
name = name,
|
||||
dateOfBirth = dateOfBirth,
|
||||
@@ -66,30 +95,20 @@ class ClusterTrainingService @Inject constructor(
|
||||
|
||||
onProgress(20, 100, "Analyzing face variations...")
|
||||
|
||||
// Step 2: Generate embeddings for all faces in cluster
|
||||
val facesWithEmbeddings = cluster.faces.mapNotNull { face ->
|
||||
try {
|
||||
val bitmap = context.contentResolver.openInputStream(Uri.parse(face.imageUri))?.use {
|
||||
BitmapFactory.decodeStream(it)
|
||||
} ?: return@mapNotNull null
|
||||
|
||||
// Generate embedding
|
||||
val embedding = faceNetModel.generateEmbedding(bitmap)
|
||||
bitmap.recycle()
|
||||
|
||||
Triple(face.imageUri, face.capturedAt, embedding)
|
||||
} catch (e: Exception) {
|
||||
null
|
||||
}
|
||||
}
|
||||
|
||||
if (facesWithEmbeddings.isEmpty()) {
|
||||
throw Exception("Failed to process any faces from cluster")
|
||||
// Step 3: Use pre-computed embeddings from clustering
|
||||
// CRITICAL: These embeddings are already face-specific, even in group photos!
|
||||
// The clustering phase already cropped and generated embeddings for each face.
|
||||
val facesWithEmbeddings = facesToUse.map { face ->
|
||||
Triple(
|
||||
face.imageUri,
|
||||
face.capturedAt,
|
||||
face.embedding // ✅ Use existing embedding (already cropped to face)
|
||||
)
|
||||
}
|
||||
|
||||
onProgress(50, 100, "Creating face model...")
|
||||
|
||||
// Step 3: Create centroids based on whether person is a child
|
||||
// Step 4: Create centroids based on whether person is a child
|
||||
val centroids = if (isChild && dateOfBirth != null) {
|
||||
createTemporalCentroidsForChild(
|
||||
facesWithEmbeddings = facesWithEmbeddings,
|
||||
@@ -101,14 +120,14 @@ class ClusterTrainingService @Inject constructor(
|
||||
|
||||
onProgress(80, 100, "Saving model...")
|
||||
|
||||
// Step 4: Calculate average confidence
|
||||
// Step 5: Calculate average confidence
|
||||
val avgConfidence = centroids.map { it.avgConfidence }.average().toFloat()
|
||||
|
||||
// Step 5: Create FaceModelEntity
|
||||
// Step 6: Create FaceModelEntity
|
||||
val faceModel = FaceModelEntity.createFromCentroids(
|
||||
personId = person.id,
|
||||
centroids = centroids,
|
||||
trainingImageCount = cluster.faces.size,
|
||||
trainingImageCount = facesToUse.size,
|
||||
averageConfidence = avgConfidence
|
||||
)
|
||||
|
||||
|
||||
@@ -0,0 +1,312 @@
|
||||
package com.placeholder.sherpai2.domain.validation
|
||||
|
||||
import android.content.Context
|
||||
import android.graphics.BitmapFactory
|
||||
import android.net.Uri
|
||||
import com.google.mlkit.vision.common.InputImage
|
||||
import com.google.mlkit.vision.face.FaceDetection
|
||||
import com.google.mlkit.vision.face.FaceDetectorOptions
|
||||
import com.placeholder.sherpai2.data.local.dao.FaceModelDao
|
||||
import com.placeholder.sherpai2.data.local.dao.ImageDao
|
||||
import com.placeholder.sherpai2.data.local.entity.FaceModelEntity
|
||||
import com.placeholder.sherpai2.data.local.entity.ImageEntity
|
||||
import com.placeholder.sherpai2.ml.FaceNetModel
|
||||
import dagger.hilt.android.qualifiers.ApplicationContext
|
||||
import kotlinx.coroutines.Dispatchers
|
||||
import kotlinx.coroutines.async
|
||||
import kotlinx.coroutines.awaitAll
|
||||
import kotlinx.coroutines.coroutineScope
|
||||
import kotlinx.coroutines.tasks.await
|
||||
import kotlinx.coroutines.withContext
|
||||
import javax.inject.Inject
|
||||
import javax.inject.Singleton
|
||||
|
||||
/**
|
||||
* ValidationScanService - Quick validation scan after training
|
||||
*
|
||||
* PURPOSE: Let user verify model quality BEFORE full library scan
|
||||
*
|
||||
* STRATEGY:
|
||||
* 1. Sample 20-30 random photos with faces
|
||||
* 2. Scan for the newly trained person
|
||||
* 3. Return preview results with confidence scores
|
||||
* 4. User reviews and decides: "Looks good" or "Add more photos"
|
||||
*
|
||||
* THRESHOLD STRATEGY:
|
||||
* - Use CONSERVATIVE threshold (0.75) for validation
|
||||
* - Better to show false negatives than false positives
|
||||
* - If user approves, full scan uses slightly looser threshold (0.70)
|
||||
*/
|
||||
@Singleton
|
||||
class ValidationScanService @Inject constructor(
|
||||
@ApplicationContext private val context: Context,
|
||||
private val imageDao: ImageDao,
|
||||
private val faceModelDao: FaceModelDao
|
||||
) {
|
||||
|
||||
companion object {
|
||||
private const val VALIDATION_SAMPLE_SIZE = 25
|
||||
private const val VALIDATION_THRESHOLD = 0.75f // Conservative
|
||||
}
|
||||
|
||||
/**
|
||||
* Perform validation scan after training
|
||||
*
|
||||
* @param personId The newly trained person
|
||||
* @param onProgress Callback (current, total)
|
||||
* @return Validation results with preview matches
|
||||
*/
|
||||
suspend fun performValidationScan(
|
||||
personId: String,
|
||||
onProgress: (Int, Int) -> Unit = { _, _ -> }
|
||||
): ValidationScanResult = withContext(Dispatchers.Default) {
|
||||
|
||||
onProgress(0, 100)
|
||||
|
||||
// Step 1: Get face model
|
||||
val faceModel = withContext(Dispatchers.IO) {
|
||||
faceModelDao.getFaceModelByPersonId(personId)
|
||||
} ?: return@withContext ValidationScanResult(
|
||||
personId = personId,
|
||||
matches = emptyList(),
|
||||
sampleSize = 0,
|
||||
errorMessage = "Face model not found"
|
||||
)
|
||||
|
||||
onProgress(10, 100)
|
||||
|
||||
// Step 2: Get random sample of photos with faces
|
||||
val allPhotosWithFaces = withContext(Dispatchers.IO) {
|
||||
imageDao.getImagesWithFaces()
|
||||
}
|
||||
|
||||
if (allPhotosWithFaces.isEmpty()) {
|
||||
return@withContext ValidationScanResult(
|
||||
personId = personId,
|
||||
matches = emptyList(),
|
||||
sampleSize = 0,
|
||||
errorMessage = "No photos with faces in library"
|
||||
)
|
||||
}
|
||||
|
||||
// Random sample
|
||||
val samplePhotos = allPhotosWithFaces.shuffled().take(VALIDATION_SAMPLE_SIZE)
|
||||
onProgress(20, 100)
|
||||
|
||||
// Step 3: Scan sample photos
|
||||
val faceNetModel = FaceNetModel(context)
|
||||
val detector = FaceDetection.getClient(
|
||||
FaceDetectorOptions.Builder()
|
||||
.setPerformanceMode(FaceDetectorOptions.PERFORMANCE_MODE_ACCURATE)
|
||||
.setMinFaceSize(0.15f)
|
||||
.build()
|
||||
)
|
||||
|
||||
try {
|
||||
val matches = scanPhotosForPerson(
|
||||
photos = samplePhotos,
|
||||
faceModel = faceModel,
|
||||
faceNetModel = faceNetModel,
|
||||
detector = detector,
|
||||
threshold = VALIDATION_THRESHOLD,
|
||||
onProgress = { current, total ->
|
||||
// Map to 20-100 range
|
||||
val progress = 20 + (current * 80 / total)
|
||||
onProgress(progress, 100)
|
||||
}
|
||||
)
|
||||
|
||||
onProgress(100, 100)
|
||||
|
||||
ValidationScanResult(
|
||||
personId = personId,
|
||||
matches = matches,
|
||||
sampleSize = samplePhotos.size,
|
||||
threshold = VALIDATION_THRESHOLD
|
||||
)
|
||||
|
||||
} finally {
|
||||
faceNetModel.close()
|
||||
detector.close()
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Scan photos for a specific person
|
||||
*/
|
||||
private suspend fun scanPhotosForPerson(
|
||||
photos: List<ImageEntity>,
|
||||
faceModel: FaceModelEntity,
|
||||
faceNetModel: FaceNetModel,
|
||||
detector: com.google.mlkit.vision.face.FaceDetector,
|
||||
threshold: Float,
|
||||
onProgress: (Int, Int) -> Unit
|
||||
): List<ValidationMatch> = coroutineScope {
|
||||
|
||||
val modelEmbedding = faceModel.getEmbeddingArray()
|
||||
val matches = mutableListOf<ValidationMatch>()
|
||||
var processedCount = 0
|
||||
|
||||
// Process in parallel
|
||||
val jobs = photos.map { photo ->
|
||||
async(Dispatchers.IO) {
|
||||
val photoMatches = scanSinglePhoto(
|
||||
photo = photo,
|
||||
modelEmbedding = modelEmbedding,
|
||||
faceNetModel = faceNetModel,
|
||||
detector = detector,
|
||||
threshold = threshold
|
||||
)
|
||||
|
||||
synchronized(matches) {
|
||||
matches.addAll(photoMatches)
|
||||
processedCount++
|
||||
if (processedCount % 5 == 0) {
|
||||
onProgress(processedCount, photos.size)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
jobs.awaitAll()
|
||||
matches.sortedByDescending { it.confidence }
|
||||
}
|
||||
|
||||
/**
|
||||
* Scan a single photo for the person
|
||||
*/
|
||||
private suspend fun scanSinglePhoto(
|
||||
photo: ImageEntity,
|
||||
modelEmbedding: FloatArray,
|
||||
faceNetModel: FaceNetModel,
|
||||
detector: com.google.mlkit.vision.face.FaceDetector,
|
||||
threshold: Float
|
||||
): List<ValidationMatch> = withContext(Dispatchers.IO) {
|
||||
|
||||
try {
|
||||
// Load bitmap
|
||||
val bitmap = loadBitmapDownsampled(Uri.parse(photo.imageUri), 768)
|
||||
?: return@withContext emptyList()
|
||||
|
||||
// Detect faces
|
||||
val inputImage = InputImage.fromBitmap(bitmap, 0)
|
||||
val faces = detector.process(inputImage).await()
|
||||
|
||||
// Check each face
|
||||
val matches = faces.mapNotNull { face ->
|
||||
try {
|
||||
// Crop face
|
||||
val faceBitmap = android.graphics.Bitmap.createBitmap(
|
||||
bitmap,
|
||||
face.boundingBox.left.coerceIn(0, bitmap.width - 1),
|
||||
face.boundingBox.top.coerceIn(0, bitmap.height - 1),
|
||||
face.boundingBox.width().coerceAtMost(bitmap.width - face.boundingBox.left),
|
||||
face.boundingBox.height().coerceAtMost(bitmap.height - face.boundingBox.top)
|
||||
)
|
||||
|
||||
// Generate embedding
|
||||
val faceEmbedding = faceNetModel.generateEmbedding(faceBitmap)
|
||||
faceBitmap.recycle()
|
||||
|
||||
// Calculate similarity
|
||||
val similarity = faceNetModel.calculateSimilarity(faceEmbedding, modelEmbedding)
|
||||
|
||||
if (similarity >= threshold) {
|
||||
ValidationMatch(
|
||||
imageId = photo.imageId,
|
||||
imageUri = photo.imageUri,
|
||||
capturedAt = photo.capturedAt,
|
||||
confidence = similarity,
|
||||
boundingBox = face.boundingBox,
|
||||
faceCount = faces.size
|
||||
)
|
||||
} else {
|
||||
null
|
||||
}
|
||||
} catch (e: Exception) {
|
||||
null
|
||||
}
|
||||
}
|
||||
|
||||
bitmap.recycle()
|
||||
matches
|
||||
|
||||
} catch (e: Exception) {
|
||||
emptyList()
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Load bitmap with downsampling
|
||||
*/
|
||||
private fun loadBitmapDownsampled(uri: Uri, maxDim: Int): android.graphics.Bitmap? {
|
||||
return try {
|
||||
val opts = BitmapFactory.Options().apply { inJustDecodeBounds = true }
|
||||
context.contentResolver.openInputStream(uri)?.use {
|
||||
BitmapFactory.decodeStream(it, null, opts)
|
||||
}
|
||||
|
||||
var sample = 1
|
||||
while (opts.outWidth / sample > maxDim || opts.outHeight / sample > maxDim) {
|
||||
sample *= 2
|
||||
}
|
||||
|
||||
val finalOpts = BitmapFactory.Options().apply {
|
||||
inSampleSize = sample
|
||||
}
|
||||
|
||||
context.contentResolver.openInputStream(uri)?.use {
|
||||
BitmapFactory.decodeStream(it, null, finalOpts)
|
||||
}
|
||||
} catch (e: Exception) {
|
||||
null
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Result of validation scan
|
||||
*/
|
||||
data class ValidationScanResult(
|
||||
val personId: String,
|
||||
val matches: List<ValidationMatch>,
|
||||
val sampleSize: Int,
|
||||
val threshold: Float = 0.75f,
|
||||
val errorMessage: String? = null
|
||||
) {
|
||||
val matchCount: Int get() = matches.size
|
||||
val averageConfidence: Float get() = if (matches.isNotEmpty()) {
|
||||
matches.map { it.confidence }.average().toFloat()
|
||||
} else 0f
|
||||
|
||||
val qualityAssessment: ValidationQuality get() = when {
|
||||
matchCount == 0 -> ValidationQuality.NO_MATCHES
|
||||
averageConfidence >= 0.85f && matchCount >= 5 -> ValidationQuality.EXCELLENT
|
||||
averageConfidence >= 0.78f && matchCount >= 3 -> ValidationQuality.GOOD
|
||||
averageConfidence < 0.75f || matchCount < 2 -> ValidationQuality.POOR
|
||||
else -> ValidationQuality.FAIR
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Single match found during validation
|
||||
*/
|
||||
data class ValidationMatch(
|
||||
val imageId: String,
|
||||
val imageUri: String,
|
||||
val capturedAt: Long,
|
||||
val confidence: Float,
|
||||
val boundingBox: android.graphics.Rect,
|
||||
val faceCount: Int
|
||||
)
|
||||
|
||||
/**
|
||||
* Overall quality assessment
|
||||
*/
|
||||
enum class ValidationQuality {
|
||||
EXCELLENT, // High confidence, many matches
|
||||
GOOD, // Decent confidence, some matches
|
||||
FAIR, // Acceptable, proceed with caution
|
||||
POOR, // Low confidence or very few matches
|
||||
NO_MATCHES // No matches found at all
|
||||
}
|
||||
@@ -1,253 +1,201 @@
|
||||
package com.placeholder.sherpai2.ui.discover
|
||||
|
||||
import android.graphics.BitmapFactory
|
||||
import android.net.Uri
|
||||
import androidx.compose.foundation.Image
|
||||
import androidx.compose.foundation.background
|
||||
import androidx.compose.foundation.clickable
|
||||
import androidx.compose.foundation.layout.*
|
||||
import androidx.compose.foundation.lazy.LazyColumn
|
||||
import androidx.compose.foundation.lazy.grid.GridCells
|
||||
import androidx.compose.foundation.lazy.grid.LazyVerticalGrid
|
||||
import androidx.compose.foundation.lazy.grid.items
|
||||
import androidx.compose.foundation.lazy.items
|
||||
import androidx.compose.foundation.shape.CircleShape
|
||||
import androidx.compose.foundation.shape.RoundedCornerShape
|
||||
import androidx.compose.material.icons.Icons
|
||||
import androidx.compose.material.icons.filled.*
|
||||
import androidx.compose.material.icons.filled.Person
|
||||
import androidx.compose.material3.*
|
||||
import androidx.compose.runtime.*
|
||||
import androidx.compose.ui.Alignment
|
||||
import androidx.compose.ui.Modifier
|
||||
import androidx.compose.ui.draw.clip
|
||||
import androidx.compose.ui.graphics.asImageBitmap
|
||||
import androidx.compose.ui.layout.ContentScale
|
||||
import androidx.compose.ui.platform.LocalContext
|
||||
import androidx.compose.ui.text.font.FontWeight
|
||||
import androidx.compose.ui.text.style.TextAlign
|
||||
import androidx.compose.ui.unit.dp
|
||||
import androidx.hilt.navigation.compose.hiltViewModel
|
||||
import androidx.lifecycle.compose.collectAsStateWithLifecycle
|
||||
import com.placeholder.sherpai2.domain.clustering.AgeEstimate
|
||||
import com.placeholder.sherpai2.domain.clustering.FaceCluster
|
||||
import java.text.SimpleDateFormat
|
||||
import java.util.*
|
||||
|
||||
/**
|
||||
* DiscoverPeopleScreen - Beautiful auto-clustering UI
|
||||
* DiscoverPeopleScreen - COMPLETE WORKING VERSION
|
||||
*
|
||||
* FLOW:
|
||||
* 1. Hero CTA: "Discover People in Your Photos"
|
||||
* 2. Auto-clustering progress (2-5 min)
|
||||
* 3. Grid of discovered people
|
||||
* 4. Tap cluster → Name person + metadata
|
||||
* 5. Background deep scan starts
|
||||
* This handles ALL states properly including Idle state
|
||||
*/
|
||||
@OptIn(ExperimentalMaterial3Api::class)
|
||||
@Composable
|
||||
fun DiscoverPeopleScreen(
|
||||
viewModel: DiscoverPeopleViewModel = hiltViewModel()
|
||||
viewModel: DiscoverPeopleViewModel = hiltViewModel(),
|
||||
onNavigateBack: () -> Unit = {}
|
||||
) {
|
||||
val uiState by viewModel.uiState.collectAsStateWithLifecycle()
|
||||
val uiState by viewModel.uiState.collectAsState()
|
||||
|
||||
// NO SCAFFOLD - MainScreen already has TopAppBar
|
||||
Box(modifier = Modifier.fillMaxSize()) {
|
||||
when (val state = uiState) {
|
||||
is DiscoverUiState.Idle -> IdleScreen(
|
||||
onStartDiscovery = { viewModel.startDiscovery() }
|
||||
)
|
||||
|
||||
is DiscoverUiState.Clustering -> ClusteringProgressScreen(
|
||||
progress = state.progress,
|
||||
total = state.total,
|
||||
message = state.message
|
||||
)
|
||||
|
||||
is DiscoverUiState.NamingReady -> ClusterGridScreen(
|
||||
result = state.result,
|
||||
onClusterClick = { cluster ->
|
||||
viewModel.selectCluster(cluster)
|
||||
Scaffold(
|
||||
topBar = {
|
||||
TopAppBar(
|
||||
title = { Text("Discover People") },
|
||||
navigationIcon = {
|
||||
IconButton(onClick = onNavigateBack) {
|
||||
Icon(
|
||||
imageVector = Icons.Default.Person,
|
||||
contentDescription = "Back"
|
||||
)
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
is DiscoverUiState.NamingCluster -> NamingDialog(
|
||||
cluster = state.selectedCluster,
|
||||
suggestedSiblings = state.suggestedSiblings,
|
||||
onConfirm = { name, dob, isChild, siblings ->
|
||||
viewModel.confirmClusterName(
|
||||
cluster = state.selectedCluster,
|
||||
name = name,
|
||||
dateOfBirth = dob,
|
||||
isChild = isChild,
|
||||
selectedSiblings = siblings
|
||||
)
|
||||
},
|
||||
onDismiss = { viewModel.cancelNaming() }
|
||||
)
|
||||
|
||||
is DiscoverUiState.NoPeopleFound -> EmptyStateScreen(
|
||||
message = state.message
|
||||
)
|
||||
|
||||
is DiscoverUiState.Error -> ErrorScreen(
|
||||
message = state.message,
|
||||
onRetry = { viewModel.startDiscovery() }
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Idle state - Hero CTA to start discovery
|
||||
*/
|
||||
@Composable
|
||||
fun IdleScreen(
|
||||
onStartDiscovery: () -> Unit
|
||||
) {
|
||||
Column(
|
||||
modifier = Modifier
|
||||
.fillMaxSize()
|
||||
.padding(32.dp),
|
||||
horizontalAlignment = Alignment.CenterHorizontally,
|
||||
verticalArrangement = Arrangement.Center
|
||||
) {
|
||||
Icon(
|
||||
imageVector = Icons.Default.AutoAwesome,
|
||||
contentDescription = null,
|
||||
modifier = Modifier.size(120.dp),
|
||||
tint = MaterialTheme.colorScheme.primary
|
||||
)
|
||||
|
||||
Spacer(Modifier.height(24.dp))
|
||||
|
||||
Text(
|
||||
text = "Discover People",
|
||||
style = MaterialTheme.typography.headlineLarge,
|
||||
fontWeight = FontWeight.Bold,
|
||||
textAlign = TextAlign.Center
|
||||
)
|
||||
|
||||
Spacer(Modifier.height(16.dp))
|
||||
|
||||
Text(
|
||||
text = "Let AI automatically find and group faces in your photos. " +
|
||||
"You'll name them, and we'll tag all their photos.",
|
||||
style = MaterialTheme.typography.bodyLarge,
|
||||
textAlign = TextAlign.Center,
|
||||
color = MaterialTheme.colorScheme.onSurfaceVariant
|
||||
)
|
||||
|
||||
Spacer(Modifier.height(32.dp))
|
||||
|
||||
Button(
|
||||
onClick = onStartDiscovery,
|
||||
) { paddingValues ->
|
||||
Box(
|
||||
modifier = Modifier
|
||||
.fillMaxWidth()
|
||||
.height(56.dp),
|
||||
colors = ButtonDefaults.buttonColors(
|
||||
containerColor = MaterialTheme.colorScheme.primary
|
||||
)
|
||||
.fillMaxSize()
|
||||
.padding(paddingValues)
|
||||
) {
|
||||
Icon(
|
||||
imageVector = Icons.Default.AutoAwesome,
|
||||
contentDescription = null,
|
||||
modifier = Modifier.size(24.dp)
|
||||
)
|
||||
Spacer(Modifier.width(8.dp))
|
||||
Text(
|
||||
text = "Start Discovery",
|
||||
style = MaterialTheme.typography.titleMedium,
|
||||
fontWeight = FontWeight.Bold
|
||||
)
|
||||
}
|
||||
when (val state = uiState) {
|
||||
// ===== IDLE STATE (START HERE) =====
|
||||
is DiscoverUiState.Idle -> {
|
||||
IdleStateContent(
|
||||
onStartDiscovery = { viewModel.startDiscovery() }
|
||||
)
|
||||
}
|
||||
|
||||
Spacer(Modifier.height(16.dp))
|
||||
// ===== CLUSTERING IN PROGRESS =====
|
||||
is DiscoverUiState.Clustering -> {
|
||||
ClusteringProgressContent(
|
||||
progress = state.progress,
|
||||
total = state.total,
|
||||
message = state.message
|
||||
)
|
||||
}
|
||||
|
||||
Card(
|
||||
modifier = Modifier.fillMaxWidth(),
|
||||
colors = CardDefaults.cardColors(
|
||||
containerColor = MaterialTheme.colorScheme.surfaceVariant
|
||||
)
|
||||
) {
|
||||
Column(
|
||||
modifier = Modifier.padding(16.dp),
|
||||
verticalArrangement = Arrangement.spacedBy(8.dp)
|
||||
) {
|
||||
InfoRow(Icons.Default.Speed, "Fast: Analyzes ~1000 photos in 2-5 minutes")
|
||||
InfoRow(Icons.Default.Security, "Private: Everything stays on your device")
|
||||
InfoRow(Icons.Default.AutoAwesome, "Smart: Groups faces automatically")
|
||||
// ===== CLUSTERS READY FOR NAMING =====
|
||||
is DiscoverUiState.NamingReady -> {
|
||||
Text(
|
||||
text = "Found ${state.result.clusters.size} people!\n\nCluster grid UI coming...",
|
||||
modifier = Modifier.align(Alignment.Center)
|
||||
)
|
||||
}
|
||||
|
||||
// ===== ANALYZING CLUSTER QUALITY =====
|
||||
is DiscoverUiState.AnalyzingCluster -> {
|
||||
LoadingContent(message = "Analyzing cluster quality...")
|
||||
}
|
||||
|
||||
// ===== NAMING A CLUSTER =====
|
||||
is DiscoverUiState.NamingCluster -> {
|
||||
Text(
|
||||
text = "Naming dialog for cluster ${state.selectedCluster.clusterId}\n\nDialog UI coming...",
|
||||
modifier = Modifier.align(Alignment.Center)
|
||||
)
|
||||
}
|
||||
|
||||
// ===== TRAINING IN PROGRESS =====
|
||||
is DiscoverUiState.Training -> {
|
||||
TrainingProgressContent(
|
||||
stage = state.stage,
|
||||
progress = state.progress,
|
||||
total = state.total
|
||||
)
|
||||
}
|
||||
|
||||
// ===== VALIDATION PREVIEW =====
|
||||
is DiscoverUiState.ValidationPreview -> {
|
||||
ValidationPreviewScreen(
|
||||
personName = state.personName,
|
||||
validationResult = state.validationResult,
|
||||
onApprove = {
|
||||
viewModel.approveValidationAndScan(
|
||||
personId = state.personId,
|
||||
personName = state.personName
|
||||
)
|
||||
},
|
||||
onReject = {
|
||||
viewModel.rejectValidationAndImprove()
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
// ===== COMPLETE =====
|
||||
is DiscoverUiState.Complete -> {
|
||||
CompleteStateContent(
|
||||
message = state.message,
|
||||
onDone = onNavigateBack
|
||||
)
|
||||
}
|
||||
|
||||
// ===== NO PEOPLE FOUND =====
|
||||
is DiscoverUiState.NoPeopleFound -> {
|
||||
ErrorStateContent(
|
||||
title = "No People Found",
|
||||
message = state.message,
|
||||
onRetry = { viewModel.startDiscovery() },
|
||||
onBack = onNavigateBack
|
||||
)
|
||||
}
|
||||
|
||||
// ===== ERROR =====
|
||||
is DiscoverUiState.Error -> {
|
||||
ErrorStateContent(
|
||||
title = "Error",
|
||||
message = state.message,
|
||||
onRetry = { viewModel.reset(); viewModel.startDiscovery() },
|
||||
onBack = onNavigateBack
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Composable
|
||||
fun InfoRow(icon: androidx.compose.ui.graphics.vector.ImageVector, text: String) {
|
||||
Row(
|
||||
verticalAlignment = Alignment.CenterVertically,
|
||||
horizontalArrangement = Arrangement.spacedBy(12.dp)
|
||||
) {
|
||||
Icon(
|
||||
imageVector = icon,
|
||||
contentDescription = null,
|
||||
tint = MaterialTheme.colorScheme.primary,
|
||||
modifier = Modifier.size(20.dp)
|
||||
)
|
||||
Text(
|
||||
text = text,
|
||||
style = MaterialTheme.typography.bodyMedium
|
||||
)
|
||||
}
|
||||
}
|
||||
// ===== IDLE STATE CONTENT =====
|
||||
|
||||
/**
|
||||
* Clustering progress screen
|
||||
*/
|
||||
@Composable
|
||||
fun ClusteringProgressScreen(
|
||||
progress: Int,
|
||||
total: Int,
|
||||
message: String
|
||||
private fun IdleStateContent(
|
||||
onStartDiscovery: () -> Unit
|
||||
) {
|
||||
Column(
|
||||
modifier = Modifier
|
||||
.fillMaxSize()
|
||||
.padding(32.dp),
|
||||
.padding(24.dp),
|
||||
horizontalAlignment = Alignment.CenterHorizontally,
|
||||
verticalArrangement = Arrangement.Center
|
||||
) {
|
||||
CircularProgressIndicator(
|
||||
modifier = Modifier.size(80.dp),
|
||||
strokeWidth = 6.dp
|
||||
Icon(
|
||||
imageVector = Icons.Default.Person,
|
||||
contentDescription = null,
|
||||
modifier = Modifier.size(120.dp),
|
||||
tint = MaterialTheme.colorScheme.primary
|
||||
)
|
||||
|
||||
Spacer(Modifier.height(32.dp))
|
||||
Spacer(modifier = Modifier.height(32.dp))
|
||||
|
||||
Text(
|
||||
text = "Discovering People...",
|
||||
style = MaterialTheme.typography.headlineSmall,
|
||||
text = "Discover People",
|
||||
style = MaterialTheme.typography.headlineLarge,
|
||||
fontWeight = FontWeight.Bold
|
||||
)
|
||||
|
||||
Spacer(Modifier.height(16.dp))
|
||||
|
||||
LinearProgressIndicator(
|
||||
progress = { if (total > 0) progress.toFloat() / total.toFloat() else 0f },
|
||||
modifier = Modifier.fillMaxWidth(),
|
||||
)
|
||||
|
||||
Spacer(Modifier.height(8.dp))
|
||||
Spacer(modifier = Modifier.height(16.dp))
|
||||
|
||||
Text(
|
||||
text = message,
|
||||
style = MaterialTheme.typography.bodyMedium,
|
||||
text = "Automatically find and organize people in your photo library",
|
||||
style = MaterialTheme.typography.bodyLarge,
|
||||
textAlign = TextAlign.Center,
|
||||
color = MaterialTheme.colorScheme.onSurfaceVariant
|
||||
)
|
||||
|
||||
Spacer(Modifier.height(24.dp))
|
||||
Spacer(modifier = Modifier.height(48.dp))
|
||||
|
||||
Button(
|
||||
onClick = onStartDiscovery,
|
||||
modifier = Modifier
|
||||
.fillMaxWidth()
|
||||
.height(56.dp)
|
||||
) {
|
||||
Text(
|
||||
text = "Start Discovery",
|
||||
style = MaterialTheme.typography.titleMedium
|
||||
)
|
||||
}
|
||||
|
||||
Spacer(modifier = Modifier.height(16.dp))
|
||||
|
||||
Text(
|
||||
text = "This will take 2-5 minutes. You can leave and come back later.",
|
||||
text = "This will analyze faces in your photos and group similar faces together",
|
||||
style = MaterialTheme.typography.bodySmall,
|
||||
textAlign = TextAlign.Center,
|
||||
color = MaterialTheme.colorScheme.onSurfaceVariant
|
||||
@@ -255,421 +203,145 @@ fun ClusteringProgressScreen(
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Grid of discovered clusters
|
||||
*/
|
||||
// ===== CLUSTERING PROGRESS =====
|
||||
|
||||
@Composable
|
||||
fun ClusterGridScreen(
|
||||
result: com.placeholder.sherpai2.domain.clustering.ClusteringResult,
|
||||
onClusterClick: (FaceCluster) -> Unit
|
||||
private fun ClusteringProgressContent(
|
||||
progress: Int,
|
||||
total: Int,
|
||||
message: String
|
||||
) {
|
||||
Column(
|
||||
modifier = Modifier
|
||||
.fillMaxSize()
|
||||
.padding(16.dp)
|
||||
) {
|
||||
Text(
|
||||
text = "Found ${result.clusters.size} People",
|
||||
style = MaterialTheme.typography.headlineSmall,
|
||||
fontWeight = FontWeight.Bold
|
||||
)
|
||||
|
||||
Spacer(Modifier.height(8.dp))
|
||||
|
||||
Text(
|
||||
text = "Tap to name each person. We'll then tag all their photos.",
|
||||
style = MaterialTheme.typography.bodyMedium,
|
||||
color = MaterialTheme.colorScheme.onSurfaceVariant
|
||||
)
|
||||
|
||||
Spacer(Modifier.height(16.dp))
|
||||
|
||||
LazyVerticalGrid(
|
||||
columns = GridCells.Fixed(2),
|
||||
horizontalArrangement = Arrangement.spacedBy(12.dp),
|
||||
verticalArrangement = Arrangement.spacedBy(12.dp)
|
||||
) {
|
||||
items(result.clusters) { cluster ->
|
||||
ClusterCard(
|
||||
cluster = cluster,
|
||||
onClick = { onClusterClick(cluster) }
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Single cluster card
|
||||
*/
|
||||
@Composable
|
||||
fun ClusterCard(
|
||||
cluster: FaceCluster,
|
||||
onClick: () -> Unit
|
||||
) {
|
||||
val context = LocalContext.current
|
||||
|
||||
Card(
|
||||
modifier = Modifier
|
||||
.fillMaxWidth()
|
||||
.clickable(onClick = onClick),
|
||||
elevation = CardDefaults.cardElevation(defaultElevation = 2.dp)
|
||||
) {
|
||||
Column {
|
||||
// Face grid (2x3)
|
||||
LazyVerticalGrid(
|
||||
columns = GridCells.Fixed(3),
|
||||
modifier = Modifier.height(180.dp),
|
||||
userScrollEnabled = false
|
||||
) {
|
||||
items(cluster.representativeFaces.take(6)) { face ->
|
||||
val bitmap = remember(face.imageUri) {
|
||||
try {
|
||||
context.contentResolver.openInputStream(Uri.parse(face.imageUri))?.use {
|
||||
BitmapFactory.decodeStream(it)
|
||||
}
|
||||
} catch (e: Exception) {
|
||||
null
|
||||
}
|
||||
}
|
||||
|
||||
if (bitmap != null) {
|
||||
Image(
|
||||
bitmap = bitmap.asImageBitmap(),
|
||||
contentDescription = null,
|
||||
modifier = Modifier
|
||||
.fillMaxWidth()
|
||||
.aspectRatio(1f),
|
||||
contentScale = ContentScale.Crop
|
||||
)
|
||||
} else {
|
||||
Box(
|
||||
modifier = Modifier
|
||||
.fillMaxWidth()
|
||||
.aspectRatio(1f)
|
||||
.background(MaterialTheme.colorScheme.surfaceVariant),
|
||||
contentAlignment = Alignment.Center
|
||||
) {
|
||||
Icon(
|
||||
imageVector = Icons.Default.Person,
|
||||
contentDescription = null,
|
||||
tint = MaterialTheme.colorScheme.onSurfaceVariant
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Info
|
||||
Column(
|
||||
modifier = Modifier.padding(12.dp)
|
||||
) {
|
||||
Row(
|
||||
modifier = Modifier.fillMaxWidth(),
|
||||
horizontalArrangement = Arrangement.SpaceBetween,
|
||||
verticalAlignment = Alignment.CenterVertically
|
||||
) {
|
||||
Text(
|
||||
text = "${cluster.photoCount} photos",
|
||||
style = MaterialTheme.typography.titleMedium,
|
||||
fontWeight = FontWeight.Bold
|
||||
)
|
||||
|
||||
if (cluster.estimatedAge == AgeEstimate.CHILD) {
|
||||
Surface(
|
||||
shape = RoundedCornerShape(12.dp),
|
||||
color = MaterialTheme.colorScheme.primaryContainer
|
||||
) {
|
||||
Text(
|
||||
text = "Child",
|
||||
modifier = Modifier.padding(horizontal = 8.dp, vertical = 4.dp),
|
||||
style = MaterialTheme.typography.labelSmall,
|
||||
color = MaterialTheme.colorScheme.onPrimaryContainer
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (cluster.potentialSiblings.isNotEmpty()) {
|
||||
Spacer(Modifier.height(4.dp))
|
||||
Text(
|
||||
text = "Appears with ${cluster.potentialSiblings.size} other person(s)",
|
||||
style = MaterialTheme.typography.bodySmall,
|
||||
color = MaterialTheme.colorScheme.onSurfaceVariant
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Naming dialog
|
||||
*/
|
||||
@Composable
|
||||
fun NamingDialog(
|
||||
cluster: FaceCluster,
|
||||
suggestedSiblings: List<FaceCluster>,
|
||||
onConfirm: (String, Long?, Boolean, List<Int>) -> Unit,
|
||||
onDismiss: () -> Unit
|
||||
) {
|
||||
var name by remember { mutableStateOf("") }
|
||||
var isChild by remember { mutableStateOf(cluster.estimatedAge == AgeEstimate.CHILD) }
|
||||
var dateOfBirth by remember { mutableStateOf<Long?>(null) }
|
||||
var selectedSiblings by remember { mutableStateOf<Set<Int>>(emptySet()) }
|
||||
var showDatePicker by remember { mutableStateOf(false) }
|
||||
val context = LocalContext.current
|
||||
|
||||
// Date picker dialog
|
||||
if (showDatePicker) {
|
||||
val calendar = java.util.Calendar.getInstance()
|
||||
if (dateOfBirth != null) {
|
||||
calendar.timeInMillis = dateOfBirth!!
|
||||
}
|
||||
|
||||
val datePickerDialog = android.app.DatePickerDialog(
|
||||
context,
|
||||
{ _, year, month, dayOfMonth ->
|
||||
val cal = java.util.Calendar.getInstance()
|
||||
cal.set(year, month, dayOfMonth)
|
||||
dateOfBirth = cal.timeInMillis
|
||||
showDatePicker = false
|
||||
},
|
||||
calendar.get(java.util.Calendar.YEAR),
|
||||
calendar.get(java.util.Calendar.MONTH),
|
||||
calendar.get(java.util.Calendar.DAY_OF_MONTH)
|
||||
)
|
||||
|
||||
datePickerDialog.setOnDismissListener {
|
||||
showDatePicker = false
|
||||
}
|
||||
|
||||
DisposableEffect(Unit) {
|
||||
datePickerDialog.show()
|
||||
onDispose {
|
||||
datePickerDialog.dismiss()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
AlertDialog(
|
||||
onDismissRequest = onDismiss,
|
||||
title = {
|
||||
Text("Name This Person")
|
||||
},
|
||||
text = {
|
||||
Column(
|
||||
verticalArrangement = Arrangement.spacedBy(16.dp)
|
||||
) {
|
||||
// FACE PREVIEW - Show 6 representative faces
|
||||
Text(
|
||||
text = "${cluster.photoCount} photos found",
|
||||
style = MaterialTheme.typography.bodyMedium,
|
||||
color = MaterialTheme.colorScheme.onSurfaceVariant
|
||||
)
|
||||
|
||||
LazyVerticalGrid(
|
||||
columns = GridCells.Fixed(3),
|
||||
modifier = Modifier.height(180.dp),
|
||||
horizontalArrangement = Arrangement.spacedBy(4.dp),
|
||||
verticalArrangement = Arrangement.spacedBy(4.dp)
|
||||
) {
|
||||
items(cluster.representativeFaces.take(6)) { face ->
|
||||
val bitmap = remember(face.imageUri) {
|
||||
try {
|
||||
context.contentResolver.openInputStream(Uri.parse(face.imageUri))?.use {
|
||||
BitmapFactory.decodeStream(it)
|
||||
}
|
||||
} catch (e: Exception) {
|
||||
null
|
||||
}
|
||||
}
|
||||
|
||||
if (bitmap != null) {
|
||||
Image(
|
||||
bitmap = bitmap.asImageBitmap(),
|
||||
contentDescription = null,
|
||||
modifier = Modifier
|
||||
.fillMaxWidth()
|
||||
.aspectRatio(1f)
|
||||
.clip(RoundedCornerShape(8.dp)),
|
||||
contentScale = ContentScale.Crop
|
||||
)
|
||||
} else {
|
||||
Box(
|
||||
modifier = Modifier
|
||||
.fillMaxWidth()
|
||||
.aspectRatio(1f)
|
||||
.clip(RoundedCornerShape(8.dp))
|
||||
.background(MaterialTheme.colorScheme.surfaceVariant),
|
||||
contentAlignment = Alignment.Center
|
||||
) {
|
||||
Icon(
|
||||
imageVector = Icons.Default.Person,
|
||||
contentDescription = null,
|
||||
tint = MaterialTheme.colorScheme.onSurfaceVariant
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
HorizontalDivider()
|
||||
|
||||
// Name input
|
||||
OutlinedTextField(
|
||||
value = name,
|
||||
onValueChange = { name = it },
|
||||
label = { Text("Name") },
|
||||
singleLine = true,
|
||||
modifier = Modifier.fillMaxWidth()
|
||||
)
|
||||
|
||||
// Is child toggle
|
||||
Row(
|
||||
modifier = Modifier.fillMaxWidth(),
|
||||
horizontalArrangement = Arrangement.SpaceBetween,
|
||||
verticalAlignment = Alignment.CenterVertically
|
||||
) {
|
||||
Text("This person is a child")
|
||||
Switch(
|
||||
checked = isChild,
|
||||
onCheckedChange = { isChild = it }
|
||||
)
|
||||
}
|
||||
|
||||
// Date of birth (if child)
|
||||
if (isChild) {
|
||||
OutlinedButton(
|
||||
onClick = { showDatePicker = true },
|
||||
modifier = Modifier.fillMaxWidth()
|
||||
) {
|
||||
Icon(Icons.Default.CalendarToday, null)
|
||||
Spacer(Modifier.width(8.dp))
|
||||
Text(
|
||||
if (dateOfBirth != null) {
|
||||
SimpleDateFormat("MMM dd, yyyy", Locale.getDefault())
|
||||
.format(Date(dateOfBirth!!))
|
||||
} else {
|
||||
"Set Date of Birth"
|
||||
}
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// Suggested siblings
|
||||
if (suggestedSiblings.isNotEmpty()) {
|
||||
Text(
|
||||
"Appears with these people (select siblings):",
|
||||
style = MaterialTheme.typography.labelMedium
|
||||
)
|
||||
|
||||
suggestedSiblings.forEach { sibling ->
|
||||
Row(
|
||||
modifier = Modifier.fillMaxWidth(),
|
||||
verticalAlignment = Alignment.CenterVertically
|
||||
) {
|
||||
Checkbox(
|
||||
checked = sibling.clusterId in selectedSiblings,
|
||||
onCheckedChange = { checked ->
|
||||
selectedSiblings = if (checked) {
|
||||
selectedSiblings + sibling.clusterId
|
||||
} else {
|
||||
selectedSiblings - sibling.clusterId
|
||||
}
|
||||
}
|
||||
)
|
||||
Text("Person ${sibling.clusterId + 1} (${sibling.photoCount} photos)")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
confirmButton = {
|
||||
TextButton(
|
||||
onClick = {
|
||||
onConfirm(
|
||||
name,
|
||||
dateOfBirth,
|
||||
isChild,
|
||||
selectedSiblings.toList()
|
||||
)
|
||||
},
|
||||
enabled = name.isNotBlank()
|
||||
) {
|
||||
Text("Save & Train")
|
||||
}
|
||||
},
|
||||
dismissButton = {
|
||||
TextButton(onClick = onDismiss) {
|
||||
Text("Cancel")
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
// TODO: Add DatePickerDialog when showDatePicker is true
|
||||
}
|
||||
|
||||
/**
|
||||
* Empty state screen
|
||||
*/
|
||||
@Composable
|
||||
fun EmptyStateScreen(message: String) {
|
||||
Column(
|
||||
modifier = Modifier
|
||||
.fillMaxSize()
|
||||
.padding(32.dp),
|
||||
.padding(24.dp),
|
||||
horizontalAlignment = Alignment.CenterHorizontally,
|
||||
verticalArrangement = Arrangement.Center
|
||||
) {
|
||||
Icon(
|
||||
imageVector = Icons.Default.PersonOff,
|
||||
contentDescription = null,
|
||||
modifier = Modifier.size(80.dp),
|
||||
tint = MaterialTheme.colorScheme.onSurfaceVariant
|
||||
CircularProgressIndicator(
|
||||
modifier = Modifier.size(64.dp)
|
||||
)
|
||||
|
||||
Spacer(Modifier.height(16.dp))
|
||||
Spacer(modifier = Modifier.height(32.dp))
|
||||
|
||||
Text(
|
||||
text = message,
|
||||
style = MaterialTheme.typography.bodyLarge,
|
||||
style = MaterialTheme.typography.titleMedium,
|
||||
textAlign = TextAlign.Center
|
||||
)
|
||||
|
||||
Spacer(modifier = Modifier.height(16.dp))
|
||||
|
||||
if (total > 0) {
|
||||
LinearProgressIndicator(
|
||||
progress = progress.toFloat() / total.toFloat(),
|
||||
modifier = Modifier
|
||||
.fillMaxWidth()
|
||||
.height(8.dp)
|
||||
)
|
||||
|
||||
Spacer(modifier = Modifier.height(8.dp))
|
||||
|
||||
Text(
|
||||
text = "$progress / $total",
|
||||
style = MaterialTheme.typography.bodyMedium,
|
||||
color = MaterialTheme.colorScheme.onSurfaceVariant
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Error screen
|
||||
*/
|
||||
// ===== TRAINING PROGRESS =====
|
||||
|
||||
@Composable
|
||||
fun ErrorScreen(
|
||||
message: String,
|
||||
onRetry: () -> Unit
|
||||
private fun TrainingProgressContent(
|
||||
stage: String,
|
||||
progress: Int,
|
||||
total: Int
|
||||
) {
|
||||
Column(
|
||||
modifier = Modifier
|
||||
.fillMaxSize()
|
||||
.padding(32.dp),
|
||||
.padding(24.dp),
|
||||
horizontalAlignment = Alignment.CenterHorizontally,
|
||||
verticalArrangement = Arrangement.Center
|
||||
) {
|
||||
Icon(
|
||||
imageVector = Icons.Default.Error,
|
||||
contentDescription = null,
|
||||
modifier = Modifier.size(80.dp),
|
||||
tint = MaterialTheme.colorScheme.error
|
||||
CircularProgressIndicator(
|
||||
modifier = Modifier.size(64.dp)
|
||||
)
|
||||
|
||||
Spacer(Modifier.height(16.dp))
|
||||
Spacer(modifier = Modifier.height(32.dp))
|
||||
|
||||
Text(
|
||||
text = "Oops!",
|
||||
text = stage,
|
||||
style = MaterialTheme.typography.titleMedium,
|
||||
textAlign = TextAlign.Center
|
||||
)
|
||||
|
||||
if (total > 0) {
|
||||
Spacer(modifier = Modifier.height(16.dp))
|
||||
|
||||
LinearProgressIndicator(
|
||||
progress = progress.toFloat() / total.toFloat(),
|
||||
modifier = Modifier
|
||||
.fillMaxWidth()
|
||||
.height(8.dp)
|
||||
)
|
||||
|
||||
Spacer(modifier = Modifier.height(8.dp))
|
||||
|
||||
Text(
|
||||
text = "$progress / $total",
|
||||
style = MaterialTheme.typography.bodyMedium,
|
||||
color = MaterialTheme.colorScheme.onSurfaceVariant
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ===== LOADING CONTENT =====
|
||||
|
||||
@Composable
|
||||
private fun LoadingContent(message: String) {
|
||||
Column(
|
||||
modifier = Modifier.fillMaxSize(),
|
||||
horizontalAlignment = Alignment.CenterHorizontally,
|
||||
verticalArrangement = Arrangement.Center
|
||||
) {
|
||||
CircularProgressIndicator()
|
||||
Spacer(modifier = Modifier.height(16.dp))
|
||||
Text(text = message)
|
||||
}
|
||||
}
|
||||
|
||||
// ===== COMPLETE STATE =====
|
||||
|
||||
@Composable
|
||||
private fun CompleteStateContent(
|
||||
message: String,
|
||||
onDone: () -> Unit
|
||||
) {
|
||||
Column(
|
||||
modifier = Modifier
|
||||
.fillMaxSize()
|
||||
.padding(24.dp),
|
||||
horizontalAlignment = Alignment.CenterHorizontally,
|
||||
verticalArrangement = Arrangement.Center
|
||||
) {
|
||||
Text(
|
||||
text = "🎉",
|
||||
style = MaterialTheme.typography.displayLarge
|
||||
)
|
||||
|
||||
Spacer(modifier = Modifier.height(24.dp))
|
||||
|
||||
Text(
|
||||
text = "Success!",
|
||||
style = MaterialTheme.typography.headlineMedium,
|
||||
fontWeight = FontWeight.Bold
|
||||
)
|
||||
|
||||
Spacer(Modifier.height(8.dp))
|
||||
Spacer(modifier = Modifier.height(16.dp))
|
||||
|
||||
Text(
|
||||
text = message,
|
||||
@@ -678,10 +350,74 @@ fun ErrorScreen(
|
||||
color = MaterialTheme.colorScheme.onSurfaceVariant
|
||||
)
|
||||
|
||||
Spacer(Modifier.height(24.dp))
|
||||
Spacer(modifier = Modifier.height(32.dp))
|
||||
|
||||
Button(onClick = onRetry) {
|
||||
Text("Try Again")
|
||||
Button(
|
||||
onClick = onDone,
|
||||
modifier = Modifier.fillMaxWidth()
|
||||
) {
|
||||
Text("Done")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ===== ERROR STATE =====
|
||||
|
||||
@Composable
|
||||
private fun ErrorStateContent(
|
||||
title: String,
|
||||
message: String,
|
||||
onRetry: () -> Unit,
|
||||
onBack: () -> Unit
|
||||
) {
|
||||
Column(
|
||||
modifier = Modifier
|
||||
.fillMaxSize()
|
||||
.padding(24.dp),
|
||||
horizontalAlignment = Alignment.CenterHorizontally,
|
||||
verticalArrangement = Arrangement.Center
|
||||
) {
|
||||
Text(
|
||||
text = "⚠️",
|
||||
style = MaterialTheme.typography.displayLarge
|
||||
)
|
||||
|
||||
Spacer(modifier = Modifier.height(24.dp))
|
||||
|
||||
Text(
|
||||
text = title,
|
||||
style = MaterialTheme.typography.headlineMedium,
|
||||
fontWeight = FontWeight.Bold
|
||||
)
|
||||
|
||||
Spacer(modifier = Modifier.height(16.dp))
|
||||
|
||||
Text(
|
||||
text = message,
|
||||
style = MaterialTheme.typography.bodyLarge,
|
||||
textAlign = TextAlign.Center,
|
||||
color = MaterialTheme.colorScheme.onSurfaceVariant
|
||||
)
|
||||
|
||||
Spacer(modifier = Modifier.height(32.dp))
|
||||
|
||||
Row(
|
||||
modifier = Modifier.fillMaxWidth(),
|
||||
horizontalArrangement = Arrangement.spacedBy(12.dp)
|
||||
) {
|
||||
OutlinedButton(
|
||||
onClick = onBack,
|
||||
modifier = Modifier.weight(1f)
|
||||
) {
|
||||
Text("Back")
|
||||
}
|
||||
|
||||
Button(
|
||||
onClick = onRetry,
|
||||
modifier = Modifier.weight(1f)
|
||||
) {
|
||||
Text("Retry")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -2,10 +2,15 @@ package com.placeholder.sherpai2.ui.discover
|
||||
|
||||
import androidx.lifecycle.ViewModel
|
||||
import androidx.lifecycle.viewModelScope
|
||||
import androidx.work.WorkManager
|
||||
import com.placeholder.sherpai2.domain.clustering.ClusteringResult
|
||||
import com.placeholder.sherpai2.domain.clustering.ClusterQualityResult
|
||||
import com.placeholder.sherpai2.domain.clustering.FaceCluster
|
||||
import com.placeholder.sherpai2.domain.clustering.FaceClusteringService
|
||||
import com.placeholder.sherpai2.domain.training.ClusterTrainingService
|
||||
import com.placeholder.sherpai2.domain.validation.ValidationScanResult
|
||||
import com.placeholder.sherpai2.domain.validation.ValidationScanService
|
||||
import com.placeholder.sherpai2.workers.LibraryScanWorker
|
||||
import dagger.hilt.android.lifecycle.HiltViewModel
|
||||
import kotlinx.coroutines.flow.MutableStateFlow
|
||||
import kotlinx.coroutines.flow.StateFlow
|
||||
@@ -14,21 +19,22 @@ import kotlinx.coroutines.launch
|
||||
import javax.inject.Inject
|
||||
|
||||
/**
|
||||
* DiscoverPeopleViewModel - Manages auto-clustering and naming flow
|
||||
* DiscoverPeopleViewModel - Manages TWO-STAGE validation flow
|
||||
*
|
||||
* PHASE 2: Now includes multi-centroid training from clusters
|
||||
*
|
||||
* STATE FLOW:
|
||||
* 1. Idle → User taps "Discover People"
|
||||
* 2. Clustering → Auto-analyzing faces (2-5 min)
|
||||
* 3. NamingReady → Shows clusters, user names them
|
||||
* 4. Training → Creating multi-centroid face model
|
||||
* 5. Complete → Ready to scan library
|
||||
* FLOW:
|
||||
* 1. Clustering → User selects cluster
|
||||
* 2. STAGE 1: Show cluster quality analysis
|
||||
* 3. User names person → Training
|
||||
* 4. STAGE 2: Show validation scan preview
|
||||
* 5. User approves → Full library scan (background worker)
|
||||
* 6. Results appear in "People" tab
|
||||
*/
|
||||
@HiltViewModel
|
||||
class DiscoverPeopleViewModel @Inject constructor(
|
||||
private val clusteringService: FaceClusteringService,
|
||||
private val trainingService: ClusterTrainingService
|
||||
private val trainingService: ClusterTrainingService,
|
||||
private val validationScanService: ValidationScanService,
|
||||
private val workManager: WorkManager
|
||||
) : ViewModel() {
|
||||
|
||||
private val _uiState = MutableStateFlow<DiscoverUiState>(DiscoverUiState.Idle)
|
||||
@@ -37,6 +43,9 @@ class DiscoverPeopleViewModel @Inject constructor(
|
||||
// Track which clusters have been named
|
||||
private val namedClusterIds = mutableSetOf<Int>()
|
||||
|
||||
// Store quality analysis for current cluster
|
||||
private var currentQualityResult: ClusterQualityResult? = null
|
||||
|
||||
/**
|
||||
* Start auto-clustering process
|
||||
*/
|
||||
@@ -78,27 +87,41 @@ class DiscoverPeopleViewModel @Inject constructor(
|
||||
|
||||
/**
|
||||
* User selected a cluster to name
|
||||
* STAGE 1: Analyze quality FIRST
|
||||
*/
|
||||
fun selectCluster(cluster: FaceCluster) {
|
||||
val currentState = _uiState.value
|
||||
if (currentState is DiscoverUiState.NamingReady) {
|
||||
_uiState.value = DiscoverUiState.NamingCluster(
|
||||
result = currentState.result,
|
||||
selectedCluster = cluster,
|
||||
suggestedSiblings = currentState.result.clusters.filter {
|
||||
it.clusterId in cluster.potentialSiblings
|
||||
viewModelScope.launch {
|
||||
try {
|
||||
// Show analyzing state
|
||||
_uiState.value = DiscoverUiState.AnalyzingCluster(cluster)
|
||||
|
||||
// Analyze cluster quality
|
||||
val qualityResult = trainingService.analyzeClusterQuality(cluster)
|
||||
currentQualityResult = qualityResult
|
||||
|
||||
// Show naming dialog with quality info
|
||||
_uiState.value = DiscoverUiState.NamingCluster(
|
||||
result = currentState.result,
|
||||
selectedCluster = cluster,
|
||||
qualityResult = qualityResult,
|
||||
suggestedSiblings = currentState.result.clusters.filter {
|
||||
it.clusterId in cluster.potentialSiblings
|
||||
}
|
||||
)
|
||||
} catch (e: Exception) {
|
||||
_uiState.value = DiscoverUiState.Error(
|
||||
"Failed to analyze cluster: ${e.message}"
|
||||
)
|
||||
}
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* User confirmed name and metadata for a cluster
|
||||
*
|
||||
* CREATES:
|
||||
* 1. PersonEntity with all metadata (name, DOB, siblings)
|
||||
* 2. Multi-centroid FaceModelEntity (temporal tracking for children)
|
||||
* 3. Removes cluster from display
|
||||
* STAGE 2: Train → Validation scan → Preview
|
||||
*/
|
||||
fun confirmClusterName(
|
||||
cluster: FaceCluster,
|
||||
@@ -112,37 +135,59 @@ class DiscoverPeopleViewModel @Inject constructor(
|
||||
val currentState = _uiState.value
|
||||
if (currentState !is DiscoverUiState.NamingCluster) return@launch
|
||||
|
||||
// Train person from cluster
|
||||
// Show training progress
|
||||
_uiState.value = DiscoverUiState.Training(
|
||||
stage = "Creating person and training model",
|
||||
progress = 0,
|
||||
total = 100
|
||||
)
|
||||
|
||||
// Train person from cluster (using clean faces from quality analysis)
|
||||
val personId = trainingService.trainFromCluster(
|
||||
cluster = cluster,
|
||||
name = name,
|
||||
dateOfBirth = dateOfBirth,
|
||||
isChild = isChild,
|
||||
siblingClusterIds = selectedSiblings,
|
||||
qualityResult = currentQualityResult, // Use clean faces!
|
||||
onProgress = { current, total, message ->
|
||||
_uiState.value = DiscoverUiState.Clustering(current, total, message)
|
||||
_uiState.value = DiscoverUiState.Training(
|
||||
stage = message,
|
||||
progress = current,
|
||||
total = total
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
// Training complete - now run validation scan
|
||||
_uiState.value = DiscoverUiState.Training(
|
||||
stage = "Running validation scan...",
|
||||
progress = 0,
|
||||
total = 100
|
||||
)
|
||||
|
||||
val validationResult = validationScanService.performValidationScan(
|
||||
personId = personId,
|
||||
onProgress = { current, total ->
|
||||
_uiState.value = DiscoverUiState.Training(
|
||||
stage = "Scanning sample photos...",
|
||||
progress = current,
|
||||
total = total
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
// Show validation preview to user
|
||||
_uiState.value = DiscoverUiState.ValidationPreview(
|
||||
personId = personId,
|
||||
personName = name,
|
||||
validationResult = validationResult,
|
||||
originalClusterResult = currentState.result
|
||||
)
|
||||
|
||||
// Mark cluster as named
|
||||
namedClusterIds.add(cluster.clusterId)
|
||||
|
||||
// Filter out named clusters
|
||||
val remainingClusters = currentState.result.clusters
|
||||
.filter { it.clusterId !in namedClusterIds }
|
||||
|
||||
if (remainingClusters.isEmpty()) {
|
||||
// All clusters named! Show success
|
||||
_uiState.value = DiscoverUiState.NoPeopleFound(
|
||||
"All people have been named! 🎉\n\nGo to 'People' to see your trained models."
|
||||
)
|
||||
} else {
|
||||
// Return to naming screen with remaining clusters
|
||||
_uiState.value = DiscoverUiState.NamingReady(
|
||||
result = currentState.result.copy(clusters = remainingClusters)
|
||||
)
|
||||
}
|
||||
|
||||
} catch (e: Exception) {
|
||||
_uiState.value = DiscoverUiState.Error(
|
||||
e.message ?: "Failed to create person: ${e.message}"
|
||||
@@ -151,6 +196,57 @@ class DiscoverPeopleViewModel @Inject constructor(
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* User approves validation preview → Start full library scan
|
||||
*/
|
||||
fun approveValidationAndScan(personId: String, personName: String) {
|
||||
viewModelScope.launch {
|
||||
val currentState = _uiState.value
|
||||
if (currentState !is DiscoverUiState.ValidationPreview) return@launch
|
||||
|
||||
// Enqueue background worker for full library scan
|
||||
val workRequest = LibraryScanWorker.createWorkRequest(
|
||||
personId = personId,
|
||||
personName = personName,
|
||||
threshold = 0.70f // Slightly looser than validation
|
||||
)
|
||||
workManager.enqueue(workRequest)
|
||||
|
||||
// Filter out named clusters and return to cluster list
|
||||
val remainingClusters = currentState.originalClusterResult.clusters
|
||||
.filter { it.clusterId !in namedClusterIds }
|
||||
|
||||
if (remainingClusters.isEmpty()) {
|
||||
// All clusters named! Show success
|
||||
_uiState.value = DiscoverUiState.Complete(
|
||||
message = "All people have been named! 🎉\n\n" +
|
||||
"Full library scan is running in the background.\n" +
|
||||
"Go to 'People' to see results as they come in."
|
||||
)
|
||||
} else {
|
||||
// Return to naming screen with remaining clusters
|
||||
_uiState.value = DiscoverUiState.NamingReady(
|
||||
result = currentState.originalClusterResult.copy(clusters = remainingClusters)
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* User rejects validation → Go back to add more training photos
|
||||
*/
|
||||
fun rejectValidationAndImprove() {
|
||||
viewModelScope.launch {
|
||||
val currentState = _uiState.value
|
||||
if (currentState !is DiscoverUiState.ValidationPreview) return@launch
|
||||
|
||||
_uiState.value = DiscoverUiState.Error(
|
||||
"Model quality needs improvement.\n\n" +
|
||||
"Please use the manual training flow to add more high-quality photos."
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Cancel naming and go back to cluster list
|
||||
*/
|
||||
@@ -172,7 +268,7 @@ class DiscoverPeopleViewModel @Inject constructor(
|
||||
}
|
||||
|
||||
/**
|
||||
* UI States for Discover People flow
|
||||
* UI States for Discover People flow with TWO-STAGE VALIDATION
|
||||
*/
|
||||
sealed class DiscoverUiState {
|
||||
|
||||
@@ -198,14 +294,48 @@ sealed class DiscoverUiState {
|
||||
) : DiscoverUiState()
|
||||
|
||||
/**
|
||||
* User is naming a specific cluster
|
||||
* STAGE 1: Analyzing cluster quality (before naming)
|
||||
*/
|
||||
data class AnalyzingCluster(
|
||||
val cluster: FaceCluster
|
||||
) : DiscoverUiState()
|
||||
|
||||
/**
|
||||
* User is naming a specific cluster (with quality analysis)
|
||||
*/
|
||||
data class NamingCluster(
|
||||
val result: ClusteringResult,
|
||||
val selectedCluster: FaceCluster,
|
||||
val qualityResult: ClusterQualityResult,
|
||||
val suggestedSiblings: List<FaceCluster>
|
||||
) : DiscoverUiState()
|
||||
|
||||
/**
|
||||
* Training in progress
|
||||
*/
|
||||
data class Training(
|
||||
val stage: String,
|
||||
val progress: Int,
|
||||
val total: Int
|
||||
) : DiscoverUiState()
|
||||
|
||||
/**
|
||||
* STAGE 2: Validation scan complete - show preview to user
|
||||
*/
|
||||
data class ValidationPreview(
|
||||
val personId: String,
|
||||
val personName: String,
|
||||
val validationResult: ValidationScanResult,
|
||||
val originalClusterResult: ClusteringResult
|
||||
) : DiscoverUiState()
|
||||
|
||||
/**
|
||||
* All clusters named and scans launched
|
||||
*/
|
||||
data class Complete(
|
||||
val message: String
|
||||
) : DiscoverUiState()
|
||||
|
||||
/**
|
||||
* No people found in library
|
||||
*/
|
||||
|
||||
@@ -0,0 +1,395 @@
|
||||
package com.placeholder.sherpai2.ui.discover
|
||||
|
||||
import android.net.Uri
|
||||
import androidx.compose.foundation.background
|
||||
import androidx.compose.foundation.border
|
||||
import androidx.compose.foundation.layout.*
|
||||
import androidx.compose.foundation.lazy.grid.GridCells
|
||||
import androidx.compose.foundation.lazy.grid.LazyVerticalGrid
|
||||
import androidx.compose.foundation.lazy.grid.items
|
||||
import androidx.compose.foundation.shape.RoundedCornerShape
|
||||
import androidx.compose.material.icons.Icons
|
||||
import androidx.compose.material.icons.filled.*
|
||||
import androidx.compose.material3.*
|
||||
import androidx.compose.runtime.*
|
||||
import androidx.compose.ui.Alignment
|
||||
import androidx.compose.ui.Modifier
|
||||
import androidx.compose.ui.draw.clip
|
||||
import androidx.compose.ui.graphics.Color
|
||||
import androidx.compose.ui.layout.ContentScale
|
||||
import androidx.compose.ui.text.font.FontWeight
|
||||
import androidx.compose.ui.unit.dp
|
||||
import coil.compose.AsyncImage
|
||||
import com.placeholder.sherpai2.domain.validation.ValidationMatch
|
||||
import com.placeholder.sherpai2.domain.validation.ValidationQuality
|
||||
import com.placeholder.sherpai2.domain.validation.ValidationScanResult
|
||||
|
||||
/**
|
||||
* ValidationPreviewScreen - STAGE 2 validation UI
|
||||
*
|
||||
* Shows user a preview of matches found in validation scan
|
||||
* User can approve (→ full scan) or reject (→ add more photos)
|
||||
*/
|
||||
@Composable
|
||||
fun ValidationPreviewScreen(
|
||||
personName: String,
|
||||
validationResult: ValidationScanResult,
|
||||
onApprove: () -> Unit,
|
||||
onReject: () -> Unit,
|
||||
modifier: Modifier = Modifier
|
||||
) {
|
||||
Column(
|
||||
modifier = modifier
|
||||
.fillMaxSize()
|
||||
.padding(16.dp)
|
||||
) {
|
||||
// Header
|
||||
Text(
|
||||
text = "Validation Results",
|
||||
style = MaterialTheme.typography.headlineMedium,
|
||||
fontWeight = FontWeight.Bold
|
||||
)
|
||||
|
||||
Spacer(modifier = Modifier.height(8.dp))
|
||||
|
||||
Text(
|
||||
text = "Review matches for \"$personName\" before scanning your entire library",
|
||||
style = MaterialTheme.typography.bodyMedium,
|
||||
color = MaterialTheme.colorScheme.onSurfaceVariant
|
||||
)
|
||||
|
||||
Spacer(modifier = Modifier.height(16.dp))
|
||||
|
||||
// Quality Summary
|
||||
QualitySummaryCard(
|
||||
validationResult = validationResult,
|
||||
personName = personName
|
||||
)
|
||||
|
||||
Spacer(modifier = Modifier.height(16.dp))
|
||||
|
||||
// Matches Grid
|
||||
if (validationResult.matches.isNotEmpty()) {
|
||||
Text(
|
||||
text = "Sample Matches (${validationResult.matchCount})",
|
||||
style = MaterialTheme.typography.titleMedium,
|
||||
fontWeight = FontWeight.SemiBold
|
||||
)
|
||||
|
||||
Spacer(modifier = Modifier.height(8.dp))
|
||||
|
||||
LazyVerticalGrid(
|
||||
columns = GridCells.Fixed(3),
|
||||
horizontalArrangement = Arrangement.spacedBy(8.dp),
|
||||
verticalArrangement = Arrangement.spacedBy(8.dp),
|
||||
modifier = Modifier.weight(1f)
|
||||
) {
|
||||
items(validationResult.matches.take(15)) { match ->
|
||||
MatchPreviewCard(match = match)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// No matches found
|
||||
NoMatchesCard()
|
||||
}
|
||||
|
||||
Spacer(modifier = Modifier.height(16.dp))
|
||||
|
||||
// Action Buttons
|
||||
Row(
|
||||
modifier = Modifier.fillMaxWidth(),
|
||||
horizontalArrangement = Arrangement.spacedBy(12.dp)
|
||||
) {
|
||||
// Reject button
|
||||
OutlinedButton(
|
||||
onClick = onReject,
|
||||
modifier = Modifier.weight(1f),
|
||||
colors = ButtonDefaults.outlinedButtonColors(
|
||||
contentColor = MaterialTheme.colorScheme.error
|
||||
)
|
||||
) {
|
||||
Icon(
|
||||
imageVector = Icons.Default.Close,
|
||||
contentDescription = null,
|
||||
modifier = Modifier.size(20.dp)
|
||||
)
|
||||
Spacer(modifier = Modifier.width(8.dp))
|
||||
Text("Add More Photos")
|
||||
}
|
||||
|
||||
// Approve button
|
||||
Button(
|
||||
onClick = onApprove,
|
||||
modifier = Modifier.weight(1f),
|
||||
enabled = validationResult.qualityAssessment != ValidationQuality.NO_MATCHES
|
||||
) {
|
||||
Icon(
|
||||
imageVector = Icons.Default.Check,
|
||||
contentDescription = null,
|
||||
modifier = Modifier.size(20.dp)
|
||||
)
|
||||
Spacer(modifier = Modifier.width(8.dp))
|
||||
Text("Scan Library")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Composable
|
||||
private fun QualitySummaryCard(
|
||||
validationResult: ValidationScanResult,
|
||||
personName: String
|
||||
) {
|
||||
val (backgroundColor, iconColor, statusText, statusIcon) = when (validationResult.qualityAssessment) {
|
||||
ValidationQuality.EXCELLENT -> {
|
||||
Quadruple(
|
||||
Color(0xFF1B5E20).copy(alpha = 0.1f),
|
||||
Color(0xFF1B5E20),
|
||||
"Excellent Match Quality",
|
||||
Icons.Default.CheckCircle
|
||||
)
|
||||
}
|
||||
ValidationQuality.GOOD -> {
|
||||
Quadruple(
|
||||
Color(0xFF2E7D32).copy(alpha = 0.1f),
|
||||
Color(0xFF2E7D32),
|
||||
"Good Match Quality",
|
||||
Icons.Default.ThumbUp
|
||||
)
|
||||
}
|
||||
ValidationQuality.FAIR -> {
|
||||
Quadruple(
|
||||
Color(0xFFF57F17).copy(alpha = 0.1f),
|
||||
Color(0xFFF57F17),
|
||||
"Fair Match Quality",
|
||||
Icons.Default.Warning
|
||||
)
|
||||
}
|
||||
ValidationQuality.POOR -> {
|
||||
Quadruple(
|
||||
Color(0xFFD32F2F).copy(alpha = 0.1f),
|
||||
Color(0xFFD32F2F),
|
||||
"Poor Match Quality",
|
||||
Icons.Default.Warning
|
||||
)
|
||||
}
|
||||
ValidationQuality.NO_MATCHES -> {
|
||||
Quadruple(
|
||||
Color(0xFFD32F2F).copy(alpha = 0.1f),
|
||||
Color(0xFFD32F2F),
|
||||
"No Matches Found",
|
||||
Icons.Default.Close
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
Card(
|
||||
modifier = Modifier.fillMaxWidth(),
|
||||
colors = CardDefaults.cardColors(
|
||||
containerColor = backgroundColor
|
||||
)
|
||||
) {
|
||||
Column(
|
||||
modifier = Modifier.padding(16.dp)
|
||||
) {
|
||||
Row(
|
||||
verticalAlignment = Alignment.CenterVertically
|
||||
) {
|
||||
Icon(
|
||||
imageVector = statusIcon,
|
||||
contentDescription = null,
|
||||
tint = iconColor,
|
||||
modifier = Modifier.size(24.dp)
|
||||
)
|
||||
Spacer(modifier = Modifier.width(8.dp))
|
||||
Text(
|
||||
text = statusText,
|
||||
style = MaterialTheme.typography.titleMedium,
|
||||
fontWeight = FontWeight.Bold,
|
||||
color = iconColor
|
||||
)
|
||||
}
|
||||
|
||||
Spacer(modifier = Modifier.height(12.dp))
|
||||
|
||||
// Stats
|
||||
Row(
|
||||
modifier = Modifier.fillMaxWidth(),
|
||||
horizontalArrangement = Arrangement.SpaceBetween
|
||||
) {
|
||||
StatItem(
|
||||
label = "Matches Found",
|
||||
value = "${validationResult.matchCount} / ${validationResult.sampleSize}"
|
||||
)
|
||||
StatItem(
|
||||
label = "Avg Confidence",
|
||||
value = "${(validationResult.averageConfidence * 100).toInt()}%"
|
||||
)
|
||||
StatItem(
|
||||
label = "Threshold",
|
||||
value = "${(validationResult.threshold * 100).toInt()}%"
|
||||
)
|
||||
}
|
||||
|
||||
// Recommendation
|
||||
if (validationResult.qualityAssessment != ValidationQuality.NO_MATCHES) {
|
||||
Spacer(modifier = Modifier.height(12.dp))
|
||||
|
||||
val recommendation = when (validationResult.qualityAssessment) {
|
||||
ValidationQuality.EXCELLENT ->
|
||||
"✅ Model looks great! Safe to scan your full library."
|
||||
ValidationQuality.GOOD ->
|
||||
"✅ Model quality is good. You can proceed with the full scan."
|
||||
ValidationQuality.FAIR ->
|
||||
"⚠️ Model quality is acceptable but could be improved with more photos."
|
||||
ValidationQuality.POOR ->
|
||||
"⚠️ Consider adding more diverse, high-quality training photos."
|
||||
ValidationQuality.NO_MATCHES -> ""
|
||||
}
|
||||
|
||||
Text(
|
||||
text = recommendation,
|
||||
style = MaterialTheme.typography.bodyMedium,
|
||||
color = MaterialTheme.colorScheme.onSurfaceVariant
|
||||
)
|
||||
} else {
|
||||
Spacer(modifier = Modifier.height(12.dp))
|
||||
Text(
|
||||
text = "No matches found. The model may need more or better training photos, or the validation sample didn't include $personName.",
|
||||
style = MaterialTheme.typography.bodyMedium,
|
||||
color = MaterialTheme.colorScheme.error
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Composable
|
||||
private fun StatItem(
|
||||
label: String,
|
||||
value: String
|
||||
) {
|
||||
Column(
|
||||
horizontalAlignment = Alignment.CenterHorizontally
|
||||
) {
|
||||
Text(
|
||||
text = value,
|
||||
style = MaterialTheme.typography.titleLarge,
|
||||
fontWeight = FontWeight.Bold
|
||||
)
|
||||
Text(
|
||||
text = label,
|
||||
style = MaterialTheme.typography.bodySmall,
|
||||
color = MaterialTheme.colorScheme.onSurfaceVariant
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@Composable
|
||||
private fun MatchPreviewCard(
|
||||
match: ValidationMatch
|
||||
) {
|
||||
Box(
|
||||
modifier = Modifier
|
||||
.aspectRatio(1f)
|
||||
.clip(RoundedCornerShape(8.dp))
|
||||
.background(MaterialTheme.colorScheme.surfaceVariant)
|
||||
) {
|
||||
AsyncImage(
|
||||
model = Uri.parse(match.imageUri),
|
||||
contentDescription = "Match preview",
|
||||
modifier = Modifier.fillMaxSize(),
|
||||
contentScale = ContentScale.Crop
|
||||
)
|
||||
|
||||
// Confidence badge
|
||||
Surface(
|
||||
modifier = Modifier
|
||||
.align(Alignment.BottomEnd)
|
||||
.padding(4.dp),
|
||||
shape = RoundedCornerShape(4.dp),
|
||||
color = Color.Black.copy(alpha = 0.7f)
|
||||
) {
|
||||
Text(
|
||||
text = "${(match.confidence * 100).toInt()}%",
|
||||
style = MaterialTheme.typography.labelSmall,
|
||||
color = Color.White,
|
||||
modifier = Modifier.padding(horizontal = 6.dp, vertical = 2.dp)
|
||||
)
|
||||
}
|
||||
|
||||
// Face count indicator (if group photo)
|
||||
if (match.faceCount > 1) {
|
||||
Surface(
|
||||
modifier = Modifier
|
||||
.align(Alignment.TopEnd)
|
||||
.padding(4.dp),
|
||||
shape = RoundedCornerShape(4.dp),
|
||||
color = MaterialTheme.colorScheme.primary
|
||||
) {
|
||||
Row(
|
||||
modifier = Modifier.padding(horizontal = 4.dp, vertical = 2.dp),
|
||||
verticalAlignment = Alignment.CenterVertically
|
||||
) {
|
||||
Icon(
|
||||
imageVector = Icons.Default.Person,
|
||||
contentDescription = null,
|
||||
tint = Color.White,
|
||||
modifier = Modifier.size(12.dp)
|
||||
)
|
||||
Text(
|
||||
text = "${match.faceCount}",
|
||||
style = MaterialTheme.typography.labelSmall,
|
||||
color = Color.White
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Composable
|
||||
private fun NoMatchesCard() {
|
||||
Card(
|
||||
modifier = Modifier.fillMaxWidth(),
|
||||
colors = CardDefaults.cardColors(
|
||||
containerColor = MaterialTheme.colorScheme.errorContainer
|
||||
)
|
||||
) {
|
||||
Column(
|
||||
modifier = Modifier.padding(24.dp),
|
||||
horizontalAlignment = Alignment.CenterHorizontally
|
||||
) {
|
||||
Icon(
|
||||
imageVector = Icons.Default.Warning,
|
||||
contentDescription = null,
|
||||
tint = MaterialTheme.colorScheme.error,
|
||||
modifier = Modifier.size(48.dp)
|
||||
)
|
||||
Spacer(modifier = Modifier.height(16.dp))
|
||||
Text(
|
||||
text = "No Matches Found",
|
||||
style = MaterialTheme.typography.titleLarge,
|
||||
fontWeight = FontWeight.Bold,
|
||||
color = MaterialTheme.colorScheme.error
|
||||
)
|
||||
Spacer(modifier = Modifier.height(8.dp))
|
||||
Text(
|
||||
text = "The validation scan didn't find this person in the sample photos. This could mean:\n\n" +
|
||||
"• The model needs more training photos\n" +
|
||||
"• The training photos weren't diverse enough\n" +
|
||||
"• The person wasn't in the validation sample",
|
||||
style = MaterialTheme.typography.bodyMedium,
|
||||
color = MaterialTheme.colorScheme.onErrorContainer
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Helper data class for quality indicator
|
||||
private data class Quadruple<A, B, C, D>(
|
||||
val first: A,
|
||||
val second: B,
|
||||
val third: C,
|
||||
val fourth: D
|
||||
)
|
||||
@@ -153,14 +153,4 @@ fun getDestinationByRoute(route: String?): AppDestinations? {
|
||||
AppRoutes.SETTINGS -> AppDestinations.Settings
|
||||
else -> null
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Legacy support (for backwards compatibility)
|
||||
* These match your old structure
|
||||
*/
|
||||
@Deprecated("Use organized groups instead", ReplaceWith("allMainDrawerDestinations"))
|
||||
val mainDrawerItems = allMainDrawerDestinations
|
||||
|
||||
@Deprecated("Use settingsDestination instead", ReplaceWith("listOf(settingsDestination)"))
|
||||
val utilityDrawerItems = listOf(settingsDestination)
|
||||
}
|
||||
@@ -15,7 +15,10 @@ import com.placeholder.sherpai2.ui.navigation.AppRoutes
|
||||
import kotlinx.coroutines.launch
|
||||
|
||||
/**
|
||||
* Clean main screen - NO duplicate FABs, Collections support, Discover People
|
||||
* MainScreen - FIXED double header issue
|
||||
*
|
||||
* BEST PRACTICE: Screens that manage their own TopAppBar should be excluded
|
||||
* from MainScreen's TopAppBar to prevent ugly double headers.
|
||||
*/
|
||||
@OptIn(ExperimentalMaterial3Api::class)
|
||||
@Composable
|
||||
@@ -45,68 +48,77 @@ fun MainScreen() {
|
||||
)
|
||||
},
|
||||
) {
|
||||
// CRITICAL: Some screens manage their own TopAppBar
|
||||
// Hide MainScreen's TopAppBar for these routes to prevent double headers
|
||||
val screensWithOwnTopBar = setOf(
|
||||
AppRoutes.TRAINING_PHOTO_SELECTOR // Has its own TopAppBar with subtitle
|
||||
)
|
||||
val showTopBar = currentRoute !in screensWithOwnTopBar
|
||||
|
||||
Scaffold(
|
||||
topBar = {
|
||||
TopAppBar(
|
||||
title = {
|
||||
Column {
|
||||
Text(
|
||||
text = getScreenTitle(currentRoute),
|
||||
style = MaterialTheme.typography.titleLarge,
|
||||
fontWeight = FontWeight.Bold
|
||||
)
|
||||
getScreenSubtitle(currentRoute)?.let { subtitle ->
|
||||
if (showTopBar) {
|
||||
TopAppBar(
|
||||
title = {
|
||||
Column {
|
||||
Text(
|
||||
text = subtitle,
|
||||
style = MaterialTheme.typography.bodySmall,
|
||||
color = MaterialTheme.colorScheme.onSurfaceVariant
|
||||
text = getScreenTitle(currentRoute),
|
||||
style = MaterialTheme.typography.titleLarge,
|
||||
fontWeight = FontWeight.Bold
|
||||
)
|
||||
getScreenSubtitle(currentRoute)?.let { subtitle ->
|
||||
Text(
|
||||
text = subtitle,
|
||||
style = MaterialTheme.typography.bodySmall,
|
||||
color = MaterialTheme.colorScheme.onSurfaceVariant
|
||||
)
|
||||
}
|
||||
}
|
||||
},
|
||||
navigationIcon = {
|
||||
IconButton(
|
||||
onClick = { scope.launch { drawerState.open() } }
|
||||
) {
|
||||
Icon(
|
||||
Icons.Default.Menu,
|
||||
contentDescription = "Open Menu",
|
||||
tint = MaterialTheme.colorScheme.primary
|
||||
)
|
||||
}
|
||||
}
|
||||
},
|
||||
navigationIcon = {
|
||||
IconButton(
|
||||
onClick = { scope.launch { drawerState.open() } }
|
||||
) {
|
||||
Icon(
|
||||
Icons.Default.Menu,
|
||||
contentDescription = "Open Menu",
|
||||
tint = MaterialTheme.colorScheme.primary
|
||||
)
|
||||
}
|
||||
},
|
||||
actions = {
|
||||
// Dynamic actions based on current screen
|
||||
when (currentRoute) {
|
||||
AppRoutes.SEARCH -> {
|
||||
IconButton(onClick = { /* TODO: Open filter dialog */ }) {
|
||||
Icon(
|
||||
Icons.Default.FilterList,
|
||||
contentDescription = "Filter",
|
||||
tint = MaterialTheme.colorScheme.primary
|
||||
)
|
||||
},
|
||||
actions = {
|
||||
// Dynamic actions based on current screen
|
||||
when (currentRoute) {
|
||||
AppRoutes.SEARCH -> {
|
||||
IconButton(onClick = { /* TODO: Open filter dialog */ }) {
|
||||
Icon(
|
||||
Icons.Default.FilterList,
|
||||
contentDescription = "Filter",
|
||||
tint = MaterialTheme.colorScheme.primary
|
||||
)
|
||||
}
|
||||
}
|
||||
AppRoutes.INVENTORY -> {
|
||||
IconButton(onClick = {
|
||||
navController.navigate(AppRoutes.TRAIN)
|
||||
}) {
|
||||
Icon(
|
||||
Icons.Default.PersonAdd,
|
||||
contentDescription = "Add Person",
|
||||
tint = MaterialTheme.colorScheme.primary
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
AppRoutes.INVENTORY -> {
|
||||
IconButton(onClick = {
|
||||
navController.navigate(AppRoutes.TRAIN)
|
||||
}) {
|
||||
Icon(
|
||||
Icons.Default.PersonAdd,
|
||||
contentDescription = "Add Person",
|
||||
tint = MaterialTheme.colorScheme.primary
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
colors = TopAppBarDefaults.topAppBarColors(
|
||||
containerColor = MaterialTheme.colorScheme.surface,
|
||||
titleContentColor = MaterialTheme.colorScheme.onSurface,
|
||||
navigationIconContentColor = MaterialTheme.colorScheme.primary,
|
||||
actionIconContentColor = MaterialTheme.colorScheme.primary
|
||||
},
|
||||
colors = TopAppBarDefaults.topAppBarColors(
|
||||
containerColor = MaterialTheme.colorScheme.surface,
|
||||
titleContentColor = MaterialTheme.colorScheme.onSurface,
|
||||
navigationIconContentColor = MaterialTheme.colorScheme.primary,
|
||||
actionIconContentColor = MaterialTheme.colorScheme.primary
|
||||
)
|
||||
)
|
||||
)
|
||||
}
|
||||
}
|
||||
) { paddingValues ->
|
||||
AppNavHost(
|
||||
@@ -125,10 +137,10 @@ private fun getScreenTitle(route: String): String {
|
||||
AppRoutes.SEARCH -> "Search"
|
||||
AppRoutes.EXPLORE -> "Explore"
|
||||
AppRoutes.COLLECTIONS -> "Collections"
|
||||
AppRoutes.DISCOVER -> "Discover People" // ✨ NEW!
|
||||
AppRoutes.DISCOVER -> "Discover People"
|
||||
AppRoutes.INVENTORY -> "People"
|
||||
AppRoutes.TRAIN -> "Train New Person"
|
||||
AppRoutes.MODELS -> "AI Models" // Deprecated, but keep for backwards compat
|
||||
AppRoutes.MODELS -> "AI Models"
|
||||
AppRoutes.TAGS -> "Tag Management"
|
||||
AppRoutes.UTILITIES -> "Photo Util."
|
||||
AppRoutes.SETTINGS -> "Settings"
|
||||
@@ -144,7 +156,7 @@ private fun getScreenSubtitle(route: String): String? {
|
||||
AppRoutes.SEARCH -> "Find photos by tags, people, or date"
|
||||
AppRoutes.EXPLORE -> "Browse your collection"
|
||||
AppRoutes.COLLECTIONS -> "Your photo collections"
|
||||
AppRoutes.DISCOVER -> "Auto-find faces in your library" // ✨ NEW!
|
||||
AppRoutes.DISCOVER -> "Auto-find faces in your library"
|
||||
AppRoutes.INVENTORY -> "Trained face models"
|
||||
AppRoutes.TRAIN -> "Add a new person to recognize"
|
||||
AppRoutes.TAGS -> "Organize your photo collection"
|
||||
|
||||
@@ -14,7 +14,9 @@ import javax.inject.Inject
|
||||
* ImageSelectorViewModel
|
||||
*
|
||||
* Provides face-tagged image URIs for smart filtering
|
||||
* during training photo selection
|
||||
* during training photo selection.
|
||||
*
|
||||
* PRIORITIZATION: Solo photos first (faceCount=1) for clearer training data
|
||||
*/
|
||||
@HiltViewModel
|
||||
class ImageSelectorViewModel @Inject constructor(
|
||||
@@ -31,8 +33,15 @@ class ImageSelectorViewModel @Inject constructor(
|
||||
private fun loadFaceTaggedImages() {
|
||||
viewModelScope.launch {
|
||||
try {
|
||||
// Get all images with faces
|
||||
val imagesWithFaces = imageDao.getImagesWithFaces()
|
||||
_faceTaggedImageUris.value = imagesWithFaces.map { it.imageUri }
|
||||
|
||||
// CRITICAL FIX: Sort by faceCount ASCENDING (solo photos first!)
|
||||
// Previously: Sorted by faceCount DESC (group photos first - WRONG!)
|
||||
// Now: Solo photos appear first, making training selection easier
|
||||
val sortedImages = imagesWithFaces.sortedBy { it.faceCount }
|
||||
|
||||
_faceTaggedImageUris.value = sortedImages.map { it.imageUri }
|
||||
} catch (e: Exception) {
|
||||
// If cache not available, just use empty list (filter disabled)
|
||||
_faceTaggedImageUris.value = emptyList()
|
||||
|
||||
@@ -46,6 +46,8 @@ class TrainingPhotoSelectorViewModel @Inject constructor(
|
||||
*
|
||||
* Uses indexed query: SELECT * FROM images WHERE hasFaces = 1
|
||||
* Fast! (~10ms for 10k photos)
|
||||
*
|
||||
* SORTED: Solo photos (faceCount=1) first for best training quality
|
||||
*/
|
||||
private fun loadPhotosWithFaces() {
|
||||
viewModelScope.launch {
|
||||
@@ -55,8 +57,9 @@ class TrainingPhotoSelectorViewModel @Inject constructor(
|
||||
// ✅ CRITICAL: Only get images with faces!
|
||||
val photos = imageDao.getImagesWithFaces()
|
||||
|
||||
// Sort by most faces first (better for training)
|
||||
val sorted = photos.sortedByDescending { it.faceCount ?: 0 }
|
||||
// ✅ FIX: Sort by LEAST faces first (solo photos = best training data)
|
||||
// faceCount=1 first, then faceCount=2, etc.
|
||||
val sorted = photos.sortedBy { it.faceCount ?: 999 }
|
||||
|
||||
_photosWithFaces.value = sorted
|
||||
|
||||
|
||||
@@ -0,0 +1,315 @@
|
||||
package com.placeholder.sherpai2.workers
|
||||
|
||||
import android.content.Context
|
||||
import android.graphics.BitmapFactory
|
||||
import android.net.Uri
|
||||
import androidx.hilt.work.HiltWorker
|
||||
import androidx.work.*
|
||||
import com.google.mlkit.vision.common.InputImage
|
||||
import com.google.mlkit.vision.face.FaceDetection
|
||||
import com.google.mlkit.vision.face.FaceDetectorOptions
|
||||
import com.placeholder.sherpai2.data.local.dao.FaceModelDao
|
||||
import com.placeholder.sherpai2.data.local.dao.ImageDao
|
||||
import com.placeholder.sherpai2.data.local.dao.PhotoFaceTagDao
|
||||
import com.placeholder.sherpai2.data.local.entity.PhotoFaceTagEntity
|
||||
import com.placeholder.sherpai2.ml.FaceNetModel
|
||||
import dagger.assisted.Assisted
|
||||
import dagger.assisted.AssistedInject
|
||||
import kotlinx.coroutines.Dispatchers
|
||||
import kotlinx.coroutines.tasks.await
|
||||
import kotlinx.coroutines.withContext
|
||||
|
||||
/**
|
||||
* LibraryScanWorker - Full library background scan for a trained person
|
||||
*
|
||||
* PURPOSE: After user approves validation preview, scan entire library
|
||||
*
|
||||
* STRATEGY:
|
||||
* 1. Load all photos with faces (from cache)
|
||||
* 2. Scan each photo for the trained person
|
||||
* 3. Create PhotoFaceTagEntity for matches
|
||||
* 4. Progressive updates to "People" tab
|
||||
* 5. Supports pause/resume via WorkManager
|
||||
*
|
||||
* SCHEDULING:
|
||||
* - Runs in background with progress notifications
|
||||
* - Can be cancelled by user
|
||||
* - Automatically retries on failure
|
||||
*
|
||||
* INPUT DATA:
|
||||
* - personId: String (UUID)
|
||||
* - personName: String (for notifications)
|
||||
* - threshold: Float (optional, default 0.70)
|
||||
*
|
||||
* OUTPUT DATA:
|
||||
* - matchesFound: Int
|
||||
* - photosScanned: Int
|
||||
* - errorMessage: String? (if failed)
|
||||
*/
|
||||
@HiltWorker
|
||||
class LibraryScanWorker @AssistedInject constructor(
|
||||
@Assisted private val context: Context,
|
||||
@Assisted workerParams: WorkerParameters,
|
||||
private val imageDao: ImageDao,
|
||||
private val faceModelDao: FaceModelDao,
|
||||
private val photoFaceTagDao: PhotoFaceTagDao
|
||||
) : CoroutineWorker(context, workerParams) {
|
||||
|
||||
companion object {
|
||||
const val WORK_NAME_PREFIX = "library_scan_"
|
||||
const val KEY_PERSON_ID = "person_id"
|
||||
const val KEY_PERSON_NAME = "person_name"
|
||||
const val KEY_THRESHOLD = "threshold"
|
||||
const val KEY_PROGRESS_CURRENT = "progress_current"
|
||||
const val KEY_PROGRESS_TOTAL = "progress_total"
|
||||
const val KEY_MATCHES_FOUND = "matches_found"
|
||||
const val KEY_PHOTOS_SCANNED = "photos_scanned"
|
||||
|
||||
private const val DEFAULT_THRESHOLD = 0.70f // Slightly looser than validation
|
||||
private const val BATCH_SIZE = 20
|
||||
private const val MAX_RETRIES = 3
|
||||
|
||||
/**
|
||||
* Create work request for library scan
|
||||
*/
|
||||
fun createWorkRequest(
|
||||
personId: String,
|
||||
personName: String,
|
||||
threshold: Float = DEFAULT_THRESHOLD
|
||||
): OneTimeWorkRequest {
|
||||
val inputData = workDataOf(
|
||||
KEY_PERSON_ID to personId,
|
||||
KEY_PERSON_NAME to personName,
|
||||
KEY_THRESHOLD to threshold
|
||||
)
|
||||
|
||||
return OneTimeWorkRequestBuilder<LibraryScanWorker>()
|
||||
.setInputData(inputData)
|
||||
.setConstraints(
|
||||
Constraints.Builder()
|
||||
.setRequiresBatteryNotLow(true) // Don't drain battery
|
||||
.build()
|
||||
)
|
||||
.addTag(WORK_NAME_PREFIX + personId)
|
||||
.build()
|
||||
}
|
||||
}
|
||||
|
||||
override suspend fun doWork(): Result = withContext(Dispatchers.Default) {
|
||||
try {
|
||||
// Get input parameters
|
||||
val personId = inputData.getString(KEY_PERSON_ID)
|
||||
?: return@withContext Result.failure(
|
||||
workDataOf("error" to "Missing person ID")
|
||||
)
|
||||
|
||||
val personName = inputData.getString(KEY_PERSON_NAME) ?: "Unknown"
|
||||
val threshold = inputData.getFloat(KEY_THRESHOLD, DEFAULT_THRESHOLD)
|
||||
|
||||
// Check if stopped
|
||||
if (isStopped) {
|
||||
return@withContext Result.failure()
|
||||
}
|
||||
|
||||
// Step 1: Get face model
|
||||
val faceModel = withContext(Dispatchers.IO) {
|
||||
faceModelDao.getFaceModelByPersonId(personId)
|
||||
} ?: return@withContext Result.failure(
|
||||
workDataOf("error" to "Face model not found")
|
||||
)
|
||||
|
||||
setProgress(workDataOf(
|
||||
KEY_PROGRESS_CURRENT to 0,
|
||||
KEY_PROGRESS_TOTAL to 100
|
||||
))
|
||||
|
||||
// Step 2: Get all photos with faces (from cache)
|
||||
val photosWithFaces = withContext(Dispatchers.IO) {
|
||||
imageDao.getImagesWithFaces()
|
||||
}
|
||||
|
||||
if (photosWithFaces.isEmpty()) {
|
||||
return@withContext Result.success(
|
||||
workDataOf(
|
||||
KEY_MATCHES_FOUND to 0,
|
||||
KEY_PHOTOS_SCANNED to 0
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
// Step 3: Initialize ML components
|
||||
val faceNetModel = FaceNetModel(context)
|
||||
val detector = FaceDetection.getClient(
|
||||
FaceDetectorOptions.Builder()
|
||||
.setPerformanceMode(FaceDetectorOptions.PERFORMANCE_MODE_ACCURATE)
|
||||
.setMinFaceSize(0.15f)
|
||||
.build()
|
||||
)
|
||||
|
||||
val modelEmbedding = faceModel.getEmbeddingArray()
|
||||
var matchesFound = 0
|
||||
var photosScanned = 0
|
||||
|
||||
try {
|
||||
// Step 4: Process in batches
|
||||
photosWithFaces.chunked(BATCH_SIZE).forEach { batch ->
|
||||
if (isStopped) {
|
||||
return@forEach
|
||||
}
|
||||
|
||||
// Scan batch
|
||||
batch.forEach { photo ->
|
||||
try {
|
||||
val tags = scanPhotoForPerson(
|
||||
photo = photo,
|
||||
personId = personId,
|
||||
faceModelId = faceModel.id,
|
||||
modelEmbedding = modelEmbedding,
|
||||
faceNetModel = faceNetModel,
|
||||
detector = detector,
|
||||
threshold = threshold
|
||||
)
|
||||
|
||||
if (tags.isNotEmpty()) {
|
||||
// Save tags
|
||||
withContext(Dispatchers.IO) {
|
||||
photoFaceTagDao.insertTags(tags)
|
||||
}
|
||||
matchesFound += tags.size
|
||||
}
|
||||
|
||||
photosScanned++
|
||||
|
||||
// Update progress
|
||||
if (photosScanned % 10 == 0) {
|
||||
val progress = (photosScanned * 100 / photosWithFaces.size)
|
||||
setProgress(workDataOf(
|
||||
KEY_PROGRESS_CURRENT to photosScanned,
|
||||
KEY_PROGRESS_TOTAL to photosWithFaces.size,
|
||||
KEY_MATCHES_FOUND to matchesFound
|
||||
))
|
||||
}
|
||||
|
||||
} catch (e: Exception) {
|
||||
// Skip failed photos, continue scanning
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Success!
|
||||
Result.success(
|
||||
workDataOf(
|
||||
KEY_MATCHES_FOUND to matchesFound,
|
||||
KEY_PHOTOS_SCANNED to photosScanned
|
||||
)
|
||||
)
|
||||
|
||||
} finally {
|
||||
faceNetModel.close()
|
||||
detector.close()
|
||||
}
|
||||
|
||||
} catch (e: Exception) {
|
||||
// Retry on failure
|
||||
if (runAttemptCount < MAX_RETRIES) {
|
||||
Result.retry()
|
||||
} else {
|
||||
Result.failure(
|
||||
workDataOf("error" to (e.message ?: "Unknown error"))
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Scan a single photo for the person
|
||||
*/
|
||||
private suspend fun scanPhotoForPerson(
|
||||
photo: com.placeholder.sherpai2.data.local.entity.ImageEntity,
|
||||
personId: String,
|
||||
faceModelId: String,
|
||||
modelEmbedding: FloatArray,
|
||||
faceNetModel: FaceNetModel,
|
||||
detector: com.google.mlkit.vision.face.FaceDetector,
|
||||
threshold: Float
|
||||
): List<PhotoFaceTagEntity> = withContext(Dispatchers.IO) {
|
||||
|
||||
try {
|
||||
// Load bitmap
|
||||
val bitmap = loadBitmapDownsampled(Uri.parse(photo.imageUri), 768)
|
||||
?: return@withContext emptyList()
|
||||
|
||||
// Detect faces
|
||||
val inputImage = InputImage.fromBitmap(bitmap, 0)
|
||||
val faces = detector.process(inputImage).await()
|
||||
|
||||
// Check each face
|
||||
val tags = faces.mapNotNull { face ->
|
||||
try {
|
||||
// Crop face
|
||||
val faceBitmap = android.graphics.Bitmap.createBitmap(
|
||||
bitmap,
|
||||
face.boundingBox.left.coerceIn(0, bitmap.width - 1),
|
||||
face.boundingBox.top.coerceIn(0, bitmap.height - 1),
|
||||
face.boundingBox.width().coerceAtMost(bitmap.width - face.boundingBox.left),
|
||||
face.boundingBox.height().coerceAtMost(bitmap.height - face.boundingBox.top)
|
||||
)
|
||||
|
||||
// Generate embedding
|
||||
val faceEmbedding = faceNetModel.generateEmbedding(faceBitmap)
|
||||
faceBitmap.recycle()
|
||||
|
||||
// Calculate similarity
|
||||
val similarity = faceNetModel.calculateSimilarity(faceEmbedding, modelEmbedding)
|
||||
|
||||
if (similarity >= threshold) {
|
||||
PhotoFaceTagEntity.create(
|
||||
imageId = photo.imageId,
|
||||
faceModelId = faceModelId,
|
||||
boundingBox = face.boundingBox,
|
||||
confidence = similarity,
|
||||
faceEmbedding = faceEmbedding
|
||||
)
|
||||
} else {
|
||||
null
|
||||
}
|
||||
} catch (e: Exception) {
|
||||
null
|
||||
}
|
||||
}
|
||||
|
||||
bitmap.recycle()
|
||||
tags
|
||||
|
||||
} catch (e: Exception) {
|
||||
emptyList()
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Load bitmap with downsampling for memory efficiency
|
||||
*/
|
||||
private fun loadBitmapDownsampled(uri: Uri, maxDim: Int): android.graphics.Bitmap? {
|
||||
return try {
|
||||
val opts = BitmapFactory.Options().apply { inJustDecodeBounds = true }
|
||||
context.contentResolver.openInputStream(uri)?.use {
|
||||
BitmapFactory.decodeStream(it, null, opts)
|
||||
}
|
||||
|
||||
var sample = 1
|
||||
while (opts.outWidth / sample > maxDim || opts.outHeight / sample > maxDim) {
|
||||
sample *= 2
|
||||
}
|
||||
|
||||
val finalOpts = BitmapFactory.Options().apply {
|
||||
inSampleSize = sample
|
||||
}
|
||||
|
||||
context.contentResolver.openInputStream(uri)?.use {
|
||||
BitmapFactory.decodeStream(it, null, finalOpts)
|
||||
}
|
||||
} catch (e: Exception) {
|
||||
null
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user