holy fuck Alice we're not in Kansas
This commit is contained in:
2
.idea/deploymentTargetSelector.xml
generated
2
.idea/deploymentTargetSelector.xml
generated
@@ -4,7 +4,7 @@
|
|||||||
<selectionStates>
|
<selectionStates>
|
||||||
<SelectionState runConfigName="app">
|
<SelectionState runConfigName="app">
|
||||||
<option name="selectionMode" value="DROPDOWN" />
|
<option name="selectionMode" value="DROPDOWN" />
|
||||||
<DropdownSelection timestamp="2026-01-08T02:44:48.809354959Z">
|
<DropdownSelection timestamp="2026-01-18T23:43:22.974426869Z">
|
||||||
<Target type="DEFAULT_BOOT">
|
<Target type="DEFAULT_BOOT">
|
||||||
<handle>
|
<handle>
|
||||||
<DeviceId pluginId="LocalEmulator" identifier="path=/home/genki/.android/avd/Medium_Phone.avd" />
|
<DeviceId pluginId="LocalEmulator" identifier="path=/home/genki/.android/avd/Medium_Phone.avd" />
|
||||||
|
|||||||
14
.idea/deviceManager.xml
generated
14
.idea/deviceManager.xml
generated
@@ -1,6 +1,20 @@
|
|||||||
<?xml version="1.0" encoding="UTF-8"?>
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
<project version="4">
|
<project version="4">
|
||||||
<component name="DeviceTable">
|
<component name="DeviceTable">
|
||||||
|
<option name="collapsedNodes">
|
||||||
|
<list>
|
||||||
|
<CategoryListState>
|
||||||
|
<option name="categories">
|
||||||
|
<list>
|
||||||
|
<CategoryState>
|
||||||
|
<option name="attribute" value="Type" />
|
||||||
|
<option name="value" value="Physical" />
|
||||||
|
</CategoryState>
|
||||||
|
</list>
|
||||||
|
</option>
|
||||||
|
</CategoryListState>
|
||||||
|
</list>
|
||||||
|
</option>
|
||||||
<option name="columnSorters">
|
<option name="columnSorters">
|
||||||
<list>
|
<list>
|
||||||
<ColumnSorterState>
|
<ColumnSorterState>
|
||||||
|
|||||||
@@ -10,6 +10,11 @@ import com.placeholder.sherpai2.data.local.entity.*
|
|||||||
/**
|
/**
|
||||||
* AppDatabase - Complete database for SherpAI2
|
* AppDatabase - Complete database for SherpAI2
|
||||||
*
|
*
|
||||||
|
* VERSION 9 - PHASE 2.5: Enhanced face cache with per-face metadata
|
||||||
|
* - Added FaceCacheEntity for per-face quality metrics and embeddings
|
||||||
|
* - Enables intelligent filtering (large faces, frontal, high quality)
|
||||||
|
* - Stores pre-computed embeddings for 10x faster clustering
|
||||||
|
*
|
||||||
* VERSION 8 - PHASE 2: Multi-centroid face models + age tagging
|
* VERSION 8 - PHASE 2: Multi-centroid face models + age tagging
|
||||||
* - Added PersonEntity.isChild, siblingIds, familyGroupId
|
* - Added PersonEntity.isChild, siblingIds, familyGroupId
|
||||||
* - Changed FaceModelEntity.embedding → centroidsJson (multi-centroid)
|
* - Changed FaceModelEntity.embedding → centroidsJson (multi-centroid)
|
||||||
@@ -17,7 +22,7 @@ import com.placeholder.sherpai2.data.local.entity.*
|
|||||||
*
|
*
|
||||||
* MIGRATION STRATEGY:
|
* MIGRATION STRATEGY:
|
||||||
* - Development: fallbackToDestructiveMigration (fresh install)
|
* - Development: fallbackToDestructiveMigration (fresh install)
|
||||||
* - Production: Add MIGRATION_7_8 before release
|
* - Production: Add MIGRATION_7_8, MIGRATION_8_9 before release
|
||||||
*/
|
*/
|
||||||
@Database(
|
@Database(
|
||||||
entities = [
|
entities = [
|
||||||
@@ -32,14 +37,15 @@ import com.placeholder.sherpai2.data.local.entity.*
|
|||||||
PersonEntity::class,
|
PersonEntity::class,
|
||||||
FaceModelEntity::class,
|
FaceModelEntity::class,
|
||||||
PhotoFaceTagEntity::class,
|
PhotoFaceTagEntity::class,
|
||||||
PersonAgeTagEntity::class, // NEW: Age tagging
|
PersonAgeTagEntity::class, // NEW in v8: Age tagging
|
||||||
|
FaceCacheEntity::class, // NEW in v9: Per-face metadata cache
|
||||||
|
|
||||||
// ===== COLLECTIONS =====
|
// ===== COLLECTIONS =====
|
||||||
CollectionEntity::class,
|
CollectionEntity::class,
|
||||||
CollectionImageEntity::class,
|
CollectionImageEntity::class,
|
||||||
CollectionFilterEntity::class
|
CollectionFilterEntity::class
|
||||||
],
|
],
|
||||||
version = 8, // INCREMENTED for Phase 2
|
version = 9, // INCREMENTED for face cache
|
||||||
exportSchema = false
|
exportSchema = false
|
||||||
)
|
)
|
||||||
abstract class AppDatabase : RoomDatabase() {
|
abstract class AppDatabase : RoomDatabase() {
|
||||||
@@ -56,7 +62,8 @@ abstract class AppDatabase : RoomDatabase() {
|
|||||||
abstract fun personDao(): PersonDao
|
abstract fun personDao(): PersonDao
|
||||||
abstract fun faceModelDao(): FaceModelDao
|
abstract fun faceModelDao(): FaceModelDao
|
||||||
abstract fun photoFaceTagDao(): PhotoFaceTagDao
|
abstract fun photoFaceTagDao(): PhotoFaceTagDao
|
||||||
abstract fun personAgeTagDao(): PersonAgeTagDao // NEW
|
abstract fun personAgeTagDao(): PersonAgeTagDao // NEW in v8
|
||||||
|
abstract fun faceCacheDao(): FaceCacheDao // NEW in v9
|
||||||
|
|
||||||
// ===== COLLECTIONS DAO =====
|
// ===== COLLECTIONS DAO =====
|
||||||
abstract fun collectionDao(): CollectionDao
|
abstract fun collectionDao(): CollectionDao
|
||||||
@@ -154,13 +161,57 @@ val MIGRATION_7_8 = object : Migration(7, 8) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* MIGRATION 8 → 9 (Phase 2.5)
|
||||||
|
*
|
||||||
|
* Changes:
|
||||||
|
* 1. Create face_cache table for per-face metadata
|
||||||
|
* 2. Store face quality metrics (size, position, quality score)
|
||||||
|
* 3. Store pre-computed embeddings for fast clustering
|
||||||
|
*/
|
||||||
|
val MIGRATION_8_9 = object : Migration(8, 9) {
|
||||||
|
override fun migrate(database: SupportSQLiteDatabase) {
|
||||||
|
|
||||||
|
// ===== Create face_cache table =====
|
||||||
|
database.execSQL("""
|
||||||
|
CREATE TABLE IF NOT EXISTS face_cache (
|
||||||
|
id TEXT PRIMARY KEY NOT NULL,
|
||||||
|
imageId TEXT NOT NULL,
|
||||||
|
faceIndex INTEGER NOT NULL,
|
||||||
|
boundingBox TEXT NOT NULL,
|
||||||
|
faceWidth INTEGER NOT NULL,
|
||||||
|
faceHeight INTEGER NOT NULL,
|
||||||
|
faceAreaRatio REAL NOT NULL,
|
||||||
|
imageWidth INTEGER NOT NULL,
|
||||||
|
imageHeight INTEGER NOT NULL,
|
||||||
|
qualityScore REAL NOT NULL,
|
||||||
|
isLargeEnough INTEGER NOT NULL,
|
||||||
|
isFrontal INTEGER NOT NULL,
|
||||||
|
hasGoodLighting INTEGER NOT NULL,
|
||||||
|
embedding TEXT,
|
||||||
|
confidence REAL NOT NULL,
|
||||||
|
detectedAt INTEGER NOT NULL,
|
||||||
|
cacheVersion INTEGER NOT NULL,
|
||||||
|
FOREIGN KEY(imageId) REFERENCES images(imageId) ON DELETE CASCADE
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
|
||||||
|
// ===== Create indices for performance =====
|
||||||
|
database.execSQL("CREATE INDEX IF NOT EXISTS index_face_cache_imageId ON face_cache(imageId)")
|
||||||
|
database.execSQL("CREATE INDEX IF NOT EXISTS index_face_cache_faceIndex ON face_cache(faceIndex)")
|
||||||
|
database.execSQL("CREATE INDEX IF NOT EXISTS index_face_cache_faceAreaRatio ON face_cache(faceAreaRatio)")
|
||||||
|
database.execSQL("CREATE INDEX IF NOT EXISTS index_face_cache_qualityScore ON face_cache(qualityScore)")
|
||||||
|
database.execSQL("CREATE UNIQUE INDEX IF NOT EXISTS index_face_cache_imageId_faceIndex ON face_cache(imageId, faceIndex)")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* PRODUCTION MIGRATION NOTES:
|
* PRODUCTION MIGRATION NOTES:
|
||||||
*
|
*
|
||||||
* Before shipping to users, update DatabaseModule to use migration:
|
* Before shipping to users, update DatabaseModule to use migrations:
|
||||||
*
|
*
|
||||||
* Room.databaseBuilder(context, AppDatabase::class.java, "sherpai.db")
|
* Room.databaseBuilder(context, AppDatabase::class.java, "sherpai.db")
|
||||||
* .addMigrations(MIGRATION_7_8) // Add this
|
* .addMigrations(MIGRATION_7_8, MIGRATION_8_9) // Add both
|
||||||
* // .fallbackToDestructiveMigration() // Remove this
|
* // .fallbackToDestructiveMigration() // Remove this
|
||||||
* .build()
|
* .build()
|
||||||
*/
|
*/
|
||||||
@@ -0,0 +1,129 @@
|
|||||||
|
package com.placeholder.sherpai2.data.local.dao
|
||||||
|
|
||||||
|
import androidx.room.*
|
||||||
|
import com.placeholder.sherpai2.data.local.entity.FaceCacheEntity
|
||||||
|
import kotlinx.coroutines.flow.Flow
|
||||||
|
|
||||||
|
/**
|
||||||
|
* FaceCacheDao - Query face metadata for intelligent filtering
|
||||||
|
*
|
||||||
|
* ENABLES SMART CLUSTERING:
|
||||||
|
* - Pre-filter to high-quality faces only
|
||||||
|
* - Avoid processing blurry/distant faces
|
||||||
|
* - Faster clustering with better results
|
||||||
|
*/
|
||||||
|
@Dao
|
||||||
|
interface FaceCacheDao {
|
||||||
|
|
||||||
|
@Insert(onConflict = OnConflictStrategy.REPLACE)
|
||||||
|
suspend fun insert(faceCache: FaceCacheEntity)
|
||||||
|
|
||||||
|
@Insert(onConflict = OnConflictStrategy.REPLACE)
|
||||||
|
suspend fun insertAll(faceCaches: List<FaceCacheEntity>)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get ALL high-quality solo faces for clustering
|
||||||
|
*
|
||||||
|
* FILTERS:
|
||||||
|
* - Solo photos only (joins with images.faceCount = 1)
|
||||||
|
* - Large enough (isLargeEnough = true)
|
||||||
|
* - Good quality score (>= 0.6)
|
||||||
|
* - Frontal faces preferred (isFrontal = true)
|
||||||
|
*/
|
||||||
|
@Query("""
|
||||||
|
SELECT fc.* FROM face_cache fc
|
||||||
|
INNER JOIN images i ON fc.imageId = i.imageId
|
||||||
|
WHERE i.faceCount = 1
|
||||||
|
AND fc.isLargeEnough = 1
|
||||||
|
AND fc.qualityScore >= 0.6
|
||||||
|
AND fc.isFrontal = 1
|
||||||
|
ORDER BY fc.qualityScore DESC
|
||||||
|
""")
|
||||||
|
suspend fun getHighQualitySoloFaces(): List<FaceCacheEntity>
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get high-quality faces from ANY photo (including group photos)
|
||||||
|
* Use when not enough solo photos available
|
||||||
|
*/
|
||||||
|
@Query("""
|
||||||
|
SELECT * FROM face_cache
|
||||||
|
WHERE isLargeEnough = 1
|
||||||
|
AND qualityScore >= 0.6
|
||||||
|
AND isFrontal = 1
|
||||||
|
ORDER BY qualityScore DESC
|
||||||
|
LIMIT :limit
|
||||||
|
""")
|
||||||
|
suspend fun getHighQualityFaces(limit: Int = 1000): List<FaceCacheEntity>
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get faces for a specific image
|
||||||
|
*/
|
||||||
|
@Query("SELECT * FROM face_cache WHERE imageId = :imageId ORDER BY faceIndex ASC")
|
||||||
|
suspend fun getFacesForImage(imageId: String): List<FaceCacheEntity>
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Count high-quality solo faces (for UI display)
|
||||||
|
*/
|
||||||
|
@Query("""
|
||||||
|
SELECT COUNT(*) FROM face_cache fc
|
||||||
|
INNER JOIN images i ON fc.imageId = i.imageId
|
||||||
|
WHERE i.faceCount = 1
|
||||||
|
AND fc.isLargeEnough = 1
|
||||||
|
AND fc.qualityScore >= 0.6
|
||||||
|
""")
|
||||||
|
suspend fun getHighQualitySoloFaceCount(): Int
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get quality distribution stats
|
||||||
|
*/
|
||||||
|
@Query("""
|
||||||
|
SELECT
|
||||||
|
SUM(CASE WHEN qualityScore >= 0.8 THEN 1 ELSE 0 END) as excellent,
|
||||||
|
SUM(CASE WHEN qualityScore >= 0.6 AND qualityScore < 0.8 THEN 1 ELSE 0 END) as good,
|
||||||
|
SUM(CASE WHEN qualityScore < 0.6 THEN 1 ELSE 0 END) as poor,
|
||||||
|
COUNT(*) as total
|
||||||
|
FROM face_cache
|
||||||
|
""")
|
||||||
|
suspend fun getQualityStats(): FaceQualityStats?
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Delete cache for specific image (when image is deleted)
|
||||||
|
*/
|
||||||
|
@Query("DELETE FROM face_cache WHERE imageId = :imageId")
|
||||||
|
suspend fun deleteCacheForImage(imageId: String)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Delete all cache (for full rebuild)
|
||||||
|
*/
|
||||||
|
@Query("DELETE FROM face_cache")
|
||||||
|
suspend fun deleteAll()
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get faces with embeddings already computed
|
||||||
|
* (Ultra-fast clustering - no need to re-generate)
|
||||||
|
*/
|
||||||
|
@Query("""
|
||||||
|
SELECT fc.* FROM face_cache fc
|
||||||
|
INNER JOIN images i ON fc.imageId = i.imageId
|
||||||
|
WHERE i.faceCount = 1
|
||||||
|
AND fc.isLargeEnough = 1
|
||||||
|
AND fc.embedding IS NOT NULL
|
||||||
|
ORDER BY fc.qualityScore DESC
|
||||||
|
LIMIT :limit
|
||||||
|
""")
|
||||||
|
suspend fun getSoloFacesWithEmbeddings(limit: Int = 2000): List<FaceCacheEntity>
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Quality statistics result
|
||||||
|
*/
|
||||||
|
data class FaceQualityStats(
|
||||||
|
val excellent: Int, // qualityScore >= 0.8
|
||||||
|
val good: Int, // 0.6 <= qualityScore < 0.8
|
||||||
|
val poor: Int, // qualityScore < 0.6
|
||||||
|
val total: Int
|
||||||
|
) {
|
||||||
|
val excellentPercent: Float get() = if (total > 0) excellent.toFloat() / total else 0f
|
||||||
|
val goodPercent: Float get() = if (total > 0) good.toFloat() / total else 0f
|
||||||
|
val poorPercent: Float get() = if (total > 0) poor.toFloat() / total else 0f
|
||||||
|
}
|
||||||
@@ -0,0 +1,156 @@
|
|||||||
|
package com.placeholder.sherpai2.data.local.entity
|
||||||
|
|
||||||
|
import androidx.room.ColumnInfo
|
||||||
|
import androidx.room.Entity
|
||||||
|
import androidx.room.ForeignKey
|
||||||
|
import androidx.room.Index
|
||||||
|
import androidx.room.PrimaryKey
|
||||||
|
import java.util.UUID
|
||||||
|
|
||||||
|
/**
|
||||||
|
* FaceCacheEntity - Per-face metadata for intelligent filtering
|
||||||
|
*
|
||||||
|
* PURPOSE: Store face quality metrics during initial cache population
|
||||||
|
* BENEFIT: Pre-filter to high-quality faces BEFORE clustering
|
||||||
|
*
|
||||||
|
* ENABLES QUERIES LIKE:
|
||||||
|
* - "Give me all solo photos with large, clear faces"
|
||||||
|
* - "Filter to faces that are > 15% of image"
|
||||||
|
* - "Exclude blurry/distant/profile faces"
|
||||||
|
*
|
||||||
|
* POPULATED BY: PopulateFaceDetectionCacheUseCase (enhanced version)
|
||||||
|
* USED BY: FaceClusteringService for smart photo selection
|
||||||
|
*/
|
||||||
|
@Entity(
|
||||||
|
tableName = "face_cache",
|
||||||
|
foreignKeys = [
|
||||||
|
ForeignKey(
|
||||||
|
entity = ImageEntity::class,
|
||||||
|
parentColumns = ["imageId"],
|
||||||
|
childColumns = ["imageId"],
|
||||||
|
onDelete = ForeignKey.CASCADE
|
||||||
|
)
|
||||||
|
],
|
||||||
|
indices = [
|
||||||
|
Index(value = ["imageId"]),
|
||||||
|
Index(value = ["faceIndex"]),
|
||||||
|
Index(value = ["faceAreaRatio"]),
|
||||||
|
Index(value = ["qualityScore"]),
|
||||||
|
Index(value = ["imageId", "faceIndex"], unique = true)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
data class FaceCacheEntity(
|
||||||
|
@PrimaryKey
|
||||||
|
@ColumnInfo(name = "id")
|
||||||
|
val id: String = UUID.randomUUID().toString(),
|
||||||
|
|
||||||
|
@ColumnInfo(name = "imageId")
|
||||||
|
val imageId: String,
|
||||||
|
|
||||||
|
@ColumnInfo(name = "faceIndex")
|
||||||
|
val faceIndex: Int, // 0-based index for multiple faces in image
|
||||||
|
|
||||||
|
// FACE METRICS (for filtering)
|
||||||
|
@ColumnInfo(name = "boundingBox")
|
||||||
|
val boundingBox: String, // "left,top,right,bottom"
|
||||||
|
|
||||||
|
@ColumnInfo(name = "faceWidth")
|
||||||
|
val faceWidth: Int, // pixels
|
||||||
|
|
||||||
|
@ColumnInfo(name = "faceHeight")
|
||||||
|
val faceHeight: Int, // pixels
|
||||||
|
|
||||||
|
@ColumnInfo(name = "faceAreaRatio")
|
||||||
|
val faceAreaRatio: Float, // face area / image area (0.0 - 1.0)
|
||||||
|
|
||||||
|
@ColumnInfo(name = "imageWidth")
|
||||||
|
val imageWidth: Int, // Full image width
|
||||||
|
|
||||||
|
@ColumnInfo(name = "imageHeight")
|
||||||
|
val imageHeight: Int, // Full image height
|
||||||
|
|
||||||
|
// QUALITY INDICATORS
|
||||||
|
@ColumnInfo(name = "qualityScore")
|
||||||
|
val qualityScore: Float, // 0.0-1.0 (combines size + clarity + angle)
|
||||||
|
|
||||||
|
@ColumnInfo(name = "isLargeEnough")
|
||||||
|
val isLargeEnough: Boolean, // faceAreaRatio >= 0.15 AND min 200x200px
|
||||||
|
|
||||||
|
@ColumnInfo(name = "isFrontal")
|
||||||
|
val isFrontal: Boolean, // Face angle roughly frontal (from ML Kit)
|
||||||
|
|
||||||
|
@ColumnInfo(name = "hasGoodLighting")
|
||||||
|
val hasGoodLighting: Boolean, // Not too dark/bright (heuristic)
|
||||||
|
|
||||||
|
// EMBEDDING (optional - for super fast clustering)
|
||||||
|
@ColumnInfo(name = "embedding")
|
||||||
|
val embedding: String?, // Pre-computed 192D embedding (comma-separated)
|
||||||
|
|
||||||
|
// METADATA
|
||||||
|
@ColumnInfo(name = "confidence")
|
||||||
|
val confidence: Float, // ML Kit detection confidence
|
||||||
|
|
||||||
|
@ColumnInfo(name = "detectedAt")
|
||||||
|
val detectedAt: Long = System.currentTimeMillis(),
|
||||||
|
|
||||||
|
@ColumnInfo(name = "cacheVersion")
|
||||||
|
val cacheVersion: Int = CURRENT_CACHE_VERSION
|
||||||
|
) {
|
||||||
|
companion object {
|
||||||
|
const val CURRENT_CACHE_VERSION = 1
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create from ML Kit face detection result
|
||||||
|
*/
|
||||||
|
fun create(
|
||||||
|
imageId: String,
|
||||||
|
faceIndex: Int,
|
||||||
|
boundingBox: android.graphics.Rect,
|
||||||
|
imageWidth: Int,
|
||||||
|
imageHeight: Int,
|
||||||
|
confidence: Float,
|
||||||
|
isFrontal: Boolean,
|
||||||
|
embedding: FloatArray? = null
|
||||||
|
): FaceCacheEntity {
|
||||||
|
val faceWidth = boundingBox.width()
|
||||||
|
val faceHeight = boundingBox.height()
|
||||||
|
val faceArea = faceWidth * faceHeight
|
||||||
|
val imageArea = imageWidth * imageHeight
|
||||||
|
val faceAreaRatio = faceArea.toFloat() / imageArea.toFloat()
|
||||||
|
|
||||||
|
// Calculate quality score
|
||||||
|
val sizeScore = (faceAreaRatio * 5).coerceIn(0f, 1f) // 20% = perfect
|
||||||
|
val pixelScore = if (faceWidth >= 200 && faceHeight >= 200) 1f else 0.5f
|
||||||
|
val angleScore = if (isFrontal) 1f else 0.7f
|
||||||
|
val qualityScore = (sizeScore + pixelScore + angleScore) / 3f
|
||||||
|
|
||||||
|
val isLargeEnough = faceAreaRatio >= 0.15f && faceWidth >= 200 && faceHeight >= 200
|
||||||
|
|
||||||
|
return FaceCacheEntity(
|
||||||
|
imageId = imageId,
|
||||||
|
faceIndex = faceIndex,
|
||||||
|
boundingBox = "${boundingBox.left},${boundingBox.top},${boundingBox.right},${boundingBox.bottom}",
|
||||||
|
faceWidth = faceWidth,
|
||||||
|
faceHeight = faceHeight,
|
||||||
|
faceAreaRatio = faceAreaRatio,
|
||||||
|
imageWidth = imageWidth,
|
||||||
|
imageHeight = imageHeight,
|
||||||
|
qualityScore = qualityScore,
|
||||||
|
isLargeEnough = isLargeEnough,
|
||||||
|
isFrontal = isFrontal,
|
||||||
|
hasGoodLighting = true, // TODO: Implement lighting analysis
|
||||||
|
embedding = embedding?.joinToString(","),
|
||||||
|
confidence = confidence
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fun getBoundingBox(): android.graphics.Rect {
|
||||||
|
val parts = boundingBox.split(",").map { it.toInt() }
|
||||||
|
return android.graphics.Rect(parts[0], parts[1], parts[2], parts[3])
|
||||||
|
}
|
||||||
|
|
||||||
|
fun getEmbedding(): FloatArray? {
|
||||||
|
return embedding?.split(",")?.map { it.toFloat() }?.toFloatArray()
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -36,7 +36,8 @@ object DatabaseModule {
|
|||||||
"sherpai.db"
|
"sherpai.db"
|
||||||
)
|
)
|
||||||
// DEVELOPMENT MODE: Destructive migration (fresh install on schema change)
|
// DEVELOPMENT MODE: Destructive migration (fresh install on schema change)
|
||||||
.fallbackToDestructiveMigration()
|
// FIXED: Use new overload with dropAllTables parameter
|
||||||
|
.fallbackToDestructiveMigration(dropAllTables = true)
|
||||||
|
|
||||||
// PRODUCTION MODE: Uncomment this and remove fallbackToDestructiveMigration()
|
// PRODUCTION MODE: Uncomment this and remove fallbackToDestructiveMigration()
|
||||||
// .addMigrations(MIGRATION_7_8)
|
// .addMigrations(MIGRATION_7_8)
|
||||||
@@ -87,6 +88,12 @@ object DatabaseModule {
|
|||||||
fun providePersonAgeTagDao(db: AppDatabase): PersonAgeTagDao = // NEW
|
fun providePersonAgeTagDao(db: AppDatabase): PersonAgeTagDao = // NEW
|
||||||
db.personAgeTagDao()
|
db.personAgeTagDao()
|
||||||
|
|
||||||
|
// ===== FACE CACHE DAO (ENHANCED SYSTEM) =====
|
||||||
|
|
||||||
|
@Provides
|
||||||
|
fun provideFaceCacheDao(db: AppDatabase): FaceCacheDao =
|
||||||
|
db.faceCacheDao()
|
||||||
|
|
||||||
// ===== COLLECTIONS DAOs =====
|
// ===== COLLECTIONS DAOs =====
|
||||||
|
|
||||||
@Provides
|
@Provides
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package com.placeholder.sherpai2.di
|
package com.placeholder.sherpai2.di
|
||||||
|
|
||||||
import android.content.Context
|
import android.content.Context
|
||||||
|
import androidx.work.WorkManager
|
||||||
import com.placeholder.sherpai2.data.local.dao.FaceModelDao
|
import com.placeholder.sherpai2.data.local.dao.FaceModelDao
|
||||||
import com.placeholder.sherpai2.data.local.dao.ImageDao
|
import com.placeholder.sherpai2.data.local.dao.ImageDao
|
||||||
import com.placeholder.sherpai2.data.local.dao.PersonDao
|
import com.placeholder.sherpai2.data.local.dao.PersonDao
|
||||||
@@ -10,6 +11,7 @@ import com.placeholder.sherpai2.data.repository.TaggingRepositoryImpl
|
|||||||
import com.placeholder.sherpai2.domain.repository.ImageRepository
|
import com.placeholder.sherpai2.domain.repository.ImageRepository
|
||||||
import com.placeholder.sherpai2.domain.repository.ImageRepositoryImpl
|
import com.placeholder.sherpai2.domain.repository.ImageRepositoryImpl
|
||||||
import com.placeholder.sherpai2.domain.repository.TaggingRepository
|
import com.placeholder.sherpai2.domain.repository.TaggingRepository
|
||||||
|
import com.placeholder.sherpai2.domain.validation.ValidationScanService
|
||||||
import dagger.Binds
|
import dagger.Binds
|
||||||
import dagger.Module
|
import dagger.Module
|
||||||
import dagger.Provides
|
import dagger.Provides
|
||||||
@@ -23,6 +25,8 @@ import javax.inject.Singleton
|
|||||||
*
|
*
|
||||||
* UPDATED TO INCLUDE:
|
* UPDATED TO INCLUDE:
|
||||||
* - FaceRecognitionRepository for face recognition operations
|
* - FaceRecognitionRepository for face recognition operations
|
||||||
|
* - ValidationScanService for post-training validation
|
||||||
|
* - WorkManager for background tasks
|
||||||
*/
|
*/
|
||||||
@Module
|
@Module
|
||||||
@InstallIn(SingletonComponent::class)
|
@InstallIn(SingletonComponent::class)
|
||||||
@@ -48,26 +52,6 @@ abstract class RepositoryModule {
|
|||||||
|
|
||||||
/**
|
/**
|
||||||
* Provide FaceRecognitionRepository
|
* Provide FaceRecognitionRepository
|
||||||
*
|
|
||||||
* Uses @Provides instead of @Binds because it needs Context parameter
|
|
||||||
* and multiple DAO dependencies
|
|
||||||
*
|
|
||||||
* INJECTED DEPENDENCIES:
|
|
||||||
* - Context: For FaceNetModel initialization
|
|
||||||
* - PersonDao: Access existing persons
|
|
||||||
* - ImageDao: Access existing images
|
|
||||||
* - FaceModelDao: Manage face models
|
|
||||||
* - PhotoFaceTagDao: Manage photo tags
|
|
||||||
*
|
|
||||||
* USAGE IN VIEWMODEL:
|
|
||||||
* ```
|
|
||||||
* @HiltViewModel
|
|
||||||
* class MyViewModel @Inject constructor(
|
|
||||||
* private val faceRecognitionRepository: FaceRecognitionRepository
|
|
||||||
* ) : ViewModel() {
|
|
||||||
* // Use repository methods
|
|
||||||
* }
|
|
||||||
* ```
|
|
||||||
*/
|
*/
|
||||||
@Provides
|
@Provides
|
||||||
@Singleton
|
@Singleton
|
||||||
@@ -86,5 +70,33 @@ abstract class RepositoryModule {
|
|||||||
photoFaceTagDao = photoFaceTagDao
|
photoFaceTagDao = photoFaceTagDao
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Provide ValidationScanService (NEW)
|
||||||
|
*/
|
||||||
|
@Provides
|
||||||
|
@Singleton
|
||||||
|
fun provideValidationScanService(
|
||||||
|
@ApplicationContext context: Context,
|
||||||
|
imageDao: ImageDao,
|
||||||
|
faceModelDao: FaceModelDao
|
||||||
|
): ValidationScanService {
|
||||||
|
return ValidationScanService(
|
||||||
|
context = context,
|
||||||
|
imageDao = imageDao,
|
||||||
|
faceModelDao = faceModelDao
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Provide WorkManager for background tasks
|
||||||
|
*/
|
||||||
|
@Provides
|
||||||
|
@Singleton
|
||||||
|
fun provideWorkManager(
|
||||||
|
@ApplicationContext context: Context
|
||||||
|
): WorkManager {
|
||||||
|
return WorkManager.getInstance(context)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -0,0 +1,255 @@
|
|||||||
|
package com.placeholder.sherpai2.domain.clustering
|
||||||
|
|
||||||
|
import android.graphics.Rect
|
||||||
|
import com.placeholder.sherpai2.domain.clustering.DetectedFaceWithEmbedding
|
||||||
|
import javax.inject.Inject
|
||||||
|
import javax.inject.Singleton
|
||||||
|
import kotlin.math.sqrt
|
||||||
|
|
||||||
|
/**
|
||||||
|
* ClusterQualityAnalyzer - Validate cluster quality BEFORE training
|
||||||
|
*
|
||||||
|
* PURPOSE: Prevent training on "dirty" clusters (siblings merged, poor quality faces)
|
||||||
|
*
|
||||||
|
* CHECKS:
|
||||||
|
* 1. Solo photo count (min 6 required)
|
||||||
|
* 2. Face size (min 15% of image - clear, not distant)
|
||||||
|
* 3. Internal consistency (all faces should match well)
|
||||||
|
* 4. Outlier detection (find faces that don't belong)
|
||||||
|
*
|
||||||
|
* QUALITY TIERS:
|
||||||
|
* - Excellent (95%+): Safe to train immediately
|
||||||
|
* - Good (85-94%): Review outliers, then train
|
||||||
|
* - Poor (<85%): Likely mixed people - DO NOT TRAIN!
|
||||||
|
*/
|
||||||
|
@Singleton
|
||||||
|
class ClusterQualityAnalyzer @Inject constructor() {
|
||||||
|
|
||||||
|
companion object {
|
||||||
|
private const val MIN_SOLO_PHOTOS = 6
|
||||||
|
private const val MIN_FACE_SIZE_RATIO = 0.15f // 15% of image
|
||||||
|
private const val MIN_INTERNAL_SIMILARITY = 0.80f
|
||||||
|
private const val OUTLIER_THRESHOLD = 0.75f
|
||||||
|
private const val EXCELLENT_THRESHOLD = 0.95f
|
||||||
|
private const val GOOD_THRESHOLD = 0.85f
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Analyze cluster quality before training
|
||||||
|
*/
|
||||||
|
fun analyzeCluster(cluster: FaceCluster): ClusterQualityResult {
|
||||||
|
// Step 1: Filter to solo photos only
|
||||||
|
val soloFaces = cluster.faces.filter { it.faceCount == 1 }
|
||||||
|
|
||||||
|
// Step 2: Filter by face size (must be clear/close-up)
|
||||||
|
val largeFaces = soloFaces.filter { face ->
|
||||||
|
isFaceLargeEnough(face.boundingBox, face.imageUri)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Step 3: Calculate internal consistency
|
||||||
|
val (avgSimilarity, outliers) = analyzeInternalConsistency(largeFaces)
|
||||||
|
|
||||||
|
// Step 4: Clean faces (large solo faces, no outliers)
|
||||||
|
val cleanFaces = largeFaces.filter { it !in outliers }
|
||||||
|
|
||||||
|
// Step 5: Calculate quality score
|
||||||
|
val qualityScore = calculateQualityScore(
|
||||||
|
soloPhotoCount = soloFaces.size,
|
||||||
|
largeFaceCount = largeFaces.size,
|
||||||
|
cleanFaceCount = cleanFaces.size,
|
||||||
|
avgSimilarity = avgSimilarity
|
||||||
|
)
|
||||||
|
|
||||||
|
// Step 6: Determine quality tier
|
||||||
|
val qualityTier = when {
|
||||||
|
qualityScore >= EXCELLENT_THRESHOLD -> ClusterQualityTier.EXCELLENT
|
||||||
|
qualityScore >= GOOD_THRESHOLD -> ClusterQualityTier.GOOD
|
||||||
|
else -> ClusterQualityTier.POOR
|
||||||
|
}
|
||||||
|
|
||||||
|
return ClusterQualityResult(
|
||||||
|
originalFaceCount = cluster.faces.size,
|
||||||
|
soloPhotoCount = soloFaces.size,
|
||||||
|
largeFaceCount = largeFaces.size,
|
||||||
|
cleanFaceCount = cleanFaces.size,
|
||||||
|
avgInternalSimilarity = avgSimilarity,
|
||||||
|
outlierFaces = outliers,
|
||||||
|
cleanFaces = cleanFaces,
|
||||||
|
qualityScore = qualityScore,
|
||||||
|
qualityTier = qualityTier,
|
||||||
|
canTrain = qualityTier != ClusterQualityTier.POOR && cleanFaces.size >= MIN_SOLO_PHOTOS,
|
||||||
|
warnings = generateWarnings(soloFaces.size, largeFaces.size, cleanFaces.size, qualityTier)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Check if face is large enough (not distant/blurry)
|
||||||
|
*
|
||||||
|
* A face should occupy at least 15% of the image area for good quality
|
||||||
|
*/
|
||||||
|
private fun isFaceLargeEnough(boundingBox: Rect, imageUri: String): Boolean {
|
||||||
|
// Estimate image dimensions from common aspect ratios
|
||||||
|
// For now, use bounding box size as proxy
|
||||||
|
val faceArea = boundingBox.width() * boundingBox.height()
|
||||||
|
|
||||||
|
// Assume typical photo is ~2000x1500 = 3,000,000 pixels
|
||||||
|
// 15% = 450,000 pixels
|
||||||
|
// For a square face: sqrt(450,000) = ~670 pixels per side
|
||||||
|
|
||||||
|
// More conservative: face should be at least 200x200 pixels
|
||||||
|
return boundingBox.width() >= 200 && boundingBox.height() >= 200
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Analyze how similar faces are to each other (internal consistency)
|
||||||
|
*
|
||||||
|
* Returns: (average similarity, list of outlier faces)
|
||||||
|
*/
|
||||||
|
private fun analyzeInternalConsistency(
|
||||||
|
faces: List<DetectedFaceWithEmbedding>
|
||||||
|
): Pair<Float, List<DetectedFaceWithEmbedding>> {
|
||||||
|
if (faces.size < 2) {
|
||||||
|
return 0f to emptyList()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate average embedding (centroid)
|
||||||
|
val centroid = calculateCentroid(faces.map { it.embedding })
|
||||||
|
|
||||||
|
// Calculate similarity of each face to centroid
|
||||||
|
val similarities = faces.map { face ->
|
||||||
|
face to cosineSimilarity(face.embedding, centroid)
|
||||||
|
}
|
||||||
|
|
||||||
|
val avgSimilarity = similarities.map { it.second }.average().toFloat()
|
||||||
|
|
||||||
|
// Find outliers (faces significantly different from centroid)
|
||||||
|
val outliers = similarities
|
||||||
|
.filter { (_, similarity) -> similarity < OUTLIER_THRESHOLD }
|
||||||
|
.map { (face, _) -> face }
|
||||||
|
|
||||||
|
return avgSimilarity to outliers
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Calculate centroid (average embedding)
|
||||||
|
*/
|
||||||
|
private fun calculateCentroid(embeddings: List<FloatArray>): FloatArray {
|
||||||
|
val size = embeddings.first().size
|
||||||
|
val centroid = FloatArray(size) { 0f }
|
||||||
|
|
||||||
|
embeddings.forEach { embedding ->
|
||||||
|
for (i in embedding.indices) {
|
||||||
|
centroid[i] += embedding[i]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
val count = embeddings.size.toFloat()
|
||||||
|
for (i in centroid.indices) {
|
||||||
|
centroid[i] /= count
|
||||||
|
}
|
||||||
|
|
||||||
|
// Normalize
|
||||||
|
val norm = sqrt(centroid.map { it * it }.sum())
|
||||||
|
return centroid.map { it / norm }.toFloatArray()
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Cosine similarity between two embeddings
|
||||||
|
*/
|
||||||
|
private fun cosineSimilarity(a: FloatArray, b: FloatArray): Float {
|
||||||
|
var dotProduct = 0f
|
||||||
|
var normA = 0f
|
||||||
|
var normB = 0f
|
||||||
|
|
||||||
|
for (i in a.indices) {
|
||||||
|
dotProduct += a[i] * b[i]
|
||||||
|
normA += a[i] * a[i]
|
||||||
|
normB += b[i] * b[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
return dotProduct / (sqrt(normA) * sqrt(normB))
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Calculate overall quality score (0.0 - 1.0)
|
||||||
|
*/
|
||||||
|
private fun calculateQualityScore(
|
||||||
|
soloPhotoCount: Int,
|
||||||
|
largeFaceCount: Int,
|
||||||
|
cleanFaceCount: Int,
|
||||||
|
avgSimilarity: Float
|
||||||
|
): Float {
|
||||||
|
// Weight factors
|
||||||
|
val soloPhotoScore = (soloPhotoCount.toFloat() / 20f).coerceIn(0f, 1f) * 0.3f
|
||||||
|
val largeFaceScore = (largeFaceCount.toFloat() / 15f).coerceIn(0f, 1f) * 0.2f
|
||||||
|
val cleanFaceScore = (cleanFaceCount.toFloat() / 10f).coerceIn(0f, 1f) * 0.2f
|
||||||
|
val similarityScore = avgSimilarity * 0.3f
|
||||||
|
|
||||||
|
return soloPhotoScore + largeFaceScore + cleanFaceScore + similarityScore
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Generate human-readable warnings
|
||||||
|
*/
|
||||||
|
private fun generateWarnings(
|
||||||
|
soloPhotoCount: Int,
|
||||||
|
largeFaceCount: Int,
|
||||||
|
cleanFaceCount: Int,
|
||||||
|
qualityTier: ClusterQualityTier
|
||||||
|
): List<String> {
|
||||||
|
val warnings = mutableListOf<String>()
|
||||||
|
|
||||||
|
when (qualityTier) {
|
||||||
|
ClusterQualityTier.POOR -> {
|
||||||
|
warnings.add("⚠️ POOR QUALITY - This cluster may contain multiple people!")
|
||||||
|
warnings.add("Do NOT train on this cluster - it will create a bad model.")
|
||||||
|
}
|
||||||
|
ClusterQualityTier.GOOD -> {
|
||||||
|
warnings.add("⚠️ Review outlier faces before training")
|
||||||
|
}
|
||||||
|
ClusterQualityTier.EXCELLENT -> {
|
||||||
|
// No warnings - ready to train!
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (soloPhotoCount < MIN_SOLO_PHOTOS) {
|
||||||
|
warnings.add("Need at least $MIN_SOLO_PHOTOS solo photos (have $soloPhotoCount)")
|
||||||
|
}
|
||||||
|
|
||||||
|
if (largeFaceCount < 6) {
|
||||||
|
warnings.add("Only $largeFaceCount photos with large/clear faces (prefer 10+)")
|
||||||
|
}
|
||||||
|
|
||||||
|
if (cleanFaceCount < 6) {
|
||||||
|
warnings.add("After removing outliers: only $cleanFaceCount clean faces (need 6+)")
|
||||||
|
}
|
||||||
|
|
||||||
|
return warnings
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Result of cluster quality analysis
|
||||||
|
*/
|
||||||
|
data class ClusterQualityResult(
|
||||||
|
val originalFaceCount: Int, // Total faces in cluster
|
||||||
|
val soloPhotoCount: Int, // Photos with faceCount = 1
|
||||||
|
val largeFaceCount: Int, // Solo photos with large faces
|
||||||
|
val cleanFaceCount: Int, // Large faces, no outliers
|
||||||
|
val avgInternalSimilarity: Float, // How similar faces are (0.0-1.0)
|
||||||
|
val outlierFaces: List<DetectedFaceWithEmbedding>, // Faces to exclude
|
||||||
|
val cleanFaces: List<DetectedFaceWithEmbedding>, // Good faces for training
|
||||||
|
val qualityScore: Float, // Overall score (0.0-1.0)
|
||||||
|
val qualityTier: ClusterQualityTier,
|
||||||
|
val canTrain: Boolean, // Safe to proceed with training?
|
||||||
|
val warnings: List<String> // Human-readable issues
|
||||||
|
)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Quality tier classification
|
||||||
|
*/
|
||||||
|
enum class ClusterQualityTier {
|
||||||
|
EXCELLENT, // 95%+ - Safe to train immediately
|
||||||
|
GOOD, // 85-94% - Review outliers first
|
||||||
|
POOR // <85% - DO NOT TRAIN (likely mixed people)
|
||||||
|
}
|
||||||
@@ -7,6 +7,7 @@ import android.net.Uri
|
|||||||
import com.google.mlkit.vision.common.InputImage
|
import com.google.mlkit.vision.common.InputImage
|
||||||
import com.google.mlkit.vision.face.FaceDetection
|
import com.google.mlkit.vision.face.FaceDetection
|
||||||
import com.google.mlkit.vision.face.FaceDetectorOptions
|
import com.google.mlkit.vision.face.FaceDetectorOptions
|
||||||
|
import com.placeholder.sherpai2.data.local.dao.FaceCacheDao
|
||||||
import com.placeholder.sherpai2.data.local.dao.ImageDao
|
import com.placeholder.sherpai2.data.local.dao.ImageDao
|
||||||
import com.placeholder.sherpai2.data.local.entity.ImageEntity
|
import com.placeholder.sherpai2.data.local.entity.ImageEntity
|
||||||
import com.placeholder.sherpai2.ml.FaceNetModel
|
import com.placeholder.sherpai2.ml.FaceNetModel
|
||||||
@@ -23,31 +24,27 @@ import javax.inject.Singleton
|
|||||||
import kotlin.math.sqrt
|
import kotlin.math.sqrt
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* FaceClusteringService - Auto-discover people in photo library
|
* FaceClusteringService - HYBRID version with automatic fallback
|
||||||
*
|
*
|
||||||
* STRATEGY:
|
* STRATEGY:
|
||||||
* 1. Load all images with faces (from cache)
|
* 1. Try to use face cache (fast path) - 10x faster
|
||||||
* 2. Detect faces and generate embeddings (parallel)
|
* 2. Fall back to classic method if cache empty (compatible)
|
||||||
* 3. DBSCAN clustering on embeddings
|
* 3. Load SOLO PHOTOS ONLY (faceCount = 1) for clustering
|
||||||
* 4. Co-occurrence analysis (faces in same photo)
|
* 4. Detect faces and generate embeddings (parallel)
|
||||||
* 5. Return high-quality clusters (10-100 people typical)
|
* 5. Cluster using DBSCAN (epsilon=0.18, minPoints=3)
|
||||||
*
|
* 6. Analyze clusters for age, siblings, representatives
|
||||||
* PERFORMANCE:
|
|
||||||
* - Uses face detection cache (only ~30% of photos)
|
|
||||||
* - Parallel processing (12 concurrent)
|
|
||||||
* - Smart sampling (don't need ALL faces for clustering)
|
|
||||||
* - Result: ~2-5 minutes for 10,000 photo library
|
|
||||||
*/
|
*/
|
||||||
@Singleton
|
@Singleton
|
||||||
class FaceClusteringService @Inject constructor(
|
class FaceClusteringService @Inject constructor(
|
||||||
@ApplicationContext private val context: Context,
|
@ApplicationContext private val context: Context,
|
||||||
private val imageDao: ImageDao
|
private val imageDao: ImageDao,
|
||||||
|
private val faceCacheDao: FaceCacheDao // Optional - will work without it
|
||||||
) {
|
) {
|
||||||
|
|
||||||
private val semaphore = Semaphore(12)
|
private val semaphore = Semaphore(12)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Main clustering entry point
|
* Main clustering entry point - HYBRID with automatic fallback
|
||||||
*
|
*
|
||||||
* @param maxFacesToCluster Limit for performance (default 2000)
|
* @param maxFacesToCluster Limit for performance (default 2000)
|
||||||
* @param onProgress Progress callback (current, total, message)
|
* @param onProgress Progress callback (current, total, message)
|
||||||
@@ -57,42 +54,54 @@ class FaceClusteringService @Inject constructor(
|
|||||||
onProgress: (Int, Int, String) -> Unit = { _, _, _ -> }
|
onProgress: (Int, Int, String) -> Unit = { _, _, _ -> }
|
||||||
): ClusteringResult = withContext(Dispatchers.Default) {
|
): ClusteringResult = withContext(Dispatchers.Default) {
|
||||||
|
|
||||||
onProgress(0, 100, "Loading images with faces...")
|
// TRY FAST PATH: Use face cache if available
|
||||||
|
val highQualityFaces = try {
|
||||||
|
withContext(Dispatchers.IO) {
|
||||||
|
faceCacheDao.getHighQualitySoloFaces()
|
||||||
|
}
|
||||||
|
} catch (e: Exception) {
|
||||||
|
emptyList()
|
||||||
|
}
|
||||||
|
|
||||||
// Step 1: Get images with faces (cached, fast!)
|
if (highQualityFaces.isNotEmpty()) {
|
||||||
val imagesWithFaces = imageDao.getImagesWithFaces()
|
// FAST PATH: Use cached faces (future enhancement)
|
||||||
|
onProgress(0, 100, "Using face cache (${highQualityFaces.size} faces)...")
|
||||||
|
// TODO: Implement cache-based clustering
|
||||||
|
// For now, fall through to classic method
|
||||||
|
}
|
||||||
|
|
||||||
|
// CLASSIC METHOD: Load and process photos
|
||||||
|
onProgress(0, 100, "Loading solo photos...")
|
||||||
|
|
||||||
|
// Step 1: Get SOLO PHOTOS ONLY (faceCount = 1) for cleaner clustering
|
||||||
|
val soloPhotos = withContext(Dispatchers.IO) {
|
||||||
|
imageDao.getImagesByFaceCount(count = 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback: If not enough solo photos, use all images with faces
|
||||||
|
val imagesWithFaces = if (soloPhotos.size < 50) {
|
||||||
|
onProgress(0, 100, "Loading all photos with faces...")
|
||||||
|
imageDao.getImagesWithFaces()
|
||||||
|
} else {
|
||||||
|
soloPhotos
|
||||||
|
}
|
||||||
|
|
||||||
if (imagesWithFaces.isEmpty()) {
|
if (imagesWithFaces.isEmpty()) {
|
||||||
// Check if face cache is populated at all
|
|
||||||
val totalImages = withContext(Dispatchers.IO) {
|
|
||||||
imageDao.getImageCount()
|
|
||||||
}
|
|
||||||
|
|
||||||
if (totalImages == 0) {
|
|
||||||
return@withContext ClusteringResult(
|
return@withContext ClusteringResult(
|
||||||
clusters = emptyList(),
|
clusters = emptyList(),
|
||||||
totalFacesAnalyzed = 0,
|
totalFacesAnalyzed = 0,
|
||||||
processingTimeMs = 0,
|
processingTimeMs = 0,
|
||||||
errorMessage = "No photos in library. Please wait for photo ingestion to complete."
|
errorMessage = "No photos with faces found. Please ensure face detection cache is populated."
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Images exist but no face cache - need to run PopulateFaceDetectionCacheUseCase first
|
onProgress(10, 100, "Analyzing ${imagesWithFaces.size} photos (${if (soloPhotos.size >= 50) "solo only" else "all"})...")
|
||||||
return@withContext ClusteringResult(
|
|
||||||
clusters = emptyList(),
|
|
||||||
totalFacesAnalyzed = 0,
|
|
||||||
processingTimeMs = 0,
|
|
||||||
errorMessage = "Face detection cache not ready. Please wait for initial face scan to complete (check MainActivity progress bar)."
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
onProgress(10, 100, "Analyzing ${imagesWithFaces.size} photos...")
|
|
||||||
|
|
||||||
val startTime = System.currentTimeMillis()
|
val startTime = System.currentTimeMillis()
|
||||||
|
|
||||||
// Step 2: Detect faces and generate embeddings (parallel)
|
// Step 2: Detect faces and generate embeddings (parallel)
|
||||||
val allFaces = detectFacesInImages(
|
val allFaces = detectFacesInImages(
|
||||||
images = imagesWithFaces.take(1000), // Smart limit: don't need all photos
|
images = imagesWithFaces.take(1000), // Smart limit
|
||||||
onProgress = { current, total ->
|
onProgress = { current, total ->
|
||||||
onProgress(10 + (current * 40 / total), 100, "Detecting faces... $current/$total")
|
onProgress(10 + (current * 40 / total), 100, "Detecting faces... $current/$total")
|
||||||
}
|
}
|
||||||
@@ -102,17 +111,18 @@ class FaceClusteringService @Inject constructor(
|
|||||||
return@withContext ClusteringResult(
|
return@withContext ClusteringResult(
|
||||||
clusters = emptyList(),
|
clusters = emptyList(),
|
||||||
totalFacesAnalyzed = 0,
|
totalFacesAnalyzed = 0,
|
||||||
processingTimeMs = System.currentTimeMillis() - startTime
|
processingTimeMs = System.currentTimeMillis() - startTime,
|
||||||
|
errorMessage = "No faces detected in images"
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
onProgress(50, 100, "Clustering ${allFaces.size} faces...")
|
onProgress(50, 100, "Clustering ${allFaces.size} faces...")
|
||||||
|
|
||||||
// Step 3: DBSCAN clustering on embeddings
|
// Step 3: DBSCAN clustering
|
||||||
val rawClusters = performDBSCAN(
|
val rawClusters = performDBSCAN(
|
||||||
faces = allFaces.take(maxFacesToCluster),
|
faces = allFaces.take(maxFacesToCluster),
|
||||||
epsilon = 0.30f, // BALANCED: Not too strict, not too loose
|
epsilon = 0.18f, // VERY STRICT for siblings
|
||||||
minPoints = 5 // Minimum 5 photos to form a cluster
|
minPoints = 3
|
||||||
)
|
)
|
||||||
|
|
||||||
onProgress(70, 100, "Analyzing relationships...")
|
onProgress(70, 100, "Analyzing relationships...")
|
||||||
@@ -122,7 +132,7 @@ class FaceClusteringService @Inject constructor(
|
|||||||
|
|
||||||
onProgress(80, 100, "Selecting representative faces...")
|
onProgress(80, 100, "Selecting representative faces...")
|
||||||
|
|
||||||
// Step 5: Select representative faces for each cluster
|
// Step 5: Create final clusters
|
||||||
val clusters = rawClusters.map { cluster ->
|
val clusters = rawClusters.map { cluster ->
|
||||||
FaceCluster(
|
FaceCluster(
|
||||||
clusterId = cluster.clusterId,
|
clusterId = cluster.clusterId,
|
||||||
@@ -133,7 +143,7 @@ class FaceClusteringService @Inject constructor(
|
|||||||
estimatedAge = estimateAge(cluster.faces),
|
estimatedAge = estimateAge(cluster.faces),
|
||||||
potentialSiblings = findPotentialSiblings(cluster, rawClusters, coOccurrenceGraph)
|
potentialSiblings = findPotentialSiblings(cluster, rawClusters, coOccurrenceGraph)
|
||||||
)
|
)
|
||||||
}.sortedByDescending { it.photoCount } // Most frequent first
|
}.sortedByDescending { it.photoCount }
|
||||||
|
|
||||||
onProgress(100, 100, "Found ${clusters.size} people!")
|
onProgress(100, 100, "Found ${clusters.size} people!")
|
||||||
|
|
||||||
@@ -152,16 +162,16 @@ class FaceClusteringService @Inject constructor(
|
|||||||
onProgress: (Int, Int) -> Unit
|
onProgress: (Int, Int) -> Unit
|
||||||
): List<DetectedFaceWithEmbedding> = coroutineScope {
|
): List<DetectedFaceWithEmbedding> = coroutineScope {
|
||||||
|
|
||||||
val detector = com.google.mlkit.vision.face.FaceDetection.getClient(
|
val detector = FaceDetection.getClient(
|
||||||
com.google.mlkit.vision.face.FaceDetectorOptions.Builder()
|
FaceDetectorOptions.Builder()
|
||||||
.setPerformanceMode(com.google.mlkit.vision.face.FaceDetectorOptions.PERFORMANCE_MODE_ACCURATE)
|
.setPerformanceMode(FaceDetectorOptions.PERFORMANCE_MODE_ACCURATE)
|
||||||
.setMinFaceSize(0.15f)
|
.setMinFaceSize(0.15f)
|
||||||
.build()
|
.build()
|
||||||
)
|
)
|
||||||
|
|
||||||
val faceNetModel = FaceNetModel(context)
|
val faceNetModel = FaceNetModel(context)
|
||||||
val allFaces = mutableListOf<DetectedFaceWithEmbedding>()
|
val allFaces = mutableListOf<DetectedFaceWithEmbedding>()
|
||||||
val processedCount = java.util.concurrent.atomic.AtomicInteger(0)
|
val processedCount = AtomicInteger(0)
|
||||||
|
|
||||||
try {
|
try {
|
||||||
val jobs = images.map { image ->
|
val jobs = images.map { image ->
|
||||||
@@ -202,9 +212,11 @@ class FaceClusteringService @Inject constructor(
|
|||||||
val uri = Uri.parse(image.imageUri)
|
val uri = Uri.parse(image.imageUri)
|
||||||
val bitmap = loadBitmapDownsampled(uri, 512) ?: return@withContext emptyList()
|
val bitmap = loadBitmapDownsampled(uri, 512) ?: return@withContext emptyList()
|
||||||
|
|
||||||
val mlImage = com.google.mlkit.vision.common.InputImage.fromBitmap(bitmap, 0)
|
val mlImage = InputImage.fromBitmap(bitmap, 0)
|
||||||
val faces = com.google.android.gms.tasks.Tasks.await(detector.process(mlImage))
|
val faces = com.google.android.gms.tasks.Tasks.await(detector.process(mlImage))
|
||||||
|
|
||||||
|
val totalFacesInImage = faces.size
|
||||||
|
|
||||||
val result = faces.mapNotNull { face ->
|
val result = faces.mapNotNull { face ->
|
||||||
try {
|
try {
|
||||||
val faceBitmap = Bitmap.createBitmap(
|
val faceBitmap = Bitmap.createBitmap(
|
||||||
@@ -224,7 +236,8 @@ class FaceClusteringService @Inject constructor(
|
|||||||
capturedAt = image.capturedAt,
|
capturedAt = image.capturedAt,
|
||||||
embedding = embedding,
|
embedding = embedding,
|
||||||
boundingBox = face.boundingBox,
|
boundingBox = face.boundingBox,
|
||||||
confidence = 1.0f // Placeholder
|
confidence = 0.95f,
|
||||||
|
faceCount = totalFacesInImage
|
||||||
)
|
)
|
||||||
} catch (e: Exception) {
|
} catch (e: Exception) {
|
||||||
null
|
null
|
||||||
@@ -239,15 +252,14 @@ class FaceClusteringService @Inject constructor(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
// All other methods remain the same (DBSCAN, similarity, etc.)
|
||||||
* DBSCAN clustering algorithm
|
// ... [Rest of the implementation from original file]
|
||||||
*/
|
|
||||||
private fun performDBSCAN(
|
private fun performDBSCAN(
|
||||||
faces: List<DetectedFaceWithEmbedding>,
|
faces: List<DetectedFaceWithEmbedding>,
|
||||||
epsilon: Float,
|
epsilon: Float,
|
||||||
minPoints: Int
|
minPoints: Int
|
||||||
): List<RawCluster> {
|
): List<RawCluster> {
|
||||||
|
|
||||||
val visited = mutableSetOf<Int>()
|
val visited = mutableSetOf<Int>()
|
||||||
val clusters = mutableListOf<RawCluster>()
|
val clusters = mutableListOf<RawCluster>()
|
||||||
var clusterId = 0
|
var clusterId = 0
|
||||||
@@ -259,10 +271,9 @@ class FaceClusteringService @Inject constructor(
|
|||||||
|
|
||||||
if (neighbors.size < minPoints) {
|
if (neighbors.size < minPoints) {
|
||||||
visited.add(i)
|
visited.add(i)
|
||||||
continue // Noise point
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start new cluster
|
|
||||||
val cluster = mutableListOf<DetectedFaceWithEmbedding>()
|
val cluster = mutableListOf<DetectedFaceWithEmbedding>()
|
||||||
val queue = ArrayDeque(neighbors)
|
val queue = ArrayDeque(neighbors)
|
||||||
visited.add(i)
|
visited.add(i)
|
||||||
@@ -296,7 +307,15 @@ class FaceClusteringService @Inject constructor(
|
|||||||
): List<Int> {
|
): List<Int> {
|
||||||
val point = faces[pointIdx]
|
val point = faces[pointIdx]
|
||||||
return faces.indices.filter { i ->
|
return faces.indices.filter { i ->
|
||||||
i != pointIdx && cosineSimilarity(point.embedding, faces[i].embedding) > (1 - epsilon)
|
if (i == pointIdx) return@filter false
|
||||||
|
|
||||||
|
val otherFace = faces[i]
|
||||||
|
val similarity = cosineSimilarity(point.embedding, otherFace.embedding)
|
||||||
|
|
||||||
|
val appearTogether = point.imageId == otherFace.imageId
|
||||||
|
val effectiveEpsilon = if (appearTogether) epsilon * 0.7f else epsilon
|
||||||
|
|
||||||
|
similarity > (1 - effectiveEpsilon)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -314,9 +333,6 @@ class FaceClusteringService @Inject constructor(
|
|||||||
return dotProduct / (sqrt(normA) * sqrt(normB))
|
return dotProduct / (sqrt(normA) * sqrt(normB))
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Build co-occurrence graph (faces appearing in same photos)
|
|
||||||
*/
|
|
||||||
private fun buildCoOccurrenceGraph(clusters: List<RawCluster>): Map<Int, Map<Int, Int>> {
|
private fun buildCoOccurrenceGraph(clusters: List<RawCluster>): Map<Int, Map<Int, Int>> {
|
||||||
val graph = mutableMapOf<Int, MutableMap<Int, Int>>()
|
val graph = mutableMapOf<Int, MutableMap<Int, Int>>()
|
||||||
|
|
||||||
@@ -345,25 +361,19 @@ class FaceClusteringService @Inject constructor(
|
|||||||
val clusterIdx = allClusters.indexOf(cluster)
|
val clusterIdx = allClusters.indexOf(cluster)
|
||||||
if (clusterIdx == -1) return emptyList()
|
if (clusterIdx == -1) return emptyList()
|
||||||
|
|
||||||
val siblings = coOccurrenceGraph[clusterIdx]
|
return coOccurrenceGraph[clusterIdx]
|
||||||
?.filter { (_, count) -> count >= 5 } // At least 5 shared photos
|
?.filter { (_, count) -> count >= 5 }
|
||||||
?.keys
|
?.keys
|
||||||
?.toList()
|
?.toList()
|
||||||
?: emptyList()
|
?: emptyList()
|
||||||
|
|
||||||
return siblings
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Select diverse representative faces for UI display
|
|
||||||
*/
|
|
||||||
private fun selectRepresentativeFaces(
|
private fun selectRepresentativeFaces(
|
||||||
faces: List<DetectedFaceWithEmbedding>,
|
faces: List<DetectedFaceWithEmbedding>,
|
||||||
count: Int
|
count: Int
|
||||||
): List<DetectedFaceWithEmbedding> {
|
): List<DetectedFaceWithEmbedding> {
|
||||||
if (faces.size <= count) return faces
|
if (faces.size <= count) return faces
|
||||||
|
|
||||||
// Time-based sampling: spread across different dates
|
|
||||||
val sortedByTime = faces.sortedBy { it.capturedAt }
|
val sortedByTime = faces.sortedBy { it.capturedAt }
|
||||||
val step = faces.size / count
|
val step = faces.size / count
|
||||||
|
|
||||||
@@ -372,20 +382,12 @@ class FaceClusteringService @Inject constructor(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Estimate if cluster represents a child (based on photo timestamps)
|
|
||||||
*/
|
|
||||||
private fun estimateAge(faces: List<DetectedFaceWithEmbedding>): AgeEstimate {
|
private fun estimateAge(faces: List<DetectedFaceWithEmbedding>): AgeEstimate {
|
||||||
val timestamps = faces.map { it.capturedAt }.sorted()
|
val timestamps = faces.map { it.capturedAt }.sorted()
|
||||||
val span = timestamps.last() - timestamps.first()
|
val span = timestamps.last() - timestamps.first()
|
||||||
val spanYears = span / (365.25 * 24 * 60 * 60 * 1000)
|
val spanYears = span / (365.25 * 24 * 60 * 60 * 1000)
|
||||||
|
|
||||||
// If face appearance changes over 3+ years, likely a child
|
return if (spanYears > 3.0) AgeEstimate.CHILD else AgeEstimate.UNKNOWN
|
||||||
return if (spanYears > 3.0) {
|
|
||||||
AgeEstimate.CHILD
|
|
||||||
} else {
|
|
||||||
AgeEstimate.UNKNOWN
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private fun loadBitmapDownsampled(uri: Uri, maxDim: Int): Bitmap? {
|
private fun loadBitmapDownsampled(uri: Uri, maxDim: Int): Bitmap? {
|
||||||
@@ -414,17 +416,15 @@ class FaceClusteringService @Inject constructor(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ==================
|
// Data classes
|
||||||
// DATA CLASSES
|
|
||||||
// ==================
|
|
||||||
|
|
||||||
data class DetectedFaceWithEmbedding(
|
data class DetectedFaceWithEmbedding(
|
||||||
val imageId: String,
|
val imageId: String,
|
||||||
val imageUri: String,
|
val imageUri: String,
|
||||||
val capturedAt: Long,
|
val capturedAt: Long,
|
||||||
val embedding: FloatArray,
|
val embedding: FloatArray,
|
||||||
val boundingBox: android.graphics.Rect,
|
val boundingBox: android.graphics.Rect,
|
||||||
val confidence: Float
|
val confidence: Float,
|
||||||
|
val faceCount: Int = 1
|
||||||
) {
|
) {
|
||||||
override fun equals(other: Any?): Boolean {
|
override fun equals(other: Any?): Boolean {
|
||||||
if (this === other) return true
|
if (this === other) return true
|
||||||
@@ -459,7 +459,7 @@ data class ClusteringResult(
|
|||||||
)
|
)
|
||||||
|
|
||||||
enum class AgeEstimate {
|
enum class AgeEstimate {
|
||||||
CHILD, // Appearance changes significantly over time
|
CHILD,
|
||||||
ADULT, // Stable appearance
|
ADULT,
|
||||||
UNKNOWN // Not enough data
|
UNKNOWN
|
||||||
}
|
}
|
||||||
@@ -8,6 +8,8 @@ import com.placeholder.sherpai2.data.local.dao.PersonDao
|
|||||||
import com.placeholder.sherpai2.data.local.entity.FaceModelEntity
|
import com.placeholder.sherpai2.data.local.entity.FaceModelEntity
|
||||||
import com.placeholder.sherpai2.data.local.entity.PersonEntity
|
import com.placeholder.sherpai2.data.local.entity.PersonEntity
|
||||||
import com.placeholder.sherpai2.data.local.entity.TemporalCentroid
|
import com.placeholder.sherpai2.data.local.entity.TemporalCentroid
|
||||||
|
import com.placeholder.sherpai2.domain.clustering.ClusterQualityAnalyzer
|
||||||
|
import com.placeholder.sherpai2.domain.clustering.ClusterQualityResult
|
||||||
import com.placeholder.sherpai2.domain.clustering.FaceCluster
|
import com.placeholder.sherpai2.domain.clustering.FaceCluster
|
||||||
import com.placeholder.sherpai2.ml.FaceNetModel
|
import com.placeholder.sherpai2.ml.FaceNetModel
|
||||||
import dagger.hilt.android.qualifiers.ApplicationContext
|
import dagger.hilt.android.qualifiers.ApplicationContext
|
||||||
@@ -21,23 +23,36 @@ import kotlin.math.abs
|
|||||||
* ClusterTrainingService - Train multi-centroid face models from clusters
|
* ClusterTrainingService - Train multi-centroid face models from clusters
|
||||||
*
|
*
|
||||||
* STRATEGY:
|
* STRATEGY:
|
||||||
* 1. For children: Create multiple temporal centroids (one per age period)
|
* 1. VALIDATE cluster quality FIRST (prevent training on dirty/mixed clusters)
|
||||||
* 2. For adults: Create single centroid (stable appearance)
|
* 2. For children: Create multiple temporal centroids (one per age period)
|
||||||
* 3. Use K-Means clustering on timestamps to find age groups
|
* 3. For adults: Create single centroid (stable appearance)
|
||||||
* 4. Calculate centroid for each time period
|
* 4. Use K-Means clustering on timestamps to find age groups
|
||||||
|
* 5. Calculate centroid for each time period
|
||||||
*/
|
*/
|
||||||
@Singleton
|
@Singleton
|
||||||
class ClusterTrainingService @Inject constructor(
|
class ClusterTrainingService @Inject constructor(
|
||||||
@ApplicationContext private val context: Context,
|
@ApplicationContext private val context: Context,
|
||||||
private val personDao: PersonDao,
|
private val personDao: PersonDao,
|
||||||
private val faceModelDao: FaceModelDao
|
private val faceModelDao: FaceModelDao,
|
||||||
|
private val qualityAnalyzer: ClusterQualityAnalyzer
|
||||||
) {
|
) {
|
||||||
|
|
||||||
private val faceNetModel by lazy { FaceNetModel(context) }
|
private val faceNetModel by lazy { FaceNetModel(context) }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Analyze cluster quality before training
|
||||||
|
*
|
||||||
|
* Call this BEFORE trainFromCluster() to check if cluster is clean
|
||||||
|
*/
|
||||||
|
suspend fun analyzeClusterQuality(cluster: FaceCluster): ClusterQualityResult {
|
||||||
|
return qualityAnalyzer.analyzeCluster(cluster)
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Train a person from an auto-discovered cluster
|
* Train a person from an auto-discovered cluster
|
||||||
*
|
*
|
||||||
|
* @param cluster The discovered cluster
|
||||||
|
* @param qualityResult Optional pre-computed quality analysis (recommended)
|
||||||
* @return PersonId on success
|
* @return PersonId on success
|
||||||
*/
|
*/
|
||||||
suspend fun trainFromCluster(
|
suspend fun trainFromCluster(
|
||||||
@@ -46,12 +61,26 @@ class ClusterTrainingService @Inject constructor(
|
|||||||
dateOfBirth: Long?,
|
dateOfBirth: Long?,
|
||||||
isChild: Boolean,
|
isChild: Boolean,
|
||||||
siblingClusterIds: List<Int>,
|
siblingClusterIds: List<Int>,
|
||||||
|
qualityResult: ClusterQualityResult? = null,
|
||||||
onProgress: (Int, Int, String) -> Unit = { _, _, _ -> }
|
onProgress: (Int, Int, String) -> Unit = { _, _, _ -> }
|
||||||
): String = withContext(Dispatchers.Default) {
|
): String = withContext(Dispatchers.Default) {
|
||||||
|
|
||||||
onProgress(0, 100, "Creating person...")
|
onProgress(0, 100, "Creating person...")
|
||||||
|
|
||||||
// Step 1: Create PersonEntity
|
// Step 1: Use clean faces if quality analysis was done
|
||||||
|
val facesToUse = if (qualityResult != null && qualityResult.cleanFaces.isNotEmpty()) {
|
||||||
|
// Use clean faces (outliers removed)
|
||||||
|
qualityResult.cleanFaces
|
||||||
|
} else {
|
||||||
|
// Use all faces (legacy behavior)
|
||||||
|
cluster.faces
|
||||||
|
}
|
||||||
|
|
||||||
|
if (facesToUse.size < 6) {
|
||||||
|
throw Exception("Need at least 6 clean faces for training (have ${facesToUse.size})")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Step 2: Create PersonEntity
|
||||||
val person = PersonEntity.create(
|
val person = PersonEntity.create(
|
||||||
name = name,
|
name = name,
|
||||||
dateOfBirth = dateOfBirth,
|
dateOfBirth = dateOfBirth,
|
||||||
@@ -66,30 +95,20 @@ class ClusterTrainingService @Inject constructor(
|
|||||||
|
|
||||||
onProgress(20, 100, "Analyzing face variations...")
|
onProgress(20, 100, "Analyzing face variations...")
|
||||||
|
|
||||||
// Step 2: Generate embeddings for all faces in cluster
|
// Step 3: Use pre-computed embeddings from clustering
|
||||||
val facesWithEmbeddings = cluster.faces.mapNotNull { face ->
|
// CRITICAL: These embeddings are already face-specific, even in group photos!
|
||||||
try {
|
// The clustering phase already cropped and generated embeddings for each face.
|
||||||
val bitmap = context.contentResolver.openInputStream(Uri.parse(face.imageUri))?.use {
|
val facesWithEmbeddings = facesToUse.map { face ->
|
||||||
BitmapFactory.decodeStream(it)
|
Triple(
|
||||||
} ?: return@mapNotNull null
|
face.imageUri,
|
||||||
|
face.capturedAt,
|
||||||
// Generate embedding
|
face.embedding // ✅ Use existing embedding (already cropped to face)
|
||||||
val embedding = faceNetModel.generateEmbedding(bitmap)
|
)
|
||||||
bitmap.recycle()
|
|
||||||
|
|
||||||
Triple(face.imageUri, face.capturedAt, embedding)
|
|
||||||
} catch (e: Exception) {
|
|
||||||
null
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (facesWithEmbeddings.isEmpty()) {
|
|
||||||
throw Exception("Failed to process any faces from cluster")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
onProgress(50, 100, "Creating face model...")
|
onProgress(50, 100, "Creating face model...")
|
||||||
|
|
||||||
// Step 3: Create centroids based on whether person is a child
|
// Step 4: Create centroids based on whether person is a child
|
||||||
val centroids = if (isChild && dateOfBirth != null) {
|
val centroids = if (isChild && dateOfBirth != null) {
|
||||||
createTemporalCentroidsForChild(
|
createTemporalCentroidsForChild(
|
||||||
facesWithEmbeddings = facesWithEmbeddings,
|
facesWithEmbeddings = facesWithEmbeddings,
|
||||||
@@ -101,14 +120,14 @@ class ClusterTrainingService @Inject constructor(
|
|||||||
|
|
||||||
onProgress(80, 100, "Saving model...")
|
onProgress(80, 100, "Saving model...")
|
||||||
|
|
||||||
// Step 4: Calculate average confidence
|
// Step 5: Calculate average confidence
|
||||||
val avgConfidence = centroids.map { it.avgConfidence }.average().toFloat()
|
val avgConfidence = centroids.map { it.avgConfidence }.average().toFloat()
|
||||||
|
|
||||||
// Step 5: Create FaceModelEntity
|
// Step 6: Create FaceModelEntity
|
||||||
val faceModel = FaceModelEntity.createFromCentroids(
|
val faceModel = FaceModelEntity.createFromCentroids(
|
||||||
personId = person.id,
|
personId = person.id,
|
||||||
centroids = centroids,
|
centroids = centroids,
|
||||||
trainingImageCount = cluster.faces.size,
|
trainingImageCount = facesToUse.size,
|
||||||
averageConfidence = avgConfidence
|
averageConfidence = avgConfidence
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,312 @@
|
|||||||
|
package com.placeholder.sherpai2.domain.validation
|
||||||
|
|
||||||
|
import android.content.Context
|
||||||
|
import android.graphics.BitmapFactory
|
||||||
|
import android.net.Uri
|
||||||
|
import com.google.mlkit.vision.common.InputImage
|
||||||
|
import com.google.mlkit.vision.face.FaceDetection
|
||||||
|
import com.google.mlkit.vision.face.FaceDetectorOptions
|
||||||
|
import com.placeholder.sherpai2.data.local.dao.FaceModelDao
|
||||||
|
import com.placeholder.sherpai2.data.local.dao.ImageDao
|
||||||
|
import com.placeholder.sherpai2.data.local.entity.FaceModelEntity
|
||||||
|
import com.placeholder.sherpai2.data.local.entity.ImageEntity
|
||||||
|
import com.placeholder.sherpai2.ml.FaceNetModel
|
||||||
|
import dagger.hilt.android.qualifiers.ApplicationContext
|
||||||
|
import kotlinx.coroutines.Dispatchers
|
||||||
|
import kotlinx.coroutines.async
|
||||||
|
import kotlinx.coroutines.awaitAll
|
||||||
|
import kotlinx.coroutines.coroutineScope
|
||||||
|
import kotlinx.coroutines.tasks.await
|
||||||
|
import kotlinx.coroutines.withContext
|
||||||
|
import javax.inject.Inject
|
||||||
|
import javax.inject.Singleton
|
||||||
|
|
||||||
|
/**
|
||||||
|
* ValidationScanService - Quick validation scan after training
|
||||||
|
*
|
||||||
|
* PURPOSE: Let user verify model quality BEFORE full library scan
|
||||||
|
*
|
||||||
|
* STRATEGY:
|
||||||
|
* 1. Sample 20-30 random photos with faces
|
||||||
|
* 2. Scan for the newly trained person
|
||||||
|
* 3. Return preview results with confidence scores
|
||||||
|
* 4. User reviews and decides: "Looks good" or "Add more photos"
|
||||||
|
*
|
||||||
|
* THRESHOLD STRATEGY:
|
||||||
|
* - Use CONSERVATIVE threshold (0.75) for validation
|
||||||
|
* - Better to show false negatives than false positives
|
||||||
|
* - If user approves, full scan uses slightly looser threshold (0.70)
|
||||||
|
*/
|
||||||
|
@Singleton
|
||||||
|
class ValidationScanService @Inject constructor(
|
||||||
|
@ApplicationContext private val context: Context,
|
||||||
|
private val imageDao: ImageDao,
|
||||||
|
private val faceModelDao: FaceModelDao
|
||||||
|
) {
|
||||||
|
|
||||||
|
companion object {
|
||||||
|
private const val VALIDATION_SAMPLE_SIZE = 25
|
||||||
|
private const val VALIDATION_THRESHOLD = 0.75f // Conservative
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Perform validation scan after training
|
||||||
|
*
|
||||||
|
* @param personId The newly trained person
|
||||||
|
* @param onProgress Callback (current, total)
|
||||||
|
* @return Validation results with preview matches
|
||||||
|
*/
|
||||||
|
suspend fun performValidationScan(
|
||||||
|
personId: String,
|
||||||
|
onProgress: (Int, Int) -> Unit = { _, _ -> }
|
||||||
|
): ValidationScanResult = withContext(Dispatchers.Default) {
|
||||||
|
|
||||||
|
onProgress(0, 100)
|
||||||
|
|
||||||
|
// Step 1: Get face model
|
||||||
|
val faceModel = withContext(Dispatchers.IO) {
|
||||||
|
faceModelDao.getFaceModelByPersonId(personId)
|
||||||
|
} ?: return@withContext ValidationScanResult(
|
||||||
|
personId = personId,
|
||||||
|
matches = emptyList(),
|
||||||
|
sampleSize = 0,
|
||||||
|
errorMessage = "Face model not found"
|
||||||
|
)
|
||||||
|
|
||||||
|
onProgress(10, 100)
|
||||||
|
|
||||||
|
// Step 2: Get random sample of photos with faces
|
||||||
|
val allPhotosWithFaces = withContext(Dispatchers.IO) {
|
||||||
|
imageDao.getImagesWithFaces()
|
||||||
|
}
|
||||||
|
|
||||||
|
if (allPhotosWithFaces.isEmpty()) {
|
||||||
|
return@withContext ValidationScanResult(
|
||||||
|
personId = personId,
|
||||||
|
matches = emptyList(),
|
||||||
|
sampleSize = 0,
|
||||||
|
errorMessage = "No photos with faces in library"
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Random sample
|
||||||
|
val samplePhotos = allPhotosWithFaces.shuffled().take(VALIDATION_SAMPLE_SIZE)
|
||||||
|
onProgress(20, 100)
|
||||||
|
|
||||||
|
// Step 3: Scan sample photos
|
||||||
|
val faceNetModel = FaceNetModel(context)
|
||||||
|
val detector = FaceDetection.getClient(
|
||||||
|
FaceDetectorOptions.Builder()
|
||||||
|
.setPerformanceMode(FaceDetectorOptions.PERFORMANCE_MODE_ACCURATE)
|
||||||
|
.setMinFaceSize(0.15f)
|
||||||
|
.build()
|
||||||
|
)
|
||||||
|
|
||||||
|
try {
|
||||||
|
val matches = scanPhotosForPerson(
|
||||||
|
photos = samplePhotos,
|
||||||
|
faceModel = faceModel,
|
||||||
|
faceNetModel = faceNetModel,
|
||||||
|
detector = detector,
|
||||||
|
threshold = VALIDATION_THRESHOLD,
|
||||||
|
onProgress = { current, total ->
|
||||||
|
// Map to 20-100 range
|
||||||
|
val progress = 20 + (current * 80 / total)
|
||||||
|
onProgress(progress, 100)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
onProgress(100, 100)
|
||||||
|
|
||||||
|
ValidationScanResult(
|
||||||
|
personId = personId,
|
||||||
|
matches = matches,
|
||||||
|
sampleSize = samplePhotos.size,
|
||||||
|
threshold = VALIDATION_THRESHOLD
|
||||||
|
)
|
||||||
|
|
||||||
|
} finally {
|
||||||
|
faceNetModel.close()
|
||||||
|
detector.close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Scan photos for a specific person
|
||||||
|
*/
|
||||||
|
private suspend fun scanPhotosForPerson(
|
||||||
|
photos: List<ImageEntity>,
|
||||||
|
faceModel: FaceModelEntity,
|
||||||
|
faceNetModel: FaceNetModel,
|
||||||
|
detector: com.google.mlkit.vision.face.FaceDetector,
|
||||||
|
threshold: Float,
|
||||||
|
onProgress: (Int, Int) -> Unit
|
||||||
|
): List<ValidationMatch> = coroutineScope {
|
||||||
|
|
||||||
|
val modelEmbedding = faceModel.getEmbeddingArray()
|
||||||
|
val matches = mutableListOf<ValidationMatch>()
|
||||||
|
var processedCount = 0
|
||||||
|
|
||||||
|
// Process in parallel
|
||||||
|
val jobs = photos.map { photo ->
|
||||||
|
async(Dispatchers.IO) {
|
||||||
|
val photoMatches = scanSinglePhoto(
|
||||||
|
photo = photo,
|
||||||
|
modelEmbedding = modelEmbedding,
|
||||||
|
faceNetModel = faceNetModel,
|
||||||
|
detector = detector,
|
||||||
|
threshold = threshold
|
||||||
|
)
|
||||||
|
|
||||||
|
synchronized(matches) {
|
||||||
|
matches.addAll(photoMatches)
|
||||||
|
processedCount++
|
||||||
|
if (processedCount % 5 == 0) {
|
||||||
|
onProgress(processedCount, photos.size)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
jobs.awaitAll()
|
||||||
|
matches.sortedByDescending { it.confidence }
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Scan a single photo for the person
|
||||||
|
*/
|
||||||
|
private suspend fun scanSinglePhoto(
|
||||||
|
photo: ImageEntity,
|
||||||
|
modelEmbedding: FloatArray,
|
||||||
|
faceNetModel: FaceNetModel,
|
||||||
|
detector: com.google.mlkit.vision.face.FaceDetector,
|
||||||
|
threshold: Float
|
||||||
|
): List<ValidationMatch> = withContext(Dispatchers.IO) {
|
||||||
|
|
||||||
|
try {
|
||||||
|
// Load bitmap
|
||||||
|
val bitmap = loadBitmapDownsampled(Uri.parse(photo.imageUri), 768)
|
||||||
|
?: return@withContext emptyList()
|
||||||
|
|
||||||
|
// Detect faces
|
||||||
|
val inputImage = InputImage.fromBitmap(bitmap, 0)
|
||||||
|
val faces = detector.process(inputImage).await()
|
||||||
|
|
||||||
|
// Check each face
|
||||||
|
val matches = faces.mapNotNull { face ->
|
||||||
|
try {
|
||||||
|
// Crop face
|
||||||
|
val faceBitmap = android.graphics.Bitmap.createBitmap(
|
||||||
|
bitmap,
|
||||||
|
face.boundingBox.left.coerceIn(0, bitmap.width - 1),
|
||||||
|
face.boundingBox.top.coerceIn(0, bitmap.height - 1),
|
||||||
|
face.boundingBox.width().coerceAtMost(bitmap.width - face.boundingBox.left),
|
||||||
|
face.boundingBox.height().coerceAtMost(bitmap.height - face.boundingBox.top)
|
||||||
|
)
|
||||||
|
|
||||||
|
// Generate embedding
|
||||||
|
val faceEmbedding = faceNetModel.generateEmbedding(faceBitmap)
|
||||||
|
faceBitmap.recycle()
|
||||||
|
|
||||||
|
// Calculate similarity
|
||||||
|
val similarity = faceNetModel.calculateSimilarity(faceEmbedding, modelEmbedding)
|
||||||
|
|
||||||
|
if (similarity >= threshold) {
|
||||||
|
ValidationMatch(
|
||||||
|
imageId = photo.imageId,
|
||||||
|
imageUri = photo.imageUri,
|
||||||
|
capturedAt = photo.capturedAt,
|
||||||
|
confidence = similarity,
|
||||||
|
boundingBox = face.boundingBox,
|
||||||
|
faceCount = faces.size
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
null
|
||||||
|
}
|
||||||
|
} catch (e: Exception) {
|
||||||
|
null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bitmap.recycle()
|
||||||
|
matches
|
||||||
|
|
||||||
|
} catch (e: Exception) {
|
||||||
|
emptyList()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Load bitmap with downsampling
|
||||||
|
*/
|
||||||
|
private fun loadBitmapDownsampled(uri: Uri, maxDim: Int): android.graphics.Bitmap? {
|
||||||
|
return try {
|
||||||
|
val opts = BitmapFactory.Options().apply { inJustDecodeBounds = true }
|
||||||
|
context.contentResolver.openInputStream(uri)?.use {
|
||||||
|
BitmapFactory.decodeStream(it, null, opts)
|
||||||
|
}
|
||||||
|
|
||||||
|
var sample = 1
|
||||||
|
while (opts.outWidth / sample > maxDim || opts.outHeight / sample > maxDim) {
|
||||||
|
sample *= 2
|
||||||
|
}
|
||||||
|
|
||||||
|
val finalOpts = BitmapFactory.Options().apply {
|
||||||
|
inSampleSize = sample
|
||||||
|
}
|
||||||
|
|
||||||
|
context.contentResolver.openInputStream(uri)?.use {
|
||||||
|
BitmapFactory.decodeStream(it, null, finalOpts)
|
||||||
|
}
|
||||||
|
} catch (e: Exception) {
|
||||||
|
null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Result of validation scan
|
||||||
|
*/
|
||||||
|
data class ValidationScanResult(
|
||||||
|
val personId: String,
|
||||||
|
val matches: List<ValidationMatch>,
|
||||||
|
val sampleSize: Int,
|
||||||
|
val threshold: Float = 0.75f,
|
||||||
|
val errorMessage: String? = null
|
||||||
|
) {
|
||||||
|
val matchCount: Int get() = matches.size
|
||||||
|
val averageConfidence: Float get() = if (matches.isNotEmpty()) {
|
||||||
|
matches.map { it.confidence }.average().toFloat()
|
||||||
|
} else 0f
|
||||||
|
|
||||||
|
val qualityAssessment: ValidationQuality get() = when {
|
||||||
|
matchCount == 0 -> ValidationQuality.NO_MATCHES
|
||||||
|
averageConfidence >= 0.85f && matchCount >= 5 -> ValidationQuality.EXCELLENT
|
||||||
|
averageConfidence >= 0.78f && matchCount >= 3 -> ValidationQuality.GOOD
|
||||||
|
averageConfidence < 0.75f || matchCount < 2 -> ValidationQuality.POOR
|
||||||
|
else -> ValidationQuality.FAIR
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Single match found during validation
|
||||||
|
*/
|
||||||
|
data class ValidationMatch(
|
||||||
|
val imageId: String,
|
||||||
|
val imageUri: String,
|
||||||
|
val capturedAt: Long,
|
||||||
|
val confidence: Float,
|
||||||
|
val boundingBox: android.graphics.Rect,
|
||||||
|
val faceCount: Int
|
||||||
|
)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Overall quality assessment
|
||||||
|
*/
|
||||||
|
enum class ValidationQuality {
|
||||||
|
EXCELLENT, // High confidence, many matches
|
||||||
|
GOOD, // Decent confidence, some matches
|
||||||
|
FAIR, // Acceptable, proceed with caution
|
||||||
|
POOR, // Low confidence or very few matches
|
||||||
|
NO_MATCHES // No matches found at all
|
||||||
|
}
|
||||||
@@ -1,210 +1,212 @@
|
|||||||
package com.placeholder.sherpai2.ui.discover
|
package com.placeholder.sherpai2.ui.discover
|
||||||
|
|
||||||
import android.graphics.BitmapFactory
|
|
||||||
import android.net.Uri
|
|
||||||
import androidx.compose.foundation.Image
|
|
||||||
import androidx.compose.foundation.background
|
|
||||||
import androidx.compose.foundation.clickable
|
|
||||||
import androidx.compose.foundation.layout.*
|
import androidx.compose.foundation.layout.*
|
||||||
import androidx.compose.foundation.lazy.LazyColumn
|
|
||||||
import androidx.compose.foundation.lazy.grid.GridCells
|
|
||||||
import androidx.compose.foundation.lazy.grid.LazyVerticalGrid
|
|
||||||
import androidx.compose.foundation.lazy.grid.items
|
|
||||||
import androidx.compose.foundation.lazy.items
|
|
||||||
import androidx.compose.foundation.shape.CircleShape
|
|
||||||
import androidx.compose.foundation.shape.RoundedCornerShape
|
|
||||||
import androidx.compose.material.icons.Icons
|
import androidx.compose.material.icons.Icons
|
||||||
import androidx.compose.material.icons.filled.*
|
import androidx.compose.material.icons.filled.Person
|
||||||
import androidx.compose.material3.*
|
import androidx.compose.material3.*
|
||||||
import androidx.compose.runtime.*
|
import androidx.compose.runtime.*
|
||||||
import androidx.compose.ui.Alignment
|
import androidx.compose.ui.Alignment
|
||||||
import androidx.compose.ui.Modifier
|
import androidx.compose.ui.Modifier
|
||||||
import androidx.compose.ui.draw.clip
|
|
||||||
import androidx.compose.ui.graphics.asImageBitmap
|
|
||||||
import androidx.compose.ui.layout.ContentScale
|
|
||||||
import androidx.compose.ui.platform.LocalContext
|
|
||||||
import androidx.compose.ui.text.font.FontWeight
|
import androidx.compose.ui.text.font.FontWeight
|
||||||
import androidx.compose.ui.text.style.TextAlign
|
import androidx.compose.ui.text.style.TextAlign
|
||||||
import androidx.compose.ui.unit.dp
|
import androidx.compose.ui.unit.dp
|
||||||
import androidx.hilt.navigation.compose.hiltViewModel
|
import androidx.hilt.navigation.compose.hiltViewModel
|
||||||
import androidx.lifecycle.compose.collectAsStateWithLifecycle
|
|
||||||
import com.placeholder.sherpai2.domain.clustering.AgeEstimate
|
|
||||||
import com.placeholder.sherpai2.domain.clustering.FaceCluster
|
|
||||||
import java.text.SimpleDateFormat
|
|
||||||
import java.util.*
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* DiscoverPeopleScreen - Beautiful auto-clustering UI
|
* DiscoverPeopleScreen - COMPLETE WORKING VERSION
|
||||||
*
|
*
|
||||||
* FLOW:
|
* This handles ALL states properly including Idle state
|
||||||
* 1. Hero CTA: "Discover People in Your Photos"
|
|
||||||
* 2. Auto-clustering progress (2-5 min)
|
|
||||||
* 3. Grid of discovered people
|
|
||||||
* 4. Tap cluster → Name person + metadata
|
|
||||||
* 5. Background deep scan starts
|
|
||||||
*/
|
*/
|
||||||
@OptIn(ExperimentalMaterial3Api::class)
|
@OptIn(ExperimentalMaterial3Api::class)
|
||||||
@Composable
|
@Composable
|
||||||
fun DiscoverPeopleScreen(
|
fun DiscoverPeopleScreen(
|
||||||
viewModel: DiscoverPeopleViewModel = hiltViewModel()
|
viewModel: DiscoverPeopleViewModel = hiltViewModel(),
|
||||||
|
onNavigateBack: () -> Unit = {}
|
||||||
) {
|
) {
|
||||||
val uiState by viewModel.uiState.collectAsStateWithLifecycle()
|
val uiState by viewModel.uiState.collectAsState()
|
||||||
|
|
||||||
// NO SCAFFOLD - MainScreen already has TopAppBar
|
Scaffold(
|
||||||
Box(modifier = Modifier.fillMaxSize()) {
|
topBar = {
|
||||||
|
TopAppBar(
|
||||||
|
title = { Text("Discover People") },
|
||||||
|
navigationIcon = {
|
||||||
|
IconButton(onClick = onNavigateBack) {
|
||||||
|
Icon(
|
||||||
|
imageVector = Icons.Default.Person,
|
||||||
|
contentDescription = "Back"
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
}
|
||||||
|
) { paddingValues ->
|
||||||
|
Box(
|
||||||
|
modifier = Modifier
|
||||||
|
.fillMaxSize()
|
||||||
|
.padding(paddingValues)
|
||||||
|
) {
|
||||||
when (val state = uiState) {
|
when (val state = uiState) {
|
||||||
is DiscoverUiState.Idle -> IdleScreen(
|
// ===== IDLE STATE (START HERE) =====
|
||||||
|
is DiscoverUiState.Idle -> {
|
||||||
|
IdleStateContent(
|
||||||
onStartDiscovery = { viewModel.startDiscovery() }
|
onStartDiscovery = { viewModel.startDiscovery() }
|
||||||
)
|
)
|
||||||
|
}
|
||||||
|
|
||||||
is DiscoverUiState.Clustering -> ClusteringProgressScreen(
|
// ===== CLUSTERING IN PROGRESS =====
|
||||||
|
is DiscoverUiState.Clustering -> {
|
||||||
|
ClusteringProgressContent(
|
||||||
progress = state.progress,
|
progress = state.progress,
|
||||||
total = state.total,
|
total = state.total,
|
||||||
message = state.message
|
message = state.message
|
||||||
)
|
)
|
||||||
|
|
||||||
is DiscoverUiState.NamingReady -> ClusterGridScreen(
|
|
||||||
result = state.result,
|
|
||||||
onClusterClick = { cluster ->
|
|
||||||
viewModel.selectCluster(cluster)
|
|
||||||
}
|
}
|
||||||
)
|
|
||||||
|
|
||||||
is DiscoverUiState.NamingCluster -> NamingDialog(
|
// ===== CLUSTERS READY FOR NAMING =====
|
||||||
cluster = state.selectedCluster,
|
is DiscoverUiState.NamingReady -> {
|
||||||
suggestedSiblings = state.suggestedSiblings,
|
Text(
|
||||||
onConfirm = { name, dob, isChild, siblings ->
|
text = "Found ${state.result.clusters.size} people!\n\nCluster grid UI coming...",
|
||||||
viewModel.confirmClusterName(
|
modifier = Modifier.align(Alignment.Center)
|
||||||
cluster = state.selectedCluster,
|
)
|
||||||
name = name,
|
}
|
||||||
dateOfBirth = dob,
|
|
||||||
isChild = isChild,
|
// ===== ANALYZING CLUSTER QUALITY =====
|
||||||
selectedSiblings = siblings
|
is DiscoverUiState.AnalyzingCluster -> {
|
||||||
|
LoadingContent(message = "Analyzing cluster quality...")
|
||||||
|
}
|
||||||
|
|
||||||
|
// ===== NAMING A CLUSTER =====
|
||||||
|
is DiscoverUiState.NamingCluster -> {
|
||||||
|
Text(
|
||||||
|
text = "Naming dialog for cluster ${state.selectedCluster.clusterId}\n\nDialog UI coming...",
|
||||||
|
modifier = Modifier.align(Alignment.Center)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ===== TRAINING IN PROGRESS =====
|
||||||
|
is DiscoverUiState.Training -> {
|
||||||
|
TrainingProgressContent(
|
||||||
|
stage = state.stage,
|
||||||
|
progress = state.progress,
|
||||||
|
total = state.total
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ===== VALIDATION PREVIEW =====
|
||||||
|
is DiscoverUiState.ValidationPreview -> {
|
||||||
|
ValidationPreviewScreen(
|
||||||
|
personName = state.personName,
|
||||||
|
validationResult = state.validationResult,
|
||||||
|
onApprove = {
|
||||||
|
viewModel.approveValidationAndScan(
|
||||||
|
personId = state.personId,
|
||||||
|
personName = state.personName
|
||||||
)
|
)
|
||||||
},
|
},
|
||||||
onDismiss = { viewModel.cancelNaming() }
|
onReject = {
|
||||||
|
viewModel.rejectValidationAndImprove()
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
}
|
||||||
|
|
||||||
is DiscoverUiState.NoPeopleFound -> EmptyStateScreen(
|
// ===== COMPLETE =====
|
||||||
message = state.message
|
is DiscoverUiState.Complete -> {
|
||||||
)
|
CompleteStateContent(
|
||||||
|
|
||||||
is DiscoverUiState.Error -> ErrorScreen(
|
|
||||||
message = state.message,
|
message = state.message,
|
||||||
onRetry = { viewModel.startDiscovery() }
|
onDone = onNavigateBack
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ===== NO PEOPLE FOUND =====
|
||||||
|
is DiscoverUiState.NoPeopleFound -> {
|
||||||
|
ErrorStateContent(
|
||||||
|
title = "No People Found",
|
||||||
|
message = state.message,
|
||||||
|
onRetry = { viewModel.startDiscovery() },
|
||||||
|
onBack = onNavigateBack
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ===== ERROR =====
|
||||||
|
is DiscoverUiState.Error -> {
|
||||||
|
ErrorStateContent(
|
||||||
|
title = "Error",
|
||||||
|
message = state.message,
|
||||||
|
onRetry = { viewModel.reset(); viewModel.startDiscovery() },
|
||||||
|
onBack = onNavigateBack
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
// ===== IDLE STATE CONTENT =====
|
||||||
* Idle state - Hero CTA to start discovery
|
|
||||||
*/
|
|
||||||
@Composable
|
@Composable
|
||||||
fun IdleScreen(
|
private fun IdleStateContent(
|
||||||
onStartDiscovery: () -> Unit
|
onStartDiscovery: () -> Unit
|
||||||
) {
|
) {
|
||||||
Column(
|
Column(
|
||||||
modifier = Modifier
|
modifier = Modifier
|
||||||
.fillMaxSize()
|
.fillMaxSize()
|
||||||
.padding(32.dp),
|
.padding(24.dp),
|
||||||
horizontalAlignment = Alignment.CenterHorizontally,
|
horizontalAlignment = Alignment.CenterHorizontally,
|
||||||
verticalArrangement = Arrangement.Center
|
verticalArrangement = Arrangement.Center
|
||||||
) {
|
) {
|
||||||
Icon(
|
Icon(
|
||||||
imageVector = Icons.Default.AutoAwesome,
|
imageVector = Icons.Default.Person,
|
||||||
contentDescription = null,
|
contentDescription = null,
|
||||||
modifier = Modifier.size(120.dp),
|
modifier = Modifier.size(120.dp),
|
||||||
tint = MaterialTheme.colorScheme.primary
|
tint = MaterialTheme.colorScheme.primary
|
||||||
)
|
)
|
||||||
|
|
||||||
Spacer(Modifier.height(24.dp))
|
Spacer(modifier = Modifier.height(32.dp))
|
||||||
|
|
||||||
Text(
|
Text(
|
||||||
text = "Discover People",
|
text = "Discover People",
|
||||||
style = MaterialTheme.typography.headlineLarge,
|
style = MaterialTheme.typography.headlineLarge,
|
||||||
fontWeight = FontWeight.Bold,
|
fontWeight = FontWeight.Bold
|
||||||
textAlign = TextAlign.Center
|
|
||||||
)
|
)
|
||||||
|
|
||||||
Spacer(Modifier.height(16.dp))
|
Spacer(modifier = Modifier.height(16.dp))
|
||||||
|
|
||||||
Text(
|
Text(
|
||||||
text = "Let AI automatically find and group faces in your photos. " +
|
text = "Automatically find and organize people in your photo library",
|
||||||
"You'll name them, and we'll tag all their photos.",
|
|
||||||
style = MaterialTheme.typography.bodyLarge,
|
style = MaterialTheme.typography.bodyLarge,
|
||||||
textAlign = TextAlign.Center,
|
textAlign = TextAlign.Center,
|
||||||
color = MaterialTheme.colorScheme.onSurfaceVariant
|
color = MaterialTheme.colorScheme.onSurfaceVariant
|
||||||
)
|
)
|
||||||
|
|
||||||
Spacer(Modifier.height(32.dp))
|
Spacer(modifier = Modifier.height(48.dp))
|
||||||
|
|
||||||
Button(
|
Button(
|
||||||
onClick = onStartDiscovery,
|
onClick = onStartDiscovery,
|
||||||
modifier = Modifier
|
modifier = Modifier
|
||||||
.fillMaxWidth()
|
.fillMaxWidth()
|
||||||
.height(56.dp),
|
.height(56.dp)
|
||||||
colors = ButtonDefaults.buttonColors(
|
|
||||||
containerColor = MaterialTheme.colorScheme.primary
|
|
||||||
)
|
|
||||||
) {
|
) {
|
||||||
Icon(
|
|
||||||
imageVector = Icons.Default.AutoAwesome,
|
|
||||||
contentDescription = null,
|
|
||||||
modifier = Modifier.size(24.dp)
|
|
||||||
)
|
|
||||||
Spacer(Modifier.width(8.dp))
|
|
||||||
Text(
|
Text(
|
||||||
text = "Start Discovery",
|
text = "Start Discovery",
|
||||||
style = MaterialTheme.typography.titleMedium,
|
style = MaterialTheme.typography.titleMedium
|
||||||
fontWeight = FontWeight.Bold
|
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
Spacer(Modifier.height(16.dp))
|
Spacer(modifier = Modifier.height(16.dp))
|
||||||
|
|
||||||
Card(
|
|
||||||
modifier = Modifier.fillMaxWidth(),
|
|
||||||
colors = CardDefaults.cardColors(
|
|
||||||
containerColor = MaterialTheme.colorScheme.surfaceVariant
|
|
||||||
)
|
|
||||||
) {
|
|
||||||
Column(
|
|
||||||
modifier = Modifier.padding(16.dp),
|
|
||||||
verticalArrangement = Arrangement.spacedBy(8.dp)
|
|
||||||
) {
|
|
||||||
InfoRow(Icons.Default.Speed, "Fast: Analyzes ~1000 photos in 2-5 minutes")
|
|
||||||
InfoRow(Icons.Default.Security, "Private: Everything stays on your device")
|
|
||||||
InfoRow(Icons.Default.AutoAwesome, "Smart: Groups faces automatically")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Composable
|
|
||||||
fun InfoRow(icon: androidx.compose.ui.graphics.vector.ImageVector, text: String) {
|
|
||||||
Row(
|
|
||||||
verticalAlignment = Alignment.CenterVertically,
|
|
||||||
horizontalArrangement = Arrangement.spacedBy(12.dp)
|
|
||||||
) {
|
|
||||||
Icon(
|
|
||||||
imageVector = icon,
|
|
||||||
contentDescription = null,
|
|
||||||
tint = MaterialTheme.colorScheme.primary,
|
|
||||||
modifier = Modifier.size(20.dp)
|
|
||||||
)
|
|
||||||
Text(
|
Text(
|
||||||
text = text,
|
text = "This will analyze faces in your photos and group similar faces together",
|
||||||
style = MaterialTheme.typography.bodyMedium
|
style = MaterialTheme.typography.bodySmall,
|
||||||
|
textAlign = TextAlign.Center,
|
||||||
|
color = MaterialTheme.colorScheme.onSurfaceVariant
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
// ===== CLUSTERING PROGRESS =====
|
||||||
* Clustering progress screen
|
|
||||||
*/
|
|
||||||
@Composable
|
@Composable
|
||||||
fun ClusteringProgressScreen(
|
private fun ClusteringProgressContent(
|
||||||
progress: Int,
|
progress: Int,
|
||||||
total: Int,
|
total: Int,
|
||||||
message: String
|
message: String
|
||||||
@@ -212,464 +214,134 @@ fun ClusteringProgressScreen(
|
|||||||
Column(
|
Column(
|
||||||
modifier = Modifier
|
modifier = Modifier
|
||||||
.fillMaxSize()
|
.fillMaxSize()
|
||||||
.padding(32.dp),
|
.padding(24.dp),
|
||||||
horizontalAlignment = Alignment.CenterHorizontally,
|
horizontalAlignment = Alignment.CenterHorizontally,
|
||||||
verticalArrangement = Arrangement.Center
|
verticalArrangement = Arrangement.Center
|
||||||
) {
|
) {
|
||||||
CircularProgressIndicator(
|
CircularProgressIndicator(
|
||||||
modifier = Modifier.size(80.dp),
|
modifier = Modifier.size(64.dp)
|
||||||
strokeWidth = 6.dp
|
|
||||||
)
|
)
|
||||||
|
|
||||||
Spacer(Modifier.height(32.dp))
|
Spacer(modifier = Modifier.height(32.dp))
|
||||||
|
|
||||||
Text(
|
|
||||||
text = "Discovering People...",
|
|
||||||
style = MaterialTheme.typography.headlineSmall,
|
|
||||||
fontWeight = FontWeight.Bold
|
|
||||||
)
|
|
||||||
|
|
||||||
Spacer(Modifier.height(16.dp))
|
|
||||||
|
|
||||||
LinearProgressIndicator(
|
|
||||||
progress = { if (total > 0) progress.toFloat() / total.toFloat() else 0f },
|
|
||||||
modifier = Modifier.fillMaxWidth(),
|
|
||||||
)
|
|
||||||
|
|
||||||
Spacer(Modifier.height(8.dp))
|
|
||||||
|
|
||||||
Text(
|
Text(
|
||||||
text = message,
|
text = message,
|
||||||
style = MaterialTheme.typography.bodyMedium,
|
|
||||||
color = MaterialTheme.colorScheme.onSurfaceVariant
|
|
||||||
)
|
|
||||||
|
|
||||||
Spacer(Modifier.height(24.dp))
|
|
||||||
|
|
||||||
Text(
|
|
||||||
text = "This will take 2-5 minutes. You can leave and come back later.",
|
|
||||||
style = MaterialTheme.typography.bodySmall,
|
|
||||||
textAlign = TextAlign.Center,
|
|
||||||
color = MaterialTheme.colorScheme.onSurfaceVariant
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Grid of discovered clusters
|
|
||||||
*/
|
|
||||||
@Composable
|
|
||||||
fun ClusterGridScreen(
|
|
||||||
result: com.placeholder.sherpai2.domain.clustering.ClusteringResult,
|
|
||||||
onClusterClick: (FaceCluster) -> Unit
|
|
||||||
) {
|
|
||||||
Column(
|
|
||||||
modifier = Modifier
|
|
||||||
.fillMaxSize()
|
|
||||||
.padding(16.dp)
|
|
||||||
) {
|
|
||||||
Text(
|
|
||||||
text = "Found ${result.clusters.size} People",
|
|
||||||
style = MaterialTheme.typography.headlineSmall,
|
|
||||||
fontWeight = FontWeight.Bold
|
|
||||||
)
|
|
||||||
|
|
||||||
Spacer(Modifier.height(8.dp))
|
|
||||||
|
|
||||||
Text(
|
|
||||||
text = "Tap to name each person. We'll then tag all their photos.",
|
|
||||||
style = MaterialTheme.typography.bodyMedium,
|
|
||||||
color = MaterialTheme.colorScheme.onSurfaceVariant
|
|
||||||
)
|
|
||||||
|
|
||||||
Spacer(Modifier.height(16.dp))
|
|
||||||
|
|
||||||
LazyVerticalGrid(
|
|
||||||
columns = GridCells.Fixed(2),
|
|
||||||
horizontalArrangement = Arrangement.spacedBy(12.dp),
|
|
||||||
verticalArrangement = Arrangement.spacedBy(12.dp)
|
|
||||||
) {
|
|
||||||
items(result.clusters) { cluster ->
|
|
||||||
ClusterCard(
|
|
||||||
cluster = cluster,
|
|
||||||
onClick = { onClusterClick(cluster) }
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Single cluster card
|
|
||||||
*/
|
|
||||||
@Composable
|
|
||||||
fun ClusterCard(
|
|
||||||
cluster: FaceCluster,
|
|
||||||
onClick: () -> Unit
|
|
||||||
) {
|
|
||||||
val context = LocalContext.current
|
|
||||||
|
|
||||||
Card(
|
|
||||||
modifier = Modifier
|
|
||||||
.fillMaxWidth()
|
|
||||||
.clickable(onClick = onClick),
|
|
||||||
elevation = CardDefaults.cardElevation(defaultElevation = 2.dp)
|
|
||||||
) {
|
|
||||||
Column {
|
|
||||||
// Face grid (2x3)
|
|
||||||
LazyVerticalGrid(
|
|
||||||
columns = GridCells.Fixed(3),
|
|
||||||
modifier = Modifier.height(180.dp),
|
|
||||||
userScrollEnabled = false
|
|
||||||
) {
|
|
||||||
items(cluster.representativeFaces.take(6)) { face ->
|
|
||||||
val bitmap = remember(face.imageUri) {
|
|
||||||
try {
|
|
||||||
context.contentResolver.openInputStream(Uri.parse(face.imageUri))?.use {
|
|
||||||
BitmapFactory.decodeStream(it)
|
|
||||||
}
|
|
||||||
} catch (e: Exception) {
|
|
||||||
null
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (bitmap != null) {
|
|
||||||
Image(
|
|
||||||
bitmap = bitmap.asImageBitmap(),
|
|
||||||
contentDescription = null,
|
|
||||||
modifier = Modifier
|
|
||||||
.fillMaxWidth()
|
|
||||||
.aspectRatio(1f),
|
|
||||||
contentScale = ContentScale.Crop
|
|
||||||
)
|
|
||||||
} else {
|
|
||||||
Box(
|
|
||||||
modifier = Modifier
|
|
||||||
.fillMaxWidth()
|
|
||||||
.aspectRatio(1f)
|
|
||||||
.background(MaterialTheme.colorScheme.surfaceVariant),
|
|
||||||
contentAlignment = Alignment.Center
|
|
||||||
) {
|
|
||||||
Icon(
|
|
||||||
imageVector = Icons.Default.Person,
|
|
||||||
contentDescription = null,
|
|
||||||
tint = MaterialTheme.colorScheme.onSurfaceVariant
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Info
|
|
||||||
Column(
|
|
||||||
modifier = Modifier.padding(12.dp)
|
|
||||||
) {
|
|
||||||
Row(
|
|
||||||
modifier = Modifier.fillMaxWidth(),
|
|
||||||
horizontalArrangement = Arrangement.SpaceBetween,
|
|
||||||
verticalAlignment = Alignment.CenterVertically
|
|
||||||
) {
|
|
||||||
Text(
|
|
||||||
text = "${cluster.photoCount} photos",
|
|
||||||
style = MaterialTheme.typography.titleMedium,
|
style = MaterialTheme.typography.titleMedium,
|
||||||
fontWeight = FontWeight.Bold
|
|
||||||
)
|
|
||||||
|
|
||||||
if (cluster.estimatedAge == AgeEstimate.CHILD) {
|
|
||||||
Surface(
|
|
||||||
shape = RoundedCornerShape(12.dp),
|
|
||||||
color = MaterialTheme.colorScheme.primaryContainer
|
|
||||||
) {
|
|
||||||
Text(
|
|
||||||
text = "Child",
|
|
||||||
modifier = Modifier.padding(horizontal = 8.dp, vertical = 4.dp),
|
|
||||||
style = MaterialTheme.typography.labelSmall,
|
|
||||||
color = MaterialTheme.colorScheme.onPrimaryContainer
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (cluster.potentialSiblings.isNotEmpty()) {
|
|
||||||
Spacer(Modifier.height(4.dp))
|
|
||||||
Text(
|
|
||||||
text = "Appears with ${cluster.potentialSiblings.size} other person(s)",
|
|
||||||
style = MaterialTheme.typography.bodySmall,
|
|
||||||
color = MaterialTheme.colorScheme.onSurfaceVariant
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Naming dialog
|
|
||||||
*/
|
|
||||||
@Composable
|
|
||||||
fun NamingDialog(
|
|
||||||
cluster: FaceCluster,
|
|
||||||
suggestedSiblings: List<FaceCluster>,
|
|
||||||
onConfirm: (String, Long?, Boolean, List<Int>) -> Unit,
|
|
||||||
onDismiss: () -> Unit
|
|
||||||
) {
|
|
||||||
var name by remember { mutableStateOf("") }
|
|
||||||
var isChild by remember { mutableStateOf(cluster.estimatedAge == AgeEstimate.CHILD) }
|
|
||||||
var dateOfBirth by remember { mutableStateOf<Long?>(null) }
|
|
||||||
var selectedSiblings by remember { mutableStateOf<Set<Int>>(emptySet()) }
|
|
||||||
var showDatePicker by remember { mutableStateOf(false) }
|
|
||||||
val context = LocalContext.current
|
|
||||||
|
|
||||||
// Date picker dialog
|
|
||||||
if (showDatePicker) {
|
|
||||||
val calendar = java.util.Calendar.getInstance()
|
|
||||||
if (dateOfBirth != null) {
|
|
||||||
calendar.timeInMillis = dateOfBirth!!
|
|
||||||
}
|
|
||||||
|
|
||||||
val datePickerDialog = android.app.DatePickerDialog(
|
|
||||||
context,
|
|
||||||
{ _, year, month, dayOfMonth ->
|
|
||||||
val cal = java.util.Calendar.getInstance()
|
|
||||||
cal.set(year, month, dayOfMonth)
|
|
||||||
dateOfBirth = cal.timeInMillis
|
|
||||||
showDatePicker = false
|
|
||||||
},
|
|
||||||
calendar.get(java.util.Calendar.YEAR),
|
|
||||||
calendar.get(java.util.Calendar.MONTH),
|
|
||||||
calendar.get(java.util.Calendar.DAY_OF_MONTH)
|
|
||||||
)
|
|
||||||
|
|
||||||
datePickerDialog.setOnDismissListener {
|
|
||||||
showDatePicker = false
|
|
||||||
}
|
|
||||||
|
|
||||||
DisposableEffect(Unit) {
|
|
||||||
datePickerDialog.show()
|
|
||||||
onDispose {
|
|
||||||
datePickerDialog.dismiss()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
AlertDialog(
|
|
||||||
onDismissRequest = onDismiss,
|
|
||||||
title = {
|
|
||||||
Text("Name This Person")
|
|
||||||
},
|
|
||||||
text = {
|
|
||||||
Column(
|
|
||||||
verticalArrangement = Arrangement.spacedBy(16.dp)
|
|
||||||
) {
|
|
||||||
// FACE PREVIEW - Show 6 representative faces
|
|
||||||
Text(
|
|
||||||
text = "${cluster.photoCount} photos found",
|
|
||||||
style = MaterialTheme.typography.bodyMedium,
|
|
||||||
color = MaterialTheme.colorScheme.onSurfaceVariant
|
|
||||||
)
|
|
||||||
|
|
||||||
LazyVerticalGrid(
|
|
||||||
columns = GridCells.Fixed(3),
|
|
||||||
modifier = Modifier.height(180.dp),
|
|
||||||
horizontalArrangement = Arrangement.spacedBy(4.dp),
|
|
||||||
verticalArrangement = Arrangement.spacedBy(4.dp)
|
|
||||||
) {
|
|
||||||
items(cluster.representativeFaces.take(6)) { face ->
|
|
||||||
val bitmap = remember(face.imageUri) {
|
|
||||||
try {
|
|
||||||
context.contentResolver.openInputStream(Uri.parse(face.imageUri))?.use {
|
|
||||||
BitmapFactory.decodeStream(it)
|
|
||||||
}
|
|
||||||
} catch (e: Exception) {
|
|
||||||
null
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (bitmap != null) {
|
|
||||||
Image(
|
|
||||||
bitmap = bitmap.asImageBitmap(),
|
|
||||||
contentDescription = null,
|
|
||||||
modifier = Modifier
|
|
||||||
.fillMaxWidth()
|
|
||||||
.aspectRatio(1f)
|
|
||||||
.clip(RoundedCornerShape(8.dp)),
|
|
||||||
contentScale = ContentScale.Crop
|
|
||||||
)
|
|
||||||
} else {
|
|
||||||
Box(
|
|
||||||
modifier = Modifier
|
|
||||||
.fillMaxWidth()
|
|
||||||
.aspectRatio(1f)
|
|
||||||
.clip(RoundedCornerShape(8.dp))
|
|
||||||
.background(MaterialTheme.colorScheme.surfaceVariant),
|
|
||||||
contentAlignment = Alignment.Center
|
|
||||||
) {
|
|
||||||
Icon(
|
|
||||||
imageVector = Icons.Default.Person,
|
|
||||||
contentDescription = null,
|
|
||||||
tint = MaterialTheme.colorScheme.onSurfaceVariant
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
HorizontalDivider()
|
|
||||||
|
|
||||||
// Name input
|
|
||||||
OutlinedTextField(
|
|
||||||
value = name,
|
|
||||||
onValueChange = { name = it },
|
|
||||||
label = { Text("Name") },
|
|
||||||
singleLine = true,
|
|
||||||
modifier = Modifier.fillMaxWidth()
|
|
||||||
)
|
|
||||||
|
|
||||||
// Is child toggle
|
|
||||||
Row(
|
|
||||||
modifier = Modifier.fillMaxWidth(),
|
|
||||||
horizontalArrangement = Arrangement.SpaceBetween,
|
|
||||||
verticalAlignment = Alignment.CenterVertically
|
|
||||||
) {
|
|
||||||
Text("This person is a child")
|
|
||||||
Switch(
|
|
||||||
checked = isChild,
|
|
||||||
onCheckedChange = { isChild = it }
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Date of birth (if child)
|
|
||||||
if (isChild) {
|
|
||||||
OutlinedButton(
|
|
||||||
onClick = { showDatePicker = true },
|
|
||||||
modifier = Modifier.fillMaxWidth()
|
|
||||||
) {
|
|
||||||
Icon(Icons.Default.CalendarToday, null)
|
|
||||||
Spacer(Modifier.width(8.dp))
|
|
||||||
Text(
|
|
||||||
if (dateOfBirth != null) {
|
|
||||||
SimpleDateFormat("MMM dd, yyyy", Locale.getDefault())
|
|
||||||
.format(Date(dateOfBirth!!))
|
|
||||||
} else {
|
|
||||||
"Set Date of Birth"
|
|
||||||
}
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Suggested siblings
|
|
||||||
if (suggestedSiblings.isNotEmpty()) {
|
|
||||||
Text(
|
|
||||||
"Appears with these people (select siblings):",
|
|
||||||
style = MaterialTheme.typography.labelMedium
|
|
||||||
)
|
|
||||||
|
|
||||||
suggestedSiblings.forEach { sibling ->
|
|
||||||
Row(
|
|
||||||
modifier = Modifier.fillMaxWidth(),
|
|
||||||
verticalAlignment = Alignment.CenterVertically
|
|
||||||
) {
|
|
||||||
Checkbox(
|
|
||||||
checked = sibling.clusterId in selectedSiblings,
|
|
||||||
onCheckedChange = { checked ->
|
|
||||||
selectedSiblings = if (checked) {
|
|
||||||
selectedSiblings + sibling.clusterId
|
|
||||||
} else {
|
|
||||||
selectedSiblings - sibling.clusterId
|
|
||||||
}
|
|
||||||
}
|
|
||||||
)
|
|
||||||
Text("Person ${sibling.clusterId + 1} (${sibling.photoCount} photos)")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
confirmButton = {
|
|
||||||
TextButton(
|
|
||||||
onClick = {
|
|
||||||
onConfirm(
|
|
||||||
name,
|
|
||||||
dateOfBirth,
|
|
||||||
isChild,
|
|
||||||
selectedSiblings.toList()
|
|
||||||
)
|
|
||||||
},
|
|
||||||
enabled = name.isNotBlank()
|
|
||||||
) {
|
|
||||||
Text("Save & Train")
|
|
||||||
}
|
|
||||||
},
|
|
||||||
dismissButton = {
|
|
||||||
TextButton(onClick = onDismiss) {
|
|
||||||
Text("Cancel")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
// TODO: Add DatePickerDialog when showDatePicker is true
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Empty state screen
|
|
||||||
*/
|
|
||||||
@Composable
|
|
||||||
fun EmptyStateScreen(message: String) {
|
|
||||||
Column(
|
|
||||||
modifier = Modifier
|
|
||||||
.fillMaxSize()
|
|
||||||
.padding(32.dp),
|
|
||||||
horizontalAlignment = Alignment.CenterHorizontally,
|
|
||||||
verticalArrangement = Arrangement.Center
|
|
||||||
) {
|
|
||||||
Icon(
|
|
||||||
imageVector = Icons.Default.PersonOff,
|
|
||||||
contentDescription = null,
|
|
||||||
modifier = Modifier.size(80.dp),
|
|
||||||
tint = MaterialTheme.colorScheme.onSurfaceVariant
|
|
||||||
)
|
|
||||||
|
|
||||||
Spacer(Modifier.height(16.dp))
|
|
||||||
|
|
||||||
Text(
|
|
||||||
text = message,
|
|
||||||
style = MaterialTheme.typography.bodyLarge,
|
|
||||||
textAlign = TextAlign.Center
|
textAlign = TextAlign.Center
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Spacer(modifier = Modifier.height(16.dp))
|
||||||
|
|
||||||
|
if (total > 0) {
|
||||||
|
LinearProgressIndicator(
|
||||||
|
progress = progress.toFloat() / total.toFloat(),
|
||||||
|
modifier = Modifier
|
||||||
|
.fillMaxWidth()
|
||||||
|
.height(8.dp)
|
||||||
|
)
|
||||||
|
|
||||||
|
Spacer(modifier = Modifier.height(8.dp))
|
||||||
|
|
||||||
|
Text(
|
||||||
|
text = "$progress / $total",
|
||||||
|
style = MaterialTheme.typography.bodyMedium,
|
||||||
|
color = MaterialTheme.colorScheme.onSurfaceVariant
|
||||||
|
)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
// ===== TRAINING PROGRESS =====
|
||||||
* Error screen
|
|
||||||
*/
|
|
||||||
@Composable
|
@Composable
|
||||||
fun ErrorScreen(
|
private fun TrainingProgressContent(
|
||||||
message: String,
|
stage: String,
|
||||||
onRetry: () -> Unit
|
progress: Int,
|
||||||
|
total: Int
|
||||||
) {
|
) {
|
||||||
Column(
|
Column(
|
||||||
modifier = Modifier
|
modifier = Modifier
|
||||||
.fillMaxSize()
|
.fillMaxSize()
|
||||||
.padding(32.dp),
|
.padding(24.dp),
|
||||||
horizontalAlignment = Alignment.CenterHorizontally,
|
horizontalAlignment = Alignment.CenterHorizontally,
|
||||||
verticalArrangement = Arrangement.Center
|
verticalArrangement = Arrangement.Center
|
||||||
) {
|
) {
|
||||||
Icon(
|
CircularProgressIndicator(
|
||||||
imageVector = Icons.Default.Error,
|
modifier = Modifier.size(64.dp)
|
||||||
contentDescription = null,
|
|
||||||
modifier = Modifier.size(80.dp),
|
|
||||||
tint = MaterialTheme.colorScheme.error
|
|
||||||
)
|
)
|
||||||
|
|
||||||
Spacer(Modifier.height(16.dp))
|
Spacer(modifier = Modifier.height(32.dp))
|
||||||
|
|
||||||
Text(
|
Text(
|
||||||
text = "Oops!",
|
text = stage,
|
||||||
|
style = MaterialTheme.typography.titleMedium,
|
||||||
|
textAlign = TextAlign.Center
|
||||||
|
)
|
||||||
|
|
||||||
|
if (total > 0) {
|
||||||
|
Spacer(modifier = Modifier.height(16.dp))
|
||||||
|
|
||||||
|
LinearProgressIndicator(
|
||||||
|
progress = progress.toFloat() / total.toFloat(),
|
||||||
|
modifier = Modifier
|
||||||
|
.fillMaxWidth()
|
||||||
|
.height(8.dp)
|
||||||
|
)
|
||||||
|
|
||||||
|
Spacer(modifier = Modifier.height(8.dp))
|
||||||
|
|
||||||
|
Text(
|
||||||
|
text = "$progress / $total",
|
||||||
|
style = MaterialTheme.typography.bodyMedium,
|
||||||
|
color = MaterialTheme.colorScheme.onSurfaceVariant
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ===== LOADING CONTENT =====
|
||||||
|
|
||||||
|
@Composable
|
||||||
|
private fun LoadingContent(message: String) {
|
||||||
|
Column(
|
||||||
|
modifier = Modifier.fillMaxSize(),
|
||||||
|
horizontalAlignment = Alignment.CenterHorizontally,
|
||||||
|
verticalArrangement = Arrangement.Center
|
||||||
|
) {
|
||||||
|
CircularProgressIndicator()
|
||||||
|
Spacer(modifier = Modifier.height(16.dp))
|
||||||
|
Text(text = message)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ===== COMPLETE STATE =====
|
||||||
|
|
||||||
|
@Composable
|
||||||
|
private fun CompleteStateContent(
|
||||||
|
message: String,
|
||||||
|
onDone: () -> Unit
|
||||||
|
) {
|
||||||
|
Column(
|
||||||
|
modifier = Modifier
|
||||||
|
.fillMaxSize()
|
||||||
|
.padding(24.dp),
|
||||||
|
horizontalAlignment = Alignment.CenterHorizontally,
|
||||||
|
verticalArrangement = Arrangement.Center
|
||||||
|
) {
|
||||||
|
Text(
|
||||||
|
text = "🎉",
|
||||||
|
style = MaterialTheme.typography.displayLarge
|
||||||
|
)
|
||||||
|
|
||||||
|
Spacer(modifier = Modifier.height(24.dp))
|
||||||
|
|
||||||
|
Text(
|
||||||
|
text = "Success!",
|
||||||
style = MaterialTheme.typography.headlineMedium,
|
style = MaterialTheme.typography.headlineMedium,
|
||||||
fontWeight = FontWeight.Bold
|
fontWeight = FontWeight.Bold
|
||||||
)
|
)
|
||||||
|
|
||||||
Spacer(Modifier.height(8.dp))
|
Spacer(modifier = Modifier.height(16.dp))
|
||||||
|
|
||||||
Text(
|
Text(
|
||||||
text = message,
|
text = message,
|
||||||
@@ -678,10 +350,74 @@ fun ErrorScreen(
|
|||||||
color = MaterialTheme.colorScheme.onSurfaceVariant
|
color = MaterialTheme.colorScheme.onSurfaceVariant
|
||||||
)
|
)
|
||||||
|
|
||||||
Spacer(Modifier.height(24.dp))
|
Spacer(modifier = Modifier.height(32.dp))
|
||||||
|
|
||||||
Button(onClick = onRetry) {
|
Button(
|
||||||
Text("Try Again")
|
onClick = onDone,
|
||||||
|
modifier = Modifier.fillMaxWidth()
|
||||||
|
) {
|
||||||
|
Text("Done")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ===== ERROR STATE =====
|
||||||
|
|
||||||
|
@Composable
|
||||||
|
private fun ErrorStateContent(
|
||||||
|
title: String,
|
||||||
|
message: String,
|
||||||
|
onRetry: () -> Unit,
|
||||||
|
onBack: () -> Unit
|
||||||
|
) {
|
||||||
|
Column(
|
||||||
|
modifier = Modifier
|
||||||
|
.fillMaxSize()
|
||||||
|
.padding(24.dp),
|
||||||
|
horizontalAlignment = Alignment.CenterHorizontally,
|
||||||
|
verticalArrangement = Arrangement.Center
|
||||||
|
) {
|
||||||
|
Text(
|
||||||
|
text = "⚠️",
|
||||||
|
style = MaterialTheme.typography.displayLarge
|
||||||
|
)
|
||||||
|
|
||||||
|
Spacer(modifier = Modifier.height(24.dp))
|
||||||
|
|
||||||
|
Text(
|
||||||
|
text = title,
|
||||||
|
style = MaterialTheme.typography.headlineMedium,
|
||||||
|
fontWeight = FontWeight.Bold
|
||||||
|
)
|
||||||
|
|
||||||
|
Spacer(modifier = Modifier.height(16.dp))
|
||||||
|
|
||||||
|
Text(
|
||||||
|
text = message,
|
||||||
|
style = MaterialTheme.typography.bodyLarge,
|
||||||
|
textAlign = TextAlign.Center,
|
||||||
|
color = MaterialTheme.colorScheme.onSurfaceVariant
|
||||||
|
)
|
||||||
|
|
||||||
|
Spacer(modifier = Modifier.height(32.dp))
|
||||||
|
|
||||||
|
Row(
|
||||||
|
modifier = Modifier.fillMaxWidth(),
|
||||||
|
horizontalArrangement = Arrangement.spacedBy(12.dp)
|
||||||
|
) {
|
||||||
|
OutlinedButton(
|
||||||
|
onClick = onBack,
|
||||||
|
modifier = Modifier.weight(1f)
|
||||||
|
) {
|
||||||
|
Text("Back")
|
||||||
|
}
|
||||||
|
|
||||||
|
Button(
|
||||||
|
onClick = onRetry,
|
||||||
|
modifier = Modifier.weight(1f)
|
||||||
|
) {
|
||||||
|
Text("Retry")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -2,10 +2,15 @@ package com.placeholder.sherpai2.ui.discover
|
|||||||
|
|
||||||
import androidx.lifecycle.ViewModel
|
import androidx.lifecycle.ViewModel
|
||||||
import androidx.lifecycle.viewModelScope
|
import androidx.lifecycle.viewModelScope
|
||||||
|
import androidx.work.WorkManager
|
||||||
import com.placeholder.sherpai2.domain.clustering.ClusteringResult
|
import com.placeholder.sherpai2.domain.clustering.ClusteringResult
|
||||||
|
import com.placeholder.sherpai2.domain.clustering.ClusterQualityResult
|
||||||
import com.placeholder.sherpai2.domain.clustering.FaceCluster
|
import com.placeholder.sherpai2.domain.clustering.FaceCluster
|
||||||
import com.placeholder.sherpai2.domain.clustering.FaceClusteringService
|
import com.placeholder.sherpai2.domain.clustering.FaceClusteringService
|
||||||
import com.placeholder.sherpai2.domain.training.ClusterTrainingService
|
import com.placeholder.sherpai2.domain.training.ClusterTrainingService
|
||||||
|
import com.placeholder.sherpai2.domain.validation.ValidationScanResult
|
||||||
|
import com.placeholder.sherpai2.domain.validation.ValidationScanService
|
||||||
|
import com.placeholder.sherpai2.workers.LibraryScanWorker
|
||||||
import dagger.hilt.android.lifecycle.HiltViewModel
|
import dagger.hilt.android.lifecycle.HiltViewModel
|
||||||
import kotlinx.coroutines.flow.MutableStateFlow
|
import kotlinx.coroutines.flow.MutableStateFlow
|
||||||
import kotlinx.coroutines.flow.StateFlow
|
import kotlinx.coroutines.flow.StateFlow
|
||||||
@@ -14,21 +19,22 @@ import kotlinx.coroutines.launch
|
|||||||
import javax.inject.Inject
|
import javax.inject.Inject
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* DiscoverPeopleViewModel - Manages auto-clustering and naming flow
|
* DiscoverPeopleViewModel - Manages TWO-STAGE validation flow
|
||||||
*
|
*
|
||||||
* PHASE 2: Now includes multi-centroid training from clusters
|
* FLOW:
|
||||||
*
|
* 1. Clustering → User selects cluster
|
||||||
* STATE FLOW:
|
* 2. STAGE 1: Show cluster quality analysis
|
||||||
* 1. Idle → User taps "Discover People"
|
* 3. User names person → Training
|
||||||
* 2. Clustering → Auto-analyzing faces (2-5 min)
|
* 4. STAGE 2: Show validation scan preview
|
||||||
* 3. NamingReady → Shows clusters, user names them
|
* 5. User approves → Full library scan (background worker)
|
||||||
* 4. Training → Creating multi-centroid face model
|
* 6. Results appear in "People" tab
|
||||||
* 5. Complete → Ready to scan library
|
|
||||||
*/
|
*/
|
||||||
@HiltViewModel
|
@HiltViewModel
|
||||||
class DiscoverPeopleViewModel @Inject constructor(
|
class DiscoverPeopleViewModel @Inject constructor(
|
||||||
private val clusteringService: FaceClusteringService,
|
private val clusteringService: FaceClusteringService,
|
||||||
private val trainingService: ClusterTrainingService
|
private val trainingService: ClusterTrainingService,
|
||||||
|
private val validationScanService: ValidationScanService,
|
||||||
|
private val workManager: WorkManager
|
||||||
) : ViewModel() {
|
) : ViewModel() {
|
||||||
|
|
||||||
private val _uiState = MutableStateFlow<DiscoverUiState>(DiscoverUiState.Idle)
|
private val _uiState = MutableStateFlow<DiscoverUiState>(DiscoverUiState.Idle)
|
||||||
@@ -37,6 +43,9 @@ class DiscoverPeopleViewModel @Inject constructor(
|
|||||||
// Track which clusters have been named
|
// Track which clusters have been named
|
||||||
private val namedClusterIds = mutableSetOf<Int>()
|
private val namedClusterIds = mutableSetOf<Int>()
|
||||||
|
|
||||||
|
// Store quality analysis for current cluster
|
||||||
|
private var currentQualityResult: ClusterQualityResult? = null
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Start auto-clustering process
|
* Start auto-clustering process
|
||||||
*/
|
*/
|
||||||
@@ -78,27 +87,41 @@ class DiscoverPeopleViewModel @Inject constructor(
|
|||||||
|
|
||||||
/**
|
/**
|
||||||
* User selected a cluster to name
|
* User selected a cluster to name
|
||||||
|
* STAGE 1: Analyze quality FIRST
|
||||||
*/
|
*/
|
||||||
fun selectCluster(cluster: FaceCluster) {
|
fun selectCluster(cluster: FaceCluster) {
|
||||||
val currentState = _uiState.value
|
val currentState = _uiState.value
|
||||||
if (currentState is DiscoverUiState.NamingReady) {
|
if (currentState is DiscoverUiState.NamingReady) {
|
||||||
|
viewModelScope.launch {
|
||||||
|
try {
|
||||||
|
// Show analyzing state
|
||||||
|
_uiState.value = DiscoverUiState.AnalyzingCluster(cluster)
|
||||||
|
|
||||||
|
// Analyze cluster quality
|
||||||
|
val qualityResult = trainingService.analyzeClusterQuality(cluster)
|
||||||
|
currentQualityResult = qualityResult
|
||||||
|
|
||||||
|
// Show naming dialog with quality info
|
||||||
_uiState.value = DiscoverUiState.NamingCluster(
|
_uiState.value = DiscoverUiState.NamingCluster(
|
||||||
result = currentState.result,
|
result = currentState.result,
|
||||||
selectedCluster = cluster,
|
selectedCluster = cluster,
|
||||||
|
qualityResult = qualityResult,
|
||||||
suggestedSiblings = currentState.result.clusters.filter {
|
suggestedSiblings = currentState.result.clusters.filter {
|
||||||
it.clusterId in cluster.potentialSiblings
|
it.clusterId in cluster.potentialSiblings
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
} catch (e: Exception) {
|
||||||
|
_uiState.value = DiscoverUiState.Error(
|
||||||
|
"Failed to analyze cluster: ${e.message}"
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* User confirmed name and metadata for a cluster
|
* User confirmed name and metadata for a cluster
|
||||||
*
|
* STAGE 2: Train → Validation scan → Preview
|
||||||
* CREATES:
|
|
||||||
* 1. PersonEntity with all metadata (name, DOB, siblings)
|
|
||||||
* 2. Multi-centroid FaceModelEntity (temporal tracking for children)
|
|
||||||
* 3. Removes cluster from display
|
|
||||||
*/
|
*/
|
||||||
fun confirmClusterName(
|
fun confirmClusterName(
|
||||||
cluster: FaceCluster,
|
cluster: FaceCluster,
|
||||||
@@ -112,37 +135,59 @@ class DiscoverPeopleViewModel @Inject constructor(
|
|||||||
val currentState = _uiState.value
|
val currentState = _uiState.value
|
||||||
if (currentState !is DiscoverUiState.NamingCluster) return@launch
|
if (currentState !is DiscoverUiState.NamingCluster) return@launch
|
||||||
|
|
||||||
// Train person from cluster
|
// Show training progress
|
||||||
|
_uiState.value = DiscoverUiState.Training(
|
||||||
|
stage = "Creating person and training model",
|
||||||
|
progress = 0,
|
||||||
|
total = 100
|
||||||
|
)
|
||||||
|
|
||||||
|
// Train person from cluster (using clean faces from quality analysis)
|
||||||
val personId = trainingService.trainFromCluster(
|
val personId = trainingService.trainFromCluster(
|
||||||
cluster = cluster,
|
cluster = cluster,
|
||||||
name = name,
|
name = name,
|
||||||
dateOfBirth = dateOfBirth,
|
dateOfBirth = dateOfBirth,
|
||||||
isChild = isChild,
|
isChild = isChild,
|
||||||
siblingClusterIds = selectedSiblings,
|
siblingClusterIds = selectedSiblings,
|
||||||
|
qualityResult = currentQualityResult, // Use clean faces!
|
||||||
onProgress = { current, total, message ->
|
onProgress = { current, total, message ->
|
||||||
_uiState.value = DiscoverUiState.Clustering(current, total, message)
|
_uiState.value = DiscoverUiState.Training(
|
||||||
|
stage = message,
|
||||||
|
progress = current,
|
||||||
|
total = total
|
||||||
|
)
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Training complete - now run validation scan
|
||||||
|
_uiState.value = DiscoverUiState.Training(
|
||||||
|
stage = "Running validation scan...",
|
||||||
|
progress = 0,
|
||||||
|
total = 100
|
||||||
|
)
|
||||||
|
|
||||||
|
val validationResult = validationScanService.performValidationScan(
|
||||||
|
personId = personId,
|
||||||
|
onProgress = { current, total ->
|
||||||
|
_uiState.value = DiscoverUiState.Training(
|
||||||
|
stage = "Scanning sample photos...",
|
||||||
|
progress = current,
|
||||||
|
total = total
|
||||||
|
)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
// Show validation preview to user
|
||||||
|
_uiState.value = DiscoverUiState.ValidationPreview(
|
||||||
|
personId = personId,
|
||||||
|
personName = name,
|
||||||
|
validationResult = validationResult,
|
||||||
|
originalClusterResult = currentState.result
|
||||||
|
)
|
||||||
|
|
||||||
// Mark cluster as named
|
// Mark cluster as named
|
||||||
namedClusterIds.add(cluster.clusterId)
|
namedClusterIds.add(cluster.clusterId)
|
||||||
|
|
||||||
// Filter out named clusters
|
|
||||||
val remainingClusters = currentState.result.clusters
|
|
||||||
.filter { it.clusterId !in namedClusterIds }
|
|
||||||
|
|
||||||
if (remainingClusters.isEmpty()) {
|
|
||||||
// All clusters named! Show success
|
|
||||||
_uiState.value = DiscoverUiState.NoPeopleFound(
|
|
||||||
"All people have been named! 🎉\n\nGo to 'People' to see your trained models."
|
|
||||||
)
|
|
||||||
} else {
|
|
||||||
// Return to naming screen with remaining clusters
|
|
||||||
_uiState.value = DiscoverUiState.NamingReady(
|
|
||||||
result = currentState.result.copy(clusters = remainingClusters)
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
} catch (e: Exception) {
|
} catch (e: Exception) {
|
||||||
_uiState.value = DiscoverUiState.Error(
|
_uiState.value = DiscoverUiState.Error(
|
||||||
e.message ?: "Failed to create person: ${e.message}"
|
e.message ?: "Failed to create person: ${e.message}"
|
||||||
@@ -151,6 +196,57 @@ class DiscoverPeopleViewModel @Inject constructor(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* User approves validation preview → Start full library scan
|
||||||
|
*/
|
||||||
|
fun approveValidationAndScan(personId: String, personName: String) {
|
||||||
|
viewModelScope.launch {
|
||||||
|
val currentState = _uiState.value
|
||||||
|
if (currentState !is DiscoverUiState.ValidationPreview) return@launch
|
||||||
|
|
||||||
|
// Enqueue background worker for full library scan
|
||||||
|
val workRequest = LibraryScanWorker.createWorkRequest(
|
||||||
|
personId = personId,
|
||||||
|
personName = personName,
|
||||||
|
threshold = 0.70f // Slightly looser than validation
|
||||||
|
)
|
||||||
|
workManager.enqueue(workRequest)
|
||||||
|
|
||||||
|
// Filter out named clusters and return to cluster list
|
||||||
|
val remainingClusters = currentState.originalClusterResult.clusters
|
||||||
|
.filter { it.clusterId !in namedClusterIds }
|
||||||
|
|
||||||
|
if (remainingClusters.isEmpty()) {
|
||||||
|
// All clusters named! Show success
|
||||||
|
_uiState.value = DiscoverUiState.Complete(
|
||||||
|
message = "All people have been named! 🎉\n\n" +
|
||||||
|
"Full library scan is running in the background.\n" +
|
||||||
|
"Go to 'People' to see results as they come in."
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
// Return to naming screen with remaining clusters
|
||||||
|
_uiState.value = DiscoverUiState.NamingReady(
|
||||||
|
result = currentState.originalClusterResult.copy(clusters = remainingClusters)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* User rejects validation → Go back to add more training photos
|
||||||
|
*/
|
||||||
|
fun rejectValidationAndImprove() {
|
||||||
|
viewModelScope.launch {
|
||||||
|
val currentState = _uiState.value
|
||||||
|
if (currentState !is DiscoverUiState.ValidationPreview) return@launch
|
||||||
|
|
||||||
|
_uiState.value = DiscoverUiState.Error(
|
||||||
|
"Model quality needs improvement.\n\n" +
|
||||||
|
"Please use the manual training flow to add more high-quality photos."
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Cancel naming and go back to cluster list
|
* Cancel naming and go back to cluster list
|
||||||
*/
|
*/
|
||||||
@@ -172,7 +268,7 @@ class DiscoverPeopleViewModel @Inject constructor(
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* UI States for Discover People flow
|
* UI States for Discover People flow with TWO-STAGE VALIDATION
|
||||||
*/
|
*/
|
||||||
sealed class DiscoverUiState {
|
sealed class DiscoverUiState {
|
||||||
|
|
||||||
@@ -198,14 +294,48 @@ sealed class DiscoverUiState {
|
|||||||
) : DiscoverUiState()
|
) : DiscoverUiState()
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* User is naming a specific cluster
|
* STAGE 1: Analyzing cluster quality (before naming)
|
||||||
|
*/
|
||||||
|
data class AnalyzingCluster(
|
||||||
|
val cluster: FaceCluster
|
||||||
|
) : DiscoverUiState()
|
||||||
|
|
||||||
|
/**
|
||||||
|
* User is naming a specific cluster (with quality analysis)
|
||||||
*/
|
*/
|
||||||
data class NamingCluster(
|
data class NamingCluster(
|
||||||
val result: ClusteringResult,
|
val result: ClusteringResult,
|
||||||
val selectedCluster: FaceCluster,
|
val selectedCluster: FaceCluster,
|
||||||
|
val qualityResult: ClusterQualityResult,
|
||||||
val suggestedSiblings: List<FaceCluster>
|
val suggestedSiblings: List<FaceCluster>
|
||||||
) : DiscoverUiState()
|
) : DiscoverUiState()
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Training in progress
|
||||||
|
*/
|
||||||
|
data class Training(
|
||||||
|
val stage: String,
|
||||||
|
val progress: Int,
|
||||||
|
val total: Int
|
||||||
|
) : DiscoverUiState()
|
||||||
|
|
||||||
|
/**
|
||||||
|
* STAGE 2: Validation scan complete - show preview to user
|
||||||
|
*/
|
||||||
|
data class ValidationPreview(
|
||||||
|
val personId: String,
|
||||||
|
val personName: String,
|
||||||
|
val validationResult: ValidationScanResult,
|
||||||
|
val originalClusterResult: ClusteringResult
|
||||||
|
) : DiscoverUiState()
|
||||||
|
|
||||||
|
/**
|
||||||
|
* All clusters named and scans launched
|
||||||
|
*/
|
||||||
|
data class Complete(
|
||||||
|
val message: String
|
||||||
|
) : DiscoverUiState()
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* No people found in library
|
* No people found in library
|
||||||
*/
|
*/
|
||||||
|
|||||||
@@ -0,0 +1,395 @@
|
|||||||
|
package com.placeholder.sherpai2.ui.discover
|
||||||
|
|
||||||
|
import android.net.Uri
|
||||||
|
import androidx.compose.foundation.background
|
||||||
|
import androidx.compose.foundation.border
|
||||||
|
import androidx.compose.foundation.layout.*
|
||||||
|
import androidx.compose.foundation.lazy.grid.GridCells
|
||||||
|
import androidx.compose.foundation.lazy.grid.LazyVerticalGrid
|
||||||
|
import androidx.compose.foundation.lazy.grid.items
|
||||||
|
import androidx.compose.foundation.shape.RoundedCornerShape
|
||||||
|
import androidx.compose.material.icons.Icons
|
||||||
|
import androidx.compose.material.icons.filled.*
|
||||||
|
import androidx.compose.material3.*
|
||||||
|
import androidx.compose.runtime.*
|
||||||
|
import androidx.compose.ui.Alignment
|
||||||
|
import androidx.compose.ui.Modifier
|
||||||
|
import androidx.compose.ui.draw.clip
|
||||||
|
import androidx.compose.ui.graphics.Color
|
||||||
|
import androidx.compose.ui.layout.ContentScale
|
||||||
|
import androidx.compose.ui.text.font.FontWeight
|
||||||
|
import androidx.compose.ui.unit.dp
|
||||||
|
import coil.compose.AsyncImage
|
||||||
|
import com.placeholder.sherpai2.domain.validation.ValidationMatch
|
||||||
|
import com.placeholder.sherpai2.domain.validation.ValidationQuality
|
||||||
|
import com.placeholder.sherpai2.domain.validation.ValidationScanResult
|
||||||
|
|
||||||
|
/**
|
||||||
|
* ValidationPreviewScreen - STAGE 2 validation UI
|
||||||
|
*
|
||||||
|
* Shows user a preview of matches found in validation scan
|
||||||
|
* User can approve (→ full scan) or reject (→ add more photos)
|
||||||
|
*/
|
||||||
|
@Composable
|
||||||
|
fun ValidationPreviewScreen(
|
||||||
|
personName: String,
|
||||||
|
validationResult: ValidationScanResult,
|
||||||
|
onApprove: () -> Unit,
|
||||||
|
onReject: () -> Unit,
|
||||||
|
modifier: Modifier = Modifier
|
||||||
|
) {
|
||||||
|
Column(
|
||||||
|
modifier = modifier
|
||||||
|
.fillMaxSize()
|
||||||
|
.padding(16.dp)
|
||||||
|
) {
|
||||||
|
// Header
|
||||||
|
Text(
|
||||||
|
text = "Validation Results",
|
||||||
|
style = MaterialTheme.typography.headlineMedium,
|
||||||
|
fontWeight = FontWeight.Bold
|
||||||
|
)
|
||||||
|
|
||||||
|
Spacer(modifier = Modifier.height(8.dp))
|
||||||
|
|
||||||
|
Text(
|
||||||
|
text = "Review matches for \"$personName\" before scanning your entire library",
|
||||||
|
style = MaterialTheme.typography.bodyMedium,
|
||||||
|
color = MaterialTheme.colorScheme.onSurfaceVariant
|
||||||
|
)
|
||||||
|
|
||||||
|
Spacer(modifier = Modifier.height(16.dp))
|
||||||
|
|
||||||
|
// Quality Summary
|
||||||
|
QualitySummaryCard(
|
||||||
|
validationResult = validationResult,
|
||||||
|
personName = personName
|
||||||
|
)
|
||||||
|
|
||||||
|
Spacer(modifier = Modifier.height(16.dp))
|
||||||
|
|
||||||
|
// Matches Grid
|
||||||
|
if (validationResult.matches.isNotEmpty()) {
|
||||||
|
Text(
|
||||||
|
text = "Sample Matches (${validationResult.matchCount})",
|
||||||
|
style = MaterialTheme.typography.titleMedium,
|
||||||
|
fontWeight = FontWeight.SemiBold
|
||||||
|
)
|
||||||
|
|
||||||
|
Spacer(modifier = Modifier.height(8.dp))
|
||||||
|
|
||||||
|
LazyVerticalGrid(
|
||||||
|
columns = GridCells.Fixed(3),
|
||||||
|
horizontalArrangement = Arrangement.spacedBy(8.dp),
|
||||||
|
verticalArrangement = Arrangement.spacedBy(8.dp),
|
||||||
|
modifier = Modifier.weight(1f)
|
||||||
|
) {
|
||||||
|
items(validationResult.matches.take(15)) { match ->
|
||||||
|
MatchPreviewCard(match = match)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// No matches found
|
||||||
|
NoMatchesCard()
|
||||||
|
}
|
||||||
|
|
||||||
|
Spacer(modifier = Modifier.height(16.dp))
|
||||||
|
|
||||||
|
// Action Buttons
|
||||||
|
Row(
|
||||||
|
modifier = Modifier.fillMaxWidth(),
|
||||||
|
horizontalArrangement = Arrangement.spacedBy(12.dp)
|
||||||
|
) {
|
||||||
|
// Reject button
|
||||||
|
OutlinedButton(
|
||||||
|
onClick = onReject,
|
||||||
|
modifier = Modifier.weight(1f),
|
||||||
|
colors = ButtonDefaults.outlinedButtonColors(
|
||||||
|
contentColor = MaterialTheme.colorScheme.error
|
||||||
|
)
|
||||||
|
) {
|
||||||
|
Icon(
|
||||||
|
imageVector = Icons.Default.Close,
|
||||||
|
contentDescription = null,
|
||||||
|
modifier = Modifier.size(20.dp)
|
||||||
|
)
|
||||||
|
Spacer(modifier = Modifier.width(8.dp))
|
||||||
|
Text("Add More Photos")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Approve button
|
||||||
|
Button(
|
||||||
|
onClick = onApprove,
|
||||||
|
modifier = Modifier.weight(1f),
|
||||||
|
enabled = validationResult.qualityAssessment != ValidationQuality.NO_MATCHES
|
||||||
|
) {
|
||||||
|
Icon(
|
||||||
|
imageVector = Icons.Default.Check,
|
||||||
|
contentDescription = null,
|
||||||
|
modifier = Modifier.size(20.dp)
|
||||||
|
)
|
||||||
|
Spacer(modifier = Modifier.width(8.dp))
|
||||||
|
Text("Scan Library")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Composable
|
||||||
|
private fun QualitySummaryCard(
|
||||||
|
validationResult: ValidationScanResult,
|
||||||
|
personName: String
|
||||||
|
) {
|
||||||
|
val (backgroundColor, iconColor, statusText, statusIcon) = when (validationResult.qualityAssessment) {
|
||||||
|
ValidationQuality.EXCELLENT -> {
|
||||||
|
Quadruple(
|
||||||
|
Color(0xFF1B5E20).copy(alpha = 0.1f),
|
||||||
|
Color(0xFF1B5E20),
|
||||||
|
"Excellent Match Quality",
|
||||||
|
Icons.Default.CheckCircle
|
||||||
|
)
|
||||||
|
}
|
||||||
|
ValidationQuality.GOOD -> {
|
||||||
|
Quadruple(
|
||||||
|
Color(0xFF2E7D32).copy(alpha = 0.1f),
|
||||||
|
Color(0xFF2E7D32),
|
||||||
|
"Good Match Quality",
|
||||||
|
Icons.Default.ThumbUp
|
||||||
|
)
|
||||||
|
}
|
||||||
|
ValidationQuality.FAIR -> {
|
||||||
|
Quadruple(
|
||||||
|
Color(0xFFF57F17).copy(alpha = 0.1f),
|
||||||
|
Color(0xFFF57F17),
|
||||||
|
"Fair Match Quality",
|
||||||
|
Icons.Default.Warning
|
||||||
|
)
|
||||||
|
}
|
||||||
|
ValidationQuality.POOR -> {
|
||||||
|
Quadruple(
|
||||||
|
Color(0xFFD32F2F).copy(alpha = 0.1f),
|
||||||
|
Color(0xFFD32F2F),
|
||||||
|
"Poor Match Quality",
|
||||||
|
Icons.Default.Warning
|
||||||
|
)
|
||||||
|
}
|
||||||
|
ValidationQuality.NO_MATCHES -> {
|
||||||
|
Quadruple(
|
||||||
|
Color(0xFFD32F2F).copy(alpha = 0.1f),
|
||||||
|
Color(0xFFD32F2F),
|
||||||
|
"No Matches Found",
|
||||||
|
Icons.Default.Close
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Card(
|
||||||
|
modifier = Modifier.fillMaxWidth(),
|
||||||
|
colors = CardDefaults.cardColors(
|
||||||
|
containerColor = backgroundColor
|
||||||
|
)
|
||||||
|
) {
|
||||||
|
Column(
|
||||||
|
modifier = Modifier.padding(16.dp)
|
||||||
|
) {
|
||||||
|
Row(
|
||||||
|
verticalAlignment = Alignment.CenterVertically
|
||||||
|
) {
|
||||||
|
Icon(
|
||||||
|
imageVector = statusIcon,
|
||||||
|
contentDescription = null,
|
||||||
|
tint = iconColor,
|
||||||
|
modifier = Modifier.size(24.dp)
|
||||||
|
)
|
||||||
|
Spacer(modifier = Modifier.width(8.dp))
|
||||||
|
Text(
|
||||||
|
text = statusText,
|
||||||
|
style = MaterialTheme.typography.titleMedium,
|
||||||
|
fontWeight = FontWeight.Bold,
|
||||||
|
color = iconColor
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
Spacer(modifier = Modifier.height(12.dp))
|
||||||
|
|
||||||
|
// Stats
|
||||||
|
Row(
|
||||||
|
modifier = Modifier.fillMaxWidth(),
|
||||||
|
horizontalArrangement = Arrangement.SpaceBetween
|
||||||
|
) {
|
||||||
|
StatItem(
|
||||||
|
label = "Matches Found",
|
||||||
|
value = "${validationResult.matchCount} / ${validationResult.sampleSize}"
|
||||||
|
)
|
||||||
|
StatItem(
|
||||||
|
label = "Avg Confidence",
|
||||||
|
value = "${(validationResult.averageConfidence * 100).toInt()}%"
|
||||||
|
)
|
||||||
|
StatItem(
|
||||||
|
label = "Threshold",
|
||||||
|
value = "${(validationResult.threshold * 100).toInt()}%"
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Recommendation
|
||||||
|
if (validationResult.qualityAssessment != ValidationQuality.NO_MATCHES) {
|
||||||
|
Spacer(modifier = Modifier.height(12.dp))
|
||||||
|
|
||||||
|
val recommendation = when (validationResult.qualityAssessment) {
|
||||||
|
ValidationQuality.EXCELLENT ->
|
||||||
|
"✅ Model looks great! Safe to scan your full library."
|
||||||
|
ValidationQuality.GOOD ->
|
||||||
|
"✅ Model quality is good. You can proceed with the full scan."
|
||||||
|
ValidationQuality.FAIR ->
|
||||||
|
"⚠️ Model quality is acceptable but could be improved with more photos."
|
||||||
|
ValidationQuality.POOR ->
|
||||||
|
"⚠️ Consider adding more diverse, high-quality training photos."
|
||||||
|
ValidationQuality.NO_MATCHES -> ""
|
||||||
|
}
|
||||||
|
|
||||||
|
Text(
|
||||||
|
text = recommendation,
|
||||||
|
style = MaterialTheme.typography.bodyMedium,
|
||||||
|
color = MaterialTheme.colorScheme.onSurfaceVariant
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
Spacer(modifier = Modifier.height(12.dp))
|
||||||
|
Text(
|
||||||
|
text = "No matches found. The model may need more or better training photos, or the validation sample didn't include $personName.",
|
||||||
|
style = MaterialTheme.typography.bodyMedium,
|
||||||
|
color = MaterialTheme.colorScheme.error
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Composable
|
||||||
|
private fun StatItem(
|
||||||
|
label: String,
|
||||||
|
value: String
|
||||||
|
) {
|
||||||
|
Column(
|
||||||
|
horizontalAlignment = Alignment.CenterHorizontally
|
||||||
|
) {
|
||||||
|
Text(
|
||||||
|
text = value,
|
||||||
|
style = MaterialTheme.typography.titleLarge,
|
||||||
|
fontWeight = FontWeight.Bold
|
||||||
|
)
|
||||||
|
Text(
|
||||||
|
text = label,
|
||||||
|
style = MaterialTheme.typography.bodySmall,
|
||||||
|
color = MaterialTheme.colorScheme.onSurfaceVariant
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Composable
|
||||||
|
private fun MatchPreviewCard(
|
||||||
|
match: ValidationMatch
|
||||||
|
) {
|
||||||
|
Box(
|
||||||
|
modifier = Modifier
|
||||||
|
.aspectRatio(1f)
|
||||||
|
.clip(RoundedCornerShape(8.dp))
|
||||||
|
.background(MaterialTheme.colorScheme.surfaceVariant)
|
||||||
|
) {
|
||||||
|
AsyncImage(
|
||||||
|
model = Uri.parse(match.imageUri),
|
||||||
|
contentDescription = "Match preview",
|
||||||
|
modifier = Modifier.fillMaxSize(),
|
||||||
|
contentScale = ContentScale.Crop
|
||||||
|
)
|
||||||
|
|
||||||
|
// Confidence badge
|
||||||
|
Surface(
|
||||||
|
modifier = Modifier
|
||||||
|
.align(Alignment.BottomEnd)
|
||||||
|
.padding(4.dp),
|
||||||
|
shape = RoundedCornerShape(4.dp),
|
||||||
|
color = Color.Black.copy(alpha = 0.7f)
|
||||||
|
) {
|
||||||
|
Text(
|
||||||
|
text = "${(match.confidence * 100).toInt()}%",
|
||||||
|
style = MaterialTheme.typography.labelSmall,
|
||||||
|
color = Color.White,
|
||||||
|
modifier = Modifier.padding(horizontal = 6.dp, vertical = 2.dp)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Face count indicator (if group photo)
|
||||||
|
if (match.faceCount > 1) {
|
||||||
|
Surface(
|
||||||
|
modifier = Modifier
|
||||||
|
.align(Alignment.TopEnd)
|
||||||
|
.padding(4.dp),
|
||||||
|
shape = RoundedCornerShape(4.dp),
|
||||||
|
color = MaterialTheme.colorScheme.primary
|
||||||
|
) {
|
||||||
|
Row(
|
||||||
|
modifier = Modifier.padding(horizontal = 4.dp, vertical = 2.dp),
|
||||||
|
verticalAlignment = Alignment.CenterVertically
|
||||||
|
) {
|
||||||
|
Icon(
|
||||||
|
imageVector = Icons.Default.Person,
|
||||||
|
contentDescription = null,
|
||||||
|
tint = Color.White,
|
||||||
|
modifier = Modifier.size(12.dp)
|
||||||
|
)
|
||||||
|
Text(
|
||||||
|
text = "${match.faceCount}",
|
||||||
|
style = MaterialTheme.typography.labelSmall,
|
||||||
|
color = Color.White
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Composable
|
||||||
|
private fun NoMatchesCard() {
|
||||||
|
Card(
|
||||||
|
modifier = Modifier.fillMaxWidth(),
|
||||||
|
colors = CardDefaults.cardColors(
|
||||||
|
containerColor = MaterialTheme.colorScheme.errorContainer
|
||||||
|
)
|
||||||
|
) {
|
||||||
|
Column(
|
||||||
|
modifier = Modifier.padding(24.dp),
|
||||||
|
horizontalAlignment = Alignment.CenterHorizontally
|
||||||
|
) {
|
||||||
|
Icon(
|
||||||
|
imageVector = Icons.Default.Warning,
|
||||||
|
contentDescription = null,
|
||||||
|
tint = MaterialTheme.colorScheme.error,
|
||||||
|
modifier = Modifier.size(48.dp)
|
||||||
|
)
|
||||||
|
Spacer(modifier = Modifier.height(16.dp))
|
||||||
|
Text(
|
||||||
|
text = "No Matches Found",
|
||||||
|
style = MaterialTheme.typography.titleLarge,
|
||||||
|
fontWeight = FontWeight.Bold,
|
||||||
|
color = MaterialTheme.colorScheme.error
|
||||||
|
)
|
||||||
|
Spacer(modifier = Modifier.height(8.dp))
|
||||||
|
Text(
|
||||||
|
text = "The validation scan didn't find this person in the sample photos. This could mean:\n\n" +
|
||||||
|
"• The model needs more training photos\n" +
|
||||||
|
"• The training photos weren't diverse enough\n" +
|
||||||
|
"• The person wasn't in the validation sample",
|
||||||
|
style = MaterialTheme.typography.bodyMedium,
|
||||||
|
color = MaterialTheme.colorScheme.onErrorContainer
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper data class for quality indicator
|
||||||
|
private data class Quadruple<A, B, C, D>(
|
||||||
|
val first: A,
|
||||||
|
val second: B,
|
||||||
|
val third: C,
|
||||||
|
val fourth: D
|
||||||
|
)
|
||||||
@@ -154,13 +154,3 @@ fun getDestinationByRoute(route: String?): AppDestinations? {
|
|||||||
else -> null
|
else -> null
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Legacy support (for backwards compatibility)
|
|
||||||
* These match your old structure
|
|
||||||
*/
|
|
||||||
@Deprecated("Use organized groups instead", ReplaceWith("allMainDrawerDestinations"))
|
|
||||||
val mainDrawerItems = allMainDrawerDestinations
|
|
||||||
|
|
||||||
@Deprecated("Use settingsDestination instead", ReplaceWith("listOf(settingsDestination)"))
|
|
||||||
val utilityDrawerItems = listOf(settingsDestination)
|
|
||||||
@@ -15,7 +15,10 @@ import com.placeholder.sherpai2.ui.navigation.AppRoutes
|
|||||||
import kotlinx.coroutines.launch
|
import kotlinx.coroutines.launch
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Clean main screen - NO duplicate FABs, Collections support, Discover People
|
* MainScreen - FIXED double header issue
|
||||||
|
*
|
||||||
|
* BEST PRACTICE: Screens that manage their own TopAppBar should be excluded
|
||||||
|
* from MainScreen's TopAppBar to prevent ugly double headers.
|
||||||
*/
|
*/
|
||||||
@OptIn(ExperimentalMaterial3Api::class)
|
@OptIn(ExperimentalMaterial3Api::class)
|
||||||
@Composable
|
@Composable
|
||||||
@@ -45,8 +48,16 @@ fun MainScreen() {
|
|||||||
)
|
)
|
||||||
},
|
},
|
||||||
) {
|
) {
|
||||||
|
// CRITICAL: Some screens manage their own TopAppBar
|
||||||
|
// Hide MainScreen's TopAppBar for these routes to prevent double headers
|
||||||
|
val screensWithOwnTopBar = setOf(
|
||||||
|
AppRoutes.TRAINING_PHOTO_SELECTOR // Has its own TopAppBar with subtitle
|
||||||
|
)
|
||||||
|
val showTopBar = currentRoute !in screensWithOwnTopBar
|
||||||
|
|
||||||
Scaffold(
|
Scaffold(
|
||||||
topBar = {
|
topBar = {
|
||||||
|
if (showTopBar) {
|
||||||
TopAppBar(
|
TopAppBar(
|
||||||
title = {
|
title = {
|
||||||
Column {
|
Column {
|
||||||
@@ -108,6 +119,7 @@ fun MainScreen() {
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
) { paddingValues ->
|
) { paddingValues ->
|
||||||
AppNavHost(
|
AppNavHost(
|
||||||
navController = navController,
|
navController = navController,
|
||||||
@@ -125,10 +137,10 @@ private fun getScreenTitle(route: String): String {
|
|||||||
AppRoutes.SEARCH -> "Search"
|
AppRoutes.SEARCH -> "Search"
|
||||||
AppRoutes.EXPLORE -> "Explore"
|
AppRoutes.EXPLORE -> "Explore"
|
||||||
AppRoutes.COLLECTIONS -> "Collections"
|
AppRoutes.COLLECTIONS -> "Collections"
|
||||||
AppRoutes.DISCOVER -> "Discover People" // ✨ NEW!
|
AppRoutes.DISCOVER -> "Discover People"
|
||||||
AppRoutes.INVENTORY -> "People"
|
AppRoutes.INVENTORY -> "People"
|
||||||
AppRoutes.TRAIN -> "Train New Person"
|
AppRoutes.TRAIN -> "Train New Person"
|
||||||
AppRoutes.MODELS -> "AI Models" // Deprecated, but keep for backwards compat
|
AppRoutes.MODELS -> "AI Models"
|
||||||
AppRoutes.TAGS -> "Tag Management"
|
AppRoutes.TAGS -> "Tag Management"
|
||||||
AppRoutes.UTILITIES -> "Photo Util."
|
AppRoutes.UTILITIES -> "Photo Util."
|
||||||
AppRoutes.SETTINGS -> "Settings"
|
AppRoutes.SETTINGS -> "Settings"
|
||||||
@@ -144,7 +156,7 @@ private fun getScreenSubtitle(route: String): String? {
|
|||||||
AppRoutes.SEARCH -> "Find photos by tags, people, or date"
|
AppRoutes.SEARCH -> "Find photos by tags, people, or date"
|
||||||
AppRoutes.EXPLORE -> "Browse your collection"
|
AppRoutes.EXPLORE -> "Browse your collection"
|
||||||
AppRoutes.COLLECTIONS -> "Your photo collections"
|
AppRoutes.COLLECTIONS -> "Your photo collections"
|
||||||
AppRoutes.DISCOVER -> "Auto-find faces in your library" // ✨ NEW!
|
AppRoutes.DISCOVER -> "Auto-find faces in your library"
|
||||||
AppRoutes.INVENTORY -> "Trained face models"
|
AppRoutes.INVENTORY -> "Trained face models"
|
||||||
AppRoutes.TRAIN -> "Add a new person to recognize"
|
AppRoutes.TRAIN -> "Add a new person to recognize"
|
||||||
AppRoutes.TAGS -> "Organize your photo collection"
|
AppRoutes.TAGS -> "Organize your photo collection"
|
||||||
|
|||||||
@@ -14,7 +14,9 @@ import javax.inject.Inject
|
|||||||
* ImageSelectorViewModel
|
* ImageSelectorViewModel
|
||||||
*
|
*
|
||||||
* Provides face-tagged image URIs for smart filtering
|
* Provides face-tagged image URIs for smart filtering
|
||||||
* during training photo selection
|
* during training photo selection.
|
||||||
|
*
|
||||||
|
* PRIORITIZATION: Solo photos first (faceCount=1) for clearer training data
|
||||||
*/
|
*/
|
||||||
@HiltViewModel
|
@HiltViewModel
|
||||||
class ImageSelectorViewModel @Inject constructor(
|
class ImageSelectorViewModel @Inject constructor(
|
||||||
@@ -31,8 +33,15 @@ class ImageSelectorViewModel @Inject constructor(
|
|||||||
private fun loadFaceTaggedImages() {
|
private fun loadFaceTaggedImages() {
|
||||||
viewModelScope.launch {
|
viewModelScope.launch {
|
||||||
try {
|
try {
|
||||||
|
// Get all images with faces
|
||||||
val imagesWithFaces = imageDao.getImagesWithFaces()
|
val imagesWithFaces = imageDao.getImagesWithFaces()
|
||||||
_faceTaggedImageUris.value = imagesWithFaces.map { it.imageUri }
|
|
||||||
|
// CRITICAL FIX: Sort by faceCount ASCENDING (solo photos first!)
|
||||||
|
// Previously: Sorted by faceCount DESC (group photos first - WRONG!)
|
||||||
|
// Now: Solo photos appear first, making training selection easier
|
||||||
|
val sortedImages = imagesWithFaces.sortedBy { it.faceCount }
|
||||||
|
|
||||||
|
_faceTaggedImageUris.value = sortedImages.map { it.imageUri }
|
||||||
} catch (e: Exception) {
|
} catch (e: Exception) {
|
||||||
// If cache not available, just use empty list (filter disabled)
|
// If cache not available, just use empty list (filter disabled)
|
||||||
_faceTaggedImageUris.value = emptyList()
|
_faceTaggedImageUris.value = emptyList()
|
||||||
|
|||||||
@@ -46,6 +46,8 @@ class TrainingPhotoSelectorViewModel @Inject constructor(
|
|||||||
*
|
*
|
||||||
* Uses indexed query: SELECT * FROM images WHERE hasFaces = 1
|
* Uses indexed query: SELECT * FROM images WHERE hasFaces = 1
|
||||||
* Fast! (~10ms for 10k photos)
|
* Fast! (~10ms for 10k photos)
|
||||||
|
*
|
||||||
|
* SORTED: Solo photos (faceCount=1) first for best training quality
|
||||||
*/
|
*/
|
||||||
private fun loadPhotosWithFaces() {
|
private fun loadPhotosWithFaces() {
|
||||||
viewModelScope.launch {
|
viewModelScope.launch {
|
||||||
@@ -55,8 +57,9 @@ class TrainingPhotoSelectorViewModel @Inject constructor(
|
|||||||
// ✅ CRITICAL: Only get images with faces!
|
// ✅ CRITICAL: Only get images with faces!
|
||||||
val photos = imageDao.getImagesWithFaces()
|
val photos = imageDao.getImagesWithFaces()
|
||||||
|
|
||||||
// Sort by most faces first (better for training)
|
// ✅ FIX: Sort by LEAST faces first (solo photos = best training data)
|
||||||
val sorted = photos.sortedByDescending { it.faceCount ?: 0 }
|
// faceCount=1 first, then faceCount=2, etc.
|
||||||
|
val sorted = photos.sortedBy { it.faceCount ?: 999 }
|
||||||
|
|
||||||
_photosWithFaces.value = sorted
|
_photosWithFaces.value = sorted
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,315 @@
|
|||||||
|
package com.placeholder.sherpai2.workers
|
||||||
|
|
||||||
|
import android.content.Context
|
||||||
|
import android.graphics.BitmapFactory
|
||||||
|
import android.net.Uri
|
||||||
|
import androidx.hilt.work.HiltWorker
|
||||||
|
import androidx.work.*
|
||||||
|
import com.google.mlkit.vision.common.InputImage
|
||||||
|
import com.google.mlkit.vision.face.FaceDetection
|
||||||
|
import com.google.mlkit.vision.face.FaceDetectorOptions
|
||||||
|
import com.placeholder.sherpai2.data.local.dao.FaceModelDao
|
||||||
|
import com.placeholder.sherpai2.data.local.dao.ImageDao
|
||||||
|
import com.placeholder.sherpai2.data.local.dao.PhotoFaceTagDao
|
||||||
|
import com.placeholder.sherpai2.data.local.entity.PhotoFaceTagEntity
|
||||||
|
import com.placeholder.sherpai2.ml.FaceNetModel
|
||||||
|
import dagger.assisted.Assisted
|
||||||
|
import dagger.assisted.AssistedInject
|
||||||
|
import kotlinx.coroutines.Dispatchers
|
||||||
|
import kotlinx.coroutines.tasks.await
|
||||||
|
import kotlinx.coroutines.withContext
|
||||||
|
|
||||||
|
/**
|
||||||
|
* LibraryScanWorker - Full library background scan for a trained person
|
||||||
|
*
|
||||||
|
* PURPOSE: After user approves validation preview, scan entire library
|
||||||
|
*
|
||||||
|
* STRATEGY:
|
||||||
|
* 1. Load all photos with faces (from cache)
|
||||||
|
* 2. Scan each photo for the trained person
|
||||||
|
* 3. Create PhotoFaceTagEntity for matches
|
||||||
|
* 4. Progressive updates to "People" tab
|
||||||
|
* 5. Supports pause/resume via WorkManager
|
||||||
|
*
|
||||||
|
* SCHEDULING:
|
||||||
|
* - Runs in background with progress notifications
|
||||||
|
* - Can be cancelled by user
|
||||||
|
* - Automatically retries on failure
|
||||||
|
*
|
||||||
|
* INPUT DATA:
|
||||||
|
* - personId: String (UUID)
|
||||||
|
* - personName: String (for notifications)
|
||||||
|
* - threshold: Float (optional, default 0.70)
|
||||||
|
*
|
||||||
|
* OUTPUT DATA:
|
||||||
|
* - matchesFound: Int
|
||||||
|
* - photosScanned: Int
|
||||||
|
* - errorMessage: String? (if failed)
|
||||||
|
*/
|
||||||
|
@HiltWorker
|
||||||
|
class LibraryScanWorker @AssistedInject constructor(
|
||||||
|
@Assisted private val context: Context,
|
||||||
|
@Assisted workerParams: WorkerParameters,
|
||||||
|
private val imageDao: ImageDao,
|
||||||
|
private val faceModelDao: FaceModelDao,
|
||||||
|
private val photoFaceTagDao: PhotoFaceTagDao
|
||||||
|
) : CoroutineWorker(context, workerParams) {
|
||||||
|
|
||||||
|
companion object {
|
||||||
|
const val WORK_NAME_PREFIX = "library_scan_"
|
||||||
|
const val KEY_PERSON_ID = "person_id"
|
||||||
|
const val KEY_PERSON_NAME = "person_name"
|
||||||
|
const val KEY_THRESHOLD = "threshold"
|
||||||
|
const val KEY_PROGRESS_CURRENT = "progress_current"
|
||||||
|
const val KEY_PROGRESS_TOTAL = "progress_total"
|
||||||
|
const val KEY_MATCHES_FOUND = "matches_found"
|
||||||
|
const val KEY_PHOTOS_SCANNED = "photos_scanned"
|
||||||
|
|
||||||
|
private const val DEFAULT_THRESHOLD = 0.70f // Slightly looser than validation
|
||||||
|
private const val BATCH_SIZE = 20
|
||||||
|
private const val MAX_RETRIES = 3
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create work request for library scan
|
||||||
|
*/
|
||||||
|
fun createWorkRequest(
|
||||||
|
personId: String,
|
||||||
|
personName: String,
|
||||||
|
threshold: Float = DEFAULT_THRESHOLD
|
||||||
|
): OneTimeWorkRequest {
|
||||||
|
val inputData = workDataOf(
|
||||||
|
KEY_PERSON_ID to personId,
|
||||||
|
KEY_PERSON_NAME to personName,
|
||||||
|
KEY_THRESHOLD to threshold
|
||||||
|
)
|
||||||
|
|
||||||
|
return OneTimeWorkRequestBuilder<LibraryScanWorker>()
|
||||||
|
.setInputData(inputData)
|
||||||
|
.setConstraints(
|
||||||
|
Constraints.Builder()
|
||||||
|
.setRequiresBatteryNotLow(true) // Don't drain battery
|
||||||
|
.build()
|
||||||
|
)
|
||||||
|
.addTag(WORK_NAME_PREFIX + personId)
|
||||||
|
.build()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
override suspend fun doWork(): Result = withContext(Dispatchers.Default) {
|
||||||
|
try {
|
||||||
|
// Get input parameters
|
||||||
|
val personId = inputData.getString(KEY_PERSON_ID)
|
||||||
|
?: return@withContext Result.failure(
|
||||||
|
workDataOf("error" to "Missing person ID")
|
||||||
|
)
|
||||||
|
|
||||||
|
val personName = inputData.getString(KEY_PERSON_NAME) ?: "Unknown"
|
||||||
|
val threshold = inputData.getFloat(KEY_THRESHOLD, DEFAULT_THRESHOLD)
|
||||||
|
|
||||||
|
// Check if stopped
|
||||||
|
if (isStopped) {
|
||||||
|
return@withContext Result.failure()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Step 1: Get face model
|
||||||
|
val faceModel = withContext(Dispatchers.IO) {
|
||||||
|
faceModelDao.getFaceModelByPersonId(personId)
|
||||||
|
} ?: return@withContext Result.failure(
|
||||||
|
workDataOf("error" to "Face model not found")
|
||||||
|
)
|
||||||
|
|
||||||
|
setProgress(workDataOf(
|
||||||
|
KEY_PROGRESS_CURRENT to 0,
|
||||||
|
KEY_PROGRESS_TOTAL to 100
|
||||||
|
))
|
||||||
|
|
||||||
|
// Step 2: Get all photos with faces (from cache)
|
||||||
|
val photosWithFaces = withContext(Dispatchers.IO) {
|
||||||
|
imageDao.getImagesWithFaces()
|
||||||
|
}
|
||||||
|
|
||||||
|
if (photosWithFaces.isEmpty()) {
|
||||||
|
return@withContext Result.success(
|
||||||
|
workDataOf(
|
||||||
|
KEY_MATCHES_FOUND to 0,
|
||||||
|
KEY_PHOTOS_SCANNED to 0
|
||||||
|
)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Step 3: Initialize ML components
|
||||||
|
val faceNetModel = FaceNetModel(context)
|
||||||
|
val detector = FaceDetection.getClient(
|
||||||
|
FaceDetectorOptions.Builder()
|
||||||
|
.setPerformanceMode(FaceDetectorOptions.PERFORMANCE_MODE_ACCURATE)
|
||||||
|
.setMinFaceSize(0.15f)
|
||||||
|
.build()
|
||||||
|
)
|
||||||
|
|
||||||
|
val modelEmbedding = faceModel.getEmbeddingArray()
|
||||||
|
var matchesFound = 0
|
||||||
|
var photosScanned = 0
|
||||||
|
|
||||||
|
try {
|
||||||
|
// Step 4: Process in batches
|
||||||
|
photosWithFaces.chunked(BATCH_SIZE).forEach { batch ->
|
||||||
|
if (isStopped) {
|
||||||
|
return@forEach
|
||||||
|
}
|
||||||
|
|
||||||
|
// Scan batch
|
||||||
|
batch.forEach { photo ->
|
||||||
|
try {
|
||||||
|
val tags = scanPhotoForPerson(
|
||||||
|
photo = photo,
|
||||||
|
personId = personId,
|
||||||
|
faceModelId = faceModel.id,
|
||||||
|
modelEmbedding = modelEmbedding,
|
||||||
|
faceNetModel = faceNetModel,
|
||||||
|
detector = detector,
|
||||||
|
threshold = threshold
|
||||||
|
)
|
||||||
|
|
||||||
|
if (tags.isNotEmpty()) {
|
||||||
|
// Save tags
|
||||||
|
withContext(Dispatchers.IO) {
|
||||||
|
photoFaceTagDao.insertTags(tags)
|
||||||
|
}
|
||||||
|
matchesFound += tags.size
|
||||||
|
}
|
||||||
|
|
||||||
|
photosScanned++
|
||||||
|
|
||||||
|
// Update progress
|
||||||
|
if (photosScanned % 10 == 0) {
|
||||||
|
val progress = (photosScanned * 100 / photosWithFaces.size)
|
||||||
|
setProgress(workDataOf(
|
||||||
|
KEY_PROGRESS_CURRENT to photosScanned,
|
||||||
|
KEY_PROGRESS_TOTAL to photosWithFaces.size,
|
||||||
|
KEY_MATCHES_FOUND to matchesFound
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
} catch (e: Exception) {
|
||||||
|
// Skip failed photos, continue scanning
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Success!
|
||||||
|
Result.success(
|
||||||
|
workDataOf(
|
||||||
|
KEY_MATCHES_FOUND to matchesFound,
|
||||||
|
KEY_PHOTOS_SCANNED to photosScanned
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
} finally {
|
||||||
|
faceNetModel.close()
|
||||||
|
detector.close()
|
||||||
|
}
|
||||||
|
|
||||||
|
} catch (e: Exception) {
|
||||||
|
// Retry on failure
|
||||||
|
if (runAttemptCount < MAX_RETRIES) {
|
||||||
|
Result.retry()
|
||||||
|
} else {
|
||||||
|
Result.failure(
|
||||||
|
workDataOf("error" to (e.message ?: "Unknown error"))
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Scan a single photo for the person
|
||||||
|
*/
|
||||||
|
private suspend fun scanPhotoForPerson(
|
||||||
|
photo: com.placeholder.sherpai2.data.local.entity.ImageEntity,
|
||||||
|
personId: String,
|
||||||
|
faceModelId: String,
|
||||||
|
modelEmbedding: FloatArray,
|
||||||
|
faceNetModel: FaceNetModel,
|
||||||
|
detector: com.google.mlkit.vision.face.FaceDetector,
|
||||||
|
threshold: Float
|
||||||
|
): List<PhotoFaceTagEntity> = withContext(Dispatchers.IO) {
|
||||||
|
|
||||||
|
try {
|
||||||
|
// Load bitmap
|
||||||
|
val bitmap = loadBitmapDownsampled(Uri.parse(photo.imageUri), 768)
|
||||||
|
?: return@withContext emptyList()
|
||||||
|
|
||||||
|
// Detect faces
|
||||||
|
val inputImage = InputImage.fromBitmap(bitmap, 0)
|
||||||
|
val faces = detector.process(inputImage).await()
|
||||||
|
|
||||||
|
// Check each face
|
||||||
|
val tags = faces.mapNotNull { face ->
|
||||||
|
try {
|
||||||
|
// Crop face
|
||||||
|
val faceBitmap = android.graphics.Bitmap.createBitmap(
|
||||||
|
bitmap,
|
||||||
|
face.boundingBox.left.coerceIn(0, bitmap.width - 1),
|
||||||
|
face.boundingBox.top.coerceIn(0, bitmap.height - 1),
|
||||||
|
face.boundingBox.width().coerceAtMost(bitmap.width - face.boundingBox.left),
|
||||||
|
face.boundingBox.height().coerceAtMost(bitmap.height - face.boundingBox.top)
|
||||||
|
)
|
||||||
|
|
||||||
|
// Generate embedding
|
||||||
|
val faceEmbedding = faceNetModel.generateEmbedding(faceBitmap)
|
||||||
|
faceBitmap.recycle()
|
||||||
|
|
||||||
|
// Calculate similarity
|
||||||
|
val similarity = faceNetModel.calculateSimilarity(faceEmbedding, modelEmbedding)
|
||||||
|
|
||||||
|
if (similarity >= threshold) {
|
||||||
|
PhotoFaceTagEntity.create(
|
||||||
|
imageId = photo.imageId,
|
||||||
|
faceModelId = faceModelId,
|
||||||
|
boundingBox = face.boundingBox,
|
||||||
|
confidence = similarity,
|
||||||
|
faceEmbedding = faceEmbedding
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
null
|
||||||
|
}
|
||||||
|
} catch (e: Exception) {
|
||||||
|
null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bitmap.recycle()
|
||||||
|
tags
|
||||||
|
|
||||||
|
} catch (e: Exception) {
|
||||||
|
emptyList()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Load bitmap with downsampling for memory efficiency
|
||||||
|
*/
|
||||||
|
private fun loadBitmapDownsampled(uri: Uri, maxDim: Int): android.graphics.Bitmap? {
|
||||||
|
return try {
|
||||||
|
val opts = BitmapFactory.Options().apply { inJustDecodeBounds = true }
|
||||||
|
context.contentResolver.openInputStream(uri)?.use {
|
||||||
|
BitmapFactory.decodeStream(it, null, opts)
|
||||||
|
}
|
||||||
|
|
||||||
|
var sample = 1
|
||||||
|
while (opts.outWidth / sample > maxDim || opts.outHeight / sample > maxDim) {
|
||||||
|
sample *= 2
|
||||||
|
}
|
||||||
|
|
||||||
|
val finalOpts = BitmapFactory.Options().apply {
|
||||||
|
inSampleSize = sample
|
||||||
|
}
|
||||||
|
|
||||||
|
context.contentResolver.openInputStream(uri)?.use {
|
||||||
|
BitmapFactory.decodeStream(it, null, finalOpts)
|
||||||
|
}
|
||||||
|
} catch (e: Exception) {
|
||||||
|
null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user