Compare commits
5 Commits
d1032a0e6e
...
master
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
804f3d5640 | ||
|
|
cfec2b980a | ||
|
|
1ef8faad17 | ||
|
|
941337f671 | ||
|
|
4aa3499bb3 |
4
.idea/deploymentTargetSelector.xml
generated
4
.idea/deploymentTargetSelector.xml
generated
@@ -4,10 +4,10 @@
|
|||||||
<selectionStates>
|
<selectionStates>
|
||||||
<SelectionState runConfigName="app">
|
<SelectionState runConfigName="app">
|
||||||
<option name="selectionMode" value="DROPDOWN" />
|
<option name="selectionMode" value="DROPDOWN" />
|
||||||
<DropdownSelection timestamp="2026-01-20T00:30:16.888577418Z">
|
<DropdownSelection timestamp="2026-01-27T00:21:15.014661014Z">
|
||||||
<Target type="DEFAULT_BOOT">
|
<Target type="DEFAULT_BOOT">
|
||||||
<handle>
|
<handle>
|
||||||
<DeviceId pluginId="LocalEmulator" identifier="path=/home/genki/.android/avd/Medium_Phone.avd" />
|
<DeviceId pluginId="PhysicalDevice" identifier="serial=R3CX106YYCB" />
|
||||||
</handle>
|
</handle>
|
||||||
</Target>
|
</Target>
|
||||||
</DropdownSelection>
|
</DropdownSelection>
|
||||||
|
|||||||
93
.idea/deviceManager.xml
generated
93
.idea/deviceManager.xml
generated
@@ -3,6 +3,24 @@
|
|||||||
<component name="DeviceTable">
|
<component name="DeviceTable">
|
||||||
<option name="collapsedNodes">
|
<option name="collapsedNodes">
|
||||||
<list>
|
<list>
|
||||||
|
<CategoryListState>
|
||||||
|
<option name="categories">
|
||||||
|
<list>
|
||||||
|
<CategoryState>
|
||||||
|
<option name="attribute" value="Type" />
|
||||||
|
<option name="value" value="Virtual" />
|
||||||
|
</CategoryState>
|
||||||
|
<CategoryState>
|
||||||
|
<option name="attribute" value="Type" />
|
||||||
|
<option name="value" value="Virtual" />
|
||||||
|
</CategoryState>
|
||||||
|
<CategoryState>
|
||||||
|
<option name="attribute" value="Type" />
|
||||||
|
<option name="value" value="Virtual" />
|
||||||
|
</CategoryState>
|
||||||
|
</list>
|
||||||
|
</option>
|
||||||
|
</CategoryListState>
|
||||||
<CategoryListState>
|
<CategoryListState>
|
||||||
<option name="categories">
|
<option name="categories">
|
||||||
<list>
|
<list>
|
||||||
@@ -42,6 +60,81 @@
|
|||||||
<option value="Type" />
|
<option value="Type" />
|
||||||
<option value="Type" />
|
<option value="Type" />
|
||||||
<option value="Type" />
|
<option value="Type" />
|
||||||
|
<option value="Type" />
|
||||||
|
<option value="Type" />
|
||||||
|
<option value="Type" />
|
||||||
|
<option value="Type" />
|
||||||
|
<option value="Type" />
|
||||||
|
<option value="Type" />
|
||||||
|
<option value="Type" />
|
||||||
|
<option value="Type" />
|
||||||
|
<option value="Type" />
|
||||||
|
<option value="Type" />
|
||||||
|
<option value="Type" />
|
||||||
|
<option value="Type" />
|
||||||
|
<option value="Type" />
|
||||||
|
<option value="Type" />
|
||||||
|
<option value="Type" />
|
||||||
|
<option value="Type" />
|
||||||
|
<option value="Type" />
|
||||||
|
<option value="Type" />
|
||||||
|
<option value="Type" />
|
||||||
|
<option value="Type" />
|
||||||
|
<option value="Type" />
|
||||||
|
<option value="Type" />
|
||||||
|
<option value="Type" />
|
||||||
|
<option value="Type" />
|
||||||
|
<option value="Type" />
|
||||||
|
<option value="Type" />
|
||||||
|
<option value="Type" />
|
||||||
|
<option value="Type" />
|
||||||
|
<option value="Type" />
|
||||||
|
<option value="Type" />
|
||||||
|
<option value="Type" />
|
||||||
|
<option value="Type" />
|
||||||
|
<option value="Type" />
|
||||||
|
<option value="Type" />
|
||||||
|
<option value="Type" />
|
||||||
|
<option value="Type" />
|
||||||
|
<option value="Type" />
|
||||||
|
<option value="Type" />
|
||||||
|
<option value="Type" />
|
||||||
|
<option value="Type" />
|
||||||
|
<option value="Type" />
|
||||||
|
<option value="Type" />
|
||||||
|
<option value="Type" />
|
||||||
|
<option value="Type" />
|
||||||
|
<option value="Type" />
|
||||||
|
<option value="Type" />
|
||||||
|
<option value="Type" />
|
||||||
|
<option value="Type" />
|
||||||
|
<option value="Type" />
|
||||||
|
<option value="Type" />
|
||||||
|
<option value="Type" />
|
||||||
|
<option value="Type" />
|
||||||
|
<option value="Type" />
|
||||||
|
<option value="Type" />
|
||||||
|
<option value="Type" />
|
||||||
|
<option value="Type" />
|
||||||
|
<option value="Type" />
|
||||||
|
<option value="Type" />
|
||||||
|
<option value="Type" />
|
||||||
|
<option value="Type" />
|
||||||
|
<option value="Type" />
|
||||||
|
<option value="Type" />
|
||||||
|
<option value="Type" />
|
||||||
|
<option value="Type" />
|
||||||
|
<option value="Type" />
|
||||||
|
<option value="Type" />
|
||||||
|
<option value="Type" />
|
||||||
|
<option value="Type" />
|
||||||
|
<option value="Type" />
|
||||||
|
<option value="Type" />
|
||||||
|
<option value="Type" />
|
||||||
|
<option value="Type" />
|
||||||
|
<option value="Type" />
|
||||||
|
<option value="Type" />
|
||||||
|
<option value="Type" />
|
||||||
</list>
|
</list>
|
||||||
</option>
|
</option>
|
||||||
</component>
|
</component>
|
||||||
|
|||||||
@@ -48,6 +48,9 @@ dependencies {
|
|||||||
implementation(libs.androidx.lifecycle.viewmodel.compose)
|
implementation(libs.androidx.lifecycle.viewmodel.compose)
|
||||||
implementation(libs.androidx.activity.compose)
|
implementation(libs.androidx.activity.compose)
|
||||||
|
|
||||||
|
// DataStore Preferences
|
||||||
|
implementation("androidx.datastore:datastore-preferences:1.1.1")
|
||||||
|
|
||||||
// Compose
|
// Compose
|
||||||
implementation(platform(libs.androidx.compose.bom))
|
implementation(platform(libs.androidx.compose.bom))
|
||||||
implementation(libs.androidx.compose.ui)
|
implementation(libs.androidx.compose.ui)
|
||||||
|
|||||||
@@ -10,6 +10,10 @@ import com.placeholder.sherpai2.data.local.entity.*
|
|||||||
/**
|
/**
|
||||||
* AppDatabase - Complete database for SherpAI2
|
* AppDatabase - Complete database for SherpAI2
|
||||||
*
|
*
|
||||||
|
* VERSION 12 - Distribution-based rejection stats
|
||||||
|
* - Added similarityStdDev, similarityMin to FaceModelEntity
|
||||||
|
* - Enables self-calibrating threshold for face matching
|
||||||
|
*
|
||||||
* VERSION 10 - User Feedback Loop
|
* VERSION 10 - User Feedback Loop
|
||||||
* - Added UserFeedbackEntity for storing user corrections
|
* - Added UserFeedbackEntity for storing user corrections
|
||||||
* - Enables cluster refinement before training
|
* - Enables cluster refinement before training
|
||||||
@@ -44,14 +48,15 @@ import com.placeholder.sherpai2.data.local.entity.*
|
|||||||
PhotoFaceTagEntity::class,
|
PhotoFaceTagEntity::class,
|
||||||
PersonAgeTagEntity::class,
|
PersonAgeTagEntity::class,
|
||||||
FaceCacheEntity::class,
|
FaceCacheEntity::class,
|
||||||
UserFeedbackEntity::class, // NEW: User corrections
|
UserFeedbackEntity::class,
|
||||||
|
PersonStatisticsEntity::class, // Pre-computed person stats
|
||||||
|
|
||||||
// ===== COLLECTIONS =====
|
// ===== COLLECTIONS =====
|
||||||
CollectionEntity::class,
|
CollectionEntity::class,
|
||||||
CollectionImageEntity::class,
|
CollectionImageEntity::class,
|
||||||
CollectionFilterEntity::class
|
CollectionFilterEntity::class
|
||||||
],
|
],
|
||||||
version = 10, // INCREMENTED for user feedback
|
version = 12, // INCREMENTED for distribution-based rejection stats
|
||||||
exportSchema = false
|
exportSchema = false
|
||||||
)
|
)
|
||||||
abstract class AppDatabase : RoomDatabase() {
|
abstract class AppDatabase : RoomDatabase() {
|
||||||
@@ -70,7 +75,8 @@ abstract class AppDatabase : RoomDatabase() {
|
|||||||
abstract fun photoFaceTagDao(): PhotoFaceTagDao
|
abstract fun photoFaceTagDao(): PhotoFaceTagDao
|
||||||
abstract fun personAgeTagDao(): PersonAgeTagDao
|
abstract fun personAgeTagDao(): PersonAgeTagDao
|
||||||
abstract fun faceCacheDao(): FaceCacheDao
|
abstract fun faceCacheDao(): FaceCacheDao
|
||||||
abstract fun userFeedbackDao(): UserFeedbackDao // NEW
|
abstract fun userFeedbackDao(): UserFeedbackDao
|
||||||
|
abstract fun personStatisticsDao(): PersonStatisticsDao
|
||||||
|
|
||||||
// ===== COLLECTIONS DAO =====
|
// ===== COLLECTIONS DAO =====
|
||||||
abstract fun collectionDao(): CollectionDao
|
abstract fun collectionDao(): CollectionDao
|
||||||
@@ -242,13 +248,60 @@ val MIGRATION_9_10 = object : Migration(9, 10) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* MIGRATION 10 → 11 (Person Statistics)
|
||||||
|
*
|
||||||
|
* Changes:
|
||||||
|
* 1. Create person_statistics table for pre-computed aggregates
|
||||||
|
*/
|
||||||
|
val MIGRATION_10_11 = object : Migration(10, 11) {
|
||||||
|
override fun migrate(database: SupportSQLiteDatabase) {
|
||||||
|
|
||||||
|
// Create person_statistics table
|
||||||
|
database.execSQL("""
|
||||||
|
CREATE TABLE IF NOT EXISTS person_statistics (
|
||||||
|
personId TEXT PRIMARY KEY NOT NULL,
|
||||||
|
photoCount INTEGER NOT NULL DEFAULT 0,
|
||||||
|
firstPhotoDate INTEGER NOT NULL DEFAULT 0,
|
||||||
|
lastPhotoDate INTEGER NOT NULL DEFAULT 0,
|
||||||
|
averageConfidence REAL NOT NULL DEFAULT 0,
|
||||||
|
agesWithPhotos TEXT,
|
||||||
|
updatedAt INTEGER NOT NULL DEFAULT 0,
|
||||||
|
FOREIGN KEY(personId) REFERENCES persons(id) ON DELETE CASCADE
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
|
||||||
|
// Index for sorting by photo count (People Dashboard)
|
||||||
|
database.execSQL("CREATE INDEX IF NOT EXISTS index_person_statistics_photoCount ON person_statistics(photoCount)")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* MIGRATION 11 → 12 (Distribution-based Rejection Stats)
|
||||||
|
*
|
||||||
|
* Changes:
|
||||||
|
* 1. Add similarityStdDev column to face_models (default 0.05)
|
||||||
|
* 2. Add similarityMin column to face_models (default 0.6)
|
||||||
|
*
|
||||||
|
* These fields enable self-calibrating thresholds during scanning.
|
||||||
|
* During training, we compute stats from training sample similarities
|
||||||
|
* and use (mean - 2*stdDev) as a floor for matching.
|
||||||
|
*/
|
||||||
|
val MIGRATION_11_12 = object : Migration(11, 12) {
|
||||||
|
override fun migrate(database: SupportSQLiteDatabase) {
|
||||||
|
// Add distribution stats columns with sensible defaults for existing models
|
||||||
|
database.execSQL("ALTER TABLE face_models ADD COLUMN similarityStdDev REAL NOT NULL DEFAULT 0.05")
|
||||||
|
database.execSQL("ALTER TABLE face_models ADD COLUMN similarityMin REAL NOT NULL DEFAULT 0.6")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* PRODUCTION MIGRATION NOTES:
|
* PRODUCTION MIGRATION NOTES:
|
||||||
*
|
*
|
||||||
* Before shipping to users, update DatabaseModule to use migrations:
|
* Before shipping to users, update DatabaseModule to use migrations:
|
||||||
*
|
*
|
||||||
* Room.databaseBuilder(context, AppDatabase::class.java, "sherpai.db")
|
* Room.databaseBuilder(context, AppDatabase::class.java, "sherpai.db")
|
||||||
* .addMigrations(MIGRATION_7_8, MIGRATION_8_9, MIGRATION_9_10) // Add all migrations
|
* .addMigrations(MIGRATION_7_8, MIGRATION_8_9, MIGRATION_9_10, MIGRATION_10_11, MIGRATION_11_12) // Add all migrations
|
||||||
* // .fallbackToDestructiveMigration() // Remove this
|
* // .fallbackToDestructiveMigration() // Remove this
|
||||||
* .build()
|
* .build()
|
||||||
*/
|
*/
|
||||||
@@ -4,21 +4,9 @@ import androidx.room.*
|
|||||||
import com.placeholder.sherpai2.data.local.entity.FaceCacheEntity
|
import com.placeholder.sherpai2.data.local.entity.FaceCacheEntity
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* FaceCacheDao - NO SOLO-PHOTO FILTER
|
* FaceCacheDao - ENHANCED with Rolling Scan support
|
||||||
*
|
*
|
||||||
* CRITICAL CHANGE:
|
* FIXED: Replaced Map return type with proper data class
|
||||||
* ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
|
|
||||||
* Removed all faceCount filters from queries
|
|
||||||
*
|
|
||||||
* WHY:
|
|
||||||
* - Group photos contain high-quality faces (especially for children)
|
|
||||||
* - IoU matching ensures we extract the CORRECT face from group photos
|
|
||||||
* - Rejecting group photos was eliminating 60-70% of quality faces!
|
|
||||||
*
|
|
||||||
* RESULT:
|
|
||||||
* - 2-3x more faces for clustering
|
|
||||||
* - Quality remains high (still filter by size + score)
|
|
||||||
* - Better clusters, especially for children
|
|
||||||
*/
|
*/
|
||||||
@Dao
|
@Dao
|
||||||
interface FaceCacheDao {
|
interface FaceCacheDao {
|
||||||
@@ -124,8 +112,179 @@ interface FaceCacheDao {
|
|||||||
|
|
||||||
@Query("DELETE FROM face_cache")
|
@Query("DELETE FROM face_cache")
|
||||||
suspend fun deleteAll()
|
suspend fun deleteAll()
|
||||||
|
|
||||||
|
// ═══════════════════════════════════════════════════════════════════════════
|
||||||
|
// NEW: ROLLING SCAN SUPPORT
|
||||||
|
// ═══════════════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
/**
|
||||||
|
* CRITICAL: Batch get face cache entries by image IDs
|
||||||
|
*
|
||||||
|
* Used by FaceSimilarityScorer to retrieve embeddings for scoring
|
||||||
|
*
|
||||||
|
* Performance: ~10ms for 1000 images with index on imageId
|
||||||
|
*/
|
||||||
|
@Query("""
|
||||||
|
SELECT * FROM face_cache
|
||||||
|
WHERE imageId IN (:imageIds)
|
||||||
|
AND embedding IS NOT NULL
|
||||||
|
ORDER BY qualityScore DESC
|
||||||
|
""")
|
||||||
|
suspend fun getFaceCacheByImageIds(imageIds: List<String>): List<FaceCacheEntity>
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get ALL photos with cached faces for rolling scan
|
||||||
|
*
|
||||||
|
* Returns all high-quality faces with embeddings
|
||||||
|
* Sorted by quality (solo photos first due to quality boost)
|
||||||
|
*/
|
||||||
|
@Query("""
|
||||||
|
SELECT * FROM face_cache
|
||||||
|
WHERE embedding IS NOT NULL
|
||||||
|
AND qualityScore >= :minQuality
|
||||||
|
AND faceAreaRatio >= :minRatio
|
||||||
|
ORDER BY qualityScore DESC, faceAreaRatio DESC
|
||||||
|
""")
|
||||||
|
suspend fun getAllPhotosWithFacesForScanning(
|
||||||
|
minQuality: Float = 0.6f,
|
||||||
|
minRatio: Float = 0.03f
|
||||||
|
): List<FaceCacheEntity>
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get embedding for a single image
|
||||||
|
*
|
||||||
|
* If multiple faces in image, returns the highest quality face
|
||||||
|
*/
|
||||||
|
@Query("""
|
||||||
|
SELECT * FROM face_cache
|
||||||
|
WHERE imageId = :imageId
|
||||||
|
AND embedding IS NOT NULL
|
||||||
|
ORDER BY qualityScore DESC
|
||||||
|
LIMIT 1
|
||||||
|
""")
|
||||||
|
suspend fun getEmbeddingByImageId(imageId: String): FaceCacheEntity?
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get distinct image IDs with cached embeddings
|
||||||
|
*
|
||||||
|
* Useful for getting list of all scannable images
|
||||||
|
*/
|
||||||
|
@Query("""
|
||||||
|
SELECT DISTINCT imageId FROM face_cache
|
||||||
|
WHERE embedding IS NOT NULL
|
||||||
|
AND qualityScore >= :minQuality
|
||||||
|
ORDER BY qualityScore DESC
|
||||||
|
""")
|
||||||
|
suspend fun getDistinctImageIdsWithEmbeddings(
|
||||||
|
minQuality: Float = 0.6f
|
||||||
|
): List<String>
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get face count per image (for quality boosting)
|
||||||
|
*
|
||||||
|
* FIXED: Returns List<ImageFaceCount> instead of Map
|
||||||
|
*/
|
||||||
|
@Query("""
|
||||||
|
SELECT imageId, COUNT(*) as faceCount
|
||||||
|
FROM face_cache
|
||||||
|
WHERE embedding IS NOT NULL
|
||||||
|
GROUP BY imageId
|
||||||
|
""")
|
||||||
|
suspend fun getFaceCountsPerImage(): List<ImageFaceCount>
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get embeddings for specific images (for centroid calculation)
|
||||||
|
*
|
||||||
|
* Used when initializing rolling scan with seed photos
|
||||||
|
*/
|
||||||
|
@Query("""
|
||||||
|
SELECT * FROM face_cache
|
||||||
|
WHERE imageId IN (:imageIds)
|
||||||
|
AND embedding IS NOT NULL
|
||||||
|
ORDER BY qualityScore DESC
|
||||||
|
""")
|
||||||
|
suspend fun getEmbeddingsForImages(imageIds: List<String>): List<FaceCacheEntity>
|
||||||
|
|
||||||
|
// ═══════════════════════════════════════════════════════════════════════════
|
||||||
|
// PREMIUM FACES - For training photo selection
|
||||||
|
// ═══════════════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get PREMIUM faces only - ideal for training seeds
|
||||||
|
*
|
||||||
|
* Premium = solo photo (faceCount=1) + large face + frontal + high quality
|
||||||
|
*
|
||||||
|
* These are the clearest, most unambiguous faces for user to pick seeds from.
|
||||||
|
*/
|
||||||
|
@Query("""
|
||||||
|
SELECT fc.* FROM face_cache fc
|
||||||
|
INNER JOIN images i ON fc.imageId = i.imageId
|
||||||
|
WHERE i.faceCount = 1
|
||||||
|
AND fc.faceAreaRatio >= :minAreaRatio
|
||||||
|
AND fc.isFrontal = 1
|
||||||
|
AND fc.qualityScore >= :minQuality
|
||||||
|
AND fc.embedding IS NOT NULL
|
||||||
|
ORDER BY fc.qualityScore DESC, fc.faceAreaRatio DESC
|
||||||
|
LIMIT :limit
|
||||||
|
""")
|
||||||
|
suspend fun getPremiumFaces(
|
||||||
|
minAreaRatio: Float = 0.10f,
|
||||||
|
minQuality: Float = 0.7f,
|
||||||
|
limit: Int = 500
|
||||||
|
): List<FaceCacheEntity>
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get premium face CANDIDATES - same criteria but WITHOUT embedding requirement.
|
||||||
|
* Used to find faces that need embedding generation.
|
||||||
|
*/
|
||||||
|
@Query("""
|
||||||
|
SELECT fc.* FROM face_cache fc
|
||||||
|
INNER JOIN images i ON fc.imageId = i.imageId
|
||||||
|
WHERE i.faceCount = 1
|
||||||
|
AND fc.faceAreaRatio >= :minAreaRatio
|
||||||
|
AND fc.isFrontal = 1
|
||||||
|
AND fc.qualityScore >= :minQuality
|
||||||
|
AND fc.embedding IS NULL
|
||||||
|
ORDER BY fc.qualityScore DESC, fc.faceAreaRatio DESC
|
||||||
|
LIMIT :limit
|
||||||
|
""")
|
||||||
|
suspend fun getPremiumFaceCandidatesNeedingEmbeddings(
|
||||||
|
minAreaRatio: Float = 0.10f,
|
||||||
|
minQuality: Float = 0.7f,
|
||||||
|
limit: Int = 500
|
||||||
|
): List<FaceCacheEntity>
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Update embedding for a face cache entry
|
||||||
|
*/
|
||||||
|
@Query("UPDATE face_cache SET embedding = :embedding WHERE imageId = :imageId AND faceIndex = :faceIndex")
|
||||||
|
suspend fun updateEmbedding(imageId: String, faceIndex: Int, embedding: String)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Count of premium faces available
|
||||||
|
*/
|
||||||
|
@Query("""
|
||||||
|
SELECT COUNT(*) FROM face_cache fc
|
||||||
|
INNER JOIN images i ON fc.imageId = i.imageId
|
||||||
|
WHERE i.faceCount = 1
|
||||||
|
AND fc.faceAreaRatio >= 0.10
|
||||||
|
AND fc.isFrontal = 1
|
||||||
|
AND fc.qualityScore >= 0.7
|
||||||
|
AND fc.embedding IS NOT NULL
|
||||||
|
""")
|
||||||
|
suspend fun countPremiumFaces(): Int
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Data class for face count per image
|
||||||
|
*
|
||||||
|
* Used by getFaceCountsPerImage() query
|
||||||
|
*/
|
||||||
|
data class ImageFaceCount(
|
||||||
|
val imageId: String,
|
||||||
|
val faceCount: Int
|
||||||
|
)
|
||||||
|
|
||||||
data class CacheStats(
|
data class CacheStats(
|
||||||
val totalFaces: Int,
|
val totalFaces: Int,
|
||||||
val withEmbeddings: Int,
|
val withEmbeddings: Int,
|
||||||
|
|||||||
@@ -66,6 +66,9 @@ interface ImageDao {
|
|||||||
@Query("SELECT * FROM images WHERE imageId = :imageId")
|
@Query("SELECT * FROM images WHERE imageId = :imageId")
|
||||||
suspend fun getImageById(imageId: String): ImageEntity?
|
suspend fun getImageById(imageId: String): ImageEntity?
|
||||||
|
|
||||||
|
@Query("SELECT * FROM images WHERE imageUri = :uri LIMIT 1")
|
||||||
|
suspend fun getImageByUri(uri: String): ImageEntity?
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Stream images ordered by capture time (newest first).
|
* Stream images ordered by capture time (newest first).
|
||||||
*
|
*
|
||||||
|
|||||||
@@ -83,9 +83,89 @@ interface PhotoFaceTagDao {
|
|||||||
*/
|
*/
|
||||||
@Query("SELECT * FROM photo_face_tags ORDER BY detectedAt DESC LIMIT :limit")
|
@Query("SELECT * FROM photo_face_tags ORDER BY detectedAt DESC LIMIT :limit")
|
||||||
suspend fun getRecentlyDetectedFaces(limit: Int): List<PhotoFaceTagEntity>
|
suspend fun getRecentlyDetectedFaces(limit: Int): List<PhotoFaceTagEntity>
|
||||||
|
|
||||||
|
// ===== CO-OCCURRENCE QUERIES =====
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Find people who appear in photos together with a given person.
|
||||||
|
* Returns list of (otherFaceModelId, count) sorted by count descending.
|
||||||
|
* Use case: "Who appears most with Mom?" or "Show photos of Mom WITH Dad"
|
||||||
|
*/
|
||||||
|
@Query("""
|
||||||
|
SELECT pft2.faceModelId as otherFaceModelId, COUNT(DISTINCT pft1.imageId) as coCount
|
||||||
|
FROM photo_face_tags pft1
|
||||||
|
INNER JOIN photo_face_tags pft2 ON pft1.imageId = pft2.imageId
|
||||||
|
WHERE pft1.faceModelId = :faceModelId
|
||||||
|
AND pft2.faceModelId != :faceModelId
|
||||||
|
GROUP BY pft2.faceModelId
|
||||||
|
ORDER BY coCount DESC
|
||||||
|
""")
|
||||||
|
suspend fun getCoOccurrences(faceModelId: String): List<PersonCoOccurrence>
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get images where BOTH people appear together.
|
||||||
|
*/
|
||||||
|
@Query("""
|
||||||
|
SELECT DISTINCT pft1.imageId
|
||||||
|
FROM photo_face_tags pft1
|
||||||
|
INNER JOIN photo_face_tags pft2 ON pft1.imageId = pft2.imageId
|
||||||
|
WHERE pft1.faceModelId = :faceModelId1
|
||||||
|
AND pft2.faceModelId = :faceModelId2
|
||||||
|
ORDER BY pft1.detectedAt DESC
|
||||||
|
""")
|
||||||
|
suspend fun getImagesWithBothPeople(faceModelId1: String, faceModelId2: String): List<String>
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get images where person appears ALONE (no other trained faces).
|
||||||
|
*/
|
||||||
|
@Query("""
|
||||||
|
SELECT imageId FROM photo_face_tags
|
||||||
|
WHERE faceModelId = :faceModelId
|
||||||
|
AND imageId NOT IN (
|
||||||
|
SELECT imageId FROM photo_face_tags
|
||||||
|
WHERE faceModelId != :faceModelId
|
||||||
|
)
|
||||||
|
ORDER BY detectedAt DESC
|
||||||
|
""")
|
||||||
|
suspend fun getImagesWithPersonAlone(faceModelId: String): List<String>
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get images where ALL specified people appear (N-way intersection).
|
||||||
|
* For "Intersection Search" moonshot feature.
|
||||||
|
*/
|
||||||
|
@Query("""
|
||||||
|
SELECT imageId FROM photo_face_tags
|
||||||
|
WHERE faceModelId IN (:faceModelIds)
|
||||||
|
GROUP BY imageId
|
||||||
|
HAVING COUNT(DISTINCT faceModelId) = :requiredCount
|
||||||
|
""")
|
||||||
|
suspend fun getImagesWithAllPeople(faceModelIds: List<String>, requiredCount: Int): List<String>
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get images with at least N of the specified people (family portrait detection).
|
||||||
|
*/
|
||||||
|
@Query("""
|
||||||
|
SELECT imageId, COUNT(DISTINCT faceModelId) as memberCount
|
||||||
|
FROM photo_face_tags
|
||||||
|
WHERE faceModelId IN (:faceModelIds)
|
||||||
|
GROUP BY imageId
|
||||||
|
HAVING memberCount >= :minMembers
|
||||||
|
ORDER BY memberCount DESC
|
||||||
|
""")
|
||||||
|
suspend fun getFamilyPortraits(faceModelIds: List<String>, minMembers: Int): List<FamilyPortraitResult>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
data class FamilyPortraitResult(
|
||||||
|
val imageId: String,
|
||||||
|
val memberCount: Int
|
||||||
|
)
|
||||||
|
|
||||||
data class FaceModelPhotoCount(
|
data class FaceModelPhotoCount(
|
||||||
val faceModelId: String,
|
val faceModelId: String,
|
||||||
val photoCount: Int
|
val photoCount: Int
|
||||||
)
|
)
|
||||||
|
|
||||||
|
data class PersonCoOccurrence(
|
||||||
|
val otherFaceModelId: String,
|
||||||
|
val coCount: Int
|
||||||
|
)
|
||||||
|
|||||||
@@ -99,6 +99,13 @@ data class FaceCacheEntity(
|
|||||||
companion object {
|
companion object {
|
||||||
const val CURRENT_CACHE_VERSION = 1
|
const val CURRENT_CACHE_VERSION = 1
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Convert FloatArray embedding to JSON string for storage
|
||||||
|
*/
|
||||||
|
fun embeddingToJson(embedding: FloatArray): String {
|
||||||
|
return embedding.joinToString(",")
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Create from ML Kit face detection result
|
* Create from ML Kit face detection result
|
||||||
*/
|
*/
|
||||||
|
|||||||
@@ -143,6 +143,13 @@ data class FaceModelEntity(
|
|||||||
@ColumnInfo(name = "averageConfidence")
|
@ColumnInfo(name = "averageConfidence")
|
||||||
val averageConfidence: Float,
|
val averageConfidence: Float,
|
||||||
|
|
||||||
|
// Distribution stats for self-calibrating rejection
|
||||||
|
@ColumnInfo(name = "similarityStdDev")
|
||||||
|
val similarityStdDev: Float = 0.05f, // Default for backwards compat
|
||||||
|
|
||||||
|
@ColumnInfo(name = "similarityMin")
|
||||||
|
val similarityMin: Float = 0.6f, // Default for backwards compat
|
||||||
|
|
||||||
@ColumnInfo(name = "createdAt")
|
@ColumnInfo(name = "createdAt")
|
||||||
val createdAt: Long,
|
val createdAt: Long,
|
||||||
|
|
||||||
@@ -157,26 +164,29 @@ data class FaceModelEntity(
|
|||||||
) {
|
) {
|
||||||
companion object {
|
companion object {
|
||||||
/**
|
/**
|
||||||
* Backwards compatible create() method
|
* Create with distribution stats for self-calibrating rejection
|
||||||
* Used by existing FaceRecognitionRepository code
|
|
||||||
*/
|
*/
|
||||||
fun create(
|
fun create(
|
||||||
personId: String,
|
personId: String,
|
||||||
embeddingArray: FloatArray,
|
embeddingArray: FloatArray,
|
||||||
trainingImageCount: Int,
|
trainingImageCount: Int,
|
||||||
averageConfidence: Float
|
averageConfidence: Float,
|
||||||
|
similarityStdDev: Float = 0.05f,
|
||||||
|
similarityMin: Float = 0.6f
|
||||||
): FaceModelEntity {
|
): FaceModelEntity {
|
||||||
return createFromEmbedding(personId, embeddingArray, trainingImageCount, averageConfidence)
|
return createFromEmbedding(personId, embeddingArray, trainingImageCount, averageConfidence, similarityStdDev, similarityMin)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Create from single embedding (backwards compatible)
|
* Create from single embedding with distribution stats
|
||||||
*/
|
*/
|
||||||
fun createFromEmbedding(
|
fun createFromEmbedding(
|
||||||
personId: String,
|
personId: String,
|
||||||
embeddingArray: FloatArray,
|
embeddingArray: FloatArray,
|
||||||
trainingImageCount: Int,
|
trainingImageCount: Int,
|
||||||
averageConfidence: Float
|
averageConfidence: Float,
|
||||||
|
similarityStdDev: Float = 0.05f,
|
||||||
|
similarityMin: Float = 0.6f
|
||||||
): FaceModelEntity {
|
): FaceModelEntity {
|
||||||
val now = System.currentTimeMillis()
|
val now = System.currentTimeMillis()
|
||||||
val centroid = TemporalCentroid(
|
val centroid = TemporalCentroid(
|
||||||
@@ -194,6 +204,8 @@ data class FaceModelEntity(
|
|||||||
centroidsJson = serializeCentroids(listOf(centroid)),
|
centroidsJson = serializeCentroids(listOf(centroid)),
|
||||||
trainingImageCount = trainingImageCount,
|
trainingImageCount = trainingImageCount,
|
||||||
averageConfidence = averageConfidence,
|
averageConfidence = averageConfidence,
|
||||||
|
similarityStdDev = similarityStdDev,
|
||||||
|
similarityMin = similarityMin,
|
||||||
createdAt = now,
|
createdAt = now,
|
||||||
updatedAt = now,
|
updatedAt = now,
|
||||||
lastUsed = null,
|
lastUsed = null,
|
||||||
|
|||||||
@@ -2,8 +2,10 @@ package com.placeholder.sherpai2.data.repository
|
|||||||
|
|
||||||
import android.content.Context
|
import android.content.Context
|
||||||
import android.graphics.Bitmap
|
import android.graphics.Bitmap
|
||||||
|
import android.util.Log
|
||||||
import com.placeholder.sherpai2.data.local.dao.FaceModelDao
|
import com.placeholder.sherpai2.data.local.dao.FaceModelDao
|
||||||
import com.placeholder.sherpai2.data.local.dao.ImageDao
|
import com.placeholder.sherpai2.data.local.dao.ImageDao
|
||||||
|
import com.placeholder.sherpai2.data.local.dao.PersonAgeTagDao
|
||||||
import com.placeholder.sherpai2.data.local.dao.PersonDao
|
import com.placeholder.sherpai2.data.local.dao.PersonDao
|
||||||
import com.placeholder.sherpai2.data.local.dao.PhotoFaceTagDao
|
import com.placeholder.sherpai2.data.local.dao.PhotoFaceTagDao
|
||||||
import com.placeholder.sherpai2.data.local.entity.*
|
import com.placeholder.sherpai2.data.local.entity.*
|
||||||
@@ -31,8 +33,12 @@ class FaceRecognitionRepository @Inject constructor(
|
|||||||
private val personDao: PersonDao,
|
private val personDao: PersonDao,
|
||||||
private val imageDao: ImageDao,
|
private val imageDao: ImageDao,
|
||||||
private val faceModelDao: FaceModelDao,
|
private val faceModelDao: FaceModelDao,
|
||||||
private val photoFaceTagDao: PhotoFaceTagDao
|
private val photoFaceTagDao: PhotoFaceTagDao,
|
||||||
|
private val personAgeTagDao: PersonAgeTagDao
|
||||||
) {
|
) {
|
||||||
|
companion object {
|
||||||
|
private const val TAG = "FaceRecognitionRepo"
|
||||||
|
}
|
||||||
|
|
||||||
private val faceNetModel by lazy { FaceNetModel(context) }
|
private val faceNetModel by lazy { FaceNetModel(context) }
|
||||||
|
|
||||||
@@ -93,11 +99,19 @@ class FaceRecognitionRepository @Inject constructor(
|
|||||||
}
|
}
|
||||||
val avgConfidence = confidences.average().toFloat()
|
val avgConfidence = confidences.average().toFloat()
|
||||||
|
|
||||||
|
// Compute distribution stats for self-calibrating rejection
|
||||||
|
val stdDev = kotlin.math.sqrt(
|
||||||
|
confidences.map { (it - avgConfidence).toDouble().let { d -> d * d } }.average()
|
||||||
|
).toFloat()
|
||||||
|
val minSimilarity = confidences.minOrNull() ?: 0f
|
||||||
|
|
||||||
val faceModel = FaceModelEntity.create(
|
val faceModel = FaceModelEntity.create(
|
||||||
personId = personId,
|
personId = personId,
|
||||||
embeddingArray = personEmbedding,
|
embeddingArray = personEmbedding,
|
||||||
trainingImageCount = validImages.size,
|
trainingImageCount = validImages.size,
|
||||||
averageConfidence = avgConfidence
|
averageConfidence = avgConfidence,
|
||||||
|
similarityStdDev = stdDev,
|
||||||
|
similarityMin = minSimilarity
|
||||||
)
|
)
|
||||||
|
|
||||||
faceModelDao.insertFaceModel(faceModel)
|
faceModelDao.insertFaceModel(faceModel)
|
||||||
@@ -181,12 +195,15 @@ class FaceRecognitionRepository @Inject constructor(
|
|||||||
var highestSimilarity = threshold
|
var highestSimilarity = threshold
|
||||||
|
|
||||||
for (faceModel in faceModels) {
|
for (faceModel in faceModels) {
|
||||||
val modelEmbedding = faceModel.getEmbeddingArray()
|
// Check ALL centroids for best match (critical for children with age centroids)
|
||||||
val similarity = faceNetModel.calculateSimilarity(faceEmbedding, modelEmbedding)
|
val centroids = faceModel.getCentroids()
|
||||||
|
val bestCentroidSimilarity = centroids.maxOfOrNull { centroid ->
|
||||||
|
faceNetModel.calculateSimilarity(faceEmbedding, centroid.getEmbeddingArray())
|
||||||
|
} ?: 0f
|
||||||
|
|
||||||
if (similarity > highestSimilarity) {
|
if (bestCentroidSimilarity > highestSimilarity) {
|
||||||
highestSimilarity = similarity
|
highestSimilarity = bestCentroidSimilarity
|
||||||
bestMatch = Pair(faceModel.id, similarity)
|
bestMatch = Pair(faceModel.id, bestCentroidSimilarity)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -374,9 +391,49 @@ class FaceRecognitionRepository @Inject constructor(
|
|||||||
onProgress = onProgress
|
onProgress = onProgress
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Generate age tags for children
|
||||||
|
if (person.isChild && person.dateOfBirth != null) {
|
||||||
|
generateAgeTagsForTraining(person, validImages)
|
||||||
|
}
|
||||||
|
|
||||||
person.id
|
person.id
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Generate age tags from training images for a child
|
||||||
|
*/
|
||||||
|
private suspend fun generateAgeTagsForTraining(
|
||||||
|
person: PersonEntity,
|
||||||
|
validImages: List<TrainingSanityChecker.ValidTrainingImage>
|
||||||
|
) {
|
||||||
|
try {
|
||||||
|
val dob = person.dateOfBirth ?: return
|
||||||
|
|
||||||
|
val tags = validImages.mapNotNull { img ->
|
||||||
|
val imageEntity = imageDao.getImageByUri(img.uri.toString()) ?: return@mapNotNull null
|
||||||
|
val ageMs = imageEntity.capturedAt - dob
|
||||||
|
val ageYears = (ageMs / (365.25 * 24 * 60 * 60 * 1000)).toInt()
|
||||||
|
|
||||||
|
if (ageYears < 0 || ageYears > 25) return@mapNotNull null
|
||||||
|
|
||||||
|
PersonAgeTagEntity.create(
|
||||||
|
personId = person.id,
|
||||||
|
personName = person.name,
|
||||||
|
imageId = imageEntity.imageId,
|
||||||
|
ageAtCapture = ageYears,
|
||||||
|
confidence = 1.0f
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
if (tags.isNotEmpty()) {
|
||||||
|
personAgeTagDao.insertTags(tags)
|
||||||
|
Log.d(TAG, "Created ${tags.size} age tags for ${person.name}")
|
||||||
|
}
|
||||||
|
} catch (e: Exception) {
|
||||||
|
Log.e(TAG, "Failed to generate age tags", e)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get face model by ID
|
* Get face model by ID
|
||||||
*/
|
*/
|
||||||
|
|||||||
@@ -61,14 +61,16 @@ abstract class RepositoryModule {
|
|||||||
personDao: PersonDao,
|
personDao: PersonDao,
|
||||||
imageDao: ImageDao,
|
imageDao: ImageDao,
|
||||||
faceModelDao: FaceModelDao,
|
faceModelDao: FaceModelDao,
|
||||||
photoFaceTagDao: PhotoFaceTagDao
|
photoFaceTagDao: PhotoFaceTagDao,
|
||||||
|
personAgeTagDao: PersonAgeTagDao
|
||||||
): FaceRecognitionRepository {
|
): FaceRecognitionRepository {
|
||||||
return FaceRecognitionRepository(
|
return FaceRecognitionRepository(
|
||||||
context = context,
|
context = context,
|
||||||
personDao = personDao,
|
personDao = personDao,
|
||||||
imageDao = imageDao,
|
imageDao = imageDao,
|
||||||
faceModelDao = faceModelDao,
|
faceModelDao = faceModelDao,
|
||||||
photoFaceTagDao = photoFaceTagDao
|
photoFaceTagDao = photoFaceTagDao,
|
||||||
|
personAgeTagDao = personAgeTagDao
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,33 @@
|
|||||||
|
package com.placeholder.sherpai2.di
|
||||||
|
|
||||||
|
import com.placeholder.sherpai2.data.local.dao.FaceCacheDao
|
||||||
|
import com.placeholder.sherpai2.domain.similarity.FaceSimilarityScorer
|
||||||
|
import dagger.Module
|
||||||
|
import dagger.Provides
|
||||||
|
import dagger.hilt.InstallIn
|
||||||
|
import dagger.hilt.components.SingletonComponent
|
||||||
|
import javax.inject.Singleton
|
||||||
|
|
||||||
|
/**
|
||||||
|
* SimilarityModule - Provides similarity scoring dependencies
|
||||||
|
*
|
||||||
|
* This module provides FaceSimilarityScorer for Rolling Scan feature
|
||||||
|
*/
|
||||||
|
@Module
|
||||||
|
@InstallIn(SingletonComponent::class)
|
||||||
|
object SimilarityModule {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Provide FaceSimilarityScorer singleton
|
||||||
|
*
|
||||||
|
* FaceSimilarityScorer handles real-time similarity scoring
|
||||||
|
* for the Rolling Scan feature
|
||||||
|
*/
|
||||||
|
@Provides
|
||||||
|
@Singleton
|
||||||
|
fun provideFaceSimilarityScorer(
|
||||||
|
faceCacheDao: FaceCacheDao
|
||||||
|
): FaceSimilarityScorer {
|
||||||
|
return FaceSimilarityScorer(faceCacheDao)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -15,6 +15,7 @@ import com.placeholder.sherpai2.data.local.dao.ImageDao
|
|||||||
import com.placeholder.sherpai2.data.local.entity.FaceCacheEntity
|
import com.placeholder.sherpai2.data.local.entity.FaceCacheEntity
|
||||||
import com.placeholder.sherpai2.data.local.entity.ImageEntity
|
import com.placeholder.sherpai2.data.local.entity.ImageEntity
|
||||||
import com.placeholder.sherpai2.ml.FaceNetModel
|
import com.placeholder.sherpai2.ml.FaceNetModel
|
||||||
|
import com.placeholder.sherpai2.ml.FaceNormalizer
|
||||||
import com.placeholder.sherpai2.ui.discover.DiscoverySettings
|
import com.placeholder.sherpai2.ui.discover.DiscoverySettings
|
||||||
import dagger.hilt.android.qualifiers.ApplicationContext
|
import dagger.hilt.android.qualifiers.ApplicationContext
|
||||||
import kotlinx.coroutines.Dispatchers
|
import kotlinx.coroutines.Dispatchers
|
||||||
@@ -344,14 +345,9 @@ class FaceClusteringService @Inject constructor(
|
|||||||
}
|
}
|
||||||
|
|
||||||
try {
|
try {
|
||||||
// Crop and generate embedding
|
// Crop and normalize face
|
||||||
val faceBitmap = Bitmap.createBitmap(
|
val faceBitmap = FaceNormalizer.cropAndNormalize(bitmap, mlFace)
|
||||||
bitmap,
|
?: return@forEach
|
||||||
mlFace.boundingBox.left.coerceIn(0, bitmap.width - 1),
|
|
||||||
mlFace.boundingBox.top.coerceIn(0, bitmap.height - 1),
|
|
||||||
mlFace.boundingBox.width().coerceAtMost(bitmap.width - mlFace.boundingBox.left),
|
|
||||||
mlFace.boundingBox.height().coerceAtMost(bitmap.height - mlFace.boundingBox.top)
|
|
||||||
)
|
|
||||||
|
|
||||||
val embedding = faceNetModel.generateEmbedding(faceBitmap)
|
val embedding = faceNetModel.generateEmbedding(faceBitmap)
|
||||||
faceBitmap.recycle()
|
faceBitmap.recycle()
|
||||||
@@ -591,13 +587,8 @@ class FaceClusteringService @Inject constructor(
|
|||||||
if (!qualityCheck.isValid) return@mapNotNull null
|
if (!qualityCheck.isValid) return@mapNotNull null
|
||||||
|
|
||||||
try {
|
try {
|
||||||
val faceBitmap = Bitmap.createBitmap(
|
val faceBitmap = FaceNormalizer.cropAndNormalize(bitmap, face)
|
||||||
bitmap,
|
?: return@mapNotNull null
|
||||||
face.boundingBox.left.coerceIn(0, bitmap.width - 1),
|
|
||||||
face.boundingBox.top.coerceIn(0, bitmap.height - 1),
|
|
||||||
face.boundingBox.width().coerceAtMost(bitmap.width - face.boundingBox.left),
|
|
||||||
face.boundingBox.height().coerceAtMost(bitmap.height - face.boundingBox.top)
|
|
||||||
)
|
|
||||||
|
|
||||||
val embedding = faceNetModel.generateEmbedding(faceBitmap)
|
val embedding = faceNetModel.generateEmbedding(faceBitmap)
|
||||||
faceBitmap.recycle()
|
faceBitmap.recycle()
|
||||||
|
|||||||
@@ -29,6 +29,64 @@ import kotlin.math.sqrt
|
|||||||
*/
|
*/
|
||||||
object FaceQualityFilter {
|
object FaceQualityFilter {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Age group estimation for filtering (child vs adult detection)
|
||||||
|
*/
|
||||||
|
enum class AgeGroup { CHILD, ADULT, UNCERTAIN }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Estimate whether a face belongs to a child or adult based on facial proportions.
|
||||||
|
*
|
||||||
|
* Uses two heuristics:
|
||||||
|
* 1. Eye position ratio - Children have larger foreheads, so eyes are lower (~45% from top)
|
||||||
|
* Adults have eyes at ~35% from top
|
||||||
|
* 2. Face roundness (width/height ratio) - Children: ~0.85-1.0, Adults: ~0.7-0.85
|
||||||
|
*
|
||||||
|
* @return AgeGroup.CHILD, AgeGroup.ADULT, or AgeGroup.UNCERTAIN
|
||||||
|
*/
|
||||||
|
fun estimateAgeGroup(face: Face, imageWidth: Int, imageHeight: Int): AgeGroup {
|
||||||
|
val leftEye = face.getLandmark(FaceLandmark.LEFT_EYE)
|
||||||
|
val rightEye = face.getLandmark(FaceLandmark.RIGHT_EYE)
|
||||||
|
|
||||||
|
if (leftEye == null || rightEye == null) {
|
||||||
|
return AgeGroup.UNCERTAIN
|
||||||
|
}
|
||||||
|
|
||||||
|
// Eye-to-face height ratio (where eyes sit relative to face top)
|
||||||
|
val faceHeight = face.boundingBox.height().toFloat()
|
||||||
|
val faceTop = face.boundingBox.top.toFloat()
|
||||||
|
val eyeY = (leftEye.position.y + rightEye.position.y) / 2
|
||||||
|
val eyePositionRatio = (eyeY - faceTop) / faceHeight
|
||||||
|
|
||||||
|
// Children: eyes at ~45% from top (larger forehead proportionally)
|
||||||
|
// Adults: eyes at ~35% from top
|
||||||
|
// Score: higher = more child-like
|
||||||
|
|
||||||
|
// Face roundness (width/height)
|
||||||
|
val faceWidth = face.boundingBox.width().toFloat()
|
||||||
|
val faceRatio = faceWidth / faceHeight
|
||||||
|
// Children: ratio ~0.85-1.0 (rounder faces)
|
||||||
|
// Adults: ratio ~0.7-0.85 (longer/narrower faces)
|
||||||
|
|
||||||
|
var childScore = 0
|
||||||
|
|
||||||
|
// Eye position scoring
|
||||||
|
if (eyePositionRatio > 0.45f) childScore += 2 // Strong child signal
|
||||||
|
else if (eyePositionRatio > 0.42f) childScore += 1 // Mild child signal
|
||||||
|
else if (eyePositionRatio < 0.35f) childScore -= 1 // Adult signal
|
||||||
|
|
||||||
|
// Face roundness scoring
|
||||||
|
if (faceRatio > 0.90f) childScore += 2 // Very round = child
|
||||||
|
else if (faceRatio > 0.82f) childScore += 1 // Somewhat round
|
||||||
|
else if (faceRatio < 0.75f) childScore -= 1 // Long face = adult
|
||||||
|
|
||||||
|
return when {
|
||||||
|
childScore >= 3 -> AgeGroup.CHILD
|
||||||
|
childScore <= 0 -> AgeGroup.ADULT
|
||||||
|
else -> AgeGroup.UNCERTAIN
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Validate face for Discovery/Clustering
|
* Validate face for Discovery/Clustering
|
||||||
*
|
*
|
||||||
|
|||||||
@@ -0,0 +1,353 @@
|
|||||||
|
package com.placeholder.sherpai2.domain.similarity
|
||||||
|
|
||||||
|
import android.util.Log
|
||||||
|
import com.placeholder.sherpai2.data.local.dao.FaceCacheDao
|
||||||
|
import com.placeholder.sherpai2.data.local.entity.FaceCacheEntity
|
||||||
|
import kotlinx.coroutines.Dispatchers
|
||||||
|
import kotlinx.coroutines.withContext
|
||||||
|
import javax.inject.Inject
|
||||||
|
import javax.inject.Singleton
|
||||||
|
import kotlin.math.sqrt
|
||||||
|
|
||||||
|
/**
|
||||||
|
* FaceSimilarityScorer - Real-time similarity scoring for Rolling Scan
|
||||||
|
*
|
||||||
|
* CORE RESPONSIBILITIES:
|
||||||
|
* 1. Calculate centroid from selected face embeddings
|
||||||
|
* 2. Score all unselected photos against centroid
|
||||||
|
* 3. Apply quality boosting (solo photos, high confidence, etc.)
|
||||||
|
* 4. Rank photos by final score (similarity + quality boost)
|
||||||
|
*
|
||||||
|
* KEY OPTIMIZATION: Uses cached embeddings from FaceCacheEntity
|
||||||
|
* - No embedding generation needed (already done!)
|
||||||
|
* - Blazing fast scoring (just cosine similarity)
|
||||||
|
* - Can score 1000+ photos in ~100ms
|
||||||
|
*/
|
||||||
|
@Singleton
|
||||||
|
class FaceSimilarityScorer @Inject constructor(
|
||||||
|
private val faceCacheDao: FaceCacheDao
|
||||||
|
) {
|
||||||
|
|
||||||
|
companion object {
|
||||||
|
private const val TAG = "FaceSimilarityScorer"
|
||||||
|
|
||||||
|
// Quality boost constants
|
||||||
|
private const val SOLO_PHOTO_BOOST = 0.15f
|
||||||
|
private const val HIGH_CONFIDENCE_BOOST = 0.05f
|
||||||
|
private const val GROUP_PHOTO_PENALTY = -0.10f
|
||||||
|
private const val HIGH_QUALITY_BOOST = 0.03f
|
||||||
|
|
||||||
|
// Thresholds
|
||||||
|
private const val HIGH_CONFIDENCE_THRESHOLD = 0.8f
|
||||||
|
private const val HIGH_QUALITY_THRESHOLD = 0.8f
|
||||||
|
private const val GROUP_PHOTO_THRESHOLD = 3
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Scored photo with similarity and quality metrics
|
||||||
|
*/
|
||||||
|
data class ScoredPhoto(
|
||||||
|
val imageId: String,
|
||||||
|
val imageUri: String,
|
||||||
|
val faceIndex: Int,
|
||||||
|
val similarityScore: Float, // 0.0 - 1.0 (cosine similarity to centroid)
|
||||||
|
val qualityBoost: Float, // -0.2 to +0.2 (quality adjustments)
|
||||||
|
val finalScore: Float, // similarity + qualityBoost
|
||||||
|
val faceCount: Int, // Number of faces in image
|
||||||
|
val faceAreaRatio: Float, // Size of face in image
|
||||||
|
val qualityScore: Float, // Overall face quality
|
||||||
|
val cachedEmbedding: FloatArray // For further operations
|
||||||
|
) {
|
||||||
|
override fun equals(other: Any?): Boolean {
|
||||||
|
if (this === other) return true
|
||||||
|
if (other !is ScoredPhoto) return false
|
||||||
|
return imageId == other.imageId && faceIndex == other.faceIndex
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun hashCode(): Int {
|
||||||
|
return imageId.hashCode() * 31 + faceIndex
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Calculate centroid from multiple embeddings
|
||||||
|
*
|
||||||
|
* Centroid = average of all embedding vectors
|
||||||
|
* This represents the "average face" of selected photos
|
||||||
|
*/
|
||||||
|
fun calculateCentroid(embeddings: List<FloatArray>): FloatArray {
|
||||||
|
if (embeddings.isEmpty()) {
|
||||||
|
Log.w(TAG, "Cannot calculate centroid from empty list")
|
||||||
|
return FloatArray(192) { 0f }
|
||||||
|
}
|
||||||
|
|
||||||
|
val dimension = embeddings.first().size
|
||||||
|
val centroid = FloatArray(dimension) { 0f }
|
||||||
|
|
||||||
|
// Sum all embeddings
|
||||||
|
embeddings.forEach { embedding ->
|
||||||
|
if (embedding.size != dimension) {
|
||||||
|
Log.e(TAG, "Embedding size mismatch: ${embedding.size} vs $dimension")
|
||||||
|
return@forEach
|
||||||
|
}
|
||||||
|
|
||||||
|
embedding.forEachIndexed { i, value ->
|
||||||
|
centroid[i] += value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Average
|
||||||
|
val count = embeddings.size.toFloat()
|
||||||
|
centroid.forEachIndexed { i, _ ->
|
||||||
|
centroid[i] /= count
|
||||||
|
}
|
||||||
|
|
||||||
|
// Normalize to unit length
|
||||||
|
return normalizeEmbedding(centroid)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Score a single photo against centroid
|
||||||
|
* Uses cosine similarity
|
||||||
|
*/
|
||||||
|
fun scorePhotoAgainstCentroid(
|
||||||
|
photoEmbedding: FloatArray,
|
||||||
|
centroid: FloatArray
|
||||||
|
): Float {
|
||||||
|
return cosineSimilarity(photoEmbedding, centroid)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* CRITICAL: Batch score all photos against centroid
|
||||||
|
*
|
||||||
|
* This is the main function used by RollingScanViewModel
|
||||||
|
*
|
||||||
|
* @param allImageIds All available image IDs (with cached embeddings)
|
||||||
|
* @param selectedImageIds Already selected images (exclude from results)
|
||||||
|
* @param centroid Centroid calculated from selected embeddings
|
||||||
|
* @return List of scored photos, sorted by finalScore DESC
|
||||||
|
*/
|
||||||
|
suspend fun scorePhotosAgainstCentroid(
|
||||||
|
allImageIds: List<String>,
|
||||||
|
selectedImageIds: Set<String>,
|
||||||
|
centroid: FloatArray
|
||||||
|
): List<ScoredPhoto> = withContext(Dispatchers.Default) {
|
||||||
|
|
||||||
|
if (centroid.all { it == 0f }) {
|
||||||
|
Log.w(TAG, "Centroid is all zeros, cannot score")
|
||||||
|
return@withContext emptyList()
|
||||||
|
}
|
||||||
|
|
||||||
|
Log.d(TAG, "Scoring ${allImageIds.size} photos (excluding ${selectedImageIds.size} selected)")
|
||||||
|
|
||||||
|
try {
|
||||||
|
// Get ALL cached face entries for these images
|
||||||
|
val cachedFaces = faceCacheDao.getFaceCacheByImageIds(allImageIds)
|
||||||
|
|
||||||
|
Log.d(TAG, "Retrieved ${cachedFaces.size} cached faces")
|
||||||
|
|
||||||
|
// Filter to unselected images with embeddings
|
||||||
|
val scorablePhotos = cachedFaces
|
||||||
|
.filter { it.imageId !in selectedImageIds }
|
||||||
|
.filter { it.embedding != null }
|
||||||
|
|
||||||
|
Log.d(TAG, "Scorable photos: ${scorablePhotos.size}")
|
||||||
|
|
||||||
|
// Score each photo
|
||||||
|
val scoredPhotos = scorablePhotos.mapNotNull { cachedFace ->
|
||||||
|
try {
|
||||||
|
val embedding = cachedFace.getEmbedding() ?: return@mapNotNull null
|
||||||
|
|
||||||
|
// Calculate similarity to centroid
|
||||||
|
val similarityScore = cosineSimilarity(embedding, centroid)
|
||||||
|
|
||||||
|
// Calculate quality boost
|
||||||
|
val qualityBoost = calculateQualityBoost(
|
||||||
|
faceCount = getFaceCountForImage(cachedFace.imageId, cachedFaces),
|
||||||
|
confidence = cachedFace.confidence,
|
||||||
|
qualityScore = cachedFace.qualityScore,
|
||||||
|
faceAreaRatio = cachedFace.faceAreaRatio
|
||||||
|
)
|
||||||
|
|
||||||
|
// Final score
|
||||||
|
val finalScore = (similarityScore + qualityBoost).coerceIn(0f, 1f)
|
||||||
|
|
||||||
|
ScoredPhoto(
|
||||||
|
imageId = cachedFace.imageId,
|
||||||
|
imageUri = getImageUri(cachedFace.imageId), // Will need to fetch
|
||||||
|
faceIndex = cachedFace.faceIndex,
|
||||||
|
similarityScore = similarityScore,
|
||||||
|
qualityBoost = qualityBoost,
|
||||||
|
finalScore = finalScore,
|
||||||
|
faceCount = getFaceCountForImage(cachedFace.imageId, cachedFaces),
|
||||||
|
faceAreaRatio = cachedFace.faceAreaRatio,
|
||||||
|
qualityScore = cachedFace.qualityScore,
|
||||||
|
cachedEmbedding = embedding
|
||||||
|
)
|
||||||
|
} catch (e: Exception) {
|
||||||
|
Log.w(TAG, "Error scoring photo ${cachedFace.imageId}: ${e.message}")
|
||||||
|
null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sort by final score (highest first)
|
||||||
|
val sorted = scoredPhotos.sortedByDescending { it.finalScore }
|
||||||
|
|
||||||
|
Log.d(TAG, "Scored ${sorted.size} photos. Top score: ${sorted.firstOrNull()?.finalScore}")
|
||||||
|
|
||||||
|
sorted
|
||||||
|
|
||||||
|
} catch (e: Exception) {
|
||||||
|
Log.e(TAG, "Error in batch scoring", e)
|
||||||
|
emptyList()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Calculate quality boost based on photo characteristics
|
||||||
|
*
|
||||||
|
* Boosts:
|
||||||
|
* - Solo photos (faceCount == 1): +0.15
|
||||||
|
* - High confidence (>0.8): +0.05
|
||||||
|
* - High quality score (>0.8): +0.03
|
||||||
|
*
|
||||||
|
* Penalties:
|
||||||
|
* - Group photos (faceCount >= 3): -0.10
|
||||||
|
*/
|
||||||
|
private fun calculateQualityBoost(
|
||||||
|
faceCount: Int,
|
||||||
|
confidence: Float,
|
||||||
|
qualityScore: Float,
|
||||||
|
faceAreaRatio: Float
|
||||||
|
): Float {
|
||||||
|
var boost = 0f
|
||||||
|
|
||||||
|
// MAJOR boost for solo photos (easier to verify, less confusion)
|
||||||
|
if (faceCount == 1) {
|
||||||
|
boost += SOLO_PHOTO_BOOST
|
||||||
|
}
|
||||||
|
|
||||||
|
// Penalize group photos (harder to verify correct face)
|
||||||
|
if (faceCount >= GROUP_PHOTO_THRESHOLD) {
|
||||||
|
boost += GROUP_PHOTO_PENALTY
|
||||||
|
}
|
||||||
|
|
||||||
|
// Boost high-confidence detections
|
||||||
|
if (confidence > HIGH_CONFIDENCE_THRESHOLD) {
|
||||||
|
boost += HIGH_CONFIDENCE_BOOST
|
||||||
|
}
|
||||||
|
|
||||||
|
// Boost high-quality faces (large, clear, frontal)
|
||||||
|
if (qualityScore > HIGH_QUALITY_THRESHOLD) {
|
||||||
|
boost += HIGH_QUALITY_BOOST
|
||||||
|
}
|
||||||
|
|
||||||
|
// Coerce to reasonable range
|
||||||
|
return boost.coerceIn(-0.2f, 0.2f)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get face count for an image
|
||||||
|
* (Multiple faces in same image share imageId but different faceIndex)
|
||||||
|
*/
|
||||||
|
private fun getFaceCountForImage(
|
||||||
|
imageId: String,
|
||||||
|
allCachedFaces: List<FaceCacheEntity>
|
||||||
|
): Int {
|
||||||
|
return allCachedFaces.count { it.imageId == imageId }
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get image URI for an imageId
|
||||||
|
*
|
||||||
|
* NOTE: This is a temporary implementation
|
||||||
|
* In production, we'd join with ImageEntity or cache URIs
|
||||||
|
*/
|
||||||
|
private suspend fun getImageUri(imageId: String): String {
|
||||||
|
// TODO: Implement proper URI retrieval
|
||||||
|
// For now, return imageId as placeholder
|
||||||
|
return imageId
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Cosine similarity calculation
|
||||||
|
*
|
||||||
|
* Returns value between -1.0 and 1.0
|
||||||
|
* Higher = more similar
|
||||||
|
*/
|
||||||
|
private fun cosineSimilarity(a: FloatArray, b: FloatArray): Float {
|
||||||
|
if (a.size != b.size) {
|
||||||
|
Log.e(TAG, "Embedding size mismatch: ${a.size} vs ${b.size}")
|
||||||
|
return 0f
|
||||||
|
}
|
||||||
|
|
||||||
|
var dotProduct = 0f
|
||||||
|
var normA = 0f
|
||||||
|
var normB = 0f
|
||||||
|
|
||||||
|
a.indices.forEach { i ->
|
||||||
|
dotProduct += a[i] * b[i]
|
||||||
|
normA += a[i] * a[i]
|
||||||
|
normB += b[i] * b[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
if (normA == 0f || normB == 0f) {
|
||||||
|
Log.w(TAG, "Zero norm in similarity calculation")
|
||||||
|
return 0f
|
||||||
|
}
|
||||||
|
|
||||||
|
val similarity = dotProduct / (sqrt(normA) * sqrt(normB))
|
||||||
|
|
||||||
|
// Handle NaN/Infinity
|
||||||
|
if (similarity.isNaN() || similarity.isInfinite()) {
|
||||||
|
Log.w(TAG, "Invalid similarity: $similarity")
|
||||||
|
return 0f
|
||||||
|
}
|
||||||
|
|
||||||
|
return similarity
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Normalize embedding to unit length
|
||||||
|
*/
|
||||||
|
private fun normalizeEmbedding(embedding: FloatArray): FloatArray {
|
||||||
|
var norm = 0f
|
||||||
|
for (value in embedding) {
|
||||||
|
norm += value * value
|
||||||
|
}
|
||||||
|
norm = sqrt(norm)
|
||||||
|
|
||||||
|
return if (norm > 0) {
|
||||||
|
FloatArray(embedding.size) { i -> embedding[i] / norm }
|
||||||
|
} else {
|
||||||
|
Log.w(TAG, "Cannot normalize zero embedding")
|
||||||
|
embedding
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Incremental scoring for viewport optimization
|
||||||
|
*
|
||||||
|
* Only scores photos in visible range + next batch
|
||||||
|
* Useful for large libraries (5000+ photos)
|
||||||
|
*/
|
||||||
|
suspend fun scorePhotosIncrementally(
|
||||||
|
visibleRange: IntRange,
|
||||||
|
batchSize: Int = 50,
|
||||||
|
allImageIds: List<String>,
|
||||||
|
selectedImageIds: Set<String>,
|
||||||
|
centroid: FloatArray
|
||||||
|
): List<ScoredPhoto> {
|
||||||
|
|
||||||
|
val rangeToScan = visibleRange.first until
|
||||||
|
(visibleRange.last + batchSize).coerceAtMost(allImageIds.size)
|
||||||
|
|
||||||
|
val imageIdsToScan = allImageIds.slice(rangeToScan)
|
||||||
|
|
||||||
|
return scorePhotosAgainstCentroid(
|
||||||
|
allImageIds = imageIdsToScan,
|
||||||
|
selectedImageIds = selectedImageIds,
|
||||||
|
centroid = centroid
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -3,13 +3,17 @@ package com.placeholder.sherpai2.domain.training
|
|||||||
import android.content.Context
|
import android.content.Context
|
||||||
import android.graphics.BitmapFactory
|
import android.graphics.BitmapFactory
|
||||||
import android.net.Uri
|
import android.net.Uri
|
||||||
|
import android.util.Log
|
||||||
import com.placeholder.sherpai2.data.local.dao.FaceModelDao
|
import com.placeholder.sherpai2.data.local.dao.FaceModelDao
|
||||||
|
import com.placeholder.sherpai2.data.local.dao.PersonAgeTagDao
|
||||||
import com.placeholder.sherpai2.data.local.dao.PersonDao
|
import com.placeholder.sherpai2.data.local.dao.PersonDao
|
||||||
import com.placeholder.sherpai2.data.local.entity.FaceModelEntity
|
import com.placeholder.sherpai2.data.local.entity.FaceModelEntity
|
||||||
|
import com.placeholder.sherpai2.data.local.entity.PersonAgeTagEntity
|
||||||
import com.placeholder.sherpai2.data.local.entity.PersonEntity
|
import com.placeholder.sherpai2.data.local.entity.PersonEntity
|
||||||
import com.placeholder.sherpai2.data.local.entity.TemporalCentroid
|
import com.placeholder.sherpai2.data.local.entity.TemporalCentroid
|
||||||
import com.placeholder.sherpai2.domain.clustering.ClusterQualityAnalyzer
|
import com.placeholder.sherpai2.domain.clustering.ClusterQualityAnalyzer
|
||||||
import com.placeholder.sherpai2.domain.clustering.ClusterQualityResult
|
import com.placeholder.sherpai2.domain.clustering.ClusterQualityResult
|
||||||
|
import com.placeholder.sherpai2.domain.clustering.DetectedFaceWithEmbedding
|
||||||
import com.placeholder.sherpai2.domain.clustering.FaceCluster
|
import com.placeholder.sherpai2.domain.clustering.FaceCluster
|
||||||
import com.placeholder.sherpai2.ml.FaceNetModel
|
import com.placeholder.sherpai2.ml.FaceNetModel
|
||||||
import dagger.hilt.android.qualifiers.ApplicationContext
|
import dagger.hilt.android.qualifiers.ApplicationContext
|
||||||
@@ -34,8 +38,12 @@ class ClusterTrainingService @Inject constructor(
|
|||||||
@ApplicationContext private val context: Context,
|
@ApplicationContext private val context: Context,
|
||||||
private val personDao: PersonDao,
|
private val personDao: PersonDao,
|
||||||
private val faceModelDao: FaceModelDao,
|
private val faceModelDao: FaceModelDao,
|
||||||
|
private val personAgeTagDao: PersonAgeTagDao,
|
||||||
private val qualityAnalyzer: ClusterQualityAnalyzer
|
private val qualityAnalyzer: ClusterQualityAnalyzer
|
||||||
) {
|
) {
|
||||||
|
companion object {
|
||||||
|
private const val TAG = "ClusterTraining"
|
||||||
|
}
|
||||||
|
|
||||||
private val faceNetModel by lazy { FaceNetModel(context) }
|
private val faceNetModel by lazy { FaceNetModel(context) }
|
||||||
|
|
||||||
@@ -135,11 +143,65 @@ class ClusterTrainingService @Inject constructor(
|
|||||||
faceModelDao.insertFaceModel(faceModel)
|
faceModelDao.insertFaceModel(faceModel)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Step 7: Generate age tags for children
|
||||||
|
if (isChild && dateOfBirth != null) {
|
||||||
|
onProgress(90, 100, "Creating age tags...")
|
||||||
|
generateAgeTags(
|
||||||
|
personId = person.id,
|
||||||
|
personName = name,
|
||||||
|
faces = facesToUse,
|
||||||
|
dateOfBirth = dateOfBirth
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
onProgress(100, 100, "Complete!")
|
onProgress(100, 100, "Complete!")
|
||||||
|
|
||||||
person.id
|
person.id
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Generate PersonAgeTagEntity records for a child's photos
|
||||||
|
*
|
||||||
|
* Creates searchable tags like "emma_age2", "emma_age3" etc.
|
||||||
|
* Enables queries like "Show all photos of Emma at age 2"
|
||||||
|
*/
|
||||||
|
private suspend fun generateAgeTags(
|
||||||
|
personId: String,
|
||||||
|
personName: String,
|
||||||
|
faces: List<com.placeholder.sherpai2.domain.clustering.DetectedFaceWithEmbedding>,
|
||||||
|
dateOfBirth: Long
|
||||||
|
) = withContext(Dispatchers.IO) {
|
||||||
|
try {
|
||||||
|
val tags = faces.mapNotNull { face ->
|
||||||
|
// Calculate age at capture
|
||||||
|
val ageMs = face.capturedAt - dateOfBirth
|
||||||
|
val ageYears = (ageMs / (365.25 * 24 * 60 * 60 * 1000)).toInt()
|
||||||
|
|
||||||
|
// Skip if age is negative or unreasonably high
|
||||||
|
if (ageYears < 0 || ageYears > 25) {
|
||||||
|
Log.w(TAG, "Skipping face with invalid age: $ageYears years")
|
||||||
|
return@mapNotNull null
|
||||||
|
}
|
||||||
|
|
||||||
|
PersonAgeTagEntity.create(
|
||||||
|
personId = personId,
|
||||||
|
personName = personName,
|
||||||
|
imageId = face.imageId,
|
||||||
|
ageAtCapture = ageYears,
|
||||||
|
confidence = 1.0f // High confidence since this is from training data
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
if (tags.isNotEmpty()) {
|
||||||
|
personAgeTagDao.insertTags(tags)
|
||||||
|
Log.d(TAG, "Created ${tags.size} age tags for $personName (ages: ${tags.map { it.ageAtCapture }.distinct().sorted()})")
|
||||||
|
}
|
||||||
|
} catch (e: Exception) {
|
||||||
|
Log.e(TAG, "Failed to generate age tags", e)
|
||||||
|
// Non-fatal - continue without tags
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Create temporal centroids for a child
|
* Create temporal centroids for a child
|
||||||
* Groups faces by age and creates one centroid per age period
|
* Groups faces by age and creates one centroid per age period
|
||||||
|
|||||||
@@ -75,7 +75,21 @@ class PopulateFaceDetectionCacheUseCase @Inject constructor(
|
|||||||
)
|
)
|
||||||
|
|
||||||
try {
|
try {
|
||||||
val imagesToScan = imageDao.getImagesNeedingFaceDetection()
|
// Get images that need face detection (hasFaces IS NULL)
|
||||||
|
var imagesToScan = imageDao.getImagesNeedingFaceDetection()
|
||||||
|
|
||||||
|
// CRITICAL FIX: Also check for images marked as having faces but no FaceCacheEntity
|
||||||
|
if (imagesToScan.isEmpty()) {
|
||||||
|
val faceStats = faceCacheDao.getCacheStats()
|
||||||
|
if (faceStats.totalFaces == 0) {
|
||||||
|
// FaceCacheEntity is empty - rescan images that have faces
|
||||||
|
val imagesWithFaces = imageDao.getImagesWithFaces()
|
||||||
|
if (imagesWithFaces.isNotEmpty()) {
|
||||||
|
Log.w(TAG, "FaceCacheEntity empty but ${imagesWithFaces.size} images have faces - rescanning")
|
||||||
|
imagesToScan = imagesWithFaces
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if (imagesToScan.isEmpty()) {
|
if (imagesToScan.isEmpty()) {
|
||||||
Log.d(TAG, "No images need scanning")
|
Log.d(TAG, "No images need scanning")
|
||||||
@@ -184,7 +198,7 @@ class PopulateFaceDetectionCacheUseCase @Inject constructor(
|
|||||||
imageUri = image.imageUri
|
imageUri = image.imageUri
|
||||||
)
|
)
|
||||||
|
|
||||||
// Create FaceCacheEntity entries for each face
|
// Create FaceCacheEntity entries for each face (NO embeddings - generated on demand)
|
||||||
val faceCacheEntries = faces.mapIndexed { index, face ->
|
val faceCacheEntries = faces.mapIndexed { index, face ->
|
||||||
createFaceCacheEntry(
|
createFaceCacheEntry(
|
||||||
imageId = image.imageId,
|
imageId = image.imageId,
|
||||||
@@ -205,7 +219,8 @@ class PopulateFaceDetectionCacheUseCase @Inject constructor(
|
|||||||
/**
|
/**
|
||||||
* Create FaceCacheEntity from ML Kit Face
|
* Create FaceCacheEntity from ML Kit Face
|
||||||
*
|
*
|
||||||
* Uses FaceCacheEntity.create() which calculates quality metrics automatically
|
* Uses FaceCacheEntity.create() which calculates quality metrics automatically.
|
||||||
|
* Embeddings are NOT generated here - they're generated on-demand in Training/Discovery.
|
||||||
*/
|
*/
|
||||||
private fun createFaceCacheEntry(
|
private fun createFaceCacheEntry(
|
||||||
imageId: String,
|
imageId: String,
|
||||||
@@ -225,7 +240,7 @@ class PopulateFaceDetectionCacheUseCase @Inject constructor(
|
|||||||
imageHeight = imageHeight,
|
imageHeight = imageHeight,
|
||||||
confidence = 0.9f, // High confidence from accurate detector
|
confidence = 0.9f, // High confidence from accurate detector
|
||||||
isFrontal = isFrontal,
|
isFrontal = isFrontal,
|
||||||
embedding = null // Will be generated later during Discovery
|
embedding = null // Generated on-demand in Training/Discovery
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -312,13 +327,27 @@ class PopulateFaceDetectionCacheUseCase @Inject constructor(
|
|||||||
val imageStats = imageDao.getFaceCacheStats()
|
val imageStats = imageDao.getFaceCacheStats()
|
||||||
val faceStats = faceCacheDao.getCacheStats()
|
val faceStats = faceCacheDao.getCacheStats()
|
||||||
|
|
||||||
|
// CRITICAL FIX: If ImageEntity says "scanned" but FaceCacheEntity is empty,
|
||||||
|
// we need to re-scan. This happens after DB migration clears face_cache table.
|
||||||
|
val imagesWithFaces = imageStats?.imagesWithFaces ?: 0
|
||||||
|
val facesCached = faceStats.totalFaces
|
||||||
|
|
||||||
|
// If we have images marked as having faces but no FaceCacheEntity entries,
|
||||||
|
// those images need re-scanning
|
||||||
|
val needsRescan = if (imagesWithFaces > 0 && facesCached == 0) {
|
||||||
|
Log.w(TAG, "⚠️ FaceCacheEntity is empty but $imagesWithFaces images marked as having faces - forcing rescan")
|
||||||
|
imagesWithFaces
|
||||||
|
} else {
|
||||||
|
imageStats?.needsScanning ?: 0
|
||||||
|
}
|
||||||
|
|
||||||
CacheStats(
|
CacheStats(
|
||||||
totalImages = imageStats?.totalImages ?: 0,
|
totalImages = imageStats?.totalImages ?: 0,
|
||||||
imagesWithFaceCache = imageStats?.imagesWithFaceCache ?: 0,
|
imagesWithFaceCache = imageStats?.imagesWithFaceCache ?: 0,
|
||||||
imagesWithFaces = imageStats?.imagesWithFaces ?: 0,
|
imagesWithFaces = imagesWithFaces,
|
||||||
imagesWithoutFaces = imageStats?.imagesWithoutFaces ?: 0,
|
imagesWithoutFaces = imageStats?.imagesWithoutFaces ?: 0,
|
||||||
needsScanning = imageStats?.needsScanning ?: 0,
|
needsScanning = needsRescan,
|
||||||
totalFacesCached = faceStats.totalFaces,
|
totalFacesCached = facesCached,
|
||||||
facesWithEmbeddings = faceStats.withEmbeddings,
|
facesWithEmbeddings = faceStats.withEmbeddings,
|
||||||
averageQuality = faceStats.avgQuality
|
averageQuality = faceStats.avgQuality
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ import androidx.lifecycle.compose.collectAsStateWithLifecycle
|
|||||||
import androidx.navigation.NavController
|
import androidx.navigation.NavController
|
||||||
import coil.compose.AsyncImage
|
import coil.compose.AsyncImage
|
||||||
import com.placeholder.sherpai2.data.local.entity.TagEntity
|
import com.placeholder.sherpai2.data.local.entity.TagEntity
|
||||||
|
import com.placeholder.sherpai2.ui.imagedetail.viewmodel.FaceTagInfo
|
||||||
import com.placeholder.sherpai2.ui.imagedetail.viewmodel.ImageDetailViewModel
|
import com.placeholder.sherpai2.ui.imagedetail.viewmodel.ImageDetailViewModel
|
||||||
import net.engawapg.lib.zoomable.rememberZoomState
|
import net.engawapg.lib.zoomable.rememberZoomState
|
||||||
import net.engawapg.lib.zoomable.zoomable
|
import net.engawapg.lib.zoomable.zoomable
|
||||||
@@ -51,8 +52,12 @@ fun ImageDetailScreen(
|
|||||||
}
|
}
|
||||||
|
|
||||||
val tags by viewModel.tags.collectAsStateWithLifecycle()
|
val tags by viewModel.tags.collectAsStateWithLifecycle()
|
||||||
|
val faceTags by viewModel.faceTags.collectAsStateWithLifecycle()
|
||||||
var showTags by remember { mutableStateOf(false) }
|
var showTags by remember { mutableStateOf(false) }
|
||||||
|
|
||||||
|
// Total tag count for badge
|
||||||
|
val totalTagCount = tags.size + faceTags.size
|
||||||
|
|
||||||
// Navigation state
|
// Navigation state
|
||||||
val currentIndex = if (allImageUris.isNotEmpty()) allImageUris.indexOf(imageUri) else -1
|
val currentIndex = if (allImageUris.isNotEmpty()) allImageUris.indexOf(imageUri) else -1
|
||||||
val hasNavigation = allImageUris.isNotEmpty() && currentIndex >= 0
|
val hasNavigation = allImageUris.isNotEmpty() && currentIndex >= 0
|
||||||
@@ -84,27 +89,35 @@ fun ImageDetailScreen(
|
|||||||
horizontalArrangement = Arrangement.spacedBy(4.dp),
|
horizontalArrangement = Arrangement.spacedBy(4.dp),
|
||||||
verticalAlignment = Alignment.CenterVertically
|
verticalAlignment = Alignment.CenterVertically
|
||||||
) {
|
) {
|
||||||
if (tags.isNotEmpty()) {
|
if (totalTagCount > 0) {
|
||||||
Badge(
|
Badge(
|
||||||
containerColor = if (showTags)
|
containerColor = if (showTags)
|
||||||
MaterialTheme.colorScheme.primary
|
MaterialTheme.colorScheme.primary
|
||||||
|
else if (faceTags.isNotEmpty())
|
||||||
|
MaterialTheme.colorScheme.tertiary
|
||||||
else
|
else
|
||||||
MaterialTheme.colorScheme.surfaceVariant
|
MaterialTheme.colorScheme.surfaceVariant
|
||||||
) {
|
) {
|
||||||
Text(
|
Text(
|
||||||
tags.size.toString(),
|
totalTagCount.toString(),
|
||||||
color = if (showTags)
|
color = if (showTags)
|
||||||
MaterialTheme.colorScheme.onPrimary
|
MaterialTheme.colorScheme.onPrimary
|
||||||
|
else if (faceTags.isNotEmpty())
|
||||||
|
MaterialTheme.colorScheme.onTertiary
|
||||||
else
|
else
|
||||||
MaterialTheme.colorScheme.onSurfaceVariant
|
MaterialTheme.colorScheme.onSurfaceVariant
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Icon(
|
Icon(
|
||||||
if (showTags) Icons.Default.Label else Icons.Default.LocalOffer,
|
if (faceTags.isNotEmpty()) Icons.Default.Face
|
||||||
|
else if (showTags) Icons.Default.Label
|
||||||
|
else Icons.Default.LocalOffer,
|
||||||
"Show Tags",
|
"Show Tags",
|
||||||
tint = if (showTags)
|
tint = if (showTags)
|
||||||
MaterialTheme.colorScheme.primary
|
MaterialTheme.colorScheme.primary
|
||||||
|
else if (faceTags.isNotEmpty())
|
||||||
|
MaterialTheme.colorScheme.tertiary
|
||||||
else
|
else
|
||||||
MaterialTheme.colorScheme.onSurfaceVariant
|
MaterialTheme.colorScheme.onSurfaceVariant
|
||||||
)
|
)
|
||||||
@@ -189,6 +202,30 @@ fun ImageDetailScreen(
|
|||||||
contentPadding = PaddingValues(16.dp),
|
contentPadding = PaddingValues(16.dp),
|
||||||
verticalArrangement = Arrangement.spacedBy(8.dp)
|
verticalArrangement = Arrangement.spacedBy(8.dp)
|
||||||
) {
|
) {
|
||||||
|
// Face Tags Section (People in Photo)
|
||||||
|
if (faceTags.isNotEmpty()) {
|
||||||
|
item {
|
||||||
|
Text(
|
||||||
|
"People (${faceTags.size})",
|
||||||
|
style = MaterialTheme.typography.titleMedium,
|
||||||
|
fontWeight = FontWeight.Bold,
|
||||||
|
color = MaterialTheme.colorScheme.tertiary
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
items(faceTags, key = { it.tagId }) { faceTag ->
|
||||||
|
FaceTagCard(
|
||||||
|
faceTag = faceTag,
|
||||||
|
onRemove = { viewModel.removeFaceTag(faceTag) }
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
item {
|
||||||
|
Spacer(modifier = Modifier.height(8.dp))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Regular Tags Section
|
||||||
item {
|
item {
|
||||||
Text(
|
Text(
|
||||||
"Tags (${tags.size})",
|
"Tags (${tags.size})",
|
||||||
@@ -197,7 +234,7 @@ fun ImageDetailScreen(
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
if (tags.isEmpty()) {
|
if (tags.isEmpty() && faceTags.isEmpty()) {
|
||||||
item {
|
item {
|
||||||
Text(
|
Text(
|
||||||
"No tags yet",
|
"No tags yet",
|
||||||
@@ -205,6 +242,14 @@ fun ImageDetailScreen(
|
|||||||
color = MaterialTheme.colorScheme.onSurfaceVariant
|
color = MaterialTheme.colorScheme.onSurfaceVariant
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
} else if (tags.isEmpty()) {
|
||||||
|
item {
|
||||||
|
Text(
|
||||||
|
"No other tags",
|
||||||
|
style = MaterialTheme.typography.bodySmall,
|
||||||
|
color = MaterialTheme.colorScheme.onSurfaceVariant
|
||||||
|
)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
items(tags, key = { it.tagId }) { tag ->
|
items(tags, key = { it.tagId }) { tag ->
|
||||||
@@ -220,6 +265,83 @@ fun ImageDetailScreen(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Composable
|
||||||
|
private fun FaceTagCard(
|
||||||
|
faceTag: FaceTagInfo,
|
||||||
|
onRemove: () -> Unit
|
||||||
|
) {
|
||||||
|
Card(
|
||||||
|
modifier = Modifier.fillMaxWidth(),
|
||||||
|
colors = CardDefaults.cardColors(
|
||||||
|
containerColor = MaterialTheme.colorScheme.tertiaryContainer
|
||||||
|
),
|
||||||
|
shape = RoundedCornerShape(8.dp)
|
||||||
|
) {
|
||||||
|
Row(
|
||||||
|
modifier = Modifier
|
||||||
|
.fillMaxWidth()
|
||||||
|
.padding(12.dp),
|
||||||
|
horizontalArrangement = Arrangement.SpaceBetween,
|
||||||
|
verticalAlignment = Alignment.CenterVertically
|
||||||
|
) {
|
||||||
|
Column(modifier = Modifier.weight(1f)) {
|
||||||
|
Row(
|
||||||
|
horizontalArrangement = Arrangement.spacedBy(8.dp),
|
||||||
|
verticalAlignment = Alignment.CenterVertically
|
||||||
|
) {
|
||||||
|
Icon(
|
||||||
|
imageVector = Icons.Default.Face,
|
||||||
|
contentDescription = null,
|
||||||
|
modifier = Modifier.size(20.dp),
|
||||||
|
tint = MaterialTheme.colorScheme.tertiary
|
||||||
|
)
|
||||||
|
Text(
|
||||||
|
text = faceTag.personName,
|
||||||
|
style = MaterialTheme.typography.bodyLarge,
|
||||||
|
fontWeight = FontWeight.SemiBold
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
Row(
|
||||||
|
horizontalArrangement = Arrangement.spacedBy(4.dp),
|
||||||
|
verticalAlignment = Alignment.CenterVertically
|
||||||
|
) {
|
||||||
|
Text(
|
||||||
|
text = "Face Recognition",
|
||||||
|
style = MaterialTheme.typography.labelSmall,
|
||||||
|
color = MaterialTheme.colorScheme.onSurfaceVariant
|
||||||
|
)
|
||||||
|
Text(
|
||||||
|
text = "•",
|
||||||
|
style = MaterialTheme.typography.labelSmall,
|
||||||
|
color = MaterialTheme.colorScheme.onSurfaceVariant
|
||||||
|
)
|
||||||
|
Text(
|
||||||
|
text = "${(faceTag.confidence * 100).toInt()}% confidence",
|
||||||
|
style = MaterialTheme.typography.labelSmall,
|
||||||
|
color = if (faceTag.confidence >= 0.7f)
|
||||||
|
MaterialTheme.colorScheme.primary
|
||||||
|
else if (faceTag.confidence >= 0.5f)
|
||||||
|
MaterialTheme.colorScheme.secondary
|
||||||
|
else
|
||||||
|
MaterialTheme.colorScheme.error
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove button
|
||||||
|
IconButton(
|
||||||
|
onClick = onRemove,
|
||||||
|
colors = IconButtonDefaults.iconButtonColors(
|
||||||
|
contentColor = MaterialTheme.colorScheme.error
|
||||||
|
)
|
||||||
|
) {
|
||||||
|
Icon(Icons.Default.Delete, "Remove face tag")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@Composable
|
@Composable
|
||||||
private fun TagCard(
|
private fun TagCard(
|
||||||
tag: TagEntity,
|
tag: TagEntity,
|
||||||
|
|||||||
@@ -2,6 +2,10 @@ package com.placeholder.sherpai2.ui.imagedetail.viewmodel
|
|||||||
|
|
||||||
import androidx.lifecycle.ViewModel
|
import androidx.lifecycle.ViewModel
|
||||||
import androidx.lifecycle.viewModelScope
|
import androidx.lifecycle.viewModelScope
|
||||||
|
import com.placeholder.sherpai2.data.local.dao.FaceModelDao
|
||||||
|
import com.placeholder.sherpai2.data.local.dao.ImageDao
|
||||||
|
import com.placeholder.sherpai2.data.local.dao.PersonDao
|
||||||
|
import com.placeholder.sherpai2.data.local.dao.PhotoFaceTagDao
|
||||||
import com.placeholder.sherpai2.data.local.entity.TagEntity
|
import com.placeholder.sherpai2.data.local.entity.TagEntity
|
||||||
import com.placeholder.sherpai2.domain.repository.TaggingRepository
|
import com.placeholder.sherpai2.domain.repository.TaggingRepository
|
||||||
import dagger.hilt.android.lifecycle.HiltViewModel
|
import dagger.hilt.android.lifecycle.HiltViewModel
|
||||||
@@ -10,17 +14,33 @@ import kotlinx.coroutines.flow.*
|
|||||||
import kotlinx.coroutines.launch
|
import kotlinx.coroutines.launch
|
||||||
import javax.inject.Inject
|
import javax.inject.Inject
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Represents a person tagged in this photo via face recognition
|
||||||
|
*/
|
||||||
|
data class FaceTagInfo(
|
||||||
|
val personId: String,
|
||||||
|
val personName: String,
|
||||||
|
val confidence: Float,
|
||||||
|
val faceModelId: String,
|
||||||
|
val tagId: String
|
||||||
|
)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* ImageDetailViewModel
|
* ImageDetailViewModel
|
||||||
*
|
*
|
||||||
* Owns:
|
* Owns:
|
||||||
* - Image context
|
* - Image context
|
||||||
* - Tag write operations
|
* - Tag write operations
|
||||||
|
* - Face tag display (people recognized in photo)
|
||||||
*/
|
*/
|
||||||
@HiltViewModel
|
@HiltViewModel
|
||||||
@OptIn(ExperimentalCoroutinesApi::class)
|
@OptIn(ExperimentalCoroutinesApi::class)
|
||||||
class ImageDetailViewModel @Inject constructor(
|
class ImageDetailViewModel @Inject constructor(
|
||||||
private val tagRepository: TaggingRepository
|
private val tagRepository: TaggingRepository,
|
||||||
|
private val imageDao: ImageDao,
|
||||||
|
private val photoFaceTagDao: PhotoFaceTagDao,
|
||||||
|
private val faceModelDao: FaceModelDao,
|
||||||
|
private val personDao: PersonDao
|
||||||
) : ViewModel() {
|
) : ViewModel() {
|
||||||
|
|
||||||
private val imageUri = MutableStateFlow<String?>(null)
|
private val imageUri = MutableStateFlow<String?>(null)
|
||||||
@@ -37,8 +57,43 @@ class ImageDetailViewModel @Inject constructor(
|
|||||||
initialValue = emptyList()
|
initialValue = emptyList()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Face tags (people recognized in this photo)
|
||||||
|
private val _faceTags = MutableStateFlow<List<FaceTagInfo>>(emptyList())
|
||||||
|
val faceTags: StateFlow<List<FaceTagInfo>> = _faceTags.asStateFlow()
|
||||||
|
|
||||||
fun loadImage(uri: String) {
|
fun loadImage(uri: String) {
|
||||||
imageUri.value = uri
|
imageUri.value = uri
|
||||||
|
loadFaceTags(uri)
|
||||||
|
}
|
||||||
|
|
||||||
|
private fun loadFaceTags(uri: String) {
|
||||||
|
viewModelScope.launch {
|
||||||
|
try {
|
||||||
|
// Get imageId from URI
|
||||||
|
val image = imageDao.getImageByUri(uri) ?: return@launch
|
||||||
|
|
||||||
|
// Get face tags for this image
|
||||||
|
val faceTags = photoFaceTagDao.getTagsForImage(image.imageId)
|
||||||
|
|
||||||
|
// Resolve to person names
|
||||||
|
val faceTagInfos = faceTags.mapNotNull { tag ->
|
||||||
|
val faceModel = faceModelDao.getFaceModelById(tag.faceModelId) ?: return@mapNotNull null
|
||||||
|
val person = personDao.getPersonById(faceModel.personId) ?: return@mapNotNull null
|
||||||
|
|
||||||
|
FaceTagInfo(
|
||||||
|
personId = person.id,
|
||||||
|
personName = person.name,
|
||||||
|
confidence = tag.confidence,
|
||||||
|
faceModelId = tag.faceModelId,
|
||||||
|
tagId = tag.id
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
_faceTags.value = faceTagInfos.sortedByDescending { it.confidence }
|
||||||
|
} catch (e: Exception) {
|
||||||
|
_faceTags.value = emptyList()
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fun addTag(value: String) {
|
fun addTag(value: String) {
|
||||||
@@ -54,4 +109,15 @@ class ImageDetailViewModel @Inject constructor(
|
|||||||
tagRepository.removeTagFromImage(uri, tag.value)
|
tagRepository.removeTagFromImage(uri, tag.value)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Remove a face tag (person recognition)
|
||||||
|
*/
|
||||||
|
fun removeFaceTag(faceTagInfo: FaceTagInfo) {
|
||||||
|
viewModelScope.launch {
|
||||||
|
photoFaceTagDao.deleteTagById(faceTagInfo.tagId)
|
||||||
|
// Reload face tags
|
||||||
|
imageUri.value?.let { loadFaceTags(it) }
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -95,6 +95,9 @@ fun PersonInventoryScreen(
|
|||||||
},
|
},
|
||||||
onDelete = { personId ->
|
onDelete = { personId ->
|
||||||
viewModel.deletePerson(personId)
|
viewModel.deletePerson(personId)
|
||||||
|
},
|
||||||
|
onClearTags = { personId ->
|
||||||
|
viewModel.clearTagsForPerson(personId)
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
@@ -319,7 +322,8 @@ private fun PersonList(
|
|||||||
persons: List<PersonWithModelInfo>,
|
persons: List<PersonWithModelInfo>,
|
||||||
onScan: (String) -> Unit,
|
onScan: (String) -> Unit,
|
||||||
onView: (String) -> Unit,
|
onView: (String) -> Unit,
|
||||||
onDelete: (String) -> Unit
|
onDelete: (String) -> Unit,
|
||||||
|
onClearTags: (String) -> Unit
|
||||||
) {
|
) {
|
||||||
LazyColumn(
|
LazyColumn(
|
||||||
contentPadding = PaddingValues(vertical = 8.dp)
|
contentPadding = PaddingValues(vertical = 8.dp)
|
||||||
@@ -332,7 +336,8 @@ private fun PersonList(
|
|||||||
person = person,
|
person = person,
|
||||||
onScan = { onScan(person.person.id) },
|
onScan = { onScan(person.person.id) },
|
||||||
onView = { onView(person.person.id) },
|
onView = { onView(person.person.id) },
|
||||||
onDelete = { onDelete(person.person.id) }
|
onDelete = { onDelete(person.person.id) },
|
||||||
|
onClearTags = { onClearTags(person.person.id) }
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -343,9 +348,34 @@ private fun PersonCard(
|
|||||||
person: PersonWithModelInfo,
|
person: PersonWithModelInfo,
|
||||||
onScan: () -> Unit,
|
onScan: () -> Unit,
|
||||||
onView: () -> Unit,
|
onView: () -> Unit,
|
||||||
onDelete: () -> Unit
|
onDelete: () -> Unit,
|
||||||
|
onClearTags: () -> Unit
|
||||||
) {
|
) {
|
||||||
var showDeleteDialog by remember { mutableStateOf(false) }
|
var showDeleteDialog by remember { mutableStateOf(false) }
|
||||||
|
var showClearDialog by remember { mutableStateOf(false) }
|
||||||
|
|
||||||
|
if (showClearDialog) {
|
||||||
|
AlertDialog(
|
||||||
|
onDismissRequest = { showClearDialog = false },
|
||||||
|
title = { Text("Clear tags for ${person.person.name}?") },
|
||||||
|
text = { Text("This will remove all ${person.taggedPhotoCount} photo tags but keep the face model. You can re-scan after clearing.") },
|
||||||
|
confirmButton = {
|
||||||
|
TextButton(
|
||||||
|
onClick = {
|
||||||
|
showClearDialog = false
|
||||||
|
onClearTags()
|
||||||
|
}
|
||||||
|
) {
|
||||||
|
Text("Clear Tags", color = MaterialTheme.colorScheme.error)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
dismissButton = {
|
||||||
|
TextButton(onClick = { showClearDialog = false }) {
|
||||||
|
Text("Cancel")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
if (showDeleteDialog) {
|
if (showDeleteDialog) {
|
||||||
AlertDialog(
|
AlertDialog(
|
||||||
@@ -413,6 +443,17 @@ private fun PersonCard(
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Clear tags button (if has tags)
|
||||||
|
if (person.taggedPhotoCount > 0) {
|
||||||
|
IconButton(onClick = { showClearDialog = true }) {
|
||||||
|
Icon(
|
||||||
|
Icons.Default.ClearAll,
|
||||||
|
contentDescription = "Clear Tags",
|
||||||
|
tint = MaterialTheme.colorScheme.secondary
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Delete button
|
// Delete button
|
||||||
IconButton(onClick = { showDeleteDialog = true }) {
|
IconButton(onClick = { showDeleteDialog = true }) {
|
||||||
Icon(
|
Icon(
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ import com.placeholder.sherpai2.data.local.entity.PersonEntity
|
|||||||
import com.placeholder.sherpai2.data.local.entity.PhotoFaceTagEntity
|
import com.placeholder.sherpai2.data.local.entity.PhotoFaceTagEntity
|
||||||
import com.placeholder.sherpai2.ml.FaceNetModel
|
import com.placeholder.sherpai2.ml.FaceNetModel
|
||||||
import com.placeholder.sherpai2.ml.ThresholdStrategy
|
import com.placeholder.sherpai2.ml.ThresholdStrategy
|
||||||
|
import com.placeholder.sherpai2.domain.clustering.FaceQualityFilter
|
||||||
import dagger.hilt.android.lifecycle.HiltViewModel
|
import dagger.hilt.android.lifecycle.HiltViewModel
|
||||||
import dagger.hilt.android.qualifiers.ApplicationContext
|
import dagger.hilt.android.qualifiers.ApplicationContext
|
||||||
import kotlinx.coroutines.Dispatchers
|
import kotlinx.coroutines.Dispatchers
|
||||||
@@ -105,6 +106,21 @@ class PersonInventoryViewModel @Inject constructor(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Clear all face tags for a person (keep model, allow rescan)
|
||||||
|
*/
|
||||||
|
fun clearTagsForPerson(personId: String) {
|
||||||
|
viewModelScope.launch(Dispatchers.IO) {
|
||||||
|
try {
|
||||||
|
val faceModel = faceModelDao.getFaceModelByPersonId(personId)
|
||||||
|
if (faceModel != null) {
|
||||||
|
photoFaceTagDao.deleteTagsForFaceModel(faceModel.id)
|
||||||
|
}
|
||||||
|
loadPersons()
|
||||||
|
} catch (e: Exception) {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fun scanForPerson(personId: String) {
|
fun scanForPerson(personId: String) {
|
||||||
viewModelScope.launch(Dispatchers.IO) {
|
viewModelScope.launch(Dispatchers.IO) {
|
||||||
try {
|
try {
|
||||||
@@ -127,16 +143,40 @@ class PersonInventoryViewModel @Inject constructor(
|
|||||||
|
|
||||||
val detectorOptions = FaceDetectorOptions.Builder()
|
val detectorOptions = FaceDetectorOptions.Builder()
|
||||||
.setPerformanceMode(FaceDetectorOptions.PERFORMANCE_MODE_ACCURATE)
|
.setPerformanceMode(FaceDetectorOptions.PERFORMANCE_MODE_ACCURATE)
|
||||||
.setLandmarkMode(FaceDetectorOptions.LANDMARK_MODE_NONE)
|
.setLandmarkMode(FaceDetectorOptions.LANDMARK_MODE_ALL) // Needed for age estimation
|
||||||
.setClassificationMode(FaceDetectorOptions.CLASSIFICATION_MODE_NONE)
|
.setClassificationMode(FaceDetectorOptions.CLASSIFICATION_MODE_NONE)
|
||||||
.setMinFaceSize(0.15f)
|
.setMinFaceSize(0.15f)
|
||||||
.build()
|
.build()
|
||||||
|
|
||||||
val detector = FaceDetection.getClient(detectorOptions)
|
val detector = FaceDetection.getClient(detectorOptions)
|
||||||
val modelEmbedding = faceModel.getEmbeddingArray()
|
// CRITICAL: Use ALL centroids for matching
|
||||||
val faceNetModel = FaceNetModel(context)
|
val modelCentroids = faceModel.getCentroids().map { it.getEmbeddingArray() }
|
||||||
val trainingCount = faceModel.trainingImageCount
|
val trainingCount = faceModel.trainingImageCount
|
||||||
val baseThreshold = ThresholdStrategy.getLiberalThreshold(trainingCount)
|
android.util.Log.e("PersonScan", "=== CENTROIDS: ${modelCentroids.size}, trainingCount: $trainingCount ===")
|
||||||
|
|
||||||
|
if (modelCentroids.isEmpty()) {
|
||||||
|
_scanningState.value = ScanningState.Error("No centroids found")
|
||||||
|
return@launch
|
||||||
|
}
|
||||||
|
|
||||||
|
val faceNetModel = FaceNetModel(context)
|
||||||
|
// Production threshold - STRICT to avoid false positives
|
||||||
|
// Solo face photos: 0.62, Group photos: 0.68
|
||||||
|
val baseThreshold = 0.62f
|
||||||
|
val groupPhotoThreshold = 0.68f // Higher bar for multi-face images
|
||||||
|
|
||||||
|
// Load ALL other models for "best match wins" comparison
|
||||||
|
val allModels = faceModelDao.getAllActiveFaceModels()
|
||||||
|
val otherModelCentroids = allModels
|
||||||
|
.filter { it.id != faceModel.id }
|
||||||
|
.map { model -> model.id to model.getCentroids().map { it.getEmbeddingArray() } }
|
||||||
|
|
||||||
|
// Distribution-based minimum threshold (self-calibrating)
|
||||||
|
val distributionMin = (faceModel.averageConfidence - 2 * faceModel.similarityStdDev)
|
||||||
|
.coerceAtLeast(faceModel.similarityMin - 0.05f)
|
||||||
|
.coerceAtLeast(0.50f) // Never go below 0.50 absolute floor
|
||||||
|
|
||||||
|
android.util.Log.d("PersonScan", "Using threshold: solo=$baseThreshold, group=$groupPhotoThreshold, distributionMin=$distributionMin (avgConf=${faceModel.averageConfidence}, stdDev=${faceModel.similarityStdDev}), centroids: ${modelCentroids.size}, competing models: ${otherModelCentroids.size}, isChild=${person.isChild}")
|
||||||
|
|
||||||
val completed = AtomicInteger(0)
|
val completed = AtomicInteger(0)
|
||||||
val facesFound = AtomicInteger(0)
|
val facesFound = AtomicInteger(0)
|
||||||
@@ -148,7 +188,7 @@ class PersonInventoryViewModel @Inject constructor(
|
|||||||
val jobs = untaggedImages.map { image ->
|
val jobs = untaggedImages.map { image ->
|
||||||
async {
|
async {
|
||||||
semaphore.withPermit {
|
semaphore.withPermit {
|
||||||
processImage(image, detector, faceNetModel, modelEmbedding, trainingCount, baseThreshold, personId, faceModel.id, batchMatches, batchUpdateMutex, completed, facesFound, startTime, totalToScan, person.name)
|
processImage(image, detector, faceNetModel, modelCentroids, otherModelCentroids, trainingCount, baseThreshold, groupPhotoThreshold, distributionMin, person.isChild, personId, faceModel.id, batchMatches, batchUpdateMutex, completed, facesFound, startTime, totalToScan, person.name)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -175,7 +215,10 @@ class PersonInventoryViewModel @Inject constructor(
|
|||||||
|
|
||||||
private suspend fun processImage(
|
private suspend fun processImage(
|
||||||
image: ImageEntity, detector: com.google.mlkit.vision.face.FaceDetector, faceNetModel: FaceNetModel,
|
image: ImageEntity, detector: com.google.mlkit.vision.face.FaceDetector, faceNetModel: FaceNetModel,
|
||||||
modelEmbedding: FloatArray, trainingCount: Int, baseThreshold: Float, personId: String, faceModelId: String,
|
modelCentroids: List<FloatArray>, otherModelCentroids: List<Pair<String, List<FloatArray>>>,
|
||||||
|
trainingCount: Int, baseThreshold: Float, groupPhotoThreshold: Float,
|
||||||
|
distributionMin: Float, isChildTarget: Boolean,
|
||||||
|
personId: String, faceModelId: String,
|
||||||
batchMatches: MutableList<Triple<String, String, Float>>, batchUpdateMutex: Mutex,
|
batchMatches: MutableList<Triple<String, String, Float>>, batchUpdateMutex: Mutex,
|
||||||
completed: AtomicInteger, facesFound: AtomicInteger, startTime: Long, totalToScan: Int, personName: String
|
completed: AtomicInteger, facesFound: AtomicInteger, startTime: Long, totalToScan: Int, personName: String
|
||||||
) {
|
) {
|
||||||
@@ -200,9 +243,13 @@ class PersonInventoryViewModel @Inject constructor(
|
|||||||
val scaleX = sizeOpts.outWidth.toFloat() / detectionBitmap.width
|
val scaleX = sizeOpts.outWidth.toFloat() / detectionBitmap.width
|
||||||
val scaleY = sizeOpts.outHeight.toFloat() / detectionBitmap.height
|
val scaleY = sizeOpts.outHeight.toFloat() / detectionBitmap.height
|
||||||
|
|
||||||
val imageQuality = ThresholdStrategy.estimateImageQuality(sizeOpts.outWidth, sizeOpts.outHeight)
|
// CRITICAL: Use higher threshold for group photos (more likely false positives)
|
||||||
val detectionContext = ThresholdStrategy.estimateDetectionContext(faces.size)
|
val isGroupPhoto = faces.size > 1
|
||||||
val threshold = ThresholdStrategy.getOptimalThreshold(trainingCount, imageQuality, detectionContext).coerceAtMost(baseThreshold)
|
val effectiveThreshold = if (isGroupPhoto) groupPhotoThreshold else baseThreshold
|
||||||
|
|
||||||
|
// Track best match in this image (only tag ONE face per image)
|
||||||
|
var bestMatchSimilarity = 0f
|
||||||
|
var foundMatch = false
|
||||||
|
|
||||||
for (face in faces) {
|
for (face in faces) {
|
||||||
val scaledBounds = android.graphics.Rect(
|
val scaledBounds = android.graphics.Rect(
|
||||||
@@ -212,22 +259,70 @@ class PersonInventoryViewModel @Inject constructor(
|
|||||||
(face.boundingBox.bottom * scaleY).toInt()
|
(face.boundingBox.bottom * scaleY).toInt()
|
||||||
)
|
)
|
||||||
|
|
||||||
val faceBitmap = loadFaceRegion(uri, scaledBounds) ?: continue
|
// Skip very small faces (less reliable)
|
||||||
|
val faceArea = scaledBounds.width() * scaledBounds.height()
|
||||||
|
val imageArea = sizeOpts.outWidth * sizeOpts.outHeight
|
||||||
|
val faceRatio = faceArea.toFloat() / imageArea
|
||||||
|
if (faceRatio < 0.02f) continue // Face must be at least 2% of image
|
||||||
|
|
||||||
|
// SIGNAL 2: Age plausibility check (if target is a child)
|
||||||
|
if (isChildTarget) {
|
||||||
|
val ageGroup = FaceQualityFilter.estimateAgeGroup(face, detectionBitmap.width, detectionBitmap.height)
|
||||||
|
if (ageGroup == FaceQualityFilter.AgeGroup.ADULT) {
|
||||||
|
continue // Reject clearly adult faces when searching for a child
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// CRITICAL: Add padding to face crop (same as training)
|
||||||
|
val faceBitmap = loadFaceRegionWithPadding(uri, scaledBounds, sizeOpts.outWidth, sizeOpts.outHeight) ?: continue
|
||||||
val faceEmbedding = faceNetModel.generateEmbedding(faceBitmap)
|
val faceEmbedding = faceNetModel.generateEmbedding(faceBitmap)
|
||||||
val similarity = faceNetModel.calculateSimilarity(faceEmbedding, modelEmbedding)
|
|
||||||
faceBitmap.recycle()
|
faceBitmap.recycle()
|
||||||
|
|
||||||
if (similarity >= threshold) {
|
// Match against target person's centroids
|
||||||
batchUpdateMutex.withLock {
|
val targetSimilarity = modelCentroids.maxOfOrNull { centroid ->
|
||||||
batchMatches.add(Triple(personId, image.imageId, similarity))
|
faceNetModel.calculateSimilarity(faceEmbedding, centroid)
|
||||||
facesFound.incrementAndGet()
|
} ?: 0f
|
||||||
if (batchMatches.size >= BATCH_DB_SIZE) {
|
|
||||||
saveBatchMatches(batchMatches.toList(), faceModelId)
|
// SIGNAL 1: Distribution-based rejection
|
||||||
batchMatches.clear()
|
// If similarity is below (mean - 2*stdDev) or (min - 0.05), it's a statistical outlier
|
||||||
}
|
if (targetSimilarity < distributionMin) {
|
||||||
|
continue // Too far below training distribution
|
||||||
|
}
|
||||||
|
|
||||||
|
// SIGNAL 3: Basic threshold check
|
||||||
|
if (targetSimilarity < effectiveThreshold) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// SIGNAL 4: "Best match wins" - check if any OTHER model scores higher
|
||||||
|
// This prevents tagging siblings/similar people incorrectly
|
||||||
|
val bestOtherSimilarity = otherModelCentroids.maxOfOrNull { (_, centroids) ->
|
||||||
|
centroids.maxOfOrNull { centroid ->
|
||||||
|
faceNetModel.calculateSimilarity(faceEmbedding, centroid)
|
||||||
|
} ?: 0f
|
||||||
|
} ?: 0f
|
||||||
|
|
||||||
|
val isTargetBestMatch = targetSimilarity > bestOtherSimilarity
|
||||||
|
|
||||||
|
// All signals must pass
|
||||||
|
if (isTargetBestMatch && targetSimilarity > bestMatchSimilarity) {
|
||||||
|
bestMatchSimilarity = targetSimilarity
|
||||||
|
foundMatch = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Only add ONE tag per image (the best match)
|
||||||
|
if (foundMatch) {
|
||||||
|
batchUpdateMutex.withLock {
|
||||||
|
batchMatches.add(Triple(personId, image.imageId, bestMatchSimilarity))
|
||||||
|
facesFound.incrementAndGet()
|
||||||
|
if (batchMatches.size >= BATCH_DB_SIZE) {
|
||||||
|
saveBatchMatches(batchMatches.toList(), faceModelId)
|
||||||
|
batchMatches.clear()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
detectionBitmap.recycle()
|
detectionBitmap.recycle()
|
||||||
} catch (e: Exception) {
|
} catch (e: Exception) {
|
||||||
} finally {
|
} finally {
|
||||||
@@ -250,18 +345,32 @@ class PersonInventoryViewModel @Inject constructor(
|
|||||||
} catch (e: Exception) { null }
|
} catch (e: Exception) { null }
|
||||||
}
|
}
|
||||||
|
|
||||||
private fun loadFaceRegion(uri: Uri, bounds: android.graphics.Rect): Bitmap? {
|
/**
|
||||||
|
* Load face region WITH 25% padding - CRITICAL for matching training conditions
|
||||||
|
*/
|
||||||
|
private fun loadFaceRegionWithPadding(uri: Uri, bounds: android.graphics.Rect, imgWidth: Int, imgHeight: Int): Bitmap? {
|
||||||
return try {
|
return try {
|
||||||
val full = context.contentResolver.openInputStream(uri)?.use {
|
val full = context.contentResolver.openInputStream(uri)?.use {
|
||||||
BitmapFactory.decodeStream(it, null, BitmapFactory.Options().apply { inPreferredConfig = Bitmap.Config.ARGB_8888 })
|
BitmapFactory.decodeStream(it, null, BitmapFactory.Options().apply { inPreferredConfig = Bitmap.Config.ARGB_8888 })
|
||||||
} ?: return null
|
} ?: return null
|
||||||
|
|
||||||
val safeLeft = bounds.left.coerceIn(0, full.width - 1)
|
// Add 25% padding (same as training)
|
||||||
val safeTop = bounds.top.coerceIn(0, full.height - 1)
|
val padding = (kotlin.math.max(bounds.width(), bounds.height()) * 0.25f).toInt()
|
||||||
val safeWidth = bounds.width().coerceAtMost(full.width - safeLeft)
|
|
||||||
val safeHeight = bounds.height().coerceAtMost(full.height - safeTop)
|
|
||||||
|
|
||||||
val cropped = Bitmap.createBitmap(full, safeLeft, safeTop, safeWidth, safeHeight)
|
val left = (bounds.left - padding).coerceAtLeast(0)
|
||||||
|
val top = (bounds.top - padding).coerceAtLeast(0)
|
||||||
|
val right = (bounds.right + padding).coerceAtMost(full.width)
|
||||||
|
val bottom = (bounds.bottom + padding).coerceAtMost(full.height)
|
||||||
|
|
||||||
|
val width = right - left
|
||||||
|
val height = bottom - top
|
||||||
|
|
||||||
|
if (width <= 0 || height <= 0) {
|
||||||
|
full.recycle()
|
||||||
|
return null
|
||||||
|
}
|
||||||
|
|
||||||
|
val cropped = Bitmap.createBitmap(full, left, top, width, height)
|
||||||
full.recycle()
|
full.recycle()
|
||||||
cropped
|
cropped
|
||||||
} catch (e: Exception) { null }
|
} catch (e: Exception) { null }
|
||||||
|
|||||||
@@ -30,6 +30,7 @@ import com.placeholder.sherpai2.ui.trainingprep.ScanningState
|
|||||||
import com.placeholder.sherpai2.ui.trainingprep.TrainViewModel
|
import com.placeholder.sherpai2.ui.trainingprep.TrainViewModel
|
||||||
import com.placeholder.sherpai2.ui.trainingprep.TrainingScreen
|
import com.placeholder.sherpai2.ui.trainingprep.TrainingScreen
|
||||||
import com.placeholder.sherpai2.ui.trainingprep.TrainingPhotoSelectorScreen
|
import com.placeholder.sherpai2.ui.trainingprep.TrainingPhotoSelectorScreen
|
||||||
|
import com.placeholder.sherpai2.ui.rollingscan.RollingScanScreen
|
||||||
import com.placeholder.sherpai2.ui.utilities.PhotoUtilitiesScreen
|
import com.placeholder.sherpai2.ui.utilities.PhotoUtilitiesScreen
|
||||||
import java.net.URLDecoder
|
import java.net.URLDecoder
|
||||||
import java.net.URLEncoder
|
import java.net.URLEncoder
|
||||||
@@ -249,7 +250,7 @@ fun AppNavHost(
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* TRAINING PHOTO SELECTOR - Custom gallery with face filtering
|
* TRAINING PHOTO SELECTOR - Premium grid with rolling scan
|
||||||
*/
|
*/
|
||||||
composable(AppRoutes.TRAINING_PHOTO_SELECTOR) {
|
composable(AppRoutes.TRAINING_PHOTO_SELECTOR) {
|
||||||
TrainingPhotoSelectorScreen(
|
TrainingPhotoSelectorScreen(
|
||||||
@@ -262,6 +263,42 @@ fun AppNavHost(
|
|||||||
?.savedStateHandle
|
?.savedStateHandle
|
||||||
?.set("selected_image_uris", uris)
|
?.set("selected_image_uris", uris)
|
||||||
navController.popBackStack()
|
navController.popBackStack()
|
||||||
|
},
|
||||||
|
onLaunchRollingScan = { seedImageIds ->
|
||||||
|
// Navigate to rolling scan with seeds
|
||||||
|
navController.navigate(AppRoutes.rollingScanRoute(seedImageIds))
|
||||||
|
}
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* ROLLING SCAN - Similarity-based photo discovery
|
||||||
|
*
|
||||||
|
* Takes seed image IDs, finds similar faces across library
|
||||||
|
*/
|
||||||
|
composable(
|
||||||
|
route = AppRoutes.ROLLING_SCAN,
|
||||||
|
arguments = listOf(
|
||||||
|
navArgument("seedImageIds") {
|
||||||
|
type = NavType.StringType
|
||||||
|
}
|
||||||
|
)
|
||||||
|
) { backStackEntry ->
|
||||||
|
val seedImageIdsString = backStackEntry.arguments?.getString("seedImageIds") ?: ""
|
||||||
|
val seedImageIds = seedImageIdsString.split(",").filter { it.isNotBlank() }
|
||||||
|
|
||||||
|
RollingScanScreen(
|
||||||
|
seedImageIds = seedImageIds,
|
||||||
|
onSubmitForTraining = { selectedUris ->
|
||||||
|
// Pass selected URIs back to training flow (via photo selector)
|
||||||
|
navController.getBackStackEntry(AppRoutes.TRAIN)
|
||||||
|
.savedStateHandle
|
||||||
|
.set("selected_image_uris", selectedUris.map { Uri.parse(it) })
|
||||||
|
// Pop back to training screen
|
||||||
|
navController.popBackStack(AppRoutes.TRAIN, inclusive = false)
|
||||||
|
},
|
||||||
|
onNavigateBack = {
|
||||||
|
navController.popBackStack()
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
@@ -302,10 +339,7 @@ fun AppNavHost(
|
|||||||
* SETTINGS SCREEN
|
* SETTINGS SCREEN
|
||||||
*/
|
*/
|
||||||
composable(AppRoutes.SETTINGS) {
|
composable(AppRoutes.SETTINGS) {
|
||||||
DummyScreen(
|
com.placeholder.sherpai2.ui.settings.SettingsScreen()
|
||||||
title = "Settings",
|
|
||||||
subtitle = "App preferences and configuration"
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -32,10 +32,17 @@ object AppRoutes {
|
|||||||
// Internal training flow screens
|
// Internal training flow screens
|
||||||
const val IMAGE_SELECTOR = "Image Selection" // DEPRECATED - kept for reference only
|
const val IMAGE_SELECTOR = "Image Selection" // DEPRECATED - kept for reference only
|
||||||
const val TRAINING_PHOTO_SELECTOR = "training_photo_selector" // Face-filtered gallery
|
const val TRAINING_PHOTO_SELECTOR = "training_photo_selector" // Face-filtered gallery
|
||||||
|
const val ROLLING_SCAN = "rolling_scan/{seedImageIds}" // Similarity-based photo finder
|
||||||
const val CROP_SCREEN = "CROP_SCREEN"
|
const val CROP_SCREEN = "CROP_SCREEN"
|
||||||
const val TRAINING_SCREEN = "TRAINING_SCREEN"
|
const val TRAINING_SCREEN = "TRAINING_SCREEN"
|
||||||
const val ScanResultsScreen = "First Scan Results"
|
const val ScanResultsScreen = "First Scan Results"
|
||||||
|
|
||||||
|
// Rolling scan helper
|
||||||
|
fun rollingScanRoute(seedImageIds: List<String>): String {
|
||||||
|
val encoded = seedImageIds.joinToString(",")
|
||||||
|
return "rolling_scan/$encoded"
|
||||||
|
}
|
||||||
|
|
||||||
// Album view
|
// Album view
|
||||||
const val ALBUM_VIEW = "album/{albumType}/{albumId}"
|
const val ALBUM_VIEW = "album/{albumType}/{albumId}"
|
||||||
fun albumRoute(albumType: String, albumId: String) = "album/$albumType/$albumId"
|
fun albumRoute(albumType: String, albumId: String) = "album/$albumType/$albumId"
|
||||||
|
|||||||
@@ -78,6 +78,7 @@ fun MainScreen(
|
|||||||
AppRoutes.DISCOVER -> "Discover People" // ✅ SHOWS NOW!
|
AppRoutes.DISCOVER -> "Discover People" // ✅ SHOWS NOW!
|
||||||
AppRoutes.INVENTORY -> "People"
|
AppRoutes.INVENTORY -> "People"
|
||||||
AppRoutes.TRAIN -> "Train Model"
|
AppRoutes.TRAIN -> "Train Model"
|
||||||
|
AppRoutes.ScanResultsScreen -> "Train New Person"
|
||||||
AppRoutes.TAGS -> "Tags"
|
AppRoutes.TAGS -> "Tags"
|
||||||
AppRoutes.UTILITIES -> "Utilities"
|
AppRoutes.UTILITIES -> "Utilities"
|
||||||
AppRoutes.SETTINGS -> "Settings"
|
AppRoutes.SETTINGS -> "Settings"
|
||||||
|
|||||||
@@ -0,0 +1,206 @@
|
|||||||
|
package com.placeholder.sherpai2.ui.rollingscan
|
||||||
|
|
||||||
|
import androidx.compose.foundation.layout.*
|
||||||
|
import androidx.compose.foundation.shape.RoundedCornerShape
|
||||||
|
import androidx.compose.material.icons.Icons
|
||||||
|
import androidx.compose.material.icons.filled.*
|
||||||
|
import androidx.compose.material3.*
|
||||||
|
import androidx.compose.runtime.*
|
||||||
|
import androidx.compose.ui.Alignment
|
||||||
|
import androidx.compose.ui.Modifier
|
||||||
|
import androidx.compose.ui.text.font.FontWeight
|
||||||
|
import androidx.compose.ui.text.style.TextAlign
|
||||||
|
import androidx.compose.ui.unit.dp
|
||||||
|
import androidx.compose.ui.window.Dialog
|
||||||
|
|
||||||
|
/**
|
||||||
|
* RollingScanModeDialog - Offers Rolling Scan after initial photo selection
|
||||||
|
*
|
||||||
|
* USER JOURNEY:
|
||||||
|
* 1. User selects 3-5 seed photos from photo picker
|
||||||
|
* 2. This dialog appears: "Want to find more similar photos?"
|
||||||
|
* 3. User can:
|
||||||
|
* - "Search & Add More" → Go to Rolling Scan (recommended)
|
||||||
|
* - "Continue with N photos" → Skip to validation
|
||||||
|
*
|
||||||
|
* BENEFITS:
|
||||||
|
* - Suggests intelligent workflow
|
||||||
|
* - Optional (doesn't force)
|
||||||
|
* - Shows potential (N → N*3 photos)
|
||||||
|
* - Fast path for power users
|
||||||
|
*/
|
||||||
|
@Composable
|
||||||
|
fun RollingScanModeDialog(
|
||||||
|
currentPhotoCount: Int,
|
||||||
|
onUseRollingScan: () -> Unit,
|
||||||
|
onContinueWithCurrent: () -> Unit,
|
||||||
|
onDismiss: () -> Unit
|
||||||
|
) {
|
||||||
|
Dialog(onDismissRequest = onDismiss) {
|
||||||
|
Card(
|
||||||
|
modifier = Modifier
|
||||||
|
.fillMaxWidth(0.92f)
|
||||||
|
.wrapContentHeight(),
|
||||||
|
shape = RoundedCornerShape(24.dp),
|
||||||
|
colors = CardDefaults.cardColors(
|
||||||
|
containerColor = MaterialTheme.colorScheme.surface
|
||||||
|
),
|
||||||
|
elevation = CardDefaults.cardElevation(defaultElevation = 8.dp)
|
||||||
|
) {
|
||||||
|
Column(
|
||||||
|
modifier = Modifier
|
||||||
|
.fillMaxWidth()
|
||||||
|
.padding(24.dp),
|
||||||
|
verticalArrangement = Arrangement.spacedBy(20.dp),
|
||||||
|
horizontalAlignment = Alignment.CenterHorizontally
|
||||||
|
) {
|
||||||
|
// Icon
|
||||||
|
Surface(
|
||||||
|
shape = RoundedCornerShape(20.dp),
|
||||||
|
color = MaterialTheme.colorScheme.primaryContainer,
|
||||||
|
modifier = Modifier.size(80.dp)
|
||||||
|
) {
|
||||||
|
Box(contentAlignment = Alignment.Center) {
|
||||||
|
Icon(
|
||||||
|
Icons.Default.AutoAwesome,
|
||||||
|
contentDescription = null,
|
||||||
|
modifier = Modifier.size(44.dp),
|
||||||
|
tint = MaterialTheme.colorScheme.primary
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Title
|
||||||
|
Text(
|
||||||
|
"Find More Similar Photos?",
|
||||||
|
style = MaterialTheme.typography.headlineSmall,
|
||||||
|
fontWeight = FontWeight.Bold,
|
||||||
|
textAlign = TextAlign.Center
|
||||||
|
)
|
||||||
|
|
||||||
|
// Description
|
||||||
|
Column(
|
||||||
|
verticalArrangement = Arrangement.spacedBy(12.dp)
|
||||||
|
) {
|
||||||
|
Text(
|
||||||
|
"You've selected $currentPhotoCount ${if (currentPhotoCount == 1) "photo" else "photos"}. " +
|
||||||
|
"Our AI can scan your library and find similar photos automatically!",
|
||||||
|
style = MaterialTheme.typography.bodyMedium,
|
||||||
|
color = MaterialTheme.colorScheme.onSurfaceVariant,
|
||||||
|
textAlign = TextAlign.Center
|
||||||
|
)
|
||||||
|
|
||||||
|
// Feature highlights
|
||||||
|
Card(
|
||||||
|
colors = CardDefaults.cardColors(
|
||||||
|
containerColor = MaterialTheme.colorScheme.secondaryContainer.copy(alpha = 0.3f)
|
||||||
|
),
|
||||||
|
shape = RoundedCornerShape(12.dp)
|
||||||
|
) {
|
||||||
|
Column(
|
||||||
|
modifier = Modifier.padding(16.dp),
|
||||||
|
verticalArrangement = Arrangement.spacedBy(10.dp)
|
||||||
|
) {
|
||||||
|
FeatureRow(
|
||||||
|
icon = Icons.Default.Speed,
|
||||||
|
text = "Real-time similarity ranking"
|
||||||
|
)
|
||||||
|
FeatureRow(
|
||||||
|
icon = Icons.Default.PhotoLibrary,
|
||||||
|
text = "Get 20-30 photos in seconds"
|
||||||
|
)
|
||||||
|
FeatureRow(
|
||||||
|
icon = Icons.Default.HighQuality,
|
||||||
|
text = "Better training quality"
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Action buttons
|
||||||
|
Column(
|
||||||
|
modifier = Modifier.fillMaxWidth(),
|
||||||
|
verticalArrangement = Arrangement.spacedBy(12.dp)
|
||||||
|
) {
|
||||||
|
// Primary: Use Rolling Scan (RECOMMENDED)
|
||||||
|
Button(
|
||||||
|
onClick = onUseRollingScan,
|
||||||
|
modifier = Modifier
|
||||||
|
.fillMaxWidth()
|
||||||
|
.height(56.dp),
|
||||||
|
shape = RoundedCornerShape(16.dp),
|
||||||
|
colors = ButtonDefaults.buttonColors(
|
||||||
|
containerColor = MaterialTheme.colorScheme.primary
|
||||||
|
)
|
||||||
|
) {
|
||||||
|
Icon(
|
||||||
|
Icons.Default.AutoAwesome,
|
||||||
|
contentDescription = null,
|
||||||
|
modifier = Modifier.size(22.dp)
|
||||||
|
)
|
||||||
|
Spacer(Modifier.width(12.dp))
|
||||||
|
Column(
|
||||||
|
horizontalAlignment = Alignment.CenterHorizontally
|
||||||
|
) {
|
||||||
|
Text(
|
||||||
|
"Search & Add More",
|
||||||
|
style = MaterialTheme.typography.titleMedium,
|
||||||
|
fontWeight = FontWeight.Bold
|
||||||
|
)
|
||||||
|
Text(
|
||||||
|
"Recommended",
|
||||||
|
style = MaterialTheme.typography.labelSmall,
|
||||||
|
color = MaterialTheme.colorScheme.onPrimary.copy(alpha = 0.8f)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Secondary: Skip Rolling Scan
|
||||||
|
OutlinedButton(
|
||||||
|
onClick = onContinueWithCurrent,
|
||||||
|
modifier = Modifier
|
||||||
|
.fillMaxWidth()
|
||||||
|
.height(48.dp),
|
||||||
|
shape = RoundedCornerShape(16.dp)
|
||||||
|
) {
|
||||||
|
Text(
|
||||||
|
"Continue with $currentPhotoCount ${if (currentPhotoCount == 1) "Photo" else "Photos"}",
|
||||||
|
style = MaterialTheme.typography.titleSmall
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tertiary: Cancel/Back
|
||||||
|
TextButton(
|
||||||
|
onClick = onDismiss,
|
||||||
|
modifier = Modifier.fillMaxWidth()
|
||||||
|
) {
|
||||||
|
Text("Go Back")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Composable
|
||||||
|
private fun FeatureRow(
|
||||||
|
icon: androidx.compose.ui.graphics.vector.ImageVector,
|
||||||
|
text: String
|
||||||
|
) {
|
||||||
|
Row(
|
||||||
|
horizontalArrangement = Arrangement.spacedBy(12.dp),
|
||||||
|
verticalAlignment = Alignment.CenterVertically
|
||||||
|
) {
|
||||||
|
Icon(
|
||||||
|
icon,
|
||||||
|
contentDescription = null,
|
||||||
|
modifier = Modifier.size(20.dp),
|
||||||
|
tint = MaterialTheme.colorScheme.primary
|
||||||
|
)
|
||||||
|
Text(
|
||||||
|
text,
|
||||||
|
style = MaterialTheme.typography.bodyMedium,
|
||||||
|
color = MaterialTheme.colorScheme.onSecondaryContainer
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,611 @@
|
|||||||
|
package com.placeholder.sherpai2.ui.rollingscan
|
||||||
|
|
||||||
|
import android.net.Uri
|
||||||
|
import androidx.compose.foundation.BorderStroke
|
||||||
|
import androidx.compose.foundation.ExperimentalFoundationApi
|
||||||
|
import androidx.compose.foundation.clickable
|
||||||
|
import androidx.compose.foundation.combinedClickable
|
||||||
|
import androidx.compose.foundation.layout.*
|
||||||
|
import androidx.compose.foundation.lazy.grid.GridCells
|
||||||
|
import androidx.compose.foundation.lazy.grid.GridItemSpan
|
||||||
|
import androidx.compose.foundation.lazy.grid.LazyVerticalGrid
|
||||||
|
import androidx.compose.foundation.lazy.grid.items
|
||||||
|
import androidx.compose.foundation.shape.CircleShape
|
||||||
|
import androidx.compose.foundation.shape.RoundedCornerShape
|
||||||
|
import androidx.compose.material.icons.Icons
|
||||||
|
import androidx.compose.material.icons.filled.*
|
||||||
|
import androidx.compose.material3.*
|
||||||
|
import androidx.compose.runtime.*
|
||||||
|
import androidx.compose.ui.Alignment
|
||||||
|
import androidx.compose.ui.Modifier
|
||||||
|
import androidx.compose.ui.graphics.Color
|
||||||
|
import androidx.compose.ui.graphics.vector.ImageVector
|
||||||
|
import androidx.compose.ui.layout.ContentScale
|
||||||
|
import androidx.compose.ui.text.font.FontWeight
|
||||||
|
import androidx.compose.ui.unit.dp
|
||||||
|
import androidx.hilt.navigation.compose.hiltViewModel
|
||||||
|
import coil.compose.AsyncImage
|
||||||
|
import com.placeholder.sherpai2.domain.similarity.FaceSimilarityScorer
|
||||||
|
|
||||||
|
/**
|
||||||
|
* RollingScanScreen - Real-time photo ranking UI
|
||||||
|
*
|
||||||
|
* FEATURES:
|
||||||
|
* - Section headers (Most Similar / Good / Other)
|
||||||
|
* - Similarity badges on top matches
|
||||||
|
* - Selection checkmarks
|
||||||
|
* - Face count indicators
|
||||||
|
* - Scanning progress bar
|
||||||
|
* - Quick action buttons (Select Top N)
|
||||||
|
* - Submit button with validation
|
||||||
|
*/
|
||||||
|
@OptIn(ExperimentalMaterial3Api::class, ExperimentalFoundationApi::class)
|
||||||
|
@Composable
|
||||||
|
fun RollingScanScreen(
|
||||||
|
seedImageIds: List<String>,
|
||||||
|
onSubmitForTraining: (List<String>) -> Unit,
|
||||||
|
onNavigateBack: () -> Unit,
|
||||||
|
modifier: Modifier = Modifier,
|
||||||
|
viewModel: RollingScanViewModel = hiltViewModel()
|
||||||
|
) {
|
||||||
|
val uiState by viewModel.uiState.collectAsState()
|
||||||
|
val selectedImageIds by viewModel.selectedImageIds.collectAsState()
|
||||||
|
val negativeImageIds by viewModel.negativeImageIds.collectAsState()
|
||||||
|
val rankedPhotos by viewModel.rankedPhotos.collectAsState()
|
||||||
|
val isScanning by viewModel.isScanning.collectAsState()
|
||||||
|
|
||||||
|
// Initialize on first composition
|
||||||
|
LaunchedEffect(seedImageIds) {
|
||||||
|
viewModel.initialize(seedImageIds)
|
||||||
|
}
|
||||||
|
|
||||||
|
Scaffold(
|
||||||
|
topBar = {
|
||||||
|
RollingScanTopBar(
|
||||||
|
selectedCount = selectedImageIds.size,
|
||||||
|
onNavigateBack = onNavigateBack,
|
||||||
|
onClearSelection = { viewModel.clearSelection() }
|
||||||
|
)
|
||||||
|
},
|
||||||
|
bottomBar = {
|
||||||
|
RollingScanBottomBar(
|
||||||
|
selectedCount = selectedImageIds.size,
|
||||||
|
isReadyForTraining = viewModel.isReadyForTraining(),
|
||||||
|
validationMessage = viewModel.getValidationMessage(),
|
||||||
|
onSelectTopN = { count -> viewModel.selectTopN(count) },
|
||||||
|
onSelectAboveThreshold = { threshold -> viewModel.selectAllAboveThreshold(threshold) },
|
||||||
|
onSubmit = {
|
||||||
|
val uris = viewModel.getSelectedImageUris()
|
||||||
|
onSubmitForTraining(uris)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
},
|
||||||
|
modifier = modifier
|
||||||
|
) { padding ->
|
||||||
|
|
||||||
|
when (val state = uiState) {
|
||||||
|
is RollingScanState.Idle -> {
|
||||||
|
// Waiting for initialization
|
||||||
|
LoadingContent()
|
||||||
|
}
|
||||||
|
|
||||||
|
is RollingScanState.Loading -> {
|
||||||
|
LoadingContent()
|
||||||
|
}
|
||||||
|
|
||||||
|
is RollingScanState.Ready -> {
|
||||||
|
RollingScanPhotoGrid(
|
||||||
|
rankedPhotos = rankedPhotos,
|
||||||
|
selectedImageIds = selectedImageIds,
|
||||||
|
negativeImageIds = negativeImageIds,
|
||||||
|
isScanning = isScanning,
|
||||||
|
onToggleSelection = { imageId -> viewModel.toggleSelection(imageId) },
|
||||||
|
onToggleNegative = { imageId -> viewModel.toggleNegative(imageId) },
|
||||||
|
modifier = Modifier.padding(padding)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
is RollingScanState.Error -> {
|
||||||
|
ErrorContent(
|
||||||
|
message = state.message,
|
||||||
|
onRetry = { viewModel.initialize(seedImageIds) },
|
||||||
|
onBack = onNavigateBack
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
is RollingScanState.SubmittedForTraining -> {
|
||||||
|
// Navigate back handled by parent
|
||||||
|
LaunchedEffect(Unit) {
|
||||||
|
onNavigateBack()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ═══════════════════════════════════════════════════════════
|
||||||
|
// TOP BAR
|
||||||
|
// ═══════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
@OptIn(ExperimentalMaterial3Api::class)
|
||||||
|
@Composable
|
||||||
|
private fun RollingScanTopBar(
|
||||||
|
selectedCount: Int,
|
||||||
|
onNavigateBack: () -> Unit,
|
||||||
|
onClearSelection: () -> Unit
|
||||||
|
) {
|
||||||
|
TopAppBar(
|
||||||
|
title = {
|
||||||
|
Column {
|
||||||
|
Text(
|
||||||
|
"Find Similar Photos",
|
||||||
|
style = MaterialTheme.typography.titleLarge,
|
||||||
|
fontWeight = FontWeight.Bold
|
||||||
|
)
|
||||||
|
Text(
|
||||||
|
"$selectedCount selected",
|
||||||
|
style = MaterialTheme.typography.bodySmall,
|
||||||
|
color = MaterialTheme.colorScheme.onSurfaceVariant
|
||||||
|
)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
navigationIcon = {
|
||||||
|
IconButton(onClick = onNavigateBack) {
|
||||||
|
Icon(Icons.Default.ArrowBack, "Back")
|
||||||
|
}
|
||||||
|
},
|
||||||
|
actions = {
|
||||||
|
if (selectedCount > 0) {
|
||||||
|
TextButton(onClick = onClearSelection) {
|
||||||
|
Text("Clear")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ═══════════════════════════════════════════════════════════
|
||||||
|
// PHOTO GRID - Similarity-based bucketing
|
||||||
|
// ═══════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
@OptIn(ExperimentalFoundationApi::class)
|
||||||
|
@Composable
|
||||||
|
private fun RollingScanPhotoGrid(
|
||||||
|
rankedPhotos: List<FaceSimilarityScorer.ScoredPhoto>,
|
||||||
|
selectedImageIds: Set<String>,
|
||||||
|
negativeImageIds: Set<String>,
|
||||||
|
isScanning: Boolean,
|
||||||
|
onToggleSelection: (String) -> Unit,
|
||||||
|
onToggleNegative: (String) -> Unit,
|
||||||
|
modifier: Modifier = Modifier
|
||||||
|
) {
|
||||||
|
// Bucket by similarity score
|
||||||
|
val veryLikely = rankedPhotos.filter { it.finalScore >= 0.60f }
|
||||||
|
val probably = rankedPhotos.filter { it.finalScore in 0.45f..0.599f }
|
||||||
|
val maybe = rankedPhotos.filter { it.finalScore < 0.45f }
|
||||||
|
|
||||||
|
Column(modifier = modifier.fillMaxSize()) {
|
||||||
|
// Scanning indicator
|
||||||
|
if (isScanning) {
|
||||||
|
LinearProgressIndicator(
|
||||||
|
modifier = Modifier.fillMaxWidth(),
|
||||||
|
color = MaterialTheme.colorScheme.primary
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Hint for negative marking
|
||||||
|
Text(
|
||||||
|
text = "Tap to select • Long-press to mark as NOT this person",
|
||||||
|
style = MaterialTheme.typography.bodySmall,
|
||||||
|
color = MaterialTheme.colorScheme.onSurfaceVariant,
|
||||||
|
modifier = Modifier.padding(horizontal = 12.dp, vertical = 4.dp)
|
||||||
|
)
|
||||||
|
|
||||||
|
LazyVerticalGrid(
|
||||||
|
columns = GridCells.Fixed(3),
|
||||||
|
contentPadding = PaddingValues(8.dp),
|
||||||
|
horizontalArrangement = Arrangement.spacedBy(8.dp),
|
||||||
|
verticalArrangement = Arrangement.spacedBy(8.dp)
|
||||||
|
) {
|
||||||
|
// Section: Very Likely (>60%)
|
||||||
|
if (veryLikely.isNotEmpty()) {
|
||||||
|
item(span = { GridItemSpan(3) }) {
|
||||||
|
SectionHeader(
|
||||||
|
icon = Icons.Default.Whatshot,
|
||||||
|
text = "🟢 Very Likely (${veryLikely.size})",
|
||||||
|
color = Color(0xFF4CAF50)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
items(veryLikely, key = { it.imageId }) { photo ->
|
||||||
|
PhotoCard(
|
||||||
|
photo = photo,
|
||||||
|
isSelected = photo.imageId in selectedImageIds,
|
||||||
|
isNegative = photo.imageId in negativeImageIds,
|
||||||
|
onToggle = { onToggleSelection(photo.imageId) },
|
||||||
|
onLongPress = { onToggleNegative(photo.imageId) },
|
||||||
|
showSimilarityBadge = true
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Section: Probably (45-60%)
|
||||||
|
if (probably.isNotEmpty()) {
|
||||||
|
item(span = { GridItemSpan(3) }) {
|
||||||
|
SectionHeader(
|
||||||
|
icon = Icons.Default.CheckCircle,
|
||||||
|
text = "🟡 Probably (${probably.size})",
|
||||||
|
color = Color(0xFFFFC107)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
items(probably, key = { it.imageId }) { photo ->
|
||||||
|
PhotoCard(
|
||||||
|
photo = photo,
|
||||||
|
isSelected = photo.imageId in selectedImageIds,
|
||||||
|
isNegative = photo.imageId in negativeImageIds,
|
||||||
|
onToggle = { onToggleSelection(photo.imageId) },
|
||||||
|
onLongPress = { onToggleNegative(photo.imageId) },
|
||||||
|
showSimilarityBadge = true
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Section: Maybe (<45%)
|
||||||
|
if (maybe.isNotEmpty()) {
|
||||||
|
item(span = { GridItemSpan(3) }) {
|
||||||
|
SectionHeader(
|
||||||
|
icon = Icons.Default.Photo,
|
||||||
|
text = "🟠 Maybe (${maybe.size})",
|
||||||
|
color = Color(0xFFFF9800)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
items(maybe, key = { it.imageId }) { photo ->
|
||||||
|
PhotoCard(
|
||||||
|
photo = photo,
|
||||||
|
isSelected = photo.imageId in selectedImageIds,
|
||||||
|
isNegative = photo.imageId in negativeImageIds,
|
||||||
|
onToggle = { onToggleSelection(photo.imageId) },
|
||||||
|
onLongPress = { onToggleNegative(photo.imageId) }
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Empty state
|
||||||
|
if (rankedPhotos.isEmpty()) {
|
||||||
|
item(span = { GridItemSpan(3) }) {
|
||||||
|
EmptyStateContent()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ═══════════════════════════════════════════════════════════
|
||||||
|
// PHOTO CARD - with long-press for negative marking
|
||||||
|
// ═══════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
@OptIn(ExperimentalFoundationApi::class)
|
||||||
|
@Composable
|
||||||
|
private fun PhotoCard(
|
||||||
|
photo: FaceSimilarityScorer.ScoredPhoto,
|
||||||
|
isSelected: Boolean,
|
||||||
|
isNegative: Boolean = false,
|
||||||
|
onToggle: () -> Unit,
|
||||||
|
onLongPress: () -> Unit = {},
|
||||||
|
showSimilarityBadge: Boolean = false
|
||||||
|
) {
|
||||||
|
val borderColor = when {
|
||||||
|
isNegative -> Color(0xFFE53935) // Red for negative
|
||||||
|
isSelected -> MaterialTheme.colorScheme.primary
|
||||||
|
else -> MaterialTheme.colorScheme.outline.copy(alpha = 0.3f)
|
||||||
|
}
|
||||||
|
val borderWidth = if (isSelected || isNegative) 3.dp else 1.dp
|
||||||
|
|
||||||
|
Card(
|
||||||
|
modifier = Modifier
|
||||||
|
.aspectRatio(1f)
|
||||||
|
.combinedClickable(
|
||||||
|
onClick = onToggle,
|
||||||
|
onLongClick = onLongPress
|
||||||
|
),
|
||||||
|
border = BorderStroke(borderWidth, borderColor),
|
||||||
|
elevation = CardDefaults.cardElevation(
|
||||||
|
defaultElevation = if (isSelected) 4.dp else 1.dp
|
||||||
|
)
|
||||||
|
) {
|
||||||
|
Box(modifier = Modifier.fillMaxSize()) {
|
||||||
|
// Photo
|
||||||
|
AsyncImage(
|
||||||
|
model = Uri.parse(photo.imageUri),
|
||||||
|
contentDescription = null,
|
||||||
|
modifier = Modifier.fillMaxSize(),
|
||||||
|
contentScale = ContentScale.Crop
|
||||||
|
)
|
||||||
|
|
||||||
|
// Dim overlay for negatives
|
||||||
|
if (isNegative) {
|
||||||
|
Box(
|
||||||
|
modifier = Modifier
|
||||||
|
.fillMaxSize()
|
||||||
|
.padding(0.dp),
|
||||||
|
contentAlignment = Alignment.Center
|
||||||
|
) {
|
||||||
|
Surface(
|
||||||
|
modifier = Modifier.fillMaxSize(),
|
||||||
|
color = Color.Black.copy(alpha = 0.5f)
|
||||||
|
) {}
|
||||||
|
Icon(
|
||||||
|
Icons.Default.Close,
|
||||||
|
contentDescription = "Not this person",
|
||||||
|
tint = Color.White,
|
||||||
|
modifier = Modifier.size(32.dp)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Similarity badge (top-left)
|
||||||
|
if (showSimilarityBadge && !isNegative) {
|
||||||
|
Surface(
|
||||||
|
modifier = Modifier
|
||||||
|
.align(Alignment.TopStart)
|
||||||
|
.padding(6.dp),
|
||||||
|
shape = RoundedCornerShape(8.dp),
|
||||||
|
color = when {
|
||||||
|
photo.finalScore >= 0.60f -> Color(0xFF4CAF50)
|
||||||
|
photo.finalScore >= 0.45f -> Color(0xFFFFC107)
|
||||||
|
else -> Color(0xFFFF9800)
|
||||||
|
},
|
||||||
|
shadowElevation = 4.dp
|
||||||
|
) {
|
||||||
|
Text(
|
||||||
|
text = "${(photo.finalScore * 100).toInt()}%",
|
||||||
|
modifier = Modifier.padding(horizontal = 8.dp, vertical = 4.dp),
|
||||||
|
style = MaterialTheme.typography.labelSmall,
|
||||||
|
fontWeight = FontWeight.Bold,
|
||||||
|
color = Color.White
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Selection checkmark (top-right)
|
||||||
|
if (isSelected) {
|
||||||
|
Surface(
|
||||||
|
modifier = Modifier
|
||||||
|
.align(Alignment.TopEnd)
|
||||||
|
.padding(6.dp)
|
||||||
|
.size(28.dp),
|
||||||
|
shape = CircleShape,
|
||||||
|
color = MaterialTheme.colorScheme.primary,
|
||||||
|
shadowElevation = 4.dp
|
||||||
|
) {
|
||||||
|
Icon(
|
||||||
|
Icons.Default.CheckCircle,
|
||||||
|
contentDescription = "Selected",
|
||||||
|
modifier = Modifier
|
||||||
|
.padding(4.dp)
|
||||||
|
.size(20.dp),
|
||||||
|
tint = MaterialTheme.colorScheme.onPrimary
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Face count badge (bottom-right)
|
||||||
|
if (photo.faceCount > 1 && !isNegative) {
|
||||||
|
Surface(
|
||||||
|
modifier = Modifier
|
||||||
|
.align(Alignment.BottomEnd)
|
||||||
|
.padding(6.dp),
|
||||||
|
shape = CircleShape,
|
||||||
|
color = MaterialTheme.colorScheme.secondary
|
||||||
|
) {
|
||||||
|
Text(
|
||||||
|
text = "${photo.faceCount}",
|
||||||
|
modifier = Modifier.padding(6.dp),
|
||||||
|
style = MaterialTheme.typography.labelSmall,
|
||||||
|
fontWeight = FontWeight.Bold,
|
||||||
|
color = MaterialTheme.colorScheme.onSecondary
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ═══════════════════════════════════════════════════════════
|
||||||
|
// SECTION HEADER
|
||||||
|
// ═══════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
@Composable
|
||||||
|
private fun SectionHeader(
|
||||||
|
icon: ImageVector,
|
||||||
|
text: String,
|
||||||
|
color: Color
|
||||||
|
) {
|
||||||
|
Row(
|
||||||
|
modifier = Modifier
|
||||||
|
.fillMaxWidth()
|
||||||
|
.padding(vertical = 12.dp),
|
||||||
|
horizontalArrangement = Arrangement.spacedBy(8.dp),
|
||||||
|
verticalAlignment = Alignment.CenterVertically
|
||||||
|
) {
|
||||||
|
Icon(
|
||||||
|
icon,
|
||||||
|
contentDescription = null,
|
||||||
|
tint = color,
|
||||||
|
modifier = Modifier.size(24.dp)
|
||||||
|
)
|
||||||
|
Text(
|
||||||
|
text = text,
|
||||||
|
style = MaterialTheme.typography.titleMedium,
|
||||||
|
fontWeight = FontWeight.Bold,
|
||||||
|
color = color
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ═══════════════════════════════════════════════════════════
|
||||||
|
// BOTTOM BAR
|
||||||
|
// ═══════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
@Composable
|
||||||
|
private fun RollingScanBottomBar(
|
||||||
|
selectedCount: Int,
|
||||||
|
isReadyForTraining: Boolean,
|
||||||
|
validationMessage: String?,
|
||||||
|
onSelectTopN: (Int) -> Unit,
|
||||||
|
onSelectAboveThreshold: (Float) -> Unit,
|
||||||
|
onSubmit: () -> Unit
|
||||||
|
) {
|
||||||
|
Surface(
|
||||||
|
tonalElevation = 8.dp,
|
||||||
|
shadowElevation = 8.dp
|
||||||
|
) {
|
||||||
|
Column(
|
||||||
|
modifier = Modifier
|
||||||
|
.fillMaxWidth()
|
||||||
|
.padding(16.dp)
|
||||||
|
) {
|
||||||
|
// Validation message
|
||||||
|
if (validationMessage != null) {
|
||||||
|
Text(
|
||||||
|
text = validationMessage,
|
||||||
|
style = MaterialTheme.typography.bodySmall,
|
||||||
|
color = MaterialTheme.colorScheme.error,
|
||||||
|
modifier = Modifier.padding(bottom = 8.dp)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// First row: threshold selection
|
||||||
|
Row(
|
||||||
|
modifier = Modifier.fillMaxWidth(),
|
||||||
|
horizontalArrangement = Arrangement.spacedBy(6.dp)
|
||||||
|
) {
|
||||||
|
OutlinedButton(
|
||||||
|
onClick = { onSelectAboveThreshold(0.60f) },
|
||||||
|
modifier = Modifier.weight(1f),
|
||||||
|
contentPadding = PaddingValues(horizontal = 8.dp, vertical = 4.dp)
|
||||||
|
) {
|
||||||
|
Text(">60%", style = MaterialTheme.typography.labelSmall)
|
||||||
|
}
|
||||||
|
OutlinedButton(
|
||||||
|
onClick = { onSelectAboveThreshold(0.50f) },
|
||||||
|
modifier = Modifier.weight(1f),
|
||||||
|
contentPadding = PaddingValues(horizontal = 8.dp, vertical = 4.dp)
|
||||||
|
) {
|
||||||
|
Text(">50%", style = MaterialTheme.typography.labelSmall)
|
||||||
|
}
|
||||||
|
OutlinedButton(
|
||||||
|
onClick = { onSelectTopN(15) },
|
||||||
|
modifier = Modifier.weight(1f),
|
||||||
|
contentPadding = PaddingValues(horizontal = 8.dp, vertical = 4.dp)
|
||||||
|
) {
|
||||||
|
Text("Top 15", style = MaterialTheme.typography.labelSmall)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Spacer(Modifier.height(8.dp))
|
||||||
|
|
||||||
|
// Second row: submit
|
||||||
|
Button(
|
||||||
|
onClick = onSubmit,
|
||||||
|
enabled = isReadyForTraining,
|
||||||
|
modifier = Modifier.fillMaxWidth()
|
||||||
|
) {
|
||||||
|
Icon(
|
||||||
|
Icons.Default.Done,
|
||||||
|
contentDescription = null,
|
||||||
|
modifier = Modifier.size(18.dp)
|
||||||
|
)
|
||||||
|
Spacer(Modifier.width(8.dp))
|
||||||
|
Text("Train Model ($selectedCount photos)")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ═══════════════════════════════════════════════════════════
|
||||||
|
// STATE SCREENS
|
||||||
|
// ═══════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
@Composable
|
||||||
|
private fun LoadingContent() {
|
||||||
|
Box(
|
||||||
|
modifier = Modifier.fillMaxSize(),
|
||||||
|
contentAlignment = Alignment.Center
|
||||||
|
) {
|
||||||
|
Column(
|
||||||
|
horizontalAlignment = Alignment.CenterHorizontally,
|
||||||
|
verticalArrangement = Arrangement.spacedBy(16.dp)
|
||||||
|
) {
|
||||||
|
CircularProgressIndicator()
|
||||||
|
Text(
|
||||||
|
"Loading photos...",
|
||||||
|
style = MaterialTheme.typography.bodyLarge
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Composable
|
||||||
|
private fun ErrorContent(
|
||||||
|
message: String,
|
||||||
|
onRetry: () -> Unit,
|
||||||
|
onBack: () -> Unit
|
||||||
|
) {
|
||||||
|
Box(
|
||||||
|
modifier = Modifier.fillMaxSize(),
|
||||||
|
contentAlignment = Alignment.Center
|
||||||
|
) {
|
||||||
|
Column(
|
||||||
|
modifier = Modifier.padding(32.dp),
|
||||||
|
horizontalAlignment = Alignment.CenterHorizontally,
|
||||||
|
verticalArrangement = Arrangement.spacedBy(16.dp)
|
||||||
|
) {
|
||||||
|
Icon(
|
||||||
|
Icons.Default.Error,
|
||||||
|
contentDescription = null,
|
||||||
|
modifier = Modifier.size(64.dp),
|
||||||
|
tint = MaterialTheme.colorScheme.error
|
||||||
|
)
|
||||||
|
|
||||||
|
Text(
|
||||||
|
"Oops!",
|
||||||
|
style = MaterialTheme.typography.headlineMedium,
|
||||||
|
fontWeight = FontWeight.Bold
|
||||||
|
)
|
||||||
|
|
||||||
|
Text(
|
||||||
|
message,
|
||||||
|
style = MaterialTheme.typography.bodyLarge,
|
||||||
|
color = MaterialTheme.colorScheme.onSurfaceVariant
|
||||||
|
)
|
||||||
|
|
||||||
|
Row(
|
||||||
|
horizontalArrangement = Arrangement.spacedBy(8.dp)
|
||||||
|
) {
|
||||||
|
OutlinedButton(onClick = onBack) {
|
||||||
|
Text("Back")
|
||||||
|
}
|
||||||
|
|
||||||
|
Button(onClick = onRetry) {
|
||||||
|
Text("Retry")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Composable
|
||||||
|
private fun EmptyStateContent() {
|
||||||
|
Box(
|
||||||
|
modifier = Modifier
|
||||||
|
.fillMaxWidth()
|
||||||
|
.height(200.dp),
|
||||||
|
contentAlignment = Alignment.Center
|
||||||
|
) {
|
||||||
|
Text(
|
||||||
|
"Select a photo to find similar ones",
|
||||||
|
style = MaterialTheme.typography.bodyLarge,
|
||||||
|
color = MaterialTheme.colorScheme.onSurfaceVariant
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,49 @@
|
|||||||
|
package com.placeholder.sherpai2.ui.rollingscan
|
||||||
|
|
||||||
|
/**
|
||||||
|
* RollingScanState - UI states for Rolling Scan feature
|
||||||
|
*
|
||||||
|
* State machine:
|
||||||
|
* Idle → Loading → Ready ⇄ Error
|
||||||
|
* ↓
|
||||||
|
* SubmittedForTraining
|
||||||
|
*/
|
||||||
|
sealed class RollingScanState {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Initial state - not started
|
||||||
|
*/
|
||||||
|
object Idle : RollingScanState()
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Loading initial data
|
||||||
|
* - Fetching cached embeddings
|
||||||
|
* - Building image URI cache
|
||||||
|
* - Loading seed embeddings
|
||||||
|
*/
|
||||||
|
object Loading : RollingScanState()
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Ready for user interaction
|
||||||
|
*
|
||||||
|
* @param totalPhotos Total number of scannable photos
|
||||||
|
* @param selectedCount Number of currently selected photos
|
||||||
|
*/
|
||||||
|
data class Ready(
|
||||||
|
val totalPhotos: Int,
|
||||||
|
val selectedCount: Int
|
||||||
|
) : RollingScanState()
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Error state
|
||||||
|
*
|
||||||
|
* @param message Error message to display
|
||||||
|
*/
|
||||||
|
data class Error(val message: String) : RollingScanState()
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Photos submitted for training
|
||||||
|
* Navigate back to training flow
|
||||||
|
*/
|
||||||
|
object SubmittedForTraining : RollingScanState()
|
||||||
|
}
|
||||||
@@ -0,0 +1,459 @@
|
|||||||
|
package com.placeholder.sherpai2.ui.rollingscan
|
||||||
|
|
||||||
|
import android.net.Uri
|
||||||
|
import android.util.Log
|
||||||
|
import androidx.lifecycle.ViewModel
|
||||||
|
import androidx.lifecycle.viewModelScope
|
||||||
|
import com.placeholder.sherpai2.data.local.dao.FaceCacheDao
|
||||||
|
import com.placeholder.sherpai2.data.local.dao.ImageDao
|
||||||
|
import com.placeholder.sherpai2.domain.similarity.FaceSimilarityScorer
|
||||||
|
import com.placeholder.sherpai2.util.Debouncer
|
||||||
|
import dagger.hilt.android.lifecycle.HiltViewModel
|
||||||
|
import kotlinx.coroutines.flow.MutableStateFlow
|
||||||
|
import kotlinx.coroutines.flow.StateFlow
|
||||||
|
import kotlinx.coroutines.flow.asStateFlow
|
||||||
|
import kotlinx.coroutines.launch
|
||||||
|
import javax.inject.Inject
|
||||||
|
|
||||||
|
/**
|
||||||
|
* RollingScanViewModel - Real-time photo ranking based on similarity
|
||||||
|
*
|
||||||
|
* WORKFLOW:
|
||||||
|
* 1. Initialize with seed photos (from initial selection or cluster)
|
||||||
|
* 2. Load all scannable photos with cached embeddings
|
||||||
|
* 3. User selects/deselects photos
|
||||||
|
* 4. Debounced scan triggers → Calculate centroid → Rank all photos
|
||||||
|
* 5. UI updates with ranked photos (most similar first)
|
||||||
|
* 6. User continues selecting until satisfied
|
||||||
|
* 7. Submit selected photos for training
|
||||||
|
*
|
||||||
|
* PERFORMANCE:
|
||||||
|
* - Debounced scanning (300ms delay) avoids excessive re-ranking
|
||||||
|
* - Batch queries fetch 1000+ photos in ~10ms
|
||||||
|
* - Similarity scoring ~100ms for 1000 photos
|
||||||
|
* - Total scan cycle: ~120ms (smooth real-time UI)
|
||||||
|
*/
|
||||||
|
@HiltViewModel
|
||||||
|
class RollingScanViewModel @Inject constructor(
|
||||||
|
private val faceSimilarityScorer: FaceSimilarityScorer,
|
||||||
|
private val faceCacheDao: FaceCacheDao,
|
||||||
|
private val imageDao: ImageDao
|
||||||
|
) : ViewModel() {
|
||||||
|
|
||||||
|
companion object {
|
||||||
|
private const val TAG = "RollingScanVM"
|
||||||
|
private const val DEBOUNCE_DELAY_MS = 300L
|
||||||
|
private const val MIN_PHOTOS_FOR_TRAINING = 15
|
||||||
|
|
||||||
|
// Progressive thresholds based on selection count
|
||||||
|
private const val FLOOR_FEW_SEEDS = 0.30f // 1-3 seeds
|
||||||
|
private const val FLOOR_MEDIUM_SEEDS = 0.40f // 4-10 seeds
|
||||||
|
private const val FLOOR_MANY_SEEDS = 0.50f // 10+ seeds
|
||||||
|
}
|
||||||
|
|
||||||
|
// ═══════════════════════════════════════════════════════════
|
||||||
|
// STATE
|
||||||
|
// ═══════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
private val _uiState = MutableStateFlow<RollingScanState>(RollingScanState.Idle)
|
||||||
|
val uiState: StateFlow<RollingScanState> = _uiState.asStateFlow()
|
||||||
|
|
||||||
|
private val _selectedImageIds = MutableStateFlow<Set<String>>(emptySet())
|
||||||
|
val selectedImageIds: StateFlow<Set<String>> = _selectedImageIds.asStateFlow()
|
||||||
|
|
||||||
|
private val _rankedPhotos = MutableStateFlow<List<FaceSimilarityScorer.ScoredPhoto>>(emptyList())
|
||||||
|
val rankedPhotos: StateFlow<List<FaceSimilarityScorer.ScoredPhoto>> = _rankedPhotos.asStateFlow()
|
||||||
|
|
||||||
|
private val _isScanning = MutableStateFlow(false)
|
||||||
|
val isScanning: StateFlow<Boolean> = _isScanning.asStateFlow()
|
||||||
|
|
||||||
|
// Debouncer to avoid re-scanning on every selection
|
||||||
|
private val scanDebouncer = Debouncer(
|
||||||
|
delayMs = DEBOUNCE_DELAY_MS,
|
||||||
|
scope = viewModelScope
|
||||||
|
)
|
||||||
|
|
||||||
|
// Cache of selected embeddings
|
||||||
|
private val selectedEmbeddings = mutableListOf<FloatArray>()
|
||||||
|
|
||||||
|
// Negative embeddings (marked as "not this person")
|
||||||
|
private val _negativeImageIds = MutableStateFlow<Set<String>>(emptySet())
|
||||||
|
val negativeImageIds: StateFlow<Set<String>> = _negativeImageIds.asStateFlow()
|
||||||
|
private val negativeEmbeddings = mutableListOf<FloatArray>()
|
||||||
|
|
||||||
|
// All available image IDs
|
||||||
|
private var allImageIds: List<String> = emptyList()
|
||||||
|
|
||||||
|
// Image URI cache (imageId -> imageUri)
|
||||||
|
private var imageUriCache: Map<String, String> = emptyMap()
|
||||||
|
|
||||||
|
// ═══════════════════════════════════════════════════════════
|
||||||
|
// INITIALIZATION
|
||||||
|
// ═══════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Initialize with seed photos (from initial selection or cluster)
|
||||||
|
*
|
||||||
|
* @param seedImageIds List of image IDs to start with
|
||||||
|
*/
|
||||||
|
fun initialize(seedImageIds: List<String>) {
|
||||||
|
viewModelScope.launch {
|
||||||
|
try {
|
||||||
|
_uiState.value = RollingScanState.Loading
|
||||||
|
|
||||||
|
Log.d(TAG, "Initializing with ${seedImageIds.size} seed photos")
|
||||||
|
|
||||||
|
// Add seed photos to selection
|
||||||
|
_selectedImageIds.value = seedImageIds.toSet()
|
||||||
|
|
||||||
|
// Load ALL photos with cached embeddings
|
||||||
|
val cachedPhotos = faceCacheDao.getAllPhotosWithFacesForScanning()
|
||||||
|
|
||||||
|
Log.d(TAG, "Loaded ${cachedPhotos.size} photos with cached embeddings")
|
||||||
|
|
||||||
|
if (cachedPhotos.isEmpty()) {
|
||||||
|
_uiState.value = RollingScanState.Error(
|
||||||
|
"No cached embeddings found. Please run face cache population first."
|
||||||
|
)
|
||||||
|
return@launch
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract image IDs
|
||||||
|
allImageIds = cachedPhotos.map { it.imageId }.distinct()
|
||||||
|
|
||||||
|
// Build URI cache from ImageDao
|
||||||
|
val images = imageDao.getImagesByIds(allImageIds)
|
||||||
|
imageUriCache = images.associate { it.imageId to it.imageUri }
|
||||||
|
|
||||||
|
Log.d(TAG, "Built URI cache for ${imageUriCache.size} images")
|
||||||
|
|
||||||
|
// Get embeddings for seed photos
|
||||||
|
val seedEmbeddings = faceCacheDao.getEmbeddingsForImages(seedImageIds)
|
||||||
|
selectedEmbeddings.clear()
|
||||||
|
selectedEmbeddings.addAll(seedEmbeddings.mapNotNull { it.getEmbedding() })
|
||||||
|
|
||||||
|
Log.d(TAG, "Loaded ${selectedEmbeddings.size} seed embeddings")
|
||||||
|
|
||||||
|
// Initial scan
|
||||||
|
triggerRollingScan()
|
||||||
|
|
||||||
|
_uiState.value = RollingScanState.Ready(
|
||||||
|
totalPhotos = allImageIds.size,
|
||||||
|
selectedCount = seedImageIds.size
|
||||||
|
)
|
||||||
|
|
||||||
|
} catch (e: Exception) {
|
||||||
|
Log.e(TAG, "Failed to initialize", e)
|
||||||
|
_uiState.value = RollingScanState.Error(
|
||||||
|
"Failed to initialize: ${e.message}"
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ═══════════════════════════════════════════════════════════
|
||||||
|
// SELECTION MANAGEMENT
|
||||||
|
// ═══════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Toggle photo selection
|
||||||
|
*/
|
||||||
|
fun toggleSelection(imageId: String) {
|
||||||
|
val current = _selectedImageIds.value.toMutableSet()
|
||||||
|
|
||||||
|
if (imageId in current) {
|
||||||
|
// Deselect
|
||||||
|
current.remove(imageId)
|
||||||
|
|
||||||
|
viewModelScope.launch {
|
||||||
|
val cached = faceCacheDao.getEmbeddingByImageId(imageId)
|
||||||
|
cached?.getEmbedding()?.let { selectedEmbeddings.remove(it) }
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Select (and remove from negatives if present)
|
||||||
|
current.add(imageId)
|
||||||
|
if (imageId in _negativeImageIds.value) {
|
||||||
|
toggleNegative(imageId)
|
||||||
|
}
|
||||||
|
|
||||||
|
viewModelScope.launch {
|
||||||
|
val cached = faceCacheDao.getEmbeddingByImageId(imageId)
|
||||||
|
cached?.getEmbedding()?.let { selectedEmbeddings.add(it) }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
_selectedImageIds.value = current.toSet() // Immutable copy
|
||||||
|
|
||||||
|
scanDebouncer.debounce {
|
||||||
|
triggerRollingScan()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Toggle negative marking ("Not this person")
|
||||||
|
*/
|
||||||
|
fun toggleNegative(imageId: String) {
|
||||||
|
val current = _negativeImageIds.value.toMutableSet()
|
||||||
|
|
||||||
|
if (imageId in current) {
|
||||||
|
current.remove(imageId)
|
||||||
|
viewModelScope.launch {
|
||||||
|
val cached = faceCacheDao.getEmbeddingByImageId(imageId)
|
||||||
|
cached?.getEmbedding()?.let { negativeEmbeddings.remove(it) }
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
current.add(imageId)
|
||||||
|
// Remove from selected if present
|
||||||
|
if (imageId in _selectedImageIds.value) {
|
||||||
|
toggleSelection(imageId)
|
||||||
|
}
|
||||||
|
viewModelScope.launch {
|
||||||
|
val cached = faceCacheDao.getEmbeddingByImageId(imageId)
|
||||||
|
cached?.getEmbedding()?.let { negativeEmbeddings.add(it) }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
_negativeImageIds.value = current.toSet() // Immutable copy
|
||||||
|
|
||||||
|
scanDebouncer.debounce {
|
||||||
|
triggerRollingScan()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Select top N photos
|
||||||
|
*/
|
||||||
|
fun selectTopN(count: Int) {
|
||||||
|
val topPhotos = _rankedPhotos.value
|
||||||
|
.take(count)
|
||||||
|
.map { it.imageId }
|
||||||
|
.toSet()
|
||||||
|
|
||||||
|
val current = _selectedImageIds.value.toMutableSet()
|
||||||
|
current.addAll(topPhotos)
|
||||||
|
_selectedImageIds.value = current.toSet() // Immutable copy
|
||||||
|
|
||||||
|
viewModelScope.launch {
|
||||||
|
val embeddings = faceCacheDao.getEmbeddingsForImages(topPhotos.toList())
|
||||||
|
selectedEmbeddings.addAll(embeddings.mapNotNull { it.getEmbedding() })
|
||||||
|
triggerRollingScan()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Select all photos above a similarity threshold
|
||||||
|
*/
|
||||||
|
fun selectAllAboveThreshold(threshold: Float) {
|
||||||
|
val photosAbove = _rankedPhotos.value
|
||||||
|
.filter { it.finalScore >= threshold }
|
||||||
|
.map { it.imageId }
|
||||||
|
|
||||||
|
val current = _selectedImageIds.value.toMutableSet()
|
||||||
|
current.addAll(photosAbove)
|
||||||
|
_selectedImageIds.value = current.toSet() // Immutable copy
|
||||||
|
|
||||||
|
viewModelScope.launch {
|
||||||
|
val newIds = photosAbove.filter { it !in _selectedImageIds.value }
|
||||||
|
if (newIds.isNotEmpty()) {
|
||||||
|
val embeddings = faceCacheDao.getEmbeddingsForImages(newIds)
|
||||||
|
selectedEmbeddings.addAll(embeddings.mapNotNull { it.getEmbedding() })
|
||||||
|
}
|
||||||
|
triggerRollingScan()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Clear all selections
|
||||||
|
*/
|
||||||
|
fun clearSelection() {
|
||||||
|
_selectedImageIds.value = emptySet()
|
||||||
|
selectedEmbeddings.clear()
|
||||||
|
_rankedPhotos.value = emptyList()
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Clear negative markings
|
||||||
|
*/
|
||||||
|
fun clearNegatives() {
|
||||||
|
_negativeImageIds.value = emptySet()
|
||||||
|
negativeEmbeddings.clear()
|
||||||
|
scanDebouncer.debounce { triggerRollingScan() }
|
||||||
|
}
|
||||||
|
|
||||||
|
// ═══════════════════════════════════════════════════════════
|
||||||
|
// ROLLING SCAN LOGIC
|
||||||
|
// ═══════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
/**
|
||||||
|
* CORE: Trigger rolling similarity scan with progressive filtering
|
||||||
|
*/
|
||||||
|
private suspend fun triggerRollingScan() {
|
||||||
|
if (selectedEmbeddings.isEmpty()) {
|
||||||
|
_rankedPhotos.value = emptyList()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
_isScanning.value = true
|
||||||
|
|
||||||
|
val selectionCount = selectedEmbeddings.size
|
||||||
|
Log.d(TAG, "Starting scan with $selectionCount selected, ${negativeEmbeddings.size} negative")
|
||||||
|
|
||||||
|
// Progressive threshold based on selection count
|
||||||
|
val similarityFloor = when {
|
||||||
|
selectionCount <= 3 -> FLOOR_FEW_SEEDS
|
||||||
|
selectionCount <= 10 -> FLOOR_MEDIUM_SEEDS
|
||||||
|
else -> FLOOR_MANY_SEEDS
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate centroid from selected embeddings
|
||||||
|
val centroid = faceSimilarityScorer.calculateCentroid(selectedEmbeddings)
|
||||||
|
|
||||||
|
// Score all unselected photos
|
||||||
|
val scoredPhotos = faceSimilarityScorer.scorePhotosAgainstCentroid(
|
||||||
|
allImageIds = allImageIds,
|
||||||
|
selectedImageIds = _selectedImageIds.value,
|
||||||
|
centroid = centroid
|
||||||
|
)
|
||||||
|
|
||||||
|
// Apply negative penalty, quality boost, and floor filter
|
||||||
|
val filteredPhotos = scoredPhotos
|
||||||
|
.map { photo ->
|
||||||
|
// Calculate max similarity to any negative embedding
|
||||||
|
val negativePenalty = if (negativeEmbeddings.isNotEmpty()) {
|
||||||
|
negativeEmbeddings.maxOfOrNull { neg ->
|
||||||
|
cosineSimilarity(photo.cachedEmbedding, neg)
|
||||||
|
} ?: 0f
|
||||||
|
} else 0f
|
||||||
|
|
||||||
|
// Quality multiplier: solo face, large face, good quality
|
||||||
|
val qualityMultiplier = 1f +
|
||||||
|
(if (photo.faceCount == 1) 0.15f else 0f) +
|
||||||
|
(if (photo.faceAreaRatio > 0.15f) 0.10f else 0f) +
|
||||||
|
(if (photo.qualityScore > 0.7f) 0.10f else 0f)
|
||||||
|
|
||||||
|
// Final score = (similarity - negativePenalty) * qualityMultiplier
|
||||||
|
val adjustedScore = ((photo.similarityScore - negativePenalty * 0.5f) * qualityMultiplier)
|
||||||
|
.coerceIn(0f, 1f)
|
||||||
|
|
||||||
|
photo.copy(
|
||||||
|
imageUri = imageUriCache[photo.imageId] ?: photo.imageId,
|
||||||
|
finalScore = adjustedScore
|
||||||
|
)
|
||||||
|
}
|
||||||
|
.filter { it.finalScore >= similarityFloor } // Apply floor
|
||||||
|
.filter { it.imageId !in _negativeImageIds.value } // Hide negatives
|
||||||
|
.sortedByDescending { it.finalScore }
|
||||||
|
|
||||||
|
Log.d(TAG, "Scan complete. ${filteredPhotos.size} photos above floor $similarityFloor")
|
||||||
|
|
||||||
|
_rankedPhotos.value = filteredPhotos
|
||||||
|
|
||||||
|
} catch (e: Exception) {
|
||||||
|
Log.e(TAG, "Scan failed", e)
|
||||||
|
} finally {
|
||||||
|
_isScanning.value = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private fun cosineSimilarity(a: FloatArray, b: FloatArray): Float {
|
||||||
|
if (a.size != b.size) return 0f
|
||||||
|
var dot = 0f
|
||||||
|
var normA = 0f
|
||||||
|
var normB = 0f
|
||||||
|
for (i in a.indices) {
|
||||||
|
dot += a[i] * b[i]
|
||||||
|
normA += a[i] * a[i]
|
||||||
|
normB += b[i] * b[i]
|
||||||
|
}
|
||||||
|
return if (normA > 0 && normB > 0) dot / (kotlin.math.sqrt(normA) * kotlin.math.sqrt(normB)) else 0f
|
||||||
|
}
|
||||||
|
|
||||||
|
// ═══════════════════════════════════════════════════════════
|
||||||
|
// SUBMISSION
|
||||||
|
// ═══════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get selected image URIs for training submission
|
||||||
|
*
|
||||||
|
* @return List of URIs as strings
|
||||||
|
*/
|
||||||
|
fun getSelectedImageUris(): List<String> {
|
||||||
|
return _selectedImageIds.value.mapNotNull { imageId ->
|
||||||
|
imageUriCache[imageId]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Check if ready for training
|
||||||
|
*/
|
||||||
|
fun isReadyForTraining(): Boolean {
|
||||||
|
return _selectedImageIds.value.size >= MIN_PHOTOS_FOR_TRAINING
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get validation message
|
||||||
|
*/
|
||||||
|
fun getValidationMessage(): String? {
|
||||||
|
val selectedCount = _selectedImageIds.value.size
|
||||||
|
return when {
|
||||||
|
selectedCount < MIN_PHOTOS_FOR_TRAINING ->
|
||||||
|
"Need at least $MIN_PHOTOS_FOR_TRAINING photos, have $selectedCount"
|
||||||
|
else -> null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Reset state
|
||||||
|
*/
|
||||||
|
fun reset() {
|
||||||
|
_uiState.value = RollingScanState.Idle
|
||||||
|
_selectedImageIds.value = emptySet()
|
||||||
|
_negativeImageIds.value = emptySet()
|
||||||
|
_rankedPhotos.value = emptyList()
|
||||||
|
_isScanning.value = false
|
||||||
|
selectedEmbeddings.clear()
|
||||||
|
negativeEmbeddings.clear()
|
||||||
|
allImageIds = emptyList()
|
||||||
|
imageUriCache = emptyMap()
|
||||||
|
scanDebouncer.cancel()
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun onCleared() {
|
||||||
|
super.onCleared()
|
||||||
|
scanDebouncer.cancel()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ═══════════════════════════════════════════════════════════
|
||||||
|
// HELPER EXTENSION
|
||||||
|
// ═══════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Copy ScoredPhoto with updated imageUri
|
||||||
|
*/
|
||||||
|
private fun FaceSimilarityScorer.ScoredPhoto.copy(
|
||||||
|
imageId: String = this.imageId,
|
||||||
|
imageUri: String = this.imageUri,
|
||||||
|
faceIndex: Int = this.faceIndex,
|
||||||
|
similarityScore: Float = this.similarityScore,
|
||||||
|
qualityBoost: Float = this.qualityBoost,
|
||||||
|
finalScore: Float = this.finalScore,
|
||||||
|
faceCount: Int = this.faceCount,
|
||||||
|
faceAreaRatio: Float = this.faceAreaRatio,
|
||||||
|
qualityScore: Float = this.qualityScore,
|
||||||
|
cachedEmbedding: FloatArray = this.cachedEmbedding
|
||||||
|
): FaceSimilarityScorer.ScoredPhoto {
|
||||||
|
return FaceSimilarityScorer.ScoredPhoto(
|
||||||
|
imageId = imageId,
|
||||||
|
imageUri = imageUri,
|
||||||
|
faceIndex = faceIndex,
|
||||||
|
similarityScore = similarityScore,
|
||||||
|
qualityBoost = qualityBoost,
|
||||||
|
finalScore = finalScore,
|
||||||
|
faceCount = faceCount,
|
||||||
|
faceAreaRatio = faceAreaRatio,
|
||||||
|
qualityScore = qualityScore,
|
||||||
|
cachedEmbedding = cachedEmbedding
|
||||||
|
)
|
||||||
|
}
|
||||||
@@ -4,6 +4,7 @@ import android.os.Build
|
|||||||
import android.view.View
|
import android.view.View
|
||||||
import android.view.autofill.AutofillManager
|
import android.view.autofill.AutofillManager
|
||||||
import androidx.annotation.RequiresApi
|
import androidx.annotation.RequiresApi
|
||||||
|
import androidx.compose.foundation.clickable
|
||||||
import androidx.compose.foundation.layout.*
|
import androidx.compose.foundation.layout.*
|
||||||
import androidx.compose.foundation.rememberScrollState
|
import androidx.compose.foundation.rememberScrollState
|
||||||
import androidx.compose.foundation.shape.RoundedCornerShape
|
import androidx.compose.foundation.shape.RoundedCornerShape
|
||||||
@@ -28,11 +29,12 @@ import java.util.*
|
|||||||
@Composable
|
@Composable
|
||||||
fun BeautifulPersonInfoDialog(
|
fun BeautifulPersonInfoDialog(
|
||||||
onDismiss: () -> Unit,
|
onDismiss: () -> Unit,
|
||||||
onConfirm: (name: String, dateOfBirth: Long?, relationship: String) -> Unit
|
onConfirm: (name: String, dateOfBirth: Long?, relationship: String, isChild: Boolean) -> Unit
|
||||||
) {
|
) {
|
||||||
var name by remember { mutableStateOf("") }
|
var name by remember { mutableStateOf("") }
|
||||||
var dateOfBirth by remember { mutableStateOf<Long?>(null) }
|
var dateOfBirth by remember { mutableStateOf<Long?>(null) }
|
||||||
var selectedRelationship by remember { mutableStateOf("Other") }
|
var selectedRelationship by remember { mutableStateOf("Other") }
|
||||||
|
var isChild by remember { mutableStateOf(false) }
|
||||||
var showDatePicker by remember { mutableStateOf(false) }
|
var showDatePicker by remember { mutableStateOf(false) }
|
||||||
|
|
||||||
// ✅ Disable autofill for this dialog
|
// ✅ Disable autofill for this dialog
|
||||||
@@ -108,8 +110,75 @@ fun BeautifulPersonInfoDialog(
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Child toggle
|
||||||
|
Surface(
|
||||||
|
modifier = Modifier
|
||||||
|
.fillMaxWidth()
|
||||||
|
.clickable { isChild = !isChild },
|
||||||
|
color = if (isChild) MaterialTheme.colorScheme.primaryContainer
|
||||||
|
else MaterialTheme.colorScheme.surfaceVariant,
|
||||||
|
shape = RoundedCornerShape(16.dp)
|
||||||
|
) {
|
||||||
|
Row(
|
||||||
|
modifier = Modifier
|
||||||
|
.fillMaxWidth()
|
||||||
|
.padding(16.dp),
|
||||||
|
verticalAlignment = Alignment.CenterVertically,
|
||||||
|
horizontalArrangement = Arrangement.SpaceBetween
|
||||||
|
) {
|
||||||
|
Row(
|
||||||
|
verticalAlignment = Alignment.CenterVertically,
|
||||||
|
horizontalArrangement = Arrangement.spacedBy(12.dp)
|
||||||
|
) {
|
||||||
|
Icon(
|
||||||
|
imageVector = Icons.Default.Face,
|
||||||
|
contentDescription = null,
|
||||||
|
tint = if (isChild) MaterialTheme.colorScheme.primary
|
||||||
|
else MaterialTheme.colorScheme.onSurfaceVariant
|
||||||
|
)
|
||||||
|
Column {
|
||||||
|
Text(
|
||||||
|
"This is a child",
|
||||||
|
style = MaterialTheme.typography.bodyLarge,
|
||||||
|
fontWeight = FontWeight.Medium,
|
||||||
|
color = if (isChild) MaterialTheme.colorScheme.onPrimaryContainer
|
||||||
|
else MaterialTheme.colorScheme.onSurfaceVariant
|
||||||
|
)
|
||||||
|
Text(
|
||||||
|
"Creates age tags (emma_age2, emma_age3...)",
|
||||||
|
style = MaterialTheme.typography.bodySmall,
|
||||||
|
color = if (isChild) MaterialTheme.colorScheme.onPrimaryContainer.copy(alpha = 0.7f)
|
||||||
|
else MaterialTheme.colorScheme.onSurfaceVariant.copy(alpha = 0.7f)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Switch(
|
||||||
|
checked = isChild,
|
||||||
|
onCheckedChange = { isChild = it }
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Birthday (more prominent for children)
|
||||||
Column(verticalArrangement = Arrangement.spacedBy(8.dp)) {
|
Column(verticalArrangement = Arrangement.spacedBy(8.dp)) {
|
||||||
Text("Birthday", style = MaterialTheme.typography.titleSmall, fontWeight = FontWeight.SemiBold, color = MaterialTheme.colorScheme.primary)
|
Row(
|
||||||
|
verticalAlignment = Alignment.CenterVertically,
|
||||||
|
horizontalArrangement = Arrangement.spacedBy(8.dp)
|
||||||
|
) {
|
||||||
|
Text(
|
||||||
|
if (isChild) "Birthday *" else "Birthday",
|
||||||
|
style = MaterialTheme.typography.titleSmall,
|
||||||
|
fontWeight = FontWeight.SemiBold,
|
||||||
|
color = MaterialTheme.colorScheme.primary
|
||||||
|
)
|
||||||
|
if (isChild && dateOfBirth == null) {
|
||||||
|
Text(
|
||||||
|
"(required for age tags)",
|
||||||
|
style = MaterialTheme.typography.bodySmall,
|
||||||
|
color = MaterialTheme.colorScheme.error
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
OutlinedTextField(
|
OutlinedTextField(
|
||||||
value = dateOfBirth?.let { SimpleDateFormat("MMM d, yyyy", Locale.getDefault()).format(Date(it)) } ?: "",
|
value = dateOfBirth?.let { SimpleDateFormat("MMM d, yyyy", Locale.getDefault()).format(Date(it)) } ?: "",
|
||||||
onValueChange = {},
|
onValueChange = {},
|
||||||
@@ -169,8 +238,8 @@ fun BeautifulPersonInfoDialog(
|
|||||||
}
|
}
|
||||||
|
|
||||||
Button(
|
Button(
|
||||||
onClick = { onConfirm(name.trim(), dateOfBirth, selectedRelationship) },
|
onClick = { onConfirm(name.trim(), dateOfBirth, selectedRelationship, isChild) },
|
||||||
enabled = name.trim().isNotEmpty(),
|
enabled = name.trim().isNotEmpty() && (!isChild || dateOfBirth != null),
|
||||||
modifier = Modifier.weight(1f).height(56.dp),
|
modifier = Modifier.weight(1f).height(56.dp),
|
||||||
shape = RoundedCornerShape(16.dp)
|
shape = RoundedCornerShape(16.dp)
|
||||||
) {
|
) {
|
||||||
|
|||||||
@@ -6,8 +6,11 @@ import android.graphics.BitmapFactory
|
|||||||
import android.graphics.Rect
|
import android.graphics.Rect
|
||||||
import android.net.Uri
|
import android.net.Uri
|
||||||
import com.google.mlkit.vision.common.InputImage
|
import com.google.mlkit.vision.common.InputImage
|
||||||
|
import com.google.mlkit.vision.face.Face
|
||||||
import com.google.mlkit.vision.face.FaceDetection
|
import com.google.mlkit.vision.face.FaceDetection
|
||||||
import com.google.mlkit.vision.face.FaceDetectorOptions
|
import com.google.mlkit.vision.face.FaceDetectorOptions
|
||||||
|
import com.placeholder.sherpai2.domain.clustering.FaceQualityFilter
|
||||||
|
import com.placeholder.sherpai2.ml.FaceNormalizer
|
||||||
import kotlinx.coroutines.Dispatchers
|
import kotlinx.coroutines.Dispatchers
|
||||||
import kotlinx.coroutines.async
|
import kotlinx.coroutines.async
|
||||||
import kotlinx.coroutines.awaitAll
|
import kotlinx.coroutines.awaitAll
|
||||||
@@ -64,21 +67,30 @@ class FaceDetectionHelper(private val context: Context) {
|
|||||||
val inputImage = InputImage.fromBitmap(bitmap, 0)
|
val inputImage = InputImage.fromBitmap(bitmap, 0)
|
||||||
val faces = detector.process(inputImage).await()
|
val faces = detector.process(inputImage).await()
|
||||||
|
|
||||||
// Sort by face size (area) to get the largest face
|
// Filter to quality faces - use lenient scanning filter
|
||||||
val sortedFaces = faces.sortedByDescending { face ->
|
// (Discovery filter was too strict, rejecting faces from rolling scan)
|
||||||
|
val qualityFaces = faces.filter { face ->
|
||||||
|
FaceQualityFilter.validateForScanning(
|
||||||
|
face = face,
|
||||||
|
imageWidth = bitmap.width,
|
||||||
|
imageHeight = bitmap.height
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sort by face size (area) to get the largest quality face
|
||||||
|
val sortedFaces = qualityFaces.sortedByDescending { face ->
|
||||||
face.boundingBox.width() * face.boundingBox.height()
|
face.boundingBox.width() * face.boundingBox.height()
|
||||||
}
|
}
|
||||||
|
|
||||||
val croppedFace = if (sortedFaces.isNotEmpty()) {
|
val croppedFace = if (sortedFaces.isNotEmpty()) {
|
||||||
// Crop the LARGEST detected face (most likely the subject)
|
FaceNormalizer.cropAndNormalize(bitmap, sortedFaces[0])
|
||||||
cropFaceFromBitmap(bitmap, sortedFaces[0].boundingBox)
|
|
||||||
} else null
|
} else null
|
||||||
|
|
||||||
FaceDetectionResult(
|
FaceDetectionResult(
|
||||||
uri = uri,
|
uri = uri,
|
||||||
hasFace = faces.isNotEmpty(),
|
hasFace = qualityFaces.isNotEmpty(),
|
||||||
faceCount = faces.size,
|
faceCount = qualityFaces.size,
|
||||||
faceBounds = faces.map { it.boundingBox },
|
faceBounds = qualityFaces.map { it.boundingBox },
|
||||||
croppedFaceBitmap = croppedFace
|
croppedFaceBitmap = croppedFace
|
||||||
)
|
)
|
||||||
} catch (e: Exception) {
|
} catch (e: Exception) {
|
||||||
|
|||||||
@@ -19,31 +19,39 @@ import androidx.compose.ui.text.font.FontWeight
|
|||||||
import androidx.compose.ui.unit.dp
|
import androidx.compose.ui.unit.dp
|
||||||
import androidx.hilt.navigation.compose.hiltViewModel
|
import androidx.hilt.navigation.compose.hiltViewModel
|
||||||
import androidx.lifecycle.compose.collectAsStateWithLifecycle
|
import androidx.lifecycle.compose.collectAsStateWithLifecycle
|
||||||
import com.placeholder.sherpai2.data.local.dao.ImageDao
|
import com.placeholder.sherpai2.ui.rollingscan.RollingScanModeDialog
|
||||||
import com.placeholder.sherpai2.data.local.entity.ImageEntity
|
|
||||||
import kotlinx.coroutines.launch
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* OPTIMIZED ImageSelectorScreen
|
* ImageSelectorScreen - WITH ROLLING SCAN INTEGRATION
|
||||||
*
|
*
|
||||||
* 🎯 NEW FEATURE: Filter to only show face-tagged images!
|
* ENHANCED FEATURES:
|
||||||
* ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
|
* ✅ Smart filtering (photos with faces)
|
||||||
* - Uses face detection cache to pre-filter
|
* ✅ Rolling Scan integration (NEW!)
|
||||||
* - Shows "Only photos with faces" toggle
|
* ✅ Same signature as original
|
||||||
* - Dramatically faster photo selection
|
* ✅ Drop-in replacement
|
||||||
* - Better training quality (no manual filtering needed)
|
*
|
||||||
|
* FLOW:
|
||||||
|
* 1. User selects 3-5 photos
|
||||||
|
* 2. RollingScanModeDialog appears
|
||||||
|
* 3. User can:
|
||||||
|
* - Use Rolling Scan (recommended) → Navigate to Rolling Scan
|
||||||
|
* - Continue with current → Call onImagesSelected
|
||||||
|
* - Go back → Stay on selector
|
||||||
*/
|
*/
|
||||||
@OptIn(ExperimentalMaterial3Api::class)
|
@OptIn(ExperimentalMaterial3Api::class)
|
||||||
@Composable
|
@Composable
|
||||||
fun ImageSelectorScreen(
|
fun ImageSelectorScreen(
|
||||||
onImagesSelected: (List<Uri>) -> Unit
|
onImagesSelected: (List<Uri>) -> Unit,
|
||||||
|
// NEW: Optional callback for Rolling Scan navigation
|
||||||
|
// If null, Rolling Scan option is hidden
|
||||||
|
onLaunchRollingScan: ((seedImageIds: List<String>) -> Unit)? = null
|
||||||
) {
|
) {
|
||||||
// Inject ImageDao via Hilt ViewModel pattern
|
|
||||||
val viewModel: ImageSelectorViewModel = hiltViewModel()
|
val viewModel: ImageSelectorViewModel = hiltViewModel()
|
||||||
val faceTaggedUris by viewModel.faceTaggedImageUris.collectAsStateWithLifecycle()
|
val faceTaggedUris by viewModel.faceTaggedImageUris.collectAsStateWithLifecycle()
|
||||||
|
|
||||||
var selectedImages by remember { mutableStateOf<List<Uri>>(emptyList()) }
|
var selectedImages by remember { mutableStateOf<List<Uri>>(emptyList()) }
|
||||||
var onlyShowFaceImages by remember { mutableStateOf(true) } // Default: smart filtering
|
var onlyShowFaceImages by remember { mutableStateOf(true) }
|
||||||
|
var showRollingScanDialog by remember { mutableStateOf(false) } // NEW!
|
||||||
val scrollState = rememberScrollState()
|
val scrollState = rememberScrollState()
|
||||||
|
|
||||||
val photoPicker = rememberLauncherForActivityResult(
|
val photoPicker = rememberLauncherForActivityResult(
|
||||||
@@ -56,6 +64,13 @@ fun ImageSelectorScreen(
|
|||||||
} else {
|
} else {
|
||||||
uris
|
uris
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NEW: Show Rolling Scan dialog if:
|
||||||
|
// - Rolling Scan is available (callback provided)
|
||||||
|
// - User selected 3-10 photos (sweet spot)
|
||||||
|
if (onLaunchRollingScan != null && selectedImages.size in 3..10) {
|
||||||
|
showRollingScanDialog = true
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -159,12 +174,17 @@ fun ImageSelectorScreen(
|
|||||||
|
|
||||||
Column {
|
Column {
|
||||||
Text(
|
Text(
|
||||||
"Training Tips",
|
// NEW: Changed text if Rolling Scan available
|
||||||
|
if (onLaunchRollingScan != null) "Quick Start" else "Training Tips",
|
||||||
style = MaterialTheme.typography.titleLarge,
|
style = MaterialTheme.typography.titleLarge,
|
||||||
fontWeight = FontWeight.Bold
|
fontWeight = FontWeight.Bold
|
||||||
)
|
)
|
||||||
Text(
|
Text(
|
||||||
"More photos = better recognition",
|
// NEW: Changed text if Rolling Scan available
|
||||||
|
if (onLaunchRollingScan != null)
|
||||||
|
"Pick a few photos, we'll help find more"
|
||||||
|
else
|
||||||
|
"More photos = better recognition",
|
||||||
style = MaterialTheme.typography.bodyMedium,
|
style = MaterialTheme.typography.bodyMedium,
|
||||||
color = MaterialTheme.colorScheme.onPrimaryContainer.copy(alpha = 0.7f)
|
color = MaterialTheme.colorScheme.onPrimaryContainer.copy(alpha = 0.7f)
|
||||||
)
|
)
|
||||||
@@ -173,11 +193,18 @@ fun ImageSelectorScreen(
|
|||||||
|
|
||||||
Spacer(Modifier.height(4.dp))
|
Spacer(Modifier.height(4.dp))
|
||||||
|
|
||||||
TipItem("✓ Select 20-30 photos for best results", true)
|
// NEW: Different tips if Rolling Scan available
|
||||||
TipItem("✓ Include different angles and lighting", true)
|
if (onLaunchRollingScan != null) {
|
||||||
TipItem("✓ Mix expressions (smile, neutral, laugh)", true)
|
TipItem("✓ Start with just 3-5 good photos", true)
|
||||||
TipItem("✓ With/without glasses if applicable", true)
|
TipItem("✓ AI will find similar ones automatically", true)
|
||||||
TipItem("✗ Avoid blurry or very dark photos", false)
|
TipItem("✓ Or select all 20-30 manually if you prefer", true)
|
||||||
|
} else {
|
||||||
|
TipItem("✓ Select 20-30 photos for best results", true)
|
||||||
|
TipItem("✓ Include different angles and lighting", true)
|
||||||
|
TipItem("✓ Mix expressions (smile, neutral, laugh)", true)
|
||||||
|
TipItem("✓ With/without glasses if applicable", true)
|
||||||
|
TipItem("✗ Avoid blurry or very dark photos", false)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -195,20 +222,20 @@ fun ImageSelectorScreen(
|
|||||||
),
|
),
|
||||||
contentPadding = PaddingValues(vertical = 16.dp)
|
contentPadding = PaddingValues(vertical = 16.dp)
|
||||||
) {
|
) {
|
||||||
Icon(Icons.Default.PhotoLibrary, contentDescription = null)
|
Icon(Icons.Default.AddPhotoAlternate, contentDescription = null)
|
||||||
Spacer(Modifier.width(8.dp))
|
Spacer(Modifier.width(8.dp))
|
||||||
Text(
|
Text(
|
||||||
if (selectedImages.isEmpty()) {
|
// NEW: Different text if Rolling Scan available
|
||||||
"Select Training Photos"
|
if (onLaunchRollingScan != null)
|
||||||
} else {
|
"Pick Seed Photos"
|
||||||
"Selected: ${selectedImages.size} photos - Tap to change"
|
else
|
||||||
},
|
"Select Photos",
|
||||||
style = MaterialTheme.typography.titleMedium
|
style = MaterialTheme.typography.titleMedium
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Continue button
|
// Continue button (only if photos selected)
|
||||||
AnimatedVisibility(selectedImages.size >= 15) {
|
AnimatedVisibility(selectedImages.isNotEmpty()) {
|
||||||
Button(
|
Button(
|
||||||
onClick = { onImagesSelected(selectedImages) },
|
onClick = { onImagesSelected(selectedImages) },
|
||||||
modifier = Modifier.fillMaxWidth(),
|
modifier = Modifier.fillMaxWidth(),
|
||||||
@@ -261,10 +288,34 @@ fun ImageSelectorScreen(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Bottom spacing to ensure last item is visible
|
// Bottom spacing
|
||||||
Spacer(Modifier.height(32.dp))
|
Spacer(Modifier.height(32.dp))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NEW: Rolling Scan Mode Dialog
|
||||||
|
if (showRollingScanDialog && selectedImages.isNotEmpty() && onLaunchRollingScan != null) {
|
||||||
|
RollingScanModeDialog(
|
||||||
|
currentPhotoCount = selectedImages.size,
|
||||||
|
onUseRollingScan = {
|
||||||
|
showRollingScanDialog = false
|
||||||
|
|
||||||
|
// Convert URIs to image IDs
|
||||||
|
// Note: Using URI strings as IDs for now
|
||||||
|
// RollingScanViewModel will convert to actual IDs
|
||||||
|
val seedImageIds = selectedImages.map { it.toString() }
|
||||||
|
onLaunchRollingScan(seedImageIds)
|
||||||
|
},
|
||||||
|
onContinueWithCurrent = {
|
||||||
|
showRollingScanDialog = false
|
||||||
|
onImagesSelected(selectedImages)
|
||||||
|
},
|
||||||
|
onDismiss = {
|
||||||
|
showRollingScanDialog = false
|
||||||
|
// Keep selection, user can re-pick or continue
|
||||||
|
}
|
||||||
|
)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Composable
|
@Composable
|
||||||
|
|||||||
@@ -51,57 +51,41 @@ fun ScanResultsScreen(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Scaffold(
|
// No Scaffold - MainScreen provides TopAppBar
|
||||||
topBar = {
|
Box(modifier = Modifier.fillMaxSize()) {
|
||||||
TopAppBar(
|
when (state) {
|
||||||
title = { Text("Train New Person") },
|
is ScanningState.Idle -> {}
|
||||||
colors = TopAppBarDefaults.topAppBarColors(
|
|
||||||
containerColor = MaterialTheme.colorScheme.primaryContainer
|
is ScanningState.Processing -> {
|
||||||
|
ProcessingView(progress = state.progress, total = state.total)
|
||||||
|
}
|
||||||
|
|
||||||
|
is ScanningState.Success -> {
|
||||||
|
ImprovedResultsView(
|
||||||
|
result = state.sanityCheckResult,
|
||||||
|
onContinue = {
|
||||||
|
trainViewModel.createFaceModel(
|
||||||
|
trainViewModel.getPersonInfo()?.name ?: "Unknown"
|
||||||
|
)
|
||||||
|
},
|
||||||
|
onRetry = onFinish,
|
||||||
|
onReplaceImage = { oldUri, newUri ->
|
||||||
|
trainViewModel.replaceImage(oldUri, newUri)
|
||||||
|
},
|
||||||
|
onSelectFaceFromMultiple = { result ->
|
||||||
|
showFacePickerDialog = result
|
||||||
|
},
|
||||||
|
trainViewModel = trainViewModel
|
||||||
)
|
)
|
||||||
)
|
}
|
||||||
|
|
||||||
|
is ScanningState.Error -> {
|
||||||
|
ErrorView(message = state.message, onRetry = onFinish)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
) { paddingValues ->
|
|
||||||
Box(
|
|
||||||
modifier = Modifier
|
|
||||||
.fillMaxSize()
|
|
||||||
.padding(paddingValues)
|
|
||||||
) {
|
|
||||||
when (state) {
|
|
||||||
is ScanningState.Idle -> {}
|
|
||||||
|
|
||||||
is ScanningState.Processing -> {
|
if (trainingState is TrainingState.Processing) {
|
||||||
ProcessingView(progress = state.progress, total = state.total)
|
TrainingOverlay(trainingState = trainingState as TrainingState.Processing)
|
||||||
}
|
|
||||||
|
|
||||||
is ScanningState.Success -> {
|
|
||||||
ImprovedResultsView(
|
|
||||||
result = state.sanityCheckResult,
|
|
||||||
onContinue = {
|
|
||||||
// PersonInfo already captured in TrainingScreen!
|
|
||||||
// Just start training with stored info
|
|
||||||
trainViewModel.createFaceModel(
|
|
||||||
trainViewModel.getPersonInfo()?.name ?: "Unknown"
|
|
||||||
)
|
|
||||||
},
|
|
||||||
onRetry = onFinish,
|
|
||||||
onReplaceImage = { oldUri, newUri ->
|
|
||||||
trainViewModel.replaceImage(oldUri, newUri)
|
|
||||||
},
|
|
||||||
onSelectFaceFromMultiple = { result ->
|
|
||||||
showFacePickerDialog = result
|
|
||||||
},
|
|
||||||
trainViewModel = trainViewModel
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
is ScanningState.Error -> {
|
|
||||||
ErrorView(message = state.message, onRetry = onFinish)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (trainingState is TrainingState.Processing) {
|
|
||||||
TrainingOverlay(trainingState = trainingState as TrainingState.Processing)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -5,11 +5,18 @@ import android.graphics.Bitmap
|
|||||||
import android.net.Uri
|
import android.net.Uri
|
||||||
import androidx.lifecycle.AndroidViewModel
|
import androidx.lifecycle.AndroidViewModel
|
||||||
import androidx.lifecycle.viewModelScope
|
import androidx.lifecycle.viewModelScope
|
||||||
|
import androidx.datastore.preferences.core.booleanPreferencesKey
|
||||||
|
import androidx.datastore.preferences.preferencesDataStore
|
||||||
|
import androidx.work.WorkManager
|
||||||
|
import android.content.Context
|
||||||
import com.placeholder.sherpai2.data.local.entity.PersonEntity
|
import com.placeholder.sherpai2.data.local.entity.PersonEntity
|
||||||
import com.placeholder.sherpai2.data.repository.FaceRecognitionRepository
|
import com.placeholder.sherpai2.data.repository.FaceRecognitionRepository
|
||||||
import com.placeholder.sherpai2.ml.FaceNetModel
|
import com.placeholder.sherpai2.ml.FaceNetModel
|
||||||
|
import com.placeholder.sherpai2.workers.LibraryScanWorker
|
||||||
import dagger.hilt.android.lifecycle.HiltViewModel
|
import dagger.hilt.android.lifecycle.HiltViewModel
|
||||||
import kotlinx.coroutines.flow.MutableStateFlow
|
import kotlinx.coroutines.flow.MutableStateFlow
|
||||||
|
import kotlinx.coroutines.flow.first
|
||||||
|
import kotlinx.coroutines.flow.map
|
||||||
import kotlinx.coroutines.flow.StateFlow
|
import kotlinx.coroutines.flow.StateFlow
|
||||||
import kotlinx.coroutines.flow.asStateFlow
|
import kotlinx.coroutines.flow.asStateFlow
|
||||||
import kotlinx.coroutines.launch
|
import kotlinx.coroutines.launch
|
||||||
@@ -41,21 +48,27 @@ sealed class TrainingState {
|
|||||||
data class PersonInfo(
|
data class PersonInfo(
|
||||||
val name: String,
|
val name: String,
|
||||||
val dateOfBirth: Long?,
|
val dateOfBirth: Long?,
|
||||||
val relationship: String
|
val relationship: String,
|
||||||
|
val isChild: Boolean = false
|
||||||
)
|
)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* FIXED TrainViewModel with proper exclude functionality and efficient replace
|
* FIXED TrainViewModel with proper exclude functionality and efficient replace
|
||||||
*/
|
*/
|
||||||
|
private val android.content.Context.dataStore by preferencesDataStore(name = "settings")
|
||||||
|
private val KEY_BACKGROUND_TAGGING = booleanPreferencesKey("background_recognition_tagging")
|
||||||
|
|
||||||
@HiltViewModel
|
@HiltViewModel
|
||||||
class TrainViewModel @Inject constructor(
|
class TrainViewModel @Inject constructor(
|
||||||
application: Application,
|
application: Application,
|
||||||
private val faceRecognitionRepository: FaceRecognitionRepository,
|
private val faceRecognitionRepository: FaceRecognitionRepository,
|
||||||
private val faceNetModel: FaceNetModel
|
private val faceNetModel: FaceNetModel,
|
||||||
|
private val workManager: WorkManager
|
||||||
) : AndroidViewModel(application) {
|
) : AndroidViewModel(application) {
|
||||||
|
|
||||||
private val sanityChecker = TrainingSanityChecker(application)
|
private val sanityChecker = TrainingSanityChecker(application)
|
||||||
private val faceDetectionHelper = FaceDetectionHelper(application)
|
private val faceDetectionHelper = FaceDetectionHelper(application)
|
||||||
|
private val dataStore = application.dataStore
|
||||||
|
|
||||||
private val _uiState = MutableStateFlow<ScanningState>(ScanningState.Idle)
|
private val _uiState = MutableStateFlow<ScanningState>(ScanningState.Idle)
|
||||||
val uiState: StateFlow<ScanningState> = _uiState.asStateFlow()
|
val uiState: StateFlow<ScanningState> = _uiState.asStateFlow()
|
||||||
@@ -80,8 +93,8 @@ class TrainViewModel @Inject constructor(
|
|||||||
/**
|
/**
|
||||||
* Store person info before photo selection
|
* Store person info before photo selection
|
||||||
*/
|
*/
|
||||||
fun setPersonInfo(name: String, dateOfBirth: Long?, relationship: String) {
|
fun setPersonInfo(name: String, dateOfBirth: Long?, relationship: String, isChild: Boolean = false) {
|
||||||
personInfo = PersonInfo(name, dateOfBirth, relationship)
|
personInfo = PersonInfo(name, dateOfBirth, relationship, isChild)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -151,6 +164,7 @@ class TrainViewModel @Inject constructor(
|
|||||||
val person = PersonEntity.create(
|
val person = PersonEntity.create(
|
||||||
name = personName,
|
name = personName,
|
||||||
dateOfBirth = personInfo?.dateOfBirth,
|
dateOfBirth = personInfo?.dateOfBirth,
|
||||||
|
isChild = personInfo?.isChild ?: false,
|
||||||
relationship = personInfo?.relationship
|
relationship = personInfo?.relationship
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -172,6 +186,20 @@ class TrainViewModel @Inject constructor(
|
|||||||
relationship = person.relationship
|
relationship = person.relationship
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Trigger library scan if setting enabled
|
||||||
|
val backgroundTaggingEnabled = dataStore.data
|
||||||
|
.map { it[KEY_BACKGROUND_TAGGING] ?: true }
|
||||||
|
.first()
|
||||||
|
|
||||||
|
if (backgroundTaggingEnabled) {
|
||||||
|
// Use default threshold (0.62 solo, 0.68 group)
|
||||||
|
val scanRequest = LibraryScanWorker.createWorkRequest(
|
||||||
|
personId = personId,
|
||||||
|
personName = personName
|
||||||
|
)
|
||||||
|
workManager.enqueue(scanRequest)
|
||||||
|
}
|
||||||
|
|
||||||
} catch (e: Exception) {
|
} catch (e: Exception) {
|
||||||
_trainingState.value = TrainingState.Error(
|
_trainingState.value = TrainingState.Error(
|
||||||
e.message ?: "Failed to create face model"
|
e.message ?: "Failed to create face model"
|
||||||
@@ -353,7 +381,7 @@ class TrainViewModel @Inject constructor(
|
|||||||
faceDetectionResults = updatedFaceResults,
|
faceDetectionResults = updatedFaceResults,
|
||||||
validationErrors = updatedErrors,
|
validationErrors = updatedErrors,
|
||||||
validImagesWithFaces = updatedValidImages,
|
validImagesWithFaces = updatedValidImages,
|
||||||
excludedImages = excludedImages
|
excludedImages = excludedImages.toSet() // Immutable copy for Compose state detection
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -61,9 +61,9 @@ fun TrainingScreen(
|
|||||||
if (showInfoDialog) {
|
if (showInfoDialog) {
|
||||||
BeautifulPersonInfoDialog(
|
BeautifulPersonInfoDialog(
|
||||||
onDismiss = { showInfoDialog = false },
|
onDismiss = { showInfoDialog = false },
|
||||||
onConfirm = { name, dob, relationship ->
|
onConfirm = { name, dob, relationship, isChild ->
|
||||||
showInfoDialog = false
|
showInfoDialog = false
|
||||||
trainViewModel.setPersonInfo(name, dob, relationship)
|
trainViewModel.setPersonInfo(name, dob, relationship, isChild)
|
||||||
onSelectImages()
|
onSelectImages()
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package com.placeholder.sherpai2.ui.trainingprep
|
package com.placeholder.sherpai2.ui.trainingprep
|
||||||
|
|
||||||
import androidx.compose.animation.AnimatedVisibility
|
import androidx.compose.animation.AnimatedVisibility
|
||||||
|
import androidx.compose.animation.core.animateFloatAsState
|
||||||
import androidx.compose.foundation.BorderStroke
|
import androidx.compose.foundation.BorderStroke
|
||||||
import androidx.compose.foundation.ExperimentalFoundationApi
|
import androidx.compose.foundation.ExperimentalFoundationApi
|
||||||
import androidx.compose.foundation.background
|
import androidx.compose.foundation.background
|
||||||
@@ -15,7 +16,7 @@ import androidx.compose.material3.*
|
|||||||
import androidx.compose.runtime.*
|
import androidx.compose.runtime.*
|
||||||
import androidx.compose.ui.Alignment
|
import androidx.compose.ui.Alignment
|
||||||
import androidx.compose.ui.Modifier
|
import androidx.compose.ui.Modifier
|
||||||
import androidx.compose.ui.draw.clip
|
import androidx.compose.ui.draw.alpha
|
||||||
import androidx.compose.ui.graphics.Color
|
import androidx.compose.ui.graphics.Color
|
||||||
import androidx.compose.ui.layout.ContentScale
|
import androidx.compose.ui.layout.ContentScale
|
||||||
import androidx.compose.ui.text.font.FontWeight
|
import androidx.compose.ui.text.font.FontWeight
|
||||||
@@ -26,50 +27,79 @@ import coil.compose.AsyncImage
|
|||||||
import com.placeholder.sherpai2.data.local.entity.ImageEntity
|
import com.placeholder.sherpai2.data.local.entity.ImageEntity
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* TrainingPhotoSelectorScreen - Smart photo selector for face training
|
* TrainingPhotoSelectorScreen - PREMIUM GRID + ROLLING SCAN
|
||||||
*
|
*
|
||||||
* SOLVES THE PROBLEM:
|
* FLOW:
|
||||||
* - User has 10,000 photos total
|
* 1. Shows PREMIUM faces only (solo, large, frontal)
|
||||||
* - Only ~500 have faces (hasFaces=true)
|
* 2. User picks 1-3 seed photos
|
||||||
* - Shows ONLY photos with faces
|
* 3. "Find Similar" button appears → launches RollingScanScreen
|
||||||
* - Multi-select mode for quick selection
|
* 4. Toggle to show all photos if needed
|
||||||
* - Face count badges on each photo
|
|
||||||
* - Minimum 15 photos enforced
|
|
||||||
*
|
|
||||||
* REUSES:
|
|
||||||
* - Existing ImageDao.getImagesWithFaces()
|
|
||||||
* - Existing face detection cache
|
|
||||||
* - Proven album grid layout
|
|
||||||
*/
|
*/
|
||||||
@OptIn(ExperimentalMaterial3Api::class, ExperimentalFoundationApi::class)
|
@OptIn(ExperimentalMaterial3Api::class, ExperimentalFoundationApi::class)
|
||||||
@Composable
|
@Composable
|
||||||
fun TrainingPhotoSelectorScreen(
|
fun TrainingPhotoSelectorScreen(
|
||||||
onBack: () -> Unit,
|
onBack: () -> Unit,
|
||||||
onPhotosSelected: (List<android.net.Uri>) -> Unit,
|
onPhotosSelected: (List<android.net.Uri>) -> Unit,
|
||||||
|
onLaunchRollingScan: ((List<String>) -> Unit)? = null, // NEW: Navigate to rolling scan
|
||||||
viewModel: TrainingPhotoSelectorViewModel = hiltViewModel()
|
viewModel: TrainingPhotoSelectorViewModel = hiltViewModel()
|
||||||
) {
|
) {
|
||||||
val photos by viewModel.photosWithFaces.collectAsStateWithLifecycle()
|
val photos by viewModel.photosWithFaces.collectAsStateWithLifecycle()
|
||||||
val selectedPhotos by viewModel.selectedPhotos.collectAsStateWithLifecycle()
|
val selectedPhotos by viewModel.selectedPhotos.collectAsStateWithLifecycle()
|
||||||
val isLoading by viewModel.isLoading.collectAsStateWithLifecycle()
|
val isLoading by viewModel.isLoading.collectAsStateWithLifecycle()
|
||||||
|
val isRanking by viewModel.isRanking.collectAsStateWithLifecycle()
|
||||||
|
val showPremiumOnly by viewModel.showPremiumOnly.collectAsStateWithLifecycle()
|
||||||
|
val premiumCount by viewModel.premiumCount.collectAsStateWithLifecycle()
|
||||||
|
val embeddingProgress by viewModel.embeddingProgress.collectAsStateWithLifecycle()
|
||||||
|
|
||||||
Scaffold(
|
Scaffold(
|
||||||
topBar = {
|
topBar = {
|
||||||
TopAppBar(
|
TopAppBar(
|
||||||
title = {
|
title = {
|
||||||
Column {
|
Column {
|
||||||
|
Row(
|
||||||
|
horizontalArrangement = Arrangement.spacedBy(8.dp),
|
||||||
|
verticalAlignment = Alignment.CenterVertically
|
||||||
|
) {
|
||||||
|
Text(
|
||||||
|
if (selectedPhotos.isEmpty()) {
|
||||||
|
"Select Training Photos"
|
||||||
|
} else {
|
||||||
|
"${selectedPhotos.size} selected"
|
||||||
|
},
|
||||||
|
style = MaterialTheme.typography.titleLarge,
|
||||||
|
fontWeight = FontWeight.Bold
|
||||||
|
)
|
||||||
|
|
||||||
|
// NEW: Ranking indicator
|
||||||
|
if (isRanking) {
|
||||||
|
CircularProgressIndicator(
|
||||||
|
modifier = Modifier.size(16.dp),
|
||||||
|
strokeWidth = 2.dp,
|
||||||
|
color = MaterialTheme.colorScheme.primary
|
||||||
|
)
|
||||||
|
} else if (selectedPhotos.isNotEmpty()) {
|
||||||
|
Icon(
|
||||||
|
Icons.Default.AutoAwesome,
|
||||||
|
contentDescription = "AI Ranked",
|
||||||
|
modifier = Modifier.size(20.dp),
|
||||||
|
tint = MaterialTheme.colorScheme.primary
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Status text
|
||||||
Text(
|
Text(
|
||||||
if (selectedPhotos.isEmpty()) {
|
when {
|
||||||
"Select Training Photos"
|
isRanking -> "Ranking similar photos..."
|
||||||
} else {
|
showPremiumOnly -> "Showing $premiumCount premium faces"
|
||||||
"${selectedPhotos.size} selected"
|
else -> "Showing ${photos.size} photos with faces"
|
||||||
},
|
},
|
||||||
style = MaterialTheme.typography.titleLarge,
|
|
||||||
fontWeight = FontWeight.Bold
|
|
||||||
)
|
|
||||||
Text(
|
|
||||||
"Showing ${photos.size} photos with faces",
|
|
||||||
style = MaterialTheme.typography.bodySmall,
|
style = MaterialTheme.typography.bodySmall,
|
||||||
color = MaterialTheme.colorScheme.onSurfaceVariant
|
color = when {
|
||||||
|
isRanking -> MaterialTheme.colorScheme.primary
|
||||||
|
showPremiumOnly -> MaterialTheme.colorScheme.tertiary
|
||||||
|
else -> MaterialTheme.colorScheme.onSurfaceVariant
|
||||||
|
}
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
@@ -79,6 +109,14 @@ fun TrainingPhotoSelectorScreen(
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
actions = {
|
actions = {
|
||||||
|
// Toggle premium/all
|
||||||
|
IconButton(onClick = { viewModel.togglePremiumOnly() }) {
|
||||||
|
Icon(
|
||||||
|
if (showPremiumOnly) Icons.Default.Star else Icons.Default.GridView,
|
||||||
|
contentDescription = if (showPremiumOnly) "Show all" else "Show premium only",
|
||||||
|
tint = if (showPremiumOnly) MaterialTheme.colorScheme.tertiary else MaterialTheme.colorScheme.onSurface
|
||||||
|
)
|
||||||
|
}
|
||||||
if (selectedPhotos.isNotEmpty()) {
|
if (selectedPhotos.isNotEmpty()) {
|
||||||
TextButton(onClick = { viewModel.clearSelection() }) {
|
TextButton(onClick = { viewModel.clearSelection() }) {
|
||||||
Text("Clear")
|
Text("Clear")
|
||||||
@@ -94,7 +132,11 @@ fun TrainingPhotoSelectorScreen(
|
|||||||
AnimatedVisibility(visible = selectedPhotos.isNotEmpty()) {
|
AnimatedVisibility(visible = selectedPhotos.isNotEmpty()) {
|
||||||
SelectionBottomBar(
|
SelectionBottomBar(
|
||||||
selectedCount = selectedPhotos.size,
|
selectedCount = selectedPhotos.size,
|
||||||
|
canLaunchRollingScan = viewModel.canLaunchRollingScan && onLaunchRollingScan != null,
|
||||||
onClear = { viewModel.clearSelection() },
|
onClear = { viewModel.clearSelection() },
|
||||||
|
onFindSimilar = {
|
||||||
|
onLaunchRollingScan?.invoke(viewModel.getSeedImageIds())
|
||||||
|
},
|
||||||
onContinue = {
|
onContinue = {
|
||||||
val uris = selectedPhotos.map { android.net.Uri.parse(it.imageUri) }
|
val uris = selectedPhotos.map { android.net.Uri.parse(it.imageUri) }
|
||||||
onPhotosSelected(uris)
|
onPhotosSelected(uris)
|
||||||
@@ -114,7 +156,33 @@ fun TrainingPhotoSelectorScreen(
|
|||||||
modifier = Modifier.fillMaxSize(),
|
modifier = Modifier.fillMaxSize(),
|
||||||
contentAlignment = Alignment.Center
|
contentAlignment = Alignment.Center
|
||||||
) {
|
) {
|
||||||
CircularProgressIndicator()
|
Column(
|
||||||
|
horizontalAlignment = Alignment.CenterHorizontally,
|
||||||
|
verticalArrangement = Arrangement.spacedBy(16.dp)
|
||||||
|
) {
|
||||||
|
CircularProgressIndicator()
|
||||||
|
// Capture value to avoid race condition
|
||||||
|
val progress = embeddingProgress
|
||||||
|
if (progress != null) {
|
||||||
|
Text(
|
||||||
|
"Preparing faces: ${progress.current}/${progress.total}",
|
||||||
|
style = MaterialTheme.typography.bodyMedium,
|
||||||
|
color = MaterialTheme.colorScheme.onSurfaceVariant
|
||||||
|
)
|
||||||
|
LinearProgressIndicator(
|
||||||
|
progress = { progress.current.toFloat() / progress.total },
|
||||||
|
modifier = Modifier
|
||||||
|
.width(200.dp)
|
||||||
|
.padding(top = 8.dp)
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
Text(
|
||||||
|
"Loading premium faces...",
|
||||||
|
style = MaterialTheme.typography.bodyMedium,
|
||||||
|
color = MaterialTheme.colorScheme.onSurfaceVariant
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
photos.isEmpty() -> {
|
photos.isEmpty() -> {
|
||||||
@@ -135,7 +203,9 @@ fun TrainingPhotoSelectorScreen(
|
|||||||
@Composable
|
@Composable
|
||||||
private fun SelectionBottomBar(
|
private fun SelectionBottomBar(
|
||||||
selectedCount: Int,
|
selectedCount: Int,
|
||||||
|
canLaunchRollingScan: Boolean,
|
||||||
onClear: () -> Unit,
|
onClear: () -> Unit,
|
||||||
|
onFindSimilar: () -> Unit,
|
||||||
onContinue: () -> Unit
|
onContinue: () -> Unit
|
||||||
) {
|
) {
|
||||||
Surface(
|
Surface(
|
||||||
@@ -143,42 +213,72 @@ private fun SelectionBottomBar(
|
|||||||
color = MaterialTheme.colorScheme.primaryContainer,
|
color = MaterialTheme.colorScheme.primaryContainer,
|
||||||
shadowElevation = 8.dp
|
shadowElevation = 8.dp
|
||||||
) {
|
) {
|
||||||
Row(
|
Column(
|
||||||
modifier = Modifier
|
modifier = Modifier
|
||||||
.fillMaxWidth()
|
.fillMaxWidth()
|
||||||
.padding(16.dp),
|
.padding(16.dp)
|
||||||
horizontalArrangement = Arrangement.SpaceBetween,
|
|
||||||
verticalAlignment = Alignment.CenterVertically
|
|
||||||
) {
|
) {
|
||||||
Column {
|
Row(
|
||||||
Text(
|
modifier = Modifier.fillMaxWidth(),
|
||||||
"$selectedCount photos selected",
|
horizontalArrangement = Arrangement.SpaceBetween,
|
||||||
style = MaterialTheme.typography.titleMedium,
|
verticalAlignment = Alignment.CenterVertically
|
||||||
fontWeight = FontWeight.Bold
|
) {
|
||||||
)
|
Column {
|
||||||
Text(
|
Text(
|
||||||
when {
|
"$selectedCount seed${if (selectedCount != 1) "s" else ""} selected",
|
||||||
selectedCount < 15 -> "Need ${15 - selectedCount} more"
|
style = MaterialTheme.typography.titleMedium,
|
||||||
selectedCount < 20 -> "Good start!"
|
fontWeight = FontWeight.Bold
|
||||||
selectedCount < 30 -> "Great selection!"
|
)
|
||||||
else -> "Excellent coverage!"
|
Text(
|
||||||
},
|
when {
|
||||||
style = MaterialTheme.typography.bodySmall,
|
selectedCount == 0 -> "Pick 1-3 clear photos of the same person"
|
||||||
color = when {
|
selectedCount in 1..3 -> "Tap 'Find Similar' to discover more"
|
||||||
selectedCount < 15 -> MaterialTheme.colorScheme.error
|
selectedCount < 15 -> "Need ${15 - selectedCount} more for training"
|
||||||
else -> MaterialTheme.colorScheme.onPrimaryContainer.copy(alpha = 0.8f)
|
else -> "Ready to train!"
|
||||||
}
|
},
|
||||||
)
|
style = MaterialTheme.typography.bodySmall,
|
||||||
}
|
color = when {
|
||||||
|
selectedCount in 1..3 -> MaterialTheme.colorScheme.tertiary
|
||||||
|
selectedCount < 15 -> MaterialTheme.colorScheme.error
|
||||||
|
else -> MaterialTheme.colorScheme.onPrimaryContainer.copy(alpha = 0.8f)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
Row(horizontalArrangement = Arrangement.spacedBy(8.dp)) {
|
|
||||||
OutlinedButton(onClick = onClear) {
|
OutlinedButton(onClick = onClear) {
|
||||||
Text("Clear")
|
Text("Clear")
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Spacer(Modifier.height(12.dp))
|
||||||
|
|
||||||
|
Row(
|
||||||
|
modifier = Modifier.fillMaxWidth(),
|
||||||
|
horizontalArrangement = Arrangement.spacedBy(8.dp)
|
||||||
|
) {
|
||||||
|
// Find Similar button (prominent when 1-5 seeds selected)
|
||||||
|
Button(
|
||||||
|
onClick = onFindSimilar,
|
||||||
|
enabled = canLaunchRollingScan,
|
||||||
|
modifier = Modifier.weight(1f),
|
||||||
|
colors = ButtonDefaults.buttonColors(
|
||||||
|
containerColor = MaterialTheme.colorScheme.tertiary
|
||||||
|
)
|
||||||
|
) {
|
||||||
|
Icon(
|
||||||
|
Icons.Default.AutoAwesome,
|
||||||
|
contentDescription = null,
|
||||||
|
modifier = Modifier.size(20.dp)
|
||||||
|
)
|
||||||
|
Spacer(Modifier.width(8.dp))
|
||||||
|
Text("Find Similar")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Continue button (for manual selection path)
|
||||||
Button(
|
Button(
|
||||||
onClick = onContinue,
|
onClick = onContinue,
|
||||||
enabled = selectedCount >= 15
|
enabled = selectedCount >= 15,
|
||||||
|
modifier = Modifier.weight(1f)
|
||||||
) {
|
) {
|
||||||
Icon(
|
Icon(
|
||||||
Icons.Default.Check,
|
Icons.Default.Check,
|
||||||
@@ -186,7 +286,7 @@ private fun SelectionBottomBar(
|
|||||||
modifier = Modifier.size(20.dp)
|
modifier = Modifier.size(20.dp)
|
||||||
)
|
)
|
||||||
Spacer(Modifier.width(8.dp))
|
Spacer(Modifier.width(8.dp))
|
||||||
Text("Continue")
|
Text("Train ($selectedCount)")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -205,7 +305,7 @@ private fun PhotoGrid(
|
|||||||
contentPadding = PaddingValues(
|
contentPadding = PaddingValues(
|
||||||
start = 4.dp,
|
start = 4.dp,
|
||||||
end = 4.dp,
|
end = 4.dp,
|
||||||
bottom = 100.dp // Space for bottom bar
|
bottom = 100.dp
|
||||||
),
|
),
|
||||||
horizontalArrangement = Arrangement.spacedBy(4.dp),
|
horizontalArrangement = Arrangement.spacedBy(4.dp),
|
||||||
verticalArrangement = Arrangement.spacedBy(4.dp)
|
verticalArrangement = Arrangement.spacedBy(4.dp)
|
||||||
@@ -230,10 +330,17 @@ private fun PhotoThumbnail(
|
|||||||
isSelected: Boolean,
|
isSelected: Boolean,
|
||||||
onClick: () -> Unit
|
onClick: () -> Unit
|
||||||
) {
|
) {
|
||||||
|
// NEW: Fade animation for non-selected photos
|
||||||
|
val alpha by animateFloatAsState(
|
||||||
|
targetValue = if (isSelected) 1f else 1f,
|
||||||
|
label = "photoAlpha"
|
||||||
|
)
|
||||||
|
|
||||||
Card(
|
Card(
|
||||||
modifier = Modifier
|
modifier = Modifier
|
||||||
.fillMaxWidth()
|
.fillMaxWidth()
|
||||||
.aspectRatio(1f)
|
.aspectRatio(1f)
|
||||||
|
.alpha(alpha)
|
||||||
.combinedClickable(onClick = onClick),
|
.combinedClickable(onClick = onClick),
|
||||||
shape = RoundedCornerShape(4.dp),
|
shape = RoundedCornerShape(4.dp),
|
||||||
border = if (isSelected) {
|
border = if (isSelected) {
|
||||||
|
|||||||
@@ -1,119 +1,449 @@
|
|||||||
package com.placeholder.sherpai2.ui.trainingprep
|
package com.placeholder.sherpai2.ui.trainingprep
|
||||||
|
|
||||||
import androidx.lifecycle.ViewModel
|
import android.app.Application
|
||||||
|
import android.graphics.Bitmap
|
||||||
|
import android.graphics.BitmapFactory
|
||||||
|
import android.graphics.Rect
|
||||||
|
import android.net.Uri
|
||||||
|
import android.util.Log
|
||||||
|
import androidx.lifecycle.AndroidViewModel
|
||||||
import androidx.lifecycle.viewModelScope
|
import androidx.lifecycle.viewModelScope
|
||||||
|
import com.placeholder.sherpai2.data.local.dao.FaceCacheDao
|
||||||
import com.placeholder.sherpai2.data.local.dao.ImageDao
|
import com.placeholder.sherpai2.data.local.dao.ImageDao
|
||||||
|
import com.placeholder.sherpai2.data.local.entity.FaceCacheEntity
|
||||||
import com.placeholder.sherpai2.data.local.entity.ImageEntity
|
import com.placeholder.sherpai2.data.local.entity.ImageEntity
|
||||||
|
import com.placeholder.sherpai2.domain.similarity.FaceSimilarityScorer
|
||||||
|
import com.placeholder.sherpai2.ml.FaceNetModel
|
||||||
import dagger.hilt.android.lifecycle.HiltViewModel
|
import dagger.hilt.android.lifecycle.HiltViewModel
|
||||||
|
import kotlinx.coroutines.Dispatchers
|
||||||
|
import kotlinx.coroutines.Job
|
||||||
|
import kotlinx.coroutines.delay
|
||||||
import kotlinx.coroutines.flow.MutableStateFlow
|
import kotlinx.coroutines.flow.MutableStateFlow
|
||||||
import kotlinx.coroutines.flow.StateFlow
|
import kotlinx.coroutines.flow.StateFlow
|
||||||
import kotlinx.coroutines.flow.asStateFlow
|
import kotlinx.coroutines.flow.asStateFlow
|
||||||
import kotlinx.coroutines.launch
|
import kotlinx.coroutines.launch
|
||||||
|
import kotlinx.coroutines.withContext
|
||||||
import javax.inject.Inject
|
import javax.inject.Inject
|
||||||
|
import kotlin.math.max
|
||||||
|
import kotlin.math.min
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* TrainingPhotoSelectorViewModel - Smart photo selector for training
|
* TrainingPhotoSelectorViewModel - PREMIUM GRID + ROLLING SCAN
|
||||||
*
|
*
|
||||||
* KEY OPTIMIZATION:
|
* FLOW:
|
||||||
* - Only loads images with hasFaces=true from database
|
* 1. Start with PREMIUM faces only (solo, large, frontal, high quality)
|
||||||
* - Result: 10,000 photos → ~500 with faces
|
* 2. User picks 1-3 seed photos
|
||||||
* - User can quickly select 20-30 good ones
|
* 3. User taps "Find Similar" → navigate to RollingScanScreen
|
||||||
* - Multi-select state management
|
* 4. RollingScanScreen returns with full selection
|
||||||
*/
|
*/
|
||||||
@HiltViewModel
|
@HiltViewModel
|
||||||
class TrainingPhotoSelectorViewModel @Inject constructor(
|
class TrainingPhotoSelectorViewModel @Inject constructor(
|
||||||
private val imageDao: ImageDao
|
application: Application,
|
||||||
) : ViewModel() {
|
private val imageDao: ImageDao,
|
||||||
|
private val faceCacheDao: FaceCacheDao,
|
||||||
|
private val faceSimilarityScorer: FaceSimilarityScorer,
|
||||||
|
private val faceNetModel: FaceNetModel
|
||||||
|
) : AndroidViewModel(application) {
|
||||||
|
|
||||||
|
companion object {
|
||||||
|
private const val TAG = "PremiumSelector"
|
||||||
|
private const val MIN_SEEDS_FOR_ROLLING_SCAN = 1
|
||||||
|
private const val MAX_SEEDS_FOR_ROLLING_SCAN = 5
|
||||||
|
private const val MAX_EMBEDDINGS_TO_GENERATE = 500
|
||||||
|
}
|
||||||
|
|
||||||
|
// All photos (for fallback / full list)
|
||||||
|
private var allPhotosWithFaces: List<ImageEntity> = emptyList()
|
||||||
|
|
||||||
|
// Premium-only photos (initial view)
|
||||||
|
private var premiumPhotos: List<ImageEntity> = emptyList()
|
||||||
|
|
||||||
// Photos with faces (hasFaces=true)
|
|
||||||
private val _photosWithFaces = MutableStateFlow<List<ImageEntity>>(emptyList())
|
private val _photosWithFaces = MutableStateFlow<List<ImageEntity>>(emptyList())
|
||||||
val photosWithFaces: StateFlow<List<ImageEntity>> = _photosWithFaces.asStateFlow()
|
val photosWithFaces: StateFlow<List<ImageEntity>> = _photosWithFaces.asStateFlow()
|
||||||
|
|
||||||
// Selected photos (multi-select)
|
|
||||||
private val _selectedPhotos = MutableStateFlow<Set<ImageEntity>>(emptySet())
|
private val _selectedPhotos = MutableStateFlow<Set<ImageEntity>>(emptySet())
|
||||||
val selectedPhotos: StateFlow<Set<ImageEntity>> = _selectedPhotos.asStateFlow()
|
val selectedPhotos: StateFlow<Set<ImageEntity>> = _selectedPhotos.asStateFlow()
|
||||||
|
|
||||||
// Loading state
|
|
||||||
private val _isLoading = MutableStateFlow(true)
|
private val _isLoading = MutableStateFlow(true)
|
||||||
val isLoading: StateFlow<Boolean> = _isLoading.asStateFlow()
|
val isLoading: StateFlow<Boolean> = _isLoading.asStateFlow()
|
||||||
|
|
||||||
|
private val _isRanking = MutableStateFlow(false)
|
||||||
|
val isRanking: StateFlow<Boolean> = _isRanking.asStateFlow()
|
||||||
|
|
||||||
|
// Embedding generation progress
|
||||||
|
private val _embeddingProgress = MutableStateFlow<EmbeddingProgress?>(null)
|
||||||
|
val embeddingProgress: StateFlow<EmbeddingProgress?> = _embeddingProgress.asStateFlow()
|
||||||
|
|
||||||
|
data class EmbeddingProgress(val current: Int, val total: Int)
|
||||||
|
|
||||||
|
// Premium mode toggle
|
||||||
|
private val _showPremiumOnly = MutableStateFlow(true)
|
||||||
|
val showPremiumOnly: StateFlow<Boolean> = _showPremiumOnly.asStateFlow()
|
||||||
|
|
||||||
|
// Premium face count for UI
|
||||||
|
private val _premiumCount = MutableStateFlow(0)
|
||||||
|
val premiumCount: StateFlow<Int> = _premiumCount.asStateFlow()
|
||||||
|
|
||||||
|
// Can launch rolling scan?
|
||||||
|
val canLaunchRollingScan: Boolean
|
||||||
|
get() = _selectedPhotos.value.size in MIN_SEEDS_FOR_ROLLING_SCAN..MAX_SEEDS_FOR_ROLLING_SCAN
|
||||||
|
|
||||||
|
// Get seed image IDs for rolling scan navigation
|
||||||
|
fun getSeedImageIds(): List<String> = _selectedPhotos.value.map { it.imageId }
|
||||||
|
|
||||||
|
private var rankingJob: Job? = null
|
||||||
|
|
||||||
init {
|
init {
|
||||||
loadPhotosWithFaces()
|
loadPremiumFaces()
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Load ONLY photos with hasFaces=true
|
* Load PREMIUM faces first (solo, large, frontal, high quality)
|
||||||
*
|
* If no embeddings exist, generate them on-demand for premium candidates
|
||||||
* Uses indexed query: SELECT * FROM images WHERE hasFaces = 1
|
|
||||||
* Fast! (~10ms for 10k photos)
|
|
||||||
*
|
|
||||||
* SORTED: Solo photos (faceCount=1) first for best training quality
|
|
||||||
*/
|
*/
|
||||||
private fun loadPhotosWithFaces() {
|
private fun loadPremiumFaces() {
|
||||||
viewModelScope.launch {
|
viewModelScope.launch {
|
||||||
try {
|
try {
|
||||||
_isLoading.value = true
|
_isLoading.value = true
|
||||||
|
|
||||||
// ✅ CRITICAL: Only get images with faces!
|
// First check if premium faces with embeddings exist
|
||||||
val photos = imageDao.getImagesWithFaces()
|
var premiumFaceCache = faceCacheDao.getPremiumFaces(
|
||||||
|
minAreaRatio = 0.10f,
|
||||||
|
minQuality = 0.7f,
|
||||||
|
limit = 500
|
||||||
|
)
|
||||||
|
|
||||||
// ✅ FIX: Sort by LEAST faces first (solo photos = best training data)
|
Log.d(TAG, "📊 Found ${premiumFaceCache.size} premium faces with embeddings")
|
||||||
// faceCount=1 first, then faceCount=2, etc.
|
|
||||||
val sorted = photos.sortedBy { it.faceCount ?: 999 }
|
|
||||||
|
|
||||||
_photosWithFaces.value = sorted
|
// If no premium faces with embeddings, generate them on-demand
|
||||||
|
if (premiumFaceCache.isEmpty()) {
|
||||||
|
Log.d(TAG, "⚠️ No premium faces with embeddings - generating on-demand")
|
||||||
|
|
||||||
|
val candidates = faceCacheDao.getPremiumFaceCandidatesNeedingEmbeddings(
|
||||||
|
minAreaRatio = 0.10f,
|
||||||
|
minQuality = 0.7f,
|
||||||
|
limit = MAX_EMBEDDINGS_TO_GENERATE
|
||||||
|
)
|
||||||
|
|
||||||
|
Log.d(TAG, "📦 Found ${candidates.size} premium candidates needing embeddings")
|
||||||
|
|
||||||
|
if (candidates.isNotEmpty()) {
|
||||||
|
generateEmbeddingsForCandidates(candidates)
|
||||||
|
|
||||||
|
// Re-query after generating
|
||||||
|
premiumFaceCache = faceCacheDao.getPremiumFaces(
|
||||||
|
minAreaRatio = 0.10f,
|
||||||
|
minQuality = 0.7f,
|
||||||
|
limit = 500
|
||||||
|
)
|
||||||
|
Log.d(TAG, "✅ After generation: ${premiumFaceCache.size} premium faces")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
_premiumCount.value = premiumFaceCache.size
|
||||||
|
|
||||||
|
// Get corresponding ImageEntities
|
||||||
|
val premiumImageIds = premiumFaceCache.map { it.imageId }.distinct()
|
||||||
|
val images = imageDao.getImagesByIds(premiumImageIds)
|
||||||
|
|
||||||
|
// Sort by quality (highest first)
|
||||||
|
val imageQualityMap = premiumFaceCache.associate { it.imageId to it.qualityScore }
|
||||||
|
premiumPhotos = images.sortedByDescending { imageQualityMap[it.imageId] ?: 0f }
|
||||||
|
|
||||||
|
_photosWithFaces.value = premiumPhotos
|
||||||
|
|
||||||
|
// Also load all photos for fallback
|
||||||
|
allPhotosWithFaces = imageDao.getImagesWithFaces()
|
||||||
|
.sortedBy { it.faceCount ?: 999 }
|
||||||
|
|
||||||
|
Log.d(TAG, "✅ Premium: ${premiumPhotos.size}, Total: ${allPhotosWithFaces.size}")
|
||||||
|
|
||||||
} catch (e: Exception) {
|
} catch (e: Exception) {
|
||||||
// If face cache not populated, empty list
|
Log.e(TAG, "❌ Failed to load premium faces", e)
|
||||||
_photosWithFaces.value = emptyList()
|
// Fallback to all faces
|
||||||
|
loadAllFaces()
|
||||||
} finally {
|
} finally {
|
||||||
_isLoading.value = false
|
_isLoading.value = false
|
||||||
|
_embeddingProgress.value = null
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Toggle photo selection
|
* Generate embeddings for premium face candidates
|
||||||
*/
|
*/
|
||||||
|
private suspend fun generateEmbeddingsForCandidates(candidates: List<FaceCacheEntity>) {
|
||||||
|
val context = getApplication<Application>()
|
||||||
|
val total = candidates.size
|
||||||
|
var processed = 0
|
||||||
|
|
||||||
|
withContext(Dispatchers.IO) {
|
||||||
|
// Get image URIs for candidates
|
||||||
|
val imageIds = candidates.map { it.imageId }.distinct()
|
||||||
|
val images = imageDao.getImagesByIds(imageIds)
|
||||||
|
val imageUriMap = images.associate { it.imageId to it.imageUri }
|
||||||
|
|
||||||
|
for (candidate in candidates) {
|
||||||
|
try {
|
||||||
|
val imageUri = imageUriMap[candidate.imageId] ?: continue
|
||||||
|
|
||||||
|
// Load bitmap
|
||||||
|
val bitmap = loadBitmapOptimized(context, Uri.parse(imageUri)) ?: continue
|
||||||
|
|
||||||
|
// Crop face
|
||||||
|
val croppedFace = cropFaceWithPadding(bitmap, candidate.getBoundingBox())
|
||||||
|
bitmap.recycle()
|
||||||
|
|
||||||
|
if (croppedFace == null) continue
|
||||||
|
|
||||||
|
// Generate embedding
|
||||||
|
val embedding = faceNetModel.generateEmbedding(croppedFace)
|
||||||
|
croppedFace.recycle()
|
||||||
|
|
||||||
|
// Validate embedding
|
||||||
|
if (embedding.any { it != 0f }) {
|
||||||
|
// Save to database
|
||||||
|
val embeddingJson = FaceCacheEntity.embeddingToJson(embedding)
|
||||||
|
faceCacheDao.updateEmbedding(candidate.imageId, candidate.faceIndex, embeddingJson)
|
||||||
|
}
|
||||||
|
|
||||||
|
} catch (e: Exception) {
|
||||||
|
Log.w(TAG, "Failed to generate embedding for ${candidate.imageId}: ${e.message}")
|
||||||
|
}
|
||||||
|
|
||||||
|
processed++
|
||||||
|
withContext(Dispatchers.Main) {
|
||||||
|
_embeddingProgress.value = EmbeddingProgress(processed, total)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Log.d(TAG, "✅ Generated embeddings for $processed/$total candidates")
|
||||||
|
}
|
||||||
|
|
||||||
|
private fun loadBitmapOptimized(context: android.content.Context, uri: Uri, maxDim: Int = 768): Bitmap? {
|
||||||
|
return try {
|
||||||
|
val options = BitmapFactory.Options().apply { inJustDecodeBounds = true }
|
||||||
|
context.contentResolver.openInputStream(uri)?.use { stream ->
|
||||||
|
BitmapFactory.decodeStream(stream, null, options)
|
||||||
|
}
|
||||||
|
|
||||||
|
var sampleSize = 1
|
||||||
|
while (options.outWidth / sampleSize > maxDim || options.outHeight / sampleSize > maxDim) {
|
||||||
|
sampleSize *= 2
|
||||||
|
}
|
||||||
|
|
||||||
|
val finalOptions = BitmapFactory.Options().apply {
|
||||||
|
inSampleSize = sampleSize
|
||||||
|
inPreferredConfig = Bitmap.Config.ARGB_8888
|
||||||
|
}
|
||||||
|
|
||||||
|
context.contentResolver.openInputStream(uri)?.use { stream ->
|
||||||
|
BitmapFactory.decodeStream(stream, null, finalOptions)
|
||||||
|
}
|
||||||
|
} catch (e: Exception) {
|
||||||
|
Log.w(TAG, "Failed to load bitmap: ${e.message}")
|
||||||
|
null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private fun cropFaceWithPadding(bitmap: Bitmap, boundingBox: Rect): Bitmap? {
|
||||||
|
return try {
|
||||||
|
val padding = (max(boundingBox.width(), boundingBox.height()) * 0.25f).toInt()
|
||||||
|
val left = max(0, boundingBox.left - padding)
|
||||||
|
val top = max(0, boundingBox.top - padding)
|
||||||
|
val right = min(bitmap.width, boundingBox.right + padding)
|
||||||
|
val bottom = min(bitmap.height, boundingBox.bottom + padding)
|
||||||
|
val width = right - left
|
||||||
|
val height = bottom - top
|
||||||
|
|
||||||
|
if (width > 0 && height > 0) {
|
||||||
|
Bitmap.createBitmap(bitmap, left, top, width, height)
|
||||||
|
} else null
|
||||||
|
} catch (e: Exception) {
|
||||||
|
Log.w(TAG, "Failed to crop face: ${e.message}")
|
||||||
|
null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Fallback: load all photos with faces
|
||||||
|
*/
|
||||||
|
private suspend fun loadAllFaces() {
|
||||||
|
try {
|
||||||
|
val photos = imageDao.getImagesWithFaces()
|
||||||
|
allPhotosWithFaces = photos.sortedBy { it.faceCount ?: 999 }
|
||||||
|
premiumPhotos = allPhotosWithFaces.filter { it.faceCount == 1 }.take(200)
|
||||||
|
_photosWithFaces.value = if (_showPremiumOnly.value) premiumPhotos else allPhotosWithFaces
|
||||||
|
Log.d(TAG, "✅ Fallback loaded ${allPhotosWithFaces.size} photos")
|
||||||
|
} catch (e: Exception) {
|
||||||
|
Log.e(TAG, "❌ Failed fallback load", e)
|
||||||
|
allPhotosWithFaces = emptyList()
|
||||||
|
premiumPhotos = emptyList()
|
||||||
|
_photosWithFaces.value = emptyList()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Toggle between premium-only and all photos
|
||||||
|
*/
|
||||||
|
fun togglePremiumOnly() {
|
||||||
|
_showPremiumOnly.value = !_showPremiumOnly.value
|
||||||
|
_photosWithFaces.value = if (_showPremiumOnly.value) premiumPhotos else allPhotosWithFaces
|
||||||
|
Log.d(TAG, "📊 Showing ${if (_showPremiumOnly.value) "premium only" else "all photos"}")
|
||||||
|
}
|
||||||
|
|
||||||
fun toggleSelection(photo: ImageEntity) {
|
fun toggleSelection(photo: ImageEntity) {
|
||||||
val current = _selectedPhotos.value.toMutableSet()
|
val current = _selectedPhotos.value.toMutableSet()
|
||||||
|
|
||||||
if (photo in current) {
|
if (photo in current) {
|
||||||
current.remove(photo)
|
current.remove(photo)
|
||||||
|
Log.d(TAG, "➖ Deselected photo: ${photo.imageId}")
|
||||||
} else {
|
} else {
|
||||||
current.add(photo)
|
current.add(photo)
|
||||||
|
Log.d(TAG, "➕ Selected photo: ${photo.imageId}")
|
||||||
}
|
}
|
||||||
|
|
||||||
_selectedPhotos.value = current
|
_selectedPhotos.value = current
|
||||||
|
Log.d(TAG, "📊 Total selected: ${current.size}")
|
||||||
|
|
||||||
|
// Trigger ranking
|
||||||
|
triggerLiveRanking()
|
||||||
|
}
|
||||||
|
|
||||||
|
private fun triggerLiveRanking() {
|
||||||
|
Log.d(TAG, "🔄 triggerLiveRanking() called")
|
||||||
|
|
||||||
|
// Cancel previous ranking job
|
||||||
|
rankingJob?.cancel()
|
||||||
|
|
||||||
|
val selectedCount = _selectedPhotos.value.size
|
||||||
|
|
||||||
|
if (selectedCount == 0) {
|
||||||
|
Log.d(TAG, "⏹️ No photos selected, resetting to original order")
|
||||||
|
_photosWithFaces.value = allPhotosWithFaces
|
||||||
|
_isRanking.value = false
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
Log.d(TAG, "⏳ Starting debounced ranking (300ms delay)...")
|
||||||
|
|
||||||
|
// Debounce ranking by 300ms
|
||||||
|
rankingJob = viewModelScope.launch {
|
||||||
|
try {
|
||||||
|
delay(300)
|
||||||
|
Log.d(TAG, "✓ Debounce complete, starting ranking...")
|
||||||
|
|
||||||
|
_isRanking.value = true
|
||||||
|
|
||||||
|
// Get embeddings for selected photos
|
||||||
|
val selectedImageIds = _selectedPhotos.value.map { it.imageId }
|
||||||
|
Log.d(TAG, "📥 Getting embeddings for ${selectedImageIds.size} selected photos...")
|
||||||
|
|
||||||
|
val selectedEmbeddings = faceCacheDao.getEmbeddingsForImages(selectedImageIds)
|
||||||
|
.mapNotNull { it.getEmbedding() }
|
||||||
|
|
||||||
|
Log.d(TAG, "📦 Retrieved ${selectedEmbeddings.size} embeddings")
|
||||||
|
|
||||||
|
if (selectedEmbeddings.isEmpty()) {
|
||||||
|
Log.w(TAG, "⚠️ No embeddings available! Check if face cache is populated.")
|
||||||
|
_photosWithFaces.value = allPhotosWithFaces
|
||||||
|
return@launch
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate centroid
|
||||||
|
Log.d(TAG, "🧮 Calculating centroid from ${selectedEmbeddings.size} embeddings...")
|
||||||
|
val centroidStart = System.currentTimeMillis()
|
||||||
|
val centroid = faceSimilarityScorer.calculateCentroid(selectedEmbeddings)
|
||||||
|
val centroidTime = System.currentTimeMillis() - centroidStart
|
||||||
|
Log.d(TAG, "✓ Centroid calculated in ${centroidTime}ms")
|
||||||
|
|
||||||
|
// Score all photos
|
||||||
|
val allImageIds = allPhotosWithFaces.map { it.imageId }
|
||||||
|
Log.d(TAG, "🎯 Scoring ${allImageIds.size} photos against centroid...")
|
||||||
|
|
||||||
|
val scoringStart = System.currentTimeMillis()
|
||||||
|
val scoredPhotos = faceSimilarityScorer.scorePhotosAgainstCentroid(
|
||||||
|
allImageIds = allImageIds,
|
||||||
|
selectedImageIds = selectedImageIds.toSet(),
|
||||||
|
centroid = centroid
|
||||||
|
)
|
||||||
|
val scoringTime = System.currentTimeMillis() - scoringStart
|
||||||
|
Log.d(TAG, "✓ Scoring completed in ${scoringTime}ms")
|
||||||
|
Log.d(TAG, "📊 Scored ${scoredPhotos.size} photos")
|
||||||
|
|
||||||
|
// Create score map
|
||||||
|
val scoreMap = scoredPhotos.associate { it.imageId to it.finalScore }
|
||||||
|
|
||||||
|
// Log top 5 scores for debugging
|
||||||
|
val top5 = scoredPhotos.take(5)
|
||||||
|
top5.forEach { scored ->
|
||||||
|
Log.d(TAG, " 🏆 Top photo: ${scored.imageId.take(8)} - score: ${scored.finalScore}")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Re-rank photos
|
||||||
|
val rankingStart = System.currentTimeMillis()
|
||||||
|
val rankedPhotos = allPhotosWithFaces.sortedByDescending { photo ->
|
||||||
|
if (photo in _selectedPhotos.value) {
|
||||||
|
1.0f // Selected photos stay at top
|
||||||
|
} else {
|
||||||
|
scoreMap[photo.imageId] ?: 0f
|
||||||
|
}
|
||||||
|
}
|
||||||
|
val rankingTime = System.currentTimeMillis() - rankingStart
|
||||||
|
Log.d(TAG, "✓ Ranking completed in ${rankingTime}ms")
|
||||||
|
|
||||||
|
// Update UI
|
||||||
|
_photosWithFaces.value = rankedPhotos
|
||||||
|
|
||||||
|
val totalTime = centroidTime + scoringTime + rankingTime
|
||||||
|
Log.d(TAG, "🎉 Live ranking complete! Total time: ${totalTime}ms")
|
||||||
|
Log.d(TAG, " - Centroid: ${centroidTime}ms")
|
||||||
|
Log.d(TAG, " - Scoring: ${scoringTime}ms")
|
||||||
|
Log.d(TAG, " - Ranking: ${rankingTime}ms")
|
||||||
|
|
||||||
|
} catch (e: Exception) {
|
||||||
|
Log.e(TAG, "❌ Ranking failed!", e)
|
||||||
|
Log.e(TAG, " Error: ${e.message}")
|
||||||
|
Log.e(TAG, " Stack: ${e.stackTraceToString()}")
|
||||||
|
} finally {
|
||||||
|
_isRanking.value = false
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Clear all selections
|
|
||||||
*/
|
|
||||||
fun clearSelection() {
|
fun clearSelection() {
|
||||||
|
Log.d(TAG, "🗑️ Clearing selection")
|
||||||
_selectedPhotos.value = emptySet()
|
_selectedPhotos.value = emptySet()
|
||||||
|
_photosWithFaces.value = allPhotosWithFaces
|
||||||
|
_isRanking.value = false
|
||||||
|
rankingJob?.cancel()
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Auto-select first N photos (quick start)
|
|
||||||
*/
|
|
||||||
fun autoSelect(count: Int = 25) {
|
fun autoSelect(count: Int = 25) {
|
||||||
val photos = _photosWithFaces.value.take(count)
|
val photos = allPhotosWithFaces.take(count)
|
||||||
_selectedPhotos.value = photos.toSet()
|
_selectedPhotos.value = photos.toSet()
|
||||||
|
Log.d(TAG, "🤖 Auto-selected ${photos.size} photos")
|
||||||
|
triggerLiveRanking()
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Select photos with single face only (best for training)
|
|
||||||
*/
|
|
||||||
fun selectSingleFacePhotos(count: Int = 25) {
|
fun selectSingleFacePhotos(count: Int = 25) {
|
||||||
val singleFacePhotos = _photosWithFaces.value
|
val singleFacePhotos = allPhotosWithFaces
|
||||||
.filter { it.faceCount == 1 }
|
.filter { it.faceCount == 1 }
|
||||||
.take(count)
|
.take(count)
|
||||||
_selectedPhotos.value = singleFacePhotos.toSet()
|
_selectedPhotos.value = singleFacePhotos.toSet()
|
||||||
|
Log.d(TAG, "👤 Selected ${singleFacePhotos.size} single-face photos")
|
||||||
|
triggerLiveRanking()
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Refresh data (call after face detection cache updates)
|
|
||||||
*/
|
|
||||||
fun refresh() {
|
fun refresh() {
|
||||||
loadPhotosWithFaces()
|
Log.d(TAG, "🔄 Refreshing data")
|
||||||
|
loadPremiumFaces()
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun onCleared() {
|
||||||
|
super.onCleared()
|
||||||
|
Log.d(TAG, "🧹 ViewModel cleared")
|
||||||
|
rankingJob?.cancel()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
61
app/src/main/java/com/placeholder/sherpai2/util/Debouncer.kt
Normal file
61
app/src/main/java/com/placeholder/sherpai2/util/Debouncer.kt
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
package com.placeholder.sherpai2.util
|
||||||
|
|
||||||
|
import kotlinx.coroutines.CoroutineScope
|
||||||
|
import kotlinx.coroutines.Dispatchers
|
||||||
|
import kotlinx.coroutines.Job
|
||||||
|
import kotlinx.coroutines.delay
|
||||||
|
import kotlinx.coroutines.launch
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Debouncer - Delays execution until a pause in rapid calls
|
||||||
|
*
|
||||||
|
* Used by RollingScanViewModel to avoid re-scanning on every selection change
|
||||||
|
*
|
||||||
|
* EXAMPLE:
|
||||||
|
* User selects photos rapidly:
|
||||||
|
* - Select photo 1 → Debouncer starts 300ms timer
|
||||||
|
* - Select photo 2 (100ms later) → Timer resets to 300ms
|
||||||
|
* - Select photo 3 (100ms later) → Timer resets to 300ms
|
||||||
|
* - Wait 300ms → Scan executes ONCE
|
||||||
|
*
|
||||||
|
* RESULT: 3 selections = 1 scan (instead of 3 scans!)
|
||||||
|
*/
|
||||||
|
class Debouncer(
|
||||||
|
private val delayMs: Long = 300L,
|
||||||
|
private val scope: CoroutineScope = CoroutineScope(Dispatchers.Main)
|
||||||
|
) {
|
||||||
|
|
||||||
|
private var debounceJob: Job? = null
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Debounce an action
|
||||||
|
*
|
||||||
|
* Cancels any pending action and schedules a new one
|
||||||
|
*
|
||||||
|
* @param action Suspend function to execute after delay
|
||||||
|
*/
|
||||||
|
fun debounce(action: suspend () -> Unit) {
|
||||||
|
// Cancel previous job
|
||||||
|
debounceJob?.cancel()
|
||||||
|
|
||||||
|
// Schedule new job
|
||||||
|
debounceJob = scope.launch {
|
||||||
|
delay(delayMs)
|
||||||
|
action()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Cancel any pending debounced action
|
||||||
|
*/
|
||||||
|
fun cancel() {
|
||||||
|
debounceJob?.cancel()
|
||||||
|
debounceJob = null
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Check if debouncer has a pending action
|
||||||
|
*/
|
||||||
|
val isPending: Boolean
|
||||||
|
get() = debounceJob?.isActive == true
|
||||||
|
}
|
||||||
@@ -9,6 +9,9 @@ import com.google.mlkit.vision.common.InputImage
|
|||||||
import com.google.mlkit.vision.face.FaceDetection
|
import com.google.mlkit.vision.face.FaceDetection
|
||||||
import com.google.mlkit.vision.face.FaceDetectorOptions
|
import com.google.mlkit.vision.face.FaceDetectorOptions
|
||||||
import com.placeholder.sherpai2.data.local.dao.FaceModelDao
|
import com.placeholder.sherpai2.data.local.dao.FaceModelDao
|
||||||
|
import com.placeholder.sherpai2.data.local.dao.PersonDao
|
||||||
|
import com.placeholder.sherpai2.domain.clustering.FaceQualityFilter
|
||||||
|
import com.placeholder.sherpai2.ml.FaceNormalizer
|
||||||
import com.placeholder.sherpai2.data.local.dao.ImageDao
|
import com.placeholder.sherpai2.data.local.dao.ImageDao
|
||||||
import com.placeholder.sherpai2.data.local.dao.PhotoFaceTagDao
|
import com.placeholder.sherpai2.data.local.dao.PhotoFaceTagDao
|
||||||
import com.placeholder.sherpai2.data.local.entity.PhotoFaceTagEntity
|
import com.placeholder.sherpai2.data.local.entity.PhotoFaceTagEntity
|
||||||
@@ -52,7 +55,8 @@ class LibraryScanWorker @AssistedInject constructor(
|
|||||||
@Assisted workerParams: WorkerParameters,
|
@Assisted workerParams: WorkerParameters,
|
||||||
private val imageDao: ImageDao,
|
private val imageDao: ImageDao,
|
||||||
private val faceModelDao: FaceModelDao,
|
private val faceModelDao: FaceModelDao,
|
||||||
private val photoFaceTagDao: PhotoFaceTagDao
|
private val photoFaceTagDao: PhotoFaceTagDao,
|
||||||
|
private val personDao: PersonDao
|
||||||
) : CoroutineWorker(context, workerParams) {
|
) : CoroutineWorker(context, workerParams) {
|
||||||
|
|
||||||
companion object {
|
companion object {
|
||||||
@@ -65,7 +69,8 @@ class LibraryScanWorker @AssistedInject constructor(
|
|||||||
const val KEY_MATCHES_FOUND = "matches_found"
|
const val KEY_MATCHES_FOUND = "matches_found"
|
||||||
const val KEY_PHOTOS_SCANNED = "photos_scanned"
|
const val KEY_PHOTOS_SCANNED = "photos_scanned"
|
||||||
|
|
||||||
private const val DEFAULT_THRESHOLD = 0.70f // Slightly looser than validation
|
private const val DEFAULT_THRESHOLD = 0.62f // Solo photos
|
||||||
|
private const val GROUP_THRESHOLD = 0.68f // Group photos (stricter)
|
||||||
private const val BATCH_SIZE = 20
|
private const val BATCH_SIZE = 20
|
||||||
private const val MAX_RETRIES = 3
|
private const val MAX_RETRIES = 3
|
||||||
|
|
||||||
@@ -137,16 +142,40 @@ class LibraryScanWorker @AssistedInject constructor(
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Step 2.5: Load person to check isChild flag
|
||||||
|
val person = withContext(Dispatchers.IO) {
|
||||||
|
personDao.getPersonById(personId)
|
||||||
|
}
|
||||||
|
val isChildTarget = person?.isChild ?: false
|
||||||
|
|
||||||
// Step 3: Initialize ML components
|
// Step 3: Initialize ML components
|
||||||
val faceNetModel = FaceNetModel(context)
|
val faceNetModel = FaceNetModel(context)
|
||||||
val detector = FaceDetection.getClient(
|
val detector = FaceDetection.getClient(
|
||||||
FaceDetectorOptions.Builder()
|
FaceDetectorOptions.Builder()
|
||||||
.setPerformanceMode(FaceDetectorOptions.PERFORMANCE_MODE_ACCURATE)
|
.setPerformanceMode(FaceDetectorOptions.PERFORMANCE_MODE_ACCURATE)
|
||||||
|
.setLandmarkMode(FaceDetectorOptions.LANDMARK_MODE_ALL) // Needed for age estimation
|
||||||
.setMinFaceSize(0.15f)
|
.setMinFaceSize(0.15f)
|
||||||
.build()
|
.build()
|
||||||
)
|
)
|
||||||
|
|
||||||
val modelEmbedding = faceModel.getEmbeddingArray()
|
// Distribution-based minimum threshold (self-calibrating)
|
||||||
|
val distributionMin = (faceModel.averageConfidence - 2 * faceModel.similarityStdDev)
|
||||||
|
.coerceAtLeast(faceModel.similarityMin - 0.05f)
|
||||||
|
.coerceAtLeast(0.50f) // Never go below 0.50 absolute floor
|
||||||
|
|
||||||
|
// Get ALL centroids for multi-centroid matching (critical for children)
|
||||||
|
val modelCentroids = faceModel.getCentroids().map { it.getEmbeddingArray() }
|
||||||
|
if (modelCentroids.isEmpty()) {
|
||||||
|
return@withContext Result.failure(workDataOf("error" to "No centroids in model"))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load ALL other models for "best match wins" comparison
|
||||||
|
// This prevents tagging siblings incorrectly
|
||||||
|
val allModels = withContext(Dispatchers.IO) { faceModelDao.getAllActiveFaceModels() }
|
||||||
|
val otherModelCentroids = allModels
|
||||||
|
.filter { it.id != faceModel.id }
|
||||||
|
.map { model -> model.id to model.getCentroids().map { it.getEmbeddingArray() } }
|
||||||
|
|
||||||
var matchesFound = 0
|
var matchesFound = 0
|
||||||
var photosScanned = 0
|
var photosScanned = 0
|
||||||
|
|
||||||
@@ -164,10 +193,13 @@ class LibraryScanWorker @AssistedInject constructor(
|
|||||||
photo = photo,
|
photo = photo,
|
||||||
personId = personId,
|
personId = personId,
|
||||||
faceModelId = faceModel.id,
|
faceModelId = faceModel.id,
|
||||||
modelEmbedding = modelEmbedding,
|
modelCentroids = modelCentroids,
|
||||||
|
otherModelCentroids = otherModelCentroids,
|
||||||
faceNetModel = faceNetModel,
|
faceNetModel = faceNetModel,
|
||||||
detector = detector,
|
detector = detector,
|
||||||
threshold = threshold
|
threshold = threshold,
|
||||||
|
distributionMin = distributionMin,
|
||||||
|
isChildTarget = isChildTarget
|
||||||
)
|
)
|
||||||
|
|
||||||
if (tags.isNotEmpty()) {
|
if (tags.isNotEmpty()) {
|
||||||
@@ -228,10 +260,13 @@ class LibraryScanWorker @AssistedInject constructor(
|
|||||||
photo: com.placeholder.sherpai2.data.local.entity.ImageEntity,
|
photo: com.placeholder.sherpai2.data.local.entity.ImageEntity,
|
||||||
personId: String,
|
personId: String,
|
||||||
faceModelId: String,
|
faceModelId: String,
|
||||||
modelEmbedding: FloatArray,
|
modelCentroids: List<FloatArray>,
|
||||||
|
otherModelCentroids: List<Pair<String, List<FloatArray>>>,
|
||||||
faceNetModel: FaceNetModel,
|
faceNetModel: FaceNetModel,
|
||||||
detector: com.google.mlkit.vision.face.FaceDetector,
|
detector: com.google.mlkit.vision.face.FaceDetector,
|
||||||
threshold: Float
|
threshold: Float,
|
||||||
|
distributionMin: Float,
|
||||||
|
isChildTarget: Boolean
|
||||||
): List<PhotoFaceTagEntity> = withContext(Dispatchers.IO) {
|
): List<PhotoFaceTagEntity> = withContext(Dispatchers.IO) {
|
||||||
|
|
||||||
try {
|
try {
|
||||||
@@ -243,43 +278,94 @@ class LibraryScanWorker @AssistedInject constructor(
|
|||||||
val inputImage = InputImage.fromBitmap(bitmap, 0)
|
val inputImage = InputImage.fromBitmap(bitmap, 0)
|
||||||
val faces = detector.process(inputImage).await()
|
val faces = detector.process(inputImage).await()
|
||||||
|
|
||||||
// Check each face
|
if (faces.isEmpty()) {
|
||||||
val tags = faces.mapNotNull { face ->
|
bitmap.recycle()
|
||||||
|
return@withContext emptyList()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use higher threshold for group photos
|
||||||
|
val isGroupPhoto = faces.size > 1
|
||||||
|
val effectiveThreshold = if (isGroupPhoto) GROUP_THRESHOLD else threshold
|
||||||
|
|
||||||
|
// Track best match (only tag ONE face per image to avoid false positives)
|
||||||
|
var bestMatch: PhotoFaceTagEntity? = null
|
||||||
|
var bestSimilarity = 0f
|
||||||
|
|
||||||
|
// Check each face (filter by quality first)
|
||||||
|
for (face in faces) {
|
||||||
|
// Quality check
|
||||||
|
if (!FaceQualityFilter.validateForScanning(face, bitmap.width, bitmap.height)) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Skip very small faces
|
||||||
|
val faceArea = face.boundingBox.width() * face.boundingBox.height()
|
||||||
|
val imageArea = bitmap.width * bitmap.height
|
||||||
|
if (faceArea.toFloat() / imageArea < 0.02f) continue
|
||||||
|
|
||||||
|
// SIGNAL 2: Age plausibility check (if target is a child)
|
||||||
|
if (isChildTarget) {
|
||||||
|
val ageGroup = FaceQualityFilter.estimateAgeGroup(face, bitmap.width, bitmap.height)
|
||||||
|
if (ageGroup == FaceQualityFilter.AgeGroup.ADULT) {
|
||||||
|
continue // Reject clearly adult faces when searching for a child
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
try {
|
try {
|
||||||
// Crop face
|
// Crop and normalize face for best recognition
|
||||||
val faceBitmap = android.graphics.Bitmap.createBitmap(
|
val faceBitmap = FaceNormalizer.cropAndNormalize(bitmap, face)
|
||||||
bitmap,
|
?: continue
|
||||||
face.boundingBox.left.coerceIn(0, bitmap.width - 1),
|
|
||||||
face.boundingBox.top.coerceIn(0, bitmap.height - 1),
|
|
||||||
face.boundingBox.width().coerceAtMost(bitmap.width - face.boundingBox.left),
|
|
||||||
face.boundingBox.height().coerceAtMost(bitmap.height - face.boundingBox.top)
|
|
||||||
)
|
|
||||||
|
|
||||||
// Generate embedding
|
// Generate embedding
|
||||||
val faceEmbedding = faceNetModel.generateEmbedding(faceBitmap)
|
val faceEmbedding = faceNetModel.generateEmbedding(faceBitmap)
|
||||||
faceBitmap.recycle()
|
faceBitmap.recycle()
|
||||||
|
|
||||||
// Calculate similarity
|
// Match against target person's centroids
|
||||||
val similarity = faceNetModel.calculateSimilarity(faceEmbedding, modelEmbedding)
|
val targetSimilarity = modelCentroids.maxOfOrNull { centroid ->
|
||||||
|
faceNetModel.calculateSimilarity(faceEmbedding, centroid)
|
||||||
|
} ?: 0f
|
||||||
|
|
||||||
if (similarity >= threshold) {
|
// SIGNAL 1: Distribution-based rejection
|
||||||
PhotoFaceTagEntity.create(
|
// If similarity is below (mean - 2*stdDev) or (min - 0.05), it's a statistical outlier
|
||||||
|
if (targetSimilarity < distributionMin) {
|
||||||
|
continue // Too far below training distribution
|
||||||
|
}
|
||||||
|
|
||||||
|
// SIGNAL 3: Basic threshold check
|
||||||
|
if (targetSimilarity < effectiveThreshold) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// SIGNAL 4: "Best match wins" - check if any OTHER model scores higher
|
||||||
|
// This prevents tagging siblings incorrectly
|
||||||
|
val bestOtherSimilarity = otherModelCentroids.maxOfOrNull { (_, centroids) ->
|
||||||
|
centroids.maxOfOrNull { centroid ->
|
||||||
|
faceNetModel.calculateSimilarity(faceEmbedding, centroid)
|
||||||
|
} ?: 0f
|
||||||
|
} ?: 0f
|
||||||
|
|
||||||
|
val isTargetBestMatch = targetSimilarity > bestOtherSimilarity
|
||||||
|
|
||||||
|
// All signals must pass
|
||||||
|
if (isTargetBestMatch && targetSimilarity > bestSimilarity) {
|
||||||
|
bestSimilarity = targetSimilarity
|
||||||
|
bestMatch = PhotoFaceTagEntity.create(
|
||||||
imageId = photo.imageId,
|
imageId = photo.imageId,
|
||||||
faceModelId = faceModelId,
|
faceModelId = faceModelId,
|
||||||
boundingBox = face.boundingBox,
|
boundingBox = face.boundingBox,
|
||||||
confidence = similarity,
|
confidence = targetSimilarity,
|
||||||
faceEmbedding = faceEmbedding
|
faceEmbedding = faceEmbedding
|
||||||
)
|
)
|
||||||
} else {
|
|
||||||
null
|
|
||||||
}
|
}
|
||||||
} catch (e: Exception) {
|
} catch (e: Exception) {
|
||||||
null
|
// Skip this face
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bitmap.recycle()
|
bitmap.recycle()
|
||||||
tags
|
|
||||||
|
// Return only the best match (or empty)
|
||||||
|
if (bestMatch != null) listOf(bestMatch) else emptyList()
|
||||||
|
|
||||||
} catch (e: Exception) {
|
} catch (e: Exception) {
|
||||||
emptyList()
|
emptyList()
|
||||||
|
|||||||
Reference in New Issue
Block a user