3 Commits

Author SHA1 Message Date
genki
804f3d5640 rollingscan very clean
likelyhood -> find similar
REFRESHCLAUD.MD 20260126
2026-01-26 22:46:38 -05:00
genki
cfec2b980a toofasttooclaude 2026-01-26 14:15:54 -05:00
genki
1ef8faad17 jFc 2026-01-25 22:01:46 -05:00
27 changed files with 1369 additions and 251 deletions

View File

@@ -4,10 +4,10 @@
<selectionStates>
<SelectionState runConfigName="app">
<option name="selectionMode" value="DROPDOWN" />
<DropdownSelection timestamp="2026-01-25T20:45:06.118763497Z">
<DropdownSelection timestamp="2026-01-27T00:21:15.014661014Z">
<Target type="DEFAULT_BOOT">
<handle>
<DeviceId pluginId="LocalEmulator" identifier="path=/home/genki/.android/avd/Medium_Phone.avd" />
<DeviceId pluginId="PhysicalDevice" identifier="serial=R3CX106YYCB" />
</handle>
</Target>
</DropdownSelection>

View File

@@ -48,6 +48,9 @@ dependencies {
implementation(libs.androidx.lifecycle.viewmodel.compose)
implementation(libs.androidx.activity.compose)
// DataStore Preferences
implementation("androidx.datastore:datastore-preferences:1.1.1")
// Compose
implementation(platform(libs.androidx.compose.bom))
implementation(libs.androidx.compose.ui)

View File

@@ -10,6 +10,10 @@ import com.placeholder.sherpai2.data.local.entity.*
/**
* AppDatabase - Complete database for SherpAI2
*
* VERSION 12 - Distribution-based rejection stats
* - Added similarityStdDev, similarityMin to FaceModelEntity
* - Enables self-calibrating threshold for face matching
*
* VERSION 10 - User Feedback Loop
* - Added UserFeedbackEntity for storing user corrections
* - Enables cluster refinement before training
@@ -44,14 +48,15 @@ import com.placeholder.sherpai2.data.local.entity.*
PhotoFaceTagEntity::class,
PersonAgeTagEntity::class,
FaceCacheEntity::class,
UserFeedbackEntity::class, // NEW: User corrections
UserFeedbackEntity::class,
PersonStatisticsEntity::class, // Pre-computed person stats
// ===== COLLECTIONS =====
CollectionEntity::class,
CollectionImageEntity::class,
CollectionFilterEntity::class
],
version = 10, // INCREMENTED for user feedback
version = 12, // INCREMENTED for distribution-based rejection stats
exportSchema = false
)
abstract class AppDatabase : RoomDatabase() {
@@ -70,7 +75,8 @@ abstract class AppDatabase : RoomDatabase() {
abstract fun photoFaceTagDao(): PhotoFaceTagDao
abstract fun personAgeTagDao(): PersonAgeTagDao
abstract fun faceCacheDao(): FaceCacheDao
abstract fun userFeedbackDao(): UserFeedbackDao // NEW
abstract fun userFeedbackDao(): UserFeedbackDao
abstract fun personStatisticsDao(): PersonStatisticsDao
// ===== COLLECTIONS DAO =====
abstract fun collectionDao(): CollectionDao
@@ -242,13 +248,60 @@ val MIGRATION_9_10 = object : Migration(9, 10) {
}
}
/**
* MIGRATION 10 → 11 (Person Statistics)
*
* Changes:
* 1. Create person_statistics table for pre-computed aggregates
*/
val MIGRATION_10_11 = object : Migration(10, 11) {
override fun migrate(database: SupportSQLiteDatabase) {
// Create person_statistics table
database.execSQL("""
CREATE TABLE IF NOT EXISTS person_statistics (
personId TEXT PRIMARY KEY NOT NULL,
photoCount INTEGER NOT NULL DEFAULT 0,
firstPhotoDate INTEGER NOT NULL DEFAULT 0,
lastPhotoDate INTEGER NOT NULL DEFAULT 0,
averageConfidence REAL NOT NULL DEFAULT 0,
agesWithPhotos TEXT,
updatedAt INTEGER NOT NULL DEFAULT 0,
FOREIGN KEY(personId) REFERENCES persons(id) ON DELETE CASCADE
)
""")
// Index for sorting by photo count (People Dashboard)
database.execSQL("CREATE INDEX IF NOT EXISTS index_person_statistics_photoCount ON person_statistics(photoCount)")
}
}
/**
* MIGRATION 11 → 12 (Distribution-based Rejection Stats)
*
* Changes:
* 1. Add similarityStdDev column to face_models (default 0.05)
* 2. Add similarityMin column to face_models (default 0.6)
*
* These fields enable self-calibrating thresholds during scanning.
* During training, we compute stats from training sample similarities
* and use (mean - 2*stdDev) as a floor for matching.
*/
val MIGRATION_11_12 = object : Migration(11, 12) {
override fun migrate(database: SupportSQLiteDatabase) {
// Add distribution stats columns with sensible defaults for existing models
database.execSQL("ALTER TABLE face_models ADD COLUMN similarityStdDev REAL NOT NULL DEFAULT 0.05")
database.execSQL("ALTER TABLE face_models ADD COLUMN similarityMin REAL NOT NULL DEFAULT 0.6")
}
}
/**
* PRODUCTION MIGRATION NOTES:
*
* Before shipping to users, update DatabaseModule to use migrations:
*
* Room.databaseBuilder(context, AppDatabase::class.java, "sherpai.db")
* .addMigrations(MIGRATION_7_8, MIGRATION_8_9, MIGRATION_9_10) // Add all migrations
* .addMigrations(MIGRATION_7_8, MIGRATION_8_9, MIGRATION_9_10, MIGRATION_10_11, MIGRATION_11_12) // Add all migrations
* // .fallbackToDestructiveMigration() // Remove this
* .build()
*/

View File

@@ -233,6 +233,33 @@ interface FaceCacheDao {
limit: Int = 500
): List<FaceCacheEntity>
/**
* Get premium face CANDIDATES - same criteria but WITHOUT embedding requirement.
* Used to find faces that need embedding generation.
*/
@Query("""
SELECT fc.* FROM face_cache fc
INNER JOIN images i ON fc.imageId = i.imageId
WHERE i.faceCount = 1
AND fc.faceAreaRatio >= :minAreaRatio
AND fc.isFrontal = 1
AND fc.qualityScore >= :minQuality
AND fc.embedding IS NULL
ORDER BY fc.qualityScore DESC, fc.faceAreaRatio DESC
LIMIT :limit
""")
suspend fun getPremiumFaceCandidatesNeedingEmbeddings(
minAreaRatio: Float = 0.10f,
minQuality: Float = 0.7f,
limit: Int = 500
): List<FaceCacheEntity>
/**
* Update embedding for a face cache entry
*/
@Query("UPDATE face_cache SET embedding = :embedding WHERE imageId = :imageId AND faceIndex = :faceIndex")
suspend fun updateEmbedding(imageId: String, faceIndex: Int, embedding: String)
/**
* Count of premium faces available
*/

View File

@@ -66,6 +66,9 @@ interface ImageDao {
@Query("SELECT * FROM images WHERE imageId = :imageId")
suspend fun getImageById(imageId: String): ImageEntity?
@Query("SELECT * FROM images WHERE imageUri = :uri LIMIT 1")
suspend fun getImageByUri(uri: String): ImageEntity?
/**
* Stream images ordered by capture time (newest first).
*

View File

@@ -83,9 +83,89 @@ interface PhotoFaceTagDao {
*/
@Query("SELECT * FROM photo_face_tags ORDER BY detectedAt DESC LIMIT :limit")
suspend fun getRecentlyDetectedFaces(limit: Int): List<PhotoFaceTagEntity>
// ===== CO-OCCURRENCE QUERIES =====
/**
* Find people who appear in photos together with a given person.
* Returns list of (otherFaceModelId, count) sorted by count descending.
* Use case: "Who appears most with Mom?" or "Show photos of Mom WITH Dad"
*/
@Query("""
SELECT pft2.faceModelId as otherFaceModelId, COUNT(DISTINCT pft1.imageId) as coCount
FROM photo_face_tags pft1
INNER JOIN photo_face_tags pft2 ON pft1.imageId = pft2.imageId
WHERE pft1.faceModelId = :faceModelId
AND pft2.faceModelId != :faceModelId
GROUP BY pft2.faceModelId
ORDER BY coCount DESC
""")
suspend fun getCoOccurrences(faceModelId: String): List<PersonCoOccurrence>
/**
* Get images where BOTH people appear together.
*/
@Query("""
SELECT DISTINCT pft1.imageId
FROM photo_face_tags pft1
INNER JOIN photo_face_tags pft2 ON pft1.imageId = pft2.imageId
WHERE pft1.faceModelId = :faceModelId1
AND pft2.faceModelId = :faceModelId2
ORDER BY pft1.detectedAt DESC
""")
suspend fun getImagesWithBothPeople(faceModelId1: String, faceModelId2: String): List<String>
/**
* Get images where person appears ALONE (no other trained faces).
*/
@Query("""
SELECT imageId FROM photo_face_tags
WHERE faceModelId = :faceModelId
AND imageId NOT IN (
SELECT imageId FROM photo_face_tags
WHERE faceModelId != :faceModelId
)
ORDER BY detectedAt DESC
""")
suspend fun getImagesWithPersonAlone(faceModelId: String): List<String>
/**
* Get images where ALL specified people appear (N-way intersection).
* For "Intersection Search" moonshot feature.
*/
@Query("""
SELECT imageId FROM photo_face_tags
WHERE faceModelId IN (:faceModelIds)
GROUP BY imageId
HAVING COUNT(DISTINCT faceModelId) = :requiredCount
""")
suspend fun getImagesWithAllPeople(faceModelIds: List<String>, requiredCount: Int): List<String>
/**
* Get images with at least N of the specified people (family portrait detection).
*/
@Query("""
SELECT imageId, COUNT(DISTINCT faceModelId) as memberCount
FROM photo_face_tags
WHERE faceModelId IN (:faceModelIds)
GROUP BY imageId
HAVING memberCount >= :minMembers
ORDER BY memberCount DESC
""")
suspend fun getFamilyPortraits(faceModelIds: List<String>, minMembers: Int): List<FamilyPortraitResult>
}
data class FamilyPortraitResult(
val imageId: String,
val memberCount: Int
)
data class FaceModelPhotoCount(
val faceModelId: String,
val photoCount: Int
)
data class PersonCoOccurrence(
val otherFaceModelId: String,
val coCount: Int
)

View File

@@ -99,6 +99,13 @@ data class FaceCacheEntity(
companion object {
const val CURRENT_CACHE_VERSION = 1
/**
* Convert FloatArray embedding to JSON string for storage
*/
fun embeddingToJson(embedding: FloatArray): String {
return embedding.joinToString(",")
}
/**
* Create from ML Kit face detection result
*/

View File

@@ -143,6 +143,13 @@ data class FaceModelEntity(
@ColumnInfo(name = "averageConfidence")
val averageConfidence: Float,
// Distribution stats for self-calibrating rejection
@ColumnInfo(name = "similarityStdDev")
val similarityStdDev: Float = 0.05f, // Default for backwards compat
@ColumnInfo(name = "similarityMin")
val similarityMin: Float = 0.6f, // Default for backwards compat
@ColumnInfo(name = "createdAt")
val createdAt: Long,
@@ -157,26 +164,29 @@ data class FaceModelEntity(
) {
companion object {
/**
* Backwards compatible create() method
* Used by existing FaceRecognitionRepository code
* Create with distribution stats for self-calibrating rejection
*/
fun create(
personId: String,
embeddingArray: FloatArray,
trainingImageCount: Int,
averageConfidence: Float
averageConfidence: Float,
similarityStdDev: Float = 0.05f,
similarityMin: Float = 0.6f
): FaceModelEntity {
return createFromEmbedding(personId, embeddingArray, trainingImageCount, averageConfidence)
return createFromEmbedding(personId, embeddingArray, trainingImageCount, averageConfidence, similarityStdDev, similarityMin)
}
/**
* Create from single embedding (backwards compatible)
* Create from single embedding with distribution stats
*/
fun createFromEmbedding(
personId: String,
embeddingArray: FloatArray,
trainingImageCount: Int,
averageConfidence: Float
averageConfidence: Float,
similarityStdDev: Float = 0.05f,
similarityMin: Float = 0.6f
): FaceModelEntity {
val now = System.currentTimeMillis()
val centroid = TemporalCentroid(
@@ -194,6 +204,8 @@ data class FaceModelEntity(
centroidsJson = serializeCentroids(listOf(centroid)),
trainingImageCount = trainingImageCount,
averageConfidence = averageConfidence,
similarityStdDev = similarityStdDev,
similarityMin = similarityMin,
createdAt = now,
updatedAt = now,
lastUsed = null,

View File

@@ -2,8 +2,10 @@ package com.placeholder.sherpai2.data.repository
import android.content.Context
import android.graphics.Bitmap
import android.util.Log
import com.placeholder.sherpai2.data.local.dao.FaceModelDao
import com.placeholder.sherpai2.data.local.dao.ImageDao
import com.placeholder.sherpai2.data.local.dao.PersonAgeTagDao
import com.placeholder.sherpai2.data.local.dao.PersonDao
import com.placeholder.sherpai2.data.local.dao.PhotoFaceTagDao
import com.placeholder.sherpai2.data.local.entity.*
@@ -31,8 +33,12 @@ class FaceRecognitionRepository @Inject constructor(
private val personDao: PersonDao,
private val imageDao: ImageDao,
private val faceModelDao: FaceModelDao,
private val photoFaceTagDao: PhotoFaceTagDao
private val photoFaceTagDao: PhotoFaceTagDao,
private val personAgeTagDao: PersonAgeTagDao
) {
companion object {
private const val TAG = "FaceRecognitionRepo"
}
private val faceNetModel by lazy { FaceNetModel(context) }
@@ -93,11 +99,19 @@ class FaceRecognitionRepository @Inject constructor(
}
val avgConfidence = confidences.average().toFloat()
// Compute distribution stats for self-calibrating rejection
val stdDev = kotlin.math.sqrt(
confidences.map { (it - avgConfidence).toDouble().let { d -> d * d } }.average()
).toFloat()
val minSimilarity = confidences.minOrNull() ?: 0f
val faceModel = FaceModelEntity.create(
personId = personId,
embeddingArray = personEmbedding,
trainingImageCount = validImages.size,
averageConfidence = avgConfidence
averageConfidence = avgConfidence,
similarityStdDev = stdDev,
similarityMin = minSimilarity
)
faceModelDao.insertFaceModel(faceModel)
@@ -181,12 +195,15 @@ class FaceRecognitionRepository @Inject constructor(
var highestSimilarity = threshold
for (faceModel in faceModels) {
val modelEmbedding = faceModel.getEmbeddingArray()
val similarity = faceNetModel.calculateSimilarity(faceEmbedding, modelEmbedding)
// Check ALL centroids for best match (critical for children with age centroids)
val centroids = faceModel.getCentroids()
val bestCentroidSimilarity = centroids.maxOfOrNull { centroid ->
faceNetModel.calculateSimilarity(faceEmbedding, centroid.getEmbeddingArray())
} ?: 0f
if (similarity > highestSimilarity) {
highestSimilarity = similarity
bestMatch = Pair(faceModel.id, similarity)
if (bestCentroidSimilarity > highestSimilarity) {
highestSimilarity = bestCentroidSimilarity
bestMatch = Pair(faceModel.id, bestCentroidSimilarity)
}
}
@@ -374,9 +391,49 @@ class FaceRecognitionRepository @Inject constructor(
onProgress = onProgress
)
// Generate age tags for children
if (person.isChild && person.dateOfBirth != null) {
generateAgeTagsForTraining(person, validImages)
}
person.id
}
/**
* Generate age tags from training images for a child
*/
private suspend fun generateAgeTagsForTraining(
person: PersonEntity,
validImages: List<TrainingSanityChecker.ValidTrainingImage>
) {
try {
val dob = person.dateOfBirth ?: return
val tags = validImages.mapNotNull { img ->
val imageEntity = imageDao.getImageByUri(img.uri.toString()) ?: return@mapNotNull null
val ageMs = imageEntity.capturedAt - dob
val ageYears = (ageMs / (365.25 * 24 * 60 * 60 * 1000)).toInt()
if (ageYears < 0 || ageYears > 25) return@mapNotNull null
PersonAgeTagEntity.create(
personId = person.id,
personName = person.name,
imageId = imageEntity.imageId,
ageAtCapture = ageYears,
confidence = 1.0f
)
}
if (tags.isNotEmpty()) {
personAgeTagDao.insertTags(tags)
Log.d(TAG, "Created ${tags.size} age tags for ${person.name}")
}
} catch (e: Exception) {
Log.e(TAG, "Failed to generate age tags", e)
}
}
/**
* Get face model by ID
*/

View File

@@ -61,14 +61,16 @@ abstract class RepositoryModule {
personDao: PersonDao,
imageDao: ImageDao,
faceModelDao: FaceModelDao,
photoFaceTagDao: PhotoFaceTagDao
photoFaceTagDao: PhotoFaceTagDao,
personAgeTagDao: PersonAgeTagDao
): FaceRecognitionRepository {
return FaceRecognitionRepository(
context = context,
personDao = personDao,
imageDao = imageDao,
faceModelDao = faceModelDao,
photoFaceTagDao = photoFaceTagDao
photoFaceTagDao = photoFaceTagDao,
personAgeTagDao = personAgeTagDao
)
}

View File

@@ -15,6 +15,7 @@ import com.placeholder.sherpai2.data.local.dao.ImageDao
import com.placeholder.sherpai2.data.local.entity.FaceCacheEntity
import com.placeholder.sherpai2.data.local.entity.ImageEntity
import com.placeholder.sherpai2.ml.FaceNetModel
import com.placeholder.sherpai2.ml.FaceNormalizer
import com.placeholder.sherpai2.ui.discover.DiscoverySettings
import dagger.hilt.android.qualifiers.ApplicationContext
import kotlinx.coroutines.Dispatchers
@@ -344,14 +345,9 @@ class FaceClusteringService @Inject constructor(
}
try {
// Crop and generate embedding
val faceBitmap = Bitmap.createBitmap(
bitmap,
mlFace.boundingBox.left.coerceIn(0, bitmap.width - 1),
mlFace.boundingBox.top.coerceIn(0, bitmap.height - 1),
mlFace.boundingBox.width().coerceAtMost(bitmap.width - mlFace.boundingBox.left),
mlFace.boundingBox.height().coerceAtMost(bitmap.height - mlFace.boundingBox.top)
)
// Crop and normalize face
val faceBitmap = FaceNormalizer.cropAndNormalize(bitmap, mlFace)
?: return@forEach
val embedding = faceNetModel.generateEmbedding(faceBitmap)
faceBitmap.recycle()
@@ -591,13 +587,8 @@ class FaceClusteringService @Inject constructor(
if (!qualityCheck.isValid) return@mapNotNull null
try {
val faceBitmap = Bitmap.createBitmap(
bitmap,
face.boundingBox.left.coerceIn(0, bitmap.width - 1),
face.boundingBox.top.coerceIn(0, bitmap.height - 1),
face.boundingBox.width().coerceAtMost(bitmap.width - face.boundingBox.left),
face.boundingBox.height().coerceAtMost(bitmap.height - face.boundingBox.top)
)
val faceBitmap = FaceNormalizer.cropAndNormalize(bitmap, face)
?: return@mapNotNull null
val embedding = faceNetModel.generateEmbedding(faceBitmap)
faceBitmap.recycle()

View File

@@ -29,6 +29,64 @@ import kotlin.math.sqrt
*/
object FaceQualityFilter {
/**
* Age group estimation for filtering (child vs adult detection)
*/
enum class AgeGroup { CHILD, ADULT, UNCERTAIN }
/**
* Estimate whether a face belongs to a child or adult based on facial proportions.
*
* Uses two heuristics:
* 1. Eye position ratio - Children have larger foreheads, so eyes are lower (~45% from top)
* Adults have eyes at ~35% from top
* 2. Face roundness (width/height ratio) - Children: ~0.85-1.0, Adults: ~0.7-0.85
*
* @return AgeGroup.CHILD, AgeGroup.ADULT, or AgeGroup.UNCERTAIN
*/
fun estimateAgeGroup(face: Face, imageWidth: Int, imageHeight: Int): AgeGroup {
val leftEye = face.getLandmark(FaceLandmark.LEFT_EYE)
val rightEye = face.getLandmark(FaceLandmark.RIGHT_EYE)
if (leftEye == null || rightEye == null) {
return AgeGroup.UNCERTAIN
}
// Eye-to-face height ratio (where eyes sit relative to face top)
val faceHeight = face.boundingBox.height().toFloat()
val faceTop = face.boundingBox.top.toFloat()
val eyeY = (leftEye.position.y + rightEye.position.y) / 2
val eyePositionRatio = (eyeY - faceTop) / faceHeight
// Children: eyes at ~45% from top (larger forehead proportionally)
// Adults: eyes at ~35% from top
// Score: higher = more child-like
// Face roundness (width/height)
val faceWidth = face.boundingBox.width().toFloat()
val faceRatio = faceWidth / faceHeight
// Children: ratio ~0.85-1.0 (rounder faces)
// Adults: ratio ~0.7-0.85 (longer/narrower faces)
var childScore = 0
// Eye position scoring
if (eyePositionRatio > 0.45f) childScore += 2 // Strong child signal
else if (eyePositionRatio > 0.42f) childScore += 1 // Mild child signal
else if (eyePositionRatio < 0.35f) childScore -= 1 // Adult signal
// Face roundness scoring
if (faceRatio > 0.90f) childScore += 2 // Very round = child
else if (faceRatio > 0.82f) childScore += 1 // Somewhat round
else if (faceRatio < 0.75f) childScore -= 1 // Long face = adult
return when {
childScore >= 3 -> AgeGroup.CHILD
childScore <= 0 -> AgeGroup.ADULT
else -> AgeGroup.UNCERTAIN
}
}
/**
* Validate face for Discovery/Clustering
*

View File

@@ -75,7 +75,21 @@ class PopulateFaceDetectionCacheUseCase @Inject constructor(
)
try {
val imagesToScan = imageDao.getImagesNeedingFaceDetection()
// Get images that need face detection (hasFaces IS NULL)
var imagesToScan = imageDao.getImagesNeedingFaceDetection()
// CRITICAL FIX: Also check for images marked as having faces but no FaceCacheEntity
if (imagesToScan.isEmpty()) {
val faceStats = faceCacheDao.getCacheStats()
if (faceStats.totalFaces == 0) {
// FaceCacheEntity is empty - rescan images that have faces
val imagesWithFaces = imageDao.getImagesWithFaces()
if (imagesWithFaces.isNotEmpty()) {
Log.w(TAG, "FaceCacheEntity empty but ${imagesWithFaces.size} images have faces - rescanning")
imagesToScan = imagesWithFaces
}
}
}
if (imagesToScan.isEmpty()) {
Log.d(TAG, "No images need scanning")
@@ -184,7 +198,7 @@ class PopulateFaceDetectionCacheUseCase @Inject constructor(
imageUri = image.imageUri
)
// Create FaceCacheEntity entries for each face
// Create FaceCacheEntity entries for each face (NO embeddings - generated on demand)
val faceCacheEntries = faces.mapIndexed { index, face ->
createFaceCacheEntry(
imageId = image.imageId,
@@ -205,7 +219,8 @@ class PopulateFaceDetectionCacheUseCase @Inject constructor(
/**
* Create FaceCacheEntity from ML Kit Face
*
* Uses FaceCacheEntity.create() which calculates quality metrics automatically
* Uses FaceCacheEntity.create() which calculates quality metrics automatically.
* Embeddings are NOT generated here - they're generated on-demand in Training/Discovery.
*/
private fun createFaceCacheEntry(
imageId: String,
@@ -225,7 +240,7 @@ class PopulateFaceDetectionCacheUseCase @Inject constructor(
imageHeight = imageHeight,
confidence = 0.9f, // High confidence from accurate detector
isFrontal = isFrontal,
embedding = null // Will be generated later during Discovery
embedding = null // Generated on-demand in Training/Discovery
)
}
@@ -312,13 +327,27 @@ class PopulateFaceDetectionCacheUseCase @Inject constructor(
val imageStats = imageDao.getFaceCacheStats()
val faceStats = faceCacheDao.getCacheStats()
// CRITICAL FIX: If ImageEntity says "scanned" but FaceCacheEntity is empty,
// we need to re-scan. This happens after DB migration clears face_cache table.
val imagesWithFaces = imageStats?.imagesWithFaces ?: 0
val facesCached = faceStats.totalFaces
// If we have images marked as having faces but no FaceCacheEntity entries,
// those images need re-scanning
val needsRescan = if (imagesWithFaces > 0 && facesCached == 0) {
Log.w(TAG, "⚠️ FaceCacheEntity is empty but $imagesWithFaces images marked as having faces - forcing rescan")
imagesWithFaces
} else {
imageStats?.needsScanning ?: 0
}
CacheStats(
totalImages = imageStats?.totalImages ?: 0,
imagesWithFaceCache = imageStats?.imagesWithFaceCache ?: 0,
imagesWithFaces = imageStats?.imagesWithFaces ?: 0,
imagesWithFaces = imagesWithFaces,
imagesWithoutFaces = imageStats?.imagesWithoutFaces ?: 0,
needsScanning = imageStats?.needsScanning ?: 0,
totalFacesCached = faceStats.totalFaces,
needsScanning = needsRescan,
totalFacesCached = facesCached,
facesWithEmbeddings = faceStats.withEmbeddings,
averageQuality = faceStats.avgQuality
)

View File

@@ -20,6 +20,7 @@ import androidx.lifecycle.compose.collectAsStateWithLifecycle
import androidx.navigation.NavController
import coil.compose.AsyncImage
import com.placeholder.sherpai2.data.local.entity.TagEntity
import com.placeholder.sherpai2.ui.imagedetail.viewmodel.FaceTagInfo
import com.placeholder.sherpai2.ui.imagedetail.viewmodel.ImageDetailViewModel
import net.engawapg.lib.zoomable.rememberZoomState
import net.engawapg.lib.zoomable.zoomable
@@ -51,8 +52,12 @@ fun ImageDetailScreen(
}
val tags by viewModel.tags.collectAsStateWithLifecycle()
val faceTags by viewModel.faceTags.collectAsStateWithLifecycle()
var showTags by remember { mutableStateOf(false) }
// Total tag count for badge
val totalTagCount = tags.size + faceTags.size
// Navigation state
val currentIndex = if (allImageUris.isNotEmpty()) allImageUris.indexOf(imageUri) else -1
val hasNavigation = allImageUris.isNotEmpty() && currentIndex >= 0
@@ -84,27 +89,35 @@ fun ImageDetailScreen(
horizontalArrangement = Arrangement.spacedBy(4.dp),
verticalAlignment = Alignment.CenterVertically
) {
if (tags.isNotEmpty()) {
if (totalTagCount > 0) {
Badge(
containerColor = if (showTags)
MaterialTheme.colorScheme.primary
else if (faceTags.isNotEmpty())
MaterialTheme.colorScheme.tertiary
else
MaterialTheme.colorScheme.surfaceVariant
) {
Text(
tags.size.toString(),
totalTagCount.toString(),
color = if (showTags)
MaterialTheme.colorScheme.onPrimary
else if (faceTags.isNotEmpty())
MaterialTheme.colorScheme.onTertiary
else
MaterialTheme.colorScheme.onSurfaceVariant
)
}
}
Icon(
if (showTags) Icons.Default.Label else Icons.Default.LocalOffer,
if (faceTags.isNotEmpty()) Icons.Default.Face
else if (showTags) Icons.Default.Label
else Icons.Default.LocalOffer,
"Show Tags",
tint = if (showTags)
MaterialTheme.colorScheme.primary
else if (faceTags.isNotEmpty())
MaterialTheme.colorScheme.tertiary
else
MaterialTheme.colorScheme.onSurfaceVariant
)
@@ -189,6 +202,30 @@ fun ImageDetailScreen(
contentPadding = PaddingValues(16.dp),
verticalArrangement = Arrangement.spacedBy(8.dp)
) {
// Face Tags Section (People in Photo)
if (faceTags.isNotEmpty()) {
item {
Text(
"People (${faceTags.size})",
style = MaterialTheme.typography.titleMedium,
fontWeight = FontWeight.Bold,
color = MaterialTheme.colorScheme.tertiary
)
}
items(faceTags, key = { it.tagId }) { faceTag ->
FaceTagCard(
faceTag = faceTag,
onRemove = { viewModel.removeFaceTag(faceTag) }
)
}
item {
Spacer(modifier = Modifier.height(8.dp))
}
}
// Regular Tags Section
item {
Text(
"Tags (${tags.size})",
@@ -197,7 +234,7 @@ fun ImageDetailScreen(
)
}
if (tags.isEmpty()) {
if (tags.isEmpty() && faceTags.isEmpty()) {
item {
Text(
"No tags yet",
@@ -205,6 +242,14 @@ fun ImageDetailScreen(
color = MaterialTheme.colorScheme.onSurfaceVariant
)
}
} else if (tags.isEmpty()) {
item {
Text(
"No other tags",
style = MaterialTheme.typography.bodySmall,
color = MaterialTheme.colorScheme.onSurfaceVariant
)
}
}
items(tags, key = { it.tagId }) { tag ->
@@ -220,6 +265,83 @@ fun ImageDetailScreen(
}
}
@Composable
private fun FaceTagCard(
faceTag: FaceTagInfo,
onRemove: () -> Unit
) {
Card(
modifier = Modifier.fillMaxWidth(),
colors = CardDefaults.cardColors(
containerColor = MaterialTheme.colorScheme.tertiaryContainer
),
shape = RoundedCornerShape(8.dp)
) {
Row(
modifier = Modifier
.fillMaxWidth()
.padding(12.dp),
horizontalArrangement = Arrangement.SpaceBetween,
verticalAlignment = Alignment.CenterVertically
) {
Column(modifier = Modifier.weight(1f)) {
Row(
horizontalArrangement = Arrangement.spacedBy(8.dp),
verticalAlignment = Alignment.CenterVertically
) {
Icon(
imageVector = Icons.Default.Face,
contentDescription = null,
modifier = Modifier.size(20.dp),
tint = MaterialTheme.colorScheme.tertiary
)
Text(
text = faceTag.personName,
style = MaterialTheme.typography.bodyLarge,
fontWeight = FontWeight.SemiBold
)
}
Row(
horizontalArrangement = Arrangement.spacedBy(4.dp),
verticalAlignment = Alignment.CenterVertically
) {
Text(
text = "Face Recognition",
style = MaterialTheme.typography.labelSmall,
color = MaterialTheme.colorScheme.onSurfaceVariant
)
Text(
text = "",
style = MaterialTheme.typography.labelSmall,
color = MaterialTheme.colorScheme.onSurfaceVariant
)
Text(
text = "${(faceTag.confidence * 100).toInt()}% confidence",
style = MaterialTheme.typography.labelSmall,
color = if (faceTag.confidence >= 0.7f)
MaterialTheme.colorScheme.primary
else if (faceTag.confidence >= 0.5f)
MaterialTheme.colorScheme.secondary
else
MaterialTheme.colorScheme.error
)
}
}
// Remove button
IconButton(
onClick = onRemove,
colors = IconButtonDefaults.iconButtonColors(
contentColor = MaterialTheme.colorScheme.error
)
) {
Icon(Icons.Default.Delete, "Remove face tag")
}
}
}
}
@Composable
private fun TagCard(
tag: TagEntity,

View File

@@ -2,6 +2,10 @@ package com.placeholder.sherpai2.ui.imagedetail.viewmodel
import androidx.lifecycle.ViewModel
import androidx.lifecycle.viewModelScope
import com.placeholder.sherpai2.data.local.dao.FaceModelDao
import com.placeholder.sherpai2.data.local.dao.ImageDao
import com.placeholder.sherpai2.data.local.dao.PersonDao
import com.placeholder.sherpai2.data.local.dao.PhotoFaceTagDao
import com.placeholder.sherpai2.data.local.entity.TagEntity
import com.placeholder.sherpai2.domain.repository.TaggingRepository
import dagger.hilt.android.lifecycle.HiltViewModel
@@ -10,17 +14,33 @@ import kotlinx.coroutines.flow.*
import kotlinx.coroutines.launch
import javax.inject.Inject
/**
* Represents a person tagged in this photo via face recognition
*/
data class FaceTagInfo(
val personId: String,
val personName: String,
val confidence: Float,
val faceModelId: String,
val tagId: String
)
/**
* ImageDetailViewModel
*
* Owns:
* - Image context
* - Tag write operations
* - Face tag display (people recognized in photo)
*/
@HiltViewModel
@OptIn(ExperimentalCoroutinesApi::class)
class ImageDetailViewModel @Inject constructor(
private val tagRepository: TaggingRepository
private val tagRepository: TaggingRepository,
private val imageDao: ImageDao,
private val photoFaceTagDao: PhotoFaceTagDao,
private val faceModelDao: FaceModelDao,
private val personDao: PersonDao
) : ViewModel() {
private val imageUri = MutableStateFlow<String?>(null)
@@ -37,8 +57,43 @@ class ImageDetailViewModel @Inject constructor(
initialValue = emptyList()
)
// Face tags (people recognized in this photo)
private val _faceTags = MutableStateFlow<List<FaceTagInfo>>(emptyList())
val faceTags: StateFlow<List<FaceTagInfo>> = _faceTags.asStateFlow()
fun loadImage(uri: String) {
imageUri.value = uri
loadFaceTags(uri)
}
private fun loadFaceTags(uri: String) {
viewModelScope.launch {
try {
// Get imageId from URI
val image = imageDao.getImageByUri(uri) ?: return@launch
// Get face tags for this image
val faceTags = photoFaceTagDao.getTagsForImage(image.imageId)
// Resolve to person names
val faceTagInfos = faceTags.mapNotNull { tag ->
val faceModel = faceModelDao.getFaceModelById(tag.faceModelId) ?: return@mapNotNull null
val person = personDao.getPersonById(faceModel.personId) ?: return@mapNotNull null
FaceTagInfo(
personId = person.id,
personName = person.name,
confidence = tag.confidence,
faceModelId = tag.faceModelId,
tagId = tag.id
)
}
_faceTags.value = faceTagInfos.sortedByDescending { it.confidence }
} catch (e: Exception) {
_faceTags.value = emptyList()
}
}
}
fun addTag(value: String) {
@@ -54,4 +109,15 @@ class ImageDetailViewModel @Inject constructor(
tagRepository.removeTagFromImage(uri, tag.value)
}
}
/**
* Remove a face tag (person recognition)
*/
fun removeFaceTag(faceTagInfo: FaceTagInfo) {
viewModelScope.launch {
photoFaceTagDao.deleteTagById(faceTagInfo.tagId)
// Reload face tags
imageUri.value?.let { loadFaceTags(it) }
}
}
}

View File

@@ -95,6 +95,9 @@ fun PersonInventoryScreen(
},
onDelete = { personId ->
viewModel.deletePerson(personId)
},
onClearTags = { personId ->
viewModel.clearTagsForPerson(personId)
}
)
}
@@ -319,7 +322,8 @@ private fun PersonList(
persons: List<PersonWithModelInfo>,
onScan: (String) -> Unit,
onView: (String) -> Unit,
onDelete: (String) -> Unit
onDelete: (String) -> Unit,
onClearTags: (String) -> Unit
) {
LazyColumn(
contentPadding = PaddingValues(vertical = 8.dp)
@@ -332,7 +336,8 @@ private fun PersonList(
person = person,
onScan = { onScan(person.person.id) },
onView = { onView(person.person.id) },
onDelete = { onDelete(person.person.id) }
onDelete = { onDelete(person.person.id) },
onClearTags = { onClearTags(person.person.id) }
)
}
}
@@ -343,9 +348,34 @@ private fun PersonCard(
person: PersonWithModelInfo,
onScan: () -> Unit,
onView: () -> Unit,
onDelete: () -> Unit
onDelete: () -> Unit,
onClearTags: () -> Unit
) {
var showDeleteDialog by remember { mutableStateOf(false) }
var showClearDialog by remember { mutableStateOf(false) }
if (showClearDialog) {
AlertDialog(
onDismissRequest = { showClearDialog = false },
title = { Text("Clear tags for ${person.person.name}?") },
text = { Text("This will remove all ${person.taggedPhotoCount} photo tags but keep the face model. You can re-scan after clearing.") },
confirmButton = {
TextButton(
onClick = {
showClearDialog = false
onClearTags()
}
) {
Text("Clear Tags", color = MaterialTheme.colorScheme.error)
}
},
dismissButton = {
TextButton(onClick = { showClearDialog = false }) {
Text("Cancel")
}
}
)
}
if (showDeleteDialog) {
AlertDialog(
@@ -413,6 +443,17 @@ private fun PersonCard(
)
}
// Clear tags button (if has tags)
if (person.taggedPhotoCount > 0) {
IconButton(onClick = { showClearDialog = true }) {
Icon(
Icons.Default.ClearAll,
contentDescription = "Clear Tags",
tint = MaterialTheme.colorScheme.secondary
)
}
}
// Delete button
IconButton(onClick = { showDeleteDialog = true }) {
Icon(

View File

@@ -19,6 +19,7 @@ import com.placeholder.sherpai2.data.local.entity.PersonEntity
import com.placeholder.sherpai2.data.local.entity.PhotoFaceTagEntity
import com.placeholder.sherpai2.ml.FaceNetModel
import com.placeholder.sherpai2.ml.ThresholdStrategy
import com.placeholder.sherpai2.domain.clustering.FaceQualityFilter
import dagger.hilt.android.lifecycle.HiltViewModel
import dagger.hilt.android.qualifiers.ApplicationContext
import kotlinx.coroutines.Dispatchers
@@ -105,6 +106,21 @@ class PersonInventoryViewModel @Inject constructor(
}
}
/**
* Clear all face tags for a person (keep model, allow rescan)
*/
fun clearTagsForPerson(personId: String) {
viewModelScope.launch(Dispatchers.IO) {
try {
val faceModel = faceModelDao.getFaceModelByPersonId(personId)
if (faceModel != null) {
photoFaceTagDao.deleteTagsForFaceModel(faceModel.id)
}
loadPersons()
} catch (e: Exception) {}
}
}
fun scanForPerson(personId: String) {
viewModelScope.launch(Dispatchers.IO) {
try {
@@ -127,16 +143,40 @@ class PersonInventoryViewModel @Inject constructor(
val detectorOptions = FaceDetectorOptions.Builder()
.setPerformanceMode(FaceDetectorOptions.PERFORMANCE_MODE_ACCURATE)
.setLandmarkMode(FaceDetectorOptions.LANDMARK_MODE_NONE)
.setLandmarkMode(FaceDetectorOptions.LANDMARK_MODE_ALL) // Needed for age estimation
.setClassificationMode(FaceDetectorOptions.CLASSIFICATION_MODE_NONE)
.setMinFaceSize(0.15f)
.build()
val detector = FaceDetection.getClient(detectorOptions)
val modelEmbedding = faceModel.getEmbeddingArray()
val faceNetModel = FaceNetModel(context)
// CRITICAL: Use ALL centroids for matching
val modelCentroids = faceModel.getCentroids().map { it.getEmbeddingArray() }
val trainingCount = faceModel.trainingImageCount
val baseThreshold = ThresholdStrategy.getLiberalThreshold(trainingCount)
android.util.Log.e("PersonScan", "=== CENTROIDS: ${modelCentroids.size}, trainingCount: $trainingCount ===")
if (modelCentroids.isEmpty()) {
_scanningState.value = ScanningState.Error("No centroids found")
return@launch
}
val faceNetModel = FaceNetModel(context)
// Production threshold - STRICT to avoid false positives
// Solo face photos: 0.62, Group photos: 0.68
val baseThreshold = 0.62f
val groupPhotoThreshold = 0.68f // Higher bar for multi-face images
// Load ALL other models for "best match wins" comparison
val allModels = faceModelDao.getAllActiveFaceModels()
val otherModelCentroids = allModels
.filter { it.id != faceModel.id }
.map { model -> model.id to model.getCentroids().map { it.getEmbeddingArray() } }
// Distribution-based minimum threshold (self-calibrating)
val distributionMin = (faceModel.averageConfidence - 2 * faceModel.similarityStdDev)
.coerceAtLeast(faceModel.similarityMin - 0.05f)
.coerceAtLeast(0.50f) // Never go below 0.50 absolute floor
android.util.Log.d("PersonScan", "Using threshold: solo=$baseThreshold, group=$groupPhotoThreshold, distributionMin=$distributionMin (avgConf=${faceModel.averageConfidence}, stdDev=${faceModel.similarityStdDev}), centroids: ${modelCentroids.size}, competing models: ${otherModelCentroids.size}, isChild=${person.isChild}")
val completed = AtomicInteger(0)
val facesFound = AtomicInteger(0)
@@ -148,7 +188,7 @@ class PersonInventoryViewModel @Inject constructor(
val jobs = untaggedImages.map { image ->
async {
semaphore.withPermit {
processImage(image, detector, faceNetModel, modelEmbedding, trainingCount, baseThreshold, personId, faceModel.id, batchMatches, batchUpdateMutex, completed, facesFound, startTime, totalToScan, person.name)
processImage(image, detector, faceNetModel, modelCentroids, otherModelCentroids, trainingCount, baseThreshold, groupPhotoThreshold, distributionMin, person.isChild, personId, faceModel.id, batchMatches, batchUpdateMutex, completed, facesFound, startTime, totalToScan, person.name)
}
}
}
@@ -175,7 +215,10 @@ class PersonInventoryViewModel @Inject constructor(
private suspend fun processImage(
image: ImageEntity, detector: com.google.mlkit.vision.face.FaceDetector, faceNetModel: FaceNetModel,
modelEmbedding: FloatArray, trainingCount: Int, baseThreshold: Float, personId: String, faceModelId: String,
modelCentroids: List<FloatArray>, otherModelCentroids: List<Pair<String, List<FloatArray>>>,
trainingCount: Int, baseThreshold: Float, groupPhotoThreshold: Float,
distributionMin: Float, isChildTarget: Boolean,
personId: String, faceModelId: String,
batchMatches: MutableList<Triple<String, String, Float>>, batchUpdateMutex: Mutex,
completed: AtomicInteger, facesFound: AtomicInteger, startTime: Long, totalToScan: Int, personName: String
) {
@@ -200,9 +243,13 @@ class PersonInventoryViewModel @Inject constructor(
val scaleX = sizeOpts.outWidth.toFloat() / detectionBitmap.width
val scaleY = sizeOpts.outHeight.toFloat() / detectionBitmap.height
val imageQuality = ThresholdStrategy.estimateImageQuality(sizeOpts.outWidth, sizeOpts.outHeight)
val detectionContext = ThresholdStrategy.estimateDetectionContext(faces.size)
val threshold = ThresholdStrategy.getOptimalThreshold(trainingCount, imageQuality, detectionContext).coerceAtMost(baseThreshold)
// CRITICAL: Use higher threshold for group photos (more likely false positives)
val isGroupPhoto = faces.size > 1
val effectiveThreshold = if (isGroupPhoto) groupPhotoThreshold else baseThreshold
// Track best match in this image (only tag ONE face per image)
var bestMatchSimilarity = 0f
var foundMatch = false
for (face in faces) {
val scaledBounds = android.graphics.Rect(
@@ -212,22 +259,70 @@ class PersonInventoryViewModel @Inject constructor(
(face.boundingBox.bottom * scaleY).toInt()
)
val faceBitmap = loadFaceRegion(uri, scaledBounds) ?: continue
// Skip very small faces (less reliable)
val faceArea = scaledBounds.width() * scaledBounds.height()
val imageArea = sizeOpts.outWidth * sizeOpts.outHeight
val faceRatio = faceArea.toFloat() / imageArea
if (faceRatio < 0.02f) continue // Face must be at least 2% of image
// SIGNAL 2: Age plausibility check (if target is a child)
if (isChildTarget) {
val ageGroup = FaceQualityFilter.estimateAgeGroup(face, detectionBitmap.width, detectionBitmap.height)
if (ageGroup == FaceQualityFilter.AgeGroup.ADULT) {
continue // Reject clearly adult faces when searching for a child
}
}
// CRITICAL: Add padding to face crop (same as training)
val faceBitmap = loadFaceRegionWithPadding(uri, scaledBounds, sizeOpts.outWidth, sizeOpts.outHeight) ?: continue
val faceEmbedding = faceNetModel.generateEmbedding(faceBitmap)
val similarity = faceNetModel.calculateSimilarity(faceEmbedding, modelEmbedding)
faceBitmap.recycle()
if (similarity >= threshold) {
batchUpdateMutex.withLock {
batchMatches.add(Triple(personId, image.imageId, similarity))
facesFound.incrementAndGet()
if (batchMatches.size >= BATCH_DB_SIZE) {
saveBatchMatches(batchMatches.toList(), faceModelId)
batchMatches.clear()
}
// Match against target person's centroids
val targetSimilarity = modelCentroids.maxOfOrNull { centroid ->
faceNetModel.calculateSimilarity(faceEmbedding, centroid)
} ?: 0f
// SIGNAL 1: Distribution-based rejection
// If similarity is below (mean - 2*stdDev) or (min - 0.05), it's a statistical outlier
if (targetSimilarity < distributionMin) {
continue // Too far below training distribution
}
// SIGNAL 3: Basic threshold check
if (targetSimilarity < effectiveThreshold) {
continue
}
// SIGNAL 4: "Best match wins" - check if any OTHER model scores higher
// This prevents tagging siblings/similar people incorrectly
val bestOtherSimilarity = otherModelCentroids.maxOfOrNull { (_, centroids) ->
centroids.maxOfOrNull { centroid ->
faceNetModel.calculateSimilarity(faceEmbedding, centroid)
} ?: 0f
} ?: 0f
val isTargetBestMatch = targetSimilarity > bestOtherSimilarity
// All signals must pass
if (isTargetBestMatch && targetSimilarity > bestMatchSimilarity) {
bestMatchSimilarity = targetSimilarity
foundMatch = true
}
}
// Only add ONE tag per image (the best match)
if (foundMatch) {
batchUpdateMutex.withLock {
batchMatches.add(Triple(personId, image.imageId, bestMatchSimilarity))
facesFound.incrementAndGet()
if (batchMatches.size >= BATCH_DB_SIZE) {
saveBatchMatches(batchMatches.toList(), faceModelId)
batchMatches.clear()
}
}
}
detectionBitmap.recycle()
} catch (e: Exception) {
} finally {
@@ -250,18 +345,32 @@ class PersonInventoryViewModel @Inject constructor(
} catch (e: Exception) { null }
}
private fun loadFaceRegion(uri: Uri, bounds: android.graphics.Rect): Bitmap? {
/**
* Load face region WITH 25% padding - CRITICAL for matching training conditions
*/
private fun loadFaceRegionWithPadding(uri: Uri, bounds: android.graphics.Rect, imgWidth: Int, imgHeight: Int): Bitmap? {
return try {
val full = context.contentResolver.openInputStream(uri)?.use {
BitmapFactory.decodeStream(it, null, BitmapFactory.Options().apply { inPreferredConfig = Bitmap.Config.ARGB_8888 })
} ?: return null
val safeLeft = bounds.left.coerceIn(0, full.width - 1)
val safeTop = bounds.top.coerceIn(0, full.height - 1)
val safeWidth = bounds.width().coerceAtMost(full.width - safeLeft)
val safeHeight = bounds.height().coerceAtMost(full.height - safeTop)
// Add 25% padding (same as training)
val padding = (kotlin.math.max(bounds.width(), bounds.height()) * 0.25f).toInt()
val cropped = Bitmap.createBitmap(full, safeLeft, safeTop, safeWidth, safeHeight)
val left = (bounds.left - padding).coerceAtLeast(0)
val top = (bounds.top - padding).coerceAtLeast(0)
val right = (bounds.right + padding).coerceAtMost(full.width)
val bottom = (bounds.bottom + padding).coerceAtMost(full.height)
val width = right - left
val height = bottom - top
if (width <= 0 || height <= 0) {
full.recycle()
return null
}
val cropped = Bitmap.createBitmap(full, left, top, width, height)
full.recycle()
cropped
} catch (e: Exception) { null }

View File

@@ -339,10 +339,7 @@ fun AppNavHost(
* SETTINGS SCREEN
*/
composable(AppRoutes.SETTINGS) {
DummyScreen(
title = "Settings",
subtitle = "App preferences and configuration"
)
com.placeholder.sherpai2.ui.settings.SettingsScreen()
}
}
}

View File

@@ -78,6 +78,7 @@ fun MainScreen(
AppRoutes.DISCOVER -> "Discover People" // ✅ SHOWS NOW!
AppRoutes.INVENTORY -> "People"
AppRoutes.TRAIN -> "Train Model"
AppRoutes.ScanResultsScreen -> "Train New Person"
AppRoutes.TAGS -> "Tags"
AppRoutes.UTILITIES -> "Utilities"
AppRoutes.SETTINGS -> "Settings"

View File

@@ -2,7 +2,9 @@ package com.placeholder.sherpai2.ui.rollingscan
import android.net.Uri
import androidx.compose.foundation.BorderStroke
import androidx.compose.foundation.ExperimentalFoundationApi
import androidx.compose.foundation.clickable
import androidx.compose.foundation.combinedClickable
import androidx.compose.foundation.layout.*
import androidx.compose.foundation.lazy.grid.GridCells
import androidx.compose.foundation.lazy.grid.GridItemSpan
@@ -37,7 +39,7 @@ import com.placeholder.sherpai2.domain.similarity.FaceSimilarityScorer
* - Quick action buttons (Select Top N)
* - Submit button with validation
*/
@OptIn(ExperimentalMaterial3Api::class)
@OptIn(ExperimentalMaterial3Api::class, ExperimentalFoundationApi::class)
@Composable
fun RollingScanScreen(
seedImageIds: List<String>,
@@ -48,6 +50,7 @@ fun RollingScanScreen(
) {
val uiState by viewModel.uiState.collectAsState()
val selectedImageIds by viewModel.selectedImageIds.collectAsState()
val negativeImageIds by viewModel.negativeImageIds.collectAsState()
val rankedPhotos by viewModel.rankedPhotos.collectAsState()
val isScanning by viewModel.isScanning.collectAsState()
@@ -70,6 +73,7 @@ fun RollingScanScreen(
isReadyForTraining = viewModel.isReadyForTraining(),
validationMessage = viewModel.getValidationMessage(),
onSelectTopN = { count -> viewModel.selectTopN(count) },
onSelectAboveThreshold = { threshold -> viewModel.selectAllAboveThreshold(threshold) },
onSubmit = {
val uris = viewModel.getSelectedImageUris()
onSubmitForTraining(uris)
@@ -93,8 +97,10 @@ fun RollingScanScreen(
RollingScanPhotoGrid(
rankedPhotos = rankedPhotos,
selectedImageIds = selectedImageIds,
negativeImageIds = negativeImageIds,
isScanning = isScanning,
onToggleSelection = { imageId -> viewModel.toggleSelection(imageId) },
onToggleNegative = { imageId -> viewModel.toggleNegative(imageId) },
modifier = Modifier.padding(padding)
)
}
@@ -159,19 +165,26 @@ private fun RollingScanTopBar(
}
// ═══════════════════════════════════════════════════════════
// PHOTO GRID
// PHOTO GRID - Similarity-based bucketing
// ═══════════════════════════════════════════════════════════
@OptIn(ExperimentalFoundationApi::class)
@Composable
private fun RollingScanPhotoGrid(
rankedPhotos: List<FaceSimilarityScorer.ScoredPhoto>,
selectedImageIds: Set<String>,
negativeImageIds: Set<String>,
isScanning: Boolean,
onToggleSelection: (String) -> Unit,
onToggleNegative: (String) -> Unit,
modifier: Modifier = Modifier
) {
Column(modifier = modifier.fillMaxSize()) {
// Bucket by similarity score
val veryLikely = rankedPhotos.filter { it.finalScore >= 0.60f }
val probably = rankedPhotos.filter { it.finalScore in 0.45f..0.599f }
val maybe = rankedPhotos.filter { it.finalScore < 0.45f }
Column(modifier = modifier.fillMaxSize()) {
// Scanning indicator
if (isScanning) {
LinearProgressIndicator(
@@ -180,69 +193,78 @@ private fun RollingScanPhotoGrid(
)
}
// Hint for negative marking
Text(
text = "Tap to select • Long-press to mark as NOT this person",
style = MaterialTheme.typography.bodySmall,
color = MaterialTheme.colorScheme.onSurfaceVariant,
modifier = Modifier.padding(horizontal = 12.dp, vertical = 4.dp)
)
LazyVerticalGrid(
columns = GridCells.Fixed(3),
contentPadding = PaddingValues(8.dp),
horizontalArrangement = Arrangement.spacedBy(8.dp),
verticalArrangement = Arrangement.spacedBy(8.dp)
) {
// Section: Most Similar (top 10)
val topMatches = rankedPhotos.take(10)
if (topMatches.isNotEmpty()) {
// Section: Very Likely (>60%)
if (veryLikely.isNotEmpty()) {
item(span = { GridItemSpan(3) }) {
SectionHeader(
icon = Icons.Default.Whatshot,
text = "🔥 Most Similar (${topMatches.size})",
color = MaterialTheme.colorScheme.primary
text = "🟢 Very Likely (${veryLikely.size})",
color = Color(0xFF4CAF50)
)
}
items(topMatches, key = { it.imageId }) { photo ->
items(veryLikely, key = { it.imageId }) { photo ->
PhotoCard(
photo = photo,
isSelected = photo.imageId in selectedImageIds,
isNegative = photo.imageId in negativeImageIds,
onToggle = { onToggleSelection(photo.imageId) },
onLongPress = { onToggleNegative(photo.imageId) },
showSimilarityBadge = true
)
}
}
// Section: Good Matches (11-30)
val goodMatches = rankedPhotos.drop(10).take(20)
if (goodMatches.isNotEmpty()) {
// Section: Probably (45-60%)
if (probably.isNotEmpty()) {
item(span = { GridItemSpan(3) }) {
SectionHeader(
icon = Icons.Default.CheckCircle,
text = "📊 Good Matches (${goodMatches.size})",
color = MaterialTheme.colorScheme.tertiary
text = "🟡 Probably (${probably.size})",
color = Color(0xFFFFC107)
)
}
items(goodMatches, key = { it.imageId }) { photo ->
items(probably, key = { it.imageId }) { photo ->
PhotoCard(
photo = photo,
isSelected = photo.imageId in selectedImageIds,
onToggle = { onToggleSelection(photo.imageId) }
isNegative = photo.imageId in negativeImageIds,
onToggle = { onToggleSelection(photo.imageId) },
onLongPress = { onToggleNegative(photo.imageId) },
showSimilarityBadge = true
)
}
}
// Section: Other Photos
val otherPhotos = rankedPhotos.drop(30)
if (otherPhotos.isNotEmpty()) {
// Section: Maybe (<45%)
if (maybe.isNotEmpty()) {
item(span = { GridItemSpan(3) }) {
SectionHeader(
icon = Icons.Default.Photo,
text = "📷 Other Photos (${otherPhotos.size})",
color = MaterialTheme.colorScheme.onSurfaceVariant
text = "🟠 Maybe (${maybe.size})",
color = Color(0xFFFF9800)
)
}
items(otherPhotos, key = { it.imageId }) { photo ->
items(maybe, key = { it.imageId }) { photo ->
PhotoCard(
photo = photo,
isSelected = photo.imageId in selectedImageIds,
onToggle = { onToggleSelection(photo.imageId) }
isNegative = photo.imageId in negativeImageIds,
onToggle = { onToggleSelection(photo.imageId) },
onLongPress = { onToggleNegative(photo.imageId) }
)
}
}
@@ -258,24 +280,34 @@ private fun RollingScanPhotoGrid(
}
// ═══════════════════════════════════════════════════════════
// PHOTO CARD
// PHOTO CARD - with long-press for negative marking
// ═══════════════════════════════════════════════════════════
@OptIn(ExperimentalFoundationApi::class)
@Composable
private fun PhotoCard(
photo: FaceSimilarityScorer.ScoredPhoto,
isSelected: Boolean,
isNegative: Boolean = false,
onToggle: () -> Unit,
onLongPress: () -> Unit = {},
showSimilarityBadge: Boolean = false
) {
val borderColor = when {
isNegative -> Color(0xFFE53935) // Red for negative
isSelected -> MaterialTheme.colorScheme.primary
else -> MaterialTheme.colorScheme.outline.copy(alpha = 0.3f)
}
val borderWidth = if (isSelected || isNegative) 3.dp else 1.dp
Card(
modifier = Modifier
.aspectRatio(1f)
.clickable(onClick = onToggle),
border = if (isSelected)
BorderStroke(3.dp, MaterialTheme.colorScheme.primary)
else
BorderStroke(1.dp, MaterialTheme.colorScheme.outline.copy(alpha = 0.3f)),
.combinedClickable(
onClick = onToggle,
onLongClick = onLongPress
),
border = BorderStroke(borderWidth, borderColor),
elevation = CardDefaults.cardElevation(
defaultElevation = if (isSelected) 4.dp else 1.dp
)
@@ -289,22 +321,47 @@ private fun PhotoCard(
contentScale = ContentScale.Crop
)
// Similarity badge (top-left) - Only for top matches
if (showSimilarityBadge) {
// Dim overlay for negatives
if (isNegative) {
Box(
modifier = Modifier
.fillMaxSize()
.padding(0.dp),
contentAlignment = Alignment.Center
) {
Surface(
modifier = Modifier.fillMaxSize(),
color = Color.Black.copy(alpha = 0.5f)
) {}
Icon(
Icons.Default.Close,
contentDescription = "Not this person",
tint = Color.White,
modifier = Modifier.size(32.dp)
)
}
}
// Similarity badge (top-left)
if (showSimilarityBadge && !isNegative) {
Surface(
modifier = Modifier
.align(Alignment.TopStart)
.padding(6.dp),
shape = RoundedCornerShape(8.dp),
color = MaterialTheme.colorScheme.primary,
color = when {
photo.finalScore >= 0.60f -> Color(0xFF4CAF50)
photo.finalScore >= 0.45f -> Color(0xFFFFC107)
else -> Color(0xFFFF9800)
},
shadowElevation = 4.dp
) {
Text(
text = "${(photo.similarityScore * 100).toInt()}%",
text = "${(photo.finalScore * 100).toInt()}%",
modifier = Modifier.padding(horizontal = 8.dp, vertical = 4.dp),
style = MaterialTheme.typography.labelSmall,
fontWeight = FontWeight.Bold,
color = MaterialTheme.colorScheme.onPrimary
color = Color.White
)
}
}
@@ -332,7 +389,7 @@ private fun PhotoCard(
}
// Face count badge (bottom-right)
if (photo.faceCount > 1) {
if (photo.faceCount > 1 && !isNegative) {
Surface(
modifier = Modifier
.align(Alignment.BottomEnd)
@@ -395,6 +452,7 @@ private fun RollingScanBottomBar(
isReadyForTraining: Boolean,
validationMessage: String?,
onSelectTopN: (Int) -> Unit,
onSelectAboveThreshold: (Float) -> Unit,
onSubmit: () -> Unit
) {
Surface(
@@ -416,39 +474,49 @@ private fun RollingScanBottomBar(
)
}
// First row: threshold selection
Row(
modifier = Modifier.fillMaxWidth(),
horizontalArrangement = Arrangement.spacedBy(8.dp)
horizontalArrangement = Arrangement.spacedBy(6.dp)
) {
// Quick select buttons
OutlinedButton(
onClick = { onSelectTopN(10) },
modifier = Modifier.weight(1f)
onClick = { onSelectAboveThreshold(0.60f) },
modifier = Modifier.weight(1f),
contentPadding = PaddingValues(horizontal = 8.dp, vertical = 4.dp)
) {
Text("Top 10")
Text(">60%", style = MaterialTheme.typography.labelSmall)
}
OutlinedButton(
onClick = { onSelectTopN(20) },
modifier = Modifier.weight(1f)
onClick = { onSelectAboveThreshold(0.50f) },
modifier = Modifier.weight(1f),
contentPadding = PaddingValues(horizontal = 8.dp, vertical = 4.dp)
) {
Text("Top 20")
Text(">50%", style = MaterialTheme.typography.labelSmall)
}
OutlinedButton(
onClick = { onSelectTopN(15) },
modifier = Modifier.weight(1f),
contentPadding = PaddingValues(horizontal = 8.dp, vertical = 4.dp)
) {
Text("Top 15", style = MaterialTheme.typography.labelSmall)
}
}
// Submit button
Button(
onClick = onSubmit,
enabled = isReadyForTraining,
modifier = Modifier.weight(1.5f)
) {
Icon(
Icons.Default.Done,
contentDescription = null,
modifier = Modifier.size(18.dp)
)
Spacer(Modifier.width(8.dp))
Text("Train ($selectedCount)")
}
Spacer(Modifier.height(8.dp))
// Second row: submit
Button(
onClick = onSubmit,
enabled = isReadyForTraining,
modifier = Modifier.fillMaxWidth()
) {
Icon(
Icons.Default.Done,
contentDescription = null,
modifier = Modifier.size(18.dp)
)
Spacer(Modifier.width(8.dp))
Text("Train Model ($selectedCount photos)")
}
}
}

View File

@@ -44,6 +44,11 @@ class RollingScanViewModel @Inject constructor(
private const val TAG = "RollingScanVM"
private const val DEBOUNCE_DELAY_MS = 300L
private const val MIN_PHOTOS_FOR_TRAINING = 15
// Progressive thresholds based on selection count
private const val FLOOR_FEW_SEEDS = 0.30f // 1-3 seeds
private const val FLOOR_MEDIUM_SEEDS = 0.40f // 4-10 seeds
private const val FLOOR_MANY_SEEDS = 0.50f // 10+ seeds
}
// ═══════════════════════════════════════════════════════════
@@ -71,6 +76,11 @@ class RollingScanViewModel @Inject constructor(
// Cache of selected embeddings
private val selectedEmbeddings = mutableListOf<FloatArray>()
// Negative embeddings (marked as "not this person")
private val _negativeImageIds = MutableStateFlow<Set<String>>(emptySet())
val negativeImageIds: StateFlow<Set<String>> = _negativeImageIds.asStateFlow()
private val negativeEmbeddings = mutableListOf<FloatArray>()
// All available image IDs
private var allImageIds: List<String> = emptyList()
@@ -156,24 +166,55 @@ class RollingScanViewModel @Inject constructor(
current.remove(imageId)
viewModelScope.launch {
// Remove embedding from cache
val cached = faceCacheDao.getEmbeddingByImageId(imageId)
cached?.getEmbedding()?.let { selectedEmbeddings.remove(it) }
}
} else {
// Select
// Select (and remove from negatives if present)
current.add(imageId)
if (imageId in _negativeImageIds.value) {
toggleNegative(imageId)
}
viewModelScope.launch {
// Add embedding to cache
val cached = faceCacheDao.getEmbeddingByImageId(imageId)
cached?.getEmbedding()?.let { selectedEmbeddings.add(it) }
}
}
_selectedImageIds.value = current
_selectedImageIds.value = current.toSet() // Immutable copy
scanDebouncer.debounce {
triggerRollingScan()
}
}
/**
* Toggle negative marking ("Not this person")
*/
fun toggleNegative(imageId: String) {
val current = _negativeImageIds.value.toMutableSet()
if (imageId in current) {
current.remove(imageId)
viewModelScope.launch {
val cached = faceCacheDao.getEmbeddingByImageId(imageId)
cached?.getEmbedding()?.let { negativeEmbeddings.remove(it) }
}
} else {
current.add(imageId)
// Remove from selected if present
if (imageId in _selectedImageIds.value) {
toggleSelection(imageId)
}
viewModelScope.launch {
val cached = faceCacheDao.getEmbeddingByImageId(imageId)
cached?.getEmbedding()?.let { negativeEmbeddings.add(it) }
}
}
_negativeImageIds.value = current.toSet() // Immutable copy
// Debounced rescan
scanDebouncer.debounce {
triggerRollingScan()
}
@@ -190,13 +231,33 @@ class RollingScanViewModel @Inject constructor(
val current = _selectedImageIds.value.toMutableSet()
current.addAll(topPhotos)
_selectedImageIds.value = current
_selectedImageIds.value = current.toSet() // Immutable copy
viewModelScope.launch {
// Add embeddings
val embeddings = faceCacheDao.getEmbeddingsForImages(topPhotos.toList())
selectedEmbeddings.addAll(embeddings.mapNotNull { it.getEmbedding() })
triggerRollingScan()
}
}
/**
* Select all photos above a similarity threshold
*/
fun selectAllAboveThreshold(threshold: Float) {
val photosAbove = _rankedPhotos.value
.filter { it.finalScore >= threshold }
.map { it.imageId }
val current = _selectedImageIds.value.toMutableSet()
current.addAll(photosAbove)
_selectedImageIds.value = current.toSet() // Immutable copy
viewModelScope.launch {
val newIds = photosAbove.filter { it !in _selectedImageIds.value }
if (newIds.isNotEmpty()) {
val embeddings = faceCacheDao.getEmbeddingsForImages(newIds)
selectedEmbeddings.addAll(embeddings.mapNotNull { it.getEmbedding() })
}
triggerRollingScan()
}
}
@@ -207,17 +268,24 @@ class RollingScanViewModel @Inject constructor(
fun clearSelection() {
_selectedImageIds.value = emptySet()
selectedEmbeddings.clear()
// Reset ranking
_rankedPhotos.value = emptyList()
}
/**
* Clear negative markings
*/
fun clearNegatives() {
_negativeImageIds.value = emptySet()
negativeEmbeddings.clear()
scanDebouncer.debounce { triggerRollingScan() }
}
// ═══════════════════════════════════════════════════════════
// ROLLING SCAN LOGIC
// ═══════════════════════════════════════════════════════════
/**
* CORE: Trigger rolling similarity scan
* CORE: Trigger rolling similarity scan with progressive filtering
*/
private suspend fun triggerRollingScan() {
if (selectedEmbeddings.isEmpty()) {
@@ -228,7 +296,15 @@ class RollingScanViewModel @Inject constructor(
try {
_isScanning.value = true
Log.d(TAG, "Starting scan with ${selectedEmbeddings.size} selected embeddings")
val selectionCount = selectedEmbeddings.size
Log.d(TAG, "Starting scan with $selectionCount selected, ${negativeEmbeddings.size} negative")
// Progressive threshold based on selection count
val similarityFloor = when {
selectionCount <= 3 -> FLOOR_FEW_SEEDS
selectionCount <= 10 -> FLOOR_MEDIUM_SEEDS
else -> FLOOR_MANY_SEEDS
}
// Calculate centroid from selected embeddings
val centroid = faceSimilarityScorer.calculateCentroid(selectedEmbeddings)
@@ -240,17 +316,38 @@ class RollingScanViewModel @Inject constructor(
centroid = centroid
)
// Update image URIs in scored photos
val photosWithUris = scoredPhotos.map { photo ->
photo.copy(
imageUri = imageUriCache[photo.imageId] ?: photo.imageId
)
}
// Apply negative penalty, quality boost, and floor filter
val filteredPhotos = scoredPhotos
.map { photo ->
// Calculate max similarity to any negative embedding
val negativePenalty = if (negativeEmbeddings.isNotEmpty()) {
negativeEmbeddings.maxOfOrNull { neg ->
cosineSimilarity(photo.cachedEmbedding, neg)
} ?: 0f
} else 0f
Log.d(TAG, "Scan complete. Scored ${photosWithUris.size} photos")
// Quality multiplier: solo face, large face, good quality
val qualityMultiplier = 1f +
(if (photo.faceCount == 1) 0.15f else 0f) +
(if (photo.faceAreaRatio > 0.15f) 0.10f else 0f) +
(if (photo.qualityScore > 0.7f) 0.10f else 0f)
// Update ranked list
_rankedPhotos.value = photosWithUris
// Final score = (similarity - negativePenalty) * qualityMultiplier
val adjustedScore = ((photo.similarityScore - negativePenalty * 0.5f) * qualityMultiplier)
.coerceIn(0f, 1f)
photo.copy(
imageUri = imageUriCache[photo.imageId] ?: photo.imageId,
finalScore = adjustedScore
)
}
.filter { it.finalScore >= similarityFloor } // Apply floor
.filter { it.imageId !in _negativeImageIds.value } // Hide negatives
.sortedByDescending { it.finalScore }
Log.d(TAG, "Scan complete. ${filteredPhotos.size} photos above floor $similarityFloor")
_rankedPhotos.value = filteredPhotos
} catch (e: Exception) {
Log.e(TAG, "Scan failed", e)
@@ -259,6 +356,19 @@ class RollingScanViewModel @Inject constructor(
}
}
private fun cosineSimilarity(a: FloatArray, b: FloatArray): Float {
if (a.size != b.size) return 0f
var dot = 0f
var normA = 0f
var normB = 0f
for (i in a.indices) {
dot += a[i] * b[i]
normA += a[i] * a[i]
normB += b[i] * b[i]
}
return if (normA > 0 && normB > 0) dot / (kotlin.math.sqrt(normA) * kotlin.math.sqrt(normB)) else 0f
}
// ═══════════════════════════════════════════════════════════
// SUBMISSION
// ═══════════════════════════════════════════════════════════
@@ -299,9 +409,11 @@ class RollingScanViewModel @Inject constructor(
fun reset() {
_uiState.value = RollingScanState.Idle
_selectedImageIds.value = emptySet()
_negativeImageIds.value = emptySet()
_rankedPhotos.value = emptyList()
_isScanning.value = false
selectedEmbeddings.clear()
negativeEmbeddings.clear()
allImageIds = emptyList()
imageUriCache = emptyMap()
scanDebouncer.cancel()

View File

@@ -6,8 +6,11 @@ import android.graphics.BitmapFactory
import android.graphics.Rect
import android.net.Uri
import com.google.mlkit.vision.common.InputImage
import com.google.mlkit.vision.face.Face
import com.google.mlkit.vision.face.FaceDetection
import com.google.mlkit.vision.face.FaceDetectorOptions
import com.placeholder.sherpai2.domain.clustering.FaceQualityFilter
import com.placeholder.sherpai2.ml.FaceNormalizer
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.async
import kotlinx.coroutines.awaitAll
@@ -64,21 +67,30 @@ class FaceDetectionHelper(private val context: Context) {
val inputImage = InputImage.fromBitmap(bitmap, 0)
val faces = detector.process(inputImage).await()
// Sort by face size (area) to get the largest face
val sortedFaces = faces.sortedByDescending { face ->
// Filter to quality faces - use lenient scanning filter
// (Discovery filter was too strict, rejecting faces from rolling scan)
val qualityFaces = faces.filter { face ->
FaceQualityFilter.validateForScanning(
face = face,
imageWidth = bitmap.width,
imageHeight = bitmap.height
)
}
// Sort by face size (area) to get the largest quality face
val sortedFaces = qualityFaces.sortedByDescending { face ->
face.boundingBox.width() * face.boundingBox.height()
}
val croppedFace = if (sortedFaces.isNotEmpty()) {
// Crop the LARGEST detected face (most likely the subject)
cropFaceFromBitmap(bitmap, sortedFaces[0].boundingBox)
FaceNormalizer.cropAndNormalize(bitmap, sortedFaces[0])
} else null
FaceDetectionResult(
uri = uri,
hasFace = faces.isNotEmpty(),
faceCount = faces.size,
faceBounds = faces.map { it.boundingBox },
hasFace = qualityFaces.isNotEmpty(),
faceCount = qualityFaces.size,
faceBounds = qualityFaces.map { it.boundingBox },
croppedFaceBitmap = croppedFace
)
} catch (e: Exception) {

View File

@@ -51,57 +51,41 @@ fun ScanResultsScreen(
}
}
Scaffold(
topBar = {
TopAppBar(
title = { Text("Train New Person") },
colors = TopAppBarDefaults.topAppBarColors(
containerColor = MaterialTheme.colorScheme.primaryContainer
// No Scaffold - MainScreen provides TopAppBar
Box(modifier = Modifier.fillMaxSize()) {
when (state) {
is ScanningState.Idle -> {}
is ScanningState.Processing -> {
ProcessingView(progress = state.progress, total = state.total)
}
is ScanningState.Success -> {
ImprovedResultsView(
result = state.sanityCheckResult,
onContinue = {
trainViewModel.createFaceModel(
trainViewModel.getPersonInfo()?.name ?: "Unknown"
)
},
onRetry = onFinish,
onReplaceImage = { oldUri, newUri ->
trainViewModel.replaceImage(oldUri, newUri)
},
onSelectFaceFromMultiple = { result ->
showFacePickerDialog = result
},
trainViewModel = trainViewModel
)
)
}
is ScanningState.Error -> {
ErrorView(message = state.message, onRetry = onFinish)
}
}
) { paddingValues ->
Box(
modifier = Modifier
.fillMaxSize()
.padding(paddingValues)
) {
when (state) {
is ScanningState.Idle -> {}
is ScanningState.Processing -> {
ProcessingView(progress = state.progress, total = state.total)
}
is ScanningState.Success -> {
ImprovedResultsView(
result = state.sanityCheckResult,
onContinue = {
// PersonInfo already captured in TrainingScreen!
// Just start training with stored info
trainViewModel.createFaceModel(
trainViewModel.getPersonInfo()?.name ?: "Unknown"
)
},
onRetry = onFinish,
onReplaceImage = { oldUri, newUri ->
trainViewModel.replaceImage(oldUri, newUri)
},
onSelectFaceFromMultiple = { result ->
showFacePickerDialog = result
},
trainViewModel = trainViewModel
)
}
is ScanningState.Error -> {
ErrorView(message = state.message, onRetry = onFinish)
}
}
if (trainingState is TrainingState.Processing) {
TrainingOverlay(trainingState = trainingState as TrainingState.Processing)
}
if (trainingState is TrainingState.Processing) {
TrainingOverlay(trainingState = trainingState as TrainingState.Processing)
}
}

View File

@@ -5,11 +5,18 @@ import android.graphics.Bitmap
import android.net.Uri
import androidx.lifecycle.AndroidViewModel
import androidx.lifecycle.viewModelScope
import androidx.datastore.preferences.core.booleanPreferencesKey
import androidx.datastore.preferences.preferencesDataStore
import androidx.work.WorkManager
import android.content.Context
import com.placeholder.sherpai2.data.local.entity.PersonEntity
import com.placeholder.sherpai2.data.repository.FaceRecognitionRepository
import com.placeholder.sherpai2.ml.FaceNetModel
import com.placeholder.sherpai2.workers.LibraryScanWorker
import dagger.hilt.android.lifecycle.HiltViewModel
import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.first
import kotlinx.coroutines.flow.map
import kotlinx.coroutines.flow.StateFlow
import kotlinx.coroutines.flow.asStateFlow
import kotlinx.coroutines.launch
@@ -48,15 +55,20 @@ data class PersonInfo(
/**
* FIXED TrainViewModel with proper exclude functionality and efficient replace
*/
private val android.content.Context.dataStore by preferencesDataStore(name = "settings")
private val KEY_BACKGROUND_TAGGING = booleanPreferencesKey("background_recognition_tagging")
@HiltViewModel
class TrainViewModel @Inject constructor(
application: Application,
private val faceRecognitionRepository: FaceRecognitionRepository,
private val faceNetModel: FaceNetModel
private val faceNetModel: FaceNetModel,
private val workManager: WorkManager
) : AndroidViewModel(application) {
private val sanityChecker = TrainingSanityChecker(application)
private val faceDetectionHelper = FaceDetectionHelper(application)
private val dataStore = application.dataStore
private val _uiState = MutableStateFlow<ScanningState>(ScanningState.Idle)
val uiState: StateFlow<ScanningState> = _uiState.asStateFlow()
@@ -174,6 +186,20 @@ class TrainViewModel @Inject constructor(
relationship = person.relationship
)
// Trigger library scan if setting enabled
val backgroundTaggingEnabled = dataStore.data
.map { it[KEY_BACKGROUND_TAGGING] ?: true }
.first()
if (backgroundTaggingEnabled) {
// Use default threshold (0.62 solo, 0.68 group)
val scanRequest = LibraryScanWorker.createWorkRequest(
personId = personId,
personName = personName
)
workManager.enqueue(scanRequest)
}
} catch (e: Exception) {
_trainingState.value = TrainingState.Error(
e.message ?: "Failed to create face model"
@@ -355,7 +381,7 @@ class TrainViewModel @Inject constructor(
faceDetectionResults = updatedFaceResults,
validationErrors = updatedErrors,
validImagesWithFaces = updatedValidImages,
excludedImages = excludedImages
excludedImages = excludedImages.toSet() // Immutable copy for Compose state detection
)
}

View File

@@ -49,6 +49,7 @@ fun TrainingPhotoSelectorScreen(
val isRanking by viewModel.isRanking.collectAsStateWithLifecycle()
val showPremiumOnly by viewModel.showPremiumOnly.collectAsStateWithLifecycle()
val premiumCount by viewModel.premiumCount.collectAsStateWithLifecycle()
val embeddingProgress by viewModel.embeddingProgress.collectAsStateWithLifecycle()
Scaffold(
topBar = {
@@ -155,7 +156,33 @@ fun TrainingPhotoSelectorScreen(
modifier = Modifier.fillMaxSize(),
contentAlignment = Alignment.Center
) {
CircularProgressIndicator()
Column(
horizontalAlignment = Alignment.CenterHorizontally,
verticalArrangement = Arrangement.spacedBy(16.dp)
) {
CircularProgressIndicator()
// Capture value to avoid race condition
val progress = embeddingProgress
if (progress != null) {
Text(
"Preparing faces: ${progress.current}/${progress.total}",
style = MaterialTheme.typography.bodyMedium,
color = MaterialTheme.colorScheme.onSurfaceVariant
)
LinearProgressIndicator(
progress = { progress.current.toFloat() / progress.total },
modifier = Modifier
.width(200.dp)
.padding(top = 8.dp)
)
} else {
Text(
"Loading premium faces...",
style = MaterialTheme.typography.bodyMedium,
color = MaterialTheme.colorScheme.onSurfaceVariant
)
}
}
}
}
photos.isEmpty() -> {

View File

@@ -1,20 +1,31 @@
package com.placeholder.sherpai2.ui.trainingprep
import android.app.Application
import android.graphics.Bitmap
import android.graphics.BitmapFactory
import android.graphics.Rect
import android.net.Uri
import android.util.Log
import androidx.lifecycle.ViewModel
import androidx.lifecycle.AndroidViewModel
import androidx.lifecycle.viewModelScope
import com.placeholder.sherpai2.data.local.dao.FaceCacheDao
import com.placeholder.sherpai2.data.local.dao.ImageDao
import com.placeholder.sherpai2.data.local.entity.FaceCacheEntity
import com.placeholder.sherpai2.data.local.entity.ImageEntity
import com.placeholder.sherpai2.domain.similarity.FaceSimilarityScorer
import com.placeholder.sherpai2.ml.FaceNetModel
import dagger.hilt.android.lifecycle.HiltViewModel
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.Job
import kotlinx.coroutines.delay
import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.StateFlow
import kotlinx.coroutines.flow.asStateFlow
import kotlinx.coroutines.launch
import kotlinx.coroutines.withContext
import javax.inject.Inject
import kotlin.math.max
import kotlin.math.min
/**
* TrainingPhotoSelectorViewModel - PREMIUM GRID + ROLLING SCAN
@@ -27,15 +38,18 @@ import javax.inject.Inject
*/
@HiltViewModel
class TrainingPhotoSelectorViewModel @Inject constructor(
application: Application,
private val imageDao: ImageDao,
private val faceCacheDao: FaceCacheDao,
private val faceSimilarityScorer: FaceSimilarityScorer
) : ViewModel() {
private val faceSimilarityScorer: FaceSimilarityScorer,
private val faceNetModel: FaceNetModel
) : AndroidViewModel(application) {
companion object {
private const val TAG = "PremiumSelector"
private const val MIN_SEEDS_FOR_ROLLING_SCAN = 1
private const val MAX_SEEDS_FOR_ROLLING_SCAN = 5
private const val MAX_EMBEDDINGS_TO_GENERATE = 500
}
// All photos (for fallback / full list)
@@ -56,6 +70,12 @@ class TrainingPhotoSelectorViewModel @Inject constructor(
private val _isRanking = MutableStateFlow(false)
val isRanking: StateFlow<Boolean> = _isRanking.asStateFlow()
// Embedding generation progress
private val _embeddingProgress = MutableStateFlow<EmbeddingProgress?>(null)
val embeddingProgress: StateFlow<EmbeddingProgress?> = _embeddingProgress.asStateFlow()
data class EmbeddingProgress(val current: Int, val total: Int)
// Premium mode toggle
private val _showPremiumOnly = MutableStateFlow(true)
val showPremiumOnly: StateFlow<Boolean> = _showPremiumOnly.asStateFlow()
@@ -79,20 +99,47 @@ class TrainingPhotoSelectorViewModel @Inject constructor(
/**
* Load PREMIUM faces first (solo, large, frontal, high quality)
* If no embeddings exist, generate them on-demand for premium candidates
*/
private fun loadPremiumFaces() {
viewModelScope.launch {
try {
_isLoading.value = true
// Get premium faces from cache
val premiumFaceCache = faceCacheDao.getPremiumFaces(
// First check if premium faces with embeddings exist
var premiumFaceCache = faceCacheDao.getPremiumFaces(
minAreaRatio = 0.10f,
minQuality = 0.7f,
limit = 500
)
Log.d(TAG, " Found ${premiumFaceCache.size} premium faces")
Log.d(TAG, "📊 Found ${premiumFaceCache.size} premium faces with embeddings")
// If no premium faces with embeddings, generate them on-demand
if (premiumFaceCache.isEmpty()) {
Log.d(TAG, "⚠️ No premium faces with embeddings - generating on-demand")
val candidates = faceCacheDao.getPremiumFaceCandidatesNeedingEmbeddings(
minAreaRatio = 0.10f,
minQuality = 0.7f,
limit = MAX_EMBEDDINGS_TO_GENERATE
)
Log.d(TAG, "📦 Found ${candidates.size} premium candidates needing embeddings")
if (candidates.isNotEmpty()) {
generateEmbeddingsForCandidates(candidates)
// Re-query after generating
premiumFaceCache = faceCacheDao.getPremiumFaces(
minAreaRatio = 0.10f,
minQuality = 0.7f,
limit = 500
)
Log.d(TAG, "✅ After generation: ${premiumFaceCache.size} premium faces")
}
}
_premiumCount.value = premiumFaceCache.size
// Get corresponding ImageEntities
@@ -117,10 +164,108 @@ class TrainingPhotoSelectorViewModel @Inject constructor(
loadAllFaces()
} finally {
_isLoading.value = false
_embeddingProgress.value = null
}
}
}
/**
* Generate embeddings for premium face candidates
*/
private suspend fun generateEmbeddingsForCandidates(candidates: List<FaceCacheEntity>) {
val context = getApplication<Application>()
val total = candidates.size
var processed = 0
withContext(Dispatchers.IO) {
// Get image URIs for candidates
val imageIds = candidates.map { it.imageId }.distinct()
val images = imageDao.getImagesByIds(imageIds)
val imageUriMap = images.associate { it.imageId to it.imageUri }
for (candidate in candidates) {
try {
val imageUri = imageUriMap[candidate.imageId] ?: continue
// Load bitmap
val bitmap = loadBitmapOptimized(context, Uri.parse(imageUri)) ?: continue
// Crop face
val croppedFace = cropFaceWithPadding(bitmap, candidate.getBoundingBox())
bitmap.recycle()
if (croppedFace == null) continue
// Generate embedding
val embedding = faceNetModel.generateEmbedding(croppedFace)
croppedFace.recycle()
// Validate embedding
if (embedding.any { it != 0f }) {
// Save to database
val embeddingJson = FaceCacheEntity.embeddingToJson(embedding)
faceCacheDao.updateEmbedding(candidate.imageId, candidate.faceIndex, embeddingJson)
}
} catch (e: Exception) {
Log.w(TAG, "Failed to generate embedding for ${candidate.imageId}: ${e.message}")
}
processed++
withContext(Dispatchers.Main) {
_embeddingProgress.value = EmbeddingProgress(processed, total)
}
}
}
Log.d(TAG, "✅ Generated embeddings for $processed/$total candidates")
}
private fun loadBitmapOptimized(context: android.content.Context, uri: Uri, maxDim: Int = 768): Bitmap? {
return try {
val options = BitmapFactory.Options().apply { inJustDecodeBounds = true }
context.contentResolver.openInputStream(uri)?.use { stream ->
BitmapFactory.decodeStream(stream, null, options)
}
var sampleSize = 1
while (options.outWidth / sampleSize > maxDim || options.outHeight / sampleSize > maxDim) {
sampleSize *= 2
}
val finalOptions = BitmapFactory.Options().apply {
inSampleSize = sampleSize
inPreferredConfig = Bitmap.Config.ARGB_8888
}
context.contentResolver.openInputStream(uri)?.use { stream ->
BitmapFactory.decodeStream(stream, null, finalOptions)
}
} catch (e: Exception) {
Log.w(TAG, "Failed to load bitmap: ${e.message}")
null
}
}
private fun cropFaceWithPadding(bitmap: Bitmap, boundingBox: Rect): Bitmap? {
return try {
val padding = (max(boundingBox.width(), boundingBox.height()) * 0.25f).toInt()
val left = max(0, boundingBox.left - padding)
val top = max(0, boundingBox.top - padding)
val right = min(bitmap.width, boundingBox.right + padding)
val bottom = min(bitmap.height, boundingBox.bottom + padding)
val width = right - left
val height = bottom - top
if (width > 0 && height > 0) {
Bitmap.createBitmap(bitmap, left, top, width, height)
} else null
} catch (e: Exception) {
Log.w(TAG, "Failed to crop face: ${e.message}")
null
}
}
/**
* Fallback: load all photos with faces
*/

View File

@@ -9,6 +9,9 @@ import com.google.mlkit.vision.common.InputImage
import com.google.mlkit.vision.face.FaceDetection
import com.google.mlkit.vision.face.FaceDetectorOptions
import com.placeholder.sherpai2.data.local.dao.FaceModelDao
import com.placeholder.sherpai2.data.local.dao.PersonDao
import com.placeholder.sherpai2.domain.clustering.FaceQualityFilter
import com.placeholder.sherpai2.ml.FaceNormalizer
import com.placeholder.sherpai2.data.local.dao.ImageDao
import com.placeholder.sherpai2.data.local.dao.PhotoFaceTagDao
import com.placeholder.sherpai2.data.local.entity.PhotoFaceTagEntity
@@ -52,7 +55,8 @@ class LibraryScanWorker @AssistedInject constructor(
@Assisted workerParams: WorkerParameters,
private val imageDao: ImageDao,
private val faceModelDao: FaceModelDao,
private val photoFaceTagDao: PhotoFaceTagDao
private val photoFaceTagDao: PhotoFaceTagDao,
private val personDao: PersonDao
) : CoroutineWorker(context, workerParams) {
companion object {
@@ -65,7 +69,8 @@ class LibraryScanWorker @AssistedInject constructor(
const val KEY_MATCHES_FOUND = "matches_found"
const val KEY_PHOTOS_SCANNED = "photos_scanned"
private const val DEFAULT_THRESHOLD = 0.70f // Slightly looser than validation
private const val DEFAULT_THRESHOLD = 0.62f // Solo photos
private const val GROUP_THRESHOLD = 0.68f // Group photos (stricter)
private const val BATCH_SIZE = 20
private const val MAX_RETRIES = 3
@@ -137,16 +142,40 @@ class LibraryScanWorker @AssistedInject constructor(
)
}
// Step 2.5: Load person to check isChild flag
val person = withContext(Dispatchers.IO) {
personDao.getPersonById(personId)
}
val isChildTarget = person?.isChild ?: false
// Step 3: Initialize ML components
val faceNetModel = FaceNetModel(context)
val detector = FaceDetection.getClient(
FaceDetectorOptions.Builder()
.setPerformanceMode(FaceDetectorOptions.PERFORMANCE_MODE_ACCURATE)
.setLandmarkMode(FaceDetectorOptions.LANDMARK_MODE_ALL) // Needed for age estimation
.setMinFaceSize(0.15f)
.build()
)
val modelEmbedding = faceModel.getEmbeddingArray()
// Distribution-based minimum threshold (self-calibrating)
val distributionMin = (faceModel.averageConfidence - 2 * faceModel.similarityStdDev)
.coerceAtLeast(faceModel.similarityMin - 0.05f)
.coerceAtLeast(0.50f) // Never go below 0.50 absolute floor
// Get ALL centroids for multi-centroid matching (critical for children)
val modelCentroids = faceModel.getCentroids().map { it.getEmbeddingArray() }
if (modelCentroids.isEmpty()) {
return@withContext Result.failure(workDataOf("error" to "No centroids in model"))
}
// Load ALL other models for "best match wins" comparison
// This prevents tagging siblings incorrectly
val allModels = withContext(Dispatchers.IO) { faceModelDao.getAllActiveFaceModels() }
val otherModelCentroids = allModels
.filter { it.id != faceModel.id }
.map { model -> model.id to model.getCentroids().map { it.getEmbeddingArray() } }
var matchesFound = 0
var photosScanned = 0
@@ -164,10 +193,13 @@ class LibraryScanWorker @AssistedInject constructor(
photo = photo,
personId = personId,
faceModelId = faceModel.id,
modelEmbedding = modelEmbedding,
modelCentroids = modelCentroids,
otherModelCentroids = otherModelCentroids,
faceNetModel = faceNetModel,
detector = detector,
threshold = threshold
threshold = threshold,
distributionMin = distributionMin,
isChildTarget = isChildTarget
)
if (tags.isNotEmpty()) {
@@ -228,10 +260,13 @@ class LibraryScanWorker @AssistedInject constructor(
photo: com.placeholder.sherpai2.data.local.entity.ImageEntity,
personId: String,
faceModelId: String,
modelEmbedding: FloatArray,
modelCentroids: List<FloatArray>,
otherModelCentroids: List<Pair<String, List<FloatArray>>>,
faceNetModel: FaceNetModel,
detector: com.google.mlkit.vision.face.FaceDetector,
threshold: Float
threshold: Float,
distributionMin: Float,
isChildTarget: Boolean
): List<PhotoFaceTagEntity> = withContext(Dispatchers.IO) {
try {
@@ -243,43 +278,94 @@ class LibraryScanWorker @AssistedInject constructor(
val inputImage = InputImage.fromBitmap(bitmap, 0)
val faces = detector.process(inputImage).await()
// Check each face
val tags = faces.mapNotNull { face ->
if (faces.isEmpty()) {
bitmap.recycle()
return@withContext emptyList()
}
// Use higher threshold for group photos
val isGroupPhoto = faces.size > 1
val effectiveThreshold = if (isGroupPhoto) GROUP_THRESHOLD else threshold
// Track best match (only tag ONE face per image to avoid false positives)
var bestMatch: PhotoFaceTagEntity? = null
var bestSimilarity = 0f
// Check each face (filter by quality first)
for (face in faces) {
// Quality check
if (!FaceQualityFilter.validateForScanning(face, bitmap.width, bitmap.height)) {
continue
}
// Skip very small faces
val faceArea = face.boundingBox.width() * face.boundingBox.height()
val imageArea = bitmap.width * bitmap.height
if (faceArea.toFloat() / imageArea < 0.02f) continue
// SIGNAL 2: Age plausibility check (if target is a child)
if (isChildTarget) {
val ageGroup = FaceQualityFilter.estimateAgeGroup(face, bitmap.width, bitmap.height)
if (ageGroup == FaceQualityFilter.AgeGroup.ADULT) {
continue // Reject clearly adult faces when searching for a child
}
}
try {
// Crop face
val faceBitmap = android.graphics.Bitmap.createBitmap(
bitmap,
face.boundingBox.left.coerceIn(0, bitmap.width - 1),
face.boundingBox.top.coerceIn(0, bitmap.height - 1),
face.boundingBox.width().coerceAtMost(bitmap.width - face.boundingBox.left),
face.boundingBox.height().coerceAtMost(bitmap.height - face.boundingBox.top)
)
// Crop and normalize face for best recognition
val faceBitmap = FaceNormalizer.cropAndNormalize(bitmap, face)
?: continue
// Generate embedding
val faceEmbedding = faceNetModel.generateEmbedding(faceBitmap)
faceBitmap.recycle()
// Calculate similarity
val similarity = faceNetModel.calculateSimilarity(faceEmbedding, modelEmbedding)
// Match against target person's centroids
val targetSimilarity = modelCentroids.maxOfOrNull { centroid ->
faceNetModel.calculateSimilarity(faceEmbedding, centroid)
} ?: 0f
if (similarity >= threshold) {
PhotoFaceTagEntity.create(
// SIGNAL 1: Distribution-based rejection
// If similarity is below (mean - 2*stdDev) or (min - 0.05), it's a statistical outlier
if (targetSimilarity < distributionMin) {
continue // Too far below training distribution
}
// SIGNAL 3: Basic threshold check
if (targetSimilarity < effectiveThreshold) {
continue
}
// SIGNAL 4: "Best match wins" - check if any OTHER model scores higher
// This prevents tagging siblings incorrectly
val bestOtherSimilarity = otherModelCentroids.maxOfOrNull { (_, centroids) ->
centroids.maxOfOrNull { centroid ->
faceNetModel.calculateSimilarity(faceEmbedding, centroid)
} ?: 0f
} ?: 0f
val isTargetBestMatch = targetSimilarity > bestOtherSimilarity
// All signals must pass
if (isTargetBestMatch && targetSimilarity > bestSimilarity) {
bestSimilarity = targetSimilarity
bestMatch = PhotoFaceTagEntity.create(
imageId = photo.imageId,
faceModelId = faceModelId,
boundingBox = face.boundingBox,
confidence = similarity,
confidence = targetSimilarity,
faceEmbedding = faceEmbedding
)
} else {
null
}
} catch (e: Exception) {
null
// Skip this face
}
}
bitmap.recycle()
tags
// Return only the best match (or empty)
if (bestMatch != null) listOf(bestMatch) else emptyList()
} catch (e: Exception) {
emptyList()