rollingscan very clean

likelyhood -> find similar
REFRESHCLAUD.MD 20260126
This commit is contained in:
genki
2026-01-26 22:46:38 -05:00
parent cfec2b980a
commit 804f3d5640
11 changed files with 557 additions and 134 deletions

View File

@@ -4,10 +4,10 @@
<selectionStates> <selectionStates>
<SelectionState runConfigName="app"> <SelectionState runConfigName="app">
<option name="selectionMode" value="DROPDOWN" /> <option name="selectionMode" value="DROPDOWN" />
<DropdownSelection timestamp="2026-01-26T02:23:12.309011764Z"> <DropdownSelection timestamp="2026-01-27T00:21:15.014661014Z">
<Target type="DEFAULT_BOOT"> <Target type="DEFAULT_BOOT">
<handle> <handle>
<DeviceId pluginId="LocalEmulator" identifier="path=/home/genki/.android/avd/Medium_Phone.avd" /> <DeviceId pluginId="PhysicalDevice" identifier="serial=R3CX106YYCB" />
</handle> </handle>
</Target> </Target>
</DropdownSelection> </DropdownSelection>

View File

@@ -10,6 +10,10 @@ import com.placeholder.sherpai2.data.local.entity.*
/** /**
* AppDatabase - Complete database for SherpAI2 * AppDatabase - Complete database for SherpAI2
* *
* VERSION 12 - Distribution-based rejection stats
* - Added similarityStdDev, similarityMin to FaceModelEntity
* - Enables self-calibrating threshold for face matching
*
* VERSION 10 - User Feedback Loop * VERSION 10 - User Feedback Loop
* - Added UserFeedbackEntity for storing user corrections * - Added UserFeedbackEntity for storing user corrections
* - Enables cluster refinement before training * - Enables cluster refinement before training
@@ -52,7 +56,7 @@ import com.placeholder.sherpai2.data.local.entity.*
CollectionImageEntity::class, CollectionImageEntity::class,
CollectionFilterEntity::class CollectionFilterEntity::class
], ],
version = 11, // INCREMENTED for person statistics version = 12, // INCREMENTED for distribution-based rejection stats
exportSchema = false exportSchema = false
) )
abstract class AppDatabase : RoomDatabase() { abstract class AppDatabase : RoomDatabase() {
@@ -272,13 +276,32 @@ val MIGRATION_10_11 = object : Migration(10, 11) {
} }
} }
/**
* MIGRATION 11 → 12 (Distribution-based Rejection Stats)
*
* Changes:
* 1. Add similarityStdDev column to face_models (default 0.05)
* 2. Add similarityMin column to face_models (default 0.6)
*
* These fields enable self-calibrating thresholds during scanning.
* During training, we compute stats from training sample similarities
* and use (mean - 2*stdDev) as a floor for matching.
*/
val MIGRATION_11_12 = object : Migration(11, 12) {
override fun migrate(database: SupportSQLiteDatabase) {
// Add distribution stats columns with sensible defaults for existing models
database.execSQL("ALTER TABLE face_models ADD COLUMN similarityStdDev REAL NOT NULL DEFAULT 0.05")
database.execSQL("ALTER TABLE face_models ADD COLUMN similarityMin REAL NOT NULL DEFAULT 0.6")
}
}
/** /**
* PRODUCTION MIGRATION NOTES: * PRODUCTION MIGRATION NOTES:
* *
* Before shipping to users, update DatabaseModule to use migrations: * Before shipping to users, update DatabaseModule to use migrations:
* *
* Room.databaseBuilder(context, AppDatabase::class.java, "sherpai.db") * Room.databaseBuilder(context, AppDatabase::class.java, "sherpai.db")
* .addMigrations(MIGRATION_7_8, MIGRATION_8_9, MIGRATION_9_10, MIGRATION_10_11) // Add all migrations * .addMigrations(MIGRATION_7_8, MIGRATION_8_9, MIGRATION_9_10, MIGRATION_10_11, MIGRATION_11_12) // Add all migrations
* // .fallbackToDestructiveMigration() // Remove this * // .fallbackToDestructiveMigration() // Remove this
* .build() * .build()
*/ */

View File

@@ -143,6 +143,13 @@ data class FaceModelEntity(
@ColumnInfo(name = "averageConfidence") @ColumnInfo(name = "averageConfidence")
val averageConfidence: Float, val averageConfidence: Float,
// Distribution stats for self-calibrating rejection
@ColumnInfo(name = "similarityStdDev")
val similarityStdDev: Float = 0.05f, // Default for backwards compat
@ColumnInfo(name = "similarityMin")
val similarityMin: Float = 0.6f, // Default for backwards compat
@ColumnInfo(name = "createdAt") @ColumnInfo(name = "createdAt")
val createdAt: Long, val createdAt: Long,
@@ -157,26 +164,29 @@ data class FaceModelEntity(
) { ) {
companion object { companion object {
/** /**
* Backwards compatible create() method * Create with distribution stats for self-calibrating rejection
* Used by existing FaceRecognitionRepository code
*/ */
fun create( fun create(
personId: String, personId: String,
embeddingArray: FloatArray, embeddingArray: FloatArray,
trainingImageCount: Int, trainingImageCount: Int,
averageConfidence: Float averageConfidence: Float,
similarityStdDev: Float = 0.05f,
similarityMin: Float = 0.6f
): FaceModelEntity { ): FaceModelEntity {
return createFromEmbedding(personId, embeddingArray, trainingImageCount, averageConfidence) return createFromEmbedding(personId, embeddingArray, trainingImageCount, averageConfidence, similarityStdDev, similarityMin)
} }
/** /**
* Create from single embedding (backwards compatible) * Create from single embedding with distribution stats
*/ */
fun createFromEmbedding( fun createFromEmbedding(
personId: String, personId: String,
embeddingArray: FloatArray, embeddingArray: FloatArray,
trainingImageCount: Int, trainingImageCount: Int,
averageConfidence: Float averageConfidence: Float,
similarityStdDev: Float = 0.05f,
similarityMin: Float = 0.6f
): FaceModelEntity { ): FaceModelEntity {
val now = System.currentTimeMillis() val now = System.currentTimeMillis()
val centroid = TemporalCentroid( val centroid = TemporalCentroid(
@@ -194,6 +204,8 @@ data class FaceModelEntity(
centroidsJson = serializeCentroids(listOf(centroid)), centroidsJson = serializeCentroids(listOf(centroid)),
trainingImageCount = trainingImageCount, trainingImageCount = trainingImageCount,
averageConfidence = averageConfidence, averageConfidence = averageConfidence,
similarityStdDev = similarityStdDev,
similarityMin = similarityMin,
createdAt = now, createdAt = now,
updatedAt = now, updatedAt = now,
lastUsed = null, lastUsed = null,

View File

@@ -99,11 +99,19 @@ class FaceRecognitionRepository @Inject constructor(
} }
val avgConfidence = confidences.average().toFloat() val avgConfidence = confidences.average().toFloat()
// Compute distribution stats for self-calibrating rejection
val stdDev = kotlin.math.sqrt(
confidences.map { (it - avgConfidence).toDouble().let { d -> d * d } }.average()
).toFloat()
val minSimilarity = confidences.minOrNull() ?: 0f
val faceModel = FaceModelEntity.create( val faceModel = FaceModelEntity.create(
personId = personId, personId = personId,
embeddingArray = personEmbedding, embeddingArray = personEmbedding,
trainingImageCount = validImages.size, trainingImageCount = validImages.size,
averageConfidence = avgConfidence averageConfidence = avgConfidence,
similarityStdDev = stdDev,
similarityMin = minSimilarity
) )
faceModelDao.insertFaceModel(faceModel) faceModelDao.insertFaceModel(faceModel)

View File

@@ -29,6 +29,64 @@ import kotlin.math.sqrt
*/ */
object FaceQualityFilter { object FaceQualityFilter {
/**
* Age group estimation for filtering (child vs adult detection)
*/
enum class AgeGroup { CHILD, ADULT, UNCERTAIN }
/**
* Estimate whether a face belongs to a child or adult based on facial proportions.
*
* Uses two heuristics:
* 1. Eye position ratio - Children have larger foreheads, so eyes are lower (~45% from top)
* Adults have eyes at ~35% from top
* 2. Face roundness (width/height ratio) - Children: ~0.85-1.0, Adults: ~0.7-0.85
*
* @return AgeGroup.CHILD, AgeGroup.ADULT, or AgeGroup.UNCERTAIN
*/
fun estimateAgeGroup(face: Face, imageWidth: Int, imageHeight: Int): AgeGroup {
val leftEye = face.getLandmark(FaceLandmark.LEFT_EYE)
val rightEye = face.getLandmark(FaceLandmark.RIGHT_EYE)
if (leftEye == null || rightEye == null) {
return AgeGroup.UNCERTAIN
}
// Eye-to-face height ratio (where eyes sit relative to face top)
val faceHeight = face.boundingBox.height().toFloat()
val faceTop = face.boundingBox.top.toFloat()
val eyeY = (leftEye.position.y + rightEye.position.y) / 2
val eyePositionRatio = (eyeY - faceTop) / faceHeight
// Children: eyes at ~45% from top (larger forehead proportionally)
// Adults: eyes at ~35% from top
// Score: higher = more child-like
// Face roundness (width/height)
val faceWidth = face.boundingBox.width().toFloat()
val faceRatio = faceWidth / faceHeight
// Children: ratio ~0.85-1.0 (rounder faces)
// Adults: ratio ~0.7-0.85 (longer/narrower faces)
var childScore = 0
// Eye position scoring
if (eyePositionRatio > 0.45f) childScore += 2 // Strong child signal
else if (eyePositionRatio > 0.42f) childScore += 1 // Mild child signal
else if (eyePositionRatio < 0.35f) childScore -= 1 // Adult signal
// Face roundness scoring
if (faceRatio > 0.90f) childScore += 2 // Very round = child
else if (faceRatio > 0.82f) childScore += 1 // Somewhat round
else if (faceRatio < 0.75f) childScore -= 1 // Long face = adult
return when {
childScore >= 3 -> AgeGroup.CHILD
childScore <= 0 -> AgeGroup.ADULT
else -> AgeGroup.UNCERTAIN
}
}
/** /**
* Validate face for Discovery/Clustering * Validate face for Discovery/Clustering
* *

View File

@@ -19,6 +19,7 @@ import com.placeholder.sherpai2.data.local.entity.PersonEntity
import com.placeholder.sherpai2.data.local.entity.PhotoFaceTagEntity import com.placeholder.sherpai2.data.local.entity.PhotoFaceTagEntity
import com.placeholder.sherpai2.ml.FaceNetModel import com.placeholder.sherpai2.ml.FaceNetModel
import com.placeholder.sherpai2.ml.ThresholdStrategy import com.placeholder.sherpai2.ml.ThresholdStrategy
import com.placeholder.sherpai2.domain.clustering.FaceQualityFilter
import dagger.hilt.android.lifecycle.HiltViewModel import dagger.hilt.android.lifecycle.HiltViewModel
import dagger.hilt.android.qualifiers.ApplicationContext import dagger.hilt.android.qualifiers.ApplicationContext
import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.Dispatchers
@@ -142,7 +143,7 @@ class PersonInventoryViewModel @Inject constructor(
val detectorOptions = FaceDetectorOptions.Builder() val detectorOptions = FaceDetectorOptions.Builder()
.setPerformanceMode(FaceDetectorOptions.PERFORMANCE_MODE_ACCURATE) .setPerformanceMode(FaceDetectorOptions.PERFORMANCE_MODE_ACCURATE)
.setLandmarkMode(FaceDetectorOptions.LANDMARK_MODE_NONE) .setLandmarkMode(FaceDetectorOptions.LANDMARK_MODE_ALL) // Needed for age estimation
.setClassificationMode(FaceDetectorOptions.CLASSIFICATION_MODE_NONE) .setClassificationMode(FaceDetectorOptions.CLASSIFICATION_MODE_NONE)
.setMinFaceSize(0.15f) .setMinFaceSize(0.15f)
.build() .build()
@@ -159,9 +160,23 @@ class PersonInventoryViewModel @Inject constructor(
} }
val faceNetModel = FaceNetModel(context) val faceNetModel = FaceNetModel(context)
// Production threshold - balance precision vs recall // Production threshold - STRICT to avoid false positives
val baseThreshold = 0.58f // Solo face photos: 0.62, Group photos: 0.68
android.util.Log.d("PersonScan", "Using threshold: $baseThreshold, centroids: ${modelCentroids.size}") val baseThreshold = 0.62f
val groupPhotoThreshold = 0.68f // Higher bar for multi-face images
// Load ALL other models for "best match wins" comparison
val allModels = faceModelDao.getAllActiveFaceModels()
val otherModelCentroids = allModels
.filter { it.id != faceModel.id }
.map { model -> model.id to model.getCentroids().map { it.getEmbeddingArray() } }
// Distribution-based minimum threshold (self-calibrating)
val distributionMin = (faceModel.averageConfidence - 2 * faceModel.similarityStdDev)
.coerceAtLeast(faceModel.similarityMin - 0.05f)
.coerceAtLeast(0.50f) // Never go below 0.50 absolute floor
android.util.Log.d("PersonScan", "Using threshold: solo=$baseThreshold, group=$groupPhotoThreshold, distributionMin=$distributionMin (avgConf=${faceModel.averageConfidence}, stdDev=${faceModel.similarityStdDev}), centroids: ${modelCentroids.size}, competing models: ${otherModelCentroids.size}, isChild=${person.isChild}")
val completed = AtomicInteger(0) val completed = AtomicInteger(0)
val facesFound = AtomicInteger(0) val facesFound = AtomicInteger(0)
@@ -173,7 +188,7 @@ class PersonInventoryViewModel @Inject constructor(
val jobs = untaggedImages.map { image -> val jobs = untaggedImages.map { image ->
async { async {
semaphore.withPermit { semaphore.withPermit {
processImage(image, detector, faceNetModel, modelCentroids, trainingCount, baseThreshold, personId, faceModel.id, batchMatches, batchUpdateMutex, completed, facesFound, startTime, totalToScan, person.name) processImage(image, detector, faceNetModel, modelCentroids, otherModelCentroids, trainingCount, baseThreshold, groupPhotoThreshold, distributionMin, person.isChild, personId, faceModel.id, batchMatches, batchUpdateMutex, completed, facesFound, startTime, totalToScan, person.name)
} }
} }
} }
@@ -200,7 +215,10 @@ class PersonInventoryViewModel @Inject constructor(
private suspend fun processImage( private suspend fun processImage(
image: ImageEntity, detector: com.google.mlkit.vision.face.FaceDetector, faceNetModel: FaceNetModel, image: ImageEntity, detector: com.google.mlkit.vision.face.FaceDetector, faceNetModel: FaceNetModel,
modelCentroids: List<FloatArray>, trainingCount: Int, baseThreshold: Float, personId: String, faceModelId: String, modelCentroids: List<FloatArray>, otherModelCentroids: List<Pair<String, List<FloatArray>>>,
trainingCount: Int, baseThreshold: Float, groupPhotoThreshold: Float,
distributionMin: Float, isChildTarget: Boolean,
personId: String, faceModelId: String,
batchMatches: MutableList<Triple<String, String, Float>>, batchUpdateMutex: Mutex, batchMatches: MutableList<Triple<String, String, Float>>, batchUpdateMutex: Mutex,
completed: AtomicInteger, facesFound: AtomicInteger, startTime: Long, totalToScan: Int, personName: String completed: AtomicInteger, facesFound: AtomicInteger, startTime: Long, totalToScan: Int, personName: String
) { ) {
@@ -225,9 +243,13 @@ class PersonInventoryViewModel @Inject constructor(
val scaleX = sizeOpts.outWidth.toFloat() / detectionBitmap.width val scaleX = sizeOpts.outWidth.toFloat() / detectionBitmap.width
val scaleY = sizeOpts.outHeight.toFloat() / detectionBitmap.height val scaleY = sizeOpts.outHeight.toFloat() / detectionBitmap.height
val imageQuality = ThresholdStrategy.estimateImageQuality(sizeOpts.outWidth, sizeOpts.outHeight) // CRITICAL: Use higher threshold for group photos (more likely false positives)
val detectionContext = ThresholdStrategy.estimateDetectionContext(faces.size) val isGroupPhoto = faces.size > 1
val threshold = ThresholdStrategy.getOptimalThreshold(trainingCount, imageQuality, detectionContext).coerceAtMost(baseThreshold) val effectiveThreshold = if (isGroupPhoto) groupPhotoThreshold else baseThreshold
// Track best match in this image (only tag ONE face per image)
var bestMatchSimilarity = 0f
var foundMatch = false
for (face in faces) { for (face in faces) {
val scaledBounds = android.graphics.Rect( val scaledBounds = android.graphics.Rect(
@@ -237,27 +259,70 @@ class PersonInventoryViewModel @Inject constructor(
(face.boundingBox.bottom * scaleY).toInt() (face.boundingBox.bottom * scaleY).toInt()
) )
// Skip very small faces (less reliable)
val faceArea = scaledBounds.width() * scaledBounds.height()
val imageArea = sizeOpts.outWidth * sizeOpts.outHeight
val faceRatio = faceArea.toFloat() / imageArea
if (faceRatio < 0.02f) continue // Face must be at least 2% of image
// SIGNAL 2: Age plausibility check (if target is a child)
if (isChildTarget) {
val ageGroup = FaceQualityFilter.estimateAgeGroup(face, detectionBitmap.width, detectionBitmap.height)
if (ageGroup == FaceQualityFilter.AgeGroup.ADULT) {
continue // Reject clearly adult faces when searching for a child
}
}
// CRITICAL: Add padding to face crop (same as training) // CRITICAL: Add padding to face crop (same as training)
val faceBitmap = loadFaceRegionWithPadding(uri, scaledBounds, sizeOpts.outWidth, sizeOpts.outHeight) ?: continue val faceBitmap = loadFaceRegionWithPadding(uri, scaledBounds, sizeOpts.outWidth, sizeOpts.outHeight) ?: continue
val faceEmbedding = faceNetModel.generateEmbedding(faceBitmap) val faceEmbedding = faceNetModel.generateEmbedding(faceBitmap)
faceBitmap.recycle() faceBitmap.recycle()
// Match against ALL centroids, use best match // Match against target person's centroids
val bestSimilarity = modelCentroids.maxOfOrNull { centroid -> val targetSimilarity = modelCentroids.maxOfOrNull { centroid ->
faceNetModel.calculateSimilarity(faceEmbedding, centroid) faceNetModel.calculateSimilarity(faceEmbedding, centroid)
} ?: 0f } ?: 0f
if (bestSimilarity >= threshold) { // SIGNAL 1: Distribution-based rejection
batchUpdateMutex.withLock { // If similarity is below (mean - 2*stdDev) or (min - 0.05), it's a statistical outlier
batchMatches.add(Triple(personId, image.imageId, bestSimilarity)) if (targetSimilarity < distributionMin) {
facesFound.incrementAndGet() continue // Too far below training distribution
if (batchMatches.size >= BATCH_DB_SIZE) { }
saveBatchMatches(batchMatches.toList(), faceModelId)
batchMatches.clear() // SIGNAL 3: Basic threshold check
} if (targetSimilarity < effectiveThreshold) {
continue
}
// SIGNAL 4: "Best match wins" - check if any OTHER model scores higher
// This prevents tagging siblings/similar people incorrectly
val bestOtherSimilarity = otherModelCentroids.maxOfOrNull { (_, centroids) ->
centroids.maxOfOrNull { centroid ->
faceNetModel.calculateSimilarity(faceEmbedding, centroid)
} ?: 0f
} ?: 0f
val isTargetBestMatch = targetSimilarity > bestOtherSimilarity
// All signals must pass
if (isTargetBestMatch && targetSimilarity > bestMatchSimilarity) {
bestMatchSimilarity = targetSimilarity
foundMatch = true
}
}
// Only add ONE tag per image (the best match)
if (foundMatch) {
batchUpdateMutex.withLock {
batchMatches.add(Triple(personId, image.imageId, bestMatchSimilarity))
facesFound.incrementAndGet()
if (batchMatches.size >= BATCH_DB_SIZE) {
saveBatchMatches(batchMatches.toList(), faceModelId)
batchMatches.clear()
} }
} }
} }
detectionBitmap.recycle() detectionBitmap.recycle()
} catch (e: Exception) { } catch (e: Exception) {
} finally { } finally {

View File

@@ -2,7 +2,9 @@ package com.placeholder.sherpai2.ui.rollingscan
import android.net.Uri import android.net.Uri
import androidx.compose.foundation.BorderStroke import androidx.compose.foundation.BorderStroke
import androidx.compose.foundation.ExperimentalFoundationApi
import androidx.compose.foundation.clickable import androidx.compose.foundation.clickable
import androidx.compose.foundation.combinedClickable
import androidx.compose.foundation.layout.* import androidx.compose.foundation.layout.*
import androidx.compose.foundation.lazy.grid.GridCells import androidx.compose.foundation.lazy.grid.GridCells
import androidx.compose.foundation.lazy.grid.GridItemSpan import androidx.compose.foundation.lazy.grid.GridItemSpan
@@ -37,7 +39,7 @@ import com.placeholder.sherpai2.domain.similarity.FaceSimilarityScorer
* - Quick action buttons (Select Top N) * - Quick action buttons (Select Top N)
* - Submit button with validation * - Submit button with validation
*/ */
@OptIn(ExperimentalMaterial3Api::class) @OptIn(ExperimentalMaterial3Api::class, ExperimentalFoundationApi::class)
@Composable @Composable
fun RollingScanScreen( fun RollingScanScreen(
seedImageIds: List<String>, seedImageIds: List<String>,
@@ -48,6 +50,7 @@ fun RollingScanScreen(
) { ) {
val uiState by viewModel.uiState.collectAsState() val uiState by viewModel.uiState.collectAsState()
val selectedImageIds by viewModel.selectedImageIds.collectAsState() val selectedImageIds by viewModel.selectedImageIds.collectAsState()
val negativeImageIds by viewModel.negativeImageIds.collectAsState()
val rankedPhotos by viewModel.rankedPhotos.collectAsState() val rankedPhotos by viewModel.rankedPhotos.collectAsState()
val isScanning by viewModel.isScanning.collectAsState() val isScanning by viewModel.isScanning.collectAsState()
@@ -70,6 +73,7 @@ fun RollingScanScreen(
isReadyForTraining = viewModel.isReadyForTraining(), isReadyForTraining = viewModel.isReadyForTraining(),
validationMessage = viewModel.getValidationMessage(), validationMessage = viewModel.getValidationMessage(),
onSelectTopN = { count -> viewModel.selectTopN(count) }, onSelectTopN = { count -> viewModel.selectTopN(count) },
onSelectAboveThreshold = { threshold -> viewModel.selectAllAboveThreshold(threshold) },
onSubmit = { onSubmit = {
val uris = viewModel.getSelectedImageUris() val uris = viewModel.getSelectedImageUris()
onSubmitForTraining(uris) onSubmitForTraining(uris)
@@ -93,8 +97,10 @@ fun RollingScanScreen(
RollingScanPhotoGrid( RollingScanPhotoGrid(
rankedPhotos = rankedPhotos, rankedPhotos = rankedPhotos,
selectedImageIds = selectedImageIds, selectedImageIds = selectedImageIds,
negativeImageIds = negativeImageIds,
isScanning = isScanning, isScanning = isScanning,
onToggleSelection = { imageId -> viewModel.toggleSelection(imageId) }, onToggleSelection = { imageId -> viewModel.toggleSelection(imageId) },
onToggleNegative = { imageId -> viewModel.toggleNegative(imageId) },
modifier = Modifier.padding(padding) modifier = Modifier.padding(padding)
) )
} }
@@ -159,19 +165,26 @@ private fun RollingScanTopBar(
} }
// ═══════════════════════════════════════════════════════════ // ═══════════════════════════════════════════════════════════
// PHOTO GRID // PHOTO GRID - Similarity-based bucketing
// ═══════════════════════════════════════════════════════════ // ═══════════════════════════════════════════════════════════
@OptIn(ExperimentalFoundationApi::class)
@Composable @Composable
private fun RollingScanPhotoGrid( private fun RollingScanPhotoGrid(
rankedPhotos: List<FaceSimilarityScorer.ScoredPhoto>, rankedPhotos: List<FaceSimilarityScorer.ScoredPhoto>,
selectedImageIds: Set<String>, selectedImageIds: Set<String>,
negativeImageIds: Set<String>,
isScanning: Boolean, isScanning: Boolean,
onToggleSelection: (String) -> Unit, onToggleSelection: (String) -> Unit,
onToggleNegative: (String) -> Unit,
modifier: Modifier = Modifier modifier: Modifier = Modifier
) { ) {
Column(modifier = modifier.fillMaxSize()) { // Bucket by similarity score
val veryLikely = rankedPhotos.filter { it.finalScore >= 0.60f }
val probably = rankedPhotos.filter { it.finalScore in 0.45f..0.599f }
val maybe = rankedPhotos.filter { it.finalScore < 0.45f }
Column(modifier = modifier.fillMaxSize()) {
// Scanning indicator // Scanning indicator
if (isScanning) { if (isScanning) {
LinearProgressIndicator( LinearProgressIndicator(
@@ -180,69 +193,78 @@ private fun RollingScanPhotoGrid(
) )
} }
// Hint for negative marking
Text(
text = "Tap to select • Long-press to mark as NOT this person",
style = MaterialTheme.typography.bodySmall,
color = MaterialTheme.colorScheme.onSurfaceVariant,
modifier = Modifier.padding(horizontal = 12.dp, vertical = 4.dp)
)
LazyVerticalGrid( LazyVerticalGrid(
columns = GridCells.Fixed(3), columns = GridCells.Fixed(3),
contentPadding = PaddingValues(8.dp), contentPadding = PaddingValues(8.dp),
horizontalArrangement = Arrangement.spacedBy(8.dp), horizontalArrangement = Arrangement.spacedBy(8.dp),
verticalArrangement = Arrangement.spacedBy(8.dp) verticalArrangement = Arrangement.spacedBy(8.dp)
) { ) {
// Section: Most Similar (top 10) // Section: Very Likely (>60%)
val topMatches = rankedPhotos.take(10) if (veryLikely.isNotEmpty()) {
if (topMatches.isNotEmpty()) {
item(span = { GridItemSpan(3) }) { item(span = { GridItemSpan(3) }) {
SectionHeader( SectionHeader(
icon = Icons.Default.Whatshot, icon = Icons.Default.Whatshot,
text = "🔥 Most Similar (${topMatches.size})", text = "🟢 Very Likely (${veryLikely.size})",
color = MaterialTheme.colorScheme.primary color = Color(0xFF4CAF50)
) )
} }
items(veryLikely, key = { it.imageId }) { photo ->
items(topMatches, key = { it.imageId }) { photo ->
PhotoCard( PhotoCard(
photo = photo, photo = photo,
isSelected = photo.imageId in selectedImageIds, isSelected = photo.imageId in selectedImageIds,
isNegative = photo.imageId in negativeImageIds,
onToggle = { onToggleSelection(photo.imageId) }, onToggle = { onToggleSelection(photo.imageId) },
onLongPress = { onToggleNegative(photo.imageId) },
showSimilarityBadge = true showSimilarityBadge = true
) )
} }
} }
// Section: Good Matches (11-30) // Section: Probably (45-60%)
val goodMatches = rankedPhotos.drop(10).take(20) if (probably.isNotEmpty()) {
if (goodMatches.isNotEmpty()) {
item(span = { GridItemSpan(3) }) { item(span = { GridItemSpan(3) }) {
SectionHeader( SectionHeader(
icon = Icons.Default.CheckCircle, icon = Icons.Default.CheckCircle,
text = "📊 Good Matches (${goodMatches.size})", text = "🟡 Probably (${probably.size})",
color = MaterialTheme.colorScheme.tertiary color = Color(0xFFFFC107)
) )
} }
items(probably, key = { it.imageId }) { photo ->
items(goodMatches, key = { it.imageId }) { photo ->
PhotoCard( PhotoCard(
photo = photo, photo = photo,
isSelected = photo.imageId in selectedImageIds, isSelected = photo.imageId in selectedImageIds,
onToggle = { onToggleSelection(photo.imageId) } isNegative = photo.imageId in negativeImageIds,
onToggle = { onToggleSelection(photo.imageId) },
onLongPress = { onToggleNegative(photo.imageId) },
showSimilarityBadge = true
) )
} }
} }
// Section: Other Photos // Section: Maybe (<45%)
val otherPhotos = rankedPhotos.drop(30) if (maybe.isNotEmpty()) {
if (otherPhotos.isNotEmpty()) {
item(span = { GridItemSpan(3) }) { item(span = { GridItemSpan(3) }) {
SectionHeader( SectionHeader(
icon = Icons.Default.Photo, icon = Icons.Default.Photo,
text = "📷 Other Photos (${otherPhotos.size})", text = "🟠 Maybe (${maybe.size})",
color = MaterialTheme.colorScheme.onSurfaceVariant color = Color(0xFFFF9800)
) )
} }
items(maybe, key = { it.imageId }) { photo ->
items(otherPhotos, key = { it.imageId }) { photo ->
PhotoCard( PhotoCard(
photo = photo, photo = photo,
isSelected = photo.imageId in selectedImageIds, isSelected = photo.imageId in selectedImageIds,
onToggle = { onToggleSelection(photo.imageId) } isNegative = photo.imageId in negativeImageIds,
onToggle = { onToggleSelection(photo.imageId) },
onLongPress = { onToggleNegative(photo.imageId) }
) )
} }
} }
@@ -258,24 +280,34 @@ private fun RollingScanPhotoGrid(
} }
// ═══════════════════════════════════════════════════════════ // ═══════════════════════════════════════════════════════════
// PHOTO CARD // PHOTO CARD - with long-press for negative marking
// ═══════════════════════════════════════════════════════════ // ═══════════════════════════════════════════════════════════
@OptIn(ExperimentalFoundationApi::class)
@Composable @Composable
private fun PhotoCard( private fun PhotoCard(
photo: FaceSimilarityScorer.ScoredPhoto, photo: FaceSimilarityScorer.ScoredPhoto,
isSelected: Boolean, isSelected: Boolean,
isNegative: Boolean = false,
onToggle: () -> Unit, onToggle: () -> Unit,
onLongPress: () -> Unit = {},
showSimilarityBadge: Boolean = false showSimilarityBadge: Boolean = false
) { ) {
val borderColor = when {
isNegative -> Color(0xFFE53935) // Red for negative
isSelected -> MaterialTheme.colorScheme.primary
else -> MaterialTheme.colorScheme.outline.copy(alpha = 0.3f)
}
val borderWidth = if (isSelected || isNegative) 3.dp else 1.dp
Card( Card(
modifier = Modifier modifier = Modifier
.aspectRatio(1f) .aspectRatio(1f)
.clickable(onClick = onToggle), .combinedClickable(
border = if (isSelected) onClick = onToggle,
BorderStroke(3.dp, MaterialTheme.colorScheme.primary) onLongClick = onLongPress
else ),
BorderStroke(1.dp, MaterialTheme.colorScheme.outline.copy(alpha = 0.3f)), border = BorderStroke(borderWidth, borderColor),
elevation = CardDefaults.cardElevation( elevation = CardDefaults.cardElevation(
defaultElevation = if (isSelected) 4.dp else 1.dp defaultElevation = if (isSelected) 4.dp else 1.dp
) )
@@ -289,22 +321,47 @@ private fun PhotoCard(
contentScale = ContentScale.Crop contentScale = ContentScale.Crop
) )
// Similarity badge (top-left) - Only for top matches // Dim overlay for negatives
if (showSimilarityBadge) { if (isNegative) {
Box(
modifier = Modifier
.fillMaxSize()
.padding(0.dp),
contentAlignment = Alignment.Center
) {
Surface(
modifier = Modifier.fillMaxSize(),
color = Color.Black.copy(alpha = 0.5f)
) {}
Icon(
Icons.Default.Close,
contentDescription = "Not this person",
tint = Color.White,
modifier = Modifier.size(32.dp)
)
}
}
// Similarity badge (top-left)
if (showSimilarityBadge && !isNegative) {
Surface( Surface(
modifier = Modifier modifier = Modifier
.align(Alignment.TopStart) .align(Alignment.TopStart)
.padding(6.dp), .padding(6.dp),
shape = RoundedCornerShape(8.dp), shape = RoundedCornerShape(8.dp),
color = MaterialTheme.colorScheme.primary, color = when {
photo.finalScore >= 0.60f -> Color(0xFF4CAF50)
photo.finalScore >= 0.45f -> Color(0xFFFFC107)
else -> Color(0xFFFF9800)
},
shadowElevation = 4.dp shadowElevation = 4.dp
) { ) {
Text( Text(
text = "${(photo.similarityScore * 100).toInt()}%", text = "${(photo.finalScore * 100).toInt()}%",
modifier = Modifier.padding(horizontal = 8.dp, vertical = 4.dp), modifier = Modifier.padding(horizontal = 8.dp, vertical = 4.dp),
style = MaterialTheme.typography.labelSmall, style = MaterialTheme.typography.labelSmall,
fontWeight = FontWeight.Bold, fontWeight = FontWeight.Bold,
color = MaterialTheme.colorScheme.onPrimary color = Color.White
) )
} }
} }
@@ -332,7 +389,7 @@ private fun PhotoCard(
} }
// Face count badge (bottom-right) // Face count badge (bottom-right)
if (photo.faceCount > 1) { if (photo.faceCount > 1 && !isNegative) {
Surface( Surface(
modifier = Modifier modifier = Modifier
.align(Alignment.BottomEnd) .align(Alignment.BottomEnd)
@@ -395,6 +452,7 @@ private fun RollingScanBottomBar(
isReadyForTraining: Boolean, isReadyForTraining: Boolean,
validationMessage: String?, validationMessage: String?,
onSelectTopN: (Int) -> Unit, onSelectTopN: (Int) -> Unit,
onSelectAboveThreshold: (Float) -> Unit,
onSubmit: () -> Unit onSubmit: () -> Unit
) { ) {
Surface( Surface(
@@ -416,39 +474,49 @@ private fun RollingScanBottomBar(
) )
} }
// First row: threshold selection
Row( Row(
modifier = Modifier.fillMaxWidth(), modifier = Modifier.fillMaxWidth(),
horizontalArrangement = Arrangement.spacedBy(8.dp) horizontalArrangement = Arrangement.spacedBy(6.dp)
) { ) {
// Quick select buttons
OutlinedButton( OutlinedButton(
onClick = { onSelectTopN(10) }, onClick = { onSelectAboveThreshold(0.60f) },
modifier = Modifier.weight(1f) modifier = Modifier.weight(1f),
contentPadding = PaddingValues(horizontal = 8.dp, vertical = 4.dp)
) { ) {
Text("Top 10") Text(">60%", style = MaterialTheme.typography.labelSmall)
} }
OutlinedButton( OutlinedButton(
onClick = { onSelectTopN(20) }, onClick = { onSelectAboveThreshold(0.50f) },
modifier = Modifier.weight(1f) modifier = Modifier.weight(1f),
contentPadding = PaddingValues(horizontal = 8.dp, vertical = 4.dp)
) { ) {
Text("Top 20") Text(">50%", style = MaterialTheme.typography.labelSmall)
} }
OutlinedButton(
onClick = { onSelectTopN(15) },
modifier = Modifier.weight(1f),
contentPadding = PaddingValues(horizontal = 8.dp, vertical = 4.dp)
) {
Text("Top 15", style = MaterialTheme.typography.labelSmall)
}
}
// Submit button Spacer(Modifier.height(8.dp))
Button(
onClick = onSubmit, // Second row: submit
enabled = isReadyForTraining, Button(
modifier = Modifier.weight(1.5f) onClick = onSubmit,
) { enabled = isReadyForTraining,
Icon( modifier = Modifier.fillMaxWidth()
Icons.Default.Done, ) {
contentDescription = null, Icon(
modifier = Modifier.size(18.dp) Icons.Default.Done,
) contentDescription = null,
Spacer(Modifier.width(8.dp)) modifier = Modifier.size(18.dp)
Text("Train ($selectedCount)") )
} Spacer(Modifier.width(8.dp))
Text("Train Model ($selectedCount photos)")
} }
} }
} }

View File

@@ -44,6 +44,11 @@ class RollingScanViewModel @Inject constructor(
private const val TAG = "RollingScanVM" private const val TAG = "RollingScanVM"
private const val DEBOUNCE_DELAY_MS = 300L private const val DEBOUNCE_DELAY_MS = 300L
private const val MIN_PHOTOS_FOR_TRAINING = 15 private const val MIN_PHOTOS_FOR_TRAINING = 15
// Progressive thresholds based on selection count
private const val FLOOR_FEW_SEEDS = 0.30f // 1-3 seeds
private const val FLOOR_MEDIUM_SEEDS = 0.40f // 4-10 seeds
private const val FLOOR_MANY_SEEDS = 0.50f // 10+ seeds
} }
// ═══════════════════════════════════════════════════════════ // ═══════════════════════════════════════════════════════════
@@ -71,6 +76,11 @@ class RollingScanViewModel @Inject constructor(
// Cache of selected embeddings // Cache of selected embeddings
private val selectedEmbeddings = mutableListOf<FloatArray>() private val selectedEmbeddings = mutableListOf<FloatArray>()
// Negative embeddings (marked as "not this person")
private val _negativeImageIds = MutableStateFlow<Set<String>>(emptySet())
val negativeImageIds: StateFlow<Set<String>> = _negativeImageIds.asStateFlow()
private val negativeEmbeddings = mutableListOf<FloatArray>()
// All available image IDs // All available image IDs
private var allImageIds: List<String> = emptyList() private var allImageIds: List<String> = emptyList()
@@ -156,24 +166,55 @@ class RollingScanViewModel @Inject constructor(
current.remove(imageId) current.remove(imageId)
viewModelScope.launch { viewModelScope.launch {
// Remove embedding from cache
val cached = faceCacheDao.getEmbeddingByImageId(imageId) val cached = faceCacheDao.getEmbeddingByImageId(imageId)
cached?.getEmbedding()?.let { selectedEmbeddings.remove(it) } cached?.getEmbedding()?.let { selectedEmbeddings.remove(it) }
} }
} else { } else {
// Select // Select (and remove from negatives if present)
current.add(imageId) current.add(imageId)
if (imageId in _negativeImageIds.value) {
toggleNegative(imageId)
}
viewModelScope.launch { viewModelScope.launch {
// Add embedding to cache
val cached = faceCacheDao.getEmbeddingByImageId(imageId) val cached = faceCacheDao.getEmbeddingByImageId(imageId)
cached?.getEmbedding()?.let { selectedEmbeddings.add(it) } cached?.getEmbedding()?.let { selectedEmbeddings.add(it) }
} }
} }
_selectedImageIds.value = current _selectedImageIds.value = current.toSet() // Immutable copy
scanDebouncer.debounce {
triggerRollingScan()
}
}
/**
* Toggle negative marking ("Not this person")
*/
fun toggleNegative(imageId: String) {
val current = _negativeImageIds.value.toMutableSet()
if (imageId in current) {
current.remove(imageId)
viewModelScope.launch {
val cached = faceCacheDao.getEmbeddingByImageId(imageId)
cached?.getEmbedding()?.let { negativeEmbeddings.remove(it) }
}
} else {
current.add(imageId)
// Remove from selected if present
if (imageId in _selectedImageIds.value) {
toggleSelection(imageId)
}
viewModelScope.launch {
val cached = faceCacheDao.getEmbeddingByImageId(imageId)
cached?.getEmbedding()?.let { negativeEmbeddings.add(it) }
}
}
_negativeImageIds.value = current.toSet() // Immutable copy
// Debounced rescan
scanDebouncer.debounce { scanDebouncer.debounce {
triggerRollingScan() triggerRollingScan()
} }
@@ -190,13 +231,33 @@ class RollingScanViewModel @Inject constructor(
val current = _selectedImageIds.value.toMutableSet() val current = _selectedImageIds.value.toMutableSet()
current.addAll(topPhotos) current.addAll(topPhotos)
_selectedImageIds.value = current _selectedImageIds.value = current.toSet() // Immutable copy
viewModelScope.launch { viewModelScope.launch {
// Add embeddings
val embeddings = faceCacheDao.getEmbeddingsForImages(topPhotos.toList()) val embeddings = faceCacheDao.getEmbeddingsForImages(topPhotos.toList())
selectedEmbeddings.addAll(embeddings.mapNotNull { it.getEmbedding() }) selectedEmbeddings.addAll(embeddings.mapNotNull { it.getEmbedding() })
triggerRollingScan()
}
}
/**
* Select all photos above a similarity threshold
*/
fun selectAllAboveThreshold(threshold: Float) {
val photosAbove = _rankedPhotos.value
.filter { it.finalScore >= threshold }
.map { it.imageId }
val current = _selectedImageIds.value.toMutableSet()
current.addAll(photosAbove)
_selectedImageIds.value = current.toSet() // Immutable copy
viewModelScope.launch {
val newIds = photosAbove.filter { it !in _selectedImageIds.value }
if (newIds.isNotEmpty()) {
val embeddings = faceCacheDao.getEmbeddingsForImages(newIds)
selectedEmbeddings.addAll(embeddings.mapNotNull { it.getEmbedding() })
}
triggerRollingScan() triggerRollingScan()
} }
} }
@@ -207,17 +268,24 @@ class RollingScanViewModel @Inject constructor(
fun clearSelection() { fun clearSelection() {
_selectedImageIds.value = emptySet() _selectedImageIds.value = emptySet()
selectedEmbeddings.clear() selectedEmbeddings.clear()
// Reset ranking
_rankedPhotos.value = emptyList() _rankedPhotos.value = emptyList()
} }
/**
* Clear negative markings
*/
fun clearNegatives() {
_negativeImageIds.value = emptySet()
negativeEmbeddings.clear()
scanDebouncer.debounce { triggerRollingScan() }
}
// ═══════════════════════════════════════════════════════════ // ═══════════════════════════════════════════════════════════
// ROLLING SCAN LOGIC // ROLLING SCAN LOGIC
// ═══════════════════════════════════════════════════════════ // ═══════════════════════════════════════════════════════════
/** /**
* CORE: Trigger rolling similarity scan * CORE: Trigger rolling similarity scan with progressive filtering
*/ */
private suspend fun triggerRollingScan() { private suspend fun triggerRollingScan() {
if (selectedEmbeddings.isEmpty()) { if (selectedEmbeddings.isEmpty()) {
@@ -228,7 +296,15 @@ class RollingScanViewModel @Inject constructor(
try { try {
_isScanning.value = true _isScanning.value = true
Log.d(TAG, "Starting scan with ${selectedEmbeddings.size} selected embeddings") val selectionCount = selectedEmbeddings.size
Log.d(TAG, "Starting scan with $selectionCount selected, ${negativeEmbeddings.size} negative")
// Progressive threshold based on selection count
val similarityFloor = when {
selectionCount <= 3 -> FLOOR_FEW_SEEDS
selectionCount <= 10 -> FLOOR_MEDIUM_SEEDS
else -> FLOOR_MANY_SEEDS
}
// Calculate centroid from selected embeddings // Calculate centroid from selected embeddings
val centroid = faceSimilarityScorer.calculateCentroid(selectedEmbeddings) val centroid = faceSimilarityScorer.calculateCentroid(selectedEmbeddings)
@@ -240,17 +316,38 @@ class RollingScanViewModel @Inject constructor(
centroid = centroid centroid = centroid
) )
// Update image URIs in scored photos // Apply negative penalty, quality boost, and floor filter
val photosWithUris = scoredPhotos.map { photo -> val filteredPhotos = scoredPhotos
photo.copy( .map { photo ->
imageUri = imageUriCache[photo.imageId] ?: photo.imageId // Calculate max similarity to any negative embedding
) val negativePenalty = if (negativeEmbeddings.isNotEmpty()) {
} negativeEmbeddings.maxOfOrNull { neg ->
cosineSimilarity(photo.cachedEmbedding, neg)
} ?: 0f
} else 0f
Log.d(TAG, "Scan complete. Scored ${photosWithUris.size} photos") // Quality multiplier: solo face, large face, good quality
val qualityMultiplier = 1f +
(if (photo.faceCount == 1) 0.15f else 0f) +
(if (photo.faceAreaRatio > 0.15f) 0.10f else 0f) +
(if (photo.qualityScore > 0.7f) 0.10f else 0f)
// Update ranked list // Final score = (similarity - negativePenalty) * qualityMultiplier
_rankedPhotos.value = photosWithUris val adjustedScore = ((photo.similarityScore - negativePenalty * 0.5f) * qualityMultiplier)
.coerceIn(0f, 1f)
photo.copy(
imageUri = imageUriCache[photo.imageId] ?: photo.imageId,
finalScore = adjustedScore
)
}
.filter { it.finalScore >= similarityFloor } // Apply floor
.filter { it.imageId !in _negativeImageIds.value } // Hide negatives
.sortedByDescending { it.finalScore }
Log.d(TAG, "Scan complete. ${filteredPhotos.size} photos above floor $similarityFloor")
_rankedPhotos.value = filteredPhotos
} catch (e: Exception) { } catch (e: Exception) {
Log.e(TAG, "Scan failed", e) Log.e(TAG, "Scan failed", e)
@@ -259,6 +356,19 @@ class RollingScanViewModel @Inject constructor(
} }
} }
private fun cosineSimilarity(a: FloatArray, b: FloatArray): Float {
if (a.size != b.size) return 0f
var dot = 0f
var normA = 0f
var normB = 0f
for (i in a.indices) {
dot += a[i] * b[i]
normA += a[i] * a[i]
normB += b[i] * b[i]
}
return if (normA > 0 && normB > 0) dot / (kotlin.math.sqrt(normA) * kotlin.math.sqrt(normB)) else 0f
}
// ═══════════════════════════════════════════════════════════ // ═══════════════════════════════════════════════════════════
// SUBMISSION // SUBMISSION
// ═══════════════════════════════════════════════════════════ // ═══════════════════════════════════════════════════════════
@@ -299,9 +409,11 @@ class RollingScanViewModel @Inject constructor(
fun reset() { fun reset() {
_uiState.value = RollingScanState.Idle _uiState.value = RollingScanState.Idle
_selectedImageIds.value = emptySet() _selectedImageIds.value = emptySet()
_negativeImageIds.value = emptySet()
_rankedPhotos.value = emptyList() _rankedPhotos.value = emptyList()
_isScanning.value = false _isScanning.value = false
selectedEmbeddings.clear() selectedEmbeddings.clear()
negativeEmbeddings.clear()
allImageIds = emptyList() allImageIds = emptyList()
imageUriCache = emptyMap() imageUriCache = emptyMap()
scanDebouncer.cancel() scanDebouncer.cancel()

View File

@@ -67,13 +67,14 @@ class FaceDetectionHelper(private val context: Context) {
val inputImage = InputImage.fromBitmap(bitmap, 0) val inputImage = InputImage.fromBitmap(bitmap, 0)
val faces = detector.process(inputImage).await() val faces = detector.process(inputImage).await()
// Filter to quality faces only // Filter to quality faces - use lenient scanning filter
// (Discovery filter was too strict, rejecting faces from rolling scan)
val qualityFaces = faces.filter { face -> val qualityFaces = faces.filter { face ->
FaceQualityFilter.validateForDiscovery( FaceQualityFilter.validateForScanning(
face = face, face = face,
imageWidth = bitmap.width, imageWidth = bitmap.width,
imageHeight = bitmap.height imageHeight = bitmap.height
).isValid )
} }
// Sort by face size (area) to get the largest quality face // Sort by face size (area) to get the largest quality face

View File

@@ -192,11 +192,10 @@ class TrainViewModel @Inject constructor(
.first() .first()
if (backgroundTaggingEnabled) { if (backgroundTaggingEnabled) {
// Lower threshold (0.55) since we use multi-centroid matching // Use default threshold (0.62 solo, 0.68 group)
val scanRequest = LibraryScanWorker.createWorkRequest( val scanRequest = LibraryScanWorker.createWorkRequest(
personId = personId, personId = personId,
personName = personName, personName = personName
threshold = 0.55f
) )
workManager.enqueue(scanRequest) workManager.enqueue(scanRequest)
} }
@@ -382,7 +381,7 @@ class TrainViewModel @Inject constructor(
faceDetectionResults = updatedFaceResults, faceDetectionResults = updatedFaceResults,
validationErrors = updatedErrors, validationErrors = updatedErrors,
validImagesWithFaces = updatedValidImages, validImagesWithFaces = updatedValidImages,
excludedImages = excludedImages excludedImages = excludedImages.toSet() // Immutable copy for Compose state detection
) )
} }

View File

@@ -9,6 +9,7 @@ import com.google.mlkit.vision.common.InputImage
import com.google.mlkit.vision.face.FaceDetection import com.google.mlkit.vision.face.FaceDetection
import com.google.mlkit.vision.face.FaceDetectorOptions import com.google.mlkit.vision.face.FaceDetectorOptions
import com.placeholder.sherpai2.data.local.dao.FaceModelDao import com.placeholder.sherpai2.data.local.dao.FaceModelDao
import com.placeholder.sherpai2.data.local.dao.PersonDao
import com.placeholder.sherpai2.domain.clustering.FaceQualityFilter import com.placeholder.sherpai2.domain.clustering.FaceQualityFilter
import com.placeholder.sherpai2.ml.FaceNormalizer import com.placeholder.sherpai2.ml.FaceNormalizer
import com.placeholder.sherpai2.data.local.dao.ImageDao import com.placeholder.sherpai2.data.local.dao.ImageDao
@@ -54,7 +55,8 @@ class LibraryScanWorker @AssistedInject constructor(
@Assisted workerParams: WorkerParameters, @Assisted workerParams: WorkerParameters,
private val imageDao: ImageDao, private val imageDao: ImageDao,
private val faceModelDao: FaceModelDao, private val faceModelDao: FaceModelDao,
private val photoFaceTagDao: PhotoFaceTagDao private val photoFaceTagDao: PhotoFaceTagDao,
private val personDao: PersonDao
) : CoroutineWorker(context, workerParams) { ) : CoroutineWorker(context, workerParams) {
companion object { companion object {
@@ -67,7 +69,8 @@ class LibraryScanWorker @AssistedInject constructor(
const val KEY_MATCHES_FOUND = "matches_found" const val KEY_MATCHES_FOUND = "matches_found"
const val KEY_PHOTOS_SCANNED = "photos_scanned" const val KEY_PHOTOS_SCANNED = "photos_scanned"
private const val DEFAULT_THRESHOLD = 0.70f // Slightly looser than validation private const val DEFAULT_THRESHOLD = 0.62f // Solo photos
private const val GROUP_THRESHOLD = 0.68f // Group photos (stricter)
private const val BATCH_SIZE = 20 private const val BATCH_SIZE = 20
private const val MAX_RETRIES = 3 private const val MAX_RETRIES = 3
@@ -139,21 +142,40 @@ class LibraryScanWorker @AssistedInject constructor(
) )
} }
// Step 2.5: Load person to check isChild flag
val person = withContext(Dispatchers.IO) {
personDao.getPersonById(personId)
}
val isChildTarget = person?.isChild ?: false
// Step 3: Initialize ML components // Step 3: Initialize ML components
val faceNetModel = FaceNetModel(context) val faceNetModel = FaceNetModel(context)
val detector = FaceDetection.getClient( val detector = FaceDetection.getClient(
FaceDetectorOptions.Builder() FaceDetectorOptions.Builder()
.setPerformanceMode(FaceDetectorOptions.PERFORMANCE_MODE_ACCURATE) .setPerformanceMode(FaceDetectorOptions.PERFORMANCE_MODE_ACCURATE)
.setLandmarkMode(FaceDetectorOptions.LANDMARK_MODE_ALL) // Needed for age estimation
.setMinFaceSize(0.15f) .setMinFaceSize(0.15f)
.build() .build()
) )
// Distribution-based minimum threshold (self-calibrating)
val distributionMin = (faceModel.averageConfidence - 2 * faceModel.similarityStdDev)
.coerceAtLeast(faceModel.similarityMin - 0.05f)
.coerceAtLeast(0.50f) // Never go below 0.50 absolute floor
// Get ALL centroids for multi-centroid matching (critical for children) // Get ALL centroids for multi-centroid matching (critical for children)
val modelCentroids = faceModel.getCentroids().map { it.getEmbeddingArray() } val modelCentroids = faceModel.getCentroids().map { it.getEmbeddingArray() }
if (modelCentroids.isEmpty()) { if (modelCentroids.isEmpty()) {
return@withContext Result.failure(workDataOf("error" to "No centroids in model")) return@withContext Result.failure(workDataOf("error" to "No centroids in model"))
} }
// Load ALL other models for "best match wins" comparison
// This prevents tagging siblings incorrectly
val allModels = withContext(Dispatchers.IO) { faceModelDao.getAllActiveFaceModels() }
val otherModelCentroids = allModels
.filter { it.id != faceModel.id }
.map { model -> model.id to model.getCentroids().map { it.getEmbeddingArray() } }
var matchesFound = 0 var matchesFound = 0
var photosScanned = 0 var photosScanned = 0
@@ -172,9 +194,12 @@ class LibraryScanWorker @AssistedInject constructor(
personId = personId, personId = personId,
faceModelId = faceModel.id, faceModelId = faceModel.id,
modelCentroids = modelCentroids, modelCentroids = modelCentroids,
otherModelCentroids = otherModelCentroids,
faceNetModel = faceNetModel, faceNetModel = faceNetModel,
detector = detector, detector = detector,
threshold = threshold threshold = threshold,
distributionMin = distributionMin,
isChildTarget = isChildTarget
) )
if (tags.isNotEmpty()) { if (tags.isNotEmpty()) {
@@ -236,9 +261,12 @@ class LibraryScanWorker @AssistedInject constructor(
personId: String, personId: String,
faceModelId: String, faceModelId: String,
modelCentroids: List<FloatArray>, modelCentroids: List<FloatArray>,
otherModelCentroids: List<Pair<String, List<FloatArray>>>,
faceNetModel: FaceNetModel, faceNetModel: FaceNetModel,
detector: com.google.mlkit.vision.face.FaceDetector, detector: com.google.mlkit.vision.face.FaceDetector,
threshold: Float threshold: Float,
distributionMin: Float,
isChildTarget: Boolean
): List<PhotoFaceTagEntity> = withContext(Dispatchers.IO) { ): List<PhotoFaceTagEntity> = withContext(Dispatchers.IO) {
try { try {
@@ -250,45 +278,94 @@ class LibraryScanWorker @AssistedInject constructor(
val inputImage = InputImage.fromBitmap(bitmap, 0) val inputImage = InputImage.fromBitmap(bitmap, 0)
val faces = detector.process(inputImage).await() val faces = detector.process(inputImage).await()
if (faces.isEmpty()) {
bitmap.recycle()
return@withContext emptyList()
}
// Use higher threshold for group photos
val isGroupPhoto = faces.size > 1
val effectiveThreshold = if (isGroupPhoto) GROUP_THRESHOLD else threshold
// Track best match (only tag ONE face per image to avoid false positives)
var bestMatch: PhotoFaceTagEntity? = null
var bestSimilarity = 0f
// Check each face (filter by quality first) // Check each face (filter by quality first)
val tags = faces.mapNotNull { face -> for (face in faces) {
// Quality check // Quality check
if (!FaceQualityFilter.validateForScanning(face, bitmap.width, bitmap.height)) { if (!FaceQualityFilter.validateForScanning(face, bitmap.width, bitmap.height)) {
return@mapNotNull null continue
}
// Skip very small faces
val faceArea = face.boundingBox.width() * face.boundingBox.height()
val imageArea = bitmap.width * bitmap.height
if (faceArea.toFloat() / imageArea < 0.02f) continue
// SIGNAL 2: Age plausibility check (if target is a child)
if (isChildTarget) {
val ageGroup = FaceQualityFilter.estimateAgeGroup(face, bitmap.width, bitmap.height)
if (ageGroup == FaceQualityFilter.AgeGroup.ADULT) {
continue // Reject clearly adult faces when searching for a child
}
} }
try { try {
// Crop and normalize face for best recognition // Crop and normalize face for best recognition
val faceBitmap = FaceNormalizer.cropAndNormalize(bitmap, face) val faceBitmap = FaceNormalizer.cropAndNormalize(bitmap, face)
?: return@mapNotNull null ?: continue
// Generate embedding // Generate embedding
val faceEmbedding = faceNetModel.generateEmbedding(faceBitmap) val faceEmbedding = faceNetModel.generateEmbedding(faceBitmap)
faceBitmap.recycle() faceBitmap.recycle()
// Match against ALL centroids, use best match (critical for children) // Match against target person's centroids
val similarity = modelCentroids.maxOfOrNull { centroid -> val targetSimilarity = modelCentroids.maxOfOrNull { centroid ->
faceNetModel.calculateSimilarity(faceEmbedding, centroid) faceNetModel.calculateSimilarity(faceEmbedding, centroid)
} ?: 0f } ?: 0f
if (similarity >= threshold) { // SIGNAL 1: Distribution-based rejection
PhotoFaceTagEntity.create( // If similarity is below (mean - 2*stdDev) or (min - 0.05), it's a statistical outlier
if (targetSimilarity < distributionMin) {
continue // Too far below training distribution
}
// SIGNAL 3: Basic threshold check
if (targetSimilarity < effectiveThreshold) {
continue
}
// SIGNAL 4: "Best match wins" - check if any OTHER model scores higher
// This prevents tagging siblings incorrectly
val bestOtherSimilarity = otherModelCentroids.maxOfOrNull { (_, centroids) ->
centroids.maxOfOrNull { centroid ->
faceNetModel.calculateSimilarity(faceEmbedding, centroid)
} ?: 0f
} ?: 0f
val isTargetBestMatch = targetSimilarity > bestOtherSimilarity
// All signals must pass
if (isTargetBestMatch && targetSimilarity > bestSimilarity) {
bestSimilarity = targetSimilarity
bestMatch = PhotoFaceTagEntity.create(
imageId = photo.imageId, imageId = photo.imageId,
faceModelId = faceModelId, faceModelId = faceModelId,
boundingBox = face.boundingBox, boundingBox = face.boundingBox,
confidence = similarity, confidence = targetSimilarity,
faceEmbedding = faceEmbedding faceEmbedding = faceEmbedding
) )
} else {
null
} }
} catch (e: Exception) { } catch (e: Exception) {
null // Skip this face
} }
} }
bitmap.recycle() bitmap.recycle()
tags
// Return only the best match (or empty)
if (bestMatch != null) listOf(bestMatch) else emptyList()
} catch (e: Exception) { } catch (e: Exception) {
emptyList() emptyList()