discover dez

This commit is contained in:
genki
2026-01-21 15:59:41 -05:00
parent 4474365cd6
commit fa68138c15
15 changed files with 3402 additions and 509 deletions

View File

@@ -4,7 +4,7 @@
<selectionStates> <selectionStates>
<SelectionState runConfigName="app"> <SelectionState runConfigName="app">
<option name="selectionMode" value="DROPDOWN" /> <option name="selectionMode" value="DROPDOWN" />
<DropdownSelection timestamp="2026-01-18T23:43:22.974426869Z"> <DropdownSelection timestamp="2026-01-21T20:49:28.305260931Z">
<Target type="DEFAULT_BOOT"> <Target type="DEFAULT_BOOT">
<handle> <handle>
<DeviceId pluginId="LocalEmulator" identifier="path=/home/genki/.android/avd/Medium_Phone.avd" /> <DeviceId pluginId="LocalEmulator" identifier="path=/home/genki/.android/avd/Medium_Phone.avd" />

View File

@@ -10,6 +10,11 @@ import com.placeholder.sherpai2.data.local.entity.*
/** /**
* AppDatabase - Complete database for SherpAI2 * AppDatabase - Complete database for SherpAI2
* *
* VERSION 10 - User Feedback Loop
* - Added UserFeedbackEntity for storing user corrections
* - Enables cluster refinement before training
* - Ground truth data for improving clustering
*
* VERSION 9 - Enhanced Face Cache * VERSION 9 - Enhanced Face Cache
* - Added FaceCacheEntity for per-face metadata * - Added FaceCacheEntity for per-face metadata
* - Stores quality scores, embeddings, bounding boxes * - Stores quality scores, embeddings, bounding boxes
@@ -38,14 +43,15 @@ import com.placeholder.sherpai2.data.local.entity.*
FaceModelEntity::class, FaceModelEntity::class,
PhotoFaceTagEntity::class, PhotoFaceTagEntity::class,
PersonAgeTagEntity::class, PersonAgeTagEntity::class,
FaceCacheEntity::class, // NEW: Per-face metadata cache FaceCacheEntity::class,
UserFeedbackEntity::class, // NEW: User corrections
// ===== COLLECTIONS ===== // ===== COLLECTIONS =====
CollectionEntity::class, CollectionEntity::class,
CollectionImageEntity::class, CollectionImageEntity::class,
CollectionFilterEntity::class CollectionFilterEntity::class
], ],
version = 9, // INCREMENTED for face cache version = 10, // INCREMENTED for user feedback
exportSchema = false exportSchema = false
) )
abstract class AppDatabase : RoomDatabase() { abstract class AppDatabase : RoomDatabase() {
@@ -63,7 +69,8 @@ abstract class AppDatabase : RoomDatabase() {
abstract fun faceModelDao(): FaceModelDao abstract fun faceModelDao(): FaceModelDao
abstract fun photoFaceTagDao(): PhotoFaceTagDao abstract fun photoFaceTagDao(): PhotoFaceTagDao
abstract fun personAgeTagDao(): PersonAgeTagDao abstract fun personAgeTagDao(): PersonAgeTagDao
abstract fun faceCacheDao(): FaceCacheDao // NEW abstract fun faceCacheDao(): FaceCacheDao
abstract fun userFeedbackDao(): UserFeedbackDao // NEW
// ===== COLLECTIONS DAO ===== // ===== COLLECTIONS DAO =====
abstract fun collectionDao(): CollectionDao abstract fun collectionDao(): CollectionDao
@@ -185,6 +192,10 @@ val MIGRATION_8_9 = object : Migration(8, 9) {
hasGoodLighting INTEGER NOT NULL, hasGoodLighting INTEGER NOT NULL,
embedding TEXT, embedding TEXT,
confidence REAL NOT NULL, confidence REAL NOT NULL,
imageWidth INTEGER NOT NULL DEFAULT 0,
imageHeight INTEGER NOT NULL DEFAULT 0,
cacheVersion INTEGER NOT NULL DEFAULT 1,
cachedAt INTEGER NOT NULL DEFAULT 0,
PRIMARY KEY(imageId, faceIndex), PRIMARY KEY(imageId, faceIndex),
FOREIGN KEY(imageId) REFERENCES images(imageId) ON DELETE CASCADE FOREIGN KEY(imageId) REFERENCES images(imageId) ON DELETE CASCADE
) )
@@ -197,13 +208,47 @@ val MIGRATION_8_9 = object : Migration(8, 9) {
} }
} }
/**
* MIGRATION 9 → 10 (User Feedback Loop)
*
* Changes:
* 1. Create user_feedback table for storing user corrections
*/
val MIGRATION_9_10 = object : Migration(9, 10) {
override fun migrate(database: SupportSQLiteDatabase) {
// Create user_feedback table
database.execSQL("""
CREATE TABLE IF NOT EXISTS user_feedback (
id TEXT PRIMARY KEY NOT NULL,
imageId TEXT NOT NULL,
faceIndex INTEGER NOT NULL,
clusterId INTEGER,
personId TEXT,
feedbackType TEXT NOT NULL,
originalConfidence REAL NOT NULL,
userNote TEXT,
timestamp INTEGER NOT NULL,
FOREIGN KEY(imageId) REFERENCES images(imageId) ON DELETE CASCADE,
FOREIGN KEY(personId) REFERENCES persons(id) ON DELETE CASCADE
)
""")
// Create indices for fast lookups
database.execSQL("CREATE INDEX IF NOT EXISTS index_user_feedback_imageId ON user_feedback(imageId)")
database.execSQL("CREATE INDEX IF NOT EXISTS index_user_feedback_clusterId ON user_feedback(clusterId)")
database.execSQL("CREATE INDEX IF NOT EXISTS index_user_feedback_personId ON user_feedback(personId)")
database.execSQL("CREATE INDEX IF NOT EXISTS index_user_feedback_feedbackType ON user_feedback(feedbackType)")
}
}
/** /**
* PRODUCTION MIGRATION NOTES: * PRODUCTION MIGRATION NOTES:
* *
* Before shipping to users, update DatabaseModule to use migration: * Before shipping to users, update DatabaseModule to use migrations:
* *
* Room.databaseBuilder(context, AppDatabase::class.java, "sherpai.db") * Room.databaseBuilder(context, AppDatabase::class.java, "sherpai.db")
* .addMigrations(MIGRATION_7_8) // Add this * .addMigrations(MIGRATION_7_8, MIGRATION_8_9, MIGRATION_9_10) // Add all migrations
* // .fallbackToDestructiveMigration() // Remove this * // .fallbackToDestructiveMigration() // Remove this
* .build() * .build()
*/ */

View File

@@ -8,7 +8,21 @@ import androidx.room.Update
import com.placeholder.sherpai2.data.local.entity.FaceCacheEntity import com.placeholder.sherpai2.data.local.entity.FaceCacheEntity
/** /**
* FaceCacheDao - Face detection cache with NEW queries for two-stage clustering * FaceCacheDao - YEAR-BASED filtering for temporal clustering
*
* NEW STRATEGY: Cluster by YEAR to handle children
* ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
* Problem: Child's face changes dramatically over years
* Solution: Cluster each YEAR separately
*
* Example:
* - 2020 photos → Emma age 2
* - 2021 photos → Emma age 3
* - 2022 photos → Emma age 4
*
* Result: Multiple clusters of same child at different ages
* User names: "Emma, Age 2", "Emma, Age 3", etc.
* System creates: Emma_Age_2, Emma_Age_3 submodels
*/ */
@Dao @Dao
interface FaceCacheDao { interface FaceCacheDao {
@@ -27,17 +41,161 @@ interface FaceCacheDao {
suspend fun update(faceCache: FaceCacheEntity) suspend fun update(faceCache: FaceCacheEntity)
// ═══════════════════════════════════════ // ═══════════════════════════════════════
// NEW CLUSTERING QUERIES ⭐ // YEAR-BASED QUERIES (NEW - For Children)
// ═══════════════════════════════════════ // ═══════════════════════════════════════
/** /**
* Get high-quality solo faces for Stage 1 clustering * Get premium solo faces from a SPECIFIC YEAR
* *
* Filters: * Use Case: Cluster children by age
* - Solo photos (faceCount = 1) * - Cluster 2020 photos separately from 2021 photos
* - Large faces (faceAreaRatio >= minFaceRatio) * - Same child at different ages = different clusters
* - Has embedding * - User names each: "Emma Age 2", "Emma Age 3"
*
* @param year Year in YYYY format (e.g., "2020")
* @param minRatio Minimum face size (default 5%)
* @param minQuality Minimum quality score (default 0.8)
* @param limit Maximum faces to return
*/ */
@Query("""
SELECT fc.*
FROM face_cache fc
INNER JOIN images i ON fc.imageId = i.imageId
WHERE i.faceCount = 1
AND fc.faceAreaRatio >= :minRatio
AND fc.qualityScore >= :minQuality
AND fc.embedding IS NOT NULL
AND strftime('%Y', i.capturedAt/1000, 'unixepoch') = :year
ORDER BY fc.qualityScore DESC, fc.faceAreaRatio DESC
LIMIT :limit
""")
suspend fun getPremiumSoloFacesByYear(
year: String,
minRatio: Float = 0.05f,
minQuality: Float = 0.8f,
limit: Int = 1000
): List<FaceCacheEntity>
/**
* Get premium solo faces from a YEAR RANGE
*
* Use Case: Cluster adults who don't change much
* - Photos from 2018-2023 can cluster together
* - Adults look similar across years
*
* @param startYear Start year in YYYY format
* @param endYear End year in YYYY format
*/
@Query("""
SELECT fc.*
FROM face_cache fc
INNER JOIN images i ON fc.imageId = i.imageId
WHERE i.faceCount = 1
AND fc.faceAreaRatio >= :minRatio
AND fc.qualityScore >= :minQuality
AND fc.embedding IS NOT NULL
AND strftime('%Y', i.capturedAt/1000, 'unixepoch') BETWEEN :startYear AND :endYear
ORDER BY fc.qualityScore DESC, fc.faceAreaRatio DESC
LIMIT :limit
""")
suspend fun getPremiumSoloFacesByYearRange(
startYear: String,
endYear: String,
minRatio: Float = 0.05f,
minQuality: Float = 0.8f,
limit: Int = 1000
): List<FaceCacheEntity>
/**
* Get years that have sufficient photos for clustering
*
* Returns years with at least N solo photos
* Use to determine which years to cluster
*
* Example output:
* ```
* [
* YearPhotoCount(year="2020", photoCount=150),
* YearPhotoCount(year="2021", photoCount=200),
* YearPhotoCount(year="2022", photoCount=180)
* ]
* ```
*/
@Query("""
SELECT
strftime('%Y', i.capturedAt/1000, 'unixepoch') as year,
COUNT(DISTINCT fc.imageId) as photoCount
FROM face_cache fc
INNER JOIN images i ON fc.imageId = i.imageId
WHERE i.faceCount = 1
AND fc.faceAreaRatio >= :minRatio
AND fc.embedding IS NOT NULL
GROUP BY year
HAVING photoCount >= :minPhotos
ORDER BY year ASC
""")
suspend fun getYearsWithSufficientPhotos(
minPhotos: Int = 20,
minRatio: Float = 0.03f
): List<YearPhotoCount>
/**
* Get month-by-month breakdown for a year
*
* For fine-grained age clustering (babies change monthly)
*/
@Query("""
SELECT
strftime('%Y-%m', i.capturedAt/1000, 'unixepoch') as yearMonth,
COUNT(DISTINCT fc.imageId) as photoCount
FROM face_cache fc
INNER JOIN images i ON fc.imageId = i.imageId
WHERE i.faceCount = 1
AND fc.embedding IS NOT NULL
AND strftime('%Y', i.capturedAt/1000, 'unixepoch') = :year
GROUP BY yearMonth
ORDER BY yearMonth ASC
""")
suspend fun getMonthlyBreakdownForYear(year: String): List<MonthPhotoCount>
// ═══════════════════════════════════════
// STANDARD QUERIES (Original)
// ═══════════════════════════════════════
@Query("""
SELECT fc.*
FROM face_cache fc
INNER JOIN images i ON fc.imageId = i.imageId
WHERE i.faceCount = 1
AND fc.faceAreaRatio >= :minRatio
AND fc.qualityScore >= :minQuality
AND fc.embedding IS NOT NULL
ORDER BY fc.qualityScore DESC, fc.faceAreaRatio DESC
LIMIT :limit
""")
suspend fun getPremiumSoloFaces(
minRatio: Float = 0.05f,
minQuality: Float = 0.8f,
limit: Int = 1000
): List<FaceCacheEntity>
@Query("""
SELECT fc.*
FROM face_cache fc
INNER JOIN images i ON fc.imageId = i.imageId
WHERE i.faceCount = 1
AND fc.faceAreaRatio >= :minRatio
AND fc.qualityScore >= :minQuality
AND fc.embedding IS NOT NULL
ORDER BY fc.qualityScore DESC, fc.faceAreaRatio DESC
LIMIT :limit
""")
suspend fun getStandardSoloFaces(
minRatio: Float = 0.03f,
minQuality: Float = 0.6f,
limit: Int = 2000
): List<FaceCacheEntity>
@Query(""" @Query("""
SELECT fc.* SELECT fc.*
FROM face_cache fc FROM face_cache fc
@@ -53,10 +211,6 @@ interface FaceCacheDao {
limit: Int = 2000 limit: Int = 2000
): List<FaceCacheEntity> ): List<FaceCacheEntity>
/**
* FALLBACK: Get ANY solo faces with embeddings
* Used if getHighQualitySoloFaces() returns empty
*/
@Query(""" @Query("""
SELECT fc.* SELECT fc.*
FROM face_cache fc FROM face_cache fc
@@ -70,12 +224,90 @@ interface FaceCacheDao {
limit: Int = 2000 limit: Int = 2000
): List<FaceCacheEntity> ): List<FaceCacheEntity>
// ═══════════════════════════════════════ @Query("""
// EXISTING QUERIES (keep as-is) SELECT fc.*
// ═══════════════════════════════════════ FROM face_cache fc
INNER JOIN images i ON fc.imageId = i.imageId
WHERE i.faceCount BETWEEN :minFaces AND :maxFaces
AND fc.faceAreaRatio >= :minRatio
AND fc.embedding IS NOT NULL
ORDER BY i.faceCount ASC, fc.faceAreaRatio DESC
""")
suspend fun getSmallGroupFaces(
minFaces: Int = 2,
maxFaces: Int = 5,
minRatio: Float = 0.02f
): List<FaceCacheEntity>
@Query("SELECT * FROM face_cache WHERE id = :id") @Query("""
suspend fun getFaceCacheById(id: String): FaceCacheEntity? SELECT fc.*
FROM face_cache fc
INNER JOIN images i ON fc.imageId = i.imageId
WHERE i.faceCount = :faceCount
AND fc.faceAreaRatio >= :minRatio
AND fc.embedding IS NOT NULL
ORDER BY fc.qualityScore DESC
""")
suspend fun getFacesByGroupSize(
faceCount: Int,
minRatio: Float = 0.02f
): List<FaceCacheEntity>
@Query("""
SELECT fc.*
FROM face_cache fc
INNER JOIN images i ON fc.imageId = i.imageId
WHERE i.faceCount = 1
AND fc.faceAreaRatio >= :minRatio
AND fc.embedding IS NOT NULL
AND fc.imageId NOT IN (:excludedImageIds)
ORDER BY fc.qualityScore DESC
LIMIT :limit
""")
suspend fun getSoloFacesExcluding(
excludedImageIds: List<String>,
minRatio: Float = 0.03f,
limit: Int = 2000
): List<FaceCacheEntity>
@Query("""
SELECT
i.faceCount,
COUNT(DISTINCT i.imageId) as imageCount,
AVG(fc.faceAreaRatio) as avgFaceSize,
AVG(fc.qualityScore) as avgQuality,
COUNT(fc.embedding IS NOT NULL) as hasEmbedding
FROM images i
LEFT JOIN face_cache fc ON i.imageId = fc.imageId
WHERE i.hasFaces = 1
GROUP BY i.faceCount
ORDER BY i.faceCount ASC
""")
suspend fun getLibraryQualityDistribution(): List<LibraryQualityStat>
@Query("""
SELECT COUNT(*)
FROM face_cache fc
INNER JOIN images i ON fc.imageId = i.imageId
WHERE i.faceCount = 1
AND fc.faceAreaRatio >= :minRatio
AND fc.qualityScore >= :minQuality
AND fc.embedding IS NOT NULL
""")
suspend fun countPremiumSoloFaces(
minRatio: Float = 0.05f,
minQuality: Float = 0.8f
): Int
@Query("""
SELECT COUNT(*)
FROM face_cache
WHERE embedding IS NOT NULL
""")
suspend fun countFacesWithEmbeddings(): Int
@Query("SELECT * FROM face_cache WHERE imageId = :imageId AND faceIndex = :faceIndex")
suspend fun getFaceCacheByKey(imageId: String, faceIndex: Int): FaceCacheEntity?
@Query("SELECT * FROM face_cache WHERE imageId = :imageId ORDER BY faceIndex") @Query("SELECT * FROM face_cache WHERE imageId = :imageId ORDER BY faceIndex")
suspend fun getFaceCacheForImage(imageId: String): List<FaceCacheEntity> suspend fun getFaceCacheForImage(imageId: String): List<FaceCacheEntity>
@@ -88,4 +320,25 @@ interface FaceCacheDao {
@Query("DELETE FROM face_cache WHERE cacheVersion < :version") @Query("DELETE FROM face_cache WHERE cacheVersion < :version")
suspend fun deleteOldVersions(version: Int) suspend fun deleteOldVersions(version: Int)
} }
/**
* Result classes for year-based queries
*/
data class YearPhotoCount(
val year: String,
val photoCount: Int
)
data class MonthPhotoCount(
val yearMonth: String, // "2020-05"
val photoCount: Int
)
data class LibraryQualityStat(
val faceCount: Int,
val imageCount: Int,
val avgFaceSize: Float,
val avgQuality: Float,
val hasEmbedding: Int
)

View File

@@ -0,0 +1,212 @@
package com.placeholder.sherpai2.data.local.dao
import androidx.room.*
import com.placeholder.sherpai2.data.local.entity.FeedbackType
import com.placeholder.sherpai2.data.local.entity.UserFeedbackEntity
import kotlinx.coroutines.flow.Flow
/**
* UserFeedbackDao - Query user corrections and feedback
*
* KEY QUERIES:
* - Get feedback for cluster validation
* - Find rejected faces to exclude from training
* - Track feedback statistics for quality metrics
* - Support cluster refinement workflow
*/
@Dao
interface UserFeedbackDao {
// ═══════════════════════════════════════
// INSERT / UPDATE
// ═══════════════════════════════════════
@Insert(onConflict = OnConflictStrategy.REPLACE)
suspend fun insert(feedback: UserFeedbackEntity): Long
@Insert(onConflict = OnConflictStrategy.REPLACE)
suspend fun insertAll(feedbacks: List<UserFeedbackEntity>)
@Update
suspend fun update(feedback: UserFeedbackEntity)
@Delete
suspend fun delete(feedback: UserFeedbackEntity)
// ═══════════════════════════════════════
// CLUSTER VALIDATION QUERIES
// ═══════════════════════════════════════
/**
* Get all feedback for a cluster
* Used during validation to see what user has reviewed
*/
@Query("SELECT * FROM user_feedback WHERE clusterId = :clusterId ORDER BY timestamp DESC")
suspend fun getFeedbackForCluster(clusterId: Int): List<UserFeedbackEntity>
/**
* Get rejected faces for a cluster
* These faces should be EXCLUDED from training
*/
@Query("""
SELECT * FROM user_feedback
WHERE clusterId = :clusterId
AND feedbackType = 'REJECTED_MATCH'
""")
suspend fun getRejectedFacesForCluster(clusterId: Int): List<UserFeedbackEntity>
/**
* Get confirmed faces for a cluster
* These faces are SAFE for training
*/
@Query("""
SELECT * FROM user_feedback
WHERE clusterId = :clusterId
AND feedbackType = 'CONFIRMED_MATCH'
""")
suspend fun getConfirmedFacesForCluster(clusterId: Int): List<UserFeedbackEntity>
/**
* Count feedback by type for a cluster
* Used to show stats: "15 confirmed, 3 rejected"
*/
@Query("""
SELECT feedbackType, COUNT(*) as count
FROM user_feedback
WHERE clusterId = :clusterId
GROUP BY feedbackType
""")
suspend fun getFeedbackStatsByCluster(clusterId: Int): List<FeedbackStat>
// ═══════════════════════════════════════
// PERSON FEEDBACK QUERIES
// ═══════════════════════════════════════
/**
* Get all feedback for a person
* Used to show history of corrections
*/
@Query("SELECT * FROM user_feedback WHERE personId = :personId ORDER BY timestamp DESC")
suspend fun getFeedbackForPerson(personId: String): List<UserFeedbackEntity>
/**
* Get rejected faces for a person
* User said "this is NOT X" - exclude from model improvement
*/
@Query("""
SELECT * FROM user_feedback
WHERE personId = :personId
AND feedbackType = 'REJECTED_MATCH'
""")
suspend fun getRejectedFacesForPerson(personId: String): List<UserFeedbackEntity>
/**
* Flow version for reactive UI
*/
@Query("SELECT * FROM user_feedback WHERE personId = :personId ORDER BY timestamp DESC")
fun observeFeedbackForPerson(personId: String): Flow<List<UserFeedbackEntity>>
// ═══════════════════════════════════════
// IMAGE QUERIES
// ═══════════════════════════════════════
/**
* Get feedback for a specific image
*/
@Query("SELECT * FROM user_feedback WHERE imageId = :imageId")
suspend fun getFeedbackForImage(imageId: String): List<UserFeedbackEntity>
/**
* Check if user has provided feedback for a specific face
*/
@Query("""
SELECT EXISTS(
SELECT 1 FROM user_feedback
WHERE imageId = :imageId
AND faceIndex = :faceIndex
)
""")
suspend fun hasFeedbackForFace(imageId: String, faceIndex: Int): Boolean
// ═══════════════════════════════════════
// STATISTICS & ANALYTICS
// ═══════════════════════════════════════
/**
* Get total feedback count
*/
@Query("SELECT COUNT(*) FROM user_feedback")
suspend fun getTotalFeedbackCount(): Int
/**
* Get feedback count by type (global)
*/
@Query("""
SELECT feedbackType, COUNT(*) as count
FROM user_feedback
GROUP BY feedbackType
""")
suspend fun getGlobalFeedbackStats(): List<FeedbackStat>
/**
* Get average original confidence for rejected faces
* Helps identify if low confidence → more rejections
*/
@Query("""
SELECT AVG(originalConfidence)
FROM user_feedback
WHERE feedbackType = 'REJECTED_MATCH'
""")
suspend fun getAverageConfidenceForRejectedFaces(): Float?
/**
* Find faces with low confidence that were confirmed
* These are "surprising successes" - model worked despite low confidence
*/
@Query("""
SELECT * FROM user_feedback
WHERE feedbackType = 'CONFIRMED_MATCH'
AND originalConfidence < :threshold
ORDER BY originalConfidence ASC
""")
suspend fun getLowConfidenceSuccesses(threshold: Float = 0.7f): List<UserFeedbackEntity>
// ═══════════════════════════════════════
// CLEANUP
// ═══════════════════════════════════════
/**
* Delete all feedback for a cluster
* Called when cluster is deleted or refined
*/
@Query("DELETE FROM user_feedback WHERE clusterId = :clusterId")
suspend fun deleteFeedbackForCluster(clusterId: Int)
/**
* Delete all feedback for a person
* Called when person is deleted
*/
@Query("DELETE FROM user_feedback WHERE personId = :personId")
suspend fun deleteFeedbackForPerson(personId: String)
/**
* Delete old feedback (older than X days)
* Keep database size manageable
*/
@Query("DELETE FROM user_feedback WHERE timestamp < :cutoffTimestamp")
suspend fun deleteOldFeedback(cutoffTimestamp: Long)
/**
* Clear all feedback (nuclear option)
*/
@Query("DELETE FROM user_feedback")
suspend fun deleteAll()
}
/**
* Result class for feedback statistics
*/
data class FeedbackStat(
val feedbackType: String,
val count: Int
)

View File

@@ -0,0 +1,161 @@
package com.placeholder.sherpai2.data.local.entity
import androidx.room.Entity
import androidx.room.ForeignKey
import androidx.room.Index
import androidx.room.PrimaryKey
import java.util.UUID
/**
* UserFeedbackEntity - Stores user corrections during cluster validation
*
* PURPOSE:
* - Capture which faces user marked as correct/incorrect
* - Ground truth data for improving clustering
* - Enable cluster refinement before training
* - Track confidence in automated detections
*
* USAGE FLOW:
* 1. Clustering creates initial clusters
* 2. User reviews ValidationPreview
* 3. User swipes faces: ✅ Correct / ❌ Incorrect
* 4. Feedback stored here
* 5. If too many incorrect → Re-cluster without those faces
* 6. If approved → Train model with confirmed faces only
*
* INDEXES:
* - imageId: Fast lookup of feedback for specific images
* - clusterId: Get all feedback for a cluster
* - feedbackType: Filter by correction type
* - personId: Track feedback after person created
*/
@Entity(
tableName = "user_feedback",
foreignKeys = [
ForeignKey(
entity = ImageEntity::class,
parentColumns = ["imageId"],
childColumns = ["imageId"],
onDelete = ForeignKey.CASCADE
),
ForeignKey(
entity = PersonEntity::class,
parentColumns = ["id"],
childColumns = ["personId"],
onDelete = ForeignKey.CASCADE
)
],
indices = [
Index(value = ["imageId"]),
Index(value = ["clusterId"]),
Index(value = ["personId"]),
Index(value = ["feedbackType"])
]
)
data class UserFeedbackEntity(
@PrimaryKey
val id: String = UUID.randomUUID().toString(),
/**
* Image containing the face
*/
val imageId: String,
/**
* Face index within the image (0-based)
* Multiple faces per image possible
*/
val faceIndex: Int,
/**
* Cluster ID from clustering (before person created)
* Null if feedback given after person exists
*/
val clusterId: Int?,
/**
* Person ID if feedback is about an existing person
* Null during initial cluster validation
*/
val personId: String?,
/**
* Type of feedback user provided
*/
val feedbackType: String, // FeedbackType enum stored as string
/**
* Confidence score that led to this face being shown
* Helps identify if low confidence = more errors
*/
val originalConfidence: Float,
/**
* Optional user note
*/
val userNote: String? = null,
/**
* When feedback was provided
*/
val timestamp: Long = System.currentTimeMillis()
) {
companion object {
fun create(
imageId: String,
faceIndex: Int,
clusterId: Int? = null,
personId: String? = null,
feedbackType: FeedbackType,
originalConfidence: Float,
userNote: String? = null
): UserFeedbackEntity {
return UserFeedbackEntity(
imageId = imageId,
faceIndex = faceIndex,
clusterId = clusterId,
personId = personId,
feedbackType = feedbackType.name,
originalConfidence = originalConfidence,
userNote = userNote
)
}
}
fun getFeedbackType(): FeedbackType {
return try {
FeedbackType.valueOf(feedbackType)
} catch (e: Exception) {
FeedbackType.UNCERTAIN
}
}
}
/**
* FeedbackType - Types of user corrections
*/
enum class FeedbackType {
/**
* User confirmed this face IS the person
* Boosts confidence, use for training
*/
CONFIRMED_MATCH,
/**
* User said this face is NOT the person
* Remove from cluster, exclude from training
*/
REJECTED_MATCH,
/**
* User marked as outlier during cluster review
* Face doesn't belong in this cluster
*/
MARKED_OUTLIER,
/**
* User is uncertain
* Skip this face for training, revisit later
*/
UNCERTAIN
}

View File

@@ -5,6 +5,7 @@ import androidx.room.Room
import com.placeholder.sherpai2.data.local.AppDatabase import com.placeholder.sherpai2.data.local.AppDatabase
import com.placeholder.sherpai2.data.local.MIGRATION_7_8 import com.placeholder.sherpai2.data.local.MIGRATION_7_8
import com.placeholder.sherpai2.data.local.MIGRATION_8_9 import com.placeholder.sherpai2.data.local.MIGRATION_8_9
import com.placeholder.sherpai2.data.local.MIGRATION_9_10
import com.placeholder.sherpai2.data.local.dao.* import com.placeholder.sherpai2.data.local.dao.*
import dagger.Module import dagger.Module
import dagger.Provides import dagger.Provides
@@ -16,6 +17,10 @@ import javax.inject.Singleton
/** /**
* DatabaseModule - Provides database and ALL DAOs * DatabaseModule - Provides database and ALL DAOs
* *
* VERSION 10 UPDATES:
* - Added UserFeedbackDao for cluster refinement
* - Added MIGRATION_9_10
*
* VERSION 9 UPDATES: * VERSION 9 UPDATES:
* - Added FaceCacheDao for per-face metadata * - Added FaceCacheDao for per-face metadata
* - Added MIGRATION_8_9 * - Added MIGRATION_8_9
@@ -44,7 +49,7 @@ object DatabaseModule {
.fallbackToDestructiveMigration(dropAllTables = true) .fallbackToDestructiveMigration(dropAllTables = true)
// PRODUCTION MODE: Uncomment this and remove fallbackToDestructiveMigration() // PRODUCTION MODE: Uncomment this and remove fallbackToDestructiveMigration()
// .addMigrations(MIGRATION_7_8, MIGRATION_8_9) // .addMigrations(MIGRATION_7_8, MIGRATION_8_9, MIGRATION_9_10)
.build() .build()
@@ -96,6 +101,10 @@ object DatabaseModule {
fun provideFaceCacheDao(db: AppDatabase): FaceCacheDao = fun provideFaceCacheDao(db: AppDatabase): FaceCacheDao =
db.faceCacheDao() db.faceCacheDao()
@Provides
fun provideUserFeedbackDao(db: AppDatabase): UserFeedbackDao =
db.userFeedbackDao()
// ===== COLLECTIONS DAOs ===== // ===== COLLECTIONS DAOs =====
@Provides @Provides

View File

@@ -2,12 +2,11 @@ package com.placeholder.sherpai2.di
import android.content.Context import android.content.Context
import androidx.work.WorkManager import androidx.work.WorkManager
import com.placeholder.sherpai2.data.local.dao.FaceModelDao import com.placeholder.sherpai2.data.local.dao.*
import com.placeholder.sherpai2.data.local.dao.ImageDao
import com.placeholder.sherpai2.data.local.dao.PersonDao
import com.placeholder.sherpai2.data.local.dao.PhotoFaceTagDao
import com.placeholder.sherpai2.data.repository.FaceRecognitionRepository import com.placeholder.sherpai2.data.repository.FaceRecognitionRepository
import com.placeholder.sherpai2.data.repository.TaggingRepositoryImpl import com.placeholder.sherpai2.data.repository.TaggingRepositoryImpl
import com.placeholder.sherpai2.domain.clustering.ClusterQualityAnalyzer
import com.placeholder.sherpai2.domain.clustering.ClusterRefinementService
import com.placeholder.sherpai2.domain.repository.ImageRepository import com.placeholder.sherpai2.domain.repository.ImageRepository
import com.placeholder.sherpai2.domain.repository.ImageRepositoryImpl import com.placeholder.sherpai2.domain.repository.ImageRepositoryImpl
import com.placeholder.sherpai2.domain.repository.TaggingRepository import com.placeholder.sherpai2.domain.repository.TaggingRepository
@@ -26,6 +25,8 @@ import javax.inject.Singleton
* UPDATED TO INCLUDE: * UPDATED TO INCLUDE:
* - FaceRecognitionRepository for face recognition operations * - FaceRecognitionRepository for face recognition operations
* - ValidationScanService for post-training validation * - ValidationScanService for post-training validation
* - ClusterRefinementService for user feedback loop (NEW)
* - ClusterQualityAnalyzer for cluster validation
* - WorkManager for background tasks * - WorkManager for background tasks
*/ */
@Module @Module
@@ -72,7 +73,7 @@ abstract class RepositoryModule {
} }
/** /**
* Provide ValidationScanService (NEW) * Provide ValidationScanService
*/ */
@Provides @Provides
@Singleton @Singleton
@@ -88,6 +89,34 @@ abstract class RepositoryModule {
) )
} }
/**
* Provide ClusterRefinementService (NEW)
* Handles user feedback and cluster refinement workflow
*/
@Provides
@Singleton
fun provideClusterRefinementService(
faceCacheDao: FaceCacheDao,
userFeedbackDao: UserFeedbackDao,
qualityAnalyzer: ClusterQualityAnalyzer
): ClusterRefinementService {
return ClusterRefinementService(
faceCacheDao = faceCacheDao,
userFeedbackDao = userFeedbackDao,
qualityAnalyzer = qualityAnalyzer
)
}
/**
* Provide ClusterQualityAnalyzer
* Validates cluster quality before training
*/
@Provides
@Singleton
fun provideClusterQualityAnalyzer(): ClusterQualityAnalyzer {
return ClusterQualityAnalyzer()
}
/** /**
* Provide WorkManager for background tasks * Provide WorkManager for background tasks
*/ */

View File

@@ -0,0 +1,415 @@
package com.placeholder.sherpai2.domain.clustering
import android.util.Log
import com.placeholder.sherpai2.data.local.dao.FaceCacheDao
import com.placeholder.sherpai2.data.local.dao.UserFeedbackDao
import com.placeholder.sherpai2.data.local.entity.FeedbackType
import com.placeholder.sherpai2.data.local.entity.UserFeedbackEntity
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.withContext
import javax.inject.Inject
import javax.inject.Singleton
import kotlin.math.sqrt
/**
* ClusterRefinementService - Handle user feedback and cluster refinement
*
* PURPOSE:
* ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
* Close the feedback loop between user corrections and clustering
*
* WORKFLOW:
* ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
* 1. Clustering produces initial clusters
* 2. User reviews in ValidationPreview
* 3. User marks faces: ✅ Correct / ❌ Incorrect / ❓ Uncertain
* 4. If too many incorrect → Call refineCluster()
* 5. Re-cluster WITHOUT incorrect faces
* 6. Show updated validation preview
* 7. Repeat until user approves
*
* BENEFITS:
* ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
* - Prevents bad models from being created
* - Learns from user corrections
* - Iterative improvement
* - Ground truth data for future enhancements
*/
@Singleton
class ClusterRefinementService @Inject constructor(
private val faceCacheDao: FaceCacheDao,
private val userFeedbackDao: UserFeedbackDao,
private val qualityAnalyzer: ClusterQualityAnalyzer
) {
companion object {
private const val TAG = "ClusterRefinement"
// Thresholds for refinement decisions
private const val MIN_REJECTION_RATIO = 0.15f // 15% rejected → refine
private const val MIN_CONFIRMED_FACES = 6 // Need at least 6 good faces
private const val MAX_REFINEMENT_ITERATIONS = 3 // Prevent infinite loops
}
/**
* Store user feedback for faces in a cluster
*
* @param cluster The cluster being reviewed
* @param feedbackMap Map of face index → feedback type
* @param originalConfidences Map of face index → original detection confidence
* @return Number of feedback items stored
*/
suspend fun storeFeedback(
cluster: FaceCluster,
feedbackMap: Map<DetectedFaceWithEmbedding, FeedbackType>,
originalConfidences: Map<DetectedFaceWithEmbedding, Float> = emptyMap()
): Int = withContext(Dispatchers.IO) {
val feedbackEntities = feedbackMap.map { (face, feedbackType) ->
UserFeedbackEntity.create(
imageId = face.imageId,
faceIndex = 0, // We don't track faceIndex in DetectedFaceWithEmbedding yet
clusterId = cluster.clusterId,
personId = null, // Not created yet
feedbackType = feedbackType,
originalConfidence = originalConfidences[face] ?: face.confidence
)
}
userFeedbackDao.insertAll(feedbackEntities)
Log.d(TAG, "Stored ${feedbackEntities.size} feedback items for cluster ${cluster.clusterId}")
feedbackEntities.size
}
/**
* Check if cluster needs refinement based on user feedback
*
* Criteria:
* - Too many rejected faces (> 15%)
* - Too few confirmed faces (< 6)
* - High rejection rate for cluster suggests mixed identities
*
* @return RefinementRecommendation with action and reason
*/
suspend fun shouldRefineCluster(
cluster: FaceCluster
): RefinementRecommendation = withContext(Dispatchers.Default) {
val feedback = withContext(Dispatchers.IO) {
userFeedbackDao.getFeedbackForCluster(cluster.clusterId)
}
if (feedback.isEmpty()) {
return@withContext RefinementRecommendation(
shouldRefine = false,
reason = "No feedback provided yet"
)
}
val totalFeedback = feedback.size
val rejectedCount = feedback.count { it.getFeedbackType() == FeedbackType.REJECTED_MATCH }
val confirmedCount = feedback.count { it.getFeedbackType() == FeedbackType.CONFIRMED_MATCH }
val uncertainCount = feedback.count { it.getFeedbackType() == FeedbackType.UNCERTAIN }
val rejectionRatio = rejectedCount.toFloat() / totalFeedback.toFloat()
Log.d(TAG, "Cluster ${cluster.clusterId} feedback: " +
"$confirmedCount confirmed, $rejectedCount rejected, $uncertainCount uncertain")
// Check 1: Too many rejections
if (rejectionRatio > MIN_REJECTION_RATIO) {
return@withContext RefinementRecommendation(
shouldRefine = true,
reason = "High rejection rate (${(rejectionRatio * 100).toInt()}%) suggests mixed identities",
confirmedCount = confirmedCount,
rejectedCount = rejectedCount,
uncertainCount = uncertainCount
)
}
// Check 2: Too few confirmed faces after removing rejected
val effectiveConfirmedCount = confirmedCount - rejectedCount
if (effectiveConfirmedCount < MIN_CONFIRMED_FACES) {
return@withContext RefinementRecommendation(
shouldRefine = true,
reason = "Only $effectiveConfirmedCount faces remain after removing rejected faces (need $MIN_CONFIRMED_FACES)",
confirmedCount = confirmedCount,
rejectedCount = rejectedCount,
uncertainCount = uncertainCount
)
}
// Cluster is good!
RefinementRecommendation(
shouldRefine = false,
reason = "Cluster quality acceptable: $confirmedCount confirmed, $rejectedCount rejected",
confirmedCount = confirmedCount,
rejectedCount = rejectedCount,
uncertainCount = uncertainCount
)
}
/**
* Refine cluster by removing rejected faces and re-clustering
*
* ALGORITHM:
* 1. Get all rejected faces from feedback
* 2. Remove those faces from cluster
* 3. Recalculate cluster centroid
* 4. Re-run quality analysis
* 5. Return refined cluster
*
* @param cluster Original cluster to refine
* @return Refined cluster without rejected faces
*/
suspend fun refineCluster(
cluster: FaceCluster,
iterationNumber: Int = 1
): ClusterRefinementResult = withContext(Dispatchers.Default) {
Log.d(TAG, "Refining cluster ${cluster.clusterId} (iteration $iterationNumber)")
// Guard against infinite refinement
if (iterationNumber > MAX_REFINEMENT_ITERATIONS) {
return@withContext ClusterRefinementResult(
success = false,
refinedCluster = null,
errorMessage = "Maximum refinement iterations reached. Cluster quality still poor.",
facesRemoved = 0,
facesRemaining = cluster.faces.size
)
}
// Get rejected faces
val feedback = withContext(Dispatchers.IO) {
userFeedbackDao.getRejectedFacesForCluster(cluster.clusterId)
}
val rejectedImageIds = feedback.map { it.imageId }.toSet()
if (rejectedImageIds.isEmpty()) {
return@withContext ClusterRefinementResult(
success = false,
refinedCluster = cluster,
errorMessage = "No rejected faces to remove",
facesRemoved = 0,
facesRemaining = cluster.faces.size
)
}
// Remove rejected faces
val cleanFaces = cluster.faces.filter { it.imageId !in rejectedImageIds }
Log.d(TAG, "Removed ${rejectedImageIds.size} rejected faces, ${cleanFaces.size} remain")
// Check if we have enough faces left
if (cleanFaces.size < MIN_CONFIRMED_FACES) {
return@withContext ClusterRefinementResult(
success = false,
refinedCluster = null,
errorMessage = "Too few faces remaining after removing rejected faces (${cleanFaces.size} < $MIN_CONFIRMED_FACES)",
facesRemoved = rejectedImageIds.size,
facesRemaining = cleanFaces.size
)
}
// Recalculate centroid
val newCentroid = calculateCentroid(cleanFaces.map { it.embedding })
// Select new representative faces
val newRepresentatives = selectRepresentativeFacesByCentroid(cleanFaces, newCentroid, count = 6)
// Create refined cluster
val refinedCluster = FaceCluster(
clusterId = cluster.clusterId,
faces = cleanFaces,
representativeFaces = newRepresentatives,
photoCount = cleanFaces.map { it.imageId }.distinct().size,
averageConfidence = cleanFaces.map { it.confidence }.average().toFloat(),
estimatedAge = cluster.estimatedAge, // Keep same estimate
potentialSiblings = cluster.potentialSiblings // Keep same siblings
)
// Re-run quality analysis
val qualityResult = qualityAnalyzer.analyzeCluster(refinedCluster)
Log.d(TAG, "Refined cluster quality: ${qualityResult.qualityTier} " +
"(${qualityResult.cleanFaceCount} clean faces)")
ClusterRefinementResult(
success = true,
refinedCluster = refinedCluster,
qualityResult = qualityResult,
facesRemoved = rejectedImageIds.size,
facesRemaining = cleanFaces.size,
newQualityTier = qualityResult.qualityTier
)
}
/**
* Get feedback summary for cluster
*
* Returns human-readable summary like:
* "15 confirmed, 3 rejected, 2 uncertain"
*/
suspend fun getFeedbackSummary(clusterId: Int): FeedbackSummary = withContext(Dispatchers.IO) {
val feedback = userFeedbackDao.getFeedbackForCluster(clusterId)
val confirmed = feedback.count { it.getFeedbackType() == FeedbackType.CONFIRMED_MATCH }
val rejected = feedback.count { it.getFeedbackType() == FeedbackType.REJECTED_MATCH }
val uncertain = feedback.count { it.getFeedbackType() == FeedbackType.UNCERTAIN }
val outliers = feedback.count { it.getFeedbackType() == FeedbackType.MARKED_OUTLIER }
FeedbackSummary(
totalFeedback = feedback.size,
confirmedCount = confirmed,
rejectedCount = rejected,
uncertainCount = uncertain,
outlierCount = outliers,
rejectionRatio = if (feedback.isNotEmpty()) {
rejected.toFloat() / feedback.size.toFloat()
} else 0f
)
}
/**
* Filter cluster to only confirmed faces
*
* Use Case: User has reviewed cluster, now create model using ONLY confirmed faces
*/
suspend fun getConfirmedFaces(cluster: FaceCluster): List<DetectedFaceWithEmbedding> =
withContext(Dispatchers.Default) {
val confirmedFeedback = withContext(Dispatchers.IO) {
userFeedbackDao.getConfirmedFacesForCluster(cluster.clusterId)
}
val confirmedImageIds = confirmedFeedback.map { it.imageId }.toSet()
// If no explicit confirmations, assume all non-rejected faces are OK
if (confirmedImageIds.isEmpty()) {
val rejectedFeedback = withContext(Dispatchers.IO) {
userFeedbackDao.getRejectedFacesForCluster(cluster.clusterId)
}
val rejectedImageIds = rejectedFeedback.map { it.imageId }.toSet()
return@withContext cluster.faces.filter { it.imageId !in rejectedImageIds }
}
// Return only explicitly confirmed faces
cluster.faces.filter { it.imageId in confirmedImageIds }
}
/**
* Calculate centroid from embeddings
*/
private fun calculateCentroid(embeddings: List<FloatArray>): FloatArray {
if (embeddings.isEmpty()) return FloatArray(0)
val size = embeddings.first().size
val centroid = FloatArray(size) { 0f }
embeddings.forEach { embedding ->
for (i in embedding.indices) {
centroid[i] += embedding[i]
}
}
val count = embeddings.size.toFloat()
for (i in centroid.indices) {
centroid[i] /= count
}
// Normalize
val norm = sqrt(centroid.map { it * it }.sum())
return if (norm > 0) {
centroid.map { it / norm }.toFloatArray()
} else {
centroid
}
}
/**
* Select representative faces closest to centroid
*/
private fun selectRepresentativeFacesByCentroid(
faces: List<DetectedFaceWithEmbedding>,
centroid: FloatArray,
count: Int
): List<DetectedFaceWithEmbedding> {
if (faces.size <= count) return faces
val facesWithDistance = faces.map { face ->
val similarity = cosineSimilarity(face.embedding, centroid)
val distance = 1 - similarity
face to distance
}
return facesWithDistance
.sortedBy { it.second }
.take(count)
.map { it.first }
}
/**
* Cosine similarity calculation
*/
private fun cosineSimilarity(a: FloatArray, b: FloatArray): Float {
var dotProduct = 0f
var normA = 0f
var normB = 0f
for (i in a.indices) {
dotProduct += a[i] * b[i]
normA += a[i] * a[i]
normB += b[i] * b[i]
}
return dotProduct / (sqrt(normA) * sqrt(normB))
}
}
/**
* Result of refinement analysis
*/
data class RefinementRecommendation(
val shouldRefine: Boolean,
val reason: String,
val confirmedCount: Int = 0,
val rejectedCount: Int = 0,
val uncertainCount: Int = 0
)
/**
* Result of cluster refinement
*/
data class ClusterRefinementResult(
val success: Boolean,
val refinedCluster: FaceCluster?,
val qualityResult: ClusterQualityResult? = null,
val errorMessage: String? = null,
val facesRemoved: Int,
val facesRemaining: Int,
val newQualityTier: ClusterQualityTier? = null
)
/**
* Summary of user feedback for a cluster
*/
data class FeedbackSummary(
val totalFeedback: Int,
val confirmedCount: Int,
val rejectedCount: Int,
val uncertainCount: Int,
val outlierCount: Int,
val rejectionRatio: Float
) {
fun getDisplayText(): String {
val parts = mutableListOf<String>()
if (confirmedCount > 0) parts.add("$confirmedCount confirmed")
if (rejectedCount > 0) parts.add("$rejectedCount rejected")
if (uncertainCount > 0) parts.add("$uncertainCount uncertain")
return parts.joinToString(", ")
}
}

View File

@@ -5,6 +5,7 @@ import android.graphics.Bitmap
import android.graphics.BitmapFactory import android.graphics.BitmapFactory
import android.net.Uri import android.net.Uri
import android.util.Log import android.util.Log
import com.google.android.gms.tasks.Tasks
import com.google.mlkit.vision.common.InputImage import com.google.mlkit.vision.common.InputImage
import com.google.mlkit.vision.face.FaceDetection import com.google.mlkit.vision.face.FaceDetection
import com.google.mlkit.vision.face.FaceDetectorOptions import com.google.mlkit.vision.face.FaceDetectorOptions
@@ -23,24 +24,16 @@ import kotlinx.coroutines.withContext
import javax.inject.Inject import javax.inject.Inject
import javax.inject.Singleton import javax.inject.Singleton
import kotlin.math.sqrt import kotlin.math.sqrt
import kotlin.random.Random
/** /**
* FaceClusteringService - HYBRID version with automatic fallback * FaceClusteringService - ENHANCED with quality filtering & deterministic results
* *
* STRATEGY: * NEW FEATURES:
* 1. Try to use face cache (fast path) - 10x faster * ✅ FaceQualityFilter integration (eliminates clothing/ghost faces)
* 2. Fall back to classic method if cache empty (compatible) * ✅ Deterministic clustering (seeded random)
* 3. Load SOLO PHOTOS ONLY (faceCount = 1) for clustering * ✅ Better thresholds (finds Brad Pitt)
* 4. Detect faces and generate embeddings (parallel) * ✅ Faster processing (filters garbage early)
* 5. Cluster using DBSCAN (epsilon=0.18, minPoints=3)
* 6. Analyze clusters for age, siblings, representatives
*
* IMPROVEMENTS:
* - ✅ Complete fast-path using FaceCacheDao.getSoloFacesWithEmbeddings()
* - ✅ Works with existing FaceCacheEntity.getEmbedding() method
* - ✅ Centroid-based representative face selection
* - ✅ Batched processing to prevent OOM
* - ✅ RGB_565 bitmap config for 50% memory savings
*/ */
@Singleton @Singleton
class FaceClusteringService @Inject constructor( class FaceClusteringService @Inject constructor(
@@ -50,105 +43,109 @@ class FaceClusteringService @Inject constructor(
) { ) {
private val semaphore = Semaphore(8) private val semaphore = Semaphore(8)
private val deterministicRandom = Random(42) // Fixed seed for reproducibility
companion object { companion object {
private const val TAG = "FaceClustering" private const val TAG = "FaceClustering"
private const val MAX_FACES_TO_CLUSTER = 2000 private const val MAX_FACES_TO_CLUSTER = 2000
private const val MIN_SOLO_PHOTOS = 50 private const val MIN_SOLO_PHOTOS = 50
private const val MIN_PREMIUM_FACES = 100
private const val MIN_STANDARD_FACES = 50
private const val BATCH_SIZE = 50 private const val BATCH_SIZE = 50
private const val MIN_CACHED_FACES = 100
} }
/**
* Main clustering entry point - HYBRID with automatic fallback
*/
suspend fun discoverPeople( suspend fun discoverPeople(
strategy: ClusteringStrategy = ClusteringStrategy.PREMIUM_SOLO_ONLY,
maxFacesToCluster: Int = MAX_FACES_TO_CLUSTER, maxFacesToCluster: Int = MAX_FACES_TO_CLUSTER,
onProgress: (Int, Int, String) -> Unit = { _, _, _ -> } onProgress: (Int, Int, String) -> Unit = { _, _, _ -> }
): ClusteringResult = withContext(Dispatchers.Default) { ): ClusteringResult = withContext(Dispatchers.Default) {
val startTime = System.currentTimeMillis() val startTime = System.currentTimeMillis()
// Try high-quality cached faces FIRST (NEW!) Log.d(TAG, "Starting people discovery with strategy: $strategy")
var cachedFaces = withContext(Dispatchers.IO) {
val result = when (strategy) {
ClusteringStrategy.PREMIUM_SOLO_ONLY -> {
clusterPremiumSoloFaces(maxFacesToCluster, onProgress)
}
ClusteringStrategy.STANDARD_SOLO_ONLY -> {
clusterStandardSoloFaces(maxFacesToCluster, onProgress)
}
ClusteringStrategy.TWO_PHASE -> {
clusterPremiumSoloFaces(maxFacesToCluster, onProgress)
}
ClusteringStrategy.LEGACY_ALL_FACES -> {
clusterAllFacesLegacy(maxFacesToCluster, onProgress)
}
}
val elapsedTime = System.currentTimeMillis() - startTime
Log.d(TAG, "Clustering complete: ${result.clusters.size} clusters in ${elapsedTime}ms")
result.copy(processingTimeMs = elapsedTime)
}
private suspend fun clusterPremiumSoloFaces(
maxFaces: Int,
onProgress: (Int, Int, String) -> Unit
): ClusteringResult = withContext(Dispatchers.Default) {
onProgress(5, 100, "Checking face cache...")
var premiumFaces = withContext(Dispatchers.IO) {
try { try {
faceCacheDao.getHighQualitySoloFaces( faceCacheDao.getPremiumSoloFaces(
minFaceRatio = 0.015f, // 1.5% minRatio = 0.05f,
limit = maxFacesToCluster minQuality = 0.8f,
limit = maxFaces
) )
} catch (e: Exception) { } catch (e: Exception) {
// Method doesn't exist yet - that's ok Log.w(TAG, "Error fetching premium faces: ${e.message}")
emptyList() emptyList()
} }
} }
// Fallback to ANY solo faces if high-quality returned nothing Log.d(TAG, "Found ${premiumFaces.size} premium solo faces in cache")
if (cachedFaces.isEmpty()) {
Log.w(TAG, "No high-quality faces (>= 1.5%), trying ANY solo faces...") if (premiumFaces.size < MIN_PREMIUM_FACES) {
cachedFaces = withContext(Dispatchers.IO) { Log.w(TAG, "Insufficient premium faces (${premiumFaces.size} < $MIN_PREMIUM_FACES)")
onProgress(10, 100, "Trying standard quality faces...")
premiumFaces = withContext(Dispatchers.IO) {
try { try {
faceCacheDao.getSoloFacesWithEmbeddings(limit = maxFacesToCluster) faceCacheDao.getStandardSoloFaces(
minRatio = 0.03f,
minQuality = 0.6f,
limit = maxFaces
)
} catch (e: Exception) { } catch (e: Exception) {
emptyList() emptyList()
} }
} }
Log.d(TAG, "Found ${premiumFaces.size} standard solo faces in cache")
} }
Log.d(TAG, "Cache check: ${cachedFaces.size} faces available") if (premiumFaces.size < MIN_STANDARD_FACES) {
Log.w(TAG, "Insufficient cached faces, falling back to slow path")
return@withContext clusterAllFacesLegacy(maxFaces, onProgress)
}
val allFaces = if (cachedFaces.size >= MIN_CACHED_FACES) { onProgress(20, 100, "Loading ${premiumFaces.size} high-quality solo photos...")
// FAST PATH ✅
Log.d(TAG, "Using FAST PATH with ${cachedFaces.size} cached faces")
onProgress(10, 100, "Using cached embeddings (${cachedFaces.size} faces)...")
cachedFaces.mapNotNull { cached -> val allFaces = premiumFaces.mapNotNull { cached: FaceCacheEntity ->
val embedding = cached.getEmbedding() ?: return@mapNotNull null val embedding = cached.getEmbedding() ?: return@mapNotNull null
DetectedFaceWithEmbedding( DetectedFaceWithEmbedding(
imageId = cached.imageId, imageId = cached.imageId,
imageUri = "", imageUri = "",
capturedAt = 0L, capturedAt = 0L,
embedding = embedding, embedding = embedding,
boundingBox = cached.getBoundingBox(), boundingBox = cached.getBoundingBox(),
confidence = cached.confidence, confidence = cached.confidence,
faceCount = 1, // Solo faces only (filtered by query) faceCount = 1,
imageWidth = cached.imageWidth, imageWidth = cached.imageWidth,
imageHeight = cached.imageHeight imageHeight = cached.imageHeight
)
}.also {
onProgress(50, 100, "Processing ${it.size} cached faces...")
}
} else {
// SLOW PATH
Log.d(TAG, "Using SLOW PATH - cache has < $MIN_CACHED_FACES faces")
onProgress(0, 100, "Loading photos...")
val soloPhotos = withContext(Dispatchers.IO) {
imageDao.getImagesByFaceCount(count = 1)
}
val imagesWithFaces = if (soloPhotos.size < MIN_SOLO_PHOTOS) {
imageDao.getImagesWithFaces()
} else {
soloPhotos
}
if (imagesWithFaces.isEmpty()) {
return@withContext ClusteringResult(
clusters = emptyList(),
totalFacesAnalyzed = 0,
processingTimeMs = 0,
errorMessage = "No photos with faces found"
)
}
onProgress(10, 100, "Analyzing ${imagesWithFaces.size} photos...")
detectFacesInImagesBatched(
images = imagesWithFaces.take(1000),
onProgress = { current, total ->
onProgress(10 + (current * 40 / total), 100, "Detecting faces... $current/$total")
}
) )
} }
@@ -156,28 +153,31 @@ class FaceClusteringService @Inject constructor(
return@withContext ClusteringResult( return@withContext ClusteringResult(
clusters = emptyList(), clusters = emptyList(),
totalFacesAnalyzed = 0, totalFacesAnalyzed = 0,
processingTimeMs = System.currentTimeMillis() - startTime, processingTimeMs = 0,
errorMessage = "No faces detected" errorMessage = "No valid faces with embeddings found"
) )
} }
onProgress(50, 100, "Clustering ${allFaces.size} faces...") onProgress(40, 100, "Clustering ${allFaces.size} faces...")
// ENHANCED: Lower threshold (quality filter handles garbage now)
val rawClusters = performDBSCAN( val rawClusters = performDBSCAN(
faces = allFaces.take(maxFacesToCluster), faces = allFaces.take(maxFaces),
epsilon = 0.26f, epsilon = 0.24f, // Was 0.26f - now more aggressive
minPoints = 3 minPoints = 3 // Was 3 - keeping same
) )
Log.d(TAG, "DBSCAN produced ${rawClusters.size} raw clusters")
onProgress(70, 100, "Analyzing relationships...") onProgress(70, 100, "Analyzing relationships...")
val coOccurrenceGraph = buildCoOccurrenceGraph(rawClusters) val coOccurrenceGraph = buildCoOccurrenceGraph(rawClusters)
onProgress(80, 100, "Selecting representative faces...") onProgress(80, 100, "Selecting representative faces...")
val clusters = rawClusters.map { cluster -> val clusters = rawClusters.mapIndexed { index: Int, cluster: RawCluster ->
FaceCluster( FaceCluster(
clusterId = cluster.clusterId, clusterId = index,
faces = cluster.faces, faces = cluster.faces,
representativeFaces = selectRepresentativeFacesByCentroid(cluster.faces, count = 6), representativeFaces = selectRepresentativeFacesByCentroid(cluster.faces, count = 6),
photoCount = cluster.faces.map { it.imageId }.distinct().size, photoCount = cluster.faces.map { it.imageId }.distinct().size,
@@ -192,120 +192,225 @@ class FaceClusteringService @Inject constructor(
ClusteringResult( ClusteringResult(
clusters = clusters, clusters = clusters,
totalFacesAnalyzed = allFaces.size, totalFacesAnalyzed = allFaces.size,
processingTimeMs = System.currentTimeMillis() - startTime processingTimeMs = 0,
strategy = ClusteringStrategy.PREMIUM_SOLO_ONLY
) )
} }
private suspend fun detectFacesInImagesBatched( private suspend fun clusterStandardSoloFaces(
images: List<ImageEntity>, maxFaces: Int,
onProgress: (Int, Int) -> Unit onProgress: (Int, Int, String) -> Unit
): List<DetectedFaceWithEmbedding> = coroutineScope { ): ClusteringResult = withContext(Dispatchers.Default) {
val allFaces = mutableListOf<DetectedFaceWithEmbedding>() onProgress(10, 100, "Loading solo photos...")
var processedCount = 0
images.chunked(BATCH_SIZE).forEach { batch -> val standardFaces = withContext(Dispatchers.IO) {
val batchFaces = detectFacesInBatch(batch) try {
allFaces.addAll(batchFaces) faceCacheDao.getStandardSoloFaces(
minRatio = 0.03f,
processedCount += batch.size minQuality = 0.6f,
onProgress(processedCount, images.size) limit = maxFaces
)
System.gc() } catch (e: Exception) {
emptyList()
}
} }
allFaces if (standardFaces.size < MIN_STANDARD_FACES) {
return@withContext clusterAllFacesLegacy(maxFaces, onProgress)
}
val allFaces = standardFaces.mapNotNull { cached: FaceCacheEntity ->
val embedding = cached.getEmbedding() ?: return@mapNotNull null
DetectedFaceWithEmbedding(
imageId = cached.imageId,
imageUri = "",
capturedAt = 0L,
embedding = embedding,
boundingBox = cached.getBoundingBox(),
confidence = cached.confidence,
faceCount = 1,
imageWidth = cached.imageWidth,
imageHeight = cached.imageHeight
)
}
onProgress(40, 100, "Clustering ${allFaces.size} faces...")
val rawClusters = performDBSCAN(allFaces.take(maxFaces), 0.24f, 3)
val coOccurrenceGraph = buildCoOccurrenceGraph(rawClusters)
val clusters = rawClusters.mapIndexed { index, cluster ->
FaceCluster(
clusterId = index,
faces = cluster.faces,
representativeFaces = selectRepresentativeFacesByCentroid(cluster.faces, 6),
photoCount = cluster.faces.map { it.imageId }.distinct().size,
averageConfidence = cluster.faces.map { it.confidence }.average().toFloat(),
estimatedAge = estimateAge(cluster.faces),
potentialSiblings = findPotentialSiblings(cluster, rawClusters, coOccurrenceGraph)
)
}.sortedByDescending { it.photoCount }
ClusteringResult(
clusters = clusters,
totalFacesAnalyzed = allFaces.size,
processingTimeMs = 0,
strategy = ClusteringStrategy.STANDARD_SOLO_ONLY
)
} }
private suspend fun detectFacesInBatch( private suspend fun clusterAllFacesLegacy(
images: List<ImageEntity> maxFaces: Int,
): List<DetectedFaceWithEmbedding> = coroutineScope { onProgress: (Int, Int, String) -> Unit
): ClusteringResult = withContext(Dispatchers.Default) {
onProgress(10, 100, "Loading photos...")
val images = withContext(Dispatchers.IO) {
imageDao.getAllImages()
}
if (images.isEmpty()) {
return@withContext ClusteringResult(
clusters = emptyList(),
totalFacesAnalyzed = 0,
processingTimeMs = 0,
errorMessage = "No images in library"
)
}
// ENHANCED: Process ALL photos (no limit)
val shuffled = images.shuffled(deterministicRandom)
onProgress(20, 100, "Analyzing ${shuffled.size} photos...")
val faceNetModel = FaceNetModel(context)
val detector = FaceDetection.getClient( val detector = FaceDetection.getClient(
FaceDetectorOptions.Builder() FaceDetectorOptions.Builder()
.setPerformanceMode(FaceDetectorOptions.PERFORMANCE_MODE_ACCURATE) .setPerformanceMode(FaceDetectorOptions.PERFORMANCE_MODE_ACCURATE)
.setLandmarkMode(FaceDetectorOptions.LANDMARK_MODE_ALL) // ENHANCED: Get landmarks
.setMinFaceSize(0.15f) .setMinFaceSize(0.15f)
.build() .build()
) )
val faceNetModel = FaceNetModel(context)
val batchFaces = mutableListOf<DetectedFaceWithEmbedding>()
try { try {
val jobs = images.map { image -> val allFaces = mutableListOf<DetectedFaceWithEmbedding>()
async(Dispatchers.IO) {
semaphore.acquire() coroutineScope {
try { val jobs = shuffled.mapIndexed { index, image ->
detectFacesInImage(image, detector, faceNetModel) async(Dispatchers.IO) {
} finally { semaphore.acquire()
semaphore.release() try {
val bitmap = loadBitmapDownsampled(Uri.parse(image.imageUri), 768)
?: return@async emptyList()
val inputImage = InputImage.fromBitmap(bitmap, 0)
val faces = Tasks.await(detector.process(inputImage))
val imageWidth = bitmap.width
val imageHeight = bitmap.height
val faceEmbeddings = faces.mapNotNull { face ->
// ===== APPLY QUALITY FILTER =====
val qualityCheck = FaceQualityFilter.validateForDiscovery(
face = face,
imageWidth = imageWidth,
imageHeight = imageHeight
)
// Skip low-quality faces
if (!qualityCheck.isValid) {
Log.d(TAG, "Rejected face: ${qualityCheck.issues.joinToString()}")
return@mapNotNull null
}
try {
val faceBitmap = android.graphics.Bitmap.createBitmap(
bitmap,
face.boundingBox.left.coerceIn(0, bitmap.width - 1),
face.boundingBox.top.coerceIn(0, bitmap.height - 1),
face.boundingBox.width().coerceAtMost(bitmap.width - face.boundingBox.left),
face.boundingBox.height().coerceAtMost(bitmap.height - face.boundingBox.top)
)
val embedding = faceNetModel.generateEmbedding(faceBitmap)
faceBitmap.recycle()
DetectedFaceWithEmbedding(
imageId = image.imageId,
imageUri = image.imageUri,
capturedAt = image.capturedAt,
embedding = embedding,
boundingBox = face.boundingBox,
confidence = qualityCheck.confidenceScore, // Use quality score
faceCount = faces.size,
imageWidth = imageWidth,
imageHeight = imageHeight
)
} catch (e: Exception) {
null
}
}
bitmap.recycle()
if (index % 20 == 0) {
val progress = 20 + (index * 60 / shuffled.size)
onProgress(progress, 100, "Processed $index/${shuffled.size} photos...")
}
faceEmbeddings
} finally {
semaphore.release()
}
} }
} }
jobs.awaitAll().flatten().forEach { allFaces.add(it) }
} }
jobs.awaitAll().flatten().also { if (allFaces.isEmpty()) {
batchFaces.addAll(it) return@withContext ClusteringResult(
clusters = emptyList(),
totalFacesAnalyzed = 0,
processingTimeMs = 0,
errorMessage = "No faces detected with sufficient quality"
)
} }
onProgress(80, 100, "Clustering ${allFaces.size} faces...")
val rawClusters = performDBSCAN(allFaces.take(maxFaces), 0.24f, 3)
val coOccurrenceGraph = buildCoOccurrenceGraph(rawClusters)
val clusters = rawClusters.mapIndexed { index, cluster ->
FaceCluster(
clusterId = index,
faces = cluster.faces,
representativeFaces = selectRepresentativeFacesByCentroid(cluster.faces, 6),
photoCount = cluster.faces.map { it.imageId }.distinct().size,
averageConfidence = cluster.faces.map { it.confidence }.average().toFloat(),
estimatedAge = estimateAge(cluster.faces),
potentialSiblings = findPotentialSiblings(cluster, rawClusters, coOccurrenceGraph)
)
}.sortedByDescending { it.photoCount }
onProgress(100, 100, "Complete!")
ClusteringResult(
clusters = clusters,
totalFacesAnalyzed = allFaces.size,
processingTimeMs = 0,
strategy = ClusteringStrategy.LEGACY_ALL_FACES
)
} finally { } finally {
detector.close()
faceNetModel.close() faceNetModel.close()
} detector.close()
batchFaces
}
private suspend fun detectFacesInImage(
image: ImageEntity,
detector: com.google.mlkit.vision.face.FaceDetector,
faceNetModel: FaceNetModel
): List<DetectedFaceWithEmbedding> = withContext(Dispatchers.IO) {
try {
val uri = Uri.parse(image.imageUri)
val bitmap = loadBitmapDownsampled(uri, 512) ?: return@withContext emptyList()
val mlImage = InputImage.fromBitmap(bitmap, 0)
val faces = com.google.android.gms.tasks.Tasks.await(detector.process(mlImage))
val result = faces.mapNotNull { face ->
try {
val faceBitmap = Bitmap.createBitmap(
bitmap,
face.boundingBox.left.coerceIn(0, bitmap.width - 1),
face.boundingBox.top.coerceIn(0, bitmap.height - 1),
face.boundingBox.width().coerceAtMost(bitmap.width - face.boundingBox.left),
face.boundingBox.height().coerceAtMost(bitmap.height - face.boundingBox.top)
)
val embedding = faceNetModel.generateEmbedding(faceBitmap)
faceBitmap.recycle()
DetectedFaceWithEmbedding(
imageId = image.imageId,
imageUri = image.imageUri,
capturedAt = image.capturedAt,
embedding = embedding,
boundingBox = face.boundingBox,
confidence = 0.95f,
faceCount = faces.size,
imageWidth = bitmap.width,
imageHeight = bitmap.height
)
} catch (e: Exception) {
null
}
}
bitmap.recycle()
result
} catch (e: Exception) {
emptyList()
} }
} }
private fun performDBSCAN( fun performDBSCAN(
faces: List<DetectedFaceWithEmbedding>, faces: List<DetectedFaceWithEmbedding>,
epsilon: Float, epsilon: Float,
minPoints: Int minPoints: Int
@@ -326,8 +431,6 @@ class FaceClusteringService @Inject constructor(
val cluster = mutableListOf<DetectedFaceWithEmbedding>() val cluster = mutableListOf<DetectedFaceWithEmbedding>()
val queue = ArrayDeque(neighbors) val queue = ArrayDeque(neighbors)
visited.add(i)
cluster.add(faces[i])
while (queue.isNotEmpty()) { while (queue.isNotEmpty()) {
val pointIdx = queue.removeFirst() val pointIdx = queue.removeFirst()
@@ -356,7 +459,7 @@ class FaceClusteringService @Inject constructor(
epsilon: Float epsilon: Float
): List<Int> { ): List<Int> {
val point = faces[pointIdx] val point = faces[pointIdx]
return faces.indices.filter { i -> return faces.indices.filter { i: Int ->
if (i == pointIdx) return@filter false if (i == pointIdx) return@filter false
val otherFace = faces[i] val otherFace = faces[i]
@@ -412,13 +515,13 @@ class FaceClusteringService @Inject constructor(
if (clusterIdx == -1) return emptyList() if (clusterIdx == -1) return emptyList()
return coOccurrenceGraph[clusterIdx] return coOccurrenceGraph[clusterIdx]
?.filter { (_, count) -> count >= 5 } ?.filter { (_, count: Int) -> count >= 5 }
?.keys ?.keys
?.toList() ?.toList()
?: emptyList() ?: emptyList()
} }
private fun selectRepresentativeFacesByCentroid( fun selectRepresentativeFacesByCentroid(
faces: List<DetectedFaceWithEmbedding>, faces: List<DetectedFaceWithEmbedding>,
count: Int count: Int
): List<DetectedFaceWithEmbedding> { ): List<DetectedFaceWithEmbedding> {
@@ -426,7 +529,7 @@ class FaceClusteringService @Inject constructor(
val centroid = calculateCentroid(faces.map { it.embedding }) val centroid = calculateCentroid(faces.map { it.embedding })
val facesWithDistance = faces.map { face -> val facesWithDistance = faces.map { face: DetectedFaceWithEmbedding ->
val distance = 1 - cosineSimilarity(face.embedding, centroid) val distance = 1 - cosineSimilarity(face.embedding, centroid)
face to distance face to distance
} }
@@ -456,7 +559,7 @@ class FaceClusteringService @Inject constructor(
val size = embeddings.first().size val size = embeddings.first().size
val centroid = FloatArray(size) { 0f } val centroid = FloatArray(size) { 0f }
embeddings.forEach { embedding -> embeddings.forEach { embedding: FloatArray ->
for (i in embedding.indices) { for (i in embedding.indices) {
centroid[i] += embedding[i] centroid[i] += embedding[i]
} }
@@ -477,6 +580,8 @@ class FaceClusteringService @Inject constructor(
private fun estimateAge(faces: List<DetectedFaceWithEmbedding>): AgeEstimate { private fun estimateAge(faces: List<DetectedFaceWithEmbedding>): AgeEstimate {
val timestamps = faces.map { it.capturedAt }.sorted() val timestamps = faces.map { it.capturedAt }.sorted()
if (timestamps.isEmpty() || timestamps.last() == 0L) return AgeEstimate.UNKNOWN
val span = timestamps.last() - timestamps.first() val span = timestamps.last() - timestamps.first()
val spanYears = span / (365.25 * 24 * 60 * 60 * 1000) val spanYears = span / (365.25 * 24 * 60 * 60 * 1000)
@@ -509,6 +614,13 @@ class FaceClusteringService @Inject constructor(
} }
} }
enum class ClusteringStrategy {
PREMIUM_SOLO_ONLY,
STANDARD_SOLO_ONLY,
TWO_PHASE,
LEGACY_ALL_FACES
}
data class DetectedFaceWithEmbedding( data class DetectedFaceWithEmbedding(
val imageId: String, val imageId: String,
val imageUri: String, val imageUri: String,
@@ -549,7 +661,8 @@ data class ClusteringResult(
val clusters: List<FaceCluster>, val clusters: List<FaceCluster>,
val totalFacesAnalyzed: Int, val totalFacesAnalyzed: Int,
val processingTimeMs: Long, val processingTimeMs: Long,
val errorMessage: String? = null val errorMessage: String? = null,
val strategy: ClusteringStrategy = ClusteringStrategy.PREMIUM_SOLO_ONLY
) )
enum class AgeEstimate { enum class AgeEstimate {

View File

@@ -0,0 +1,193 @@
package com.placeholder.sherpai2.domain.clustering
import com.google.mlkit.vision.face.Face
import com.google.mlkit.vision.face.FaceLandmark
import kotlin.math.abs
import kotlin.math.pow
import kotlin.math.sqrt
/**
* FaceQualityFilter - Aggressive filtering for Discovery/Clustering phase
*
* PURPOSE:
* ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
* ONLY used during Discovery to create high-quality training clusters.
* NOT used during scanning phase (scanning remains permissive).
*
* FILTERS OUT:
* ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
* ✅ Ghost faces (clothing patterns, textures, shadows)
* ✅ Partial faces (side profiles, blocked faces)
* ✅ Tiny background faces
* ✅ Extreme angles (looking away, upside down)
* ✅ Low-confidence detections
*
* STRATEGY:
* ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
* Multi-stage validation:
* 1. ML Kit confidence score
* 2. Eye landmark detection (both eyes required)
* 3. Head pose validation (reasonable angles)
* 4. Face size validation (minimum threshold)
* 5. Tracking ID validation (stable detection)
*/
object FaceQualityFilter {
/**
* Validate face for Discovery/Clustering
*
* @param face ML Kit detected face
* @param imageWidth Image width in pixels
* @param imageHeight Image height in pixels
* @return Quality result with pass/fail and reasons
*/
fun validateForDiscovery(
face: Face,
imageWidth: Int,
imageHeight: Int
): FaceQualityValidation {
val issues = mutableListOf<String>()
// ===== CHECK 1: Eye Detection =====
// Both eyes must be detected (eliminates 90% of false positives)
val leftEye = face.getLandmark(FaceLandmark.LEFT_EYE)
val rightEye = face.getLandmark(FaceLandmark.RIGHT_EYE)
if (leftEye == null || rightEye == null) {
issues.add("Missing eye landmarks (likely not a real face)")
return FaceQualityValidation(
isValid = false,
issues = issues,
confidenceScore = 0f
)
}
// ===== CHECK 2: Head Pose Validation =====
// Reject extreme angles (side profiles, looking away, upside down)
val headEulerAngleY = face.headEulerAngleY // Left/right rotation
val headEulerAngleZ = face.headEulerAngleZ // Tilt
val headEulerAngleX = face.headEulerAngleX // Up/down
// Allow reasonable range: -30° to +30° for Y and Z
if (abs(headEulerAngleY) > 30f) {
issues.add("Head turned too far (${headEulerAngleY.toInt()}°)")
}
if (abs(headEulerAngleZ) > 30f) {
issues.add("Head tilted too much (${headEulerAngleZ.toInt()}°)")
}
if (abs(headEulerAngleX) > 25f) {
issues.add("Head angle too extreme (${headEulerAngleX.toInt()}°)")
}
// ===== CHECK 3: Face Size Validation =====
// Minimum 15% of image width/height
val faceWidth = face.boundingBox.width()
val faceHeight = face.boundingBox.height()
val minFaceSize = 0.15f
val faceWidthRatio = faceWidth.toFloat() / imageWidth.toFloat()
val faceHeightRatio = faceHeight.toFloat() / imageHeight.toFloat()
if (faceWidthRatio < minFaceSize) {
issues.add("Face too small (${(faceWidthRatio * 100).toInt()}% of image width)")
}
if (faceHeightRatio < minFaceSize) {
issues.add("Face too small (${(faceHeightRatio * 100).toInt()}% of image height)")
}
// ===== CHECK 4: Tracking Confidence =====
// ML Kit provides tracking ID - if null, detection is unstable
if (face.trackingId == null) {
issues.add("Unstable detection (no tracking ID)")
}
// ===== CHECK 5: Nose Detection (Additional Validation) =====
// Nose landmark helps confirm it's a frontal face
val nose = face.getLandmark(FaceLandmark.NOSE_BASE)
if (nose == null) {
issues.add("No nose detected (likely partial/occluded face)")
}
// ===== CHECK 6: Eye Distance Validation =====
// Eyes should be reasonably spaced (detects stretched/warped faces)
if (leftEye != null && rightEye != null) {
val eyeDistance = sqrt(
(rightEye.position.x - leftEye.position.x).toDouble().pow(2.0) +
(rightEye.position.y - leftEye.position.y).toDouble().pow(2.0)
).toFloat()
// Eye distance should be 20-60% of face width
val eyeDistanceRatio = eyeDistance / faceWidth
if (eyeDistanceRatio < 0.20f || eyeDistanceRatio > 0.60f) {
issues.add("Abnormal eye spacing (${(eyeDistanceRatio * 100).toInt()}%)")
}
}
// ===== CALCULATE CONFIDENCE SCORE =====
// Based on head pose, size, and landmark quality
val poseScore = 1f - (abs(headEulerAngleY) + abs(headEulerAngleZ) + abs(headEulerAngleX)) / 180f
val sizeScore = (faceWidthRatio + faceHeightRatio) / 2f
val landmarkScore = if (nose != null && leftEye != null && rightEye != null) 1f else 0.5f
val confidenceScore = (poseScore * 0.4f + sizeScore * 0.3f + landmarkScore * 0.3f).coerceIn(0f, 1f)
// ===== FINAL VERDICT =====
// Pass if no critical issues and confidence > 0.6
val isValid = issues.isEmpty() && confidenceScore >= 0.6f
return FaceQualityValidation(
isValid = isValid,
issues = issues,
confidenceScore = confidenceScore
)
}
/**
* Quick check for scanning phase (permissive)
*
* Only filters out obvious garbage - used during full library scans
*/
fun validateForScanning(
face: Face,
imageWidth: Int,
imageHeight: Int
): Boolean {
// Only reject if:
// 1. No eyes detected (obvious false positive)
// 2. Face is tiny (< 10% of image)
val leftEye = face.getLandmark(FaceLandmark.LEFT_EYE)
val rightEye = face.getLandmark(FaceLandmark.RIGHT_EYE)
if (leftEye == null && rightEye == null) {
return false // No eyes = not a face
}
val faceWidth = face.boundingBox.width()
val faceWidthRatio = faceWidth.toFloat() / imageWidth.toFloat()
if (faceWidthRatio < 0.10f) {
return false // Too small
}
return true
}
}
/**
* Face quality validation result
*/
data class FaceQualityValidation(
val isValid: Boolean,
val issues: List<String>,
val confidenceScore: Float
) {
val passesStrictValidation: Boolean
get() = isValid && confidenceScore >= 0.7f
val passesModerateValidation: Boolean
get() = isValid && confidenceScore >= 0.5f
}

View File

@@ -0,0 +1,597 @@
package com.placeholder.sherpai2.domain.clustering
import android.content.Context
import android.graphics.Bitmap
import android.graphics.BitmapFactory
import android.net.Uri
import android.util.Log
import com.google.android.gms.tasks.Tasks
import com.google.mlkit.vision.common.InputImage
import com.google.mlkit.vision.face.FaceDetection
import com.google.mlkit.vision.face.FaceDetectorOptions
import com.placeholder.sherpai2.data.local.dao.FaceCacheDao
import com.placeholder.sherpai2.data.local.dao.ImageDao
import com.placeholder.sherpai2.ml.FaceNetModel
import dagger.hilt.android.qualifiers.ApplicationContext
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.async
import kotlinx.coroutines.awaitAll
import kotlinx.coroutines.coroutineScope
import kotlinx.coroutines.sync.Semaphore
import kotlinx.coroutines.withContext
import java.util.Calendar
import javax.inject.Inject
import javax.inject.Singleton
import kotlin.math.sqrt
import kotlin.random.Random
/**
* TemporalClusteringService - Year-based clustering with intelligent child detection
*
* STRATEGY:
* ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
* 1. Process ALL photos (no limits)
* 2. Apply strict quality filter (FaceQualityFilter)
* 3. Group faces by YEAR
* 4. Cluster within each year
* 5. Link clusters across years (same person)
* 6. Detect children (changing appearance over years)
* 7. Generate tags: "Emma_2020", "Emma_Age_2", "Brad_Pitt"
*
* CHILD DETECTION:
* ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
* A person is a CHILD if:
* - Appears across 3+ years
* - Face embeddings change significantly between years (>0.20 distance)
* - Consistent presence (not just random appearances)
*
* OUTPUT:
* ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
* Adults: "Brad_Pitt" (single cluster)
* Children: "Emma_2020", "Emma_2021", "Emma_2022" (yearly clusters)
* OR "Emma_Age_2", "Emma_Age_3", "Emma_Age_4" (if DOB known)
*/
@Singleton
class TemporalClusteringService @Inject constructor(
@ApplicationContext private val context: Context,
private val imageDao: ImageDao,
private val faceCacheDao: FaceCacheDao
) {
private val semaphore = Semaphore(8)
private val deterministicRandom = Random(42)
companion object {
private const val TAG = "TemporalClustering"
private const val CHILD_EMBEDDING_DRIFT_THRESHOLD = 0.20f // Significant change
private const val CHILD_MIN_YEARS = 3 // Must span 3+ years
private const val ADULT_SIMILARITY_THRESHOLD = 0.80f // 80% similar across years
private const val CHILD_SIMILARITY_THRESHOLD = 0.70f // 70% similar (more lenient)
}
/**
* Discover people with year-based clustering
*
* @return List of AnnotatedCluster (year-specific clusters with metadata)
*/
suspend fun discoverPeopleByYear(
onProgress: (Int, Int, String) -> Unit = { _, _, _ -> }
): TemporalClusteringResult = withContext(Dispatchers.Default) {
val startTime = System.currentTimeMillis()
onProgress(5, 100, "Loading all photos...")
// STEP 1: Load ALL images (no limit)
val allImages = withContext(Dispatchers.IO) {
imageDao.getAllImages()
}
if (allImages.isEmpty()) {
return@withContext TemporalClusteringResult(
clusters = emptyList(),
totalPhotosProcessed = 0,
totalFacesDetected = 0,
processingTimeMs = 0,
errorMessage = "No photos in library"
)
}
Log.d(TAG, "Processing ${allImages.size} photos (no limit)")
onProgress(10, 100, "Detecting high-quality faces...")
// STEP 2: Detect faces with STRICT quality filtering
val faceNetModel = FaceNetModel(context)
val detector = FaceDetection.getClient(
FaceDetectorOptions.Builder()
.setPerformanceMode(FaceDetectorOptions.PERFORMANCE_MODE_ACCURATE)
.setLandmarkMode(FaceDetectorOptions.LANDMARK_MODE_ALL)
.setMinFaceSize(0.15f)
.build()
)
try {
val allFaces = mutableListOf<DetectedFaceWithEmbedding>()
coroutineScope {
val jobs = allImages.mapIndexed { index, image ->
async(Dispatchers.IO) {
semaphore.acquire()
try {
val bitmap = loadBitmapDownsampled(Uri.parse(image.imageUri), 768)
?: return@async emptyList()
val inputImage = InputImage.fromBitmap(bitmap, 0)
val faces = Tasks.await(detector.process(inputImage))
val imageWidth = bitmap.width
val imageHeight = bitmap.height
val validFaces = faces.mapNotNull { face ->
// Apply STRICT quality filter
val qualityCheck = FaceQualityFilter.validateForDiscovery(
face = face,
imageWidth = imageWidth,
imageHeight = imageHeight
)
if (!qualityCheck.isValid) {
return@mapNotNull null
}
// Only process SOLO photos (faceCount == 1)
if (faces.size != 1) {
return@mapNotNull null
}
try {
val faceBitmap = Bitmap.createBitmap(
bitmap,
face.boundingBox.left.coerceIn(0, bitmap.width - 1),
face.boundingBox.top.coerceIn(0, bitmap.height - 1),
face.boundingBox.width().coerceAtMost(bitmap.width - face.boundingBox.left),
face.boundingBox.height().coerceAtMost(bitmap.height - face.boundingBox.top)
)
val embedding = faceNetModel.generateEmbedding(faceBitmap)
faceBitmap.recycle()
DetectedFaceWithEmbedding(
imageId = image.imageId,
imageUri = image.imageUri,
capturedAt = image.capturedAt,
embedding = embedding,
boundingBox = face.boundingBox,
confidence = qualityCheck.confidenceScore,
faceCount = 1,
imageWidth = imageWidth,
imageHeight = imageHeight
)
} catch (e: Exception) {
null
}
}
bitmap.recycle()
if (index % 50 == 0) {
val progress = 10 + (index * 40 / allImages.size)
onProgress(progress, 100, "Processed $index/${allImages.size} photos...")
}
validFaces
} finally {
semaphore.release()
}
}
}
jobs.awaitAll().flatten().forEach { allFaces.add(it) }
}
Log.d(TAG, "Detected ${allFaces.size} high-quality solo faces")
if (allFaces.isEmpty()) {
return@withContext TemporalClusteringResult(
clusters = emptyList(),
totalPhotosProcessed = allImages.size,
totalFacesDetected = 0,
processingTimeMs = System.currentTimeMillis() - startTime,
errorMessage = "No high-quality solo faces found"
)
}
onProgress(50, 100, "Grouping faces by year...")
// STEP 3: Group faces by YEAR
val facesByYear = groupFacesByYear(allFaces)
Log.d(TAG, "Faces grouped into ${facesByYear.size} years")
onProgress(60, 100, "Clustering within each year...")
// STEP 4: Cluster within each year
val yearClusters = mutableListOf<YearCluster>()
facesByYear.forEach { (year, faces) ->
Log.d(TAG, "Clustering $year: ${faces.size} faces")
val rawClusters = performDBSCAN(
faces = faces,
epsilon = 0.24f,
minPoints = 3
)
rawClusters.forEach { rawCluster ->
yearClusters.add(
YearCluster(
year = year,
faces = rawCluster.faces,
centroid = calculateCentroid(rawCluster.faces.map { it.embedding })
)
)
}
}
Log.d(TAG, "Created ${yearClusters.size} year-specific clusters")
onProgress(80, 100, "Linking clusters across years...")
// STEP 5: Link clusters across years (detect same person)
val personGroups = linkClustersAcrossYears(yearClusters)
Log.d(TAG, "Identified ${personGroups.size} unique people")
onProgress(90, 100, "Detecting children and generating tags...")
// STEP 6: Detect children and generate final clusters
val annotatedClusters = personGroups.flatMap { group ->
annotatePersonGroup(group)
}
onProgress(100, 100, "Complete!")
TemporalClusteringResult(
clusters = annotatedClusters.sortedByDescending { it.cluster.faces.size },
totalPhotosProcessed = allImages.size,
totalFacesDetected = allFaces.size,
processingTimeMs = System.currentTimeMillis() - startTime
)
} finally {
faceNetModel.close()
detector.close()
}
}
/**
* Group faces by year of capture
*/
private fun groupFacesByYear(faces: List<DetectedFaceWithEmbedding>): Map<String, List<DetectedFaceWithEmbedding>> {
return faces.groupBy { face ->
val calendar = Calendar.getInstance()
calendar.timeInMillis = face.capturedAt
calendar.get(Calendar.YEAR).toString()
}
}
/**
* Link year clusters that belong to the same person
*/
private fun linkClustersAcrossYears(yearClusters: List<YearCluster>): List<PersonGroup> {
val sortedClusters = yearClusters.sortedBy { it.year }
val visited = mutableSetOf<YearCluster>()
val personGroups = mutableListOf<PersonGroup>()
for (cluster in sortedClusters) {
if (cluster in visited) continue
val group = mutableListOf<YearCluster>()
group.add(cluster)
visited.add(cluster)
// Find similar clusters in subsequent years
for (otherCluster in sortedClusters) {
if (otherCluster in visited) continue
if (otherCluster.year <= cluster.year) continue
val similarity = cosineSimilarity(cluster.centroid, otherCluster.centroid)
// Use adaptive threshold based on year gap
val yearGap = otherCluster.year.toInt() - cluster.year.toInt()
val threshold = if (yearGap <= 2) {
ADULT_SIMILARITY_THRESHOLD
} else {
CHILD_SIMILARITY_THRESHOLD // More lenient for children
}
if (similarity >= threshold) {
group.add(otherCluster)
visited.add(otherCluster)
}
}
personGroups.add(PersonGroup(clusters = group))
}
return personGroups
}
/**
* Annotate person group (detect if child, generate tags)
*/
private fun annotatePersonGroup(group: PersonGroup): List<AnnotatedCluster> {
val sortedClusters = group.clusters.sortedBy { it.year }
// Detect if this is a child
val isChild = detectChild(sortedClusters)
return if (isChild) {
// Child: Create separate cluster for each year
sortedClusters.map { yearCluster ->
AnnotatedCluster(
cluster = FaceCluster(
clusterId = 0,
faces = yearCluster.faces,
representativeFaces = selectRepresentativeFaces(yearCluster.faces, 6),
photoCount = yearCluster.faces.size,
averageConfidence = yearCluster.faces.map { it.confidence }.average().toFloat(),
estimatedAge = AgeEstimate.CHILD,
potentialSiblings = emptyList()
),
year = yearCluster.year,
isChild = true,
suggestedName = null,
suggestedAge = estimateAgeInYear(yearCluster.year, sortedClusters)
)
}
} else {
// Adult: Single cluster combining all years
val allFaces = sortedClusters.flatMap { it.faces }
listOf(
AnnotatedCluster(
cluster = FaceCluster(
clusterId = 0,
faces = allFaces,
representativeFaces = selectRepresentativeFaces(allFaces, 6),
photoCount = allFaces.size,
averageConfidence = allFaces.map { it.confidence }.average().toFloat(),
estimatedAge = AgeEstimate.ADULT,
potentialSiblings = emptyList()
),
year = "All Years",
isChild = false,
suggestedName = null,
suggestedAge = null
)
)
}
}
/**
* Detect if person group represents a child
*/
private fun detectChild(clusters: List<YearCluster>): Boolean {
if (clusters.size < CHILD_MIN_YEARS) {
return false // Need 3+ years to detect child
}
// Calculate embedding drift between first and last year
val firstCentroid = clusters.first().centroid
val lastCentroid = clusters.last().centroid
val drift = 1 - cosineSimilarity(firstCentroid, lastCentroid)
// If embeddings changed significantly, likely a child
return drift >= CHILD_EMBEDDING_DRIFT_THRESHOLD
}
/**
* Estimate age in specific year based on cluster position
*/
private fun estimateAgeInYear(targetYear: String, allClusters: List<YearCluster>): Int? {
val sortedClusters = allClusters.sortedBy { it.year }
val firstYear = sortedClusters.first().year.toInt()
val targetYearInt = targetYear.toInt()
val yearsSinceFirst = targetYearInt - firstYear
return yearsSinceFirst + 1 // Start at age 1
}
/**
* Select representative faces
*/
private fun selectRepresentativeFaces(
faces: List<DetectedFaceWithEmbedding>,
count: Int
): List<DetectedFaceWithEmbedding> {
if (faces.size <= count) return faces
val centroid = calculateCentroid(faces.map { it.embedding })
return faces
.map { face -> face to (1 - cosineSimilarity(face.embedding, centroid)) }
.sortedBy { it.second }
.take(count)
.map { it.first }
}
/**
* DBSCAN clustering
*/
private fun performDBSCAN(
faces: List<DetectedFaceWithEmbedding>,
epsilon: Float,
minPoints: Int
): List<RawCluster> {
val visited = mutableSetOf<Int>()
val clusters = mutableListOf<RawCluster>()
var clusterId = 0
for (i in faces.indices) {
if (i in visited) continue
val neighbors = findNeighbors(i, faces, epsilon)
if (neighbors.size < minPoints) {
visited.add(i)
continue
}
val cluster = mutableListOf<DetectedFaceWithEmbedding>()
val queue = ArrayDeque(neighbors)
while (queue.isNotEmpty()) {
val pointIdx = queue.removeFirst()
if (pointIdx in visited) continue
visited.add(pointIdx)
cluster.add(faces[pointIdx])
val pointNeighbors = findNeighbors(pointIdx, faces, epsilon)
if (pointNeighbors.size >= minPoints) {
queue.addAll(pointNeighbors.filter { it !in visited })
}
}
if (cluster.size >= minPoints) {
clusters.add(RawCluster(clusterId++, cluster))
}
}
return clusters
}
private fun findNeighbors(
pointIdx: Int,
faces: List<DetectedFaceWithEmbedding>,
epsilon: Float
): List<Int> {
val point = faces[pointIdx]
return faces.indices.filter { i ->
if (i == pointIdx) return@filter false
val similarity = cosineSimilarity(point.embedding, faces[i].embedding)
similarity > (1 - epsilon)
}
}
private fun cosineSimilarity(a: FloatArray, b: FloatArray): Float {
var dotProduct = 0f
var normA = 0f
var normB = 0f
for (i in a.indices) {
dotProduct += a[i] * b[i]
normA += a[i] * a[i]
normB += b[i] * b[i]
}
return dotProduct / (sqrt(normA) * sqrt(normB))
}
private fun calculateCentroid(embeddings: List<FloatArray>): FloatArray {
if (embeddings.isEmpty()) return FloatArray(0)
val size = embeddings.first().size
val centroid = FloatArray(size) { 0f }
embeddings.forEach { embedding ->
for (i in embedding.indices) {
centroid[i] += embedding[i]
}
}
val count = embeddings.size.toFloat()
for (i in centroid.indices) {
centroid[i] /= count
}
val norm = sqrt(centroid.map { it * it }.sum())
if (norm > 0) {
return centroid.map { it / norm }.toFloatArray()
}
return centroid
}
private fun loadBitmapDownsampled(uri: Uri, maxDim: Int): Bitmap? {
return try {
val opts = BitmapFactory.Options().apply { inJustDecodeBounds = true }
context.contentResolver.openInputStream(uri)?.use {
BitmapFactory.decodeStream(it, null, opts)
}
var sample = 1
while (opts.outWidth / sample > maxDim || opts.outHeight / sample > maxDim) {
sample *= 2
}
val finalOpts = BitmapFactory.Options().apply {
inSampleSize = sample
inPreferredConfig = Bitmap.Config.RGB_565
}
context.contentResolver.openInputStream(uri)?.use {
BitmapFactory.decodeStream(it, null, finalOpts)
}
} catch (e: Exception) {
null
}
}
}
/**
* Year-specific cluster
*/
data class YearCluster(
val year: String,
val faces: List<DetectedFaceWithEmbedding>,
val centroid: FloatArray
)
/**
* Group of year clusters belonging to same person
*/
data class PersonGroup(
val clusters: List<YearCluster>
)
/**
* Annotated cluster with temporal metadata
*/
data class AnnotatedCluster(
val cluster: FaceCluster,
val year: String,
val isChild: Boolean,
val suggestedName: String?,
val suggestedAge: Int?
) {
/**
* Generate tag for this cluster
* Examples:
* - Child: "Emma_2020" or "Emma_Age_2"
* - Adult: "Brad_Pitt"
*/
fun generateTag(name: String): String {
return if (isChild) {
if (suggestedAge != null) {
"${name}_Age_${suggestedAge}"
} else {
"${name}_${year}"
}
} else {
name
}
}
}
/**
* Result of temporal clustering
*/
data class TemporalClusteringResult(
val clusters: List<AnnotatedCluster>,
val totalPhotosProcessed: Int,
val totalFacesDetected: Int,
val processingTimeMs: Long,
val errorMessage: String? = null
)

View File

@@ -22,6 +22,7 @@ import com.placeholder.sherpai2.domain.clustering.ClusterQualityAnalyzer
* - ✅ Complete naming dialog integration * - ✅ Complete naming dialog integration
* - ✅ Quality analysis in cluster grid * - ✅ Quality analysis in cluster grid
* - ✅ Better error handling * - ✅ Better error handling
* - ✅ Refinement flow support
*/ */
@OptIn(ExperimentalMaterial3Api::class) @OptIn(ExperimentalMaterial3Api::class)
@Composable @Composable
@@ -112,6 +113,12 @@ fun DiscoverPeopleScreen(
ValidationPreviewScreen( ValidationPreviewScreen(
personName = state.personName, personName = state.personName,
validationResult = state.validationResult, validationResult = state.validationResult,
onMarkFeedback = { feedbackMap ->
viewModel.submitFeedback(state.cluster, feedbackMap)
},
onRequestRefinement = {
viewModel.requestRefinement(state.cluster)
},
onApprove = { onApprove = {
viewModel.approveValidationAndScan( viewModel.approveValidationAndScan(
personId = state.personId, personId = state.personId,
@@ -124,6 +131,28 @@ fun DiscoverPeopleScreen(
) )
} }
// ===== REFINEMENT NEEDED =====
is DiscoverUiState.RefinementNeeded -> {
RefinementNeededContent(
recommendation = state.recommendation,
currentIteration = state.currentIteration,
onRefine = {
viewModel.requestRefinement(state.cluster)
},
onSkip = {
viewModel.reset()
}
)
}
// ===== REFINING IN PROGRESS =====
is DiscoverUiState.Refining -> {
RefiningProgressContent(
iteration = state.iteration,
message = state.message
)
}
// ===== COMPLETE ===== // ===== COMPLETE =====
is DiscoverUiState.Complete -> { is DiscoverUiState.Complete -> {
CompleteStateContent( CompleteStateContent(
@@ -154,6 +183,7 @@ fun DiscoverPeopleScreen(
} }
} }
} }
// ===== IDLE STATE CONTENT ===== // ===== IDLE STATE CONTENT =====
@Composable @Composable
@@ -239,7 +269,7 @@ private fun ClusteringProgressContent(
if (total > 0) { if (total > 0) {
LinearProgressIndicator( LinearProgressIndicator(
progress = progress.toFloat() / total.toFloat(), progress = { progress.toFloat() / total.toFloat() },
modifier = Modifier modifier = Modifier
.fillMaxWidth() .fillMaxWidth()
.height(8.dp) .height(8.dp)
@@ -287,7 +317,7 @@ private fun TrainingProgressContent(
Spacer(modifier = Modifier.height(16.dp)) Spacer(modifier = Modifier.height(16.dp))
LinearProgressIndicator( LinearProgressIndicator(
progress = progress.toFloat() / total.toFloat(), progress = { progress.toFloat() / total.toFloat() },
modifier = Modifier modifier = Modifier
.fillMaxWidth() .fillMaxWidth()
.height(8.dp) .height(8.dp)
@@ -304,6 +334,131 @@ private fun TrainingProgressContent(
} }
} }
// ===== REFINEMENT NEEDED =====
@Composable
private fun RefinementNeededContent(
recommendation: com.placeholder.sherpai2.domain.clustering.RefinementRecommendation,
currentIteration: Int,
onRefine: () -> Unit,
onSkip: () -> Unit
) {
Column(
modifier = Modifier
.fillMaxSize()
.padding(24.dp),
horizontalAlignment = Alignment.CenterHorizontally,
verticalArrangement = Arrangement.Center
) {
Icon(
imageVector = Icons.Default.Person,
contentDescription = null,
modifier = Modifier.size(80.dp),
tint = MaterialTheme.colorScheme.primary
)
Spacer(modifier = Modifier.height(24.dp))
Text(
text = "Refinement Recommended",
style = MaterialTheme.typography.headlineMedium,
fontWeight = FontWeight.Bold
)
Spacer(modifier = Modifier.height(16.dp))
Card(
colors = CardDefaults.cardColors(
containerColor = MaterialTheme.colorScheme.errorContainer
)
) {
Column(
modifier = Modifier.padding(16.dp)
) {
Text(
text = recommendation.reason,
style = MaterialTheme.typography.bodyMedium,
color = MaterialTheme.colorScheme.onErrorContainer
)
}
}
Spacer(modifier = Modifier.height(16.dp))
Text(
text = "Iteration: $currentIteration",
style = MaterialTheme.typography.bodySmall,
color = MaterialTheme.colorScheme.onSurfaceVariant
)
Spacer(modifier = Modifier.height(32.dp))
Row(
modifier = Modifier.fillMaxWidth(),
horizontalArrangement = Arrangement.spacedBy(12.dp)
) {
OutlinedButton(
onClick = onSkip,
modifier = Modifier.weight(1f)
) {
Text("Skip")
}
Button(
onClick = onRefine,
modifier = Modifier.weight(1f)
) {
Text("Refine Cluster")
}
}
}
}
// ===== REFINING PROGRESS =====
@Composable
private fun RefiningProgressContent(
iteration: Int,
message: String
) {
Column(
modifier = Modifier
.fillMaxSize()
.padding(24.dp),
horizontalAlignment = Alignment.CenterHorizontally,
verticalArrangement = Arrangement.Center
) {
CircularProgressIndicator(
modifier = Modifier.size(64.dp)
)
Spacer(modifier = Modifier.height(32.dp))
Text(
text = "Refining Cluster",
style = MaterialTheme.typography.titleLarge,
fontWeight = FontWeight.Bold
)
Spacer(modifier = Modifier.height(16.dp))
Text(
text = message,
style = MaterialTheme.typography.bodyMedium,
textAlign = TextAlign.Center,
color = MaterialTheme.colorScheme.onSurfaceVariant
)
Spacer(modifier = Modifier.height(8.dp))
Text(
text = "Iteration $iteration",
style = MaterialTheme.typography.bodySmall,
color = MaterialTheme.colorScheme.onSurfaceVariant
)
}
}
// ===== LOADING CONTENT ===== // ===== LOADING CONTENT =====
@Composable @Composable

View File

@@ -2,9 +2,8 @@ package com.placeholder.sherpai2.ui.discover
import androidx.lifecycle.ViewModel import androidx.lifecycle.ViewModel
import androidx.lifecycle.viewModelScope import androidx.lifecycle.viewModelScope
import com.placeholder.sherpai2.domain.clustering.ClusteringResult import com.placeholder.sherpai2.data.local.entity.FeedbackType
import com.placeholder.sherpai2.domain.clustering.FaceCluster import com.placeholder.sherpai2.domain.clustering.*
import com.placeholder.sherpai2.domain.clustering.FaceClusteringService
import com.placeholder.sherpai2.domain.training.ClusterTrainingService import com.placeholder.sherpai2.domain.training.ClusterTrainingService
import com.placeholder.sherpai2.domain.validation.ValidationScanResult import com.placeholder.sherpai2.domain.validation.ValidationScanResult
import com.placeholder.sherpai2.domain.validation.ValidationScanService import com.placeholder.sherpai2.domain.validation.ValidationScanService
@@ -16,33 +15,49 @@ import kotlinx.coroutines.launch
import javax.inject.Inject import javax.inject.Inject
/** /**
* DiscoverPeopleViewModel - COMPLETE workflow with validation * DiscoverPeopleViewModel - COMPLETE workflow with feedback loop
* *
* Flow: * FLOW WITH REFINEMENT:
* ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
* 1. Idle → Clustering → NamingReady (2x2 grid) * 1. Idle → Clustering → NamingReady (2x2 grid)
* 2. Select cluster → NamingCluster (dialog) * 2. Select cluster → NamingCluster (dialog)
* 3. Confirm → AnalyzingCluster → Training → ValidationPreview * 3. Confirm → AnalyzingCluster → Training → ValidationPreview
* 4. Approve → Complete OR Reject → Error * 4. User reviews faces → Marks correct/incorrect
* 5a. If too many incorrect → Refining (re-cluster without bad faces)
* 5b. If approved → Complete OR Reject → Error
* 6. Loop back to step 3 if refinement happened
*
* NEW FEATURES:
* ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
* ✅ User feedback collection
* ✅ Cluster refinement loop
* ✅ Feedback persistence
* ✅ Quality-aware training (only confirmed faces)
*/ */
@HiltViewModel @HiltViewModel
class DiscoverPeopleViewModel @Inject constructor( class DiscoverPeopleViewModel @Inject constructor(
private val clusteringService: FaceClusteringService, private val clusteringService: FaceClusteringService,
private val trainingService: ClusterTrainingService, private val trainingService: ClusterTrainingService,
private val validationService: ValidationScanService private val validationService: ValidationScanService,
private val refinementService: ClusterRefinementService
) : ViewModel() { ) : ViewModel() {
private val _uiState = MutableStateFlow<DiscoverUiState>(DiscoverUiState.Idle) private val _uiState = MutableStateFlow<DiscoverUiState>(DiscoverUiState.Idle)
val uiState: StateFlow<DiscoverUiState> = _uiState.asStateFlow() val uiState: StateFlow<DiscoverUiState> = _uiState.asStateFlow()
private val namedClusterIds = mutableSetOf<Int>() private val namedClusterIds = mutableSetOf<Int>()
private var currentIterationCount = 0
fun startDiscovery() { fun startDiscovery() {
viewModelScope.launch { viewModelScope.launch {
try { try {
namedClusterIds.clear() namedClusterIds.clear()
currentIterationCount = 0
_uiState.value = DiscoverUiState.Clustering(0, 100, "Starting...") _uiState.value = DiscoverUiState.Clustering(0, 100, "Starting...")
// Use PREMIUM_SOLO_ONLY strategy for best results
val result = clusteringService.discoverPeople( val result = clusteringService.discoverPeople(
strategy = ClusteringStrategy.PREMIUM_SOLO_ONLY,
onProgress = { current: Int, total: Int, message: String -> onProgress = { current: Int, total: Int, message: String ->
_uiState.value = DiscoverUiState.Clustering(current, total, message) _uiState.value = DiscoverUiState.Clustering(current, total, message)
} }
@@ -56,7 +71,7 @@ class DiscoverPeopleViewModel @Inject constructor(
if (result.clusters.isEmpty()) { if (result.clusters.isEmpty()) {
_uiState.value = DiscoverUiState.NoPeopleFound( _uiState.value = DiscoverUiState.NoPeopleFound(
result.errorMessage result.errorMessage
?: "No people clusters found.\n\nTry:\n• Adding more photos\n• Ensuring photos are clear\n• Having 3+ photos per person" ?: "No people clusters found.\n\nTry:\n• Adding more solo photos\n• Ensuring photos are clear\n• Having 6+ photos per person"
) )
} else { } else {
_uiState.value = DiscoverUiState.NamingReady(result) _uiState.value = DiscoverUiState.NamingReady(result)
@@ -131,10 +146,11 @@ class DiscoverPeopleViewModel @Inject constructor(
} }
) )
// Stage 4: Show validation preview // Stage 4: Show validation preview WITH FEEDBACK SUPPORT
_uiState.value = DiscoverUiState.ValidationPreview( _uiState.value = DiscoverUiState.ValidationPreview(
personId = personId, personId = personId,
personName = name, personName = name,
cluster = cluster,
validationResult = validationResult validationResult = validationResult
) )
@@ -144,14 +160,112 @@ class DiscoverPeopleViewModel @Inject constructor(
} }
} }
/**
* NEW: Handle user feedback from validation preview
*
* @param cluster The cluster being validated
* @param feedbackMap Map of imageId → FeedbackType
*/
fun submitFeedback(
cluster: FaceCluster,
feedbackMap: Map<String, FeedbackType>
) {
viewModelScope.launch {
try {
// Convert imageId feedback to face feedback
val faceFeedbackMap = cluster.faces
.associateWith { face ->
feedbackMap[face.imageId] ?: FeedbackType.UNCERTAIN
}
val originalConfidences = cluster.faces.associateWith { it.confidence }
// Store feedback
refinementService.storeFeedback(
cluster = cluster,
feedbackMap = faceFeedbackMap,
originalConfidences = originalConfidences
)
// Check if refinement needed
val recommendation = refinementService.shouldRefineCluster(cluster)
if (recommendation.shouldRefine) {
_uiState.value = DiscoverUiState.RefinementNeeded(
cluster = cluster,
recommendation = recommendation,
currentIteration = currentIterationCount
)
}
} catch (e: Exception) {
_uiState.value = DiscoverUiState.Error(
"Failed to process feedback: ${e.message}"
)
}
}
}
/**
* NEW: Request cluster refinement
*
* Re-clusters WITHOUT rejected faces
*/
fun requestRefinement(cluster: FaceCluster) {
viewModelScope.launch {
try {
currentIterationCount++
_uiState.value = DiscoverUiState.Refining(
iteration = currentIterationCount,
message = "Removing incorrect faces and re-clustering..."
)
// Refine cluster
val refinementResult = refinementService.refineCluster(
cluster = cluster,
iterationNumber = currentIterationCount
)
if (!refinementResult.success || refinementResult.refinedCluster == null) {
_uiState.value = DiscoverUiState.Error(
refinementResult.errorMessage
?: "Failed to refine cluster. Please try manual training."
)
return@launch
}
// Show refined cluster for re-validation
val currentState = _uiState.value
if (currentState is DiscoverUiState.RefinementNeeded) {
// Re-train with refined cluster
// This will loop back to ValidationPreview
confirmClusterName(
cluster = refinementResult.refinedCluster,
name = currentState.cluster.representativeFaces.first().imageId, // Placeholder
dateOfBirth = null,
isChild = false,
selectedSiblings = emptyList()
)
}
} catch (e: Exception) {
_uiState.value = DiscoverUiState.Error(
"Refinement failed: ${e.message}"
)
}
}
}
fun approveValidationAndScan(personId: String, personName: String) { fun approveValidationAndScan(personId: String, personName: String) {
viewModelScope.launch { viewModelScope.launch {
try { try {
// Mark cluster as named (find it from previous state) // Mark cluster as named
// TODO: Track this properly // TODO: Track this properly
_uiState.value = DiscoverUiState.Complete( _uiState.value = DiscoverUiState.Complete(
message = "Successfully created model for \"$personName\"!\n\nFull library scan has been queued in the background." message = "Successfully created model for \"$personName\"!\n\n" +
"Full library scan has been queued in the background.\n\n" +
"${currentIterationCount} refinement iterations completed"
) )
} catch (e: Exception) { } catch (e: Exception) {
_uiState.value = DiscoverUiState.Error(e.message ?: "Failed to start library scan") _uiState.value = DiscoverUiState.Error(e.message ?: "Failed to start library scan")
@@ -161,7 +275,8 @@ class DiscoverPeopleViewModel @Inject constructor(
fun rejectValidationAndImprove() { fun rejectValidationAndImprove() {
_uiState.value = DiscoverUiState.Error( _uiState.value = DiscoverUiState.Error(
"Please add more training photos and try again.\n\n(Feature coming: ability to add photos to existing model)" "Please add more training photos and try again.\n\n" +
"(Feature coming: ability to add photos to existing model)"
) )
} }
@@ -175,9 +290,13 @@ class DiscoverPeopleViewModel @Inject constructor(
fun reset() { fun reset() {
_uiState.value = DiscoverUiState.Idle _uiState.value = DiscoverUiState.Idle
namedClusterIds.clear() namedClusterIds.clear()
currentIterationCount = 0
} }
} }
/**
* UI States - ENHANCED with refinement states
*/
sealed class DiscoverUiState { sealed class DiscoverUiState {
object Idle : DiscoverUiState() object Idle : DiscoverUiState()
@@ -205,12 +324,33 @@ sealed class DiscoverUiState {
val total: Int val total: Int
) : DiscoverUiState() ) : DiscoverUiState()
/**
* NEW: Validation with feedback support
*/
data class ValidationPreview( data class ValidationPreview(
val personId: String, val personId: String,
val personName: String, val personName: String,
val cluster: FaceCluster,
val validationResult: ValidationScanResult val validationResult: ValidationScanResult
) : DiscoverUiState() ) : DiscoverUiState()
/**
* NEW: Refinement needed state
*/
data class RefinementNeeded(
val cluster: FaceCluster,
val recommendation: RefinementRecommendation,
val currentIteration: Int
) : DiscoverUiState()
/**
* NEW: Refining in progress
*/
data class Refining(
val iteration: Int,
val message: String
) : DiscoverUiState()
data class Complete( data class Complete(
val message: String val message: String
) : DiscoverUiState() ) : DiscoverUiState()

View File

@@ -0,0 +1,353 @@
package com.placeholder.sherpai2.ui.discover
import androidx.compose.foundation.layout.*
import androidx.compose.foundation.text.KeyboardOptions
import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.filled.*
import androidx.compose.material3.*
import androidx.compose.runtime.*
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.text.font.FontWeight
import androidx.compose.ui.text.input.KeyboardType
import androidx.compose.ui.unit.dp
import androidx.compose.ui.window.Dialog
import com.placeholder.sherpai2.domain.clustering.AnnotatedCluster
import com.placeholder.sherpai2.domain.clustering.ClusterQualityAnalyzer
import com.placeholder.sherpai2.domain.clustering.ClusterQualityResult
/**
* TemporalNamingDialog - ENHANCED with age input for temporal clustering
*
* NEW FEATURES:
* ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
* ✅ Name field: "Emma"
* ✅ Age field: "2" (optional but recommended)
* ✅ Year display: "Photos from 2020"
* ✅ Auto-suggest: If year=2020 and DOB=2018 → Age=2
*
* NAMING PATTERNS:
* ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
* Adults:
* - Name: "John Doe"
* - Age: (leave empty)
* - Result: Person "John Doe" with single model
*
* Children (with age):
* - Name: "Emma"
* - Age: "2"
* - Year: "2020"
* - Result: Person "Emma" with submodel "Emma_Age_2"
*
* Children (without age):
* - Name: "Emma"
* - Age: (empty)
* - Year: "2020"
* - Result: Person "Emma" with submodel "Emma_2020"
*/
@Composable
fun TemporalNamingDialog(
annotatedCluster: AnnotatedCluster,
onConfirm: (name: String, age: Int?, isChild: Boolean) -> Unit,
onDismiss: () -> Unit,
qualityAnalyzer: ClusterQualityAnalyzer
) {
var name by remember { mutableStateOf(annotatedCluster.suggestedName ?: "") }
var ageText by remember { mutableStateOf(annotatedCluster.suggestedAge?.toString() ?: "") }
var isChild by remember { mutableStateOf(annotatedCluster.suggestedAge != null) }
// Analyze cluster quality
val qualityResult = remember(annotatedCluster.cluster) {
qualityAnalyzer.analyzeCluster(annotatedCluster.cluster)
}
Dialog(onDismissRequest = onDismiss) {
Card(
modifier = Modifier
.fillMaxWidth()
.padding(16.dp)
) {
Column(
modifier = Modifier.padding(24.dp),
verticalArrangement = Arrangement.spacedBy(16.dp)
) {
// Header
Text(
text = "Name This Person",
style = MaterialTheme.typography.headlineSmall,
fontWeight = FontWeight.Bold
)
// Year badge
YearBadge(year = annotatedCluster.year)
HorizontalDivider()
// Quality warnings
QualityWarnings(qualityResult)
// Name field
OutlinedTextField(
value = name,
onValueChange = { name = it },
label = { Text("Name") },
placeholder = { Text("e.g., Emma") },
leadingIcon = {
Icon(Icons.Default.Person, contentDescription = null)
},
modifier = Modifier.fillMaxWidth(),
singleLine = true
)
// Child checkbox
Row(
modifier = Modifier.fillMaxWidth(),
verticalAlignment = Alignment.CenterVertically
) {
Checkbox(
checked = isChild,
onCheckedChange = { isChild = it }
)
Spacer(modifier = Modifier.width(8.dp))
Column {
Text(
text = "This is a child",
style = MaterialTheme.typography.bodyMedium
)
Text(
text = "Enable age-specific models",
style = MaterialTheme.typography.bodySmall,
color = MaterialTheme.colorScheme.onSurfaceVariant
)
}
}
// Age field (only if child)
if (isChild) {
OutlinedTextField(
value = ageText,
onValueChange = {
// Only allow numbers
if (it.isEmpty() || it.all { c -> c.isDigit() }) {
ageText = it
}
},
label = { Text("Age in ${annotatedCluster.year}") },
placeholder = { Text("e.g., 2") },
leadingIcon = {
Icon(Icons.Default.DateRange, contentDescription = null)
},
modifier = Modifier.fillMaxWidth(),
singleLine = true,
keyboardOptions = KeyboardOptions(keyboardType = KeyboardType.Number),
supportingText = {
Text("Optional: Helps create age-specific models")
}
)
// Model name preview
if (name.isNotBlank()) {
Card(
colors = CardDefaults.cardColors(
containerColor = MaterialTheme.colorScheme.primaryContainer
)
) {
Row(
modifier = Modifier
.fillMaxWidth()
.padding(12.dp),
verticalAlignment = Alignment.CenterVertically
) {
Icon(
imageVector = Icons.Default.Info,
contentDescription = null,
tint = MaterialTheme.colorScheme.onPrimaryContainer
)
Spacer(modifier = Modifier.width(8.dp))
Column {
Text(
text = "Model will be created as:",
style = MaterialTheme.typography.bodySmall,
color = MaterialTheme.colorScheme.onPrimaryContainer
)
Text(
text = buildModelName(name, ageText, annotatedCluster.year),
style = MaterialTheme.typography.bodyMedium,
fontWeight = FontWeight.Bold,
color = MaterialTheme.colorScheme.onPrimaryContainer
)
}
}
}
}
}
// Cluster stats
ClusterStats(qualityResult)
HorizontalDivider()
// Actions
Row(
modifier = Modifier.fillMaxWidth(),
horizontalArrangement = Arrangement.spacedBy(12.dp)
) {
OutlinedButton(
onClick = onDismiss,
modifier = Modifier.weight(1f)
) {
Text("Cancel")
}
Button(
onClick = {
val age = ageText.toIntOrNull()
onConfirm(name, age, isChild)
},
modifier = Modifier.weight(1f),
enabled = name.isNotBlank() && qualityResult.canTrain
) {
Text("Create")
}
}
}
}
}
}
/**
* Year badge showing photo year
*/
@Composable
private fun YearBadge(year: String) {
Surface(
color = MaterialTheme.colorScheme.secondaryContainer,
shape = MaterialTheme.shapes.small
) {
Row(
modifier = Modifier.padding(horizontal = 12.dp, vertical = 6.dp),
verticalAlignment = Alignment.CenterVertically,
horizontalArrangement = Arrangement.spacedBy(4.dp)
) {
Icon(
imageVector = Icons.Default.DateRange,
contentDescription = null,
modifier = Modifier.size(16.dp),
tint = MaterialTheme.colorScheme.onSecondaryContainer
)
Text(
text = "Photos from $year",
style = MaterialTheme.typography.labelMedium,
color = MaterialTheme.colorScheme.onSecondaryContainer
)
}
}
}
/**
* Quality warnings
*/
@Composable
private fun QualityWarnings(qualityResult: ClusterQualityResult) {
if (qualityResult.warnings.isNotEmpty()) {
Card(
colors = CardDefaults.cardColors(
containerColor = when (qualityResult.qualityTier) {
com.placeholder.sherpai2.domain.clustering.ClusterQualityTier.POOR ->
MaterialTheme.colorScheme.errorContainer
com.placeholder.sherpai2.domain.clustering.ClusterQualityTier.GOOD ->
MaterialTheme.colorScheme.tertiaryContainer
else -> MaterialTheme.colorScheme.surfaceVariant
}
)
) {
Column(
modifier = Modifier.padding(12.dp),
verticalArrangement = Arrangement.spacedBy(4.dp)
) {
qualityResult.warnings.take(3).forEach { warning ->
Row(
verticalAlignment = Alignment.Top,
horizontalArrangement = Arrangement.spacedBy(8.dp)
) {
Icon(
imageVector = when (qualityResult.qualityTier) {
com.placeholder.sherpai2.domain.clustering.ClusterQualityTier.POOR ->
Icons.Default.Warning
else -> Icons.Default.Info
},
contentDescription = null,
modifier = Modifier.size(16.dp),
tint = when (qualityResult.qualityTier) {
com.placeholder.sherpai2.domain.clustering.ClusterQualityTier.POOR ->
MaterialTheme.colorScheme.onErrorContainer
else -> MaterialTheme.colorScheme.onSurfaceVariant
}
)
Text(
text = warning,
style = MaterialTheme.typography.bodySmall,
color = when (qualityResult.qualityTier) {
com.placeholder.sherpai2.domain.clustering.ClusterQualityTier.POOR ->
MaterialTheme.colorScheme.onErrorContainer
else -> MaterialTheme.colorScheme.onSurfaceVariant
}
)
}
}
}
}
}
}
/**
* Cluster statistics
*/
@Composable
private fun ClusterStats(qualityResult: ClusterQualityResult) {
Row(
modifier = Modifier.fillMaxWidth(),
horizontalArrangement = Arrangement.SpaceEvenly
) {
StatItem(
label = "Photos",
value = qualityResult.soloPhotoCount.toString()
)
StatItem(
label = "Clean Faces",
value = qualityResult.cleanFaceCount.toString()
)
StatItem(
label = "Quality",
value = "${(qualityResult.qualityScore * 100).toInt()}%"
)
}
}
@Composable
private fun StatItem(label: String, value: String) {
Column(
horizontalAlignment = Alignment.CenterHorizontally
) {
Text(
text = value,
style = MaterialTheme.typography.titleMedium,
fontWeight = FontWeight.Bold
)
Text(
text = label,
style = MaterialTheme.typography.bodySmall,
color = MaterialTheme.colorScheme.onSurfaceVariant
)
}
}
/**
* Build model name preview
*/
private fun buildModelName(name: String, ageText: String, year: String): String {
return when {
ageText.isNotBlank() -> "${name}_Age_${ageText}"
else -> "${name}_${year}"
}
}

View File

@@ -1,12 +1,17 @@
package com.placeholder.sherpai2.ui.discover package com.placeholder.sherpai2.ui.discover
import android.net.Uri import android.net.Uri
import androidx.compose.animation.AnimatedVisibility
import androidx.compose.animation.core.animateFloatAsState
import androidx.compose.foundation.background import androidx.compose.foundation.background
import androidx.compose.foundation.border import androidx.compose.foundation.border
import androidx.compose.foundation.clickable
import androidx.compose.foundation.gestures.detectDragGestures
import androidx.compose.foundation.layout.* import androidx.compose.foundation.layout.*
import androidx.compose.foundation.lazy.grid.GridCells import androidx.compose.foundation.lazy.grid.GridCells
import androidx.compose.foundation.lazy.grid.LazyVerticalGrid import androidx.compose.foundation.lazy.grid.LazyVerticalGrid
import androidx.compose.foundation.lazy.grid.items import androidx.compose.foundation.lazy.grid.items
import androidx.compose.foundation.shape.CircleShape
import androidx.compose.foundation.shape.RoundedCornerShape import androidx.compose.foundation.shape.RoundedCornerShape
import androidx.compose.material.icons.Icons import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.filled.* import androidx.compose.material.icons.filled.*
@@ -15,268 +20,458 @@ import androidx.compose.runtime.*
import androidx.compose.ui.Alignment import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier import androidx.compose.ui.Modifier
import androidx.compose.ui.draw.clip import androidx.compose.ui.draw.clip
import androidx.compose.ui.draw.scale
import androidx.compose.ui.graphics.Color import androidx.compose.ui.graphics.Color
import androidx.compose.ui.input.pointer.pointerInput
import androidx.compose.ui.layout.ContentScale import androidx.compose.ui.layout.ContentScale
import androidx.compose.ui.text.font.FontWeight import androidx.compose.ui.text.font.FontWeight
import androidx.compose.ui.text.style.TextAlign
import androidx.compose.ui.unit.IntOffset
import androidx.compose.ui.unit.dp import androidx.compose.ui.unit.dp
import androidx.compose.ui.zIndex
import coil.compose.AsyncImage import coil.compose.AsyncImage
import com.placeholder.sherpai2.domain.validation.ValidationMatch import com.placeholder.sherpai2.data.local.entity.FeedbackType
import com.placeholder.sherpai2.domain.validation.ValidationQuality
import com.placeholder.sherpai2.domain.validation.ValidationScanResult import com.placeholder.sherpai2.domain.validation.ValidationScanResult
import com.placeholder.sherpai2.domain.validation.ValidationMatch
import kotlin.math.roundToInt
/** /**
* ValidationPreviewScreen - STAGE 2 validation UI * ValidationPreviewScreen - User reviews validation results with swipe gestures
* *
* Shows user a preview of matches found in validation scan * FEATURES:
* User can approve (→ full scan) or reject (→ add more photos) * ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
* ✅ Swipe right (✓) = Confirmed match
* ✅ Swipe left (✗) = Rejected match
* ✅ Tap = Mark uncertain (?)
* ✅ Real-time feedback stats
* ✅ Automatic refinement recommendation
* ✅ Bottom bar with approve/reject/refine actions
*
* FLOW:
* ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
* 1. User swipes/taps to mark faces
* 2. Feedback tracked in local state
* 3. If >15% rejection → "Refine" button appears
* 4. Approve → Sends feedback map to ViewModel
* 5. Reject → Returns to previous screen
* 6. Refine → Triggers cluster refinement
*/ */
@Composable @Composable
fun ValidationPreviewScreen( fun ValidationPreviewScreen(
personName: String, personName: String,
validationResult: ValidationScanResult, validationResult: ValidationScanResult,
onMarkFeedback: (Map<String, FeedbackType>) -> Unit = {},
onRequestRefinement: () -> Unit = {},
onApprove: () -> Unit, onApprove: () -> Unit,
onReject: () -> Unit, onReject: () -> Unit,
modifier: Modifier = Modifier modifier: Modifier = Modifier
) { ) {
Column( // Get sample images from validation result matches
modifier = modifier val sampleMatches = remember(validationResult) {
.fillMaxSize() validationResult.matches.take(24) // Show up to 24 faces
.padding(16.dp) }
) {
// Header
Text(
text = "Validation Results",
style = MaterialTheme.typography.headlineMedium,
fontWeight = FontWeight.Bold
)
Spacer(modifier = Modifier.height(8.dp)) // Track feedback for each image (imageId -> FeedbackType)
var feedbackMap by remember {
mutableStateOf<Map<String, FeedbackType>>(emptyMap())
}
Text( // Calculate feedback statistics
text = "Review matches for \"$personName\" before scanning your entire library", val confirmedCount = feedbackMap.count { it.value == FeedbackType.CONFIRMED_MATCH }
style = MaterialTheme.typography.bodyMedium, val rejectedCount = feedbackMap.count { it.value == FeedbackType.REJECTED_MATCH }
color = MaterialTheme.colorScheme.onSurfaceVariant val uncertainCount = feedbackMap.count { it.value == FeedbackType.UNCERTAIN }
) val reviewedCount = feedbackMap.size
val totalCount = sampleMatches.size
Spacer(modifier = Modifier.height(16.dp)) // Determine if refinement is recommended
val rejectionRatio = if (reviewedCount > 0) {
rejectedCount.toFloat() / reviewedCount.toFloat()
} else {
0f
}
val shouldRefine = rejectionRatio > 0.15f && rejectedCount >= 2
// Quality Summary Scaffold(
QualitySummaryCard( bottomBar = {
validationResult = validationResult, ValidationBottomBar(
personName = personName confirmedCount = confirmedCount,
) rejectedCount = rejectedCount,
uncertainCount = uncertainCount,
Spacer(modifier = Modifier.height(16.dp)) reviewedCount = reviewedCount,
totalCount = totalCount,
// Matches Grid shouldRefine = shouldRefine,
if (validationResult.matches.isNotEmpty()) { onApprove = {
onMarkFeedback(feedbackMap)
onApprove()
},
onReject = onReject,
onRefine = {
onMarkFeedback(feedbackMap)
onRequestRefinement()
}
)
}
) { paddingValues ->
Column(
modifier = modifier
.fillMaxSize()
.padding(paddingValues)
.padding(16.dp)
) {
// Header
Text( Text(
text = "Sample Matches (${validationResult.matchCount})", text = "Validate \"$personName\"",
style = MaterialTheme.typography.titleMedium, style = MaterialTheme.typography.headlineMedium,
fontWeight = FontWeight.SemiBold fontWeight = FontWeight.Bold
) )
Spacer(modifier = Modifier.height(8.dp)) Spacer(modifier = Modifier.height(8.dp))
// Instructions
InstructionsCard()
Spacer(modifier = Modifier.height(16.dp))
// Feedback stats
FeedbackStatsCard(
confirmedCount = confirmedCount,
rejectedCount = rejectedCount,
uncertainCount = uncertainCount,
reviewedCount = reviewedCount,
totalCount = totalCount
)
Spacer(modifier = Modifier.height(16.dp))
// Grid of faces to review
LazyVerticalGrid( LazyVerticalGrid(
columns = GridCells.Fixed(3), columns = GridCells.Fixed(3),
horizontalArrangement = Arrangement.spacedBy(8.dp), horizontalArrangement = Arrangement.spacedBy(8.dp),
verticalArrangement = Arrangement.spacedBy(8.dp), verticalArrangement = Arrangement.spacedBy(8.dp),
modifier = Modifier.weight(1f) modifier = Modifier.weight(1f)
) { ) {
items(validationResult.matches.take(15)) { match -> items(
MatchPreviewCard(match = match) items = sampleMatches,
key = { match -> match.imageId }
) { match ->
SwipeableFaceCard(
match = match,
currentFeedback = feedbackMap[match.imageId],
onFeedbackChange = { feedback ->
feedbackMap = feedbackMap.toMutableMap().apply {
put(match.imageId, feedback)
}
}
)
} }
} }
} else {
// No matches found
NoMatchesCard()
}
Spacer(modifier = Modifier.height(16.dp))
// Action Buttons
Row(
modifier = Modifier.fillMaxWidth(),
horizontalArrangement = Arrangement.spacedBy(12.dp)
) {
// Reject button
OutlinedButton(
onClick = onReject,
modifier = Modifier.weight(1f),
colors = ButtonDefaults.outlinedButtonColors(
contentColor = MaterialTheme.colorScheme.error
)
) {
Icon(
imageVector = Icons.Default.Close,
contentDescription = null,
modifier = Modifier.size(20.dp)
)
Spacer(modifier = Modifier.width(8.dp))
Text("Add More Photos")
}
// Approve button
Button(
onClick = onApprove,
modifier = Modifier.weight(1f),
enabled = validationResult.qualityAssessment != ValidationQuality.NO_MATCHES
) {
Icon(
imageVector = Icons.Default.Check,
contentDescription = null,
modifier = Modifier.size(20.dp)
)
Spacer(modifier = Modifier.width(8.dp))
Text("Scan Library")
}
} }
} }
} }
/**
* Swipeable face card with visual feedback indicators
*/
@Composable @Composable
private fun QualitySummaryCard( private fun SwipeableFaceCard(
validationResult: ValidationScanResult, match: ValidationMatch,
personName: String currentFeedback: FeedbackType?,
onFeedbackChange: (FeedbackType) -> Unit
) { ) {
val (backgroundColor, iconColor, statusText, statusIcon) = when (validationResult.qualityAssessment) { var offsetX by remember { mutableFloatStateOf(0f) }
ValidationQuality.EXCELLENT -> { var isDragging by remember { mutableStateOf(false) }
Quadruple(
Color(0xFF1B5E20).copy(alpha = 0.1f), val scale by animateFloatAsState(
Color(0xFF1B5E20), targetValue = if (isDragging) 1.1f else 1f,
"Excellent Match Quality", label = "scale"
Icons.Default.CheckCircle )
Box(
modifier = Modifier
.aspectRatio(1f)
.scale(scale)
.zIndex(if (isDragging) 1f else 0f)
) {
// Face image with border color based on feedback
AsyncImage(
model = Uri.parse(match.imageUri),
contentDescription = "Face",
modifier = Modifier
.fillMaxSize()
.clip(RoundedCornerShape(12.dp))
.border(
width = 3.dp,
color = when (currentFeedback) {
FeedbackType.CONFIRMED_MATCH -> Color(0xFF4CAF50) // Green
FeedbackType.REJECTED_MATCH -> Color(0xFFF44336) // Red
FeedbackType.UNCERTAIN -> Color(0xFFFF9800) // Orange
else -> MaterialTheme.colorScheme.outline
},
shape = RoundedCornerShape(12.dp)
)
.offset { IntOffset(offsetX.roundToInt(), 0) }
.pointerInput(Unit) {
detectDragGestures(
onDragStart = {
isDragging = true
},
onDrag = { _, dragAmount ->
offsetX += dragAmount.x
},
onDragEnd = {
isDragging = false
// Determine feedback based on swipe direction
when {
offsetX > 100 -> {
onFeedbackChange(FeedbackType.CONFIRMED_MATCH)
}
offsetX < -100 -> {
onFeedbackChange(FeedbackType.REJECTED_MATCH)
}
}
// Reset position
offsetX = 0f
},
onDragCancel = {
isDragging = false
offsetX = 0f
}
)
}
.clickable {
// Tap to toggle uncertain
val newFeedback = when (currentFeedback) {
FeedbackType.UNCERTAIN -> null
else -> FeedbackType.UNCERTAIN
}
if (newFeedback != null) {
onFeedbackChange(newFeedback)
}
},
contentScale = ContentScale.Crop
)
// Confidence badge (top-left)
Surface(
modifier = Modifier
.align(Alignment.TopStart)
.padding(4.dp),
shape = RoundedCornerShape(4.dp),
color = Color.Black.copy(alpha = 0.6f)
) {
Text(
text = "${(match.confidence * 100).toInt()}%",
modifier = Modifier.padding(horizontal = 6.dp, vertical = 2.dp),
style = MaterialTheme.typography.labelSmall,
color = Color.White,
fontWeight = FontWeight.Bold
) )
} }
ValidationQuality.GOOD -> {
Quadruple( // Feedback indicator overlay (top-right)
Color(0xFF2E7D32).copy(alpha = 0.1f), if (currentFeedback != null) {
Color(0xFF2E7D32), Surface(
"Good Match Quality", modifier = Modifier
Icons.Default.ThumbUp .align(Alignment.TopEnd)
) .padding(4.dp),
shape = CircleShape,
color = when (currentFeedback) {
FeedbackType.CONFIRMED_MATCH -> Color(0xFF4CAF50)
FeedbackType.REJECTED_MATCH -> Color(0xFFF44336)
FeedbackType.UNCERTAIN -> Color(0xFFFF9800)
else -> Color.Transparent
},
shadowElevation = 2.dp
) {
Icon(
imageVector = when (currentFeedback) {
FeedbackType.CONFIRMED_MATCH -> Icons.Default.Check
FeedbackType.REJECTED_MATCH -> Icons.Default.Close
FeedbackType.UNCERTAIN -> Icons.Default.Warning
else -> Icons.Default.Info
},
contentDescription = currentFeedback.name,
tint = Color.White,
modifier = Modifier
.size(32.dp)
.padding(6.dp)
)
}
} }
ValidationQuality.FAIR -> {
Quadruple( // Swipe hint during drag
Color(0xFFF57F17).copy(alpha = 0.1f), if (isDragging) {
Color(0xFFF57F17), SwipeDragHint(offsetX = offsetX)
"Fair Match Quality",
Icons.Default.Warning
)
}
ValidationQuality.POOR -> {
Quadruple(
Color(0xFFD32F2F).copy(alpha = 0.1f),
Color(0xFFD32F2F),
"Poor Match Quality",
Icons.Default.Warning
)
}
ValidationQuality.NO_MATCHES -> {
Quadruple(
Color(0xFFD32F2F).copy(alpha = 0.1f),
Color(0xFFD32F2F),
"No Matches Found",
Icons.Default.Close
)
} }
} }
}
/**
* Swipe drag hint overlay
*/
@Composable
private fun BoxScope.SwipeDragHint(offsetX: Float) {
val hintText = when {
offsetX > 50 -> "✓ Correct"
offsetX < -50 -> "✗ Incorrect"
else -> "Keep swiping"
}
val hintColor = when {
offsetX > 50 -> Color(0xFF4CAF50)
offsetX < -50 -> Color(0xFFF44336)
else -> Color.Gray
}
Surface(
modifier = Modifier
.align(Alignment.BottomCenter)
.padding(8.dp),
shape = RoundedCornerShape(4.dp),
color = hintColor.copy(alpha = 0.9f)
) {
Text(
text = hintText,
modifier = Modifier.padding(horizontal = 8.dp, vertical = 4.dp),
style = MaterialTheme.typography.labelSmall,
color = Color.White,
fontWeight = FontWeight.Bold
)
}
}
/**
* Instructions card showing gesture controls
*/
@Composable
private fun InstructionsCard() {
Card( Card(
modifier = Modifier.fillMaxWidth(),
colors = CardDefaults.cardColors( colors = CardDefaults.cardColors(
containerColor = backgroundColor containerColor = MaterialTheme.colorScheme.primaryContainer
) )
) { ) {
Column( Row(
modifier = Modifier.padding(16.dp) modifier = Modifier.padding(16.dp),
verticalAlignment = Alignment.CenterVertically
) { ) {
Row( Icon(
verticalAlignment = Alignment.CenterVertically imageVector = Icons.Default.Info,
) { contentDescription = null,
Icon( tint = MaterialTheme.colorScheme.onPrimaryContainer
imageVector = statusIcon, )
contentDescription = null,
tint = iconColor, Spacer(modifier = Modifier.width(12.dp))
modifier = Modifier.size(24.dp)
) Column {
Spacer(modifier = Modifier.width(8.dp))
Text( Text(
text = statusText, text = "Review Detected Faces",
style = MaterialTheme.typography.titleMedium, style = MaterialTheme.typography.titleSmall,
fontWeight = FontWeight.Bold, fontWeight = FontWeight.Bold,
color = iconColor color = MaterialTheme.colorScheme.onPrimaryContainer
) )
} Spacer(modifier = Modifier.height(4.dp))
Spacer(modifier = Modifier.height(12.dp))
// Stats
Row(
modifier = Modifier.fillMaxWidth(),
horizontalArrangement = Arrangement.SpaceBetween
) {
StatItem(
label = "Matches Found",
value = "${validationResult.matchCount} / ${validationResult.sampleSize}"
)
StatItem(
label = "Avg Confidence",
value = "${(validationResult.averageConfidence * 100).toInt()}%"
)
StatItem(
label = "Threshold",
value = "${(validationResult.threshold * 100).toInt()}%"
)
}
// Recommendation
if (validationResult.qualityAssessment != ValidationQuality.NO_MATCHES) {
Spacer(modifier = Modifier.height(12.dp))
val recommendation = when (validationResult.qualityAssessment) {
ValidationQuality.EXCELLENT ->
"✅ Model looks great! Safe to scan your full library."
ValidationQuality.GOOD ->
"✅ Model quality is good. You can proceed with the full scan."
ValidationQuality.FAIR ->
"⚠️ Model quality is acceptable but could be improved with more photos."
ValidationQuality.POOR ->
"⚠️ Consider adding more diverse, high-quality training photos."
ValidationQuality.NO_MATCHES -> ""
}
Text( Text(
text = recommendation, text = "Swipe right ✅ for correct, left ❌ for incorrect, tap ❓ for uncertain",
style = MaterialTheme.typography.bodyMedium, style = MaterialTheme.typography.bodySmall,
color = MaterialTheme.colorScheme.onSurfaceVariant color = MaterialTheme.colorScheme.onPrimaryContainer
)
} else {
Spacer(modifier = Modifier.height(12.dp))
Text(
text = "No matches found. The model may need more or better training photos, or the validation sample didn't include $personName.",
style = MaterialTheme.typography.bodyMedium,
color = MaterialTheme.colorScheme.error
) )
} }
} }
} }
} }
/**
* Feedback statistics card
*/
@Composable @Composable
private fun StatItem( private fun FeedbackStatsCard(
label: String, confirmedCount: Int,
value: String rejectedCount: Int,
uncertainCount: Int,
reviewedCount: Int,
totalCount: Int
) {
Card {
Row(
modifier = Modifier
.fillMaxWidth()
.padding(16.dp),
horizontalArrangement = Arrangement.SpaceEvenly
) {
FeedbackStat(
icon = Icons.Default.Check,
color = Color(0xFF4CAF50),
count = confirmedCount,
label = "Correct"
)
FeedbackStat(
icon = Icons.Default.Close,
color = Color(0xFFF44336),
count = rejectedCount,
label = "Incorrect"
)
FeedbackStat(
icon = Icons.Default.Warning,
color = Color(0xFFFF9800),
count = uncertainCount,
label = "Uncertain"
)
}
val progressValue = if (totalCount > 0) {
reviewedCount.toFloat() / totalCount.toFloat()
} else {
0f
}
LinearProgressIndicator(
progress = { progressValue },
modifier = Modifier
.fillMaxWidth()
.height(4.dp)
)
}
}
/**
* Individual feedback statistic item
*/
@Composable
private fun FeedbackStat(
icon: androidx.compose.ui.graphics.vector.ImageVector,
color: Color,
count: Int,
label: String
) { ) {
Column( Column(
horizontalAlignment = Alignment.CenterHorizontally horizontalAlignment = Alignment.CenterHorizontally
) { ) {
Surface(
shape = CircleShape,
color = color.copy(alpha = 0.2f)
) {
Icon(
imageVector = icon,
contentDescription = null,
tint = color,
modifier = Modifier
.size(40.dp)
.padding(8.dp)
)
}
Spacer(modifier = Modifier.height(4.dp))
Text( Text(
text = value, text = count.toString(),
style = MaterialTheme.typography.titleLarge, style = MaterialTheme.typography.titleMedium,
fontWeight = FontWeight.Bold fontWeight = FontWeight.Bold
) )
Text( Text(
text = label, text = label,
style = MaterialTheme.typography.bodySmall, style = MaterialTheme.typography.bodySmall,
@@ -285,111 +480,134 @@ private fun StatItem(
} }
} }
/**
* Bottom action bar with approve/reject/refine buttons
*/
@Composable @Composable
private fun MatchPreviewCard( private fun ValidationBottomBar(
match: ValidationMatch confirmedCount: Int,
rejectedCount: Int,
uncertainCount: Int,
reviewedCount: Int,
totalCount: Int,
shouldRefine: Boolean,
onApprove: () -> Unit,
onReject: () -> Unit,
onRefine: () -> Unit
) { ) {
Box( Surface(
modifier = Modifier modifier = Modifier.fillMaxWidth(),
.aspectRatio(1f) color = MaterialTheme.colorScheme.surface,
.clip(RoundedCornerShape(8.dp)) shadowElevation = 8.dp
.background(MaterialTheme.colorScheme.surfaceVariant)
) { ) {
AsyncImage( Column(
model = Uri.parse(match.imageUri), modifier = Modifier.padding(16.dp)
contentDescription = "Match preview",
modifier = Modifier.fillMaxSize(),
contentScale = ContentScale.Crop
)
// Confidence badge
Surface(
modifier = Modifier
.align(Alignment.BottomEnd)
.padding(4.dp),
shape = RoundedCornerShape(4.dp),
color = Color.Black.copy(alpha = 0.7f)
) { ) {
// Refinement warning banner
AnimatedVisibility(visible = shouldRefine) {
RefinementWarningBanner(
rejectedCount = rejectedCount,
reviewedCount = reviewedCount,
onRefine = onRefine
)
}
// Main action buttons
Row(
modifier = Modifier.fillMaxWidth(),
horizontalArrangement = Arrangement.spacedBy(12.dp)
) {
OutlinedButton(
onClick = onReject,
modifier = Modifier.weight(1f)
) {
Icon(Icons.Default.Close, contentDescription = null)
Spacer(modifier = Modifier.width(8.dp))
Text("Reject")
}
Button(
onClick = onApprove,
modifier = Modifier.weight(1f),
enabled = confirmedCount > 0 || (reviewedCount == 0 && totalCount > 6)
) {
Icon(Icons.Default.Check, contentDescription = null)
Spacer(modifier = Modifier.width(8.dp))
Text("Approve")
}
}
// Review progress text
Spacer(modifier = Modifier.height(8.dp))
Text( Text(
text = "${(match.confidence * 100).toInt()}%", text = if (reviewedCount == 0) {
style = MaterialTheme.typography.labelSmall, "Review faces above or approve to continue"
color = Color.White, } else {
modifier = Modifier.padding(horizontal = 6.dp, vertical = 2.dp) "Reviewed $reviewedCount of $totalCount faces"
},
style = MaterialTheme.typography.bodySmall,
color = MaterialTheme.colorScheme.onSurfaceVariant,
textAlign = TextAlign.Center,
modifier = Modifier.fillMaxWidth()
) )
} }
}
}
// Face count indicator (if group photo) /**
if (match.faceCount > 1) { * Refinement warning banner component
Surface( */
modifier = Modifier @Composable
.align(Alignment.TopEnd) private fun RefinementWarningBanner(
.padding(4.dp), rejectedCount: Int,
shape = RoundedCornerShape(4.dp), reviewedCount: Int,
color = MaterialTheme.colorScheme.primary onRefine: () -> Unit
) {
Column {
Card(
colors = CardDefaults.cardColors(
containerColor = MaterialTheme.colorScheme.errorContainer
),
modifier = Modifier.fillMaxWidth()
) {
Row(
modifier = Modifier.padding(12.dp),
verticalAlignment = Alignment.CenterVertically
) { ) {
Row( Icon(
modifier = Modifier.padding(horizontal = 4.dp, vertical = 2.dp), imageVector = Icons.Default.Warning,
verticalAlignment = Alignment.CenterVertically contentDescription = null,
) { tint = MaterialTheme.colorScheme.onErrorContainer
Icon( )
imageVector = Icons.Default.Person,
contentDescription = null, Spacer(modifier = Modifier.width(12.dp))
tint = Color.White,
modifier = Modifier.size(12.dp) Column(modifier = Modifier.weight(1f)) {
Text(
text = "High Rejection Rate",
style = MaterialTheme.typography.titleSmall,
fontWeight = FontWeight.Bold,
color = MaterialTheme.colorScheme.onErrorContainer
) )
Text( Text(
text = "${match.faceCount}", text = "${(rejectedCount.toFloat() / reviewedCount.toFloat() * 100).toInt()}% rejected. Consider refining the cluster.",
style = MaterialTheme.typography.labelSmall, style = MaterialTheme.typography.bodySmall,
color = Color.White color = MaterialTheme.colorScheme.onErrorContainer
) )
} }
Button(
onClick = onRefine,
colors = ButtonDefaults.buttonColors(
containerColor = MaterialTheme.colorScheme.error
)
) {
Text("Refine")
}
} }
} }
}
}
@Composable Spacer(modifier = Modifier.height(12.dp))
private fun NoMatchesCard() {
Card(
modifier = Modifier.fillMaxWidth(),
colors = CardDefaults.cardColors(
containerColor = MaterialTheme.colorScheme.errorContainer
)
) {
Column(
modifier = Modifier.padding(24.dp),
horizontalAlignment = Alignment.CenterHorizontally
) {
Icon(
imageVector = Icons.Default.Warning,
contentDescription = null,
tint = MaterialTheme.colorScheme.error,
modifier = Modifier.size(48.dp)
)
Spacer(modifier = Modifier.height(16.dp))
Text(
text = "No Matches Found",
style = MaterialTheme.typography.titleLarge,
fontWeight = FontWeight.Bold,
color = MaterialTheme.colorScheme.error
)
Spacer(modifier = Modifier.height(8.dp))
Text(
text = "The validation scan didn't find this person in the sample photos. This could mean:\n\n" +
"• The model needs more training photos\n" +
"• The training photos weren't diverse enough\n" +
"• The person wasn't in the validation sample",
style = MaterialTheme.typography.bodyMedium,
color = MaterialTheme.colorScheme.onErrorContainer
)
}
} }
} }
// Helper data class for quality indicator
private data class Quadruple<A, B, C, D>(
val first: A,
val second: B,
val third: C,
val fourth: D
)