3 Commits

Author SHA1 Message Date
genki
804f3d5640 rollingscan very clean
likelyhood -> find similar
REFRESHCLAUD.MD 20260126
2026-01-26 22:46:38 -05:00
genki
cfec2b980a toofasttooclaude 2026-01-26 14:15:54 -05:00
genki
1ef8faad17 jFc 2026-01-25 22:01:46 -05:00
27 changed files with 1369 additions and 251 deletions

View File

@@ -4,10 +4,10 @@
<selectionStates> <selectionStates>
<SelectionState runConfigName="app"> <SelectionState runConfigName="app">
<option name="selectionMode" value="DROPDOWN" /> <option name="selectionMode" value="DROPDOWN" />
<DropdownSelection timestamp="2026-01-25T20:45:06.118763497Z"> <DropdownSelection timestamp="2026-01-27T00:21:15.014661014Z">
<Target type="DEFAULT_BOOT"> <Target type="DEFAULT_BOOT">
<handle> <handle>
<DeviceId pluginId="LocalEmulator" identifier="path=/home/genki/.android/avd/Medium_Phone.avd" /> <DeviceId pluginId="PhysicalDevice" identifier="serial=R3CX106YYCB" />
</handle> </handle>
</Target> </Target>
</DropdownSelection> </DropdownSelection>

View File

@@ -48,6 +48,9 @@ dependencies {
implementation(libs.androidx.lifecycle.viewmodel.compose) implementation(libs.androidx.lifecycle.viewmodel.compose)
implementation(libs.androidx.activity.compose) implementation(libs.androidx.activity.compose)
// DataStore Preferences
implementation("androidx.datastore:datastore-preferences:1.1.1")
// Compose // Compose
implementation(platform(libs.androidx.compose.bom)) implementation(platform(libs.androidx.compose.bom))
implementation(libs.androidx.compose.ui) implementation(libs.androidx.compose.ui)

View File

@@ -10,6 +10,10 @@ import com.placeholder.sherpai2.data.local.entity.*
/** /**
* AppDatabase - Complete database for SherpAI2 * AppDatabase - Complete database for SherpAI2
* *
* VERSION 12 - Distribution-based rejection stats
* - Added similarityStdDev, similarityMin to FaceModelEntity
* - Enables self-calibrating threshold for face matching
*
* VERSION 10 - User Feedback Loop * VERSION 10 - User Feedback Loop
* - Added UserFeedbackEntity for storing user corrections * - Added UserFeedbackEntity for storing user corrections
* - Enables cluster refinement before training * - Enables cluster refinement before training
@@ -44,14 +48,15 @@ import com.placeholder.sherpai2.data.local.entity.*
PhotoFaceTagEntity::class, PhotoFaceTagEntity::class,
PersonAgeTagEntity::class, PersonAgeTagEntity::class,
FaceCacheEntity::class, FaceCacheEntity::class,
UserFeedbackEntity::class, // NEW: User corrections UserFeedbackEntity::class,
PersonStatisticsEntity::class, // Pre-computed person stats
// ===== COLLECTIONS ===== // ===== COLLECTIONS =====
CollectionEntity::class, CollectionEntity::class,
CollectionImageEntity::class, CollectionImageEntity::class,
CollectionFilterEntity::class CollectionFilterEntity::class
], ],
version = 10, // INCREMENTED for user feedback version = 12, // INCREMENTED for distribution-based rejection stats
exportSchema = false exportSchema = false
) )
abstract class AppDatabase : RoomDatabase() { abstract class AppDatabase : RoomDatabase() {
@@ -70,7 +75,8 @@ abstract class AppDatabase : RoomDatabase() {
abstract fun photoFaceTagDao(): PhotoFaceTagDao abstract fun photoFaceTagDao(): PhotoFaceTagDao
abstract fun personAgeTagDao(): PersonAgeTagDao abstract fun personAgeTagDao(): PersonAgeTagDao
abstract fun faceCacheDao(): FaceCacheDao abstract fun faceCacheDao(): FaceCacheDao
abstract fun userFeedbackDao(): UserFeedbackDao // NEW abstract fun userFeedbackDao(): UserFeedbackDao
abstract fun personStatisticsDao(): PersonStatisticsDao
// ===== COLLECTIONS DAO ===== // ===== COLLECTIONS DAO =====
abstract fun collectionDao(): CollectionDao abstract fun collectionDao(): CollectionDao
@@ -242,13 +248,60 @@ val MIGRATION_9_10 = object : Migration(9, 10) {
} }
} }
/**
* MIGRATION 10 → 11 (Person Statistics)
*
* Changes:
* 1. Create person_statistics table for pre-computed aggregates
*/
val MIGRATION_10_11 = object : Migration(10, 11) {
override fun migrate(database: SupportSQLiteDatabase) {
// Create person_statistics table
database.execSQL("""
CREATE TABLE IF NOT EXISTS person_statistics (
personId TEXT PRIMARY KEY NOT NULL,
photoCount INTEGER NOT NULL DEFAULT 0,
firstPhotoDate INTEGER NOT NULL DEFAULT 0,
lastPhotoDate INTEGER NOT NULL DEFAULT 0,
averageConfidence REAL NOT NULL DEFAULT 0,
agesWithPhotos TEXT,
updatedAt INTEGER NOT NULL DEFAULT 0,
FOREIGN KEY(personId) REFERENCES persons(id) ON DELETE CASCADE
)
""")
// Index for sorting by photo count (People Dashboard)
database.execSQL("CREATE INDEX IF NOT EXISTS index_person_statistics_photoCount ON person_statistics(photoCount)")
}
}
/**
* MIGRATION 11 → 12 (Distribution-based Rejection Stats)
*
* Changes:
* 1. Add similarityStdDev column to face_models (default 0.05)
* 2. Add similarityMin column to face_models (default 0.6)
*
* These fields enable self-calibrating thresholds during scanning.
* During training, we compute stats from training sample similarities
* and use (mean - 2*stdDev) as a floor for matching.
*/
val MIGRATION_11_12 = object : Migration(11, 12) {
override fun migrate(database: SupportSQLiteDatabase) {
// Add distribution stats columns with sensible defaults for existing models
database.execSQL("ALTER TABLE face_models ADD COLUMN similarityStdDev REAL NOT NULL DEFAULT 0.05")
database.execSQL("ALTER TABLE face_models ADD COLUMN similarityMin REAL NOT NULL DEFAULT 0.6")
}
}
/** /**
* PRODUCTION MIGRATION NOTES: * PRODUCTION MIGRATION NOTES:
* *
* Before shipping to users, update DatabaseModule to use migrations: * Before shipping to users, update DatabaseModule to use migrations:
* *
* Room.databaseBuilder(context, AppDatabase::class.java, "sherpai.db") * Room.databaseBuilder(context, AppDatabase::class.java, "sherpai.db")
* .addMigrations(MIGRATION_7_8, MIGRATION_8_9, MIGRATION_9_10) // Add all migrations * .addMigrations(MIGRATION_7_8, MIGRATION_8_9, MIGRATION_9_10, MIGRATION_10_11, MIGRATION_11_12) // Add all migrations
* // .fallbackToDestructiveMigration() // Remove this * // .fallbackToDestructiveMigration() // Remove this
* .build() * .build()
*/ */

View File

@@ -233,6 +233,33 @@ interface FaceCacheDao {
limit: Int = 500 limit: Int = 500
): List<FaceCacheEntity> ): List<FaceCacheEntity>
/**
* Get premium face CANDIDATES - same criteria but WITHOUT embedding requirement.
* Used to find faces that need embedding generation.
*/
@Query("""
SELECT fc.* FROM face_cache fc
INNER JOIN images i ON fc.imageId = i.imageId
WHERE i.faceCount = 1
AND fc.faceAreaRatio >= :minAreaRatio
AND fc.isFrontal = 1
AND fc.qualityScore >= :minQuality
AND fc.embedding IS NULL
ORDER BY fc.qualityScore DESC, fc.faceAreaRatio DESC
LIMIT :limit
""")
suspend fun getPremiumFaceCandidatesNeedingEmbeddings(
minAreaRatio: Float = 0.10f,
minQuality: Float = 0.7f,
limit: Int = 500
): List<FaceCacheEntity>
/**
* Update embedding for a face cache entry
*/
@Query("UPDATE face_cache SET embedding = :embedding WHERE imageId = :imageId AND faceIndex = :faceIndex")
suspend fun updateEmbedding(imageId: String, faceIndex: Int, embedding: String)
/** /**
* Count of premium faces available * Count of premium faces available
*/ */

View File

@@ -66,6 +66,9 @@ interface ImageDao {
@Query("SELECT * FROM images WHERE imageId = :imageId") @Query("SELECT * FROM images WHERE imageId = :imageId")
suspend fun getImageById(imageId: String): ImageEntity? suspend fun getImageById(imageId: String): ImageEntity?
@Query("SELECT * FROM images WHERE imageUri = :uri LIMIT 1")
suspend fun getImageByUri(uri: String): ImageEntity?
/** /**
* Stream images ordered by capture time (newest first). * Stream images ordered by capture time (newest first).
* *

View File

@@ -83,9 +83,89 @@ interface PhotoFaceTagDao {
*/ */
@Query("SELECT * FROM photo_face_tags ORDER BY detectedAt DESC LIMIT :limit") @Query("SELECT * FROM photo_face_tags ORDER BY detectedAt DESC LIMIT :limit")
suspend fun getRecentlyDetectedFaces(limit: Int): List<PhotoFaceTagEntity> suspend fun getRecentlyDetectedFaces(limit: Int): List<PhotoFaceTagEntity>
// ===== CO-OCCURRENCE QUERIES =====
/**
* Find people who appear in photos together with a given person.
* Returns list of (otherFaceModelId, count) sorted by count descending.
* Use case: "Who appears most with Mom?" or "Show photos of Mom WITH Dad"
*/
@Query("""
SELECT pft2.faceModelId as otherFaceModelId, COUNT(DISTINCT pft1.imageId) as coCount
FROM photo_face_tags pft1
INNER JOIN photo_face_tags pft2 ON pft1.imageId = pft2.imageId
WHERE pft1.faceModelId = :faceModelId
AND pft2.faceModelId != :faceModelId
GROUP BY pft2.faceModelId
ORDER BY coCount DESC
""")
suspend fun getCoOccurrences(faceModelId: String): List<PersonCoOccurrence>
/**
* Get images where BOTH people appear together.
*/
@Query("""
SELECT DISTINCT pft1.imageId
FROM photo_face_tags pft1
INNER JOIN photo_face_tags pft2 ON pft1.imageId = pft2.imageId
WHERE pft1.faceModelId = :faceModelId1
AND pft2.faceModelId = :faceModelId2
ORDER BY pft1.detectedAt DESC
""")
suspend fun getImagesWithBothPeople(faceModelId1: String, faceModelId2: String): List<String>
/**
* Get images where person appears ALONE (no other trained faces).
*/
@Query("""
SELECT imageId FROM photo_face_tags
WHERE faceModelId = :faceModelId
AND imageId NOT IN (
SELECT imageId FROM photo_face_tags
WHERE faceModelId != :faceModelId
)
ORDER BY detectedAt DESC
""")
suspend fun getImagesWithPersonAlone(faceModelId: String): List<String>
/**
* Get images where ALL specified people appear (N-way intersection).
* For "Intersection Search" moonshot feature.
*/
@Query("""
SELECT imageId FROM photo_face_tags
WHERE faceModelId IN (:faceModelIds)
GROUP BY imageId
HAVING COUNT(DISTINCT faceModelId) = :requiredCount
""")
suspend fun getImagesWithAllPeople(faceModelIds: List<String>, requiredCount: Int): List<String>
/**
* Get images with at least N of the specified people (family portrait detection).
*/
@Query("""
SELECT imageId, COUNT(DISTINCT faceModelId) as memberCount
FROM photo_face_tags
WHERE faceModelId IN (:faceModelIds)
GROUP BY imageId
HAVING memberCount >= :minMembers
ORDER BY memberCount DESC
""")
suspend fun getFamilyPortraits(faceModelIds: List<String>, minMembers: Int): List<FamilyPortraitResult>
} }
data class FamilyPortraitResult(
val imageId: String,
val memberCount: Int
)
data class FaceModelPhotoCount( data class FaceModelPhotoCount(
val faceModelId: String, val faceModelId: String,
val photoCount: Int val photoCount: Int
) )
data class PersonCoOccurrence(
val otherFaceModelId: String,
val coCount: Int
)

View File

@@ -99,6 +99,13 @@ data class FaceCacheEntity(
companion object { companion object {
const val CURRENT_CACHE_VERSION = 1 const val CURRENT_CACHE_VERSION = 1
/**
* Convert FloatArray embedding to JSON string for storage
*/
fun embeddingToJson(embedding: FloatArray): String {
return embedding.joinToString(",")
}
/** /**
* Create from ML Kit face detection result * Create from ML Kit face detection result
*/ */

View File

@@ -143,6 +143,13 @@ data class FaceModelEntity(
@ColumnInfo(name = "averageConfidence") @ColumnInfo(name = "averageConfidence")
val averageConfidence: Float, val averageConfidence: Float,
// Distribution stats for self-calibrating rejection
@ColumnInfo(name = "similarityStdDev")
val similarityStdDev: Float = 0.05f, // Default for backwards compat
@ColumnInfo(name = "similarityMin")
val similarityMin: Float = 0.6f, // Default for backwards compat
@ColumnInfo(name = "createdAt") @ColumnInfo(name = "createdAt")
val createdAt: Long, val createdAt: Long,
@@ -157,26 +164,29 @@ data class FaceModelEntity(
) { ) {
companion object { companion object {
/** /**
* Backwards compatible create() method * Create with distribution stats for self-calibrating rejection
* Used by existing FaceRecognitionRepository code
*/ */
fun create( fun create(
personId: String, personId: String,
embeddingArray: FloatArray, embeddingArray: FloatArray,
trainingImageCount: Int, trainingImageCount: Int,
averageConfidence: Float averageConfidence: Float,
similarityStdDev: Float = 0.05f,
similarityMin: Float = 0.6f
): FaceModelEntity { ): FaceModelEntity {
return createFromEmbedding(personId, embeddingArray, trainingImageCount, averageConfidence) return createFromEmbedding(personId, embeddingArray, trainingImageCount, averageConfidence, similarityStdDev, similarityMin)
} }
/** /**
* Create from single embedding (backwards compatible) * Create from single embedding with distribution stats
*/ */
fun createFromEmbedding( fun createFromEmbedding(
personId: String, personId: String,
embeddingArray: FloatArray, embeddingArray: FloatArray,
trainingImageCount: Int, trainingImageCount: Int,
averageConfidence: Float averageConfidence: Float,
similarityStdDev: Float = 0.05f,
similarityMin: Float = 0.6f
): FaceModelEntity { ): FaceModelEntity {
val now = System.currentTimeMillis() val now = System.currentTimeMillis()
val centroid = TemporalCentroid( val centroid = TemporalCentroid(
@@ -194,6 +204,8 @@ data class FaceModelEntity(
centroidsJson = serializeCentroids(listOf(centroid)), centroidsJson = serializeCentroids(listOf(centroid)),
trainingImageCount = trainingImageCount, trainingImageCount = trainingImageCount,
averageConfidence = averageConfidence, averageConfidence = averageConfidence,
similarityStdDev = similarityStdDev,
similarityMin = similarityMin,
createdAt = now, createdAt = now,
updatedAt = now, updatedAt = now,
lastUsed = null, lastUsed = null,

View File

@@ -2,8 +2,10 @@ package com.placeholder.sherpai2.data.repository
import android.content.Context import android.content.Context
import android.graphics.Bitmap import android.graphics.Bitmap
import android.util.Log
import com.placeholder.sherpai2.data.local.dao.FaceModelDao import com.placeholder.sherpai2.data.local.dao.FaceModelDao
import com.placeholder.sherpai2.data.local.dao.ImageDao import com.placeholder.sherpai2.data.local.dao.ImageDao
import com.placeholder.sherpai2.data.local.dao.PersonAgeTagDao
import com.placeholder.sherpai2.data.local.dao.PersonDao import com.placeholder.sherpai2.data.local.dao.PersonDao
import com.placeholder.sherpai2.data.local.dao.PhotoFaceTagDao import com.placeholder.sherpai2.data.local.dao.PhotoFaceTagDao
import com.placeholder.sherpai2.data.local.entity.* import com.placeholder.sherpai2.data.local.entity.*
@@ -31,8 +33,12 @@ class FaceRecognitionRepository @Inject constructor(
private val personDao: PersonDao, private val personDao: PersonDao,
private val imageDao: ImageDao, private val imageDao: ImageDao,
private val faceModelDao: FaceModelDao, private val faceModelDao: FaceModelDao,
private val photoFaceTagDao: PhotoFaceTagDao private val photoFaceTagDao: PhotoFaceTagDao,
private val personAgeTagDao: PersonAgeTagDao
) { ) {
companion object {
private const val TAG = "FaceRecognitionRepo"
}
private val faceNetModel by lazy { FaceNetModel(context) } private val faceNetModel by lazy { FaceNetModel(context) }
@@ -93,11 +99,19 @@ class FaceRecognitionRepository @Inject constructor(
} }
val avgConfidence = confidences.average().toFloat() val avgConfidence = confidences.average().toFloat()
// Compute distribution stats for self-calibrating rejection
val stdDev = kotlin.math.sqrt(
confidences.map { (it - avgConfidence).toDouble().let { d -> d * d } }.average()
).toFloat()
val minSimilarity = confidences.minOrNull() ?: 0f
val faceModel = FaceModelEntity.create( val faceModel = FaceModelEntity.create(
personId = personId, personId = personId,
embeddingArray = personEmbedding, embeddingArray = personEmbedding,
trainingImageCount = validImages.size, trainingImageCount = validImages.size,
averageConfidence = avgConfidence averageConfidence = avgConfidence,
similarityStdDev = stdDev,
similarityMin = minSimilarity
) )
faceModelDao.insertFaceModel(faceModel) faceModelDao.insertFaceModel(faceModel)
@@ -181,12 +195,15 @@ class FaceRecognitionRepository @Inject constructor(
var highestSimilarity = threshold var highestSimilarity = threshold
for (faceModel in faceModels) { for (faceModel in faceModels) {
val modelEmbedding = faceModel.getEmbeddingArray() // Check ALL centroids for best match (critical for children with age centroids)
val similarity = faceNetModel.calculateSimilarity(faceEmbedding, modelEmbedding) val centroids = faceModel.getCentroids()
val bestCentroidSimilarity = centroids.maxOfOrNull { centroid ->
faceNetModel.calculateSimilarity(faceEmbedding, centroid.getEmbeddingArray())
} ?: 0f
if (similarity > highestSimilarity) { if (bestCentroidSimilarity > highestSimilarity) {
highestSimilarity = similarity highestSimilarity = bestCentroidSimilarity
bestMatch = Pair(faceModel.id, similarity) bestMatch = Pair(faceModel.id, bestCentroidSimilarity)
} }
} }
@@ -374,9 +391,49 @@ class FaceRecognitionRepository @Inject constructor(
onProgress = onProgress onProgress = onProgress
) )
// Generate age tags for children
if (person.isChild && person.dateOfBirth != null) {
generateAgeTagsForTraining(person, validImages)
}
person.id person.id
} }
/**
* Generate age tags from training images for a child
*/
private suspend fun generateAgeTagsForTraining(
person: PersonEntity,
validImages: List<TrainingSanityChecker.ValidTrainingImage>
) {
try {
val dob = person.dateOfBirth ?: return
val tags = validImages.mapNotNull { img ->
val imageEntity = imageDao.getImageByUri(img.uri.toString()) ?: return@mapNotNull null
val ageMs = imageEntity.capturedAt - dob
val ageYears = (ageMs / (365.25 * 24 * 60 * 60 * 1000)).toInt()
if (ageYears < 0 || ageYears > 25) return@mapNotNull null
PersonAgeTagEntity.create(
personId = person.id,
personName = person.name,
imageId = imageEntity.imageId,
ageAtCapture = ageYears,
confidence = 1.0f
)
}
if (tags.isNotEmpty()) {
personAgeTagDao.insertTags(tags)
Log.d(TAG, "Created ${tags.size} age tags for ${person.name}")
}
} catch (e: Exception) {
Log.e(TAG, "Failed to generate age tags", e)
}
}
/** /**
* Get face model by ID * Get face model by ID
*/ */

View File

@@ -61,14 +61,16 @@ abstract class RepositoryModule {
personDao: PersonDao, personDao: PersonDao,
imageDao: ImageDao, imageDao: ImageDao,
faceModelDao: FaceModelDao, faceModelDao: FaceModelDao,
photoFaceTagDao: PhotoFaceTagDao photoFaceTagDao: PhotoFaceTagDao,
personAgeTagDao: PersonAgeTagDao
): FaceRecognitionRepository { ): FaceRecognitionRepository {
return FaceRecognitionRepository( return FaceRecognitionRepository(
context = context, context = context,
personDao = personDao, personDao = personDao,
imageDao = imageDao, imageDao = imageDao,
faceModelDao = faceModelDao, faceModelDao = faceModelDao,
photoFaceTagDao = photoFaceTagDao photoFaceTagDao = photoFaceTagDao,
personAgeTagDao = personAgeTagDao
) )
} }

View File

@@ -15,6 +15,7 @@ import com.placeholder.sherpai2.data.local.dao.ImageDao
import com.placeholder.sherpai2.data.local.entity.FaceCacheEntity import com.placeholder.sherpai2.data.local.entity.FaceCacheEntity
import com.placeholder.sherpai2.data.local.entity.ImageEntity import com.placeholder.sherpai2.data.local.entity.ImageEntity
import com.placeholder.sherpai2.ml.FaceNetModel import com.placeholder.sherpai2.ml.FaceNetModel
import com.placeholder.sherpai2.ml.FaceNormalizer
import com.placeholder.sherpai2.ui.discover.DiscoverySettings import com.placeholder.sherpai2.ui.discover.DiscoverySettings
import dagger.hilt.android.qualifiers.ApplicationContext import dagger.hilt.android.qualifiers.ApplicationContext
import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.Dispatchers
@@ -344,14 +345,9 @@ class FaceClusteringService @Inject constructor(
} }
try { try {
// Crop and generate embedding // Crop and normalize face
val faceBitmap = Bitmap.createBitmap( val faceBitmap = FaceNormalizer.cropAndNormalize(bitmap, mlFace)
bitmap, ?: return@forEach
mlFace.boundingBox.left.coerceIn(0, bitmap.width - 1),
mlFace.boundingBox.top.coerceIn(0, bitmap.height - 1),
mlFace.boundingBox.width().coerceAtMost(bitmap.width - mlFace.boundingBox.left),
mlFace.boundingBox.height().coerceAtMost(bitmap.height - mlFace.boundingBox.top)
)
val embedding = faceNetModel.generateEmbedding(faceBitmap) val embedding = faceNetModel.generateEmbedding(faceBitmap)
faceBitmap.recycle() faceBitmap.recycle()
@@ -591,13 +587,8 @@ class FaceClusteringService @Inject constructor(
if (!qualityCheck.isValid) return@mapNotNull null if (!qualityCheck.isValid) return@mapNotNull null
try { try {
val faceBitmap = Bitmap.createBitmap( val faceBitmap = FaceNormalizer.cropAndNormalize(bitmap, face)
bitmap, ?: return@mapNotNull null
face.boundingBox.left.coerceIn(0, bitmap.width - 1),
face.boundingBox.top.coerceIn(0, bitmap.height - 1),
face.boundingBox.width().coerceAtMost(bitmap.width - face.boundingBox.left),
face.boundingBox.height().coerceAtMost(bitmap.height - face.boundingBox.top)
)
val embedding = faceNetModel.generateEmbedding(faceBitmap) val embedding = faceNetModel.generateEmbedding(faceBitmap)
faceBitmap.recycle() faceBitmap.recycle()

View File

@@ -29,6 +29,64 @@ import kotlin.math.sqrt
*/ */
object FaceQualityFilter { object FaceQualityFilter {
/**
* Age group estimation for filtering (child vs adult detection)
*/
enum class AgeGroup { CHILD, ADULT, UNCERTAIN }
/**
* Estimate whether a face belongs to a child or adult based on facial proportions.
*
* Uses two heuristics:
* 1. Eye position ratio - Children have larger foreheads, so eyes are lower (~45% from top)
* Adults have eyes at ~35% from top
* 2. Face roundness (width/height ratio) - Children: ~0.85-1.0, Adults: ~0.7-0.85
*
* @return AgeGroup.CHILD, AgeGroup.ADULT, or AgeGroup.UNCERTAIN
*/
fun estimateAgeGroup(face: Face, imageWidth: Int, imageHeight: Int): AgeGroup {
val leftEye = face.getLandmark(FaceLandmark.LEFT_EYE)
val rightEye = face.getLandmark(FaceLandmark.RIGHT_EYE)
if (leftEye == null || rightEye == null) {
return AgeGroup.UNCERTAIN
}
// Eye-to-face height ratio (where eyes sit relative to face top)
val faceHeight = face.boundingBox.height().toFloat()
val faceTop = face.boundingBox.top.toFloat()
val eyeY = (leftEye.position.y + rightEye.position.y) / 2
val eyePositionRatio = (eyeY - faceTop) / faceHeight
// Children: eyes at ~45% from top (larger forehead proportionally)
// Adults: eyes at ~35% from top
// Score: higher = more child-like
// Face roundness (width/height)
val faceWidth = face.boundingBox.width().toFloat()
val faceRatio = faceWidth / faceHeight
// Children: ratio ~0.85-1.0 (rounder faces)
// Adults: ratio ~0.7-0.85 (longer/narrower faces)
var childScore = 0
// Eye position scoring
if (eyePositionRatio > 0.45f) childScore += 2 // Strong child signal
else if (eyePositionRatio > 0.42f) childScore += 1 // Mild child signal
else if (eyePositionRatio < 0.35f) childScore -= 1 // Adult signal
// Face roundness scoring
if (faceRatio > 0.90f) childScore += 2 // Very round = child
else if (faceRatio > 0.82f) childScore += 1 // Somewhat round
else if (faceRatio < 0.75f) childScore -= 1 // Long face = adult
return when {
childScore >= 3 -> AgeGroup.CHILD
childScore <= 0 -> AgeGroup.ADULT
else -> AgeGroup.UNCERTAIN
}
}
/** /**
* Validate face for Discovery/Clustering * Validate face for Discovery/Clustering
* *

View File

@@ -75,7 +75,21 @@ class PopulateFaceDetectionCacheUseCase @Inject constructor(
) )
try { try {
val imagesToScan = imageDao.getImagesNeedingFaceDetection() // Get images that need face detection (hasFaces IS NULL)
var imagesToScan = imageDao.getImagesNeedingFaceDetection()
// CRITICAL FIX: Also check for images marked as having faces but no FaceCacheEntity
if (imagesToScan.isEmpty()) {
val faceStats = faceCacheDao.getCacheStats()
if (faceStats.totalFaces == 0) {
// FaceCacheEntity is empty - rescan images that have faces
val imagesWithFaces = imageDao.getImagesWithFaces()
if (imagesWithFaces.isNotEmpty()) {
Log.w(TAG, "FaceCacheEntity empty but ${imagesWithFaces.size} images have faces - rescanning")
imagesToScan = imagesWithFaces
}
}
}
if (imagesToScan.isEmpty()) { if (imagesToScan.isEmpty()) {
Log.d(TAG, "No images need scanning") Log.d(TAG, "No images need scanning")
@@ -184,7 +198,7 @@ class PopulateFaceDetectionCacheUseCase @Inject constructor(
imageUri = image.imageUri imageUri = image.imageUri
) )
// Create FaceCacheEntity entries for each face // Create FaceCacheEntity entries for each face (NO embeddings - generated on demand)
val faceCacheEntries = faces.mapIndexed { index, face -> val faceCacheEntries = faces.mapIndexed { index, face ->
createFaceCacheEntry( createFaceCacheEntry(
imageId = image.imageId, imageId = image.imageId,
@@ -205,7 +219,8 @@ class PopulateFaceDetectionCacheUseCase @Inject constructor(
/** /**
* Create FaceCacheEntity from ML Kit Face * Create FaceCacheEntity from ML Kit Face
* *
* Uses FaceCacheEntity.create() which calculates quality metrics automatically * Uses FaceCacheEntity.create() which calculates quality metrics automatically.
* Embeddings are NOT generated here - they're generated on-demand in Training/Discovery.
*/ */
private fun createFaceCacheEntry( private fun createFaceCacheEntry(
imageId: String, imageId: String,
@@ -225,7 +240,7 @@ class PopulateFaceDetectionCacheUseCase @Inject constructor(
imageHeight = imageHeight, imageHeight = imageHeight,
confidence = 0.9f, // High confidence from accurate detector confidence = 0.9f, // High confidence from accurate detector
isFrontal = isFrontal, isFrontal = isFrontal,
embedding = null // Will be generated later during Discovery embedding = null // Generated on-demand in Training/Discovery
) )
} }
@@ -312,13 +327,27 @@ class PopulateFaceDetectionCacheUseCase @Inject constructor(
val imageStats = imageDao.getFaceCacheStats() val imageStats = imageDao.getFaceCacheStats()
val faceStats = faceCacheDao.getCacheStats() val faceStats = faceCacheDao.getCacheStats()
// CRITICAL FIX: If ImageEntity says "scanned" but FaceCacheEntity is empty,
// we need to re-scan. This happens after DB migration clears face_cache table.
val imagesWithFaces = imageStats?.imagesWithFaces ?: 0
val facesCached = faceStats.totalFaces
// If we have images marked as having faces but no FaceCacheEntity entries,
// those images need re-scanning
val needsRescan = if (imagesWithFaces > 0 && facesCached == 0) {
Log.w(TAG, "⚠️ FaceCacheEntity is empty but $imagesWithFaces images marked as having faces - forcing rescan")
imagesWithFaces
} else {
imageStats?.needsScanning ?: 0
}
CacheStats( CacheStats(
totalImages = imageStats?.totalImages ?: 0, totalImages = imageStats?.totalImages ?: 0,
imagesWithFaceCache = imageStats?.imagesWithFaceCache ?: 0, imagesWithFaceCache = imageStats?.imagesWithFaceCache ?: 0,
imagesWithFaces = imageStats?.imagesWithFaces ?: 0, imagesWithFaces = imagesWithFaces,
imagesWithoutFaces = imageStats?.imagesWithoutFaces ?: 0, imagesWithoutFaces = imageStats?.imagesWithoutFaces ?: 0,
needsScanning = imageStats?.needsScanning ?: 0, needsScanning = needsRescan,
totalFacesCached = faceStats.totalFaces, totalFacesCached = facesCached,
facesWithEmbeddings = faceStats.withEmbeddings, facesWithEmbeddings = faceStats.withEmbeddings,
averageQuality = faceStats.avgQuality averageQuality = faceStats.avgQuality
) )

View File

@@ -20,6 +20,7 @@ import androidx.lifecycle.compose.collectAsStateWithLifecycle
import androidx.navigation.NavController import androidx.navigation.NavController
import coil.compose.AsyncImage import coil.compose.AsyncImage
import com.placeholder.sherpai2.data.local.entity.TagEntity import com.placeholder.sherpai2.data.local.entity.TagEntity
import com.placeholder.sherpai2.ui.imagedetail.viewmodel.FaceTagInfo
import com.placeholder.sherpai2.ui.imagedetail.viewmodel.ImageDetailViewModel import com.placeholder.sherpai2.ui.imagedetail.viewmodel.ImageDetailViewModel
import net.engawapg.lib.zoomable.rememberZoomState import net.engawapg.lib.zoomable.rememberZoomState
import net.engawapg.lib.zoomable.zoomable import net.engawapg.lib.zoomable.zoomable
@@ -51,8 +52,12 @@ fun ImageDetailScreen(
} }
val tags by viewModel.tags.collectAsStateWithLifecycle() val tags by viewModel.tags.collectAsStateWithLifecycle()
val faceTags by viewModel.faceTags.collectAsStateWithLifecycle()
var showTags by remember { mutableStateOf(false) } var showTags by remember { mutableStateOf(false) }
// Total tag count for badge
val totalTagCount = tags.size + faceTags.size
// Navigation state // Navigation state
val currentIndex = if (allImageUris.isNotEmpty()) allImageUris.indexOf(imageUri) else -1 val currentIndex = if (allImageUris.isNotEmpty()) allImageUris.indexOf(imageUri) else -1
val hasNavigation = allImageUris.isNotEmpty() && currentIndex >= 0 val hasNavigation = allImageUris.isNotEmpty() && currentIndex >= 0
@@ -84,27 +89,35 @@ fun ImageDetailScreen(
horizontalArrangement = Arrangement.spacedBy(4.dp), horizontalArrangement = Arrangement.spacedBy(4.dp),
verticalAlignment = Alignment.CenterVertically verticalAlignment = Alignment.CenterVertically
) { ) {
if (tags.isNotEmpty()) { if (totalTagCount > 0) {
Badge( Badge(
containerColor = if (showTags) containerColor = if (showTags)
MaterialTheme.colorScheme.primary MaterialTheme.colorScheme.primary
else if (faceTags.isNotEmpty())
MaterialTheme.colorScheme.tertiary
else else
MaterialTheme.colorScheme.surfaceVariant MaterialTheme.colorScheme.surfaceVariant
) { ) {
Text( Text(
tags.size.toString(), totalTagCount.toString(),
color = if (showTags) color = if (showTags)
MaterialTheme.colorScheme.onPrimary MaterialTheme.colorScheme.onPrimary
else if (faceTags.isNotEmpty())
MaterialTheme.colorScheme.onTertiary
else else
MaterialTheme.colorScheme.onSurfaceVariant MaterialTheme.colorScheme.onSurfaceVariant
) )
} }
} }
Icon( Icon(
if (showTags) Icons.Default.Label else Icons.Default.LocalOffer, if (faceTags.isNotEmpty()) Icons.Default.Face
else if (showTags) Icons.Default.Label
else Icons.Default.LocalOffer,
"Show Tags", "Show Tags",
tint = if (showTags) tint = if (showTags)
MaterialTheme.colorScheme.primary MaterialTheme.colorScheme.primary
else if (faceTags.isNotEmpty())
MaterialTheme.colorScheme.tertiary
else else
MaterialTheme.colorScheme.onSurfaceVariant MaterialTheme.colorScheme.onSurfaceVariant
) )
@@ -189,6 +202,30 @@ fun ImageDetailScreen(
contentPadding = PaddingValues(16.dp), contentPadding = PaddingValues(16.dp),
verticalArrangement = Arrangement.spacedBy(8.dp) verticalArrangement = Arrangement.spacedBy(8.dp)
) { ) {
// Face Tags Section (People in Photo)
if (faceTags.isNotEmpty()) {
item {
Text(
"People (${faceTags.size})",
style = MaterialTheme.typography.titleMedium,
fontWeight = FontWeight.Bold,
color = MaterialTheme.colorScheme.tertiary
)
}
items(faceTags, key = { it.tagId }) { faceTag ->
FaceTagCard(
faceTag = faceTag,
onRemove = { viewModel.removeFaceTag(faceTag) }
)
}
item {
Spacer(modifier = Modifier.height(8.dp))
}
}
// Regular Tags Section
item { item {
Text( Text(
"Tags (${tags.size})", "Tags (${tags.size})",
@@ -197,7 +234,7 @@ fun ImageDetailScreen(
) )
} }
if (tags.isEmpty()) { if (tags.isEmpty() && faceTags.isEmpty()) {
item { item {
Text( Text(
"No tags yet", "No tags yet",
@@ -205,6 +242,14 @@ fun ImageDetailScreen(
color = MaterialTheme.colorScheme.onSurfaceVariant color = MaterialTheme.colorScheme.onSurfaceVariant
) )
} }
} else if (tags.isEmpty()) {
item {
Text(
"No other tags",
style = MaterialTheme.typography.bodySmall,
color = MaterialTheme.colorScheme.onSurfaceVariant
)
}
} }
items(tags, key = { it.tagId }) { tag -> items(tags, key = { it.tagId }) { tag ->
@@ -220,6 +265,83 @@ fun ImageDetailScreen(
} }
} }
@Composable
private fun FaceTagCard(
faceTag: FaceTagInfo,
onRemove: () -> Unit
) {
Card(
modifier = Modifier.fillMaxWidth(),
colors = CardDefaults.cardColors(
containerColor = MaterialTheme.colorScheme.tertiaryContainer
),
shape = RoundedCornerShape(8.dp)
) {
Row(
modifier = Modifier
.fillMaxWidth()
.padding(12.dp),
horizontalArrangement = Arrangement.SpaceBetween,
verticalAlignment = Alignment.CenterVertically
) {
Column(modifier = Modifier.weight(1f)) {
Row(
horizontalArrangement = Arrangement.spacedBy(8.dp),
verticalAlignment = Alignment.CenterVertically
) {
Icon(
imageVector = Icons.Default.Face,
contentDescription = null,
modifier = Modifier.size(20.dp),
tint = MaterialTheme.colorScheme.tertiary
)
Text(
text = faceTag.personName,
style = MaterialTheme.typography.bodyLarge,
fontWeight = FontWeight.SemiBold
)
}
Row(
horizontalArrangement = Arrangement.spacedBy(4.dp),
verticalAlignment = Alignment.CenterVertically
) {
Text(
text = "Face Recognition",
style = MaterialTheme.typography.labelSmall,
color = MaterialTheme.colorScheme.onSurfaceVariant
)
Text(
text = "",
style = MaterialTheme.typography.labelSmall,
color = MaterialTheme.colorScheme.onSurfaceVariant
)
Text(
text = "${(faceTag.confidence * 100).toInt()}% confidence",
style = MaterialTheme.typography.labelSmall,
color = if (faceTag.confidence >= 0.7f)
MaterialTheme.colorScheme.primary
else if (faceTag.confidence >= 0.5f)
MaterialTheme.colorScheme.secondary
else
MaterialTheme.colorScheme.error
)
}
}
// Remove button
IconButton(
onClick = onRemove,
colors = IconButtonDefaults.iconButtonColors(
contentColor = MaterialTheme.colorScheme.error
)
) {
Icon(Icons.Default.Delete, "Remove face tag")
}
}
}
}
@Composable @Composable
private fun TagCard( private fun TagCard(
tag: TagEntity, tag: TagEntity,

View File

@@ -2,6 +2,10 @@ package com.placeholder.sherpai2.ui.imagedetail.viewmodel
import androidx.lifecycle.ViewModel import androidx.lifecycle.ViewModel
import androidx.lifecycle.viewModelScope import androidx.lifecycle.viewModelScope
import com.placeholder.sherpai2.data.local.dao.FaceModelDao
import com.placeholder.sherpai2.data.local.dao.ImageDao
import com.placeholder.sherpai2.data.local.dao.PersonDao
import com.placeholder.sherpai2.data.local.dao.PhotoFaceTagDao
import com.placeholder.sherpai2.data.local.entity.TagEntity import com.placeholder.sherpai2.data.local.entity.TagEntity
import com.placeholder.sherpai2.domain.repository.TaggingRepository import com.placeholder.sherpai2.domain.repository.TaggingRepository
import dagger.hilt.android.lifecycle.HiltViewModel import dagger.hilt.android.lifecycle.HiltViewModel
@@ -10,17 +14,33 @@ import kotlinx.coroutines.flow.*
import kotlinx.coroutines.launch import kotlinx.coroutines.launch
import javax.inject.Inject import javax.inject.Inject
/**
* Represents a person tagged in this photo via face recognition
*/
data class FaceTagInfo(
val personId: String,
val personName: String,
val confidence: Float,
val faceModelId: String,
val tagId: String
)
/** /**
* ImageDetailViewModel * ImageDetailViewModel
* *
* Owns: * Owns:
* - Image context * - Image context
* - Tag write operations * - Tag write operations
* - Face tag display (people recognized in photo)
*/ */
@HiltViewModel @HiltViewModel
@OptIn(ExperimentalCoroutinesApi::class) @OptIn(ExperimentalCoroutinesApi::class)
class ImageDetailViewModel @Inject constructor( class ImageDetailViewModel @Inject constructor(
private val tagRepository: TaggingRepository private val tagRepository: TaggingRepository,
private val imageDao: ImageDao,
private val photoFaceTagDao: PhotoFaceTagDao,
private val faceModelDao: FaceModelDao,
private val personDao: PersonDao
) : ViewModel() { ) : ViewModel() {
private val imageUri = MutableStateFlow<String?>(null) private val imageUri = MutableStateFlow<String?>(null)
@@ -37,8 +57,43 @@ class ImageDetailViewModel @Inject constructor(
initialValue = emptyList() initialValue = emptyList()
) )
// Face tags (people recognized in this photo)
private val _faceTags = MutableStateFlow<List<FaceTagInfo>>(emptyList())
val faceTags: StateFlow<List<FaceTagInfo>> = _faceTags.asStateFlow()
fun loadImage(uri: String) { fun loadImage(uri: String) {
imageUri.value = uri imageUri.value = uri
loadFaceTags(uri)
}
private fun loadFaceTags(uri: String) {
viewModelScope.launch {
try {
// Get imageId from URI
val image = imageDao.getImageByUri(uri) ?: return@launch
// Get face tags for this image
val faceTags = photoFaceTagDao.getTagsForImage(image.imageId)
// Resolve to person names
val faceTagInfos = faceTags.mapNotNull { tag ->
val faceModel = faceModelDao.getFaceModelById(tag.faceModelId) ?: return@mapNotNull null
val person = personDao.getPersonById(faceModel.personId) ?: return@mapNotNull null
FaceTagInfo(
personId = person.id,
personName = person.name,
confidence = tag.confidence,
faceModelId = tag.faceModelId,
tagId = tag.id
)
}
_faceTags.value = faceTagInfos.sortedByDescending { it.confidence }
} catch (e: Exception) {
_faceTags.value = emptyList()
}
}
} }
fun addTag(value: String) { fun addTag(value: String) {
@@ -54,4 +109,15 @@ class ImageDetailViewModel @Inject constructor(
tagRepository.removeTagFromImage(uri, tag.value) tagRepository.removeTagFromImage(uri, tag.value)
} }
} }
/**
* Remove a face tag (person recognition)
*/
fun removeFaceTag(faceTagInfo: FaceTagInfo) {
viewModelScope.launch {
photoFaceTagDao.deleteTagById(faceTagInfo.tagId)
// Reload face tags
imageUri.value?.let { loadFaceTags(it) }
}
}
} }

View File

@@ -95,6 +95,9 @@ fun PersonInventoryScreen(
}, },
onDelete = { personId -> onDelete = { personId ->
viewModel.deletePerson(personId) viewModel.deletePerson(personId)
},
onClearTags = { personId ->
viewModel.clearTagsForPerson(personId)
} }
) )
} }
@@ -319,7 +322,8 @@ private fun PersonList(
persons: List<PersonWithModelInfo>, persons: List<PersonWithModelInfo>,
onScan: (String) -> Unit, onScan: (String) -> Unit,
onView: (String) -> Unit, onView: (String) -> Unit,
onDelete: (String) -> Unit onDelete: (String) -> Unit,
onClearTags: (String) -> Unit
) { ) {
LazyColumn( LazyColumn(
contentPadding = PaddingValues(vertical = 8.dp) contentPadding = PaddingValues(vertical = 8.dp)
@@ -332,7 +336,8 @@ private fun PersonList(
person = person, person = person,
onScan = { onScan(person.person.id) }, onScan = { onScan(person.person.id) },
onView = { onView(person.person.id) }, onView = { onView(person.person.id) },
onDelete = { onDelete(person.person.id) } onDelete = { onDelete(person.person.id) },
onClearTags = { onClearTags(person.person.id) }
) )
} }
} }
@@ -343,9 +348,34 @@ private fun PersonCard(
person: PersonWithModelInfo, person: PersonWithModelInfo,
onScan: () -> Unit, onScan: () -> Unit,
onView: () -> Unit, onView: () -> Unit,
onDelete: () -> Unit onDelete: () -> Unit,
onClearTags: () -> Unit
) { ) {
var showDeleteDialog by remember { mutableStateOf(false) } var showDeleteDialog by remember { mutableStateOf(false) }
var showClearDialog by remember { mutableStateOf(false) }
if (showClearDialog) {
AlertDialog(
onDismissRequest = { showClearDialog = false },
title = { Text("Clear tags for ${person.person.name}?") },
text = { Text("This will remove all ${person.taggedPhotoCount} photo tags but keep the face model. You can re-scan after clearing.") },
confirmButton = {
TextButton(
onClick = {
showClearDialog = false
onClearTags()
}
) {
Text("Clear Tags", color = MaterialTheme.colorScheme.error)
}
},
dismissButton = {
TextButton(onClick = { showClearDialog = false }) {
Text("Cancel")
}
}
)
}
if (showDeleteDialog) { if (showDeleteDialog) {
AlertDialog( AlertDialog(
@@ -413,6 +443,17 @@ private fun PersonCard(
) )
} }
// Clear tags button (if has tags)
if (person.taggedPhotoCount > 0) {
IconButton(onClick = { showClearDialog = true }) {
Icon(
Icons.Default.ClearAll,
contentDescription = "Clear Tags",
tint = MaterialTheme.colorScheme.secondary
)
}
}
// Delete button // Delete button
IconButton(onClick = { showDeleteDialog = true }) { IconButton(onClick = { showDeleteDialog = true }) {
Icon( Icon(

View File

@@ -19,6 +19,7 @@ import com.placeholder.sherpai2.data.local.entity.PersonEntity
import com.placeholder.sherpai2.data.local.entity.PhotoFaceTagEntity import com.placeholder.sherpai2.data.local.entity.PhotoFaceTagEntity
import com.placeholder.sherpai2.ml.FaceNetModel import com.placeholder.sherpai2.ml.FaceNetModel
import com.placeholder.sherpai2.ml.ThresholdStrategy import com.placeholder.sherpai2.ml.ThresholdStrategy
import com.placeholder.sherpai2.domain.clustering.FaceQualityFilter
import dagger.hilt.android.lifecycle.HiltViewModel import dagger.hilt.android.lifecycle.HiltViewModel
import dagger.hilt.android.qualifiers.ApplicationContext import dagger.hilt.android.qualifiers.ApplicationContext
import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.Dispatchers
@@ -105,6 +106,21 @@ class PersonInventoryViewModel @Inject constructor(
} }
} }
/**
* Clear all face tags for a person (keep model, allow rescan)
*/
fun clearTagsForPerson(personId: String) {
viewModelScope.launch(Dispatchers.IO) {
try {
val faceModel = faceModelDao.getFaceModelByPersonId(personId)
if (faceModel != null) {
photoFaceTagDao.deleteTagsForFaceModel(faceModel.id)
}
loadPersons()
} catch (e: Exception) {}
}
}
fun scanForPerson(personId: String) { fun scanForPerson(personId: String) {
viewModelScope.launch(Dispatchers.IO) { viewModelScope.launch(Dispatchers.IO) {
try { try {
@@ -127,16 +143,40 @@ class PersonInventoryViewModel @Inject constructor(
val detectorOptions = FaceDetectorOptions.Builder() val detectorOptions = FaceDetectorOptions.Builder()
.setPerformanceMode(FaceDetectorOptions.PERFORMANCE_MODE_ACCURATE) .setPerformanceMode(FaceDetectorOptions.PERFORMANCE_MODE_ACCURATE)
.setLandmarkMode(FaceDetectorOptions.LANDMARK_MODE_NONE) .setLandmarkMode(FaceDetectorOptions.LANDMARK_MODE_ALL) // Needed for age estimation
.setClassificationMode(FaceDetectorOptions.CLASSIFICATION_MODE_NONE) .setClassificationMode(FaceDetectorOptions.CLASSIFICATION_MODE_NONE)
.setMinFaceSize(0.15f) .setMinFaceSize(0.15f)
.build() .build()
val detector = FaceDetection.getClient(detectorOptions) val detector = FaceDetection.getClient(detectorOptions)
val modelEmbedding = faceModel.getEmbeddingArray() // CRITICAL: Use ALL centroids for matching
val faceNetModel = FaceNetModel(context) val modelCentroids = faceModel.getCentroids().map { it.getEmbeddingArray() }
val trainingCount = faceModel.trainingImageCount val trainingCount = faceModel.trainingImageCount
val baseThreshold = ThresholdStrategy.getLiberalThreshold(trainingCount) android.util.Log.e("PersonScan", "=== CENTROIDS: ${modelCentroids.size}, trainingCount: $trainingCount ===")
if (modelCentroids.isEmpty()) {
_scanningState.value = ScanningState.Error("No centroids found")
return@launch
}
val faceNetModel = FaceNetModel(context)
// Production threshold - STRICT to avoid false positives
// Solo face photos: 0.62, Group photos: 0.68
val baseThreshold = 0.62f
val groupPhotoThreshold = 0.68f // Higher bar for multi-face images
// Load ALL other models for "best match wins" comparison
val allModels = faceModelDao.getAllActiveFaceModels()
val otherModelCentroids = allModels
.filter { it.id != faceModel.id }
.map { model -> model.id to model.getCentroids().map { it.getEmbeddingArray() } }
// Distribution-based minimum threshold (self-calibrating)
val distributionMin = (faceModel.averageConfidence - 2 * faceModel.similarityStdDev)
.coerceAtLeast(faceModel.similarityMin - 0.05f)
.coerceAtLeast(0.50f) // Never go below 0.50 absolute floor
android.util.Log.d("PersonScan", "Using threshold: solo=$baseThreshold, group=$groupPhotoThreshold, distributionMin=$distributionMin (avgConf=${faceModel.averageConfidence}, stdDev=${faceModel.similarityStdDev}), centroids: ${modelCentroids.size}, competing models: ${otherModelCentroids.size}, isChild=${person.isChild}")
val completed = AtomicInteger(0) val completed = AtomicInteger(0)
val facesFound = AtomicInteger(0) val facesFound = AtomicInteger(0)
@@ -148,7 +188,7 @@ class PersonInventoryViewModel @Inject constructor(
val jobs = untaggedImages.map { image -> val jobs = untaggedImages.map { image ->
async { async {
semaphore.withPermit { semaphore.withPermit {
processImage(image, detector, faceNetModel, modelEmbedding, trainingCount, baseThreshold, personId, faceModel.id, batchMatches, batchUpdateMutex, completed, facesFound, startTime, totalToScan, person.name) processImage(image, detector, faceNetModel, modelCentroids, otherModelCentroids, trainingCount, baseThreshold, groupPhotoThreshold, distributionMin, person.isChild, personId, faceModel.id, batchMatches, batchUpdateMutex, completed, facesFound, startTime, totalToScan, person.name)
} }
} }
} }
@@ -175,7 +215,10 @@ class PersonInventoryViewModel @Inject constructor(
private suspend fun processImage( private suspend fun processImage(
image: ImageEntity, detector: com.google.mlkit.vision.face.FaceDetector, faceNetModel: FaceNetModel, image: ImageEntity, detector: com.google.mlkit.vision.face.FaceDetector, faceNetModel: FaceNetModel,
modelEmbedding: FloatArray, trainingCount: Int, baseThreshold: Float, personId: String, faceModelId: String, modelCentroids: List<FloatArray>, otherModelCentroids: List<Pair<String, List<FloatArray>>>,
trainingCount: Int, baseThreshold: Float, groupPhotoThreshold: Float,
distributionMin: Float, isChildTarget: Boolean,
personId: String, faceModelId: String,
batchMatches: MutableList<Triple<String, String, Float>>, batchUpdateMutex: Mutex, batchMatches: MutableList<Triple<String, String, Float>>, batchUpdateMutex: Mutex,
completed: AtomicInteger, facesFound: AtomicInteger, startTime: Long, totalToScan: Int, personName: String completed: AtomicInteger, facesFound: AtomicInteger, startTime: Long, totalToScan: Int, personName: String
) { ) {
@@ -200,9 +243,13 @@ class PersonInventoryViewModel @Inject constructor(
val scaleX = sizeOpts.outWidth.toFloat() / detectionBitmap.width val scaleX = sizeOpts.outWidth.toFloat() / detectionBitmap.width
val scaleY = sizeOpts.outHeight.toFloat() / detectionBitmap.height val scaleY = sizeOpts.outHeight.toFloat() / detectionBitmap.height
val imageQuality = ThresholdStrategy.estimateImageQuality(sizeOpts.outWidth, sizeOpts.outHeight) // CRITICAL: Use higher threshold for group photos (more likely false positives)
val detectionContext = ThresholdStrategy.estimateDetectionContext(faces.size) val isGroupPhoto = faces.size > 1
val threshold = ThresholdStrategy.getOptimalThreshold(trainingCount, imageQuality, detectionContext).coerceAtMost(baseThreshold) val effectiveThreshold = if (isGroupPhoto) groupPhotoThreshold else baseThreshold
// Track best match in this image (only tag ONE face per image)
var bestMatchSimilarity = 0f
var foundMatch = false
for (face in faces) { for (face in faces) {
val scaledBounds = android.graphics.Rect( val scaledBounds = android.graphics.Rect(
@@ -212,14 +259,62 @@ class PersonInventoryViewModel @Inject constructor(
(face.boundingBox.bottom * scaleY).toInt() (face.boundingBox.bottom * scaleY).toInt()
) )
val faceBitmap = loadFaceRegion(uri, scaledBounds) ?: continue // Skip very small faces (less reliable)
val faceArea = scaledBounds.width() * scaledBounds.height()
val imageArea = sizeOpts.outWidth * sizeOpts.outHeight
val faceRatio = faceArea.toFloat() / imageArea
if (faceRatio < 0.02f) continue // Face must be at least 2% of image
// SIGNAL 2: Age plausibility check (if target is a child)
if (isChildTarget) {
val ageGroup = FaceQualityFilter.estimateAgeGroup(face, detectionBitmap.width, detectionBitmap.height)
if (ageGroup == FaceQualityFilter.AgeGroup.ADULT) {
continue // Reject clearly adult faces when searching for a child
}
}
// CRITICAL: Add padding to face crop (same as training)
val faceBitmap = loadFaceRegionWithPadding(uri, scaledBounds, sizeOpts.outWidth, sizeOpts.outHeight) ?: continue
val faceEmbedding = faceNetModel.generateEmbedding(faceBitmap) val faceEmbedding = faceNetModel.generateEmbedding(faceBitmap)
val similarity = faceNetModel.calculateSimilarity(faceEmbedding, modelEmbedding)
faceBitmap.recycle() faceBitmap.recycle()
if (similarity >= threshold) { // Match against target person's centroids
val targetSimilarity = modelCentroids.maxOfOrNull { centroid ->
faceNetModel.calculateSimilarity(faceEmbedding, centroid)
} ?: 0f
// SIGNAL 1: Distribution-based rejection
// If similarity is below (mean - 2*stdDev) or (min - 0.05), it's a statistical outlier
if (targetSimilarity < distributionMin) {
continue // Too far below training distribution
}
// SIGNAL 3: Basic threshold check
if (targetSimilarity < effectiveThreshold) {
continue
}
// SIGNAL 4: "Best match wins" - check if any OTHER model scores higher
// This prevents tagging siblings/similar people incorrectly
val bestOtherSimilarity = otherModelCentroids.maxOfOrNull { (_, centroids) ->
centroids.maxOfOrNull { centroid ->
faceNetModel.calculateSimilarity(faceEmbedding, centroid)
} ?: 0f
} ?: 0f
val isTargetBestMatch = targetSimilarity > bestOtherSimilarity
// All signals must pass
if (isTargetBestMatch && targetSimilarity > bestMatchSimilarity) {
bestMatchSimilarity = targetSimilarity
foundMatch = true
}
}
// Only add ONE tag per image (the best match)
if (foundMatch) {
batchUpdateMutex.withLock { batchUpdateMutex.withLock {
batchMatches.add(Triple(personId, image.imageId, similarity)) batchMatches.add(Triple(personId, image.imageId, bestMatchSimilarity))
facesFound.incrementAndGet() facesFound.incrementAndGet()
if (batchMatches.size >= BATCH_DB_SIZE) { if (batchMatches.size >= BATCH_DB_SIZE) {
saveBatchMatches(batchMatches.toList(), faceModelId) saveBatchMatches(batchMatches.toList(), faceModelId)
@@ -227,7 +322,7 @@ class PersonInventoryViewModel @Inject constructor(
} }
} }
} }
}
detectionBitmap.recycle() detectionBitmap.recycle()
} catch (e: Exception) { } catch (e: Exception) {
} finally { } finally {
@@ -250,18 +345,32 @@ class PersonInventoryViewModel @Inject constructor(
} catch (e: Exception) { null } } catch (e: Exception) { null }
} }
private fun loadFaceRegion(uri: Uri, bounds: android.graphics.Rect): Bitmap? { /**
* Load face region WITH 25% padding - CRITICAL for matching training conditions
*/
private fun loadFaceRegionWithPadding(uri: Uri, bounds: android.graphics.Rect, imgWidth: Int, imgHeight: Int): Bitmap? {
return try { return try {
val full = context.contentResolver.openInputStream(uri)?.use { val full = context.contentResolver.openInputStream(uri)?.use {
BitmapFactory.decodeStream(it, null, BitmapFactory.Options().apply { inPreferredConfig = Bitmap.Config.ARGB_8888 }) BitmapFactory.decodeStream(it, null, BitmapFactory.Options().apply { inPreferredConfig = Bitmap.Config.ARGB_8888 })
} ?: return null } ?: return null
val safeLeft = bounds.left.coerceIn(0, full.width - 1) // Add 25% padding (same as training)
val safeTop = bounds.top.coerceIn(0, full.height - 1) val padding = (kotlin.math.max(bounds.width(), bounds.height()) * 0.25f).toInt()
val safeWidth = bounds.width().coerceAtMost(full.width - safeLeft)
val safeHeight = bounds.height().coerceAtMost(full.height - safeTop)
val cropped = Bitmap.createBitmap(full, safeLeft, safeTop, safeWidth, safeHeight) val left = (bounds.left - padding).coerceAtLeast(0)
val top = (bounds.top - padding).coerceAtLeast(0)
val right = (bounds.right + padding).coerceAtMost(full.width)
val bottom = (bounds.bottom + padding).coerceAtMost(full.height)
val width = right - left
val height = bottom - top
if (width <= 0 || height <= 0) {
full.recycle()
return null
}
val cropped = Bitmap.createBitmap(full, left, top, width, height)
full.recycle() full.recycle()
cropped cropped
} catch (e: Exception) { null } } catch (e: Exception) { null }

View File

@@ -339,10 +339,7 @@ fun AppNavHost(
* SETTINGS SCREEN * SETTINGS SCREEN
*/ */
composable(AppRoutes.SETTINGS) { composable(AppRoutes.SETTINGS) {
DummyScreen( com.placeholder.sherpai2.ui.settings.SettingsScreen()
title = "Settings",
subtitle = "App preferences and configuration"
)
} }
} }
} }

View File

@@ -78,6 +78,7 @@ fun MainScreen(
AppRoutes.DISCOVER -> "Discover People" // ✅ SHOWS NOW! AppRoutes.DISCOVER -> "Discover People" // ✅ SHOWS NOW!
AppRoutes.INVENTORY -> "People" AppRoutes.INVENTORY -> "People"
AppRoutes.TRAIN -> "Train Model" AppRoutes.TRAIN -> "Train Model"
AppRoutes.ScanResultsScreen -> "Train New Person"
AppRoutes.TAGS -> "Tags" AppRoutes.TAGS -> "Tags"
AppRoutes.UTILITIES -> "Utilities" AppRoutes.UTILITIES -> "Utilities"
AppRoutes.SETTINGS -> "Settings" AppRoutes.SETTINGS -> "Settings"

View File

@@ -2,7 +2,9 @@ package com.placeholder.sherpai2.ui.rollingscan
import android.net.Uri import android.net.Uri
import androidx.compose.foundation.BorderStroke import androidx.compose.foundation.BorderStroke
import androidx.compose.foundation.ExperimentalFoundationApi
import androidx.compose.foundation.clickable import androidx.compose.foundation.clickable
import androidx.compose.foundation.combinedClickable
import androidx.compose.foundation.layout.* import androidx.compose.foundation.layout.*
import androidx.compose.foundation.lazy.grid.GridCells import androidx.compose.foundation.lazy.grid.GridCells
import androidx.compose.foundation.lazy.grid.GridItemSpan import androidx.compose.foundation.lazy.grid.GridItemSpan
@@ -37,7 +39,7 @@ import com.placeholder.sherpai2.domain.similarity.FaceSimilarityScorer
* - Quick action buttons (Select Top N) * - Quick action buttons (Select Top N)
* - Submit button with validation * - Submit button with validation
*/ */
@OptIn(ExperimentalMaterial3Api::class) @OptIn(ExperimentalMaterial3Api::class, ExperimentalFoundationApi::class)
@Composable @Composable
fun RollingScanScreen( fun RollingScanScreen(
seedImageIds: List<String>, seedImageIds: List<String>,
@@ -48,6 +50,7 @@ fun RollingScanScreen(
) { ) {
val uiState by viewModel.uiState.collectAsState() val uiState by viewModel.uiState.collectAsState()
val selectedImageIds by viewModel.selectedImageIds.collectAsState() val selectedImageIds by viewModel.selectedImageIds.collectAsState()
val negativeImageIds by viewModel.negativeImageIds.collectAsState()
val rankedPhotos by viewModel.rankedPhotos.collectAsState() val rankedPhotos by viewModel.rankedPhotos.collectAsState()
val isScanning by viewModel.isScanning.collectAsState() val isScanning by viewModel.isScanning.collectAsState()
@@ -70,6 +73,7 @@ fun RollingScanScreen(
isReadyForTraining = viewModel.isReadyForTraining(), isReadyForTraining = viewModel.isReadyForTraining(),
validationMessage = viewModel.getValidationMessage(), validationMessage = viewModel.getValidationMessage(),
onSelectTopN = { count -> viewModel.selectTopN(count) }, onSelectTopN = { count -> viewModel.selectTopN(count) },
onSelectAboveThreshold = { threshold -> viewModel.selectAllAboveThreshold(threshold) },
onSubmit = { onSubmit = {
val uris = viewModel.getSelectedImageUris() val uris = viewModel.getSelectedImageUris()
onSubmitForTraining(uris) onSubmitForTraining(uris)
@@ -93,8 +97,10 @@ fun RollingScanScreen(
RollingScanPhotoGrid( RollingScanPhotoGrid(
rankedPhotos = rankedPhotos, rankedPhotos = rankedPhotos,
selectedImageIds = selectedImageIds, selectedImageIds = selectedImageIds,
negativeImageIds = negativeImageIds,
isScanning = isScanning, isScanning = isScanning,
onToggleSelection = { imageId -> viewModel.toggleSelection(imageId) }, onToggleSelection = { imageId -> viewModel.toggleSelection(imageId) },
onToggleNegative = { imageId -> viewModel.toggleNegative(imageId) },
modifier = Modifier.padding(padding) modifier = Modifier.padding(padding)
) )
} }
@@ -159,19 +165,26 @@ private fun RollingScanTopBar(
} }
// ═══════════════════════════════════════════════════════════ // ═══════════════════════════════════════════════════════════
// PHOTO GRID // PHOTO GRID - Similarity-based bucketing
// ═══════════════════════════════════════════════════════════ // ═══════════════════════════════════════════════════════════
@OptIn(ExperimentalFoundationApi::class)
@Composable @Composable
private fun RollingScanPhotoGrid( private fun RollingScanPhotoGrid(
rankedPhotos: List<FaceSimilarityScorer.ScoredPhoto>, rankedPhotos: List<FaceSimilarityScorer.ScoredPhoto>,
selectedImageIds: Set<String>, selectedImageIds: Set<String>,
negativeImageIds: Set<String>,
isScanning: Boolean, isScanning: Boolean,
onToggleSelection: (String) -> Unit, onToggleSelection: (String) -> Unit,
onToggleNegative: (String) -> Unit,
modifier: Modifier = Modifier modifier: Modifier = Modifier
) { ) {
Column(modifier = modifier.fillMaxSize()) { // Bucket by similarity score
val veryLikely = rankedPhotos.filter { it.finalScore >= 0.60f }
val probably = rankedPhotos.filter { it.finalScore in 0.45f..0.599f }
val maybe = rankedPhotos.filter { it.finalScore < 0.45f }
Column(modifier = modifier.fillMaxSize()) {
// Scanning indicator // Scanning indicator
if (isScanning) { if (isScanning) {
LinearProgressIndicator( LinearProgressIndicator(
@@ -180,69 +193,78 @@ private fun RollingScanPhotoGrid(
) )
} }
// Hint for negative marking
Text(
text = "Tap to select • Long-press to mark as NOT this person",
style = MaterialTheme.typography.bodySmall,
color = MaterialTheme.colorScheme.onSurfaceVariant,
modifier = Modifier.padding(horizontal = 12.dp, vertical = 4.dp)
)
LazyVerticalGrid( LazyVerticalGrid(
columns = GridCells.Fixed(3), columns = GridCells.Fixed(3),
contentPadding = PaddingValues(8.dp), contentPadding = PaddingValues(8.dp),
horizontalArrangement = Arrangement.spacedBy(8.dp), horizontalArrangement = Arrangement.spacedBy(8.dp),
verticalArrangement = Arrangement.spacedBy(8.dp) verticalArrangement = Arrangement.spacedBy(8.dp)
) { ) {
// Section: Most Similar (top 10) // Section: Very Likely (>60%)
val topMatches = rankedPhotos.take(10) if (veryLikely.isNotEmpty()) {
if (topMatches.isNotEmpty()) {
item(span = { GridItemSpan(3) }) { item(span = { GridItemSpan(3) }) {
SectionHeader( SectionHeader(
icon = Icons.Default.Whatshot, icon = Icons.Default.Whatshot,
text = "🔥 Most Similar (${topMatches.size})", text = "🟢 Very Likely (${veryLikely.size})",
color = MaterialTheme.colorScheme.primary color = Color(0xFF4CAF50)
) )
} }
items(veryLikely, key = { it.imageId }) { photo ->
items(topMatches, key = { it.imageId }) { photo ->
PhotoCard( PhotoCard(
photo = photo, photo = photo,
isSelected = photo.imageId in selectedImageIds, isSelected = photo.imageId in selectedImageIds,
isNegative = photo.imageId in negativeImageIds,
onToggle = { onToggleSelection(photo.imageId) }, onToggle = { onToggleSelection(photo.imageId) },
onLongPress = { onToggleNegative(photo.imageId) },
showSimilarityBadge = true showSimilarityBadge = true
) )
} }
} }
// Section: Good Matches (11-30) // Section: Probably (45-60%)
val goodMatches = rankedPhotos.drop(10).take(20) if (probably.isNotEmpty()) {
if (goodMatches.isNotEmpty()) {
item(span = { GridItemSpan(3) }) { item(span = { GridItemSpan(3) }) {
SectionHeader( SectionHeader(
icon = Icons.Default.CheckCircle, icon = Icons.Default.CheckCircle,
text = "📊 Good Matches (${goodMatches.size})", text = "🟡 Probably (${probably.size})",
color = MaterialTheme.colorScheme.tertiary color = Color(0xFFFFC107)
) )
} }
items(probably, key = { it.imageId }) { photo ->
items(goodMatches, key = { it.imageId }) { photo ->
PhotoCard( PhotoCard(
photo = photo, photo = photo,
isSelected = photo.imageId in selectedImageIds, isSelected = photo.imageId in selectedImageIds,
onToggle = { onToggleSelection(photo.imageId) } isNegative = photo.imageId in negativeImageIds,
onToggle = { onToggleSelection(photo.imageId) },
onLongPress = { onToggleNegative(photo.imageId) },
showSimilarityBadge = true
) )
} }
} }
// Section: Other Photos // Section: Maybe (<45%)
val otherPhotos = rankedPhotos.drop(30) if (maybe.isNotEmpty()) {
if (otherPhotos.isNotEmpty()) {
item(span = { GridItemSpan(3) }) { item(span = { GridItemSpan(3) }) {
SectionHeader( SectionHeader(
icon = Icons.Default.Photo, icon = Icons.Default.Photo,
text = "📷 Other Photos (${otherPhotos.size})", text = "🟠 Maybe (${maybe.size})",
color = MaterialTheme.colorScheme.onSurfaceVariant color = Color(0xFFFF9800)
) )
} }
items(maybe, key = { it.imageId }) { photo ->
items(otherPhotos, key = { it.imageId }) { photo ->
PhotoCard( PhotoCard(
photo = photo, photo = photo,
isSelected = photo.imageId in selectedImageIds, isSelected = photo.imageId in selectedImageIds,
onToggle = { onToggleSelection(photo.imageId) } isNegative = photo.imageId in negativeImageIds,
onToggle = { onToggleSelection(photo.imageId) },
onLongPress = { onToggleNegative(photo.imageId) }
) )
} }
} }
@@ -258,24 +280,34 @@ private fun RollingScanPhotoGrid(
} }
// ═══════════════════════════════════════════════════════════ // ═══════════════════════════════════════════════════════════
// PHOTO CARD // PHOTO CARD - with long-press for negative marking
// ═══════════════════════════════════════════════════════════ // ═══════════════════════════════════════════════════════════
@OptIn(ExperimentalFoundationApi::class)
@Composable @Composable
private fun PhotoCard( private fun PhotoCard(
photo: FaceSimilarityScorer.ScoredPhoto, photo: FaceSimilarityScorer.ScoredPhoto,
isSelected: Boolean, isSelected: Boolean,
isNegative: Boolean = false,
onToggle: () -> Unit, onToggle: () -> Unit,
onLongPress: () -> Unit = {},
showSimilarityBadge: Boolean = false showSimilarityBadge: Boolean = false
) { ) {
val borderColor = when {
isNegative -> Color(0xFFE53935) // Red for negative
isSelected -> MaterialTheme.colorScheme.primary
else -> MaterialTheme.colorScheme.outline.copy(alpha = 0.3f)
}
val borderWidth = if (isSelected || isNegative) 3.dp else 1.dp
Card( Card(
modifier = Modifier modifier = Modifier
.aspectRatio(1f) .aspectRatio(1f)
.clickable(onClick = onToggle), .combinedClickable(
border = if (isSelected) onClick = onToggle,
BorderStroke(3.dp, MaterialTheme.colorScheme.primary) onLongClick = onLongPress
else ),
BorderStroke(1.dp, MaterialTheme.colorScheme.outline.copy(alpha = 0.3f)), border = BorderStroke(borderWidth, borderColor),
elevation = CardDefaults.cardElevation( elevation = CardDefaults.cardElevation(
defaultElevation = if (isSelected) 4.dp else 1.dp defaultElevation = if (isSelected) 4.dp else 1.dp
) )
@@ -289,22 +321,47 @@ private fun PhotoCard(
contentScale = ContentScale.Crop contentScale = ContentScale.Crop
) )
// Similarity badge (top-left) - Only for top matches // Dim overlay for negatives
if (showSimilarityBadge) { if (isNegative) {
Box(
modifier = Modifier
.fillMaxSize()
.padding(0.dp),
contentAlignment = Alignment.Center
) {
Surface(
modifier = Modifier.fillMaxSize(),
color = Color.Black.copy(alpha = 0.5f)
) {}
Icon(
Icons.Default.Close,
contentDescription = "Not this person",
tint = Color.White,
modifier = Modifier.size(32.dp)
)
}
}
// Similarity badge (top-left)
if (showSimilarityBadge && !isNegative) {
Surface( Surface(
modifier = Modifier modifier = Modifier
.align(Alignment.TopStart) .align(Alignment.TopStart)
.padding(6.dp), .padding(6.dp),
shape = RoundedCornerShape(8.dp), shape = RoundedCornerShape(8.dp),
color = MaterialTheme.colorScheme.primary, color = when {
photo.finalScore >= 0.60f -> Color(0xFF4CAF50)
photo.finalScore >= 0.45f -> Color(0xFFFFC107)
else -> Color(0xFFFF9800)
},
shadowElevation = 4.dp shadowElevation = 4.dp
) { ) {
Text( Text(
text = "${(photo.similarityScore * 100).toInt()}%", text = "${(photo.finalScore * 100).toInt()}%",
modifier = Modifier.padding(horizontal = 8.dp, vertical = 4.dp), modifier = Modifier.padding(horizontal = 8.dp, vertical = 4.dp),
style = MaterialTheme.typography.labelSmall, style = MaterialTheme.typography.labelSmall,
fontWeight = FontWeight.Bold, fontWeight = FontWeight.Bold,
color = MaterialTheme.colorScheme.onPrimary color = Color.White
) )
} }
} }
@@ -332,7 +389,7 @@ private fun PhotoCard(
} }
// Face count badge (bottom-right) // Face count badge (bottom-right)
if (photo.faceCount > 1) { if (photo.faceCount > 1 && !isNegative) {
Surface( Surface(
modifier = Modifier modifier = Modifier
.align(Alignment.BottomEnd) .align(Alignment.BottomEnd)
@@ -395,6 +452,7 @@ private fun RollingScanBottomBar(
isReadyForTraining: Boolean, isReadyForTraining: Boolean,
validationMessage: String?, validationMessage: String?,
onSelectTopN: (Int) -> Unit, onSelectTopN: (Int) -> Unit,
onSelectAboveThreshold: (Float) -> Unit,
onSubmit: () -> Unit onSubmit: () -> Unit
) { ) {
Surface( Surface(
@@ -416,30 +474,41 @@ private fun RollingScanBottomBar(
) )
} }
// First row: threshold selection
Row( Row(
modifier = Modifier.fillMaxWidth(), modifier = Modifier.fillMaxWidth(),
horizontalArrangement = Arrangement.spacedBy(8.dp) horizontalArrangement = Arrangement.spacedBy(6.dp)
) { ) {
// Quick select buttons
OutlinedButton( OutlinedButton(
onClick = { onSelectTopN(10) }, onClick = { onSelectAboveThreshold(0.60f) },
modifier = Modifier.weight(1f) modifier = Modifier.weight(1f),
contentPadding = PaddingValues(horizontal = 8.dp, vertical = 4.dp)
) { ) {
Text("Top 10") Text(">60%", style = MaterialTheme.typography.labelSmall)
}
OutlinedButton(
onClick = { onSelectAboveThreshold(0.50f) },
modifier = Modifier.weight(1f),
contentPadding = PaddingValues(horizontal = 8.dp, vertical = 4.dp)
) {
Text(">50%", style = MaterialTheme.typography.labelSmall)
}
OutlinedButton(
onClick = { onSelectTopN(15) },
modifier = Modifier.weight(1f),
contentPadding = PaddingValues(horizontal = 8.dp, vertical = 4.dp)
) {
Text("Top 15", style = MaterialTheme.typography.labelSmall)
}
} }
OutlinedButton( Spacer(Modifier.height(8.dp))
onClick = { onSelectTopN(20) },
modifier = Modifier.weight(1f)
) {
Text("Top 20")
}
// Submit button // Second row: submit
Button( Button(
onClick = onSubmit, onClick = onSubmit,
enabled = isReadyForTraining, enabled = isReadyForTraining,
modifier = Modifier.weight(1.5f) modifier = Modifier.fillMaxWidth()
) { ) {
Icon( Icon(
Icons.Default.Done, Icons.Default.Done,
@@ -447,8 +516,7 @@ private fun RollingScanBottomBar(
modifier = Modifier.size(18.dp) modifier = Modifier.size(18.dp)
) )
Spacer(Modifier.width(8.dp)) Spacer(Modifier.width(8.dp))
Text("Train ($selectedCount)") Text("Train Model ($selectedCount photos)")
}
} }
} }
} }

View File

@@ -44,6 +44,11 @@ class RollingScanViewModel @Inject constructor(
private const val TAG = "RollingScanVM" private const val TAG = "RollingScanVM"
private const val DEBOUNCE_DELAY_MS = 300L private const val DEBOUNCE_DELAY_MS = 300L
private const val MIN_PHOTOS_FOR_TRAINING = 15 private const val MIN_PHOTOS_FOR_TRAINING = 15
// Progressive thresholds based on selection count
private const val FLOOR_FEW_SEEDS = 0.30f // 1-3 seeds
private const val FLOOR_MEDIUM_SEEDS = 0.40f // 4-10 seeds
private const val FLOOR_MANY_SEEDS = 0.50f // 10+ seeds
} }
// ═══════════════════════════════════════════════════════════ // ═══════════════════════════════════════════════════════════
@@ -71,6 +76,11 @@ class RollingScanViewModel @Inject constructor(
// Cache of selected embeddings // Cache of selected embeddings
private val selectedEmbeddings = mutableListOf<FloatArray>() private val selectedEmbeddings = mutableListOf<FloatArray>()
// Negative embeddings (marked as "not this person")
private val _negativeImageIds = MutableStateFlow<Set<String>>(emptySet())
val negativeImageIds: StateFlow<Set<String>> = _negativeImageIds.asStateFlow()
private val negativeEmbeddings = mutableListOf<FloatArray>()
// All available image IDs // All available image IDs
private var allImageIds: List<String> = emptyList() private var allImageIds: List<String> = emptyList()
@@ -156,24 +166,55 @@ class RollingScanViewModel @Inject constructor(
current.remove(imageId) current.remove(imageId)
viewModelScope.launch { viewModelScope.launch {
// Remove embedding from cache
val cached = faceCacheDao.getEmbeddingByImageId(imageId) val cached = faceCacheDao.getEmbeddingByImageId(imageId)
cached?.getEmbedding()?.let { selectedEmbeddings.remove(it) } cached?.getEmbedding()?.let { selectedEmbeddings.remove(it) }
} }
} else { } else {
// Select // Select (and remove from negatives if present)
current.add(imageId) current.add(imageId)
if (imageId in _negativeImageIds.value) {
toggleNegative(imageId)
}
viewModelScope.launch { viewModelScope.launch {
// Add embedding to cache
val cached = faceCacheDao.getEmbeddingByImageId(imageId) val cached = faceCacheDao.getEmbeddingByImageId(imageId)
cached?.getEmbedding()?.let { selectedEmbeddings.add(it) } cached?.getEmbedding()?.let { selectedEmbeddings.add(it) }
} }
} }
_selectedImageIds.value = current _selectedImageIds.value = current.toSet() // Immutable copy
scanDebouncer.debounce {
triggerRollingScan()
}
}
/**
* Toggle negative marking ("Not this person")
*/
fun toggleNegative(imageId: String) {
val current = _negativeImageIds.value.toMutableSet()
if (imageId in current) {
current.remove(imageId)
viewModelScope.launch {
val cached = faceCacheDao.getEmbeddingByImageId(imageId)
cached?.getEmbedding()?.let { negativeEmbeddings.remove(it) }
}
} else {
current.add(imageId)
// Remove from selected if present
if (imageId in _selectedImageIds.value) {
toggleSelection(imageId)
}
viewModelScope.launch {
val cached = faceCacheDao.getEmbeddingByImageId(imageId)
cached?.getEmbedding()?.let { negativeEmbeddings.add(it) }
}
}
_negativeImageIds.value = current.toSet() // Immutable copy
// Debounced rescan
scanDebouncer.debounce { scanDebouncer.debounce {
triggerRollingScan() triggerRollingScan()
} }
@@ -190,13 +231,33 @@ class RollingScanViewModel @Inject constructor(
val current = _selectedImageIds.value.toMutableSet() val current = _selectedImageIds.value.toMutableSet()
current.addAll(topPhotos) current.addAll(topPhotos)
_selectedImageIds.value = current _selectedImageIds.value = current.toSet() // Immutable copy
viewModelScope.launch { viewModelScope.launch {
// Add embeddings
val embeddings = faceCacheDao.getEmbeddingsForImages(topPhotos.toList()) val embeddings = faceCacheDao.getEmbeddingsForImages(topPhotos.toList())
selectedEmbeddings.addAll(embeddings.mapNotNull { it.getEmbedding() }) selectedEmbeddings.addAll(embeddings.mapNotNull { it.getEmbedding() })
triggerRollingScan()
}
}
/**
* Select all photos above a similarity threshold
*/
fun selectAllAboveThreshold(threshold: Float) {
val photosAbove = _rankedPhotos.value
.filter { it.finalScore >= threshold }
.map { it.imageId }
val current = _selectedImageIds.value.toMutableSet()
current.addAll(photosAbove)
_selectedImageIds.value = current.toSet() // Immutable copy
viewModelScope.launch {
val newIds = photosAbove.filter { it !in _selectedImageIds.value }
if (newIds.isNotEmpty()) {
val embeddings = faceCacheDao.getEmbeddingsForImages(newIds)
selectedEmbeddings.addAll(embeddings.mapNotNull { it.getEmbedding() })
}
triggerRollingScan() triggerRollingScan()
} }
} }
@@ -207,17 +268,24 @@ class RollingScanViewModel @Inject constructor(
fun clearSelection() { fun clearSelection() {
_selectedImageIds.value = emptySet() _selectedImageIds.value = emptySet()
selectedEmbeddings.clear() selectedEmbeddings.clear()
// Reset ranking
_rankedPhotos.value = emptyList() _rankedPhotos.value = emptyList()
} }
/**
* Clear negative markings
*/
fun clearNegatives() {
_negativeImageIds.value = emptySet()
negativeEmbeddings.clear()
scanDebouncer.debounce { triggerRollingScan() }
}
// ═══════════════════════════════════════════════════════════ // ═══════════════════════════════════════════════════════════
// ROLLING SCAN LOGIC // ROLLING SCAN LOGIC
// ═══════════════════════════════════════════════════════════ // ═══════════════════════════════════════════════════════════
/** /**
* CORE: Trigger rolling similarity scan * CORE: Trigger rolling similarity scan with progressive filtering
*/ */
private suspend fun triggerRollingScan() { private suspend fun triggerRollingScan() {
if (selectedEmbeddings.isEmpty()) { if (selectedEmbeddings.isEmpty()) {
@@ -228,7 +296,15 @@ class RollingScanViewModel @Inject constructor(
try { try {
_isScanning.value = true _isScanning.value = true
Log.d(TAG, "Starting scan with ${selectedEmbeddings.size} selected embeddings") val selectionCount = selectedEmbeddings.size
Log.d(TAG, "Starting scan with $selectionCount selected, ${negativeEmbeddings.size} negative")
// Progressive threshold based on selection count
val similarityFloor = when {
selectionCount <= 3 -> FLOOR_FEW_SEEDS
selectionCount <= 10 -> FLOOR_MEDIUM_SEEDS
else -> FLOOR_MANY_SEEDS
}
// Calculate centroid from selected embeddings // Calculate centroid from selected embeddings
val centroid = faceSimilarityScorer.calculateCentroid(selectedEmbeddings) val centroid = faceSimilarityScorer.calculateCentroid(selectedEmbeddings)
@@ -240,17 +316,38 @@ class RollingScanViewModel @Inject constructor(
centroid = centroid centroid = centroid
) )
// Update image URIs in scored photos // Apply negative penalty, quality boost, and floor filter
val photosWithUris = scoredPhotos.map { photo -> val filteredPhotos = scoredPhotos
.map { photo ->
// Calculate max similarity to any negative embedding
val negativePenalty = if (negativeEmbeddings.isNotEmpty()) {
negativeEmbeddings.maxOfOrNull { neg ->
cosineSimilarity(photo.cachedEmbedding, neg)
} ?: 0f
} else 0f
// Quality multiplier: solo face, large face, good quality
val qualityMultiplier = 1f +
(if (photo.faceCount == 1) 0.15f else 0f) +
(if (photo.faceAreaRatio > 0.15f) 0.10f else 0f) +
(if (photo.qualityScore > 0.7f) 0.10f else 0f)
// Final score = (similarity - negativePenalty) * qualityMultiplier
val adjustedScore = ((photo.similarityScore - negativePenalty * 0.5f) * qualityMultiplier)
.coerceIn(0f, 1f)
photo.copy( photo.copy(
imageUri = imageUriCache[photo.imageId] ?: photo.imageId imageUri = imageUriCache[photo.imageId] ?: photo.imageId,
finalScore = adjustedScore
) )
} }
.filter { it.finalScore >= similarityFloor } // Apply floor
.filter { it.imageId !in _negativeImageIds.value } // Hide negatives
.sortedByDescending { it.finalScore }
Log.d(TAG, "Scan complete. Scored ${photosWithUris.size} photos") Log.d(TAG, "Scan complete. ${filteredPhotos.size} photos above floor $similarityFloor")
// Update ranked list _rankedPhotos.value = filteredPhotos
_rankedPhotos.value = photosWithUris
} catch (e: Exception) { } catch (e: Exception) {
Log.e(TAG, "Scan failed", e) Log.e(TAG, "Scan failed", e)
@@ -259,6 +356,19 @@ class RollingScanViewModel @Inject constructor(
} }
} }
private fun cosineSimilarity(a: FloatArray, b: FloatArray): Float {
if (a.size != b.size) return 0f
var dot = 0f
var normA = 0f
var normB = 0f
for (i in a.indices) {
dot += a[i] * b[i]
normA += a[i] * a[i]
normB += b[i] * b[i]
}
return if (normA > 0 && normB > 0) dot / (kotlin.math.sqrt(normA) * kotlin.math.sqrt(normB)) else 0f
}
// ═══════════════════════════════════════════════════════════ // ═══════════════════════════════════════════════════════════
// SUBMISSION // SUBMISSION
// ═══════════════════════════════════════════════════════════ // ═══════════════════════════════════════════════════════════
@@ -299,9 +409,11 @@ class RollingScanViewModel @Inject constructor(
fun reset() { fun reset() {
_uiState.value = RollingScanState.Idle _uiState.value = RollingScanState.Idle
_selectedImageIds.value = emptySet() _selectedImageIds.value = emptySet()
_negativeImageIds.value = emptySet()
_rankedPhotos.value = emptyList() _rankedPhotos.value = emptyList()
_isScanning.value = false _isScanning.value = false
selectedEmbeddings.clear() selectedEmbeddings.clear()
negativeEmbeddings.clear()
allImageIds = emptyList() allImageIds = emptyList()
imageUriCache = emptyMap() imageUriCache = emptyMap()
scanDebouncer.cancel() scanDebouncer.cancel()

View File

@@ -6,8 +6,11 @@ import android.graphics.BitmapFactory
import android.graphics.Rect import android.graphics.Rect
import android.net.Uri import android.net.Uri
import com.google.mlkit.vision.common.InputImage import com.google.mlkit.vision.common.InputImage
import com.google.mlkit.vision.face.Face
import com.google.mlkit.vision.face.FaceDetection import com.google.mlkit.vision.face.FaceDetection
import com.google.mlkit.vision.face.FaceDetectorOptions import com.google.mlkit.vision.face.FaceDetectorOptions
import com.placeholder.sherpai2.domain.clustering.FaceQualityFilter
import com.placeholder.sherpai2.ml.FaceNormalizer
import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.async import kotlinx.coroutines.async
import kotlinx.coroutines.awaitAll import kotlinx.coroutines.awaitAll
@@ -64,21 +67,30 @@ class FaceDetectionHelper(private val context: Context) {
val inputImage = InputImage.fromBitmap(bitmap, 0) val inputImage = InputImage.fromBitmap(bitmap, 0)
val faces = detector.process(inputImage).await() val faces = detector.process(inputImage).await()
// Sort by face size (area) to get the largest face // Filter to quality faces - use lenient scanning filter
val sortedFaces = faces.sortedByDescending { face -> // (Discovery filter was too strict, rejecting faces from rolling scan)
val qualityFaces = faces.filter { face ->
FaceQualityFilter.validateForScanning(
face = face,
imageWidth = bitmap.width,
imageHeight = bitmap.height
)
}
// Sort by face size (area) to get the largest quality face
val sortedFaces = qualityFaces.sortedByDescending { face ->
face.boundingBox.width() * face.boundingBox.height() face.boundingBox.width() * face.boundingBox.height()
} }
val croppedFace = if (sortedFaces.isNotEmpty()) { val croppedFace = if (sortedFaces.isNotEmpty()) {
// Crop the LARGEST detected face (most likely the subject) FaceNormalizer.cropAndNormalize(bitmap, sortedFaces[0])
cropFaceFromBitmap(bitmap, sortedFaces[0].boundingBox)
} else null } else null
FaceDetectionResult( FaceDetectionResult(
uri = uri, uri = uri,
hasFace = faces.isNotEmpty(), hasFace = qualityFaces.isNotEmpty(),
faceCount = faces.size, faceCount = qualityFaces.size,
faceBounds = faces.map { it.boundingBox }, faceBounds = qualityFaces.map { it.boundingBox },
croppedFaceBitmap = croppedFace croppedFaceBitmap = croppedFace
) )
} catch (e: Exception) { } catch (e: Exception) {

View File

@@ -51,21 +51,8 @@ fun ScanResultsScreen(
} }
} }
Scaffold( // No Scaffold - MainScreen provides TopAppBar
topBar = { Box(modifier = Modifier.fillMaxSize()) {
TopAppBar(
title = { Text("Train New Person") },
colors = TopAppBarDefaults.topAppBarColors(
containerColor = MaterialTheme.colorScheme.primaryContainer
)
)
}
) { paddingValues ->
Box(
modifier = Modifier
.fillMaxSize()
.padding(paddingValues)
) {
when (state) { when (state) {
is ScanningState.Idle -> {} is ScanningState.Idle -> {}
@@ -77,8 +64,6 @@ fun ScanResultsScreen(
ImprovedResultsView( ImprovedResultsView(
result = state.sanityCheckResult, result = state.sanityCheckResult,
onContinue = { onContinue = {
// PersonInfo already captured in TrainingScreen!
// Just start training with stored info
trainViewModel.createFaceModel( trainViewModel.createFaceModel(
trainViewModel.getPersonInfo()?.name ?: "Unknown" trainViewModel.getPersonInfo()?.name ?: "Unknown"
) )
@@ -103,7 +88,6 @@ fun ScanResultsScreen(
TrainingOverlay(trainingState = trainingState as TrainingState.Processing) TrainingOverlay(trainingState = trainingState as TrainingState.Processing)
} }
} }
}
showFacePickerDialog?.let { result -> showFacePickerDialog?.let { result ->
FacePickerDialog( FacePickerDialog(

View File

@@ -5,11 +5,18 @@ import android.graphics.Bitmap
import android.net.Uri import android.net.Uri
import androidx.lifecycle.AndroidViewModel import androidx.lifecycle.AndroidViewModel
import androidx.lifecycle.viewModelScope import androidx.lifecycle.viewModelScope
import androidx.datastore.preferences.core.booleanPreferencesKey
import androidx.datastore.preferences.preferencesDataStore
import androidx.work.WorkManager
import android.content.Context
import com.placeholder.sherpai2.data.local.entity.PersonEntity import com.placeholder.sherpai2.data.local.entity.PersonEntity
import com.placeholder.sherpai2.data.repository.FaceRecognitionRepository import com.placeholder.sherpai2.data.repository.FaceRecognitionRepository
import com.placeholder.sherpai2.ml.FaceNetModel import com.placeholder.sherpai2.ml.FaceNetModel
import com.placeholder.sherpai2.workers.LibraryScanWorker
import dagger.hilt.android.lifecycle.HiltViewModel import dagger.hilt.android.lifecycle.HiltViewModel
import kotlinx.coroutines.flow.MutableStateFlow import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.first
import kotlinx.coroutines.flow.map
import kotlinx.coroutines.flow.StateFlow import kotlinx.coroutines.flow.StateFlow
import kotlinx.coroutines.flow.asStateFlow import kotlinx.coroutines.flow.asStateFlow
import kotlinx.coroutines.launch import kotlinx.coroutines.launch
@@ -48,15 +55,20 @@ data class PersonInfo(
/** /**
* FIXED TrainViewModel with proper exclude functionality and efficient replace * FIXED TrainViewModel with proper exclude functionality and efficient replace
*/ */
private val android.content.Context.dataStore by preferencesDataStore(name = "settings")
private val KEY_BACKGROUND_TAGGING = booleanPreferencesKey("background_recognition_tagging")
@HiltViewModel @HiltViewModel
class TrainViewModel @Inject constructor( class TrainViewModel @Inject constructor(
application: Application, application: Application,
private val faceRecognitionRepository: FaceRecognitionRepository, private val faceRecognitionRepository: FaceRecognitionRepository,
private val faceNetModel: FaceNetModel private val faceNetModel: FaceNetModel,
private val workManager: WorkManager
) : AndroidViewModel(application) { ) : AndroidViewModel(application) {
private val sanityChecker = TrainingSanityChecker(application) private val sanityChecker = TrainingSanityChecker(application)
private val faceDetectionHelper = FaceDetectionHelper(application) private val faceDetectionHelper = FaceDetectionHelper(application)
private val dataStore = application.dataStore
private val _uiState = MutableStateFlow<ScanningState>(ScanningState.Idle) private val _uiState = MutableStateFlow<ScanningState>(ScanningState.Idle)
val uiState: StateFlow<ScanningState> = _uiState.asStateFlow() val uiState: StateFlow<ScanningState> = _uiState.asStateFlow()
@@ -174,6 +186,20 @@ class TrainViewModel @Inject constructor(
relationship = person.relationship relationship = person.relationship
) )
// Trigger library scan if setting enabled
val backgroundTaggingEnabled = dataStore.data
.map { it[KEY_BACKGROUND_TAGGING] ?: true }
.first()
if (backgroundTaggingEnabled) {
// Use default threshold (0.62 solo, 0.68 group)
val scanRequest = LibraryScanWorker.createWorkRequest(
personId = personId,
personName = personName
)
workManager.enqueue(scanRequest)
}
} catch (e: Exception) { } catch (e: Exception) {
_trainingState.value = TrainingState.Error( _trainingState.value = TrainingState.Error(
e.message ?: "Failed to create face model" e.message ?: "Failed to create face model"
@@ -355,7 +381,7 @@ class TrainViewModel @Inject constructor(
faceDetectionResults = updatedFaceResults, faceDetectionResults = updatedFaceResults,
validationErrors = updatedErrors, validationErrors = updatedErrors,
validImagesWithFaces = updatedValidImages, validImagesWithFaces = updatedValidImages,
excludedImages = excludedImages excludedImages = excludedImages.toSet() // Immutable copy for Compose state detection
) )
} }

View File

@@ -49,6 +49,7 @@ fun TrainingPhotoSelectorScreen(
val isRanking by viewModel.isRanking.collectAsStateWithLifecycle() val isRanking by viewModel.isRanking.collectAsStateWithLifecycle()
val showPremiumOnly by viewModel.showPremiumOnly.collectAsStateWithLifecycle() val showPremiumOnly by viewModel.showPremiumOnly.collectAsStateWithLifecycle()
val premiumCount by viewModel.premiumCount.collectAsStateWithLifecycle() val premiumCount by viewModel.premiumCount.collectAsStateWithLifecycle()
val embeddingProgress by viewModel.embeddingProgress.collectAsStateWithLifecycle()
Scaffold( Scaffold(
topBar = { topBar = {
@@ -154,8 +155,34 @@ fun TrainingPhotoSelectorScreen(
Box( Box(
modifier = Modifier.fillMaxSize(), modifier = Modifier.fillMaxSize(),
contentAlignment = Alignment.Center contentAlignment = Alignment.Center
) {
Column(
horizontalAlignment = Alignment.CenterHorizontally,
verticalArrangement = Arrangement.spacedBy(16.dp)
) { ) {
CircularProgressIndicator() CircularProgressIndicator()
// Capture value to avoid race condition
val progress = embeddingProgress
if (progress != null) {
Text(
"Preparing faces: ${progress.current}/${progress.total}",
style = MaterialTheme.typography.bodyMedium,
color = MaterialTheme.colorScheme.onSurfaceVariant
)
LinearProgressIndicator(
progress = { progress.current.toFloat() / progress.total },
modifier = Modifier
.width(200.dp)
.padding(top = 8.dp)
)
} else {
Text(
"Loading premium faces...",
style = MaterialTheme.typography.bodyMedium,
color = MaterialTheme.colorScheme.onSurfaceVariant
)
}
}
} }
} }
photos.isEmpty() -> { photos.isEmpty() -> {

View File

@@ -1,20 +1,31 @@
package com.placeholder.sherpai2.ui.trainingprep package com.placeholder.sherpai2.ui.trainingprep
import android.app.Application
import android.graphics.Bitmap
import android.graphics.BitmapFactory
import android.graphics.Rect
import android.net.Uri
import android.util.Log import android.util.Log
import androidx.lifecycle.ViewModel import androidx.lifecycle.AndroidViewModel
import androidx.lifecycle.viewModelScope import androidx.lifecycle.viewModelScope
import com.placeholder.sherpai2.data.local.dao.FaceCacheDao import com.placeholder.sherpai2.data.local.dao.FaceCacheDao
import com.placeholder.sherpai2.data.local.dao.ImageDao import com.placeholder.sherpai2.data.local.dao.ImageDao
import com.placeholder.sherpai2.data.local.entity.FaceCacheEntity
import com.placeholder.sherpai2.data.local.entity.ImageEntity import com.placeholder.sherpai2.data.local.entity.ImageEntity
import com.placeholder.sherpai2.domain.similarity.FaceSimilarityScorer import com.placeholder.sherpai2.domain.similarity.FaceSimilarityScorer
import com.placeholder.sherpai2.ml.FaceNetModel
import dagger.hilt.android.lifecycle.HiltViewModel import dagger.hilt.android.lifecycle.HiltViewModel
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.Job import kotlinx.coroutines.Job
import kotlinx.coroutines.delay import kotlinx.coroutines.delay
import kotlinx.coroutines.flow.MutableStateFlow import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.StateFlow import kotlinx.coroutines.flow.StateFlow
import kotlinx.coroutines.flow.asStateFlow import kotlinx.coroutines.flow.asStateFlow
import kotlinx.coroutines.launch import kotlinx.coroutines.launch
import kotlinx.coroutines.withContext
import javax.inject.Inject import javax.inject.Inject
import kotlin.math.max
import kotlin.math.min
/** /**
* TrainingPhotoSelectorViewModel - PREMIUM GRID + ROLLING SCAN * TrainingPhotoSelectorViewModel - PREMIUM GRID + ROLLING SCAN
@@ -27,15 +38,18 @@ import javax.inject.Inject
*/ */
@HiltViewModel @HiltViewModel
class TrainingPhotoSelectorViewModel @Inject constructor( class TrainingPhotoSelectorViewModel @Inject constructor(
application: Application,
private val imageDao: ImageDao, private val imageDao: ImageDao,
private val faceCacheDao: FaceCacheDao, private val faceCacheDao: FaceCacheDao,
private val faceSimilarityScorer: FaceSimilarityScorer private val faceSimilarityScorer: FaceSimilarityScorer,
) : ViewModel() { private val faceNetModel: FaceNetModel
) : AndroidViewModel(application) {
companion object { companion object {
private const val TAG = "PremiumSelector" private const val TAG = "PremiumSelector"
private const val MIN_SEEDS_FOR_ROLLING_SCAN = 1 private const val MIN_SEEDS_FOR_ROLLING_SCAN = 1
private const val MAX_SEEDS_FOR_ROLLING_SCAN = 5 private const val MAX_SEEDS_FOR_ROLLING_SCAN = 5
private const val MAX_EMBEDDINGS_TO_GENERATE = 500
} }
// All photos (for fallback / full list) // All photos (for fallback / full list)
@@ -56,6 +70,12 @@ class TrainingPhotoSelectorViewModel @Inject constructor(
private val _isRanking = MutableStateFlow(false) private val _isRanking = MutableStateFlow(false)
val isRanking: StateFlow<Boolean> = _isRanking.asStateFlow() val isRanking: StateFlow<Boolean> = _isRanking.asStateFlow()
// Embedding generation progress
private val _embeddingProgress = MutableStateFlow<EmbeddingProgress?>(null)
val embeddingProgress: StateFlow<EmbeddingProgress?> = _embeddingProgress.asStateFlow()
data class EmbeddingProgress(val current: Int, val total: Int)
// Premium mode toggle // Premium mode toggle
private val _showPremiumOnly = MutableStateFlow(true) private val _showPremiumOnly = MutableStateFlow(true)
val showPremiumOnly: StateFlow<Boolean> = _showPremiumOnly.asStateFlow() val showPremiumOnly: StateFlow<Boolean> = _showPremiumOnly.asStateFlow()
@@ -79,20 +99,47 @@ class TrainingPhotoSelectorViewModel @Inject constructor(
/** /**
* Load PREMIUM faces first (solo, large, frontal, high quality) * Load PREMIUM faces first (solo, large, frontal, high quality)
* If no embeddings exist, generate them on-demand for premium candidates
*/ */
private fun loadPremiumFaces() { private fun loadPremiumFaces() {
viewModelScope.launch { viewModelScope.launch {
try { try {
_isLoading.value = true _isLoading.value = true
// Get premium faces from cache // First check if premium faces with embeddings exist
val premiumFaceCache = faceCacheDao.getPremiumFaces( var premiumFaceCache = faceCacheDao.getPremiumFaces(
minAreaRatio = 0.10f, minAreaRatio = 0.10f,
minQuality = 0.7f, minQuality = 0.7f,
limit = 500 limit = 500
) )
Log.d(TAG, " Found ${premiumFaceCache.size} premium faces") Log.d(TAG, "📊 Found ${premiumFaceCache.size} premium faces with embeddings")
// If no premium faces with embeddings, generate them on-demand
if (premiumFaceCache.isEmpty()) {
Log.d(TAG, "⚠️ No premium faces with embeddings - generating on-demand")
val candidates = faceCacheDao.getPremiumFaceCandidatesNeedingEmbeddings(
minAreaRatio = 0.10f,
minQuality = 0.7f,
limit = MAX_EMBEDDINGS_TO_GENERATE
)
Log.d(TAG, "📦 Found ${candidates.size} premium candidates needing embeddings")
if (candidates.isNotEmpty()) {
generateEmbeddingsForCandidates(candidates)
// Re-query after generating
premiumFaceCache = faceCacheDao.getPremiumFaces(
minAreaRatio = 0.10f,
minQuality = 0.7f,
limit = 500
)
Log.d(TAG, "✅ After generation: ${premiumFaceCache.size} premium faces")
}
}
_premiumCount.value = premiumFaceCache.size _premiumCount.value = premiumFaceCache.size
// Get corresponding ImageEntities // Get corresponding ImageEntities
@@ -117,10 +164,108 @@ class TrainingPhotoSelectorViewModel @Inject constructor(
loadAllFaces() loadAllFaces()
} finally { } finally {
_isLoading.value = false _isLoading.value = false
_embeddingProgress.value = null
} }
} }
} }
/**
* Generate embeddings for premium face candidates
*/
private suspend fun generateEmbeddingsForCandidates(candidates: List<FaceCacheEntity>) {
val context = getApplication<Application>()
val total = candidates.size
var processed = 0
withContext(Dispatchers.IO) {
// Get image URIs for candidates
val imageIds = candidates.map { it.imageId }.distinct()
val images = imageDao.getImagesByIds(imageIds)
val imageUriMap = images.associate { it.imageId to it.imageUri }
for (candidate in candidates) {
try {
val imageUri = imageUriMap[candidate.imageId] ?: continue
// Load bitmap
val bitmap = loadBitmapOptimized(context, Uri.parse(imageUri)) ?: continue
// Crop face
val croppedFace = cropFaceWithPadding(bitmap, candidate.getBoundingBox())
bitmap.recycle()
if (croppedFace == null) continue
// Generate embedding
val embedding = faceNetModel.generateEmbedding(croppedFace)
croppedFace.recycle()
// Validate embedding
if (embedding.any { it != 0f }) {
// Save to database
val embeddingJson = FaceCacheEntity.embeddingToJson(embedding)
faceCacheDao.updateEmbedding(candidate.imageId, candidate.faceIndex, embeddingJson)
}
} catch (e: Exception) {
Log.w(TAG, "Failed to generate embedding for ${candidate.imageId}: ${e.message}")
}
processed++
withContext(Dispatchers.Main) {
_embeddingProgress.value = EmbeddingProgress(processed, total)
}
}
}
Log.d(TAG, "✅ Generated embeddings for $processed/$total candidates")
}
private fun loadBitmapOptimized(context: android.content.Context, uri: Uri, maxDim: Int = 768): Bitmap? {
return try {
val options = BitmapFactory.Options().apply { inJustDecodeBounds = true }
context.contentResolver.openInputStream(uri)?.use { stream ->
BitmapFactory.decodeStream(stream, null, options)
}
var sampleSize = 1
while (options.outWidth / sampleSize > maxDim || options.outHeight / sampleSize > maxDim) {
sampleSize *= 2
}
val finalOptions = BitmapFactory.Options().apply {
inSampleSize = sampleSize
inPreferredConfig = Bitmap.Config.ARGB_8888
}
context.contentResolver.openInputStream(uri)?.use { stream ->
BitmapFactory.decodeStream(stream, null, finalOptions)
}
} catch (e: Exception) {
Log.w(TAG, "Failed to load bitmap: ${e.message}")
null
}
}
private fun cropFaceWithPadding(bitmap: Bitmap, boundingBox: Rect): Bitmap? {
return try {
val padding = (max(boundingBox.width(), boundingBox.height()) * 0.25f).toInt()
val left = max(0, boundingBox.left - padding)
val top = max(0, boundingBox.top - padding)
val right = min(bitmap.width, boundingBox.right + padding)
val bottom = min(bitmap.height, boundingBox.bottom + padding)
val width = right - left
val height = bottom - top
if (width > 0 && height > 0) {
Bitmap.createBitmap(bitmap, left, top, width, height)
} else null
} catch (e: Exception) {
Log.w(TAG, "Failed to crop face: ${e.message}")
null
}
}
/** /**
* Fallback: load all photos with faces * Fallback: load all photos with faces
*/ */

View File

@@ -9,6 +9,9 @@ import com.google.mlkit.vision.common.InputImage
import com.google.mlkit.vision.face.FaceDetection import com.google.mlkit.vision.face.FaceDetection
import com.google.mlkit.vision.face.FaceDetectorOptions import com.google.mlkit.vision.face.FaceDetectorOptions
import com.placeholder.sherpai2.data.local.dao.FaceModelDao import com.placeholder.sherpai2.data.local.dao.FaceModelDao
import com.placeholder.sherpai2.data.local.dao.PersonDao
import com.placeholder.sherpai2.domain.clustering.FaceQualityFilter
import com.placeholder.sherpai2.ml.FaceNormalizer
import com.placeholder.sherpai2.data.local.dao.ImageDao import com.placeholder.sherpai2.data.local.dao.ImageDao
import com.placeholder.sherpai2.data.local.dao.PhotoFaceTagDao import com.placeholder.sherpai2.data.local.dao.PhotoFaceTagDao
import com.placeholder.sherpai2.data.local.entity.PhotoFaceTagEntity import com.placeholder.sherpai2.data.local.entity.PhotoFaceTagEntity
@@ -52,7 +55,8 @@ class LibraryScanWorker @AssistedInject constructor(
@Assisted workerParams: WorkerParameters, @Assisted workerParams: WorkerParameters,
private val imageDao: ImageDao, private val imageDao: ImageDao,
private val faceModelDao: FaceModelDao, private val faceModelDao: FaceModelDao,
private val photoFaceTagDao: PhotoFaceTagDao private val photoFaceTagDao: PhotoFaceTagDao,
private val personDao: PersonDao
) : CoroutineWorker(context, workerParams) { ) : CoroutineWorker(context, workerParams) {
companion object { companion object {
@@ -65,7 +69,8 @@ class LibraryScanWorker @AssistedInject constructor(
const val KEY_MATCHES_FOUND = "matches_found" const val KEY_MATCHES_FOUND = "matches_found"
const val KEY_PHOTOS_SCANNED = "photos_scanned" const val KEY_PHOTOS_SCANNED = "photos_scanned"
private const val DEFAULT_THRESHOLD = 0.70f // Slightly looser than validation private const val DEFAULT_THRESHOLD = 0.62f // Solo photos
private const val GROUP_THRESHOLD = 0.68f // Group photos (stricter)
private const val BATCH_SIZE = 20 private const val BATCH_SIZE = 20
private const val MAX_RETRIES = 3 private const val MAX_RETRIES = 3
@@ -137,16 +142,40 @@ class LibraryScanWorker @AssistedInject constructor(
) )
} }
// Step 2.5: Load person to check isChild flag
val person = withContext(Dispatchers.IO) {
personDao.getPersonById(personId)
}
val isChildTarget = person?.isChild ?: false
// Step 3: Initialize ML components // Step 3: Initialize ML components
val faceNetModel = FaceNetModel(context) val faceNetModel = FaceNetModel(context)
val detector = FaceDetection.getClient( val detector = FaceDetection.getClient(
FaceDetectorOptions.Builder() FaceDetectorOptions.Builder()
.setPerformanceMode(FaceDetectorOptions.PERFORMANCE_MODE_ACCURATE) .setPerformanceMode(FaceDetectorOptions.PERFORMANCE_MODE_ACCURATE)
.setLandmarkMode(FaceDetectorOptions.LANDMARK_MODE_ALL) // Needed for age estimation
.setMinFaceSize(0.15f) .setMinFaceSize(0.15f)
.build() .build()
) )
val modelEmbedding = faceModel.getEmbeddingArray() // Distribution-based minimum threshold (self-calibrating)
val distributionMin = (faceModel.averageConfidence - 2 * faceModel.similarityStdDev)
.coerceAtLeast(faceModel.similarityMin - 0.05f)
.coerceAtLeast(0.50f) // Never go below 0.50 absolute floor
// Get ALL centroids for multi-centroid matching (critical for children)
val modelCentroids = faceModel.getCentroids().map { it.getEmbeddingArray() }
if (modelCentroids.isEmpty()) {
return@withContext Result.failure(workDataOf("error" to "No centroids in model"))
}
// Load ALL other models for "best match wins" comparison
// This prevents tagging siblings incorrectly
val allModels = withContext(Dispatchers.IO) { faceModelDao.getAllActiveFaceModels() }
val otherModelCentroids = allModels
.filter { it.id != faceModel.id }
.map { model -> model.id to model.getCentroids().map { it.getEmbeddingArray() } }
var matchesFound = 0 var matchesFound = 0
var photosScanned = 0 var photosScanned = 0
@@ -164,10 +193,13 @@ class LibraryScanWorker @AssistedInject constructor(
photo = photo, photo = photo,
personId = personId, personId = personId,
faceModelId = faceModel.id, faceModelId = faceModel.id,
modelEmbedding = modelEmbedding, modelCentroids = modelCentroids,
otherModelCentroids = otherModelCentroids,
faceNetModel = faceNetModel, faceNetModel = faceNetModel,
detector = detector, detector = detector,
threshold = threshold threshold = threshold,
distributionMin = distributionMin,
isChildTarget = isChildTarget
) )
if (tags.isNotEmpty()) { if (tags.isNotEmpty()) {
@@ -228,10 +260,13 @@ class LibraryScanWorker @AssistedInject constructor(
photo: com.placeholder.sherpai2.data.local.entity.ImageEntity, photo: com.placeholder.sherpai2.data.local.entity.ImageEntity,
personId: String, personId: String,
faceModelId: String, faceModelId: String,
modelEmbedding: FloatArray, modelCentroids: List<FloatArray>,
otherModelCentroids: List<Pair<String, List<FloatArray>>>,
faceNetModel: FaceNetModel, faceNetModel: FaceNetModel,
detector: com.google.mlkit.vision.face.FaceDetector, detector: com.google.mlkit.vision.face.FaceDetector,
threshold: Float threshold: Float,
distributionMin: Float,
isChildTarget: Boolean
): List<PhotoFaceTagEntity> = withContext(Dispatchers.IO) { ): List<PhotoFaceTagEntity> = withContext(Dispatchers.IO) {
try { try {
@@ -243,43 +278,94 @@ class LibraryScanWorker @AssistedInject constructor(
val inputImage = InputImage.fromBitmap(bitmap, 0) val inputImage = InputImage.fromBitmap(bitmap, 0)
val faces = detector.process(inputImage).await() val faces = detector.process(inputImage).await()
// Check each face if (faces.isEmpty()) {
val tags = faces.mapNotNull { face -> bitmap.recycle()
return@withContext emptyList()
}
// Use higher threshold for group photos
val isGroupPhoto = faces.size > 1
val effectiveThreshold = if (isGroupPhoto) GROUP_THRESHOLD else threshold
// Track best match (only tag ONE face per image to avoid false positives)
var bestMatch: PhotoFaceTagEntity? = null
var bestSimilarity = 0f
// Check each face (filter by quality first)
for (face in faces) {
// Quality check
if (!FaceQualityFilter.validateForScanning(face, bitmap.width, bitmap.height)) {
continue
}
// Skip very small faces
val faceArea = face.boundingBox.width() * face.boundingBox.height()
val imageArea = bitmap.width * bitmap.height
if (faceArea.toFloat() / imageArea < 0.02f) continue
// SIGNAL 2: Age plausibility check (if target is a child)
if (isChildTarget) {
val ageGroup = FaceQualityFilter.estimateAgeGroup(face, bitmap.width, bitmap.height)
if (ageGroup == FaceQualityFilter.AgeGroup.ADULT) {
continue // Reject clearly adult faces when searching for a child
}
}
try { try {
// Crop face // Crop and normalize face for best recognition
val faceBitmap = android.graphics.Bitmap.createBitmap( val faceBitmap = FaceNormalizer.cropAndNormalize(bitmap, face)
bitmap, ?: continue
face.boundingBox.left.coerceIn(0, bitmap.width - 1),
face.boundingBox.top.coerceIn(0, bitmap.height - 1),
face.boundingBox.width().coerceAtMost(bitmap.width - face.boundingBox.left),
face.boundingBox.height().coerceAtMost(bitmap.height - face.boundingBox.top)
)
// Generate embedding // Generate embedding
val faceEmbedding = faceNetModel.generateEmbedding(faceBitmap) val faceEmbedding = faceNetModel.generateEmbedding(faceBitmap)
faceBitmap.recycle() faceBitmap.recycle()
// Calculate similarity // Match against target person's centroids
val similarity = faceNetModel.calculateSimilarity(faceEmbedding, modelEmbedding) val targetSimilarity = modelCentroids.maxOfOrNull { centroid ->
faceNetModel.calculateSimilarity(faceEmbedding, centroid)
} ?: 0f
if (similarity >= threshold) { // SIGNAL 1: Distribution-based rejection
PhotoFaceTagEntity.create( // If similarity is below (mean - 2*stdDev) or (min - 0.05), it's a statistical outlier
if (targetSimilarity < distributionMin) {
continue // Too far below training distribution
}
// SIGNAL 3: Basic threshold check
if (targetSimilarity < effectiveThreshold) {
continue
}
// SIGNAL 4: "Best match wins" - check if any OTHER model scores higher
// This prevents tagging siblings incorrectly
val bestOtherSimilarity = otherModelCentroids.maxOfOrNull { (_, centroids) ->
centroids.maxOfOrNull { centroid ->
faceNetModel.calculateSimilarity(faceEmbedding, centroid)
} ?: 0f
} ?: 0f
val isTargetBestMatch = targetSimilarity > bestOtherSimilarity
// All signals must pass
if (isTargetBestMatch && targetSimilarity > bestSimilarity) {
bestSimilarity = targetSimilarity
bestMatch = PhotoFaceTagEntity.create(
imageId = photo.imageId, imageId = photo.imageId,
faceModelId = faceModelId, faceModelId = faceModelId,
boundingBox = face.boundingBox, boundingBox = face.boundingBox,
confidence = similarity, confidence = targetSimilarity,
faceEmbedding = faceEmbedding faceEmbedding = faceEmbedding
) )
} else {
null
} }
} catch (e: Exception) { } catch (e: Exception) {
null // Skip this face
} }
} }
bitmap.recycle() bitmap.recycle()
tags
// Return only the best match (or empty)
if (bestMatch != null) listOf(bestMatch) else emptyList()
} catch (e: Exception) { } catch (e: Exception) {
emptyList() emptyList()