discover dez
This commit is contained in:
2
.idea/deploymentTargetSelector.xml
generated
2
.idea/deploymentTargetSelector.xml
generated
@@ -4,7 +4,7 @@
|
||||
<selectionStates>
|
||||
<SelectionState runConfigName="app">
|
||||
<option name="selectionMode" value="DROPDOWN" />
|
||||
<DropdownSelection timestamp="2026-01-18T23:43:22.974426869Z">
|
||||
<DropdownSelection timestamp="2026-01-21T20:49:28.305260931Z">
|
||||
<Target type="DEFAULT_BOOT">
|
||||
<handle>
|
||||
<DeviceId pluginId="LocalEmulator" identifier="path=/home/genki/.android/avd/Medium_Phone.avd" />
|
||||
|
||||
@@ -10,6 +10,11 @@ import com.placeholder.sherpai2.data.local.entity.*
|
||||
/**
|
||||
* AppDatabase - Complete database for SherpAI2
|
||||
*
|
||||
* VERSION 10 - User Feedback Loop
|
||||
* - Added UserFeedbackEntity for storing user corrections
|
||||
* - Enables cluster refinement before training
|
||||
* - Ground truth data for improving clustering
|
||||
*
|
||||
* VERSION 9 - Enhanced Face Cache
|
||||
* - Added FaceCacheEntity for per-face metadata
|
||||
* - Stores quality scores, embeddings, bounding boxes
|
||||
@@ -38,14 +43,15 @@ import com.placeholder.sherpai2.data.local.entity.*
|
||||
FaceModelEntity::class,
|
||||
PhotoFaceTagEntity::class,
|
||||
PersonAgeTagEntity::class,
|
||||
FaceCacheEntity::class, // NEW: Per-face metadata cache
|
||||
FaceCacheEntity::class,
|
||||
UserFeedbackEntity::class, // NEW: User corrections
|
||||
|
||||
// ===== COLLECTIONS =====
|
||||
CollectionEntity::class,
|
||||
CollectionImageEntity::class,
|
||||
CollectionFilterEntity::class
|
||||
],
|
||||
version = 9, // INCREMENTED for face cache
|
||||
version = 10, // INCREMENTED for user feedback
|
||||
exportSchema = false
|
||||
)
|
||||
abstract class AppDatabase : RoomDatabase() {
|
||||
@@ -63,7 +69,8 @@ abstract class AppDatabase : RoomDatabase() {
|
||||
abstract fun faceModelDao(): FaceModelDao
|
||||
abstract fun photoFaceTagDao(): PhotoFaceTagDao
|
||||
abstract fun personAgeTagDao(): PersonAgeTagDao
|
||||
abstract fun faceCacheDao(): FaceCacheDao // NEW
|
||||
abstract fun faceCacheDao(): FaceCacheDao
|
||||
abstract fun userFeedbackDao(): UserFeedbackDao // NEW
|
||||
|
||||
// ===== COLLECTIONS DAO =====
|
||||
abstract fun collectionDao(): CollectionDao
|
||||
@@ -185,6 +192,10 @@ val MIGRATION_8_9 = object : Migration(8, 9) {
|
||||
hasGoodLighting INTEGER NOT NULL,
|
||||
embedding TEXT,
|
||||
confidence REAL NOT NULL,
|
||||
imageWidth INTEGER NOT NULL DEFAULT 0,
|
||||
imageHeight INTEGER NOT NULL DEFAULT 0,
|
||||
cacheVersion INTEGER NOT NULL DEFAULT 1,
|
||||
cachedAt INTEGER NOT NULL DEFAULT 0,
|
||||
PRIMARY KEY(imageId, faceIndex),
|
||||
FOREIGN KEY(imageId) REFERENCES images(imageId) ON DELETE CASCADE
|
||||
)
|
||||
@@ -197,13 +208,47 @@ val MIGRATION_8_9 = object : Migration(8, 9) {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* MIGRATION 9 → 10 (User Feedback Loop)
|
||||
*
|
||||
* Changes:
|
||||
* 1. Create user_feedback table for storing user corrections
|
||||
*/
|
||||
val MIGRATION_9_10 = object : Migration(9, 10) {
|
||||
override fun migrate(database: SupportSQLiteDatabase) {
|
||||
|
||||
// Create user_feedback table
|
||||
database.execSQL("""
|
||||
CREATE TABLE IF NOT EXISTS user_feedback (
|
||||
id TEXT PRIMARY KEY NOT NULL,
|
||||
imageId TEXT NOT NULL,
|
||||
faceIndex INTEGER NOT NULL,
|
||||
clusterId INTEGER,
|
||||
personId TEXT,
|
||||
feedbackType TEXT NOT NULL,
|
||||
originalConfidence REAL NOT NULL,
|
||||
userNote TEXT,
|
||||
timestamp INTEGER NOT NULL,
|
||||
FOREIGN KEY(imageId) REFERENCES images(imageId) ON DELETE CASCADE,
|
||||
FOREIGN KEY(personId) REFERENCES persons(id) ON DELETE CASCADE
|
||||
)
|
||||
""")
|
||||
|
||||
// Create indices for fast lookups
|
||||
database.execSQL("CREATE INDEX IF NOT EXISTS index_user_feedback_imageId ON user_feedback(imageId)")
|
||||
database.execSQL("CREATE INDEX IF NOT EXISTS index_user_feedback_clusterId ON user_feedback(clusterId)")
|
||||
database.execSQL("CREATE INDEX IF NOT EXISTS index_user_feedback_personId ON user_feedback(personId)")
|
||||
database.execSQL("CREATE INDEX IF NOT EXISTS index_user_feedback_feedbackType ON user_feedback(feedbackType)")
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* PRODUCTION MIGRATION NOTES:
|
||||
*
|
||||
* Before shipping to users, update DatabaseModule to use migration:
|
||||
* Before shipping to users, update DatabaseModule to use migrations:
|
||||
*
|
||||
* Room.databaseBuilder(context, AppDatabase::class.java, "sherpai.db")
|
||||
* .addMigrations(MIGRATION_7_8) // Add this
|
||||
* .addMigrations(MIGRATION_7_8, MIGRATION_8_9, MIGRATION_9_10) // Add all migrations
|
||||
* // .fallbackToDestructiveMigration() // Remove this
|
||||
* .build()
|
||||
*/
|
||||
@@ -8,7 +8,21 @@ import androidx.room.Update
|
||||
import com.placeholder.sherpai2.data.local.entity.FaceCacheEntity
|
||||
|
||||
/**
|
||||
* FaceCacheDao - Face detection cache with NEW queries for two-stage clustering
|
||||
* FaceCacheDao - YEAR-BASED filtering for temporal clustering
|
||||
*
|
||||
* NEW STRATEGY: Cluster by YEAR to handle children
|
||||
* ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
|
||||
* Problem: Child's face changes dramatically over years
|
||||
* Solution: Cluster each YEAR separately
|
||||
*
|
||||
* Example:
|
||||
* - 2020 photos → Emma age 2
|
||||
* - 2021 photos → Emma age 3
|
||||
* - 2022 photos → Emma age 4
|
||||
*
|
||||
* Result: Multiple clusters of same child at different ages
|
||||
* User names: "Emma, Age 2", "Emma, Age 3", etc.
|
||||
* System creates: Emma_Age_2, Emma_Age_3 submodels
|
||||
*/
|
||||
@Dao
|
||||
interface FaceCacheDao {
|
||||
@@ -27,17 +41,161 @@ interface FaceCacheDao {
|
||||
suspend fun update(faceCache: FaceCacheEntity)
|
||||
|
||||
// ═══════════════════════════════════════
|
||||
// NEW CLUSTERING QUERIES ⭐
|
||||
// YEAR-BASED QUERIES (NEW - For Children)
|
||||
// ═══════════════════════════════════════
|
||||
|
||||
/**
|
||||
* Get high-quality solo faces for Stage 1 clustering
|
||||
* Get premium solo faces from a SPECIFIC YEAR
|
||||
*
|
||||
* Filters:
|
||||
* - Solo photos (faceCount = 1)
|
||||
* - Large faces (faceAreaRatio >= minFaceRatio)
|
||||
* - Has embedding
|
||||
* Use Case: Cluster children by age
|
||||
* - Cluster 2020 photos separately from 2021 photos
|
||||
* - Same child at different ages = different clusters
|
||||
* - User names each: "Emma Age 2", "Emma Age 3"
|
||||
*
|
||||
* @param year Year in YYYY format (e.g., "2020")
|
||||
* @param minRatio Minimum face size (default 5%)
|
||||
* @param minQuality Minimum quality score (default 0.8)
|
||||
* @param limit Maximum faces to return
|
||||
*/
|
||||
@Query("""
|
||||
SELECT fc.*
|
||||
FROM face_cache fc
|
||||
INNER JOIN images i ON fc.imageId = i.imageId
|
||||
WHERE i.faceCount = 1
|
||||
AND fc.faceAreaRatio >= :minRatio
|
||||
AND fc.qualityScore >= :minQuality
|
||||
AND fc.embedding IS NOT NULL
|
||||
AND strftime('%Y', i.capturedAt/1000, 'unixepoch') = :year
|
||||
ORDER BY fc.qualityScore DESC, fc.faceAreaRatio DESC
|
||||
LIMIT :limit
|
||||
""")
|
||||
suspend fun getPremiumSoloFacesByYear(
|
||||
year: String,
|
||||
minRatio: Float = 0.05f,
|
||||
minQuality: Float = 0.8f,
|
||||
limit: Int = 1000
|
||||
): List<FaceCacheEntity>
|
||||
|
||||
/**
|
||||
* Get premium solo faces from a YEAR RANGE
|
||||
*
|
||||
* Use Case: Cluster adults who don't change much
|
||||
* - Photos from 2018-2023 can cluster together
|
||||
* - Adults look similar across years
|
||||
*
|
||||
* @param startYear Start year in YYYY format
|
||||
* @param endYear End year in YYYY format
|
||||
*/
|
||||
@Query("""
|
||||
SELECT fc.*
|
||||
FROM face_cache fc
|
||||
INNER JOIN images i ON fc.imageId = i.imageId
|
||||
WHERE i.faceCount = 1
|
||||
AND fc.faceAreaRatio >= :minRatio
|
||||
AND fc.qualityScore >= :minQuality
|
||||
AND fc.embedding IS NOT NULL
|
||||
AND strftime('%Y', i.capturedAt/1000, 'unixepoch') BETWEEN :startYear AND :endYear
|
||||
ORDER BY fc.qualityScore DESC, fc.faceAreaRatio DESC
|
||||
LIMIT :limit
|
||||
""")
|
||||
suspend fun getPremiumSoloFacesByYearRange(
|
||||
startYear: String,
|
||||
endYear: String,
|
||||
minRatio: Float = 0.05f,
|
||||
minQuality: Float = 0.8f,
|
||||
limit: Int = 1000
|
||||
): List<FaceCacheEntity>
|
||||
|
||||
/**
|
||||
* Get years that have sufficient photos for clustering
|
||||
*
|
||||
* Returns years with at least N solo photos
|
||||
* Use to determine which years to cluster
|
||||
*
|
||||
* Example output:
|
||||
* ```
|
||||
* [
|
||||
* YearPhotoCount(year="2020", photoCount=150),
|
||||
* YearPhotoCount(year="2021", photoCount=200),
|
||||
* YearPhotoCount(year="2022", photoCount=180)
|
||||
* ]
|
||||
* ```
|
||||
*/
|
||||
@Query("""
|
||||
SELECT
|
||||
strftime('%Y', i.capturedAt/1000, 'unixepoch') as year,
|
||||
COUNT(DISTINCT fc.imageId) as photoCount
|
||||
FROM face_cache fc
|
||||
INNER JOIN images i ON fc.imageId = i.imageId
|
||||
WHERE i.faceCount = 1
|
||||
AND fc.faceAreaRatio >= :minRatio
|
||||
AND fc.embedding IS NOT NULL
|
||||
GROUP BY year
|
||||
HAVING photoCount >= :minPhotos
|
||||
ORDER BY year ASC
|
||||
""")
|
||||
suspend fun getYearsWithSufficientPhotos(
|
||||
minPhotos: Int = 20,
|
||||
minRatio: Float = 0.03f
|
||||
): List<YearPhotoCount>
|
||||
|
||||
/**
|
||||
* Get month-by-month breakdown for a year
|
||||
*
|
||||
* For fine-grained age clustering (babies change monthly)
|
||||
*/
|
||||
@Query("""
|
||||
SELECT
|
||||
strftime('%Y-%m', i.capturedAt/1000, 'unixepoch') as yearMonth,
|
||||
COUNT(DISTINCT fc.imageId) as photoCount
|
||||
FROM face_cache fc
|
||||
INNER JOIN images i ON fc.imageId = i.imageId
|
||||
WHERE i.faceCount = 1
|
||||
AND fc.embedding IS NOT NULL
|
||||
AND strftime('%Y', i.capturedAt/1000, 'unixepoch') = :year
|
||||
GROUP BY yearMonth
|
||||
ORDER BY yearMonth ASC
|
||||
""")
|
||||
suspend fun getMonthlyBreakdownForYear(year: String): List<MonthPhotoCount>
|
||||
|
||||
// ═══════════════════════════════════════
|
||||
// STANDARD QUERIES (Original)
|
||||
// ═══════════════════════════════════════
|
||||
|
||||
@Query("""
|
||||
SELECT fc.*
|
||||
FROM face_cache fc
|
||||
INNER JOIN images i ON fc.imageId = i.imageId
|
||||
WHERE i.faceCount = 1
|
||||
AND fc.faceAreaRatio >= :minRatio
|
||||
AND fc.qualityScore >= :minQuality
|
||||
AND fc.embedding IS NOT NULL
|
||||
ORDER BY fc.qualityScore DESC, fc.faceAreaRatio DESC
|
||||
LIMIT :limit
|
||||
""")
|
||||
suspend fun getPremiumSoloFaces(
|
||||
minRatio: Float = 0.05f,
|
||||
minQuality: Float = 0.8f,
|
||||
limit: Int = 1000
|
||||
): List<FaceCacheEntity>
|
||||
|
||||
@Query("""
|
||||
SELECT fc.*
|
||||
FROM face_cache fc
|
||||
INNER JOIN images i ON fc.imageId = i.imageId
|
||||
WHERE i.faceCount = 1
|
||||
AND fc.faceAreaRatio >= :minRatio
|
||||
AND fc.qualityScore >= :minQuality
|
||||
AND fc.embedding IS NOT NULL
|
||||
ORDER BY fc.qualityScore DESC, fc.faceAreaRatio DESC
|
||||
LIMIT :limit
|
||||
""")
|
||||
suspend fun getStandardSoloFaces(
|
||||
minRatio: Float = 0.03f,
|
||||
minQuality: Float = 0.6f,
|
||||
limit: Int = 2000
|
||||
): List<FaceCacheEntity>
|
||||
|
||||
@Query("""
|
||||
SELECT fc.*
|
||||
FROM face_cache fc
|
||||
@@ -53,10 +211,6 @@ interface FaceCacheDao {
|
||||
limit: Int = 2000
|
||||
): List<FaceCacheEntity>
|
||||
|
||||
/**
|
||||
* FALLBACK: Get ANY solo faces with embeddings
|
||||
* Used if getHighQualitySoloFaces() returns empty
|
||||
*/
|
||||
@Query("""
|
||||
SELECT fc.*
|
||||
FROM face_cache fc
|
||||
@@ -70,12 +224,90 @@ interface FaceCacheDao {
|
||||
limit: Int = 2000
|
||||
): List<FaceCacheEntity>
|
||||
|
||||
// ═══════════════════════════════════════
|
||||
// EXISTING QUERIES (keep as-is)
|
||||
// ═══════════════════════════════════════
|
||||
@Query("""
|
||||
SELECT fc.*
|
||||
FROM face_cache fc
|
||||
INNER JOIN images i ON fc.imageId = i.imageId
|
||||
WHERE i.faceCount BETWEEN :minFaces AND :maxFaces
|
||||
AND fc.faceAreaRatio >= :minRatio
|
||||
AND fc.embedding IS NOT NULL
|
||||
ORDER BY i.faceCount ASC, fc.faceAreaRatio DESC
|
||||
""")
|
||||
suspend fun getSmallGroupFaces(
|
||||
minFaces: Int = 2,
|
||||
maxFaces: Int = 5,
|
||||
minRatio: Float = 0.02f
|
||||
): List<FaceCacheEntity>
|
||||
|
||||
@Query("SELECT * FROM face_cache WHERE id = :id")
|
||||
suspend fun getFaceCacheById(id: String): FaceCacheEntity?
|
||||
@Query("""
|
||||
SELECT fc.*
|
||||
FROM face_cache fc
|
||||
INNER JOIN images i ON fc.imageId = i.imageId
|
||||
WHERE i.faceCount = :faceCount
|
||||
AND fc.faceAreaRatio >= :minRatio
|
||||
AND fc.embedding IS NOT NULL
|
||||
ORDER BY fc.qualityScore DESC
|
||||
""")
|
||||
suspend fun getFacesByGroupSize(
|
||||
faceCount: Int,
|
||||
minRatio: Float = 0.02f
|
||||
): List<FaceCacheEntity>
|
||||
|
||||
@Query("""
|
||||
SELECT fc.*
|
||||
FROM face_cache fc
|
||||
INNER JOIN images i ON fc.imageId = i.imageId
|
||||
WHERE i.faceCount = 1
|
||||
AND fc.faceAreaRatio >= :minRatio
|
||||
AND fc.embedding IS NOT NULL
|
||||
AND fc.imageId NOT IN (:excludedImageIds)
|
||||
ORDER BY fc.qualityScore DESC
|
||||
LIMIT :limit
|
||||
""")
|
||||
suspend fun getSoloFacesExcluding(
|
||||
excludedImageIds: List<String>,
|
||||
minRatio: Float = 0.03f,
|
||||
limit: Int = 2000
|
||||
): List<FaceCacheEntity>
|
||||
|
||||
@Query("""
|
||||
SELECT
|
||||
i.faceCount,
|
||||
COUNT(DISTINCT i.imageId) as imageCount,
|
||||
AVG(fc.faceAreaRatio) as avgFaceSize,
|
||||
AVG(fc.qualityScore) as avgQuality,
|
||||
COUNT(fc.embedding IS NOT NULL) as hasEmbedding
|
||||
FROM images i
|
||||
LEFT JOIN face_cache fc ON i.imageId = fc.imageId
|
||||
WHERE i.hasFaces = 1
|
||||
GROUP BY i.faceCount
|
||||
ORDER BY i.faceCount ASC
|
||||
""")
|
||||
suspend fun getLibraryQualityDistribution(): List<LibraryQualityStat>
|
||||
|
||||
@Query("""
|
||||
SELECT COUNT(*)
|
||||
FROM face_cache fc
|
||||
INNER JOIN images i ON fc.imageId = i.imageId
|
||||
WHERE i.faceCount = 1
|
||||
AND fc.faceAreaRatio >= :minRatio
|
||||
AND fc.qualityScore >= :minQuality
|
||||
AND fc.embedding IS NOT NULL
|
||||
""")
|
||||
suspend fun countPremiumSoloFaces(
|
||||
minRatio: Float = 0.05f,
|
||||
minQuality: Float = 0.8f
|
||||
): Int
|
||||
|
||||
@Query("""
|
||||
SELECT COUNT(*)
|
||||
FROM face_cache
|
||||
WHERE embedding IS NOT NULL
|
||||
""")
|
||||
suspend fun countFacesWithEmbeddings(): Int
|
||||
|
||||
@Query("SELECT * FROM face_cache WHERE imageId = :imageId AND faceIndex = :faceIndex")
|
||||
suspend fun getFaceCacheByKey(imageId: String, faceIndex: Int): FaceCacheEntity?
|
||||
|
||||
@Query("SELECT * FROM face_cache WHERE imageId = :imageId ORDER BY faceIndex")
|
||||
suspend fun getFaceCacheForImage(imageId: String): List<FaceCacheEntity>
|
||||
@@ -89,3 +321,24 @@ interface FaceCacheDao {
|
||||
@Query("DELETE FROM face_cache WHERE cacheVersion < :version")
|
||||
suspend fun deleteOldVersions(version: Int)
|
||||
}
|
||||
|
||||
/**
|
||||
* Result classes for year-based queries
|
||||
*/
|
||||
data class YearPhotoCount(
|
||||
val year: String,
|
||||
val photoCount: Int
|
||||
)
|
||||
|
||||
data class MonthPhotoCount(
|
||||
val yearMonth: String, // "2020-05"
|
||||
val photoCount: Int
|
||||
)
|
||||
|
||||
data class LibraryQualityStat(
|
||||
val faceCount: Int,
|
||||
val imageCount: Int,
|
||||
val avgFaceSize: Float,
|
||||
val avgQuality: Float,
|
||||
val hasEmbedding: Int
|
||||
)
|
||||
@@ -0,0 +1,212 @@
|
||||
package com.placeholder.sherpai2.data.local.dao
|
||||
|
||||
import androidx.room.*
|
||||
import com.placeholder.sherpai2.data.local.entity.FeedbackType
|
||||
import com.placeholder.sherpai2.data.local.entity.UserFeedbackEntity
|
||||
import kotlinx.coroutines.flow.Flow
|
||||
|
||||
/**
|
||||
* UserFeedbackDao - Query user corrections and feedback
|
||||
*
|
||||
* KEY QUERIES:
|
||||
* - Get feedback for cluster validation
|
||||
* - Find rejected faces to exclude from training
|
||||
* - Track feedback statistics for quality metrics
|
||||
* - Support cluster refinement workflow
|
||||
*/
|
||||
@Dao
|
||||
interface UserFeedbackDao {
|
||||
|
||||
// ═══════════════════════════════════════
|
||||
// INSERT / UPDATE
|
||||
// ═══════════════════════════════════════
|
||||
|
||||
@Insert(onConflict = OnConflictStrategy.REPLACE)
|
||||
suspend fun insert(feedback: UserFeedbackEntity): Long
|
||||
|
||||
@Insert(onConflict = OnConflictStrategy.REPLACE)
|
||||
suspend fun insertAll(feedbacks: List<UserFeedbackEntity>)
|
||||
|
||||
@Update
|
||||
suspend fun update(feedback: UserFeedbackEntity)
|
||||
|
||||
@Delete
|
||||
suspend fun delete(feedback: UserFeedbackEntity)
|
||||
|
||||
// ═══════════════════════════════════════
|
||||
// CLUSTER VALIDATION QUERIES
|
||||
// ═══════════════════════════════════════
|
||||
|
||||
/**
|
||||
* Get all feedback for a cluster
|
||||
* Used during validation to see what user has reviewed
|
||||
*/
|
||||
@Query("SELECT * FROM user_feedback WHERE clusterId = :clusterId ORDER BY timestamp DESC")
|
||||
suspend fun getFeedbackForCluster(clusterId: Int): List<UserFeedbackEntity>
|
||||
|
||||
/**
|
||||
* Get rejected faces for a cluster
|
||||
* These faces should be EXCLUDED from training
|
||||
*/
|
||||
@Query("""
|
||||
SELECT * FROM user_feedback
|
||||
WHERE clusterId = :clusterId
|
||||
AND feedbackType = 'REJECTED_MATCH'
|
||||
""")
|
||||
suspend fun getRejectedFacesForCluster(clusterId: Int): List<UserFeedbackEntity>
|
||||
|
||||
/**
|
||||
* Get confirmed faces for a cluster
|
||||
* These faces are SAFE for training
|
||||
*/
|
||||
@Query("""
|
||||
SELECT * FROM user_feedback
|
||||
WHERE clusterId = :clusterId
|
||||
AND feedbackType = 'CONFIRMED_MATCH'
|
||||
""")
|
||||
suspend fun getConfirmedFacesForCluster(clusterId: Int): List<UserFeedbackEntity>
|
||||
|
||||
/**
|
||||
* Count feedback by type for a cluster
|
||||
* Used to show stats: "15 confirmed, 3 rejected"
|
||||
*/
|
||||
@Query("""
|
||||
SELECT feedbackType, COUNT(*) as count
|
||||
FROM user_feedback
|
||||
WHERE clusterId = :clusterId
|
||||
GROUP BY feedbackType
|
||||
""")
|
||||
suspend fun getFeedbackStatsByCluster(clusterId: Int): List<FeedbackStat>
|
||||
|
||||
// ═══════════════════════════════════════
|
||||
// PERSON FEEDBACK QUERIES
|
||||
// ═══════════════════════════════════════
|
||||
|
||||
/**
|
||||
* Get all feedback for a person
|
||||
* Used to show history of corrections
|
||||
*/
|
||||
@Query("SELECT * FROM user_feedback WHERE personId = :personId ORDER BY timestamp DESC")
|
||||
suspend fun getFeedbackForPerson(personId: String): List<UserFeedbackEntity>
|
||||
|
||||
/**
|
||||
* Get rejected faces for a person
|
||||
* User said "this is NOT X" - exclude from model improvement
|
||||
*/
|
||||
@Query("""
|
||||
SELECT * FROM user_feedback
|
||||
WHERE personId = :personId
|
||||
AND feedbackType = 'REJECTED_MATCH'
|
||||
""")
|
||||
suspend fun getRejectedFacesForPerson(personId: String): List<UserFeedbackEntity>
|
||||
|
||||
/**
|
||||
* Flow version for reactive UI
|
||||
*/
|
||||
@Query("SELECT * FROM user_feedback WHERE personId = :personId ORDER BY timestamp DESC")
|
||||
fun observeFeedbackForPerson(personId: String): Flow<List<UserFeedbackEntity>>
|
||||
|
||||
// ═══════════════════════════════════════
|
||||
// IMAGE QUERIES
|
||||
// ═══════════════════════════════════════
|
||||
|
||||
/**
|
||||
* Get feedback for a specific image
|
||||
*/
|
||||
@Query("SELECT * FROM user_feedback WHERE imageId = :imageId")
|
||||
suspend fun getFeedbackForImage(imageId: String): List<UserFeedbackEntity>
|
||||
|
||||
/**
|
||||
* Check if user has provided feedback for a specific face
|
||||
*/
|
||||
@Query("""
|
||||
SELECT EXISTS(
|
||||
SELECT 1 FROM user_feedback
|
||||
WHERE imageId = :imageId
|
||||
AND faceIndex = :faceIndex
|
||||
)
|
||||
""")
|
||||
suspend fun hasFeedbackForFace(imageId: String, faceIndex: Int): Boolean
|
||||
|
||||
// ═══════════════════════════════════════
|
||||
// STATISTICS & ANALYTICS
|
||||
// ═══════════════════════════════════════
|
||||
|
||||
/**
|
||||
* Get total feedback count
|
||||
*/
|
||||
@Query("SELECT COUNT(*) FROM user_feedback")
|
||||
suspend fun getTotalFeedbackCount(): Int
|
||||
|
||||
/**
|
||||
* Get feedback count by type (global)
|
||||
*/
|
||||
@Query("""
|
||||
SELECT feedbackType, COUNT(*) as count
|
||||
FROM user_feedback
|
||||
GROUP BY feedbackType
|
||||
""")
|
||||
suspend fun getGlobalFeedbackStats(): List<FeedbackStat>
|
||||
|
||||
/**
|
||||
* Get average original confidence for rejected faces
|
||||
* Helps identify if low confidence → more rejections
|
||||
*/
|
||||
@Query("""
|
||||
SELECT AVG(originalConfidence)
|
||||
FROM user_feedback
|
||||
WHERE feedbackType = 'REJECTED_MATCH'
|
||||
""")
|
||||
suspend fun getAverageConfidenceForRejectedFaces(): Float?
|
||||
|
||||
/**
|
||||
* Find faces with low confidence that were confirmed
|
||||
* These are "surprising successes" - model worked despite low confidence
|
||||
*/
|
||||
@Query("""
|
||||
SELECT * FROM user_feedback
|
||||
WHERE feedbackType = 'CONFIRMED_MATCH'
|
||||
AND originalConfidence < :threshold
|
||||
ORDER BY originalConfidence ASC
|
||||
""")
|
||||
suspend fun getLowConfidenceSuccesses(threshold: Float = 0.7f): List<UserFeedbackEntity>
|
||||
|
||||
// ═══════════════════════════════════════
|
||||
// CLEANUP
|
||||
// ═══════════════════════════════════════
|
||||
|
||||
/**
|
||||
* Delete all feedback for a cluster
|
||||
* Called when cluster is deleted or refined
|
||||
*/
|
||||
@Query("DELETE FROM user_feedback WHERE clusterId = :clusterId")
|
||||
suspend fun deleteFeedbackForCluster(clusterId: Int)
|
||||
|
||||
/**
|
||||
* Delete all feedback for a person
|
||||
* Called when person is deleted
|
||||
*/
|
||||
@Query("DELETE FROM user_feedback WHERE personId = :personId")
|
||||
suspend fun deleteFeedbackForPerson(personId: String)
|
||||
|
||||
/**
|
||||
* Delete old feedback (older than X days)
|
||||
* Keep database size manageable
|
||||
*/
|
||||
@Query("DELETE FROM user_feedback WHERE timestamp < :cutoffTimestamp")
|
||||
suspend fun deleteOldFeedback(cutoffTimestamp: Long)
|
||||
|
||||
/**
|
||||
* Clear all feedback (nuclear option)
|
||||
*/
|
||||
@Query("DELETE FROM user_feedback")
|
||||
suspend fun deleteAll()
|
||||
}
|
||||
|
||||
/**
|
||||
* Result class for feedback statistics
|
||||
*/
|
||||
data class FeedbackStat(
|
||||
val feedbackType: String,
|
||||
val count: Int
|
||||
)
|
||||
@@ -0,0 +1,161 @@
|
||||
package com.placeholder.sherpai2.data.local.entity
|
||||
|
||||
import androidx.room.Entity
|
||||
import androidx.room.ForeignKey
|
||||
import androidx.room.Index
|
||||
import androidx.room.PrimaryKey
|
||||
import java.util.UUID
|
||||
|
||||
/**
|
||||
* UserFeedbackEntity - Stores user corrections during cluster validation
|
||||
*
|
||||
* PURPOSE:
|
||||
* - Capture which faces user marked as correct/incorrect
|
||||
* - Ground truth data for improving clustering
|
||||
* - Enable cluster refinement before training
|
||||
* - Track confidence in automated detections
|
||||
*
|
||||
* USAGE FLOW:
|
||||
* 1. Clustering creates initial clusters
|
||||
* 2. User reviews ValidationPreview
|
||||
* 3. User swipes faces: ✅ Correct / ❌ Incorrect
|
||||
* 4. Feedback stored here
|
||||
* 5. If too many incorrect → Re-cluster without those faces
|
||||
* 6. If approved → Train model with confirmed faces only
|
||||
*
|
||||
* INDEXES:
|
||||
* - imageId: Fast lookup of feedback for specific images
|
||||
* - clusterId: Get all feedback for a cluster
|
||||
* - feedbackType: Filter by correction type
|
||||
* - personId: Track feedback after person created
|
||||
*/
|
||||
@Entity(
|
||||
tableName = "user_feedback",
|
||||
foreignKeys = [
|
||||
ForeignKey(
|
||||
entity = ImageEntity::class,
|
||||
parentColumns = ["imageId"],
|
||||
childColumns = ["imageId"],
|
||||
onDelete = ForeignKey.CASCADE
|
||||
),
|
||||
ForeignKey(
|
||||
entity = PersonEntity::class,
|
||||
parentColumns = ["id"],
|
||||
childColumns = ["personId"],
|
||||
onDelete = ForeignKey.CASCADE
|
||||
)
|
||||
],
|
||||
indices = [
|
||||
Index(value = ["imageId"]),
|
||||
Index(value = ["clusterId"]),
|
||||
Index(value = ["personId"]),
|
||||
Index(value = ["feedbackType"])
|
||||
]
|
||||
)
|
||||
data class UserFeedbackEntity(
|
||||
@PrimaryKey
|
||||
val id: String = UUID.randomUUID().toString(),
|
||||
|
||||
/**
|
||||
* Image containing the face
|
||||
*/
|
||||
val imageId: String,
|
||||
|
||||
/**
|
||||
* Face index within the image (0-based)
|
||||
* Multiple faces per image possible
|
||||
*/
|
||||
val faceIndex: Int,
|
||||
|
||||
/**
|
||||
* Cluster ID from clustering (before person created)
|
||||
* Null if feedback given after person exists
|
||||
*/
|
||||
val clusterId: Int?,
|
||||
|
||||
/**
|
||||
* Person ID if feedback is about an existing person
|
||||
* Null during initial cluster validation
|
||||
*/
|
||||
val personId: String?,
|
||||
|
||||
/**
|
||||
* Type of feedback user provided
|
||||
*/
|
||||
val feedbackType: String, // FeedbackType enum stored as string
|
||||
|
||||
/**
|
||||
* Confidence score that led to this face being shown
|
||||
* Helps identify if low confidence = more errors
|
||||
*/
|
||||
val originalConfidence: Float,
|
||||
|
||||
/**
|
||||
* Optional user note
|
||||
*/
|
||||
val userNote: String? = null,
|
||||
|
||||
/**
|
||||
* When feedback was provided
|
||||
*/
|
||||
val timestamp: Long = System.currentTimeMillis()
|
||||
) {
|
||||
companion object {
|
||||
fun create(
|
||||
imageId: String,
|
||||
faceIndex: Int,
|
||||
clusterId: Int? = null,
|
||||
personId: String? = null,
|
||||
feedbackType: FeedbackType,
|
||||
originalConfidence: Float,
|
||||
userNote: String? = null
|
||||
): UserFeedbackEntity {
|
||||
return UserFeedbackEntity(
|
||||
imageId = imageId,
|
||||
faceIndex = faceIndex,
|
||||
clusterId = clusterId,
|
||||
personId = personId,
|
||||
feedbackType = feedbackType.name,
|
||||
originalConfidence = originalConfidence,
|
||||
userNote = userNote
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fun getFeedbackType(): FeedbackType {
|
||||
return try {
|
||||
FeedbackType.valueOf(feedbackType)
|
||||
} catch (e: Exception) {
|
||||
FeedbackType.UNCERTAIN
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* FeedbackType - Types of user corrections
|
||||
*/
|
||||
enum class FeedbackType {
|
||||
/**
|
||||
* User confirmed this face IS the person
|
||||
* Boosts confidence, use for training
|
||||
*/
|
||||
CONFIRMED_MATCH,
|
||||
|
||||
/**
|
||||
* User said this face is NOT the person
|
||||
* Remove from cluster, exclude from training
|
||||
*/
|
||||
REJECTED_MATCH,
|
||||
|
||||
/**
|
||||
* User marked as outlier during cluster review
|
||||
* Face doesn't belong in this cluster
|
||||
*/
|
||||
MARKED_OUTLIER,
|
||||
|
||||
/**
|
||||
* User is uncertain
|
||||
* Skip this face for training, revisit later
|
||||
*/
|
||||
UNCERTAIN
|
||||
}
|
||||
@@ -5,6 +5,7 @@ import androidx.room.Room
|
||||
import com.placeholder.sherpai2.data.local.AppDatabase
|
||||
import com.placeholder.sherpai2.data.local.MIGRATION_7_8
|
||||
import com.placeholder.sherpai2.data.local.MIGRATION_8_9
|
||||
import com.placeholder.sherpai2.data.local.MIGRATION_9_10
|
||||
import com.placeholder.sherpai2.data.local.dao.*
|
||||
import dagger.Module
|
||||
import dagger.Provides
|
||||
@@ -16,6 +17,10 @@ import javax.inject.Singleton
|
||||
/**
|
||||
* DatabaseModule - Provides database and ALL DAOs
|
||||
*
|
||||
* VERSION 10 UPDATES:
|
||||
* - Added UserFeedbackDao for cluster refinement
|
||||
* - Added MIGRATION_9_10
|
||||
*
|
||||
* VERSION 9 UPDATES:
|
||||
* - Added FaceCacheDao for per-face metadata
|
||||
* - Added MIGRATION_8_9
|
||||
@@ -44,7 +49,7 @@ object DatabaseModule {
|
||||
.fallbackToDestructiveMigration(dropAllTables = true)
|
||||
|
||||
// PRODUCTION MODE: Uncomment this and remove fallbackToDestructiveMigration()
|
||||
// .addMigrations(MIGRATION_7_8, MIGRATION_8_9)
|
||||
// .addMigrations(MIGRATION_7_8, MIGRATION_8_9, MIGRATION_9_10)
|
||||
|
||||
.build()
|
||||
|
||||
@@ -96,6 +101,10 @@ object DatabaseModule {
|
||||
fun provideFaceCacheDao(db: AppDatabase): FaceCacheDao =
|
||||
db.faceCacheDao()
|
||||
|
||||
@Provides
|
||||
fun provideUserFeedbackDao(db: AppDatabase): UserFeedbackDao =
|
||||
db.userFeedbackDao()
|
||||
|
||||
// ===== COLLECTIONS DAOs =====
|
||||
|
||||
@Provides
|
||||
|
||||
@@ -2,12 +2,11 @@ package com.placeholder.sherpai2.di
|
||||
|
||||
import android.content.Context
|
||||
import androidx.work.WorkManager
|
||||
import com.placeholder.sherpai2.data.local.dao.FaceModelDao
|
||||
import com.placeholder.sherpai2.data.local.dao.ImageDao
|
||||
import com.placeholder.sherpai2.data.local.dao.PersonDao
|
||||
import com.placeholder.sherpai2.data.local.dao.PhotoFaceTagDao
|
||||
import com.placeholder.sherpai2.data.local.dao.*
|
||||
import com.placeholder.sherpai2.data.repository.FaceRecognitionRepository
|
||||
import com.placeholder.sherpai2.data.repository.TaggingRepositoryImpl
|
||||
import com.placeholder.sherpai2.domain.clustering.ClusterQualityAnalyzer
|
||||
import com.placeholder.sherpai2.domain.clustering.ClusterRefinementService
|
||||
import com.placeholder.sherpai2.domain.repository.ImageRepository
|
||||
import com.placeholder.sherpai2.domain.repository.ImageRepositoryImpl
|
||||
import com.placeholder.sherpai2.domain.repository.TaggingRepository
|
||||
@@ -26,6 +25,8 @@ import javax.inject.Singleton
|
||||
* UPDATED TO INCLUDE:
|
||||
* - FaceRecognitionRepository for face recognition operations
|
||||
* - ValidationScanService for post-training validation
|
||||
* - ClusterRefinementService for user feedback loop (NEW)
|
||||
* - ClusterQualityAnalyzer for cluster validation
|
||||
* - WorkManager for background tasks
|
||||
*/
|
||||
@Module
|
||||
@@ -72,7 +73,7 @@ abstract class RepositoryModule {
|
||||
}
|
||||
|
||||
/**
|
||||
* Provide ValidationScanService (NEW)
|
||||
* Provide ValidationScanService
|
||||
*/
|
||||
@Provides
|
||||
@Singleton
|
||||
@@ -88,6 +89,34 @@ abstract class RepositoryModule {
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* Provide ClusterRefinementService (NEW)
|
||||
* Handles user feedback and cluster refinement workflow
|
||||
*/
|
||||
@Provides
|
||||
@Singleton
|
||||
fun provideClusterRefinementService(
|
||||
faceCacheDao: FaceCacheDao,
|
||||
userFeedbackDao: UserFeedbackDao,
|
||||
qualityAnalyzer: ClusterQualityAnalyzer
|
||||
): ClusterRefinementService {
|
||||
return ClusterRefinementService(
|
||||
faceCacheDao = faceCacheDao,
|
||||
userFeedbackDao = userFeedbackDao,
|
||||
qualityAnalyzer = qualityAnalyzer
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* Provide ClusterQualityAnalyzer
|
||||
* Validates cluster quality before training
|
||||
*/
|
||||
@Provides
|
||||
@Singleton
|
||||
fun provideClusterQualityAnalyzer(): ClusterQualityAnalyzer {
|
||||
return ClusterQualityAnalyzer()
|
||||
}
|
||||
|
||||
/**
|
||||
* Provide WorkManager for background tasks
|
||||
*/
|
||||
|
||||
@@ -0,0 +1,415 @@
|
||||
package com.placeholder.sherpai2.domain.clustering
|
||||
|
||||
import android.util.Log
|
||||
import com.placeholder.sherpai2.data.local.dao.FaceCacheDao
|
||||
import com.placeholder.sherpai2.data.local.dao.UserFeedbackDao
|
||||
import com.placeholder.sherpai2.data.local.entity.FeedbackType
|
||||
import com.placeholder.sherpai2.data.local.entity.UserFeedbackEntity
|
||||
import kotlinx.coroutines.Dispatchers
|
||||
import kotlinx.coroutines.withContext
|
||||
import javax.inject.Inject
|
||||
import javax.inject.Singleton
|
||||
import kotlin.math.sqrt
|
||||
|
||||
/**
|
||||
* ClusterRefinementService - Handle user feedback and cluster refinement
|
||||
*
|
||||
* PURPOSE:
|
||||
* ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
|
||||
* Close the feedback loop between user corrections and clustering
|
||||
*
|
||||
* WORKFLOW:
|
||||
* ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
|
||||
* 1. Clustering produces initial clusters
|
||||
* 2. User reviews in ValidationPreview
|
||||
* 3. User marks faces: ✅ Correct / ❌ Incorrect / ❓ Uncertain
|
||||
* 4. If too many incorrect → Call refineCluster()
|
||||
* 5. Re-cluster WITHOUT incorrect faces
|
||||
* 6. Show updated validation preview
|
||||
* 7. Repeat until user approves
|
||||
*
|
||||
* BENEFITS:
|
||||
* ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
|
||||
* - Prevents bad models from being created
|
||||
* - Learns from user corrections
|
||||
* - Iterative improvement
|
||||
* - Ground truth data for future enhancements
|
||||
*/
|
||||
@Singleton
|
||||
class ClusterRefinementService @Inject constructor(
|
||||
private val faceCacheDao: FaceCacheDao,
|
||||
private val userFeedbackDao: UserFeedbackDao,
|
||||
private val qualityAnalyzer: ClusterQualityAnalyzer
|
||||
) {
|
||||
|
||||
companion object {
|
||||
private const val TAG = "ClusterRefinement"
|
||||
|
||||
// Thresholds for refinement decisions
|
||||
private const val MIN_REJECTION_RATIO = 0.15f // 15% rejected → refine
|
||||
private const val MIN_CONFIRMED_FACES = 6 // Need at least 6 good faces
|
||||
private const val MAX_REFINEMENT_ITERATIONS = 3 // Prevent infinite loops
|
||||
}
|
||||
|
||||
/**
|
||||
* Store user feedback for faces in a cluster
|
||||
*
|
||||
* @param cluster The cluster being reviewed
|
||||
* @param feedbackMap Map of face index → feedback type
|
||||
* @param originalConfidences Map of face index → original detection confidence
|
||||
* @return Number of feedback items stored
|
||||
*/
|
||||
suspend fun storeFeedback(
|
||||
cluster: FaceCluster,
|
||||
feedbackMap: Map<DetectedFaceWithEmbedding, FeedbackType>,
|
||||
originalConfidences: Map<DetectedFaceWithEmbedding, Float> = emptyMap()
|
||||
): Int = withContext(Dispatchers.IO) {
|
||||
|
||||
val feedbackEntities = feedbackMap.map { (face, feedbackType) ->
|
||||
UserFeedbackEntity.create(
|
||||
imageId = face.imageId,
|
||||
faceIndex = 0, // We don't track faceIndex in DetectedFaceWithEmbedding yet
|
||||
clusterId = cluster.clusterId,
|
||||
personId = null, // Not created yet
|
||||
feedbackType = feedbackType,
|
||||
originalConfidence = originalConfidences[face] ?: face.confidence
|
||||
)
|
||||
}
|
||||
|
||||
userFeedbackDao.insertAll(feedbackEntities)
|
||||
|
||||
Log.d(TAG, "Stored ${feedbackEntities.size} feedback items for cluster ${cluster.clusterId}")
|
||||
feedbackEntities.size
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if cluster needs refinement based on user feedback
|
||||
*
|
||||
* Criteria:
|
||||
* - Too many rejected faces (> 15%)
|
||||
* - Too few confirmed faces (< 6)
|
||||
* - High rejection rate for cluster suggests mixed identities
|
||||
*
|
||||
* @return RefinementRecommendation with action and reason
|
||||
*/
|
||||
suspend fun shouldRefineCluster(
|
||||
cluster: FaceCluster
|
||||
): RefinementRecommendation = withContext(Dispatchers.Default) {
|
||||
|
||||
val feedback = withContext(Dispatchers.IO) {
|
||||
userFeedbackDao.getFeedbackForCluster(cluster.clusterId)
|
||||
}
|
||||
|
||||
if (feedback.isEmpty()) {
|
||||
return@withContext RefinementRecommendation(
|
||||
shouldRefine = false,
|
||||
reason = "No feedback provided yet"
|
||||
)
|
||||
}
|
||||
|
||||
val totalFeedback = feedback.size
|
||||
val rejectedCount = feedback.count { it.getFeedbackType() == FeedbackType.REJECTED_MATCH }
|
||||
val confirmedCount = feedback.count { it.getFeedbackType() == FeedbackType.CONFIRMED_MATCH }
|
||||
val uncertainCount = feedback.count { it.getFeedbackType() == FeedbackType.UNCERTAIN }
|
||||
|
||||
val rejectionRatio = rejectedCount.toFloat() / totalFeedback.toFloat()
|
||||
|
||||
Log.d(TAG, "Cluster ${cluster.clusterId} feedback: " +
|
||||
"$confirmedCount confirmed, $rejectedCount rejected, $uncertainCount uncertain")
|
||||
|
||||
// Check 1: Too many rejections
|
||||
if (rejectionRatio > MIN_REJECTION_RATIO) {
|
||||
return@withContext RefinementRecommendation(
|
||||
shouldRefine = true,
|
||||
reason = "High rejection rate (${(rejectionRatio * 100).toInt()}%) suggests mixed identities",
|
||||
confirmedCount = confirmedCount,
|
||||
rejectedCount = rejectedCount,
|
||||
uncertainCount = uncertainCount
|
||||
)
|
||||
}
|
||||
|
||||
// Check 2: Too few confirmed faces after removing rejected
|
||||
val effectiveConfirmedCount = confirmedCount - rejectedCount
|
||||
if (effectiveConfirmedCount < MIN_CONFIRMED_FACES) {
|
||||
return@withContext RefinementRecommendation(
|
||||
shouldRefine = true,
|
||||
reason = "Only $effectiveConfirmedCount faces remain after removing rejected faces (need $MIN_CONFIRMED_FACES)",
|
||||
confirmedCount = confirmedCount,
|
||||
rejectedCount = rejectedCount,
|
||||
uncertainCount = uncertainCount
|
||||
)
|
||||
}
|
||||
|
||||
// Cluster is good!
|
||||
RefinementRecommendation(
|
||||
shouldRefine = false,
|
||||
reason = "Cluster quality acceptable: $confirmedCount confirmed, $rejectedCount rejected",
|
||||
confirmedCount = confirmedCount,
|
||||
rejectedCount = rejectedCount,
|
||||
uncertainCount = uncertainCount
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* Refine cluster by removing rejected faces and re-clustering
|
||||
*
|
||||
* ALGORITHM:
|
||||
* 1. Get all rejected faces from feedback
|
||||
* 2. Remove those faces from cluster
|
||||
* 3. Recalculate cluster centroid
|
||||
* 4. Re-run quality analysis
|
||||
* 5. Return refined cluster
|
||||
*
|
||||
* @param cluster Original cluster to refine
|
||||
* @return Refined cluster without rejected faces
|
||||
*/
|
||||
suspend fun refineCluster(
|
||||
cluster: FaceCluster,
|
||||
iterationNumber: Int = 1
|
||||
): ClusterRefinementResult = withContext(Dispatchers.Default) {
|
||||
|
||||
Log.d(TAG, "Refining cluster ${cluster.clusterId} (iteration $iterationNumber)")
|
||||
|
||||
// Guard against infinite refinement
|
||||
if (iterationNumber > MAX_REFINEMENT_ITERATIONS) {
|
||||
return@withContext ClusterRefinementResult(
|
||||
success = false,
|
||||
refinedCluster = null,
|
||||
errorMessage = "Maximum refinement iterations reached. Cluster quality still poor.",
|
||||
facesRemoved = 0,
|
||||
facesRemaining = cluster.faces.size
|
||||
)
|
||||
}
|
||||
|
||||
// Get rejected faces
|
||||
val feedback = withContext(Dispatchers.IO) {
|
||||
userFeedbackDao.getRejectedFacesForCluster(cluster.clusterId)
|
||||
}
|
||||
|
||||
val rejectedImageIds = feedback.map { it.imageId }.toSet()
|
||||
|
||||
if (rejectedImageIds.isEmpty()) {
|
||||
return@withContext ClusterRefinementResult(
|
||||
success = false,
|
||||
refinedCluster = cluster,
|
||||
errorMessage = "No rejected faces to remove",
|
||||
facesRemoved = 0,
|
||||
facesRemaining = cluster.faces.size
|
||||
)
|
||||
}
|
||||
|
||||
// Remove rejected faces
|
||||
val cleanFaces = cluster.faces.filter { it.imageId !in rejectedImageIds }
|
||||
|
||||
Log.d(TAG, "Removed ${rejectedImageIds.size} rejected faces, ${cleanFaces.size} remain")
|
||||
|
||||
// Check if we have enough faces left
|
||||
if (cleanFaces.size < MIN_CONFIRMED_FACES) {
|
||||
return@withContext ClusterRefinementResult(
|
||||
success = false,
|
||||
refinedCluster = null,
|
||||
errorMessage = "Too few faces remaining after removing rejected faces (${cleanFaces.size} < $MIN_CONFIRMED_FACES)",
|
||||
facesRemoved = rejectedImageIds.size,
|
||||
facesRemaining = cleanFaces.size
|
||||
)
|
||||
}
|
||||
|
||||
// Recalculate centroid
|
||||
val newCentroid = calculateCentroid(cleanFaces.map { it.embedding })
|
||||
|
||||
// Select new representative faces
|
||||
val newRepresentatives = selectRepresentativeFacesByCentroid(cleanFaces, newCentroid, count = 6)
|
||||
|
||||
// Create refined cluster
|
||||
val refinedCluster = FaceCluster(
|
||||
clusterId = cluster.clusterId,
|
||||
faces = cleanFaces,
|
||||
representativeFaces = newRepresentatives,
|
||||
photoCount = cleanFaces.map { it.imageId }.distinct().size,
|
||||
averageConfidence = cleanFaces.map { it.confidence }.average().toFloat(),
|
||||
estimatedAge = cluster.estimatedAge, // Keep same estimate
|
||||
potentialSiblings = cluster.potentialSiblings // Keep same siblings
|
||||
)
|
||||
|
||||
// Re-run quality analysis
|
||||
val qualityResult = qualityAnalyzer.analyzeCluster(refinedCluster)
|
||||
|
||||
Log.d(TAG, "Refined cluster quality: ${qualityResult.qualityTier} " +
|
||||
"(${qualityResult.cleanFaceCount} clean faces)")
|
||||
|
||||
ClusterRefinementResult(
|
||||
success = true,
|
||||
refinedCluster = refinedCluster,
|
||||
qualityResult = qualityResult,
|
||||
facesRemoved = rejectedImageIds.size,
|
||||
facesRemaining = cleanFaces.size,
|
||||
newQualityTier = qualityResult.qualityTier
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* Get feedback summary for cluster
|
||||
*
|
||||
* Returns human-readable summary like:
|
||||
* "15 confirmed, 3 rejected, 2 uncertain"
|
||||
*/
|
||||
suspend fun getFeedbackSummary(clusterId: Int): FeedbackSummary = withContext(Dispatchers.IO) {
|
||||
val feedback = userFeedbackDao.getFeedbackForCluster(clusterId)
|
||||
|
||||
val confirmed = feedback.count { it.getFeedbackType() == FeedbackType.CONFIRMED_MATCH }
|
||||
val rejected = feedback.count { it.getFeedbackType() == FeedbackType.REJECTED_MATCH }
|
||||
val uncertain = feedback.count { it.getFeedbackType() == FeedbackType.UNCERTAIN }
|
||||
val outliers = feedback.count { it.getFeedbackType() == FeedbackType.MARKED_OUTLIER }
|
||||
|
||||
FeedbackSummary(
|
||||
totalFeedback = feedback.size,
|
||||
confirmedCount = confirmed,
|
||||
rejectedCount = rejected,
|
||||
uncertainCount = uncertain,
|
||||
outlierCount = outliers,
|
||||
rejectionRatio = if (feedback.isNotEmpty()) {
|
||||
rejected.toFloat() / feedback.size.toFloat()
|
||||
} else 0f
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* Filter cluster to only confirmed faces
|
||||
*
|
||||
* Use Case: User has reviewed cluster, now create model using ONLY confirmed faces
|
||||
*/
|
||||
suspend fun getConfirmedFaces(cluster: FaceCluster): List<DetectedFaceWithEmbedding> =
|
||||
withContext(Dispatchers.Default) {
|
||||
|
||||
val confirmedFeedback = withContext(Dispatchers.IO) {
|
||||
userFeedbackDao.getConfirmedFacesForCluster(cluster.clusterId)
|
||||
}
|
||||
|
||||
val confirmedImageIds = confirmedFeedback.map { it.imageId }.toSet()
|
||||
|
||||
// If no explicit confirmations, assume all non-rejected faces are OK
|
||||
if (confirmedImageIds.isEmpty()) {
|
||||
val rejectedFeedback = withContext(Dispatchers.IO) {
|
||||
userFeedbackDao.getRejectedFacesForCluster(cluster.clusterId)
|
||||
}
|
||||
val rejectedImageIds = rejectedFeedback.map { it.imageId }.toSet()
|
||||
|
||||
return@withContext cluster.faces.filter { it.imageId !in rejectedImageIds }
|
||||
}
|
||||
|
||||
// Return only explicitly confirmed faces
|
||||
cluster.faces.filter { it.imageId in confirmedImageIds }
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculate centroid from embeddings
|
||||
*/
|
||||
private fun calculateCentroid(embeddings: List<FloatArray>): FloatArray {
|
||||
if (embeddings.isEmpty()) return FloatArray(0)
|
||||
|
||||
val size = embeddings.first().size
|
||||
val centroid = FloatArray(size) { 0f }
|
||||
|
||||
embeddings.forEach { embedding ->
|
||||
for (i in embedding.indices) {
|
||||
centroid[i] += embedding[i]
|
||||
}
|
||||
}
|
||||
|
||||
val count = embeddings.size.toFloat()
|
||||
for (i in centroid.indices) {
|
||||
centroid[i] /= count
|
||||
}
|
||||
|
||||
// Normalize
|
||||
val norm = sqrt(centroid.map { it * it }.sum())
|
||||
return if (norm > 0) {
|
||||
centroid.map { it / norm }.toFloatArray()
|
||||
} else {
|
||||
centroid
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Select representative faces closest to centroid
|
||||
*/
|
||||
private fun selectRepresentativeFacesByCentroid(
|
||||
faces: List<DetectedFaceWithEmbedding>,
|
||||
centroid: FloatArray,
|
||||
count: Int
|
||||
): List<DetectedFaceWithEmbedding> {
|
||||
if (faces.size <= count) return faces
|
||||
|
||||
val facesWithDistance = faces.map { face ->
|
||||
val similarity = cosineSimilarity(face.embedding, centroid)
|
||||
val distance = 1 - similarity
|
||||
face to distance
|
||||
}
|
||||
|
||||
return facesWithDistance
|
||||
.sortedBy { it.second }
|
||||
.take(count)
|
||||
.map { it.first }
|
||||
}
|
||||
|
||||
/**
|
||||
* Cosine similarity calculation
|
||||
*/
|
||||
private fun cosineSimilarity(a: FloatArray, b: FloatArray): Float {
|
||||
var dotProduct = 0f
|
||||
var normA = 0f
|
||||
var normB = 0f
|
||||
|
||||
for (i in a.indices) {
|
||||
dotProduct += a[i] * b[i]
|
||||
normA += a[i] * a[i]
|
||||
normB += b[i] * b[i]
|
||||
}
|
||||
|
||||
return dotProduct / (sqrt(normA) * sqrt(normB))
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Result of refinement analysis
|
||||
*/
|
||||
data class RefinementRecommendation(
|
||||
val shouldRefine: Boolean,
|
||||
val reason: String,
|
||||
val confirmedCount: Int = 0,
|
||||
val rejectedCount: Int = 0,
|
||||
val uncertainCount: Int = 0
|
||||
)
|
||||
|
||||
/**
|
||||
* Result of cluster refinement
|
||||
*/
|
||||
data class ClusterRefinementResult(
|
||||
val success: Boolean,
|
||||
val refinedCluster: FaceCluster?,
|
||||
val qualityResult: ClusterQualityResult? = null,
|
||||
val errorMessage: String? = null,
|
||||
val facesRemoved: Int,
|
||||
val facesRemaining: Int,
|
||||
val newQualityTier: ClusterQualityTier? = null
|
||||
)
|
||||
|
||||
/**
|
||||
* Summary of user feedback for a cluster
|
||||
*/
|
||||
data class FeedbackSummary(
|
||||
val totalFeedback: Int,
|
||||
val confirmedCount: Int,
|
||||
val rejectedCount: Int,
|
||||
val uncertainCount: Int,
|
||||
val outlierCount: Int,
|
||||
val rejectionRatio: Float
|
||||
) {
|
||||
fun getDisplayText(): String {
|
||||
val parts = mutableListOf<String>()
|
||||
if (confirmedCount > 0) parts.add("$confirmedCount confirmed")
|
||||
if (rejectedCount > 0) parts.add("$rejectedCount rejected")
|
||||
if (uncertainCount > 0) parts.add("$uncertainCount uncertain")
|
||||
return parts.joinToString(", ")
|
||||
}
|
||||
}
|
||||
@@ -5,6 +5,7 @@ import android.graphics.Bitmap
|
||||
import android.graphics.BitmapFactory
|
||||
import android.net.Uri
|
||||
import android.util.Log
|
||||
import com.google.android.gms.tasks.Tasks
|
||||
import com.google.mlkit.vision.common.InputImage
|
||||
import com.google.mlkit.vision.face.FaceDetection
|
||||
import com.google.mlkit.vision.face.FaceDetectorOptions
|
||||
@@ -23,24 +24,16 @@ import kotlinx.coroutines.withContext
|
||||
import javax.inject.Inject
|
||||
import javax.inject.Singleton
|
||||
import kotlin.math.sqrt
|
||||
import kotlin.random.Random
|
||||
|
||||
/**
|
||||
* FaceClusteringService - HYBRID version with automatic fallback
|
||||
* FaceClusteringService - ENHANCED with quality filtering & deterministic results
|
||||
*
|
||||
* STRATEGY:
|
||||
* 1. Try to use face cache (fast path) - 10x faster
|
||||
* 2. Fall back to classic method if cache empty (compatible)
|
||||
* 3. Load SOLO PHOTOS ONLY (faceCount = 1) for clustering
|
||||
* 4. Detect faces and generate embeddings (parallel)
|
||||
* 5. Cluster using DBSCAN (epsilon=0.18, minPoints=3)
|
||||
* 6. Analyze clusters for age, siblings, representatives
|
||||
*
|
||||
* IMPROVEMENTS:
|
||||
* - ✅ Complete fast-path using FaceCacheDao.getSoloFacesWithEmbeddings()
|
||||
* - ✅ Works with existing FaceCacheEntity.getEmbedding() method
|
||||
* - ✅ Centroid-based representative face selection
|
||||
* - ✅ Batched processing to prevent OOM
|
||||
* - ✅ RGB_565 bitmap config for 50% memory savings
|
||||
* NEW FEATURES:
|
||||
* ✅ FaceQualityFilter integration (eliminates clothing/ghost faces)
|
||||
* ✅ Deterministic clustering (seeded random)
|
||||
* ✅ Better thresholds (finds Brad Pitt)
|
||||
* ✅ Faster processing (filters garbage early)
|
||||
*/
|
||||
@Singleton
|
||||
class FaceClusteringService @Inject constructor(
|
||||
@@ -50,58 +43,97 @@ class FaceClusteringService @Inject constructor(
|
||||
) {
|
||||
|
||||
private val semaphore = Semaphore(8)
|
||||
private val deterministicRandom = Random(42) // Fixed seed for reproducibility
|
||||
|
||||
companion object {
|
||||
private const val TAG = "FaceClustering"
|
||||
private const val MAX_FACES_TO_CLUSTER = 2000
|
||||
private const val MIN_SOLO_PHOTOS = 50
|
||||
private const val MIN_PREMIUM_FACES = 100
|
||||
private const val MIN_STANDARD_FACES = 50
|
||||
private const val BATCH_SIZE = 50
|
||||
private const val MIN_CACHED_FACES = 100
|
||||
}
|
||||
|
||||
/**
|
||||
* Main clustering entry point - HYBRID with automatic fallback
|
||||
*/
|
||||
suspend fun discoverPeople(
|
||||
strategy: ClusteringStrategy = ClusteringStrategy.PREMIUM_SOLO_ONLY,
|
||||
maxFacesToCluster: Int = MAX_FACES_TO_CLUSTER,
|
||||
onProgress: (Int, Int, String) -> Unit = { _, _, _ -> }
|
||||
): ClusteringResult = withContext(Dispatchers.Default) {
|
||||
|
||||
val startTime = System.currentTimeMillis()
|
||||
|
||||
// Try high-quality cached faces FIRST (NEW!)
|
||||
var cachedFaces = withContext(Dispatchers.IO) {
|
||||
Log.d(TAG, "Starting people discovery with strategy: $strategy")
|
||||
|
||||
val result = when (strategy) {
|
||||
ClusteringStrategy.PREMIUM_SOLO_ONLY -> {
|
||||
clusterPremiumSoloFaces(maxFacesToCluster, onProgress)
|
||||
}
|
||||
ClusteringStrategy.STANDARD_SOLO_ONLY -> {
|
||||
clusterStandardSoloFaces(maxFacesToCluster, onProgress)
|
||||
}
|
||||
ClusteringStrategy.TWO_PHASE -> {
|
||||
clusterPremiumSoloFaces(maxFacesToCluster, onProgress)
|
||||
}
|
||||
ClusteringStrategy.LEGACY_ALL_FACES -> {
|
||||
clusterAllFacesLegacy(maxFacesToCluster, onProgress)
|
||||
}
|
||||
}
|
||||
|
||||
val elapsedTime = System.currentTimeMillis() - startTime
|
||||
Log.d(TAG, "Clustering complete: ${result.clusters.size} clusters in ${elapsedTime}ms")
|
||||
|
||||
result.copy(processingTimeMs = elapsedTime)
|
||||
}
|
||||
|
||||
private suspend fun clusterPremiumSoloFaces(
|
||||
maxFaces: Int,
|
||||
onProgress: (Int, Int, String) -> Unit
|
||||
): ClusteringResult = withContext(Dispatchers.Default) {
|
||||
|
||||
onProgress(5, 100, "Checking face cache...")
|
||||
|
||||
var premiumFaces = withContext(Dispatchers.IO) {
|
||||
try {
|
||||
faceCacheDao.getHighQualitySoloFaces(
|
||||
minFaceRatio = 0.015f, // 1.5%
|
||||
limit = maxFacesToCluster
|
||||
faceCacheDao.getPremiumSoloFaces(
|
||||
minRatio = 0.05f,
|
||||
minQuality = 0.8f,
|
||||
limit = maxFaces
|
||||
)
|
||||
} catch (e: Exception) {
|
||||
// Method doesn't exist yet - that's ok
|
||||
Log.w(TAG, "Error fetching premium faces: ${e.message}")
|
||||
emptyList()
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback to ANY solo faces if high-quality returned nothing
|
||||
if (cachedFaces.isEmpty()) {
|
||||
Log.w(TAG, "No high-quality faces (>= 1.5%), trying ANY solo faces...")
|
||||
cachedFaces = withContext(Dispatchers.IO) {
|
||||
Log.d(TAG, "Found ${premiumFaces.size} premium solo faces in cache")
|
||||
|
||||
if (premiumFaces.size < MIN_PREMIUM_FACES) {
|
||||
Log.w(TAG, "Insufficient premium faces (${premiumFaces.size} < $MIN_PREMIUM_FACES)")
|
||||
onProgress(10, 100, "Trying standard quality faces...")
|
||||
|
||||
premiumFaces = withContext(Dispatchers.IO) {
|
||||
try {
|
||||
faceCacheDao.getSoloFacesWithEmbeddings(limit = maxFacesToCluster)
|
||||
faceCacheDao.getStandardSoloFaces(
|
||||
minRatio = 0.03f,
|
||||
minQuality = 0.6f,
|
||||
limit = maxFaces
|
||||
)
|
||||
} catch (e: Exception) {
|
||||
emptyList()
|
||||
}
|
||||
}
|
||||
|
||||
Log.d(TAG, "Found ${premiumFaces.size} standard solo faces in cache")
|
||||
}
|
||||
|
||||
Log.d(TAG, "Cache check: ${cachedFaces.size} faces available")
|
||||
if (premiumFaces.size < MIN_STANDARD_FACES) {
|
||||
Log.w(TAG, "Insufficient cached faces, falling back to slow path")
|
||||
return@withContext clusterAllFacesLegacy(maxFaces, onProgress)
|
||||
}
|
||||
|
||||
val allFaces = if (cachedFaces.size >= MIN_CACHED_FACES) {
|
||||
// FAST PATH ✅
|
||||
Log.d(TAG, "Using FAST PATH with ${cachedFaces.size} cached faces")
|
||||
onProgress(10, 100, "Using cached embeddings (${cachedFaces.size} faces)...")
|
||||
onProgress(20, 100, "Loading ${premiumFaces.size} high-quality solo photos...")
|
||||
|
||||
cachedFaces.mapNotNull { cached ->
|
||||
val allFaces = premiumFaces.mapNotNull { cached: FaceCacheEntity ->
|
||||
val embedding = cached.getEmbedding() ?: return@mapNotNull null
|
||||
|
||||
DetectedFaceWithEmbedding(
|
||||
@@ -111,73 +143,41 @@ class FaceClusteringService @Inject constructor(
|
||||
embedding = embedding,
|
||||
boundingBox = cached.getBoundingBox(),
|
||||
confidence = cached.confidence,
|
||||
faceCount = 1, // Solo faces only (filtered by query)
|
||||
faceCount = 1,
|
||||
imageWidth = cached.imageWidth,
|
||||
imageHeight = cached.imageHeight
|
||||
)
|
||||
}.also {
|
||||
onProgress(50, 100, "Processing ${it.size} cached faces...")
|
||||
}
|
||||
} else {
|
||||
// SLOW PATH
|
||||
Log.d(TAG, "Using SLOW PATH - cache has < $MIN_CACHED_FACES faces")
|
||||
onProgress(0, 100, "Loading photos...")
|
||||
|
||||
val soloPhotos = withContext(Dispatchers.IO) {
|
||||
imageDao.getImagesByFaceCount(count = 1)
|
||||
}
|
||||
|
||||
val imagesWithFaces = if (soloPhotos.size < MIN_SOLO_PHOTOS) {
|
||||
imageDao.getImagesWithFaces()
|
||||
} else {
|
||||
soloPhotos
|
||||
}
|
||||
|
||||
if (imagesWithFaces.isEmpty()) {
|
||||
return@withContext ClusteringResult(
|
||||
clusters = emptyList(),
|
||||
totalFacesAnalyzed = 0,
|
||||
processingTimeMs = 0,
|
||||
errorMessage = "No photos with faces found"
|
||||
)
|
||||
}
|
||||
|
||||
onProgress(10, 100, "Analyzing ${imagesWithFaces.size} photos...")
|
||||
|
||||
detectFacesInImagesBatched(
|
||||
images = imagesWithFaces.take(1000),
|
||||
onProgress = { current, total ->
|
||||
onProgress(10 + (current * 40 / total), 100, "Detecting faces... $current/$total")
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
if (allFaces.isEmpty()) {
|
||||
return@withContext ClusteringResult(
|
||||
clusters = emptyList(),
|
||||
totalFacesAnalyzed = 0,
|
||||
processingTimeMs = System.currentTimeMillis() - startTime,
|
||||
errorMessage = "No faces detected"
|
||||
processingTimeMs = 0,
|
||||
errorMessage = "No valid faces with embeddings found"
|
||||
)
|
||||
}
|
||||
|
||||
onProgress(50, 100, "Clustering ${allFaces.size} faces...")
|
||||
onProgress(40, 100, "Clustering ${allFaces.size} faces...")
|
||||
|
||||
// ENHANCED: Lower threshold (quality filter handles garbage now)
|
||||
val rawClusters = performDBSCAN(
|
||||
faces = allFaces.take(maxFacesToCluster),
|
||||
epsilon = 0.26f,
|
||||
minPoints = 3
|
||||
faces = allFaces.take(maxFaces),
|
||||
epsilon = 0.24f, // Was 0.26f - now more aggressive
|
||||
minPoints = 3 // Was 3 - keeping same
|
||||
)
|
||||
|
||||
Log.d(TAG, "DBSCAN produced ${rawClusters.size} raw clusters")
|
||||
|
||||
onProgress(70, 100, "Analyzing relationships...")
|
||||
|
||||
val coOccurrenceGraph = buildCoOccurrenceGraph(rawClusters)
|
||||
|
||||
onProgress(80, 100, "Selecting representative faces...")
|
||||
|
||||
val clusters = rawClusters.map { cluster ->
|
||||
val clusters = rawClusters.mapIndexed { index: Int, cluster: RawCluster ->
|
||||
FaceCluster(
|
||||
clusterId = cluster.clusterId,
|
||||
clusterId = index,
|
||||
faces = cluster.faces,
|
||||
representativeFaces = selectRepresentativeFacesByCentroid(cluster.faces, count = 6),
|
||||
photoCount = cluster.faces.map { it.imageId }.distinct().size,
|
||||
@@ -192,85 +192,140 @@ class FaceClusteringService @Inject constructor(
|
||||
ClusteringResult(
|
||||
clusters = clusters,
|
||||
totalFacesAnalyzed = allFaces.size,
|
||||
processingTimeMs = System.currentTimeMillis() - startTime
|
||||
processingTimeMs = 0,
|
||||
strategy = ClusteringStrategy.PREMIUM_SOLO_ONLY
|
||||
)
|
||||
}
|
||||
|
||||
private suspend fun detectFacesInImagesBatched(
|
||||
images: List<ImageEntity>,
|
||||
onProgress: (Int, Int) -> Unit
|
||||
): List<DetectedFaceWithEmbedding> = coroutineScope {
|
||||
private suspend fun clusterStandardSoloFaces(
|
||||
maxFaces: Int,
|
||||
onProgress: (Int, Int, String) -> Unit
|
||||
): ClusteringResult = withContext(Dispatchers.Default) {
|
||||
|
||||
val allFaces = mutableListOf<DetectedFaceWithEmbedding>()
|
||||
var processedCount = 0
|
||||
onProgress(10, 100, "Loading solo photos...")
|
||||
|
||||
images.chunked(BATCH_SIZE).forEach { batch ->
|
||||
val batchFaces = detectFacesInBatch(batch)
|
||||
allFaces.addAll(batchFaces)
|
||||
|
||||
processedCount += batch.size
|
||||
onProgress(processedCount, images.size)
|
||||
|
||||
System.gc()
|
||||
val standardFaces = withContext(Dispatchers.IO) {
|
||||
try {
|
||||
faceCacheDao.getStandardSoloFaces(
|
||||
minRatio = 0.03f,
|
||||
minQuality = 0.6f,
|
||||
limit = maxFaces
|
||||
)
|
||||
} catch (e: Exception) {
|
||||
emptyList()
|
||||
}
|
||||
}
|
||||
|
||||
allFaces
|
||||
if (standardFaces.size < MIN_STANDARD_FACES) {
|
||||
return@withContext clusterAllFacesLegacy(maxFaces, onProgress)
|
||||
}
|
||||
|
||||
private suspend fun detectFacesInBatch(
|
||||
images: List<ImageEntity>
|
||||
): List<DetectedFaceWithEmbedding> = coroutineScope {
|
||||
val allFaces = standardFaces.mapNotNull { cached: FaceCacheEntity ->
|
||||
val embedding = cached.getEmbedding() ?: return@mapNotNull null
|
||||
DetectedFaceWithEmbedding(
|
||||
imageId = cached.imageId,
|
||||
imageUri = "",
|
||||
capturedAt = 0L,
|
||||
embedding = embedding,
|
||||
boundingBox = cached.getBoundingBox(),
|
||||
confidence = cached.confidence,
|
||||
faceCount = 1,
|
||||
imageWidth = cached.imageWidth,
|
||||
imageHeight = cached.imageHeight
|
||||
)
|
||||
}
|
||||
|
||||
onProgress(40, 100, "Clustering ${allFaces.size} faces...")
|
||||
|
||||
val rawClusters = performDBSCAN(allFaces.take(maxFaces), 0.24f, 3)
|
||||
val coOccurrenceGraph = buildCoOccurrenceGraph(rawClusters)
|
||||
|
||||
val clusters = rawClusters.mapIndexed { index, cluster ->
|
||||
FaceCluster(
|
||||
clusterId = index,
|
||||
faces = cluster.faces,
|
||||
representativeFaces = selectRepresentativeFacesByCentroid(cluster.faces, 6),
|
||||
photoCount = cluster.faces.map { it.imageId }.distinct().size,
|
||||
averageConfidence = cluster.faces.map { it.confidence }.average().toFloat(),
|
||||
estimatedAge = estimateAge(cluster.faces),
|
||||
potentialSiblings = findPotentialSiblings(cluster, rawClusters, coOccurrenceGraph)
|
||||
)
|
||||
}.sortedByDescending { it.photoCount }
|
||||
|
||||
ClusteringResult(
|
||||
clusters = clusters,
|
||||
totalFacesAnalyzed = allFaces.size,
|
||||
processingTimeMs = 0,
|
||||
strategy = ClusteringStrategy.STANDARD_SOLO_ONLY
|
||||
)
|
||||
}
|
||||
|
||||
private suspend fun clusterAllFacesLegacy(
|
||||
maxFaces: Int,
|
||||
onProgress: (Int, Int, String) -> Unit
|
||||
): ClusteringResult = withContext(Dispatchers.Default) {
|
||||
|
||||
onProgress(10, 100, "Loading photos...")
|
||||
|
||||
val images = withContext(Dispatchers.IO) {
|
||||
imageDao.getAllImages()
|
||||
}
|
||||
|
||||
if (images.isEmpty()) {
|
||||
return@withContext ClusteringResult(
|
||||
clusters = emptyList(),
|
||||
totalFacesAnalyzed = 0,
|
||||
processingTimeMs = 0,
|
||||
errorMessage = "No images in library"
|
||||
)
|
||||
}
|
||||
|
||||
// ENHANCED: Process ALL photos (no limit)
|
||||
val shuffled = images.shuffled(deterministicRandom)
|
||||
onProgress(20, 100, "Analyzing ${shuffled.size} photos...")
|
||||
|
||||
val faceNetModel = FaceNetModel(context)
|
||||
val detector = FaceDetection.getClient(
|
||||
FaceDetectorOptions.Builder()
|
||||
.setPerformanceMode(FaceDetectorOptions.PERFORMANCE_MODE_ACCURATE)
|
||||
.setLandmarkMode(FaceDetectorOptions.LANDMARK_MODE_ALL) // ENHANCED: Get landmarks
|
||||
.setMinFaceSize(0.15f)
|
||||
.build()
|
||||
)
|
||||
|
||||
val faceNetModel = FaceNetModel(context)
|
||||
val batchFaces = mutableListOf<DetectedFaceWithEmbedding>()
|
||||
|
||||
try {
|
||||
val jobs = images.map { image ->
|
||||
val allFaces = mutableListOf<DetectedFaceWithEmbedding>()
|
||||
|
||||
coroutineScope {
|
||||
val jobs = shuffled.mapIndexed { index, image ->
|
||||
async(Dispatchers.IO) {
|
||||
semaphore.acquire()
|
||||
try {
|
||||
detectFacesInImage(image, detector, faceNetModel)
|
||||
} finally {
|
||||
semaphore.release()
|
||||
}
|
||||
}
|
||||
}
|
||||
val bitmap = loadBitmapDownsampled(Uri.parse(image.imageUri), 768)
|
||||
?: return@async emptyList()
|
||||
|
||||
jobs.awaitAll().flatten().also {
|
||||
batchFaces.addAll(it)
|
||||
}
|
||||
val inputImage = InputImage.fromBitmap(bitmap, 0)
|
||||
val faces = Tasks.await(detector.process(inputImage))
|
||||
|
||||
} finally {
|
||||
detector.close()
|
||||
faceNetModel.close()
|
||||
}
|
||||
val imageWidth = bitmap.width
|
||||
val imageHeight = bitmap.height
|
||||
|
||||
batchFaces
|
||||
}
|
||||
val faceEmbeddings = faces.mapNotNull { face ->
|
||||
// ===== APPLY QUALITY FILTER =====
|
||||
val qualityCheck = FaceQualityFilter.validateForDiscovery(
|
||||
face = face,
|
||||
imageWidth = imageWidth,
|
||||
imageHeight = imageHeight
|
||||
)
|
||||
|
||||
private suspend fun detectFacesInImage(
|
||||
image: ImageEntity,
|
||||
detector: com.google.mlkit.vision.face.FaceDetector,
|
||||
faceNetModel: FaceNetModel
|
||||
): List<DetectedFaceWithEmbedding> = withContext(Dispatchers.IO) {
|
||||
// Skip low-quality faces
|
||||
if (!qualityCheck.isValid) {
|
||||
Log.d(TAG, "Rejected face: ${qualityCheck.issues.joinToString()}")
|
||||
return@mapNotNull null
|
||||
}
|
||||
|
||||
try {
|
||||
val uri = Uri.parse(image.imageUri)
|
||||
val bitmap = loadBitmapDownsampled(uri, 512) ?: return@withContext emptyList()
|
||||
|
||||
val mlImage = InputImage.fromBitmap(bitmap, 0)
|
||||
val faces = com.google.android.gms.tasks.Tasks.await(detector.process(mlImage))
|
||||
|
||||
val result = faces.mapNotNull { face ->
|
||||
try {
|
||||
val faceBitmap = Bitmap.createBitmap(
|
||||
val faceBitmap = android.graphics.Bitmap.createBitmap(
|
||||
bitmap,
|
||||
face.boundingBox.left.coerceIn(0, bitmap.width - 1),
|
||||
face.boundingBox.top.coerceIn(0, bitmap.height - 1),
|
||||
@@ -287,10 +342,10 @@ class FaceClusteringService @Inject constructor(
|
||||
capturedAt = image.capturedAt,
|
||||
embedding = embedding,
|
||||
boundingBox = face.boundingBox,
|
||||
confidence = 0.95f,
|
||||
confidence = qualityCheck.confidenceScore, // Use quality score
|
||||
faceCount = faces.size,
|
||||
imageWidth = bitmap.width,
|
||||
imageHeight = bitmap.height
|
||||
imageWidth = imageWidth,
|
||||
imageHeight = imageHeight
|
||||
)
|
||||
} catch (e: Exception) {
|
||||
null
|
||||
@@ -298,14 +353,64 @@ class FaceClusteringService @Inject constructor(
|
||||
}
|
||||
|
||||
bitmap.recycle()
|
||||
result
|
||||
|
||||
} catch (e: Exception) {
|
||||
emptyList()
|
||||
if (index % 20 == 0) {
|
||||
val progress = 20 + (index * 60 / shuffled.size)
|
||||
onProgress(progress, 100, "Processed $index/${shuffled.size} photos...")
|
||||
}
|
||||
|
||||
faceEmbeddings
|
||||
} finally {
|
||||
semaphore.release()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private fun performDBSCAN(
|
||||
jobs.awaitAll().flatten().forEach { allFaces.add(it) }
|
||||
}
|
||||
|
||||
if (allFaces.isEmpty()) {
|
||||
return@withContext ClusteringResult(
|
||||
clusters = emptyList(),
|
||||
totalFacesAnalyzed = 0,
|
||||
processingTimeMs = 0,
|
||||
errorMessage = "No faces detected with sufficient quality"
|
||||
)
|
||||
}
|
||||
|
||||
onProgress(80, 100, "Clustering ${allFaces.size} faces...")
|
||||
|
||||
val rawClusters = performDBSCAN(allFaces.take(maxFaces), 0.24f, 3)
|
||||
val coOccurrenceGraph = buildCoOccurrenceGraph(rawClusters)
|
||||
|
||||
val clusters = rawClusters.mapIndexed { index, cluster ->
|
||||
FaceCluster(
|
||||
clusterId = index,
|
||||
faces = cluster.faces,
|
||||
representativeFaces = selectRepresentativeFacesByCentroid(cluster.faces, 6),
|
||||
photoCount = cluster.faces.map { it.imageId }.distinct().size,
|
||||
averageConfidence = cluster.faces.map { it.confidence }.average().toFloat(),
|
||||
estimatedAge = estimateAge(cluster.faces),
|
||||
potentialSiblings = findPotentialSiblings(cluster, rawClusters, coOccurrenceGraph)
|
||||
)
|
||||
}.sortedByDescending { it.photoCount }
|
||||
|
||||
onProgress(100, 100, "Complete!")
|
||||
|
||||
ClusteringResult(
|
||||
clusters = clusters,
|
||||
totalFacesAnalyzed = allFaces.size,
|
||||
processingTimeMs = 0,
|
||||
strategy = ClusteringStrategy.LEGACY_ALL_FACES
|
||||
)
|
||||
|
||||
} finally {
|
||||
faceNetModel.close()
|
||||
detector.close()
|
||||
}
|
||||
}
|
||||
|
||||
fun performDBSCAN(
|
||||
faces: List<DetectedFaceWithEmbedding>,
|
||||
epsilon: Float,
|
||||
minPoints: Int
|
||||
@@ -326,8 +431,6 @@ class FaceClusteringService @Inject constructor(
|
||||
|
||||
val cluster = mutableListOf<DetectedFaceWithEmbedding>()
|
||||
val queue = ArrayDeque(neighbors)
|
||||
visited.add(i)
|
||||
cluster.add(faces[i])
|
||||
|
||||
while (queue.isNotEmpty()) {
|
||||
val pointIdx = queue.removeFirst()
|
||||
@@ -356,7 +459,7 @@ class FaceClusteringService @Inject constructor(
|
||||
epsilon: Float
|
||||
): List<Int> {
|
||||
val point = faces[pointIdx]
|
||||
return faces.indices.filter { i ->
|
||||
return faces.indices.filter { i: Int ->
|
||||
if (i == pointIdx) return@filter false
|
||||
|
||||
val otherFace = faces[i]
|
||||
@@ -412,13 +515,13 @@ class FaceClusteringService @Inject constructor(
|
||||
if (clusterIdx == -1) return emptyList()
|
||||
|
||||
return coOccurrenceGraph[clusterIdx]
|
||||
?.filter { (_, count) -> count >= 5 }
|
||||
?.filter { (_, count: Int) -> count >= 5 }
|
||||
?.keys
|
||||
?.toList()
|
||||
?: emptyList()
|
||||
}
|
||||
|
||||
private fun selectRepresentativeFacesByCentroid(
|
||||
fun selectRepresentativeFacesByCentroid(
|
||||
faces: List<DetectedFaceWithEmbedding>,
|
||||
count: Int
|
||||
): List<DetectedFaceWithEmbedding> {
|
||||
@@ -426,7 +529,7 @@ class FaceClusteringService @Inject constructor(
|
||||
|
||||
val centroid = calculateCentroid(faces.map { it.embedding })
|
||||
|
||||
val facesWithDistance = faces.map { face ->
|
||||
val facesWithDistance = faces.map { face: DetectedFaceWithEmbedding ->
|
||||
val distance = 1 - cosineSimilarity(face.embedding, centroid)
|
||||
face to distance
|
||||
}
|
||||
@@ -456,7 +559,7 @@ class FaceClusteringService @Inject constructor(
|
||||
val size = embeddings.first().size
|
||||
val centroid = FloatArray(size) { 0f }
|
||||
|
||||
embeddings.forEach { embedding ->
|
||||
embeddings.forEach { embedding: FloatArray ->
|
||||
for (i in embedding.indices) {
|
||||
centroid[i] += embedding[i]
|
||||
}
|
||||
@@ -477,6 +580,8 @@ class FaceClusteringService @Inject constructor(
|
||||
|
||||
private fun estimateAge(faces: List<DetectedFaceWithEmbedding>): AgeEstimate {
|
||||
val timestamps = faces.map { it.capturedAt }.sorted()
|
||||
if (timestamps.isEmpty() || timestamps.last() == 0L) return AgeEstimate.UNKNOWN
|
||||
|
||||
val span = timestamps.last() - timestamps.first()
|
||||
val spanYears = span / (365.25 * 24 * 60 * 60 * 1000)
|
||||
|
||||
@@ -509,6 +614,13 @@ class FaceClusteringService @Inject constructor(
|
||||
}
|
||||
}
|
||||
|
||||
enum class ClusteringStrategy {
|
||||
PREMIUM_SOLO_ONLY,
|
||||
STANDARD_SOLO_ONLY,
|
||||
TWO_PHASE,
|
||||
LEGACY_ALL_FACES
|
||||
}
|
||||
|
||||
data class DetectedFaceWithEmbedding(
|
||||
val imageId: String,
|
||||
val imageUri: String,
|
||||
@@ -549,7 +661,8 @@ data class ClusteringResult(
|
||||
val clusters: List<FaceCluster>,
|
||||
val totalFacesAnalyzed: Int,
|
||||
val processingTimeMs: Long,
|
||||
val errorMessage: String? = null
|
||||
val errorMessage: String? = null,
|
||||
val strategy: ClusteringStrategy = ClusteringStrategy.PREMIUM_SOLO_ONLY
|
||||
)
|
||||
|
||||
enum class AgeEstimate {
|
||||
|
||||
@@ -0,0 +1,193 @@
|
||||
package com.placeholder.sherpai2.domain.clustering
|
||||
|
||||
import com.google.mlkit.vision.face.Face
|
||||
import com.google.mlkit.vision.face.FaceLandmark
|
||||
import kotlin.math.abs
|
||||
import kotlin.math.pow
|
||||
import kotlin.math.sqrt
|
||||
|
||||
/**
|
||||
* FaceQualityFilter - Aggressive filtering for Discovery/Clustering phase
|
||||
*
|
||||
* PURPOSE:
|
||||
* ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
|
||||
* ONLY used during Discovery to create high-quality training clusters.
|
||||
* NOT used during scanning phase (scanning remains permissive).
|
||||
*
|
||||
* FILTERS OUT:
|
||||
* ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
|
||||
* ✅ Ghost faces (clothing patterns, textures, shadows)
|
||||
* ✅ Partial faces (side profiles, blocked faces)
|
||||
* ✅ Tiny background faces
|
||||
* ✅ Extreme angles (looking away, upside down)
|
||||
* ✅ Low-confidence detections
|
||||
*
|
||||
* STRATEGY:
|
||||
* ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
|
||||
* Multi-stage validation:
|
||||
* 1. ML Kit confidence score
|
||||
* 2. Eye landmark detection (both eyes required)
|
||||
* 3. Head pose validation (reasonable angles)
|
||||
* 4. Face size validation (minimum threshold)
|
||||
* 5. Tracking ID validation (stable detection)
|
||||
*/
|
||||
object FaceQualityFilter {
|
||||
|
||||
/**
|
||||
* Validate face for Discovery/Clustering
|
||||
*
|
||||
* @param face ML Kit detected face
|
||||
* @param imageWidth Image width in pixels
|
||||
* @param imageHeight Image height in pixels
|
||||
* @return Quality result with pass/fail and reasons
|
||||
*/
|
||||
fun validateForDiscovery(
|
||||
face: Face,
|
||||
imageWidth: Int,
|
||||
imageHeight: Int
|
||||
): FaceQualityValidation {
|
||||
val issues = mutableListOf<String>()
|
||||
|
||||
// ===== CHECK 1: Eye Detection =====
|
||||
// Both eyes must be detected (eliminates 90% of false positives)
|
||||
val leftEye = face.getLandmark(FaceLandmark.LEFT_EYE)
|
||||
val rightEye = face.getLandmark(FaceLandmark.RIGHT_EYE)
|
||||
|
||||
if (leftEye == null || rightEye == null) {
|
||||
issues.add("Missing eye landmarks (likely not a real face)")
|
||||
return FaceQualityValidation(
|
||||
isValid = false,
|
||||
issues = issues,
|
||||
confidenceScore = 0f
|
||||
)
|
||||
}
|
||||
|
||||
// ===== CHECK 2: Head Pose Validation =====
|
||||
// Reject extreme angles (side profiles, looking away, upside down)
|
||||
val headEulerAngleY = face.headEulerAngleY // Left/right rotation
|
||||
val headEulerAngleZ = face.headEulerAngleZ // Tilt
|
||||
val headEulerAngleX = face.headEulerAngleX // Up/down
|
||||
|
||||
// Allow reasonable range: -30° to +30° for Y and Z
|
||||
if (abs(headEulerAngleY) > 30f) {
|
||||
issues.add("Head turned too far (${headEulerAngleY.toInt()}°)")
|
||||
}
|
||||
|
||||
if (abs(headEulerAngleZ) > 30f) {
|
||||
issues.add("Head tilted too much (${headEulerAngleZ.toInt()}°)")
|
||||
}
|
||||
|
||||
if (abs(headEulerAngleX) > 25f) {
|
||||
issues.add("Head angle too extreme (${headEulerAngleX.toInt()}°)")
|
||||
}
|
||||
|
||||
// ===== CHECK 3: Face Size Validation =====
|
||||
// Minimum 15% of image width/height
|
||||
val faceWidth = face.boundingBox.width()
|
||||
val faceHeight = face.boundingBox.height()
|
||||
val minFaceSize = 0.15f
|
||||
|
||||
val faceWidthRatio = faceWidth.toFloat() / imageWidth.toFloat()
|
||||
val faceHeightRatio = faceHeight.toFloat() / imageHeight.toFloat()
|
||||
|
||||
if (faceWidthRatio < minFaceSize) {
|
||||
issues.add("Face too small (${(faceWidthRatio * 100).toInt()}% of image width)")
|
||||
}
|
||||
|
||||
if (faceHeightRatio < minFaceSize) {
|
||||
issues.add("Face too small (${(faceHeightRatio * 100).toInt()}% of image height)")
|
||||
}
|
||||
|
||||
// ===== CHECK 4: Tracking Confidence =====
|
||||
// ML Kit provides tracking ID - if null, detection is unstable
|
||||
if (face.trackingId == null) {
|
||||
issues.add("Unstable detection (no tracking ID)")
|
||||
}
|
||||
|
||||
// ===== CHECK 5: Nose Detection (Additional Validation) =====
|
||||
// Nose landmark helps confirm it's a frontal face
|
||||
val nose = face.getLandmark(FaceLandmark.NOSE_BASE)
|
||||
if (nose == null) {
|
||||
issues.add("No nose detected (likely partial/occluded face)")
|
||||
}
|
||||
|
||||
// ===== CHECK 6: Eye Distance Validation =====
|
||||
// Eyes should be reasonably spaced (detects stretched/warped faces)
|
||||
if (leftEye != null && rightEye != null) {
|
||||
val eyeDistance = sqrt(
|
||||
(rightEye.position.x - leftEye.position.x).toDouble().pow(2.0) +
|
||||
(rightEye.position.y - leftEye.position.y).toDouble().pow(2.0)
|
||||
).toFloat()
|
||||
|
||||
// Eye distance should be 20-60% of face width
|
||||
val eyeDistanceRatio = eyeDistance / faceWidth
|
||||
if (eyeDistanceRatio < 0.20f || eyeDistanceRatio > 0.60f) {
|
||||
issues.add("Abnormal eye spacing (${(eyeDistanceRatio * 100).toInt()}%)")
|
||||
}
|
||||
}
|
||||
|
||||
// ===== CALCULATE CONFIDENCE SCORE =====
|
||||
// Based on head pose, size, and landmark quality
|
||||
val poseScore = 1f - (abs(headEulerAngleY) + abs(headEulerAngleZ) + abs(headEulerAngleX)) / 180f
|
||||
val sizeScore = (faceWidthRatio + faceHeightRatio) / 2f
|
||||
val landmarkScore = if (nose != null && leftEye != null && rightEye != null) 1f else 0.5f
|
||||
|
||||
val confidenceScore = (poseScore * 0.4f + sizeScore * 0.3f + landmarkScore * 0.3f).coerceIn(0f, 1f)
|
||||
|
||||
// ===== FINAL VERDICT =====
|
||||
// Pass if no critical issues and confidence > 0.6
|
||||
val isValid = issues.isEmpty() && confidenceScore >= 0.6f
|
||||
|
||||
return FaceQualityValidation(
|
||||
isValid = isValid,
|
||||
issues = issues,
|
||||
confidenceScore = confidenceScore
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* Quick check for scanning phase (permissive)
|
||||
*
|
||||
* Only filters out obvious garbage - used during full library scans
|
||||
*/
|
||||
fun validateForScanning(
|
||||
face: Face,
|
||||
imageWidth: Int,
|
||||
imageHeight: Int
|
||||
): Boolean {
|
||||
// Only reject if:
|
||||
// 1. No eyes detected (obvious false positive)
|
||||
// 2. Face is tiny (< 10% of image)
|
||||
|
||||
val leftEye = face.getLandmark(FaceLandmark.LEFT_EYE)
|
||||
val rightEye = face.getLandmark(FaceLandmark.RIGHT_EYE)
|
||||
|
||||
if (leftEye == null && rightEye == null) {
|
||||
return false // No eyes = not a face
|
||||
}
|
||||
|
||||
val faceWidth = face.boundingBox.width()
|
||||
val faceWidthRatio = faceWidth.toFloat() / imageWidth.toFloat()
|
||||
|
||||
if (faceWidthRatio < 0.10f) {
|
||||
return false // Too small
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Face quality validation result
|
||||
*/
|
||||
data class FaceQualityValidation(
|
||||
val isValid: Boolean,
|
||||
val issues: List<String>,
|
||||
val confidenceScore: Float
|
||||
) {
|
||||
val passesStrictValidation: Boolean
|
||||
get() = isValid && confidenceScore >= 0.7f
|
||||
|
||||
val passesModerateValidation: Boolean
|
||||
get() = isValid && confidenceScore >= 0.5f
|
||||
}
|
||||
@@ -0,0 +1,597 @@
|
||||
package com.placeholder.sherpai2.domain.clustering
|
||||
|
||||
import android.content.Context
|
||||
import android.graphics.Bitmap
|
||||
import android.graphics.BitmapFactory
|
||||
import android.net.Uri
|
||||
import android.util.Log
|
||||
import com.google.android.gms.tasks.Tasks
|
||||
import com.google.mlkit.vision.common.InputImage
|
||||
import com.google.mlkit.vision.face.FaceDetection
|
||||
import com.google.mlkit.vision.face.FaceDetectorOptions
|
||||
import com.placeholder.sherpai2.data.local.dao.FaceCacheDao
|
||||
import com.placeholder.sherpai2.data.local.dao.ImageDao
|
||||
import com.placeholder.sherpai2.ml.FaceNetModel
|
||||
import dagger.hilt.android.qualifiers.ApplicationContext
|
||||
import kotlinx.coroutines.Dispatchers
|
||||
import kotlinx.coroutines.async
|
||||
import kotlinx.coroutines.awaitAll
|
||||
import kotlinx.coroutines.coroutineScope
|
||||
import kotlinx.coroutines.sync.Semaphore
|
||||
import kotlinx.coroutines.withContext
|
||||
import java.util.Calendar
|
||||
import javax.inject.Inject
|
||||
import javax.inject.Singleton
|
||||
import kotlin.math.sqrt
|
||||
import kotlin.random.Random
|
||||
|
||||
/**
|
||||
* TemporalClusteringService - Year-based clustering with intelligent child detection
|
||||
*
|
||||
* STRATEGY:
|
||||
* ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
|
||||
* 1. Process ALL photos (no limits)
|
||||
* 2. Apply strict quality filter (FaceQualityFilter)
|
||||
* 3. Group faces by YEAR
|
||||
* 4. Cluster within each year
|
||||
* 5. Link clusters across years (same person)
|
||||
* 6. Detect children (changing appearance over years)
|
||||
* 7. Generate tags: "Emma_2020", "Emma_Age_2", "Brad_Pitt"
|
||||
*
|
||||
* CHILD DETECTION:
|
||||
* ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
|
||||
* A person is a CHILD if:
|
||||
* - Appears across 3+ years
|
||||
* - Face embeddings change significantly between years (>0.20 distance)
|
||||
* - Consistent presence (not just random appearances)
|
||||
*
|
||||
* OUTPUT:
|
||||
* ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
|
||||
* Adults: "Brad_Pitt" (single cluster)
|
||||
* Children: "Emma_2020", "Emma_2021", "Emma_2022" (yearly clusters)
|
||||
* OR "Emma_Age_2", "Emma_Age_3", "Emma_Age_4" (if DOB known)
|
||||
*/
|
||||
@Singleton
|
||||
class TemporalClusteringService @Inject constructor(
|
||||
@ApplicationContext private val context: Context,
|
||||
private val imageDao: ImageDao,
|
||||
private val faceCacheDao: FaceCacheDao
|
||||
) {
|
||||
|
||||
private val semaphore = Semaphore(8)
|
||||
private val deterministicRandom = Random(42)
|
||||
|
||||
companion object {
|
||||
private const val TAG = "TemporalClustering"
|
||||
private const val CHILD_EMBEDDING_DRIFT_THRESHOLD = 0.20f // Significant change
|
||||
private const val CHILD_MIN_YEARS = 3 // Must span 3+ years
|
||||
private const val ADULT_SIMILARITY_THRESHOLD = 0.80f // 80% similar across years
|
||||
private const val CHILD_SIMILARITY_THRESHOLD = 0.70f // 70% similar (more lenient)
|
||||
}
|
||||
|
||||
/**
|
||||
* Discover people with year-based clustering
|
||||
*
|
||||
* @return List of AnnotatedCluster (year-specific clusters with metadata)
|
||||
*/
|
||||
suspend fun discoverPeopleByYear(
|
||||
onProgress: (Int, Int, String) -> Unit = { _, _, _ -> }
|
||||
): TemporalClusteringResult = withContext(Dispatchers.Default) {
|
||||
|
||||
val startTime = System.currentTimeMillis()
|
||||
|
||||
onProgress(5, 100, "Loading all photos...")
|
||||
|
||||
// STEP 1: Load ALL images (no limit)
|
||||
val allImages = withContext(Dispatchers.IO) {
|
||||
imageDao.getAllImages()
|
||||
}
|
||||
|
||||
if (allImages.isEmpty()) {
|
||||
return@withContext TemporalClusteringResult(
|
||||
clusters = emptyList(),
|
||||
totalPhotosProcessed = 0,
|
||||
totalFacesDetected = 0,
|
||||
processingTimeMs = 0,
|
||||
errorMessage = "No photos in library"
|
||||
)
|
||||
}
|
||||
|
||||
Log.d(TAG, "Processing ${allImages.size} photos (no limit)")
|
||||
|
||||
onProgress(10, 100, "Detecting high-quality faces...")
|
||||
|
||||
// STEP 2: Detect faces with STRICT quality filtering
|
||||
val faceNetModel = FaceNetModel(context)
|
||||
val detector = FaceDetection.getClient(
|
||||
FaceDetectorOptions.Builder()
|
||||
.setPerformanceMode(FaceDetectorOptions.PERFORMANCE_MODE_ACCURATE)
|
||||
.setLandmarkMode(FaceDetectorOptions.LANDMARK_MODE_ALL)
|
||||
.setMinFaceSize(0.15f)
|
||||
.build()
|
||||
)
|
||||
|
||||
try {
|
||||
val allFaces = mutableListOf<DetectedFaceWithEmbedding>()
|
||||
|
||||
coroutineScope {
|
||||
val jobs = allImages.mapIndexed { index, image ->
|
||||
async(Dispatchers.IO) {
|
||||
semaphore.acquire()
|
||||
try {
|
||||
val bitmap = loadBitmapDownsampled(Uri.parse(image.imageUri), 768)
|
||||
?: return@async emptyList()
|
||||
|
||||
val inputImage = InputImage.fromBitmap(bitmap, 0)
|
||||
val faces = Tasks.await(detector.process(inputImage))
|
||||
|
||||
val imageWidth = bitmap.width
|
||||
val imageHeight = bitmap.height
|
||||
|
||||
val validFaces = faces.mapNotNull { face ->
|
||||
// Apply STRICT quality filter
|
||||
val qualityCheck = FaceQualityFilter.validateForDiscovery(
|
||||
face = face,
|
||||
imageWidth = imageWidth,
|
||||
imageHeight = imageHeight
|
||||
)
|
||||
|
||||
if (!qualityCheck.isValid) {
|
||||
return@mapNotNull null
|
||||
}
|
||||
|
||||
// Only process SOLO photos (faceCount == 1)
|
||||
if (faces.size != 1) {
|
||||
return@mapNotNull null
|
||||
}
|
||||
|
||||
try {
|
||||
val faceBitmap = Bitmap.createBitmap(
|
||||
bitmap,
|
||||
face.boundingBox.left.coerceIn(0, bitmap.width - 1),
|
||||
face.boundingBox.top.coerceIn(0, bitmap.height - 1),
|
||||
face.boundingBox.width().coerceAtMost(bitmap.width - face.boundingBox.left),
|
||||
face.boundingBox.height().coerceAtMost(bitmap.height - face.boundingBox.top)
|
||||
)
|
||||
|
||||
val embedding = faceNetModel.generateEmbedding(faceBitmap)
|
||||
faceBitmap.recycle()
|
||||
|
||||
DetectedFaceWithEmbedding(
|
||||
imageId = image.imageId,
|
||||
imageUri = image.imageUri,
|
||||
capturedAt = image.capturedAt,
|
||||
embedding = embedding,
|
||||
boundingBox = face.boundingBox,
|
||||
confidence = qualityCheck.confidenceScore,
|
||||
faceCount = 1,
|
||||
imageWidth = imageWidth,
|
||||
imageHeight = imageHeight
|
||||
)
|
||||
} catch (e: Exception) {
|
||||
null
|
||||
}
|
||||
}
|
||||
|
||||
bitmap.recycle()
|
||||
|
||||
if (index % 50 == 0) {
|
||||
val progress = 10 + (index * 40 / allImages.size)
|
||||
onProgress(progress, 100, "Processed $index/${allImages.size} photos...")
|
||||
}
|
||||
|
||||
validFaces
|
||||
} finally {
|
||||
semaphore.release()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
jobs.awaitAll().flatten().forEach { allFaces.add(it) }
|
||||
}
|
||||
|
||||
Log.d(TAG, "Detected ${allFaces.size} high-quality solo faces")
|
||||
|
||||
if (allFaces.isEmpty()) {
|
||||
return@withContext TemporalClusteringResult(
|
||||
clusters = emptyList(),
|
||||
totalPhotosProcessed = allImages.size,
|
||||
totalFacesDetected = 0,
|
||||
processingTimeMs = System.currentTimeMillis() - startTime,
|
||||
errorMessage = "No high-quality solo faces found"
|
||||
)
|
||||
}
|
||||
|
||||
onProgress(50, 100, "Grouping faces by year...")
|
||||
|
||||
// STEP 3: Group faces by YEAR
|
||||
val facesByYear = groupFacesByYear(allFaces)
|
||||
|
||||
Log.d(TAG, "Faces grouped into ${facesByYear.size} years")
|
||||
|
||||
onProgress(60, 100, "Clustering within each year...")
|
||||
|
||||
// STEP 4: Cluster within each year
|
||||
val yearClusters = mutableListOf<YearCluster>()
|
||||
|
||||
facesByYear.forEach { (year, faces) ->
|
||||
Log.d(TAG, "Clustering $year: ${faces.size} faces")
|
||||
|
||||
val rawClusters = performDBSCAN(
|
||||
faces = faces,
|
||||
epsilon = 0.24f,
|
||||
minPoints = 3
|
||||
)
|
||||
|
||||
rawClusters.forEach { rawCluster ->
|
||||
yearClusters.add(
|
||||
YearCluster(
|
||||
year = year,
|
||||
faces = rawCluster.faces,
|
||||
centroid = calculateCentroid(rawCluster.faces.map { it.embedding })
|
||||
)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
Log.d(TAG, "Created ${yearClusters.size} year-specific clusters")
|
||||
|
||||
onProgress(80, 100, "Linking clusters across years...")
|
||||
|
||||
// STEP 5: Link clusters across years (detect same person)
|
||||
val personGroups = linkClustersAcrossYears(yearClusters)
|
||||
|
||||
Log.d(TAG, "Identified ${personGroups.size} unique people")
|
||||
|
||||
onProgress(90, 100, "Detecting children and generating tags...")
|
||||
|
||||
// STEP 6: Detect children and generate final clusters
|
||||
val annotatedClusters = personGroups.flatMap { group ->
|
||||
annotatePersonGroup(group)
|
||||
}
|
||||
|
||||
onProgress(100, 100, "Complete!")
|
||||
|
||||
TemporalClusteringResult(
|
||||
clusters = annotatedClusters.sortedByDescending { it.cluster.faces.size },
|
||||
totalPhotosProcessed = allImages.size,
|
||||
totalFacesDetected = allFaces.size,
|
||||
processingTimeMs = System.currentTimeMillis() - startTime
|
||||
)
|
||||
|
||||
} finally {
|
||||
faceNetModel.close()
|
||||
detector.close()
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Group faces by year of capture
|
||||
*/
|
||||
private fun groupFacesByYear(faces: List<DetectedFaceWithEmbedding>): Map<String, List<DetectedFaceWithEmbedding>> {
|
||||
return faces.groupBy { face ->
|
||||
val calendar = Calendar.getInstance()
|
||||
calendar.timeInMillis = face.capturedAt
|
||||
calendar.get(Calendar.YEAR).toString()
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Link year clusters that belong to the same person
|
||||
*/
|
||||
private fun linkClustersAcrossYears(yearClusters: List<YearCluster>): List<PersonGroup> {
|
||||
val sortedClusters = yearClusters.sortedBy { it.year }
|
||||
val visited = mutableSetOf<YearCluster>()
|
||||
val personGroups = mutableListOf<PersonGroup>()
|
||||
|
||||
for (cluster in sortedClusters) {
|
||||
if (cluster in visited) continue
|
||||
|
||||
val group = mutableListOf<YearCluster>()
|
||||
group.add(cluster)
|
||||
visited.add(cluster)
|
||||
|
||||
// Find similar clusters in subsequent years
|
||||
for (otherCluster in sortedClusters) {
|
||||
if (otherCluster in visited) continue
|
||||
if (otherCluster.year <= cluster.year) continue
|
||||
|
||||
val similarity = cosineSimilarity(cluster.centroid, otherCluster.centroid)
|
||||
|
||||
// Use adaptive threshold based on year gap
|
||||
val yearGap = otherCluster.year.toInt() - cluster.year.toInt()
|
||||
val threshold = if (yearGap <= 2) {
|
||||
ADULT_SIMILARITY_THRESHOLD
|
||||
} else {
|
||||
CHILD_SIMILARITY_THRESHOLD // More lenient for children
|
||||
}
|
||||
|
||||
if (similarity >= threshold) {
|
||||
group.add(otherCluster)
|
||||
visited.add(otherCluster)
|
||||
}
|
||||
}
|
||||
|
||||
personGroups.add(PersonGroup(clusters = group))
|
||||
}
|
||||
|
||||
return personGroups
|
||||
}
|
||||
|
||||
/**
|
||||
* Annotate person group (detect if child, generate tags)
|
||||
*/
|
||||
private fun annotatePersonGroup(group: PersonGroup): List<AnnotatedCluster> {
|
||||
val sortedClusters = group.clusters.sortedBy { it.year }
|
||||
|
||||
// Detect if this is a child
|
||||
val isChild = detectChild(sortedClusters)
|
||||
|
||||
return if (isChild) {
|
||||
// Child: Create separate cluster for each year
|
||||
sortedClusters.map { yearCluster ->
|
||||
AnnotatedCluster(
|
||||
cluster = FaceCluster(
|
||||
clusterId = 0,
|
||||
faces = yearCluster.faces,
|
||||
representativeFaces = selectRepresentativeFaces(yearCluster.faces, 6),
|
||||
photoCount = yearCluster.faces.size,
|
||||
averageConfidence = yearCluster.faces.map { it.confidence }.average().toFloat(),
|
||||
estimatedAge = AgeEstimate.CHILD,
|
||||
potentialSiblings = emptyList()
|
||||
),
|
||||
year = yearCluster.year,
|
||||
isChild = true,
|
||||
suggestedName = null,
|
||||
suggestedAge = estimateAgeInYear(yearCluster.year, sortedClusters)
|
||||
)
|
||||
}
|
||||
} else {
|
||||
// Adult: Single cluster combining all years
|
||||
val allFaces = sortedClusters.flatMap { it.faces }
|
||||
listOf(
|
||||
AnnotatedCluster(
|
||||
cluster = FaceCluster(
|
||||
clusterId = 0,
|
||||
faces = allFaces,
|
||||
representativeFaces = selectRepresentativeFaces(allFaces, 6),
|
||||
photoCount = allFaces.size,
|
||||
averageConfidence = allFaces.map { it.confidence }.average().toFloat(),
|
||||
estimatedAge = AgeEstimate.ADULT,
|
||||
potentialSiblings = emptyList()
|
||||
),
|
||||
year = "All Years",
|
||||
isChild = false,
|
||||
suggestedName = null,
|
||||
suggestedAge = null
|
||||
)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Detect if person group represents a child
|
||||
*/
|
||||
private fun detectChild(clusters: List<YearCluster>): Boolean {
|
||||
if (clusters.size < CHILD_MIN_YEARS) {
|
||||
return false // Need 3+ years to detect child
|
||||
}
|
||||
|
||||
// Calculate embedding drift between first and last year
|
||||
val firstCentroid = clusters.first().centroid
|
||||
val lastCentroid = clusters.last().centroid
|
||||
val drift = 1 - cosineSimilarity(firstCentroid, lastCentroid)
|
||||
|
||||
// If embeddings changed significantly, likely a child
|
||||
return drift >= CHILD_EMBEDDING_DRIFT_THRESHOLD
|
||||
}
|
||||
|
||||
/**
|
||||
* Estimate age in specific year based on cluster position
|
||||
*/
|
||||
private fun estimateAgeInYear(targetYear: String, allClusters: List<YearCluster>): Int? {
|
||||
val sortedClusters = allClusters.sortedBy { it.year }
|
||||
val firstYear = sortedClusters.first().year.toInt()
|
||||
val targetYearInt = targetYear.toInt()
|
||||
|
||||
val yearsSinceFirst = targetYearInt - firstYear
|
||||
return yearsSinceFirst + 1 // Start at age 1
|
||||
}
|
||||
|
||||
/**
|
||||
* Select representative faces
|
||||
*/
|
||||
private fun selectRepresentativeFaces(
|
||||
faces: List<DetectedFaceWithEmbedding>,
|
||||
count: Int
|
||||
): List<DetectedFaceWithEmbedding> {
|
||||
if (faces.size <= count) return faces
|
||||
|
||||
val centroid = calculateCentroid(faces.map { it.embedding })
|
||||
|
||||
return faces
|
||||
.map { face -> face to (1 - cosineSimilarity(face.embedding, centroid)) }
|
||||
.sortedBy { it.second }
|
||||
.take(count)
|
||||
.map { it.first }
|
||||
}
|
||||
|
||||
/**
|
||||
* DBSCAN clustering
|
||||
*/
|
||||
private fun performDBSCAN(
|
||||
faces: List<DetectedFaceWithEmbedding>,
|
||||
epsilon: Float,
|
||||
minPoints: Int
|
||||
): List<RawCluster> {
|
||||
val visited = mutableSetOf<Int>()
|
||||
val clusters = mutableListOf<RawCluster>()
|
||||
var clusterId = 0
|
||||
|
||||
for (i in faces.indices) {
|
||||
if (i in visited) continue
|
||||
|
||||
val neighbors = findNeighbors(i, faces, epsilon)
|
||||
|
||||
if (neighbors.size < minPoints) {
|
||||
visited.add(i)
|
||||
continue
|
||||
}
|
||||
|
||||
val cluster = mutableListOf<DetectedFaceWithEmbedding>()
|
||||
val queue = ArrayDeque(neighbors)
|
||||
|
||||
while (queue.isNotEmpty()) {
|
||||
val pointIdx = queue.removeFirst()
|
||||
if (pointIdx in visited) continue
|
||||
|
||||
visited.add(pointIdx)
|
||||
cluster.add(faces[pointIdx])
|
||||
|
||||
val pointNeighbors = findNeighbors(pointIdx, faces, epsilon)
|
||||
if (pointNeighbors.size >= minPoints) {
|
||||
queue.addAll(pointNeighbors.filter { it !in visited })
|
||||
}
|
||||
}
|
||||
|
||||
if (cluster.size >= minPoints) {
|
||||
clusters.add(RawCluster(clusterId++, cluster))
|
||||
}
|
||||
}
|
||||
|
||||
return clusters
|
||||
}
|
||||
|
||||
private fun findNeighbors(
|
||||
pointIdx: Int,
|
||||
faces: List<DetectedFaceWithEmbedding>,
|
||||
epsilon: Float
|
||||
): List<Int> {
|
||||
val point = faces[pointIdx]
|
||||
return faces.indices.filter { i ->
|
||||
if (i == pointIdx) return@filter false
|
||||
val similarity = cosineSimilarity(point.embedding, faces[i].embedding)
|
||||
similarity > (1 - epsilon)
|
||||
}
|
||||
}
|
||||
|
||||
private fun cosineSimilarity(a: FloatArray, b: FloatArray): Float {
|
||||
var dotProduct = 0f
|
||||
var normA = 0f
|
||||
var normB = 0f
|
||||
|
||||
for (i in a.indices) {
|
||||
dotProduct += a[i] * b[i]
|
||||
normA += a[i] * a[i]
|
||||
normB += b[i] * b[i]
|
||||
}
|
||||
|
||||
return dotProduct / (sqrt(normA) * sqrt(normB))
|
||||
}
|
||||
|
||||
private fun calculateCentroid(embeddings: List<FloatArray>): FloatArray {
|
||||
if (embeddings.isEmpty()) return FloatArray(0)
|
||||
|
||||
val size = embeddings.first().size
|
||||
val centroid = FloatArray(size) { 0f }
|
||||
|
||||
embeddings.forEach { embedding ->
|
||||
for (i in embedding.indices) {
|
||||
centroid[i] += embedding[i]
|
||||
}
|
||||
}
|
||||
|
||||
val count = embeddings.size.toFloat()
|
||||
for (i in centroid.indices) {
|
||||
centroid[i] /= count
|
||||
}
|
||||
|
||||
val norm = sqrt(centroid.map { it * it }.sum())
|
||||
if (norm > 0) {
|
||||
return centroid.map { it / norm }.toFloatArray()
|
||||
}
|
||||
|
||||
return centroid
|
||||
}
|
||||
|
||||
private fun loadBitmapDownsampled(uri: Uri, maxDim: Int): Bitmap? {
|
||||
return try {
|
||||
val opts = BitmapFactory.Options().apply { inJustDecodeBounds = true }
|
||||
context.contentResolver.openInputStream(uri)?.use {
|
||||
BitmapFactory.decodeStream(it, null, opts)
|
||||
}
|
||||
|
||||
var sample = 1
|
||||
while (opts.outWidth / sample > maxDim || opts.outHeight / sample > maxDim) {
|
||||
sample *= 2
|
||||
}
|
||||
|
||||
val finalOpts = BitmapFactory.Options().apply {
|
||||
inSampleSize = sample
|
||||
inPreferredConfig = Bitmap.Config.RGB_565
|
||||
}
|
||||
|
||||
context.contentResolver.openInputStream(uri)?.use {
|
||||
BitmapFactory.decodeStream(it, null, finalOpts)
|
||||
}
|
||||
} catch (e: Exception) {
|
||||
null
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Year-specific cluster
|
||||
*/
|
||||
data class YearCluster(
|
||||
val year: String,
|
||||
val faces: List<DetectedFaceWithEmbedding>,
|
||||
val centroid: FloatArray
|
||||
)
|
||||
|
||||
/**
|
||||
* Group of year clusters belonging to same person
|
||||
*/
|
||||
data class PersonGroup(
|
||||
val clusters: List<YearCluster>
|
||||
)
|
||||
|
||||
/**
|
||||
* Annotated cluster with temporal metadata
|
||||
*/
|
||||
data class AnnotatedCluster(
|
||||
val cluster: FaceCluster,
|
||||
val year: String,
|
||||
val isChild: Boolean,
|
||||
val suggestedName: String?,
|
||||
val suggestedAge: Int?
|
||||
) {
|
||||
/**
|
||||
* Generate tag for this cluster
|
||||
* Examples:
|
||||
* - Child: "Emma_2020" or "Emma_Age_2"
|
||||
* - Adult: "Brad_Pitt"
|
||||
*/
|
||||
fun generateTag(name: String): String {
|
||||
return if (isChild) {
|
||||
if (suggestedAge != null) {
|
||||
"${name}_Age_${suggestedAge}"
|
||||
} else {
|
||||
"${name}_${year}"
|
||||
}
|
||||
} else {
|
||||
name
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Result of temporal clustering
|
||||
*/
|
||||
data class TemporalClusteringResult(
|
||||
val clusters: List<AnnotatedCluster>,
|
||||
val totalPhotosProcessed: Int,
|
||||
val totalFacesDetected: Int,
|
||||
val processingTimeMs: Long,
|
||||
val errorMessage: String? = null
|
||||
)
|
||||
@@ -22,6 +22,7 @@ import com.placeholder.sherpai2.domain.clustering.ClusterQualityAnalyzer
|
||||
* - ✅ Complete naming dialog integration
|
||||
* - ✅ Quality analysis in cluster grid
|
||||
* - ✅ Better error handling
|
||||
* - ✅ Refinement flow support
|
||||
*/
|
||||
@OptIn(ExperimentalMaterial3Api::class)
|
||||
@Composable
|
||||
@@ -112,6 +113,12 @@ fun DiscoverPeopleScreen(
|
||||
ValidationPreviewScreen(
|
||||
personName = state.personName,
|
||||
validationResult = state.validationResult,
|
||||
onMarkFeedback = { feedbackMap ->
|
||||
viewModel.submitFeedback(state.cluster, feedbackMap)
|
||||
},
|
||||
onRequestRefinement = {
|
||||
viewModel.requestRefinement(state.cluster)
|
||||
},
|
||||
onApprove = {
|
||||
viewModel.approveValidationAndScan(
|
||||
personId = state.personId,
|
||||
@@ -124,6 +131,28 @@ fun DiscoverPeopleScreen(
|
||||
)
|
||||
}
|
||||
|
||||
// ===== REFINEMENT NEEDED =====
|
||||
is DiscoverUiState.RefinementNeeded -> {
|
||||
RefinementNeededContent(
|
||||
recommendation = state.recommendation,
|
||||
currentIteration = state.currentIteration,
|
||||
onRefine = {
|
||||
viewModel.requestRefinement(state.cluster)
|
||||
},
|
||||
onSkip = {
|
||||
viewModel.reset()
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
// ===== REFINING IN PROGRESS =====
|
||||
is DiscoverUiState.Refining -> {
|
||||
RefiningProgressContent(
|
||||
iteration = state.iteration,
|
||||
message = state.message
|
||||
)
|
||||
}
|
||||
|
||||
// ===== COMPLETE =====
|
||||
is DiscoverUiState.Complete -> {
|
||||
CompleteStateContent(
|
||||
@@ -154,6 +183,7 @@ fun DiscoverPeopleScreen(
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ===== IDLE STATE CONTENT =====
|
||||
|
||||
@Composable
|
||||
@@ -239,7 +269,7 @@ private fun ClusteringProgressContent(
|
||||
|
||||
if (total > 0) {
|
||||
LinearProgressIndicator(
|
||||
progress = progress.toFloat() / total.toFloat(),
|
||||
progress = { progress.toFloat() / total.toFloat() },
|
||||
modifier = Modifier
|
||||
.fillMaxWidth()
|
||||
.height(8.dp)
|
||||
@@ -287,7 +317,7 @@ private fun TrainingProgressContent(
|
||||
Spacer(modifier = Modifier.height(16.dp))
|
||||
|
||||
LinearProgressIndicator(
|
||||
progress = progress.toFloat() / total.toFloat(),
|
||||
progress = { progress.toFloat() / total.toFloat() },
|
||||
modifier = Modifier
|
||||
.fillMaxWidth()
|
||||
.height(8.dp)
|
||||
@@ -304,6 +334,131 @@ private fun TrainingProgressContent(
|
||||
}
|
||||
}
|
||||
|
||||
// ===== REFINEMENT NEEDED =====
|
||||
|
||||
@Composable
|
||||
private fun RefinementNeededContent(
|
||||
recommendation: com.placeholder.sherpai2.domain.clustering.RefinementRecommendation,
|
||||
currentIteration: Int,
|
||||
onRefine: () -> Unit,
|
||||
onSkip: () -> Unit
|
||||
) {
|
||||
Column(
|
||||
modifier = Modifier
|
||||
.fillMaxSize()
|
||||
.padding(24.dp),
|
||||
horizontalAlignment = Alignment.CenterHorizontally,
|
||||
verticalArrangement = Arrangement.Center
|
||||
) {
|
||||
Icon(
|
||||
imageVector = Icons.Default.Person,
|
||||
contentDescription = null,
|
||||
modifier = Modifier.size(80.dp),
|
||||
tint = MaterialTheme.colorScheme.primary
|
||||
)
|
||||
|
||||
Spacer(modifier = Modifier.height(24.dp))
|
||||
|
||||
Text(
|
||||
text = "Refinement Recommended",
|
||||
style = MaterialTheme.typography.headlineMedium,
|
||||
fontWeight = FontWeight.Bold
|
||||
)
|
||||
|
||||
Spacer(modifier = Modifier.height(16.dp))
|
||||
|
||||
Card(
|
||||
colors = CardDefaults.cardColors(
|
||||
containerColor = MaterialTheme.colorScheme.errorContainer
|
||||
)
|
||||
) {
|
||||
Column(
|
||||
modifier = Modifier.padding(16.dp)
|
||||
) {
|
||||
Text(
|
||||
text = recommendation.reason,
|
||||
style = MaterialTheme.typography.bodyMedium,
|
||||
color = MaterialTheme.colorScheme.onErrorContainer
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
Spacer(modifier = Modifier.height(16.dp))
|
||||
|
||||
Text(
|
||||
text = "Iteration: $currentIteration",
|
||||
style = MaterialTheme.typography.bodySmall,
|
||||
color = MaterialTheme.colorScheme.onSurfaceVariant
|
||||
)
|
||||
|
||||
Spacer(modifier = Modifier.height(32.dp))
|
||||
|
||||
Row(
|
||||
modifier = Modifier.fillMaxWidth(),
|
||||
horizontalArrangement = Arrangement.spacedBy(12.dp)
|
||||
) {
|
||||
OutlinedButton(
|
||||
onClick = onSkip,
|
||||
modifier = Modifier.weight(1f)
|
||||
) {
|
||||
Text("Skip")
|
||||
}
|
||||
|
||||
Button(
|
||||
onClick = onRefine,
|
||||
modifier = Modifier.weight(1f)
|
||||
) {
|
||||
Text("Refine Cluster")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ===== REFINING PROGRESS =====
|
||||
|
||||
@Composable
|
||||
private fun RefiningProgressContent(
|
||||
iteration: Int,
|
||||
message: String
|
||||
) {
|
||||
Column(
|
||||
modifier = Modifier
|
||||
.fillMaxSize()
|
||||
.padding(24.dp),
|
||||
horizontalAlignment = Alignment.CenterHorizontally,
|
||||
verticalArrangement = Arrangement.Center
|
||||
) {
|
||||
CircularProgressIndicator(
|
||||
modifier = Modifier.size(64.dp)
|
||||
)
|
||||
|
||||
Spacer(modifier = Modifier.height(32.dp))
|
||||
|
||||
Text(
|
||||
text = "Refining Cluster",
|
||||
style = MaterialTheme.typography.titleLarge,
|
||||
fontWeight = FontWeight.Bold
|
||||
)
|
||||
|
||||
Spacer(modifier = Modifier.height(16.dp))
|
||||
|
||||
Text(
|
||||
text = message,
|
||||
style = MaterialTheme.typography.bodyMedium,
|
||||
textAlign = TextAlign.Center,
|
||||
color = MaterialTheme.colorScheme.onSurfaceVariant
|
||||
)
|
||||
|
||||
Spacer(modifier = Modifier.height(8.dp))
|
||||
|
||||
Text(
|
||||
text = "Iteration $iteration",
|
||||
style = MaterialTheme.typography.bodySmall,
|
||||
color = MaterialTheme.colorScheme.onSurfaceVariant
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// ===== LOADING CONTENT =====
|
||||
|
||||
@Composable
|
||||
|
||||
@@ -2,9 +2,8 @@ package com.placeholder.sherpai2.ui.discover
|
||||
|
||||
import androidx.lifecycle.ViewModel
|
||||
import androidx.lifecycle.viewModelScope
|
||||
import com.placeholder.sherpai2.domain.clustering.ClusteringResult
|
||||
import com.placeholder.sherpai2.domain.clustering.FaceCluster
|
||||
import com.placeholder.sherpai2.domain.clustering.FaceClusteringService
|
||||
import com.placeholder.sherpai2.data.local.entity.FeedbackType
|
||||
import com.placeholder.sherpai2.domain.clustering.*
|
||||
import com.placeholder.sherpai2.domain.training.ClusterTrainingService
|
||||
import com.placeholder.sherpai2.domain.validation.ValidationScanResult
|
||||
import com.placeholder.sherpai2.domain.validation.ValidationScanService
|
||||
@@ -16,33 +15,49 @@ import kotlinx.coroutines.launch
|
||||
import javax.inject.Inject
|
||||
|
||||
/**
|
||||
* DiscoverPeopleViewModel - COMPLETE workflow with validation
|
||||
* DiscoverPeopleViewModel - COMPLETE workflow with feedback loop
|
||||
*
|
||||
* Flow:
|
||||
* FLOW WITH REFINEMENT:
|
||||
* ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
|
||||
* 1. Idle → Clustering → NamingReady (2x2 grid)
|
||||
* 2. Select cluster → NamingCluster (dialog)
|
||||
* 3. Confirm → AnalyzingCluster → Training → ValidationPreview
|
||||
* 4. Approve → Complete OR Reject → Error
|
||||
* 4. User reviews faces → Marks correct/incorrect
|
||||
* 5a. If too many incorrect → Refining (re-cluster without bad faces)
|
||||
* 5b. If approved → Complete OR Reject → Error
|
||||
* 6. Loop back to step 3 if refinement happened
|
||||
*
|
||||
* NEW FEATURES:
|
||||
* ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
|
||||
* ✅ User feedback collection
|
||||
* ✅ Cluster refinement loop
|
||||
* ✅ Feedback persistence
|
||||
* ✅ Quality-aware training (only confirmed faces)
|
||||
*/
|
||||
@HiltViewModel
|
||||
class DiscoverPeopleViewModel @Inject constructor(
|
||||
private val clusteringService: FaceClusteringService,
|
||||
private val trainingService: ClusterTrainingService,
|
||||
private val validationService: ValidationScanService
|
||||
private val validationService: ValidationScanService,
|
||||
private val refinementService: ClusterRefinementService
|
||||
) : ViewModel() {
|
||||
|
||||
private val _uiState = MutableStateFlow<DiscoverUiState>(DiscoverUiState.Idle)
|
||||
val uiState: StateFlow<DiscoverUiState> = _uiState.asStateFlow()
|
||||
|
||||
private val namedClusterIds = mutableSetOf<Int>()
|
||||
private var currentIterationCount = 0
|
||||
|
||||
fun startDiscovery() {
|
||||
viewModelScope.launch {
|
||||
try {
|
||||
namedClusterIds.clear()
|
||||
currentIterationCount = 0
|
||||
_uiState.value = DiscoverUiState.Clustering(0, 100, "Starting...")
|
||||
|
||||
// Use PREMIUM_SOLO_ONLY strategy for best results
|
||||
val result = clusteringService.discoverPeople(
|
||||
strategy = ClusteringStrategy.PREMIUM_SOLO_ONLY,
|
||||
onProgress = { current: Int, total: Int, message: String ->
|
||||
_uiState.value = DiscoverUiState.Clustering(current, total, message)
|
||||
}
|
||||
@@ -56,7 +71,7 @@ class DiscoverPeopleViewModel @Inject constructor(
|
||||
if (result.clusters.isEmpty()) {
|
||||
_uiState.value = DiscoverUiState.NoPeopleFound(
|
||||
result.errorMessage
|
||||
?: "No people clusters found.\n\nTry:\n• Adding more photos\n• Ensuring photos are clear\n• Having 3+ photos per person"
|
||||
?: "No people clusters found.\n\nTry:\n• Adding more solo photos\n• Ensuring photos are clear\n• Having 6+ photos per person"
|
||||
)
|
||||
} else {
|
||||
_uiState.value = DiscoverUiState.NamingReady(result)
|
||||
@@ -131,10 +146,11 @@ class DiscoverPeopleViewModel @Inject constructor(
|
||||
}
|
||||
)
|
||||
|
||||
// Stage 4: Show validation preview
|
||||
// Stage 4: Show validation preview WITH FEEDBACK SUPPORT
|
||||
_uiState.value = DiscoverUiState.ValidationPreview(
|
||||
personId = personId,
|
||||
personName = name,
|
||||
cluster = cluster,
|
||||
validationResult = validationResult
|
||||
)
|
||||
|
||||
@@ -144,14 +160,112 @@ class DiscoverPeopleViewModel @Inject constructor(
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* NEW: Handle user feedback from validation preview
|
||||
*
|
||||
* @param cluster The cluster being validated
|
||||
* @param feedbackMap Map of imageId → FeedbackType
|
||||
*/
|
||||
fun submitFeedback(
|
||||
cluster: FaceCluster,
|
||||
feedbackMap: Map<String, FeedbackType>
|
||||
) {
|
||||
viewModelScope.launch {
|
||||
try {
|
||||
// Convert imageId feedback to face feedback
|
||||
val faceFeedbackMap = cluster.faces
|
||||
.associateWith { face ->
|
||||
feedbackMap[face.imageId] ?: FeedbackType.UNCERTAIN
|
||||
}
|
||||
|
||||
val originalConfidences = cluster.faces.associateWith { it.confidence }
|
||||
|
||||
// Store feedback
|
||||
refinementService.storeFeedback(
|
||||
cluster = cluster,
|
||||
feedbackMap = faceFeedbackMap,
|
||||
originalConfidences = originalConfidences
|
||||
)
|
||||
|
||||
// Check if refinement needed
|
||||
val recommendation = refinementService.shouldRefineCluster(cluster)
|
||||
|
||||
if (recommendation.shouldRefine) {
|
||||
_uiState.value = DiscoverUiState.RefinementNeeded(
|
||||
cluster = cluster,
|
||||
recommendation = recommendation,
|
||||
currentIteration = currentIterationCount
|
||||
)
|
||||
}
|
||||
} catch (e: Exception) {
|
||||
_uiState.value = DiscoverUiState.Error(
|
||||
"Failed to process feedback: ${e.message}"
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* NEW: Request cluster refinement
|
||||
*
|
||||
* Re-clusters WITHOUT rejected faces
|
||||
*/
|
||||
fun requestRefinement(cluster: FaceCluster) {
|
||||
viewModelScope.launch {
|
||||
try {
|
||||
currentIterationCount++
|
||||
|
||||
_uiState.value = DiscoverUiState.Refining(
|
||||
iteration = currentIterationCount,
|
||||
message = "Removing incorrect faces and re-clustering..."
|
||||
)
|
||||
|
||||
// Refine cluster
|
||||
val refinementResult = refinementService.refineCluster(
|
||||
cluster = cluster,
|
||||
iterationNumber = currentIterationCount
|
||||
)
|
||||
|
||||
if (!refinementResult.success || refinementResult.refinedCluster == null) {
|
||||
_uiState.value = DiscoverUiState.Error(
|
||||
refinementResult.errorMessage
|
||||
?: "Failed to refine cluster. Please try manual training."
|
||||
)
|
||||
return@launch
|
||||
}
|
||||
|
||||
// Show refined cluster for re-validation
|
||||
val currentState = _uiState.value
|
||||
if (currentState is DiscoverUiState.RefinementNeeded) {
|
||||
// Re-train with refined cluster
|
||||
// This will loop back to ValidationPreview
|
||||
confirmClusterName(
|
||||
cluster = refinementResult.refinedCluster,
|
||||
name = currentState.cluster.representativeFaces.first().imageId, // Placeholder
|
||||
dateOfBirth = null,
|
||||
isChild = false,
|
||||
selectedSiblings = emptyList()
|
||||
)
|
||||
}
|
||||
|
||||
} catch (e: Exception) {
|
||||
_uiState.value = DiscoverUiState.Error(
|
||||
"Refinement failed: ${e.message}"
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fun approveValidationAndScan(personId: String, personName: String) {
|
||||
viewModelScope.launch {
|
||||
try {
|
||||
// Mark cluster as named (find it from previous state)
|
||||
// Mark cluster as named
|
||||
// TODO: Track this properly
|
||||
|
||||
_uiState.value = DiscoverUiState.Complete(
|
||||
message = "Successfully created model for \"$personName\"!\n\nFull library scan has been queued in the background."
|
||||
message = "Successfully created model for \"$personName\"!\n\n" +
|
||||
"Full library scan has been queued in the background.\n\n" +
|
||||
"✅ ${currentIterationCount} refinement iterations completed"
|
||||
)
|
||||
} catch (e: Exception) {
|
||||
_uiState.value = DiscoverUiState.Error(e.message ?: "Failed to start library scan")
|
||||
@@ -161,7 +275,8 @@ class DiscoverPeopleViewModel @Inject constructor(
|
||||
|
||||
fun rejectValidationAndImprove() {
|
||||
_uiState.value = DiscoverUiState.Error(
|
||||
"Please add more training photos and try again.\n\n(Feature coming: ability to add photos to existing model)"
|
||||
"Please add more training photos and try again.\n\n" +
|
||||
"(Feature coming: ability to add photos to existing model)"
|
||||
)
|
||||
}
|
||||
|
||||
@@ -175,9 +290,13 @@ class DiscoverPeopleViewModel @Inject constructor(
|
||||
fun reset() {
|
||||
_uiState.value = DiscoverUiState.Idle
|
||||
namedClusterIds.clear()
|
||||
currentIterationCount = 0
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* UI States - ENHANCED with refinement states
|
||||
*/
|
||||
sealed class DiscoverUiState {
|
||||
object Idle : DiscoverUiState()
|
||||
|
||||
@@ -205,12 +324,33 @@ sealed class DiscoverUiState {
|
||||
val total: Int
|
||||
) : DiscoverUiState()
|
||||
|
||||
/**
|
||||
* NEW: Validation with feedback support
|
||||
*/
|
||||
data class ValidationPreview(
|
||||
val personId: String,
|
||||
val personName: String,
|
||||
val cluster: FaceCluster,
|
||||
val validationResult: ValidationScanResult
|
||||
) : DiscoverUiState()
|
||||
|
||||
/**
|
||||
* NEW: Refinement needed state
|
||||
*/
|
||||
data class RefinementNeeded(
|
||||
val cluster: FaceCluster,
|
||||
val recommendation: RefinementRecommendation,
|
||||
val currentIteration: Int
|
||||
) : DiscoverUiState()
|
||||
|
||||
/**
|
||||
* NEW: Refining in progress
|
||||
*/
|
||||
data class Refining(
|
||||
val iteration: Int,
|
||||
val message: String
|
||||
) : DiscoverUiState()
|
||||
|
||||
data class Complete(
|
||||
val message: String
|
||||
) : DiscoverUiState()
|
||||
|
||||
@@ -0,0 +1,353 @@
|
||||
package com.placeholder.sherpai2.ui.discover
|
||||
|
||||
import androidx.compose.foundation.layout.*
|
||||
import androidx.compose.foundation.text.KeyboardOptions
|
||||
import androidx.compose.material.icons.Icons
|
||||
import androidx.compose.material.icons.filled.*
|
||||
import androidx.compose.material3.*
|
||||
import androidx.compose.runtime.*
|
||||
import androidx.compose.ui.Alignment
|
||||
import androidx.compose.ui.Modifier
|
||||
import androidx.compose.ui.text.font.FontWeight
|
||||
import androidx.compose.ui.text.input.KeyboardType
|
||||
import androidx.compose.ui.unit.dp
|
||||
import androidx.compose.ui.window.Dialog
|
||||
import com.placeholder.sherpai2.domain.clustering.AnnotatedCluster
|
||||
import com.placeholder.sherpai2.domain.clustering.ClusterQualityAnalyzer
|
||||
import com.placeholder.sherpai2.domain.clustering.ClusterQualityResult
|
||||
|
||||
/**
|
||||
* TemporalNamingDialog - ENHANCED with age input for temporal clustering
|
||||
*
|
||||
* NEW FEATURES:
|
||||
* ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
|
||||
* ✅ Name field: "Emma"
|
||||
* ✅ Age field: "2" (optional but recommended)
|
||||
* ✅ Year display: "Photos from 2020"
|
||||
* ✅ Auto-suggest: If year=2020 and DOB=2018 → Age=2
|
||||
*
|
||||
* NAMING PATTERNS:
|
||||
* ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
|
||||
* Adults:
|
||||
* - Name: "John Doe"
|
||||
* - Age: (leave empty)
|
||||
* - Result: Person "John Doe" with single model
|
||||
*
|
||||
* Children (with age):
|
||||
* - Name: "Emma"
|
||||
* - Age: "2"
|
||||
* - Year: "2020"
|
||||
* - Result: Person "Emma" with submodel "Emma_Age_2"
|
||||
*
|
||||
* Children (without age):
|
||||
* - Name: "Emma"
|
||||
* - Age: (empty)
|
||||
* - Year: "2020"
|
||||
* - Result: Person "Emma" with submodel "Emma_2020"
|
||||
*/
|
||||
@Composable
|
||||
fun TemporalNamingDialog(
|
||||
annotatedCluster: AnnotatedCluster,
|
||||
onConfirm: (name: String, age: Int?, isChild: Boolean) -> Unit,
|
||||
onDismiss: () -> Unit,
|
||||
qualityAnalyzer: ClusterQualityAnalyzer
|
||||
) {
|
||||
var name by remember { mutableStateOf(annotatedCluster.suggestedName ?: "") }
|
||||
var ageText by remember { mutableStateOf(annotatedCluster.suggestedAge?.toString() ?: "") }
|
||||
var isChild by remember { mutableStateOf(annotatedCluster.suggestedAge != null) }
|
||||
|
||||
// Analyze cluster quality
|
||||
val qualityResult = remember(annotatedCluster.cluster) {
|
||||
qualityAnalyzer.analyzeCluster(annotatedCluster.cluster)
|
||||
}
|
||||
|
||||
Dialog(onDismissRequest = onDismiss) {
|
||||
Card(
|
||||
modifier = Modifier
|
||||
.fillMaxWidth()
|
||||
.padding(16.dp)
|
||||
) {
|
||||
Column(
|
||||
modifier = Modifier.padding(24.dp),
|
||||
verticalArrangement = Arrangement.spacedBy(16.dp)
|
||||
) {
|
||||
// Header
|
||||
Text(
|
||||
text = "Name This Person",
|
||||
style = MaterialTheme.typography.headlineSmall,
|
||||
fontWeight = FontWeight.Bold
|
||||
)
|
||||
|
||||
// Year badge
|
||||
YearBadge(year = annotatedCluster.year)
|
||||
|
||||
HorizontalDivider()
|
||||
|
||||
// Quality warnings
|
||||
QualityWarnings(qualityResult)
|
||||
|
||||
// Name field
|
||||
OutlinedTextField(
|
||||
value = name,
|
||||
onValueChange = { name = it },
|
||||
label = { Text("Name") },
|
||||
placeholder = { Text("e.g., Emma") },
|
||||
leadingIcon = {
|
||||
Icon(Icons.Default.Person, contentDescription = null)
|
||||
},
|
||||
modifier = Modifier.fillMaxWidth(),
|
||||
singleLine = true
|
||||
)
|
||||
|
||||
// Child checkbox
|
||||
Row(
|
||||
modifier = Modifier.fillMaxWidth(),
|
||||
verticalAlignment = Alignment.CenterVertically
|
||||
) {
|
||||
Checkbox(
|
||||
checked = isChild,
|
||||
onCheckedChange = { isChild = it }
|
||||
)
|
||||
Spacer(modifier = Modifier.width(8.dp))
|
||||
Column {
|
||||
Text(
|
||||
text = "This is a child",
|
||||
style = MaterialTheme.typography.bodyMedium
|
||||
)
|
||||
Text(
|
||||
text = "Enable age-specific models",
|
||||
style = MaterialTheme.typography.bodySmall,
|
||||
color = MaterialTheme.colorScheme.onSurfaceVariant
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// Age field (only if child)
|
||||
if (isChild) {
|
||||
OutlinedTextField(
|
||||
value = ageText,
|
||||
onValueChange = {
|
||||
// Only allow numbers
|
||||
if (it.isEmpty() || it.all { c -> c.isDigit() }) {
|
||||
ageText = it
|
||||
}
|
||||
},
|
||||
label = { Text("Age in ${annotatedCluster.year}") },
|
||||
placeholder = { Text("e.g., 2") },
|
||||
leadingIcon = {
|
||||
Icon(Icons.Default.DateRange, contentDescription = null)
|
||||
},
|
||||
modifier = Modifier.fillMaxWidth(),
|
||||
singleLine = true,
|
||||
keyboardOptions = KeyboardOptions(keyboardType = KeyboardType.Number),
|
||||
supportingText = {
|
||||
Text("Optional: Helps create age-specific models")
|
||||
}
|
||||
)
|
||||
|
||||
// Model name preview
|
||||
if (name.isNotBlank()) {
|
||||
Card(
|
||||
colors = CardDefaults.cardColors(
|
||||
containerColor = MaterialTheme.colorScheme.primaryContainer
|
||||
)
|
||||
) {
|
||||
Row(
|
||||
modifier = Modifier
|
||||
.fillMaxWidth()
|
||||
.padding(12.dp),
|
||||
verticalAlignment = Alignment.CenterVertically
|
||||
) {
|
||||
Icon(
|
||||
imageVector = Icons.Default.Info,
|
||||
contentDescription = null,
|
||||
tint = MaterialTheme.colorScheme.onPrimaryContainer
|
||||
)
|
||||
Spacer(modifier = Modifier.width(8.dp))
|
||||
Column {
|
||||
Text(
|
||||
text = "Model will be created as:",
|
||||
style = MaterialTheme.typography.bodySmall,
|
||||
color = MaterialTheme.colorScheme.onPrimaryContainer
|
||||
)
|
||||
Text(
|
||||
text = buildModelName(name, ageText, annotatedCluster.year),
|
||||
style = MaterialTheme.typography.bodyMedium,
|
||||
fontWeight = FontWeight.Bold,
|
||||
color = MaterialTheme.colorScheme.onPrimaryContainer
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Cluster stats
|
||||
ClusterStats(qualityResult)
|
||||
|
||||
HorizontalDivider()
|
||||
|
||||
// Actions
|
||||
Row(
|
||||
modifier = Modifier.fillMaxWidth(),
|
||||
horizontalArrangement = Arrangement.spacedBy(12.dp)
|
||||
) {
|
||||
OutlinedButton(
|
||||
onClick = onDismiss,
|
||||
modifier = Modifier.weight(1f)
|
||||
) {
|
||||
Text("Cancel")
|
||||
}
|
||||
|
||||
Button(
|
||||
onClick = {
|
||||
val age = ageText.toIntOrNull()
|
||||
onConfirm(name, age, isChild)
|
||||
},
|
||||
modifier = Modifier.weight(1f),
|
||||
enabled = name.isNotBlank() && qualityResult.canTrain
|
||||
) {
|
||||
Text("Create")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Year badge showing photo year
|
||||
*/
|
||||
@Composable
|
||||
private fun YearBadge(year: String) {
|
||||
Surface(
|
||||
color = MaterialTheme.colorScheme.secondaryContainer,
|
||||
shape = MaterialTheme.shapes.small
|
||||
) {
|
||||
Row(
|
||||
modifier = Modifier.padding(horizontal = 12.dp, vertical = 6.dp),
|
||||
verticalAlignment = Alignment.CenterVertically,
|
||||
horizontalArrangement = Arrangement.spacedBy(4.dp)
|
||||
) {
|
||||
Icon(
|
||||
imageVector = Icons.Default.DateRange,
|
||||
contentDescription = null,
|
||||
modifier = Modifier.size(16.dp),
|
||||
tint = MaterialTheme.colorScheme.onSecondaryContainer
|
||||
)
|
||||
Text(
|
||||
text = "Photos from $year",
|
||||
style = MaterialTheme.typography.labelMedium,
|
||||
color = MaterialTheme.colorScheme.onSecondaryContainer
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Quality warnings
|
||||
*/
|
||||
@Composable
|
||||
private fun QualityWarnings(qualityResult: ClusterQualityResult) {
|
||||
if (qualityResult.warnings.isNotEmpty()) {
|
||||
Card(
|
||||
colors = CardDefaults.cardColors(
|
||||
containerColor = when (qualityResult.qualityTier) {
|
||||
com.placeholder.sherpai2.domain.clustering.ClusterQualityTier.POOR ->
|
||||
MaterialTheme.colorScheme.errorContainer
|
||||
com.placeholder.sherpai2.domain.clustering.ClusterQualityTier.GOOD ->
|
||||
MaterialTheme.colorScheme.tertiaryContainer
|
||||
else -> MaterialTheme.colorScheme.surfaceVariant
|
||||
}
|
||||
)
|
||||
) {
|
||||
Column(
|
||||
modifier = Modifier.padding(12.dp),
|
||||
verticalArrangement = Arrangement.spacedBy(4.dp)
|
||||
) {
|
||||
qualityResult.warnings.take(3).forEach { warning ->
|
||||
Row(
|
||||
verticalAlignment = Alignment.Top,
|
||||
horizontalArrangement = Arrangement.spacedBy(8.dp)
|
||||
) {
|
||||
Icon(
|
||||
imageVector = when (qualityResult.qualityTier) {
|
||||
com.placeholder.sherpai2.domain.clustering.ClusterQualityTier.POOR ->
|
||||
Icons.Default.Warning
|
||||
else -> Icons.Default.Info
|
||||
},
|
||||
contentDescription = null,
|
||||
modifier = Modifier.size(16.dp),
|
||||
tint = when (qualityResult.qualityTier) {
|
||||
com.placeholder.sherpai2.domain.clustering.ClusterQualityTier.POOR ->
|
||||
MaterialTheme.colorScheme.onErrorContainer
|
||||
else -> MaterialTheme.colorScheme.onSurfaceVariant
|
||||
}
|
||||
)
|
||||
Text(
|
||||
text = warning,
|
||||
style = MaterialTheme.typography.bodySmall,
|
||||
color = when (qualityResult.qualityTier) {
|
||||
com.placeholder.sherpai2.domain.clustering.ClusterQualityTier.POOR ->
|
||||
MaterialTheme.colorScheme.onErrorContainer
|
||||
else -> MaterialTheme.colorScheme.onSurfaceVariant
|
||||
}
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Cluster statistics
|
||||
*/
|
||||
@Composable
|
||||
private fun ClusterStats(qualityResult: ClusterQualityResult) {
|
||||
Row(
|
||||
modifier = Modifier.fillMaxWidth(),
|
||||
horizontalArrangement = Arrangement.SpaceEvenly
|
||||
) {
|
||||
StatItem(
|
||||
label = "Photos",
|
||||
value = qualityResult.soloPhotoCount.toString()
|
||||
)
|
||||
StatItem(
|
||||
label = "Clean Faces",
|
||||
value = qualityResult.cleanFaceCount.toString()
|
||||
)
|
||||
StatItem(
|
||||
label = "Quality",
|
||||
value = "${(qualityResult.qualityScore * 100).toInt()}%"
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@Composable
|
||||
private fun StatItem(label: String, value: String) {
|
||||
Column(
|
||||
horizontalAlignment = Alignment.CenterHorizontally
|
||||
) {
|
||||
Text(
|
||||
text = value,
|
||||
style = MaterialTheme.typography.titleMedium,
|
||||
fontWeight = FontWeight.Bold
|
||||
)
|
||||
Text(
|
||||
text = label,
|
||||
style = MaterialTheme.typography.bodySmall,
|
||||
color = MaterialTheme.colorScheme.onSurfaceVariant
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Build model name preview
|
||||
*/
|
||||
private fun buildModelName(name: String, ageText: String, year: String): String {
|
||||
return when {
|
||||
ageText.isNotBlank() -> "${name}_Age_${ageText}"
|
||||
else -> "${name}_${year}"
|
||||
}
|
||||
}
|
||||
@@ -1,12 +1,17 @@
|
||||
package com.placeholder.sherpai2.ui.discover
|
||||
|
||||
import android.net.Uri
|
||||
import androidx.compose.animation.AnimatedVisibility
|
||||
import androidx.compose.animation.core.animateFloatAsState
|
||||
import androidx.compose.foundation.background
|
||||
import androidx.compose.foundation.border
|
||||
import androidx.compose.foundation.clickable
|
||||
import androidx.compose.foundation.gestures.detectDragGestures
|
||||
import androidx.compose.foundation.layout.*
|
||||
import androidx.compose.foundation.lazy.grid.GridCells
|
||||
import androidx.compose.foundation.lazy.grid.LazyVerticalGrid
|
||||
import androidx.compose.foundation.lazy.grid.items
|
||||
import androidx.compose.foundation.shape.CircleShape
|
||||
import androidx.compose.foundation.shape.RoundedCornerShape
|
||||
import androidx.compose.material.icons.Icons
|
||||
import androidx.compose.material.icons.filled.*
|
||||
@@ -15,268 +20,458 @@ import androidx.compose.runtime.*
|
||||
import androidx.compose.ui.Alignment
|
||||
import androidx.compose.ui.Modifier
|
||||
import androidx.compose.ui.draw.clip
|
||||
import androidx.compose.ui.draw.scale
|
||||
import androidx.compose.ui.graphics.Color
|
||||
import androidx.compose.ui.input.pointer.pointerInput
|
||||
import androidx.compose.ui.layout.ContentScale
|
||||
import androidx.compose.ui.text.font.FontWeight
|
||||
import androidx.compose.ui.text.style.TextAlign
|
||||
import androidx.compose.ui.unit.IntOffset
|
||||
import androidx.compose.ui.unit.dp
|
||||
import androidx.compose.ui.zIndex
|
||||
import coil.compose.AsyncImage
|
||||
import com.placeholder.sherpai2.domain.validation.ValidationMatch
|
||||
import com.placeholder.sherpai2.domain.validation.ValidationQuality
|
||||
import com.placeholder.sherpai2.data.local.entity.FeedbackType
|
||||
import com.placeholder.sherpai2.domain.validation.ValidationScanResult
|
||||
import com.placeholder.sherpai2.domain.validation.ValidationMatch
|
||||
import kotlin.math.roundToInt
|
||||
|
||||
/**
|
||||
* ValidationPreviewScreen - STAGE 2 validation UI
|
||||
* ValidationPreviewScreen - User reviews validation results with swipe gestures
|
||||
*
|
||||
* Shows user a preview of matches found in validation scan
|
||||
* User can approve (→ full scan) or reject (→ add more photos)
|
||||
* FEATURES:
|
||||
* ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
|
||||
* ✅ Swipe right (✓) = Confirmed match
|
||||
* ✅ Swipe left (✗) = Rejected match
|
||||
* ✅ Tap = Mark uncertain (?)
|
||||
* ✅ Real-time feedback stats
|
||||
* ✅ Automatic refinement recommendation
|
||||
* ✅ Bottom bar with approve/reject/refine actions
|
||||
*
|
||||
* FLOW:
|
||||
* ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
|
||||
* 1. User swipes/taps to mark faces
|
||||
* 2. Feedback tracked in local state
|
||||
* 3. If >15% rejection → "Refine" button appears
|
||||
* 4. Approve → Sends feedback map to ViewModel
|
||||
* 5. Reject → Returns to previous screen
|
||||
* 6. Refine → Triggers cluster refinement
|
||||
*/
|
||||
@Composable
|
||||
fun ValidationPreviewScreen(
|
||||
personName: String,
|
||||
validationResult: ValidationScanResult,
|
||||
onMarkFeedback: (Map<String, FeedbackType>) -> Unit = {},
|
||||
onRequestRefinement: () -> Unit = {},
|
||||
onApprove: () -> Unit,
|
||||
onReject: () -> Unit,
|
||||
modifier: Modifier = Modifier
|
||||
) {
|
||||
// Get sample images from validation result matches
|
||||
val sampleMatches = remember(validationResult) {
|
||||
validationResult.matches.take(24) // Show up to 24 faces
|
||||
}
|
||||
|
||||
// Track feedback for each image (imageId -> FeedbackType)
|
||||
var feedbackMap by remember {
|
||||
mutableStateOf<Map<String, FeedbackType>>(emptyMap())
|
||||
}
|
||||
|
||||
// Calculate feedback statistics
|
||||
val confirmedCount = feedbackMap.count { it.value == FeedbackType.CONFIRMED_MATCH }
|
||||
val rejectedCount = feedbackMap.count { it.value == FeedbackType.REJECTED_MATCH }
|
||||
val uncertainCount = feedbackMap.count { it.value == FeedbackType.UNCERTAIN }
|
||||
val reviewedCount = feedbackMap.size
|
||||
val totalCount = sampleMatches.size
|
||||
|
||||
// Determine if refinement is recommended
|
||||
val rejectionRatio = if (reviewedCount > 0) {
|
||||
rejectedCount.toFloat() / reviewedCount.toFloat()
|
||||
} else {
|
||||
0f
|
||||
}
|
||||
val shouldRefine = rejectionRatio > 0.15f && rejectedCount >= 2
|
||||
|
||||
Scaffold(
|
||||
bottomBar = {
|
||||
ValidationBottomBar(
|
||||
confirmedCount = confirmedCount,
|
||||
rejectedCount = rejectedCount,
|
||||
uncertainCount = uncertainCount,
|
||||
reviewedCount = reviewedCount,
|
||||
totalCount = totalCount,
|
||||
shouldRefine = shouldRefine,
|
||||
onApprove = {
|
||||
onMarkFeedback(feedbackMap)
|
||||
onApprove()
|
||||
},
|
||||
onReject = onReject,
|
||||
onRefine = {
|
||||
onMarkFeedback(feedbackMap)
|
||||
onRequestRefinement()
|
||||
}
|
||||
)
|
||||
}
|
||||
) { paddingValues ->
|
||||
Column(
|
||||
modifier = modifier
|
||||
.fillMaxSize()
|
||||
.padding(paddingValues)
|
||||
.padding(16.dp)
|
||||
) {
|
||||
// Header
|
||||
Text(
|
||||
text = "Validation Results",
|
||||
text = "Validate \"$personName\"",
|
||||
style = MaterialTheme.typography.headlineMedium,
|
||||
fontWeight = FontWeight.Bold
|
||||
)
|
||||
|
||||
Spacer(modifier = Modifier.height(8.dp))
|
||||
|
||||
Text(
|
||||
text = "Review matches for \"$personName\" before scanning your entire library",
|
||||
style = MaterialTheme.typography.bodyMedium,
|
||||
color = MaterialTheme.colorScheme.onSurfaceVariant
|
||||
// Instructions
|
||||
InstructionsCard()
|
||||
|
||||
Spacer(modifier = Modifier.height(16.dp))
|
||||
|
||||
// Feedback stats
|
||||
FeedbackStatsCard(
|
||||
confirmedCount = confirmedCount,
|
||||
rejectedCount = rejectedCount,
|
||||
uncertainCount = uncertainCount,
|
||||
reviewedCount = reviewedCount,
|
||||
totalCount = totalCount
|
||||
)
|
||||
|
||||
Spacer(modifier = Modifier.height(16.dp))
|
||||
|
||||
// Quality Summary
|
||||
QualitySummaryCard(
|
||||
validationResult = validationResult,
|
||||
personName = personName
|
||||
)
|
||||
|
||||
Spacer(modifier = Modifier.height(16.dp))
|
||||
|
||||
// Matches Grid
|
||||
if (validationResult.matches.isNotEmpty()) {
|
||||
Text(
|
||||
text = "Sample Matches (${validationResult.matchCount})",
|
||||
style = MaterialTheme.typography.titleMedium,
|
||||
fontWeight = FontWeight.SemiBold
|
||||
)
|
||||
|
||||
Spacer(modifier = Modifier.height(8.dp))
|
||||
|
||||
// Grid of faces to review
|
||||
LazyVerticalGrid(
|
||||
columns = GridCells.Fixed(3),
|
||||
horizontalArrangement = Arrangement.spacedBy(8.dp),
|
||||
verticalArrangement = Arrangement.spacedBy(8.dp),
|
||||
modifier = Modifier.weight(1f)
|
||||
) {
|
||||
items(validationResult.matches.take(15)) { match ->
|
||||
MatchPreviewCard(match = match)
|
||||
items(
|
||||
items = sampleMatches,
|
||||
key = { match -> match.imageId }
|
||||
) { match ->
|
||||
SwipeableFaceCard(
|
||||
match = match,
|
||||
currentFeedback = feedbackMap[match.imageId],
|
||||
onFeedbackChange = { feedback ->
|
||||
feedbackMap = feedbackMap.toMutableMap().apply {
|
||||
put(match.imageId, feedback)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// No matches found
|
||||
NoMatchesCard()
|
||||
}
|
||||
|
||||
Spacer(modifier = Modifier.height(16.dp))
|
||||
|
||||
// Action Buttons
|
||||
Row(
|
||||
modifier = Modifier.fillMaxWidth(),
|
||||
horizontalArrangement = Arrangement.spacedBy(12.dp)
|
||||
) {
|
||||
// Reject button
|
||||
OutlinedButton(
|
||||
onClick = onReject,
|
||||
modifier = Modifier.weight(1f),
|
||||
colors = ButtonDefaults.outlinedButtonColors(
|
||||
contentColor = MaterialTheme.colorScheme.error
|
||||
)
|
||||
) {
|
||||
Icon(
|
||||
imageVector = Icons.Default.Close,
|
||||
contentDescription = null,
|
||||
modifier = Modifier.size(20.dp)
|
||||
)
|
||||
Spacer(modifier = Modifier.width(8.dp))
|
||||
Text("Add More Photos")
|
||||
}
|
||||
|
||||
// Approve button
|
||||
Button(
|
||||
onClick = onApprove,
|
||||
modifier = Modifier.weight(1f),
|
||||
enabled = validationResult.qualityAssessment != ValidationQuality.NO_MATCHES
|
||||
) {
|
||||
Icon(
|
||||
imageVector = Icons.Default.Check,
|
||||
contentDescription = null,
|
||||
modifier = Modifier.size(20.dp)
|
||||
)
|
||||
Spacer(modifier = Modifier.width(8.dp))
|
||||
Text("Scan Library")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Swipeable face card with visual feedback indicators
|
||||
*/
|
||||
@Composable
|
||||
private fun QualitySummaryCard(
|
||||
validationResult: ValidationScanResult,
|
||||
personName: String
|
||||
private fun SwipeableFaceCard(
|
||||
match: ValidationMatch,
|
||||
currentFeedback: FeedbackType?,
|
||||
onFeedbackChange: (FeedbackType) -> Unit
|
||||
) {
|
||||
val (backgroundColor, iconColor, statusText, statusIcon) = when (validationResult.qualityAssessment) {
|
||||
ValidationQuality.EXCELLENT -> {
|
||||
Quadruple(
|
||||
Color(0xFF1B5E20).copy(alpha = 0.1f),
|
||||
Color(0xFF1B5E20),
|
||||
"Excellent Match Quality",
|
||||
Icons.Default.CheckCircle
|
||||
var offsetX by remember { mutableFloatStateOf(0f) }
|
||||
var isDragging by remember { mutableStateOf(false) }
|
||||
|
||||
val scale by animateFloatAsState(
|
||||
targetValue = if (isDragging) 1.1f else 1f,
|
||||
label = "scale"
|
||||
)
|
||||
|
||||
Box(
|
||||
modifier = Modifier
|
||||
.aspectRatio(1f)
|
||||
.scale(scale)
|
||||
.zIndex(if (isDragging) 1f else 0f)
|
||||
) {
|
||||
// Face image with border color based on feedback
|
||||
AsyncImage(
|
||||
model = Uri.parse(match.imageUri),
|
||||
contentDescription = "Face",
|
||||
modifier = Modifier
|
||||
.fillMaxSize()
|
||||
.clip(RoundedCornerShape(12.dp))
|
||||
.border(
|
||||
width = 3.dp,
|
||||
color = when (currentFeedback) {
|
||||
FeedbackType.CONFIRMED_MATCH -> Color(0xFF4CAF50) // Green
|
||||
FeedbackType.REJECTED_MATCH -> Color(0xFFF44336) // Red
|
||||
FeedbackType.UNCERTAIN -> Color(0xFFFF9800) // Orange
|
||||
else -> MaterialTheme.colorScheme.outline
|
||||
},
|
||||
shape = RoundedCornerShape(12.dp)
|
||||
)
|
||||
.offset { IntOffset(offsetX.roundToInt(), 0) }
|
||||
.pointerInput(Unit) {
|
||||
detectDragGestures(
|
||||
onDragStart = {
|
||||
isDragging = true
|
||||
},
|
||||
onDrag = { _, dragAmount ->
|
||||
offsetX += dragAmount.x
|
||||
},
|
||||
onDragEnd = {
|
||||
isDragging = false
|
||||
|
||||
// Determine feedback based on swipe direction
|
||||
when {
|
||||
offsetX > 100 -> {
|
||||
onFeedbackChange(FeedbackType.CONFIRMED_MATCH)
|
||||
}
|
||||
offsetX < -100 -> {
|
||||
onFeedbackChange(FeedbackType.REJECTED_MATCH)
|
||||
}
|
||||
}
|
||||
|
||||
// Reset position
|
||||
offsetX = 0f
|
||||
},
|
||||
onDragCancel = {
|
||||
isDragging = false
|
||||
offsetX = 0f
|
||||
}
|
||||
)
|
||||
}
|
||||
ValidationQuality.GOOD -> {
|
||||
Quadruple(
|
||||
Color(0xFF2E7D32).copy(alpha = 0.1f),
|
||||
Color(0xFF2E7D32),
|
||||
"Good Match Quality",
|
||||
Icons.Default.ThumbUp
|
||||
.clickable {
|
||||
// Tap to toggle uncertain
|
||||
val newFeedback = when (currentFeedback) {
|
||||
FeedbackType.UNCERTAIN -> null
|
||||
else -> FeedbackType.UNCERTAIN
|
||||
}
|
||||
if (newFeedback != null) {
|
||||
onFeedbackChange(newFeedback)
|
||||
}
|
||||
},
|
||||
contentScale = ContentScale.Crop
|
||||
)
|
||||
|
||||
// Confidence badge (top-left)
|
||||
Surface(
|
||||
modifier = Modifier
|
||||
.align(Alignment.TopStart)
|
||||
.padding(4.dp),
|
||||
shape = RoundedCornerShape(4.dp),
|
||||
color = Color.Black.copy(alpha = 0.6f)
|
||||
) {
|
||||
Text(
|
||||
text = "${(match.confidence * 100).toInt()}%",
|
||||
modifier = Modifier.padding(horizontal = 6.dp, vertical = 2.dp),
|
||||
style = MaterialTheme.typography.labelSmall,
|
||||
color = Color.White,
|
||||
fontWeight = FontWeight.Bold
|
||||
)
|
||||
}
|
||||
ValidationQuality.FAIR -> {
|
||||
Quadruple(
|
||||
Color(0xFFF57F17).copy(alpha = 0.1f),
|
||||
Color(0xFFF57F17),
|
||||
"Fair Match Quality",
|
||||
Icons.Default.Warning
|
||||
)
|
||||
}
|
||||
ValidationQuality.POOR -> {
|
||||
Quadruple(
|
||||
Color(0xFFD32F2F).copy(alpha = 0.1f),
|
||||
Color(0xFFD32F2F),
|
||||
"Poor Match Quality",
|
||||
Icons.Default.Warning
|
||||
)
|
||||
}
|
||||
ValidationQuality.NO_MATCHES -> {
|
||||
Quadruple(
|
||||
Color(0xFFD32F2F).copy(alpha = 0.1f),
|
||||
Color(0xFFD32F2F),
|
||||
"No Matches Found",
|
||||
Icons.Default.Close
|
||||
|
||||
// Feedback indicator overlay (top-right)
|
||||
if (currentFeedback != null) {
|
||||
Surface(
|
||||
modifier = Modifier
|
||||
.align(Alignment.TopEnd)
|
||||
.padding(4.dp),
|
||||
shape = CircleShape,
|
||||
color = when (currentFeedback) {
|
||||
FeedbackType.CONFIRMED_MATCH -> Color(0xFF4CAF50)
|
||||
FeedbackType.REJECTED_MATCH -> Color(0xFFF44336)
|
||||
FeedbackType.UNCERTAIN -> Color(0xFFFF9800)
|
||||
else -> Color.Transparent
|
||||
},
|
||||
shadowElevation = 2.dp
|
||||
) {
|
||||
Icon(
|
||||
imageVector = when (currentFeedback) {
|
||||
FeedbackType.CONFIRMED_MATCH -> Icons.Default.Check
|
||||
FeedbackType.REJECTED_MATCH -> Icons.Default.Close
|
||||
FeedbackType.UNCERTAIN -> Icons.Default.Warning
|
||||
else -> Icons.Default.Info
|
||||
},
|
||||
contentDescription = currentFeedback.name,
|
||||
tint = Color.White,
|
||||
modifier = Modifier
|
||||
.size(32.dp)
|
||||
.padding(6.dp)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// Swipe hint during drag
|
||||
if (isDragging) {
|
||||
SwipeDragHint(offsetX = offsetX)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Swipe drag hint overlay
|
||||
*/
|
||||
@Composable
|
||||
private fun BoxScope.SwipeDragHint(offsetX: Float) {
|
||||
val hintText = when {
|
||||
offsetX > 50 -> "✓ Correct"
|
||||
offsetX < -50 -> "✗ Incorrect"
|
||||
else -> "Keep swiping"
|
||||
}
|
||||
|
||||
val hintColor = when {
|
||||
offsetX > 50 -> Color(0xFF4CAF50)
|
||||
offsetX < -50 -> Color(0xFFF44336)
|
||||
else -> Color.Gray
|
||||
}
|
||||
|
||||
Surface(
|
||||
modifier = Modifier
|
||||
.align(Alignment.BottomCenter)
|
||||
.padding(8.dp),
|
||||
shape = RoundedCornerShape(4.dp),
|
||||
color = hintColor.copy(alpha = 0.9f)
|
||||
) {
|
||||
Text(
|
||||
text = hintText,
|
||||
modifier = Modifier.padding(horizontal = 8.dp, vertical = 4.dp),
|
||||
style = MaterialTheme.typography.labelSmall,
|
||||
color = Color.White,
|
||||
fontWeight = FontWeight.Bold
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Instructions card showing gesture controls
|
||||
*/
|
||||
@Composable
|
||||
private fun InstructionsCard() {
|
||||
Card(
|
||||
modifier = Modifier.fillMaxWidth(),
|
||||
colors = CardDefaults.cardColors(
|
||||
containerColor = backgroundColor
|
||||
containerColor = MaterialTheme.colorScheme.primaryContainer
|
||||
)
|
||||
) {
|
||||
Column(
|
||||
modifier = Modifier.padding(16.dp)
|
||||
) {
|
||||
Row(
|
||||
modifier = Modifier.padding(16.dp),
|
||||
verticalAlignment = Alignment.CenterVertically
|
||||
) {
|
||||
Icon(
|
||||
imageVector = statusIcon,
|
||||
imageVector = Icons.Default.Info,
|
||||
contentDescription = null,
|
||||
tint = iconColor,
|
||||
modifier = Modifier.size(24.dp)
|
||||
tint = MaterialTheme.colorScheme.onPrimaryContainer
|
||||
)
|
||||
Spacer(modifier = Modifier.width(8.dp))
|
||||
|
||||
Spacer(modifier = Modifier.width(12.dp))
|
||||
|
||||
Column {
|
||||
Text(
|
||||
text = statusText,
|
||||
style = MaterialTheme.typography.titleMedium,
|
||||
text = "Review Detected Faces",
|
||||
style = MaterialTheme.typography.titleSmall,
|
||||
fontWeight = FontWeight.Bold,
|
||||
color = iconColor
|
||||
color = MaterialTheme.colorScheme.onPrimaryContainer
|
||||
)
|
||||
}
|
||||
|
||||
Spacer(modifier = Modifier.height(12.dp))
|
||||
|
||||
// Stats
|
||||
Row(
|
||||
modifier = Modifier.fillMaxWidth(),
|
||||
horizontalArrangement = Arrangement.SpaceBetween
|
||||
) {
|
||||
StatItem(
|
||||
label = "Matches Found",
|
||||
value = "${validationResult.matchCount} / ${validationResult.sampleSize}"
|
||||
)
|
||||
StatItem(
|
||||
label = "Avg Confidence",
|
||||
value = "${(validationResult.averageConfidence * 100).toInt()}%"
|
||||
)
|
||||
StatItem(
|
||||
label = "Threshold",
|
||||
value = "${(validationResult.threshold * 100).toInt()}%"
|
||||
)
|
||||
}
|
||||
|
||||
// Recommendation
|
||||
if (validationResult.qualityAssessment != ValidationQuality.NO_MATCHES) {
|
||||
Spacer(modifier = Modifier.height(12.dp))
|
||||
|
||||
val recommendation = when (validationResult.qualityAssessment) {
|
||||
ValidationQuality.EXCELLENT ->
|
||||
"✅ Model looks great! Safe to scan your full library."
|
||||
ValidationQuality.GOOD ->
|
||||
"✅ Model quality is good. You can proceed with the full scan."
|
||||
ValidationQuality.FAIR ->
|
||||
"⚠️ Model quality is acceptable but could be improved with more photos."
|
||||
ValidationQuality.POOR ->
|
||||
"⚠️ Consider adding more diverse, high-quality training photos."
|
||||
ValidationQuality.NO_MATCHES -> ""
|
||||
}
|
||||
|
||||
Spacer(modifier = Modifier.height(4.dp))
|
||||
Text(
|
||||
text = recommendation,
|
||||
style = MaterialTheme.typography.bodyMedium,
|
||||
color = MaterialTheme.colorScheme.onSurfaceVariant
|
||||
)
|
||||
} else {
|
||||
Spacer(modifier = Modifier.height(12.dp))
|
||||
Text(
|
||||
text = "No matches found. The model may need more or better training photos, or the validation sample didn't include $personName.",
|
||||
style = MaterialTheme.typography.bodyMedium,
|
||||
color = MaterialTheme.colorScheme.error
|
||||
text = "Swipe right ✅ for correct, left ❌ for incorrect, tap ❓ for uncertain",
|
||||
style = MaterialTheme.typography.bodySmall,
|
||||
color = MaterialTheme.colorScheme.onPrimaryContainer
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Feedback statistics card
|
||||
*/
|
||||
@Composable
|
||||
private fun StatItem(
|
||||
label: String,
|
||||
value: String
|
||||
private fun FeedbackStatsCard(
|
||||
confirmedCount: Int,
|
||||
rejectedCount: Int,
|
||||
uncertainCount: Int,
|
||||
reviewedCount: Int,
|
||||
totalCount: Int
|
||||
) {
|
||||
Card {
|
||||
Row(
|
||||
modifier = Modifier
|
||||
.fillMaxWidth()
|
||||
.padding(16.dp),
|
||||
horizontalArrangement = Arrangement.SpaceEvenly
|
||||
) {
|
||||
FeedbackStat(
|
||||
icon = Icons.Default.Check,
|
||||
color = Color(0xFF4CAF50),
|
||||
count = confirmedCount,
|
||||
label = "Correct"
|
||||
)
|
||||
|
||||
FeedbackStat(
|
||||
icon = Icons.Default.Close,
|
||||
color = Color(0xFFF44336),
|
||||
count = rejectedCount,
|
||||
label = "Incorrect"
|
||||
)
|
||||
|
||||
FeedbackStat(
|
||||
icon = Icons.Default.Warning,
|
||||
color = Color(0xFFFF9800),
|
||||
count = uncertainCount,
|
||||
label = "Uncertain"
|
||||
)
|
||||
}
|
||||
|
||||
val progressValue = if (totalCount > 0) {
|
||||
reviewedCount.toFloat() / totalCount.toFloat()
|
||||
} else {
|
||||
0f
|
||||
}
|
||||
|
||||
LinearProgressIndicator(
|
||||
progress = { progressValue },
|
||||
modifier = Modifier
|
||||
.fillMaxWidth()
|
||||
.height(4.dp)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Individual feedback statistic item
|
||||
*/
|
||||
@Composable
|
||||
private fun FeedbackStat(
|
||||
icon: androidx.compose.ui.graphics.vector.ImageVector,
|
||||
color: Color,
|
||||
count: Int,
|
||||
label: String
|
||||
) {
|
||||
Column(
|
||||
horizontalAlignment = Alignment.CenterHorizontally
|
||||
) {
|
||||
Surface(
|
||||
shape = CircleShape,
|
||||
color = color.copy(alpha = 0.2f)
|
||||
) {
|
||||
Icon(
|
||||
imageVector = icon,
|
||||
contentDescription = null,
|
||||
tint = color,
|
||||
modifier = Modifier
|
||||
.size(40.dp)
|
||||
.padding(8.dp)
|
||||
)
|
||||
}
|
||||
|
||||
Spacer(modifier = Modifier.height(4.dp))
|
||||
|
||||
Text(
|
||||
text = value,
|
||||
style = MaterialTheme.typography.titleLarge,
|
||||
text = count.toString(),
|
||||
style = MaterialTheme.typography.titleMedium,
|
||||
fontWeight = FontWeight.Bold
|
||||
)
|
||||
|
||||
Text(
|
||||
text = label,
|
||||
style = MaterialTheme.typography.bodySmall,
|
||||
@@ -285,111 +480,134 @@ private fun StatItem(
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Bottom action bar with approve/reject/refine buttons
|
||||
*/
|
||||
@Composable
|
||||
private fun MatchPreviewCard(
|
||||
match: ValidationMatch
|
||||
private fun ValidationBottomBar(
|
||||
confirmedCount: Int,
|
||||
rejectedCount: Int,
|
||||
uncertainCount: Int,
|
||||
reviewedCount: Int,
|
||||
totalCount: Int,
|
||||
shouldRefine: Boolean,
|
||||
onApprove: () -> Unit,
|
||||
onReject: () -> Unit,
|
||||
onRefine: () -> Unit
|
||||
) {
|
||||
Box(
|
||||
modifier = Modifier
|
||||
.aspectRatio(1f)
|
||||
.clip(RoundedCornerShape(8.dp))
|
||||
.background(MaterialTheme.colorScheme.surfaceVariant)
|
||||
) {
|
||||
AsyncImage(
|
||||
model = Uri.parse(match.imageUri),
|
||||
contentDescription = "Match preview",
|
||||
modifier = Modifier.fillMaxSize(),
|
||||
contentScale = ContentScale.Crop
|
||||
)
|
||||
|
||||
// Confidence badge
|
||||
Surface(
|
||||
modifier = Modifier
|
||||
.align(Alignment.BottomEnd)
|
||||
.padding(4.dp),
|
||||
shape = RoundedCornerShape(4.dp),
|
||||
color = Color.Black.copy(alpha = 0.7f)
|
||||
) {
|
||||
Text(
|
||||
text = "${(match.confidence * 100).toInt()}%",
|
||||
style = MaterialTheme.typography.labelSmall,
|
||||
color = Color.White,
|
||||
modifier = Modifier.padding(horizontal = 6.dp, vertical = 2.dp)
|
||||
)
|
||||
}
|
||||
|
||||
// Face count indicator (if group photo)
|
||||
if (match.faceCount > 1) {
|
||||
Surface(
|
||||
modifier = Modifier
|
||||
.align(Alignment.TopEnd)
|
||||
.padding(4.dp),
|
||||
shape = RoundedCornerShape(4.dp),
|
||||
color = MaterialTheme.colorScheme.primary
|
||||
) {
|
||||
Row(
|
||||
modifier = Modifier.padding(horizontal = 4.dp, vertical = 2.dp),
|
||||
verticalAlignment = Alignment.CenterVertically
|
||||
) {
|
||||
Icon(
|
||||
imageVector = Icons.Default.Person,
|
||||
contentDescription = null,
|
||||
tint = Color.White,
|
||||
modifier = Modifier.size(12.dp)
|
||||
)
|
||||
Text(
|
||||
text = "${match.faceCount}",
|
||||
style = MaterialTheme.typography.labelSmall,
|
||||
color = Color.White
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Composable
|
||||
private fun NoMatchesCard() {
|
||||
Card(
|
||||
modifier = Modifier.fillMaxWidth(),
|
||||
colors = CardDefaults.cardColors(
|
||||
containerColor = MaterialTheme.colorScheme.errorContainer
|
||||
)
|
||||
color = MaterialTheme.colorScheme.surface,
|
||||
shadowElevation = 8.dp
|
||||
) {
|
||||
Column(
|
||||
modifier = Modifier.padding(24.dp),
|
||||
horizontalAlignment = Alignment.CenterHorizontally
|
||||
modifier = Modifier.padding(16.dp)
|
||||
) {
|
||||
// Refinement warning banner
|
||||
AnimatedVisibility(visible = shouldRefine) {
|
||||
RefinementWarningBanner(
|
||||
rejectedCount = rejectedCount,
|
||||
reviewedCount = reviewedCount,
|
||||
onRefine = onRefine
|
||||
)
|
||||
}
|
||||
|
||||
// Main action buttons
|
||||
Row(
|
||||
modifier = Modifier.fillMaxWidth(),
|
||||
horizontalArrangement = Arrangement.spacedBy(12.dp)
|
||||
) {
|
||||
OutlinedButton(
|
||||
onClick = onReject,
|
||||
modifier = Modifier.weight(1f)
|
||||
) {
|
||||
Icon(Icons.Default.Close, contentDescription = null)
|
||||
Spacer(modifier = Modifier.width(8.dp))
|
||||
Text("Reject")
|
||||
}
|
||||
|
||||
Button(
|
||||
onClick = onApprove,
|
||||
modifier = Modifier.weight(1f),
|
||||
enabled = confirmedCount > 0 || (reviewedCount == 0 && totalCount > 6)
|
||||
) {
|
||||
Icon(Icons.Default.Check, contentDescription = null)
|
||||
Spacer(modifier = Modifier.width(8.dp))
|
||||
Text("Approve")
|
||||
}
|
||||
}
|
||||
|
||||
// Review progress text
|
||||
Spacer(modifier = Modifier.height(8.dp))
|
||||
|
||||
Text(
|
||||
text = if (reviewedCount == 0) {
|
||||
"Review faces above or approve to continue"
|
||||
} else {
|
||||
"Reviewed $reviewedCount of $totalCount faces"
|
||||
},
|
||||
style = MaterialTheme.typography.bodySmall,
|
||||
color = MaterialTheme.colorScheme.onSurfaceVariant,
|
||||
textAlign = TextAlign.Center,
|
||||
modifier = Modifier.fillMaxWidth()
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Refinement warning banner component
|
||||
*/
|
||||
@Composable
|
||||
private fun RefinementWarningBanner(
|
||||
rejectedCount: Int,
|
||||
reviewedCount: Int,
|
||||
onRefine: () -> Unit
|
||||
) {
|
||||
Column {
|
||||
Card(
|
||||
colors = CardDefaults.cardColors(
|
||||
containerColor = MaterialTheme.colorScheme.errorContainer
|
||||
),
|
||||
modifier = Modifier.fillMaxWidth()
|
||||
) {
|
||||
Row(
|
||||
modifier = Modifier.padding(12.dp),
|
||||
verticalAlignment = Alignment.CenterVertically
|
||||
) {
|
||||
Icon(
|
||||
imageVector = Icons.Default.Warning,
|
||||
contentDescription = null,
|
||||
tint = MaterialTheme.colorScheme.error,
|
||||
modifier = Modifier.size(48.dp)
|
||||
tint = MaterialTheme.colorScheme.onErrorContainer
|
||||
)
|
||||
Spacer(modifier = Modifier.height(16.dp))
|
||||
|
||||
Spacer(modifier = Modifier.width(12.dp))
|
||||
|
||||
Column(modifier = Modifier.weight(1f)) {
|
||||
Text(
|
||||
text = "No Matches Found",
|
||||
style = MaterialTheme.typography.titleLarge,
|
||||
text = "High Rejection Rate",
|
||||
style = MaterialTheme.typography.titleSmall,
|
||||
fontWeight = FontWeight.Bold,
|
||||
color = MaterialTheme.colorScheme.error
|
||||
)
|
||||
Spacer(modifier = Modifier.height(8.dp))
|
||||
Text(
|
||||
text = "The validation scan didn't find this person in the sample photos. This could mean:\n\n" +
|
||||
"• The model needs more training photos\n" +
|
||||
"• The training photos weren't diverse enough\n" +
|
||||
"• The person wasn't in the validation sample",
|
||||
style = MaterialTheme.typography.bodyMedium,
|
||||
color = MaterialTheme.colorScheme.onErrorContainer
|
||||
)
|
||||
Text(
|
||||
text = "${(rejectedCount.toFloat() / reviewedCount.toFloat() * 100).toInt()}% rejected. Consider refining the cluster.",
|
||||
style = MaterialTheme.typography.bodySmall,
|
||||
color = MaterialTheme.colorScheme.onErrorContainer
|
||||
)
|
||||
}
|
||||
|
||||
Button(
|
||||
onClick = onRefine,
|
||||
colors = ButtonDefaults.buttonColors(
|
||||
containerColor = MaterialTheme.colorScheme.error
|
||||
)
|
||||
) {
|
||||
Text("Refine")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Helper data class for quality indicator
|
||||
private data class Quadruple<A, B, C, D>(
|
||||
val first: A,
|
||||
val second: B,
|
||||
val third: C,
|
||||
val fourth: D
|
||||
)
|
||||
Spacer(modifier = Modifier.height(12.dp))
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user