toofasttooclaude

This commit is contained in:
genki
2026-01-26 14:15:54 -05:00
parent 1ef8faad17
commit cfec2b980a
12 changed files with 661 additions and 42 deletions

View File

@@ -44,14 +44,15 @@ import com.placeholder.sherpai2.data.local.entity.*
PhotoFaceTagEntity::class, PhotoFaceTagEntity::class,
PersonAgeTagEntity::class, PersonAgeTagEntity::class,
FaceCacheEntity::class, FaceCacheEntity::class,
UserFeedbackEntity::class, // NEW: User corrections UserFeedbackEntity::class,
PersonStatisticsEntity::class, // Pre-computed person stats
// ===== COLLECTIONS ===== // ===== COLLECTIONS =====
CollectionEntity::class, CollectionEntity::class,
CollectionImageEntity::class, CollectionImageEntity::class,
CollectionFilterEntity::class CollectionFilterEntity::class
], ],
version = 10, // INCREMENTED for user feedback version = 11, // INCREMENTED for person statistics
exportSchema = false exportSchema = false
) )
abstract class AppDatabase : RoomDatabase() { abstract class AppDatabase : RoomDatabase() {
@@ -70,7 +71,8 @@ abstract class AppDatabase : RoomDatabase() {
abstract fun photoFaceTagDao(): PhotoFaceTagDao abstract fun photoFaceTagDao(): PhotoFaceTagDao
abstract fun personAgeTagDao(): PersonAgeTagDao abstract fun personAgeTagDao(): PersonAgeTagDao
abstract fun faceCacheDao(): FaceCacheDao abstract fun faceCacheDao(): FaceCacheDao
abstract fun userFeedbackDao(): UserFeedbackDao // NEW abstract fun userFeedbackDao(): UserFeedbackDao
abstract fun personStatisticsDao(): PersonStatisticsDao
// ===== COLLECTIONS DAO ===== // ===== COLLECTIONS DAO =====
abstract fun collectionDao(): CollectionDao abstract fun collectionDao(): CollectionDao
@@ -242,13 +244,41 @@ val MIGRATION_9_10 = object : Migration(9, 10) {
} }
} }
/**
* MIGRATION 10 → 11 (Person Statistics)
*
* Changes:
* 1. Create person_statistics table for pre-computed aggregates
*/
val MIGRATION_10_11 = object : Migration(10, 11) {
override fun migrate(database: SupportSQLiteDatabase) {
// Create person_statistics table
database.execSQL("""
CREATE TABLE IF NOT EXISTS person_statistics (
personId TEXT PRIMARY KEY NOT NULL,
photoCount INTEGER NOT NULL DEFAULT 0,
firstPhotoDate INTEGER NOT NULL DEFAULT 0,
lastPhotoDate INTEGER NOT NULL DEFAULT 0,
averageConfidence REAL NOT NULL DEFAULT 0,
agesWithPhotos TEXT,
updatedAt INTEGER NOT NULL DEFAULT 0,
FOREIGN KEY(personId) REFERENCES persons(id) ON DELETE CASCADE
)
""")
// Index for sorting by photo count (People Dashboard)
database.execSQL("CREATE INDEX IF NOT EXISTS index_person_statistics_photoCount ON person_statistics(photoCount)")
}
}
/** /**
* PRODUCTION MIGRATION NOTES: * PRODUCTION MIGRATION NOTES:
* *
* Before shipping to users, update DatabaseModule to use migrations: * Before shipping to users, update DatabaseModule to use migrations:
* *
* Room.databaseBuilder(context, AppDatabase::class.java, "sherpai.db") * Room.databaseBuilder(context, AppDatabase::class.java, "sherpai.db")
* .addMigrations(MIGRATION_7_8, MIGRATION_8_9, MIGRATION_9_10) // Add all migrations * .addMigrations(MIGRATION_7_8, MIGRATION_8_9, MIGRATION_9_10, MIGRATION_10_11) // Add all migrations
* // .fallbackToDestructiveMigration() // Remove this * // .fallbackToDestructiveMigration() // Remove this
* .build() * .build()
*/ */

View File

@@ -233,6 +233,33 @@ interface FaceCacheDao {
limit: Int = 500 limit: Int = 500
): List<FaceCacheEntity> ): List<FaceCacheEntity>
/**
* Get premium face CANDIDATES - same criteria but WITHOUT embedding requirement.
* Used to find faces that need embedding generation.
*/
@Query("""
SELECT fc.* FROM face_cache fc
INNER JOIN images i ON fc.imageId = i.imageId
WHERE i.faceCount = 1
AND fc.faceAreaRatio >= :minAreaRatio
AND fc.isFrontal = 1
AND fc.qualityScore >= :minQuality
AND fc.embedding IS NULL
ORDER BY fc.qualityScore DESC, fc.faceAreaRatio DESC
LIMIT :limit
""")
suspend fun getPremiumFaceCandidatesNeedingEmbeddings(
minAreaRatio: Float = 0.10f,
minQuality: Float = 0.7f,
limit: Int = 500
): List<FaceCacheEntity>
/**
* Update embedding for a face cache entry
*/
@Query("UPDATE face_cache SET embedding = :embedding WHERE imageId = :imageId AND faceIndex = :faceIndex")
suspend fun updateEmbedding(imageId: String, faceIndex: Int, embedding: String)
/** /**
* Count of premium faces available * Count of premium faces available
*/ */

View File

@@ -83,9 +83,89 @@ interface PhotoFaceTagDao {
*/ */
@Query("SELECT * FROM photo_face_tags ORDER BY detectedAt DESC LIMIT :limit") @Query("SELECT * FROM photo_face_tags ORDER BY detectedAt DESC LIMIT :limit")
suspend fun getRecentlyDetectedFaces(limit: Int): List<PhotoFaceTagEntity> suspend fun getRecentlyDetectedFaces(limit: Int): List<PhotoFaceTagEntity>
// ===== CO-OCCURRENCE QUERIES =====
/**
* Find people who appear in photos together with a given person.
* Returns list of (otherFaceModelId, count) sorted by count descending.
* Use case: "Who appears most with Mom?" or "Show photos of Mom WITH Dad"
*/
@Query("""
SELECT pft2.faceModelId as otherFaceModelId, COUNT(DISTINCT pft1.imageId) as coCount
FROM photo_face_tags pft1
INNER JOIN photo_face_tags pft2 ON pft1.imageId = pft2.imageId
WHERE pft1.faceModelId = :faceModelId
AND pft2.faceModelId != :faceModelId
GROUP BY pft2.faceModelId
ORDER BY coCount DESC
""")
suspend fun getCoOccurrences(faceModelId: String): List<PersonCoOccurrence>
/**
* Get images where BOTH people appear together.
*/
@Query("""
SELECT DISTINCT pft1.imageId
FROM photo_face_tags pft1
INNER JOIN photo_face_tags pft2 ON pft1.imageId = pft2.imageId
WHERE pft1.faceModelId = :faceModelId1
AND pft2.faceModelId = :faceModelId2
ORDER BY pft1.detectedAt DESC
""")
suspend fun getImagesWithBothPeople(faceModelId1: String, faceModelId2: String): List<String>
/**
* Get images where person appears ALONE (no other trained faces).
*/
@Query("""
SELECT imageId FROM photo_face_tags
WHERE faceModelId = :faceModelId
AND imageId NOT IN (
SELECT imageId FROM photo_face_tags
WHERE faceModelId != :faceModelId
)
ORDER BY detectedAt DESC
""")
suspend fun getImagesWithPersonAlone(faceModelId: String): List<String>
/**
* Get images where ALL specified people appear (N-way intersection).
* For "Intersection Search" moonshot feature.
*/
@Query("""
SELECT imageId FROM photo_face_tags
WHERE faceModelId IN (:faceModelIds)
GROUP BY imageId
HAVING COUNT(DISTINCT faceModelId) = :requiredCount
""")
suspend fun getImagesWithAllPeople(faceModelIds: List<String>, requiredCount: Int): List<String>
/**
* Get images with at least N of the specified people (family portrait detection).
*/
@Query("""
SELECT imageId, COUNT(DISTINCT faceModelId) as memberCount
FROM photo_face_tags
WHERE faceModelId IN (:faceModelIds)
GROUP BY imageId
HAVING memberCount >= :minMembers
ORDER BY memberCount DESC
""")
suspend fun getFamilyPortraits(faceModelIds: List<String>, minMembers: Int): List<FamilyPortraitResult>
} }
data class FamilyPortraitResult(
val imageId: String,
val memberCount: Int
)
data class FaceModelPhotoCount( data class FaceModelPhotoCount(
val faceModelId: String, val faceModelId: String,
val photoCount: Int val photoCount: Int
) )
data class PersonCoOccurrence(
val otherFaceModelId: String,
val coCount: Int
)

View File

@@ -99,6 +99,13 @@ data class FaceCacheEntity(
companion object { companion object {
const val CURRENT_CACHE_VERSION = 1 const val CURRENT_CACHE_VERSION = 1
/**
* Convert FloatArray embedding to JSON string for storage
*/
fun embeddingToJson(embedding: FloatArray): String {
return embedding.joinToString(",")
}
/** /**
* Create from ML Kit face detection result * Create from ML Kit face detection result
*/ */

View File

@@ -75,7 +75,21 @@ class PopulateFaceDetectionCacheUseCase @Inject constructor(
) )
try { try {
val imagesToScan = imageDao.getImagesNeedingFaceDetection() // Get images that need face detection (hasFaces IS NULL)
var imagesToScan = imageDao.getImagesNeedingFaceDetection()
// CRITICAL FIX: Also check for images marked as having faces but no FaceCacheEntity
if (imagesToScan.isEmpty()) {
val faceStats = faceCacheDao.getCacheStats()
if (faceStats.totalFaces == 0) {
// FaceCacheEntity is empty - rescan images that have faces
val imagesWithFaces = imageDao.getImagesWithFaces()
if (imagesWithFaces.isNotEmpty()) {
Log.w(TAG, "FaceCacheEntity empty but ${imagesWithFaces.size} images have faces - rescanning")
imagesToScan = imagesWithFaces
}
}
}
if (imagesToScan.isEmpty()) { if (imagesToScan.isEmpty()) {
Log.d(TAG, "No images need scanning") Log.d(TAG, "No images need scanning")
@@ -184,7 +198,7 @@ class PopulateFaceDetectionCacheUseCase @Inject constructor(
imageUri = image.imageUri imageUri = image.imageUri
) )
// Create FaceCacheEntity entries for each face // Create FaceCacheEntity entries for each face (NO embeddings - generated on demand)
val faceCacheEntries = faces.mapIndexed { index, face -> val faceCacheEntries = faces.mapIndexed { index, face ->
createFaceCacheEntry( createFaceCacheEntry(
imageId = image.imageId, imageId = image.imageId,
@@ -205,7 +219,8 @@ class PopulateFaceDetectionCacheUseCase @Inject constructor(
/** /**
* Create FaceCacheEntity from ML Kit Face * Create FaceCacheEntity from ML Kit Face
* *
* Uses FaceCacheEntity.create() which calculates quality metrics automatically * Uses FaceCacheEntity.create() which calculates quality metrics automatically.
* Embeddings are NOT generated here - they're generated on-demand in Training/Discovery.
*/ */
private fun createFaceCacheEntry( private fun createFaceCacheEntry(
imageId: String, imageId: String,
@@ -225,7 +240,7 @@ class PopulateFaceDetectionCacheUseCase @Inject constructor(
imageHeight = imageHeight, imageHeight = imageHeight,
confidence = 0.9f, // High confidence from accurate detector confidence = 0.9f, // High confidence from accurate detector
isFrontal = isFrontal, isFrontal = isFrontal,
embedding = null // Will be generated later during Discovery embedding = null // Generated on-demand in Training/Discovery
) )
} }
@@ -312,13 +327,27 @@ class PopulateFaceDetectionCacheUseCase @Inject constructor(
val imageStats = imageDao.getFaceCacheStats() val imageStats = imageDao.getFaceCacheStats()
val faceStats = faceCacheDao.getCacheStats() val faceStats = faceCacheDao.getCacheStats()
// CRITICAL FIX: If ImageEntity says "scanned" but FaceCacheEntity is empty,
// we need to re-scan. This happens after DB migration clears face_cache table.
val imagesWithFaces = imageStats?.imagesWithFaces ?: 0
val facesCached = faceStats.totalFaces
// If we have images marked as having faces but no FaceCacheEntity entries,
// those images need re-scanning
val needsRescan = if (imagesWithFaces > 0 && facesCached == 0) {
Log.w(TAG, "⚠️ FaceCacheEntity is empty but $imagesWithFaces images marked as having faces - forcing rescan")
imagesWithFaces
} else {
imageStats?.needsScanning ?: 0
}
CacheStats( CacheStats(
totalImages = imageStats?.totalImages ?: 0, totalImages = imageStats?.totalImages ?: 0,
imagesWithFaceCache = imageStats?.imagesWithFaceCache ?: 0, imagesWithFaceCache = imageStats?.imagesWithFaceCache ?: 0,
imagesWithFaces = imageStats?.imagesWithFaces ?: 0, imagesWithFaces = imagesWithFaces,
imagesWithoutFaces = imageStats?.imagesWithoutFaces ?: 0, imagesWithoutFaces = imageStats?.imagesWithoutFaces ?: 0,
needsScanning = imageStats?.needsScanning ?: 0, needsScanning = needsRescan,
totalFacesCached = faceStats.totalFaces, totalFacesCached = facesCached,
facesWithEmbeddings = faceStats.withEmbeddings, facesWithEmbeddings = faceStats.withEmbeddings,
averageQuality = faceStats.avgQuality averageQuality = faceStats.avgQuality
) )

View File

@@ -20,6 +20,7 @@ import androidx.lifecycle.compose.collectAsStateWithLifecycle
import androidx.navigation.NavController import androidx.navigation.NavController
import coil.compose.AsyncImage import coil.compose.AsyncImage
import com.placeholder.sherpai2.data.local.entity.TagEntity import com.placeholder.sherpai2.data.local.entity.TagEntity
import com.placeholder.sherpai2.ui.imagedetail.viewmodel.FaceTagInfo
import com.placeholder.sherpai2.ui.imagedetail.viewmodel.ImageDetailViewModel import com.placeholder.sherpai2.ui.imagedetail.viewmodel.ImageDetailViewModel
import net.engawapg.lib.zoomable.rememberZoomState import net.engawapg.lib.zoomable.rememberZoomState
import net.engawapg.lib.zoomable.zoomable import net.engawapg.lib.zoomable.zoomable
@@ -51,8 +52,12 @@ fun ImageDetailScreen(
} }
val tags by viewModel.tags.collectAsStateWithLifecycle() val tags by viewModel.tags.collectAsStateWithLifecycle()
val faceTags by viewModel.faceTags.collectAsStateWithLifecycle()
var showTags by remember { mutableStateOf(false) } var showTags by remember { mutableStateOf(false) }
// Total tag count for badge
val totalTagCount = tags.size + faceTags.size
// Navigation state // Navigation state
val currentIndex = if (allImageUris.isNotEmpty()) allImageUris.indexOf(imageUri) else -1 val currentIndex = if (allImageUris.isNotEmpty()) allImageUris.indexOf(imageUri) else -1
val hasNavigation = allImageUris.isNotEmpty() && currentIndex >= 0 val hasNavigation = allImageUris.isNotEmpty() && currentIndex >= 0
@@ -84,27 +89,35 @@ fun ImageDetailScreen(
horizontalArrangement = Arrangement.spacedBy(4.dp), horizontalArrangement = Arrangement.spacedBy(4.dp),
verticalAlignment = Alignment.CenterVertically verticalAlignment = Alignment.CenterVertically
) { ) {
if (tags.isNotEmpty()) { if (totalTagCount > 0) {
Badge( Badge(
containerColor = if (showTags) containerColor = if (showTags)
MaterialTheme.colorScheme.primary MaterialTheme.colorScheme.primary
else if (faceTags.isNotEmpty())
MaterialTheme.colorScheme.tertiary
else else
MaterialTheme.colorScheme.surfaceVariant MaterialTheme.colorScheme.surfaceVariant
) { ) {
Text( Text(
tags.size.toString(), totalTagCount.toString(),
color = if (showTags) color = if (showTags)
MaterialTheme.colorScheme.onPrimary MaterialTheme.colorScheme.onPrimary
else if (faceTags.isNotEmpty())
MaterialTheme.colorScheme.onTertiary
else else
MaterialTheme.colorScheme.onSurfaceVariant MaterialTheme.colorScheme.onSurfaceVariant
) )
} }
} }
Icon( Icon(
if (showTags) Icons.Default.Label else Icons.Default.LocalOffer, if (faceTags.isNotEmpty()) Icons.Default.Face
else if (showTags) Icons.Default.Label
else Icons.Default.LocalOffer,
"Show Tags", "Show Tags",
tint = if (showTags) tint = if (showTags)
MaterialTheme.colorScheme.primary MaterialTheme.colorScheme.primary
else if (faceTags.isNotEmpty())
MaterialTheme.colorScheme.tertiary
else else
MaterialTheme.colorScheme.onSurfaceVariant MaterialTheme.colorScheme.onSurfaceVariant
) )
@@ -189,6 +202,30 @@ fun ImageDetailScreen(
contentPadding = PaddingValues(16.dp), contentPadding = PaddingValues(16.dp),
verticalArrangement = Arrangement.spacedBy(8.dp) verticalArrangement = Arrangement.spacedBy(8.dp)
) { ) {
// Face Tags Section (People in Photo)
if (faceTags.isNotEmpty()) {
item {
Text(
"People (${faceTags.size})",
style = MaterialTheme.typography.titleMedium,
fontWeight = FontWeight.Bold,
color = MaterialTheme.colorScheme.tertiary
)
}
items(faceTags, key = { it.tagId }) { faceTag ->
FaceTagCard(
faceTag = faceTag,
onRemove = { viewModel.removeFaceTag(faceTag) }
)
}
item {
Spacer(modifier = Modifier.height(8.dp))
}
}
// Regular Tags Section
item { item {
Text( Text(
"Tags (${tags.size})", "Tags (${tags.size})",
@@ -197,7 +234,7 @@ fun ImageDetailScreen(
) )
} }
if (tags.isEmpty()) { if (tags.isEmpty() && faceTags.isEmpty()) {
item { item {
Text( Text(
"No tags yet", "No tags yet",
@@ -205,6 +242,14 @@ fun ImageDetailScreen(
color = MaterialTheme.colorScheme.onSurfaceVariant color = MaterialTheme.colorScheme.onSurfaceVariant
) )
} }
} else if (tags.isEmpty()) {
item {
Text(
"No other tags",
style = MaterialTheme.typography.bodySmall,
color = MaterialTheme.colorScheme.onSurfaceVariant
)
}
} }
items(tags, key = { it.tagId }) { tag -> items(tags, key = { it.tagId }) { tag ->
@@ -220,6 +265,83 @@ fun ImageDetailScreen(
} }
} }
@Composable
private fun FaceTagCard(
faceTag: FaceTagInfo,
onRemove: () -> Unit
) {
Card(
modifier = Modifier.fillMaxWidth(),
colors = CardDefaults.cardColors(
containerColor = MaterialTheme.colorScheme.tertiaryContainer
),
shape = RoundedCornerShape(8.dp)
) {
Row(
modifier = Modifier
.fillMaxWidth()
.padding(12.dp),
horizontalArrangement = Arrangement.SpaceBetween,
verticalAlignment = Alignment.CenterVertically
) {
Column(modifier = Modifier.weight(1f)) {
Row(
horizontalArrangement = Arrangement.spacedBy(8.dp),
verticalAlignment = Alignment.CenterVertically
) {
Icon(
imageVector = Icons.Default.Face,
contentDescription = null,
modifier = Modifier.size(20.dp),
tint = MaterialTheme.colorScheme.tertiary
)
Text(
text = faceTag.personName,
style = MaterialTheme.typography.bodyLarge,
fontWeight = FontWeight.SemiBold
)
}
Row(
horizontalArrangement = Arrangement.spacedBy(4.dp),
verticalAlignment = Alignment.CenterVertically
) {
Text(
text = "Face Recognition",
style = MaterialTheme.typography.labelSmall,
color = MaterialTheme.colorScheme.onSurfaceVariant
)
Text(
text = "",
style = MaterialTheme.typography.labelSmall,
color = MaterialTheme.colorScheme.onSurfaceVariant
)
Text(
text = "${(faceTag.confidence * 100).toInt()}% confidence",
style = MaterialTheme.typography.labelSmall,
color = if (faceTag.confidence >= 0.7f)
MaterialTheme.colorScheme.primary
else if (faceTag.confidence >= 0.5f)
MaterialTheme.colorScheme.secondary
else
MaterialTheme.colorScheme.error
)
}
}
// Remove button
IconButton(
onClick = onRemove,
colors = IconButtonDefaults.iconButtonColors(
contentColor = MaterialTheme.colorScheme.error
)
) {
Icon(Icons.Default.Delete, "Remove face tag")
}
}
}
}
@Composable @Composable
private fun TagCard( private fun TagCard(
tag: TagEntity, tag: TagEntity,

View File

@@ -2,6 +2,10 @@ package com.placeholder.sherpai2.ui.imagedetail.viewmodel
import androidx.lifecycle.ViewModel import androidx.lifecycle.ViewModel
import androidx.lifecycle.viewModelScope import androidx.lifecycle.viewModelScope
import com.placeholder.sherpai2.data.local.dao.FaceModelDao
import com.placeholder.sherpai2.data.local.dao.ImageDao
import com.placeholder.sherpai2.data.local.dao.PersonDao
import com.placeholder.sherpai2.data.local.dao.PhotoFaceTagDao
import com.placeholder.sherpai2.data.local.entity.TagEntity import com.placeholder.sherpai2.data.local.entity.TagEntity
import com.placeholder.sherpai2.domain.repository.TaggingRepository import com.placeholder.sherpai2.domain.repository.TaggingRepository
import dagger.hilt.android.lifecycle.HiltViewModel import dagger.hilt.android.lifecycle.HiltViewModel
@@ -10,17 +14,33 @@ import kotlinx.coroutines.flow.*
import kotlinx.coroutines.launch import kotlinx.coroutines.launch
import javax.inject.Inject import javax.inject.Inject
/**
* Represents a person tagged in this photo via face recognition
*/
data class FaceTagInfo(
val personId: String,
val personName: String,
val confidence: Float,
val faceModelId: String,
val tagId: String
)
/** /**
* ImageDetailViewModel * ImageDetailViewModel
* *
* Owns: * Owns:
* - Image context * - Image context
* - Tag write operations * - Tag write operations
* - Face tag display (people recognized in photo)
*/ */
@HiltViewModel @HiltViewModel
@OptIn(ExperimentalCoroutinesApi::class) @OptIn(ExperimentalCoroutinesApi::class)
class ImageDetailViewModel @Inject constructor( class ImageDetailViewModel @Inject constructor(
private val tagRepository: TaggingRepository private val tagRepository: TaggingRepository,
private val imageDao: ImageDao,
private val photoFaceTagDao: PhotoFaceTagDao,
private val faceModelDao: FaceModelDao,
private val personDao: PersonDao
) : ViewModel() { ) : ViewModel() {
private val imageUri = MutableStateFlow<String?>(null) private val imageUri = MutableStateFlow<String?>(null)
@@ -37,8 +57,43 @@ class ImageDetailViewModel @Inject constructor(
initialValue = emptyList() initialValue = emptyList()
) )
// Face tags (people recognized in this photo)
private val _faceTags = MutableStateFlow<List<FaceTagInfo>>(emptyList())
val faceTags: StateFlow<List<FaceTagInfo>> = _faceTags.asStateFlow()
fun loadImage(uri: String) { fun loadImage(uri: String) {
imageUri.value = uri imageUri.value = uri
loadFaceTags(uri)
}
private fun loadFaceTags(uri: String) {
viewModelScope.launch {
try {
// Get imageId from URI
val image = imageDao.getImageByUri(uri) ?: return@launch
// Get face tags for this image
val faceTags = photoFaceTagDao.getTagsForImage(image.imageId)
// Resolve to person names
val faceTagInfos = faceTags.mapNotNull { tag ->
val faceModel = faceModelDao.getFaceModelById(tag.faceModelId) ?: return@mapNotNull null
val person = personDao.getPersonById(faceModel.personId) ?: return@mapNotNull null
FaceTagInfo(
personId = person.id,
personName = person.name,
confidence = tag.confidence,
faceModelId = tag.faceModelId,
tagId = tag.id
)
}
_faceTags.value = faceTagInfos.sortedByDescending { it.confidence }
} catch (e: Exception) {
_faceTags.value = emptyList()
}
}
} }
fun addTag(value: String) { fun addTag(value: String) {
@@ -54,4 +109,15 @@ class ImageDetailViewModel @Inject constructor(
tagRepository.removeTagFromImage(uri, tag.value) tagRepository.removeTagFromImage(uri, tag.value)
} }
} }
/**
* Remove a face tag (person recognition)
*/
fun removeFaceTag(faceTagInfo: FaceTagInfo) {
viewModelScope.launch {
photoFaceTagDao.deleteTagById(faceTagInfo.tagId)
// Reload face tags
imageUri.value?.let { loadFaceTags(it) }
}
}
} }

View File

@@ -95,6 +95,9 @@ fun PersonInventoryScreen(
}, },
onDelete = { personId -> onDelete = { personId ->
viewModel.deletePerson(personId) viewModel.deletePerson(personId)
},
onClearTags = { personId ->
viewModel.clearTagsForPerson(personId)
} }
) )
} }
@@ -319,7 +322,8 @@ private fun PersonList(
persons: List<PersonWithModelInfo>, persons: List<PersonWithModelInfo>,
onScan: (String) -> Unit, onScan: (String) -> Unit,
onView: (String) -> Unit, onView: (String) -> Unit,
onDelete: (String) -> Unit onDelete: (String) -> Unit,
onClearTags: (String) -> Unit
) { ) {
LazyColumn( LazyColumn(
contentPadding = PaddingValues(vertical = 8.dp) contentPadding = PaddingValues(vertical = 8.dp)
@@ -332,7 +336,8 @@ private fun PersonList(
person = person, person = person,
onScan = { onScan(person.person.id) }, onScan = { onScan(person.person.id) },
onView = { onView(person.person.id) }, onView = { onView(person.person.id) },
onDelete = { onDelete(person.person.id) } onDelete = { onDelete(person.person.id) },
onClearTags = { onClearTags(person.person.id) }
) )
} }
} }
@@ -343,9 +348,34 @@ private fun PersonCard(
person: PersonWithModelInfo, person: PersonWithModelInfo,
onScan: () -> Unit, onScan: () -> Unit,
onView: () -> Unit, onView: () -> Unit,
onDelete: () -> Unit onDelete: () -> Unit,
onClearTags: () -> Unit
) { ) {
var showDeleteDialog by remember { mutableStateOf(false) } var showDeleteDialog by remember { mutableStateOf(false) }
var showClearDialog by remember { mutableStateOf(false) }
if (showClearDialog) {
AlertDialog(
onDismissRequest = { showClearDialog = false },
title = { Text("Clear tags for ${person.person.name}?") },
text = { Text("This will remove all ${person.taggedPhotoCount} photo tags but keep the face model. You can re-scan after clearing.") },
confirmButton = {
TextButton(
onClick = {
showClearDialog = false
onClearTags()
}
) {
Text("Clear Tags", color = MaterialTheme.colorScheme.error)
}
},
dismissButton = {
TextButton(onClick = { showClearDialog = false }) {
Text("Cancel")
}
}
)
}
if (showDeleteDialog) { if (showDeleteDialog) {
AlertDialog( AlertDialog(
@@ -413,6 +443,17 @@ private fun PersonCard(
) )
} }
// Clear tags button (if has tags)
if (person.taggedPhotoCount > 0) {
IconButton(onClick = { showClearDialog = true }) {
Icon(
Icons.Default.ClearAll,
contentDescription = "Clear Tags",
tint = MaterialTheme.colorScheme.secondary
)
}
}
// Delete button // Delete button
IconButton(onClick = { showDeleteDialog = true }) { IconButton(onClick = { showDeleteDialog = true }) {
Icon( Icon(

View File

@@ -105,6 +105,21 @@ class PersonInventoryViewModel @Inject constructor(
} }
} }
/**
* Clear all face tags for a person (keep model, allow rescan)
*/
fun clearTagsForPerson(personId: String) {
viewModelScope.launch(Dispatchers.IO) {
try {
val faceModel = faceModelDao.getFaceModelByPersonId(personId)
if (faceModel != null) {
photoFaceTagDao.deleteTagsForFaceModel(faceModel.id)
}
loadPersons()
} catch (e: Exception) {}
}
}
fun scanForPerson(personId: String) { fun scanForPerson(personId: String) {
viewModelScope.launch(Dispatchers.IO) { viewModelScope.launch(Dispatchers.IO) {
try { try {
@@ -133,10 +148,20 @@ class PersonInventoryViewModel @Inject constructor(
.build() .build()
val detector = FaceDetection.getClient(detectorOptions) val detector = FaceDetection.getClient(detectorOptions)
val modelEmbedding = faceModel.getEmbeddingArray() // CRITICAL: Use ALL centroids for matching
val faceNetModel = FaceNetModel(context) val modelCentroids = faceModel.getCentroids().map { it.getEmbeddingArray() }
val trainingCount = faceModel.trainingImageCount val trainingCount = faceModel.trainingImageCount
val baseThreshold = ThresholdStrategy.getLiberalThreshold(trainingCount) android.util.Log.e("PersonScan", "=== CENTROIDS: ${modelCentroids.size}, trainingCount: $trainingCount ===")
if (modelCentroids.isEmpty()) {
_scanningState.value = ScanningState.Error("No centroids found")
return@launch
}
val faceNetModel = FaceNetModel(context)
// Production threshold - balance precision vs recall
val baseThreshold = 0.58f
android.util.Log.d("PersonScan", "Using threshold: $baseThreshold, centroids: ${modelCentroids.size}")
val completed = AtomicInteger(0) val completed = AtomicInteger(0)
val facesFound = AtomicInteger(0) val facesFound = AtomicInteger(0)
@@ -148,7 +173,7 @@ class PersonInventoryViewModel @Inject constructor(
val jobs = untaggedImages.map { image -> val jobs = untaggedImages.map { image ->
async { async {
semaphore.withPermit { semaphore.withPermit {
processImage(image, detector, faceNetModel, modelEmbedding, trainingCount, baseThreshold, personId, faceModel.id, batchMatches, batchUpdateMutex, completed, facesFound, startTime, totalToScan, person.name) processImage(image, detector, faceNetModel, modelCentroids, trainingCount, baseThreshold, personId, faceModel.id, batchMatches, batchUpdateMutex, completed, facesFound, startTime, totalToScan, person.name)
} }
} }
} }
@@ -175,7 +200,7 @@ class PersonInventoryViewModel @Inject constructor(
private suspend fun processImage( private suspend fun processImage(
image: ImageEntity, detector: com.google.mlkit.vision.face.FaceDetector, faceNetModel: FaceNetModel, image: ImageEntity, detector: com.google.mlkit.vision.face.FaceDetector, faceNetModel: FaceNetModel,
modelEmbedding: FloatArray, trainingCount: Int, baseThreshold: Float, personId: String, faceModelId: String, modelCentroids: List<FloatArray>, trainingCount: Int, baseThreshold: Float, personId: String, faceModelId: String,
batchMatches: MutableList<Triple<String, String, Float>>, batchUpdateMutex: Mutex, batchMatches: MutableList<Triple<String, String, Float>>, batchUpdateMutex: Mutex,
completed: AtomicInteger, facesFound: AtomicInteger, startTime: Long, totalToScan: Int, personName: String completed: AtomicInteger, facesFound: AtomicInteger, startTime: Long, totalToScan: Int, personName: String
) { ) {
@@ -212,14 +237,19 @@ class PersonInventoryViewModel @Inject constructor(
(face.boundingBox.bottom * scaleY).toInt() (face.boundingBox.bottom * scaleY).toInt()
) )
val faceBitmap = loadFaceRegion(uri, scaledBounds) ?: continue // CRITICAL: Add padding to face crop (same as training)
val faceBitmap = loadFaceRegionWithPadding(uri, scaledBounds, sizeOpts.outWidth, sizeOpts.outHeight) ?: continue
val faceEmbedding = faceNetModel.generateEmbedding(faceBitmap) val faceEmbedding = faceNetModel.generateEmbedding(faceBitmap)
val similarity = faceNetModel.calculateSimilarity(faceEmbedding, modelEmbedding)
faceBitmap.recycle() faceBitmap.recycle()
if (similarity >= threshold) { // Match against ALL centroids, use best match
val bestSimilarity = modelCentroids.maxOfOrNull { centroid ->
faceNetModel.calculateSimilarity(faceEmbedding, centroid)
} ?: 0f
if (bestSimilarity >= threshold) {
batchUpdateMutex.withLock { batchUpdateMutex.withLock {
batchMatches.add(Triple(personId, image.imageId, similarity)) batchMatches.add(Triple(personId, image.imageId, bestSimilarity))
facesFound.incrementAndGet() facesFound.incrementAndGet()
if (batchMatches.size >= BATCH_DB_SIZE) { if (batchMatches.size >= BATCH_DB_SIZE) {
saveBatchMatches(batchMatches.toList(), faceModelId) saveBatchMatches(batchMatches.toList(), faceModelId)
@@ -250,18 +280,32 @@ class PersonInventoryViewModel @Inject constructor(
} catch (e: Exception) { null } } catch (e: Exception) { null }
} }
private fun loadFaceRegion(uri: Uri, bounds: android.graphics.Rect): Bitmap? { /**
* Load face region WITH 25% padding - CRITICAL for matching training conditions
*/
private fun loadFaceRegionWithPadding(uri: Uri, bounds: android.graphics.Rect, imgWidth: Int, imgHeight: Int): Bitmap? {
return try { return try {
val full = context.contentResolver.openInputStream(uri)?.use { val full = context.contentResolver.openInputStream(uri)?.use {
BitmapFactory.decodeStream(it, null, BitmapFactory.Options().apply { inPreferredConfig = Bitmap.Config.ARGB_8888 }) BitmapFactory.decodeStream(it, null, BitmapFactory.Options().apply { inPreferredConfig = Bitmap.Config.ARGB_8888 })
} ?: return null } ?: return null
val safeLeft = bounds.left.coerceIn(0, full.width - 1) // Add 25% padding (same as training)
val safeTop = bounds.top.coerceIn(0, full.height - 1) val padding = (kotlin.math.max(bounds.width(), bounds.height()) * 0.25f).toInt()
val safeWidth = bounds.width().coerceAtMost(full.width - safeLeft)
val safeHeight = bounds.height().coerceAtMost(full.height - safeTop)
val cropped = Bitmap.createBitmap(full, safeLeft, safeTop, safeWidth, safeHeight) val left = (bounds.left - padding).coerceAtLeast(0)
val top = (bounds.top - padding).coerceAtLeast(0)
val right = (bounds.right + padding).coerceAtMost(full.width)
val bottom = (bounds.bottom + padding).coerceAtMost(full.height)
val width = right - left
val height = bottom - top
if (width <= 0 || height <= 0) {
full.recycle()
return null
}
val cropped = Bitmap.createBitmap(full, left, top, width, height)
full.recycle() full.recycle()
cropped cropped
} catch (e: Exception) { null } } catch (e: Exception) { null }

View File

@@ -192,10 +192,11 @@ class TrainViewModel @Inject constructor(
.first() .first()
if (backgroundTaggingEnabled) { if (backgroundTaggingEnabled) {
// Lower threshold (0.55) since we use multi-centroid matching
val scanRequest = LibraryScanWorker.createWorkRequest( val scanRequest = LibraryScanWorker.createWorkRequest(
personId = personId, personId = personId,
personName = personName, personName = personName,
threshold = 0.65f threshold = 0.55f
) )
workManager.enqueue(scanRequest) workManager.enqueue(scanRequest)
} }

View File

@@ -49,6 +49,7 @@ fun TrainingPhotoSelectorScreen(
val isRanking by viewModel.isRanking.collectAsStateWithLifecycle() val isRanking by viewModel.isRanking.collectAsStateWithLifecycle()
val showPremiumOnly by viewModel.showPremiumOnly.collectAsStateWithLifecycle() val showPremiumOnly by viewModel.showPremiumOnly.collectAsStateWithLifecycle()
val premiumCount by viewModel.premiumCount.collectAsStateWithLifecycle() val premiumCount by viewModel.premiumCount.collectAsStateWithLifecycle()
val embeddingProgress by viewModel.embeddingProgress.collectAsStateWithLifecycle()
Scaffold( Scaffold(
topBar = { topBar = {
@@ -155,7 +156,33 @@ fun TrainingPhotoSelectorScreen(
modifier = Modifier.fillMaxSize(), modifier = Modifier.fillMaxSize(),
contentAlignment = Alignment.Center contentAlignment = Alignment.Center
) { ) {
CircularProgressIndicator() Column(
horizontalAlignment = Alignment.CenterHorizontally,
verticalArrangement = Arrangement.spacedBy(16.dp)
) {
CircularProgressIndicator()
// Capture value to avoid race condition
val progress = embeddingProgress
if (progress != null) {
Text(
"Preparing faces: ${progress.current}/${progress.total}",
style = MaterialTheme.typography.bodyMedium,
color = MaterialTheme.colorScheme.onSurfaceVariant
)
LinearProgressIndicator(
progress = { progress.current.toFloat() / progress.total },
modifier = Modifier
.width(200.dp)
.padding(top = 8.dp)
)
} else {
Text(
"Loading premium faces...",
style = MaterialTheme.typography.bodyMedium,
color = MaterialTheme.colorScheme.onSurfaceVariant
)
}
}
} }
} }
photos.isEmpty() -> { photos.isEmpty() -> {

View File

@@ -1,20 +1,31 @@
package com.placeholder.sherpai2.ui.trainingprep package com.placeholder.sherpai2.ui.trainingprep
import android.app.Application
import android.graphics.Bitmap
import android.graphics.BitmapFactory
import android.graphics.Rect
import android.net.Uri
import android.util.Log import android.util.Log
import androidx.lifecycle.ViewModel import androidx.lifecycle.AndroidViewModel
import androidx.lifecycle.viewModelScope import androidx.lifecycle.viewModelScope
import com.placeholder.sherpai2.data.local.dao.FaceCacheDao import com.placeholder.sherpai2.data.local.dao.FaceCacheDao
import com.placeholder.sherpai2.data.local.dao.ImageDao import com.placeholder.sherpai2.data.local.dao.ImageDao
import com.placeholder.sherpai2.data.local.entity.FaceCacheEntity
import com.placeholder.sherpai2.data.local.entity.ImageEntity import com.placeholder.sherpai2.data.local.entity.ImageEntity
import com.placeholder.sherpai2.domain.similarity.FaceSimilarityScorer import com.placeholder.sherpai2.domain.similarity.FaceSimilarityScorer
import com.placeholder.sherpai2.ml.FaceNetModel
import dagger.hilt.android.lifecycle.HiltViewModel import dagger.hilt.android.lifecycle.HiltViewModel
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.Job import kotlinx.coroutines.Job
import kotlinx.coroutines.delay import kotlinx.coroutines.delay
import kotlinx.coroutines.flow.MutableStateFlow import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.StateFlow import kotlinx.coroutines.flow.StateFlow
import kotlinx.coroutines.flow.asStateFlow import kotlinx.coroutines.flow.asStateFlow
import kotlinx.coroutines.launch import kotlinx.coroutines.launch
import kotlinx.coroutines.withContext
import javax.inject.Inject import javax.inject.Inject
import kotlin.math.max
import kotlin.math.min
/** /**
* TrainingPhotoSelectorViewModel - PREMIUM GRID + ROLLING SCAN * TrainingPhotoSelectorViewModel - PREMIUM GRID + ROLLING SCAN
@@ -27,15 +38,18 @@ import javax.inject.Inject
*/ */
@HiltViewModel @HiltViewModel
class TrainingPhotoSelectorViewModel @Inject constructor( class TrainingPhotoSelectorViewModel @Inject constructor(
application: Application,
private val imageDao: ImageDao, private val imageDao: ImageDao,
private val faceCacheDao: FaceCacheDao, private val faceCacheDao: FaceCacheDao,
private val faceSimilarityScorer: FaceSimilarityScorer private val faceSimilarityScorer: FaceSimilarityScorer,
) : ViewModel() { private val faceNetModel: FaceNetModel
) : AndroidViewModel(application) {
companion object { companion object {
private const val TAG = "PremiumSelector" private const val TAG = "PremiumSelector"
private const val MIN_SEEDS_FOR_ROLLING_SCAN = 1 private const val MIN_SEEDS_FOR_ROLLING_SCAN = 1
private const val MAX_SEEDS_FOR_ROLLING_SCAN = 5 private const val MAX_SEEDS_FOR_ROLLING_SCAN = 5
private const val MAX_EMBEDDINGS_TO_GENERATE = 500
} }
// All photos (for fallback / full list) // All photos (for fallback / full list)
@@ -56,6 +70,12 @@ class TrainingPhotoSelectorViewModel @Inject constructor(
private val _isRanking = MutableStateFlow(false) private val _isRanking = MutableStateFlow(false)
val isRanking: StateFlow<Boolean> = _isRanking.asStateFlow() val isRanking: StateFlow<Boolean> = _isRanking.asStateFlow()
// Embedding generation progress
private val _embeddingProgress = MutableStateFlow<EmbeddingProgress?>(null)
val embeddingProgress: StateFlow<EmbeddingProgress?> = _embeddingProgress.asStateFlow()
data class EmbeddingProgress(val current: Int, val total: Int)
// Premium mode toggle // Premium mode toggle
private val _showPremiumOnly = MutableStateFlow(true) private val _showPremiumOnly = MutableStateFlow(true)
val showPremiumOnly: StateFlow<Boolean> = _showPremiumOnly.asStateFlow() val showPremiumOnly: StateFlow<Boolean> = _showPremiumOnly.asStateFlow()
@@ -79,20 +99,47 @@ class TrainingPhotoSelectorViewModel @Inject constructor(
/** /**
* Load PREMIUM faces first (solo, large, frontal, high quality) * Load PREMIUM faces first (solo, large, frontal, high quality)
* If no embeddings exist, generate them on-demand for premium candidates
*/ */
private fun loadPremiumFaces() { private fun loadPremiumFaces() {
viewModelScope.launch { viewModelScope.launch {
try { try {
_isLoading.value = true _isLoading.value = true
// Get premium faces from cache // First check if premium faces with embeddings exist
val premiumFaceCache = faceCacheDao.getPremiumFaces( var premiumFaceCache = faceCacheDao.getPremiumFaces(
minAreaRatio = 0.10f, minAreaRatio = 0.10f,
minQuality = 0.7f, minQuality = 0.7f,
limit = 500 limit = 500
) )
Log.d(TAG, " Found ${premiumFaceCache.size} premium faces") Log.d(TAG, "📊 Found ${premiumFaceCache.size} premium faces with embeddings")
// If no premium faces with embeddings, generate them on-demand
if (premiumFaceCache.isEmpty()) {
Log.d(TAG, "⚠️ No premium faces with embeddings - generating on-demand")
val candidates = faceCacheDao.getPremiumFaceCandidatesNeedingEmbeddings(
minAreaRatio = 0.10f,
minQuality = 0.7f,
limit = MAX_EMBEDDINGS_TO_GENERATE
)
Log.d(TAG, "📦 Found ${candidates.size} premium candidates needing embeddings")
if (candidates.isNotEmpty()) {
generateEmbeddingsForCandidates(candidates)
// Re-query after generating
premiumFaceCache = faceCacheDao.getPremiumFaces(
minAreaRatio = 0.10f,
minQuality = 0.7f,
limit = 500
)
Log.d(TAG, "✅ After generation: ${premiumFaceCache.size} premium faces")
}
}
_premiumCount.value = premiumFaceCache.size _premiumCount.value = premiumFaceCache.size
// Get corresponding ImageEntities // Get corresponding ImageEntities
@@ -117,10 +164,108 @@ class TrainingPhotoSelectorViewModel @Inject constructor(
loadAllFaces() loadAllFaces()
} finally { } finally {
_isLoading.value = false _isLoading.value = false
_embeddingProgress.value = null
} }
} }
} }
/**
* Generate embeddings for premium face candidates
*/
private suspend fun generateEmbeddingsForCandidates(candidates: List<FaceCacheEntity>) {
val context = getApplication<Application>()
val total = candidates.size
var processed = 0
withContext(Dispatchers.IO) {
// Get image URIs for candidates
val imageIds = candidates.map { it.imageId }.distinct()
val images = imageDao.getImagesByIds(imageIds)
val imageUriMap = images.associate { it.imageId to it.imageUri }
for (candidate in candidates) {
try {
val imageUri = imageUriMap[candidate.imageId] ?: continue
// Load bitmap
val bitmap = loadBitmapOptimized(context, Uri.parse(imageUri)) ?: continue
// Crop face
val croppedFace = cropFaceWithPadding(bitmap, candidate.getBoundingBox())
bitmap.recycle()
if (croppedFace == null) continue
// Generate embedding
val embedding = faceNetModel.generateEmbedding(croppedFace)
croppedFace.recycle()
// Validate embedding
if (embedding.any { it != 0f }) {
// Save to database
val embeddingJson = FaceCacheEntity.embeddingToJson(embedding)
faceCacheDao.updateEmbedding(candidate.imageId, candidate.faceIndex, embeddingJson)
}
} catch (e: Exception) {
Log.w(TAG, "Failed to generate embedding for ${candidate.imageId}: ${e.message}")
}
processed++
withContext(Dispatchers.Main) {
_embeddingProgress.value = EmbeddingProgress(processed, total)
}
}
}
Log.d(TAG, "✅ Generated embeddings for $processed/$total candidates")
}
private fun loadBitmapOptimized(context: android.content.Context, uri: Uri, maxDim: Int = 768): Bitmap? {
return try {
val options = BitmapFactory.Options().apply { inJustDecodeBounds = true }
context.contentResolver.openInputStream(uri)?.use { stream ->
BitmapFactory.decodeStream(stream, null, options)
}
var sampleSize = 1
while (options.outWidth / sampleSize > maxDim || options.outHeight / sampleSize > maxDim) {
sampleSize *= 2
}
val finalOptions = BitmapFactory.Options().apply {
inSampleSize = sampleSize
inPreferredConfig = Bitmap.Config.ARGB_8888
}
context.contentResolver.openInputStream(uri)?.use { stream ->
BitmapFactory.decodeStream(stream, null, finalOptions)
}
} catch (e: Exception) {
Log.w(TAG, "Failed to load bitmap: ${e.message}")
null
}
}
private fun cropFaceWithPadding(bitmap: Bitmap, boundingBox: Rect): Bitmap? {
return try {
val padding = (max(boundingBox.width(), boundingBox.height()) * 0.25f).toInt()
val left = max(0, boundingBox.left - padding)
val top = max(0, boundingBox.top - padding)
val right = min(bitmap.width, boundingBox.right + padding)
val bottom = min(bitmap.height, boundingBox.bottom + padding)
val width = right - left
val height = bottom - top
if (width > 0 && height > 0) {
Bitmap.createBitmap(bitmap, left, top, width, height)
} else null
} catch (e: Exception) {
Log.w(TAG, "Failed to crop face: ${e.message}")
null
}
}
/** /**
* Fallback: load all photos with faces * Fallback: load all photos with faces
*/ */