From cfec2b980ad4781b69bd70243cfb4561fcbda9b5 Mon Sep 17 00:00:00 2001 From: genki <123@1234.com> Date: Mon, 26 Jan 2026 14:15:54 -0500 Subject: [PATCH] toofasttooclaude --- .../sherpai2/data/local/AppDatabase.kt | 38 ++++- .../sherpai2/data/local/dao/Facecachedao.kt | 27 +++ .../data/local/dao/Photofacetagdao.kt | 80 +++++++++ .../data/local/entity/Facecacheentity.kt | 7 + .../Populatefacedetectioncacheusecase.kt | 43 ++++- .../ui/imagedetail/ImageDetailScreen.kt | 130 ++++++++++++++- .../viewmodel/ImageDetailViewModel.kt | 68 +++++++- .../modelinventory/Personinventoryscreen.kt | 47 +++++- .../Personinventoryviewmodel.kt | 74 +++++++-- .../ui/trainingprep/TrainViewModel.kt | 3 +- .../Trainingphotoselectorscreen.kt | 29 +++- .../Trainingphotoselectorviewmodel.kt | 157 +++++++++++++++++- 12 files changed, 661 insertions(+), 42 deletions(-) diff --git a/app/src/main/java/com/placeholder/sherpai2/data/local/AppDatabase.kt b/app/src/main/java/com/placeholder/sherpai2/data/local/AppDatabase.kt index a128c5c..26b8623 100644 --- a/app/src/main/java/com/placeholder/sherpai2/data/local/AppDatabase.kt +++ b/app/src/main/java/com/placeholder/sherpai2/data/local/AppDatabase.kt @@ -44,14 +44,15 @@ import com.placeholder.sherpai2.data.local.entity.* PhotoFaceTagEntity::class, PersonAgeTagEntity::class, FaceCacheEntity::class, - UserFeedbackEntity::class, // NEW: User corrections + UserFeedbackEntity::class, + PersonStatisticsEntity::class, // Pre-computed person stats // ===== COLLECTIONS ===== CollectionEntity::class, CollectionImageEntity::class, CollectionFilterEntity::class ], - version = 10, // INCREMENTED for user feedback + version = 11, // INCREMENTED for person statistics exportSchema = false ) abstract class AppDatabase : RoomDatabase() { @@ -70,7 +71,8 @@ abstract class AppDatabase : RoomDatabase() { abstract fun photoFaceTagDao(): PhotoFaceTagDao abstract fun personAgeTagDao(): PersonAgeTagDao abstract fun faceCacheDao(): FaceCacheDao - abstract fun userFeedbackDao(): UserFeedbackDao // NEW + abstract fun userFeedbackDao(): UserFeedbackDao + abstract fun personStatisticsDao(): PersonStatisticsDao // ===== COLLECTIONS DAO ===== abstract fun collectionDao(): CollectionDao @@ -242,13 +244,41 @@ val MIGRATION_9_10 = object : Migration(9, 10) { } } +/** + * MIGRATION 10 → 11 (Person Statistics) + * + * Changes: + * 1. Create person_statistics table for pre-computed aggregates + */ +val MIGRATION_10_11 = object : Migration(10, 11) { + override fun migrate(database: SupportSQLiteDatabase) { + + // Create person_statistics table + database.execSQL(""" + CREATE TABLE IF NOT EXISTS person_statistics ( + personId TEXT PRIMARY KEY NOT NULL, + photoCount INTEGER NOT NULL DEFAULT 0, + firstPhotoDate INTEGER NOT NULL DEFAULT 0, + lastPhotoDate INTEGER NOT NULL DEFAULT 0, + averageConfidence REAL NOT NULL DEFAULT 0, + agesWithPhotos TEXT, + updatedAt INTEGER NOT NULL DEFAULT 0, + FOREIGN KEY(personId) REFERENCES persons(id) ON DELETE CASCADE + ) + """) + + // Index for sorting by photo count (People Dashboard) + database.execSQL("CREATE INDEX IF NOT EXISTS index_person_statistics_photoCount ON person_statistics(photoCount)") + } +} + /** * PRODUCTION MIGRATION NOTES: * * Before shipping to users, update DatabaseModule to use migrations: * * Room.databaseBuilder(context, AppDatabase::class.java, "sherpai.db") - * .addMigrations(MIGRATION_7_8, MIGRATION_8_9, MIGRATION_9_10) // Add all migrations + * .addMigrations(MIGRATION_7_8, MIGRATION_8_9, MIGRATION_9_10, MIGRATION_10_11) // Add all migrations * // .fallbackToDestructiveMigration() // Remove this * .build() */ \ No newline at end of file diff --git a/app/src/main/java/com/placeholder/sherpai2/data/local/dao/Facecachedao.kt b/app/src/main/java/com/placeholder/sherpai2/data/local/dao/Facecachedao.kt index 146b703..108f263 100644 --- a/app/src/main/java/com/placeholder/sherpai2/data/local/dao/Facecachedao.kt +++ b/app/src/main/java/com/placeholder/sherpai2/data/local/dao/Facecachedao.kt @@ -233,6 +233,33 @@ interface FaceCacheDao { limit: Int = 500 ): List + /** + * Get premium face CANDIDATES - same criteria but WITHOUT embedding requirement. + * Used to find faces that need embedding generation. + */ + @Query(""" + SELECT fc.* FROM face_cache fc + INNER JOIN images i ON fc.imageId = i.imageId + WHERE i.faceCount = 1 + AND fc.faceAreaRatio >= :minAreaRatio + AND fc.isFrontal = 1 + AND fc.qualityScore >= :minQuality + AND fc.embedding IS NULL + ORDER BY fc.qualityScore DESC, fc.faceAreaRatio DESC + LIMIT :limit + """) + suspend fun getPremiumFaceCandidatesNeedingEmbeddings( + minAreaRatio: Float = 0.10f, + minQuality: Float = 0.7f, + limit: Int = 500 + ): List + + /** + * Update embedding for a face cache entry + */ + @Query("UPDATE face_cache SET embedding = :embedding WHERE imageId = :imageId AND faceIndex = :faceIndex") + suspend fun updateEmbedding(imageId: String, faceIndex: Int, embedding: String) + /** * Count of premium faces available */ diff --git a/app/src/main/java/com/placeholder/sherpai2/data/local/dao/Photofacetagdao.kt b/app/src/main/java/com/placeholder/sherpai2/data/local/dao/Photofacetagdao.kt index 842e026..93b8d90 100644 --- a/app/src/main/java/com/placeholder/sherpai2/data/local/dao/Photofacetagdao.kt +++ b/app/src/main/java/com/placeholder/sherpai2/data/local/dao/Photofacetagdao.kt @@ -83,9 +83,89 @@ interface PhotoFaceTagDao { */ @Query("SELECT * FROM photo_face_tags ORDER BY detectedAt DESC LIMIT :limit") suspend fun getRecentlyDetectedFaces(limit: Int): List + + // ===== CO-OCCURRENCE QUERIES ===== + + /** + * Find people who appear in photos together with a given person. + * Returns list of (otherFaceModelId, count) sorted by count descending. + * Use case: "Who appears most with Mom?" or "Show photos of Mom WITH Dad" + */ + @Query(""" + SELECT pft2.faceModelId as otherFaceModelId, COUNT(DISTINCT pft1.imageId) as coCount + FROM photo_face_tags pft1 + INNER JOIN photo_face_tags pft2 ON pft1.imageId = pft2.imageId + WHERE pft1.faceModelId = :faceModelId + AND pft2.faceModelId != :faceModelId + GROUP BY pft2.faceModelId + ORDER BY coCount DESC + """) + suspend fun getCoOccurrences(faceModelId: String): List + + /** + * Get images where BOTH people appear together. + */ + @Query(""" + SELECT DISTINCT pft1.imageId + FROM photo_face_tags pft1 + INNER JOIN photo_face_tags pft2 ON pft1.imageId = pft2.imageId + WHERE pft1.faceModelId = :faceModelId1 + AND pft2.faceModelId = :faceModelId2 + ORDER BY pft1.detectedAt DESC + """) + suspend fun getImagesWithBothPeople(faceModelId1: String, faceModelId2: String): List + + /** + * Get images where person appears ALONE (no other trained faces). + */ + @Query(""" + SELECT imageId FROM photo_face_tags + WHERE faceModelId = :faceModelId + AND imageId NOT IN ( + SELECT imageId FROM photo_face_tags + WHERE faceModelId != :faceModelId + ) + ORDER BY detectedAt DESC + """) + suspend fun getImagesWithPersonAlone(faceModelId: String): List + + /** + * Get images where ALL specified people appear (N-way intersection). + * For "Intersection Search" moonshot feature. + */ + @Query(""" + SELECT imageId FROM photo_face_tags + WHERE faceModelId IN (:faceModelIds) + GROUP BY imageId + HAVING COUNT(DISTINCT faceModelId) = :requiredCount + """) + suspend fun getImagesWithAllPeople(faceModelIds: List, requiredCount: Int): List + + /** + * Get images with at least N of the specified people (family portrait detection). + */ + @Query(""" + SELECT imageId, COUNT(DISTINCT faceModelId) as memberCount + FROM photo_face_tags + WHERE faceModelId IN (:faceModelIds) + GROUP BY imageId + HAVING memberCount >= :minMembers + ORDER BY memberCount DESC + """) + suspend fun getFamilyPortraits(faceModelIds: List, minMembers: Int): List } +data class FamilyPortraitResult( + val imageId: String, + val memberCount: Int +) + data class FaceModelPhotoCount( val faceModelId: String, val photoCount: Int ) + +data class PersonCoOccurrence( + val otherFaceModelId: String, + val coCount: Int +) diff --git a/app/src/main/java/com/placeholder/sherpai2/data/local/entity/Facecacheentity.kt b/app/src/main/java/com/placeholder/sherpai2/data/local/entity/Facecacheentity.kt index eca4f8c..6f0ece7 100644 --- a/app/src/main/java/com/placeholder/sherpai2/data/local/entity/Facecacheentity.kt +++ b/app/src/main/java/com/placeholder/sherpai2/data/local/entity/Facecacheentity.kt @@ -99,6 +99,13 @@ data class FaceCacheEntity( companion object { const val CURRENT_CACHE_VERSION = 1 + /** + * Convert FloatArray embedding to JSON string for storage + */ + fun embeddingToJson(embedding: FloatArray): String { + return embedding.joinToString(",") + } + /** * Create from ML Kit face detection result */ diff --git a/app/src/main/java/com/placeholder/sherpai2/domain/usecase/Populatefacedetectioncacheusecase.kt b/app/src/main/java/com/placeholder/sherpai2/domain/usecase/Populatefacedetectioncacheusecase.kt index ff5fabd..f082af8 100644 --- a/app/src/main/java/com/placeholder/sherpai2/domain/usecase/Populatefacedetectioncacheusecase.kt +++ b/app/src/main/java/com/placeholder/sherpai2/domain/usecase/Populatefacedetectioncacheusecase.kt @@ -75,7 +75,21 @@ class PopulateFaceDetectionCacheUseCase @Inject constructor( ) try { - val imagesToScan = imageDao.getImagesNeedingFaceDetection() + // Get images that need face detection (hasFaces IS NULL) + var imagesToScan = imageDao.getImagesNeedingFaceDetection() + + // CRITICAL FIX: Also check for images marked as having faces but no FaceCacheEntity + if (imagesToScan.isEmpty()) { + val faceStats = faceCacheDao.getCacheStats() + if (faceStats.totalFaces == 0) { + // FaceCacheEntity is empty - rescan images that have faces + val imagesWithFaces = imageDao.getImagesWithFaces() + if (imagesWithFaces.isNotEmpty()) { + Log.w(TAG, "FaceCacheEntity empty but ${imagesWithFaces.size} images have faces - rescanning") + imagesToScan = imagesWithFaces + } + } + } if (imagesToScan.isEmpty()) { Log.d(TAG, "No images need scanning") @@ -184,7 +198,7 @@ class PopulateFaceDetectionCacheUseCase @Inject constructor( imageUri = image.imageUri ) - // Create FaceCacheEntity entries for each face + // Create FaceCacheEntity entries for each face (NO embeddings - generated on demand) val faceCacheEntries = faces.mapIndexed { index, face -> createFaceCacheEntry( imageId = image.imageId, @@ -205,7 +219,8 @@ class PopulateFaceDetectionCacheUseCase @Inject constructor( /** * Create FaceCacheEntity from ML Kit Face * - * Uses FaceCacheEntity.create() which calculates quality metrics automatically + * Uses FaceCacheEntity.create() which calculates quality metrics automatically. + * Embeddings are NOT generated here - they're generated on-demand in Training/Discovery. */ private fun createFaceCacheEntry( imageId: String, @@ -225,7 +240,7 @@ class PopulateFaceDetectionCacheUseCase @Inject constructor( imageHeight = imageHeight, confidence = 0.9f, // High confidence from accurate detector isFrontal = isFrontal, - embedding = null // Will be generated later during Discovery + embedding = null // Generated on-demand in Training/Discovery ) } @@ -312,13 +327,27 @@ class PopulateFaceDetectionCacheUseCase @Inject constructor( val imageStats = imageDao.getFaceCacheStats() val faceStats = faceCacheDao.getCacheStats() + // CRITICAL FIX: If ImageEntity says "scanned" but FaceCacheEntity is empty, + // we need to re-scan. This happens after DB migration clears face_cache table. + val imagesWithFaces = imageStats?.imagesWithFaces ?: 0 + val facesCached = faceStats.totalFaces + + // If we have images marked as having faces but no FaceCacheEntity entries, + // those images need re-scanning + val needsRescan = if (imagesWithFaces > 0 && facesCached == 0) { + Log.w(TAG, "⚠️ FaceCacheEntity is empty but $imagesWithFaces images marked as having faces - forcing rescan") + imagesWithFaces + } else { + imageStats?.needsScanning ?: 0 + } + CacheStats( totalImages = imageStats?.totalImages ?: 0, imagesWithFaceCache = imageStats?.imagesWithFaceCache ?: 0, - imagesWithFaces = imageStats?.imagesWithFaces ?: 0, + imagesWithFaces = imagesWithFaces, imagesWithoutFaces = imageStats?.imagesWithoutFaces ?: 0, - needsScanning = imageStats?.needsScanning ?: 0, - totalFacesCached = faceStats.totalFaces, + needsScanning = needsRescan, + totalFacesCached = facesCached, facesWithEmbeddings = faceStats.withEmbeddings, averageQuality = faceStats.avgQuality ) diff --git a/app/src/main/java/com/placeholder/sherpai2/ui/imagedetail/ImageDetailScreen.kt b/app/src/main/java/com/placeholder/sherpai2/ui/imagedetail/ImageDetailScreen.kt index 18e47a4..51e2668 100644 --- a/app/src/main/java/com/placeholder/sherpai2/ui/imagedetail/ImageDetailScreen.kt +++ b/app/src/main/java/com/placeholder/sherpai2/ui/imagedetail/ImageDetailScreen.kt @@ -20,6 +20,7 @@ import androidx.lifecycle.compose.collectAsStateWithLifecycle import androidx.navigation.NavController import coil.compose.AsyncImage import com.placeholder.sherpai2.data.local.entity.TagEntity +import com.placeholder.sherpai2.ui.imagedetail.viewmodel.FaceTagInfo import com.placeholder.sherpai2.ui.imagedetail.viewmodel.ImageDetailViewModel import net.engawapg.lib.zoomable.rememberZoomState import net.engawapg.lib.zoomable.zoomable @@ -51,8 +52,12 @@ fun ImageDetailScreen( } val tags by viewModel.tags.collectAsStateWithLifecycle() + val faceTags by viewModel.faceTags.collectAsStateWithLifecycle() var showTags by remember { mutableStateOf(false) } + // Total tag count for badge + val totalTagCount = tags.size + faceTags.size + // Navigation state val currentIndex = if (allImageUris.isNotEmpty()) allImageUris.indexOf(imageUri) else -1 val hasNavigation = allImageUris.isNotEmpty() && currentIndex >= 0 @@ -84,27 +89,35 @@ fun ImageDetailScreen( horizontalArrangement = Arrangement.spacedBy(4.dp), verticalAlignment = Alignment.CenterVertically ) { - if (tags.isNotEmpty()) { + if (totalTagCount > 0) { Badge( containerColor = if (showTags) MaterialTheme.colorScheme.primary + else if (faceTags.isNotEmpty()) + MaterialTheme.colorScheme.tertiary else MaterialTheme.colorScheme.surfaceVariant ) { Text( - tags.size.toString(), + totalTagCount.toString(), color = if (showTags) MaterialTheme.colorScheme.onPrimary + else if (faceTags.isNotEmpty()) + MaterialTheme.colorScheme.onTertiary else MaterialTheme.colorScheme.onSurfaceVariant ) } } Icon( - if (showTags) Icons.Default.Label else Icons.Default.LocalOffer, + if (faceTags.isNotEmpty()) Icons.Default.Face + else if (showTags) Icons.Default.Label + else Icons.Default.LocalOffer, "Show Tags", tint = if (showTags) MaterialTheme.colorScheme.primary + else if (faceTags.isNotEmpty()) + MaterialTheme.colorScheme.tertiary else MaterialTheme.colorScheme.onSurfaceVariant ) @@ -189,6 +202,30 @@ fun ImageDetailScreen( contentPadding = PaddingValues(16.dp), verticalArrangement = Arrangement.spacedBy(8.dp) ) { + // Face Tags Section (People in Photo) + if (faceTags.isNotEmpty()) { + item { + Text( + "People (${faceTags.size})", + style = MaterialTheme.typography.titleMedium, + fontWeight = FontWeight.Bold, + color = MaterialTheme.colorScheme.tertiary + ) + } + + items(faceTags, key = { it.tagId }) { faceTag -> + FaceTagCard( + faceTag = faceTag, + onRemove = { viewModel.removeFaceTag(faceTag) } + ) + } + + item { + Spacer(modifier = Modifier.height(8.dp)) + } + } + + // Regular Tags Section item { Text( "Tags (${tags.size})", @@ -197,7 +234,7 @@ fun ImageDetailScreen( ) } - if (tags.isEmpty()) { + if (tags.isEmpty() && faceTags.isEmpty()) { item { Text( "No tags yet", @@ -205,6 +242,14 @@ fun ImageDetailScreen( color = MaterialTheme.colorScheme.onSurfaceVariant ) } + } else if (tags.isEmpty()) { + item { + Text( + "No other tags", + style = MaterialTheme.typography.bodySmall, + color = MaterialTheme.colorScheme.onSurfaceVariant + ) + } } items(tags, key = { it.tagId }) { tag -> @@ -220,6 +265,83 @@ fun ImageDetailScreen( } } +@Composable +private fun FaceTagCard( + faceTag: FaceTagInfo, + onRemove: () -> Unit +) { + Card( + modifier = Modifier.fillMaxWidth(), + colors = CardDefaults.cardColors( + containerColor = MaterialTheme.colorScheme.tertiaryContainer + ), + shape = RoundedCornerShape(8.dp) + ) { + Row( + modifier = Modifier + .fillMaxWidth() + .padding(12.dp), + horizontalArrangement = Arrangement.SpaceBetween, + verticalAlignment = Alignment.CenterVertically + ) { + Column(modifier = Modifier.weight(1f)) { + Row( + horizontalArrangement = Arrangement.spacedBy(8.dp), + verticalAlignment = Alignment.CenterVertically + ) { + Icon( + imageVector = Icons.Default.Face, + contentDescription = null, + modifier = Modifier.size(20.dp), + tint = MaterialTheme.colorScheme.tertiary + ) + Text( + text = faceTag.personName, + style = MaterialTheme.typography.bodyLarge, + fontWeight = FontWeight.SemiBold + ) + } + + Row( + horizontalArrangement = Arrangement.spacedBy(4.dp), + verticalAlignment = Alignment.CenterVertically + ) { + Text( + text = "Face Recognition", + style = MaterialTheme.typography.labelSmall, + color = MaterialTheme.colorScheme.onSurfaceVariant + ) + Text( + text = "•", + style = MaterialTheme.typography.labelSmall, + color = MaterialTheme.colorScheme.onSurfaceVariant + ) + Text( + text = "${(faceTag.confidence * 100).toInt()}% confidence", + style = MaterialTheme.typography.labelSmall, + color = if (faceTag.confidence >= 0.7f) + MaterialTheme.colorScheme.primary + else if (faceTag.confidence >= 0.5f) + MaterialTheme.colorScheme.secondary + else + MaterialTheme.colorScheme.error + ) + } + } + + // Remove button + IconButton( + onClick = onRemove, + colors = IconButtonDefaults.iconButtonColors( + contentColor = MaterialTheme.colorScheme.error + ) + ) { + Icon(Icons.Default.Delete, "Remove face tag") + } + } + } +} + @Composable private fun TagCard( tag: TagEntity, diff --git a/app/src/main/java/com/placeholder/sherpai2/ui/imagedetail/viewmodel/ImageDetailViewModel.kt b/app/src/main/java/com/placeholder/sherpai2/ui/imagedetail/viewmodel/ImageDetailViewModel.kt index 4d08aed..ceb49db 100644 --- a/app/src/main/java/com/placeholder/sherpai2/ui/imagedetail/viewmodel/ImageDetailViewModel.kt +++ b/app/src/main/java/com/placeholder/sherpai2/ui/imagedetail/viewmodel/ImageDetailViewModel.kt @@ -2,6 +2,10 @@ package com.placeholder.sherpai2.ui.imagedetail.viewmodel import androidx.lifecycle.ViewModel import androidx.lifecycle.viewModelScope +import com.placeholder.sherpai2.data.local.dao.FaceModelDao +import com.placeholder.sherpai2.data.local.dao.ImageDao +import com.placeholder.sherpai2.data.local.dao.PersonDao +import com.placeholder.sherpai2.data.local.dao.PhotoFaceTagDao import com.placeholder.sherpai2.data.local.entity.TagEntity import com.placeholder.sherpai2.domain.repository.TaggingRepository import dagger.hilt.android.lifecycle.HiltViewModel @@ -10,17 +14,33 @@ import kotlinx.coroutines.flow.* import kotlinx.coroutines.launch import javax.inject.Inject +/** + * Represents a person tagged in this photo via face recognition + */ +data class FaceTagInfo( + val personId: String, + val personName: String, + val confidence: Float, + val faceModelId: String, + val tagId: String +) + /** * ImageDetailViewModel * * Owns: * - Image context * - Tag write operations + * - Face tag display (people recognized in photo) */ @HiltViewModel @OptIn(ExperimentalCoroutinesApi::class) class ImageDetailViewModel @Inject constructor( - private val tagRepository: TaggingRepository + private val tagRepository: TaggingRepository, + private val imageDao: ImageDao, + private val photoFaceTagDao: PhotoFaceTagDao, + private val faceModelDao: FaceModelDao, + private val personDao: PersonDao ) : ViewModel() { private val imageUri = MutableStateFlow(null) @@ -37,8 +57,43 @@ class ImageDetailViewModel @Inject constructor( initialValue = emptyList() ) + // Face tags (people recognized in this photo) + private val _faceTags = MutableStateFlow>(emptyList()) + val faceTags: StateFlow> = _faceTags.asStateFlow() + fun loadImage(uri: String) { imageUri.value = uri + loadFaceTags(uri) + } + + private fun loadFaceTags(uri: String) { + viewModelScope.launch { + try { + // Get imageId from URI + val image = imageDao.getImageByUri(uri) ?: return@launch + + // Get face tags for this image + val faceTags = photoFaceTagDao.getTagsForImage(image.imageId) + + // Resolve to person names + val faceTagInfos = faceTags.mapNotNull { tag -> + val faceModel = faceModelDao.getFaceModelById(tag.faceModelId) ?: return@mapNotNull null + val person = personDao.getPersonById(faceModel.personId) ?: return@mapNotNull null + + FaceTagInfo( + personId = person.id, + personName = person.name, + confidence = tag.confidence, + faceModelId = tag.faceModelId, + tagId = tag.id + ) + } + + _faceTags.value = faceTagInfos.sortedByDescending { it.confidence } + } catch (e: Exception) { + _faceTags.value = emptyList() + } + } } fun addTag(value: String) { @@ -54,4 +109,15 @@ class ImageDetailViewModel @Inject constructor( tagRepository.removeTagFromImage(uri, tag.value) } } + + /** + * Remove a face tag (person recognition) + */ + fun removeFaceTag(faceTagInfo: FaceTagInfo) { + viewModelScope.launch { + photoFaceTagDao.deleteTagById(faceTagInfo.tagId) + // Reload face tags + imageUri.value?.let { loadFaceTags(it) } + } + } } diff --git a/app/src/main/java/com/placeholder/sherpai2/ui/modelinventory/Personinventoryscreen.kt b/app/src/main/java/com/placeholder/sherpai2/ui/modelinventory/Personinventoryscreen.kt index b56951c..68641b5 100644 --- a/app/src/main/java/com/placeholder/sherpai2/ui/modelinventory/Personinventoryscreen.kt +++ b/app/src/main/java/com/placeholder/sherpai2/ui/modelinventory/Personinventoryscreen.kt @@ -95,6 +95,9 @@ fun PersonInventoryScreen( }, onDelete = { personId -> viewModel.deletePerson(personId) + }, + onClearTags = { personId -> + viewModel.clearTagsForPerson(personId) } ) } @@ -319,7 +322,8 @@ private fun PersonList( persons: List, onScan: (String) -> Unit, onView: (String) -> Unit, - onDelete: (String) -> Unit + onDelete: (String) -> Unit, + onClearTags: (String) -> Unit ) { LazyColumn( contentPadding = PaddingValues(vertical = 8.dp) @@ -332,7 +336,8 @@ private fun PersonList( person = person, onScan = { onScan(person.person.id) }, onView = { onView(person.person.id) }, - onDelete = { onDelete(person.person.id) } + onDelete = { onDelete(person.person.id) }, + onClearTags = { onClearTags(person.person.id) } ) } } @@ -343,9 +348,34 @@ private fun PersonCard( person: PersonWithModelInfo, onScan: () -> Unit, onView: () -> Unit, - onDelete: () -> Unit + onDelete: () -> Unit, + onClearTags: () -> Unit ) { var showDeleteDialog by remember { mutableStateOf(false) } + var showClearDialog by remember { mutableStateOf(false) } + + if (showClearDialog) { + AlertDialog( + onDismissRequest = { showClearDialog = false }, + title = { Text("Clear tags for ${person.person.name}?") }, + text = { Text("This will remove all ${person.taggedPhotoCount} photo tags but keep the face model. You can re-scan after clearing.") }, + confirmButton = { + TextButton( + onClick = { + showClearDialog = false + onClearTags() + } + ) { + Text("Clear Tags", color = MaterialTheme.colorScheme.error) + } + }, + dismissButton = { + TextButton(onClick = { showClearDialog = false }) { + Text("Cancel") + } + } + ) + } if (showDeleteDialog) { AlertDialog( @@ -413,6 +443,17 @@ private fun PersonCard( ) } + // Clear tags button (if has tags) + if (person.taggedPhotoCount > 0) { + IconButton(onClick = { showClearDialog = true }) { + Icon( + Icons.Default.ClearAll, + contentDescription = "Clear Tags", + tint = MaterialTheme.colorScheme.secondary + ) + } + } + // Delete button IconButton(onClick = { showDeleteDialog = true }) { Icon( diff --git a/app/src/main/java/com/placeholder/sherpai2/ui/modelinventory/Personinventoryviewmodel.kt b/app/src/main/java/com/placeholder/sherpai2/ui/modelinventory/Personinventoryviewmodel.kt index bd5c831..549a42d 100644 --- a/app/src/main/java/com/placeholder/sherpai2/ui/modelinventory/Personinventoryviewmodel.kt +++ b/app/src/main/java/com/placeholder/sherpai2/ui/modelinventory/Personinventoryviewmodel.kt @@ -105,6 +105,21 @@ class PersonInventoryViewModel @Inject constructor( } } + /** + * Clear all face tags for a person (keep model, allow rescan) + */ + fun clearTagsForPerson(personId: String) { + viewModelScope.launch(Dispatchers.IO) { + try { + val faceModel = faceModelDao.getFaceModelByPersonId(personId) + if (faceModel != null) { + photoFaceTagDao.deleteTagsForFaceModel(faceModel.id) + } + loadPersons() + } catch (e: Exception) {} + } + } + fun scanForPerson(personId: String) { viewModelScope.launch(Dispatchers.IO) { try { @@ -133,10 +148,20 @@ class PersonInventoryViewModel @Inject constructor( .build() val detector = FaceDetection.getClient(detectorOptions) - val modelEmbedding = faceModel.getEmbeddingArray() - val faceNetModel = FaceNetModel(context) + // CRITICAL: Use ALL centroids for matching + val modelCentroids = faceModel.getCentroids().map { it.getEmbeddingArray() } val trainingCount = faceModel.trainingImageCount - val baseThreshold = ThresholdStrategy.getLiberalThreshold(trainingCount) + android.util.Log.e("PersonScan", "=== CENTROIDS: ${modelCentroids.size}, trainingCount: $trainingCount ===") + + if (modelCentroids.isEmpty()) { + _scanningState.value = ScanningState.Error("No centroids found") + return@launch + } + + val faceNetModel = FaceNetModel(context) + // Production threshold - balance precision vs recall + val baseThreshold = 0.58f + android.util.Log.d("PersonScan", "Using threshold: $baseThreshold, centroids: ${modelCentroids.size}") val completed = AtomicInteger(0) val facesFound = AtomicInteger(0) @@ -148,7 +173,7 @@ class PersonInventoryViewModel @Inject constructor( val jobs = untaggedImages.map { image -> async { semaphore.withPermit { - processImage(image, detector, faceNetModel, modelEmbedding, trainingCount, baseThreshold, personId, faceModel.id, batchMatches, batchUpdateMutex, completed, facesFound, startTime, totalToScan, person.name) + processImage(image, detector, faceNetModel, modelCentroids, trainingCount, baseThreshold, personId, faceModel.id, batchMatches, batchUpdateMutex, completed, facesFound, startTime, totalToScan, person.name) } } } @@ -175,7 +200,7 @@ class PersonInventoryViewModel @Inject constructor( private suspend fun processImage( image: ImageEntity, detector: com.google.mlkit.vision.face.FaceDetector, faceNetModel: FaceNetModel, - modelEmbedding: FloatArray, trainingCount: Int, baseThreshold: Float, personId: String, faceModelId: String, + modelCentroids: List, trainingCount: Int, baseThreshold: Float, personId: String, faceModelId: String, batchMatches: MutableList>, batchUpdateMutex: Mutex, completed: AtomicInteger, facesFound: AtomicInteger, startTime: Long, totalToScan: Int, personName: String ) { @@ -212,14 +237,19 @@ class PersonInventoryViewModel @Inject constructor( (face.boundingBox.bottom * scaleY).toInt() ) - val faceBitmap = loadFaceRegion(uri, scaledBounds) ?: continue + // CRITICAL: Add padding to face crop (same as training) + val faceBitmap = loadFaceRegionWithPadding(uri, scaledBounds, sizeOpts.outWidth, sizeOpts.outHeight) ?: continue val faceEmbedding = faceNetModel.generateEmbedding(faceBitmap) - val similarity = faceNetModel.calculateSimilarity(faceEmbedding, modelEmbedding) faceBitmap.recycle() - if (similarity >= threshold) { + // Match against ALL centroids, use best match + val bestSimilarity = modelCentroids.maxOfOrNull { centroid -> + faceNetModel.calculateSimilarity(faceEmbedding, centroid) + } ?: 0f + + if (bestSimilarity >= threshold) { batchUpdateMutex.withLock { - batchMatches.add(Triple(personId, image.imageId, similarity)) + batchMatches.add(Triple(personId, image.imageId, bestSimilarity)) facesFound.incrementAndGet() if (batchMatches.size >= BATCH_DB_SIZE) { saveBatchMatches(batchMatches.toList(), faceModelId) @@ -250,18 +280,32 @@ class PersonInventoryViewModel @Inject constructor( } catch (e: Exception) { null } } - private fun loadFaceRegion(uri: Uri, bounds: android.graphics.Rect): Bitmap? { + /** + * Load face region WITH 25% padding - CRITICAL for matching training conditions + */ + private fun loadFaceRegionWithPadding(uri: Uri, bounds: android.graphics.Rect, imgWidth: Int, imgHeight: Int): Bitmap? { return try { val full = context.contentResolver.openInputStream(uri)?.use { BitmapFactory.decodeStream(it, null, BitmapFactory.Options().apply { inPreferredConfig = Bitmap.Config.ARGB_8888 }) } ?: return null - val safeLeft = bounds.left.coerceIn(0, full.width - 1) - val safeTop = bounds.top.coerceIn(0, full.height - 1) - val safeWidth = bounds.width().coerceAtMost(full.width - safeLeft) - val safeHeight = bounds.height().coerceAtMost(full.height - safeTop) + // Add 25% padding (same as training) + val padding = (kotlin.math.max(bounds.width(), bounds.height()) * 0.25f).toInt() - val cropped = Bitmap.createBitmap(full, safeLeft, safeTop, safeWidth, safeHeight) + val left = (bounds.left - padding).coerceAtLeast(0) + val top = (bounds.top - padding).coerceAtLeast(0) + val right = (bounds.right + padding).coerceAtMost(full.width) + val bottom = (bounds.bottom + padding).coerceAtMost(full.height) + + val width = right - left + val height = bottom - top + + if (width <= 0 || height <= 0) { + full.recycle() + return null + } + + val cropped = Bitmap.createBitmap(full, left, top, width, height) full.recycle() cropped } catch (e: Exception) { null } diff --git a/app/src/main/java/com/placeholder/sherpai2/ui/trainingprep/TrainViewModel.kt b/app/src/main/java/com/placeholder/sherpai2/ui/trainingprep/TrainViewModel.kt index 2e287a4..fc6c324 100644 --- a/app/src/main/java/com/placeholder/sherpai2/ui/trainingprep/TrainViewModel.kt +++ b/app/src/main/java/com/placeholder/sherpai2/ui/trainingprep/TrainViewModel.kt @@ -192,10 +192,11 @@ class TrainViewModel @Inject constructor( .first() if (backgroundTaggingEnabled) { + // Lower threshold (0.55) since we use multi-centroid matching val scanRequest = LibraryScanWorker.createWorkRequest( personId = personId, personName = personName, - threshold = 0.65f + threshold = 0.55f ) workManager.enqueue(scanRequest) } diff --git a/app/src/main/java/com/placeholder/sherpai2/ui/trainingprep/Trainingphotoselectorscreen.kt b/app/src/main/java/com/placeholder/sherpai2/ui/trainingprep/Trainingphotoselectorscreen.kt index 9f39853..43364e8 100644 --- a/app/src/main/java/com/placeholder/sherpai2/ui/trainingprep/Trainingphotoselectorscreen.kt +++ b/app/src/main/java/com/placeholder/sherpai2/ui/trainingprep/Trainingphotoselectorscreen.kt @@ -49,6 +49,7 @@ fun TrainingPhotoSelectorScreen( val isRanking by viewModel.isRanking.collectAsStateWithLifecycle() val showPremiumOnly by viewModel.showPremiumOnly.collectAsStateWithLifecycle() val premiumCount by viewModel.premiumCount.collectAsStateWithLifecycle() + val embeddingProgress by viewModel.embeddingProgress.collectAsStateWithLifecycle() Scaffold( topBar = { @@ -155,7 +156,33 @@ fun TrainingPhotoSelectorScreen( modifier = Modifier.fillMaxSize(), contentAlignment = Alignment.Center ) { - CircularProgressIndicator() + Column( + horizontalAlignment = Alignment.CenterHorizontally, + verticalArrangement = Arrangement.spacedBy(16.dp) + ) { + CircularProgressIndicator() + // Capture value to avoid race condition + val progress = embeddingProgress + if (progress != null) { + Text( + "Preparing faces: ${progress.current}/${progress.total}", + style = MaterialTheme.typography.bodyMedium, + color = MaterialTheme.colorScheme.onSurfaceVariant + ) + LinearProgressIndicator( + progress = { progress.current.toFloat() / progress.total }, + modifier = Modifier + .width(200.dp) + .padding(top = 8.dp) + ) + } else { + Text( + "Loading premium faces...", + style = MaterialTheme.typography.bodyMedium, + color = MaterialTheme.colorScheme.onSurfaceVariant + ) + } + } } } photos.isEmpty() -> { diff --git a/app/src/main/java/com/placeholder/sherpai2/ui/trainingprep/Trainingphotoselectorviewmodel.kt b/app/src/main/java/com/placeholder/sherpai2/ui/trainingprep/Trainingphotoselectorviewmodel.kt index 284f0a7..d643061 100644 --- a/app/src/main/java/com/placeholder/sherpai2/ui/trainingprep/Trainingphotoselectorviewmodel.kt +++ b/app/src/main/java/com/placeholder/sherpai2/ui/trainingprep/Trainingphotoselectorviewmodel.kt @@ -1,20 +1,31 @@ package com.placeholder.sherpai2.ui.trainingprep +import android.app.Application +import android.graphics.Bitmap +import android.graphics.BitmapFactory +import android.graphics.Rect +import android.net.Uri import android.util.Log -import androidx.lifecycle.ViewModel +import androidx.lifecycle.AndroidViewModel import androidx.lifecycle.viewModelScope import com.placeholder.sherpai2.data.local.dao.FaceCacheDao import com.placeholder.sherpai2.data.local.dao.ImageDao +import com.placeholder.sherpai2.data.local.entity.FaceCacheEntity import com.placeholder.sherpai2.data.local.entity.ImageEntity import com.placeholder.sherpai2.domain.similarity.FaceSimilarityScorer +import com.placeholder.sherpai2.ml.FaceNetModel import dagger.hilt.android.lifecycle.HiltViewModel +import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.Job import kotlinx.coroutines.delay import kotlinx.coroutines.flow.MutableStateFlow import kotlinx.coroutines.flow.StateFlow import kotlinx.coroutines.flow.asStateFlow import kotlinx.coroutines.launch +import kotlinx.coroutines.withContext import javax.inject.Inject +import kotlin.math.max +import kotlin.math.min /** * TrainingPhotoSelectorViewModel - PREMIUM GRID + ROLLING SCAN @@ -27,15 +38,18 @@ import javax.inject.Inject */ @HiltViewModel class TrainingPhotoSelectorViewModel @Inject constructor( + application: Application, private val imageDao: ImageDao, private val faceCacheDao: FaceCacheDao, - private val faceSimilarityScorer: FaceSimilarityScorer -) : ViewModel() { + private val faceSimilarityScorer: FaceSimilarityScorer, + private val faceNetModel: FaceNetModel +) : AndroidViewModel(application) { companion object { private const val TAG = "PremiumSelector" private const val MIN_SEEDS_FOR_ROLLING_SCAN = 1 private const val MAX_SEEDS_FOR_ROLLING_SCAN = 5 + private const val MAX_EMBEDDINGS_TO_GENERATE = 500 } // All photos (for fallback / full list) @@ -56,6 +70,12 @@ class TrainingPhotoSelectorViewModel @Inject constructor( private val _isRanking = MutableStateFlow(false) val isRanking: StateFlow = _isRanking.asStateFlow() + // Embedding generation progress + private val _embeddingProgress = MutableStateFlow(null) + val embeddingProgress: StateFlow = _embeddingProgress.asStateFlow() + + data class EmbeddingProgress(val current: Int, val total: Int) + // Premium mode toggle private val _showPremiumOnly = MutableStateFlow(true) val showPremiumOnly: StateFlow = _showPremiumOnly.asStateFlow() @@ -79,20 +99,47 @@ class TrainingPhotoSelectorViewModel @Inject constructor( /** * Load PREMIUM faces first (solo, large, frontal, high quality) + * If no embeddings exist, generate them on-demand for premium candidates */ private fun loadPremiumFaces() { viewModelScope.launch { try { _isLoading.value = true - // Get premium faces from cache - val premiumFaceCache = faceCacheDao.getPremiumFaces( + // First check if premium faces with embeddings exist + var premiumFaceCache = faceCacheDao.getPremiumFaces( minAreaRatio = 0.10f, minQuality = 0.7f, limit = 500 ) - Log.d(TAG, "✅ Found ${premiumFaceCache.size} premium faces") + Log.d(TAG, "📊 Found ${premiumFaceCache.size} premium faces with embeddings") + + // If no premium faces with embeddings, generate them on-demand + if (premiumFaceCache.isEmpty()) { + Log.d(TAG, "⚠️ No premium faces with embeddings - generating on-demand") + + val candidates = faceCacheDao.getPremiumFaceCandidatesNeedingEmbeddings( + minAreaRatio = 0.10f, + minQuality = 0.7f, + limit = MAX_EMBEDDINGS_TO_GENERATE + ) + + Log.d(TAG, "📦 Found ${candidates.size} premium candidates needing embeddings") + + if (candidates.isNotEmpty()) { + generateEmbeddingsForCandidates(candidates) + + // Re-query after generating + premiumFaceCache = faceCacheDao.getPremiumFaces( + minAreaRatio = 0.10f, + minQuality = 0.7f, + limit = 500 + ) + Log.d(TAG, "✅ After generation: ${premiumFaceCache.size} premium faces") + } + } + _premiumCount.value = premiumFaceCache.size // Get corresponding ImageEntities @@ -117,10 +164,108 @@ class TrainingPhotoSelectorViewModel @Inject constructor( loadAllFaces() } finally { _isLoading.value = false + _embeddingProgress.value = null } } } + /** + * Generate embeddings for premium face candidates + */ + private suspend fun generateEmbeddingsForCandidates(candidates: List) { + val context = getApplication() + val total = candidates.size + var processed = 0 + + withContext(Dispatchers.IO) { + // Get image URIs for candidates + val imageIds = candidates.map { it.imageId }.distinct() + val images = imageDao.getImagesByIds(imageIds) + val imageUriMap = images.associate { it.imageId to it.imageUri } + + for (candidate in candidates) { + try { + val imageUri = imageUriMap[candidate.imageId] ?: continue + + // Load bitmap + val bitmap = loadBitmapOptimized(context, Uri.parse(imageUri)) ?: continue + + // Crop face + val croppedFace = cropFaceWithPadding(bitmap, candidate.getBoundingBox()) + bitmap.recycle() + + if (croppedFace == null) continue + + // Generate embedding + val embedding = faceNetModel.generateEmbedding(croppedFace) + croppedFace.recycle() + + // Validate embedding + if (embedding.any { it != 0f }) { + // Save to database + val embeddingJson = FaceCacheEntity.embeddingToJson(embedding) + faceCacheDao.updateEmbedding(candidate.imageId, candidate.faceIndex, embeddingJson) + } + + } catch (e: Exception) { + Log.w(TAG, "Failed to generate embedding for ${candidate.imageId}: ${e.message}") + } + + processed++ + withContext(Dispatchers.Main) { + _embeddingProgress.value = EmbeddingProgress(processed, total) + } + } + } + + Log.d(TAG, "✅ Generated embeddings for $processed/$total candidates") + } + + private fun loadBitmapOptimized(context: android.content.Context, uri: Uri, maxDim: Int = 768): Bitmap? { + return try { + val options = BitmapFactory.Options().apply { inJustDecodeBounds = true } + context.contentResolver.openInputStream(uri)?.use { stream -> + BitmapFactory.decodeStream(stream, null, options) + } + + var sampleSize = 1 + while (options.outWidth / sampleSize > maxDim || options.outHeight / sampleSize > maxDim) { + sampleSize *= 2 + } + + val finalOptions = BitmapFactory.Options().apply { + inSampleSize = sampleSize + inPreferredConfig = Bitmap.Config.ARGB_8888 + } + + context.contentResolver.openInputStream(uri)?.use { stream -> + BitmapFactory.decodeStream(stream, null, finalOptions) + } + } catch (e: Exception) { + Log.w(TAG, "Failed to load bitmap: ${e.message}") + null + } + } + + private fun cropFaceWithPadding(bitmap: Bitmap, boundingBox: Rect): Bitmap? { + return try { + val padding = (max(boundingBox.width(), boundingBox.height()) * 0.25f).toInt() + val left = max(0, boundingBox.left - padding) + val top = max(0, boundingBox.top - padding) + val right = min(bitmap.width, boundingBox.right + padding) + val bottom = min(bitmap.height, boundingBox.bottom + padding) + val width = right - left + val height = bottom - top + + if (width > 0 && height > 0) { + Bitmap.createBitmap(bitmap, left, top, width, height) + } else null + } catch (e: Exception) { + Log.w(TAG, "Failed to crop face: ${e.message}") + null + } + } + /** * Fallback: load all photos with faces */