toofasttooclaude
This commit is contained in:
@@ -44,14 +44,15 @@ import com.placeholder.sherpai2.data.local.entity.*
|
|||||||
PhotoFaceTagEntity::class,
|
PhotoFaceTagEntity::class,
|
||||||
PersonAgeTagEntity::class,
|
PersonAgeTagEntity::class,
|
||||||
FaceCacheEntity::class,
|
FaceCacheEntity::class,
|
||||||
UserFeedbackEntity::class, // NEW: User corrections
|
UserFeedbackEntity::class,
|
||||||
|
PersonStatisticsEntity::class, // Pre-computed person stats
|
||||||
|
|
||||||
// ===== COLLECTIONS =====
|
// ===== COLLECTIONS =====
|
||||||
CollectionEntity::class,
|
CollectionEntity::class,
|
||||||
CollectionImageEntity::class,
|
CollectionImageEntity::class,
|
||||||
CollectionFilterEntity::class
|
CollectionFilterEntity::class
|
||||||
],
|
],
|
||||||
version = 10, // INCREMENTED for user feedback
|
version = 11, // INCREMENTED for person statistics
|
||||||
exportSchema = false
|
exportSchema = false
|
||||||
)
|
)
|
||||||
abstract class AppDatabase : RoomDatabase() {
|
abstract class AppDatabase : RoomDatabase() {
|
||||||
@@ -70,7 +71,8 @@ abstract class AppDatabase : RoomDatabase() {
|
|||||||
abstract fun photoFaceTagDao(): PhotoFaceTagDao
|
abstract fun photoFaceTagDao(): PhotoFaceTagDao
|
||||||
abstract fun personAgeTagDao(): PersonAgeTagDao
|
abstract fun personAgeTagDao(): PersonAgeTagDao
|
||||||
abstract fun faceCacheDao(): FaceCacheDao
|
abstract fun faceCacheDao(): FaceCacheDao
|
||||||
abstract fun userFeedbackDao(): UserFeedbackDao // NEW
|
abstract fun userFeedbackDao(): UserFeedbackDao
|
||||||
|
abstract fun personStatisticsDao(): PersonStatisticsDao
|
||||||
|
|
||||||
// ===== COLLECTIONS DAO =====
|
// ===== COLLECTIONS DAO =====
|
||||||
abstract fun collectionDao(): CollectionDao
|
abstract fun collectionDao(): CollectionDao
|
||||||
@@ -242,13 +244,41 @@ val MIGRATION_9_10 = object : Migration(9, 10) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* MIGRATION 10 → 11 (Person Statistics)
|
||||||
|
*
|
||||||
|
* Changes:
|
||||||
|
* 1. Create person_statistics table for pre-computed aggregates
|
||||||
|
*/
|
||||||
|
val MIGRATION_10_11 = object : Migration(10, 11) {
|
||||||
|
override fun migrate(database: SupportSQLiteDatabase) {
|
||||||
|
|
||||||
|
// Create person_statistics table
|
||||||
|
database.execSQL("""
|
||||||
|
CREATE TABLE IF NOT EXISTS person_statistics (
|
||||||
|
personId TEXT PRIMARY KEY NOT NULL,
|
||||||
|
photoCount INTEGER NOT NULL DEFAULT 0,
|
||||||
|
firstPhotoDate INTEGER NOT NULL DEFAULT 0,
|
||||||
|
lastPhotoDate INTEGER NOT NULL DEFAULT 0,
|
||||||
|
averageConfidence REAL NOT NULL DEFAULT 0,
|
||||||
|
agesWithPhotos TEXT,
|
||||||
|
updatedAt INTEGER NOT NULL DEFAULT 0,
|
||||||
|
FOREIGN KEY(personId) REFERENCES persons(id) ON DELETE CASCADE
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
|
||||||
|
// Index for sorting by photo count (People Dashboard)
|
||||||
|
database.execSQL("CREATE INDEX IF NOT EXISTS index_person_statistics_photoCount ON person_statistics(photoCount)")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* PRODUCTION MIGRATION NOTES:
|
* PRODUCTION MIGRATION NOTES:
|
||||||
*
|
*
|
||||||
* Before shipping to users, update DatabaseModule to use migrations:
|
* Before shipping to users, update DatabaseModule to use migrations:
|
||||||
*
|
*
|
||||||
* Room.databaseBuilder(context, AppDatabase::class.java, "sherpai.db")
|
* Room.databaseBuilder(context, AppDatabase::class.java, "sherpai.db")
|
||||||
* .addMigrations(MIGRATION_7_8, MIGRATION_8_9, MIGRATION_9_10) // Add all migrations
|
* .addMigrations(MIGRATION_7_8, MIGRATION_8_9, MIGRATION_9_10, MIGRATION_10_11) // Add all migrations
|
||||||
* // .fallbackToDestructiveMigration() // Remove this
|
* // .fallbackToDestructiveMigration() // Remove this
|
||||||
* .build()
|
* .build()
|
||||||
*/
|
*/
|
||||||
@@ -233,6 +233,33 @@ interface FaceCacheDao {
|
|||||||
limit: Int = 500
|
limit: Int = 500
|
||||||
): List<FaceCacheEntity>
|
): List<FaceCacheEntity>
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get premium face CANDIDATES - same criteria but WITHOUT embedding requirement.
|
||||||
|
* Used to find faces that need embedding generation.
|
||||||
|
*/
|
||||||
|
@Query("""
|
||||||
|
SELECT fc.* FROM face_cache fc
|
||||||
|
INNER JOIN images i ON fc.imageId = i.imageId
|
||||||
|
WHERE i.faceCount = 1
|
||||||
|
AND fc.faceAreaRatio >= :minAreaRatio
|
||||||
|
AND fc.isFrontal = 1
|
||||||
|
AND fc.qualityScore >= :minQuality
|
||||||
|
AND fc.embedding IS NULL
|
||||||
|
ORDER BY fc.qualityScore DESC, fc.faceAreaRatio DESC
|
||||||
|
LIMIT :limit
|
||||||
|
""")
|
||||||
|
suspend fun getPremiumFaceCandidatesNeedingEmbeddings(
|
||||||
|
minAreaRatio: Float = 0.10f,
|
||||||
|
minQuality: Float = 0.7f,
|
||||||
|
limit: Int = 500
|
||||||
|
): List<FaceCacheEntity>
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Update embedding for a face cache entry
|
||||||
|
*/
|
||||||
|
@Query("UPDATE face_cache SET embedding = :embedding WHERE imageId = :imageId AND faceIndex = :faceIndex")
|
||||||
|
suspend fun updateEmbedding(imageId: String, faceIndex: Int, embedding: String)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Count of premium faces available
|
* Count of premium faces available
|
||||||
*/
|
*/
|
||||||
|
|||||||
@@ -83,9 +83,89 @@ interface PhotoFaceTagDao {
|
|||||||
*/
|
*/
|
||||||
@Query("SELECT * FROM photo_face_tags ORDER BY detectedAt DESC LIMIT :limit")
|
@Query("SELECT * FROM photo_face_tags ORDER BY detectedAt DESC LIMIT :limit")
|
||||||
suspend fun getRecentlyDetectedFaces(limit: Int): List<PhotoFaceTagEntity>
|
suspend fun getRecentlyDetectedFaces(limit: Int): List<PhotoFaceTagEntity>
|
||||||
|
|
||||||
|
// ===== CO-OCCURRENCE QUERIES =====
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Find people who appear in photos together with a given person.
|
||||||
|
* Returns list of (otherFaceModelId, count) sorted by count descending.
|
||||||
|
* Use case: "Who appears most with Mom?" or "Show photos of Mom WITH Dad"
|
||||||
|
*/
|
||||||
|
@Query("""
|
||||||
|
SELECT pft2.faceModelId as otherFaceModelId, COUNT(DISTINCT pft1.imageId) as coCount
|
||||||
|
FROM photo_face_tags pft1
|
||||||
|
INNER JOIN photo_face_tags pft2 ON pft1.imageId = pft2.imageId
|
||||||
|
WHERE pft1.faceModelId = :faceModelId
|
||||||
|
AND pft2.faceModelId != :faceModelId
|
||||||
|
GROUP BY pft2.faceModelId
|
||||||
|
ORDER BY coCount DESC
|
||||||
|
""")
|
||||||
|
suspend fun getCoOccurrences(faceModelId: String): List<PersonCoOccurrence>
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get images where BOTH people appear together.
|
||||||
|
*/
|
||||||
|
@Query("""
|
||||||
|
SELECT DISTINCT pft1.imageId
|
||||||
|
FROM photo_face_tags pft1
|
||||||
|
INNER JOIN photo_face_tags pft2 ON pft1.imageId = pft2.imageId
|
||||||
|
WHERE pft1.faceModelId = :faceModelId1
|
||||||
|
AND pft2.faceModelId = :faceModelId2
|
||||||
|
ORDER BY pft1.detectedAt DESC
|
||||||
|
""")
|
||||||
|
suspend fun getImagesWithBothPeople(faceModelId1: String, faceModelId2: String): List<String>
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get images where person appears ALONE (no other trained faces).
|
||||||
|
*/
|
||||||
|
@Query("""
|
||||||
|
SELECT imageId FROM photo_face_tags
|
||||||
|
WHERE faceModelId = :faceModelId
|
||||||
|
AND imageId NOT IN (
|
||||||
|
SELECT imageId FROM photo_face_tags
|
||||||
|
WHERE faceModelId != :faceModelId
|
||||||
|
)
|
||||||
|
ORDER BY detectedAt DESC
|
||||||
|
""")
|
||||||
|
suspend fun getImagesWithPersonAlone(faceModelId: String): List<String>
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get images where ALL specified people appear (N-way intersection).
|
||||||
|
* For "Intersection Search" moonshot feature.
|
||||||
|
*/
|
||||||
|
@Query("""
|
||||||
|
SELECT imageId FROM photo_face_tags
|
||||||
|
WHERE faceModelId IN (:faceModelIds)
|
||||||
|
GROUP BY imageId
|
||||||
|
HAVING COUNT(DISTINCT faceModelId) = :requiredCount
|
||||||
|
""")
|
||||||
|
suspend fun getImagesWithAllPeople(faceModelIds: List<String>, requiredCount: Int): List<String>
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get images with at least N of the specified people (family portrait detection).
|
||||||
|
*/
|
||||||
|
@Query("""
|
||||||
|
SELECT imageId, COUNT(DISTINCT faceModelId) as memberCount
|
||||||
|
FROM photo_face_tags
|
||||||
|
WHERE faceModelId IN (:faceModelIds)
|
||||||
|
GROUP BY imageId
|
||||||
|
HAVING memberCount >= :minMembers
|
||||||
|
ORDER BY memberCount DESC
|
||||||
|
""")
|
||||||
|
suspend fun getFamilyPortraits(faceModelIds: List<String>, minMembers: Int): List<FamilyPortraitResult>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
data class FamilyPortraitResult(
|
||||||
|
val imageId: String,
|
||||||
|
val memberCount: Int
|
||||||
|
)
|
||||||
|
|
||||||
data class FaceModelPhotoCount(
|
data class FaceModelPhotoCount(
|
||||||
val faceModelId: String,
|
val faceModelId: String,
|
||||||
val photoCount: Int
|
val photoCount: Int
|
||||||
)
|
)
|
||||||
|
|
||||||
|
data class PersonCoOccurrence(
|
||||||
|
val otherFaceModelId: String,
|
||||||
|
val coCount: Int
|
||||||
|
)
|
||||||
|
|||||||
@@ -99,6 +99,13 @@ data class FaceCacheEntity(
|
|||||||
companion object {
|
companion object {
|
||||||
const val CURRENT_CACHE_VERSION = 1
|
const val CURRENT_CACHE_VERSION = 1
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Convert FloatArray embedding to JSON string for storage
|
||||||
|
*/
|
||||||
|
fun embeddingToJson(embedding: FloatArray): String {
|
||||||
|
return embedding.joinToString(",")
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Create from ML Kit face detection result
|
* Create from ML Kit face detection result
|
||||||
*/
|
*/
|
||||||
|
|||||||
@@ -75,7 +75,21 @@ class PopulateFaceDetectionCacheUseCase @Inject constructor(
|
|||||||
)
|
)
|
||||||
|
|
||||||
try {
|
try {
|
||||||
val imagesToScan = imageDao.getImagesNeedingFaceDetection()
|
// Get images that need face detection (hasFaces IS NULL)
|
||||||
|
var imagesToScan = imageDao.getImagesNeedingFaceDetection()
|
||||||
|
|
||||||
|
// CRITICAL FIX: Also check for images marked as having faces but no FaceCacheEntity
|
||||||
|
if (imagesToScan.isEmpty()) {
|
||||||
|
val faceStats = faceCacheDao.getCacheStats()
|
||||||
|
if (faceStats.totalFaces == 0) {
|
||||||
|
// FaceCacheEntity is empty - rescan images that have faces
|
||||||
|
val imagesWithFaces = imageDao.getImagesWithFaces()
|
||||||
|
if (imagesWithFaces.isNotEmpty()) {
|
||||||
|
Log.w(TAG, "FaceCacheEntity empty but ${imagesWithFaces.size} images have faces - rescanning")
|
||||||
|
imagesToScan = imagesWithFaces
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if (imagesToScan.isEmpty()) {
|
if (imagesToScan.isEmpty()) {
|
||||||
Log.d(TAG, "No images need scanning")
|
Log.d(TAG, "No images need scanning")
|
||||||
@@ -184,7 +198,7 @@ class PopulateFaceDetectionCacheUseCase @Inject constructor(
|
|||||||
imageUri = image.imageUri
|
imageUri = image.imageUri
|
||||||
)
|
)
|
||||||
|
|
||||||
// Create FaceCacheEntity entries for each face
|
// Create FaceCacheEntity entries for each face (NO embeddings - generated on demand)
|
||||||
val faceCacheEntries = faces.mapIndexed { index, face ->
|
val faceCacheEntries = faces.mapIndexed { index, face ->
|
||||||
createFaceCacheEntry(
|
createFaceCacheEntry(
|
||||||
imageId = image.imageId,
|
imageId = image.imageId,
|
||||||
@@ -205,7 +219,8 @@ class PopulateFaceDetectionCacheUseCase @Inject constructor(
|
|||||||
/**
|
/**
|
||||||
* Create FaceCacheEntity from ML Kit Face
|
* Create FaceCacheEntity from ML Kit Face
|
||||||
*
|
*
|
||||||
* Uses FaceCacheEntity.create() which calculates quality metrics automatically
|
* Uses FaceCacheEntity.create() which calculates quality metrics automatically.
|
||||||
|
* Embeddings are NOT generated here - they're generated on-demand in Training/Discovery.
|
||||||
*/
|
*/
|
||||||
private fun createFaceCacheEntry(
|
private fun createFaceCacheEntry(
|
||||||
imageId: String,
|
imageId: String,
|
||||||
@@ -225,7 +240,7 @@ class PopulateFaceDetectionCacheUseCase @Inject constructor(
|
|||||||
imageHeight = imageHeight,
|
imageHeight = imageHeight,
|
||||||
confidence = 0.9f, // High confidence from accurate detector
|
confidence = 0.9f, // High confidence from accurate detector
|
||||||
isFrontal = isFrontal,
|
isFrontal = isFrontal,
|
||||||
embedding = null // Will be generated later during Discovery
|
embedding = null // Generated on-demand in Training/Discovery
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -312,13 +327,27 @@ class PopulateFaceDetectionCacheUseCase @Inject constructor(
|
|||||||
val imageStats = imageDao.getFaceCacheStats()
|
val imageStats = imageDao.getFaceCacheStats()
|
||||||
val faceStats = faceCacheDao.getCacheStats()
|
val faceStats = faceCacheDao.getCacheStats()
|
||||||
|
|
||||||
|
// CRITICAL FIX: If ImageEntity says "scanned" but FaceCacheEntity is empty,
|
||||||
|
// we need to re-scan. This happens after DB migration clears face_cache table.
|
||||||
|
val imagesWithFaces = imageStats?.imagesWithFaces ?: 0
|
||||||
|
val facesCached = faceStats.totalFaces
|
||||||
|
|
||||||
|
// If we have images marked as having faces but no FaceCacheEntity entries,
|
||||||
|
// those images need re-scanning
|
||||||
|
val needsRescan = if (imagesWithFaces > 0 && facesCached == 0) {
|
||||||
|
Log.w(TAG, "⚠️ FaceCacheEntity is empty but $imagesWithFaces images marked as having faces - forcing rescan")
|
||||||
|
imagesWithFaces
|
||||||
|
} else {
|
||||||
|
imageStats?.needsScanning ?: 0
|
||||||
|
}
|
||||||
|
|
||||||
CacheStats(
|
CacheStats(
|
||||||
totalImages = imageStats?.totalImages ?: 0,
|
totalImages = imageStats?.totalImages ?: 0,
|
||||||
imagesWithFaceCache = imageStats?.imagesWithFaceCache ?: 0,
|
imagesWithFaceCache = imageStats?.imagesWithFaceCache ?: 0,
|
||||||
imagesWithFaces = imageStats?.imagesWithFaces ?: 0,
|
imagesWithFaces = imagesWithFaces,
|
||||||
imagesWithoutFaces = imageStats?.imagesWithoutFaces ?: 0,
|
imagesWithoutFaces = imageStats?.imagesWithoutFaces ?: 0,
|
||||||
needsScanning = imageStats?.needsScanning ?: 0,
|
needsScanning = needsRescan,
|
||||||
totalFacesCached = faceStats.totalFaces,
|
totalFacesCached = facesCached,
|
||||||
facesWithEmbeddings = faceStats.withEmbeddings,
|
facesWithEmbeddings = faceStats.withEmbeddings,
|
||||||
averageQuality = faceStats.avgQuality
|
averageQuality = faceStats.avgQuality
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ import androidx.lifecycle.compose.collectAsStateWithLifecycle
|
|||||||
import androidx.navigation.NavController
|
import androidx.navigation.NavController
|
||||||
import coil.compose.AsyncImage
|
import coil.compose.AsyncImage
|
||||||
import com.placeholder.sherpai2.data.local.entity.TagEntity
|
import com.placeholder.sherpai2.data.local.entity.TagEntity
|
||||||
|
import com.placeholder.sherpai2.ui.imagedetail.viewmodel.FaceTagInfo
|
||||||
import com.placeholder.sherpai2.ui.imagedetail.viewmodel.ImageDetailViewModel
|
import com.placeholder.sherpai2.ui.imagedetail.viewmodel.ImageDetailViewModel
|
||||||
import net.engawapg.lib.zoomable.rememberZoomState
|
import net.engawapg.lib.zoomable.rememberZoomState
|
||||||
import net.engawapg.lib.zoomable.zoomable
|
import net.engawapg.lib.zoomable.zoomable
|
||||||
@@ -51,8 +52,12 @@ fun ImageDetailScreen(
|
|||||||
}
|
}
|
||||||
|
|
||||||
val tags by viewModel.tags.collectAsStateWithLifecycle()
|
val tags by viewModel.tags.collectAsStateWithLifecycle()
|
||||||
|
val faceTags by viewModel.faceTags.collectAsStateWithLifecycle()
|
||||||
var showTags by remember { mutableStateOf(false) }
|
var showTags by remember { mutableStateOf(false) }
|
||||||
|
|
||||||
|
// Total tag count for badge
|
||||||
|
val totalTagCount = tags.size + faceTags.size
|
||||||
|
|
||||||
// Navigation state
|
// Navigation state
|
||||||
val currentIndex = if (allImageUris.isNotEmpty()) allImageUris.indexOf(imageUri) else -1
|
val currentIndex = if (allImageUris.isNotEmpty()) allImageUris.indexOf(imageUri) else -1
|
||||||
val hasNavigation = allImageUris.isNotEmpty() && currentIndex >= 0
|
val hasNavigation = allImageUris.isNotEmpty() && currentIndex >= 0
|
||||||
@@ -84,27 +89,35 @@ fun ImageDetailScreen(
|
|||||||
horizontalArrangement = Arrangement.spacedBy(4.dp),
|
horizontalArrangement = Arrangement.spacedBy(4.dp),
|
||||||
verticalAlignment = Alignment.CenterVertically
|
verticalAlignment = Alignment.CenterVertically
|
||||||
) {
|
) {
|
||||||
if (tags.isNotEmpty()) {
|
if (totalTagCount > 0) {
|
||||||
Badge(
|
Badge(
|
||||||
containerColor = if (showTags)
|
containerColor = if (showTags)
|
||||||
MaterialTheme.colorScheme.primary
|
MaterialTheme.colorScheme.primary
|
||||||
|
else if (faceTags.isNotEmpty())
|
||||||
|
MaterialTheme.colorScheme.tertiary
|
||||||
else
|
else
|
||||||
MaterialTheme.colorScheme.surfaceVariant
|
MaterialTheme.colorScheme.surfaceVariant
|
||||||
) {
|
) {
|
||||||
Text(
|
Text(
|
||||||
tags.size.toString(),
|
totalTagCount.toString(),
|
||||||
color = if (showTags)
|
color = if (showTags)
|
||||||
MaterialTheme.colorScheme.onPrimary
|
MaterialTheme.colorScheme.onPrimary
|
||||||
|
else if (faceTags.isNotEmpty())
|
||||||
|
MaterialTheme.colorScheme.onTertiary
|
||||||
else
|
else
|
||||||
MaterialTheme.colorScheme.onSurfaceVariant
|
MaterialTheme.colorScheme.onSurfaceVariant
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Icon(
|
Icon(
|
||||||
if (showTags) Icons.Default.Label else Icons.Default.LocalOffer,
|
if (faceTags.isNotEmpty()) Icons.Default.Face
|
||||||
|
else if (showTags) Icons.Default.Label
|
||||||
|
else Icons.Default.LocalOffer,
|
||||||
"Show Tags",
|
"Show Tags",
|
||||||
tint = if (showTags)
|
tint = if (showTags)
|
||||||
MaterialTheme.colorScheme.primary
|
MaterialTheme.colorScheme.primary
|
||||||
|
else if (faceTags.isNotEmpty())
|
||||||
|
MaterialTheme.colorScheme.tertiary
|
||||||
else
|
else
|
||||||
MaterialTheme.colorScheme.onSurfaceVariant
|
MaterialTheme.colorScheme.onSurfaceVariant
|
||||||
)
|
)
|
||||||
@@ -189,6 +202,30 @@ fun ImageDetailScreen(
|
|||||||
contentPadding = PaddingValues(16.dp),
|
contentPadding = PaddingValues(16.dp),
|
||||||
verticalArrangement = Arrangement.spacedBy(8.dp)
|
verticalArrangement = Arrangement.spacedBy(8.dp)
|
||||||
) {
|
) {
|
||||||
|
// Face Tags Section (People in Photo)
|
||||||
|
if (faceTags.isNotEmpty()) {
|
||||||
|
item {
|
||||||
|
Text(
|
||||||
|
"People (${faceTags.size})",
|
||||||
|
style = MaterialTheme.typography.titleMedium,
|
||||||
|
fontWeight = FontWeight.Bold,
|
||||||
|
color = MaterialTheme.colorScheme.tertiary
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
items(faceTags, key = { it.tagId }) { faceTag ->
|
||||||
|
FaceTagCard(
|
||||||
|
faceTag = faceTag,
|
||||||
|
onRemove = { viewModel.removeFaceTag(faceTag) }
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
item {
|
||||||
|
Spacer(modifier = Modifier.height(8.dp))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Regular Tags Section
|
||||||
item {
|
item {
|
||||||
Text(
|
Text(
|
||||||
"Tags (${tags.size})",
|
"Tags (${tags.size})",
|
||||||
@@ -197,7 +234,7 @@ fun ImageDetailScreen(
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
if (tags.isEmpty()) {
|
if (tags.isEmpty() && faceTags.isEmpty()) {
|
||||||
item {
|
item {
|
||||||
Text(
|
Text(
|
||||||
"No tags yet",
|
"No tags yet",
|
||||||
@@ -205,6 +242,14 @@ fun ImageDetailScreen(
|
|||||||
color = MaterialTheme.colorScheme.onSurfaceVariant
|
color = MaterialTheme.colorScheme.onSurfaceVariant
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
} else if (tags.isEmpty()) {
|
||||||
|
item {
|
||||||
|
Text(
|
||||||
|
"No other tags",
|
||||||
|
style = MaterialTheme.typography.bodySmall,
|
||||||
|
color = MaterialTheme.colorScheme.onSurfaceVariant
|
||||||
|
)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
items(tags, key = { it.tagId }) { tag ->
|
items(tags, key = { it.tagId }) { tag ->
|
||||||
@@ -220,6 +265,83 @@ fun ImageDetailScreen(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Composable
|
||||||
|
private fun FaceTagCard(
|
||||||
|
faceTag: FaceTagInfo,
|
||||||
|
onRemove: () -> Unit
|
||||||
|
) {
|
||||||
|
Card(
|
||||||
|
modifier = Modifier.fillMaxWidth(),
|
||||||
|
colors = CardDefaults.cardColors(
|
||||||
|
containerColor = MaterialTheme.colorScheme.tertiaryContainer
|
||||||
|
),
|
||||||
|
shape = RoundedCornerShape(8.dp)
|
||||||
|
) {
|
||||||
|
Row(
|
||||||
|
modifier = Modifier
|
||||||
|
.fillMaxWidth()
|
||||||
|
.padding(12.dp),
|
||||||
|
horizontalArrangement = Arrangement.SpaceBetween,
|
||||||
|
verticalAlignment = Alignment.CenterVertically
|
||||||
|
) {
|
||||||
|
Column(modifier = Modifier.weight(1f)) {
|
||||||
|
Row(
|
||||||
|
horizontalArrangement = Arrangement.spacedBy(8.dp),
|
||||||
|
verticalAlignment = Alignment.CenterVertically
|
||||||
|
) {
|
||||||
|
Icon(
|
||||||
|
imageVector = Icons.Default.Face,
|
||||||
|
contentDescription = null,
|
||||||
|
modifier = Modifier.size(20.dp),
|
||||||
|
tint = MaterialTheme.colorScheme.tertiary
|
||||||
|
)
|
||||||
|
Text(
|
||||||
|
text = faceTag.personName,
|
||||||
|
style = MaterialTheme.typography.bodyLarge,
|
||||||
|
fontWeight = FontWeight.SemiBold
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
Row(
|
||||||
|
horizontalArrangement = Arrangement.spacedBy(4.dp),
|
||||||
|
verticalAlignment = Alignment.CenterVertically
|
||||||
|
) {
|
||||||
|
Text(
|
||||||
|
text = "Face Recognition",
|
||||||
|
style = MaterialTheme.typography.labelSmall,
|
||||||
|
color = MaterialTheme.colorScheme.onSurfaceVariant
|
||||||
|
)
|
||||||
|
Text(
|
||||||
|
text = "•",
|
||||||
|
style = MaterialTheme.typography.labelSmall,
|
||||||
|
color = MaterialTheme.colorScheme.onSurfaceVariant
|
||||||
|
)
|
||||||
|
Text(
|
||||||
|
text = "${(faceTag.confidence * 100).toInt()}% confidence",
|
||||||
|
style = MaterialTheme.typography.labelSmall,
|
||||||
|
color = if (faceTag.confidence >= 0.7f)
|
||||||
|
MaterialTheme.colorScheme.primary
|
||||||
|
else if (faceTag.confidence >= 0.5f)
|
||||||
|
MaterialTheme.colorScheme.secondary
|
||||||
|
else
|
||||||
|
MaterialTheme.colorScheme.error
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove button
|
||||||
|
IconButton(
|
||||||
|
onClick = onRemove,
|
||||||
|
colors = IconButtonDefaults.iconButtonColors(
|
||||||
|
contentColor = MaterialTheme.colorScheme.error
|
||||||
|
)
|
||||||
|
) {
|
||||||
|
Icon(Icons.Default.Delete, "Remove face tag")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@Composable
|
@Composable
|
||||||
private fun TagCard(
|
private fun TagCard(
|
||||||
tag: TagEntity,
|
tag: TagEntity,
|
||||||
|
|||||||
@@ -2,6 +2,10 @@ package com.placeholder.sherpai2.ui.imagedetail.viewmodel
|
|||||||
|
|
||||||
import androidx.lifecycle.ViewModel
|
import androidx.lifecycle.ViewModel
|
||||||
import androidx.lifecycle.viewModelScope
|
import androidx.lifecycle.viewModelScope
|
||||||
|
import com.placeholder.sherpai2.data.local.dao.FaceModelDao
|
||||||
|
import com.placeholder.sherpai2.data.local.dao.ImageDao
|
||||||
|
import com.placeholder.sherpai2.data.local.dao.PersonDao
|
||||||
|
import com.placeholder.sherpai2.data.local.dao.PhotoFaceTagDao
|
||||||
import com.placeholder.sherpai2.data.local.entity.TagEntity
|
import com.placeholder.sherpai2.data.local.entity.TagEntity
|
||||||
import com.placeholder.sherpai2.domain.repository.TaggingRepository
|
import com.placeholder.sherpai2.domain.repository.TaggingRepository
|
||||||
import dagger.hilt.android.lifecycle.HiltViewModel
|
import dagger.hilt.android.lifecycle.HiltViewModel
|
||||||
@@ -10,17 +14,33 @@ import kotlinx.coroutines.flow.*
|
|||||||
import kotlinx.coroutines.launch
|
import kotlinx.coroutines.launch
|
||||||
import javax.inject.Inject
|
import javax.inject.Inject
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Represents a person tagged in this photo via face recognition
|
||||||
|
*/
|
||||||
|
data class FaceTagInfo(
|
||||||
|
val personId: String,
|
||||||
|
val personName: String,
|
||||||
|
val confidence: Float,
|
||||||
|
val faceModelId: String,
|
||||||
|
val tagId: String
|
||||||
|
)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* ImageDetailViewModel
|
* ImageDetailViewModel
|
||||||
*
|
*
|
||||||
* Owns:
|
* Owns:
|
||||||
* - Image context
|
* - Image context
|
||||||
* - Tag write operations
|
* - Tag write operations
|
||||||
|
* - Face tag display (people recognized in photo)
|
||||||
*/
|
*/
|
||||||
@HiltViewModel
|
@HiltViewModel
|
||||||
@OptIn(ExperimentalCoroutinesApi::class)
|
@OptIn(ExperimentalCoroutinesApi::class)
|
||||||
class ImageDetailViewModel @Inject constructor(
|
class ImageDetailViewModel @Inject constructor(
|
||||||
private val tagRepository: TaggingRepository
|
private val tagRepository: TaggingRepository,
|
||||||
|
private val imageDao: ImageDao,
|
||||||
|
private val photoFaceTagDao: PhotoFaceTagDao,
|
||||||
|
private val faceModelDao: FaceModelDao,
|
||||||
|
private val personDao: PersonDao
|
||||||
) : ViewModel() {
|
) : ViewModel() {
|
||||||
|
|
||||||
private val imageUri = MutableStateFlow<String?>(null)
|
private val imageUri = MutableStateFlow<String?>(null)
|
||||||
@@ -37,8 +57,43 @@ class ImageDetailViewModel @Inject constructor(
|
|||||||
initialValue = emptyList()
|
initialValue = emptyList()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Face tags (people recognized in this photo)
|
||||||
|
private val _faceTags = MutableStateFlow<List<FaceTagInfo>>(emptyList())
|
||||||
|
val faceTags: StateFlow<List<FaceTagInfo>> = _faceTags.asStateFlow()
|
||||||
|
|
||||||
fun loadImage(uri: String) {
|
fun loadImage(uri: String) {
|
||||||
imageUri.value = uri
|
imageUri.value = uri
|
||||||
|
loadFaceTags(uri)
|
||||||
|
}
|
||||||
|
|
||||||
|
private fun loadFaceTags(uri: String) {
|
||||||
|
viewModelScope.launch {
|
||||||
|
try {
|
||||||
|
// Get imageId from URI
|
||||||
|
val image = imageDao.getImageByUri(uri) ?: return@launch
|
||||||
|
|
||||||
|
// Get face tags for this image
|
||||||
|
val faceTags = photoFaceTagDao.getTagsForImage(image.imageId)
|
||||||
|
|
||||||
|
// Resolve to person names
|
||||||
|
val faceTagInfos = faceTags.mapNotNull { tag ->
|
||||||
|
val faceModel = faceModelDao.getFaceModelById(tag.faceModelId) ?: return@mapNotNull null
|
||||||
|
val person = personDao.getPersonById(faceModel.personId) ?: return@mapNotNull null
|
||||||
|
|
||||||
|
FaceTagInfo(
|
||||||
|
personId = person.id,
|
||||||
|
personName = person.name,
|
||||||
|
confidence = tag.confidence,
|
||||||
|
faceModelId = tag.faceModelId,
|
||||||
|
tagId = tag.id
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
_faceTags.value = faceTagInfos.sortedByDescending { it.confidence }
|
||||||
|
} catch (e: Exception) {
|
||||||
|
_faceTags.value = emptyList()
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fun addTag(value: String) {
|
fun addTag(value: String) {
|
||||||
@@ -54,4 +109,15 @@ class ImageDetailViewModel @Inject constructor(
|
|||||||
tagRepository.removeTagFromImage(uri, tag.value)
|
tagRepository.removeTagFromImage(uri, tag.value)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Remove a face tag (person recognition)
|
||||||
|
*/
|
||||||
|
fun removeFaceTag(faceTagInfo: FaceTagInfo) {
|
||||||
|
viewModelScope.launch {
|
||||||
|
photoFaceTagDao.deleteTagById(faceTagInfo.tagId)
|
||||||
|
// Reload face tags
|
||||||
|
imageUri.value?.let { loadFaceTags(it) }
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -95,6 +95,9 @@ fun PersonInventoryScreen(
|
|||||||
},
|
},
|
||||||
onDelete = { personId ->
|
onDelete = { personId ->
|
||||||
viewModel.deletePerson(personId)
|
viewModel.deletePerson(personId)
|
||||||
|
},
|
||||||
|
onClearTags = { personId ->
|
||||||
|
viewModel.clearTagsForPerson(personId)
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
@@ -319,7 +322,8 @@ private fun PersonList(
|
|||||||
persons: List<PersonWithModelInfo>,
|
persons: List<PersonWithModelInfo>,
|
||||||
onScan: (String) -> Unit,
|
onScan: (String) -> Unit,
|
||||||
onView: (String) -> Unit,
|
onView: (String) -> Unit,
|
||||||
onDelete: (String) -> Unit
|
onDelete: (String) -> Unit,
|
||||||
|
onClearTags: (String) -> Unit
|
||||||
) {
|
) {
|
||||||
LazyColumn(
|
LazyColumn(
|
||||||
contentPadding = PaddingValues(vertical = 8.dp)
|
contentPadding = PaddingValues(vertical = 8.dp)
|
||||||
@@ -332,7 +336,8 @@ private fun PersonList(
|
|||||||
person = person,
|
person = person,
|
||||||
onScan = { onScan(person.person.id) },
|
onScan = { onScan(person.person.id) },
|
||||||
onView = { onView(person.person.id) },
|
onView = { onView(person.person.id) },
|
||||||
onDelete = { onDelete(person.person.id) }
|
onDelete = { onDelete(person.person.id) },
|
||||||
|
onClearTags = { onClearTags(person.person.id) }
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -343,9 +348,34 @@ private fun PersonCard(
|
|||||||
person: PersonWithModelInfo,
|
person: PersonWithModelInfo,
|
||||||
onScan: () -> Unit,
|
onScan: () -> Unit,
|
||||||
onView: () -> Unit,
|
onView: () -> Unit,
|
||||||
onDelete: () -> Unit
|
onDelete: () -> Unit,
|
||||||
|
onClearTags: () -> Unit
|
||||||
) {
|
) {
|
||||||
var showDeleteDialog by remember { mutableStateOf(false) }
|
var showDeleteDialog by remember { mutableStateOf(false) }
|
||||||
|
var showClearDialog by remember { mutableStateOf(false) }
|
||||||
|
|
||||||
|
if (showClearDialog) {
|
||||||
|
AlertDialog(
|
||||||
|
onDismissRequest = { showClearDialog = false },
|
||||||
|
title = { Text("Clear tags for ${person.person.name}?") },
|
||||||
|
text = { Text("This will remove all ${person.taggedPhotoCount} photo tags but keep the face model. You can re-scan after clearing.") },
|
||||||
|
confirmButton = {
|
||||||
|
TextButton(
|
||||||
|
onClick = {
|
||||||
|
showClearDialog = false
|
||||||
|
onClearTags()
|
||||||
|
}
|
||||||
|
) {
|
||||||
|
Text("Clear Tags", color = MaterialTheme.colorScheme.error)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
dismissButton = {
|
||||||
|
TextButton(onClick = { showClearDialog = false }) {
|
||||||
|
Text("Cancel")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
if (showDeleteDialog) {
|
if (showDeleteDialog) {
|
||||||
AlertDialog(
|
AlertDialog(
|
||||||
@@ -413,6 +443,17 @@ private fun PersonCard(
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Clear tags button (if has tags)
|
||||||
|
if (person.taggedPhotoCount > 0) {
|
||||||
|
IconButton(onClick = { showClearDialog = true }) {
|
||||||
|
Icon(
|
||||||
|
Icons.Default.ClearAll,
|
||||||
|
contentDescription = "Clear Tags",
|
||||||
|
tint = MaterialTheme.colorScheme.secondary
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Delete button
|
// Delete button
|
||||||
IconButton(onClick = { showDeleteDialog = true }) {
|
IconButton(onClick = { showDeleteDialog = true }) {
|
||||||
Icon(
|
Icon(
|
||||||
|
|||||||
@@ -105,6 +105,21 @@ class PersonInventoryViewModel @Inject constructor(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Clear all face tags for a person (keep model, allow rescan)
|
||||||
|
*/
|
||||||
|
fun clearTagsForPerson(personId: String) {
|
||||||
|
viewModelScope.launch(Dispatchers.IO) {
|
||||||
|
try {
|
||||||
|
val faceModel = faceModelDao.getFaceModelByPersonId(personId)
|
||||||
|
if (faceModel != null) {
|
||||||
|
photoFaceTagDao.deleteTagsForFaceModel(faceModel.id)
|
||||||
|
}
|
||||||
|
loadPersons()
|
||||||
|
} catch (e: Exception) {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fun scanForPerson(personId: String) {
|
fun scanForPerson(personId: String) {
|
||||||
viewModelScope.launch(Dispatchers.IO) {
|
viewModelScope.launch(Dispatchers.IO) {
|
||||||
try {
|
try {
|
||||||
@@ -133,10 +148,20 @@ class PersonInventoryViewModel @Inject constructor(
|
|||||||
.build()
|
.build()
|
||||||
|
|
||||||
val detector = FaceDetection.getClient(detectorOptions)
|
val detector = FaceDetection.getClient(detectorOptions)
|
||||||
val modelEmbedding = faceModel.getEmbeddingArray()
|
// CRITICAL: Use ALL centroids for matching
|
||||||
val faceNetModel = FaceNetModel(context)
|
val modelCentroids = faceModel.getCentroids().map { it.getEmbeddingArray() }
|
||||||
val trainingCount = faceModel.trainingImageCount
|
val trainingCount = faceModel.trainingImageCount
|
||||||
val baseThreshold = ThresholdStrategy.getLiberalThreshold(trainingCount)
|
android.util.Log.e("PersonScan", "=== CENTROIDS: ${modelCentroids.size}, trainingCount: $trainingCount ===")
|
||||||
|
|
||||||
|
if (modelCentroids.isEmpty()) {
|
||||||
|
_scanningState.value = ScanningState.Error("No centroids found")
|
||||||
|
return@launch
|
||||||
|
}
|
||||||
|
|
||||||
|
val faceNetModel = FaceNetModel(context)
|
||||||
|
// Production threshold - balance precision vs recall
|
||||||
|
val baseThreshold = 0.58f
|
||||||
|
android.util.Log.d("PersonScan", "Using threshold: $baseThreshold, centroids: ${modelCentroids.size}")
|
||||||
|
|
||||||
val completed = AtomicInteger(0)
|
val completed = AtomicInteger(0)
|
||||||
val facesFound = AtomicInteger(0)
|
val facesFound = AtomicInteger(0)
|
||||||
@@ -148,7 +173,7 @@ class PersonInventoryViewModel @Inject constructor(
|
|||||||
val jobs = untaggedImages.map { image ->
|
val jobs = untaggedImages.map { image ->
|
||||||
async {
|
async {
|
||||||
semaphore.withPermit {
|
semaphore.withPermit {
|
||||||
processImage(image, detector, faceNetModel, modelEmbedding, trainingCount, baseThreshold, personId, faceModel.id, batchMatches, batchUpdateMutex, completed, facesFound, startTime, totalToScan, person.name)
|
processImage(image, detector, faceNetModel, modelCentroids, trainingCount, baseThreshold, personId, faceModel.id, batchMatches, batchUpdateMutex, completed, facesFound, startTime, totalToScan, person.name)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -175,7 +200,7 @@ class PersonInventoryViewModel @Inject constructor(
|
|||||||
|
|
||||||
private suspend fun processImage(
|
private suspend fun processImage(
|
||||||
image: ImageEntity, detector: com.google.mlkit.vision.face.FaceDetector, faceNetModel: FaceNetModel,
|
image: ImageEntity, detector: com.google.mlkit.vision.face.FaceDetector, faceNetModel: FaceNetModel,
|
||||||
modelEmbedding: FloatArray, trainingCount: Int, baseThreshold: Float, personId: String, faceModelId: String,
|
modelCentroids: List<FloatArray>, trainingCount: Int, baseThreshold: Float, personId: String, faceModelId: String,
|
||||||
batchMatches: MutableList<Triple<String, String, Float>>, batchUpdateMutex: Mutex,
|
batchMatches: MutableList<Triple<String, String, Float>>, batchUpdateMutex: Mutex,
|
||||||
completed: AtomicInteger, facesFound: AtomicInteger, startTime: Long, totalToScan: Int, personName: String
|
completed: AtomicInteger, facesFound: AtomicInteger, startTime: Long, totalToScan: Int, personName: String
|
||||||
) {
|
) {
|
||||||
@@ -212,14 +237,19 @@ class PersonInventoryViewModel @Inject constructor(
|
|||||||
(face.boundingBox.bottom * scaleY).toInt()
|
(face.boundingBox.bottom * scaleY).toInt()
|
||||||
)
|
)
|
||||||
|
|
||||||
val faceBitmap = loadFaceRegion(uri, scaledBounds) ?: continue
|
// CRITICAL: Add padding to face crop (same as training)
|
||||||
|
val faceBitmap = loadFaceRegionWithPadding(uri, scaledBounds, sizeOpts.outWidth, sizeOpts.outHeight) ?: continue
|
||||||
val faceEmbedding = faceNetModel.generateEmbedding(faceBitmap)
|
val faceEmbedding = faceNetModel.generateEmbedding(faceBitmap)
|
||||||
val similarity = faceNetModel.calculateSimilarity(faceEmbedding, modelEmbedding)
|
|
||||||
faceBitmap.recycle()
|
faceBitmap.recycle()
|
||||||
|
|
||||||
if (similarity >= threshold) {
|
// Match against ALL centroids, use best match
|
||||||
|
val bestSimilarity = modelCentroids.maxOfOrNull { centroid ->
|
||||||
|
faceNetModel.calculateSimilarity(faceEmbedding, centroid)
|
||||||
|
} ?: 0f
|
||||||
|
|
||||||
|
if (bestSimilarity >= threshold) {
|
||||||
batchUpdateMutex.withLock {
|
batchUpdateMutex.withLock {
|
||||||
batchMatches.add(Triple(personId, image.imageId, similarity))
|
batchMatches.add(Triple(personId, image.imageId, bestSimilarity))
|
||||||
facesFound.incrementAndGet()
|
facesFound.incrementAndGet()
|
||||||
if (batchMatches.size >= BATCH_DB_SIZE) {
|
if (batchMatches.size >= BATCH_DB_SIZE) {
|
||||||
saveBatchMatches(batchMatches.toList(), faceModelId)
|
saveBatchMatches(batchMatches.toList(), faceModelId)
|
||||||
@@ -250,18 +280,32 @@ class PersonInventoryViewModel @Inject constructor(
|
|||||||
} catch (e: Exception) { null }
|
} catch (e: Exception) { null }
|
||||||
}
|
}
|
||||||
|
|
||||||
private fun loadFaceRegion(uri: Uri, bounds: android.graphics.Rect): Bitmap? {
|
/**
|
||||||
|
* Load face region WITH 25% padding - CRITICAL for matching training conditions
|
||||||
|
*/
|
||||||
|
private fun loadFaceRegionWithPadding(uri: Uri, bounds: android.graphics.Rect, imgWidth: Int, imgHeight: Int): Bitmap? {
|
||||||
return try {
|
return try {
|
||||||
val full = context.contentResolver.openInputStream(uri)?.use {
|
val full = context.contentResolver.openInputStream(uri)?.use {
|
||||||
BitmapFactory.decodeStream(it, null, BitmapFactory.Options().apply { inPreferredConfig = Bitmap.Config.ARGB_8888 })
|
BitmapFactory.decodeStream(it, null, BitmapFactory.Options().apply { inPreferredConfig = Bitmap.Config.ARGB_8888 })
|
||||||
} ?: return null
|
} ?: return null
|
||||||
|
|
||||||
val safeLeft = bounds.left.coerceIn(0, full.width - 1)
|
// Add 25% padding (same as training)
|
||||||
val safeTop = bounds.top.coerceIn(0, full.height - 1)
|
val padding = (kotlin.math.max(bounds.width(), bounds.height()) * 0.25f).toInt()
|
||||||
val safeWidth = bounds.width().coerceAtMost(full.width - safeLeft)
|
|
||||||
val safeHeight = bounds.height().coerceAtMost(full.height - safeTop)
|
|
||||||
|
|
||||||
val cropped = Bitmap.createBitmap(full, safeLeft, safeTop, safeWidth, safeHeight)
|
val left = (bounds.left - padding).coerceAtLeast(0)
|
||||||
|
val top = (bounds.top - padding).coerceAtLeast(0)
|
||||||
|
val right = (bounds.right + padding).coerceAtMost(full.width)
|
||||||
|
val bottom = (bounds.bottom + padding).coerceAtMost(full.height)
|
||||||
|
|
||||||
|
val width = right - left
|
||||||
|
val height = bottom - top
|
||||||
|
|
||||||
|
if (width <= 0 || height <= 0) {
|
||||||
|
full.recycle()
|
||||||
|
return null
|
||||||
|
}
|
||||||
|
|
||||||
|
val cropped = Bitmap.createBitmap(full, left, top, width, height)
|
||||||
full.recycle()
|
full.recycle()
|
||||||
cropped
|
cropped
|
||||||
} catch (e: Exception) { null }
|
} catch (e: Exception) { null }
|
||||||
|
|||||||
@@ -192,10 +192,11 @@ class TrainViewModel @Inject constructor(
|
|||||||
.first()
|
.first()
|
||||||
|
|
||||||
if (backgroundTaggingEnabled) {
|
if (backgroundTaggingEnabled) {
|
||||||
|
// Lower threshold (0.55) since we use multi-centroid matching
|
||||||
val scanRequest = LibraryScanWorker.createWorkRequest(
|
val scanRequest = LibraryScanWorker.createWorkRequest(
|
||||||
personId = personId,
|
personId = personId,
|
||||||
personName = personName,
|
personName = personName,
|
||||||
threshold = 0.65f
|
threshold = 0.55f
|
||||||
)
|
)
|
||||||
workManager.enqueue(scanRequest)
|
workManager.enqueue(scanRequest)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -49,6 +49,7 @@ fun TrainingPhotoSelectorScreen(
|
|||||||
val isRanking by viewModel.isRanking.collectAsStateWithLifecycle()
|
val isRanking by viewModel.isRanking.collectAsStateWithLifecycle()
|
||||||
val showPremiumOnly by viewModel.showPremiumOnly.collectAsStateWithLifecycle()
|
val showPremiumOnly by viewModel.showPremiumOnly.collectAsStateWithLifecycle()
|
||||||
val premiumCount by viewModel.premiumCount.collectAsStateWithLifecycle()
|
val premiumCount by viewModel.premiumCount.collectAsStateWithLifecycle()
|
||||||
|
val embeddingProgress by viewModel.embeddingProgress.collectAsStateWithLifecycle()
|
||||||
|
|
||||||
Scaffold(
|
Scaffold(
|
||||||
topBar = {
|
topBar = {
|
||||||
@@ -155,7 +156,33 @@ fun TrainingPhotoSelectorScreen(
|
|||||||
modifier = Modifier.fillMaxSize(),
|
modifier = Modifier.fillMaxSize(),
|
||||||
contentAlignment = Alignment.Center
|
contentAlignment = Alignment.Center
|
||||||
) {
|
) {
|
||||||
CircularProgressIndicator()
|
Column(
|
||||||
|
horizontalAlignment = Alignment.CenterHorizontally,
|
||||||
|
verticalArrangement = Arrangement.spacedBy(16.dp)
|
||||||
|
) {
|
||||||
|
CircularProgressIndicator()
|
||||||
|
// Capture value to avoid race condition
|
||||||
|
val progress = embeddingProgress
|
||||||
|
if (progress != null) {
|
||||||
|
Text(
|
||||||
|
"Preparing faces: ${progress.current}/${progress.total}",
|
||||||
|
style = MaterialTheme.typography.bodyMedium,
|
||||||
|
color = MaterialTheme.colorScheme.onSurfaceVariant
|
||||||
|
)
|
||||||
|
LinearProgressIndicator(
|
||||||
|
progress = { progress.current.toFloat() / progress.total },
|
||||||
|
modifier = Modifier
|
||||||
|
.width(200.dp)
|
||||||
|
.padding(top = 8.dp)
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
Text(
|
||||||
|
"Loading premium faces...",
|
||||||
|
style = MaterialTheme.typography.bodyMedium,
|
||||||
|
color = MaterialTheme.colorScheme.onSurfaceVariant
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
photos.isEmpty() -> {
|
photos.isEmpty() -> {
|
||||||
|
|||||||
@@ -1,20 +1,31 @@
|
|||||||
package com.placeholder.sherpai2.ui.trainingprep
|
package com.placeholder.sherpai2.ui.trainingprep
|
||||||
|
|
||||||
|
import android.app.Application
|
||||||
|
import android.graphics.Bitmap
|
||||||
|
import android.graphics.BitmapFactory
|
||||||
|
import android.graphics.Rect
|
||||||
|
import android.net.Uri
|
||||||
import android.util.Log
|
import android.util.Log
|
||||||
import androidx.lifecycle.ViewModel
|
import androidx.lifecycle.AndroidViewModel
|
||||||
import androidx.lifecycle.viewModelScope
|
import androidx.lifecycle.viewModelScope
|
||||||
import com.placeholder.sherpai2.data.local.dao.FaceCacheDao
|
import com.placeholder.sherpai2.data.local.dao.FaceCacheDao
|
||||||
import com.placeholder.sherpai2.data.local.dao.ImageDao
|
import com.placeholder.sherpai2.data.local.dao.ImageDao
|
||||||
|
import com.placeholder.sherpai2.data.local.entity.FaceCacheEntity
|
||||||
import com.placeholder.sherpai2.data.local.entity.ImageEntity
|
import com.placeholder.sherpai2.data.local.entity.ImageEntity
|
||||||
import com.placeholder.sherpai2.domain.similarity.FaceSimilarityScorer
|
import com.placeholder.sherpai2.domain.similarity.FaceSimilarityScorer
|
||||||
|
import com.placeholder.sherpai2.ml.FaceNetModel
|
||||||
import dagger.hilt.android.lifecycle.HiltViewModel
|
import dagger.hilt.android.lifecycle.HiltViewModel
|
||||||
|
import kotlinx.coroutines.Dispatchers
|
||||||
import kotlinx.coroutines.Job
|
import kotlinx.coroutines.Job
|
||||||
import kotlinx.coroutines.delay
|
import kotlinx.coroutines.delay
|
||||||
import kotlinx.coroutines.flow.MutableStateFlow
|
import kotlinx.coroutines.flow.MutableStateFlow
|
||||||
import kotlinx.coroutines.flow.StateFlow
|
import kotlinx.coroutines.flow.StateFlow
|
||||||
import kotlinx.coroutines.flow.asStateFlow
|
import kotlinx.coroutines.flow.asStateFlow
|
||||||
import kotlinx.coroutines.launch
|
import kotlinx.coroutines.launch
|
||||||
|
import kotlinx.coroutines.withContext
|
||||||
import javax.inject.Inject
|
import javax.inject.Inject
|
||||||
|
import kotlin.math.max
|
||||||
|
import kotlin.math.min
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* TrainingPhotoSelectorViewModel - PREMIUM GRID + ROLLING SCAN
|
* TrainingPhotoSelectorViewModel - PREMIUM GRID + ROLLING SCAN
|
||||||
@@ -27,15 +38,18 @@ import javax.inject.Inject
|
|||||||
*/
|
*/
|
||||||
@HiltViewModel
|
@HiltViewModel
|
||||||
class TrainingPhotoSelectorViewModel @Inject constructor(
|
class TrainingPhotoSelectorViewModel @Inject constructor(
|
||||||
|
application: Application,
|
||||||
private val imageDao: ImageDao,
|
private val imageDao: ImageDao,
|
||||||
private val faceCacheDao: FaceCacheDao,
|
private val faceCacheDao: FaceCacheDao,
|
||||||
private val faceSimilarityScorer: FaceSimilarityScorer
|
private val faceSimilarityScorer: FaceSimilarityScorer,
|
||||||
) : ViewModel() {
|
private val faceNetModel: FaceNetModel
|
||||||
|
) : AndroidViewModel(application) {
|
||||||
|
|
||||||
companion object {
|
companion object {
|
||||||
private const val TAG = "PremiumSelector"
|
private const val TAG = "PremiumSelector"
|
||||||
private const val MIN_SEEDS_FOR_ROLLING_SCAN = 1
|
private const val MIN_SEEDS_FOR_ROLLING_SCAN = 1
|
||||||
private const val MAX_SEEDS_FOR_ROLLING_SCAN = 5
|
private const val MAX_SEEDS_FOR_ROLLING_SCAN = 5
|
||||||
|
private const val MAX_EMBEDDINGS_TO_GENERATE = 500
|
||||||
}
|
}
|
||||||
|
|
||||||
// All photos (for fallback / full list)
|
// All photos (for fallback / full list)
|
||||||
@@ -56,6 +70,12 @@ class TrainingPhotoSelectorViewModel @Inject constructor(
|
|||||||
private val _isRanking = MutableStateFlow(false)
|
private val _isRanking = MutableStateFlow(false)
|
||||||
val isRanking: StateFlow<Boolean> = _isRanking.asStateFlow()
|
val isRanking: StateFlow<Boolean> = _isRanking.asStateFlow()
|
||||||
|
|
||||||
|
// Embedding generation progress
|
||||||
|
private val _embeddingProgress = MutableStateFlow<EmbeddingProgress?>(null)
|
||||||
|
val embeddingProgress: StateFlow<EmbeddingProgress?> = _embeddingProgress.asStateFlow()
|
||||||
|
|
||||||
|
data class EmbeddingProgress(val current: Int, val total: Int)
|
||||||
|
|
||||||
// Premium mode toggle
|
// Premium mode toggle
|
||||||
private val _showPremiumOnly = MutableStateFlow(true)
|
private val _showPremiumOnly = MutableStateFlow(true)
|
||||||
val showPremiumOnly: StateFlow<Boolean> = _showPremiumOnly.asStateFlow()
|
val showPremiumOnly: StateFlow<Boolean> = _showPremiumOnly.asStateFlow()
|
||||||
@@ -79,20 +99,47 @@ class TrainingPhotoSelectorViewModel @Inject constructor(
|
|||||||
|
|
||||||
/**
|
/**
|
||||||
* Load PREMIUM faces first (solo, large, frontal, high quality)
|
* Load PREMIUM faces first (solo, large, frontal, high quality)
|
||||||
|
* If no embeddings exist, generate them on-demand for premium candidates
|
||||||
*/
|
*/
|
||||||
private fun loadPremiumFaces() {
|
private fun loadPremiumFaces() {
|
||||||
viewModelScope.launch {
|
viewModelScope.launch {
|
||||||
try {
|
try {
|
||||||
_isLoading.value = true
|
_isLoading.value = true
|
||||||
|
|
||||||
// Get premium faces from cache
|
// First check if premium faces with embeddings exist
|
||||||
val premiumFaceCache = faceCacheDao.getPremiumFaces(
|
var premiumFaceCache = faceCacheDao.getPremiumFaces(
|
||||||
minAreaRatio = 0.10f,
|
minAreaRatio = 0.10f,
|
||||||
minQuality = 0.7f,
|
minQuality = 0.7f,
|
||||||
limit = 500
|
limit = 500
|
||||||
)
|
)
|
||||||
|
|
||||||
Log.d(TAG, "✅ Found ${premiumFaceCache.size} premium faces")
|
Log.d(TAG, "📊 Found ${premiumFaceCache.size} premium faces with embeddings")
|
||||||
|
|
||||||
|
// If no premium faces with embeddings, generate them on-demand
|
||||||
|
if (premiumFaceCache.isEmpty()) {
|
||||||
|
Log.d(TAG, "⚠️ No premium faces with embeddings - generating on-demand")
|
||||||
|
|
||||||
|
val candidates = faceCacheDao.getPremiumFaceCandidatesNeedingEmbeddings(
|
||||||
|
minAreaRatio = 0.10f,
|
||||||
|
minQuality = 0.7f,
|
||||||
|
limit = MAX_EMBEDDINGS_TO_GENERATE
|
||||||
|
)
|
||||||
|
|
||||||
|
Log.d(TAG, "📦 Found ${candidates.size} premium candidates needing embeddings")
|
||||||
|
|
||||||
|
if (candidates.isNotEmpty()) {
|
||||||
|
generateEmbeddingsForCandidates(candidates)
|
||||||
|
|
||||||
|
// Re-query after generating
|
||||||
|
premiumFaceCache = faceCacheDao.getPremiumFaces(
|
||||||
|
minAreaRatio = 0.10f,
|
||||||
|
minQuality = 0.7f,
|
||||||
|
limit = 500
|
||||||
|
)
|
||||||
|
Log.d(TAG, "✅ After generation: ${premiumFaceCache.size} premium faces")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
_premiumCount.value = premiumFaceCache.size
|
_premiumCount.value = premiumFaceCache.size
|
||||||
|
|
||||||
// Get corresponding ImageEntities
|
// Get corresponding ImageEntities
|
||||||
@@ -117,10 +164,108 @@ class TrainingPhotoSelectorViewModel @Inject constructor(
|
|||||||
loadAllFaces()
|
loadAllFaces()
|
||||||
} finally {
|
} finally {
|
||||||
_isLoading.value = false
|
_isLoading.value = false
|
||||||
|
_embeddingProgress.value = null
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Generate embeddings for premium face candidates
|
||||||
|
*/
|
||||||
|
private suspend fun generateEmbeddingsForCandidates(candidates: List<FaceCacheEntity>) {
|
||||||
|
val context = getApplication<Application>()
|
||||||
|
val total = candidates.size
|
||||||
|
var processed = 0
|
||||||
|
|
||||||
|
withContext(Dispatchers.IO) {
|
||||||
|
// Get image URIs for candidates
|
||||||
|
val imageIds = candidates.map { it.imageId }.distinct()
|
||||||
|
val images = imageDao.getImagesByIds(imageIds)
|
||||||
|
val imageUriMap = images.associate { it.imageId to it.imageUri }
|
||||||
|
|
||||||
|
for (candidate in candidates) {
|
||||||
|
try {
|
||||||
|
val imageUri = imageUriMap[candidate.imageId] ?: continue
|
||||||
|
|
||||||
|
// Load bitmap
|
||||||
|
val bitmap = loadBitmapOptimized(context, Uri.parse(imageUri)) ?: continue
|
||||||
|
|
||||||
|
// Crop face
|
||||||
|
val croppedFace = cropFaceWithPadding(bitmap, candidate.getBoundingBox())
|
||||||
|
bitmap.recycle()
|
||||||
|
|
||||||
|
if (croppedFace == null) continue
|
||||||
|
|
||||||
|
// Generate embedding
|
||||||
|
val embedding = faceNetModel.generateEmbedding(croppedFace)
|
||||||
|
croppedFace.recycle()
|
||||||
|
|
||||||
|
// Validate embedding
|
||||||
|
if (embedding.any { it != 0f }) {
|
||||||
|
// Save to database
|
||||||
|
val embeddingJson = FaceCacheEntity.embeddingToJson(embedding)
|
||||||
|
faceCacheDao.updateEmbedding(candidate.imageId, candidate.faceIndex, embeddingJson)
|
||||||
|
}
|
||||||
|
|
||||||
|
} catch (e: Exception) {
|
||||||
|
Log.w(TAG, "Failed to generate embedding for ${candidate.imageId}: ${e.message}")
|
||||||
|
}
|
||||||
|
|
||||||
|
processed++
|
||||||
|
withContext(Dispatchers.Main) {
|
||||||
|
_embeddingProgress.value = EmbeddingProgress(processed, total)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Log.d(TAG, "✅ Generated embeddings for $processed/$total candidates")
|
||||||
|
}
|
||||||
|
|
||||||
|
private fun loadBitmapOptimized(context: android.content.Context, uri: Uri, maxDim: Int = 768): Bitmap? {
|
||||||
|
return try {
|
||||||
|
val options = BitmapFactory.Options().apply { inJustDecodeBounds = true }
|
||||||
|
context.contentResolver.openInputStream(uri)?.use { stream ->
|
||||||
|
BitmapFactory.decodeStream(stream, null, options)
|
||||||
|
}
|
||||||
|
|
||||||
|
var sampleSize = 1
|
||||||
|
while (options.outWidth / sampleSize > maxDim || options.outHeight / sampleSize > maxDim) {
|
||||||
|
sampleSize *= 2
|
||||||
|
}
|
||||||
|
|
||||||
|
val finalOptions = BitmapFactory.Options().apply {
|
||||||
|
inSampleSize = sampleSize
|
||||||
|
inPreferredConfig = Bitmap.Config.ARGB_8888
|
||||||
|
}
|
||||||
|
|
||||||
|
context.contentResolver.openInputStream(uri)?.use { stream ->
|
||||||
|
BitmapFactory.decodeStream(stream, null, finalOptions)
|
||||||
|
}
|
||||||
|
} catch (e: Exception) {
|
||||||
|
Log.w(TAG, "Failed to load bitmap: ${e.message}")
|
||||||
|
null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private fun cropFaceWithPadding(bitmap: Bitmap, boundingBox: Rect): Bitmap? {
|
||||||
|
return try {
|
||||||
|
val padding = (max(boundingBox.width(), boundingBox.height()) * 0.25f).toInt()
|
||||||
|
val left = max(0, boundingBox.left - padding)
|
||||||
|
val top = max(0, boundingBox.top - padding)
|
||||||
|
val right = min(bitmap.width, boundingBox.right + padding)
|
||||||
|
val bottom = min(bitmap.height, boundingBox.bottom + padding)
|
||||||
|
val width = right - left
|
||||||
|
val height = bottom - top
|
||||||
|
|
||||||
|
if (width > 0 && height > 0) {
|
||||||
|
Bitmap.createBitmap(bitmap, left, top, width, height)
|
||||||
|
} else null
|
||||||
|
} catch (e: Exception) {
|
||||||
|
Log.w(TAG, "Failed to crop face: ${e.message}")
|
||||||
|
null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Fallback: load all photos with faces
|
* Fallback: load all photos with faces
|
||||||
*/
|
*/
|
||||||
|
|||||||
Reference in New Issue
Block a user