toofasttooclaude

This commit is contained in:
genki
2026-01-26 14:15:54 -05:00
parent 1ef8faad17
commit cfec2b980a
12 changed files with 661 additions and 42 deletions

View File

@@ -44,14 +44,15 @@ import com.placeholder.sherpai2.data.local.entity.*
PhotoFaceTagEntity::class,
PersonAgeTagEntity::class,
FaceCacheEntity::class,
UserFeedbackEntity::class, // NEW: User corrections
UserFeedbackEntity::class,
PersonStatisticsEntity::class, // Pre-computed person stats
// ===== COLLECTIONS =====
CollectionEntity::class,
CollectionImageEntity::class,
CollectionFilterEntity::class
],
version = 10, // INCREMENTED for user feedback
version = 11, // INCREMENTED for person statistics
exportSchema = false
)
abstract class AppDatabase : RoomDatabase() {
@@ -70,7 +71,8 @@ abstract class AppDatabase : RoomDatabase() {
abstract fun photoFaceTagDao(): PhotoFaceTagDao
abstract fun personAgeTagDao(): PersonAgeTagDao
abstract fun faceCacheDao(): FaceCacheDao
abstract fun userFeedbackDao(): UserFeedbackDao // NEW
abstract fun userFeedbackDao(): UserFeedbackDao
abstract fun personStatisticsDao(): PersonStatisticsDao
// ===== COLLECTIONS DAO =====
abstract fun collectionDao(): CollectionDao
@@ -242,13 +244,41 @@ val MIGRATION_9_10 = object : Migration(9, 10) {
}
}
/**
* MIGRATION 10 → 11 (Person Statistics)
*
* Changes:
* 1. Create person_statistics table for pre-computed aggregates
*/
val MIGRATION_10_11 = object : Migration(10, 11) {
override fun migrate(database: SupportSQLiteDatabase) {
// Create person_statistics table
database.execSQL("""
CREATE TABLE IF NOT EXISTS person_statistics (
personId TEXT PRIMARY KEY NOT NULL,
photoCount INTEGER NOT NULL DEFAULT 0,
firstPhotoDate INTEGER NOT NULL DEFAULT 0,
lastPhotoDate INTEGER NOT NULL DEFAULT 0,
averageConfidence REAL NOT NULL DEFAULT 0,
agesWithPhotos TEXT,
updatedAt INTEGER NOT NULL DEFAULT 0,
FOREIGN KEY(personId) REFERENCES persons(id) ON DELETE CASCADE
)
""")
// Index for sorting by photo count (People Dashboard)
database.execSQL("CREATE INDEX IF NOT EXISTS index_person_statistics_photoCount ON person_statistics(photoCount)")
}
}
/**
* PRODUCTION MIGRATION NOTES:
*
* Before shipping to users, update DatabaseModule to use migrations:
*
* Room.databaseBuilder(context, AppDatabase::class.java, "sherpai.db")
* .addMigrations(MIGRATION_7_8, MIGRATION_8_9, MIGRATION_9_10) // Add all migrations
* .addMigrations(MIGRATION_7_8, MIGRATION_8_9, MIGRATION_9_10, MIGRATION_10_11) // Add all migrations
* // .fallbackToDestructiveMigration() // Remove this
* .build()
*/

View File

@@ -233,6 +233,33 @@ interface FaceCacheDao {
limit: Int = 500
): List<FaceCacheEntity>
/**
* Get premium face CANDIDATES - same criteria but WITHOUT embedding requirement.
* Used to find faces that need embedding generation.
*/
@Query("""
SELECT fc.* FROM face_cache fc
INNER JOIN images i ON fc.imageId = i.imageId
WHERE i.faceCount = 1
AND fc.faceAreaRatio >= :minAreaRatio
AND fc.isFrontal = 1
AND fc.qualityScore >= :minQuality
AND fc.embedding IS NULL
ORDER BY fc.qualityScore DESC, fc.faceAreaRatio DESC
LIMIT :limit
""")
suspend fun getPremiumFaceCandidatesNeedingEmbeddings(
minAreaRatio: Float = 0.10f,
minQuality: Float = 0.7f,
limit: Int = 500
): List<FaceCacheEntity>
/**
* Update embedding for a face cache entry
*/
@Query("UPDATE face_cache SET embedding = :embedding WHERE imageId = :imageId AND faceIndex = :faceIndex")
suspend fun updateEmbedding(imageId: String, faceIndex: Int, embedding: String)
/**
* Count of premium faces available
*/

View File

@@ -83,9 +83,89 @@ interface PhotoFaceTagDao {
*/
@Query("SELECT * FROM photo_face_tags ORDER BY detectedAt DESC LIMIT :limit")
suspend fun getRecentlyDetectedFaces(limit: Int): List<PhotoFaceTagEntity>
// ===== CO-OCCURRENCE QUERIES =====
/**
* Find people who appear in photos together with a given person.
* Returns list of (otherFaceModelId, count) sorted by count descending.
* Use case: "Who appears most with Mom?" or "Show photos of Mom WITH Dad"
*/
@Query("""
SELECT pft2.faceModelId as otherFaceModelId, COUNT(DISTINCT pft1.imageId) as coCount
FROM photo_face_tags pft1
INNER JOIN photo_face_tags pft2 ON pft1.imageId = pft2.imageId
WHERE pft1.faceModelId = :faceModelId
AND pft2.faceModelId != :faceModelId
GROUP BY pft2.faceModelId
ORDER BY coCount DESC
""")
suspend fun getCoOccurrences(faceModelId: String): List<PersonCoOccurrence>
/**
* Get images where BOTH people appear together.
*/
@Query("""
SELECT DISTINCT pft1.imageId
FROM photo_face_tags pft1
INNER JOIN photo_face_tags pft2 ON pft1.imageId = pft2.imageId
WHERE pft1.faceModelId = :faceModelId1
AND pft2.faceModelId = :faceModelId2
ORDER BY pft1.detectedAt DESC
""")
suspend fun getImagesWithBothPeople(faceModelId1: String, faceModelId2: String): List<String>
/**
* Get images where person appears ALONE (no other trained faces).
*/
@Query("""
SELECT imageId FROM photo_face_tags
WHERE faceModelId = :faceModelId
AND imageId NOT IN (
SELECT imageId FROM photo_face_tags
WHERE faceModelId != :faceModelId
)
ORDER BY detectedAt DESC
""")
suspend fun getImagesWithPersonAlone(faceModelId: String): List<String>
/**
* Get images where ALL specified people appear (N-way intersection).
* For "Intersection Search" moonshot feature.
*/
@Query("""
SELECT imageId FROM photo_face_tags
WHERE faceModelId IN (:faceModelIds)
GROUP BY imageId
HAVING COUNT(DISTINCT faceModelId) = :requiredCount
""")
suspend fun getImagesWithAllPeople(faceModelIds: List<String>, requiredCount: Int): List<String>
/**
* Get images with at least N of the specified people (family portrait detection).
*/
@Query("""
SELECT imageId, COUNT(DISTINCT faceModelId) as memberCount
FROM photo_face_tags
WHERE faceModelId IN (:faceModelIds)
GROUP BY imageId
HAVING memberCount >= :minMembers
ORDER BY memberCount DESC
""")
suspend fun getFamilyPortraits(faceModelIds: List<String>, minMembers: Int): List<FamilyPortraitResult>
}
data class FamilyPortraitResult(
val imageId: String,
val memberCount: Int
)
data class FaceModelPhotoCount(
val faceModelId: String,
val photoCount: Int
)
data class PersonCoOccurrence(
val otherFaceModelId: String,
val coCount: Int
)

View File

@@ -99,6 +99,13 @@ data class FaceCacheEntity(
companion object {
const val CURRENT_CACHE_VERSION = 1
/**
* Convert FloatArray embedding to JSON string for storage
*/
fun embeddingToJson(embedding: FloatArray): String {
return embedding.joinToString(",")
}
/**
* Create from ML Kit face detection result
*/

View File

@@ -75,7 +75,21 @@ class PopulateFaceDetectionCacheUseCase @Inject constructor(
)
try {
val imagesToScan = imageDao.getImagesNeedingFaceDetection()
// Get images that need face detection (hasFaces IS NULL)
var imagesToScan = imageDao.getImagesNeedingFaceDetection()
// CRITICAL FIX: Also check for images marked as having faces but no FaceCacheEntity
if (imagesToScan.isEmpty()) {
val faceStats = faceCacheDao.getCacheStats()
if (faceStats.totalFaces == 0) {
// FaceCacheEntity is empty - rescan images that have faces
val imagesWithFaces = imageDao.getImagesWithFaces()
if (imagesWithFaces.isNotEmpty()) {
Log.w(TAG, "FaceCacheEntity empty but ${imagesWithFaces.size} images have faces - rescanning")
imagesToScan = imagesWithFaces
}
}
}
if (imagesToScan.isEmpty()) {
Log.d(TAG, "No images need scanning")
@@ -184,7 +198,7 @@ class PopulateFaceDetectionCacheUseCase @Inject constructor(
imageUri = image.imageUri
)
// Create FaceCacheEntity entries for each face
// Create FaceCacheEntity entries for each face (NO embeddings - generated on demand)
val faceCacheEntries = faces.mapIndexed { index, face ->
createFaceCacheEntry(
imageId = image.imageId,
@@ -205,7 +219,8 @@ class PopulateFaceDetectionCacheUseCase @Inject constructor(
/**
* Create FaceCacheEntity from ML Kit Face
*
* Uses FaceCacheEntity.create() which calculates quality metrics automatically
* Uses FaceCacheEntity.create() which calculates quality metrics automatically.
* Embeddings are NOT generated here - they're generated on-demand in Training/Discovery.
*/
private fun createFaceCacheEntry(
imageId: String,
@@ -225,7 +240,7 @@ class PopulateFaceDetectionCacheUseCase @Inject constructor(
imageHeight = imageHeight,
confidence = 0.9f, // High confidence from accurate detector
isFrontal = isFrontal,
embedding = null // Will be generated later during Discovery
embedding = null // Generated on-demand in Training/Discovery
)
}
@@ -312,13 +327,27 @@ class PopulateFaceDetectionCacheUseCase @Inject constructor(
val imageStats = imageDao.getFaceCacheStats()
val faceStats = faceCacheDao.getCacheStats()
// CRITICAL FIX: If ImageEntity says "scanned" but FaceCacheEntity is empty,
// we need to re-scan. This happens after DB migration clears face_cache table.
val imagesWithFaces = imageStats?.imagesWithFaces ?: 0
val facesCached = faceStats.totalFaces
// If we have images marked as having faces but no FaceCacheEntity entries,
// those images need re-scanning
val needsRescan = if (imagesWithFaces > 0 && facesCached == 0) {
Log.w(TAG, "⚠️ FaceCacheEntity is empty but $imagesWithFaces images marked as having faces - forcing rescan")
imagesWithFaces
} else {
imageStats?.needsScanning ?: 0
}
CacheStats(
totalImages = imageStats?.totalImages ?: 0,
imagesWithFaceCache = imageStats?.imagesWithFaceCache ?: 0,
imagesWithFaces = imageStats?.imagesWithFaces ?: 0,
imagesWithFaces = imagesWithFaces,
imagesWithoutFaces = imageStats?.imagesWithoutFaces ?: 0,
needsScanning = imageStats?.needsScanning ?: 0,
totalFacesCached = faceStats.totalFaces,
needsScanning = needsRescan,
totalFacesCached = facesCached,
facesWithEmbeddings = faceStats.withEmbeddings,
averageQuality = faceStats.avgQuality
)

View File

@@ -20,6 +20,7 @@ import androidx.lifecycle.compose.collectAsStateWithLifecycle
import androidx.navigation.NavController
import coil.compose.AsyncImage
import com.placeholder.sherpai2.data.local.entity.TagEntity
import com.placeholder.sherpai2.ui.imagedetail.viewmodel.FaceTagInfo
import com.placeholder.sherpai2.ui.imagedetail.viewmodel.ImageDetailViewModel
import net.engawapg.lib.zoomable.rememberZoomState
import net.engawapg.lib.zoomable.zoomable
@@ -51,8 +52,12 @@ fun ImageDetailScreen(
}
val tags by viewModel.tags.collectAsStateWithLifecycle()
val faceTags by viewModel.faceTags.collectAsStateWithLifecycle()
var showTags by remember { mutableStateOf(false) }
// Total tag count for badge
val totalTagCount = tags.size + faceTags.size
// Navigation state
val currentIndex = if (allImageUris.isNotEmpty()) allImageUris.indexOf(imageUri) else -1
val hasNavigation = allImageUris.isNotEmpty() && currentIndex >= 0
@@ -84,27 +89,35 @@ fun ImageDetailScreen(
horizontalArrangement = Arrangement.spacedBy(4.dp),
verticalAlignment = Alignment.CenterVertically
) {
if (tags.isNotEmpty()) {
if (totalTagCount > 0) {
Badge(
containerColor = if (showTags)
MaterialTheme.colorScheme.primary
else if (faceTags.isNotEmpty())
MaterialTheme.colorScheme.tertiary
else
MaterialTheme.colorScheme.surfaceVariant
) {
Text(
tags.size.toString(),
totalTagCount.toString(),
color = if (showTags)
MaterialTheme.colorScheme.onPrimary
else if (faceTags.isNotEmpty())
MaterialTheme.colorScheme.onTertiary
else
MaterialTheme.colorScheme.onSurfaceVariant
)
}
}
Icon(
if (showTags) Icons.Default.Label else Icons.Default.LocalOffer,
if (faceTags.isNotEmpty()) Icons.Default.Face
else if (showTags) Icons.Default.Label
else Icons.Default.LocalOffer,
"Show Tags",
tint = if (showTags)
MaterialTheme.colorScheme.primary
else if (faceTags.isNotEmpty())
MaterialTheme.colorScheme.tertiary
else
MaterialTheme.colorScheme.onSurfaceVariant
)
@@ -189,6 +202,30 @@ fun ImageDetailScreen(
contentPadding = PaddingValues(16.dp),
verticalArrangement = Arrangement.spacedBy(8.dp)
) {
// Face Tags Section (People in Photo)
if (faceTags.isNotEmpty()) {
item {
Text(
"People (${faceTags.size})",
style = MaterialTheme.typography.titleMedium,
fontWeight = FontWeight.Bold,
color = MaterialTheme.colorScheme.tertiary
)
}
items(faceTags, key = { it.tagId }) { faceTag ->
FaceTagCard(
faceTag = faceTag,
onRemove = { viewModel.removeFaceTag(faceTag) }
)
}
item {
Spacer(modifier = Modifier.height(8.dp))
}
}
// Regular Tags Section
item {
Text(
"Tags (${tags.size})",
@@ -197,7 +234,7 @@ fun ImageDetailScreen(
)
}
if (tags.isEmpty()) {
if (tags.isEmpty() && faceTags.isEmpty()) {
item {
Text(
"No tags yet",
@@ -205,6 +242,14 @@ fun ImageDetailScreen(
color = MaterialTheme.colorScheme.onSurfaceVariant
)
}
} else if (tags.isEmpty()) {
item {
Text(
"No other tags",
style = MaterialTheme.typography.bodySmall,
color = MaterialTheme.colorScheme.onSurfaceVariant
)
}
}
items(tags, key = { it.tagId }) { tag ->
@@ -220,6 +265,83 @@ fun ImageDetailScreen(
}
}
@Composable
private fun FaceTagCard(
faceTag: FaceTagInfo,
onRemove: () -> Unit
) {
Card(
modifier = Modifier.fillMaxWidth(),
colors = CardDefaults.cardColors(
containerColor = MaterialTheme.colorScheme.tertiaryContainer
),
shape = RoundedCornerShape(8.dp)
) {
Row(
modifier = Modifier
.fillMaxWidth()
.padding(12.dp),
horizontalArrangement = Arrangement.SpaceBetween,
verticalAlignment = Alignment.CenterVertically
) {
Column(modifier = Modifier.weight(1f)) {
Row(
horizontalArrangement = Arrangement.spacedBy(8.dp),
verticalAlignment = Alignment.CenterVertically
) {
Icon(
imageVector = Icons.Default.Face,
contentDescription = null,
modifier = Modifier.size(20.dp),
tint = MaterialTheme.colorScheme.tertiary
)
Text(
text = faceTag.personName,
style = MaterialTheme.typography.bodyLarge,
fontWeight = FontWeight.SemiBold
)
}
Row(
horizontalArrangement = Arrangement.spacedBy(4.dp),
verticalAlignment = Alignment.CenterVertically
) {
Text(
text = "Face Recognition",
style = MaterialTheme.typography.labelSmall,
color = MaterialTheme.colorScheme.onSurfaceVariant
)
Text(
text = "",
style = MaterialTheme.typography.labelSmall,
color = MaterialTheme.colorScheme.onSurfaceVariant
)
Text(
text = "${(faceTag.confidence * 100).toInt()}% confidence",
style = MaterialTheme.typography.labelSmall,
color = if (faceTag.confidence >= 0.7f)
MaterialTheme.colorScheme.primary
else if (faceTag.confidence >= 0.5f)
MaterialTheme.colorScheme.secondary
else
MaterialTheme.colorScheme.error
)
}
}
// Remove button
IconButton(
onClick = onRemove,
colors = IconButtonDefaults.iconButtonColors(
contentColor = MaterialTheme.colorScheme.error
)
) {
Icon(Icons.Default.Delete, "Remove face tag")
}
}
}
}
@Composable
private fun TagCard(
tag: TagEntity,

View File

@@ -2,6 +2,10 @@ package com.placeholder.sherpai2.ui.imagedetail.viewmodel
import androidx.lifecycle.ViewModel
import androidx.lifecycle.viewModelScope
import com.placeholder.sherpai2.data.local.dao.FaceModelDao
import com.placeholder.sherpai2.data.local.dao.ImageDao
import com.placeholder.sherpai2.data.local.dao.PersonDao
import com.placeholder.sherpai2.data.local.dao.PhotoFaceTagDao
import com.placeholder.sherpai2.data.local.entity.TagEntity
import com.placeholder.sherpai2.domain.repository.TaggingRepository
import dagger.hilt.android.lifecycle.HiltViewModel
@@ -10,17 +14,33 @@ import kotlinx.coroutines.flow.*
import kotlinx.coroutines.launch
import javax.inject.Inject
/**
* Represents a person tagged in this photo via face recognition
*/
data class FaceTagInfo(
val personId: String,
val personName: String,
val confidence: Float,
val faceModelId: String,
val tagId: String
)
/**
* ImageDetailViewModel
*
* Owns:
* - Image context
* - Tag write operations
* - Face tag display (people recognized in photo)
*/
@HiltViewModel
@OptIn(ExperimentalCoroutinesApi::class)
class ImageDetailViewModel @Inject constructor(
private val tagRepository: TaggingRepository
private val tagRepository: TaggingRepository,
private val imageDao: ImageDao,
private val photoFaceTagDao: PhotoFaceTagDao,
private val faceModelDao: FaceModelDao,
private val personDao: PersonDao
) : ViewModel() {
private val imageUri = MutableStateFlow<String?>(null)
@@ -37,8 +57,43 @@ class ImageDetailViewModel @Inject constructor(
initialValue = emptyList()
)
// Face tags (people recognized in this photo)
private val _faceTags = MutableStateFlow<List<FaceTagInfo>>(emptyList())
val faceTags: StateFlow<List<FaceTagInfo>> = _faceTags.asStateFlow()
fun loadImage(uri: String) {
imageUri.value = uri
loadFaceTags(uri)
}
private fun loadFaceTags(uri: String) {
viewModelScope.launch {
try {
// Get imageId from URI
val image = imageDao.getImageByUri(uri) ?: return@launch
// Get face tags for this image
val faceTags = photoFaceTagDao.getTagsForImage(image.imageId)
// Resolve to person names
val faceTagInfos = faceTags.mapNotNull { tag ->
val faceModel = faceModelDao.getFaceModelById(tag.faceModelId) ?: return@mapNotNull null
val person = personDao.getPersonById(faceModel.personId) ?: return@mapNotNull null
FaceTagInfo(
personId = person.id,
personName = person.name,
confidence = tag.confidence,
faceModelId = tag.faceModelId,
tagId = tag.id
)
}
_faceTags.value = faceTagInfos.sortedByDescending { it.confidence }
} catch (e: Exception) {
_faceTags.value = emptyList()
}
}
}
fun addTag(value: String) {
@@ -54,4 +109,15 @@ class ImageDetailViewModel @Inject constructor(
tagRepository.removeTagFromImage(uri, tag.value)
}
}
/**
* Remove a face tag (person recognition)
*/
fun removeFaceTag(faceTagInfo: FaceTagInfo) {
viewModelScope.launch {
photoFaceTagDao.deleteTagById(faceTagInfo.tagId)
// Reload face tags
imageUri.value?.let { loadFaceTags(it) }
}
}
}

View File

@@ -95,6 +95,9 @@ fun PersonInventoryScreen(
},
onDelete = { personId ->
viewModel.deletePerson(personId)
},
onClearTags = { personId ->
viewModel.clearTagsForPerson(personId)
}
)
}
@@ -319,7 +322,8 @@ private fun PersonList(
persons: List<PersonWithModelInfo>,
onScan: (String) -> Unit,
onView: (String) -> Unit,
onDelete: (String) -> Unit
onDelete: (String) -> Unit,
onClearTags: (String) -> Unit
) {
LazyColumn(
contentPadding = PaddingValues(vertical = 8.dp)
@@ -332,7 +336,8 @@ private fun PersonList(
person = person,
onScan = { onScan(person.person.id) },
onView = { onView(person.person.id) },
onDelete = { onDelete(person.person.id) }
onDelete = { onDelete(person.person.id) },
onClearTags = { onClearTags(person.person.id) }
)
}
}
@@ -343,9 +348,34 @@ private fun PersonCard(
person: PersonWithModelInfo,
onScan: () -> Unit,
onView: () -> Unit,
onDelete: () -> Unit
onDelete: () -> Unit,
onClearTags: () -> Unit
) {
var showDeleteDialog by remember { mutableStateOf(false) }
var showClearDialog by remember { mutableStateOf(false) }
if (showClearDialog) {
AlertDialog(
onDismissRequest = { showClearDialog = false },
title = { Text("Clear tags for ${person.person.name}?") },
text = { Text("This will remove all ${person.taggedPhotoCount} photo tags but keep the face model. You can re-scan after clearing.") },
confirmButton = {
TextButton(
onClick = {
showClearDialog = false
onClearTags()
}
) {
Text("Clear Tags", color = MaterialTheme.colorScheme.error)
}
},
dismissButton = {
TextButton(onClick = { showClearDialog = false }) {
Text("Cancel")
}
}
)
}
if (showDeleteDialog) {
AlertDialog(
@@ -413,6 +443,17 @@ private fun PersonCard(
)
}
// Clear tags button (if has tags)
if (person.taggedPhotoCount > 0) {
IconButton(onClick = { showClearDialog = true }) {
Icon(
Icons.Default.ClearAll,
contentDescription = "Clear Tags",
tint = MaterialTheme.colorScheme.secondary
)
}
}
// Delete button
IconButton(onClick = { showDeleteDialog = true }) {
Icon(

View File

@@ -105,6 +105,21 @@ class PersonInventoryViewModel @Inject constructor(
}
}
/**
* Clear all face tags for a person (keep model, allow rescan)
*/
fun clearTagsForPerson(personId: String) {
viewModelScope.launch(Dispatchers.IO) {
try {
val faceModel = faceModelDao.getFaceModelByPersonId(personId)
if (faceModel != null) {
photoFaceTagDao.deleteTagsForFaceModel(faceModel.id)
}
loadPersons()
} catch (e: Exception) {}
}
}
fun scanForPerson(personId: String) {
viewModelScope.launch(Dispatchers.IO) {
try {
@@ -133,10 +148,20 @@ class PersonInventoryViewModel @Inject constructor(
.build()
val detector = FaceDetection.getClient(detectorOptions)
val modelEmbedding = faceModel.getEmbeddingArray()
val faceNetModel = FaceNetModel(context)
// CRITICAL: Use ALL centroids for matching
val modelCentroids = faceModel.getCentroids().map { it.getEmbeddingArray() }
val trainingCount = faceModel.trainingImageCount
val baseThreshold = ThresholdStrategy.getLiberalThreshold(trainingCount)
android.util.Log.e("PersonScan", "=== CENTROIDS: ${modelCentroids.size}, trainingCount: $trainingCount ===")
if (modelCentroids.isEmpty()) {
_scanningState.value = ScanningState.Error("No centroids found")
return@launch
}
val faceNetModel = FaceNetModel(context)
// Production threshold - balance precision vs recall
val baseThreshold = 0.58f
android.util.Log.d("PersonScan", "Using threshold: $baseThreshold, centroids: ${modelCentroids.size}")
val completed = AtomicInteger(0)
val facesFound = AtomicInteger(0)
@@ -148,7 +173,7 @@ class PersonInventoryViewModel @Inject constructor(
val jobs = untaggedImages.map { image ->
async {
semaphore.withPermit {
processImage(image, detector, faceNetModel, modelEmbedding, trainingCount, baseThreshold, personId, faceModel.id, batchMatches, batchUpdateMutex, completed, facesFound, startTime, totalToScan, person.name)
processImage(image, detector, faceNetModel, modelCentroids, trainingCount, baseThreshold, personId, faceModel.id, batchMatches, batchUpdateMutex, completed, facesFound, startTime, totalToScan, person.name)
}
}
}
@@ -175,7 +200,7 @@ class PersonInventoryViewModel @Inject constructor(
private suspend fun processImage(
image: ImageEntity, detector: com.google.mlkit.vision.face.FaceDetector, faceNetModel: FaceNetModel,
modelEmbedding: FloatArray, trainingCount: Int, baseThreshold: Float, personId: String, faceModelId: String,
modelCentroids: List<FloatArray>, trainingCount: Int, baseThreshold: Float, personId: String, faceModelId: String,
batchMatches: MutableList<Triple<String, String, Float>>, batchUpdateMutex: Mutex,
completed: AtomicInteger, facesFound: AtomicInteger, startTime: Long, totalToScan: Int, personName: String
) {
@@ -212,14 +237,19 @@ class PersonInventoryViewModel @Inject constructor(
(face.boundingBox.bottom * scaleY).toInt()
)
val faceBitmap = loadFaceRegion(uri, scaledBounds) ?: continue
// CRITICAL: Add padding to face crop (same as training)
val faceBitmap = loadFaceRegionWithPadding(uri, scaledBounds, sizeOpts.outWidth, sizeOpts.outHeight) ?: continue
val faceEmbedding = faceNetModel.generateEmbedding(faceBitmap)
val similarity = faceNetModel.calculateSimilarity(faceEmbedding, modelEmbedding)
faceBitmap.recycle()
if (similarity >= threshold) {
// Match against ALL centroids, use best match
val bestSimilarity = modelCentroids.maxOfOrNull { centroid ->
faceNetModel.calculateSimilarity(faceEmbedding, centroid)
} ?: 0f
if (bestSimilarity >= threshold) {
batchUpdateMutex.withLock {
batchMatches.add(Triple(personId, image.imageId, similarity))
batchMatches.add(Triple(personId, image.imageId, bestSimilarity))
facesFound.incrementAndGet()
if (batchMatches.size >= BATCH_DB_SIZE) {
saveBatchMatches(batchMatches.toList(), faceModelId)
@@ -250,18 +280,32 @@ class PersonInventoryViewModel @Inject constructor(
} catch (e: Exception) { null }
}
private fun loadFaceRegion(uri: Uri, bounds: android.graphics.Rect): Bitmap? {
/**
* Load face region WITH 25% padding - CRITICAL for matching training conditions
*/
private fun loadFaceRegionWithPadding(uri: Uri, bounds: android.graphics.Rect, imgWidth: Int, imgHeight: Int): Bitmap? {
return try {
val full = context.contentResolver.openInputStream(uri)?.use {
BitmapFactory.decodeStream(it, null, BitmapFactory.Options().apply { inPreferredConfig = Bitmap.Config.ARGB_8888 })
} ?: return null
val safeLeft = bounds.left.coerceIn(0, full.width - 1)
val safeTop = bounds.top.coerceIn(0, full.height - 1)
val safeWidth = bounds.width().coerceAtMost(full.width - safeLeft)
val safeHeight = bounds.height().coerceAtMost(full.height - safeTop)
// Add 25% padding (same as training)
val padding = (kotlin.math.max(bounds.width(), bounds.height()) * 0.25f).toInt()
val cropped = Bitmap.createBitmap(full, safeLeft, safeTop, safeWidth, safeHeight)
val left = (bounds.left - padding).coerceAtLeast(0)
val top = (bounds.top - padding).coerceAtLeast(0)
val right = (bounds.right + padding).coerceAtMost(full.width)
val bottom = (bounds.bottom + padding).coerceAtMost(full.height)
val width = right - left
val height = bottom - top
if (width <= 0 || height <= 0) {
full.recycle()
return null
}
val cropped = Bitmap.createBitmap(full, left, top, width, height)
full.recycle()
cropped
} catch (e: Exception) { null }

View File

@@ -192,10 +192,11 @@ class TrainViewModel @Inject constructor(
.first()
if (backgroundTaggingEnabled) {
// Lower threshold (0.55) since we use multi-centroid matching
val scanRequest = LibraryScanWorker.createWorkRequest(
personId = personId,
personName = personName,
threshold = 0.65f
threshold = 0.55f
)
workManager.enqueue(scanRequest)
}

View File

@@ -49,6 +49,7 @@ fun TrainingPhotoSelectorScreen(
val isRanking by viewModel.isRanking.collectAsStateWithLifecycle()
val showPremiumOnly by viewModel.showPremiumOnly.collectAsStateWithLifecycle()
val premiumCount by viewModel.premiumCount.collectAsStateWithLifecycle()
val embeddingProgress by viewModel.embeddingProgress.collectAsStateWithLifecycle()
Scaffold(
topBar = {
@@ -154,8 +155,34 @@ fun TrainingPhotoSelectorScreen(
Box(
modifier = Modifier.fillMaxSize(),
contentAlignment = Alignment.Center
) {
Column(
horizontalAlignment = Alignment.CenterHorizontally,
verticalArrangement = Arrangement.spacedBy(16.dp)
) {
CircularProgressIndicator()
// Capture value to avoid race condition
val progress = embeddingProgress
if (progress != null) {
Text(
"Preparing faces: ${progress.current}/${progress.total}",
style = MaterialTheme.typography.bodyMedium,
color = MaterialTheme.colorScheme.onSurfaceVariant
)
LinearProgressIndicator(
progress = { progress.current.toFloat() / progress.total },
modifier = Modifier
.width(200.dp)
.padding(top = 8.dp)
)
} else {
Text(
"Loading premium faces...",
style = MaterialTheme.typography.bodyMedium,
color = MaterialTheme.colorScheme.onSurfaceVariant
)
}
}
}
}
photos.isEmpty() -> {

View File

@@ -1,20 +1,31 @@
package com.placeholder.sherpai2.ui.trainingprep
import android.app.Application
import android.graphics.Bitmap
import android.graphics.BitmapFactory
import android.graphics.Rect
import android.net.Uri
import android.util.Log
import androidx.lifecycle.ViewModel
import androidx.lifecycle.AndroidViewModel
import androidx.lifecycle.viewModelScope
import com.placeholder.sherpai2.data.local.dao.FaceCacheDao
import com.placeholder.sherpai2.data.local.dao.ImageDao
import com.placeholder.sherpai2.data.local.entity.FaceCacheEntity
import com.placeholder.sherpai2.data.local.entity.ImageEntity
import com.placeholder.sherpai2.domain.similarity.FaceSimilarityScorer
import com.placeholder.sherpai2.ml.FaceNetModel
import dagger.hilt.android.lifecycle.HiltViewModel
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.Job
import kotlinx.coroutines.delay
import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.StateFlow
import kotlinx.coroutines.flow.asStateFlow
import kotlinx.coroutines.launch
import kotlinx.coroutines.withContext
import javax.inject.Inject
import kotlin.math.max
import kotlin.math.min
/**
* TrainingPhotoSelectorViewModel - PREMIUM GRID + ROLLING SCAN
@@ -27,15 +38,18 @@ import javax.inject.Inject
*/
@HiltViewModel
class TrainingPhotoSelectorViewModel @Inject constructor(
application: Application,
private val imageDao: ImageDao,
private val faceCacheDao: FaceCacheDao,
private val faceSimilarityScorer: FaceSimilarityScorer
) : ViewModel() {
private val faceSimilarityScorer: FaceSimilarityScorer,
private val faceNetModel: FaceNetModel
) : AndroidViewModel(application) {
companion object {
private const val TAG = "PremiumSelector"
private const val MIN_SEEDS_FOR_ROLLING_SCAN = 1
private const val MAX_SEEDS_FOR_ROLLING_SCAN = 5
private const val MAX_EMBEDDINGS_TO_GENERATE = 500
}
// All photos (for fallback / full list)
@@ -56,6 +70,12 @@ class TrainingPhotoSelectorViewModel @Inject constructor(
private val _isRanking = MutableStateFlow(false)
val isRanking: StateFlow<Boolean> = _isRanking.asStateFlow()
// Embedding generation progress
private val _embeddingProgress = MutableStateFlow<EmbeddingProgress?>(null)
val embeddingProgress: StateFlow<EmbeddingProgress?> = _embeddingProgress.asStateFlow()
data class EmbeddingProgress(val current: Int, val total: Int)
// Premium mode toggle
private val _showPremiumOnly = MutableStateFlow(true)
val showPremiumOnly: StateFlow<Boolean> = _showPremiumOnly.asStateFlow()
@@ -79,20 +99,47 @@ class TrainingPhotoSelectorViewModel @Inject constructor(
/**
* Load PREMIUM faces first (solo, large, frontal, high quality)
* If no embeddings exist, generate them on-demand for premium candidates
*/
private fun loadPremiumFaces() {
viewModelScope.launch {
try {
_isLoading.value = true
// Get premium faces from cache
val premiumFaceCache = faceCacheDao.getPremiumFaces(
// First check if premium faces with embeddings exist
var premiumFaceCache = faceCacheDao.getPremiumFaces(
minAreaRatio = 0.10f,
minQuality = 0.7f,
limit = 500
)
Log.d(TAG, " Found ${premiumFaceCache.size} premium faces")
Log.d(TAG, "📊 Found ${premiumFaceCache.size} premium faces with embeddings")
// If no premium faces with embeddings, generate them on-demand
if (premiumFaceCache.isEmpty()) {
Log.d(TAG, "⚠️ No premium faces with embeddings - generating on-demand")
val candidates = faceCacheDao.getPremiumFaceCandidatesNeedingEmbeddings(
minAreaRatio = 0.10f,
minQuality = 0.7f,
limit = MAX_EMBEDDINGS_TO_GENERATE
)
Log.d(TAG, "📦 Found ${candidates.size} premium candidates needing embeddings")
if (candidates.isNotEmpty()) {
generateEmbeddingsForCandidates(candidates)
// Re-query after generating
premiumFaceCache = faceCacheDao.getPremiumFaces(
minAreaRatio = 0.10f,
minQuality = 0.7f,
limit = 500
)
Log.d(TAG, "✅ After generation: ${premiumFaceCache.size} premium faces")
}
}
_premiumCount.value = premiumFaceCache.size
// Get corresponding ImageEntities
@@ -117,10 +164,108 @@ class TrainingPhotoSelectorViewModel @Inject constructor(
loadAllFaces()
} finally {
_isLoading.value = false
_embeddingProgress.value = null
}
}
}
/**
* Generate embeddings for premium face candidates
*/
private suspend fun generateEmbeddingsForCandidates(candidates: List<FaceCacheEntity>) {
val context = getApplication<Application>()
val total = candidates.size
var processed = 0
withContext(Dispatchers.IO) {
// Get image URIs for candidates
val imageIds = candidates.map { it.imageId }.distinct()
val images = imageDao.getImagesByIds(imageIds)
val imageUriMap = images.associate { it.imageId to it.imageUri }
for (candidate in candidates) {
try {
val imageUri = imageUriMap[candidate.imageId] ?: continue
// Load bitmap
val bitmap = loadBitmapOptimized(context, Uri.parse(imageUri)) ?: continue
// Crop face
val croppedFace = cropFaceWithPadding(bitmap, candidate.getBoundingBox())
bitmap.recycle()
if (croppedFace == null) continue
// Generate embedding
val embedding = faceNetModel.generateEmbedding(croppedFace)
croppedFace.recycle()
// Validate embedding
if (embedding.any { it != 0f }) {
// Save to database
val embeddingJson = FaceCacheEntity.embeddingToJson(embedding)
faceCacheDao.updateEmbedding(candidate.imageId, candidate.faceIndex, embeddingJson)
}
} catch (e: Exception) {
Log.w(TAG, "Failed to generate embedding for ${candidate.imageId}: ${e.message}")
}
processed++
withContext(Dispatchers.Main) {
_embeddingProgress.value = EmbeddingProgress(processed, total)
}
}
}
Log.d(TAG, "✅ Generated embeddings for $processed/$total candidates")
}
private fun loadBitmapOptimized(context: android.content.Context, uri: Uri, maxDim: Int = 768): Bitmap? {
return try {
val options = BitmapFactory.Options().apply { inJustDecodeBounds = true }
context.contentResolver.openInputStream(uri)?.use { stream ->
BitmapFactory.decodeStream(stream, null, options)
}
var sampleSize = 1
while (options.outWidth / sampleSize > maxDim || options.outHeight / sampleSize > maxDim) {
sampleSize *= 2
}
val finalOptions = BitmapFactory.Options().apply {
inSampleSize = sampleSize
inPreferredConfig = Bitmap.Config.ARGB_8888
}
context.contentResolver.openInputStream(uri)?.use { stream ->
BitmapFactory.decodeStream(stream, null, finalOptions)
}
} catch (e: Exception) {
Log.w(TAG, "Failed to load bitmap: ${e.message}")
null
}
}
private fun cropFaceWithPadding(bitmap: Bitmap, boundingBox: Rect): Bitmap? {
return try {
val padding = (max(boundingBox.width(), boundingBox.height()) * 0.25f).toInt()
val left = max(0, boundingBox.left - padding)
val top = max(0, boundingBox.top - padding)
val right = min(bitmap.width, boundingBox.right + padding)
val bottom = min(bitmap.height, boundingBox.bottom + padding)
val width = right - left
val height = bottom - top
if (width > 0 && height > 0) {
Bitmap.createBitmap(bitmap, left, top, width, height)
} else null
} catch (e: Exception) {
Log.w(TAG, "Failed to crop face: ${e.message}")
null
}
}
/**
* Fallback: load all photos with faces
*/